sglang 0.4.1.post6__py3-none-any.whl → 0.4.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (141) hide show
  1. sglang/__init__.py +21 -23
  2. sglang/api.py +2 -7
  3. sglang/bench_offline_throughput.py +41 -27
  4. sglang/bench_one_batch.py +60 -4
  5. sglang/bench_one_batch_server.py +1 -1
  6. sglang/bench_serving.py +83 -71
  7. sglang/lang/backend/runtime_endpoint.py +183 -4
  8. sglang/lang/chat_template.py +46 -4
  9. sglang/launch_server.py +1 -1
  10. sglang/srt/_custom_ops.py +80 -42
  11. sglang/srt/configs/device_config.py +1 -1
  12. sglang/srt/configs/load_config.py +1 -0
  13. sglang/srt/configs/model_config.py +1 -0
  14. sglang/srt/constrained/base_grammar_backend.py +21 -0
  15. sglang/srt/constrained/xgrammar_backend.py +8 -4
  16. sglang/srt/conversation.py +14 -1
  17. sglang/srt/distributed/__init__.py +3 -3
  18. sglang/srt/distributed/communication_op.py +2 -1
  19. sglang/srt/distributed/device_communicators/cuda_wrapper.py +2 -1
  20. sglang/srt/distributed/device_communicators/custom_all_reduce.py +112 -42
  21. sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +2 -2
  22. sglang/srt/distributed/device_communicators/hpu_communicator.py +2 -1
  23. sglang/srt/distributed/device_communicators/pynccl.py +80 -1
  24. sglang/srt/distributed/device_communicators/pynccl_wrapper.py +112 -2
  25. sglang/srt/distributed/device_communicators/shm_broadcast.py +5 -72
  26. sglang/srt/distributed/device_communicators/xpu_communicator.py +2 -1
  27. sglang/srt/distributed/parallel_state.py +1 -1
  28. sglang/srt/distributed/utils.py +2 -1
  29. sglang/srt/entrypoints/engine.py +452 -0
  30. sglang/srt/entrypoints/http_server.py +603 -0
  31. sglang/srt/function_call_parser.py +494 -0
  32. sglang/srt/layers/activation.py +8 -8
  33. sglang/srt/layers/attention/flashinfer_backend.py +10 -9
  34. sglang/srt/layers/attention/triton_backend.py +4 -6
  35. sglang/srt/layers/attention/vision.py +204 -0
  36. sglang/srt/layers/dp_attention.py +71 -0
  37. sglang/srt/layers/layernorm.py +5 -5
  38. sglang/srt/layers/linear.py +65 -14
  39. sglang/srt/layers/logits_processor.py +49 -64
  40. sglang/srt/layers/moe/ep_moe/layer.py +24 -16
  41. sglang/srt/layers/moe/fused_moe_native.py +84 -1
  42. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  43. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +27 -7
  44. sglang/srt/layers/moe/fused_moe_triton/layer.py +38 -5
  45. sglang/srt/layers/parameter.py +18 -8
  46. sglang/srt/layers/quantization/__init__.py +20 -23
  47. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  48. sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  49. sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  50. sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  51. sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  52. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  53. sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  54. sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  55. sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  56. sglang/srt/layers/quantization/fp8.py +10 -4
  57. sglang/srt/layers/quantization/modelopt_quant.py +1 -2
  58. sglang/srt/layers/quantization/w8a8_int8.py +1 -1
  59. sglang/srt/layers/radix_attention.py +2 -2
  60. sglang/srt/layers/rotary_embedding.py +1184 -31
  61. sglang/srt/layers/sampler.py +64 -6
  62. sglang/srt/layers/torchao_utils.py +12 -6
  63. sglang/srt/layers/vocab_parallel_embedding.py +2 -2
  64. sglang/srt/lora/lora.py +1 -9
  65. sglang/srt/managers/configure_logging.py +3 -0
  66. sglang/srt/managers/data_parallel_controller.py +79 -72
  67. sglang/srt/managers/detokenizer_manager.py +24 -6
  68. sglang/srt/managers/image_processor.py +158 -2
  69. sglang/srt/managers/io_struct.py +57 -3
  70. sglang/srt/managers/schedule_batch.py +78 -45
  71. sglang/srt/managers/schedule_policy.py +26 -12
  72. sglang/srt/managers/scheduler.py +326 -201
  73. sglang/srt/managers/session_controller.py +1 -0
  74. sglang/srt/managers/tokenizer_manager.py +210 -121
  75. sglang/srt/managers/tp_worker.py +6 -4
  76. sglang/srt/managers/tp_worker_overlap_thread.py +5 -8
  77. sglang/srt/managers/utils.py +44 -0
  78. sglang/srt/mem_cache/memory_pool.py +10 -32
  79. sglang/srt/metrics/collector.py +15 -6
  80. sglang/srt/model_executor/cuda_graph_runner.py +26 -30
  81. sglang/srt/model_executor/forward_batch_info.py +5 -7
  82. sglang/srt/model_executor/model_runner.py +44 -19
  83. sglang/srt/model_loader/loader.py +83 -6
  84. sglang/srt/model_loader/weight_utils.py +145 -6
  85. sglang/srt/models/baichuan.py +6 -6
  86. sglang/srt/models/chatglm.py +2 -2
  87. sglang/srt/models/commandr.py +17 -5
  88. sglang/srt/models/dbrx.py +13 -5
  89. sglang/srt/models/deepseek.py +3 -3
  90. sglang/srt/models/deepseek_v2.py +11 -11
  91. sglang/srt/models/exaone.py +2 -2
  92. sglang/srt/models/gemma.py +2 -2
  93. sglang/srt/models/gemma2.py +15 -25
  94. sglang/srt/models/gpt2.py +3 -5
  95. sglang/srt/models/gpt_bigcode.py +1 -1
  96. sglang/srt/models/granite.py +2 -2
  97. sglang/srt/models/grok.py +4 -3
  98. sglang/srt/models/internlm2.py +2 -2
  99. sglang/srt/models/llama.py +7 -5
  100. sglang/srt/models/minicpm.py +2 -2
  101. sglang/srt/models/minicpm3.py +9 -9
  102. sglang/srt/models/minicpmv.py +1238 -0
  103. sglang/srt/models/mixtral.py +3 -3
  104. sglang/srt/models/mixtral_quant.py +3 -3
  105. sglang/srt/models/mllama.py +2 -2
  106. sglang/srt/models/olmo.py +3 -3
  107. sglang/srt/models/olmo2.py +4 -4
  108. sglang/srt/models/olmoe.py +7 -13
  109. sglang/srt/models/phi3_small.py +2 -2
  110. sglang/srt/models/qwen.py +2 -2
  111. sglang/srt/models/qwen2.py +41 -4
  112. sglang/srt/models/qwen2_moe.py +3 -3
  113. sglang/srt/models/qwen2_vl.py +22 -122
  114. sglang/srt/models/stablelm.py +2 -2
  115. sglang/srt/models/torch_native_llama.py +20 -7
  116. sglang/srt/models/xverse.py +6 -6
  117. sglang/srt/models/xverse_moe.py +6 -6
  118. sglang/srt/openai_api/adapter.py +139 -37
  119. sglang/srt/openai_api/protocol.py +7 -4
  120. sglang/srt/sampling/custom_logit_processor.py +38 -0
  121. sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +11 -14
  122. sglang/srt/sampling/sampling_batch_info.py +143 -18
  123. sglang/srt/sampling/sampling_params.py +3 -1
  124. sglang/srt/server.py +4 -1090
  125. sglang/srt/server_args.py +77 -15
  126. sglang/srt/speculative/eagle_utils.py +37 -15
  127. sglang/srt/speculative/eagle_worker.py +11 -13
  128. sglang/srt/utils.py +164 -129
  129. sglang/test/runners.py +8 -13
  130. sglang/test/test_programs.py +2 -1
  131. sglang/test/test_utils.py +83 -22
  132. sglang/utils.py +12 -2
  133. sglang/version.py +1 -1
  134. {sglang-0.4.1.post6.dist-info → sglang-0.4.2.dist-info}/METADATA +21 -10
  135. {sglang-0.4.1.post6.dist-info → sglang-0.4.2.dist-info}/RECORD +138 -123
  136. sglang/launch_server_llavavid.py +0 -25
  137. sglang/srt/constrained/__init__.py +0 -16
  138. sglang/srt/distributed/device_communicators/__init__.py +0 -0
  139. {sglang-0.4.1.post6.dist-info → sglang-0.4.2.dist-info}/LICENSE +0 -0
  140. {sglang-0.4.1.post6.dist-info → sglang-0.4.2.dist-info}/WHEEL +0 -0
  141. {sglang-0.4.1.post6.dist-info → sglang-0.4.2.dist-info}/top_level.txt +0 -0
