sglang 0.4.10.post2__py3-none-any.whl → 0.5.0rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +113 -17
- sglang/srt/configs/model_config.py +35 -0
- sglang/srt/conversation.py +9 -5
- sglang/srt/disaggregation/base/conn.py +5 -2
- sglang/srt/disaggregation/decode.py +6 -1
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +3 -0
- sglang/srt/disaggregation/mooncake/conn.py +243 -135
- sglang/srt/disaggregation/prefill.py +2 -0
- sglang/srt/distributed/parallel_state.py +11 -9
- sglang/srt/entrypoints/context.py +244 -0
- sglang/srt/entrypoints/engine.py +4 -3
- sglang/srt/entrypoints/harmony_utils.py +370 -0
- sglang/srt/entrypoints/http_server.py +71 -0
- sglang/srt/entrypoints/openai/protocol.py +227 -1
- sglang/srt/entrypoints/openai/serving_chat.py +278 -42
- sglang/srt/entrypoints/openai/serving_responses.py +1273 -0
- sglang/srt/entrypoints/openai/tool_server.py +174 -0
- sglang/srt/entrypoints/tool.py +87 -0
- sglang/srt/eplb/expert_location.py +5 -1
- sglang/srt/function_call/harmony_tool_parser.py +130 -0
- sglang/srt/hf_transformers_utils.py +30 -3
- sglang/srt/jinja_template_utils.py +8 -1
- sglang/srt/layers/attention/aiter_backend.py +5 -8
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1700 -0
- sglang/srt/layers/attention/triton_backend.py +85 -14
- sglang/srt/layers/attention/triton_ops/decode_attention.py +17 -0
- sglang/srt/layers/attention/triton_ops/extend_attention.py +143 -98
- sglang/srt/layers/attention/trtllm_mha_backend.py +332 -0
- sglang/srt/layers/attention/vision.py +13 -5
- sglang/srt/layers/communicator.py +21 -4
- sglang/srt/layers/dp_attention.py +12 -0
- sglang/srt/layers/linear.py +2 -7
- sglang/srt/layers/moe/cutlass_moe.py +20 -6
- sglang/srt/layers/moe/ep_moe/layer.py +77 -73
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +101 -12
- sglang/srt/layers/moe/fused_moe_triton/layer.py +416 -35
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +188 -3
- sglang/srt/layers/moe/topk.py +12 -3
- sglang/srt/layers/moe/utils.py +16 -0
- sglang/srt/layers/quantization/__init__.py +22 -0
- sglang/srt/layers/quantization/fp4.py +557 -0
- sglang/srt/layers/quantization/fp8.py +3 -6
- sglang/srt/layers/quantization/fp8_utils.py +29 -0
- sglang/srt/layers/quantization/modelopt_quant.py +259 -64
- sglang/srt/layers/quantization/mxfp4.py +651 -0
- sglang/srt/layers/quantization/mxfp4_tensor.py +133 -0
- sglang/srt/layers/quantization/quark/__init__.py +0 -0
- sglang/srt/layers/quantization/quark/schemes/__init__.py +6 -0
- sglang/srt/layers/quantization/quark/schemes/quark_scheme.py +55 -0
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +118 -0
- sglang/srt/layers/quantization/quark/utils.py +107 -0
- sglang/srt/layers/quantization/unquant.py +60 -6
- sglang/srt/layers/quantization/w4afp8.py +1 -1
- sglang/srt/layers/rotary_embedding.py +225 -1
- sglang/srt/layers/utils.py +9 -0
- sglang/srt/layers/vocab_parallel_embedding.py +8 -3
- sglang/srt/lora/lora_manager.py +70 -14
- sglang/srt/lora/lora_registry.py +3 -2
- sglang/srt/lora/mem_pool.py +43 -5
- sglang/srt/managers/cache_controller.py +55 -30
- sglang/srt/managers/detokenizer_manager.py +1 -1
- sglang/srt/managers/io_struct.py +15 -3
- sglang/srt/managers/mm_utils.py +5 -11
- sglang/srt/managers/schedule_batch.py +28 -7
- sglang/srt/managers/scheduler.py +26 -12
- sglang/srt/managers/scheduler_output_processor_mixin.py +1 -2
- sglang/srt/managers/scheduler_recv_skipper.py +37 -0
- sglang/srt/managers/scheduler_update_weights_mixin.py +6 -0
- sglang/srt/managers/template_manager.py +35 -1
- sglang/srt/managers/tokenizer_manager.py +24 -6
- sglang/srt/managers/tp_worker.py +3 -0
- sglang/srt/managers/tp_worker_overlap_thread.py +3 -0
- sglang/srt/mem_cache/hiradix_cache.py +53 -5
- sglang/srt/mem_cache/memory_pool_host.py +1 -1
- sglang/srt/mem_cache/multimodal_cache.py +33 -13
- sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +2 -2
- sglang/srt/model_executor/cuda_graph_runner.py +7 -6
- sglang/srt/model_executor/forward_batch_info.py +35 -14
- sglang/srt/model_executor/model_runner.py +19 -2
- sglang/srt/model_loader/weight_utils.py +10 -0
- sglang/srt/models/bailing_moe.py +425 -0
- sglang/srt/models/deepseek_v2.py +72 -33
- sglang/srt/models/ernie4.py +426 -0
- sglang/srt/models/ernie4_eagle.py +203 -0
- sglang/srt/models/gemma3n_mm.py +39 -0
- sglang/srt/models/glm4_moe.py +24 -12
- sglang/srt/models/gpt_oss.py +1134 -0
- sglang/srt/models/qwen2.py +6 -0
- sglang/srt/models/qwen2_moe.py +6 -0
- sglang/srt/models/qwen3_moe.py +32 -6
- sglang/srt/models/step3_vl.py +9 -0
- sglang/srt/models/transformers.py +2 -5
- sglang/srt/multimodal/processors/step3_vl.py +3 -1
- sglang/srt/reasoning_parser.py +18 -39
- sglang/srt/server_args.py +142 -7
- sglang/srt/two_batch_overlap.py +157 -5
- sglang/srt/utils.py +38 -2
- sglang/test/runners.py +2 -2
- sglang/test/test_utils.py +1 -1
- sglang/version.py +1 -1
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc0.dist-info}/METADATA +16 -14
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc0.dist-info}/RECORD +105 -84
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc0.dist-info}/WHEEL +0 -0
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc0.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc0.dist-info}/top_level.txt +0 -0
@@ -556,7 +556,7 @@ class TokenizerManager:
|
|
556
556
|
if self.server_args.enable_lora and obj.lora_path:
|
557
557
|
# Start tracking ongoing requests for LoRA adapters and replace the user-friendly LoRA names in
|
558
558
|
# `lora_path` with their corresponding unique LoRA IDs, as required for internal processing.
|
559
|
-
obj.
|
559
|
+
obj.lora_id = await self.lora_registry.acquire(obj.lora_path)
|
560
560
|
|
561
561
|
self._validate_one_request(obj, input_ids)
|
562
562
|
return self._create_tokenized_object(
|
@@ -665,7 +665,7 @@ class TokenizerManager:
|
|
665
665
|
bootstrap_host=obj.bootstrap_host,
|
666
666
|
bootstrap_port=obj.bootstrap_port,
|
667
667
|
bootstrap_room=obj.bootstrap_room,
|
668
|
-
|
668
|
+
lora_id=obj.lora_id,
|
669
669
|
input_embeds=input_embeds,
|
670
670
|
session_params=session_params,
|
671
671
|
custom_logit_processor=obj.custom_logit_processor,
|
@@ -750,7 +750,11 @@ class TokenizerManager:
|
|
750
750
|
try:
|
751
751
|
await asyncio.wait_for(state.event.wait(), timeout=4)
|
752
752
|
except asyncio.TimeoutError:
|
753
|
-
if
|
753
|
+
if (
|
754
|
+
request is not None
|
755
|
+
and not obj.background
|
756
|
+
and await request.is_disconnected()
|
757
|
+
):
|
754
758
|
# Abort the request for disconnected requests (non-streaming, waiting queue)
|
755
759
|
self.abort_request(obj.rid)
|
756
760
|
# Use exception to kill the whole call stack and asyncio task
|
@@ -773,7 +777,7 @@ class TokenizerManager:
|
|
773
777
|
|
774
778
|
# Mark ongoing LoRA request as finished.
|
775
779
|
if self.server_args.enable_lora and obj.lora_path:
|
776
|
-
await self.lora_registry.release(obj.
|
780
|
+
await self.lora_registry.release(obj.lora_id)
|
777
781
|
|
778
782
|
# Check if this was an abort/error created by scheduler
|
779
783
|
if isinstance(out["meta_info"].get("finish_reason"), dict):
|
@@ -805,7 +809,11 @@ class TokenizerManager:
|
|
805
809
|
if obj.stream:
|
806
810
|
yield out
|
807
811
|
else:
|
808
|
-
if
|
812
|
+
if (
|
813
|
+
request is not None
|
814
|
+
and not obj.background
|
815
|
+
and await request.is_disconnected()
|
816
|
+
):
|
809
817
|
# Abort the request for disconnected requests (non-streaming, running)
|
810
818
|
self.abort_request(obj.rid)
|
811
819
|
# Use exception to kill the whole call stack and asyncio task
|
@@ -1121,6 +1129,7 @@ class TokenizerManager:
|
|
1121
1129
|
new_adapter = LoRARef(
|
1122
1130
|
lora_name=obj.lora_name,
|
1123
1131
|
lora_path=obj.lora_path,
|
1132
|
+
pinned=obj.pinned,
|
1124
1133
|
)
|
1125
1134
|
|
1126
1135
|
# Trigger the actual loading operation at the backend processes.
|
@@ -1178,7 +1187,7 @@ class TokenizerManager:
|
|
1178
1187
|
|
1179
1188
|
return result
|
1180
1189
|
except ValueError as e:
|
1181
|
-
return UnloadLoRAAdapterReqOutput(success=False,
|
1190
|
+
return UnloadLoRAAdapterReqOutput(success=False, error_message=str(e))
|
1182
1191
|
|
1183
1192
|
async def get_weights_by_name(
|
1184
1193
|
self, obj: GetWeightsByNameReqInput, request: Optional[fastapi.Request] = None
|
@@ -1548,8 +1557,17 @@ class TokenizerManager:
|
|
1548
1557
|
|
1549
1558
|
if isinstance(recv_obj, BatchStrOut):
|
1550
1559
|
state.text += recv_obj.output_strs[i]
|
1560
|
+
if state.obj.stream:
|
1561
|
+
state.output_ids.extend(recv_obj.output_ids[i])
|
1562
|
+
output_token_ids = state.output_ids[state.last_output_offset :]
|
1563
|
+
state.last_output_offset = len(state.output_ids)
|
1564
|
+
else:
|
1565
|
+
state.output_ids.extend(recv_obj.output_ids[i])
|
1566
|
+
output_token_ids = state.output_ids.copy()
|
1567
|
+
|
1551
1568
|
out_dict = {
|
1552
1569
|
"text": state.text,
|
1570
|
+
"output_ids": output_token_ids,
|
1553
1571
|
"meta_info": meta_info,
|
1554
1572
|
}
|
1555
1573
|
elif isinstance(recv_obj, BatchTokenIDOut):
|
sglang/srt/managers/tp_worker.py
CHANGED
@@ -311,3 +311,6 @@ class TpModelWorker:
|
|
311
311
|
def unload_lora_adapter(self, recv_req: UnloadLoRAAdapterReqInput):
|
312
312
|
result = self.model_runner.unload_lora_adapter(recv_req.to_ref())
|
313
313
|
return result
|
314
|
+
|
315
|
+
def can_run_lora_batch(self, lora_ids: list[str]) -> bool:
|
316
|
+
return self.model_runner.lora_manager.validate_lora_batch(lora_ids)
|
@@ -288,6 +288,9 @@ class TpModelWorkerClient:
|
|
288
288
|
def unload_lora_adapter(self, recv_req: UnloadLoRAAdapterReqInput):
|
289
289
|
return self.worker.unload_lora_adapter(recv_req)
|
290
290
|
|
291
|
+
def can_run_lora_batch(self, lora_ids: list[str]) -> bool:
|
292
|
+
return self.worker.can_run_lora_batch(lora_ids)
|
293
|
+
|
291
294
|
def __delete__(self):
|
292
295
|
self.input_queue.put((None, None))
|
293
296
|
self.copy_queue.put((None, None, None))
|
@@ -2,11 +2,12 @@ import heapq
|
|
2
2
|
import logging
|
3
3
|
import threading
|
4
4
|
import time
|
5
|
+
from queue import Queue
|
5
6
|
from typing import List, Optional
|
6
7
|
|
7
8
|
import torch
|
8
9
|
|
9
|
-
from sglang.srt.managers.cache_controller import HiCacheController
|
10
|
+
from sglang.srt.managers.cache_controller import HiCacheController, PrefetchOperation
|
10
11
|
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
|
11
12
|
from sglang.srt.mem_cache.base_prefix_cache import MatchResult
|
12
13
|
from sglang.srt.mem_cache.memory_pool import (
|
@@ -37,6 +38,7 @@ class HiRadixCache(RadixCache):
|
|
37
38
|
hicache_io_backend: str,
|
38
39
|
hicache_mem_layout: str,
|
39
40
|
hicache_storage_backend: Optional[str] = None,
|
41
|
+
hicache_storage_prefetch_policy: Optional[str] = "best_effort",
|
40
42
|
):
|
41
43
|
|
42
44
|
if hicache_io_backend == "direct":
|
@@ -85,6 +87,13 @@ class HiRadixCache(RadixCache):
|
|
85
87
|
prefetch_threshold=self.prefetch_threshold,
|
86
88
|
)
|
87
89
|
|
90
|
+
self.prefetch_stop_policy = hicache_storage_prefetch_policy
|
91
|
+
# todo: customizable storage prefetch timeout
|
92
|
+
self.prefetch_timeout = 3 # seconds
|
93
|
+
logger.info(
|
94
|
+
f"HiCache storage prefetch policy: {hicache_storage_prefetch_policy}"
|
95
|
+
)
|
96
|
+
|
88
97
|
# record the nodes with ongoing write through
|
89
98
|
self.ongoing_write_through = {}
|
90
99
|
# record the node segments with ongoing load back
|
@@ -385,9 +394,10 @@ class HiRadixCache(RadixCache):
|
|
385
394
|
for _ in range(queue_size.item()):
|
386
395
|
req_id = self.cache_controller.prefetch_revoke_queue.get()
|
387
396
|
if req_id in self.ongoing_prefetch:
|
388
|
-
last_host_node,
|
397
|
+
last_host_node, token_ids, _, _ = self.ongoing_prefetch[req_id]
|
389
398
|
last_host_node.release_host()
|
390
399
|
del self.ongoing_prefetch[req_id]
|
400
|
+
self.cache_controller.prefetch_tokens_occupied -= len(token_ids)
|
391
401
|
else:
|
392
402
|
# the revoked operation already got terminated
|
393
403
|
pass
|
@@ -419,10 +429,41 @@ class HiRadixCache(RadixCache):
|
|
419
429
|
host_node.release_host()
|
420
430
|
del self.ongoing_backup[ack_id]
|
421
431
|
|
422
|
-
def
|
432
|
+
def can_terminate_prefetch(self, operation: PrefetchOperation):
|
433
|
+
can_terminate = True
|
434
|
+
|
435
|
+
if self.prefetch_stop_policy == "best_effort":
|
436
|
+
return can_terminate
|
437
|
+
|
438
|
+
completed = (
|
439
|
+
operation.completed_tokens == len(operation.hash_value) * self.page_size
|
440
|
+
)
|
441
|
+
|
442
|
+
if self.prefetch_stop_policy == "wait_complete":
|
443
|
+
can_terminate = completed
|
444
|
+
elif self.prefetch_stop_policy == "timeout":
|
445
|
+
can_terminate = completed or (
|
446
|
+
time.monotonic() - operation.start_time > self.prefetch_timeout
|
447
|
+
)
|
448
|
+
else:
|
449
|
+
# unknown prefetch stop policy, just return True
|
450
|
+
return True
|
451
|
+
|
452
|
+
if self.tp_world_size > 1:
|
453
|
+
can_terminate = torch.tensor(can_terminate, dtype=torch.int)
|
454
|
+
torch.distributed.all_reduce(
|
455
|
+
can_terminate,
|
456
|
+
op=torch.distributed.ReduceOp.MIN,
|
457
|
+
group=self.tp_group,
|
458
|
+
)
|
459
|
+
can_terminate = bool(can_terminate.item())
|
460
|
+
|
461
|
+
return can_terminate
|
462
|
+
|
463
|
+
def check_prefetch_progress(self, req_id: str) -> bool:
|
423
464
|
if req_id not in self.ongoing_prefetch:
|
424
465
|
# there is no ongoing prefetch for this request or it has been revoked
|
425
|
-
return
|
466
|
+
return True
|
426
467
|
|
427
468
|
# todo: more policies for prefetch progress such as timeout
|
428
469
|
# the current policy is to prefetch with best effort and terminate when queuing is over
|
@@ -430,13 +471,16 @@ class HiRadixCache(RadixCache):
|
|
430
471
|
req_id
|
431
472
|
]
|
432
473
|
|
474
|
+
if not self.can_terminate_prefetch(operation):
|
475
|
+
return False
|
476
|
+
|
433
477
|
completed_tokens, hash_value = self.cache_controller.terminate_prefetch(
|
434
478
|
operation
|
435
479
|
)
|
436
480
|
logger.debug(f"Prefetch {req_id} completed with {completed_tokens} tokens")
|
437
481
|
|
438
482
|
min_completed_tokens = completed_tokens
|
439
|
-
if self.tp_world_size > 1:
|
483
|
+
if self.tp_world_size > 1 and self.prefetch_stop_policy != "wait_complete":
|
440
484
|
# synchrnoize TP workers to make the same update to hiradix cache
|
441
485
|
completed_tokens_tensor = torch.tensor(
|
442
486
|
min_completed_tokens, dtype=torch.int
|
@@ -464,6 +508,9 @@ class HiRadixCache(RadixCache):
|
|
464
508
|
)
|
465
509
|
last_host_node.release_host()
|
466
510
|
del self.ongoing_prefetch[req_id]
|
511
|
+
self.cache_controller.prefetch_tokens_occupied -= len(token_ids)
|
512
|
+
|
513
|
+
return True
|
467
514
|
|
468
515
|
def match_prefix(self, key: List[int], **kwargs):
|
469
516
|
empty_value = torch.empty((0,), dtype=torch.int64, device=self.device)
|
@@ -531,6 +578,7 @@ class HiRadixCache(RadixCache):
|
|
531
578
|
host_indices,
|
532
579
|
operation,
|
533
580
|
)
|
581
|
+
self.cache_controller.prefetch_tokens_occupied += len(new_input_tokens)
|
534
582
|
|
535
583
|
def _insert_helper_host(self, node: TreeNode, key: List, host_value, hash_value):
|
536
584
|
node.last_access_time = time.monotonic()
|
@@ -618,7 +618,7 @@ class MLATokenToKVPoolHost(HostKVCache):
|
|
618
618
|
elif self.layout == "page_first":
|
619
619
|
transfer_kv_all_layer_mla_lf_pf(
|
620
620
|
src_layers=device_pool.data_ptrs,
|
621
|
-
|
621
|
+
dst=self.kv_buffer,
|
622
622
|
src_indices=device_indices,
|
623
623
|
dst_indices=host_indices,
|
624
624
|
item_size=self.token_stride_size,
|
@@ -1,24 +1,46 @@
|
|
1
|
+
import logging
|
2
|
+
from collections import OrderedDict
|
1
3
|
from typing import Dict
|
2
4
|
|
3
5
|
import torch
|
4
6
|
|
7
|
+
# Set up logging for cache behavior
|
8
|
+
logger = logging.getLogger(__name__)
|
9
|
+
|
5
10
|
|
6
11
|
class MultiModalCache:
|
7
|
-
"""MultiModalCache is used to store vlm encoder results"""
|
12
|
+
"""MultiModalCache is used to store vlm encoder results with LRU eviction"""
|
8
13
|
|
9
14
|
def __init__(
|
10
15
|
self,
|
11
16
|
max_size: int,
|
12
17
|
):
|
13
18
|
self.max_size = max_size
|
14
|
-
self.mm_cache:
|
19
|
+
self.mm_cache: OrderedDict[int, torch.Tensor] = OrderedDict()
|
15
20
|
self.current_size = 0
|
16
21
|
|
22
|
+
def _allocate(self, embedding_size: int) -> bool:
|
23
|
+
"""Allocate space by evicting least recently used entries"""
|
24
|
+
evictions = 0
|
25
|
+
while self.current_size + embedding_size > self.max_size and self.mm_cache:
|
26
|
+
_, old_embedding = self.mm_cache.popitem(last=False)
|
27
|
+
evicted_size = self._get_tensor_size(old_embedding)
|
28
|
+
self.current_size -= evicted_size
|
29
|
+
evictions += evicted_size
|
30
|
+
|
31
|
+
if evictions > 0:
|
32
|
+
logger.debug(
|
33
|
+
f"Cache eviction: evicted {evictions} bytes, remaining size: {self.current_size}/{self.max_size} bytes"
|
34
|
+
)
|
35
|
+
|
36
|
+
if self.current_size + embedding_size > self.max_size:
|
37
|
+
return False
|
38
|
+
return True
|
39
|
+
|
17
40
|
def put(self, mm_hash: int, embedding: torch.Tensor) -> bool:
|
18
|
-
if mm_hash in self.mm_cache:
|
19
|
-
return True
|
20
41
|
data_size = self._get_tensor_size(embedding)
|
21
|
-
|
42
|
+
# Lazy free cache if not enough space
|
43
|
+
if not self._allocate(data_size):
|
22
44
|
return False
|
23
45
|
self.mm_cache[mm_hash] = embedding
|
24
46
|
self.current_size += data_size
|
@@ -28,14 +50,12 @@ class MultiModalCache:
|
|
28
50
|
return mm_hash in self.mm_cache
|
29
51
|
|
30
52
|
def get(self, mm_hash: int) -> torch.Tensor:
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
return
|
36
|
-
|
37
|
-
self.current_size -= self._get_tensor_size(old_embedding)
|
38
|
-
return True
|
53
|
+
"""Get embedding and update LRU order"""
|
54
|
+
if mm_hash in self.mm_cache:
|
55
|
+
# Move to end (most recently used)
|
56
|
+
self.mm_cache.move_to_end(mm_hash)
|
57
|
+
return self.mm_cache[mm_hash]
|
58
|
+
return None
|
39
59
|
|
40
60
|
def clear(self):
|
41
61
|
self.mm_cache.clear()
|
@@ -96,6 +96,8 @@ class Hf3fsClient:
|
|
96
96
|
)
|
97
97
|
self.iov_r = make_iovec(self.shm_r, self.hf3fs_mount_point)
|
98
98
|
self.iov_w = make_iovec(self.shm_w, self.hf3fs_mount_point)
|
99
|
+
self.shm_r.unlink()
|
100
|
+
self.shm_w.unlink()
|
99
101
|
|
100
102
|
self.rlock = threading.RLock()
|
101
103
|
self.wlock = threading.RLock()
|
@@ -176,8 +178,6 @@ class Hf3fsClient:
|
|
176
178
|
del self.iov_w
|
177
179
|
self.shm_r.close()
|
178
180
|
self.shm_w.close()
|
179
|
-
self.shm_r.unlink()
|
180
|
-
self.shm_w.unlink()
|
181
181
|
|
182
182
|
def flush(self) -> None:
|
183
183
|
os.fsync(self.file)
|
@@ -576,11 +576,11 @@ class CudaGraphRunner:
|
|
576
576
|
)
|
577
577
|
|
578
578
|
if self.model_runner.server_args.enable_lora:
|
579
|
-
# It is safe to capture CUDA graph using empty LoRA
|
580
|
-
# `--enable-lora` is set to True (and return immediately if the LoRA
|
581
|
-
|
579
|
+
# It is safe to capture CUDA graph using empty LoRA id, as the LoRA kernels will always be launched whenever
|
580
|
+
# `--enable-lora` is set to True (and return immediately if the LoRA id is empty for perf optimization).
|
581
|
+
lora_ids = [None] * bs
|
582
582
|
else:
|
583
|
-
|
583
|
+
lora_ids = None
|
584
584
|
|
585
585
|
forward_batch = ForwardBatch(
|
586
586
|
forward_mode=self.capture_forward_mode,
|
@@ -589,6 +589,7 @@ class CudaGraphRunner:
|
|
589
589
|
req_pool_indices=req_pool_indices,
|
590
590
|
seq_lens=seq_lens,
|
591
591
|
next_token_logits_buffer=next_token_logits_buffer,
|
592
|
+
orig_seq_lens=seq_lens,
|
592
593
|
req_to_token_pool=self.model_runner.req_to_token_pool,
|
593
594
|
token_to_kv_pool=self.model_runner.token_to_kv_pool,
|
594
595
|
attn_backend=self.model_runner.attn_backend,
|
@@ -607,11 +608,11 @@ class CudaGraphRunner:
|
|
607
608
|
capture_hidden_mode=self.capture_hidden_mode,
|
608
609
|
num_token_non_padded=self.num_token_non_padded,
|
609
610
|
global_forward_mode=self.capture_forward_mode,
|
610
|
-
|
611
|
+
lora_ids=lora_ids,
|
611
612
|
)
|
612
613
|
self.tbo_plugin.capture_one_batch_size(forward_batch, num_tokens=num_tokens)
|
613
614
|
|
614
|
-
if
|
615
|
+
if lora_ids is not None:
|
615
616
|
self.model_runner.lora_manager.prepare_lora_batch(forward_batch)
|
616
617
|
|
617
618
|
# Attention backend
|
@@ -180,6 +180,9 @@ class ForwardBatch:
|
|
180
180
|
# The sum of all sequence lengths
|
181
181
|
seq_lens_sum: int
|
182
182
|
|
183
|
+
# The original sequence length without being chunked. Qwen-1M related.
|
184
|
+
orig_seq_lens: Optional[torch.Tensor] = None
|
185
|
+
|
183
186
|
# Optional seq_lens on cpu
|
184
187
|
seq_lens_cpu: Optional[torch.Tensor] = None
|
185
188
|
|
@@ -248,7 +251,7 @@ class ForwardBatch:
|
|
248
251
|
encoder_out_cache_loc: Optional[torch.Tensor] = None
|
249
252
|
|
250
253
|
# For LoRA
|
251
|
-
|
254
|
+
lora_ids: Optional[List[str]] = None
|
252
255
|
|
253
256
|
# For input embeddings
|
254
257
|
input_embeds: Optional[torch.Tensor] = None
|
@@ -321,13 +324,14 @@ class ForwardBatch:
|
|
321
324
|
encoder_out_cache_loc=batch.encoder_out_cache_loc,
|
322
325
|
seq_lens_sum=batch.seq_lens_sum,
|
323
326
|
seq_lens_cpu=batch.seq_lens_cpu,
|
327
|
+
orig_seq_lens=batch.orig_seq_lens,
|
324
328
|
return_logprob=batch.return_logprob,
|
325
329
|
top_logprobs_nums=batch.top_logprobs_nums,
|
326
330
|
token_ids_logprobs=batch.token_ids_logprobs,
|
327
331
|
is_extend_in_batch=batch.is_extend_in_batch,
|
328
332
|
can_run_dp_cuda_graph=batch.can_run_dp_cuda_graph,
|
329
333
|
global_forward_mode=batch.global_forward_mode,
|
330
|
-
|
334
|
+
lora_ids=batch.lora_ids,
|
331
335
|
sampling_info=batch.sampling_info,
|
332
336
|
req_to_token_pool=model_runner.req_to_token_pool,
|
333
337
|
token_to_kv_pool=model_runner.token_to_kv_pool,
|
@@ -420,16 +424,12 @@ class ForwardBatch:
|
|
420
424
|
batch.extend_prefix_lens, dtype=torch.int32
|
421
425
|
).to(device, non_blocking=True)
|
422
426
|
ret.extend_num_tokens = batch.extend_num_tokens
|
423
|
-
|
424
|
-
|
425
|
-
|
426
|
-
|
427
|
-
|
428
|
-
|
429
|
-
else:
|
430
|
-
positions, ret.extend_start_loc = compute_position_torch(
|
431
|
-
ret.extend_prefix_lens, ret.extend_seq_lens
|
432
|
-
)
|
427
|
+
positions, ret.extend_start_loc = compute_position(
|
428
|
+
model_runner.server_args.attention_backend,
|
429
|
+
ret.extend_prefix_lens,
|
430
|
+
ret.extend_seq_lens,
|
431
|
+
ret.extend_num_tokens,
|
432
|
+
)
|
433
433
|
if ret.positions is None:
|
434
434
|
ret.positions = positions
|
435
435
|
ret.extend_prefix_lens_cpu = batch.extend_prefix_lens
|
@@ -632,8 +632,10 @@ class ForwardBatch:
|
|
632
632
|
self.dp_padding_mode = dp_padding_mode
|
633
633
|
|
634
634
|
if dp_padding_mode.is_max_len():
|
635
|
-
# when DP gather mode is all gather, we will use
|
636
|
-
#
|
635
|
+
# when DP gather mode is all gather, we will use
|
636
|
+
# all_gather_into_tensor to gather hidden states, where transferred
|
637
|
+
# tokens should be padded to the same length. We will also use
|
638
|
+
# reduce-scatter instead of all-reduce after MLP.
|
637
639
|
max_num_tokens = max(global_num_tokens)
|
638
640
|
global_num_tokens = [max_num_tokens] * sync_group_size
|
639
641
|
buffer_len = max_num_tokens * sync_group_size
|
@@ -882,6 +884,25 @@ class PPProxyTensors:
|
|
882
884
|
return f"PPProxyTensors(tensors={self.tensors})"
|
883
885
|
|
884
886
|
|
887
|
+
def compute_position(
|
888
|
+
attn_backend: str,
|
889
|
+
extend_prefix_lens: torch.Tensor,
|
890
|
+
extend_seq_lens: torch.Tensor,
|
891
|
+
extend_seq_lens_sum: int,
|
892
|
+
):
|
893
|
+
if support_triton(attn_backend):
|
894
|
+
positions, extend_start_loc = compute_position_triton(
|
895
|
+
extend_prefix_lens,
|
896
|
+
extend_seq_lens,
|
897
|
+
extend_seq_lens_sum,
|
898
|
+
)
|
899
|
+
else:
|
900
|
+
positions, extend_start_loc = compute_position_torch(
|
901
|
+
extend_prefix_lens, extend_seq_lens
|
902
|
+
)
|
903
|
+
return positions, extend_start_loc
|
904
|
+
|
905
|
+
|
885
906
|
def compute_position_triton(
|
886
907
|
extend_prefix_lens: torch.Tensor, extend_seq_lens: torch.Tensor, extend_seq_lens_sum
|
887
908
|
):
|
@@ -1443,19 +1443,36 @@ class ModelRunner:
|
|
1443
1443
|
)
|
1444
1444
|
|
1445
1445
|
return CutlassMLABackend(self)
|
1446
|
-
elif
|
1446
|
+
elif backend_str == "trtllm_mla":
|
1447
1447
|
if not self.use_mla_backend:
|
1448
1448
|
raise ValueError("trtllm_mla backend can only be used with MLA models.")
|
1449
1449
|
from sglang.srt.layers.attention.trtllm_mla_backend import TRTLLMMLABackend
|
1450
1450
|
|
1451
1451
|
return TRTLLMMLABackend(self)
|
1452
|
-
elif
|
1452
|
+
elif backend_str == "trtllm_mha":
|
1453
|
+
if self.use_mla_backend:
|
1454
|
+
raise ValueError(
|
1455
|
+
"trtllm_mha backend can only be used with non-MLA models."
|
1456
|
+
)
|
1457
|
+
from sglang.srt.layers.attention.trtllm_mha_backend import (
|
1458
|
+
TRTLLMHAAttnBackend,
|
1459
|
+
)
|
1460
|
+
|
1461
|
+
return TRTLLMHAAttnBackend(self)
|
1462
|
+
|
1463
|
+
elif backend_str == "intel_amx":
|
1453
1464
|
from sglang.srt.layers.attention.intel_amx_backend import (
|
1454
1465
|
IntelAMXAttnBackend,
|
1455
1466
|
)
|
1456
1467
|
|
1457
1468
|
logger.info(f"Intel AMX attention backend is enabled.")
|
1458
1469
|
return IntelAMXAttnBackend(self)
|
1470
|
+
elif self.server_args.attention_backend == "dual_chunk_flash_attn":
|
1471
|
+
from sglang.srt.layers.attention.dual_chunk_flashattention_backend import (
|
1472
|
+
DualChunkFlashAttentionBackend,
|
1473
|
+
)
|
1474
|
+
|
1475
|
+
return DualChunkFlashAttentionBackend(self)
|
1459
1476
|
else:
|
1460
1477
|
raise ValueError(f"Invalid attention backend: {backend_str}")
|
1461
1478
|
|
@@ -843,6 +843,16 @@ def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> Optional[str]:
|
|
843
843
|
return None
|
844
844
|
return remapped_name
|
845
845
|
|
846
|
+
quark_scale_names = {
|
847
|
+
".q_proj.output_scale": ".attn.q_scale",
|
848
|
+
".k_proj.output_scale": ".attn.k_scale",
|
849
|
+
".v_proj.output_scale": ".attn.v_scale",
|
850
|
+
"self_attn.prob_output_scale": ".attn.prob_scale",
|
851
|
+
}
|
852
|
+
for quark_scale_name, sglang_scale_name in quark_scale_names.items():
|
853
|
+
if name.endswith(quark_scale_name):
|
854
|
+
return name.replace(quark_scale_name, sglang_scale_name)
|
855
|
+
|
846
856
|
# If there were no matches, return the untouched param name
|
847
857
|
return name
|
848
858
|
|