sglang 0.3.3__py3-none-any.whl → 0.3.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (77) hide show
  1. sglang/bench_latency.py +31 -13
  2. sglang/bench_server_latency.py +21 -10
  3. sglang/bench_serving.py +101 -7
  4. sglang/global_config.py +0 -1
  5. sglang/srt/conversation.py +11 -2
  6. sglang/srt/layers/attention/__init__.py +27 -5
  7. sglang/srt/layers/attention/double_sparsity_backend.py +281 -0
  8. sglang/srt/layers/attention/flashinfer_backend.py +352 -83
  9. sglang/srt/layers/attention/triton_backend.py +6 -4
  10. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +772 -0
  11. sglang/srt/layers/attention/triton_ops/extend_attention.py +5 -3
  12. sglang/srt/layers/attention/triton_ops/prefill_attention.py +4 -2
  13. sglang/srt/layers/sampler.py +6 -2
  14. sglang/srt/managers/data_parallel_controller.py +177 -0
  15. sglang/srt/managers/detokenizer_manager.py +31 -10
  16. sglang/srt/managers/io_struct.py +11 -2
  17. sglang/srt/managers/schedule_batch.py +126 -43
  18. sglang/srt/managers/schedule_policy.py +2 -1
  19. sglang/srt/managers/scheduler.py +245 -142
  20. sglang/srt/managers/tokenizer_manager.py +14 -1
  21. sglang/srt/managers/tp_worker.py +111 -1
  22. sglang/srt/mem_cache/chunk_cache.py +8 -4
  23. sglang/srt/mem_cache/memory_pool.py +77 -4
  24. sglang/srt/mem_cache/radix_cache.py +15 -7
  25. sglang/srt/model_executor/cuda_graph_runner.py +4 -4
  26. sglang/srt/model_executor/forward_batch_info.py +16 -21
  27. sglang/srt/model_executor/model_runner.py +100 -36
  28. sglang/srt/models/baichuan.py +2 -3
  29. sglang/srt/models/chatglm.py +5 -6
  30. sglang/srt/models/commandr.py +1 -2
  31. sglang/srt/models/dbrx.py +1 -2
  32. sglang/srt/models/deepseek.py +4 -5
  33. sglang/srt/models/deepseek_v2.py +5 -6
  34. sglang/srt/models/exaone.py +1 -2
  35. sglang/srt/models/gemma.py +2 -2
  36. sglang/srt/models/gemma2.py +5 -5
  37. sglang/srt/models/gpt_bigcode.py +5 -5
  38. sglang/srt/models/grok.py +1 -2
  39. sglang/srt/models/internlm2.py +1 -2
  40. sglang/srt/models/llama.py +1 -2
  41. sglang/srt/models/llama_classification.py +1 -2
  42. sglang/srt/models/llama_reward.py +2 -3
  43. sglang/srt/models/llava.py +4 -8
  44. sglang/srt/models/llavavid.py +1 -2
  45. sglang/srt/models/minicpm.py +1 -2
  46. sglang/srt/models/minicpm3.py +5 -6
  47. sglang/srt/models/mixtral.py +1 -2
  48. sglang/srt/models/mixtral_quant.py +1 -2
  49. sglang/srt/models/olmo.py +352 -0
  50. sglang/srt/models/olmoe.py +1 -2
  51. sglang/srt/models/qwen.py +1 -2
  52. sglang/srt/models/qwen2.py +1 -2
  53. sglang/srt/models/qwen2_moe.py +4 -5
  54. sglang/srt/models/stablelm.py +1 -2
  55. sglang/srt/models/torch_native_llama.py +1 -2
  56. sglang/srt/models/xverse.py +1 -2
  57. sglang/srt/models/xverse_moe.py +4 -5
  58. sglang/srt/models/yivl.py +1 -2
  59. sglang/srt/openai_api/adapter.py +97 -52
  60. sglang/srt/openai_api/protocol.py +10 -2
  61. sglang/srt/sampling/penaltylib/orchestrator.py +28 -9
  62. sglang/srt/sampling/sampling_batch_info.py +105 -59
  63. sglang/srt/sampling/sampling_params.py +2 -0
  64. sglang/srt/server.py +171 -37
  65. sglang/srt/server_args.py +127 -48
  66. sglang/srt/utils.py +37 -14
  67. sglang/test/few_shot_gsm8k.py +4 -1
  68. sglang/test/few_shot_gsm8k_engine.py +144 -0
  69. sglang/test/srt/sampling/penaltylib/utils.py +16 -12
  70. sglang/version.py +1 -1
  71. {sglang-0.3.3.dist-info → sglang-0.3.4.dist-info}/METADATA +82 -32
  72. sglang-0.3.4.dist-info/RECORD +143 -0
  73. {sglang-0.3.3.dist-info → sglang-0.3.4.dist-info}/WHEEL +1 -1
  74. sglang/srt/layers/attention/flashinfer_utils.py +0 -237
  75. sglang-0.3.3.dist-info/RECORD +0 -139
  76. {sglang-0.3.3.dist-info → sglang-0.3.4.dist-info}/LICENSE +0 -0
  77. {sglang-0.3.3.dist-info → sglang-0.3.4.dist-info}/top_level.txt +0 -0
