sglang 0.5.4__py3-none-any.whl → 0.5.4.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +149 -34
- sglang/bench_serving.py +73 -14
- sglang/compile_deep_gemm.py +13 -7
- sglang/launch_server.py +2 -0
- sglang/srt/batch_invariant_ops/__init__.py +2 -0
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +221 -4
- sglang/srt/checkpoint_engine/__init__.py +9 -0
- sglang/srt/checkpoint_engine/update.py +317 -0
- sglang/srt/compilation/backend.py +1 -1
- sglang/srt/configs/__init__.py +2 -0
- sglang/srt/configs/deepseek_ocr.py +542 -10
- sglang/srt/configs/deepseekvl2.py +95 -194
- sglang/srt/configs/kimi_linear.py +160 -0
- sglang/srt/configs/mamba_utils.py +66 -0
- sglang/srt/configs/model_config.py +30 -7
- sglang/srt/constants.py +7 -0
- sglang/srt/debug_utils/tensor_dump_forward_hook.py +149 -0
- sglang/srt/disaggregation/decode.py +34 -6
- sglang/srt/disaggregation/nixl/conn.py +2 -2
- sglang/srt/disaggregation/prefill.py +25 -3
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -1
- sglang/srt/distributed/parallel_state.py +9 -12
- sglang/srt/entrypoints/engine.py +31 -20
- sglang/srt/entrypoints/grpc_server.py +0 -1
- sglang/srt/entrypoints/http_server.py +94 -94
- sglang/srt/entrypoints/openai/protocol.py +7 -1
- sglang/srt/entrypoints/openai/serving_chat.py +42 -0
- sglang/srt/entrypoints/openai/serving_completions.py +10 -0
- sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
- sglang/srt/environ.py +23 -2
- sglang/srt/eplb/expert_distribution.py +64 -1
- sglang/srt/eplb/expert_location.py +106 -36
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/minimax_m2.py +367 -0
- sglang/srt/grpc/compile_proto.py +3 -0
- sglang/srt/layers/activation.py +6 -0
- sglang/srt/layers/attention/ascend_backend.py +233 -5
- sglang/srt/layers/attention/attention_registry.py +3 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +61 -32
- sglang/srt/layers/attention/fla/fused_recurrent.py +17 -4
- sglang/srt/layers/attention/fla/kda.py +1359 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +7 -1
- sglang/srt/layers/attention/flashattention_backend.py +19 -8
- sglang/srt/layers/attention/flashinfer_backend.py +10 -1
- sglang/srt/layers/attention/flashinfer_mla_backend.py +21 -11
- sglang/srt/layers/attention/flashmla_backend.py +1 -1
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +223 -0
- sglang/srt/layers/attention/mamba/mamba.py +20 -11
- sglang/srt/layers/attention/nsa/dequant_k_cache.py +138 -6
- sglang/srt/layers/attention/nsa/nsa_indexer.py +45 -22
- sglang/srt/layers/attention/nsa/quant_k_cache.py +44 -12
- sglang/srt/layers/attention/nsa/transform_index.py +1 -1
- sglang/srt/layers/attention/nsa_backend.py +157 -23
- sglang/srt/layers/attention/triton_backend.py +4 -1
- sglang/srt/layers/attention/trtllm_mha_backend.py +10 -4
- sglang/srt/layers/attention/trtllm_mla_backend.py +11 -15
- sglang/srt/layers/attention/utils.py +78 -0
- sglang/srt/layers/communicator.py +24 -1
- sglang/srt/layers/deep_gemm_wrapper/compile_utils.py +1 -1
- sglang/srt/layers/layernorm.py +35 -6
- sglang/srt/layers/logits_processor.py +9 -20
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +138 -0
- sglang/srt/layers/moe/ep_moe/kernels.py +194 -0
- sglang/srt/layers/moe/ep_moe/layer.py +78 -289
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128]_down.json +164 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +68 -22
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +43 -3
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +106 -26
- sglang/srt/layers/moe/fused_moe_triton/layer.py +3 -3
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +7 -4
- sglang/srt/layers/moe/moe_runner/deep_gemm.py +340 -55
- sglang/srt/layers/moe/moe_runner/runner.py +3 -0
- sglang/srt/layers/moe/moe_runner/triton_kernels.py +194 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +4 -4
- sglang/srt/layers/moe/token_dispatcher/base.py +11 -5
- sglang/srt/layers/moe/token_dispatcher/deepep.py +25 -18
- sglang/srt/layers/moe/token_dispatcher/standard.py +1 -1
- sglang/srt/layers/moe/topk.py +35 -10
- sglang/srt/layers/moe/utils.py +3 -4
- sglang/srt/layers/pooler.py +21 -2
- sglang/srt/layers/quantization/__init__.py +13 -84
- sglang/srt/layers/quantization/auto_round.py +394 -0
- sglang/srt/layers/quantization/awq.py +0 -3
- sglang/srt/layers/quantization/base_config.py +7 -0
- sglang/srt/layers/quantization/fp8.py +68 -63
- sglang/srt/layers/quantization/fp8_kernel.py +1 -1
- sglang/srt/layers/quantization/fp8_utils.py +2 -2
- sglang/srt/layers/quantization/gguf.py +566 -0
- sglang/srt/layers/quantization/modelopt_quant.py +168 -11
- sglang/srt/layers/quantization/mxfp4.py +30 -38
- sglang/srt/layers/quantization/unquant.py +23 -45
- sglang/srt/layers/quantization/w4afp8.py +38 -2
- sglang/srt/layers/radix_attention.py +5 -2
- sglang/srt/layers/rotary_embedding.py +130 -46
- sglang/srt/layers/sampler.py +12 -1
- sglang/srt/lora/lora_registry.py +9 -0
- sglang/srt/managers/async_mm_data_processor.py +122 -0
- sglang/srt/managers/data_parallel_controller.py +30 -3
- sglang/srt/managers/detokenizer_manager.py +3 -0
- sglang/srt/managers/io_struct.py +29 -4
- sglang/srt/managers/multi_tokenizer_mixin.py +22 -1
- sglang/srt/managers/schedule_batch.py +74 -15
- sglang/srt/managers/scheduler.py +185 -144
- sglang/srt/managers/scheduler_metrics_mixin.py +22 -14
- sglang/srt/managers/scheduler_output_processor_mixin.py +40 -3
- sglang/srt/managers/scheduler_pp_mixin.py +7 -2
- sglang/srt/managers/scheduler_profiler_mixin.py +3 -4
- sglang/srt/managers/scheduler_runtime_checker_mixin.py +45 -0
- sglang/srt/managers/scheduler_update_weights_mixin.py +18 -3
- sglang/srt/managers/session_controller.py +6 -5
- sglang/srt/managers/tokenizer_manager.py +165 -78
- sglang/srt/managers/tp_worker.py +24 -1
- sglang/srt/mem_cache/base_prefix_cache.py +23 -4
- sglang/srt/mem_cache/common.py +1 -0
- sglang/srt/mem_cache/hicache_storage.py +7 -1
- sglang/srt/mem_cache/memory_pool.py +253 -57
- sglang/srt/mem_cache/memory_pool_host.py +12 -5
- sglang/srt/mem_cache/radix_cache.py +4 -0
- sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +3 -2
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +1 -1
- sglang/srt/metrics/collector.py +46 -3
- sglang/srt/model_executor/cuda_graph_runner.py +15 -3
- sglang/srt/model_executor/forward_batch_info.py +55 -14
- sglang/srt/model_executor/model_runner.py +77 -170
- sglang/srt/model_executor/npu_graph_runner.py +7 -3
- sglang/srt/model_executor/piecewise_cuda_graph_runner.py +22 -12
- sglang/srt/model_loader/weight_utils.py +1 -1
- sglang/srt/models/bailing_moe.py +9 -2
- sglang/srt/models/deepseek_nextn.py +11 -2
- sglang/srt/models/deepseek_v2.py +296 -78
- sglang/srt/models/glm4.py +391 -77
- sglang/srt/models/glm4_moe.py +322 -354
- sglang/srt/models/glm4_moe_nextn.py +4 -14
- sglang/srt/models/glm4v.py +196 -55
- sglang/srt/models/glm4v_moe.py +29 -197
- sglang/srt/models/gpt_oss.py +1 -10
- sglang/srt/models/kimi_linear.py +678 -0
- sglang/srt/models/llama4.py +1 -1
- sglang/srt/models/llama_eagle3.py +11 -1
- sglang/srt/models/longcat_flash.py +2 -2
- sglang/srt/models/minimax_m2.py +922 -0
- sglang/srt/models/nvila.py +355 -0
- sglang/srt/models/nvila_lite.py +184 -0
- sglang/srt/models/qwen2.py +23 -2
- sglang/srt/models/qwen2_moe.py +30 -15
- sglang/srt/models/qwen3.py +35 -5
- sglang/srt/models/qwen3_moe.py +18 -12
- sglang/srt/models/qwen3_next.py +7 -0
- sglang/srt/multimodal/customized_mm_processor_utils.py +35 -0
- sglang/srt/multimodal/processors/base_processor.py +1 -0
- sglang/srt/multimodal/processors/glm4v.py +1 -1
- sglang/srt/multimodal/processors/{vila.py → nvila.py} +32 -24
- sglang/srt/multimodal/processors/points_v15_chat.py +2 -2
- sglang/srt/multiplex/multiplexing_mixin.py +209 -0
- sglang/srt/multiplex/pdmux_context.py +164 -0
- sglang/srt/parser/conversation.py +7 -1
- sglang/srt/parser/reasoning_parser.py +28 -1
- sglang/srt/sampling/custom_logit_processor.py +67 -1
- sglang/srt/sampling/penaltylib/frequency_penalty.py +6 -8
- sglang/srt/sampling/penaltylib/min_new_tokens.py +7 -8
- sglang/srt/sampling/penaltylib/orchestrator.py +43 -3
- sglang/srt/sampling/penaltylib/presence_penalty.py +6 -8
- sglang/srt/server_args.py +459 -199
- sglang/srt/single_batch_overlap.py +2 -4
- sglang/srt/speculative/draft_utils.py +16 -0
- sglang/srt/speculative/eagle_info.py +42 -36
- sglang/srt/speculative/eagle_info_v2.py +68 -25
- sglang/srt/speculative/eagle_utils.py +261 -16
- sglang/srt/speculative/eagle_worker.py +11 -3
- sglang/srt/speculative/eagle_worker_v2.py +15 -9
- sglang/srt/speculative/spec_info.py +305 -31
- sglang/srt/speculative/spec_utils.py +44 -8
- sglang/srt/tracing/trace.py +121 -12
- sglang/srt/utils/common.py +142 -74
- sglang/srt/utils/hf_transformers_utils.py +38 -12
- sglang/srt/utils/torch_memory_saver_adapter.py +20 -0
- sglang/test/kits/radix_cache_server_kit.py +50 -0
- sglang/test/runners.py +31 -7
- sglang/test/simple_eval_common.py +5 -3
- sglang/test/simple_eval_humaneval.py +1 -0
- sglang/test/simple_eval_math.py +1 -0
- sglang/test/simple_eval_mmlu.py +1 -0
- sglang/test/simple_eval_mmmu_vlm.py +1 -0
- sglang/test/test_deterministic.py +235 -12
- sglang/test/test_deterministic_utils.py +2 -1
- sglang/test/test_utils.py +7 -1
- sglang/version.py +1 -1
- {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/METADATA +15 -28
- {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/RECORD +194 -175
- sglang/srt/models/vila.py +0 -306
- /sglang/test/{kit_matched_stop.py → kits/matched_stop_kit.py} +0 -0
- {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/WHEEL +0 -0
- {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,164 @@
|
|
|
1
|
+
from dataclasses import dataclass, field
|
|
2
|
+
from typing import List
|
|
3
|
+
|
|
4
|
+
import torch
|
|
5
|
+
import yaml
|
|
6
|
+
|
|
7
|
+
STREAM_GROUPS = []
|
|
8
|
+
SM_COUNTS = []
|
|
9
|
+
SM_GROUP_NUM = 8 # Default number of SM groups
|
|
10
|
+
CURRENT_STREAM_IDX = 0
|
|
11
|
+
CURRENT_STREAM_GROUP = None
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@dataclass
|
|
15
|
+
class PDMuxConfig:
|
|
16
|
+
sm_group_num: int = 8
|
|
17
|
+
manual_divisions: List[List[int]] = field(
|
|
18
|
+
default_factory=list
|
|
19
|
+
) # [prefill_sm, decode_sm, decode_bs_threshold]
|
|
20
|
+
split_forward_token_budget: int = 65536
|
|
21
|
+
decode_bs_divisor: int = 36
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def load_pdmux_config(config_path: str) -> PDMuxConfig:
|
|
25
|
+
"""Load pdmux configuration from YAML file into a dataclass."""
|
|
26
|
+
if not config_path:
|
|
27
|
+
return PDMuxConfig()
|
|
28
|
+
|
|
29
|
+
with open(config_path, "r") as f:
|
|
30
|
+
raw = yaml.safe_load(f)
|
|
31
|
+
|
|
32
|
+
if "sm_group_num" not in raw:
|
|
33
|
+
raise ValueError("Missing required field: sm_group_num")
|
|
34
|
+
|
|
35
|
+
if raw["sm_group_num"] < 3:
|
|
36
|
+
raise ValueError("sm_group_num must greater than 3")
|
|
37
|
+
|
|
38
|
+
manual_divisions = raw.get("manual_divisions", [])
|
|
39
|
+
|
|
40
|
+
expected = raw["sm_group_num"] - 2
|
|
41
|
+
if manual_divisions and len(manual_divisions) != expected:
|
|
42
|
+
raise ValueError(
|
|
43
|
+
f"manual_divisions must have {expected} entries, "
|
|
44
|
+
f"but got {len(manual_divisions)}"
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
return PDMuxConfig(
|
|
48
|
+
sm_group_num=raw["sm_group_num"],
|
|
49
|
+
manual_divisions=manual_divisions,
|
|
50
|
+
split_forward_token_budget=raw.get("split_forward_token_budget", 65536),
|
|
51
|
+
decode_bs_divisor=raw.get("decode_bs_divisor", 36),
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def get_arch_constraints(compute_capability):
|
|
56
|
+
major, minor = compute_capability
|
|
57
|
+
# green context constraints for different architectures
|
|
58
|
+
if major == 6:
|
|
59
|
+
return 1, 1 # min_per_part, multiple
|
|
60
|
+
elif major == 7:
|
|
61
|
+
return 2, 2
|
|
62
|
+
elif major == 8:
|
|
63
|
+
return 4, 2
|
|
64
|
+
elif major == 9 and minor >= 0:
|
|
65
|
+
return 8, 8
|
|
66
|
+
else:
|
|
67
|
+
raise ValueError(f"Unsupported compute capability: {major}.{minor}")
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def divide_sm(total_sms, compute_capability, groups):
|
|
71
|
+
"""
|
|
72
|
+
:param total_sms: total sm count on a single GPU
|
|
73
|
+
:param compute_capability: (major, minor)
|
|
74
|
+
:return: SM partition group(prefill sm, decode sm)
|
|
75
|
+
"""
|
|
76
|
+
min_per_part, multiple = get_arch_constraints(compute_capability)
|
|
77
|
+
possible_values = [
|
|
78
|
+
x
|
|
79
|
+
for x in range(min_per_part, total_sms - min_per_part + 1, multiple)
|
|
80
|
+
if x >= total_sms - x and total_sms - x >= 16
|
|
81
|
+
]
|
|
82
|
+
if not possible_values:
|
|
83
|
+
raise ValueError(
|
|
84
|
+
f"No valid partitions found for total SMs {total_sms} "
|
|
85
|
+
f"with constraints (min per part: {min_per_part}, multiple: {multiple})"
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
if len(possible_values) >= groups:
|
|
89
|
+
step = max(1, len(possible_values) // groups)
|
|
90
|
+
selected_values = possible_values[::step][:groups]
|
|
91
|
+
else:
|
|
92
|
+
selected_values = possible_values
|
|
93
|
+
|
|
94
|
+
divisions = []
|
|
95
|
+
for part1 in selected_values:
|
|
96
|
+
part2 = total_sms - part1
|
|
97
|
+
divisions.append((part1, part2))
|
|
98
|
+
|
|
99
|
+
divisions.reverse() # Reverse to have larger prefill SM first
|
|
100
|
+
|
|
101
|
+
return divisions
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
def initialize_stream_groups(gpu_id: int, config: PDMuxConfig):
|
|
105
|
+
from sgl_kernel import spatial
|
|
106
|
+
|
|
107
|
+
global STREAM_GROUPS, SM_COUNTS, SM_GROUP_NUM, CURRENT_STREAM_IDX, CURRENT_STREAM_GROUP
|
|
108
|
+
# for pd_multiplexing, Init stream_groups
|
|
109
|
+
device = torch.cuda.current_device()
|
|
110
|
+
total_sm_count = spatial.get_sm_available(gpu_id)
|
|
111
|
+
# (prefill_sm_count, decode_sm_count)
|
|
112
|
+
if config.manual_divisions:
|
|
113
|
+
divisions = [
|
|
114
|
+
(prefill_sm, decode_sm)
|
|
115
|
+
for prefill_sm, decode_sm, _ in config.manual_divisions
|
|
116
|
+
]
|
|
117
|
+
else:
|
|
118
|
+
divisions = divide_sm(
|
|
119
|
+
total_sm_count,
|
|
120
|
+
torch.cuda.get_device_capability(device),
|
|
121
|
+
config.sm_group_num - 2,
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
SM_COUNTS = []
|
|
125
|
+
SM_COUNTS.append((total_sm_count, 0)) # Normal stream for prefill
|
|
126
|
+
SM_COUNTS.extend(divisions) # Add the divided SM counts
|
|
127
|
+
SM_COUNTS.append((0, total_sm_count)) # Normal stream for decode
|
|
128
|
+
STREAM_GROUPS = []
|
|
129
|
+
STREAM_GROUPS.append(
|
|
130
|
+
(torch.cuda.Stream(gpu_id), torch.cuda.Stream(gpu_id))
|
|
131
|
+
) # Normal stream for prefill
|
|
132
|
+
for prefill_sm, decode_sm in divisions:
|
|
133
|
+
STREAM_GROUPS.append(
|
|
134
|
+
(spatial.create_greenctx_stream_by_value(prefill_sm, decode_sm, gpu_id))
|
|
135
|
+
)
|
|
136
|
+
STREAM_GROUPS.append(
|
|
137
|
+
(torch.cuda.Stream(gpu_id), torch.cuda.Stream(gpu_id))
|
|
138
|
+
) # Normal stream for decode
|
|
139
|
+
|
|
140
|
+
CURRENT_STREAM_IDX = 0
|
|
141
|
+
CURRENT_STREAM_GROUP = STREAM_GROUPS[CURRENT_STREAM_IDX]
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
def set_current_stream_idx(idx: int):
|
|
145
|
+
global CURRENT_STREAM_IDX, CURRENT_STREAM_GROUP
|
|
146
|
+
if idx < 0 or idx >= len(STREAM_GROUPS):
|
|
147
|
+
raise ValueError(f"Invalid stream index: {idx}")
|
|
148
|
+
CURRENT_STREAM_IDX = idx
|
|
149
|
+
CURRENT_STREAM_GROUP = STREAM_GROUPS[CURRENT_STREAM_IDX]
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
def get_stream_groups() -> list[tuple[torch.cuda.Stream, torch.cuda.Stream]]:
|
|
153
|
+
"""Get the stream groups."""
|
|
154
|
+
return STREAM_GROUPS
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
def get_sm_counts() -> list[tuple[int, int]]:
|
|
158
|
+
"""Get the SM counts."""
|
|
159
|
+
return SM_COUNTS
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
def get_current_stream_idx() -> int:
|
|
163
|
+
"""Get the current stream index."""
|
|
164
|
+
return CURRENT_STREAM_IDX
|
|
@@ -101,6 +101,7 @@ class Conversation:
|
|
|
101
101
|
stop_token_ids: Optional[int] = None
|
|
102
102
|
|
|
103
103
|
audio_data: Optional[List[str]] = None
|
|
104
|
+
image_token_at_prefix: bool = False
|
|
104
105
|
|
|
105
106
|
def get_prompt(self) -> str:
|
|
106
107
|
"""Get the prompt for generation."""
|
|
@@ -445,6 +446,7 @@ class Conversation:
|
|
|
445
446
|
image_token=self.image_token,
|
|
446
447
|
video_token=self.video_token,
|
|
447
448
|
audio_token=self.audio_token,
|
|
449
|
+
image_token_at_prefix=self.image_token_at_prefix,
|
|
448
450
|
)
|
|
449
451
|
|
|
450
452
|
def dict(self):
|
|
@@ -512,6 +514,7 @@ def generate_embedding_convs(
|
|
|
512
514
|
image_token=conv_template.image_token,
|
|
513
515
|
video_token=conv_template.video_token,
|
|
514
516
|
audio_token=conv_template.audio_token,
|
|
517
|
+
image_token_at_prefix=conv_template.image_token_at_prefix,
|
|
515
518
|
)
|
|
516
519
|
real_content = ""
|
|
517
520
|
|
|
@@ -578,6 +581,7 @@ def generate_chat_conv(
|
|
|
578
581
|
image_token=conv.image_token,
|
|
579
582
|
audio_token=conv.audio_token,
|
|
580
583
|
video_token=conv.video_token,
|
|
584
|
+
image_token_at_prefix=conv.image_token_at_prefix,
|
|
581
585
|
)
|
|
582
586
|
|
|
583
587
|
if isinstance(request.messages, str):
|
|
@@ -627,7 +631,7 @@ def generate_chat_conv(
|
|
|
627
631
|
real_content += content.text
|
|
628
632
|
elif content.type == "image_url":
|
|
629
633
|
# NOTE: works for llava and intervl2_5
|
|
630
|
-
if conv.
|
|
634
|
+
if conv.image_token_at_prefix:
|
|
631
635
|
real_content = image_token + real_content
|
|
632
636
|
else:
|
|
633
637
|
real_content += image_token
|
|
@@ -820,6 +824,7 @@ register_conv_template(
|
|
|
820
824
|
sep="<|im_end|>\n",
|
|
821
825
|
stop_str=["<|im_end|>", "<|action_end|>"],
|
|
822
826
|
image_token="<IMG_CONTEXT>",
|
|
827
|
+
image_token_at_prefix=True,
|
|
823
828
|
)
|
|
824
829
|
)
|
|
825
830
|
|
|
@@ -848,6 +853,7 @@ register_conv_template(
|
|
|
848
853
|
sep_style=SeparatorStyle.NO_COLON_SINGLE,
|
|
849
854
|
stop_str=["<|end▁of▁sentence|>"],
|
|
850
855
|
image_token="<image>",
|
|
856
|
+
image_token_at_prefix=True,
|
|
851
857
|
)
|
|
852
858
|
)
|
|
853
859
|
|
|
@@ -249,6 +249,31 @@ class GptOssDetector(BaseReasoningFormatDetector):
|
|
|
249
249
|
)
|
|
250
250
|
|
|
251
251
|
|
|
252
|
+
class MiniMaxAppendThinkDetector(BaseReasoningFormatDetector):
|
|
253
|
+
"""
|
|
254
|
+
Append `<think>` token to the beginning of the text.
|
|
255
|
+
"""
|
|
256
|
+
|
|
257
|
+
def __init__(self, stream_reasoning: bool = True, force_reasoning: bool = False):
|
|
258
|
+
# scheduler.py need `reasoning_parser.detector.think_end_token`
|
|
259
|
+
super().__init__(
|
|
260
|
+
"<think>",
|
|
261
|
+
"</think>",
|
|
262
|
+
force_reasoning=force_reasoning,
|
|
263
|
+
stream_reasoning=stream_reasoning,
|
|
264
|
+
)
|
|
265
|
+
self.is_first_chunk = False
|
|
266
|
+
|
|
267
|
+
def parse_streaming_increment(self, new_text: str) -> StreamingParseResult:
|
|
268
|
+
if not self.is_first_chunk:
|
|
269
|
+
self.is_first_chunk = True
|
|
270
|
+
new_text = self.think_start_token + new_text
|
|
271
|
+
return StreamingParseResult(normal_text=new_text)
|
|
272
|
+
|
|
273
|
+
def detect_and_parse(self, text: str) -> StreamingParseResult:
|
|
274
|
+
return StreamingParseResult(normal_text=self.think_start_token + text)
|
|
275
|
+
|
|
276
|
+
|
|
252
277
|
class ReasoningParser:
|
|
253
278
|
"""
|
|
254
279
|
Parser that handles both streaming and non-streaming scenarios for extracting
|
|
@@ -268,6 +293,8 @@ class ReasoningParser:
|
|
|
268
293
|
"kimi": KimiDetector,
|
|
269
294
|
"qwen3": Qwen3Detector,
|
|
270
295
|
"qwen3-thinking": Qwen3Detector,
|
|
296
|
+
"minimax": Qwen3Detector,
|
|
297
|
+
"minimax-append-think": MiniMaxAppendThinkDetector,
|
|
271
298
|
"step3": DeepSeekR1Detector,
|
|
272
299
|
}
|
|
273
300
|
|
|
@@ -285,7 +312,7 @@ class ReasoningParser:
|
|
|
285
312
|
raise ValueError(f"Unsupported model type: {model_type}")
|
|
286
313
|
|
|
287
314
|
# Special cases where we override force_reasoning
|
|
288
|
-
if model_type.lower() in {"qwen3-thinking", "gpt-oss"}:
|
|
315
|
+
if model_type.lower() in {"qwen3-thinking", "gpt-oss", "minimax"}:
|
|
289
316
|
force_reasoning = True
|
|
290
317
|
|
|
291
318
|
# Only pass force_reasoning if explicitly set, let detectors use their defaults
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
import json
|
|
2
2
|
from abc import ABC, abstractmethod
|
|
3
3
|
from functools import lru_cache
|
|
4
|
-
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
|
4
|
+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set
|
|
5
5
|
|
|
6
6
|
import dill
|
|
7
7
|
import orjson
|
|
@@ -126,3 +126,69 @@ class DeepSeekR1ThinkingBudgetLogitProcessor(ThinkingBudgetLogitProcessor):
|
|
|
126
126
|
THINKING_START_TOKEN_ID: int = 128798
|
|
127
127
|
THINKING_END_TOKEN_ID: int = 128799
|
|
128
128
|
NEW_LINE_TOKEN_ID: int = 201
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
# Adapted from DeepSeek's implementation: https://github.com/deepseek-ai/DeepSeek-OCR/blob/main/DeepSeek-OCR-master/DeepSeek-OCR-vllm/process/ngram_norepeat.py
|
|
132
|
+
class DeepseekOCRNoRepeatNGramLogitProcessor(CustomLogitProcessor):
|
|
133
|
+
"""Block n-gram repetitions within a sliding window for DeepSeek-OCR outputs."""
|
|
134
|
+
|
|
135
|
+
def __call__(
|
|
136
|
+
self,
|
|
137
|
+
logits: torch.Tensor,
|
|
138
|
+
custom_param_list: Optional[List[Dict[str, Any]]] = None,
|
|
139
|
+
) -> torch.Tensor:
|
|
140
|
+
if not custom_param_list:
|
|
141
|
+
return logits
|
|
142
|
+
|
|
143
|
+
for batch_idx, params in enumerate(custom_param_list):
|
|
144
|
+
if not params:
|
|
145
|
+
continue
|
|
146
|
+
|
|
147
|
+
req = params.get("__req__")
|
|
148
|
+
if req is None:
|
|
149
|
+
continue
|
|
150
|
+
|
|
151
|
+
try:
|
|
152
|
+
ngram_size = int(params.get("ngram_size") or 0)
|
|
153
|
+
window_size = int(params.get("window_size") or 0)
|
|
154
|
+
except (TypeError, ValueError):
|
|
155
|
+
continue
|
|
156
|
+
|
|
157
|
+
if ngram_size <= 0 or window_size <= 0:
|
|
158
|
+
continue
|
|
159
|
+
|
|
160
|
+
sequence: List[int] = req.origin_input_ids + req.output_ids
|
|
161
|
+
if len(sequence) < ngram_size:
|
|
162
|
+
continue
|
|
163
|
+
|
|
164
|
+
search_start = max(0, len(sequence) - window_size)
|
|
165
|
+
search_end = len(sequence) - ngram_size + 1
|
|
166
|
+
if search_end <= search_start:
|
|
167
|
+
continue
|
|
168
|
+
|
|
169
|
+
if ngram_size > 1:
|
|
170
|
+
current_prefix = tuple(sequence[-(ngram_size - 1) :])
|
|
171
|
+
else:
|
|
172
|
+
current_prefix = tuple()
|
|
173
|
+
|
|
174
|
+
banned_tokens: Set[int] = set()
|
|
175
|
+
for idx in range(search_start, search_end):
|
|
176
|
+
ngram = sequence[idx : idx + ngram_size]
|
|
177
|
+
if ngram_size == 1 or tuple(ngram[:-1]) == current_prefix:
|
|
178
|
+
banned_tokens.add(ngram[-1])
|
|
179
|
+
|
|
180
|
+
whitelist_ids = params.get("whitelist_token_ids") or []
|
|
181
|
+
try:
|
|
182
|
+
whitelist = {int(token_id) for token_id in whitelist_ids}
|
|
183
|
+
except (TypeError, ValueError):
|
|
184
|
+
whitelist = set()
|
|
185
|
+
|
|
186
|
+
banned_tokens.difference_update(whitelist)
|
|
187
|
+
|
|
188
|
+
if not banned_tokens:
|
|
189
|
+
continue
|
|
190
|
+
|
|
191
|
+
indices = list(banned_tokens)
|
|
192
|
+
logits[batch_idx, indices] = -float("inf")
|
|
193
|
+
|
|
194
|
+
return logits
|
|
@@ -1,9 +1,6 @@
|
|
|
1
1
|
import torch
|
|
2
2
|
|
|
3
|
-
from sglang.srt.sampling.penaltylib.orchestrator import
|
|
4
|
-
BatchedPenalizerOrchestrator,
|
|
5
|
-
_BatchedPenalizer,
|
|
6
|
-
)
|
|
3
|
+
from sglang.srt.sampling.penaltylib.orchestrator import _BatchedPenalizer
|
|
7
4
|
|
|
8
5
|
|
|
9
6
|
class BatchedFrequencyPenalizer(_BatchedPenalizer):
|
|
@@ -11,10 +8,6 @@ class BatchedFrequencyPenalizer(_BatchedPenalizer):
|
|
|
11
8
|
Frequency penalizer penalizes tokens based on their frequency in the output.
|
|
12
9
|
"""
|
|
13
10
|
|
|
14
|
-
def __init__(self, orchestrator: BatchedPenalizerOrchestrator):
|
|
15
|
-
self.orchestrator = orchestrator
|
|
16
|
-
self._is_prepared = False
|
|
17
|
-
|
|
18
11
|
def _is_required(self) -> bool:
|
|
19
12
|
return any(
|
|
20
13
|
req.sampling_params.frequency_penalty != 0.0
|
|
@@ -63,3 +56,8 @@ class BatchedFrequencyPenalizer(_BatchedPenalizer):
|
|
|
63
56
|
[self.cumulated_frequency_penalties, their.cumulated_frequency_penalties],
|
|
64
57
|
dim=0,
|
|
65
58
|
)
|
|
59
|
+
|
|
60
|
+
def _teardown(self) -> None:
|
|
61
|
+
for name in ("frequency_penalties", "cumulated_frequency_penalties"):
|
|
62
|
+
if hasattr(self, name):
|
|
63
|
+
delattr(self, name)
|
|
@@ -1,9 +1,6 @@
|
|
|
1
1
|
import torch
|
|
2
2
|
|
|
3
|
-
from sglang.srt.sampling.penaltylib.orchestrator import
|
|
4
|
-
BatchedPenalizerOrchestrator,
|
|
5
|
-
_BatchedPenalizer,
|
|
6
|
-
)
|
|
3
|
+
from sglang.srt.sampling.penaltylib.orchestrator import _BatchedPenalizer
|
|
7
4
|
|
|
8
5
|
|
|
9
6
|
class BatchedMinNewTokensPenalizer(_BatchedPenalizer):
|
|
@@ -11,10 +8,6 @@ class BatchedMinNewTokensPenalizer(_BatchedPenalizer):
|
|
|
11
8
|
Min new tokens penalizer penalizes tokens based on the length of the output.
|
|
12
9
|
"""
|
|
13
10
|
|
|
14
|
-
def __init__(self, orchestrator: BatchedPenalizerOrchestrator):
|
|
15
|
-
self.orchestrator = orchestrator
|
|
16
|
-
self._is_prepared = False
|
|
17
|
-
|
|
18
11
|
def _is_required(self) -> bool:
|
|
19
12
|
return any(
|
|
20
13
|
req.sampling_params.min_new_tokens > 0 for req in self.orchestrator.reqs()
|
|
@@ -92,3 +85,9 @@ class BatchedMinNewTokensPenalizer(_BatchedPenalizer):
|
|
|
92
85
|
self.len_output_tokens = torch.cat(
|
|
93
86
|
[self.len_output_tokens, their.len_output_tokens], dim=0
|
|
94
87
|
)
|
|
88
|
+
|
|
89
|
+
# Explicit resource cleanup to aid GC and free CUDA memory promptly
|
|
90
|
+
def _teardown(self) -> None:
|
|
91
|
+
for name in ("min_new_tokens", "stop_token_penalties", "len_output_tokens"):
|
|
92
|
+
if hasattr(self, name):
|
|
93
|
+
delattr(self, name)
|
|
@@ -77,9 +77,8 @@ class BatchedPenalizerOrchestrator:
|
|
|
77
77
|
return
|
|
78
78
|
|
|
79
79
|
if len(keep_indices) == 0:
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
penalizer.teardown()
|
|
80
|
+
# No requests left in the batch, fully release orchestrator resources
|
|
81
|
+
self.release()
|
|
83
82
|
return
|
|
84
83
|
|
|
85
84
|
is_required = False
|
|
@@ -92,6 +91,23 @@ class BatchedPenalizerOrchestrator:
|
|
|
92
91
|
penalizer.teardown()
|
|
93
92
|
self.is_required = is_required
|
|
94
93
|
|
|
94
|
+
# Resource management helpers
|
|
95
|
+
def release(self) -> None:
|
|
96
|
+
"""Release all penalizers and break references so GC can reclaim promptly."""
|
|
97
|
+
for penalizer in self.penalizers.values():
|
|
98
|
+
penalizer.teardown()
|
|
99
|
+
self.penalizers.clear()
|
|
100
|
+
# Break reference to ScheduleBatch
|
|
101
|
+
self._batch_ref = None
|
|
102
|
+
self.is_required = False
|
|
103
|
+
|
|
104
|
+
# Context manager support
|
|
105
|
+
def __enter__(self) -> "BatchedPenalizerOrchestrator":
|
|
106
|
+
return self
|
|
107
|
+
|
|
108
|
+
def __exit__(self, exc_type, exc, tb) -> None:
|
|
109
|
+
self.release()
|
|
110
|
+
|
|
95
111
|
def merge(self, their: "BatchedPenalizerOrchestrator"):
|
|
96
112
|
"""
|
|
97
113
|
Merge the penalizers of another orchestrator into this one.
|
|
@@ -116,6 +132,22 @@ class _BatchedPenalizer(abc.ABC):
|
|
|
116
132
|
An abstract class for a batched penalizer.
|
|
117
133
|
"""
|
|
118
134
|
|
|
135
|
+
def __init__(self, orchestrator: BatchedPenalizerOrchestrator):
|
|
136
|
+
self._orchestrator_ref: weakref.ReferenceType[BatchedPenalizerOrchestrator] = (
|
|
137
|
+
weakref.ref(orchestrator)
|
|
138
|
+
)
|
|
139
|
+
self._is_prepared = False
|
|
140
|
+
|
|
141
|
+
@property
|
|
142
|
+
def orchestrator(self) -> BatchedPenalizerOrchestrator:
|
|
143
|
+
orch: Optional[BatchedPenalizerOrchestrator] = self._orchestrator_ref()
|
|
144
|
+
# This should never happen, but we need to handle it gracefully
|
|
145
|
+
if orch is None:
|
|
146
|
+
raise RuntimeError(
|
|
147
|
+
"BatchedPenalizerOrchestrator has been garbage-collected"
|
|
148
|
+
)
|
|
149
|
+
return orch
|
|
150
|
+
|
|
119
151
|
def is_prepared(self) -> bool:
|
|
120
152
|
return self._is_prepared
|
|
121
153
|
|
|
@@ -135,6 +167,7 @@ class _BatchedPenalizer(abc.ABC):
|
|
|
135
167
|
return False
|
|
136
168
|
|
|
137
169
|
def teardown(self):
|
|
170
|
+
self._teardown()
|
|
138
171
|
self._is_prepared = False
|
|
139
172
|
|
|
140
173
|
def cumulate_output_tokens(self, output_ids: torch.Tensor):
|
|
@@ -207,3 +240,10 @@ class _BatchedPenalizer(abc.ABC):
|
|
|
207
240
|
Merge the penalizer with another penalizer.
|
|
208
241
|
"""
|
|
209
242
|
pass
|
|
243
|
+
|
|
244
|
+
@abc.abstractmethod
|
|
245
|
+
def _teardown(self):
|
|
246
|
+
"""
|
|
247
|
+
Teardown the penalizer.
|
|
248
|
+
"""
|
|
249
|
+
pass
|
|
@@ -1,9 +1,6 @@
|
|
|
1
1
|
import torch
|
|
2
2
|
|
|
3
|
-
from sglang.srt.sampling.penaltylib.orchestrator import
|
|
4
|
-
BatchedPenalizerOrchestrator,
|
|
5
|
-
_BatchedPenalizer,
|
|
6
|
-
)
|
|
3
|
+
from sglang.srt.sampling.penaltylib.orchestrator import _BatchedPenalizer
|
|
7
4
|
|
|
8
5
|
|
|
9
6
|
class BatchedPresencePenalizer(_BatchedPenalizer):
|
|
@@ -11,10 +8,6 @@ class BatchedPresencePenalizer(_BatchedPenalizer):
|
|
|
11
8
|
Presence penalizer penalizes tokens based on their presence in the output.
|
|
12
9
|
"""
|
|
13
10
|
|
|
14
|
-
def __init__(self, orchestrator: BatchedPenalizerOrchestrator):
|
|
15
|
-
self.orchestrator = orchestrator
|
|
16
|
-
self._is_prepared = False
|
|
17
|
-
|
|
18
11
|
def _is_required(self) -> bool:
|
|
19
12
|
return any(
|
|
20
13
|
req.sampling_params.presence_penalty != 0.0
|
|
@@ -63,3 +56,8 @@ class BatchedPresencePenalizer(_BatchedPenalizer):
|
|
|
63
56
|
[self.cumulated_presence_penalties, their.cumulated_presence_penalties],
|
|
64
57
|
dim=0,
|
|
65
58
|
)
|
|
59
|
+
|
|
60
|
+
def _teardown(self) -> None:
|
|
61
|
+
for name in ("presence_penalties", "cumulated_presence_penalties"):
|
|
62
|
+
if hasattr(self, name):
|
|
63
|
+
delattr(self, name)
|