sglang 0.5.1.post3__py3-none-any.whl → 0.5.2rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +3 -0
- sglang/srt/configs/__init__.py +2 -0
- sglang/srt/configs/longcat_flash.py +104 -0
- sglang/srt/configs/model_config.py +14 -1
- sglang/srt/connector/__init__.py +1 -1
- sglang/srt/connector/base_connector.py +1 -2
- sglang/srt/connector/redis.py +2 -2
- sglang/srt/connector/serde/__init__.py +1 -1
- sglang/srt/connector/serde/safe_serde.py +4 -3
- sglang/srt/disaggregation/ascend/conn.py +75 -0
- sglang/srt/disaggregation/launch_lb.py +0 -13
- sglang/srt/disaggregation/mini_lb.py +33 -8
- sglang/srt/disaggregation/prefill.py +1 -1
- sglang/srt/distributed/parallel_state.py +27 -15
- sglang/srt/entrypoints/engine.py +19 -12
- sglang/srt/entrypoints/http_server.py +174 -34
- sglang/srt/entrypoints/openai/protocol.py +60 -0
- sglang/srt/eplb/eplb_manager.py +26 -2
- sglang/srt/eplb/expert_distribution.py +29 -2
- sglang/srt/hf_transformers_utils.py +10 -0
- sglang/srt/layers/activation.py +12 -0
- sglang/srt/layers/attention/ascend_backend.py +240 -109
- sglang/srt/layers/attention/hybrid_attn_backend.py +53 -21
- sglang/srt/layers/attention/trtllm_mla_backend.py +25 -10
- sglang/srt/layers/layernorm.py +28 -3
- sglang/srt/layers/linear.py +3 -2
- sglang/srt/layers/logits_processor.py +1 -1
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +1 -9
- sglang/srt/layers/moe/ep_moe/kernels.py +74 -0
- sglang/srt/layers/moe/ep_moe/layer.py +14 -13
- sglang/srt/layers/moe/fused_moe_triton/__init__.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -1048
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +212 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +796 -0
- sglang/srt/layers/moe/fused_moe_triton/layer.py +5 -2
- sglang/srt/layers/moe/fused_moe_triton/moe_align_block_size.py +87 -0
- sglang/srt/layers/moe/topk.py +35 -12
- sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +9 -1
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +0 -3
- sglang/srt/layers/quantization/modelopt_quant.py +7 -0
- sglang/srt/layers/quantization/mxfp4.py +9 -4
- sglang/srt/layers/quantization/utils.py +13 -0
- sglang/srt/layers/quantization/w4afp8.py +30 -25
- sglang/srt/layers/quantization/w8a8_int8.py +7 -3
- sglang/srt/layers/rotary_embedding.py +28 -1
- sglang/srt/layers/sampler.py +29 -5
- sglang/srt/managers/cache_controller.py +62 -96
- sglang/srt/managers/detokenizer_manager.py +9 -2
- sglang/srt/managers/io_struct.py +27 -0
- sglang/srt/managers/mm_utils.py +5 -1
- sglang/srt/managers/multi_tokenizer_mixin.py +629 -0
- sglang/srt/managers/scheduler.py +39 -2
- sglang/srt/managers/scheduler_output_processor_mixin.py +20 -18
- sglang/srt/managers/scheduler_update_weights_mixin.py +8 -1
- sglang/srt/managers/tokenizer_manager.py +86 -39
- sglang/srt/mem_cache/chunk_cache.py +1 -1
- sglang/srt/mem_cache/hicache_storage.py +20 -3
- sglang/srt/mem_cache/hiradix_cache.py +94 -71
- sglang/srt/mem_cache/lora_radix_cache.py +1 -1
- sglang/srt/mem_cache/memory_pool.py +4 -0
- sglang/srt/mem_cache/memory_pool_host.py +4 -4
- sglang/srt/mem_cache/radix_cache.py +5 -4
- sglang/srt/mem_cache/radix_cache_cpp.py +1 -1
- sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +61 -34
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +56 -9
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +2 -1
- sglang/srt/mem_cache/swa_radix_cache.py +1 -1
- sglang/srt/model_executor/model_runner.py +5 -4
- sglang/srt/model_loader/loader.py +15 -24
- sglang/srt/model_loader/utils.py +12 -0
- sglang/srt/models/deepseek_v2.py +31 -10
- sglang/srt/models/gpt_oss.py +5 -18
- sglang/srt/models/llama_eagle3.py +4 -0
- sglang/srt/models/longcat_flash.py +1026 -0
- sglang/srt/models/longcat_flash_nextn.py +699 -0
- sglang/srt/models/qwen2.py +26 -3
- sglang/srt/models/qwen2_5_vl.py +65 -41
- sglang/srt/models/qwen2_moe.py +22 -2
- sglang/srt/models/transformers.py +1 -1
- sglang/srt/multimodal/processors/base_processor.py +4 -2
- sglang/srt/sampling/penaltylib/orchestrator.py +14 -2
- sglang/srt/server_args.py +112 -55
- sglang/srt/speculative/eagle_worker.py +28 -8
- sglang/srt/utils.py +4 -0
- sglang/test/attention/test_trtllm_mla_backend.py +12 -3
- sglang/test/test_cutlass_w4a8_moe.py +24 -9
- sglang/version.py +1 -1
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2rc1.dist-info}/METADATA +5 -5
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2rc1.dist-info}/RECORD +93 -85
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2rc1.dist-info}/WHEEL +0 -0
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2rc1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2rc1.dist-info}/top_level.txt +0 -0
sglang/srt/layers/linear.py
CHANGED
@@ -235,8 +235,9 @@ class ReplicatedLinear(LinearBase):
|
|
235
235
|
loaded_weight = loaded_weight[:1]
|
236
236
|
else:
|
237
237
|
raise ValueError(f"{loaded_weight} are not all equal")
|
238
|
-
|
239
|
-
|
238
|
+
assert (
|
239
|
+
param.size() == loaded_weight.size()
|
240
|
+
), f"Loading weight error: param: {param.size()}, loaded_weight: {loaded_weight.size()}"
|
240
241
|
param.data.copy_(loaded_weight)
|
241
242
|
|
242
243
|
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
@@ -61,7 +61,7 @@ class LogitsProcessorOutput:
|
|
61
61
|
hidden_states: Optional[torch.Tensor] = None
|
62
62
|
|
63
63
|
## Part 2: This part will be assigned in python/sglang/srt/layers/sampler.py::Sampler
|
64
|
-
#
|
64
|
+
# he log probs of output tokens, if RETURN_ORIGINAL_LOGPROB = True, will get the log probs before applying temperature. If False, will get the log probs before applying temperature.
|
65
65
|
next_token_logprobs: Optional[torch.Tensor] = None
|
66
66
|
# The logprobs and ids of the top-k tokens in output positions. shape: [#seq, k]
|
67
67
|
next_token_top_logprobs_val: Optional[List] = None
|
@@ -91,18 +91,10 @@ def cutlass_w4a8_moe(
|
|
91
91
|
assert w1_q.shape[0] == w2_q.shape[0], "Expert number mismatch"
|
92
92
|
assert w1_q.shape[0] == w1_scale.shape[0], "w1 scales expert number mismatch"
|
93
93
|
assert w1_q.shape[0] == w2_scale.shape[0], "w2 scales expert number mismatch"
|
94
|
-
assert (
|
95
|
-
w1_scale.shape[1] == w1_q.shape[2] * 2 / 512
|
96
|
-
and w1_scale.shape[2] == w1_q.shape[1] * 4
|
97
|
-
), "W1 scale shape mismatch"
|
98
|
-
assert (
|
99
|
-
w2_scale.shape[1] == w2_q.shape[2] * 2 / 512
|
100
|
-
and w2_scale.shape[2] == w2_q.shape[1] * 4
|
101
|
-
), "W2 scale shape mismatch"
|
102
94
|
|
103
95
|
assert a_strides1.shape[0] == w1_q.shape[0], "A Strides 1 expert number mismatch"
|
104
96
|
assert b_strides1.shape[0] == w1_q.shape[0], "B Strides 1 expert number mismatch"
|
105
|
-
assert a_strides2.shape[0] == w2_q.shape[0], "A Strides 2 expert number
|
97
|
+
assert a_strides2.shape[0] == w2_q.shape[0], "A Strides 2 expert number mismatch"
|
106
98
|
assert b_strides2.shape[0] == w2_q.shape[0], "B Strides 2 expert number mismatch"
|
107
99
|
num_experts = w1_q.size(0)
|
108
100
|
m = a.size(0)
|
@@ -1362,3 +1362,77 @@ def moe_ep_deepgemm_preprocess(
|
|
1362
1362
|
gateup_input,
|
1363
1363
|
gateup_input_scale,
|
1364
1364
|
)
|
1365
|
+
|
1366
|
+
|
1367
|
+
@triton.jit
|
1368
|
+
def compute_identity_kernel(
|
1369
|
+
top_k,
|
1370
|
+
hidden_states_ptr,
|
1371
|
+
expert_scales_ptr,
|
1372
|
+
num_tokens,
|
1373
|
+
output_ptr,
|
1374
|
+
hidden_dim,
|
1375
|
+
scales_stride,
|
1376
|
+
BLOCK_SIZE: tl.constexpr,
|
1377
|
+
):
|
1378
|
+
pid = tl.program_id(0)
|
1379
|
+
|
1380
|
+
batch_id = pid // (hidden_dim // BLOCK_SIZE)
|
1381
|
+
dim_offset = pid % (hidden_dim // BLOCK_SIZE) * BLOCK_SIZE
|
1382
|
+
|
1383
|
+
if batch_id >= num_tokens or dim_offset >= hidden_dim:
|
1384
|
+
return
|
1385
|
+
|
1386
|
+
h = tl.load(
|
1387
|
+
hidden_states_ptr
|
1388
|
+
+ batch_id * hidden_dim
|
1389
|
+
+ dim_offset
|
1390
|
+
+ tl.arange(0, BLOCK_SIZE),
|
1391
|
+
mask=(dim_offset + tl.arange(0, BLOCK_SIZE)) < hidden_dim,
|
1392
|
+
)
|
1393
|
+
|
1394
|
+
result = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
|
1395
|
+
for i in range(top_k):
|
1396
|
+
scale = tl.load(expert_scales_ptr + batch_id * scales_stride + i)
|
1397
|
+
result += h * scale
|
1398
|
+
|
1399
|
+
tl.store(
|
1400
|
+
output_ptr + batch_id * hidden_dim + dim_offset + tl.arange(0, BLOCK_SIZE),
|
1401
|
+
result,
|
1402
|
+
mask=(dim_offset + tl.arange(0, BLOCK_SIZE)) < hidden_dim,
|
1403
|
+
)
|
1404
|
+
|
1405
|
+
|
1406
|
+
def zero_experts_compute_triton(
|
1407
|
+
expert_indices, expert_scales, num_experts, zero_expert_type, hidden_states
|
1408
|
+
):
|
1409
|
+
N = expert_indices.numel()
|
1410
|
+
top_k = expert_indices.size(-1)
|
1411
|
+
grid = lambda meta: (triton.cdiv(N, meta["BLOCK_SIZE"]),)
|
1412
|
+
|
1413
|
+
if zero_expert_type == "identity":
|
1414
|
+
zero_expert_mask = expert_indices < num_experts
|
1415
|
+
zero_expert_scales = expert_scales.clone()
|
1416
|
+
zero_expert_scales[zero_expert_mask] = 0.0
|
1417
|
+
|
1418
|
+
normal_expert_mask = expert_indices >= num_experts
|
1419
|
+
expert_indices[normal_expert_mask] = 0
|
1420
|
+
expert_scales[normal_expert_mask] = 0.0
|
1421
|
+
|
1422
|
+
output = torch.zeros_like(hidden_states).to(hidden_states.device)
|
1423
|
+
hidden_dim = hidden_states.size(-1)
|
1424
|
+
num_tokens = hidden_states.size(0)
|
1425
|
+
|
1426
|
+
grid = lambda meta: (num_tokens * (hidden_dim // meta["BLOCK_SIZE"]),)
|
1427
|
+
compute_identity_kernel[grid](
|
1428
|
+
top_k,
|
1429
|
+
hidden_states,
|
1430
|
+
zero_expert_scales,
|
1431
|
+
num_tokens,
|
1432
|
+
output,
|
1433
|
+
hidden_dim,
|
1434
|
+
zero_expert_scales.stride(0),
|
1435
|
+
BLOCK_SIZE=256,
|
1436
|
+
)
|
1437
|
+
|
1438
|
+
return output
|
@@ -114,9 +114,6 @@ class EPMoE(FusedMoE):
|
|
114
114
|
with_bias=with_bias,
|
115
115
|
)
|
116
116
|
|
117
|
-
self.start_expert_id = self.moe_ep_rank * self.num_local_experts
|
118
|
-
self.end_expert_id = self.start_expert_id + self.num_local_experts - 1
|
119
|
-
|
120
117
|
self.intermediate_size = intermediate_size
|
121
118
|
|
122
119
|
if isinstance(quant_config, Fp8Config):
|
@@ -232,7 +229,7 @@ class EPMoE(FusedMoE):
|
|
232
229
|
(
|
233
230
|
_cast_to_e8m0_with_rounding_up(gateup_input_scale)
|
234
231
|
if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0
|
235
|
-
else deep_gemm_wrapper.
|
232
|
+
else deep_gemm_wrapper.get_mn_major_tma_aligned_tensor(
|
236
233
|
gateup_input_scale
|
237
234
|
)
|
238
235
|
),
|
@@ -289,9 +286,7 @@ class EPMoE(FusedMoE):
|
|
289
286
|
(
|
290
287
|
down_input_scale
|
291
288
|
if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0
|
292
|
-
else deep_gemm_wrapper.
|
293
|
-
down_input_scale
|
294
|
-
)
|
289
|
+
else deep_gemm_wrapper.get_mn_major_tma_aligned_tensor(down_input_scale)
|
295
290
|
),
|
296
291
|
)
|
297
292
|
down_output = torch.empty(
|
@@ -746,19 +741,25 @@ class DeepEPMoE(EPMoE):
|
|
746
741
|
hidden_states = torch_npu.npu_grouped_matmul(
|
747
742
|
x=[hidden_states],
|
748
743
|
weight=[self.w13_weight],
|
749
|
-
scale=[self.w13_weight_scale.to(output_dtype)],
|
750
|
-
per_token_scale=[pertoken_scale],
|
751
744
|
split_item=2,
|
752
745
|
group_list_type=group_list_type,
|
753
746
|
group_type=0,
|
754
747
|
group_list=seg_indptr,
|
755
|
-
output_dtype=
|
748
|
+
output_dtype=torch.int32,
|
756
749
|
)[0]
|
757
750
|
|
758
751
|
# act_fn: swiglu
|
759
|
-
hidden_states = torch_npu.
|
760
|
-
|
761
|
-
|
752
|
+
hidden_states, swiglu_out_scale = torch_npu.npu_dequant_swiglu_quant(
|
753
|
+
x=hidden_states,
|
754
|
+
weight_scale=self.w13_weight_scale.to(torch.float32),
|
755
|
+
activation_scale=pertoken_scale,
|
756
|
+
bias=None,
|
757
|
+
quant_scale=None,
|
758
|
+
quant_offset=None,
|
759
|
+
group_index=seg_indptr,
|
760
|
+
activate_left=True,
|
761
|
+
quant_mode=1,
|
762
|
+
)
|
762
763
|
|
763
764
|
# gmm2: down_proj
|
764
765
|
hidden_states = torch_npu.npu_grouped_matmul(
|
@@ -1,16 +1,18 @@
|
|
1
1
|
from contextlib import contextmanager
|
2
2
|
from typing import Any, Dict, Optional
|
3
3
|
|
4
|
-
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import
|
5
|
-
|
4
|
+
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
|
5
|
+
from sglang.srt.layers.moe.fused_moe_triton.fused_moe_triton_config import (
|
6
6
|
get_config_file_name,
|
7
|
-
moe_align_block_size,
|
8
7
|
try_get_optimal_moe_config,
|
9
8
|
)
|
10
9
|
from sglang.srt.layers.moe.fused_moe_triton.layer import (
|
11
10
|
FusedMoE,
|
12
11
|
FusedMoeWeightScaleSupported,
|
13
12
|
)
|
13
|
+
from sglang.srt.layers.moe.fused_moe_triton.moe_align_block_size import (
|
14
|
+
moe_align_block_size,
|
15
|
+
)
|
14
16
|
|
15
17
|
_config: Optional[Dict[str, Any]] = None
|
16
18
|
|
@@ -0,0 +1,146 @@
|
|
1
|
+
{
|
2
|
+
"1": {
|
3
|
+
"BLOCK_SIZE_M": 16,
|
4
|
+
"BLOCK_SIZE_N": 64,
|
5
|
+
"BLOCK_SIZE_K": 64,
|
6
|
+
"GROUP_SIZE_M": 1,
|
7
|
+
"num_warps": 4,
|
8
|
+
"num_stages": 5
|
9
|
+
},
|
10
|
+
"2": {
|
11
|
+
"BLOCK_SIZE_M": 16,
|
12
|
+
"BLOCK_SIZE_N": 128,
|
13
|
+
"BLOCK_SIZE_K": 128,
|
14
|
+
"GROUP_SIZE_M": 1,
|
15
|
+
"num_warps": 4,
|
16
|
+
"num_stages": 4
|
17
|
+
},
|
18
|
+
"4": {
|
19
|
+
"BLOCK_SIZE_M": 16,
|
20
|
+
"BLOCK_SIZE_N": 128,
|
21
|
+
"BLOCK_SIZE_K": 128,
|
22
|
+
"GROUP_SIZE_M": 1,
|
23
|
+
"num_warps": 4,
|
24
|
+
"num_stages": 5
|
25
|
+
},
|
26
|
+
"8": {
|
27
|
+
"BLOCK_SIZE_M": 16,
|
28
|
+
"BLOCK_SIZE_N": 128,
|
29
|
+
"BLOCK_SIZE_K": 128,
|
30
|
+
"GROUP_SIZE_M": 1,
|
31
|
+
"num_warps": 4,
|
32
|
+
"num_stages": 4
|
33
|
+
},
|
34
|
+
"16": {
|
35
|
+
"BLOCK_SIZE_M": 16,
|
36
|
+
"BLOCK_SIZE_N": 128,
|
37
|
+
"BLOCK_SIZE_K": 128,
|
38
|
+
"GROUP_SIZE_M": 1,
|
39
|
+
"num_warps": 4,
|
40
|
+
"num_stages": 3
|
41
|
+
},
|
42
|
+
"24": {
|
43
|
+
"BLOCK_SIZE_M": 16,
|
44
|
+
"BLOCK_SIZE_N": 128,
|
45
|
+
"BLOCK_SIZE_K": 128,
|
46
|
+
"GROUP_SIZE_M": 1,
|
47
|
+
"num_warps": 4,
|
48
|
+
"num_stages": 3
|
49
|
+
},
|
50
|
+
"32": {
|
51
|
+
"BLOCK_SIZE_M": 16,
|
52
|
+
"BLOCK_SIZE_N": 128,
|
53
|
+
"BLOCK_SIZE_K": 128,
|
54
|
+
"GROUP_SIZE_M": 1,
|
55
|
+
"num_warps": 4,
|
56
|
+
"num_stages": 3
|
57
|
+
},
|
58
|
+
"48": {
|
59
|
+
"BLOCK_SIZE_M": 16,
|
60
|
+
"BLOCK_SIZE_N": 256,
|
61
|
+
"BLOCK_SIZE_K": 128,
|
62
|
+
"GROUP_SIZE_M": 1,
|
63
|
+
"num_warps": 4,
|
64
|
+
"num_stages": 3
|
65
|
+
},
|
66
|
+
"64": {
|
67
|
+
"BLOCK_SIZE_M": 16,
|
68
|
+
"BLOCK_SIZE_N": 256,
|
69
|
+
"BLOCK_SIZE_K": 64,
|
70
|
+
"GROUP_SIZE_M": 1,
|
71
|
+
"num_warps": 4,
|
72
|
+
"num_stages": 3
|
73
|
+
},
|
74
|
+
"96": {
|
75
|
+
"BLOCK_SIZE_M": 16,
|
76
|
+
"BLOCK_SIZE_N": 256,
|
77
|
+
"BLOCK_SIZE_K": 64,
|
78
|
+
"GROUP_SIZE_M": 1,
|
79
|
+
"num_warps": 4,
|
80
|
+
"num_stages": 4
|
81
|
+
},
|
82
|
+
"128": {
|
83
|
+
"BLOCK_SIZE_M": 16,
|
84
|
+
"BLOCK_SIZE_N": 256,
|
85
|
+
"BLOCK_SIZE_K": 128,
|
86
|
+
"GROUP_SIZE_M": 1,
|
87
|
+
"num_warps": 4,
|
88
|
+
"num_stages": 3
|
89
|
+
},
|
90
|
+
"256": {
|
91
|
+
"BLOCK_SIZE_M": 16,
|
92
|
+
"BLOCK_SIZE_N": 256,
|
93
|
+
"BLOCK_SIZE_K": 64,
|
94
|
+
"GROUP_SIZE_M": 1,
|
95
|
+
"num_warps": 4,
|
96
|
+
"num_stages": 3
|
97
|
+
},
|
98
|
+
"512": {
|
99
|
+
"BLOCK_SIZE_M": 64,
|
100
|
+
"BLOCK_SIZE_N": 128,
|
101
|
+
"BLOCK_SIZE_K": 128,
|
102
|
+
"GROUP_SIZE_M": 1,
|
103
|
+
"num_warps": 8,
|
104
|
+
"num_stages": 3
|
105
|
+
},
|
106
|
+
"1024": {
|
107
|
+
"BLOCK_SIZE_M": 64,
|
108
|
+
"BLOCK_SIZE_N": 256,
|
109
|
+
"BLOCK_SIZE_K": 64,
|
110
|
+
"GROUP_SIZE_M": 16,
|
111
|
+
"num_warps": 8,
|
112
|
+
"num_stages": 5
|
113
|
+
},
|
114
|
+
"1536": {
|
115
|
+
"BLOCK_SIZE_M": 128,
|
116
|
+
"BLOCK_SIZE_N": 256,
|
117
|
+
"BLOCK_SIZE_K": 64,
|
118
|
+
"GROUP_SIZE_M": 16,
|
119
|
+
"num_warps": 4,
|
120
|
+
"num_stages": 4
|
121
|
+
},
|
122
|
+
"2048": {
|
123
|
+
"BLOCK_SIZE_M": 128,
|
124
|
+
"BLOCK_SIZE_N": 256,
|
125
|
+
"BLOCK_SIZE_K": 64,
|
126
|
+
"GROUP_SIZE_M": 1,
|
127
|
+
"num_warps": 4,
|
128
|
+
"num_stages": 4
|
129
|
+
},
|
130
|
+
"3072": {
|
131
|
+
"BLOCK_SIZE_M": 256,
|
132
|
+
"BLOCK_SIZE_N": 256,
|
133
|
+
"BLOCK_SIZE_K": 64,
|
134
|
+
"GROUP_SIZE_M": 16,
|
135
|
+
"num_warps": 8,
|
136
|
+
"num_stages": 5
|
137
|
+
},
|
138
|
+
"4096": {
|
139
|
+
"BLOCK_SIZE_M": 256,
|
140
|
+
"BLOCK_SIZE_N": 256,
|
141
|
+
"BLOCK_SIZE_K": 64,
|
142
|
+
"GROUP_SIZE_M": 16,
|
143
|
+
"num_warps": 8,
|
144
|
+
"num_stages": 5
|
145
|
+
}
|
146
|
+
}
|