sglang 0.4.1.post2__py3-none-any.whl → 0.4.1.post4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (173) hide show
  1. sglang/bench_one_batch.py +2 -0
  2. sglang/srt/layers/attention/__init__.py +14 -5
  3. sglang/srt/layers/attention/double_sparsity_backend.py +0 -52
  4. sglang/srt/layers/attention/flashinfer_backend.py +211 -81
  5. sglang/srt/layers/attention/torch_native_backend.py +1 -38
  6. sglang/srt/layers/attention/triton_backend.py +20 -11
  7. sglang/srt/layers/attention/triton_ops/decode_attention.py +4 -0
  8. sglang/srt/layers/logits_processor.py +167 -212
  9. sglang/srt/layers/moe/fused_moe_triton/configs/E=1,N=14336,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json +146 -0
  10. sglang/srt/layers/moe/fused_moe_triton/configs/E=1,N=14336,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  11. sglang/srt/layers/moe/fused_moe_triton/configs/E=1,N=1792,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json +218 -0
  12. sglang/srt/layers/moe/fused_moe_triton/configs/E=1,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json +218 -0
  13. sglang/srt/layers/moe/fused_moe_triton/configs/E=1,N=3072,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json +218 -0
  14. sglang/srt/layers/moe/fused_moe_triton/configs/E=1,N=3072,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json +218 -0
  15. sglang/srt/layers/moe/fused_moe_triton/configs/E=1,N=3072,device_name=NVIDIA_H100_80GB_HBM3.json +218 -0
  16. sglang/srt/layers/moe/fused_moe_triton/configs/E=1,N=3584,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json +218 -0
  17. sglang/srt/layers/moe/fused_moe_triton/configs/E=1,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json +218 -0
  18. sglang/srt/layers/moe/fused_moe_triton/configs/E=1,N=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json +218 -0
  19. sglang/srt/layers/moe/fused_moe_triton/configs/E=1,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json +218 -0
  20. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1344,device_name=NVIDIA_A100-SXM4-40GB.json +146 -0
  21. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1344,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  22. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1344,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  23. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=14336,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json +146 -0
  24. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=14336,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  25. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1792,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json +218 -0
  26. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json +218 -0
  27. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=2688,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  28. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=2688,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  29. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=3072,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json +146 -0
  30. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=3072,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json +146 -0
  31. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=3200,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +130 -0
  32. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=3584,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json +146 -0
  33. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json +218 -0
  34. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=6400,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +130 -0
  35. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json +146 -0
  36. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  37. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json +146 -0
  38. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=800,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +130 -0
  39. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  40. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  41. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1280,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  42. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1280,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  43. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1280,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +146 -0
  44. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1280,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  45. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1280,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  46. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1280,device_name=NVIDIA_H200.json +146 -0
  47. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=2560,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +146 -0
  48. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=2560,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  49. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=2560,device_name=NVIDIA_H200.json +146 -0
  50. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=320,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +146 -0
  51. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=320,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  52. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=320,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  53. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=320,device_name=NVIDIA_H200.json +146 -0
  54. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  55. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  56. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_GeForce_RTX_4090,dtype=fp8_w8a8.json +146 -0
  57. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +146 -0
  58. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  59. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  60. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_H200.json +146 -0
  61. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X.json +200 -0
  62. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +138 -0
  63. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  64. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=NVIDIA_H200.json +146 -0
  65. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI300X.json +200 -0
  66. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=NVIDIA_A100-SXM4-40GB.json +146 -0
  67. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  68. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  69. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  70. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=NVIDIA_H200.json +146 -0
  71. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=2048,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  72. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=2048,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +146 -0
  73. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  74. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=2048,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  75. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=2048,device_name=NVIDIA_H200.json +146 -0
  76. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X.json +200 -0
  77. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_A100-SXM4-40GB.json +146 -0
  78. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  79. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_GeForce_RTX_4090,dtype=fp8_w8a8.json +146 -0
  80. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +146 -0
  81. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  82. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  83. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_H200.json +146 -0
  84. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_L40S.json +173 -0
  85. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +178 -0
  86. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  87. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +146 -0
  88. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  89. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  90. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=NVIDIA_H200.json +146 -0
  91. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X.json +200 -0
  92. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  93. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +146 -0
  94. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  95. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  96. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=NVIDIA_H200.json +146 -0
  97. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +175 -0
  98. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +146 -0
  99. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  100. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +187 -29
  101. sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -6
  102. sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  103. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  104. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  105. sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  106. sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  107. sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  108. sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  109. sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  110. sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  111. sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  112. sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  113. sglang/srt/layers/quantization/configs/N=3072,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  114. sglang/srt/layers/quantization/configs/N=3072,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  115. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  116. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  117. sglang/srt/layers/quantization/configs/N=36864,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  118. sglang/srt/layers/quantization/configs/N=36864,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  119. sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  120. sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  121. sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  122. sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  123. sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  124. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  125. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  126. sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  127. sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  128. sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  129. sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  130. sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  131. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  132. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  133. sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  134. sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  135. sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  136. sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  137. sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  138. sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  139. sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  140. sglang/srt/layers/quantization/fp8.py +2 -2
  141. sglang/srt/layers/sampler.py +57 -21
  142. sglang/srt/layers/torchao_utils.py +17 -3
  143. sglang/srt/managers/detokenizer_manager.py +2 -0
  144. sglang/srt/managers/io_struct.py +12 -3
  145. sglang/srt/managers/schedule_batch.py +26 -2
  146. sglang/srt/managers/schedule_policy.py +159 -90
  147. sglang/srt/managers/scheduler.py +71 -27
  148. sglang/srt/managers/tokenizer_manager.py +29 -20
  149. sglang/srt/managers/tp_worker.py +16 -4
  150. sglang/srt/managers/tp_worker_overlap_thread.py +3 -4
  151. sglang/srt/model_executor/cuda_graph_runner.py +118 -73
  152. sglang/srt/model_executor/forward_batch_info.py +33 -8
  153. sglang/srt/model_executor/model_runner.py +63 -61
  154. sglang/srt/models/deepseek_v2.py +34 -7
  155. sglang/srt/models/grok.py +97 -26
  156. sglang/srt/openai_api/adapter.py +0 -17
  157. sglang/srt/openai_api/protocol.py +3 -3
  158. sglang/srt/sampling/sampling_batch_info.py +21 -0
  159. sglang/srt/sampling/sampling_params.py +9 -1
  160. sglang/srt/server.py +9 -5
  161. sglang/srt/server_args.py +109 -51
  162. sglang/srt/speculative/build_eagle_tree.py +347 -0
  163. sglang/srt/speculative/eagle_utils.py +618 -0
  164. sglang/srt/speculative/eagle_worker.py +170 -0
  165. sglang/srt/speculative/spec_info.py +5 -0
  166. sglang/srt/utils.py +15 -2
  167. sglang/version.py +1 -1
  168. {sglang-0.4.1.post2.dist-info → sglang-0.4.1.post4.dist-info}/METADATA +9 -8
  169. sglang-0.4.1.post4.dist-info/RECORD +329 -0
  170. {sglang-0.4.1.post2.dist-info → sglang-0.4.1.post4.dist-info}/WHEEL +1 -1
  171. sglang-0.4.1.post2.dist-info/RECORD +0 -197
  172. {sglang-0.4.1.post2.dist-info → sglang-0.4.1.post4.dist-info}/LICENSE +0 -0
  173. {sglang-0.4.1.post2.dist-info → sglang-0.4.1.post4.dist-info}/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  from __future__ import annotations
2
2
 
3
- from typing import TYPE_CHECKING, Optional
3
+ from typing import TYPE_CHECKING
4
4
 
5
5
  import torch
6
6
  from torch.nn.functional import scaled_dot_product_attention
@@ -23,43 +23,6 @@ class TorchNativeAttnBackend(AttentionBackend):
23
23
  """Init the metadata for a forward pass."""
24
24
  pass
25
25
 
26
- def init_cuda_graph_state(self, max_bs: int):
27
- # TODO: Support CUDA graph
28
- raise ValueError(
29
- "Torch native attention does not support CUDA graph for now. Please --disable-cuda-graph"
30
- )
31
-
32
- def init_forward_metadata_capture_cuda_graph(
33
- self,
34
- bs: int,
35
- req_pool_indices: torch.Tensor,
36
- seq_lens: torch.Tensor,
37
- encoder_lens: Optional[torch.Tensor] = None,
38
- ):
39
- # TODO: Support CUDA graph
40
- raise ValueError(
41
- "Torch native attention does not support CUDA graph for now. Please --disable-cuda-graph"
42
- )
43
-
44
- def init_forward_metadata_replay_cuda_graph(
45
- self,
46
- bs: int,
47
- req_pool_indices: torch.Tensor,
48
- seq_lens: torch.Tensor,
49
- seq_lens_sum: int,
50
- encoder_lens: Optional[torch.Tensor] = None,
51
- ):
52
- # TODO: Support CUDA graph
53
- raise ValueError(
54
- "Torch native attention does not support CUDA graph for now. Please --disable-cuda-graph"
55
- )
56
-
57
- def get_cuda_graph_seq_len_fill_value(self):
58
- # TODO: Support CUDA graph
59
- raise ValueError(
60
- "Torch native attention does not support CUDA graph for now. Please --disable-cuda-graph"
61
- )
62
-
63
26
  def _run_sdpa_forward_extend(
64
27
  self,
65
28
  query: torch.Tensor,
@@ -1,15 +1,16 @@
1
1
  from __future__ import annotations
2
2
 
3
- from typing import TYPE_CHECKING
3
+ from typing import TYPE_CHECKING, Optional
4
4
 
5
5
  import torch
6
6
 
7
7
  from sglang.srt.layers.attention import AttentionBackend
8
- from sglang.srt.model_executor.forward_batch_info import ForwardBatch
8
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
9
9
 
10
10
  if TYPE_CHECKING:
11
11
  from sglang.srt.layers.radix_attention import RadixAttention
12
12
  from sglang.srt.model_executor.model_runner import ModelRunner
13
+ from sglang.srt.speculative.spec_info import SpecInfo
13
14
 
14
15
 
15
16
  class TritonAttnBackend(AttentionBackend):
@@ -80,11 +81,17 @@ class TritonAttnBackend(AttentionBackend):
80
81
  def init_forward_metadata_capture_cuda_graph(
81
82
  self,
82
83
  bs: int,
84
+ num_tokens: int,
83
85
  req_pool_indices: torch.Tensor,
84
86
  seq_lens: torch.Tensor,
85
- encoder_lens=None,
87
+ encoder_lens: Optional[torch.Tensor],
88
+ forward_mode: ForwardMode,
89
+ spec_info: Optional[SpecInfo],
86
90
  ):
87
- # NOTE: encoder_lens expected to be zeros or None
91
+ assert encoder_lens is None, "Not supported"
92
+ assert forward_mode.is_decode(), "Not supported"
93
+ assert spec_info is None, "Not supported"
94
+
88
95
  self.forward_metadata = (
89
96
  self.cuda_graph_attn_logits,
90
97
  None,
@@ -96,7 +103,9 @@ class TritonAttnBackend(AttentionBackend):
96
103
  req_pool_indices: torch.Tensor,
97
104
  seq_lens: torch.Tensor,
98
105
  seq_lens_sum: int,
99
- encoder_lens=None,
106
+ encoder_lens: Optional[torch.Tensor],
107
+ forward_mode: ForwardMode,
108
+ spec_info: Optional[SpecInfo],
100
109
  ):
101
110
  # NOTE: encoder_lens expected to be zeros or None
102
111
  self.cuda_graph_start_loc.zero_()
@@ -107,9 +116,9 @@ class TritonAttnBackend(AttentionBackend):
107
116
 
108
117
  def forward_extend(
109
118
  self,
110
- q,
111
- k,
112
- v,
119
+ q: torch.Tensor,
120
+ k: torch.Tensor,
121
+ v: torch.Tensor,
113
122
  layer: RadixAttention,
114
123
  forward_batch: ForwardBatch,
115
124
  save_kv_cache=True,
@@ -146,9 +155,9 @@ class TritonAttnBackend(AttentionBackend):
146
155
 
147
156
  def forward_decode(
148
157
  self,
149
- q,
150
- k,
151
- v,
158
+ q: torch.Tensor,
159
+ k: torch.Tensor,
160
+ v: torch.Tensor,
152
161
  layer: RadixAttention,
153
162
  forward_batch: ForwardBatch,
154
163
  save_kv_cache=True,
@@ -406,6 +406,10 @@ def _decode_grouped_att_m_fwd(
406
406
  Lk = k_buffer.shape[-1]
407
407
  Lv = v_buffer.shape[-1]
408
408
 
409
+ # [TODO] work around shmem limit on MI3xx
410
+ if is_hip_ and Lk >= 576:
411
+ BLOCK = 16
412
+
409
413
  if Lk == 576:
410
414
  BLOCK_DMODEL = 512
411
415
  BLOCK_DPE = 64
@@ -17,6 +17,8 @@ import dataclasses
17
17
  from typing import List, Optional, Union
18
18
 
19
19
  import torch
20
+ import triton
21
+ import triton.language as tl
20
22
  from torch import nn
21
23
  from vllm.distributed import (
22
24
  get_tensor_model_parallel_world_size,
@@ -33,76 +35,77 @@ from sglang.srt.model_executor.forward_batch_info import (
33
35
 
34
36
  @dataclasses.dataclass
35
37
  class LogitsProcessorOutput:
38
+ ## Part 1: This part will be assigned in python/sglang/srt/layers/logits_processor.py::LogitsProcessor
36
39
  # The logits of the next tokens. shape: [#seq, vocab_size]
37
40
  next_token_logits: torch.Tensor
38
- # The logprobs of the next tokens. shape: [#seq, vocab_size]
39
- next_token_logprobs: torch.Tensor = None
41
+ # Used by speculative decoding (EAGLE)
42
+ # The last hidden layers
43
+ hidden_states: Optional[torch.Tensor] = None
44
+
45
+ ## Part 2: This part will be assigned in python/sglang/srt/layers/sampler.py::Sampler
46
+ # The logprobs of the next tokens. shape: [#seq]
47
+ next_token_logprobs: Optional[torch.Tensor] = None
48
+ # The logprobs and ids of the top-k tokens in output positions. shape: [#seq, k]
49
+ next_token_top_logprobs_val: Optional[List] = None
50
+ next_token_top_logprobs_idx: Optional[List] = None
40
51
 
52
+ ## Part 3: Prefill-only. This part will be assigned in python/sglang/srt/layers/logits_processor.py::LogitsProcessor
41
53
  # The normlaized logprobs of prompts. shape: [#seq]
42
54
  normalized_prompt_logprobs: torch.Tensor = None
43
- # The logprobs of input tokens. shape: [#token, vocab_size]
55
+ # The logprobs of input tokens. shape: [#token]
44
56
  input_token_logprobs: torch.Tensor = None
45
-
46
- # The logprob and id of the top-k tokens in input positions. shape [#seq, #token, k]
57
+ # The logprobs and ids of the top-k tokens in input positions. shape: [#seq, #token, k]
47
58
  input_top_logprobs_val: List = None
48
59
  input_top_logprobs_idx: List = None
49
- # The logprob and id of the top-k tokens in output positions. shape [#seq, #token, k]
50
- output_top_logprobs_val: List = None
51
- output_top_logprobs_idx: List = None
52
-
53
- # Used by speculative decoding (EAGLE)
54
- # The output of transformer layers
55
- hidden_states: Optional[torch.Tensor] = None
56
60
 
57
61
 
58
62
  @dataclasses.dataclass
59
63
  class LogitsMetadata:
60
64
  forward_mode: ForwardMode
61
- top_logprobs_nums: Optional[List[int]]
62
-
63
- return_logprob: bool = False
64
- return_top_logprob: bool = False
65
+ capture_hidden_mode: CaptureHiddenMode = CaptureHiddenMode.NULL
65
66
 
67
+ extend_return_logprob: bool = False
68
+ extend_return_top_logprob: bool = False
66
69
  extend_seq_lens: Optional[torch.Tensor] = None
67
70
  extend_seq_lens_cpu: Optional[List[int]] = None
68
-
69
71
  extend_logprob_start_lens_cpu: Optional[List[int]] = None
70
72
  extend_logprob_pruned_lens_cpu: Optional[List[int]] = None
71
-
72
- capture_hidden_mode: CaptureHiddenMode = CaptureHiddenMode.NULL
73
+ top_logprobs_nums: Optional[List[int]] = None
73
74
 
74
75
  @classmethod
75
76
  def from_forward_batch(cls, forward_batch: ForwardBatch):
76
- extend_logprob_pruned_lens_cpu = None
77
-
78
- if forward_batch.return_logprob:
79
- return_top_logprob = any(x > 0 for x in forward_batch.top_logprobs_nums)
80
- if forward_batch.forward_mode.is_extend():
81
- extend_logprob_pruned_lens_cpu = [
82
- extend_len - start_len
83
- for extend_len, start_len in zip(
84
- forward_batch.extend_seq_lens_cpu,
85
- forward_batch.extend_logprob_start_lens_cpu,
86
- )
87
- ]
88
- else:
89
- return_top_logprob = False
90
-
91
77
  if forward_batch.spec_info:
92
78
  capture_hidden_mode = forward_batch.spec_info.capture_hidden_mode
93
79
  else:
94
80
  capture_hidden_mode = CaptureHiddenMode.NULL
95
81
 
82
+ if forward_batch.forward_mode.is_extend() and forward_batch.return_logprob:
83
+ extend_return_logprob = True
84
+ extend_return_top_logprob = any(
85
+ x > 0 for x in forward_batch.top_logprobs_nums
86
+ )
87
+ extend_logprob_pruned_lens_cpu = [
88
+ extend_len - start_len
89
+ for extend_len, start_len in zip(
90
+ forward_batch.extend_seq_lens_cpu,
91
+ forward_batch.extend_logprob_start_lens_cpu,
92
+ )
93
+ ]
94
+ else:
95
+ extend_return_logprob = extend_return_top_logprob = (
96
+ extend_logprob_pruned_lens_cpu
97
+ ) = False
98
+
96
99
  return cls(
97
100
  forward_mode=forward_batch.forward_mode,
98
- top_logprobs_nums=forward_batch.top_logprobs_nums,
99
- return_logprob=forward_batch.return_logprob,
100
- return_top_logprob=return_top_logprob,
101
+ capture_hidden_mode=capture_hidden_mode,
102
+ extend_return_logprob=extend_return_logprob,
103
+ extend_return_top_logprob=extend_return_top_logprob,
101
104
  extend_seq_lens=forward_batch.extend_seq_lens,
102
105
  extend_seq_lens_cpu=forward_batch.extend_seq_lens_cpu,
103
106
  extend_logprob_start_lens_cpu=forward_batch.extend_logprob_start_lens_cpu,
104
107
  extend_logprob_pruned_lens_cpu=extend_logprob_pruned_lens_cpu,
105
- capture_hidden_mode=capture_hidden_mode,
108
+ top_logprobs_nums=forward_batch.top_logprobs_nums,
106
109
  )
107
110
 
108
111
 
@@ -129,7 +132,6 @@ class LogitsProcessor(nn.Module):
129
132
  ):
130
133
  if isinstance(logits_metadata, ForwardBatch):
131
134
  logits_metadata = LogitsMetadata.from_forward_batch(logits_metadata)
132
- assert isinstance(logits_metadata, LogitsMetadata)
133
135
 
134
136
  # Get the last hidden states and last logits for the next token prediction
135
137
  if (
@@ -142,18 +144,13 @@ class LogitsProcessor(nn.Module):
142
144
  last_index = torch.cumsum(logits_metadata.extend_seq_lens, dim=0) - 1
143
145
  last_hidden = hidden_states[last_index]
144
146
 
147
+ # Compute logits
145
148
  last_logits = self._get_logits(last_hidden, lm_head)
146
- if self.do_tensor_parallel_all_gather:
147
- last_logits = tensor_model_parallel_all_gather(last_logits)
148
- last_logits = last_logits[:, : self.config.vocab_size].float()
149
-
150
- if self.final_logit_softcapping:
151
- last_logits.div_(self.final_logit_softcapping)
152
- torch.tanh(last_logits, out=last_logits)
153
- last_logits.mul_(self.final_logit_softcapping)
154
-
155
- # Return only last_logits if logprob is not requested
156
- if not logits_metadata.return_logprob:
149
+ if (
150
+ not logits_metadata.extend_return_logprob
151
+ or logits_metadata.capture_hidden_mode.need_capture()
152
+ ):
153
+ # Decode mode or extend mode without return_logprob.
157
154
  return LogitsProcessorOutput(
158
155
  next_token_logits=last_logits,
159
156
  hidden_states=(
@@ -167,95 +164,60 @@ class LogitsProcessor(nn.Module):
167
164
  ),
168
165
  )
169
166
  else:
170
- last_logprobs = self.compute_temp_top_p_normalized_logprobs(
171
- last_logits, logits_metadata
167
+ # Slice the requested tokens to compute logprob
168
+ pt, pruned_states, pruned_input_ids = 0, [], []
169
+ for start_len, extend_len in zip(
170
+ logits_metadata.extend_logprob_start_lens_cpu,
171
+ logits_metadata.extend_seq_lens_cpu,
172
+ ):
173
+ pruned_states.append(hidden_states[pt + start_len : pt + extend_len])
174
+ pruned_input_ids.append(input_ids[pt + start_len : pt + extend_len])
175
+ pt += extend_len
176
+
177
+ # Compute the logits of all required tokens
178
+ pruned_states = torch.cat(pruned_states)
179
+ del hidden_states
180
+ input_token_logits = self._get_logits(pruned_states, lm_head)
181
+ del pruned_states
182
+
183
+ # Normalize the logprob w/o temperature, top-p
184
+ input_logprobs = input_token_logits
185
+ input_logprobs = self.compute_temp_top_p_normalized_logprobs(
186
+ input_logprobs, logits_metadata
172
187
  )
173
188
 
174
- if logits_metadata.forward_mode.is_decode():
175
- if logits_metadata.return_top_logprob:
176
- output_top_logprobs_val, output_top_logprobs_idx = (
177
- self.get_top_logprobs(last_logprobs, logits_metadata)[2:4]
178
- )
179
- else:
180
- output_top_logprobs_val = output_top_logprobs_idx = None
181
- return LogitsProcessorOutput(
182
- next_token_logits=last_logits,
183
- next_token_logprobs=last_logprobs,
184
- output_top_logprobs_val=output_top_logprobs_val,
185
- output_top_logprobs_idx=output_top_logprobs_idx,
186
- )
189
+ # Get the logprob of top-k tokens
190
+ if logits_metadata.extend_return_top_logprob:
191
+ (
192
+ input_top_logprobs_val,
193
+ input_top_logprobs_idx,
194
+ ) = self.get_top_logprobs(input_logprobs, logits_metadata)
187
195
  else:
188
- # Slice the requested tokens to compute logprob
189
- pt, states, pruned_input_ids = 0, [], []
190
- for start_len, extend_len in zip(
191
- logits_metadata.extend_logprob_start_lens_cpu,
192
- logits_metadata.extend_seq_lens_cpu,
193
- ):
194
- states.append(hidden_states[pt + start_len : pt + extend_len])
195
- pruned_input_ids.append(input_ids[pt + start_len : pt + extend_len])
196
- pt += extend_len
197
-
198
- # Compute the logits and logprobs for all required tokens
199
- states = torch.cat(states, dim=0)
200
- all_logits = self._get_logits(states, lm_head)
201
- if self.do_tensor_parallel_all_gather:
202
- all_logits = tensor_model_parallel_all_gather(all_logits)
203
-
204
- # The LM head's weights may be zero-padded for parallelism. Remove any
205
- # extra logits that this padding may have produced.
206
- all_logits = all_logits[:, : self.config.vocab_size].float()
207
-
208
- if self.final_logit_softcapping:
209
- all_logits.div_(self.final_logit_softcapping)
210
- torch.tanh(all_logits, out=all_logits)
211
- all_logits.mul_(self.final_logit_softcapping)
212
-
213
- all_logprobs = all_logits
214
- del all_logits, hidden_states
215
-
216
- all_logprobs = self.compute_temp_top_p_normalized_logprobs(
217
- all_logprobs, logits_metadata
218
- )
219
-
220
- # Get the logprob of top-k tokens
221
- if logits_metadata.return_top_logprob:
222
- (
223
- input_top_logprobs_val,
224
- input_top_logprobs_idx,
225
- output_top_logprobs_val,
226
- output_top_logprobs_idx,
227
- ) = self.get_top_logprobs(all_logprobs, logits_metadata)
228
- else:
229
- input_top_logprobs_val = input_top_logprobs_idx = (
230
- output_top_logprobs_val
231
- ) = output_top_logprobs_idx = None
232
-
233
- # Compute the normalized logprobs for the requested tokens.
234
- # Note that we pad a zero at the end for easy batching.
235
- input_token_logprobs = all_logprobs[
236
- torch.arange(all_logprobs.shape[0], device="cuda"),
237
- torch.cat(
238
- [
239
- torch.cat(pruned_input_ids)[1:],
240
- torch.tensor([0], device="cuda"),
241
- ]
242
- ),
243
- ]
244
- normalized_prompt_logprobs = self._get_normalized_prompt_logprobs(
245
- input_token_logprobs,
246
- logits_metadata,
247
- )
196
+ input_top_logprobs_val = input_top_logprobs_idx = None
197
+
198
+ # Compute the normalized logprobs for the requested tokens.
199
+ # Note that we pad a zero at the end for easy batching.
200
+ input_token_logprobs = input_logprobs[
201
+ torch.arange(input_logprobs.shape[0], device="cuda"),
202
+ torch.cat(
203
+ [
204
+ torch.cat(pruned_input_ids)[1:],
205
+ torch.tensor([0], device="cuda"),
206
+ ]
207
+ ),
208
+ ]
209
+ normalized_prompt_logprobs = self._get_normalized_prompt_logprobs(
210
+ input_token_logprobs,
211
+ logits_metadata,
212
+ )
248
213
 
249
- return LogitsProcessorOutput(
250
- next_token_logits=last_logits,
251
- next_token_logprobs=last_logprobs,
252
- normalized_prompt_logprobs=normalized_prompt_logprobs,
253
- input_token_logprobs=input_token_logprobs,
254
- input_top_logprobs_val=input_top_logprobs_val,
255
- input_top_logprobs_idx=input_top_logprobs_idx,
256
- output_top_logprobs_val=output_top_logprobs_val,
257
- output_top_logprobs_idx=output_top_logprobs_idx,
258
- )
214
+ return LogitsProcessorOutput(
215
+ next_token_logits=last_logits,
216
+ normalized_prompt_logprobs=normalized_prompt_logprobs,
217
+ input_token_logprobs=input_token_logprobs,
218
+ input_top_logprobs_val=input_top_logprobs_val,
219
+ input_top_logprobs_idx=input_top_logprobs_idx,
220
+ )
259
221
 
260
222
  def _get_logits(
261
223
  self,
@@ -269,9 +231,19 @@ class LogitsProcessor(nn.Module):
269
231
  # GGUF models
270
232
  logits = lm_head.linear_method.apply(lm_head, hidden_states, embedding_bias)
271
233
 
272
- # Optional scaling factor
273
234
  if self.logit_scale is not None:
274
- logits.mul_(self.logit_scale) # In-place multiply
235
+ logits.mul_(self.logit_scale)
236
+
237
+ if self.do_tensor_parallel_all_gather:
238
+ logits = tensor_model_parallel_all_gather(logits)
239
+
240
+ # Compute the normalized logprobs for the requested tokens.
241
+ # Note that we pad a zero at the end for easy batching.
242
+ logits = logits[:, : self.config.vocab_size].float()
243
+
244
+ if self.final_logit_softcapping:
245
+ fused_softcap(logits, self.final_logit_softcapping)
246
+
275
247
  return logits
276
248
 
277
249
  @staticmethod
@@ -302,90 +274,73 @@ class LogitsProcessor(nn.Module):
302
274
  values = ret.values.tolist()
303
275
  indices = ret.indices.tolist()
304
276
 
305
- if logits_metadata.forward_mode.is_decode():
306
- output_top_logprobs_val = []
307
- output_top_logprobs_idx = []
308
- for i, k in enumerate(logits_metadata.top_logprobs_nums):
309
- output_top_logprobs_val.append(values[i][:k])
310
- output_top_logprobs_idx.append(indices[i][:k])
311
- return None, None, output_top_logprobs_val, output_top_logprobs_idx
312
- else:
313
- input_top_logprobs_val, input_top_logprobs_idx = [], []
314
- output_top_logprobs_val, output_top_logprobs_idx = [], []
277
+ input_top_logprobs_val, input_top_logprobs_idx = [], []
315
278
 
316
- pt = 0
317
- for k, pruned_len in zip(
318
- logits_metadata.top_logprobs_nums,
319
- logits_metadata.extend_logprob_pruned_lens_cpu,
320
- ):
321
- if pruned_len <= 0:
322
- input_top_logprobs_val.append([])
323
- input_top_logprobs_idx.append([])
324
- output_top_logprobs_val.append([])
325
- output_top_logprobs_idx.append([])
326
- continue
327
-
328
- input_top_logprobs_val.append(
329
- [values[pt + j][:k] for j in range(pruned_len - 1)]
330
- )
331
- input_top_logprobs_idx.append(
332
- [indices[pt + j][:k] for j in range(pruned_len - 1)]
333
- )
334
- output_top_logprobs_val.append(
335
- list(
336
- values[pt + pruned_len - 1][:k],
337
- )
338
- )
339
- output_top_logprobs_idx.append(
340
- list(
341
- indices[pt + pruned_len - 1][:k],
342
- )
343
- )
344
- pt += pruned_len
279
+ pt = 0
280
+ for k, pruned_len in zip(
281
+ logits_metadata.top_logprobs_nums,
282
+ logits_metadata.extend_logprob_pruned_lens_cpu,
283
+ ):
284
+ if pruned_len <= 0:
285
+ input_top_logprobs_val.append([])
286
+ input_top_logprobs_idx.append([])
287
+ continue
345
288
 
346
- return (
347
- input_top_logprobs_val,
348
- input_top_logprobs_idx,
349
- output_top_logprobs_val,
350
- output_top_logprobs_idx,
289
+ input_top_logprobs_val.append(
290
+ [values[pt + j][:k] for j in range(pruned_len - 1)]
351
291
  )
292
+ input_top_logprobs_idx.append(
293
+ [indices[pt + j][:k] for j in range(pruned_len - 1)]
294
+ )
295
+ pt += pruned_len
296
+
297
+ return input_top_logprobs_val, input_top_logprobs_idx
352
298
 
353
299
  @staticmethod
354
300
  def compute_temp_top_p_normalized_logprobs(
355
301
  last_logits: torch.Tensor, logits_metadata: LogitsMetadata
356
302
  ) -> torch.Tensor:
303
+ # TODO: Implement the temp and top-p normalization
357
304
  return torch.nn.functional.log_softmax(last_logits, dim=-1)
358
305
 
359
306
 
360
- def test():
361
- all_logprobs = torch.tensor(
362
- # s s s
363
- [[0, 1, 2, 3], [1, 2, 3, 4], [2, 3, 4, 5], [3, 4, 5, 6], [4, 5, 6, 7]],
364
- dtype=torch.float32,
365
- device="cuda",
307
+ @triton.jit
308
+ def fused_softcap_kernel(
309
+ full_logits_ptr,
310
+ softcapping_value,
311
+ n_elements,
312
+ BLOCK_SIZE: tl.constexpr,
313
+ ):
314
+ pid = tl.program_id(0)
315
+ block_start = pid * BLOCK_SIZE
316
+ offsets = block_start + tl.arange(0, BLOCK_SIZE)
317
+ mask = offsets < n_elements
318
+
319
+ # Load values
320
+ x = tl.load(full_logits_ptr + offsets, mask=mask)
321
+
322
+ # Perform operations in-place
323
+ x = x / softcapping_value
324
+
325
+ # Manual tanh implementation using exp
326
+ exp2x = tl.exp(2 * x)
327
+ x = (exp2x - 1) / (exp2x + 1)
328
+
329
+ x = x * softcapping_value
330
+
331
+ # Store result
332
+ tl.store(full_logits_ptr + offsets, x, mask=mask)
333
+
334
+
335
+ def fused_softcap(full_logits, final_logit_softcapping):
336
+ n_elements = full_logits.numel()
337
+ BLOCK_SIZE = 1024
338
+ grid = ((n_elements + BLOCK_SIZE - 1) // BLOCK_SIZE, 1, 1)
339
+
340
+ fused_softcap_kernel[grid](
341
+ full_logits_ptr=full_logits,
342
+ softcapping_value=final_logit_softcapping,
343
+ n_elements=n_elements,
344
+ BLOCK_SIZE=BLOCK_SIZE,
366
345
  )
367
- seq_lens = torch.tensor([2, 0, 3, 0], dtype=torch.int32, device="cuda")
368
- input_ids = torch.tensor([1, 2, 3, 0, 1], dtype=torch.int32, device="cuda")
369
-
370
- token_logprobs = all_logprobs[
371
- torch.arange(all_logprobs.shape[0], device="cuda"),
372
- torch.cat([input_ids[1:], torch.tensor([0], device="cuda")]),
373
- ]
374
- logprobs_cumsum = torch.cumsum(token_logprobs, dim=0, dtype=torch.float32)
375
-
376
- len_cumsum = torch.cumsum(seq_lens, dim=0)
377
- start = torch.cat((torch.tensor([0], device="cuda"), len_cumsum[:-1]), 0)
378
- end = start + seq_lens - 2
379
- start.clamp_(min=0, max=token_logprobs.shape[0] - 1)
380
- end.clamp_(min=0, max=token_logprobs.shape[0] - 1)
381
- sum_logp = logprobs_cumsum[end] - logprobs_cumsum[start] + token_logprobs[start]
382
-
383
- # assert logprobs == [2, _, 2, 4, _]
384
- print("token logprobs", token_logprobs)
385
- print("start", start)
386
- print("end", end)
387
- print("sum_logp", sum_logp)
388
-
389
-
390
- if __name__ == "__main__":
391
- test()
346
+ return full_logits