@@ -22,7 +22,6 @@ from typing import Any, Dict, Iterable, Optional, Tuple
22
22
  import torch
23
23
  from torch import nn
24
24
  from transformers import LlamaConfig
25
- from vllm.config import CacheConfig
26
25
  from vllm.distributed import get_tensor_model_parallel_world_size
27
26
  from vllm.model_executor.layers.rotary_embedding import get_rope
28
27
  from vllm.model_executor.layers.vocab_parallel_embedding import (
@@ -295,7 +294,7 @@ class LlamaForCausalLM(nn.Module):
295
294
  self,
296
295
  config: LlamaConfig,
297
296
  quant_config: Optional[QuantizationConfig] = None,
298
- cache_config: Optional[CacheConfig] = None,
297
+ cache_config=None,
299
298
  ) -> None:
300
299
  super().__init__()
301
300
  self.config = config
@@ -18,7 +18,6 @@ from typing import Iterable, Optional, Tuple
18
18
  import torch
19
19
  from torch import nn
20
20
  from transformers import LlamaConfig
21
- from vllm.config import CacheConfig
22
21
  from vllm.model_executor.model_loader.weight_utils import default_weight_loader
23
22
 
24
23
  from sglang.srt.layers.logits_processor import LogitsProcessorOutput
@@ -32,7 +31,7 @@ class LlamaForClassification(nn.Module):
32
31
  self,
33
32
  config: LlamaConfig,
34
33
  quant_config: Optional[QuantizationConfig] = None,
35
- cache_config: Optional[CacheConfig] = None,
34
+ cache_config=None,
36
35
  ) -> None:
37
36
  super().__init__()
38
37
  self.config = config
@@ -18,7 +18,6 @@ from typing import Iterable, Optional, Tuple
18
18
  import torch
19
19
  from torch import nn
20
20
  from transformers import LlamaConfig
21
- from vllm.config import CacheConfig
22
21
  from vllm.model_executor.model_loader.weight_utils import default_weight_loader
23
22
 
24
23
  from sglang.srt.layers.logits_processor import LogitsProcessorOutput
@@ -33,7 +32,7 @@ class LlamaForSequenceClassification(nn.Module):
33
32
  self,
34
33
  config: LlamaConfig,
35
34
  quant_config: Optional[QuantizationConfig] = None,
36
- cache_config: Optional[CacheConfig] = None,
35
+ cache_config=None,
37
36
  ) -> None:
38
37
  super().__init__()
39
38
  self.config = config
@@ -92,7 +91,7 @@ class LlamaForSequenceClassificationWithNormal_Weights(LlamaForSequenceClassific
92
91
  self,
93
92
  config: LlamaConfig,
94
93
  quant_config: Optional[QuantizationConfig] = None,
95
- cache_config: Optional[CacheConfig] = None,
94
+ cache_config=None,
96
95
  ) -> None:
97
96
  super().__init__(config, quant_config, cache_config)
98
97
  self.weights = self.Weights(config.hidden_size, self.num_labels)
@@ -31,7 +31,6 @@ from transformers import (
31
31
  SiglipVisionModel,
32
32
  )
33
33
  from transformers.models.llava.modeling_llava import LlavaMultiModalProjector
34
- from vllm.config import CacheConfig
35
34
  from vllm.model_executor.model_loader.weight_utils import default_weight_loader
