sglang 0.3.6.post3__py3-none-any.whl → 0.4.0.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +1 -1
- sglang/bench_one_batch.py +4 -0
- sglang/bench_serving.py +13 -0
- sglang/check_env.py +1 -1
- sglang/srt/_custom_ops.py +118 -0
- sglang/srt/configs/device_config.py +17 -0
- sglang/srt/configs/load_config.py +84 -0
- sglang/srt/configs/model_config.py +161 -4
- sglang/srt/configs/qwen2vl.py +5 -8
- sglang/srt/constrained/outlines_backend.py +11 -1
- sglang/srt/constrained/outlines_jump_forward.py +8 -1
- sglang/srt/constrained/xgrammar_backend.py +5 -5
- sglang/srt/distributed/__init__.py +3 -0
- sglang/srt/distributed/communication_op.py +34 -0
- sglang/srt/distributed/device_communicators/__init__.py +0 -0
- sglang/srt/distributed/device_communicators/cuda_wrapper.py +182 -0
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +352 -0
- sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +291 -0
- sglang/srt/distributed/device_communicators/hpu_communicator.py +48 -0
- sglang/srt/distributed/device_communicators/pynccl.py +204 -0
- sglang/srt/distributed/device_communicators/pynccl_wrapper.py +362 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +568 -0
- sglang/srt/distributed/device_communicators/xpu_communicator.py +47 -0
- sglang/srt/distributed/parallel_state.py +1275 -0
- sglang/srt/distributed/utils.py +223 -0
- sglang/srt/hf_transformers_utils.py +37 -1
- sglang/srt/layers/attention/__init__.py +5 -2
- sglang/srt/layers/attention/double_sparsity_backend.py +22 -8
- sglang/srt/layers/attention/flashinfer_backend.py +33 -20
- sglang/srt/layers/attention/torch_native_backend.py +299 -0
- sglang/srt/layers/attention/triton_backend.py +22 -8
- sglang/srt/layers/attention/triton_ops/extend_attention.py +3 -0
- sglang/srt/layers/ep_moe/__init__.py +0 -0
- sglang/srt/layers/ep_moe/kernels.py +349 -0
- sglang/srt/layers/ep_moe/layer.py +661 -0
- sglang/srt/layers/fused_moe_patch.py +20 -11
- sglang/srt/layers/linear.py +1 -0
- sglang/srt/layers/logits_processor.py +17 -3
- sglang/srt/layers/quantization/__init__.py +36 -2
- sglang/srt/layers/quantization/fp8.py +559 -0
- sglang/srt/layers/quantization/fp8_utils.py +27 -0
- sglang/srt/layers/radix_attention.py +4 -2
- sglang/srt/layers/sampler.py +2 -0
- sglang/srt/layers/torchao_utils.py +23 -45
- sglang/srt/layers/vocab_parallel_embedding.py +1 -0
- sglang/srt/lora/lora.py +1 -1
- sglang/srt/managers/io_struct.py +48 -2
- sglang/srt/managers/schedule_batch.py +19 -14
- sglang/srt/managers/schedule_policy.py +7 -4
- sglang/srt/managers/scheduler.py +145 -85
- sglang/srt/managers/tokenizer_manager.py +166 -68
- sglang/srt/managers/tp_worker.py +36 -3
- sglang/srt/managers/tp_worker_overlap_thread.py +28 -8
- sglang/srt/mem_cache/memory_pool.py +5 -1
- sglang/srt/model_executor/cuda_graph_runner.py +30 -7
- sglang/srt/model_executor/forward_batch_info.py +9 -4
- sglang/srt/model_executor/model_runner.py +146 -153
- sglang/srt/model_loader/__init__.py +34 -0
- sglang/srt/model_loader/loader.py +1139 -0
- sglang/srt/model_loader/utils.py +41 -0
- sglang/srt/model_loader/weight_utils.py +640 -0
- sglang/srt/model_parallel.py +1 -5
- sglang/srt/models/baichuan.py +9 -10
- sglang/srt/models/chatglm.py +6 -15
- sglang/srt/models/commandr.py +4 -5
- sglang/srt/models/dbrx.py +2 -3
- sglang/srt/models/deepseek.py +4 -11
- sglang/srt/models/deepseek_v2.py +90 -18
- sglang/srt/models/exaone.py +2 -3
- sglang/srt/models/gemma.py +2 -6
- sglang/srt/models/gemma2.py +3 -14
- sglang/srt/models/gemma2_reward.py +0 -1
- sglang/srt/models/gpt2.py +5 -12
- sglang/srt/models/gpt_bigcode.py +6 -22
- sglang/srt/models/grok.py +3 -8
- sglang/srt/models/internlm2.py +2 -3
- sglang/srt/models/internlm2_reward.py +0 -1
- sglang/srt/models/llama.py +96 -31
- sglang/srt/models/llama_classification.py +1 -2
- sglang/srt/models/llama_embedding.py +1 -2
- sglang/srt/models/llama_reward.py +2 -3
- sglang/srt/models/llava.py +1 -4
- sglang/srt/models/llavavid.py +1 -2
- sglang/srt/models/minicpm.py +4 -7
- sglang/srt/models/minicpm3.py +6 -19
- sglang/srt/models/mixtral.py +24 -14
- sglang/srt/models/mixtral_quant.py +2 -3
- sglang/srt/models/mllama.py +3 -7
- sglang/srt/models/olmo.py +2 -8
- sglang/srt/models/olmo2.py +0 -1
- sglang/srt/models/olmoe.py +3 -5
- sglang/srt/models/phi3_small.py +8 -13
- sglang/srt/models/qwen.py +2 -3
- sglang/srt/models/qwen2.py +10 -9
- sglang/srt/models/qwen2_moe.py +4 -16
- sglang/srt/models/qwen2_vl.py +2 -6
- sglang/srt/models/registry.py +99 -0
- sglang/srt/models/stablelm.py +2 -3
- sglang/srt/models/torch_native_llama.py +6 -17
- sglang/srt/models/xverse.py +2 -4
- sglang/srt/models/xverse_moe.py +4 -11
- sglang/srt/models/yivl.py +2 -3
- sglang/srt/openai_api/adapter.py +9 -5
- sglang/srt/openai_api/protocol.py +1 -0
- sglang/srt/sampling/sampling_batch_info.py +9 -8
- sglang/srt/server.py +270 -173
- sglang/srt/server_args.py +102 -29
- sglang/srt/utils.py +295 -28
- sglang/test/test_utils.py +7 -0
- sglang/version.py +1 -1
- {sglang-0.3.6.post3.dist-info → sglang-0.4.0.post1.dist-info}/METADATA +5 -4
- sglang-0.4.0.post1.dist-info/RECORD +189 -0
- sglang-0.3.6.post3.dist-info/RECORD +0 -162
- {sglang-0.3.6.post3.dist-info → sglang-0.4.0.post1.dist-info}/LICENSE +0 -0
- {sglang-0.3.6.post3.dist-info → sglang-0.4.0.post1.dist-info}/WHEEL +0 -0
- {sglang-0.3.6.post3.dist-info → sglang-0.4.0.post1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,349 @@
|
|
1
|
+
import logging
|
2
|
+
from typing import Optional
|
3
|
+
|
4
|
+
import torch
|
5
|
+
import triton
|
6
|
+
import triton.language as tl
|
7
|
+
|
8
|
+
logger = logging.getLogger(__name__)
|
9
|
+
|
10
|
+
|
11
|
+
@triton.jit
|
12
|
+
def compute_seg_indptr_triton_kernel(reorder_topk_ids, seg_indptr, num_toks):
|
13
|
+
expert = tl.program_id(0)
|
14
|
+
low = 0
|
15
|
+
high = num_toks - 1
|
16
|
+
target_location = -1
|
17
|
+
while low <= high:
|
18
|
+
mid = (low + high) // 2
|
19
|
+
|
20
|
+
if tl.load(reorder_topk_ids + mid) > expert:
|
21
|
+
high = mid - 1
|
22
|
+
else:
|
23
|
+
low = mid + 1
|
24
|
+
target_location = mid
|
25
|
+
tl.store(seg_indptr + expert + 1, target_location + 1)
|
26
|
+
|
27
|
+
|
28
|
+
@triton.jit
|
29
|
+
def compute_src2dst_triton_kernel(
|
30
|
+
reorder_ids, src2dst, num_toks, BLOCK_SIZE: tl.constexpr
|
31
|
+
):
|
32
|
+
pid = tl.program_id(axis=0)
|
33
|
+
dst_id = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
|
34
|
+
mask = dst_id < num_toks
|
35
|
+
src_id = tl.load(reorder_ids + dst_id, mask=mask)
|
36
|
+
tl.store(src2dst + src_id, dst_id, mask=mask)
|
37
|
+
|
38
|
+
|
39
|
+
def run_moe_ep_preproess(topk_ids: torch.Tensor, num_experts: int):
|
40
|
+
reorder_topk_ids, reorder_ids = torch.sort(topk_ids.view(-1), stable=True)
|
41
|
+
seg_indptr = torch.zeros(num_experts + 1, device=topk_ids.device, dtype=torch.int64)
|
42
|
+
src2dst = torch.empty(topk_ids.numel(), device=topk_ids.device, dtype=torch.int32)
|
43
|
+
|
44
|
+
compute_seg_indptr_triton_kernel[(num_experts,)](
|
45
|
+
reorder_topk_ids, seg_indptr, topk_ids.numel()
|
46
|
+
)
|
47
|
+
|
48
|
+
BLOCK_SIZE = 512
|
49
|
+
grid = (triton.cdiv(topk_ids.numel(), BLOCK_SIZE),)
|
50
|
+
compute_src2dst_triton_kernel[grid](
|
51
|
+
reorder_ids, src2dst, topk_ids.numel(), BLOCK_SIZE
|
52
|
+
)
|
53
|
+
return reorder_topk_ids, src2dst, seg_indptr
|
54
|
+
|
55
|
+
|
56
|
+
@triton.jit
|
57
|
+
def pre_reorder_triton_kernel(
|
58
|
+
input_ptr,
|
59
|
+
gateup_input_ptr,
|
60
|
+
src2dst_ptr,
|
61
|
+
topk_ids_ptr,
|
62
|
+
a1_scales_ptr,
|
63
|
+
start_expert_id,
|
64
|
+
end_expert_id,
|
65
|
+
topk,
|
66
|
+
hidden_size,
|
67
|
+
BLOCK_SIZE: tl.constexpr,
|
68
|
+
):
|
69
|
+
OutDtype = gateup_input_ptr.dtype.element_ty
|
70
|
+
|
71
|
+
src_idx = tl.program_id(0)
|
72
|
+
src2dst_ptr = src2dst_ptr + src_idx * topk
|
73
|
+
topk_ids_ptr = topk_ids_ptr + src_idx * topk
|
74
|
+
|
75
|
+
src_ptr = input_ptr + src_idx * hidden_size
|
76
|
+
for idx in range(topk):
|
77
|
+
expert_id = tl.load(topk_ids_ptr + idx)
|
78
|
+
if expert_id >= start_expert_id and expert_id <= end_expert_id:
|
79
|
+
if a1_scales_ptr is not None:
|
80
|
+
scale = 1.0 / tl.load(a1_scales_ptr + expert_id - start_expert_id)
|
81
|
+
else:
|
82
|
+
scale = 1.0
|
83
|
+
|
84
|
+
dst_idx = tl.load(src2dst_ptr + idx)
|
85
|
+
dst_ptr = gateup_input_ptr + dst_idx * hidden_size
|
86
|
+
for start_offset in tl.range(0, hidden_size, BLOCK_SIZE):
|
87
|
+
offset = start_offset + tl.arange(0, BLOCK_SIZE)
|
88
|
+
mask = offset < hidden_size
|
89
|
+
in_data = tl.load(src_ptr + offset, mask=mask).to(tl.float32)
|
90
|
+
out_data = (in_data * scale).to(OutDtype)
|
91
|
+
tl.store(dst_ptr + offset, out_data, mask=mask)
|
92
|
+
|
93
|
+
|
94
|
+
@triton.jit
|
95
|
+
def silu_and_mul_triton_kernel(
|
96
|
+
gateup_output,
|
97
|
+
down_input,
|
98
|
+
hidden_size,
|
99
|
+
reorder_topk_ids,
|
100
|
+
scales,
|
101
|
+
start_expert_id,
|
102
|
+
end_expert_id,
|
103
|
+
BLOCK_SIZE: tl.constexpr,
|
104
|
+
):
|
105
|
+
InDtype = gateup_output.dtype.element_ty
|
106
|
+
OutDtype = down_input.dtype.element_ty
|
107
|
+
|
108
|
+
half_hidden_size = hidden_size // 2
|
109
|
+
|
110
|
+
pid = tl.program_id(0)
|
111
|
+
expert_id = tl.load(reorder_topk_ids + pid)
|
112
|
+
if expert_id >= start_expert_id and expert_id <= end_expert_id:
|
113
|
+
gateup_output_ptr = gateup_output + pid * hidden_size
|
114
|
+
gate_output_ptr = gateup_output_ptr
|
115
|
+
up_output_ptr = gateup_output_ptr + half_hidden_size
|
116
|
+
down_input_ptr = down_input + pid * half_hidden_size
|
117
|
+
|
118
|
+
if scales is not None:
|
119
|
+
scale = tl.load(scales + expert_id - start_expert_id)
|
120
|
+
scale = (1 / scale).to(InDtype)
|
121
|
+
else:
|
122
|
+
scale = 1
|
123
|
+
|
124
|
+
for start_offset in tl.range(0, half_hidden_size, BLOCK_SIZE):
|
125
|
+
offset = start_offset + tl.arange(0, BLOCK_SIZE)
|
126
|
+
mask = offset < half_hidden_size
|
127
|
+
|
128
|
+
gate_output = tl.load(gate_output_ptr + offset, mask=mask).to(tl.float32)
|
129
|
+
up_output = tl.load(up_output_ptr + offset, mask=mask)
|
130
|
+
|
131
|
+
# silu & mul & quantize
|
132
|
+
gate_output = gate_output * tl.sigmoid(gate_output)
|
133
|
+
gate_output = gate_output.to(InDtype)
|
134
|
+
|
135
|
+
silu_mul_output = gate_output * up_output * scale
|
136
|
+
silu_mul_output = silu_mul_output.to(OutDtype)
|
137
|
+
tl.store(down_input_ptr + offset, silu_mul_output, mask=mask)
|
138
|
+
|
139
|
+
|
140
|
+
@triton.jit
|
141
|
+
def post_reorder_triton_kernel(
|
142
|
+
down_output_ptr,
|
143
|
+
output_ptr,
|
144
|
+
src2dst_ptr,
|
145
|
+
topk_ids_ptr,
|
146
|
+
topk_weights_ptr,
|
147
|
+
start_expert_id,
|
148
|
+
end_expert_id,
|
149
|
+
topk,
|
150
|
+
hidden_size,
|
151
|
+
BLOCK_SIZE: tl.constexpr,
|
152
|
+
):
|
153
|
+
InDtype = down_output_ptr.dtype.element_ty
|
154
|
+
|
155
|
+
src_idx = tl.program_id(0)
|
156
|
+
src2dst_ptr = src2dst_ptr + src_idx * topk
|
157
|
+
topk_ids_ptr = topk_ids_ptr + src_idx * topk
|
158
|
+
topk_weights_ptr = topk_weights_ptr + src_idx * topk
|
159
|
+
|
160
|
+
computed = False
|
161
|
+
store_ptr = output_ptr + src_idx * hidden_size
|
162
|
+
for start_offset in tl.range(0, hidden_size, BLOCK_SIZE):
|
163
|
+
offset = start_offset + tl.arange(0, BLOCK_SIZE)
|
164
|
+
mask = offset < hidden_size
|
165
|
+
|
166
|
+
sum_vec = tl.zeros([BLOCK_SIZE], dtype=InDtype)
|
167
|
+
for idx in range(topk):
|
168
|
+
expert_id = tl.load(topk_ids_ptr + idx)
|
169
|
+
if expert_id >= start_expert_id and expert_id <= end_expert_id:
|
170
|
+
computed = True
|
171
|
+
dst_idx = tl.load(src2dst_ptr + idx)
|
172
|
+
weigh_scale = tl.load(topk_weights_ptr + idx).to(InDtype)
|
173
|
+
load_ptr = down_output_ptr + dst_idx * hidden_size
|
174
|
+
in_data = tl.load(load_ptr + offset, mask=mask)
|
175
|
+
sum_vec += in_data * weigh_scale
|
176
|
+
tl.store(store_ptr + offset, sum_vec, mask=mask)
|
177
|
+
|
178
|
+
if computed == False:
|
179
|
+
for start_offset in tl.range(0, hidden_size, BLOCK_SIZE):
|
180
|
+
offset = start_offset + tl.arange(0, BLOCK_SIZE)
|
181
|
+
mask = offset < hidden_size
|
182
|
+
tl.store(
|
183
|
+
store_ptr + offset, tl.zeros([BLOCK_SIZE], dtype=InDtype), mask=mask
|
184
|
+
)
|
185
|
+
|
186
|
+
|
187
|
+
@triton.jit
|
188
|
+
def compute_m_range(
|
189
|
+
pid,
|
190
|
+
batch_size,
|
191
|
+
seg_indptr,
|
192
|
+
weight_indices,
|
193
|
+
m_num_tiles_indptr,
|
194
|
+
BLOCK_SIZE_M: tl.constexpr,
|
195
|
+
):
|
196
|
+
idx = 0
|
197
|
+
for bs in range(batch_size):
|
198
|
+
tiles = tl.load(m_num_tiles_indptr + bs)
|
199
|
+
if pid >= tiles:
|
200
|
+
idx = bs
|
201
|
+
|
202
|
+
idx_start = tl.load(m_num_tiles_indptr + idx)
|
203
|
+
|
204
|
+
m_range_start = tl.load(seg_indptr + idx) + (pid - idx_start) * BLOCK_SIZE_M
|
205
|
+
m_range_end = min(tl.load(seg_indptr + idx + 1), m_range_start + BLOCK_SIZE_M)
|
206
|
+
expert_id = tl.load(weight_indices + idx)
|
207
|
+
return m_range_start, m_range_end, expert_id
|
208
|
+
|
209
|
+
|
210
|
+
@triton.jit
|
211
|
+
def grouped_gemm_triton_kernel(
|
212
|
+
a,
|
213
|
+
b,
|
214
|
+
c,
|
215
|
+
batch_size,
|
216
|
+
N,
|
217
|
+
K,
|
218
|
+
seg_indptr,
|
219
|
+
weight_indices,
|
220
|
+
m_num_tiles_indptr,
|
221
|
+
use_fp8_w8a8,
|
222
|
+
scale_a,
|
223
|
+
scale_b,
|
224
|
+
a_stride_0: tl.constexpr,
|
225
|
+
b_stride_0: tl.constexpr,
|
226
|
+
b_stride_1: tl.constexpr,
|
227
|
+
BLOCK_SIZE_M: tl.constexpr,
|
228
|
+
BLOCK_SIZE_N: tl.constexpr,
|
229
|
+
BLOCK_SIZE_K: tl.constexpr,
|
230
|
+
):
|
231
|
+
c_dtype = c.dtype.element_ty
|
232
|
+
|
233
|
+
pid_m = tl.program_id(0)
|
234
|
+
pid_n = tl.program_id(1)
|
235
|
+
total_m_block = tl.load(m_num_tiles_indptr + batch_size)
|
236
|
+
if pid_m >= total_m_block:
|
237
|
+
return
|
238
|
+
|
239
|
+
m_range_start, m_range_end, expert_id = compute_m_range(
|
240
|
+
pid_m, batch_size, seg_indptr, weight_indices, m_num_tiles_indptr, BLOCK_SIZE_M
|
241
|
+
)
|
242
|
+
if m_range_end - m_range_start == 0:
|
243
|
+
return
|
244
|
+
|
245
|
+
n_range_start = pid_n * BLOCK_SIZE_N
|
246
|
+
n_range_end = min(n_range_start + BLOCK_SIZE_N, N)
|
247
|
+
|
248
|
+
offs_am = tl.arange(0, BLOCK_SIZE_M)
|
249
|
+
offs_bn = tl.arange(0, BLOCK_SIZE_N)
|
250
|
+
|
251
|
+
offs_am = tl.where(offs_am < m_range_end - m_range_start, offs_am, 0)
|
252
|
+
offs_bn = tl.where(offs_bn < n_range_end - n_range_start, offs_bn, 0)
|
253
|
+
offs_am = tl.max_contiguous(tl.multiple_of(offs_am, BLOCK_SIZE_M), BLOCK_SIZE_M)
|
254
|
+
offs_bn = tl.max_contiguous(tl.multiple_of(offs_bn, BLOCK_SIZE_N), BLOCK_SIZE_N)
|
255
|
+
offs_k = tl.arange(0, BLOCK_SIZE_K)
|
256
|
+
|
257
|
+
a_ptr = a + (m_range_start + offs_am[:, None]) * a_stride_0 + offs_k[None, :]
|
258
|
+
b_ptr = b + (
|
259
|
+
(expert_id * b_stride_0)
|
260
|
+
+ (n_range_start + offs_bn[:, None]) * b_stride_1
|
261
|
+
+ offs_k[None, :]
|
262
|
+
)
|
263
|
+
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
264
|
+
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
|
265
|
+
a_tile = tl.load(
|
266
|
+
a_ptr, mask=offs_k[None, :] < (K - k * BLOCK_SIZE_K), other=0.0
|
267
|
+
)
|
268
|
+
b_tile = tl.load(
|
269
|
+
b_ptr, mask=offs_k[None, :] < (K - k * BLOCK_SIZE_K), other=0.0
|
270
|
+
)
|
271
|
+
accumulator = tl.dot(a_tile, b_tile.T, accumulator)
|
272
|
+
a_ptr += BLOCK_SIZE_K
|
273
|
+
b_ptr += BLOCK_SIZE_K
|
274
|
+
|
275
|
+
if use_fp8_w8a8:
|
276
|
+
scale_a_value = tl.load(scale_a + expert_id)
|
277
|
+
scale_b_value = tl.load(scale_b + expert_id)
|
278
|
+
accumulator *= scale_a_value * scale_b_value
|
279
|
+
c_tile = accumulator.to(c_dtype)
|
280
|
+
|
281
|
+
offs_cm = m_range_start + tl.arange(0, BLOCK_SIZE_M)
|
282
|
+
offs_cn = n_range_start + tl.arange(0, BLOCK_SIZE_N)
|
283
|
+
c_ptr = c + offs_cm[:, None] * N + offs_cn[None, :]
|
284
|
+
c_mask = (offs_cm[:, None] < m_range_end) & (offs_cn[None, :] < n_range_end)
|
285
|
+
tl.store(c_ptr, c_tile, mask=c_mask)
|
286
|
+
|
287
|
+
|
288
|
+
@triton.jit
|
289
|
+
def compute_m_num_tiles_indptr(
|
290
|
+
m_num_tiles_indptr, seg_indptr, batch_size: tl.constexpr, BLOCK_SIZE_M: tl.constexpr
|
291
|
+
):
|
292
|
+
for bs in range(batch_size):
|
293
|
+
m = tl.load(seg_indptr + bs + 1) - tl.load(seg_indptr + bs)
|
294
|
+
cur_num_tiles = tl.cdiv(m, BLOCK_SIZE_M)
|
295
|
+
pre_num_tiles = tl.load(m_num_tiles_indptr + bs)
|
296
|
+
tl.store(m_num_tiles_indptr + bs + 1, pre_num_tiles + cur_num_tiles)
|
297
|
+
|
298
|
+
|
299
|
+
def grouped_gemm_triton(
|
300
|
+
a: torch.Tensor,
|
301
|
+
b: torch.Tensor,
|
302
|
+
c: torch.Tensor,
|
303
|
+
batch_size: int,
|
304
|
+
weight_column_major: bool,
|
305
|
+
seg_indptr: Optional[torch.Tensor] = None,
|
306
|
+
weight_indices: Optional[torch.Tensor] = None,
|
307
|
+
use_fp8_w8a8: bool = False,
|
308
|
+
scale_a: torch.Tensor = None,
|
309
|
+
scale_b: torch.Tensor = None,
|
310
|
+
):
|
311
|
+
assert weight_column_major == True # TODO: more
|
312
|
+
if use_fp8_w8a8:
|
313
|
+
assert scale_a is not None and scale_b is not None
|
314
|
+
|
315
|
+
config = {
|
316
|
+
"BLOCK_SIZE_M": 128,
|
317
|
+
"BLOCK_SIZE_N": 128,
|
318
|
+
"BLOCK_SIZE_K": 128,
|
319
|
+
}
|
320
|
+
|
321
|
+
m_num_tiles_indptr = torch.zeros(batch_size + 1, device=a.device, dtype=torch.int64)
|
322
|
+
compute_m_num_tiles_indptr[(1,)](
|
323
|
+
m_num_tiles_indptr, seg_indptr, batch_size, config["BLOCK_SIZE_M"]
|
324
|
+
)
|
325
|
+
|
326
|
+
grid = lambda META: (
|
327
|
+
triton.cdiv(a.size(0), META["BLOCK_SIZE_M"]) + batch_size,
|
328
|
+
triton.cdiv(b.size(1), META["BLOCK_SIZE_N"]),
|
329
|
+
)
|
330
|
+
|
331
|
+
grouped_gemm_triton_kernel[grid](
|
332
|
+
a,
|
333
|
+
b,
|
334
|
+
c,
|
335
|
+
batch_size,
|
336
|
+
b.size(1),
|
337
|
+
b.size(2),
|
338
|
+
seg_indptr,
|
339
|
+
weight_indices,
|
340
|
+
m_num_tiles_indptr,
|
341
|
+
use_fp8_w8a8,
|
342
|
+
scale_a,
|
343
|
+
scale_b,
|
344
|
+
a.stride(0),
|
345
|
+
b.stride(0),
|
346
|
+
b.stride(1),
|
347
|
+
**config,
|
348
|
+
)
|
349
|
+
return c
|