sglang 0.3.1__py3-none-any.whl → 0.3.1.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (37) hide show
  1. sglang/bench_latency.py +10 -3
  2. sglang/bench_server_latency.py +187 -0
  3. sglang/bench_serving.py +1 -1
  4. sglang/global_config.py +5 -13
  5. sglang/lang/interpreter.py +0 -3
  6. sglang/srt/constrained/fsm_cache.py +5 -1
  7. sglang/srt/layers/activation.py +16 -1
  8. sglang/srt/layers/attention_backend.py +12 -12
  9. sglang/srt/layers/fused_moe/layer.py +27 -7
  10. sglang/srt/layers/layernorm.py +21 -6
  11. sglang/srt/layers/sampler.py +40 -98
  12. sglang/srt/lora/lora_manager.py +11 -8
  13. sglang/srt/managers/io_struct.py +3 -0
  14. sglang/srt/managers/policy_scheduler.py +49 -93
  15. sglang/srt/managers/schedule_batch.py +2 -1
  16. sglang/srt/managers/tp_worker.py +19 -13
  17. sglang/srt/model_executor/cuda_graph_runner.py +25 -13
  18. sglang/srt/model_executor/model_runner.py +37 -46
  19. sglang/srt/models/deepseek_v2.py +8 -3
  20. sglang/srt/models/llama.py +1 -3
  21. sglang/srt/models/llama_classification.py +2 -3
  22. sglang/srt/models/minicpm3.py +7 -3
  23. sglang/srt/models/olmoe.py +415 -0
  24. sglang/srt/models/xverse.py +1 -3
  25. sglang/srt/models/xverse_moe.py +1 -4
  26. sglang/srt/sampling/sampling_batch_info.py +3 -50
  27. sglang/srt/server.py +6 -1
  28. sglang/srt/server_args.py +39 -10
  29. sglang/srt/utils.py +7 -51
  30. sglang/test/few_shot_gsm8k.py +8 -2
  31. sglang/test/test_utils.py +1 -1
  32. sglang/version.py +1 -1
  33. {sglang-0.3.1.dist-info → sglang-0.3.1.post2.dist-info}/METADATA +4 -5
  34. {sglang-0.3.1.dist-info → sglang-0.3.1.post2.dist-info}/RECORD +37 -35
  35. {sglang-0.3.1.dist-info → sglang-0.3.1.post2.dist-info}/WHEEL +1 -1
  36. {sglang-0.3.1.dist-info → sglang-0.3.1.post2.dist-info}/LICENSE +0 -0
  37. {sglang-0.3.1.dist-info → sglang-0.3.1.post2.dist-info}/top_level.txt +0 -0
@@ -21,12 +21,15 @@ import re
21
21
  from dataclasses import dataclass
22
22
 
23
23
  import torch
24
- from flashinfer import SegmentGEMMWrapper
25
24
 
26
25
  from sglang.srt.lora.lora import LoRAAdapter, get_lora_layer
27
26
  from sglang.srt.lora.lora_config import LoRAConfig
28
27
  from sglang.srt.model_executor.forward_batch_info import ForwardMode
29
- from sglang.srt.utils import replace_submodule
28
+ from sglang.srt.utils import is_hip, replace_submodule
29
+
30
+ # ROCm: flashinfer available later
31
+ if not is_hip():
32
+ from flashinfer import SegmentGEMMWrapper
30
33
 
31
34
 
32
35
  def get_stacked_name(name):
@@ -96,10 +99,10 @@ class LoRAManager:
96
99
  # get configs and target modules
97
100
  self.configs = {}
98
101
  self.origin_target_modules = set()
99
- for path in self.lora_paths:
100
- self.configs[path] = LoRAConfig(path)
102
+ for name, path in self.lora_paths.items():
103
+ self.configs[name] = LoRAConfig(path)
101
104
  self.origin_target_modules = set(self.origin_target_modules) | set(
102
- self.configs[path].target_modules
105
+ self.configs[name].target_modules
103
106
  )
