sglang 0.5.3rc0__py3-none-any.whl → 0.5.3rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (282) hide show
  1. sglang/bench_one_batch.py +7 -9
  2. sglang/bench_one_batch_server.py +321 -31
  3. sglang/bench_serving.py +10 -3
  4. sglang/global_config.py +2 -2
  5. sglang/lang/backend/runtime_endpoint.py +1 -1
  6. sglang/launch_server.py +14 -0
  7. sglang/profiler.py +2 -2
  8. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  9. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
  10. sglang/srt/configs/__init__.py +4 -0
  11. sglang/srt/configs/dots_ocr.py +64 -0
  12. sglang/srt/configs/falcon_h1.py +360 -0
  13. sglang/srt/configs/load_config.py +8 -0
  14. sglang/srt/configs/model_config.py +160 -105
  15. sglang/srt/configs/qwen3_vl.py +586 -0
  16. sglang/srt/constrained/base_grammar_backend.py +1 -0
  17. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  18. sglang/srt/constrained/xgrammar_backend.py +6 -4
  19. sglang/srt/debug_utils/dumper.py +10 -3
  20. sglang/srt/disaggregation/ascend/conn.py +2 -2
  21. sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
  22. sglang/srt/disaggregation/common/conn.py +266 -98
  23. sglang/srt/disaggregation/decode.py +50 -9
  24. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  25. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +25 -16
  26. sglang/srt/disaggregation/mooncake/conn.py +51 -541
  27. sglang/srt/disaggregation/nixl/conn.py +148 -39
  28. sglang/srt/disaggregation/prefill.py +31 -14
  29. sglang/srt/disaggregation/utils.py +36 -5
  30. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  31. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  32. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  33. sglang/srt/distributed/parallel_state.py +135 -80
  34. sglang/srt/entrypoints/engine.py +23 -3
  35. sglang/srt/entrypoints/grpc_request_manager.py +330 -55
  36. sglang/srt/entrypoints/grpc_server.py +232 -102
  37. sglang/srt/entrypoints/http_server.py +49 -9
  38. sglang/srt/entrypoints/openai/protocol.py +110 -5
  39. sglang/srt/entrypoints/openai/serving_base.py +25 -6
  40. sglang/srt/entrypoints/openai/serving_chat.py +178 -49
  41. sglang/srt/entrypoints/openai/serving_completions.py +5 -3
  42. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  43. sglang/srt/entrypoints/openai/serving_responses.py +42 -0
  44. sglang/srt/environ.py +285 -0
  45. sglang/srt/eplb/expert_location.py +30 -5
  46. sglang/srt/function_call/function_call_parser.py +3 -2
  47. sglang/srt/function_call/glm4_moe_detector.py +3 -3
  48. sglang/srt/function_call/gpt_oss_detector.py +23 -0
  49. sglang/srt/function_call/json_array_parser.py +63 -0
  50. sglang/srt/function_call/kimik2_detector.py +17 -4
  51. sglang/srt/function_call/utils.py +96 -5
  52. sglang/srt/grpc/compile_proto.py +245 -0
  53. sglang/srt/grpc/sglang_scheduler_pb2.py +73 -68
  54. sglang/srt/grpc/sglang_scheduler_pb2.pyi +60 -53
  55. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +3 -0
  56. sglang/srt/layers/activation.py +7 -6
  57. sglang/srt/layers/attention/aiter_backend.py +14 -15
  58. sglang/srt/layers/attention/ascend_backend.py +108 -9
  59. sglang/srt/layers/attention/attention_registry.py +206 -0
  60. sglang/srt/layers/attention/base_attn_backend.py +12 -3
  61. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  62. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  63. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +2 -2
  64. sglang/srt/layers/attention/fla/fused_recurrent.py +4 -4
  65. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +2 -2
  66. sglang/srt/layers/attention/flashattention_backend.py +41 -8
  67. sglang/srt/layers/attention/flashinfer_backend.py +112 -194
  68. sglang/srt/layers/attention/flashinfer_mla_backend.py +11 -15
  69. sglang/srt/layers/attention/flashmla_backend.py +7 -5
  70. sglang/srt/layers/attention/hybrid_attn_backend.py +11 -3
  71. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +72 -72
  72. sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -0
  73. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +15 -98
  74. sglang/srt/layers/attention/mamba/mamba.py +566 -1
  75. sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
  76. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  77. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  78. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  79. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
  80. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
  81. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
  82. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
  83. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
  84. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  85. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  86. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  87. sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
  88. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  89. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  90. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  91. sglang/srt/layers/attention/nsa/utils.py +24 -0
  92. sglang/srt/layers/attention/nsa_backend.py +887 -0
  93. sglang/srt/layers/attention/tbo_backend.py +6 -6
  94. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  95. sglang/srt/layers/attention/triton_backend.py +42 -9
  96. sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
  97. sglang/srt/layers/attention/trtllm_mla_backend.py +178 -34
  98. sglang/srt/layers/attention/vision.py +58 -0
  99. sglang/srt/layers/attention/wave_backend.py +4 -4
  100. sglang/srt/layers/communicator.py +8 -0
  101. sglang/srt/layers/dp_attention.py +11 -1
  102. sglang/srt/layers/elementwise.py +3 -1
  103. sglang/srt/layers/layernorm.py +2 -0
  104. sglang/srt/layers/linear.py +21 -4
  105. sglang/srt/layers/logits_processor.py +15 -2
  106. sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
  107. sglang/srt/layers/moe/ep_moe/layer.py +147 -74
  108. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +52 -25
  109. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  110. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  111. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  112. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +6 -2
  113. sglang/srt/layers/moe/fused_moe_triton/layer.py +11 -12
  114. sglang/srt/layers/moe/token_dispatcher/deepep.py +77 -19
  115. sglang/srt/layers/moe/utils.py +10 -0
  116. sglang/srt/layers/parameter.py +23 -6
  117. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  118. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  119. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  120. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
  121. sglang/srt/layers/quantization/fp8.py +2 -2
  122. sglang/srt/layers/quantization/fp8_utils.py +1 -1
  123. sglang/srt/layers/quantization/modelopt_quant.py +44 -9
  124. sglang/srt/layers/quantization/mxfp4.py +12 -4
  125. sglang/srt/layers/quantization/quark/quark_moe.py +16 -3
  126. sglang/srt/layers/quantization/w4afp8.py +0 -4
  127. sglang/srt/layers/quantization/w8a8_int8.py +15 -3
  128. sglang/srt/layers/rotary_embedding.py +78 -31
  129. sglang/srt/layers/sampler.py +52 -4
  130. sglang/srt/layers/utils.py +23 -0
  131. sglang/srt/lora/backend/base_backend.py +3 -3
  132. sglang/srt/lora/backend/chunked_backend.py +348 -0
  133. sglang/srt/lora/backend/triton_backend.py +10 -4
  134. sglang/srt/lora/lora.py +7 -5
  135. sglang/srt/lora/lora_manager.py +17 -6
  136. sglang/srt/lora/mem_pool.py +1 -1
  137. sglang/srt/lora/triton_ops/__init__.py +4 -0
  138. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  139. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
  140. sglang/srt/lora/utils.py +7 -5
  141. sglang/srt/managers/cache_controller.py +42 -142
  142. sglang/srt/managers/data_parallel_controller.py +11 -46
  143. sglang/srt/managers/detokenizer_manager.py +11 -11
  144. sglang/srt/managers/io_struct.py +162 -118
  145. sglang/srt/managers/mm_utils.py +43 -6
  146. sglang/srt/managers/multi_tokenizer_mixin.py +17 -17
  147. sglang/srt/managers/multimodal_processor.py +1 -2
  148. sglang/srt/managers/overlap_utils.py +53 -0
  149. sglang/srt/managers/schedule_batch.py +167 -86
  150. sglang/srt/managers/schedule_policy.py +143 -16
  151. sglang/srt/managers/scheduler.py +359 -214
  152. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  153. sglang/srt/managers/scheduler_metrics_mixin.py +98 -126
  154. sglang/srt/managers/scheduler_output_processor_mixin.py +21 -12
  155. sglang/srt/managers/scheduler_profiler_mixin.py +5 -5
  156. sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
  157. sglang/srt/managers/tokenizer_communicator_mixin.py +111 -5
  158. sglang/srt/managers/tokenizer_manager.py +84 -136
  159. sglang/srt/managers/tp_worker.py +39 -29
  160. sglang/srt/managers/tp_worker_overlap_thread.py +33 -41
  161. sglang/srt/managers/utils.py +1 -45
  162. sglang/srt/mem_cache/allocator.py +14 -20
  163. sglang/srt/mem_cache/allocator_ascend.py +41 -27
  164. sglang/srt/mem_cache/base_prefix_cache.py +1 -1
  165. sglang/srt/mem_cache/chunk_cache.py +8 -1
  166. sglang/srt/mem_cache/evict_policy.py +23 -0
  167. sglang/srt/mem_cache/hicache_storage.py +40 -1
  168. sglang/srt/mem_cache/hiradix_cache.py +119 -32
  169. sglang/srt/mem_cache/memory_pool.py +188 -10
  170. sglang/srt/mem_cache/memory_pool_host.py +134 -182
  171. sglang/srt/mem_cache/radix_cache.py +222 -71
  172. sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
  173. sglang/srt/mem_cache/storage/__init__.py +10 -0
  174. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
  175. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
  176. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  177. sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
  178. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  179. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +173 -58
  180. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +10 -6
  181. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +117 -10
  182. sglang/srt/mem_cache/swa_radix_cache.py +25 -34
  183. sglang/srt/metrics/collector.py +82 -120
  184. sglang/srt/metrics/func_timer.py +2 -7
  185. sglang/srt/metrics/utils.py +8 -1
  186. sglang/srt/model_executor/cpu_graph_runner.py +2 -2
  187. sglang/srt/model_executor/cuda_graph_runner.py +39 -32
  188. sglang/srt/model_executor/forward_batch_info.py +23 -38
  189. sglang/srt/model_executor/model_runner.py +131 -183
  190. sglang/srt/model_executor/npu_graph_runner.py +12 -5
  191. sglang/srt/model_loader/loader.py +14 -10
  192. sglang/srt/model_loader/weight_utils.py +156 -2
  193. sglang/srt/models/bailing_moe.py +27 -4
  194. sglang/srt/models/deepseek_nextn.py +6 -1
  195. sglang/srt/models/deepseek_v2.py +536 -153
  196. sglang/srt/models/dots_ocr.py +173 -0
  197. sglang/srt/models/falcon_h1.py +576 -0
  198. sglang/srt/models/gemma3_causal.py +0 -2
  199. sglang/srt/models/gemma3_mm.py +1 -1
  200. sglang/srt/models/gemma3n_mm.py +1 -1
  201. sglang/srt/models/glm4_moe.py +3 -3
  202. sglang/srt/models/glm4_moe_nextn.py +2 -2
  203. sglang/srt/models/glm4v.py +1 -1
  204. sglang/srt/models/glm4v_moe.py +1 -1
  205. sglang/srt/models/gpt_oss.py +7 -30
  206. sglang/srt/models/kimi_vl_moonvit.py +2 -2
  207. sglang/srt/models/llama.py +4 -0
  208. sglang/srt/models/longcat_flash.py +1 -1
  209. sglang/srt/models/longcat_flash_nextn.py +1 -1
  210. sglang/srt/models/mllama4.py +15 -4
  211. sglang/srt/models/qwen2.py +0 -7
  212. sglang/srt/models/qwen2_5_vl.py +2 -2
  213. sglang/srt/models/qwen2_audio.py +1 -1
  214. sglang/srt/models/qwen2_moe.py +64 -1
  215. sglang/srt/models/qwen2_vl.py +1 -1
  216. sglang/srt/models/qwen3.py +18 -3
  217. sglang/srt/models/qwen3_moe.py +31 -3
  218. sglang/srt/models/qwen3_next.py +36 -9
  219. sglang/srt/models/qwen3_vl.py +787 -0
  220. sglang/srt/models/qwen3_vl_moe.py +471 -0
  221. sglang/srt/models/registry.py +15 -3
  222. sglang/srt/models/sarashina2_vision.py +269 -0
  223. sglang/srt/models/solar.py +505 -0
  224. sglang/srt/models/starcoder2.py +357 -0
  225. sglang/srt/models/torch_native_llama.py +9 -2
  226. sglang/srt/models/utils.py +51 -0
  227. sglang/srt/multimodal/processors/base_processor.py +15 -7
  228. sglang/srt/multimodal/processors/dots_vlm.py +2 -3
  229. sglang/srt/multimodal/processors/internvl.py +20 -8
  230. sglang/srt/multimodal/processors/qwen_vl.py +8 -1
  231. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  232. sglang/srt/parser/jinja_template_utils.py +6 -0
  233. sglang/srt/sampling/sampling_batch_info.py +20 -2
  234. sglang/srt/sampling/sampling_params.py +7 -0
  235. sglang/srt/server_args.py +753 -295
  236. sglang/srt/server_args_config_parser.py +146 -0
  237. sglang/srt/single_batch_overlap.py +151 -0
  238. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  239. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  240. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  241. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  242. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  243. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  244. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +2 -1
  245. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +3 -1
  246. sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -755
  247. sglang/srt/speculative/eagle_worker.py +57 -25
  248. sglang/srt/speculative/ngram_utils.py +428 -0
  249. sglang/srt/speculative/ngram_worker.py +245 -0
  250. sglang/srt/speculative/spec_info.py +47 -0
  251. sglang/srt/speculative/spec_utils.py +606 -0
  252. sglang/srt/torch_memory_saver_adapter.py +5 -7
  253. sglang/srt/tracing/trace.py +32 -6
  254. sglang/srt/two_batch_overlap.py +8 -5
  255. sglang/srt/utils/__init__.py +2 -0
  256. sglang/srt/{utils.py → utils/common.py} +399 -74
  257. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +49 -5
  258. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  259. sglang/srt/utils/rpd_utils.py +452 -0
  260. sglang/srt/utils/slow_rank_detector.py +71 -0
  261. sglang/srt/warmup.py +8 -4
  262. sglang/srt/weight_sync/utils.py +1 -1
  263. sglang/test/get_logits_ut.py +57 -0
  264. sglang/test/run_eval.py +79 -11
  265. sglang/test/runners.py +1 -1
  266. sglang/test/simple_eval_common.py +5 -2
  267. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  268. sglang/test/test_block_fp8.py +2 -2
  269. sglang/test/test_deterministic.py +297 -0
  270. sglang/test/test_disaggregation_utils.py +12 -1
  271. sglang/test/test_programs.py +1 -1
  272. sglang/test/test_utils.py +355 -4
  273. sglang/utils.py +10 -1
  274. sglang/version.py +1 -1
  275. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/METADATA +34 -25
  276. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/RECORD +281 -210
  277. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  278. /sglang/srt/{remote_instance_weight_loader_utils.py → model_loader/remote_instance_weight_loader_utils.py} +0 -0
  279. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  280. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/WHEEL +0 -0
  281. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/licenses/LICENSE +0 -0
  282. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/top_level.txt +0 -0
