sglang 0.4.6__py3-none-any.whl → 0.4.6.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (85) hide show
  1. sglang/bench_one_batch.py +2 -0
  2. sglang/check_env.py +3 -3
  3. sglang/srt/configs/__init__.py +4 -0
  4. sglang/srt/configs/kimi_vl.py +38 -0
  5. sglang/srt/configs/kimi_vl_moonvit.py +32 -0
  6. sglang/srt/configs/model_config.py +15 -0
  7. sglang/srt/conversation.py +122 -1
  8. sglang/srt/disaggregation/decode.py +8 -2
  9. sglang/srt/disaggregation/fake/__init__.py +1 -0
  10. sglang/srt/disaggregation/fake/conn.py +88 -0
  11. sglang/srt/disaggregation/prefill.py +12 -3
  12. sglang/srt/disaggregation/utils.py +16 -2
  13. sglang/srt/entrypoints/engine.py +52 -21
  14. sglang/srt/entrypoints/http_server.py +27 -2
  15. sglang/srt/function_call_parser.py +97 -0
  16. sglang/srt/hf_transformers_utils.py +2 -0
  17. sglang/srt/layers/attention/cutlass_mla_backend.py +278 -0
  18. sglang/srt/layers/attention/flashinfer_backend.py +107 -82
  19. sglang/srt/layers/attention/flashinfer_mla_backend.py +27 -16
  20. sglang/srt/layers/attention/flashmla_backend.py +3 -0
  21. sglang/srt/layers/attention/utils.py +1 -1
  22. sglang/srt/layers/dp_attention.py +5 -2
  23. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +1 -3
  24. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=192,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  25. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=192,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  26. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=192,device_name=NVIDIA_H20.json +146 -0
  27. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=192,device_name=NVIDIA_H200.json +146 -0
  28. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  29. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  30. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=384,device_name=NVIDIA_H20.json +146 -0
  31. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=384,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  32. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=384,device_name=NVIDIA_H200.json +146 -0
  33. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  34. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  35. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  36. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_H20.json +146 -0
  37. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  38. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_H200.json +146 -0
  39. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=96,device_name=NVIDIA_H20.json +146 -0
  40. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +10 -8
  41. sglang/srt/layers/moe/fused_moe_triton/layer.py +15 -17
  42. sglang/srt/layers/quantization/__init__.py +2 -2
  43. sglang/srt/layers/quantization/deep_gemm.py +1 -1
  44. sglang/srt/layers/quantization/fp8.py +20 -22
  45. sglang/srt/layers/quantization/fp8_utils.py +2 -2
  46. sglang/srt/layers/utils.py +35 -0
  47. sglang/srt/lora/layers.py +35 -9
  48. sglang/srt/lora/lora_manager.py +84 -35
  49. sglang/srt/managers/data_parallel_controller.py +52 -34
  50. sglang/srt/managers/multimodal_processors/kimi_vl.py +73 -0
  51. sglang/srt/managers/schedule_batch.py +34 -15
  52. sglang/srt/managers/scheduler.py +273 -67
  53. sglang/srt/managers/scheduler_output_processor_mixin.py +26 -10
  54. sglang/srt/managers/tp_worker.py +52 -17
  55. sglang/srt/managers/tp_worker_overlap_thread.py +18 -7
  56. sglang/srt/mem_cache/memory_pool.py +70 -36
  57. sglang/srt/model_executor/cuda_graph_runner.py +82 -19
  58. sglang/srt/model_executor/forward_batch_info.py +31 -1
  59. sglang/srt/model_executor/model_runner.py +123 -58
  60. sglang/srt/models/deepseek_nextn.py +1 -257
  61. sglang/srt/models/deepseek_v2.py +78 -18
  62. sglang/srt/models/kimi_vl.py +308 -0
  63. sglang/srt/models/kimi_vl_moonvit.py +639 -0
  64. sglang/srt/models/llama.py +92 -30
  65. sglang/srt/models/llama4.py +2 -1
  66. sglang/srt/models/llama_eagle.py +4 -1
  67. sglang/srt/models/llama_eagle3.py +4 -1
  68. sglang/srt/models/qwen2_moe.py +8 -3
  69. sglang/srt/models/qwen2_vl.py +0 -12
  70. sglang/srt/models/qwen3_moe.py +8 -3
  71. sglang/srt/openai_api/adapter.py +49 -8
  72. sglang/srt/openai_api/protocol.py +13 -1
  73. sglang/srt/reasoning_parser.py +25 -1
  74. sglang/srt/server_args.py +83 -24
  75. sglang/srt/speculative/eagle_worker.py +3 -2
  76. sglang/srt/utils.py +91 -9
  77. sglang/test/runners.py +4 -0
  78. sglang/test/send_one.py +84 -28
  79. sglang/test/test_utils.py +67 -0
  80. sglang/version.py +1 -1
  81. {sglang-0.4.6.dist-info → sglang-0.4.6.post2.dist-info}/METADATA +5 -4
  82. {sglang-0.4.6.dist-info → sglang-0.4.6.post2.dist-info}/RECORD +85 -60
  83. {sglang-0.4.6.dist-info → sglang-0.4.6.post2.dist-info}/WHEEL +1 -1
  84. {sglang-0.4.6.dist-info → sglang-0.4.6.post2.dist-info}/licenses/LICENSE +0 -0
  85. {sglang-0.4.6.dist-info → sglang-0.4.6.post2.dist-info}/top_level.txt +0 -0
