sglang 0.4.1.post3__py3-none-any.whl → 0.4.1.post5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (86) hide show
  1. sglang/bench_one_batch.py +2 -0
  2. sglang/bench_serving.py +18 -1
  3. sglang/lang/interpreter.py +71 -1
  4. sglang/lang/ir.py +2 -0
  5. sglang/srt/configs/__init__.py +4 -0
  6. sglang/srt/configs/chatglm.py +78 -0
  7. sglang/srt/configs/dbrx.py +279 -0
  8. sglang/srt/configs/model_config.py +1 -1
  9. sglang/srt/hf_transformers_utils.py +9 -14
  10. sglang/srt/layers/attention/__init__.py +22 -6
  11. sglang/srt/layers/attention/double_sparsity_backend.py +0 -52
  12. sglang/srt/layers/attention/flashinfer_backend.py +215 -83
  13. sglang/srt/layers/attention/torch_native_backend.py +1 -38
  14. sglang/srt/layers/attention/triton_backend.py +20 -11
  15. sglang/srt/layers/attention/triton_ops/decode_attention.py +4 -0
  16. sglang/srt/layers/linear.py +159 -55
  17. sglang/srt/layers/logits_processor.py +170 -215
  18. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1280,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  19. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1280,device_name=NVIDIA_H200.json +146 -0
  20. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=2560,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  21. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=2560,device_name=NVIDIA_H200.json +146 -0
  22. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=320,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  23. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=320,device_name=NVIDIA_H200.json +146 -0
  24. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  25. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=640,device_name=NVIDIA_H200.json +146 -0
  26. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  27. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=NVIDIA_H200.json +146 -0
  28. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  29. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=NVIDIA_H200.json +146 -0
  30. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=2048,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  31. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=2048,device_name=NVIDIA_H200.json +146 -0
  32. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  33. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=NVIDIA_H200.json +146 -0
  34. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  35. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=NVIDIA_H200.json +146 -0
  36. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  37. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=NVIDIA_H200.json +146 -0
  38. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  39. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +198 -29
  40. sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -7
  41. sglang/srt/layers/parameter.py +431 -0
  42. sglang/srt/layers/quantization/__init__.py +3 -2
  43. sglang/srt/layers/quantization/fp8.py +3 -3
  44. sglang/srt/layers/quantization/modelopt_quant.py +174 -0
  45. sglang/srt/layers/sampler.py +57 -21
  46. sglang/srt/layers/torchao_utils.py +17 -3
  47. sglang/srt/layers/vocab_parallel_embedding.py +1 -1
  48. sglang/srt/managers/cache_controller.py +307 -0
  49. sglang/srt/managers/data_parallel_controller.py +2 -0
  50. sglang/srt/managers/io_struct.py +1 -2
  51. sglang/srt/managers/schedule_batch.py +33 -3
  52. sglang/srt/managers/schedule_policy.py +159 -90
  53. sglang/srt/managers/scheduler.py +68 -28
  54. sglang/srt/managers/session_controller.py +1 -1
  55. sglang/srt/managers/tokenizer_manager.py +27 -21
  56. sglang/srt/managers/tp_worker.py +16 -4
  57. sglang/srt/managers/tp_worker_overlap_thread.py +3 -4
  58. sglang/srt/mem_cache/memory_pool.py +206 -1
  59. sglang/srt/metrics/collector.py +22 -30
  60. sglang/srt/model_executor/cuda_graph_runner.py +129 -77
  61. sglang/srt/model_executor/forward_batch_info.py +51 -21
  62. sglang/srt/model_executor/model_runner.py +72 -64
  63. sglang/srt/models/chatglm.py +1 -1
  64. sglang/srt/models/dbrx.py +1 -1
  65. sglang/srt/models/deepseek_v2.py +34 -7
  66. sglang/srt/models/grok.py +109 -29
  67. sglang/srt/models/llama.py +9 -2
  68. sglang/srt/openai_api/adapter.py +0 -17
  69. sglang/srt/openai_api/protocol.py +3 -3
  70. sglang/srt/sampling/sampling_batch_info.py +22 -0
  71. sglang/srt/sampling/sampling_params.py +9 -1
  72. sglang/srt/server.py +20 -13
  73. sglang/srt/server_args.py +120 -58
  74. sglang/srt/speculative/build_eagle_tree.py +347 -0
  75. sglang/srt/speculative/eagle_utils.py +626 -0
  76. sglang/srt/speculative/eagle_worker.py +184 -0
  77. sglang/srt/speculative/spec_info.py +5 -0
  78. sglang/srt/utils.py +47 -7
  79. sglang/test/test_programs.py +23 -1
  80. sglang/test/test_utils.py +36 -7
  81. sglang/version.py +1 -1
  82. {sglang-0.4.1.post3.dist-info → sglang-0.4.1.post5.dist-info}/METADATA +12 -12
  83. {sglang-0.4.1.post3.dist-info → sglang-0.4.1.post5.dist-info}/RECORD +86 -57
  84. {sglang-0.4.1.post3.dist-info → sglang-0.4.1.post5.dist-info}/WHEEL +1 -1
  85. {sglang-0.4.1.post3.dist-info → sglang-0.4.1.post5.dist-info}/LICENSE +0 -0
  86. {sglang-0.4.1.post3.dist-info → sglang-0.4.1.post5.dist-info}/top_level.txt +0 -0
