sglang 0.4.1.post5__py3-none-any.whl → 0.4.1.post7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (129) hide show
  1. sglang/__init__.py +21 -23
  2. sglang/api.py +2 -7
  3. sglang/bench_offline_throughput.py +24 -16
  4. sglang/bench_one_batch.py +51 -3
  5. sglang/bench_one_batch_server.py +1 -1
  6. sglang/bench_serving.py +37 -28
  7. sglang/lang/backend/runtime_endpoint.py +183 -4
  8. sglang/lang/chat_template.py +15 -4
  9. sglang/launch_server.py +1 -1
  10. sglang/srt/_custom_ops.py +80 -42
  11. sglang/srt/configs/device_config.py +1 -1
  12. sglang/srt/configs/model_config.py +16 -6
  13. sglang/srt/constrained/base_grammar_backend.py +21 -0
  14. sglang/srt/constrained/xgrammar_backend.py +8 -4
  15. sglang/srt/conversation.py +14 -1
  16. sglang/srt/distributed/__init__.py +3 -3
  17. sglang/srt/distributed/communication_op.py +2 -1
  18. sglang/srt/distributed/device_communicators/cuda_wrapper.py +2 -1
  19. sglang/srt/distributed/device_communicators/custom_all_reduce.py +107 -40
  20. sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +2 -2
  21. sglang/srt/distributed/device_communicators/hpu_communicator.py +2 -1
  22. sglang/srt/distributed/device_communicators/pynccl.py +80 -1
  23. sglang/srt/distributed/device_communicators/pynccl_wrapper.py +112 -2
  24. sglang/srt/distributed/device_communicators/shm_broadcast.py +5 -72
  25. sglang/srt/distributed/device_communicators/xpu_communicator.py +2 -1
  26. sglang/srt/distributed/parallel_state.py +1 -1
  27. sglang/srt/distributed/utils.py +2 -1
  28. sglang/srt/entrypoints/engine.py +449 -0
  29. sglang/srt/entrypoints/http_server.py +579 -0
  30. sglang/srt/layers/activation.py +3 -3
  31. sglang/srt/layers/attention/flashinfer_backend.py +27 -12
  32. sglang/srt/layers/attention/triton_backend.py +4 -6
  33. sglang/srt/layers/attention/vision.py +204 -0
  34. sglang/srt/layers/dp_attention.py +69 -0
  35. sglang/srt/layers/linear.py +76 -102
  36. sglang/srt/layers/logits_processor.py +48 -63
  37. sglang/srt/layers/moe/ep_moe/layer.py +4 -4
  38. sglang/srt/layers/moe/fused_moe_native.py +69 -0
  39. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +9 -6
  40. sglang/srt/layers/moe/fused_moe_triton/layer.py +66 -14
  41. sglang/srt/layers/moe/topk.py +4 -2
  42. sglang/srt/layers/parameter.py +26 -17
  43. sglang/srt/layers/quantization/__init__.py +22 -23
  44. sglang/srt/layers/quantization/fp8.py +112 -55
  45. sglang/srt/layers/quantization/fp8_utils.py +1 -1
  46. sglang/srt/layers/quantization/int8_kernel.py +54 -0
  47. sglang/srt/layers/quantization/modelopt_quant.py +2 -3
  48. sglang/srt/layers/quantization/w8a8_int8.py +117 -0
  49. sglang/srt/layers/radix_attention.py +2 -0
  50. sglang/srt/layers/rotary_embedding.py +1179 -31
  51. sglang/srt/layers/sampler.py +39 -1
  52. sglang/srt/layers/vocab_parallel_embedding.py +17 -4
  53. sglang/srt/lora/lora.py +1 -9
  54. sglang/srt/managers/configure_logging.py +46 -0
  55. sglang/srt/managers/data_parallel_controller.py +79 -72
  56. sglang/srt/managers/detokenizer_manager.py +23 -8
  57. sglang/srt/managers/image_processor.py +158 -2
  58. sglang/srt/managers/io_struct.py +54 -15
  59. sglang/srt/managers/schedule_batch.py +49 -22
  60. sglang/srt/managers/schedule_policy.py +26 -12
  61. sglang/srt/managers/scheduler.py +319 -181
  62. sglang/srt/managers/session_controller.py +1 -0
  63. sglang/srt/managers/tokenizer_manager.py +303 -158
  64. sglang/srt/managers/tp_worker.py +6 -4
  65. sglang/srt/managers/tp_worker_overlap_thread.py +5 -8
  66. sglang/srt/managers/utils.py +44 -0
  67. sglang/srt/mem_cache/memory_pool.py +110 -77
  68. sglang/srt/metrics/collector.py +25 -11
  69. sglang/srt/model_executor/cuda_graph_runner.py +4 -6
  70. sglang/srt/model_executor/model_runner.py +80 -21
  71. sglang/srt/model_loader/loader.py +8 -6
  72. sglang/srt/model_loader/weight_utils.py +55 -2
  73. sglang/srt/models/baichuan.py +6 -6
  74. sglang/srt/models/chatglm.py +2 -2
  75. sglang/srt/models/commandr.py +3 -3
  76. sglang/srt/models/dbrx.py +4 -4
  77. sglang/srt/models/deepseek.py +3 -3
  78. sglang/srt/models/deepseek_v2.py +8 -8
  79. sglang/srt/models/exaone.py +2 -2
  80. sglang/srt/models/gemma.py +2 -2
  81. sglang/srt/models/gemma2.py +6 -24
  82. sglang/srt/models/gpt2.py +3 -5
  83. sglang/srt/models/gpt_bigcode.py +1 -1
  84. sglang/srt/models/granite.py +2 -2
  85. sglang/srt/models/grok.py +3 -3
  86. sglang/srt/models/internlm2.py +2 -2
  87. sglang/srt/models/llama.py +41 -4
  88. sglang/srt/models/minicpm.py +2 -2
  89. sglang/srt/models/minicpm3.py +6 -6
  90. sglang/srt/models/minicpmv.py +1238 -0
  91. sglang/srt/models/mixtral.py +3 -3
  92. sglang/srt/models/mixtral_quant.py +3 -3
  93. sglang/srt/models/mllama.py +2 -2
  94. sglang/srt/models/olmo.py +3 -3
  95. sglang/srt/models/olmo2.py +4 -4
  96. sglang/srt/models/olmoe.py +7 -13
  97. sglang/srt/models/phi3_small.py +2 -2
  98. sglang/srt/models/qwen.py +2 -2
  99. sglang/srt/models/qwen2.py +52 -4
  100. sglang/srt/models/qwen2_eagle.py +131 -0
  101. sglang/srt/models/qwen2_moe.py +3 -3
  102. sglang/srt/models/qwen2_vl.py +22 -122
  103. sglang/srt/models/stablelm.py +2 -2
  104. sglang/srt/models/torch_native_llama.py +3 -3
  105. sglang/srt/models/xverse.py +6 -6
  106. sglang/srt/models/xverse_moe.py +6 -6
  107. sglang/srt/openai_api/protocol.py +2 -0
  108. sglang/srt/sampling/custom_logit_processor.py +38 -0
  109. sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +15 -5
  110. sglang/srt/sampling/sampling_batch_info.py +153 -9
  111. sglang/srt/sampling/sampling_params.py +4 -2
  112. sglang/srt/server.py +4 -1037
  113. sglang/srt/server_args.py +84 -32
  114. sglang/srt/speculative/eagle_worker.py +1 -0
  115. sglang/srt/torch_memory_saver_adapter.py +59 -0
  116. sglang/srt/utils.py +130 -63
  117. sglang/test/runners.py +8 -13
  118. sglang/test/test_programs.py +1 -1
  119. sglang/test/test_utils.py +3 -1
  120. sglang/utils.py +12 -2
  121. sglang/version.py +1 -1
  122. {sglang-0.4.1.post5.dist-info → sglang-0.4.1.post7.dist-info}/METADATA +26 -13
  123. {sglang-0.4.1.post5.dist-info → sglang-0.4.1.post7.dist-info}/RECORD +126 -117
  124. sglang/launch_server_llavavid.py +0 -25
  125. sglang/srt/constrained/__init__.py +0 -16
  126. sglang/srt/distributed/device_communicators/__init__.py +0 -0
  127. {sglang-0.4.1.post5.dist-info → sglang-0.4.1.post7.dist-info}/LICENSE +0 -0
  128. {sglang-0.4.1.post5.dist-info → sglang-0.4.1.post7.dist-info}/WHEEL +0 -0
  129. {sglang-0.4.1.post5.dist-info → sglang-0.4.1.post7.dist-info}/top_level.txt +0 -0
