sglang 0.4.6.post2__py3-none-any.whl → 0.4.6.post3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (90) hide show
  1. sglang/bench_one_batch.py +1 -11
  2. sglang/bench_serving.py +149 -1
  3. sglang/lang/chat_template.py +44 -0
  4. sglang/srt/configs/deepseekvl2.py +3 -0
  5. sglang/srt/configs/device_config.py +1 -1
  6. sglang/srt/configs/internvl.py +696 -0
  7. sglang/srt/configs/janus_pro.py +3 -0
  8. sglang/srt/configs/model_config.py +17 -0
  9. sglang/srt/constrained/xgrammar_backend.py +11 -19
  10. sglang/srt/conversation.py +30 -3
  11. sglang/srt/disaggregation/decode.py +4 -1
  12. sglang/srt/disaggregation/mini_lb.py +74 -23
  13. sglang/srt/disaggregation/mooncake/conn.py +9 -18
  14. sglang/srt/disaggregation/nixl/conn.py +241 -71
  15. sglang/srt/disaggregation/utils.py +44 -1
  16. sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -8
  17. sglang/srt/distributed/device_communicators/npu_communicator.py +39 -0
  18. sglang/srt/distributed/device_communicators/pynccl.py +2 -1
  19. sglang/srt/distributed/device_communicators/shm_broadcast.py +2 -1
  20. sglang/srt/distributed/parallel_state.py +22 -1
  21. sglang/srt/entrypoints/engine.py +14 -2
  22. sglang/srt/entrypoints/http_server.py +28 -1
  23. sglang/srt/entrypoints/verl_engine.py +3 -2
  24. sglang/srt/hf_transformers_utils.py +20 -1
  25. sglang/srt/layers/attention/flashattention_backend.py +146 -50
  26. sglang/srt/layers/attention/flashinfer_backend.py +23 -13
  27. sglang/srt/layers/attention/flashinfer_mla_backend.py +62 -15
  28. sglang/srt/layers/attention/merge_state.py +46 -0
  29. sglang/srt/layers/attention/triton_ops/merge_state.py +96 -0
  30. sglang/srt/layers/attention/vision.py +290 -163
  31. sglang/srt/layers/moe/ep_moe/kernels.py +342 -7
  32. sglang/srt/layers/moe/ep_moe/layer.py +120 -1
  33. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +97 -54
  34. sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  35. sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  36. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +4 -1
  37. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +2 -4
  38. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +2 -1
  39. sglang/srt/layers/quantization/deep_gemm.py +5 -0
  40. sglang/srt/layers/quantization/fp8.py +108 -95
  41. sglang/srt/layers/quantization/fp8_kernel.py +79 -60
  42. sglang/srt/layers/quantization/fp8_utils.py +71 -23
  43. sglang/srt/layers/quantization/kv_cache.py +3 -10
  44. sglang/srt/layers/quantization/utils.py +0 -5
  45. sglang/srt/layers/quantization/w8a8_fp8.py +8 -10
  46. sglang/srt/lora/lora_manager.py +10 -13
  47. sglang/srt/managers/cache_controller.py +115 -119
  48. sglang/srt/managers/io_struct.py +10 -0
  49. sglang/srt/managers/multimodal_processors/base_processor.py +5 -0
  50. sglang/srt/managers/multimodal_processors/internvl.py +232 -0
  51. sglang/srt/managers/schedule_batch.py +19 -1
  52. sglang/srt/managers/schedule_policy.py +11 -5
  53. sglang/srt/managers/scheduler.py +28 -13
  54. sglang/srt/managers/tokenizer_manager.py +24 -13
  55. sglang/srt/managers/tp_worker.py +9 -12
  56. sglang/srt/mem_cache/chunk_cache.py +2 -0
  57. sglang/srt/mem_cache/memory_pool.py +2 -2
  58. sglang/srt/model_executor/model_runner.py +44 -33
  59. sglang/srt/model_loader/loader.py +18 -11
  60. sglang/srt/models/clip.py +4 -4
  61. sglang/srt/models/deepseek_janus_pro.py +1 -1
  62. sglang/srt/models/deepseek_nextn.py +1 -20
  63. sglang/srt/models/deepseek_v2.py +55 -20
  64. sglang/srt/models/gemma3_mm.py +1 -1
  65. sglang/srt/models/internlm2.py +3 -0
  66. sglang/srt/models/internvl.py +670 -0
  67. sglang/srt/models/llama.py +1 -1
  68. sglang/srt/models/llama4.py +53 -7
  69. sglang/srt/models/minicpmv.py +1 -1
  70. sglang/srt/models/mllama.py +1 -1
  71. sglang/srt/models/phi3_small.py +16 -2
  72. sglang/srt/models/qwen2_5_vl.py +8 -4
  73. sglang/srt/models/qwen2_vl.py +4 -4
  74. sglang/srt/models/xiaomi_mimo.py +171 -0
  75. sglang/srt/openai_api/adapter.py +24 -40
  76. sglang/srt/openai_api/protocol.py +28 -16
  77. sglang/srt/reasoning_parser.py +2 -2
  78. sglang/srt/sampling/sampling_batch_info.py +54 -2
  79. sglang/srt/sampling/sampling_params.py +2 -0
  80. sglang/srt/server_args.py +30 -6
  81. sglang/srt/utils.py +35 -1
  82. sglang/test/test_block_fp8.py +2 -2
  83. sglang/test/test_deepep_utils.py +219 -0
  84. sglang/test/test_utils.py +3 -1
  85. sglang/version.py +1 -1
  86. {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post3.dist-info}/METADATA +14 -6
  87. {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post3.dist-info}/RECORD +90 -80
  88. {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post3.dist-info}/WHEEL +1 -1
  89. {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post3.dist-info}/licenses/LICENSE +0 -0
  90. {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post3.dist-info}/top_level.txt +0 -0
