sglang 0.4.6.post1__py3-none-any.whl → 0.4.6.post3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (119) hide show
  1. sglang/bench_one_batch.py +3 -11
  2. sglang/bench_serving.py +149 -1
  3. sglang/check_env.py +3 -3
  4. sglang/lang/chat_template.py +44 -0
  5. sglang/srt/configs/__init__.py +4 -0
  6. sglang/srt/configs/deepseekvl2.py +3 -0
  7. sglang/srt/configs/device_config.py +1 -1
  8. sglang/srt/configs/internvl.py +696 -0
  9. sglang/srt/configs/janus_pro.py +3 -0
  10. sglang/srt/configs/kimi_vl.py +38 -0
  11. sglang/srt/configs/kimi_vl_moonvit.py +32 -0
  12. sglang/srt/configs/model_config.py +32 -0
  13. sglang/srt/constrained/xgrammar_backend.py +11 -19
  14. sglang/srt/conversation.py +151 -3
  15. sglang/srt/disaggregation/decode.py +4 -1
  16. sglang/srt/disaggregation/mini_lb.py +74 -23
  17. sglang/srt/disaggregation/mooncake/conn.py +9 -18
  18. sglang/srt/disaggregation/nixl/conn.py +241 -71
  19. sglang/srt/disaggregation/utils.py +44 -1
  20. sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -8
  21. sglang/srt/distributed/device_communicators/npu_communicator.py +39 -0
  22. sglang/srt/distributed/device_communicators/pynccl.py +2 -1
  23. sglang/srt/distributed/device_communicators/shm_broadcast.py +2 -1
  24. sglang/srt/distributed/parallel_state.py +22 -1
  25. sglang/srt/entrypoints/engine.py +58 -24
  26. sglang/srt/entrypoints/http_server.py +28 -1
  27. sglang/srt/entrypoints/verl_engine.py +3 -2
  28. sglang/srt/function_call_parser.py +97 -0
  29. sglang/srt/hf_transformers_utils.py +22 -1
  30. sglang/srt/layers/attention/cutlass_mla_backend.py +1 -1
  31. sglang/srt/layers/attention/flashattention_backend.py +146 -50
  32. sglang/srt/layers/attention/flashinfer_backend.py +129 -94
  33. sglang/srt/layers/attention/flashinfer_mla_backend.py +88 -30
  34. sglang/srt/layers/attention/flashmla_backend.py +3 -0
  35. sglang/srt/layers/attention/merge_state.py +46 -0
  36. sglang/srt/layers/attention/triton_ops/merge_state.py +96 -0
  37. sglang/srt/layers/attention/vision.py +290 -163
  38. sglang/srt/layers/dp_attention.py +5 -2
  39. sglang/srt/layers/moe/ep_moe/kernels.py +342 -7
  40. sglang/srt/layers/moe/ep_moe/layer.py +120 -1
  41. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +98 -57
  42. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=192,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  43. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  44. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  45. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  46. sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  47. sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  48. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +10 -5
  49. sglang/srt/layers/quantization/__init__.py +2 -2
  50. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +2 -4
  51. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +2 -1
  52. sglang/srt/layers/quantization/deep_gemm.py +6 -1
  53. sglang/srt/layers/quantization/fp8.py +108 -95
  54. sglang/srt/layers/quantization/fp8_kernel.py +79 -60
  55. sglang/srt/layers/quantization/fp8_utils.py +71 -23
  56. sglang/srt/layers/quantization/kv_cache.py +3 -10
  57. sglang/srt/layers/quantization/utils.py +0 -5
  58. sglang/srt/layers/quantization/w8a8_fp8.py +8 -10
  59. sglang/srt/layers/utils.py +35 -0
  60. sglang/srt/lora/layers.py +35 -9
  61. sglang/srt/lora/lora_manager.py +81 -35
  62. sglang/srt/managers/cache_controller.py +115 -119
  63. sglang/srt/managers/data_parallel_controller.py +52 -34
  64. sglang/srt/managers/io_struct.py +10 -0
  65. sglang/srt/managers/multimodal_processors/base_processor.py +5 -0
  66. sglang/srt/managers/multimodal_processors/internvl.py +232 -0
  67. sglang/srt/managers/multimodal_processors/kimi_vl.py +73 -0
  68. sglang/srt/managers/schedule_batch.py +44 -16
  69. sglang/srt/managers/schedule_policy.py +11 -5
  70. sglang/srt/managers/scheduler.py +291 -72
  71. sglang/srt/managers/scheduler_output_processor_mixin.py +1 -1
  72. sglang/srt/managers/tokenizer_manager.py +24 -13
  73. sglang/srt/managers/tp_worker.py +60 -28
  74. sglang/srt/managers/tp_worker_overlap_thread.py +9 -3
  75. sglang/srt/mem_cache/chunk_cache.py +2 -0
  76. sglang/srt/mem_cache/memory_pool.py +70 -36
  77. sglang/srt/model_executor/cuda_graph_runner.py +82 -19
  78. sglang/srt/model_executor/forward_batch_info.py +31 -1
  79. sglang/srt/model_executor/model_runner.py +159 -90
  80. sglang/srt/model_loader/loader.py +18 -11
  81. sglang/srt/models/clip.py +4 -4
  82. sglang/srt/models/deepseek_janus_pro.py +1 -1
  83. sglang/srt/models/deepseek_nextn.py +2 -277
  84. sglang/srt/models/deepseek_v2.py +132 -37
  85. sglang/srt/models/gemma3_mm.py +1 -1
  86. sglang/srt/models/internlm2.py +3 -0
  87. sglang/srt/models/internvl.py +670 -0
  88. sglang/srt/models/kimi_vl.py +308 -0
  89. sglang/srt/models/kimi_vl_moonvit.py +639 -0
  90. sglang/srt/models/llama.py +93 -31
  91. sglang/srt/models/llama4.py +54 -7
  92. sglang/srt/models/llama_eagle.py +4 -1
  93. sglang/srt/models/llama_eagle3.py +4 -1
  94. sglang/srt/models/minicpmv.py +1 -1
  95. sglang/srt/models/mllama.py +1 -1
  96. sglang/srt/models/phi3_small.py +16 -2
  97. sglang/srt/models/qwen2_5_vl.py +8 -4
  98. sglang/srt/models/qwen2_moe.py +8 -3
  99. sglang/srt/models/qwen2_vl.py +4 -16
  100. sglang/srt/models/qwen3_moe.py +8 -3
  101. sglang/srt/models/xiaomi_mimo.py +171 -0
  102. sglang/srt/openai_api/adapter.py +58 -62
  103. sglang/srt/openai_api/protocol.py +38 -16
  104. sglang/srt/reasoning_parser.py +2 -2
  105. sglang/srt/sampling/sampling_batch_info.py +54 -2
  106. sglang/srt/sampling/sampling_params.py +2 -0
  107. sglang/srt/server_args.py +93 -24
  108. sglang/srt/speculative/eagle_worker.py +3 -2
  109. sglang/srt/utils.py +123 -10
  110. sglang/test/runners.py +4 -0
  111. sglang/test/test_block_fp8.py +2 -2
  112. sglang/test/test_deepep_utils.py +219 -0
  113. sglang/test/test_utils.py +32 -1
  114. sglang/version.py +1 -1
  115. {sglang-0.4.6.post1.dist-info → sglang-0.4.6.post3.dist-info}/METADATA +18 -9
  116. {sglang-0.4.6.post1.dist-info → sglang-0.4.6.post3.dist-info}/RECORD +119 -99
  117. {sglang-0.4.6.post1.dist-info → sglang-0.4.6.post3.dist-info}/WHEEL +1 -1
  118. {sglang-0.4.6.post1.dist-info → sglang-0.4.6.post3.dist-info}/licenses/LICENSE +0 -0
  119. {sglang-0.4.6.post1.dist-info → sglang-0.4.6.post3.dist-info}/top_level.txt +0 -0
