sglang 0.4.4__py3-none-any.whl → 0.4.4.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/srt/function_call_parser.py +33 -2
- sglang/srt/layers/dp_attention.py +30 -2
- sglang/srt/layers/elementwise.py +411 -0
- sglang/srt/layers/logits_processor.py +1 -0
- sglang/srt/layers/moe/router.py +342 -0
- sglang/srt/managers/cache_controller.py +2 -0
- sglang/srt/managers/data_parallel_controller.py +1 -1
- sglang/srt/managers/schedule_batch.py +1 -1
- sglang/srt/managers/scheduler.py +52 -18
- sglang/srt/managers/scheduler_output_processor_mixin.py +4 -1
- sglang/srt/mem_cache/hiradix_cache.py +9 -1
- sglang/srt/mem_cache/memory_pool.py +4 -1
- sglang/srt/model_executor/cuda_graph_runner.py +59 -16
- sglang/srt/model_executor/forward_batch_info.py +13 -4
- sglang/srt/models/deepseek_v2.py +180 -177
- sglang/srt/models/grok.py +374 -119
- sglang/srt/openai_api/adapter.py +22 -20
- sglang/srt/server_args.py +5 -5
- sglang/version.py +1 -1
- {sglang-0.4.4.dist-info → sglang-0.4.4.post1.dist-info}/METADATA +1 -1
- {sglang-0.4.4.dist-info → sglang-0.4.4.post1.dist-info}/RECORD +24 -22
- {sglang-0.4.4.dist-info → sglang-0.4.4.post1.dist-info}/LICENSE +0 -0
- {sglang-0.4.4.dist-info → sglang-0.4.4.post1.dist-info}/WHEEL +0 -0
- {sglang-0.4.4.dist-info → sglang-0.4.4.post1.dist-info}/top_level.txt +0 -0
@@ -318,6 +318,10 @@ class Qwen25Detector(BaseFormatDetector):
|
|
318
318
|
self.bot_token = "<tool_call>"
|
319
319
|
self.eot_token = "</tool_call>"
|
320
320
|
|
321
|
+
def has_tool_call(self, text: str) -> bool:
|
322
|
+
"""Check if the text contains a Qwen 2.5 format tool call."""
|
323
|
+
return self.bot_token in text
|
324
|
+
|
321
325
|
def detect_and_parse(self, text: str, tools: List[Function]) -> List[ToolCallItem]:
|
322
326
|
"""
|
323
327
|
One-time parsing: Detects and parses tool calls in the provided text.
|
@@ -352,6 +356,10 @@ class MistralDetector(BaseFormatDetector):
|
|
352
356
|
self.bot_token = "[TOOL_CALLS] ["
|
353
357
|
self.tool_call_regex = re.compile(r"\[{.*}\]", re.DOTALL)
|
354
358
|
|
359
|
+
def has_tool_call(self, text: str) -> bool:
|
360
|
+
"""Check if the text contains a Mistral format tool call."""
|
361
|
+
return self.bot_token in text
|
362
|
+
|
355
363
|
def _clean_text(self, text: str) -> str:
|
356
364
|
"""
|
357
365
|
clean text to only leave ''[TOOL_CALLS] [{"name": xxx, "arguments": {xxx}}]'
|
@@ -397,12 +405,21 @@ class Llama32Detector(BaseFormatDetector):
|
|
397
405
|
super().__init__()
|
398
406
|
self.bot_token = "<|python_tag|>"
|
399
407
|
|
408
|
+
def has_tool_call(self, text: str) -> bool:
|
409
|
+
"""Check if the text contains a Llama 3.2 format tool call."""
|
410
|
+
# depending on the prompt format the Llama model may or may not
|
411
|
+
# prefix the output with the <|python_tag|> token
|
412
|
+
return "<|python_tag|>" in text or text.startswith("{")
|
413
|
+
|
400
414
|
def detect_and_parse(self, text: str, tools: List[Function]) -> List[ToolCallItem]:
|
401
415
|
"""Parse function calls from text, handling multiple JSON objects."""
|
402
|
-
if "<|python_tag|>" not in text:
|
416
|
+
if "<|python_tag|>" not in text and not text.startswith("{"):
|
403
417
|
return []
|
404
418
|
|
405
|
-
|
419
|
+
if "<|python_tag|>" in text:
|
420
|
+
_, action_text = text.split("<|python_tag|>")
|
421
|
+
else:
|
422
|
+
action_text = text
|
406
423
|
|
407
424
|
# Split by semicolon and process each part
|
408
425
|
json_parts = [part.strip() for part in action_text.split(";") if part.strip()]
|
@@ -501,6 +518,20 @@ class FunctionCallParser:
|
|
501
518
|
self.multi_format_parser = MultiFormatParser(detectors)
|
502
519
|
self.tools = tools
|
503
520
|
|
521
|
+
def has_tool_call(self, text: str) -> bool:
|
522
|
+
"""
|
523
|
+
Check if the given text contains a tool call in the format supported by this parser.
|
524
|
+
This delegates to the detector's implementation.
|
525
|
+
|
526
|
+
:param text: The text to check for tool calls
|
527
|
+
:return: True if the text contains a tool call, False otherwise
|
528
|
+
"""
|
529
|
+
# Check all detectors in the multi_format_parser
|
530
|
+
for detector in self.multi_format_parser.detectors:
|
531
|
+
if detector.has_tool_call(text):
|
532
|
+
return True
|
533
|
+
return False
|
534
|
+
|
504
535
|
def parse_non_stream(self, full_text: str):
|
505
536
|
"""
|
506
537
|
Non-streaming call: one-time parsing
|
@@ -1,6 +1,8 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
3
|
import functools
|
4
|
+
import logging
|
5
|
+
from contextlib import contextmanager
|
4
6
|
from typing import TYPE_CHECKING, Union
|
5
7
|
|
6
8
|
import torch
|
@@ -14,6 +16,8 @@ from sglang.srt.distributed import (
|
|
14
16
|
tensor_model_parallel_all_reduce,
|
15
17
|
)
|
16
18
|
|
19
|
+
logger = logging.getLogger(__name__)
|
20
|
+
|
17
21
|
if TYPE_CHECKING:
|
18
22
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
19
23
|
|
@@ -86,6 +90,27 @@ def get_attention_dp_size():
|
|
86
90
|
return _DP_SIZE
|
87
91
|
|
88
92
|
|
93
|
+
@contextmanager
|
94
|
+
def disable_dp_size():
|
95
|
+
"""Patch the tp group temporarily until this function ends.
|
96
|
+
|
97
|
+
This method is for draft workers of speculative decoding to run draft model
|
98
|
+
with different tp degree from that of target model workers.
|
99
|
+
|
100
|
+
Args:
|
101
|
+
tp_group (GroupCoordinator): the tp group coordinator
|
102
|
+
"""
|
103
|
+
global _DP_SIZE
|
104
|
+
assert _DP_SIZE is not None, "dp attention not initialized!"
|
105
|
+
|
106
|
+
old_dp_size = _DP_SIZE
|
107
|
+
_DP_SIZE = 1
|
108
|
+
try:
|
109
|
+
yield
|
110
|
+
finally:
|
111
|
+
_DP_SIZE = old_dp_size
|
112
|
+
|
113
|
+
|
89
114
|
def get_dp_local_info(forward_batch: ForwardBatch):
|
90
115
|
dp_rank = get_attention_dp_rank()
|
91
116
|
|
@@ -159,7 +184,8 @@ def dp_gather(
|
|
159
184
|
layer_id != "embedding" or get_attention_tp_rank() == 0
|
160
185
|
):
|
161
186
|
assert (
|
162
|
-
global_tokens.
|
187
|
+
global_tokens.untyped_storage().data_ptr()
|
188
|
+
!= local_tokens.untyped_storage().data_ptr()
|
163
189
|
), "aliasing between global_tokens and local_tokens not allowed"
|
164
190
|
memcpy_triton(
|
165
191
|
global_tokens, local_tokens, 0, local_start_pos, local_num_tokens, False
|
@@ -174,8 +200,9 @@ def dp_gather(
|
|
174
200
|
torch.ops.sglang.inplace_all_reduce(
|
175
201
|
global_tokens, group_name=get_tp_group().unique_name
|
176
202
|
)
|
203
|
+
|
177
204
|
else:
|
178
|
-
global_tokens = tensor_model_parallel_all_reduce(global_tokens)
|
205
|
+
global_tokens[:] = tensor_model_parallel_all_reduce(global_tokens)
|
179
206
|
|
180
207
|
|
181
208
|
def dp_scatter(
|
@@ -186,6 +213,7 @@ def dp_scatter(
|
|
186
213
|
# local_num_tokens is not necessarily the same as local_tokens.shape[0],
|
187
214
|
# since local_tokens may be padded for cuda graph
|
188
215
|
local_start_pos, local_num_tokens = get_dp_local_info(forward_batch)
|
216
|
+
|
189
217
|
local_tokens.fill_(0)
|
190
218
|
assert local_tokens.is_contiguous()
|
191
219
|
assert global_tokens.is_contiguous()
|
@@ -0,0 +1,411 @@
|
|
1
|
+
from typing import Tuple
|
2
|
+
|
3
|
+
import torch
|
4
|
+
import triton
|
5
|
+
import triton.language as tl
|
6
|
+
|
7
|
+
fused_softcap_autotune = triton.autotune(
|
8
|
+
configs=[
|
9
|
+
triton.Config(kwargs={"BLOCK_SIZE": 128}, num_warps=4),
|
10
|
+
triton.Config(kwargs={"BLOCK_SIZE": 128}, num_warps=8),
|
11
|
+
triton.Config(kwargs={"BLOCK_SIZE": 128}, num_warps=16),
|
12
|
+
triton.Config(kwargs={"BLOCK_SIZE": 256}, num_warps=4),
|
13
|
+
triton.Config(kwargs={"BLOCK_SIZE": 256}, num_warps=8),
|
14
|
+
triton.Config(kwargs={"BLOCK_SIZE": 512}, num_warps=4),
|
15
|
+
triton.Config(kwargs={"BLOCK_SIZE": 512}, num_warps=8),
|
16
|
+
triton.Config(kwargs={"BLOCK_SIZE": 512}, num_warps=16),
|
17
|
+
triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_warps=4),
|
18
|
+
triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_warps=8),
|
19
|
+
triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_warps=16),
|
20
|
+
triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_warps=32),
|
21
|
+
triton.Config(kwargs={"BLOCK_SIZE": 2048}, num_warps=32),
|
22
|
+
triton.Config(kwargs={"BLOCK_SIZE": 4096}, num_warps=32),
|
23
|
+
triton.Config(kwargs={"BLOCK_SIZE": 8192}, num_warps=32),
|
24
|
+
triton.Config(kwargs={"BLOCK_SIZE": 16384}, num_warps=32),
|
25
|
+
triton.Config(kwargs={"BLOCK_SIZE": 32768}, num_warps=32),
|
26
|
+
],
|
27
|
+
key=["n_ele"],
|
28
|
+
)
|
29
|
+
|
30
|
+
|
31
|
+
@triton.jit
|
32
|
+
def fused_softcap_kernel(
|
33
|
+
output_ptr,
|
34
|
+
input_ptr,
|
35
|
+
n_ele,
|
36
|
+
softcap_const: tl.constexpr,
|
37
|
+
BLOCK_SIZE: tl.constexpr,
|
38
|
+
):
|
39
|
+
pid = tl.program_id(axis=0)
|
40
|
+
block_start = pid * BLOCK_SIZE
|
41
|
+
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
42
|
+
mask = offsets < n_ele
|
43
|
+
x = tl.load(input_ptr + offsets, mask=mask)
|
44
|
+
fx = x.to(tl.float32)
|
45
|
+
fxs = fx / softcap_const
|
46
|
+
exped = tl.exp(2 * fxs)
|
47
|
+
top = exped - 1
|
48
|
+
bottom = exped + 1
|
49
|
+
output = top / bottom * softcap_const
|
50
|
+
tl.store(output_ptr + offsets, output, mask=mask)
|
51
|
+
|
52
|
+
|
53
|
+
fused_softcap_kernel_autotuned = fused_softcap_autotune(fused_softcap_kernel)
|
54
|
+
|
55
|
+
|
56
|
+
def fused_softcap(x, softcap_const, autotune=False):
|
57
|
+
output = torch.empty_like(x, dtype=torch.float32)
|
58
|
+
n_elements = output.numel()
|
59
|
+
if autotune:
|
60
|
+
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
|
61
|
+
fused_softcap_kernel_autotuned[grid](output, x, n_elements, softcap_const)
|
62
|
+
else:
|
63
|
+
fused_softcap_kernel[(triton.cdiv(n_elements, 128),)](
|
64
|
+
output, x, n_elements, softcap_const, BLOCK_SIZE=128, num_warps=8
|
65
|
+
)
|
66
|
+
return output
|
67
|
+
|
68
|
+
|
69
|
+
# cast to float + softcap
|
70
|
+
class Softcap:
|
71
|
+
def __init__(self, softcap_const: float):
|
72
|
+
self.softcap_const = softcap_const
|
73
|
+
|
74
|
+
def __call__(self, *args, **kwargs):
|
75
|
+
return self.forward(*args, **kwargs)
|
76
|
+
|
77
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
78
|
+
if x.is_cuda:
|
79
|
+
return self.forward_cuda(x)
|
80
|
+
else:
|
81
|
+
return self.forward_native(x)
|
82
|
+
|
83
|
+
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
|
84
|
+
return torch.tanh(x.float() / self.softcap_const) * self.softcap_const
|
85
|
+
|
86
|
+
def forward_cuda(self, x: torch.Tensor, autotune=False) -> torch.Tensor:
|
87
|
+
return fused_softcap(x, self.softcap_const, autotune=autotune)
|
88
|
+
|
89
|
+
|
90
|
+
rmsnorm_autotune = triton.autotune(
|
91
|
+
configs=[
|
92
|
+
triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_warps=4, num_stages=1),
|
93
|
+
triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_warps=8, num_stages=1),
|
94
|
+
triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_warps=16, num_stages=1),
|
95
|
+
triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_warps=4),
|
96
|
+
triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_warps=8),
|
97
|
+
triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_warps=16),
|
98
|
+
triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_warps=4, num_stages=4),
|
99
|
+
triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_warps=8, num_stages=4),
|
100
|
+
triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_warps=16, num_stages=4),
|
101
|
+
triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_warps=8, num_stages=8),
|
102
|
+
triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_warps=16, num_stages=8),
|
103
|
+
triton.Config(kwargs={"BLOCK_SIZE": 2048}, num_warps=8),
|
104
|
+
triton.Config(kwargs={"BLOCK_SIZE": 2048}, num_warps=16),
|
105
|
+
triton.Config(kwargs={"BLOCK_SIZE": 2048}, num_warps=8, num_stages=4),
|
106
|
+
triton.Config(kwargs={"BLOCK_SIZE": 2048}, num_warps=16, num_stages=4),
|
107
|
+
triton.Config(kwargs={"BLOCK_SIZE": 4096}, num_warps=8),
|
108
|
+
triton.Config(kwargs={"BLOCK_SIZE": 4096}, num_warps=16),
|
109
|
+
triton.Config(kwargs={"BLOCK_SIZE": 8192}, num_warps=8),
|
110
|
+
triton.Config(kwargs={"BLOCK_SIZE": 8192}, num_warps=16),
|
111
|
+
triton.Config(kwargs={"BLOCK_SIZE": 8192}, num_warps=32),
|
112
|
+
triton.Config(kwargs={"BLOCK_SIZE": 8192}, num_warps=8, num_stages=1),
|
113
|
+
triton.Config(kwargs={"BLOCK_SIZE": 8192}, num_warps=16, num_stages=1),
|
114
|
+
triton.Config(kwargs={"BLOCK_SIZE": 8192}, num_warps=32, num_stages=1),
|
115
|
+
triton.Config(kwargs={"BLOCK_SIZE": 8192}, num_warps=8, num_stages=4),
|
116
|
+
triton.Config(kwargs={"BLOCK_SIZE": 8192}, num_warps=16, num_stages=4),
|
117
|
+
triton.Config(kwargs={"BLOCK_SIZE": 8192}, num_warps=32, num_stages=4),
|
118
|
+
triton.Config(kwargs={"BLOCK_SIZE": 16384}, num_warps=8),
|
119
|
+
triton.Config(kwargs={"BLOCK_SIZE": 16384}, num_warps=16),
|
120
|
+
triton.Config(kwargs={"BLOCK_SIZE": 16384}, num_warps=32),
|
121
|
+
triton.Config(kwargs={"BLOCK_SIZE": 16384}, num_warps=8, num_stages=1),
|
122
|
+
triton.Config(kwargs={"BLOCK_SIZE": 16384}, num_warps=16, num_stages=1),
|
123
|
+
triton.Config(kwargs={"BLOCK_SIZE": 16384}, num_warps=32, num_stages=1),
|
124
|
+
triton.Config(kwargs={"BLOCK_SIZE": 16384}, num_warps=8, num_stages=4),
|
125
|
+
triton.Config(kwargs={"BLOCK_SIZE": 16384}, num_warps=16, num_stages=4),
|
126
|
+
triton.Config(kwargs={"BLOCK_SIZE": 16384}, num_warps=32, num_stages=4),
|
127
|
+
],
|
128
|
+
key=["hidden_dim"],
|
129
|
+
)
|
130
|
+
|
131
|
+
|
132
|
+
@triton.jit
|
133
|
+
def fused_dual_residual_rmsnorm_kernel(
|
134
|
+
output_ptr,
|
135
|
+
mid_ptr,
|
136
|
+
activ_ptr,
|
137
|
+
residual_ptr,
|
138
|
+
weight1_ptr,
|
139
|
+
weight2_ptr,
|
140
|
+
eps: tl.constexpr,
|
141
|
+
hidden_dim: tl.constexpr,
|
142
|
+
BLOCK_SIZE: tl.constexpr,
|
143
|
+
):
|
144
|
+
pid = tl.program_id(axis=0)
|
145
|
+
input_start = pid * hidden_dim
|
146
|
+
|
147
|
+
offsets = tl.arange(0, BLOCK_SIZE)
|
148
|
+
mask = offsets < hidden_dim
|
149
|
+
|
150
|
+
a_ = tl.load(activ_ptr + input_start + offsets, mask=mask, other=0.0)
|
151
|
+
a = a_.to(tl.float32)
|
152
|
+
rms = tl.sqrt(tl.sum(a * a, axis=0) / hidden_dim + eps)
|
153
|
+
|
154
|
+
r = tl.load(residual_ptr + input_start + offsets, mask=mask, other=0.0)
|
155
|
+
w1_ = tl.load(weight1_ptr + offsets, mask=mask, other=0.0)
|
156
|
+
w1 = w1_.to(tl.float32)
|
157
|
+
|
158
|
+
a2r = r + (a / rms * w1).to(r.dtype)
|
159
|
+
tl.store(
|
160
|
+
mid_ptr + input_start + offsets,
|
161
|
+
a2r,
|
162
|
+
mask=mask,
|
163
|
+
)
|
164
|
+
|
165
|
+
a2r = a2r.to(tl.float32)
|
166
|
+
rms2 = tl.sqrt(tl.sum(a2r * a2r, axis=0) / hidden_dim + eps)
|
167
|
+
|
168
|
+
w2_ = tl.load(weight2_ptr + offsets, mask=mask, other=0.0)
|
169
|
+
w2 = w2_.to(tl.float32)
|
170
|
+
|
171
|
+
tl.store(
|
172
|
+
output_ptr + input_start + offsets,
|
173
|
+
a2r / rms2 * w2, # implicitly casts to output dtype here
|
174
|
+
mask=mask,
|
175
|
+
)
|
176
|
+
|
177
|
+
|
178
|
+
fused_dual_residual_rmsnorm_kernel_autotune = rmsnorm_autotune(
|
179
|
+
fused_dual_residual_rmsnorm_kernel
|
180
|
+
)
|
181
|
+
|
182
|
+
|
183
|
+
def fused_dual_residual_rmsnorm(x, residual, weight1, weight2, eps, autotune=False):
|
184
|
+
assert len(x.shape) == 2
|
185
|
+
assert x.shape == residual.shape and x.dtype == residual.dtype
|
186
|
+
output, mid = torch.empty_like(x), torch.empty_like(x)
|
187
|
+
bs, hidden_dim = x.shape
|
188
|
+
if autotune:
|
189
|
+
fused_dual_residual_rmsnorm_kernel_autotune[(bs,)](
|
190
|
+
output, mid, x, residual, weight1, weight2, eps=eps, hidden_dim=hidden_dim
|
191
|
+
)
|
192
|
+
else:
|
193
|
+
config = {
|
194
|
+
"BLOCK_SIZE": triton.next_power_of_2(hidden_dim),
|
195
|
+
"num_warps": max(
|
196
|
+
min(triton.next_power_of_2(triton.cdiv(hidden_dim, 256)), 32), 4
|
197
|
+
),
|
198
|
+
}
|
199
|
+
|
200
|
+
fused_dual_residual_rmsnorm_kernel[(bs,)](
|
201
|
+
output,
|
202
|
+
mid,
|
203
|
+
x,
|
204
|
+
residual,
|
205
|
+
weight1,
|
206
|
+
weight2,
|
207
|
+
eps=eps,
|
208
|
+
hidden_dim=hidden_dim,
|
209
|
+
**config,
|
210
|
+
)
|
211
|
+
|
212
|
+
return output, mid
|
213
|
+
|
214
|
+
|
215
|
+
@triton.jit
|
216
|
+
def fused_rmsnorm_kernel(
|
217
|
+
output_ptr,
|
218
|
+
activ_ptr,
|
219
|
+
weight_ptr,
|
220
|
+
eps: tl.constexpr,
|
221
|
+
hidden_dim: tl.constexpr,
|
222
|
+
BLOCK_SIZE: tl.constexpr,
|
223
|
+
):
|
224
|
+
pid = tl.program_id(axis=0)
|
225
|
+
input_start = pid * hidden_dim
|
226
|
+
|
227
|
+
offsets = tl.arange(0, BLOCK_SIZE)
|
228
|
+
mask = offsets < hidden_dim
|
229
|
+
|
230
|
+
a_ = tl.load(activ_ptr + input_start + offsets, mask=mask, other=0.0)
|
231
|
+
a = a_.to(tl.float32)
|
232
|
+
rms = tl.sqrt(tl.sum(a * a, axis=0) / hidden_dim + eps)
|
233
|
+
|
234
|
+
w1_ = tl.load(weight_ptr + offsets, mask=mask, other=0.0)
|
235
|
+
w1 = w1_.to(tl.float32)
|
236
|
+
|
237
|
+
a_rms = a / rms * w1
|
238
|
+
|
239
|
+
tl.store(
|
240
|
+
output_ptr + input_start + offsets,
|
241
|
+
a_rms, # implicitly casts to output dtype here
|
242
|
+
mask=mask,
|
243
|
+
)
|
244
|
+
|
245
|
+
|
246
|
+
def fused_rmsnorm(x, weight, eps, autotune=False, inplace=False):
|
247
|
+
assert len(x.shape) == 2
|
248
|
+
if inplace:
|
249
|
+
output = x
|
250
|
+
else:
|
251
|
+
output = torch.empty_like(x)
|
252
|
+
bs, hidden_dim = x.shape
|
253
|
+
config = {
|
254
|
+
"BLOCK_SIZE": triton.next_power_of_2(hidden_dim),
|
255
|
+
"num_warps": max(
|
256
|
+
min(triton.next_power_of_2(triton.cdiv(hidden_dim, 256)), 32), 4
|
257
|
+
),
|
258
|
+
}
|
259
|
+
|
260
|
+
fused_rmsnorm_kernel[(bs,)](
|
261
|
+
output, x, weight, eps=eps, hidden_dim=hidden_dim, **config
|
262
|
+
)
|
263
|
+
return output
|
264
|
+
|
265
|
+
|
266
|
+
class FusedDualResidualRMSNorm:
|
267
|
+
"""
|
268
|
+
Fused implementation of
|
269
|
+
y = RMSNorm2(RMSNorm1(x) + residual))
|
270
|
+
"""
|
271
|
+
|
272
|
+
def __init__(self, rmsnorm1, rmsnorm2) -> None: # the one after rmsnorm1
|
273
|
+
self.rmsnorm1 = rmsnorm1
|
274
|
+
self.rmsnorm2 = rmsnorm2
|
275
|
+
self.variance_epsilon = self.rmsnorm1.variance_epsilon
|
276
|
+
assert self.rmsnorm1.variance_epsilon == self.rmsnorm2.variance_epsilon
|
277
|
+
assert self.rmsnorm1.weight.shape == self.rmsnorm2.weight.shape
|
278
|
+
|
279
|
+
def __call__(self, *args, **kwargs):
|
280
|
+
return self.forward(*args, **kwargs)
|
281
|
+
|
282
|
+
def forward(
|
283
|
+
self, x: torch.Tensor, residual: torch.Tensor
|
284
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
285
|
+
if x.is_cuda:
|
286
|
+
return self.forward_cuda(x, residual)
|
287
|
+
else:
|
288
|
+
return self.forward_flashinfer(x, residual)
|
289
|
+
|
290
|
+
def forward_cuda(
|
291
|
+
self, x: torch.Tensor, residual: torch.Tensor, autotune=False
|
292
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
293
|
+
return fused_dual_residual_rmsnorm(
|
294
|
+
x,
|
295
|
+
residual,
|
296
|
+
self.rmsnorm1.weight,
|
297
|
+
self.rmsnorm2.weight,
|
298
|
+
self.variance_epsilon,
|
299
|
+
autotune=autotune,
|
300
|
+
)
|
301
|
+
|
302
|
+
def forward_flashinfer(
|
303
|
+
self,
|
304
|
+
x: torch.Tensor,
|
305
|
+
residual: torch.Tensor,
|
306
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
307
|
+
normed1 = self.rmsnorm1(x)
|
308
|
+
residual = normed1 + residual
|
309
|
+
return self.rmsnorm2(residual), residual
|
310
|
+
|
311
|
+
def forward_native(
|
312
|
+
self,
|
313
|
+
x: torch.Tensor,
|
314
|
+
residual: torch.Tensor,
|
315
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
316
|
+
normed1 = self.rmsnorm1.forward_native(x)
|
317
|
+
residual = normed1 + residual
|
318
|
+
return self.rmsnorm2.forward_native(residual), residual
|
319
|
+
|
320
|
+
|
321
|
+
# gelu on first half of vector
|
322
|
+
@triton.jit
|
323
|
+
def gelu_and_mul_kernel(
|
324
|
+
out_hidden_states_ptr, # (bs, hidden_dim)
|
325
|
+
out_scales_ptr, # (bs,)
|
326
|
+
hidden_states_ptr, # (bs, hidden_dim * 2)
|
327
|
+
quant_max: tl.constexpr,
|
328
|
+
static_scale: tl.constexpr,
|
329
|
+
hidden_dim: tl.constexpr, # the output hidden_dim
|
330
|
+
BLOCK_SIZE: tl.constexpr,
|
331
|
+
):
|
332
|
+
pid = tl.program_id(axis=0)
|
333
|
+
|
334
|
+
input_start = pid * hidden_dim * 2
|
335
|
+
output_start = pid * hidden_dim
|
336
|
+
|
337
|
+
input1_offs = tl.arange(0, BLOCK_SIZE)
|
338
|
+
mask = tl.arange(0, BLOCK_SIZE) < hidden_dim # shared for input1, input3, output
|
339
|
+
input3_offs = hidden_dim + tl.arange(0, BLOCK_SIZE)
|
340
|
+
output_offs = tl.arange(0, BLOCK_SIZE)
|
341
|
+
|
342
|
+
x1 = tl.load(
|
343
|
+
hidden_states_ptr + input_start + input1_offs, mask=mask, other=0.0
|
344
|
+
).to(tl.float32)
|
345
|
+
x3 = tl.load(
|
346
|
+
hidden_states_ptr + input_start + input3_offs, mask=mask, other=0.0
|
347
|
+
).to(tl.float32)
|
348
|
+
|
349
|
+
# gelu
|
350
|
+
# cast down before mul to better match training?
|
351
|
+
gelu_x1 = 0.5 * (1.0 + tl.erf(x1 * 0.7071067811865475)) * x1
|
352
|
+
out = x3 * gelu_x1.to(hidden_states_ptr.dtype.element_ty)
|
353
|
+
|
354
|
+
if quant_max is not None:
|
355
|
+
raise NotImplementedError()
|
356
|
+
|
357
|
+
tl.store(out_hidden_states_ptr + output_start + output_offs, out, mask=mask)
|
358
|
+
|
359
|
+
|
360
|
+
def gelu_and_mul_triton(
|
361
|
+
hidden_states,
|
362
|
+
scales=None,
|
363
|
+
quantize=None, # dtype to quantize to
|
364
|
+
out=None,
|
365
|
+
):
|
366
|
+
bs, in_hidden_dim = hidden_states.shape
|
367
|
+
hidden_dim = in_hidden_dim // 2
|
368
|
+
|
369
|
+
if out is None:
|
370
|
+
out_hidden_states = torch.empty(
|
371
|
+
(bs, hidden_dim),
|
372
|
+
dtype=quantize or hidden_states.dtype,
|
373
|
+
device=hidden_states.device,
|
374
|
+
)
|
375
|
+
else:
|
376
|
+
assert out.shape == (bs, hidden_dim)
|
377
|
+
assert out.dtype == (quantize or hidden_states.dtype)
|
378
|
+
out_hidden_states = out
|
379
|
+
out_scales = None
|
380
|
+
static_scale = False
|
381
|
+
if quantize is not None:
|
382
|
+
if scales is None:
|
383
|
+
out_scales = torch.empty(
|
384
|
+
(bs,), dtype=torch.float32, device=hidden_states.device
|
385
|
+
)
|
386
|
+
else:
|
387
|
+
out_scales = scales
|
388
|
+
static_scale = True
|
389
|
+
|
390
|
+
config = {
|
391
|
+
# 8 ele per thread (not tuned)
|
392
|
+
"num_warps": max(
|
393
|
+
min(triton.next_power_of_2(triton.cdiv(hidden_dim, 8 * 32)), 32), 4
|
394
|
+
),
|
395
|
+
}
|
396
|
+
|
397
|
+
gelu_and_mul_kernel[(bs,)](
|
398
|
+
out_hidden_states,
|
399
|
+
out_scales,
|
400
|
+
hidden_states,
|
401
|
+
quant_max=torch.finfo(quantize).max if quantize is not None else None,
|
402
|
+
static_scale=static_scale,
|
403
|
+
hidden_dim=hidden_dim,
|
404
|
+
BLOCK_SIZE=triton.next_power_of_2(hidden_dim),
|
405
|
+
**config,
|
406
|
+
)
|
407
|
+
|
408
|
+
if quantize is not None:
|
409
|
+
return out_hidden_states, out_scales
|
410
|
+
else:
|
411
|
+
return out_hidden_states, None
|