sglang 0.4.10.post2__py3-none-any.whl → 0.5.0rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (105) hide show
  1. sglang/bench_one_batch.py +113 -17
  2. sglang/srt/configs/model_config.py +35 -0
  3. sglang/srt/conversation.py +9 -5
  4. sglang/srt/disaggregation/base/conn.py +5 -2
  5. sglang/srt/disaggregation/decode.py +6 -1
  6. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +3 -0
  7. sglang/srt/disaggregation/mooncake/conn.py +243 -135
  8. sglang/srt/disaggregation/prefill.py +2 -0
  9. sglang/srt/distributed/parallel_state.py +11 -9
  10. sglang/srt/entrypoints/context.py +244 -0
  11. sglang/srt/entrypoints/engine.py +4 -3
  12. sglang/srt/entrypoints/harmony_utils.py +370 -0
  13. sglang/srt/entrypoints/http_server.py +71 -0
  14. sglang/srt/entrypoints/openai/protocol.py +227 -1
  15. sglang/srt/entrypoints/openai/serving_chat.py +278 -42
  16. sglang/srt/entrypoints/openai/serving_responses.py +1273 -0
  17. sglang/srt/entrypoints/openai/tool_server.py +174 -0
  18. sglang/srt/entrypoints/tool.py +87 -0
  19. sglang/srt/eplb/expert_location.py +5 -1
  20. sglang/srt/function_call/harmony_tool_parser.py +130 -0
  21. sglang/srt/hf_transformers_utils.py +30 -3
  22. sglang/srt/jinja_template_utils.py +8 -1
  23. sglang/srt/layers/attention/aiter_backend.py +5 -8
  24. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1700 -0
  25. sglang/srt/layers/attention/triton_backend.py +85 -14
  26. sglang/srt/layers/attention/triton_ops/decode_attention.py +17 -0
  27. sglang/srt/layers/attention/triton_ops/extend_attention.py +143 -98
  28. sglang/srt/layers/attention/trtllm_mha_backend.py +332 -0
  29. sglang/srt/layers/attention/vision.py +13 -5
  30. sglang/srt/layers/communicator.py +21 -4
  31. sglang/srt/layers/dp_attention.py +12 -0
  32. sglang/srt/layers/linear.py +2 -7
  33. sglang/srt/layers/moe/cutlass_moe.py +20 -6
  34. sglang/srt/layers/moe/ep_moe/layer.py +77 -73
  35. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +101 -12
  36. sglang/srt/layers/moe/fused_moe_triton/layer.py +416 -35
  37. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +188 -3
  38. sglang/srt/layers/moe/topk.py +12 -3
  39. sglang/srt/layers/moe/utils.py +16 -0
  40. sglang/srt/layers/quantization/__init__.py +22 -0
  41. sglang/srt/layers/quantization/fp4.py +557 -0
  42. sglang/srt/layers/quantization/fp8.py +3 -6
  43. sglang/srt/layers/quantization/fp8_utils.py +29 -0
  44. sglang/srt/layers/quantization/modelopt_quant.py +259 -64
  45. sglang/srt/layers/quantization/mxfp4.py +651 -0
  46. sglang/srt/layers/quantization/mxfp4_tensor.py +133 -0
  47. sglang/srt/layers/quantization/quark/__init__.py +0 -0
  48. sglang/srt/layers/quantization/quark/schemes/__init__.py +6 -0
  49. sglang/srt/layers/quantization/quark/schemes/quark_scheme.py +55 -0
  50. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +118 -0
  51. sglang/srt/layers/quantization/quark/utils.py +107 -0
  52. sglang/srt/layers/quantization/unquant.py +60 -6
  53. sglang/srt/layers/quantization/w4afp8.py +1 -1
  54. sglang/srt/layers/rotary_embedding.py +225 -1
  55. sglang/srt/layers/utils.py +9 -0
  56. sglang/srt/layers/vocab_parallel_embedding.py +8 -3
  57. sglang/srt/lora/lora_manager.py +70 -14
  58. sglang/srt/lora/lora_registry.py +3 -2
  59. sglang/srt/lora/mem_pool.py +43 -5
  60. sglang/srt/managers/cache_controller.py +55 -30
  61. sglang/srt/managers/detokenizer_manager.py +1 -1
  62. sglang/srt/managers/io_struct.py +15 -3
  63. sglang/srt/managers/mm_utils.py +5 -11
  64. sglang/srt/managers/schedule_batch.py +28 -7
  65. sglang/srt/managers/scheduler.py +26 -12
  66. sglang/srt/managers/scheduler_output_processor_mixin.py +1 -2
  67. sglang/srt/managers/scheduler_recv_skipper.py +37 -0
  68. sglang/srt/managers/scheduler_update_weights_mixin.py +6 -0
  69. sglang/srt/managers/template_manager.py +35 -1
  70. sglang/srt/managers/tokenizer_manager.py +24 -6
  71. sglang/srt/managers/tp_worker.py +3 -0
  72. sglang/srt/managers/tp_worker_overlap_thread.py +3 -0
  73. sglang/srt/mem_cache/hiradix_cache.py +53 -5
  74. sglang/srt/mem_cache/memory_pool_host.py +1 -1
  75. sglang/srt/mem_cache/multimodal_cache.py +33 -13
  76. sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +2 -2
  77. sglang/srt/model_executor/cuda_graph_runner.py +7 -6
  78. sglang/srt/model_executor/forward_batch_info.py +35 -14
  79. sglang/srt/model_executor/model_runner.py +19 -2
  80. sglang/srt/model_loader/weight_utils.py +10 -0
  81. sglang/srt/models/bailing_moe.py +425 -0
  82. sglang/srt/models/deepseek_v2.py +72 -33
  83. sglang/srt/models/ernie4.py +426 -0
  84. sglang/srt/models/ernie4_eagle.py +203 -0
  85. sglang/srt/models/gemma3n_mm.py +39 -0
  86. sglang/srt/models/glm4_moe.py +24 -12
  87. sglang/srt/models/gpt_oss.py +1134 -0
  88. sglang/srt/models/qwen2.py +6 -0
  89. sglang/srt/models/qwen2_moe.py +6 -0
  90. sglang/srt/models/qwen3_moe.py +32 -6
  91. sglang/srt/models/step3_vl.py +9 -0
  92. sglang/srt/models/transformers.py +2 -5
  93. sglang/srt/multimodal/processors/step3_vl.py +3 -1
  94. sglang/srt/reasoning_parser.py +18 -39
  95. sglang/srt/server_args.py +142 -7
  96. sglang/srt/two_batch_overlap.py +157 -5
  97. sglang/srt/utils.py +38 -2
  98. sglang/test/runners.py +2 -2
  99. sglang/test/test_utils.py +1 -1
  100. sglang/version.py +1 -1
  101. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc0.dist-info}/METADATA +16 -14
  102. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc0.dist-info}/RECORD +105 -84
  103. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc0.dist-info}/WHEEL +0 -0
  104. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc0.dist-info}/licenses/LICENSE +0 -0
  105. {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc0.dist-info}/top_level.txt +0 -0
