sglang 0.4.9.post4__py3-none-any.whl → 0.4.9.post5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (84) hide show
  1. sglang/lang/chat_template.py +21 -0
  2. sglang/srt/configs/internvl.py +3 -0
  3. sglang/srt/configs/model_config.py +4 -0
  4. sglang/srt/constrained/base_grammar_backend.py +10 -2
  5. sglang/srt/constrained/xgrammar_backend.py +7 -5
  6. sglang/srt/conversation.py +16 -1
  7. sglang/srt/debug_utils/__init__.py +0 -0
  8. sglang/srt/debug_utils/dump_comparator.py +131 -0
  9. sglang/srt/debug_utils/dumper.py +108 -0
  10. sglang/srt/debug_utils/text_comparator.py +172 -0
  11. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +13 -1
  12. sglang/srt/disaggregation/mooncake/conn.py +16 -0
  13. sglang/srt/disaggregation/prefill.py +13 -1
  14. sglang/srt/entrypoints/engine.py +4 -2
  15. sglang/srt/entrypoints/openai/serving_chat.py +132 -79
  16. sglang/srt/function_call/ebnf_composer.py +10 -3
  17. sglang/srt/function_call/function_call_parser.py +2 -0
  18. sglang/srt/function_call/glm4_moe_detector.py +164 -0
  19. sglang/srt/function_call/qwen3_coder_detector.py +1 -0
  20. sglang/srt/layers/attention/hybrid_attn_backend.py +100 -0
  21. sglang/srt/layers/attention/vision.py +56 -8
  22. sglang/srt/layers/layernorm.py +26 -1
  23. sglang/srt/layers/logits_processor.py +14 -3
  24. sglang/srt/layers/moe/ep_moe/layer.py +172 -206
  25. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=160,N=320,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  26. sglang/srt/layers/moe/fused_moe_triton/layer.py +38 -48
  27. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +11 -8
  28. sglang/srt/layers/moe/topk.py +84 -22
  29. sglang/srt/layers/multimodal.py +11 -8
  30. sglang/srt/layers/quantization/fp8.py +25 -247
  31. sglang/srt/layers/quantization/fp8_kernel.py +78 -48
  32. sglang/srt/layers/quantization/modelopt_quant.py +25 -10
  33. sglang/srt/layers/quantization/unquant.py +24 -76
  34. sglang/srt/layers/quantization/w4afp8.py +68 -17
  35. sglang/srt/lora/lora_registry.py +93 -29
  36. sglang/srt/managers/cache_controller.py +9 -7
  37. sglang/srt/managers/mm_utils.py +154 -35
  38. sglang/srt/managers/multimodal_processor.py +3 -14
  39. sglang/srt/managers/schedule_batch.py +14 -8
  40. sglang/srt/managers/scheduler.py +35 -1
  41. sglang/srt/managers/tokenizer_manager.py +37 -6
  42. sglang/srt/managers/tp_worker.py +3 -0
  43. sglang/srt/mem_cache/hiradix_cache.py +5 -2
  44. sglang/srt/model_executor/model_runner.py +68 -14
  45. sglang/srt/models/deepseek_v2.py +62 -28
  46. sglang/srt/models/glm4_moe.py +1035 -0
  47. sglang/srt/models/glm4_moe_nextn.py +167 -0
  48. sglang/srt/models/interns1.py +328 -0
  49. sglang/srt/models/internvl.py +143 -47
  50. sglang/srt/models/llava.py +9 -5
  51. sglang/srt/models/minicpmo.py +4 -1
  52. sglang/srt/models/qwen2_moe.py +2 -2
  53. sglang/srt/models/qwen3_moe.py +5 -2
  54. sglang/srt/multimodal/processors/base_processor.py +20 -6
  55. sglang/srt/multimodal/processors/clip.py +2 -2
  56. sglang/srt/multimodal/processors/deepseek_vl_v2.py +2 -2
  57. sglang/srt/multimodal/processors/gemma3.py +2 -2
  58. sglang/srt/multimodal/processors/gemma3n.py +2 -2
  59. sglang/srt/multimodal/processors/internvl.py +21 -8
  60. sglang/srt/multimodal/processors/janus_pro.py +2 -2
  61. sglang/srt/multimodal/processors/kimi_vl.py +2 -2
  62. sglang/srt/multimodal/processors/llava.py +4 -4
  63. sglang/srt/multimodal/processors/minicpm.py +2 -3
  64. sglang/srt/multimodal/processors/mlama.py +2 -2
  65. sglang/srt/multimodal/processors/mllama4.py +18 -111
  66. sglang/srt/multimodal/processors/phi4mm.py +2 -2
  67. sglang/srt/multimodal/processors/pixtral.py +2 -2
  68. sglang/srt/multimodal/processors/qwen_audio.py +2 -2
  69. sglang/srt/multimodal/processors/qwen_vl.py +2 -2
  70. sglang/srt/multimodal/processors/vila.py +3 -1
  71. sglang/srt/reasoning_parser.py +2 -1
  72. sglang/srt/server_args.py +57 -6
  73. sglang/srt/utils.py +96 -1
  74. sglang/srt/weight_sync/utils.py +119 -0
  75. sglang/test/runners.py +4 -0
  76. sglang/test/test_utils.py +65 -5
  77. sglang/utils.py +19 -0
  78. sglang/version.py +1 -1
  79. {sglang-0.4.9.post4.dist-info → sglang-0.4.9.post5.dist-info}/METADATA +4 -4
  80. {sglang-0.4.9.post4.dist-info → sglang-0.4.9.post5.dist-info}/RECORD +83 -73
  81. sglang/srt/debug_utils.py +0 -74
  82. {sglang-0.4.9.post4.dist-info → sglang-0.4.9.post5.dist-info}/WHEEL +0 -0
  83. {sglang-0.4.9.post4.dist-info → sglang-0.4.9.post5.dist-info}/licenses/LICENSE +0 -0
  84. {sglang-0.4.9.post4.dist-info → sglang-0.4.9.post5.dist-info}/top_level.txt +0 -0
