sglang 0.5.3rc0__py3-none-any.whl → 0.5.3rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +7 -9
- sglang/bench_one_batch_server.py +321 -31
- sglang/bench_serving.py +10 -3
- sglang/global_config.py +2 -2
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/launch_server.py +14 -0
- sglang/profiler.py +2 -2
- sglang/srt/batch_invariant_ops/__init__.py +27 -0
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/dots_ocr.py +64 -0
- sglang/srt/configs/falcon_h1.py +360 -0
- sglang/srt/configs/load_config.py +8 -0
- sglang/srt/configs/model_config.py +160 -105
- sglang/srt/configs/qwen3_vl.py +586 -0
- sglang/srt/constrained/base_grammar_backend.py +1 -0
- sglang/srt/constrained/outlines_jump_forward.py +1 -1
- sglang/srt/constrained/xgrammar_backend.py +6 -4
- sglang/srt/debug_utils/dumper.py +10 -3
- sglang/srt/disaggregation/ascend/conn.py +2 -2
- sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
- sglang/srt/disaggregation/common/conn.py +266 -98
- sglang/srt/disaggregation/decode.py +50 -9
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +25 -16
- sglang/srt/disaggregation/mooncake/conn.py +51 -541
- sglang/srt/disaggregation/nixl/conn.py +148 -39
- sglang/srt/disaggregation/prefill.py +31 -14
- sglang/srt/disaggregation/utils.py +36 -5
- sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
- sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
- sglang/srt/distributed/parallel_state.py +135 -80
- sglang/srt/entrypoints/engine.py +23 -3
- sglang/srt/entrypoints/grpc_request_manager.py +330 -55
- sglang/srt/entrypoints/grpc_server.py +232 -102
- sglang/srt/entrypoints/http_server.py +49 -9
- sglang/srt/entrypoints/openai/protocol.py +110 -5
- sglang/srt/entrypoints/openai/serving_base.py +25 -6
- sglang/srt/entrypoints/openai/serving_chat.py +178 -49
- sglang/srt/entrypoints/openai/serving_completions.py +5 -3
- sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
- sglang/srt/entrypoints/openai/serving_responses.py +42 -0
- sglang/srt/environ.py +285 -0
- sglang/srt/eplb/expert_location.py +30 -5
- sglang/srt/function_call/function_call_parser.py +3 -2
- sglang/srt/function_call/glm4_moe_detector.py +3 -3
- sglang/srt/function_call/gpt_oss_detector.py +23 -0
- sglang/srt/function_call/json_array_parser.py +63 -0
- sglang/srt/function_call/kimik2_detector.py +17 -4
- sglang/srt/function_call/utils.py +96 -5
- sglang/srt/grpc/compile_proto.py +245 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +73 -68
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +60 -53
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +3 -0
- sglang/srt/layers/activation.py +7 -6
- sglang/srt/layers/attention/aiter_backend.py +14 -15
- sglang/srt/layers/attention/ascend_backend.py +108 -9
- sglang/srt/layers/attention/attention_registry.py +206 -0
- sglang/srt/layers/attention/base_attn_backend.py +12 -3
- sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +2 -2
- sglang/srt/layers/attention/fla/fused_recurrent.py +4 -4
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +2 -2
- sglang/srt/layers/attention/flashattention_backend.py +41 -8
- sglang/srt/layers/attention/flashinfer_backend.py +112 -194
- sglang/srt/layers/attention/flashinfer_mla_backend.py +11 -15
- sglang/srt/layers/attention/flashmla_backend.py +7 -5
- sglang/srt/layers/attention/hybrid_attn_backend.py +11 -3
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +72 -72
- sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +15 -98
- sglang/srt/layers/attention/mamba/mamba.py +566 -1
- sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
- sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
- sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
- sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
- sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
- sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
- sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
- sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
- sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
- sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
- sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
- sglang/srt/layers/attention/nsa/transform_index.py +144 -0
- sglang/srt/layers/attention/nsa/utils.py +24 -0
- sglang/srt/layers/attention/nsa_backend.py +887 -0
- sglang/srt/layers/attention/tbo_backend.py +6 -6
- sglang/srt/layers/attention/torch_flex_backend.py +325 -0
- sglang/srt/layers/attention/triton_backend.py +42 -9
- sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
- sglang/srt/layers/attention/trtllm_mla_backend.py +178 -34
- sglang/srt/layers/attention/vision.py +58 -0
- sglang/srt/layers/attention/wave_backend.py +4 -4
- sglang/srt/layers/communicator.py +8 -0
- sglang/srt/layers/dp_attention.py +11 -1
- sglang/srt/layers/elementwise.py +3 -1
- sglang/srt/layers/layernorm.py +2 -0
- sglang/srt/layers/linear.py +21 -4
- sglang/srt/layers/logits_processor.py +15 -2
- sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
- sglang/srt/layers/moe/ep_moe/layer.py +147 -74
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +52 -25
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +6 -2
- sglang/srt/layers/moe/fused_moe_triton/layer.py +11 -12
- sglang/srt/layers/moe/token_dispatcher/deepep.py +77 -19
- sglang/srt/layers/moe/utils.py +10 -0
- sglang/srt/layers/parameter.py +23 -6
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
- sglang/srt/layers/quantization/fp8.py +2 -2
- sglang/srt/layers/quantization/fp8_utils.py +1 -1
- sglang/srt/layers/quantization/modelopt_quant.py +44 -9
- sglang/srt/layers/quantization/mxfp4.py +12 -4
- sglang/srt/layers/quantization/quark/quark_moe.py +16 -3
- sglang/srt/layers/quantization/w4afp8.py +0 -4
- sglang/srt/layers/quantization/w8a8_int8.py +15 -3
- sglang/srt/layers/rotary_embedding.py +78 -31
- sglang/srt/layers/sampler.py +52 -4
- sglang/srt/layers/utils.py +23 -0
- sglang/srt/lora/backend/base_backend.py +3 -3
- sglang/srt/lora/backend/chunked_backend.py +348 -0
- sglang/srt/lora/backend/triton_backend.py +10 -4
- sglang/srt/lora/lora.py +7 -5
- sglang/srt/lora/lora_manager.py +17 -6
- sglang/srt/lora/mem_pool.py +1 -1
- sglang/srt/lora/triton_ops/__init__.py +4 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
- sglang/srt/lora/utils.py +7 -5
- sglang/srt/managers/cache_controller.py +42 -142
- sglang/srt/managers/data_parallel_controller.py +11 -46
- sglang/srt/managers/detokenizer_manager.py +11 -11
- sglang/srt/managers/io_struct.py +162 -118
- sglang/srt/managers/mm_utils.py +43 -6
- sglang/srt/managers/multi_tokenizer_mixin.py +17 -17
- sglang/srt/managers/multimodal_processor.py +1 -2
- sglang/srt/managers/overlap_utils.py +53 -0
- sglang/srt/managers/schedule_batch.py +167 -86
- sglang/srt/managers/schedule_policy.py +143 -16
- sglang/srt/managers/scheduler.py +359 -214
- sglang/srt/managers/scheduler_input_blocker.py +1 -1
- sglang/srt/managers/scheduler_metrics_mixin.py +98 -126
- sglang/srt/managers/scheduler_output_processor_mixin.py +21 -12
- sglang/srt/managers/scheduler_profiler_mixin.py +5 -5
- sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
- sglang/srt/managers/tokenizer_communicator_mixin.py +111 -5
- sglang/srt/managers/tokenizer_manager.py +84 -136
- sglang/srt/managers/tp_worker.py +39 -29
- sglang/srt/managers/tp_worker_overlap_thread.py +33 -41
- sglang/srt/managers/utils.py +1 -45
- sglang/srt/mem_cache/allocator.py +14 -20
- sglang/srt/mem_cache/allocator_ascend.py +41 -27
- sglang/srt/mem_cache/base_prefix_cache.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +8 -1
- sglang/srt/mem_cache/evict_policy.py +23 -0
- sglang/srt/mem_cache/hicache_storage.py +40 -1
- sglang/srt/mem_cache/hiradix_cache.py +119 -32
- sglang/srt/mem_cache/memory_pool.py +188 -10
- sglang/srt/mem_cache/memory_pool_host.py +134 -182
- sglang/srt/mem_cache/radix_cache.py +222 -71
- sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
- sglang/srt/mem_cache/storage/__init__.py +10 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
- sglang/srt/mem_cache/storage/backend_factory.py +223 -0
- sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
- sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +173 -58
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +10 -6
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +117 -10
- sglang/srt/mem_cache/swa_radix_cache.py +25 -34
- sglang/srt/metrics/collector.py +82 -120
- sglang/srt/metrics/func_timer.py +2 -7
- sglang/srt/metrics/utils.py +8 -1
- sglang/srt/model_executor/cpu_graph_runner.py +2 -2
- sglang/srt/model_executor/cuda_graph_runner.py +39 -32
- sglang/srt/model_executor/forward_batch_info.py +23 -38
- sglang/srt/model_executor/model_runner.py +131 -183
- sglang/srt/model_executor/npu_graph_runner.py +12 -5
- sglang/srt/model_loader/loader.py +14 -10
- sglang/srt/model_loader/weight_utils.py +156 -2
- sglang/srt/models/bailing_moe.py +27 -4
- sglang/srt/models/deepseek_nextn.py +6 -1
- sglang/srt/models/deepseek_v2.py +536 -153
- sglang/srt/models/dots_ocr.py +173 -0
- sglang/srt/models/falcon_h1.py +576 -0
- sglang/srt/models/gemma3_causal.py +0 -2
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/gemma3n_mm.py +1 -1
- sglang/srt/models/glm4_moe.py +3 -3
- sglang/srt/models/glm4_moe_nextn.py +2 -2
- sglang/srt/models/glm4v.py +1 -1
- sglang/srt/models/glm4v_moe.py +1 -1
- sglang/srt/models/gpt_oss.py +7 -30
- sglang/srt/models/kimi_vl_moonvit.py +2 -2
- sglang/srt/models/llama.py +4 -0
- sglang/srt/models/longcat_flash.py +1 -1
- sglang/srt/models/longcat_flash_nextn.py +1 -1
- sglang/srt/models/mllama4.py +15 -4
- sglang/srt/models/qwen2.py +0 -7
- sglang/srt/models/qwen2_5_vl.py +2 -2
- sglang/srt/models/qwen2_audio.py +1 -1
- sglang/srt/models/qwen2_moe.py +64 -1
- sglang/srt/models/qwen2_vl.py +1 -1
- sglang/srt/models/qwen3.py +18 -3
- sglang/srt/models/qwen3_moe.py +31 -3
- sglang/srt/models/qwen3_next.py +36 -9
- sglang/srt/models/qwen3_vl.py +787 -0
- sglang/srt/models/qwen3_vl_moe.py +471 -0
- sglang/srt/models/registry.py +15 -3
- sglang/srt/models/sarashina2_vision.py +269 -0
- sglang/srt/models/solar.py +505 -0
- sglang/srt/models/starcoder2.py +357 -0
- sglang/srt/models/torch_native_llama.py +9 -2
- sglang/srt/models/utils.py +51 -0
- sglang/srt/multimodal/processors/base_processor.py +15 -7
- sglang/srt/multimodal/processors/dots_vlm.py +2 -3
- sglang/srt/multimodal/processors/internvl.py +20 -8
- sglang/srt/multimodal/processors/qwen_vl.py +8 -1
- sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
- sglang/srt/parser/jinja_template_utils.py +6 -0
- sglang/srt/sampling/sampling_batch_info.py +20 -2
- sglang/srt/sampling/sampling_params.py +7 -0
- sglang/srt/server_args.py +753 -295
- sglang/srt/server_args_config_parser.py +146 -0
- sglang/srt/single_batch_overlap.py +151 -0
- sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
- sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
- sglang/srt/speculative/cpp_ngram/param.h +125 -0
- sglang/srt/speculative/cpp_ngram/queue.h +71 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +2 -1
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +3 -1
- sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -755
- sglang/srt/speculative/eagle_worker.py +57 -25
- sglang/srt/speculative/ngram_utils.py +428 -0
- sglang/srt/speculative/ngram_worker.py +245 -0
- sglang/srt/speculative/spec_info.py +47 -0
- sglang/srt/speculative/spec_utils.py +606 -0
- sglang/srt/torch_memory_saver_adapter.py +5 -7
- sglang/srt/tracing/trace.py +32 -6
- sglang/srt/two_batch_overlap.py +8 -5
- sglang/srt/utils/__init__.py +2 -0
- sglang/srt/{utils.py → utils/common.py} +399 -74
- sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +49 -5
- sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
- sglang/srt/utils/rpd_utils.py +452 -0
- sglang/srt/utils/slow_rank_detector.py +71 -0
- sglang/srt/warmup.py +8 -4
- sglang/srt/weight_sync/utils.py +1 -1
- sglang/test/get_logits_ut.py +57 -0
- sglang/test/run_eval.py +79 -11
- sglang/test/runners.py +1 -1
- sglang/test/simple_eval_common.py +5 -2
- sglang/test/simple_eval_mmmu_vlm.py +441 -0
- sglang/test/test_block_fp8.py +2 -2
- sglang/test/test_deterministic.py +297 -0
- sglang/test/test_disaggregation_utils.py +12 -1
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +355 -4
- sglang/utils.py +10 -1
- sglang/version.py +1 -1
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/METADATA +34 -25
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/RECORD +281 -210
- sglang/srt/mem_cache/lora_radix_cache.py +0 -421
- /sglang/srt/{remote_instance_weight_loader_utils.py → model_loader/remote_instance_weight_loader_utils.py} +0 -0
- /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/WHEEL +0 -0
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/top_level.txt +0 -0
@@ -3,7 +3,6 @@ import datetime
|
|
3
3
|
from google.protobuf import timestamp_pb2 as _timestamp_pb2
|
4
4
|
from google.protobuf import struct_pb2 as _struct_pb2
|
5
5
|
from google.protobuf.internal import containers as _containers
|
6
|
-
from google.protobuf.internal import enum_type_wrapper as _enum_type_wrapper
|
7
6
|
from google.protobuf import descriptor as _descriptor
|
8
7
|
from google.protobuf import message as _message
|
9
8
|
from collections.abc import Iterable as _Iterable, Mapping as _Mapping
|
@@ -12,7 +11,7 @@ from typing import ClassVar as _ClassVar, Optional as _Optional, Union as _Union
|
|
12
11
|
DESCRIPTOR: _descriptor.FileDescriptor
|
13
12
|
|
14
13
|
class SamplingParams(_message.Message):
|
15
|
-
__slots__ = ("temperature", "top_p", "top_k", "min_p", "frequency_penalty", "presence_penalty", "repetition_penalty", "max_new_tokens", "stop", "stop_token_ids", "skip_special_tokens", "spaces_between_special_tokens", "regex", "json_schema", "ebnf_grammar", "lora_path", "n", "token_healing", "min_new_tokens", "ignore_eos", "no_stop_trim", "stream_interval", "logit_bias", "
|
14
|
+
__slots__ = ("temperature", "top_p", "top_k", "min_p", "frequency_penalty", "presence_penalty", "repetition_penalty", "max_new_tokens", "stop", "stop_token_ids", "skip_special_tokens", "spaces_between_special_tokens", "regex", "json_schema", "ebnf_grammar", "structural_tag", "lora_path", "n", "token_healing", "min_new_tokens", "ignore_eos", "no_stop_trim", "stream_interval", "logit_bias", "custom_params")
|
16
15
|
class LogitBiasEntry(_message.Message):
|
17
16
|
__slots__ = ("key", "value")
|
18
17
|
KEY_FIELD_NUMBER: _ClassVar[int]
|
@@ -35,6 +34,7 @@ class SamplingParams(_message.Message):
|
|
35
34
|
REGEX_FIELD_NUMBER: _ClassVar[int]
|
36
35
|
JSON_SCHEMA_FIELD_NUMBER: _ClassVar[int]
|
37
36
|
EBNF_GRAMMAR_FIELD_NUMBER: _ClassVar[int]
|
37
|
+
STRUCTURAL_TAG_FIELD_NUMBER: _ClassVar[int]
|
38
38
|
LORA_PATH_FIELD_NUMBER: _ClassVar[int]
|
39
39
|
N_FIELD_NUMBER: _ClassVar[int]
|
40
40
|
TOKEN_HEALING_FIELD_NUMBER: _ClassVar[int]
|
@@ -43,7 +43,6 @@ class SamplingParams(_message.Message):
|
|
43
43
|
NO_STOP_TRIM_FIELD_NUMBER: _ClassVar[int]
|
44
44
|
STREAM_INTERVAL_FIELD_NUMBER: _ClassVar[int]
|
45
45
|
LOGIT_BIAS_FIELD_NUMBER: _ClassVar[int]
|
46
|
-
STRUCTURAL_TAG_FIELD_NUMBER: _ClassVar[int]
|
47
46
|
CUSTOM_PARAMS_FIELD_NUMBER: _ClassVar[int]
|
48
47
|
temperature: float
|
49
48
|
top_p: float
|
@@ -60,6 +59,7 @@ class SamplingParams(_message.Message):
|
|
60
59
|
regex: str
|
61
60
|
json_schema: str
|
62
61
|
ebnf_grammar: str
|
62
|
+
structural_tag: str
|
63
63
|
lora_path: str
|
64
64
|
n: int
|
65
65
|
token_healing: bool
|
@@ -68,9 +68,8 @@ class SamplingParams(_message.Message):
|
|
68
68
|
no_stop_trim: bool
|
69
69
|
stream_interval: int
|
70
70
|
logit_bias: _containers.ScalarMap[str, float]
|
71
|
-
structural_tag: str
|
72
71
|
custom_params: _struct_pb2.Struct
|
73
|
-
def __init__(self, temperature: _Optional[float] = ..., top_p: _Optional[float] = ..., top_k: _Optional[int] = ..., min_p: _Optional[float] = ..., frequency_penalty: _Optional[float] = ..., presence_penalty: _Optional[float] = ..., repetition_penalty: _Optional[float] = ..., max_new_tokens: _Optional[int] = ..., stop: _Optional[_Iterable[str]] = ..., stop_token_ids: _Optional[_Iterable[int]] = ..., skip_special_tokens: bool = ..., spaces_between_special_tokens: bool = ..., regex: _Optional[str] = ..., json_schema: _Optional[str] = ..., ebnf_grammar: _Optional[str] = ..., lora_path: _Optional[str] = ..., n: _Optional[int] = ..., token_healing: bool = ..., min_new_tokens: _Optional[int] = ..., ignore_eos: bool = ..., no_stop_trim: bool = ..., stream_interval: _Optional[int] = ..., logit_bias: _Optional[_Mapping[str, float]] = ...,
|
72
|
+
def __init__(self, temperature: _Optional[float] = ..., top_p: _Optional[float] = ..., top_k: _Optional[int] = ..., min_p: _Optional[float] = ..., frequency_penalty: _Optional[float] = ..., presence_penalty: _Optional[float] = ..., repetition_penalty: _Optional[float] = ..., max_new_tokens: _Optional[int] = ..., stop: _Optional[_Iterable[str]] = ..., stop_token_ids: _Optional[_Iterable[int]] = ..., skip_special_tokens: bool = ..., spaces_between_special_tokens: bool = ..., regex: _Optional[str] = ..., json_schema: _Optional[str] = ..., ebnf_grammar: _Optional[str] = ..., structural_tag: _Optional[str] = ..., lora_path: _Optional[str] = ..., n: _Optional[int] = ..., token_healing: bool = ..., min_new_tokens: _Optional[int] = ..., ignore_eos: bool = ..., no_stop_trim: bool = ..., stream_interval: _Optional[int] = ..., logit_bias: _Optional[_Mapping[str, float]] = ..., custom_params: _Optional[_Union[_struct_pb2.Struct, _Mapping]] = ...) -> None: ...
|
74
73
|
|
75
74
|
class DisaggregatedParams(_message.Message):
|
76
75
|
__slots__ = ("bootstrap_host", "bootstrap_port", "bootstrap_room")
|
@@ -83,7 +82,7 @@ class DisaggregatedParams(_message.Message):
|
|
83
82
|
def __init__(self, bootstrap_host: _Optional[str] = ..., bootstrap_port: _Optional[int] = ..., bootstrap_room: _Optional[int] = ...) -> None: ...
|
84
83
|
|
85
84
|
class GenerateRequest(_message.Message):
|
86
|
-
__slots__ = ("request_id", "tokenized", "mm_inputs", "sampling_params", "return_logprob", "logprob_start_len", "top_logprobs_num", "token_ids_logprob", "return_hidden_states", "disaggregated_params", "custom_logit_processor", "timestamp", "log_metrics", "input_embeds", "lora_id", "data_parallel_rank", "
|
85
|
+
__slots__ = ("request_id", "tokenized", "mm_inputs", "sampling_params", "return_logprob", "logprob_start_len", "top_logprobs_num", "token_ids_logprob", "return_hidden_states", "disaggregated_params", "custom_logit_processor", "timestamp", "log_metrics", "input_embeds", "lora_id", "data_parallel_rank", "stream")
|
87
86
|
REQUEST_ID_FIELD_NUMBER: _ClassVar[int]
|
88
87
|
TOKENIZED_FIELD_NUMBER: _ClassVar[int]
|
89
88
|
MM_INPUTS_FIELD_NUMBER: _ClassVar[int]
|
@@ -100,7 +99,7 @@ class GenerateRequest(_message.Message):
|
|
100
99
|
INPUT_EMBEDS_FIELD_NUMBER: _ClassVar[int]
|
101
100
|
LORA_ID_FIELD_NUMBER: _ClassVar[int]
|
102
101
|
DATA_PARALLEL_RANK_FIELD_NUMBER: _ClassVar[int]
|
103
|
-
|
102
|
+
STREAM_FIELD_NUMBER: _ClassVar[int]
|
104
103
|
request_id: str
|
105
104
|
tokenized: TokenizedInput
|
106
105
|
mm_inputs: MultimodalInputs
|
@@ -117,8 +116,8 @@ class GenerateRequest(_message.Message):
|
|
117
116
|
input_embeds: _containers.RepeatedScalarFieldContainer[float]
|
118
117
|
lora_id: str
|
119
118
|
data_parallel_rank: int
|
120
|
-
|
121
|
-
def __init__(self, request_id: _Optional[str] = ..., tokenized: _Optional[_Union[TokenizedInput, _Mapping]] = ..., mm_inputs: _Optional[_Union[MultimodalInputs, _Mapping]] = ..., sampling_params: _Optional[_Union[SamplingParams, _Mapping]] = ..., return_logprob: bool = ..., logprob_start_len: _Optional[int] = ..., top_logprobs_num: _Optional[int] = ..., token_ids_logprob: _Optional[_Iterable[int]] = ..., return_hidden_states: bool = ..., disaggregated_params: _Optional[_Union[DisaggregatedParams, _Mapping]] = ..., custom_logit_processor: _Optional[str] = ..., timestamp: _Optional[_Union[datetime.datetime, _timestamp_pb2.Timestamp, _Mapping]] = ..., log_metrics: bool = ..., input_embeds: _Optional[_Iterable[float]] = ..., lora_id: _Optional[str] = ..., data_parallel_rank: _Optional[int] = ...,
|
119
|
+
stream: bool
|
120
|
+
def __init__(self, request_id: _Optional[str] = ..., tokenized: _Optional[_Union[TokenizedInput, _Mapping]] = ..., mm_inputs: _Optional[_Union[MultimodalInputs, _Mapping]] = ..., sampling_params: _Optional[_Union[SamplingParams, _Mapping]] = ..., return_logprob: bool = ..., logprob_start_len: _Optional[int] = ..., top_logprobs_num: _Optional[int] = ..., token_ids_logprob: _Optional[_Iterable[int]] = ..., return_hidden_states: bool = ..., disaggregated_params: _Optional[_Union[DisaggregatedParams, _Mapping]] = ..., custom_logit_processor: _Optional[str] = ..., timestamp: _Optional[_Union[datetime.datetime, _timestamp_pb2.Timestamp, _Mapping]] = ..., log_metrics: bool = ..., input_embeds: _Optional[_Iterable[float]] = ..., lora_id: _Optional[str] = ..., data_parallel_rank: _Optional[int] = ..., stream: bool = ...) -> None: ...
|
122
121
|
|
123
122
|
class TokenizedInput(_message.Message):
|
124
123
|
__slots__ = ("original_text", "input_ids")
|
@@ -161,52 +160,50 @@ class GenerateResponse(_message.Message):
|
|
161
160
|
def __init__(self, request_id: _Optional[str] = ..., chunk: _Optional[_Union[GenerateStreamChunk, _Mapping]] = ..., complete: _Optional[_Union[GenerateComplete, _Mapping]] = ..., error: _Optional[_Union[GenerateError, _Mapping]] = ...) -> None: ...
|
162
161
|
|
163
162
|
class GenerateStreamChunk(_message.Message):
|
164
|
-
__slots__ = ("
|
165
|
-
|
166
|
-
TEXT_FIELD_NUMBER: _ClassVar[int]
|
163
|
+
__slots__ = ("token_ids", "prompt_tokens", "completion_tokens", "cached_tokens", "output_logprobs", "hidden_states", "input_logprobs", "index")
|
164
|
+
TOKEN_IDS_FIELD_NUMBER: _ClassVar[int]
|
167
165
|
PROMPT_TOKENS_FIELD_NUMBER: _ClassVar[int]
|
168
166
|
COMPLETION_TOKENS_FIELD_NUMBER: _ClassVar[int]
|
169
167
|
CACHED_TOKENS_FIELD_NUMBER: _ClassVar[int]
|
170
|
-
|
168
|
+
OUTPUT_LOGPROBS_FIELD_NUMBER: _ClassVar[int]
|
171
169
|
HIDDEN_STATES_FIELD_NUMBER: _ClassVar[int]
|
172
|
-
|
173
|
-
|
174
|
-
|
175
|
-
text: str
|
170
|
+
INPUT_LOGPROBS_FIELD_NUMBER: _ClassVar[int]
|
171
|
+
INDEX_FIELD_NUMBER: _ClassVar[int]
|
172
|
+
token_ids: _containers.RepeatedScalarFieldContainer[int]
|
176
173
|
prompt_tokens: int
|
177
174
|
completion_tokens: int
|
178
175
|
cached_tokens: int
|
179
|
-
|
176
|
+
output_logprobs: OutputLogProbs
|
180
177
|
hidden_states: _containers.RepeatedScalarFieldContainer[float]
|
181
|
-
|
182
|
-
|
183
|
-
def __init__(self,
|
178
|
+
input_logprobs: InputLogProbs
|
179
|
+
index: int
|
180
|
+
def __init__(self, token_ids: _Optional[_Iterable[int]] = ..., prompt_tokens: _Optional[int] = ..., completion_tokens: _Optional[int] = ..., cached_tokens: _Optional[int] = ..., output_logprobs: _Optional[_Union[OutputLogProbs, _Mapping]] = ..., hidden_states: _Optional[_Iterable[float]] = ..., input_logprobs: _Optional[_Union[InputLogProbs, _Mapping]] = ..., index: _Optional[int] = ...) -> None: ...
|
184
181
|
|
185
182
|
class GenerateComplete(_message.Message):
|
186
|
-
__slots__ = ("output_ids", "
|
187
|
-
class FinishReason(int, metaclass=_enum_type_wrapper.EnumTypeWrapper):
|
188
|
-
__slots__ = ()
|
189
|
-
STOP: _ClassVar[GenerateComplete.FinishReason]
|
190
|
-
LENGTH: _ClassVar[GenerateComplete.FinishReason]
|
191
|
-
EOS_TOKEN: _ClassVar[GenerateComplete.FinishReason]
|
192
|
-
STOP_STR: _ClassVar[GenerateComplete.FinishReason]
|
193
|
-
ABORT: _ClassVar[GenerateComplete.FinishReason]
|
194
|
-
STOP: GenerateComplete.FinishReason
|
195
|
-
LENGTH: GenerateComplete.FinishReason
|
196
|
-
EOS_TOKEN: GenerateComplete.FinishReason
|
197
|
-
STOP_STR: GenerateComplete.FinishReason
|
198
|
-
ABORT: GenerateComplete.FinishReason
|
183
|
+
__slots__ = ("output_ids", "finish_reason", "prompt_tokens", "completion_tokens", "cached_tokens", "output_logprobs", "all_hidden_states", "matched_token_id", "matched_stop_str", "input_logprobs", "index")
|
199
184
|
OUTPUT_IDS_FIELD_NUMBER: _ClassVar[int]
|
200
|
-
OUTPUT_TEXT_FIELD_NUMBER: _ClassVar[int]
|
201
185
|
FINISH_REASON_FIELD_NUMBER: _ClassVar[int]
|
202
|
-
|
186
|
+
PROMPT_TOKENS_FIELD_NUMBER: _ClassVar[int]
|
187
|
+
COMPLETION_TOKENS_FIELD_NUMBER: _ClassVar[int]
|
188
|
+
CACHED_TOKENS_FIELD_NUMBER: _ClassVar[int]
|
189
|
+
OUTPUT_LOGPROBS_FIELD_NUMBER: _ClassVar[int]
|
203
190
|
ALL_HIDDEN_STATES_FIELD_NUMBER: _ClassVar[int]
|
191
|
+
MATCHED_TOKEN_ID_FIELD_NUMBER: _ClassVar[int]
|
192
|
+
MATCHED_STOP_STR_FIELD_NUMBER: _ClassVar[int]
|
193
|
+
INPUT_LOGPROBS_FIELD_NUMBER: _ClassVar[int]
|
194
|
+
INDEX_FIELD_NUMBER: _ClassVar[int]
|
204
195
|
output_ids: _containers.RepeatedScalarFieldContainer[int]
|
205
|
-
|
206
|
-
|
207
|
-
|
196
|
+
finish_reason: str
|
197
|
+
prompt_tokens: int
|
198
|
+
completion_tokens: int
|
199
|
+
cached_tokens: int
|
200
|
+
output_logprobs: OutputLogProbs
|
208
201
|
all_hidden_states: _containers.RepeatedCompositeFieldContainer[HiddenStates]
|
209
|
-
|
202
|
+
matched_token_id: int
|
203
|
+
matched_stop_str: str
|
204
|
+
input_logprobs: InputLogProbs
|
205
|
+
index: int
|
206
|
+
def __init__(self, output_ids: _Optional[_Iterable[int]] = ..., finish_reason: _Optional[str] = ..., prompt_tokens: _Optional[int] = ..., completion_tokens: _Optional[int] = ..., cached_tokens: _Optional[int] = ..., output_logprobs: _Optional[_Union[OutputLogProbs, _Mapping]] = ..., all_hidden_states: _Optional[_Iterable[_Union[HiddenStates, _Mapping]]] = ..., matched_token_id: _Optional[int] = ..., matched_stop_str: _Optional[str] = ..., input_logprobs: _Optional[_Union[InputLogProbs, _Mapping]] = ..., index: _Optional[int] = ...) -> None: ...
|
210
207
|
|
211
208
|
class GenerateError(_message.Message):
|
212
209
|
__slots__ = ("message", "http_status_code", "details")
|
@@ -218,27 +215,39 @@ class GenerateError(_message.Message):
|
|
218
215
|
details: str
|
219
216
|
def __init__(self, message: _Optional[str] = ..., http_status_code: _Optional[str] = ..., details: _Optional[str] = ...) -> None: ...
|
220
217
|
|
221
|
-
class
|
222
|
-
__slots__ = ("token_logprobs", "token_ids", "top_logprobs"
|
218
|
+
class OutputLogProbs(_message.Message):
|
219
|
+
__slots__ = ("token_logprobs", "token_ids", "top_logprobs")
|
223
220
|
TOKEN_LOGPROBS_FIELD_NUMBER: _ClassVar[int]
|
224
221
|
TOKEN_IDS_FIELD_NUMBER: _ClassVar[int]
|
225
222
|
TOP_LOGPROBS_FIELD_NUMBER: _ClassVar[int]
|
226
|
-
TOKEN_TEXTS_FIELD_NUMBER: _ClassVar[int]
|
227
223
|
token_logprobs: _containers.RepeatedScalarFieldContainer[float]
|
228
224
|
token_ids: _containers.RepeatedScalarFieldContainer[int]
|
229
225
|
top_logprobs: _containers.RepeatedCompositeFieldContainer[TopLogProbs]
|
230
|
-
|
231
|
-
|
226
|
+
def __init__(self, token_logprobs: _Optional[_Iterable[float]] = ..., token_ids: _Optional[_Iterable[int]] = ..., top_logprobs: _Optional[_Iterable[_Union[TopLogProbs, _Mapping]]] = ...) -> None: ...
|
227
|
+
|
228
|
+
class InputLogProbs(_message.Message):
|
229
|
+
__slots__ = ("token_logprobs", "token_ids", "top_logprobs")
|
230
|
+
TOKEN_LOGPROBS_FIELD_NUMBER: _ClassVar[int]
|
231
|
+
TOKEN_IDS_FIELD_NUMBER: _ClassVar[int]
|
232
|
+
TOP_LOGPROBS_FIELD_NUMBER: _ClassVar[int]
|
233
|
+
token_logprobs: _containers.RepeatedCompositeFieldContainer[InputTokenLogProb]
|
234
|
+
token_ids: _containers.RepeatedScalarFieldContainer[int]
|
235
|
+
top_logprobs: _containers.RepeatedCompositeFieldContainer[TopLogProbs]
|
236
|
+
def __init__(self, token_logprobs: _Optional[_Iterable[_Union[InputTokenLogProb, _Mapping]]] = ..., token_ids: _Optional[_Iterable[int]] = ..., top_logprobs: _Optional[_Iterable[_Union[TopLogProbs, _Mapping]]] = ...) -> None: ...
|
237
|
+
|
238
|
+
class InputTokenLogProb(_message.Message):
|
239
|
+
__slots__ = ("value",)
|
240
|
+
VALUE_FIELD_NUMBER: _ClassVar[int]
|
241
|
+
value: float
|
242
|
+
def __init__(self, value: _Optional[float] = ...) -> None: ...
|
232
243
|
|
233
244
|
class TopLogProbs(_message.Message):
|
234
|
-
__slots__ = ("values", "token_ids"
|
245
|
+
__slots__ = ("values", "token_ids")
|
235
246
|
VALUES_FIELD_NUMBER: _ClassVar[int]
|
236
247
|
TOKEN_IDS_FIELD_NUMBER: _ClassVar[int]
|
237
|
-
TOKEN_TEXTS_FIELD_NUMBER: _ClassVar[int]
|
238
248
|
values: _containers.RepeatedScalarFieldContainer[float]
|
239
249
|
token_ids: _containers.RepeatedScalarFieldContainer[int]
|
240
|
-
|
241
|
-
def __init__(self, values: _Optional[_Iterable[float]] = ..., token_ids: _Optional[_Iterable[int]] = ..., token_texts: _Optional[_Iterable[str]] = ...) -> None: ...
|
250
|
+
def __init__(self, values: _Optional[_Iterable[float]] = ..., token_ids: _Optional[_Iterable[int]] = ...) -> None: ...
|
242
251
|
|
243
252
|
class HiddenStates(_message.Message):
|
244
253
|
__slots__ = ("values", "layer", "position")
|
@@ -283,20 +292,18 @@ class EmbedResponse(_message.Message):
|
|
283
292
|
def __init__(self, request_id: _Optional[str] = ..., complete: _Optional[_Union[EmbedComplete, _Mapping]] = ..., error: _Optional[_Union[EmbedError, _Mapping]] = ...) -> None: ...
|
284
293
|
|
285
294
|
class EmbedComplete(_message.Message):
|
286
|
-
__slots__ = ("embedding", "prompt_tokens", "cached_tokens", "embedding_dim", "
|
295
|
+
__slots__ = ("embedding", "prompt_tokens", "cached_tokens", "embedding_dim", "batch_embeddings")
|
287
296
|
EMBEDDING_FIELD_NUMBER: _ClassVar[int]
|
288
297
|
PROMPT_TOKENS_FIELD_NUMBER: _ClassVar[int]
|
289
298
|
CACHED_TOKENS_FIELD_NUMBER: _ClassVar[int]
|
290
299
|
EMBEDDING_DIM_FIELD_NUMBER: _ClassVar[int]
|
291
|
-
GENERATION_TIME_FIELD_NUMBER: _ClassVar[int]
|
292
300
|
BATCH_EMBEDDINGS_FIELD_NUMBER: _ClassVar[int]
|
293
301
|
embedding: _containers.RepeatedScalarFieldContainer[float]
|
294
302
|
prompt_tokens: int
|
295
303
|
cached_tokens: int
|
296
304
|
embedding_dim: int
|
297
|
-
generation_time: float
|
298
305
|
batch_embeddings: _containers.RepeatedCompositeFieldContainer[Embedding]
|
299
|
-
def __init__(self, embedding: _Optional[_Iterable[float]] = ..., prompt_tokens: _Optional[int] = ..., cached_tokens: _Optional[int] = ..., embedding_dim: _Optional[int] = ...,
|
306
|
+
def __init__(self, embedding: _Optional[_Iterable[float]] = ..., prompt_tokens: _Optional[int] = ..., cached_tokens: _Optional[int] = ..., embedding_dim: _Optional[int] = ..., batch_embeddings: _Optional[_Iterable[_Union[Embedding, _Mapping]]] = ...) -> None: ...
|
300
307
|
|
301
308
|
class Embedding(_message.Message):
|
302
309
|
__slots__ = ("values", "index")
|
sglang/srt/layers/activation.py
CHANGED
@@ -224,12 +224,13 @@ class XIELU(CustomOp):
|
|
224
224
|
self._xielu_cuda_fn = self._xielu_cuda
|
225
225
|
logger.warning_once(msg)
|
226
226
|
except Exception as err:
|
227
|
-
|
228
|
-
|
229
|
-
|
230
|
-
|
231
|
-
|
232
|
-
)
|
227
|
+
pass
|
228
|
+
# logger.warning_once(
|
229
|
+
# "CUDA-fused xIELU not available (%s) –"
|
230
|
+
# " falling back to a Python version.\n"
|
231
|
+
# "For CUDA xIELU (experimental), `pip install git+https://github.com/nickjbrowning/XIELU`",
|
232
|
+
# str(err),
|
233
|
+
# )
|
233
234
|
|
234
235
|
def _xielu_python(self, x: torch.Tensor) -> torch.Tensor:
|
235
236
|
alpha_p = nn.functional.softplus(self.alpha_p)
|
@@ -4,18 +4,13 @@ from __future__ import annotations
|
|
4
4
|
end to end attention solution with aiter kernels
|
5
5
|
"""
|
6
6
|
|
7
|
-
import math
|
8
|
-
import os
|
9
7
|
from dataclasses import dataclass
|
10
8
|
from enum import Enum, auto
|
11
|
-
from
|
12
|
-
from typing import TYPE_CHECKING, List, Optional, Union
|
9
|
+
from typing import TYPE_CHECKING, Optional
|
13
10
|
|
14
11
|
import torch
|
15
12
|
import triton
|
16
|
-
import triton.language as tl
|
17
13
|
|
18
|
-
from sglang.global_config import global_config
|
19
14
|
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
20
15
|
from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
|
21
16
|
from sglang.srt.layers.dp_attention import (
|
@@ -27,7 +22,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMo
|
|
27
22
|
if TYPE_CHECKING:
|
28
23
|
from sglang.srt.layers.radix_attention import RadixAttention
|
29
24
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
30
|
-
from sglang.srt.speculative.spec_info import
|
25
|
+
from sglang.srt.speculative.spec_info import SpecInput
|
31
26
|
|
32
27
|
try:
|
33
28
|
from aiter import (
|
@@ -374,7 +369,7 @@ class AiterAttnBackend(AttentionBackend):
|
|
374
369
|
seq_lens: torch.Tensor,
|
375
370
|
encoder_lens: Optional[torch.Tensor],
|
376
371
|
forward_mode: ForwardMode,
|
377
|
-
spec_info: Optional[
|
372
|
+
spec_info: Optional[SpecInput],
|
378
373
|
):
|
379
374
|
if forward_mode.is_decode_or_idle():
|
380
375
|
qo_indptr = None
|
@@ -509,7 +504,7 @@ class AiterAttnBackend(AttentionBackend):
|
|
509
504
|
seq_lens_sum: int,
|
510
505
|
encoder_lens: Optional[torch.Tensor],
|
511
506
|
forward_mode: ForwardMode,
|
512
|
-
spec_info: Optional[
|
507
|
+
spec_info: Optional[SpecInput],
|
513
508
|
seq_lens_cpu: Optional[torch.Tensor],
|
514
509
|
):
|
515
510
|
if forward_mode.is_decode_or_idle():
|
@@ -619,7 +614,11 @@ class AiterAttnBackend(AttentionBackend):
|
|
619
614
|
assert len(k.shape) == 3
|
620
615
|
assert len(v.shape) == 3
|
621
616
|
|
622
|
-
if
|
617
|
+
if (
|
618
|
+
forward_batch.forward_mode.is_extend()
|
619
|
+
and not forward_batch.forward_mode.is_target_verify()
|
620
|
+
and not forward_batch.forward_mode.is_draft_extend()
|
621
|
+
):
|
623
622
|
if kv_indices.shape[0] == 0:
|
624
623
|
o = flash_attn_varlen_func(
|
625
624
|
q,
|
@@ -884,7 +883,7 @@ class AiterIndicesUpdaterPrefill:
|
|
884
883
|
seq_lens_sum: int,
|
885
884
|
prefix_lens: torch.Tensor,
|
886
885
|
encoder_lens: Optional[torch.Tensor],
|
887
|
-
spec_info: Optional[
|
886
|
+
spec_info: Optional[SpecInput],
|
888
887
|
):
|
889
888
|
# Keep the signature for type checking. It will be assigned during runtime.
|
890
889
|
raise NotImplementedError()
|
@@ -896,7 +895,7 @@ class AiterIndicesUpdaterPrefill:
|
|
896
895
|
seq_lens_sum: int,
|
897
896
|
prefix_lens: torch.Tensor,
|
898
897
|
encoder_lens: Optional[torch.Tensor],
|
899
|
-
spec_info: Optional[
|
898
|
+
spec_info: Optional[SpecInput],
|
900
899
|
):
|
901
900
|
|
902
901
|
kv_start_idx = None
|
@@ -980,7 +979,7 @@ class AiterMlaIndicesUpdaterPrefill:
|
|
980
979
|
extend_lens: torch.Tensor,
|
981
980
|
max_q_len: int,
|
982
981
|
max_kv_len: int,
|
983
|
-
spec_info: Optional[
|
982
|
+
spec_info: Optional[SpecInput],
|
984
983
|
):
|
985
984
|
# Keep the signature for type checking. It will be assigned during runtime.
|
986
985
|
raise NotImplementedError()
|
@@ -993,7 +992,7 @@ class AiterMlaIndicesUpdaterPrefill:
|
|
993
992
|
extend_lens: torch.Tensor,
|
994
993
|
max_q_len: int,
|
995
994
|
max_kv_len: int,
|
996
|
-
spec_info: Optional[
|
995
|
+
spec_info: Optional[SpecInput],
|
997
996
|
):
|
998
997
|
bs = len(req_pool_indices)
|
999
998
|
|
@@ -1050,7 +1049,7 @@ class AiterMultiStepDraftBackend:
|
|
1050
1049
|
topk: int,
|
1051
1050
|
speculative_num_steps: int,
|
1052
1051
|
):
|
1053
|
-
from sglang.srt.speculative.
|
1052
|
+
from sglang.srt.speculative.spec_utils import generate_draft_decode_kv_indices
|
1054
1053
|
|
1055
1054
|
self.topk = topk
|
1056
1055
|
self.speculative_num_steps = speculative_num_steps
|
@@ -5,14 +5,15 @@ from typing import TYPE_CHECKING, List, Optional
|
|
5
5
|
|
6
6
|
import torch
|
7
7
|
import torch_npu
|
8
|
-
from torch.nn.functional import scaled_dot_product_attention
|
9
8
|
|
10
9
|
from sglang.srt.configs.model_config import AttentionArch
|
11
10
|
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
11
|
+
from sglang.srt.layers.attention.npu_ops.mla_preprocess import is_mla_preprocess_enabled
|
12
12
|
from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBackend
|
13
13
|
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
14
14
|
from sglang.srt.layers.radix_attention import AttentionType
|
15
|
-
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
15
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
16
|
+
from sglang.srt.speculative.spec_info import SpecInput
|
16
17
|
from sglang.srt.utils import get_bool_env_var
|
17
18
|
|
18
19
|
if TYPE_CHECKING:
|
@@ -35,6 +36,8 @@ class ForwardMetadata:
|
|
35
36
|
seq_lens_cpu_int: Optional[torch.Tensor] = None
|
36
37
|
seq_lens_cpu_list: Optional[List[int]] = None
|
37
38
|
seq_lens_list_cumsum: Optional[List[int]] = None
|
39
|
+
seq_lens: Optional[torch.Tensor] = None
|
40
|
+
actual_seq_lengths_q: Optional[torch.Tensor] = None
|
38
41
|
|
39
42
|
|
40
43
|
class AscendAttnBackend(AttentionBackend):
|
@@ -66,6 +69,9 @@ class AscendAttnBackend(AttentionBackend):
|
|
66
69
|
if self.use_mla:
|
67
70
|
self.kv_lora_rank = model_runner.model_config.kv_lora_rank
|
68
71
|
self.qk_rope_head_dim = model_runner.model_config.qk_rope_head_dim
|
72
|
+
self.q_head_dim = (
|
73
|
+
self.qk_rope_head_dim + model_runner.model_config.qk_nope_head_dim
|
74
|
+
)
|
69
75
|
self.native_attn = TorchNativeAttnBackend(model_runner)
|
70
76
|
self.graph_metadata = {}
|
71
77
|
self.max_context_len = model_runner.model_config.context_len
|
@@ -101,10 +107,6 @@ class AscendAttnBackend(AttentionBackend):
|
|
101
107
|
self.forward_metadata.seq_lens_cpu_int = forward_batch.seq_lens_cpu.int()
|
102
108
|
|
103
109
|
seq_lens_list_cumsum = np.cumsum(forward_batch.extend_seq_lens_cpu)
|
104
|
-
if forward_batch.is_extend_in_batch:
|
105
|
-
seq_lens_list_cumsum[-1] = (
|
106
|
-
(seq_lens_list_cumsum[-1] - 1) // tp_size + 1
|
107
|
-
) * tp_size
|
108
110
|
self.forward_metadata.seq_lens_list_cumsum = seq_lens_list_cumsum
|
109
111
|
|
110
112
|
self.graph_mode = False
|
@@ -126,12 +128,16 @@ class AscendAttnBackend(AttentionBackend):
|
|
126
128
|
seq_lens: torch.Tensor,
|
127
129
|
encoder_lens: Optional[torch.Tensor],
|
128
130
|
forward_mode: ForwardMode,
|
129
|
-
spec_info: Optional[
|
131
|
+
spec_info: Optional[SpecInput],
|
130
132
|
):
|
131
133
|
metadata = ForwardMetadata()
|
132
134
|
|
133
135
|
metadata.block_tables = self.graph_metadata["block_tables"][:bs, :]
|
134
136
|
metadata.seq_lens_cpu_list = seq_lens.cpu().int().tolist()
|
137
|
+
metadata.seq_lens = seq_lens
|
138
|
+
metadata.actual_seq_lengths_q = torch.tensor(
|
139
|
+
[1 + i * 1 for i in range(bs)], dtype=torch.int32, device=seq_lens.device
|
140
|
+
)
|
135
141
|
|
136
142
|
self.graph_metadata[bs] = metadata
|
137
143
|
self.forward_metadata = metadata
|
@@ -146,7 +152,7 @@ class AscendAttnBackend(AttentionBackend):
|
|
146
152
|
seq_lens_sum: int,
|
147
153
|
encoder_lens: Optional[torch.Tensor],
|
148
154
|
forward_mode: ForwardMode,
|
149
|
-
spec_info: Optional[
|
155
|
+
spec_info: Optional[SpecInput],
|
150
156
|
seq_lens_cpu: Optional[torch.Tensor],
|
151
157
|
):
|
152
158
|
metadata = self.graph_metadata[bs]
|
@@ -160,6 +166,8 @@ class AscendAttnBackend(AttentionBackend):
|
|
160
166
|
metadata.block_tables[:bs, max_seq_pages:].fill_(0)
|
161
167
|
metadata.block_tables[bs:, :].fill_(0)
|
162
168
|
|
169
|
+
metadata.seq_lens[:bs].copy_(seq_lens[:bs])
|
170
|
+
|
163
171
|
self.forward_metadata = metadata
|
164
172
|
|
165
173
|
self.graph_mode = True
|
@@ -167,6 +175,64 @@ class AscendAttnBackend(AttentionBackend):
|
|
167
175
|
def get_cuda_graph_seq_len_fill_value(self):
|
168
176
|
return 0
|
169
177
|
|
178
|
+
def forward_sparse(
|
179
|
+
self,
|
180
|
+
q: torch.Tensor,
|
181
|
+
k: torch.Tensor,
|
182
|
+
v: torch.Tensor,
|
183
|
+
layer: RadixAttention,
|
184
|
+
forward_batch: ForwardBatch,
|
185
|
+
save_kv_cache: bool = True,
|
186
|
+
# For multi_head latent attention
|
187
|
+
q_rope: Optional[torch.Tensor] = None,
|
188
|
+
k_rope: Optional[torch.Tensor] = None,
|
189
|
+
topk_indices: torch.Tensor = None,
|
190
|
+
):
|
191
|
+
|
192
|
+
is_prefill = forward_batch.forward_mode.is_extend()
|
193
|
+
|
194
|
+
if save_kv_cache:
|
195
|
+
k = k.view(-1, layer.tp_k_head_num, self.kv_lora_rank)
|
196
|
+
k_rope = k_rope.view(-1, layer.tp_k_head_num, self.qk_rope_head_dim)
|
197
|
+
forward_batch.token_to_kv_pool.set_kv_buffer(
|
198
|
+
layer, forward_batch.out_cache_loc, k, k_rope
|
199
|
+
)
|
200
|
+
q_nope, q_pe = q, q_rope
|
201
|
+
k_nope, k_pe = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id)
|
202
|
+
block_table = self.forward_metadata.block_tables
|
203
|
+
if is_prefill:
|
204
|
+
actual_seq_qlen = torch.cumsum(forward_batch.seq_lens, dim=0)
|
205
|
+
else:
|
206
|
+
if self.forward_metadata.actual_seq_lengths_q is None:
|
207
|
+
actual_seq_qlen = (
|
208
|
+
torch.arange(1, q.shape[0] + 1).to(q.device).to(torch.int32)
|
209
|
+
)
|
210
|
+
else:
|
211
|
+
actual_seq_qlen = self.forward_metadata.actual_seq_lengths_q
|
212
|
+
if self.forward_metadata.seq_lens_cpu_int is None:
|
213
|
+
actual_seq_lengths_kv = self.forward_metadata.seq_lens
|
214
|
+
else:
|
215
|
+
actual_seq_lengths_kv = self.forward_metadata.seq_lens_cpu_int
|
216
|
+
|
217
|
+
attn_out = torch.ops.custom.npu_sparse_flash_attention(
|
218
|
+
query=q_nope,
|
219
|
+
key=k_nope,
|
220
|
+
value=k_nope,
|
221
|
+
query_rope=q_pe,
|
222
|
+
key_rope=k_pe,
|
223
|
+
sparse_indices=topk_indices,
|
224
|
+
scale_value=layer.scaling,
|
225
|
+
actual_seq_lengths_query=actual_seq_qlen.to(torch.int32),
|
226
|
+
actual_seq_lengths_kv=actual_seq_lengths_kv.to(q.device),
|
227
|
+
block_table=block_table,
|
228
|
+
sparse_block_size=1,
|
229
|
+
layout_query="TND",
|
230
|
+
layout_kv="PA_BSND",
|
231
|
+
sparse_mode=3,
|
232
|
+
)
|
233
|
+
|
234
|
+
return attn_out
|
235
|
+
|
170
236
|
def forward_extend(
|
171
237
|
self,
|
172
238
|
q,
|
@@ -175,7 +241,23 @@ class AscendAttnBackend(AttentionBackend):
|
|
175
241
|
layer: RadixAttention,
|
176
242
|
forward_batch: ForwardBatch,
|
177
243
|
save_kv_cache: bool = True,
|
244
|
+
# For multi_head latent attention
|
245
|
+
q_rope: Optional[torch.Tensor] = None,
|
246
|
+
k_rope: Optional[torch.Tensor] = None,
|
247
|
+
topk_indices: Optional[torch.Tensor] = None,
|
178
248
|
):
|
249
|
+
if topk_indices is not None:
|
250
|
+
return self.forward_sparse(
|
251
|
+
q,
|
252
|
+
k,
|
253
|
+
v,
|
254
|
+
layer,
|
255
|
+
forward_batch,
|
256
|
+
save_kv_cache,
|
257
|
+
q_rope,
|
258
|
+
k_rope,
|
259
|
+
topk_indices,
|
260
|
+
)
|
179
261
|
if not self.use_mla:
|
180
262
|
if save_kv_cache:
|
181
263
|
forward_batch.token_to_kv_pool.set_kv_buffer(
|
@@ -401,7 +483,7 @@ class AscendAttnBackend(AttentionBackend):
|
|
401
483
|
antiquant_scale=None,
|
402
484
|
sparse_mode=0,
|
403
485
|
)
|
404
|
-
output = torch.
|
486
|
+
output = torch.empty_like(q_nope, dtype=q.dtype, device=q.device)
|
405
487
|
softmax_lse = torch.empty(1, dtype=q.dtype, device=q.device)
|
406
488
|
|
407
489
|
torch_npu.npu_fused_infer_attention_score.out(
|
@@ -436,7 +518,24 @@ class AscendAttnBackend(AttentionBackend):
|
|
436
518
|
# For multi-head latent attention
|
437
519
|
q_rope: Optional[torch.Tensor] = None,
|
438
520
|
k_rope: Optional[torch.Tensor] = None,
|
521
|
+
topk_indices: Optional[torch.Tensor] = None,
|
439
522
|
):
|
523
|
+
if is_mla_preprocess_enabled():
|
524
|
+
# MLAPO does saving kv_cache
|
525
|
+
save_kv_cache = False
|
526
|
+
if topk_indices is not None:
|
527
|
+
return self.forward_sparse(
|
528
|
+
q,
|
529
|
+
k,
|
530
|
+
v,
|
531
|
+
layer,
|
532
|
+
forward_batch,
|
533
|
+
save_kv_cache,
|
534
|
+
q_rope,
|
535
|
+
k_rope,
|
536
|
+
topk_indices,
|
537
|
+
)
|
538
|
+
|
440
539
|
if self.graph_mode:
|
441
540
|
return self.forward_decode_graph(
|
442
541
|
q,
|