sglang 0.4.4__py3-none-any.whl → 0.4.4.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (176) hide show
  1. sglang/__init__.py +2 -0
  2. sglang/api.py +6 -0
  3. sglang/bench_one_batch.py +1 -1
  4. sglang/bench_one_batch_server.py +1 -1
  5. sglang/bench_serving.py +3 -1
  6. sglang/check_env.py +3 -4
  7. sglang/lang/backend/openai.py +18 -5
  8. sglang/lang/chat_template.py +28 -7
  9. sglang/lang/interpreter.py +7 -3
  10. sglang/lang/ir.py +10 -0
  11. sglang/srt/_custom_ops.py +1 -1
  12. sglang/srt/code_completion_parser.py +174 -0
  13. sglang/srt/configs/__init__.py +2 -6
  14. sglang/srt/configs/deepseekvl2.py +667 -0
  15. sglang/srt/configs/janus_pro.py +3 -4
  16. sglang/srt/configs/load_config.py +1 -0
  17. sglang/srt/configs/model_config.py +63 -11
  18. sglang/srt/configs/utils.py +25 -0
  19. sglang/srt/connector/__init__.py +51 -0
  20. sglang/srt/connector/base_connector.py +112 -0
  21. sglang/srt/connector/redis.py +85 -0
  22. sglang/srt/connector/s3.py +122 -0
  23. sglang/srt/connector/serde/__init__.py +31 -0
  24. sglang/srt/connector/serde/safe_serde.py +29 -0
  25. sglang/srt/connector/serde/serde.py +43 -0
  26. sglang/srt/connector/utils.py +35 -0
  27. sglang/srt/conversation.py +88 -0
  28. sglang/srt/disaggregation/conn.py +81 -0
  29. sglang/srt/disaggregation/decode.py +495 -0
  30. sglang/srt/disaggregation/mini_lb.py +285 -0
  31. sglang/srt/disaggregation/prefill.py +249 -0
  32. sglang/srt/disaggregation/utils.py +44 -0
  33. sglang/srt/distributed/parallel_state.py +10 -3
  34. sglang/srt/entrypoints/engine.py +55 -5
  35. sglang/srt/entrypoints/http_server.py +71 -12
  36. sglang/srt/function_call_parser.py +164 -54
  37. sglang/srt/hf_transformers_utils.py +28 -3
  38. sglang/srt/layers/activation.py +4 -2
  39. sglang/srt/layers/attention/base_attn_backend.py +1 -1
  40. sglang/srt/layers/attention/flashattention_backend.py +295 -0
  41. sglang/srt/layers/attention/flashinfer_backend.py +1 -1
  42. sglang/srt/layers/attention/flashmla_backend.py +284 -0
  43. sglang/srt/layers/attention/triton_backend.py +171 -38
  44. sglang/srt/layers/attention/triton_ops/decode_attention.py +94 -31
  45. sglang/srt/layers/attention/triton_ops/extend_attention.py +14 -5
  46. sglang/srt/layers/attention/utils.py +53 -0
  47. sglang/srt/layers/attention/vision.py +9 -28
  48. sglang/srt/layers/dp_attention.py +62 -23
  49. sglang/srt/layers/elementwise.py +411 -0
  50. sglang/srt/layers/layernorm.py +24 -2
  51. sglang/srt/layers/linear.py +17 -5
  52. sglang/srt/layers/logits_processor.py +26 -7
  53. sglang/srt/layers/moe/ep_moe/kernels.py +110 -11
  54. sglang/srt/layers/moe/ep_moe/layer.py +273 -1
  55. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +416 -0
  56. sglang/srt/layers/moe/fused_moe_native.py +2 -1
  57. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L20,dtype=int8_w8a8.json +146 -0
  58. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L40S,dtype=int8_w8a8.json +146 -0
  59. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1024,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  60. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  61. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +23 -32
  62. sglang/srt/layers/moe/fused_moe_triton/layer.py +1 -2
  63. sglang/srt/layers/moe/router.py +342 -0
  64. sglang/srt/layers/moe/topk.py +31 -18
  65. sglang/srt/layers/parameter.py +1 -1
  66. sglang/srt/layers/quantization/__init__.py +184 -126
  67. sglang/srt/layers/quantization/base_config.py +5 -0
  68. sglang/srt/layers/quantization/blockwise_int8.py +1 -1
  69. sglang/srt/layers/quantization/compressed_tensors/__init__.py +0 -0
  70. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +652 -0
  71. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +658 -0
  72. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +9 -0
  73. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py +56 -0
  74. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +162 -0
  75. sglang/srt/layers/quantization/compressed_tensors/utils.py +218 -0
  76. sglang/srt/layers/quantization/fp8.py +76 -34
  77. sglang/srt/layers/quantization/fp8_kernel.py +24 -8
  78. sglang/srt/layers/quantization/fp8_utils.py +284 -28
  79. sglang/srt/layers/quantization/gptq.py +36 -9
  80. sglang/srt/layers/quantization/kv_cache.py +98 -0
  81. sglang/srt/layers/quantization/modelopt_quant.py +9 -7
  82. sglang/srt/layers/quantization/utils.py +153 -0
  83. sglang/srt/layers/quantization/w8a8_fp8.py +70 -19
  84. sglang/srt/layers/rotary_embedding.py +66 -87
  85. sglang/srt/layers/sampler.py +1 -1
  86. sglang/srt/lora/layers.py +68 -0
  87. sglang/srt/lora/lora.py +2 -22
  88. sglang/srt/lora/lora_manager.py +47 -23
  89. sglang/srt/lora/mem_pool.py +110 -51
  90. sglang/srt/lora/utils.py +12 -1
  91. sglang/srt/managers/cache_controller.py +4 -5
  92. sglang/srt/managers/data_parallel_controller.py +31 -9
  93. sglang/srt/managers/expert_distribution.py +81 -0
  94. sglang/srt/managers/io_struct.py +39 -3
  95. sglang/srt/managers/mm_utils.py +373 -0
  96. sglang/srt/managers/multimodal_processor.py +68 -0
  97. sglang/srt/managers/multimodal_processors/base_processor.py +275 -0
  98. sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +119 -0
  99. sglang/srt/managers/multimodal_processors/gemma3.py +83 -0
  100. sglang/srt/managers/{image_processors → multimodal_processors}/janus_pro.py +20 -15
  101. sglang/srt/managers/{image_processors → multimodal_processors}/llava.py +10 -15
  102. sglang/srt/managers/multimodal_processors/minicpm.py +167 -0
  103. sglang/srt/managers/{image_processors → multimodal_processors}/mlama.py +7 -8
  104. sglang/srt/managers/{image_processors → multimodal_processors}/qwen_vl.py +28 -22
  105. sglang/srt/managers/schedule_batch.py +134 -31
  106. sglang/srt/managers/scheduler.py +325 -38
  107. sglang/srt/managers/scheduler_output_processor_mixin.py +4 -1
  108. sglang/srt/managers/session_controller.py +1 -1
  109. sglang/srt/managers/tokenizer_manager.py +59 -23
  110. sglang/srt/managers/tp_worker.py +1 -1
  111. sglang/srt/managers/tp_worker_overlap_thread.py +3 -3
  112. sglang/srt/managers/utils.py +6 -1
  113. sglang/srt/mem_cache/hiradix_cache.py +27 -8
  114. sglang/srt/mem_cache/memory_pool.py +258 -98
  115. sglang/srt/mem_cache/paged_allocator.py +2 -2
  116. sglang/srt/mem_cache/radix_cache.py +4 -4
  117. sglang/srt/model_executor/cuda_graph_runner.py +85 -28
  118. sglang/srt/model_executor/forward_batch_info.py +81 -15
  119. sglang/srt/model_executor/model_runner.py +70 -6
  120. sglang/srt/model_loader/loader.py +160 -2
  121. sglang/srt/model_loader/weight_utils.py +45 -0
  122. sglang/srt/models/deepseek_janus_pro.py +29 -86
  123. sglang/srt/models/deepseek_nextn.py +22 -10
  124. sglang/srt/models/deepseek_v2.py +326 -192
  125. sglang/srt/models/deepseek_vl2.py +358 -0
  126. sglang/srt/models/gemma3_causal.py +684 -0
  127. sglang/srt/models/gemma3_mm.py +462 -0
  128. sglang/srt/models/grok.py +374 -119
  129. sglang/srt/models/llama.py +47 -7
  130. sglang/srt/models/llama_eagle.py +1 -0
  131. sglang/srt/models/llama_eagle3.py +196 -0
  132. sglang/srt/models/llava.py +3 -3
  133. sglang/srt/models/llavavid.py +3 -3
  134. sglang/srt/models/minicpmo.py +1995 -0
  135. sglang/srt/models/minicpmv.py +62 -137
  136. sglang/srt/models/mllama.py +4 -4
  137. sglang/srt/models/phi3_small.py +1 -1
  138. sglang/srt/models/qwen2.py +3 -0
  139. sglang/srt/models/qwen2_5_vl.py +68 -146
  140. sglang/srt/models/qwen2_classification.py +75 -0
  141. sglang/srt/models/qwen2_moe.py +9 -1
  142. sglang/srt/models/qwen2_vl.py +25 -63
  143. sglang/srt/openai_api/adapter.py +145 -47
  144. sglang/srt/openai_api/protocol.py +23 -2
  145. sglang/srt/sampling/sampling_batch_info.py +1 -1
  146. sglang/srt/sampling/sampling_params.py +6 -6
  147. sglang/srt/server_args.py +104 -14
  148. sglang/srt/speculative/build_eagle_tree.py +7 -347
  149. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +41 -5
  150. sglang/srt/speculative/eagle_utils.py +208 -252
  151. sglang/srt/speculative/eagle_worker.py +139 -53
  152. sglang/srt/speculative/spec_info.py +6 -1
  153. sglang/srt/torch_memory_saver_adapter.py +22 -0
  154. sglang/srt/utils.py +182 -21
  155. sglang/test/__init__.py +0 -0
  156. sglang/test/attention/__init__.py +0 -0
  157. sglang/test/attention/test_flashattn_backend.py +312 -0
  158. sglang/test/runners.py +2 -0
  159. sglang/test/test_activation.py +2 -1
  160. sglang/test/test_block_fp8.py +5 -4
  161. sglang/test/test_block_fp8_ep.py +2 -1
  162. sglang/test/test_dynamic_grad_mode.py +58 -0
  163. sglang/test/test_layernorm.py +3 -2
  164. sglang/test/test_utils.py +55 -4
  165. sglang/utils.py +31 -0
  166. sglang/version.py +1 -1
  167. {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/METADATA +12 -8
  168. {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/RECORD +171 -125
  169. {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/WHEEL +1 -1
  170. sglang/srt/configs/qwen2_5_vl_config.py +0 -1006
  171. sglang/srt/managers/image_processor.py +0 -55
  172. sglang/srt/managers/image_processors/base_image_processor.py +0 -219
  173. sglang/srt/managers/image_processors/minicpmv.py +0 -86
  174. sglang/srt/managers/multi_modality_padding.py +0 -134
  175. {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info/licenses}/LICENSE +0 -0
  176. {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/top_level.txt +0 -0
@@ -16,6 +16,7 @@
16
16
  from __future__ import annotations
17
17
 
18
18
  import bisect
19
+ import os
19
20
  from contextlib import contextmanager
20
21
  from typing import TYPE_CHECKING, Callable
21
22
 
@@ -33,7 +34,7 @@ from sglang.srt.model_executor.forward_batch_info import (
33
34
  ForwardBatch,
34
35
  ForwardMode,
35
36
  )
36
- from sglang.srt.utils import is_hip
37
+ from sglang.srt.utils import get_available_gpu_memory, is_hip
37
38
 
38
39
  _is_hip = is_hip()
39
40
 
@@ -81,7 +82,9 @@ def patch_model(
81
82
  # tp_group.ca_comm = None
82
83
  yield torch.compile(
83
84
  torch.no_grad()(model.forward),
84
- mode="max-autotune-no-cudagraphs",
85
+ mode=os.environ.get(
86
+ "SGLANG_TORCH_COMPILE_MODE", "max-autotune-no-cudagraphs"
87
+ ),
85
88
  dynamic=False,
86
89
  )
87
90
  else:
@@ -117,7 +120,9 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner):
117
120
  else:
118
121
  capture_bs = [1, 2, 4] + [i * 8 for i in range(1, 21)]
119
122
  else:
120
- capture_bs = list(range(1, 33))
123
+ # Since speculative decoding requires more cuda graph memory, we
124
+ # capture less.
125
+ capture_bs = list(range(1, 9)) + list(range(9, 33, 2)) + [64, 96, 128, 160]
121
126
 
122
127
  if _is_hip:
123
128
  capture_bs += [i * 8 for i in range(21, 33)]
@@ -125,16 +130,11 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner):
125
130
  if max(capture_bs) > model_runner.req_to_token_pool.size:
126
131
  # In some case (e.g., with a small GPU or --max-running-requests), the #max-running-requests
127
132
  # is very small. We add more values here to make sure we capture the maximum bs.
128
- capture_bs = list(
129
- sorted(
130
- set(
131
- capture_bs
132
- + [model_runner.req_to_token_pool.size - 1]
133
- + [model_runner.req_to_token_pool.size]
134
- )
135
- )
136
- )
133
+ capture_bs += [model_runner.req_to_token_pool.size - 1] + [
134
+ model_runner.req_to_token_pool.size
135
+ ]
137
136
 
137
+ capture_bs = list(sorted(set(capture_bs)))
138
138
  capture_bs = [
139
139
  bs
140
140
  for bs in capture_bs
@@ -174,6 +174,7 @@ class CudaGraphRunner:
174
174
  self.disable_padding = model_runner.server_args.disable_cuda_graph_padding
175
175
  self.is_encoder_decoder = model_runner.model_config.is_encoder_decoder
176
176
  self.enable_dp_attention = model_runner.server_args.enable_dp_attention
177
+ self.speculative_algorithm = model_runner.server_args.speculative_algorithm
177
178
  self.tp_size = model_runner.server_args.tp_size
178
179
  self.dp_size = model_runner.server_args.dp_size
179
180
 
@@ -219,7 +220,19 @@ class CudaGraphRunner:
219
220
  self.mrope_positions = torch.zeros((3, self.max_bs), dtype=torch.int64)
220
221
 
221
222
  # Speculative_inference
222
- if model_runner.spec_algorithm.is_eagle():
223
+ if (
224
+ model_runner.spec_algorithm.is_eagle3()
225
+ and not model_runner.is_draft_worker
226
+ ):
227
+ self.hidden_states = torch.zeros(
228
+ (
229
+ self.max_num_token,
230
+ 3 * self.model_runner.model_config.hidden_size,
231
+ ),
232
+ dtype=self.model_runner.dtype,
233
+ )
234
+ self.model_runner.model.set_eagle3_layers_to_capture()
235
+ elif model_runner.spec_algorithm.is_eagle():
223
236
  self.hidden_states = torch.zeros(
224
237
  (self.max_num_token, self.model_runner.model_config.hidden_size),
225
238
  dtype=self.model_runner.dtype,
@@ -236,7 +249,7 @@ class CudaGraphRunner:
236
249
  if self.enable_dp_attention:
237
250
  self.gathered_buffer = torch.zeros(
238
251
  (
239
- self.max_bs * self.dp_size,
252
+ self.max_bs * self.dp_size * self.num_tokens_per_bs,
240
253
  self.model_runner.model_config.hidden_size,
241
254
  ),
242
255
  dtype=self.model_runner.dtype,
@@ -276,13 +289,12 @@ class CudaGraphRunner:
276
289
 
277
290
  def can_run(self, forward_batch: ForwardBatch):
278
291
  if self.enable_dp_attention:
279
- min_num_tokens, max_num_tokens = min(
280
- forward_batch.global_num_tokens_cpu
281
- ), max(forward_batch.global_num_tokens_cpu)
292
+ total_global_tokens = sum(forward_batch.global_num_tokens_cpu)
293
+
282
294
  is_bs_supported = forward_batch.can_run_dp_cuda_graph and (
283
- (min_num_tokens == max_num_tokens and max_num_tokens in self.graphs)
295
+ total_global_tokens in self.graphs
284
296
  if self.disable_padding
285
- else max_num_tokens <= self.max_bs
297
+ else total_global_tokens <= self.max_bs
286
298
  )
287
299
  else:
288
300
  is_bs_supported = (
@@ -304,6 +316,9 @@ class CudaGraphRunner:
304
316
  def capture(self):
305
317
  with graph_capture() as graph_capture_context:
306
318
  self.stream = graph_capture_context.stream
319
+ avail_mem = get_available_gpu_memory(
320
+ self.model_runner.device, self.model_runner.gpu_id, empty_cache=False
321
+ )
307
322
  # Reverse the order to enable better memory sharing across cuda graphs.
308
323
  capture_range = (
309
324
  tqdm.tqdm(list(reversed(self.capture_bs)))
@@ -311,6 +326,16 @@ class CudaGraphRunner:
311
326
  else reversed(self.capture_bs)
312
327
  )
313
328
  for bs in capture_range:
329
+ if get_tensor_model_parallel_rank() == 0:
330
+ avail_mem = get_available_gpu_memory(
331
+ self.model_runner.device,
332
+ self.model_runner.gpu_id,
333
+ empty_cache=False,
334
+ )
335
+ capture_range.set_description(
336
+ f"Capturing batches ({avail_mem=:.2f} GB)"
337
+ )
338
+
314
339
  with patch_model(
315
340
  self.model_runner.model,
316
341
  bs in self.compile_bs,
@@ -345,8 +370,18 @@ class CudaGraphRunner:
345
370
  mrope_positions = self.mrope_positions[:, :bs]
346
371
 
347
372
  if self.enable_dp_attention:
348
- global_num_tokens = [bs] * self.tp_size
349
- gathered_buffer = self.gathered_buffer[: bs * self.tp_size]
373
+ self.global_num_tokens_gpu.copy_(
374
+ torch.tensor(
375
+ [
376
+ num_tokens // self.dp_size + (i < bs % self.dp_size)
377
+ for i in range(self.dp_size)
378
+ ],
379
+ dtype=torch.int32,
380
+ device=input_ids.device,
381
+ )
382
+ )
383
+ global_num_tokens = self.global_num_tokens_gpu
384
+ gathered_buffer = self.gathered_buffer[:num_tokens]
350
385
  else:
351
386
  global_num_tokens = None
352
387
  gathered_buffer = None
@@ -371,7 +406,7 @@ class CudaGraphRunner:
371
406
  encoder_lens=encoder_lens,
372
407
  return_logprob=False,
373
408
  positions=positions,
374
- global_num_tokens_cpu=global_num_tokens,
409
+ global_num_tokens_gpu=global_num_tokens,
375
410
  gathered_buffer=gathered_buffer,
376
411
  mrope_positions=mrope_positions,
377
412
  spec_algorithm=self.model_runner.spec_algorithm,
@@ -392,6 +427,9 @@ class CudaGraphRunner:
392
427
 
393
428
  # Run and capture
394
429
  def run_once():
430
+ # Clean intermediate result cache for DP attention
431
+ forward_batch.dp_local_start_pos = forward_batch.dp_local_num_tokens = None
432
+
395
433
  logits_output = forward(input_ids, forward_batch.positions, forward_batch)
396
434
  return logits_output.next_token_logits, logits_output.hidden_states
397
435
 
@@ -426,7 +464,7 @@ class CudaGraphRunner:
426
464
  self.capture_hidden_mode = hidden_mode_from_spec_info
427
465
  self.capture()
428
466
 
429
- def replay(self, forward_batch: ForwardBatch, skip_attn_backend_init: bool = False):
467
+ def replay_prepare(self, forward_batch: ForwardBatch):
430
468
  self.recapture_if_needed(forward_batch)
431
469
 
432
470
  raw_bs = forward_batch.batch_size
@@ -435,7 +473,7 @@ class CudaGraphRunner:
435
473
  # Pad
436
474
  if self.enable_dp_attention:
437
475
  index = bisect.bisect_left(
438
- self.capture_bs, max(forward_batch.global_num_tokens_cpu)
476
+ self.capture_bs, sum(forward_batch.global_num_tokens_cpu)
439
477
  )
440
478
  else:
441
479
  index = bisect.bisect_left(self.capture_bs, raw_bs)
@@ -459,6 +497,8 @@ class CudaGraphRunner:
459
497
  self.encoder_lens[:raw_bs].copy_(forward_batch.encoder_lens)
460
498
  if forward_batch.mrope_positions is not None:
461
499
  self.mrope_positions[:, :raw_bs].copy_(forward_batch.mrope_positions)
500
+ if self.enable_dp_attention:
501
+ self.global_num_tokens_gpu.copy_(forward_batch.global_num_tokens_gpu)
462
502
 
463
503
  if hasattr(forward_batch.spec_info, "hidden_states"):
464
504
  self.hidden_states[:raw_num_token] = forward_batch.spec_info.hidden_states
@@ -475,14 +515,31 @@ class CudaGraphRunner:
475
515
  seq_lens_cpu=self.seq_lens_cpu,
476
516
  )
477
517
 
518
+ # Store fields
519
+ self.raw_bs = raw_bs
520
+ self.raw_num_token = raw_num_token
521
+ self.bs = bs
522
+
523
+ def replay(
524
+ self, forward_batch: ForwardBatch, skip_attn_backend_init: bool = False
525
+ ) -> LogitsProcessorOutput:
526
+ if not skip_attn_backend_init:
527
+ self.replay_prepare(forward_batch)
528
+ else:
529
+ # In speculative decoding, these two fields are still needed.
530
+ self.input_ids[: self.raw_num_token].copy_(forward_batch.input_ids)
531
+ self.positions[: self.raw_num_token].copy_(forward_batch.positions)
532
+
478
533
  # Replay
479
- self.graphs[bs].replay()
480
- next_token_logits, hidden_states = self.output_buffers[bs]
534
+ self.graphs[self.bs].replay()
535
+ next_token_logits, hidden_states = self.output_buffers[self.bs]
481
536
 
482
537
  logits_output = LogitsProcessorOutput(
483
- next_token_logits=next_token_logits[:raw_num_token],
538
+ next_token_logits=next_token_logits[: self.raw_num_token],
484
539
  hidden_states=(
485
- hidden_states[:raw_num_token] if hidden_states is not None else None
540
+ hidden_states[: self.raw_num_token]
541
+ if hidden_states is not None
542
+ else None
486
543
  ),
487
544
  )
488
545
  return logits_output
@@ -33,16 +33,17 @@ from dataclasses import dataclass
33
33
  from enum import IntEnum, auto
34
34
  from typing import TYPE_CHECKING, List, Optional, Union
35
35
 
36
+ import numpy as np
36
37
  import torch
37
38
  import triton
38
39
  import triton.language as tl
39
40
 
40
41
  from sglang.srt.layers.rotary_embedding import MRotaryEmbedding
41
- from sglang.srt.utils import get_compiler_backend, next_power_of_2
42
+ from sglang.srt.utils import get_compiler_backend
42
43
 
43
44
  if TYPE_CHECKING:
44
45
  from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
45
- from sglang.srt.managers.schedule_batch import ImageInputs, ModelWorkerBatch
46
+ from sglang.srt.managers.schedule_batch import ModelWorkerBatch, MultimodalInputs
46
47
  from sglang.srt.mem_cache.memory_pool import KVCache, ReqToTokenPool
47
48
  from sglang.srt.model_executor.model_runner import ModelRunner
48
49
  from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
@@ -175,7 +176,7 @@ class ForwardBatch:
175
176
  extend_input_logprob_token_ids_gpu: Optional[torch.Tensor] = None
176
177
 
177
178
  # For multimodal
178
- image_inputs: Optional[List[ImageInputs]] = None
179
+ mm_inputs: Optional[List[MultimodalInputs]] = None
179
180
 
180
181
  # Encoder-decoder
181
182
  encoder_cached: Optional[List[bool]] = None
@@ -241,7 +242,7 @@ class ForwardBatch:
241
242
  req_pool_indices=batch.req_pool_indices,
242
243
  seq_lens=batch.seq_lens,
243
244
  out_cache_loc=batch.out_cache_loc,
244
- image_inputs=batch.image_inputs,
245
+ mm_inputs=batch.multimodal_inputs,
245
246
  encoder_cached=batch.encoder_cached,
246
247
  encoder_lens=batch.encoder_lens,
247
248
  encoder_lens_cpu=batch.encoder_lens_cpu,
@@ -263,15 +264,24 @@ class ForwardBatch:
263
264
  extend_input_logprob_token_ids_gpu=extend_input_logprob_token_ids_gpu,
264
265
  )
265
266
 
267
+ # For DP attention
266
268
  if batch.global_num_tokens is not None:
267
269
  ret.global_num_tokens_cpu = batch.global_num_tokens
268
- max_len = max(ret.global_num_tokens_cpu)
270
+ ret.global_num_tokens_gpu = torch.tensor(
271
+ batch.global_num_tokens, dtype=torch.int64
272
+ ).to(device, non_blocking=True)
273
+
274
+ ret.global_num_tokens_for_logprob_cpu = batch.global_num_tokens_for_logprob
275
+ ret.global_num_tokens_for_logprob_gpu = torch.tensor(
276
+ batch.global_num_tokens_for_logprob, dtype=torch.int64
277
+ ).to(device, non_blocking=True)
278
+
279
+ sum_len = sum(batch.global_num_tokens)
269
280
  ret.gathered_buffer = torch.zeros(
270
- (max_len * model_runner.tp_size, model_runner.model_config.hidden_size),
281
+ (sum_len, model_runner.model_config.hidden_size),
271
282
  dtype=model_runner.dtype,
272
283
  device=device,
273
284
  )
274
-
275
285
  if ret.forward_mode.is_idle():
276
286
  ret.positions = torch.empty((0,), device=device)
277
287
  return ret
@@ -322,6 +332,53 @@ class ForwardBatch:
322
332
 
323
333
  return ret
324
334
 
335
+ def merge_mm_inputs(self) -> Optional[MultimodalInputs]:
336
+ """
337
+ Merge all image inputs in the batch into a single MultiModalInputs object.
338
+
339
+ Returns:
340
+ if none, current batch contains no image input
341
+
342
+ """
343
+ if not self.mm_inputs or all(x is None for x in self.mm_inputs):
344
+ return None
345
+
346
+ # Filter out None values
347
+ valid_inputs = [x for x in self.mm_inputs if x is not None]
348
+
349
+ # Start with the first valid image input
350
+ merged = valid_inputs[0]
351
+
352
+ # Merge remaining inputs
353
+ for mm_input in valid_inputs[1:]:
354
+ merged.merge(mm_input)
355
+
356
+ if isinstance(merged.pixel_values, np.ndarray):
357
+ merged.pixel_values = torch.from_numpy(merged.pixel_values)
358
+ if isinstance(merged.audio_features, np.ndarray):
359
+ merged.audio_features = torch.from_numpy(merged.audio_features)
360
+
361
+ return merged
362
+
363
+ def contains_image_inputs(self) -> bool:
364
+ if self.mm_inputs is None:
365
+ return False
366
+ return any(
367
+ mm_input is not None and mm_input.contains_image_inputs()
368
+ for mm_input in self.mm_inputs
369
+ )
370
+
371
+ def contains_audio_inputs(self) -> bool:
372
+ if self.mm_inputs is None:
373
+ return False
374
+ return any(
375
+ mm_input is not None and mm_input.contains_audio_inputs()
376
+ for mm_input in self.mm_inputs
377
+ )
378
+
379
+ def contains_mm_inputs(self) -> bool:
380
+ return self.contains_audio_inputs() or self.contains_image_inputs()
381
+
325
382
  def _compute_mrope_positions(
326
383
  self, model_runner: ModelRunner, batch: ModelWorkerBatch
327
384
  ):
@@ -332,8 +389,8 @@ class ForwardBatch:
332
389
  for i, _ in enumerate(mrope_positions_list):
333
390
  mrope_position_delta = (
334
391
  0
335
- if batch.image_inputs[i] is None
336
- else batch.image_inputs[i].mrope_position_delta
392
+ if batch.multimodal_inputs[i] is None
393
+ else batch.multimodal_inputs[i].mrope_position_delta
337
394
  )
338
395
  mrope_positions_list[i] = MRotaryEmbedding.get_next_input_positions(
339
396
  mrope_position_delta,
@@ -342,13 +399,13 @@ class ForwardBatch:
342
399
  )
343
400
  elif self.forward_mode.is_extend():
344
401
  extend_start_loc_cpu = self.extend_start_loc.cpu().numpy()
345
- for i, image_inputs in enumerate(batch.image_inputs):
402
+ for i, multimodal_inputs in enumerate(batch.multimodal_inputs):
346
403
  extend_start_loc, extend_seq_len, extend_prefix_len = (
347
404
  extend_start_loc_cpu[i],
348
405
  batch.extend_seq_lens[i],
349
406
  batch.extend_prefix_lens[i],
350
407
  )
351
- if image_inputs is None:
408
+ if multimodal_inputs is None:
352
409
  # text only
353
410
  mrope_positions = [
354
411
  [
@@ -365,16 +422,25 @@ class ForwardBatch:
365
422
  input_tokens=self.input_ids[
366
423
  extend_start_loc : extend_start_loc + extend_seq_len
367
424
  ],
368
- image_grid_thw=image_inputs.image_grid_thws,
425
+ image_grid_thw=multimodal_inputs.image_grid_thws,
426
+ video_grid_thw=multimodal_inputs.video_grid_thws,
427
+ image_token_id=multimodal_inputs.im_token_id,
428
+ video_token_id=multimodal_inputs.video_token_id,
369
429
  vision_start_token_id=hf_config.vision_start_token_id,
430
+ vision_end_token_id=hf_config.vision_end_token_id,
370
431
  spatial_merge_size=hf_config.vision_config.spatial_merge_size,
371
432
  context_len=0,
433
+ seq_len=len(self.input_ids),
434
+ second_per_grid_ts=multimodal_inputs.second_per_grid_ts,
435
+ tokens_per_second=hf_config.vision_config.tokens_per_second,
372
436
  )
373
437
  )
374
- batch.image_inputs[i].mrope_position_delta = mrope_position_delta
438
+ batch.multimodal_inputs[i].mrope_position_delta = (
439
+ mrope_position_delta
440
+ )
375
441
  mrope_positions_list[i] = mrope_positions
376
442
 
377
- self.mrope_positions = torch.concat(
443
+ self.mrope_positions = torch.cat(
378
444
  [torch.tensor(pos, device=device) for pos in mrope_positions_list],
379
445
  axis=1,
380
446
  )
@@ -440,7 +506,7 @@ def compute_position_kernel(
440
506
  def compute_position_torch(
441
507
  extend_prefix_lens: torch.Tensor, extend_seq_lens: torch.Tensor
442
508
  ):
443
- positions = torch.concat(
509
+ positions = torch.cat(
444
510
  [
445
511
  torch.arange(
446
512
  prefix_len, prefix_len + extend_len, device=extend_prefix_lens.device
@@ -145,10 +145,12 @@ class ModelRunner:
145
145
  "enable_nan_detection": server_args.enable_nan_detection,
146
146
  "enable_dp_attention": server_args.enable_dp_attention,
147
147
  "enable_ep_moe": server_args.enable_ep_moe,
148
+ "enable_deepep_moe": server_args.enable_deepep_moe,
148
149
  "device": server_args.device,
149
150
  "speculative_accept_threshold_single": server_args.speculative_accept_threshold_single,
150
151
  "speculative_accept_threshold_acc": server_args.speculative_accept_threshold_acc,
151
152
  "enable_flashinfer_mla": server_args.enable_flashinfer_mla,
153
+ "enable_flashmla": server_args.enable_flashmla,
152
154
  "disable_radix_cache": server_args.disable_radix_cache,
153
155
  "flashinfer_mla_disable_ragged": server_args.flashinfer_mla_disable_ragged,
154
156
  "debug_tensor_dump_output_folder": server_args.debug_tensor_dump_output_folder,
@@ -187,9 +189,6 @@ class ModelRunner:
187
189
  supports_torch_tp = getattr(self.model, "supports_torch_tp", False)
188
190
  if self.tp_size > 1 and supports_torch_tp:
189
191
  self.apply_torch_tp()
190
- self.torch_tp_applied = True
191
- else:
192
- self.torch_tp_applied = False
193
192
 
194
193
  # Init lora
195
194
  if server_args.lora_paths is not None:
@@ -209,6 +208,10 @@ class ModelRunner:
209
208
  self.cuda_graph_runner = None
210
209
  self.init_attention_backend()
211
210
 
211
+ # auxiliary hidden capture mode. TODO: expose this to server args?
212
+ if self.spec_algorithm.is_eagle3() and not self.is_draft_worker:
213
+ self.model.set_eagle3_layers_to_capture()
214
+
212
215
  def model_specific_adjustment(self):
213
216
  server_args = self.server_args
214
217
 
@@ -223,6 +226,9 @@ class ModelRunner:
223
226
  "MLA optimization is turned on. Use flashinfer mla backend."
224
227
  )
225
228
  server_args.attention_backend = "flashinfer_mla"
229
+ elif server_args.enable_flashmla:
230
+ logger.info("MLA optimization is turned on. Use flashmla decode.")
231
+ server_args.attention_backend = "flashmla"
226
232
  else:
227
233
  logger.info("MLA optimization is turned on. Use triton backend.")
228
234
  server_args.attention_backend = "triton"
@@ -254,18 +260,41 @@ class ModelRunner:
254
260
 
255
261
  if self.model_config.hf_config.architectures == [
256
262
  "Qwen2VLForConditionalGeneration"
263
+ ] or self.model_config.hf_config.architectures == [
264
+ "Qwen2_5_VLForConditionalGeneration"
257
265
  ]:
258
- # TODO: qwen2-vl does not support radix cache now, set disable_radix_cache=True automatically
266
+ # TODO: qwen2-vl series does not support radix cache now, set disable_radix_cache=True automatically
267
+ logger.info(
268
+ "Automatically turn off --chunked-prefill-size and disable radix cache for qwen-vl series."
269
+ )
270
+ server_args.chunked_prefill_size = -1
271
+ server_args.disable_radix_cache = True
272
+
273
+ if self.model_config.hf_config.architectures == ["DeepseekVL2ForCausalLM"]:
274
+ # TODO: deepseek-vl2 does not support radix cache now, set disable_radix_cache=True automatically
259
275
  logger.info(
260
- "Automatically turn off --chunked-prefill-size and disable radix cache for qwen2-vl."
276
+ "Automatically turn off --chunked-prefill-size and disable radix cache for deepseek-vl2."
261
277
  )
262
278
  server_args.chunked_prefill_size = -1
263
279
  server_args.disable_radix_cache = True
264
280
 
281
+ if server_args.enable_deepep_moe:
282
+ logger.info("DeepEP is turned on.")
283
+ assert (
284
+ server_args.enable_dp_attention == True
285
+ ), "Currently DeepEP is bind to Attention DP. Set '--enable-dp-attention --enable-deepep-moe'"
286
+
265
287
  def init_torch_distributed(self):
266
288
  logger.info("Init torch distributed begin.")
267
289
 
268
- torch.get_device_module(self.device).set_device(self.gpu_id)
290
+ try:
291
+ torch.get_device_module(self.device).set_device(self.gpu_id)
292
+ except Exception:
293
+ logger.warning(
294
+ f"Context: {self.device=} {self.gpu_id=} {os.environ.get('CUDA_VISIBLE_DEVICES')=} {self.tp_rank=} {self.tp_size=}"
295
+ )
296
+ raise
297
+
269
298
  if self.device == "cuda":
270
299
  backend = "nccl"
271
300
  elif self.device == "xpu":
@@ -606,6 +635,8 @@ class ModelRunner:
606
635
  load_config=self.load_config,
607
636
  dtype=self.dtype,
608
637
  lora_backend=self.server_args.lora_backend,
638
+ tp_size=self.tp_size,
639
+ tp_rank=self.tp_rank,
609
640
  )
610
641
  logger.info("LoRA manager ready.")
611
642
 
@@ -840,6 +871,23 @@ class ModelRunner:
840
871
  )
841
872
 
842
873
  self.attn_backend = FlashInferMLAAttnBackend(self)
874
+ elif self.server_args.attention_backend == "flashmla":
875
+ from sglang.srt.layers.attention.flashmla_backend import FlashMLABackend
876
+
877
+ self.attn_backend = FlashMLABackend(self)
878
+ elif self.server_args.attention_backend == "fa3":
879
+ assert torch.cuda.get_device_capability()[0] >= 9, (
880
+ "FlashAttention v3 Backend requires SM>=90. "
881
+ "Please use `--attention-backend flashinfer`."
882
+ )
883
+ logger.warning(
884
+ "FlashAttention v3 Backend is in Beta. Multimodal, Page > 1, FP8, MLA and Speculative Decoding are not supported."
885
+ )
886
+ from sglang.srt.layers.attention.flashattention_backend import (
887
+ FlashAttentionBackend,
888
+ )
889
+
890
+ self.attn_backend = FlashAttentionBackend(self)
843
891
  else:
844
892
  raise ValueError(
845
893
  f"Invalid attention backend: {self.server_args.attention_backend}"
@@ -1009,6 +1057,22 @@ class ModelRunner:
1009
1057
  return False
1010
1058
  return rope_scaling.get("type", None) == "mrope"
1011
1059
 
1060
+ def save_remote_model(self, url: str):
1061
+ from sglang.srt.model_loader.loader import RemoteModelLoader
1062
+
1063
+ logger.info(f"Saving model to {url}")
1064
+ RemoteModelLoader.save_model(self.model, self.model_config.model_path, url)
1065
+
1066
+ def save_sharded_model(
1067
+ self, path: str, pattern: Optional[str] = None, max_size: Optional[int] = None
1068
+ ):
1069
+ from sglang.srt.model_loader.loader import ShardedStateLoader
1070
+
1071
+ logger.info(
1072
+ f"Save sharded model to {path} with pattern {pattern} and max_size {max_size}"
1073
+ )
1074
+ ShardedStateLoader.save_model(self.model, path, pattern, max_size)
1075
+
1012
1076
 
1013
1077
  def _model_load_weights_direct(model, named_tensors: List[Tuple[str, torch.Tensor]]):
1014
1078
  params_dict = dict(model.named_parameters())