sglang/bench_one_batch.py CHANGED
@@ -43,6 +43,7 @@ I'm going to the park
43
43
  """
44
44
 
45
45
  import argparse
46
+ import copy
46
47
  import dataclasses
47
48
  import itertools
48
49
  import json
@@ -84,12 +85,14 @@ class BenchArgs:
84
85
  batch_size: Tuple[int] = (1,)
85
86
  input_len: Tuple[int] = (1024,)
86
87
  output_len: Tuple[int] = (16,)
88
+ prompt_filename: str = ""
87
89
  result_filename: str = "result.jsonl"
88
90
  correctness_test: bool = False
89
91
  # This is only used for correctness test
90
92
  cut_len: int = 4
91
93
  log_decode_step: int = 0
92
94
  profile: bool = False
95
+ profile_record_shapes: bool = False
93
96
  profile_filename_prefix: str = "profile"
94
97
 
95
98
  @staticmethod
@@ -104,6 +107,9 @@ class BenchArgs:
104
107
  parser.add_argument(
105
108
  "--output-len", type=int, nargs="+", default=BenchArgs.output_len
106
109
  )
110
+ parser.add_argument(
111
+ "--prompt-filename", type=str, default=BenchArgs.prompt_filename
112
+ )
107
113
  parser.add_argument(
108
114
  "--result-filename", type=str, default=BenchArgs.result_filename
109
115
  )
@@ -118,6 +124,11 @@ class BenchArgs:
118
124
  parser.add_argument(
119
125
  "--profile", action="store_true", help="Use Torch Profiler."
120
126
  )
127
+ parser.add_argument(
128
+ "--profile-record-shapes",
129
+ action="store_true",
130
+ help="Record tensor shapes in profiling results.",
131
+ )
121
132
  parser.add_argument(
122
133
  "--profile-filename-prefix",
123
134
  type=str,
@@ -165,12 +176,16 @@ def load_model(server_args, port_args, tp_rank):
165
176
  return model_runner, tokenizer
166
177
 
167
178
 
168
- def prepare_inputs_for_correctness_test(bench_args, tokenizer):
169
- prompts = [
170
- "The capital of France is",
171
- "The capital of the United Kindom is",
172
- "Today is a sunny day and I like",
173
- ]
179
+ def prepare_inputs_for_correctness_test(bench_args, tokenizer, custom_prompts):
180
+ prompts = (
181
+ custom_prompts
182
+ if custom_prompts
183
+ else [
184
+ "The capital of France is",
185
+ "The capital of the United Kindom is",
186
+ "Today is a sunny day and I like",
187
+ ]
188
+ )
174
189
  input_ids = [tokenizer.encode(p) for p in prompts]
175
190
  sampling_params = SamplingParams(
176
191
  temperature=0,
@@ -211,8 +226,14 @@ def prepare_extend_inputs_for_correctness_test(
211
226
  return reqs
212
227
 
213
228
 
214
- def prepare_synthetic_inputs_for_latency_test(batch_size, input_len):
215
- input_ids = np.random.randint(0, 10000, (batch_size, input_len), dtype=np.int32)
229
+ def prepare_synthetic_inputs_for_latency_test(
230
+ batch_size, input_len, custom_inputs=None
231
+ ):
232
+ input_ids = (
233
+ custom_inputs
234
+ if custom_inputs
235
+ else np.random.randint(0, 10000, (batch_size, input_len), dtype=np.int32)
236
+ )
216
237
  sampling_params = SamplingParams(
217
238
  temperature=0,
218
239
  max_new_tokens=BenchArgs.output_len,
@@ -284,6 +305,30 @@ def _maybe_prepare_mlp_sync_batch(batch: ScheduleBatch, model_runner):
284
305
  )
285
306
 
286
307
 
308
+ def _read_prompts_from_file(prompt_file, rank_print):
309
+ """Read custom prompts from the file specified by `--prompt-filename`."""
310
+ if not prompt_file:
311
+ return []
312
+ if not os.path.exists(prompt_file):
313
+ rank_print(
314
+ f"Custom prompt file {prompt_file} not found. Using default inputs..."
315
+ )
316
+ return []
317
+ with open(prompt_file, "r") as pf:
318
+ return pf.readlines()
319
+
320
+
321
+ def _save_profile_trace_results(profiler, filename):
322
+ parent_dir = os.path.dirname(os.path.abspath(filename))
323
+ os.makedirs(parent_dir, exist_ok=True)
324
+ profiler.export_chrome_trace(filename)
325
+ print(
326
+ profiler.key_averages(group_by_input_shape=True).table(
327
+ sort_by="self_cpu_time_total"
328
+ )
329
+ )
330
+
331
+
287
332
  def correctness_test(
288
333
  server_args,
289
334
  port_args,
@@ -298,7 +343,10 @@ def correctness_test(
298
343
  model_runner, tokenizer = load_model(server_args, port_args, tp_rank)
299
344
 
300
345
  # Prepare inputs
301
- input_ids, reqs = prepare_inputs_for_correctness_test(bench_args, tokenizer)
346
+ custom_prompts = _read_prompts_from_file(bench_args.prompt_filename, rank_print)
347
+ input_ids, reqs = prepare_inputs_for_correctness_test(
348
+ bench_args, tokenizer, custom_prompts
349
+ )
302
350
  rank_print(f"\n{input_ids=}\n")
303
351
 
304
352
  if bench_args.cut_len > 0:
@@ -344,6 +392,7 @@ def latency_test_run_once(
344
392
  device,
345
393
  log_decode_step,
346
394
  profile,
395
+ profile_record_shapes,
347
396
  profile_filename_prefix,
348
397
  ):
349
398
  max_batch_size = model_runner.max_total_num_tokens // (input_len + output_len)
@@ -374,6 +423,7 @@ def latency_test_run_once(
374
423
  torch.profiler.ProfilerActivity.CUDA,
375
424
  ],
376
425
  with_stack=True,
426
+ record_shapes=profile_record_shapes,
377
427
  )
378
428
  profiler.start()
379
429
 
@@ -391,10 +441,30 @@ def latency_test_run_once(
391
441
  measurement_results["prefill_latency"] = prefill_latency
392
442
  measurement_results["prefill_throughput"] = throughput
393
443
 
444
+ if profile:
445
+ profiler.stop()
446
+ profile_filename = f"{profile_filename_prefix}_batch{batch_size}_input{input_len}_output{output_len}_prefill.trace.json.gz"
447
+ _save_profile_trace_results(profiler, profile_filename)
448
+ rank_print(
449
+ f"torch profiler chrome trace for prefill saved to {profile_filename}"
450
+ )
451
+
394
452
  # Decode
395
453
  decode_latencies = []
396
454
  for i in range(output_len - 1):
397
455
  synchronize(device)
456
+ if profile and i == output_len / 2:
457
+ profiler = None
458
+ profiler = torch.profiler.profile(
459
+ activities=[
460
+ torch.profiler.ProfilerActivity.CPU,
461
+ torch.profiler.ProfilerActivity.CUDA,
462
+ ],
463
+ with_stack=True,
464
+ record_shapes=profile_record_shapes,
465
+ )
466
+ profiler.start()
467
+
398
468
  tic = time.perf_counter()
399
469
  next_token_ids, _ = decode(next_token_ids, batch, model_runner)
400
470
  synchronize(device)
@@ -407,13 +477,13 @@ def latency_test_run_once(
407
477
  f"Decode {i}. Batch size: {batch_size}, latency: {latency:6.5f} s, throughput: {throughput:9.2f} token/s"
408
478
  )
409
479
 
410
- if profile:
411
- profiler.stop()
412
- profile_filename = f"{profile_filename_prefix}_batch{batch_size}_input{input_len}_output{output_len}.trace.json.gz"
413
- parent_dir = os.path.dirname(os.path.abspath(profile_filename))
414
- os.makedirs(parent_dir, exist_ok=True)
415
- profiler.export_chrome_trace(profile_filename)
416
- rank_print(f"torch profiler chrome trace saved to {profile_filename}")
480
+ if profile and i == output_len / 2:
481
+ profiler.stop()
482
+ profile_filename = f"{profile_filename_prefix}_batch{batch_size}_input{input_len}_output{output_len}_decode.trace.json.gz"
483
+ _save_profile_trace_results(profiler, profile_filename)
484
+ rank_print(
485
+ f"torch profiler chrome trace for decoding 1 token saved to {profile_filename}"
486
+ )
417
487
 
418
488
  # Record decode timing from 2nd output
419
489
  if output_len > 1:
@@ -469,17 +539,42 @@ def latency_test(
469
539
  server_args.device,
470
540
  log_decode_step=0,
471
541
  profile=False,
542
+ profile_record_shapes=False,
472
543
  profile_filename_prefix="", # not used
473
544
  )
474
545
 
475
546
  rank_print("Benchmark ...")
476
547
 
548
+ custom_inputs = _read_prompts_from_file(bench_args.prompt_filename, rank_print)
549
+ custom_inputs = [tokenizer.encode(p.strip()) for p in custom_inputs]
550
+ custom_input_len = len(custom_inputs)
551
+
477
552
  # Run the sweep
478
553
  result_list = []
479
554
  for bs, il, ol in itertools.product(
480
555
  bench_args.batch_size, bench_args.input_len, bench_args.output_len
481
556
  ):
482
- reqs = prepare_synthetic_inputs_for_latency_test(bs, il)
557
+ bs_aligned_inputs = []
558
+ if custom_inputs:
559
+ if custom_input_len == bs:
560
+ bs_aligned_inputs = custom_inputs
561
+ elif custom_input_len > bs:
562
+ rank_print(
563
+ f"Custom input size ({custom_input_len}) is larger than batch_size ({bs}). "
564
+ f"Using the first {bs} prompts."
565
+ )
566
+ bs_aligned_inputs = copy.deepcopy(custom_inputs[:bs])
567
+ else:
568
+ rank_print(
569
+ f"Custom input size ({custom_input_len}) is smaller than batch_size ({bs}). "
570
+ f"Pad to the desired batch_size with the last prompt."
571
+ )
572
+ bs_aligned_inputs = copy.deepcopy(custom_inputs)
573
+ bs_aligned_inputs.extend(
574
+ [bs_aligned_inputs[-1]] * (bs - custom_input_len)
575
+ )
576
+
577
+ reqs = prepare_synthetic_inputs_for_latency_test(bs, il, bs_aligned_inputs)
483
578
  ret = latency_test_run_once(
484
579
  bench_args.run_name,
485
580
  model_runner,
@@ -491,6 +586,7 @@ def latency_test(
491
586
  server_args.device,
492
587
  bench_args.log_decode_step,
493
588
  bench_args.profile if tp_rank == 0 else None,
589
+ bench_args.profile_record_shapes if tp_rank == 0 else None,
494
590
  bench_args.profile_filename_prefix,
495
591
  )
496
592
  if ret is not None:
@@ -27,6 +27,7 @@ from sglang.srt.hf_transformers_utils import (
27
27
  get_context_length,
28
28
  get_generation_config,
29
29
  get_hf_text_config,
30
+ get_sparse_attention_config,
30
31
  )
31
32
  from sglang.srt.layers.quantization import QUANTIZATION_METHODS
32
33
  from sglang.srt.server_args import ServerArgs
@@ -133,6 +134,11 @@ class ModelConfig:
133
134
 
134
135
  if is_draft_model and self.hf_config.architectures[0] == "MiMoForCausalLM":
135
136
  self.hf_config.architectures[0] = "MiMoMTP"
137
+ if (
138
+ is_draft_model
139
+ and self.hf_config.architectures[0] == "Ernie4_5_MoeForCausalLM"
140
+ ):
141
+ self.hf_config.architectures[0] = "Ernie4_5_MoeForCausalLMMTP"
136
142
  # Check model type
137
143
  self.is_generation = is_generation_model(
138
144
  self.hf_config.architectures, is_embedding
@@ -270,6 +276,9 @@ class ModelConfig:
270
276
  # Verify quantization
271
277
  self._verify_quantization()
272
278
 
279
+ # Verify dual-chunk attention config
280
+ self._verify_dual_chunk_attention_config()
281
+
273
282
  # Cache attributes
274
283
  self.hf_eos_token_id = self.get_hf_eos_token_id()
275
284
 
@@ -297,6 +306,13 @@ class ModelConfig:
297
306
  **kwargs,
298
307
  )
299
308
 
309
+ def get_total_num_attention_heads(self) -> int:
310
+ return self.num_attention_heads
311
+
312
+ def get_num_attention_heads(self, tensor_parallel_size) -> int:
313
+ total_num_attention_heads = self.num_attention_heads
314
+ return max(1, total_num_attention_heads // tensor_parallel_size)
315
+
300
316
  # adapted from https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L289
301
317
  def get_total_num_kv_heads(self) -> int:
302
318
  """Returns the total number of KV heads."""
@@ -401,6 +417,8 @@ class ModelConfig:
401
417
  "fbgemm_fp8",
402
418
  "w8a8_fp8",
403
419
  "petit_nvfp4",
420
+ "quark",
421
+ "mxfp4",
404
422
  ]
405
423
  optimized_quantization_methods = [
406
424
  "fp8",
@@ -482,6 +500,23 @@ class ModelConfig:
482
500
  self.quantization,
483
501
  )
484
502
 
503
+ def _verify_dual_chunk_attention_config(self) -> None:
504
+ if hasattr(self.hf_config, "dual_chunk_attention_config"):
505
+ # Try loading the sparse attention config
506
+ sparse_attn_config = get_sparse_attention_config(self.model_path)
507
+ if not sparse_attn_config:
508
+ return
509
+ self.hf_config.dual_chunk_attention_config["sparse_attention_config"] = (
510
+ sparse_attn_config
511
+ )
512
+ if (
513
+ "sparse_attention_enabled"
514
+ not in self.hf_config.dual_chunk_attention_config
515
+ ):
516
+ self.hf_config.dual_chunk_attention_config[
517
+ "sparse_attention_enabled"
518
+ ] = True
519
+
485
520
  def get_hf_eos_token_id(self) -> Optional[Set[int]]:
486
521
  eos_ids = getattr(self.hf_config, "eos_token_id", None)
487
522
  if eos_ids is not None:
@@ -30,8 +30,10 @@ import re
30
30
  from enum import IntEnum, auto
31
31
  from typing import Callable, Dict, List, Optional, Tuple, Union
32
32
 
33
+ from typing_extensions import Literal
34
+
33
35
  from sglang.srt.entrypoints.openai.protocol import ChatCompletionRequest
34
- from sglang.srt.utils import read_system_prompt_from_file
36
+ from sglang.srt.utils import ImageData, read_system_prompt_from_file
35
37
 
36
38
 
37
39
  class SeparatorStyle(IntEnum):
@@ -91,7 +93,7 @@ class Conversation:
91
93
  video_token: str = "<video>"
92
94
  audio_token: str = "<audio>"
93
95
 
94
- image_data: Optional[List[str]] = None
96
+ image_data: Optional[List[ImageData]] = None
95
97
  video_data: Optional[List[str]] = None
96
98
  modalities: Optional[List[str]] = None
97
99
  stop_token_ids: Optional[int] = None
@@ -381,9 +383,9 @@ class Conversation:
381
383
  """Append a new message."""
382
384
  self.messages.append([role, message])
383
385
 
384
- def append_image(self, image: str):
386
+ def append_image(self, image: str, detail: Literal["auto", "low", "high"]):
385
387
  """Append a new image."""
386
- self.image_data.append(image)
388
+ self.image_data.append(ImageData(url=image, detail=detail))
387
389
 
388
390
  def append_video(self, video: str):
389
391
  """Append a new video."""
@@ -627,7 +629,9 @@ def generate_chat_conv(
627
629
  real_content = image_token + real_content
628
630
  else:
629
631
  real_content += image_token
630
- conv.append_image(content.image_url.url)
632
+ conv.append_image(
633
+ content.image_url.url, content.image_url.detail
634
+ )
631
635
  elif content.type == "video_url":
632
636
  real_content += video_token
633
637
  conv.append_video(content.video_url.url)
@@ -25,10 +25,13 @@ class KVArgs:
25
25
  gpu_id: int
26
26
  # for different tp
27
27
  decode_tp_size: int
28
- # for pp prefill
29
- prefill_pp_size: int
30
28
  kv_head_num: int
31
29
  page_size: int
30
+ # for pp prefill
31
+ prefill_pp_size: int
32
+ pp_rank: int
33
+ # for system dp
34
+ system_dp_rank: int
32
35
 
33
36
 
34
37
  class KVPoll:
@@ -44,6 +44,7 @@ from sglang.srt.disaggregation.utils import (
44
44
  poll_and_all_reduce,
45
45
  prepare_abort,
46
46
  )
47
+ from sglang.srt.layers.dp_attention import get_attention_tp_size
47
48
  from sglang.srt.managers.schedule_batch import FINISH_ABORT, ScheduleBatch
48
49
  from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
49
50
  from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
@@ -184,9 +185,13 @@ class DecodePreallocQueue:
184
185
  kv_args_class = get_kv_class(self.transfer_backend, KVClassType.KVARGS)
185
186
  kv_args = kv_args_class()
186
187
 
187
- attn_tp_size = self.tp_size // self.dp_size
188
+ attn_tp_size = get_attention_tp_size()
188
189
  kv_args.engine_rank = self.tp_rank % (attn_tp_size)
190
+
189
191
  kv_args.decode_tp_size = attn_tp_size
192
+ # Note(shangming): pp is not supported on the decode side yet, so its rank is fixed to 0
193
+ kv_args.pp_rank = 0
194
+ kv_args.system_dp_rank = self.scheduler.dp_rank
190
195
  kv_args.prefill_pp_size = self.prefill_pp_size
191
196
  kv_data_ptrs, kv_data_lens, kv_item_lens = (
192
197
  self.token_to_kv_pool.get_contiguous_buf_infos()
@@ -76,6 +76,9 @@ class ScheduleBatchDisaggregationDecodeMixin:
76
76
  req_pool_indices, dtype=torch.int64, device=self.device
77
77
  )
78
78
  self.seq_lens = torch.tensor(seq_lens, dtype=torch.int64, device=self.device)
79
+ self.orig_seq_lens = torch.tensor(
80
+ seq_lens, dtype=torch.int32, device=self.device
81
+ )
79
82
  self.out_cache_loc = out_cache_loc
80
83
  self.seq_lens_sum = sum(seq_lens)
81
84