sglang 0.4.1.post6__py3-none-any.whl → 0.4.1.post7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (122) hide show
  1. sglang/__init__.py +21 -23
  2. sglang/api.py +2 -7
  3. sglang/bench_offline_throughput.py +24 -16
  4. sglang/bench_one_batch.py +51 -3
  5. sglang/bench_one_batch_server.py +1 -1
  6. sglang/bench_serving.py +37 -28
  7. sglang/lang/backend/runtime_endpoint.py +183 -4
  8. sglang/lang/chat_template.py +15 -4
  9. sglang/launch_server.py +1 -1
  10. sglang/srt/_custom_ops.py +80 -42
  11. sglang/srt/configs/device_config.py +1 -1
  12. sglang/srt/configs/model_config.py +1 -0
  13. sglang/srt/constrained/base_grammar_backend.py +21 -0
  14. sglang/srt/constrained/xgrammar_backend.py +8 -4
  15. sglang/srt/conversation.py +14 -1
  16. sglang/srt/distributed/__init__.py +3 -3
  17. sglang/srt/distributed/communication_op.py +2 -1
  18. sglang/srt/distributed/device_communicators/cuda_wrapper.py +2 -1
  19. sglang/srt/distributed/device_communicators/custom_all_reduce.py +107 -40
  20. sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +2 -2
  21. sglang/srt/distributed/device_communicators/hpu_communicator.py +2 -1
  22. sglang/srt/distributed/device_communicators/pynccl.py +80 -1
  23. sglang/srt/distributed/device_communicators/pynccl_wrapper.py +112 -2
  24. sglang/srt/distributed/device_communicators/shm_broadcast.py +5 -72
  25. sglang/srt/distributed/device_communicators/xpu_communicator.py +2 -1
  26. sglang/srt/distributed/parallel_state.py +1 -1
  27. sglang/srt/distributed/utils.py +2 -1
  28. sglang/srt/entrypoints/engine.py +449 -0
  29. sglang/srt/entrypoints/http_server.py +579 -0
  30. sglang/srt/layers/activation.py +3 -3
  31. sglang/srt/layers/attention/flashinfer_backend.py +10 -9
  32. sglang/srt/layers/attention/triton_backend.py +4 -6
  33. sglang/srt/layers/attention/vision.py +204 -0
  34. sglang/srt/layers/dp_attention.py +69 -0
  35. sglang/srt/layers/linear.py +41 -5
  36. sglang/srt/layers/logits_processor.py +48 -63
  37. sglang/srt/layers/moe/ep_moe/layer.py +4 -4
  38. sglang/srt/layers/moe/fused_moe_native.py +69 -0
  39. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +9 -6
  40. sglang/srt/layers/moe/fused_moe_triton/layer.py +29 -5
  41. sglang/srt/layers/parameter.py +2 -1
  42. sglang/srt/layers/quantization/__init__.py +20 -23
  43. sglang/srt/layers/quantization/fp8.py +6 -3
  44. sglang/srt/layers/quantization/modelopt_quant.py +1 -2
  45. sglang/srt/layers/quantization/w8a8_int8.py +1 -1
  46. sglang/srt/layers/radix_attention.py +2 -2
  47. sglang/srt/layers/rotary_embedding.py +1179 -31
  48. sglang/srt/layers/sampler.py +39 -1
  49. sglang/srt/layers/vocab_parallel_embedding.py +2 -2
  50. sglang/srt/lora/lora.py +1 -9
  51. sglang/srt/managers/configure_logging.py +3 -0
  52. sglang/srt/managers/data_parallel_controller.py +79 -72
  53. sglang/srt/managers/detokenizer_manager.py +23 -6
  54. sglang/srt/managers/image_processor.py +158 -2
  55. sglang/srt/managers/io_struct.py +25 -2
  56. sglang/srt/managers/schedule_batch.py +49 -22
  57. sglang/srt/managers/schedule_policy.py +26 -12
  58. sglang/srt/managers/scheduler.py +277 -178
  59. sglang/srt/managers/session_controller.py +1 -0
  60. sglang/srt/managers/tokenizer_manager.py +206 -121
  61. sglang/srt/managers/tp_worker.py +6 -4
  62. sglang/srt/managers/tp_worker_overlap_thread.py +5 -8
  63. sglang/srt/managers/utils.py +44 -0
  64. sglang/srt/mem_cache/memory_pool.py +10 -32
  65. sglang/srt/metrics/collector.py +15 -6
  66. sglang/srt/model_executor/cuda_graph_runner.py +4 -6
  67. sglang/srt/model_executor/model_runner.py +37 -15
  68. sglang/srt/model_loader/loader.py +8 -6
  69. sglang/srt/model_loader/weight_utils.py +55 -2
  70. sglang/srt/models/baichuan.py +6 -6
  71. sglang/srt/models/chatglm.py +2 -2
  72. sglang/srt/models/commandr.py +3 -3
  73. sglang/srt/models/dbrx.py +4 -4
  74. sglang/srt/models/deepseek.py +3 -3
  75. sglang/srt/models/deepseek_v2.py +8 -8
  76. sglang/srt/models/exaone.py +2 -2
  77. sglang/srt/models/gemma.py +2 -2
  78. sglang/srt/models/gemma2.py +6 -24
  79. sglang/srt/models/gpt2.py +3 -5
  80. sglang/srt/models/gpt_bigcode.py +1 -1
  81. sglang/srt/models/granite.py +2 -2
  82. sglang/srt/models/grok.py +3 -3
  83. sglang/srt/models/internlm2.py +2 -2
  84. sglang/srt/models/llama.py +7 -5
  85. sglang/srt/models/minicpm.py +2 -2
  86. sglang/srt/models/minicpm3.py +6 -6
  87. sglang/srt/models/minicpmv.py +1238 -0
  88. sglang/srt/models/mixtral.py +3 -3
  89. sglang/srt/models/mixtral_quant.py +3 -3
  90. sglang/srt/models/mllama.py +2 -2
  91. sglang/srt/models/olmo.py +3 -3
  92. sglang/srt/models/olmo2.py +4 -4
  93. sglang/srt/models/olmoe.py +7 -13
  94. sglang/srt/models/phi3_small.py +2 -2
  95. sglang/srt/models/qwen.py +2 -2
  96. sglang/srt/models/qwen2.py +41 -4
  97. sglang/srt/models/qwen2_moe.py +3 -3
  98. sglang/srt/models/qwen2_vl.py +22 -122
  99. sglang/srt/models/stablelm.py +2 -2
  100. sglang/srt/models/torch_native_llama.py +3 -3
  101. sglang/srt/models/xverse.py +6 -6
  102. sglang/srt/models/xverse_moe.py +6 -6
  103. sglang/srt/openai_api/protocol.py +2 -0
  104. sglang/srt/sampling/custom_logit_processor.py +38 -0
  105. sglang/srt/sampling/sampling_batch_info.py +139 -4
  106. sglang/srt/sampling/sampling_params.py +3 -1
  107. sglang/srt/server.py +4 -1090
  108. sglang/srt/server_args.py +57 -14
  109. sglang/srt/utils.py +103 -65
  110. sglang/test/runners.py +8 -13
  111. sglang/test/test_programs.py +1 -1
  112. sglang/test/test_utils.py +3 -1
  113. sglang/utils.py +12 -2
  114. sglang/version.py +1 -1
  115. {sglang-0.4.1.post6.dist-info → sglang-0.4.1.post7.dist-info}/METADATA +16 -5
  116. {sglang-0.4.1.post6.dist-info → sglang-0.4.1.post7.dist-info}/RECORD +119 -115
  117. sglang/launch_server_llavavid.py +0 -25
  118. sglang/srt/constrained/__init__.py +0 -16
  119. sglang/srt/distributed/device_communicators/__init__.py +0 -0
  120. {sglang-0.4.1.post6.dist-info → sglang-0.4.1.post7.dist-info}/LICENSE +0 -0
  121. {sglang-0.4.1.post6.dist-info → sglang-0.4.1.post7.dist-info}/WHEEL +0 -0
  122. {sglang-0.4.1.post6.dist-info → sglang-0.4.1.post7.dist-info}/top_level.txt +0 -0