@@ -17,7 +17,7 @@ import gc
17
17
  import json
18
18
  import logging
19
19
  import time
20
- from typing import Optional
20
+ from typing import List, Optional, Tuple
21
21
 
22
22
  import torch
23
23
  import torch.distributed as dist
@@ -48,8 +48,8 @@ from sglang.srt.mem_cache.memory_pool import (
48
48
  )
49
49
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
50
50
  from sglang.srt.model_loader import get_model
51
- from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
52
51
  from sglang.srt.server_args import ServerArgs
52
+ from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
53
53
  from sglang.srt.utils import (
54
54
  enable_show_time_cost,
55
55
  get_available_gpu_memory,
@@ -75,6 +75,7 @@ class ModelRunner:
75
75
  tp_size: int,
76
76
  nccl_port: int,
77
77
  server_args: ServerArgs,
78
+ is_draft_worker: bool = False,
78
79
  ):
79
80
  # Parse args
80
81
  self.model_config = model_config
@@ -85,8 +86,13 @@ class ModelRunner:
85
86
  self.tp_size = tp_size
86
87
  self.dist_port = nccl_port
87
88
  self.server_args = server_args
89
+ self.is_draft_worker = is_draft_worker
88
90
  self.is_generation = model_config.is_generation
89
91
  self.is_multimodal = model_config.is_multimodal
92
+ self.should_log = tp_rank == 0
93
+ self.spec_algorithm = SpeculativeAlgorithm.from_string(
94
+ server_args.speculative_algorithm
95
+ )
90
96
 
91
97
  # Model-specific adjustment
