sglang 0.3.6.post3__py3-none-any.whl → 0.4.0.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (116) hide show
  1. sglang/__init__.py +1 -1
  2. sglang/bench_one_batch.py +4 -0
  3. sglang/bench_serving.py +13 -0
  4. sglang/check_env.py +1 -1
  5. sglang/srt/_custom_ops.py +118 -0
  6. sglang/srt/configs/device_config.py +17 -0
  7. sglang/srt/configs/load_config.py +84 -0
  8. sglang/srt/configs/model_config.py +161 -4
  9. sglang/srt/configs/qwen2vl.py +5 -8
  10. sglang/srt/constrained/outlines_backend.py +11 -1
  11. sglang/srt/constrained/outlines_jump_forward.py +8 -1
  12. sglang/srt/constrained/xgrammar_backend.py +5 -5
  13. sglang/srt/distributed/__init__.py +3 -0
  14. sglang/srt/distributed/communication_op.py +34 -0
  15. sglang/srt/distributed/device_communicators/__init__.py +0 -0
  16. sglang/srt/distributed/device_communicators/cuda_wrapper.py +182 -0
  17. sglang/srt/distributed/device_communicators/custom_all_reduce.py +352 -0
  18. sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +291 -0
  19. sglang/srt/distributed/device_communicators/hpu_communicator.py +48 -0
  20. sglang/srt/distributed/device_communicators/pynccl.py +204 -0
  21. sglang/srt/distributed/device_communicators/pynccl_wrapper.py +362 -0
  22. sglang/srt/distributed/device_communicators/shm_broadcast.py +568 -0
  23. sglang/srt/distributed/device_communicators/xpu_communicator.py +47 -0
  24. sglang/srt/distributed/parallel_state.py +1275 -0
  25. sglang/srt/distributed/utils.py +223 -0
  26. sglang/srt/hf_transformers_utils.py +37 -1
  27. sglang/srt/layers/attention/__init__.py +5 -2
  28. sglang/srt/layers/attention/double_sparsity_backend.py +22 -8
  29. sglang/srt/layers/attention/flashinfer_backend.py +33 -20
  30. sglang/srt/layers/attention/torch_native_backend.py +299 -0
  31. sglang/srt/layers/attention/triton_backend.py +22 -8
  32. sglang/srt/layers/attention/triton_ops/extend_attention.py +3 -0
  33. sglang/srt/layers/ep_moe/__init__.py +0 -0
  34. sglang/srt/layers/ep_moe/kernels.py +349 -0
  35. sglang/srt/layers/ep_moe/layer.py +661 -0
  36. sglang/srt/layers/fused_moe_patch.py +20 -11
  37. sglang/srt/layers/linear.py +1 -0
  38. sglang/srt/layers/logits_processor.py +17 -3
  39. sglang/srt/layers/quantization/__init__.py +36 -2
  40. sglang/srt/layers/quantization/fp8.py +559 -0
  41. sglang/srt/layers/quantization/fp8_utils.py +27 -0
  42. sglang/srt/layers/radix_attention.py +4 -2
  43. sglang/srt/layers/sampler.py +2 -0
  44. sglang/srt/layers/torchao_utils.py +23 -45
  45. sglang/srt/layers/vocab_parallel_embedding.py +1 -0
  46. sglang/srt/lora/lora.py +1 -1
  47. sglang/srt/managers/io_struct.py +48 -2
  48. sglang/srt/managers/schedule_batch.py +19 -14
  49. sglang/srt/managers/schedule_policy.py +7 -4
  50. sglang/srt/managers/scheduler.py +145 -85
  51. sglang/srt/managers/tokenizer_manager.py +166 -68
  52. sglang/srt/managers/tp_worker.py +36 -3
  53. sglang/srt/managers/tp_worker_overlap_thread.py +28 -8
  54. sglang/srt/mem_cache/memory_pool.py +5 -1
  55. sglang/srt/model_executor/cuda_graph_runner.py +30 -7
  56. sglang/srt/model_executor/forward_batch_info.py +9 -4
  57. sglang/srt/model_executor/model_runner.py +146 -153
  58. sglang/srt/model_loader/__init__.py +34 -0
  59. sglang/srt/model_loader/loader.py +1139 -0
  60. sglang/srt/model_loader/utils.py +41 -0
  61. sglang/srt/model_loader/weight_utils.py +640 -0
  62. sglang/srt/model_parallel.py +1 -5
  63. sglang/srt/models/baichuan.py +9 -10
  64. sglang/srt/models/chatglm.py +6 -15
  65. sglang/srt/models/commandr.py +4 -5
  66. sglang/srt/models/dbrx.py +2 -3
  67. sglang/srt/models/deepseek.py +4 -11
  68. sglang/srt/models/deepseek_v2.py +90 -18
  69. sglang/srt/models/exaone.py +2 -3
  70. sglang/srt/models/gemma.py +2 -6
  71. sglang/srt/models/gemma2.py +3 -14
  72. sglang/srt/models/gemma2_reward.py +0 -1
  73. sglang/srt/models/gpt2.py +5 -12
  74. sglang/srt/models/gpt_bigcode.py +6 -22
  75. sglang/srt/models/grok.py +3 -8
  76. sglang/srt/models/internlm2.py +2 -3
  77. sglang/srt/models/internlm2_reward.py +0 -1
  78. sglang/srt/models/llama.py +96 -31
  79. sglang/srt/models/llama_classification.py +1 -2
  80. sglang/srt/models/llama_embedding.py +1 -2
  81. sglang/srt/models/llama_reward.py +2 -3
  82. sglang/srt/models/llava.py +1 -4
  83. sglang/srt/models/llavavid.py +1 -2
  84. sglang/srt/models/minicpm.py +4 -7
  85. sglang/srt/models/minicpm3.py +6 -19
  86. sglang/srt/models/mixtral.py +24 -14
  87. sglang/srt/models/mixtral_quant.py +2 -3
  88. sglang/srt/models/mllama.py +3 -7
  89. sglang/srt/models/olmo.py +2 -8
  90. sglang/srt/models/olmo2.py +0 -1
  91. sglang/srt/models/olmoe.py +3 -5
  92. sglang/srt/models/phi3_small.py +8 -13
  93. sglang/srt/models/qwen.py +2 -3
  94. sglang/srt/models/qwen2.py +10 -9
  95. sglang/srt/models/qwen2_moe.py +4 -16
  96. sglang/srt/models/qwen2_vl.py +2 -6
  97. sglang/srt/models/registry.py +99 -0
  98. sglang/srt/models/stablelm.py +2 -3
  99. sglang/srt/models/torch_native_llama.py +6 -17
  100. sglang/srt/models/xverse.py +2 -4
  101. sglang/srt/models/xverse_moe.py +4 -11
  102. sglang/srt/models/yivl.py +2 -3
  103. sglang/srt/openai_api/adapter.py +9 -5
  104. sglang/srt/openai_api/protocol.py +1 -0
  105. sglang/srt/sampling/sampling_batch_info.py +9 -8
  106. sglang/srt/server.py +270 -173
  107. sglang/srt/server_args.py +102 -29
  108. sglang/srt/utils.py +295 -28
  109. sglang/test/test_utils.py +7 -0
  110. sglang/version.py +1 -1
  111. {sglang-0.3.6.post3.dist-info → sglang-0.4.0.post1.dist-info}/METADATA +5 -4
  112. sglang-0.4.0.post1.dist-info/RECORD +189 -0
  113. sglang-0.3.6.post3.dist-info/RECORD +0 -162
  114. {sglang-0.3.6.post3.dist-info → sglang-0.4.0.post1.dist-info}/LICENSE +0 -0
  115. {sglang-0.3.6.post3.dist-info → sglang-0.4.0.post1.dist-info}/WHEEL +0 -0
  116. {sglang-0.3.6.post3.dist-info → sglang-0.4.0.post1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,27 @@
1
+ from typing import Optional, Tuple
2
+
3
+ import torch
4
+
5
+
6
+ def normalize_e4m3fn_to_e4m3fnuz(
7
+ weight: torch.Tensor,
8
+ weight_scale: torch.Tensor,
9
+ input_scale: Optional[torch.Tensor] = None,
10
+ ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
11
+ assert weight.dtype == torch.float8_e4m3fn
12
+ # The bits pattern 10000000(-128) represents zero in e4m3fn
13
+ # but NaN in e4m3fnuz. So here we set it to 0.
14
+ # https://onnx.ai/onnx/technical/float8.html
15
+ weight_as_int8 = weight.view(torch.int8)
16
+ ROCM_FP8_NAN_AS_INT = -128
17
+ weight_as_int8[weight_as_int8 == ROCM_FP8_NAN_AS_INT] = 0
18
+ weight = weight_as_int8.view(torch.float8_e4m3fnuz)
19
+
20
+ # For the same bits representation, e4m3fnuz value is half of
21
+ # the e4m3fn value, so we should double the scaling factor to
22
+ # get the same dequantized value.
23
+ # https://onnx.ai/onnx/technical/float8.html
24
+ weight_scale = weight_scale * 2.0
25
+ if input_scale is not None:
26
+ input_scale = input_scale * 2.0
27
+ return weight, weight_scale, input_scale
@@ -48,11 +48,13 @@ class RadixAttention(nn.Module):
48
48
  self.sliding_window_size = sliding_window_size or -1
49
49
  self.is_cross_attention = is_cross_attention
50
50
 
51
- def forward(self, q, k, v, forward_batch: ForwardBatch):
51
+ def forward(self, q, k, v, forward_batch: ForwardBatch, save_kv_cache=True):
52
52
  if k is not None:
53
53
  # For cross-layer sharing, kv can be None
54
54
  assert v is not None
55
55
  k = k.view(-1, self.tp_k_head_num, self.qk_head_dim)
56
56
  v = v.view(-1, self.tp_v_head_num, self.v_head_dim)
57
57
 
58
- return forward_batch.attn_backend.forward(q, k, v, self, forward_batch)
58
+ return forward_batch.attn_backend.forward(
59
+ q, k, v, self, forward_batch, save_kv_cache
60
+ )
@@ -111,5 +111,7 @@ def top_k_top_p_min_p_sampling_from_probs_torch(
111
111
  probs_sort[probs_sort < min_p_thresholds.view(-1, 1)] = 0.0
112
112
  probs_sort.div_(probs_sort.max(dim=-1, keepdim=True)[0])
113
113
  sampled_index = torch.multinomial(probs_sort, num_samples=1)
114
+ # int32 range is enough to represent the token ids
115
+ probs_idx = probs_idx.to(torch.int32)
114
116
  batch_next_token_ids = torch.gather(probs_idx, dim=1, index=sampled_index).view(-1)
115
117
  return batch_next_token_ids
@@ -2,23 +2,24 @@
2
2
  Common utilities for torchao.
3
3
  """
4
4
 
5
- from typing import Dict, Set
6
-
7
5
  import torch
8
6
 
9
7
 
10
- def torchao_quantize_param_data(param: torch.Tensor, torchao_config: str):
11
- """Quantize a Tensor with torchao quantization specified by torchao_config
8
+ def apply_torchao_config_to_model(
9
+ model: torch.nn.Module, torchao_config: str, filter_fn=None
10
+ ):
11
+ """Quantize a modelwith torchao quantization specified by torchao_config
12
12
 
13
13
  Args:
14
- `param`: weight parameter of the linear module
15
- `torchao_config`: type of quantization and their arguments we want to use to
16
- quantize the Tensor, e.g. int4wo-128 means int4 weight only quantization with group_size
14
+ `model`: a model to be quantized based on torchao_config
15
+ `torchao_config` (str): type of quantization and their arguments we want to use to
16
+ quantize the model, e.g. int4wo-128 means int4 weight only quantization with group_size
17
17
  128
18
18
  """
19
19
  # Lazy import to suppress some warnings
20
20
  from torchao.quantization import (
21
21
  float8_dynamic_activation_float8_weight,
22
+ float8_weight_only,
22
23
  int4_weight_only,
23
24
  int8_dynamic_activation_int8_weight,
24
25
  int8_weight_only,
@@ -26,12 +27,17 @@ def torchao_quantize_param_data(param: torch.Tensor, torchao_config: str):
26
27
  )
27
28
  from torchao.quantization.observer import PerRow, PerTensor
28
29
 
29
- dummy_linear = torch.nn.Linear(param.shape[1], param.shape[0], bias=False)
30
- dummy_linear.weight = param
31
- if "int8wo" in torchao_config:
32
- quantize_(dummy_linear, int8_weight_only())
30
+ if filter_fn is None:
31
+
32
+ def filter_fn(module, fqn):
33
+ return "proj" in fqn
34
+
35
+ if torchao_config == "" or torchao_config is None:
36
+ return model
37
+ elif "int8wo" in torchao_config:
38
+ quantize_(model, int8_weight_only(), filter_fn=filter_fn)
33
39
  elif "int8dq" in torchao_config:
34
- quantize_(dummy_linear, int8_dynamic_activation_int8_weight())
40
+ quantize_(model, int8_dynamic_activation_int8_weight(), filter_fn=filter_fn)
35
41
  elif "int4wo" in torchao_config:
36
42
  group_size = int(torchao_config.split("-")[-1])
37
43
  assert group_size in [
@@ -40,13 +46,11 @@ def torchao_quantize_param_data(param: torch.Tensor, torchao_config: str):
40
46
  128,
41
47
  256,
42
48
  ], f"int4wo groupsize needs to be one of [32, 64, 128, 256] but got {group_size}"
43
- quantize_(dummy_linear, int4_weight_only(group_size=group_size))
49
+ quantize_(model, int4_weight_only(group_size=group_size), filter_fn=filter_fn)
44
50
  elif "fp8wo" in torchao_config:
45
- from torchao.quantization import float8_weight_only
46
-
47
51
  # this requires newer hardware
48
52
  # [rank0]: AssertionError: fp8e4nv data type is not supported on CUDA arch < 89
49
- quantize_(dummy_linear, float8_weight_only())
53
+ quantize_(model, float8_weight_only(), filter_fn=filter_fn)
50
54
  elif "fp8dq" in torchao_config:
51
55
  granularity = torchao_config.split("-")[-1]
52
56
  GRANULARITY_MAP = {
@@ -57,39 +61,13 @@ def torchao_quantize_param_data(param: torch.Tensor, torchao_config: str):
57
61
  granularity in GRANULARITY_MAP
58
62
  ), f"Supported granularity are: {GRANULARITY_MAP.keys()}, got {granularity}"
59
63
  quantize_(
60
- dummy_linear,
64
+ model,
61
65
  float8_dynamic_activation_float8_weight(
62
66
  granularity=GRANULARITY_MAP[granularity]
63
67
  ),
68
+ filter_fn=filter_fn,
64
69
  )
65
70
  else:
66
71
  raise ValueError(f"Unexpected config: {torchao_config}")
67
72
 
68
- return dummy_linear.weight
69
-
70
-
71
- def apply_torchao_config_(
72
- self: torch.nn.Module,
73
- params_dict: Dict[str, torch.Tensor],
74
- param_suffixes: Set[str],
75
- ) -> None:
76
- """A util function used for quantizing the weight parameters after they are loaded if
77
- self.torchao_config is specified
78
-
79
- Args:
80
- `self`: the model we want to quantize
81
- `params_dict`: dictionary mapping from param_name to the parameter Tensor
82
- `param_suffixes`: a set of suffixes, we'll quantize the Tensor matching these suffixes
83
-
84
- Returns:
85
- None, the `params_dict` is modified inplace and the weights of `self` model are quantized
86
- """
87
- if self.torchao_config:
88
- for param_suffix in param_suffixes:
89
- for name in params_dict:
90
- param = params_dict[name]
91
- if param_suffix in name and param.ndim == 2:
92
- params_dict[name] = torchao_quantize_param_data(
93
- param, self.torchao_config
94
- )
95
- self.load_state_dict(params_dict, assign=True)
73
+ return model
@@ -222,6 +222,7 @@ class VocabParallelEmbedding(torch.nn.Module):
222
222
  enable_tp: bool = True,
223
223
  ):
224
224
  super().__init__()
225
+ self.quant_config = quant_config
225
226
 
226
227
  self.enable_tp = enable_tp
227
228
  if self.enable_tp:
sglang/srt/lora/lora.py CHANGED
@@ -31,7 +31,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
31
31
  ParallelLMHead,
32
32
  VocabParallelEmbedding,
33
33
  )
34
- from vllm.model_executor.model_loader.loader import DefaultModelLoader
35
34
 
36
35
  from sglang.srt.layers.linear import (
37
36
  ColumnParallelLinear,
@@ -40,6 +39,7 @@ from sglang.srt.layers.linear import (
40
39
  RowParallelLinear,
41
40
  )
42
41
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
42
+ from sglang.srt.model_loader.loader import DefaultModelLoader
43
43
 
44
44
 
45
45
  class BaseLayerWithLoRA(nn.Module):
@@ -352,7 +352,7 @@ class FlushCacheReq:
352
352
 
353
353
 
354
354
  @dataclass
355
- class UpdateWeightReqInput:
355
+ class UpdateWeightFromDiskReqInput:
356
356
  # The model path with the new weights
357
357
  model_path: str
358
358
  # The format to load the weights
@@ -360,11 +360,57 @@ class UpdateWeightReqInput:
360
360
 
361
361
 
362
362
  @dataclass
363
- class UpdateWeightReqOutput:
363
+ class UpdateWeightFromDiskReqOutput:
364
364
  success: bool
365
365
  message: str
366
366
 
367
367
 
368
+ @dataclass
369
+ class UpdateWeightsFromDistributedReqInput:
370
+ name: str
371
+ dtype: str
372
+ shape: List[int]
373
+
374
+
375
+ @dataclass
376
+ class UpdateWeightsFromDistributedReqOutput:
377
+ success: bool
378
+ message: str
379
+
380
+
381
+ @dataclass
382
+ class InitWeightsUpdateGroupReqInput:
383
+ # The master address
384
+ master_address: str
385
+ # The master port
386
+ master_port: int
387
+ # The rank offset
388
+ rank_offset: int
389
+ # The world size
390
+ world_size: int
391
+ # The group name
392
+ group_name: str = "weight_update_group"
393
+ # The backend
394
+ backend: str = "nccl"
395
+
396
+
397
+ @dataclass
398
+ class InitWeightsUpdateGroupReqOutput:
399
+ success: bool
400
+ message: str
401
+
402
+
403
+ @dataclass
404
+ class GetWeightsByNameReqInput:
405
+ name: str
406
+ truncate_size: int = 100
407
+
408
+
409
+ @dataclass
410
+ class GetWeightsByNameReqOutput:
411
+ parameter: list
412
+
413
+
368
414
  @dataclass
369
415
  class AbortReq:
370
416
  # The request id
@@ -58,6 +58,7 @@ global_server_args_dict = {
58
58
  "torchao_config": ServerArgs.torchao_config,
59
59
  "enable_nan_detection": ServerArgs.enable_nan_detection,
60
60
  "enable_dp_attention": ServerArgs.enable_dp_attention,
61
+ "enable_ep_moe": ServerArgs.enable_ep_moe,
61
62
  }
62
63
 
63
64
 
@@ -743,20 +744,24 @@ class ScheduleBatch:
743
744
  extend_lens = torch.tensor(self.extend_lens, dtype=torch.int32).to(
744
745
  self.device, non_blocking=True
745
746
  )
746
- write_req_to_token_pool_triton[(bs,)](
747
- self.req_to_token_pool.req_to_token,
748
- self.req_pool_indices,
749
- pre_lens,
750
- self.seq_lens,
751
- extend_lens,
752
- self.out_cache_loc,
753
- self.req_to_token_pool.req_to_token.shape[1],
754
- )
755
- # The triton kernel is equivalent to the following python code.
756
- # self.req_to_token_pool.write(
757
- # (req.req_pool_idx, slice(pre_len, seq_len)),
758
- # out_cache_loc[pt : pt + req.extend_input_len],
759
- # )
747
+ if global_server_args_dict["attention_backend"] != "torch_native":
748
+ write_req_to_token_pool_triton[(bs,)](
749
+ self.req_to_token_pool.req_to_token,
750
+ self.req_pool_indices,
751
+ pre_lens,
752
+ self.seq_lens,
753
+ extend_lens,
754
+ self.out_cache_loc,
755
+ self.req_to_token_pool.req_to_token.shape[1],
756
+ )
757
+ else:
758
+ pt = 0
759
+ for i in range(bs):
760
+ self.req_to_token_pool.write(
761
+ (self.req_pool_indices[i], slice(pre_lens[i], self.seq_lens[i])),
762
+ self.out_cache_loc[pt : pt + self.extend_lens[i]],
763
+ )
764
+ pt += self.extend_lens[i]
760
765
  # TODO: some tensors can be reused for ForwardBatchInfo (e.g., extend_lens, cumsum_start)
761
766
 
762
767
  if self.model_config.is_encoder_decoder:
@@ -142,7 +142,7 @@ class PrefillAdder:
142
142
 
143
143
  self.req_states = None
144
144
  self.can_run_list = []
145
- self.new_inflight_req = None
145
+ self.new_being_chunked_req = None
146
146
  self.log_hit_tokens = 0
147
147
  self.log_input_tokens = 0
148
148
 
@@ -182,7 +182,7 @@ class PrefillAdder:
182
182
  self.log_hit_tokens += prefix_len
183
183
  self.log_input_tokens += extend_input_len
184
184
 
185
- def add_inflight_req(self, req: Req):
185
+ def add_being_chunked_req(self, req: Req):
186
186
  truncated = req.extend_input_len > self.rem_chunk_tokens
187
187
  req.extend_input_len = min(req.extend_input_len, self.rem_chunk_tokens)
188
188
  req.fill_ids = req.fill_ids[: len(req.prefix_indices) + req.extend_input_len]
@@ -269,10 +269,13 @@ class PrefillAdder:
269
269
  else:
270
270
  # Chunked prefill
271
271
  trunc_len = self.rem_chunk_tokens
272
+ if trunc_len == 0:
273
+ return AddReqResult.OTHER
274
+
272
275
  req.extend_input_len = trunc_len
273
276
  req.fill_ids = req.fill_ids[:trunc_len]
274
277
  self.can_run_list.append(req)
275
- self.new_inflight_req = req
278
+ self.new_being_chunked_req = req
276
279
  self._prefill_one_req(0, trunc_len, 0)
277
280
 
278
281
  return self.budget_state()
@@ -326,7 +329,7 @@ class PrefillAdder:
326
329
  req.extend_input_len = trunc_len
327
330
  req.fill_ids = req.fill_ids[: len(req.prefix_indices) + trunc_len]
328
331
  self.can_run_list.append(req)
329
- self.new_inflight_req = req
332
+ self.new_being_chunked_req = req
330
333
  self.tree_cache.inc_lock_ref(req.last_node)
331
334
  self._prefill_one_req(prefix_len, trunc_len, 0)
332
335