sglang 0.4.5.post1__py3-none-any.whl → 0.4.5.post3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (119) hide show
  1. sglang/__init__.py +2 -4
  2. sglang/bench_one_batch.py +2 -2
  3. sglang/bench_serving.py +3 -6
  4. sglang/compile_deep_gemm.py +136 -0
  5. sglang/lang/backend/anthropic.py +0 -4
  6. sglang/lang/backend/base_backend.py +1 -1
  7. sglang/lang/backend/openai.py +6 -2
  8. sglang/lang/backend/runtime_endpoint.py +5 -1
  9. sglang/lang/backend/vertexai.py +0 -1
  10. sglang/lang/compiler.py +1 -7
  11. sglang/lang/tracer.py +3 -7
  12. sglang/srt/_custom_ops.py +0 -2
  13. sglang/srt/configs/model_config.py +4 -1
  14. sglang/srt/constrained/outlines_jump_forward.py +14 -1
  15. sglang/srt/constrained/triton_ops/bitmask_ops.py +141 -0
  16. sglang/srt/constrained/xgrammar_backend.py +27 -4
  17. sglang/srt/custom_op.py +0 -62
  18. sglang/srt/disaggregation/decode.py +105 -6
  19. sglang/srt/disaggregation/mini_lb.py +74 -9
  20. sglang/srt/disaggregation/mooncake/conn.py +33 -63
  21. sglang/srt/disaggregation/mooncake/transfer_engine.py +30 -61
  22. sglang/srt/disaggregation/nixl/__init__.py +1 -0
  23. sglang/srt/disaggregation/nixl/conn.py +622 -0
  24. sglang/srt/disaggregation/prefill.py +137 -17
  25. sglang/srt/disaggregation/utils.py +32 -0
  26. sglang/srt/entrypoints/engine.py +4 -0
  27. sglang/srt/entrypoints/http_server.py +3 -7
  28. sglang/srt/entrypoints/verl_engine.py +7 -5
  29. sglang/srt/function_call_parser.py +60 -0
  30. sglang/srt/layers/activation.py +6 -8
  31. sglang/srt/layers/attention/flashattention_backend.py +883 -209
  32. sglang/srt/layers/attention/flashinfer_backend.py +5 -2
  33. sglang/srt/layers/attention/torch_native_backend.py +6 -1
  34. sglang/srt/layers/attention/triton_backend.py +6 -0
  35. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +5 -5
  36. sglang/srt/layers/attention/triton_ops/extend_attention.py +18 -7
  37. sglang/srt/layers/attention/triton_ops/prefill_attention.py +7 -3
  38. sglang/srt/layers/dp_attention.py +1 -1
  39. sglang/srt/layers/layernorm.py +20 -5
  40. sglang/srt/layers/linear.py +17 -3
  41. sglang/srt/layers/moe/ep_moe/layer.py +17 -29
  42. sglang/srt/layers/moe/fused_moe_native.py +4 -0
  43. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +14 -19
  44. sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -0
  45. sglang/srt/layers/moe/topk.py +27 -30
  46. sglang/srt/layers/parameter.py +0 -2
  47. sglang/srt/layers/quantization/__init__.py +1 -0
  48. sglang/srt/layers/quantization/blockwise_int8.py +2 -0
  49. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +9 -2
  50. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +16 -44
  51. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  52. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +153 -0
  53. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +4 -7
  54. sglang/srt/layers/quantization/deep_gemm.py +378 -0
  55. sglang/srt/layers/quantization/fp8.py +115 -132
  56. sglang/srt/layers/quantization/fp8_kernel.py +213 -88
  57. sglang/srt/layers/quantization/fp8_utils.py +189 -264
  58. sglang/srt/layers/quantization/gptq.py +13 -7
  59. sglang/srt/layers/quantization/modelopt_quant.py +2 -2
  60. sglang/srt/layers/quantization/moe_wna16.py +2 -0
  61. sglang/srt/layers/quantization/utils.py +5 -11
  62. sglang/srt/layers/quantization/w8a8_fp8.py +2 -0
  63. sglang/srt/layers/quantization/w8a8_int8.py +7 -7
  64. sglang/srt/layers/radix_attention.py +15 -0
  65. sglang/srt/layers/rotary_embedding.py +9 -8
  66. sglang/srt/layers/sampler.py +7 -12
  67. sglang/srt/lora/backend/base_backend.py +18 -2
  68. sglang/srt/lora/backend/flashinfer_backend.py +1 -1
  69. sglang/srt/lora/backend/triton_backend.py +1 -1
  70. sglang/srt/lora/layers.py +1 -1
  71. sglang/srt/lora/lora.py +1 -1
  72. sglang/srt/lora/lora_manager.py +1 -1
  73. sglang/srt/managers/data_parallel_controller.py +7 -1
  74. sglang/srt/managers/detokenizer_manager.py +0 -1
  75. sglang/srt/managers/io_struct.py +15 -3
  76. sglang/srt/managers/mm_utils.py +4 -3
  77. sglang/srt/managers/multimodal_processor.py +0 -2
  78. sglang/srt/managers/multimodal_processors/base_processor.py +3 -2
  79. sglang/srt/managers/schedule_batch.py +15 -4
  80. sglang/srt/managers/scheduler.py +28 -77
  81. sglang/srt/managers/tokenizer_manager.py +116 -29
  82. sglang/srt/managers/tp_worker.py +1 -0
  83. sglang/srt/mem_cache/hiradix_cache.py +41 -29
  84. sglang/srt/mem_cache/memory_pool.py +38 -15
  85. sglang/srt/model_executor/cuda_graph_runner.py +15 -10
  86. sglang/srt/model_executor/model_runner.py +39 -31
  87. sglang/srt/models/bert.py +398 -0
  88. sglang/srt/models/deepseek.py +1 -1
  89. sglang/srt/models/deepseek_nextn.py +74 -70
  90. sglang/srt/models/deepseek_v2.py +292 -348
  91. sglang/srt/models/llama.py +5 -5
  92. sglang/srt/models/minicpm3.py +31 -203
  93. sglang/srt/models/minicpmo.py +17 -6
  94. sglang/srt/models/qwen2.py +4 -1
  95. sglang/srt/models/qwen2_moe.py +14 -13
  96. sglang/srt/models/qwen3.py +335 -0
  97. sglang/srt/models/qwen3_moe.py +423 -0
  98. sglang/srt/openai_api/adapter.py +71 -4
  99. sglang/srt/openai_api/protocol.py +6 -1
  100. sglang/srt/reasoning_parser.py +0 -1
  101. sglang/srt/sampling/sampling_batch_info.py +2 -3
  102. sglang/srt/server_args.py +86 -72
  103. sglang/srt/speculative/build_eagle_tree.py +2 -2
  104. sglang/srt/speculative/eagle_utils.py +2 -2
  105. sglang/srt/speculative/eagle_worker.py +6 -14
  106. sglang/srt/utils.py +62 -6
  107. sglang/test/runners.py +5 -1
  108. sglang/test/test_block_fp8.py +167 -0
  109. sglang/test/test_custom_ops.py +1 -1
  110. sglang/test/test_utils.py +3 -1
  111. sglang/version.py +1 -1
  112. {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post3.dist-info}/METADATA +5 -5
  113. {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post3.dist-info}/RECORD +116 -110
  114. {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post3.dist-info}/WHEEL +1 -1
  115. sglang/lang/__init__.py +0 -0
  116. sglang/srt/lora/backend/__init__.py +0 -25
  117. sglang/srt/server.py +0 -18
  118. {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post3.dist-info}/licenses/LICENSE +0 -0
  119. {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post3.dist-info}/top_level.txt +0 -0