@@ -9,7 +9,17 @@ import logging
9
9
  import os
10
10
  import tempfile
11
11
  from collections import defaultdict
12
- from typing import Any, Callable, Dict, Generator, List, Optional, Tuple, Union
12
+ from typing import (
13
+ Any,
14
+ Callable,
15
+ Dict,
16
+ Generator,
17
+ Iterable,
18
+ List,
19
+ Optional,
20
+ Tuple,
21
+ Union,
22
+ )
13
23
 
14
24
  import filelock
15
25
  import gguf
@@ -17,12 +27,13 @@ import huggingface_hub.constants
17
27
  import numpy as np
18
28
  import torch
19
29
  from huggingface_hub import HfFileSystem, hf_hub_download, snapshot_download
30
+ from pydantic import BaseModel, ConfigDict, ValidationInfo, model_validator
20
31
  from safetensors.torch import load_file, safe_open, save_file
21
32
  from tqdm.auto import tqdm
22
- from vllm.distributed import get_tensor_model_parallel_rank
23
33
 
24
34
  from sglang.srt.configs.load_config import LoadConfig
25
35
  from sglang.srt.configs.model_config import ModelConfig
36
+ from sglang.srt.distributed import get_tensor_model_parallel_rank
26
37
  from sglang.srt.layers.quantization import QuantizationConfig, get_quantization_config