@@ -1,6 +1,7 @@
1
1
  from __future__ import annotations
2
2
 
3
- from functools import lru_cache
3
+ import math
4
+ from functools import lru_cache, wraps
4
5
  from typing import Optional, Tuple
5
6
 
6
7
  import torch
@@ -8,6 +9,13 @@ import torch.nn as nn
8
9
  import torch.nn.functional as F
9
10
  from einops import rearrange
10
11
 
12
+ from sglang.srt.utils import is_cuda
13
+
14
+ _is_cuda = is_cuda()
15
+
16
+ if _is_cuda:
17
+ from sgl_kernel.flash_attn import flash_attn_varlen_func
18
+
11
19
  from sglang.srt.distributed import parallel_state
12
20
  from sglang.srt.distributed import utils as dist_utils
13
21
  from sglang.srt.layers.attention.triton_ops.prefill_attention import (
@@ -19,166 +27,31 @@ from sglang.srt.layers.linear import (
19
27
  RowParallelLinear,
20
28
  )
21
29
  from sglang.srt.layers.quantization import QuantizationConfig
22
- from sglang.srt.layers.rotary_embedding import apply_rotary_pos_emb, rotate_half
23
- from sglang.srt.utils import add_prefix
30
+ from sglang.srt.layers.rotary_embedding import apply_rotary_pos_emb
31
+ from sglang.srt.managers.schedule_batch import global_server_args_dict
32
+ from sglang.srt.utils import add_prefix, logger
24
33
 
25
-
26
- class VisionAttention(nn.Module):
27
- r"""
28
- Multi-headed attention without any cache, mostly used for ViT.
34
+ ROTARY_EMBED_CLASSES = {
35
+ "normal": apply_rotary_pos_emb,
36
+ }
29
37
 
30
38
 
31
- Args:
32
- use_qkv_parallel (bool, optional): If True, use QKV-parallel attention.
33
- use_context_forward (bool, default to True):
34
- if ``True``, a flash_attn style attention will be applied
35
- Otherwise, a full-sequence attention will be applied.
36
- softmax_in_single_precision (bool, default to False):
37
- if ``True``, the softmax will be performed in single-precision
38
- Otherwise, it will be performed in half-precision
39
+ def execute_once(func):
40
+ has_run = None
39
41
 
40
- """
42
+ @wraps(func)
43
+ def wrapper(*args, **kwargs):
44
+ nonlocal has_run
45
+ if not has_run:
46
+ func(*args, **kwargs)
47
+ has_run = True
41
48
 
42
- def __init__(
43
- self,
44
- embed_dim: int,
45
- num_heads: int,
46
- projection_size: int,
47
- use_qkv_parallel: bool,
48
- quant_config: Optional[QuantizationConfig] = None,
49
- dropout: float = 0.0,
50
- use_context_forward: bool = True,
51
- softmax_in_single_precision: bool = False,
52
- flatten_batch: bool = False,
53
- prefix: str = "",
54
- ):
55
- super().__init__()
56
- self.use_context_forward = use_context_forward
57
- world_size = parallel_state.get_tensor_model_parallel_world_size()
58
- self.dropout = dropout
59
- self.head_size = embed_dim // num_heads
60
- self.hidden_size_per_attention_head = dist_utils.divide(
61
- projection_size, num_heads
62
- )
63
- self.num_attention_heads_per_partition = dist_utils.divide(
64
- num_heads, world_size
65
- )
49
+ return wrapper
66
50
 
67
- if self.use_context_forward:
68
- self.qkv_backend = VisionTritonAttention()
69
- else:
70
- self.qkv_backend = VisionSdpaAttention(
71
- head_size=self.head_size,
72
- dropout=dropout,
73
- flatten_batch=flatten_batch,
74
- softmax_in_single_precision=softmax_in_single_precision,
75
- )
76
51
 
77
- self.use_qkv_parallel = use_qkv_parallel
78
- if use_qkv_parallel:
79
- self.qkv_proj = QKVParallelLinear(
80
- hidden_size=embed_dim,
81
- head_size=self.head_size,
82
- total_num_heads=num_heads,
83
- quant_config=quant_config,
84
- prefix=add_prefix("qkv_proj", prefix),
85
- )
86
- else:
87
- self.qkv_proj = ColumnParallelLinear(
88
- input_size=embed_dim,
89
- output_size=3 * projection_size,
90
- quant_config=quant_config,
91
- prefix=add_prefix("qkv_proj", prefix),
92
- )
93
- self.proj = RowParallelLinear(
94
- input_size=embed_dim,
95
- output_size=embed_dim,
96
- quant_config=quant_config,
97
- prefix=add_prefix("proj", prefix),
98
- )
99
-
100
- def forward(
101
- self,
102
- x: torch.Tensor,
103
- cu_seqlens: Optional[torch.Tensor] = None,
104
- position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
105
- attention_mask: Optional[torch.Tensor] = None,
106
- ) -> torch.Tensor:
107
- r"""
108
- Args:
109
- x: [b, s, embed_dim]
110
- cu_seqlens: [b]
111
- Returns:
112
- [s, b, head * head_size]
113
- """
114
- bsz, s, _ = x.shape
115
- head = self.num_attention_heads_per_partition
116
- if self.use_qkv_parallel:
117
- # [b, s, embed_dim] --> [b, s, embed_dim]
118
- qkv, _ = self.qkv_proj(x)
119
- q, k, v = qkv.chunk(3, dim=-1)
120
-
121
- # [b, s, embed_dim] --> [b * s, head, head_size]
122
- q, k, v = [x.reshape(bsz * s, head, -1).contiguous() for x in (q, k, v)]
123
- else:
124
- # [b, s, embed_dim] --> [s, b, embed_dim]
125
- x = rearrange(x, "b s ... -> s b ...")
126
- # [s, b, embed_dim] --> [s, b, head * 3 * head_size]
127
- qkv, _ = self.qkv_proj(x)
128
- # [s, b, head * 3 * head_size] --> [s, b, head, 3 * head_size]
129
- new_x_shape = qkv.size()[:-1] + (
130
- head,
131
- 3 * self.hidden_size_per_attention_head,
132
- )
133
- qkv = qkv.view(*new_x_shape)
134
-
135
- # [s, b, head, 3 * head_size] --> 3 [s, b, head, head_size]
136
- q, k, v = dist_utils.split_tensor_along_last_dim(qkv, 3)
137
-
138
- # [s, b, head, head_size] --> [b, s, head, head_size]
139
- q, k, v = [
140
- rearrange(x, "s b ... -> b s ...").contiguous() for x in (q, k, v)
141
- ]
142
-
143
- if position_embeddings is not None:
144
- cos, sin = position_embeddings
145
- original_shape = q.shape
146
- # [total_tokens, head, head_size]
147
- q = q.view(-1, head, self.head_size)
148
- k = k.view(-1, head, self.head_size)
149
-
150
- q, k = apply_rotary_pos_emb(q, k, cos, sin)
151
-
152
- q = q.view(original_shape)
153
- k = k.view(original_shape)
154
-
155
- if self.use_qkv_parallel:
156
- pass
157
- else:
158
- # [b, s, head, head_size] --> [b * s, head, head_size]
159
- q, k, v = [rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]]
160
-
161
- output = self.qkv_backend.forward(q, k, v, bsz, cu_seqlens, attention_mask)
162
-
163
- if self.use_qkv_parallel:
164
- # [b * s, h, head_size] --> [b, s, h * head_size]
165
- output = rearrange(output, "(b s) ... h d -> b s ... (h d)", b=bsz)
166
-
167
- # [b, s, h * head_size] --> [b, s, h * head_size]
168
- output, _ = self.proj(output)
169
- else:
170
- # [b * s, h, head_size] --> [s, b, h * head_size]
171
- context_layer = rearrange(
172
- output, "(b s) h d -> s b (h d)", b=bsz, s=s
173
- ).contiguous()
174
-
175
- # [s, b, h * head_size] --> [s, b, h * head_size]
176
- output, _ = self.proj(context_layer)
177
-
178
- # [s, b, h * head_size] --> [b, s, h * head_size]
179
- output = output.view(bsz, s, -1)
180
-
181
- return output
52
+ @execute_once
53
+ def info_once(message: str):
54
+ logger.info(message)
182
55
 