@@ -7,6 +7,7 @@ from typing import Callable, Optional, Union
7
7
  import torch
8
8
  from torch.nn import Parameter
9
9
 
10
+ from sglang.srt.layers.utils import pad_or_narrow_weight
10
11
  from sglang.srt.utils import is_cpu
11
12
 
12
13
  __all__ = [
@@ -156,9 +157,17 @@ class _ColumnvLLMParameter(BasevLLMParameter):
156
157
  )
157
158
  else:
158
159
  if not use_presharded_weights:
159
- loaded_weight = loaded_weight.narrow(
160
- self.output_dim, tp_rank * shard_size, shard_size
161
- )
160
+ # Padding for special case like qwen2_5_VL's mlp which is not 8-aligned
161
+ start_idx = tp_rank * shard_size
162
+ end_idx = start_idx + shard_size
163
+ if end_idx > loaded_weight.shape[self.output_dim]:
164
+ loaded_weight = pad_or_narrow_weight(
165
+ loaded_weight, self.output_dim, start_idx, shard_size
166
+ )
167
+ else:
168
+ loaded_weight = loaded_weight.narrow(
169
+ self.output_dim, start_idx, shard_size
170
+ )
162
171
 
163
172
  assert param_data.shape == loaded_weight.shape
164
173
  param_data.copy_(loaded_weight)
