sglang 0.4.0__py3-none-any.whl → 0.4.0.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (72) hide show
  1. sglang/__init__.py +1 -1
  2. sglang/bench_offline_throughput.py +18 -6
  3. sglang/bench_one_batch.py +13 -0
  4. sglang/bench_serving.py +8 -1
  5. sglang/check_env.py +140 -48
  6. sglang/lang/backend/runtime_endpoint.py +1 -0
  7. sglang/lang/chat_template.py +32 -0
  8. sglang/llama3_eval.py +316 -0
  9. sglang/srt/constrained/outlines_backend.py +5 -0
  10. sglang/srt/constrained/xgrammar_backend.py +9 -6
  11. sglang/srt/layers/attention/__init__.py +5 -2
  12. sglang/srt/layers/attention/double_sparsity_backend.py +22 -8
  13. sglang/srt/layers/attention/flashinfer_backend.py +22 -5
  14. sglang/srt/layers/attention/torch_native_backend.py +22 -8
  15. sglang/srt/layers/attention/triton_backend.py +38 -33
  16. sglang/srt/layers/attention/triton_ops/decode_attention.py +305 -350
  17. sglang/srt/layers/attention/triton_ops/extend_attention.py +3 -0
  18. sglang/srt/layers/ep_moe/__init__.py +0 -0
  19. sglang/srt/layers/ep_moe/kernels.py +349 -0
  20. sglang/srt/layers/ep_moe/layer.py +665 -0
  21. sglang/srt/layers/fused_moe_triton/fused_moe.py +64 -21
  22. sglang/srt/layers/fused_moe_triton/layer.py +1 -1
  23. sglang/srt/layers/logits_processor.py +133 -95
  24. sglang/srt/layers/quantization/__init__.py +2 -47
  25. sglang/srt/layers/quantization/fp8.py +607 -0
  26. sglang/srt/layers/quantization/fp8_utils.py +27 -0
  27. sglang/srt/layers/radix_attention.py +11 -2
  28. sglang/srt/layers/sampler.py +29 -5
  29. sglang/srt/layers/torchao_utils.py +58 -45
  30. sglang/srt/managers/detokenizer_manager.py +37 -17
  31. sglang/srt/managers/io_struct.py +39 -10
  32. sglang/srt/managers/schedule_batch.py +39 -24
  33. sglang/srt/managers/schedule_policy.py +64 -5
  34. sglang/srt/managers/scheduler.py +236 -197
  35. sglang/srt/managers/tokenizer_manager.py +99 -58
  36. sglang/srt/managers/tp_worker_overlap_thread.py +7 -5
  37. sglang/srt/mem_cache/base_prefix_cache.py +2 -2
  38. sglang/srt/mem_cache/chunk_cache.py +2 -2
  39. sglang/srt/mem_cache/memory_pool.py +5 -1
  40. sglang/srt/mem_cache/radix_cache.py +12 -2
  41. sglang/srt/model_executor/cuda_graph_runner.py +39 -11
  42. sglang/srt/model_executor/model_runner.py +24 -9
  43. sglang/srt/model_parallel.py +67 -10
  44. sglang/srt/models/commandr.py +2 -2
  45. sglang/srt/models/deepseek_v2.py +87 -7
  46. sglang/srt/models/gemma2.py +34 -0
  47. sglang/srt/models/gemma2_reward.py +0 -1
  48. sglang/srt/models/granite.py +517 -0
  49. sglang/srt/models/grok.py +72 -13
  50. sglang/srt/models/llama.py +22 -5
  51. sglang/srt/models/llama_classification.py +11 -23
  52. sglang/srt/models/llama_reward.py +0 -2
  53. sglang/srt/models/llava.py +37 -14
  54. sglang/srt/models/mixtral.py +12 -9
  55. sglang/srt/models/phi3_small.py +0 -5
  56. sglang/srt/models/qwen2.py +20 -0
  57. sglang/srt/models/qwen2_moe.py +0 -5
  58. sglang/srt/models/torch_native_llama.py +0 -5
  59. sglang/srt/openai_api/adapter.py +4 -0
  60. sglang/srt/openai_api/protocol.py +9 -4
  61. sglang/srt/sampling/sampling_batch_info.py +9 -8
  62. sglang/srt/server.py +4 -4
  63. sglang/srt/server_args.py +62 -13
  64. sglang/srt/utils.py +57 -10
  65. sglang/test/test_utils.py +3 -2
  66. sglang/utils.py +10 -3
  67. sglang/version.py +1 -1
  68. {sglang-0.4.0.dist-info → sglang-0.4.0.post2.dist-info}/METADATA +15 -9
  69. {sglang-0.4.0.dist-info → sglang-0.4.0.post2.dist-info}/RECORD +72 -65
  70. {sglang-0.4.0.dist-info → sglang-0.4.0.post2.dist-info}/LICENSE +0 -0
  71. {sglang-0.4.0.dist-info → sglang-0.4.0.post2.dist-info}/WHEEL +0 -0
  72. {sglang-0.4.0.dist-info → sglang-0.4.0.post2.dist-info}/top_level.txt +0 -0
