sglang 0.5.1.post3__py3-none-any.whl → 0.5.2rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (93) hide show
  1. sglang/bench_one_batch.py +3 -0
  2. sglang/srt/configs/__init__.py +2 -0
  3. sglang/srt/configs/longcat_flash.py +104 -0
  4. sglang/srt/configs/model_config.py +14 -1
  5. sglang/srt/connector/__init__.py +1 -1
  6. sglang/srt/connector/base_connector.py +1 -2
  7. sglang/srt/connector/redis.py +2 -2
  8. sglang/srt/connector/serde/__init__.py +1 -1
  9. sglang/srt/connector/serde/safe_serde.py +4 -3
  10. sglang/srt/disaggregation/ascend/conn.py +75 -0
  11. sglang/srt/disaggregation/launch_lb.py +0 -13
  12. sglang/srt/disaggregation/mini_lb.py +33 -8
  13. sglang/srt/disaggregation/prefill.py +1 -1
  14. sglang/srt/distributed/parallel_state.py +27 -15
  15. sglang/srt/entrypoints/engine.py +19 -12
  16. sglang/srt/entrypoints/http_server.py +174 -34
  17. sglang/srt/entrypoints/openai/protocol.py +60 -0
  18. sglang/srt/eplb/eplb_manager.py +26 -2
  19. sglang/srt/eplb/expert_distribution.py +29 -2
  20. sglang/srt/hf_transformers_utils.py +10 -0
  21. sglang/srt/layers/activation.py +12 -0
  22. sglang/srt/layers/attention/ascend_backend.py +240 -109
  23. sglang/srt/layers/attention/hybrid_attn_backend.py +53 -21
  24. sglang/srt/layers/attention/trtllm_mla_backend.py +25 -10
  25. sglang/srt/layers/layernorm.py +28 -3
  26. sglang/srt/layers/linear.py +3 -2
  27. sglang/srt/layers/logits_processor.py +1 -1
  28. sglang/srt/layers/moe/cutlass_w4a8_moe.py +1 -9
  29. sglang/srt/layers/moe/ep_moe/kernels.py +74 -0
  30. sglang/srt/layers/moe/ep_moe/layer.py +14 -13
  31. sglang/srt/layers/moe/fused_moe_triton/__init__.py +5 -3
  32. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
  33. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -1048
  34. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +212 -0
  35. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +796 -0
  36. sglang/srt/layers/moe/fused_moe_triton/layer.py +5 -2
  37. sglang/srt/layers/moe/fused_moe_triton/moe_align_block_size.py +87 -0
  38. sglang/srt/layers/moe/topk.py +35 -12
  39. sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +9 -1
  40. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +0 -3
  41. sglang/srt/layers/quantization/modelopt_quant.py +7 -0
  42. sglang/srt/layers/quantization/mxfp4.py +9 -4
  43. sglang/srt/layers/quantization/utils.py +13 -0
  44. sglang/srt/layers/quantization/w4afp8.py +30 -25
  45. sglang/srt/layers/quantization/w8a8_int8.py +7 -3
  46. sglang/srt/layers/rotary_embedding.py +28 -1
  47. sglang/srt/layers/sampler.py +29 -5
  48. sglang/srt/managers/cache_controller.py +62 -96
  49. sglang/srt/managers/detokenizer_manager.py +9 -2
  50. sglang/srt/managers/io_struct.py +27 -0
  51. sglang/srt/managers/mm_utils.py +5 -1
  52. sglang/srt/managers/multi_tokenizer_mixin.py +629 -0
  53. sglang/srt/managers/scheduler.py +39 -2
  54. sglang/srt/managers/scheduler_output_processor_mixin.py +20 -18
  55. sglang/srt/managers/scheduler_update_weights_mixin.py +8 -1
  56. sglang/srt/managers/tokenizer_manager.py +86 -39
  57. sglang/srt/mem_cache/chunk_cache.py +1 -1
  58. sglang/srt/mem_cache/hicache_storage.py +20 -3
  59. sglang/srt/mem_cache/hiradix_cache.py +94 -71
  60. sglang/srt/mem_cache/lora_radix_cache.py +1 -1
  61. sglang/srt/mem_cache/memory_pool.py +4 -0
  62. sglang/srt/mem_cache/memory_pool_host.py +4 -4
  63. sglang/srt/mem_cache/radix_cache.py +5 -4
  64. sglang/srt/mem_cache/radix_cache_cpp.py +1 -1
  65. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +61 -34
  66. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +56 -9
  67. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +2 -1
  68. sglang/srt/mem_cache/swa_radix_cache.py +1 -1
  69. sglang/srt/model_executor/model_runner.py +5 -4
  70. sglang/srt/model_loader/loader.py +15 -24
  71. sglang/srt/model_loader/utils.py +12 -0
  72. sglang/srt/models/deepseek_v2.py +31 -10
  73. sglang/srt/models/gpt_oss.py +5 -18
  74. sglang/srt/models/llama_eagle3.py +4 -0
  75. sglang/srt/models/longcat_flash.py +1026 -0
  76. sglang/srt/models/longcat_flash_nextn.py +699 -0
  77. sglang/srt/models/qwen2.py +26 -3
  78. sglang/srt/models/qwen2_5_vl.py +65 -41
  79. sglang/srt/models/qwen2_moe.py +22 -2
  80. sglang/srt/models/transformers.py +1 -1
  81. sglang/srt/multimodal/processors/base_processor.py +4 -2
  82. sglang/srt/sampling/penaltylib/orchestrator.py +14 -2
  83. sglang/srt/server_args.py +112 -55
  84. sglang/srt/speculative/eagle_worker.py +28 -8
  85. sglang/srt/utils.py +4 -0
  86. sglang/test/attention/test_trtllm_mla_backend.py +12 -3
  87. sglang/test/test_cutlass_w4a8_moe.py +24 -9
  88. sglang/version.py +1 -1
  89. {sglang-0.5.1.post3.dist-info → sglang-0.5.2rc1.dist-info}/METADATA +5 -5
  90. {sglang-0.5.1.post3.dist-info → sglang-0.5.2rc1.dist-info}/RECORD +93 -85
  91. {sglang-0.5.1.post3.dist-info → sglang-0.5.2rc1.dist-info}/WHEEL +0 -0
  92. {sglang-0.5.1.post3.dist-info → sglang-0.5.2rc1.dist-info}/licenses/LICENSE +0 -0
  93. {sglang-0.5.1.post3.dist-info → sglang-0.5.2rc1.dist-info}/top_level.txt +0 -0
