sglang 0.4.1.post5__py3-none-any.whl → 0.4.1.post7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (129) hide show
  1. sglang/__init__.py +21 -23
  2. sglang/api.py +2 -7
  3. sglang/bench_offline_throughput.py +24 -16
  4. sglang/bench_one_batch.py +51 -3
  5. sglang/bench_one_batch_server.py +1 -1
  6. sglang/bench_serving.py +37 -28
  7. sglang/lang/backend/runtime_endpoint.py +183 -4
  8. sglang/lang/chat_template.py +15 -4
  9. sglang/launch_server.py +1 -1
  10. sglang/srt/_custom_ops.py +80 -42
  11. sglang/srt/configs/device_config.py +1 -1
  12. sglang/srt/configs/model_config.py +16 -6
  13. sglang/srt/constrained/base_grammar_backend.py +21 -0
  14. sglang/srt/constrained/xgrammar_backend.py +8 -4
  15. sglang/srt/conversation.py +14 -1
  16. sglang/srt/distributed/__init__.py +3 -3
  17. sglang/srt/distributed/communication_op.py +2 -1
  18. sglang/srt/distributed/device_communicators/cuda_wrapper.py +2 -1
  19. sglang/srt/distributed/device_communicators/custom_all_reduce.py +107 -40
  20. sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +2 -2
  21. sglang/srt/distributed/device_communicators/hpu_communicator.py +2 -1
  22. sglang/srt/distributed/device_communicators/pynccl.py +80 -1
  23. sglang/srt/distributed/device_communicators/pynccl_wrapper.py +112 -2
  24. sglang/srt/distributed/device_communicators/shm_broadcast.py +5 -72
  25. sglang/srt/distributed/device_communicators/xpu_communicator.py +2 -1
  26. sglang/srt/distributed/parallel_state.py +1 -1
  27. sglang/srt/distributed/utils.py +2 -1
  28. sglang/srt/entrypoints/engine.py +449 -0
  29. sglang/srt/entrypoints/http_server.py +579 -0
  30. sglang/srt/layers/activation.py +3 -3
  31. sglang/srt/layers/attention/flashinfer_backend.py +27 -12
  32. sglang/srt/layers/attention/triton_backend.py +4 -6
  33. sglang/srt/layers/attention/vision.py +204 -0
  34. sglang/srt/layers/dp_attention.py +69 -0
  35. sglang/srt/layers/linear.py +76 -102
  36. sglang/srt/layers/logits_processor.py +48 -63
  37. sglang/srt/layers/moe/ep_moe/layer.py +4 -4
  38. sglang/srt/layers/moe/fused_moe_native.py +69 -0
  39. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +9 -6
  40. sglang/srt/layers/moe/fused_moe_triton/layer.py +66 -14
  41. sglang/srt/layers/moe/topk.py +4 -2
  42. sglang/srt/layers/parameter.py +26 -17
  43. sglang/srt/layers/quantization/__init__.py +22 -23
  44. sglang/srt/layers/quantization/fp8.py +112 -55
  45. sglang/srt/layers/quantization/fp8_utils.py +1 -1
  46. sglang/srt/layers/quantization/int8_kernel.py +54 -0
  47. sglang/srt/layers/quantization/modelopt_quant.py +2 -3
  48. sglang/srt/layers/quantization/w8a8_int8.py +117 -0
  49. sglang/srt/layers/radix_attention.py +2 -0
  50. sglang/srt/layers/rotary_embedding.py +1179 -31
  51. sglang/srt/layers/sampler.py +39 -1
  52. sglang/srt/layers/vocab_parallel_embedding.py +17 -4
  53. sglang/srt/lora/lora.py +1 -9
  54. sglang/srt/managers/configure_logging.py +46 -0
  55. sglang/srt/managers/data_parallel_controller.py +79 -72
  56. sglang/srt/managers/detokenizer_manager.py +23 -8
  57. sglang/srt/managers/image_processor.py +158 -2
  58. sglang/srt/managers/io_struct.py +54 -15
  59. sglang/srt/managers/schedule_batch.py +49 -22
  60. sglang/srt/managers/schedule_policy.py +26 -12
  61. sglang/srt/managers/scheduler.py +319 -181
  62. sglang/srt/managers/session_controller.py +1 -0
  63. sglang/srt/managers/tokenizer_manager.py +303 -158
  64. sglang/srt/managers/tp_worker.py +6 -4
  65. sglang/srt/managers/tp_worker_overlap_thread.py +5 -8
  66. sglang/srt/managers/utils.py +44 -0
  67. sglang/srt/mem_cache/memory_pool.py +110 -77
  68. sglang/srt/metrics/collector.py +25 -11
  69. sglang/srt/model_executor/cuda_graph_runner.py +4 -6
  70. sglang/srt/model_executor/model_runner.py +80 -21
  71. sglang/srt/model_loader/loader.py +8 -6
  72. sglang/srt/model_loader/weight_utils.py +55 -2
  73. sglang/srt/models/baichuan.py +6 -6
  74. sglang/srt/models/chatglm.py +2 -2
  75. sglang/srt/models/commandr.py +3 -3
  76. sglang/srt/models/dbrx.py +4 -4
  77. sglang/srt/models/deepseek.py +3 -3
  78. sglang/srt/models/deepseek_v2.py +8 -8
  79. sglang/srt/models/exaone.py +2 -2
  80. sglang/srt/models/gemma.py +2 -2
  81. sglang/srt/models/gemma2.py +6 -24
  82. sglang/srt/models/gpt2.py +3 -5
  83. sglang/srt/models/gpt_bigcode.py +1 -1
  84. sglang/srt/models/granite.py +2 -2
  85. sglang/srt/models/grok.py +3 -3
  86. sglang/srt/models/internlm2.py +2 -2
  87. sglang/srt/models/llama.py +41 -4
  88. sglang/srt/models/minicpm.py +2 -2
  89. sglang/srt/models/minicpm3.py +6 -6
  90. sglang/srt/models/minicpmv.py +1238 -0
  91. sglang/srt/models/mixtral.py +3 -3
  92. sglang/srt/models/mixtral_quant.py +3 -3
  93. sglang/srt/models/mllama.py +2 -2
  94. sglang/srt/models/olmo.py +3 -3
  95. sglang/srt/models/olmo2.py +4 -4
  96. sglang/srt/models/olmoe.py +7 -13
  97. sglang/srt/models/phi3_small.py +2 -2
  98. sglang/srt/models/qwen.py +2 -2
  99. sglang/srt/models/qwen2.py +52 -4
  100. sglang/srt/models/qwen2_eagle.py +131 -0
  101. sglang/srt/models/qwen2_moe.py +3 -3
  102. sglang/srt/models/qwen2_vl.py +22 -122
  103. sglang/srt/models/stablelm.py +2 -2
  104. sglang/srt/models/torch_native_llama.py +3 -3
  105. sglang/srt/models/xverse.py +6 -6
  106. sglang/srt/models/xverse_moe.py +6 -6
  107. sglang/srt/openai_api/protocol.py +2 -0
  108. sglang/srt/sampling/custom_logit_processor.py +38 -0
  109. sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +15 -5
  110. sglang/srt/sampling/sampling_batch_info.py +153 -9
  111. sglang/srt/sampling/sampling_params.py +4 -2
  112. sglang/srt/server.py +4 -1037
  113. sglang/srt/server_args.py +84 -32
  114. sglang/srt/speculative/eagle_worker.py +1 -0
  115. sglang/srt/torch_memory_saver_adapter.py +59 -0
  116. sglang/srt/utils.py +130 -63
  117. sglang/test/runners.py +8 -13
  118. sglang/test/test_programs.py +1 -1
  119. sglang/test/test_utils.py +3 -1
  120. sglang/utils.py +12 -2
  121. sglang/version.py +1 -1
  122. {sglang-0.4.1.post5.dist-info → sglang-0.4.1.post7.dist-info}/METADATA +26 -13
  123. {sglang-0.4.1.post5.dist-info → sglang-0.4.1.post7.dist-info}/RECORD +126 -117
  124. sglang/launch_server_llavavid.py +0 -25
  125. sglang/srt/constrained/__init__.py +0 -16
  126. sglang/srt/distributed/device_communicators/__init__.py +0 -0
  127. {sglang-0.4.1.post5.dist-info → sglang-0.4.1.post7.dist-info}/LICENSE +0 -0
  128. {sglang-0.4.1.post5.dist-info → sglang-0.4.1.post7.dist-info}/WHEEL +0 -0
  129. {sglang-0.4.1.post5.dist-info → sglang-0.4.1.post7.dist-info}/top_level.txt +0 -0
@@ -18,6 +18,7 @@ import triton.language as tl
18
18
 
19
19
  from sglang.global_config import global_config
20
20
  from sglang.srt.layers.attention import AttentionBackend
21
+ from sglang.srt.layers.dp_attention import get_attention_tp_size
21
22
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
22
23
  from sglang.srt.utils import is_flashinfer_available
23
24
 
@@ -62,9 +63,9 @@ class FlashInferAttnBackend(AttentionBackend):
62
63
  self.decode_use_tensor_cores = should_use_tensor_core(
63
64
  kv_cache_dtype=model_runner.kv_cache_dtype,
64
65
  num_attention_heads=model_runner.model_config.num_attention_heads
65
- // model_runner.tp_size,
66
+ // get_attention_tp_size(),
66
67
  num_kv_heads=model_runner.model_config.get_num_kv_heads(
67
- model_runner.tp_size
68
+ get_attention_tp_size()
68
69
  ),
69
70
  )
70
71
  self.max_context_len = model_runner.model_config.context_len
@@ -84,6 +85,10 @@ class FlashInferAttnBackend(AttentionBackend):
84
85
  self.num_wrappers = 1
85
86
  self.dispatch_reason = None
86
87
 
88
+ # Qwen2 models require higher flashinfer workspace size
89
+ if "Qwen2ForCausalLM" in model_runner.model_config.hf_config.architectures:
90
+ global_config.flashinfer_workspace_size = 512 * 1024 * 1024
91
+
87
92
  # Allocate buffers
88
93
  self.workspace_buffer = torch.empty(
89
94
  global_config.flashinfer_workspace_size,
@@ -143,7 +148,7 @@ class FlashInferAttnBackend(AttentionBackend):
143
148
  self.prefill_cuda_graph_metadata = {}
144
149
 
145
150
  def init_forward_metadata(self, forward_batch: ForwardBatch):
146
- if forward_batch.forward_mode.is_decode():
151
+ if forward_batch.forward_mode.is_decode_or_idle():
147
152
  self.indices_updater_decode.update(
148
153
  forward_batch.req_pool_indices,
149
154
  forward_batch.seq_lens,
@@ -234,7 +239,7 @@ class FlashInferAttnBackend(AttentionBackend):
234
239
  forward_mode: ForwardMode,
235
240
  spec_info: Optional[SpecInfo],
236
241
  ):
237
- if forward_mode.is_decode():
242
+ if forward_mode.is_decode_or_idle():
238
243
  decode_wrappers = []
239
244
  for i in range(self.num_wrappers):
240
245
  decode_wrappers.append(
@@ -303,7 +308,7 @@ class FlashInferAttnBackend(AttentionBackend):
303
308
  forward_mode: ForwardMode,
304
309
  spec_info: Optional[SpecInfo],
305
310
  ):
306
- if forward_mode.is_decode():
311
+ if forward_mode.is_decode_or_idle():
307
312
  self.indices_updater_decode.update(
308
313
  req_pool_indices[:bs],
309
314
  seq_lens[:bs],
@@ -353,7 +358,9 @@ class FlashInferAttnBackend(AttentionBackend):
353
358
  if k is not None:
354
359
  assert v is not None
355
360
  if save_kv_cache:
356
- forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
361
+ forward_batch.token_to_kv_pool.set_kv_buffer(
362
+ layer, cache_loc, k, v, layer.k_scale, layer.v_scale
363
+ )
357
364
 
358
365
  o = prefill_wrapper_paged.forward(
359
366
  q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
@@ -362,6 +369,8 @@ class FlashInferAttnBackend(AttentionBackend):
362
369
  sm_scale=layer.scaling,
363
370
  window_left=layer.sliding_window_size,
364
371
  logits_soft_cap=logits_soft_cap,
372
+ k_scale=layer.k_scale,
373
+ v_scale=layer.v_scale,
365
374
  )
366
375
  else:
367
376
  o1, s1 = self.prefill_wrapper_ragged.forward_return_lse(
@@ -387,7 +396,9 @@ class FlashInferAttnBackend(AttentionBackend):
387
396
  o, _ = merge_state(o1, s1, o2, s2)
388
397
 
389
398
  if save_kv_cache:
390
- forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
399
+ forward_batch.token_to_kv_pool.set_kv_buffer(
400
+ layer, cache_loc, k, v, layer.k_scale, layer.v_scale
401
+ )
391
402
 
392
403
  return o.view(-1, layer.tp_q_head_num * layer.head_dim)
393
404
 
@@ -412,13 +423,17 @@ class FlashInferAttnBackend(AttentionBackend):
412
423
  if k is not None:
413
424
  assert v is not None
414
425
  if save_kv_cache:
415
- forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
426
+ forward_batch.token_to_kv_pool.set_kv_buffer(
427
+ layer, cache_loc, k, v, layer.k_scale, layer.v_scale
428
+ )
416
429
 
417
430
  o = decode_wrapper.forward(
418
431
  q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
419
432
  forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
420
433
  sm_scale=layer.scaling,
421
434
  logits_soft_cap=layer.logit_cap,
435
+ k_scale=layer.k_scale,
436
+ v_scale=layer.v_scale,
422
437
  )
423
438
 
424
439
  return o.view(-1, layer.tp_q_head_num * layer.head_dim)
@@ -439,10 +454,10 @@ class FlashInferIndicesUpdaterDecode:
439
454
  def __init__(self, model_runner: ModelRunner, attn_backend: AttentionBackend):
440
455
  # Parse Constants
441
456
  self.num_qo_heads = (
442
- model_runner.model_config.num_attention_heads // model_runner.tp_size
457
+ model_runner.model_config.num_attention_heads // get_attention_tp_size()
443
458
  )
444
459
  self.num_kv_heads = model_runner.model_config.get_num_kv_heads(
445
- model_runner.tp_size
460
+ get_attention_tp_size()
446
461
  )
447
462
  self.head_dim = model_runner.model_config.head_dim
448
463
  self.data_type = model_runner.kv_cache_dtype
@@ -611,10 +626,10 @@ class FlashInferIndicesUpdaterPrefill:
611
626
  def __init__(self, model_runner: ModelRunner, attn_backend: AttentionBackend):
612
627
  # Parse Constants
613
628
  self.num_qo_heads = (
614
- model_runner.model_config.num_attention_heads // model_runner.tp_size
629
+ model_runner.model_config.num_attention_heads // get_attention_tp_size()
615
630
  )
616
631
  self.num_kv_heads = model_runner.model_config.get_num_kv_heads(
617
- model_runner.tp_size
632
+ get_attention_tp_size()
618
633
  )
619
634
  self.head_dim = model_runner.model_config.head_dim
620
635
  self.data_type = model_runner.kv_cache_dtype
@@ -5,6 +5,7 @@ from typing import TYPE_CHECKING, Optional
5
5
  import torch
6
6
 
7
7
  from sglang.srt.layers.attention import AttentionBackend
8
+ from sglang.srt.layers.dp_attention import get_attention_tp_size
8
9
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
9
10
 
10
11
  if TYPE_CHECKING:
@@ -28,12 +29,9 @@ class TritonAttnBackend(AttentionBackend):
28
29
  self.decode_attention_fwd = decode_attention_fwd
29
30
  self.extend_attention_fwd = extend_attention_fwd
30
31
 
31
- if model_runner.server_args.enable_dp_attention:
32
- self.num_head = model_runner.model_config.num_attention_heads
33
- else:
34
- self.num_head = (
35
- model_runner.model_config.num_attention_heads // model_runner.tp_size
36
- )
32
+ self.num_head = (
33
+ model_runner.model_config.num_attention_heads // get_attention_tp_size()
34
+ )
37
35
 
38
36
  self.num_kv_splits = model_runner.server_args.triton_attention_num_kv_splits
39
37
  self.v_head_dim = model_runner.token_to_kv_pool.get_value_buffer(0).shape[-1]
@@ -0,0 +1,204 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Optional
4
+
5
+ import torch
6
+ import torch.nn as nn
7
+ from einops import rearrange, repeat
8
+
9
+ from sglang.srt.distributed import parallel_state
10
+ from sglang.srt.distributed import utils as dist_utils
11
+ from sglang.srt.layers.attention.triton_ops.prefill_attention import (
12
+ context_attention_fwd,
13
+ )
14
+ from sglang.srt.layers.linear import (
15
+ ColumnParallelLinear,
16
+ QKVParallelLinear,
17
+ RowParallelLinear,
18
+ )
19
+ from sglang.srt.layers.quantization import QuantizationConfig
20
+
21
+
22
+ def rotate_half(x: torch.Tensor, interleaved: bool = False) -> torch.Tensor:
23
+ if not interleaved:
24
+ x1, x2 = x.chunk(2, dim=-1)
25
+ return torch.cat((-x2, x1), dim=-1)
26
+ else:
27
+ x1, x2 = x[..., ::2], x[..., 1::2]
28
+ return rearrange(
29
+ torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2
30
+ )
31
+
32
+
33
+ def apply_rotary_emb_torch(
34
+ x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, interleaved: bool = False
35
+ ) -> torch.Tensor:
36
+ """
37
+ x: (batch_size, seqlen, nheads, headdim)
38
+ cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2)
39
+ """
40
+ ro_dim = cos.shape[-1] * 2
41
+ assert ro_dim <= x.shape[-1]
42
+ cos = repeat(
43
+ cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)"
44
+ )
45
+ sin = repeat(
46
+ sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)"
47
+ )
48
+ return torch.cat(
49
+ [
50
+ x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin,
51
+ x[..., ro_dim:],
52
+ ],
53
+ dim=-1,
54
+ )
55
+
56
+
57
+ def apply_rotary_pos_emb_vision(t: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor:
58
+ t_ = t.float()
59
+ cos = freqs.cos()
60
+ sin = freqs.sin()
61
+ output = apply_rotary_emb_torch(t_, cos, sin).type_as(t)
62
+ return output
63
+
64
+
65
+ class VisionAttention(nn.Module):
66
+ """Multi-headed attention without any cache, mostly used for ViT."""
67
+
68
+ def __init__(
69
+ self,
70
+ embed_dim: int,
71
+ num_heads: int,
72
+ projection_size: int,
73
+ use_qkv_parallel: bool,
74
+ quant_config: Optional[QuantizationConfig] = None,
75
+ prefix: str = "",
76
+ ):
77
+ super().__init__()
78
+ world_size = parallel_state.get_tensor_model_parallel_world_size()
79
+
80
+ self.hidden_size_per_attention_head = dist_utils.divide(
81
+ projection_size, num_heads
82
+ )
83
+ self.num_attention_heads_per_partition = dist_utils.divide(
84
+ num_heads, world_size
85
+ )
86
+ # self.tp_size = get_tensor_model_parallel_world_size()
87
+ # num_heads = self.num_heads_per_partition
88
+ self.use_qkv_parallel = use_qkv_parallel
89
+ if use_qkv_parallel:
90
+ self.head_dim = embed_dim // num_heads
91
+ self.qkv_proj = QKVParallelLinear(
92
+ hidden_size=embed_dim,
93
+ head_size=self.head_dim,
94
+ total_num_heads=num_heads,
95
+ quant_config=quant_config,
96
+ prefix=f"{prefix}.qkv_proj",
97
+ )
98
+ else:
99
+ self.qkv_proj = ColumnParallelLinear(
100
+ input_size=embed_dim,
101
+ output_size=3 * projection_size,
102
+ quant_config=quant_config,
103
+ prefix=f"{prefix}.qkv_proj",
104
+ )
105
+ self.proj = RowParallelLinear(
106
+ input_size=embed_dim,
107
+ output_size=embed_dim,
108
+ quant_config=quant_config,
109
+ prefix=f"{prefix}.out_proj",
110
+ )
111
+
112
+ def forward(
113
+ self,
114
+ x: torch.Tensor,
115
+ cu_seqlens: Optional[torch.Tensor] = None,
116
+ rotary_pos_emb: torch.Tensor = None,
117
+ ) -> torch.Tensor:
118
+ """
119
+ Input shape: [b, s, embed_dim]
120
+ Output shape: [s, b, num_heads * head_size]
121
+ """
122
+
123
+ bsz, s, _ = x.shape
124
+ if self.use_qkv_parallel:
125
+ # [b, s, embed_dim] --> [b, s, embed_dim]
126
+ qkv, _ = self.qkv_proj(x)
127
+ q, k, v = qkv.chunk(3, dim=-1)
128
+
129
+ # [b, s, embed_dim] --> [b * s, num_heads, head_size]
130
+ q, k, v = [
131
+ x.reshape(
132
+ bsz * s, self.num_attention_heads_per_partition, -1
133
+ ).contiguous()
134
+ for x in (q, k, v)
135
+ ]
136
+ else:
137
+ # [b, s, embed_dim] --> [s, b, embed_dim]
138
+ x = rearrange(x, "b s ... -> s b ...")
139
+ # [s, b, embed_dim] --> [s, b, head * 3 * head_dim]
140
+ qkv, _ = self.qkv_proj(x)
141
+ # [s, b, head * 3 * head_dim] --> [s, b, head, 3 * head_dim]
142
+ new_x_shape = qkv.size()[:-1] + (
143
+ self.num_attention_heads_per_partition,
144
+ 3 * self.hidden_size_per_attention_head,
145
+ )
146
+ qkv = qkv.view(*new_x_shape)
147
+
148
+ # [s, b, head, 3 * head_dim] --> 3 [s, b, head, head_dim]
149
+ q, k, v = dist_utils.split_tensor_along_last_dim(qkv, 3)
150
+
151
+ # [s, b, head, head_dim] --> [b, s, head, head_dim]
152
+ q, k, v = [
153
+ rearrange(x, "s b ... -> b s ...").contiguous() for x in (q, k, v)
154
+ ]
155
+
156
+ if rotary_pos_emb is not None:
157
+ q = apply_rotary_pos_emb_vision(q, rotary_pos_emb)
158
+ k = apply_rotary_pos_emb_vision(k, rotary_pos_emb)
159
+
160
+ if self.use_qkv_parallel:
161
+ pass
162
+ else:
163
+ # [b, s, head, head_dim] --> [b * s, head, head_dim]
164
+ q, k, v = [rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]]
165
+
166
+ # [b * s, num_heads, head_size]
167
+ output = torch.empty_like(q)
168
+
169
+ seq_lens = (cu_seqlens[1:] - cu_seqlens[:-1]).cuda()
170
+ max_seqlen = seq_lens.max().item()
171
+
172
+ context_attention_fwd(
173
+ q,
174
+ k,
175
+ v,
176
+ output,
177
+ cu_seqlens.cuda(),
178
+ seq_lens,
179
+ max_seqlen,
180
+ is_causal=False,
181
+ )
182
+
183
+ if self.use_qkv_parallel:
184
+
185
+ # [b * s, head, head_dim] --> [b, s, head * head_dim]
186
+ output = rearrange(output, "(b s) ... h d -> b s ... (h d)", b=bsz)
187
+
188
+ # [b, s, head, head_dim] --> [b, s, head, head_dim]
189
+ output, _ = self.proj(output)
190
+ else:
191
+ # [b * s, head, head_dim] --> [b, s, head, head_dim]
192
+ context_layer = rearrange(output, "(b s) ... -> b s ...", b=bsz)
193
+
194
+ # [s, b, num_heads * head_size]
195
+ context_layer = rearrange(
196
+ context_layer, "b s h d -> s b (h d)"
197
+ ).contiguous()
198
+
199
+ # [s, b, num_heads * head_size] --> [s, b, num_heads * head_size]
200
+ output, _ = self.proj(context_layer)
201
+
202
+ output = output.view(bsz, s, -1)
203
+
204
+ return output
@@ -0,0 +1,69 @@
1
+ import torch
2
+
3
+ from sglang.srt.distributed import GroupCoordinator, get_tp_group
4
+
5
+ _ATTN_TP_GROUP = None
6
+ _ATTN_TP_RANK = None
7
+ _ATTN_TP_SIZE = None
8
+ _DP_RANK = None
9
+ _DP_SIZE = None
10
+
11
+
12
+ def compute_dp_attention_world_info(enable_dp_attention, tp_rank, tp_size, dp_size):
13
+ if not enable_dp_attention:
14
+ return tp_rank, tp_size, 0
15
+
16
+ attn_tp_size = tp_size // dp_size
17
+ dp_rank = tp_rank // attn_tp_size
18
+ attn_tp_rank = tp_rank % attn_tp_size
19
+ return attn_tp_rank, attn_tp_size, dp_rank
20
+
21
+
22
+ def initialize_dp_attention(enable_dp_attention, tp_rank, tp_size, dp_size):
23
+ global _ATTN_TP_GROUP, _ATTN_TP_RANK, _ATTN_TP_SIZE, _DP_RANK, _DP_SIZE
24
+
25
+ _ATTN_TP_RANK, _ATTN_TP_SIZE, _DP_RANK = compute_dp_attention_world_info(
26
+ enable_dp_attention, tp_rank, tp_size, dp_size
27
+ )
28
+ _DP_SIZE = dp_size
29
+
30
+ tp_group = get_tp_group()
31
+ _ATTN_TP_GROUP = GroupCoordinator(
32
+ [
33
+ list(range(head, head + _ATTN_TP_SIZE))
34
+ for head in range(0, tp_size, _ATTN_TP_SIZE)
35
+ ],
36
+ tp_rank,
37
+ torch.distributed.get_backend(tp_group.device_group),
38
+ False,
39
+ False,
40
+ False,
41
+ False,
42
+ False,
43
+ group_name="attention_tp",
44
+ )
45
+
46
+
47
+ def get_attention_tp_group():
48
+ assert _ATTN_TP_GROUP is not None, "dp attention not initialized!"
49
+ return _ATTN_TP_GROUP
50
+
51
+
52
+ def get_attention_tp_rank():
53
+ assert _ATTN_TP_RANK is not None, "dp attention not initialized!"
54
+ return _ATTN_TP_RANK
55
+
56
+
57
+ def get_attention_tp_size():
58
+ assert _ATTN_TP_SIZE is not None, "dp attention not initialized!"
59
+ return _ATTN_TP_SIZE
60
+
61
+
62
+ def get_attention_dp_rank():
63
+ assert _DP_RANK is not None, "dp attention not initialized!"
64
+ return _DP_RANK
65
+
66
+
67
+ def get_attention_dp_size():
68
+ assert _DP_SIZE is not None, "dp attention not initialized!"
69
+ return _DP_SIZE