sglang 0.4.4.post1__py3-none-any.whl → 0.4.4.post3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -0
- sglang/api.py +6 -0
- sglang/bench_one_batch.py +1 -1
- sglang/bench_one_batch_server.py +1 -1
- sglang/bench_serving.py +26 -4
- sglang/check_env.py +3 -4
- sglang/lang/backend/openai.py +18 -5
- sglang/lang/chat_template.py +28 -7
- sglang/lang/interpreter.py +7 -3
- sglang/lang/ir.py +10 -0
- sglang/srt/_custom_ops.py +1 -1
- sglang/srt/code_completion_parser.py +174 -0
- sglang/srt/configs/__init__.py +2 -6
- sglang/srt/configs/deepseekvl2.py +676 -0
- sglang/srt/configs/janus_pro.py +3 -4
- sglang/srt/configs/load_config.py +1 -0
- sglang/srt/configs/model_config.py +49 -8
- sglang/srt/configs/utils.py +25 -0
- sglang/srt/connector/__init__.py +51 -0
- sglang/srt/connector/base_connector.py +112 -0
- sglang/srt/connector/redis.py +85 -0
- sglang/srt/connector/s3.py +122 -0
- sglang/srt/connector/serde/__init__.py +31 -0
- sglang/srt/connector/serde/safe_serde.py +29 -0
- sglang/srt/connector/serde/serde.py +43 -0
- sglang/srt/connector/utils.py +35 -0
- sglang/srt/conversation.py +88 -0
- sglang/srt/disaggregation/conn.py +81 -0
- sglang/srt/disaggregation/decode.py +495 -0
- sglang/srt/disaggregation/mini_lb.py +285 -0
- sglang/srt/disaggregation/prefill.py +249 -0
- sglang/srt/disaggregation/utils.py +44 -0
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -1
- sglang/srt/distributed/parallel_state.py +42 -8
- sglang/srt/entrypoints/engine.py +55 -5
- sglang/srt/entrypoints/http_server.py +78 -13
- sglang/srt/entrypoints/verl_engine.py +2 -0
- sglang/srt/function_call_parser.py +133 -55
- sglang/srt/hf_transformers_utils.py +28 -3
- sglang/srt/layers/activation.py +4 -2
- sglang/srt/layers/attention/base_attn_backend.py +1 -1
- sglang/srt/layers/attention/flashattention_backend.py +434 -0
- sglang/srt/layers/attention/flashinfer_backend.py +1 -1
- sglang/srt/layers/attention/flashmla_backend.py +284 -0
- sglang/srt/layers/attention/triton_backend.py +171 -38
- sglang/srt/layers/attention/triton_ops/decode_attention.py +94 -31
- sglang/srt/layers/attention/triton_ops/extend_attention.py +14 -5
- sglang/srt/layers/attention/utils.py +53 -0
- sglang/srt/layers/attention/vision.py +9 -28
- sglang/srt/layers/dp_attention.py +41 -19
- sglang/srt/layers/layernorm.py +24 -2
- sglang/srt/layers/linear.py +17 -5
- sglang/srt/layers/logits_processor.py +25 -7
- sglang/srt/layers/moe/ep_moe/kernels.py +110 -11
- sglang/srt/layers/moe/ep_moe/layer.py +273 -1
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +416 -0
- sglang/srt/layers/moe/fused_moe_native.py +2 -1
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L20,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L40S,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1024,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +23 -32
- sglang/srt/layers/moe/fused_moe_triton/layer.py +1 -2
- sglang/srt/layers/moe/topk.py +60 -20
- sglang/srt/layers/parameter.py +1 -1
- sglang/srt/layers/quantization/__init__.py +80 -53
- sglang/srt/layers/quantization/awq.py +200 -0
- sglang/srt/layers/quantization/base_config.py +5 -0
- sglang/srt/layers/quantization/blockwise_int8.py +1 -1
- sglang/srt/layers/quantization/compressed_tensors/__init__.py +0 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +652 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +658 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +9 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py +56 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +162 -0
- sglang/srt/layers/quantization/compressed_tensors/utils.py +218 -0
- sglang/srt/layers/quantization/fp8.py +76 -34
- sglang/srt/layers/quantization/fp8_kernel.py +25 -8
- sglang/srt/layers/quantization/fp8_utils.py +284 -28
- sglang/srt/layers/quantization/gptq.py +36 -19
- sglang/srt/layers/quantization/kv_cache.py +98 -0
- sglang/srt/layers/quantization/modelopt_quant.py +9 -7
- sglang/srt/layers/quantization/utils.py +153 -0
- sglang/srt/layers/quantization/w8a8_fp8.py +70 -19
- sglang/srt/layers/rotary_embedding.py +78 -87
- sglang/srt/layers/sampler.py +1 -1
- sglang/srt/lora/backend/base_backend.py +4 -4
- sglang/srt/lora/backend/flashinfer_backend.py +12 -9
- sglang/srt/lora/backend/triton_backend.py +5 -8
- sglang/srt/lora/layers.py +87 -33
- sglang/srt/lora/lora.py +2 -22
- sglang/srt/lora/lora_manager.py +67 -30
- sglang/srt/lora/mem_pool.py +117 -52
- sglang/srt/lora/triton_ops/gate_up_lora_b.py +10 -4
- sglang/srt/lora/triton_ops/qkv_lora_b.py +8 -3
- sglang/srt/lora/triton_ops/sgemm_lora_a.py +16 -5
- sglang/srt/lora/triton_ops/sgemm_lora_b.py +11 -6
- sglang/srt/lora/utils.py +18 -1
- sglang/srt/managers/cache_controller.py +2 -5
- sglang/srt/managers/data_parallel_controller.py +30 -8
- sglang/srt/managers/expert_distribution.py +81 -0
- sglang/srt/managers/io_struct.py +43 -5
- sglang/srt/managers/mm_utils.py +373 -0
- sglang/srt/managers/multimodal_processor.py +68 -0
- sglang/srt/managers/multimodal_processors/base_processor.py +275 -0
- sglang/srt/managers/multimodal_processors/clip.py +63 -0
- sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +119 -0
- sglang/srt/managers/multimodal_processors/gemma3.py +83 -0
- sglang/srt/managers/{image_processors → multimodal_processors}/janus_pro.py +20 -15
- sglang/srt/managers/{image_processors → multimodal_processors}/llava.py +10 -15
- sglang/srt/managers/multimodal_processors/minicpm.py +167 -0
- sglang/srt/managers/{image_processors → multimodal_processors}/mlama.py +7 -8
- sglang/srt/managers/{image_processors → multimodal_processors}/qwen_vl.py +28 -22
- sglang/srt/managers/schedule_batch.py +134 -30
- sglang/srt/managers/scheduler.py +290 -31
- sglang/srt/managers/session_controller.py +1 -1
- sglang/srt/managers/tokenizer_manager.py +59 -24
- sglang/srt/managers/tp_worker.py +4 -1
- sglang/srt/managers/tp_worker_overlap_thread.py +3 -3
- sglang/srt/managers/utils.py +6 -1
- sglang/srt/mem_cache/hiradix_cache.py +18 -7
- sglang/srt/mem_cache/memory_pool.py +255 -98
- sglang/srt/mem_cache/paged_allocator.py +2 -2
- sglang/srt/mem_cache/radix_cache.py +4 -4
- sglang/srt/model_executor/cuda_graph_runner.py +36 -21
- sglang/srt/model_executor/forward_batch_info.py +68 -11
- sglang/srt/model_executor/model_runner.py +75 -8
- sglang/srt/model_loader/loader.py +171 -3
- sglang/srt/model_loader/weight_utils.py +51 -3
- sglang/srt/models/clip.py +563 -0
- sglang/srt/models/deepseek_janus_pro.py +31 -88
- sglang/srt/models/deepseek_nextn.py +22 -10
- sglang/srt/models/deepseek_v2.py +329 -73
- sglang/srt/models/deepseek_vl2.py +358 -0
- sglang/srt/models/gemma3_causal.py +694 -0
- sglang/srt/models/gemma3_mm.py +468 -0
- sglang/srt/models/llama.py +47 -7
- sglang/srt/models/llama_eagle.py +1 -0
- sglang/srt/models/llama_eagle3.py +196 -0
- sglang/srt/models/llava.py +3 -3
- sglang/srt/models/llavavid.py +3 -3
- sglang/srt/models/minicpmo.py +1995 -0
- sglang/srt/models/minicpmv.py +62 -137
- sglang/srt/models/mllama.py +4 -4
- sglang/srt/models/phi3_small.py +1 -1
- sglang/srt/models/qwen2.py +3 -0
- sglang/srt/models/qwen2_5_vl.py +68 -146
- sglang/srt/models/qwen2_classification.py +75 -0
- sglang/srt/models/qwen2_moe.py +9 -1
- sglang/srt/models/qwen2_vl.py +25 -63
- sglang/srt/openai_api/adapter.py +201 -104
- sglang/srt/openai_api/protocol.py +33 -7
- sglang/srt/patch_torch.py +71 -0
- sglang/srt/sampling/sampling_batch_info.py +1 -1
- sglang/srt/sampling/sampling_params.py +6 -6
- sglang/srt/server_args.py +114 -14
- sglang/srt/speculative/build_eagle_tree.py +7 -347
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +41 -5
- sglang/srt/speculative/eagle_utils.py +208 -252
- sglang/srt/speculative/eagle_worker.py +140 -54
- sglang/srt/speculative/spec_info.py +6 -1
- sglang/srt/torch_memory_saver_adapter.py +22 -0
- sglang/srt/utils.py +215 -21
- sglang/test/__init__.py +0 -0
- sglang/test/attention/__init__.py +0 -0
- sglang/test/attention/test_flashattn_backend.py +312 -0
- sglang/test/runners.py +29 -2
- sglang/test/test_activation.py +2 -1
- sglang/test/test_block_fp8.py +5 -4
- sglang/test/test_block_fp8_ep.py +2 -1
- sglang/test/test_dynamic_grad_mode.py +58 -0
- sglang/test/test_layernorm.py +3 -2
- sglang/test/test_utils.py +56 -5
- sglang/utils.py +31 -0
- sglang/version.py +1 -1
- {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info}/METADATA +16 -8
- {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info}/RECORD +180 -132
- {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info}/WHEEL +1 -1
- sglang/srt/configs/qwen2_5_vl_config.py +0 -1006
- sglang/srt/managers/image_processor.py +0 -55
- sglang/srt/managers/image_processors/base_image_processor.py +0 -219
- sglang/srt/managers/image_processors/minicpmv.py +0 -86
- sglang/srt/managers/multi_modality_padding.py +0 -134
- {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info/licenses}/LICENSE +0 -0
- {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info}/top_level.txt +0 -0
@@ -47,10 +47,11 @@ from sglang.srt.configs.janus_pro import *
|
|
47
47
|
from sglang.srt.layers.attention.vision import VisionAttention
|
48
48
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
49
49
|
from sglang.srt.layers.quantization import QuantizationConfig
|
50
|
-
from sglang.srt.managers.
|
50
|
+
from sglang.srt.managers.mm_utils import (
|
51
51
|
MultiModalityDataPaddingPatternTokenPairs,
|
52
|
+
general_mm_embed_routine,
|
52
53
|
)
|
53
|
-
from sglang.srt.managers.schedule_batch import
|
54
|
+
from sglang.srt.managers.schedule_batch import MultimodalInputs, global_server_args_dict
|
54
55
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
55
56
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
56
57
|
from sglang.srt.models.llama import LlamaForCausalLM
|
@@ -251,7 +252,7 @@ def resample_patch_embed(
|
|
251
252
|
try:
|
252
253
|
from torch import vmap
|
253
254
|
except ImportError:
|
254
|
-
from
|
255
|
+
from torch.func import vmap
|
255
256
|
|
256
257
|
assert len(patch_embed.shape) == 4, "Four dimensions expected"
|
257
258
|
assert len(new_size) == 2, "New shape should only be hw"
|
@@ -1083,7 +1084,7 @@ def create_siglip_vit(
|
|
1083
1084
|
)
|
1084
1085
|
|
1085
1086
|
if ckpt_path:
|
1086
|
-
state_dict = torch.load(ckpt_path, map_location="cpu")
|
1087
|
+
state_dict = torch.load(ckpt_path, map_location="cpu", weights_only=True)
|
1087
1088
|
|
1088
1089
|
incompatible_keys = model.load_state_dict(state_dict, strict=False)
|
1089
1090
|
print(
|
@@ -1289,7 +1290,7 @@ class MlpProjector(nn.Module):
|
|
1289
1290
|
high_x, low_x = x_or_tuple
|
1290
1291
|
high_x = self.high_up_proj(high_x)
|
1291
1292
|
low_x = self.low_up_proj(low_x)
|
1292
|
-
x = torch.
|
1293
|
+
x = torch.cat([high_x, low_x], dim=-1)
|
1293
1294
|
else:
|
1294
1295
|
x = x_or_tuple
|
1295
1296
|
|
@@ -1958,17 +1959,24 @@ class MultiModalityCausalLM(MultiModalityPreTrainedModel):
|
|
1958
1959
|
)
|
1959
1960
|
self.logits_processor = LogitsProcessor(config)
|
1960
1961
|
|
1961
|
-
def
|
1962
|
-
|
1963
|
-
|
1964
|
-
|
1965
|
-
|
1962
|
+
def get_image_feature(self, image_input: MultimodalInputs) -> torch.Tensor:
|
1963
|
+
pixel_values = image_input.pixel_values
|
1964
|
+
bs, n = pixel_values.shape[0:2]
|
1965
|
+
pixel_values = pixel_values.to(
|
1966
|
+
device=self.vision_model.device, dtype=self.vision_model.dtype
|
1966
1967
|
)
|
1967
|
-
|
1968
|
-
|
1969
|
-
|
1970
|
-
|
1971
|
-
|
1968
|
+
images = rearrange(pixel_values, "b n c h w -> (b n) c h w")
|
1969
|
+
|
1970
|
+
# [b x n, T2, D]
|
1971
|
+
images_embeds = self.aligner(self.vision_model(images))
|
1972
|
+
|
1973
|
+
# [b x n, T2, D] -> [b, n x T2, D]
|
1974
|
+
images_embeds = rearrange(images_embeds, "(b n) t d -> b (n t) d", b=bs, n=n)
|
1975
|
+
|
1976
|
+
return images_embeds
|
1977
|
+
|
1978
|
+
def get_input_embeddings(self) -> nn.Embedding:
|
1979
|
+
return self.language_model.model.embed_tokens
|
1972
1980
|
|
1973
1981
|
@torch.no_grad()
|
1974
1982
|
def forward(
|
@@ -1978,90 +1986,25 @@ class MultiModalityCausalLM(MultiModalityPreTrainedModel):
|
|
1978
1986
|
forward_batch: ForwardBatch,
|
1979
1987
|
) -> torch.Tensor:
|
1980
1988
|
|
1981
|
-
inputs_embeds =
|
1982
|
-
|
1983
|
-
forward_batch
|
1984
|
-
|
1985
|
-
|
1986
|
-
)
|
1987
|
-
|
1988
|
-
image_inputs = forward_batch.image_inputs[0]
|
1989
|
-
|
1990
|
-
images_seq_mask = self.prepare_images_seq_mask(
|
1991
|
-
input_ids=input_ids, image_inputs=image_inputs
|
1992
|
-
)
|
1993
|
-
|
1994
|
-
if images_seq_mask is not None:
|
1995
|
-
input_ids.clamp_(min=0, max=self.config.vocab_size - 1)
|
1996
|
-
inputs_embeds = self.prepare_inputs_embeds(
|
1997
|
-
input_ids=input_ids,
|
1998
|
-
pixel_values=image_inputs.pixel_values,
|
1999
|
-
images_seq_mask=images_seq_mask,
|
2000
|
-
images_emb_mask=image_inputs.images_emb_mask,
|
2001
|
-
)
|
2002
|
-
input_ids = None
|
2003
|
-
|
2004
|
-
if input_ids is not None:
|
2005
|
-
input_ids.clamp_(min=0, max=self.config.vocab_size - 1)
|
1989
|
+
inputs_embeds = general_mm_embed_routine(
|
1990
|
+
input_ids=input_ids,
|
1991
|
+
forward_batch=forward_batch,
|
1992
|
+
embed_tokens=self.get_input_embeddings(),
|
1993
|
+
mm_data_embedding_func=self.get_image_feature,
|
1994
|
+
)
|
2006
1995
|
|
2007
1996
|
return self.language_model(
|
2008
|
-
input_ids=
|
1997
|
+
input_ids=None,
|
2009
1998
|
positions=positions,
|
2010
1999
|
forward_batch=forward_batch,
|
2011
2000
|
input_embeds=inputs_embeds,
|
2012
2001
|
get_embedding=False,
|
2013
2002
|
)
|
2014
2003
|
|
2015
|
-
def prepare_inputs_embeds(
|
2016
|
-
self,
|
2017
|
-
input_ids: torch.LongTensor,
|
2018
|
-
pixel_values: torch.FloatTensor,
|
2019
|
-
images_seq_mask: torch.LongTensor,
|
2020
|
-
images_emb_mask: torch.BoolTensor,
|
2021
|
-
**_kwargs,
|
2022
|
-
):
|
2023
|
-
"""
|
2024
|
-
|
2025
|
-
Args:
|
2026
|
-
input_ids (torch.LongTensor): [b, T]
|
2027
|
-
pixel_values (torch.FloatTensor): [b, n_images, 3, h, w]
|
2028
|
-
images_seq_mask (torch.BoolTensor): [b, T]
|
2029
|
-
images_emb_mask (torch.BoolTensor): [b, n_images, n_image_tokens]
|
2030
|
-
|
2031
|
-
assert torch.sum(images_seq_mask) == torch.sum(images_emb_mask)
|
2032
|
-
|
2033
|
-
Returns:
|
2034
|
-
input_embeds (torch.Tensor): [b, T, D]
|
2035
|
-
"""
|
2036
|
-
|
2037
|
-
bs, n = pixel_values.shape[0:2]
|
2038
|
-
pixel_values = pixel_values.to(
|
2039
|
-
device=self.vision_model.device, dtype=self.vision_model.dtype
|
2040
|
-
)
|
2041
|
-
images = rearrange(pixel_values, "b n c h w -> (b n) c h w")
|
2042
|
-
|
2043
|
-
# [b x n, T2, D]
|
2044
|
-
images_embeds = self.aligner(self.vision_model(images))
|
2045
|
-
|
2046
|
-
# [b x n, T2, D] -> [b, n x T2, D]
|
2047
|
-
images_embeds = rearrange(images_embeds, "(b n) t d -> b (n t) d", b=bs, n=n)
|
2048
|
-
# [b, n, T2] -> [b, n x T2]
|
2049
|
-
images_emb_mask = rearrange(images_emb_mask, "b n t -> b (n t)")
|
2050
|
-
|
2051
|
-
# [b, T, D]
|
2052
|
-
# ignore the image embeddings
|
2053
|
-
input_ids[input_ids < 0] = 0
|
2054
|
-
inputs_embeds = self.language_model.model.embed_tokens(input_ids)
|
2055
|
-
|
2056
|
-
# replace with the image embeddings
|
2057
|
-
inputs_embeds[images_seq_mask] = images_embeds[images_emb_mask]
|
2058
|
-
|
2059
|
-
return inputs_embeds
|
2060
|
-
|
2061
2004
|
def prepare_gen_img_embeds(self, image_ids: torch.LongTensor):
|
2062
2005
|
return self.gen_aligner(self.gen_embed(image_ids))
|
2063
2006
|
|
2064
|
-
def pad_input_ids(self, input_ids: List[int], image_inputs:
|
2007
|
+
def pad_input_ids(self, input_ids: List[int], image_inputs: MultimodalInputs):
|
2065
2008
|
im_start_id = image_inputs.im_start_id
|
2066
2009
|
im_end_id = image_inputs.im_end_id
|
2067
2010
|
media_token_pairs = [(im_start_id, im_end_id)]
|
@@ -18,7 +18,6 @@ from typing import Iterable, Optional, Tuple
|
|
18
18
|
import torch
|
19
19
|
from torch import nn
|
20
20
|
from transformers import PretrainedConfig
|
21
|
-
from vllm import _custom_ops as ops
|
22
21
|
|
23
22
|
from sglang.srt.layers.layernorm import RMSNorm
|
24
23
|
from sglang.srt.layers.linear import ReplicatedLinear
|
@@ -41,9 +40,15 @@ from sglang.srt.managers.schedule_batch import global_server_args_dict
|
|
41
40
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
42
41
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
43
42
|
from sglang.srt.models.deepseek_v2 import DeepseekV2DecoderLayer, DeepseekV3ForCausalLM
|
44
|
-
from sglang.srt.utils import add_prefix, is_hip
|
43
|
+
from sglang.srt.utils import add_prefix, is_cuda, is_hip
|
45
44
|
|
46
45
|
_is_hip = is_hip()
|
46
|
+
_is_cuda = is_cuda()
|
47
|
+
|
48
|
+
if _is_cuda:
|
49
|
+
from sgl_kernel import awq_dequantize
|
50
|
+
else:
|
51
|
+
from vllm import _custom_ops as ops
|
47
52
|
|
48
53
|
|
49
54
|
class DeepseekModelNextN(nn.Module):
|
@@ -261,14 +266,21 @@ class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM):
|
|
261
266
|
self_attn = self.model.decoder.self_attn
|
262
267
|
if hasattr(self_attn.kv_b_proj, "qweight"):
|
263
268
|
# AWQ compatible
|
264
|
-
|
265
|
-
|
266
|
-
|
267
|
-
|
268
|
-
|
269
|
-
|
270
|
-
|
271
|
-
|
269
|
+
if _is_cuda:
|
270
|
+
w = awq_dequantize(
|
271
|
+
self_attn.kv_b_proj.qweight,
|
272
|
+
self_attn.kv_b_proj.scales,
|
273
|
+
self_attn.kv_b_proj.qzeros,
|
274
|
+
).T
|
275
|
+
else:
|
276
|
+
w = ops.awq_dequantize(
|
277
|
+
self_attn.kv_b_proj.qweight,
|
278
|
+
self_attn.kv_b_proj.scales,
|
279
|
+
self_attn.kv_b_proj.qzeros,
|
280
|
+
0,
|
281
|
+
0,
|
282
|
+
0,
|
283
|
+
).T
|
272
284
|
else:
|
273
285
|
w = self_attn.kv_b_proj.weight
|
274
286
|
# NOTE(HandH1998): Since `bmm_fp8` only supports per-tensor scale, we have to requantize `self_attn.kv_b_proj`.
|