sglang 0.4.6.post3__py3-none-any.whl → 0.4.6.post4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (107) hide show
  1. sglang/bench_offline_throughput.py +4 -2
  2. sglang/bench_one_batch.py +2 -2
  3. sglang/bench_one_batch_server.py +143 -15
  4. sglang/bench_serving.py +9 -7
  5. sglang/compile_deep_gemm.py +1 -1
  6. sglang/eval/loogle_eval.py +157 -0
  7. sglang/lang/chat_template.py +78 -78
  8. sglang/lang/tracer.py +1 -1
  9. sglang/srt/code_completion_parser.py +1 -1
  10. sglang/srt/configs/deepseekvl2.py +2 -2
  11. sglang/srt/configs/model_config.py +1 -0
  12. sglang/srt/constrained/base_grammar_backend.py +55 -72
  13. sglang/srt/constrained/llguidance_backend.py +25 -21
  14. sglang/srt/constrained/outlines_backend.py +27 -26
  15. sglang/srt/constrained/reasoner_grammar_backend.py +22 -33
  16. sglang/srt/constrained/xgrammar_backend.py +69 -43
  17. sglang/srt/conversation.py +48 -43
  18. sglang/srt/disaggregation/base/conn.py +1 -0
  19. sglang/srt/disaggregation/decode.py +7 -2
  20. sglang/srt/disaggregation/fake/conn.py +1 -1
  21. sglang/srt/disaggregation/mooncake/conn.py +227 -120
  22. sglang/srt/disaggregation/nixl/conn.py +1 -0
  23. sglang/srt/disaggregation/prefill.py +7 -4
  24. sglang/srt/disaggregation/utils.py +7 -1
  25. sglang/srt/entrypoints/engine.py +17 -2
  26. sglang/srt/entrypoints/http_server.py +17 -2
  27. sglang/srt/function_call_parser.py +2 -2
  28. sglang/srt/layers/attention/flashattention_backend.py +1 -1
  29. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +1 -1
  30. sglang/srt/layers/attention/utils.py +4 -2
  31. sglang/srt/layers/dp_attention.py +71 -21
  32. sglang/srt/layers/layernorm.py +1 -1
  33. sglang/srt/layers/logits_processor.py +46 -11
  34. sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
  35. sglang/srt/layers/moe/ep_moe/layer.py +1 -1
  36. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +1 -1
  37. sglang/srt/layers/moe/topk.py +1 -1
  38. sglang/srt/layers/quantization/__init__.py +1 -1
  39. sglang/srt/layers/quantization/blockwise_int8.py +2 -2
  40. sglang/srt/layers/quantization/deep_gemm.py +72 -71
  41. sglang/srt/layers/quantization/fp8.py +2 -2
  42. sglang/srt/layers/quantization/fp8_kernel.py +3 -3
  43. sglang/srt/layers/quantization/int8_kernel.py +2 -2
  44. sglang/srt/layers/sampler.py +0 -4
  45. sglang/srt/layers/vocab_parallel_embedding.py +18 -7
  46. sglang/srt/lora/lora_manager.py +1 -1
  47. sglang/srt/lora/mem_pool.py +4 -4
  48. sglang/srt/lora/triton_ops/gate_up_lora_b.py +1 -1
  49. sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
  50. sglang/srt/lora/triton_ops/sgemm_lora_a.py +1 -1
  51. sglang/srt/lora/triton_ops/sgemm_lora_b.py +1 -1
  52. sglang/srt/lora/utils.py +1 -1
  53. sglang/srt/managers/data_parallel_controller.py +3 -3
  54. sglang/srt/managers/detokenizer_manager.py +21 -8
  55. sglang/srt/managers/io_struct.py +3 -1
  56. sglang/srt/managers/mm_utils.py +1 -1
  57. sglang/srt/managers/multimodal_processors/llava.py +46 -0
  58. sglang/srt/managers/multimodal_processors/pixtral.py +127 -0
  59. sglang/srt/managers/schedule_batch.py +76 -24
  60. sglang/srt/managers/schedule_policy.py +0 -3
  61. sglang/srt/managers/scheduler.py +113 -88
  62. sglang/srt/managers/scheduler_output_processor_mixin.py +124 -55
  63. sglang/srt/managers/tokenizer_manager.py +133 -34
  64. sglang/srt/managers/tp_worker.py +12 -9
  65. sglang/srt/managers/tp_worker_overlap_thread.py +22 -11
  66. sglang/srt/mem_cache/memory_pool.py +2 -0
  67. sglang/srt/metrics/collector.py +312 -37
  68. sglang/srt/model_executor/cuda_graph_runner.py +10 -11
  69. sglang/srt/model_executor/forward_batch_info.py +1 -1
  70. sglang/srt/model_executor/model_runner.py +19 -14
  71. sglang/srt/models/deepseek_janus_pro.py +2 -2
  72. sglang/srt/models/deepseek_v2.py +23 -20
  73. sglang/srt/models/llama.py +2 -0
  74. sglang/srt/models/llama4.py +5 -6
  75. sglang/srt/models/llava.py +248 -5
  76. sglang/srt/models/mixtral.py +98 -34
  77. sglang/srt/models/pixtral.py +467 -0
  78. sglang/srt/models/roberta.py +1 -1
  79. sglang/srt/models/torch_native_llama.py +1 -1
  80. sglang/srt/openai_api/adapter.py +30 -4
  81. sglang/srt/openai_api/protocol.py +0 -8
  82. sglang/srt/reasoning_parser.py +3 -3
  83. sglang/srt/sampling/custom_logit_processor.py +18 -3
  84. sglang/srt/sampling/sampling_batch_info.py +4 -56
  85. sglang/srt/sampling/sampling_params.py +2 -2
  86. sglang/srt/server_args.py +34 -4
  87. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +3 -3
  88. sglang/srt/speculative/eagle_utils.py +7 -7
  89. sglang/srt/speculative/eagle_worker.py +22 -19
  90. sglang/srt/utils.py +6 -5
  91. sglang/test/few_shot_gsm8k.py +2 -2
  92. sglang/test/few_shot_gsm8k_engine.py +2 -2
  93. sglang/test/run_eval.py +2 -2
  94. sglang/test/runners.py +8 -1
  95. sglang/test/send_one.py +13 -3
  96. sglang/test/simple_eval_common.py +1 -1
  97. sglang/test/simple_eval_humaneval.py +1 -1
  98. sglang/test/test_programs.py +5 -5
  99. sglang/test/test_utils.py +89 -14
  100. sglang/utils.py +1 -1
  101. sglang/version.py +1 -1
  102. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post4.dist-info}/METADATA +6 -5
  103. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post4.dist-info}/RECORD +107 -104
  104. /sglang/{llama3_eval.py → eval/llama3_eval.py} +0 -0
  105. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post4.dist-info}/WHEEL +0 -0
  106. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post4.dist-info}/licenses/LICENSE +0 -0
  107. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post4.dist-info}/top_level.txt +0 -0
