sglang 0.4.1.post6__py3-none-any.whl → 0.4.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (141) hide show
  1. sglang/__init__.py +21 -23
  2. sglang/api.py +2 -7
  3. sglang/bench_offline_throughput.py +41 -27
  4. sglang/bench_one_batch.py +60 -4
  5. sglang/bench_one_batch_server.py +1 -1
  6. sglang/bench_serving.py +83 -71
  7. sglang/lang/backend/runtime_endpoint.py +183 -4
  8. sglang/lang/chat_template.py +46 -4
  9. sglang/launch_server.py +1 -1
  10. sglang/srt/_custom_ops.py +80 -42
  11. sglang/srt/configs/device_config.py +1 -1
  12. sglang/srt/configs/load_config.py +1 -0
  13. sglang/srt/configs/model_config.py +1 -0
  14. sglang/srt/constrained/base_grammar_backend.py +21 -0
  15. sglang/srt/constrained/xgrammar_backend.py +8 -4
  16. sglang/srt/conversation.py +14 -1
  17. sglang/srt/distributed/__init__.py +3 -3
  18. sglang/srt/distributed/communication_op.py +2 -1
  19. sglang/srt/distributed/device_communicators/cuda_wrapper.py +2 -1
  20. sglang/srt/distributed/device_communicators/custom_all_reduce.py +112 -42
  21. sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +2 -2
  22. sglang/srt/distributed/device_communicators/hpu_communicator.py +2 -1
  23. sglang/srt/distributed/device_communicators/pynccl.py +80 -1
  24. sglang/srt/distributed/device_communicators/pynccl_wrapper.py +112 -2
  25. sglang/srt/distributed/device_communicators/shm_broadcast.py +5 -72
  26. sglang/srt/distributed/device_communicators/xpu_communicator.py +2 -1
  27. sglang/srt/distributed/parallel_state.py +1 -1
  28. sglang/srt/distributed/utils.py +2 -1
  29. sglang/srt/entrypoints/engine.py +452 -0
  30. sglang/srt/entrypoints/http_server.py +603 -0
  31. sglang/srt/function_call_parser.py +494 -0
  32. sglang/srt/layers/activation.py +8 -8
  33. sglang/srt/layers/attention/flashinfer_backend.py +10 -9
  34. sglang/srt/layers/attention/triton_backend.py +4 -6
  35. sglang/srt/layers/attention/vision.py +204 -0
  36. sglang/srt/layers/dp_attention.py +71 -0
  37. sglang/srt/layers/layernorm.py +5 -5
  38. sglang/srt/layers/linear.py +65 -14
  39. sglang/srt/layers/logits_processor.py +49 -64
  40. sglang/srt/layers/moe/ep_moe/layer.py +24 -16
  41. sglang/srt/layers/moe/fused_moe_native.py +84 -1
  42. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  43. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +27 -7
  44. sglang/srt/layers/moe/fused_moe_triton/layer.py +38 -5
  45. sglang/srt/layers/parameter.py +18 -8
  46. sglang/srt/layers/quantization/__init__.py +20 -23
  47. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  48. sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  49. sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  50. sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  51. sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  52. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  53. sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  54. sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  55. sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  56. sglang/srt/layers/quantization/fp8.py +10 -4
  57. sglang/srt/layers/quantization/modelopt_quant.py +1 -2
  58. sglang/srt/layers/quantization/w8a8_int8.py +1 -1
  59. sglang/srt/layers/radix_attention.py +2 -2
  60. sglang/srt/layers/rotary_embedding.py +1184 -31
  61. sglang/srt/layers/sampler.py +64 -6
  62. sglang/srt/layers/torchao_utils.py +12 -6
  63. sglang/srt/layers/vocab_parallel_embedding.py +2 -2
  64. sglang/srt/lora/lora.py +1 -9
  65. sglang/srt/managers/configure_logging.py +3 -0
  66. sglang/srt/managers/data_parallel_controller.py +79 -72
  67. sglang/srt/managers/detokenizer_manager.py +24 -6
  68. sglang/srt/managers/image_processor.py +158 -2
  69. sglang/srt/managers/io_struct.py +57 -3
  70. sglang/srt/managers/schedule_batch.py +78 -45
  71. sglang/srt/managers/schedule_policy.py +26 -12
  72. sglang/srt/managers/scheduler.py +326 -201
  73. sglang/srt/managers/session_controller.py +1 -0
  74. sglang/srt/managers/tokenizer_manager.py +210 -121
  75. sglang/srt/managers/tp_worker.py +6 -4
  76. sglang/srt/managers/tp_worker_overlap_thread.py +5 -8
  77. sglang/srt/managers/utils.py +44 -0
  78. sglang/srt/mem_cache/memory_pool.py +10 -32
  79. sglang/srt/metrics/collector.py +15 -6
  80. sglang/srt/model_executor/cuda_graph_runner.py +26 -30
  81. sglang/srt/model_executor/forward_batch_info.py +5 -7
  82. sglang/srt/model_executor/model_runner.py +44 -19
  83. sglang/srt/model_loader/loader.py +83 -6
  84. sglang/srt/model_loader/weight_utils.py +145 -6
  85. sglang/srt/models/baichuan.py +6 -6
  86. sglang/srt/models/chatglm.py +2 -2
  87. sglang/srt/models/commandr.py +17 -5
  88. sglang/srt/models/dbrx.py +13 -5
  89. sglang/srt/models/deepseek.py +3 -3
  90. sglang/srt/models/deepseek_v2.py +11 -11
  91. sglang/srt/models/exaone.py +2 -2
  92. sglang/srt/models/gemma.py +2 -2
  93. sglang/srt/models/gemma2.py +15 -25
  94. sglang/srt/models/gpt2.py +3 -5
  95. sglang/srt/models/gpt_bigcode.py +1 -1
  96. sglang/srt/models/granite.py +2 -2
  97. sglang/srt/models/grok.py +4 -3
  98. sglang/srt/models/internlm2.py +2 -2
  99. sglang/srt/models/llama.py +7 -5
  100. sglang/srt/models/minicpm.py +2 -2
  101. sglang/srt/models/minicpm3.py +9 -9
  102. sglang/srt/models/minicpmv.py +1238 -0
  103. sglang/srt/models/mixtral.py +3 -3
  104. sglang/srt/models/mixtral_quant.py +3 -3
  105. sglang/srt/models/mllama.py +2 -2
  106. sglang/srt/models/olmo.py +3 -3
  107. sglang/srt/models/olmo2.py +4 -4
  108. sglang/srt/models/olmoe.py +7 -13
  109. sglang/srt/models/phi3_small.py +2 -2
  110. sglang/srt/models/qwen.py +2 -2
  111. sglang/srt/models/qwen2.py +41 -4
  112. sglang/srt/models/qwen2_moe.py +3 -3
  113. sglang/srt/models/qwen2_vl.py +22 -122
  114. sglang/srt/models/stablelm.py +2 -2
  115. sglang/srt/models/torch_native_llama.py +20 -7
  116. sglang/srt/models/xverse.py +6 -6
  117. sglang/srt/models/xverse_moe.py +6 -6
  118. sglang/srt/openai_api/adapter.py +139 -37
  119. sglang/srt/openai_api/protocol.py +7 -4
  120. sglang/srt/sampling/custom_logit_processor.py +38 -0
  121. sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +11 -14
  122. sglang/srt/sampling/sampling_batch_info.py +143 -18
  123. sglang/srt/sampling/sampling_params.py +3 -1
  124. sglang/srt/server.py +4 -1090
  125. sglang/srt/server_args.py +77 -15
  126. sglang/srt/speculative/eagle_utils.py +37 -15
  127. sglang/srt/speculative/eagle_worker.py +11 -13
  128. sglang/srt/utils.py +164 -129
  129. sglang/test/runners.py +8 -13
  130. sglang/test/test_programs.py +2 -1
  131. sglang/test/test_utils.py +83 -22
  132. sglang/utils.py +12 -2
  133. sglang/version.py +1 -1
  134. {sglang-0.4.1.post6.dist-info → sglang-0.4.2.dist-info}/METADATA +21 -10
  135. {sglang-0.4.1.post6.dist-info → sglang-0.4.2.dist-info}/RECORD +138 -123
  136. sglang/launch_server_llavavid.py +0 -25
  137. sglang/srt/constrained/__init__.py +0 -16
  138. sglang/srt/distributed/device_communicators/__init__.py +0 -0
  139. {sglang-0.4.1.post6.dist-info → sglang-0.4.2.dist-info}/LICENSE +0 -0
  140. {sglang-0.4.1.post6.dist-info → sglang-0.4.2.dist-info}/WHEEL +0 -0
  141. {sglang-0.4.1.post6.dist-info → sglang-0.4.2.dist-info}/top_level.txt +0 -0