@@ -21,20 +21,26 @@ from typing import List, Optional, Tuple
21
21
 
22
22
  import torch
23
23
  import torch.distributed as dist
24
- from vllm.distributed import (
24
+
25
+ from sglang.srt.configs.device_config import DeviceConfig
26
+ from sglang.srt.configs.load_config import LoadConfig
27
+ from sglang.srt.configs.model_config import AttentionArch, ModelConfig
28
+ from sglang.srt.distributed import (
25
29
  get_tp_group,
26
30
  init_distributed_environment,
27
31
  initialize_model_parallel,
28
32
  set_custom_all_reduce,
29
33
  )
30
-
31
- from sglang.srt.configs.device_config import DeviceConfig
32
- from sglang.srt.configs.load_config import LoadConfig
33
- from sglang.srt.configs.model_config import AttentionArch, ModelConfig
34
+ from sglang.srt.distributed.parallel_state import monkey_patch_vllm_parallel_state
34
35
  from sglang.srt.layers.attention.double_sparsity_backend import DoubleSparseAttnBackend
35
36
  from sglang.srt.layers.attention.flashinfer_backend import FlashInferAttnBackend
36
37
  from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBackend
37
38
  from sglang.srt.layers.attention.triton_backend import TritonAttnBackend
39
+ from sglang.srt.layers.dp_attention import (
40
+ get_attention_tp_group,
41
+ get_attention_tp_size,
42
+ initialize_dp_attention,
43
+ )
38
44
  from sglang.srt.layers.logits_processor import LogitsProcessorOutput
39
45
  from sglang.srt.layers.sampler import Sampler
40
46
  from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model
@@ -50,13 +56,15 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
50
56
  from sglang.srt.model_loader import get_model
51
57
  from sglang.srt.server_args import ServerArgs
52
58
  from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
59
+ from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
53
60
  from sglang.srt.utils import (
54
61
  enable_show_time_cost,
55
62
  get_available_gpu_memory,
56
63
  init_custom_process_group,
64
+ is_cuda,
57
65
  is_hip,
66
+ monkey_patch_p2p_access_check,
58
67
  monkey_patch_vllm_gguf_config,
59
- monkey_patch_vllm_p2p_access_check,
60
68
  set_cpu_offload_max_bytes,
61
69
  )
62
70
 
@@ -99,8 +107,10 @@ class ModelRunner:
99
107
  self.model_config.attention_arch == AttentionArch.MLA
100
108
  and not self.server_args.disable_mla
101
109
  ):