@@ -3,8 +3,9 @@ Multi-modality utils
3
3
  """
4
4
 
5
5
  import hashlib
6
+ import pickle
6
7
  from abc import abstractmethod
7
- from typing import Callable, Dict, List, Optional, Tuple
8
+ from typing import Any, Callable, Dict, List, Literal, Optional, Tuple
8
9
 
9
10
  import numpy as np
10
11
  import torch
@@ -27,6 +28,128 @@ from sglang.utils import logger
27
28
  # propagation that can cause some log messages (like 'server is fired up') to not appear
28
29
  # in the console when multimodal support is enabled.
29
30
 
31
+ # TODO(mick): nccl
32
+ # cuda_ipc: for intranode tensor sharing
33
+ TensorTransportMode = Literal["cuda_ipc", "auto", "default"]
34
+
35
+
36
+ class TransportProxyTensor(torch.Tensor):
37
+ """
38
+ A convenient torch.Tensor subclass that carries extra metadata and supports
39
+ efficient inter-process communications
40
+ """
41
+
42
+ @staticmethod
43
+ def __new__(
44
+ cls,
45
+ data: torch.Tensor,
46
+ name: Optional[str] = None,
47
+ fields: Optional[Dict[str, Any]] = None,
48
+ transport_mode: TensorTransportMode = "default",
49
+ *args,
50
+ **kwargs,
51
+ ):
52
+
53
+ if not isinstance(data, torch.Tensor):
54
+ raise TypeError(
55
+ f"Input 'data' must be a torch.Tensor, but got {type(data)}"
56
+ )
57
+
58
+ instance = data.as_subclass(cls)
59
+
60
+ instance._metadata = {
61
+ "name": name,
62
+ "fields": fields if fields is not None else {},
63
+ "transport_mode": transport_mode,
64
+ }
65
+
66
+ return instance
67
+
68
+ def __getstate__(self):
69
+ """
70
+ Called during pickling. Implements the serialization logic.
71
+ """
72
+ # acquire all serialize metadata from _metadata
73
+ state = {
74
+ "metadata": self._metadata,
75
+ "tensor_data": None,
76
+ "ipc_extra": None,
77
+ }
78
+
79
+ transport_mode = self._metadata.get("transport_mode", "default")
80
+
81
+ if transport_mode == "cuda_ipc" and self.is_cuda:
82
+ try:
83
+ storage = self.untyped_storage()
84
+ handle = storage._share_cuda_()
85
+
86
+ state["ipc_extra"] = {
87
+ "handle": handle,
88
+ "shape": self.shape,
89
+ "dtype": self.dtype,
90
+ "stride": self.stride(),
91
+ "device_index": self.device.index,
92
+ }
93
+ state["tensor_data"] = None
94
+ except Exception as e:
95
+ # Failed to get CUDA IPC handle (possibly tp). Falling back to default transport.
96
+ state["metadata"]["transport_mode"] = "default"
97
+ state["tensor_data"] = self.as_subclass(torch.Tensor)
98
+ else:
99
+ state["metadata"]["transport_mode"] = "default"
100
+ state["tensor_data"] = self.as_subclass(torch.Tensor)
101
+
102
+ return state
103
+
104
+ def __setstate__(self, state: Dict[str, Any]):
105
+ """
106
+ Called during unpickling. Implements the deserialization logic.
107
+ """
108
+ self._metadata = state["metadata"]
109
+
110
+ transport_mode = self._metadata.get("transport_mode", "default")
111
+
112
+ if transport_mode == "cuda_ipc" and state["ipc_extra"] is not None:
113
+ ipc_extra = state["ipc_extra"]
114
+ handle, shape, dtype, stride, source_device_index = (
115
+ ipc_extra["handle"],
116
+ ipc_extra["shape"],
117
+ ipc_extra["dtype"],
118
+ ipc_extra["stride"],
119
+ ipc_extra["device_index"],
120
+ )
121
+
122
+ try:
123
+ target_device = torch.device(f"cuda:{source_device_index}")
124
+ with torch.cuda.device(target_device):
125
+ storage = torch.UntypedStorage._new_shared_cuda(*handle)
126
+ reconstructed_tensor = torch.empty(
127
+ 0, dtype=dtype, device=target_device
128
+ ).set_(storage, storage_offset=0, size=shape, stride=stride)
129
+ self.set_(reconstructed_tensor)
130
+ except Exception as e:
131
+ print(f"Error: Failed to deserialize from CUDA IPC handle ({e}).")
132
+ raise e
133
+
134
+ elif state["tensor_data"] is not None:
135
+ self.set_(state["tensor_data"])
136
+ else:
137
+ raise pickle.UnpicklingError(
138
+ "Invalid state for TransportProxyTensor: no tensor data found."
139
+ )
140
+
141
+ @property
142
+ def name(self) -> Optional[str]:
143
+ return self._metadata.get("name")
144
+
145
+ @property
146
+ def fields(self) -> Dict[str, Any]:
147
+ return self._metadata.get("fields", {})
148
+
149
+ @property
150
+ def transport_mode(self) -> TensorTransportMode:
151
+ return self._metadata.get("transport_mode", "default")
152
+
30
153
 
31
154
  class MultiModalityDataPaddingPattern:
32
155
  """