@@ -22,7 +22,7 @@ import signal
22
22
  import sys
23
23
  import time
24
24
  import uuid
25
- from typing import Dict, List, Optional, Tuple, Union
25
+ from typing import Any, Dict, List, Optional, Union
26
26
 
27
27
  import fastapi
28
28
  import uvloop
@@ -76,6 +76,7 @@ class ReqState:
76
76
  out_list: List
77
77
  finished: bool
78
78
  event: asyncio.Event
79
+ obj: Any
79
80
 
80
81
  # For metrics
81
82
  created_time: float
@@ -283,7 +284,7 @@ class TokenizerManager:
283
284
  ):
284
285
  """Wait for the response of one request."""
285
286
  event = asyncio.Event()
286
- state = ReqState([], False, event, created_time=created_time)
287
+ state = ReqState([], False, event, obj, created_time=created_time)
287
288
  self.rid_to_state[obj.rid] = state
288
289
 
289
290
  while True:
@@ -295,15 +296,7 @@ class TokenizerManager:
295
296
  raise ValueError(f"Abort request {obj.rid}")
296
297
  continue
297
298
 
298
- if isinstance(obj, GenerateReqInput):
299
- out = self.convert_logprob_style(
300
- state.out_list[-1],
301
- obj.return_logprob,
302
- obj.top_logprobs_num,
303
- obj.return_text_in_logprobs,
304
- )
305
- else: # isinstance(obj, (EmbeddingReqInput,))
306
- out = state.out_list[-1]
299
+ out = state.out_list[-1]
307
300
 
308
301
  state.out_list = []
309
302
  if state.finished:
@@ -315,7 +308,13 @@ class TokenizerManager:
315
308
  break
316
309
 
317
310
  state.event.clear()
318
- yield out
311
+
312
+ if obj.stream:
313
+ yield out
314
+ else:
315
+ if request is not None and await request.is_disconnected():
316
+ self.abort_request(obj.rid)
317
+ raise ValueError(f"Abort request {obj.rid}")
319
318
 
