sglang 0.4.4.post2__py3-none-any.whl → 0.4.4.post3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (57) hide show
  1. sglang/bench_serving.py +23 -3
  2. sglang/srt/configs/deepseekvl2.py +10 -1
  3. sglang/srt/configs/model_config.py +5 -16
  4. sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -1
  5. sglang/srt/distributed/parallel_state.py +32 -5
  6. sglang/srt/entrypoints/http_server.py +7 -1
  7. sglang/srt/entrypoints/verl_engine.py +2 -0
  8. sglang/srt/function_call_parser.py +0 -1
  9. sglang/srt/layers/attention/flashattention_backend.py +218 -79
  10. sglang/srt/layers/dp_attention.py +12 -1
  11. sglang/srt/layers/moe/topk.py +30 -3
  12. sglang/srt/layers/quantization/__init__.py +134 -165
  13. sglang/srt/layers/quantization/awq.py +200 -0
  14. sglang/srt/layers/quantization/fp8_kernel.py +2 -1
  15. sglang/srt/layers/quantization/gptq.py +30 -40
  16. sglang/srt/layers/quantization/w8a8_fp8.py +1 -1
  17. sglang/srt/layers/rotary_embedding.py +12 -0
  18. sglang/srt/lora/backend/base_backend.py +4 -4
  19. sglang/srt/lora/backend/flashinfer_backend.py +12 -9
  20. sglang/srt/lora/backend/triton_backend.py +5 -8
  21. sglang/srt/lora/layers.py +19 -33
  22. sglang/srt/lora/lora_manager.py +20 -7
  23. sglang/srt/lora/mem_pool.py +12 -6
  24. sglang/srt/lora/triton_ops/gate_up_lora_b.py +10 -4
  25. sglang/srt/lora/triton_ops/qkv_lora_b.py +8 -3
  26. sglang/srt/lora/triton_ops/sgemm_lora_a.py +16 -5
  27. sglang/srt/lora/triton_ops/sgemm_lora_b.py +11 -6
  28. sglang/srt/lora/utils.py +6 -0
  29. sglang/srt/managers/io_struct.py +4 -2
  30. sglang/srt/managers/multimodal_processors/clip.py +63 -0
  31. sglang/srt/managers/schedule_batch.py +1 -0
  32. sglang/srt/managers/scheduler.py +25 -19
  33. sglang/srt/managers/tokenizer_manager.py +0 -1
  34. sglang/srt/managers/tp_worker.py +3 -0
  35. sglang/srt/model_executor/cuda_graph_runner.py +9 -8
  36. sglang/srt/model_executor/model_runner.py +9 -6
  37. sglang/srt/model_loader/loader.py +11 -1
  38. sglang/srt/model_loader/weight_utils.py +6 -3
  39. sglang/srt/models/clip.py +563 -0
  40. sglang/srt/models/deepseek_janus_pro.py +2 -2
  41. sglang/srt/models/deepseek_v2.py +151 -26
  42. sglang/srt/models/gemma3_causal.py +12 -2
  43. sglang/srt/models/gemma3_mm.py +6 -0
  44. sglang/srt/openai_api/adapter.py +88 -87
  45. sglang/srt/openai_api/protocol.py +10 -5
  46. sglang/srt/patch_torch.py +71 -0
  47. sglang/srt/server_args.py +21 -11
  48. sglang/srt/speculative/eagle_worker.py +1 -1
  49. sglang/srt/utils.py +33 -0
  50. sglang/test/runners.py +27 -2
  51. sglang/test/test_utils.py +1 -1
  52. sglang/version.py +1 -1
  53. {sglang-0.4.4.post2.dist-info → sglang-0.4.4.post3.dist-info}/METADATA +8 -4
  54. {sglang-0.4.4.post2.dist-info → sglang-0.4.4.post3.dist-info}/RECORD +57 -53
  55. {sglang-0.4.4.post2.dist-info → sglang-0.4.4.post3.dist-info}/WHEEL +0 -0
  56. {sglang-0.4.4.post2.dist-info → sglang-0.4.4.post3.dist-info}/licenses/LICENSE +0 -0
  57. {sglang-0.4.4.post2.dist-info → sglang-0.4.4.post3.dist-info}/top_level.txt +0 -0
