sglang 0.5.3rc0__py3-none-any.whl → 0.5.3rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +7 -9
- sglang/bench_one_batch_server.py +321 -31
- sglang/bench_serving.py +10 -3
- sglang/global_config.py +2 -2
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/launch_server.py +14 -0
- sglang/profiler.py +2 -2
- sglang/srt/batch_invariant_ops/__init__.py +27 -0
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/dots_ocr.py +64 -0
- sglang/srt/configs/falcon_h1.py +360 -0
- sglang/srt/configs/load_config.py +8 -0
- sglang/srt/configs/model_config.py +160 -105
- sglang/srt/configs/qwen3_vl.py +586 -0
- sglang/srt/constrained/base_grammar_backend.py +1 -0
- sglang/srt/constrained/outlines_jump_forward.py +1 -1
- sglang/srt/constrained/xgrammar_backend.py +6 -4
- sglang/srt/debug_utils/dumper.py +10 -3
- sglang/srt/disaggregation/ascend/conn.py +2 -2
- sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
- sglang/srt/disaggregation/common/conn.py +266 -98
- sglang/srt/disaggregation/decode.py +50 -9
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +25 -16
- sglang/srt/disaggregation/mooncake/conn.py +51 -541
- sglang/srt/disaggregation/nixl/conn.py +148 -39
- sglang/srt/disaggregation/prefill.py +31 -14
- sglang/srt/disaggregation/utils.py +36 -5
- sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
- sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
- sglang/srt/distributed/parallel_state.py +135 -80
- sglang/srt/entrypoints/engine.py +23 -3
- sglang/srt/entrypoints/grpc_request_manager.py +330 -55
- sglang/srt/entrypoints/grpc_server.py +232 -102
- sglang/srt/entrypoints/http_server.py +49 -9
- sglang/srt/entrypoints/openai/protocol.py +110 -5
- sglang/srt/entrypoints/openai/serving_base.py +25 -6
- sglang/srt/entrypoints/openai/serving_chat.py +178 -49
- sglang/srt/entrypoints/openai/serving_completions.py +5 -3
- sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
- sglang/srt/entrypoints/openai/serving_responses.py +42 -0
- sglang/srt/environ.py +285 -0
- sglang/srt/eplb/expert_location.py +30 -5
- sglang/srt/function_call/function_call_parser.py +3 -2
- sglang/srt/function_call/glm4_moe_detector.py +3 -3
- sglang/srt/function_call/gpt_oss_detector.py +23 -0
- sglang/srt/function_call/json_array_parser.py +63 -0
- sglang/srt/function_call/kimik2_detector.py +17 -4
- sglang/srt/function_call/utils.py +96 -5
- sglang/srt/grpc/compile_proto.py +245 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +73 -68
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +60 -53
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +3 -0
- sglang/srt/layers/activation.py +7 -6
- sglang/srt/layers/attention/aiter_backend.py +14 -15
- sglang/srt/layers/attention/ascend_backend.py +108 -9
- sglang/srt/layers/attention/attention_registry.py +206 -0
- sglang/srt/layers/attention/base_attn_backend.py +12 -3
- sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +2 -2
- sglang/srt/layers/attention/fla/fused_recurrent.py +4 -4
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +2 -2
- sglang/srt/layers/attention/flashattention_backend.py +41 -8
- sglang/srt/layers/attention/flashinfer_backend.py +112 -194
- sglang/srt/layers/attention/flashinfer_mla_backend.py +11 -15
- sglang/srt/layers/attention/flashmla_backend.py +7 -5
- sglang/srt/layers/attention/hybrid_attn_backend.py +11 -3
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +72 -72
- sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +15 -98
- sglang/srt/layers/attention/mamba/mamba.py +566 -1
- sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
- sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
- sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
- sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
- sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
- sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
- sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
- sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
- sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
- sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
- sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
- sglang/srt/layers/attention/nsa/transform_index.py +144 -0
- sglang/srt/layers/attention/nsa/utils.py +24 -0
- sglang/srt/layers/attention/nsa_backend.py +887 -0
- sglang/srt/layers/attention/tbo_backend.py +6 -6
- sglang/srt/layers/attention/torch_flex_backend.py +325 -0
- sglang/srt/layers/attention/triton_backend.py +42 -9
- sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
- sglang/srt/layers/attention/trtllm_mla_backend.py +178 -34
- sglang/srt/layers/attention/vision.py +58 -0
- sglang/srt/layers/attention/wave_backend.py +4 -4
- sglang/srt/layers/communicator.py +8 -0
- sglang/srt/layers/dp_attention.py +11 -1
- sglang/srt/layers/elementwise.py +3 -1
- sglang/srt/layers/layernorm.py +2 -0
- sglang/srt/layers/linear.py +21 -4
- sglang/srt/layers/logits_processor.py +15 -2
- sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
- sglang/srt/layers/moe/ep_moe/layer.py +147 -74
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +52 -25
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +6 -2
- sglang/srt/layers/moe/fused_moe_triton/layer.py +11 -12
- sglang/srt/layers/moe/token_dispatcher/deepep.py +77 -19
- sglang/srt/layers/moe/utils.py +10 -0
- sglang/srt/layers/parameter.py +23 -6
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
- sglang/srt/layers/quantization/fp8.py +2 -2
- sglang/srt/layers/quantization/fp8_utils.py +1 -1
- sglang/srt/layers/quantization/modelopt_quant.py +44 -9
- sglang/srt/layers/quantization/mxfp4.py +12 -4
- sglang/srt/layers/quantization/quark/quark_moe.py +16 -3
- sglang/srt/layers/quantization/w4afp8.py +0 -4
- sglang/srt/layers/quantization/w8a8_int8.py +15 -3
- sglang/srt/layers/rotary_embedding.py +78 -31
- sglang/srt/layers/sampler.py +52 -4
- sglang/srt/layers/utils.py +23 -0
- sglang/srt/lora/backend/base_backend.py +3 -3
- sglang/srt/lora/backend/chunked_backend.py +348 -0
- sglang/srt/lora/backend/triton_backend.py +10 -4
- sglang/srt/lora/lora.py +7 -5
- sglang/srt/lora/lora_manager.py +17 -6
- sglang/srt/lora/mem_pool.py +1 -1
- sglang/srt/lora/triton_ops/__init__.py +4 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
- sglang/srt/lora/utils.py +7 -5
- sglang/srt/managers/cache_controller.py +42 -142
- sglang/srt/managers/data_parallel_controller.py +11 -46
- sglang/srt/managers/detokenizer_manager.py +11 -11
- sglang/srt/managers/io_struct.py +162 -118
- sglang/srt/managers/mm_utils.py +43 -6
- sglang/srt/managers/multi_tokenizer_mixin.py +17 -17
- sglang/srt/managers/multimodal_processor.py +1 -2
- sglang/srt/managers/overlap_utils.py +53 -0
- sglang/srt/managers/schedule_batch.py +167 -86
- sglang/srt/managers/schedule_policy.py +143 -16
- sglang/srt/managers/scheduler.py +359 -214
- sglang/srt/managers/scheduler_input_blocker.py +1 -1
- sglang/srt/managers/scheduler_metrics_mixin.py +98 -126
- sglang/srt/managers/scheduler_output_processor_mixin.py +21 -12
- sglang/srt/managers/scheduler_profiler_mixin.py +5 -5
- sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
- sglang/srt/managers/tokenizer_communicator_mixin.py +111 -5
- sglang/srt/managers/tokenizer_manager.py +84 -136
- sglang/srt/managers/tp_worker.py +39 -29
- sglang/srt/managers/tp_worker_overlap_thread.py +33 -41
- sglang/srt/managers/utils.py +1 -45
- sglang/srt/mem_cache/allocator.py +14 -20
- sglang/srt/mem_cache/allocator_ascend.py +41 -27
- sglang/srt/mem_cache/base_prefix_cache.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +8 -1
- sglang/srt/mem_cache/evict_policy.py +23 -0
- sglang/srt/mem_cache/hicache_storage.py +40 -1
- sglang/srt/mem_cache/hiradix_cache.py +119 -32
- sglang/srt/mem_cache/memory_pool.py +188 -10
- sglang/srt/mem_cache/memory_pool_host.py +134 -182
- sglang/srt/mem_cache/radix_cache.py +222 -71
- sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
- sglang/srt/mem_cache/storage/__init__.py +10 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
- sglang/srt/mem_cache/storage/backend_factory.py +223 -0
- sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
- sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +173 -58
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +10 -6
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +117 -10
- sglang/srt/mem_cache/swa_radix_cache.py +25 -34
- sglang/srt/metrics/collector.py +82 -120
- sglang/srt/metrics/func_timer.py +2 -7
- sglang/srt/metrics/utils.py +8 -1
- sglang/srt/model_executor/cpu_graph_runner.py +2 -2
- sglang/srt/model_executor/cuda_graph_runner.py +39 -32
- sglang/srt/model_executor/forward_batch_info.py +23 -38
- sglang/srt/model_executor/model_runner.py +131 -183
- sglang/srt/model_executor/npu_graph_runner.py +12 -5
- sglang/srt/model_loader/loader.py +14 -10
- sglang/srt/model_loader/weight_utils.py +156 -2
- sglang/srt/models/bailing_moe.py +27 -4
- sglang/srt/models/deepseek_nextn.py +6 -1
- sglang/srt/models/deepseek_v2.py +536 -153
- sglang/srt/models/dots_ocr.py +173 -0
- sglang/srt/models/falcon_h1.py +576 -0
- sglang/srt/models/gemma3_causal.py +0 -2
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/gemma3n_mm.py +1 -1
- sglang/srt/models/glm4_moe.py +3 -3
- sglang/srt/models/glm4_moe_nextn.py +2 -2
- sglang/srt/models/glm4v.py +1 -1
- sglang/srt/models/glm4v_moe.py +1 -1
- sglang/srt/models/gpt_oss.py +7 -30
- sglang/srt/models/kimi_vl_moonvit.py +2 -2
- sglang/srt/models/llama.py +4 -0
- sglang/srt/models/longcat_flash.py +1 -1
- sglang/srt/models/longcat_flash_nextn.py +1 -1
- sglang/srt/models/mllama4.py +15 -4
- sglang/srt/models/qwen2.py +0 -7
- sglang/srt/models/qwen2_5_vl.py +2 -2
- sglang/srt/models/qwen2_audio.py +1 -1
- sglang/srt/models/qwen2_moe.py +64 -1
- sglang/srt/models/qwen2_vl.py +1 -1
- sglang/srt/models/qwen3.py +18 -3
- sglang/srt/models/qwen3_moe.py +31 -3
- sglang/srt/models/qwen3_next.py +36 -9
- sglang/srt/models/qwen3_vl.py +787 -0
- sglang/srt/models/qwen3_vl_moe.py +471 -0
- sglang/srt/models/registry.py +15 -3
- sglang/srt/models/sarashina2_vision.py +269 -0
- sglang/srt/models/solar.py +505 -0
- sglang/srt/models/starcoder2.py +357 -0
- sglang/srt/models/torch_native_llama.py +9 -2
- sglang/srt/models/utils.py +51 -0
- sglang/srt/multimodal/processors/base_processor.py +15 -7
- sglang/srt/multimodal/processors/dots_vlm.py +2 -3
- sglang/srt/multimodal/processors/internvl.py +20 -8
- sglang/srt/multimodal/processors/qwen_vl.py +8 -1
- sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
- sglang/srt/parser/jinja_template_utils.py +6 -0
- sglang/srt/sampling/sampling_batch_info.py +20 -2
- sglang/srt/sampling/sampling_params.py +7 -0
- sglang/srt/server_args.py +753 -295
- sglang/srt/server_args_config_parser.py +146 -0
- sglang/srt/single_batch_overlap.py +151 -0
- sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
- sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
- sglang/srt/speculative/cpp_ngram/param.h +125 -0
- sglang/srt/speculative/cpp_ngram/queue.h +71 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +2 -1
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +3 -1
- sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -755
- sglang/srt/speculative/eagle_worker.py +57 -25
- sglang/srt/speculative/ngram_utils.py +428 -0
- sglang/srt/speculative/ngram_worker.py +245 -0
- sglang/srt/speculative/spec_info.py +47 -0
- sglang/srt/speculative/spec_utils.py +606 -0
- sglang/srt/torch_memory_saver_adapter.py +5 -7
- sglang/srt/tracing/trace.py +32 -6
- sglang/srt/two_batch_overlap.py +8 -5
- sglang/srt/utils/__init__.py +2 -0
- sglang/srt/{utils.py → utils/common.py} +399 -74
- sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +49 -5
- sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
- sglang/srt/utils/rpd_utils.py +452 -0
- sglang/srt/utils/slow_rank_detector.py +71 -0
- sglang/srt/warmup.py +8 -4
- sglang/srt/weight_sync/utils.py +1 -1
- sglang/test/get_logits_ut.py +57 -0
- sglang/test/run_eval.py +79 -11
- sglang/test/runners.py +1 -1
- sglang/test/simple_eval_common.py +5 -2
- sglang/test/simple_eval_mmmu_vlm.py +441 -0
- sglang/test/test_block_fp8.py +2 -2
- sglang/test/test_deterministic.py +297 -0
- sglang/test/test_disaggregation_utils.py +12 -1
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +355 -4
- sglang/utils.py +10 -1
- sglang/version.py +1 -1
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/METADATA +34 -25
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/RECORD +281 -210
- sglang/srt/mem_cache/lora_radix_cache.py +0 -421
- /sglang/srt/{remote_instance_weight_loader_utils.py → model_loader/remote_instance_weight_loader_utils.py} +0 -0
- /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/WHEEL +0 -0
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,576 @@
|
|
1
|
+
import enum
|
2
|
+
import logging
|
3
|
+
from typing import Any, Iterable, List, Optional, Set, Tuple
|
4
|
+
|
5
|
+
import torch
|
6
|
+
from torch import nn
|
7
|
+
|
8
|
+
from sglang.srt.configs.falcon_h1 import FalconH1Config
|
9
|
+
from sglang.srt.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
10
|
+
from sglang.srt.layers.activation import SiluAndMul
|
11
|
+
from sglang.srt.layers.attention.mamba.mamba import MambaMixer2
|
12
|
+
from sglang.srt.layers.communicator import LayerCommunicator, LayerScatterModes
|
13
|
+
from sglang.srt.layers.dp_attention import (
|
14
|
+
get_attention_tp_rank,
|
15
|
+
get_attention_tp_size,
|
16
|
+
is_dp_attention_enabled,
|
17
|
+
)
|
18
|
+
from sglang.srt.layers.layernorm import RMSNorm
|
19
|
+
from sglang.srt.layers.linear import (
|
20
|
+
MergedColumnParallelLinear,
|
21
|
+
QKVParallelLinear,
|
22
|
+
RowParallelLinear,
|
23
|
+
)
|
24
|
+
from sglang.srt.layers.logits_processor import LogitsProcessor
|
25
|
+
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
26
|
+
from sglang.srt.layers.radix_attention import RadixAttention
|
27
|
+
from sglang.srt.layers.rotary_embedding import get_rope
|
28
|
+
from sglang.srt.layers.vocab_parallel_embedding import (
|
29
|
+
ParallelLMHead,
|
30
|
+
VocabParallelEmbedding,
|
31
|
+
)
|
32
|
+
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
33
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
34
|
+
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
35
|
+
from sglang.srt.utils import add_prefix, is_cuda, make_layers
|
36
|
+
|
37
|
+
logger = logging.getLogger(__name__)
|
38
|
+
_is_cuda = is_cuda()
|
39
|
+
|
40
|
+
|
41
|
+
class FalconH1MLP(nn.Module):
|
42
|
+
def __init__(
|
43
|
+
self,
|
44
|
+
hidden_size: int,
|
45
|
+
intermediate_size: int,
|
46
|
+
hidden_act: str,
|
47
|
+
layer_id: int,
|
48
|
+
mlp_multipliers: List[float],
|
49
|
+
quant_config: Optional[QuantizationConfig] = None,
|
50
|
+
prefix: str = "",
|
51
|
+
reduce_results: bool = True,
|
52
|
+
) -> None:
|
53
|
+
super().__init__()
|
54
|
+
self.gate_up_proj = MergedColumnParallelLinear(
|
55
|
+
hidden_size,
|
56
|
+
[intermediate_size] * 2,
|
57
|
+
bias=False,
|
58
|
+
quant_config=quant_config,
|
59
|
+
prefix=add_prefix("gate_up_proj", prefix),
|
60
|
+
)
|
61
|
+
self.down_proj = RowParallelLinear(
|
62
|
+
intermediate_size,
|
63
|
+
hidden_size,
|
64
|
+
bias=False,
|
65
|
+
quant_config=quant_config,
|
66
|
+
prefix=add_prefix("down_proj", prefix),
|
67
|
+
reduce_results=reduce_results,
|
68
|
+
)
|
69
|
+
if hidden_act != "silu":
|
70
|
+
raise ValueError(
|
71
|
+
f"Unsupported activation: {hidden_act}. "
|
72
|
+
"Only silu is supported for now."
|
73
|
+
)
|
74
|
+
self.act_fn = SiluAndMul()
|
75
|
+
self.layer_id = layer_id
|
76
|
+
|
77
|
+
self.intermediate_size = intermediate_size
|
78
|
+
self.tp_size = get_tensor_model_parallel_world_size()
|
79
|
+
|
80
|
+
self.gate_multiplier, self.down_multiplier = mlp_multipliers
|
81
|
+
|
82
|
+
def forward(
|
83
|
+
self,
|
84
|
+
x,
|
85
|
+
forward_batch=None,
|
86
|
+
use_reduce_scatter: bool = False,
|
87
|
+
):
|
88
|
+
gate_up, _ = self.gate_up_proj(x)
|
89
|
+
gate_up[:, : self.intermediate_size // self.tp_size] *= self.gate_multiplier
|
90
|
+
|
91
|
+
x = self.act_fn(gate_up)
|
92
|
+
x, _ = self.down_proj(
|
93
|
+
x,
|
94
|
+
skip_all_reduce=use_reduce_scatter,
|
95
|
+
)
|
96
|
+
x = x * self.down_multiplier
|
97
|
+
return x
|
98
|
+
|
99
|
+
|
100
|
+
class FalconH1HybridAttentionDecoderLayer(nn.Module):
|
101
|
+
|
102
|
+
def __init__(
|
103
|
+
self,
|
104
|
+
config: FalconH1Config,
|
105
|
+
layer_id: int,
|
106
|
+
quant_config: Optional[QuantizationConfig] = None,
|
107
|
+
prefix: str = "",
|
108
|
+
alt_stream: Optional[torch.cuda.Stream] = None,
|
109
|
+
) -> None:
|
110
|
+
super().__init__()
|
111
|
+
self.config = config
|
112
|
+
self.hidden_size = config.hidden_size
|
113
|
+
self.attn_tp_rank = get_attention_tp_rank()
|
114
|
+
self.attn_tp_size = get_attention_tp_size()
|
115
|
+
self.tp_size = get_tensor_model_parallel_world_size()
|
116
|
+
self.total_num_heads = config.num_attention_heads
|
117
|
+
assert self.total_num_heads % self.attn_tp_size == 0
|
118
|
+
self.num_heads = self.total_num_heads // self.attn_tp_size
|
119
|
+
self.total_num_kv_heads = config.num_key_value_heads
|
120
|
+
if self.total_num_kv_heads >= self.attn_tp_size:
|
121
|
+
# Number of KV heads is greater than TP size, so we partition
|
122
|
+
# the KV heads across multiple tensor parallel GPUs.
|
123
|
+
assert self.total_num_kv_heads % self.attn_tp_size == 0
|
124
|
+
else:
|
125
|
+
# Number of KV heads is less than TP size, so we replicate
|
126
|
+
# the KV heads across multiple tensor parallel GPUs.
|
127
|
+
assert self.attn_tp_size % self.total_num_kv_heads == 0
|
128
|
+
self.num_kv_heads = max(1, self.total_num_kv_heads // self.attn_tp_size)
|
129
|
+
self.head_dim = config.head_dim or (self.hidden_size // self.num_heads)
|
130
|
+
self.q_size = self.num_heads * self.head_dim
|
131
|
+
self.kv_size = self.num_kv_heads * self.head_dim
|
132
|
+
self.scaling = self.head_dim**-0.5
|
133
|
+
self.rope_theta = getattr(config, "rope_theta", 10000)
|
134
|
+
self.max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
|
135
|
+
self.rope_scaling = getattr(config, "rope_scaling", None)
|
136
|
+
self.partial_rotary_factor = getattr(config, "partial_rotary_factor", 1)
|
137
|
+
self.layer_id = layer_id
|
138
|
+
|
139
|
+
self.rotary_emb = get_rope(
|
140
|
+
head_size=self.head_dim,
|
141
|
+
rotary_dim=self.head_dim,
|
142
|
+
max_position=self.max_position_embeddings,
|
143
|
+
rope_scaling=self.rope_scaling,
|
144
|
+
base=self.rope_theta,
|
145
|
+
partial_rotary_factor=self.partial_rotary_factor,
|
146
|
+
is_neox_style=True,
|
147
|
+
dtype=torch.get_default_dtype(), # see impl of get_rope
|
148
|
+
)
|
149
|
+
|
150
|
+
self.qkv_proj = QKVParallelLinear(
|
151
|
+
config.hidden_size,
|
152
|
+
self.head_dim,
|
153
|
+
self.total_num_heads,
|
154
|
+
self.total_num_kv_heads,
|
155
|
+
bias=False,
|
156
|
+
quant_config=quant_config,
|
157
|
+
tp_rank=self.attn_tp_rank,
|
158
|
+
tp_size=self.attn_tp_size,
|
159
|
+
)
|
160
|
+
|
161
|
+
self.o_proj = RowParallelLinear(
|
162
|
+
self.total_num_heads * self.head_dim,
|
163
|
+
config.hidden_size,
|
164
|
+
bias=False,
|
165
|
+
quant_config=quant_config,
|
166
|
+
reduce_results=False,
|
167
|
+
tp_rank=self.attn_tp_rank,
|
168
|
+
tp_size=self.attn_tp_size,
|
169
|
+
)
|
170
|
+
|
171
|
+
self.attn = RadixAttention(
|
172
|
+
self.num_heads,
|
173
|
+
self.head_dim,
|
174
|
+
self.scaling,
|
175
|
+
num_kv_heads=self.num_kv_heads,
|
176
|
+
layer_id=layer_id,
|
177
|
+
prefix=f"{prefix}.attn",
|
178
|
+
)
|
179
|
+
|
180
|
+
self.d_ssm = (
|
181
|
+
int(config.mamba_expand * config.hidden_size)
|
182
|
+
if config.mamba_d_ssm is None
|
183
|
+
else config.mamba_d_ssm
|
184
|
+
)
|
185
|
+
|
186
|
+
self.mamba = MambaMixer2(
|
187
|
+
hidden_size=config.hidden_size,
|
188
|
+
ssm_state_size=config.mamba_d_state,
|
189
|
+
conv_kernel_size=config.mamba_d_conv,
|
190
|
+
intermediate_size=self.d_ssm,
|
191
|
+
use_conv_bias=config.mamba_conv_bias,
|
192
|
+
use_bias=config.mamba_proj_bias,
|
193
|
+
n_groups=config.mamba_n_groups,
|
194
|
+
num_heads=config.mamba_n_heads,
|
195
|
+
layer_id=layer_id,
|
196
|
+
head_dim=config.mamba_d_head,
|
197
|
+
rms_norm_eps=config.rms_norm_eps,
|
198
|
+
chunk_size=config.mamba_chunk_size,
|
199
|
+
activation=config.hidden_act,
|
200
|
+
use_rms_norm=config.mamba_rms_norm,
|
201
|
+
prefix=f"{prefix}.mixer",
|
202
|
+
)
|
203
|
+
|
204
|
+
# FalconH1 all layers are sparse and have no nextn now
|
205
|
+
self.is_layer_sparse = False
|
206
|
+
is_previous_layer_sparse = False
|
207
|
+
|
208
|
+
self.layer_scatter_modes = LayerScatterModes.init_new(
|
209
|
+
layer_id=layer_id,
|
210
|
+
num_layers=config.num_hidden_layers,
|
211
|
+
is_layer_sparse=self.is_layer_sparse,
|
212
|
+
is_previous_layer_sparse=is_previous_layer_sparse,
|
213
|
+
)
|
214
|
+
|
215
|
+
self.feed_forward = FalconH1MLP(
|
216
|
+
hidden_size=self.hidden_size,
|
217
|
+
intermediate_size=config.intermediate_size,
|
218
|
+
hidden_act=config.hidden_act,
|
219
|
+
layer_id=layer_id,
|
220
|
+
mlp_multipliers=config.mlp_multipliers,
|
221
|
+
quant_config=quant_config,
|
222
|
+
prefix=add_prefix("mlp", prefix),
|
223
|
+
)
|
224
|
+
|
225
|
+
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
226
|
+
self.pre_ff_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
227
|
+
|
228
|
+
self.q_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps)
|
229
|
+
self.k_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps)
|
230
|
+
|
231
|
+
self.layer_communicator = LayerCommunicator(
|
232
|
+
layer_scatter_modes=self.layer_scatter_modes,
|
233
|
+
input_layernorm=self.input_layernorm,
|
234
|
+
post_attention_layernorm=self.pre_ff_layernorm,
|
235
|
+
allow_reduce_scatter=True,
|
236
|
+
)
|
237
|
+
|
238
|
+
self.alt_stream = alt_stream
|
239
|
+
self.key_multiplier = config.key_multiplier
|
240
|
+
|
241
|
+
self.ssm_out_multiplier = config.ssm_out_multiplier
|
242
|
+
self.ssm_in_multiplier = config.ssm_in_multiplier
|
243
|
+
|
244
|
+
self.attention_in_multiplier = config.attention_in_multiplier
|
245
|
+
self.attn_out_multiplier = config.attention_out_multiplier
|
246
|
+
|
247
|
+
self.groups_time_state_size = self.mamba.n_groups * config.mamba_d_state
|
248
|
+
self.zxbcdt_multipliers = config.ssm_multipliers
|
249
|
+
self._init_mup_vector()
|
250
|
+
|
251
|
+
def _init_mup_vector(self):
|
252
|
+
"""
|
253
|
+
Non learnable per-block scaling vector composed of element-wise
|
254
|
+
multipliersapplied to each separate contiguous block of the output
|
255
|
+
of the linear projection (in_proj) before further processing
|
256
|
+
(gating, convolution, SSM):
|
257
|
+
|
258
|
+
- Z block: [0 : d_ssm] → zxbcdt_multipliers[0]
|
259
|
+
- X block: [d_ssm : 2 * d_ssm] → zxbcdt_multipliers[1]
|
260
|
+
- B block: [2 * d_ssm : 2 * d_ssm + G * S] → zxbcdt_multipliers[2]
|
261
|
+
- C block: [2 * d_ssm + G * S : 2 * d_ssm + 2 * G * S]
|
262
|
+
→ zxbcdt_multipliers[3]
|
263
|
+
- dt block: [2 * d_ssm + 2 * G * S : end] → zxbcdt_multipliers[4]
|
264
|
+
|
265
|
+
where:
|
266
|
+
- d_ssm: Dimension of state-space model latent
|
267
|
+
- G: Number of groups (n_groups)
|
268
|
+
- S: SSM state size per group
|
269
|
+
- All indices are divided by tp_size to support tensor parallelism
|
270
|
+
"""
|
271
|
+
vector_shape = (
|
272
|
+
2 * self.d_ssm + 2 * self.groups_time_state_size + self.config.mamba_n_heads
|
273
|
+
) // self.tp_size
|
274
|
+
mup_vector = torch.ones(1, vector_shape)
|
275
|
+
# Z vector 0 -> d_ssm
|
276
|
+
mup_vector[:, : self.d_ssm // self.tp_size] *= self.zxbcdt_multipliers[0]
|
277
|
+
# X vector d_ssm -> 2 * d_ssm
|
278
|
+
mup_vector[
|
279
|
+
:, (self.d_ssm // self.tp_size) : (2 * self.d_ssm // self.tp_size)
|
280
|
+
] *= self.zxbcdt_multipliers[1]
|
281
|
+
# B vector 2 * d_ssm -> 2 * d_ssm + (n_group * d_state)
|
282
|
+
mup_vector[
|
283
|
+
:,
|
284
|
+
(2 * self.d_ssm)
|
285
|
+
// self.tp_size : (2 * self.d_ssm + self.groups_time_state_size)
|
286
|
+
// self.tp_size,
|
287
|
+
] *= self.zxbcdt_multipliers[2]
|
288
|
+
# C vector 2 * d_ssm + (n_group * d_state)
|
289
|
+
# -> 2 * d_ssm + 2 * (n_group * d_state)
|
290
|
+
mup_vector[
|
291
|
+
:,
|
292
|
+
(2 * self.d_ssm + self.groups_time_state_size)
|
293
|
+
// self.tp_size : (2 * self.d_ssm + 2 * self.groups_time_state_size)
|
294
|
+
// self.tp_size,
|
295
|
+
] *= self.zxbcdt_multipliers[3]
|
296
|
+
# dt vector 2 * d_ssm + 2 * (n_group * d_state)
|
297
|
+
# -> 2 * d_ssm + 2 * (n_group * d_state) + n_heads
|
298
|
+
mup_vector[
|
299
|
+
:,
|
300
|
+
(2 * self.d_ssm + 2 * self.groups_time_state_size) // self.tp_size :,
|
301
|
+
] *= self.zxbcdt_multipliers[4]
|
302
|
+
|
303
|
+
self.register_buffer("mup_vector", mup_vector, persistent=False)
|
304
|
+
|
305
|
+
def self_attention(
|
306
|
+
self,
|
307
|
+
positions: torch.Tensor,
|
308
|
+
hidden_states: torch.Tensor,
|
309
|
+
forward_batch: ForwardBatch,
|
310
|
+
) -> torch.Tensor:
|
311
|
+
qkv, _ = self.qkv_proj(hidden_states)
|
312
|
+
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
313
|
+
k = k * self.key_multiplier
|
314
|
+
q, k = self.rotary_emb(positions, q, k)
|
315
|
+
|
316
|
+
attn_output = self.attn(q, k, v, forward_batch)
|
317
|
+
|
318
|
+
output, _ = self.o_proj(attn_output)
|
319
|
+
return output
|
320
|
+
|
321
|
+
def forward(
|
322
|
+
self,
|
323
|
+
positions: torch.Tensor,
|
324
|
+
hidden_states: torch.Tensor,
|
325
|
+
residual: Optional[torch.Tensor],
|
326
|
+
forward_batch: ForwardBatch,
|
327
|
+
**kwargs: Any,
|
328
|
+
):
|
329
|
+
hidden_states, residual = self.layer_communicator.prepare_attn(
|
330
|
+
hidden_states, residual, forward_batch
|
331
|
+
)
|
332
|
+
|
333
|
+
if not forward_batch.forward_mode.is_idle():
|
334
|
+
# Attention block
|
335
|
+
attention_hidden_states = self.self_attention(
|
336
|
+
positions=positions,
|
337
|
+
hidden_states=hidden_states * self.attention_in_multiplier,
|
338
|
+
forward_batch=forward_batch,
|
339
|
+
)
|
340
|
+
attention_hidden_states = attention_hidden_states * self.attn_out_multiplier
|
341
|
+
|
342
|
+
# Mamba block
|
343
|
+
mamba_hidden_states = torch.empty_like(hidden_states)
|
344
|
+
self.mamba(
|
345
|
+
hidden_states * self.ssm_in_multiplier,
|
346
|
+
mamba_hidden_states,
|
347
|
+
forward_batch=forward_batch,
|
348
|
+
mup_vector=self.mup_vector,
|
349
|
+
)
|
350
|
+
mamba_hidden_states = mamba_hidden_states * self.ssm_out_multiplier
|
351
|
+
|
352
|
+
hidden_states = attention_hidden_states + mamba_hidden_states
|
353
|
+
|
354
|
+
# Fully Connected
|
355
|
+
hidden_states, residual = self.layer_communicator.prepare_mlp(
|
356
|
+
hidden_states, residual, forward_batch
|
357
|
+
)
|
358
|
+
use_reduce_scatter = self.layer_communicator.should_use_reduce_scatter(
|
359
|
+
forward_batch
|
360
|
+
)
|
361
|
+
hidden_states = self.feed_forward(
|
362
|
+
hidden_states, forward_batch, use_reduce_scatter
|
363
|
+
)
|
364
|
+
|
365
|
+
hidden_states, residual = self.layer_communicator.postprocess_layer(
|
366
|
+
hidden_states, residual, forward_batch
|
367
|
+
)
|
368
|
+
|
369
|
+
return hidden_states, residual
|
370
|
+
|
371
|
+
|
372
|
+
ALL_DECODER_LAYER_TYPES = {
|
373
|
+
"falcon_h1": FalconH1HybridAttentionDecoderLayer,
|
374
|
+
}
|
375
|
+
|
376
|
+
|
377
|
+
class FalconH1Model(nn.Module):
|
378
|
+
def __init__(
|
379
|
+
self,
|
380
|
+
config: FalconH1Config,
|
381
|
+
quant_config: Optional[QuantizationConfig] = None,
|
382
|
+
prefix: str = "",
|
383
|
+
) -> None:
|
384
|
+
super().__init__()
|
385
|
+
self.config = config
|
386
|
+
|
387
|
+
alt_stream = torch.cuda.Stream() if _is_cuda else None
|
388
|
+
self.embedding_multiplier = config.embedding_multiplier
|
389
|
+
|
390
|
+
self.embed_tokens = VocabParallelEmbedding(
|
391
|
+
config.vocab_size,
|
392
|
+
config.hidden_size,
|
393
|
+
org_num_embeddings=config.vocab_size,
|
394
|
+
enable_tp=not is_dp_attention_enabled(),
|
395
|
+
)
|
396
|
+
|
397
|
+
def get_layer(idx: int, prefix: str):
|
398
|
+
layer_class = ALL_DECODER_LAYER_TYPES[config.layers_block_type[idx]]
|
399
|
+
return layer_class(
|
400
|
+
config,
|
401
|
+
idx,
|
402
|
+
quant_config=quant_config,
|
403
|
+
prefix=prefix,
|
404
|
+
alt_stream=alt_stream,
|
405
|
+
)
|
406
|
+
|
407
|
+
self.layers = make_layers(
|
408
|
+
config.num_hidden_layers, get_layer, prefix=f"{prefix}.layers"
|
409
|
+
)
|
410
|
+
|
411
|
+
self.final_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
412
|
+
self.infer_count = 0
|
413
|
+
|
414
|
+
def forward(
|
415
|
+
self,
|
416
|
+
input_ids: torch.Tensor,
|
417
|
+
positions: torch.Tensor,
|
418
|
+
forward_batch: ForwardBatch,
|
419
|
+
# mamba_cache_params: MambaCacheParams,
|
420
|
+
inputs_embeds: Optional[torch.Tensor] = None,
|
421
|
+
) -> torch.Tensor:
|
422
|
+
|
423
|
+
# pass a sequence index tensor, that is required for
|
424
|
+
# proper continuous batching computation including
|
425
|
+
# chunked prefill
|
426
|
+
if inputs_embeds is not None:
|
427
|
+
hidden_states = inputs_embeds * self.embedding_multiplier
|
428
|
+
else:
|
429
|
+
hidden_states = self.embed_tokens(input_ids) * self.embedding_multiplier
|
430
|
+
|
431
|
+
residual = None
|
432
|
+
for i in range(len(self.layers)):
|
433
|
+
layer = self.layers[i]
|
434
|
+
hidden_states, residual = layer(
|
435
|
+
layer_id=i,
|
436
|
+
positions=positions,
|
437
|
+
hidden_states=hidden_states,
|
438
|
+
residual=residual,
|
439
|
+
forward_batch=forward_batch,
|
440
|
+
)
|
441
|
+
|
442
|
+
if not forward_batch.forward_mode.is_idle():
|
443
|
+
if residual is None:
|
444
|
+
hidden_states = self.final_layernorm(hidden_states)
|
445
|
+
else:
|
446
|
+
hidden_states, _ = self.final_layernorm(hidden_states, residual)
|
447
|
+
|
448
|
+
return hidden_states
|
449
|
+
|
450
|
+
|
451
|
+
class HybridLayerType(enum.Enum):
|
452
|
+
full_attention = "attention"
|
453
|
+
swa_attention = "swa_attention"
|
454
|
+
linear_attention = "linear_attention"
|
455
|
+
mamba2 = "mamba"
|
456
|
+
|
457
|
+
|
458
|
+
class FalconH1ForCausalLM(nn.Module):
|
459
|
+
fall_back_to_pt_during_load = False
|
460
|
+
|
461
|
+
def __init__(
|
462
|
+
self,
|
463
|
+
config: FalconH1Config,
|
464
|
+
quant_config: Optional[QuantizationConfig] = None,
|
465
|
+
prefix: str = "",
|
466
|
+
) -> None:
|
467
|
+
super().__init__()
|
468
|
+
self.config = config
|
469
|
+
self.pp_group = get_pp_group()
|
470
|
+
assert self.pp_group.is_first_rank and self.pp_group.is_last_rank
|
471
|
+
self.quant_config = quant_config
|
472
|
+
self.model = FalconH1Model(
|
473
|
+
config, quant_config, prefix=add_prefix("model", prefix)
|
474
|
+
)
|
475
|
+
if config.tie_word_embeddings:
|
476
|
+
self.lm_head = self.model.embed_tokens
|
477
|
+
else:
|
478
|
+
self.lm_head = ParallelLMHead(
|
479
|
+
config.vocab_size,
|
480
|
+
config.hidden_size,
|
481
|
+
quant_config=quant_config,
|
482
|
+
org_num_embeddings=config.vocab_size,
|
483
|
+
prefix=add_prefix("lm_head", prefix),
|
484
|
+
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
|
485
|
+
)
|
486
|
+
self.lm_head = self.lm_head.float()
|
487
|
+
self.lm_head_multiplier = config.lm_head_multiplier
|
488
|
+
self.logits_processor = LogitsProcessor(
|
489
|
+
config, logit_scale=self.lm_head_multiplier
|
490
|
+
)
|
491
|
+
|
492
|
+
@torch.no_grad()
|
493
|
+
def forward(
|
494
|
+
self,
|
495
|
+
input_ids: torch.Tensor,
|
496
|
+
positions: torch.Tensor,
|
497
|
+
forward_batch: ForwardBatch,
|
498
|
+
inputs_embeds: Optional[torch.Tensor] = None,
|
499
|
+
**kwargs,
|
500
|
+
):
|
501
|
+
hidden_states = self.model(input_ids, positions, forward_batch, inputs_embeds)
|
502
|
+
|
503
|
+
return self.logits_processor(
|
504
|
+
input_ids, hidden_states, self.lm_head, forward_batch
|
505
|
+
)
|
506
|
+
|
507
|
+
def get_embed_and_head(self):
|
508
|
+
return self.model.embed_tokens.weight, self.lm_head.weight
|
509
|
+
|
510
|
+
def set_embed_and_head(self, embed, head):
|
511
|
+
del self.model.embed_tokens.weight
|
512
|
+
del self.lm_head.weight
|
513
|
+
self.model.embed_tokens.weight = embed
|
514
|
+
self.lm_head.weight = head
|
515
|
+
torch.cuda.empty_cache()
|
516
|
+
torch.cuda.synchronize()
|
517
|
+
|
518
|
+
def load_weights(
|
519
|
+
self, weights: Iterable[Tuple[str, torch.Tensor]], is_mtp: bool = False
|
520
|
+
) -> Set[str]:
|
521
|
+
stacked_params_mapping = [
|
522
|
+
# (param_name, shard_name, shard_id)
|
523
|
+
("qkv_proj", "q_proj", "q"),
|
524
|
+
("qkv_proj", "k_proj", "k"),
|
525
|
+
("qkv_proj", "v_proj", "v"),
|
526
|
+
("gate_up_proj", "gate_proj", 0),
|
527
|
+
("gate_up_proj", "up_proj", 1),
|
528
|
+
]
|
529
|
+
|
530
|
+
params_dict = dict(self.named_parameters())
|
531
|
+
loaded_params: Set[str] = set()
|
532
|
+
for name, loaded_weight in weights:
|
533
|
+
|
534
|
+
if "rotary_emb.inv_freq" in name:
|
535
|
+
continue
|
536
|
+
|
537
|
+
if ".self_attn." in name:
|
538
|
+
name = name.replace(".self_attn", "")
|
539
|
+
|
540
|
+
if "A_log" in name:
|
541
|
+
name = name.replace("A_log", "A")
|
542
|
+
|
543
|
+
for param_name, weight_name, shard_id in stacked_params_mapping:
|
544
|
+
if weight_name not in name:
|
545
|
+
continue
|
546
|
+
|
547
|
+
name = name.replace(weight_name, param_name)
|
548
|
+
# Skip loading extra bias for GPTQ models.
|
549
|
+
if name.endswith(".bias") and name not in params_dict:
|
550
|
+
continue
|
551
|
+
# Skip layers on other devices.
|
552
|
+
# if is_pp_missing_parameter(name, self):
|
553
|
+
# continue
|
554
|
+
if name not in params_dict:
|
555
|
+
continue
|
556
|
+
param = params_dict[name]
|
557
|
+
weight_loader = getattr(param, "weight_loader")
|
558
|
+
weight_loader(param, loaded_weight, shard_id)
|
559
|
+
break
|
560
|
+
else:
|
561
|
+
# Skip loading extra bias for GPTQ models.
|
562
|
+
if name.endswith(".bias") and name not in params_dict:
|
563
|
+
continue
|
564
|
+
# if is_pp_missing_parameter(name, self):
|
565
|
+
# continue
|
566
|
+
|
567
|
+
param = params_dict[name]
|
568
|
+
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
569
|
+
|
570
|
+
weight_loader(param, loaded_weight)
|
571
|
+
|
572
|
+
loaded_params.add(name)
|
573
|
+
return loaded_params
|
574
|
+
|
575
|
+
|
576
|
+
EntryClass = FalconH1ForCausalLM
|
@@ -20,7 +20,6 @@ import torch.nn.functional as F
|
|
20
20
|
from torch import nn
|
21
21
|
from transformers import (
|
22
22
|
ROPE_INIT_FUNCTIONS,
|
23
|
-
AutoModel,
|
24
23
|
Gemma3TextConfig,
|
25
24
|
PretrainedConfig,
|
26
25
|
PreTrainedModel,
|
@@ -761,4 +760,3 @@ class Gemma3ForCausalLM(PreTrainedModel):
|
|
761
760
|
|
762
761
|
|
763
762
|
EntryClass = Gemma3ForCausalLM
|
764
|
-
AutoModel.register(Gemma3TextConfig, Gemma3ForCausalLM, exist_ok=True)
|
sglang/srt/models/gemma3_mm.py
CHANGED
@@ -23,7 +23,6 @@ import torch
|
|
23
23
|
from torch import nn
|
24
24
|
from transformers import Gemma3Config, PreTrainedModel
|
25
25
|
|
26
|
-
from sglang.srt.hf_transformers_utils import get_processor
|
27
26
|
from sglang.srt.layers.layernorm import Gemma3RMSNorm
|
28
27
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
29
28
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
@@ -44,6 +43,7 @@ from sglang.srt.model_loader.weight_utils import (
|
|
44
43
|
from sglang.srt.models.gemma3_causal import Gemma3ForCausalLM
|
45
44
|
from sglang.srt.models.siglip import SiglipVisionModel
|
46
45
|
from sglang.srt.utils import add_prefix
|
46
|
+
from sglang.srt.utils.hf_transformers_utils import get_processor
|
47
47
|
|
48
48
|
logger = logging.getLogger(__name__)
|
49
49
|
|
sglang/srt/models/gemma3n_mm.py
CHANGED
@@ -14,7 +14,6 @@ from transformers import (
|
|
14
14
|
)
|
15
15
|
from transformers.models.auto.modeling_auto import AutoModel
|
16
16
|
|
17
|
-
from sglang.srt.hf_transformers_utils import get_processor
|
18
17
|
from sglang.srt.layers.layernorm import RMSNorm
|
19
18
|
from sglang.srt.layers.linear import ColumnParallelLinear, RowParallelLinear
|
20
19
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
@@ -38,6 +37,7 @@ from sglang.srt.model_loader.weight_utils import (
|
|
38
37
|
from sglang.srt.models.gemma3n_audio import Gemma3nAudioEncoder
|
39
38
|
from sglang.srt.models.gemma3n_causal import Gemma3nRMSNorm, Gemma3nTextModel
|
40
39
|
from sglang.srt.utils import add_prefix
|
40
|
+
from sglang.srt.utils.hf_transformers_utils import get_processor
|
41
41
|
|
42
42
|
logger = logging.getLogger(__name__)
|
43
43
|
|
sglang/srt/models/glm4_moe.py
CHANGED
@@ -12,7 +12,7 @@
|
|
12
12
|
# limitations under the License.
|
13
13
|
# ==============================================================================
|
14
14
|
|
15
|
-
"""Inference-only GLM-4.5 model compatible with HuggingFace weights"""
|
15
|
+
"""Inference-only GLM-4.5, GLM-4.6 model compatible with HuggingFace weights"""
|
16
16
|
|
17
17
|
import logging
|
18
18
|
from typing import Any, Dict, Iterable, Optional, Tuple
|
@@ -785,9 +785,9 @@ class Glm4MoeForCausalLM(DeepseekV2ForCausalLM):
|
|
785
785
|
or self.config.architectures[0] != architecture
|
786
786
|
or self.config.n_shared_experts != 1
|
787
787
|
):
|
788
|
-
disable_reason = "Only GLM-4.5 on NV-platform with capability >= 80 can use shared experts fusion optimization."
|
788
|
+
disable_reason = "Only GLM-4.5 or GLM-4.6 on NV-platform with capability >= 80 can use shared experts fusion optimization."
|
789
789
|
elif get_moe_expert_parallel_world_size() > 1:
|
790
|
-
disable_reason = "Deepseek and GLM-4.5 can not use shared experts fusion optimization under expert parallelism."
|
790
|
+
disable_reason = "Deepseek and GLM-4.5 or GLM-4.6 can not use shared experts fusion optimization under expert parallelism."
|
791
791
|
|
792
792
|
if disable_reason is not None:
|
793
793
|
global_server_args_dict["disable_shared_experts_fusion"] = True
|
@@ -12,7 +12,7 @@
|
|
12
12
|
# limitations under the License.
|
13
13
|
# ==============================================================================
|
14
14
|
|
15
|
-
"""Inference-only GLM-4.5 NextN Speculative Decoding."""
|
15
|
+
"""Inference-only GLM-4.5, GLM-4.6 NextN Speculative Decoding."""
|
16
16
|
import logging
|
17
17
|
from typing import Iterable, Optional, Tuple
|
18
18
|
|
@@ -48,7 +48,7 @@ class Glm4MoeModelNextN(nn.Module):
|
|
48
48
|
super().__init__()
|
49
49
|
if quant_config is not None and quant_config.get_name() == "modelopt_fp4":
|
50
50
|
logger.warning(
|
51
|
-
"Overriding Glm4MoeForCausalLMNextN quant config for modelopt_fp4 GLM-4.5 model."
|
51
|
+
"Overriding Glm4MoeForCausalLMNextN quant config for modelopt_fp4 GLM-4.5 / GLM-4.6 model."
|
52
52
|
)
|
53
53
|
quant_config = None
|
54
54
|
|
sglang/srt/models/glm4v.py
CHANGED
@@ -7,7 +7,6 @@ import torch.nn as nn
|
|
7
7
|
import torch.nn.functional as F
|
8
8
|
from transformers.models.glm4v.configuration_glm4v import Glm4vConfig, Glm4vVisionConfig
|
9
9
|
|
10
|
-
from sglang.srt.hf_transformers_utils import get_processor
|
11
10
|
from sglang.srt.layers.activation import SiluAndMul
|
12
11
|
from sglang.srt.layers.attention import vision_utils
|
13
12
|
from sglang.srt.layers.layernorm import RMSNorm
|
@@ -28,6 +27,7 @@ from sglang.srt.models.qwen2_5_vl import (
|
|
28
27
|
Qwen2_5_VLForConditionalGeneration,
|
29
28
|
)
|
30
29
|
from sglang.srt.utils import add_prefix
|
30
|
+
from sglang.srt.utils.hf_transformers_utils import get_processor
|
31
31
|
|
32
32
|
logger = logging.getLogger(__name__)
|
33
33
|
|
sglang/srt/models/glm4v_moe.py
CHANGED
@@ -10,7 +10,6 @@ from sglang.srt.distributed import (
|
|
10
10
|
get_moe_expert_parallel_world_size,
|
11
11
|
get_tensor_model_parallel_world_size,
|
12
12
|
)
|
13
|
-
from sglang.srt.hf_transformers_utils import get_processor
|
14
13
|
from sglang.srt.layers.attention import vision_utils
|
15
14
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
16
15
|
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
@@ -22,6 +21,7 @@ from sglang.srt.model_loader.weight_utils import default_weight_loader
|
|
22
21
|
from sglang.srt.models.glm4_moe import Glm4MoeModel
|
23
22
|
from sglang.srt.models.glm4v import Glm4vForConditionalGeneration, Glm4vVisionModel
|
24
23
|
from sglang.srt.utils import add_prefix, is_cuda, log_info_on_rank0
|
24
|
+
from sglang.srt.utils.hf_transformers_utils import get_processor
|
25
25
|
|
26
26
|
_is_cuda = is_cuda()
|
27
27
|
|