183
56
 
184
57
  class VisionSdpaAttention(nn.Module):
@@ -189,16 +62,22 @@ class VisionSdpaAttention(nn.Module):
189
62
 
190
63
  def __init__(
191
64
  self,
192
- head_size: int,
65
+ head_dim: int,
66
+ num_heads: int,
67
+ num_kv_heads: int,
193
68
  dropout: float = 0.0,
194
69
  flatten_batch: bool = False,
195
70
  softmax_in_single_precision: bool = False,
71
+ **kwargs,
196
72
  ):
197
73
  super().__init__()
198
- self.head_size = head_size
74
+ self.head_size = head_dim
75
+ self.num_heads = num_heads
76
+ self.num_kv_heads = num_kv_heads
199
77
  self.flatten_batch = flatten_batch
200
78
  self.softmax_in_single_precision = softmax_in_single_precision
201
79
  self.dropout = dropout
80
+ self.scale = 1.0 / math.sqrt(self.head_size)
202
81
 
203
82
  @staticmethod
204
83
  @lru_cache(maxsize=128)
@@ -212,7 +91,7 @@ class VisionSdpaAttention(nn.Module):
212
91
  flatten_batch: whether to flatten batch dimension
213
92
  cu_seqlens: tuple of cumulative sequence lengths
214
93
  Returns:
215
- attention mask tensor
94
+ attention mask tensor of shape [b, 1, s, s] or [1, s, s]
216
95
  """
217
96
  if flatten_batch:
218
97
  mask = torch.zeros([1, s, s], dtype=torch.bool)
@@ -241,7 +120,7 @@ class VisionSdpaAttention(nn.Module):
241
120
  flatten_batch: bool = False,
242
121
  ) -> Optional[torch.Tensor]:
243
122
  r"""
244
- Creates a non-causal 4D mask of shape `(b, 1, s, s)` or `(1, 1, s, s)`.
123
+ Creates a non-causal 4D mask of shape `(b, 1, s, s)` or `(1, s, s)`.
245
124
  Args:
246
125
  s: sequence length
247
126
  cu_seqlens: cumulative sequence lengths tensor. If not, returns an empty mask
@@ -264,6 +143,7 @@ class VisionSdpaAttention(nn.Module):
264
143
  bsz: int,
265
144
  cu_seqlens: Optional[torch.Tensor] = None,
266
145
  attention_mask: Optional[torch.Tensor] = None,
146
+ **kwargs,
267
147
  ) -> torch.Tensor:
268
148
  r"""
269
149
  Args:
@@ -274,6 +154,8 @@ class VisionSdpaAttention(nn.Module):
274
154
  if self.flatten_batch:
275
155
  assert bsz == 1, "flatten_batch is True, bsz must be 1"
276
156
 
157
+ assert q.dim() == 3, q.shape
158
+
277
159
  s = q.shape[0] // bsz
278
160
 
279
161
  # [b, 1, s, s]
@@ -291,10 +173,10 @@ class VisionSdpaAttention(nn.Module):
291
173
  q, k, v = [rearrange(x, "(b s) h d -> b h s d", b=bsz) for x in [q, k, v]]
292
174
 
293
175
  if self.softmax_in_single_precision:
294
- scale = self.head_size**-0.5
295
- k_transposed = rearrange(k, "b h s d -> b h d s")
296
- attn_weights = torch.matmul(q, k_transposed) * scale
297
- del k, k_transposed
176
+ k = rearrange(k, "b h s d -> b h d s")
177
+ attn_weights = torch.matmul(q, k) * self.scale
178
+ del k
179
+ # masking
298
180
  attention_mask = (~attention_mask) * torch.finfo(q.dtype).min
299
181
  attn_weights = attn_weights + attention_mask
300
182
  del attention_mask
@@ -332,6 +214,7 @@ class VisionTritonAttention(nn.Module):
332
214
 
333
215
  def __init__(
334
216
  self,
217
+ **kwargs,
335
218
  ):
336
219
  super().__init__()
337
220
 
@@ -340,8 +223,8 @@ class VisionTritonAttention(nn.Module):
340
223
  q: torch.Tensor,
341
224
  k: torch.Tensor,
342
225
  v: torch.Tensor,
343
- _bsz: int,
344
226
  cu_seqlens: Optional[torch.Tensor],
227
+ **kwargs,
345
228
  ) -> torch.Tensor:
346
229
  r"""
347
230
  Args:
@@ -366,3 +249,247 @@ class VisionTritonAttention(nn.Module):
366
249
  )
367
250
 
368
251
  return output
252
+
253
+
254
+ class VisionFlash3Attention(nn.Module):
255
+ def __init__(
256
+ self,
257
+ **kwargs,
258
+ ):
259
+ if not _is_cuda:
260
+ raise Exception("VisionFlash3Attention is only available for cuda")
261
+ super().__init__()
262
+
263
+ def forward(
264
+ self,
265
+ q: torch.Tensor,
266
+ k: torch.Tensor,
267
+ v: torch.Tensor,
268
+ cu_seqlens: Optional[torch.Tensor],
269
+ attention_mask: Optional[torch.Tensor] = None,
270
+ **kwargs,
271
+ ) -> torch.Tensor:
272
+ r"""
273
+ Args:
274
+ cu_seqlens: [b]
275
+ Returns:
276
+ [b * s, h, head_size]
277
+ """
278
+ cu_seqlens = cu_seqlens.to(dtype=torch.int32).cuda()
279
+ seq_lens = cu_seqlens[1:] - cu_seqlens[:-1]
280
+ max_seqlen = seq_lens.max().item()
281
+ output = flash_attn_varlen_func(
282
+ q,
283
+ k,
284
+ v,
285
+ cu_seqlens_q=cu_seqlens,
286
+ cu_seqlens_k=cu_seqlens,
287
+ max_seqlen_q=max_seqlen,
288
+ max_seqlen_k=max_seqlen,
289
+ )
290
+
291
+ return output
292
+
293
+
294
+ QKV_BACKEND_IMPL = {
295
+ "triton_attn": VisionTritonAttention,
296
+ "sdpa": VisionSdpaAttention,
297
+ "fa3": VisionFlash3Attention,
298
+ }
299
+
300
+
301
+ class VisionAttention(nn.Module):
302
+ r"""
303
+ Multi-headed attention without any cache, mostly used for multimodal transformers.
304
+
305
+
306
+ Args:
307
+ use_qkv_parallel (bool, optional): If True, use QKV-parallel attention.
308
+ softmax_in_single_precision (bool, default to False):
309
+ if ``True``, the softmax will be performed in single-precision
310
+ Otherwise, it will be performed in half-precision
311
+
312
+ """
313
+
314
+ def __init__(
315
+ self,
316
+ embed_dim: int,
317
+ num_heads: int,
318
+ projection_size: int,
319
+ use_qkv_parallel: bool,
320
+ qkv_backend: Optional[str] = None,
321
+ quant_config: Optional[QuantizationConfig] = None,
322
+ dropout: float = 0.0,
323
+ softmax_in_single_precision: bool = False,
324
+ flatten_batch: bool = False,
325
+ prefix: str = "",
326
+ proj_bias: bool = True,
327
+ **kwargs,
328
+ ):
329
+ super().__init__()
330
+ world_size = parallel_state.get_tensor_model_parallel_world_size()
331
+ self.dropout = dropout
332
+ self.head_size = embed_dim // num_heads
333
+ self.hidden_size_per_attention_head = dist_utils.divide(
334
+ projection_size, num_heads
335
+ )
336
+ self.num_attention_heads_per_partition = dist_utils.divide(
337
+ num_heads, world_size
338
+ )
339
+ self.num_attention_kv_heads_per_partition = dist_utils.divide(
340
+ num_heads, world_size
341
+ )
342
+
343
+ self.q_size = self.num_attention_heads_per_partition * self.head_size
344
+ self.kv_size = self.num_attention_kv_heads_per_partition * self.head_size
345
+
346
+ if global_server_args_dict["mm_attention_backend"] is None:
347
+ if qkv_backend is None:
348
+ qkv_backend = "sdpa"
349
+ info_once(f"Multimodal attention backend not set. Use {qkv_backend}.")
350
+ else:
351
+ qkv_backend = global_server_args_dict["mm_attention_backend"]
352
+
353
+ info_once(f"Using {qkv_backend} as multimodal attention backend.")
354
+
355
+ self.qkv_backend = QKV_BACKEND_IMPL[qkv_backend](
356
+ head_dim=self.head_size,
357
+ num_heads=self.num_attention_heads_per_partition,
358
+ num_kv_heads=self.num_attention_kv_heads_per_partition,
359
+ dropout=dropout,
360
+ flatten_batch=flatten_batch,
361
+ softmax_in_single_precision=softmax_in_single_precision,
362
+ )
363
+
364
+ self.use_qkv_parallel = use_qkv_parallel
365
+ if use_qkv_parallel:
366
+ self.qkv_proj = QKVParallelLinear(
367
+ hidden_size=embed_dim,
368
+ head_size=self.head_size,
369
+ total_num_heads=num_heads,
370
+ total_num_kv_heads=num_heads,
371
+ quant_config=quant_config,
372
+ prefix=add_prefix("qkv_proj", prefix),
373
+ )
374
+ else:
375
+ self.qkv_proj = ColumnParallelLinear(
376
+ input_size=embed_dim,
377
+ output_size=3 * projection_size,
378
+ quant_config=quant_config,
379
+ prefix=add_prefix("qkv_proj", prefix),
380
+ )
381
+ self.proj = RowParallelLinear(
382
+ input_size=embed_dim,
383
+ output_size=embed_dim,
384
+ bias=proj_bias,
385
+ quant_config=quant_config,
386
+ prefix=add_prefix("proj", prefix),
387
+ )
388
+
389
+ def forward(
390
+ self,
391
+ x: torch.Tensor,
392
+ cu_seqlens: Optional[torch.Tensor] = None,
393
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
394
+ attention_mask: Optional[torch.Tensor] = None,
395
+ **kwargs,
396
+ ) -> torch.Tensor:
397
+ r"""
398
+ Args:
399
+ x: [b, s, embed_dim]
400
+ cu_seqlens: [b]
401
+ Returns:
402
+ [s, b, head * head_size]
403
+ """
404
+ if x.dim() == 2:
405
+ x = x.unsqueeze(0)
406
+ assert x.dim() == 3, x.shape
407
+ bsz, s, _ = x.shape
408
+ head = self.num_attention_heads_per_partition
409
+ kv_head = self.num_attention_kv_heads_per_partition
410
+ if self.use_qkv_parallel:
411
+ # [b, s, embed_dim] --> [b, s, embed_dim]
412
+ qkv, _ = self.qkv_proj(x)
413
+
414
+ q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
415
+
416
+ # [b, s, embed_dim] --> [b * s, head, head_size]
417
+ q = q.reshape(bsz * s, head, -1).contiguous()
418
+ k = k.reshape(bsz * s, kv_head, -1).contiguous()
419
+ v = v.reshape(bsz * s, kv_head, -1).contiguous()
420
+ else:
421
+ # [b, s, embed_dim] --> [s, b, embed_dim]
422
+ x = rearrange(x, "b s ... -> s b ...")
423
+ # [s, b, embed_dim] --> [s, b, head * 3 * head_size]
424
+ qkv, _ = self.qkv_proj(x)
425
+
426
+ # [s, b, head * 3 * head_size] --> [s, b, head, 3 * head_size]
427
+ new_x_shape = qkv.size()[:-1] + (
428
+ head,
429
+ 3 * self.hidden_size_per_attention_head,
430
+ )
431
+ qkv = qkv.view(*new_x_shape)
432
+
433
+ # [s, b, head, 3 * head_size] --> 3 [s, b, head, head_size]
434
+ q, k, v = dist_utils.split_tensor_along_last_dim(qkv, 3)
435
+ # [s, b, head, head_size] --> [b, s, head, head_size]
436
+ q, k, v = [
437
+ rearrange(x, "s b ... -> b s ...").contiguous() for x in (q, k, v)
438
+ ]
439
+
440
+ if position_embeddings is not None:
441
+ cos, sin = position_embeddings
442
+ original_shape = q.shape
443
+ # [total_tokens, head, head_size]
444
+ q = q.view(-1, head, self.head_size)
445
+ k = k.view(-1, head, self.head_size)
446
+
447
+ q, k = apply_rotary_pos_emb(q, k, cos, sin)
448
+
449
+ q = q.view(original_shape)
450
+ k = k.view(original_shape)
451
+
452
+ if q.dim() == 4:
453
+ # [b, s, head, head_size] --> [b * s, head, head_size]
454
+ q = rearrange(q, "b s ... -> (b s) ...")
455
+ if k.dim() == 4:
456
+ # [b, s, head, head_size] --> [b * s, head, head_size]
457
+ k = rearrange(k, "b s ... -> (b s) ...")
458
+ if v.dim() == 4:
459
+ # [b, s, head, head_size] --> [b * s, head, head_size]
460
+ v = rearrange(v, "b s ... -> (b s) ...")
461
+
462
+ assert q.dim() == 3, q.dim()
463
+ assert k.dim() == 3, k.dim()
464
+ assert v.dim() == 3, v.dim()
465
+
466
+ output = self.qkv_backend.forward(
467
+ q=q,
468
+ k=k,
469
+ v=v,
470
+ bsz=bsz,
471
+ cu_seqlens=cu_seqlens,
472
+ attention_mask=attention_mask,
473
+ )
474
+
475
+ assert output.dim() == 3, output.shape
476
+
477
+ if self.use_qkv_parallel:
478
+ # [b * s, h, head_size] --> [b, s, h * head_size]
479
+ output = rearrange(output, "(b s) ... h d -> b s ... (h d)", b=bsz)
480
+
481
+ # [b, s, h * head_size] --> [b, s, h * head_size]
482
+ output, _ = self.proj(output)
483
+ else:
484
+ # [b * s, h, head_size] --> [s, b, h * head_size]
485
+ context_layer = rearrange(
486
+ output, "(b s) h d -> s b (h d)", b=bsz, s=s
487
+ ).contiguous()
488
+
489
+ # [s, b, h * head_size] --> [s, b, h * head_size]
490
+ output, _ = self.proj(context_layer)
491
+
492
+ # [s, b, h * head_size] --> [b, s, h * head_size]
493
+ output = output.view(bsz, s, -1)
494
+
495
+ return output