sglang 0.3.6.post3__py3-none-any.whl → 0.4.0.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +1 -1
- sglang/bench_one_batch.py +4 -0
- sglang/bench_serving.py +13 -0
- sglang/check_env.py +1 -1
- sglang/srt/_custom_ops.py +118 -0
- sglang/srt/configs/device_config.py +17 -0
- sglang/srt/configs/load_config.py +84 -0
- sglang/srt/configs/model_config.py +161 -4
- sglang/srt/configs/qwen2vl.py +5 -8
- sglang/srt/constrained/outlines_backend.py +11 -1
- sglang/srt/constrained/outlines_jump_forward.py +8 -1
- sglang/srt/constrained/xgrammar_backend.py +5 -5
- sglang/srt/distributed/__init__.py +3 -0
- sglang/srt/distributed/communication_op.py +34 -0
- sglang/srt/distributed/device_communicators/__init__.py +0 -0
- sglang/srt/distributed/device_communicators/cuda_wrapper.py +182 -0
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +352 -0
- sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +291 -0
- sglang/srt/distributed/device_communicators/hpu_communicator.py +48 -0
- sglang/srt/distributed/device_communicators/pynccl.py +204 -0
- sglang/srt/distributed/device_communicators/pynccl_wrapper.py +362 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +568 -0
- sglang/srt/distributed/device_communicators/xpu_communicator.py +47 -0
- sglang/srt/distributed/parallel_state.py +1275 -0
- sglang/srt/distributed/utils.py +223 -0
- sglang/srt/hf_transformers_utils.py +37 -1
- sglang/srt/layers/attention/__init__.py +5 -2
- sglang/srt/layers/attention/double_sparsity_backend.py +22 -8
- sglang/srt/layers/attention/flashinfer_backend.py +33 -20
- sglang/srt/layers/attention/torch_native_backend.py +299 -0
- sglang/srt/layers/attention/triton_backend.py +22 -8
- sglang/srt/layers/attention/triton_ops/extend_attention.py +3 -0
- sglang/srt/layers/ep_moe/__init__.py +0 -0
- sglang/srt/layers/ep_moe/kernels.py +349 -0
- sglang/srt/layers/ep_moe/layer.py +661 -0
- sglang/srt/layers/fused_moe_patch.py +20 -11
- sglang/srt/layers/linear.py +1 -0
- sglang/srt/layers/logits_processor.py +17 -3
- sglang/srt/layers/quantization/__init__.py +36 -2
- sglang/srt/layers/quantization/fp8.py +559 -0
- sglang/srt/layers/quantization/fp8_utils.py +27 -0
- sglang/srt/layers/radix_attention.py +4 -2
- sglang/srt/layers/sampler.py +2 -0
- sglang/srt/layers/torchao_utils.py +23 -45
- sglang/srt/layers/vocab_parallel_embedding.py +1 -0
- sglang/srt/lora/lora.py +1 -1
- sglang/srt/managers/io_struct.py +48 -2
- sglang/srt/managers/schedule_batch.py +19 -14
- sglang/srt/managers/schedule_policy.py +7 -4
- sglang/srt/managers/scheduler.py +145 -85
- sglang/srt/managers/tokenizer_manager.py +166 -68
- sglang/srt/managers/tp_worker.py +36 -3
- sglang/srt/managers/tp_worker_overlap_thread.py +28 -8
- sglang/srt/mem_cache/memory_pool.py +5 -1
- sglang/srt/model_executor/cuda_graph_runner.py +30 -7
- sglang/srt/model_executor/forward_batch_info.py +9 -4
- sglang/srt/model_executor/model_runner.py +146 -153
- sglang/srt/model_loader/__init__.py +34 -0
- sglang/srt/model_loader/loader.py +1139 -0
- sglang/srt/model_loader/utils.py +41 -0
- sglang/srt/model_loader/weight_utils.py +640 -0
- sglang/srt/model_parallel.py +1 -5
- sglang/srt/models/baichuan.py +9 -10
- sglang/srt/models/chatglm.py +6 -15
- sglang/srt/models/commandr.py +4 -5
- sglang/srt/models/dbrx.py +2 -3
- sglang/srt/models/deepseek.py +4 -11
- sglang/srt/models/deepseek_v2.py +90 -18
- sglang/srt/models/exaone.py +2 -3
- sglang/srt/models/gemma.py +2 -6
- sglang/srt/models/gemma2.py +3 -14
- sglang/srt/models/gemma2_reward.py +0 -1
- sglang/srt/models/gpt2.py +5 -12
- sglang/srt/models/gpt_bigcode.py +6 -22
- sglang/srt/models/grok.py +3 -8
- sglang/srt/models/internlm2.py +2 -3
- sglang/srt/models/internlm2_reward.py +0 -1
- sglang/srt/models/llama.py +96 -31
- sglang/srt/models/llama_classification.py +1 -2
- sglang/srt/models/llama_embedding.py +1 -2
- sglang/srt/models/llama_reward.py +2 -3
- sglang/srt/models/llava.py +1 -4
- sglang/srt/models/llavavid.py +1 -2
- sglang/srt/models/minicpm.py +4 -7
- sglang/srt/models/minicpm3.py +6 -19
- sglang/srt/models/mixtral.py +24 -14
- sglang/srt/models/mixtral_quant.py +2 -3
- sglang/srt/models/mllama.py +3 -7
- sglang/srt/models/olmo.py +2 -8
- sglang/srt/models/olmo2.py +0 -1
- sglang/srt/models/olmoe.py +3 -5
- sglang/srt/models/phi3_small.py +8 -13
- sglang/srt/models/qwen.py +2 -3
- sglang/srt/models/qwen2.py +10 -9
- sglang/srt/models/qwen2_moe.py +4 -16
- sglang/srt/models/qwen2_vl.py +2 -6
- sglang/srt/models/registry.py +99 -0
- sglang/srt/models/stablelm.py +2 -3
- sglang/srt/models/torch_native_llama.py +6 -17
- sglang/srt/models/xverse.py +2 -4
- sglang/srt/models/xverse_moe.py +4 -11
- sglang/srt/models/yivl.py +2 -3
- sglang/srt/openai_api/adapter.py +9 -5
- sglang/srt/openai_api/protocol.py +1 -0
- sglang/srt/sampling/sampling_batch_info.py +9 -8
- sglang/srt/server.py +270 -173
- sglang/srt/server_args.py +102 -29
- sglang/srt/utils.py +295 -28
- sglang/test/test_utils.py +7 -0
- sglang/version.py +1 -1
- {sglang-0.3.6.post3.dist-info → sglang-0.4.0.post1.dist-info}/METADATA +5 -4
- sglang-0.4.0.post1.dist-info/RECORD +189 -0
- sglang-0.3.6.post3.dist-info/RECORD +0 -162
- {sglang-0.3.6.post3.dist-info → sglang-0.4.0.post1.dist-info}/LICENSE +0 -0
- {sglang-0.3.6.post3.dist-info → sglang-0.4.0.post1.dist-info}/WHEEL +0 -0
- {sglang-0.3.6.post3.dist-info → sglang-0.4.0.post1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,27 @@
|
|
1
|
+
from typing import Optional, Tuple
|
2
|
+
|
3
|
+
import torch
|
4
|
+
|
5
|
+
|
6
|
+
def normalize_e4m3fn_to_e4m3fnuz(
|
7
|
+
weight: torch.Tensor,
|
8
|
+
weight_scale: torch.Tensor,
|
9
|
+
input_scale: Optional[torch.Tensor] = None,
|
10
|
+
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
|
11
|
+
assert weight.dtype == torch.float8_e4m3fn
|
12
|
+
# The bits pattern 10000000(-128) represents zero in e4m3fn
|
13
|
+
# but NaN in e4m3fnuz. So here we set it to 0.
|
14
|
+
# https://onnx.ai/onnx/technical/float8.html
|
15
|
+
weight_as_int8 = weight.view(torch.int8)
|
16
|
+
ROCM_FP8_NAN_AS_INT = -128
|
17
|
+
weight_as_int8[weight_as_int8 == ROCM_FP8_NAN_AS_INT] = 0
|
18
|
+
weight = weight_as_int8.view(torch.float8_e4m3fnuz)
|
19
|
+
|
20
|
+
# For the same bits representation, e4m3fnuz value is half of
|
21
|
+
# the e4m3fn value, so we should double the scaling factor to
|
22
|
+
# get the same dequantized value.
|
23
|
+
# https://onnx.ai/onnx/technical/float8.html
|
24
|
+
weight_scale = weight_scale * 2.0
|
25
|
+
if input_scale is not None:
|
26
|
+
input_scale = input_scale * 2.0
|
27
|
+
return weight, weight_scale, input_scale
|
@@ -48,11 +48,13 @@ class RadixAttention(nn.Module):
|
|
48
48
|
self.sliding_window_size = sliding_window_size or -1
|
49
49
|
self.is_cross_attention = is_cross_attention
|
50
50
|
|
51
|
-
def forward(self, q, k, v, forward_batch: ForwardBatch):
|
51
|
+
def forward(self, q, k, v, forward_batch: ForwardBatch, save_kv_cache=True):
|
52
52
|
if k is not None:
|
53
53
|
# For cross-layer sharing, kv can be None
|
54
54
|
assert v is not None
|
55
55
|
k = k.view(-1, self.tp_k_head_num, self.qk_head_dim)
|
56
56
|
v = v.view(-1, self.tp_v_head_num, self.v_head_dim)
|
57
57
|
|
58
|
-
return forward_batch.attn_backend.forward(
|
58
|
+
return forward_batch.attn_backend.forward(
|
59
|
+
q, k, v, self, forward_batch, save_kv_cache
|
60
|
+
)
|
sglang/srt/layers/sampler.py
CHANGED
@@ -111,5 +111,7 @@ def top_k_top_p_min_p_sampling_from_probs_torch(
|
|
111
111
|
probs_sort[probs_sort < min_p_thresholds.view(-1, 1)] = 0.0
|
112
112
|
probs_sort.div_(probs_sort.max(dim=-1, keepdim=True)[0])
|
113
113
|
sampled_index = torch.multinomial(probs_sort, num_samples=1)
|
114
|
+
# int32 range is enough to represent the token ids
|
115
|
+
probs_idx = probs_idx.to(torch.int32)
|
114
116
|
batch_next_token_ids = torch.gather(probs_idx, dim=1, index=sampled_index).view(-1)
|
115
117
|
return batch_next_token_ids
|
@@ -2,23 +2,24 @@
|
|
2
2
|
Common utilities for torchao.
|
3
3
|
"""
|
4
4
|
|
5
|
-
from typing import Dict, Set
|
6
|
-
|
7
5
|
import torch
|
8
6
|
|
9
7
|
|
10
|
-
def
|
11
|
-
|
8
|
+
def apply_torchao_config_to_model(
|
9
|
+
model: torch.nn.Module, torchao_config: str, filter_fn=None
|
10
|
+
):
|
11
|
+
"""Quantize a modelwith torchao quantization specified by torchao_config
|
12
12
|
|
13
13
|
Args:
|
14
|
-
`
|
15
|
-
`torchao_config
|
16
|
-
quantize the
|
14
|
+
`model`: a model to be quantized based on torchao_config
|
15
|
+
`torchao_config` (str): type of quantization and their arguments we want to use to
|
16
|
+
quantize the model, e.g. int4wo-128 means int4 weight only quantization with group_size
|
17
17
|
128
|
18
18
|
"""
|
19
19
|
# Lazy import to suppress some warnings
|
20
20
|
from torchao.quantization import (
|
21
21
|
float8_dynamic_activation_float8_weight,
|
22
|
+
float8_weight_only,
|
22
23
|
int4_weight_only,
|
23
24
|
int8_dynamic_activation_int8_weight,
|
24
25
|
int8_weight_only,
|
@@ -26,12 +27,17 @@ def torchao_quantize_param_data(param: torch.Tensor, torchao_config: str):
|
|
26
27
|
)
|
27
28
|
from torchao.quantization.observer import PerRow, PerTensor
|
28
29
|
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
|
30
|
+
if filter_fn is None:
|
31
|
+
|
32
|
+
def filter_fn(module, fqn):
|
33
|
+
return "proj" in fqn
|
34
|
+
|
35
|
+
if torchao_config == "" or torchao_config is None:
|
36
|
+
return model
|
37
|
+
elif "int8wo" in torchao_config:
|
38
|
+
quantize_(model, int8_weight_only(), filter_fn=filter_fn)
|
33
39
|
elif "int8dq" in torchao_config:
|
34
|
-
quantize_(
|
40
|
+
quantize_(model, int8_dynamic_activation_int8_weight(), filter_fn=filter_fn)
|
35
41
|
elif "int4wo" in torchao_config:
|
36
42
|
group_size = int(torchao_config.split("-")[-1])
|
37
43
|
assert group_size in [
|
@@ -40,13 +46,11 @@ def torchao_quantize_param_data(param: torch.Tensor, torchao_config: str):
|
|
40
46
|
128,
|
41
47
|
256,
|
42
48
|
], f"int4wo groupsize needs to be one of [32, 64, 128, 256] but got {group_size}"
|
43
|
-
quantize_(
|
49
|
+
quantize_(model, int4_weight_only(group_size=group_size), filter_fn=filter_fn)
|
44
50
|
elif "fp8wo" in torchao_config:
|
45
|
-
from torchao.quantization import float8_weight_only
|
46
|
-
|
47
51
|
# this requires newer hardware
|
48
52
|
# [rank0]: AssertionError: fp8e4nv data type is not supported on CUDA arch < 89
|
49
|
-
quantize_(
|
53
|
+
quantize_(model, float8_weight_only(), filter_fn=filter_fn)
|
50
54
|
elif "fp8dq" in torchao_config:
|
51
55
|
granularity = torchao_config.split("-")[-1]
|
52
56
|
GRANULARITY_MAP = {
|
@@ -57,39 +61,13 @@ def torchao_quantize_param_data(param: torch.Tensor, torchao_config: str):
|
|
57
61
|
granularity in GRANULARITY_MAP
|
58
62
|
), f"Supported granularity are: {GRANULARITY_MAP.keys()}, got {granularity}"
|
59
63
|
quantize_(
|
60
|
-
|
64
|
+
model,
|
61
65
|
float8_dynamic_activation_float8_weight(
|
62
66
|
granularity=GRANULARITY_MAP[granularity]
|
63
67
|
),
|
68
|
+
filter_fn=filter_fn,
|
64
69
|
)
|
65
70
|
else:
|
66
71
|
raise ValueError(f"Unexpected config: {torchao_config}")
|
67
72
|
|
68
|
-
return
|
69
|
-
|
70
|
-
|
71
|
-
def apply_torchao_config_(
|
72
|
-
self: torch.nn.Module,
|
73
|
-
params_dict: Dict[str, torch.Tensor],
|
74
|
-
param_suffixes: Set[str],
|
75
|
-
) -> None:
|
76
|
-
"""A util function used for quantizing the weight parameters after they are loaded if
|
77
|
-
self.torchao_config is specified
|
78
|
-
|
79
|
-
Args:
|
80
|
-
`self`: the model we want to quantize
|
81
|
-
`params_dict`: dictionary mapping from param_name to the parameter Tensor
|
82
|
-
`param_suffixes`: a set of suffixes, we'll quantize the Tensor matching these suffixes
|
83
|
-
|
84
|
-
Returns:
|
85
|
-
None, the `params_dict` is modified inplace and the weights of `self` model are quantized
|
86
|
-
"""
|
87
|
-
if self.torchao_config:
|
88
|
-
for param_suffix in param_suffixes:
|
89
|
-
for name in params_dict:
|
90
|
-
param = params_dict[name]
|
91
|
-
if param_suffix in name and param.ndim == 2:
|
92
|
-
params_dict[name] = torchao_quantize_param_data(
|
93
|
-
param, self.torchao_config
|
94
|
-
)
|
95
|
-
self.load_state_dict(params_dict, assign=True)
|
73
|
+
return model
|
sglang/srt/lora/lora.py
CHANGED
@@ -31,7 +31,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
|
|
31
31
|
ParallelLMHead,
|
32
32
|
VocabParallelEmbedding,
|
33
33
|
)
|
34
|
-
from vllm.model_executor.model_loader.loader import DefaultModelLoader
|
35
34
|
|
36
35
|
from sglang.srt.layers.linear import (
|
37
36
|
ColumnParallelLinear,
|
@@ -40,6 +39,7 @@ from sglang.srt.layers.linear import (
|
|
40
39
|
RowParallelLinear,
|
41
40
|
)
|
42
41
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
42
|
+
from sglang.srt.model_loader.loader import DefaultModelLoader
|
43
43
|
|
44
44
|
|
45
45
|
class BaseLayerWithLoRA(nn.Module):
|
sglang/srt/managers/io_struct.py
CHANGED
@@ -352,7 +352,7 @@ class FlushCacheReq:
|
|
352
352
|
|
353
353
|
|
354
354
|
@dataclass
|
355
|
-
class
|
355
|
+
class UpdateWeightFromDiskReqInput:
|
356
356
|
# The model path with the new weights
|
357
357
|
model_path: str
|
358
358
|
# The format to load the weights
|
@@ -360,11 +360,57 @@ class UpdateWeightReqInput:
|
|
360
360
|
|
361
361
|
|
362
362
|
@dataclass
|
363
|
-
class
|
363
|
+
class UpdateWeightFromDiskReqOutput:
|
364
364
|
success: bool
|
365
365
|
message: str
|
366
366
|
|
367
367
|
|
368
|
+
@dataclass
|
369
|
+
class UpdateWeightsFromDistributedReqInput:
|
370
|
+
name: str
|
371
|
+
dtype: str
|
372
|
+
shape: List[int]
|
373
|
+
|
374
|
+
|
375
|
+
@dataclass
|
376
|
+
class UpdateWeightsFromDistributedReqOutput:
|
377
|
+
success: bool
|
378
|
+
message: str
|
379
|
+
|
380
|
+
|
381
|
+
@dataclass
|
382
|
+
class InitWeightsUpdateGroupReqInput:
|
383
|
+
# The master address
|
384
|
+
master_address: str
|
385
|
+
# The master port
|
386
|
+
master_port: int
|
387
|
+
# The rank offset
|
388
|
+
rank_offset: int
|
389
|
+
# The world size
|
390
|
+
world_size: int
|
391
|
+
# The group name
|
392
|
+
group_name: str = "weight_update_group"
|
393
|
+
# The backend
|
394
|
+
backend: str = "nccl"
|
395
|
+
|
396
|
+
|
397
|
+
@dataclass
|
398
|
+
class InitWeightsUpdateGroupReqOutput:
|
399
|
+
success: bool
|
400
|
+
message: str
|
401
|
+
|
402
|
+
|
403
|
+
@dataclass
|
404
|
+
class GetWeightsByNameReqInput:
|
405
|
+
name: str
|
406
|
+
truncate_size: int = 100
|
407
|
+
|
408
|
+
|
409
|
+
@dataclass
|
410
|
+
class GetWeightsByNameReqOutput:
|
411
|
+
parameter: list
|
412
|
+
|
413
|
+
|
368
414
|
@dataclass
|
369
415
|
class AbortReq:
|
370
416
|
# The request id
|
@@ -58,6 +58,7 @@ global_server_args_dict = {
|
|
58
58
|
"torchao_config": ServerArgs.torchao_config,
|
59
59
|
"enable_nan_detection": ServerArgs.enable_nan_detection,
|
60
60
|
"enable_dp_attention": ServerArgs.enable_dp_attention,
|
61
|
+
"enable_ep_moe": ServerArgs.enable_ep_moe,
|
61
62
|
}
|
62
63
|
|
63
64
|
|
@@ -743,20 +744,24 @@ class ScheduleBatch:
|
|
743
744
|
extend_lens = torch.tensor(self.extend_lens, dtype=torch.int32).to(
|
744
745
|
self.device, non_blocking=True
|
745
746
|
)
|
746
|
-
|
747
|
-
|
748
|
-
|
749
|
-
|
750
|
-
|
751
|
-
|
752
|
-
|
753
|
-
|
754
|
-
|
755
|
-
|
756
|
-
|
757
|
-
|
758
|
-
|
759
|
-
|
747
|
+
if global_server_args_dict["attention_backend"] != "torch_native":
|
748
|
+
write_req_to_token_pool_triton[(bs,)](
|
749
|
+
self.req_to_token_pool.req_to_token,
|
750
|
+
self.req_pool_indices,
|
751
|
+
pre_lens,
|
752
|
+
self.seq_lens,
|
753
|
+
extend_lens,
|
754
|
+
self.out_cache_loc,
|
755
|
+
self.req_to_token_pool.req_to_token.shape[1],
|
756
|
+
)
|
757
|
+
else:
|
758
|
+
pt = 0
|
759
|
+
for i in range(bs):
|
760
|
+
self.req_to_token_pool.write(
|
761
|
+
(self.req_pool_indices[i], slice(pre_lens[i], self.seq_lens[i])),
|
762
|
+
self.out_cache_loc[pt : pt + self.extend_lens[i]],
|
763
|
+
)
|
764
|
+
pt += self.extend_lens[i]
|
760
765
|
# TODO: some tensors can be reused for ForwardBatchInfo (e.g., extend_lens, cumsum_start)
|
761
766
|
|
762
767
|
if self.model_config.is_encoder_decoder:
|
@@ -142,7 +142,7 @@ class PrefillAdder:
|
|
142
142
|
|
143
143
|
self.req_states = None
|
144
144
|
self.can_run_list = []
|
145
|
-
self.
|
145
|
+
self.new_being_chunked_req = None
|
146
146
|
self.log_hit_tokens = 0
|
147
147
|
self.log_input_tokens = 0
|
148
148
|
|
@@ -182,7 +182,7 @@ class PrefillAdder:
|
|
182
182
|
self.log_hit_tokens += prefix_len
|
183
183
|
self.log_input_tokens += extend_input_len
|
184
184
|
|
185
|
-
def
|
185
|
+
def add_being_chunked_req(self, req: Req):
|
186
186
|
truncated = req.extend_input_len > self.rem_chunk_tokens
|
187
187
|
req.extend_input_len = min(req.extend_input_len, self.rem_chunk_tokens)
|
188
188
|
req.fill_ids = req.fill_ids[: len(req.prefix_indices) + req.extend_input_len]
|
@@ -269,10 +269,13 @@ class PrefillAdder:
|
|
269
269
|
else:
|
270
270
|
# Chunked prefill
|
271
271
|
trunc_len = self.rem_chunk_tokens
|
272
|
+
if trunc_len == 0:
|
273
|
+
return AddReqResult.OTHER
|
274
|
+
|
272
275
|
req.extend_input_len = trunc_len
|
273
276
|
req.fill_ids = req.fill_ids[:trunc_len]
|
274
277
|
self.can_run_list.append(req)
|
275
|
-
self.
|
278
|
+
self.new_being_chunked_req = req
|
276
279
|
self._prefill_one_req(0, trunc_len, 0)
|
277
280
|
|
278
281
|
return self.budget_state()
|
@@ -326,7 +329,7 @@ class PrefillAdder:
|
|
326
329
|
req.extend_input_len = trunc_len
|
327
330
|
req.fill_ids = req.fill_ids[: len(req.prefix_indices) + trunc_len]
|
328
331
|
self.can_run_list.append(req)
|
329
|
-
self.
|
332
|
+
self.new_being_chunked_req = req
|
330
333
|
self.tree_cache.inc_lock_ref(req.last_node)
|
331
334
|
self._prefill_one_req(prefix_len, trunc_len, 0)
|
332
335
|
|