sglang 0.4.1.post7__py3-none-any.whl → 0.4.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (67) hide show
  1. sglang/bench_offline_throughput.py +17 -11
  2. sglang/bench_one_batch.py +14 -6
  3. sglang/bench_serving.py +47 -44
  4. sglang/lang/chat_template.py +31 -0
  5. sglang/srt/configs/load_config.py +1 -0
  6. sglang/srt/distributed/device_communicators/custom_all_reduce.py +5 -2
  7. sglang/srt/entrypoints/engine.py +5 -2
  8. sglang/srt/entrypoints/http_server.py +24 -0
  9. sglang/srt/function_call_parser.py +494 -0
  10. sglang/srt/layers/activation.py +5 -5
  11. sglang/srt/layers/dp_attention.py +3 -1
  12. sglang/srt/layers/layernorm.py +5 -5
  13. sglang/srt/layers/linear.py +24 -9
  14. sglang/srt/layers/logits_processor.py +1 -1
  15. sglang/srt/layers/moe/ep_moe/layer.py +20 -12
  16. sglang/srt/layers/moe/fused_moe_native.py +17 -3
  17. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  18. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +18 -1
  19. sglang/srt/layers/moe/fused_moe_triton/layer.py +9 -0
  20. sglang/srt/layers/parameter.py +16 -7
  21. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  22. sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  23. sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  24. sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  25. sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  26. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  27. sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  28. sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  29. sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  30. sglang/srt/layers/quantization/fp8.py +4 -1
  31. sglang/srt/layers/rotary_embedding.py +6 -1
  32. sglang/srt/layers/sampler.py +28 -8
  33. sglang/srt/layers/torchao_utils.py +12 -6
  34. sglang/srt/managers/detokenizer_manager.py +1 -0
  35. sglang/srt/managers/io_struct.py +36 -5
  36. sglang/srt/managers/schedule_batch.py +31 -25
  37. sglang/srt/managers/scheduler.py +61 -35
  38. sglang/srt/managers/tokenizer_manager.py +4 -0
  39. sglang/srt/model_executor/cuda_graph_runner.py +23 -25
  40. sglang/srt/model_executor/forward_batch_info.py +5 -7
  41. sglang/srt/model_executor/model_runner.py +7 -4
  42. sglang/srt/model_loader/loader.py +75 -0
  43. sglang/srt/model_loader/weight_utils.py +91 -5
  44. sglang/srt/models/commandr.py +14 -2
  45. sglang/srt/models/dbrx.py +9 -1
  46. sglang/srt/models/deepseek_v2.py +3 -3
  47. sglang/srt/models/gemma2.py +9 -1
  48. sglang/srt/models/grok.py +1 -0
  49. sglang/srt/models/minicpm3.py +3 -3
  50. sglang/srt/models/torch_native_llama.py +17 -4
  51. sglang/srt/openai_api/adapter.py +139 -37
  52. sglang/srt/openai_api/protocol.py +5 -4
  53. sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +11 -14
  54. sglang/srt/sampling/sampling_batch_info.py +4 -14
  55. sglang/srt/server.py +2 -2
  56. sglang/srt/server_args.py +20 -1
  57. sglang/srt/speculative/eagle_utils.py +37 -15
  58. sglang/srt/speculative/eagle_worker.py +11 -13
  59. sglang/srt/utils.py +62 -65
  60. sglang/test/test_programs.py +1 -0
  61. sglang/test/test_utils.py +81 -22
  62. sglang/version.py +1 -1
  63. {sglang-0.4.1.post7.dist-info → sglang-0.4.2.dist-info}/METADATA +7 -7
  64. {sglang-0.4.1.post7.dist-info → sglang-0.4.2.dist-info}/RECORD +67 -56
  65. {sglang-0.4.1.post7.dist-info → sglang-0.4.2.dist-info}/LICENSE +0 -0
  66. {sglang-0.4.1.post7.dist-info → sglang-0.4.2.dist-info}/WHEEL +0 -0
  67. {sglang-0.4.1.post7.dist-info → sglang-0.4.2.dist-info}/top_level.txt +0 -0
