sglang 0.5.4__py3-none-any.whl → 0.5.4.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +149 -34
- sglang/bench_serving.py +73 -14
- sglang/compile_deep_gemm.py +13 -7
- sglang/launch_server.py +2 -0
- sglang/srt/batch_invariant_ops/__init__.py +2 -0
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +221 -4
- sglang/srt/checkpoint_engine/__init__.py +9 -0
- sglang/srt/checkpoint_engine/update.py +317 -0
- sglang/srt/compilation/backend.py +1 -1
- sglang/srt/configs/__init__.py +2 -0
- sglang/srt/configs/deepseek_ocr.py +542 -10
- sglang/srt/configs/deepseekvl2.py +95 -194
- sglang/srt/configs/kimi_linear.py +160 -0
- sglang/srt/configs/mamba_utils.py +66 -0
- sglang/srt/configs/model_config.py +30 -7
- sglang/srt/constants.py +7 -0
- sglang/srt/debug_utils/tensor_dump_forward_hook.py +149 -0
- sglang/srt/disaggregation/decode.py +34 -6
- sglang/srt/disaggregation/nixl/conn.py +2 -2
- sglang/srt/disaggregation/prefill.py +25 -3
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -1
- sglang/srt/distributed/parallel_state.py +9 -12
- sglang/srt/entrypoints/engine.py +31 -20
- sglang/srt/entrypoints/grpc_server.py +0 -1
- sglang/srt/entrypoints/http_server.py +94 -94
- sglang/srt/entrypoints/openai/protocol.py +7 -1
- sglang/srt/entrypoints/openai/serving_chat.py +42 -0
- sglang/srt/entrypoints/openai/serving_completions.py +10 -0
- sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
- sglang/srt/environ.py +23 -2
- sglang/srt/eplb/expert_distribution.py +64 -1
- sglang/srt/eplb/expert_location.py +106 -36
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/minimax_m2.py +367 -0
- sglang/srt/grpc/compile_proto.py +3 -0
- sglang/srt/layers/activation.py +6 -0
- sglang/srt/layers/attention/ascend_backend.py +233 -5
- sglang/srt/layers/attention/attention_registry.py +3 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +61 -32
- sglang/srt/layers/attention/fla/fused_recurrent.py +17 -4
- sglang/srt/layers/attention/fla/kda.py +1359 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +7 -1
- sglang/srt/layers/attention/flashattention_backend.py +19 -8
- sglang/srt/layers/attention/flashinfer_backend.py +10 -1
- sglang/srt/layers/attention/flashinfer_mla_backend.py +21 -11
- sglang/srt/layers/attention/flashmla_backend.py +1 -1
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +223 -0
- sglang/srt/layers/attention/mamba/mamba.py +20 -11
- sglang/srt/layers/attention/nsa/dequant_k_cache.py +138 -6
- sglang/srt/layers/attention/nsa/nsa_indexer.py +45 -22
- sglang/srt/layers/attention/nsa/quant_k_cache.py +44 -12
- sglang/srt/layers/attention/nsa/transform_index.py +1 -1
- sglang/srt/layers/attention/nsa_backend.py +157 -23
- sglang/srt/layers/attention/triton_backend.py +4 -1
- sglang/srt/layers/attention/trtllm_mha_backend.py +10 -4
- sglang/srt/layers/attention/trtllm_mla_backend.py +11 -15
- sglang/srt/layers/attention/utils.py +78 -0
- sglang/srt/layers/communicator.py +24 -1
- sglang/srt/layers/deep_gemm_wrapper/compile_utils.py +1 -1
- sglang/srt/layers/layernorm.py +35 -6
- sglang/srt/layers/logits_processor.py +9 -20
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +138 -0
- sglang/srt/layers/moe/ep_moe/kernels.py +194 -0
- sglang/srt/layers/moe/ep_moe/layer.py +78 -289
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128]_down.json +164 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +68 -22
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +43 -3
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +106 -26
- sglang/srt/layers/moe/fused_moe_triton/layer.py +3 -3
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +7 -4
- sglang/srt/layers/moe/moe_runner/deep_gemm.py +340 -55
- sglang/srt/layers/moe/moe_runner/runner.py +3 -0
- sglang/srt/layers/moe/moe_runner/triton_kernels.py +194 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +4 -4
- sglang/srt/layers/moe/token_dispatcher/base.py +11 -5
- sglang/srt/layers/moe/token_dispatcher/deepep.py +25 -18
- sglang/srt/layers/moe/token_dispatcher/standard.py +1 -1
- sglang/srt/layers/moe/topk.py +35 -10
- sglang/srt/layers/moe/utils.py +3 -4
- sglang/srt/layers/pooler.py +21 -2
- sglang/srt/layers/quantization/__init__.py +13 -84
- sglang/srt/layers/quantization/auto_round.py +394 -0
- sglang/srt/layers/quantization/awq.py +0 -3
- sglang/srt/layers/quantization/base_config.py +7 -0
- sglang/srt/layers/quantization/fp8.py +68 -63
- sglang/srt/layers/quantization/fp8_kernel.py +1 -1
- sglang/srt/layers/quantization/fp8_utils.py +2 -2
- sglang/srt/layers/quantization/gguf.py +566 -0
- sglang/srt/layers/quantization/modelopt_quant.py +168 -11
- sglang/srt/layers/quantization/mxfp4.py +30 -38
- sglang/srt/layers/quantization/unquant.py +23 -45
- sglang/srt/layers/quantization/w4afp8.py +38 -2
- sglang/srt/layers/radix_attention.py +5 -2
- sglang/srt/layers/rotary_embedding.py +130 -46
- sglang/srt/layers/sampler.py +12 -1
- sglang/srt/lora/lora_registry.py +9 -0
- sglang/srt/managers/async_mm_data_processor.py +122 -0
- sglang/srt/managers/data_parallel_controller.py +30 -3
- sglang/srt/managers/detokenizer_manager.py +3 -0
- sglang/srt/managers/io_struct.py +29 -4
- sglang/srt/managers/multi_tokenizer_mixin.py +22 -1
- sglang/srt/managers/schedule_batch.py +74 -15
- sglang/srt/managers/scheduler.py +185 -144
- sglang/srt/managers/scheduler_metrics_mixin.py +22 -14
- sglang/srt/managers/scheduler_output_processor_mixin.py +40 -3
- sglang/srt/managers/scheduler_pp_mixin.py +7 -2
- sglang/srt/managers/scheduler_profiler_mixin.py +3 -4
- sglang/srt/managers/scheduler_runtime_checker_mixin.py +45 -0
- sglang/srt/managers/scheduler_update_weights_mixin.py +18 -3
- sglang/srt/managers/session_controller.py +6 -5
- sglang/srt/managers/tokenizer_manager.py +165 -78
- sglang/srt/managers/tp_worker.py +24 -1
- sglang/srt/mem_cache/base_prefix_cache.py +23 -4
- sglang/srt/mem_cache/common.py +1 -0
- sglang/srt/mem_cache/hicache_storage.py +7 -1
- sglang/srt/mem_cache/memory_pool.py +253 -57
- sglang/srt/mem_cache/memory_pool_host.py +12 -5
- sglang/srt/mem_cache/radix_cache.py +4 -0
- sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +3 -2
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +1 -1
- sglang/srt/metrics/collector.py +46 -3
- sglang/srt/model_executor/cuda_graph_runner.py +15 -3
- sglang/srt/model_executor/forward_batch_info.py +55 -14
- sglang/srt/model_executor/model_runner.py +77 -170
- sglang/srt/model_executor/npu_graph_runner.py +7 -3
- sglang/srt/model_executor/piecewise_cuda_graph_runner.py +22 -12
- sglang/srt/model_loader/weight_utils.py +1 -1
- sglang/srt/models/bailing_moe.py +9 -2
- sglang/srt/models/deepseek_nextn.py +11 -2
- sglang/srt/models/deepseek_v2.py +296 -78
- sglang/srt/models/glm4.py +391 -77
- sglang/srt/models/glm4_moe.py +322 -354
- sglang/srt/models/glm4_moe_nextn.py +4 -14
- sglang/srt/models/glm4v.py +196 -55
- sglang/srt/models/glm4v_moe.py +29 -197
- sglang/srt/models/gpt_oss.py +1 -10
- sglang/srt/models/kimi_linear.py +678 -0
- sglang/srt/models/llama4.py +1 -1
- sglang/srt/models/llama_eagle3.py +11 -1
- sglang/srt/models/longcat_flash.py +2 -2
- sglang/srt/models/minimax_m2.py +922 -0
- sglang/srt/models/nvila.py +355 -0
- sglang/srt/models/nvila_lite.py +184 -0
- sglang/srt/models/qwen2.py +23 -2
- sglang/srt/models/qwen2_moe.py +30 -15
- sglang/srt/models/qwen3.py +35 -5
- sglang/srt/models/qwen3_moe.py +18 -12
- sglang/srt/models/qwen3_next.py +7 -0
- sglang/srt/multimodal/customized_mm_processor_utils.py +35 -0
- sglang/srt/multimodal/processors/base_processor.py +1 -0
- sglang/srt/multimodal/processors/glm4v.py +1 -1
- sglang/srt/multimodal/processors/{vila.py → nvila.py} +32 -24
- sglang/srt/multimodal/processors/points_v15_chat.py +2 -2
- sglang/srt/multiplex/multiplexing_mixin.py +209 -0
- sglang/srt/multiplex/pdmux_context.py +164 -0
- sglang/srt/parser/conversation.py +7 -1
- sglang/srt/parser/reasoning_parser.py +28 -1
- sglang/srt/sampling/custom_logit_processor.py +67 -1
- sglang/srt/sampling/penaltylib/frequency_penalty.py +6 -8
- sglang/srt/sampling/penaltylib/min_new_tokens.py +7 -8
- sglang/srt/sampling/penaltylib/orchestrator.py +43 -3
- sglang/srt/sampling/penaltylib/presence_penalty.py +6 -8
- sglang/srt/server_args.py +459 -199
- sglang/srt/single_batch_overlap.py +2 -4
- sglang/srt/speculative/draft_utils.py +16 -0
- sglang/srt/speculative/eagle_info.py +42 -36
- sglang/srt/speculative/eagle_info_v2.py +68 -25
- sglang/srt/speculative/eagle_utils.py +261 -16
- sglang/srt/speculative/eagle_worker.py +11 -3
- sglang/srt/speculative/eagle_worker_v2.py +15 -9
- sglang/srt/speculative/spec_info.py +305 -31
- sglang/srt/speculative/spec_utils.py +44 -8
- sglang/srt/tracing/trace.py +121 -12
- sglang/srt/utils/common.py +142 -74
- sglang/srt/utils/hf_transformers_utils.py +38 -12
- sglang/srt/utils/torch_memory_saver_adapter.py +20 -0
- sglang/test/kits/radix_cache_server_kit.py +50 -0
- sglang/test/runners.py +31 -7
- sglang/test/simple_eval_common.py +5 -3
- sglang/test/simple_eval_humaneval.py +1 -0
- sglang/test/simple_eval_math.py +1 -0
- sglang/test/simple_eval_mmlu.py +1 -0
- sglang/test/simple_eval_mmmu_vlm.py +1 -0
- sglang/test/test_deterministic.py +235 -12
- sglang/test/test_deterministic_utils.py +2 -1
- sglang/test/test_utils.py +7 -1
- sglang/version.py +1 -1
- {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/METADATA +15 -28
- {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/RECORD +194 -175
- sglang/srt/models/vila.py +0 -306
- /sglang/test/{kit_matched_stop.py → kits/matched_stop_kit.py} +0 -0
- {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/WHEEL +0 -0
- {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/top_level.txt +0 -0
|
@@ -12,7 +12,8 @@
|
|
|
12
12
|
# limitations under the License.
|
|
13
13
|
# ==============================================================================
|
|
14
14
|
|
|
15
|
-
"""Inference-only GLM-4.5, GLM-4.6
|
|
15
|
+
"""Inference-only GLM-4.5, GLM-4.6 Speculative Decoding."""
|
|
16
|
+
|
|
16
17
|
import logging
|
|
17
18
|
from typing import Iterable, Optional, Tuple
|
|
18
19
|
|
|
@@ -33,7 +34,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
|
|
33
34
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
|
34
35
|
from sglang.srt.models.glm4_moe import Glm4MoeDecoderLayer, Glm4MoeForCausalLM
|
|
35
36
|
from sglang.srt.server_args import get_global_server_args
|
|
36
|
-
from sglang.srt.utils import
|
|
37
|
+
from sglang.srt.utils import add_prefix
|
|
37
38
|
|
|
38
39
|
logger = logging.getLogger(__name__)
|
|
39
40
|
|
|
@@ -84,14 +85,6 @@ class Glm4MoeModelNextN(nn.Module):
|
|
|
84
85
|
forward_batch: ForwardBatch,
|
|
85
86
|
input_embeds: torch.Tensor = None,
|
|
86
87
|
) -> torch.Tensor:
|
|
87
|
-
zero_allocator = BumpAllocator(
|
|
88
|
-
buffer_size=2,
|
|
89
|
-
dtype=torch.float32,
|
|
90
|
-
device=(
|
|
91
|
-
input_embeds.device if input_embeds is not None else input_ids.device
|
|
92
|
-
),
|
|
93
|
-
)
|
|
94
|
-
|
|
95
88
|
if input_embeds is None:
|
|
96
89
|
hidden_states = self.embed_tokens(input_ids)
|
|
97
90
|
else:
|
|
@@ -111,7 +104,7 @@ class Glm4MoeModelNextN(nn.Module):
|
|
|
111
104
|
residual = None
|
|
112
105
|
with get_global_expert_distribution_recorder().disable_this_region():
|
|
113
106
|
hidden_states, residual = self.decoder(
|
|
114
|
-
positions, hidden_states, forward_batch, residual
|
|
107
|
+
positions, hidden_states, forward_batch, residual
|
|
115
108
|
)
|
|
116
109
|
|
|
117
110
|
if not forward_batch.forward_mode.is_idle():
|
|
@@ -124,7 +117,6 @@ class Glm4MoeModelNextN(nn.Module):
|
|
|
124
117
|
|
|
125
118
|
|
|
126
119
|
class Glm4MoeForCausalLMNextN(Glm4MoeForCausalLM):
|
|
127
|
-
|
|
128
120
|
def __init__(
|
|
129
121
|
self,
|
|
130
122
|
config: PretrainedConfig,
|
|
@@ -135,8 +127,6 @@ class Glm4MoeForCausalLMNextN(Glm4MoeForCausalLM):
|
|
|
135
127
|
self.config = config
|
|
136
128
|
self.tp_size = get_tensor_model_parallel_world_size()
|
|
137
129
|
self.quant_config = quant_config
|
|
138
|
-
self.determine_num_fused_shared_experts("Glm4MoeForCausalLMNextN")
|
|
139
|
-
|
|
140
130
|
self.model = Glm4MoeModelNextN(
|
|
141
131
|
config, quant_config, prefix=add_prefix("model", prefix)
|
|
142
132
|
)
|
sglang/srt/models/glm4v.py
CHANGED
|
@@ -1,15 +1,35 @@
|
|
|
1
|
+
# Copyright 2023-2024 SGLang Team
|
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
3
|
+
# you may not use this file except in compliance with the License.
|
|
4
|
+
# You may obtain a copy of the License at
|
|
5
|
+
#
|
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
7
|
+
#
|
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
11
|
+
# See the License for the specific language governing permissions and
|
|
12
|
+
# limitations under the License.
|
|
13
|
+
# ==============================================================================
|
|
14
|
+
|
|
15
|
+
# Modeling from:
|
|
16
|
+
# ./llama.py and
|
|
17
|
+
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/glm4v/modular_glm4v.py
|
|
18
|
+
"""Inference-only GLM-4.1V model compatible with HuggingFace weights."""
|
|
19
|
+
|
|
1
20
|
import logging
|
|
2
|
-
from functools import lru_cache
|
|
21
|
+
from functools import lru_cache
|
|
3
22
|
from typing import Iterable, List, Optional, Tuple
|
|
4
23
|
|
|
5
24
|
import torch
|
|
6
25
|
import torch.nn as nn
|
|
7
26
|
import torch.nn.functional as F
|
|
27
|
+
from einops import rearrange
|
|
8
28
|
from transformers.models.glm4v.configuration_glm4v import Glm4vConfig, Glm4vVisionConfig
|
|
9
29
|
|
|
10
30
|
from sglang.srt.layers.activation import SiluAndMul
|
|
11
31
|
from sglang.srt.layers.attention import vision_utils
|
|
12
|
-
from sglang.srt.layers.
|
|
32
|
+
from sglang.srt.layers.attention.vision import VisionAttention
|
|
13
33
|
from sglang.srt.layers.layernorm import RMSNorm
|
|
14
34
|
from sglang.srt.layers.linear import (
|
|
15
35
|
ColumnParallelLinear,
|
|
@@ -20,13 +40,14 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
|
|
|
20
40
|
from sglang.srt.layers.pooler import Pooler, PoolingType
|
|
21
41
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
|
22
42
|
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
|
|
23
|
-
from sglang.srt.managers.
|
|
43
|
+
from sglang.srt.managers.mm_utils import (
|
|
44
|
+
MultiModalityDataPaddingPatternMultimodalTokens,
|
|
45
|
+
general_mm_embed_routine,
|
|
46
|
+
)
|
|
47
|
+
from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInputs
|
|
48
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
|
24
49
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
|
25
50
|
from sglang.srt.models.glm4 import Glm4Model
|
|
26
|
-
from sglang.srt.models.qwen2_5_vl import (
|
|
27
|
-
Qwen2_5_VisionBlock,
|
|
28
|
-
Qwen2_5_VLForConditionalGeneration,
|
|
29
|
-
)
|
|
30
51
|
from sglang.srt.utils import add_prefix
|
|
31
52
|
from sglang.srt.utils.hf_transformers_utils import get_processor
|
|
32
53
|
|
|
@@ -56,7 +77,7 @@ class Glm4vVisionMLP(nn.Module):
|
|
|
56
77
|
super().__init__()
|
|
57
78
|
self.gate_up_proj = MergedColumnParallelLinear(
|
|
58
79
|
input_size=in_features,
|
|
59
|
-
output_sizes=[hidden_features] * 2,
|
|
80
|
+
output_sizes=[hidden_features] * 2, # [gate_proj, up_proj]
|
|
60
81
|
bias=bias,
|
|
61
82
|
quant_config=quant_config,
|
|
62
83
|
prefix=add_prefix("gate_up_proj", prefix),
|
|
@@ -77,34 +98,95 @@ class Glm4vVisionMLP(nn.Module):
|
|
|
77
98
|
return x
|
|
78
99
|
|
|
79
100
|
|
|
80
|
-
class Glm4vVisionBlock(
|
|
101
|
+
class Glm4vVisionBlock(nn.Module):
|
|
81
102
|
def __init__(
|
|
82
103
|
self,
|
|
83
|
-
|
|
84
|
-
|
|
104
|
+
dim: int,
|
|
105
|
+
intermediate_dim: int,
|
|
106
|
+
num_heads: int,
|
|
107
|
+
attn_implementation: Optional[str] = None,
|
|
85
108
|
quant_config: Optional[QuantizationConfig] = None,
|
|
86
109
|
prefix: str = "",
|
|
110
|
+
num_dummy_heads: int = 0,
|
|
111
|
+
rms_norm_eps: float = 1e-5,
|
|
87
112
|
) -> None:
|
|
88
|
-
super().__init__(
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
113
|
+
super().__init__()
|
|
114
|
+
self.norm1 = RMSNorm(dim, eps=rms_norm_eps)
|
|
115
|
+
self.norm2 = RMSNorm(dim, eps=rms_norm_eps)
|
|
116
|
+
|
|
117
|
+
if attn_implementation is None:
|
|
118
|
+
softmax_in_single_precision = False
|
|
119
|
+
qkv_backend = None
|
|
120
|
+
flatten_batch = True
|
|
121
|
+
elif attn_implementation == "sdpa":
|
|
122
|
+
softmax_in_single_precision = False
|
|
123
|
+
qkv_backend = "sdpa"
|
|
124
|
+
flatten_batch = True
|
|
125
|
+
elif attn_implementation == "flash_attention_2":
|
|
126
|
+
softmax_in_single_precision = False
|
|
127
|
+
qkv_backend = "triton_attn"
|
|
128
|
+
flatten_batch = True
|
|
129
|
+
elif attn_implementation == "eager":
|
|
130
|
+
softmax_in_single_precision = True
|
|
131
|
+
qkv_backend = "sdpa"
|
|
132
|
+
flatten_batch = True
|
|
133
|
+
elif attn_implementation == "flash_attention_3":
|
|
134
|
+
softmax_in_single_precision = False
|
|
135
|
+
qkv_backend = "fa3"
|
|
136
|
+
flatten_batch = True
|
|
137
|
+
|
|
138
|
+
self.attn = VisionAttention(
|
|
139
|
+
embed_dim=dim,
|
|
140
|
+
num_heads=num_heads,
|
|
141
|
+
projection_size=dim,
|
|
142
|
+
use_qkv_parallel=True,
|
|
143
|
+
rotary_embed="normal",
|
|
144
|
+
proj_bias=True,
|
|
145
|
+
qkv_backend=qkv_backend,
|
|
146
|
+
softmax_in_single_precision=softmax_in_single_precision,
|
|
147
|
+
flatten_batch=flatten_batch,
|
|
94
148
|
quant_config=quant_config,
|
|
95
|
-
prefix=prefix,
|
|
96
|
-
num_dummy_heads=
|
|
97
|
-
rms_norm_eps=config.rms_norm_eps,
|
|
149
|
+
prefix=add_prefix("attn", prefix),
|
|
150
|
+
num_dummy_heads=num_dummy_heads,
|
|
98
151
|
)
|
|
99
|
-
|
|
100
152
|
self.mlp = Glm4vVisionMLP(
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
bias=False,
|
|
153
|
+
dim,
|
|
154
|
+
intermediate_dim,
|
|
104
155
|
quant_config=quant_config,
|
|
105
156
|
prefix=add_prefix("mlp", prefix),
|
|
106
157
|
)
|
|
107
158
|
|
|
159
|
+
def forward(
|
|
160
|
+
self,
|
|
161
|
+
x: torch.Tensor,
|
|
162
|
+
cu_seqlens: torch.Tensor,
|
|
163
|
+
position_embeddings: torch.Tensor,
|
|
164
|
+
) -> torch.Tensor:
|
|
165
|
+
S, B, H = x.shape
|
|
166
|
+
# norm1: flatten to 2D -> [S*B, H], then reshape back
|
|
167
|
+
x2d = x.reshape(-1, H)
|
|
168
|
+
hidden_states = self.norm1(x2d).reshape(S, B, H)
|
|
169
|
+
|
|
170
|
+
# Attention expects [B, S, H]
|
|
171
|
+
hidden_states = rearrange(hidden_states, "s b h -> b s h")
|
|
172
|
+
attn = self.attn(
|
|
173
|
+
hidden_states,
|
|
174
|
+
cu_seqlens=cu_seqlens,
|
|
175
|
+
position_embeddings=position_embeddings,
|
|
176
|
+
)
|
|
177
|
+
attn = rearrange(attn, "b s h -> s b h")
|
|
178
|
+
|
|
179
|
+
# norm2 with fused residual-add: also 2D
|
|
180
|
+
attn2d = attn.reshape(-1, H)
|
|
181
|
+
x_norm_2d, x_after_add_2d = self.norm2(x2d, residual=attn2d)
|
|
182
|
+
x_norm = x_norm_2d.reshape(S, B, H)
|
|
183
|
+
x_after_add = x_after_add_2d.reshape(S, B, H)
|
|
184
|
+
|
|
185
|
+
# MLP and final residual
|
|
186
|
+
mlp_out = self.mlp(x_norm)
|
|
187
|
+
x = x_after_add + mlp_out
|
|
188
|
+
return x
|
|
189
|
+
|
|
108
190
|
|
|
109
191
|
class Glm4vVisionPatchEmbed(nn.Module):
|
|
110
192
|
def __init__(
|
|
@@ -320,7 +402,6 @@ class Glm4vVisionModel(nn.Module):
|
|
|
320
402
|
def __init__(
|
|
321
403
|
self,
|
|
322
404
|
vision_config: Glm4vVisionConfig,
|
|
323
|
-
norm_eps: float = 1e-6,
|
|
324
405
|
quant_config: Optional[QuantizationConfig] = None,
|
|
325
406
|
prefix: str = "",
|
|
326
407
|
) -> None:
|
|
@@ -344,17 +425,18 @@ class Glm4vVisionModel(nn.Module):
|
|
|
344
425
|
hidden_size=self.hidden_size,
|
|
345
426
|
)
|
|
346
427
|
|
|
347
|
-
norm_layer = partial(Glm4vRMSNorm, eps=norm_eps)
|
|
348
428
|
head_dim = self.hidden_size // self.num_heads
|
|
349
429
|
self.rotary_pos_emb = Glm4vVisionRotaryEmbedding(head_dim // 2)
|
|
350
430
|
|
|
351
431
|
self.blocks = nn.ModuleList(
|
|
352
432
|
[
|
|
353
433
|
Glm4vVisionBlock(
|
|
354
|
-
|
|
355
|
-
|
|
434
|
+
dim=self.hidden_size,
|
|
435
|
+
intermediate_dim=self.out_hidden_size,
|
|
436
|
+
num_heads=self.num_heads,
|
|
356
437
|
quant_config=quant_config,
|
|
357
438
|
prefix=add_prefix(f"blocks.{layer_idx}", prefix),
|
|
439
|
+
rms_norm_eps=vision_config.rms_norm_eps,
|
|
358
440
|
)
|
|
359
441
|
for layer_idx in range(depth)
|
|
360
442
|
]
|
|
@@ -461,29 +543,30 @@ class Glm4vVisionModel(nn.Module):
|
|
|
461
543
|
return x
|
|
462
544
|
|
|
463
545
|
|
|
464
|
-
class Glm4vForConditionalGeneration(
|
|
546
|
+
class Glm4vForConditionalGeneration(nn.Module):
|
|
465
547
|
def __init__(
|
|
466
548
|
self,
|
|
467
549
|
config: Glm4vConfig,
|
|
468
550
|
quant_config: Optional[QuantizationConfig] = None,
|
|
469
551
|
prefix: str = "",
|
|
470
552
|
) -> None:
|
|
471
|
-
|
|
553
|
+
super().__init__()
|
|
472
554
|
|
|
473
555
|
self.config = config
|
|
474
|
-
vision_utils.update_vit_attn_dummy_heads_config(self.config)
|
|
475
|
-
self.model = Glm4Model(
|
|
476
|
-
config,
|
|
477
|
-
quant_config,
|
|
478
|
-
prefix=add_prefix("model", prefix),
|
|
479
|
-
)
|
|
480
556
|
self.visual = Glm4vVisionModel(
|
|
481
557
|
config.vision_config,
|
|
482
|
-
norm_eps=getattr(config, "rms_norm_eps", 1e-5),
|
|
483
558
|
quant_config=quant_config,
|
|
484
559
|
prefix=add_prefix("visual", prefix),
|
|
485
560
|
)
|
|
486
561
|
|
|
562
|
+
vision_utils.update_vit_attn_dummy_heads_config(self.config)
|
|
563
|
+
|
|
564
|
+
self.model = Glm4Model(
|
|
565
|
+
config,
|
|
566
|
+
quant_config=quant_config,
|
|
567
|
+
prefix=add_prefix("model", prefix),
|
|
568
|
+
)
|
|
569
|
+
|
|
487
570
|
if config.tie_word_embeddings:
|
|
488
571
|
self.lm_head = self.model.embed_tokens
|
|
489
572
|
else:
|
|
@@ -494,13 +577,18 @@ class Glm4vForConditionalGeneration(Qwen2_5_VLForConditionalGeneration):
|
|
|
494
577
|
prefix=add_prefix("lm_head", prefix),
|
|
495
578
|
)
|
|
496
579
|
|
|
580
|
+
self.is_mrope_enabled = "mrope_section" in self.config.rope_scaling
|
|
581
|
+
|
|
497
582
|
self.logits_processor = LogitsProcessor(config)
|
|
498
583
|
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
|
|
499
|
-
self.is_mrope_enabled = "mrope_section" in self.config.rope_scaling
|
|
500
584
|
|
|
501
585
|
# For EAGLE3 support
|
|
502
586
|
self.capture_aux_hidden_states = False
|
|
503
587
|
|
|
588
|
+
def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
|
|
589
|
+
pattern = MultiModalityDataPaddingPatternMultimodalTokens()
|
|
590
|
+
return pattern.pad_input_tokens(input_ids, mm_inputs)
|
|
591
|
+
|
|
504
592
|
def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
|
|
505
593
|
pixel_values = torch.cat(
|
|
506
594
|
[item.feature.squeeze(0) for item in items], dim=0
|
|
@@ -542,20 +630,60 @@ class Glm4vForConditionalGeneration(Qwen2_5_VLForConditionalGeneration):
|
|
|
542
630
|
video_embeds = torch.split(video_embeds, split_sizes)
|
|
543
631
|
return torch.cat(video_embeds)
|
|
544
632
|
|
|
545
|
-
def
|
|
546
|
-
|
|
547
|
-
tp_size = get_attention_tp_size()
|
|
548
|
-
num_heads = self.config.vision_config.num_heads
|
|
549
|
-
head_dim = self.config.vision_config.hidden_size // num_heads
|
|
550
|
-
num_dummy_heads = 0
|
|
633
|
+
def get_input_embeddings(self):
|
|
634
|
+
return self.model.embed_tokens
|
|
551
635
|
|
|
552
|
-
|
|
553
|
-
|
|
554
|
-
|
|
555
|
-
|
|
636
|
+
@torch.no_grad()
|
|
637
|
+
def forward(
|
|
638
|
+
self,
|
|
639
|
+
input_ids: torch.Tensor,
|
|
640
|
+
positions: torch.Tensor,
|
|
641
|
+
forward_batch: ForwardBatch,
|
|
642
|
+
get_embedding: bool = False,
|
|
643
|
+
):
|
|
644
|
+
"""Run forward pass for GLM-4.1V.
|
|
645
|
+
|
|
646
|
+
Args:
|
|
647
|
+
input_ids: Flattened (concatenated) input_ids corresponding to a
|
|
648
|
+
batch.
|
|
649
|
+
positions: Flattened (concatenated) position ids corresponding to a
|
|
650
|
+
batch.
|
|
651
|
+
**NOTE**: If mrope is enabled (default setting for GLM-4.1V
|
|
652
|
+
opensource models), the shape will be `(3, seq_len)`,
|
|
653
|
+
otherwise it will be `(seq_len,).
|
|
654
|
+
(Use input_metadata.mrope_positions to replace it)
|
|
655
|
+
"""
|
|
656
|
+
if self.is_mrope_enabled:
|
|
657
|
+
positions = forward_batch.mrope_positions
|
|
658
|
+
|
|
659
|
+
if not (
|
|
660
|
+
forward_batch.forward_mode.is_decode()
|
|
661
|
+
or not forward_batch.contains_image_inputs()
|
|
662
|
+
):
|
|
663
|
+
if self.is_mrope_enabled:
|
|
664
|
+
assert positions.ndim == 2 and positions.size(0) == 3, (
|
|
665
|
+
"multimodal section rotary embedding requires "
|
|
666
|
+
f"(3, seq_len) positions, but got {positions.size()}"
|
|
667
|
+
)
|
|
668
|
+
|
|
669
|
+
hidden_states = general_mm_embed_routine(
|
|
670
|
+
input_ids=input_ids,
|
|
671
|
+
forward_batch=forward_batch,
|
|
672
|
+
language_model=self.model,
|
|
673
|
+
multimodal_model=self,
|
|
674
|
+
positions=positions,
|
|
675
|
+
)
|
|
556
676
|
|
|
557
|
-
|
|
558
|
-
|
|
677
|
+
aux_hidden_states = None
|
|
678
|
+
if self.capture_aux_hidden_states:
|
|
679
|
+
hidden_states, aux_hidden_states = hidden_states
|
|
680
|
+
|
|
681
|
+
if not get_embedding:
|
|
682
|
+
return self.logits_processor(
|
|
683
|
+
input_ids, hidden_states, self.lm_head, forward_batch, aux_hidden_states
|
|
684
|
+
)
|
|
685
|
+
else:
|
|
686
|
+
return self.pooler(hidden_states, forward_batch)
|
|
559
687
|
|
|
560
688
|
def _pad_vit_attn_dummy_heads(self, name: str, loaded_weight: torch.Tensor):
|
|
561
689
|
"""pad attn qkv weights for dummy heads"""
|
|
@@ -598,13 +726,12 @@ class Glm4vForConditionalGeneration(Qwen2_5_VLForConditionalGeneration):
|
|
|
598
726
|
]
|
|
599
727
|
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
|
600
728
|
for name, loaded_weight in weights:
|
|
601
|
-
if "language_model." in name:
|
|
602
|
-
name = name.replace("language_model.", "")
|
|
603
|
-
if "model.visual." in name:
|
|
604
|
-
name = name.replace("model.visual.", "visual.")
|
|
605
|
-
|
|
606
729
|
if "rotary_emb.inv_freq" in name:
|
|
607
730
|
continue
|
|
731
|
+
if "language_model" in name:
|
|
732
|
+
name = name.replace(r"model.language_model.", r"model.")
|
|
733
|
+
if "model.visual." in name:
|
|
734
|
+
name = name.replace("model.visual.", "visual.")
|
|
608
735
|
|
|
609
736
|
for param_name, weight_name, shard_id in stacked_params_mapping:
|
|
610
737
|
if weight_name not in name:
|
|
@@ -639,5 +766,19 @@ class Glm4vForConditionalGeneration(Qwen2_5_VLForConditionalGeneration):
|
|
|
639
766
|
)
|
|
640
767
|
weight_loader(param, loaded_weight)
|
|
641
768
|
|
|
769
|
+
def get_embed_and_head(self):
|
|
770
|
+
return self.model.embed_tokens.weight, self.lm_head.weight
|
|
771
|
+
|
|
772
|
+
def set_embed_and_head(self, embed, head):
|
|
773
|
+
del self.model.embed_tokens.weight
|
|
774
|
+
self.model.embed_tokens.weight = embed
|
|
775
|
+
if self.config.tie_word_embeddings:
|
|
776
|
+
self.lm_head = self.model.embed_tokens
|
|
777
|
+
else:
|
|
778
|
+
del self.lm_head.weight
|
|
779
|
+
self.lm_head.weight = head
|
|
780
|
+
torch.cuda.empty_cache()
|
|
781
|
+
torch.cuda.synchronize()
|
|
782
|
+
|
|
642
783
|
|
|
643
784
|
EntryClass = [Glm4vForConditionalGeneration]
|