@@ -1,13 +1,6 @@
1
1
  from typing import Any, Callable, Dict, List, Optional
2
2
 
3
3
  import torch
4
-
5
- from sglang.srt.utils import is_cuda_available, set_weight_attrs
6
-
7
- is_cuda = is_cuda_available()
8
- if is_cuda:
9
- from sgl_kernel import int8_scaled_mm
10
-
11
4
  from torch.nn.parameter import Parameter
12
5
 
13
6
  from sglang.srt.distributed import get_tensor_model_parallel_world_size
@@ -18,6 +11,11 @@ from sglang.srt.layers.quantization.base_config import (
18
11
  QuantizeMethodBase,
19
12
  )
20
13
  from sglang.srt.layers.quantization.int8_kernel import per_token_quant_int8
14
+ from sglang.srt.utils import is_cuda, set_weight_attrs
15
+
16
+ _is_cuda = is_cuda()
17
+ if _is_cuda:
18
+ from sgl_kernel import int8_scaled_mm
21
19
 
22
20
 
23
21
  class W8A8Int8Config(QuantizationConfig):
@@ -233,6 +231,7 @@ class W8A8Int8MoEMethod:
233
231
  apply_router_weight_on_input: bool = False,
234
232
  inplace: bool = True,
235
233
  no_combine: bool = False,