@@ -258,9 +267,17 @@ class RowvLLMParameter(BasevLLMParameter):
258
267
 
259
268
  return
260
269
  else:
261
- loaded_weight = loaded_weight.narrow(
262
- self.input_dim, tp_rank * shard_size, shard_size
263
- )
270
+ # Padding for special case like qwen2_5_VL's mlp which is not 8-aligned
271
+ start_idx = tp_rank * shard_size
272
+ end_idx = start_idx + shard_size
273
+ if end_idx > loaded_weight.shape[self.input_dim]:
274
+ loaded_weight = pad_or_narrow_weight(
275
+ loaded_weight, self.input_dim, start_idx, shard_size
276
+ )
277
+ else:
278
+ loaded_weight = loaded_weight.narrow(
279
+ self.input_dim, start_idx, shard_size
280
+ )
264
281
 
265
282
  if len(loaded_weight.shape) == 0:
266
283
  loaded_weight = loaded_weight.reshape(1)
@@ -30,6 +30,7 @@ from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors_moe im
30
30
  from sglang.srt.layers.quantization.compressed_tensors.schemes import (
31
31
  CompressedTensorsScheme,
32
32
  CompressedTensorsW8A8Fp8,
33
+ CompressedTensorsW8A8Int8,
33
34
  CompressedTensorsW8A16Fp8,
34
35
  )
