sglang 0.4.1.post2__py3-none-any.whl → 0.4.1.post4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (173) hide show
  1. sglang/bench_one_batch.py +2 -0
  2. sglang/srt/layers/attention/__init__.py +14 -5
  3. sglang/srt/layers/attention/double_sparsity_backend.py +0 -52
  4. sglang/srt/layers/attention/flashinfer_backend.py +211 -81
  5. sglang/srt/layers/attention/torch_native_backend.py +1 -38
  6. sglang/srt/layers/attention/triton_backend.py +20 -11
  7. sglang/srt/layers/attention/triton_ops/decode_attention.py +4 -0
  8. sglang/srt/layers/logits_processor.py +167 -212
  9. sglang/srt/layers/moe/fused_moe_triton/configs/E=1,N=14336,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json +146 -0
  10. sglang/srt/layers/moe/fused_moe_triton/configs/E=1,N=14336,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  11. sglang/srt/layers/moe/fused_moe_triton/configs/E=1,N=1792,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json +218 -0
  12. sglang/srt/layers/moe/fused_moe_triton/configs/E=1,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json +218 -0
  13. sglang/srt/layers/moe/fused_moe_triton/configs/E=1,N=3072,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json +218 -0
  14. sglang/srt/layers/moe/fused_moe_triton/configs/E=1,N=3072,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json +218 -0
  15. sglang/srt/layers/moe/fused_moe_triton/configs/E=1,N=3072,device_name=NVIDIA_H100_80GB_HBM3.json +218 -0
  16. sglang/srt/layers/moe/fused_moe_triton/configs/E=1,N=3584,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json +218 -0
  17. sglang/srt/layers/moe/fused_moe_triton/configs/E=1,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json +218 -0
  18. sglang/srt/layers/moe/fused_moe_triton/configs/E=1,N=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json +218 -0
  19. sglang/srt/layers/moe/fused_moe_triton/configs/E=1,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json +218 -0
  20. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1344,device_name=NVIDIA_A100-SXM4-40GB.json +146 -0
  21. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1344,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  22. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1344,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  23. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=14336,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json +146 -0
  24. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=14336,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  25. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1792,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json +218 -0
  26. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json +218 -0
  27. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=2688,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  28. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=2688,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  29. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=3072,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json +146 -0
  30. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=3072,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json +146 -0
  31. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=3200,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +130 -0
  32. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=3584,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json +146 -0
  33. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json +218 -0
  34. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=6400,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +130 -0
  35. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json +146 -0
  36. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  37. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json +146 -0
  38. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=800,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +130 -0
  39. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  40. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  41. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1280,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  42. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1280,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  43. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1280,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +146 -0
  44. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1280,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  45. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1280,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  46. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1280,device_name=NVIDIA_H200.json +146 -0
  47. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=2560,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +146 -0
  48. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=2560,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  49. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=2560,device_name=NVIDIA_H200.json +146 -0
  50. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=320,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +146 -0
  51. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=320,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  52. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=320,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  53. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=320,device_name=NVIDIA_H200.json +146 -0
  54. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  55. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  56. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_GeForce_RTX_4090,dtype=fp8_w8a8.json +146 -0
  57. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +146 -0
  58. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  59. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  60. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_H200.json +146 -0
  61. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X.json +200 -0
  62. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +138 -0
  63. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  64. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=NVIDIA_H200.json +146 -0
  65. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI300X.json +200 -0
  66. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=NVIDIA_A100-SXM4-40GB.json +146 -0
  67. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  68. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  69. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  70. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=NVIDIA_H200.json +146 -0
  71. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=2048,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  72. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=2048,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +146 -0
  73. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  74. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=2048,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  75. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=2048,device_name=NVIDIA_H200.json +146 -0
  76. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X.json +200 -0
  77. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_A100-SXM4-40GB.json +146 -0
  78. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  79. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_GeForce_RTX_4090,dtype=fp8_w8a8.json +146 -0
  80. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +146 -0
  81. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  82. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  83. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_H200.json +146 -0
  84. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_L40S.json +173 -0
  85. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +178 -0
  86. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  87. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +146 -0
  88. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  89. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  90. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=NVIDIA_H200.json +146 -0
  91. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X.json +200 -0
  92. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  93. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +146 -0
  94. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  95. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  96. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=NVIDIA_H200.json +146 -0
  97. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +175 -0
  98. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +146 -0
  99. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  100. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +187 -29
  101. sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -6
  102. sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  103. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  104. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  105. sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  106. sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  107. sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  108. sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  109. sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  110. sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  111. sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  112. sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  113. sglang/srt/layers/quantization/configs/N=3072,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  114. sglang/srt/layers/quantization/configs/N=3072,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  115. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  116. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  117. sglang/srt/layers/quantization/configs/N=36864,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  118. sglang/srt/layers/quantization/configs/N=36864,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  119. sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  120. sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  121. sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  122. sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  123. sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  124. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  125. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  126. sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  127. sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  128. sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  129. sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  130. sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  131. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  132. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  133. sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  134. sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  135. sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  136. sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  137. sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  138. sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  139. sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  140. sglang/srt/layers/quantization/fp8.py +2 -2
  141. sglang/srt/layers/sampler.py +57 -21
  142. sglang/srt/layers/torchao_utils.py +17 -3
  143. sglang/srt/managers/detokenizer_manager.py +2 -0
  144. sglang/srt/managers/io_struct.py +12 -3
  145. sglang/srt/managers/schedule_batch.py +26 -2
  146. sglang/srt/managers/schedule_policy.py +159 -90
  147. sglang/srt/managers/scheduler.py +71 -27
  148. sglang/srt/managers/tokenizer_manager.py +29 -20
  149. sglang/srt/managers/tp_worker.py +16 -4
  150. sglang/srt/managers/tp_worker_overlap_thread.py +3 -4
  151. sglang/srt/model_executor/cuda_graph_runner.py +118 -73
  152. sglang/srt/model_executor/forward_batch_info.py +33 -8
  153. sglang/srt/model_executor/model_runner.py +63 -61
  154. sglang/srt/models/deepseek_v2.py +34 -7
  155. sglang/srt/models/grok.py +97 -26
  156. sglang/srt/openai_api/adapter.py +0 -17
  157. sglang/srt/openai_api/protocol.py +3 -3
  158. sglang/srt/sampling/sampling_batch_info.py +21 -0
  159. sglang/srt/sampling/sampling_params.py +9 -1
  160. sglang/srt/server.py +9 -5
  161. sglang/srt/server_args.py +109 -51
  162. sglang/srt/speculative/build_eagle_tree.py +347 -0
  163. sglang/srt/speculative/eagle_utils.py +618 -0
  164. sglang/srt/speculative/eagle_worker.py +170 -0
  165. sglang/srt/speculative/spec_info.py +5 -0
  166. sglang/srt/utils.py +15 -2
  167. sglang/version.py +1 -1
  168. {sglang-0.4.1.post2.dist-info → sglang-0.4.1.post4.dist-info}/METADATA +9 -8
  169. sglang-0.4.1.post4.dist-info/RECORD +329 -0
  170. {sglang-0.4.1.post2.dist-info → sglang-0.4.1.post4.dist-info}/WHEEL +1 -1
  171. sglang-0.4.1.post2.dist-info/RECORD +0 -197
  172. {sglang-0.4.1.post2.dist-info → sglang-0.4.1.post4.dist-info}/LICENSE +0 -0
  173. {sglang-0.4.1.post2.dist-info → sglang-0.4.1.post4.dist-info}/top_level.txt +0 -0