92
98
  if (
@@ -112,15 +118,21 @@ class ModelRunner:
112
118
 
113
119
  if self.is_multimodal:
114
120
  self.mem_fraction_static *= 0.95
121
+ logger.info(
122
+ f"Automatically reduce --mem-fraction-static to {self.mem_fraction_static:.3f} "
123
+ f"because this is a multimodal model."
124
+ )
125
+
115
126
  if self.model_config.hf_config.architectures == [
116
127
  "MllamaForConditionalGeneration"
117
128
  ]:
118
129
  logger.info("Automatically turn off --chunked-prefill-size for mllama.")
119
130
  server_args.chunked_prefill_size = -1
120
- # TODO: qwen2-vl does not support radix cache now, set disable_radix_cache=True automatically
131
+
121
132
  if self.model_config.hf_config.architectures == [
122
133
  "Qwen2VLForConditionalGeneration"
123
134
  ]:
135
+ # TODO: qwen2-vl does not support radix cache now, set disable_radix_cache=True automatically
124
136
  logger.info(
125
137
  "Automatically turn off --chunked-prefill-size and disable radix cache for qwen2-vl."
126
138
  )
@@ -192,9 +204,9 @@ class ModelRunner:
192
204
  torch.get_device_module(self.device).set_device(self.gpu_id)
193
205
  if self.device == "cuda":
194
206
  backend = "nccl"
195
- # ToDO(liangan1):Just use gloo to bypass the initilization fail
196
- # Need to use xccl for xpu backend in the future
197
207
  elif self.device == "xpu":
208
+ # TODO(liangan1): Just use gloo to bypass the initilization fail
209
+ # Need to use xccl for xpu backend in the future
198
210
  backend = "gloo"
199
211
  elif self.device == "hpu":
200
212
  backend = "hccl"
@@ -206,14 +218,18 @@ class ModelRunner:
206
218
  else:
207
219
  dist_init_method = f"tcp://127.0.0.1:{self.dist_port}"
208
220
  set_custom_all_reduce(not self.server_args.disable_custom_all_reduce)
209
- init_distributed_environment(
210
- backend=backend,
211
- world_size=self.tp_size,
212
- rank=self.tp_rank,
213
- local_rank=self.gpu_id,
214
- distributed_init_method=dist_init_method,
215
- )
216
- initialize_model_parallel(tensor_model_parallel_size=self.tp_size)
221
+
222
+ if not self.is_draft_worker:
223
+ # Only initilzie the distributed environment on the target model worker.
224
+ init_distributed_environment(
225
+ backend=backend,
226
+ world_size=self.tp_size,
227
+ rank=self.tp_rank,
228
+ local_rank=self.gpu_id,
229
+ distributed_init_method=dist_init_method,
230
+ )
231
+ initialize_model_parallel(tensor_model_parallel_size=self.tp_size)
232
+
217
233
  min_per_gpu_memory = get_available_gpu_memory(
218
234
  self.device, self.gpu_id, distributed=self.tp_size > 1
219
235
  )
@@ -408,7 +424,6 @@ class ModelRunner:
408
424
  target_dtype = (
409
425
  dtype if isinstance(dtype, torch.dtype) else getattr(torch, dtype)
410
426
  )
411
- current_dtype = self.dtype if isinstance(self.dtype, str) else self.dtype
412
427
 
413
428
  assert (
414
429
  self._model_update_group is not None
@@ -429,9 +444,9 @@ class ModelRunner:
429
444
  logger.error(error_msg)
430
445
  return False, error_msg
431
446
 
432
- def update_weights_from_tensor(self, name, tensor: torch.Tensor):
433
- self.model.load_weights([(name, tensor)])
434
- return True, "Success" # TODO error handling
447
+ def update_weights_from_tensor(self, named_tensors: List[Tuple[str, torch.Tensor]]):
448
+ self.model.load_weights(named_tensors)
449
+ return True, "Success"
435
450
 
436
451
  def get_weights_by_name(
437
452
  self, name: str, truncate_size: int = 100
@@ -507,6 +522,28 @@ class ModelRunner:
507
522
  )
508
523
 
509
524
  self.max_total_num_tokens = self.profile_max_num_token(total_gpu_memory)
525
+
526
+ if max_num_reqs is None:
527
+ max_num_reqs = min(
528
+ max(
529
+ int(
530
+ self.max_total_num_tokens / self.model_config.context_len * 512
531
+ ),
532
+ 2048,
533
+ ),
534
+ 4096,
535
+ )
536
+
537
+ if not self.spec_algorithm.is_none():
538
+ if self.is_draft_worker:
539
+ self.max_total_num_tokens = self.server_args.draft_runner_cache_size
540
+ else:
541
+ self.server_args.draft_runner_cache_size = (
542
+ self.max_total_num_tokens
543
+ + max_num_reqs * self.server_args.speculative_num_steps
544
+ + 100
545
+ )
546
+
510
547
  if max_total_tokens is not None:
511
548
  if max_total_tokens > self.max_total_num_tokens:
512
549
  logging.warning(
@@ -521,17 +558,6 @@ class ModelRunner:
521
558
  "Not enough memory. Please try to increase --mem-fraction-static."
522
559
  )
523
560
 
524
- if max_num_reqs is None:
525
- max_num_reqs = min(
526
- max(
527
- int(
528
- self.max_total_num_tokens / self.model_config.context_len * 512
529
- ),
530
- 2048,
531
- ),
532
- 4096,
533
- )
534
-
535
561
  self.req_to_token_pool = ReqToTokenPool(
536
562
  size=max_num_reqs + 1,
537
563
  max_context_len=self.model_config.context_len + 4,
@@ -608,7 +634,6 @@ class ModelRunner:
608
634
  )
609
635
 
610
636
  def init_double_sparsity_channel_config(self, selected_channel):
611
-
612
637
  selected_channel = "." + selected_channel + "_proj"
613
638
  self.sorted_channels = []
614
639
  # load channel config
@@ -651,10 +676,6 @@ class ModelRunner:
651
676
  tensor_parallel(self.model, device_mesh)
652
677
 
653
678
  def forward_decode(self, forward_batch: ForwardBatch):
654
- if self.cuda_graph_runner and self.cuda_graph_runner.can_run(forward_batch):
655
- return self.cuda_graph_runner.replay(forward_batch)
656
-
657
- forward_batch.positions = (forward_batch.seq_lens - 1).to(torch.int64)
658
679
  self.attn_backend.init_forward_metadata(forward_batch)
659
680
  return self.model.forward(
660
681
  forward_batch.input_ids, forward_batch.positions, forward_batch
@@ -684,14 +705,18 @@ class ModelRunner:
684
705
  )
685
706
 
686
707
  def forward_idle(self, forward_batch: ForwardBatch):
687
- if self.cuda_graph_runner and self.cuda_graph_runner.can_run(forward_batch):
688
- return self.cuda_graph_runner.replay(forward_batch)
689
-
690
708
  return self.model.forward(
691
709
  forward_batch.input_ids, forward_batch.positions, forward_batch
692
710
  )
693
711
 
694
712
  def forward(self, forward_batch: ForwardBatch) -> LogitsProcessorOutput:
713
+ if (
714
+ forward_batch.forward_mode.is_cuda_graph()
715
+ and self.cuda_graph_runner
716
+ and self.cuda_graph_runner.can_run(forward_batch)
717
+ ):
718
+ return self.cuda_graph_runner.replay(forward_batch)
719
+
695
720
  if forward_batch.forward_mode.is_decode():
696
721
  return self.forward_decode(forward_batch)
697
722
  elif forward_batch.forward_mode.is_extend():
@@ -699,11 +724,12 @@ class ModelRunner:
699
724
  elif forward_batch.forward_mode.is_idle():
700
725
  return self.forward_idle(forward_batch)
701
726
  else:
702
- raise ValueError(f"Invaid forward mode: {forward_batch.forward_mode}")
727
+ raise ValueError(f"Invalid forward mode: {forward_batch.forward_mode}")
703
728
 
704
729
  def sample(
705
730
  self, logits_output: LogitsProcessorOutput, forward_batch: ForwardBatch
706
731
  ) -> torch.Tensor:
732
+ # Apply logit bias
707
733
  sampling_info = forward_batch.sampling_info
708
734
  if sampling_info.sampling_info_done:
709
735
  # Overlap mode: the function update_regex_vocab_mask was executed
@@ -714,35 +740,17 @@ class ModelRunner:
714
740
  # Normal mode: Put CPU-heavy tasks here. They will be overlapped with the forward pass.
715
741
  sampling_info.update_regex_vocab_mask()
716
742
  sampling_info.update_penalties()
717
- logits = self.apply_logits_bias(logits_output.next_token_logits, sampling_info)
718
-
719
- # Sample the next tokens.
720
- next_token_ids = self.sampler(logits, sampling_info)
743
+ sampling_info.apply_logits_bias(logits_output.next_token_logits)
744
+
745
+ # Sample the next tokens
746
+ next_token_ids = self.sampler(
747
+ logits_output,
748
+ sampling_info,
749
+ forward_batch.return_logprob,
750
+ forward_batch.top_logprobs_nums,
751
+ )
721
752
  return next_token_ids
722
753
 
723
- def apply_logits_bias(self, logits: torch.Tensor, sampling_info: SamplingBatchInfo):
724
- # Apply logit_bias
725
- if sampling_info.logit_bias is not None:
726
- logits.add_(sampling_info.logit_bias)
727
-
728
- # min-token, presence, frequency
729
- if sampling_info.linear_penalties is not None:
730
- logits.add_(sampling_info.linear_penalties)
731
-
732
- # repetition
733
- if sampling_info.scaling_penalties is not None:
734
- logits = torch.where(
735
- logits > 0,
736
- logits / sampling_info.scaling_penalties,
737
- logits * sampling_info.scaling_penalties,
738
- )
739
-
740
- # Apply regex vocab_mask
741
- if sampling_info.vocab_mask is not None:
742
- sampling_info.apply_mask(logits=logits, vocab_mask=sampling_info.vocab_mask)
743
-
744
- return logits
745
-
746
754
  @property
747
755
  def model_is_mrope(self) -> bool:
748
756
  """Detect if the model has "mrope" rope_scaling type.
@@ -23,8 +23,8 @@ from torch import nn
23
23
  from torch.nn import LayerNorm
24
24
  from vllm.distributed import get_tensor_model_parallel_world_size
25
25
  from vllm.model_executor.layers.rotary_embedding import get_rope
26
- from vllm.transformers_utils.configs import ChatGLMConfig
27
26
 
27
+ from sglang.srt.configs import ChatGLMConfig
28
28
  from sglang.srt.layers.activation import SiluAndMul
29
29
  from sglang.srt.layers.layernorm import RMSNorm
30
30
  from sglang.srt.layers.linear import (
sglang/srt/models/dbrx.py CHANGED
@@ -25,8 +25,8 @@ from vllm.distributed import (
25
25
  tensor_model_parallel_all_reduce,
26
26
  )
27
27
  from vllm.model_executor.layers.rotary_embedding import get_rope
28
- from vllm.transformers_utils.configs.dbrx import DbrxConfig
29
28
 
29
+ from sglang.srt.configs import DbrxConfig
30
30
  from sglang.srt.layers.linear import (
31
31
  QKVParallelLinear,
32
32
  ReplicatedLinear,
@@ -46,6 +46,7 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig
46
46
  from sglang.srt.layers.quantization.fp8_utils import (
47
47
  block_quant_to_tensor_quant,
48
48
  input_to_float8,
49
+ normalize_e4m3fn_to_e4m3fnuz,
49
50
  )
50
51
  from sglang.srt.layers.radix_attention import RadixAttention
51
52
  from sglang.srt.layers.vocab_parallel_embedding import (
@@ -55,7 +56,9 @@ from sglang.srt.layers.vocab_parallel_embedding import (
55
56
  from sglang.srt.managers.schedule_batch import global_server_args_dict
56
57
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
57
58
  from sglang.srt.model_loader.weight_utils import default_weight_loader
58
- from sglang.srt.utils import is_flashinfer_available
59
+ from sglang.srt.utils import is_flashinfer_available, is_hip
60
+
61
+ is_hip_ = is_hip()
59
62
 
60
63
  if is_flashinfer_available():
61
64
  from flashinfer import bmm_fp8
@@ -573,7 +576,13 @@ class DeepseekV2AttentionMLA(nn.Module):
573
576
  )
574
577
  q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
575
578
 
576
- if self.w_kc.dtype == torch.float8_e4m3fn:
579
+ if self.w_kc.dtype == torch.float8_e4m3fnuz:
580
+ # TODO(kernel): add bmm_fp8 for torch.float8_e4m3fnuz
581
+ q_nope_out = torch.bmm(
582
+ q_nope.to(torch.bfloat16).transpose(0, 1),
583
+ self.w_kc.to(torch.bfloat16) * self.w_scale,
584
+ )
585
+ elif self.w_kc.dtype == torch.float8_e4m3fn:
577
586
  q_nope_val, q_nope_scale = input_to_float8(
578
587
  q_nope.transpose(0, 1), torch.float8_e4m3fn
579
588
  )
@@ -598,7 +607,13 @@ class DeepseekV2AttentionMLA(nn.Module):
598
607
  attn_output = self.attn_mqa(q_input, k_input, v_input, forward_batch)
599
608
  attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank)
600
609
 
601
- if self.w_vc.dtype == torch.float8_e4m3fn:
610
+ if self.w_vc.dtype == torch.float8_e4m3fnuz:
611
+ # TODO(kernel): add bmm_fp8 for torch.float8_e4m3fnuz
612
+ attn_bmm_output = torch.bmm(
613
+ attn_output.to(torch.bfloat16).transpose(0, 1),
614
+ self.w_vc.to(torch.bfloat16) * self.w_scale,
615
+ )
616
+ elif self.w_vc.dtype == torch.float8_e4m3fn:
602
617
  attn_output_val, attn_output_scale = input_to_float8(
603
618
  attn_output.transpose(0, 1), torch.float8_e4m3fn
604
619
  )
@@ -940,15 +955,25 @@ class DeepseekV2ForCausalLM(nn.Module):
940
955
  w = self_attn.kv_b_proj.weight
941
956
  # NOTE(HandH1998): Since `bmm_fp8` only supports per-tensor scale, we have to requantize `self_attn.kv_b_proj`.
942
957
  # This may affect the accuracy of fp8 model.
943
- if (
944
- hasattr(self.quant_config, "weight_block_size")
945
- and w.dtype == torch.float8_e4m3fn
958
+ if hasattr(self.quant_config, "weight_block_size") and w.dtype in (
959
+ torch.float8_e4m3fn,
960
+ torch.float8_e4m3fnuz,
946
961
  ):
947
962
  weight_block_size = self.quant_config.weight_block_size
948
963
  if weight_block_size is not None:
949
964
  assert hasattr(self_attn.kv_b_proj, "weight_scale_inv")
965
+ if is_hip_:
966
+ weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
967
+ weight=w,
968
+ weight_scale=self_attn.kv_b_proj.weight_scale_inv,
969
+ input_scale=None,
970
+ )
971
+ else:
972
+ weight = w
973
+ weight_scale = self_attn.kv_b_proj.weight_scale_inv
974
+
950
975
  w, scale = block_quant_to_tensor_quant(
951
- w, self_attn.kv_b_proj.weight_scale_inv, weight_block_size
976
+ weight, weight_scale, weight_block_size
952
977
  )
953
978
  self_attn.w_scale = scale
954
979
  w_kc, w_vc = w.unflatten(
@@ -961,6 +986,8 @@ class DeepseekV2ForCausalLM(nn.Module):
961
986
  and self_attn.w_scale is None
962
987
  ):
963
988
  self_attn.w_scale = self_attn.kv_b_proj.weight_scale
989
+ if is_hip_:
990
+ self_attn.w_scale *= 2.0
964
991
 
965
992
 
966
993
  class DeepseekV3ForCausalLM(DeepseekV2ForCausalLM):
sglang/srt/models/grok.py CHANGED
@@ -16,13 +16,16 @@
16
16
  # https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/mixtral.py#L1
17
17
  """Inference-only Grok1 model."""
18
18
 
19
- from typing import Iterable, Optional, Tuple
19
+ from typing import Iterable, List, Optional, Tuple
20
20
 
21
21
  import torch
22
22
  import torch.nn.functional as F
23
23
  from torch import nn
24
24
  from transformers import PretrainedConfig
25
- from vllm.distributed import get_tensor_model_parallel_world_size
25
+ from vllm.distributed import (
26
+ get_tensor_model_parallel_rank,
27
+ get_tensor_model_parallel_world_size,
28
+ )
26
29
  from vllm.model_executor.layers.rotary_embedding import get_rope
27
30
 
28
31
  from sglang.srt.layers.activation import GeluAndMul
@@ -42,6 +45,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
42
45
  VocabParallelEmbedding,
43
46
  )
44
47
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
48
+ from sglang.srt.model_loader.loader import DefaultModelLoader
45
49
  from sglang.srt.model_loader.weight_utils import default_weight_loader
46
50
 
47
51
 
@@ -53,6 +57,7 @@ class Grok1MLP(nn.Module):
53
57
  quant_config: Optional[QuantizationConfig] = None,
54
58
  prefix: str = "",
55
59
  reduce_results=True,
60
+ use_presharded_weights: bool = False,
56
61
  ) -> None:
57
62
  super().__init__()
58
63
  self.gate_up_proj = MergedColumnParallelLinear(
@@ -61,6 +66,7 @@ class Grok1MLP(nn.Module):
61
66
  bias=False,
62
67
  quant_config=quant_config,
63
68
  prefix=f"{prefix}.gate_up_proj",
69
+ use_presharded_weights=use_presharded_weights,
64
70
  )
65
71
  self.down_proj = RowParallelLinear(
66
72
  intermediate_size,
@@ -69,6 +75,7 @@ class Grok1MLP(nn.Module):
69
75
  quant_config=quant_config,
70
76
  prefix=f"{prefix}.down_proj",
71
77
  reduce_results=reduce_results,
78
+ use_presharded_weights=use_presharded_weights,
72
79
  )
73
80
  self.act_fn = GeluAndMul(approximate="tanh")
74
81
 
@@ -99,6 +106,7 @@ class Grok1MoE(nn.Module):
99
106
  quant_config: Optional[QuantizationConfig] = None,
100
107
  tp_size: Optional[int] = None,
101
108
  reduce_results=True,
109
+ use_presharded_weights: bool = False,
102
110
  ):
103
111
  super().__init__()
104
112
  self.hidden_size = hidden_size
@@ -125,6 +133,7 @@ class Grok1MoE(nn.Module):
125
133
  renormalize=False,
126
134
  quant_config=quant_config,
127
135
  tp_size=tp_size,
136
+ use_presharded_weights=use_presharded_weights,
128
137
  )
129
138
 
130
139
  def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
@@ -152,6 +161,7 @@ class Grok1Attention(nn.Module):
152
161
  max_position: int = 4096 * 32,
153
162
  rope_theta: float = 10000,
154
163
  quant_config: Optional[QuantizationConfig] = None,
164
+ reduce_results: bool = True,
155
165
  ) -> None:
156
166
  super().__init__()
157
167
  self.config = config
@@ -190,6 +200,7 @@ class Grok1Attention(nn.Module):
190
200
  hidden_size,
191
201
  bias=False,
192
202
  quant_config=quant_config,
203
+ reduce_results=reduce_results,
193
204
  )
194
205
  self.rotary_emb = get_rope(
195
206
  self.head_dim,
@@ -230,10 +241,12 @@ class Grok1DecoderLayer(nn.Module):
230
241
  config: PretrainedConfig,
231
242
  layer_id: int = 0,
232
243
  quant_config: Optional[QuantizationConfig] = None,
244
+ use_presharded_weights: bool = False,
233
245
  ) -> None:
234
246
  super().__init__()
235
247
  self.num_experts = config.num_local_experts
236
248
  self.hidden_size = config.hidden_size
249
+ self.layer_id = layer_id
237
250
 
238
251
  rope_theta = getattr(config, "rope_theta", 10000)
239
252
  self.self_attn = Grok1Attention(
@@ -258,6 +271,7 @@ class Grok1DecoderLayer(nn.Module):
258
271
  ),
259
272
  quant_config=quant_config,
260
273
  reduce_results=True,
274
+ use_presharded_weights=use_presharded_weights,
261
275
  )
262
276
  self.pre_attn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
263
277
  self.post_attn_norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@@ -295,6 +309,7 @@ class Grok1Model(nn.Module):
295
309
  self,
296
310
  config: PretrainedConfig,
297
311
  quant_config: Optional[QuantizationConfig] = None,
312
+ use_presharded_weights: bool = False,
298
313
  ) -> None:
299
314
  super().__init__()
300
315
  self.config = config
@@ -307,7 +322,12 @@ class Grok1Model(nn.Module):
307
322
  )
308
323
  self.layers = nn.ModuleList(
309
324
  [
310
- Grok1DecoderLayer(config, i, quant_config=quant_config)
325
+ Grok1DecoderLayer(
326
+ config,
327
+ i,
328
+ quant_config=quant_config,
329
+ use_presharded_weights=use_presharded_weights,
330
+ )
311
331
  for i in range(config.num_hidden_layers)
312
332
  ]
313
333
  )
@@ -343,7 +363,21 @@ class Grok1ForCausalLM(nn.Module):
343
363
  super().__init__()
344
364
  self.config = config
345
365
  self.quant_config = quant_config
346
- self.model = Grok1Model(config, quant_config=quant_config)
366
+
367
+ if (
368
+ self.config.num_local_experts > 0
369
+ and get_tensor_model_parallel_world_size() > 1
370
+ ):
371
+ self.use_presharded_weights = True
372
+ setattr(DefaultModelLoader, "_prepare_weights", _prepare_presharded_weights)
373
+ else:
374
+ self.use_presharded_weights = False
375
+
376
+ self.model = Grok1Model(
377
+ config,
378
+ quant_config=quant_config,
379
+ use_presharded_weights=self.use_presharded_weights,
380
+ )
347
381
  self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
348
382
  self.logits_processor = LogitsProcessor(config)
349
383
 
@@ -359,7 +393,12 @@ class Grok1ForCausalLM(nn.Module):
359
393
  input_ids, hidden_states, self.lm_head, forward_batch
360
394
  )
361
395
 
362
- def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
396
+ def load_weights(
397
+ self,
398
+ weights: Iterable[Tuple[str, torch.Tensor]],
399
+ ):
400
+ num_experts = self.config.num_local_experts
401
+
363
402
  stacked_params_mapping = [
364
403
  # (param_name, shard_name, shard_id)
365
404
  ("qkv_proj", "q_proj", "q"),
@@ -375,10 +414,23 @@ class Grok1ForCausalLM(nn.Module):
375
414
  ckpt_gate_proj_name="w1",
376
415
  ckpt_down_proj_name="w2",
377
416
  ckpt_up_proj_name="w3",
378
- num_experts=self.config.num_local_experts,
417
+ num_experts=num_experts,
379
418
  )
380
419
 
381
420
  params_dict = dict(self.named_parameters())
421
+ all_names = set(params_dict.keys())
422
+ hit_names = set()
423
+
424
+ def load_weight_wrapper(name, loaded_weight, *args, **kwargs):
425
+ if name not in params_dict:
426
+ return
427
+
428
+ param = params_dict[name]
429
+ weight_loader = getattr(param, "weight_loader", default_weight_loader)
430
+ weight_loader(param, loaded_weight, *args, **kwargs)
431
+
432
+ hit_names.add(name)
433
+
382
434
  for name, loaded_weight in weights:
383
435
  if "rotary_emb.inv_freq" in name:
384
436
  continue
@@ -391,9 +443,7 @@ class Grok1ForCausalLM(nn.Module):
391
443
  if name.endswith(".bias") and name not in params_dict:
392
444
  continue
393
445
 
394
- param = params_dict[name]
395
- weight_loader = param.weight_loader
396
- weight_loader(param, loaded_weight, shard_id)
446
+ load_weight_wrapper(name, loaded_weight, shard_id)
397
447
  break
398
448
  else:
399
449
  for mapping in expert_params_mapping:
@@ -402,15 +452,8 @@ class Grok1ForCausalLM(nn.Module):
402
452
  continue
403
453
  name = name.replace(weight_name, param_name)
404
454
 
405
- if (
406
- name.endswith(".bias") or name.endswith("_bias")
407
- ) and name not in params_dict:
408
- continue
409
-
410
- param = params_dict[name]
411
- weight_loader = param.weight_loader
412
- weight_loader(
413
- param,
455
+ load_weight_wrapper(
456
+ name,
414
457
  loaded_weight,
415
458
  name,
416
459
  shard_id=shard_id,
@@ -419,21 +462,58 @@ class Grok1ForCausalLM(nn.Module):
419
462
  break
420
463
  else:
421
464
  # Skip loading extra bias for GPTQ models.
422
- if (
423
- name.endswith(".bias") or name.endswith("_bias")
424
- ) and name not in params_dict:
425
- continue
426
- # Skip loading kv_scale from ckpts towards new design.
427
- if name.endswith(".kv_scale") and name not in params_dict:
465
+ if name.endswith(".bias") and name not in params_dict:
428
466
  continue
429
467
  if name is None:
430
468
  continue
431
469
 
432
- param = params_dict[name]
433
- weight_loader = getattr(
434
- param, "weight_loader", default_weight_loader
435
- )
436
- weight_loader(param, loaded_weight)
470
+ load_weight_wrapper(name=name, loaded_weight=loaded_weight)
471
+
472
+
473
+ old_prepare_weights = getattr(DefaultModelLoader, "_prepare_weights")
474
+
475
+
476
+ def _prepare_presharded_weights(
477
+ self, model_name_or_path: str, revision: Optional[str], fall_back_to_pt: bool
478
+ ) -> Tuple[str, List[str], bool]:
479
+ import glob
480
+ import os
481
+
482
+ if get_tensor_model_parallel_world_size() == 1:
483
+ return old_prepare_weights(self, model_name_or_path, revision, fall_back_to_pt)
484
+
485
+ if not os.path.isdir(model_name_or_path):
486
+ from sglang.srt.model_loader.weight_utils import download_weights_from_hf
487
+
488
+ allow_patterns = ["*.safetensors", "*.bin"]
489
+ hf_folder = download_weights_from_hf(
490
+ model_name_or_path,
491
+ self.load_config.download_dir,
492
+ allow_patterns,
493
+ revision,
494
+ ignore_patterns=self.load_config.ignore_patterns,
495
+ )
496
+ else:
497
+ hf_folder = model_name_or_path
498
+
499
+ tp_rank = get_tensor_model_parallel_rank()
500
+
501
+ # The old format
502
+ allow_patterns = [f"*-{tp_rank:03d}.bin"]
503
+
504
+ # The new format
505
+ allow_patterns += [f"*-TP-{tp_rank:03d}.safetensors", "*-TP-common.safetensors"]
506
+
507
+ hf_weights_files: List[str] = []
508
+ for pattern in allow_patterns:
509
+ hf_weights_files += glob.glob(os.path.join(hf_folder, pattern))
510
+
511
+ if hf_weights_files[0].endswith("safetensors"):
512
+ use_safetensors = True
513
+ else:
514
+ use_safetensors = False
515
+
516
+ return hf_folder, hf_weights_files, use_safetensors
437
517
 
438
518
 
439
519
  class Grok1ModelForCausalLM(Grok1ForCausalLM):