sglang 0.4.1.post5__py3-none-any.whl → 0.4.1.post7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (129) hide show
  1. sglang/__init__.py +21 -23
  2. sglang/api.py +2 -7
  3. sglang/bench_offline_throughput.py +24 -16
  4. sglang/bench_one_batch.py +51 -3
  5. sglang/bench_one_batch_server.py +1 -1
  6. sglang/bench_serving.py +37 -28
  7. sglang/lang/backend/runtime_endpoint.py +183 -4
  8. sglang/lang/chat_template.py +15 -4
  9. sglang/launch_server.py +1 -1
  10. sglang/srt/_custom_ops.py +80 -42
  11. sglang/srt/configs/device_config.py +1 -1
  12. sglang/srt/configs/model_config.py +16 -6
  13. sglang/srt/constrained/base_grammar_backend.py +21 -0
  14. sglang/srt/constrained/xgrammar_backend.py +8 -4
  15. sglang/srt/conversation.py +14 -1
  16. sglang/srt/distributed/__init__.py +3 -3
  17. sglang/srt/distributed/communication_op.py +2 -1
  18. sglang/srt/distributed/device_communicators/cuda_wrapper.py +2 -1
  19. sglang/srt/distributed/device_communicators/custom_all_reduce.py +107 -40
  20. sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +2 -2
  21. sglang/srt/distributed/device_communicators/hpu_communicator.py +2 -1
  22. sglang/srt/distributed/device_communicators/pynccl.py +80 -1
  23. sglang/srt/distributed/device_communicators/pynccl_wrapper.py +112 -2
  24. sglang/srt/distributed/device_communicators/shm_broadcast.py +5 -72
  25. sglang/srt/distributed/device_communicators/xpu_communicator.py +2 -1
  26. sglang/srt/distributed/parallel_state.py +1 -1
  27. sglang/srt/distributed/utils.py +2 -1
  28. sglang/srt/entrypoints/engine.py +449 -0
  29. sglang/srt/entrypoints/http_server.py +579 -0
  30. sglang/srt/layers/activation.py +3 -3
  31. sglang/srt/layers/attention/flashinfer_backend.py +27 -12
  32. sglang/srt/layers/attention/triton_backend.py +4 -6
  33. sglang/srt/layers/attention/vision.py +204 -0
  34. sglang/srt/layers/dp_attention.py +69 -0
  35. sglang/srt/layers/linear.py +76 -102
  36. sglang/srt/layers/logits_processor.py +48 -63
  37. sglang/srt/layers/moe/ep_moe/layer.py +4 -4
  38. sglang/srt/layers/moe/fused_moe_native.py +69 -0
  39. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +9 -6
  40. sglang/srt/layers/moe/fused_moe_triton/layer.py +66 -14
  41. sglang/srt/layers/moe/topk.py +4 -2
  42. sglang/srt/layers/parameter.py +26 -17
  43. sglang/srt/layers/quantization/__init__.py +22 -23
  44. sglang/srt/layers/quantization/fp8.py +112 -55
  45. sglang/srt/layers/quantization/fp8_utils.py +1 -1
  46. sglang/srt/layers/quantization/int8_kernel.py +54 -0
  47. sglang/srt/layers/quantization/modelopt_quant.py +2 -3
  48. sglang/srt/layers/quantization/w8a8_int8.py +117 -0
  49. sglang/srt/layers/radix_attention.py +2 -0
  50. sglang/srt/layers/rotary_embedding.py +1179 -31
  51. sglang/srt/layers/sampler.py +39 -1
  52. sglang/srt/layers/vocab_parallel_embedding.py +17 -4
  53. sglang/srt/lora/lora.py +1 -9
  54. sglang/srt/managers/configure_logging.py +46 -0
  55. sglang/srt/managers/data_parallel_controller.py +79 -72
  56. sglang/srt/managers/detokenizer_manager.py +23 -8
  57. sglang/srt/managers/image_processor.py +158 -2
  58. sglang/srt/managers/io_struct.py +54 -15
  59. sglang/srt/managers/schedule_batch.py +49 -22
  60. sglang/srt/managers/schedule_policy.py +26 -12
  61. sglang/srt/managers/scheduler.py +319 -181
  62. sglang/srt/managers/session_controller.py +1 -0
  63. sglang/srt/managers/tokenizer_manager.py +303 -158
  64. sglang/srt/managers/tp_worker.py +6 -4
  65. sglang/srt/managers/tp_worker_overlap_thread.py +5 -8
  66. sglang/srt/managers/utils.py +44 -0
  67. sglang/srt/mem_cache/memory_pool.py +110 -77
  68. sglang/srt/metrics/collector.py +25 -11
  69. sglang/srt/model_executor/cuda_graph_runner.py +4 -6
  70. sglang/srt/model_executor/model_runner.py +80 -21
  71. sglang/srt/model_loader/loader.py +8 -6
  72. sglang/srt/model_loader/weight_utils.py +55 -2
  73. sglang/srt/models/baichuan.py +6 -6
  74. sglang/srt/models/chatglm.py +2 -2
  75. sglang/srt/models/commandr.py +3 -3
  76. sglang/srt/models/dbrx.py +4 -4
  77. sglang/srt/models/deepseek.py +3 -3
  78. sglang/srt/models/deepseek_v2.py +8 -8
  79. sglang/srt/models/exaone.py +2 -2
  80. sglang/srt/models/gemma.py +2 -2
  81. sglang/srt/models/gemma2.py +6 -24
  82. sglang/srt/models/gpt2.py +3 -5
  83. sglang/srt/models/gpt_bigcode.py +1 -1
  84. sglang/srt/models/granite.py +2 -2
  85. sglang/srt/models/grok.py +3 -3
  86. sglang/srt/models/internlm2.py +2 -2
  87. sglang/srt/models/llama.py +41 -4
  88. sglang/srt/models/minicpm.py +2 -2
  89. sglang/srt/models/minicpm3.py +6 -6
  90. sglang/srt/models/minicpmv.py +1238 -0
  91. sglang/srt/models/mixtral.py +3 -3
  92. sglang/srt/models/mixtral_quant.py +3 -3
  93. sglang/srt/models/mllama.py +2 -2
  94. sglang/srt/models/olmo.py +3 -3
  95. sglang/srt/models/olmo2.py +4 -4
  96. sglang/srt/models/olmoe.py +7 -13
  97. sglang/srt/models/phi3_small.py +2 -2
  98. sglang/srt/models/qwen.py +2 -2
  99. sglang/srt/models/qwen2.py +52 -4
  100. sglang/srt/models/qwen2_eagle.py +131 -0
  101. sglang/srt/models/qwen2_moe.py +3 -3
  102. sglang/srt/models/qwen2_vl.py +22 -122
  103. sglang/srt/models/stablelm.py +2 -2
  104. sglang/srt/models/torch_native_llama.py +3 -3
  105. sglang/srt/models/xverse.py +6 -6
  106. sglang/srt/models/xverse_moe.py +6 -6
  107. sglang/srt/openai_api/protocol.py +2 -0
  108. sglang/srt/sampling/custom_logit_processor.py +38 -0
  109. sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +15 -5
  110. sglang/srt/sampling/sampling_batch_info.py +153 -9
  111. sglang/srt/sampling/sampling_params.py +4 -2
  112. sglang/srt/server.py +4 -1037
  113. sglang/srt/server_args.py +84 -32
  114. sglang/srt/speculative/eagle_worker.py +1 -0
  115. sglang/srt/torch_memory_saver_adapter.py +59 -0
  116. sglang/srt/utils.py +130 -63
  117. sglang/test/runners.py +8 -13
  118. sglang/test/test_programs.py +1 -1
  119. sglang/test/test_utils.py +3 -1
  120. sglang/utils.py +12 -2
  121. sglang/version.py +1 -1
  122. {sglang-0.4.1.post5.dist-info → sglang-0.4.1.post7.dist-info}/METADATA +26 -13
  123. {sglang-0.4.1.post5.dist-info → sglang-0.4.1.post7.dist-info}/RECORD +126 -117
  124. sglang/launch_server_llavavid.py +0 -25
  125. sglang/srt/constrained/__init__.py +0 -16
  126. sglang/srt/distributed/device_communicators/__init__.py +0 -0
  127. {sglang-0.4.1.post5.dist-info → sglang-0.4.1.post7.dist-info}/LICENSE +0 -0
  128. {sglang-0.4.1.post5.dist-info → sglang-0.4.1.post7.dist-info}/WHEEL +0 -0
  129. {sglang-0.4.1.post5.dist-info → sglang-0.4.1.post7.dist-info}/top_level.txt +0 -0