@@ -85,8 +208,8 @@ class MultiModalityDataPaddingPatternTokenPairs(MultiModalityDataPaddingPattern)
85
208
  "No data_token_pairs provided, RadixAttention might be influenced."
86
209
  )
87
210
  return input_ids
88
- start_token_ids = [s for s, _e in data_token_pairs]
89
- end_tokens_ids = [e for _s, e in data_token_pairs]
211
+ start_token_ids = {s for s, _e in data_token_pairs}
212
+ end_tokens_ids = {e for _s, e in data_token_pairs}
90
213
 
91
214
  padded_ids = []
92
215
  last_idx = 0
@@ -135,7 +258,7 @@ class MultiModalityDataPaddingPatternMultimodalTokens(MultiModalityDataPaddingPa
135
258
  if not input_ids or not mm_inputs.mm_items:
136
259
  return input_ids
137
260
 
138
- input_ids_tensor = torch.tensor(input_ids)
261
+ input_ids_tensor = torch.as_tensor(input_ids)
139
262
 
140
263
  # Create mapping of token_ids to pad_values for each modality
141
264
  token_to_pad_mapping = {}
@@ -211,7 +334,7 @@ def get_embedding_chunk(
211
334
  end_index += extend_end_index - start + 1
212
335
  elif extend_end_index > end:
213
336
  end_index += end - start + 1
214
- # some models embedding is 3-dim, reshape it to 2-dim
337
+ # some models' embedding is 3-dim, reshape it to 2-dim
215
338
  embedding = embedding.reshape(-1, embedding.shape[-1])
216
339
  embedding_chunk = embedding[start_index:end_index]
217
340
  return embedding_chunk, start_index, end_index
@@ -428,7 +551,7 @@ def embed_mm_inputs(
428
551
  modality_id = modality.name.lower()
429
552
  embedder = getattr(multimodal_model, f"get_{modality_id}_feature", None)
430
553
  if len(items) != 0 and embedder is not None:
431
- placeholder_tensor = torch.tensor(
554
+ placeholder_tensor = torch.as_tensor(
432
555
  [item.pad_value for item in items],
433
556
  device=input_ids.device,
434
557
  )
@@ -473,11 +596,9 @@ def embed_mm_inputs(
473
596
  for embedding, mask in zip(embeddings, masks):
474
597
  if embedding is None or mask is None:
475
598
  continue
476
- mask = mask.expand_as(inputs_embeds).to(inputs_embeds.device)
477
- inputs_embeds = inputs_embeds.masked_scatter(
478
- mask,
479
- embedding.to(inputs_embeds.device, inputs_embeds.dtype),
480
- )
599
+ # in-place update
600
+ indices = torch.where(mask.squeeze(dim=-1))[0]
601
+ inputs_embeds[indices] = embedding.to(inputs_embeds.device, inputs_embeds.dtype)
481
602
  return inputs_embeds
482
603
 
483
604
 
@@ -561,34 +682,36 @@ def get_multimodal_data_bounds(
561
682
  [bounds_count, 2]
562
683
  """
563
684
  # All the multimodal data in the batch should share the same special bound token ids.
564
- start_tokens = [s for s, _e in token_pairs]
565
- end_tokens = [e for _s, e in token_pairs]
685
+ start_tokens = {s for s, _e in token_pairs}
686
+ end_tokens = {e for _s, e in token_pairs}
566
687
 
567
688
  assert all(isinstance(t, int) for t in start_tokens)
568
689
  assert all(isinstance(t, int) for t in end_tokens)
569
690
 
570
691
  start_cond = torch.isin(
571
- input_ids, torch.tensor(start_tokens, device=input_ids.device)
692
+ input_ids, torch.as_tensor(start_tokens, device=input_ids.device)
693
+ )
694
+ end_cond = torch.isin(
695
+ input_ids, torch.as_tensor(end_tokens, device=input_ids.device)
572
696
  )
573
- end_cond = torch.isin(input_ids, torch.tensor(end_tokens, device=input_ids.device))
574
697
 
575
698
  (data_start_tokens,) = torch.where(start_cond)
576
699
  (data_end_tokens,) = torch.where(end_cond)
577
700
 
701
+ data_start_tokens_cpu = data_start_tokens.cpu().tolist()
702
+ data_end_tokens_cpu = data_end_tokens.cpu().tolist()
703
+
578
704
  # the im_start_id sometimes can be cached as prefix, but it is needed for the embedding of the multimodal data
579
- if len(data_start_tokens) != len(data_end_tokens):
705
+ if len(data_start_tokens_cpu) != len(data_end_tokens_cpu):
580
706
  if (
581
- len(data_start_tokens) + 1 == len(data_end_tokens)
582
- and input_ids[0] in pad_values
583
- and data_end_tokens[0] < data_start_tokens[0]
707
+ len(data_start_tokens_cpu) + 1 == len(data_end_tokens_cpu)
708
+ and input_ids[0].item() in pad_values
709
+ and data_end_tokens_cpu
710
+ and data_start_tokens_cpu
711
+ and data_end_tokens_cpu[0] < data_start_tokens_cpu[0]
584
712
  ):
585
- data_start_tokens = torch.cat(
586
- [
587
- torch.tensor([0], device=data_start_tokens.device),
588
- data_start_tokens,
589
- ]
590
- )
591
- valid_mm_data_nums = min(len(data_start_tokens), len(data_end_tokens))
713
+ data_start_tokens_cpu.insert(0, 0)
714
+ valid_mm_data_nums = min(len(data_start_tokens_cpu), len(data_end_tokens_cpu))
592
715
 
593
716
  if valid_mm_data_nums == 0:
594
717
  return torch.zeros((0, 2), device=input_ids.device)
@@ -596,8 +719,8 @@ def get_multimodal_data_bounds(
596
719
  # Filter out pairs where start_token >= end_token
597
720
  valid_pairs = []
598
721
  for i in range(valid_mm_data_nums):
599
- start_token = data_start_tokens[i]
600
- end_token = data_end_tokens[i]
722
+ start_token = data_start_tokens_cpu[i]
723
+ end_token = data_end_tokens_cpu[i]
601
724
  if start_token < end_token:
602
725
  valid_pairs.append((start_token + 1, end_token - 1))
603
726
 
@@ -605,7 +728,7 @@ def get_multimodal_data_bounds(
605
728
  return torch.zeros((0, 2), device=input_ids.device)
606
729
 
607
730
  # Convert valid pairs to tensor
608
- valid_pairs_tensor = torch.tensor(valid_pairs, device=input_ids.device)
731
+ valid_pairs_tensor = torch.as_tensor(valid_pairs, device=input_ids.device)
609
732
  return valid_pairs_tensor
610
733
 
611
734
 
@@ -626,7 +749,7 @@ def tensor_hash(tensor_list) -> int:
626
749
  ]
627
750
  tensor = torch.concat(tensor_list)
628
751
  if tensor.is_cuda:
629
- return gpu_tensor_hash(tensor)
752
+ return gpu_tensor_hash(tensor.cuda())
630
753
  tensor = tensor.detach().contiguous()
631
754
 
632
755
  if tensor.dtype == torch.bfloat16:
@@ -634,11 +757,7 @@ def tensor_hash(tensor_list) -> int:
634
757
  tensor = tensor.float()
635
758
 
636
759
  assert isinstance(tensor, torch.Tensor)
637
- if tensor.is_cuda:
638
- # TODO: improve this
639
- tensor_cpu = tensor.cpu()
640
- else:
641
- tensor_cpu = tensor
760
+ tensor_cpu = tensor.cpu()
642
761
 
643
762
  mv = memoryview(tensor_cpu.numpy())
644
763
  return data_hash(mv.tobytes())
@@ -12,18 +12,6 @@ logger = logging.getLogger(__name__)
12
12
  PROCESSOR_MAPPING = {}
13
13
 
14
14
 
15
- class DummyMultimodalProcessor(BaseMultimodalProcessor):
16
- def __init__(self):
17
- pass
18
-
19
- async def process_mm_data_async(self, *args, **kwargs):
20
- return None
21
-
22
-
23
- def get_dummy_processor():
24
- return DummyMultimodalProcessor()
25
-
26
-
27
15
  def import_processors():
28
16
  package_name = "sglang.srt.multimodal.processors"
29
17
  package = importlib.import_module(package_name)
@@ -49,11 +37,12 @@ def import_processors():
49
37
 
50
38
 
51
39
  def get_mm_processor(
52
- hf_config, server_args: ServerArgs, processor
40
+ hf_config, server_args: ServerArgs, processor, transport_mode
53
41
  ) -> BaseMultimodalProcessor:
54
42
  for model_cls, processor_cls in PROCESSOR_MAPPING.items():
55
43
  if model_cls.__name__ in hf_config.architectures:
56
- return processor_cls(hf_config, server_args, processor)
44
+ return processor_cls(hf_config, server_args, processor, transport_mode)
45
+
57
46
  raise ValueError(
58
47
  f"No processor registered for architecture: {hf_config.architectures}.\n"
59
48
  f"Registered architectures: {[model_cls.__name__ for model_cls in PROCESSOR_MAPPING.keys()]}"
@@ -88,7 +88,8 @@ GLOBAL_SERVER_ARGS_KEYS = [
88
88
  "enable_deepep_moe",
89
89
  "deepep_mode",
90
90
  "enable_ep_moe",
91
- "enable_flashinfer_moe",
91
+ "enable_flashinfer_cutlass_moe",
92
+ "enable_flashinfer_trtllm_moe",
92
93
  "enable_flashinfer_allreduce_fusion",
93
94
  "moe_dense_tp_size",
94
95
  "ep_dispatch_algorithm",
@@ -209,10 +210,11 @@ class MultimodalDataItem:
209
210
  hash: int = None
210
211
  pad_value: int = None
211
212
  offsets: Optional[list] = None
213
+
212
214
  # the raw features returned by processor, e.g. pixel_values or audio_features
213
215
  feature: Union[torch.Tensor, np.ndarray] = None
214
-
215
- # the precomputed embeddings for the modality, e.g. image_emb for image, audio_emb for audio
216
+ # the precomputed embeddings, passed as final encoder embeddings
217
+ # One and only one of the feature and precomputed_embeddings will be empty
216
218
  precomputed_embeddings: Optional[Union[torch.Tensor, np.ndarray]] = None
217
219
 
218
220
  # Model-specific data stored in a dictionary
@@ -1688,16 +1690,20 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1688
1690
  extend_prefix_lens = self.prefix_lens
1689
1691
  extend_logprob_start_lens = self.extend_logprob_start_lens
1690
1692
 
1693
+ if self.forward_mode.is_decode_or_idle():
1694
+ attention_backend_str = global_server_args_dict["decode_attention_backend"]
1695
+ else:
1696
+ attention_backend_str = global_server_args_dict["prefill_attention_backend"]
1691
1697
  # Create seq_lens_cpu when needed
1692
1698
  if (
1693
- global_server_args_dict["attention_backend"] == "fa3"
1699
+ attention_backend_str == "fa3"
1694
1700
  or (
1695
1701
  global_server_args_dict["use_mla_backend"]
1696
- and global_server_args_dict["attention_backend"] == "flashinfer"
1702
+ and attention_backend_str == "flashinfer"
1697
1703
  )
1698
- or global_server_args_dict["attention_backend"] == "flashmla"
1699
- or global_server_args_dict["attention_backend"] == "cutlass_mla"
1700
- or global_server_args_dict["attention_backend"] == "ascend"
1704
+ or attention_backend_str == "flashmla"
1705
+ or attention_backend_str == "cutlass_mla"
1706
+ or attention_backend_str == "ascend"
1701
1707
  or global_server_args_dict["enable_two_batch_overlap"]
1702
1708
  ):
1703
1709
  seq_lens_cpu = (
@@ -458,7 +458,10 @@ class Scheduler(
458
458
  self.grammar_queue: List[Req] = []
459
459
  if not server_args.skip_tokenizer_init:
460
460
  self.grammar_backend = create_grammar_backend(
461
- server_args, self.tokenizer, self.model_config.vocab_size
461
+ server_args,
462
+ self.tokenizer,
463
+ self.model_config.vocab_size,
464
+ self.model_config.hf_eos_token_id,
462
465
  )
463
466
  else:
464
467
  self.grammar_backend = None
@@ -2437,6 +2440,37 @@ class Scheduler(
2437
2440
  req.grammar.cancel()
2438
2441
  req.set_finish_with_abort("Aborted by AbortReq.")
2439
2442
 
2443
+ # Delete requests not in the waiting queue when PD disaggregation is enabled
2444
+ if self.disaggregation_mode == DisaggregationMode.PREFILL:
2445
+ # Abort requests that have not yet been bootstrapped
2446
+ for i, req in enumerate(self.disagg_prefill_bootstrap_queue.queue):
2447
+ logger.debug(f"Abort bootstrap queue request. {req.rid=}")
2448
+ if recv_req.abort_all or req.rid.startswith(recv_req.rid):
2449
+ if hasattr(req.disagg_kv_sender, "abort"):
2450
+ req.disagg_kv_sender.abort()
2451
+
2452
+ # Abort in-flight requests
2453
+ for i, req in enumerate(self.disagg_prefill_inflight_queue):
2454
+ logger.debug(f"Abort inflight queue request. {req.rid=}")
2455
+ if recv_req.abort_all or req.rid.startswith(recv_req.rid):
2456
+ if hasattr(req.disagg_kv_sender, "abort"):
2457
+ req.disagg_kv_sender.abort()
2458
+
2459
+ elif self.disaggregation_mode == DisaggregationMode.DECODE:
2460
+ # Abort requests that have not yet finished preallocation
2461
+ for i, decode_req in enumerate(self.disagg_decode_prealloc_queue.queue):
2462
+ logger.debug(f"Abort prealloc queue request. {decode_req.req.rid=}")
2463
+ if recv_req.abort_all or decode_req.req.rid.startswith(recv_req.rid):
2464
+ if hasattr(decode_req.kv_receiver, "abort"):
2465
+ decode_req.kv_receiver.abort()
2466
+
2467
+ # Abort requests waiting for kvcache to release tree cache
2468
+ for i, decode_req in enumerate(self.disagg_decode_transfer_queue.queue):
2469
+ logger.debug(f"Abort transfer queue request. {decode_req.req.rid=}")
2470
+ if recv_req.abort_all or decode_req.req.rid.startswith(recv_req.rid):
2471
+ if hasattr(decode_req.kv_receiver, "abort"):
2472
+ decode_req.kv_receiver.abort()
2473
+
2440
2474
  # Delete requests in the running batch
2441
2475
  if self.cur_batch is self.running_batch or self.cur_batch is None:
2442
2476
  reqs = self.running_batch.reqs
@@ -112,6 +112,7 @@ from sglang.srt.managers.io_struct import (
112
112
  UpdateWeightsFromTensorReqInput,
113
113
  UpdateWeightsFromTensorReqOutput,
114
114
  )
115
+ from sglang.srt.managers.mm_utils import TensorTransportMode
115
116
  from sglang.srt.managers.multimodal_processor import get_mm_processor, import_processors
116
117
  from sglang.srt.metrics.collector import TokenizerMetricsCollector
117
118
  from sglang.srt.sampling.sampling_params import SamplingParams
@@ -166,6 +167,16 @@ class ReqState:
166
167
  output_token_ids_logprobs_idx: List = dataclasses.field(default_factory=list)
167
168
 
168
169
 
170
+ def _determine_tensor_transport_mode(server_args: ServerArgs) -> TensorTransportMode:
171
+ is_cross_node = server_args.dist_init_addr
172
+
173
+ if is_cross_node:
174
+ # Fallback to default CPU transport for multi-node
175
+ return "default"
176
+ else:
177
+ return "cuda_ipc"
178
+
179
+
169
180
  class TokenizerManager:
170
181
  """TokenizerManager is a process that tokenizes the text."""
171
182
 
@@ -216,12 +227,13 @@ class TokenizerManager:
216
227
  revision=server_args.revision,
217
228
  use_fast=not server_args.disable_fast_image_processor,
218
229
  )
230
+ transport_mode = _determine_tensor_transport_mode(self.server_args)
219
231
 
220
232
  # We want to parallelize the image pre-processing so we create an executor for it
221
233
  # We create mm_processor for any skip_tokenizer_init to make sure we still encode
222
234
  # images even with skip_tokenizer_init=False.
223
235
  self.mm_processor = get_mm_processor(
224
- self.model_config.hf_config, server_args, _processor
236
+ self.model_config.hf_config, server_args, _processor, transport_mode
225
237
  )
226
238
 
227
239
  if server_args.skip_tokenizer_init:
@@ -270,6 +282,11 @@ class TokenizerManager:
270
282
  None
271
283
  )
272
284
 
285
+ # Lock to serialize LoRA update operations.
286
+ # Please note that, unlike `model_update_lock`, this does not block inference, allowing
287
+ # LoRA updates and inference to overlap.
288
+ self.lora_update_lock = asyncio.Lock()
289
+
273
290
  # For pd disaggregtion
274
291
  self.disaggregation_mode = DisaggregationMode(
275
292
  self.server_args.disaggregation_mode
@@ -525,7 +542,8 @@ class TokenizerManager:
525
542
  mm_inputs = None
526
543
 
527
544
  if self.server_args.enable_lora and obj.lora_path:
528
- # Replace the user-friendly LoRA names in `lora_path` with their corresponding unique LoRA IDs.
545
+ # Start tracking ongoing requests for LoRA adapters and replace the user-friendly LoRA names in
546
+ # `lora_path` with their corresponding unique LoRA IDs, as required for internal processing.
529
547
  obj.lora_path = await self.lora_registry.acquire(obj.lora_path)
530
548
 
531
549
  self._validate_one_request(obj, input_ids)
@@ -735,6 +753,10 @@ class TokenizerManager:
735
753
  msg = f"Finish: obj={dataclass_to_string_truncated(obj, max_length, skip_names=skip_names)}, out={dataclass_to_string_truncated(out, max_length, skip_names=out_skip_names)}"
736
754
  logger.info(msg)
737
755
 
756
+ # Mark ongoing LoRA request as finished.
757
+ if self.server_args.enable_lora and obj.lora_path:
758
+ await self.lora_registry.release(obj.lora_path)
759
+
738
760
  # Check if this was an abort/error created by scheduler
739
761
  if isinstance(out["meta_info"].get("finish_reason"), dict):
740
762
  finish_reason = out["meta_info"]["finish_reason"]
@@ -1041,16 +1063,18 @@ class TokenizerManager:
1041
1063
  obj.lora_path,
1042
1064
  )
1043
1065
 
1044
- async with self.model_update_lock.writer_lock:
1066
+ async with self.lora_update_lock:
1045
1067
  # Generate new uniquely identifiable LoRARef object.
1046
1068
  new_adapter = LoRARef(
1047
1069
  lora_name=obj.lora_name,
1048
1070
  lora_path=obj.lora_path,
1049
1071
  )
1050
1072
 
1051
- # Register the new adapter in the registry.
1073
+ # Trigger the actual loading operation at the backend processes.
1052
1074
  obj.lora_id = new_adapter.lora_id
1053
1075
  result = (await self.update_lora_adapter_communicator(obj))[0]
1076
+
1077
+ # Register the LoRA adapter only after loading is successful.
1054
1078
  if result.success:
1055
1079
  await self.lora_registry.register(new_adapter)
1056
1080
 
@@ -1081,8 +1105,15 @@ class TokenizerManager:
1081
1105
  obj.lora_name,
1082
1106
  )
1083
1107
 
1084
- async with self.model_update_lock.writer_lock:
1085
- obj.lora_id = await self.lora_registry.unregister(obj.lora_name)
1108
+ async with self.lora_update_lock:
1109
+ # Unregister the LoRA adapter from the registry to stop new requests for this adapter
1110
+ # from being started.
1111
+ lora_id = await self.lora_registry.unregister(obj.lora_name)
1112
+ obj.lora_id = lora_id
1113
+
1114
+ # Initiate the actual unloading operation at the backend processes only after all
1115
+ # ongoing requests using this LoRA adapter are finished.
1116
+ await self.lora_registry.wait_for_unload(lora_id)
1086
1117
  result = (await self.update_lora_adapter_communicator(obj))[0]
1087
1118
 
1088
1119
  return result
@@ -41,6 +41,7 @@ from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
41
41
  from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
42
42
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
43
43
  from sglang.srt.model_executor.model_runner import ModelRunner
44
+ from sglang.srt.patch_torch import monkey_patch_torch_reductions
44
45
  from sglang.srt.server_args import ServerArgs
45
46
  from sglang.srt.utils import MultiprocessingSerializer, broadcast_pyobj, set_random_seed
46
47
 
@@ -278,6 +279,8 @@ class TpModelWorker:
278
279
  return success, message
279
280
 
280
281
  def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput):
282
+
283
+ monkey_patch_torch_reductions()
281
284
  success, message = self.model_runner.update_weights_from_tensor(
282
285
  named_tensors=MultiprocessingSerializer.deserialize(
283
286
  recv_req.serialized_named_tensors[self.tp_rank]
@@ -365,10 +365,12 @@ class HiRadixCache(RadixCache):
365
365
  for _ in range(queue_size.item()):
366
366
  req_id = self.cache_controller.prefetch_revoke_queue.get()
367
367
  if req_id in self.ongoing_prefetch:
368
- last_host_node, _, host_indices, _ = self.ongoing_prefetch[req_id]
368
+ last_host_node, _, _, _ = self.ongoing_prefetch[req_id]
369
369
  last_host_node.release_host()
370
- self.cache_controller.mem_pool_host.free(host_indices)
371
370
  del self.ongoing_prefetch[req_id]
371
+ else:
372
+ # the revoked operation already got terminated
373
+ pass
372
374
 
373
375
  def check_backup_progress(self):
374
376
  queue_size = torch.tensor(
@@ -403,6 +405,7 @@ class HiRadixCache(RadixCache):
403
405
  last_host_node, token_ids, host_indices, operation = self.ongoing_prefetch[
404
406
  req_id
405
407
  ]
408
+
406
409
  completed_tokens, hash_value = self.cache_controller.terminate_prefetch(
407
410
  operation
408
411
  )