36
35
 
37
36
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
@@ -161,9 +160,6 @@ class LlavaBaseForCausalLM(nn.Module):
161
160
  image_sizes = [
162
161
  image_inputs[i].image_sizes for i in range(bs) if need_vision[i]
163
162
  ]
164
- image_offsets = [
165
- image_inputs[i].image_offsets for i in range(bs) if need_vision[i]
166
- ]
167
163
 
168
164
  ########## Encode Image ########
169
165
 
@@ -359,7 +355,7 @@ class LlavaBaseForCausalLM(nn.Module):
359
355
  prefix_len = prefix_lens_cpu[i]
360
356
 
361
357
  # Multiple images
362
- for j, image_offset in enumerate(image_offsets[i]):
358
+ for j, image_offset in enumerate(image_inputs[i].image_offsets):
363
359
  if image_offset < prefix_len:
364
360
  continue
365
361
 
@@ -450,7 +446,7 @@ class LlavaLlamaForCausalLM(LlavaBaseForCausalLM):
450
446
  self,
451
447
  config: LlavaConfig,
452
448
  quant_config: Optional[QuantizationConfig] = None,
453
- cache_config: Optional[CacheConfig] = None,
449
+ cache_config=None,
454
450
  ) -> None:
455
451
  super().__init__()
456
452
 
@@ -472,7 +468,7 @@ class LlavaQwenForCausalLM(LlavaBaseForCausalLM):
472
468
  self,
473
469
  config: LlavaConfig,
474
470
  quant_config: Optional[QuantizationConfig] = None,
475
- cache_config: Optional[CacheConfig] = None,
471
+ cache_config=None,
476
472
  ) -> None:
477
473
  super().__init__()
478
474
 
@@ -505,7 +501,7 @@ class LlavaMistralForCausalLM(LlavaBaseForCausalLM):
505
501
  self,
506
502
  config: LlavaConfig,
507
503
  quant_config: Optional[QuantizationConfig] = None,
508
- cache_config: Optional[CacheConfig] = None,
504
+ cache_config=None,
509
505
  ) -> None:
510
506
  super().__init__()
511
507
 
@@ -22,7 +22,6 @@ import torch
22
22
  from torch import nn
23
23
  from transformers import CLIPVisionModel, LlavaConfig
24
24
  from transformers.models.llava.modeling_llava import LlavaMultiModalProjector
25
- from vllm.config import CacheConfig
26
25
  from vllm.model_executor.model_loader.weight_utils import default_weight_loader
27
26
 
28
27
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
@@ -36,7 +35,7 @@ class LlavaVidForCausalLM(nn.Module):
36
35
  self,
37
36
  config: LlavaConfig,
38
37
  quant_config: Optional[QuantizationConfig] = None,
39
- cache_config: Optional[CacheConfig] = None,
38
+ cache_config=None,
40
39
  ) -> None:
41
40
  super().__init__()
42
41
  self.config = config
@@ -20,7 +20,6 @@ from typing import Any, Dict, Iterable, Optional, Tuple
20
20
 
21
21
  import torch
22
22
  from torch import nn
23
- from vllm.config import CacheConfig
24
23
  from vllm.distributed import get_tensor_model_parallel_world_size
25
24
  from vllm.model_executor.layers.rotary_embedding import get_rope
26
25
  from vllm.model_executor.layers.vocab_parallel_embedding import (
@@ -278,7 +277,7 @@ class MiniCPMForCausalLM(nn.Module):
278
277
  self,
279
278
  config,
280
279
  quant_config: Optional[QuantizationConfig] = None,
281
- cache_config: Optional[CacheConfig] = None,
280
+ cache_config=None,
282
281
  ) -> None:
283
282
  super().__init__()
284
283
  self.config = config
@@ -21,7 +21,6 @@ from typing import Any, Dict, Iterable, Optional, Tuple
21
21
  import torch
22
22
  from torch import nn
23
23
  from transformers import PretrainedConfig
24
- from vllm.config import CacheConfig
25
24
  from vllm.distributed import get_tensor_model_parallel_world_size
