sglang 0.4.10.post2__py3-none-any.whl → 0.5.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +8 -3
- sglang/bench_one_batch.py +119 -17
- sglang/lang/chat_template.py +18 -0
- sglang/srt/bench_utils.py +137 -0
- sglang/srt/configs/model_config.py +42 -7
- sglang/srt/conversation.py +9 -5
- sglang/srt/disaggregation/base/conn.py +5 -2
- sglang/srt/disaggregation/decode.py +14 -4
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +3 -0
- sglang/srt/disaggregation/mooncake/conn.py +286 -160
- sglang/srt/disaggregation/mooncake/transfer_engine.py +29 -0
- sglang/srt/disaggregation/prefill.py +2 -0
- sglang/srt/distributed/parallel_state.py +15 -11
- sglang/srt/entrypoints/context.py +227 -0
- sglang/srt/entrypoints/engine.py +15 -9
- sglang/srt/entrypoints/harmony_utils.py +372 -0
- sglang/srt/entrypoints/http_server.py +74 -4
- sglang/srt/entrypoints/openai/protocol.py +218 -1
- sglang/srt/entrypoints/openai/serving_chat.py +41 -11
- sglang/srt/entrypoints/openai/serving_responses.py +1273 -0
- sglang/srt/entrypoints/openai/tool_server.py +175 -0
- sglang/srt/entrypoints/tool.py +87 -0
- sglang/srt/eplb/expert_location.py +5 -1
- sglang/srt/function_call/ebnf_composer.py +1 -0
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/glm4_moe_detector.py +1 -1
- sglang/srt/function_call/gpt_oss_detector.py +331 -0
- sglang/srt/function_call/kimik2_detector.py +3 -3
- sglang/srt/function_call/qwen3_coder_detector.py +219 -9
- sglang/srt/hf_transformers_utils.py +30 -3
- sglang/srt/jinja_template_utils.py +14 -1
- sglang/srt/layers/attention/aiter_backend.py +375 -115
- sglang/srt/layers/attention/ascend_backend.py +3 -0
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1700 -0
- sglang/srt/layers/attention/flashattention_backend.py +18 -0
- sglang/srt/layers/attention/flashinfer_backend.py +52 -13
- sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
- sglang/srt/layers/attention/triton_backend.py +85 -14
- sglang/srt/layers/attention/triton_ops/decode_attention.py +17 -0
- sglang/srt/layers/attention/triton_ops/extend_attention.py +143 -98
- sglang/srt/layers/attention/trtllm_mha_backend.py +332 -0
- sglang/srt/layers/attention/trtllm_mla_backend.py +119 -22
- sglang/srt/layers/attention/vision.py +22 -6
- sglang/srt/layers/attention/wave_backend.py +627 -0
- sglang/srt/layers/attention/wave_ops/decode_attention.py +186 -0
- sglang/srt/layers/attention/wave_ops/extend_attention.py +149 -0
- sglang/srt/layers/attention/wave_ops/prefill_attention.py +79 -0
- sglang/srt/layers/communicator.py +29 -14
- sglang/srt/layers/dp_attention.py +12 -0
- sglang/srt/layers/flashinfer_comm_fusion.py +4 -4
- sglang/srt/layers/linear.py +3 -7
- sglang/srt/layers/moe/cutlass_moe.py +12 -3
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -5
- sglang/srt/layers/moe/ep_moe/kernels.py +43 -0
- sglang/srt/layers/moe/ep_moe/layer.py +135 -73
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=384,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +101 -12
- sglang/srt/layers/moe/fused_moe_triton/layer.py +412 -33
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +188 -3
- sglang/srt/layers/moe/token_dispatcher/deepep.py +61 -24
- sglang/srt/layers/moe/topk.py +16 -4
- sglang/srt/layers/moe/utils.py +16 -0
- sglang/srt/layers/quantization/__init__.py +27 -3
- sglang/srt/layers/quantization/fp4.py +557 -0
- sglang/srt/layers/quantization/fp8.py +3 -6
- sglang/srt/layers/quantization/fp8_kernel.py +277 -0
- sglang/srt/layers/quantization/fp8_utils.py +51 -10
- sglang/srt/layers/quantization/modelopt_quant.py +258 -68
- sglang/srt/layers/quantization/mxfp4.py +654 -0
- sglang/srt/layers/quantization/mxfp4_tensor.py +133 -0
- sglang/srt/layers/quantization/quark/schemes/__init__.py +6 -0
- sglang/srt/layers/quantization/quark/schemes/quark_scheme.py +55 -0
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +118 -0
- sglang/srt/layers/quantization/quark/utils.py +107 -0
- sglang/srt/layers/quantization/unquant.py +60 -6
- sglang/srt/layers/quantization/w4afp8.py +21 -12
- sglang/srt/layers/quantization/w8a8_int8.py +48 -34
- sglang/srt/layers/rotary_embedding.py +506 -3
- sglang/srt/layers/utils.py +9 -0
- sglang/srt/layers/vocab_parallel_embedding.py +8 -3
- sglang/srt/lora/backend/base_backend.py +3 -23
- sglang/srt/lora/layers.py +60 -114
- sglang/srt/lora/lora.py +17 -62
- sglang/srt/lora/lora_manager.py +82 -62
- sglang/srt/lora/lora_registry.py +23 -11
- sglang/srt/lora/mem_pool.py +63 -68
- sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
- sglang/srt/lora/utils.py +25 -58
- sglang/srt/managers/cache_controller.py +75 -58
- sglang/srt/managers/detokenizer_manager.py +1 -1
- sglang/srt/managers/io_struct.py +20 -8
- sglang/srt/managers/mm_utils.py +6 -13
- sglang/srt/managers/multimodal_processor.py +1 -1
- sglang/srt/managers/schedule_batch.py +61 -25
- sglang/srt/managers/schedule_policy.py +6 -6
- sglang/srt/managers/scheduler.py +41 -19
- sglang/srt/managers/scheduler_output_processor_mixin.py +1 -2
- sglang/srt/managers/scheduler_profiler_mixin.py +28 -8
- sglang/srt/managers/scheduler_recv_skipper.py +37 -0
- sglang/srt/managers/scheduler_update_weights_mixin.py +6 -0
- sglang/srt/managers/template_manager.py +35 -1
- sglang/srt/managers/tokenizer_manager.py +47 -30
- sglang/srt/managers/tp_worker.py +3 -0
- sglang/srt/managers/tp_worker_overlap_thread.py +3 -0
- sglang/srt/mem_cache/allocator.py +61 -87
- sglang/srt/mem_cache/hicache_storage.py +1 -1
- sglang/srt/mem_cache/hiradix_cache.py +80 -22
- sglang/srt/mem_cache/lora_radix_cache.py +421 -0
- sglang/srt/mem_cache/memory_pool_host.py +34 -36
- sglang/srt/mem_cache/multimodal_cache.py +33 -13
- sglang/srt/mem_cache/radix_cache.py +2 -5
- sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +2 -2
- sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +443 -0
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +139 -67
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +6 -9
- sglang/srt/model_executor/cuda_graph_runner.py +29 -9
- sglang/srt/model_executor/forward_batch_info.py +61 -19
- sglang/srt/model_executor/model_runner.py +148 -37
- sglang/srt/model_loader/loader.py +18 -6
- sglang/srt/model_loader/weight_utils.py +10 -0
- sglang/srt/models/bailing_moe.py +425 -0
- sglang/srt/models/deepseek_v2.py +137 -59
- sglang/srt/models/ernie4.py +426 -0
- sglang/srt/models/ernie4_eagle.py +203 -0
- sglang/srt/models/gemma2.py +0 -34
- sglang/srt/models/gemma3n_mm.py +38 -0
- sglang/srt/models/glm4.py +6 -0
- sglang/srt/models/glm4_moe.py +28 -16
- sglang/srt/models/glm4v.py +589 -0
- sglang/srt/models/glm4v_moe.py +400 -0
- sglang/srt/models/gpt_oss.py +1251 -0
- sglang/srt/models/granite.py +0 -25
- sglang/srt/models/llama.py +0 -25
- sglang/srt/models/llama4.py +1 -1
- sglang/srt/models/qwen2.py +6 -0
- sglang/srt/models/qwen2_5_vl.py +7 -3
- sglang/srt/models/qwen2_audio.py +10 -9
- sglang/srt/models/qwen2_moe.py +6 -0
- sglang/srt/models/qwen3.py +0 -24
- sglang/srt/models/qwen3_moe.py +32 -6
- sglang/srt/models/registry.py +1 -1
- sglang/srt/models/step3_vl.py +9 -0
- sglang/srt/models/torch_native_llama.py +0 -24
- sglang/srt/models/transformers.py +2 -5
- sglang/srt/multimodal/processors/base_processor.py +23 -13
- sglang/srt/multimodal/processors/glm4v.py +132 -0
- sglang/srt/multimodal/processors/qwen_audio.py +4 -2
- sglang/srt/multimodal/processors/step3_vl.py +3 -1
- sglang/srt/reasoning_parser.py +332 -37
- sglang/srt/server_args.py +186 -75
- sglang/srt/speculative/eagle_worker.py +16 -0
- sglang/srt/two_batch_overlap.py +169 -9
- sglang/srt/utils.py +41 -5
- sglang/srt/weight_sync/tensor_bucket.py +106 -0
- sglang/test/attention/test_trtllm_mla_backend.py +186 -36
- sglang/test/doc_patch.py +59 -0
- sglang/test/few_shot_gsm8k.py +1 -1
- sglang/test/few_shot_gsm8k_engine.py +1 -1
- sglang/test/run_eval.py +4 -1
- sglang/test/runners.py +2 -2
- sglang/test/simple_eval_common.py +6 -0
- sglang/test/simple_eval_gpqa.py +2 -0
- sglang/test/test_fp4_moe.py +118 -36
- sglang/test/test_utils.py +1 -1
- sglang/utils.py +1 -1
- sglang/version.py +1 -1
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/METADATA +36 -38
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/RECORD +174 -141
- sglang/srt/lora/backend/flashinfer_backend.py +0 -131
- /sglang/{api.py → lang/api.py} +0 -0
- /sglang/{lang/backend → srt/layers/quantization/quark}/__init__.py +0 -0
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/WHEEL +0 -0
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/top_level.txt +0 -0
@@ -16,6 +16,7 @@ limitations under the License.
|
|
16
16
|
import logging
|
17
17
|
import math
|
18
18
|
import threading
|
19
|
+
import time
|
19
20
|
from queue import Empty, Full, PriorityQueue, Queue
|
20
21
|
from typing import TYPE_CHECKING, List, Optional
|
21
22
|
|
@@ -168,12 +169,13 @@ class StorageOperation:
|
|
168
169
|
host_indices: torch.Tensor,
|
169
170
|
token_ids: List[int],
|
170
171
|
last_hash: Optional[str] = None,
|
172
|
+
hash_value: Optional[List[str]] = None,
|
171
173
|
):
|
172
174
|
self.host_indices = host_indices
|
173
175
|
self.token_ids = token_ids
|
174
176
|
self.last_hash = last_hash
|
175
177
|
self.completed_tokens = 0
|
176
|
-
self.hash_value = []
|
178
|
+
self.hash_value = hash_value if hash_value is not None else []
|
177
179
|
|
178
180
|
self.id = StorageOperation.counter
|
179
181
|
StorageOperation.counter += 1
|
@@ -195,6 +197,8 @@ class PrefetchOperation(StorageOperation):
|
|
195
197
|
self._done_flag = False
|
196
198
|
self._lock = threading.Lock()
|
197
199
|
|
200
|
+
self.start_time = time.monotonic()
|
201
|
+
|
198
202
|
super().__init__(host_indices, token_ids, last_hash)
|
199
203
|
|
200
204
|
def increment(self, num_tokens: int):
|
@@ -243,12 +247,12 @@ class HiCacheController:
|
|
243
247
|
self.storage_backend = HiCacheFile()
|
244
248
|
self.get_hash_str = get_hash_str
|
245
249
|
elif storage_backend == "nixl":
|
246
|
-
from sglang.srt.mem_cache.nixl.hicache_nixl import HiCacheNixl
|
250
|
+
from sglang.srt.mem_cache.storage.nixl.hicache_nixl import HiCacheNixl
|
247
251
|
|
248
252
|
self.storage_backend = HiCacheNixl()
|
249
253
|
self.get_hash_str = get_hash_str
|
250
254
|
elif storage_backend == "mooncake":
|
251
|
-
from sglang.srt.mem_cache.mooncake_store.mooncake_store import (
|
255
|
+
from sglang.srt.mem_cache.storage.mooncake_store.mooncake_store import (
|
252
256
|
MooncakeStore,
|
253
257
|
get_hash_str_mooncake,
|
254
258
|
)
|
@@ -256,6 +260,7 @@ class HiCacheController:
|
|
256
260
|
self.storage_backend = MooncakeStore()
|
257
261
|
self.get_hash_str = get_hash_str_mooncake
|
258
262
|
self.storage_backend.register_buffer(self.mem_pool_host.kv_buffer)
|
263
|
+
assert self.mem_pool_host.layout == "page_first"
|
259
264
|
elif storage_backend == "hf3fs":
|
260
265
|
from sglang.srt.distributed import get_tensor_model_parallel_rank
|
261
266
|
from sglang.srt.mem_cache.storage.hf3fs.storage_hf3fs import (
|
@@ -278,6 +283,12 @@ class HiCacheController:
|
|
278
283
|
self.enable_storage = True
|
279
284
|
# todo: threshold policy for prefetching
|
280
285
|
self.prefetch_threshold = max(prefetch_threshold, self.page_size)
|
286
|
+
self.prefetch_capacity_limit = int(
|
287
|
+
0.8 * (self.mem_pool_host.size - self.mem_pool_device.size)
|
288
|
+
)
|
289
|
+
# tracking the number of tokens locked in prefetching, updated by the main scheduler thread
|
290
|
+
self.prefetch_tokens_occupied = 0
|
291
|
+
|
281
292
|
# create a new communication group for synchronizing storage operations across TP workers
|
282
293
|
self.tp_world_size = torch.distributed.get_world_size(group=tp_group)
|
283
294
|
if self.tp_world_size > 1:
|
@@ -424,7 +435,9 @@ class HiCacheController:
|
|
424
435
|
if self.io_backend == "kernel":
|
425
436
|
return host_indices.to(self.mem_pool_device.device), device_indices
|
426
437
|
elif self.io_backend == "direct":
|
427
|
-
|
438
|
+
device_indices = device_indices.cpu()
|
439
|
+
host_indices, idx = host_indices.sort()
|
440
|
+
return host_indices, device_indices.index_select(0, idx)
|
428
441
|
else:
|
429
442
|
raise ValueError(f"Unsupported io backend")
|
430
443
|
|
@@ -525,7 +538,7 @@ class HiCacheController:
|
|
525
538
|
host_indices: torch.Tensor,
|
526
539
|
new_input_tokens: List[int],
|
527
540
|
last_hash: Optional[str] = None,
|
528
|
-
) ->
|
541
|
+
) -> PrefetchOperation:
|
529
542
|
"""
|
530
543
|
Prefetch KV caches from storage backend to host memory.
|
531
544
|
"""
|
@@ -561,10 +574,6 @@ class HiCacheController:
|
|
561
574
|
)
|
562
575
|
completed_tokens += self.page_size
|
563
576
|
else:
|
564
|
-
# operation terminated by controller, release pre-allocated memory
|
565
|
-
self.mem_pool_host.free(
|
566
|
-
operation.host_indices[operation.completed_tokens :]
|
567
|
-
)
|
568
577
|
break
|
569
578
|
|
570
579
|
def mooncake_page_transfer(self, operation):
|
@@ -586,11 +595,31 @@ class HiCacheController:
|
|
586
595
|
operation = self.prefetch_buffer.get(block=True, timeout=1)
|
587
596
|
if self.is_mooncake_backend():
|
588
597
|
self.mooncake_page_transfer(operation)
|
598
|
+
elif self.storage_backend_type == "hf3fs":
|
599
|
+
self.generic_page_transfer(operation, batch_size=128)
|
589
600
|
else:
|
590
601
|
self.generic_page_transfer(operation)
|
602
|
+
|
603
|
+
if self.tp_world_size > 1:
|
604
|
+
# to ensure all TP workers release the host memory at the same time
|
605
|
+
torch.distributed.barrier(group=self.prefetch_tp_group)
|
606
|
+
# operation terminated by controller, release pre-allocated memory
|
607
|
+
self.mem_pool_host.free(
|
608
|
+
operation.host_indices[operation.completed_tokens :]
|
609
|
+
)
|
591
610
|
except Empty:
|
592
611
|
continue
|
593
612
|
|
613
|
+
def prefetch_rate_limit_check(self) -> bool:
|
614
|
+
"""
|
615
|
+
Rate limit the prefetching operations to avoid overwhelming the storage backend.
|
616
|
+
"""
|
617
|
+
# cancel prefetch if too much memory is occupied
|
618
|
+
if self.prefetch_tokens_occupied >= self.prefetch_capacity_limit:
|
619
|
+
return False
|
620
|
+
# todo: more sophisticated rate limiting based on storage backend performance
|
621
|
+
return True
|
622
|
+
|
594
623
|
def prefetch_thread_func(self):
|
595
624
|
"""
|
596
625
|
Manage prefetching operations from storage backend to host memory.
|
@@ -604,34 +633,38 @@ class HiCacheController:
|
|
604
633
|
if operation is None:
|
605
634
|
continue
|
606
635
|
|
607
|
-
last_hash = operation.last_hash
|
608
|
-
tokens_to_fetch = operation.token_ids
|
609
|
-
|
610
636
|
storage_hit_count = 0
|
611
|
-
|
612
|
-
|
613
|
-
|
614
|
-
last_hash =
|
615
|
-
|
616
|
-
|
617
|
-
|
618
|
-
|
619
|
-
|
620
|
-
|
621
|
-
|
622
|
-
|
623
|
-
|
624
|
-
|
625
|
-
|
626
|
-
|
627
|
-
|
628
|
-
|
629
|
-
|
630
|
-
|
631
|
-
|
632
|
-
|
633
|
-
|
634
|
-
|
637
|
+
if (
|
638
|
+
operation.host_indices is not None
|
639
|
+
) and self.prefetch_rate_limit_check():
|
640
|
+
last_hash = operation.last_hash
|
641
|
+
tokens_to_fetch = operation.token_ids
|
642
|
+
|
643
|
+
remaining_tokens = len(tokens_to_fetch)
|
644
|
+
hash_value = []
|
645
|
+
while remaining_tokens >= self.page_size:
|
646
|
+
last_hash = self.get_hash_str(
|
647
|
+
tokens_to_fetch[
|
648
|
+
storage_hit_count : storage_hit_count + self.page_size
|
649
|
+
],
|
650
|
+
last_hash,
|
651
|
+
)
|
652
|
+
|
653
|
+
# todo, more unified interface
|
654
|
+
if not self.is_mooncake_backend():
|
655
|
+
if not self.storage_backend.exists(last_hash):
|
656
|
+
break
|
657
|
+
hash_value.append(last_hash)
|
658
|
+
storage_hit_count += self.page_size
|
659
|
+
remaining_tokens -= self.page_size
|
660
|
+
|
661
|
+
if self.is_mooncake_backend():
|
662
|
+
# deferring to batch exists for mooncake store
|
663
|
+
exist_result = self.storage_backend.exists(hash_value)
|
664
|
+
storage_hit_count = (
|
665
|
+
sum(1 for v in exist_result.values() if v != 0)
|
666
|
+
* self.page_size
|
667
|
+
)
|
635
668
|
|
636
669
|
if self.tp_world_size > 1:
|
637
670
|
storage_hit_count_tensor = torch.tensor(
|
@@ -647,7 +680,8 @@ class HiCacheController:
|
|
647
680
|
if storage_hit_count < self.prefetch_threshold:
|
648
681
|
# not to prefetch if not enough benefits
|
649
682
|
self.prefetch_revoke_queue.put(operation.request_id)
|
650
|
-
|
683
|
+
if operation.host_indices is not None:
|
684
|
+
self.mem_pool_host.free(operation.host_indices)
|
651
685
|
logger.debug(
|
652
686
|
f"Revoking prefetch for request {operation.request_id} due to insufficient hits ({storage_hit_count})."
|
653
687
|
)
|
@@ -670,12 +704,12 @@ class HiCacheController:
|
|
670
704
|
self,
|
671
705
|
host_indices: torch.Tensor,
|
672
706
|
token_ids: List[int],
|
673
|
-
|
707
|
+
hash_value: Optional[List[str]] = None,
|
674
708
|
) -> int:
|
675
709
|
"""
|
676
710
|
Write KV caches from host memory to storage backend.
|
677
711
|
"""
|
678
|
-
operation = StorageOperation(host_indices, token_ids,
|
712
|
+
operation = StorageOperation(host_indices, token_ids, hash_value=hash_value)
|
679
713
|
self.backup_queue.put(operation)
|
680
714
|
return operation.id
|
681
715
|
|
@@ -730,26 +764,10 @@ class HiCacheController:
|
|
730
764
|
if operation is None:
|
731
765
|
continue
|
732
766
|
|
733
|
-
last_hash = operation.last_hash
|
734
|
-
tokens_to_backup = operation.token_ids
|
735
|
-
|
736
|
-
backup_hit_count = 0
|
737
|
-
remaining_tokens = len(tokens_to_backup)
|
738
|
-
hash_value = []
|
739
|
-
while remaining_tokens >= self.page_size:
|
740
|
-
last_hash = self.get_hash_str(
|
741
|
-
tokens_to_backup[
|
742
|
-
backup_hit_count : backup_hit_count + self.page_size
|
743
|
-
],
|
744
|
-
last_hash,
|
745
|
-
)
|
746
|
-
backup_hit_count += self.page_size
|
747
|
-
hash_value.append(last_hash)
|
748
|
-
remaining_tokens -= self.page_size
|
749
|
-
operation.hash_value = hash_value
|
750
|
-
|
751
767
|
if self.is_mooncake_backend():
|
752
768
|
self.mooncake_page_backup(operation)
|
769
|
+
elif self.storage_backend_type == "hf3fs":
|
770
|
+
self.generic_page_backup(operation, batch_size=128)
|
753
771
|
else:
|
754
772
|
self.generic_page_backup(operation)
|
755
773
|
|
@@ -768,7 +786,6 @@ class HiCacheController:
|
|
768
786
|
self.ack_backup_queue.put(
|
769
787
|
(
|
770
788
|
operation.id,
|
771
|
-
operation.hash_value[: min_completed_tokens // self.page_size],
|
772
789
|
min_completed_tokens,
|
773
790
|
)
|
774
791
|
)
|
@@ -216,7 +216,7 @@ class DetokenizerManager:
|
|
216
216
|
rids=recv_obj.rids,
|
217
217
|
finished_reasons=recv_obj.finished_reasons,
|
218
218
|
output_strs=output_strs,
|
219
|
-
output_ids=
|
219
|
+
output_ids=recv_obj.output_ids,
|
220
220
|
prompt_tokens=recv_obj.prompt_tokens,
|
221
221
|
completion_tokens=recv_obj.completion_tokens,
|
222
222
|
cached_tokens=recv_obj.cached_tokens,
|
sglang/srt/managers/io_struct.py
CHANGED
@@ -26,6 +26,7 @@ from sglang.srt.lora.lora_registry import LoRARef
|
|
26
26
|
from sglang.srt.managers.schedule_batch import BaseFinishReason
|
27
27
|
from sglang.srt.multimodal.mm_utils import has_valid_data
|
28
28
|
from sglang.srt.sampling.sampling_params import SamplingParams
|
29
|
+
from sglang.srt.utils import ImageData
|
29
30
|
|
30
31
|
# Handle serialization of Image for pydantic
|
31
32
|
if TYPE_CHECKING:
|
@@ -45,7 +46,7 @@ class SessionParams:
|
|
45
46
|
|
46
47
|
# Type definitions for multimodal input data
|
47
48
|
# Individual data item types for each modality
|
48
|
-
ImageDataInputItem = Union[Image, str, Dict]
|
49
|
+
ImageDataInputItem = Union[Image, str, ImageData, Dict]
|
49
50
|
AudioDataInputItem = Union[str, Dict]
|
50
51
|
VideoDataInputItem = Union[str, Dict]
|
51
52
|
# Union type for any multimodal data item
|
@@ -98,23 +99,24 @@ class GenerateReqInput:
|
|
98
99
|
stream: bool = False
|
99
100
|
# Whether to log metrics for this request (e.g. health_generate calls do not log metrics)
|
100
101
|
log_metrics: bool = True
|
102
|
+
# Whether to return hidden states
|
103
|
+
return_hidden_states: Union[List[bool], bool] = False
|
101
104
|
|
102
105
|
# The modalities of the image data [image, multi-images, video]
|
103
106
|
modalities: Optional[List[str]] = None
|
104
|
-
# The path to the LoRA
|
105
|
-
lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
|
106
|
-
|
107
107
|
# Session info for continual prompting
|
108
108
|
session_params: Optional[Union[List[Dict], Dict]] = None
|
109
109
|
|
110
|
+
# The path to the LoRA adaptors
|
111
|
+
lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
|
112
|
+
# The uid of LoRA adaptors, should be initialized by tokenizer manager
|
113
|
+
lora_id: Optional[Union[List[Optional[str]], Optional[str]]] = None
|
114
|
+
|
110
115
|
# Custom logit processor for advanced sampling control. Must be a serialized instance
|
111
116
|
# of `CustomLogitProcessor` in python/sglang/srt/sampling/custom_logit_processor.py
|
112
117
|
# Use the processor's `to_str()` method to generate the serialized string.
|
113
118
|
custom_logit_processor: Optional[Union[List[Optional[str]], str]] = None
|
114
119
|
|
115
|
-
# Whether to return hidden states
|
116
|
-
return_hidden_states: Union[List[bool], bool] = False
|
117
|
-
|
118
120
|
# For disaggregated inference
|
119
121
|
bootstrap_host: Optional[Union[List[str], str]] = None
|
120
122
|
bootstrap_port: Optional[Union[List[Optional[int]], int]] = None
|
@@ -123,6 +125,9 @@ class GenerateReqInput:
|
|
123
125
|
# For data parallel rank routing
|
124
126
|
data_parallel_rank: Optional[int] = None
|
125
127
|
|
128
|
+
# For background responses (OpenAI responses API)
|
129
|
+
background: bool = False
|
130
|
+
|
126
131
|
def contains_mm_input(self) -> bool:
|
127
132
|
return (
|
128
133
|
has_valid_data(self.image_data)
|
@@ -450,6 +455,7 @@ class GenerateReqInput:
|
|
450
455
|
log_metrics=self.log_metrics,
|
451
456
|
modalities=self.modalities[i] if self.modalities else None,
|
452
457
|
lora_path=self.lora_path[i] if self.lora_path is not None else None,
|
458
|
+
lora_id=self.lora_id[i] if self.lora_id is not None else None,
|
453
459
|
custom_logit_processor=(
|
454
460
|
self.custom_logit_processor[i]
|
455
461
|
if self.custom_logit_processor is not None
|
@@ -500,7 +506,7 @@ class TokenizedGenerateReqInput:
|
|
500
506
|
stream: bool
|
501
507
|
|
502
508
|
# LoRA related
|
503
|
-
|
509
|
+
lora_id: Optional[str] = None # None means just use the base model
|
504
510
|
# The input embeds
|
505
511
|
input_embeds: Optional[Union[List[List[List[float]]], List[List[float]]]] = None
|
506
512
|
|
@@ -557,6 +563,9 @@ class EmbeddingReqInput:
|
|
557
563
|
# For cross-encoder requests
|
558
564
|
is_cross_encoder_request: bool = False
|
559
565
|
|
566
|
+
# For background responses (OpenAI responses API)
|
567
|
+
background: bool = False
|
568
|
+
|
560
569
|
def normalize_batch_and_arguments(self):
|
561
570
|
# at least one of text, input_ids, or image should be provided
|
562
571
|
if self.text is None and self.input_ids is None and self.image_data is None:
|
@@ -1073,6 +1082,8 @@ class LoadLoRAAdapterReqInput:
|
|
1073
1082
|
lora_name: str
|
1074
1083
|
# The path of loading.
|
1075
1084
|
lora_path: str
|
1085
|
+
# Whether to pin the LoRA adapter in memory.
|
1086
|
+
pinned: bool = False
|
1076
1087
|
# The unique identifier for the LoRA adapter, which automatically generated in the `TokenizerManager`.
|
1077
1088
|
lora_id: Optional[str] = None
|
1078
1089
|
|
@@ -1081,6 +1092,7 @@ class LoadLoRAAdapterReqInput:
|
|
1081
1092
|
lora_id=self.lora_id,
|
1082
1093
|
lora_name=self.lora_name,
|
1083
1094
|
lora_path=self.lora_path,
|
1095
|
+
pinned=self.pinned,
|
1084
1096
|
)
|
1085
1097
|
|
1086
1098
|
|
sglang/srt/managers/mm_utils.py
CHANGED
@@ -388,24 +388,18 @@ def _get_chunked_prefill_embedding(
|
|
388
388
|
embedding_per_req = data_embedding_func(embedding_items_per_req)
|
389
389
|
if not embedding_cache.put(embedding_items_hash, embedding_per_req):
|
390
390
|
print_warning_once(
|
391
|
-
"Multimodal embedding cache is full.
|
392
|
-
"
|
391
|
+
"Multimodal embedding cache is full. This typically occurs when a single "
|
392
|
+
"embedding exceeds the cache size limit. Consider increasing the "
|
393
|
+
"`SGLANG_VLM_CACHE_SIZE_MB` environment variable or reducing the input "
|
394
|
+
"embedding size."
|
393
395
|
)
|
394
396
|
|
395
|
-
embedding_per_req_chunk, _,
|
397
|
+
embedding_per_req_chunk, _, _ = get_embedding_chunk(
|
396
398
|
embedding=embedding_per_req,
|
397
399
|
extend_prefix_len=prefix_length[i],
|
398
400
|
extend_seq_len=extend_length[i] if i < len(extend_length) else 0,
|
399
401
|
items_offset=items_offset,
|
400
402
|
)
|
401
|
-
# remove this item from cache if chunk reaches to the end
|
402
|
-
embedding_per_req_length = (
|
403
|
-
embedding_per_req.shape[0]
|
404
|
-
if embedding_per_req.dim() == 2
|
405
|
-
else embedding_per_req.shape[0] * embedding_per_req.shape[1]
|
406
|
-
)
|
407
|
-
if end_index == embedding_per_req_length:
|
408
|
-
embedding_cache.free(embedding_items_hash)
|
409
403
|
embedding_list.append(embedding_per_req_chunk)
|
410
404
|
if len(embedding_list) == 0:
|
411
405
|
return None
|
@@ -620,8 +614,7 @@ def general_mm_embed_routine(
|
|
620
614
|
input_ids: Input token IDs tensor
|
621
615
|
forward_batch: Batch information for model forward pass
|
622
616
|
language_model: Base language model to use
|
623
|
-
|
624
|
-
audio_data_embedding_func: Function to embed audio data
|
617
|
+
data_embedding_funcs: A dictionary mapping from modality type to the corresponding embedding function.
|
625
618
|
placeholder_tokens: Token IDs for multimodal placeholders
|
626
619
|
**kwargs: Additional arguments passed to language model
|
627
620
|
|
@@ -20,7 +20,7 @@ def import_processors():
|
|
20
20
|
try:
|
21
21
|
module = importlib.import_module(name)
|
22
22
|
except Exception as e:
|
23
|
-
logger.warning(f"Ignore import error when loading {name}:
|
23
|
+
logger.warning(f"Ignore import error when loading {name}: {e}")
|
24
24
|
continue
|
25
25
|
all_members = inspect.getmembers(module, inspect.isclass)
|
26
26
|
classes = [
|
@@ -37,6 +37,7 @@ import logging
|
|
37
37
|
import threading
|
38
38
|
from enum import Enum, auto
|
39
39
|
from http import HTTPStatus
|
40
|
+
from itertools import chain
|
40
41
|
from typing import TYPE_CHECKING, Any, List, Optional, Set, Tuple, Union
|
41
42
|
|
42
43
|
import numpy as np
|
@@ -51,13 +52,13 @@ from sglang.srt.disaggregation.decode_schedule_batch_mixin import (
|
|
51
52
|
ScheduleBatchDisaggregationDecodeMixin,
|
52
53
|
)
|
53
54
|
from sglang.srt.distributed.parallel_state import get_tensor_model_parallel_rank
|
54
|
-
from sglang.srt.layers.moe.utils import DeepEPMode, MoeA2ABackend
|
55
55
|
from sglang.srt.mem_cache.allocator import (
|
56
56
|
BaseTokenToKVPoolAllocator,
|
57
57
|
SWATokenToKVPoolAllocator,
|
58
58
|
)
|
59
59
|
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
60
60
|
from sglang.srt.mem_cache.chunk_cache import ChunkCache, SWAChunkCache
|
61
|
+
from sglang.srt.mem_cache.lora_radix_cache import LoRAKey, LoRARadixCache
|
61
62
|
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
|
62
63
|
from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
|
63
64
|
from sglang.srt.metrics.collector import TimeStats
|
@@ -85,6 +86,7 @@ GLOBAL_SERVER_ARGS_KEYS = [
|
|
85
86
|
"disable_radix_cache",
|
86
87
|
"enable_dp_attention",
|
87
88
|
"enable_two_batch_overlap",
|
89
|
+
"tbo_token_distribution_threshold",
|
88
90
|
"enable_dp_lm_head",
|
89
91
|
"moe_a2a_backend",
|
90
92
|
"deepep_mode",
|
@@ -107,8 +109,10 @@ GLOBAL_SERVER_ARGS_KEYS = [
|
|
107
109
|
"num_reserved_decode_tokens",
|
108
110
|
"weight_loader_disable_mmap",
|
109
111
|
"enable_triton_kernel_moe",
|
112
|
+
"enable_flashinfer_mxfp4_moe",
|
110
113
|
"enable_multimodal",
|
111
114
|
"enable_symm_mem",
|
115
|
+
"quantization",
|
112
116
|
]
|
113
117
|
|
114
118
|
# Put some global args for easy access
|
@@ -423,7 +427,7 @@ class Req:
|
|
423
427
|
token_ids_logprob: List[int] = None,
|
424
428
|
stream: bool = False,
|
425
429
|
origin_input_ids_unpadded: Optional[Tuple[int]] = None,
|
426
|
-
|
430
|
+
lora_id: Optional[str] = None,
|
427
431
|
input_embeds: Optional[List[List[float]]] = None,
|
428
432
|
token_type_ids: List[int] = None,
|
429
433
|
session_id: Optional[str] = None,
|
@@ -467,7 +471,7 @@ class Req:
|
|
467
471
|
self.sampling_params = sampling_params
|
468
472
|
self.custom_logit_processor = custom_logit_processor
|
469
473
|
self.return_hidden_states = return_hidden_states
|
470
|
-
self.
|
474
|
+
self.lora_id = lora_id
|
471
475
|
|
472
476
|
# Memory pool info
|
473
477
|
self.req_pool_idx: Optional[int] = None
|
@@ -636,14 +640,26 @@ class Req:
|
|
636
640
|
):
|
637
641
|
self.fill_ids = self.origin_input_ids + self.output_ids
|
638
642
|
if tree_cache is not None:
|
639
|
-
(
|
640
|
-
|
641
|
-
|
642
|
-
|
643
|
-
|
644
|
-
|
645
|
-
|
646
|
-
|
643
|
+
if isinstance(tree_cache, LoRARadixCache):
|
644
|
+
(
|
645
|
+
self.prefix_indices,
|
646
|
+
self.last_node,
|
647
|
+
self.last_host_node,
|
648
|
+
self.host_hit_length,
|
649
|
+
) = tree_cache.match_prefix_with_lora_id(
|
650
|
+
key=LoRAKey(
|
651
|
+
lora_id=self.lora_id, token_ids=self.adjust_max_prefix_ids()
|
652
|
+
),
|
653
|
+
)
|
654
|
+
else:
|
655
|
+
(
|
656
|
+
self.prefix_indices,
|
657
|
+
self.last_node,
|
658
|
+
self.last_host_node,
|
659
|
+
self.host_hit_length,
|
660
|
+
) = tree_cache.match_prefix(
|
661
|
+
key=self.adjust_max_prefix_ids(),
|
662
|
+
)
|
647
663
|
self.extend_input_len = len(self.fill_ids) - len(self.prefix_indices)
|
648
664
|
|
649
665
|
def adjust_max_prefix_ids(self):
|
@@ -845,6 +861,8 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
845
861
|
|
846
862
|
# The sum of all sequence lengths
|
847
863
|
seq_lens_sum: int = None
|
864
|
+
# The original sequence lengths, Qwen-1M related
|
865
|
+
orig_seq_lens: torch.Tensor = None # shape: [b], int32
|
848
866
|
|
849
867
|
# For DP attention
|
850
868
|
global_num_tokens: Optional[List[int]] = None
|
@@ -917,8 +935,10 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
917
935
|
|
918
936
|
is_hybrid = False
|
919
937
|
if isinstance(token_to_kv_pool_allocator, SWATokenToKVPoolAllocator):
|
920
|
-
assert
|
921
|
-
tree_cache
|
938
|
+
assert (
|
939
|
+
tree_cache is None
|
940
|
+
or isinstance(tree_cache, SWARadixCache)
|
941
|
+
or isinstance(tree_cache, SWAChunkCache)
|
922
942
|
), "SWARadixCache or SWAChunkCache is required for SWATokenToKVPoolAllocator"
|
923
943
|
is_hybrid = True
|
924
944
|
|
@@ -1128,6 +1148,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1128
1148
|
input_ids = [r.fill_ids[len(r.prefix_indices) :] for r in reqs]
|
1129
1149
|
extend_num_tokens = sum(len(ids) for ids in input_ids)
|
1130
1150
|
seq_lens = [len(r.fill_ids) for r in reqs]
|
1151
|
+
orig_seq_lens = [max(len(r.fill_ids), len(r.origin_input_ids)) for r in reqs]
|
1131
1152
|
prefix_lens = [len(r.prefix_indices) for r in reqs]
|
1132
1153
|
extend_lens = [r.extend_input_len for r in reqs]
|
1133
1154
|
|
@@ -1138,10 +1159,13 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1138
1159
|
req_pool_indices_tensor = torch.tensor(req_pool_indices, dtype=torch.int64).to(
|
1139
1160
|
self.device, non_blocking=True
|
1140
1161
|
)
|
1141
|
-
input_ids_tensor = torch.tensor(
|
1162
|
+
input_ids_tensor = torch.tensor(
|
1163
|
+
list(chain.from_iterable(input_ids)), dtype=torch.int64
|
1164
|
+
).to(self.device, non_blocking=True)
|
1165
|
+
seq_lens_tensor = torch.tensor(seq_lens, dtype=torch.int64).to(
|
1142
1166
|
self.device, non_blocking=True
|
1143
1167
|
)
|
1144
|
-
|
1168
|
+
orig_seq_lens_tensor = torch.tensor(orig_seq_lens, dtype=torch.int32).to(
|
1145
1169
|
self.device, non_blocking=True
|
1146
1170
|
)
|
1147
1171
|
prefix_lens_tensor = torch.tensor(
|
@@ -1257,6 +1281,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1257
1281
|
self.input_ids = input_ids_tensor
|
1258
1282
|
self.req_pool_indices = req_pool_indices_tensor
|
1259
1283
|
self.seq_lens = seq_lens_tensor
|
1284
|
+
self.orig_seq_lens = orig_seq_lens_tensor
|
1260
1285
|
self.out_cache_loc = out_cache_loc
|
1261
1286
|
self.input_embeds = (
|
1262
1287
|
torch.tensor(input_embeds).to(self.device, non_blocking=True)
|
@@ -1504,6 +1529,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1504
1529
|
self.forward_mode = ForwardMode.IDLE
|
1505
1530
|
self.input_ids = torch.empty(0, dtype=torch.int64, device=self.device)
|
1506
1531
|
self.seq_lens = torch.empty(0, dtype=torch.int64, device=self.device)
|
1532
|
+
self.orig_seq_lens = torch.empty(0, dtype=torch.int32, device=self.device)
|
1507
1533
|
self.out_cache_loc = torch.empty(0, dtype=torch.int64, device=self.device)
|
1508
1534
|
self.req_pool_indices = torch.empty(0, dtype=torch.int32, device=self.device)
|
1509
1535
|
self.seq_lens_sum = 0
|
@@ -1558,9 +1584,11 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1558
1584
|
if self.enable_overlap:
|
1559
1585
|
# Do not use in-place operations in the overlap mode
|
1560
1586
|
self.seq_lens = self.seq_lens + 1
|
1587
|
+
self.orig_seq_lens = self.orig_seq_lens + 1
|
1561
1588
|
else:
|
1562
1589
|
# A faster in-place version
|
1563
1590
|
self.seq_lens.add_(1)
|
1591
|
+
self.orig_seq_lens.add_(1)
|
1564
1592
|
self.seq_lens_sum += bs
|
1565
1593
|
|
1566
1594
|
# free memory
|
@@ -1624,6 +1652,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1624
1652
|
self.multimodal_inputs = [self.multimodal_inputs[i] for i in keep_indices]
|
1625
1653
|
self.req_pool_indices = self.req_pool_indices[keep_indices_device]
|
1626
1654
|
self.seq_lens = self.seq_lens[keep_indices_device]
|
1655
|
+
self.orig_seq_lens = self.orig_seq_lens[keep_indices_device]
|
1627
1656
|
self.out_cache_loc = None
|
1628
1657
|
self.seq_lens_sum = self.seq_lens.sum().item()
|
1629
1658
|
self.output_ids = self.output_ids[keep_indices_device]
|
@@ -1656,6 +1685,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1656
1685
|
[self.req_pool_indices, other.req_pool_indices]
|
1657
1686
|
)
|
1658
1687
|
self.seq_lens = torch.cat([self.seq_lens, other.seq_lens])
|
1688
|
+
self.orig_seq_lens = torch.cat([self.orig_seq_lens, other.orig_seq_lens])
|
1659
1689
|
self.out_cache_loc = None
|
1660
1690
|
self.seq_lens_sum += other.seq_lens_sum
|
1661
1691
|
if self.output_ids is not None:
|
@@ -1697,14 +1727,16 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1697
1727
|
attention_backend_str = global_server_args_dict["prefill_attention_backend"]
|
1698
1728
|
# Create seq_lens_cpu when needed
|
1699
1729
|
if (
|
1700
|
-
attention_backend_str
|
1701
|
-
|
1702
|
-
|
1703
|
-
|
1704
|
-
|
1705
|
-
|
1706
|
-
|
1707
|
-
|
1730
|
+
attention_backend_str
|
1731
|
+
in [
|
1732
|
+
"fa3",
|
1733
|
+
"flashinfer",
|
1734
|
+
"flashmla",
|
1735
|
+
"cutlass_mla",
|
1736
|
+
"ascend",
|
1737
|
+
"trtllm_mha",
|
1738
|
+
"aiter",
|
1739
|
+
]
|
1708
1740
|
or global_server_args_dict["enable_two_batch_overlap"]
|
1709
1741
|
):
|
1710
1742
|
seq_lens_cpu = (
|
@@ -1729,6 +1761,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1729
1761
|
input_ids=self.input_ids,
|
1730
1762
|
req_pool_indices=self.req_pool_indices,
|
1731
1763
|
seq_lens=self.seq_lens,
|
1764
|
+
orig_seq_lens=self.orig_seq_lens,
|
1732
1765
|
out_cache_loc=self.out_cache_loc,
|
1733
1766
|
seq_lens_cpu=seq_lens_cpu,
|
1734
1767
|
seq_lens_sum=self.seq_lens_sum,
|
@@ -1750,7 +1783,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
|
1750
1783
|
encoder_lens=self.encoder_lens,
|
1751
1784
|
encoder_lens_cpu=self.encoder_lens_cpu,
|
1752
1785
|
encoder_out_cache_loc=self.encoder_out_cache_loc,
|
1753
|
-
|
1786
|
+
lora_ids=[req.lora_id for req in self.reqs],
|
1754
1787
|
sampling_info=self.sampling_info,
|
1755
1788
|
input_embeds=self.input_embeds,
|
1756
1789
|
token_type_ids=self.token_type_ids,
|
@@ -1891,11 +1924,14 @@ class ModelWorkerBatch:
|
|
1891
1924
|
encoder_out_cache_loc: Optional[torch.Tensor]
|
1892
1925
|
|
1893
1926
|
# For LoRA
|
1894
|
-
|
1927
|
+
lora_ids: Optional[List[str]]
|
1895
1928
|
|
1896
1929
|
# Sampling info
|
1897
1930
|
sampling_info: SamplingBatchInfo
|
1898
1931
|
|
1932
|
+
# The original sequence lengths, Qwen-1M related
|
1933
|
+
orig_seq_lens: Optional[torch.Tensor] = None
|
1934
|
+
|
1899
1935
|
# The input Embeds
|
1900
1936
|
input_embeds: Optional[torch.Tensor] = None
|
1901
1937
|
|