234
+ routed_scaling_factor: Optional[float] = None,
236
235
  ) -> torch.Tensor:
237
236
  from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
238
237
  from sglang.srt.layers.moe.topk import select_experts
@@ -248,6 +247,7 @@ class W8A8Int8MoEMethod:
248
247
  num_expert_group=num_expert_group,
249
248
  custom_routing_function=custom_routing_function,
250
249
  correction_bias=correction_bias,
250
+ routed_scaling_factor=routed_scaling_factor,
251
251
  )
252
252
 
253
253
  return fused_experts(
@@ -13,6 +13,7 @@
13
13
  # ==============================================================================
14
14
  """Radix attention."""
15
15
 
16
+ from enum import Enum
16
17
  from typing import Optional
17
18
 
18
19
  from torch import nn
@@ -22,6 +23,18 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig
22
23
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
23
24
 
24
25
 
26
+ class AttentionType(Enum):
27
+ """
28
+ Attention type.
29
+ Use string to be compatible with `torch.compile`.
30
+ """
31
+
32
+ # Decoder attention between previous layer Q/K/V
33
+ DECODER = "decoder"
34
+ # Encoder attention between previous layer Q/K/V
35
+ ENCODER_ONLY = "encoder_only"
36
+
37
+
25
38
  class RadixAttention(nn.Module):
26
39
  """
27
40
  The attention layer implementation.
@@ -39,6 +52,7 @@ class RadixAttention(nn.Module):
39
52
  sliding_window_size: int = -1,
40
53
  is_cross_attention: bool = False,
41
54
  quant_config: Optional[QuantizationConfig] = None,
55
+ attn_type=AttentionType.DECODER,
42
56
  prefix: str = "",
43
57
  use_irope: bool = False,
44
58
  ):
@@ -64,6 +78,7 @@ class RadixAttention(nn.Module):
64
78
  self.quant_method = quant_config.get_quant_method(self, prefix=prefix)
65
79
  if self.quant_method is not None:
66
80
  self.quant_method.create_weights(self)
81
+ self.attn_type = attn_type
67
82
 
68
83
  def forward(
69
84
  self,
@@ -8,13 +8,14 @@ import torch
8
8
  import torch.nn as nn
9
9
 
10
10
  from sglang.srt.custom_op import CustomOp
11
- from sglang.srt.utils import is_cuda_available
11
+ from sglang.srt.utils import is_cuda
12
12
 
13
- _is_cuda_available = is_cuda_available()
14
- if _is_cuda_available:
13
+ _is_cuda = is_cuda()
14
+
15
+ if _is_cuda:
15
16
  from sgl_kernel import apply_rope_with_cos_sin_cache_inplace
16
17
  else:
17
- from vllm import _custom_ops as ops
18
+ from vllm._custom_ops import rotary_embedding as vllm_rotary_embedding
18
19
 
19
20
 
20
21
  def _rotate_neox(x: torch.Tensor) -> torch.Tensor:
@@ -81,7 +82,7 @@ class RotaryEmbedding(CustomOp):
81
82
 
82
83
  cache = self._compute_cos_sin_cache()
83
84
  # NOTE(ByronHsu): cache needs to be in FP32 for numerical stability
84
- if not _is_cuda_available:
85
+ if not _is_cuda:
85
86
  cache = cache.to(dtype)
86
87
  self.cos_sin_cache: torch.Tensor
87
88
  self.register_buffer("cos_sin_cache", cache, persistent=False)
@@ -148,7 +149,7 @@ class RotaryEmbedding(CustomOp):
148
149
  key: torch.Tensor,
149
150
  offsets: Optional[torch.Tensor] = None,
150
151
  ) -> Tuple[torch.Tensor, torch.Tensor]:
151
- if _is_cuda_available and (self.head_size in [64, 128, 256, 512]):
152
+ if _is_cuda and (self.head_size in [64, 128, 256, 512]):
152
153
  apply_rope_with_cos_sin_cache_inplace(
153
154
  positions=positions,
154
155
  query=query,
@@ -159,7 +160,7 @@ class RotaryEmbedding(CustomOp):
159
160
  )
160
161
  else:
161
162
  self.cos_sin_cache = self.cos_sin_cache.to(query.device, dtype=query.dtype)
162
- ops.rotary_embedding(
163
+ vllm_rotary_embedding(
163
164
  positions,
164
165
  query,
165
166
  key,
@@ -651,7 +652,7 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
651
652
  def forward(self, *args, **kwargs):
652
653
  if torch.compiler.is_compiling():
653
654
  return self.forward_native(*args, **kwargs)
654
- if _is_cuda_available:
655
+ if _is_cuda:
655
656
  return self.forward_cuda(*args, **kwargs)
656
657
  else:
657
658
  return self.forward_native(*args, **kwargs)
@@ -10,9 +10,9 @@ from sglang.srt.layers.dp_attention import get_attention_tp_group
10
10
  from sglang.srt.layers.logits_processor import LogitsProcessorOutput
11
11
  from sglang.srt.managers.schedule_batch import global_server_args_dict
12
12
  from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
13
- from sglang.srt.utils import crash_on_warnings, get_bool_env_var, is_cuda_available
13
+ from sglang.srt.utils import crash_on_warnings, get_bool_env_var, is_cuda
14
14
 
15
- if is_cuda_available():
15
+ if is_cuda():
16
16
  from sgl_kernel import (
17
17
  min_p_sampling_from_probs,
18
18
  top_k_renorm_prob,
@@ -93,28 +93,23 @@ class Sampler(nn.Module):
93
93
  ).clamp(min=torch.finfo(probs.dtype).min)
94
94
 
95
95
  max_top_k_round, batch_size = 32, probs.shape[0]
96
- uniform_samples = torch.rand(
97
- (max_top_k_round, batch_size), device=probs.device
98
- )
99
96
  if sampling_info.need_min_p_sampling:
100
97
  probs = top_k_renorm_prob(probs, sampling_info.top_ks)
101
98
  probs = top_p_renorm_prob(probs, sampling_info.top_ps)
102
99
  batch_next_token_ids = min_p_sampling_from_probs(
103
- probs, uniform_samples, sampling_info.min_ps
100
+ probs, sampling_info.min_ps
104
101
  )
105
102
  else:
106
- batch_next_token_ids, success = top_k_top_p_sampling_from_probs(
103
+ # Check Nan will throw exception, only check when crash_on_warnings is True
104
+ check_nan = self.use_nan_detection and crash_on_warnings()
105
+ batch_next_token_ids = top_k_top_p_sampling_from_probs(
107
106
  probs,
108
- uniform_samples,
109
107
  sampling_info.top_ks,
110
108
  sampling_info.top_ps,
111
109
  filter_apply_order="joint",
110
+ check_nan=check_nan,
112
111
  )
113
112
 
114
- if self.use_nan_detection and not torch.all(success):
115
- logger.warning("Detected errors during sampling!")
116
- batch_next_token_ids = torch.zeros_like(batch_next_token_ids)
117
-
118
113
  elif global_server_args_dict["sampling_backend"] == "pytorch":
119
114
  # A slower fallback implementation with torch native operations.
120
115
  batch_next_token_ids = top_k_top_p_min_p_sampling_from_probs_torch(
@@ -75,7 +75,7 @@ class BaseLoRABackend:
75
75
  qkv_lora_a: torch.Tensor,
76
76
  qkv_lora_b: Union[torch.Tensor, Tuple[torch.Tensor]],
77
77
  *args,
78
- **kwargs
78
+ **kwargs,
79
79
  ) -> torch.Tensor:
80
80
  """Run the lora pass for QKV Layer.
81
81
 
@@ -98,7 +98,7 @@ class BaseLoRABackend:
98
98
  gate_up_lora_a: torch.Tensor,
99
99
  gate_up_lora_b: Union[torch.Tensor, Tuple[torch.Tensor]],
100
100
  *args,
101
- **kwargs
101
+ **kwargs,
102
102
  ) -> torch.Tensor:
103
103
  """Run the lora pass for gate_up_proj, usually attached to MergedColumnParallelLayer.
104
104
 
@@ -115,3 +115,19 @@ class BaseLoRABackend:
115
115
 
116
116
  def set_batch_info(self, batch_info: LoRABatchInfo):
117
117
  self.batch_info = batch_info
118
+
119
+
120
+ def get_backend_from_name(name: str) -> BaseLoRABackend:
121
+ """
122
+ Get corresponding backend class from backend's name
123
+ """
124
+ if name == "triton":
125
+ from sglang.srt.lora.backend.triton_backend import TritonLoRABackend
126
+
127
+ return TritonLoRABackend
128
+ elif name == "flashinfer":
129
+ from sglang.srt.lora.backend.flashinfer_backend import FlashInferLoRABackend
130
+
131
+ return FlashInferLoRABackend
132
+ else:
133
+ raise ValueError(f"Invalid backend: {name}")
@@ -2,7 +2,7 @@ from typing import Tuple
2
2
 
3
3
  import torch
4
4
 
5
- from sglang.srt.lora.backend import BaseLoRABackend
5
+ from sglang.srt.lora.backend.base_backend import BaseLoRABackend
6
6
  from sglang.srt.lora.utils import LoRABatchInfo
7
7
  from sglang.srt.utils import is_flashinfer_available
8
8
 
@@ -1,6 +1,6 @@
1
1
  import torch
2
2
 
3
- from sglang.srt.lora.backend import BaseLoRABackend
3
+ from sglang.srt.lora.backend.base_backend import BaseLoRABackend
4
4
  from sglang.srt.lora.triton_ops import (
5
5
  gate_up_lora_b_fwd,
6
6
  qkv_lora_b_fwd,
sglang/srt/lora/layers.py CHANGED
@@ -16,7 +16,7 @@ from sglang.srt.layers.linear import (
16
16
  RowParallelLinear,
17
17
  )
18
18
  from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
19
- from sglang.srt.lora.backend import BaseLoRABackend
19
+ from sglang.srt.lora.backend.base_backend import BaseLoRABackend
20
20
 
21
21
 
22
22
  class BaseLayerWithLoRA(nn.Module):
sglang/srt/lora/lora.py CHANGED
@@ -27,7 +27,7 @@ from torch import nn
27
27
 
28
28
  from sglang.srt.configs.load_config import LoadConfig
29
29
  from sglang.srt.hf_transformers_utils import AutoConfig
30
- from sglang.srt.lora.backend import BaseLoRABackend
30
+ from sglang.srt.lora.backend.base_backend import BaseLoRABackend
31
31
  from sglang.srt.lora.lora_config import LoRAConfig
32
32
  from sglang.srt.model_loader.loader import DefaultModelLoader
33
33
 
@@ -22,7 +22,7 @@ import torch
22
22
 
23
23
  from sglang.srt.configs.load_config import LoadConfig
24
24
  from sglang.srt.hf_transformers_utils import AutoConfig
25
- from sglang.srt.lora.backend import BaseLoRABackend, get_backend_from_name
25
+ from sglang.srt.lora.backend.base_backend import BaseLoRABackend, get_backend_from_name
26
26
  from sglang.srt.lora.layers import BaseLayerWithLoRA, get_lora_layer
27
27
  from sglang.srt.lora.lora import LoRAAdapter
28
28
  from sglang.srt.lora.lora_config import LoRAConfig
@@ -30,6 +30,7 @@ from sglang.srt.managers.io_struct import (
30
30
  )
31
31
  from sglang.srt.managers.scheduler import run_scheduler_process
32
32
  from sglang.srt.server_args import PortArgs, ServerArgs
33
+ from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
33
34
  from sglang.srt.utils import bind_port, configure_logger, get_zmq_socket
34
35
  from sglang.utils import get_exception_traceback
35
36
 
@@ -174,6 +175,10 @@ class DataParallelController:
174
175
  if not server_args.enable_dp_attention:
175
176
  logger.info(f"Launch DP{dp_rank} starting at GPU #{base_gpu_id}.")
176
177
 
178
+ memory_saver_adapter = TorchMemorySaverAdapter.create(
179
+ enable=server_args.enable_memory_saver
180
+ )
181
+
177
182
  # Launch tensor parallel scheduler processes
178
183
  scheduler_pipe_readers = []
179
184
  tp_size_per_node = server_args.tp_size // server_args.nnodes
@@ -208,7 +213,8 @@ class DataParallelController:
208
213
  target=run_scheduler_process,
209
214
  args=(server_args, rank_port_args, gpu_id, tp_rank, dp_rank, writer),
210
215
  )
211
- proc.start()
216
+ with memory_saver_adapter.configure_subprocess():
217
+ proc.start()
212
218
  self.scheduler_procs.append(proc)
213
219
  scheduler_pipe_readers.append(reader)
214
220
 
@@ -14,7 +14,6 @@
14
14
  """DetokenizerManager is a process that detokenizes the token ids."""
15
15
 
16
16
  import dataclasses
17
- import json
18
17
  import logging
19
18
  import os
20
19
  import signal
@@ -96,8 +96,8 @@ class GenerateReqInput:
96
96
  return_hidden_states: bool = False
97
97
 
98
98
  # For disaggregated inference
99
- bootstrap_host: Optional[str] = None
100
- bootstrap_room: Optional[int] = None
99
+ bootstrap_host: Optional[Union[List[str], str]] = None
100
+ bootstrap_room: Optional[Union[List[int], int]] = None
101
101
 
102
102
  def normalize_batch_and_arguments(self):
103
103
  """