26
25
  from vllm.model_executor.layers.linear import (
27
26
  ColumnParallelLinear,
@@ -108,7 +107,7 @@ class MiniCPM3Attention(nn.Module):
108
107
  rope_theta: float = 10000,
109
108
  rope_scaling: Optional[Dict[str, Any]] = None,
110
109
  max_position_embeddings: int = 8192,
111
- cache_config: Optional[CacheConfig] = None,
110
+ cache_config=None,
112
111
  quant_config: Optional[QuantizationConfig] = None,
113
112
  layer_id=None,
114
113
  ) -> None:
@@ -252,7 +251,7 @@ class MiniCPM3AttentionMLA(nn.Module):
252
251
  rope_theta: float = 10000,
253
252
  rope_scaling: Optional[Dict[str, Any]] = None,
254
253
  max_position_embeddings: int = 8192,
255
- cache_config: Optional[CacheConfig] = None,
254
+ cache_config=None,
256
255
  quant_config: Optional[QuantizationConfig] = None,
257
256
  layer_id=None,
258
257
  ) -> None:
@@ -409,7 +408,7 @@ class MiniCPM3DecoderLayer(nn.Module):
409
408
  self,
410
409
  config: PretrainedConfig,
411
410
  layer_id: int,
412
- cache_config: Optional[CacheConfig] = None,
411
+ cache_config=None,
413
412
  quant_config: Optional[QuantizationConfig] = None,
414
413
  ) -> None:
415
414
  super().__init__()
@@ -501,7 +500,7 @@ class MiniCPM3Model(nn.Module):
501
500
  def __init__(
502
501
  self,
503
502
  config: PretrainedConfig,
504
- cache_config: Optional[CacheConfig] = None,
503
+ cache_config=None,
505
504
  quant_config: Optional[QuantizationConfig] = None,
506
505
  ) -> None:
507
506
  super().__init__()
@@ -552,7 +551,7 @@ class MiniCPM3ForCausalLM(nn.Module):
552
551
  def __init__(
553
552
  self,
554
553
  config: PretrainedConfig,
555
- cache_config: Optional[CacheConfig] = None,
554
+ cache_config=None,
556
555
  quant_config: Optional[QuantizationConfig] = None,
557
556
  ) -> None:
558
557
  super().__init__()
@@ -21,7 +21,6 @@ from typing import Iterable, Optional, Tuple
21
21
  import torch
22
22
  from torch import nn
23
23
  from transformers import MixtralConfig
24
- from vllm.config import CacheConfig
25
24
  from vllm.distributed import get_tensor_model_parallel_world_size
26
25
  from vllm.model_executor.layers.fused_moe import FusedMoE
27
26
  from vllm.model_executor.layers.rotary_embedding import get_rope
@@ -293,7 +292,7 @@ class MixtralForCausalLM(nn.Module):
293
292
  self,
294
293
  config: MixtralConfig,
295
294
  quant_config: Optional[QuantizationConfig] = None,
296
- cache_config: Optional[CacheConfig] = None,
295
+ cache_config=None,
297
296
  ) -> None:
298
297
  super().__init__()
299
298
  self.config = config
@@ -23,7 +23,6 @@ import torch
23
23
  import torch.nn.functional as F
24
24
  from torch import nn
25
25
  from transformers import MixtralConfig
26
- from vllm.config import CacheConfig
27
26
  from vllm.distributed import (
28
27
  get_tensor_model_parallel_rank,
29
28
  get_tensor_model_parallel_world_size,
@@ -325,7 +324,7 @@ class QuantMixtralForCausalLM(nn.Module):
325
324
  self,
326
325
  config: MixtralConfig,
327
326
  quant_config: Optional[QuantizationConfig] = None,
328
- cache_config: Optional[CacheConfig] = None,
327
+ cache_config=None,
329
328
  ) -> None:
330
329
  super().__init__()
331
330
  self.config = config
