sglang 0.5.4__py3-none-any.whl → 0.5.4.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (195) hide show
  1. sglang/bench_one_batch.py +149 -34
  2. sglang/bench_serving.py +73 -14
  3. sglang/compile_deep_gemm.py +13 -7
  4. sglang/launch_server.py +2 -0
  5. sglang/srt/batch_invariant_ops/__init__.py +2 -0
  6. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +221 -4
  7. sglang/srt/checkpoint_engine/__init__.py +9 -0
  8. sglang/srt/checkpoint_engine/update.py +317 -0
  9. sglang/srt/compilation/backend.py +1 -1
  10. sglang/srt/configs/__init__.py +2 -0
  11. sglang/srt/configs/deepseek_ocr.py +542 -10
  12. sglang/srt/configs/deepseekvl2.py +95 -194
  13. sglang/srt/configs/kimi_linear.py +160 -0
  14. sglang/srt/configs/mamba_utils.py +66 -0
  15. sglang/srt/configs/model_config.py +30 -7
  16. sglang/srt/constants.py +7 -0
  17. sglang/srt/debug_utils/tensor_dump_forward_hook.py +149 -0
  18. sglang/srt/disaggregation/decode.py +34 -6
  19. sglang/srt/disaggregation/nixl/conn.py +2 -2
  20. sglang/srt/disaggregation/prefill.py +25 -3
  21. sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -1
  22. sglang/srt/distributed/parallel_state.py +9 -12
  23. sglang/srt/entrypoints/engine.py +31 -20
  24. sglang/srt/entrypoints/grpc_server.py +0 -1
  25. sglang/srt/entrypoints/http_server.py +94 -94
  26. sglang/srt/entrypoints/openai/protocol.py +7 -1
  27. sglang/srt/entrypoints/openai/serving_chat.py +42 -0
  28. sglang/srt/entrypoints/openai/serving_completions.py +10 -0
  29. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  30. sglang/srt/environ.py +23 -2
  31. sglang/srt/eplb/expert_distribution.py +64 -1
  32. sglang/srt/eplb/expert_location.py +106 -36
  33. sglang/srt/function_call/function_call_parser.py +2 -0
  34. sglang/srt/function_call/minimax_m2.py +367 -0
  35. sglang/srt/grpc/compile_proto.py +3 -0
  36. sglang/srt/layers/activation.py +6 -0
  37. sglang/srt/layers/attention/ascend_backend.py +233 -5
  38. sglang/srt/layers/attention/attention_registry.py +3 -0
  39. sglang/srt/layers/attention/fla/chunk_delta_h.py +61 -32
  40. sglang/srt/layers/attention/fla/fused_recurrent.py +17 -4
  41. sglang/srt/layers/attention/fla/kda.py +1359 -0
  42. sglang/srt/layers/attention/fla/layernorm_gated.py +7 -1
  43. sglang/srt/layers/attention/flashattention_backend.py +19 -8
  44. sglang/srt/layers/attention/flashinfer_backend.py +10 -1
  45. sglang/srt/layers/attention/flashinfer_mla_backend.py +21 -11
  46. sglang/srt/layers/attention/flashmla_backend.py +1 -1
  47. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +223 -0
  48. sglang/srt/layers/attention/mamba/mamba.py +20 -11
  49. sglang/srt/layers/attention/nsa/dequant_k_cache.py +138 -6
  50. sglang/srt/layers/attention/nsa/nsa_indexer.py +45 -22
  51. sglang/srt/layers/attention/nsa/quant_k_cache.py +44 -12
  52. sglang/srt/layers/attention/nsa/transform_index.py +1 -1
  53. sglang/srt/layers/attention/nsa_backend.py +157 -23
  54. sglang/srt/layers/attention/triton_backend.py +4 -1
  55. sglang/srt/layers/attention/trtllm_mha_backend.py +10 -4
  56. sglang/srt/layers/attention/trtllm_mla_backend.py +11 -15
  57. sglang/srt/layers/attention/utils.py +78 -0
  58. sglang/srt/layers/communicator.py +24 -1
  59. sglang/srt/layers/deep_gemm_wrapper/compile_utils.py +1 -1
  60. sglang/srt/layers/layernorm.py +35 -6
  61. sglang/srt/layers/logits_processor.py +9 -20
  62. sglang/srt/layers/moe/cutlass_w4a8_moe.py +138 -0
  63. sglang/srt/layers/moe/ep_moe/kernels.py +194 -0
  64. sglang/srt/layers/moe/ep_moe/layer.py +78 -289
  65. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  66. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128]_down.json +164 -0
  67. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +68 -22
  68. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +43 -3
  69. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +106 -26
  70. sglang/srt/layers/moe/fused_moe_triton/layer.py +3 -3
  71. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +7 -4
  72. sglang/srt/layers/moe/moe_runner/deep_gemm.py +340 -55
  73. sglang/srt/layers/moe/moe_runner/runner.py +3 -0
  74. sglang/srt/layers/moe/moe_runner/triton_kernels.py +194 -0
  75. sglang/srt/layers/moe/token_dispatcher/__init__.py +4 -4
  76. sglang/srt/layers/moe/token_dispatcher/base.py +11 -5
  77. sglang/srt/layers/moe/token_dispatcher/deepep.py +25 -18
  78. sglang/srt/layers/moe/token_dispatcher/standard.py +1 -1
  79. sglang/srt/layers/moe/topk.py +35 -10
  80. sglang/srt/layers/moe/utils.py +3 -4
  81. sglang/srt/layers/pooler.py +21 -2
  82. sglang/srt/layers/quantization/__init__.py +13 -84
  83. sglang/srt/layers/quantization/auto_round.py +394 -0
  84. sglang/srt/layers/quantization/awq.py +0 -3
  85. sglang/srt/layers/quantization/base_config.py +7 -0
  86. sglang/srt/layers/quantization/fp8.py +68 -63
  87. sglang/srt/layers/quantization/fp8_kernel.py +1 -1
  88. sglang/srt/layers/quantization/fp8_utils.py +2 -2
  89. sglang/srt/layers/quantization/gguf.py +566 -0
  90. sglang/srt/layers/quantization/modelopt_quant.py +168 -11
  91. sglang/srt/layers/quantization/mxfp4.py +30 -38
  92. sglang/srt/layers/quantization/unquant.py +23 -45
  93. sglang/srt/layers/quantization/w4afp8.py +38 -2
  94. sglang/srt/layers/radix_attention.py +5 -2
  95. sglang/srt/layers/rotary_embedding.py +130 -46
  96. sglang/srt/layers/sampler.py +12 -1
  97. sglang/srt/lora/lora_registry.py +9 -0
  98. sglang/srt/managers/async_mm_data_processor.py +122 -0
  99. sglang/srt/managers/data_parallel_controller.py +30 -3
  100. sglang/srt/managers/detokenizer_manager.py +3 -0
  101. sglang/srt/managers/io_struct.py +29 -4
  102. sglang/srt/managers/multi_tokenizer_mixin.py +22 -1
  103. sglang/srt/managers/schedule_batch.py +74 -15
  104. sglang/srt/managers/scheduler.py +185 -144
  105. sglang/srt/managers/scheduler_metrics_mixin.py +22 -14
  106. sglang/srt/managers/scheduler_output_processor_mixin.py +40 -3
  107. sglang/srt/managers/scheduler_pp_mixin.py +7 -2
  108. sglang/srt/managers/scheduler_profiler_mixin.py +3 -4
  109. sglang/srt/managers/scheduler_runtime_checker_mixin.py +45 -0
  110. sglang/srt/managers/scheduler_update_weights_mixin.py +18 -3
  111. sglang/srt/managers/session_controller.py +6 -5
  112. sglang/srt/managers/tokenizer_manager.py +165 -78
  113. sglang/srt/managers/tp_worker.py +24 -1
  114. sglang/srt/mem_cache/base_prefix_cache.py +23 -4
  115. sglang/srt/mem_cache/common.py +1 -0
  116. sglang/srt/mem_cache/hicache_storage.py +7 -1
  117. sglang/srt/mem_cache/memory_pool.py +253 -57
  118. sglang/srt/mem_cache/memory_pool_host.py +12 -5
  119. sglang/srt/mem_cache/radix_cache.py +4 -0
  120. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +3 -2
  121. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +1 -1
  122. sglang/srt/metrics/collector.py +46 -3
  123. sglang/srt/model_executor/cuda_graph_runner.py +15 -3
  124. sglang/srt/model_executor/forward_batch_info.py +55 -14
  125. sglang/srt/model_executor/model_runner.py +77 -170
  126. sglang/srt/model_executor/npu_graph_runner.py +7 -3
  127. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +22 -12
  128. sglang/srt/model_loader/weight_utils.py +1 -1
  129. sglang/srt/models/bailing_moe.py +9 -2
  130. sglang/srt/models/deepseek_nextn.py +11 -2
  131. sglang/srt/models/deepseek_v2.py +296 -78
  132. sglang/srt/models/glm4.py +391 -77
  133. sglang/srt/models/glm4_moe.py +322 -354
  134. sglang/srt/models/glm4_moe_nextn.py +4 -14
  135. sglang/srt/models/glm4v.py +196 -55
  136. sglang/srt/models/glm4v_moe.py +29 -197
  137. sglang/srt/models/gpt_oss.py +1 -10
  138. sglang/srt/models/kimi_linear.py +678 -0
  139. sglang/srt/models/llama4.py +1 -1
  140. sglang/srt/models/llama_eagle3.py +11 -1
  141. sglang/srt/models/longcat_flash.py +2 -2
  142. sglang/srt/models/minimax_m2.py +922 -0
  143. sglang/srt/models/nvila.py +355 -0
  144. sglang/srt/models/nvila_lite.py +184 -0
  145. sglang/srt/models/qwen2.py +23 -2
  146. sglang/srt/models/qwen2_moe.py +30 -15
  147. sglang/srt/models/qwen3.py +35 -5
  148. sglang/srt/models/qwen3_moe.py +18 -12
  149. sglang/srt/models/qwen3_next.py +7 -0
  150. sglang/srt/multimodal/customized_mm_processor_utils.py +35 -0
  151. sglang/srt/multimodal/processors/base_processor.py +1 -0
  152. sglang/srt/multimodal/processors/glm4v.py +1 -1
  153. sglang/srt/multimodal/processors/{vila.py → nvila.py} +32 -24
  154. sglang/srt/multimodal/processors/points_v15_chat.py +2 -2
  155. sglang/srt/multiplex/multiplexing_mixin.py +209 -0
  156. sglang/srt/multiplex/pdmux_context.py +164 -0
  157. sglang/srt/parser/conversation.py +7 -1
  158. sglang/srt/parser/reasoning_parser.py +28 -1
  159. sglang/srt/sampling/custom_logit_processor.py +67 -1
  160. sglang/srt/sampling/penaltylib/frequency_penalty.py +6 -8
  161. sglang/srt/sampling/penaltylib/min_new_tokens.py +7 -8
  162. sglang/srt/sampling/penaltylib/orchestrator.py +43 -3
  163. sglang/srt/sampling/penaltylib/presence_penalty.py +6 -8
  164. sglang/srt/server_args.py +459 -199
  165. sglang/srt/single_batch_overlap.py +2 -4
  166. sglang/srt/speculative/draft_utils.py +16 -0
  167. sglang/srt/speculative/eagle_info.py +42 -36
  168. sglang/srt/speculative/eagle_info_v2.py +68 -25
  169. sglang/srt/speculative/eagle_utils.py +261 -16
  170. sglang/srt/speculative/eagle_worker.py +11 -3
  171. sglang/srt/speculative/eagle_worker_v2.py +15 -9
  172. sglang/srt/speculative/spec_info.py +305 -31
  173. sglang/srt/speculative/spec_utils.py +44 -8
  174. sglang/srt/tracing/trace.py +121 -12
  175. sglang/srt/utils/common.py +142 -74
  176. sglang/srt/utils/hf_transformers_utils.py +38 -12
  177. sglang/srt/utils/torch_memory_saver_adapter.py +20 -0
  178. sglang/test/kits/radix_cache_server_kit.py +50 -0
  179. sglang/test/runners.py +31 -7
  180. sglang/test/simple_eval_common.py +5 -3
  181. sglang/test/simple_eval_humaneval.py +1 -0
  182. sglang/test/simple_eval_math.py +1 -0
  183. sglang/test/simple_eval_mmlu.py +1 -0
  184. sglang/test/simple_eval_mmmu_vlm.py +1 -0
  185. sglang/test/test_deterministic.py +235 -12
  186. sglang/test/test_deterministic_utils.py +2 -1
  187. sglang/test/test_utils.py +7 -1
  188. sglang/version.py +1 -1
  189. {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/METADATA +15 -28
  190. {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/RECORD +194 -175
  191. sglang/srt/models/vila.py +0 -306
  192. /sglang/test/{kit_matched_stop.py → kits/matched_stop_kit.py} +0 -0
  193. {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/WHEEL +0 -0
  194. {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/licenses/LICENSE +0 -0
  195. {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,164 @@
1
+ from dataclasses import dataclass, field
2
+ from typing import List
3
+
4
+ import torch
5
+ import yaml
6
+
7
+ STREAM_GROUPS = []
8
+ SM_COUNTS = []
9
+ SM_GROUP_NUM = 8 # Default number of SM groups
10
+ CURRENT_STREAM_IDX = 0
11
+ CURRENT_STREAM_GROUP = None
12
+
13
+
14
+ @dataclass
15
+ class PDMuxConfig:
16
+ sm_group_num: int = 8
17
+ manual_divisions: List[List[int]] = field(
18
+ default_factory=list
19
+ ) # [prefill_sm, decode_sm, decode_bs_threshold]
20
+ split_forward_token_budget: int = 65536
21
+ decode_bs_divisor: int = 36
22
+
23
+
24
+ def load_pdmux_config(config_path: str) -> PDMuxConfig:
25
+ """Load pdmux configuration from YAML file into a dataclass."""
26
+ if not config_path:
27
+ return PDMuxConfig()
28
+
29
+ with open(config_path, "r") as f:
30
+ raw = yaml.safe_load(f)
31
+
32
+ if "sm_group_num" not in raw:
33
+ raise ValueError("Missing required field: sm_group_num")
34
+
35
+ if raw["sm_group_num"] < 3:
36
+ raise ValueError("sm_group_num must greater than 3")
37
+
38
+ manual_divisions = raw.get("manual_divisions", [])
39
+
40
+ expected = raw["sm_group_num"] - 2
41
+ if manual_divisions and len(manual_divisions) != expected:
42
+ raise ValueError(
43
+ f"manual_divisions must have {expected} entries, "
44
+ f"but got {len(manual_divisions)}"
45
+ )
46
+
47
+ return PDMuxConfig(
48
+ sm_group_num=raw["sm_group_num"],
49
+ manual_divisions=manual_divisions,
50
+ split_forward_token_budget=raw.get("split_forward_token_budget", 65536),
51
+ decode_bs_divisor=raw.get("decode_bs_divisor", 36),
52
+ )
53
+
54
+
55
+ def get_arch_constraints(compute_capability):
56
+ major, minor = compute_capability
57
+ # green context constraints for different architectures
58
+ if major == 6:
59
+ return 1, 1 # min_per_part, multiple
60
+ elif major == 7:
61
+ return 2, 2
62
+ elif major == 8:
63
+ return 4, 2
64
+ elif major == 9 and minor >= 0:
65
+ return 8, 8
66
+ else:
67
+ raise ValueError(f"Unsupported compute capability: {major}.{minor}")
68
+
69
+
70
+ def divide_sm(total_sms, compute_capability, groups):
71
+ """
72
+ :param total_sms: total sm count on a single GPU
73
+ :param compute_capability: (major, minor)
74
+ :return: SM partition group(prefill sm, decode sm)
75
+ """
76
+ min_per_part, multiple = get_arch_constraints(compute_capability)
77
+ possible_values = [
78
+ x
79
+ for x in range(min_per_part, total_sms - min_per_part + 1, multiple)
80
+ if x >= total_sms - x and total_sms - x >= 16
81
+ ]
82
+ if not possible_values:
83
+ raise ValueError(
84
+ f"No valid partitions found for total SMs {total_sms} "
85
+ f"with constraints (min per part: {min_per_part}, multiple: {multiple})"
86
+ )
87
+
88
+ if len(possible_values) >= groups:
89
+ step = max(1, len(possible_values) // groups)
90
+ selected_values = possible_values[::step][:groups]
91
+ else:
92
+ selected_values = possible_values
93
+
94
+ divisions = []
95
+ for part1 in selected_values:
96
+ part2 = total_sms - part1
97
+ divisions.append((part1, part2))
98
+
99
+ divisions.reverse() # Reverse to have larger prefill SM first
100
+
101
+ return divisions
102
+
103
+
104
+ def initialize_stream_groups(gpu_id: int, config: PDMuxConfig):
105
+ from sgl_kernel import spatial
106
+
107
+ global STREAM_GROUPS, SM_COUNTS, SM_GROUP_NUM, CURRENT_STREAM_IDX, CURRENT_STREAM_GROUP
108
+ # for pd_multiplexing, Init stream_groups
109
+ device = torch.cuda.current_device()
110
+ total_sm_count = spatial.get_sm_available(gpu_id)
111
+ # (prefill_sm_count, decode_sm_count)
112
+ if config.manual_divisions:
113
+ divisions = [
114
+ (prefill_sm, decode_sm)
115
+ for prefill_sm, decode_sm, _ in config.manual_divisions
116
+ ]
117
+ else:
118
+ divisions = divide_sm(
119
+ total_sm_count,
120
+ torch.cuda.get_device_capability(device),
121
+ config.sm_group_num - 2,
122
+ )
123
+
124
+ SM_COUNTS = []
125
+ SM_COUNTS.append((total_sm_count, 0)) # Normal stream for prefill
126
+ SM_COUNTS.extend(divisions) # Add the divided SM counts
127
+ SM_COUNTS.append((0, total_sm_count)) # Normal stream for decode
128
+ STREAM_GROUPS = []
129
+ STREAM_GROUPS.append(
130
+ (torch.cuda.Stream(gpu_id), torch.cuda.Stream(gpu_id))
131
+ ) # Normal stream for prefill
132
+ for prefill_sm, decode_sm in divisions:
133
+ STREAM_GROUPS.append(
134
+ (spatial.create_greenctx_stream_by_value(prefill_sm, decode_sm, gpu_id))
135
+ )
136
+ STREAM_GROUPS.append(
137
+ (torch.cuda.Stream(gpu_id), torch.cuda.Stream(gpu_id))
138
+ ) # Normal stream for decode
139
+
140
+ CURRENT_STREAM_IDX = 0
141
+ CURRENT_STREAM_GROUP = STREAM_GROUPS[CURRENT_STREAM_IDX]
142
+
143
+
144
+ def set_current_stream_idx(idx: int):
145
+ global CURRENT_STREAM_IDX, CURRENT_STREAM_GROUP
146
+ if idx < 0 or idx >= len(STREAM_GROUPS):
147
+ raise ValueError(f"Invalid stream index: {idx}")
148
+ CURRENT_STREAM_IDX = idx
149
+ CURRENT_STREAM_GROUP = STREAM_GROUPS[CURRENT_STREAM_IDX]
150
+
151
+
152
+ def get_stream_groups() -> list[tuple[torch.cuda.Stream, torch.cuda.Stream]]:
153
+ """Get the stream groups."""
154
+ return STREAM_GROUPS
155
+
156
+
157
+ def get_sm_counts() -> list[tuple[int, int]]:
158
+ """Get the SM counts."""
159
+ return SM_COUNTS
160
+
161
+
162
+ def get_current_stream_idx() -> int:
163
+ """Get the current stream index."""
164
+ return CURRENT_STREAM_IDX
@@ -101,6 +101,7 @@ class Conversation:
101
101
  stop_token_ids: Optional[int] = None
102
102
 
103
103
  audio_data: Optional[List[str]] = None
104
+ image_token_at_prefix: bool = False
104
105
 
105
106
  def get_prompt(self) -> str:
106
107
  """Get the prompt for generation."""
@@ -445,6 +446,7 @@ class Conversation:
445
446
  image_token=self.image_token,
446
447
  video_token=self.video_token,
447
448
  audio_token=self.audio_token,
449
+ image_token_at_prefix=self.image_token_at_prefix,
448
450
  )
449
451
 
450
452
  def dict(self):
@@ -512,6 +514,7 @@ def generate_embedding_convs(
512
514
  image_token=conv_template.image_token,
513
515
  video_token=conv_template.video_token,
514
516
  audio_token=conv_template.audio_token,
517
+ image_token_at_prefix=conv_template.image_token_at_prefix,
515
518
  )
516
519
  real_content = ""
517
520
 
@@ -578,6 +581,7 @@ def generate_chat_conv(
578
581
  image_token=conv.image_token,
579
582
  audio_token=conv.audio_token,
580
583
  video_token=conv.video_token,
584
+ image_token_at_prefix=conv.image_token_at_prefix,
581
585
  )
582
586
 
583
587
  if isinstance(request.messages, str):
@@ -627,7 +631,7 @@ def generate_chat_conv(
627
631
  real_content += content.text
628
632
  elif content.type == "image_url":
629
633
  # NOTE: works for llava and intervl2_5
630
- if conv.name in ["internvl-2-5"]:
634
+ if conv.image_token_at_prefix:
631
635
  real_content = image_token + real_content
632
636
  else:
633
637
  real_content += image_token
@@ -820,6 +824,7 @@ register_conv_template(
820
824
  sep="<|im_end|>\n",
821
825
  stop_str=["<|im_end|>", "<|action_end|>"],
822
826
  image_token="<IMG_CONTEXT>",
827
+ image_token_at_prefix=True,
823
828
  )
824
829
  )
825
830
 
@@ -848,6 +853,7 @@ register_conv_template(
848
853
  sep_style=SeparatorStyle.NO_COLON_SINGLE,
849
854
  stop_str=["<|end▁of▁sentence|>"],
850
855
  image_token="<image>",
856
+ image_token_at_prefix=True,
851
857
  )
852
858
  )
853
859
 
@@ -249,6 +249,31 @@ class GptOssDetector(BaseReasoningFormatDetector):
249
249
  )
250
250
 
251
251
 
252
+ class MiniMaxAppendThinkDetector(BaseReasoningFormatDetector):
253
+ """
254
+ Append `<think>` token to the beginning of the text.
255
+ """
256
+
257
+ def __init__(self, stream_reasoning: bool = True, force_reasoning: bool = False):
258
+ # scheduler.py need `reasoning_parser.detector.think_end_token`
259
+ super().__init__(
260
+ "<think>",
261
+ "</think>",
262
+ force_reasoning=force_reasoning,
263
+ stream_reasoning=stream_reasoning,
264
+ )
265
+ self.is_first_chunk = False
266
+
267
+ def parse_streaming_increment(self, new_text: str) -> StreamingParseResult:
268
+ if not self.is_first_chunk:
269
+ self.is_first_chunk = True
270
+ new_text = self.think_start_token + new_text
271
+ return StreamingParseResult(normal_text=new_text)
272
+
273
+ def detect_and_parse(self, text: str) -> StreamingParseResult:
274
+ return StreamingParseResult(normal_text=self.think_start_token + text)
275
+
276
+
252
277
  class ReasoningParser:
253
278
  """
254
279
  Parser that handles both streaming and non-streaming scenarios for extracting
@@ -268,6 +293,8 @@ class ReasoningParser:
268
293
  "kimi": KimiDetector,
269
294
  "qwen3": Qwen3Detector,
270
295
  "qwen3-thinking": Qwen3Detector,
296
+ "minimax": Qwen3Detector,
297
+ "minimax-append-think": MiniMaxAppendThinkDetector,
271
298
  "step3": DeepSeekR1Detector,
272
299
  }
273
300
 
@@ -285,7 +312,7 @@ class ReasoningParser:
285
312
  raise ValueError(f"Unsupported model type: {model_type}")
286
313
 
287
314
  # Special cases where we override force_reasoning
288
- if model_type.lower() in {"qwen3-thinking", "gpt-oss"}:
315
+ if model_type.lower() in {"qwen3-thinking", "gpt-oss", "minimax"}:
289
316
  force_reasoning = True
290
317
 
291
318
  # Only pass force_reasoning if explicitly set, let detectors use their defaults
@@ -1,7 +1,7 @@
1
1
  import json
2
2
  from abc import ABC, abstractmethod
3
3
  from functools import lru_cache
4
- from typing import TYPE_CHECKING, Any, Dict, List, Optional
4
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set
5
5
 
6
6
  import dill
7
7
  import orjson
@@ -126,3 +126,69 @@ class DeepSeekR1ThinkingBudgetLogitProcessor(ThinkingBudgetLogitProcessor):
126
126
  THINKING_START_TOKEN_ID: int = 128798
127
127
  THINKING_END_TOKEN_ID: int = 128799
128
128
  NEW_LINE_TOKEN_ID: int = 201
129
+
130
+
131
+ # Adapted from DeepSeek's implementation: https://github.com/deepseek-ai/DeepSeek-OCR/blob/main/DeepSeek-OCR-master/DeepSeek-OCR-vllm/process/ngram_norepeat.py
132
+ class DeepseekOCRNoRepeatNGramLogitProcessor(CustomLogitProcessor):
133
+ """Block n-gram repetitions within a sliding window for DeepSeek-OCR outputs."""
134
+
135
+ def __call__(
136
+ self,
137
+ logits: torch.Tensor,
138
+ custom_param_list: Optional[List[Dict[str, Any]]] = None,
139
+ ) -> torch.Tensor:
140
+ if not custom_param_list:
141
+ return logits
142
+
143
+ for batch_idx, params in enumerate(custom_param_list):
144
+ if not params:
145
+ continue
146
+
147
+ req = params.get("__req__")
148
+ if req is None:
149
+ continue
150
+
151
+ try:
152
+ ngram_size = int(params.get("ngram_size") or 0)
153
+ window_size = int(params.get("window_size") or 0)
154
+ except (TypeError, ValueError):
155
+ continue
156
+
157
+ if ngram_size <= 0 or window_size <= 0:
158
+ continue
159
+
160
+ sequence: List[int] = req.origin_input_ids + req.output_ids
161
+ if len(sequence) < ngram_size:
162
+ continue
163
+
164
+ search_start = max(0, len(sequence) - window_size)
165
+ search_end = len(sequence) - ngram_size + 1
166
+ if search_end <= search_start:
167
+ continue
168
+
169
+ if ngram_size > 1:
170
+ current_prefix = tuple(sequence[-(ngram_size - 1) :])
171
+ else:
172
+ current_prefix = tuple()
173
+
174
+ banned_tokens: Set[int] = set()
175
+ for idx in range(search_start, search_end):
176
+ ngram = sequence[idx : idx + ngram_size]
177
+ if ngram_size == 1 or tuple(ngram[:-1]) == current_prefix:
178
+ banned_tokens.add(ngram[-1])
179
+
180
+ whitelist_ids = params.get("whitelist_token_ids") or []
181
+ try:
182
+ whitelist = {int(token_id) for token_id in whitelist_ids}
183
+ except (TypeError, ValueError):
184
+ whitelist = set()
185
+
186
+ banned_tokens.difference_update(whitelist)
187
+
188
+ if not banned_tokens:
189
+ continue
190
+
191
+ indices = list(banned_tokens)
192
+ logits[batch_idx, indices] = -float("inf")
193
+
194
+ return logits
@@ -1,9 +1,6 @@
1
1
  import torch
2
2
 
3
- from sglang.srt.sampling.penaltylib.orchestrator import (
4
- BatchedPenalizerOrchestrator,
5
- _BatchedPenalizer,
6
- )
3
+ from sglang.srt.sampling.penaltylib.orchestrator import _BatchedPenalizer
7
4
 
8
5
 
9
6
  class BatchedFrequencyPenalizer(_BatchedPenalizer):
@@ -11,10 +8,6 @@ class BatchedFrequencyPenalizer(_BatchedPenalizer):
11
8
  Frequency penalizer penalizes tokens based on their frequency in the output.
12
9
  """
13
10
 
14
- def __init__(self, orchestrator: BatchedPenalizerOrchestrator):
15
- self.orchestrator = orchestrator
16
- self._is_prepared = False
17
-
18
11
  def _is_required(self) -> bool:
19
12
  return any(
20
13
  req.sampling_params.frequency_penalty != 0.0
@@ -63,3 +56,8 @@ class BatchedFrequencyPenalizer(_BatchedPenalizer):
63
56
  [self.cumulated_frequency_penalties, their.cumulated_frequency_penalties],
64
57
  dim=0,
65
58
  )
59
+
60
+ def _teardown(self) -> None:
61
+ for name in ("frequency_penalties", "cumulated_frequency_penalties"):
62
+ if hasattr(self, name):
63
+ delattr(self, name)
@@ -1,9 +1,6 @@
1
1
  import torch
2
2
 
3
- from sglang.srt.sampling.penaltylib.orchestrator import (
4
- BatchedPenalizerOrchestrator,
5
- _BatchedPenalizer,
6
- )
3
+ from sglang.srt.sampling.penaltylib.orchestrator import _BatchedPenalizer
7
4
 
8
5
 
9
6
  class BatchedMinNewTokensPenalizer(_BatchedPenalizer):
@@ -11,10 +8,6 @@ class BatchedMinNewTokensPenalizer(_BatchedPenalizer):
11
8
  Min new tokens penalizer penalizes tokens based on the length of the output.
12
9
  """
13
10
 
14
- def __init__(self, orchestrator: BatchedPenalizerOrchestrator):
15
- self.orchestrator = orchestrator
16
- self._is_prepared = False
17
-
18
11
  def _is_required(self) -> bool:
19
12
  return any(
20
13
  req.sampling_params.min_new_tokens > 0 for req in self.orchestrator.reqs()
@@ -92,3 +85,9 @@ class BatchedMinNewTokensPenalizer(_BatchedPenalizer):
92
85
  self.len_output_tokens = torch.cat(
93
86
  [self.len_output_tokens, their.len_output_tokens], dim=0
94
87
  )
88
+
89
+ # Explicit resource cleanup to aid GC and free CUDA memory promptly
90
+ def _teardown(self) -> None:
91
+ for name in ("min_new_tokens", "stop_token_penalties", "len_output_tokens"):
92
+ if hasattr(self, name):
93
+ delattr(self, name)
@@ -77,9 +77,8 @@ class BatchedPenalizerOrchestrator:
77
77
  return
78
78
 
79
79
  if len(keep_indices) == 0:
80
- self.is_required = False
81
- for penalizer in self.penalizers.values():
82
- penalizer.teardown()
80
+ # No requests left in the batch, fully release orchestrator resources
81
+ self.release()
83
82
  return
84
83
 
85
84
  is_required = False
@@ -92,6 +91,23 @@ class BatchedPenalizerOrchestrator:
92
91
  penalizer.teardown()
93
92
  self.is_required = is_required
94
93
 
94
+ # Resource management helpers
95
+ def release(self) -> None:
96
+ """Release all penalizers and break references so GC can reclaim promptly."""
97
+ for penalizer in self.penalizers.values():
98
+ penalizer.teardown()
99
+ self.penalizers.clear()
100
+ # Break reference to ScheduleBatch
101
+ self._batch_ref = None
102
+ self.is_required = False
103
+
104
+ # Context manager support
105
+ def __enter__(self) -> "BatchedPenalizerOrchestrator":
106
+ return self
107
+
108
+ def __exit__(self, exc_type, exc, tb) -> None:
109
+ self.release()
110
+
95
111
  def merge(self, their: "BatchedPenalizerOrchestrator"):
96
112
  """
97
113
  Merge the penalizers of another orchestrator into this one.
@@ -116,6 +132,22 @@ class _BatchedPenalizer(abc.ABC):
116
132
  An abstract class for a batched penalizer.
117
133
  """
118
134
 
135
+ def __init__(self, orchestrator: BatchedPenalizerOrchestrator):
136
+ self._orchestrator_ref: weakref.ReferenceType[BatchedPenalizerOrchestrator] = (
137
+ weakref.ref(orchestrator)
138
+ )
139
+ self._is_prepared = False
140
+
141
+ @property
142
+ def orchestrator(self) -> BatchedPenalizerOrchestrator:
143
+ orch: Optional[BatchedPenalizerOrchestrator] = self._orchestrator_ref()
144
+ # This should never happen, but we need to handle it gracefully
145
+ if orch is None:
146
+ raise RuntimeError(
147
+ "BatchedPenalizerOrchestrator has been garbage-collected"
148
+ )
149
+ return orch
150
+
119
151
  def is_prepared(self) -> bool:
120
152
  return self._is_prepared
121
153
 
@@ -135,6 +167,7 @@ class _BatchedPenalizer(abc.ABC):
135
167
  return False
136
168
 
137
169
  def teardown(self):
170
+ self._teardown()
138
171
  self._is_prepared = False
139
172
 
140
173
  def cumulate_output_tokens(self, output_ids: torch.Tensor):
@@ -207,3 +240,10 @@ class _BatchedPenalizer(abc.ABC):
207
240
  Merge the penalizer with another penalizer.
208
241
  """
209
242
  pass
243
+
244
+ @abc.abstractmethod
245
+ def _teardown(self):
246
+ """
247
+ Teardown the penalizer.
248
+ """
249
+ pass
@@ -1,9 +1,6 @@
1
1
  import torch
2
2
 
3
- from sglang.srt.sampling.penaltylib.orchestrator import (
4
- BatchedPenalizerOrchestrator,
5
- _BatchedPenalizer,
6
- )
3
+ from sglang.srt.sampling.penaltylib.orchestrator import _BatchedPenalizer
7
4
 
8
5
 
9
6
  class BatchedPresencePenalizer(_BatchedPenalizer):
@@ -11,10 +8,6 @@ class BatchedPresencePenalizer(_BatchedPenalizer):
11
8
  Presence penalizer penalizes tokens based on their presence in the output.
12
9
  """
13
10
 
14
- def __init__(self, orchestrator: BatchedPenalizerOrchestrator):
15
- self.orchestrator = orchestrator
16
- self._is_prepared = False
17
-
18
11
  def _is_required(self) -> bool:
19
12
  return any(
20
13
  req.sampling_params.presence_penalty != 0.0
@@ -63,3 +56,8 @@ class BatchedPresencePenalizer(_BatchedPenalizer):
63
56
  [self.cumulated_presence_penalties, their.cumulated_presence_penalties],
64
57
  dim=0,
65
58
  )
59
+
60
+ def _teardown(self) -> None:
61
+ for name in ("presence_penalties", "cumulated_presence_penalties"):
62
+ if hasattr(self, name):
63
+ delattr(self, name)