@@ -21,12 +21,11 @@ from typing import Iterable, Optional, Tuple
21
21
  import torch
22
22
  from torch import nn
23
23
  from transformers import MixtralConfig
24
- from vllm.distributed import (
24
+
25
+ from sglang.srt.distributed import (
25
26
  get_tensor_model_parallel_world_size,
26
27
  tensor_model_parallel_all_reduce,
27
28
  )
28
- from vllm.model_executor.layers.rotary_embedding import get_rope
29
-
30
29
  from sglang.srt.layers.layernorm import RMSNorm
31
30
  from sglang.srt.layers.linear import (
32
31
  QKVParallelLinear,
@@ -38,6 +37,7 @@ from sglang.srt.layers.moe.ep_moe.layer import EPMoE
38
37
  from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
39
38
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
40
39
  from sglang.srt.layers.radix_attention import RadixAttention
40
+ from sglang.srt.layers.rotary_embedding import get_rope
41
41
  from sglang.srt.layers.vocab_parallel_embedding import (
42
42
  ParallelLMHead,
43
43
  VocabParallelEmbedding,
@@ -23,13 +23,12 @@ import torch
23
23
  import torch.nn.functional as F
24
24
  from torch import nn
25
25
  from transformers import MixtralConfig
26
- from vllm.distributed import (
26
+
27
+ from sglang.srt.distributed import (
27
28
  get_tensor_model_parallel_rank,
28
29
  get_tensor_model_parallel_world_size,
29
30
  tensor_model_parallel_all_reduce,
30
31
  )
31
- from vllm.model_executor.layers.rotary_embedding import get_rope
32
-
33
32
  from sglang.srt.layers.layernorm import RMSNorm
34
33
  from sglang.srt.layers.linear import (
35
34
  QKVParallelLinear,
@@ -39,6 +38,7 @@ from sglang.srt.layers.linear import (
39
38
  from sglang.srt.layers.logits_processor import LogitsProcessor
40
39
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
41
40
  from sglang.srt.layers.radix_attention import RadixAttention
41
+ from sglang.srt.layers.rotary_embedding import get_rope
42
42
  from sglang.srt.layers.vocab_parallel_embedding import (
43
43
  ParallelLMHead,
44
44
  VocabParallelEmbedding,
@@ -8,14 +8,14 @@ import torch
8
8
  import torch.nn.functional as F
9
9
  import torch.utils.checkpoint
10
10
  import transformers.models.mllama.configuration_mllama as config_mllama
11
- import vllm.distributed.parallel_state as ps
12
11
  from torch import nn
13
12
  from transformers.modeling_outputs import BaseModelOutput, CausalLMOutputWithPast
14
13
  from transformers.models.mllama.modeling_mllama import (
15
14
  _prepare_aspect_ratio_attention_mask,
16
15
  )
17
- from vllm.distributed import get_tensor_model_parallel_world_size
18
16
 
17
+ import sglang.srt.distributed.parallel_state as ps
18
+ from sglang.srt.distributed import get_tensor_model_parallel_world_size
19
19
  from sglang.srt.layers.activation import get_act_fn
20
20
  from sglang.srt.layers.layernorm import RMSNorm
21
21
  from sglang.srt.layers.linear import (
sglang/srt/models/olmo.py CHANGED
@@ -15,14 +15,13 @@
15
15
  # Adapted from
16
16
  # https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/olmo.py#L1
17
17
  """Inference-only OLMo model compatible with HuggingFace weights."""
18
- from typing import Iterable, List, Optional, Tuple
18
+ from typing import Iterable, Optional, Tuple
19
19
 
20
20
  import torch
21
21
  from torch import nn
22
22
  from transformers import OlmoConfig
23
- from vllm.distributed import get_tensor_model_parallel_world_size
24
- from vllm.model_executor.layers.rotary_embedding import get_rope
25
23
 
24
+ from sglang.srt.distributed import get_tensor_model_parallel_world_size
26
25
  from sglang.srt.layers.activation import SiluAndMul
27
26
  from sglang.srt.layers.linear import (
28
27
  MergedColumnParallelLinear,
@@ -32,6 +31,7 @@ from sglang.srt.layers.linear import (
32
31
  from sglang.srt.layers.logits_processor import LogitsProcessor
33
32
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
34
33
  from sglang.srt.layers.radix_attention import RadixAttention
34
+ from sglang.srt.layers.rotary_embedding import get_rope
35
35
  from sglang.srt.layers.vocab_parallel_embedding import (
36
36
  ParallelLMHead,
37
37
  VocabParallelEmbedding,
@@ -21,15 +21,13 @@ from typing import Iterable, Optional, Tuple
21
21
  import torch
22
22
  from torch import nn
23
23
  from transformers import PretrainedConfig
24
- from vllm.distributed import (
24
+
25
+ from sglang.srt.distributed import (
25
26
  get_tensor_model_parallel_rank,
26
27
  get_tensor_model_parallel_world_size,
27
28
  split_tensor_along_last_dim,
28
29
  tensor_model_parallel_all_gather,
29
30
  )
30
- from vllm.model_executor.layers.rotary_embedding import get_rope
31
- from vllm.model_executor.model_loader.weight_utils import default_weight_loader
32
-
33
31
  from sglang.srt.layers.activation import SiluAndMul
34
32
  from sglang.srt.layers.layernorm import RMSNorm
35
33
  from sglang.srt.layers.linear import (
@@ -40,11 +38,13 @@ from sglang.srt.layers.linear import (
40
38
  from sglang.srt.layers.logits_processor import LogitsProcessor
41
39
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
42
40
  from sglang.srt.layers.radix_attention import RadixAttention
41
+ from sglang.srt.layers.rotary_embedding import get_rope
43
42
  from sglang.srt.layers.vocab_parallel_embedding import (
44
43
  ParallelLMHead,
45
44
  VocabParallelEmbedding,
46
45
  )
47
46
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
47
+ from sglang.srt.model_loader.weight_utils import default_weight_loader
48
48
  from sglang.srt.utils import make_layers
49
49
 
50
50
 
@@ -17,30 +17,24 @@
17
17
 
18
18
  """Inference-only OLMoE model compatible with HuggingFace weights."""
19
19
 
20
- from typing import Any, Dict, Iterable, List, Optional, Tuple
20
+ from typing import Any, Dict, Iterable, Optional, Tuple
21
21
 
22
22
  import torch
23
- import torch.nn.functional as F
24
23
  from torch import nn
25
24
  from transformers import PretrainedConfig
26
- from vllm.distributed import (
27
- get_tensor_model_parallel_world_size,
28
- tensor_model_parallel_all_reduce,
29
- )
30
- from vllm.model_executor.layers.linear import (
31
- MergedColumnParallelLinear,
25
+
26
+ from sglang.srt.distributed import get_tensor_model_parallel_world_size
27
+ from sglang.srt.layers.layernorm import RMSNorm
28
+ from sglang.srt.layers.linear import (
32
29
  QKVParallelLinear,
33
30
  ReplicatedLinear,
34
31
  RowParallelLinear,
35
32
  )
36
- from vllm.model_executor.layers.rotary_embedding import get_rope
37
-
38
- from sglang.srt.layers.activation import SiluAndMul
39
- from sglang.srt.layers.layernorm import RMSNorm
40
- from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
33
+ from sglang.srt.layers.logits_processor import LogitsProcessor
41
34
  from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
42
35
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
43
36
  from sglang.srt.layers.radix_attention import RadixAttention
37
+ from sglang.srt.layers.rotary_embedding import get_rope
44
38
  from sglang.srt.layers.vocab_parallel_embedding import (
45
39
  ParallelLMHead,
46
40
  VocabParallelEmbedding,
@@ -5,9 +5,8 @@ import torch
5
5
  from torch import nn
6
6
  from transformers import Phi3Config
7
7
  from transformers.configuration_utils import PretrainedConfig
8
- from vllm.distributed import get_tensor_model_parallel_world_size
9
- from vllm.model_executor.layers.rotary_embedding import get_rope
10
8
 
9
+ from sglang.srt.distributed import get_tensor_model_parallel_world_size
11
10
  from sglang.srt.layers.linear import (
12
11
  MergedColumnParallelLinear,
13
12
  QKVParallelLinear,
@@ -17,6 +16,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorO
17
16
  from sglang.srt.layers.pooler import Pooler, PoolingType
18
17
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
19
18
  from sglang.srt.layers.radix_attention import RadixAttention
19
+ from sglang.srt.layers.rotary_embedding import get_rope
20
20
  from sglang.srt.layers.vocab_parallel_embedding import (
21
21
  DEFAULT_VOCAB_PADDING_SIZE,
22
22
  ParallelLMHead,
sglang/srt/models/qwen.py CHANGED
@@ -20,9 +20,8 @@ from typing import Any, Dict, Iterable, Optional, Tuple
20
20
  import torch
21
21
  from torch import nn
22
22
  from transformers import PretrainedConfig
23
- from vllm.distributed import get_tensor_model_parallel_world_size
24
- from vllm.model_executor.layers.rotary_embedding import get_rope
25
23
 
24
+ from sglang.srt.distributed import get_tensor_model_parallel_world_size
26
25
  from sglang.srt.layers.activation import SiluAndMul
27
26
  from sglang.srt.layers.layernorm import RMSNorm
28
27
  from sglang.srt.layers.linear import (
@@ -33,6 +32,7 @@ from sglang.srt.layers.linear import (
33
32
  from sglang.srt.layers.logits_processor import LogitsProcessor
34
33
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
35
34
  from sglang.srt.layers.radix_attention import RadixAttention
35
+ from sglang.srt.layers.rotary_embedding import get_rope
36
36
  from sglang.srt.layers.vocab_parallel_embedding import (
37
37
  ParallelLMHead,
38
38
  VocabParallelEmbedding,
@@ -20,9 +20,11 @@ from typing import Any, Dict, Iterable, Optional, Tuple
20
20
 
21
21
  import torch
22
22
  from torch import nn
23
- from vllm.distributed import get_tensor_model_parallel_world_size
24
- from vllm.model_executor.layers.rotary_embedding import get_rope
25
23
 
24
+ from sglang.srt.distributed import (
25
+ get_tensor_model_parallel_rank,
26
+ get_tensor_model_parallel_world_size,
27
+ )
26
28
  from sglang.srt.layers.activation import SiluAndMul
27
29
  from sglang.srt.layers.layernorm import RMSNorm
28
30
  from sglang.srt.layers.linear import (
@@ -34,12 +36,16 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
34
36
  from sglang.srt.layers.pooler import Pooler, PoolingType
35
37
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
36
38
  from sglang.srt.layers.radix_attention import RadixAttention
39
+ from sglang.srt.layers.rotary_embedding import get_rope
37
40
  from sglang.srt.layers.vocab_parallel_embedding import (
38
41
  ParallelLMHead,
39
42
  VocabParallelEmbedding,
40
43
  )
41
44
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
42
- from sglang.srt.model_loader.weight_utils import default_weight_loader
45
+ from sglang.srt.model_loader.weight_utils import (
46
+ default_weight_loader,
47
+ kv_cache_scales_loader,
48
+ )
43
49
  from sglang.srt.utils import make_layers
44
50
 
45
51
  Qwen2Config = None
@@ -242,6 +248,9 @@ class Qwen2Model(nn.Module):
242
248
  )
243
249
  self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
244
250
 
251
+ def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
252
+ return self.embed_tokens(input_ids)
253
+
245
254
  def forward(
246
255
  self,
247
256
  input_ids: torch.Tensor,
@@ -265,9 +274,31 @@ class Qwen2Model(nn.Module):
265
274
  hidden_states, _ = self.norm(hidden_states, residual)
266
275
  return hidden_states
267
276
 
277
+ # If this function is called, it should always initialize KV cache scale
278
+ # factors (or else raise an exception). Thus, handled exceptions should
279
+ # make sure to leave KV cache scale factors in a known good (dummy) state
280
+ def load_kv_cache_scales(self, quantization_param_path: str) -> None:
281
+ tp_size = get_tensor_model_parallel_world_size()
282
+ tp_rank = get_tensor_model_parallel_rank()
283
+ for layer_idx, scaling_factor in kv_cache_scales_loader(
284
+ quantization_param_path,
285
+ tp_rank,
286
+ tp_size,
287
+ self.config.num_hidden_layers,
288
+ self.config.__class__.model_type,
289
+ ):
290
+ if not isinstance(self.layers[layer_idx], nn.Identity):
291
+ layer_self_attn = self.layers[layer_idx].self_attn
292
+ if hasattr(layer_self_attn.attn, "k_scale"):
293
+ layer_self_attn.attn.k_scale = scaling_factor
294
+ layer_self_attn.attn.v_scale = scaling_factor
295
+ else:
296
+ raise RuntimeError(
297
+ "Self attention has no KV cache scaling " "factor attribute!"
298
+ )
268
299
 
269
- class Qwen2ForCausalLM(nn.Module):
270
300
 
301
+ class Qwen2ForCausalLM(nn.Module):
271
302
  # BitandBytes specific attributes
272
303
  default_bitsandbytes_target_modules = [
273
304
  ".gate_proj.",
@@ -305,6 +336,9 @@ class Qwen2ForCausalLM(nn.Module):
305
336
  self.logits_processor = LogitsProcessor(config)
306
337
  self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
307
338
 
339
+ def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
340
+ return self.model.get_input_embeddings(input_ids)
341
+
308
342
  @torch.no_grad()
309
343
  def forward(
310
344
  self,
@@ -373,5 +407,8 @@ class Qwen2ForCausalLM(nn.Module):
373
407
  torch.cuda.empty_cache()
374
408
  torch.cuda.synchronize()
375
409
 
410
+ def load_kv_cache_scales(self, quantization_param_path: str) -> None:
411
+ self.model.load_kv_cache_scales(quantization_param_path)
412
+
376
413
 
377
414
  EntryClass = Qwen2ForCausalLM
@@ -22,12 +22,11 @@ import torch
22
22
  import torch.nn.functional as F
23
23
  from torch import nn
24
24
  from transformers import PretrainedConfig
25
- from vllm.distributed import (
25
+
26
+ from sglang.srt.distributed import (
26
27
  get_tensor_model_parallel_world_size,
27
28
  tensor_model_parallel_all_reduce,
28
29
  )
29
- from vllm.model_executor.layers.rotary_embedding import get_rope
30
-
31
30
  from sglang.srt.layers.activation import SiluAndMul
32
31
  from sglang.srt.layers.layernorm import RMSNorm
33
32
  from sglang.srt.layers.linear import (
@@ -40,6 +39,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
40
39
  from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
41
40
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
42
41
  from sglang.srt.layers.radix_attention import RadixAttention
42
+ from sglang.srt.layers.rotary_embedding import get_rope
43
43
  from sglang.srt.layers.vocab_parallel_embedding import (
44
44
  ParallelLMHead,
45
45
  VocabParallelEmbedding,
@@ -22,6 +22,7 @@
22
22
  # See the License for the specific language governing permissions and
23
23
  # limitations under the License.
24
24
  """Inference-only Qwen2-VL model compatible with HuggingFace weights."""
25
+ import logging
25
26
  from functools import lru_cache, partial
26
27
  from typing import Iterable, List, Optional, Tuple, Type, TypedDict
27
28
 
@@ -30,16 +31,13 @@ import torch
30
31
  import torch.nn as nn
31
32
  import torch.nn.functional as F
32
33
  from einops import rearrange, repeat
33
- from vllm.distributed import parallel_state
34
- from vllm.distributed import utils as dist_utils
35
- from vllm.logger import init_logger
36
34
  from vllm.model_executor.layers.activation import QuickGELU
37
35
 
38
36
  from sglang.srt.configs import Qwen2VLConfig, Qwen2VLVisionConfig
37
+ from sglang.srt.distributed import parallel_state
38
+ from sglang.srt.distributed import utils as dist_utils
39
39
  from sglang.srt.hf_transformers_utils import get_processor
40
- from sglang.srt.layers.attention.triton_ops.prefill_attention import (
41
- context_attention_fwd,
42
- )
40
+ from sglang.srt.layers.attention.vision import VisionAttention
43
41
  from sglang.srt.layers.linear import ColumnParallelLinear, RowParallelLinear
44
42
  from sglang.srt.layers.logits_processor import LogitsProcessor
45
43
  from sglang.srt.layers.pooler import Pooler, PoolingType
@@ -50,7 +48,8 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
50
48
  from sglang.srt.model_loader.weight_utils import default_weight_loader
51
49
  from sglang.srt.models.qwen2 import Qwen2Model
52
50
 
53
- logger = init_logger(__name__)
51
+ logger = logging.getLogger(__name__)
52
+
54
53
 
55
54
  # === Vision Inputs === #
56
55
 
@@ -110,118 +109,6 @@ class Qwen2VisionMLP(nn.Module):
110
109
  return x
111
110
 
112
111
 
113
- def rotate_half(x: torch.Tensor, interleaved: bool = False) -> torch.Tensor:
114
- if not interleaved:
115
- x1, x2 = x.chunk(2, dim=-1)
116
- return torch.cat((-x2, x1), dim=-1)
117
- else:
118
- x1, x2 = x[..., ::2], x[..., 1::2]
119
- return rearrange(
120
- torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2
121
- )
122
-
123
-
124
- def apply_rotary_emb_torch(
125
- x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, interleaved: bool = False
126
- ) -> torch.Tensor:
127
- """
128
- x: (batch_size, seqlen, nheads, headdim)
129
- cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2)
130
- """
131
- ro_dim = cos.shape[-1] * 2
132
- assert ro_dim <= x.shape[-1]
133
- cos = repeat(
134
- cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)"
135
- )
136
- sin = repeat(
137
- sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)"
138
- )
139
- return torch.cat(
140
- [
141
- x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin,
142
- x[..., ro_dim:],
143
- ],
144
- dim=-1,
145
- )
146
-
147
-
148
- def apply_rotary_pos_emb_vision(t: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor:
149
- t_ = t.float()
150
- cos = freqs.cos()
151
- sin = freqs.sin()
152
- output = apply_rotary_emb_torch(t_, cos, sin).type_as(t)
153
- return output
154
-
155
-
156
- class Qwen2VisionAttention(nn.Module):
157
-
158
- def __init__(
159
- self,
160
- embed_dim: Optional[int] = None,
161
- num_heads: Optional[int] = None,
162
- projection_size: Optional[int] = None,
163
- quant_config: Optional[QuantizationConfig] = None,
164
- ) -> None:
165
- super().__init__()
166
- # Per attention head and per partition values.
167
- world_size = parallel_state.get_tensor_model_parallel_world_size()
168
- self.hidden_size_per_attention_head = dist_utils.divide(
169
- projection_size, num_heads
170
- )
171
- self.num_attention_heads_per_partition = dist_utils.divide(
172
- num_heads, world_size
173
- )
174
-
175
- self.qkv = ColumnParallelLinear(
176
- input_size=embed_dim,
177
- output_size=3 * projection_size,
178
- quant_config=quant_config,
179
- )
180
- self.proj = RowParallelLinear(
181
- input_size=projection_size, output_size=embed_dim, quant_config=quant_config
182
- )
183
-
184
- def forward(
185
- self,
186
- x: torch.Tensor,
187
- cu_seqlens: torch.Tensor,
188
- rotary_pos_emb: torch.Tensor = None,
189
- ) -> torch.Tensor:
190
- # [s, b, c] --> [s, b, head * 3 * head_dim]
191
- x, _ = self.qkv(x)
192
-
193
- # [s, b, head * 3 * head_dim] --> [s, b, head, 3 * head_dim]
194
- new_x_shape = x.size()[:-1] + (
195
- self.num_attention_heads_per_partition,
196
- 3 * self.hidden_size_per_attention_head,
197
- )
198
- x = x.view(*new_x_shape)
199
-
200
- # [s, b, head, 3 * head_dim] --> 3 [s, b, head, head_dim]
201
- q, k, v = dist_utils.split_tensor_along_last_dim(x, 3)
202
- batch_size = q.shape[1]
203
-
204
- q, k, v = [rearrange(x, "s b ... -> b s ...").contiguous() for x in (q, k, v)]
205
- if rotary_pos_emb is not None:
206
- q = apply_rotary_pos_emb_vision(q, rotary_pos_emb)
207
- k = apply_rotary_pos_emb_vision(k, rotary_pos_emb)
208
-
209
- seq_lens = cu_seqlens[1:] - cu_seqlens[:-1]
210
- max_seqlen = (seq_lens).max().item()
211
- q, k, v = [rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]]
212
-
213
- output = torch.empty_like(q)
214
- context_attention_fwd(
215
- q, k, v, output, cu_seqlens, seq_lens, max_seqlen, is_causal=False
216
- )
217
-
218
- context_layer = rearrange(output, "(b s) ... -> b s ...", b=batch_size)
219
- context_layer = rearrange(context_layer, "b s h d -> s b (h d)").contiguous()
220
-
221
- output, _ = self.proj(context_layer)
222
- return output
223
-
224
-
225
112
  class Qwen2VisionBlock(nn.Module):
226
113
 
227
114
  def __init__(
@@ -240,10 +127,11 @@ class Qwen2VisionBlock(nn.Module):
240
127
  self.norm2 = norm_layer(dim)
241
128
  mlp_hidden_dim = int(dim * mlp_ratio)
242
129
 
243
- self.attn = Qwen2VisionAttention(
130
+ self.attn = VisionAttention(
244
131
  embed_dim=dim,
245
132
  num_heads=num_heads,
246
133
  projection_size=dim,
134
+ use_qkv_parallel=False,
247
135
  quant_config=quant_config,
248
136
  )
249
137
  self.mlp = Qwen2VisionMLP(
@@ -253,9 +141,13 @@ class Qwen2VisionBlock(nn.Module):
253
141
  def forward(
254
142
  self, x: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor
255
143
  ) -> torch.Tensor:
256
- x = x + self.attn(
257
- self.norm1(x), cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb
144
+ hidden_states = self.norm1(x)
145
+ hidden_states = rearrange(hidden_states, "s b ... -> b s ...")
146
+ attn = self.attn(
147
+ hidden_states, cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb
258
148
  )
149
+ attn = rearrange(attn, "b s ... -> s b ...")
150
+ x = x + attn
259
151
  x = x + self.mlp(self.norm2(x))
260
152
  return x
261
153
 
@@ -684,10 +576,12 @@ class Qwen2VLForConditionalGeneration(nn.Module):
684
576
  for name, loaded_weight in weights:
685
577
  if "rotary_emb.inv_freq" in name:
686
578
  continue
579
+
687
580
  for param_name, weight_name, shard_id in stacked_params_mapping:
688
581
  if weight_name not in name:
689
582
  continue
690
583
  name = name.replace(weight_name, param_name)
584
+
691
585
  # Skip loading extra bias for GPTQ models.
692
586
  if name.endswith(".bias") and name not in params_dict:
693
587
  continue
@@ -696,6 +590,7 @@ class Qwen2VLForConditionalGeneration(nn.Module):
696
590
  weight_loader(param, loaded_weight, shard_id)
697
591
  break
698
592
  else:
593
+
699
594
  if "visual" in name and "qkv.weight" in name:
700
595
  visual_num_heads = self.config.vision_config.num_heads
701
596
  visual_embed_dim = self.config.vision_config.embed_dim
@@ -712,6 +607,11 @@ class Qwen2VLForConditionalGeneration(nn.Module):
712
607
  loaded_weight = loaded_weight.view(3, visual_num_heads, head_size)
713
608
  loaded_weight = loaded_weight.transpose(0, 1)
714
609
  loaded_weight = loaded_weight.reshape(-1)
610
+
611
+ if "visual" in name:
612
+ # adapt to VisionAttention
613
+ name = name.replace(r"attn.qkv.", r"attn.qkv_proj.")
614
+
715
615
  try:
716
616
  # Skip loading extra bias for GPTQ models.
717
617
  if name.endswith(".bias") and name not in params_dict:
@@ -24,9 +24,8 @@ from typing import Iterable, Optional, Tuple
24
24
  import torch
25
25
  from torch import nn
26
26
  from transformers import PretrainedConfig
27
- from vllm.distributed import get_tensor_model_parallel_world_size
28
- from vllm.model_executor.layers.rotary_embedding import get_rope
29
27
 
28
+ from sglang.srt.distributed import get_tensor_model_parallel_world_size
30
29
  from sglang.srt.layers.activation import SiluAndMul
31
30
  from sglang.srt.layers.linear import (
32
31
  MergedColumnParallelLinear,
@@ -36,6 +35,7 @@ from sglang.srt.layers.linear import (
36
35
  from sglang.srt.layers.logits_processor import LogitsProcessor
37
36
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
38
37
  from sglang.srt.layers.radix_attention import RadixAttention
38
+ from sglang.srt.layers.rotary_embedding import get_rope
39
39
  from sglang.srt.layers.vocab_parallel_embedding import (
40
40
  ParallelLMHead,
41
41
  VocabParallelEmbedding,
@@ -47,17 +47,17 @@ import torch
47
47
  from torch import nn
48
48
  from torch.nn.parameter import Parameter
49
49
  from transformers import LlamaConfig
50
- from vllm.distributed import (
50
+
51
+ from sglang.srt.distributed import (
51
52
  get_tensor_model_parallel_rank,
52
53
  get_tensor_model_parallel_world_size,
53
54
  )
54
- from vllm.model_executor.layers.rotary_embedding import get_rope
55
-
56
55
  from sglang.srt.layers.activation import SiluAndMul
57
56
  from sglang.srt.layers.layernorm import RMSNorm
58
57
  from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
59
58
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
60
59
  from sglang.srt.layers.radix_attention import RadixAttention
60
+ from sglang.srt.layers.rotary_embedding import get_rope
61
61
  from sglang.srt.layers.vocab_parallel_embedding import (
62
62
  ParallelLMHead,
63
63
  VocabParallelEmbedding,
@@ -460,7 +460,12 @@ class TorchNativeLlamaForCausalLM(nn.Module):
460
460
  params_dict = dict(self.named_parameters())
461
461
  return len(params_dict)
462
462
 
463
- def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
463
+ def load_weights_to_module(
464
+ self,
465
+ fqn: str,
466
+ weights: Iterable[Tuple[str, torch.Tensor]],
467
+ ):
468
+ """Load weights onto submodule pointed by path `fqn`."""
464
469
  stacked_params_mapping = [
465
470
  # (param_name, shard_name, shard_id)
466
471
  (".qkv_proj", ".q_proj", "q"),
@@ -469,7 +474,8 @@ class TorchNativeLlamaForCausalLM(nn.Module):
469
474
  (".gate_up_proj", ".gate_proj", 0),
470
475
  (".gate_up_proj", ".up_proj", 1),
471
476
  ]
472
- params_dict = dict(self.named_parameters())
477
+ module = self.get_submodule(fqn)
478
+ params_dict = dict(module.named_parameters(prefix=fqn, recurse=False))
473
479
 
474
480
  for name, loaded_weight in weights:
475
481
  if "rotary_emb.inv_freq" in name or "projector" in name:
@@ -486,7 +492,7 @@ class TorchNativeLlamaForCausalLM(nn.Module):
486
492
  continue
487
493
  name = name.replace(weight_name, param_name)
488
494
  # Skip loading extra bias for GPTQ models.
489
- if name.endswith(".bias") and name not in params_dict:
495
+ if name.endswith(".bias") or name not in params_dict:
490
496
  continue
491
497
  param = params_dict[name]
492
498
  weight_loader = param.weight_loader
@@ -494,12 +500,19 @@ class TorchNativeLlamaForCausalLM(nn.Module):
494
500
  break
495
501
  else:
496
502
  # Skip loading extra bias for GPTQ models.
497
- if name.endswith(".bias") and name not in params_dict:
503
+ if name.endswith(".bias") or name not in params_dict:
498
504
  continue
499
505
  param = params_dict[name]
500
506
  weight_loader = getattr(param, "weight_loader", default_weight_loader)
501
507
  weight_loader(param, loaded_weight)
502
508
 
509
+ def load_weights(
510
+ self,
511
+ weights: Iterable[Tuple[str, torch.Tensor]],
512
+ ):
513
+ """Load weights onto the full model."""
514
+ self.load_weights_to_module("", weights)
515
+
503
516
 
504
517
  class TorchNativePhi3ForCausalLM(TorchNativeLlamaForCausalLM):
505
518
  pass
@@ -21,19 +21,19 @@ from typing import Any, Dict, Iterable, Optional, Tuple
21
21
  import torch
22
22
  from torch import nn
23
23
  from transformers import LlamaConfig
24
- from vllm.distributed import get_tensor_model_parallel_world_size
25
- from vllm.model_executor.layers.activation import SiluAndMul
26
- from vllm.model_executor.layers.layernorm import RMSNorm
27
- from vllm.model_executor.layers.linear import (
24
+
25
+ from sglang.srt.distributed import get_tensor_model_parallel_world_size
26
+ from sglang.srt.layers.activation import SiluAndMul
27
+ from sglang.srt.layers.layernorm import RMSNorm
28
+ from sglang.srt.layers.linear import (
28
29
  MergedColumnParallelLinear,
29
30
  QKVParallelLinear,
30
31
  RowParallelLinear,
31
32
  )
32
- from vllm.model_executor.layers.rotary_embedding import get_rope
33
-
34
33
  from sglang.srt.layers.logits_processor import LogitsProcessor
35
34
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
36
35
  from sglang.srt.layers.radix_attention import RadixAttention
36
+ from sglang.srt.layers.rotary_embedding import get_rope
37
37
  from sglang.srt.layers.vocab_parallel_embedding import (
38
38
  ParallelLMHead,
39
39
  VocabParallelEmbedding,