@@ -24,7 +24,7 @@ import tqdm
24
24
  from vllm.model_executor.custom_op import CustomOp
25
25
 
26
26
  from sglang.srt.distributed import get_tensor_model_parallel_rank
27
- from sglang.srt.distributed.parallel_state import graph_capture
27
+ from sglang.srt.distributed.parallel_state import GroupCoordinator, graph_capture
28
28
  from sglang.srt.layers.logits_processor import LogitsProcessorOutput
29
29
  from sglang.srt.layers.moe.fused_moe_native import fused_moe_forward_native
30
30
  from sglang.srt.layers.torchao_utils import save_gemlite_cache
@@ -38,7 +38,7 @@ if TYPE_CHECKING:
38
38
  from sglang.srt.model_executor.model_runner import ModelRunner
39
39
 
40
40
 
41
- def _to_torch(model: torch.nn.Module, reverse: bool, batch_size: int):
41
+ def _to_torch(model: torch.nn.Module, reverse: bool, num_tokens: int):
42
42
  for sub in model._modules.values():
43
43
  if isinstance(sub, CustomOp):
44
44
  if reverse:
@@ -47,7 +47,7 @@ def _to_torch(model: torch.nn.Module, reverse: bool, batch_size: int):
47
47
  else:
48
48
  # NOTE: Temporarily workaround MoE
49
49
  if "FusedMoE" in sub.__class__.__name__:
50
- if batch_size == 1:
50
+ if num_tokens == 1:
51
51
  # The performance of torch.compile on this layer is not always good when bs > 1,
52
52
  # so we decide to only use torch.compile when bs =1
53
53
  sub._forward_method = fused_moe_forward_native
@@ -55,22 +55,22 @@ def _to_torch(model: torch.nn.Module, reverse: bool, batch_size: int):
55
55
  sub._forward_method = sub.forward_native
56
56
  setattr(sub, "is_torch_compile", True)
57
57
  if isinstance(sub, torch.nn.Module):
58
- _to_torch(sub, reverse, batch_size)
58
+ _to_torch(sub, reverse, num_tokens)
59
59
 
60
60
 
61
61
  @contextmanager
62
62
  def patch_model(
63
63
  model: torch.nn.Module,
64
64
  enable_compile: bool,
65
- batch_size: int,
66
- tp_group: "GroupCoordinator",
65
+ num_tokens: int,
66
+ tp_group: GroupCoordinator,
67
67
  ):
68
68
  """Patch the model to make it compatible with with torch.compile"""
69
69
  backup_ca_comm = None
70
70
 
71
71
  try:
72
72
  if enable_compile:
73
- _to_torch(model, reverse=False, batch_size=batch_size)
73
+ _to_torch(model, reverse=False, num_tokens=num_tokens)
74
74
  backup_ca_comm = tp_group.ca_comm
75
75
  # Use custom-allreduce here.
76
76
  # We found the custom allreduce is much faster than the built-in allreduce in torch,
@@ -85,7 +85,7 @@ def patch_model(
85
85
  yield model.forward
86
86
  finally:
87
87
  if enable_compile:
88
- _to_torch(model, reverse=True, batch_size=batch_size)
88
+ _to_torch(model, reverse=True, num_tokens=num_tokens)
89
89
  tp_group.ca_comm = backup_ca_comm
90
90
 
91
91
 
@@ -149,9 +149,18 @@ class CudaGraphRunner:
149
149
  and bs <= model_runner.server_args.cuda_graph_max_bs
150
150
  ]
151
151
 
152
+ self.compile_bs = (
153
+ [
154
+ bs
155
+ for bs in self.capture_bs
156
+ if bs <= self.model_runner.server_args.torch_compile_max_bs
157
+ ]
158
+ if self.use_torch_compile
159
+ else []
160
+ )
161
+
152
162
  self.capture_forward_mode = ForwardMode.DECODE