@@ -17,7 +17,7 @@ import gc
17
17
  import json
18
18
  import logging
19
19
  import time
20
- from typing import Optional
20
+ from typing import List, Optional, Tuple
21
21
 
22
22
  import torch
23
23
  import torch.distributed as dist
@@ -48,8 +48,8 @@ from sglang.srt.mem_cache.memory_pool import (
48
48
  )
49
49
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
50
50
  from sglang.srt.model_loader import get_model
51
- from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
52
51
  from sglang.srt.server_args import ServerArgs
52
+ from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
53
53
  from sglang.srt.utils import (
54
54
  enable_show_time_cost,
55
55
  get_available_gpu_memory,
@@ -75,6 +75,7 @@ class ModelRunner:
75
75
  tp_size: int,
76
76
  nccl_port: int,
77
77
  server_args: ServerArgs,
78
+ is_draft_worker: bool = False,
78
79
  ):
79
80
  # Parse args
80
81
  self.model_config = model_config
@@ -85,8 +86,12 @@ class ModelRunner:
85
86
  self.tp_size = tp_size
86
87
  self.dist_port = nccl_port
87
88
  self.server_args = server_args
89
+ self.is_draft_worker = is_draft_worker
88
90
  self.is_generation = model_config.is_generation
89
91
  self.is_multimodal = model_config.is_multimodal
92
+ self.spec_algorithm = SpeculativeAlgorithm.from_string(
93
+ server_args.speculative_algorithm
94
+ )
90
95
 
91
96
  # Model-specific adjustment
