sglang 0.5.3rc0__py3-none-any.whl → 0.5.3rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +7 -9
- sglang/bench_one_batch_server.py +321 -31
- sglang/bench_serving.py +10 -3
- sglang/global_config.py +2 -2
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/launch_server.py +14 -0
- sglang/profiler.py +2 -2
- sglang/srt/batch_invariant_ops/__init__.py +27 -0
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/dots_ocr.py +64 -0
- sglang/srt/configs/falcon_h1.py +360 -0
- sglang/srt/configs/load_config.py +8 -0
- sglang/srt/configs/model_config.py +160 -105
- sglang/srt/configs/qwen3_vl.py +586 -0
- sglang/srt/constrained/base_grammar_backend.py +1 -0
- sglang/srt/constrained/outlines_jump_forward.py +1 -1
- sglang/srt/constrained/xgrammar_backend.py +6 -4
- sglang/srt/debug_utils/dumper.py +10 -3
- sglang/srt/disaggregation/ascend/conn.py +2 -2
- sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
- sglang/srt/disaggregation/common/conn.py +266 -98
- sglang/srt/disaggregation/decode.py +50 -9
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +25 -16
- sglang/srt/disaggregation/mooncake/conn.py +51 -541
- sglang/srt/disaggregation/nixl/conn.py +148 -39
- sglang/srt/disaggregation/prefill.py +31 -14
- sglang/srt/disaggregation/utils.py +36 -5
- sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
- sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
- sglang/srt/distributed/parallel_state.py +135 -80
- sglang/srt/entrypoints/engine.py +23 -3
- sglang/srt/entrypoints/grpc_request_manager.py +330 -55
- sglang/srt/entrypoints/grpc_server.py +232 -102
- sglang/srt/entrypoints/http_server.py +49 -9
- sglang/srt/entrypoints/openai/protocol.py +110 -5
- sglang/srt/entrypoints/openai/serving_base.py +25 -6
- sglang/srt/entrypoints/openai/serving_chat.py +178 -49
- sglang/srt/entrypoints/openai/serving_completions.py +5 -3
- sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
- sglang/srt/entrypoints/openai/serving_responses.py +42 -0
- sglang/srt/environ.py +285 -0
- sglang/srt/eplb/expert_location.py +30 -5
- sglang/srt/function_call/function_call_parser.py +3 -2
- sglang/srt/function_call/glm4_moe_detector.py +3 -3
- sglang/srt/function_call/gpt_oss_detector.py +23 -0
- sglang/srt/function_call/json_array_parser.py +63 -0
- sglang/srt/function_call/kimik2_detector.py +17 -4
- sglang/srt/function_call/utils.py +96 -5
- sglang/srt/grpc/compile_proto.py +245 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +73 -68
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +60 -53
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +3 -0
- sglang/srt/layers/activation.py +7 -6
- sglang/srt/layers/attention/aiter_backend.py +14 -15
- sglang/srt/layers/attention/ascend_backend.py +108 -9
- sglang/srt/layers/attention/attention_registry.py +206 -0
- sglang/srt/layers/attention/base_attn_backend.py +12 -3
- sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +2 -2
- sglang/srt/layers/attention/fla/fused_recurrent.py +4 -4
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +2 -2
- sglang/srt/layers/attention/flashattention_backend.py +41 -8
- sglang/srt/layers/attention/flashinfer_backend.py +112 -194
- sglang/srt/layers/attention/flashinfer_mla_backend.py +11 -15
- sglang/srt/layers/attention/flashmla_backend.py +7 -5
- sglang/srt/layers/attention/hybrid_attn_backend.py +11 -3
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +72 -72
- sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +15 -98
- sglang/srt/layers/attention/mamba/mamba.py +566 -1
- sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
- sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
- sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
- sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
- sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
- sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
- sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
- sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
- sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
- sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
- sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
- sglang/srt/layers/attention/nsa/transform_index.py +144 -0
- sglang/srt/layers/attention/nsa/utils.py +24 -0
- sglang/srt/layers/attention/nsa_backend.py +887 -0
- sglang/srt/layers/attention/tbo_backend.py +6 -6
- sglang/srt/layers/attention/torch_flex_backend.py +325 -0
- sglang/srt/layers/attention/triton_backend.py +42 -9
- sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
- sglang/srt/layers/attention/trtllm_mla_backend.py +178 -34
- sglang/srt/layers/attention/vision.py +58 -0
- sglang/srt/layers/attention/wave_backend.py +4 -4
- sglang/srt/layers/communicator.py +8 -0
- sglang/srt/layers/dp_attention.py +11 -1
- sglang/srt/layers/elementwise.py +3 -1
- sglang/srt/layers/layernorm.py +2 -0
- sglang/srt/layers/linear.py +21 -4
- sglang/srt/layers/logits_processor.py +15 -2
- sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
- sglang/srt/layers/moe/ep_moe/layer.py +147 -74
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +52 -25
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +6 -2
- sglang/srt/layers/moe/fused_moe_triton/layer.py +11 -12
- sglang/srt/layers/moe/token_dispatcher/deepep.py +77 -19
- sglang/srt/layers/moe/utils.py +10 -0
- sglang/srt/layers/parameter.py +23 -6
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
- sglang/srt/layers/quantization/fp8.py +2 -2
- sglang/srt/layers/quantization/fp8_utils.py +1 -1
- sglang/srt/layers/quantization/modelopt_quant.py +44 -9
- sglang/srt/layers/quantization/mxfp4.py +12 -4
- sglang/srt/layers/quantization/quark/quark_moe.py +16 -3
- sglang/srt/layers/quantization/w4afp8.py +0 -4
- sglang/srt/layers/quantization/w8a8_int8.py +15 -3
- sglang/srt/layers/rotary_embedding.py +78 -31
- sglang/srt/layers/sampler.py +52 -4
- sglang/srt/layers/utils.py +23 -0
- sglang/srt/lora/backend/base_backend.py +3 -3
- sglang/srt/lora/backend/chunked_backend.py +348 -0
- sglang/srt/lora/backend/triton_backend.py +10 -4
- sglang/srt/lora/lora.py +7 -5
- sglang/srt/lora/lora_manager.py +17 -6
- sglang/srt/lora/mem_pool.py +1 -1
- sglang/srt/lora/triton_ops/__init__.py +4 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
- sglang/srt/lora/utils.py +7 -5
- sglang/srt/managers/cache_controller.py +42 -142
- sglang/srt/managers/data_parallel_controller.py +11 -46
- sglang/srt/managers/detokenizer_manager.py +11 -11
- sglang/srt/managers/io_struct.py +162 -118
- sglang/srt/managers/mm_utils.py +43 -6
- sglang/srt/managers/multi_tokenizer_mixin.py +17 -17
- sglang/srt/managers/multimodal_processor.py +1 -2
- sglang/srt/managers/overlap_utils.py +53 -0
- sglang/srt/managers/schedule_batch.py +167 -86
- sglang/srt/managers/schedule_policy.py +143 -16
- sglang/srt/managers/scheduler.py +359 -214
- sglang/srt/managers/scheduler_input_blocker.py +1 -1
- sglang/srt/managers/scheduler_metrics_mixin.py +98 -126
- sglang/srt/managers/scheduler_output_processor_mixin.py +21 -12
- sglang/srt/managers/scheduler_profiler_mixin.py +5 -5
- sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
- sglang/srt/managers/tokenizer_communicator_mixin.py +111 -5
- sglang/srt/managers/tokenizer_manager.py +84 -136
- sglang/srt/managers/tp_worker.py +39 -29
- sglang/srt/managers/tp_worker_overlap_thread.py +33 -41
- sglang/srt/managers/utils.py +1 -45
- sglang/srt/mem_cache/allocator.py +14 -20
- sglang/srt/mem_cache/allocator_ascend.py +41 -27
- sglang/srt/mem_cache/base_prefix_cache.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +8 -1
- sglang/srt/mem_cache/evict_policy.py +23 -0
- sglang/srt/mem_cache/hicache_storage.py +40 -1
- sglang/srt/mem_cache/hiradix_cache.py +119 -32
- sglang/srt/mem_cache/memory_pool.py +188 -10
- sglang/srt/mem_cache/memory_pool_host.py +134 -182
- sglang/srt/mem_cache/radix_cache.py +222 -71
- sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
- sglang/srt/mem_cache/storage/__init__.py +10 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
- sglang/srt/mem_cache/storage/backend_factory.py +223 -0
- sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
- sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +173 -58
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +10 -6
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +117 -10
- sglang/srt/mem_cache/swa_radix_cache.py +25 -34
- sglang/srt/metrics/collector.py +82 -120
- sglang/srt/metrics/func_timer.py +2 -7
- sglang/srt/metrics/utils.py +8 -1
- sglang/srt/model_executor/cpu_graph_runner.py +2 -2
- sglang/srt/model_executor/cuda_graph_runner.py +39 -32
- sglang/srt/model_executor/forward_batch_info.py +23 -38
- sglang/srt/model_executor/model_runner.py +131 -183
- sglang/srt/model_executor/npu_graph_runner.py +12 -5
- sglang/srt/model_loader/loader.py +14 -10
- sglang/srt/model_loader/weight_utils.py +156 -2
- sglang/srt/models/bailing_moe.py +27 -4
- sglang/srt/models/deepseek_nextn.py +6 -1
- sglang/srt/models/deepseek_v2.py +536 -153
- sglang/srt/models/dots_ocr.py +173 -0
- sglang/srt/models/falcon_h1.py +576 -0
- sglang/srt/models/gemma3_causal.py +0 -2
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/gemma3n_mm.py +1 -1
- sglang/srt/models/glm4_moe.py +3 -3
- sglang/srt/models/glm4_moe_nextn.py +2 -2
- sglang/srt/models/glm4v.py +1 -1
- sglang/srt/models/glm4v_moe.py +1 -1
- sglang/srt/models/gpt_oss.py +7 -30
- sglang/srt/models/kimi_vl_moonvit.py +2 -2
- sglang/srt/models/llama.py +4 -0
- sglang/srt/models/longcat_flash.py +1 -1
- sglang/srt/models/longcat_flash_nextn.py +1 -1
- sglang/srt/models/mllama4.py +15 -4
- sglang/srt/models/qwen2.py +0 -7
- sglang/srt/models/qwen2_5_vl.py +2 -2
- sglang/srt/models/qwen2_audio.py +1 -1
- sglang/srt/models/qwen2_moe.py +64 -1
- sglang/srt/models/qwen2_vl.py +1 -1
- sglang/srt/models/qwen3.py +18 -3
- sglang/srt/models/qwen3_moe.py +31 -3
- sglang/srt/models/qwen3_next.py +36 -9
- sglang/srt/models/qwen3_vl.py +787 -0
- sglang/srt/models/qwen3_vl_moe.py +471 -0
- sglang/srt/models/registry.py +15 -3
- sglang/srt/models/sarashina2_vision.py +269 -0
- sglang/srt/models/solar.py +505 -0
- sglang/srt/models/starcoder2.py +357 -0
- sglang/srt/models/torch_native_llama.py +9 -2
- sglang/srt/models/utils.py +51 -0
- sglang/srt/multimodal/processors/base_processor.py +15 -7
- sglang/srt/multimodal/processors/dots_vlm.py +2 -3
- sglang/srt/multimodal/processors/internvl.py +20 -8
- sglang/srt/multimodal/processors/qwen_vl.py +8 -1
- sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
- sglang/srt/parser/jinja_template_utils.py +6 -0
- sglang/srt/sampling/sampling_batch_info.py +20 -2
- sglang/srt/sampling/sampling_params.py +7 -0
- sglang/srt/server_args.py +753 -295
- sglang/srt/server_args_config_parser.py +146 -0
- sglang/srt/single_batch_overlap.py +151 -0
- sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
- sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
- sglang/srt/speculative/cpp_ngram/param.h +125 -0
- sglang/srt/speculative/cpp_ngram/queue.h +71 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +2 -1
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +3 -1
- sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -755
- sglang/srt/speculative/eagle_worker.py +57 -25
- sglang/srt/speculative/ngram_utils.py +428 -0
- sglang/srt/speculative/ngram_worker.py +245 -0
- sglang/srt/speculative/spec_info.py +47 -0
- sglang/srt/speculative/spec_utils.py +606 -0
- sglang/srt/torch_memory_saver_adapter.py +5 -7
- sglang/srt/tracing/trace.py +32 -6
- sglang/srt/two_batch_overlap.py +8 -5
- sglang/srt/utils/__init__.py +2 -0
- sglang/srt/{utils.py → utils/common.py} +399 -74
- sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +49 -5
- sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
- sglang/srt/utils/rpd_utils.py +452 -0
- sglang/srt/utils/slow_rank_detector.py +71 -0
- sglang/srt/warmup.py +8 -4
- sglang/srt/weight_sync/utils.py +1 -1
- sglang/test/get_logits_ut.py +57 -0
- sglang/test/run_eval.py +79 -11
- sglang/test/runners.py +1 -1
- sglang/test/simple_eval_common.py +5 -2
- sglang/test/simple_eval_mmmu_vlm.py +441 -0
- sglang/test/test_block_fp8.py +2 -2
- sglang/test/test_deterministic.py +297 -0
- sglang/test/test_disaggregation_utils.py +12 -1
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +355 -4
- sglang/utils.py +10 -1
- sglang/version.py +1 -1
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/METADATA +34 -25
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/RECORD +281 -210
- sglang/srt/mem_cache/lora_radix_cache.py +0 -421
- /sglang/srt/{remote_instance_weight_loader_utils.py → model_loader/remote_instance_weight_loader_utils.py} +0 -0
- /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/WHEEL +0 -0
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/top_level.txt +0 -0
sglang/srt/layers/parameter.py
CHANGED
@@ -7,6 +7,7 @@ from typing import Callable, Optional, Union
|
|
7
7
|
import torch
|
8
8
|
from torch.nn import Parameter
|
9
9
|
|
10
|
+
from sglang.srt.layers.utils import pad_or_narrow_weight
|
10
11
|
from sglang.srt.utils import is_cpu
|
11
12
|
|
12
13
|
__all__ = [
|
@@ -156,9 +157,17 @@ class _ColumnvLLMParameter(BasevLLMParameter):
|
|
156
157
|
)
|
157
158
|
else:
|
158
159
|
if not use_presharded_weights:
|
159
|
-
|
160
|
-
|
161
|
-
|
160
|
+
# Padding for special case like qwen2_5_VL's mlp which is not 8-aligned
|
161
|
+
start_idx = tp_rank * shard_size
|
162
|
+
end_idx = start_idx + shard_size
|
163
|
+
if end_idx > loaded_weight.shape[self.output_dim]:
|
164
|
+
loaded_weight = pad_or_narrow_weight(
|
165
|
+
loaded_weight, self.output_dim, start_idx, shard_size
|
166
|
+
)
|
167
|
+
else:
|
168
|
+
loaded_weight = loaded_weight.narrow(
|
169
|
+
self.output_dim, start_idx, shard_size
|
170
|
+
)
|
162
171
|
|
163
172
|
assert param_data.shape == loaded_weight.shape
|
164
173
|
param_data.copy_(loaded_weight)
|
@@ -258,9 +267,17 @@ class RowvLLMParameter(BasevLLMParameter):
|
|
258
267
|
|
259
268
|
return
|
260
269
|
else:
|
261
|
-
|
262
|
-
|
263
|
-
|
270
|
+
# Padding for special case like qwen2_5_VL's mlp which is not 8-aligned
|
271
|
+
start_idx = tp_rank * shard_size
|
272
|
+
end_idx = start_idx + shard_size
|
273
|
+
if end_idx > loaded_weight.shape[self.input_dim]:
|
274
|
+
loaded_weight = pad_or_narrow_weight(
|
275
|
+
loaded_weight, self.input_dim, start_idx, shard_size
|
276
|
+
)
|
277
|
+
else:
|
278
|
+
loaded_weight = loaded_weight.narrow(
|
279
|
+
self.input_dim, start_idx, shard_size
|
280
|
+
)
|
264
281
|
|
265
282
|
if len(loaded_weight.shape) == 0:
|
266
283
|
loaded_weight = loaded_weight.reshape(1)
|
@@ -30,6 +30,7 @@ from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors_moe im
|
|
30
30
|
from sglang.srt.layers.quantization.compressed_tensors.schemes import (
|
31
31
|
CompressedTensorsScheme,
|
32
32
|
CompressedTensorsW8A8Fp8,
|
33
|
+
CompressedTensorsW8A8Int8,
|
33
34
|
CompressedTensorsW8A16Fp8,
|
34
35
|
)
|
35
36
|
from sglang.srt.layers.quantization.compressed_tensors.utils import (
|
@@ -2,10 +2,12 @@
|
|
2
2
|
|
3
3
|
from .compressed_tensors_scheme import CompressedTensorsScheme
|
4
4
|
from .compressed_tensors_w8a8_fp8 import CompressedTensorsW8A8Fp8
|
5
|
+
from .compressed_tensors_w8a8_int8 import CompressedTensorsW8A8Int8
|
5
6
|
from .compressed_tensors_w8a16_fp8 import CompressedTensorsW8A16Fp8
|
6
7
|
|
7
8
|
__all__ = [
|
8
9
|
"CompressedTensorsScheme",
|
9
10
|
"CompressedTensorsW8A8Fp8",
|
10
11
|
"CompressedTensorsW8A16Fp8",
|
12
|
+
"CompressedTensorsW8A8Int8",
|
11
13
|
]
|
@@ -0,0 +1,173 @@
|
|
1
|
+
# Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/layers/quantization/compressed_tensors
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
3
|
+
|
4
|
+
from typing import Callable, Optional
|
5
|
+
|
6
|
+
import torch
|
7
|
+
from compressed_tensors.quantization import QuantizationStrategy
|
8
|
+
from torch.nn import Parameter
|
9
|
+
|
10
|
+
from sglang.srt.layers.parameter import (
|
11
|
+
ChannelQuantScaleParameter,
|
12
|
+
ModelWeightParameter,
|
13
|
+
PerTensorScaleParameter,
|
14
|
+
)
|
15
|
+
from sglang.srt.layers.quantization.compressed_tensors.schemes import (
|
16
|
+
CompressedTensorsScheme,
|
17
|
+
)
|
18
|
+
from sglang.srt.layers.quantization.int8_kernel import per_token_quant_int8
|
19
|
+
from sglang.srt.layers.quantization.utils import requantize_with_max_scale
|
20
|
+
from sglang.srt.utils import is_cuda
|
21
|
+
|
22
|
+
_is_cuda = is_cuda()
|
23
|
+
if _is_cuda:
|
24
|
+
from sgl_kernel import int8_scaled_mm
|
25
|
+
|
26
|
+
|
27
|
+
class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
|
28
|
+
|
29
|
+
def __init__(
|
30
|
+
self, strategy: str, is_static_input_scheme: bool, input_symmetric: bool
|
31
|
+
):
|
32
|
+
self.strategy = strategy
|
33
|
+
self.is_static_input_scheme = is_static_input_scheme
|
34
|
+
self.input_symmetric = input_symmetric
|
35
|
+
|
36
|
+
@classmethod
|
37
|
+
def get_min_capability(cls) -> int:
|
38
|
+
# lovelace and up
|
39
|
+
return 89
|
40
|
+
|
41
|
+
def process_weights_after_loading(self, layer) -> None:
|
42
|
+
# If per tensor, when we have a fused module (e.g. QKV) with per
|
43
|
+
# tensor scales (thus N scales being passed to the kernel),
|
44
|
+
# requantize so we can always run per channel
|
45
|
+
if self.strategy == QuantizationStrategy.TENSOR:
|
46
|
+
max_w_scale, weight = requantize_with_max_scale(
|
47
|
+
weight=layer.weight,
|
48
|
+
weight_scale=layer.weight_scale,
|
49
|
+
logical_widths=layer.logical_widths,
|
50
|
+
)
|
51
|
+
|
52
|
+
layer.weight = Parameter(weight.t(), requires_grad=False)
|
53
|
+
layer.weight_scale = Parameter(max_w_scale, requires_grad=False)
|
54
|
+
|
55
|
+
# If channelwise, scales are already lined up, so just transpose.
|
56
|
+
elif self.strategy == QuantizationStrategy.CHANNEL:
|
57
|
+
weight = layer.weight
|
58
|
+
weight_scale = layer.weight_scale.data
|
59
|
+
|
60
|
+
layer.weight = Parameter(weight.t(), requires_grad=False)
|
61
|
+
# required by torch.compile to be torch.nn.Parameter
|
62
|
+
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
|
63
|
+
|
64
|
+
else:
|
65
|
+
raise ValueError(f"Unknown quantization strategy {self.strategy}")
|
66
|
+
|
67
|
+
# INPUT SCALE
|
68
|
+
if self.is_static_input_scheme and hasattr(layer, "input_scale"):
|
69
|
+
if self.input_symmetric:
|
70
|
+
layer.input_scale = Parameter(
|
71
|
+
layer.input_scale.max(), requires_grad=False
|
72
|
+
)
|
73
|
+
else:
|
74
|
+
input_scale = layer.input_scale
|
75
|
+
input_zero_point = layer.input_zero_point
|
76
|
+
|
77
|
+
# reconstruct the ranges
|
78
|
+
int8_traits = torch.iinfo(torch.int8)
|
79
|
+
azps = input_zero_point.to(dtype=torch.int32)
|
80
|
+
range_max = (input_scale * (int8_traits.max - azps)).max()
|
81
|
+
range_min = (input_scale * (int8_traits.min - azps)).min()
|
82
|
+
|
83
|
+
scale = (range_max - range_min) / (int8_traits.max - int8_traits.min)
|
84
|
+
|
85
|
+
# AZP loaded as int8 but used as int32
|
86
|
+
azp = (int8_traits.min - range_min / scale).to(dtype=torch.int32)
|
87
|
+
|
88
|
+
layer.input_scale = Parameter(scale, requires_grad=False)
|
89
|
+
layer.input_zero_point = Parameter(azp, requires_grad=False)
|
90
|
+
else:
|
91
|
+
layer.input_scale = None
|
92
|
+
layer.input_zero_point = None
|
93
|
+
|
94
|
+
# azp_adj is the AZP adjustment term, used to account for weights.
|
95
|
+
# It does not depend on scales or azp, so it is the same for
|
96
|
+
# static and dynamic quantization.
|
97
|
+
# For more details, see csrc/quantization/cutlass_w8a8/Epilogues.md
|
98
|
+
# https://github.com/vllm-project/vllm/blob/8d59dbb00044a588cab96bcdc028006ed922eb06/csrc/quantization/cutlass_w8a8/Epilogues.md
|
99
|
+
if not self.input_symmetric:
|
100
|
+
weight = layer.weight
|
101
|
+
azp_adj = weight.sum(dim=0, keepdim=True, dtype=torch.int32)
|
102
|
+
if self.is_static_input_scheme:
|
103
|
+
# cutlass_w8a8 requires azp to be folded into azp_adj
|
104
|
+
# in the per-tensor case
|
105
|
+
azp_adj = layer.input_zero_point * azp_adj
|
106
|
+
layer.azp_adj = Parameter(azp_adj, requires_grad=False)
|
107
|
+
else:
|
108
|
+
layer.azp_adj = None
|
109
|
+
|
110
|
+
def create_weights(
|
111
|
+
self,
|
112
|
+
layer: torch.nn.Module,
|
113
|
+
output_partition_sizes: list[int],
|
114
|
+
input_size_per_partition: int,
|
115
|
+
params_dtype: torch.dtype,
|
116
|
+
weight_loader: Callable,
|
117
|
+
**kwargs,
|
118
|
+
):
|
119
|
+
output_size_per_partition = sum(output_partition_sizes)
|
120
|
+
layer.logical_widths = output_partition_sizes
|
121
|
+
|
122
|
+
# WEIGHT
|
123
|
+
weight = ModelWeightParameter(
|
124
|
+
data=torch.empty(
|
125
|
+
output_size_per_partition, input_size_per_partition, dtype=torch.int8
|
126
|
+
),
|
127
|
+
input_dim=1,
|
128
|
+
output_dim=0,
|
129
|
+
weight_loader=weight_loader,
|
130
|
+
)
|
131
|
+
|
132
|
+
layer.register_parameter("weight", weight)
|
133
|
+
|
134
|
+
# WEIGHT SCALE
|
135
|
+
if self.strategy == QuantizationStrategy.CHANNEL:
|
136
|
+
weight_scale = ChannelQuantScaleParameter(
|
137
|
+
data=torch.empty((sum(output_partition_sizes), 1), dtype=torch.float32),
|
138
|
+
output_dim=0,
|
139
|
+
weight_loader=weight_loader,
|
140
|
+
)
|
141
|
+
else:
|
142
|
+
assert self.strategy == QuantizationStrategy.TENSOR
|
143
|
+
weight_scale = PerTensorScaleParameter(
|
144
|
+
data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
|
145
|
+
weight_loader=weight_loader,
|
146
|
+
)
|
147
|
+
layer.register_parameter("weight_scale", weight_scale)
|
148
|
+
|
149
|
+
# INPUT SCALE
|
150
|
+
if self.is_static_input_scheme:
|
151
|
+
input_scale = PerTensorScaleParameter(
|
152
|
+
data=torch.empty(1, dtype=torch.float32), weight_loader=weight_loader
|
153
|
+
)
|
154
|
+
layer.register_parameter("input_scale", input_scale)
|
155
|
+
|
156
|
+
if not self.input_symmetric:
|
157
|
+
# Note: compressed-tensors stores the zp using the same dtype
|
158
|
+
# as the weights
|
159
|
+
# AZP loaded as int8 but used as int32
|
160
|
+
input_zero_point = PerTensorScaleParameter(
|
161
|
+
data=torch.empty(1, dtype=torch.int8), weight_loader=weight_loader
|
162
|
+
)
|
163
|
+
layer.register_parameter("input_zero_point", input_zero_point)
|
164
|
+
|
165
|
+
def apply_weights(
|
166
|
+
self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor]
|
167
|
+
) -> torch.Tensor:
|
168
|
+
# TODO: add cutlass_scaled_mm_azp support
|
169
|
+
x_q, x_scale = per_token_quant_int8(x)
|
170
|
+
|
171
|
+
return int8_scaled_mm(
|
172
|
+
x_q, layer.weight, x_scale, layer.weight_scale, out_dtype=x.dtype, bias=bias
|
173
|
+
)
|
@@ -1,8 +1,6 @@
|
|
1
1
|
import logging
|
2
2
|
|
3
|
-
import
|
4
|
-
|
5
|
-
from sglang.srt.utils import get_bool_env_var, get_device_sm
|
3
|
+
from sglang.srt.utils import get_bool_env_var, get_device_sm, is_blackwell
|
6
4
|
|
7
5
|
logger = logging.getLogger(__name__)
|
8
6
|
|
@@ -15,18 +13,12 @@ def _compute_enable_deep_gemm():
|
|
15
13
|
try:
|
16
14
|
import deep_gemm
|
17
15
|
except ImportError:
|
18
|
-
logger.warning("Failed to import deep_gemm, disable ENABLE_JIT_DEEPGEMM.")
|
19
16
|
return False
|
20
17
|
|
21
18
|
return get_bool_env_var("SGL_ENABLE_JIT_DEEPGEMM", default="true")
|
22
19
|
|
23
20
|
|
24
|
-
def _is_blackwell_arch() -> bool:
|
25
|
-
major, minor = torch.cuda.get_device_capability(torch.cuda.current_device())
|
26
|
-
return major == 10
|
27
|
-
|
28
|
-
|
29
21
|
ENABLE_JIT_DEEPGEMM = _compute_enable_deep_gemm()
|
30
22
|
|
31
|
-
DEEPGEMM_BLACKWELL = ENABLE_JIT_DEEPGEMM and
|
23
|
+
DEEPGEMM_BLACKWELL = ENABLE_JIT_DEEPGEMM and is_blackwell()
|
32
24
|
DEEPGEMM_SCALE_UE8M0 = DEEPGEMM_BLACKWELL
|
@@ -358,8 +358,8 @@ class Fp8LinearMethod(LinearMethodBase):
|
|
358
358
|
return
|
359
359
|
else:
|
360
360
|
weight, weight_scale = layer.weight.data, layer.weight_scale_inv.data
|
361
|
-
layer.weight =
|
362
|
-
layer.weight_scale_inv =
|
361
|
+
layer.weight.data = weight.data
|
362
|
+
layer.weight_scale_inv.data = weight_scale.data
|
363
363
|
else:
|
364
364
|
layer.weight = Parameter(layer.weight.data, requires_grad=False)
|
365
365
|
|
@@ -732,7 +732,7 @@ def apply_fp8_linear(
|
|
732
732
|
# final solution should be: 1. add support to per-tensor activation scaling.
|
733
733
|
# 2. solve the torch.compile error from weight_scale.numel() == 1 and x_scale.numel() > 1 (below line#308)
|
734
734
|
if _is_hip and weight_scale.numel() == 1:
|
735
|
-
qinput, x_scale =
|
735
|
+
qinput, x_scale = scaled_fp8_quant(
|
736
736
|
input_2d,
|
737
737
|
input_scale,
|
738
738
|
use_per_token_if_dynamic=use_per_token_if_dynamic,
|
@@ -47,6 +47,7 @@ if TYPE_CHECKING:
|
|
47
47
|
CombineInput,
|
48
48
|
StandardDispatchOutput,
|
49
49
|
)
|
50
|
+
from sglang.srt.single_batch_overlap import DownGemmOverlapArgs
|
50
51
|
|
51
52
|
if is_cuda():
|
52
53
|
from sgl_kernel import scaled_fp4_quant
|
@@ -77,6 +78,13 @@ logger = logging.getLogger(__name__)
|
|
77
78
|
CUTEDSL_MOE_SCALAR_INPUT_SCALE = get_bool_env_var(
|
78
79
|
"SGLANG_CUTEDSL_MOE_SCALAR_INPUT_SCALE", "true"
|
79
80
|
)
|
81
|
+
USE_CUTLASS_BACKEND_FOR_FP4_GEMM = get_bool_env_var(
|
82
|
+
"SGLANG_USE_CUTLASS_BACKEND_FOR_FP4_GEMM"
|
83
|
+
)
|
84
|
+
# TODO make it true by default when the DeepEP PR is merged
|
85
|
+
CUTEDSL_MOE_NVFP4_DISPATCH = get_bool_env_var(
|
86
|
+
"SGLANG_CUTEDSL_MOE_NVFP4_DISPATCH", "false"
|
87
|
+
)
|
80
88
|
|
81
89
|
# Supported activation schemes for the current configuration
|
82
90
|
ACTIVATION_SCHEMES = ["static"]
|
@@ -844,14 +852,25 @@ class ModelOptFp4LinearMethod(LinearMethodBase):
|
|
844
852
|
if enable_flashinfer_fp4_gemm:
|
845
853
|
w = layer.weight.T
|
846
854
|
w_scale_interleaved = layer.weight_scale_interleaved.T
|
847
|
-
|
848
|
-
|
849
|
-
|
850
|
-
|
851
|
-
|
852
|
-
|
853
|
-
|
854
|
-
|
855
|
+
if USE_CUTLASS_BACKEND_FOR_FP4_GEMM:
|
856
|
+
out = fp4_gemm(
|
857
|
+
x_fp4,
|
858
|
+
w,
|
859
|
+
x_scale_interleaved,
|
860
|
+
w_scale_interleaved,
|
861
|
+
layer.alpha,
|
862
|
+
output_dtype,
|
863
|
+
backend="cutlass",
|
864
|
+
)
|
865
|
+
else:
|
866
|
+
out = fp4_gemm(
|
867
|
+
x_fp4,
|
868
|
+
w,
|
869
|
+
x_scale_interleaved,
|
870
|
+
w_scale_interleaved,
|
871
|
+
layer.alpha,
|
872
|
+
output_dtype,
|
873
|
+
)
|
855
874
|
if bias is not None:
|
856
875
|
out = out + bias
|
857
876
|
return out.view(*output_shape)
|
@@ -1220,6 +1239,10 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
|
1220
1239
|
|
1221
1240
|
w13_input_scale = _slice_scale(w13_input_scale)
|
1222
1241
|
w2_input_scale = _slice_scale(w2_input_scale)
|
1242
|
+
|
1243
|
+
if CUTEDSL_MOE_NVFP4_DISPATCH:
|
1244
|
+
assert torch.all(w13_input_scale == w13_input_scale[0])
|
1245
|
+
w13_input_scale = w13_input_scale[0]
|
1223
1246
|
else:
|
1224
1247
|
w13_input_scale = layer.w13_input_scale.max(dim=1).values.to(torch.float32)
|
1225
1248
|
w2_input_scale = layer.w2_input_scale
|
@@ -1446,6 +1469,7 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
|
1446
1469
|
x: torch.Tensor,
|
1447
1470
|
masked_m: torch.Tensor,
|
1448
1471
|
moe_runner_config: MoeRunnerConfig,
|
1472
|
+
down_gemm_overlap_args: Optional["DownGemmOverlapArgs"],
|
1449
1473
|
) -> torch.Tensor:
|
1450
1474
|
assert (
|
1451
1475
|
moe_runner_config.activation == "silu"
|
@@ -1462,7 +1486,9 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
|
1462
1486
|
|
1463
1487
|
out = flashinfer_cutedsl_moe_masked(
|
1464
1488
|
hidden_states=x,
|
1465
|
-
input_global_scale=
|
1489
|
+
input_global_scale=(
|
1490
|
+
None if CUTEDSL_MOE_NVFP4_DISPATCH else layer.w13_input_scale_quant
|
1491
|
+
),
|
1466
1492
|
w1=layer.w13_weight,
|
1467
1493
|
w1_blockscale=layer.w13_blockscale_swizzled,
|
1468
1494
|
w1_alpha=layer.g1_alphas,
|
@@ -1471,5 +1497,14 @@ class ModelOptNvFp4FusedMoEMethod(FusedMoEMethodBase):
|
|
1471
1497
|
w2_blockscale=layer.w2_blockscale_swizzled,
|
1472
1498
|
w2_alpha=layer.g2_alphas,
|
1473
1499
|
masked_m=masked_m,
|
1500
|
+
**(
|
1501
|
+
dict(
|
1502
|
+
down_sm_count=down_gemm_overlap_args.num_sms,
|
1503
|
+
down_signals=down_gemm_overlap_args.signal,
|
1504
|
+
down_start_event=down_gemm_overlap_args.start_event,
|
1505
|
+
)
|
1506
|
+
if down_gemm_overlap_args is not None
|
1507
|
+
else {}
|
1508
|
+
),
|
1474
1509
|
)
|
1475
1510
|
return out
|
@@ -731,8 +731,8 @@ class Mxfp4MoEMethod(FusedMoEMethodBase):
|
|
731
731
|
quant_info = TritonMoeQuantInfo(
|
732
732
|
w13_weight=layer.w13_weight,
|
733
733
|
w2_weight=layer.w2_weight,
|
734
|
-
|
735
|
-
|
734
|
+
b13=getattr(layer, "w13_weight_bias", None),
|
735
|
+
b2=getattr(layer, "w2_weight_bias", None),
|
736
736
|
)
|
737
737
|
return self.runner.run(dispatch_output, quant_info)
|
738
738
|
|
@@ -843,10 +843,18 @@ class Mxfp4DynamicQuantMoEMethod(FusedMoEMethodBase):
|
|
843
843
|
topk_weights = topk_weights.to(
|
844
844
|
torch.float32
|
845
845
|
) # aiter's moe_sorting requires topk_weights to be FP32
|
846
|
+
|
847
|
+
if hasattr(torch, "float4_e2m1fn_x2"):
|
848
|
+
w13_weight = layer.w13_weight.view(torch.float4_e2m1fn_x2)
|
849
|
+
w2_weight = layer.w2_weight.view(torch.float4_e2m1fn_x2)
|
850
|
+
else:
|
851
|
+
w13_weight = layer.w13_weight
|
852
|
+
w2_weight = layer.w2_weight
|
853
|
+
|
846
854
|
output = fused_moe(
|
847
855
|
x,
|
848
|
-
|
849
|
-
|
856
|
+
w13_weight,
|
857
|
+
w2_weight,
|
850
858
|
topk_weights,
|
851
859
|
topk_ids,
|
852
860
|
quant_type=QuantType.per_1x32,
|
@@ -12,7 +12,7 @@ from aiter.utility.fp4_utils import e8m0_shuffle
|
|
12
12
|
|
13
13
|
from sglang.srt.layers.moe import MoeRunnerConfig
|
14
14
|
from sglang.srt.layers.quantization.base_config import FusedMoEMethodBase
|
15
|
-
from sglang.srt.utils import get_bool_env_var, mxfp_supported, set_weight_attrs
|
15
|
+
from sglang.srt.utils import get_bool_env_var, is_hip, mxfp_supported, set_weight_attrs
|
16
16
|
|
17
17
|
if TYPE_CHECKING:
|
18
18
|
from sglang.srt.layers.moe.token_dispatcher import (
|
@@ -23,6 +23,8 @@ if TYPE_CHECKING:
|
|
23
23
|
|
24
24
|
logger = logging.getLogger(__name__)
|
25
25
|
|
26
|
+
_is_hip = is_hip()
|
27
|
+
|
26
28
|
__all__ = ["QuarkMoEMethod", "QuarkW4A4MXFp4MoEMethod"]
|
27
29
|
|
28
30
|
OCP_MX_BLOCK_SIZE = 32
|
@@ -182,11 +184,22 @@ class QuarkW4A4MXFp4MoEMethod(QuarkMoEMethod):
|
|
182
184
|
topk_output = dispatch_output.topk_output
|
183
185
|
moe_runner_config = self.moe_runner_config
|
184
186
|
topk_weights, topk_ids, _ = topk_output
|
187
|
+
if _is_hip:
|
188
|
+
topk_weights = topk_weights.to(
|
189
|
+
torch.float32
|
190
|
+
) # aiter's moe_sorting requires topk_weights to be FP32
|
191
|
+
|
192
|
+
if hasattr(torch, "float4_e2m1fn_x2"):
|
193
|
+
w13_weight = layer.w13_weight.view(torch.float4_e2m1fn_x2)
|
194
|
+
w2_weight = layer.w2_weight.view(torch.float4_e2m1fn_x2)
|
195
|
+
else:
|
196
|
+
w13_weight = layer.w13_weight
|
197
|
+
w2_weight = layer.w2_weight
|
185
198
|
|
186
199
|
output = fused_moe(
|
187
200
|
x,
|
188
|
-
|
189
|
-
|
201
|
+
w13_weight,
|
202
|
+
w2_weight,
|
190
203
|
topk_weights,
|
191
204
|
topk_ids,
|
192
205
|
quant_type=QuantType.per_1x32,
|
@@ -19,10 +19,6 @@ from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
|
|
19
19
|
from sglang.srt.layers.quantization.utils import is_layer_skipped
|
20
20
|
from sglang.srt.utils import is_npu, set_weight_attrs
|
21
21
|
|
22
|
-
_is_npu = is_npu()
|
23
|
-
if not _is_npu:
|
24
|
-
from sglang.srt.layers.moe.cutlass_w4a8_moe import cutlass_w4a8_moe
|
25
|
-
|
26
22
|
if TYPE_CHECKING:
|
27
23
|
from sglang.srt.layers.moe import MoeRunnerConfig
|
28
24
|
from sglang.srt.layers.moe.ep_moe.layer import EPMoE
|
@@ -393,13 +393,23 @@ class W8A8Int8LinearMethod(LinearMethodBase):
|
|
393
393
|
x.dtype,
|
394
394
|
True, # is_vnni
|
395
395
|
)
|
396
|
-
|
397
396
|
x_q, x_scale = per_token_quant_int8(x)
|
398
397
|
|
399
|
-
|
400
|
-
|
398
|
+
x_q_2d = x_q.view(-1, x_q.shape[-1])
|
399
|
+
x_scale_2d = x_scale.view(-1, x_scale.shape[-1])
|
400
|
+
output_shape = [*x_q.shape[:-1], layer.weight.shape[1]]
|
401
|
+
|
402
|
+
output = int8_scaled_mm(
|
403
|
+
x_q_2d,
|
404
|
+
layer.weight,
|
405
|
+
x_scale_2d,
|
406
|
+
layer.weight_scale,
|
407
|
+
out_dtype=x.dtype,
|
408
|
+
bias=bias,
|
401
409
|
)
|
402
410
|
|
411
|
+
return output.view(output_shape)
|
412
|
+
|
403
413
|
|
404
414
|
class W8A8Int8MoEMethod(FusedMoEMethodBase):
|
405
415
|
"""MoE method for INT8.
|
@@ -638,6 +648,7 @@ class NPU_W8A8LinearMethodImpl:
|
|
638
648
|
layer.weight.data = layer.weight.data.transpose(0, 1).contiguous()
|
639
649
|
layer.weight_scale.data = torch.flatten(layer.weight_scale.data)
|
640
650
|
layer.weight_offset.data = torch.flatten(layer.weight_offset.data)
|
651
|
+
layer.weight.data = torch_npu.npu_format_cast(layer.weight.data, 29)
|
641
652
|
|
642
653
|
|
643
654
|
class NPU_W8A8LinearMethodMTImpl:
|
@@ -830,6 +841,7 @@ class NPU_W8A8DynamicLinearMethodImpl:
|
|
830
841
|
layer.weight_scale.data = layer.weight_scale.data.flatten()
|
831
842
|
layer.weight_scale_fp32 = layer.weight_scale.data.to(torch.float32)
|
832
843
|
layer.weight_offset.data = layer.weight_offset.data.flatten()
|
844
|
+
layer.weight.data = torch_npu.npu_format_cast(layer.weight.data, 29)
|
833
845
|
|
834
846
|
|
835
847
|
class NPU_W8A8DynamicLinearMethod(LinearMethodBase):
|