@@ -0,0 +1,352 @@
1
+ """
2
+ Copyright 2023-2024 SGLang Team
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
14
+ """
15
+
16
+ # Adapted from
17
+ # https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/olmo.py#L1
18
+ """Inference-only OLMo model compatible with HuggingFace weights."""
19
+ from typing import Iterable, List, Optional, Tuple
20
+
21
+ import torch
22
+ from torch import nn
23
+ from transformers import OlmoConfig
24
+ from vllm.distributed import get_tensor_model_parallel_world_size
25
+ from vllm.model_executor.layers.rotary_embedding import get_rope
26
+ from vllm.model_executor.layers.vocab_parallel_embedding import (
27
+ ParallelLMHead,
28
+ VocabParallelEmbedding,
29
+ )
30
+ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
31
+
32
+ from sglang.srt.layers.activation import SiluAndMul
33
+ from sglang.srt.layers.linear import (
34
+ MergedColumnParallelLinear,
35
+ QKVParallelLinear,
36
+ RowParallelLinear,
37
+ )
38
+ from sglang.srt.layers.logits_processor import LogitsProcessor
39
+ from sglang.srt.layers.quantization.base_config import QuantizationConfig
40
+ from sglang.srt.layers.radix_attention import RadixAttention
41
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
42
+
43
+
44
+ class OlmoAttention(nn.Module):
45
+ """
46
+ This is the attention block where the output is computed as
47
+ ``Attention(LN(x))`` in ``MLP(LN(x + Attention(LN(x))))``
48
+ (plus another skip connection).
49
+ """
50
+
51
+ def __init__(
52
+ self,
53
+ config: OlmoConfig,
54
+ layer_id: int = 0,
55
+ quant_config: Optional[QuantizationConfig] = None,
56
+ ):
57
+ super().__init__()
58
+ self.config = config
59
+ self.hidden_size = config.hidden_size
60
+ tensor_model_parallel_world_size = get_tensor_model_parallel_world_size()
61
+ self.total_num_heads = config.num_attention_heads
62
+
63
+ assert self.hidden_size % self.total_num_heads == 0
64
+ assert self.total_num_heads % tensor_model_parallel_world_size == 0
65
+
66
+ self.num_heads = self.total_num_heads // tensor_model_parallel_world_size
67
+ self.head_dim = self.hidden_size // self.total_num_heads
68
+ self.max_position_embeddings = config.max_position_embeddings
69
+ self.rope_theta = config.rope_theta
70
+ self.clip_qkv = config.clip_qkv
71
+
72
+ # Attention input projection. Projects x -> (q, k, v)
73
+ self.qkv_proj = QKVParallelLinear(
74
+ self.hidden_size,
75
+ self.head_dim,
76
+ self.total_num_heads,
77
+ bias=config.attention_bias,
78
+ )
79
+
80
+ # Rotary embeddings.
81
+ self.rotary_emb = get_rope(
82
+ self.head_dim,
83
+ rotary_dim=self.head_dim,
84
+ max_position=self.max_position_embeddings,
85
+ base=self.rope_theta,
86
+ )
87
+ self.scaling = self.head_dim**-0.5
88
+ self.attn = RadixAttention(
89
+ self.num_heads,
90
+ self.head_dim,
91
+ self.scaling,
92
+ num_kv_heads=self.num_heads,
93
+ layer_id=layer_id,
94
+ )
95
+
96
+ # Attention output projection.
97
+ self.o_proj = RowParallelLinear(
98
+ self.hidden_size,
99
+ self.hidden_size,
100
+ bias=config.attention_bias,
101
+ )
102
+
103
+ def forward(
104
+ self,
105
+ positions: torch.Tensor,
106
+ hidden_states: torch.Tensor,
107
+ forward_batch: ForwardBatch,
108
+ ) -> torch.Tensor:
109
+ qkv, _ = self.qkv_proj(hidden_states)
110
+ if self.clip_qkv is not None:
111
+ qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
112
+ q, k, v = qkv.chunk(chunks=3, dim=-1)
113
+ q, k = self.rotary_emb(positions, q, k)
114
+ attn_output = self.attn(q, k, v, forward_batch)
115
+ output, _ = self.o_proj(attn_output)
116
+ return output
117
+
118
+
119
+ class OlmoMLP(nn.Module):
120
+ """
121
+ This is the MLP block where the output is computed as
122
+ ``MLP(LN(x))`` in ``MLP(LN(x + Attention(LN(x))))``
123
+ (plus another skip connection).
124
+ """
125
+
126
+ def __init__(
127
+ self,
128
+ config: OlmoConfig,
129
+ quant_config: Optional[QuantizationConfig] = None,
130
+ ):
131
+ super().__init__()
132
+ self.config = config
133
+ self.hidden_size = config.hidden_size
134
+ self.intermediate_size = config.intermediate_size
135
+
136
+ # Feed-forward input projection.
137
+ self.gate_up_proj = MergedColumnParallelLinear(
138
+ self.hidden_size,
139
+ [self.intermediate_size] * 2,
140
+ bias=False,
141
+ quant_config=quant_config,
142
+ )
143
+
144
+ # Activation function.
145
+ self.act_fn = SiluAndMul()
146
+
147
+ # Feed-forward output projection.
148
+ self.down_proj = RowParallelLinear(
149
+ self.intermediate_size,
150
+ self.hidden_size,
151
+ bias=False,
152
+ quant_config=quant_config,
153
+ )
154
+
155
+ def forward(
156
+ self,
157
+ x: torch.Tensor,
158
+ ) -> torch.Tensor:
159
+ gate_up, _ = self.gate_up_proj(x)
160
+ x = self.act_fn(gate_up)
161
+ x, _ = self.down_proj(x)
162
+ return x
163
+
164
+
165
+ class OlmoDecoderLayer(nn.Module):
166
+ """
167
+ This is a typical transformer block where the output is
168
+ computed as ``MLP(LN(x + Attention(LN(x))))``
169
+ (plus another skip connection).
170
+ """
171
+
172
+ def __init__(
173
+ self,
174
+ config: OlmoConfig,
175
+ layer_id: int = 0,
176
+ quant_config: Optional[QuantizationConfig] = None,
177
+ ):
178
+ super().__init__()
179
+ # Attention block.
180
+ self.self_attn = OlmoAttention(config, layer_id, quant_config)
181
+
182
+ # MLP block.
183
+ self.mlp = OlmoMLP(config, quant_config)
184
+
185
+ # LayerNorm
186
+ self.input_layernorm = nn.LayerNorm(
187
+ config.hidden_size, elementwise_affine=False, bias=False
188
+ )
189
+ self.post_attention_layernorm = nn.LayerNorm(
190
+ config.hidden_size, elementwise_affine=False, bias=False
191
+ )
192
+
193
+ def forward(
194
+ self,
195
+ positions: torch.Tensor,
196
+ hidden_states: torch.Tensor,
197
+ forward_batch: ForwardBatch,
198
+ ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
199
+ # Attention block.
200
+ residual = hidden_states
201
+ hidden_states = self.input_layernorm(hidden_states)
202
+ hidden_states = self.self_attn(positions, hidden_states, forward_batch)
203
+ hidden_states = hidden_states + residual
204
+
205
+ # MLP block.
206
+ residual = hidden_states
207
+ hidden_states = self.post_attention_layernorm(hidden_states)
208
+ hidden_states = self.mlp(hidden_states)
209
+ hidden_states = residual + hidden_states
210
+ return hidden_states
211
+
212
+
213
+ class OlmoModel(nn.Module):
214
+
215
+ def __init__(
216
+ self, config: OlmoConfig, quant_config: Optional[QuantizationConfig] = None
217
+ ):
218
+ super().__init__()
219
+ self.config = config
220
+
221
+ self.embed_tokens = VocabParallelEmbedding(
222
+ config.vocab_size, config.hidden_size
223
+ )
224
+ self.layers = nn.ModuleList(
225
+ [
226
+ OlmoDecoderLayer(config, layer_idx, quant_config)
227
+ for layer_idx in range(config.num_hidden_layers)
228
+ ]
229
+ )
230
+ self.norm = nn.LayerNorm(
231
+ config.hidden_size, elementwise_affine=False, bias=False
232
+ )
233
+
234
+ def forward(
235
+ self,
236
+ input_ids: torch.Tensor,
237
+ positions: torch.Tensor,
238
+ forward_batch: ForwardBatch,
239
+ input_embeds: torch.Tensor = None,
240
+ ) -> torch.Tensor:
241
+ """
242
+ :param input_ids: A tensor of shape `(batch_size, seq_len)`.
243
+ """
244
+ # Get embeddings of input.
245
+ # shape: (batch_size, seq_len, d_model)
246
+
247
+ if input_embeds is None:
248
+ hidden_states = self.embed_tokens(input_ids)
249
+ else:
250
+ hidden_states = input_embeds
251
+
252
+ # Apply blocks one-by-one.
253
+ for layer_idx, decoder_layer in enumerate(self.layers):
254
+ # shape: (batch_size, seq_len, d_model)
255
+ hidden_states = decoder_layer(
256
+ positions,
257
+ hidden_states,
258
+ forward_batch,
259
+ )
260
+
261
+ # Apply final layer norm.
262
+ # shape: (batch_size, seq_len or 1, d_model)
263
+ hidden_states = self.norm(hidden_states)
264
+ return hidden_states
265
+
266
+
267
+ class OlmoForCausalLM(nn.Module):
268
+ """
269
+ Extremely barebones HF model wrapper.
270
+ """
271
+
272
+ def __init__(
273
+ self,
274
+ config: OlmoConfig,
275
+ cache_config=None,
276
+ quant_config: Optional[QuantizationConfig] = None,
277
+ ):
278
+ super().__init__()
279
+ self.config = config
280
+ self.model = OlmoModel(config, quant_config)
281
+ if config.tie_word_embeddings:
282
+ self.lm_head = self.model.embed_tokens
283
+ else:
284
+ self.unpadded_vocab_size = config.vocab_size
285
+ self.lm_head = ParallelLMHead(
286
+ self.unpadded_vocab_size,
287
+ config.hidden_size,
288
+ org_num_embeddings=config.vocab_size,
289
+ quant_config=quant_config,
290
+ )
291
+ self.logits_processor = LogitsProcessor(config)
292
+
293
+ def forward(
294
+ self,
295
+ input_ids: torch.Tensor,
296
+ positions: torch.Tensor,
297
+ forward_batch: ForwardBatch,
298
+ input_embeds: torch.Tensor = None,
299
+ ) -> torch.Tensor:
300
+ hidden_states = self.model(
301
+ input_ids=input_ids,
302
+ positions=positions,
303
+ forward_batch=forward_batch,
304
+ input_embeds=input_embeds,
305
+ )
306
+ return self.logits_processor(
307
+ input_ids, hidden_states, self.lm_head.weight, forward_batch
308
+ )
309
+
310
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
311
+ stacked_params_mapping = [
312
+ # (param_name, shard_name, shard_id)
313
+ ("qkv_proj", "q_proj", "q"),
314
+ ("qkv_proj", "k_proj", "k"),
315
+ ("qkv_proj", "v_proj", "v"),
316
+ ("gate_up_proj", "gate_proj", 0),
317
+ ("gate_up_proj", "up_proj", 1),
318
+ ]
319
+ params_dict = dict(self.named_parameters(remove_duplicate=False))
320
+ for name, loaded_weight in weights:
321
+ if "rotary_emb.inv_freq" in name:
322
+ continue
323
+ if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
324
+ # Models trained using ColossalAI may include these tensors in
325
+ # the checkpoint. Skip them.
326
+ continue
327
+ # With tie_word_embeddings, we can skip lm_head.weight
328
+ # The weight might appear unnecessarily in the files if the model is
329
+ # processed with quantization, LoRA, fine-tuning, etc.
330
+ if self.config.tie_word_embeddings and "lm_head.weight" in name:
331
+ continue
332
+ for param_name, weight_name, shard_id in stacked_params_mapping:
333
+ if weight_name not in name:
334
+ continue
335
+ name = name.replace(weight_name, param_name)
336
+ # Skip loading extra bias for GPTQ models.
337
+ if name.endswith(".bias") and name not in params_dict:
338
+ continue
339
+ param = params_dict[name]
340
+ weight_loader = param.weight_loader
341
+ weight_loader(param, loaded_weight, shard_id)
342
+ break
343
+ else:
344
+ # Skip loading extra bias for GPTQ models.
345
+ if name.endswith(".bias") and name not in params_dict:
346
+ continue
347
+ param = params_dict[name]
348
+ weight_loader = getattr(param, "weight_loader", default_weight_loader)
349
+ weight_loader(param, loaded_weight)
350
+
351
+
352
+ EntryClass = OlmoForCausalLM
@@ -23,7 +23,6 @@ import torch
23
23
  import torch.nn.functional as F
24
24
  from torch import nn
25
25
  from transformers import PretrainedConfig
26
- from vllm.config import CacheConfig
27
26
  from vllm.distributed import (
28
27
  get_tensor_model_parallel_world_size,
29
28
  tensor_model_parallel_all_reduce,
@@ -298,7 +297,7 @@ class OlmoeForCausalLM(nn.Module):
298
297
  def __init__(
299
298
  self,
300
299
  config: PretrainedConfig,
301
- cache_config: Optional[CacheConfig] = None,
300
+ cache_config=None,
302
301
  quant_config: Optional[QuantizationConfig] = None,
303
302
  ) -> None:
304
303
  super().__init__()
sglang/srt/models/qwen.py CHANGED
@@ -20,7 +20,6 @@ from typing import Any, Dict, Iterable, Optional, Tuple
20
20
  import torch
21
21
  from torch import nn
22
22
  from transformers import PretrainedConfig
23
- from vllm.config import CacheConfig
24
23
  from vllm.distributed import get_tensor_model_parallel_world_size
25
24
  from vllm.model_executor.layers.rotary_embedding import get_rope
26
25
  from vllm.model_executor.layers.vocab_parallel_embedding import (
@@ -243,7 +242,7 @@ class QWenLMHeadModel(nn.Module):
243
242
  self,
244
243
  config: PretrainedConfig,
245
244
  quant_config: Optional[QuantizationConfig] = None,
246
- cache_config: Optional[CacheConfig] = None,
245
+ cache_config=None,
247
246
  ):
248
247
  super().__init__()
249
248
  self.config = config
@@ -20,7 +20,6 @@ from typing import Any, Dict, Iterable, Optional, Tuple
20
20
 
21
21
  import torch
22
22
  from torch import nn
23
- from vllm.config import CacheConfig
24
23
  from vllm.distributed import get_tensor_model_parallel_world_size
25
24
  from vllm.model_executor.layers.rotary_embedding import get_rope
26
25
  from vllm.model_executor.layers.vocab_parallel_embedding import (
@@ -268,7 +267,7 @@ class Qwen2ForCausalLM(nn.Module):
268
267
  self,
269
268
  config: Qwen2Config,
270
269
  quant_config: Optional[QuantizationConfig] = None,
271
- cache_config: Optional[CacheConfig] = None,
270
+ cache_config=None,
272
271
  ) -> None:
273
272
  super().__init__()
274
273
  self.config = config
@@ -23,7 +23,6 @@ import torch
23
23
  import torch.nn.functional as F
24
24
  from torch import nn
25
25
  from transformers import PretrainedConfig
26
- from vllm.config import CacheConfig
27
26
  from vllm.distributed import (
28
27
  get_tensor_model_parallel_world_size,
29
28
  tensor_model_parallel_all_reduce,
@@ -160,7 +159,7 @@ class Qwen2MoeAttention(nn.Module):
160
159
  rope_theta: float = 10000,
161
160
  rope_scaling: Optional[Dict[str, Any]] = None,
162
161
  max_position_embeddings: int = 8192,
163
- cache_config: Optional[CacheConfig] = None,
162
+ cache_config=None,
164
163
  quant_config: Optional[QuantizationConfig] = None,
165
164
  ) -> None:
166
165
  super().__init__()
@@ -236,7 +235,7 @@ class Qwen2MoeDecoderLayer(nn.Module):
236
235
  self,
237
236
  config: PretrainedConfig,
238
237
  layer_id: int,
239
- cache_config: Optional[CacheConfig] = None,
238
+ cache_config=None,
240
239
  quant_config: Optional[QuantizationConfig] = None,
241
240
  ) -> None:
242
241
  super().__init__()
@@ -306,7 +305,7 @@ class Qwen2MoeModel(nn.Module):
306
305
  def __init__(
307
306
  self,
308
307
  config: PretrainedConfig,
309
- cache_config: Optional[CacheConfig] = None,
308
+ cache_config=None,
310
309
  quant_config: Optional[QuantizationConfig] = None,
311
310
  ) -> None:
312
311
  super().__init__()
@@ -355,7 +354,7 @@ class Qwen2MoeForCausalLM(nn.Module):
355
354
  def __init__(
356
355
  self,
357
356
  config: PretrainedConfig,
358
- cache_config: Optional[CacheConfig] = None,
357
+ cache_config=None,
359
358
  quant_config: Optional[QuantizationConfig] = None,
360
359
  ) -> None:
361
360
  super().__init__()