sglang 0.4.1.post5__py3-none-any.whl → 0.4.1.post7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +21 -23
- sglang/api.py +2 -7
- sglang/bench_offline_throughput.py +24 -16
- sglang/bench_one_batch.py +51 -3
- sglang/bench_one_batch_server.py +1 -1
- sglang/bench_serving.py +37 -28
- sglang/lang/backend/runtime_endpoint.py +183 -4
- sglang/lang/chat_template.py +15 -4
- sglang/launch_server.py +1 -1
- sglang/srt/_custom_ops.py +80 -42
- sglang/srt/configs/device_config.py +1 -1
- sglang/srt/configs/model_config.py +16 -6
- sglang/srt/constrained/base_grammar_backend.py +21 -0
- sglang/srt/constrained/xgrammar_backend.py +8 -4
- sglang/srt/conversation.py +14 -1
- sglang/srt/distributed/__init__.py +3 -3
- sglang/srt/distributed/communication_op.py +2 -1
- sglang/srt/distributed/device_communicators/cuda_wrapper.py +2 -1
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +107 -40
- sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +2 -2
- sglang/srt/distributed/device_communicators/hpu_communicator.py +2 -1
- sglang/srt/distributed/device_communicators/pynccl.py +80 -1
- sglang/srt/distributed/device_communicators/pynccl_wrapper.py +112 -2
- sglang/srt/distributed/device_communicators/shm_broadcast.py +5 -72
- sglang/srt/distributed/device_communicators/xpu_communicator.py +2 -1
- sglang/srt/distributed/parallel_state.py +1 -1
- sglang/srt/distributed/utils.py +2 -1
- sglang/srt/entrypoints/engine.py +449 -0
- sglang/srt/entrypoints/http_server.py +579 -0
- sglang/srt/layers/activation.py +3 -3
- sglang/srt/layers/attention/flashinfer_backend.py +27 -12
- sglang/srt/layers/attention/triton_backend.py +4 -6
- sglang/srt/layers/attention/vision.py +204 -0
- sglang/srt/layers/dp_attention.py +69 -0
- sglang/srt/layers/linear.py +76 -102
- sglang/srt/layers/logits_processor.py +48 -63
- sglang/srt/layers/moe/ep_moe/layer.py +4 -4
- sglang/srt/layers/moe/fused_moe_native.py +69 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +9 -6
- sglang/srt/layers/moe/fused_moe_triton/layer.py +66 -14
- sglang/srt/layers/moe/topk.py +4 -2
- sglang/srt/layers/parameter.py +26 -17
- sglang/srt/layers/quantization/__init__.py +22 -23
- sglang/srt/layers/quantization/fp8.py +112 -55
- sglang/srt/layers/quantization/fp8_utils.py +1 -1
- sglang/srt/layers/quantization/int8_kernel.py +54 -0
- sglang/srt/layers/quantization/modelopt_quant.py +2 -3
- sglang/srt/layers/quantization/w8a8_int8.py +117 -0
- sglang/srt/layers/radix_attention.py +2 -0
- sglang/srt/layers/rotary_embedding.py +1179 -31
- sglang/srt/layers/sampler.py +39 -1
- sglang/srt/layers/vocab_parallel_embedding.py +17 -4
- sglang/srt/lora/lora.py +1 -9
- sglang/srt/managers/configure_logging.py +46 -0
- sglang/srt/managers/data_parallel_controller.py +79 -72
- sglang/srt/managers/detokenizer_manager.py +23 -8
- sglang/srt/managers/image_processor.py +158 -2
- sglang/srt/managers/io_struct.py +54 -15
- sglang/srt/managers/schedule_batch.py +49 -22
- sglang/srt/managers/schedule_policy.py +26 -12
- sglang/srt/managers/scheduler.py +319 -181
- sglang/srt/managers/session_controller.py +1 -0
- sglang/srt/managers/tokenizer_manager.py +303 -158
- sglang/srt/managers/tp_worker.py +6 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +5 -8
- sglang/srt/managers/utils.py +44 -0
- sglang/srt/mem_cache/memory_pool.py +110 -77
- sglang/srt/metrics/collector.py +25 -11
- sglang/srt/model_executor/cuda_graph_runner.py +4 -6
- sglang/srt/model_executor/model_runner.py +80 -21
- sglang/srt/model_loader/loader.py +8 -6
- sglang/srt/model_loader/weight_utils.py +55 -2
- sglang/srt/models/baichuan.py +6 -6
- sglang/srt/models/chatglm.py +2 -2
- sglang/srt/models/commandr.py +3 -3
- sglang/srt/models/dbrx.py +4 -4
- sglang/srt/models/deepseek.py +3 -3
- sglang/srt/models/deepseek_v2.py +8 -8
- sglang/srt/models/exaone.py +2 -2
- sglang/srt/models/gemma.py +2 -2
- sglang/srt/models/gemma2.py +6 -24
- sglang/srt/models/gpt2.py +3 -5
- sglang/srt/models/gpt_bigcode.py +1 -1
- sglang/srt/models/granite.py +2 -2
- sglang/srt/models/grok.py +3 -3
- sglang/srt/models/internlm2.py +2 -2
- sglang/srt/models/llama.py +41 -4
- sglang/srt/models/minicpm.py +2 -2
- sglang/srt/models/minicpm3.py +6 -6
- sglang/srt/models/minicpmv.py +1238 -0
- sglang/srt/models/mixtral.py +3 -3
- sglang/srt/models/mixtral_quant.py +3 -3
- sglang/srt/models/mllama.py +2 -2
- sglang/srt/models/olmo.py +3 -3
- sglang/srt/models/olmo2.py +4 -4
- sglang/srt/models/olmoe.py +7 -13
- sglang/srt/models/phi3_small.py +2 -2
- sglang/srt/models/qwen.py +2 -2
- sglang/srt/models/qwen2.py +52 -4
- sglang/srt/models/qwen2_eagle.py +131 -0
- sglang/srt/models/qwen2_moe.py +3 -3
- sglang/srt/models/qwen2_vl.py +22 -122
- sglang/srt/models/stablelm.py +2 -2
- sglang/srt/models/torch_native_llama.py +3 -3
- sglang/srt/models/xverse.py +6 -6
- sglang/srt/models/xverse_moe.py +6 -6
- sglang/srt/openai_api/protocol.py +2 -0
- sglang/srt/sampling/custom_logit_processor.py +38 -0
- sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +15 -5
- sglang/srt/sampling/sampling_batch_info.py +153 -9
- sglang/srt/sampling/sampling_params.py +4 -2
- sglang/srt/server.py +4 -1037
- sglang/srt/server_args.py +84 -32
- sglang/srt/speculative/eagle_worker.py +1 -0
- sglang/srt/torch_memory_saver_adapter.py +59 -0
- sglang/srt/utils.py +130 -63
- sglang/test/runners.py +8 -13
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +3 -1
- sglang/utils.py +12 -2
- sglang/version.py +1 -1
- {sglang-0.4.1.post5.dist-info → sglang-0.4.1.post7.dist-info}/METADATA +26 -13
- {sglang-0.4.1.post5.dist-info → sglang-0.4.1.post7.dist-info}/RECORD +126 -117
- sglang/launch_server_llavavid.py +0 -25
- sglang/srt/constrained/__init__.py +0 -16
- sglang/srt/distributed/device_communicators/__init__.py +0 -0
- {sglang-0.4.1.post5.dist-info → sglang-0.4.1.post7.dist-info}/LICENSE +0 -0
- {sglang-0.4.1.post5.dist-info → sglang-0.4.1.post7.dist-info}/WHEEL +0 -0
- {sglang-0.4.1.post5.dist-info → sglang-0.4.1.post7.dist-info}/top_level.txt +0 -0
@@ -18,6 +18,7 @@ import triton.language as tl
|
|
18
18
|
|
19
19
|
from sglang.global_config import global_config
|
20
20
|
from sglang.srt.layers.attention import AttentionBackend
|
21
|
+
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
21
22
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
22
23
|
from sglang.srt.utils import is_flashinfer_available
|
23
24
|
|
@@ -62,9 +63,9 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
62
63
|
self.decode_use_tensor_cores = should_use_tensor_core(
|
63
64
|
kv_cache_dtype=model_runner.kv_cache_dtype,
|
64
65
|
num_attention_heads=model_runner.model_config.num_attention_heads
|
65
|
-
//
|
66
|
+
// get_attention_tp_size(),
|
66
67
|
num_kv_heads=model_runner.model_config.get_num_kv_heads(
|
67
|
-
|
68
|
+
get_attention_tp_size()
|
68
69
|
),
|
69
70
|
)
|
70
71
|
self.max_context_len = model_runner.model_config.context_len
|
@@ -84,6 +85,10 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
84
85
|
self.num_wrappers = 1
|
85
86
|
self.dispatch_reason = None
|
86
87
|
|
88
|
+
# Qwen2 models require higher flashinfer workspace size
|
89
|
+
if "Qwen2ForCausalLM" in model_runner.model_config.hf_config.architectures:
|
90
|
+
global_config.flashinfer_workspace_size = 512 * 1024 * 1024
|
91
|
+
|
87
92
|
# Allocate buffers
|
88
93
|
self.workspace_buffer = torch.empty(
|
89
94
|
global_config.flashinfer_workspace_size,
|
@@ -143,7 +148,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
143
148
|
self.prefill_cuda_graph_metadata = {}
|
144
149
|
|
145
150
|
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
146
|
-
if forward_batch.forward_mode.
|
151
|
+
if forward_batch.forward_mode.is_decode_or_idle():
|
147
152
|
self.indices_updater_decode.update(
|
148
153
|
forward_batch.req_pool_indices,
|
149
154
|
forward_batch.seq_lens,
|
@@ -234,7 +239,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
234
239
|
forward_mode: ForwardMode,
|
235
240
|
spec_info: Optional[SpecInfo],
|
236
241
|
):
|
237
|
-
if forward_mode.
|
242
|
+
if forward_mode.is_decode_or_idle():
|
238
243
|
decode_wrappers = []
|
239
244
|
for i in range(self.num_wrappers):
|
240
245
|
decode_wrappers.append(
|
@@ -303,7 +308,7 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
303
308
|
forward_mode: ForwardMode,
|
304
309
|
spec_info: Optional[SpecInfo],
|
305
310
|
):
|
306
|
-
if forward_mode.
|
311
|
+
if forward_mode.is_decode_or_idle():
|
307
312
|
self.indices_updater_decode.update(
|
308
313
|
req_pool_indices[:bs],
|
309
314
|
seq_lens[:bs],
|
@@ -353,7 +358,9 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
353
358
|
if k is not None:
|
354
359
|
assert v is not None
|
355
360
|
if save_kv_cache:
|
356
|
-
forward_batch.token_to_kv_pool.set_kv_buffer(
|
361
|
+
forward_batch.token_to_kv_pool.set_kv_buffer(
|
362
|
+
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
|
363
|
+
)
|
357
364
|
|
358
365
|
o = prefill_wrapper_paged.forward(
|
359
366
|
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
|
@@ -362,6 +369,8 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
362
369
|
sm_scale=layer.scaling,
|
363
370
|
window_left=layer.sliding_window_size,
|
364
371
|
logits_soft_cap=logits_soft_cap,
|
372
|
+
k_scale=layer.k_scale,
|
373
|
+
v_scale=layer.v_scale,
|
365
374
|
)
|
366
375
|
else:
|
367
376
|
o1, s1 = self.prefill_wrapper_ragged.forward_return_lse(
|
@@ -387,7 +396,9 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
387
396
|
o, _ = merge_state(o1, s1, o2, s2)
|
388
397
|
|
389
398
|
if save_kv_cache:
|
390
|
-
forward_batch.token_to_kv_pool.set_kv_buffer(
|
399
|
+
forward_batch.token_to_kv_pool.set_kv_buffer(
|
400
|
+
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
|
401
|
+
)
|
391
402
|
|
392
403
|
return o.view(-1, layer.tp_q_head_num * layer.head_dim)
|
393
404
|
|
@@ -412,13 +423,17 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
412
423
|
if k is not None:
|
413
424
|
assert v is not None
|
414
425
|
if save_kv_cache:
|
415
|
-
forward_batch.token_to_kv_pool.set_kv_buffer(
|
426
|
+
forward_batch.token_to_kv_pool.set_kv_buffer(
|
427
|
+
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
|
428
|
+
)
|
416
429
|
|
417
430
|
o = decode_wrapper.forward(
|
418
431
|
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
|
419
432
|
forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
|
420
433
|
sm_scale=layer.scaling,
|
421
434
|
logits_soft_cap=layer.logit_cap,
|
435
|
+
k_scale=layer.k_scale,
|
436
|
+
v_scale=layer.v_scale,
|
422
437
|
)
|
423
438
|
|
424
439
|
return o.view(-1, layer.tp_q_head_num * layer.head_dim)
|
@@ -439,10 +454,10 @@ class FlashInferIndicesUpdaterDecode:
|
|
439
454
|
def __init__(self, model_runner: ModelRunner, attn_backend: AttentionBackend):
|
440
455
|
# Parse Constants
|
441
456
|
self.num_qo_heads = (
|
442
|
-
model_runner.model_config.num_attention_heads //
|
457
|
+
model_runner.model_config.num_attention_heads // get_attention_tp_size()
|
443
458
|
)
|
444
459
|
self.num_kv_heads = model_runner.model_config.get_num_kv_heads(
|
445
|
-
|
460
|
+
get_attention_tp_size()
|
446
461
|
)
|
447
462
|
self.head_dim = model_runner.model_config.head_dim
|
448
463
|
self.data_type = model_runner.kv_cache_dtype
|
@@ -611,10 +626,10 @@ class FlashInferIndicesUpdaterPrefill:
|
|
611
626
|
def __init__(self, model_runner: ModelRunner, attn_backend: AttentionBackend):
|
612
627
|
# Parse Constants
|
613
628
|
self.num_qo_heads = (
|
614
|
-
model_runner.model_config.num_attention_heads //
|
629
|
+
model_runner.model_config.num_attention_heads // get_attention_tp_size()
|
615
630
|
)
|
616
631
|
self.num_kv_heads = model_runner.model_config.get_num_kv_heads(
|
617
|
-
|
632
|
+
get_attention_tp_size()
|
618
633
|
)
|
619
634
|
self.head_dim = model_runner.model_config.head_dim
|
620
635
|
self.data_type = model_runner.kv_cache_dtype
|
@@ -5,6 +5,7 @@ from typing import TYPE_CHECKING, Optional
|
|
5
5
|
import torch
|
6
6
|
|
7
7
|
from sglang.srt.layers.attention import AttentionBackend
|
8
|
+
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
8
9
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
9
10
|
|
10
11
|
if TYPE_CHECKING:
|
@@ -28,12 +29,9 @@ class TritonAttnBackend(AttentionBackend):
|
|
28
29
|
self.decode_attention_fwd = decode_attention_fwd
|
29
30
|
self.extend_attention_fwd = extend_attention_fwd
|
30
31
|
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
self.num_head = (
|
35
|
-
model_runner.model_config.num_attention_heads // model_runner.tp_size
|
36
|
-
)
|
32
|
+
self.num_head = (
|
33
|
+
model_runner.model_config.num_attention_heads // get_attention_tp_size()
|
34
|
+
)
|
37
35
|
|
38
36
|
self.num_kv_splits = model_runner.server_args.triton_attention_num_kv_splits
|
39
37
|
self.v_head_dim = model_runner.token_to_kv_pool.get_value_buffer(0).shape[-1]
|
@@ -0,0 +1,204 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
from typing import Optional
|
4
|
+
|
5
|
+
import torch
|
6
|
+
import torch.nn as nn
|
7
|
+
from einops import rearrange, repeat
|
8
|
+
|
9
|
+
from sglang.srt.distributed import parallel_state
|
10
|
+
from sglang.srt.distributed import utils as dist_utils
|
11
|
+
from sglang.srt.layers.attention.triton_ops.prefill_attention import (
|
12
|
+
context_attention_fwd,
|
13
|
+
)
|
14
|
+
from sglang.srt.layers.linear import (
|
15
|
+
ColumnParallelLinear,
|
16
|
+
QKVParallelLinear,
|
17
|
+
RowParallelLinear,
|
18
|
+
)
|
19
|
+
from sglang.srt.layers.quantization import QuantizationConfig
|
20
|
+
|
21
|
+
|
22
|
+
def rotate_half(x: torch.Tensor, interleaved: bool = False) -> torch.Tensor:
|
23
|
+
if not interleaved:
|
24
|
+
x1, x2 = x.chunk(2, dim=-1)
|
25
|
+
return torch.cat((-x2, x1), dim=-1)
|
26
|
+
else:
|
27
|
+
x1, x2 = x[..., ::2], x[..., 1::2]
|
28
|
+
return rearrange(
|
29
|
+
torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2
|
30
|
+
)
|
31
|
+
|
32
|
+
|
33
|
+
def apply_rotary_emb_torch(
|
34
|
+
x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, interleaved: bool = False
|
35
|
+
) -> torch.Tensor:
|
36
|
+
"""
|
37
|
+
x: (batch_size, seqlen, nheads, headdim)
|
38
|
+
cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2)
|
39
|
+
"""
|
40
|
+
ro_dim = cos.shape[-1] * 2
|
41
|
+
assert ro_dim <= x.shape[-1]
|
42
|
+
cos = repeat(
|
43
|
+
cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)"
|
44
|
+
)
|
45
|
+
sin = repeat(
|
46
|
+
sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)"
|
47
|
+
)
|
48
|
+
return torch.cat(
|
49
|
+
[
|
50
|
+
x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin,
|
51
|
+
x[..., ro_dim:],
|
52
|
+
],
|
53
|
+
dim=-1,
|
54
|
+
)
|
55
|
+
|
56
|
+
|
57
|
+
def apply_rotary_pos_emb_vision(t: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor:
|
58
|
+
t_ = t.float()
|
59
|
+
cos = freqs.cos()
|
60
|
+
sin = freqs.sin()
|
61
|
+
output = apply_rotary_emb_torch(t_, cos, sin).type_as(t)
|
62
|
+
return output
|
63
|
+
|
64
|
+
|
65
|
+
class VisionAttention(nn.Module):
|
66
|
+
"""Multi-headed attention without any cache, mostly used for ViT."""
|
67
|
+
|
68
|
+
def __init__(
|
69
|
+
self,
|
70
|
+
embed_dim: int,
|
71
|
+
num_heads: int,
|
72
|
+
projection_size: int,
|
73
|
+
use_qkv_parallel: bool,
|
74
|
+
quant_config: Optional[QuantizationConfig] = None,
|
75
|
+
prefix: str = "",
|
76
|
+
):
|
77
|
+
super().__init__()
|
78
|
+
world_size = parallel_state.get_tensor_model_parallel_world_size()
|
79
|
+
|
80
|
+
self.hidden_size_per_attention_head = dist_utils.divide(
|
81
|
+
projection_size, num_heads
|
82
|
+
)
|
83
|
+
self.num_attention_heads_per_partition = dist_utils.divide(
|
84
|
+
num_heads, world_size
|
85
|
+
)
|
86
|
+
# self.tp_size = get_tensor_model_parallel_world_size()
|
87
|
+
# num_heads = self.num_heads_per_partition
|
88
|
+
self.use_qkv_parallel = use_qkv_parallel
|
89
|
+
if use_qkv_parallel:
|
90
|
+
self.head_dim = embed_dim // num_heads
|
91
|
+
self.qkv_proj = QKVParallelLinear(
|
92
|
+
hidden_size=embed_dim,
|
93
|
+
head_size=self.head_dim,
|
94
|
+
total_num_heads=num_heads,
|
95
|
+
quant_config=quant_config,
|
96
|
+
prefix=f"{prefix}.qkv_proj",
|
97
|
+
)
|
98
|
+
else:
|
99
|
+
self.qkv_proj = ColumnParallelLinear(
|
100
|
+
input_size=embed_dim,
|
101
|
+
output_size=3 * projection_size,
|
102
|
+
quant_config=quant_config,
|
103
|
+
prefix=f"{prefix}.qkv_proj",
|
104
|
+
)
|
105
|
+
self.proj = RowParallelLinear(
|
106
|
+
input_size=embed_dim,
|
107
|
+
output_size=embed_dim,
|
108
|
+
quant_config=quant_config,
|
109
|
+
prefix=f"{prefix}.out_proj",
|
110
|
+
)
|
111
|
+
|
112
|
+
def forward(
|
113
|
+
self,
|
114
|
+
x: torch.Tensor,
|
115
|
+
cu_seqlens: Optional[torch.Tensor] = None,
|
116
|
+
rotary_pos_emb: torch.Tensor = None,
|
117
|
+
) -> torch.Tensor:
|
118
|
+
"""
|
119
|
+
Input shape: [b, s, embed_dim]
|
120
|
+
Output shape: [s, b, num_heads * head_size]
|
121
|
+
"""
|
122
|
+
|
123
|
+
bsz, s, _ = x.shape
|
124
|
+
if self.use_qkv_parallel:
|
125
|
+
# [b, s, embed_dim] --> [b, s, embed_dim]
|
126
|
+
qkv, _ = self.qkv_proj(x)
|
127
|
+
q, k, v = qkv.chunk(3, dim=-1)
|
128
|
+
|
129
|
+
# [b, s, embed_dim] --> [b * s, num_heads, head_size]
|
130
|
+
q, k, v = [
|
131
|
+
x.reshape(
|
132
|
+
bsz * s, self.num_attention_heads_per_partition, -1
|
133
|
+
).contiguous()
|
134
|
+
for x in (q, k, v)
|
135
|
+
]
|
136
|
+
else:
|
137
|
+
# [b, s, embed_dim] --> [s, b, embed_dim]
|
138
|
+
x = rearrange(x, "b s ... -> s b ...")
|
139
|
+
# [s, b, embed_dim] --> [s, b, head * 3 * head_dim]
|
140
|
+
qkv, _ = self.qkv_proj(x)
|
141
|
+
# [s, b, head * 3 * head_dim] --> [s, b, head, 3 * head_dim]
|
142
|
+
new_x_shape = qkv.size()[:-1] + (
|
143
|
+
self.num_attention_heads_per_partition,
|
144
|
+
3 * self.hidden_size_per_attention_head,
|
145
|
+
)
|
146
|
+
qkv = qkv.view(*new_x_shape)
|
147
|
+
|
148
|
+
# [s, b, head, 3 * head_dim] --> 3 [s, b, head, head_dim]
|
149
|
+
q, k, v = dist_utils.split_tensor_along_last_dim(qkv, 3)
|
150
|
+
|
151
|
+
# [s, b, head, head_dim] --> [b, s, head, head_dim]
|
152
|
+
q, k, v = [
|
153
|
+
rearrange(x, "s b ... -> b s ...").contiguous() for x in (q, k, v)
|
154
|
+
]
|
155
|
+
|
156
|
+
if rotary_pos_emb is not None:
|
157
|
+
q = apply_rotary_pos_emb_vision(q, rotary_pos_emb)
|
158
|
+
k = apply_rotary_pos_emb_vision(k, rotary_pos_emb)
|
159
|
+
|
160
|
+
if self.use_qkv_parallel:
|
161
|
+
pass
|
162
|
+
else:
|
163
|
+
# [b, s, head, head_dim] --> [b * s, head, head_dim]
|
164
|
+
q, k, v = [rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]]
|
165
|
+
|
166
|
+
# [b * s, num_heads, head_size]
|
167
|
+
output = torch.empty_like(q)
|
168
|
+
|
169
|
+
seq_lens = (cu_seqlens[1:] - cu_seqlens[:-1]).cuda()
|
170
|
+
max_seqlen = seq_lens.max().item()
|
171
|
+
|
172
|
+
context_attention_fwd(
|
173
|
+
q,
|
174
|
+
k,
|
175
|
+
v,
|
176
|
+
output,
|
177
|
+
cu_seqlens.cuda(),
|
178
|
+
seq_lens,
|
179
|
+
max_seqlen,
|
180
|
+
is_causal=False,
|
181
|
+
)
|
182
|
+
|
183
|
+
if self.use_qkv_parallel:
|
184
|
+
|
185
|
+
# [b * s, head, head_dim] --> [b, s, head * head_dim]
|
186
|
+
output = rearrange(output, "(b s) ... h d -> b s ... (h d)", b=bsz)
|
187
|
+
|
188
|
+
# [b, s, head, head_dim] --> [b, s, head, head_dim]
|
189
|
+
output, _ = self.proj(output)
|
190
|
+
else:
|
191
|
+
# [b * s, head, head_dim] --> [b, s, head, head_dim]
|
192
|
+
context_layer = rearrange(output, "(b s) ... -> b s ...", b=bsz)
|
193
|
+
|
194
|
+
# [s, b, num_heads * head_size]
|
195
|
+
context_layer = rearrange(
|
196
|
+
context_layer, "b s h d -> s b (h d)"
|
197
|
+
).contiguous()
|
198
|
+
|
199
|
+
# [s, b, num_heads * head_size] --> [s, b, num_heads * head_size]
|
200
|
+
output, _ = self.proj(context_layer)
|
201
|
+
|
202
|
+
output = output.view(bsz, s, -1)
|
203
|
+
|
204
|
+
return output
|
@@ -0,0 +1,69 @@
|
|
1
|
+
import torch
|
2
|
+
|
3
|
+
from sglang.srt.distributed import GroupCoordinator, get_tp_group
|
4
|
+
|
5
|
+
_ATTN_TP_GROUP = None
|
6
|
+
_ATTN_TP_RANK = None
|
7
|
+
_ATTN_TP_SIZE = None
|
8
|
+
_DP_RANK = None
|
9
|
+
_DP_SIZE = None
|
10
|
+
|
11
|
+
|
12
|
+
def compute_dp_attention_world_info(enable_dp_attention, tp_rank, tp_size, dp_size):
|
13
|
+
if not enable_dp_attention:
|
14
|
+
return tp_rank, tp_size, 0
|
15
|
+
|
16
|
+
attn_tp_size = tp_size // dp_size
|
17
|
+
dp_rank = tp_rank // attn_tp_size
|
18
|
+
attn_tp_rank = tp_rank % attn_tp_size
|
19
|
+
return attn_tp_rank, attn_tp_size, dp_rank
|
20
|
+
|
21
|
+
|
22
|
+
def initialize_dp_attention(enable_dp_attention, tp_rank, tp_size, dp_size):
|
23
|
+
global _ATTN_TP_GROUP, _ATTN_TP_RANK, _ATTN_TP_SIZE, _DP_RANK, _DP_SIZE
|
24
|
+
|
25
|
+
_ATTN_TP_RANK, _ATTN_TP_SIZE, _DP_RANK = compute_dp_attention_world_info(
|
26
|
+
enable_dp_attention, tp_rank, tp_size, dp_size
|
27
|
+
)
|
28
|
+
_DP_SIZE = dp_size
|
29
|
+
|
30
|
+
tp_group = get_tp_group()
|
31
|
+
_ATTN_TP_GROUP = GroupCoordinator(
|
32
|
+
[
|
33
|
+
list(range(head, head + _ATTN_TP_SIZE))
|
34
|
+
for head in range(0, tp_size, _ATTN_TP_SIZE)
|
35
|
+
],
|
36
|
+
tp_rank,
|
37
|
+
torch.distributed.get_backend(tp_group.device_group),
|
38
|
+
False,
|
39
|
+
False,
|
40
|
+
False,
|
41
|
+
False,
|
42
|
+
False,
|
43
|
+
group_name="attention_tp",
|
44
|
+
)
|
45
|
+
|
46
|
+
|
47
|
+
def get_attention_tp_group():
|
48
|
+
assert _ATTN_TP_GROUP is not None, "dp attention not initialized!"
|
49
|
+
return _ATTN_TP_GROUP
|
50
|
+
|
51
|
+
|
52
|
+
def get_attention_tp_rank():
|
53
|
+
assert _ATTN_TP_RANK is not None, "dp attention not initialized!"
|
54
|
+
return _ATTN_TP_RANK
|
55
|
+
|
56
|
+
|
57
|
+
def get_attention_tp_size():
|
58
|
+
assert _ATTN_TP_SIZE is not None, "dp attention not initialized!"
|
59
|
+
return _ATTN_TP_SIZE
|
60
|
+
|
61
|
+
|
62
|
+
def get_attention_dp_rank():
|
63
|
+
assert _DP_RANK is not None, "dp attention not initialized!"
|
64
|
+
return _DP_RANK
|
65
|
+
|
66
|
+
|
67
|
+
def get_attention_dp_size():
|
68
|
+
assert _DP_SIZE is not None, "dp attention not initialized!"
|
69
|
+
return _DP_SIZE
|