@@ -250,26 +250,21 @@ class HiCacheController:
250
250
  self.write_policy = write_policy
251
251
  self.page_size = page_size
252
252
  self.io_backend = io_backend
253
-
254
253
  self.enable_storage = False
255
254
 
256
- # todo: move backend initialization to storage backend module
257
255
  if storage_backend is not None:
258
256
  self.storage_backend_type = storage_backend
259
257
  from sglang.srt.mem_cache.hicache_storage import get_hash_str
260
258
 
261
259
  self.get_hash_str = get_hash_str
262
-
263
260
  self.storage_config = self._generate_storage_config(
264
261
  model_name, storage_backend_extra_config
265
262
  )
266
- # In MLA backend, only one rank needs to backup the KV cache
263
+ # for MLA models, only one rank needs to backup the KV cache
267
264
  self.backup_skip = (
268
265
  self.storage_config.is_mla_model
269
- # todo: for load balancing, decide which rank to backup the KV cache by hash value
266
+ # todo: load balancing
270
267
  and self.storage_config.tp_rank != 0
271
- # todo: support other storage backends
272
- and self.storage_backend_type in ["file", "mooncake"]
273
268
  )
274
269
 
275
270
  if storage_backend == "file":
@@ -309,12 +304,15 @@ class HiCacheController:
309
304
  raise NotImplementedError(
310
305
  f"Unsupported storage backend: {storage_backend}"
311
306
  )
307
+
312
308
  self.enable_storage = True
313
309
  # todo: threshold policy for prefetching
314
310
  self.prefetch_threshold = max(prefetch_threshold, self.page_size)
315
311
  self.prefetch_capacity_limit = int(
316
312
  0.8 * (self.mem_pool_host.size - self.mem_pool_device.size)
317
313
  )
314
+ # granularity of batch storage IO operations, in number of pages
315
+ self.storage_batch_size = 128
318
316
  # tracking the number of tokens locked in prefetching, updated by the main scheduler thread
319
317
  self.prefetch_tokens_occupied = 0
320
318
 
@@ -325,12 +323,6 @@ class HiCacheController:
325
323
  self.prefetch_tp_group = torch.distributed.new_group(
326
324
  group_ranks, backend="gloo"
327
325
  )