@@ -16,6 +16,7 @@
16
16
  from __future__ import annotations
17
17
 
18
18
  import bisect
19
+ import inspect
19
20
  import os
20
21
  from contextlib import contextmanager
21
22
  from typing import TYPE_CHECKING, Callable
@@ -33,12 +34,14 @@ from sglang.srt.model_executor.forward_batch_info import (
33
34
  CaptureHiddenMode,
34
35
  ForwardBatch,
35
36
  ForwardMode,
37
+ PPProxyTensors,
36
38
  )
37
39
  from sglang.srt.patch_torch import monkey_patch_torch_compile
38
40
  from sglang.srt.utils import (
39
41
  get_available_gpu_memory,
40
42
  get_device_memory_capacity,
41
43
  is_hip,
44
+ rank0_log,
42
45
  )
43
46
 
44
47
  if TYPE_CHECKING:
@@ -135,7 +138,7 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner):
135
138
 
136
139
  gpu_mem = get_device_memory_capacity()
137
140
  # Batch size of each rank will not become so large when DP is on
138
- if gpu_mem is not None and gpu_mem > 81920 and server_args.dp_size == 1:
141
+ if gpu_mem is not None and gpu_mem > 96 * 1024:
139
142
  capture_bs += list(range(160, 257, 8))
140
143
 
141
144
  if max(capture_bs) > model_runner.req_to_token_pool.size:
@@ -188,10 +191,11 @@ class CudaGraphRunner:
188
191
  self.speculative_algorithm = model_runner.server_args.speculative_algorithm
189
192
  self.tp_size = model_runner.server_args.tp_size
190
193
  self.dp_size = model_runner.server_args.dp_size
194
+ self.pp_size = model_runner.server_args.pp_size
191
195
 
192
196
  # Batch sizes to capture
193
197
  self.capture_bs, self.compile_bs = get_batch_sizes_to_capture(model_runner)
194
-
198
+ rank0_log(f"Capture cuda graph bs {self.capture_bs}")
195
199
  self.capture_forward_mode = ForwardMode.DECODE
196
200
  self.capture_hidden_mode = CaptureHiddenMode.NULL
197
201
  self.num_tokens_per_bs = 1
@@ -220,6 +224,9 @@ class CudaGraphRunner:
220
224
  if self.enable_torch_compile:
221
225
  set_torch_compile_config()
222
226
 
227
+ if self.model_runner.server_args.lora_paths is not None:
228
+ self.model_runner.lora_manager.init_cuda_graph_batch_info(self.max_bs)
229
+
223
230
  # Graph inputs
224
231
  with torch.device("cuda"):
225
232
  self.input_ids = torch.zeros((self.max_num_token,), dtype=torch.int64)
@@ -231,6 +238,19 @@ class CudaGraphRunner:
231
238
  self.positions = torch.zeros((self.max_num_token,), dtype=torch.int64)
232
239
  self.mrope_positions = torch.zeros((3, self.max_bs), dtype=torch.int64)
233
240
 
241
+ # pipeline parallelism
242
+ if self.pp_size > 1:
243
+ self.pp_proxy_tensors = {
244
+ "hidden_states": torch.zeros(
245
+ (self.max_bs, self.model_runner.model_config.hidden_size),
246
+ dtype=torch.bfloat16,
247
+ ),
248
+ "residual": torch.zeros(
249
+ (self.max_bs, self.model_runner.model_config.hidden_size),
250
+ dtype=torch.bfloat16,
251
+ ),
252
+ }
253
+
234
254
  # Speculative_inference