@@ -397,6 +397,12 @@ class GenerateReqInput:
397
397
  else None
398
398
  ),
399
399
  return_hidden_states=self.return_hidden_states,
400
+ bootstrap_host=(
401
+ self.bootstrap_host[i] if self.bootstrap_host is not None else None
402
+ ),
403
+ bootstrap_room=(
404
+ self.bootstrap_room[i] if self.bootstrap_room is not None else None
405
+ ),
400
406
  )
401
407
 
402
408
 
@@ -665,10 +671,15 @@ class BatchEmbeddingOut:
665
671
 
666
672
 
667
673
  @dataclass
668
- class FlushCacheReq:
674
+ class FlushCacheReqInput:
669
675
  pass
670
676
 
671
677
 
678
+ @dataclass
679
+ class FlushCacheReqOutput:
680
+ success: bool
681
+
682
+
672
683
  @dataclass
673
684
  class UpdateWeightFromDiskReqInput:
674
685
  # The model path with the new weights
@@ -834,6 +845,7 @@ class ProfileReq:
834
845
  activities: Optional[List[str]] = None
835
846
  with_stack: Optional[bool] = None
836
847
  record_shapes: Optional[bool] = None
848
+ profile_id: Optional[str] = None
837
849
 
838
850
 
839
851
  @dataclass