35
36
  from sglang.srt.layers.quantization.compressed_tensors.utils import (
@@ -2,10 +2,12 @@
2
2
 
3
3
  from .compressed_tensors_scheme import CompressedTensorsScheme
4
4
  from .compressed_tensors_w8a8_fp8 import CompressedTensorsW8A8Fp8
5
+ from .compressed_tensors_w8a8_int8 import CompressedTensorsW8A8Int8
5
6
  from .compressed_tensors_w8a16_fp8 import CompressedTensorsW8A16Fp8
6
7
 
7
8
  __all__ = [
8
9
  "CompressedTensorsScheme",
9
10
  "CompressedTensorsW8A8Fp8",
10
11
  "CompressedTensorsW8A16Fp8",
12
+ "CompressedTensorsW8A8Int8",
11
13
  ]
@@ -0,0 +1,173 @@
1
+ # Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/layers/quantization/compressed_tensors
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from typing import Callable, Optional
5
+
6
+ import torch
7
+ from compressed_tensors.quantization import QuantizationStrategy
8
+ from torch.nn import Parameter
9
+
10
+ from sglang.srt.layers.parameter import (
11
+ ChannelQuantScaleParameter,
12
+ ModelWeightParameter,
13
+ PerTensorScaleParameter,
14
+ )
15
+ from sglang.srt.layers.quantization.compressed_tensors.schemes import (
16
+ CompressedTensorsScheme,
17
+ )
18
+ from sglang.srt.layers.quantization.int8_kernel import per_token_quant_int8
19
+ from sglang.srt.layers.quantization.utils import requantize_with_max_scale
20
+ from sglang.srt.utils import is_cuda
21
+
22
+ _is_cuda = is_cuda()
23
+ if _is_cuda:
24
+ from sgl_kernel import int8_scaled_mm
25
+
26
+
27
+ class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
28
+
29
+ def __init__(
30
+ self, strategy: str, is_static_input_scheme: bool, input_symmetric: bool
31
+ ):
32
+ self.strategy = strategy
33
+ self.is_static_input_scheme = is_static_input_scheme
34
+ self.input_symmetric = input_symmetric
35
+
36
+ @classmethod
37
+ def get_min_capability(cls) -> int:
38
+ # lovelace and up
39
+ return 89
40
+
41
+ def process_weights_after_loading(self, layer) -> None:
42
+ # If per tensor, when we have a fused module (e.g. QKV) with per
43
+ # tensor scales (thus N scales being passed to the kernel),
44
+ # requantize so we can always run per channel
45
+ if self.strategy == QuantizationStrategy.TENSOR:
46
+ max_w_scale, weight = requantize_with_max_scale(
47
+ weight=layer.weight,
48
+ weight_scale=layer.weight_scale,
49
+ logical_widths=layer.logical_widths,
50
+ )
51
+
52
+ layer.weight = Parameter(weight.t(), requires_grad=False)
53
+ layer.weight_scale = Parameter(max_w_scale, requires_grad=False)
54
+
55
+ # If channelwise, scales are already lined up, so just transpose.
56
+ elif self.strategy == QuantizationStrategy.CHANNEL:
57
+ weight = layer.weight
58
+ weight_scale = layer.weight_scale.data
59
+
60
+ layer.weight = Parameter(weight.t(), requires_grad=False)
61
+ # required by torch.compile to be torch.nn.Parameter
62
+ layer.weight_scale = Parameter(weight_scale, requires_grad=False)
63
+
64
+ else:
65
+ raise ValueError(f"Unknown quantization strategy {self.strategy}")
66
+
67
+ # INPUT SCALE
68
+ if self.is_static_input_scheme and hasattr(layer, "input_scale"):
69
+ if self.input_symmetric:
70
+ layer.input_scale = Parameter(
71
+ layer.input_scale.max(), requires_grad=False
72
+ )
73
+ else:
74
+ input_scale = layer.input_scale
75
+ input_zero_point = layer.input_zero_point
76
+
77
+ # reconstruct the ranges
78
+ int8_traits = torch.iinfo(torch.int8)
79
+ azps = input_zero_point.to(dtype=torch.int32)
80
+ range_max = (input_scale * (int8_traits.max - azps)).max()
81
+ range_min = (input_scale * (int8_traits.min - azps)).min()
82
+
83
+ scale = (range_max - range_min) / (int8_traits.max - int8_traits.min)
84
+
85
+ # AZP loaded as int8 but used as int32
86
+ azp = (int8_traits.min - range_min / scale).to(dtype=torch.int32)
87
+
88
+ layer.input_scale = Parameter(scale, requires_grad=False)
89
+ layer.input_zero_point = Parameter(azp, requires_grad=False)
90
+ else:
91
+ layer.input_scale = None
92
+ layer.input_zero_point = None
93
+
94
+ # azp_adj is the AZP adjustment term, used to account for weights.
95
+ # It does not depend on scales or azp, so it is the same for
96
+ # static and dynamic quantization.
97
+ # For more details, see csrc/quantization/cutlass_w8a8/Epilogues.md
98
+ # https://github.com/vllm-project/vllm/blob/8d59dbb00044a588cab96bcdc028006ed922eb06/csrc/quantization/cutlass_w8a8/Epilogues.md
99
+ if not self.input_symmetric:
100
+ weight = layer.weight
101
+ azp_adj = weight.sum(dim=0, keepdim=True, dtype=torch.int32)
102
+ if self.is_static_input_scheme:
103
+ # cutlass_w8a8 requires azp to be folded into azp_adj
104
+ # in the per-tensor case
105
+ azp_adj = layer.input_zero_point * azp_adj
106
+ layer.azp_adj = Parameter(azp_adj, requires_grad=False)
107
+ else:
108
+ layer.azp_adj = None
109
+
110
+ def create_weights(
111
+ self,
112
+ layer: torch.nn.Module,
113
+ output_partition_sizes: list[int],
114
+ input_size_per_partition: int,
115
+ params_dtype: torch.dtype,
116
+ weight_loader: Callable,
117
+ **kwargs,
118
+ ):
119
+ output_size_per_partition = sum(output_partition_sizes)
120
+ layer.logical_widths = output_partition_sizes
121
+
122
+ # WEIGHT
123
+ weight = ModelWeightParameter(
124
+ data=torch.empty(
125
+ output_size_per_partition, input_size_per_partition, dtype=torch.int8
126
+ ),
127
+ input_dim=1,
128
+ output_dim=0,
129
+ weight_loader=weight_loader,
130
+ )
131
+
132
+ layer.register_parameter("weight", weight)
133
+
134
+ # WEIGHT SCALE
135
+ if self.strategy == QuantizationStrategy.CHANNEL:
136
+ weight_scale = ChannelQuantScaleParameter(
137
+ data=torch.empty((sum(output_partition_sizes), 1), dtype=torch.float32),
138
+ output_dim=0,
139
+ weight_loader=weight_loader,
140
+ )
141
+ else:
142
+ assert self.strategy == QuantizationStrategy.TENSOR
143
+ weight_scale = PerTensorScaleParameter(
144
+ data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
145
+ weight_loader=weight_loader,
146
+ )
147
+ layer.register_parameter("weight_scale", weight_scale)
148
+
149
+ # INPUT SCALE
150
+ if self.is_static_input_scheme:
151
+ input_scale = PerTensorScaleParameter(
152
+ data=torch.empty(1, dtype=torch.float32), weight_loader=weight_loader
153
+ )
154
+ layer.register_parameter("input_scale", input_scale)
155
+
156
+ if not self.input_symmetric:
157
+ # Note: compressed-tensors stores the zp using the same dtype
158
+ # as the weights
159
+ # AZP loaded as int8 but used as int32
160
+ input_zero_point = PerTensorScaleParameter(
161
+ data=torch.empty(1, dtype=torch.int8), weight_loader=weight_loader
162
+ )
163
+ layer.register_parameter("input_zero_point", input_zero_point)
164
+
165
+ def apply_weights(
166
+ self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor]
167
+ ) -> torch.Tensor:
168
+ # TODO: add cutlass_scaled_mm_azp support
169
+ x_q, x_scale = per_token_quant_int8(x)
170
+
171
+ return int8_scaled_mm(
172
+ x_q, layer.weight, x_scale, layer.weight_scale, out_dtype=x.dtype, bias=bias
173
+ )
@@ -1,8 +1,6 @@
1
1
  import logging
2
2
 
3
- import torch
4
-
5
- from sglang.srt.utils import get_bool_env_var, get_device_sm
3
+ from sglang.srt.utils import get_bool_env_var, get_device_sm, is_blackwell
6
4
 
7
5
  logger = logging.getLogger(__name__)
8
6
 
@@ -15,18 +13,12 @@ def _compute_enable_deep_gemm():
15
13
  try:
16
14
  import deep_gemm
17
15
  except ImportError:
18
- logger.warning("Failed to import deep_gemm, disable ENABLE_JIT_DEEPGEMM.")
19
16
  return False
20
17
 
21
18
  return get_bool_env_var("SGL_ENABLE_JIT_DEEPGEMM", default="true")
22
19
 
23
20
 
24
- def _is_blackwell_arch() -> bool:
25
- major, minor = torch.cuda.get_device_capability(torch.cuda.current_device())
26
- return major == 10
27
-
28
-
29
21
  ENABLE_JIT_DEEPGEMM = _compute_enable_deep_gemm()
30
22
 
31
- DEEPGEMM_BLACKWELL = ENABLE_JIT_DEEPGEMM and _is_blackwell_arch()
23
+ DEEPGEMM_BLACKWELL = ENABLE_JIT_DEEPGEMM and is_blackwell()
32
24
  DEEPGEMM_SCALE_UE8M0 = DEEPGEMM_BLACKWELL
@@ -358,8 +358,8 @@ class Fp8LinearMethod(LinearMethodBase):
358
358
  return
359
359
  else:
360
360
  weight, weight_scale = layer.weight.data, layer.weight_scale_inv.data
361
- layer.weight = Parameter(weight, requires_grad=False)
362
- layer.weight_scale_inv = Parameter(weight_scale, requires_grad=False)
361
+ layer.weight.data = weight.data
362
+ layer.weight_scale_inv.data = weight_scale.data
363
363
  else:
364
364
  layer.weight = Parameter(layer.weight.data, requires_grad=False)
365
365
 
@@ -732,7 +732,7 @@ def apply_fp8_linear(
732
732
  # final solution should be: 1. add support to per-tensor activation scaling.
733
733
  # 2. solve the torch.compile error from weight_scale.numel() == 1 and x_scale.numel() > 1 (below line#308)
734
734
  if _is_hip and weight_scale.numel() == 1:
735
- qinput, x_scale = ops.scaled_fp8_quant(
735
+ qinput, x_scale = scaled_fp8_quant(
736
736
  input_2d,
737
737
  input_scale,
738
738
  use_per_token_if_dynamic=use_per_token_if_dynamic,
@@ -47,6 +47,7 @@ if TYPE_CHECKING:
47
47
  CombineInput,
48
48
  StandardDispatchOutput,
49
49
  )
50
+ from sglang.srt.single_batch_overlap import DownGemmOverlapArgs
50
51
 
51
52
  if is_cuda():
52
53
  from sgl_kernel import scaled_fp4_quant
@@ -77,6 +78,13 @@ logger = logging.getLogger(__name__)
77
78
  CUTEDSL_MOE_SCALAR_INPUT_SCALE = get_bool_env_var(
78
79
  "SGLANG_CUTEDSL_MOE_SCALAR_INPUT_SCALE", "true"
79
80
  )
81
+ USE_CUTLASS_BACKEND_FOR_FP4_GEMM = get_bool_env_var(
82
+ "SGLANG_USE_CUTLASS_BACKEND_FOR_FP4_GEMM"
83
+ )
84
+ # TODO make it true by default when the DeepEP PR is merged
85
+ CUTEDSL_MOE_NVFP4_DISPATCH = get_bool_env_var(
86
+ "SGLANG_CUTEDSL_MOE_NVFP4_DISPATCH", "false"
87
+ )
80
88
 
81
89
  # Supported activation schemes for the current configuration
82
90
  ACTIVATION_SCHEMES = ["static"]
@@ -844,14 +852,25 @@ class ModelOptFp4LinearMethod(LinearMethodBase):
844
852
  if enable_flashinfer_fp4_gemm:
845
853
  w = layer.weight.T
846
854
  w_scale_interleaved = layer.weight_scale_interleaved.T
847
- out = fp4_gemm(
848
- x_fp4,
849
- w,
850
- x_scale_interleaved,
851
- w_scale_interleaved,
852
- layer.alpha,
853
- output_dtype,
854
- )
855
+ if USE_CUTLASS_BACKEND_FOR_FP4_GEMM:
856
+ out = fp4_gemm(
857
+ x_fp4,
858
+ w,
859
+ x_scale_interleaved,
860
+ w_scale_interleaved,
861
+ layer.alpha,
862
+ output_dtype,
863
+ backend="cutlass",
864
+ )
865
+ else:
866
+ out = fp4_gemm(
867
+ x_fp4,
868
+ w,
869
+ x_scale_interleaved,
870
+ w_scale_interleaved,
871
+ layer.alpha,
872
+ output_dtype,
873
+ )
855
874
  if bias is not None:
856
875
  out = out + bias
857
876
  return out.view(*output_shape)
@@ -1220,6 +1239,10 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
1220
1239
 
1221
1240
  w13_input_scale = _slice_scale(w13_input_scale)
1222
1241
  w2_input_scale = _slice_scale(w2_input_scale)
1242
+
1243
+ if CUTEDSL_MOE_NVFP4_DISPATCH:
1244
+ assert torch.all(w13_input_scale == w13_input_scale[0])
1245
+ w13_input_scale = w13_input_scale[0]
1223
1246
  else:
1224
1247
  w13_input_scale = layer.w13_input_scale.max(dim=1).values.to(torch.float32)
1225
1248
  w2_input_scale = layer.w2_input_scale
@@ -1446,6 +1469,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
1446
1469
  x: torch.Tensor,
1447
1470
  masked_m: torch.Tensor,
1448
1471
  moe_runner_config: MoeRunnerConfig,
1472
+ down_gemm_overlap_args: Optional["DownGemmOverlapArgs"],
1449
1473
  ) -> torch.Tensor:
1450
1474
  assert (
1451
1475
  moe_runner_config.activation == "silu"
@@ -1462,7 +1486,9 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
1462
1486
 
1463
1487
  out = flashinfer_cutedsl_moe_masked(
1464
1488
  hidden_states=x,
1465
- input_global_scale=layer.w13_input_scale_quant,
1489
+ input_global_scale=(
1490
+ None if CUTEDSL_MOE_NVFP4_DISPATCH else layer.w13_input_scale_quant
1491
+ ),
1466
1492
  w1=layer.w13_weight,
1467
1493
  w1_blockscale=layer.w13_blockscale_swizzled,
1468
1494
  w1_alpha=layer.g1_alphas,
@@ -1471,5 +1497,14 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
1471
1497
  w2_blockscale=layer.w2_blockscale_swizzled,
1472
1498
  w2_alpha=layer.g2_alphas,
1473
1499
  masked_m=masked_m,
1500
+ **(
1501
+ dict(
1502
+ down_sm_count=down_gemm_overlap_args.num_sms,
1503
+ down_signals=down_gemm_overlap_args.signal,
1504
+ down_start_event=down_gemm_overlap_args.start_event,
1505
+ )
1506
+ if down_gemm_overlap_args is not None
1507
+ else {}
1508
+ ),
1474
1509
  )
1475
1510
  return out
@@ -731,8 +731,8 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
731
731
  quant_info = TritonMoeQuantInfo(
732
732
  w13_weight=layer.w13_weight,
733
733
  w2_weight=layer.w2_weight,
734
- w13_weight_bias=layer.w13_weight_bias,
735
- w2_weight_bias=layer.w2_weight_bias,
734
+ b13=getattr(layer, "w13_weight_bias", None),
735
+ b2=getattr(layer, "w2_weight_bias", None),
736
736
  )
737
737
  return self.runner.run(dispatch_output, quant_info)
738
738
 
@@ -843,10 +843,18 @@ class Mxfp4DynamicQuantMoEMethod(FusedMoEMethodBase):
843
843
  topk_weights = topk_weights.to(
844
844
  torch.float32
845
845
  ) # aiter's moe_sorting requires topk_weights to be FP32
846
+
847
+ if hasattr(torch, "float4_e2m1fn_x2"):
848
+ w13_weight = layer.w13_weight.view(torch.float4_e2m1fn_x2)
849
+ w2_weight = layer.w2_weight.view(torch.float4_e2m1fn_x2)
850
+ else:
851
+ w13_weight = layer.w13_weight
852
+ w2_weight = layer.w2_weight
853
+
846
854
  output = fused_moe(
847
855
  x,
848
- layer.w13_weight,
849
- layer.w2_weight,
856
+ w13_weight,
857
+ w2_weight,
850
858
  topk_weights,
851
859
  topk_ids,
852
860
  quant_type=QuantType.per_1x32,
@@ -12,7 +12,7 @@ from aiter.utility.fp4_utils import e8m0_shuffle
12
12
 
13
13
  from sglang.srt.layers.moe import MoeRunnerConfig
14
14
  from sglang.srt.layers.quantization.base_config import FusedMoEMethodBase
15
- from sglang.srt.utils import get_bool_env_var, mxfp_supported, set_weight_attrs
15
+ from sglang.srt.utils import get_bool_env_var, is_hip, mxfp_supported, set_weight_attrs
16
16
 
17
17
  if TYPE_CHECKING:
18
18
  from sglang.srt.layers.moe.token_dispatcher import (
@@ -23,6 +23,8 @@ if TYPE_CHECKING:
23
23
 
24
24
  logger = logging.getLogger(__name__)
25
25
 
26
+ _is_hip = is_hip()
27
+
26
28
  __all__ = ["QuarkMoEMethod", "QuarkW4A4MXFp4MoEMethod"]
27
29
 
28
30
  OCP_MX_BLOCK_SIZE = 32
@@ -182,11 +184,22 @@ class QuarkW4A4MXFp4MoEMethod(QuarkMoEMethod):
182
184
  topk_output = dispatch_output.topk_output
183
185
  moe_runner_config = self.moe_runner_config
184
186
  topk_weights, topk_ids, _ = topk_output
187
+ if _is_hip:
188
+ topk_weights = topk_weights.to(
189
+ torch.float32
190
+ ) # aiter's moe_sorting requires topk_weights to be FP32
191
+
192
+ if hasattr(torch, "float4_e2m1fn_x2"):
193
+ w13_weight = layer.w13_weight.view(torch.float4_e2m1fn_x2)
194
+ w2_weight = layer.w2_weight.view(torch.float4_e2m1fn_x2)
195
+ else:
196
+ w13_weight = layer.w13_weight
197
+ w2_weight = layer.w2_weight
185
198
 
186
199
  output = fused_moe(
187
200
  x,
188
- layer.w13_weight,
189
- layer.w2_weight,
201
+ w13_weight,
202
+ w2_weight,
190
203
  topk_weights,
191
204
  topk_ids,
192
205
  quant_type=QuantType.per_1x32,
@@ -19,10 +19,6 @@ from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
19
19
  from sglang.srt.layers.quantization.utils import is_layer_skipped
20
20
  from sglang.srt.utils import is_npu, set_weight_attrs
21
21
 
22
- _is_npu = is_npu()
23
- if not _is_npu:
24
- from sglang.srt.layers.moe.cutlass_w4a8_moe import cutlass_w4a8_moe
25
-
26
22
  if TYPE_CHECKING:
27
23
  from sglang.srt.layers.moe import MoeRunnerConfig
28
24
  from sglang.srt.layers.moe.ep_moe.layer import EPMoE
@@ -393,13 +393,23 @@ class W8A8Int8LinearMethod(LinearMethodBase):
393
393
  x.dtype,
394
394
  True, # is_vnni
395
395
  )
396
-
397
396
  x_q, x_scale = per_token_quant_int8(x)
398
397
 
399
- return int8_scaled_mm(
400
- x_q, layer.weight, x_scale, layer.weight_scale, out_dtype=x.dtype, bias=bias
398
+ x_q_2d = x_q.view(-1, x_q.shape[-1])
399
+ x_scale_2d = x_scale.view(-1, x_scale.shape[-1])
400
+ output_shape = [*x_q.shape[:-1], layer.weight.shape[1]]
401
+
402
+ output = int8_scaled_mm(
403
+ x_q_2d,
404
+ layer.weight,
405
+ x_scale_2d,
406
+ layer.weight_scale,
407
+ out_dtype=x.dtype,
408
+ bias=bias,
401
409
  )
402
410
 
411
+ return output.view(output_shape)
412
+
403
413
 
404
414
  class W8A8Int8MoEMethod(FusedMoEMethodBase):
405
415
  """MoE method for INT8.
@@ -638,6 +648,7 @@ class NPU_W8A8LinearMethodImpl:
638
648
  layer.weight.data = layer.weight.data.transpose(0, 1).contiguous()
639
649
  layer.weight_scale.data = torch.flatten(layer.weight_scale.data)
640
650
  layer.weight_offset.data = torch.flatten(layer.weight_offset.data)
651
+ layer.weight.data = torch_npu.npu_format_cast(layer.weight.data, 29)
641
652
 
642
653
 
643
654
  class NPU_W8A8LinearMethodMTImpl:
@@ -830,6 +841,7 @@ class NPU_W8A8DynamicLinearMethodImpl:
830
841
  layer.weight_scale.data = layer.weight_scale.data.flatten()
831
842
  layer.weight_scale_fp32 = layer.weight_scale.data.to(torch.float32)
832
843
  layer.weight_offset.data = layer.weight_offset.data.flatten()
844
+ layer.weight.data = torch_npu.npu_format_cast(layer.weight.data, 29)
833
845
 
834
846
 
835
847
  class NPU_W8A8DynamicLinearMethod(LinearMethodBase):