92
97
  if (
@@ -192,9 +197,9 @@ class ModelRunner:
192
197
  torch.get_device_module(self.device).set_device(self.gpu_id)
193
198
  if self.device == "cuda":
194
199
  backend = "nccl"
195
- # ToDO(liangan1):Just use gloo to bypass the initilization fail
196
- # Need to use xccl for xpu backend in the future
197
200
  elif self.device == "xpu":
201
+ # TODO(liangan1):Just use gloo to bypass the initilization fail
202
+ # Need to use xccl for xpu backend in the future
198
203
  backend = "gloo"
199
204
  elif self.device == "hpu":
200
205
  backend = "hccl"
@@ -206,14 +211,18 @@ class ModelRunner:
206
211
  else:
207
212
  dist_init_method = f"tcp://127.0.0.1:{self.dist_port}"
208
213
  set_custom_all_reduce(not self.server_args.disable_custom_all_reduce)
209
- init_distributed_environment(
210
- backend=backend,
211
- world_size=self.tp_size,
212
- rank=self.tp_rank,
213
- local_rank=self.gpu_id,
214
- distributed_init_method=dist_init_method,
215
- )
216
- initialize_model_parallel(tensor_model_parallel_size=self.tp_size)
214
+
215
+ if not self.is_draft_worker:
216
+ # Only initilzie the distributed environment on the target model worker.
217
+ init_distributed_environment(
218
+ backend=backend,
219
+ world_size=self.tp_size,
220
+ rank=self.tp_rank,
221
+ local_rank=self.gpu_id,
222
+ distributed_init_method=dist_init_method,
223
+ )
224
+ initialize_model_parallel(tensor_model_parallel_size=self.tp_size)
225
+
217
226
  min_per_gpu_memory = get_available_gpu_memory(
218
227
  self.device, self.gpu_id, distributed=self.tp_size > 1
219
228
  )
@@ -408,7 +417,6 @@ class ModelRunner:
408
417
  target_dtype = (
409
418
  dtype if isinstance(dtype, torch.dtype) else getattr(torch, dtype)
410
419
  )
411
- current_dtype = self.dtype if isinstance(self.dtype, str) else self.dtype
412
420
 
413
421
  assert (
414
422
  self._model_update_group is not None
@@ -429,9 +437,9 @@ class ModelRunner:
429
437
  logger.error(error_msg)
430
438
  return False, error_msg
431
439
 
432
- def update_weights_from_tensor(self, name, tensor: torch.Tensor):
433
- self.model.load_weights([(name, tensor)])
434
- return True, "Success" # TODO error handling
440
+ def update_weights_from_tensor(self, named_tensors: List[Tuple[str, torch.Tensor]]):
441
+ self.model.load_weights(named_tensors)
442
+ return True, "Success"
435
443
 
436
444
  def get_weights_by_name(
437
445
  self, name: str, truncate_size: int = 100
@@ -507,6 +515,28 @@ class ModelRunner:
507
515
  )
508
516
 
509
517
  self.max_total_num_tokens = self.profile_max_num_token(total_gpu_memory)
518
+
519
+ if max_num_reqs is None:
520
+ max_num_reqs = min(
521
+ max(
522
+ int(
523
+ self.max_total_num_tokens / self.model_config.context_len * 512
524
+ ),
525
+ 2048,
526
+ ),
527
+ 4096,
528
+ )
529
+
530
+ if not self.spec_algorithm.is_none():
531
+ if self.is_draft_worker:
532
+ self.max_total_num_tokens = self.server_args.draft_runner_cache_size
533
+ else:
534
+ self.server_args.draft_runner_cache_size = (
535
+ self.max_total_num_tokens
536
+ + max_num_reqs * self.server_args.speculative_num_steps
537
+ + 100
538
+ )
539
+
510
540
  if max_total_tokens is not None:
511
541
  if max_total_tokens > self.max_total_num_tokens:
512
542
  logging.warning(
@@ -521,17 +551,6 @@ class ModelRunner:
521
551
  "Not enough memory. Please try to increase --mem-fraction-static."
522
552
  )
523
553
 
524
- if max_num_reqs is None:
525
- max_num_reqs = min(
526
- max(
527
- int(
528
- self.max_total_num_tokens / self.model_config.context_len * 512
529
- ),
530
- 2048,
531
- ),
532
- 4096,
533
- )
534
-
535
554
  self.req_to_token_pool = ReqToTokenPool(
536
555
  size=max_num_reqs + 1,
537
556
  max_context_len=self.model_config.context_len + 4,
@@ -651,10 +670,6 @@ class ModelRunner:
651
670
  tensor_parallel(self.model, device_mesh)
652
671
 
653
672
  def forward_decode(self, forward_batch: ForwardBatch):
654
- if self.cuda_graph_runner and self.cuda_graph_runner.can_run(forward_batch):
655
- return self.cuda_graph_runner.replay(forward_batch)
656
-
657
- forward_batch.positions = (forward_batch.seq_lens - 1).to(torch.int64)
658
673
  self.attn_backend.init_forward_metadata(forward_batch)
659
674
  return self.model.forward(
660
675
  forward_batch.input_ids, forward_batch.positions, forward_batch
@@ -684,14 +699,18 @@ class ModelRunner:
684
699
  )
685
700
 
686
701
  def forward_idle(self, forward_batch: ForwardBatch):
687
- if self.cuda_graph_runner and self.cuda_graph_runner.can_run(forward_batch):
688
- return self.cuda_graph_runner.replay(forward_batch)
689
-
690
702
  return self.model.forward(
691
703
  forward_batch.input_ids, forward_batch.positions, forward_batch
692
704
  )
693
705
 
694
706
  def forward(self, forward_batch: ForwardBatch) -> LogitsProcessorOutput:
707
+ if (
708
+ forward_batch.forward_mode.is_cuda_graph()
709
+ and self.cuda_graph_runner
710
+ and self.cuda_graph_runner.can_run(forward_batch)
711
+ ):
712
+ return self.cuda_graph_runner.replay(forward_batch)
713
+
695
714
  if forward_batch.forward_mode.is_decode():
696
715
  return self.forward_decode(forward_batch)
697
716
  elif forward_batch.forward_mode.is_extend():
@@ -704,6 +723,7 @@ class ModelRunner:
704
723
  def sample(
705
724
  self, logits_output: LogitsProcessorOutput, forward_batch: ForwardBatch
706
725
  ) -> torch.Tensor:
726
+ # Apply logit bias
707
727
  sampling_info = forward_batch.sampling_info
708
728
  if sampling_info.sampling_info_done:
709
729
  # Overlap mode: the function update_regex_vocab_mask was executed
@@ -714,35 +734,17 @@ class ModelRunner:
714
734
  # Normal mode: Put CPU-heavy tasks here. They will be overlapped with the forward pass.
715
735
  sampling_info.update_regex_vocab_mask()
716
736
  sampling_info.update_penalties()
717
- logits = self.apply_logits_bias(logits_output.next_token_logits, sampling_info)
718
-
719
- # Sample the next tokens.
720
- next_token_ids = self.sampler(logits, sampling_info)
737
+ sampling_info.apply_logits_bias(logits_output.next_token_logits)
738
+
739
+ # Sample the next tokens
740
+ next_token_ids = self.sampler(
741
+ logits_output,
742
+ sampling_info,
743
+ forward_batch.return_logprob,
744
+ forward_batch.top_logprobs_nums,
745
+ )
721
746
  return next_token_ids
722
747
 
723
- def apply_logits_bias(self, logits: torch.Tensor, sampling_info: SamplingBatchInfo):
724
- # Apply logit_bias
725
- if sampling_info.logit_bias is not None:
726
- logits.add_(sampling_info.logit_bias)
727
-
728
- # min-token, presence, frequency
729
- if sampling_info.linear_penalties is not None:
730
- logits.add_(sampling_info.linear_penalties)
731
-
732
- # repetition
733
- if sampling_info.scaling_penalties is not None:
734
- logits = torch.where(
735
- logits > 0,
736
- logits / sampling_info.scaling_penalties,
737
- logits * sampling_info.scaling_penalties,
738
- )
739
-
740
- # Apply regex vocab_mask
741
- if sampling_info.vocab_mask is not None:
742
- sampling_info.apply_mask(logits=logits, vocab_mask=sampling_info.vocab_mask)
743
-
744
- return logits
745
-
746
748
  @property
747
749
  def model_is_mrope(self) -> bool:
748
750
  """Detect if the model has "mrope" rope_scaling type.
@@ -46,6 +46,7 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig
46
46
  from sglang.srt.layers.quantization.fp8_utils import (
47
47
  block_quant_to_tensor_quant,
48
48
  input_to_float8,
49
+ normalize_e4m3fn_to_e4m3fnuz,
49
50
  )
50
51
  from sglang.srt.layers.radix_attention import RadixAttention
51
52
  from sglang.srt.layers.vocab_parallel_embedding import (
@@ -55,7 +56,9 @@ from sglang.srt.layers.vocab_parallel_embedding import (
55
56
  from sglang.srt.managers.schedule_batch import global_server_args_dict
56
57
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
57
58
  from sglang.srt.model_loader.weight_utils import default_weight_loader
58
- from sglang.srt.utils import is_flashinfer_available
59
+ from sglang.srt.utils import is_flashinfer_available, is_hip
60
+
61
+ is_hip_ = is_hip()
59
62
 
60
63
  if is_flashinfer_available():
61
64
  from flashinfer import bmm_fp8
@@ -573,7 +576,13 @@ class DeepseekV2AttentionMLA(nn.Module):
573
576
  )
574
577
  q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
575
578
 
576
- if self.w_kc.dtype == torch.float8_e4m3fn:
579
+ if self.w_kc.dtype == torch.float8_e4m3fnuz:
580
+ # TODO(kernel): add bmm_fp8 for torch.float8_e4m3fnuz
581
+ q_nope_out = torch.bmm(
582
+ q_nope.to(torch.bfloat16).transpose(0, 1),
583
+ self.w_kc.to(torch.bfloat16) * self.w_scale,
584
+ )
585
+ elif self.w_kc.dtype == torch.float8_e4m3fn:
577
586
  q_nope_val, q_nope_scale = input_to_float8(
578
587
  q_nope.transpose(0, 1), torch.float8_e4m3fn
579
588
  )
@@ -598,7 +607,13 @@ class DeepseekV2AttentionMLA(nn.Module):
598
607
  attn_output = self.attn_mqa(q_input, k_input, v_input, forward_batch)
599
608
  attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank)
600
609
 
601
- if self.w_vc.dtype == torch.float8_e4m3fn:
610
+ if self.w_vc.dtype == torch.float8_e4m3fnuz:
611
+ # TODO(kernel): add bmm_fp8 for torch.float8_e4m3fnuz
612
+ attn_bmm_output = torch.bmm(
613
+ attn_output.to(torch.bfloat16).transpose(0, 1),
614
+ self.w_vc.to(torch.bfloat16) * self.w_scale,
615
+ )
616
+ elif self.w_vc.dtype == torch.float8_e4m3fn:
602
617
  attn_output_val, attn_output_scale = input_to_float8(
603
618
  attn_output.transpose(0, 1), torch.float8_e4m3fn
604
619
  )
@@ -940,15 +955,25 @@ class DeepseekV2ForCausalLM(nn.Module):
940
955
  w = self_attn.kv_b_proj.weight
941
956
  # NOTE(HandH1998): Since `bmm_fp8` only supports per-tensor scale, we have to requantize `self_attn.kv_b_proj`.
942
957
  # This may affect the accuracy of fp8 model.
943
- if (
944
- hasattr(self.quant_config, "weight_block_size")
945
- and w.dtype == torch.float8_e4m3fn
958
+ if hasattr(self.quant_config, "weight_block_size") and w.dtype in (
959
+ torch.float8_e4m3fn,
960
+ torch.float8_e4m3fnuz,
946
961
  ):
947
962
  weight_block_size = self.quant_config.weight_block_size
948
963
  if weight_block_size is not None:
949
964
  assert hasattr(self_attn.kv_b_proj, "weight_scale_inv")
965
+ if is_hip_:
966
+ weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
967
+ weight=w,
968
+ weight_scale=self_attn.kv_b_proj.weight_scale_inv,
969
+ input_scale=None,
970
+ )
971
+ else:
972
+ weight = w
973
+ weight_scale = self_attn.kv_b_proj.weight_scale_inv
974
+
950
975
  w, scale = block_quant_to_tensor_quant(
951
- w, self_attn.kv_b_proj.weight_scale_inv, weight_block_size
976
+ weight, weight_scale, weight_block_size
952
977
  )
953
978
  self_attn.w_scale = scale
954
979
  w_kc, w_vc = w.unflatten(
@@ -961,6 +986,8 @@ class DeepseekV2ForCausalLM(nn.Module):
961
986
  and self_attn.w_scale is None
962
987
  ):
963
988
  self_attn.w_scale = self_attn.kv_b_proj.weight_scale
989
+ if is_hip_:
990
+ self_attn.w_scale *= 2.0
964
991
 
965
992
 
966
993
  class DeepseekV3ForCausalLM(DeepseekV2ForCausalLM):
sglang/srt/models/grok.py CHANGED
@@ -16,13 +16,16 @@
16
16
  # https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/mixtral.py#L1
17
17
  """Inference-only Grok1 model."""
18
18
 
19
- from typing import Iterable, Optional, Tuple
19
+ from typing import Iterable, List, Optional, Tuple
20
20
 
21
21
  import torch
22
22
  import torch.nn.functional as F
23
23
  from torch import nn
24
24
  from transformers import PretrainedConfig
25
- from vllm.distributed import get_tensor_model_parallel_world_size
25
+ from vllm.distributed import (
26
+ get_tensor_model_parallel_rank,
27
+ get_tensor_model_parallel_world_size,
28
+ )
26
29
  from vllm.model_executor.layers.rotary_embedding import get_rope
27
30
 
28
31
  from sglang.srt.layers.activation import GeluAndMul
@@ -42,6 +45,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
42
45
  VocabParallelEmbedding,
43
46
  )
44
47
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
48
+ from sglang.srt.model_loader.loader import DefaultModelLoader
45
49
  from sglang.srt.model_loader.weight_utils import default_weight_loader
46
50
 
47
51
 
@@ -347,6 +351,16 @@ class Grok1ForCausalLM(nn.Module):
347
351
  self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
348
352
  self.logits_processor = LogitsProcessor(config)
349
353
 
354
+ # Monkey patch _prepare_weights to load pre-sharded weights
355
+ if (
356
+ self.config.num_local_experts > 0
357
+ and get_tensor_model_parallel_world_size() > 1
358
+ ):
359
+ self.use_presharded_weights = True
360
+ setattr(DefaultModelLoader, "_prepare_weights", _prepare_presharded_weights)
361
+ else:
362
+ self.use_presharded_weights = False
363
+
350
364
  def forward(
351
365
  self,
352
366
  input_ids: torch.Tensor,
@@ -359,7 +373,15 @@ class Grok1ForCausalLM(nn.Module):
359
373
  input_ids, hidden_states, self.lm_head, forward_batch
360
374
  )
361
375
 
362
- def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
376
+ def load_weights(
377
+ self,
378
+ weights: Iterable[Tuple[str, torch.Tensor]],
379
+ use_presharded_weights: bool | None = None,
380
+ ):
381
+ if use_presharded_weights is None:
382
+ use_presharded_weights = self.use_presharded_weights
383
+ num_experts = self.config.num_local_experts
384
+
363
385
  stacked_params_mapping = [
364
386
  # (param_name, shard_name, shard_id)
365
387
  ("qkv_proj", "q_proj", "q"),
@@ -375,10 +397,23 @@ class Grok1ForCausalLM(nn.Module):
375
397
  ckpt_gate_proj_name="w1",
376
398
  ckpt_down_proj_name="w2",
377
399
  ckpt_up_proj_name="w3",
378
- num_experts=self.config.num_local_experts,
400
+ num_experts=num_experts,
379
401
  )
380
402
 
381
403
  params_dict = dict(self.named_parameters())
404
+ all_names = set(params_dict.keys())
405
+ hit_names = set()
406
+
407
+ def load_weight_wrapper(name, loaded_weight, *args, **kwargs):
408
+ if name not in params_dict:
409
+ return
410
+
411
+ param = params_dict[name]
412
+ weight_loader = getattr(param, "weight_loader", default_weight_loader)
413
+ weight_loader(param, loaded_weight, *args, **kwargs)
414
+
415
+ hit_names.add(name)
416
+
382
417
  for name, loaded_weight in weights:
383
418
  if "rotary_emb.inv_freq" in name:
384
419
  continue
@@ -391,9 +426,7 @@ class Grok1ForCausalLM(nn.Module):
391
426
  if name.endswith(".bias") and name not in params_dict:
392
427
  continue
393
428
 
394
- param = params_dict[name]
395
- weight_loader = param.weight_loader
396
- weight_loader(param, loaded_weight, shard_id)
429
+ load_weight_wrapper(name, loaded_weight, shard_id)
397
430
  break
398
431
  else:
399
432
  for mapping in expert_params_mapping:
@@ -402,38 +435,76 @@ class Grok1ForCausalLM(nn.Module):
402
435
  continue
403
436
  name = name.replace(weight_name, param_name)
404
437
 
405
- if (
406
- name.endswith(".bias") or name.endswith("_bias")
407
- ) and name not in params_dict:
408
- continue
438
+ if use_presharded_weights:
439
+ extra_kwargs = {
440
+ "use_presharded_weights": use_presharded_weights
441
+ }
442
+ else:
443
+ extra_kwargs = {}
409
444
 
410
- param = params_dict[name]
411
- weight_loader = param.weight_loader
412
- weight_loader(
413
- param,
445
+ load_weight_wrapper(
446
+ name,
414
447
  loaded_weight,
415
448
  name,
416
449
  shard_id=shard_id,
417
450
  expert_id=expert_id,
451
+ **extra_kwargs,
418
452
  )
419
453
  break
420
454
  else:
421
455
  # Skip loading extra bias for GPTQ models.
422
- if (
423
- name.endswith(".bias") or name.endswith("_bias")
424
- ) and name not in params_dict:
425
- continue
426
- # Skip loading kv_scale from ckpts towards new design.
427
- if name.endswith(".kv_scale") and name not in params_dict:
456
+ if name.endswith(".bias") and name not in params_dict:
428
457
  continue
429
458
  if name is None:
430
459
  continue
431
460
 
432
- param = params_dict[name]
433
- weight_loader = getattr(
434
- param, "weight_loader", default_weight_loader
435
- )
436
- weight_loader(param, loaded_weight)
461
+ load_weight_wrapper(name=name, loaded_weight=loaded_weight)
462
+
463
+
464
+ old_prepare_weights = getattr(DefaultModelLoader, "_prepare_weights")
465
+
466
+
467
+ def _prepare_presharded_weights(
468
+ self, model_name_or_path: str, revision: Optional[str], fall_back_to_pt: bool
469
+ ) -> Tuple[str, List[str], bool]:
470
+ import glob
471
+ import os
472
+
473
+ if get_tensor_model_parallel_world_size() == 1:
474
+ return old_prepare_weights(self, model_name_or_path, revision, fall_back_to_pt)
475
+
476
+ if not os.path.isdir(model_name_or_path):
477
+ from sglang.srt.model_loader.weight_utils import download_weights_from_hf
478
+
479
+ allow_patterns = ["*.safetensors", "*.bin"]
480
+ hf_folder = download_weights_from_hf(
481
+ model_name_or_path,
482
+ self.load_config.download_dir,
483
+ allow_patterns,
484
+ revision,
485
+ ignore_patterns=self.load_config.ignore_patterns,
486
+ )
487
+ else:
488
+ hf_folder = model_name_or_path
489
+
490
+ tp_rank = get_tensor_model_parallel_rank()
491
+
492
+ # The old format
493
+ allow_patterns = [f"*-{tp_rank:03d}.bin"]
494
+
495
+ # The new format
496
+ allow_patterns += [f"*-TP-{tp_rank:03d}.safetensors", "*-TP-common.safetensors"]
497
+
498
+ hf_weights_files: List[str] = []
499
+ for pattern in allow_patterns:
500
+ hf_weights_files += glob.glob(os.path.join(hf_folder, pattern))
501
+
502
+ if hf_weights_files[0].endswith("safetensors"):
503
+ use_safetensors = True
504
+ else:
505
+ use_safetensors = False
506
+
507
+ return hf_folder, hf_weights_files, use_safetensors
437
508
 
438
509
 
439
510
  class Grok1ModelForCausalLM(Grok1ForCausalLM):
@@ -696,14 +696,6 @@ def v1_generate_response(request, ret, tokenizer_manager, to_file=False):
696
696
 
697
697
  async def v1_completions(tokenizer_manager, raw_request: Request):
698
698
  request_json = await raw_request.json()
699
- if "extra_body" in request_json:
700
- extra = request_json["extra_body"]
701
- if "ebnf" in extra:
702
- request_json["ebnf"] = extra["ebnf"]
703
- if "regex" in extra:
704
- request_json["regex"] = extra["regex"]
705
- # remove extra_body to avoid pydantic conflict
706
- del request_json["extra_body"]
707
699
  all_requests = [CompletionRequest(**request_json)]
708
700
  adapted_request, request = v1_generate_request(all_requests)
709
701
 
@@ -1176,15 +1168,6 @@ def v1_chat_generate_response(request, ret, to_file=False, cache_report=False):
1176
1168
 
1177
1169
  async def v1_chat_completions(tokenizer_manager, raw_request: Request):
1178
1170
  request_json = await raw_request.json()
1179
- if "extra_body" in request_json:
1180
- extra = request_json["extra_body"]
1181
- # For example, if 'ebnf' is given:
1182
- if "ebnf" in extra:
1183
- request_json["ebnf"] = extra["ebnf"]
1184
- if "regex" in extra:
1185
- request_json["regex"] = extra["regex"]
1186
- # remove extra_body to avoid pydantic conflict
1187
- del request_json["extra_body"]
1188
1171
  all_requests = [ChatCompletionRequest(**request_json)]
1189
1172
  adapted_request, request = v1_chat_generate_request(all_requests, tokenizer_manager)
1190
1173
 
@@ -171,15 +171,15 @@ class CompletionRequest(BaseModel):
171
171
  top_k: int = -1
172
172
  min_p: float = 0.0
173
173
  min_tokens: int = 0
174
- regex: Optional[str] = None
175
174
  json_schema: Optional[str] = None
175
+ regex: Optional[str] = None
176
+ ebnf: Optional[str] = None
176
177
  repetition_penalty: float = 1.0
177
178
  stop_token_ids: Optional[List[int]] = None
178
179
  no_stop_trim: bool = False
179
180
  ignore_eos: bool = False
180
181
  skip_special_tokens: bool = True
181
182
  lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
182
- ebnf: Optional[str] = None
183
183
 
184
184
 
185
185
  class CompletionResponseChoice(BaseModel):
@@ -315,13 +315,13 @@ class ChatCompletionRequest(BaseModel):
315
315
  min_p: float = 0.0
316
316
  min_tokens: int = 0
317
317
  regex: Optional[str] = None
318
+ ebnf: Optional[str] = None
318
319
  repetition_penalty: float = 1.0
319
320
  stop_token_ids: Optional[List[int]] = None
320
321
  no_stop_trim: bool = False
321
322
  ignore_eos: bool = False
322
323
  skip_special_tokens: bool = True
323
324
  lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
324
- ebnf: Optional[str] = None
325
325
 
326
326
 
327
327
  class FunctionResponse(BaseModel):
@@ -232,3 +232,24 @@ class SamplingBatchInfo:
232
232
  self.logit_bias = SamplingBatchInfo.merge_bias_tensor(
233
233
  self.logit_bias, other.logit_bias, len(self), len(other), self.device
234
234
  )
235
+
236
+ def apply_logits_bias(self, logits: torch.Tensor):
237
+ # Apply logit_bias
238
+ if self.logit_bias is not None:
239
+ logits.add_(self.logit_bias)
240
+
241
+ # min-token, presence, frequency
242
+ if self.linear_penalties is not None:
243
+ logits.add_(self.linear_penalties)
244
+
245
+ # repetition
246
+ if self.scaling_penalties is not None:
247
+ logits[:] = torch.where(
248
+ logits > 0,
249
+ logits / self.scaling_penalties,
250
+ logits * self.scaling_penalties,
251
+ )
252
+
253
+ # Apply regex vocab_mask
254
+ if self.vocab_mask is not None:
255
+ self.apply_mask(logits=logits, vocab_mask=self.vocab_mask)
@@ -19,6 +19,14 @@ _SAMPLING_EPS = 1e-6
19
19
 
20
20
 
21
21
  class SamplingParams:
22
+ """
23
+ The sampling parameters.
24
+
25
+ See docs/references/sampling_params.md or
26
+ https://sgl-project.github.io/references/sampling_params.html
27
+ for the documentation.
28
+ """
29
+
22
30
  def __init__(
23
31
  self,
24
32
  max_new_tokens: int = 128,
@@ -33,9 +41,9 @@ class SamplingParams:
33
41
  repetition_penalty: float = 1.0,
34
42
  min_new_tokens: int = 0,
35
43
  spaces_between_special_tokens: bool = True,
36
- regex: Optional[str] = None,
37
44
  n: int = 1,
38
45
  json_schema: Optional[str] = None,
46
+ regex: Optional[str] = None,
39
47
  ebnf: Optional[str] = None,
40
48
  no_stop_trim: bool = False,
41
49
  ignore_eos: bool = False,
sglang/srt/server.py CHANGED
@@ -27,7 +27,9 @@ import signal
27
27
  import threading
28
28
  import time
29
29
  from http import HTTPStatus
30
- from typing import AsyncIterator, Dict, List, Optional, Union
30
+ from typing import AsyncIterator, Dict, List, Optional, Tuple, Union
31
+
32
+ import torch
31
33
 
32
34
  # Fix a bug of Python threading
33
35
  setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
@@ -78,6 +80,7 @@ from sglang.srt.openai_api.adapter import (
78
80
  from sglang.srt.openai_api.protocol import ModelCard, ModelList
79
81
  from sglang.srt.server_args import PortArgs, ServerArgs
80
82
  from sglang.srt.utils import (
83
+ MultiprocessingSerializer,
81
84
  add_api_key_middleware,
82
85
  add_prometheus_middleware,
83
86
  assert_pkg_version,
@@ -872,9 +875,11 @@ class Engine:
872
875
  tokenizer_manager.update_weights_from_distributed(obj, None)
873
876
  )
874
877
 
875
- def update_weights_from_tensor(self, name, tensor):
878
+ def update_weights_from_tensor(self, named_tensors: List[Tuple[str, torch.Tensor]]):
876
879
  """Update weights from distributed source."""
877
- obj = UpdateWeightsFromTensorReqInput(name=name, tensor=tensor)
880
+ obj = UpdateWeightsFromTensorReqInput(
881
+ serialized_named_tensors=MultiprocessingSerializer.serialize(named_tensors)
882
+ )
878
883
  loop = asyncio.get_event_loop()
879
884
  return loop.run_until_complete(
880
885
  tokenizer_manager.update_weights_from_tensor(obj, None)
@@ -910,10 +915,9 @@ class Runtime:
910
915
  atexit.register(self.shutdown)
911
916
 
912
917
  # Pre-allocate ports
913
- for port in range(10000, 40000):
918
+ for port in range(self.server_args.port, 40000):
914
919
  if is_port_available(port):
915
920
  break
916
- port += 1
917
921
  self.server_args.port = port
918
922
 
919
923
  self.url = self.server_args.url()