@@ -49,6 +49,7 @@ from sglang.srt.disaggregation.utils import (
49
49
  from sglang.srt.entrypoints.engine import _launch_subprocesses
50
50
  from sglang.srt.function_call_parser import FunctionCallParser
51
51
  from sglang.srt.managers.io_struct import (
52
+ AbortReq,
52
53
  CloseSessionReqInput,
53
54
  ConfigureLoggingReq,
54
55
  EmbeddingReqInput,
@@ -221,7 +222,7 @@ async def get_server_info():
221
222
  return {
222
223
  **dataclasses.asdict(_global_state.tokenizer_manager.server_args),
223
224
  **_global_state.scheduler_info,
224
- **internal_states,
225
+ "internal_states": internal_states,
225
226
  "version": __version__,
226
227
  }
227
228
 
@@ -337,7 +338,11 @@ async def start_profile_async(obj: Optional[ProfileReqInput] = None):
337
338
  obj = ProfileReqInput()
338
339
 
339
340
  await _global_state.tokenizer_manager.start_profile(
340
- obj.output_dir, obj.num_steps, obj.activities
341
+ output_dir=obj.output_dir,
342
+ num_steps=obj.num_steps,
343
+ activities=obj.activities,
344
+ with_stack=obj.with_stack,
345
+ record_shapes=obj.record_shapes,
341
346
  )
342
347
  return Response(
343
348
  content="Start profiling.\n",
@@ -539,6 +544,16 @@ async def configure_logging(obj: ConfigureLoggingReq, request: Request):
539
544
  return Response(status_code=200)
540
545
 
541
546
 
547
+ @app.post("/abort_request")
548
+ async def abort_request(obj: AbortReq, request: Request):
549
+ """Abort a request."""
550
+ try:
551
+ _global_state.tokenizer_manager.abort_request(rid=obj.rid)
552
+ return Response(status_code=200)
553
+ except Exception as e:
554
+ return _create_error_response(e)
555
+
556
+
542
557
  @app.post("/parse_function_call")
543
558
  async def parse_function_call_request(obj: ParseFunctionCallReq, request: Request):
544
559
  """
@@ -86,8 +86,8 @@ class StructureInfo:
86
86
 
87
87
  _GetInfoFunc = Callable[[str], StructureInfo]
88
88
  """
89
- helper alias of function
90
- ususally it is a function that takes a name string and returns a StructureInfo object,
89
+ Helper alias of function
90
+ Usually it is a function that takes a name string and returns a StructureInfo object,
91
91
  which can be used to construct a structural_tag object
92
92
  """
93
93
 
@@ -308,7 +308,7 @@ class FlashAttentionBackend(AttentionBackend):
308
308
  ), "Sliding window and cross attention are not supported together"
309
309
 
310
310
  self.forward_metadata: FlashAttentionMetadata = None
311
- # extra metdata for handling speculative decoding topk > 1, extended draft decode and verify
311
+ # extra metadata for handling speculative decoding topk > 1, extended draft decode and verify
312
312
  self.forward_metadata_spec_decode_expand: FlashAttentionMetadata = None
313
313
  self.max_context_len = model_runner.model_config.context_len
314
314
  self.device = model_runner.device
@@ -919,7 +919,7 @@ def _fwd_kernel(
919
919
 
920
920
  e_max = n_e_max
921
921
 
922
- # stage 2: compute the trianlge part
922
+ # stage 2: compute the triangle part
923
923
 
924
924
  cur_block_m_end = tl.minimum(cur_seq_len_extend, (cur_block_m + 1) * BLOCK_M)
925
925
  for start_n in range(0, cur_block_m_end, BLOCK_N):
@@ -28,7 +28,8 @@ def create_flashinfer_kv_indices_triton(
28
28
 
29
29
  num_loop = tl.cdiv(kv_end - kv_start, BLOCK_SIZE)
30
30
  for i in range(num_loop):
31
- offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE
31
+ # index into req_to_token_ptr needs to be int64
32
+ offset = tl.arange(0, BLOCK_SIZE).to(tl.int64) + i * BLOCK_SIZE
32
33
  mask = offset < kv_end - kv_start
33
34
  data = tl.load(
34
35
  req_to_token_ptr
@@ -70,8 +71,9 @@ def create_flashmla_kv_indices_triton(
70
71
  num_pages_loop = tl.cdiv(kv_end - kv_start, BLOCK_SIZE)
71
72
 
72
73
  for i in range(num_pages_loop):
74
+ # index into req_to_token_ptr needs to be int64
73
75
  paged_offset = (
74
- tl.arange(0, NUM_PAGE_PER_BLOCK) + i * NUM_PAGE_PER_BLOCK
76
+ tl.arange(0, NUM_PAGE_PER_BLOCK).to(tl.int64) + i * NUM_PAGE_PER_BLOCK
75
77
  ) * PAGED_SIZE
76
78
  paged_offset_out = tl.arange(0, NUM_PAGE_PER_BLOCK) + i * NUM_PAGE_PER_BLOCK
77
79
 
@@ -24,8 +24,10 @@ if TYPE_CHECKING:
24
24
  _ATTN_TP_GROUP = None
25
25
  _ATTN_TP_RANK = None
26
26
  _ATTN_TP_SIZE = None
27
- _DP_RANK = None
28
- _DP_SIZE = None
27
+ _ATTN_DP_RANK = None
28
+ _ATTN_DP_SIZE = None
29
+ _LOCAL_ATTN_DP_SIZE = None
30
+ _LOCAL_ATTN_DP_RANK = None
29
31
 
30
32
 
31
33
  def compute_dp_attention_world_info(enable_dp_attention, tp_rank, tp_size, dp_size):
@@ -33,9 +35,27 @@ def compute_dp_attention_world_info(enable_dp_attention, tp_rank, tp_size, dp_si
33
35
  return tp_rank, tp_size, 0
34
36
 
35
37
  attn_tp_size = tp_size // dp_size
36
- dp_rank = tp_rank // attn_tp_size
38
+ attn_dp_rank = tp_rank // attn_tp_size
37
39
  attn_tp_rank = tp_rank % attn_tp_size
38
- return attn_tp_rank, attn_tp_size, dp_rank
40
+
41
+ return attn_tp_rank, attn_tp_size, attn_dp_rank
42
+
43
+
44
+ def compute_dp_attention_local_info(
45
+ enable_dp_attention, tp_rank, tp_size, dp_size, moe_dense_tp_size
46
+ ):
47
+ if not enable_dp_attention:
48
+ return tp_rank, tp_size, 0
49
+
50
+ local_tp_size = moe_dense_tp_size if moe_dense_tp_size else tp_size
51
+ local_tp_rank = tp_rank % local_tp_size
52
+ local_dp_size = max(1, dp_size // (tp_size // local_tp_size))
53
+
54
+ local_attn_tp_size = local_tp_size // local_dp_size
55
+ local_attn_dp_rank = local_tp_rank // local_attn_tp_size
56
+ local_attn_tp_rank = local_tp_rank % local_attn_tp_size
57
+
58
+ return local_attn_tp_rank, local_attn_tp_size, local_attn_dp_rank
39
59
 
40
60
 
41
61
  def initialize_dp_attention(
@@ -43,22 +63,32 @@ def initialize_dp_attention(
43
63
  tp_rank: int,
44
64
  tp_size: int,
45
65
  dp_size: int,
66
+ moe_dense_tp_size: int,
46
67
  pp_size: int,
47
68
  ):
48
- global _ATTN_TP_GROUP, _ATTN_TP_RANK, _ATTN_TP_SIZE, _DP_RANK, _DP_SIZE
69
+ global _ATTN_TP_GROUP, _ATTN_TP_RANK, _ATTN_TP_SIZE, _ATTN_DP_RANK, _ATTN_DP_SIZE
70
+ global _LOCAL_ATTN_DP_SIZE, _LOCAL_ATTN_DP_RANK
49
71
 
50
72
  from sglang.srt.layers.sampler import SYNC_TOKEN_IDS_ACROSS_TP
51
73
 
52
- _ATTN_TP_RANK, _ATTN_TP_SIZE, _DP_RANK = compute_dp_attention_world_info(
74
+ _ATTN_TP_RANK, _ATTN_TP_SIZE, _ATTN_DP_RANK = compute_dp_attention_world_info(
53
75
  enable_dp_attention, tp_rank, tp_size, dp_size
54
76
  )
77
+ _, _, _LOCAL_ATTN_DP_RANK = compute_dp_attention_local_info(
78
+ enable_dp_attention, tp_rank, tp_size, dp_size, moe_dense_tp_size
79
+ )
55
80
 
56
81
  if enable_dp_attention:
57
82
  local_rank = tp_rank % (tp_size // dp_size)
58
- _DP_SIZE = dp_size
83
+ _ATTN_DP_SIZE = dp_size
84
+ if moe_dense_tp_size is None:
85
+ _LOCAL_ATTN_DP_SIZE = _ATTN_DP_SIZE
86
+ else:
87
+ _LOCAL_ATTN_DP_SIZE = max(1, dp_size // (tp_size // moe_dense_tp_size))
59
88
  else:
60
89
  local_rank = tp_rank
61
- _DP_SIZE = 1
90
+ _ATTN_DP_SIZE = 1
91
+ _LOCAL_ATTN_DP_SIZE = 1
62
92
 
63
93
  tp_group = get_tp_group()
64
94
  _ATTN_TP_GROUP = GroupCoordinator(
@@ -93,13 +123,33 @@ def get_attention_tp_size():
93
123
 
94
124
 
95
125
  def get_attention_dp_rank():
96
- assert _DP_RANK is not None, "dp attention not initialized!"
97
- return _DP_RANK
126
+ assert _ATTN_DP_RANK is not None, "dp attention not initialized!"
127
+ return _ATTN_DP_RANK
98
128
 
99
129
 
100
130
  def get_attention_dp_size():
101
- assert _DP_SIZE is not None, "dp attention not initialized!"
102
- return _DP_SIZE
131
+ assert _ATTN_DP_SIZE is not None, "dp attention not initialized!"
132
+ return _ATTN_DP_SIZE
133
+
134
+
135
+ def get_local_attention_dp_rank():
136
+ assert _LOCAL_ATTN_DP_RANK is not None, "dp attention not initialized!"
137
+ return _LOCAL_ATTN_DP_RANK
138
+
139
+
140
+ def get_local_attention_dp_size():
141
+ assert _LOCAL_ATTN_DP_SIZE is not None, "dp attention not initialized!"
142
+ return _LOCAL_ATTN_DP_SIZE
143
+
144
+
145
+ def get_local_attention_dp_rank():
146
+ assert _LOCAL_ATTN_DP_RANK is not None, "dp attention not initialized!"
147
+ return _LOCAL_ATTN_DP_RANK
148
+
149
+
150
+ def get_local_attention_dp_size():
151
+ assert _LOCAL_ATTN_DP_SIZE is not None, "dp attention not initialized!"
152
+ return _LOCAL_ATTN_DP_SIZE
103
153
 
104
154
 
105
155
  @contextmanager
@@ -112,19 +162,19 @@ def disable_dp_size():
112
162
  Args:
113
163
  tp_group (GroupCoordinator): the tp group coordinator
114
164
  """
115
- global _DP_SIZE
116
- assert _DP_SIZE is not None, "dp attention not initialized!"
165
+ global _ATTN_DP_SIZE
166
+ assert _ATTN_DP_SIZE is not None, "dp attention not initialized!"
117
167
 
118
- old_dp_size = _DP_SIZE
119
- _DP_SIZE = 1
168
+ old_dp_size = _ATTN_DP_SIZE
169
+ _ATTN_DP_SIZE = 1
120
170
  try:
121
171
  yield
122
172
  finally:
123
- _DP_SIZE = old_dp_size
173
+ _ATTN_DP_SIZE = old_dp_size
124
174
 
125
175
 
126
176
  def get_dp_local_info(forward_batch: ForwardBatch):
127
- dp_rank = get_attention_dp_rank()
177
+ dp_rank = get_local_attention_dp_rank()
128
178
 
129
179
  if forward_batch.dp_local_start_pos is None:
130
180
  cumtokens = torch.cumsum(forward_batch.global_num_tokens_gpu, dim=0)
@@ -201,7 +251,7 @@ def _dp_gather(
201
251
  global_tokens, local_tokens, 0, local_start_pos, local_num_tokens, False
202
252
  )
203
253
 
204
- # Input IDs are in int 32. We should use inplace_all_reduce for local case becaues of custom all reduce.
254
+ # Input IDs are in int 32. We should use inplace_all_reduce for local case because of custom all reduce.
205
255
  NUM_GPUS_PER_NODE = 8
206
256
  if (
207
257
  not local_tokens.dtype.is_floating_point
@@ -252,12 +302,12 @@ def dp_scatter(
252
302
  )
253
303
 
254
304
 
255
- def tp_reduce_scatter(
305
+ def attn_tp_reduce_scatter(
256
306
  output: torch.Tensor,
257
307
  input_list: List[torch.Tensor],
258
308
  ):
259
309
  return get_attention_tp_group().reduce_scatter(output, input_list)
260
310
 
261
311
 
262
- def tp_all_gather(output_list: List[torch.Tensor], input_: torch.Tensor):
312
+ def attn_tp_all_gather(output_list: List[torch.Tensor], input_: torch.Tensor):
263
313
  return get_attention_tp_group().all_gather(input_, tensor_list=output_list)
@@ -76,7 +76,7 @@ class RMSNorm(CustomOp):
76
76
  residual: Optional[torch.Tensor] = None,
77
77
  ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
78
78
  if not x.is_contiguous():
79
- # NOTE: Romove this if aiter kernel supports discontinuous input
79
+ # NOTE: Remove this if aiter kernel supports discontinuous input
80
80
  x = x.contiguous()
81
81
  if residual is not None:
82
82
  fused_add_rms_norm(x, residual, self.weight.data, self.variance_epsilon)
@@ -23,15 +23,17 @@ import triton.language as tl
23
23
  from torch import nn
24
24
 
25
25
  from sglang.srt.distributed import (
26
- get_tensor_model_parallel_rank,
27
26
  get_tensor_model_parallel_world_size,
28
27
  tensor_model_parallel_all_gather,
29
28
  )
30
29
  from sglang.srt.layers.dp_attention import (
30
+ attn_tp_all_gather,
31
31
  dp_gather_replicate,
32
32
  dp_scatter,
33
- get_attention_dp_rank,
34
33
  get_attention_dp_size,
34
+ get_attention_tp_size,
35
+ get_local_attention_dp_rank,
36
+ get_local_attention_dp_size,
35
37
  )
36
38
  from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
37
39
  from sglang.srt.managers.schedule_batch import global_server_args_dict
@@ -45,6 +47,18 @@ from sglang.srt.utils import dump_to_file
45
47
  logger = logging.getLogger(__name__)
46
48
 
47
49
 
50
+ from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
51
+ from sglang.srt.managers.schedule_batch import global_server_args_dict
52
+ from sglang.srt.model_executor.forward_batch_info import (
53
+ CaptureHiddenMode,
54
+ ForwardBatch,
55
+ ForwardMode,
56
+ )
57
+ from sglang.srt.utils import dump_to_file
58
+
59
+ logger = logging.getLogger(__name__)
60
+
61
+
48
62
  @dataclasses.dataclass
49
63
  class LogitsProcessorOutput:
50
64
  ## Part 1: This part will be assigned in python/sglang/srt/layers/logits_processor.py::LogitsProcessor
@@ -169,7 +183,7 @@ class LogitsMetadata:
169
183
  return
170
184
 
171
185
  cumtokens = torch.cumsum(self.global_num_tokens_for_logprob_gpu, dim=0)
172
- dp_rank = get_attention_dp_rank()
186
+ dp_rank = get_local_attention_dp_rank()
173
187
  if dp_rank == 0:
174
188
  dp_local_start_pos = torch.zeros_like(
175
189
  self.global_num_tokens_for_logprob_gpu[0]
@@ -198,12 +212,20 @@ class LogitsProcessor(nn.Module):
198
212
  super().__init__()
199
213
  self.config = config
200
214
  self.logit_scale = logit_scale
201
- self.do_tensor_parallel_all_gather = (
202
- not skip_all_gather and get_tensor_model_parallel_world_size() > 1
203
- )
204
- self.do_tensor_parallel_all_gather_dp_attn = (
205
- self.do_tensor_parallel_all_gather and get_attention_dp_size() != 1
206
- )
215
+ self.use_attn_tp_group = global_server_args_dict["enable_dp_lm_head"]
216
+ if self.use_attn_tp_group:
217
+ self.attn_tp_size = get_attention_tp_size()
218
+ self.do_tensor_parallel_all_gather = (
219
+ not skip_all_gather and self.attn_tp_size > 1
220
+ )
221
+ self.do_tensor_parallel_all_gather_dp_attn = False
222
+ else:
223
+ self.do_tensor_parallel_all_gather = (
224
+ not skip_all_gather and get_tensor_model_parallel_world_size() > 1
225
+ )
226
+ self.do_tensor_parallel_all_gather_dp_attn = (
227
+ self.do_tensor_parallel_all_gather and get_attention_dp_size() != 1
228
+ )
207
229
  self.final_logit_softcapping = getattr(
208
230
  self.config, "final_logit_softcapping", None
209
231
  )
@@ -315,7 +337,8 @@ class LogitsProcessor(nn.Module):
315
337
 
316
338
  if self.debug_tensor_dump_output_folder:
317
339
  assert (
318
- not self.do_tensor_parallel_all_gather or get_attention_dp_size() == 1
340
+ not self.do_tensor_parallel_all_gather
341
+ or get_local_attention_dp_size() == 1
319
342
  ), "dp attention + sharded lm_head doesn't support full logits"
320
343
  full_logits = self._get_logits(hidden_states, lm_head, logits_metadata)
321
344
  dump_to_file(self.debug_tensor_dump_output_folder, "logits", full_logits)
@@ -442,7 +465,19 @@ class LogitsProcessor(nn.Module):
442
465
  logits.mul_(self.logit_scale)
443
466
 
444
467
  if self.do_tensor_parallel_all_gather:
445
- logits = tensor_model_parallel_all_gather(logits)
468
+ if self.use_attn_tp_group:
469
+ global_logits = torch.empty(
470
+ (self.config.vocab_size, logits.shape[0]),
471
+ device=logits.device,
472
+ dtype=logits.dtype,
473
+ )
474
+ global_logits = global_logits.T
475
+ attn_tp_all_gather(
476
+ list(global_logits.tensor_split(self.attn_tp_size, dim=-1)), logits
477
+ )
478
+ logits = global_logits
479
+ else:
480
+ logits = tensor_model_parallel_all_gather(logits)
446
481
 
447
482
  if self.do_tensor_parallel_all_gather_dp_attn:
448
483
  logits, global_logits = (
@@ -116,7 +116,7 @@ def deepep_run_moe_deep_preprocess(topk_ids: torch.Tensor, num_experts: int):
116
116
  seg_indptr = torch.empty(num_experts + 1, device=topk_ids.device, dtype=torch.int64)
117
117
  src2dst = torch.empty(topk_ids.numel(), device=topk_ids.device, dtype=torch.int64)
118
118
 
119
- # Find offet
119
+ # Find offset
120
120
  expert_ids = torch.arange(
121
121
  num_experts + 1, device=topk_ids.device, dtype=reorder_topk_ids.dtype
122
122
  )
@@ -611,7 +611,7 @@ class Fp8EPMoEMethod(Fp8MoEMethod):
611
611
  self.quant_config.weight_block_size[1],
612
612
  )
613
613
  # NOTE(HandH1998): To ensure proper alignment of the block-wise quantization scales, the output_size of the weights for both the gate and up layers must be divisible by block_n.
614
- # Required by collum parallel or enabling merged weights
614
+ # Required by column parallel or enabling merged weights
615
615
  if intermediate_size % block_n != 0:
616
616
  raise ValueError(
617
617
  f"The output_size of gate's and up's weight = "
@@ -994,7 +994,7 @@ def get_default_config(
994
994
  "num_stages": 2 if _is_hip else 4,
995
995
  }
996
996
  else:
997
- # Block-wise quant: BLOCK_SIZE_K must be divisable by block_shape[1]
997
+ # Block-wise quant: BLOCK_SIZE_K must be divisible by block_shape[1]
998
998
  config = {
999
999
  "BLOCK_SIZE_M": 64,
1000
1000
  "BLOCK_SIZE_N": block_shape[0],
@@ -270,7 +270,7 @@ def select_experts(
270
270
  routed_scaling_factor: Optional[float] = None,
271
271
  ):
272
272
  n_share_experts_fusion = global_server_args_dict["n_share_experts_fusion"]
273
- # DeekSeek V2/V3/R1 serices models uses grouped_top_k
273
+ # DeepSeek V2/V3/R1 series models use grouped_top_k
274
274
  if use_grouped_topk:
275
275
  assert topk_group is not None
276
276
  assert num_expert_group is not None
@@ -109,7 +109,7 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
109
109
  if quantization in VLLM_QUANTIZATION_METHODS and not VLLM_AVAILABLE:
110
110
  raise ValueError(
111
111
  f"{quantization} quantization requires some operators from vllm. "
112
- "Pleaes install vllm by `pip install vllm==0.8.4`"
112
+ "Please install vllm by `pip install vllm==0.8.4`"
113
113
  )
114
114
 
115
115
  return QUANTIZATION_METHODS[quantization]
@@ -152,7 +152,7 @@ class BlockInt8LinearMethod(LinearMethodBase):
152
152
  f"{input_size_per_partition} is not divisible by "
153
153
  f"weight quantization block_k = {block_k}."
154
154
  )
155
- # Required by collum parallel or enabling merged weights
155
+ # Required by column parallel or enabling merged weights
156
156
  if (tp_size > 1 and output_size // output_size_per_partition == tp_size) or len(
157
157
  output_partition_sizes
158
158
  ) > 1:
@@ -285,7 +285,7 @@ class BlockInt8MoEMethod:
285
285
  self.quant_config.weight_block_size[1],
286
286
  )
287
287
  # NOTE(HandH1998): To ensure proper alignment of the block-wise quantization scales, the output_size of the weights for both the gate and up layers must be divisible by block_n.
288
- # Required by collum parallel or enabling merged weights
288
+ # Required by column parallel or enabling merged weights
289
289
  if intermediate_size % block_n != 0:
290
290
  raise ValueError(
291
291
  f"The output_size of gate's and up's weight = "