sglang 0.3.1__py3-none-any.whl → 0.3.1.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -40,7 +40,7 @@ from vllm.model_executor.models import ModelRegistry
40
40
  from sglang.srt.configs.model_config import AttentionArch, ModelConfig
41
41
  from sglang.srt.layers.attention_backend import FlashInferAttnBackend, TritonAttnBackend
42
42
  from sglang.srt.layers.logits_processor import LogitsProcessorOutput
43
- from sglang.srt.layers.sampler import SampleOutput, Sampler
43
+ from sglang.srt.layers.sampler import Sampler
44
44
  from sglang.srt.lora.lora_manager import LoRAManager
45
45
  from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict
46
46
  from sglang.srt.mem_cache.memory_pool import (
@@ -54,11 +54,9 @@ from sglang.srt.server_args import ServerArgs
54
54
  from sglang.srt.utils import (
55
55
  get_available_gpu_memory,
56
56
  is_generation_model,
57
- is_llama3_405b_fp8_head_16,
58
57
  is_multimodal_model,
59
58
  monkey_patch_vllm_dummy_weight_loader,
60
59
  monkey_patch_vllm_p2p_access_check,
61
- monkey_patch_vllm_qvk_linear_loader,
62
60
  )
63
61
 
64
62
  logger = logging.getLogger(__name__)
@@ -166,10 +164,13 @@ class ModelRunner:
166
164
  return min_per_gpu_memory
167
165
 
168
166
  def load_model(self):
169
- torch.set_num_threads(1)
170
167
  logger.info(
171
168
  f"Load weight begin. avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
172
169
  )
170
+
171
+ # This can reduce thread conflicts and speed up weight loading.
172
+ torch.set_num_threads(1)
173
+
173
174
  if torch.cuda.get_device_capability()[0] < 8:
174
175
  logger.info(
175
176
  "Compute capability below sm80. Use float16 due to lack of bfloat16 support."
@@ -178,6 +179,7 @@ class ModelRunner:
178
179
  if torch.cuda.get_device_capability()[1] < 5:
179
180
  raise RuntimeError("SGLang only supports sm75 and above.")
180
181
 
182
+ # Prepare the vllm model config
181
183
  monkey_patch_vllm_dummy_weight_loader()
182
184
  self.device_config = DeviceConfig()
183
185
  self.load_config = LoadConfig(load_format=self.server_args.load_format)
@@ -188,23 +190,16 @@ class ModelRunner:
188
190
  tokenizer_mode=None,
189
191
  trust_remote_code=self.server_args.trust_remote_code,
190
192
  dtype=self.server_args.dtype,
191
- seed=42,
193
+ seed=self.server_args.random_seed,
192
194
  skip_tokenizer_init=True,
193
195
  )
194
-
195
- # A temporary hack to fix the num_heads for meta-llama/Meta-Llama-3.1-405B-FP8 checkpoints
196
- # Drop this after Sept, 2024.
197
- if is_llama3_405b_fp8_head_16(self.model_config) and self.tp_size <= 8:
198
- self.model_config.hf_config.num_key_value_heads = 8
199
- self.vllm_model_config.hf_config.num_key_value_heads = 8
200
- monkey_patch_vllm_qvk_linear_loader()
201
-
202
- self.dtype = self.vllm_model_config.dtype
203
196
  if self.model_config.model_override_args is not None:
204
197
  self.vllm_model_config.hf_config.update(
205
198
  self.model_config.model_override_args
206
199
  )
200
+ self.dtype = self.vllm_model_config.dtype
207
201
 
202
+ # Load the model
208
203
  self.model = get_model(
209
204
  model_config=self.vllm_model_config,
210
205
  load_config=self.load_config,
@@ -255,20 +250,20 @@ class ModelRunner:
255
250
  tokenizer_mode=None,
256
251
  trust_remote_code=self.server_args.trust_remote_code,
257
252
  dtype=self.server_args.dtype,
258
- seed=42,
253
+ seed=self.server_args.random_seed,
259
254
  skip_tokenizer_init=True,
260
255
  )
261
256
  except Exception as e:
262
- logger.error(f"Failed to load model config: {e}")
263
- return False, "Failed to update model weights"
257
+ message = f"Failed to load model config: {e}."
258
+ return False, message
264
259
 
265
260
  load_config = LoadConfig(load_format=load_format)
266
261
 
267
262
  # Only support vllm DefaultModelLoader for now
268
263
  loader = get_model_loader(load_config)
269
264
  if not isinstance(loader, DefaultModelLoader):
270
- logger.error("Failed to get weights iterator: Unsupported loader")
271
- return False, "Failed to update model weights"
265
+ message = f"Failed to get model loader: {loader}."
266
+ return False, message
272
267
 
273
268
  def get_weight_iter(config):
274
269
  iter = loader._get_weights_iterator(
@@ -293,14 +288,14 @@ class ModelRunner:
293
288
  try:
294
289
  iter = get_weight_iter(vllm_model_config)
295
290
  except Exception as e:
296
- message = f"Failed to get weights iterator: {e}"
297
- logger.error(message)
291
+ message = f"Failed to get weights iterator: {e}."
298
292
  return False, message
299
293
  try:
300
294
  model = model_load_weights(self.model, iter)
301
295
  except Exception as e:
302
- message = f"Failed to update weights: {e}. \n Rolling back to original weights"
303
- logger.error(message)
296
+ message = (
297
+ f"Failed to update weights: {e}.\nRolling back to original weights."
298
+ )
304
299
  del iter
305
300
  gc.collect()
306
301
  iter = get_weight_iter(self.vllm_model_config)
@@ -315,7 +310,7 @@ class ModelRunner:
315
310
  self.model_config.path = model_path
316
311
 
317
312
  logger.info("Update weights end.")
318
- return True, "Succeeded to update model weights"
313
+ return True, "Succeeded to update model weights."
319
314
 
320
315
  def init_lora_manager(self):
321
316
  self.lora_manager = LoRAManager(
@@ -521,21 +516,6 @@ class ModelRunner:
521
516
  else:
522
517
  raise ValueError(f"Invaid forward mode: {batch.forward_mode}")
523
518
 
524
- def _check_sample_results(self, sample_output: SampleOutput):
525
- if not torch.all(sample_output.success):
526
- probs = sample_output.probs
527
- batch_next_token_ids = sample_output.batch_next_token_ids
528
- logging.warning("Sampling failed, fallback to top_k=1 strategy")
529
- probs = probs.masked_fill(torch.isnan(probs), 0.0)
530
- argmax_ids = torch.argmax(probs, dim=-1)
531
- batch_next_token_ids = torch.where(
532
- sample_output.success, batch_next_token_ids, argmax_ids
533
- )
534
- sample_output.probs = probs
535
- sample_output.batch_next_token_ids = batch_next_token_ids
536
-
537
- return sample_output.batch_next_token_ids
538
-
539
519
  def _apply_logits_bias(
540
520
  self, logits: torch.Tensor, sampling_info: SamplingBatchInfo
541
521
  ):
@@ -564,13 +544,16 @@ class ModelRunner:
564
544
  def sample(
565
545
  self, logits_output: LogitsProcessorOutput, batch: ScheduleBatch
566
546
  ) -> torch.Tensor:
547
+ # Put CPU-heavy tasks here. They will be overlapped with the forward pass.
567
548
  batch.sampling_info.update_regex_vocab_mask(batch)
568
549
  batch.sampling_info.update_penalties()
569
550
  logits = self._apply_logits_bias(
570
551
  logits_output.next_token_logits, batch.sampling_info
571
552
  )
572
- sample_output = self.sampler(logits, batch.sampling_info)
573
- return self._check_sample_results(sample_output)
553
+
554
+ # Sample the next tokens.
555
+ next_token_ids = self.sampler(logits, batch.sampling_info)
556
+ return next_token_ids
574
557
 
575
558
 
576
559
  @lru_cache()
@@ -19,7 +19,6 @@ limitations under the License.
19
19
  from typing import Any, Dict, Iterable, Optional, Tuple
20
20
 
21
21
  import torch
22
- from flashinfer import bmm_fp8
23
22
  from torch import nn
24
23
  from transformers import PretrainedConfig
25
24
  from vllm.config import CacheConfig
@@ -48,6 +47,11 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
48
47
  from sglang.srt.layers.radix_attention import RadixAttention
49
48
  from sglang.srt.managers.schedule_batch import global_server_args_dict
50
49
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
50
+ from sglang.srt.utils import is_hip
51
+
52
+ # ROCm: flashinfer available later
53
+ if not is_hip():
54
+ from flashinfer import bmm_fp8
51
55
 
52
56
 
53
57
  class DeepseekV2MLP(nn.Module):
@@ -649,6 +653,7 @@ class DeepseekV2ForCausalLM(nn.Module):
649
653
  )
650
654
  self.logits_processor = LogitsProcessor(config)
651
655
 
656
+ @torch.no_grad()
652
657
  def forward(
653
658
  self,
654
659
  input_ids: torch.Tensor,
@@ -19,7 +19,6 @@ import math
19
19
  from typing import Any, Dict, Iterable, Optional, Tuple
20
20
 
21
21
  import torch
22
- from flashinfer import bmm_fp8
23
22
  from torch import nn
24
23
  from transformers import PretrainedConfig
25
24
  from vllm.config import CacheConfig
@@ -44,6 +43,11 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
44
43
  from sglang.srt.layers.radix_attention import RadixAttention
45
44
  from sglang.srt.managers.schedule_batch import global_server_args_dict
46
45
  from sglang.srt.model_executor.forward_batch_info import InputMetadata
46
+ from sglang.srt.utils import is_hip
47
+
48
+ # ROCm: flashinfer available later
49
+ if not is_hip():
50
+ from flashinfer import bmm_fp8
47
51
 
48
52
 
49
53
  class MiniCPM3MLP(nn.Module):
@@ -0,0 +1,415 @@
1
+ """
2
+ Copyright 2023-2024 SGLang Team
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
14
+ """
15
+
16
+ # Adapted from:
17
+ # https://github.com/vllm-project/vllm/pull/7922
18
+
19
+ """Inference-only OLMoE model compatible with HuggingFace weights."""
20
+ from typing import Any, Dict, Iterable, List, Optional, Tuple
21
+
22
+ import torch
23
+ import torch.nn.functional as F
24
+ from torch import nn
25
+ from transformers import PretrainedConfig
26
+ from vllm.config import CacheConfig
27
+ from vllm.distributed import (
28
+ get_tensor_model_parallel_world_size,
29
+ tensor_model_parallel_all_reduce,
30
+ )
31
+ from vllm.model_executor.layers.fused_moe import FusedMoE
32
+ from vllm.model_executor.layers.linear import (
33
+ MergedColumnParallelLinear,
34
+ QKVParallelLinear,
35
+ ReplicatedLinear,
36
+ RowParallelLinear,
37
+ )
38
+ from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
39
+ from vllm.model_executor.layers.rotary_embedding import get_rope
40
+ from vllm.model_executor.layers.vocab_parallel_embedding import (
41
+ ParallelLMHead,
42
+ VocabParallelEmbedding,
43
+ )
44
+ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
45
+ from vllm.utils import print_warning_once
46
+
47
+ from sglang.srt.layers.activation import SiluAndMul
48
+ from sglang.srt.layers.layernorm import RMSNorm
49
+ from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
50
+ from sglang.srt.layers.radix_attention import RadixAttention
51
+ from sglang.srt.model_executor.forward_batch_info import InputMetadata
52
+
53
+
54
+ class OlmoeMoE(nn.Module):
55
+ """A tensor-parallel MoE implementation for Olmoe that shards each expert
56
+ across all ranks.
57
+
58
+ Each expert's weights are sharded across all ranks and a fused MoE
59
+ kernel is used for the forward pass, and finally we reduce the outputs
60
+ across ranks.
61
+ """
62
+
63
+ def __init__(
64
+ self,
65
+ num_experts: int,
66
+ top_k: int,
67
+ hidden_size: int,
68
+ intermediate_size: int,
69
+ params_dtype: Optional[torch.dtype] = None,
70
+ quant_config: Optional[QuantizationConfig] = None,
71
+ tp_size: Optional[int] = None,
72
+ prefix: str = "",
73
+ ):
74
+ super().__init__()
75
+ self.hidden_size = hidden_size
76
+
77
+ # Gate always runs at half / full precision for now.
78
+ self.gate = ReplicatedLinear(
79
+ hidden_size, num_experts, bias=False, quant_config=None
80
+ )
81
+
82
+ self.experts = FusedMoE(
83
+ num_experts=num_experts,
84
+ top_k=top_k,
85
+ hidden_size=hidden_size,
86
+ intermediate_size=intermediate_size,
87
+ reduce_results=True,
88
+ renormalize=False,
89
+ quant_config=quant_config,
90
+ tp_size=tp_size,
91
+ )
92
+
93
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
94
+ # NOTE: hidden_states can have either 1D or 2D shape.
95
+ orig_shape = hidden_states.shape
96
+ hidden_states = hidden_states.view(-1, self.hidden_size)
97
+ # router_logits: (num_tokens, n_experts)
98
+ router_logits, _ = self.gate(hidden_states)
99
+ final_hidden_states = self.experts(
100
+ hidden_states=hidden_states, router_logits=router_logits
101
+ )
102
+ return final_hidden_states.view(orig_shape)
103
+
104
+
105
+ class OlmoeAttention(nn.Module):
106
+
107
+ def __init__(
108
+ self,
109
+ layer_id: int,
110
+ hidden_size: int,
111
+ num_heads: int,
112
+ num_kv_heads: int,
113
+ rope_theta: float = 10000,
114
+ rope_scaling: Optional[Dict[str, Any]] = None,
115
+ max_position_embeddings: int = 4096,
116
+ quant_config: Optional[QuantizationConfig] = None,
117
+ ) -> None:
118
+ super().__init__()
119
+ self.hidden_size = hidden_size
120
+ tp_size = get_tensor_model_parallel_world_size()
121
+ self.total_num_heads = num_heads
122
+ assert self.total_num_heads % tp_size == 0
123
+ self.num_heads = self.total_num_heads // tp_size
124
+ self.total_num_kv_heads = num_kv_heads
125
+ if self.total_num_kv_heads >= tp_size:
126
+ # Number of KV heads is greater than TP size, so we partition
127
+ # the KV heads across multiple tensor parallel GPUs.
128
+ assert self.total_num_kv_heads % tp_size == 0
129
+ else:
130
+ # Number of KV heads is less than TP size, so we replicate
131
+ # the KV heads across multiple tensor parallel GPUs.
132
+ assert tp_size % self.total_num_kv_heads == 0
133
+ self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
134
+ self.head_dim = hidden_size // self.total_num_heads
135
+ self.q_size = self.num_heads * self.head_dim
136
+ self.kv_size = self.num_kv_heads * self.head_dim
137
+ self.scaling = self.head_dim**-0.5
138
+ self.rope_theta = rope_theta
139
+ self.max_position_embeddings = max_position_embeddings
140
+
141
+ self.qkv_proj = QKVParallelLinear(
142
+ hidden_size,
143
+ self.head_dim,
144
+ self.total_num_heads,
145
+ self.total_num_kv_heads,
146
+ bias=False,
147
+ quant_config=quant_config,
148
+ )
149
+ self.q_norm = RMSNorm(hidden_size, eps=1e-5)
150
+ self.k_norm = RMSNorm(hidden_size, eps=1e-5)
151
+ self.o_proj = RowParallelLinear(
152
+ self.total_num_heads * self.head_dim,
153
+ hidden_size,
154
+ bias=False,
155
+ quant_config=quant_config,
156
+ )
157
+
158
+ self.rotary_emb = get_rope(
159
+ self.head_dim,
160
+ rotary_dim=self.head_dim,
161
+ max_position=max_position_embeddings,
162
+ base=rope_theta,
163
+ rope_scaling=rope_scaling,
164
+ is_neox_style=True,
165
+ )
166
+ self.attn = RadixAttention(
167
+ self.num_heads,
168
+ self.head_dim,
169
+ self.scaling,
170
+ layer_id=layer_id,
171
+ num_kv_heads=self.num_kv_heads,
172
+ )
173
+
174
+ def forward(
175
+ self,
176
+ positions: torch.Tensor,
177
+ hidden_states: torch.Tensor,
178
+ input_metadata: InputMetadata,
179
+ ) -> torch.Tensor:
180
+ qkv, _ = self.qkv_proj(hidden_states)
181
+ q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
182
+ q, k = self.q_norm(q.contiguous()), self.k_norm(k.contiguous())
183
+ q, k = self.rotary_emb(positions, q, k)
184
+ attn_output = self.attn(q, k, v, input_metadata)
185
+ output, _ = self.o_proj(attn_output)
186
+ return output
187
+
188
+
189
+ class OlmoeDecoderLayer(nn.Module):
190
+
191
+ def __init__(
192
+ self,
193
+ config: PretrainedConfig,
194
+ layer_id: int = 0,
195
+ quant_config: Optional[QuantizationConfig] = None,
196
+ ) -> None:
197
+ super().__init__()
198
+ self.hidden_size = config.hidden_size
199
+ rope_theta = getattr(config, "rope_theta", 10000)
200
+ rope_scaling = getattr(config, "rope_scaling", None)
201
+ max_position_embeddings = getattr(config, "max_position_embeddings", 4096)
202
+
203
+ self.self_attn = OlmoeAttention(
204
+ layer_id,
205
+ hidden_size=self.hidden_size,
206
+ num_heads=config.num_attention_heads,
207
+ num_kv_heads=config.num_key_value_heads,
208
+ rope_theta=rope_theta,
209
+ rope_scaling=rope_scaling,
210
+ max_position_embeddings=max_position_embeddings,
211
+ quant_config=quant_config,
212
+ )
213
+
214
+ self.mlp = OlmoeMoE(
215
+ num_experts=config.num_experts,
216
+ top_k=config.num_experts_per_tok,
217
+ hidden_size=config.hidden_size,
218
+ intermediate_size=config.intermediate_size,
219
+ quant_config=quant_config,
220
+ )
221
+ self.input_layernorm = RMSNorm(config.hidden_size, eps=1e-5)
222
+ self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=1e-5)
223
+
224
+ def forward(
225
+ self,
226
+ positions: torch.Tensor,
227
+ hidden_states: torch.Tensor,
228
+ input_metadata: InputMetadata,
229
+ residual: Optional[torch.Tensor],
230
+ ) -> torch.Tensor:
231
+ # Self Attention
232
+ if residual is None:
233
+ residual = hidden_states
234
+ hidden_states = self.input_layernorm(hidden_states)
235
+ else:
236
+ hidden_states, residual = self.input_layernorm(hidden_states, residual)
237
+
238
+ hidden_states = self.self_attn(
239
+ positions=positions,
240
+ hidden_states=hidden_states,
241
+ input_metadata=input_metadata,
242
+ )
243
+
244
+ # Fully Connected
245
+ hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
246
+ hidden_states = self.mlp(hidden_states)
247
+ return hidden_states, residual
248
+
249
+
250
+ class OlmoeModel(nn.Module):
251
+
252
+ def __init__(
253
+ self,
254
+ config: PretrainedConfig,
255
+ quant_config: Optional[QuantizationConfig] = None,
256
+ ) -> None:
257
+ super().__init__()
258
+ self.padding_idx = config.pad_token_id
259
+ self.vocab_size = config.vocab_size
260
+
261
+ self.embed_tokens = VocabParallelEmbedding(
262
+ config.vocab_size,
263
+ config.hidden_size,
264
+ )
265
+ self.layers = nn.ModuleList(
266
+ [
267
+ OlmoeDecoderLayer(config, layer_id, quant_config=quant_config)
268
+ for layer_id in range(config.num_hidden_layers)
269
+ ]
270
+ )
271
+ self.norm = RMSNorm(config.hidden_size, eps=1e-5)
272
+
273
+ def forward(
274
+ self,
275
+ input_ids: torch.Tensor,
276
+ positions: torch.Tensor,
277
+ input_metadata: InputMetadata,
278
+ input_embeds: torch.Tensor = None,
279
+ ) -> torch.Tensor:
280
+ if input_embeds is None:
281
+ hidden_states = self.embed_tokens(input_ids)
282
+ else:
283
+ hidden_states = input_embeds
284
+ residual = None
285
+ for i in range(len(self.layers)):
286
+ layer = self.layers[i]
287
+ hidden_states, residual = layer(
288
+ positions, hidden_states, input_metadata, residual
289
+ )
290
+ hidden_states, _ = self.norm(hidden_states, residual)
291
+ return hidden_states
292
+
293
+
294
+ class OlmoeForCausalLM(nn.Module):
295
+
296
+ fall_back_to_pt_during_load = False
297
+
298
+ def __init__(
299
+ self,
300
+ config: PretrainedConfig,
301
+ cache_config: Optional[CacheConfig] = None,
302
+ quant_config: Optional[QuantizationConfig] = None,
303
+ ) -> None:
304
+ super().__init__()
305
+ self.config = config
306
+ self.quant_config = quant_config
307
+ self.model = OlmoeModel(config, quant_config)
308
+ self.lm_head = ParallelLMHead(
309
+ config.vocab_size, config.hidden_size, quant_config=quant_config
310
+ )
311
+ self.logits_processor = LogitsProcessor(config)
312
+
313
+ def forward(
314
+ self,
315
+ input_ids: torch.Tensor,
316
+ positions: torch.Tensor,
317
+ input_metadata: InputMetadata,
318
+ input_embeds: torch.Tensor = None,
319
+ ) -> torch.Tensor:
320
+ hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
321
+ return self.logits_processor(
322
+ input_ids, hidden_states, self.lm_head.weight, input_metadata
323
+ )
324
+
325
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
326
+ stacked_params_mapping = [
327
+ # (param_name, shard_name, shard_id)
328
+ ("qkv_proj", "q_proj", "q"),
329
+ ("qkv_proj", "k_proj", "k"),
330
+ ("qkv_proj", "v_proj", "v"),
331
+ ("gate_up_proj", "gate_proj", 0),
332
+ ("gate_up_proj", "up_proj", 1),
333
+ ]
334
+
335
+ # Params for weights, fp8 weight scales, fp8 activation scales
336
+ # (param_name, weight_name, expert_id, shard_id)
337
+ expert_params_mapping = FusedMoE.make_expert_params_mapping(
338
+ ckpt_gate_proj_name="gate_proj",
339
+ ckpt_down_proj_name="down_proj",
340
+ ckpt_up_proj_name="up_proj",
341
+ num_experts=self.config.num_experts,
342
+ )
343
+
344
+ params_dict = dict(self.named_parameters())
345
+ for name, loaded_weight in weights:
346
+ if "rotary_emb.inv_freq" in name:
347
+ continue
348
+ for param_name, weight_name, shard_id in stacked_params_mapping:
349
+ # Skip non-stacked layers and experts (experts handled below).
350
+ if weight_name not in name:
351
+ continue
352
+ # We have mlp.experts[0].gate_proj in the checkpoint.
353
+ # Since we handle the experts below in expert_params_mapping,
354
+ # we need to skip here BEFORE we update the name, otherwise
355
+ # name will be updated to mlp.experts[0].gate_up_proj, which
356
+ # will then be updated below in expert_params_mapping
357
+ # for mlp.experts[0].gate_gate_up_proj, which breaks load.
358
+ if "mlp.experts" in name:
359
+ continue
360
+ name = name.replace(weight_name, param_name)
361
+ # Skip loading extra bias for GPTQ models.
362
+ if name.endswith(".bias") and name not in params_dict:
363
+ continue
364
+ if name not in params_dict:
365
+ continue
366
+
367
+ param = params_dict[name]
368
+ weight_loader = param.weight_loader
369
+ weight_loader(param, loaded_weight, shard_id)
370
+ break
371
+ else:
372
+ for mapping in expert_params_mapping:
373
+ param_name, weight_name, expert_id, shard_id = mapping
374
+ if weight_name not in name:
375
+ continue
376
+ name = name.replace(weight_name, param_name)
377
+ param = params_dict[name]
378
+ weight_loader = param.weight_loader
379
+ weight_loader(
380
+ param,
381
+ loaded_weight,
382
+ name,
383
+ shard_id=shard_id,
384
+ expert_id=expert_id,
385
+ )
386
+ break
387
+ else:
388
+ # Skip loading extra bias for GPTQ models.
389
+ if name.endswith(".bias") and name not in params_dict:
390
+ continue
391
+ # Remapping the name of FP8 kv-scale.
392
+ if name.endswith("kv_scale"):
393
+ remapped_kv_scale_name = name.replace(
394
+ ".kv_scale", ".attn.kv_scale"
395
+ )
396
+ if remapped_kv_scale_name not in params_dict:
397
+ print_warning_once(
398
+ "Found kv scale in the checkpoint "
399
+ f"(e.g. {name}), but not found the expected "
400
+ f"name in the model "
401
+ f"(e.g. {remapped_kv_scale_name}). "
402
+ "kv-scale is not loaded."
403
+ )
404
+ continue
405
+ else:
406
+ name = remapped_kv_scale_name
407
+
408
+ param = params_dict[name]
409
+ weight_loader = getattr(
410
+ param, "weight_loader", default_weight_loader
411
+ )
412
+ weight_loader(param, loaded_weight)
413
+
414
+
415
+ EntryClass = OlmoeForCausalLM
@@ -34,56 +34,6 @@ class SamplingBatchInfo:
34
34
  linear_penalties: torch.Tensor = None
35
35
  scaling_penalties: torch.Tensor = None
36
36
 
37
- def __len__(self):
38
- return len(self.temperatures)
39
-
40
- def can_run_in_cuda_graph(self):
41
- # Vocab bias and min_ps are not supported in CUDA graph
42
- return (
43
- self.logit_bias is None
44
- and self.linear_penalties is None
45
- and self.scaling_penalties is None
46
- and not self.need_min_p_sampling
47
- )
48
-
49
- @classmethod
50
- def dummy_one(cls, max_bs: int, vocab_size: int):
51
- ret = cls(vocab_size=vocab_size)
52
- with torch.device("cuda"):
53
- ret.temperatures = torch.ones((max_bs, 1), dtype=torch.float)
54
- ret.top_ps = torch.ones((max_bs,), dtype=torch.float)
55
- ret.top_ks = torch.ones((max_bs,), dtype=torch.int)
56
- ret.vocab_mask = torch.zeros((max_bs, vocab_size), dtype=torch.bool)
57
- return ret
58
-
59
- def __getitem__(self, key):
60
- if isinstance(key, slice):
61
- # NOTE:This method is only used in CUDA graph
62
- assert self.can_run_in_cuda_graph()
63
- return SamplingBatchInfo(
64
- vocab_size=self.vocab_size,
65
- temperatures=self.temperatures[key],
66
- top_ps=self.top_ps[key],
67
- top_ks=self.top_ks[key],
68
- vocab_mask=self.vocab_mask[key],
69
- )
70
- else:
71
- raise NotImplementedError
72
-
73
- def inplace_assign(self, bs: int, other: SamplingBatchInfo):
74
- # NOTE:This method is only used in CUDA graph
75
- assert self.can_run_in_cuda_graph()
76
-
77
- self.vocab_size = other.vocab_size
78
- self.temperatures[:bs] = other.temperatures
79
- self.top_ps[:bs] = other.top_ps
80
- self.top_ks[:bs] = other.top_ks
81
-
82
- if other.vocab_mask is None:
83
- self.vocab_mask[:bs].fill_(False)
84
- else:
85
- self.vocab_mask[:bs] = other.vocab_mask
86
-
87
37
  @classmethod
88
38
  def from_schedule_batch(cls, batch: ScheduleBatch, vocab_size: int):
89
39
  reqs = batch.reqs
@@ -130,6 +80,9 @@ class SamplingBatchInfo:
130
80
 
131
81
  return ret
132
82
 
83
+ def __len__(self):
84
+ return len(self.temperatures)
85
+
133
86
  def update_penalties(self):
134
87
  self.scaling_penalties = None
135
88
  self.linear_penalties = None