sglang 0.4.3.post1__py3-none-any.whl → 0.4.3.post3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (219) hide show
  1. sglang/api.py +1 -1
  2. sglang/bench_offline_throughput.py +19 -0
  3. sglang/bench_one_batch.py +2 -2
  4. sglang/bench_serving.py +123 -79
  5. sglang/global_config.py +8 -3
  6. sglang/lang/backend/runtime_endpoint.py +1 -1
  7. sglang/lang/ir.py +1 -1
  8. sglang/srt/_custom_ops.py +83 -91
  9. sglang/srt/configs/load_config.py +4 -1
  10. sglang/srt/configs/model_config.py +48 -2
  11. sglang/srt/configs/qwen2_5_vl_config.py +5 -2
  12. sglang/srt/constrained/base_grammar_backend.py +117 -15
  13. sglang/srt/constrained/llguidance_backend.py +151 -0
  14. sglang/srt/constrained/outlines_backend.py +24 -33
  15. sglang/srt/constrained/xgrammar_backend.py +69 -38
  16. sglang/srt/distributed/device_communicators/custom_all_reduce.py +225 -80
  17. sglang/srt/distributed/parallel_state.py +48 -3
  18. sglang/srt/entrypoints/engine.py +67 -9
  19. sglang/srt/entrypoints/http_server.py +190 -41
  20. sglang/srt/entrypoints/verl_engine.py +147 -0
  21. sglang/srt/function_call_parser.py +0 -1
  22. sglang/srt/layers/activation.py +11 -0
  23. sglang/srt/layers/attention/{__init__.py → base_attn_backend.py} +14 -6
  24. sglang/srt/layers/attention/double_sparsity_backend.py +1 -1
  25. sglang/srt/layers/attention/flashinfer_backend.py +208 -295
  26. sglang/srt/layers/attention/flashinfer_mla_backend.py +582 -0
  27. sglang/srt/layers/attention/torch_native_backend.py +1 -1
  28. sglang/srt/layers/attention/triton_backend.py +9 -6
  29. sglang/srt/layers/attention/triton_ops/decode_attention.py +3 -0
  30. sglang/srt/layers/attention/triton_ops/extend_attention.py +20 -4
  31. sglang/srt/layers/attention/triton_ops/rocm_mla_decode_rope.py +439 -0
  32. sglang/srt/layers/attention/utils.py +39 -0
  33. sglang/srt/layers/attention/vision.py +60 -63
  34. sglang/srt/layers/dp_attention.py +142 -1
  35. sglang/srt/layers/layernorm.py +1 -1
  36. sglang/srt/layers/linear.py +3 -1
  37. sglang/srt/layers/logits_processor.py +281 -45
  38. sglang/srt/layers/moe/ep_moe/kernels.py +126 -8
  39. sglang/srt/layers/moe/ep_moe/layer.py +140 -28
  40. sglang/srt/layers/moe/fused_moe_native.py +2 -0
  41. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  42. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  43. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +50 -50
  44. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X.json +18 -18
  45. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI325X.json +18 -18
  46. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Radeon_Graphics.json +18 -18
  47. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI300X.json +18 -18
  48. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI325X.json +18 -18
  49. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Radeon_Graphics.json +18 -18
  50. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X.json +18 -18
  51. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI325X.json +18 -18
  52. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Radeon_Graphics.json +18 -18
  53. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +16 -16
  54. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +16 -16
  55. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +16 -16
  56. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X.json +18 -18
  57. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI325X.json +18 -18
  58. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Radeon_Graphics.json +18 -18
  59. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +15 -15
  60. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +15 -15
  61. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +15 -15
  62. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +88 -20
  63. sglang/srt/layers/moe/fused_moe_triton/layer.py +34 -13
  64. sglang/srt/layers/moe/topk.py +13 -4
  65. sglang/srt/layers/quantization/__init__.py +111 -7
  66. sglang/srt/layers/quantization/blockwise_int8.py +409 -0
  67. sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  68. sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  69. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  70. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  71. sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  72. sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  73. sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  74. sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  75. sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  76. sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  77. sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  78. sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  79. sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  80. sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  81. sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  82. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  83. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  84. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  85. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  86. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  87. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  88. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  89. sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  90. sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  91. sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  92. sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  93. sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  94. sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  95. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  96. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  97. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  98. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  99. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  100. sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  101. sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
  102. sglang/srt/layers/quantization/fp8.py +69 -28
  103. sglang/srt/layers/quantization/fp8_utils.py +17 -1
  104. sglang/srt/layers/quantization/gptq.py +416 -0
  105. sglang/srt/layers/quantization/int8_kernel.py +327 -0
  106. sglang/srt/layers/quantization/int8_utils.py +73 -0
  107. sglang/srt/layers/quantization/modelopt_quant.py +18 -1
  108. sglang/srt/layers/radix_attention.py +1 -0
  109. sglang/srt/layers/rotary_embedding.py +0 -1
  110. sglang/srt/layers/sampler.py +76 -31
  111. sglang/srt/layers/vocab_parallel_embedding.py +14 -13
  112. sglang/srt/lora/lora.py +17 -1
  113. sglang/srt/lora/lora_config.py +5 -0
  114. sglang/srt/lora/lora_manager.py +1 -3
  115. sglang/srt/managers/cache_controller.py +193 -62
  116. sglang/srt/managers/configure_logging.py +2 -1
  117. sglang/srt/managers/data_parallel_controller.py +6 -2
  118. sglang/srt/managers/detokenizer_manager.py +124 -102
  119. sglang/srt/managers/image_processor.py +2 -1
  120. sglang/srt/managers/io_struct.py +143 -6
  121. sglang/srt/managers/schedule_batch.py +238 -197
  122. sglang/srt/managers/schedule_policy.py +29 -29
  123. sglang/srt/managers/scheduler.py +681 -259
  124. sglang/srt/managers/session_controller.py +6 -2
  125. sglang/srt/managers/tokenizer_manager.py +224 -68
  126. sglang/srt/managers/tp_worker.py +15 -4
  127. sglang/srt/managers/tp_worker_overlap_thread.py +3 -4
  128. sglang/srt/mem_cache/chunk_cache.py +18 -11
  129. sglang/srt/mem_cache/hiradix_cache.py +394 -0
  130. sglang/srt/mem_cache/memory_pool.py +44 -18
  131. sglang/srt/mem_cache/radix_cache.py +58 -47
  132. sglang/srt/metrics/collector.py +94 -36
  133. sglang/srt/model_executor/cuda_graph_runner.py +55 -24
  134. sglang/srt/model_executor/forward_batch_info.py +49 -16
  135. sglang/srt/model_executor/model_runner.py +209 -28
  136. sglang/srt/model_loader/loader.py +3 -3
  137. sglang/srt/model_loader/weight_utils.py +36 -14
  138. sglang/srt/models/baichuan.py +31 -6
  139. sglang/srt/models/chatglm.py +39 -7
  140. sglang/srt/models/commandr.py +29 -5
  141. sglang/srt/models/dbrx.py +31 -5
  142. sglang/srt/models/deepseek.py +43 -6
  143. sglang/srt/models/deepseek_nextn.py +32 -19
  144. sglang/srt/models/deepseek_v2.py +265 -29
  145. sglang/srt/models/exaone.py +19 -9
  146. sglang/srt/models/gemma.py +22 -8
  147. sglang/srt/models/gemma2.py +25 -12
  148. sglang/srt/models/gemma2_reward.py +5 -1
  149. sglang/srt/models/gpt2.py +28 -13
  150. sglang/srt/models/gpt_bigcode.py +27 -5
  151. sglang/srt/models/granite.py +21 -9
  152. sglang/srt/models/grok.py +21 -4
  153. sglang/srt/models/internlm2.py +36 -6
  154. sglang/srt/models/internlm2_reward.py +5 -1
  155. sglang/srt/models/llama.py +26 -9
  156. sglang/srt/models/llama_classification.py +5 -1
  157. sglang/srt/models/llama_eagle.py +17 -4
  158. sglang/srt/models/llama_embedding.py +5 -1
  159. sglang/srt/models/llama_reward.py +7 -2
  160. sglang/srt/models/llava.py +19 -3
  161. sglang/srt/models/llavavid.py +10 -1
  162. sglang/srt/models/minicpm.py +26 -2
  163. sglang/srt/models/minicpm3.py +39 -3
  164. sglang/srt/models/minicpmv.py +45 -14
  165. sglang/srt/models/mixtral.py +20 -9
  166. sglang/srt/models/mixtral_quant.py +50 -8
  167. sglang/srt/models/mllama.py +57 -11
  168. sglang/srt/models/olmo.py +34 -6
  169. sglang/srt/models/olmo2.py +34 -13
  170. sglang/srt/models/olmoe.py +26 -4
  171. sglang/srt/models/phi3_small.py +29 -10
  172. sglang/srt/models/qwen.py +26 -3
  173. sglang/srt/models/qwen2.py +26 -4
  174. sglang/srt/models/qwen2_5_vl.py +46 -8
  175. sglang/srt/models/qwen2_eagle.py +17 -5
  176. sglang/srt/models/qwen2_moe.py +44 -6
  177. sglang/srt/models/qwen2_rm.py +78 -0
  178. sglang/srt/models/qwen2_vl.py +39 -8
  179. sglang/srt/models/stablelm.py +32 -5
  180. sglang/srt/models/torch_native_llama.py +5 -2
  181. sglang/srt/models/xverse.py +21 -9
  182. sglang/srt/models/xverse_moe.py +45 -7
  183. sglang/srt/models/yivl.py +2 -1
  184. sglang/srt/openai_api/adapter.py +109 -24
  185. sglang/srt/openai_api/protocol.py +17 -1
  186. sglang/srt/reasoning_parser.py +154 -0
  187. sglang/srt/sampling/penaltylib/__init__.py +4 -6
  188. sglang/srt/sampling/penaltylib/frequency_penalty.py +66 -0
  189. sglang/srt/sampling/penaltylib/{penalizers/min_new_tokens.py → min_new_tokens.py} +15 -23
  190. sglang/srt/sampling/penaltylib/orchestrator.py +39 -188
  191. sglang/srt/sampling/penaltylib/presence_penalty.py +66 -0
  192. sglang/srt/sampling/sampling_batch_info.py +79 -157
  193. sglang/srt/sampling/sampling_params.py +16 -13
  194. sglang/srt/server_args.py +136 -52
  195. sglang/srt/speculative/build_eagle_tree.py +2 -8
  196. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +0 -1
  197. sglang/srt/speculative/eagle_utils.py +92 -58
  198. sglang/srt/speculative/eagle_worker.py +186 -94
  199. sglang/srt/speculative/spec_info.py +1 -13
  200. sglang/srt/utils.py +43 -17
  201. sglang/srt/warmup.py +47 -0
  202. sglang/test/few_shot_gsm8k.py +4 -1
  203. sglang/test/runners.py +389 -126
  204. sglang/test/send_one.py +88 -0
  205. sglang/test/test_block_fp8_ep.py +361 -0
  206. sglang/test/test_programs.py +1 -1
  207. sglang/test/test_utils.py +138 -84
  208. sglang/utils.py +50 -60
  209. sglang/version.py +1 -1
  210. {sglang-0.4.3.post1.dist-info → sglang-0.4.3.post3.dist-info}/METADATA +21 -15
  211. {sglang-0.4.3.post1.dist-info → sglang-0.4.3.post3.dist-info}/RECORD +214 -166
  212. {sglang-0.4.3.post1.dist-info → sglang-0.4.3.post3.dist-info}/WHEEL +1 -1
  213. sglang/bench_latency.py +0 -1
  214. sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +0 -75
  215. sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +0 -74
  216. sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +0 -85
  217. sglang/test/srt/sampling/penaltylib/utils.py +0 -344
  218. {sglang-0.4.3.post1.dist-info → sglang-0.4.3.post3.dist-info}/LICENSE +0 -0
  219. {sglang-0.4.3.post1.dist-info → sglang-0.4.3.post3.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,6 @@
1
1
  from __future__ import annotations
2
2
 
3
+ from functools import lru_cache
3
4
  from typing import Optional
4
5
 
5
6
  import torch
@@ -18,6 +19,7 @@ from sglang.srt.layers.linear import (
18
19
  RowParallelLinear,
19
20
  )
20
21
  from sglang.srt.layers.quantization import QuantizationConfig
22
+ from sglang.srt.utils import add_prefix
21
23
 
22
24
 
23
25
  def rotate_half(x: torch.Tensor, interleaved: bool = False) -> torch.Tensor:
@@ -121,20 +123,20 @@ class VisionAttention(nn.Module):
121
123
  head_size=self.head_size,
122
124
  total_num_heads=num_heads,
123
125
  quant_config=quant_config,
124
- prefix=f"{prefix}.qkv_proj",
126
+ prefix=add_prefix("qkv_proj", prefix),
125
127
  )
126
128
  else:
127
129
  self.qkv_proj = ColumnParallelLinear(
128
130
  input_size=embed_dim,
129
131
  output_size=3 * projection_size,
130
132
  quant_config=quant_config,
131
- prefix=f"{prefix}.qkv_proj",
133
+ prefix=add_prefix("qkv_proj", prefix),
132
134
  )
133
135
  self.proj = RowParallelLinear(
134
136
  input_size=embed_dim,
135
137
  output_size=embed_dim,
136
138
  quant_config=quant_config,
137
- prefix=f"{prefix}.out_proj",
139
+ prefix=add_prefix("out_proj", prefix),
138
140
  )
139
141
 
140
142
  def forward(
@@ -223,9 +225,6 @@ class VisionSdpaAttention(nn.Module):
223
225
 
224
226
  """
225
227
 
226
- # TODO: Should it be released after used?
227
- _mask_cache = {}
228
-
229
228
  def __init__(
230
229
  self,
231
230
  head_size: int,
@@ -239,75 +238,61 @@ class VisionSdpaAttention(nn.Module):
239
238
  self.use_full_precision_softmax = use_full_precision_softmax
240
239
  self.dropout = dropout
241
240
 
242
- def generate_patch_attention_mask(
243
- self,
244
- s: int,
245
- bsz: int,
246
- device,
247
- cu_seqlens: Optional[torch.Tensor],
248
- flatten_batch: bool = False,
249
- dtype=torch.bfloat16,
250
- ) -> torch.Tensor:
251
- r"""
252
- Creates a non-causal 4D mask of shape `(b, 1, s, s)` or `(1, 1, s, s)`.
253
-
254
- When `flatten_batch` is True:
255
- - All sequences in the batch are flattened into a single dimension
256
- - `s` represents the total number of tokens across all sequences in the batch
257
- - Returns a unified mask of shape `(1, 1, s, s)`
258
-
259
- When `flatten_batch` is False:
260
- - Each sequence has its own attention mask
261
- - `s` represents the maximum sequence length in the batch
262
- - Returns separate masks of shape `(b, 1, s, s)`
263
-
241
+ @staticmethod
242
+ @lru_cache(maxsize=128)
243
+ def _generate_mask_cache(
244
+ s: int, flatten_batch: bool, cu_seqlens: tuple
245
+ ) -> torch.BoolTensor:
246
+ """
247
+ Generate a boolean attention mask with caching mechanism.
264
248
  Args:
265
- flatten_batch: (bool):
266
- If True, treats all sequences in the batch as a single flattened sequence
267
- If False, generates separate masks for each sequence
268
-
249
+ s: sequence length
250
+ flatten_batch: whether to flatten batch dimension
251
+ cu_seqlens: tuple of cumulative sequence lengths
269
252
  Returns:
270
- Tensor of shape `(b, 1, s, s)` or `(1, 1, s, s)`.
253
+ attention mask tensor
271
254
  """
272
-
273
- cache_key = (s, bsz, flatten_batch, tuple(cu_seqlens.cpu().tolist()))
274
-
275
- if cache_key in VisionSdpaAttention._mask_cache:
276
- cached_mask = VisionSdpaAttention._mask_cache[cache_key]
277
- # print(f"cache hit for key: {cache_key}")
278
- return cached_mask.to(device=device, dtype=dtype)
279
-
280
- if cu_seqlens is None:
281
- raise ValueError("Internal Error: cu_seqlens cannot be None")
282
-
283
255
  if flatten_batch:
284
- mask = torch.zeros([1, s, s], device=device, dtype=torch.bool)
256
+ mask = torch.zeros([1, s, s], dtype=torch.bool)
285
257
  for i in range(1, len(cu_seqlens)):
286
258
  start = cu_seqlens[i - 1]
287
259
  end = cu_seqlens[i]
288
- mask[
289
- ...,
290
- start:end,
291
- start:end,
292
- ] = True
260
+ mask[..., start:end, start:end] = True
293
261
  else:
294
262
  # [1, 1, 1, s]
295
- row_indices = torch.arange(s, device=device).view(1, 1, 1, s)
263
+ row_indices = torch.arange(s).view(1, 1, 1, s)
296
264
  # [1, 1, s, 1]
297
- col_indices = torch.arange(s, device=device).view(1, 1, s, 1)
265
+ col_indices = torch.arange(s).view(1, 1, s, 1)
298
266
  # [b, 1, 1, 1]
299
- seq_lens = (
300
- (cu_seqlens[1:] - cu_seqlens[:-1]).to(device=device).view(-1, 1, 1, 1)
301
- )
267
+ seq_lens = torch.tensor(
268
+ [end - start for start, end in zip(cu_seqlens[:-1], cu_seqlens[1:])],
269
+ ).view(-1, 1, 1, 1)
302
270
 
303
271
  mask = (row_indices < seq_lens) & (col_indices < seq_lens)
304
272
 
305
- # Convert to attention mask format (False -> 0, True -> -inf)
306
- mask = (~mask).to(dtype) * torch.finfo(dtype).min
273
+ return mask
274
+
275
+ def generate_patch_attention_mask(
276
+ self,
277
+ s: int,
278
+ cu_seqlens: Optional[torch.Tensor],
279
+ flatten_batch: bool = False,
280
+ ) -> Optional[torch.Tensor]:
281
+ r"""
282
+ Creates a non-causal 4D mask of shape `(b, 1, s, s)` or `(1, 1, s, s)`.
283
+ Args:
284
+ s: sequence length
285
+ cu_seqlens: cumulative sequence lengths tensor. If not, returns an empty mask
286
+ flatten_batch: whether to flatten batch dimension
287
+ Returns:
288
+ attention mask tensor or None
289
+ """
290
+ if cu_seqlens is None:
291
+ return None
307
292
 
308
- VisionSdpaAttention._mask_cache[cache_key] = mask
293
+ cu_seqlens_tuple = tuple(cu_seqlens.cpu().tolist())
309
294
 
310
- return mask
295
+ return self._generate_mask_cache(s, flatten_batch, cu_seqlens_tuple)
311
296
 
312
297
  def forward(
313
298
  self,
@@ -330,15 +315,23 @@ class VisionSdpaAttention(nn.Module):
330
315
  # [b, 1, s, s]
331
316
  if attention_mask is None:
332
317
  attention_mask = self.generate_patch_attention_mask(
333
- s, bsz, q.device, cu_seqlens, self.flatten_batch, q.dtype
318
+ s, cu_seqlens, flatten_batch=self.flatten_batch
334
319
  )
320
+
321
+ if attention_mask is None:
322
+ if self.use_full_precision_softmax:
323
+ raise RuntimeError("Empty attention mask")
324
+ else:
325
+ attention_mask = attention_mask.to(device=q.device)
326
+
335
327
  q, k, v = [rearrange(x, "(b s) h d -> b h s d", b=bsz) for x in [q, k, v]]
336
- # [b, 1, s]
328
+
337
329
  if self.use_full_precision_softmax:
338
330
  scale = self.head_size**-0.5
339
331
  k_transposed = rearrange(k, "b h s d -> b h d s")
340
332
  attn_weights = torch.matmul(q, k_transposed) * scale
341
333
  del k, k_transposed
334
+ attention_mask = (~attention_mask) * torch.finfo(q.dtype).min
342
335
  attn_weights = attn_weights + attention_mask
343
336
  del attention_mask
344
337
  # full-precision
@@ -354,7 +347,12 @@ class VisionSdpaAttention(nn.Module):
354
347
  # SDPA
355
348
  # [b, h, s, head_size]
356
349
  output = F.scaled_dot_product_attention(
357
- q, k, v, attention_mask, dropout_p=self.dropout
350
+ q,
351
+ k,
352
+ v,
353
+ attn_mask=attention_mask,
354
+ dropout_p=self.dropout,
355
+ is_causal=False,
358
356
  )
359
357
 
360
358
  # [b, h, s, head_size] --> [b * s, h, head_size]
@@ -380,7 +378,6 @@ class VisionTritonAttention(nn.Module):
380
378
  v: torch.Tensor,
381
379
  _bsz: int,
382
380
  cu_seqlens: Optional[torch.Tensor],
383
- **kwargs,
384
381
  ) -> torch.Tensor:
385
382
  r"""
386
383
  Args:
@@ -1,6 +1,21 @@
1
+ from __future__ import annotations
2
+
3
+ import functools
4
+ from typing import TYPE_CHECKING, Union
5
+
1
6
  import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ from sglang.srt.distributed import (
11
+ GroupCoordinator,
12
+ get_tensor_model_parallel_world_size,
13
+ get_tp_group,
14
+ tensor_model_parallel_all_reduce,
15
+ )
2
16
 
3
- from sglang.srt.distributed import GroupCoordinator, get_tp_group
17
+ if TYPE_CHECKING:
18
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
4
19
 
5
20
  _ATTN_TP_GROUP = None
6
21
  _ATTN_TP_RANK = None
@@ -69,3 +84,129 @@ def get_attention_dp_rank():
69
84
  def get_attention_dp_size():
70
85
  assert _DP_SIZE is not None, "dp attention not initialized!"
71
86
  return _DP_SIZE
87
+
88
+
89
+ def get_dp_local_info(forward_batch: ForwardBatch):
90
+ dp_rank = get_attention_dp_rank()
91
+
92
+ if forward_batch.dp_local_start_pos is None:
93
+ cumtokens = torch.cumsum(forward_batch.global_num_tokens_gpu, dim=0)
94
+ if dp_rank == 0:
95
+ local_start_pos = torch.zeros_like(cumtokens[0])
96
+ else:
97
+ local_start_pos = cumtokens[dp_rank - 1]
98
+ local_num_tokens = forward_batch.global_num_tokens_gpu[dp_rank]
99
+
100
+ forward_batch.dp_local_start_pos = local_start_pos
101
+ forward_batch.dp_local_num_tokens = local_num_tokens
102
+
103
+ return forward_batch.dp_local_start_pos, forward_batch.dp_local_num_tokens
104
+
105
+
106
+ @triton.jit
107
+ def memcpy_triton_kernel(
108
+ dst_ptr,
109
+ src_ptr,
110
+ offset_ptr,
111
+ sz_ptr,
112
+ offset_src,
113
+ chunk_size, # multiplied for offset and sz
114
+ BLOCK_SIZE: tl.constexpr,
115
+ ):
116
+ pid = tl.program_id(axis=0).to(tl.int64)
117
+ offset = tl.load(offset_ptr).to(tl.int64) * chunk_size
118
+ sz = tl.load(sz_ptr).to(tl.int64) * chunk_size
119
+
120
+ start_index = pid * BLOCK_SIZE
121
+ offs = tl.arange(0, BLOCK_SIZE)
122
+ mask = start_index + offs < sz
123
+
124
+ if offset_src:
125
+ data = tl.load(src_ptr + offset + start_index + offs, mask=mask)
126
+ tl.store(dst_ptr + start_index + offs, data, mask=mask)
127
+ else:
128
+ data = tl.load(src_ptr + start_index + offs, mask=mask)
129
+ tl.store(dst_ptr + offset + start_index + offs, data, mask=mask)
130
+
131
+
132
+ def prod(x):
133
+ return functools.reduce(lambda a, b: a * b, x, 1)
134
+
135
+
136
+ def memcpy_triton(dst, src, dim, offset, sz, offset_src):
137
+ max_size = min(src.numel(), dst.numel())
138
+ assert dim == 0, "dim != 0 unsupported"
139
+ assert src.shape[1:] == dst.shape[1:], "src and dst must have same shape"
140
+ chunk_size = prod(src.shape[1:])
141
+ BLOCK_SIZE = 8192
142
+ grid = (triton.cdiv(max_size, BLOCK_SIZE),)
143
+
144
+ memcpy_triton_kernel[grid](dst, src, offset, sz, offset_src, chunk_size, BLOCK_SIZE)
145
+
146
+
147
+ def dp_gather(
148
+ global_tokens: torch.Tensor,
149
+ local_tokens: torch.Tensor,
150
+ forward_batch: ForwardBatch,
151
+ layer_id: Union[str, int],
152
+ ):
153
+ local_start_pos, local_num_tokens = get_dp_local_info(forward_batch)
154
+
155
+ global_tokens.fill_(0)
156
+ assert local_tokens.is_contiguous()
157
+ assert global_tokens.is_contiguous()
158
+ if local_tokens.shape[0] > 0 and (
159
+ layer_id != "embedding" or get_attention_tp_rank() == 0
160
+ ):
161
+ assert (
162
+ global_tokens.storage().data_ptr() != local_tokens.storage().data_ptr()
163
+ ), "aliasing between global_tokens and local_tokens not allowed"
164
+ memcpy_triton(
165
+ global_tokens, local_tokens, 0, local_start_pos, local_num_tokens, False
166
+ )
167
+
168
+ # Input IDs are in int 32. We should use inplace_all_reduce for local case becaues of custom all reduce.
169
+ NUM_GPUS_PER_NODE = 8
170
+ if (
171
+ not local_tokens.dtype.is_floating_point
172
+ and get_tensor_model_parallel_world_size() <= NUM_GPUS_PER_NODE
173
+ ):
174
+ torch.ops.sglang.inplace_all_reduce(
175
+ global_tokens, group_name=get_tp_group().unique_name
176
+ )
177
+ else:
178
+ global_tokens = tensor_model_parallel_all_reduce(global_tokens)
179
+
180
+
181
+ def dp_scatter(
182
+ local_tokens: torch.Tensor, # output
183
+ global_tokens: torch.Tensor, # input
184
+ forward_batch: ForwardBatch,
185
+ ):
186
+ # local_num_tokens is not necessarily the same as local_tokens.shape[0],
187
+ # since local_tokens may be padded for cuda graph
188
+ local_start_pos, local_num_tokens = get_dp_local_info(forward_batch)
189
+ local_tokens.fill_(0)
190
+ assert local_tokens.is_contiguous()
191
+ assert global_tokens.is_contiguous()
192
+ if local_tokens.shape[0] > 0:
193
+ assert (
194
+ local_tokens.untyped_storage().data_ptr()
195
+ != global_tokens.untyped_storage().data_ptr()
196
+ ), "aliasing between local_tokens and global_tokens not allowed"
197
+ memcpy_triton(
198
+ local_tokens, global_tokens, 0, local_start_pos, local_num_tokens, True
199
+ )
200
+
201
+
202
+ def get_do_logits_dp_scatter(forward_batch: ForwardBatch):
203
+ def do_logits_dp_scatter(logits: torch.Tensor):
204
+ local_logits = torch.empty(
205
+ (forward_batch.input_ids.shape[0], *logits.shape[1:]),
206
+ dtype=logits.dtype,
207
+ device=logits.device,
208
+ )
209
+ dp_scatter(local_logits, logits, forward_batch)
210
+ return local_logits
211
+
212
+ return do_logits_dp_scatter
@@ -69,7 +69,7 @@ class RMSNorm(CustomOp):
69
69
 
70
70
  variance = x.pow(2).mean(dim=-1, keepdim=True)
71
71
  x = x * torch.rsqrt(variance + self.variance_epsilon)
72
- x = x.to(orig_dtype) * self.weight
72
+ x = (x * self.weight).to(orig_dtype)
73
73
  if residual is None:
74
74
  return x
75
75
  else:
@@ -38,6 +38,7 @@ WEIGHT_LOADER_V2_SUPPORTED = [
38
38
  "AWQLinearMethod",
39
39
  "GPTQMarlinLinearMethod",
40
40
  "Fp8LinearMethod",
41
+ "BlockInt8LinearMethod",
41
42
  "MarlinLinearMethod",
42
43
  "QQQLinearMethod",
43
44
  "GPTQMarlin24LinearMethod",
@@ -425,13 +426,14 @@ class ColumnParallelLinear(LinearBase):
425
426
  from sglang.srt.layers.parameter import _ColumnvLLMParameter
426
427
 
427
428
  if isinstance(param, _ColumnvLLMParameter):
428
- # FIXME: why would we need this special case?
429
429
  param.load_column_parallel_weight(
430
430
  loaded_weight,
431
431
  tp_rank=self.tp_rank,
432
432
  use_presharded_weights=self.use_presharded_weights,
433
433
  )
434
434
  else:
435
+ # FIXME: This branch is needed to load deepseek v3 awq.
436
+ # However, we should fix this and avoid the branching here.
435
437
  param.load_column_parallel_weight(loaded_weight)
436
438
 
437
439
  def forward(self, input_):