328
- self.prefetch_io_tp_group = torch.distributed.new_group(
329
- group_ranks, backend="gloo"
330
- )
331
- self.backup_tp_group = torch.distributed.new_group(
332
- group_ranks, backend="gloo"
333
- )
334
326
 
335
327
  self.load_cache_event = load_cache_event
336
328
  self.layer_done_counter = LayerDoneCounter(self.mem_pool_device.layer_num)
@@ -380,6 +372,7 @@ class HiCacheController:
380
372
 
381
373
  self.prefetch_revoke_queue = Queue()
382
374
  self.ack_backup_queue = Queue()
375
+ self.host_mem_release_queue = Queue()
383
376
 
384
377
  self.prefetch_thread.start()
385
378
  self.backup_thread.start()
@@ -618,7 +611,11 @@ class HiCacheController:
618
611
  operation.mark_done()
619
612
  return operation.completed_tokens, operation.hash_value
620
613
 
621
- # zero copy
614
+ def append_host_mem_release(self, host_indices: torch.Tensor):
615
+ chunks = host_indices.split(self.mem_pool_host.page_size)
616
+ for chunk in chunks:
617
+ self.host_mem_release_queue.put(chunk)
618
+
622
619
  def _3fs_zero_copy_page_get(self, operation, hash_values, host_indices):
623
620
  hashes, dsts = self.mem_pool_host.get_buffer_with_hash(
624
621
  hash_values, host_indices
@@ -631,11 +628,11 @@ class HiCacheController:
631
628
  f"Prefetch operation {operation.request_id} failed to retrieve page {hashes}."
632
629
  )
633
630
 
634
- # zero copy
635
631
  def _mooncake_page_get(self, operation, hash_values, host_indices):
636
632
  key_strs, buffer_ptrs, buffer_sizes = self.mem_pool_host.get_buffer_meta(
637
633
  hash_values,
638
634
  host_indices,
635
+ self.storage_config.tp_rank,
639
636
  )
640
637
  get_result = self.storage_backend.batch_get(
641
638
  key_strs,
@@ -649,9 +646,7 @@ class HiCacheController:
649
646
  if get_result != 0:
650
647
  operation.increment(get_result * self.page_size)
651
648
 
652
- # non-zero copy
653
649
  def _generic_page_get(self, operation, hash_values, host_indices):
654
- # todo: zero copy
655
650
  dummy_page_dst = [self.mem_pool_host.get_dummy_flat_data_page()] * len(
656
651
  hash_values
657
652
  )
@@ -674,22 +669,19 @@ class HiCacheController:
674
669
 
675
670
  def _page_transfer(self, operation):
676
671
  # Select the get function and batch size
677
- if self.is_mooncake_backend():
672
+ if self.storage_backend_type == "mooncake":
678
673
  get_func = self._mooncake_page_get
679
- batch_size = 128
680
- elif self.storage_backend_type == "hf3fs":
681
- if self.mem_pool_host.layout == "page_first":
682
- get_func = self._3fs_zero_copy_page_get
683
- elif self.mem_pool_host.layout == "layer_first":
684
- get_func = self._generic_page_get
685
- batch_size = 128
674
+ elif (
675
+ self.storage_backend_type == "hf3fs"
676
+ and self.mem_pool_host.layout == "page_first"
677
+ ):
678
+ get_func = self._3fs_zero_copy_page_get
686
679
  else:
687
680
  get_func = self._generic_page_get
688
- batch_size = 8
689
681
 
690
682
  # Transfer batch by batch
691
- for i in range(0, len(operation.hash_value), batch_size):
692
- batch_hashes = operation.hash_value[i : i + batch_size]
683
+ for i in range(0, len(operation.hash_value), self.storage_batch_size):
684
+ batch_hashes = operation.hash_value[i : i + self.storage_batch_size]
693
685
  batch_host_indices = operation.host_indices[
694
686
  i * self.page_size : (i + len(batch_hashes)) * self.page_size
695
687
  ]
@@ -703,10 +695,9 @@ class HiCacheController:
703
695
  ):
704
696
  break # Some operations fail or operation terminated by controller
705
697
  # release pre-allocated memory