102
- logger.info("MLA optimization is turned on. Use triton backend.")
103
- self.server_args.attention_backend = "triton"
110
+ # TODO: add MLA optimization on CPU
111
+ if self.server_args.device != "cpu":
112
+ logger.info("MLA optimization is turned on. Use triton backend.")
113
+ self.server_args.attention_backend = "triton"
104
114
 
105
115
  if self.server_args.enable_double_sparsity:
106
116
  logger.info(
@@ -157,6 +167,7 @@ class ModelRunner:
157
167
  "enable_nan_detection": server_args.enable_nan_detection,
158
168
  "enable_dp_attention": server_args.enable_dp_attention,
159
169
  "enable_ep_moe": server_args.enable_ep_moe,
170
+ "device": server_args.device,
160
171
  }
161
172
  )
162
173
 
@@ -165,6 +176,10 @@ class ModelRunner:
165
176
  # Get memory before model loading
166
177
  min_per_gpu_memory = self.init_torch_distributed()
167
178
 
179
+ self.memory_saver_adapter = TorchMemorySaverAdapter.create(
180
+ enable=self.server_args.enable_memory_saver
181
+ )
182
+
168
183
  # Load the model
169
184
  self.sampler = Sampler()
170
185
  self.load_model()
@@ -210,9 +225,12 @@ class ModelRunner:
210
225
  backend = "gloo"
211
226
  elif self.device == "hpu":
212
227
  backend = "hccl"
228
+ elif self.device == "cpu":
229
+ backend = "gloo"
213
230
 
214
231
  if not self.server_args.enable_p2p_check:
215
- monkey_patch_vllm_p2p_access_check(self.gpu_id)
232
+ monkey_patch_p2p_access_check()
233
+
216
234
  if self.server_args.dist_init_addr:
217
235
  dist_init_method = f"tcp://{self.server_args.dist_init_addr}"
218
236
  else:
@@ -220,7 +238,7 @@ class ModelRunner:
220
238
  set_custom_all_reduce(not self.server_args.disable_custom_all_reduce)
221
239
 
222
240
  if not self.is_draft_worker:
223
- # Only initilzie the distributed environment on the target model worker.
241
+ # Only initialize the distributed environment on the target model worker.
224
242
  init_distributed_environment(
225
243
  backend=backend,
226
244
  world_size=self.tp_size,
@@ -229,11 +247,18 @@ class ModelRunner:
229
247
  distributed_init_method=dist_init_method,
230
248
  )
231
249
  initialize_model_parallel(tensor_model_parallel_size=self.tp_size)
250
+ initialize_dp_attention(
251
+ enable_dp_attention=self.server_args.enable_dp_attention,
252
+ tp_rank=self.tp_rank,
253
+ tp_size=self.tp_size,
254
+ dp_size=self.server_args.dp_size,
255
+ )
232
256
 
233
257
  min_per_gpu_memory = get_available_gpu_memory(
234
258
  self.device, self.gpu_id, distributed=self.tp_size > 1
235
259
  )
236
260
  self.tp_group = get_tp_group()
261
+ self.attention_tp_group = get_attention_tp_group()
237
262
 
238
263
  # Check memory for tensor parallelism
239
264
  if self.tp_size > 1:
@@ -251,7 +276,8 @@ class ModelRunner:
251
276
  )
252
277
 
253
278
  # This can reduce thread conflicts and speed up weight loading.
254
- torch.set_num_threads(1)
279
+ if self.device != "cpu":
280
+ torch.set_num_threads(1)
255
281
  if self.device == "cuda":
256
282
  if torch.cuda.get_device_capability()[0] < 8:
257
283
  logger.info(
@@ -271,11 +297,38 @@ class ModelRunner:
271
297
  monkey_patch_vllm_gguf_config()
272
298
 
273
299
  # Load the model
274
- self.model = get_model(
275
- model_config=self.model_config,
276
- load_config=self.load_config,
277
- device_config=DeviceConfig(self.device),
278
- )
300
+ # Remove monkey_patch when linear.py quant remove dependencies with vllm
301
+ monkey_patch_vllm_parallel_state()
302
+ with self.memory_saver_adapter.region():
303
+ self.model = get_model(
304
+ model_config=self.model_config,
305
+ load_config=self.load_config,
306
+ device_config=DeviceConfig(self.device),
307
+ )
308
+ monkey_patch_vllm_parallel_state(reverse=True)
309
+
310
+ if self.server_args.kv_cache_dtype == "fp8_e4m3":
311
+ if self.server_args.quantization_param_path is not None:
312
+ if callable(getattr(self.model, "load_kv_cache_scales", None)):
313
+ self.model.load_kv_cache_scales(
314
+ self.server_args.quantization_param_path
315
+ )
316
+ logger.info(
317
+ "Loaded KV cache scaling factors from %s",
318
+ self.server_args.quantization_param_path,
319
+ )
320
+ else:
321
+ raise RuntimeError(
322
+ "Using FP8 KV cache and scaling factors provided but "
323
+ "model %s does not support loading scaling factors.",
324
+ self.model.__class__,
325
+ )
326
+ else:
327
+ logger.warning(
328
+ "Using FP8 KV cache but no scaling factors "
329
+ "provided. Defaulting to scaling factors of 1.0. "
330
+ "This may lead to less accurate results!"
331
+ )
279
332
 
280
333
  # Parse other args
281
334
  self.sliding_window_size = (
@@ -393,7 +446,7 @@ class ModelRunner:
393
446
 
394
447
  logger.info(
395
448
  f"init custom process group: master_address={master_address}, master_port={master_port}, "
396
- f"rank_offset={rank_offset}, world_size={world_size}, group_name={group_name}, backend={backend}"
449
+ f"rank_offset={rank_offset}, rank={rank}, world_size={world_size}, group_name={group_name}, backend={backend}"
397
450
  )
398
451
 
399
452
  try:
@@ -491,7 +544,7 @@ class ModelRunner:
491
544
  )
492
545
  else:
493
546
  cell_size = (
494
- self.model_config.get_num_kv_heads(self.tp_size)
547
+ self.model_config.get_num_kv_heads(get_attention_tp_size())
495
548
  * self.model_config.head_dim
496
549
  * self.model_config.num_hidden_layers
497
550
  * 2
@@ -516,6 +569,9 @@ class ModelRunner:
516
569
  self.kv_cache_dtype = torch.float8_e5m2fnuz
517
570
  else:
518
571
  self.kv_cache_dtype = torch.float8_e5m2
572
+ elif self.server_args.kv_cache_dtype == "fp8_e4m3":
573
+ if is_cuda():
574
+ self.kv_cache_dtype = torch.float8_e4m3fn
519
575
  else:
520
576
  raise ValueError(
521
577
  f"Unsupported kv_cache_dtype: {self.server_args.kv_cache_dtype}."
@@ -562,7 +618,7 @@ class ModelRunner:
562
618
  size=max_num_reqs + 1,
563
619
  max_context_len=self.model_config.context_len + 4,
564
620
  device=self.device,
565
- use_records=False,
621
+ enable_memory_saver=self.server_args.enable_memory_saver,
566
622
  )
567
623
  if (
568
624
  self.model_config.attention_arch == AttentionArch.MLA
@@ -575,25 +631,28 @@ class ModelRunner:
575
631
  qk_rope_head_dim=self.model_config.qk_rope_head_dim,
576
632
  layer_num=self.model_config.num_hidden_layers,
577
633
  device=self.device,
634
+ enable_memory_saver=self.server_args.enable_memory_saver,
578
635
  )
579
636
  elif self.server_args.enable_double_sparsity:
580
637
  self.token_to_kv_pool = DoubleSparseTokenToKVPool(
581
638
  self.max_total_num_tokens,
582
639
  dtype=self.kv_cache_dtype,
583
- head_num=self.model_config.get_num_kv_heads(self.tp_size),
640
+ head_num=self.model_config.get_num_kv_heads(get_attention_tp_size()),
584
641
  head_dim=self.model_config.head_dim,
585
642
  layer_num=self.model_config.num_hidden_layers,
586
643
  device=self.device,
587
644
  heavy_channel_num=self.server_args.ds_heavy_channel_num,
645
+ enable_memory_saver=self.server_args.enable_memory_saver,
588
646
  )
589
647
  else:
590
648
  self.token_to_kv_pool = MHATokenToKVPool(
591
649
  self.max_total_num_tokens,
592
650
  dtype=self.kv_cache_dtype,
593
- head_num=self.model_config.get_num_kv_heads(self.tp_size),
651
+ head_num=self.model_config.get_num_kv_heads(get_attention_tp_size()),
594
652
  head_dim=self.model_config.head_dim,
595
653
  layer_num=self.model_config.num_hidden_layers,
596
654
  device=self.device,
655
+ enable_memory_saver=self.server_args.enable_memory_saver,
597
656
  )
598
657
  logger.info(
599
658
  f"Memory pool end. "
@@ -21,14 +21,14 @@ from huggingface_hub import HfApi, hf_hub_download
21
21
  from torch import nn
22
22
  from transformers import AutoModelForCausalLM, PretrainedConfig
23
23
  from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
24
- from vllm.distributed import (
25
- get_tensor_model_parallel_rank,
26
- get_tensor_model_parallel_world_size,
27
- )
28
24
 
29
25
  from sglang.srt.configs.device_config import DeviceConfig
30
26
  from sglang.srt.configs.load_config import LoadConfig, LoadFormat
31
27
  from sglang.srt.configs.model_config import ModelConfig
28
+ from sglang.srt.distributed import (
29
+ get_tensor_model_parallel_rank,
30
+ get_tensor_model_parallel_world_size,
31
+ )
32
32
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
33
33
  from sglang.srt.model_loader.utils import (
34
34
  get_model_architecture,
@@ -496,7 +496,8 @@ class ShardedStateLoader(BaseModelLoader):
496
496
  device_config: DeviceConfig,
497
497
  ) -> nn.Module:
498
498
  from safetensors.torch import safe_open
499
- from vllm.distributed import get_tensor_model_parallel_rank
499
+
500
+ from sglang.srt.distributed import get_tensor_model_parallel_rank
500
501
 
501
502
  local_model_path = self._prepare_weights(
502
503
  model_config.model_path, model_config.revision
@@ -556,7 +557,8 @@ class ShardedStateLoader(BaseModelLoader):
556
557
  max_size: Optional[int] = None,
557
558
  ) -> None:
558
559
  from safetensors.torch import save_file
559
- from vllm.distributed import get_tensor_model_parallel_rank
560
+
561
+ from sglang.srt.distributed import get_tensor_model_parallel_rank
560
562
 
561
563
  if pattern is None:
562
564
  pattern = ShardedStateLoader.DEFAULT_PATTERN
@@ -9,7 +9,17 @@ import logging
9
9
  import os
10
10
  import tempfile
11
11
  from collections import defaultdict
12
- from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union
12
+ from typing import (
13
+ Any,
14
+ Callable,
15
+ Dict,
16
+ Generator,
17
+ Iterable,
18
+ List,
19
+ Optional,
20
+ Tuple,
21
+ Union,
22
+ )
13
23
 
14
24
  import filelock
15
25
  import gguf
@@ -19,10 +29,10 @@ import torch
19
29
  from huggingface_hub import HfFileSystem, hf_hub_download, snapshot_download
20
30
  from safetensors.torch import load_file, safe_open, save_file
21
31
  from tqdm.auto import tqdm
22
- from vllm.distributed import get_tensor_model_parallel_rank
23
32
 
24
33
  from sglang.srt.configs.load_config import LoadConfig
25
34
  from sglang.srt.configs.model_config import ModelConfig
35
+ from sglang.srt.distributed import get_tensor_model_parallel_rank
26
36
  from sglang.srt.layers.quantization import QuantizationConfig, get_quantization_config
27
37
  from sglang.srt.utils import print_warning_once
28
38
 
@@ -638,3 +648,46 @@ def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> Optional[str]:
638
648
 
639
649
  # If there were no matches, return the untouched param name
640
650
  return name
651
+
652
+
653
+ def kv_cache_scales_loader(
654
+ filename: str,
655
+ tp_rank: int,
656
+ tp_size: int,
657
+ num_hidden_layers: int,
658
+ model_type: Optional[str],
659
+ ) -> Iterable[Tuple[int, float]]:
660
+ """
661
+ A simple utility to read in KV cache scaling factors that have been
662
+ previously serialized to disk. Used by the model to populate the appropriate
663
+ KV cache scaling factors. The serialization should represent a dictionary
664
+ whose keys are the TP ranks and values are another dictionary mapping layers
665
+ to their KV cache scaling factors.
666
+ """
667
+ try:
668
+ with open(filename) as f:
669
+ context = {
670
+ "model_type": model_type,
671
+ "num_hidden_layers": num_hidden_layers,
672
+ "tp_rank": tp_rank,
673
+ "tp_size": tp_size,
674
+ }
675
+ schema_dct = json.load(f)
676
+ schema = QuantParamSchema.model_validate(schema_dct, context=context)
677
+ layer_scales_map = schema.kv_cache.scaling_factor[tp_rank]
678
+ return layer_scales_map.items()
679
+ except FileNotFoundError:
680
+ logger.error("File or directory '%s' not found.", filename)
681
+ except json.JSONDecodeError:
682
+ logger.error("Error decoding JSON in file '%s'.", filename)
683
+ except Exception:
684
+ logger.exception("An error occurred while reading '%s'.", filename)
685
+ # This section is reached if and only if any of the excepts are hit
686
+ # Return an empty iterable (list) => no KV cache scales are loaded
687
+ # which ultimately defaults to 1.0 scales
688
+ logger.warning(
689
+ "Defaulting to KV cache scaling factors = 1.0 for all "
690
+ "layers in TP rank %d as an error occurred during loading.",
691
+ tp_rank,
692
+ )
693
+ return []
@@ -24,22 +24,22 @@ from typing import Iterable, Optional, Tuple
24
24
  import torch
25
25
  from torch import nn
26
26
  from transformers import PretrainedConfig
27
- from vllm.distributed import (
27
+
28
+ from sglang.srt.distributed import (
28
29
  get_tensor_model_parallel_rank,
29
30
  get_tensor_model_parallel_world_size,
30
31
  )
31
- from vllm.model_executor.layers.linear import (
32
+ from sglang.srt.layers.activation import SiluAndMul
33
+ from sglang.srt.layers.layernorm import RMSNorm
34
+ from sglang.srt.layers.linear import (
32
35
  MergedColumnParallelLinear,
33
36
  QKVParallelLinear,
34
37
  RowParallelLinear,
35
38
  )
36
- from vllm.model_executor.layers.rotary_embedding import get_rope
37
-
38
- from sglang.srt.layers.activation import SiluAndMul
39
- from sglang.srt.layers.layernorm import RMSNorm
40
39
  from sglang.srt.layers.logits_processor import LogitsProcessor
41
40
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
42
41
  from sglang.srt.layers.radix_attention import RadixAttention
42
+ from sglang.srt.layers.rotary_embedding import get_rope
43
43
  from sglang.srt.layers.vocab_parallel_embedding import (
44
44
  ParallelLMHead,
45
45
  VocabParallelEmbedding,
@@ -21,10 +21,9 @@ from typing import Iterable, Optional, Tuple
21
21
  import torch
22
22
  from torch import nn
23
23
  from torch.nn import LayerNorm
24
- from vllm.distributed import get_tensor_model_parallel_world_size
25
- from vllm.model_executor.layers.rotary_embedding import get_rope
26
24
 
27
25
  from sglang.srt.configs import ChatGLMConfig
26
+ from sglang.srt.distributed import get_tensor_model_parallel_world_size
28
27
  from sglang.srt.layers.activation import SiluAndMul
29
28
  from sglang.srt.layers.layernorm import RMSNorm
30
29
  from sglang.srt.layers.linear import (
@@ -35,6 +34,7 @@ from sglang.srt.layers.linear import (
35
34
  from sglang.srt.layers.logits_processor import LogitsProcessor
36
35
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
37
36
  from sglang.srt.layers.radix_attention import RadixAttention
37
+ from sglang.srt.layers.rotary_embedding import get_rope
38
38
  from sglang.srt.layers.vocab_parallel_embedding import (
39
39
  ParallelLMHead,
40
40
  VocabParallelEmbedding,
@@ -44,12 +44,11 @@ import torch.utils.checkpoint
44
44
  from torch import nn
45
45
  from torch.nn.parameter import Parameter
46
46
  from transformers import PretrainedConfig
47
- from vllm.distributed import (
47
+
48
+ from sglang.srt.distributed import (
48
49
  get_tensor_model_parallel_rank,
49
50
  get_tensor_model_parallel_world_size,
50
51
  )
51
- from vllm.model_executor.layers.rotary_embedding import get_rope
52
-
53
52
  from sglang.srt.layers.activation import SiluAndMul
54
53
  from sglang.srt.layers.linear import (
55
54
  MergedColumnParallelLinear,
@@ -59,6 +58,7 @@ from sglang.srt.layers.linear import (
59
58
  from sglang.srt.layers.logits_processor import LogitsProcessor
60
59
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
61
60
  from sglang.srt.layers.radix_attention import RadixAttention
61
+ from sglang.srt.layers.rotary_embedding import get_rope
62
62
  from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
63
63
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
64
64
  from sglang.srt.model_loader.weight_utils import default_weight_loader
sglang/srt/models/dbrx.py CHANGED
@@ -19,14 +19,13 @@ from typing import Iterable, Optional, Tuple
19
19
 
20
20
  import torch
21
21
  import torch.nn as nn
22
- from vllm.distributed import (
22
+
23
+ from sglang.srt.configs import DbrxConfig
24
+ from sglang.srt.distributed import (
23
25
  get_tensor_model_parallel_rank,
24
26
  get_tensor_model_parallel_world_size,
25
27
  tensor_model_parallel_all_reduce,
26
28
  )
27
- from vllm.model_executor.layers.rotary_embedding import get_rope
28
-
29
- from sglang.srt.configs import DbrxConfig
30
29
  from sglang.srt.layers.linear import (
31
30
  QKVParallelLinear,
32
31
  ReplicatedLinear,
@@ -36,6 +35,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
36
35
  from sglang.srt.layers.moe.fused_moe_triton import fused_moe
37
36
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
38
37
  from sglang.srt.layers.radix_attention import RadixAttention
38
+ from sglang.srt.layers.rotary_embedding import get_rope
39
39
  from sglang.srt.layers.vocab_parallel_embedding import (
40
40
  DEFAULT_VOCAB_PADDING_SIZE,
41
41
  ParallelLMHead,
@@ -21,13 +21,12 @@ from typing import Any, Dict, Iterable, Optional, Tuple
21
21
  import torch
22
22
  from torch import nn
23
23
  from transformers import PretrainedConfig
24
- from vllm.distributed import (
24
+
25
+ from sglang.srt.distributed import (
25
26
  get_tensor_model_parallel_rank,
26
27
  get_tensor_model_parallel_world_size,
27
28
  tensor_model_parallel_all_reduce,
28
29
  )
29
- from vllm.model_executor.layers.rotary_embedding import get_rope
30
-
31
30
  from sglang.srt.layers.activation import SiluAndMul
32
31
  from sglang.srt.layers.layernorm import RMSNorm
33
32
  from sglang.srt.layers.linear import (
@@ -40,6 +39,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
40
39
  from sglang.srt.layers.moe.fused_moe_triton import fused_moe
41
40
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
42
41
  from sglang.srt.layers.radix_attention import RadixAttention
42
+ from sglang.srt.layers.rotary_embedding import get_rope
43
43
  from sglang.srt.layers.vocab_parallel_embedding import (
44
44
  ParallelLMHead,
45
45
  VocabParallelEmbedding,
@@ -23,14 +23,13 @@ import torch.nn.functional as F
23
23
  from torch import nn
24
24
  from transformers import PretrainedConfig
25
25
  from vllm import _custom_ops as ops
26
- from vllm.distributed import (
26
+
27
+ from sglang.srt.distributed import (
27
28
  get_tensor_model_parallel_rank,
28
29
  get_tensor_model_parallel_world_size,
29
30
  get_tp_group,
30
31
  tensor_model_parallel_all_reduce,
31
32
  )
32
- from vllm.model_executor.layers.rotary_embedding import get_rope
33
-
34
33
  from sglang.srt.layers.activation import SiluAndMul
35
34
  from sglang.srt.layers.layernorm import RMSNorm
36
35
  from sglang.srt.layers.linear import (
@@ -49,6 +48,7 @@ from sglang.srt.layers.quantization.fp8_utils import (
49
48
  normalize_e4m3fn_to_e4m3fnuz,
50
49
  )
51
50
  from sglang.srt.layers.radix_attention import RadixAttention
51
+ from sglang.srt.layers.rotary_embedding import get_rope, get_rope_wrapper
52
52
  from sglang.srt.layers.vocab_parallel_embedding import (
53
53
  ParallelLMHead,
54
54
  VocabParallelEmbedding,
@@ -271,13 +271,14 @@ class DeepseekV2Attention(nn.Module):
271
271
  quant_config=quant_config,
272
272
  )
273
273
  rope_scaling["rope_type"] = "deepseek_yarn"
274
- self.rotary_emb = get_rope(
274
+ self.rotary_emb = get_rope_wrapper(
275
275
  qk_rope_head_dim,
276
276
  rotary_dim=qk_rope_head_dim,
277
277
  max_position=max_position_embeddings,
278
278
  base=rope_theta,
279
279
  rope_scaling=rope_scaling,
280
280
  is_neox_style=False,
281
+ device=global_server_args_dict["device"],
281
282
  )
282
283
 
283
284
  if rope_scaling:
@@ -855,10 +856,9 @@ class DeepseekV2ForCausalLM(nn.Module):
855
856
  forward_batch: ForwardBatch,
856
857
  ) -> torch.Tensor:
857
858
  hidden_states = self.model(input_ids, positions, forward_batch)
858
- if not forward_batch.forward_mode.is_idle():
859
- return self.logits_processor(
860
- input_ids, hidden_states, self.lm_head, forward_batch
861
- )
859
+ return self.logits_processor(
860
+ input_ids, hidden_states, self.lm_head, forward_batch
861
+ )
862
862
 
863
863
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
864
864
  stacked_params_mapping = [
@@ -20,9 +20,8 @@ from typing import Any, Dict, Iterable, Optional, Tuple
20
20
 
21
21
  import torch
22
22
  from torch import nn
23
- from vllm.distributed import get_tensor_model_parallel_world_size
24
- from vllm.model_executor.layers.rotary_embedding import get_rope
25
23
 
24
+ from sglang.srt.distributed import get_tensor_model_parallel_world_size
26
25
  from sglang.srt.layers.activation import SiluAndMul
27
26
  from sglang.srt.layers.layernorm import RMSNorm
28
27
  from sglang.srt.layers.linear import (
@@ -33,6 +32,7 @@ from sglang.srt.layers.linear import (
33
32
  from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
34
33
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
35
34
  from sglang.srt.layers.radix_attention import RadixAttention
35
+ from sglang.srt.layers.rotary_embedding import get_rope
36
36
  from sglang.srt.layers.vocab_parallel_embedding import (
37
37
  ParallelLMHead,
38
38
  VocabParallelEmbedding,
@@ -21,9 +21,8 @@ from typing import Iterable, Optional, Tuple
21
21
  import torch
22
22
  from torch import nn
23
23
  from transformers import PretrainedConfig
24
- from vllm.distributed import get_tensor_model_parallel_world_size
25
- from vllm.model_executor.layers.rotary_embedding import get_rope
26
24
 
25
+ from sglang.srt.distributed import get_tensor_model_parallel_world_size
27
26
  from sglang.srt.layers.activation import GeluAndMul
28
27
  from sglang.srt.layers.layernorm import RMSNorm
29
28
  from sglang.srt.layers.linear import (
@@ -34,6 +33,7 @@ from sglang.srt.layers.linear import (
34
33
  from sglang.srt.layers.logits_processor import LogitsProcessor
35
34
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
36
35
  from sglang.srt.layers.radix_attention import RadixAttention
36
+ from sglang.srt.layers.rotary_embedding import get_rope
37
37
  from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
38
38
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
39
39
  from sglang.srt.model_loader.weight_utils import default_weight_loader
@@ -15,13 +15,13 @@
15
15
  # Adapted from:
16
16
  # https://github.com/vllm-project/vllm/blob/56b325e977435af744f8b3dca7af0ca209663558/vllm/model_executor/models/gemma2.py
17
17
 
18
- from typing import Iterable, Optional, Set, Tuple, Union
18
+ from typing import Iterable, Optional, Set, Tuple
19
19
 
20
20
  import torch
21
21
  from torch import nn
22
22
  from transformers import PretrainedConfig
23
- from vllm.distributed import get_tensor_model_parallel_world_size
24
23
 
24
+ from sglang.srt.distributed import get_tensor_model_parallel_world_size
25
25
  from sglang.srt.layers.activation import GeluAndMul
26
26
  from sglang.srt.layers.layernorm import GemmaRMSNorm
27
27
  from sglang.srt.layers.linear import (
@@ -32,6 +32,7 @@ from sglang.srt.layers.linear import (
32
32
  from sglang.srt.layers.logits_processor import LogitsProcessor
33
33
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
34
34
  from sglang.srt.layers.radix_attention import RadixAttention
35
+ from sglang.srt.layers.rotary_embedding import get_rope
35
36
  from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
36
37
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
37
38
  from sglang.srt.model_loader.weight_utils import default_weight_loader
@@ -44,23 +45,6 @@ def get_attention_sliding_window_size(config):
44
45
  return config.sliding_window - 1
45
46
 
46
47
 
47
- # FIXME: temporary solution, remove after next vllm release
48
- from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
49
-
50
-
51
- class GemmaRotaryEmbedding(RotaryEmbedding):
52
- def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
53
- # https://github.com/huggingface/transformers/blob/v4.41.2/src/transformers/models/gemma/modeling_gemma.py#L107
54
- inv_freq = 1.0 / (
55
- base
56
- ** (
57
- torch.arange(0, self.rotary_dim, 2, dtype=torch.int64).float()
58
- / self.rotary_dim
59
- )
60
- )
61
- return inv_freq
62
-
63
-
64
48
  class Gemma2MLP(nn.Module):
65
49
  def __init__(
66
50
  self,
@@ -143,14 +127,12 @@ class Gemma2Attention(nn.Module):
143
127
  bias=config.attention_bias,
144
128
  quant_config=quant_config,
145
129
  )
146
- # from vLLM: TODO(woosuk): Use the `get_rope` interface.
147
- self.rotary_emb = GemmaRotaryEmbedding(
148
- self.head_dim,
130
+ self.rotary_emb = get_rope(
149
131
  self.head_dim,
150
- max_position_embeddings,
132
+ rotary_dim=self.head_dim,
133
+ max_position=max_position_embeddings,
151
134
  base=self.rope_theta,
152
135
  is_neox_style=True,
153
- dtype=torch.get_default_dtype(),
154
136
  )
155
137
 
156
138
  use_sliding_window = layer_id % 2 == 0 and hasattr(config, "sliding_window")
sglang/srt/models/gpt2.py CHANGED
@@ -17,16 +17,14 @@
17
17
  # See the License for the specific language governing permissions and
18
18
  # limitations under the License.
19
19
  """Inference-only GPT-2 model compatible with HuggingFace weights."""
20
- from typing import Iterable, List, Optional, Tuple
20
+ from typing import Iterable, Optional, Tuple
21
21
 
22
22
  import torch
23
23
  from torch import nn
24
24
  from transformers import GPT2Config
25
- from vllm.distributed.parallel_state import get_tensor_model_parallel_world_size
26
- from vllm.model_executor.layers.activation import get_act_fn
27
- from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
28
25
 
29
- # from sglang.srt.layers.activation import get_act_fn
26
+ from sglang.srt.distributed.parallel_state import get_tensor_model_parallel_world_size
27
+ from sglang.srt.layers.activation import get_act_fn
30
28
  from sglang.srt.layers.linear import (
31
29
  ColumnParallelLinear,
32
30
  QKVParallelLinear,
@@ -21,8 +21,8 @@ from typing import Iterable, Optional, Tuple
21
21
  import torch
22
22
  from torch import nn
23
23
  from transformers import GPTBigCodeConfig
24
- from vllm.distributed import get_tensor_model_parallel_world_size
25
24
 
25
+ from sglang.srt.distributed import get_tensor_model_parallel_world_size
26
26
  from sglang.srt.layers.activation import get_act_fn
27
27
  from sglang.srt.layers.linear import (
28
28
  ColumnParallelLinear,