@@ -27,7 +27,7 @@ import logging
27
27
  import threading
28
28
  from enum import IntEnum
29
29
  from functools import wraps
30
- from typing import List, Tuple, Union
30
+ from typing import List, Optional, Tuple, Union
31
31
 
32
32
  import numpy as np
33
33
  import psutil
@@ -49,7 +49,6 @@ class ReqToTokenPool:
49
49
  size: int,
50
50
  max_context_len: int,
51
51
  device: str,
52
- use_records: bool,
53
52
  enable_memory_saver: bool,
54
53
  ):
55
54
  memory_saver_adapter = TorchMemorySaverAdapter.create(
@@ -64,17 +63,9 @@ class ReqToTokenPool:
64
63
  (size, max_context_len), dtype=torch.int32, device=device
65
64
  )
66
65
  self.free_slots = list(range(size))
67
- self.write_records = []
68
- self.use_records = use_records
69
-
70
- if self.use_records:
71
- self.write = self.write_with_records
72
- else:
73
- self.write = self.write_without_records
74
66
 
75
67
  def write(self, indices, values):
76
- # Keep the signature for type checking. It will be assigned during runtime.
77
- raise NotImplementedError()
68
+ self.req_to_token[indices] = values
78
69
 
79
70
  def available_size(self):
80
71
  return len(self.free_slots)
@@ -96,23 +87,6 @@ class ReqToTokenPool:
96
87
 
97
88
  def clear(self):
98
89
  self.free_slots = list(range(self.size))