706
- self.mem_pool_host.free(operation.host_indices[operation.completed_tokens :])
707
-
708
- def is_mooncake_backend(self):
709
- return self.storage_backend_type == "mooncake"
698
+ self.append_host_mem_release(
699
+ operation.host_indices[operation.completed_tokens :]
700
+ )
710
701
 
711
702
  def prefetch_io_aux_func(self):
712
703
  """
@@ -716,47 +707,49 @@ class HiCacheController:
716
707
  try:
717
708
  operation = self.prefetch_buffer.get(block=True, timeout=1)
718
709
  self._page_transfer(operation)
719
-
720
- if self.tp_world_size > 1:
721
- # to ensure all TP workers release the host memory at the same time
722
- torch.distributed.barrier(group=self.prefetch_io_tp_group)
723
710
  # operation terminated by controller, release pre-allocated memory
724
- self.mem_pool_host.free(
711
+ self.append_host_mem_release(
725
712
  operation.host_indices[operation.completed_tokens :]
726
713
  )
727
714
  except Empty:
728
715
  continue
729
716
 
730
- def prefetch_rate_limit_check(self) -> bool:
717
+ def prefetch_rate_limited(self) -> bool:
731
718
  """
732
719
  Rate limit the prefetching operations to avoid overwhelming the storage backend.
733
720
  """
734
721
  # cancel prefetch if too much memory is occupied
735
722
  if self.prefetch_tokens_occupied >= self.prefetch_capacity_limit:
736
- return False
723
+ return True
737
724
  # todo: more sophisticated rate limiting based on storage backend performance
738
- return True
725
+ return False
739
726
 
740
- def _generic_storage_hit_query(self, operation) -> tuple[list[str], int]:
727
+ def _storage_hit_query(self, operation) -> tuple[list[str], int]:
741
728
  last_hash = operation.last_hash
742
729
  tokens_to_fetch = operation.token_ids
743
730
 
744
731
  storage_query_count = 0
745
- remaining_tokens = len(tokens_to_fetch)
746
732
  hash_value = []
747
- while remaining_tokens >= self.page_size:
748
- last_hash = self.get_hash_str(
749
- tokens_to_fetch[
750
- storage_query_count : storage_query_count + self.page_size
751
- ],
752
- last_hash,
733
+
734
+ for start in range(
735
+ 0, len(tokens_to_fetch), self.page_size * self.storage_batch_size
736
+ ):
737
+ end = min(
738
+ start + self.page_size * self.storage_batch_size, len(tokens_to_fetch)
753
739
  )
754
- hash_value.append(last_hash)
755
- storage_query_count += self.page_size
756
- remaining_tokens -= self.page_size
757
- # deferring to batch exists
758
- hit_page_num = self.storage_backend.batch_exists(hash_value)
759
- return hash_value[:hit_page_num], hit_page_num * self.page_size
740
+ batch_tokens = tokens_to_fetch[start:end]
741
+ batch_hashes = []
742
+ for i in range(0, len(batch_tokens), self.page_size):
743
+ last_hash = self.get_hash_str(
744
+ batch_tokens[i : i + self.page_size], last_hash
745
+ )
746
+ batch_hashes.append(last_hash)
747
+ hit_page_num = self.storage_backend.batch_exists(batch_hashes)
748
+ hash_value.extend(batch_hashes[:hit_page_num])
749
+ storage_query_count += hit_page_num * self.page_size
750
+ if hit_page_num < len(batch_hashes):
751
+ break
752
+ return hash_value, storage_query_count
760
753
 
761
754
  def prefetch_thread_func(self):
762
755
  """
@@ -771,13 +764,7 @@ class HiCacheController:
771
764
  if operation is None:
772
765
  continue
773
766
 
774
- if (
775
- operation.host_indices is not None
776
- ) and self.prefetch_rate_limit_check():
777
- hash_value, storage_hit_count = self._generic_storage_hit_query(
778
- operation
779
- )
780
-
767
+ hash_value, storage_hit_count = self._storage_hit_query(operation)
781
768
  if self.tp_world_size > 1:
782
769
  storage_hit_count_tensor = torch.tensor(
783
770
  storage_hit_count, dtype=torch.int
@@ -792,8 +779,7 @@ class HiCacheController:
792
779
  if storage_hit_count < self.prefetch_threshold:
793
780
  # not to prefetch if not enough benefits
794
781
  self.prefetch_revoke_queue.put(operation.request_id)
795
- if operation.host_indices is not None:
796
- self.mem_pool_host.free(operation.host_indices)
782
+ self.append_host_mem_release(operation.host_indices)
797
783
  logger.debug(
798
784
  f"Revoking prefetch for request {operation.request_id} due to insufficient hits ({storage_hit_count})."
799
785
  )
@@ -802,7 +788,9 @@ class HiCacheController:
802
788
  : (storage_hit_count // self.page_size)
803
789
  ]
804
790
  # free the pre-allocated memory for pages that are not hit
805
- self.mem_pool_host.free(operation.host_indices[storage_hit_count:])
791
+ self.append_host_mem_release(
792
+ operation.host_indices[storage_hit_count:]
793
+ )
806
794
  operation.host_indices = operation.host_indices[:storage_hit_count]
807
795
  logger.debug(
808
796
  f"Prefetching {len(operation.hash_value)} pages for request {operation.request_id}."
@@ -838,6 +826,7 @@ class HiCacheController:
838
826
  key_strs, buffer_ptrs, buffer_sizes = self.mem_pool_host.get_buffer_meta(
839
827
  hash_values,
840
828
  host_indices,
829
+ self.storage_config.tp_rank,
841
830
  )
842
831
  success = self.storage_backend.batch_set(
843
832
  key_strs,
@@ -856,21 +845,18 @@ class HiCacheController:
856
845
  # Backup batch by batch
857
846
  def _page_backup(self, operation):
858
847
  # Select the set function and batch size
859
- if self.is_mooncake_backend():
848
+ if self.storage_backend_type == "mooncake":
860
849
  backup_set_func = self._mooncake_page_set
861
- batch_size = 128
862
- elif self.storage_backend_type == "hf3fs":
863
- if self.mem_pool_host.layout == "page_first":
864
- backup_set_func = self._3fs_zero_copy_page_set
865
- elif self.mem_pool_host.layout == "layer_first":
866
- backup_set_func = self._generic_page_set
867
- batch_size = 128
850
+ elif (
851
+ self.storage_backend_type == "hf3fs"
852
+ and self.mem_pool_host.layout == "page_first"
853
+ ):
854
+ backup_set_func = self._3fs_zero_copy_page_set
868
855
  else:
869
856
  backup_set_func = self._generic_page_set
870
- batch_size = 8
871
857
  # Backup batch by batch
872
- for i in range(0, len(operation.hash_value), batch_size):
873
- batch_hashes = operation.hash_value[i : i + batch_size]
858
+ for i in range(0, len(operation.hash_value), self.storage_batch_size):
859
+ batch_hashes = operation.hash_value[i : i + self.storage_batch_size]
874
860
  batch_host_indices = operation.host_indices[
875
861
  i * self.page_size : (i + len(batch_hashes)) * self.page_size
876
862
  ]
@@ -896,27 +882,7 @@ class HiCacheController:
896
882
 
897
883
  if not self.backup_skip:
898
884
  self._page_backup(operation)
899
- min_completed_tokens = operation.completed_tokens
900
- else:
901
- min_completed_tokens = len(operation.token_ids)
902
-
903
- if self.tp_world_size > 1:
904
- completed_tokens_tensor = torch.tensor(
905
- min_completed_tokens, dtype=torch.int
906
- )
907
- torch.distributed.all_reduce(
908
- completed_tokens_tensor,
909
- op=torch.distributed.ReduceOp.MIN,
910
- group=self.backup_tp_group,
911
- )
912
- min_completed_tokens = completed_tokens_tensor.item()
913
-
914
- self.ack_backup_queue.put(
915
- (
916
- operation.id,
917
- min_completed_tokens,
918
- )
919
- )
885
+ self.ack_backup_queue.put(operation.id)
920
886
 
921
887
  except Empty:
922
888
  continue
@@ -32,7 +32,9 @@ from sglang.srt.managers.io_struct import (
32
32
  BatchStrOut,
33
33
  BatchTokenIDOut,
34
34
  FreezeGCReq,
35
+ MultiTokenizerRegisterReq,
35
36
  )
37
+ from sglang.srt.managers.multi_tokenizer_mixin import MultiTokenizerMixin
36
38
  from sglang.srt.server_args import PortArgs, ServerArgs
37
39
  from sglang.srt.utils import (
38
40
  configure_logger,
@@ -67,7 +69,7 @@ class DecodeStatus:
67
69
  sent_offset: int = 0
68
70
 
69
71
 
70
- class DetokenizerManager:
72
+ class DetokenizerManager(MultiTokenizerMixin):
71
73
  """DetokenizerManager is a process that detokenizes the token ids."""
72
74
 
73
75
  def __init__(
@@ -102,6 +104,7 @@ class DetokenizerManager:
102
104
  (BatchEmbeddingOut, self.handle_batch_embedding_out),
103
105
  (BatchTokenIDOut, self.handle_batch_token_id_out),
104
106
  (BatchMultimodalDecodeReq, self.handle_multimodal_decode_req),
107
+ (MultiTokenizerRegisterReq, lambda x: x),
105
108
  (FreezeGCReq, self.handle_freeze_gc_req),
106
109
  ]
107
110
  )
@@ -285,8 +288,12 @@ def run_detokenizer_process(
285
288
 
286
289
  try:
287
290
  manager = DetokenizerManager(server_args, port_args)
288
- manager.event_loop()
291
+ if server_args.tokenizer_worker_num > 1:
292
+ manager.multi_tokenizer_manager_event_loop()
293
+ else:
294
+ manager.event_loop()
289
295
  except Exception:
296
+ manager.clear_tokenizer_mapping()
290
297
  traceback = get_exception_traceback()
291
298
  logger.error(f"DetokenizerManager hit an exception: {traceback}")
292
299
  parent_process.send_signal(signal.SIGQUIT)
@@ -814,6 +814,16 @@ class BatchEmbeddingOut:
814
814
  cached_tokens: List[int]
815
815
 
816
816
 
817
+ @dataclass
818
+ class ClearHiCacheReqInput:
819
+ pass
820
+
821
+
822
+ @dataclass
823
+ class ClearHiCacheReqOutput:
824
+ success: bool
825
+
826
+
817
827
  @dataclass
818
828
  class FlushCacheReqInput:
819
829
  pass
@@ -973,6 +983,11 @@ class AbortReq:
973
983
  abort_all: bool = False
974
984
  # The finished reason data
975
985
  finished_reason: Optional[Dict[str, Any]] = None
986
+ # used in MultiTokenzierManager mode
987
+ rids: Optional[Union[List[str], str]] = None
988
+
989
+ def __post_init__(self):
990
+ self.rids = self.rid
976
991
 
977
992
 
978
993
  @dataclass
@@ -1173,6 +1188,18 @@ class LoRAUpdateResult:
1173
1188
  LoadLoRAAdapterReqOutput = UnloadLoRAAdapterReqOutput = LoRAUpdateResult
1174
1189
 
1175
1190
 
1191
+ @dataclass
1192
+ class MultiTokenizerRegisterReq:
1193
+ rids: Optional[Union[List[str], str]] = None
1194
+ ipc_name: Optional[str] = None
1195
+
1196
+
1197
+ @dataclass
1198
+ class MultiTokenizerWarpper:
1199
+ worker_id: int
1200
+ obj: Optional[Any] = None
1201
+
1202
+
1176
1203
  class BlockReqType(Enum):
1177
1204
  BLOCK = 1
1178
1205
  UNBLOCK = 2
@@ -20,9 +20,11 @@ from sglang.srt.managers.schedule_batch import (
20
20
  )
21
21
  from sglang.srt.mem_cache.multimodal_cache import MultiModalCache
22
22
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
23
- from sglang.srt.utils import flatten_nested_list, print_warning_once
23
+ from sglang.srt.utils import flatten_nested_list, is_npu, print_warning_once
24
24
  from sglang.utils import logger
25
25
 
26
+ _is_npu = is_npu()
27
+
26
28
  # NOTE: Using the shared logger from sglang.utils instead of creating a module-specific logger
27
29
  # to ensure consistent logging behavior across the codebase. This prevents issues with log
28
30
  # propagation that can cause some log messages (like 'server is fired up') to not appear
@@ -486,6 +488,8 @@ def get_embedding_and_mask(
486
488
  if embedding is None:
487
489
  return None, None
488
490
  # 2. Get mask
491
+ if _is_npu:
492
+ torch.npu.current_stream().synchronize()
489
493
  special_multimodal_mask = _get_multimodal_mask(input_ids, placeholder_tensor)
490
494
  # 3. Adjust embedding length if needed
491
495
  embedding = _adjust_embedding_length(embedding, special_multimodal_mask, logger)