104
107
  self.target_modules = set(
105
108
  [
@@ -114,11 +117,11 @@ class LoRAManager:
114
117
  # load all weights to cpu
115
118
  self.loras = []
116
119
  self.lora_id = {}
117
- for path in self.lora_paths:
118
- self.lora_id[path] = len(self.loras)
120
+ for name in self.lora_paths.keys():
121
+ self.lora_id[name] = len(self.loras)
119
122
  self.loras.append(
120
123
  LoRAAdapter(
121
- path, self.configs[path], self.base_hf_config, self.load_config
124
+ name, self.configs[name], self.base_hf_config, self.load_config
122
125
  )
123
126
  )
124
127
  self.loras[-1].initialize_weights()
@@ -133,6 +133,9 @@ class GenerateReqInput:
133
133
  self.image_data = [None] * num
134
134
  elif not isinstance(self.image_data, list):
135
135
  self.image_data = [self.image_data] * num
136
+ elif isinstance(self.image_data, list):
137
+ # multi-image with n > 1
138
+ self.image_data = self.image_data * num
136
139
 
137
140
  if self.sampling_params is None:
138
141
  self.sampling_params = [{}] * num
@@ -119,19 +119,32 @@ class PrefillAdder:
119
119
  self.running_batch = running_batch
120
120
  self.new_token_ratio = new_token_ratio
121
121
  self.rem_total_tokens = rem_total_tokens - mixed_with_decode_tokens
122
- self.rem_total_tokens_ = self.rem_total_tokens
123
- self.total_tokens = rem_total_tokens
124
122
  self.rem_input_tokens = rem_input_tokens - mixed_with_decode_tokens
125
123
  self.rem_chunk_tokens = rem_chunk_tokens
126
124
  if self.rem_chunk_tokens is not None:
127
125
  self.rem_chunk_tokens -= mixed_with_decode_tokens
128
126
 
127
+ self.cur_rem_tokens = rem_total_tokens - mixed_with_decode_tokens
128
+
129
129
  self.req_states = None
130
130
  self.can_run_list = []
131
131
  self.new_inflight_req = None
132
132
  self.log_hit_tokens = 0
133
133
  self.log_input_tokens = 0
134
134
 
135
+ if running_batch is not None:
136
+ # Pre-remove the tokens which will be occupied by the running requests
137
+ self.rem_total_tokens -= sum(
138
+ [
139
+ min(
140
+ (r.sampling_params.max_new_tokens - len(r.output_ids)),
141
+ CLIP_MAX_NEW_TOKENS,
142
+ )
143
+ * self.new_token_ratio
144
+ for r in running_batch.reqs
145
+ ]
146
+ )
147
+
135
148
  def no_remaining_tokens(self):
136
149
  return (
137
150
  self.rem_total_tokens <= 0
@@ -141,31 +154,14 @@ class PrefillAdder:
141
154
  if self.rem_chunk_tokens is not None
142
155
  else False
143
156
  )
144
- )
145
-
146
- def remove_running_tokens(self, running_batch: ScheduleBatch):
147
- self.rem_total_tokens -= sum(
148
- [
149
- min(
150
- (r.sampling_params.max_new_tokens - len(r.output_ids)),
151
- CLIP_MAX_NEW_TOKENS,
152
- )
153
- * self.new_token_ratio
154
- for r in running_batch.reqs
155
- ]
156
- )
157
- self.rem_total_tokens_ -= sum(
158
- [
159
- r.sampling_params.max_new_tokens - len(r.output_ids)
160
- for r in running_batch.reqs
161
- ]
157
+ or self.cur_rem_tokens <= 0
162
158
  )
163
159
 
164
160
  def _prefill_one_req(
165
161
  self, prefix_len: int, extend_input_len: int, max_new_tokens: int
166
162
  ):
167
163
  self.rem_total_tokens -= extend_input_len + max_new_tokens
168
- self.rem_total_tokens_ -= extend_input_len + max_new_tokens
164
+ self.cur_rem_tokens -= extend_input_len
169
165
  self.rem_input_tokens -= extend_input_len
170
166
  if self.rem_chunk_tokens is not None:
171
167
  self.rem_chunk_tokens -= extend_input_len
@@ -173,29 +169,7 @@ class PrefillAdder:
173
169
  self.log_hit_tokens += prefix_len
174
170
  self.log_input_tokens += extend_input_len
175
171
 
176
- def add_inflight_req_ignore_eos(self, req: Req):
177
- truncated = req.extend_input_len > self.rem_chunk_tokens
178
- req.extend_input_len = min(req.extend_input_len, self.rem_chunk_tokens)
179
- req.fill_ids = req.fill_ids[: len(req.prefix_indices) + req.extend_input_len]
180
- self.can_run_list.append(req)
181
-
182
- self._prefill_one_req(
183
- 0,
184
- req.extend_input_len,
185
- (
186
- min(req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS)
187
- if not truncated
188
- else 0
189
- ),
190
- )
191
-
192
- # Return if chunked prefill not finished
193
- return req if truncated else None
194
-
195
172
  def add_inflight_req(self, req: Req):
196
- if req.sampling_params.ignore_eos:
197
- return self.add_inflight_req_ignore_eos(req)
198
-
199
173
  truncated = req.extend_input_len > self.rem_chunk_tokens
200
174
  req.extend_input_len = min(req.extend_input_len, self.rem_chunk_tokens)
201
175
  req.fill_ids = req.fill_ids[: len(req.prefix_indices) + req.extend_input_len]
@@ -225,7 +199,7 @@ class PrefillAdder:
225
199
  self.rem_total_tokens += delta
226
200
 
227
201
  def add_one_req_ignore_eos(self, req: Req):
228
- def get_req_state(r):
202
+ def add_req_state(r, insert_sort=False):
229
203
  new_token_ratio = (
230
204
  1.0 if r.sampling_params.ignore_eos else self.new_token_ratio
231
205
  )
@@ -235,56 +209,38 @@ class PrefillAdder:
235
209
  tokens_occupied = len(r.origin_input_ids) + len(r.output_ids)
236
210
 
237
211
  if tokens_left > 0:
238
- return (tokens_left, tokens_occupied)
239
-
240
- return None
241
-
242
- # Quick Check
243
- can_run = False
244
- if (
245
- req.extend_input_len + req.sampling_params.max_new_tokens
246
- <= self.rem_total_tokens
247
- ):
248
- can_run = True
249
-
250
- if not can_run:
251
- if self.req_states is None:
252
- self.req_states = []
253
- if self.running_batch is not None:
254
- for r in self.running_batch.reqs:
255
- state = get_req_state(r)
256
- if state is not None:
257
- self.req_states.append(state)
258
- for r in self.can_run_list:
259
- state = get_req_state(r)
260
- if state is not None:
261
- self.req_states.append(state)
262
- state = get_req_state(req)
263
- if state is not None:
264
- self.req_states.append(state)
265
-
266
- self.req_states.sort(key=lambda x: x[0])
267
- else:
268
- state = get_req_state(req)
269
- if state is not None:
270
- for i, (tokens_left, tokens_occupied) in enumerate(self.req_states):
271
- if tokens_left >= state[0]:
272
- self.req_states.insert(i, state)
212
+ if not insert_sort:
213
+ self.req_states.append((tokens_left, tokens_occupied))
214
+ else:
215
+ for i in range(len(self.req_states)):
216
+ if tokens_left <= self.req_states[i][0]:
273
217
  break
274
- else:
275
- self.req_states.append(state)
276
-
277
- tokens_freed = 0
278
- for i, (tokens_left, tokens_occupied) in enumerate(self.req_states):
279
- decode_steps = (
280
- self.req_states[i + 1][0]
281
- if i + 1 < len(self.req_states)
282
- else tokens_left
283
- )
284
- bs = len(self.req_states) - i
285
- if self.total_tokens + tokens_freed - decode_steps * bs <= 0:
286
- return False
287
- tokens_freed += tokens_occupied
218
+ self.req_states.insert(i, (tokens_left, tokens_occupied))
219
+
220
+ if self.req_states is None:
221
+ self.req_states = []
222
+ add_req_state(req)
223
+ if self.running_batch is not None:
224
+ for r in self.running_batch.reqs:
225
+ add_req_state(r)
226
+ for r in self.can_run_list:
227
+ add_req_state(r)
228
+ self.req_states.sort(key=lambda x: x[0])
229
+ else:
230
+ add_req_state(req, insert_sort=True)
231
+
232
+ cur_rem_tokens = self.cur_rem_tokens - len(req.origin_input_ids)
233
+ tokens_freed = 0
234
+ for i, (tokens_left, tokens_occupied) in enumerate(self.req_states):
235
+ decode_steps = (
236
+ self.req_states[i + 1][0]
237
+ if i + 1 < len(self.req_states)
238
+ else tokens_left
239
+ )
240
+ bs = len(self.req_states) - i
241
+ if cur_rem_tokens + tokens_freed - decode_steps * bs <= 0:
242
+ return False
243
+ tokens_freed += tokens_occupied
288
244
 
289
245
  if req.extend_input_len <= self.rem_chunk_tokens:
290
246
  self.can_run_list.append(req)
@@ -40,7 +40,7 @@ global_server_args_dict = {
40
40
  "attention_backend": ServerArgs.attention_backend,
41
41
  "sampling_backend": ServerArgs.sampling_backend,
42
42
  "triton_attention_reduce_in_fp32": ServerArgs.triton_attention_reduce_in_fp32,
43
- "enable_mla": ServerArgs.enable_mla,
43
+ "disable_mla": ServerArgs.disable_mla,
44
44
  "torchao_config": ServerArgs.torchao_config,
45
45
  }
46
46
 
@@ -360,6 +360,7 @@ class ScheduleBatch:
360
360
  tree_cache: BasePrefixCache
361
361
 
362
362
  forward_mode: ForwardMode = None
363
+ sampling_info: SamplingBatchInfo = None
363
364
 
364
365
  # Batched arguments to model runner
365
366
  input_ids: torch.Tensor = None
@@ -198,6 +198,7 @@ class ModelTpServer:
198
198
  "trust_remote_code": server_args.trust_remote_code,
199
199
  },