99
- self.write_records = []
100
-
101
- def write_without_records(self, indices, values):
102
- self.req_to_token[indices] = values
103
-
104
- def write_with_records(self, indices, values):
105
- self.req_to_token[indices] = values
106
- self.write_records.append((indices, values))
107
-
108
- def get_write_records(self):
109
- ret = self.write_records
110
- self.write_records = []
111
- return ret
112
-
113
- def apply_write_records(self, write_records: List[Tuple]):
114
- for indices, values in write_records:
115
- self.req_to_token[indices] = values
116
90
 
117
91
 
118
92
  class BaseTokenToKVPool:
@@ -296,13 +270,17 @@ class MHATokenToKVPool(BaseTokenToKVPool):
296
270
  loc: torch.Tensor,
297
271
  cache_k: torch.Tensor,
298
272
  cache_v: torch.Tensor,
299
- k_scale: float = 1.0,
300
- v_scale: float = 1.0,
273
+ k_scale: Optional[float] = None,
274
+ v_scale: Optional[float] = None,
301
275
  ):
302
276
  layer_id = layer.layer_id
303
277
  if cache_k.dtype != self.dtype:
304
- cache_k = (cache_k / k_scale).to(self.dtype)
305
- cache_v = (cache_v / v_scale).to(self.dtype)
278
+ if k_scale is not None:
279
+ cache_k.div_(k_scale)
280
+ if v_scale is not None:
281
+ cache_v.div_(v_scale)
282
+ cache_k = cache_k.to(self.dtype)
283
+ cache_v = cache_v.to(self.dtype)
306
284
  if self.store_dtype != self.dtype:
307
285
  self.k_buffer[layer_id][loc] = cache_k.view(self.store_dtype)
308
286
  self.v_buffer[layer_id][loc] = cache_v.view(self.store_dtype)
@@ -25,6 +25,7 @@ class SchedulerStats:
25
25
  gen_throughput: float = 0.0
26
26
  num_queue_reqs: int = 0
27
27
  cache_hit_rate: float = 0.0
28
+ spec_accept_length: float = 0.0
28
29
 
29
30
 
30
31
  class SchedulerMetricsCollector:
@@ -37,42 +38,49 @@ class SchedulerMetricsCollector:
37
38
 
38
39
  self.num_running_reqs = Gauge(
39
40
  name="sglang:num_running_reqs",
40
- documentation="The number of running requests",
41
+ documentation="The number of running requests.",
41
42
  labelnames=labels.keys(),
42
43
  multiprocess_mode="sum",
43
44
  )
44
45
 
45
46
  self.num_used_tokens = Gauge(
46
47
  name="sglang:num_used_tokens",
47
- documentation="The number of used tokens",
48
+ documentation="The number of used tokens.",
48
49
  labelnames=labels.keys(),
49
50
  multiprocess_mode="sum",
50
51
  )
51
52
 
52
53
  self.token_usage = Gauge(
53
54
  name="sglang:token_usage",
54
- documentation="The token usage",
55
+ documentation="The token usage.",
55
56
  labelnames=labels.keys(),
56
57
  multiprocess_mode="mostrecent",
57
58
  )
58
59
 
59
60
  self.gen_throughput = Gauge(
60
61
  name="sglang:gen_throughput",
61
- documentation="The generate throughput (token/s)",
62
+ documentation="The generation throughput (token/s).",
62
63
  labelnames=labels.keys(),
63
64
  multiprocess_mode="sum",
64
65
  )
65
66
 
66
67
  self.num_queue_reqs = Gauge(
67
68
  name="sglang:num_queue_reqs",
68
- documentation="The number of requests in the waiting queue",
69
+ documentation="The number of requests in the waiting queue.",
69
70
  labelnames=labels.keys(),
70
71
  multiprocess_mode="sum",
71
72
  )
72
73
 
73
74
  self.cache_hit_rate = Gauge(
74
75
  name="sglang:cache_hit_rate",
75
- documentation="The cache hit rate",
76
+ documentation="The prefix cache hit rate.",
77
+ labelnames=labels.keys(),
78
+ multiprocess_mode="mostrecent",
79
+ )
80
+
81
+ self.spec_accept_length = Gauge(
82
+ name="sglang:spec_accept_length",
83
+ documentation="The average acceptance length of speculative decoding.",
76
84
  labelnames=labels.keys(),
77
85
  multiprocess_mode="mostrecent",
78
86
  )
@@ -88,6 +96,7 @@ class SchedulerMetricsCollector:
88
96
  self._log_gauge(self.gen_throughput, stats.gen_throughput)
89
97
  self._log_gauge(self.num_queue_reqs, stats.num_queue_reqs)
90
98
  self._log_gauge(self.cache_hit_rate, stats.cache_hit_rate)
99
+ self._log_gauge(self.spec_accept_length, stats.spec_accept_length)
91
100
 
92
101
 
93
102
  class TokenizerMetricsCollector:
@@ -21,10 +21,10 @@ from typing import TYPE_CHECKING, Callable
21
21
 