27
38
  from sglang.srt.utils import print_warning_once
28
39
 
@@ -393,8 +404,13 @@ def np_cache_weights_iterator(
393
404
 
394
405
  def safetensors_weights_iterator(
395
406
  hf_weights_files: List[str],
407
+ is_all_weights_sharded: bool = False,
396
408
  ) -> Generator[Tuple[str, torch.Tensor], None, None]:
397
- """Iterate over the weights in the model safetensor files."""
409
+ """Iterate over the weights in the model safetensor files.
410
+
411
+ If is_all_weights_sharded is True, it uses more optimize read by reading an
412
+ entire file instead of reading each tensor one by one.
413
+ """
398
414
  enable_tqdm = (
399
415
  not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0
400
416
  )
@@ -404,9 +420,14 @@ def safetensors_weights_iterator(
404
420
  disable=not enable_tqdm,
405
421
  bar_format=_BAR_FORMAT,
406
422
  ):
407
- with safe_open(st_file, framework="pt") as f:
408
- for name in f.keys(): # noqa: SIM118
409
- param = f.get_tensor(name)
423
+ if not is_all_weights_sharded:
424
+ with safe_open(st_file, framework="pt") as f:
425
+ for name in f.keys(): # noqa: SIM118
426
+ param = f.get_tensor(name)
427
+ yield name, param
428
+ else:
429
+ result = load_file(st_file, device="cpu")
430
+ for name, param in result.items():
410
431
  yield name, param
411
432
 
412
433
 
@@ -638,3 +659,121 @@ def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> Optional[str]:
638
659
 
639
660
  # If there were no matches, return the untouched param name
640
661
  return name
662
+
663
+
664
+ # Adapted from https://github.com/vllm-project/vllm/blob/68ad4e3a8d8a66fb2a43be57471ee13a8bec4ec0/vllm/model_executor/layers/quantization/schema.py
665
+ class KVCacheQuantSchema(BaseModel):
666
+ dtype: str
667
+ # Each key is a TP rank. Each value is a dictionary mapping a TP rank's
668
+ # layer indices to their per-tensor KV cache scaling factor.
669
+ # TODO: Consider pulling this and its validation methods out into its
670
+ # own schema class (tricky as its members are variable)
671
+ scaling_factor: Dict[int, Dict[int, float]]
672
+
673
+ @model_validator(mode="after")
674
+ def check_is_fp8(self) -> "KVCacheQuantSchema":
675
+ assert self.dtype == "float8_e4m3fn", (
676
+ "Loaded scaling factors intended for KV cache dtype = "
677
+ f"{self.dtype} rather than float8_e4m3fn!"
678
+ )
679
+ return self
680
+
681
+ @model_validator(mode="after")
682
+ def check_tp_ranks(self, info: ValidationInfo) -> "KVCacheQuantSchema":
683
+ context = info.context
684
+ if context:
685
+ tp_size = context["tp_size"]
686
+ num_hidden_layers = context["num_hidden_layers"]
687
+ assert len(self.scaling_factor) == tp_size, (
688
+ f"Loaded dictionary has TP size {len(self.scaling_factor)} "
689
+ f"but LLM engine is currently running with TP size {tp_size}."
690
+ )
691
+ for tp_rank, layer_maps in self.scaling_factor.items():
692
+ assert len(layer_maps) == num_hidden_layers, (
693
+ f"KV cache scales map for TP rank {tp_rank} is malformed. "
694
+ f"Expected {num_hidden_layers} layers, got "
695
+ f"{len(layer_maps)}."
696
+ )
697
+ for i in range(tp_size):
698
+ assert (
699
+ i in self.scaling_factor
700
+ ), f"KV cache scales map for TP rank {i} not found."
701
+ return self
702
+
703
+ @model_validator(mode="after")
704
+ def check_current_rank(self, info: ValidationInfo) -> "KVCacheQuantSchema":
705
+ context = info.context
706
+ if context:
707
+ tp_rank = context["tp_rank"]
708
+ num_hidden_layers = context["num_hidden_layers"]
709
+ layer_scales_map = self.scaling_factor[tp_rank]
710
+ for i in range(num_hidden_layers):
711
+ assert i in layer_scales_map, (
712
+ f"Could not find KV cache scales for layer {i} in "
713
+ f"TP rank {tp_rank}."
714
+ )
715
+ return self
716
+
717
+
718
+ class QuantParamSchema(BaseModel):
719
+ # TODO: Generalize and extend with more fields
720
+ # (e.g. weights/activations params) once functionality is enabled
721
+ model_config = ConfigDict(protected_namespaces=())
722
+ model_type: Optional[str]
723
+ kv_cache: KVCacheQuantSchema
724
+
725
+ @model_validator(mode="after")
726
+ def check_model_type(self, info: ValidationInfo) -> "QuantParamSchema":
727
+ context = info.context
728
+ if context:
729
+ model_type = context.get("model_type", None)
730
+ if model_type is not None:
731
+ assert model_type == self.model_type, (
732
+ f"Model type is {model_type} but loaded "
733
+ f"scaling factors belonging to different "
734
+ f"model type {self.model_type}!"
735
+ )
736
+ return self
737
+
738
+
739
+ def kv_cache_scales_loader(
740
+ filename: str,
741
+ tp_rank: int,
742
+ tp_size: int,
743
+ num_hidden_layers: int,
744
+ model_type: Optional[str],
745
+ ) -> Iterable[Tuple[int, float]]:
746
+ """
747
+ A simple utility to read in KV cache scaling factors that have been
748
+ previously serialized to disk. Used by the model to populate the appropriate
749
+ KV cache scaling factors. The serialization should represent a dictionary
750
+ whose keys are the TP ranks and values are another dictionary mapping layers
751
+ to their KV cache scaling factors.
752
+ """
753
+ try:
754
+ with open(filename) as f:
755
+ context = {
756
+ "model_type": model_type,
757
+ "num_hidden_layers": num_hidden_layers,
758
+ "tp_rank": tp_rank,
759
+ "tp_size": tp_size,
760
+ }
761
+ schema_dct = json.load(f)
762
+ schema = QuantParamSchema.model_validate(schema_dct, context=context)
763
+ layer_scales_map = schema.kv_cache.scaling_factor[tp_rank]
764
+ return layer_scales_map.items()
765
+ except FileNotFoundError:
766
+ logger.error("File or directory '%s' not found.", filename)
767
+ except json.JSONDecodeError:
768
+ logger.error("Error decoding JSON in file '%s'.", filename)
769
+ except Exception:
770
+ logger.error("An error occurred while reading '%s'.", filename)
771
+ # This section is reached if and only if any of the excepts are hit
772
+ # Return an empty iterable (list) => no KV cache scales are loaded
773
+ # which ultimately defaults to 1.0 scales
774
+ logger.warning(
775
+ "Defaulting to KV cache scaling factors = 1.0 for all "
776
+ "layers in TP rank %d as an error occurred during loading.",
777
+ tp_rank,
778
+ )
779
+ return []
@@ -24,22 +24,22 @@ from typing import Iterable, Optional, Tuple
24
24
  import torch
25
25
  from torch import nn
26
26
  from transformers import PretrainedConfig
27
- from vllm.distributed import (
27
+
28
+ from sglang.srt.distributed import (
28
29
  get_tensor_model_parallel_rank,
29
30
  get_tensor_model_parallel_world_size,
30
31
  )
31
- from vllm.model_executor.layers.linear import (
32
+ from sglang.srt.layers.activation import SiluAndMul
33
+ from sglang.srt.layers.layernorm import RMSNorm
34
+ from sglang.srt.layers.linear import (
32
35
  MergedColumnParallelLinear,
33
36
  QKVParallelLinear,
34
37
  RowParallelLinear,
35
38
  )
36
- from vllm.model_executor.layers.rotary_embedding import get_rope
37
-
38
- from sglang.srt.layers.activation import SiluAndMul
39
- from sglang.srt.layers.layernorm import RMSNorm
40
39
  from sglang.srt.layers.logits_processor import LogitsProcessor
41
40
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
42
41
  from sglang.srt.layers.radix_attention import RadixAttention
42
+ from sglang.srt.layers.rotary_embedding import get_rope
43
43
  from sglang.srt.layers.vocab_parallel_embedding import (
44
44
  ParallelLMHead,
45
45
  VocabParallelEmbedding,
@@ -21,10 +21,9 @@ from typing import Iterable, Optional, Tuple
21
21
  import torch
22
22
  from torch import nn
23
23
  from torch.nn import LayerNorm
24
- from vllm.distributed import get_tensor_model_parallel_world_size
25
- from vllm.model_executor.layers.rotary_embedding import get_rope
26
24
 
27
25
  from sglang.srt.configs import ChatGLMConfig
26
+ from sglang.srt.distributed import get_tensor_model_parallel_world_size
28
27
  from sglang.srt.layers.activation import SiluAndMul
29
28
  from sglang.srt.layers.layernorm import RMSNorm
30
29
  from sglang.srt.layers.linear import (
@@ -35,6 +34,7 @@ from sglang.srt.layers.linear import (
35
34
  from sglang.srt.layers.logits_processor import LogitsProcessor
36
35
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
37
36
  from sglang.srt.layers.radix_attention import RadixAttention
37
+ from sglang.srt.layers.rotary_embedding import get_rope
38
38
  from sglang.srt.layers.vocab_parallel_embedding import (
39
39
  ParallelLMHead,
40
40
  VocabParallelEmbedding,
@@ -44,12 +44,11 @@ import torch.utils.checkpoint
44
44
  from torch import nn
45
45
  from torch.nn.parameter import Parameter
46
46
  from transformers import PretrainedConfig
47
- from vllm.distributed import (
47
+
48
+ from sglang.srt.distributed import (
48
49
  get_tensor_model_parallel_rank,
49
50
  get_tensor_model_parallel_world_size,
50
51
  )
51
- from vllm.model_executor.layers.rotary_embedding import get_rope
52
-
53
52
  from sglang.srt.layers.activation import SiluAndMul
54
53
  from sglang.srt.layers.linear import (
55
54
  MergedColumnParallelLinear,
@@ -59,9 +58,13 @@ from sglang.srt.layers.linear import (
59
58
  from sglang.srt.layers.logits_processor import LogitsProcessor
60
59
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
61
60
  from sglang.srt.layers.radix_attention import RadixAttention
61
+ from sglang.srt.layers.rotary_embedding import get_rope
62
62
  from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
63
63
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
64
- from sglang.srt.model_loader.weight_utils import default_weight_loader
64
+ from sglang.srt.model_loader.weight_utils import (
65
+ default_weight_loader,
66
+ maybe_remap_kv_scale_name,
67
+ )
65
68
  from sglang.srt.utils import get_compiler_backend, set_weight_attrs
66
69
 
67
70
 
@@ -372,10 +375,19 @@ class CohereForCausalLM(nn.Module):
372
375
  # Skip loading extra bias for GPTQ models.
373
376
  if name.endswith(".bias") and name not in params_dict:
374
377
  continue
378
+ # Remapping the name of FP8 kv-scale.
379
+ name = maybe_remap_kv_scale_name(name, params_dict)
380
+ if name is None:
381
+ continue
382
+
375
383
  param = params_dict[name]
376
384
  weight_loader = getattr(param, "weight_loader", default_weight_loader)
377
385
  weight_loader(param, loaded_weight)
378
386
  loaded_params.add(name)
379
387
 
380
388
 
381
- EntryClass = CohereForCausalLM
389
+ class Cohere2ForCausalLM(CohereForCausalLM):
390
+ pass
391
+
392
+
393
+ EntryClass = [CohereForCausalLM, Cohere2ForCausalLM]
sglang/srt/models/dbrx.py CHANGED
@@ -19,14 +19,13 @@ from typing import Iterable, Optional, Tuple
19
19
 
20
20
  import torch
21
21
  import torch.nn as nn
22
- from vllm.distributed import (
22
+
23
+ from sglang.srt.configs import DbrxConfig
24
+ from sglang.srt.distributed import (
23
25
  get_tensor_model_parallel_rank,
24
26
  get_tensor_model_parallel_world_size,
25
27
  tensor_model_parallel_all_reduce,
26
28
  )
27
- from vllm.model_executor.layers.rotary_embedding import get_rope
28
-
29
- from sglang.srt.configs import DbrxConfig
30
29
  from sglang.srt.layers.linear import (
31
30
  QKVParallelLinear,
32
31
  ReplicatedLinear,
@@ -36,13 +35,17 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
36
35
  from sglang.srt.layers.moe.fused_moe_triton import fused_moe
37
36
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
38
37
  from sglang.srt.layers.radix_attention import RadixAttention
38
+ from sglang.srt.layers.rotary_embedding import get_rope
39
39
  from sglang.srt.layers.vocab_parallel_embedding import (
40
40
  DEFAULT_VOCAB_PADDING_SIZE,
41
41
  ParallelLMHead,
42
42
  VocabParallelEmbedding,
43
43
  )
44
44
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
45
- from sglang.srt.model_loader.weight_utils import default_weight_loader
45
+ from sglang.srt.model_loader.weight_utils import (
46
+ default_weight_loader,
47
+ maybe_remap_kv_scale_name,
48
+ )
46
49
  from sglang.srt.utils import set_weight_attrs
47
50
 
48
51
 
@@ -411,6 +414,11 @@ class DbrxForCausalLM(nn.Module):
411
414
  weight_loader(param, loaded_weight, weight_name)
412
415
  break
413
416
  else:
417
+ # Remapping the name of FP8 kv-scale.
418
+ name = maybe_remap_kv_scale_name(name, params_dict)
419
+ if name is None:
420
+ continue
421
+
414
422
  param = params_dict[name]
415
423
  weight_loader = getattr(param, "weight_loader", default_weight_loader)
416
424
  weight_loader(param, loaded_weight)
@@ -21,13 +21,12 @@ from typing import Any, Dict, Iterable, Optional, Tuple
21
21
  import torch
22
22
  from torch import nn
23
23
  from transformers import PretrainedConfig
24
- from vllm.distributed import (
24
+
25
+ from sglang.srt.distributed import (
25
26
  get_tensor_model_parallel_rank,
26
27
  get_tensor_model_parallel_world_size,
27
28
  tensor_model_parallel_all_reduce,
28
29
  )
29
- from vllm.model_executor.layers.rotary_embedding import get_rope
30
-
31
30
  from sglang.srt.layers.activation import SiluAndMul
32
31
  from sglang.srt.layers.layernorm import RMSNorm
33
32
  from sglang.srt.layers.linear import (
@@ -40,6 +39,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
40
39
  from sglang.srt.layers.moe.fused_moe_triton import fused_moe
41
40
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
42
41
  from sglang.srt.layers.radix_attention import RadixAttention
42
+ from sglang.srt.layers.rotary_embedding import get_rope
43
43
  from sglang.srt.layers.vocab_parallel_embedding import (
44
44
  ParallelLMHead,
45
45
  VocabParallelEmbedding,
@@ -23,14 +23,13 @@ import torch.nn.functional as F
23
23
  from torch import nn
24
24
  from transformers import PretrainedConfig
25
25
  from vllm import _custom_ops as ops
26
- from vllm.distributed import (
26
+
27
+ from sglang.srt.distributed import (
27
28
  get_tensor_model_parallel_rank,
28
29
  get_tensor_model_parallel_world_size,
29
30
  get_tp_group,
30
31
  tensor_model_parallel_all_reduce,
31
32
  )
32
- from vllm.model_executor.layers.rotary_embedding import get_rope
33
-
34
33
  from sglang.srt.layers.activation import SiluAndMul
35
34
  from sglang.srt.layers.layernorm import RMSNorm
36
35
  from sglang.srt.layers.linear import (
@@ -49,6 +48,7 @@ from sglang.srt.layers.quantization.fp8_utils import (
49
48
  normalize_e4m3fn_to_e4m3fnuz,
50
49
  )
51
50
  from sglang.srt.layers.radix_attention import RadixAttention
51
+ from sglang.srt.layers.rotary_embedding import get_rope, get_rope_wrapper
52
52
  from sglang.srt.layers.vocab_parallel_embedding import (
53
53
  ParallelLMHead,
54
54
  VocabParallelEmbedding,
@@ -56,12 +56,12 @@ from sglang.srt.layers.vocab_parallel_embedding import (
56
56
  from sglang.srt.managers.schedule_batch import global_server_args_dict
57
57
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
58
58
  from sglang.srt.model_loader.weight_utils import default_weight_loader
59
- from sglang.srt.utils import is_flashinfer_available, is_hip
59
+ from sglang.srt.utils import is_cuda_available, is_hip
60
60
 
61
61
  is_hip_ = is_hip()
62
62
 
63
- if is_flashinfer_available():
64
- from flashinfer import bmm_fp8
63
+ if is_cuda_available():
64
+ from sgl_kernel import bmm_fp8
65
65
 
66
66
 
67
67
  class DeepseekV2MLP(nn.Module):
@@ -271,13 +271,14 @@ class DeepseekV2Attention(nn.Module):
271
271
  quant_config=quant_config,
272
272
  )
273
273
  rope_scaling["rope_type"] = "deepseek_yarn"
274
- self.rotary_emb = get_rope(
274
+ self.rotary_emb = get_rope_wrapper(
275
275
  qk_rope_head_dim,
276
276
  rotary_dim=qk_rope_head_dim,
277
277
  max_position=max_position_embeddings,
278
278
  base=rope_theta,
279
279
  rope_scaling=rope_scaling,
280
280
  is_neox_style=False,
281
+ device=global_server_args_dict["device"],
281
282
  )
282
283
 
283
284
  if rope_scaling:
@@ -855,10 +856,9 @@ class DeepseekV2ForCausalLM(nn.Module):
855
856
  forward_batch: ForwardBatch,
856
857
  ) -> torch.Tensor:
857
858
  hidden_states = self.model(input_ids, positions, forward_batch)
858
- if not forward_batch.forward_mode.is_idle():
859
- return self.logits_processor(
860
- input_ids, hidden_states, self.lm_head, forward_batch
861
- )
859
+ return self.logits_processor(
860
+ input_ids, hidden_states, self.lm_head, forward_batch
861
+ )
862
862
 
863
863
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
864
864
  stacked_params_mapping = [
@@ -20,9 +20,8 @@ from typing import Any, Dict, Iterable, Optional, Tuple
20
20
 
21
21
  import torch
22
22
  from torch import nn
23
- from vllm.distributed import get_tensor_model_parallel_world_size
24
- from vllm.model_executor.layers.rotary_embedding import get_rope
25
23
 
24
+ from sglang.srt.distributed import get_tensor_model_parallel_world_size
26
25
  from sglang.srt.layers.activation import SiluAndMul
27
26
  from sglang.srt.layers.layernorm import RMSNorm
28
27
  from sglang.srt.layers.linear import (
@@ -33,6 +32,7 @@ from sglang.srt.layers.linear import (
33
32
  from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
34
33
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
35
34
  from sglang.srt.layers.radix_attention import RadixAttention
35
+ from sglang.srt.layers.rotary_embedding import get_rope
36
36
  from sglang.srt.layers.vocab_parallel_embedding import (
37
37
  ParallelLMHead,
38
38
  VocabParallelEmbedding,
@@ -21,9 +21,8 @@ from typing import Iterable, Optional, Tuple
21
21
  import torch
22
22
  from torch import nn
23
23
  from transformers import PretrainedConfig
24
- from vllm.distributed import get_tensor_model_parallel_world_size
25
- from vllm.model_executor.layers.rotary_embedding import get_rope
26
24
 
25
+ from sglang.srt.distributed import get_tensor_model_parallel_world_size
27
26
  from sglang.srt.layers.activation import GeluAndMul
28
27
  from sglang.srt.layers.layernorm import RMSNorm
29
28
  from sglang.srt.layers.linear import (
@@ -34,6 +33,7 @@ from sglang.srt.layers.linear import (
34
33
  from sglang.srt.layers.logits_processor import LogitsProcessor
35
34
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
36
35
  from sglang.srt.layers.radix_attention import RadixAttention
36
+ from sglang.srt.layers.rotary_embedding import get_rope
37
37
  from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
38
38
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
39
39
  from sglang.srt.model_loader.weight_utils import default_weight_loader
@@ -15,13 +15,13 @@
15
15
  # Adapted from:
16
16
  # https://github.com/vllm-project/vllm/blob/56b325e977435af744f8b3dca7af0ca209663558/vllm/model_executor/models/gemma2.py
17
17
 
18
- from typing import Iterable, Optional, Set, Tuple, Union
18
+ from typing import Iterable, Optional, Set, Tuple
19
19
 
20
20
  import torch
21
21
  from torch import nn
22
22
  from transformers import PretrainedConfig
23
- from vllm.distributed import get_tensor_model_parallel_world_size
24
23
 
24
+ from sglang.srt.distributed import get_tensor_model_parallel_world_size
25
25
  from sglang.srt.layers.activation import GeluAndMul
26
26
  from sglang.srt.layers.layernorm import GemmaRMSNorm
27
27
  from sglang.srt.layers.linear import (
@@ -32,9 +32,13 @@ from sglang.srt.layers.linear import (
32
32
  from sglang.srt.layers.logits_processor import LogitsProcessor
33
33
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
34
34
  from sglang.srt.layers.radix_attention import RadixAttention
35
+ from sglang.srt.layers.rotary_embedding import get_rope
35
36
  from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
36
37
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
37
- from sglang.srt.model_loader.weight_utils import default_weight_loader
38
+ from sglang.srt.model_loader.weight_utils import (
39
+ default_weight_loader,
40
+ maybe_remap_kv_scale_name,
41
+ )
38
42
  from sglang.srt.utils import make_layers
39
43
 
40
44
 
@@ -44,23 +48,6 @@ def get_attention_sliding_window_size(config):
44
48
  return config.sliding_window - 1
45
49
 
46
50
 
47
- # FIXME: temporary solution, remove after next vllm release
48
- from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
49
-
50
-
51
- class GemmaRotaryEmbedding(RotaryEmbedding):
52
- def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
53
- # https://github.com/huggingface/transformers/blob/v4.41.2/src/transformers/models/gemma/modeling_gemma.py#L107
54
- inv_freq = 1.0 / (
55
- base
56
- ** (
57
- torch.arange(0, self.rotary_dim, 2, dtype=torch.int64).float()
58
- / self.rotary_dim
59
- )
60
- )
61
- return inv_freq
62
-
63
-
64
51
  class Gemma2MLP(nn.Module):
65
52
  def __init__(
66
53
  self,
@@ -143,14 +130,12 @@ class Gemma2Attention(nn.Module):
143
130
  bias=config.attention_bias,
144
131
  quant_config=quant_config,
145
132
  )
146
- # from vLLM: TODO(woosuk): Use the `get_rope` interface.
147
- self.rotary_emb = GemmaRotaryEmbedding(
148
- self.head_dim,
133
+ self.rotary_emb = get_rope(
149
134
  self.head_dim,
150
- max_position_embeddings,
135
+ rotary_dim=self.head_dim,
136
+ max_position=max_position_embeddings,
151
137
  base=self.rope_theta,
152
138
  is_neox_style=True,
153
- dtype=torch.get_default_dtype(),
154
139
  )
155
140
 
156
141
  use_sliding_window = layer_id % 2 == 0 and hasattr(config, "sliding_window")
@@ -442,6 +427,11 @@ class Gemma2ForCausalLM(nn.Module):
442
427
  # Skip loading extra bias for GPTQ models.
443
428
  if name.endswith(".bias") and name not in params_dict:
444
429
  continue
430
+ # Remapping the name of FP8 kv-scale.
431
+ name = maybe_remap_kv_scale_name(name, params_dict)
432
+ if name is None:
433
+ continue
434
+
445
435
  param = params_dict[name]
446
436
  weight_loader = getattr(param, "weight_loader", default_weight_loader)
447
437
  weight_loader(param, loaded_weight)
sglang/srt/models/gpt2.py CHANGED
@@ -17,16 +17,14 @@
17
17
  # See the License for the specific language governing permissions and
18
18
  # limitations under the License.
19
19
  """Inference-only GPT-2 model compatible with HuggingFace weights."""
20
- from typing import Iterable, List, Optional, Tuple
20
+ from typing import Iterable, Optional, Tuple
21
21
 
22
22
  import torch
23
23
  from torch import nn
24
24
  from transformers import GPT2Config
25
- from vllm.distributed.parallel_state import get_tensor_model_parallel_world_size
26
- from vllm.model_executor.layers.activation import get_act_fn
27
- from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
28
25
 
29
- # from sglang.srt.layers.activation import get_act_fn
26
+ from sglang.srt.distributed.parallel_state import get_tensor_model_parallel_world_size
27
+ from sglang.srt.layers.activation import get_act_fn
30
28
  from sglang.srt.layers.linear import (
31
29
  ColumnParallelLinear,
32
30
  QKVParallelLinear,
@@ -21,8 +21,8 @@ from typing import Iterable, Optional, Tuple
21
21
  import torch
22
22
  from torch import nn
23
23
  from transformers import GPTBigCodeConfig
24
- from vllm.distributed import get_tensor_model_parallel_world_size
25
24
 
25
+ from sglang.srt.distributed import get_tensor_model_parallel_world_size
26
26
  from sglang.srt.layers.activation import get_act_fn
27
27
  from sglang.srt.layers.linear import (
28
28
  ColumnParallelLinear,
@@ -22,9 +22,8 @@ from typing import Any, Dict, Iterable, Optional, Tuple
22
22
  import torch
23
23
  from torch import nn
24
24
  from transformers import GraniteConfig
25
- from vllm.distributed import get_tensor_model_parallel_world_size
26
- from vllm.model_executor.layers.rotary_embedding import get_rope
27
25
 
26
+ from sglang.srt.distributed import get_tensor_model_parallel_world_size
28
27
  from sglang.srt.layers.activation import SiluAndMul
29
28
  from sglang.srt.layers.layernorm import RMSNorm
30
29
  from sglang.srt.layers.linear import (
@@ -36,6 +35,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorO
36
35
  from sglang.srt.layers.pooler import Pooler, PoolingType
37
36
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
38
37
  from sglang.srt.layers.radix_attention import RadixAttention
38
+ from sglang.srt.layers.rotary_embedding import get_rope
39
39
  from sglang.srt.layers.vocab_parallel_embedding import (
40
40
  ParallelLMHead,
41
41
  VocabParallelEmbedding,
sglang/srt/models/grok.py CHANGED
@@ -22,12 +22,11 @@ import torch
22
22
  import torch.nn.functional as F
23
23
  from torch import nn
24
24
  from transformers import PretrainedConfig
25
- from vllm.distributed import (
25
+
26
+ from sglang.srt.distributed import (
26
27
  get_tensor_model_parallel_rank,
27
28
  get_tensor_model_parallel_world_size,
28
29
  )
29
- from vllm.model_executor.layers.rotary_embedding import get_rope
30
-
31
30
  from sglang.srt.layers.activation import GeluAndMul
32
31
  from sglang.srt.layers.layernorm import RMSNorm
33
32
  from sglang.srt.layers.linear import (
@@ -40,6 +39,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
40
39
  from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
41
40
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
42
41
  from sglang.srt.layers.radix_attention import RadixAttention
42
+ from sglang.srt.layers.rotary_embedding import get_rope
43
43
  from sglang.srt.layers.vocab_parallel_embedding import (
44
44
  ParallelLMHead,
45
45
  VocabParallelEmbedding,
@@ -133,6 +133,7 @@ class Grok1MoE(nn.Module):
133
133
  renormalize=False,
134
134
  quant_config=quant_config,
135
135
  tp_size=tp_size,
136
+ activation="gelu",
136
137
  use_presharded_weights=use_presharded_weights,
137
138
  )
138
139
 
@@ -19,9 +19,8 @@ from typing import Any, Dict, Iterable, Optional, Tuple
19
19
  import torch
20
20
  from torch import nn
21
21
  from transformers import PretrainedConfig
22
- from vllm.distributed import get_tensor_model_parallel_world_size
23
- from vllm.model_executor.layers.rotary_embedding import get_rope
24
22
 
23
+ from sglang.srt.distributed import get_tensor_model_parallel_world_size
25
24
  from sglang.srt.layers.activation import SiluAndMul
26
25
  from sglang.srt.layers.layernorm import RMSNorm
27
26
  from sglang.srt.layers.linear import (
@@ -32,6 +31,7 @@ from sglang.srt.layers.linear import (
32
31
  from sglang.srt.layers.logits_processor import LogitsProcessor
33
32
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
34
33
  from sglang.srt.layers.radix_attention import RadixAttention
34
+ from sglang.srt.layers.rotary_embedding import get_rope
35
35
  from sglang.srt.layers.vocab_parallel_embedding import (
36
36
  ParallelLMHead,
37
37
  VocabParallelEmbedding,