sglang 0.5.4__py3-none-any.whl → 0.5.4.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (195) hide show
  1. sglang/bench_one_batch.py +149 -34
  2. sglang/bench_serving.py +73 -14
  3. sglang/compile_deep_gemm.py +13 -7
  4. sglang/launch_server.py +2 -0
  5. sglang/srt/batch_invariant_ops/__init__.py +2 -0
  6. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +221 -4
  7. sglang/srt/checkpoint_engine/__init__.py +9 -0
  8. sglang/srt/checkpoint_engine/update.py +317 -0
  9. sglang/srt/compilation/backend.py +1 -1
  10. sglang/srt/configs/__init__.py +2 -0
  11. sglang/srt/configs/deepseek_ocr.py +542 -10
  12. sglang/srt/configs/deepseekvl2.py +95 -194
  13. sglang/srt/configs/kimi_linear.py +160 -0
  14. sglang/srt/configs/mamba_utils.py +66 -0
  15. sglang/srt/configs/model_config.py +30 -7
  16. sglang/srt/constants.py +7 -0
  17. sglang/srt/debug_utils/tensor_dump_forward_hook.py +149 -0
  18. sglang/srt/disaggregation/decode.py +34 -6
  19. sglang/srt/disaggregation/nixl/conn.py +2 -2
  20. sglang/srt/disaggregation/prefill.py +25 -3
  21. sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -1
  22. sglang/srt/distributed/parallel_state.py +9 -12
  23. sglang/srt/entrypoints/engine.py +31 -20
  24. sglang/srt/entrypoints/grpc_server.py +0 -1
  25. sglang/srt/entrypoints/http_server.py +94 -94
  26. sglang/srt/entrypoints/openai/protocol.py +7 -1
  27. sglang/srt/entrypoints/openai/serving_chat.py +42 -0
  28. sglang/srt/entrypoints/openai/serving_completions.py +10 -0
  29. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  30. sglang/srt/environ.py +23 -2
  31. sglang/srt/eplb/expert_distribution.py +64 -1
  32. sglang/srt/eplb/expert_location.py +106 -36
  33. sglang/srt/function_call/function_call_parser.py +2 -0
  34. sglang/srt/function_call/minimax_m2.py +367 -0
  35. sglang/srt/grpc/compile_proto.py +3 -0
  36. sglang/srt/layers/activation.py +6 -0
  37. sglang/srt/layers/attention/ascend_backend.py +233 -5
  38. sglang/srt/layers/attention/attention_registry.py +3 -0
  39. sglang/srt/layers/attention/fla/chunk_delta_h.py +61 -32
  40. sglang/srt/layers/attention/fla/fused_recurrent.py +17 -4
  41. sglang/srt/layers/attention/fla/kda.py +1359 -0
  42. sglang/srt/layers/attention/fla/layernorm_gated.py +7 -1
  43. sglang/srt/layers/attention/flashattention_backend.py +19 -8
  44. sglang/srt/layers/attention/flashinfer_backend.py +10 -1
  45. sglang/srt/layers/attention/flashinfer_mla_backend.py +21 -11
  46. sglang/srt/layers/attention/flashmla_backend.py +1 -1
  47. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +223 -0
  48. sglang/srt/layers/attention/mamba/mamba.py +20 -11
  49. sglang/srt/layers/attention/nsa/dequant_k_cache.py +138 -6
  50. sglang/srt/layers/attention/nsa/nsa_indexer.py +45 -22
  51. sglang/srt/layers/attention/nsa/quant_k_cache.py +44 -12
  52. sglang/srt/layers/attention/nsa/transform_index.py +1 -1
  53. sglang/srt/layers/attention/nsa_backend.py +157 -23
  54. sglang/srt/layers/attention/triton_backend.py +4 -1
  55. sglang/srt/layers/attention/trtllm_mha_backend.py +10 -4
  56. sglang/srt/layers/attention/trtllm_mla_backend.py +11 -15
  57. sglang/srt/layers/attention/utils.py +78 -0
  58. sglang/srt/layers/communicator.py +24 -1
  59. sglang/srt/layers/deep_gemm_wrapper/compile_utils.py +1 -1
  60. sglang/srt/layers/layernorm.py +35 -6
  61. sglang/srt/layers/logits_processor.py +9 -20
  62. sglang/srt/layers/moe/cutlass_w4a8_moe.py +138 -0
  63. sglang/srt/layers/moe/ep_moe/kernels.py +194 -0
  64. sglang/srt/layers/moe/ep_moe/layer.py +78 -289
  65. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  66. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128]_down.json +164 -0
  67. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +68 -22
  68. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +43 -3
  69. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +106 -26
  70. sglang/srt/layers/moe/fused_moe_triton/layer.py +3 -3
  71. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +7 -4
  72. sglang/srt/layers/moe/moe_runner/deep_gemm.py +340 -55
  73. sglang/srt/layers/moe/moe_runner/runner.py +3 -0
  74. sglang/srt/layers/moe/moe_runner/triton_kernels.py +194 -0
  75. sglang/srt/layers/moe/token_dispatcher/__init__.py +4 -4
  76. sglang/srt/layers/moe/token_dispatcher/base.py +11 -5
  77. sglang/srt/layers/moe/token_dispatcher/deepep.py +25 -18
  78. sglang/srt/layers/moe/token_dispatcher/standard.py +1 -1
  79. sglang/srt/layers/moe/topk.py +35 -10
  80. sglang/srt/layers/moe/utils.py +3 -4
  81. sglang/srt/layers/pooler.py +21 -2
  82. sglang/srt/layers/quantization/__init__.py +13 -84
  83. sglang/srt/layers/quantization/auto_round.py +394 -0
  84. sglang/srt/layers/quantization/awq.py +0 -3
  85. sglang/srt/layers/quantization/base_config.py +7 -0
  86. sglang/srt/layers/quantization/fp8.py +68 -63
  87. sglang/srt/layers/quantization/fp8_kernel.py +1 -1
  88. sglang/srt/layers/quantization/fp8_utils.py +2 -2
  89. sglang/srt/layers/quantization/gguf.py +566 -0
  90. sglang/srt/layers/quantization/modelopt_quant.py +168 -11
  91. sglang/srt/layers/quantization/mxfp4.py +30 -38
  92. sglang/srt/layers/quantization/unquant.py +23 -45
  93. sglang/srt/layers/quantization/w4afp8.py +38 -2
  94. sglang/srt/layers/radix_attention.py +5 -2
  95. sglang/srt/layers/rotary_embedding.py +130 -46
  96. sglang/srt/layers/sampler.py +12 -1
  97. sglang/srt/lora/lora_registry.py +9 -0
  98. sglang/srt/managers/async_mm_data_processor.py +122 -0
  99. sglang/srt/managers/data_parallel_controller.py +30 -3
  100. sglang/srt/managers/detokenizer_manager.py +3 -0
  101. sglang/srt/managers/io_struct.py +29 -4
  102. sglang/srt/managers/multi_tokenizer_mixin.py +22 -1
  103. sglang/srt/managers/schedule_batch.py +74 -15
  104. sglang/srt/managers/scheduler.py +185 -144
  105. sglang/srt/managers/scheduler_metrics_mixin.py +22 -14
  106. sglang/srt/managers/scheduler_output_processor_mixin.py +40 -3
  107. sglang/srt/managers/scheduler_pp_mixin.py +7 -2
  108. sglang/srt/managers/scheduler_profiler_mixin.py +3 -4
  109. sglang/srt/managers/scheduler_runtime_checker_mixin.py +45 -0
  110. sglang/srt/managers/scheduler_update_weights_mixin.py +18 -3
  111. sglang/srt/managers/session_controller.py +6 -5
  112. sglang/srt/managers/tokenizer_manager.py +165 -78
  113. sglang/srt/managers/tp_worker.py +24 -1
  114. sglang/srt/mem_cache/base_prefix_cache.py +23 -4
  115. sglang/srt/mem_cache/common.py +1 -0
  116. sglang/srt/mem_cache/hicache_storage.py +7 -1
  117. sglang/srt/mem_cache/memory_pool.py +253 -57
  118. sglang/srt/mem_cache/memory_pool_host.py +12 -5
  119. sglang/srt/mem_cache/radix_cache.py +4 -0
  120. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +3 -2
  121. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +1 -1
  122. sglang/srt/metrics/collector.py +46 -3
  123. sglang/srt/model_executor/cuda_graph_runner.py +15 -3
  124. sglang/srt/model_executor/forward_batch_info.py +55 -14
  125. sglang/srt/model_executor/model_runner.py +77 -170
  126. sglang/srt/model_executor/npu_graph_runner.py +7 -3
  127. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +22 -12
  128. sglang/srt/model_loader/weight_utils.py +1 -1
  129. sglang/srt/models/bailing_moe.py +9 -2
  130. sglang/srt/models/deepseek_nextn.py +11 -2
  131. sglang/srt/models/deepseek_v2.py +296 -78
  132. sglang/srt/models/glm4.py +391 -77
  133. sglang/srt/models/glm4_moe.py +322 -354
  134. sglang/srt/models/glm4_moe_nextn.py +4 -14
  135. sglang/srt/models/glm4v.py +196 -55
  136. sglang/srt/models/glm4v_moe.py +29 -197
  137. sglang/srt/models/gpt_oss.py +1 -10
  138. sglang/srt/models/kimi_linear.py +678 -0
  139. sglang/srt/models/llama4.py +1 -1
  140. sglang/srt/models/llama_eagle3.py +11 -1
  141. sglang/srt/models/longcat_flash.py +2 -2
  142. sglang/srt/models/minimax_m2.py +922 -0
  143. sglang/srt/models/nvila.py +355 -0
  144. sglang/srt/models/nvila_lite.py +184 -0
  145. sglang/srt/models/qwen2.py +23 -2
  146. sglang/srt/models/qwen2_moe.py +30 -15
  147. sglang/srt/models/qwen3.py +35 -5
  148. sglang/srt/models/qwen3_moe.py +18 -12
  149. sglang/srt/models/qwen3_next.py +7 -0
  150. sglang/srt/multimodal/customized_mm_processor_utils.py +35 -0
  151. sglang/srt/multimodal/processors/base_processor.py +1 -0
  152. sglang/srt/multimodal/processors/glm4v.py +1 -1
  153. sglang/srt/multimodal/processors/{vila.py → nvila.py} +32 -24
  154. sglang/srt/multimodal/processors/points_v15_chat.py +2 -2
  155. sglang/srt/multiplex/multiplexing_mixin.py +209 -0
  156. sglang/srt/multiplex/pdmux_context.py +164 -0
  157. sglang/srt/parser/conversation.py +7 -1
  158. sglang/srt/parser/reasoning_parser.py +28 -1
  159. sglang/srt/sampling/custom_logit_processor.py +67 -1
  160. sglang/srt/sampling/penaltylib/frequency_penalty.py +6 -8
  161. sglang/srt/sampling/penaltylib/min_new_tokens.py +7 -8
  162. sglang/srt/sampling/penaltylib/orchestrator.py +43 -3
  163. sglang/srt/sampling/penaltylib/presence_penalty.py +6 -8
  164. sglang/srt/server_args.py +459 -199
  165. sglang/srt/single_batch_overlap.py +2 -4
  166. sglang/srt/speculative/draft_utils.py +16 -0
  167. sglang/srt/speculative/eagle_info.py +42 -36
  168. sglang/srt/speculative/eagle_info_v2.py +68 -25
  169. sglang/srt/speculative/eagle_utils.py +261 -16
  170. sglang/srt/speculative/eagle_worker.py +11 -3
  171. sglang/srt/speculative/eagle_worker_v2.py +15 -9
  172. sglang/srt/speculative/spec_info.py +305 -31
  173. sglang/srt/speculative/spec_utils.py +44 -8
  174. sglang/srt/tracing/trace.py +121 -12
  175. sglang/srt/utils/common.py +142 -74
  176. sglang/srt/utils/hf_transformers_utils.py +38 -12
  177. sglang/srt/utils/torch_memory_saver_adapter.py +20 -0
  178. sglang/test/kits/radix_cache_server_kit.py +50 -0
  179. sglang/test/runners.py +31 -7
  180. sglang/test/simple_eval_common.py +5 -3
  181. sglang/test/simple_eval_humaneval.py +1 -0
  182. sglang/test/simple_eval_math.py +1 -0
  183. sglang/test/simple_eval_mmlu.py +1 -0
  184. sglang/test/simple_eval_mmmu_vlm.py +1 -0
  185. sglang/test/test_deterministic.py +235 -12
  186. sglang/test/test_deterministic_utils.py +2 -1
  187. sglang/test/test_utils.py +7 -1
  188. sglang/version.py +1 -1
  189. {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/METADATA +15 -28
  190. {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/RECORD +194 -175
  191. sglang/srt/models/vila.py +0 -306
  192. /sglang/test/{kit_matched_stop.py → kits/matched_stop_kit.py} +0 -0
  193. {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/WHEEL +0 -0
  194. {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/licenses/LICENSE +0 -0
  195. {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/top_level.txt +0 -0
@@ -11,6 +11,7 @@ import triton
11
11
  import triton.language as tl
12
12
 
13
13
  from sglang.srt.custom_op import CustomOp
14
+ from sglang.srt.server_args import get_global_server_args
14
15
  from sglang.srt.utils import (
15
16
  cpu_has_amx_support,
16
17
  get_bool_env_var,
@@ -124,18 +125,34 @@ class RotaryEmbedding(CustomOp):
124
125
  self.cos_sin_cache: torch.Tensor
125
126
  self.register_buffer("cos_sin_cache", cache, persistent=False)
126
127
 
128
+ self._apply_rotary_emb_wrapped = _apply_rotary_emb
129
+
130
+ if get_global_server_args().rl_on_policy_target == "fsdp":
131
+ self._forward_method = self.forward_native
132
+ self._apply_rotary_emb_wrapped = torch.compile(dynamic=True)(
133
+ self._apply_rotary_emb_wrapped
134
+ )
135
+
127
136
  def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
128
137
  """Compute the inverse frequency."""
129
138
  # NOTE(woosuk): To exactly match the HF implementation, we need to
130
139
  # use CPU to compute the cache and then move it to GPU. However, we
131
140
  # create the cache on GPU for faster initialization. This may cause
132
141
  # a slight numerical difference between the HF implementation and ours.
142
+ init_device = (
143
+ "cpu" if get_global_server_args().rl_on_policy_target == "fsdp" else None
144
+ )
133
145
  inv_freq = 1.0 / (
134
146
  base
135
147
  ** (
136
- torch.arange(0, self.rotary_dim, 2, dtype=torch.float) / self.rotary_dim
148
+ torch.arange(
149
+ 0, self.rotary_dim, 2, dtype=torch.float, device=init_device
150
+ )
151
+ / self.rotary_dim
137
152
  )
138
153
  )
154
+ if get_global_server_args().rl_on_policy_target == "fsdp":
155
+ inv_freq = inv_freq.cuda()
139
156
  return inv_freq
140
157
 
141
158
  def _compute_cos_sin_cache(self) -> torch.Tensor:
@@ -173,14 +190,16 @@ class RotaryEmbedding(CustomOp):
173
190
  query = query.view(num_tokens, -1, self.head_size)
174
191
  query_rot = query[..., : self.rotary_dim]
175
192
  query_pass = query[..., self.rotary_dim :]
176
- query_rot = _apply_rotary_emb(query_rot, cos, sin, self.is_neox_style)
193
+ query_rot = self._apply_rotary_emb_wrapped(
194
+ query_rot, cos, sin, self.is_neox_style
195
+ )
177
196
  query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)
178
197
 
179
198
  key_shape = key.shape
180
199
  key = key.view(num_tokens, -1, self.head_size)
181
200
  key_rot = key[..., : self.rotary_dim]
182
201
  key_pass = key[..., self.rotary_dim :]
183
- key_rot = _apply_rotary_emb(key_rot, cos, sin, self.is_neox_style)
202
+ key_rot = self._apply_rotary_emb_wrapped(key_rot, cos, sin, self.is_neox_style)
184
203
  key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
185
204
  return query, key
186
205
 
@@ -300,10 +319,20 @@ class RotaryEmbedding(CustomOp):
300
319
  query: torch.Tensor,
301
320
  key: torch.Tensor,
302
321
  offsets: Optional[torch.Tensor] = None,
322
+ fused_set_kv_buffer_arg: Optional[FusedSetKVBufferArg] = None,
303
323
  ) -> Tuple[torch.Tensor, torch.Tensor]:
304
- # TODO: make a wrapper, and XPU will implement this kernel later.
305
- self.cos_sin_cache = self.cos_sin_cache.to(query.device)
306
- return self.forward_native(positions, query, key, offsets)
324
+ assert (
325
+ fused_set_kv_buffer_arg is None
326
+ ), "fused_set_kv_buffer_arg is not supported for xpu implementation"
327
+ positions = torch.add(positions, offsets) if offsets is not None else positions
328
+ return torch.ops.sgl_kernel.rotary_embedding(
329
+ positions,
330
+ query,
331
+ key,
332
+ self.head_size,
333
+ self.cos_sin_cache,
334
+ self.is_neox_style,
335
+ )
307
336
 
308
337
 
309
338
  class LinearScalingRotaryEmbedding(RotaryEmbedding):
@@ -1058,6 +1087,7 @@ def _triton_mrope_forward(
1058
1087
  mrope_section_h: tl.constexpr,
1059
1088
  mrope_section_w: tl.constexpr,
1060
1089
  is_interleaved: tl.constexpr,
1090
+ is_neox_style: tl.constexpr,
1061
1091
  ):
1062
1092
  # Adapted from
1063
1093
  # https://github.com/linkedin/Liger-Kernel/blob/main/src/liger_kernel/ops/qwen2vl_mrope.py
@@ -1112,51 +1142,99 @@ def _triton_mrope_forward(
1112
1142
  # program instance (i.e. for the current token) separately
1113
1143
  # ####################################################################
1114
1144
  # left half of the head
1115
- first_half_q_offsets = (
1116
- tl.arange(0, pad_n_qh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :]
1117
- )
1118
- first_half_k_offsets = (
1119
- tl.arange(0, pad_n_kh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :]
1120
- )
1121
- first_q_mask = (tl.arange(0, pad_n_qh)[:, None] < n_qh) & (
1122
- tl.arange(0, pad_hd // 2)[None, :] < rd // 2
1123
- )
1124
- first_k_mask = (tl.arange(0, pad_n_kh)[:, None] < n_kh) & (
1125
- tl.arange(0, pad_hd // 2)[None, :] < rd // 2
1126
- )
1145
+ if is_neox_style:
1146
+ first_half_q_offsets = (
1147
+ tl.arange(0, pad_n_qh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :]
1148
+ )
1149
+ first_half_k_offsets = (
1150
+ tl.arange(0, pad_n_kh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :]
1151
+ )
1152
+ first_q_mask = (tl.arange(0, pad_n_qh)[:, None] < n_qh) & (
1153
+ tl.arange(0, pad_hd // 2)[None, :] < rd // 2
1154
+ )
1155
+ first_k_mask = (tl.arange(0, pad_n_kh)[:, None] < n_kh) & (
1156
+ tl.arange(0, pad_hd // 2)[None, :] < rd // 2
1157
+ )
1127
1158
 
1128
- q_tile_1 = tl.load(q_ptr + first_half_q_offsets, mask=first_q_mask, other=0).to(
1129
- sin_row.dtype
1130
- )
1131
- k_tile_1 = tl.load(k_ptr + first_half_k_offsets, mask=first_k_mask, other=0).to(
1132
- sin_row.dtype
1133
- )
1159
+ q_tile_1 = tl.load(q_ptr + first_half_q_offsets, mask=first_q_mask, other=0).to(
1160
+ sin_row.dtype
1161
+ )
1162
+ k_tile_1 = tl.load(k_ptr + first_half_k_offsets, mask=first_k_mask, other=0).to(
1163
+ sin_row.dtype
1164
+ )
1134
1165
 
1135
- # right half of the head
1136
- second_half_q_offsets = first_half_q_offsets + (rd // 2)
1137
- second_half_k_offsets = first_half_k_offsets + (rd // 2)
1138
- second_q_mask = first_q_mask
1139
- second_k_mask = first_k_mask
1166
+ # right half of the head
1167
+ second_half_q_offsets = first_half_q_offsets + (rd // 2)
1168
+ second_half_k_offsets = first_half_k_offsets + (rd // 2)
1169
+ second_q_mask = first_q_mask
1170
+ second_k_mask = first_k_mask
1171
+
1172
+ q_tile_2 = tl.load(
1173
+ q_ptr + second_half_q_offsets, mask=second_q_mask, other=0
1174
+ ).to(sin_row.dtype)
1175
+ k_tile_2 = tl.load(
1176
+ k_ptr + second_half_k_offsets, mask=second_k_mask, other=0
1177
+ ).to(sin_row.dtype)
1178
+
1179
+ # y = [x1, x2] * [cos, cos] + [-x2, x1] * [sin, sin]
1180
+ # Since cos and sin are now half-size,
1181
+ # we use the same cos_row and sin_row for both halves
1182
+ new_q_tile_1 = q_tile_1 * cos_row - q_tile_2 * sin_row
1183
+ tl.store(q_ptr + first_half_q_offsets, new_q_tile_1, mask=first_q_mask)
1184
+ new_q_tile_2 = q_tile_2 * cos_row + q_tile_1 * sin_row
1185
+ tl.store(q_ptr + second_half_q_offsets, new_q_tile_2, mask=second_q_mask)
1186
+
1187
+ new_k_tile_1 = k_tile_1 * cos_row - k_tile_2 * sin_row
1188
+ tl.store(k_ptr + first_half_k_offsets, new_k_tile_1, mask=first_k_mask)
1189
+ new_k_tile_2 = k_tile_2 * cos_row + k_tile_1 * sin_row
1190
+ tl.store(k_ptr + second_half_k_offsets, new_k_tile_2, mask=second_k_mask)
1191
+ else:
1192
+ base_q = tl.arange(0, pad_n_qh)[:, None] * hd
1193
+ base_k = tl.arange(0, pad_n_kh)[:, None] * hd
1194
+ even_idx = 2 * tl.arange(0, pad_hd // 2)[None, :]
1195
+ odd_idx = even_idx + 1
1196
+
1197
+ even_q_offsets = base_q + even_idx
1198
+ odd_q_offsets = base_q + odd_idx
1199
+ even_k_offsets = base_k + even_idx
1200
+ odd_k_offsets = base_k + odd_idx
1201
+
1202
+ idx_mask = tl.arange(0, pad_hd // 2)[None, :] < (rd // 2)
1203
+ qn_mask = tl.arange(0, pad_n_qh)[:, None] < n_qh
1204
+ kn_mask = tl.arange(0, pad_n_kh)[:, None] < n_kh
1205
+
1206
+ even_q_mask = qn_mask & idx_mask
1207
+ odd_q_mask = qn_mask & idx_mask
1208
+ even_k_mask = kn_mask & idx_mask
1209
+ odd_k_mask = kn_mask & idx_mask
1210
+
1211
+ q_tile_1 = tl.load(q_ptr + even_q_offsets, mask=even_q_mask, other=0).to(
1212
+ sin_row.dtype
1213
+ )
1214
+ k_tile_1 = tl.load(k_ptr + even_k_offsets, mask=even_k_mask, other=0).to(
1215
+ sin_row.dtype
1216
+ )
1140
1217
 
1141
- q_tile_2 = tl.load(q_ptr + second_half_q_offsets, mask=second_q_mask, other=0).to(
1142
- sin_row.dtype
1143
- )
1144
- k_tile_2 = tl.load(k_ptr + second_half_k_offsets, mask=second_k_mask, other=0).to(
1145
- sin_row.dtype
1146
- )
1218
+ q_tile_2 = tl.load(q_ptr + odd_q_offsets, mask=odd_q_mask, other=0).to(
1219
+ sin_row.dtype
1220
+ )
1221
+ k_tile_2 = tl.load(k_ptr + odd_k_offsets, mask=odd_k_mask, other=0).to(
1222
+ sin_row.dtype
1223
+ )
1147
1224
 
1148
- # y = [x1, x2] * [cos, cos] + [-x2, x1] * [sin, sin]
1149
- # Since cos and sin are now half-size,
1150
- # we use the same cos_row and sin_row for both halves
1151
- new_q_tile_1 = q_tile_1 * cos_row - q_tile_2 * sin_row
1152
- tl.store(q_ptr + first_half_q_offsets, new_q_tile_1, mask=first_q_mask)
1153
- new_q_tile_2 = q_tile_2 * cos_row + q_tile_1 * sin_row
1154
- tl.store(q_ptr + second_half_q_offsets, new_q_tile_2, mask=second_q_mask)
1225
+ # y = [x_even, x_odd] * [cos, cos] + [-x_odd, x_even] * [sin, sin]
1226
+ # NeoX-style rotary embedding:
1227
+ # Each (even, odd) channel pair forms one rotation arm.
1228
+ # cos_row and sin_row each have length rd//2, shared across all (even, odd) pairs.
1229
+ new_q_tile_1 = q_tile_1 * cos_row - q_tile_2 * sin_row
1230
+ tl.store(q_ptr + even_q_offsets, new_q_tile_1, mask=even_q_mask)
1231
+ new_q_tile_2 = q_tile_2 * cos_row + q_tile_1 * sin_row
1232
+ tl.store(q_ptr + odd_q_offsets, new_q_tile_2, mask=odd_q_mask)
1155
1233
 
1156
- new_k_tile_1 = k_tile_1 * cos_row - k_tile_2 * sin_row
1157
- tl.store(k_ptr + first_half_k_offsets, new_k_tile_1, mask=first_k_mask)
1158
- new_k_tile_2 = k_tile_2 * cos_row + k_tile_1 * sin_row
1159
- tl.store(k_ptr + second_half_k_offsets, new_k_tile_2, mask=second_k_mask)
1234
+ new_k_tile_1 = k_tile_1 * cos_row - k_tile_2 * sin_row
1235
+ tl.store(k_ptr + even_k_offsets, new_k_tile_1, mask=even_k_mask)
1236
+ new_k_tile_2 = k_tile_2 * cos_row + k_tile_1 * sin_row
1237
+ tl.store(k_ptr + odd_k_offsets, new_k_tile_2, mask=odd_k_mask)
1160
1238
 
1161
1239
 
1162
1240
  def triton_mrope(
@@ -1168,6 +1246,7 @@ def triton_mrope(
1168
1246
  head_size: int,
1169
1247
  rotary_dim: int,
1170
1248
  mrope_interleaved: bool,
1249
+ is_neox_style: bool,
1171
1250
  ) -> tuple[torch.Tensor, torch.Tensor]:
1172
1251
  """The mrope triton kernel.
1173
1252
 
@@ -1218,6 +1297,7 @@ def triton_mrope(
1218
1297
  mrope_section[1],
1219
1298
  mrope_section[2],
1220
1299
  mrope_interleaved,
1300
+ is_neox_style,
1221
1301
  )
1222
1302
  return q, k
1223
1303
 
@@ -1361,6 +1441,7 @@ class MRotaryEmbedding(RotaryEmbedding):
1361
1441
  else:
1362
1442
  return self._forward_native(positions, query, key)
1363
1443
 
1444
+ @torch.compile(dynamic=True, backend=get_compiler_backend())
1364
1445
  def _forward_triton(
1365
1446
  self,
1366
1447
  positions: torch.Tensor,
@@ -1379,6 +1460,7 @@ class MRotaryEmbedding(RotaryEmbedding):
1379
1460
  if positions.ndim == 2:
1380
1461
  assert self.mrope_section
1381
1462
 
1463
+ torch._dynamo.graph_break()
1382
1464
  q, k = triton_mrope(
1383
1465
  query,
1384
1466
  key,
@@ -1388,7 +1470,9 @@ class MRotaryEmbedding(RotaryEmbedding):
1388
1470
  self.head_size,
1389
1471
  self.rotary_dim,
1390
1472
  self.mrope_interleaved,
1473
+ self.is_neox_style,
1391
1474
  )
1475
+ torch._dynamo.graph_break()
1392
1476
 
1393
1477
  return q.reshape(query_shape), k.reshape(key_shape)
1394
1478
 
@@ -102,6 +102,14 @@ class Sampler(nn.Module):
102
102
  if return_logprob and SGLANG_RETURN_ORIGINAL_LOGPROB:
103
103
  probs_without_temp_scaling = torch.softmax(logits, dim=-1)
104
104
 
105
+ if get_global_server_args().rl_on_policy_target == "fsdp":
106
+ logits_div_temperature = (
107
+ logits.bfloat16().div(sampling_info.temperatures).bfloat16()
108
+ )
109
+ logprobs_via_logsoftmax_kernel = torch.log_softmax(
110
+ logits_div_temperature, dim=-1
111
+ )
112
+
105
113
  # Post process logits
106
114
  logits.div_(sampling_info.temperatures)
107
115
  logits[:] = torch.softmax(logits, dim=-1)
@@ -148,8 +156,11 @@ class Sampler(nn.Module):
148
156
  )
149
157
 
150
158
  if return_logprob:
159
+ if get_global_server_args().rl_on_policy_target == "fsdp":
160
+ logprobs = logprobs_via_logsoftmax_kernel
161
+ del logprobs_via_logsoftmax_kernel
151
162
  # clamp to avoid -inf
152
- if SGLANG_RETURN_ORIGINAL_LOGPROB:
163
+ elif SGLANG_RETURN_ORIGINAL_LOGPROB:
153
164
  logprobs = torch.log(probs_without_temp_scaling).clamp(
154
165
  min=torch.finfo(probs_without_temp_scaling.dtype).min
155
166
  )
@@ -205,3 +205,12 @@ class LoRARegistry:
205
205
  Returns the total number of LoRA adapters currently registered.
206
206
  """
207
207
  return len(self._registry)
208
+
209
+ def get_all_adapters(self) -> Dict[str, LoRARef]:
210
+ """
211
+ Returns a dictionary of all registered LoRA adapters.
212
+
213
+ Returns:
214
+ Dict[str, LoRARef]: A dictionary mapping LoRA names to LoRARef objects.
215
+ """
216
+ return dict(self._registry)
@@ -0,0 +1,122 @@
1
+ import asyncio
2
+ import logging
3
+ from concurrent.futures import ThreadPoolExecutor
4
+ from functools import partial
5
+ from typing import Any, Dict, List, Optional, Union
6
+
7
+ logger = logging.getLogger(__name__)
8
+
9
+
10
+ class AsyncMMDataProcessor:
11
+ """
12
+ Async wrapper for a multimodal processor.
13
+
14
+ Behavior:
15
+ - If the underlying processor exposes `process_mm_data_async`, call/await it directly.
16
+ - Otherwise, fall back to running a synchronous `process_mm_data` in a thread pool.
17
+ - Optionally guard per-call concurrency via an asyncio.Semaphore.
18
+ - Optionally enforce per-call timeout via asyncio.wait_for.
19
+ """
20
+
21
+ def __init__(
22
+ self,
23
+ mm_processor: Any,
24
+ *,
25
+ max_concurrent_calls: Optional[int] = None,
26
+ timeout_s: Optional[float] = None,
27
+ ) -> None:
28
+ """
29
+ Args:
30
+ mm_processor: An object exposing either
31
+ - async def process_mm_data_async(...): -> Dict[str, Any]
32
+ or
33
+ - def process_mm_data(...): -> Dict[str, Any]
34
+ max_concurrent_calls: Optional concurrency cap for per-call execution.
35
+ timeout_s: Optional timeout (seconds) for each `process()` call.
36
+ """
37
+ self.mm_processor = mm_processor
38
+ self.timeout_s = timeout_s
39
+
40
+ # Concurrency guard (None -> unlimited)
41
+ self.semaphore = (
42
+ asyncio.Semaphore(max_concurrent_calls) if max_concurrent_calls else None
43
+ )
44
+
45
+ # Detect async path; if missing, prepare a fallback executor for sync path
46
+ self._proc_async = getattr(mm_processor, "process_mm_data_async", None)
47
+ self.is_async = asyncio.iscoroutinefunction(self._proc_async)
48
+ self.fallback_exec: Optional[ThreadPoolExecutor] = (
49
+ ThreadPoolExecutor(max_workers=max_concurrent_calls)
50
+ if not self.is_async
51
+ else None
52
+ )
53
+
54
+ async def process(
55
+ self,
56
+ *,
57
+ image_data: Optional[List[Union[str, bytes]]] = None,
58
+ audio_data: Optional[List[Union[str, bytes]]] = None,
59
+ input_text_or_ids: Union[str, List[int], None] = None,
60
+ request_obj: Any,
61
+ **kwargs: Any,
62
+ ) -> Dict[str, Any]:
63
+ """
64
+ Public entrypoint: process a single multimodal request without blocking the event loop.
65
+ """
66
+
67
+ async def _invoke() -> Dict[str, Any]:
68
+ if self.is_async:
69
+ # Native async implementation
70
+ return await self._proc_async(
71
+ image_data=image_data,
72
+ audio_data=audio_data,
73
+ input_text=input_text_or_ids,
74
+ request_obj=request_obj,
75
+ **kwargs,
76
+ )
77
+
78
+ # Synchronous fallback
79
+ sync_fn = getattr(self.mm_processor, "process_mm_data", None)
80
+ if not callable(sync_fn):
81
+ raise RuntimeError(
82
+ "mm_processor has neither 'process_mm_data_async' nor 'process_mm_data'."
83
+ )
84
+ loop = asyncio.get_running_loop()
85
+ fn = partial(
86
+ sync_fn,
87
+ image_data=image_data,
88
+ audio_data=audio_data,
89
+ input_text=input_text_or_ids,
90
+ request_obj=request_obj,
91
+ **kwargs,
92
+ )
93
+ return await loop.run_in_executor(self.fallback_exec, fn)
94
+
95
+ # Apply optional concurrency guard
96
+ if self.semaphore is not None:
97
+ async with self.semaphore:
98
+ if self.timeout_s is not None:
99
+ return await asyncio.wait_for(_invoke(), timeout=self.timeout_s)
100
+ return await _invoke()
101
+
102
+ # No concurrency guard
103
+ if self.timeout_s is not None:
104
+ return await asyncio.wait_for(_invoke(), timeout=self.timeout_s)
105
+ return await _invoke()
106
+
107
+ def shutdown(self) -> None:
108
+ """Gracefully shutdown resources owned by this wrapper."""
109
+ try:
110
+ if self.fallback_exec:
111
+ self.fallback_exec.shutdown(wait=False)
112
+ except Exception:
113
+ logger.exception(
114
+ "Error while shutting down fallback executor in AsyncMMDataProcessor"
115
+ )
116
+
117
+ def __del__(self):
118
+ # Best-effort shutdown
119
+ try:
120
+ self.shutdown()
121
+ except Exception:
122
+ pass
@@ -34,13 +34,21 @@ from sglang.srt.managers.io_struct import (
34
34
  TokenizedGenerateReqInput,
35
35
  WatchLoadUpdateReq,
36
36
  )
37
- from sglang.srt.managers.schedule_batch import Req
37
+ from sglang.srt.managers.schedule_batch import Req, RequestStage
38
38
  from sglang.srt.managers.scheduler import run_scheduler_process
39
39
  from sglang.srt.server_args import (
40
40
  DP_ATTENTION_HANDSHAKE_PORT_DELTA,
41
41
  PortArgs,
42
42
  ServerArgs,
43
43
  )
44
+ from sglang.srt.tracing.trace import (
45
+ process_tracing_init,
46
+ trace_get_proc_propagate_context,
47
+ trace_set_proc_propagate_context,
48
+ trace_set_thread_info,
49
+ trace_slice_end,
50
+ trace_slice_start,
51
+ )
44
52
  from sglang.srt.utils import (
45
53
  bind_port,
46
54
  configure_logger,
@@ -170,11 +178,22 @@ class DataParallelController:
170
178
  def handle_load_update_req(self, obj):
171
179
  self.dp_budget.update_budget(obj)
172
180
 
181
+ def dispatching_with_trace(self, req: Req):
182
+ if self.server_args.enable_trace:
183
+ trace_set_proc_propagate_context(req.rid, req.trace_context)
184
+ trace_slice_start(RequestStage.DC_DISPATCH, req.rid)
185
+ req.trace_context = trace_get_proc_propagate_context(req.rid)
186
+
187
+ self.dispatching(req)
188
+
189
+ if self.server_args.enable_trace:
190
+ trace_slice_end(RequestStage.DC_DISPATCH, req.rid, thread_finish_flag=True)
191
+
173
192
  def init_dispatcher(self):
174
193
  self._request_dispatcher = TypeBasedDispatcher(
175
194
  [
176
- (TokenizedGenerateReqInput, self.dispatching),
177
- (TokenizedEmbeddingReqInput, self.dispatching),
195
+ (TokenizedGenerateReqInput, self.dispatching_with_trace),
196
+ (TokenizedEmbeddingReqInput, self.dispatching_with_trace),
178
197
  (BlockReqInput, self.send_to_all_workers),
179
198
  (WatchLoadUpdateReq, self.handle_load_update_req),
180
199
  ]
@@ -487,6 +506,14 @@ def run_data_parallel_controller_process(
487
506
  pipe_writer,
488
507
  ):
489
508
  kill_itself_when_parent_died()
509
+ if server_args.enable_trace:
510
+ process_tracing_init(server_args.otlp_traces_endpoint, "sglang")
511
+ thread_label = "DP Controller"
512
+ if server_args.disaggregation_mode == "prefill":
513
+ thread_label = "Prefill DP Controller"
514
+ elif server_args.disaggregation_mode == "decode":
515
+ thread_label = "Decode DP Controller"
516
+ trace_set_thread_info(thread_label)
490
517
  setproctitle.setproctitle("sglang::data_parallel_controller")
491
518
  faulthandler.enable()
492
519
  configure_logger(server_args)
@@ -235,6 +235,8 @@ class DetokenizerManager(MultiHttpWorkerDetokenizerMixin):
235
235
  new_text = ""
236
236
  else:
237
237
  new_text = find_printable_text(new_text)
238
+ else:
239
+ del self.decode_status[recv_obj.rids[i]]
238
240
 
239
241
  output_str = self.trim_matched_stop(
240
242
  s.decoded_text + new_text,
@@ -273,6 +275,7 @@ class DetokenizerManager(MultiHttpWorkerDetokenizerMixin):
273
275
  output_hidden_states=recv_obj.output_hidden_states,
274
276
  placeholder_tokens_idx=None,
275
277
  placeholder_tokens_val=None,
278
+ retraction_counts=recv_obj.retraction_counts,
276
279
  token_steps=recv_obj.token_steps,
277
280
  )
278
281
 
@@ -574,6 +574,7 @@ class GenerateReqInput(BaseReq):
574
574
  custom_labels=self.custom_labels,
575
575
  return_bytes=self.return_bytes,
576
576
  return_entropy=self.return_entropy,
577
+ http_worker_ipc=self.http_worker_ipc,
577
578
  )
578
579
 
579
580
 
@@ -694,6 +695,9 @@ class EmbeddingReqInput(BaseReq):
694
695
  # tracing context
695
696
  trace_context: Optional[Dict] = None
696
697
 
698
+ # The number of dimensions the resulting output embeddings should have. It is applicable for Matryoshka Embeddings.
699
+ dimensions: Optional[int] = None
700
+
697
701
  def normalize_batch_and_arguments(self):
698
702
  # at least one of text, input_ids, or image should be provided
699
703
  if self.text is None and self.input_ids is None and self.image_data is None:
@@ -759,6 +763,7 @@ class EmbeddingReqInput(BaseReq):
759
763
  sampling_params=self.sampling_params[i],
760
764
  rid=self.rid[i],
761
765
  is_cross_encoder_request=True,
766
+ http_worker_ipc=self.http_worker_ipc,
762
767
  )
763
768
 
764
769
  return EmbeddingReqInput(
@@ -769,6 +774,8 @@ class EmbeddingReqInput(BaseReq):
769
774
  video_data=self.video_data[i] if self.video_data is not None else None,
770
775
  sampling_params=self.sampling_params[i],
771
776
  rid=self.rid[i],
777
+ dimensions=self.dimensions,
778
+ http_worker_ipc=self.http_worker_ipc,
772
779
  )
773
780
 
774
781
 
@@ -788,6 +795,8 @@ class TokenizedEmbeddingReqInput(BaseReq):
788
795
  data_parallel_rank: Optional[int] = None
789
796
  # Priority for the request
790
797
  priority: Optional[int] = None
798
+ # The number of dimensions the resulting output embeddings should have. It is applicable for Matryoshka Embeddings.
799
+ dimensions: Optional[int] = None
791
800
 
792
801
 
793
802
  @dataclass
@@ -851,6 +860,9 @@ class BatchTokenIDOutput(BaseBatchReq):
851
860
  placeholder_tokens_idx: List[Optional[List[int]]]
852
861
  placeholder_tokens_val: List[Optional[List[int]]]
853
862
 
863
+ # Number of times each request was retracted.
864
+ retraction_counts: List[int]
865
+
854
866
  # The trainer step id. Used to know which step's weights are used for sampling.
855
867
  token_steps: List[List[int]] = None
856
868
 
@@ -927,6 +939,9 @@ class BatchStrOutput(BaseBatchReq):
927
939
  placeholder_tokens_idx: List[Optional[List[int]]]
928
940
  placeholder_tokens_val: List[Optional[List[int]]]
929
941
 
942
+ # Number of times each request was retracted.
943
+ retraction_counts: List[int]
944
+
930
945
  # The trainer step id. Used to know which step's weights are used for sampling.
931
946
  token_steps: List[List[int]] = None
932
947
 
@@ -969,6 +984,9 @@ class BatchEmbeddingOutput(BaseBatchReq):
969
984
  placeholder_tokens_idx: List[Optional[List[int]]]
970
985
  placeholder_tokens_val: List[Optional[List[int]]]
971
986
 
987
+ # Number of times each request was retracted.
988
+ retraction_counts: List[int]
989
+
972
990
 
973
991
  @dataclass
974
992
  class ClearHiCacheReqInput(BaseReq):
@@ -1212,7 +1230,7 @@ class AbortReq(BaseReq):
1212
1230
  abort_all: bool = False
1213
1231
  # The finished reason data
1214
1232
  finished_reason: Optional[Dict[str, Any]] = None
1215
- abort_reason: Optional[str] = None
1233
+ abort_message: Optional[str] = None
1216
1234
 
1217
1235
  def __post_init__(self):
1218
1236
  # FIXME: This is a hack to keep the same with the old code
@@ -1455,6 +1473,16 @@ class WatchLoadUpdateReq(BaseReq):
1455
1473
  loads: List[GetLoadReqOutput]
1456
1474
 
1457
1475
 
1476
+ @dataclass
1477
+ class SetInjectDumpMetadataReqInput(BaseReq):
1478
+ dump_metadata: Dict[str, Any]
1479
+
1480
+
1481
+ @dataclass
1482
+ class SetInjectDumpMetadataReqOutput(BaseReq):
1483
+ success: bool
1484
+
1485
+
1458
1486
  @dataclass
1459
1487
  class LazyDumpTensorsReqInput(BaseReq):
1460
1488
  pass
@@ -1486,6 +1514,3 @@ def _check_all_req_types():
1486
1514
  raise ValueError(
1487
1515
  f"{name} is a subclass of BaseReq but not follow the naming convention."
1488
1516
  )
1489
-
1490
-
1491
- _check_all_req_types()
@@ -13,7 +13,12 @@ from __future__ import annotations
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
15
  # ==============================================================================
16
- """Mixin class and utils for multi-http-worker mode"""
16
+
17
+ """
18
+ Mixin classes and utils for multi-http-worker mode
19
+ This file uses multiple processes to handle requests and tokenization, reducing the overhead of python and http server.
20
+ """
21
+
17
22
  import asyncio
18
23
  import logging
19
24
  import multiprocessing as multiprocessing
@@ -329,6 +334,11 @@ def _handle_output_by_index(output, i):
329
334
  ),
330
335
  placeholder_tokens_idx=None,
331
336
  placeholder_tokens_val=None,
337
+ retraction_counts=(
338
+ [output.retraction_counts[i]]
339
+ if len(output.retraction_counts) > i
340
+ else None
341
+ ),
332
342
  token_steps=([output.token_steps[i]] if output.token_steps else None),
333
343
  )
334
344
  elif isinstance(output, BatchMultimodalOutput):
@@ -566,3 +576,14 @@ def monkey_patch_uvicorn_multiprocessing(timeout: float = 10):
566
576
  logger.warning(
567
577
  "uvicorn.supervisors.multiprocess not found, skipping monkey patch"
568
578
  )
579
+
580
+
581
+ class SenderWrapper:
582
+ def __init__(self, port_args: PortArgs, send_to_scheduler: zmq.Socket):
583
+ self.port_args = port_args
584
+ self.send_to_scheduler = send_to_scheduler
585
+
586
+ def send_pyobj(self, obj):
587
+ if isinstance(obj, BaseReq):
588
+ obj.http_worker_ipc = self.port_args.tokenizer_ipc_name
589
+ self.send_to_scheduler.send_pyobj(obj)