sglang 0.5.3rc0__py3-none-any.whl → 0.5.3rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +7 -9
- sglang/bench_one_batch_server.py +321 -31
- sglang/bench_serving.py +10 -3
- sglang/global_config.py +2 -2
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/launch_server.py +14 -0
- sglang/profiler.py +2 -2
- sglang/srt/batch_invariant_ops/__init__.py +27 -0
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/dots_ocr.py +64 -0
- sglang/srt/configs/falcon_h1.py +360 -0
- sglang/srt/configs/load_config.py +8 -0
- sglang/srt/configs/model_config.py +160 -105
- sglang/srt/configs/qwen3_vl.py +586 -0
- sglang/srt/constrained/base_grammar_backend.py +1 -0
- sglang/srt/constrained/outlines_jump_forward.py +1 -1
- sglang/srt/constrained/xgrammar_backend.py +6 -4
- sglang/srt/debug_utils/dumper.py +10 -3
- sglang/srt/disaggregation/ascend/conn.py +2 -2
- sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
- sglang/srt/disaggregation/common/conn.py +266 -98
- sglang/srt/disaggregation/decode.py +50 -9
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +25 -16
- sglang/srt/disaggregation/mooncake/conn.py +51 -541
- sglang/srt/disaggregation/nixl/conn.py +148 -39
- sglang/srt/disaggregation/prefill.py +31 -14
- sglang/srt/disaggregation/utils.py +36 -5
- sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
- sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
- sglang/srt/distributed/parallel_state.py +135 -80
- sglang/srt/entrypoints/engine.py +23 -3
- sglang/srt/entrypoints/grpc_request_manager.py +330 -55
- sglang/srt/entrypoints/grpc_server.py +232 -102
- sglang/srt/entrypoints/http_server.py +49 -9
- sglang/srt/entrypoints/openai/protocol.py +110 -5
- sglang/srt/entrypoints/openai/serving_base.py +25 -6
- sglang/srt/entrypoints/openai/serving_chat.py +178 -49
- sglang/srt/entrypoints/openai/serving_completions.py +5 -3
- sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
- sglang/srt/entrypoints/openai/serving_responses.py +42 -0
- sglang/srt/environ.py +285 -0
- sglang/srt/eplb/expert_location.py +30 -5
- sglang/srt/function_call/function_call_parser.py +3 -2
- sglang/srt/function_call/glm4_moe_detector.py +3 -3
- sglang/srt/function_call/gpt_oss_detector.py +23 -0
- sglang/srt/function_call/json_array_parser.py +63 -0
- sglang/srt/function_call/kimik2_detector.py +17 -4
- sglang/srt/function_call/utils.py +96 -5
- sglang/srt/grpc/compile_proto.py +245 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +73 -68
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +60 -53
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +3 -0
- sglang/srt/layers/activation.py +7 -6
- sglang/srt/layers/attention/aiter_backend.py +14 -15
- sglang/srt/layers/attention/ascend_backend.py +108 -9
- sglang/srt/layers/attention/attention_registry.py +206 -0
- sglang/srt/layers/attention/base_attn_backend.py +12 -3
- sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +2 -2
- sglang/srt/layers/attention/fla/fused_recurrent.py +4 -4
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +2 -2
- sglang/srt/layers/attention/flashattention_backend.py +41 -8
- sglang/srt/layers/attention/flashinfer_backend.py +112 -194
- sglang/srt/layers/attention/flashinfer_mla_backend.py +11 -15
- sglang/srt/layers/attention/flashmla_backend.py +7 -5
- sglang/srt/layers/attention/hybrid_attn_backend.py +11 -3
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +72 -72
- sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +15 -98
- sglang/srt/layers/attention/mamba/mamba.py +566 -1
- sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
- sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
- sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
- sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
- sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
- sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
- sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
- sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
- sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
- sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
- sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
- sglang/srt/layers/attention/nsa/transform_index.py +144 -0
- sglang/srt/layers/attention/nsa/utils.py +24 -0
- sglang/srt/layers/attention/nsa_backend.py +887 -0
- sglang/srt/layers/attention/tbo_backend.py +6 -6
- sglang/srt/layers/attention/torch_flex_backend.py +325 -0
- sglang/srt/layers/attention/triton_backend.py +42 -9
- sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
- sglang/srt/layers/attention/trtllm_mla_backend.py +178 -34
- sglang/srt/layers/attention/vision.py +58 -0
- sglang/srt/layers/attention/wave_backend.py +4 -4
- sglang/srt/layers/communicator.py +8 -0
- sglang/srt/layers/dp_attention.py +11 -1
- sglang/srt/layers/elementwise.py +3 -1
- sglang/srt/layers/layernorm.py +2 -0
- sglang/srt/layers/linear.py +21 -4
- sglang/srt/layers/logits_processor.py +15 -2
- sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
- sglang/srt/layers/moe/ep_moe/layer.py +147 -74
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +52 -25
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +6 -2
- sglang/srt/layers/moe/fused_moe_triton/layer.py +11 -12
- sglang/srt/layers/moe/token_dispatcher/deepep.py +77 -19
- sglang/srt/layers/moe/utils.py +10 -0
- sglang/srt/layers/parameter.py +23 -6
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
- sglang/srt/layers/quantization/fp8.py +2 -2
- sglang/srt/layers/quantization/fp8_utils.py +1 -1
- sglang/srt/layers/quantization/modelopt_quant.py +44 -9
- sglang/srt/layers/quantization/mxfp4.py +12 -4
- sglang/srt/layers/quantization/quark/quark_moe.py +16 -3
- sglang/srt/layers/quantization/w4afp8.py +0 -4
- sglang/srt/layers/quantization/w8a8_int8.py +15 -3
- sglang/srt/layers/rotary_embedding.py +78 -31
- sglang/srt/layers/sampler.py +52 -4
- sglang/srt/layers/utils.py +23 -0
- sglang/srt/lora/backend/base_backend.py +3 -3
- sglang/srt/lora/backend/chunked_backend.py +348 -0
- sglang/srt/lora/backend/triton_backend.py +10 -4
- sglang/srt/lora/lora.py +7 -5
- sglang/srt/lora/lora_manager.py +17 -6
- sglang/srt/lora/mem_pool.py +1 -1
- sglang/srt/lora/triton_ops/__init__.py +4 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
- sglang/srt/lora/utils.py +7 -5
- sglang/srt/managers/cache_controller.py +42 -142
- sglang/srt/managers/data_parallel_controller.py +11 -46
- sglang/srt/managers/detokenizer_manager.py +11 -11
- sglang/srt/managers/io_struct.py +162 -118
- sglang/srt/managers/mm_utils.py +43 -6
- sglang/srt/managers/multi_tokenizer_mixin.py +17 -17
- sglang/srt/managers/multimodal_processor.py +1 -2
- sglang/srt/managers/overlap_utils.py +53 -0
- sglang/srt/managers/schedule_batch.py +167 -86
- sglang/srt/managers/schedule_policy.py +143 -16
- sglang/srt/managers/scheduler.py +359 -214
- sglang/srt/managers/scheduler_input_blocker.py +1 -1
- sglang/srt/managers/scheduler_metrics_mixin.py +98 -126
- sglang/srt/managers/scheduler_output_processor_mixin.py +21 -12
- sglang/srt/managers/scheduler_profiler_mixin.py +5 -5
- sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
- sglang/srt/managers/tokenizer_communicator_mixin.py +111 -5
- sglang/srt/managers/tokenizer_manager.py +84 -136
- sglang/srt/managers/tp_worker.py +39 -29
- sglang/srt/managers/tp_worker_overlap_thread.py +33 -41
- sglang/srt/managers/utils.py +1 -45
- sglang/srt/mem_cache/allocator.py +14 -20
- sglang/srt/mem_cache/allocator_ascend.py +41 -27
- sglang/srt/mem_cache/base_prefix_cache.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +8 -1
- sglang/srt/mem_cache/evict_policy.py +23 -0
- sglang/srt/mem_cache/hicache_storage.py +40 -1
- sglang/srt/mem_cache/hiradix_cache.py +119 -32
- sglang/srt/mem_cache/memory_pool.py +188 -10
- sglang/srt/mem_cache/memory_pool_host.py +134 -182
- sglang/srt/mem_cache/radix_cache.py +222 -71
- sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
- sglang/srt/mem_cache/storage/__init__.py +10 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
- sglang/srt/mem_cache/storage/backend_factory.py +223 -0
- sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
- sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +173 -58
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +10 -6
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +117 -10
- sglang/srt/mem_cache/swa_radix_cache.py +25 -34
- sglang/srt/metrics/collector.py +82 -120
- sglang/srt/metrics/func_timer.py +2 -7
- sglang/srt/metrics/utils.py +8 -1
- sglang/srt/model_executor/cpu_graph_runner.py +2 -2
- sglang/srt/model_executor/cuda_graph_runner.py +39 -32
- sglang/srt/model_executor/forward_batch_info.py +23 -38
- sglang/srt/model_executor/model_runner.py +131 -183
- sglang/srt/model_executor/npu_graph_runner.py +12 -5
- sglang/srt/model_loader/loader.py +14 -10
- sglang/srt/model_loader/weight_utils.py +156 -2
- sglang/srt/models/bailing_moe.py +27 -4
- sglang/srt/models/deepseek_nextn.py +6 -1
- sglang/srt/models/deepseek_v2.py +536 -153
- sglang/srt/models/dots_ocr.py +173 -0
- sglang/srt/models/falcon_h1.py +576 -0
- sglang/srt/models/gemma3_causal.py +0 -2
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/gemma3n_mm.py +1 -1
- sglang/srt/models/glm4_moe.py +3 -3
- sglang/srt/models/glm4_moe_nextn.py +2 -2
- sglang/srt/models/glm4v.py +1 -1
- sglang/srt/models/glm4v_moe.py +1 -1
- sglang/srt/models/gpt_oss.py +7 -30
- sglang/srt/models/kimi_vl_moonvit.py +2 -2
- sglang/srt/models/llama.py +4 -0
- sglang/srt/models/longcat_flash.py +1 -1
- sglang/srt/models/longcat_flash_nextn.py +1 -1
- sglang/srt/models/mllama4.py +15 -4
- sglang/srt/models/qwen2.py +0 -7
- sglang/srt/models/qwen2_5_vl.py +2 -2
- sglang/srt/models/qwen2_audio.py +1 -1
- sglang/srt/models/qwen2_moe.py +64 -1
- sglang/srt/models/qwen2_vl.py +1 -1
- sglang/srt/models/qwen3.py +18 -3
- sglang/srt/models/qwen3_moe.py +31 -3
- sglang/srt/models/qwen3_next.py +36 -9
- sglang/srt/models/qwen3_vl.py +787 -0
- sglang/srt/models/qwen3_vl_moe.py +471 -0
- sglang/srt/models/registry.py +15 -3
- sglang/srt/models/sarashina2_vision.py +269 -0
- sglang/srt/models/solar.py +505 -0
- sglang/srt/models/starcoder2.py +357 -0
- sglang/srt/models/torch_native_llama.py +9 -2
- sglang/srt/models/utils.py +51 -0
- sglang/srt/multimodal/processors/base_processor.py +15 -7
- sglang/srt/multimodal/processors/dots_vlm.py +2 -3
- sglang/srt/multimodal/processors/internvl.py +20 -8
- sglang/srt/multimodal/processors/qwen_vl.py +8 -1
- sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
- sglang/srt/parser/jinja_template_utils.py +6 -0
- sglang/srt/sampling/sampling_batch_info.py +20 -2
- sglang/srt/sampling/sampling_params.py +7 -0
- sglang/srt/server_args.py +753 -295
- sglang/srt/server_args_config_parser.py +146 -0
- sglang/srt/single_batch_overlap.py +151 -0
- sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
- sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
- sglang/srt/speculative/cpp_ngram/param.h +125 -0
- sglang/srt/speculative/cpp_ngram/queue.h +71 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +2 -1
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +3 -1
- sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -755
- sglang/srt/speculative/eagle_worker.py +57 -25
- sglang/srt/speculative/ngram_utils.py +428 -0
- sglang/srt/speculative/ngram_worker.py +245 -0
- sglang/srt/speculative/spec_info.py +47 -0
- sglang/srt/speculative/spec_utils.py +606 -0
- sglang/srt/torch_memory_saver_adapter.py +5 -7
- sglang/srt/tracing/trace.py +32 -6
- sglang/srt/two_batch_overlap.py +8 -5
- sglang/srt/utils/__init__.py +2 -0
- sglang/srt/{utils.py → utils/common.py} +399 -74
- sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +49 -5
- sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
- sglang/srt/utils/rpd_utils.py +452 -0
- sglang/srt/utils/slow_rank_detector.py +71 -0
- sglang/srt/warmup.py +8 -4
- sglang/srt/weight_sync/utils.py +1 -1
- sglang/test/get_logits_ut.py +57 -0
- sglang/test/run_eval.py +79 -11
- sglang/test/runners.py +1 -1
- sglang/test/simple_eval_common.py +5 -2
- sglang/test/simple_eval_mmmu_vlm.py +441 -0
- sglang/test/test_block_fp8.py +2 -2
- sglang/test/test_deterministic.py +297 -0
- sglang/test/test_disaggregation_utils.py +12 -1
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +355 -4
- sglang/utils.py +10 -1
- sglang/version.py +1 -1
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/METADATA +34 -25
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/RECORD +281 -210
- sglang/srt/mem_cache/lora_radix_cache.py +0 -421
- /sglang/srt/{remote_instance_weight_loader_utils.py → model_loader/remote_instance_weight_loader_utils.py} +0 -0
- /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/WHEEL +0 -0
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/top_level.txt +0 -0
@@ -1,37 +1,30 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
-
import asyncio
|
4
3
|
import dataclasses
|
5
4
|
import logging
|
6
|
-
import
|
7
|
-
import socket
|
5
|
+
import os
|
8
6
|
import struct
|
9
7
|
import threading
|
8
|
+
import time
|
10
9
|
import uuid
|
11
10
|
from collections import defaultdict
|
12
|
-
from
|
13
|
-
from typing import Dict, List, Optional, Set, Tuple, TypeAlias, Union
|
11
|
+
from typing import Dict, List, Optional, Set
|
14
12
|
|
15
13
|
import numpy as np
|
16
14
|
import numpy.typing as npt
|
17
15
|
import requests
|
18
|
-
import zmq
|
19
|
-
from aiohttp import web
|
20
16
|
|
21
|
-
from sglang.srt.disaggregation.base.conn import
|
17
|
+
from sglang.srt.disaggregation.base.conn import KVArgs, KVPoll
|
22
18
|
from sglang.srt.disaggregation.common.conn import (
|
23
19
|
CommonKVBootstrapServer,
|
24
20
|
CommonKVManager,
|
25
21
|
CommonKVReceiver,
|
22
|
+
CommonKVSender,
|
26
23
|
)
|
27
24
|
from sglang.srt.disaggregation.common.utils import group_concurrent_contiguous
|
28
25
|
from sglang.srt.disaggregation.utils import DisaggregationMode
|
29
26
|
from sglang.srt.server_args import ServerArgs
|
30
|
-
from sglang.srt.utils import
|
31
|
-
format_tcp_address,
|
32
|
-
get_local_ip_auto,
|
33
|
-
is_valid_ipv6_address,
|
34
|
-
)
|
27
|
+
from sglang.srt.utils import get_int_env_var
|
35
28
|
|
36
29
|
logger = logging.getLogger(__name__)
|
37
30
|
|
@@ -113,8 +106,14 @@ class TransferStatus:
|
|
113
106
|
def is_done(self):
|
114
107
|
if self.num_kvs_expected is None:
|
115
108
|
return False
|
109
|
+
# Check for failure state
|
110
|
+
if self.num_kvs_expected == -1:
|
111
|
+
return True # Failed transfers are considered "done"
|
116
112
|
return self.num_kvs_expected == len(self.received_kvs) and self.received_aux
|
117
113
|
|
114
|
+
def is_failed(self):
|
115
|
+
return self.num_kvs_expected == -1
|
116
|
+
|
118
117
|
|
119
118
|
class NixlKVManager(CommonKVManager):
|
120
119
|
def __init__(
|
@@ -134,26 +133,133 @@ class NixlKVManager(CommonKVManager):
|
|
134
133
|
"to run SGLang with NixlTransferEngine."
|
135
134
|
) from e
|
136
135
|
self.agent = nixl_agent(str(uuid.uuid4()))
|
137
|
-
self.local_ip = get_local_ip_auto()
|
138
|
-
self.server_socket = zmq.Context().socket(zmq.PULL)
|
139
|
-
if is_valid_ipv6_address(self.local_ip):
|
140
|
-
self.server_socket.setsockopt(zmq.IPV6, 1)
|
141
136
|
self.register_buffer_to_engine()
|
142
137
|
|
143
138
|
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
144
|
-
self.request_status: Dict[int, KVPoll] = {}
|
145
|
-
self.transfer_infos: Dict[int, Dict[str, TransferInfo]] = {}
|
146
|
-
self.decode_kv_args_table: Dict[str, KVArgsRegisterInfo] = {}
|
147
139
|
self._start_bootstrap_thread()
|
148
140
|
elif self.disaggregation_mode == DisaggregationMode.DECODE:
|
149
141
|
self.transfer_statuses: Dict[int, TransferStatus] = defaultdict(
|
150
142
|
TransferStatus
|
151
143
|
)
|
144
|
+
self.heartbeat_failures = {}
|
145
|
+
self.session_pool = defaultdict(requests.Session)
|
146
|
+
self.session_pool_lock = threading.Lock()
|
147
|
+
self.addr_to_rooms_tracker = defaultdict(set)
|
148
|
+
self.connection_lock = threading.Lock()
|
149
|
+
|
150
|
+
# Heartbeat interval should be at least 2 seconds
|
151
|
+
self.heartbeat_interval = max(
|
152
|
+
float(os.getenv("SGLANG_DISAGGREGATION_HEARTBEAT_INTERVAL", 5.0)), 2.0
|
153
|
+
)
|
154
|
+
# Heartbeat failure should be at least 1
|
155
|
+
self.max_failures = max(
|
156
|
+
get_int_env_var("SGLANG_DISAGGREGATION_HEARTBEAT_MAX_FAILURE", 2), 1
|
157
|
+
)
|
158
|
+
self._start_heartbeat_checker_thread()
|
152
159
|
else:
|
153
160
|
raise ValueError(
|
154
161
|
f"Unsupported DisaggregationMode: {self.disaggregation_mode}"
|
155
162
|
)
|
156
163
|
|
164
|
+
def _start_heartbeat_checker_thread(self):
|
165
|
+
"""
|
166
|
+
Start the heartbeat checker thread for Decode worker.
|
167
|
+
TODO (smor): unite nixl heartbeat checker with mooncake's.
|
168
|
+
"""
|
169
|
+
|
170
|
+
def heartbeat_checker():
|
171
|
+
while True:
|
172
|
+
time.sleep(self.heartbeat_interval)
|
173
|
+
with self.connection_lock:
|
174
|
+
addresses = list(self.prefill_dp_size_table.keys())
|
175
|
+
|
176
|
+
for bootstrap_addr in addresses:
|
177
|
+
session = None
|
178
|
+
try:
|
179
|
+
with self.session_pool_lock:
|
180
|
+
session = self.session_pool[bootstrap_addr]
|
181
|
+
response = session.get(
|
182
|
+
f"http://{bootstrap_addr}/health",
|
183
|
+
timeout=(2, 3),
|
184
|
+
headers={"Connection": "keep-alive"},
|
185
|
+
)
|
186
|
+
if response.status_code == 200:
|
187
|
+
self.heartbeat_failures[bootstrap_addr] = 0
|
188
|
+
|
189
|
+
current_rooms = self.addr_to_rooms_tracker[
|
190
|
+
bootstrap_addr
|
191
|
+
].copy()
|
192
|
+
|
193
|
+
for bootstrap_room in current_rooms:
|
194
|
+
# Remove successful transfers from the tracker
|
195
|
+
if bootstrap_room not in self.transfer_statuses:
|
196
|
+
self.addr_to_rooms_tracker[bootstrap_addr].discard(
|
197
|
+
bootstrap_room
|
198
|
+
)
|
199
|
+
else:
|
200
|
+
logger.info(
|
201
|
+
f"Attempting to reconnect to {bootstrap_addr}..."
|
202
|
+
)
|
203
|
+
self.heartbeat_failures[bootstrap_addr] = (
|
204
|
+
self.heartbeat_failures.get(bootstrap_addr, 0) + 1
|
205
|
+
)
|
206
|
+
with self.session_pool_lock:
|
207
|
+
if bootstrap_addr in self.session_pool:
|
208
|
+
del self.session_pool[bootstrap_addr]
|
209
|
+
except Exception:
|
210
|
+
logger.info(f"Attempting to reconnect to {bootstrap_addr}...")
|
211
|
+
self.heartbeat_failures[bootstrap_addr] = (
|
212
|
+
self.heartbeat_failures.get(bootstrap_addr, 0) + 1
|
213
|
+
)
|
214
|
+
|
215
|
+
if (
|
216
|
+
self.heartbeat_failures.get(bootstrap_addr, 0)
|
217
|
+
>= self.max_failures
|
218
|
+
):
|
219
|
+
self._handle_node_failure(bootstrap_addr)
|
220
|
+
with self.session_pool_lock:
|
221
|
+
if bootstrap_addr in self.session_pool:
|
222
|
+
del self.session_pool[bootstrap_addr]
|
223
|
+
|
224
|
+
threading.Thread(target=heartbeat_checker, daemon=True).start()
|
225
|
+
|
226
|
+
def _handle_node_failure(self, failed_bootstrap_addr):
|
227
|
+
"""Handle failure of a prefill node."""
|
228
|
+
with self.connection_lock:
|
229
|
+
keys_to_remove = [
|
230
|
+
k for k in self.connection_pool if k.startswith(failed_bootstrap_addr)
|
231
|
+
]
|
232
|
+
for k in keys_to_remove:
|
233
|
+
del self.connection_pool[k]
|
234
|
+
if failed_bootstrap_addr in self.prefill_tp_size_table:
|
235
|
+
del self.prefill_tp_size_table[failed_bootstrap_addr]
|
236
|
+
if failed_bootstrap_addr in self.prefill_dp_size_table:
|
237
|
+
del self.prefill_dp_size_table[failed_bootstrap_addr]
|
238
|
+
if failed_bootstrap_addr in self.prefill_pp_size_table:
|
239
|
+
del self.prefill_pp_size_table[failed_bootstrap_addr]
|
240
|
+
|
241
|
+
possible_affected_rooms = self.addr_to_rooms_tracker.get(
|
242
|
+
failed_bootstrap_addr, []
|
243
|
+
)
|
244
|
+
if failed_bootstrap_addr in self.addr_to_rooms_tracker:
|
245
|
+
del self.addr_to_rooms_tracker[failed_bootstrap_addr]
|
246
|
+
|
247
|
+
# Mark all pending transfers associated with the failed node as failed
|
248
|
+
affected_rooms = []
|
249
|
+
for room in possible_affected_rooms:
|
250
|
+
if (
|
251
|
+
room in self.transfer_statuses
|
252
|
+
and not self.transfer_statuses[room].is_done()
|
253
|
+
):
|
254
|
+
# Mark the transfer as failed by setting a special state
|
255
|
+
self.transfer_statuses[room].num_kvs_expected = -1 # Indicates failure
|
256
|
+
affected_rooms.append(room)
|
257
|
+
|
258
|
+
logger.error(
|
259
|
+
f"Lost connection with prefill instance (bootstrap_addr: {failed_bootstrap_addr}), "
|
260
|
+
f"{len(affected_rooms)} transfers affected"
|
261
|
+
)
|
262
|
+
|
157
263
|
def check_status(self, bootstrap_room: int):
|
158
264
|
return self.request_status[bootstrap_room]
|
159
265
|
|
@@ -166,6 +272,9 @@ class NixlKVManager(CommonKVManager):
|
|
166
272
|
self.request_status[bootstrap_room], status
|
167
273
|
)
|
168
274
|
|
275
|
+
def record_failure(self, bootstrap_room: int, failure_reason: str):
|
276
|
+
pass
|
277
|
+
|
169
278
|
def register_buffer_to_engine(self):
|
170
279
|
kv_addrs = []
|
171
280
|
for kv_data_ptr, kv_data_len in zip(
|
@@ -438,7 +547,7 @@ class NixlKVManager(CommonKVManager):
|
|
438
547
|
notif = "_".join([str(req.room), "kv", str(chunk_id), str(int(is_last))])
|
439
548
|
decode_tp_size = self.decode_kv_args_table[req.agent_name].decode_tp_size
|
440
549
|
|
441
|
-
if decode_tp_size == self.
|
550
|
+
if self.is_mla_backend or (decode_tp_size == self.attn_tp_size):
|
442
551
|
kv_xfer_handle = self.send_kvcache(
|
443
552
|
req.agent_name,
|
444
553
|
kv_indices,
|
@@ -455,7 +564,7 @@ class NixlKVManager(CommonKVManager):
|
|
455
564
|
chunked_dst_kv_indice,
|
456
565
|
self.decode_kv_args_table[req.agent_name].gpu_id,
|
457
566
|
notif,
|
458
|
-
prefill_tp_size=self.
|
567
|
+
prefill_tp_size=self.attn_tp_size,
|
459
568
|
decode_tp_size=decode_tp_size,
|
460
569
|
decode_tp_rank=self.decode_kv_args_table[
|
461
570
|
req.agent_name
|
@@ -505,9 +614,6 @@ class NixlKVManager(CommonKVManager):
|
|
505
614
|
return False
|
506
615
|
return self.transfer_statuses[room].is_done()
|
507
616
|
|
508
|
-
def _bind_server_socket(self):
|
509
|
-
self.server_socket.bind(format_tcp_address(self.local_ip, self.rank_port))
|
510
|
-
|
511
617
|
def _start_bootstrap_thread(self):
|
512
618
|
self._bind_server_socket()
|
513
619
|
|
@@ -548,7 +654,7 @@ class NixlKVManager(CommonKVManager):
|
|
548
654
|
threading.Thread(target=bootstrap_thread).start()
|
549
655
|
|
550
656
|
|
551
|
-
class NixlKVSender(
|
657
|
+
class NixlKVSender(CommonKVSender):
|
552
658
|
|
553
659
|
def __init__(
|
554
660
|
self,
|
@@ -558,20 +664,10 @@ class NixlKVSender(BaseKVSender):
|
|
558
664
|
dest_tp_ranks: List[int],
|
559
665
|
pp_rank: int,
|
560
666
|
):
|
561
|
-
|
562
|
-
self.bootstrap_room = bootstrap_room
|
563
|
-
self.aux_index = None
|
564
|
-
self.bootstrap_server_url = bootstrap_addr
|
667
|
+
super().__init__(mgr, bootstrap_addr, bootstrap_room, dest_tp_ranks, pp_rank)
|
565
668
|
self.xfer_handles = []
|
566
669
|
self.has_sent = False
|
567
670
|
self.chunk_id = 0
|
568
|
-
self.kv_mgr.update_status(self.bootstrap_room, KVPoll.Bootstrapping)
|
569
|
-
# inner state
|
570
|
-
self.curr_idx = 0
|
571
|
-
|
572
|
-
def init(self, num_kv_indices: int, aux_index: Optional[int] = None):
|
573
|
-
self.num_kv_indices = num_kv_indices
|
574
|
-
self.aux_index = aux_index
|
575
671
|
|
576
672
|
def send(
|
577
673
|
self,
|
@@ -621,6 +717,12 @@ class NixlKVReceiver(CommonKVReceiver):
|
|
621
717
|
self.conclude_state = None
|
622
718
|
super().__init__(mgr, bootstrap_addr, bootstrap_room, prefill_dp_rank)
|
623
719
|
|
720
|
+
# Track this room with its bootstrap address for heartbeat monitoring
|
721
|
+
if hasattr(self.kv_mgr, "addr_to_rooms_tracker"):
|
722
|
+
self.kv_mgr.addr_to_rooms_tracker[self.bootstrap_addr].add(
|
723
|
+
self.bootstrap_room
|
724
|
+
)
|
725
|
+
|
624
726
|
def init(self, kv_indices: npt.NDArray[np.int32], aux_index: Optional[int] = None):
|
625
727
|
for bootstrap_info in self.bootstrap_infos:
|
626
728
|
logger.debug(
|
@@ -655,9 +757,16 @@ class NixlKVReceiver(CommonKVReceiver):
|
|
655
757
|
|
656
758
|
self.kv_mgr.update_transfer_status()
|
657
759
|
if self.kv_mgr.check_transfer_done(self.bootstrap_room): # type: ignore
|
658
|
-
|
760
|
+
# Check if the transfer failed
|
761
|
+
if self.kv_mgr.transfer_statuses[self.bootstrap_room].is_failed():
|
762
|
+
self.conclude_state = KVPoll.Failed
|
763
|
+
logger.error(
|
764
|
+
f"Transfer for room {self.bootstrap_room} failed due to node failure"
|
765
|
+
)
|
766
|
+
else:
|
767
|
+
self.conclude_state = KVPoll.Success
|
659
768
|
del self.kv_mgr.transfer_statuses[self.bootstrap_room]
|
660
|
-
return
|
769
|
+
return self.conclude_state # type: ignore
|
661
770
|
return KVPoll.WaitingForInput # type: ignore
|
662
771
|
|
663
772
|
def _register_kv_args(self):
|
@@ -21,6 +21,7 @@ from __future__ import annotations
|
|
21
21
|
|
22
22
|
import logging
|
23
23
|
import threading
|
24
|
+
import time
|
24
25
|
from collections import deque
|
25
26
|
from http import HTTPStatus
|
26
27
|
from typing import TYPE_CHECKING, List, Optional, Type
|
@@ -42,7 +43,12 @@ from sglang.srt.disaggregation.utils import (
|
|
42
43
|
poll_and_all_reduce,
|
43
44
|
prepare_abort,
|
44
45
|
)
|
45
|
-
from sglang.srt.managers.schedule_batch import
|
46
|
+
from sglang.srt.managers.schedule_batch import (
|
47
|
+
FINISH_LENGTH,
|
48
|
+
Req,
|
49
|
+
RequestStage,
|
50
|
+
ScheduleBatch,
|
51
|
+
)
|
46
52
|
from sglang.srt.model_executor.forward_batch_info import ForwardMode, PPProxyTensors
|
47
53
|
from sglang.srt.utils import (
|
48
54
|
DynamicGradMode,
|
@@ -170,6 +176,7 @@ class PrefillBootstrapQueue:
|
|
170
176
|
pp_rank=self.pp_rank,
|
171
177
|
)
|
172
178
|
self._process_req(req)
|
179
|
+
req.add_latency(RequestStage.PREFILL_PREPARE)
|
173
180
|
self.queue.append(req)
|
174
181
|
|
175
182
|
def extend(self, reqs: List[Req], num_kv_heads: int) -> None:
|
@@ -256,8 +263,11 @@ class PrefillBootstrapQueue:
|
|
256
263
|
|
257
264
|
num_pages = kv_to_page_num(num_kv_indices, self.token_to_kv_pool.page_size)
|
258
265
|
req.disagg_kv_sender.init(num_pages, req.metadata_buffer_index)
|
266
|
+
|
259
267
|
bootstrapped_reqs.append(req)
|
260
268
|
indices_to_remove.add(i)
|
269
|
+
req.time_stats.wait_queue_entry_time = time.perf_counter()
|
270
|
+
req.add_latency(RequestStage.PREFILL_BOOTSTRAP)
|
261
271
|
|
262
272
|
self.queue = [
|
263
273
|
entry for i, entry in enumerate(self.queue) if i not in indices_to_remove
|
@@ -399,11 +409,11 @@ class SchedulerDisaggregationPrefillMixin:
|
|
399
409
|
for i, (req, next_token_id) in enumerate(
|
400
410
|
zip(batch.reqs, next_token_ids, strict=True)
|
401
411
|
):
|
402
|
-
req: Req
|
403
412
|
if req.is_chunked <= 0:
|
404
413
|
# There is no output_ids for prefill
|
405
414
|
req.output_ids.append(next_token_id)
|
406
415
|
self.tree_cache.cache_unfinished_req(req) # update the tree and lock
|
416
|
+
req.add_latency(RequestStage.PREFILL_FORWARD)
|
407
417
|
self.disagg_prefill_inflight_queue.append(req)
|
408
418
|
if (
|
409
419
|
logits_output is not None
|
@@ -412,9 +422,16 @@ class SchedulerDisaggregationPrefillMixin:
|
|
412
422
|
last_hidden_index = (
|
413
423
|
hidden_state_offset + extend_input_len_per_req[i] - 1
|
414
424
|
)
|
415
|
-
req.
|
416
|
-
|
417
|
-
)
|
425
|
+
req.output_topk_p = batch.spec_info.topk_p[i]
|
426
|
+
req.output_topk_index = batch.spec_info.topk_index[i]
|
427
|
+
if self.spec_algorithm.is_eagle3():
|
428
|
+
req.hidden_states_tensor = (
|
429
|
+
batch.spec_info.hidden_states[i].cpu().clone()
|
430
|
+
)
|
431
|
+
else:
|
432
|
+
req.hidden_states_tensor = (
|
433
|
+
logits_output.hidden_states[last_hidden_index].cpu().clone()
|
434
|
+
)
|
418
435
|
hidden_state_offset += extend_input_len_per_req[i]
|
419
436
|
else:
|
420
437
|
req.hidden_states_tensor = None
|
@@ -434,6 +451,7 @@ class SchedulerDisaggregationPrefillMixin:
|
|
434
451
|
)
|
435
452
|
logprob_pt += num_input_logprobs
|
436
453
|
self.send_kv_chunk(req, last_chunk=True)
|
454
|
+
req.time_stats.prefill_transfer_queue_entry_time = time.perf_counter()
|
437
455
|
|
438
456
|
if req.grammar is not None:
|
439
457
|
# FIXME: this try-except block is for handling unexpected xgrammar issue.
|
@@ -531,6 +549,9 @@ class SchedulerDisaggregationPrefillMixin:
|
|
531
549
|
else:
|
532
550
|
assert False, f"Unexpected polling state {poll=}"
|
533
551
|
|
552
|
+
for req in done_reqs:
|
553
|
+
req.time_stats.completion_time = time.perf_counter()
|
554
|
+
|
534
555
|
# Stream requests which have finished transfer
|
535
556
|
self.stream_output(
|
536
557
|
done_reqs,
|
@@ -539,6 +560,7 @@ class SchedulerDisaggregationPrefillMixin:
|
|
539
560
|
)
|
540
561
|
for req in done_reqs:
|
541
562
|
req: Req
|
563
|
+
req.add_latency(RequestStage.PREFILL_TRANSFER_KV_CACHE)
|
542
564
|
self.req_to_metadata_buffer_idx_allocator.free(req.metadata_buffer_index)
|
543
565
|
req.metadata_buffer_index = -1
|
544
566
|
|
@@ -667,7 +689,6 @@ class SchedulerDisaggregationPrefillMixin:
|
|
667
689
|
self.running_mbs = [
|
668
690
|
ScheduleBatch(reqs=[], batch_is_full=False) for _ in range(self.pp_size)
|
669
691
|
]
|
670
|
-
bids = [None] * self.pp_size
|
671
692
|
pp_outputs: Optional[PPProxyTensors] = None
|
672
693
|
|
673
694
|
# Either success or failed
|
@@ -739,10 +760,7 @@ class SchedulerDisaggregationPrefillMixin:
|
|
739
760
|
# send the outputs to the next step
|
740
761
|
if self.pp_group.is_last_rank:
|
741
762
|
if self.cur_batch:
|
742
|
-
next_token_ids
|
743
|
-
result.next_token_ids,
|
744
|
-
result.bid,
|
745
|
-
)
|
763
|
+
next_token_ids = result.next_token_ids
|
746
764
|
pp_outputs = PPProxyTensors(
|
747
765
|
{
|
748
766
|
"next_token_ids": next_token_ids,
|
@@ -779,7 +797,6 @@ class SchedulerDisaggregationPrefillMixin:
|
|
779
797
|
next_token_ids=next_pp_outputs["next_token_ids"],
|
780
798
|
extend_input_len_per_req=None,
|
781
799
|
extend_logprob_start_len_per_req=None,
|
782
|
-
bid=bids[next_mb_id],
|
783
800
|
can_run_cuda_graph=result.can_run_cuda_graph,
|
784
801
|
)
|
785
802
|
self.process_batch_result_disagg_prefill(
|
@@ -796,8 +813,6 @@ class SchedulerDisaggregationPrefillMixin:
|
|
796
813
|
|
797
814
|
# carry the outputs to the next stage
|
798
815
|
if not self.pp_group.is_last_rank:
|
799
|
-
if self.cur_batch:
|
800
|
-
bids[mb_id] = result.bid
|
801
816
|
if pp_outputs:
|
802
817
|
# send the outputs from the last round to let the next stage worker run post processing
|
803
818
|
self.pp_group.send_tensor_dict(
|
@@ -816,8 +831,10 @@ class SchedulerDisaggregationPrefillMixin:
|
|
816
831
|
|
817
832
|
# send out proxy tensors to the next stage
|
818
833
|
if self.cur_batch:
|
834
|
+
# FIXME(lsyin): remove this assert
|
835
|
+
assert result.pp_hidden_states_proxy_tensors.tensors is not None
|
819
836
|
self.pp_group.send_tensor_dict(
|
820
|
-
result.pp_hidden_states_proxy_tensors,
|
837
|
+
result.pp_hidden_states_proxy_tensors.tensors,
|
821
838
|
all_gather_group=self.attn_tp_group,
|
822
839
|
)
|
823
840
|
|
@@ -5,7 +5,7 @@ import random
|
|
5
5
|
from collections import deque
|
6
6
|
from contextlib import nullcontext
|
7
7
|
from enum import Enum
|
8
|
-
from typing import TYPE_CHECKING,
|
8
|
+
from typing import TYPE_CHECKING, Optional, Type
|
9
9
|
|
10
10
|
import numpy as np
|
11
11
|
import torch
|
@@ -85,7 +85,7 @@ class MetadataBuffers:
|
|
85
85
|
self,
|
86
86
|
size: int,
|
87
87
|
hidden_size: int,
|
88
|
-
|
88
|
+
hidden_states_dtype: torch.dtype,
|
89
89
|
max_top_logprobs_num: int = 128,
|
90
90
|
custom_mem_pool: torch.cuda.MemPool = None,
|
91
91
|
):
|
@@ -107,7 +107,9 @@ class MetadataBuffers:
|
|
107
107
|
# We transfer the metadata of first output token to decode
|
108
108
|
# The minimal size for RDMA is 64Bytes, so we pad it to > 64Bytes
|
109
109
|
self.output_ids = torch.zeros((size, 16), dtype=torch.int32, device=device)
|
110
|
-
|
110
|
+
self.cached_tokens = torch.zeros(
|
111
|
+
(size, 16), dtype=torch.int32, device=device
|
112
|
+
)
|
111
113
|
self.output_token_logprobs_val = torch.zeros(
|
112
114
|
(size, 16), dtype=torch.float32, device=device
|
113
115
|
)
|
@@ -120,33 +122,49 @@ class MetadataBuffers:
|
|
120
122
|
self.output_top_logprobs_idx = torch.zeros(
|
121
123
|
(size, max_top_logprobs_num), dtype=torch.int32, device=device
|
122
124
|
)
|
125
|
+
# For PD + spec decode
|
126
|
+
self.output_topk_p = torch.zeros(
|
127
|
+
(size, 16), dtype=torch.float32, device=device
|
128
|
+
)
|
129
|
+
self.output_topk_index = torch.zeros(
|
130
|
+
(size, 16), dtype=torch.int64, device=device
|
131
|
+
)
|
123
132
|
self.output_hidden_states = torch.zeros(
|
124
|
-
(size, hidden_size), dtype=
|
133
|
+
(size, hidden_size), dtype=hidden_states_dtype, device=device
|
125
134
|
)
|
126
135
|
|
127
136
|
def get_buf_infos(self):
|
128
137
|
ptrs = [
|
129
138
|
self.output_ids.data_ptr(),
|
139
|
+
self.cached_tokens.data_ptr(),
|
130
140
|
self.output_token_logprobs_val.data_ptr(),
|
131
141
|
self.output_token_logprobs_idx.data_ptr(),
|
132
142
|
self.output_top_logprobs_val.data_ptr(),
|
133
143
|
self.output_top_logprobs_idx.data_ptr(),
|
144
|
+
self.output_topk_p.data_ptr(),
|
145
|
+
self.output_topk_index.data_ptr(),
|
134
146
|
self.output_hidden_states.data_ptr(),
|
135
147
|
]
|
136
148
|
data_lens = [
|
137
149
|
self.output_ids.nbytes,
|
150
|
+
self.cached_tokens.nbytes,
|
138
151
|
self.output_token_logprobs_val.nbytes,
|
139
152
|
self.output_token_logprobs_idx.nbytes,
|
140
153
|
self.output_top_logprobs_val.nbytes,
|
141
154
|
self.output_top_logprobs_idx.nbytes,
|
155
|
+
self.output_topk_p.nbytes,
|
156
|
+
self.output_topk_index.nbytes,
|
142
157
|
self.output_hidden_states.nbytes,
|
143
158
|
]
|
144
159
|
item_lens = [
|
145
160
|
self.output_ids[0].nbytes,
|
161
|
+
self.cached_tokens[0].nbytes,
|
146
162
|
self.output_token_logprobs_val[0].nbytes,
|
147
163
|
self.output_token_logprobs_idx[0].nbytes,
|
148
164
|
self.output_top_logprobs_val[0].nbytes,
|
149
165
|
self.output_top_logprobs_idx[0].nbytes,
|
166
|
+
self.output_topk_p[0].nbytes,
|
167
|
+
self.output_topk_index[0].nbytes,
|
150
168
|
self.output_hidden_states[0].nbytes,
|
151
169
|
]
|
152
170
|
return ptrs, data_lens, item_lens
|
@@ -154,16 +172,20 @@ class MetadataBuffers:
|
|
154
172
|
def get_buf(self, idx: int):
|
155
173
|
return (
|
156
174
|
self.output_ids[idx],
|
175
|
+
self.cached_tokens[idx],
|
157
176
|
self.output_token_logprobs_val[idx],
|
158
177
|
self.output_token_logprobs_idx[idx],
|
159
178
|
self.output_top_logprobs_val[idx],
|
160
179
|
self.output_top_logprobs_idx[idx],
|
180
|
+
self.output_topk_p[idx],
|
181
|
+
self.output_topk_index[idx],
|
161
182
|
self.output_hidden_states[idx],
|
162
183
|
)
|
163
184
|
|
164
185
|
def set_buf(self, req: Req):
|
165
186
|
|
166
187
|
self.output_ids[req.metadata_buffer_index][0] = req.output_ids[0]
|
188
|
+
self.cached_tokens[req.metadata_buffer_index][0] = req.cached_tokens
|
167
189
|
if req.return_logprob:
|
168
190
|
if req.output_token_logprobs_val: # not none or empty list
|
169
191
|
self.output_token_logprobs_val[req.metadata_buffer_index][0] = (
|
@@ -186,8 +208,17 @@ class MetadataBuffers:
|
|
186
208
|
] = torch.tensor(
|
187
209
|
req.output_top_logprobs_idx[0], dtype=torch.int32, device="cpu"
|
188
210
|
)
|
189
|
-
#
|
211
|
+
# For PD + spec decode
|
190
212
|
if req.hidden_states_tensor is not None:
|
213
|
+
# speculative_eagle_topk should not be greater than 16 currently
|
214
|
+
topk = req.output_topk_p.size(0)
|
215
|
+
|
216
|
+
self.output_topk_p[req.metadata_buffer_index, :topk].copy_(
|
217
|
+
req.output_topk_p
|
218
|
+
)
|
219
|
+
self.output_topk_index[req.metadata_buffer_index, :topk].copy_(
|
220
|
+
req.output_topk_index
|
221
|
+
)
|
191
222
|
self.output_hidden_states[req.metadata_buffer_index].copy_(
|
192
223
|
req.hidden_states_tensor
|
193
224
|
)
|
@@ -0,0 +1,16 @@
|
|
1
|
+
MiB = 1024 * 1024
|
2
|
+
|
3
|
+
SYMM_MEM_ALL_REDUCE_MAX_SIZES = {
|
4
|
+
9: {
|
5
|
+
2: 64 * MiB, # 64 MB
|
6
|
+
4: 32 * MiB, # 32 MB
|
7
|
+
6: 64 * MiB, # 64 MB
|
8
|
+
8: 64 * MiB, # 64 MB
|
9
|
+
},
|
10
|
+
10: {
|
11
|
+
2: 64 * MiB, # 64 MB
|
12
|
+
4: 32 * MiB, # 32 MB
|
13
|
+
6: 128 * MiB, # 128 MB
|
14
|
+
8: 128 * MiB, # 128 MB
|
15
|
+
},
|
16
|
+
}
|
@@ -18,7 +18,7 @@ from zmq import SUB, SUBSCRIBE, XPUB, XPUB_VERBOSE, Context # type: ignore
|
|
18
18
|
|
19
19
|
from sglang.srt.utils import (
|
20
20
|
format_tcp_address,
|
21
|
-
|
21
|
+
get_local_ip_auto,
|
22
22
|
get_open_port,
|
23
23
|
is_valid_ipv6_address,
|
24
24
|
)
|
@@ -191,7 +191,9 @@ class MessageQueue:
|
|
191
191
|
self.n_remote_reader = n_remote_reader
|
192
192
|
|
193
193
|
if connect_ip is None:
|
194
|
-
connect_ip =
|
194
|
+
connect_ip = (
|
195
|
+
get_local_ip_auto("0.0.0.0") if n_remote_reader > 0 else "127.0.0.1"
|
196
|
+
)
|
195
197
|
|
196
198
|
context = Context()
|
197
199
|
|