sglang 0.5.3rc0__py3-none-any.whl → 0.5.3rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +7 -9
- sglang/bench_one_batch_server.py +321 -31
- sglang/bench_serving.py +10 -3
- sglang/global_config.py +2 -2
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/launch_server.py +14 -0
- sglang/profiler.py +2 -2
- sglang/srt/batch_invariant_ops/__init__.py +27 -0
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/dots_ocr.py +64 -0
- sglang/srt/configs/falcon_h1.py +360 -0
- sglang/srt/configs/load_config.py +8 -0
- sglang/srt/configs/model_config.py +160 -105
- sglang/srt/configs/qwen3_vl.py +586 -0
- sglang/srt/constrained/base_grammar_backend.py +1 -0
- sglang/srt/constrained/outlines_jump_forward.py +1 -1
- sglang/srt/constrained/xgrammar_backend.py +6 -4
- sglang/srt/debug_utils/dumper.py +10 -3
- sglang/srt/disaggregation/ascend/conn.py +2 -2
- sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
- sglang/srt/disaggregation/common/conn.py +266 -98
- sglang/srt/disaggregation/decode.py +50 -9
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +25 -16
- sglang/srt/disaggregation/mooncake/conn.py +51 -541
- sglang/srt/disaggregation/nixl/conn.py +148 -39
- sglang/srt/disaggregation/prefill.py +31 -14
- sglang/srt/disaggregation/utils.py +36 -5
- sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
- sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
- sglang/srt/distributed/parallel_state.py +135 -80
- sglang/srt/entrypoints/engine.py +23 -3
- sglang/srt/entrypoints/grpc_request_manager.py +330 -55
- sglang/srt/entrypoints/grpc_server.py +232 -102
- sglang/srt/entrypoints/http_server.py +49 -9
- sglang/srt/entrypoints/openai/protocol.py +110 -5
- sglang/srt/entrypoints/openai/serving_base.py +25 -6
- sglang/srt/entrypoints/openai/serving_chat.py +178 -49
- sglang/srt/entrypoints/openai/serving_completions.py +5 -3
- sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
- sglang/srt/entrypoints/openai/serving_responses.py +42 -0
- sglang/srt/environ.py +285 -0
- sglang/srt/eplb/expert_location.py +30 -5
- sglang/srt/function_call/function_call_parser.py +3 -2
- sglang/srt/function_call/glm4_moe_detector.py +3 -3
- sglang/srt/function_call/gpt_oss_detector.py +23 -0
- sglang/srt/function_call/json_array_parser.py +63 -0
- sglang/srt/function_call/kimik2_detector.py +17 -4
- sglang/srt/function_call/utils.py +96 -5
- sglang/srt/grpc/compile_proto.py +245 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +73 -68
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +60 -53
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +3 -0
- sglang/srt/layers/activation.py +7 -6
- sglang/srt/layers/attention/aiter_backend.py +14 -15
- sglang/srt/layers/attention/ascend_backend.py +108 -9
- sglang/srt/layers/attention/attention_registry.py +206 -0
- sglang/srt/layers/attention/base_attn_backend.py +12 -3
- sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +2 -2
- sglang/srt/layers/attention/fla/fused_recurrent.py +4 -4
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +2 -2
- sglang/srt/layers/attention/flashattention_backend.py +41 -8
- sglang/srt/layers/attention/flashinfer_backend.py +112 -194
- sglang/srt/layers/attention/flashinfer_mla_backend.py +11 -15
- sglang/srt/layers/attention/flashmla_backend.py +7 -5
- sglang/srt/layers/attention/hybrid_attn_backend.py +11 -3
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +72 -72
- sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +15 -98
- sglang/srt/layers/attention/mamba/mamba.py +566 -1
- sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
- sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
- sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
- sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
- sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
- sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
- sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
- sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
- sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
- sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
- sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
- sglang/srt/layers/attention/nsa/transform_index.py +144 -0
- sglang/srt/layers/attention/nsa/utils.py +24 -0
- sglang/srt/layers/attention/nsa_backend.py +887 -0
- sglang/srt/layers/attention/tbo_backend.py +6 -6
- sglang/srt/layers/attention/torch_flex_backend.py +325 -0
- sglang/srt/layers/attention/triton_backend.py +42 -9
- sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
- sglang/srt/layers/attention/trtllm_mla_backend.py +178 -34
- sglang/srt/layers/attention/vision.py +58 -0
- sglang/srt/layers/attention/wave_backend.py +4 -4
- sglang/srt/layers/communicator.py +8 -0
- sglang/srt/layers/dp_attention.py +11 -1
- sglang/srt/layers/elementwise.py +3 -1
- sglang/srt/layers/layernorm.py +2 -0
- sglang/srt/layers/linear.py +21 -4
- sglang/srt/layers/logits_processor.py +15 -2
- sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
- sglang/srt/layers/moe/ep_moe/layer.py +147 -74
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +52 -25
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +6 -2
- sglang/srt/layers/moe/fused_moe_triton/layer.py +11 -12
- sglang/srt/layers/moe/token_dispatcher/deepep.py +77 -19
- sglang/srt/layers/moe/utils.py +10 -0
- sglang/srt/layers/parameter.py +23 -6
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
- sglang/srt/layers/quantization/fp8.py +2 -2
- sglang/srt/layers/quantization/fp8_utils.py +1 -1
- sglang/srt/layers/quantization/modelopt_quant.py +44 -9
- sglang/srt/layers/quantization/mxfp4.py +12 -4
- sglang/srt/layers/quantization/quark/quark_moe.py +16 -3
- sglang/srt/layers/quantization/w4afp8.py +0 -4
- sglang/srt/layers/quantization/w8a8_int8.py +15 -3
- sglang/srt/layers/rotary_embedding.py +78 -31
- sglang/srt/layers/sampler.py +52 -4
- sglang/srt/layers/utils.py +23 -0
- sglang/srt/lora/backend/base_backend.py +3 -3
- sglang/srt/lora/backend/chunked_backend.py +348 -0
- sglang/srt/lora/backend/triton_backend.py +10 -4
- sglang/srt/lora/lora.py +7 -5
- sglang/srt/lora/lora_manager.py +17 -6
- sglang/srt/lora/mem_pool.py +1 -1
- sglang/srt/lora/triton_ops/__init__.py +4 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
- sglang/srt/lora/utils.py +7 -5
- sglang/srt/managers/cache_controller.py +42 -142
- sglang/srt/managers/data_parallel_controller.py +11 -46
- sglang/srt/managers/detokenizer_manager.py +11 -11
- sglang/srt/managers/io_struct.py +162 -118
- sglang/srt/managers/mm_utils.py +43 -6
- sglang/srt/managers/multi_tokenizer_mixin.py +17 -17
- sglang/srt/managers/multimodal_processor.py +1 -2
- sglang/srt/managers/overlap_utils.py +53 -0
- sglang/srt/managers/schedule_batch.py +167 -86
- sglang/srt/managers/schedule_policy.py +143 -16
- sglang/srt/managers/scheduler.py +359 -214
- sglang/srt/managers/scheduler_input_blocker.py +1 -1
- sglang/srt/managers/scheduler_metrics_mixin.py +98 -126
- sglang/srt/managers/scheduler_output_processor_mixin.py +21 -12
- sglang/srt/managers/scheduler_profiler_mixin.py +5 -5
- sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
- sglang/srt/managers/tokenizer_communicator_mixin.py +111 -5
- sglang/srt/managers/tokenizer_manager.py +84 -136
- sglang/srt/managers/tp_worker.py +39 -29
- sglang/srt/managers/tp_worker_overlap_thread.py +33 -41
- sglang/srt/managers/utils.py +1 -45
- sglang/srt/mem_cache/allocator.py +14 -20
- sglang/srt/mem_cache/allocator_ascend.py +41 -27
- sglang/srt/mem_cache/base_prefix_cache.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +8 -1
- sglang/srt/mem_cache/evict_policy.py +23 -0
- sglang/srt/mem_cache/hicache_storage.py +40 -1
- sglang/srt/mem_cache/hiradix_cache.py +119 -32
- sglang/srt/mem_cache/memory_pool.py +188 -10
- sglang/srt/mem_cache/memory_pool_host.py +134 -182
- sglang/srt/mem_cache/radix_cache.py +222 -71
- sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
- sglang/srt/mem_cache/storage/__init__.py +10 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
- sglang/srt/mem_cache/storage/backend_factory.py +223 -0
- sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
- sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +173 -58
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +10 -6
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +117 -10
- sglang/srt/mem_cache/swa_radix_cache.py +25 -34
- sglang/srt/metrics/collector.py +82 -120
- sglang/srt/metrics/func_timer.py +2 -7
- sglang/srt/metrics/utils.py +8 -1
- sglang/srt/model_executor/cpu_graph_runner.py +2 -2
- sglang/srt/model_executor/cuda_graph_runner.py +39 -32
- sglang/srt/model_executor/forward_batch_info.py +23 -38
- sglang/srt/model_executor/model_runner.py +131 -183
- sglang/srt/model_executor/npu_graph_runner.py +12 -5
- sglang/srt/model_loader/loader.py +14 -10
- sglang/srt/model_loader/weight_utils.py +156 -2
- sglang/srt/models/bailing_moe.py +27 -4
- sglang/srt/models/deepseek_nextn.py +6 -1
- sglang/srt/models/deepseek_v2.py +536 -153
- sglang/srt/models/dots_ocr.py +173 -0
- sglang/srt/models/falcon_h1.py +576 -0
- sglang/srt/models/gemma3_causal.py +0 -2
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/gemma3n_mm.py +1 -1
- sglang/srt/models/glm4_moe.py +3 -3
- sglang/srt/models/glm4_moe_nextn.py +2 -2
- sglang/srt/models/glm4v.py +1 -1
- sglang/srt/models/glm4v_moe.py +1 -1
- sglang/srt/models/gpt_oss.py +7 -30
- sglang/srt/models/kimi_vl_moonvit.py +2 -2
- sglang/srt/models/llama.py +4 -0
- sglang/srt/models/longcat_flash.py +1 -1
- sglang/srt/models/longcat_flash_nextn.py +1 -1
- sglang/srt/models/mllama4.py +15 -4
- sglang/srt/models/qwen2.py +0 -7
- sglang/srt/models/qwen2_5_vl.py +2 -2
- sglang/srt/models/qwen2_audio.py +1 -1
- sglang/srt/models/qwen2_moe.py +64 -1
- sglang/srt/models/qwen2_vl.py +1 -1
- sglang/srt/models/qwen3.py +18 -3
- sglang/srt/models/qwen3_moe.py +31 -3
- sglang/srt/models/qwen3_next.py +36 -9
- sglang/srt/models/qwen3_vl.py +787 -0
- sglang/srt/models/qwen3_vl_moe.py +471 -0
- sglang/srt/models/registry.py +15 -3
- sglang/srt/models/sarashina2_vision.py +269 -0
- sglang/srt/models/solar.py +505 -0
- sglang/srt/models/starcoder2.py +357 -0
- sglang/srt/models/torch_native_llama.py +9 -2
- sglang/srt/models/utils.py +51 -0
- sglang/srt/multimodal/processors/base_processor.py +15 -7
- sglang/srt/multimodal/processors/dots_vlm.py +2 -3
- sglang/srt/multimodal/processors/internvl.py +20 -8
- sglang/srt/multimodal/processors/qwen_vl.py +8 -1
- sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
- sglang/srt/parser/jinja_template_utils.py +6 -0
- sglang/srt/sampling/sampling_batch_info.py +20 -2
- sglang/srt/sampling/sampling_params.py +7 -0
- sglang/srt/server_args.py +753 -295
- sglang/srt/server_args_config_parser.py +146 -0
- sglang/srt/single_batch_overlap.py +151 -0
- sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
- sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
- sglang/srt/speculative/cpp_ngram/param.h +125 -0
- sglang/srt/speculative/cpp_ngram/queue.h +71 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +2 -1
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +3 -1
- sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -755
- sglang/srt/speculative/eagle_worker.py +57 -25
- sglang/srt/speculative/ngram_utils.py +428 -0
- sglang/srt/speculative/ngram_worker.py +245 -0
- sglang/srt/speculative/spec_info.py +47 -0
- sglang/srt/speculative/spec_utils.py +606 -0
- sglang/srt/torch_memory_saver_adapter.py +5 -7
- sglang/srt/tracing/trace.py +32 -6
- sglang/srt/two_batch_overlap.py +8 -5
- sglang/srt/utils/__init__.py +2 -0
- sglang/srt/{utils.py → utils/common.py} +399 -74
- sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +49 -5
- sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
- sglang/srt/utils/rpd_utils.py +452 -0
- sglang/srt/utils/slow_rank_detector.py +71 -0
- sglang/srt/warmup.py +8 -4
- sglang/srt/weight_sync/utils.py +1 -1
- sglang/test/get_logits_ut.py +57 -0
- sglang/test/run_eval.py +79 -11
- sglang/test/runners.py +1 -1
- sglang/test/simple_eval_common.py +5 -2
- sglang/test/simple_eval_mmmu_vlm.py +441 -0
- sglang/test/test_block_fp8.py +2 -2
- sglang/test/test_deterministic.py +297 -0
- sglang/test/test_disaggregation_utils.py +12 -1
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +355 -4
- sglang/utils.py +10 -1
- sglang/version.py +1 -1
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/METADATA +34 -25
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/RECORD +281 -210
- sglang/srt/mem_cache/lora_radix_cache.py +0 -421
- /sglang/srt/{remote_instance_weight_loader_utils.py → model_loader/remote_instance_weight_loader_utils.py} +0 -0
- /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/WHEEL +0 -0
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,7 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
+
import enum
|
4
|
+
|
3
5
|
# Copyright 2023-2024 SGLang Team
|
4
6
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
5
7
|
# you may not use this file except in compliance with the License.
|
@@ -35,10 +37,11 @@ import copy
|
|
35
37
|
import dataclasses
|
36
38
|
import logging
|
37
39
|
import threading
|
40
|
+
import time
|
38
41
|
from enum import Enum, auto
|
39
42
|
from http import HTTPStatus
|
40
43
|
from itertools import chain
|
41
|
-
from typing import TYPE_CHECKING, Any,
|
44
|
+
from typing import TYPE_CHECKING, Any, List, Optional, Set, Tuple, Union
|
42
45
|
|
43
46
|
import numpy as np
|
44
47
|
import torch
|
@@ -51,6 +54,7 @@ from sglang.srt.disaggregation.base import BaseKVSender
|
|
51
54
|
from sglang.srt.disaggregation.decode_schedule_batch_mixin import (
|
52
55
|
ScheduleBatchDisaggregationDecodeMixin,
|
53
56
|
)
|
57
|
+
from sglang.srt.disaggregation.utils import DisaggregationMode
|
54
58
|
from sglang.srt.distributed.parallel_state import get_tensor_model_parallel_rank
|
55
59
|
from sglang.srt.mem_cache.allocator import (
|
56
60
|
BaseTokenToKVPoolAllocator,
|
@@ -58,10 +62,10 @@ from sglang.srt.mem_cache.allocator import (
|
|
58
62
|
)
|
59
63
|
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
60
64
|
from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
|
61
|
-
from sglang.srt.mem_cache.lora_radix_cache import LoRAKey, LoRARadixCache
|
62
65
|
from sglang.srt.mem_cache.memory_pool import HybridReqToTokenPool, ReqToTokenPool
|
66
|
+
from sglang.srt.mem_cache.radix_cache import RadixKey
|
63
67
|
from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
|
64
|
-
from sglang.srt.metrics.collector import TimeStats
|
68
|
+
from sglang.srt.metrics.collector import SchedulerMetricsCollector, TimeStats
|
65
69
|
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
|
66
70
|
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
67
71
|
from sglang.srt.sampling.sampling_params import SamplingParams
|
@@ -70,8 +74,7 @@ from sglang.srt.utils import flatten_nested_list, support_triton
|
|
70
74
|
|
71
75
|
if TYPE_CHECKING:
|
72
76
|
from sglang.srt.configs.model_config import ModelConfig
|
73
|
-
from sglang.srt.speculative.
|
74
|
-
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
77
|
+
from sglang.srt.speculative.spec_info import SpecInput, SpeculativeAlgorithm
|
75
78
|
|
76
79
|
INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
|
77
80
|
|
@@ -86,6 +89,7 @@ GLOBAL_SERVER_ARGS_KEYS = [
|
|
86
89
|
"disable_flashinfer_cutlass_moe_fp4_allgather",
|
87
90
|
"disable_radix_cache",
|
88
91
|
"enable_dp_lm_head",
|
92
|
+
"enable_fp32_lm_head",
|
89
93
|
"flashinfer_mxfp4_moe_precision",
|
90
94
|
"enable_flashinfer_allreduce_fusion",
|
91
95
|
"moe_dense_tp_size",
|
@@ -107,6 +111,9 @@ GLOBAL_SERVER_ARGS_KEYS = [
|
|
107
111
|
"enable_symm_mem",
|
108
112
|
"enable_custom_logit_processor",
|
109
113
|
"disaggregation_mode",
|
114
|
+
"enable_deterministic_inference",
|
115
|
+
"nsa_prefill",
|
116
|
+
"nsa_decode",
|
110
117
|
]
|
111
118
|
|
112
119
|
# Put some global args for easy access
|
@@ -407,6 +414,23 @@ class MultimodalInputs:
|
|
407
414
|
# other args would be kept intact
|
408
415
|
|
409
416
|
|
417
|
+
class RequestStage(str, enum.Enum):
|
418
|
+
# prefill
|
419
|
+
PREFILL_WAITING = "prefill_waiting"
|
420
|
+
|
421
|
+
# disaggregation prefill
|
422
|
+
PREFILL_PREPARE = "prefill_prepare"
|
423
|
+
PREFILL_BOOTSTRAP = "prefill_bootstrap"
|
424
|
+
PREFILL_FORWARD = "prefill_forward"
|
425
|
+
PREFILL_TRANSFER_KV_CACHE = "prefill_transfer_kv_cache"
|
426
|
+
|
427
|
+
# disaggregation decode
|
428
|
+
DECODE_PREPARE = "decode_prepare"
|
429
|
+
DECODE_BOOTSTRAP = "decode_bootstrap"
|
430
|
+
DECODE_WAITING = "decode_waiting"
|
431
|
+
DECODE_TRANSFERRED = "decode_transferred"
|
432
|
+
|
433
|
+
|
410
434
|
class Req:
|
411
435
|
"""The input and output status of a request."""
|
412
436
|
|
@@ -431,8 +455,12 @@ class Req:
|
|
431
455
|
bootstrap_host: Optional[str] = None,
|
432
456
|
bootstrap_port: Optional[int] = None,
|
433
457
|
bootstrap_room: Optional[int] = None,
|
458
|
+
disagg_mode: Optional[DisaggregationMode] = None,
|
434
459
|
data_parallel_rank: Optional[int] = None,
|
435
460
|
vocab_size: Optional[int] = None,
|
461
|
+
priority: Optional[int] = None,
|
462
|
+
metrics_collector: Optional[SchedulerMetricsCollector] = None,
|
463
|
+
extra_key: Optional[str] = None,
|
436
464
|
):
|
437
465
|
# Input and output info
|
438
466
|
self.rid = rid
|
@@ -465,6 +493,14 @@ class Req:
|
|
465
493
|
self.sampling_params = sampling_params
|
466
494
|
self.custom_logit_processor = custom_logit_processor
|
467
495
|
self.return_hidden_states = return_hidden_states
|
496
|
+
|
497
|
+
# extra key for classifying the request (e.g. cache_salt)
|
498
|
+
if lora_id is not None:
|
499
|
+
extra_key = (
|
500
|
+
extra_key or ""
|
501
|
+
) + lora_id # lora_id is concatenated to the extra key
|
502
|
+
|
503
|
+
self.extra_key = extra_key
|
468
504
|
self.lora_id = lora_id
|
469
505
|
|
470
506
|
# Memory pool info
|
@@ -483,6 +519,7 @@ class Req:
|
|
483
519
|
self.stream = stream
|
484
520
|
self.eos_token_ids = eos_token_ids
|
485
521
|
self.vocab_size = vocab_size
|
522
|
+
self.priority = priority
|
486
523
|
|
487
524
|
# For incremental decoding
|
488
525
|
# ----- | --------- read_ids -------|
|
@@ -512,6 +549,8 @@ class Req:
|
|
512
549
|
self.host_hit_length = 0
|
513
550
|
# The node to lock until for swa radix tree lock ref
|
514
551
|
self.swa_uuid_for_lock: Optional[int] = None
|
552
|
+
# The prefix length of the last prefix matching
|
553
|
+
self.last_matched_prefix_len: int = 0
|
515
554
|
|
516
555
|
# Whether or not if it is chunked. It increments whenever
|
517
556
|
# it is chunked, and decrement whenever chunked request is
|
@@ -573,6 +612,8 @@ class Req:
|
|
573
612
|
) = None
|
574
613
|
self.hidden_states: List[List[float]] = []
|
575
614
|
self.hidden_states_tensor = None # Note: use tensor instead of list to transfer hidden_states when PD + MTP
|
615
|
+
self.output_topk_p = None
|
616
|
+
self.output_topk_index = None
|
576
617
|
|
577
618
|
# Embedding (return values)
|
578
619
|
self.embedding = None
|
@@ -590,10 +631,10 @@ class Req:
|
|
590
631
|
self.spec_verify_ct = 0
|
591
632
|
|
592
633
|
# For metrics
|
593
|
-
self.
|
634
|
+
self.metrics_collector = metrics_collector
|
635
|
+
self.time_stats: TimeStats = TimeStats(disagg_mode=disagg_mode)
|
594
636
|
self.has_log_time_stats: bool = False
|
595
|
-
self.
|
596
|
-
self.queue_time_end = None
|
637
|
+
self.last_tic = time.monotonic()
|
597
638
|
|
598
639
|
# For disaggregation
|
599
640
|
self.bootstrap_host: str = bootstrap_host
|
@@ -624,7 +665,21 @@ class Req:
|
|
624
665
|
@property
|
625
666
|
def is_prefill_only(self) -> bool:
|
626
667
|
"""Check if this request is prefill-only (no token generation needed)."""
|
627
|
-
|
668
|
+
# NOTE: when spec is enabled, prefill_only optimizations are disabled
|
669
|
+
return (
|
670
|
+
self.sampling_params.max_new_tokens == 0
|
671
|
+
and global_server_args_dict["speculative_algorithm"] is None
|
672
|
+
)
|
673
|
+
|
674
|
+
def add_latency(self, stage: RequestStage):
|
675
|
+
if self.metrics_collector is None:
|
676
|
+
return
|
677
|
+
|
678
|
+
now = time.monotonic()
|
679
|
+
self.metrics_collector.observe_per_stage_req_latency(
|
680
|
+
stage.value, now - self.last_tic
|
681
|
+
)
|
682
|
+
self.last_tic = now
|
628
683
|
|
629
684
|
def extend_image_inputs(self, image_inputs):
|
630
685
|
if self.multimodal_inputs is None:
|
@@ -642,26 +697,17 @@ class Req:
|
|
642
697
|
):
|
643
698
|
self.fill_ids = self.origin_input_ids + self.output_ids
|
644
699
|
if tree_cache is not None:
|
645
|
-
|
646
|
-
|
647
|
-
|
648
|
-
|
649
|
-
|
650
|
-
|
651
|
-
|
652
|
-
|
653
|
-
|
654
|
-
|
655
|
-
|
656
|
-
else:
|
657
|
-
(
|
658
|
-
self.prefix_indices,
|
659
|
-
self.last_node,
|
660
|
-
self.last_host_node,
|
661
|
-
self.host_hit_length,
|
662
|
-
) = tree_cache.match_prefix(
|
663
|
-
key=self.adjust_max_prefix_ids(),
|
664
|
-
)
|
700
|
+
(
|
701
|
+
self.prefix_indices,
|
702
|
+
self.last_node,
|
703
|
+
self.last_host_node,
|
704
|
+
self.host_hit_length,
|
705
|
+
) = tree_cache.match_prefix(
|
706
|
+
key=RadixKey(
|
707
|
+
token_ids=self.adjust_max_prefix_ids(), extra_key=self.extra_key
|
708
|
+
),
|
709
|
+
)
|
710
|
+
self.last_matched_prefix_len = len(self.prefix_indices)
|
665
711
|
self.extend_input_len = len(self.fill_ids) - len(self.prefix_indices)
|
666
712
|
|
667
713
|
def adjust_max_prefix_ids(self):
|
@@ -794,10 +840,10 @@ class Req:
|
|
794
840
|
return
|
795
841
|
|
796
842
|
if self.bootstrap_room is not None:
|
797
|
-
prefix = f"Req Time Stats(rid={self.rid}, bootstrap_room={self.bootstrap_room}, input len={len(self.origin_input_ids)}, output len={len(self.output_ids)}, type={self.time_stats.
|
843
|
+
prefix = f"Req Time Stats(rid={self.rid}, bootstrap_room={self.bootstrap_room}, input len={len(self.origin_input_ids)}, output len={len(self.output_ids)}, type={self.time_stats.disagg_mode_str()})"
|
798
844
|
else:
|
799
|
-
prefix = f"Req Time Stats(rid={self.rid}, input len={len(self.origin_input_ids)}, output len={len(self.output_ids)}, type={self.time_stats.
|
800
|
-
logger.info(f"{prefix}: {self.time_stats}")
|
845
|
+
prefix = f"Req Time Stats(rid={self.rid}, input len={len(self.origin_input_ids)}, output len={len(self.output_ids)}, type={self.time_stats.disagg_mode_str()})"
|
846
|
+
logger.info(f"{prefix}: {self.time_stats.convert_to_duration()}")
|
801
847
|
self.has_log_time_stats = True
|
802
848
|
|
803
849
|
def set_finish_with_abort(self, error_msg: str):
|
@@ -820,10 +866,6 @@ class Req:
|
|
820
866
|
)
|
821
867
|
|
822
868
|
|
823
|
-
# Batch id
|
824
|
-
bid = 0
|
825
|
-
|
826
|
-
|
827
869
|
@dataclasses.dataclass
|
828
870
|
class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
829
871
|
"""Store all information of a batch on the scheduler."""
|
@@ -860,6 +902,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
860
902
|
token_type_ids: torch.Tensor = None # shape: [b], int64
|
861
903
|
req_pool_indices: torch.Tensor = None # shape: [b], int64
|
862
904
|
seq_lens: torch.Tensor = None # shape: [b], int64
|
905
|
+
seq_lens_cpu: torch.Tensor = None # shape: [b], int64
|
863
906
|
# The output locations of the KV cache
|
864
907
|
out_cache_loc: torch.Tensor = None # shape: [b], int64
|
865
908
|
output_ids: torch.Tensor = None # shape: [b], int64
|
@@ -915,7 +958,8 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
915
958
|
|
916
959
|
# Speculative decoding
|
917
960
|
spec_algorithm: SpeculativeAlgorithm = None
|
918
|
-
spec_info: Optional[
|
961
|
+
# spec_info: Optional[SpecInput] = None
|
962
|
+
spec_info: Optional[SpecInput] = None
|
919
963
|
|
920
964
|
# Whether to return hidden states
|
921
965
|
return_hidden_states: bool = False
|
@@ -1014,7 +1058,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1014
1058
|
def alloc_paged_token_slots_extend(
|
1015
1059
|
self,
|
1016
1060
|
prefix_lens: torch.Tensor,
|
1061
|
+
prefix_lens_cpu: torch.Tensor,
|
1017
1062
|
seq_lens: torch.Tensor,
|
1063
|
+
seq_lens_cpu: torch.Tensor,
|
1018
1064
|
last_loc: torch.Tensor,
|
1019
1065
|
extend_num_tokens: int,
|
1020
1066
|
backup_state: bool = False,
|
@@ -1022,7 +1068,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1022
1068
|
# Over estimate the number of tokens: assume each request needs a new page.
|
1023
1069
|
num_tokens = (
|
1024
1070
|
extend_num_tokens
|
1025
|
-
+ len(
|
1071
|
+
+ len(seq_lens_cpu) * self.token_to_kv_pool_allocator.page_size
|
1026
1072
|
)
|
1027
1073
|
self._evict_tree_cache_if_needed(num_tokens)
|
1028
1074
|
|
@@ -1030,7 +1076,12 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1030
1076
|
state = self.token_to_kv_pool_allocator.backup_state()
|
1031
1077
|
|
1032
1078
|
out_cache_loc = self.token_to_kv_pool_allocator.alloc_extend(
|
1033
|
-
prefix_lens,
|
1079
|
+
prefix_lens,
|
1080
|
+
prefix_lens_cpu,
|
1081
|
+
seq_lens,
|
1082
|
+
seq_lens_cpu,
|
1083
|
+
last_loc,
|
1084
|
+
extend_num_tokens,
|
1034
1085
|
)
|
1035
1086
|
if out_cache_loc is None:
|
1036
1087
|
error_msg = (
|
@@ -1049,6 +1100,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1049
1100
|
def alloc_paged_token_slots_decode(
|
1050
1101
|
self,
|
1051
1102
|
seq_lens: torch.Tensor,
|
1103
|
+
seq_lens_cpu: torch.Tensor,
|
1052
1104
|
last_loc: torch.Tensor,
|
1053
1105
|
backup_state: bool = False,
|
1054
1106
|
):
|
@@ -1059,7 +1111,9 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1059
1111
|
if backup_state:
|
1060
1112
|
state = self.token_to_kv_pool_allocator.backup_state()
|
1061
1113
|
|
1062
|
-
out_cache_loc = self.token_to_kv_pool_allocator.alloc_decode(
|
1114
|
+
out_cache_loc = self.token_to_kv_pool_allocator.alloc_decode(
|
1115
|
+
seq_lens, seq_lens_cpu, last_loc
|
1116
|
+
)
|
1063
1117
|
if out_cache_loc is None:
|
1064
1118
|
error_msg = (
|
1065
1119
|
f"Decode out of memory. Try to lower your batch size.\n"
|
@@ -1128,6 +1182,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1128
1182
|
self.seq_lens = torch.tensor(seq_lens, dtype=torch.int64).to(
|
1129
1183
|
self.device, non_blocking=True
|
1130
1184
|
)
|
1185
|
+
self.seq_lens_cpu = torch.tensor(seq_lens, dtype=torch.int64)
|
1131
1186
|
|
1132
1187
|
if not decoder_out_cache_loc:
|
1133
1188
|
self.out_cache_loc = torch.zeros(0, dtype=torch.int64).to(
|
@@ -1176,12 +1231,14 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1176
1231
|
seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.int64).to(
|
1177
1232
|
self.device, non_blocking=True
|
1178
1233
|
)
|
1234
|
+
seq_lens_cpu = torch.tensor(seq_lens, dtype=torch.int64)
|
1179
1235
|
orig_seq_lens_tensor = torch.tensor(orig_seq_lens, dtype=torch.int32).to(
|
1180
1236
|
self.device, non_blocking=True
|
1181
1237
|
)
|
1182
1238
|
prefix_lens_tensor = torch.tensor(
|
1183
1239
|
prefix_lens, dtype=torch.int64, device=self.device
|
1184
1240
|
)
|
1241
|
+
prefix_lens_cpu_tensor = torch.tensor(prefix_lens, dtype=torch.int64)
|
1185
1242
|
|
1186
1243
|
token_type_ids_tensor = None
|
1187
1244
|
if len(token_type_ids) > 0:
|
@@ -1308,13 +1365,19 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1308
1365
|
prefix_lens_tensor,
|
1309
1366
|
)
|
1310
1367
|
out_cache_loc = self.alloc_paged_token_slots_extend(
|
1311
|
-
prefix_lens_tensor,
|
1368
|
+
prefix_lens_tensor,
|
1369
|
+
prefix_lens_cpu_tensor,
|
1370
|
+
seq_lens_tensor,
|
1371
|
+
seq_lens_cpu,
|
1372
|
+
last_loc,
|
1373
|
+
extend_num_tokens,
|
1312
1374
|
)
|
1313
1375
|
|
1314
1376
|
# Set fields
|
1315
1377
|
self.input_ids = input_ids_tensor
|
1316
1378
|
self.req_pool_indices = req_pool_indices_tensor
|
1317
1379
|
self.seq_lens = seq_lens_tensor
|
1380
|
+
self.seq_lens_cpu = seq_lens_cpu
|
1318
1381
|
self.orig_seq_lens = orig_seq_lens_tensor
|
1319
1382
|
self.out_cache_loc = out_cache_loc
|
1320
1383
|
self.input_embeds = (
|
@@ -1457,7 +1520,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1457
1520
|
)
|
1458
1521
|
|
1459
1522
|
retracted_reqs = []
|
1460
|
-
seq_lens_cpu = self.seq_lens.cpu().numpy()
|
1461
1523
|
first_iter = True
|
1462
1524
|
while first_iter or (
|
1463
1525
|
not self.check_decode_mem(selected_indices=sorted_indices)
|
@@ -1484,37 +1546,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1484
1546
|
idx = sorted_indices.pop()
|
1485
1547
|
req = self.reqs[idx]
|
1486
1548
|
retracted_reqs.append(req)
|
1487
|
-
|
1488
|
-
if server_args.disaggregation_mode == "decode":
|
1489
|
-
req.offload_kv_cache(
|
1490
|
-
self.req_to_token_pool, self.token_to_kv_pool_allocator
|
1491
|
-
)
|
1492
|
-
|
1493
|
-
if isinstance(self.tree_cache, ChunkCache):
|
1494
|
-
# ChunkCache does not have eviction
|
1495
|
-
token_indices = self.req_to_token_pool.req_to_token[
|
1496
|
-
req.req_pool_idx, : seq_lens_cpu[idx]
|
1497
|
-
]
|
1498
|
-
self.token_to_kv_pool_allocator.free(token_indices)
|
1499
|
-
self.req_to_token_pool.free(req.req_pool_idx)
|
1500
|
-
else:
|
1501
|
-
# TODO: apply more fine-grained retraction
|
1502
|
-
last_uncached_pos = (
|
1503
|
-
len(req.prefix_indices) // server_args.page_size
|
1504
|
-
) * server_args.page_size
|
1505
|
-
token_indices = self.req_to_token_pool.req_to_token[
|
1506
|
-
req.req_pool_idx, last_uncached_pos : seq_lens_cpu[idx]
|
1507
|
-
]
|
1508
|
-
self.token_to_kv_pool_allocator.free(token_indices)
|
1509
|
-
self.req_to_token_pool.free(req.req_pool_idx)
|
1510
|
-
|
1511
|
-
# release the last node
|
1512
|
-
if self.is_hybrid:
|
1513
|
-
self.tree_cache.dec_lock_ref(req.last_node, req.swa_uuid_for_lock)
|
1514
|
-
else:
|
1515
|
-
self.tree_cache.dec_lock_ref(req.last_node)
|
1516
|
-
|
1517
|
-
req.reset_for_retract()
|
1549
|
+
self.release_req(idx, len(sorted_indices), server_args)
|
1518
1550
|
|
1519
1551
|
if len(retracted_reqs) == 0:
|
1520
1552
|
# Corner case: only one request left
|
@@ -1533,7 +1565,45 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1533
1565
|
) / total_max_new_tokens
|
1534
1566
|
new_estimate_ratio = min(1.0, new_estimate_ratio)
|
1535
1567
|
|
1536
|
-
return retracted_reqs, new_estimate_ratio
|
1568
|
+
return retracted_reqs, new_estimate_ratio, []
|
1569
|
+
|
1570
|
+
def release_req(self, idx: int, remaing_req_count: int, server_args: ServerArgs):
|
1571
|
+
req = self.reqs[idx]
|
1572
|
+
seq_lens_cpu = self.seq_lens_cpu.numpy()
|
1573
|
+
|
1574
|
+
if server_args.disaggregation_mode == "decode":
|
1575
|
+
req.offload_kv_cache(
|
1576
|
+
self.req_to_token_pool, self.token_to_kv_pool_allocator
|
1577
|
+
)
|
1578
|
+
if isinstance(self.tree_cache, ChunkCache):
|
1579
|
+
# ChunkCache does not have eviction
|
1580
|
+
token_indices = self.req_to_token_pool.req_to_token[
|
1581
|
+
req.req_pool_idx, : seq_lens_cpu[idx]
|
1582
|
+
]
|
1583
|
+
self.token_to_kv_pool_allocator.free(token_indices)
|
1584
|
+
self.req_to_token_pool.free(req.req_pool_idx)
|
1585
|
+
else:
|
1586
|
+
# TODO: apply more fine-grained retraction
|
1587
|
+
last_uncached_pos = (
|
1588
|
+
len(req.prefix_indices) // server_args.page_size
|
1589
|
+
) * server_args.page_size
|
1590
|
+
token_indices = self.req_to_token_pool.req_to_token[
|
1591
|
+
req.req_pool_idx, last_uncached_pos : seq_lens_cpu[idx]
|
1592
|
+
]
|
1593
|
+
self.token_to_kv_pool_allocator.free(token_indices)
|
1594
|
+
self.req_to_token_pool.free(req.req_pool_idx)
|
1595
|
+
|
1596
|
+
# release the last node
|
1597
|
+
if self.is_hybrid:
|
1598
|
+
self.tree_cache.dec_lock_ref(req.last_node, req.swa_uuid_for_lock)
|
1599
|
+
else:
|
1600
|
+
self.tree_cache.dec_lock_ref(req.last_node)
|
1601
|
+
|
1602
|
+
# NOTE(lsyin): we should use the newly evictable memory instantly.
|
1603
|
+
num_tokens = remaing_req_count * global_config.retract_decode_steps
|
1604
|
+
self._evict_tree_cache_if_needed(num_tokens)
|
1605
|
+
|
1606
|
+
req.reset_for_retract()
|
1537
1607
|
|
1538
1608
|
def prepare_encoder_info_decode(self):
|
1539
1609
|
# Reset the encoder cached status
|
@@ -1543,6 +1613,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1543
1613
|
self.forward_mode = ForwardMode.IDLE
|
1544
1614
|
self.input_ids = torch.empty(0, dtype=torch.int64, device=self.device)
|
1545
1615
|
self.seq_lens = torch.empty(0, dtype=torch.int64, device=self.device)
|
1616
|
+
self.seq_lens_cpu = torch.empty(0, dtype=torch.int64)
|
1546
1617
|
self.orig_seq_lens = torch.empty(0, dtype=torch.int32, device=self.device)
|
1547
1618
|
self.out_cache_loc = torch.empty(0, dtype=torch.int64, device=self.device)
|
1548
1619
|
self.req_pool_indices = torch.empty(0, dtype=torch.int32, device=self.device)
|
@@ -1557,7 +1628,11 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1557
1628
|
self.forward_mode = ForwardMode.DECODE
|
1558
1629
|
bs = len(self.reqs)
|
1559
1630
|
|
1560
|
-
if
|
1631
|
+
if (
|
1632
|
+
self.spec_algorithm.is_eagle()
|
1633
|
+
or self.spec_algorithm.is_standalone()
|
1634
|
+
or self.spec_algorithm.is_ngram()
|
1635
|
+
):
|
1561
1636
|
# if spec decoding is used, the decode batch is prepared inside
|
1562
1637
|
# `forward_batch_speculative_generation` after running draft models.
|
1563
1638
|
return
|
@@ -1598,10 +1673,12 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1598
1673
|
if self.enable_overlap:
|
1599
1674
|
# Do not use in-place operations in the overlap mode
|
1600
1675
|
self.seq_lens = self.seq_lens + 1
|
1676
|
+
self.seq_lens_cpu = self.seq_lens_cpu + 1
|
1601
1677
|
self.orig_seq_lens = self.orig_seq_lens + 1
|
1602
1678
|
else:
|
1603
1679
|
# A faster in-place version
|
1604
1680
|
self.seq_lens.add_(1)
|
1681
|
+
self.seq_lens_cpu.add_(1)
|
1605
1682
|
self.orig_seq_lens.add_(1)
|
1606
1683
|
self.seq_lens_sum += bs
|
1607
1684
|
|
@@ -1620,7 +1697,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1620
1697
|
self.req_pool_indices, self.seq_lens - 2
|
1621
1698
|
]
|
1622
1699
|
self.out_cache_loc = self.alloc_paged_token_slots_decode(
|
1623
|
-
self.seq_lens, last_loc
|
1700
|
+
self.seq_lens, self.seq_lens_cpu, last_loc
|
1624
1701
|
)
|
1625
1702
|
|
1626
1703
|
self.req_to_token_pool.write(
|
@@ -1666,6 +1743,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1666
1743
|
self.multimodal_inputs = [self.multimodal_inputs[i] for i in keep_indices]
|
1667
1744
|
self.req_pool_indices = self.req_pool_indices[keep_indices_device]
|
1668
1745
|
self.seq_lens = self.seq_lens[keep_indices_device]
|
1746
|
+
self.seq_lens_cpu = self.seq_lens_cpu[keep_indices]
|
1669
1747
|
self.orig_seq_lens = self.orig_seq_lens[keep_indices_device]
|
1670
1748
|
self.out_cache_loc = None
|
1671
1749
|
self.seq_lens_sum = self.seq_lens.sum().item()
|
@@ -1683,7 +1761,14 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1683
1761
|
|
1684
1762
|
self.sampling_info.filter_batch(keep_indices, keep_indices_device)
|
1685
1763
|
if self.spec_info:
|
1686
|
-
|
1764
|
+
if chunked_req_to_exclude is not None and len(chunked_req_to_exclude) > 0:
|
1765
|
+
has_been_filtered = False
|
1766
|
+
else:
|
1767
|
+
has_been_filtered = True
|
1768
|
+
self.spec_info.filter_batch(
|
1769
|
+
new_indices=keep_indices_device,
|
1770
|
+
has_been_filtered=has_been_filtered,
|
1771
|
+
)
|
1687
1772
|
|
1688
1773
|
def merge_batch(self, other: "ScheduleBatch"):
|
1689
1774
|
# Penalizer orchestrator must be merged before Batch.reqs is merged. This is because
|
@@ -1699,6 +1784,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1699
1784
|
[self.req_pool_indices, other.req_pool_indices]
|
1700
1785
|
)
|
1701
1786
|
self.seq_lens = torch.cat([self.seq_lens, other.seq_lens])
|
1787
|
+
self.seq_lens_cpu = torch.cat([self.seq_lens_cpu, other.seq_lens_cpu])
|
1702
1788
|
self.orig_seq_lens = torch.cat([self.orig_seq_lens, other.orig_seq_lens])
|
1703
1789
|
self.out_cache_loc = None
|
1704
1790
|
self.seq_lens_sum += other.seq_lens_sum
|
@@ -1742,15 +1828,10 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1742
1828
|
self.sampling_info.grammars = None
|
1743
1829
|
|
1744
1830
|
seq_lens_cpu = (
|
1745
|
-
seq_lens_cpu_cache
|
1746
|
-
if seq_lens_cpu_cache is not None
|
1747
|
-
else self.seq_lens.cpu()
|
1831
|
+
seq_lens_cpu_cache if seq_lens_cpu_cache is not None else self.seq_lens_cpu
|
1748
1832
|
)
|
1749
1833
|
|
1750
|
-
global bid
|
1751
|
-
bid += 1
|
1752
1834
|
return ModelWorkerBatch(
|
1753
|
-
bid=bid,
|
1754
1835
|
forward_mode=self.forward_mode,
|
1755
1836
|
input_ids=self.input_ids,
|
1756
1837
|
req_pool_indices=self.req_pool_indices,
|
@@ -1870,8 +1951,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1870
1951
|
|
1871
1952
|
@dataclasses.dataclass
|
1872
1953
|
class ModelWorkerBatch:
|
1873
|
-
# The batch id
|
1874
|
-
bid: int
|
1875
1954
|
# The forward mode
|
1876
1955
|
forward_mode: ForwardMode
|
1877
1956
|
# The input ids
|
@@ -1932,7 +2011,9 @@ class ModelWorkerBatch:
|
|
1932
2011
|
|
1933
2012
|
# Speculative decoding
|
1934
2013
|
spec_algorithm: SpeculativeAlgorithm = None
|
1935
|
-
|
2014
|
+
|
2015
|
+
spec_info: Optional[SpecInput] = None
|
2016
|
+
|
1936
2017
|
# If set, the output of the batch contains the hidden states of the run.
|
1937
2018
|
capture_hidden_mode: CaptureHiddenMode = None
|
1938
2019
|
hicache_consumer_index: int = -1
|