sglang 0.5.4__py3-none-any.whl → 0.5.4.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +149 -34
- sglang/bench_serving.py +73 -14
- sglang/compile_deep_gemm.py +13 -7
- sglang/launch_server.py +2 -0
- sglang/srt/batch_invariant_ops/__init__.py +2 -0
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +221 -4
- sglang/srt/checkpoint_engine/__init__.py +9 -0
- sglang/srt/checkpoint_engine/update.py +317 -0
- sglang/srt/compilation/backend.py +1 -1
- sglang/srt/configs/__init__.py +2 -0
- sglang/srt/configs/deepseek_ocr.py +542 -10
- sglang/srt/configs/deepseekvl2.py +95 -194
- sglang/srt/configs/kimi_linear.py +160 -0
- sglang/srt/configs/mamba_utils.py +66 -0
- sglang/srt/configs/model_config.py +30 -7
- sglang/srt/constants.py +7 -0
- sglang/srt/debug_utils/tensor_dump_forward_hook.py +149 -0
- sglang/srt/disaggregation/decode.py +34 -6
- sglang/srt/disaggregation/nixl/conn.py +2 -2
- sglang/srt/disaggregation/prefill.py +25 -3
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -1
- sglang/srt/distributed/parallel_state.py +9 -12
- sglang/srt/entrypoints/engine.py +31 -20
- sglang/srt/entrypoints/grpc_server.py +0 -1
- sglang/srt/entrypoints/http_server.py +94 -94
- sglang/srt/entrypoints/openai/protocol.py +7 -1
- sglang/srt/entrypoints/openai/serving_chat.py +42 -0
- sglang/srt/entrypoints/openai/serving_completions.py +10 -0
- sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
- sglang/srt/environ.py +23 -2
- sglang/srt/eplb/expert_distribution.py +64 -1
- sglang/srt/eplb/expert_location.py +106 -36
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/minimax_m2.py +367 -0
- sglang/srt/grpc/compile_proto.py +3 -0
- sglang/srt/layers/activation.py +6 -0
- sglang/srt/layers/attention/ascend_backend.py +233 -5
- sglang/srt/layers/attention/attention_registry.py +3 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +61 -32
- sglang/srt/layers/attention/fla/fused_recurrent.py +17 -4
- sglang/srt/layers/attention/fla/kda.py +1359 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +7 -1
- sglang/srt/layers/attention/flashattention_backend.py +19 -8
- sglang/srt/layers/attention/flashinfer_backend.py +10 -1
- sglang/srt/layers/attention/flashinfer_mla_backend.py +21 -11
- sglang/srt/layers/attention/flashmla_backend.py +1 -1
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +223 -0
- sglang/srt/layers/attention/mamba/mamba.py +20 -11
- sglang/srt/layers/attention/nsa/dequant_k_cache.py +138 -6
- sglang/srt/layers/attention/nsa/nsa_indexer.py +45 -22
- sglang/srt/layers/attention/nsa/quant_k_cache.py +44 -12
- sglang/srt/layers/attention/nsa/transform_index.py +1 -1
- sglang/srt/layers/attention/nsa_backend.py +157 -23
- sglang/srt/layers/attention/triton_backend.py +4 -1
- sglang/srt/layers/attention/trtllm_mha_backend.py +10 -4
- sglang/srt/layers/attention/trtllm_mla_backend.py +11 -15
- sglang/srt/layers/attention/utils.py +78 -0
- sglang/srt/layers/communicator.py +24 -1
- sglang/srt/layers/deep_gemm_wrapper/compile_utils.py +1 -1
- sglang/srt/layers/layernorm.py +35 -6
- sglang/srt/layers/logits_processor.py +9 -20
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +138 -0
- sglang/srt/layers/moe/ep_moe/kernels.py +194 -0
- sglang/srt/layers/moe/ep_moe/layer.py +78 -289
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128]_down.json +164 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +68 -22
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +43 -3
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +106 -26
- sglang/srt/layers/moe/fused_moe_triton/layer.py +3 -3
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +7 -4
- sglang/srt/layers/moe/moe_runner/deep_gemm.py +340 -55
- sglang/srt/layers/moe/moe_runner/runner.py +3 -0
- sglang/srt/layers/moe/moe_runner/triton_kernels.py +194 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +4 -4
- sglang/srt/layers/moe/token_dispatcher/base.py +11 -5
- sglang/srt/layers/moe/token_dispatcher/deepep.py +25 -18
- sglang/srt/layers/moe/token_dispatcher/standard.py +1 -1
- sglang/srt/layers/moe/topk.py +35 -10
- sglang/srt/layers/moe/utils.py +3 -4
- sglang/srt/layers/pooler.py +21 -2
- sglang/srt/layers/quantization/__init__.py +13 -84
- sglang/srt/layers/quantization/auto_round.py +394 -0
- sglang/srt/layers/quantization/awq.py +0 -3
- sglang/srt/layers/quantization/base_config.py +7 -0
- sglang/srt/layers/quantization/fp8.py +68 -63
- sglang/srt/layers/quantization/fp8_kernel.py +1 -1
- sglang/srt/layers/quantization/fp8_utils.py +2 -2
- sglang/srt/layers/quantization/gguf.py +566 -0
- sglang/srt/layers/quantization/modelopt_quant.py +168 -11
- sglang/srt/layers/quantization/mxfp4.py +30 -38
- sglang/srt/layers/quantization/unquant.py +23 -45
- sglang/srt/layers/quantization/w4afp8.py +38 -2
- sglang/srt/layers/radix_attention.py +5 -2
- sglang/srt/layers/rotary_embedding.py +130 -46
- sglang/srt/layers/sampler.py +12 -1
- sglang/srt/lora/lora_registry.py +9 -0
- sglang/srt/managers/async_mm_data_processor.py +122 -0
- sglang/srt/managers/data_parallel_controller.py +30 -3
- sglang/srt/managers/detokenizer_manager.py +3 -0
- sglang/srt/managers/io_struct.py +29 -4
- sglang/srt/managers/multi_tokenizer_mixin.py +22 -1
- sglang/srt/managers/schedule_batch.py +74 -15
- sglang/srt/managers/scheduler.py +185 -144
- sglang/srt/managers/scheduler_metrics_mixin.py +22 -14
- sglang/srt/managers/scheduler_output_processor_mixin.py +40 -3
- sglang/srt/managers/scheduler_pp_mixin.py +7 -2
- sglang/srt/managers/scheduler_profiler_mixin.py +3 -4
- sglang/srt/managers/scheduler_runtime_checker_mixin.py +45 -0
- sglang/srt/managers/scheduler_update_weights_mixin.py +18 -3
- sglang/srt/managers/session_controller.py +6 -5
- sglang/srt/managers/tokenizer_manager.py +165 -78
- sglang/srt/managers/tp_worker.py +24 -1
- sglang/srt/mem_cache/base_prefix_cache.py +23 -4
- sglang/srt/mem_cache/common.py +1 -0
- sglang/srt/mem_cache/hicache_storage.py +7 -1
- sglang/srt/mem_cache/memory_pool.py +253 -57
- sglang/srt/mem_cache/memory_pool_host.py +12 -5
- sglang/srt/mem_cache/radix_cache.py +4 -0
- sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +3 -2
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +1 -1
- sglang/srt/metrics/collector.py +46 -3
- sglang/srt/model_executor/cuda_graph_runner.py +15 -3
- sglang/srt/model_executor/forward_batch_info.py +55 -14
- sglang/srt/model_executor/model_runner.py +77 -170
- sglang/srt/model_executor/npu_graph_runner.py +7 -3
- sglang/srt/model_executor/piecewise_cuda_graph_runner.py +22 -12
- sglang/srt/model_loader/weight_utils.py +1 -1
- sglang/srt/models/bailing_moe.py +9 -2
- sglang/srt/models/deepseek_nextn.py +11 -2
- sglang/srt/models/deepseek_v2.py +296 -78
- sglang/srt/models/glm4.py +391 -77
- sglang/srt/models/glm4_moe.py +322 -354
- sglang/srt/models/glm4_moe_nextn.py +4 -14
- sglang/srt/models/glm4v.py +196 -55
- sglang/srt/models/glm4v_moe.py +29 -197
- sglang/srt/models/gpt_oss.py +1 -10
- sglang/srt/models/kimi_linear.py +678 -0
- sglang/srt/models/llama4.py +1 -1
- sglang/srt/models/llama_eagle3.py +11 -1
- sglang/srt/models/longcat_flash.py +2 -2
- sglang/srt/models/minimax_m2.py +922 -0
- sglang/srt/models/nvila.py +355 -0
- sglang/srt/models/nvila_lite.py +184 -0
- sglang/srt/models/qwen2.py +23 -2
- sglang/srt/models/qwen2_moe.py +30 -15
- sglang/srt/models/qwen3.py +35 -5
- sglang/srt/models/qwen3_moe.py +18 -12
- sglang/srt/models/qwen3_next.py +7 -0
- sglang/srt/multimodal/customized_mm_processor_utils.py +35 -0
- sglang/srt/multimodal/processors/base_processor.py +1 -0
- sglang/srt/multimodal/processors/glm4v.py +1 -1
- sglang/srt/multimodal/processors/{vila.py → nvila.py} +32 -24
- sglang/srt/multimodal/processors/points_v15_chat.py +2 -2
- sglang/srt/multiplex/multiplexing_mixin.py +209 -0
- sglang/srt/multiplex/pdmux_context.py +164 -0
- sglang/srt/parser/conversation.py +7 -1
- sglang/srt/parser/reasoning_parser.py +28 -1
- sglang/srt/sampling/custom_logit_processor.py +67 -1
- sglang/srt/sampling/penaltylib/frequency_penalty.py +6 -8
- sglang/srt/sampling/penaltylib/min_new_tokens.py +7 -8
- sglang/srt/sampling/penaltylib/orchestrator.py +43 -3
- sglang/srt/sampling/penaltylib/presence_penalty.py +6 -8
- sglang/srt/server_args.py +459 -199
- sglang/srt/single_batch_overlap.py +2 -4
- sglang/srt/speculative/draft_utils.py +16 -0
- sglang/srt/speculative/eagle_info.py +42 -36
- sglang/srt/speculative/eagle_info_v2.py +68 -25
- sglang/srt/speculative/eagle_utils.py +261 -16
- sglang/srt/speculative/eagle_worker.py +11 -3
- sglang/srt/speculative/eagle_worker_v2.py +15 -9
- sglang/srt/speculative/spec_info.py +305 -31
- sglang/srt/speculative/spec_utils.py +44 -8
- sglang/srt/tracing/trace.py +121 -12
- sglang/srt/utils/common.py +142 -74
- sglang/srt/utils/hf_transformers_utils.py +38 -12
- sglang/srt/utils/torch_memory_saver_adapter.py +20 -0
- sglang/test/kits/radix_cache_server_kit.py +50 -0
- sglang/test/runners.py +31 -7
- sglang/test/simple_eval_common.py +5 -3
- sglang/test/simple_eval_humaneval.py +1 -0
- sglang/test/simple_eval_math.py +1 -0
- sglang/test/simple_eval_mmlu.py +1 -0
- sglang/test/simple_eval_mmmu_vlm.py +1 -0
- sglang/test/test_deterministic.py +235 -12
- sglang/test/test_deterministic_utils.py +2 -1
- sglang/test/test_utils.py +7 -1
- sglang/version.py +1 -1
- {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/METADATA +15 -28
- {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/RECORD +194 -175
- sglang/srt/models/vila.py +0 -306
- /sglang/test/{kit_matched_stop.py → kits/matched_stop_kit.py} +0 -0
- {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/WHEEL +0 -0
- {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/top_level.txt +0 -0
sglang/srt/models/vila.py
DELETED
|
@@ -1,306 +0,0 @@
|
|
|
1
|
-
import logging
|
|
2
|
-
from typing import Any, Dict, Iterable, List, Optional, Tuple, cast
|
|
3
|
-
|
|
4
|
-
import torch
|
|
5
|
-
import torch.nn as nn
|
|
6
|
-
import torch.nn.functional as F
|
|
7
|
-
from torch import Tensor
|
|
8
|
-
from transformers.configuration_utils import PretrainedConfig
|
|
9
|
-
from transformers.modeling_outputs import BaseModelOutputWithPooling
|
|
10
|
-
from transformers.models.qwen2.configuration_qwen2 import Qwen2Config
|
|
11
|
-
from transformers.models.siglip import SiglipVisionConfig, SiglipVisionModel
|
|
12
|
-
|
|
13
|
-
import sglang.srt.managers.mm_utils as mm_utils
|
|
14
|
-
import sglang.srt.model_loader.weight_utils as weight_utils
|
|
15
|
-
import sglang.srt.utils as utils
|
|
16
|
-
from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
|
|
17
|
-
from sglang.srt.layers.pooler import Pooler, PoolingType
|
|
18
|
-
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
|
19
|
-
from sglang.srt.managers.mm_utils import MultiModalityDataPaddingPatternMultimodalTokens
|
|
20
|
-
from sglang.srt.managers.schedule_batch import (
|
|
21
|
-
Modality,
|
|
22
|
-
MultimodalDataItem,
|
|
23
|
-
MultimodalInputs,
|
|
24
|
-
)
|
|
25
|
-
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
|
26
|
-
from sglang.srt.models.qwen2 import Qwen2ForCausalLM
|
|
27
|
-
|
|
28
|
-
logger = logging.getLogger(__name__)
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
##### BEGIN COPY configuration.py #####
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
class VILAConfig(PretrainedConfig):
|
|
35
|
-
# Class attributes.
|
|
36
|
-
model_type: str = "vila"
|
|
37
|
-
sub_configs: Dict[str, PretrainedConfig] = {
|
|
38
|
-
"text_config": Qwen2Config(),
|
|
39
|
-
"vision_config": SiglipVisionConfig(),
|
|
40
|
-
}
|
|
41
|
-
_auto_class: Optional[str] = "AutoConfig"
|
|
42
|
-
|
|
43
|
-
# Configuration for sub-modules.
|
|
44
|
-
text_config: Qwen2Config = Qwen2Config()
|
|
45
|
-
vision_config: SiglipVisionConfig = SiglipVisionConfig()
|
|
46
|
-
|
|
47
|
-
# Model configuration.
|
|
48
|
-
hidden_size: int
|
|
49
|
-
image_token_id: int
|
|
50
|
-
mm_hidden_size: int
|
|
51
|
-
mm_projector_type: str
|
|
52
|
-
mm_vision_select_feature: str
|
|
53
|
-
mm_vision_select_layer: int
|
|
54
|
-
video_token_id: int
|
|
55
|
-
|
|
56
|
-
def __init__(
|
|
57
|
-
self,
|
|
58
|
-
text_config: Optional[Dict[str, Any]] = None,
|
|
59
|
-
vision_config: Optional[Dict[str, Any]] = None,
|
|
60
|
-
*,
|
|
61
|
-
hidden_size: int = 1536,
|
|
62
|
-
image_token_id: int = 151649,
|
|
63
|
-
mm_hidden_size: int = 1152,
|
|
64
|
-
mm_projector_type: str = "mlp_downsample_3x3_fix",
|
|
65
|
-
mm_vision_select_feature: str = "cls_patch",
|
|
66
|
-
mm_vision_select_layer: int = -2,
|
|
67
|
-
video_token_id: int = 151650,
|
|
68
|
-
**kwargs,
|
|
69
|
-
):
|
|
70
|
-
super().__init__(**kwargs)
|
|
71
|
-
|
|
72
|
-
self.text_config = Qwen2Config(**text_config) if text_config else Qwen2Config()
|
|
73
|
-
self.vision_config = (
|
|
74
|
-
SiglipVisionConfig(**vision_config)
|
|
75
|
-
if vision_config
|
|
76
|
-
else SiglipVisionConfig()
|
|
77
|
-
)
|
|
78
|
-
|
|
79
|
-
self.hidden_size = hidden_size
|
|
80
|
-
self.image_token_id = image_token_id
|
|
81
|
-
self.mm_hidden_size = mm_hidden_size
|
|
82
|
-
self.mm_projector_type = mm_projector_type
|
|
83
|
-
self.mm_vision_select_feature = mm_vision_select_feature
|
|
84
|
-
self.mm_vision_select_layer = mm_vision_select_layer
|
|
85
|
-
self.video_token_id = video_token_id
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
##### END COPY configuration.py #####
|
|
89
|
-
|
|
90
|
-
##### BEGIN COPY modeling_vila.py #####
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
class DownSample3x3BlockFix(nn.Module):
|
|
94
|
-
def forward(self, x: Tensor) -> Tensor:
|
|
95
|
-
"""
|
|
96
|
-
Args:
|
|
97
|
-
x: The input tensor of shape (batch_size, sequence_length, mm_hidden_size).
|
|
98
|
-
|
|
99
|
-
Returns:
|
|
100
|
-
The output tensor of shape (batch_size, image_pad_len, mm_hidden_size * 9).
|
|
101
|
-
"""
|
|
102
|
-
|
|
103
|
-
batch_size, sequence_length, hidden_size = x.shape
|
|
104
|
-
|
|
105
|
-
feat_size = int(sequence_length**0.5)
|
|
106
|
-
if feat_size**2 != sequence_length:
|
|
107
|
-
raise ValueError(
|
|
108
|
-
f"Cannot take square root: sequence_length {sequence_length} is not a perfect square"
|
|
109
|
-
)
|
|
110
|
-
|
|
111
|
-
features = x.reshape(batch_size, feat_size, feat_size, hidden_size)
|
|
112
|
-
|
|
113
|
-
pad_after = (3 - feat_size % 3) % 3
|
|
114
|
-
if pad_after > 0:
|
|
115
|
-
features = F.pad(features, (0, 0, 0, pad_after, 0, pad_after))
|
|
116
|
-
feat_size = feat_size + pad_after
|
|
117
|
-
|
|
118
|
-
features = features.reshape(
|
|
119
|
-
batch_size, feat_size // 3, 3, feat_size // 3, 3, hidden_size
|
|
120
|
-
)
|
|
121
|
-
features = features.permute(0, 1, 3, 2, 4, 5).contiguous()
|
|
122
|
-
features = features.reshape(batch_size, -1, 9 * hidden_size)
|
|
123
|
-
|
|
124
|
-
return features
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
class MultimodalProjector(nn.Module):
|
|
128
|
-
layers: nn.Sequential
|
|
129
|
-
|
|
130
|
-
def __init__(
|
|
131
|
-
self,
|
|
132
|
-
config: VILAConfig,
|
|
133
|
-
*args,
|
|
134
|
-
**kwargs,
|
|
135
|
-
):
|
|
136
|
-
super().__init__(*args, **kwargs)
|
|
137
|
-
|
|
138
|
-
if config.mm_projector_type == "mlp_downsample_3x3_fix":
|
|
139
|
-
self.layers = nn.Sequential(
|
|
140
|
-
DownSample3x3BlockFix(),
|
|
141
|
-
nn.LayerNorm(config.mm_hidden_size * 9),
|
|
142
|
-
nn.Linear(
|
|
143
|
-
config.mm_hidden_size * 9,
|
|
144
|
-
config.mm_hidden_size * 3,
|
|
145
|
-
),
|
|
146
|
-
nn.GELU(),
|
|
147
|
-
nn.LayerNorm(config.vision_config.hidden_size * 3),
|
|
148
|
-
nn.Linear(config.vision_config.hidden_size * 3, config.hidden_size),
|
|
149
|
-
nn.GELU(),
|
|
150
|
-
nn.Linear(config.hidden_size, config.hidden_size),
|
|
151
|
-
)
|
|
152
|
-
else:
|
|
153
|
-
raise NotImplementedError(
|
|
154
|
-
f"Unsupported mm_projector_type: {config.mm_projector_type}"
|
|
155
|
-
)
|
|
156
|
-
|
|
157
|
-
self.layers.type(config.torch_dtype)
|
|
158
|
-
|
|
159
|
-
@property
|
|
160
|
-
def device(self) -> torch.device:
|
|
161
|
-
return next(self.parameters()).device
|
|
162
|
-
|
|
163
|
-
@property
|
|
164
|
-
def dtype(self) -> torch.dtype:
|
|
165
|
-
return next(self.parameters()).dtype
|
|
166
|
-
|
|
167
|
-
def forward(self, x: Tensor) -> Tensor:
|
|
168
|
-
"""
|
|
169
|
-
Args:
|
|
170
|
-
x: The input tensor of shape (batch_size, sequence_length, mm_hidden_size).
|
|
171
|
-
|
|
172
|
-
Returns:
|
|
173
|
-
The output tensor of shape (batch_size, image_pad_len, hidden_size).
|
|
174
|
-
"""
|
|
175
|
-
|
|
176
|
-
return self.layers(x.to(device=self.device, dtype=self.dtype))
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
##### END COPY modeling_vila.py #####
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
class VILAForConditionalGeneration(nn.Module):
|
|
183
|
-
config: VILAConfig
|
|
184
|
-
quant_config: Optional[QuantizationConfig]
|
|
185
|
-
|
|
186
|
-
logits_processor: LogitsProcessor
|
|
187
|
-
pooler: Pooler
|
|
188
|
-
|
|
189
|
-
llm: Qwen2ForCausalLM
|
|
190
|
-
mm_projector: MultimodalProjector
|
|
191
|
-
vision_tower: SiglipVisionModel
|
|
192
|
-
|
|
193
|
-
def __init__(
|
|
194
|
-
self,
|
|
195
|
-
config: VILAConfig,
|
|
196
|
-
quant_config: Optional[QuantizationConfig] = None,
|
|
197
|
-
prefix: str = "",
|
|
198
|
-
) -> None:
|
|
199
|
-
super().__init__()
|
|
200
|
-
|
|
201
|
-
self.config = config
|
|
202
|
-
self.quant_config = quant_config
|
|
203
|
-
|
|
204
|
-
self.logits_processor = LogitsProcessor(config)
|
|
205
|
-
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
|
|
206
|
-
|
|
207
|
-
self.llm = Qwen2ForCausalLM(
|
|
208
|
-
config=config.text_config,
|
|
209
|
-
quant_config=quant_config,
|
|
210
|
-
prefix=utils.add_prefix("llm", prefix),
|
|
211
|
-
)
|
|
212
|
-
self.mm_projector = MultimodalProjector(config)
|
|
213
|
-
self.vision_tower = SiglipVisionModel(config.vision_config)
|
|
214
|
-
|
|
215
|
-
@property
|
|
216
|
-
def dtype(self) -> torch.dtype:
|
|
217
|
-
return self.config.torch_dtype
|
|
218
|
-
|
|
219
|
-
def forward(
|
|
220
|
-
self,
|
|
221
|
-
input_ids: Tensor,
|
|
222
|
-
positions: Tensor,
|
|
223
|
-
forward_batch: ForwardBatch,
|
|
224
|
-
get_embedding: bool = False,
|
|
225
|
-
) -> LogitsProcessorOutput:
|
|
226
|
-
output = mm_utils.general_mm_embed_routine(
|
|
227
|
-
input_ids=input_ids,
|
|
228
|
-
forward_batch=forward_batch,
|
|
229
|
-
language_model=self.llm,
|
|
230
|
-
data_embedding_funcs={
|
|
231
|
-
Modality.IMAGE: self.get_image_feature,
|
|
232
|
-
},
|
|
233
|
-
get_embedding=get_embedding,
|
|
234
|
-
positions=positions,
|
|
235
|
-
)
|
|
236
|
-
|
|
237
|
-
return cast(LogitsProcessorOutput, output)
|
|
238
|
-
|
|
239
|
-
def get_image_feature(self, mm_input: List[MultimodalDataItem]) -> Tensor:
|
|
240
|
-
pixel_values = cast(Tensor, mm_input[0].feature)
|
|
241
|
-
|
|
242
|
-
##### BEGIN COPY modeling_vila.py #####
|
|
243
|
-
|
|
244
|
-
vision_tower_output: BaseModelOutputWithPooling = self.vision_tower.__call__(
|
|
245
|
-
pixel_values.to(
|
|
246
|
-
device=self.vision_tower.device, dtype=self.vision_tower.dtype
|
|
247
|
-
),
|
|
248
|
-
output_hidden_states=True,
|
|
249
|
-
)
|
|
250
|
-
|
|
251
|
-
mm_projector_input = self._vision_tower_output_to_mm_projector_input(
|
|
252
|
-
vision_tower_output
|
|
253
|
-
)
|
|
254
|
-
|
|
255
|
-
image_embedding: Tensor = self.mm_projector.__call__(
|
|
256
|
-
mm_projector_input.to(
|
|
257
|
-
device=self.mm_projector.device, dtype=self.mm_projector.dtype
|
|
258
|
-
)
|
|
259
|
-
)
|
|
260
|
-
|
|
261
|
-
##### END COPY modeling_vila.py #####
|
|
262
|
-
|
|
263
|
-
return image_embedding
|
|
264
|
-
|
|
265
|
-
def load_weights(self, weights: Iterable[Tuple[str, Tensor]]) -> None:
|
|
266
|
-
params_dict = dict(self.named_parameters())
|
|
267
|
-
|
|
268
|
-
for name, loaded_weight in weights:
|
|
269
|
-
if name.startswith("llm."):
|
|
270
|
-
self.llm.load_weights([(name[len("llm.") :], loaded_weight)])
|
|
271
|
-
else:
|
|
272
|
-
param = params_dict[name]
|
|
273
|
-
weight_loader = getattr(
|
|
274
|
-
param, "weight_loader", weight_utils.default_weight_loader
|
|
275
|
-
)
|
|
276
|
-
weight_loader(param, loaded_weight)
|
|
277
|
-
|
|
278
|
-
def pad_input_ids(
|
|
279
|
-
self, input_ids: List[int], mm_inputs: MultimodalInputs
|
|
280
|
-
) -> List[int]:
|
|
281
|
-
pattern = MultiModalityDataPaddingPatternMultimodalTokens()
|
|
282
|
-
return pattern.pad_input_tokens(input_ids, mm_inputs)
|
|
283
|
-
|
|
284
|
-
##### BEGIN COPY modeling_vila.py #####
|
|
285
|
-
|
|
286
|
-
def _vision_tower_output_to_mm_projector_input(
|
|
287
|
-
self,
|
|
288
|
-
vision_tower_output: BaseModelOutputWithPooling,
|
|
289
|
-
) -> Tensor:
|
|
290
|
-
assert vision_tower_output.hidden_states is not None
|
|
291
|
-
|
|
292
|
-
selected_layer_hidden_states = vision_tower_output.hidden_states[
|
|
293
|
-
self.config.mm_vision_select_layer
|
|
294
|
-
]
|
|
295
|
-
|
|
296
|
-
if self.config.mm_vision_select_feature == "cls_patch":
|
|
297
|
-
return selected_layer_hidden_states
|
|
298
|
-
else:
|
|
299
|
-
raise NotImplementedError(
|
|
300
|
-
f"Unsupported mm_vision_select_feature: {self.config.mm_vision_select_feature}"
|
|
301
|
-
)
|
|
302
|
-
|
|
303
|
-
##### END COPY modeling_vila.py #####
|
|
304
|
-
|
|
305
|
-
|
|
306
|
-
EntryClass = [VILAForConditionalGeneration]
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|