sglang 0.4.7.post1__py3-none-any.whl → 0.4.8.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (123) hide show
  1. sglang/bench_one_batch.py +8 -6
  2. sglang/srt/_custom_ops.py +2 -2
  3. sglang/srt/code_completion_parser.py +2 -44
  4. sglang/srt/configs/model_config.py +1 -0
  5. sglang/srt/constants.py +3 -0
  6. sglang/srt/conversation.py +14 -3
  7. sglang/srt/custom_op.py +11 -1
  8. sglang/srt/disaggregation/base/conn.py +2 -0
  9. sglang/srt/disaggregation/decode.py +22 -28
  10. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -3
  11. sglang/srt/disaggregation/mini_lb.py +34 -4
  12. sglang/srt/disaggregation/mooncake/conn.py +301 -64
  13. sglang/srt/disaggregation/mooncake/transfer_engine.py +31 -1
  14. sglang/srt/disaggregation/nixl/conn.py +94 -46
  15. sglang/srt/disaggregation/prefill.py +20 -15
  16. sglang/srt/disaggregation/utils.py +47 -18
  17. sglang/srt/distributed/parallel_state.py +12 -4
  18. sglang/srt/entrypoints/engine.py +27 -31
  19. sglang/srt/entrypoints/http_server.py +149 -79
  20. sglang/srt/entrypoints/http_server_engine.py +0 -3
  21. sglang/srt/entrypoints/openai/__init__.py +0 -0
  22. sglang/srt/{openai_api → entrypoints/openai}/protocol.py +115 -34
  23. sglang/srt/entrypoints/openai/serving_base.py +149 -0
  24. sglang/srt/entrypoints/openai/serving_chat.py +897 -0
  25. sglang/srt/entrypoints/openai/serving_completions.py +425 -0
  26. sglang/srt/entrypoints/openai/serving_embedding.py +170 -0
  27. sglang/srt/entrypoints/openai/serving_rerank.py +102 -0
  28. sglang/srt/entrypoints/openai/serving_score.py +61 -0
  29. sglang/srt/entrypoints/openai/usage_processor.py +81 -0
  30. sglang/srt/entrypoints/openai/utils.py +72 -0
  31. sglang/srt/function_call/base_format_detector.py +7 -4
  32. sglang/srt/function_call/deepseekv3_detector.py +1 -1
  33. sglang/srt/function_call/ebnf_composer.py +64 -10
  34. sglang/srt/function_call/function_call_parser.py +6 -6
  35. sglang/srt/function_call/llama32_detector.py +1 -1
  36. sglang/srt/function_call/mistral_detector.py +1 -1
  37. sglang/srt/function_call/pythonic_detector.py +1 -1
  38. sglang/srt/function_call/qwen25_detector.py +1 -1
  39. sglang/srt/{openai_api/utils.py → jinja_template_utils.py} +6 -5
  40. sglang/srt/layers/activation.py +28 -3
  41. sglang/srt/layers/attention/aiter_backend.py +5 -2
  42. sglang/srt/layers/attention/base_attn_backend.py +1 -1
  43. sglang/srt/layers/attention/cutlass_mla_backend.py +1 -0
  44. sglang/srt/layers/attention/flashattention_backend.py +43 -23
  45. sglang/srt/layers/attention/flashinfer_backend.py +9 -6
  46. sglang/srt/layers/attention/flashinfer_mla_backend.py +7 -4
  47. sglang/srt/layers/attention/flashmla_backend.py +5 -2
  48. sglang/srt/layers/attention/tbo_backend.py +3 -3
  49. sglang/srt/layers/attention/triton_backend.py +19 -11
  50. sglang/srt/layers/communicator.py +5 -5
  51. sglang/srt/layers/dp_attention.py +11 -2
  52. sglang/srt/layers/layernorm.py +44 -2
  53. sglang/srt/layers/linear.py +18 -1
  54. sglang/srt/layers/logits_processor.py +14 -5
  55. sglang/srt/layers/moe/ep_moe/kernels.py +159 -2
  56. sglang/srt/layers/moe/ep_moe/layer.py +286 -13
  57. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +19 -2
  58. sglang/srt/layers/moe/fused_moe_native.py +7 -0
  59. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  60. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +13 -2
  61. sglang/srt/layers/moe/fused_moe_triton/layer.py +148 -26
  62. sglang/srt/layers/moe/topk.py +117 -4
  63. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +6 -2
  64. sglang/srt/layers/quantization/fp8.py +25 -17
  65. sglang/srt/layers/quantization/fp8_utils.py +5 -4
  66. sglang/srt/layers/quantization/modelopt_quant.py +62 -8
  67. sglang/srt/layers/quantization/utils.py +5 -2
  68. sglang/srt/layers/rotary_embedding.py +144 -12
  69. sglang/srt/layers/sampler.py +1 -1
  70. sglang/srt/layers/vocab_parallel_embedding.py +14 -1
  71. sglang/srt/lora/lora_manager.py +173 -74
  72. sglang/srt/lora/mem_pool.py +49 -45
  73. sglang/srt/lora/utils.py +1 -1
  74. sglang/srt/managers/cache_controller.py +33 -15
  75. sglang/srt/managers/expert_distribution.py +21 -0
  76. sglang/srt/managers/io_struct.py +19 -14
  77. sglang/srt/managers/multimodal_processors/base_processor.py +44 -9
  78. sglang/srt/managers/multimodal_processors/gemma3n.py +97 -0
  79. sglang/srt/managers/schedule_batch.py +49 -32
  80. sglang/srt/managers/schedule_policy.py +70 -56
  81. sglang/srt/managers/scheduler.py +189 -68
  82. sglang/srt/managers/template_manager.py +226 -0
  83. sglang/srt/managers/tokenizer_manager.py +11 -8
  84. sglang/srt/managers/tp_worker.py +12 -2
  85. sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
  86. sglang/srt/mem_cache/{paged_allocator.py → allocator.py} +125 -34
  87. sglang/srt/mem_cache/base_prefix_cache.py +52 -8
  88. sglang/srt/mem_cache/chunk_cache.py +11 -16
  89. sglang/srt/mem_cache/hiradix_cache.py +34 -23
  90. sglang/srt/mem_cache/memory_pool.py +118 -114
  91. sglang/srt/mem_cache/radix_cache.py +20 -16
  92. sglang/srt/model_executor/cuda_graph_runner.py +77 -46
  93. sglang/srt/model_executor/forward_batch_info.py +18 -5
  94. sglang/srt/model_executor/model_runner.py +27 -8
  95. sglang/srt/model_loader/loader.py +50 -8
  96. sglang/srt/model_loader/weight_utils.py +100 -2
  97. sglang/srt/models/deepseek_nextn.py +35 -30
  98. sglang/srt/models/deepseek_v2.py +255 -30
  99. sglang/srt/models/gemma3n_audio.py +949 -0
  100. sglang/srt/models/gemma3n_causal.py +1009 -0
  101. sglang/srt/models/gemma3n_mm.py +511 -0
  102. sglang/srt/models/glm4.py +312 -0
  103. sglang/srt/models/hunyuan.py +771 -0
  104. sglang/srt/models/mimo_mtp.py +2 -18
  105. sglang/srt/reasoning_parser.py +21 -11
  106. sglang/srt/server_args.py +51 -9
  107. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +131 -10
  108. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +125 -12
  109. sglang/srt/speculative/eagle_utils.py +80 -8
  110. sglang/srt/speculative/eagle_worker.py +124 -41
  111. sglang/srt/torch_memory_saver_adapter.py +19 -15
  112. sglang/srt/two_batch_overlap.py +4 -1
  113. sglang/srt/utils.py +248 -11
  114. sglang/test/test_block_fp8_ep.py +1 -0
  115. sglang/test/test_utils.py +1 -0
  116. sglang/version.py +1 -1
  117. {sglang-0.4.7.post1.dist-info → sglang-0.4.8.post1.dist-info}/METADATA +4 -10
  118. {sglang-0.4.7.post1.dist-info → sglang-0.4.8.post1.dist-info}/RECORD +121 -105
  119. sglang/srt/entrypoints/verl_engine.py +0 -179
  120. sglang/srt/openai_api/adapter.py +0 -2148
  121. {sglang-0.4.7.post1.dist-info → sglang-0.4.8.post1.dist-info}/WHEEL +0 -0
  122. {sglang-0.4.7.post1.dist-info → sglang-0.4.8.post1.dist-info}/licenses/LICENSE +0 -0
  123. {sglang-0.4.7.post1.dist-info → sglang-0.4.8.post1.dist-info}/top_level.txt +0 -0
@@ -7,8 +7,12 @@ from typing import List, Optional, Tuple
7
7
  import torch
8
8
  from huggingface_hub import snapshot_download
9
9
 
10
- from sglang.srt.distributed import GroupCoordinator, patch_tensor_parallel_group
11
- from sglang.srt.layers.dp_attention import disable_dp_size
10
+ from sglang.srt.distributed import (
11
+ GroupCoordinator,
12
+ get_tensor_model_parallel_world_size,
13
+ get_tp_group,
14
+ patch_tensor_parallel_group,
15
+ )
12
16
  from sglang.srt.layers.logits_processor import LogitsProcessorOutput
13
17
  from sglang.srt.layers.sampler import get_token_ids_logprobs, get_top_logprobs
14
18
  from sglang.srt.managers.schedule_batch import (
@@ -57,7 +61,7 @@ logger = logging.getLogger(__name__)
57
61
  def draft_tp_context(tp_group: GroupCoordinator):
58
62
  # Draft model doesn't use dp and has its own tp group.
59
63
  # We disable mscclpp now because it doesn't support 2 comm groups.
60
- with disable_dp_size(), patch_tensor_parallel_group(tp_group):
64
+ with patch_tensor_parallel_group(tp_group):
61
65
  yield
62
66
 
63
67
 
@@ -76,6 +80,7 @@ class EAGLEWorker(TpModelWorker):
76
80
  self.server_args = server_args
77
81
  self.topk = server_args.speculative_eagle_topk
78
82
  self.speculative_num_steps = server_args.speculative_num_steps
83
+ self.speculative_num_draft_tokens = server_args.speculative_num_draft_tokens
79
84
  self.enable_nan_detection = server_args.enable_nan_detection
80
85
  self.gpu_id = gpu_id
81
86
  self.device = server_args.device
@@ -166,6 +171,10 @@ class EAGLEWorker(TpModelWorker):
166
171
 
167
172
  def init_attention_backend(self):
168
173
  # Create multi-step attn backends and cuda graph runners
174
+
175
+ self.has_prefill_wrapper_verify = False
176
+ self.draft_extend_attn_backend = None
177
+
169
178
  if self.server_args.attention_backend == "flashinfer":
170
179
  if not global_server_args_dict["use_mla_backend"]:
171
180
  from sglang.srt.layers.attention.flashinfer_backend import (
@@ -213,7 +222,6 @@ class EAGLEWorker(TpModelWorker):
213
222
  self.draft_model_runner,
214
223
  skip_prefill=False,
215
224
  )
216
- self.has_prefill_wrapper_verify = False
217
225
  elif self.server_args.attention_backend == "fa3":
218
226
  from sglang.srt.layers.attention.flashattention_backend import (
219
227
  FlashAttentionBackend,
@@ -229,7 +237,6 @@ class EAGLEWorker(TpModelWorker):
229
237
  self.draft_model_runner,
230
238
  skip_prefill=False,
231
239
  )
232
- self.has_prefill_wrapper_verify = False
233
240
  elif self.server_args.attention_backend == "flashmla":
234
241
  from sglang.srt.layers.attention.flashmla_backend import (
235
242
  FlashMLAMultiStepDraftBackend,
@@ -240,8 +247,6 @@ class EAGLEWorker(TpModelWorker):
240
247
  self.topk,
241
248
  self.speculative_num_steps,
242
249
  )
243
- self.draft_extend_attn_backend = None
244
- self.has_prefill_wrapper_verify = False
245
250
  else:
246
251
  raise ValueError(
247
252
  f"EAGLE is not supported in attention backend {self.server_args.attention_backend}"
@@ -302,17 +307,27 @@ class EAGLEWorker(TpModelWorker):
302
307
  A tuple of the final logit output of the target model, next tokens accepted,
303
308
  the batch id (used for overlap schedule), and number of accepted tokens.
304
309
  """
305
- if batch.forward_mode.is_decode():
310
+ if batch.forward_mode.is_extend() or batch.is_extend_in_batch:
311
+ logits_output, next_token_ids, bid, seq_lens_cpu = (
312
+ self.forward_target_extend(batch)
313
+ )
314
+ with self.draft_tp_context(self.draft_model_runner.tp_group):
315
+ self.forward_draft_extend(
316
+ batch, logits_output.hidden_states, next_token_ids, seq_lens_cpu
317
+ )
318
+ return logits_output, next_token_ids, bid, 0, False
319
+ else:
306
320
  with self.draft_tp_context(self.draft_model_runner.tp_group):
307
321
  spec_info = self.draft(batch)
308
322
  logits_output, verify_output, model_worker_batch, can_run_cuda_graph = (
309
323
  self.verify(batch, spec_info)
310
324
  )
311
325
 
312
- # If it is None, it means all requests are finished
313
- if batch.spec_info.verified_id is not None:
326
+ if self.check_forward_draft_extend_after_decode(batch):
314
327
  with self.draft_tp_context(self.draft_model_runner.tp_group):
315
- self.forward_draft_extend_after_decode(batch)
328
+ self.forward_draft_extend_after_decode(
329
+ batch,
330
+ )
316
331
  return (
317
332
  logits_output,
318
333
  verify_output.verified_id,
@@ -320,22 +335,27 @@ class EAGLEWorker(TpModelWorker):
320
335
  sum(verify_output.accept_length_per_req_cpu),
321
336
  can_run_cuda_graph,
322
337
  )
323
- elif batch.forward_mode.is_idle():
324
- model_worker_batch = batch.get_model_worker_batch()
325
- logits_output, next_token_ids, _ = (
326
- self.target_worker.forward_batch_generation(model_worker_batch)
327
- )
328
338
 
329
- return logits_output, next_token_ids, model_worker_batch.bid, 0, False
330
- else:
331
- logits_output, next_token_ids, bid, seq_lens_cpu = (
332
- self.forward_target_extend(batch)
333
- )
334
- with self.draft_tp_context(self.draft_model_runner.tp_group):
335
- self.forward_draft_extend(
336
- batch, logits_output.hidden_states, next_token_ids, seq_lens_cpu
337
- )
338
- return logits_output, next_token_ids, bid, 0, False
339
+ def check_forward_draft_extend_after_decode(self, batch: ScheduleBatch):
340
+ local_need_forward = (
341
+ batch.spec_info.verified_id is not None
342
+ and batch.spec_info.verified_id.shape[0] > 0
343
+ )
344
+ if not self.server_args.enable_dp_attention:
345
+ return local_need_forward
346
+
347
+ global_need_forward = torch.tensor(
348
+ [
349
+ (local_need_forward),
350
+ ],
351
+ dtype=torch.int64,
352
+ )
353
+ torch.distributed.all_reduce(
354
+ global_need_forward, group=get_tp_group().cpu_group
355
+ )
356
+ global_need_forward_cnt = global_need_forward[0].item()
357
+ need_forward = global_need_forward_cnt > 0
358
+ return need_forward
339
359
 
340
360
  def forward_target_extend(
341
361
  self, batch: ScheduleBatch
@@ -354,6 +374,7 @@ class EAGLEWorker(TpModelWorker):
354
374
  # We need the full hidden states to prefill the KV cache of the draft model.
355
375
  model_worker_batch = batch.get_model_worker_batch()
356
376
  model_worker_batch.capture_hidden_mode = CaptureHiddenMode.FULL
377
+ model_worker_batch.spec_num_draft_tokens = 1
357
378
  logits_output, next_token_ids, _ = self.target_worker.forward_batch_generation(
358
379
  model_worker_batch
359
380
  )
@@ -364,7 +385,7 @@ class EAGLEWorker(TpModelWorker):
364
385
  model_worker_batch.seq_lens_cpu,
365
386
  )
366
387
 
367
- def draft(self, batch: ScheduleBatch):
388
+ def _draft_preprocess_decode(self, batch: ScheduleBatch):
368
389
  # Parse args
369
390
  num_seqs = batch.batch_size()
370
391
  spec_info = batch.spec_info
@@ -466,10 +487,33 @@ class EAGLEWorker(TpModelWorker):
466
487
  batch.seq_lens_sum = torch.sum(batch.seq_lens).item()
467
488
  batch.return_hidden_states = False
468
489
  spec_info.positions = batch.seq_lens.repeat_interleave(self.topk, dim=0)
490
+ self.token_to_kv_pool_allocator.restore_state(token_to_kv_pool_state_backup)
491
+
492
+ def _draft_preprocess_idle(self, batch: ScheduleBatch):
493
+ batch.spec_info = EagleDraftInput.create_idle_input(
494
+ device=self.device,
495
+ hidden_size=self.model_config.hidden_size,
496
+ dtype=self.model_config.dtype,
497
+ topk=self.topk,
498
+ capture_hidden_mode=CaptureHiddenMode.LAST,
499
+ )
500
+
501
+ def draft(self, batch: ScheduleBatch):
502
+ # Parse args
503
+ if batch.forward_mode.is_idle():
504
+ self._draft_preprocess_idle(batch)
505
+ else:
506
+ self._draft_preprocess_decode(batch)
507
+
508
+ spec_info = batch.spec_info
509
+
469
510
  spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
511
+ batch.return_hidden_states = False
470
512
 
471
513
  # Get forward batch
472
514
  model_worker_batch = batch.get_model_worker_batch()
515
+ model_worker_batch.spec_num_draft_tokens = self.topk
516
+ assert model_worker_batch.capture_hidden_mode == CaptureHiddenMode.LAST
473
517
  forward_batch = ForwardBatch.init_new(
474
518
  model_worker_batch, self.draft_model_runner
475
519
  )
@@ -481,12 +525,18 @@ class EAGLEWorker(TpModelWorker):
481
525
  forward_batch
482
526
  )
483
527
  else:
484
- # Initialize attention backend
485
- self.draft_attn_backend.init_forward_metadata(forward_batch)
528
+ if not forward_batch.forward_mode.is_idle():
529
+ # Initialize attention backend
530
+ self.draft_attn_backend.init_forward_metadata(forward_batch)
486
531
  # Run forward steps
487
532
  score_list, token_list, parents_list = self.draft_forward(forward_batch)
488
533
 
489
- self.token_to_kv_pool_allocator.restore_state(token_to_kv_pool_state_backup)
534
+ if batch.forward_mode.is_idle():
535
+ return EagleVerifyInput.create_idle_input(
536
+ self.topk,
537
+ self.speculative_num_steps,
538
+ self.speculative_num_draft_tokens,
539
+ )
490
540
 
491
541
  (
492
542
  tree_mask,
@@ -504,7 +554,7 @@ class EAGLEWorker(TpModelWorker):
504
554
  batch.seq_lens_sum,
505
555
  self.topk,
506
556
  self.speculative_num_steps,
507
- self.server_args.speculative_num_draft_tokens,
557
+ self.speculative_num_draft_tokens,
508
558
  )
509
559
 
510
560
  return EagleVerifyInput(
@@ -584,11 +634,16 @@ class EAGLEWorker(TpModelWorker):
584
634
  def verify(self, batch: ScheduleBatch, spec_info: EagleVerifyInput):
585
635
  spec_info.prepare_for_verify(batch, self.page_size)
586
636
  batch.return_hidden_states = False
587
- batch.forward_mode = ForwardMode.TARGET_VERIFY
637
+ batch.forward_mode = (
638
+ ForwardMode.TARGET_VERIFY
639
+ if not batch.forward_mode.is_idle()
640
+ else ForwardMode.IDLE
641
+ )
588
642
  batch.spec_info = spec_info
589
643
  model_worker_batch = batch.get_model_worker_batch(
590
644
  seq_lens_cpu_cache=spec_info.seq_lens_cpu
591
645
  )
646
+ model_worker_batch.spec_num_draft_tokens = self.speculative_num_draft_tokens
592
647
  assert model_worker_batch.capture_hidden_mode == spec_info.capture_hidden_mode
593
648
 
594
649
  if batch.has_grammar:
@@ -646,7 +701,9 @@ class EAGLEWorker(TpModelWorker):
646
701
  self.add_logprob_values(batch, res, logits_output)
647
702
 
648
703
  # Prepare the batch for the next draft forwards.
649
- batch.forward_mode = ForwardMode.DECODE
704
+ batch.forward_mode = (
705
+ ForwardMode.DECODE if not batch.forward_mode.is_idle() else ForwardMode.IDLE
706
+ )
650
707
  batch.spec_info = res.draft_input
651
708
 
652
709
  return logits_output, res, model_worker_batch, can_run_cuda_graph
@@ -743,6 +800,7 @@ class EAGLEWorker(TpModelWorker):
743
800
  model_worker_batch = batch.get_model_worker_batch(
744
801
  seq_lens_cpu_cache=seq_lens_cpu
745
802
  )
803
+ model_worker_batch.spec_num_draft_tokens = 1
746
804
  forward_batch = ForwardBatch.init_new(
747
805
  model_worker_batch, self.draft_model_runner
748
806
  )
@@ -759,13 +817,33 @@ class EAGLEWorker(TpModelWorker):
759
817
  req_pool_indices_backup = batch.req_pool_indices
760
818
  accept_length_backup = batch.spec_info.accept_length
761
819
  return_logprob_backup = batch.return_logprob
762
-
763
- # Prepare metadata
764
- batch.spec_info.prepare_extend_after_decode(
765
- batch,
766
- self.speculative_num_steps,
767
- )
820
+ input_is_idle = batch.forward_mode.is_idle()
821
+ if not input_is_idle:
822
+ # Prepare metadata
823
+ if batch.spec_info.verified_id is not None:
824
+ batch.spec_info.prepare_extend_after_decode(
825
+ batch,
826
+ self.speculative_num_steps,
827
+ )
828
+ else:
829
+ batch = batch.copy()
830
+ batch.prepare_for_idle()
831
+ hidden_size = (
832
+ self.model_config.hidden_size * 3
833
+ if self.speculative_algorithm.is_eagle3()
834
+ else self.model_config.hidden_size
835
+ )
836
+ batch.spec_info = EagleDraftInput.create_idle_input(
837
+ device=self.device,
838
+ hidden_size=hidden_size,
839
+ dtype=self.model_config.dtype,
840
+ topk=self.topk,
841
+ capture_hidden_mode=CaptureHiddenMode.LAST,
842
+ )
843
+ batch.return_hidden_states = False
768
844
  model_worker_batch = batch.get_model_worker_batch()
845
+ model_worker_batch.spec_num_draft_tokens = self.speculative_num_draft_tokens
846
+ assert model_worker_batch.capture_hidden_mode == CaptureHiddenMode.LAST
769
847
  forward_batch = ForwardBatch.init_new(
770
848
  model_worker_batch, self.draft_model_runner
771
849
  )
@@ -789,7 +867,10 @@ class EAGLEWorker(TpModelWorker):
789
867
  )
790
868
  forward_batch.spec_info.hidden_states = logits_output.hidden_states
791
869
  else:
792
- self.draft_model_runner.attn_backend.init_forward_metadata(forward_batch)
870
+ if not forward_batch.forward_mode.is_idle():
871
+ self.draft_model_runner.attn_backend.init_forward_metadata(
872
+ forward_batch
873
+ )
793
874
  logits_output = self.draft_model_runner.model.forward(
794
875
  forward_batch.input_ids, forward_batch.positions, forward_batch
795
876
  )
@@ -799,7 +880,9 @@ class EAGLEWorker(TpModelWorker):
799
880
 
800
881
  # Restore backup.
801
882
  # This is because `seq_lens` can be modified in `prepare_extend_after_decode`
802
- batch.forward_mode = ForwardMode.DECODE
883
+ batch.forward_mode = (
884
+ ForwardMode.DECODE if not input_is_idle else ForwardMode.IDLE
885
+ )
803
886
  batch.seq_lens = seq_lens_backup
804
887
  batch.req_pool_indices = req_pool_indices_backup
805
888
  batch.spec_info.accept_length = accept_length_backup
@@ -1,11 +1,13 @@
1
1
  import logging
2
+ import threading
3
+ import time
2
4
  from abc import ABC
3
- from contextlib import contextmanager
5
+ from contextlib import contextmanager, nullcontext
4
6
 
5
7
  try:
6
8
  import torch_memory_saver
7
9
 
8
- _primary_memory_saver = torch_memory_saver.TorchMemorySaver()
10
+ _memory_saver = torch_memory_saver.torch_memory_saver
9
11
  import_error = None
10
12
  except ImportError as e:
11
13
  import_error = e
@@ -38,13 +40,13 @@ class TorchMemorySaverAdapter(ABC):
38
40
  def configure_subprocess(self):
39
41
  raise NotImplementedError
40
42
 
41
- def region(self):
43
+ def region(self, tag: str):
42
44
  raise NotImplementedError
43
45
 
44
- def pause(self):
46
+ def pause(self, tag: str):
45
47
  raise NotImplementedError
46
48
 
47
- def resume(self):
49
+ def resume(self, tag: str):
48
50
  raise NotImplementedError
49
51
 
50
52
  @property
@@ -53,21 +55,23 @@ class TorchMemorySaverAdapter(ABC):
53
55
 
54
56
 
55
57
  class _TorchMemorySaverAdapterReal(TorchMemorySaverAdapter):
58
+ """Adapter for TorchMemorySaver with tag-based control"""
59
+
56
60
  def configure_subprocess(self):
57
61
  return torch_memory_saver.configure_subprocess()
58
62
 
59
- def region(self):
60
- return _primary_memory_saver.region()
63
+ def region(self, tag: str):
64
+ return _memory_saver.region(tag=tag)
61
65
 
62
- def pause(self):
63
- return _primary_memory_saver.pause()
66
+ def pause(self, tag: str):
67
+ return _memory_saver.pause(tag=tag)
64
68
 
65
- def resume(self):
66
- return _primary_memory_saver.resume()
69
+ def resume(self, tag: str):
70
+ return _memory_saver.resume(tag=tag)
67
71
 
68
72
  @property
69
73
  def enabled(self):
70
- return _primary_memory_saver.enabled
74
+ return _memory_saver is not None and _memory_saver.enabled
71
75
 
72
76
 
73
77
  class _TorchMemorySaverAdapterNoop(TorchMemorySaverAdapter):
@@ -76,13 +80,13 @@ class _TorchMemorySaverAdapterNoop(TorchMemorySaverAdapter):
76
80
  yield
77
81
 
78
82
  @contextmanager
79
- def region(self):
83
+ def region(self, tag: str):
80
84
  yield
81
85
 
82
- def pause(self):
86
+ def pause(self, tag: str):
83
87
  pass
84
88
 
85
- def resume(self):
89
+ def resume(self, tag: str):
86
90
  pass
87
91
 
88
92
  @property
@@ -346,7 +346,10 @@ class TboForwardBatchPreparer:
346
346
  )
347
347
 
348
348
  # TODO improve, e.g. unify w/ `init_raw`
349
- if global_server_args_dict["moe_dense_tp_size"] == 1:
349
+ if (
350
+ global_server_args_dict["moe_dense_tp_size"] == 1
351
+ and batch.gathered_buffer is not None
352
+ ):
350
353
  sum_len = end_token_index - start_token_index
351
354
  gathered_buffer = torch.zeros(
352
355
  (sum_len, batch.gathered_buffer.shape[1]),