@@ -13,8 +13,10 @@
13
13
  # ==============================================================================
14
14
  """ModelRunner runs the forward passes of the models."""
15
15
 
16
+ import collections
16
17
  import datetime
17
18
  import gc
19
+ import inspect
18
20
  import json
19
21
  import logging
20
22
  import os
@@ -59,7 +61,7 @@ from sglang.srt.mem_cache.memory_pool import (
59
61
  )
60
62
  from sglang.srt.mem_cache.paged_allocator import PagedTokenToKVPoolAllocator
61
63
  from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
62
- from sglang.srt.model_executor.forward_batch_info import ForwardBatch
64
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
63
65
  from sglang.srt.model_loader import get_model
64
66
  from sglang.srt.model_loader.loader import (
65
67
  DefaultModelLoader,
@@ -110,6 +112,8 @@ class ModelRunner:
110
112
  gpu_id: int,
111
113
  tp_rank: int,
112
114
  tp_size: int,
115
+ pp_rank: int,
116
+ pp_size: int,
113
117
  nccl_port: int,
114
118
  server_args: ServerArgs,
115
119
  is_draft_worker: bool = False,
@@ -123,6 +127,8 @@ class ModelRunner:
123
127
  self.gpu_id = gpu_id
124
128
  self.tp_rank = tp_rank
125
129
  self.tp_size = tp_size
130
+ self.pp_rank = pp_rank
131
+ self.pp_size = pp_size
126
132
  self.dist_port = nccl_port
127
133
  self.server_args = server_args
128
134
  self.is_draft_worker = is_draft_worker
@@ -148,25 +154,26 @@ class ModelRunner:
148
154
  global_server_args_dict.update(
149
155
  {
150
156
  "attention_backend": server_args.attention_backend,
151
- "sampling_backend": server_args.sampling_backend,
152
- "triton_attention_reduce_in_fp32": server_args.triton_attention_reduce_in_fp32,
153
- "torchao_config": server_args.torchao_config,
157
+ "debug_tensor_dump_inject": server_args.debug_tensor_dump_inject,
158
+ "debug_tensor_dump_output_folder": server_args.debug_tensor_dump_output_folder,
159
+ "deepep_mode": server_args.deepep_mode,
160
+ "device": server_args.device,
161
+ "disable_chunked_prefix_cache": server_args.disable_chunked_prefix_cache,
162
+ "disable_radix_cache": server_args.disable_radix_cache,
154
163
  "enable_nan_detection": server_args.enable_nan_detection,
155
164
  "enable_dp_attention": server_args.enable_dp_attention,
156
165
  "enable_ep_moe": server_args.enable_ep_moe,
157
166
  "enable_deepep_moe": server_args.enable_deepep_moe,
158
- "deepep_mode": server_args.deepep_mode,
159
- "device": server_args.device,
160
- "speculative_accept_threshold_single": server_args.speculative_accept_threshold_single,
161
- "speculative_accept_threshold_acc": server_args.speculative_accept_threshold_acc,
162
- "disable_radix_cache": server_args.disable_radix_cache,
163
167
  "flashinfer_mla_disable_ragged": server_args.flashinfer_mla_disable_ragged,
164
168
  "moe_dense_tp_size": server_args.moe_dense_tp_size,
165
- "debug_tensor_dump_output_folder": server_args.debug_tensor_dump_output_folder,
166
- "debug_tensor_dump_inject": server_args.debug_tensor_dump_inject,
167
169
  "n_share_experts_fusion": server_args.n_share_experts_fusion,
168
- "disable_chunked_prefix_cache": server_args.disable_chunked_prefix_cache,
170
+ "triton_attention_reduce_in_fp32": server_args.triton_attention_reduce_in_fp32,
171
+ "torchao_config": server_args.torchao_config,
172
+ "sampling_backend": server_args.sampling_backend,
173
+ "speculative_accept_threshold_single": server_args.speculative_accept_threshold_single,
174
+ "speculative_accept_threshold_acc": server_args.speculative_accept_threshold_acc,
169
175
  "use_mla_backend": self.use_mla_backend,
176
+ "mm_attention_backend": server_args.mm_attention_backend,
170
177
  }
171
178
  )
172
179
 
@@ -183,6 +190,11 @@ class ModelRunner:
183
190
  # If it is a draft model, tp_group can be different
184
191
  self.initialize(min_per_gpu_memory)
185
192
 
193
+ # temporary cached values
194
+ self.support_pp = (
195
+ "pp_proxy_tensors" in inspect.signature(self.model.forward).parameters
196
+ )
197
+
186
198
  def initialize(self, min_per_gpu_memory: float):
187
199
  server_args = self.server_args
188
200
  self.memory_saver_adapter = TorchMemorySaverAdapter.create(
@@ -193,6 +205,12 @@ class ModelRunner:
193
205
  self.sampler = Sampler()
194
206
  self.load_model()
195
207
 
208
+ self.start_layer = getattr(self.model, "start_layer", 0)
209
+ self.end_layer = getattr(
210
+ self.model, "end_layer", self.model_config.num_hidden_layers
211
+ )
212
+ self.num_effective_layers = self.end_layer - self.start_layer
213
+
196
214
  # Apply torchao quantization
197
215
  torchao_applied = getattr(self.model, "torchao_applied", False)
198
216
  # In layered loading, torchao may have been applied
@@ -261,9 +279,10 @@ class ModelRunner:
261
279
  server_args.attention_backend = "fa3"
262
280
  else:
263
281
  server_args.attention_backend = "triton"
264
- logger.info(
265
- f"Attention backend not set. Use {server_args.attention_backend} backend by default."
266
- )
282
+ if self.should_log:
283
+ logger.info(
284
+ f"Attention backend not set. Use {server_args.attention_backend} backend by default."
285
+ )
267
286
  elif self.use_mla_backend:
268
287
  if server_args.device != "cpu":
269
288
  if server_args.attention_backend in [
@@ -273,9 +292,10 @@ class ModelRunner:
273
292
  "flashmla",
274
293
  "cutlass_mla",
275
294
  ]:
276
- logger.info(
277
- f"MLA optimization is turned on. Use {server_args.attention_backend} backend."
278
- )
295
+ if self.should_log:
296
+ logger.info(
297
+ f"MLA optimization is turned on. Use {server_args.attention_backend} backend."
298
+ )
279
299
  else:
280
300
  raise ValueError(
281
301
  f"Invalid attention backend for MLA: {server_args.attention_backend}"
@@ -294,9 +314,10 @@ class ModelRunner:
294
314
  server_args.attention_backend = "triton"
295
315
 
296
316
  if server_args.enable_double_sparsity:
297
- logger.info(
298
- "Double sparsity optimization is turned on. Use triton backend without CUDA graph."
299
- )
317
+ if self.should_log:
318
+ logger.info(
319
+ "Double sparsity optimization is turned on. Use triton backend without CUDA graph."
320
+ )
300
321
  server_args.attention_backend = "triton"
301
322
  server_args.disable_cuda_graph = True
302
323
  if server_args.ds_heavy_channel_type is None:
@@ -307,23 +328,26 @@ class ModelRunner:
307
328
 
308
329
  if self.is_multimodal:
309
330
  self.mem_fraction_static *= 0.90
310
- logger.info(
311
- f"Automatically reduce --mem-fraction-static to {self.mem_fraction_static:.3f} "
312
- f"because this is a multimodal model."
313
- )
314
- logger.info(
315
- "Automatically turn off --chunked-prefill-size for multimodal model."
316
- )
331
+ if self.should_log:
332
+ logger.info(
333
+ f"Automatically reduce --mem-fraction-static to {self.mem_fraction_static:.3f} "
334
+ f"because this is a multimodal model."
335
+ )
336
+ logger.info(
337
+ "Automatically turn off --chunked-prefill-size for multimodal model."
338
+ )
317
339
  server_args.chunked_prefill_size = -1
318
340
 
319
341
  if not self.use_mla_backend:
320
342
  server_args.disable_chunked_prefix_cache = True
321
343
  elif self.page_size > 1:
322
- logger.info("Disable chunked prefix cache when page size > 1.")
344
+ if self.should_log:
345
+ logger.info("Disable chunked prefix cache when page size > 1.")
323
346
  server_args.disable_chunked_prefix_cache = True
324
347
 
325
348
  if not server_args.disable_chunked_prefix_cache:
326
- logger.info("Chunked prefix cache is turned on.")
349
+ if self.should_log:
350
+ logger.info("Chunked prefix cache is turned on.")
327
351
 
328
352
  def init_torch_distributed(self):
329
353
  logger.info("Init torch distributed begin.")
@@ -344,6 +368,8 @@ class ModelRunner:
344
368
  backend = "hccl"
345
369
  elif self.device == "cpu":
346
370
  backend = "gloo"
371
+ elif self.device == "npu":
372
+ backend = "hccl"
347
373
 
348
374
  before_avail_memory = get_available_gpu_memory(self.device, self.gpu_id)
349
375
  if not self.server_args.enable_p2p_check:
@@ -359,18 +385,22 @@ class ModelRunner:
359
385
  # Only initialize the distributed environment on the target model worker.
360
386
  init_distributed_environment(
361
387
  backend=backend,
362
- world_size=self.tp_size,
363
- rank=self.tp_rank,
388
+ world_size=self.tp_size * self.pp_size,
389
+ rank=self.tp_size * self.pp_rank + self.tp_rank,
364
390
  local_rank=self.gpu_id,
365
391
  distributed_init_method=dist_init_method,
366
392
  timeout=self.server_args.dist_timeout,
367
393
  )
368
- initialize_model_parallel(tensor_model_parallel_size=self.tp_size)
394
+ initialize_model_parallel(
395
+ tensor_model_parallel_size=self.tp_size,
396
+ pipeline_model_parallel_size=self.pp_size,
397
+ )
369
398
  initialize_dp_attention(
370
399
  enable_dp_attention=self.server_args.enable_dp_attention,
371
400
  tp_rank=self.tp_rank,
372
401
  tp_size=self.tp_size,
373
402
  dp_size=self.server_args.dp_size,
403
+ pp_size=self.server_args.pp_size,
374
404
  )
375
405
 
376
406
  min_per_gpu_memory = get_available_gpu_memory(
@@ -410,9 +440,10 @@ class ModelRunner:
410
440
  torch.set_num_threads(1)
411
441
  if self.device == "cuda":
412
442
  if torch.cuda.get_device_capability()[0] < 8:
413
- logger.info(
414
- "Compute capability below sm80. Use float16 due to lack of bfloat16 support."
415
- )
443
+ if self.should_log:
444
+ logger.info(
445
+ "Compute capability below sm80. Use float16 due to lack of bfloat16 support."
446
+ )
416
447
  self.server_args.dtype = "float16"
417
448
  self.model_config.dtype = torch.float16
418
449
  if torch.cuda.get_device_capability()[1] < 5:
@@ -448,10 +479,11 @@ class ModelRunner:
448
479
  self.model.load_kv_cache_scales(
449
480
  self.server_args.quantization_param_path
450
481
  )
451
- logger.info(
452
- "Loaded KV cache scaling factors from %s",
453
- self.server_args.quantization_param_path,
454
- )
482
+ if self.should_log:
483
+ logger.info(
484
+ "Loaded KV cache scaling factors from %s",
485
+ self.server_args.quantization_param_path,
486
+ )
455
487
  else:
456
488
  raise RuntimeError(
457
489
  "Using FP8 KV cache and scaling factors provided but "
@@ -526,12 +558,7 @@ class ModelRunner:
526
558
  return iter
527
559
 
528
560
  def model_load_weights(model, iter):
529
- model.load_weights(iter)
530
- for _, module in self.model.named_modules():
531
- quant_method = getattr(module, "quant_method", None)
532
- if quant_method is not None:
533
- with device_loading_context(module, target_device):
534
- quant_method.process_weights_after_loading(module)
561
+ DefaultModelLoader.load_weights_and_postprocess(model, iter, target_device)
535
562
  return model
536
563
 
537
564
  with set_default_torch_dtype(self.model_config.dtype):
@@ -692,16 +719,23 @@ class ModelRunner:
692
719
  self.device, self.gpu_id, distributed=self.tp_size > 1
693
720
  )
694
721
  if self.use_mla_backend:
722
+ num_layers = (
723
+ self.model_config.num_hidden_layers
724
+ if not self.is_draft_worker
725
+ else self.model_config.hf_config.num_nextn_predict_layers
726
+ )
727
+ # FIXME: pipeline parallelism is not compatible with mla backend
728
+ assert self.pp_size == 1
695
729
  cell_size = (
696
730
  (self.model_config.kv_lora_rank + self.model_config.qk_rope_head_dim)
697
- * self.model_config.num_hidden_layers
731
+ * num_layers
698
732
  * torch._utils._element_size(self.kv_cache_dtype)
699
733
  )
700
734
  else:
701
735
  cell_size = (
702
736
  self.model_config.get_num_kv_heads(get_attention_tp_size())
703
737
  * self.model_config.head_dim
704
- * self.model_config.num_hidden_layers
738
+ * self.num_effective_layers
705
739
  * 2
706
740
  * torch._utils._element_size(self.kv_cache_dtype)
707
741
  )
@@ -809,9 +843,15 @@ class ModelRunner:
809
843
  dtype=self.kv_cache_dtype,
810
844
  kv_lora_rank=self.model_config.kv_lora_rank,
811
845
  qk_rope_head_dim=self.model_config.qk_rope_head_dim,
812
- layer_num=self.model_config.num_hidden_layers,
846
+ layer_num=(
847
+ self.model_config.num_hidden_layers
848
+ if not self.is_draft_worker
849
+ else self.model_config.hf_config.num_nextn_predict_layers
850
+ ), # PP is not compatible with mla backend
813
851
  device=self.device,
814
852
  enable_memory_saver=self.server_args.enable_memory_saver,
853
+ start_layer=self.start_layer,
854
+ end_layer=self.end_layer,
815
855
  )
816
856
  elif self.server_args.enable_double_sparsity:
817
857
  self.token_to_kv_pool = DoubleSparseTokenToKVPool(
@@ -820,10 +860,12 @@ class ModelRunner:
820
860
  dtype=self.kv_cache_dtype,
821
861
  head_num=self.model_config.get_num_kv_heads(get_attention_tp_size()),
822
862
  head_dim=self.model_config.head_dim,
823
- layer_num=self.model_config.num_hidden_layers,
863
+ layer_num=self.num_effective_layers,
824
864
  device=self.device,
825
865
  heavy_channel_num=self.server_args.ds_heavy_channel_num,
826
866
  enable_memory_saver=self.server_args.enable_memory_saver,
867
+ start_layer=self.start_layer,
868
+ end_layer=self.end_layer,
827
869
  )
828
870
  else:
829
871
  self.token_to_kv_pool = MHATokenToKVPool(
@@ -832,9 +874,11 @@ class ModelRunner:
832
874
  dtype=self.kv_cache_dtype,
833
875
  head_num=self.model_config.get_num_kv_heads(get_attention_tp_size()),
834
876
  head_dim=self.model_config.head_dim,
835
- layer_num=self.model_config.num_hidden_layers,
877
+ layer_num=self.num_effective_layers,
836
878
  device=self.device,
837
879
  enable_memory_saver=self.server_args.enable_memory_saver,
880
+ start_layer=self.start_layer,
881
+ end_layer=self.end_layer,
838
882
  )
839
883
 
840
884
  if self.token_to_kv_pool_allocator is None:
@@ -918,8 +962,10 @@ class ModelRunner:
918
962
 
919
963
  self.attn_backend = FlashMLABackend(self)
920
964
  elif self.server_args.attention_backend == "fa3":
921
- assert torch.cuda.get_device_capability()[0] >= 9, (
922
- "FlashAttention v3 Backend requires SM>=90. "
965
+ assert (
966
+ torch.cuda.get_device_capability()[0] == 8 and not self.use_mla_backend
967
+ ) or torch.cuda.get_device_capability()[0] == 9, (
968
+ "FlashAttention v3 Backend requires SM>=80 and SM<=90. "
923
969
  "Please use `--attention-backend flashinfer`."
924
970
  )
925
971
  from sglang.srt.layers.attention.flashattention_backend import (
@@ -945,7 +991,7 @@ class ModelRunner:
945
991
  with open(self.server_args.ds_channel_config_path, "r") as f:
946
992
  channel_config = json.load(f)
947
993
 
948
- for i in range(self.model_config.num_hidden_layers):
994
+ for i in range(self.start_layer, self.end_layer):
949
995
  key = "model.layers." + str(i) + ".self_attn" + selected_channel
950
996
  self.sorted_channels.append(
951
997
  torch.tensor(channel_config[key])[
@@ -979,70 +1025,89 @@ class ModelRunner:
979
1025
  )
980
1026
 
981
1027
  def apply_torch_tp(self):
982
- logger.info(f"Enabling torch tensor parallelism on {self.tp_size} devices.")
1028
+ if self.should_log:
1029
+ logger.info(f"Enabling torch tensor parallelism on {self.tp_size} devices.")
983
1030
  from sglang.srt.model_parallel import tensor_parallel
984
1031
 
985
1032
  device_mesh = torch.distributed.init_device_mesh(self.device, (self.tp_size,))
986
1033
  tensor_parallel(self.model, device_mesh)
987
1034
 
988
- def forward_decode(self, forward_batch: ForwardBatch):
1035
+ def forward_decode(
1036
+ self, forward_batch: ForwardBatch, pp_proxy_tensors=None
1037
+ ) -> LogitsProcessorOutput:
989
1038
  self.attn_backend.init_forward_metadata(forward_batch)
1039
+ # FIXME: add pp_proxy_tensors arg to all models
1040
+ kwargs = {}
1041
+ if self.support_pp:
1042
+ kwargs["pp_proxy_tensors"] = pp_proxy_tensors
990
1043
  return self.model.forward(
991
- forward_batch.input_ids, forward_batch.positions, forward_batch
1044
+ forward_batch.input_ids, forward_batch.positions, forward_batch, **kwargs
992
1045
  )
993
1046
 
994
1047
  def forward_extend(
995
- self, forward_batch: ForwardBatch, skip_attn_backend_init: bool = False
996
- ):
1048
+ self,
1049
+ forward_batch: ForwardBatch,
1050
+ skip_attn_backend_init: bool = False,
1051
+ pp_proxy_tensors=None,
1052
+ ) -> LogitsProcessorOutput:
997
1053
  if not skip_attn_backend_init:
998
1054
  self.attn_backend.init_forward_metadata(forward_batch)
999
1055
 
1000
- if self.is_generation:
1001
- if forward_batch.input_embeds is None:
1002
- return self.model.forward(
1003
- forward_batch.input_ids, forward_batch.positions, forward_batch
1004
- )
1005
- else:
1006
- return self.model.forward(
1007
- forward_batch.input_ids,
1008
- forward_batch.positions,
1009
- forward_batch,
1010
- input_embeds=forward_batch.input_embeds.bfloat16(),
1011
- )
1012
- else:
1013
- # Only embedding models have get_embedding parameter
1014
- return self.model.forward(
1015
- forward_batch.input_ids,
1016
- forward_batch.positions,
1017
- forward_batch,
1018
- get_embedding=True,
1019
- )
1056
+ kwargs = {}
1057
+ if self.support_pp:
1058
+ kwargs["pp_proxy_tensors"] = pp_proxy_tensors
1059
+ if forward_batch.input_embeds is not None:
1060
+ kwargs["input_embeds"] = forward_batch.input_embeds.bfloat16()
1061
+ if not self.is_generation:
1062
+ kwargs["get_embedding"] = True
1063
+ return self.model.forward(
1064
+ forward_batch.input_ids,
1065
+ forward_batch.positions,
1066
+ forward_batch,
1067
+ **kwargs,
1068
+ )
1020
1069
 
1021
- def forward_idle(self, forward_batch: ForwardBatch):
1070
+ def forward_idle(
1071
+ self, forward_batch: ForwardBatch, pp_proxy_tensors=None
1072
+ ) -> LogitsProcessorOutput:
1073
+ kwargs = {}
1074
+ if self.support_pp:
1075
+ kwargs["pp_proxy_tensors"] = pp_proxy_tensors
1022
1076
  return self.model.forward(
1023
- forward_batch.input_ids, forward_batch.positions, forward_batch
1077
+ forward_batch.input_ids,
1078
+ forward_batch.positions,
1079
+ forward_batch,
1080
+ **kwargs,
1024
1081
  )
1025
1082
 
1026
1083
  def forward(
1027
- self, forward_batch: ForwardBatch, skip_attn_backend_init: bool = False
1028
- ) -> LogitsProcessorOutput:
1029
- if (
1084
+ self,
1085
+ forward_batch: ForwardBatch,
1086
+ skip_attn_backend_init: bool = False,
1087
+ pp_proxy_tensors: Optional[PPProxyTensors] = None,
1088
+ ) -> Union[LogitsProcessorOutput, PPProxyTensors]:
1089
+ can_run_cuda_graph = bool(
1030
1090
  forward_batch.forward_mode.is_cuda_graph()
1031
1091
  and self.cuda_graph_runner
1032
1092
  and self.cuda_graph_runner.can_run(forward_batch)
1033
- ):
1093
+ )
1094
+ if can_run_cuda_graph:
1034
1095
  return self.cuda_graph_runner.replay(
1035
- forward_batch, skip_attn_backend_init=skip_attn_backend_init
1096
+ forward_batch,
1097
+ skip_attn_backend_init=skip_attn_backend_init,
1098
+ pp_proxy_tensors=pp_proxy_tensors,
1036
1099
  )
1037
1100
 
1038
1101
  if forward_batch.forward_mode.is_decode():
1039
- return self.forward_decode(forward_batch)
1102
+ return self.forward_decode(forward_batch, pp_proxy_tensors=pp_proxy_tensors)
1040
1103
  elif forward_batch.forward_mode.is_extend():
1041
1104
  return self.forward_extend(
1042
- forward_batch, skip_attn_backend_init=skip_attn_backend_init
1105
+ forward_batch,
1106
+ skip_attn_backend_init=skip_attn_backend_init,
1107
+ pp_proxy_tensors=pp_proxy_tensors,
1043
1108
  )
1044
1109
  elif forward_batch.forward_mode.is_idle():
1045
- return self.forward_idle(forward_batch)
1110
+ return self.forward_idle(forward_batch, pp_proxy_tensors=pp_proxy_tensors)
1046
1111
  else:
1047
1112
  raise ValueError(f"Invalid forward mode: {forward_batch.forward_mode}")
1048
1113
 
@@ -1080,7 +1145,9 @@ class ModelRunner:
1080
1145
  [self.sample(values, forward_batch) for values in logits_output],
1081
1146
  axis=-1,
1082
1147
  )
1083
-
1148
+ sampling_info = forward_batch.sampling_info
1149
+ if sampling_info.thinking_budgets is not None:
1150
+ sampling_info.apply_thinking_budgets(logits_output.next_token_logits)
1084
1151
  self._preprocess_logits(logits_output, forward_batch.sampling_info)
1085
1152
 
1086
1153
  # Sample the next tokens
@@ -1091,6 +1158,8 @@ class ModelRunner:
1091
1158
  forward_batch.top_logprobs_nums,
1092
1159
  forward_batch.token_ids_logprobs,
1093
1160
  )
1161
+ if sampling_info.thinking_budgets is not None:
1162
+ sampling_info.update_thinking_budgets(next_token_ids)
1094
1163
  return next_token_ids
1095
1164
 
1096
1165
  @property
@@ -374,20 +374,27 @@ class DefaultModelLoader(BaseModelLoader):
374
374
  self.load_config,
375
375
  )
376
376
 
377
- model.load_weights(self._get_all_weights(model_config, model))
377
+ self.load_weights_and_postprocess(
378
+ model, self._get_all_weights(model_config, model), target_device
379
+ )
378
380
 
379
- for _, module in model.named_modules():
380
- quant_method = getattr(module, "quant_method", None)
381
- if quant_method is not None:
382
- # When quant methods need to process weights after loading
383
- # (for repacking, quantizing, etc), they expect parameters
384
- # to be on the global target device. This scope is for the
385
- # case where cpu offloading is used, where we will move the
386
- # parameters onto device for processing and back off after.
387
- with device_loading_context(module, target_device):
388
- quant_method.process_weights_after_loading(module)
389
381
  return model.eval()
390
382
 
383
+ @staticmethod
384
+ def load_weights_and_postprocess(model, weights, target_device):
385
+ model.load_weights(weights)
386
+
387
+ for _, module in model.named_modules():
388
+ quant_method = getattr(module, "quant_method", None)
389
+ if quant_method is not None:
390
+ # When quant methods need to process weights after loading
391
+ # (for repacking, quantizing, etc), they expect parameters
392
+ # to be on the global target device. This scope is for the
393
+ # case where cpu offloading is used, where we will move the
394
+ # parameters onto device for processing and back off after.
395
+ with device_loading_context(module, target_device):
396
+ quant_method.process_weights_after_loading(module)
397
+
391
398
 
392
399
  class LayeredModelLoader(DefaultModelLoader):
393
400
  """Model loader that loads weights layer by layer so that one can quantize a
sglang/srt/models/clip.py CHANGED
@@ -151,20 +151,20 @@ class CLIPEncoderLayer(nn.Module):
151
151
  self.layer_norm1 = norm_layer(config.hidden_size)
152
152
  self.layer_norm2 = norm_layer(config.hidden_size)
153
153
  if attn_implementation == "sdpa":
154
- use_context_forward = False
154
+ qkv_backend = "sdpa"
155
155
  softmax_in_single_precision = False
156
156
  elif attn_implementation == "flash_attention_2":
157
+ qkv_backend = "triton_attn"
157
158
  softmax_in_single_precision = False
158
- use_context_forward = True
159
159
  elif attn_implementation == "eager":
160
+ qkv_backend = "sdpa"
160
161
  softmax_in_single_precision = True
161
- use_context_forward = False
162
162
  self.self_attn = VisionAttention(
163
163
  embed_dim=config.hidden_size,
164
164
  num_heads=config.num_attention_heads,
165
165
  projection_size=config.hidden_size,
166
166
  use_qkv_parallel=True,
167
- use_context_forward=use_context_forward,
167
+ qkv_backend=qkv_backend,
168
168
  softmax_in_single_precision=softmax_in_single_precision,
169
169
  flatten_batch=True,
170
170
  quant_config=quant_config,
@@ -532,7 +532,7 @@ class VisionTransformerBlock(nn.Module):
532
532
  num_heads=num_heads,
533
533
  projection_size=dim,
534
534
  use_qkv_parallel=True,
535
- use_context_forward=False,
535
+ qkv_backend="sdpa",
536
536
  softmax_in_single_precision=False,
537
537
  dropout=attn_drop,
538
538
  )