sglang 0.5.2rc2__py3-none-any.whl → 0.5.3rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch_server.py +10 -1
- sglang/bench_serving.py +257 -29
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/device_config.py +3 -1
- sglang/srt/configs/dots_vlm.py +139 -0
- sglang/srt/configs/load_config.py +1 -0
- sglang/srt/configs/model_config.py +50 -6
- sglang/srt/configs/qwen3_next.py +326 -0
- sglang/srt/connector/__init__.py +8 -1
- sglang/srt/connector/remote_instance.py +82 -0
- sglang/srt/constrained/base_grammar_backend.py +48 -12
- sglang/srt/constrained/llguidance_backend.py +0 -1
- sglang/srt/constrained/outlines_backend.py +0 -1
- sglang/srt/constrained/xgrammar_backend.py +28 -9
- sglang/srt/custom_op.py +11 -1
- sglang/srt/debug_utils/dump_comparator.py +81 -44
- sglang/srt/debug_utils/dump_loader.py +97 -0
- sglang/srt/debug_utils/dumper.py +11 -3
- sglang/srt/debug_utils/text_comparator.py +73 -11
- sglang/srt/disaggregation/base/conn.py +1 -1
- sglang/srt/disaggregation/common/conn.py +15 -12
- sglang/srt/disaggregation/decode.py +21 -10
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -1
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +6 -445
- sglang/srt/disaggregation/mooncake/conn.py +18 -10
- sglang/srt/disaggregation/nixl/conn.py +180 -16
- sglang/srt/disaggregation/prefill.py +5 -3
- sglang/srt/disaggregation/utils.py +5 -50
- sglang/srt/distributed/parallel_state.py +24 -3
- sglang/srt/entrypoints/engine.py +38 -17
- sglang/srt/entrypoints/grpc_request_manager.py +580 -0
- sglang/srt/entrypoints/grpc_server.py +680 -0
- sglang/srt/entrypoints/http_server.py +85 -54
- sglang/srt/entrypoints/openai/protocol.py +4 -1
- sglang/srt/entrypoints/openai/serving_base.py +46 -3
- sglang/srt/entrypoints/openai/serving_chat.py +36 -16
- sglang/srt/entrypoints/openai/serving_completions.py +12 -3
- sglang/srt/entrypoints/openai/serving_embedding.py +8 -3
- sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
- sglang/srt/entrypoints/openai/serving_responses.py +6 -3
- sglang/srt/entrypoints/openai/serving_score.py +1 -0
- sglang/srt/eplb/eplb_manager.py +2 -2
- sglang/srt/eplb/expert_distribution.py +26 -13
- sglang/srt/eplb/expert_location.py +8 -3
- sglang/srt/eplb/expert_location_updater.py +1 -1
- sglang/srt/function_call/base_format_detector.py +3 -6
- sglang/srt/function_call/ebnf_composer.py +11 -9
- sglang/srt/function_call/function_call_parser.py +6 -0
- sglang/srt/function_call/glm4_moe_detector.py +1 -1
- sglang/srt/function_call/qwen3_coder_detector.py +1 -1
- sglang/srt/grpc/__init__.py +1 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +106 -0
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +427 -0
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +236 -0
- sglang/srt/hf_transformers_utils.py +4 -0
- sglang/srt/layers/activation.py +142 -9
- sglang/srt/layers/attention/ascend_backend.py +11 -4
- sglang/srt/layers/attention/fla/chunk.py +242 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
- sglang/srt/layers/attention/fla/chunk_o.py +178 -0
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
- sglang/srt/layers/attention/fla/cumsum.py +300 -0
- sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
- sglang/srt/layers/attention/fla/index.py +37 -0
- sglang/srt/layers/attention/fla/l2norm.py +150 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
- sglang/srt/layers/attention/fla/op.py +66 -0
- sglang/srt/layers/attention/fla/solve_tril.py +465 -0
- sglang/srt/layers/attention/fla/utils.py +331 -0
- sglang/srt/layers/attention/fla/wy_fast.py +158 -0
- sglang/srt/layers/attention/flashinfer_backend.py +6 -4
- sglang/srt/layers/attention/flashinfer_mla_backend.py +16 -12
- sglang/srt/layers/attention/hybrid_attn_backend.py +57 -50
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
- sglang/srt/layers/attention/intel_amx_backend.py +3 -0
- sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
- sglang/srt/layers/attention/mamba/mamba.py +64 -0
- sglang/srt/layers/attention/torch_native_backend.py +12 -6
- sglang/srt/layers/attention/triton_backend.py +18 -1
- sglang/srt/layers/attention/trtllm_mla_backend.py +124 -31
- sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
- sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
- sglang/srt/layers/dp_attention.py +30 -1
- sglang/srt/layers/layernorm.py +32 -15
- sglang/srt/layers/linear.py +34 -3
- sglang/srt/layers/logits_processor.py +29 -10
- sglang/srt/layers/moe/__init__.py +2 -1
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
- sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
- sglang/srt/layers/moe/ep_moe/layer.py +182 -62
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +156 -0
- sglang/srt/layers/moe/fused_moe_native.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +1 -1
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
- sglang/srt/layers/moe/fused_moe_triton/layer.py +61 -59
- sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
- sglang/srt/layers/moe/moe_runner/base.py +274 -1
- sglang/srt/layers/moe/moe_runner/runner.py +80 -0
- sglang/srt/layers/moe/moe_runner/triton.py +448 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
- sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
- sglang/srt/layers/moe/token_dispatcher/deepep.py +43 -39
- sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
- sglang/srt/layers/moe/topk.py +30 -9
- sglang/srt/layers/moe/utils.py +12 -6
- sglang/srt/layers/quantization/awq.py +19 -7
- sglang/srt/layers/quantization/base_config.py +11 -6
- sglang/srt/layers/quantization/blockwise_int8.py +38 -27
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
- sglang/srt/layers/quantization/fp8.py +76 -47
- sglang/srt/layers/quantization/fp8_utils.py +50 -31
- sglang/srt/layers/quantization/gptq.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +147 -47
- sglang/srt/layers/quantization/moe_wna16.py +21 -18
- sglang/srt/layers/quantization/mxfp4.py +64 -40
- sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
- sglang/srt/layers/quantization/unquant.py +135 -47
- sglang/srt/layers/quantization/w4afp8.py +30 -17
- sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
- sglang/srt/layers/quantization/w8a8_int8.py +76 -38
- sglang/srt/layers/sampler.py +162 -18
- sglang/srt/lora/backend/base_backend.py +50 -8
- sglang/srt/lora/backend/triton_backend.py +90 -2
- sglang/srt/lora/layers.py +32 -0
- sglang/srt/lora/lora.py +4 -1
- sglang/srt/lora/lora_manager.py +35 -112
- sglang/srt/lora/mem_pool.py +24 -10
- sglang/srt/lora/utils.py +18 -9
- sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
- sglang/srt/managers/cache_controller.py +158 -160
- sglang/srt/managers/data_parallel_controller.py +105 -35
- sglang/srt/managers/detokenizer_manager.py +8 -4
- sglang/srt/managers/disagg_service.py +46 -0
- sglang/srt/managers/io_struct.py +199 -12
- sglang/srt/managers/mm_utils.py +1 -0
- sglang/srt/managers/multi_tokenizer_mixin.py +350 -400
- sglang/srt/managers/schedule_batch.py +77 -56
- sglang/srt/managers/schedule_policy.py +1 -1
- sglang/srt/managers/scheduler.py +187 -39
- sglang/srt/managers/scheduler_metrics_mixin.py +4 -3
- sglang/srt/managers/scheduler_output_processor_mixin.py +55 -11
- sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
- sglang/srt/managers/tokenizer_communicator_mixin.py +569 -0
- sglang/srt/managers/tokenizer_manager.py +259 -519
- sglang/srt/managers/tp_worker.py +53 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +42 -19
- sglang/srt/mem_cache/hicache_storage.py +3 -23
- sglang/srt/mem_cache/hiradix_cache.py +103 -43
- sglang/srt/mem_cache/memory_pool.py +347 -48
- sglang/srt/mem_cache/memory_pool_host.py +105 -46
- sglang/srt/mem_cache/radix_cache.py +0 -2
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
- sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +86 -4
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
- sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +49 -7
- sglang/srt/mem_cache/swa_radix_cache.py +0 -2
- sglang/srt/metrics/collector.py +493 -76
- sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
- sglang/srt/model_executor/cpu_graph_runner.py +640 -0
- sglang/srt/model_executor/cuda_graph_runner.py +13 -5
- sglang/srt/model_executor/forward_batch_info.py +59 -2
- sglang/srt/model_executor/model_runner.py +356 -29
- sglang/srt/model_loader/__init__.py +9 -3
- sglang/srt/model_loader/loader.py +128 -4
- sglang/srt/model_loader/weight_utils.py +2 -1
- sglang/srt/models/apertus.py +686 -0
- sglang/srt/models/bailing_moe.py +798 -218
- sglang/srt/models/bailing_moe_nextn.py +168 -0
- sglang/srt/models/deepseek_v2.py +109 -15
- sglang/srt/models/dots_vlm.py +174 -0
- sglang/srt/models/dots_vlm_vit.py +337 -0
- sglang/srt/models/ernie4.py +1 -1
- sglang/srt/models/gemma3n_mm.py +1 -1
- sglang/srt/models/glm4_moe.py +1 -1
- sglang/srt/models/glm4v.py +4 -2
- sglang/srt/models/glm4v_moe.py +3 -0
- sglang/srt/models/gpt_oss.py +1 -1
- sglang/srt/models/llama4.py +9 -0
- sglang/srt/models/llama_eagle3.py +13 -0
- sglang/srt/models/longcat_flash.py +2 -2
- sglang/srt/models/mllama4.py +25 -0
- sglang/srt/models/opt.py +637 -0
- sglang/srt/models/qwen2.py +7 -0
- sglang/srt/models/qwen2_5_vl.py +27 -3
- sglang/srt/models/qwen2_moe.py +56 -12
- sglang/srt/models/qwen3_moe.py +1 -1
- sglang/srt/models/qwen3_next.py +1042 -0
- sglang/srt/models/qwen3_next_mtp.py +112 -0
- sglang/srt/models/step3_vl.py +1 -1
- sglang/srt/multimodal/processors/dots_vlm.py +99 -0
- sglang/srt/multimodal/processors/glm4v.py +9 -9
- sglang/srt/multimodal/processors/internvl.py +141 -129
- sglang/srt/multimodal/processors/qwen_vl.py +15 -5
- sglang/srt/offloader.py +27 -3
- sglang/srt/remote_instance_weight_loader_utils.py +69 -0
- sglang/srt/sampling/sampling_batch_info.py +18 -15
- sglang/srt/server_args.py +276 -35
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
- sglang/srt/speculative/eagle_utils.py +0 -2
- sglang/srt/speculative/eagle_worker.py +43 -4
- sglang/srt/speculative/spec_info.py +5 -0
- sglang/srt/speculative/standalone_worker.py +109 -0
- sglang/srt/tracing/trace.py +552 -0
- sglang/srt/utils.py +34 -3
- sglang/srt/weight_sync/utils.py +1 -1
- sglang/test/attention/test_trtllm_mla_backend.py +169 -5
- sglang/test/runners.py +4 -0
- sglang/test/test_cutlass_moe.py +24 -6
- sglang/test/test_disaggregation_utils.py +66 -0
- sglang/test/test_fp4_moe.py +370 -1
- sglang/test/test_utils.py +28 -1
- sglang/utils.py +11 -0
- sglang/version.py +1 -1
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/METADATA +59 -123
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/RECORD +237 -178
- sglang/srt/disaggregation/launch_lb.py +0 -118
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/WHEEL +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,64 @@
|
|
1
|
+
from typing import Callable, List, Tuple
|
2
|
+
|
3
|
+
import torch
|
4
|
+
|
5
|
+
LoaderFunction = Callable[[torch.Tensor, torch.Tensor], None]
|
6
|
+
|
7
|
+
|
8
|
+
def mamba_v2_sharded_weight_loader(
|
9
|
+
shard_spec: List[Tuple[int, int, float]],
|
10
|
+
tp_size: int,
|
11
|
+
tp_rank: int,
|
12
|
+
) -> LoaderFunction:
|
13
|
+
"""Create a weight loader for mamba v2. This ensures that the projections
|
14
|
+
are correctly sharded so that they can be split into x, B, C. It also
|
15
|
+
ensures the the all the groups corresponding to a head shard is placed
|
16
|
+
together with it.
|
17
|
+
"""
|
18
|
+
|
19
|
+
def loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None:
|
20
|
+
|
21
|
+
# - track boundary of (sharded) param, and loaded_weight, respectively
|
22
|
+
boundary, loaded_boundary = 0, 0
|
23
|
+
|
24
|
+
# - iterate over the shard specs
|
25
|
+
for full_dim, extra, duplicate_groups in shard_spec:
|
26
|
+
# - full dim is the model dim (before TP).
|
27
|
+
# - extra > 0, means there is expected overall increase
|
28
|
+
# of dimensions. This is so because of replication.
|
29
|
+
# - ratio is used map the tp_rank to the actual shard
|
30
|
+
# rank. This is useful when there is replication of
|
31
|
+
# groups to accompany head shards.
|
32
|
+
|
33
|
+
# - size of the loaded shard
|
34
|
+
shard_size = full_dim // tp_size
|
35
|
+
|
36
|
+
# - compute the rank into the loaded shard.
|
37
|
+
# - if there is replication, different TP shards will
|
38
|
+
# take from the same rank.
|
39
|
+
# NOTE: currently we only support duplication
|
40
|
+
# in the case where num_groups == 1
|
41
|
+
rank = 0 if duplicate_groups else tp_rank
|
42
|
+
|
43
|
+
# - leftmost boundary index into loaded weight.
|
44
|
+
loaded_skip = rank * shard_size
|
45
|
+
loaded_start_idx = loaded_boundary + loaded_skip
|
46
|
+
|
47
|
+
# - take these many dims from the loaded weight.
|
48
|
+
take = min(shard_size, full_dim - extra - loaded_skip)
|
49
|
+
|
50
|
+
# - always shard on dim 0
|
51
|
+
# - the ignore is for a mundane mypy error as it does not
|
52
|
+
# seem to handle slices well.
|
53
|
+
# https://github.com/python/mypy/issues/2410
|
54
|
+
param.data[
|
55
|
+
boundary : (boundary + take), ... # type: ignore[misc]
|
56
|
+
] = loaded_weight[
|
57
|
+
loaded_start_idx : (loaded_start_idx + take) # type: ignore[misc]
|
58
|
+
] # type: ignore[misc]
|
59
|
+
|
60
|
+
# move indexing boundaries
|
61
|
+
boundary += shard_size
|
62
|
+
loaded_boundary += full_dim - extra
|
63
|
+
|
64
|
+
return loader
|
@@ -193,10 +193,13 @@ class TorchNativeAttnBackend(AttentionBackend):
|
|
193
193
|
else:
|
194
194
|
o = torch.empty_like(q)
|
195
195
|
|
196
|
+
if layer.is_cross_attention:
|
197
|
+
cache_loc = forward_batch.encoder_out_cache_loc
|
198
|
+
else:
|
199
|
+
cache_loc = forward_batch.out_cache_loc
|
200
|
+
|
196
201
|
if save_kv_cache:
|
197
|
-
forward_batch.token_to_kv_pool.set_kv_buffer(
|
198
|
-
layer, forward_batch.out_cache_loc, k, v
|
199
|
-
)
|
202
|
+
forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
|
200
203
|
|
201
204
|
use_gqa = layer.tp_q_head_num != layer.tp_k_head_num
|
202
205
|
|
@@ -241,10 +244,13 @@ class TorchNativeAttnBackend(AttentionBackend):
|
|
241
244
|
else:
|
242
245
|
o = torch.empty_like(q)
|
243
246
|
|
247
|
+
if layer.is_cross_attention:
|
248
|
+
cache_loc = forward_batch.encoder_out_cache_loc
|
249
|
+
else:
|
250
|
+
cache_loc = forward_batch.out_cache_loc
|
251
|
+
|
244
252
|
if save_kv_cache:
|
245
|
-
forward_batch.token_to_kv_pool.set_kv_buffer(
|
246
|
-
layer, forward_batch.out_cache_loc, k, v
|
247
|
-
)
|
253
|
+
forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
|
248
254
|
|
249
255
|
use_gqa = layer.tp_q_head_num != layer.tp_k_head_num
|
250
256
|
|
@@ -80,7 +80,13 @@ class TritonAttnBackend(AttentionBackend):
|
|
80
80
|
self.num_kv_head = model_runner.model_config.get_num_kv_heads(
|
81
81
|
get_attention_tp_size()
|
82
82
|
)
|
83
|
-
|
83
|
+
if model_runner.is_hybrid_gdn:
|
84
|
+
# For hybrid linear models, layer_id = 0 may not be full attention
|
85
|
+
self.v_head_dim = model_runner.token_to_kv_pool.get_v_head_dim()
|
86
|
+
else:
|
87
|
+
self.v_head_dim = model_runner.token_to_kv_pool.get_value_buffer(0).shape[
|
88
|
+
-1
|
89
|
+
]
|
84
90
|
self.max_context_len = model_runner.model_config.context_len
|
85
91
|
self.device = model_runner.device
|
86
92
|
self.device_core_count = get_device_core_count(model_runner.gpu_id)
|
@@ -88,6 +94,11 @@ class TritonAttnBackend(AttentionBackend):
|
|
88
94
|
"SGLANG_TRITON_DECODE_ATTN_STATIC_KV_SPLITS", "false"
|
89
95
|
)
|
90
96
|
self.max_kv_splits = model_runner.server_args.triton_attention_num_kv_splits
|
97
|
+
self.split_tile_size = model_runner.server_args.triton_attention_split_tile_size
|
98
|
+
if self.split_tile_size is not None:
|
99
|
+
self.max_kv_splits = (
|
100
|
+
self.max_context_len + self.split_tile_size - 1
|
101
|
+
) // self.split_tile_size
|
91
102
|
|
92
103
|
# Check arguments
|
93
104
|
assert not (
|
@@ -147,6 +158,12 @@ class TritonAttnBackend(AttentionBackend):
|
|
147
158
|
num_kv_splits.fill_(self.max_kv_splits)
|
148
159
|
return
|
149
160
|
|
161
|
+
if self.split_tile_size is not None:
|
162
|
+
num_kv_splits[:] = (
|
163
|
+
seq_lens + self.split_tile_size - 1
|
164
|
+
) // self.split_tile_size
|
165
|
+
return
|
166
|
+
|
150
167
|
if num_seq < 256:
|
151
168
|
SCHEDULE_SEQ = 256
|
152
169
|
else:
|
@@ -20,6 +20,7 @@ from sglang.srt.layers.attention.utils import (
|
|
20
20
|
create_flashmla_kv_indices_triton,
|
21
21
|
)
|
22
22
|
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
23
|
+
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
23
24
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
24
25
|
from sglang.srt.utils import is_flashinfer_available
|
25
26
|
|
@@ -45,11 +46,19 @@ TRTLLM_BLOCK_CONSTRAINT = 128
|
|
45
46
|
global_zero_init_workspace_buffer = None
|
46
47
|
|
47
48
|
|
49
|
+
@dataclass
|
50
|
+
class TRTLLMMLAPrefillMetadata:
|
51
|
+
"""Metadata for TRTLLM MLA prefill operations."""
|
52
|
+
|
53
|
+
max_seq_len: int
|
54
|
+
cum_seq_lens: torch.Tensor
|
55
|
+
seq_lens: torch.Tensor
|
56
|
+
|
57
|
+
|
48
58
|
@dataclass
|
49
59
|
class TRTLLMMLADecodeMetadata:
|
50
60
|
"""Metadata for TRTLLM MLA decode operations."""
|
51
61
|
|
52
|
-
workspace: Optional[torch.Tensor] = None
|
53
62
|
block_kv_indices: Optional[torch.Tensor] = None
|
54
63
|
max_seq_len: Optional[int] = None
|
55
64
|
|
@@ -64,7 +73,12 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
|
64
73
|
kv_indptr_buf: Optional[torch.Tensor] = None,
|
65
74
|
q_indptr_decode_buf: Optional[torch.Tensor] = None,
|
66
75
|
):
|
67
|
-
super().__init__(
|
76
|
+
super().__init__(
|
77
|
+
model_runner,
|
78
|
+
skip_prefill,
|
79
|
+
kv_indptr_buf,
|
80
|
+
q_indptr_decode_buf,
|
81
|
+
)
|
68
82
|
|
69
83
|
config = model_runner.model_config
|
70
84
|
|
@@ -101,7 +115,12 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
|
101
115
|
# CUDA graph state
|
102
116
|
self.decode_cuda_graph_metadata = {}
|
103
117
|
self.decode_cuda_graph_kv_indices = None
|
104
|
-
self.
|
118
|
+
self.forward_prefill_metadata: Optional[TRTLLMMLAPrefillMetadata] = None
|
119
|
+
self.forward_decode_metadata: Union[TRTLLMMLADecodeMetadata, None] = None
|
120
|
+
|
121
|
+
self.disable_chunked_prefix_cache = global_server_args_dict[
|
122
|
+
"disable_chunked_prefix_cache"
|
123
|
+
]
|
105
124
|
|
106
125
|
def _calc_padded_blocks(self, max_seq_len: int) -> int:
|
107
126
|
"""
|
@@ -177,9 +196,6 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
|
177
196
|
self.decode_cuda_graph_kv_indices = torch.full(
|
178
197
|
(max_bs, max_blocks_per_seq), -1, dtype=torch.int32, device=self.device
|
179
198
|
)
|
180
|
-
self.decode_cuda_graph_workspace = torch.empty(
|
181
|
-
self.workspace_size, dtype=torch.int8, device=self.device
|
182
|
-
)
|
183
199
|
|
184
200
|
super().init_cuda_graph_state(max_bs, max_num_tokens, kv_indices_buf)
|
185
201
|
|
@@ -230,12 +246,11 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
|
230
246
|
max_seq_len_val = int(seq_lens.max().item())
|
231
247
|
|
232
248
|
metadata = TRTLLMMLADecodeMetadata(
|
233
|
-
self.decode_cuda_graph_workspace,
|
234
249
|
block_kv_indices,
|
235
250
|
max_seq_len_val,
|
236
251
|
)
|
237
252
|
self.decode_cuda_graph_metadata[bs] = metadata
|
238
|
-
self.
|
253
|
+
self.forward_decode_metadata = metadata
|
239
254
|
|
240
255
|
def init_forward_metadata_replay_cuda_graph(
|
241
256
|
self,
|
@@ -291,31 +306,55 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
|
291
306
|
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
292
307
|
"""Initialize the metadata for a forward pass."""
|
293
308
|
# Delegate to parent for non-decode modes.
|
294
|
-
if
|
295
|
-
|
309
|
+
if (
|
310
|
+
forward_batch.forward_mode.is_extend()
|
311
|
+
and not forward_batch.forward_mode.is_target_verify()
|
312
|
+
and not forward_batch.forward_mode.is_draft_extend()
|
313
|
+
):
|
314
|
+
if self.disable_chunked_prefix_cache:
|
315
|
+
super().init_forward_metadata(forward_batch)
|
316
|
+
|
317
|
+
seq_lens = forward_batch.seq_lens - forward_batch.extend_prefix_lens
|
318
|
+
cum_seq_lens_q = torch.cat(
|
319
|
+
(
|
320
|
+
torch.tensor([0], device=forward_batch.seq_lens.device),
|
321
|
+
torch.cumsum(seq_lens, dim=0),
|
322
|
+
)
|
323
|
+
).int()
|
324
|
+
max_seq_len = max(forward_batch.extend_seq_lens_cpu)
|
325
|
+
self.forward_prefill_metadata = TRTLLMMLAPrefillMetadata(
|
326
|
+
max_seq_len,
|
327
|
+
cum_seq_lens_q,
|
328
|
+
seq_lens,
|
329
|
+
)
|
330
|
+
elif forward_batch.forward_mode.is_decode_or_idle():
|
331
|
+
bs = forward_batch.batch_size
|
296
332
|
|
297
|
-
|
333
|
+
# Get maximum sequence length.
|
334
|
+
if getattr(forward_batch, "seq_lens_cpu", None) is not None:
|
335
|
+
max_seq = forward_batch.seq_lens_cpu.max().item()
|
336
|
+
else:
|
337
|
+
max_seq = forward_batch.seq_lens.max().item()
|
338
|
+
|
339
|
+
max_seqlen_pad = self._calc_padded_blocks(max_seq)
|
340
|
+
block_kv_indices = self._create_block_kv_indices(
|
341
|
+
bs,
|
342
|
+
max_seqlen_pad,
|
343
|
+
forward_batch.req_pool_indices,
|
344
|
+
forward_batch.seq_lens,
|
345
|
+
forward_batch.seq_lens.device,
|
346
|
+
)
|
298
347
|
|
299
|
-
|
300
|
-
|
301
|
-
|
348
|
+
max_seq_len_val = int(max_seq)
|
349
|
+
self.forward_decode_metadata = TRTLLMMLADecodeMetadata(
|
350
|
+
block_kv_indices, max_seq_len_val
|
351
|
+
)
|
352
|
+
forward_batch.decode_trtllm_mla_metadata = self.forward_decode_metadata
|
302
353
|
else:
|
303
|
-
|
304
|
-
|
305
|
-
max_seqlen_pad = self._calc_padded_blocks(max_seq)
|
306
|
-
block_kv_indices = self._create_block_kv_indices(
|
307
|
-
bs,
|
308
|
-
max_seqlen_pad,
|
309
|
-
forward_batch.req_pool_indices,
|
310
|
-
forward_batch.seq_lens,
|
311
|
-
forward_batch.seq_lens.device,
|
312
|
-
)
|
354
|
+
return super().init_forward_metadata(forward_batch)
|
313
355
|
|
314
|
-
|
315
|
-
|
316
|
-
self.workspace_buffer, block_kv_indices, max_seq_len_val
|
317
|
-
)
|
318
|
-
forward_batch.decode_trtllm_mla_metadata = self.forward_metadata
|
356
|
+
def init_mha_chunk_metadata(self, forward_batch: ForwardBatch):
|
357
|
+
super().init_mha_chunk_metadata(forward_batch, disable_flashinfer_ragged=True)
|
319
358
|
|
320
359
|
def quantize_and_rope_for_fp8(
|
321
360
|
self,
|
@@ -459,7 +498,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
|
459
498
|
# Get metadata
|
460
499
|
metadata = (
|
461
500
|
getattr(forward_batch, "decode_trtllm_mla_metadata", None)
|
462
|
-
or self.
|
501
|
+
or self.forward_decode_metadata
|
463
502
|
)
|
464
503
|
|
465
504
|
# Scale computation for TRTLLM MLA kernel BMM1 operation:
|
@@ -482,7 +521,7 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
|
482
521
|
raw_out = flashinfer.decode.trtllm_batch_decode_with_kv_cache_mla(
|
483
522
|
query=query,
|
484
523
|
kv_cache=kv_cache,
|
485
|
-
workspace_buffer=
|
524
|
+
workspace_buffer=self.workspace_buffer,
|
486
525
|
qk_nope_head_dim=self.qk_nope_head_dim,
|
487
526
|
kv_lora_rank=self.kv_lora_rank,
|
488
527
|
qk_rope_head_dim=self.qk_rope_head_dim,
|
@@ -496,6 +535,60 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
|
496
535
|
output = raw_out.view(-1, layer.tp_q_head_num * layer.v_head_dim)
|
497
536
|
return output
|
498
537
|
|
538
|
+
def forward_extend(
|
539
|
+
self,
|
540
|
+
q: torch.Tensor,
|
541
|
+
k: torch.Tensor,
|
542
|
+
v: torch.Tensor,
|
543
|
+
layer: RadixAttention,
|
544
|
+
forward_batch: ForwardBatch,
|
545
|
+
save_kv_cache: bool = True,
|
546
|
+
q_rope: Optional[torch.Tensor] = None,
|
547
|
+
k_rope: Optional[torch.Tensor] = None,
|
548
|
+
) -> torch.Tensor:
|
549
|
+
if (
|
550
|
+
forward_batch.forward_mode.is_target_verify()
|
551
|
+
or forward_batch.forward_mode.is_draft_extend()
|
552
|
+
):
|
553
|
+
return super().forward_extend(
|
554
|
+
q, k, v, layer, forward_batch, save_kv_cache, q_rope, k_rope
|
555
|
+
)
|
556
|
+
# chunked prefix cache is not enabled, use Flashinfer MLA prefill kernel
|
557
|
+
if forward_batch.attn_attend_prefix_cache is None:
|
558
|
+
return super().forward_extend(
|
559
|
+
q, k, v, layer, forward_batch, save_kv_cache, q_rope, k_rope
|
560
|
+
)
|
561
|
+
|
562
|
+
if not forward_batch.attn_attend_prefix_cache:
|
563
|
+
q = q.view(-1, layer.tp_q_head_num, layer.head_dim)
|
564
|
+
k = k.view(-1, layer.tp_k_head_num, layer.head_dim)
|
565
|
+
v = v.view(-1, layer.tp_k_head_num, layer.v_head_dim)
|
566
|
+
output = flashinfer.prefill.trtllm_ragged_attention_deepseek(
|
567
|
+
query=q,
|
568
|
+
key=k,
|
569
|
+
value=v,
|
570
|
+
workspace_buffer=self.workspace_buffer,
|
571
|
+
seq_lens=self.forward_prefill_metadata.seq_lens,
|
572
|
+
max_q_len=self.forward_prefill_metadata.max_seq_len,
|
573
|
+
max_kv_len=self.forward_prefill_metadata.max_seq_len,
|
574
|
+
bmm1_scale=layer.scaling,
|
575
|
+
bmm2_scale=1.0,
|
576
|
+
o_sf_scale=1.0,
|
577
|
+
batch_size=forward_batch.batch_size,
|
578
|
+
window_left=-1,
|
579
|
+
cum_seq_lens_q=self.forward_prefill_metadata.cum_seq_lens,
|
580
|
+
cum_seq_lens_kv=self.forward_prefill_metadata.cum_seq_lens,
|
581
|
+
enable_pdl=False,
|
582
|
+
is_causal=True,
|
583
|
+
return_lse=forward_batch.mha_return_lse,
|
584
|
+
)
|
585
|
+
else:
|
586
|
+
# replace with trtllm ragged attention once accuracy is resolved.
|
587
|
+
output = super().forward_extend(
|
588
|
+
q, k, v, layer, forward_batch, save_kv_cache, q_rope, k_rope
|
589
|
+
)
|
590
|
+
return output
|
591
|
+
|
499
592
|
|
500
593
|
class TRTLLMMLAMultiStepDraftBackend(FlashInferMLAMultiStepDraftBackend):
|
501
594
|
"""Multi-step draft backend for TRT-LLM MLA used by EAGLE."""
|
@@ -64,8 +64,7 @@ def get_wave_kernel(
|
|
64
64
|
subs=hyperparams_0,
|
65
65
|
canonicalize=True,
|
66
66
|
run_bench=False,
|
67
|
-
|
68
|
-
use_buffer_store_ops=True,
|
67
|
+
use_buffer_ops=True,
|
69
68
|
waves_per_eu=2,
|
70
69
|
dynamic_symbols=dynamic_symbols_0,
|
71
70
|
wave_runtime=True,
|
@@ -77,8 +76,7 @@ def get_wave_kernel(
|
|
77
76
|
subs=hyperparams_1,
|
78
77
|
canonicalize=True,
|
79
78
|
run_bench=False,
|
80
|
-
|
81
|
-
use_buffer_store_ops=False,
|
79
|
+
use_buffer_ops=False,
|
82
80
|
waves_per_eu=4,
|
83
81
|
dynamic_symbols=dynamic_symbols_1,
|
84
82
|
wave_runtime=True,
|
@@ -67,11 +67,9 @@ def get_wave_kernel(
|
|
67
67
|
schedule=SchedulingType.NONE,
|
68
68
|
use_scheduling_barriers=False,
|
69
69
|
dynamic_symbols=dynamic_symbols,
|
70
|
-
|
71
|
-
use_buffer_store_ops=True,
|
70
|
+
use_buffer_ops=True,
|
72
71
|
waves_per_eu=2,
|
73
72
|
denorm_fp_math_f32="preserve-sign",
|
74
|
-
gpu_native_math_precision=True,
|
75
73
|
wave_runtime=True,
|
76
74
|
)
|
77
75
|
options = set_default_run_config(options)
|
@@ -51,7 +51,12 @@ class DpPaddingMode(IntEnum):
|
|
51
51
|
return self == DpPaddingMode.SUM_LEN
|
52
52
|
|
53
53
|
@classmethod
|
54
|
-
def get_dp_padding_mode(
|
54
|
+
def get_dp_padding_mode(
|
55
|
+
cls, is_extend_in_batch, global_num_tokens: List[int]
|
56
|
+
) -> DpPaddingMode:
|
57
|
+
if is_extend_in_batch:
|
58
|
+
return DpPaddingMode.SUM_LEN
|
59
|
+
|
55
60
|
# we choose the mode that minimizes the communication cost
|
56
61
|
max_len = max(global_num_tokens)
|
57
62
|
sum_len = sum(global_num_tokens)
|
@@ -119,6 +124,18 @@ class _DpGatheredBufferWrapper:
|
|
119
124
|
def get_dp_global_num_tokens(cls) -> List[int]:
|
120
125
|
return cls._global_num_tokens
|
121
126
|
|
127
|
+
@classmethod
|
128
|
+
def get_dp_hidden_size(cls) -> int:
|
129
|
+
return cls._hidden_size
|
130
|
+
|
131
|
+
@classmethod
|
132
|
+
def get_dp_dtype(cls) -> torch.dtype:
|
133
|
+
return cls._dtype
|
134
|
+
|
135
|
+
@classmethod
|
136
|
+
def get_dp_device(cls) -> torch.device:
|
137
|
+
return cls._device
|
138
|
+
|
122
139
|
|
123
140
|
def set_dp_buffer_len(
|
124
141
|
global_dp_buffer_len: int,
|
@@ -150,6 +167,18 @@ def get_dp_global_num_tokens() -> List[int]:
|
|
150
167
|
return _DpGatheredBufferWrapper.get_dp_global_num_tokens()
|
151
168
|
|
152
169
|
|
170
|
+
def get_dp_hidden_size() -> int:
|
171
|
+
return _DpGatheredBufferWrapper.get_dp_hidden_size()
|
172
|
+
|
173
|
+
|
174
|
+
def get_dp_dtype() -> torch.dtype:
|
175
|
+
return _DpGatheredBufferWrapper.get_dp_dtype()
|
176
|
+
|
177
|
+
|
178
|
+
def get_dp_device() -> torch.device:
|
179
|
+
return _DpGatheredBufferWrapper.get_dp_device()
|
180
|
+
|
181
|
+
|
153
182
|
def compute_dp_attention_world_info(enable_dp_attention, tp_rank, tp_size, dp_size):
|
154
183
|
if not enable_dp_attention:
|
155
184
|
return tp_rank, tp_size, 0
|
sglang/srt/layers/layernorm.py
CHANGED
@@ -18,6 +18,7 @@ from typing import Optional, Tuple, Union
|
|
18
18
|
|
19
19
|
import torch
|
20
20
|
import torch.nn as nn
|
21
|
+
from packaging.version import Version
|
21
22
|
|
22
23
|
from sglang.srt.custom_op import CustomOp
|
23
24
|
from sglang.srt.utils import (
|
@@ -25,32 +26,38 @@ from sglang.srt.utils import (
|
|
25
26
|
get_bool_env_var,
|
26
27
|
is_cpu,
|
27
28
|
is_cuda,
|
29
|
+
is_flashinfer_available,
|
28
30
|
is_hip,
|
29
31
|
is_npu,
|
32
|
+
is_xpu,
|
30
33
|
supports_custom_op,
|
31
34
|
)
|
32
35
|
|
33
36
|
_is_cuda = is_cuda()
|
37
|
+
_is_flashinfer_available = is_flashinfer_available()
|
34
38
|
_is_hip = is_hip()
|
35
39
|
_is_npu = is_npu()
|
36
40
|
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
|
37
41
|
_is_cpu_amx_available = cpu_has_amx_support()
|
38
42
|
_is_cpu = is_cpu()
|
43
|
+
_is_xpu = is_xpu()
|
39
44
|
|
40
45
|
if _is_cuda:
|
41
|
-
|
42
|
-
fused_add_rmsnorm
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
)
|
46
|
+
if _is_flashinfer_available:
|
47
|
+
from flashinfer.norm import fused_add_rmsnorm
|
48
|
+
else:
|
49
|
+
from sgl_kernel import fused_add_rmsnorm
|
50
|
+
from sgl_kernel import gemma_fused_add_rmsnorm, gemma_rmsnorm, rmsnorm
|
47
51
|
|
48
52
|
if _use_aiter:
|
49
53
|
from aiter import rmsnorm2d_fwd as rms_norm
|
50
54
|
from aiter import rmsnorm2d_fwd_with_add as fused_add_rms_norm
|
51
55
|
elif _is_hip:
|
56
|
+
import vllm
|
52
57
|
from vllm._custom_ops import fused_add_rms_norm, rms_norm
|
53
58
|
|
59
|
+
_vllm_version = Version(vllm.__version__)
|
60
|
+
|
54
61
|
logger = logging.getLogger(__name__)
|
55
62
|
|
56
63
|
if _is_npu:
|
@@ -127,8 +134,21 @@ class RMSNorm(CustomOp):
|
|
127
134
|
# NOTE: Remove this if aiter kernel supports discontinuous input
|
128
135
|
x = x.contiguous()
|
129
136
|
if residual is not None:
|
130
|
-
|
131
|
-
|
137
|
+
if _vllm_version < Version("0.9"):
|
138
|
+
fused_add_rms_norm(x, residual, self.weight.data, self.variance_epsilon)
|
139
|
+
return x, residual
|
140
|
+
else:
|
141
|
+
residual_out = torch.empty_like(x)
|
142
|
+
output = torch.empty_like(x)
|
143
|
+
fused_add_rms_norm(
|
144
|
+
output,
|
145
|
+
x,
|
146
|
+
residual_out,
|
147
|
+
residual,
|
148
|
+
self.weight.data,
|
149
|
+
self.variance_epsilon,
|
150
|
+
)
|
151
|
+
return output, residual_out
|
132
152
|
out = torch.empty_like(x)
|
133
153
|
rms_norm(out, x, self.weight.data, self.variance_epsilon)
|
134
154
|
return out
|
@@ -271,16 +291,11 @@ class GemmaRMSNorm(CustomOp):
|
|
271
291
|
x: torch.Tensor,
|
272
292
|
residual: Optional[torch.Tensor] = None,
|
273
293
|
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
274
|
-
orig_dtype = x.dtype
|
275
294
|
if residual is not None:
|
276
295
|
x = x + residual
|
277
296
|
residual = x
|
278
297
|
|
279
|
-
x = x.
|
280
|
-
variance = torch_npu.mean(torch_npu.pow(x, 2), dim=-1, keepdim=True)
|
281
|
-
x = x * torch_npu.rsqrt(variance + self.variance_epsilon)
|
282
|
-
x = x * (1.0 + self.weight.float())
|
283
|
-
x = x.to(orig_dtype)
|
298
|
+
x, _ = torch_npu.npu_gemma_rms_norm(x, self.weight, self.variance_epsilon)
|
284
299
|
return x if residual is None else (x, residual)
|
285
300
|
|
286
301
|
|
@@ -312,7 +327,9 @@ class Gemma3RMSNorm(CustomOp):
|
|
312
327
|
return f"{tuple(self.weight.shape)}, eps={self.eps}"
|
313
328
|
|
314
329
|
|
315
|
-
if not (
|
330
|
+
if not (
|
331
|
+
_is_cuda or _is_hip or _is_npu or (_is_cpu and _is_cpu_amx_available) or _is_xpu
|
332
|
+
):
|
316
333
|
logger.info(
|
317
334
|
"sgl-kernel layernorm implementation is not available on current platform. Fallback to other kernel libraries."
|
318
335
|
)
|
sglang/srt/layers/linear.py
CHANGED
@@ -235,9 +235,8 @@ class ReplicatedLinear(LinearBase):
|
|
235
235
|
loaded_weight = loaded_weight[:1]
|
236
236
|
else:
|
237
237
|
raise ValueError(f"{loaded_weight} are not all equal")
|
238
|
-
|
239
|
-
|
240
|
-
), f"Loading weight error: param: {param.size()}, loaded_weight: {loaded_weight.size()}"
|
238
|
+
|
239
|
+
assert param.size() == loaded_weight.size()
|
241
240
|
param.data.copy_(loaded_weight)
|
242
241
|
|
243
242
|
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
@@ -894,6 +893,35 @@ class QKVParallelLinear(ColumnParallelLinear):
|
|
894
893
|
)
|
895
894
|
self.weight_loader_v2(param, loaded_weight_shard, shard_id)
|
896
895
|
|
896
|
+
def _load_qkv_block_scale(
|
897
|
+
self, param: BasevLLMParameter, loaded_weight: torch.Tensor
|
898
|
+
):
|
899
|
+
block_n, _ = self.quant_method.quant_config.weight_block_size
|
900
|
+
q_size = self.total_num_heads * self.head_size // block_n
|
901
|
+
k_size = self.total_num_kv_heads * self.head_size // block_n
|
902
|
+
v_size = self.total_num_kv_heads * self.head_size // block_n
|
903
|
+
shard_offsets = [
|
904
|
+
# (shard_id, shard_offset, shard_size)
|
905
|
+
("q", 0, q_size),
|
906
|
+
("k", q_size, k_size),
|
907
|
+
("v", q_size + k_size, v_size),
|
908
|
+
]
|
909
|
+
for shard_id, shard_offset, shard_size in shard_offsets:
|
910
|
+
loaded_weight_shard = loaded_weight.narrow(
|
911
|
+
param.output_dim, shard_offset, shard_size
|
912
|
+
)
|
913
|
+
rank_shard_offset = self._get_shard_offset_mapping(shard_id) // block_n
|
914
|
+
rank_shard_size = self._get_shard_size_mapping(shard_id) // block_n
|
915
|
+
param.load_qkv_weight(
|
916
|
+
loaded_weight=loaded_weight_shard,
|
917
|
+
num_heads=self.num_kv_head_replicas,
|
918
|
+
shard_id=shard_id,
|
919
|
+
shard_offset=rank_shard_offset,
|
920
|
+
shard_size=rank_shard_size,
|
921
|
+
tp_rank=self.tp_rank,
|
922
|
+
use_presharded_weights=self.use_presharded_weights,
|
923
|
+
)
|
924
|
+
|
897
925
|
def weight_loader_v2(
|
898
926
|
self,
|
899
927
|
param: BasevLLMParameter,
|
@@ -907,6 +935,9 @@ class QKVParallelLinear(ColumnParallelLinear):
|
|
907
935
|
elif type(param) in (RowvLLMParameter, BasevLLMParameter):
|
908
936
|
param.load_qkv_weight(loaded_weight=loaded_weight)
|
909
937
|
return
|
938
|
+
elif isinstance(param, BlockQuantScaleParameter):
|
939
|
+
self._load_qkv_block_scale(param, loaded_weight)
|
940
|
+
return
|
910
941
|
# TODO: @dsikka - move to parameter.py
|
911
942
|
self._load_fused_module_from_checkpoint(param, loaded_weight)
|
912
943
|
return
|