22
22
  import torch
23
23
  import tqdm
24
- from vllm.distributed import get_tensor_model_parallel_rank
25
- from vllm.distributed.parallel_state import graph_capture
26
24
  from vllm.model_executor.custom_op import CustomOp
27
25
 
26
+ from sglang.srt.distributed import get_tensor_model_parallel_rank
27
+ from sglang.srt.distributed.parallel_state import graph_capture
28
28
  from sglang.srt.layers.logits_processor import LogitsProcessorOutput
29
29
  from sglang.srt.layers.moe.fused_moe_native import fused_moe_forward_native
30
30
  from sglang.srt.layers.torchao_utils import save_gemlite_cache
@@ -33,7 +33,6 @@ from sglang.srt.model_executor.forward_batch_info import (
33
33
  ForwardBatch,
34
34
  ForwardMode,
35
35
  )
36
- from sglang.srt.utils import monkey_patch_vllm_all_gather
37
36
 
38
37
  if TYPE_CHECKING:
39
38
  from sglang.srt.model_executor.model_runner import ModelRunner
@@ -72,7 +71,6 @@ def patch_model(
72
71
  try:
73
72
  if enable_compile:
74
73
  _to_torch(model, reverse=False, batch_size=batch_size)
75
- monkey_patch_vllm_all_gather()
76
74
  backup_ca_comm = tp_group.ca_comm
77
75
  # Use custom-allreduce here.
78
76
  # We found the custom allreduce is much faster than the built-in allreduce in torch,
@@ -88,7 +86,6 @@ def patch_model(
88
86
  finally:
89
87
  if enable_compile:
90
88
  _to_torch(model, reverse=True, batch_size=batch_size)
91
- monkey_patch_vllm_all_gather(reverse=True)
92
89
  tp_group.ca_comm = backup_ca_comm
93
90
 
94
91
 
@@ -122,6 +119,7 @@ class CudaGraphRunner:
122
119
  self.is_encoder_decoder = self.model_runner.model_config.is_encoder_decoder
123
120
  self.enable_dp_attention = self.model_runner.server_args.enable_dp_attention
124
121
  self.tp_size = self.model_runner.tp_size
122
+ self.dp_size = self.model_runner.server_args.dp_size
125
123
 
126
124
  # Batch sizes to capture
127
125
  self.capture_bs = self.model_runner.server_args.cuda_graph_bs
@@ -218,7 +216,7 @@ class CudaGraphRunner:
218
216
  if self.enable_dp_attention:
219
217
  self.gathered_buffer = torch.zeros(
220
218
  (
221
- self.max_bs * self.tp_size,
219
+ self.max_bs * self.dp_size,
222
220
  self.model_runner.model_config.hidden_size,
223
221
  ),
224
222
  dtype=self.model_runner.dtype,
@@ -21,20 +21,26 @@ from typing import List, Optional, Tuple
21
21
 
22
22
  import torch
23
23
  import torch.distributed as dist
24
- from vllm.distributed import (
24
+
25
+ from sglang.srt.configs.device_config import DeviceConfig
26
+ from sglang.srt.configs.load_config import LoadConfig
27
+ from sglang.srt.configs.model_config import AttentionArch, ModelConfig
28
+ from sglang.srt.distributed import (
25
29
  get_tp_group,
26
30
  init_distributed_environment,
27
31
  initialize_model_parallel,
28
32
  set_custom_all_reduce,
29
33
  )
30
-
31
- from sglang.srt.configs.device_config import DeviceConfig
32
- from sglang.srt.configs.load_config import LoadConfig
33
- from sglang.srt.configs.model_config import AttentionArch, ModelConfig
34
+ from sglang.srt.distributed.parallel_state import monkey_patch_vllm_parallel_state
34
35
  from sglang.srt.layers.attention.double_sparsity_backend import DoubleSparseAttnBackend
35
36
  from sglang.srt.layers.attention.flashinfer_backend import FlashInferAttnBackend
36
37
  from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBackend
37
38
  from sglang.srt.layers.attention.triton_backend import TritonAttnBackend
39
+ from sglang.srt.layers.dp_attention import (
40
+ get_attention_tp_group,
41
+ get_attention_tp_size,
42
+ initialize_dp_attention,
43
+ )
38
44
  from sglang.srt.layers.logits_processor import LogitsProcessorOutput
39
45
  from sglang.srt.layers.sampler import Sampler
40
46
  from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model
@@ -57,8 +63,8 @@ from sglang.srt.utils import (
57
63
  init_custom_process_group,
58
64
  is_cuda,
59
65
  is_hip,
66
+ monkey_patch_p2p_access_check,
60
67
  monkey_patch_vllm_gguf_config,
61
- monkey_patch_vllm_p2p_access_check,
62
68
  set_cpu_offload_max_bytes,
63
69
  )
64
70
 
@@ -101,8 +107,10 @@ class ModelRunner:
101
107
  self.model_config.attention_arch == AttentionArch.MLA
102
108
  and not self.server_args.disable_mla
103
109
  ):
104
- logger.info("MLA optimization is turned on. Use triton backend.")
105
- self.server_args.attention_backend = "triton"
110
+ # TODO: add MLA optimization on CPU
111
+ if self.server_args.device != "cpu":
112
+ logger.info("MLA optimization is turned on. Use triton backend.")
113
+ self.server_args.attention_backend = "triton"
106
114
 
107
115
  if self.server_args.enable_double_sparsity:
108
116
  logger.info(
@@ -159,6 +167,7 @@ class ModelRunner:
159
167
  "enable_nan_detection": server_args.enable_nan_detection,
160
168
  "enable_dp_attention": server_args.enable_dp_attention,
161
169
  "enable_ep_moe": server_args.enable_ep_moe,
170
+ "device": server_args.device,
162
171
  }
163
172
  )
164
173
 
@@ -216,9 +225,12 @@ class ModelRunner:
216
225
  backend = "gloo"
217
226
  elif self.device == "hpu":
218
227
  backend = "hccl"
228
+ elif self.device == "cpu":
229
+ backend = "gloo"
219
230
 
220
231
  if not self.server_args.enable_p2p_check:
221
- monkey_patch_vllm_p2p_access_check(self.gpu_id)
232
+ monkey_patch_p2p_access_check()
233
+
222
234
  if self.server_args.dist_init_addr:
223
235
  dist_init_method = f"tcp://{self.server_args.dist_init_addr}"
224
236
  else:
@@ -226,7 +238,7 @@ class ModelRunner:
226
238
  set_custom_all_reduce(not self.server_args.disable_custom_all_reduce)
227
239
 
228
240
  if not self.is_draft_worker:
229
- # Only initilzie the distributed environment on the target model worker.
241
+ # Only initialize the distributed environment on the target model worker.
230
242
  init_distributed_environment(
231
243
  backend=backend,
232
244
  world_size=self.tp_size,
@@ -235,11 +247,18 @@ class ModelRunner:
235
247
  distributed_init_method=dist_init_method,
236
248
  )
237
249
  initialize_model_parallel(tensor_model_parallel_size=self.tp_size)
250
+ initialize_dp_attention(
251
+ enable_dp_attention=self.server_args.enable_dp_attention,
252
+ tp_rank=self.tp_rank,
253
+ tp_size=self.tp_size,
254
+ dp_size=self.server_args.dp_size,
255
+ )
238
256
 
239
257
  min_per_gpu_memory = get_available_gpu_memory(
240
258
  self.device, self.gpu_id, distributed=self.tp_size > 1
241
259
  )
242
260
  self.tp_group = get_tp_group()
261
+ self.attention_tp_group = get_attention_tp_group()
243
262
 
244
263
  # Check memory for tensor parallelism
245
264
  if self.tp_size > 1:
@@ -257,7 +276,8 @@ class ModelRunner:
257
276
  )
258
277
 
259
278
  # This can reduce thread conflicts and speed up weight loading.
260
- torch.set_num_threads(1)
279
+ if self.device != "cpu":
280
+ torch.set_num_threads(1)
261
281
  if self.device == "cuda":
262
282
  if torch.cuda.get_device_capability()[0] < 8:
263
283
  logger.info(
@@ -277,12 +297,15 @@ class ModelRunner:
277
297
  monkey_patch_vllm_gguf_config()
278
298
 
279
299
  # Load the model
300
+ # Remove monkey_patch when linear.py quant remove dependencies with vllm
301
+ monkey_patch_vllm_parallel_state()
280
302
  with self.memory_saver_adapter.region():
281
303
  self.model = get_model(
282
304
  model_config=self.model_config,
283
305
  load_config=self.load_config,
284
306
  device_config=DeviceConfig(self.device),
285
307
  )
308
+ monkey_patch_vllm_parallel_state(reverse=True)
286
309
 
287
310
  if self.server_args.kv_cache_dtype == "fp8_e4m3":
288
311
  if self.server_args.quantization_param_path is not None:
@@ -521,7 +544,7 @@ class ModelRunner:
521
544
  )
522
545
  else:
523
546
  cell_size = (
524
- self.model_config.get_num_kv_heads(self.tp_size)
547
+ self.model_config.get_num_kv_heads(get_attention_tp_size())
525
548
  * self.model_config.head_dim
526
549
  * self.model_config.num_hidden_layers
527
550
  * 2
@@ -595,7 +618,6 @@ class ModelRunner:
595
618
  size=max_num_reqs + 1,
596
619
  max_context_len=self.model_config.context_len + 4,
597
620
  device=self.device,
598
- use_records=False,
599
621
  enable_memory_saver=self.server_args.enable_memory_saver,
600
622
  )
601
623
  if (
@@ -615,7 +637,7 @@ class ModelRunner:
615
637
  self.token_to_kv_pool = DoubleSparseTokenToKVPool(
616
638
  self.max_total_num_tokens,
617
639
  dtype=self.kv_cache_dtype,
618
- head_num=self.model_config.get_num_kv_heads(self.tp_size),
640
+ head_num=self.model_config.get_num_kv_heads(get_attention_tp_size()),
619
641
  head_dim=self.model_config.head_dim,
620
642
  layer_num=self.model_config.num_hidden_layers,
621
643
  device=self.device,
@@ -626,7 +648,7 @@ class ModelRunner:
626
648
  self.token_to_kv_pool = MHATokenToKVPool(
627
649
  self.max_total_num_tokens,
628
650
  dtype=self.kv_cache_dtype,
629
- head_num=self.model_config.get_num_kv_heads(self.tp_size),
651
+ head_num=self.model_config.get_num_kv_heads(get_attention_tp_size()),
630
652
  head_dim=self.model_config.head_dim,
631
653
  layer_num=self.model_config.num_hidden_layers,
632
654
  device=self.device,
@@ -21,14 +21,14 @@ from huggingface_hub import HfApi, hf_hub_download
21
21
  from torch import nn
22
22
  from transformers import AutoModelForCausalLM, PretrainedConfig
23
23
  from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
24
- from vllm.distributed import (
25
- get_tensor_model_parallel_rank,
26
- get_tensor_model_parallel_world_size,
27
- )
28
24
 
29
25
  from sglang.srt.configs.device_config import DeviceConfig
30
26
  from sglang.srt.configs.load_config import LoadConfig, LoadFormat
31
27
  from sglang.srt.configs.model_config import ModelConfig
28
+ from sglang.srt.distributed import (
29
+ get_tensor_model_parallel_rank,
30
+ get_tensor_model_parallel_world_size,
31
+ )
32
32
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
33
33
  from sglang.srt.model_loader.utils import (
34
34
  get_model_architecture,
@@ -496,7 +496,8 @@ class ShardedStateLoader(BaseModelLoader):
496
496
  device_config: DeviceConfig,
497
497
  ) -> nn.Module:
498
498
  from safetensors.torch import safe_open
499
- from vllm.distributed import get_tensor_model_parallel_rank
499
+
500
+ from sglang.srt.distributed import get_tensor_model_parallel_rank
500
501
 
501
502
  local_model_path = self._prepare_weights(
502
503
  model_config.model_path, model_config.revision
@@ -556,7 +557,8 @@ class ShardedStateLoader(BaseModelLoader):
556
557
  max_size: Optional[int] = None,
557
558
  ) -> None:
558
559
  from safetensors.torch import save_file
559
- from vllm.distributed import get_tensor_model_parallel_rank
560
+
561
+ from sglang.srt.distributed import get_tensor_model_parallel_rank
560
562
 
561
563
  if pattern is None:
562
564
  pattern = ShardedStateLoader.DEFAULT_PATTERN
@@ -9,7 +9,17 @@ import logging
9
9
  import os
10
10
  import tempfile
11
11
  from collections import defaultdict
12
- from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union
12
+ from typing import (
13
+ Any,
14
+ Callable,
15
+ Dict,
16
+ Generator,
17
+ Iterable,
18
+ List,
19
+ Optional,
20
+ Tuple,
21
+ Union,
22
+ )
13
23
 
14
24
  import filelock
15
25
  import gguf
@@ -19,10 +29,10 @@ import torch
19
29
  from huggingface_hub import HfFileSystem, hf_hub_download, snapshot_download
20
30
  from safetensors.torch import load_file, safe_open, save_file
21
31
  from tqdm.auto import tqdm
22
- from vllm.distributed import get_tensor_model_parallel_rank
23
32
 
24
33
  from sglang.srt.configs.load_config import LoadConfig
25
34
  from sglang.srt.configs.model_config import ModelConfig
35
+ from sglang.srt.distributed import get_tensor_model_parallel_rank
26
36
  from sglang.srt.layers.quantization import QuantizationConfig, get_quantization_config
27
37
  from sglang.srt.utils import print_warning_once
28
38
 
@@ -638,3 +648,46 @@ def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> Optional[str]:
638
648
 
639
649
  # If there were no matches, return the untouched param name
640
650
  return name
651
+
652
+
653
+ def kv_cache_scales_loader(
654
+ filename: str,
655
+ tp_rank: int,
656
+ tp_size: int,
657
+ num_hidden_layers: int,
658
+ model_type: Optional[str],
659
+ ) -> Iterable[Tuple[int, float]]:
660
+ """
661
+ A simple utility to read in KV cache scaling factors that have been
662
+ previously serialized to disk. Used by the model to populate the appropriate
663
+ KV cache scaling factors. The serialization should represent a dictionary
664
+ whose keys are the TP ranks and values are another dictionary mapping layers
665
+ to their KV cache scaling factors.
666
+ """
667
+ try:
668
+ with open(filename) as f:
669
+ context = {
670
+ "model_type": model_type,
671
+ "num_hidden_layers": num_hidden_layers,
672
+ "tp_rank": tp_rank,
673
+ "tp_size": tp_size,
674
+ }
675
+ schema_dct = json.load(f)
676
+ schema = QuantParamSchema.model_validate(schema_dct, context=context)
677
+ layer_scales_map = schema.kv_cache.scaling_factor[tp_rank]
678
+ return layer_scales_map.items()
679
+ except FileNotFoundError:
680
+ logger.error("File or directory '%s' not found.", filename)
681
+ except json.JSONDecodeError:
682
+ logger.error("Error decoding JSON in file '%s'.", filename)
683
+ except Exception:
684
+ logger.exception("An error occurred while reading '%s'.", filename)
685
+ # This section is reached if and only if any of the excepts are hit
686
+ # Return an empty iterable (list) => no KV cache scales are loaded
687
+ # which ultimately defaults to 1.0 scales
688
+ logger.warning(
689
+ "Defaulting to KV cache scaling factors = 1.0 for all "
690
+ "layers in TP rank %d as an error occurred during loading.",
691
+ tp_rank,
692
+ )
693
+ return []
@@ -24,22 +24,22 @@ from typing import Iterable, Optional, Tuple
24
24
  import torch
25
25
  from torch import nn
26
26
  from transformers import PretrainedConfig
27
- from vllm.distributed import (
27
+
28
+ from sglang.srt.distributed import (
28
29
  get_tensor_model_parallel_rank,
29
30
  get_tensor_model_parallel_world_size,
30
31
  )
31
- from vllm.model_executor.layers.linear import (
32
+ from sglang.srt.layers.activation import SiluAndMul
33
+ from sglang.srt.layers.layernorm import RMSNorm
34
+ from sglang.srt.layers.linear import (
32
35
  MergedColumnParallelLinear,
33
36
  QKVParallelLinear,
34
37
  RowParallelLinear,
35
38
  )
36
- from vllm.model_executor.layers.rotary_embedding import get_rope
37
-
38
- from sglang.srt.layers.activation import SiluAndMul
39
- from sglang.srt.layers.layernorm import RMSNorm
40
39
  from sglang.srt.layers.logits_processor import LogitsProcessor
41
40
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
42
41
  from sglang.srt.layers.radix_attention import RadixAttention
42
+ from sglang.srt.layers.rotary_embedding import get_rope
43
43
  from sglang.srt.layers.vocab_parallel_embedding import (
44
44
  ParallelLMHead,
45
45
  VocabParallelEmbedding,
@@ -21,10 +21,9 @@ from typing import Iterable, Optional, Tuple
21
21
  import torch
22
22
  from torch import nn
23
23
  from torch.nn import LayerNorm
24
- from vllm.distributed import get_tensor_model_parallel_world_size
25
- from vllm.model_executor.layers.rotary_embedding import get_rope
26
24
 
27
25
  from sglang.srt.configs import ChatGLMConfig
26
+ from sglang.srt.distributed import get_tensor_model_parallel_world_size
28
27
  from sglang.srt.layers.activation import SiluAndMul
29
28
  from sglang.srt.layers.layernorm import RMSNorm
30
29
  from sglang.srt.layers.linear import (
@@ -35,6 +34,7 @@ from sglang.srt.layers.linear import (
35
34
  from sglang.srt.layers.logits_processor import LogitsProcessor
36
35
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
37
36
  from sglang.srt.layers.radix_attention import RadixAttention
37
+ from sglang.srt.layers.rotary_embedding import get_rope
38
38
  from sglang.srt.layers.vocab_parallel_embedding import (
39
39
  ParallelLMHead,
40
40
  VocabParallelEmbedding,
@@ -44,12 +44,11 @@ import torch.utils.checkpoint
44
44
  from torch import nn
45
45
  from torch.nn.parameter import Parameter
46
46
  from transformers import PretrainedConfig
47
- from vllm.distributed import (
47
+
48
+ from sglang.srt.distributed import (
48
49
  get_tensor_model_parallel_rank,
49
50
  get_tensor_model_parallel_world_size,
50
51
  )
51
- from vllm.model_executor.layers.rotary_embedding import get_rope
52
-
53
52
  from sglang.srt.layers.activation import SiluAndMul
54
53
  from sglang.srt.layers.linear import (
55
54
  MergedColumnParallelLinear,
@@ -59,6 +58,7 @@ from sglang.srt.layers.linear import (
59
58
  from sglang.srt.layers.logits_processor import LogitsProcessor
60
59
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
61
60
  from sglang.srt.layers.radix_attention import RadixAttention
61
+ from sglang.srt.layers.rotary_embedding import get_rope
62
62
  from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
63
63
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
64
64
  from sglang.srt.model_loader.weight_utils import default_weight_loader
sglang/srt/models/dbrx.py CHANGED
@@ -19,14 +19,13 @@ from typing import Iterable, Optional, Tuple
19
19
 
20
20
  import torch
21
21
  import torch.nn as nn
22
- from vllm.distributed import (
22
+
23
+ from sglang.srt.configs import DbrxConfig
24
+ from sglang.srt.distributed import (
23
25
  get_tensor_model_parallel_rank,
24
26
  get_tensor_model_parallel_world_size,
25
27
  tensor_model_parallel_all_reduce,
26
28
  )
27
- from vllm.model_executor.layers.rotary_embedding import get_rope
28
-
29
- from sglang.srt.configs import DbrxConfig
30
29
  from sglang.srt.layers.linear import (
31
30
  QKVParallelLinear,
32
31
  ReplicatedLinear,
@@ -36,6 +35,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
36
35
  from sglang.srt.layers.moe.fused_moe_triton import fused_moe
37
36
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
38
37
  from sglang.srt.layers.radix_attention import RadixAttention
38
+ from sglang.srt.layers.rotary_embedding import get_rope
39
39
  from sglang.srt.layers.vocab_parallel_embedding import (
40
40
  DEFAULT_VOCAB_PADDING_SIZE,
41
41
  ParallelLMHead,
@@ -21,13 +21,12 @@ from typing import Any, Dict, Iterable, Optional, Tuple
21
21
  import torch
22
22
  from torch import nn
23
23
  from transformers import PretrainedConfig
24
- from vllm.distributed import (
24
+
25
+ from sglang.srt.distributed import (
25
26
  get_tensor_model_parallel_rank,
26
27
  get_tensor_model_parallel_world_size,
27
28
  tensor_model_parallel_all_reduce,
28
29
  )
29
- from vllm.model_executor.layers.rotary_embedding import get_rope
30
-
31
30
  from sglang.srt.layers.activation import SiluAndMul
32
31
  from sglang.srt.layers.layernorm import RMSNorm
33
32
  from sglang.srt.layers.linear import (
@@ -40,6 +39,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
40
39
  from sglang.srt.layers.moe.fused_moe_triton import fused_moe
41
40
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
42
41
  from sglang.srt.layers.radix_attention import RadixAttention
42
+ from sglang.srt.layers.rotary_embedding import get_rope
43
43
  from sglang.srt.layers.vocab_parallel_embedding import (
44
44
  ParallelLMHead,
45
45
  VocabParallelEmbedding,
@@ -23,14 +23,13 @@ import torch.nn.functional as F
23
23
  from torch import nn
24
24
  from transformers import PretrainedConfig
25
25
  from vllm import _custom_ops as ops
26
- from vllm.distributed import (
26
+
27
+ from sglang.srt.distributed import (
27
28
  get_tensor_model_parallel_rank,
28
29
  get_tensor_model_parallel_world_size,
29
30
  get_tp_group,
30
31
  tensor_model_parallel_all_reduce,
31
32
  )
32
- from vllm.model_executor.layers.rotary_embedding import get_rope
33
-
34
33
  from sglang.srt.layers.activation import SiluAndMul
35
34
  from sglang.srt.layers.layernorm import RMSNorm
36
35
  from sglang.srt.layers.linear import (
@@ -49,6 +48,7 @@ from sglang.srt.layers.quantization.fp8_utils import (
49
48
  normalize_e4m3fn_to_e4m3fnuz,
50
49
  )
51
50
  from sglang.srt.layers.radix_attention import RadixAttention
51
+ from sglang.srt.layers.rotary_embedding import get_rope, get_rope_wrapper
52
52
  from sglang.srt.layers.vocab_parallel_embedding import (
53
53
  ParallelLMHead,
54
54
  VocabParallelEmbedding,
@@ -271,13 +271,14 @@ class DeepseekV2Attention(nn.Module):
271
271
  quant_config=quant_config,
272
272
  )
273
273
  rope_scaling["rope_type"] = "deepseek_yarn"
274
- self.rotary_emb = get_rope(
274
+ self.rotary_emb = get_rope_wrapper(
275
275
  qk_rope_head_dim,
276
276
  rotary_dim=qk_rope_head_dim,
277
277
  max_position=max_position_embeddings,
278
278
  base=rope_theta,
279
279
  rope_scaling=rope_scaling,
280
280
  is_neox_style=False,
281
+ device=global_server_args_dict["device"],
281
282
  )
282
283
 
283
284
  if rope_scaling:
@@ -855,10 +856,9 @@ class DeepseekV2ForCausalLM(nn.Module):
855
856
  forward_batch: ForwardBatch,
856
857
  ) -> torch.Tensor:
857
858
  hidden_states = self.model(input_ids, positions, forward_batch)
858
- if not forward_batch.forward_mode.is_idle():
859
- return self.logits_processor(
860
- input_ids, hidden_states, self.lm_head, forward_batch
861
- )
859
+ return self.logits_processor(
860
+ input_ids, hidden_states, self.lm_head, forward_batch
861
+ )
862
862
 
863
863
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
864
864
  stacked_params_mapping = [