sglang 0.4.10.post2__py3-none-any.whl → 0.5.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +8 -3
- sglang/bench_one_batch.py +119 -17
- sglang/lang/chat_template.py +18 -0
- sglang/srt/bench_utils.py +137 -0
- sglang/srt/configs/model_config.py +42 -7
- sglang/srt/conversation.py +9 -5
- sglang/srt/disaggregation/base/conn.py +5 -2
- sglang/srt/disaggregation/decode.py +14 -4
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +3 -0
- sglang/srt/disaggregation/mooncake/conn.py +286 -160
- sglang/srt/disaggregation/mooncake/transfer_engine.py +29 -0
- sglang/srt/disaggregation/prefill.py +2 -0
- sglang/srt/distributed/parallel_state.py +15 -11
- sglang/srt/entrypoints/context.py +227 -0
- sglang/srt/entrypoints/engine.py +15 -9
- sglang/srt/entrypoints/harmony_utils.py +372 -0
- sglang/srt/entrypoints/http_server.py +74 -4
- sglang/srt/entrypoints/openai/protocol.py +218 -1
- sglang/srt/entrypoints/openai/serving_chat.py +41 -11
- sglang/srt/entrypoints/openai/serving_responses.py +1273 -0
- sglang/srt/entrypoints/openai/tool_server.py +175 -0
- sglang/srt/entrypoints/tool.py +87 -0
- sglang/srt/eplb/expert_location.py +5 -1
- sglang/srt/function_call/ebnf_composer.py +1 -0
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/glm4_moe_detector.py +1 -1
- sglang/srt/function_call/gpt_oss_detector.py +331 -0
- sglang/srt/function_call/kimik2_detector.py +3 -3
- sglang/srt/function_call/qwen3_coder_detector.py +219 -9
- sglang/srt/hf_transformers_utils.py +30 -3
- sglang/srt/jinja_template_utils.py +14 -1
- sglang/srt/layers/attention/aiter_backend.py +375 -115
- sglang/srt/layers/attention/ascend_backend.py +3 -0
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1700 -0
- sglang/srt/layers/attention/flashattention_backend.py +18 -0
- sglang/srt/layers/attention/flashinfer_backend.py +52 -13
- sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
- sglang/srt/layers/attention/triton_backend.py +85 -14
- sglang/srt/layers/attention/triton_ops/decode_attention.py +17 -0
- sglang/srt/layers/attention/triton_ops/extend_attention.py +143 -98
- sglang/srt/layers/attention/trtllm_mha_backend.py +332 -0
- sglang/srt/layers/attention/trtllm_mla_backend.py +119 -22
- sglang/srt/layers/attention/vision.py +22 -6
- sglang/srt/layers/attention/wave_backend.py +627 -0
- sglang/srt/layers/attention/wave_ops/decode_attention.py +186 -0
- sglang/srt/layers/attention/wave_ops/extend_attention.py +149 -0
- sglang/srt/layers/attention/wave_ops/prefill_attention.py +79 -0
- sglang/srt/layers/communicator.py +29 -14
- sglang/srt/layers/dp_attention.py +12 -0
- sglang/srt/layers/flashinfer_comm_fusion.py +4 -4
- sglang/srt/layers/linear.py +3 -7
- sglang/srt/layers/moe/cutlass_moe.py +12 -3
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -5
- sglang/srt/layers/moe/ep_moe/kernels.py +43 -0
- sglang/srt/layers/moe/ep_moe/layer.py +135 -73
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=384,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +101 -12
- sglang/srt/layers/moe/fused_moe_triton/layer.py +412 -33
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +188 -3
- sglang/srt/layers/moe/token_dispatcher/deepep.py +61 -24
- sglang/srt/layers/moe/topk.py +16 -4
- sglang/srt/layers/moe/utils.py +16 -0
- sglang/srt/layers/quantization/__init__.py +27 -3
- sglang/srt/layers/quantization/fp4.py +557 -0
- sglang/srt/layers/quantization/fp8.py +3 -6
- sglang/srt/layers/quantization/fp8_kernel.py +277 -0
- sglang/srt/layers/quantization/fp8_utils.py +51 -10
- sglang/srt/layers/quantization/modelopt_quant.py +258 -68
- sglang/srt/layers/quantization/mxfp4.py +654 -0
- sglang/srt/layers/quantization/mxfp4_tensor.py +133 -0
- sglang/srt/layers/quantization/quark/schemes/__init__.py +6 -0
- sglang/srt/layers/quantization/quark/schemes/quark_scheme.py +55 -0
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +118 -0
- sglang/srt/layers/quantization/quark/utils.py +107 -0
- sglang/srt/layers/quantization/unquant.py +60 -6
- sglang/srt/layers/quantization/w4afp8.py +21 -12
- sglang/srt/layers/quantization/w8a8_int8.py +48 -34
- sglang/srt/layers/rotary_embedding.py +506 -3
- sglang/srt/layers/utils.py +9 -0
- sglang/srt/layers/vocab_parallel_embedding.py +8 -3
- sglang/srt/lora/backend/base_backend.py +3 -23
- sglang/srt/lora/layers.py +60 -114
- sglang/srt/lora/lora.py +17 -62
- sglang/srt/lora/lora_manager.py +82 -62
- sglang/srt/lora/lora_registry.py +23 -11
- sglang/srt/lora/mem_pool.py +63 -68
- sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
- sglang/srt/lora/utils.py +25 -58
- sglang/srt/managers/cache_controller.py +75 -58
- sglang/srt/managers/detokenizer_manager.py +1 -1
- sglang/srt/managers/io_struct.py +20 -8
- sglang/srt/managers/mm_utils.py +6 -13
- sglang/srt/managers/multimodal_processor.py +1 -1
- sglang/srt/managers/schedule_batch.py +61 -25
- sglang/srt/managers/schedule_policy.py +6 -6
- sglang/srt/managers/scheduler.py +41 -19
- sglang/srt/managers/scheduler_output_processor_mixin.py +1 -2
- sglang/srt/managers/scheduler_profiler_mixin.py +28 -8
- sglang/srt/managers/scheduler_recv_skipper.py +37 -0
- sglang/srt/managers/scheduler_update_weights_mixin.py +6 -0
- sglang/srt/managers/template_manager.py +35 -1
- sglang/srt/managers/tokenizer_manager.py +47 -30
- sglang/srt/managers/tp_worker.py +3 -0
- sglang/srt/managers/tp_worker_overlap_thread.py +3 -0
- sglang/srt/mem_cache/allocator.py +61 -87
- sglang/srt/mem_cache/hicache_storage.py +1 -1
- sglang/srt/mem_cache/hiradix_cache.py +80 -22
- sglang/srt/mem_cache/lora_radix_cache.py +421 -0
- sglang/srt/mem_cache/memory_pool_host.py +34 -36
- sglang/srt/mem_cache/multimodal_cache.py +33 -13
- sglang/srt/mem_cache/radix_cache.py +2 -5
- sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +2 -2
- sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +443 -0
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +139 -67
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +6 -9
- sglang/srt/model_executor/cuda_graph_runner.py +29 -9
- sglang/srt/model_executor/forward_batch_info.py +61 -19
- sglang/srt/model_executor/model_runner.py +148 -37
- sglang/srt/model_loader/loader.py +18 -6
- sglang/srt/model_loader/weight_utils.py +10 -0
- sglang/srt/models/bailing_moe.py +425 -0
- sglang/srt/models/deepseek_v2.py +137 -59
- sglang/srt/models/ernie4.py +426 -0
- sglang/srt/models/ernie4_eagle.py +203 -0
- sglang/srt/models/gemma2.py +0 -34
- sglang/srt/models/gemma3n_mm.py +38 -0
- sglang/srt/models/glm4.py +6 -0
- sglang/srt/models/glm4_moe.py +28 -16
- sglang/srt/models/glm4v.py +589 -0
- sglang/srt/models/glm4v_moe.py +400 -0
- sglang/srt/models/gpt_oss.py +1251 -0
- sglang/srt/models/granite.py +0 -25
- sglang/srt/models/llama.py +0 -25
- sglang/srt/models/llama4.py +1 -1
- sglang/srt/models/qwen2.py +6 -0
- sglang/srt/models/qwen2_5_vl.py +7 -3
- sglang/srt/models/qwen2_audio.py +10 -9
- sglang/srt/models/qwen2_moe.py +6 -0
- sglang/srt/models/qwen3.py +0 -24
- sglang/srt/models/qwen3_moe.py +32 -6
- sglang/srt/models/registry.py +1 -1
- sglang/srt/models/step3_vl.py +9 -0
- sglang/srt/models/torch_native_llama.py +0 -24
- sglang/srt/models/transformers.py +2 -5
- sglang/srt/multimodal/processors/base_processor.py +23 -13
- sglang/srt/multimodal/processors/glm4v.py +132 -0
- sglang/srt/multimodal/processors/qwen_audio.py +4 -2
- sglang/srt/multimodal/processors/step3_vl.py +3 -1
- sglang/srt/reasoning_parser.py +332 -37
- sglang/srt/server_args.py +186 -75
- sglang/srt/speculative/eagle_worker.py +16 -0
- sglang/srt/two_batch_overlap.py +169 -9
- sglang/srt/utils.py +41 -5
- sglang/srt/weight_sync/tensor_bucket.py +106 -0
- sglang/test/attention/test_trtllm_mla_backend.py +186 -36
- sglang/test/doc_patch.py +59 -0
- sglang/test/few_shot_gsm8k.py +1 -1
- sglang/test/few_shot_gsm8k_engine.py +1 -1
- sglang/test/run_eval.py +4 -1
- sglang/test/runners.py +2 -2
- sglang/test/simple_eval_common.py +6 -0
- sglang/test/simple_eval_gpqa.py +2 -0
- sglang/test/test_fp4_moe.py +118 -36
- sglang/test/test_utils.py +1 -1
- sglang/utils.py +1 -1
- sglang/version.py +1 -1
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/METADATA +36 -38
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/RECORD +174 -141
- sglang/srt/lora/backend/flashinfer_backend.py +0 -131
- /sglang/{api.py → lang/api.py} +0 -0
- /sglang/{lang/backend → srt/layers/quantization/quark}/__init__.py +0 -0
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/WHEEL +0 -0
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/top_level.txt +0 -0
@@ -36,7 +36,7 @@ if TYPE_CHECKING:
|
|
36
36
|
# This can prevent the server from being too conservative.
|
37
37
|
# Note that this only clips the estimation in the scheduler but does not change the stop
|
38
38
|
# condition. The request can still generate tokens until it hits the unclipped max_new_tokens.
|
39
|
-
|
39
|
+
CLIP_MAX_NEW_TOKENS = int(
|
40
40
|
os.environ.get("SGLANG_CLIP_MAX_NEW_TOKENS_ESTIMATION", "4096")
|
41
41
|
)
|
42
42
|
|
@@ -305,7 +305,7 @@ class PrefillAdder:
|
|
305
305
|
[
|
306
306
|
min(
|
307
307
|
(r.sampling_params.max_new_tokens - len(r.output_ids)),
|
308
|
-
|
308
|
+
CLIP_MAX_NEW_TOKENS,
|
309
309
|
)
|
310
310
|
* self.new_token_ratio
|
311
311
|
for r in running_batch.reqs
|
@@ -388,7 +388,7 @@ class PrefillAdder:
|
|
388
388
|
0,
|
389
389
|
req.extend_input_len,
|
390
390
|
(
|
391
|
-
min(req.sampling_params.max_new_tokens,
|
391
|
+
min(req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS)
|
392
392
|
if not truncated
|
393
393
|
else 0
|
394
394
|
),
|
@@ -477,7 +477,7 @@ class PrefillAdder:
|
|
477
477
|
self._update_prefill_budget(
|
478
478
|
0,
|
479
479
|
req.extend_input_len,
|
480
|
-
min(req.sampling_params.max_new_tokens,
|
480
|
+
min(req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS),
|
481
481
|
)
|
482
482
|
else:
|
483
483
|
if self.rem_chunk_tokens == 0:
|
@@ -499,7 +499,7 @@ class PrefillAdder:
|
|
499
499
|
return self.add_one_req_ignore_eos(req, has_chunked_req)
|
500
500
|
|
501
501
|
total_tokens = req.extend_input_len + min(
|
502
|
-
req.sampling_params.max_new_tokens,
|
502
|
+
req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS
|
503
503
|
)
|
504
504
|
|
505
505
|
# adjusting the input_tokens based on host_hit_length and page_size
|
@@ -544,7 +544,7 @@ class PrefillAdder:
|
|
544
544
|
input_tokens,
|
545
545
|
min(
|
546
546
|
req.sampling_params.max_new_tokens,
|
547
|
-
|
547
|
+
CLIP_MAX_NEW_TOKENS,
|
548
548
|
),
|
549
549
|
)
|
550
550
|
else:
|
sglang/srt/managers/scheduler.py
CHANGED
@@ -120,6 +120,7 @@ from sglang.srt.managers.scheduler_output_processor_mixin import (
|
|
120
120
|
SchedulerOutputProcessorMixin,
|
121
121
|
)
|
122
122
|
from sglang.srt.managers.scheduler_profiler_mixin import SchedulerProfilerMixin
|
123
|
+
from sglang.srt.managers.scheduler_recv_skipper import SchedulerRecvSkipper
|
123
124
|
from sglang.srt.managers.scheduler_update_weights_mixin import (
|
124
125
|
SchedulerUpdateWeightsMixin,
|
125
126
|
)
|
@@ -129,6 +130,7 @@ from sglang.srt.managers.tp_worker_overlap_thread import TpModelWorkerClient
|
|
129
130
|
from sglang.srt.managers.utils import DPBalanceMeta, validate_input_length
|
130
131
|
from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
|
131
132
|
from sglang.srt.mem_cache.hiradix_cache import HiRadixCache
|
133
|
+
from sglang.srt.mem_cache.lora_radix_cache import LoRARadixCache
|
132
134
|
from sglang.srt.mem_cache.radix_cache import RadixCache
|
133
135
|
from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
|
134
136
|
from sglang.srt.model_executor.forward_batch_info import ForwardMode, PPProxyTensors
|
@@ -472,8 +474,10 @@ class Scheduler(
|
|
472
474
|
self.memory_saver_adapter = TorchMemorySaverAdapter.create(
|
473
475
|
enable=server_args.enable_memory_saver
|
474
476
|
)
|
477
|
+
self.offload_tags = set()
|
475
478
|
self.init_profier()
|
476
479
|
|
480
|
+
self.recv_skipper = SchedulerRecvSkipper.maybe_create(server_args)
|
477
481
|
self.input_blocker = (
|
478
482
|
SchedulerInputBlocker(noop=self.attn_tp_rank != 0)
|
479
483
|
if get_bool_env_var("SGLANG_ENABLE_COLOCATED_BATCH_GEN")
|
@@ -608,14 +612,10 @@ class Scheduler(
|
|
608
612
|
hicache_ratio=server_args.hicache_ratio,
|
609
613
|
hicache_size=server_args.hicache_size,
|
610
614
|
hicache_write_policy=server_args.hicache_write_policy,
|
611
|
-
hicache_io_backend=
|
612
|
-
"direct"
|
613
|
-
if server_args.attention_backend
|
614
|
-
== "fa3" # hot fix for incompatibility
|
615
|
-
else server_args.hicache_io_backend
|
616
|
-
),
|
615
|
+
hicache_io_backend=server_args.hicache_io_backend,
|
617
616
|
hicache_mem_layout=server_args.hicache_mem_layout,
|
618
617
|
hicache_storage_backend=server_args.hicache_storage_backend,
|
618
|
+
hicache_storage_prefetch_policy=server_args.hicache_storage_prefetch_policy,
|
619
619
|
)
|
620
620
|
self.tp_worker.register_hicache_layer_transfer_counter(
|
621
621
|
self.tree_cache.cache_controller.layer_done_counter
|
@@ -631,7 +631,19 @@ class Scheduler(
|
|
631
631
|
page_size=self.page_size,
|
632
632
|
disable=server_args.disable_radix_cache,
|
633
633
|
)
|
634
|
-
|
634
|
+
elif self.enable_lora:
|
635
|
+
assert (
|
636
|
+
not self.enable_hierarchical_cache
|
637
|
+
), "LoRA radix cache doesn't support hierarchical cache"
|
638
|
+
assert (
|
639
|
+
self.schedule_policy == "fcfs"
|
640
|
+
), "LoRA radix cache only supports FCFS policy"
|
641
|
+
self.tree_cache = LoRARadixCache(
|
642
|
+
req_to_token_pool=self.req_to_token_pool,
|
643
|
+
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
|
644
|
+
page_size=self.page_size,
|
645
|
+
disable=server_args.disable_radix_cache,
|
646
|
+
)
|
635
647
|
else:
|
636
648
|
self.tree_cache = RadixCache(
|
637
649
|
req_to_token_pool=self.req_to_token_pool,
|
@@ -946,6 +958,14 @@ class Scheduler(
|
|
946
958
|
|
947
959
|
def recv_requests(self) -> List[Req]:
|
948
960
|
"""Receive results at tp_rank = 0 and broadcast it to all other TP ranks."""
|
961
|
+
|
962
|
+
if self.recv_skipper is not None:
|
963
|
+
last_forward_mode = (
|
964
|
+
self.last_batch.forward_mode if self.last_batch is not None else None
|
965
|
+
)
|
966
|
+
if not self.recv_skipper.handle(last_forward_mode):
|
967
|
+
return []
|
968
|
+
|
949
969
|
if self.pp_rank == 0:
|
950
970
|
if self.attn_tp_rank == 0:
|
951
971
|
recv_reqs = []
|
@@ -1029,7 +1049,9 @@ class Scheduler(
|
|
1029
1049
|
for recv_req in recv_reqs:
|
1030
1050
|
# If it is a health check generation request and there are running requests, ignore it.
|
1031
1051
|
if is_health_check_generate_req(recv_req) and (
|
1032
|
-
self.chunked_req is not None
|
1052
|
+
self.chunked_req is not None
|
1053
|
+
or not self.running_batch.is_empty()
|
1054
|
+
or len(self.offload_tags) > 0
|
1033
1055
|
):
|
1034
1056
|
self.return_health_check_ct += 1
|
1035
1057
|
continue
|
@@ -1090,7 +1112,7 @@ class Scheduler(
|
|
1090
1112
|
top_logprobs_num=recv_req.top_logprobs_num,
|
1091
1113
|
token_ids_logprob=recv_req.token_ids_logprob,
|
1092
1114
|
stream=recv_req.stream,
|
1093
|
-
|
1115
|
+
lora_id=recv_req.lora_id,
|
1094
1116
|
input_embeds=recv_req.input_embeds,
|
1095
1117
|
custom_logit_processor=recv_req.custom_logit_processor,
|
1096
1118
|
return_hidden_states=recv_req.return_hidden_states,
|
@@ -1534,18 +1556,15 @@ class Scheduler(
|
|
1534
1556
|
self.chunked_req = adder.add_chunked_req(self.chunked_req)
|
1535
1557
|
|
1536
1558
|
if self.enable_lora:
|
1537
|
-
lora_set = set([req.
|
1559
|
+
lora_set = set([req.lora_id for req in self.running_batch.reqs])
|
1538
1560
|
|
1539
1561
|
# Get requests from the waiting queue to a new prefill batch
|
1540
1562
|
for req in self.waiting_queue:
|
1541
|
-
|
1542
|
-
|
1543
|
-
|
1544
|
-
|
1545
|
-
|
1546
|
-
| set([req.lora_path])
|
1547
|
-
)
|
1548
|
-
> self.max_loras_per_batch
|
1563
|
+
|
1564
|
+
if self.enable_lora and not self.tp_worker.can_run_lora_batch(
|
1565
|
+
lora_set
|
1566
|
+
| set([req.lora_id for req in adder.can_run_list])
|
1567
|
+
| set([req.lora_id])
|
1549
1568
|
):
|
1550
1569
|
self.running_batch.batch_is_full = True
|
1551
1570
|
break
|
@@ -1562,7 +1581,10 @@ class Scheduler(
|
|
1562
1581
|
break
|
1563
1582
|
|
1564
1583
|
if self.enable_hicache_storage:
|
1565
|
-
self.tree_cache.check_prefetch_progress(req.rid)
|
1584
|
+
prefetch_done = self.tree_cache.check_prefetch_progress(req.rid)
|
1585
|
+
if not prefetch_done:
|
1586
|
+
# skip staging requests that are ongoing prefetch
|
1587
|
+
continue
|
1566
1588
|
|
1567
1589
|
req.init_next_round_input(self.tree_cache)
|
1568
1590
|
res = adder.add_one_req(req, has_chunked_req=(self.chunked_req is not None))
|
@@ -571,8 +571,7 @@ class SchedulerOutputProcessorMixin:
|
|
571
571
|
|
572
572
|
req.send_decode_id_offset = len(decode_ids)
|
573
573
|
read_offsets.append(read_offset)
|
574
|
-
|
575
|
-
output_ids.append(req.output_ids[send_token_offset:])
|
574
|
+
output_ids.append(req.output_ids[send_token_offset:])
|
576
575
|
req.send_token_offset = len(req.output_ids)
|
577
576
|
skip_special_tokens.append(req.sampling_params.skip_special_tokens)
|
578
577
|
spaces_between_special_tokens.append(
|
@@ -8,6 +8,18 @@ import torch
|
|
8
8
|
|
9
9
|
from sglang.srt.managers.io_struct import ProfileReq, ProfileReqOutput, ProfileReqType
|
10
10
|
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
11
|
+
from sglang.srt.utils import is_npu
|
12
|
+
|
13
|
+
_is_npu = is_npu()
|
14
|
+
if _is_npu:
|
15
|
+
import torch_npu
|
16
|
+
|
17
|
+
patches = [
|
18
|
+
["profiler.profile", torch_npu.profiler.profile],
|
19
|
+
["profiler.ProfilerActivity.CUDA", torch_npu.profiler.ProfilerActivity.NPU],
|
20
|
+
["profiler.ProfilerActivity.CPU", torch_npu.profiler.ProfilerActivity.CPU],
|
21
|
+
]
|
22
|
+
torch_npu._apply_patches(patches)
|
11
23
|
|
12
24
|
logger = logging.getLogger(__name__)
|
13
25
|
|
@@ -136,6 +148,13 @@ class SchedulerProfilerMixin:
|
|
136
148
|
activities=torchprof_activities,
|
137
149
|
with_stack=with_stack if with_stack is not None else True,
|
138
150
|
record_shapes=record_shapes if record_shapes is not None else False,
|
151
|
+
on_trace_ready=(
|
152
|
+
None
|
153
|
+
if not _is_npu
|
154
|
+
else torch_npu.profiler.tensorboard_trace_handler(
|
155
|
+
self.torch_profiler_output_dir
|
156
|
+
)
|
157
|
+
),
|
139
158
|
)
|
140
159
|
self.torch_profiler.start()
|
141
160
|
self.profile_in_progress = True
|
@@ -166,15 +185,16 @@ class SchedulerProfilerMixin:
|
|
166
185
|
logger.info("Stop profiling" + stage_suffix + "...")
|
167
186
|
if self.torch_profiler is not None:
|
168
187
|
self.torch_profiler.stop()
|
169
|
-
|
170
|
-
|
171
|
-
|
172
|
-
|
173
|
-
|
174
|
-
|
175
|
-
|
188
|
+
if not _is_npu:
|
189
|
+
self.torch_profiler.export_chrome_trace(
|
190
|
+
os.path.join(
|
191
|
+
self.torch_profiler_output_dir,
|
192
|
+
self.profile_id
|
193
|
+
+ f"-TP-{self.tp_rank}"
|
194
|
+
+ stage_suffix
|
195
|
+
+ ".trace.json.gz",
|
196
|
+
)
|
176
197
|
)
|
177
|
-
)
|
178
198
|
torch.distributed.barrier(self.tp_cpu_group)
|
179
199
|
|
180
200
|
if self.rpd_profiler is not None:
|
@@ -0,0 +1,37 @@
|
|
1
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
2
|
+
from sglang.srt.server_args import ServerArgs
|
3
|
+
|
4
|
+
|
5
|
+
class SchedulerRecvSkipper:
|
6
|
+
@staticmethod
|
7
|
+
def maybe_create(server_args: ServerArgs):
|
8
|
+
if server_args.scheduler_recv_interval <= 1:
|
9
|
+
return None
|
10
|
+
return SchedulerRecvSkipper(server_args)
|
11
|
+
|
12
|
+
def __init__(self, server_args: ServerArgs):
|
13
|
+
# Can be supported if needed, but may need e.g. `global_forward_mode`
|
14
|
+
assert not server_args.enable_dp_attention
|
15
|
+
self._counter = 0
|
16
|
+
self._threshold = server_args.scheduler_recv_interval
|
17
|
+
|
18
|
+
def handle(self, last_forward_mode: ForwardMode):
|
19
|
+
should_recv = False
|
20
|
+
|
21
|
+
last_weight = _WEIGHT_OF_FORWARD_MODE.get(last_forward_mode, _DEFAULT_WEIGHT)
|
22
|
+
self._counter += last_weight
|
23
|
+
|
24
|
+
if self._counter >= self._threshold:
|
25
|
+
self._counter = 0
|
26
|
+
should_recv = True
|
27
|
+
|
28
|
+
return should_recv
|
29
|
+
|
30
|
+
|
31
|
+
# All can be tuned if needed
|
32
|
+
_DEFAULT_WEIGHT = 1000
|
33
|
+
_WEIGHT_OF_FORWARD_MODE = {
|
34
|
+
ForwardMode.DECODE: 1,
|
35
|
+
ForwardMode.TARGET_VERIFY: 1,
|
36
|
+
None: 1,
|
37
|
+
}
|
@@ -78,6 +78,9 @@ class SchedulerUpdateWeightsMixin:
|
|
78
78
|
if tags is None or len(tags) == 0:
|
79
79
|
tags = [GPU_MEMORY_TYPE_WEIGHTS, GPU_MEMORY_TYPE_KV_CACHE]
|
80
80
|
|
81
|
+
for tag in tags:
|
82
|
+
self.offload_tags.add(tag)
|
83
|
+
|
81
84
|
if GPU_MEMORY_TYPE_KV_CACHE in tags:
|
82
85
|
self.memory_saver_adapter.pause(GPU_MEMORY_TYPE_KV_CACHE)
|
83
86
|
self.flush_cache()
|
@@ -97,6 +100,9 @@ class SchedulerUpdateWeightsMixin:
|
|
97
100
|
if tags is None or len(tags) == 0:
|
98
101
|
tags = [GPU_MEMORY_TYPE_WEIGHTS, GPU_MEMORY_TYPE_KV_CACHE]
|
99
102
|
|
103
|
+
for tag in tags:
|
104
|
+
self.offload_tags.remove(tag)
|
105
|
+
|
100
106
|
if GPU_MEMORY_TYPE_WEIGHTS in tags:
|
101
107
|
self.memory_saver_adapter.resume(GPU_MEMORY_TYPE_WEIGHTS)
|
102
108
|
torch.distributed.barrier(self.tp_cpu_group)
|
@@ -21,6 +21,7 @@ and code completion templates, eliminating global state and improving modularity
|
|
21
21
|
import json
|
22
22
|
import logging
|
23
23
|
import os
|
24
|
+
import re
|
24
25
|
from typing import Optional
|
25
26
|
|
26
27
|
from sglang.srt.code_completion_parser import (
|
@@ -54,6 +55,7 @@ class TemplateManager:
|
|
54
55
|
self._chat_template_name: Optional[str] = None
|
55
56
|
self._completion_template_name: Optional[str] = None
|
56
57
|
self._jinja_template_content_format: Optional[str] = "openai"
|
58
|
+
self._force_reasoning: bool = False
|
57
59
|
|
58
60
|
@property
|
59
61
|
def chat_template_name(self) -> Optional[str]:
|
@@ -70,6 +72,31 @@ class TemplateManager:
|
|
70
72
|
"""Get the detected template content format ('string' or 'openai' or None)."""
|
71
73
|
return self._jinja_template_content_format
|
72
74
|
|
75
|
+
@property
|
76
|
+
def force_reasoning(self) -> bool:
|
77
|
+
"""
|
78
|
+
Check if the current chat template enforces reasoning/thinking.
|
79
|
+
|
80
|
+
Returns:
|
81
|
+
True if the template contains reasoning patterns like <think> tags
|
82
|
+
"""
|
83
|
+
return self._force_reasoning
|
84
|
+
|
85
|
+
def _detect_reasoning_pattern(self, template: str) -> bool:
|
86
|
+
"""
|
87
|
+
Detect if the chat template contains reasoning/thinking patterns.
|
88
|
+
"""
|
89
|
+
if template is None:
|
90
|
+
return False
|
91
|
+
|
92
|
+
force_reasoning_pattern = r"<\|im_start\|>assistant\\n<think>\\n"
|
93
|
+
has_reasoning = re.search(force_reasoning_pattern, template) is not None
|
94
|
+
|
95
|
+
if has_reasoning:
|
96
|
+
logger.info("Detected the force reasoning pattern in chat template.")
|
97
|
+
|
98
|
+
return has_reasoning
|
99
|
+
|
73
100
|
def load_chat_template(
|
74
101
|
self, tokenizer_manager, chat_template_arg: Optional[str], model_path: str
|
75
102
|
) -> None:
|
@@ -93,7 +120,8 @@ class TemplateManager:
|
|
93
120
|
hf_template = self._resolve_hf_chat_template(tokenizer_manager)
|
94
121
|
if hf_template:
|
95
122
|
# override the chat template
|
96
|
-
tokenizer_manager.tokenizer
|
123
|
+
if tokenizer_manager.tokenizer:
|
124
|
+
tokenizer_manager.tokenizer.chat_template = hf_template
|
97
125
|
self._jinja_template_content_format = (
|
98
126
|
detect_jinja_template_content_format(hf_template)
|
99
127
|
)
|
@@ -106,6 +134,12 @@ class TemplateManager:
|
|
106
134
|
self._jinja_template_content_format = "string"
|
107
135
|
logger.info("No chat template found, defaulting to 'string' content format")
|
108
136
|
|
137
|
+
# Detect reasoning pattern from chat template
|
138
|
+
if tokenizer_manager.tokenizer:
|
139
|
+
self._force_reasoning = self._detect_reasoning_pattern(
|
140
|
+
tokenizer_manager.tokenizer.chat_template
|
141
|
+
)
|
142
|
+
|
109
143
|
def _load_explicit_chat_template(
|
110
144
|
self, tokenizer_manager, chat_template_arg: str
|
111
145
|
) -> None:
|
@@ -269,10 +269,9 @@ class TokenizerManager:
|
|
269
269
|
self.asyncio_tasks = set()
|
270
270
|
|
271
271
|
# Health check
|
272
|
-
self.
|
272
|
+
self.server_status = ServerStatus.Starting
|
273
273
|
self.gracefully_exit = False
|
274
274
|
self.last_receive_tstamp = 0
|
275
|
-
self.server_status = ServerStatus.Starting
|
276
275
|
|
277
276
|
# Dumping
|
278
277
|
self.dump_requests_folder = "" # By default do not dump
|
@@ -291,8 +290,8 @@ class TokenizerManager:
|
|
291
290
|
self.model_update_result: Optional[Awaitable[UpdateWeightFromDiskReqOutput]] = (
|
292
291
|
None
|
293
292
|
)
|
294
|
-
self.
|
295
|
-
self.
|
293
|
+
self.is_pause = False
|
294
|
+
self.is_pause_cond = asyncio.Condition()
|
296
295
|
|
297
296
|
# LoRA
|
298
297
|
# Initialize the `LoRARegistry` with initial LoRA adapter paths provided in `server_args`.
|
@@ -476,16 +475,20 @@ class TokenizerManager:
|
|
476
475
|
self.auto_create_handle_loop()
|
477
476
|
obj.normalize_batch_and_arguments()
|
478
477
|
|
479
|
-
async with self._is_updating_cond:
|
480
|
-
await self._is_updating_cond.wait_for(lambda: not self._is_updating)
|
481
|
-
|
482
478
|
if self.log_requests:
|
483
479
|
max_length, skip_names, _ = self.log_request_metadata
|
484
480
|
logger.info(
|
485
481
|
f"Receive: obj={dataclass_to_string_truncated(obj, max_length, skip_names=skip_names)}"
|
486
482
|
)
|
487
483
|
|
484
|
+
async with self.is_pause_cond:
|
485
|
+
await self.is_pause_cond.wait_for(lambda: not self.is_pause)
|
486
|
+
|
488
487
|
async with self.model_update_lock.reader_lock:
|
488
|
+
if self.server_args.enable_lora and obj.lora_path:
|
489
|
+
# Look up the LoRA ID from the registry and start tracking ongoing LoRA requests.
|
490
|
+
obj.lora_id = await self.lora_registry.acquire(obj.lora_path)
|
491
|
+
|
489
492
|
if obj.is_single:
|
490
493
|
tokenized_obj = await self._tokenize_one_request(obj)
|
491
494
|
state = self._send_one_request(obj, tokenized_obj, created_time)
|
@@ -553,11 +556,6 @@ class TokenizerManager:
|
|
553
556
|
else:
|
554
557
|
mm_inputs = None
|
555
558
|
|
556
|
-
if self.server_args.enable_lora and obj.lora_path:
|
557
|
-
# Start tracking ongoing requests for LoRA adapters and replace the user-friendly LoRA names in
|
558
|
-
# `lora_path` with their corresponding unique LoRA IDs, as required for internal processing.
|
559
|
-
obj.lora_path = await self.lora_registry.acquire(obj.lora_path)
|
560
|
-
|
561
559
|
self._validate_one_request(obj, input_ids)
|
562
560
|
return self._create_tokenized_object(
|
563
561
|
obj, input_text, input_ids, input_embeds, mm_inputs, token_type_ids
|
@@ -665,7 +663,7 @@ class TokenizerManager:
|
|
665
663
|
bootstrap_host=obj.bootstrap_host,
|
666
664
|
bootstrap_port=obj.bootstrap_port,
|
667
665
|
bootstrap_room=obj.bootstrap_room,
|
668
|
-
|
666
|
+
lora_id=obj.lora_id,
|
669
667
|
input_embeds=input_embeds,
|
670
668
|
session_params=session_params,
|
671
669
|
custom_logit_processor=obj.custom_logit_processor,
|
@@ -750,7 +748,11 @@ class TokenizerManager:
|
|
750
748
|
try:
|
751
749
|
await asyncio.wait_for(state.event.wait(), timeout=4)
|
752
750
|
except asyncio.TimeoutError:
|
753
|
-
if
|
751
|
+
if (
|
752
|
+
request is not None
|
753
|
+
and not obj.background
|
754
|
+
and await request.is_disconnected()
|
755
|
+
):
|
754
756
|
# Abort the request for disconnected requests (non-streaming, waiting queue)
|
755
757
|
self.abort_request(obj.rid)
|
756
758
|
# Use exception to kill the whole call stack and asyncio task
|
@@ -771,10 +773,6 @@ class TokenizerManager:
|
|
771
773
|
msg = f"Finish: obj={dataclass_to_string_truncated(obj, max_length, skip_names=skip_names)}, out={dataclass_to_string_truncated(out, max_length, skip_names=out_skip_names)}"
|
772
774
|
logger.info(msg)
|
773
775
|
|
774
|
-
# Mark ongoing LoRA request as finished.
|
775
|
-
if self.server_args.enable_lora and obj.lora_path:
|
776
|
-
await self.lora_registry.release(obj.lora_path)
|
777
|
-
|
778
776
|
# Check if this was an abort/error created by scheduler
|
779
777
|
if isinstance(out["meta_info"].get("finish_reason"), dict):
|
780
778
|
finish_reason = out["meta_info"]["finish_reason"]
|
@@ -793,6 +791,11 @@ class TokenizerManager:
|
|
793
791
|
# Delete the key to prevent resending abort request to the scheduler and
|
794
792
|
# to ensure aborted request state is cleaned up.
|
795
793
|
del self.rid_to_state[state.obj.rid]
|
794
|
+
|
795
|
+
# Mark ongoing LoRA request as finished.
|
796
|
+
if self.server_args.enable_lora and state.obj.lora_path:
|
797
|
+
await self.lora_registry.release(state.obj.lora_id)
|
798
|
+
|
796
799
|
raise fastapi.HTTPException(
|
797
800
|
status_code=finish_reason["status_code"],
|
798
801
|
detail=finish_reason["message"],
|
@@ -805,7 +808,11 @@ class TokenizerManager:
|
|
805
808
|
if obj.stream:
|
806
809
|
yield out
|
807
810
|
else:
|
808
|
-
if
|
811
|
+
if (
|
812
|
+
request is not None
|
813
|
+
and not obj.background
|
814
|
+
and await request.is_disconnected()
|
815
|
+
):
|
809
816
|
# Abort the request for disconnected requests (non-streaming, running)
|
810
817
|
self.abort_request(obj.rid)
|
811
818
|
# Use exception to kill the whole call stack and asyncio task
|
@@ -974,14 +981,14 @@ class TokenizerManager:
|
|
974
981
|
await self.expert_distribution_communicator(ExpertDistributionReq.DUMP_RECORD)
|
975
982
|
|
976
983
|
async def pause_generation(self):
|
977
|
-
async with self.
|
978
|
-
self.
|
984
|
+
async with self.is_pause_cond:
|
985
|
+
self.is_pause = True
|
979
986
|
self.abort_request(abort_all=True)
|
980
987
|
|
981
988
|
async def continue_generation(self):
|
982
|
-
async with self.
|
983
|
-
self.
|
984
|
-
self.
|
989
|
+
async with self.is_pause_cond:
|
990
|
+
self.is_pause = False
|
991
|
+
self.is_pause_cond.notify_all()
|
985
992
|
|
986
993
|
async def update_weights_from_disk(
|
987
994
|
self,
|
@@ -1121,6 +1128,7 @@ class TokenizerManager:
|
|
1121
1128
|
new_adapter = LoRARef(
|
1122
1129
|
lora_name=obj.lora_name,
|
1123
1130
|
lora_path=obj.lora_path,
|
1131
|
+
pinned=obj.pinned,
|
1124
1132
|
)
|
1125
1133
|
|
1126
1134
|
# Trigger the actual loading operation at the backend processes.
|
@@ -1178,7 +1186,7 @@ class TokenizerManager:
|
|
1178
1186
|
|
1179
1187
|
return result
|
1180
1188
|
except ValueError as e:
|
1181
|
-
return UnloadLoRAAdapterReqOutput(success=False,
|
1189
|
+
return UnloadLoRAAdapterReqOutput(success=False, error_message=str(e))
|
1182
1190
|
|
1183
1191
|
async def get_weights_by_name(
|
1184
1192
|
self, obj: GetWeightsByNameReqInput, request: Optional[fastapi.Request] = None
|
@@ -1465,7 +1473,7 @@ class TokenizerManager:
|
|
1465
1473
|
while True:
|
1466
1474
|
remain_num_req = len(self.rid_to_state)
|
1467
1475
|
|
1468
|
-
if self.
|
1476
|
+
if self.server_status == ServerStatus.UnHealthy:
|
1469
1477
|
# if health check failed, we should exit immediately
|
1470
1478
|
logger.error(
|
1471
1479
|
"Signal SIGTERM received while health check failed. Exiting... remaining number of requests: %d",
|
@@ -1548,8 +1556,17 @@ class TokenizerManager:
|
|
1548
1556
|
|
1549
1557
|
if isinstance(recv_obj, BatchStrOut):
|
1550
1558
|
state.text += recv_obj.output_strs[i]
|
1559
|
+
if state.obj.stream:
|
1560
|
+
state.output_ids.extend(recv_obj.output_ids[i])
|
1561
|
+
output_token_ids = state.output_ids[state.last_output_offset :]
|
1562
|
+
state.last_output_offset = len(state.output_ids)
|
1563
|
+
else:
|
1564
|
+
state.output_ids.extend(recv_obj.output_ids[i])
|
1565
|
+
output_token_ids = state.output_ids.copy()
|
1566
|
+
|
1551
1567
|
out_dict = {
|
1552
1568
|
"text": state.text,
|
1569
|
+
"output_ids": output_token_ids,
|
1553
1570
|
"meta_info": meta_info,
|
1554
1571
|
}
|
1555
1572
|
elif isinstance(recv_obj, BatchTokenIDOut):
|
@@ -1582,6 +1599,10 @@ class TokenizerManager:
|
|
1582
1599
|
meta_info["e2e_latency"] = state.finished_time - state.created_time
|
1583
1600
|
del self.rid_to_state[rid]
|
1584
1601
|
|
1602
|
+
# Mark ongoing LoRA request as finished.
|
1603
|
+
if self.server_args.enable_lora and state.obj.lora_path:
|
1604
|
+
asyncio.create_task(self.lora_registry.release(state.obj.lora_id))
|
1605
|
+
|
1585
1606
|
state.out_list.append(out_dict)
|
1586
1607
|
state.event.set()
|
1587
1608
|
|
@@ -1947,10 +1968,6 @@ class ServerStatus(Enum):
|
|
1947
1968
|
Up = "Up"
|
1948
1969
|
Starting = "Starting"
|
1949
1970
|
UnHealthy = "UnHealthy"
|
1950
|
-
Crashed = "Crashed"
|
1951
|
-
|
1952
|
-
def is_healthy(self) -> bool:
|
1953
|
-
return self == ServerStatus.Up
|
1954
1971
|
|
1955
1972
|
|
1956
1973
|
def _determine_tensor_transport_mode(server_args: ServerArgs) -> TensorTransportMode:
|
sglang/srt/managers/tp_worker.py
CHANGED
@@ -311,3 +311,6 @@ class TpModelWorker:
|
|
311
311
|
def unload_lora_adapter(self, recv_req: UnloadLoRAAdapterReqInput):
|
312
312
|
result = self.model_runner.unload_lora_adapter(recv_req.to_ref())
|
313
313
|
return result
|
314
|
+
|
315
|
+
def can_run_lora_batch(self, lora_ids: list[str]) -> bool:
|
316
|
+
return self.model_runner.lora_manager.validate_lora_batch(lora_ids)
|
@@ -288,6 +288,9 @@ class TpModelWorkerClient:
|
|
288
288
|
def unload_lora_adapter(self, recv_req: UnloadLoRAAdapterReqInput):
|
289
289
|
return self.worker.unload_lora_adapter(recv_req)
|
290
290
|
|
291
|
+
def can_run_lora_batch(self, lora_ids: list[str]) -> bool:
|
292
|
+
return self.worker.can_run_lora_batch(lora_ids)
|
293
|
+
|
291
294
|
def __delete__(self):
|
292
295
|
self.input_queue.put((None, None))
|
293
296
|
self.copy_queue.put((None, None, None))
|