@@ -12,8 +12,9 @@ def _sgemm_lora_a_kernel(
12
12
  weights,
13
13
  output,
14
14
  # Matrix dimensions
15
- N, # r
15
+ N, # stack_num * r
16
16
  K, # input_dim
17
+ stack_num,
17
18
  # Strides
18
19
  x_stride_0,
19
20
  x_stride_1,
@@ -22,10 +23,11 @@ def _sgemm_lora_a_kernel(
22
23
  w_stride_2,
23
24
  output_stride_0,
24
25
  output_stride_1,
25
- # Information on sequence lengths and weight id
26
+ # Information on sequence lengths,ranks and weight id
26
27
  seg_lens,
27
28
  seg_indptr,
28
29
  weight_indices,
30
+ lora_ranks,
29
31
  # Meta parameters
30
32
  BLOCK_S: tl.constexpr,
31
33
  BLOCK_N: tl.constexpr,
@@ -43,6 +45,9 @@ def _sgemm_lora_a_kernel(
43
45
  seg_len = tl.load(seg_lens + batch_id)
44
46
  w_index = tl.load(weight_indices + batch_id)
45
47
  seg_start = tl.load(seg_indptr + batch_id)
48
+ rank = tl.load(lora_ranks + w_index)
49
+ # Adjust N (stack_num * max_rank) according to the specific LoRA adapter
50
+ N = tl.minimum(N, rank * stack_num)
46
51
 
47
52
  # The tile in output matrix will have (pid_s, pid_n) as id
48
53
  num_pid_n = tl.cdiv(N, BLOCK_N)
@@ -91,11 +96,15 @@ def _sgemm_lora_a_kernel(
91
96
 
92
97
 
93
98
  def sgemm_lora_a_fwd(
94
- x: torch.Tensor, weights: torch.Tensor, batch_info: LoRABatchInfo
99
+ x: torch.Tensor,
100
+ weights: torch.Tensor,
101
+ batch_info: LoRABatchInfo,
102
+ stack_num: int = 1,
95
103
  ) -> torch.Tensor:
96
104
  # x: (s, input_dim)
97
- # weights: (num_lora, r, input_dim)
98
- # output: (s, r)
105
+ # weights: (num_lora, stack_num * r, input_dim)
106
+ # output: (s, stack_num * r)
107
+ # stack_num: run_qkv_lora: 3, run_gate_up_lora: 2
99
108
  # when called by run_qkv_lora, the weights.shape[-2] will be 3 * r
100
109
  # input_dim is much larger than r
101
110
 
@@ -126,6 +135,7 @@ def sgemm_lora_a_fwd(
126
135
  output,
127
136
  R,
128
137
  K,
138
+ stack_num,
129
139
  x.stride(0),
130
140
  x.stride(1),
131
141
  weights.stride(0),
@@ -136,6 +146,7 @@ def sgemm_lora_a_fwd(
136
146
  batch_info.seg_lens,
137
147
  batch_info.seg_indptr,
138
148
  batch_info.weight_indices,
149
+ batch_info.lora_ranks,
139
150
  BLOCK_S,
140
151
  BLOCK_R,
141
152
  BLOCK_K,
@@ -26,13 +26,14 @@ def _sgemm_lora_b_kernel(
26
26
  seg_lens,
27
27
  seg_indptr,
28
28
  weight_indices,
29
+ lora_ranks,
29
30
  # Meta parameters
30
31
  BLOCK_S: tl.constexpr,
31
32
  BLOCK_N: tl.constexpr,
32
33
  BLOCK_K: tl.constexpr,
33
34
  # For fused output scaling and adding
34
35
  fuse_scaling_add,
35
- scaling,
36
+ scalings,
36
37
  ):
37
38
  # x: (s, K), s is the sum of sequence lengths
38
39
  # weights: (num_lora, N, K)
@@ -45,6 +46,10 @@ def _sgemm_lora_b_kernel(
45
46
  seg_len = tl.load(seg_lens + batch_id)
46
47
  w_index = tl.load(weight_indices + batch_id)
47
48
  seg_start = tl.load(seg_indptr + batch_id)
49
+ rank = tl.load(lora_ranks + w_index)
50
+ scaling = tl.load(scalings + w_index)
51
+ # Adjust K (rank) according to the specific LoRA adapter
52
+ K = tl.minimum(K, rank)
48
53
 
49
54
  # The tile in output matrix will have (pid_s, pid_n) as id
50
55
  num_pid_n = tl.cdiv(N, BLOCK_N)
@@ -100,12 +105,11 @@ def sgemm_lora_b_fwd(
100
105
  weights: torch.Tensor,
101
106
  batch_info: LoRABatchInfo,
102
107
  base_output: torch.Tensor = None,
103
- scaling: float = 1.0,
104
108
  ) -> torch.Tensor:
105
- # x: (s, r)
106
- # weights: (num_lora, output_dim, r)
109
+ # x: (s, max_r)
110
+ # weights: (num_lora, output_dim, max_r)
107
111
  # output: (s, output_dim)
108
- # output_dim is much larger than r
112
+ # output_dim is much larger than max_r
109
113
 
110
114
  assert x.is_contiguous()
111
115
  assert weights.is_contiguous()
@@ -150,10 +154,11 @@ def sgemm_lora_b_fwd(
150
154
  batch_info.seg_lens,
151
155
  batch_info.seg_indptr,
152
156
  batch_info.weight_indices,
157
+ batch_info.lora_ranks,
153
158
  BLOCK_S,
154
159
  BLOCK_N,
155
160
  BLOCK_R,
156
161
  fuse_scaling_add,
157
- scaling,
162
+ batch_info.scalings,
158
163
  )
159
164
  return output
sglang/srt/lora/utils.py CHANGED
@@ -25,6 +25,12 @@ class LoRABatchInfo:
25
25
  # The index of lora adapter used by each sequence, in shape (bs,)
26
26
  weight_indices: torch.Tensor
27
27
 
28
+ # ranks of each lora adapter, in shape (lora_num,)
29
+ lora_ranks: torch.Tensor
30
+
31
+ # scaling of each lora adapter, in shape (lora_num,)
32
+ scalings: torch.Tensor
33
+
28
34
 
29
35
  class LoRAType(Enum):
30
36
  LORA_A = 0
@@ -20,7 +20,7 @@ import copy
20
20
  import uuid
21
21
  from dataclasses import dataclass, field
22
22
  from enum import Enum
23
- from typing import Any, Dict, List, Optional, Union
23
+ from typing import Any, Dict, List, Literal, Optional, Union
24
24
 
25
25
  from sglang.srt.managers.schedule_batch import BaseFinishReason
26
26
  from sglang.srt.sampling.sampling_params import SamplingParams
@@ -650,7 +650,7 @@ class ProfileReqInput:
650
650
  # If it is set, profiling is automatically stopped after this step, and
651
651
  # the caller doesn't need to run stop_profile.
652
652
  num_steps: Optional[int] = None
653
- activities: Optional[List[str]] = None
653
+ activities: Optional[List[Literal["CPU", "GPU", "MEM", "CUDA_PROFILER"]]] = None
654
654
 
655
655
 
656
656
  class ProfileReqType(Enum):
@@ -675,6 +675,8 @@ class ProfileReq:
675
675
  output_dir: Optional[str] = None
676
676
  num_steps: Optional[int] = None
677
677
  activities: Optional[List[str]] = None
678
+ with_stack: Optional[bool] = None
679
+ record_shapes: Optional[bool] = None
678
680
 
679
681
 
680
682
  @dataclass
@@ -0,0 +1,63 @@
1
+ import asyncio
2
+ from typing import List, Union
3
+
4
+ from sglang.srt.managers.multimodal_processors.base_processor import (
5
+ BaseMultimodalProcessor,
6
+ get_global_processor,
7
+ )
8
+ from sglang.srt.models.clip import CLIPModel
9
+ from sglang.srt.utils import load_image
10
+
11
+
12
+ class ClipImageProcessor(BaseMultimodalProcessor):
13
+ models = [CLIPModel]
14
+
15
+ def __init__(self, hf_config, server_args, _processor):
16
+ super().__init__(hf_config, server_args, _processor)
17
+
18
+ @staticmethod
19
+ def _process_single_image_task(images, input_text):
20
+ # input_ids', 'attention_mask', 'pixel_values', 'aspect_ratio_ids', 'aspect_ratio_mask', 'cross_attention_mask'
21
+ return get_global_processor()(
22
+ images=images, text=input_text, return_tensors="pt"
23
+ )
24
+
25
+ async def _process_single_image(self, images, input_text):
26
+ if self.executor is not None:
27
+ loop = asyncio.get_event_loop()
28
+ image_inputs = await loop.run_in_executor(
29
+ self.executor,
30
+ ClipImageProcessor._process_single_image_task,
31
+ images,
32
+ input_text,
33
+ )
34
+ else:
35
+ image_inputs = self._processor(
36
+ images=images, text=[input_text], return_tensors="pt"
37
+ )
38
+
39
+ return image_inputs
40
+
41
+ async def process_mm_data_async(
42
+ self, image_data: List[Union[str, bytes]], input_text, *args, **kwargs
43
+ ):
44
+ if not image_data:
45
+ return None
46
+
47
+ if isinstance(input_text, list):
48
+ assert len(input_text) and isinstance(input_text[0], int)
49
+ input_text = self._processor.tokenizer.decode(input_text)
50
+
51
+ if not isinstance(image_data, list):
52
+ image_data = [image_data]
53
+
54
+ if len(image_data) > 0:
55
+ images = [load_image(image)[0] for image in image_data]
56
+ else:
57
+ images = load_image(image_data[0])[0]
58
+
59
+ image_inputs = await self._process_single_image(images, input_text)
60
+ image_inputs["data_hashes"] = [hash(str(image_data))]
61
+ image_inputs["input_ids"] = image_inputs["input_ids"].tolist()[0]
62
+
63
+ return image_inputs
@@ -1376,6 +1376,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1376
1376
  if (
1377
1377
  global_server_args_dict["enable_flashinfer_mla"]
1378
1378
  or global_server_args_dict["enable_flashmla"]
1379
+ or global_server_args_dict["attention_backend"] == "fa3"
1379
1380
  ):
1380
1381
  decode_seq_lens = self.seq_lens.cpu()
1381
1382
  else:
@@ -379,7 +379,7 @@ class Scheduler(
379
379
  # Init profiler
380
380
  self.torch_profiler = None
381
381
  self.torch_profiler_output_dir: Optional[str] = None
382
- self.torch_profiler_activities: Optional[List[str]] = None
382
+ self.profiler_activities: Optional[List[str]] = None
383
383
  self.profiler_target_forward_ct: Optional[int] = None
384
384
 
385
385
  # Init metrics stats
@@ -1186,7 +1186,7 @@ class Scheduler(
1186
1186
  ret = None
1187
1187
 
1188
1188
  # Handle DP attention
1189
- if self.server_args.enable_dp_attention:
1189
+ if self.server_args.enable_dp_attention or self.server_args.enable_sp_layernorm:
1190
1190
  ret, _ = self.prepare_dp_attn_batch(ret)
1191
1191
 
1192
1192
  return ret
@@ -1703,18 +1703,12 @@ class Scheduler(
1703
1703
  def save_remote_model(self, params):
1704
1704
  url = params["url"]
1705
1705
 
1706
- if isinstance(self.tp_worker, TpModelWorkerClient):
1707
- worker = self.tp_worker.worker
1708
- else:
1709
- worker = self.tp_worker
1706
+ worker = self.tp_worker.worker
1710
1707
 
1711
1708
  worker.model_runner.save_remote_model(url)
1712
1709
 
1713
1710
  def save_sharded_model(self, params):
1714
- if isinstance(self.tp_worker, TpModelWorkerClient):
1715
- worker = self.tp_worker.worker
1716
- else:
1717
- worker = self.tp_worker
1711
+ worker = self.tp_worker.worker
1718
1712
 
1719
1713
  worker.model_runner.save_sharded_model(
1720
1714
  path=params["path"],
@@ -1813,7 +1807,11 @@ class Scheduler(
1813
1807
  def profile(self, recv_req: ProfileReq):
1814
1808
  if recv_req.type == ProfileReqType.START_PROFILE:
1815
1809
  return self.start_profile(
1816
- recv_req.output_dir, recv_req.num_steps, recv_req.activities
1810
+ recv_req.output_dir,
1811
+ recv_req.num_steps,
1812
+ recv_req.activities,
1813
+ recv_req.with_stack,
1814
+ recv_req.record_shapes,
1817
1815
  )
1818
1816
  else:
1819
1817
  return self.stop_profile()
@@ -1823,8 +1821,10 @@ class Scheduler(
1823
1821
  output_dir: Optional[str],
1824
1822
  num_steps: Optional[int],
1825
1823
  activities: Optional[List[str]],
1824
+ with_stack: Optional[bool],
1825
+ record_shapes: Optional[bool],
1826
1826
  ) -> None:
1827
- if self.torch_profiler_activities:
1827
+ if self.profiler_activities:
1828
1828
  return ProfileReqOutput(
1829
1829
  success=False,
1830
1830
  message="Profiling is already in progress. Call /stop_profile first.",
@@ -1836,7 +1836,7 @@ class Scheduler(
1836
1836
  activities = ["CPU", "GPU"]
1837
1837
 
1838
1838
  self.torch_profiler_output_dir = output_dir
1839
- self.torch_profiler_activities = activities
1839
+ self.profiler_activities = activities
1840
1840
  logger.info(
1841
1841
  "Profiling starts. Traces will be saved to: %s",
1842
1842
  self.torch_profiler_output_dir,
@@ -1853,13 +1853,17 @@ class Scheduler(
1853
1853
  if torchprof_activities:
1854
1854
  self.torch_profiler = torch.profiler.profile(
1855
1855
  activities=torchprof_activities,
1856
- with_stack=True,
1856
+ with_stack=with_stack if with_stack is not None else True,
1857
+ record_shapes=record_shapes if record_shapes is not None else False,
1857
1858
  )
1858
1859
  self.torch_profiler.start()
1859
1860
 
1860
1861
  if "MEM" in activities:
1861
1862
  torch.cuda.memory._record_memory_history(max_entries=100000)
1862
1863
 
1864
+ if "CUDA_PROFILER" in activities:
1865
+ torch.cuda.cudart().cudaProfilerStart()
1866
+
1863
1867
  if num_steps:
1864
1868
  self.profiler_target_forward_ct = self.forward_ct + num_steps
1865
1869
  # The caller will be notified when reaching profiler_target_forward_ct
@@ -1868,7 +1872,7 @@ class Scheduler(
1868
1872
  return ProfileReqOutput(success=True, message="Succeeded")
1869
1873
 
1870
1874
  def stop_profile(self) -> None:
1871
- if self.torch_profiler_activities is None:
1875
+ if self.profiler_activities is None:
1872
1876
  return
1873
1877
 
1874
1878
  logger.info("Stop profiling...")
@@ -1881,21 +1885,24 @@ class Scheduler(
1881
1885
  )
1882
1886
  )
1883
1887
 
1884
- if "MEM" in self.torch_profiler_activities:
1888
+ if "MEM" in self.profiler_activities:
1885
1889
  memory_profile_path = os.path.join(
1886
- self.torch_profiler_trace_dir,
1890
+ self.torch_profiler_output_dir,
1887
1891
  str(time.time()) + f"-TP-{self.tp_rank}-memory" + ".pickle",
1888
1892
  )
1889
1893
  torch.cuda.memory._dump_snapshot(memory_profile_path)
1890
1894
  torch.cuda.memory._record_memory_history(enabled=None)
1891
1895
 
1896
+ if "CUDA_PROFILER" in self.profiler_activities:
1897
+ torch.cuda.cudart().cudaProfilerStop()
1898
+
1892
1899
  logger.info(
1893
1900
  "Profiling done. Traces are saved to: %s",
1894
1901
  self.torch_profiler_output_dir,
1895
1902
  )
1896
1903
  self.torch_profiler = None
1897
1904
  self.torch_profiler_output_dir = None
1898
- self.torch_profiler_activities = None
1905
+ self.profiler_activities = None
1899
1906
 
1900
1907
  if self.profiler_target_forward_ct:
1901
1908
  self.send_to_tokenizer.send_pyobj(
@@ -1963,7 +1970,6 @@ def run_scheduler_process(
1963
1970
  dp_rank: Optional[int],
1964
1971
  pipe_writer,
1965
1972
  ):
1966
-
1967
1973
  # Generate the prefix
1968
1974
  if dp_rank is None:
1969
1975
  prefix = f" TP{tp_rank}"
@@ -261,7 +261,6 @@ class TokenizerManager:
261
261
  self.start_profile_communicator = _Communicator(
262
262
  self.send_to_scheduler, server_args.dp_size
263
263
  )
264
- self.health_check_communitcator = _Communicator(self.send_to_scheduler, 1)
265
264
  self.get_internal_state_communicator = _Communicator(
266
265
  self.send_to_scheduler, server_args.dp_size
267
266
  )
@@ -132,6 +132,9 @@ class TpModelWorker:
132
132
  )[0]
133
133
  set_random_seed(self.random_seed)
134
134
 
135
+ # A reference make this class has the same member as TpModelWorkerClient
136
+ self.worker = self
137
+
135
138
  def get_worker_info(self):
136
139
  return (
137
140
  self.max_total_num_tokens,
@@ -124,8 +124,8 @@ def get_batch_sizes_to_capture(model_runner: ModelRunner):
124
124
  # capture less.
125
125
  capture_bs = list(range(1, 9)) + list(range(9, 33, 2)) + [64, 96, 128, 160]
126
126
 
127
- if _is_hip:
128
- capture_bs += [i * 8 for i in range(21, 33)]
127
+ if _is_hip:
128
+ capture_bs += [i * 8 for i in range(21, 33)]
129
129
 
130
130
  if max(capture_bs) > model_runner.req_to_token_pool.size:
131
131
  # In some case (e.g., with a small GPU or --max-running-requests), the #max-running-requests
@@ -174,6 +174,7 @@ class CudaGraphRunner:
174
174
  self.disable_padding = model_runner.server_args.disable_cuda_graph_padding
175
175
  self.is_encoder_decoder = model_runner.model_config.is_encoder_decoder
176
176
  self.enable_dp_attention = model_runner.server_args.enable_dp_attention
177
+ self.enable_sp_layernorm = model_runner.server_args.enable_sp_layernorm
177
178
  self.speculative_algorithm = model_runner.server_args.speculative_algorithm
178
179
  self.tp_size = model_runner.server_args.tp_size
179
180
  self.dp_size = model_runner.server_args.dp_size
@@ -245,8 +246,8 @@ class CudaGraphRunner:
245
246
  )
246
247
  else:
247
248
  self.encoder_lens = None
248
-
249
- if self.enable_dp_attention:
249
+ if self.enable_dp_attention or self.enable_sp_layernorm:
250
+ # TODO(ch-wan): SP layernorm should use a different logic to manage gathered_buffer
250
251
  self.gathered_buffer = torch.zeros(
251
252
  (
252
253
  self.max_bs * self.dp_size * self.num_tokens_per_bs,
@@ -288,7 +289,7 @@ class CudaGraphRunner:
288
289
  self.model_runner.token_to_kv_pool.capture_mode = False
289
290
 
290
291
  def can_run(self, forward_batch: ForwardBatch):
291
- if self.enable_dp_attention:
292
+ if self.enable_dp_attention or self.enable_sp_layernorm:
292
293
  total_global_tokens = sum(forward_batch.global_num_tokens_cpu)
293
294
 
294
295
  is_bs_supported = forward_batch.can_run_dp_cuda_graph and (
@@ -369,7 +370,7 @@ class CudaGraphRunner:
369
370
  encoder_lens = None
370
371
  mrope_positions = self.mrope_positions[:, :bs]
371
372
 
372
- if self.enable_dp_attention:
373
+ if self.enable_dp_attention or self.enable_sp_layernorm:
373
374
  self.global_num_tokens_gpu.copy_(
374
375
  torch.tensor(
375
376
  [
@@ -471,7 +472,7 @@ class CudaGraphRunner:
471
472
  raw_num_token = raw_bs * self.num_tokens_per_bs
472
473
 
473
474
  # Pad
474
- if self.enable_dp_attention:
475
+ if self.enable_dp_attention or self.enable_sp_layernorm:
475
476
  index = bisect.bisect_left(
476
477
  self.capture_bs, sum(forward_batch.global_num_tokens_cpu)
477
478
  )
@@ -497,7 +498,7 @@ class CudaGraphRunner:
497
498
  self.encoder_lens[:raw_bs].copy_(forward_batch.encoder_lens)
498
499
  if forward_batch.mrope_positions is not None:
499
500
  self.mrope_positions[:, :raw_bs].copy_(forward_batch.mrope_positions)
500
- if self.enable_dp_attention:
501
+ if self.enable_dp_attention or self.enable_sp_layernorm:
501
502
  self.global_num_tokens_gpu.copy_(forward_batch.global_num_tokens_gpu)
502
503
 
503
504
  if hasattr(forward_batch.spec_info, "hidden_states"):
@@ -64,6 +64,7 @@ from sglang.srt.model_loader.loader import (
64
64
  )
65
65
  from sglang.srt.model_loader.utils import set_default_torch_dtype
66
66
  from sglang.srt.model_loader.weight_utils import default_weight_loader
67
+ from sglang.srt.patch_torch import monkey_patch_torch_reductions
67
68
  from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
68
69
  from sglang.srt.server_args import ServerArgs
69
70
  from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
@@ -229,6 +230,10 @@ class ModelRunner:
229
230
  elif server_args.enable_flashmla:
230
231
  logger.info("MLA optimization is turned on. Use flashmla decode.")
231
232
  server_args.attention_backend = "flashmla"
233
+ elif server_args.attention_backend == "fa3":
234
+ logger.info(
235
+ f"MLA optimization is turned on. Use flash attention 3 backend."
236
+ )
232
237
  else:
233
238
  logger.info("MLA optimization is turned on. Use triton backend.")
234
239
  server_args.attention_backend = "triton"
@@ -280,9 +285,6 @@ class ModelRunner:
280
285
 
281
286
  if server_args.enable_deepep_moe:
282
287
  logger.info("DeepEP is turned on.")
283
- assert (
284
- server_args.enable_dp_attention == True
285
- ), "Currently DeepEP is bind to Attention DP. Set '--enable-dp-attention --enable-deepep-moe'"
286
288
 
287
289
  def init_torch_distributed(self):
288
290
  logger.info("Init torch distributed begin.")
@@ -881,7 +883,7 @@ class ModelRunner:
881
883
  "Please use `--attention-backend flashinfer`."
882
884
  )
883
885
  logger.warning(
884
- "FlashAttention v3 Backend is in Beta. Multimodal, Page > 1, FP8, MLA and Speculative Decoding are not supported."
886
+ "FlashAttention v3 Backend is in Beta. Multimodal, FP8, and Speculative Decoding are not supported."
885
887
  )
886
888
  from sglang.srt.layers.attention.flashattention_backend import (
887
889
  FlashAttentionBackend,
@@ -1082,8 +1084,9 @@ def _model_load_weights_direct(model, named_tensors: List[Tuple[str, torch.Tenso
1082
1084
 
1083
1085
  def _unwrap_tensor(tensor, tp_rank):
1084
1086
  if isinstance(tensor, LocalSerializedTensor):
1085
- return tensor.get(tp_rank)
1086
- return tensor
1087
+ monkey_patch_torch_reductions()
1088
+ tensor = tensor.get(tp_rank)
1089
+ return tensor.to(torch.cuda.current_device())
1087
1090
 
1088
1091
 
1089
1092
  @dataclass
@@ -14,7 +14,6 @@ from abc import ABC, abstractmethod
14
14
  from contextlib import contextmanager
15
15
  from typing import Any, Dict, Generator, Iterable, List, Optional, Tuple, cast
16
16
 
17
- import gguf
18
17
  import huggingface_hub
19
18
  import numpy as np
20
19
  import torch
@@ -1155,6 +1154,17 @@ class GGUFModelLoader(BaseModelLoader):
1155
1154
  See "Standardized tensor names" in
1156
1155
  https://github.com/ggerganov/ggml/blob/master/docs/gguf.md for details.
1157
1156
  """
1157
+
1158
+ # only load the gguf module when needed
1159
+ try:
1160
+ import gguf
1161
+
1162
+ # FIXME: add version check for gguf
1163
+ except ImportError as err:
1164
+ raise ImportError(
1165
+ "Please install gguf via `pip install gguf` to use gguf quantizer."
1166
+ ) from err
1167
+
1158
1168
  config = model_config.hf_config
1159
1169
  model_type = config.model_type
1160
1170
  # hack: ggufs have a different name than transformers
@@ -22,7 +22,6 @@ from typing import (
22
22
  )
23
23
 
24
24
  import filelock
25
- import gguf
26
25
  import huggingface_hub.constants
27
26
  import numpy as np
28
27
  import safetensors.torch
@@ -93,7 +92,7 @@ def convert_bin_to_safetensor_file(
93
92
  pt_filename: str,
94
93
  sf_filename: str,
95
94
  ) -> None:
96
- loaded = torch.load(pt_filename, map_location="cpu")
95
+ loaded = torch.load(pt_filename, map_location="cpu", weights_only=True)
97
96
  if "state_dict" in loaded:
98
97
  loaded = loaded["state_dict"]
99
98
  shared = _shared_pointers(loaded)
@@ -381,7 +380,7 @@ def np_cache_weights_iterator(
381
380
  disable=not enable_tqdm,
382
381
  bar_format=_BAR_FORMAT,
383
382
  ):
384
- state = torch.load(bin_file, map_location="cpu")
383
+ state = torch.load(bin_file, map_location="cpu", weights_only=True)
385
384
  for name, param in state.items():
386
385
  param_path = os.path.join(np_folder, name)
387
386
  with open(param_path, "wb") as f:
@@ -464,6 +463,8 @@ def pt_weights_iterator(
464
463
  def get_gguf_extra_tensor_names(
465
464
  gguf_file: str, gguf_to_hf_name_map: Dict[str, str]
466
465
  ) -> List[str]:
466
+ import gguf
467
+
467
468
  reader = gguf.GGUFReader(gguf_file)
468
469
  expected_gguf_keys = set(gguf_to_hf_name_map.keys())
469
470
  exact_gguf_keys = set([tensor.name for tensor in reader.tensors])
@@ -479,6 +480,8 @@ def gguf_quant_weights_iterator(
479
480
  them to torch tensors
480
481
  """
481
482
 
483
+ import gguf
484
+
482
485
  reader = gguf.GGUFReader(gguf_file)
483
486
 
484
487
  for tensor in reader.tensors: