sglang 0.5.2rc2__py3-none-any.whl → 0.5.3rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch_server.py +10 -1
- sglang/bench_serving.py +257 -29
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/device_config.py +3 -1
- sglang/srt/configs/dots_vlm.py +139 -0
- sglang/srt/configs/load_config.py +1 -0
- sglang/srt/configs/model_config.py +50 -6
- sglang/srt/configs/qwen3_next.py +326 -0
- sglang/srt/connector/__init__.py +8 -1
- sglang/srt/connector/remote_instance.py +82 -0
- sglang/srt/constrained/base_grammar_backend.py +48 -12
- sglang/srt/constrained/llguidance_backend.py +0 -1
- sglang/srt/constrained/outlines_backend.py +0 -1
- sglang/srt/constrained/xgrammar_backend.py +28 -9
- sglang/srt/custom_op.py +11 -1
- sglang/srt/debug_utils/dump_comparator.py +81 -44
- sglang/srt/debug_utils/dump_loader.py +97 -0
- sglang/srt/debug_utils/dumper.py +11 -3
- sglang/srt/debug_utils/text_comparator.py +73 -11
- sglang/srt/disaggregation/base/conn.py +1 -1
- sglang/srt/disaggregation/common/conn.py +15 -12
- sglang/srt/disaggregation/decode.py +21 -10
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -1
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +6 -445
- sglang/srt/disaggregation/mooncake/conn.py +18 -10
- sglang/srt/disaggregation/nixl/conn.py +180 -16
- sglang/srt/disaggregation/prefill.py +5 -3
- sglang/srt/disaggregation/utils.py +5 -50
- sglang/srt/distributed/parallel_state.py +24 -3
- sglang/srt/entrypoints/engine.py +38 -17
- sglang/srt/entrypoints/grpc_request_manager.py +580 -0
- sglang/srt/entrypoints/grpc_server.py +680 -0
- sglang/srt/entrypoints/http_server.py +85 -54
- sglang/srt/entrypoints/openai/protocol.py +4 -1
- sglang/srt/entrypoints/openai/serving_base.py +46 -3
- sglang/srt/entrypoints/openai/serving_chat.py +36 -16
- sglang/srt/entrypoints/openai/serving_completions.py +12 -3
- sglang/srt/entrypoints/openai/serving_embedding.py +8 -3
- sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
- sglang/srt/entrypoints/openai/serving_responses.py +6 -3
- sglang/srt/entrypoints/openai/serving_score.py +1 -0
- sglang/srt/eplb/eplb_manager.py +2 -2
- sglang/srt/eplb/expert_distribution.py +26 -13
- sglang/srt/eplb/expert_location.py +8 -3
- sglang/srt/eplb/expert_location_updater.py +1 -1
- sglang/srt/function_call/base_format_detector.py +3 -6
- sglang/srt/function_call/ebnf_composer.py +11 -9
- sglang/srt/function_call/function_call_parser.py +6 -0
- sglang/srt/function_call/glm4_moe_detector.py +1 -1
- sglang/srt/function_call/qwen3_coder_detector.py +1 -1
- sglang/srt/grpc/__init__.py +1 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +106 -0
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +427 -0
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +236 -0
- sglang/srt/hf_transformers_utils.py +4 -0
- sglang/srt/layers/activation.py +142 -9
- sglang/srt/layers/attention/ascend_backend.py +11 -4
- sglang/srt/layers/attention/fla/chunk.py +242 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
- sglang/srt/layers/attention/fla/chunk_o.py +178 -0
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
- sglang/srt/layers/attention/fla/cumsum.py +300 -0
- sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
- sglang/srt/layers/attention/fla/index.py +37 -0
- sglang/srt/layers/attention/fla/l2norm.py +150 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
- sglang/srt/layers/attention/fla/op.py +66 -0
- sglang/srt/layers/attention/fla/solve_tril.py +465 -0
- sglang/srt/layers/attention/fla/utils.py +331 -0
- sglang/srt/layers/attention/fla/wy_fast.py +158 -0
- sglang/srt/layers/attention/flashinfer_backend.py +6 -4
- sglang/srt/layers/attention/flashinfer_mla_backend.py +16 -12
- sglang/srt/layers/attention/hybrid_attn_backend.py +57 -50
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
- sglang/srt/layers/attention/intel_amx_backend.py +3 -0
- sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
- sglang/srt/layers/attention/mamba/mamba.py +64 -0
- sglang/srt/layers/attention/torch_native_backend.py +12 -6
- sglang/srt/layers/attention/triton_backend.py +18 -1
- sglang/srt/layers/attention/trtllm_mla_backend.py +124 -31
- sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
- sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
- sglang/srt/layers/dp_attention.py +30 -1
- sglang/srt/layers/layernorm.py +32 -15
- sglang/srt/layers/linear.py +34 -3
- sglang/srt/layers/logits_processor.py +29 -10
- sglang/srt/layers/moe/__init__.py +2 -1
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
- sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
- sglang/srt/layers/moe/ep_moe/layer.py +182 -62
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +156 -0
- sglang/srt/layers/moe/fused_moe_native.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +1 -1
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
- sglang/srt/layers/moe/fused_moe_triton/layer.py +61 -59
- sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
- sglang/srt/layers/moe/moe_runner/base.py +274 -1
- sglang/srt/layers/moe/moe_runner/runner.py +80 -0
- sglang/srt/layers/moe/moe_runner/triton.py +448 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
- sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
- sglang/srt/layers/moe/token_dispatcher/deepep.py +43 -39
- sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
- sglang/srt/layers/moe/topk.py +30 -9
- sglang/srt/layers/moe/utils.py +12 -6
- sglang/srt/layers/quantization/awq.py +19 -7
- sglang/srt/layers/quantization/base_config.py +11 -6
- sglang/srt/layers/quantization/blockwise_int8.py +38 -27
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
- sglang/srt/layers/quantization/fp8.py +76 -47
- sglang/srt/layers/quantization/fp8_utils.py +50 -31
- sglang/srt/layers/quantization/gptq.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +147 -47
- sglang/srt/layers/quantization/moe_wna16.py +21 -18
- sglang/srt/layers/quantization/mxfp4.py +64 -40
- sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
- sglang/srt/layers/quantization/unquant.py +135 -47
- sglang/srt/layers/quantization/w4afp8.py +30 -17
- sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
- sglang/srt/layers/quantization/w8a8_int8.py +76 -38
- sglang/srt/layers/sampler.py +162 -18
- sglang/srt/lora/backend/base_backend.py +50 -8
- sglang/srt/lora/backend/triton_backend.py +90 -2
- sglang/srt/lora/layers.py +32 -0
- sglang/srt/lora/lora.py +4 -1
- sglang/srt/lora/lora_manager.py +35 -112
- sglang/srt/lora/mem_pool.py +24 -10
- sglang/srt/lora/utils.py +18 -9
- sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
- sglang/srt/managers/cache_controller.py +158 -160
- sglang/srt/managers/data_parallel_controller.py +105 -35
- sglang/srt/managers/detokenizer_manager.py +8 -4
- sglang/srt/managers/disagg_service.py +46 -0
- sglang/srt/managers/io_struct.py +199 -12
- sglang/srt/managers/mm_utils.py +1 -0
- sglang/srt/managers/multi_tokenizer_mixin.py +350 -400
- sglang/srt/managers/schedule_batch.py +77 -56
- sglang/srt/managers/schedule_policy.py +1 -1
- sglang/srt/managers/scheduler.py +187 -39
- sglang/srt/managers/scheduler_metrics_mixin.py +4 -3
- sglang/srt/managers/scheduler_output_processor_mixin.py +55 -11
- sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
- sglang/srt/managers/tokenizer_communicator_mixin.py +569 -0
- sglang/srt/managers/tokenizer_manager.py +259 -519
- sglang/srt/managers/tp_worker.py +53 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +42 -19
- sglang/srt/mem_cache/hicache_storage.py +3 -23
- sglang/srt/mem_cache/hiradix_cache.py +103 -43
- sglang/srt/mem_cache/memory_pool.py +347 -48
- sglang/srt/mem_cache/memory_pool_host.py +105 -46
- sglang/srt/mem_cache/radix_cache.py +0 -2
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
- sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +86 -4
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
- sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +49 -7
- sglang/srt/mem_cache/swa_radix_cache.py +0 -2
- sglang/srt/metrics/collector.py +493 -76
- sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
- sglang/srt/model_executor/cpu_graph_runner.py +640 -0
- sglang/srt/model_executor/cuda_graph_runner.py +13 -5
- sglang/srt/model_executor/forward_batch_info.py +59 -2
- sglang/srt/model_executor/model_runner.py +356 -29
- sglang/srt/model_loader/__init__.py +9 -3
- sglang/srt/model_loader/loader.py +128 -4
- sglang/srt/model_loader/weight_utils.py +2 -1
- sglang/srt/models/apertus.py +686 -0
- sglang/srt/models/bailing_moe.py +798 -218
- sglang/srt/models/bailing_moe_nextn.py +168 -0
- sglang/srt/models/deepseek_v2.py +109 -15
- sglang/srt/models/dots_vlm.py +174 -0
- sglang/srt/models/dots_vlm_vit.py +337 -0
- sglang/srt/models/ernie4.py +1 -1
- sglang/srt/models/gemma3n_mm.py +1 -1
- sglang/srt/models/glm4_moe.py +1 -1
- sglang/srt/models/glm4v.py +4 -2
- sglang/srt/models/glm4v_moe.py +3 -0
- sglang/srt/models/gpt_oss.py +1 -1
- sglang/srt/models/llama4.py +9 -0
- sglang/srt/models/llama_eagle3.py +13 -0
- sglang/srt/models/longcat_flash.py +2 -2
- sglang/srt/models/mllama4.py +25 -0
- sglang/srt/models/opt.py +637 -0
- sglang/srt/models/qwen2.py +7 -0
- sglang/srt/models/qwen2_5_vl.py +27 -3
- sglang/srt/models/qwen2_moe.py +56 -12
- sglang/srt/models/qwen3_moe.py +1 -1
- sglang/srt/models/qwen3_next.py +1042 -0
- sglang/srt/models/qwen3_next_mtp.py +112 -0
- sglang/srt/models/step3_vl.py +1 -1
- sglang/srt/multimodal/processors/dots_vlm.py +99 -0
- sglang/srt/multimodal/processors/glm4v.py +9 -9
- sglang/srt/multimodal/processors/internvl.py +141 -129
- sglang/srt/multimodal/processors/qwen_vl.py +15 -5
- sglang/srt/offloader.py +27 -3
- sglang/srt/remote_instance_weight_loader_utils.py +69 -0
- sglang/srt/sampling/sampling_batch_info.py +18 -15
- sglang/srt/server_args.py +276 -35
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
- sglang/srt/speculative/eagle_utils.py +0 -2
- sglang/srt/speculative/eagle_worker.py +43 -4
- sglang/srt/speculative/spec_info.py +5 -0
- sglang/srt/speculative/standalone_worker.py +109 -0
- sglang/srt/tracing/trace.py +552 -0
- sglang/srt/utils.py +34 -3
- sglang/srt/weight_sync/utils.py +1 -1
- sglang/test/attention/test_trtllm_mla_backend.py +169 -5
- sglang/test/runners.py +4 -0
- sglang/test/test_cutlass_moe.py +24 -6
- sglang/test/test_disaggregation_utils.py +66 -0
- sglang/test/test_fp4_moe.py +370 -1
- sglang/test/test_utils.py +28 -1
- sglang/utils.py +11 -0
- sglang/version.py +1 -1
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/METADATA +59 -123
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/RECORD +237 -178
- sglang/srt/disaggregation/launch_lb.py +0 -118
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/WHEEL +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,168 @@
|
|
1
|
+
# coding=utf-8
|
2
|
+
# Copyright 2023 Antgroup and The HuggingFace Inc. team. All rights reserved.
|
3
|
+
#
|
4
|
+
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
5
|
+
# and OPT implementations in this library. It has been modified from its
|
6
|
+
# original forms to accommodate minor architectural differences compared
|
7
|
+
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
8
|
+
#
|
9
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
10
|
+
# you may not use this file except in compliance with the License.
|
11
|
+
# You may obtain a copy of the License at
|
12
|
+
#
|
13
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
14
|
+
#
|
15
|
+
# Unless required by applicable law or agreed to in writing, software
|
16
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
17
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
18
|
+
# See the License for the specific language governing permissions and
|
19
|
+
# limitations under the License.
|
20
|
+
""" SGLang BailingMoENextN model."""
|
21
|
+
import logging
|
22
|
+
from typing import Iterable, Optional, Tuple
|
23
|
+
|
24
|
+
import torch
|
25
|
+
from torch import nn
|
26
|
+
from transformers import PretrainedConfig
|
27
|
+
|
28
|
+
from sglang.srt.distributed import get_tensor_model_parallel_world_size
|
29
|
+
from sglang.srt.layers.dp_attention import is_dp_attention_enabled
|
30
|
+
from sglang.srt.layers.layernorm import RMSNorm
|
31
|
+
from sglang.srt.layers.logits_processor import LogitsProcessor
|
32
|
+
from sglang.srt.layers.moe.topk import select_experts
|
33
|
+
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
34
|
+
from sglang.srt.layers.vocab_parallel_embedding import (
|
35
|
+
ParallelLMHead,
|
36
|
+
VocabParallelEmbedding,
|
37
|
+
)
|
38
|
+
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
39
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
40
|
+
from sglang.srt.models.bailing_moe import BailingMoEBlock, BailingMoEForCausalLM
|
41
|
+
from sglang.srt.utils import add_prefix
|
42
|
+
|
43
|
+
LoraConfig = None
|
44
|
+
logger = logging.getLogger(__name__)
|
45
|
+
|
46
|
+
|
47
|
+
class BailingMoEModelNextN(nn.Module):
|
48
|
+
def __init__(
|
49
|
+
self,
|
50
|
+
config: PretrainedConfig,
|
51
|
+
quant_config: Optional[QuantizationConfig] = None,
|
52
|
+
prefix: str = "",
|
53
|
+
) -> None:
|
54
|
+
super().__init__()
|
55
|
+
if quant_config is not None and quant_config.get_name() == "modelopt_fp4":
|
56
|
+
logger.warning(
|
57
|
+
"Overriding DeepseekV3ForCausalLMNextN quant config for modelopt_fp4 Deepseek model."
|
58
|
+
)
|
59
|
+
quant_config = None
|
60
|
+
|
61
|
+
self.vocab_size = config.vocab_size
|
62
|
+
|
63
|
+
self.word_embeddings = VocabParallelEmbedding(
|
64
|
+
config.vocab_size,
|
65
|
+
config.hidden_size,
|
66
|
+
enable_tp=not is_dp_attention_enabled(),
|
67
|
+
prefix=add_prefix("word_embeddings", prefix),
|
68
|
+
)
|
69
|
+
|
70
|
+
self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
71
|
+
self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
72
|
+
|
73
|
+
self.eh_proj = nn.Linear(2 * config.hidden_size, config.hidden_size, bias=False)
|
74
|
+
|
75
|
+
self.decoder = BailingMoEBlock(
|
76
|
+
config,
|
77
|
+
0,
|
78
|
+
quant_config=quant_config,
|
79
|
+
# is_nextn=True,
|
80
|
+
prefix=add_prefix("decoder", prefix),
|
81
|
+
)
|
82
|
+
|
83
|
+
self.shared_head = nn.Module()
|
84
|
+
self.final_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
85
|
+
|
86
|
+
def forward(
|
87
|
+
self,
|
88
|
+
input_ids: torch.Tensor,
|
89
|
+
positions: torch.Tensor,
|
90
|
+
forward_batch: ForwardBatch,
|
91
|
+
input_embeds: torch.Tensor = None,
|
92
|
+
) -> torch.Tensor:
|
93
|
+
|
94
|
+
if input_embeds is None:
|
95
|
+
hidden_states = self.word_embeddings(input_ids)
|
96
|
+
else:
|
97
|
+
hidden_states = input_embeds
|
98
|
+
|
99
|
+
if hidden_states.shape[0] > 0:
|
100
|
+
hidden_states = self.eh_proj(
|
101
|
+
torch.cat(
|
102
|
+
(
|
103
|
+
self.enorm(hidden_states),
|
104
|
+
self.hnorm(forward_batch.spec_info.hidden_states),
|
105
|
+
),
|
106
|
+
dim=-1,
|
107
|
+
)
|
108
|
+
)
|
109
|
+
|
110
|
+
residual = None
|
111
|
+
hidden_states, residual = self.decoder(
|
112
|
+
positions, hidden_states, forward_batch, residual
|
113
|
+
)
|
114
|
+
|
115
|
+
if not forward_batch.forward_mode.is_idle():
|
116
|
+
if residual is not None:
|
117
|
+
hidden_states, _ = self.final_layernorm(hidden_states, residual)
|
118
|
+
else:
|
119
|
+
hidden_states = self.final_layernorm(hidden_states)
|
120
|
+
|
121
|
+
return hidden_states
|
122
|
+
|
123
|
+
|
124
|
+
class BailingMoeForCausalLMNextN(BailingMoEForCausalLM):
|
125
|
+
|
126
|
+
def __init__(
|
127
|
+
self,
|
128
|
+
config: PretrainedConfig,
|
129
|
+
quant_config: Optional[QuantizationConfig] = None,
|
130
|
+
prefix: str = "",
|
131
|
+
) -> None:
|
132
|
+
nn.Module.__init__(self)
|
133
|
+
self.config = config
|
134
|
+
self.tp_size = get_tensor_model_parallel_world_size()
|
135
|
+
self.quant_config = quant_config
|
136
|
+
if hasattr(self, "determine_num_fused_shared_experts"):
|
137
|
+
# Asystem has determine_num_fused_shared_experts but theta does not.
|
138
|
+
self.determine_num_fused_shared_experts("BailingMoeForCausalLMNextN")
|
139
|
+
|
140
|
+
self.model = BailingMoEModelNextN(
|
141
|
+
config, quant_config, prefix=add_prefix("model", prefix)
|
142
|
+
)
|
143
|
+
self.lm_head = ParallelLMHead(
|
144
|
+
config.vocab_size,
|
145
|
+
config.hidden_size,
|
146
|
+
quant_config=quant_config,
|
147
|
+
prefix=add_prefix("model.shared_head.head", prefix),
|
148
|
+
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
|
149
|
+
)
|
150
|
+
self.logits_processor = LogitsProcessor(config)
|
151
|
+
|
152
|
+
@torch.no_grad()
|
153
|
+
def forward(
|
154
|
+
self,
|
155
|
+
input_ids: torch.Tensor,
|
156
|
+
positions: torch.Tensor,
|
157
|
+
forward_batch: ForwardBatch,
|
158
|
+
) -> torch.Tensor:
|
159
|
+
hidden_states = self.model(input_ids, positions, forward_batch)
|
160
|
+
return self.logits_processor(
|
161
|
+
input_ids, hidden_states, self.lm_head, forward_batch
|
162
|
+
)
|
163
|
+
|
164
|
+
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
165
|
+
super().load_weights(weights, is_nextn=True)
|
166
|
+
|
167
|
+
|
168
|
+
EntryClass = [BailingMoeForCausalLMNextN]
|
sglang/srt/models/deepseek_v2.py
CHANGED
@@ -65,10 +65,11 @@ from sglang.srt.layers.moe import (
|
|
65
65
|
get_deepep_mode,
|
66
66
|
get_moe_a2a_backend,
|
67
67
|
should_use_flashinfer_cutlass_moe_fp4_allgather,
|
68
|
+
should_use_flashinfer_trtllm_moe,
|
68
69
|
)
|
69
70
|
from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, get_moe_impl_class
|
70
71
|
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
|
71
|
-
from sglang.srt.layers.moe.topk import TopK
|
72
|
+
from sglang.srt.layers.moe.topk import TopK, TopKOutputFormat
|
72
73
|
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
73
74
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
74
75
|
from sglang.srt.layers.quantization.fp8_kernel import (
|
@@ -151,6 +152,7 @@ if _is_cuda:
|
|
151
152
|
from sgl_kernel import (
|
152
153
|
awq_dequantize,
|
153
154
|
bmm_fp8,
|
155
|
+
concat_mla_k,
|
154
156
|
dsv3_fused_a_gemm,
|
155
157
|
dsv3_router_gemm,
|
156
158
|
merge_state_v2,
|
@@ -246,7 +248,11 @@ class DeepseekV2MLP(nn.Module):
|
|
246
248
|
if (self.tp_size == 1) and x.shape[0] == 0:
|
247
249
|
return x
|
248
250
|
|
249
|
-
if
|
251
|
+
if (
|
252
|
+
gemm_output_zero_allocator is not None
|
253
|
+
and x.shape[0] <= 256
|
254
|
+
and self.gate_up_proj.weight.dtype == torch.uint8
|
255
|
+
):
|
250
256
|
y = gemm_output_zero_allocator.allocate(
|
251
257
|
x.shape[0] * self.gate_up_proj.output_size_per_partition
|
252
258
|
).view(x.shape[0], self.gate_up_proj.output_size_per_partition)
|
@@ -264,6 +270,7 @@ class MoEGate(nn.Module):
|
|
264
270
|
def __init__(
|
265
271
|
self,
|
266
272
|
config,
|
273
|
+
quant_config,
|
267
274
|
prefix: str = "",
|
268
275
|
is_nextn: bool = False,
|
269
276
|
):
|
@@ -273,8 +280,15 @@ class MoEGate(nn.Module):
|
|
273
280
|
torch.empty((config.n_routed_experts, config.hidden_size))
|
274
281
|
)
|
275
282
|
if config.topk_method == "noaux_tc":
|
283
|
+
correction_bias_dtype = (
|
284
|
+
torch.bfloat16
|
285
|
+
if quant_config is not None
|
286
|
+
and quant_config.get_name() == "modelopt_fp4"
|
287
|
+
and should_use_flashinfer_trtllm_moe()
|
288
|
+
else torch.float32
|
289
|
+
)
|
276
290
|
self.e_score_correction_bias = nn.Parameter(
|
277
|
-
torch.empty((config.n_routed_experts), dtype=
|
291
|
+
torch.empty((config.n_routed_experts), dtype=correction_bias_dtype)
|
278
292
|
)
|
279
293
|
else:
|
280
294
|
self.e_score_correction_bias = None
|
@@ -299,7 +313,9 @@ class MoEGate(nn.Module):
|
|
299
313
|
and _device_sm >= 90
|
300
314
|
):
|
301
315
|
# router gemm output float32
|
302
|
-
logits = dsv3_router_gemm(
|
316
|
+
logits = dsv3_router_gemm(
|
317
|
+
hidden_states, self.weight, out_dtype=torch.float32
|
318
|
+
)
|
303
319
|
elif _use_aiter_gfx95 and hidden_states.shape[0] <= 256:
|
304
320
|
logits = aiter_dsv3_router_gemm(
|
305
321
|
hidden_states, self.weight, gemm_output_zero_allocator
|
@@ -347,7 +363,10 @@ class DeepseekV2MoE(nn.Module):
|
|
347
363
|
)
|
348
364
|
|
349
365
|
self.gate = MoEGate(
|
350
|
-
config=config,
|
366
|
+
config=config,
|
367
|
+
quant_config=quant_config,
|
368
|
+
prefix=add_prefix("gate", prefix),
|
369
|
+
is_nextn=is_nextn,
|
351
370
|
)
|
352
371
|
|
353
372
|
self.experts = get_moe_impl_class(quant_config)(
|
@@ -372,9 +391,12 @@ class DeepseekV2MoE(nn.Module):
|
|
372
391
|
num_fused_shared_experts=self.num_fused_shared_experts,
|
373
392
|
topk_group=config.topk_group,
|
374
393
|
correction_bias=self.gate.e_score_correction_bias,
|
394
|
+
quant_config=quant_config,
|
375
395
|
routed_scaling_factor=self.routed_scaling_factor,
|
376
396
|
apply_routed_scaling_factor_on_output=self.experts.should_fuse_routed_scaling_factor_in_topk(),
|
377
|
-
|
397
|
+
# Some Fp4 MoE backends require the output format to be bypassed but the MTP layers are unquantized
|
398
|
+
# and requires the output format to be standard. We use quant_config to determine the output format.
|
399
|
+
output_format=TopKOutputFormat.STANDARD if quant_config is None else None,
|
378
400
|
)
|
379
401
|
|
380
402
|
self.shared_experts_is_int8 = False
|
@@ -661,10 +683,14 @@ class DeepseekV2MoE(nn.Module):
|
|
661
683
|
|
662
684
|
if shared_output is not None:
|
663
685
|
x = shared_output
|
664
|
-
|
686
|
+
if self.experts.should_fuse_routed_scaling_factor_in_topk():
|
687
|
+
x.add_(final_hidden_states)
|
688
|
+
else:
|
689
|
+
x.add_(final_hidden_states, alpha=self.routed_scaling_factor)
|
665
690
|
final_hidden_states = x
|
666
691
|
else:
|
667
|
-
|
692
|
+
if not self.experts.should_fuse_routed_scaling_factor_in_topk():
|
693
|
+
final_hidden_states *= self.routed_scaling_factor
|
668
694
|
|
669
695
|
return final_hidden_states
|
670
696
|
|
@@ -1033,6 +1059,15 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
1033
1059
|
# Determine attention backend used by current forward batch
|
1034
1060
|
if forward_batch.forward_mode.is_decode_or_idle():
|
1035
1061
|
attention_backend = global_server_args_dict["decode_attention_backend"]
|
1062
|
+
elif (
|
1063
|
+
forward_batch.forward_mode.is_target_verify()
|
1064
|
+
or forward_batch.forward_mode.is_draft_extend()
|
1065
|
+
):
|
1066
|
+
# Use the specified backend for speculative operations (both verify and draft extend)
|
1067
|
+
if global_server_args_dict["speculative_attention_mode"] == "decode":
|
1068
|
+
attention_backend = global_server_args_dict["decode_attention_backend"]
|
1069
|
+
else: # default to prefill
|
1070
|
+
attention_backend = global_server_args_dict["prefill_attention_backend"]
|
1036
1071
|
else:
|
1037
1072
|
attention_backend = global_server_args_dict["prefill_attention_backend"]
|
1038
1073
|
self.current_attention_backend = attention_backend
|
@@ -1050,7 +1085,6 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
1050
1085
|
attention_backend == "flashinfer"
|
1051
1086
|
or attention_backend == "fa3"
|
1052
1087
|
or attention_backend == "flashmla"
|
1053
|
-
or attention_backend == "trtllm_mla"
|
1054
1088
|
or attention_backend == "cutlass_mla"
|
1055
1089
|
):
|
1056
1090
|
# Use MHA with chunked KV cache when prefilling on long sequences.
|
@@ -1063,6 +1097,8 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
1063
1097
|
disable_ragged = (
|
1064
1098
|
attention_backend == "flashinfer" or attention_backend == "flashmla"
|
1065
1099
|
) and self.flashinfer_mla_disable_ragged
|
1100
|
+
|
1101
|
+
original_mode = getattr(forward_batch, "_original_forward_mode", None)
|
1066
1102
|
if (
|
1067
1103
|
not disable_ragged
|
1068
1104
|
and forward_batch.forward_mode.is_extend()
|
@@ -1075,6 +1111,40 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
1075
1111
|
)
|
1076
1112
|
or sum_extend_prefix_lens == 0
|
1077
1113
|
)
|
1114
|
+
# TODO(shuw@nvidia.com) Flashinfer cutlass and trtllm_mla backend have accuracy issue on blackwell for
|
1115
|
+
# dp case. Redirect to mla kernel as a workaround.
|
1116
|
+
# Tracked by https://github.com/sgl-project/sglang/issues/9806.
|
1117
|
+
and not (
|
1118
|
+
original_mode is not None
|
1119
|
+
and original_mode.is_decode()
|
1120
|
+
and is_sm100_supported()
|
1121
|
+
and self.current_attention_backend in ("cutlass_mla", "flashinfer")
|
1122
|
+
)
|
1123
|
+
):
|
1124
|
+
return AttnForwardMethod.MHA_CHUNKED_KV
|
1125
|
+
else:
|
1126
|
+
return _dispatch_mla_subtype()
|
1127
|
+
elif attention_backend == "trtllm_mla":
|
1128
|
+
original_mode = getattr(forward_batch, "_original_forward_mode", None)
|
1129
|
+
if (
|
1130
|
+
original_mode is not None
|
1131
|
+
and original_mode.is_decode()
|
1132
|
+
and is_sm100_supported()
|
1133
|
+
):
|
1134
|
+
return _dispatch_mla_subtype()
|
1135
|
+
|
1136
|
+
sum_extend_prefix_lens = (
|
1137
|
+
sum(forward_batch.extend_prefix_lens_cpu)
|
1138
|
+
if forward_batch.extend_prefix_lens_cpu is not None
|
1139
|
+
else 0
|
1140
|
+
)
|
1141
|
+
if (
|
1142
|
+
forward_batch.forward_mode.is_extend()
|
1143
|
+
and not forward_batch.forward_mode.is_target_verify()
|
1144
|
+
and not forward_batch.forward_mode.is_draft_extend()
|
1145
|
+
and (
|
1146
|
+
not self.disable_chunked_prefix_cache or sum_extend_prefix_lens == 0
|
1147
|
+
)
|
1078
1148
|
):
|
1079
1149
|
return AttnForwardMethod.MHA_CHUNKED_KV
|
1080
1150
|
else:
|
@@ -1235,8 +1305,18 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
1235
1305
|
q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
|
1236
1306
|
q[..., self.qk_nope_head_dim :] = q_pe
|
1237
1307
|
k = torch.empty_like(q)
|
1238
|
-
|
1239
|
-
|
1308
|
+
|
1309
|
+
# Temporary for DeepSeek V3/R1 only, but can generalize if needed
|
1310
|
+
if (
|
1311
|
+
_is_cuda
|
1312
|
+
and (self.num_local_heads == 128)
|
1313
|
+
and (self.qk_nope_head_dim == 128)
|
1314
|
+
and (self.qk_rope_head_dim == 64)
|
1315
|
+
):
|
1316
|
+
concat_mla_k(k=k, k_nope=k_nope, k_rope=k_pe)
|
1317
|
+
else:
|
1318
|
+
k[..., : self.qk_nope_head_dim] = k_nope
|
1319
|
+
k[..., self.qk_nope_head_dim :] = k_pe
|
1240
1320
|
|
1241
1321
|
if not _is_npu:
|
1242
1322
|
latent_cache[:, :, : self.kv_lora_rank] = kv_a.unsqueeze(1)
|
@@ -1998,7 +2078,10 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
1998
2078
|
quant_format = (
|
1999
2079
|
"mxfp4"
|
2000
2080
|
if _is_gfx95_supported
|
2001
|
-
and self.self_attn
|
2081
|
+
and getattr(self.self_attn, "fused_qkv_a_proj_with_mqa", None) is not None
|
2082
|
+
and getattr(self.self_attn.fused_qkv_a_proj_with_mqa, "weight", None)
|
2083
|
+
is not None
|
2084
|
+
and self.self_attn.fused_qkv_a_proj_with_mqa.weight.dtype == torch.uint8
|
2002
2085
|
else ""
|
2003
2086
|
)
|
2004
2087
|
|
@@ -2170,8 +2253,15 @@ class DeepseekV2Model(nn.Module):
|
|
2170
2253
|
[
|
2171
2254
|
"w13_weight",
|
2172
2255
|
"w2_weight",
|
2173
|
-
|
2174
|
-
|
2256
|
+
# only for nvfp4
|
2257
|
+
*(
|
2258
|
+
[
|
2259
|
+
"w13_blockscale_swizzled",
|
2260
|
+
"w2_blockscale_swizzled",
|
2261
|
+
]
|
2262
|
+
if hasattr(module, "w13_blockscale_swizzled")
|
2263
|
+
else []
|
2264
|
+
),
|
2175
2265
|
]
|
2176
2266
|
if isinstance(module, FusedMoE)
|
2177
2267
|
else []
|
@@ -2553,7 +2643,11 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
2553
2643
|
0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim)
|
2554
2644
|
).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)
|
2555
2645
|
|
2556
|
-
if
|
2646
|
+
if (
|
2647
|
+
_use_aiter_gfx95
|
2648
|
+
and self.quant_config is not None
|
2649
|
+
and self.quant_config.get_name() == "quark"
|
2650
|
+
):
|
2557
2651
|
w_kc, self_attn.w_scale_k, w_vc, self_attn.w_scale_v = (
|
2558
2652
|
quark_post_load_weights(self_attn, w, "mxfp4")
|
2559
2653
|
)
|
@@ -0,0 +1,174 @@
|
|
1
|
+
# Copyright 2025 The RedNote HiLab team.
|
2
|
+
# Copyright 2025 The SGLang team.
|
3
|
+
#
|
4
|
+
# This code is based on the DeepseekVL2ForCausalLM and DotsVisionTransformer
|
5
|
+
# implementation in this library.
|
6
|
+
#
|
7
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
8
|
+
# you may not use this file except in compliance with the License.
|
9
|
+
# You may obtain a copy of the License at
|
10
|
+
#
|
11
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
12
|
+
#
|
13
|
+
# Unless required by applicable law or agreed to in writing, software
|
14
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
15
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
16
|
+
# See the License for the specific language governing permissions and
|
17
|
+
# limitations under the License.
|
18
|
+
"""Inference-only Dots-VL model compatible with HuggingFace weights."""
|
19
|
+
|
20
|
+
from typing import Iterable, List, Optional, Tuple
|
21
|
+
|
22
|
+
import torch
|
23
|
+
from torch import nn
|
24
|
+
|
25
|
+
from sglang.srt.configs.dots_vlm import DotsVLMConfig
|
26
|
+
from sglang.srt.distributed import parallel_state
|
27
|
+
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
28
|
+
from sglang.srt.managers.mm_utils import (
|
29
|
+
MultiModalityDataPaddingPatternMultimodalTokens,
|
30
|
+
general_mm_embed_routine,
|
31
|
+
)
|
32
|
+
from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInputs
|
33
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
34
|
+
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
35
|
+
from sglang.srt.models.deepseek_v2 import DeepseekV2ForCausalLM
|
36
|
+
|
37
|
+
from .dots_vlm_vit import DotsVisionTransformer
|
38
|
+
|
39
|
+
|
40
|
+
class DotsVLMForCausalLM(nn.Module):
|
41
|
+
"""DotsVLM model for sglang inference"""
|
42
|
+
|
43
|
+
def __init__(
|
44
|
+
self, config: DotsVLMConfig, quant_config: Optional[QuantizationConfig] = None
|
45
|
+
) -> None:
|
46
|
+
super().__init__()
|
47
|
+
|
48
|
+
self.config = config
|
49
|
+
self.image_token_id = config.im_span_id
|
50
|
+
self.video_token_id = config.video_span_id
|
51
|
+
|
52
|
+
self.language_model = DeepseekV2ForCausalLM(
|
53
|
+
config.language_config, quant_config
|
54
|
+
)
|
55
|
+
|
56
|
+
# Initialize vision tower (matching transformers naming for weight compatibility)
|
57
|
+
self.vision_tower = DotsVisionTransformer(config.vision_config)
|
58
|
+
|
59
|
+
def _pad_vit_attn_dummy_heads(self, name: str, loaded_weight: torch.Tensor):
|
60
|
+
"""pad attn qkv weights for dummy heads"""
|
61
|
+
num_dummy_heads = self.config.vision_config.num_dummy_heads
|
62
|
+
if num_dummy_heads == 0:
|
63
|
+
return loaded_weight
|
64
|
+
head_dim = self.config.vision_config.head_dim
|
65
|
+
|
66
|
+
if "attn.qkv_proj" in name:
|
67
|
+
wq, wk, wv = loaded_weight.chunk(3, dim=0)
|
68
|
+
if name.endswith(".weight"):
|
69
|
+
dummy_shape = [num_dummy_heads, head_dim, wq.shape[-1]]
|
70
|
+
elif name.endswith(".bias"):
|
71
|
+
dummy_shape = [num_dummy_heads, head_dim]
|
72
|
+
else:
|
73
|
+
raise RuntimeError(f"Unsupported weight with name={name}")
|
74
|
+
pad_func = lambda x: torch.cat(
|
75
|
+
[x.unflatten(0, (-1, head_dim)), x.new_zeros(dummy_shape)], dim=0
|
76
|
+
).flatten(0, 1)
|
77
|
+
wq, wk, wv = pad_func(wq), pad_func(wk), pad_func(wv)
|
78
|
+
loaded_weight = torch.cat([wq, wk, wv], dim=0)
|
79
|
+
if "attn.proj.weight" in name:
|
80
|
+
padded_weight = loaded_weight.new_zeros(
|
81
|
+
loaded_weight.shape[0], head_dim * num_dummy_heads
|
82
|
+
)
|
83
|
+
loaded_weight = torch.cat([loaded_weight, padded_weight], dim=-1)
|
84
|
+
if "attn.q_norm.weight" in name or "attn.k_norm.weight" in name:
|
85
|
+
padded_weight = loaded_weight.new_zeros(head_dim * num_dummy_heads)
|
86
|
+
loaded_weight = torch.cat([loaded_weight, padded_weight], dim=0)
|
87
|
+
return loaded_weight
|
88
|
+
|
89
|
+
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
90
|
+
"""Load weights for the model, separating vision and language weights"""
|
91
|
+
weights = list(weights)
|
92
|
+
|
93
|
+
# Separate vision tower weights and language model weights
|
94
|
+
vision_weights = []
|
95
|
+
language_weights = []
|
96
|
+
|
97
|
+
for name, loaded_weight in weights:
|
98
|
+
if name.startswith("vision_tower."):
|
99
|
+
vision_name = name.replace(r"attn.qkv.", r"attn.qkv_proj.")
|
100
|
+
vision_weights.append((vision_name, loaded_weight))
|
101
|
+
else:
|
102
|
+
# All other weights go to language model
|
103
|
+
language_weights.append((name, loaded_weight))
|
104
|
+
|
105
|
+
# Load vision tower weights
|
106
|
+
vision_state_dict = dict(vision_weights)
|
107
|
+
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
108
|
+
for name, loaded_weight in vision_state_dict.items():
|
109
|
+
if name not in params_dict:
|
110
|
+
raise ValueError(f"Weight {name} not found in params_dict")
|
111
|
+
param = params_dict[name]
|
112
|
+
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
113
|
+
loaded_weight = self._pad_vit_attn_dummy_heads(name, loaded_weight)
|
114
|
+
weight_loader(param, loaded_weight)
|
115
|
+
|
116
|
+
# Load language model weights
|
117
|
+
if language_weights:
|
118
|
+
self.language_model.load_weights(language_weights)
|
119
|
+
|
120
|
+
@classmethod
|
121
|
+
def get_model_config_for_expert_location(cls, config):
|
122
|
+
return DeepseekV2ForCausalLM.get_model_config_for_expert_location(config)
|
123
|
+
|
124
|
+
def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
|
125
|
+
"""Pad input_ids with multimodal tokens"""
|
126
|
+
# Get image token ID for padding pattern
|
127
|
+
pattern = MultiModalityDataPaddingPatternMultimodalTokens()
|
128
|
+
padded_input_ids = pattern.pad_input_tokens(input_ids, mm_inputs)
|
129
|
+
return padded_input_ids
|
130
|
+
|
131
|
+
def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
|
132
|
+
# Extract pixel values and grid information (following reference pattern)
|
133
|
+
pixel_values = torch.cat([item.feature for item in items], dim=0).type(
|
134
|
+
self.vision_tower.dtype
|
135
|
+
)
|
136
|
+
image_grid_thw = torch.concat(
|
137
|
+
[item.image_grid_thw for item in items], dim=0
|
138
|
+
).to(self.vision_tower.device)
|
139
|
+
|
140
|
+
# Add dimension checks like in reference code
|
141
|
+
assert pixel_values.dim() == 2, f"{pixel_values.dim()=}"
|
142
|
+
assert image_grid_thw.dim() == 2, f"{image_grid_thw.dim()=}"
|
143
|
+
|
144
|
+
# Process through vision tower
|
145
|
+
image_embeds = self.vision_tower(pixel_values, image_grid_thw)
|
146
|
+
|
147
|
+
# Ensure consistent dtype for FlashInfer compatibility
|
148
|
+
# Force bfloat16 to match model's expected dtype
|
149
|
+
if image_embeds.dtype != torch.bfloat16 and hasattr(
|
150
|
+
self.language_model.model, "embed_tokens"
|
151
|
+
):
|
152
|
+
target_dtype = self.language_model.model.embed_tokens.weight.dtype
|
153
|
+
image_embeds = image_embeds.to(target_dtype)
|
154
|
+
|
155
|
+
return image_embeds
|
156
|
+
|
157
|
+
def forward(
|
158
|
+
self,
|
159
|
+
input_ids: torch.Tensor,
|
160
|
+
positions: torch.Tensor,
|
161
|
+
forward_batch: ForwardBatch,
|
162
|
+
**kwargs: object,
|
163
|
+
) -> torch.Tensor:
|
164
|
+
hidden_states = general_mm_embed_routine(
|
165
|
+
input_ids=input_ids,
|
166
|
+
positions=positions,
|
167
|
+
forward_batch=forward_batch,
|
168
|
+
multimodal_model=self,
|
169
|
+
language_model=self.language_model,
|
170
|
+
)
|
171
|
+
return hidden_states
|
172
|
+
|
173
|
+
|
174
|
+
EntryClass = [DotsVLMForCausalLM]
|