sglang 0.4.4__py3-none-any.whl → 0.4.4.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -0
- sglang/api.py +6 -0
- sglang/bench_one_batch.py +1 -1
- sglang/bench_one_batch_server.py +1 -1
- sglang/bench_serving.py +3 -1
- sglang/check_env.py +3 -4
- sglang/lang/backend/openai.py +18 -5
- sglang/lang/chat_template.py +28 -7
- sglang/lang/interpreter.py +7 -3
- sglang/lang/ir.py +10 -0
- sglang/srt/_custom_ops.py +1 -1
- sglang/srt/code_completion_parser.py +174 -0
- sglang/srt/configs/__init__.py +2 -6
- sglang/srt/configs/deepseekvl2.py +667 -0
- sglang/srt/configs/janus_pro.py +3 -4
- sglang/srt/configs/load_config.py +1 -0
- sglang/srt/configs/model_config.py +63 -11
- sglang/srt/configs/utils.py +25 -0
- sglang/srt/connector/__init__.py +51 -0
- sglang/srt/connector/base_connector.py +112 -0
- sglang/srt/connector/redis.py +85 -0
- sglang/srt/connector/s3.py +122 -0
- sglang/srt/connector/serde/__init__.py +31 -0
- sglang/srt/connector/serde/safe_serde.py +29 -0
- sglang/srt/connector/serde/serde.py +43 -0
- sglang/srt/connector/utils.py +35 -0
- sglang/srt/conversation.py +88 -0
- sglang/srt/disaggregation/conn.py +81 -0
- sglang/srt/disaggregation/decode.py +495 -0
- sglang/srt/disaggregation/mini_lb.py +285 -0
- sglang/srt/disaggregation/prefill.py +249 -0
- sglang/srt/disaggregation/utils.py +44 -0
- sglang/srt/distributed/parallel_state.py +10 -3
- sglang/srt/entrypoints/engine.py +55 -5
- sglang/srt/entrypoints/http_server.py +71 -12
- sglang/srt/function_call_parser.py +164 -54
- sglang/srt/hf_transformers_utils.py +28 -3
- sglang/srt/layers/activation.py +4 -2
- sglang/srt/layers/attention/base_attn_backend.py +1 -1
- sglang/srt/layers/attention/flashattention_backend.py +295 -0
- sglang/srt/layers/attention/flashinfer_backend.py +1 -1
- sglang/srt/layers/attention/flashmla_backend.py +284 -0
- sglang/srt/layers/attention/triton_backend.py +171 -38
- sglang/srt/layers/attention/triton_ops/decode_attention.py +94 -31
- sglang/srt/layers/attention/triton_ops/extend_attention.py +14 -5
- sglang/srt/layers/attention/utils.py +53 -0
- sglang/srt/layers/attention/vision.py +9 -28
- sglang/srt/layers/dp_attention.py +62 -23
- sglang/srt/layers/elementwise.py +411 -0
- sglang/srt/layers/layernorm.py +24 -2
- sglang/srt/layers/linear.py +17 -5
- sglang/srt/layers/logits_processor.py +26 -7
- sglang/srt/layers/moe/ep_moe/kernels.py +110 -11
- sglang/srt/layers/moe/ep_moe/layer.py +273 -1
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +416 -0
- sglang/srt/layers/moe/fused_moe_native.py +2 -1
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L20,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L40S,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1024,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +23 -32
- sglang/srt/layers/moe/fused_moe_triton/layer.py +1 -2
- sglang/srt/layers/moe/router.py +342 -0
- sglang/srt/layers/moe/topk.py +31 -18
- sglang/srt/layers/parameter.py +1 -1
- sglang/srt/layers/quantization/__init__.py +184 -126
- sglang/srt/layers/quantization/base_config.py +5 -0
- sglang/srt/layers/quantization/blockwise_int8.py +1 -1
- sglang/srt/layers/quantization/compressed_tensors/__init__.py +0 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +652 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +658 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +9 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py +56 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +162 -0
- sglang/srt/layers/quantization/compressed_tensors/utils.py +218 -0
- sglang/srt/layers/quantization/fp8.py +76 -34
- sglang/srt/layers/quantization/fp8_kernel.py +24 -8
- sglang/srt/layers/quantization/fp8_utils.py +284 -28
- sglang/srt/layers/quantization/gptq.py +36 -9
- sglang/srt/layers/quantization/kv_cache.py +98 -0
- sglang/srt/layers/quantization/modelopt_quant.py +9 -7
- sglang/srt/layers/quantization/utils.py +153 -0
- sglang/srt/layers/quantization/w8a8_fp8.py +70 -19
- sglang/srt/layers/rotary_embedding.py +66 -87
- sglang/srt/layers/sampler.py +1 -1
- sglang/srt/lora/layers.py +68 -0
- sglang/srt/lora/lora.py +2 -22
- sglang/srt/lora/lora_manager.py +47 -23
- sglang/srt/lora/mem_pool.py +110 -51
- sglang/srt/lora/utils.py +12 -1
- sglang/srt/managers/cache_controller.py +4 -5
- sglang/srt/managers/data_parallel_controller.py +31 -9
- sglang/srt/managers/expert_distribution.py +81 -0
- sglang/srt/managers/io_struct.py +39 -3
- sglang/srt/managers/mm_utils.py +373 -0
- sglang/srt/managers/multimodal_processor.py +68 -0
- sglang/srt/managers/multimodal_processors/base_processor.py +275 -0
- sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +119 -0
- sglang/srt/managers/multimodal_processors/gemma3.py +83 -0
- sglang/srt/managers/{image_processors → multimodal_processors}/janus_pro.py +20 -15
- sglang/srt/managers/{image_processors → multimodal_processors}/llava.py +10 -15
- sglang/srt/managers/multimodal_processors/minicpm.py +167 -0
- sglang/srt/managers/{image_processors → multimodal_processors}/mlama.py +7 -8
- sglang/srt/managers/{image_processors → multimodal_processors}/qwen_vl.py +28 -22
- sglang/srt/managers/schedule_batch.py +134 -31
- sglang/srt/managers/scheduler.py +325 -38
- sglang/srt/managers/scheduler_output_processor_mixin.py +4 -1
- sglang/srt/managers/session_controller.py +1 -1
- sglang/srt/managers/tokenizer_manager.py +59 -23
- sglang/srt/managers/tp_worker.py +1 -1
- sglang/srt/managers/tp_worker_overlap_thread.py +3 -3
- sglang/srt/managers/utils.py +6 -1
- sglang/srt/mem_cache/hiradix_cache.py +27 -8
- sglang/srt/mem_cache/memory_pool.py +258 -98
- sglang/srt/mem_cache/paged_allocator.py +2 -2
- sglang/srt/mem_cache/radix_cache.py +4 -4
- sglang/srt/model_executor/cuda_graph_runner.py +85 -28
- sglang/srt/model_executor/forward_batch_info.py +81 -15
- sglang/srt/model_executor/model_runner.py +70 -6
- sglang/srt/model_loader/loader.py +160 -2
- sglang/srt/model_loader/weight_utils.py +45 -0
- sglang/srt/models/deepseek_janus_pro.py +29 -86
- sglang/srt/models/deepseek_nextn.py +22 -10
- sglang/srt/models/deepseek_v2.py +326 -192
- sglang/srt/models/deepseek_vl2.py +358 -0
- sglang/srt/models/gemma3_causal.py +684 -0
- sglang/srt/models/gemma3_mm.py +462 -0
- sglang/srt/models/grok.py +374 -119
- sglang/srt/models/llama.py +47 -7
- sglang/srt/models/llama_eagle.py +1 -0
- sglang/srt/models/llama_eagle3.py +196 -0
- sglang/srt/models/llava.py +3 -3
- sglang/srt/models/llavavid.py +3 -3
- sglang/srt/models/minicpmo.py +1995 -0
- sglang/srt/models/minicpmv.py +62 -137
- sglang/srt/models/mllama.py +4 -4
- sglang/srt/models/phi3_small.py +1 -1
- sglang/srt/models/qwen2.py +3 -0
- sglang/srt/models/qwen2_5_vl.py +68 -146
- sglang/srt/models/qwen2_classification.py +75 -0
- sglang/srt/models/qwen2_moe.py +9 -1
- sglang/srt/models/qwen2_vl.py +25 -63
- sglang/srt/openai_api/adapter.py +145 -47
- sglang/srt/openai_api/protocol.py +23 -2
- sglang/srt/sampling/sampling_batch_info.py +1 -1
- sglang/srt/sampling/sampling_params.py +6 -6
- sglang/srt/server_args.py +104 -14
- sglang/srt/speculative/build_eagle_tree.py +7 -347
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +41 -5
- sglang/srt/speculative/eagle_utils.py +208 -252
- sglang/srt/speculative/eagle_worker.py +139 -53
- sglang/srt/speculative/spec_info.py +6 -1
- sglang/srt/torch_memory_saver_adapter.py +22 -0
- sglang/srt/utils.py +182 -21
- sglang/test/__init__.py +0 -0
- sglang/test/attention/__init__.py +0 -0
- sglang/test/attention/test_flashattn_backend.py +312 -0
- sglang/test/runners.py +2 -0
- sglang/test/test_activation.py +2 -1
- sglang/test/test_block_fp8.py +5 -4
- sglang/test/test_block_fp8_ep.py +2 -1
- sglang/test/test_dynamic_grad_mode.py +58 -0
- sglang/test/test_layernorm.py +3 -2
- sglang/test/test_utils.py +55 -4
- sglang/utils.py +31 -0
- sglang/version.py +1 -1
- {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/METADATA +12 -8
- {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/RECORD +171 -125
- {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/WHEEL +1 -1
- sglang/srt/configs/qwen2_5_vl_config.py +0 -1006
- sglang/srt/managers/image_processor.py +0 -55
- sglang/srt/managers/image_processors/base_image_processor.py +0 -219
- sglang/srt/managers/image_processors/minicpmv.py +0 -86
- sglang/srt/managers/multi_modality_padding.py +0 -134
- {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info/licenses}/LICENSE +0 -0
- {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/top_level.txt +0 -0
@@ -1,6 +1,8 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
3
|
import functools
|
4
|
+
import logging
|
5
|
+
from contextlib import contextmanager
|
4
6
|
from typing import TYPE_CHECKING, Union
|
5
7
|
|
6
8
|
import torch
|
@@ -14,6 +16,8 @@ from sglang.srt.distributed import (
|
|
14
16
|
tensor_model_parallel_all_reduce,
|
15
17
|
)
|
16
18
|
|
19
|
+
logger = logging.getLogger(__name__)
|
20
|
+
|
17
21
|
if TYPE_CHECKING:
|
18
22
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
19
23
|
|
@@ -34,7 +38,12 @@ def compute_dp_attention_world_info(enable_dp_attention, tp_rank, tp_size, dp_si
|
|
34
38
|
return attn_tp_rank, attn_tp_size, dp_rank
|
35
39
|
|
36
40
|
|
37
|
-
def initialize_dp_attention(
|
41
|
+
def initialize_dp_attention(
|
42
|
+
enable_dp_attention: bool,
|
43
|
+
tp_rank: int,
|
44
|
+
tp_size: int,
|
45
|
+
dp_size: int,
|
46
|
+
):
|
38
47
|
global _ATTN_TP_GROUP, _ATTN_TP_RANK, _ATTN_TP_SIZE, _DP_RANK, _DP_SIZE
|
39
48
|
|
40
49
|
from sglang.srt.layers.sampler import SYNC_TOKEN_IDS_ACROSS_TP
|
@@ -42,7 +51,11 @@ def initialize_dp_attention(enable_dp_attention, tp_rank, tp_size, dp_size):
|
|
42
51
|
_ATTN_TP_RANK, _ATTN_TP_SIZE, _DP_RANK = compute_dp_attention_world_info(
|
43
52
|
enable_dp_attention, tp_rank, tp_size, dp_size
|
44
53
|
)
|
45
|
-
|
54
|
+
|
55
|
+
if enable_dp_attention:
|
56
|
+
_DP_SIZE = dp_size
|
57
|
+
else:
|
58
|
+
_DP_SIZE = 1
|
46
59
|
|
47
60
|
tp_group = get_tp_group()
|
48
61
|
_ATTN_TP_GROUP = GroupCoordinator(
|
@@ -50,7 +63,7 @@ def initialize_dp_attention(enable_dp_attention, tp_rank, tp_size, dp_size):
|
|
50
63
|
list(range(head, head + _ATTN_TP_SIZE))
|
51
64
|
for head in range(0, tp_size, _ATTN_TP_SIZE)
|
52
65
|
],
|
53
|
-
|
66
|
+
tp_group.local_rank,
|
54
67
|
torch.distributed.get_backend(tp_group.device_group),
|
55
68
|
SYNC_TOKEN_IDS_ACROSS_TP,
|
56
69
|
False,
|
@@ -86,6 +99,27 @@ def get_attention_dp_size():
|
|
86
99
|
return _DP_SIZE
|
87
100
|
|
88
101
|
|
102
|
+
@contextmanager
|
103
|
+
def disable_dp_size():
|
104
|
+
"""Patch the tp group temporarily until this function ends.
|
105
|
+
|
106
|
+
This method is for draft workers of speculative decoding to run draft model
|
107
|
+
with different tp degree from that of target model workers.
|
108
|
+
|
109
|
+
Args:
|
110
|
+
tp_group (GroupCoordinator): the tp group coordinator
|
111
|
+
"""
|
112
|
+
global _DP_SIZE
|
113
|
+
assert _DP_SIZE is not None, "dp attention not initialized!"
|
114
|
+
|
115
|
+
old_dp_size = _DP_SIZE
|
116
|
+
_DP_SIZE = 1
|
117
|
+
try:
|
118
|
+
yield
|
119
|
+
finally:
|
120
|
+
_DP_SIZE = old_dp_size
|
121
|
+
|
122
|
+
|
89
123
|
def get_dp_local_info(forward_batch: ForwardBatch):
|
90
124
|
dp_rank = get_attention_dp_rank()
|
91
125
|
|
@@ -144,22 +178,22 @@ def memcpy_triton(dst, src, dim, offset, sz, offset_src):
|
|
144
178
|
memcpy_triton_kernel[grid](dst, src, offset, sz, offset_src, chunk_size, BLOCK_SIZE)
|
145
179
|
|
146
180
|
|
147
|
-
def
|
181
|
+
def _dp_gather(
|
148
182
|
global_tokens: torch.Tensor,
|
149
183
|
local_tokens: torch.Tensor,
|
150
184
|
forward_batch: ForwardBatch,
|
151
|
-
|
185
|
+
is_partial: bool,
|
152
186
|
):
|
153
187
|
local_start_pos, local_num_tokens = get_dp_local_info(forward_batch)
|
154
188
|
|
155
189
|
global_tokens.fill_(0)
|
156
190
|
assert local_tokens.is_contiguous()
|
157
191
|
assert global_tokens.is_contiguous()
|
158
|
-
|
159
|
-
|
160
|
-
):
|
192
|
+
|
193
|
+
if local_tokens.shape[0] > 0 and (is_partial or get_attention_tp_rank() == 0):
|
161
194
|
assert (
|
162
|
-
global_tokens.
|
195
|
+
global_tokens.untyped_storage().data_ptr()
|
196
|
+
!= local_tokens.untyped_storage().data_ptr()
|
163
197
|
), "aliasing between global_tokens and local_tokens not allowed"
|
164
198
|
memcpy_triton(
|
165
199
|
global_tokens, local_tokens, 0, local_start_pos, local_num_tokens, False
|
@@ -174,8 +208,25 @@ def dp_gather(
|
|
174
208
|
torch.ops.sglang.inplace_all_reduce(
|
175
209
|
global_tokens, group_name=get_tp_group().unique_name
|
176
210
|
)
|
211
|
+
|
177
212
|
else:
|
178
|
-
global_tokens = tensor_model_parallel_all_reduce(global_tokens)
|
213
|
+
global_tokens[:] = tensor_model_parallel_all_reduce(global_tokens)
|
214
|
+
|
215
|
+
|
216
|
+
def dp_gather_partial(
|
217
|
+
global_tokens: torch.Tensor,
|
218
|
+
local_tokens: torch.Tensor,
|
219
|
+
forward_batch: ForwardBatch,
|
220
|
+
):
|
221
|
+
_dp_gather(global_tokens, local_tokens, forward_batch, is_partial=True)
|
222
|
+
|
223
|
+
|
224
|
+
def dp_gather_replicate(
|
225
|
+
global_tokens: torch.Tensor,
|
226
|
+
local_tokens: torch.Tensor,
|
227
|
+
forward_batch: ForwardBatch,
|
228
|
+
):
|
229
|
+
_dp_gather(global_tokens, local_tokens, forward_batch, is_partial=False)
|
179
230
|
|
180
231
|
|
181
232
|
def dp_scatter(
|
@@ -186,6 +237,7 @@ def dp_scatter(
|
|
186
237
|
# local_num_tokens is not necessarily the same as local_tokens.shape[0],
|
187
238
|
# since local_tokens may be padded for cuda graph
|
188
239
|
local_start_pos, local_num_tokens = get_dp_local_info(forward_batch)
|
240
|
+
|
189
241
|
local_tokens.fill_(0)
|
190
242
|
assert local_tokens.is_contiguous()
|
191
243
|
assert global_tokens.is_contiguous()
|
@@ -197,16 +249,3 @@ def dp_scatter(
|
|
197
249
|
memcpy_triton(
|
198
250
|
local_tokens, global_tokens, 0, local_start_pos, local_num_tokens, True
|
199
251
|
)
|
200
|
-
|
201
|
-
|
202
|
-
def get_do_logits_dp_scatter(forward_batch: ForwardBatch):
|
203
|
-
def do_logits_dp_scatter(logits: torch.Tensor):
|
204
|
-
local_logits = torch.empty(
|
205
|
-
(forward_batch.input_ids.shape[0], *logits.shape[1:]),
|
206
|
-
dtype=logits.dtype,
|
207
|
-
device=logits.device,
|
208
|
-
)
|
209
|
-
dp_scatter(local_logits, logits, forward_batch)
|
210
|
-
return local_logits
|
211
|
-
|
212
|
-
return do_logits_dp_scatter
|
@@ -0,0 +1,411 @@
|
|
1
|
+
from typing import Tuple
|
2
|
+
|
3
|
+
import torch
|
4
|
+
import triton
|
5
|
+
import triton.language as tl
|
6
|
+
|
7
|
+
fused_softcap_autotune = triton.autotune(
|
8
|
+
configs=[
|
9
|
+
triton.Config(kwargs={"BLOCK_SIZE": 128}, num_warps=4),
|
10
|
+
triton.Config(kwargs={"BLOCK_SIZE": 128}, num_warps=8),
|
11
|
+
triton.Config(kwargs={"BLOCK_SIZE": 128}, num_warps=16),
|
12
|
+
triton.Config(kwargs={"BLOCK_SIZE": 256}, num_warps=4),
|
13
|
+
triton.Config(kwargs={"BLOCK_SIZE": 256}, num_warps=8),
|
14
|
+
triton.Config(kwargs={"BLOCK_SIZE": 512}, num_warps=4),
|
15
|
+
triton.Config(kwargs={"BLOCK_SIZE": 512}, num_warps=8),
|
16
|
+
triton.Config(kwargs={"BLOCK_SIZE": 512}, num_warps=16),
|
17
|
+
triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_warps=4),
|
18
|
+
triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_warps=8),
|
19
|
+
triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_warps=16),
|
20
|
+
triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_warps=32),
|
21
|
+
triton.Config(kwargs={"BLOCK_SIZE": 2048}, num_warps=32),
|
22
|
+
triton.Config(kwargs={"BLOCK_SIZE": 4096}, num_warps=32),
|
23
|
+
triton.Config(kwargs={"BLOCK_SIZE": 8192}, num_warps=32),
|
24
|
+
triton.Config(kwargs={"BLOCK_SIZE": 16384}, num_warps=32),
|
25
|
+
triton.Config(kwargs={"BLOCK_SIZE": 32768}, num_warps=32),
|
26
|
+
],
|
27
|
+
key=["n_ele"],
|
28
|
+
)
|
29
|
+
|
30
|
+
|
31
|
+
@triton.jit
|
32
|
+
def fused_softcap_kernel(
|
33
|
+
output_ptr,
|
34
|
+
input_ptr,
|
35
|
+
n_ele,
|
36
|
+
softcap_const: tl.constexpr,
|
37
|
+
BLOCK_SIZE: tl.constexpr,
|
38
|
+
):
|
39
|
+
pid = tl.program_id(axis=0)
|
40
|
+
block_start = pid * BLOCK_SIZE
|
41
|
+
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
42
|
+
mask = offsets < n_ele
|
43
|
+
x = tl.load(input_ptr + offsets, mask=mask)
|
44
|
+
fx = x.to(tl.float32)
|
45
|
+
fxs = fx / softcap_const
|
46
|
+
exped = tl.exp(2 * fxs)
|
47
|
+
top = exped - 1
|
48
|
+
bottom = exped + 1
|
49
|
+
output = top / bottom * softcap_const
|
50
|
+
tl.store(output_ptr + offsets, output, mask=mask)
|
51
|
+
|
52
|
+
|
53
|
+
fused_softcap_kernel_autotuned = fused_softcap_autotune(fused_softcap_kernel)
|
54
|
+
|
55
|
+
|
56
|
+
def fused_softcap(x, softcap_const, autotune=False):
|
57
|
+
output = torch.empty_like(x, dtype=torch.float32)
|
58
|
+
n_elements = output.numel()
|
59
|
+
if autotune:
|
60
|
+
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
|
61
|
+
fused_softcap_kernel_autotuned[grid](output, x, n_elements, softcap_const)
|
62
|
+
else:
|
63
|
+
fused_softcap_kernel[(triton.cdiv(n_elements, 128),)](
|
64
|
+
output, x, n_elements, softcap_const, BLOCK_SIZE=128, num_warps=8
|
65
|
+
)
|
66
|
+
return output
|
67
|
+
|
68
|
+
|
69
|
+
# cast to float + softcap
|
70
|
+
class Softcap:
|
71
|
+
def __init__(self, softcap_const: float):
|
72
|
+
self.softcap_const = softcap_const
|
73
|
+
|
74
|
+
def __call__(self, *args, **kwargs):
|
75
|
+
return self.forward(*args, **kwargs)
|
76
|
+
|
77
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
78
|
+
if x.is_cuda:
|
79
|
+
return self.forward_cuda(x)
|
80
|
+
else:
|
81
|
+
return self.forward_native(x)
|
82
|
+
|
83
|
+
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
|
84
|
+
return torch.tanh(x.float() / self.softcap_const) * self.softcap_const
|
85
|
+
|
86
|
+
def forward_cuda(self, x: torch.Tensor, autotune=False) -> torch.Tensor:
|
87
|
+
return fused_softcap(x, self.softcap_const, autotune=autotune)
|
88
|
+
|
89
|
+
|
90
|
+
rmsnorm_autotune = triton.autotune(
|
91
|
+
configs=[
|
92
|
+
triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_warps=4, num_stages=1),
|
93
|
+
triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_warps=8, num_stages=1),
|
94
|
+
triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_warps=16, num_stages=1),
|
95
|
+
triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_warps=4),
|
96
|
+
triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_warps=8),
|
97
|
+
triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_warps=16),
|
98
|
+
triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_warps=4, num_stages=4),
|
99
|
+
triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_warps=8, num_stages=4),
|
100
|
+
triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_warps=16, num_stages=4),
|
101
|
+
triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_warps=8, num_stages=8),
|
102
|
+
triton.Config(kwargs={"BLOCK_SIZE": 1024}, num_warps=16, num_stages=8),
|
103
|
+
triton.Config(kwargs={"BLOCK_SIZE": 2048}, num_warps=8),
|
104
|
+
triton.Config(kwargs={"BLOCK_SIZE": 2048}, num_warps=16),
|
105
|
+
triton.Config(kwargs={"BLOCK_SIZE": 2048}, num_warps=8, num_stages=4),
|
106
|
+
triton.Config(kwargs={"BLOCK_SIZE": 2048}, num_warps=16, num_stages=4),
|
107
|
+
triton.Config(kwargs={"BLOCK_SIZE": 4096}, num_warps=8),
|
108
|
+
triton.Config(kwargs={"BLOCK_SIZE": 4096}, num_warps=16),
|
109
|
+
triton.Config(kwargs={"BLOCK_SIZE": 8192}, num_warps=8),
|
110
|
+
triton.Config(kwargs={"BLOCK_SIZE": 8192}, num_warps=16),
|
111
|
+
triton.Config(kwargs={"BLOCK_SIZE": 8192}, num_warps=32),
|
112
|
+
triton.Config(kwargs={"BLOCK_SIZE": 8192}, num_warps=8, num_stages=1),
|
113
|
+
triton.Config(kwargs={"BLOCK_SIZE": 8192}, num_warps=16, num_stages=1),
|
114
|
+
triton.Config(kwargs={"BLOCK_SIZE": 8192}, num_warps=32, num_stages=1),
|
115
|
+
triton.Config(kwargs={"BLOCK_SIZE": 8192}, num_warps=8, num_stages=4),
|
116
|
+
triton.Config(kwargs={"BLOCK_SIZE": 8192}, num_warps=16, num_stages=4),
|
117
|
+
triton.Config(kwargs={"BLOCK_SIZE": 8192}, num_warps=32, num_stages=4),
|
118
|
+
triton.Config(kwargs={"BLOCK_SIZE": 16384}, num_warps=8),
|
119
|
+
triton.Config(kwargs={"BLOCK_SIZE": 16384}, num_warps=16),
|
120
|
+
triton.Config(kwargs={"BLOCK_SIZE": 16384}, num_warps=32),
|
121
|
+
triton.Config(kwargs={"BLOCK_SIZE": 16384}, num_warps=8, num_stages=1),
|
122
|
+
triton.Config(kwargs={"BLOCK_SIZE": 16384}, num_warps=16, num_stages=1),
|
123
|
+
triton.Config(kwargs={"BLOCK_SIZE": 16384}, num_warps=32, num_stages=1),
|
124
|
+
triton.Config(kwargs={"BLOCK_SIZE": 16384}, num_warps=8, num_stages=4),
|
125
|
+
triton.Config(kwargs={"BLOCK_SIZE": 16384}, num_warps=16, num_stages=4),
|
126
|
+
triton.Config(kwargs={"BLOCK_SIZE": 16384}, num_warps=32, num_stages=4),
|
127
|
+
],
|
128
|
+
key=["hidden_dim"],
|
129
|
+
)
|
130
|
+
|
131
|
+
|
132
|
+
@triton.jit
|
133
|
+
def fused_dual_residual_rmsnorm_kernel(
|
134
|
+
output_ptr,
|
135
|
+
mid_ptr,
|
136
|
+
activ_ptr,
|
137
|
+
residual_ptr,
|
138
|
+
weight1_ptr,
|
139
|
+
weight2_ptr,
|
140
|
+
eps: tl.constexpr,
|
141
|
+
hidden_dim: tl.constexpr,
|
142
|
+
BLOCK_SIZE: tl.constexpr,
|
143
|
+
):
|
144
|
+
pid = tl.program_id(axis=0)
|
145
|
+
input_start = pid * hidden_dim
|
146
|
+
|
147
|
+
offsets = tl.arange(0, BLOCK_SIZE)
|
148
|
+
mask = offsets < hidden_dim
|
149
|
+
|
150
|
+
a_ = tl.load(activ_ptr + input_start + offsets, mask=mask, other=0.0)
|
151
|
+
a = a_.to(tl.float32)
|
152
|
+
rms = tl.sqrt(tl.sum(a * a, axis=0) / hidden_dim + eps)
|
153
|
+
|
154
|
+
r = tl.load(residual_ptr + input_start + offsets, mask=mask, other=0.0)
|
155
|
+
w1_ = tl.load(weight1_ptr + offsets, mask=mask, other=0.0)
|
156
|
+
w1 = w1_.to(tl.float32)
|
157
|
+
|
158
|
+
a2r = r + (a / rms * w1).to(r.dtype)
|
159
|
+
tl.store(
|
160
|
+
mid_ptr + input_start + offsets,
|
161
|
+
a2r,
|
162
|
+
mask=mask,
|
163
|
+
)
|
164
|
+
|
165
|
+
a2r = a2r.to(tl.float32)
|
166
|
+
rms2 = tl.sqrt(tl.sum(a2r * a2r, axis=0) / hidden_dim + eps)
|
167
|
+
|
168
|
+
w2_ = tl.load(weight2_ptr + offsets, mask=mask, other=0.0)
|
169
|
+
w2 = w2_.to(tl.float32)
|
170
|
+
|
171
|
+
tl.store(
|
172
|
+
output_ptr + input_start + offsets,
|
173
|
+
a2r / rms2 * w2, # implicitly casts to output dtype here
|
174
|
+
mask=mask,
|
175
|
+
)
|
176
|
+
|
177
|
+
|
178
|
+
fused_dual_residual_rmsnorm_kernel_autotune = rmsnorm_autotune(
|
179
|
+
fused_dual_residual_rmsnorm_kernel
|
180
|
+
)
|
181
|
+
|
182
|
+
|
183
|
+
def fused_dual_residual_rmsnorm(x, residual, weight1, weight2, eps, autotune=False):
|
184
|
+
assert len(x.shape) == 2
|
185
|
+
assert x.shape == residual.shape and x.dtype == residual.dtype
|
186
|
+
output, mid = torch.empty_like(x), torch.empty_like(x)
|
187
|
+
bs, hidden_dim = x.shape
|
188
|
+
if autotune:
|
189
|
+
fused_dual_residual_rmsnorm_kernel_autotune[(bs,)](
|
190
|
+
output, mid, x, residual, weight1, weight2, eps=eps, hidden_dim=hidden_dim
|
191
|
+
)
|
192
|
+
else:
|
193
|
+
config = {
|
194
|
+
"BLOCK_SIZE": triton.next_power_of_2(hidden_dim),
|
195
|
+
"num_warps": max(
|
196
|
+
min(triton.next_power_of_2(triton.cdiv(hidden_dim, 256)), 32), 4
|
197
|
+
),
|
198
|
+
}
|
199
|
+
|
200
|
+
fused_dual_residual_rmsnorm_kernel[(bs,)](
|
201
|
+
output,
|
202
|
+
mid,
|
203
|
+
x,
|
204
|
+
residual,
|
205
|
+
weight1,
|
206
|
+
weight2,
|
207
|
+
eps=eps,
|
208
|
+
hidden_dim=hidden_dim,
|
209
|
+
**config,
|
210
|
+
)
|
211
|
+
|
212
|
+
return output, mid
|
213
|
+
|
214
|
+
|
215
|
+
@triton.jit
|
216
|
+
def fused_rmsnorm_kernel(
|
217
|
+
output_ptr,
|
218
|
+
activ_ptr,
|
219
|
+
weight_ptr,
|
220
|
+
eps: tl.constexpr,
|
221
|
+
hidden_dim: tl.constexpr,
|
222
|
+
BLOCK_SIZE: tl.constexpr,
|
223
|
+
):
|
224
|
+
pid = tl.program_id(axis=0)
|
225
|
+
input_start = pid * hidden_dim
|
226
|
+
|
227
|
+
offsets = tl.arange(0, BLOCK_SIZE)
|
228
|
+
mask = offsets < hidden_dim
|
229
|
+
|
230
|
+
a_ = tl.load(activ_ptr + input_start + offsets, mask=mask, other=0.0)
|
231
|
+
a = a_.to(tl.float32)
|
232
|
+
rms = tl.sqrt(tl.sum(a * a, axis=0) / hidden_dim + eps)
|
233
|
+
|
234
|
+
w1_ = tl.load(weight_ptr + offsets, mask=mask, other=0.0)
|
235
|
+
w1 = w1_.to(tl.float32)
|
236
|
+
|
237
|
+
a_rms = a / rms * w1
|
238
|
+
|
239
|
+
tl.store(
|
240
|
+
output_ptr + input_start + offsets,
|
241
|
+
a_rms, # implicitly casts to output dtype here
|
242
|
+
mask=mask,
|
243
|
+
)
|
244
|
+
|
245
|
+
|
246
|
+
def fused_rmsnorm(x, weight, eps, autotune=False, inplace=False):
|
247
|
+
assert len(x.shape) == 2
|
248
|
+
if inplace:
|
249
|
+
output = x
|
250
|
+
else:
|
251
|
+
output = torch.empty_like(x)
|
252
|
+
bs, hidden_dim = x.shape
|
253
|
+
config = {
|
254
|
+
"BLOCK_SIZE": triton.next_power_of_2(hidden_dim),
|
255
|
+
"num_warps": max(
|
256
|
+
min(triton.next_power_of_2(triton.cdiv(hidden_dim, 256)), 32), 4
|
257
|
+
),
|
258
|
+
}
|
259
|
+
|
260
|
+
fused_rmsnorm_kernel[(bs,)](
|
261
|
+
output, x, weight, eps=eps, hidden_dim=hidden_dim, **config
|
262
|
+
)
|
263
|
+
return output
|
264
|
+
|
265
|
+
|
266
|
+
class FusedDualResidualRMSNorm:
|
267
|
+
"""
|
268
|
+
Fused implementation of
|
269
|
+
y = RMSNorm2(RMSNorm1(x) + residual))
|
270
|
+
"""
|
271
|
+
|
272
|
+
def __init__(self, rmsnorm1, rmsnorm2) -> None: # the one after rmsnorm1
|
273
|
+
self.rmsnorm1 = rmsnorm1
|
274
|
+
self.rmsnorm2 = rmsnorm2
|
275
|
+
self.variance_epsilon = self.rmsnorm1.variance_epsilon
|
276
|
+
assert self.rmsnorm1.variance_epsilon == self.rmsnorm2.variance_epsilon
|
277
|
+
assert self.rmsnorm1.weight.shape == self.rmsnorm2.weight.shape
|
278
|
+
|
279
|
+
def __call__(self, *args, **kwargs):
|
280
|
+
return self.forward(*args, **kwargs)
|
281
|
+
|
282
|
+
def forward(
|
283
|
+
self, x: torch.Tensor, residual: torch.Tensor
|
284
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
285
|
+
if x.is_cuda:
|
286
|
+
return self.forward_cuda(x, residual)
|
287
|
+
else:
|
288
|
+
return self.forward_flashinfer(x, residual)
|
289
|
+
|
290
|
+
def forward_cuda(
|
291
|
+
self, x: torch.Tensor, residual: torch.Tensor, autotune=False
|
292
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
293
|
+
return fused_dual_residual_rmsnorm(
|
294
|
+
x,
|
295
|
+
residual,
|
296
|
+
self.rmsnorm1.weight,
|
297
|
+
self.rmsnorm2.weight,
|
298
|
+
self.variance_epsilon,
|
299
|
+
autotune=autotune,
|
300
|
+
)
|
301
|
+
|
302
|
+
def forward_flashinfer(
|
303
|
+
self,
|
304
|
+
x: torch.Tensor,
|
305
|
+
residual: torch.Tensor,
|
306
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
307
|
+
normed1 = self.rmsnorm1(x)
|
308
|
+
residual = normed1 + residual
|
309
|
+
return self.rmsnorm2(residual), residual
|
310
|
+
|
311
|
+
def forward_native(
|
312
|
+
self,
|
313
|
+
x: torch.Tensor,
|
314
|
+
residual: torch.Tensor,
|
315
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
316
|
+
normed1 = self.rmsnorm1.forward_native(x)
|
317
|
+
residual = normed1 + residual
|
318
|
+
return self.rmsnorm2.forward_native(residual), residual
|
319
|
+
|
320
|
+
|
321
|
+
# gelu on first half of vector
|
322
|
+
@triton.jit
|
323
|
+
def gelu_and_mul_kernel(
|
324
|
+
out_hidden_states_ptr, # (bs, hidden_dim)
|
325
|
+
out_scales_ptr, # (bs,)
|
326
|
+
hidden_states_ptr, # (bs, hidden_dim * 2)
|
327
|
+
quant_max: tl.constexpr,
|
328
|
+
static_scale: tl.constexpr,
|
329
|
+
hidden_dim: tl.constexpr, # the output hidden_dim
|
330
|
+
BLOCK_SIZE: tl.constexpr,
|
331
|
+
):
|
332
|
+
pid = tl.program_id(axis=0)
|
333
|
+
|
334
|
+
input_start = pid * hidden_dim * 2
|
335
|
+
output_start = pid * hidden_dim
|
336
|
+
|
337
|
+
input1_offs = tl.arange(0, BLOCK_SIZE)
|
338
|
+
mask = tl.arange(0, BLOCK_SIZE) < hidden_dim # shared for input1, input3, output
|
339
|
+
input3_offs = hidden_dim + tl.arange(0, BLOCK_SIZE)
|
340
|
+
output_offs = tl.arange(0, BLOCK_SIZE)
|
341
|
+
|
342
|
+
x1 = tl.load(
|
343
|
+
hidden_states_ptr + input_start + input1_offs, mask=mask, other=0.0
|
344
|
+
).to(tl.float32)
|
345
|
+
x3 = tl.load(
|
346
|
+
hidden_states_ptr + input_start + input3_offs, mask=mask, other=0.0
|
347
|
+
).to(tl.float32)
|
348
|
+
|
349
|
+
# gelu
|
350
|
+
# cast down before mul to better match training?
|
351
|
+
gelu_x1 = 0.5 * (1.0 + tl.erf(x1 * 0.7071067811865475)) * x1
|
352
|
+
out = x3 * gelu_x1.to(hidden_states_ptr.dtype.element_ty)
|
353
|
+
|
354
|
+
if quant_max is not None:
|
355
|
+
raise NotImplementedError()
|
356
|
+
|
357
|
+
tl.store(out_hidden_states_ptr + output_start + output_offs, out, mask=mask)
|
358
|
+
|
359
|
+
|
360
|
+
def gelu_and_mul_triton(
|
361
|
+
hidden_states,
|
362
|
+
scales=None,
|
363
|
+
quantize=None, # dtype to quantize to
|
364
|
+
out=None,
|
365
|
+
):
|
366
|
+
bs, in_hidden_dim = hidden_states.shape
|
367
|
+
hidden_dim = in_hidden_dim // 2
|
368
|
+
|
369
|
+
if out is None:
|
370
|
+
out_hidden_states = torch.empty(
|
371
|
+
(bs, hidden_dim),
|
372
|
+
dtype=quantize or hidden_states.dtype,
|
373
|
+
device=hidden_states.device,
|
374
|
+
)
|
375
|
+
else:
|
376
|
+
assert out.shape == (bs, hidden_dim)
|
377
|
+
assert out.dtype == (quantize or hidden_states.dtype)
|
378
|
+
out_hidden_states = out
|
379
|
+
out_scales = None
|
380
|
+
static_scale = False
|
381
|
+
if quantize is not None:
|
382
|
+
if scales is None:
|
383
|
+
out_scales = torch.empty(
|
384
|
+
(bs,), dtype=torch.float32, device=hidden_states.device
|
385
|
+
)
|
386
|
+
else:
|
387
|
+
out_scales = scales
|
388
|
+
static_scale = True
|
389
|
+
|
390
|
+
config = {
|
391
|
+
# 8 ele per thread (not tuned)
|
392
|
+
"num_warps": max(
|
393
|
+
min(triton.next_power_of_2(triton.cdiv(hidden_dim, 8 * 32)), 32), 4
|
394
|
+
),
|
395
|
+
}
|
396
|
+
|
397
|
+
gelu_and_mul_kernel[(bs,)](
|
398
|
+
out_hidden_states,
|
399
|
+
out_scales,
|
400
|
+
hidden_states,
|
401
|
+
quant_max=torch.finfo(quantize).max if quantize is not None else None,
|
402
|
+
static_scale=static_scale,
|
403
|
+
hidden_dim=hidden_dim,
|
404
|
+
BLOCK_SIZE=triton.next_power_of_2(hidden_dim),
|
405
|
+
**config,
|
406
|
+
)
|
407
|
+
|
408
|
+
if quantize is not None:
|
409
|
+
return out_hidden_states, out_scales
|
410
|
+
else:
|
411
|
+
return out_hidden_states, None
|
sglang/srt/layers/layernorm.py
CHANGED
@@ -21,7 +21,9 @@ import torch.nn as nn
|
|
21
21
|
|
22
22
|
from sglang.srt.utils import is_cuda_available
|
23
23
|
|
24
|
-
|
24
|
+
_is_cuda = is_cuda_available()
|
25
|
+
|
26
|
+
if _is_cuda:
|
25
27
|
from sgl_kernel import (
|
26
28
|
fused_add_rmsnorm,
|
27
29
|
gemma_fused_add_rmsnorm,
|
@@ -117,7 +119,27 @@ class GemmaRMSNorm(CustomOp):
|
|
117
119
|
return out
|
118
120
|
|
119
121
|
|
120
|
-
|
122
|
+
class Gemma3RMSNorm(nn.Module):
|
123
|
+
def __init__(self, dim: int, eps: float = 1e-6):
|
124
|
+
super().__init__()
|
125
|
+
self.eps = eps
|
126
|
+
self.weight = nn.Parameter(torch.zeros(dim))
|
127
|
+
|
128
|
+
def _norm(self, x):
|
129
|
+
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
130
|
+
|
131
|
+
def forward(self, x):
|
132
|
+
output = self._norm(x.float())
|
133
|
+
# Llama does x.to(float16) * w whilst Gemma3 is (x * w).to(float16)
|
134
|
+
# See https://github.com/huggingface/transformers/pull/29402
|
135
|
+
output = output * (1.0 + self.weight.float())
|
136
|
+
return output.type_as(x)
|
137
|
+
|
138
|
+
def extra_repr(self):
|
139
|
+
return f"{tuple(self.weight.shape)}, eps={self.eps}"
|
140
|
+
|
141
|
+
|
142
|
+
if not _is_cuda:
|
121
143
|
logger.info(
|
122
144
|
"sgl-kernel is not available on Non-NV platforms. Fallback to other kernel libraries."
|
123
145
|
)
|
sglang/srt/layers/linear.py
CHANGED
@@ -23,6 +23,7 @@ from sglang.srt.layers.parameter import (
|
|
23
23
|
PackedvLLMParameter,
|
24
24
|
PerTensorScaleParameter,
|
25
25
|
RowvLLMParameter,
|
26
|
+
_ColumnvLLMParameter,
|
26
27
|
)
|
27
28
|
from sglang.srt.layers.quantization.base_config import (
|
28
29
|
QuantizationConfig,
|
@@ -423,8 +424,6 @@ class ColumnParallelLinear(LinearBase):
|
|
423
424
|
assert loaded_weight.numel() == 1
|
424
425
|
loaded_weight = loaded_weight.reshape(1)
|
425
426
|
|
426
|
-
from sglang.srt.layers.parameter import _ColumnvLLMParameter
|
427
|
-
|
428
427
|
if isinstance(param, _ColumnvLLMParameter):
|
429
428
|
param.load_column_parallel_weight(
|
430
429
|
loaded_weight,
|
@@ -687,10 +686,19 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
|
687
686
|
):
|
688
687
|
if loaded_shard_id is None:
|
689
688
|
if isinstance(param, PerTensorScaleParameter):
|
690
|
-
param.load_merged_column_weight(
|
689
|
+
param.load_merged_column_weight(
|
690
|
+
loaded_weight=loaded_weight,
|
691
|
+
shard_id=0,
|
692
|
+
tp_rank=self.tp_rank,
|
693
|
+
tp_size=self.tp_size,
|
694
|
+
)
|
691
695
|
return
|
692
696
|
elif type(param) in (RowvLLMParameter, BasevLLMParameter):
|
693
|
-
param.load_merged_column_weight(
|
697
|
+
param.load_merged_column_weight(
|
698
|
+
loaded_weight=loaded_weight,
|
699
|
+
tp_rank=self.tp_rank,
|
700
|
+
tp_size=self.tp_size,
|
701
|
+
)
|
694
702
|
return
|
695
703
|
# TODO: @dsikka - move to parameter.py
|
696
704
|
self._load_fused_module_from_checkpoint(param, loaded_weight)
|
@@ -719,6 +727,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
|
|
719
727
|
shard_offset=shard_offset,
|
720
728
|
shard_size=shard_size,
|
721
729
|
use_presharded_weights=self.use_presharded_weights,
|
730
|
+
tp_rank=self.tp_rank,
|
731
|
+
tp_size=self.tp_size,
|
722
732
|
)
|
723
733
|
|
724
734
|
|
@@ -782,6 +792,8 @@ class QKVParallelLinear(ColumnParallelLinear):
|
|
782
792
|
else:
|
783
793
|
self.num_kv_heads = divide(self.total_num_kv_heads, tp_size)
|
784
794
|
self.num_kv_head_replicas = 1
|
795
|
+
self.q_proj_shard_size = self.num_heads * self.head_size
|
796
|
+
self.kv_proj_shard_size = self.num_kv_heads * self.head_size
|
785
797
|
input_size = self.hidden_size
|
786
798
|
output_size = (
|
787
799
|
(self.num_heads + 2 * self.num_kv_heads) * tp_size * self.head_size
|
@@ -1234,7 +1246,7 @@ class RowParallelLinear(LinearBase):
|
|
1234
1246
|
assert loaded_weight.numel() == 1
|
1235
1247
|
loaded_weight = loaded_weight.reshape(1)
|
1236
1248
|
|
1237
|
-
if isinstance(param,
|
1249
|
+
if isinstance(param, RowvLLMParameter):
|
1238
1250
|
# This `BasevLLMParameter` is defined in sglang/srt/layers/parameter.py,
|
1239
1251
|
# It supports additional parameters like tp_rank and use_presharded_weights.
|
1240
1252
|
param.load_row_parallel_weight(
|