153
163
  self.num_tokens_per_bs = 1
154
-
155
164
  if model_runner.spec_algorithm.is_eagle():
156
165
  if self.model_runner.is_draft_worker:
157
166
  self.num_tokens_per_bs = (
@@ -163,16 +172,6 @@ class CudaGraphRunner:
163
172
  self.model_runner.server_args.speculative_num_draft_tokens
164
173
  )
165
174
 
166
- self.compile_bs = (
167
- [
168
- bs
169
- for bs in self.capture_bs
170
- if bs <= self.model_runner.server_args.torch_compile_max_bs
171
- ]
172
- if self.use_torch_compile
173
- else []
174
- )
175
-
176
175
  # Attention backend
177
176
  self.max_bs = max(self.capture_bs)
178
177
  self.max_num_token = self.max_bs * self.num_tokens_per_bs
@@ -180,7 +179,6 @@ class CudaGraphRunner:
180
179
  self.seq_len_fill_value = (
181
180
  self.model_runner.attn_backend.get_cuda_graph_seq_len_fill_value()
182
181
  )
183
-
184
182
  # FIXME(lsyin): leave it here for now, I don't know whether it is necessary
185
183
  self.encoder_len_fill_value = 0
186
184
 
@@ -189,14 +187,14 @@ class CudaGraphRunner:
189
187
 
190
188
  # Common inputs
191
189
  with torch.device("cuda"):
192
- self.input_ids = torch.zeros((self.max_num_token,), dtype=torch.int32)
190
+ self.input_ids = torch.zeros((self.max_num_token,), dtype=torch.int64)
193
191
  self.req_pool_indices = torch.zeros((self.max_bs,), dtype=torch.int32)
194
192
  self.seq_lens = torch.full(
195
193
  (self.max_bs,), self.seq_len_fill_value, dtype=torch.int32
196
194
  )
197
- self.out_cache_loc = torch.zeros((self.max_num_token,), dtype=torch.int32)
195
+ self.out_cache_loc = torch.zeros((self.max_num_token,), dtype=torch.int64)
198
196
  self.positions = torch.zeros((self.max_num_token,), dtype=torch.int64)
199
- self.mrope_positions = torch.zeros((3, self.max_bs), dtype=torch.int32)
197
+ self.mrope_positions = torch.zeros((3, self.max_bs), dtype=torch.int64)
200
198
 
201
199
  # Speculative_inference
202
200
  if model_runner.spec_algorithm.is_eagle():
@@ -285,8 +283,8 @@ class CudaGraphRunner:
285
283
  with patch_model(
286
284
  self.model_runner.model,
287
285
  bs in self.compile_bs,
288
- bs,
289
- self.model_runner.tp_group,
286
+ num_tokens=bs * self.num_tokens_per_bs,
287
+ tp_group=self.model_runner.tp_group,
290
288
  ) as forward:
291
289
  (
292
290
  graph,
@@ -38,7 +38,7 @@ import triton
38
38
  import triton.language as tl
39
39
 
40
40
  from sglang.srt.layers.rotary_embedding import MRotaryEmbedding
41
- from sglang.srt.utils import maybe_torch_compile
41
+ from sglang.srt.utils import get_compiler_backend
42
42
 
43
43
  if TYPE_CHECKING:
44
44
  from sglang.srt.layers.attention import AttentionBackend
@@ -282,6 +282,9 @@ class ForwardBatch:
282
282
  can_run_dp_cuda_graph=batch.can_run_dp_cuda_graph,
283
283
  lora_paths=batch.lora_paths,
284
284
  sampling_info=batch.sampling_info,
285
+ req_to_token_pool=model_runner.req_to_token_pool,
286
+ token_to_kv_pool=model_runner.token_to_kv_pool,
287
+ attn_backend=model_runner.attn_backend,
285
288
  spec_algorithm=batch.spec_algorithm,
286
289
  spec_info=batch.spec_info,
287
290
  capture_hidden_mode=batch.capture_hidden_mode,
@@ -336,11 +339,6 @@ class ForwardBatch:
336
339
  if model_runner.model_is_mrope:
337
340
  ret.compute_mrope_positions(model_runner, batch)
338
341
 
339
- # Init attention information
340
- ret.req_to_token_pool = model_runner.req_to_token_pool
341
- ret.token_to_kv_pool = model_runner.token_to_kv_pool
342
- ret.attn_backend = model_runner.attn_backend
343
-
344
342
  # Init lora information
345
343
  if model_runner.server_args.lora_paths is not None:
346
344
  model_runner.lora_manager.prepare_lora_batch(ret)
@@ -417,6 +415,6 @@ def compute_position_torch(
417
415
  return positions.to(torch.int64), extend_start_loc
418
416
 
419
417
 
420
- @maybe_torch_compile(dynamic=True)
418
+ @torch.compile(dynamic=True, backend=get_compiler_backend())
421
419
  def clamp_position(seq_lens):
422
420
  return torch.clamp((seq_lens - 1), min=0).to(torch.int64)
@@ -185,9 +185,12 @@ class ModelRunner:
185
185
  self.load_model()
186
186
 
187
187
  # Apply torchao quantization
188
- apply_torchao_config_to_model(
189
- self.model, global_server_args_dict["torchao_config"]
190
- )
188
+ torchao_applied = getattr(self.model, "torchao_applied", False)
189
+ # In layered loading, torchao may have been applied
190
+ if not torchao_applied:
191
+ apply_torchao_config_to_model(
192
+ self.model, global_server_args_dict["torchao_config"]
193
+ )
191
194
 
192
195
  # Apply torch TP if the model supports it
193
196
  supports_torch_tp = getattr(self.model, "supports_torch_tp", False)
@@ -215,7 +218,7 @@ class ModelRunner:
215
218
 
216
219
  def init_torch_distributed(self):
217
220
  logger.info("Init torch distributed begin.")
218
- # Init torch distributed
221
+
219
222
  torch.get_device_module(self.device).set_device(self.gpu_id)
220
223
  if self.device == "cuda":
221
224
  backend = "nccl"
@@ -374,6 +374,78 @@ class DefaultModelLoader(BaseModelLoader):
374
374
  return model.eval()
375
375
 
376
376
 
377
+ class LayeredModelLoader(DefaultModelLoader):
378
+ """Model loader that loads weights layer by layer so that one can quantize a
379
+ layer before loading another to make the peak memory envelope smaller."""
380
+
381
+ def __init__(self, load_config: LoadConfig):
382
+ # Back to the default load format
383
+ load_config.load_format = LoadFormat.AUTO
384
+ super().__init__(load_config)
385
+
386
+ def load_model(
387
+ self,
388
+ *,
389
+ model_config: ModelConfig,
390
+ device_config: DeviceConfig,
391
+ ) -> nn.Module:
392
+ from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model
393
+ from sglang.srt.managers.schedule_batch import global_server_args_dict
394
+
395
+ torchao_config = global_server_args_dict.get("torchao_config")
396
+ target_device = torch.device(device_config.device)
397
+
398
+ with set_default_torch_dtype(model_config.dtype):
399
+ # Create model on meta device
400
+ with torch.device("meta"):
401
+ model = _initialize_model(
402
+ model_config,
403
+ self.load_config,
404
+ )
405
+
406
+ # Check model's layered load support
407
+ if not hasattr(model, "load_weights_to_module"):
408
+ raise ValueError(
409
+ "LayeredModelLoader requires the model to have a "
410
+ "`load_weights_to_module` method. "
411
+ f"{model_config.model_path} does not support it."
412
+ )
413
+
414
+ # Get all weights from disk
415
+ weights = self._get_all_weights(model_config, model)
416
+
417
+ # Helper function to recursively fill the weights of a module
418
+ def fill_module(module, fqn: List[str], weights):
419
+ """
420
+ fqn: list of strings representing the fully qualified name of `module`.
421
+ """
422
+ # Layer by layer
423
+ for name, submod in module.named_children():
424
+ fill_module(submod, fqn + [name], weights)
425
+
426
+ # First materialize on target device
427
+ module.to_empty(device=target_device, recurse=False)
428
+ fqn_path = ".".join(fqn)
429
+ # Fill weights
430
+ model.load_weights_to_module(
431
+ fqn_path,
432
+ weights,
433
+ )
434
+ # Quantize weights if applicable
435
+ if torchao_config and "proj" in fqn_path:
436
+ # Note: `None` here is needed to indicate no filter, see
437
+ # `apply_torchao_config_to_model` for details.
438
+ apply_torchao_config_to_model(module, torchao_config, None)
439
+
440
+ # Start calling on root module
441
+ fill_module(model, [], weights)
442
+
443
+ if torchao_config:
444
+ model.torchao_applied = True
445
+
446
+ return model.eval()
447
+
448
+
377
449
  class DummyModelLoader(BaseModelLoader):
378
450
  """Model loader that will set model weights to random values."""
379
451
 
@@ -1149,4 +1221,7 @@ def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
1149
1221
  if load_config.load_format == LoadFormat.GGUF:
1150
1222
  return GGUFModelLoader(load_config)
1151
1223
 
1224
+ if load_config.load_format == LoadFormat.LAYERED:
1225
+ return LayeredModelLoader(load_config)
1226
+
1152
1227
  return DefaultModelLoader(load_config)
@@ -27,6 +27,7 @@ import huggingface_hub.constants
27
27
  import numpy as np
28
28
  import torch
29
29
  from huggingface_hub import HfFileSystem, hf_hub_download, snapshot_download
30
+ from pydantic import BaseModel, ConfigDict, ValidationInfo, model_validator
30
31
  from safetensors.torch import load_file, safe_open, save_file
31
32
  from tqdm.auto import tqdm
32
33
 
@@ -403,8 +404,13 @@ def np_cache_weights_iterator(
403
404
 
404
405
  def safetensors_weights_iterator(
405
406
  hf_weights_files: List[str],
407
+ is_all_weights_sharded: bool = False,
406
408
  ) -> Generator[Tuple[str, torch.Tensor], None, None]:
407
- """Iterate over the weights in the model safetensor files."""
409
+ """Iterate over the weights in the model safetensor files.
410
+
411
+ If is_all_weights_sharded is True, it uses more optimize read by reading an
412
+ entire file instead of reading each tensor one by one.
413
+ """
408
414
  enable_tqdm = (
409
415
  not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0
410
416
  )
@@ -414,9 +420,14 @@ def safetensors_weights_iterator(
414
420
  disable=not enable_tqdm,
415
421
  bar_format=_BAR_FORMAT,
416
422
  ):
417
- with safe_open(st_file, framework="pt") as f:
418
- for name in f.keys(): # noqa: SIM118
419
- param = f.get_tensor(name)
423
+ if not is_all_weights_sharded:
424
+ with safe_open(st_file, framework="pt") as f:
425
+ for name in f.keys(): # noqa: SIM118
426
+ param = f.get_tensor(name)
427
+ yield name, param
428
+ else:
429
+ result = load_file(st_file, device="cpu")
430
+ for name, param in result.items():
420
431
  yield name, param
421
432
 
422
433
 
@@ -650,6 +661,81 @@ def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> Optional[str]:
650
661
  return name
651
662
 
652
663
 
664
+ # Adapted from https://github.com/vllm-project/vllm/blob/68ad4e3a8d8a66fb2a43be57471ee13a8bec4ec0/vllm/model_executor/layers/quantization/schema.py
665
+ class KVCacheQuantSchema(BaseModel):
666
+ dtype: str
667
+ # Each key is a TP rank. Each value is a dictionary mapping a TP rank's
668
+ # layer indices to their per-tensor KV cache scaling factor.
669
+ # TODO: Consider pulling this and its validation methods out into its
670
+ # own schema class (tricky as its members are variable)
671
+ scaling_factor: Dict[int, Dict[int, float]]
672
+
673
+ @model_validator(mode="after")
674
+ def check_is_fp8(self) -> "KVCacheQuantSchema":
675
+ assert self.dtype == "float8_e4m3fn", (
676
+ "Loaded scaling factors intended for KV cache dtype = "
677
+ f"{self.dtype} rather than float8_e4m3fn!"
678
+ )
679
+ return self
680
+
681
+ @model_validator(mode="after")
682
+ def check_tp_ranks(self, info: ValidationInfo) -> "KVCacheQuantSchema":
683
+ context = info.context
684
+ if context:
685
+ tp_size = context["tp_size"]
686
+ num_hidden_layers = context["num_hidden_layers"]
687
+ assert len(self.scaling_factor) == tp_size, (
688
+ f"Loaded dictionary has TP size {len(self.scaling_factor)} "
689
+ f"but LLM engine is currently running with TP size {tp_size}."
690
+ )
691
+ for tp_rank, layer_maps in self.scaling_factor.items():
692
+ assert len(layer_maps) == num_hidden_layers, (
693
+ f"KV cache scales map for TP rank {tp_rank} is malformed. "
694
+ f"Expected {num_hidden_layers} layers, got "
695
+ f"{len(layer_maps)}."
696
+ )
697
+ for i in range(tp_size):
698
+ assert (
699
+ i in self.scaling_factor
700
+ ), f"KV cache scales map for TP rank {i} not found."
701
+ return self
702
+
703
+ @model_validator(mode="after")
704
+ def check_current_rank(self, info: ValidationInfo) -> "KVCacheQuantSchema":
705
+ context = info.context
706
+ if context:
707
+ tp_rank = context["tp_rank"]
708
+ num_hidden_layers = context["num_hidden_layers"]
709
+ layer_scales_map = self.scaling_factor[tp_rank]
710
+ for i in range(num_hidden_layers):
711
+ assert i in layer_scales_map, (
712
+ f"Could not find KV cache scales for layer {i} in "
713
+ f"TP rank {tp_rank}."
714
+ )
715
+ return self
716
+
717
+
718
+ class QuantParamSchema(BaseModel):
719
+ # TODO: Generalize and extend with more fields
720
+ # (e.g. weights/activations params) once functionality is enabled
721
+ model_config = ConfigDict(protected_namespaces=())
722
+ model_type: Optional[str]
723
+ kv_cache: KVCacheQuantSchema
724
+
725
+ @model_validator(mode="after")
726
+ def check_model_type(self, info: ValidationInfo) -> "QuantParamSchema":
727
+ context = info.context
728
+ if context:
729
+ model_type = context.get("model_type", None)
730
+ if model_type is not None:
731
+ assert model_type == self.model_type, (
732
+ f"Model type is {model_type} but loaded "
733
+ f"scaling factors belonging to different "
734
+ f"model type {self.model_type}!"
735
+ )
736
+ return self
737
+
738
+
653
739
  def kv_cache_scales_loader(
654
740
  filename: str,
655
741
  tp_rank: int,
@@ -681,7 +767,7 @@ def kv_cache_scales_loader(
681
767
  except json.JSONDecodeError:
682
768
  logger.error("Error decoding JSON in file '%s'.", filename)
683
769
  except Exception:
684
- logger.exception("An error occurred while reading '%s'.", filename)
770
+ logger.error("An error occurred while reading '%s'.", filename)
685
771
  # This section is reached if and only if any of the excepts are hit
686
772
  # Return an empty iterable (list) => no KV cache scales are loaded
687
773
  # which ultimately defaults to 1.0 scales
@@ -61,7 +61,10 @@ from sglang.srt.layers.radix_attention import RadixAttention
61
61
  from sglang.srt.layers.rotary_embedding import get_rope
62
62
  from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
63
63
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
64
- from sglang.srt.model_loader.weight_utils import default_weight_loader
64
+ from sglang.srt.model_loader.weight_utils import (
65
+ default_weight_loader,
66
+ maybe_remap_kv_scale_name,
67
+ )
65
68
  from sglang.srt.utils import get_compiler_backend, set_weight_attrs
66
69
 
67
70
 
@@ -372,10 +375,19 @@ class CohereForCausalLM(nn.Module):
372
375
  # Skip loading extra bias for GPTQ models.
373
376
  if name.endswith(".bias") and name not in params_dict:
374
377
  continue
378
+ # Remapping the name of FP8 kv-scale.
379
+ name = maybe_remap_kv_scale_name(name, params_dict)
380
+ if name is None:
381
+ continue
382
+
375
383
  param = params_dict[name]
376
384
  weight_loader = getattr(param, "weight_loader", default_weight_loader)
377
385
  weight_loader(param, loaded_weight)
378
386
  loaded_params.add(name)
379
387
 
380
388
 
381
- EntryClass = CohereForCausalLM
389
+ class Cohere2ForCausalLM(CohereForCausalLM):
390
+ pass
391
+
392
+
393
+ EntryClass = [CohereForCausalLM, Cohere2ForCausalLM]
sglang/srt/models/dbrx.py CHANGED
@@ -42,7 +42,10 @@ from sglang.srt.layers.vocab_parallel_embedding import (
42
42
  VocabParallelEmbedding,
43
43
  )
44
44
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
45
- from sglang.srt.model_loader.weight_utils import default_weight_loader
45
+ from sglang.srt.model_loader.weight_utils import (
46
+ default_weight_loader,
47
+ maybe_remap_kv_scale_name,
48
+ )
46
49
  from sglang.srt.utils import set_weight_attrs
47
50
 
48
51
 
@@ -411,6 +414,11 @@ class DbrxForCausalLM(nn.Module):
411
414
  weight_loader(param, loaded_weight, weight_name)
412
415
  break
413
416
  else:
417
+ # Remapping the name of FP8 kv-scale.
418
+ name = maybe_remap_kv_scale_name(name, params_dict)
419
+ if name is None:
420
+ continue
421
+
414
422
  param = params_dict[name]
415
423
  weight_loader = getattr(param, "weight_loader", default_weight_loader)
416
424
  weight_loader(param, loaded_weight)
@@ -56,12 +56,12 @@ from sglang.srt.layers.vocab_parallel_embedding import (
56
56
  from sglang.srt.managers.schedule_batch import global_server_args_dict
57
57
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
58
58
  from sglang.srt.model_loader.weight_utils import default_weight_loader
59
- from sglang.srt.utils import is_flashinfer_available, is_hip
59
+ from sglang.srt.utils import is_cuda_available, is_hip
60
60
 
61
61
  is_hip_ = is_hip()
62
62
 
63
- if is_flashinfer_available():
64
- from flashinfer import bmm_fp8
63
+ if is_cuda_available():
64
+ from sgl_kernel import bmm_fp8
65
65
 
66
66
 
67
67
  class DeepseekV2MLP(nn.Module):
@@ -35,7 +35,10 @@ from sglang.srt.layers.radix_attention import RadixAttention
35
35
  from sglang.srt.layers.rotary_embedding import get_rope
36
36
  from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
37
37
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
38
- from sglang.srt.model_loader.weight_utils import default_weight_loader
38
+ from sglang.srt.model_loader.weight_utils import (
39
+ default_weight_loader,
40
+ maybe_remap_kv_scale_name,
41
+ )
39
42
  from sglang.srt.utils import make_layers
40
43
 
41
44
 
@@ -424,6 +427,11 @@ class Gemma2ForCausalLM(nn.Module):
424
427
  # Skip loading extra bias for GPTQ models.
425
428
  if name.endswith(".bias") and name not in params_dict:
426
429
  continue
430
+ # Remapping the name of FP8 kv-scale.
431
+ name = maybe_remap_kv_scale_name(name, params_dict)
432
+ if name is None:
433
+ continue
434
+
427
435
  param = params_dict[name]
428
436
  weight_loader = getattr(param, "weight_loader", default_weight_loader)
429
437
  weight_loader(param, loaded_weight)
sglang/srt/models/grok.py CHANGED
@@ -133,6 +133,7 @@ class Grok1MoE(nn.Module):
133
133
  renormalize=False,
134
134
  quant_config=quant_config,
135
135
  tp_size=tp_size,
136
+ activation="gelu",
136
137
  use_presharded_weights=use_presharded_weights,
137
138
  )
138
139
 
@@ -40,10 +40,10 @@ from sglang.srt.layers.vocab_parallel_embedding import (
40
40
  from sglang.srt.managers.schedule_batch import global_server_args_dict
41
41
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
42
42
  from sglang.srt.model_loader.weight_utils import default_weight_loader
43
- from sglang.srt.utils import is_flashinfer_available
43
+ from sglang.srt.utils import is_cuda_available
44
44
 
45
- if is_flashinfer_available():
46
- from flashinfer import bmm_fp8
45
+ if is_cuda_available():
46
+ from sgl_kernel import bmm_fp8
47
47
 
48
48
 
49
49
  class MiniCPM3MLP(nn.Module):
@@ -460,7 +460,12 @@ class TorchNativeLlamaForCausalLM(nn.Module):
460
460
  params_dict = dict(self.named_parameters())
461
461
  return len(params_dict)
462
462
 
463
- def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
463
+ def load_weights_to_module(
464
+ self,
465
+ fqn: str,
466
+ weights: Iterable[Tuple[str, torch.Tensor]],
467
+ ):
468
+ """Load weights onto submodule pointed by path `fqn`."""
464
469
  stacked_params_mapping = [
465
470
  # (param_name, shard_name, shard_id)
466
471
  (".qkv_proj", ".q_proj", "q"),
@@ -469,7 +474,8 @@ class TorchNativeLlamaForCausalLM(nn.Module):
469
474
  (".gate_up_proj", ".gate_proj", 0),
470
475
  (".gate_up_proj", ".up_proj", 1),
471
476
  ]
472
- params_dict = dict(self.named_parameters())
477
+ module = self.get_submodule(fqn)
478
+ params_dict = dict(module.named_parameters(prefix=fqn, recurse=False))
473
479
 
474
480
  for name, loaded_weight in weights:
475
481
  if "rotary_emb.inv_freq" in name or "projector" in name:
@@ -486,7 +492,7 @@ class TorchNativeLlamaForCausalLM(nn.Module):
486
492
  continue
487
493
  name = name.replace(weight_name, param_name)
488
494
  # Skip loading extra bias for GPTQ models.
489
- if name.endswith(".bias") and name not in params_dict:
495
+ if name.endswith(".bias") or name not in params_dict:
490
496
  continue
491
497
  param = params_dict[name]
492
498
  weight_loader = param.weight_loader
@@ -494,12 +500,19 @@ class TorchNativeLlamaForCausalLM(nn.Module):
494
500
  break
495
501
  else:
496
502
  # Skip loading extra bias for GPTQ models.
497
- if name.endswith(".bias") and name not in params_dict:
503
+ if name.endswith(".bias") or name not in params_dict:
498
504
  continue
499
505
  param = params_dict[name]
500
506
  weight_loader = getattr(param, "weight_loader", default_weight_loader)
501
507
  weight_loader(param, loaded_weight)
502
508
 
509
+ def load_weights(
510
+ self,
511
+ weights: Iterable[Tuple[str, torch.Tensor]],
512
+ ):
513
+ """Load weights onto the full model."""
514
+ self.load_weights_to_module("", weights)
515
+
503
516
 
504
517
  class TorchNativePhi3ForCausalLM(TorchNativeLlamaForCausalLM):
505
518
  pass