320
319
  async def _handle_batch_request(
321
320
  self,
@@ -573,7 +572,7 @@ class TokenizerManager:
573
572
 
574
573
  async def sigterm_watchdog(self):
575
574
  while not self.gracefully_exit:
576
- await asyncio.sleep(60)
575
+ await asyncio.sleep(5)
577
576
 
578
577
  # drain requests
579
578
  while True:
@@ -609,29 +608,55 @@ class TokenizerManager:
609
608
  if state is None:
610
609
  continue
611
610
 
612
- recv_obj.meta_info[i]["id"] = rid
611
+ meta_info = {
612
+ "id": rid,
613
+ "finish_reason": recv_obj.finished_reasons[i],
614
+ "prompt_tokens": recv_obj.prompt_tokens[i],
615
+ }
616
+
617
+ if getattr(state.obj, "return_logprob", False):
618
+ self.convert_logprob_style(
619
+ meta_info,
620
+ state.obj.top_logprobs_num,
621
+ state.obj.return_text_in_logprobs,
622
+ recv_obj,
623
+ i,
624
+ )
625
+
626
+ if not isinstance(recv_obj, BatchEmbeddingOut):
627
+ meta_info.update(
628
+ {
629
+ "completion_tokens": recv_obj.completion_tokens[i],
630
+ "cached_tokens": recv_obj.cached_tokens[i],
631
+ }
632
+ )
633
+
613
634
  if isinstance(recv_obj, BatchStrOut):
614
635
  out_dict = {
615
636
  "text": recv_obj.output_strs[i],
616
- "meta_info": recv_obj.meta_info[i],
637
+ "meta_info": meta_info,
617
638
  }
618
639
  elif isinstance(recv_obj, BatchTokenIDOut):
619
640
  out_dict = {
620
641
  "token_ids": recv_obj.output_ids[i],
621
- "meta_info": recv_obj.meta_info[i],
642
+ "meta_info": meta_info,
622
643
  }
623
644
  else:
624
645
  assert isinstance(recv_obj, BatchEmbeddingOut)
625
646
  out_dict = {
626
647
  "embedding": recv_obj.embeddings[i],
627
- "meta_info": recv_obj.meta_info[i],
648
+ "meta_info": meta_info,
628
649
  }
629
650
  state.out_list.append(out_dict)
630
- state.finished = recv_obj.finished_reason[i] is not None
651
+ state.finished = recv_obj.finished_reasons[i] is not None
631
652
  state.event.set()
632
653
 
633
654
  if self.enable_metrics:
634
- completion_tokens = recv_obj.meta_info[i]["completion_tokens"]
655
+ completion_tokens = (
656
+ recv_obj.completion_tokens[i]
657
+ if recv_obj.completion_tokens
658
+ else 0
659
+ )
635
660
 
636
661
  if state.first_token_time is None:
637
662
  state.first_token_time = time.time()
@@ -647,7 +672,7 @@ class TokenizerManager:
647
672
 
648
673
  if state.finished:
649
674
  self.metrics_collector.inc_prompt_tokens(
650
- recv_obj.meta_info[i]["prompt_tokens"]
675
+ recv_obj.prompt_tokens[i]
651
676
  )
652
677
  self.metrics_collector.inc_generation_tokens(
653
678
  completion_tokens
@@ -696,57 +721,73 @@ class TokenizerManager:
696
721
 
697
722
  def convert_logprob_style(
698
723
  self,
699
- ret: dict,
700
- return_logprob: bool,
724
+ meta_info: dict,
701
725
  top_logprobs_num: int,
702
726
  return_text_in_logprobs: bool,
727
+ recv_obj: BatchStrOut,
728
+ recv_obj_index: int,
703
729
  ):
704
- if return_logprob:
705
- ret["meta_info"]["input_token_logprobs"] = self.detokenize_logprob_tokens(
706
- ret["meta_info"]["input_token_logprobs"], return_text_in_logprobs
730
+ meta_info["input_token_logprobs"] = self.detokenize_logprob_tokens(
731
+ recv_obj.input_token_logprobs_val[recv_obj_index],
732
+ recv_obj.input_token_logprobs_idx[recv_obj_index],
733
+ return_text_in_logprobs,
734
+ )
735
+ meta_info["output_token_logprobs"] = self.detokenize_logprob_tokens(
736
+ recv_obj.output_token_logprobs_val[recv_obj_index],
737
+ recv_obj.output_token_logprobs_idx[recv_obj_index],
738
+ return_text_in_logprobs,
739
+ )
740
+ meta_info["normalized_prompt_logprob"] = recv_obj.normalized_prompt_logprob[
741
+ recv_obj_index
742
+ ]
743
+
744
+ if top_logprobs_num > 0:
745
+ meta_info["input_top_logprobs"] = self.detokenize_top_logprobs_tokens(
746
+ recv_obj.input_top_logprobs_val[recv_obj_index],
747
+ recv_obj.input_top_logprobs_idx[recv_obj_index],
748
+ return_text_in_logprobs,
707
749
  )
708
- ret["meta_info"]["output_token_logprobs"] = self.detokenize_logprob_tokens(
709
- ret["meta_info"]["output_token_logprobs"], return_text_in_logprobs
750
+ meta_info["output_top_logprobs"] = self.detokenize_top_logprobs_tokens(
751
+ recv_obj.output_top_logprobs_val[recv_obj_index],
752
+ recv_obj.output_top_logprobs_idx[recv_obj_index],
753
+ return_text_in_logprobs,
710
754
  )
711
755
 
712
- if top_logprobs_num > 0:
713
- ret["meta_info"]["input_top_logprobs"] = (
714
- self.detokenize_top_logprobs_tokens(
715
- ret["meta_info"]["input_top_logprobs"],
716
- return_text_in_logprobs,
717
- )
718
- )
719
- ret["meta_info"]["output_top_logprobs"] = (
720
- self.detokenize_top_logprobs_tokens(
721
- ret["meta_info"]["output_top_logprobs"], return_text_in_logprobs
722
- )
723
- )
724
- return ret
725
-
726
756
  def detokenize_logprob_tokens(
727
- self, token_logprobs: List[Tuple[float, int]], decode_to_text: bool
757
+ self,
758
+ token_logprobs_val: List[float],
759
+ token_logprobs_idx: List[int],
760
+ decode_to_text: bool,
728
761
  ):
729
- # TODO(lianmin): This should run on DetokenizerManager
730
762
  if not decode_to_text:
731
- return [(logprob, token_id, None) for logprob, token_id in token_logprobs]
732
-
733
- assert self.tokenizer is not None
734
- token_ids = [tid for _, tid in token_logprobs]
735
- token_texts = self.tokenizer.batch_decode(token_ids)
736
- return [
737
- (logprob, token_id, token_text)
738
- for (logprob, token_id), token_text in zip(token_logprobs, token_texts)
739
- ]
763
+ return [
764
+ (logprob, token_id, None)
765
+ for logprob, token_id in zip(token_logprobs_val, token_logprobs_idx)
766
+ ]
767
+ else:
768
+ assert self.tokenizer is not None
769
+ token_texts = self.tokenizer.batch_decode(token_logprobs_idx)
770
+ return list(zip(token_logprobs_val, token_logprobs_idx, token_texts))
740
771
 
741
- def detokenize_top_logprobs_tokens(self, top_logprobs, decode_to_text: bool):
772
+ def detokenize_top_logprobs_tokens(
773
+ self,
774
+ token_logprobs_val: List[float],
775
+ token_logprobs_idx: List[int],
776
+ decode_to_text: bool,
777
+ ):
742
778
  # TODO: The current implementation only batches the detokenization for top-k tokens per single position.
743
779
  # We should batch all top-k tokens in all positions.
744
- for i, token_top_logprobs in enumerate(top_logprobs):
745
- if token_top_logprobs:
746
- top_logprobs[i] = self.detokenize_logprob_tokens(
747
- token_top_logprobs, decode_to_text
780
+ ret = []
781
+ for i in range(len(token_logprobs_val)):
782
+ if token_logprobs_val[i]:
783
+ ret.append(
784
+ self.detokenize_logprob_tokens(
785
+ token_logprobs_val[i], token_logprobs_idx[i], decode_to_text
786
+ )
748
787
  )
749
- return top_logprobs
788
+ else:
789
+ ret.append(None)
790
+ return ret
750
791
 
751
792
 
752
793
  class SignalHandler:
@@ -32,12 +32,13 @@ from sglang.srt.managers.io_struct import (
32
32
  from sglang.srt.managers.schedule_batch import ModelWorkerBatch
33
33
  from sglang.srt.managers.tp_worker import TpModelWorker
34
34
  from sglang.srt.server_args import ServerArgs
35
+ from sglang.srt.utils import get_compiler_backend
35
36
  from sglang.utils import get_exception_traceback
36
37
 
37
38
  logger = logging.getLogger(__name__)
38
39
 
39
40
 
40
- @torch.compile(dynamic=True)
41
+ @torch.compile(dynamic=True, backend=get_compiler_backend())
41
42
  def resolve_future_token_ids(input_ids, future_token_ids_map):
42
43
  input_ids[:] = torch.where(
43
44
  input_ids < 0,
@@ -73,12 +74,13 @@ class TpModelWorkerClient:
73
74
  # Launch threads
74
75
  self.input_queue = Queue()
75
76
  self.output_queue = Queue()
76
- self.forward_stream = torch.cuda.Stream()
77
+ self.forward_stream = torch.get_device_module(self.device).Stream()
77
78
  self.forward_thread = threading.Thread(
78
79
  target=self.forward_thread_func,
79
80
  )
80
81
  self.forward_thread.start()
81
82
  self.parent_process = psutil.Process().parent()
83
+ self.scheduler_stream = torch.get_device_module(self.device).current_stream()
82
84
 
83
85
  def get_worker_info(self):
84
86
  return self.worker.get_worker_info()
@@ -97,7 +99,7 @@ class TpModelWorkerClient:
97
99
 
98
100
  def forward_thread_func(self):
99
101
  try:
100
- with torch.cuda.stream(self.forward_stream):
102
+ with torch.get_device_module(self.device).stream(self.forward_stream):
101
103
  self.forward_thread_func_()
102
104
  except Exception:
103
105
  traceback = get_exception_traceback()
@@ -122,7 +124,7 @@ class TpModelWorkerClient:
122
124
 
123
125
  # Create event
124
126
  self.launch_done = threading.Event()
125
- copy_done = torch.cuda.Event()
127
+ copy_done = torch.get_device_module(self.device).Event()
126
128
 
127
129
  # Resolve future tokens in the input
128
130
  input_ids = model_worker_batch.input_ids
@@ -190,7 +192,7 @@ class TpModelWorkerClient:
190
192
  )
191
193
 
192
194
  # A cuda stream sync here to avoid the cuda illegal memory access error.
193
- torch.cuda.current_stream().synchronize()
195
+ self.scheduler_stream.synchronize()
194
196
 
195
197
  # Push a new batch to the queue
196
198
  self.input_queue.put((model_worker_batch, self.future_token_ids_ct))
@@ -1,5 +1,5 @@
1
1
  from abc import ABC, abstractmethod
2
- from typing import Callable
2
+ from typing import Callable, List, Tuple
3
3
 
4
4
 
5
5
  class BasePrefixCache(ABC):
@@ -10,7 +10,7 @@ class BasePrefixCache(ABC):
10
10
  pass
11
11
 
12
12
  @abstractmethod
13
- def match_prefix(self, **kwargs):
13
+ def match_prefix(self, **kwargs) -> Tuple[List[int], int]:
14
14
  pass
15
15
 
16
16
  @abstractmethod
@@ -2,7 +2,7 @@ from __future__ import annotations
2
2
 
3
3
  """Cache for chunked prefill, used when RadixCache is disabled."""
4
4
 
5
- from typing import TYPE_CHECKING, Callable, List, Optional
5
+ from typing import TYPE_CHECKING, Callable, List, Optional, Tuple
6
6
 
7
7
  from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
8
8
  from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
@@ -30,7 +30,7 @@ class ChunkCache(BasePrefixCache):
30
30
  def reset(self):
31
31
  self.entries = {}
32
32
 
33
- def match_prefix(self, rid: int, key: List[int]):
33
+ def match_prefix(self, rid: int, key: List[int]) -> Tuple[List[int], int]:
34
34
  if rid not in self.entries:
35
35
  return [], None
36
36
 
@@ -27,6 +27,7 @@ from typing import List, Tuple, Union
27
27
  import torch
28
28
 
29
29
  from sglang.srt.layers.radix_attention import RadixAttention
30
+ from sglang.srt.utils import get_compiler_backend
30
31
 
31
32
  logger = logging.getLogger(__name__)
32
33
 
@@ -129,6 +130,9 @@ class BaseTokenToKVPool:
129
130
  return select_index.to(self.device, non_blocking=True)
130
131
 
131
132
  def free(self, free_index: torch.Tensor):
133
+ if free_index.numel() == 0:
134
+ return
135
+
132
136
  if self.is_not_in_free_group:
133
137
  self.free_slots = torch.concat((self.free_slots, free_index.cpu()))
134
138
  else:
@@ -234,7 +238,7 @@ class MHATokenToKVPool(BaseTokenToKVPool):
234
238
 
235
239
  # This compiled version is slower in the unit test
236
240
  # python3 -m unittest test_bench_serving.TestBenchServing.test_offline_throughput_non_stream_small_batch_size
237
- @torch.compile(dynamic=True)
241
+ @torch.compile(dynamic=True, backend=get_compiler_backend())
238
242
  def copy_two_array(loc, dst_1, src_1, dst_2, src_2, dtype, store_dtype):
239
243
  dst_1[loc] = src_1.to(dtype).view(store_dtype)
240
244
  dst_2[loc] = src_2.to(dtype).view(store_dtype)
@@ -22,7 +22,7 @@ The radix tree data structure for managing the KV cache.
22
22
  import heapq
23
23
  import time
24
24
  from collections import defaultdict
25
- from typing import TYPE_CHECKING, Callable, List, Optional
25
+ from typing import TYPE_CHECKING, Callable, List, Optional, Tuple
26
26
 
27
27
  import torch
28
28
 
@@ -76,7 +76,17 @@ class RadixCache(BasePrefixCache):
76
76
  self.root_node.lock_ref = 1
77
77
  self.evictable_size_ = 0
78
78
 
79
- def match_prefix(self, key: List, **kwargs):
79
+ def match_prefix(self, key: List[int], **kwargs) -> Tuple[torch.Tensor, int]:
80
+ """Find the matching prefix from the radix tree.
81
+ Args:
82
+ key: A list of token IDs to find a matching prefix.
83
+ Returns:
84
+ A tuple of a tensor of matching prefix token IDs and
85
+ the last node that contains the prefix values. Note that
86
+ this API can modify the internal state of the Radix tree.
87
+ The last node create a new child if the prefix is shorter
88
+ than the last node's value.
89
+ """
80
90
  if self.disable:
81
91
  return [], self.root_node
82
92
 
@@ -20,6 +20,8 @@ from contextlib import contextmanager
20
20
  from typing import TYPE_CHECKING, Callable
21
21
 
22
22
  import torch
23
+ import tqdm
24
+ from vllm.distributed import get_tensor_model_parallel_rank
23
25
  from vllm.distributed.parallel_state import graph_capture
24
26
  from vllm.model_executor.custom_op import CustomOp
25
27
 
@@ -47,7 +49,7 @@ def _to_torch(model: torch.nn.Module, reverse: bool, batch_size: int):
47
49
  if "FusedMoE" in sub.__class__.__name__:
48
50
  if batch_size == 1:
49
51
  # The performance of torch.compile on this layer is not always good when bs > 1,
50
- # so we decide to skip it for now.
52
+ # so we decide to only use torch.compile when bs =1
51
53
  sub._forward_method = fused_moe_forward_native
52
54
  else:
53
55
  sub._forward_method = sub.forward_native
@@ -127,9 +129,23 @@ class CudaGraphRunner:
127
129
 
128
130
  # Batch sizes to capture
129
131
  if model_runner.server_args.disable_cuda_graph_padding:
130
- self.capture_bs = list(range(1, 32)) + [64, 128]
132
+ self.capture_bs = list(range(1, 33)) + [64, 128]
131
133
  else:
132
134
  self.capture_bs = [1, 2, 4] + [i * 8 for i in range(1, 21)]
135
+
136
+ if max(self.capture_bs) > model_runner.req_to_token_pool.size:
137
+ # In some case (e.g., with a small GPU or --max-running-requests), the #max-running-requests
138
+ # is very samll. We add more values here to make sure we capture the maximum bs.
139
+ self.capture_bs = list(
140
+ sorted(
141
+ set(
142
+ self.capture_bs
143
+ + [model_runner.req_to_token_pool.size - 1]
144
+ + [model_runner.req_to_token_pool.size]
145
+ )
146
+ )
147
+ )
148
+
133
149
  self.capture_bs = [
134
150
  bs
135
151
  for bs in self.capture_bs
@@ -241,7 +257,12 @@ class CudaGraphRunner:
241
257
  def capture(self):
242
258
  with graph_capture() as graph_capture_context:
243
259
  self.stream = graph_capture_context.stream
244
- for bs in self.capture_bs:
260
+ capture_bs = (
261
+ tqdm.tqdm(self.capture_bs)
262
+ if get_tensor_model_parallel_rank() == 0
263
+ else self.capture_bs
264
+ )
265
+ for bs in capture_bs:
245
266
  with patch_model(
246
267
  self.model_runner.model,
247
268
  bs in self.compile_bs,
@@ -373,8 +394,14 @@ class CudaGraphRunner:
373
394
 
374
395
  # Extract logprobs
375
396
  if forward_batch.return_logprob:
376
- next_token_logprobs = torch.nn.functional.log_softmax(
377
- next_token_logits, dim=-1
397
+ logits_metadata = LogitsMetadata(
398
+ forward_mode=ForwardMode.DECODE,
399
+ top_logprobs_nums=forward_batch.top_logprobs_nums,
400
+ )
401
+ next_token_logprobs = (
402
+ LogitsProcessor.compute_temp_top_p_normalized_logprobs(
403
+ next_token_logits, logits_metadata
404
+ )
378
405
  )
379
406
  logits_output = LogitsProcessorOutput(
380
407
  next_token_logits=next_token_logits,
@@ -382,13 +409,14 @@ class CudaGraphRunner:
382
409
  )
383
410
  return_top_logprob = any(x > 0 for x in forward_batch.top_logprobs_nums)
384
411
  if return_top_logprob:
385
- logits_metadata = LogitsMetadata(
386
- forward_mode=ForwardMode.DECODE,
387
- top_logprobs_nums=forward_batch.top_logprobs_nums,
388
- )
389
- logits_output.output_top_logprobs = LogitsProcessor.get_top_logprobs(
412
+ (
413
+ logits_output.output_top_logprobs_val,
414
+ logits_output.output_top_logprobs_idx,
415
+ ) = LogitsProcessor.get_top_logprobs(
390
416
  next_token_logprobs, logits_metadata
391
- )[1]
417
+ )[
418
+ 2:4
419
+ ]
392
420
  else:
393
421
  logits_output = LogitsProcessorOutput(
394
422
  next_token_logits=next_token_logits,
@@ -27,7 +27,6 @@ from vllm.distributed import (
27
27
  initialize_model_parallel,
28
28
  set_custom_all_reduce,
29
29
  )
30
- from vllm.distributed.parallel_state import in_the_same_node_as
31
30
 
32
31
  from sglang.srt.configs.device_config import DeviceConfig
33
32
  from sglang.srt.configs.load_config import LoadConfig
@@ -38,6 +37,7 @@ from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBack
38
37
  from sglang.srt.layers.attention.triton_backend import TritonAttnBackend
39
38
  from sglang.srt.layers.logits_processor import LogitsProcessorOutput
40
39
  from sglang.srt.layers.sampler import Sampler
40
+ from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model
41
41
  from sglang.srt.lora.lora_manager import LoRAManager
42
42
  from sglang.srt.managers.schedule_batch import global_server_args_dict
43
43
  from sglang.srt.mem_cache.memory_pool import (
@@ -111,15 +111,20 @@ class ModelRunner:
111
111
  )
112
112
 
113
113
  if self.is_multimodal:
114
- logger.info(
115
- "Automatically turn off --chunked-prefill-size and adjust --mem-fraction-static for multimodal models."
116
- )
117
- server_args.chunked_prefill_size = -1
118
114
  self.mem_fraction_static *= 0.95
115
+ if self.model_config.hf_config.architectures == [
116
+ "MllamaForConditionalGeneration"
117
+ ]:
118
+ logger.info("Automatically turn off --chunked-prefill-size for mllama.")
119
+ server_args.chunked_prefill_size = -1
119
120
  # TODO: qwen2-vl does not support radix cache now, set disable_radix_cache=True automatically
120
121
  if self.model_config.hf_config.architectures == [
121
122
  "Qwen2VLForConditionalGeneration"
122
123
  ]:
124
+ logger.info(
125
+ "Automatically turn off --chunked-prefill-size and disable radix cache for qwen2-vl."
126
+ )
127
+ server_args.chunked_prefill_size = -1
123
128
  server_args.disable_radix_cache = True
124
129
 
125
130
  # Global vars
@@ -139,6 +144,7 @@ class ModelRunner:
139
144
  "torchao_config": server_args.torchao_config,
140
145
  "enable_nan_detection": server_args.enable_nan_detection,
141
146
  "enable_dp_attention": server_args.enable_dp_attention,
147
+ "enable_ep_moe": server_args.enable_ep_moe,
142
148
  }
143
149
  )
144
150
 
@@ -151,6 +157,11 @@ class ModelRunner:
151
157
  self.sampler = Sampler()
152
158
  self.load_model()
153
159
 
160
+ # Apply torchao quantization
161
+ apply_torchao_config_to_model(
162
+ self.model, global_server_args_dict["torchao_config"]
163
+ )
164
+
154
165
  # Apply torch TP if the model supports it
155
166
  supports_torch_tp = getattr(self.model, "supports_torch_tp", False)
156
167
  if self.tp_size > 1 and supports_torch_tp:
@@ -235,20 +246,22 @@ class ModelRunner:
235
246
  if torch.cuda.get_device_capability()[1] < 5:
236
247
  raise RuntimeError("SGLang only supports sm75 and above.")
237
248
 
238
- # Prepare the vllm model config
249
+ # Prepare the model config
239
250
  self.load_config = LoadConfig(
240
251
  load_format=self.server_args.load_format,
241
252
  download_dir=self.server_args.download_dir,
242
253
  )
243
-
244
254
  if self.server_args.load_format == "gguf":
245
255
  monkey_patch_vllm_gguf_config()
256
+
257
+ # Load the model
246
258
  self.model = get_model(
247
259
  model_config=self.model_config,
248
260
  load_config=self.load_config,
249
261
  device_config=DeviceConfig(self.device),
250
262
  )
251
263
 
264
+ # Parse other args
252
265
  self.sliding_window_size = (
253
266
  self.model.get_attention_sliding_window_size()
254
267
  if hasattr(self.model, "get_attention_sliding_window_size")
@@ -263,8 +276,10 @@ class ModelRunner:
263
276
  f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
264
277
  )
265
278
 
266
- def update_weights_from_disk(self, model_path: str, load_format: str):
267
- """Update engine weights online from disk."""
279
+ def update_weights_from_disk(
280
+ self, model_path: str, load_format: str
281
+ ) -> tuple[bool, str]:
282
+ """Update engine weights in-place from the disk."""
268
283
  from sglang.srt.model_loader.loader import (
269
284
  DefaultModelLoader,
270
285
  device_loading_context,