200
200
  skip_tokenizer_init=server_args.skip_tokenizer_init,
201
+ constrained_json_whitespace_pattern=server_args.constrained_json_whitespace_pattern,
201
202
  )
202
203
  self.jump_forward_cache = JumpForwardCache()
203
204
 
@@ -414,7 +415,7 @@ class ModelTpServer:
414
415
 
415
416
  # Truncate prompts that are too long
416
417
  if len(req.origin_input_ids) >= self.max_req_input_len:
417
- logger.warn(
418
+ logger.warning(
418
419
  "Request length is longer than the KV cache pool size or "
419
420
  "the max context length. Truncated!!!"
420
421
  )
@@ -444,9 +445,6 @@ class ModelTpServer:
444
445
  num_mixed_running,
445
446
  )
446
447
 
447
- if self.running_batch is not None:
448
- adder.remove_running_tokens(self.running_batch)
449
-
450
448
  has_inflight = self.current_inflight_req is not None
451
449
  if self.current_inflight_req is not None:
452
450
  self.current_inflight_req.init_next_round_input(
@@ -464,9 +462,6 @@ class ModelTpServer:
464
462
  )
465
463
 
466
464
  for req in self.waiting_queue:
467
- if adder.no_remaining_tokens():
468
- break
469
- req.init_next_round_input(None if prefix_computed else self.tree_cache)
470
465
  if (
471
466
  self.lora_paths is not None
472
467
  and len(
@@ -477,6 +472,10 @@ class ModelTpServer:
477
472
  > self.max_loras_per_batch
478
473
  ):
479
474
  break
475
+
476
+ if adder.no_remaining_tokens():
477
+ break
478
+ req.init_next_round_input(None if prefix_computed else self.tree_cache)
480
479
  res = adder.add_one_req(req)
481
480
  if (
482
481
  not res
@@ -506,6 +505,11 @@ class ModelTpServer:
506
505
  else:
507
506
  tree_cache_hit_rate = 0.0
508
507
 
508
+ num_used = self.max_total_num_tokens - (
509
+ self.token_to_kv_pool.available_size()
510
+ + self.tree_cache.evictable_size()
511
+ )
512
+
509
513
  if num_mixed_running > 0:
510
514
  logger.info(
511
515
  f"Prefill batch"
@@ -514,6 +518,7 @@ class ModelTpServer:
514
518
  f"#new-token: {adder.log_input_tokens}, "
515
519
  f"#cached-token: {adder.log_hit_tokens}, "
516
520
  f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, "
521
+ f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
517
522
  f"#queue-req: {len(self.waiting_queue) - len(can_run_list) + has_inflight}"
518
523
  )
519
524
  else:
@@ -523,6 +528,7 @@ class ModelTpServer:
523
528
  f"#new-token: {adder.log_input_tokens}, "
524
529
  f"#cached-token: {adder.log_hit_tokens}, "
525
530
  f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, "
531
+ f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
526
532
  f"#running-req: {running_bs}, "
527
533
  f"#queue-req: {len(self.waiting_queue) - len(can_run_list) + has_inflight}"
528
534
  )
@@ -807,12 +813,10 @@ class ModelTpServer:
807
813
  unfinished_indices.append(i)
808
814
 
809
815
  if req.finished() or (
810
- (
811
- req.stream
812
- and (
813
- self.decode_forward_ct % self.stream_interval == 0
814
- or len(req.output_ids) == 1
815
- )
816
+ req.stream
817
+ and (
818
+ self.decode_forward_ct % self.stream_interval == 0
819
+ or len(req.output_ids) == 1
816
820
  )
817
821
  ):
818
822
  output_rids.append(req.rid)
@@ -937,6 +941,8 @@ class ModelTpServer:
937
941
  if success:
938
942
  flash_cache_success = self.flush_cache()
939
943
  assert flash_cache_success, "Cache flush failed after updating weights"
944
+ else:
945
+ logger.error(message)
940
946
  return success, message
941
947
 
942
948
 
@@ -41,6 +41,9 @@ if TYPE_CHECKING:
41
41
  def _to_torch(model: torch.nn.Module, reverse: bool = False):
42
42
  for sub in model._modules.values():
43
43
  if isinstance(sub, CustomOp):
44
+ # NOTE: FusedMoE torch native implementaiton is not efficient
45
+ if "FusedMoE" in sub.__class__.__name__:
46
+ continue
44
47
  if reverse:
45
48
  sub._forward_method = sub.forward_cuda
46
49
  setattr(sub, "is_torch_compile", False)
@@ -105,23 +108,22 @@ class CudaGraphRunner:
105
108
  self.capture_bs = list(range(1, 32)) + [64, 128]
106
109
  else:
107
110
  self.capture_bs = [1, 2, 4] + [i * 8 for i in range(1, 21)]
108
- self.compile_bs = [1, 2, 4, 8, 16, 24, 32] if self.use_torch_compile else []
109
111
 
110
- # Common inputs
111
- self.max_bs = max(self.capture_bs)
112
- self.input_ids = torch.zeros((self.max_bs,), dtype=torch.int32, device="cuda")
113
- self.req_pool_indices = torch.zeros(
114
- (self.max_bs,), dtype=torch.int32, device="cuda"
115
- )
116
- self.seq_lens = torch.ones((self.max_bs,), dtype=torch.int32, device="cuda")
117
- self.position_ids_offsets = torch.ones(
118
- (self.max_bs,), dtype=torch.int32, device="cuda"
119
- )
120
- self.out_cache_loc = torch.zeros(
121
- (self.max_bs,), dtype=torch.int32, device="cuda"
112
+ self.capture_bs = [
113
+ bs for bs in self.capture_bs if bs <= model_runner.req_to_token_pool.size
114
+ ]
115
+ self.compile_bs = (
116
+ [
117
+ bs
118
+ for bs in self.capture_bs
119
+ if bs <= self.model_runner.server_args.max_torch_compile_bs
120
+ ]
121
+ if self.use_torch_compile
122
+ else []
122
123
  )
123
124
 
124
125
  # Attention backend
126
+ self.max_bs = max(self.capture_bs)
125
127
  self.model_runner.attn_backend.init_cuda_graph_state(self.max_bs)
126
128
  self.seq_len_fill_value = (
127
129
  self.model_runner.attn_backend.get_cuda_graph_seq_len_fill_value()
@@ -130,6 +132,16 @@ class CudaGraphRunner:
130
132
  if self.use_torch_compile:
131
133
  set_torch_compile_config()
132
134
 
135
+ # Common inputs
136
+ with torch.device("cuda"):
137
+ self.input_ids = torch.zeros((self.max_bs,), dtype=torch.int32)
138
+ self.req_pool_indices = torch.zeros((self.max_bs,), dtype=torch.int32)
139
+ self.seq_lens = torch.full(
140
+ (self.max_bs,), self.seq_len_fill_value, dtype=torch.int32
141
+ )
142
+ self.position_ids_offsets = torch.ones((self.max_bs,), dtype=torch.int32)
143
+ self.out_cache_loc = torch.zeros((self.max_bs,), dtype=torch.int32)
144
+
133
145
  # Capture
134
146
  try:
135
147
  self.capture()
@@ -40,7 +40,7 @@ from vllm.model_executor.models import ModelRegistry
40
40
  from sglang.srt.configs.model_config import AttentionArch, ModelConfig
41
41
  from sglang.srt.layers.attention_backend import FlashInferAttnBackend, TritonAttnBackend
42
42
  from sglang.srt.layers.logits_processor import LogitsProcessorOutput
43
- from sglang.srt.layers.sampler import SampleOutput, Sampler
43
+ from sglang.srt.layers.sampler import Sampler
44
44
  from sglang.srt.lora.lora_manager import LoRAManager
45
45
  from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict
46
46
  from sglang.srt.mem_cache.memory_pool import (
@@ -54,11 +54,9 @@ from sglang.srt.server_args import ServerArgs
54
54
  from sglang.srt.utils import (
55
55
  get_available_gpu_memory,
56
56
  is_generation_model,
57
- is_llama3_405b_fp8_head_16,
58
57
  is_multimodal_model,
59
58
  monkey_patch_vllm_dummy_weight_loader,
60
59
  monkey_patch_vllm_p2p_access_check,
61
- monkey_patch_vllm_qvk_linear_loader,
62
60
  )
63
61
 
64
62
  logger = logging.getLogger(__name__)
@@ -88,12 +86,20 @@ class ModelRunner:
88
86
  self.is_multimodal_model = is_multimodal_model(
89
87
  self.model_config.hf_config.architectures
90
88
  )
89
+
90
+ if (
91
+ self.model_config.attention_arch == AttentionArch.MLA
92
+ and not self.server_args.disable_mla
93
+ ):
94
+ logger.info("MLA optimization is tunred on. Use triton backend.")
95
+ self.server_args.attention_backend = "triton"
96
+
91
97
  global_server_args_dict.update(
92
98
  {
93
99
  "attention_backend": server_args.attention_backend,
94
100
  "sampling_backend": server_args.sampling_backend,
95
101
  "triton_attention_reduce_in_fp32": server_args.triton_attention_reduce_in_fp32,
96
- "enable_mla": server_args.enable_mla,
102
+ "disable_mla": server_args.disable_mla,
97
103
  "torchao_config": server_args.torchao_config,
98
104
  }
99
105
  )
@@ -166,10 +172,13 @@ class ModelRunner:
166
172
  return min_per_gpu_memory
167
173
 
168
174
  def load_model(self):
169
- torch.set_num_threads(1)
170
175
  logger.info(
171
176
  f"Load weight begin. avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
172
177
  )
178
+
179
+ # This can reduce thread conflicts and speed up weight loading.
180
+ torch.set_num_threads(1)
181
+
173
182
  if torch.cuda.get_device_capability()[0] < 8:
174
183
  logger.info(
175
184
  "Compute capability below sm80. Use float16 due to lack of bfloat16 support."
@@ -178,6 +187,7 @@ class ModelRunner:
178
187
  if torch.cuda.get_device_capability()[1] < 5:
179
188
  raise RuntimeError("SGLang only supports sm75 and above.")
180
189
 
190
+ # Prepare the vllm model config
181
191
  monkey_patch_vllm_dummy_weight_loader()
182
192
  self.device_config = DeviceConfig()
183
193
  self.load_config = LoadConfig(load_format=self.server_args.load_format)
@@ -188,23 +198,16 @@ class ModelRunner:
188
198
  tokenizer_mode=None,
189
199
  trust_remote_code=self.server_args.trust_remote_code,
190
200
  dtype=self.server_args.dtype,
191
- seed=42,
201
+ seed=self.server_args.random_seed,
192
202
  skip_tokenizer_init=True,
193
203
  )
194
-
195
- # A temporary hack to fix the num_heads for meta-llama/Meta-Llama-3.1-405B-FP8 checkpoints
196
- # Drop this after Sept, 2024.
197
- if is_llama3_405b_fp8_head_16(self.model_config) and self.tp_size <= 8:
198
- self.model_config.hf_config.num_key_value_heads = 8
199
- self.vllm_model_config.hf_config.num_key_value_heads = 8
200
- monkey_patch_vllm_qvk_linear_loader()
201
-
202
- self.dtype = self.vllm_model_config.dtype
203
204
  if self.model_config.model_override_args is not None:
204
205
  self.vllm_model_config.hf_config.update(
205
206
  self.model_config.model_override_args
206
207
  )
208
+ self.dtype = self.vllm_model_config.dtype
207
209
 
210
+ # Load the model
208
211
  self.model = get_model(
209
212
  model_config=self.vllm_model_config,
210
213
  load_config=self.load_config,
@@ -255,20 +258,20 @@ class ModelRunner:
255
258
  tokenizer_mode=None,
256
259
  trust_remote_code=self.server_args.trust_remote_code,
257
260
  dtype=self.server_args.dtype,
258
- seed=42,
261
+ seed=self.server_args.random_seed,
259
262
  skip_tokenizer_init=True,
260
263
  )
261
264
  except Exception as e:
262
- logger.error(f"Failed to load model config: {e}")
263
- return False, "Failed to update model weights"
265
+ message = f"Failed to load model config: {e}."
266
+ return False, message
264
267
 
265
268
  load_config = LoadConfig(load_format=load_format)
266
269
 
267
270
  # Only support vllm DefaultModelLoader for now
268
271
  loader = get_model_loader(load_config)
269
272
  if not isinstance(loader, DefaultModelLoader):
270
- logger.error("Failed to get weights iterator: Unsupported loader")
271
- return False, "Failed to update model weights"
273
+ message = f"Failed to get model loader: {loader}."
274
+ return False, message
272
275
 
273
276
  def get_weight_iter(config):
274
277
  iter = loader._get_weights_iterator(
@@ -293,14 +296,14 @@ class ModelRunner:
293
296
  try:
294
297
  iter = get_weight_iter(vllm_model_config)
295
298
  except Exception as e:
296
- message = f"Failed to get weights iterator: {e}"
297
- logger.error(message)
299
+ message = f"Failed to get weights iterator: {e}."
298
300
  return False, message
299
301
  try:
300
302
  model = model_load_weights(self.model, iter)
301
303
  except Exception as e:
302
- message = f"Failed to update weights: {e}. \n Rolling back to original weights"
303
- logger.error(message)
304
+ message = (
305
+ f"Failed to update weights: {e}.\nRolling back to original weights."
306
+ )
304
307
  del iter
305
308
  gc.collect()
306
309
  iter = get_weight_iter(self.vllm_model_config)
@@ -315,7 +318,7 @@ class ModelRunner:
315
318
  self.model_config.path = model_path
316
319
 
317
320
  logger.info("Update weights end.")
318
- return True, "Succeeded to update model weights"
321
+ return True, "Succeeded to update model weights."
319
322
 
320
323
  def init_lora_manager(self):
321
324
  self.lora_manager = LoRAManager(
@@ -334,7 +337,7 @@ class ModelRunner:
334
337
  )
335
338
  if (
336
339
  self.model_config.attention_arch == AttentionArch.MLA
337
- and self.server_args.enable_mla
340
+ and not self.server_args.disable_mla
338
341
  ):
339
342
  cell_size = (
340
343
  (self.model_config.kv_lora_rank + self.model_config.qk_rope_head_dim)
@@ -397,12 +400,12 @@ class ModelRunner:
397
400
  )
398
401
 
399
402
  self.req_to_token_pool = ReqToTokenPool(
400
- max_num_reqs,
401
- self.model_config.context_len + 8,
403
+ max_num_reqs + 1,
404
+ self.model_config.context_len + 4,
402
405
  )
403
406
  if (
404
407
  self.model_config.attention_arch == AttentionArch.MLA
405
- and self.server_args.enable_mla
408
+ and not self.server_args.disable_mla
406
409
  ):
407
410
  self.token_to_kv_pool = MLATokenToKVPool(
408
411
  self.max_total_num_tokens,
@@ -521,21 +524,6 @@ class ModelRunner:
521
524
  else:
522
525
  raise ValueError(f"Invaid forward mode: {batch.forward_mode}")
523
526
 
524
- def _check_sample_results(self, sample_output: SampleOutput):
525
- if not torch.all(sample_output.success):
526
- probs = sample_output.probs
527
- batch_next_token_ids = sample_output.batch_next_token_ids
528
- logging.warning("Sampling failed, fallback to top_k=1 strategy")
529
- probs = probs.masked_fill(torch.isnan(probs), 0.0)
530
- argmax_ids = torch.argmax(probs, dim=-1)
531
- batch_next_token_ids = torch.where(
532
- sample_output.success, batch_next_token_ids, argmax_ids
533
- )
534
- sample_output.probs = probs
535
- sample_output.batch_next_token_ids = batch_next_token_ids
536
-
537
- return sample_output.batch_next_token_ids
538
-
539
527
  def _apply_logits_bias(
540
528
  self, logits: torch.Tensor, sampling_info: SamplingBatchInfo
541
529
  ):
@@ -564,13 +552,16 @@ class ModelRunner:
564
552
  def sample(
565
553
  self, logits_output: LogitsProcessorOutput, batch: ScheduleBatch
566
554
  ) -> torch.Tensor:
555
+ # Put CPU-heavy tasks here. They will be overlapped with the forward pass.
567
556
  batch.sampling_info.update_regex_vocab_mask(batch)
568
557
  batch.sampling_info.update_penalties()
569
558
  logits = self._apply_logits_bias(
570
559
  logits_output.next_token_logits, batch.sampling_info
571
560
  )
572
- sample_output = self.sampler(logits, batch.sampling_info)
573
- return self._check_sample_results(sample_output)
561
+
562
+ # Sample the next tokens.
563
+ next_token_ids = self.sampler(logits, batch.sampling_info)
564
+ return next_token_ids
574
565
 
575
566
 
576
567
  @lru_cache()