@@ -21,12 +21,11 @@ from typing import Iterable, Optional, Tuple
21
21
  import torch
22
22
  from torch import nn
23
23
  from transformers import MixtralConfig
24
- from vllm.distributed import (
24
+
25
+ from sglang.srt.distributed import (
25
26
  get_tensor_model_parallel_world_size,
26
27
  tensor_model_parallel_all_reduce,
27
28
  )
28
- from vllm.model_executor.layers.rotary_embedding import get_rope
29
-
30
29
  from sglang.srt.layers.layernorm import RMSNorm
31
30
  from sglang.srt.layers.linear import (
32
31
  QKVParallelLinear,
@@ -38,6 +37,7 @@ from sglang.srt.layers.moe.ep_moe.layer import EPMoE
38
37
  from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
39
38
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
40
39
  from sglang.srt.layers.radix_attention import RadixAttention
40
+ from sglang.srt.layers.rotary_embedding import get_rope
41
41
  from sglang.srt.layers.vocab_parallel_embedding import (
42
42
  ParallelLMHead,
43
43
  VocabParallelEmbedding,
@@ -23,13 +23,12 @@ import torch
23
23
  import torch.nn.functional as F
24
24
  from torch import nn
25
25
  from transformers import MixtralConfig
26
- from vllm.distributed import (
26
+
27
+ from sglang.srt.distributed import (
27
28
  get_tensor_model_parallel_rank,
28
29
  get_tensor_model_parallel_world_size,
29
30
  tensor_model_parallel_all_reduce,
30
31
  )
31
- from vllm.model_executor.layers.rotary_embedding import get_rope
32
-
33
32
  from sglang.srt.layers.layernorm import RMSNorm
34
33
  from sglang.srt.layers.linear import (
35
34
  QKVParallelLinear,
@@ -39,6 +38,7 @@ from sglang.srt.layers.linear import (
39
38
  from sglang.srt.layers.logits_processor import LogitsProcessor
40
39
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
41
40
  from sglang.srt.layers.radix_attention import RadixAttention
41
+ from sglang.srt.layers.rotary_embedding import get_rope
42
42
  from sglang.srt.layers.vocab_parallel_embedding import (
43
43
  ParallelLMHead,
44
44
  VocabParallelEmbedding,
@@ -8,14 +8,14 @@ import torch
8
8
  import torch.nn.functional as F
9
9
  import torch.utils.checkpoint
10
10
  import transformers.models.mllama.configuration_mllama as config_mllama
11
- import vllm.distributed.parallel_state as ps
12
11
  from torch import nn
13
12
  from transformers.modeling_outputs import BaseModelOutput, CausalLMOutputWithPast
14
13
  from transformers.models.mllama.modeling_mllama import (
15
14
  _prepare_aspect_ratio_attention_mask,
16
15
  )
17
- from vllm.distributed import get_tensor_model_parallel_world_size
18
16
 
17
+ import sglang.srt.distributed.parallel_state as ps
18
+ from sglang.srt.distributed import get_tensor_model_parallel_world_size
19
19
  from sglang.srt.layers.activation import get_act_fn
20
20
  from sglang.srt.layers.layernorm import RMSNorm
21
21
  from sglang.srt.layers.linear import (
sglang/srt/models/olmo.py CHANGED
@@ -15,14 +15,13 @@
15
15
  # Adapted from
16
16
  # https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/olmo.py#L1
17
17
  """Inference-only OLMo model compatible with HuggingFace weights."""
18
- from typing import Iterable, List, Optional, Tuple
18
+ from typing import Iterable, Optional, Tuple
19
19
 
20
20
  import torch
21
21
  from torch import nn
22
22
  from transformers import OlmoConfig
23
- from vllm.distributed import get_tensor_model_parallel_world_size
24
- from vllm.model_executor.layers.rotary_embedding import get_rope
25
23
 
24
+ from sglang.srt.distributed import get_tensor_model_parallel_world_size
26
25
  from sglang.srt.layers.activation import SiluAndMul
27
26
  from sglang.srt.layers.linear import (
28
27
  MergedColumnParallelLinear,
@@ -32,6 +31,7 @@ from sglang.srt.layers.linear import (
32
31
  from sglang.srt.layers.logits_processor import LogitsProcessor
33
32
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
34
33
  from sglang.srt.layers.radix_attention import RadixAttention
34
+ from sglang.srt.layers.rotary_embedding import get_rope
35
35
  from sglang.srt.layers.vocab_parallel_embedding import (
36
36
  ParallelLMHead,
37
37
  VocabParallelEmbedding,
@@ -21,15 +21,13 @@ from typing import Iterable, Optional, Tuple
21
21
  import torch
22
22
  from torch import nn
23
23
  from transformers import PretrainedConfig
24
- from vllm.distributed import (
24
+
25
+ from sglang.srt.distributed import (
25
26
  get_tensor_model_parallel_rank,
26
27
  get_tensor_model_parallel_world_size,
27
28
  split_tensor_along_last_dim,
28
29
  tensor_model_parallel_all_gather,
29
30
  )
30
- from vllm.model_executor.layers.rotary_embedding import get_rope
31
- from vllm.model_executor.model_loader.weight_utils import default_weight_loader
32
-
33
31
  from sglang.srt.layers.activation import SiluAndMul
34
32
  from sglang.srt.layers.layernorm import RMSNorm
35
33
  from sglang.srt.layers.linear import (
@@ -40,11 +38,13 @@ from sglang.srt.layers.linear import (
40
38
  from sglang.srt.layers.logits_processor import LogitsProcessor
41
39
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
42
40
  from sglang.srt.layers.radix_attention import RadixAttention
41
+ from sglang.srt.layers.rotary_embedding import get_rope
43
42
  from sglang.srt.layers.vocab_parallel_embedding import (
44
43
  ParallelLMHead,
45
44
  VocabParallelEmbedding,
46
45
  )
47
46
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
47
+ from sglang.srt.model_loader.weight_utils import default_weight_loader
48
48
  from sglang.srt.utils import make_layers
49
49
 
50
50
 
@@ -17,30 +17,24 @@
17
17
 
18
18
  """Inference-only OLMoE model compatible with HuggingFace weights."""
19
19
 
20
- from typing import Any, Dict, Iterable, List, Optional, Tuple
20
+ from typing import Any, Dict, Iterable, Optional, Tuple
21
21
 
22
22
  import torch
23
- import torch.nn.functional as F
24
23
  from torch import nn
25
24
  from transformers import PretrainedConfig
26
- from vllm.distributed import (
27
- get_tensor_model_parallel_world_size,
28
- tensor_model_parallel_all_reduce,
29
- )
30
- from vllm.model_executor.layers.linear import (
31
- MergedColumnParallelLinear,
25
+
26
+ from sglang.srt.distributed import get_tensor_model_parallel_world_size
27
+ from sglang.srt.layers.layernorm import RMSNorm
28
+ from sglang.srt.layers.linear import (
32
29
  QKVParallelLinear,
33
30
  ReplicatedLinear,
34
31
  RowParallelLinear,
35
32
  )
36
- from vllm.model_executor.layers.rotary_embedding import get_rope
37
-
38
- from sglang.srt.layers.activation import SiluAndMul
39
- from sglang.srt.layers.layernorm import RMSNorm
40
- from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
33
+ from sglang.srt.layers.logits_processor import LogitsProcessor
41
34
  from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
42
35
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
43
36
  from sglang.srt.layers.radix_attention import RadixAttention
37
+ from sglang.srt.layers.rotary_embedding import get_rope
44
38
  from sglang.srt.layers.vocab_parallel_embedding import (
45
39
  ParallelLMHead,
46
40
  VocabParallelEmbedding,
@@ -5,9 +5,8 @@ import torch
5
5
  from torch import nn
6
6
  from transformers import Phi3Config
7
7
  from transformers.configuration_utils import PretrainedConfig
8
- from vllm.distributed import get_tensor_model_parallel_world_size
9
- from vllm.model_executor.layers.rotary_embedding import get_rope
10
8
 
9
+ from sglang.srt.distributed import get_tensor_model_parallel_world_size
11
10
  from sglang.srt.layers.linear import (
12
11
  MergedColumnParallelLinear,
13
12
  QKVParallelLinear,
@@ -17,6 +16,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorO
17
16
  from sglang.srt.layers.pooler import Pooler, PoolingType
18
17
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
19
18
  from sglang.srt.layers.radix_attention import RadixAttention
19
+ from sglang.srt.layers.rotary_embedding import get_rope
20
20
  from sglang.srt.layers.vocab_parallel_embedding import (
21
21
  DEFAULT_VOCAB_PADDING_SIZE,
22
22
  ParallelLMHead,
sglang/srt/models/qwen.py CHANGED
@@ -20,9 +20,8 @@ from typing import Any, Dict, Iterable, Optional, Tuple
20
20
  import torch
21
21
  from torch import nn
22
22
  from transformers import PretrainedConfig
23
- from vllm.distributed import get_tensor_model_parallel_world_size
24
- from vllm.model_executor.layers.rotary_embedding import get_rope
25
23
 
24
+ from sglang.srt.distributed import get_tensor_model_parallel_world_size
26
25
  from sglang.srt.layers.activation import SiluAndMul
27
26
  from sglang.srt.layers.layernorm import RMSNorm
28
27
  from sglang.srt.layers.linear import (
@@ -33,6 +32,7 @@ from sglang.srt.layers.linear import (
33
32
  from sglang.srt.layers.logits_processor import LogitsProcessor
34
33
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
35
34
  from sglang.srt.layers.radix_attention import RadixAttention
35
+ from sglang.srt.layers.rotary_embedding import get_rope
36
36
  from sglang.srt.layers.vocab_parallel_embedding import (
37
37
  ParallelLMHead,
38
38
  VocabParallelEmbedding,
@@ -20,9 +20,11 @@ from typing import Any, Dict, Iterable, Optional, Tuple
20
20
 
21
21
  import torch
22
22
  from torch import nn
23
- from vllm.distributed import get_tensor_model_parallel_world_size
24
- from vllm.model_executor.layers.rotary_embedding import get_rope
25
23
 
24
+ from sglang.srt.distributed import (
25
+ get_tensor_model_parallel_rank,
26
+ get_tensor_model_parallel_world_size,
27
+ )
26
28
  from sglang.srt.layers.activation import SiluAndMul
27
29
  from sglang.srt.layers.layernorm import RMSNorm
28
30
  from sglang.srt.layers.linear import (
@@ -34,12 +36,16 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
34
36
  from sglang.srt.layers.pooler import Pooler, PoolingType
35
37
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
36
38
  from sglang.srt.layers.radix_attention import RadixAttention
39
+ from sglang.srt.layers.rotary_embedding import get_rope
37
40
  from sglang.srt.layers.vocab_parallel_embedding import (
38
41
  ParallelLMHead,
39
42
  VocabParallelEmbedding,
40
43
  )
41
44
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
42
- from sglang.srt.model_loader.weight_utils import default_weight_loader
45
+ from sglang.srt.model_loader.weight_utils import (
46
+ default_weight_loader,
47
+ kv_cache_scales_loader,
48
+ )
43
49
  from sglang.srt.utils import make_layers
44
50
 
45
51
  Qwen2Config = None
@@ -242,6 +248,9 @@ class Qwen2Model(nn.Module):
242
248
  )
243
249
  self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
244
250
 
251
+ def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
252
+ return self.embed_tokens(input_ids)
253
+
245
254
  def forward(
246
255
  self,
247
256
  input_ids: torch.Tensor,
@@ -265,9 +274,31 @@ class Qwen2Model(nn.Module):
265
274
  hidden_states, _ = self.norm(hidden_states, residual)
266
275
  return hidden_states
267
276
 
277
+ # If this function is called, it should always initialize KV cache scale
278
+ # factors (or else raise an exception). Thus, handled exceptions should
279
+ # make sure to leave KV cache scale factors in a known good (dummy) state
280
+ def load_kv_cache_scales(self, quantization_param_path: str) -> None:
281
+ tp_size = get_tensor_model_parallel_world_size()
282
+ tp_rank = get_tensor_model_parallel_rank()
283
+ for layer_idx, scaling_factor in kv_cache_scales_loader(
284
+ quantization_param_path,
285
+ tp_rank,
286
+ tp_size,
287
+ self.config.num_hidden_layers,
288
+ self.config.__class__.model_type,
289
+ ):
290
+ if not isinstance(self.layers[layer_idx], nn.Identity):
291
+ layer_self_attn = self.layers[layer_idx].self_attn
292
+ if hasattr(layer_self_attn.attn, "k_scale"):
293
+ layer_self_attn.attn.k_scale = scaling_factor
294
+ layer_self_attn.attn.v_scale = scaling_factor
295
+ else:
296
+ raise RuntimeError(
297
+ "Self attention has no KV cache scaling " "factor attribute!"
298
+ )
268
299
 
269
- class Qwen2ForCausalLM(nn.Module):
270
300
 
301
+ class Qwen2ForCausalLM(nn.Module):
271
302
  # BitandBytes specific attributes
272
303
  default_bitsandbytes_target_modules = [
273
304
  ".gate_proj.",
@@ -305,6 +336,9 @@ class Qwen2ForCausalLM(nn.Module):
305
336
  self.logits_processor = LogitsProcessor(config)
306
337
  self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
307
338
 
339
+ def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
340
+ return self.model.get_input_embeddings(input_ids)
341
+
308
342
  @torch.no_grad()
309
343
  def forward(
310
344
  self,
@@ -362,5 +396,19 @@ class Qwen2ForCausalLM(nn.Module):
362
396
  weight_loader = getattr(param, "weight_loader", default_weight_loader)
363
397
  weight_loader(param, loaded_weight)
364
398
 
399
+ def get_embed_and_head(self):
400
+ return self.model.embed_tokens.weight, self.lm_head.weight
401
+
402
+ def set_embed_and_head(self, embed, head):
403
+ del self.model.embed_tokens.weight
404
+ del self.lm_head.weight
405
+ self.model.embed_tokens.weight = embed
406
+ self.lm_head.weight = head
407
+ torch.cuda.empty_cache()
408
+ torch.cuda.synchronize()
409
+
410
+ def load_kv_cache_scales(self, quantization_param_path: str) -> None:
411
+ self.model.load_kv_cache_scales(quantization_param_path)
412
+
365
413
 
366
414
  EntryClass = Qwen2ForCausalLM
@@ -0,0 +1,131 @@
1
+ """
2
+ Copyright 2023-2024 SGLang Team
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
14
+ """
15
+
16
+ # Adapted from
17
+ # https://github.com/SafeAILab/EAGLE/blob/main/eagle/model/cnets.py
18
+ """Inference-only LLaMA-EAGLE model compatible with HuggingFace weights."""
19
+
20
+ from typing import Iterable, Optional, Tuple
21
+
22
+ import torch
23
+ from torch import nn
24
+
25
+ from sglang.srt.layers.logits_processor import LogitsProcessor
26
+ from sglang.srt.layers.quantization.base_config import QuantizationConfig
27
+ from sglang.srt.layers.vocab_parallel_embedding import (
28
+ ParallelLMHead,
29
+ VocabParallelEmbedding,
30
+ )
31
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
32
+ from sglang.srt.models.qwen2 import Qwen2DecoderLayer, Qwen2ForCausalLM
33
+
34
+ Qwen2Config = None
35
+
36
+
37
+ class Qwen2DecoderLayer(Qwen2DecoderLayer):
38
+ def __init__(
39
+ self,
40
+ config: Qwen2Config,
41
+ layer_id: int = 0,
42
+ quant_config: Optional[QuantizationConfig] = None,
43
+ prefix: str = "",
44
+ ) -> None:
45
+ super().__init__(config, layer_id, quant_config)
46
+
47
+ # Skip the input_layernorm
48
+ # https://github.com/SafeAILab/EAGLE/blob/35c78f6cdc19a73e05cf5c330b4c358dad970c6a/eagle/model/cnets.py#L427
49
+ if layer_id == 0:
50
+ del self.input_layernorm
51
+ setattr(self, "input_layernorm", lambda x: x)
52
+
53
+
54
+ class Qwen2Model(nn.Module):
55
+ def __init__(
56
+ self,
57
+ config: Qwen2Config,
58
+ quant_config: Optional[QuantizationConfig] = None,
59
+ ) -> None:
60
+ super().__init__()
61
+ self.config = config
62
+ self.vocab_size = config.vocab_size
63
+ self.embed_tokens = VocabParallelEmbedding(
64
+ config.vocab_size,
65
+ config.hidden_size,
66
+ )
67
+ self.layers = nn.ModuleList(
68
+ [
69
+ Qwen2DecoderLayer(
70
+ config, i, quant_config=quant_config, prefix=f"model.layers.{i}"
71
+ )
72
+ for i in range(config.num_hidden_layers)
73
+ ]
74
+ )
75
+ self.fc = torch.nn.Linear(config.hidden_size * 2, config.hidden_size)
76
+
77
+ def forward(
78
+ self,
79
+ input_ids: torch.Tensor,
80
+ positions: torch.Tensor,
81
+ forward_batch: ForwardBatch,
82
+ input_embeds: torch.Tensor = None,
83
+ ) -> torch.Tensor:
84
+ if input_embeds is None:
85
+ hidden_states = self.embed_tokens(input_ids)
86
+ else:
87
+ hidden_states = input_embeds
88
+
89
+ hidden_states = self.fc(
90
+ torch.cat((hidden_states, forward_batch.spec_info.hidden_states), dim=-1)
91
+ )
92
+
93
+ residual = None
94
+ for i in range(len(self.layers)):
95
+ layer = self.layers[i]
96
+ hidden_states, residual = layer(
97
+ positions,
98
+ hidden_states,
99
+ forward_batch,
100
+ residual,
101
+ )
102
+ return hidden_states + residual
103
+
104
+
105
+ class Qwen2ForCausalLMEagle(Qwen2ForCausalLM):
106
+ def __init__(
107
+ self,
108
+ config: Qwen2Config,
109
+ quant_config: Optional[QuantizationConfig] = None,
110
+ cache_config=None,
111
+ ) -> None:
112
+ nn.Module.__init__(self)
113
+ self.config = config
114
+ self.quant_config = quant_config
115
+ self.model = Qwen2Model(config, quant_config=quant_config)
116
+ if self.config.tie_word_embeddings:
117
+ self.lm_head = self.model.embed_tokens
118
+ else:
119
+ self.lm_head = ParallelLMHead(
120
+ config.vocab_size, config.hidden_size, quant_config=quant_config
121
+ )
122
+ self.logits_processor = LogitsProcessor(config)
123
+
124
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
125
+ for name, loaded_weight in weights:
126
+ if "lm_head" not in name:
127
+ name = "model." + name
128
+ super().load_weights([(name, loaded_weight)])
129
+
130
+
131
+ EntryClass = [Qwen2ForCausalLMEagle]
@@ -22,12 +22,11 @@ import torch
22
22
  import torch.nn.functional as F
23
23
  from torch import nn
24
24
  from transformers import PretrainedConfig
25
- from vllm.distributed import (
25
+
26
+ from sglang.srt.distributed import (
26
27
  get_tensor_model_parallel_world_size,
27
28
  tensor_model_parallel_all_reduce,
28
29
  )
29
- from vllm.model_executor.layers.rotary_embedding import get_rope
30
-
31
30
  from sglang.srt.layers.activation import SiluAndMul
32
31
  from sglang.srt.layers.layernorm import RMSNorm
33
32
  from sglang.srt.layers.linear import (
@@ -40,6 +39,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
40
39
  from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
41
40
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
42
41
  from sglang.srt.layers.radix_attention import RadixAttention
42
+ from sglang.srt.layers.rotary_embedding import get_rope
43
43
  from sglang.srt.layers.vocab_parallel_embedding import (
44
44
  ParallelLMHead,
45
45
  VocabParallelEmbedding,
@@ -22,6 +22,7 @@
22
22
  # See the License for the specific language governing permissions and
23
23
  # limitations under the License.
24
24
  """Inference-only Qwen2-VL model compatible with HuggingFace weights."""
25
+ import logging
25
26
  from functools import lru_cache, partial
26
27
  from typing import Iterable, List, Optional, Tuple, Type, TypedDict
27
28
 
@@ -30,16 +31,13 @@ import torch
30
31
  import torch.nn as nn
31
32
  import torch.nn.functional as F
32
33
  from einops import rearrange, repeat
33
- from vllm.distributed import parallel_state
34
- from vllm.distributed import utils as dist_utils
35
- from vllm.logger import init_logger
36
34
  from vllm.model_executor.layers.activation import QuickGELU
37
35
 
38
36
  from sglang.srt.configs import Qwen2VLConfig, Qwen2VLVisionConfig
37
+ from sglang.srt.distributed import parallel_state
38
+ from sglang.srt.distributed import utils as dist_utils
39
39
  from sglang.srt.hf_transformers_utils import get_processor
40
- from sglang.srt.layers.attention.triton_ops.prefill_attention import (
41
- context_attention_fwd,
42
- )
40
+ from sglang.srt.layers.attention.vision import VisionAttention
43
41
  from sglang.srt.layers.linear import ColumnParallelLinear, RowParallelLinear
44
42
  from sglang.srt.layers.logits_processor import LogitsProcessor
45
43
  from sglang.srt.layers.pooler import Pooler, PoolingType
@@ -50,7 +48,8 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
50
48
  from sglang.srt.model_loader.weight_utils import default_weight_loader
51
49
  from sglang.srt.models.qwen2 import Qwen2Model
52
50
 
53
- logger = init_logger(__name__)
51
+ logger = logging.getLogger(__name__)
52
+
54
53
 
55
54
  # === Vision Inputs === #
56
55
 
@@ -110,118 +109,6 @@ class Qwen2VisionMLP(nn.Module):
110
109
  return x
111
110
 
112
111
 
113
- def rotate_half(x: torch.Tensor, interleaved: bool = False) -> torch.Tensor:
114
- if not interleaved:
115
- x1, x2 = x.chunk(2, dim=-1)
116
- return torch.cat((-x2, x1), dim=-1)
117
- else:
118
- x1, x2 = x[..., ::2], x[..., 1::2]
119
- return rearrange(
120
- torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2
121
- )
122
-
123
-
124
- def apply_rotary_emb_torch(
125
- x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, interleaved: bool = False
126
- ) -> torch.Tensor:
127
- """
128
- x: (batch_size, seqlen, nheads, headdim)
129
- cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2)
130
- """
131
- ro_dim = cos.shape[-1] * 2
132
- assert ro_dim <= x.shape[-1]
133
- cos = repeat(
134
- cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)"
135
- )
136
- sin = repeat(
137
- sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)"
138
- )
139
- return torch.cat(
140
- [
141
- x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin,
142
- x[..., ro_dim:],
143
- ],
144
- dim=-1,
145
- )
146
-
147
-
148
- def apply_rotary_pos_emb_vision(t: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor:
149
- t_ = t.float()
150
- cos = freqs.cos()
151
- sin = freqs.sin()
152
- output = apply_rotary_emb_torch(t_, cos, sin).type_as(t)
153
- return output
154
-
155
-
156
- class Qwen2VisionAttention(nn.Module):
157
-
158
- def __init__(
159
- self,
160
- embed_dim: Optional[int] = None,
161
- num_heads: Optional[int] = None,
162
- projection_size: Optional[int] = None,
163
- quant_config: Optional[QuantizationConfig] = None,
164
- ) -> None:
165
- super().__init__()
166
- # Per attention head and per partition values.
167
- world_size = parallel_state.get_tensor_model_parallel_world_size()
168
- self.hidden_size_per_attention_head = dist_utils.divide(
169
- projection_size, num_heads
170
- )
171
- self.num_attention_heads_per_partition = dist_utils.divide(
172
- num_heads, world_size
173
- )
174
-
175
- self.qkv = ColumnParallelLinear(
176
- input_size=embed_dim,
177
- output_size=3 * projection_size,
178
- quant_config=quant_config,
179
- )
180
- self.proj = RowParallelLinear(
181
- input_size=projection_size, output_size=embed_dim, quant_config=quant_config
182
- )
183
-
184
- def forward(
185
- self,
186
- x: torch.Tensor,
187
- cu_seqlens: torch.Tensor,
188
- rotary_pos_emb: torch.Tensor = None,
189
- ) -> torch.Tensor:
190
- # [s, b, c] --> [s, b, head * 3 * head_dim]
191
- x, _ = self.qkv(x)
192
-
193
- # [s, b, head * 3 * head_dim] --> [s, b, head, 3 * head_dim]
194
- new_x_shape = x.size()[:-1] + (
195
- self.num_attention_heads_per_partition,
196
- 3 * self.hidden_size_per_attention_head,
197
- )
198
- x = x.view(*new_x_shape)
199
-
200
- # [s, b, head, 3 * head_dim] --> 3 [s, b, head, head_dim]
201
- q, k, v = dist_utils.split_tensor_along_last_dim(x, 3)
202
- batch_size = q.shape[1]
203
-
204
- q, k, v = [rearrange(x, "s b ... -> b s ...").contiguous() for x in (q, k, v)]
205
- if rotary_pos_emb is not None:
206
- q = apply_rotary_pos_emb_vision(q, rotary_pos_emb)
207
- k = apply_rotary_pos_emb_vision(k, rotary_pos_emb)
208
-
209
- seq_lens = cu_seqlens[1:] - cu_seqlens[:-1]
210
- max_seqlen = (seq_lens).max().item()
211
- q, k, v = [rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]]
212
-
213
- output = torch.empty_like(q)
214
- context_attention_fwd(
215
- q, k, v, output, cu_seqlens, seq_lens, max_seqlen, is_causal=False
216
- )
217
-
218
- context_layer = rearrange(output, "(b s) ... -> b s ...", b=batch_size)
219
- context_layer = rearrange(context_layer, "b s h d -> s b (h d)").contiguous()
220
-
221
- output, _ = self.proj(context_layer)
222
- return output
223
-
224
-
225
112
  class Qwen2VisionBlock(nn.Module):
226
113
 
227
114
  def __init__(
@@ -240,10 +127,11 @@ class Qwen2VisionBlock(nn.Module):
240
127
  self.norm2 = norm_layer(dim)
241
128
  mlp_hidden_dim = int(dim * mlp_ratio)
242
129
 
243
- self.attn = Qwen2VisionAttention(
130
+ self.attn = VisionAttention(
244
131
  embed_dim=dim,
245
132
  num_heads=num_heads,
246
133
  projection_size=dim,
134
+ use_qkv_parallel=False,
247
135
  quant_config=quant_config,
248
136
  )
249
137
  self.mlp = Qwen2VisionMLP(
@@ -253,9 +141,13 @@ class Qwen2VisionBlock(nn.Module):
253
141
  def forward(
254
142
  self, x: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor
255
143
  ) -> torch.Tensor:
256
- x = x + self.attn(
257
- self.norm1(x), cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb
144
+ hidden_states = self.norm1(x)
145
+ hidden_states = rearrange(hidden_states, "s b ... -> b s ...")
146
+ attn = self.attn(
147
+ hidden_states, cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb
258
148
  )
149
+ attn = rearrange(attn, "b s ... -> s b ...")
150
+ x = x + attn
259
151
  x = x + self.mlp(self.norm2(x))
260
152
  return x
261
153
 
@@ -684,10 +576,12 @@ class Qwen2VLForConditionalGeneration(nn.Module):
684
576
  for name, loaded_weight in weights:
685
577
  if "rotary_emb.inv_freq" in name:
686
578
  continue
579
+
687
580
  for param_name, weight_name, shard_id in stacked_params_mapping:
688
581
  if weight_name not in name:
689
582
  continue
690
583
  name = name.replace(weight_name, param_name)
584
+
691
585
  # Skip loading extra bias for GPTQ models.
692
586
  if name.endswith(".bias") and name not in params_dict:
693
587
  continue
@@ -696,6 +590,7 @@ class Qwen2VLForConditionalGeneration(nn.Module):
696
590
  weight_loader(param, loaded_weight, shard_id)
697
591
  break
698
592
  else:
593
+
699
594
  if "visual" in name and "qkv.weight" in name:
700
595
  visual_num_heads = self.config.vision_config.num_heads
701
596
  visual_embed_dim = self.config.vision_config.embed_dim
@@ -712,6 +607,11 @@ class Qwen2VLForConditionalGeneration(nn.Module):
712
607
  loaded_weight = loaded_weight.view(3, visual_num_heads, head_size)
713
608
  loaded_weight = loaded_weight.transpose(0, 1)
714
609
  loaded_weight = loaded_weight.reshape(-1)
610
+
611
+ if "visual" in name:
612
+ # adapt to VisionAttention
613
+ name = name.replace(r"attn.qkv.", r"attn.qkv_proj.")
614
+
715
615
  try:
716
616
  # Skip loading extra bias for GPTQ models.
717
617
  if name.endswith(".bias") and name not in params_dict: