sglang 0.4.10.post2__py3-none-any.whl → 0.5.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +8 -3
- sglang/bench_one_batch.py +119 -17
- sglang/lang/chat_template.py +18 -0
- sglang/srt/bench_utils.py +137 -0
- sglang/srt/configs/model_config.py +42 -7
- sglang/srt/conversation.py +9 -5
- sglang/srt/disaggregation/base/conn.py +5 -2
- sglang/srt/disaggregation/decode.py +14 -4
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +3 -0
- sglang/srt/disaggregation/mooncake/conn.py +286 -160
- sglang/srt/disaggregation/mooncake/transfer_engine.py +29 -0
- sglang/srt/disaggregation/prefill.py +2 -0
- sglang/srt/distributed/parallel_state.py +15 -11
- sglang/srt/entrypoints/context.py +227 -0
- sglang/srt/entrypoints/engine.py +15 -9
- sglang/srt/entrypoints/harmony_utils.py +372 -0
- sglang/srt/entrypoints/http_server.py +74 -4
- sglang/srt/entrypoints/openai/protocol.py +218 -1
- sglang/srt/entrypoints/openai/serving_chat.py +41 -11
- sglang/srt/entrypoints/openai/serving_responses.py +1273 -0
- sglang/srt/entrypoints/openai/tool_server.py +175 -0
- sglang/srt/entrypoints/tool.py +87 -0
- sglang/srt/eplb/expert_location.py +5 -1
- sglang/srt/function_call/ebnf_composer.py +1 -0
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/glm4_moe_detector.py +1 -1
- sglang/srt/function_call/gpt_oss_detector.py +331 -0
- sglang/srt/function_call/kimik2_detector.py +3 -3
- sglang/srt/function_call/qwen3_coder_detector.py +219 -9
- sglang/srt/hf_transformers_utils.py +30 -3
- sglang/srt/jinja_template_utils.py +14 -1
- sglang/srt/layers/attention/aiter_backend.py +375 -115
- sglang/srt/layers/attention/ascend_backend.py +3 -0
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1700 -0
- sglang/srt/layers/attention/flashattention_backend.py +18 -0
- sglang/srt/layers/attention/flashinfer_backend.py +52 -13
- sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
- sglang/srt/layers/attention/triton_backend.py +85 -14
- sglang/srt/layers/attention/triton_ops/decode_attention.py +17 -0
- sglang/srt/layers/attention/triton_ops/extend_attention.py +143 -98
- sglang/srt/layers/attention/trtllm_mha_backend.py +332 -0
- sglang/srt/layers/attention/trtllm_mla_backend.py +119 -22
- sglang/srt/layers/attention/vision.py +22 -6
- sglang/srt/layers/attention/wave_backend.py +627 -0
- sglang/srt/layers/attention/wave_ops/decode_attention.py +186 -0
- sglang/srt/layers/attention/wave_ops/extend_attention.py +149 -0
- sglang/srt/layers/attention/wave_ops/prefill_attention.py +79 -0
- sglang/srt/layers/communicator.py +29 -14
- sglang/srt/layers/dp_attention.py +12 -0
- sglang/srt/layers/flashinfer_comm_fusion.py +4 -4
- sglang/srt/layers/linear.py +3 -7
- sglang/srt/layers/moe/cutlass_moe.py +12 -3
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -5
- sglang/srt/layers/moe/ep_moe/kernels.py +43 -0
- sglang/srt/layers/moe/ep_moe/layer.py +135 -73
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=384,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +101 -12
- sglang/srt/layers/moe/fused_moe_triton/layer.py +412 -33
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +188 -3
- sglang/srt/layers/moe/token_dispatcher/deepep.py +61 -24
- sglang/srt/layers/moe/topk.py +16 -4
- sglang/srt/layers/moe/utils.py +16 -0
- sglang/srt/layers/quantization/__init__.py +27 -3
- sglang/srt/layers/quantization/fp4.py +557 -0
- sglang/srt/layers/quantization/fp8.py +3 -6
- sglang/srt/layers/quantization/fp8_kernel.py +277 -0
- sglang/srt/layers/quantization/fp8_utils.py +51 -10
- sglang/srt/layers/quantization/modelopt_quant.py +258 -68
- sglang/srt/layers/quantization/mxfp4.py +654 -0
- sglang/srt/layers/quantization/mxfp4_tensor.py +133 -0
- sglang/srt/layers/quantization/quark/schemes/__init__.py +6 -0
- sglang/srt/layers/quantization/quark/schemes/quark_scheme.py +55 -0
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +118 -0
- sglang/srt/layers/quantization/quark/utils.py +107 -0
- sglang/srt/layers/quantization/unquant.py +60 -6
- sglang/srt/layers/quantization/w4afp8.py +21 -12
- sglang/srt/layers/quantization/w8a8_int8.py +48 -34
- sglang/srt/layers/rotary_embedding.py +506 -3
- sglang/srt/layers/utils.py +9 -0
- sglang/srt/layers/vocab_parallel_embedding.py +8 -3
- sglang/srt/lora/backend/base_backend.py +3 -23
- sglang/srt/lora/layers.py +60 -114
- sglang/srt/lora/lora.py +17 -62
- sglang/srt/lora/lora_manager.py +82 -62
- sglang/srt/lora/lora_registry.py +23 -11
- sglang/srt/lora/mem_pool.py +63 -68
- sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
- sglang/srt/lora/utils.py +25 -58
- sglang/srt/managers/cache_controller.py +75 -58
- sglang/srt/managers/detokenizer_manager.py +1 -1
- sglang/srt/managers/io_struct.py +20 -8
- sglang/srt/managers/mm_utils.py +6 -13
- sglang/srt/managers/multimodal_processor.py +1 -1
- sglang/srt/managers/schedule_batch.py +61 -25
- sglang/srt/managers/schedule_policy.py +6 -6
- sglang/srt/managers/scheduler.py +41 -19
- sglang/srt/managers/scheduler_output_processor_mixin.py +1 -2
- sglang/srt/managers/scheduler_profiler_mixin.py +28 -8
- sglang/srt/managers/scheduler_recv_skipper.py +37 -0
- sglang/srt/managers/scheduler_update_weights_mixin.py +6 -0
- sglang/srt/managers/template_manager.py +35 -1
- sglang/srt/managers/tokenizer_manager.py +47 -30
- sglang/srt/managers/tp_worker.py +3 -0
- sglang/srt/managers/tp_worker_overlap_thread.py +3 -0
- sglang/srt/mem_cache/allocator.py +61 -87
- sglang/srt/mem_cache/hicache_storage.py +1 -1
- sglang/srt/mem_cache/hiradix_cache.py +80 -22
- sglang/srt/mem_cache/lora_radix_cache.py +421 -0
- sglang/srt/mem_cache/memory_pool_host.py +34 -36
- sglang/srt/mem_cache/multimodal_cache.py +33 -13
- sglang/srt/mem_cache/radix_cache.py +2 -5
- sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +2 -2
- sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +443 -0
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +139 -67
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +6 -9
- sglang/srt/model_executor/cuda_graph_runner.py +29 -9
- sglang/srt/model_executor/forward_batch_info.py +61 -19
- sglang/srt/model_executor/model_runner.py +148 -37
- sglang/srt/model_loader/loader.py +18 -6
- sglang/srt/model_loader/weight_utils.py +10 -0
- sglang/srt/models/bailing_moe.py +425 -0
- sglang/srt/models/deepseek_v2.py +137 -59
- sglang/srt/models/ernie4.py +426 -0
- sglang/srt/models/ernie4_eagle.py +203 -0
- sglang/srt/models/gemma2.py +0 -34
- sglang/srt/models/gemma3n_mm.py +38 -0
- sglang/srt/models/glm4.py +6 -0
- sglang/srt/models/glm4_moe.py +28 -16
- sglang/srt/models/glm4v.py +589 -0
- sglang/srt/models/glm4v_moe.py +400 -0
- sglang/srt/models/gpt_oss.py +1251 -0
- sglang/srt/models/granite.py +0 -25
- sglang/srt/models/llama.py +0 -25
- sglang/srt/models/llama4.py +1 -1
- sglang/srt/models/qwen2.py +6 -0
- sglang/srt/models/qwen2_5_vl.py +7 -3
- sglang/srt/models/qwen2_audio.py +10 -9
- sglang/srt/models/qwen2_moe.py +6 -0
- sglang/srt/models/qwen3.py +0 -24
- sglang/srt/models/qwen3_moe.py +32 -6
- sglang/srt/models/registry.py +1 -1
- sglang/srt/models/step3_vl.py +9 -0
- sglang/srt/models/torch_native_llama.py +0 -24
- sglang/srt/models/transformers.py +2 -5
- sglang/srt/multimodal/processors/base_processor.py +23 -13
- sglang/srt/multimodal/processors/glm4v.py +132 -0
- sglang/srt/multimodal/processors/qwen_audio.py +4 -2
- sglang/srt/multimodal/processors/step3_vl.py +3 -1
- sglang/srt/reasoning_parser.py +332 -37
- sglang/srt/server_args.py +186 -75
- sglang/srt/speculative/eagle_worker.py +16 -0
- sglang/srt/two_batch_overlap.py +169 -9
- sglang/srt/utils.py +41 -5
- sglang/srt/weight_sync/tensor_bucket.py +106 -0
- sglang/test/attention/test_trtllm_mla_backend.py +186 -36
- sglang/test/doc_patch.py +59 -0
- sglang/test/few_shot_gsm8k.py +1 -1
- sglang/test/few_shot_gsm8k_engine.py +1 -1
- sglang/test/run_eval.py +4 -1
- sglang/test/runners.py +2 -2
- sglang/test/simple_eval_common.py +6 -0
- sglang/test/simple_eval_gpqa.py +2 -0
- sglang/test/test_fp4_moe.py +118 -36
- sglang/test/test_utils.py +1 -1
- sglang/utils.py +1 -1
- sglang/version.py +1 -1
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/METADATA +36 -38
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/RECORD +174 -141
- sglang/srt/lora/backend/flashinfer_backend.py +0 -131
- /sglang/{api.py → lang/api.py} +0 -0
- /sglang/{lang/backend → srt/layers/quantization/quark}/__init__.py +0 -0
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/WHEEL +0 -0
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/top_level.txt +0 -0
sglang/srt/two_batch_overlap.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
+
import copy
|
3
4
|
import dataclasses
|
4
5
|
import logging
|
5
6
|
from dataclasses import replace
|
@@ -17,15 +18,21 @@ from sglang.srt.layers.moe.token_dispatcher import DeepEPDispatcher
|
|
17
18
|
from sglang.srt.layers.moe.utils import DeepEPMode
|
18
19
|
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
19
20
|
from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict
|
20
|
-
from sglang.srt.model_executor.forward_batch_info import
|
21
|
+
from sglang.srt.model_executor.forward_batch_info import (
|
22
|
+
ForwardBatch,
|
23
|
+
ForwardMode,
|
24
|
+
compute_position,
|
25
|
+
)
|
21
26
|
from sglang.srt.operations import execute_operations, execute_overlapped_operations
|
22
27
|
from sglang.srt.operations_strategy import OperationsStrategy
|
23
28
|
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
24
|
-
from sglang.srt.utils import BumpAllocator, get_bool_env_var
|
29
|
+
from sglang.srt.utils import BumpAllocator, get_bool_env_var, is_hip
|
25
30
|
|
26
31
|
if TYPE_CHECKING:
|
27
32
|
from sglang.srt.layers.moe.token_dispatcher import DispatchOutput
|
28
33
|
|
34
|
+
_is_hip = is_hip()
|
35
|
+
|
29
36
|
_tbo_debug = get_bool_env_var("SGLANG_TBO_DEBUG")
|
30
37
|
|
31
38
|
logger = logging.getLogger(__name__)
|
@@ -58,7 +65,7 @@ def compute_split_seq_index(
|
|
58
65
|
) -> Optional[int]:
|
59
66
|
if forward_mode == ForwardMode.EXTEND:
|
60
67
|
assert extend_lens is not None
|
61
|
-
return
|
68
|
+
return _split_extend_seqs(extend_lens)
|
62
69
|
elif forward_mode.is_target_verify() or forward_mode.is_decode():
|
63
70
|
assert token_num_per_seq is not None
|
64
71
|
return (num_tokens // token_num_per_seq) // 2
|
@@ -69,7 +76,43 @@ def compute_split_seq_index(
|
|
69
76
|
raise NotImplementedError()
|
70
77
|
|
71
78
|
|
72
|
-
def
|
79
|
+
def _is_two_chunk_split_enabled(extend_lens: Sequence[int]) -> bool:
|
80
|
+
if extend_lens is None:
|
81
|
+
return False
|
82
|
+
|
83
|
+
vanilla_split_seq_index = _split_array_by_balanced_sum(extend_lens)
|
84
|
+
left_sum = sum(extend_lens[:vanilla_split_seq_index])
|
85
|
+
overall_sum = sum(extend_lens)
|
86
|
+
threshold = global_server_args_dict["tbo_token_distribution_threshold"]
|
87
|
+
assert threshold <= 0.5, f"{threshold=}"
|
88
|
+
return left_sum < overall_sum * threshold or left_sum > overall_sum * (
|
89
|
+
1 - threshold
|
90
|
+
)
|
91
|
+
|
92
|
+
|
93
|
+
def _split_extend_seqs(arr: Sequence[int]) -> int:
|
94
|
+
if _is_two_chunk_split_enabled(arr):
|
95
|
+
return _split_array_by_cum_less_than_half(arr)
|
96
|
+
|
97
|
+
return _split_array_by_balanced_sum(arr)
|
98
|
+
|
99
|
+
|
100
|
+
def _split_array_by_cum_less_than_half(arr: Sequence[int]) -> int:
|
101
|
+
left_sum = 0
|
102
|
+
overall_sum = sum(arr)
|
103
|
+
half_sum = overall_sum // 2
|
104
|
+
chosen_index = 0
|
105
|
+
|
106
|
+
for i in range(len(arr)):
|
107
|
+
left_sum += arr[i]
|
108
|
+
if left_sum > half_sum:
|
109
|
+
chosen_index = i
|
110
|
+
break
|
111
|
+
|
112
|
+
return chosen_index
|
113
|
+
|
114
|
+
|
115
|
+
def _split_array_by_balanced_sum(arr: Sequence[int]) -> int:
|
73
116
|
overall_sum = sum(arr)
|
74
117
|
left_sum = 0
|
75
118
|
min_diff = float("inf")
|
@@ -88,6 +131,34 @@ def _split_array_by_half_sum(arr: Sequence[int]) -> int:
|
|
88
131
|
return best_index
|
89
132
|
|
90
133
|
|
134
|
+
def _update_device_and_sum_field_from_cpu_field(
|
135
|
+
batch: ForwardBatch, cpu_field: str, device_field: str, sum_field: str = None
|
136
|
+
):
|
137
|
+
cpu_value = getattr(batch, cpu_field, None)
|
138
|
+
old_device_value = getattr(batch, device_field, None)
|
139
|
+
if (
|
140
|
+
cpu_value is None
|
141
|
+
or old_device_value is None
|
142
|
+
or not (isinstance(cpu_value, torch.Tensor) or isinstance(cpu_value, list))
|
143
|
+
):
|
144
|
+
return
|
145
|
+
|
146
|
+
new_device_value = (
|
147
|
+
cpu_value
|
148
|
+
if isinstance(cpu_value, torch.Tensor)
|
149
|
+
else torch.tensor(cpu_value, dtype=old_device_value.dtype)
|
150
|
+
).to(device=global_server_args_dict["device"], non_blocking=True)
|
151
|
+
setattr(batch, device_field, new_device_value)
|
152
|
+
|
153
|
+
if sum_field is not None:
|
154
|
+
sum_value = (
|
155
|
+
cpu_value.sum().item()
|
156
|
+
if isinstance(cpu_value, torch.Tensor)
|
157
|
+
else sum(cpu_value)
|
158
|
+
)
|
159
|
+
setattr(batch, sum_field, sum_value)
|
160
|
+
|
161
|
+
|
91
162
|
def _compute_mask_offset(seq_index: int, spec_info: Optional[EagleVerifyInput]) -> int:
|
92
163
|
if seq_index == 0:
|
93
164
|
return 0
|
@@ -181,6 +252,8 @@ def compute_split_token_index(
|
|
181
252
|
) -> int:
|
182
253
|
if forward_mode == ForwardMode.EXTEND:
|
183
254
|
assert extend_seq_lens is not None
|
255
|
+
if _is_two_chunk_split_enabled(extend_seq_lens):
|
256
|
+
return sum(extend_seq_lens) // 2
|
184
257
|
return sum(extend_seq_lens[:split_seq_index])
|
185
258
|
elif forward_mode.is_target_verify() or forward_mode.is_decode():
|
186
259
|
assert token_num_per_seq is not None
|
@@ -388,9 +461,15 @@ class TboForwardBatchPreparer:
|
|
388
461
|
|
389
462
|
tbo_split_token_index = cls._compute_split_token_index(batch)
|
390
463
|
|
464
|
+
is_enable_two_chunk = (
|
465
|
+
batch.forward_mode == ForwardMode.EXTEND
|
466
|
+
and _is_two_chunk_split_enabled(batch.extend_seq_lens_cpu)
|
467
|
+
)
|
468
|
+
|
391
469
|
if _tbo_debug:
|
392
470
|
logger.info(
|
393
471
|
f"TboForwardBatchPreparer.prepare "
|
472
|
+
f"is_enable_two_chunk={is_enable_two_chunk} "
|
394
473
|
f"tbo_split_seq_index={batch.tbo_split_seq_index} "
|
395
474
|
f"tbo_split_token_index={tbo_split_token_index} "
|
396
475
|
f"extend_seq_lens={batch.extend_seq_lens_cpu} "
|
@@ -410,7 +489,11 @@ class TboForwardBatchPreparer:
|
|
410
489
|
start_token_index=0,
|
411
490
|
end_token_index=tbo_split_token_index,
|
412
491
|
start_seq_index=0,
|
413
|
-
end_seq_index=
|
492
|
+
end_seq_index=(
|
493
|
+
batch.tbo_split_seq_index + 1
|
494
|
+
if is_enable_two_chunk
|
495
|
+
else batch.tbo_split_seq_index
|
496
|
+
),
|
414
497
|
output_attn_backend=attn_backend_child_a,
|
415
498
|
out_num_token_non_padded=out_num_token_non_padded_a,
|
416
499
|
)
|
@@ -424,9 +507,79 @@ class TboForwardBatchPreparer:
|
|
424
507
|
out_num_token_non_padded=out_num_token_non_padded_b,
|
425
508
|
)
|
426
509
|
|
510
|
+
if is_enable_two_chunk:
|
511
|
+
cls.derive_fields_related_to_seq_len_for_two_chunk(
|
512
|
+
batch,
|
513
|
+
child_a=child_a,
|
514
|
+
child_b=child_b,
|
515
|
+
tbo_split_seq_index=batch.tbo_split_seq_index,
|
516
|
+
)
|
517
|
+
|
427
518
|
assert batch.tbo_children is None
|
428
519
|
batch.tbo_children = [child_a, child_b]
|
429
520
|
|
521
|
+
@classmethod
|
522
|
+
def derive_fields_related_to_seq_len_for_two_chunk(
|
523
|
+
cls,
|
524
|
+
batch: ForwardBatch,
|
525
|
+
*,
|
526
|
+
child_a: ForwardBatch,
|
527
|
+
child_b: ForwardBatch,
|
528
|
+
tbo_split_seq_index: int,
|
529
|
+
):
|
530
|
+
extend_seq_lens_cpu = batch.extend_seq_lens_cpu
|
531
|
+
overall_seq_lens_sum = sum(extend_seq_lens_cpu)
|
532
|
+
half_seq_lens_sum = overall_seq_lens_sum // 2
|
533
|
+
left_last_seq_token_num = half_seq_lens_sum - sum(
|
534
|
+
extend_seq_lens_cpu[:tbo_split_seq_index]
|
535
|
+
)
|
536
|
+
right_first_seq_token_num = (
|
537
|
+
extend_seq_lens_cpu[tbo_split_seq_index] - left_last_seq_token_num
|
538
|
+
)
|
539
|
+
|
540
|
+
# making deepcopy to be extra safe
|
541
|
+
child_a.extend_seq_lens_cpu = copy.deepcopy(child_a.extend_seq_lens_cpu)
|
542
|
+
child_a.extend_seq_lens_cpu[-1] = left_last_seq_token_num
|
543
|
+
child_b.extend_seq_lens_cpu = copy.deepcopy(child_b.extend_seq_lens_cpu)
|
544
|
+
child_b.extend_seq_lens_cpu[0] = right_first_seq_token_num
|
545
|
+
for child in [child_a, child_b]:
|
546
|
+
_update_device_and_sum_field_from_cpu_field(
|
547
|
+
batch=child,
|
548
|
+
cpu_field="extend_seq_lens_cpu",
|
549
|
+
device_field="extend_seq_lens",
|
550
|
+
sum_field="extend_num_tokens",
|
551
|
+
)
|
552
|
+
|
553
|
+
assert (
|
554
|
+
child_a.extend_num_tokens == half_seq_lens_sum
|
555
|
+
), f"{child_a.extend_num_tokens=}, {half_seq_lens_sum=}"
|
556
|
+
|
557
|
+
child_a.seq_lens_cpu = copy.deepcopy(child_a.seq_lens_cpu)
|
558
|
+
child_a.seq_lens_cpu[-1] = (
|
559
|
+
child_a.extend_seq_lens_cpu[-1] + child_a.extend_prefix_lens_cpu[-1]
|
560
|
+
)
|
561
|
+
_update_device_and_sum_field_from_cpu_field(
|
562
|
+
batch=child_a,
|
563
|
+
cpu_field="seq_lens_cpu",
|
564
|
+
device_field="seq_lens",
|
565
|
+
sum_field="seq_lens_sum",
|
566
|
+
)
|
567
|
+
|
568
|
+
child_b.extend_prefix_lens_cpu = copy.deepcopy(child_b.extend_prefix_lens_cpu)
|
569
|
+
child_b.extend_prefix_lens_cpu[0] += left_last_seq_token_num
|
570
|
+
_update_device_and_sum_field_from_cpu_field(
|
571
|
+
batch=child_b,
|
572
|
+
cpu_field="extend_prefix_lens_cpu",
|
573
|
+
device_field="extend_prefix_lens",
|
574
|
+
sum_field=None,
|
575
|
+
)
|
576
|
+
_, child_b.extend_start_loc = compute_position(
|
577
|
+
global_server_args_dict["attention_backend"],
|
578
|
+
child_b.extend_prefix_lens,
|
579
|
+
child_b.extend_seq_lens,
|
580
|
+
child_b.extend_num_tokens,
|
581
|
+
)
|
582
|
+
|
430
583
|
@classmethod
|
431
584
|
def filter_batch(
|
432
585
|
cls,
|
@@ -468,7 +621,7 @@ class TboForwardBatchPreparer:
|
|
468
621
|
"extend_prefix_lens_cpu",
|
469
622
|
"extend_seq_lens_cpu",
|
470
623
|
"extend_logprob_start_lens_cpu",
|
471
|
-
"
|
624
|
+
"lora_ids",
|
472
625
|
]:
|
473
626
|
old_value = getattr(batch, key)
|
474
627
|
if old_value is None:
|
@@ -510,6 +663,7 @@ class TboForwardBatchPreparer:
|
|
510
663
|
"padded_static_len",
|
511
664
|
"mrope_positions", # only used by qwen2-vl, thus not care
|
512
665
|
"split_index", # for split prefill
|
666
|
+
"orig_seq_lens", # only used by qwen-1m, thus not care
|
513
667
|
]:
|
514
668
|
output_dict[key] = getattr(batch, key)
|
515
669
|
if not batch.forward_mode.is_target_verify():
|
@@ -670,9 +824,15 @@ def _model_forward_tbo(
|
|
670
824
|
)
|
671
825
|
del inputs
|
672
826
|
|
673
|
-
|
674
|
-
|
675
|
-
|
827
|
+
context = (
|
828
|
+
empty_context()
|
829
|
+
if _is_hip
|
830
|
+
else deep_gemm_wrapper.configure_deep_gemm_num_sms(
|
831
|
+
operations_strategy.deep_gemm_num_sms
|
832
|
+
)
|
833
|
+
)
|
834
|
+
|
835
|
+
with context:
|
676
836
|
outputs_arr = execute_overlapped_operations(
|
677
837
|
inputs_arr=inputs_arr,
|
678
838
|
operations_arr=[operations_strategy.operations] * 2,
|
sglang/srt/utils.py
CHANGED
@@ -41,9 +41,11 @@ import tempfile
|
|
41
41
|
import threading
|
42
42
|
import time
|
43
43
|
import traceback
|
44
|
+
import uuid
|
44
45
|
import warnings
|
45
46
|
from collections import OrderedDict, defaultdict
|
46
47
|
from contextlib import contextmanager
|
48
|
+
from dataclasses import dataclass
|
47
49
|
from functools import lru_cache
|
48
50
|
from importlib.metadata import PackageNotFoundError, version
|
49
51
|
from importlib.util import find_spec
|
@@ -84,6 +86,7 @@ from torch.library import Library
|
|
84
86
|
from torch.profiler import ProfilerActivity, profile, record_function
|
85
87
|
from torch.utils._contextlib import _DecoratorContextManager
|
86
88
|
from triton.runtime.cache import FileCacheManager
|
89
|
+
from typing_extensions import Literal
|
87
90
|
|
88
91
|
from sglang.srt.metrics.func_timer import enable_func_timer
|
89
92
|
|
@@ -231,6 +234,10 @@ def is_flashinfer_available():
|
|
231
234
|
return importlib.util.find_spec("flashinfer") is not None and is_cuda()
|
232
235
|
|
233
236
|
|
237
|
+
def random_uuid() -> str:
|
238
|
+
return str(uuid.uuid4().hex)
|
239
|
+
|
240
|
+
|
234
241
|
_ENABLE_TORCH_INFERENCE_MODE = get_bool_env_var(
|
235
242
|
"SGLANG_ENABLE_TORCH_INFERENCE_MODE", "false"
|
236
243
|
)
|
@@ -736,9 +743,18 @@ def load_audio(
|
|
736
743
|
return audio
|
737
744
|
|
738
745
|
|
746
|
+
@dataclass
|
747
|
+
class ImageData:
|
748
|
+
url: str
|
749
|
+
detail: Optional[Literal["auto", "low", "high"]] = "auto"
|
750
|
+
|
751
|
+
|
739
752
|
def load_image(
|
740
|
-
image_file: Union[Image.Image, str, bytes],
|
753
|
+
image_file: Union[Image.Image, str, ImageData, bytes],
|
741
754
|
) -> tuple[Image.Image, tuple[int, int]]:
|
755
|
+
if isinstance(image_file, ImageData):
|
756
|
+
image_file = image_file.url
|
757
|
+
|
742
758
|
image = image_size = None
|
743
759
|
if isinstance(image_file, Image.Image):
|
744
760
|
image = image_file
|
@@ -762,7 +778,7 @@ def load_image(
|
|
762
778
|
elif isinstance(image_file, str):
|
763
779
|
image = Image.open(BytesIO(pybase64.b64decode(image_file, validate=True)))
|
764
780
|
else:
|
765
|
-
raise ValueError(f"Invalid image: {
|
781
|
+
raise ValueError(f"Invalid image: {image_file}")
|
766
782
|
|
767
783
|
return image, image_size
|
768
784
|
|
@@ -799,7 +815,7 @@ def load_video(video_file: Union[str, bytes], use_gpu: bool = True):
|
|
799
815
|
vr = VideoReader(tmp_file.name, ctx=ctx)
|
800
816
|
elif video_file.startswith("data:"):
|
801
817
|
_, encoded = video_file.split(",", 1)
|
802
|
-
video_bytes =
|
818
|
+
video_bytes = pybase64.b64decode(encoded)
|
803
819
|
tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4")
|
804
820
|
tmp_file.write(video_bytes)
|
805
821
|
tmp_file.close()
|
@@ -807,7 +823,7 @@ def load_video(video_file: Union[str, bytes], use_gpu: bool = True):
|
|
807
823
|
elif os.path.isfile(video_file):
|
808
824
|
vr = VideoReader(video_file, ctx=ctx)
|
809
825
|
else:
|
810
|
-
video_bytes =
|
826
|
+
video_bytes = pybase64.b64decode(video_file)
|
811
827
|
tmp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".mp4")
|
812
828
|
tmp_file.write(video_bytes)
|
813
829
|
tmp_file.close()
|
@@ -2113,6 +2129,10 @@ def next_power_of_2(n: int):
|
|
2113
2129
|
return 1 << (n - 1).bit_length() if n > 0 else 1
|
2114
2130
|
|
2115
2131
|
|
2132
|
+
def round_up(x: int, y: int) -> int:
|
2133
|
+
return ((x - 1) // y + 1) * y
|
2134
|
+
|
2135
|
+
|
2116
2136
|
setattr(triton, "next_power_of_2", next_power_of_2)
|
2117
2137
|
|
2118
2138
|
|
@@ -2832,6 +2852,17 @@ def parse_module_path(module_path, function_name, create_dummy):
|
|
2832
2852
|
return final_module, None
|
2833
2853
|
|
2834
2854
|
|
2855
|
+
def mxfp_supported():
|
2856
|
+
"""
|
2857
|
+
Returns whether the current platform supports MX types.
|
2858
|
+
"""
|
2859
|
+
if torch.version.hip:
|
2860
|
+
gcn_arch = torch.cuda.get_device_properties(0).gcnArchName
|
2861
|
+
return any(gfx in gcn_arch for gfx in ["gfx95"])
|
2862
|
+
else:
|
2863
|
+
return False
|
2864
|
+
|
2865
|
+
|
2835
2866
|
# LoRA-related constants and utilities
|
2836
2867
|
SUPPORTED_LORA_TARGET_MODULES = [
|
2837
2868
|
"q_proj",
|
@@ -2929,4 +2960,9 @@ class ConcurrentCounter:
|
|
2929
2960
|
This suspends the calling coroutine without blocking the thread, allowing
|
2930
2961
|
other tasks to run while waiting. When the counter becomes zero, the coroutine resumes.
|
2931
2962
|
"""
|
2932
|
-
self.wait_for(lambda count: count == 0)
|
2963
|
+
await self.wait_for(lambda count: count == 0)
|
2964
|
+
|
2965
|
+
|
2966
|
+
@lru_cache(maxsize=1)
|
2967
|
+
def is_triton_kernels_available() -> bool:
|
2968
|
+
return importlib.util.find_spec("triton_kernels") is not None
|
@@ -0,0 +1,106 @@
|
|
1
|
+
from dataclasses import dataclass
|
2
|
+
from typing import List, Tuple
|
3
|
+
|
4
|
+
import torch
|
5
|
+
|
6
|
+
|
7
|
+
@dataclass
|
8
|
+
class FlattenedTensorMetadata:
|
9
|
+
"""Metadata for a tensor in a flattened bucket"""
|
10
|
+
|
11
|
+
name: str
|
12
|
+
shape: torch.Size
|
13
|
+
dtype: torch.dtype
|
14
|
+
start_idx: int
|
15
|
+
end_idx: int
|
16
|
+
numel: int
|
17
|
+
|
18
|
+
|
19
|
+
class FlattenedTensorBucket:
|
20
|
+
"""
|
21
|
+
A bucket that flattens multiple tensors into a single tensor for efficient processing
|
22
|
+
while preserving all metadata needed for reconstruction.
|
23
|
+
"""
|
24
|
+
|
25
|
+
def __init__(
|
26
|
+
self,
|
27
|
+
named_tensors: List[Tuple[str, torch.Tensor]] = None,
|
28
|
+
flattened_tensor: torch.Tensor = None,
|
29
|
+
metadata: List[FlattenedTensorMetadata] = None,
|
30
|
+
):
|
31
|
+
"""
|
32
|
+
Initialize a tensor bucket from a list of named tensors OR from pre-flattened data.
|
33
|
+
Args:
|
34
|
+
named_tensors: List of (name, tensor) tuples (for creating new bucket)
|
35
|
+
flattened_tensor: Pre-flattened tensor (for reconstruction)
|
36
|
+
metadata: Pre-computed metadata (for reconstruction)
|
37
|
+
"""
|
38
|
+
if named_tensors is not None:
|
39
|
+
# Create bucket from named tensors
|
40
|
+
self.metadata: List[FlattenedTensorMetadata] = [None] * len(named_tensors)
|
41
|
+
self.flattened_tensor: torch.Tensor = None
|
42
|
+
|
43
|
+
if not named_tensors:
|
44
|
+
raise ValueError("Cannot create empty tensor bucket")
|
45
|
+
|
46
|
+
# Collect metadata and flatten tensors
|
47
|
+
current_idx = 0
|
48
|
+
flattened_tensors: List[torch.Tensor] = [None] * len(named_tensors)
|
49
|
+
|
50
|
+
for i, (name, tensor) in enumerate(named_tensors):
|
51
|
+
flattened = tensor.flatten()
|
52
|
+
flattened_tensors[i] = flattened
|
53
|
+
|
54
|
+
# Store metadata
|
55
|
+
|
56
|
+
numel = flattened.numel()
|
57
|
+
metadata_obj = FlattenedTensorMetadata(
|
58
|
+
name=name,
|
59
|
+
shape=tensor.shape,
|
60
|
+
dtype=tensor.dtype,
|
61
|
+
start_idx=current_idx,
|
62
|
+
end_idx=current_idx + numel,
|
63
|
+
numel=numel,
|
64
|
+
)
|
65
|
+
self.metadata[i] = metadata_obj
|
66
|
+
current_idx += numel
|
67
|
+
|
68
|
+
# Concatenate all flattened tensors
|
69
|
+
self.flattened_tensor = torch.cat(flattened_tensors, dim=0)
|
70
|
+
else:
|
71
|
+
# Initialize from pre-flattened data
|
72
|
+
if flattened_tensor is None or metadata is None:
|
73
|
+
raise ValueError(
|
74
|
+
"Must provide either named_tensors or both flattened_tensor and metadata"
|
75
|
+
)
|
76
|
+
self.flattened_tensor = flattened_tensor
|
77
|
+
self.metadata = metadata
|
78
|
+
|
79
|
+
def get_flattened_tensor(self) -> torch.Tensor:
|
80
|
+
"""Get the flattened tensor containing all bucket tensors"""
|
81
|
+
return self.flattened_tensor
|
82
|
+
|
83
|
+
def get_metadata(self) -> List[FlattenedTensorMetadata]:
|
84
|
+
"""Get metadata for all tensors in the bucket"""
|
85
|
+
return self.metadata
|
86
|
+
|
87
|
+
def reconstruct_tensors(self) -> List[Tuple[str, torch.Tensor]]:
|
88
|
+
"""
|
89
|
+
Reconstruct original tensors from flattened tensor with optimized performance.
|
90
|
+
Uses memory-efficient operations to minimize allocations and copies.
|
91
|
+
"""
|
92
|
+
# preallocate the result list
|
93
|
+
reconstructed = [None] * len(self.metadata)
|
94
|
+
|
95
|
+
for i, meta in enumerate(self.metadata):
|
96
|
+
tensor = self.flattened_tensor[meta.start_idx : meta.end_idx].reshape(
|
97
|
+
meta.shape
|
98
|
+
)
|
99
|
+
|
100
|
+
# batch dtype conversion (if needed)
|
101
|
+
if tensor.dtype != meta.dtype:
|
102
|
+
tensor = tensor.to(meta.dtype)
|
103
|
+
|
104
|
+
reconstructed[i] = (meta.name, tensor)
|
105
|
+
|
106
|
+
return reconstructed
|