@@ -1,7 +1,8 @@
1
1
  """
2
- Multi-modality utils
2
+ Multi-modality utils
3
3
  """
4
4
 
5
+ import logging
5
6
  from abc import abstractmethod
6
7
  from typing import Callable, List, Optional, Tuple
7
8
 
@@ -12,11 +13,11 @@ from sglang.srt.managers.schedule_batch import (
12
13
  MultimodalDataItem,
13
14
  MultimodalInputs,
14
15
  global_server_args_dict,
15
- logger,
16
16
  )
17
17
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
18
18
  from sglang.srt.utils import print_warning_once
19
- from sglang.utils import logger
19
+
20
+ logger = logging.getLogger(__name__)
20
21
 
21
22
 
22
23
  class MultiModalityDataPaddingPattern:
@@ -5,8 +5,6 @@ import logging
5
5
  import pkgutil
6
6
  from functools import lru_cache
7
7
 
8
- from transformers import PROCESSOR_MAPPING
9
-
10
8
  from sglang.srt.managers.multimodal_processors.base_processor import (
11
9
  BaseMultimodalProcessor,
12
10
  )
@@ -8,8 +8,6 @@ from typing import List, Optional
8
8
 
9
9
  import numpy as np
10
10
  import PIL
11
- from decord import VideoReader, cpu
12
- from PIL import Image
13
11
  from transformers import BaseImageProcessorFast