235
255
  if (
236
256
  model_runner.spec_algorithm.is_eagle3()
@@ -381,6 +401,12 @@ class CudaGraphRunner:
381
401
  encoder_lens = None
382
402
  mrope_positions = self.mrope_positions[:, :bs]
383
403
 
404
+ # pipeline parallelism
405
+ if self.pp_size > 1:
406
+ pp_proxy_tensors = PPProxyTensors(
407
+ {k: v[:num_tokens] for k, v in self.pp_proxy_tensors.items()}
408
+ )
409
+
384
410
  if self.enable_dp_attention or self.enable_sp_layernorm:
385
411
  self.global_num_tokens_gpu.copy_(
386
412
  torch.tensor(
@@ -403,6 +429,13 @@ class CudaGraphRunner:
403
429
  self.capture_hidden_mode = (
404
430
  spec_info.capture_hidden_mode if spec_info else CaptureHiddenMode.NULL
405
431
  )
432
+ if self.model_runner.server_args.lora_paths is not None:
433
+ # Currently, if the lora_path in `lora_paths` is None, the lora backend will use a
434
+ # different logic to handle lora, so we need to set `lora_paths` to a list of non-None
435
+ # values if lora is enabled.
436
+ lora_paths = [next(iter(self.model_runner.server_args.lora_paths))] * bs
437
+ else:
438
+ lora_paths = None
406
439
 
407
440
  forward_batch = ForwardBatch(
408
441
  forward_mode=self.capture_forward_mode,
@@ -424,8 +457,12 @@ class CudaGraphRunner:
424
457
  spec_algorithm=self.model_runner.spec_algorithm,
425
458
  spec_info=spec_info,
426
459
  capture_hidden_mode=self.capture_hidden_mode,
460
+ lora_paths=lora_paths,
427
461
  )
428
462
 
463
+ if lora_paths is not None:
464
+ self.model_runner.lora_manager.prepare_lora_batch(forward_batch)
465
+
429
466
  # Attention backend
430
467
  self.model_runner.attn_backend.init_forward_metadata_capture_cuda_graph(
431
468
  bs,
@@ -442,8 +479,20 @@ class CudaGraphRunner:
442
479
  # Clean intermediate result cache for DP attention
443
480
  forward_batch.dp_local_start_pos = forward_batch.dp_local_num_tokens = None
444
481
 
445
- logits_output = forward(input_ids, forward_batch.positions, forward_batch)
446
- return logits_output.next_token_logits, logits_output.hidden_states
482
+ kwargs = {}
483
+ if (
484
+ self.pp_size > 1
485
+ and "pp_proxy_tensors" in inspect.signature(forward).parameters
486
+ ):
487
+ kwargs["pp_proxy_tensors"] = pp_proxy_tensors
488
+
489
+ logits_output_or_pp_proxy_tensors = forward(
490
+ input_ids,
491
+ forward_batch.positions,
492
+ forward_batch,
493
+ **kwargs,
494
+ )
495
+ return logits_output_or_pp_proxy_tensors
447
496
 
448
497
  for _ in range(2):
449
498
  torch.cuda.synchronize()
@@ -476,7 +525,11 @@ class CudaGraphRunner:
476
525
  self.capture_hidden_mode = hidden_mode_from_spec_info
477
526
  self.capture()
478
527
 
479
- def replay_prepare(self, forward_batch: ForwardBatch):
528
+ def replay_prepare(
529
+ self,
530
+ forward_batch: ForwardBatch,
531
+ pp_proxy_tensors: Optional[PPProxyTensors] = None,
532
+ ):
480
533
  self.recapture_if_needed(forward_batch)
481
534
 
482
535
  raw_bs = forward_batch.batch_size
@@ -505,6 +558,11 @@ class CudaGraphRunner:
505
558
  self.seq_lens_cpu.fill_(1)
506
559
  self.seq_lens_cpu[:raw_bs].copy_(forward_batch.seq_lens_cpu)
507
560
 
561
+ if pp_proxy_tensors:
562
+ for key in self.pp_proxy_tensors.keys():
563
+ dim = pp_proxy_tensors[key].shape[0]
564
+ self.pp_proxy_tensors[key][:dim].copy_(pp_proxy_tensors[key])
565
+
508
566
  if self.is_encoder_decoder:
509
567
  self.encoder_lens[:raw_bs].copy_(forward_batch.encoder_lens)
510
568
  if forward_batch.mrope_positions is not None:
@@ -533,10 +591,13 @@ class CudaGraphRunner:
533
591
  self.bs = bs
534
592
 
535
593
  def replay(
536
- self, forward_batch: ForwardBatch, skip_attn_backend_init: bool = False
537
- ) -> LogitsProcessorOutput:
594
+ self,
595
+ forward_batch: ForwardBatch,
596
+ skip_attn_backend_init: bool = False,
597
+ pp_proxy_tensors: Optional[PPProxyTensors] = None,
598
+ ) -> Union[LogitsProcessorOutput, PPProxyTensors]:
538
599
  if not skip_attn_backend_init:
539
- self.replay_prepare(forward_batch)
600
+ self.replay_prepare(forward_batch, pp_proxy_tensors)
540
601
  else:
541
602
  # In speculative decoding, these two fields are still needed.
542
603
  self.input_ids[: self.raw_num_token].copy_(forward_batch.input_ids)
@@ -544,17 +605,19 @@ class CudaGraphRunner:
544
605
 
545
606
  # Replay
546
607
  self.graphs[self.bs].replay()
547
- next_token_logits, hidden_states = self.output_buffers[self.bs]
548
-
549
- logits_output = LogitsProcessorOutput(
550
- next_token_logits=next_token_logits[: self.raw_num_token],
551
- hidden_states=(
552
- hidden_states[: self.raw_num_token]
553
- if hidden_states is not None
554
- else None
555
- ),
556
- )
557
- return logits_output
608
+ output = self.output_buffers[self.bs]
609
+ if isinstance(output, LogitsProcessorOutput):
610
+ return LogitsProcessorOutput(
611
+ next_token_logits=output.next_token_logits[: self.raw_num_token],
612
+ hidden_states=(
613
+ output.hidden_states[: self.raw_num_token]
614
+ if output.hidden_states is not None
615
+ else None
616
+ ),
617
+ )
618
+ else:
619
+ assert isinstance(output, PPProxyTensors)
620
+ return PPProxyTensors({k: v[: self.bs] for k, v in output.tensors.items()})
558
621
 
559
622
  def get_spec_info(self, num_tokens: int):
560
623
  spec_info = None
@@ -31,7 +31,7 @@ from __future__ import annotations
31
31
 
32
32
  from dataclasses import dataclass
33
33
  from enum import IntEnum, auto
34
- from typing import TYPE_CHECKING, List, Optional, Union
34
+ from typing import TYPE_CHECKING, Dict, List, Optional, Union
35
35
 
36
36
  import torch
37
37
  import triton
@@ -585,6 +585,36 @@ class ForwardBatch:
585
585
  self.prepare_chunked_kv_indices(device)
586
586
 
587
587
 
588
+ class PPProxyTensors:
589
+ # adapted from https://github.com/vllm-project/vllm/blob/d14e98d924724b284dc5eaf8070d935e214e50c0/vllm/sequence.py#L1103
590
+ tensors: Dict[str, torch.Tensor]
591
+
592
+ def __init__(self, tensors):
593
+ # manually define this function, so that
594
+ # Dynamo knows `IntermediateTensors()` comes from this file.
595
+ # Otherwise, dataclass will generate this function by evaluating
596
+ # a string, and we will lose the information about the source file.
597
+ self.tensors = tensors
598
+
599
+ def __getitem__(self, key: Union[str, slice]):
600
+ if isinstance(key, str):
601
+ return self.tensors[key]
602
+ elif isinstance(key, slice):
603
+ return self.__class__({k: v[key] for k, v in self.tensors.items()})
604
+
605
+ def __setitem__(self, key: str, value: torch.Tensor):
606
+ self.tensors[key] = value
607
+
608
+ def __len__(self):
609
+ return len(self.tensors)
610
+
611
+ def __eq__(self, other: object):
612
+ return isinstance(other, self.__class__) and self
613
+
614
+ def __repr__(self) -> str:
615
+ return f"PPProxyTensors(tensors={self.tensors})"
616
+
617
+
588
618
  def compute_position_triton(
589
619
  extend_prefix_lens: torch.Tensor, extend_seq_lens: torch.Tensor, extend_seq_lens_sum
590
620
  ):
@@ -13,8 +13,10 @@
13
13
  # ==============================================================================
14
14
  """ModelRunner runs the forward passes of the models."""
15
15
 
16
+ import collections
16
17
  import datetime
17
18
  import gc
19
+ import inspect
18
20
  import json
19
21
  import logging
20
22
  import os
@@ -59,7 +61,7 @@ from sglang.srt.mem_cache.memory_pool import (
59
61
  )
60
62
  from sglang.srt.mem_cache.paged_allocator import PagedTokenToKVPoolAllocator
61
63
  from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
62
- from sglang.srt.model_executor.forward_batch_info import ForwardBatch
64
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
63
65
  from sglang.srt.model_loader import get_model
64
66
  from sglang.srt.model_loader.loader import (
65
67
  DefaultModelLoader,
@@ -110,6 +112,8 @@ class ModelRunner:
110
112
  gpu_id: int,
111
113
  tp_rank: int,
112
114
  tp_size: int,
115
+ pp_rank: int,
116
+ pp_size: int,
113
117
  nccl_port: int,
114
118
  server_args: ServerArgs,
115
119
  is_draft_worker: bool = False,
@@ -123,6 +127,8 @@ class ModelRunner:
123
127
  self.gpu_id = gpu_id
124
128
  self.tp_rank = tp_rank
125
129
  self.tp_size = tp_size
130
+ self.pp_rank = pp_rank
131
+ self.pp_size = pp_size
126
132
  self.dist_port = nccl_port
127
133
  self.server_args = server_args
128
134
  self.is_draft_worker = is_draft_worker
@@ -148,24 +154,24 @@ class ModelRunner:
148
154
  global_server_args_dict.update(
149
155
  {
150
156
  "attention_backend": server_args.attention_backend,
151
- "sampling_backend": server_args.sampling_backend,
152
- "triton_attention_reduce_in_fp32": server_args.triton_attention_reduce_in_fp32,
153
- "torchao_config": server_args.torchao_config,
157
+ "debug_tensor_dump_inject": server_args.debug_tensor_dump_inject,
158
+ "debug_tensor_dump_output_folder": server_args.debug_tensor_dump_output_folder,
159
+ "deepep_mode": server_args.deepep_mode,
160
+ "device": server_args.device,
161
+ "disable_chunked_prefix_cache": server_args.disable_chunked_prefix_cache,
162
+ "disable_radix_cache": server_args.disable_radix_cache,
154
163
  "enable_nan_detection": server_args.enable_nan_detection,
155
164
  "enable_dp_attention": server_args.enable_dp_attention,
156
165
  "enable_ep_moe": server_args.enable_ep_moe,
157
166
  "enable_deepep_moe": server_args.enable_deepep_moe,
158
- "deepep_mode": server_args.deepep_mode,
159
- "device": server_args.device,
160
- "speculative_accept_threshold_single": server_args.speculative_accept_threshold_single,
161
- "speculative_accept_threshold_acc": server_args.speculative_accept_threshold_acc,
162
- "disable_radix_cache": server_args.disable_radix_cache,
163
167
  "flashinfer_mla_disable_ragged": server_args.flashinfer_mla_disable_ragged,
164
168
  "moe_dense_tp_size": server_args.moe_dense_tp_size,
165
- "debug_tensor_dump_output_folder": server_args.debug_tensor_dump_output_folder,
166
- "debug_tensor_dump_inject": server_args.debug_tensor_dump_inject,
167
169
  "n_share_experts_fusion": server_args.n_share_experts_fusion,
168
- "disable_chunked_prefix_cache": server_args.disable_chunked_prefix_cache,
170
+ "triton_attention_reduce_in_fp32": server_args.triton_attention_reduce_in_fp32,
171
+ "torchao_config": server_args.torchao_config,
172
+ "sampling_backend": server_args.sampling_backend,
173
+ "speculative_accept_threshold_single": server_args.speculative_accept_threshold_single,
174
+ "speculative_accept_threshold_acc": server_args.speculative_accept_threshold_acc,
169
175
  "use_mla_backend": self.use_mla_backend,
170
176
  }
171
177
  )
@@ -183,6 +189,11 @@ class ModelRunner:
183
189
  # If it is a draft model, tp_group can be different
184
190
  self.initialize(min_per_gpu_memory)
185
191
 
192
+ # temporary cached values
193
+ self.support_pp = (
194
+ "pp_proxy_tensors" in inspect.signature(self.model.forward).parameters
195
+ )
196
+
186
197
  def initialize(self, min_per_gpu_memory: float):
187
198
  server_args = self.server_args
188
199
  self.memory_saver_adapter = TorchMemorySaverAdapter.create(
@@ -193,6 +204,12 @@ class ModelRunner:
193
204
  self.sampler = Sampler()
194
205
  self.load_model()
195
206
 
207
+ self.start_layer = getattr(self.model, "start_layer", 0)
208
+ self.end_layer = getattr(
209
+ self.model, "end_layer", self.model_config.num_hidden_layers
210
+ )
211
+ self.num_effective_layers = self.end_layer - self.start_layer
212
+
196
213
  # Apply torchao quantization
197
214
  torchao_applied = getattr(self.model, "torchao_applied", False)
198
215
  # In layered loading, torchao may have been applied
@@ -271,6 +288,7 @@ class ModelRunner:
271
288
  "fa3",
272
289
  "triton",
273
290
  "flashmla",
291
+ "cutlass_mla",
274
292
  ]:
275
293
  logger.info(
276
294
  f"MLA optimization is turned on. Use {server_args.attention_backend} backend."
@@ -358,18 +376,22 @@ class ModelRunner:
358
376
  # Only initialize the distributed environment on the target model worker.
359
377
  init_distributed_environment(
360
378
  backend=backend,
361
- world_size=self.tp_size,
362
- rank=self.tp_rank,
379
+ world_size=self.tp_size * self.pp_size,
380
+ rank=self.tp_size * self.pp_rank + self.tp_rank,
363
381
  local_rank=self.gpu_id,
364
382
  distributed_init_method=dist_init_method,
365
383
  timeout=self.server_args.dist_timeout,
366
384
  )
367
- initialize_model_parallel(tensor_model_parallel_size=self.tp_size)
385
+ initialize_model_parallel(
386
+ tensor_model_parallel_size=self.tp_size,
387
+ pipeline_model_parallel_size=self.pp_size,
388
+ )
368
389
  initialize_dp_attention(
369
390
  enable_dp_attention=self.server_args.enable_dp_attention,
370
391
  tp_rank=self.tp_rank,
371
392
  tp_size=self.tp_size,
372
393
  dp_size=self.server_args.dp_size,
394
+ pp_size=self.server_args.pp_size,
373
395
  )
374
396
 
375
397
  min_per_gpu_memory = get_available_gpu_memory(
@@ -691,16 +713,23 @@ class ModelRunner:
691
713
  self.device, self.gpu_id, distributed=self.tp_size > 1
692
714
  )
693
715
  if self.use_mla_backend:
716
+ num_layers = (
717
+ self.model_config.num_hidden_layers
718
+ if not self.is_draft_worker
719
+ else self.model_config.hf_config.num_nextn_predict_layers
720
+ )
721
+ # FIXME: pipeline parallelism is not compatible with mla backend
722
+ assert self.pp_size == 1
694
723
  cell_size = (
695
724
  (self.model_config.kv_lora_rank + self.model_config.qk_rope_head_dim)
696
- * self.model_config.num_hidden_layers
725
+ * num_layers
697
726
  * torch._utils._element_size(self.kv_cache_dtype)
698
727
  )
699
728
  else:
700
729
  cell_size = (
701
730
  self.model_config.get_num_kv_heads(get_attention_tp_size())
702
731
  * self.model_config.head_dim
703
- * self.model_config.num_hidden_layers
732
+ * self.num_effective_layers
704
733
  * 2
705
734
  * torch._utils._element_size(self.kv_cache_dtype)
706
735
  )
@@ -808,9 +837,15 @@ class ModelRunner:
808
837
  dtype=self.kv_cache_dtype,
809
838
  kv_lora_rank=self.model_config.kv_lora_rank,
810
839
  qk_rope_head_dim=self.model_config.qk_rope_head_dim,
811
- layer_num=self.model_config.num_hidden_layers,
840
+ layer_num=(
841
+ self.model_config.num_hidden_layers
842
+ if not self.is_draft_worker
843
+ else self.model_config.hf_config.num_nextn_predict_layers
844
+ ), # PP is not compatible with mla backend
812
845
  device=self.device,
813
846
  enable_memory_saver=self.server_args.enable_memory_saver,
847
+ start_layer=self.start_layer,
848
+ end_layer=self.end_layer,
814
849
  )
815
850
  elif self.server_args.enable_double_sparsity:
816
851
  self.token_to_kv_pool = DoubleSparseTokenToKVPool(
@@ -819,10 +854,12 @@ class ModelRunner:
819
854
  dtype=self.kv_cache_dtype,
820
855
  head_num=self.model_config.get_num_kv_heads(get_attention_tp_size()),
821
856
  head_dim=self.model_config.head_dim,
822
- layer_num=self.model_config.num_hidden_layers,
857
+ layer_num=self.num_effective_layers,
823
858
  device=self.device,
824
859
  heavy_channel_num=self.server_args.ds_heavy_channel_num,
825
860
  enable_memory_saver=self.server_args.enable_memory_saver,
861
+ start_layer=self.start_layer,
862
+ end_layer=self.end_layer,
826
863
  )
827
864
  else:
828
865
  self.token_to_kv_pool = MHATokenToKVPool(
@@ -831,9 +868,11 @@ class ModelRunner:
831
868
  dtype=self.kv_cache_dtype,
832
869
  head_num=self.model_config.get_num_kv_heads(get_attention_tp_size()),
833
870
  head_dim=self.model_config.head_dim,
834
- layer_num=self.model_config.num_hidden_layers,
871
+ layer_num=self.num_effective_layers,
835
872
  device=self.device,
836
873
  enable_memory_saver=self.server_args.enable_memory_saver,
874
+ start_layer=self.start_layer,
875
+ end_layer=self.end_layer,
837
876
  )
838
877
 
839
878
  if self.token_to_kv_pool_allocator is None:
@@ -917,8 +956,10 @@ class ModelRunner:
917
956
 
918
957
  self.attn_backend = FlashMLABackend(self)
919
958
  elif self.server_args.attention_backend == "fa3":
920
- assert torch.cuda.get_device_capability()[0] >= 9, (
921
- "FlashAttention v3 Backend requires SM>=90. "
959
+ assert (
960
+ torch.cuda.get_device_capability()[0] == 8 and not self.use_mla_backend
961
+ ) or torch.cuda.get_device_capability()[0] == 9, (
962
+ "FlashAttention v3 Backend requires SM>=80 and SM<=90. "
922
963
  "Please use `--attention-backend flashinfer`."
923
964
  )
924
965
  from sglang.srt.layers.attention.flashattention_backend import (
@@ -926,6 +967,12 @@ class ModelRunner:
926
967
  )
927
968
 
928
969
  self.attn_backend = FlashAttentionBackend(self)
970
+ elif self.server_args.attention_backend == "cutlass_mla":
971
+ from sglang.srt.layers.attention.cutlass_mla_backend import (
972
+ CutlassMLABackend,
973
+ )
974
+
975
+ self.attn_backend = CutlassMLABackend(self)
929
976
  else:
930
977
  raise ValueError(
931
978
  f"Invalid attention backend: {self.server_args.attention_backend}"
@@ -938,7 +985,7 @@ class ModelRunner:
938
985
  with open(self.server_args.ds_channel_config_path, "r") as f:
939
986
  channel_config = json.load(f)
940
987
 
941
- for i in range(self.model_config.num_hidden_layers):
988
+ for i in range(self.start_layer, self.end_layer):
942
989
  key = "model.layers." + str(i) + ".self_attn" + selected_channel
943
990
  self.sorted_channels.append(
944
991
  torch.tensor(channel_config[key])[
@@ -968,7 +1015,7 @@ class ModelRunner:
968
1015
  after_mem = get_available_gpu_memory(self.device, self.gpu_id)
969
1016
  logger.info(
970
1017
  f"Capture cuda graph end. Time elapsed: {time.time() - tic:.2f} s. "
971
- f"avail mem={after_mem:.2f} GB. mem usage={(before_mem - after_mem):.2f} GB."
1018
+ f"mem usage={(before_mem - after_mem):.2f} GB. avail mem={after_mem:.2f} GB."
972
1019
  )
973
1020
 
974
1021
  def apply_torch_tp(self):
@@ -978,64 +1025,82 @@ class ModelRunner:
978
1025
  device_mesh = torch.distributed.init_device_mesh(self.device, (self.tp_size,))
979
1026
  tensor_parallel(self.model, device_mesh)
980
1027
 
981
- def forward_decode(self, forward_batch: ForwardBatch):
1028
+ def forward_decode(
1029
+ self, forward_batch: ForwardBatch, pp_proxy_tensors=None
1030
+ ) -> LogitsProcessorOutput:
982
1031
  self.attn_backend.init_forward_metadata(forward_batch)
1032
+ # FIXME: add pp_proxy_tensors arg to all models
1033
+ kwargs = {}
1034
+ if self.support_pp:
1035
+ kwargs["pp_proxy_tensors"] = pp_proxy_tensors
983
1036
  return self.model.forward(
984
- forward_batch.input_ids, forward_batch.positions, forward_batch
1037
+ forward_batch.input_ids, forward_batch.positions, forward_batch, **kwargs
985
1038
  )
986
1039
 
987
1040
  def forward_extend(
988
- self, forward_batch: ForwardBatch, skip_attn_backend_init: bool = False
989
- ):
1041
+ self,
1042
+ forward_batch: ForwardBatch,
1043
+ skip_attn_backend_init: bool = False,
1044
+ pp_proxy_tensors=None,
1045
+ ) -> LogitsProcessorOutput:
990
1046
  if not skip_attn_backend_init:
991
1047
  self.attn_backend.init_forward_metadata(forward_batch)
992
1048
 
993
- if self.is_generation:
994
- if forward_batch.input_embeds is None:
995
- return self.model.forward(
996
- forward_batch.input_ids, forward_batch.positions, forward_batch
997
- )
998
- else:
999
- return self.model.forward(
1000
- forward_batch.input_ids,
1001
- forward_batch.positions,
1002
- forward_batch,
1003
- input_embeds=forward_batch.input_embeds.bfloat16(),
1004
- )
1005
- else:
1006
- # Only embedding models have get_embedding parameter
1007
- return self.model.forward(
1008
- forward_batch.input_ids,
1009
- forward_batch.positions,
1010
- forward_batch,
1011
- get_embedding=True,
1012
- )
1049
+ kwargs = {}
1050
+ if self.support_pp:
1051
+ kwargs["pp_proxy_tensors"] = pp_proxy_tensors
1052
+ if forward_batch.input_embeds is not None:
1053
+ kwargs["input_embeds"] = forward_batch.input_embeds.bfloat16()
1054
+ if not self.is_generation:
1055
+ kwargs["get_embedding"] = True
1056
+ return self.model.forward(
1057
+ forward_batch.input_ids,
1058
+ forward_batch.positions,
1059
+ forward_batch,
1060
+ **kwargs,
1061
+ )
1013
1062
 
1014
- def forward_idle(self, forward_batch: ForwardBatch):
1063
+ def forward_idle(
1064
+ self, forward_batch: ForwardBatch, pp_proxy_tensors=None
1065
+ ) -> LogitsProcessorOutput:
1066
+ kwargs = {}
1067
+ if self.support_pp:
1068
+ kwargs["pp_proxy_tensors"] = pp_proxy_tensors
1015
1069
  return self.model.forward(
1016
- forward_batch.input_ids, forward_batch.positions, forward_batch
1070
+ forward_batch.input_ids,
1071
+ forward_batch.positions,
1072
+ forward_batch,
1073
+ **kwargs,
1017
1074
  )
1018
1075
 
1019
1076
  def forward(
1020
- self, forward_batch: ForwardBatch, skip_attn_backend_init: bool = False
1021
- ) -> LogitsProcessorOutput:
1022
- if (
1077
+ self,
1078
+ forward_batch: ForwardBatch,
1079
+ skip_attn_backend_init: bool = False,
1080
+ pp_proxy_tensors: Optional[PPProxyTensors] = None,
1081
+ ) -> Union[LogitsProcessorOutput, PPProxyTensors]:
1082
+ can_run_cuda_graph = bool(
1023
1083
  forward_batch.forward_mode.is_cuda_graph()
1024
1084
  and self.cuda_graph_runner
1025
1085
  and self.cuda_graph_runner.can_run(forward_batch)
1026
- ):
1086
+ )
1087
+ if can_run_cuda_graph:
1027
1088
  return self.cuda_graph_runner.replay(
1028
- forward_batch, skip_attn_backend_init=skip_attn_backend_init
1089
+ forward_batch,
1090
+ skip_attn_backend_init=skip_attn_backend_init,
1091
+ pp_proxy_tensors=pp_proxy_tensors,
1029
1092
  )
1030
1093
 
1031
1094
  if forward_batch.forward_mode.is_decode():
1032
- return self.forward_decode(forward_batch)
1095
+ return self.forward_decode(forward_batch, pp_proxy_tensors=pp_proxy_tensors)
1033
1096
  elif forward_batch.forward_mode.is_extend():
1034
1097
  return self.forward_extend(
1035
- forward_batch, skip_attn_backend_init=skip_attn_backend_init
1098
+ forward_batch,
1099
+ skip_attn_backend_init=skip_attn_backend_init,
1100
+ pp_proxy_tensors=pp_proxy_tensors,
1036
1101
  )
1037
1102
  elif forward_batch.forward_mode.is_idle():
1038
- return self.forward_idle(forward_batch)
1103
+ return self.forward_idle(forward_batch, pp_proxy_tensors=pp_proxy_tensors)
1039
1104
  else:
1040
1105
  raise ValueError(f"Invalid forward mode: {forward_batch.forward_mode}")
1041
1106