14
12
 
15
13
  from sglang.srt.managers.schedule_batch import Modality
@@ -102,6 +100,9 @@ class BaseMultimodalProcessor(ABC):
102
100
  """
103
101
  estimate the total frame count from all visual input
104
102
  """
103
+ # Lazy import because decord is not available on some arm platforms.
104
+ from decord import VideoReader, cpu
105
+
105
106
  # Before processing inputs
106
107
  estimated_frames_list = []
107
108
  for image in image_data:
@@ -67,7 +67,6 @@ global_server_args_dict = {
67
67
  "attention_backend": ServerArgs.attention_backend,
68
68
  "sampling_backend": ServerArgs.sampling_backend,
69
69
  "triton_attention_reduce_in_fp32": ServerArgs.triton_attention_reduce_in_fp32,
70
- "disable_mla": ServerArgs.disable_mla,
71
70
  "torchao_config": ServerArgs.torchao_config,
72
71
  "enable_nan_detection": ServerArgs.enable_nan_detection,
73
72
  "enable_dp_attention": ServerArgs.enable_dp_attention,
@@ -77,12 +76,11 @@ global_server_args_dict = {
77
76
  "device": ServerArgs.device,
78
77
  "speculative_accept_threshold_single": ServerArgs.speculative_accept_threshold_single,
79
78
  "speculative_accept_threshold_acc": ServerArgs.speculative_accept_threshold_acc,
80
- "enable_flashmla": ServerArgs.enable_flashmla,
81
79
  "disable_radix_cache": ServerArgs.disable_radix_cache,
82
80
  "flashinfer_mla_disable_ragged": ServerArgs.flashinfer_mla_disable_ragged,
81
+ "moe_dense_tp_size": ServerArgs.moe_dense_tp_size,
83
82
  "chunked_prefill_size": ServerArgs.chunked_prefill_size,
84
83
  "n_share_experts_fusion": ServerArgs.n_share_experts_fusion,
85
- "disable_shared_experts_fusion": ServerArgs.disable_shared_experts_fusion,
86
84
  "disable_chunked_prefix_cache": ServerArgs.disable_chunked_prefix_cache,
87
85
  }
88
86
 
@@ -541,6 +539,11 @@ class Req:
541
539
  # The first output_id transferred from prefill instance.
542
540
  self.transferred_output_id: Optional[int] = None
543
541
 
542
+ # For overlap schedule, we delay the kv transfer until `process_batch_result_disagg_prefill` rather than `process_prefill_chunk` in non-overlap
543
+ # This is because kv is not ready in `process_prefill_chunk`.
544
+ # We use `tmp_end_idx` to store the end index of the kv cache to send.
545
+ self.tmp_end_idx: int = -1
546
+
544
547
  @property
545
548
  def seqlen(self):
546
549
  return len(self.origin_input_ids) + len(self.output_ids)
@@ -573,6 +576,14 @@ class Req:
573
576
  self.prefix_indices, self.last_node = tree_cache.match_prefix(
574
577
  rid=self.rid, key=self.adjust_max_prefix_ids()
575
578
  )
579
+ elif enable_hierarchical_cache:
580
+ # in case last_node is evicted during scheduling, we need to update the prefix_indices
581
+ while self.last_node.evicted:
582
+ self.prefix_indices = self.prefix_indices[
583
+ : -len(self.last_node.host_value)
584
+ ]
585
+ self.last_node = self.last_node.parent
586
+
576
587
  self.extend_input_len = len(self.fill_ids) - len(self.prefix_indices)
577
588
 
578
589
  def adjust_max_prefix_ids(self):
@@ -1481,7 +1492,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1481
1492
  global_server_args_dict["use_mla_backend"]
1482
1493
  and global_server_args_dict["attention_backend"] == "flashinfer"
1483
1494
  )
1484
- or global_server_args_dict["enable_flashmla"]
1495
+ or global_server_args_dict["attention_backend"] == "flashmla"
1485
1496
  or global_server_args_dict["attention_backend"] == "fa3"
1486
1497
  ):
1487
1498
  seq_lens_cpu = self.seq_lens.cpu()