sglang 0.5.1.post3__py3-none-any.whl → 0.5.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +3 -0
- sglang/bench_one_batch_server.py +10 -1
- sglang/bench_serving.py +251 -26
- sglang/lang/interpreter.py +1 -1
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/internvl.py +6 -0
- sglang/srt/configs/longcat_flash.py +104 -0
- sglang/srt/configs/model_config.py +37 -7
- sglang/srt/configs/qwen3_next.py +326 -0
- sglang/srt/connector/__init__.py +1 -1
- sglang/srt/connector/base_connector.py +1 -2
- sglang/srt/connector/redis.py +2 -2
- sglang/srt/connector/serde/__init__.py +1 -1
- sglang/srt/connector/serde/safe_serde.py +4 -3
- sglang/srt/custom_op.py +11 -1
- sglang/srt/debug_utils/dump_comparator.py +81 -44
- sglang/srt/debug_utils/dump_loader.py +97 -0
- sglang/srt/debug_utils/dumper.py +11 -3
- sglang/srt/debug_utils/text_comparator.py +73 -11
- sglang/srt/disaggregation/ascend/conn.py +75 -0
- sglang/srt/disaggregation/base/conn.py +1 -1
- sglang/srt/disaggregation/common/conn.py +15 -12
- sglang/srt/disaggregation/decode.py +6 -4
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +6 -420
- sglang/srt/disaggregation/mooncake/conn.py +18 -10
- sglang/srt/disaggregation/nixl/conn.py +180 -16
- sglang/srt/disaggregation/prefill.py +6 -4
- sglang/srt/disaggregation/utils.py +5 -50
- sglang/srt/distributed/parallel_state.py +94 -58
- sglang/srt/entrypoints/engine.py +34 -14
- sglang/srt/entrypoints/http_server.py +172 -47
- sglang/srt/entrypoints/openai/protocol.py +63 -3
- sglang/srt/entrypoints/openai/serving_base.py +6 -2
- sglang/srt/entrypoints/openai/serving_chat.py +34 -19
- sglang/srt/entrypoints/openai/serving_completions.py +10 -4
- sglang/srt/entrypoints/openai/serving_embedding.py +8 -4
- sglang/srt/entrypoints/openai/serving_responses.py +7 -4
- sglang/srt/eplb/eplb_manager.py +28 -4
- sglang/srt/eplb/expert_distribution.py +55 -15
- sglang/srt/eplb/expert_location.py +8 -3
- sglang/srt/eplb/expert_location_updater.py +1 -1
- sglang/srt/function_call/ebnf_composer.py +11 -9
- sglang/srt/function_call/glm4_moe_detector.py +1 -1
- sglang/srt/function_call/gpt_oss_detector.py +1 -1
- sglang/srt/function_call/qwen3_coder_detector.py +1 -1
- sglang/srt/hf_transformers_utils.py +12 -0
- sglang/srt/layers/activation.py +44 -9
- sglang/srt/layers/attention/aiter_backend.py +93 -68
- sglang/srt/layers/attention/ascend_backend.py +250 -112
- sglang/srt/layers/attention/fla/chunk.py +242 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
- sglang/srt/layers/attention/fla/chunk_o.py +178 -0
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
- sglang/srt/layers/attention/fla/cumsum.py +300 -0
- sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
- sglang/srt/layers/attention/fla/index.py +37 -0
- sglang/srt/layers/attention/fla/l2norm.py +150 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
- sglang/srt/layers/attention/fla/op.py +66 -0
- sglang/srt/layers/attention/fla/solve_tril.py +465 -0
- sglang/srt/layers/attention/fla/utils.py +331 -0
- sglang/srt/layers/attention/fla/wy_fast.py +158 -0
- sglang/srt/layers/attention/flashinfer_backend.py +6 -4
- sglang/srt/layers/attention/flashinfer_mla_backend.py +16 -12
- sglang/srt/layers/attention/hybrid_attn_backend.py +47 -8
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +584 -0
- sglang/srt/layers/attention/intel_amx_backend.py +3 -0
- sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
- sglang/srt/layers/attention/mamba/mamba.py +64 -0
- sglang/srt/layers/attention/torch_native_backend.py +12 -6
- sglang/srt/layers/attention/trtllm_mla_backend.py +126 -36
- sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
- sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
- sglang/srt/layers/communicator.py +45 -7
- sglang/srt/layers/layernorm.py +54 -12
- sglang/srt/layers/logits_processor.py +10 -3
- sglang/srt/layers/moe/__init__.py +2 -1
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -12
- sglang/srt/layers/moe/ep_moe/kernels.py +74 -0
- sglang/srt/layers/moe/ep_moe/layer.py +110 -49
- sglang/srt/layers/moe/fused_moe_native.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/__init__.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/{E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json } +29 -29
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +9 -1049
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +212 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +799 -0
- sglang/srt/layers/moe/fused_moe_triton/layer.py +56 -45
- sglang/srt/layers/moe/fused_moe_triton/moe_align_block_size.py +87 -0
- sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
- sglang/srt/layers/moe/moe_runner/base.py +274 -1
- sglang/srt/layers/moe/moe_runner/runner.py +80 -0
- sglang/srt/layers/moe/moe_runner/triton.py +448 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
- sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
- sglang/srt/layers/moe/token_dispatcher/deepep.py +41 -38
- sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
- sglang/srt/layers/moe/topk.py +43 -12
- sglang/srt/layers/moe/utils.py +6 -5
- sglang/srt/layers/quantization/awq.py +19 -7
- sglang/srt/layers/quantization/base_config.py +11 -6
- sglang/srt/layers/quantization/blockwise_int8.py +38 -27
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
- sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +9 -1
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +0 -3
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
- sglang/srt/layers/quantization/fp8.py +76 -47
- sglang/srt/layers/quantization/fp8_utils.py +43 -29
- sglang/srt/layers/quantization/gptq.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +107 -40
- sglang/srt/layers/quantization/moe_wna16.py +21 -18
- sglang/srt/layers/quantization/mxfp4.py +77 -45
- sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
- sglang/srt/layers/quantization/quark/utils.py +97 -0
- sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
- sglang/srt/layers/quantization/unquant.py +135 -47
- sglang/srt/layers/quantization/utils.py +13 -0
- sglang/srt/layers/quantization/w4afp8.py +60 -42
- sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
- sglang/srt/layers/quantization/w8a8_int8.py +83 -41
- sglang/srt/layers/rocm_linear_utils.py +44 -0
- sglang/srt/layers/rotary_embedding.py +28 -19
- sglang/srt/layers/sampler.py +29 -5
- sglang/srt/lora/backend/base_backend.py +50 -8
- sglang/srt/lora/backend/triton_backend.py +90 -2
- sglang/srt/lora/layers.py +32 -0
- sglang/srt/lora/lora.py +4 -1
- sglang/srt/lora/lora_manager.py +35 -112
- sglang/srt/lora/mem_pool.py +24 -10
- sglang/srt/lora/utils.py +18 -9
- sglang/srt/managers/cache_controller.py +242 -278
- sglang/srt/managers/data_parallel_controller.py +30 -15
- sglang/srt/managers/detokenizer_manager.py +13 -2
- sglang/srt/managers/disagg_service.py +46 -0
- sglang/srt/managers/io_struct.py +160 -11
- sglang/srt/managers/mm_utils.py +6 -1
- sglang/srt/managers/multi_tokenizer_mixin.py +579 -0
- sglang/srt/managers/schedule_batch.py +27 -44
- sglang/srt/managers/schedule_policy.py +4 -3
- sglang/srt/managers/scheduler.py +90 -115
- sglang/srt/managers/scheduler_metrics_mixin.py +114 -8
- sglang/srt/managers/scheduler_output_processor_mixin.py +29 -19
- sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
- sglang/srt/managers/scheduler_update_weights_mixin.py +8 -1
- sglang/srt/managers/template_manager.py +3 -3
- sglang/srt/managers/tokenizer_communicator_mixin.py +491 -0
- sglang/srt/managers/tokenizer_manager.py +41 -477
- sglang/srt/managers/tp_worker.py +16 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +8 -10
- sglang/srt/mem_cache/allocator.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +1 -1
- sglang/srt/mem_cache/hicache_storage.py +24 -22
- sglang/srt/mem_cache/hiradix_cache.py +184 -101
- sglang/srt/mem_cache/lora_radix_cache.py +1 -1
- sglang/srt/mem_cache/memory_pool.py +324 -41
- sglang/srt/mem_cache/memory_pool_host.py +25 -18
- sglang/srt/mem_cache/radix_cache.py +5 -6
- sglang/srt/mem_cache/radix_cache_cpp.py +1 -1
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
- sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
- sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +61 -34
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +149 -12
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
- sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +74 -19
- sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
- sglang/srt/mem_cache/swa_radix_cache.py +1 -3
- sglang/srt/metrics/collector.py +484 -63
- sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
- sglang/srt/metrics/utils.py +48 -0
- sglang/srt/model_executor/cpu_graph_runner.py +640 -0
- sglang/srt/model_executor/cuda_graph_runner.py +13 -5
- sglang/srt/model_executor/forward_batch_info.py +72 -18
- sglang/srt/model_executor/model_runner.py +189 -31
- sglang/srt/model_loader/__init__.py +9 -3
- sglang/srt/model_loader/loader.py +33 -28
- sglang/srt/model_loader/utils.py +12 -0
- sglang/srt/model_loader/weight_utils.py +2 -1
- sglang/srt/models/deepseek_v2.py +311 -50
- sglang/srt/models/gemma3n_mm.py +1 -1
- sglang/srt/models/glm4_moe.py +10 -1
- sglang/srt/models/glm4v.py +4 -2
- sglang/srt/models/gpt_oss.py +5 -18
- sglang/srt/models/internvl.py +28 -0
- sglang/srt/models/llama4.py +9 -0
- sglang/srt/models/llama_eagle3.py +17 -0
- sglang/srt/models/longcat_flash.py +1026 -0
- sglang/srt/models/longcat_flash_nextn.py +699 -0
- sglang/srt/models/minicpmv.py +165 -3
- sglang/srt/models/mllama4.py +25 -0
- sglang/srt/models/opt.py +637 -0
- sglang/srt/models/qwen2.py +33 -3
- sglang/srt/models/qwen2_5_vl.py +90 -42
- sglang/srt/models/qwen2_moe.py +79 -14
- sglang/srt/models/qwen3.py +8 -2
- sglang/srt/models/qwen3_moe.py +39 -8
- sglang/srt/models/qwen3_next.py +1039 -0
- sglang/srt/models/qwen3_next_mtp.py +109 -0
- sglang/srt/models/torch_native_llama.py +1 -1
- sglang/srt/models/transformers.py +1 -1
- sglang/srt/multimodal/processors/base_processor.py +4 -2
- sglang/srt/multimodal/processors/glm4v.py +9 -9
- sglang/srt/multimodal/processors/internvl.py +141 -129
- sglang/srt/{reasoning_parser.py → parser/reasoning_parser.py} +1 -1
- sglang/srt/sampling/penaltylib/orchestrator.py +14 -2
- sglang/srt/sampling/sampling_batch_info.py +18 -15
- sglang/srt/server_args.py +297 -79
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
- sglang/srt/speculative/eagle_worker.py +216 -120
- sglang/srt/speculative/spec_info.py +5 -0
- sglang/srt/speculative/standalone_worker.py +109 -0
- sglang/srt/utils.py +37 -2
- sglang/srt/weight_sync/utils.py +1 -1
- sglang/test/attention/test_trtllm_mla_backend.py +181 -8
- sglang/test/few_shot_gsm8k.py +1 -0
- sglang/test/runners.py +4 -0
- sglang/test/test_cutlass_moe.py +24 -6
- sglang/test/test_cutlass_w4a8_moe.py +24 -9
- sglang/test/test_disaggregation_utils.py +66 -0
- sglang/test/test_utils.py +25 -1
- sglang/utils.py +5 -0
- sglang/version.py +1 -1
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/METADATA +11 -9
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/RECORD +243 -194
- sglang/srt/disaggregation/launch_lb.py +0 -131
- sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
- /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
- /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
- /sglang/srt/{conversation.py → parser/conversation.py} +0 -0
- /sglang/srt/{harmony_parser.py → parser/harmony_parser.py} +0 -0
- /sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +0 -0
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/WHEEL +0 -0
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,109 @@
|
|
1
|
+
# Copyright 2023-2024 SGLang Team
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
#
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
7
|
+
#
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
11
|
+
# See the License for the specific language governing permissions and
|
12
|
+
# limitations under the License.
|
13
|
+
# ==============================================================================
|
14
|
+
|
15
|
+
"""Inference-only Qwen3Next MTP Speculative Decoding."""
|
16
|
+
import logging
|
17
|
+
from typing import Iterable, Optional, Tuple
|
18
|
+
|
19
|
+
import torch
|
20
|
+
from torch import nn
|
21
|
+
from transformers import PretrainedConfig
|
22
|
+
|
23
|
+
from sglang.srt.distributed import get_pp_group, get_tensor_model_parallel_world_size
|
24
|
+
from sglang.srt.layers.layernorm import GemmaRMSNorm, RMSNorm
|
25
|
+
from sglang.srt.layers.logits_processor import LogitsProcessor
|
26
|
+
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
27
|
+
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
|
28
|
+
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
29
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
30
|
+
from sglang.srt.models.qwen3_moe import Qwen3MoeModel
|
31
|
+
from sglang.srt.models.qwen3_next import Qwen3NextForCausalLM, Qwen3NextModel
|
32
|
+
from sglang.srt.utils import add_prefix
|
33
|
+
|
34
|
+
logger = logging.getLogger(__name__)
|
35
|
+
|
36
|
+
|
37
|
+
class Qwen3NextForCausalLMMTP(Qwen3NextForCausalLM):
|
38
|
+
|
39
|
+
def __init__(
|
40
|
+
self,
|
41
|
+
config: PretrainedConfig,
|
42
|
+
quant_config: Optional[QuantizationConfig] = None,
|
43
|
+
prefix: str = "",
|
44
|
+
) -> None:
|
45
|
+
nn.Module.__init__(self)
|
46
|
+
self.config = config
|
47
|
+
self.tp_size = get_tensor_model_parallel_world_size()
|
48
|
+
self.quant_config = quant_config
|
49
|
+
# if not set, model load will be broken in Qwen3NextForCausalLM load_weights()
|
50
|
+
self.pp_group = get_pp_group()
|
51
|
+
# self.determine_num_fused_shared_experts("Qwen3NextForCausalLMMTP")
|
52
|
+
|
53
|
+
# currently based on the provided ckpt, we:
|
54
|
+
# (1) do not use_dedicated_mtp_embeddings provided in ckpt since not provided and directly use the target model embeddings
|
55
|
+
# (2) hardcode bias=False since not provided
|
56
|
+
self.fc = nn.Linear(2 * config.hidden_size, config.hidden_size, bias=False)
|
57
|
+
RMSNorm_cls = GemmaRMSNorm
|
58
|
+
self.pre_fc_norm_embedding = RMSNorm_cls(
|
59
|
+
config.hidden_size, config.rms_norm_eps
|
60
|
+
)
|
61
|
+
self.pre_fc_norm_hidden = RMSNorm_cls(config.hidden_size, config.rms_norm_eps)
|
62
|
+
config.num_hidden_layers = 1
|
63
|
+
config.full_attention_interval = 1
|
64
|
+
self.model = Qwen3NextModel(
|
65
|
+
config, quant_config, prefix=add_prefix("model", prefix)
|
66
|
+
)
|
67
|
+
self.lm_head = ParallelLMHead(
|
68
|
+
config.vocab_size,
|
69
|
+
config.hidden_size,
|
70
|
+
quant_config=quant_config,
|
71
|
+
prefix=add_prefix("model.shared_head.head", prefix),
|
72
|
+
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
|
73
|
+
)
|
74
|
+
self.logits_processor = LogitsProcessor(config)
|
75
|
+
|
76
|
+
@torch.no_grad()
|
77
|
+
def forward(
|
78
|
+
self,
|
79
|
+
input_ids: torch.Tensor,
|
80
|
+
positions: torch.Tensor,
|
81
|
+
forward_batch: ForwardBatch,
|
82
|
+
input_embeds: Optional[torch.Tensor] = None,
|
83
|
+
**kwargs,
|
84
|
+
):
|
85
|
+
if input_embeds is None:
|
86
|
+
input_embeds = self.model.embed_tokens(input_ids)
|
87
|
+
|
88
|
+
input_embeds = self.pre_fc_norm_embedding(input_embeds)
|
89
|
+
hidden_states = self.pre_fc_norm_hidden(forward_batch.spec_info.hidden_states)
|
90
|
+
hidden_states = self.fc(torch.cat((input_embeds, hidden_states), dim=-1))
|
91
|
+
|
92
|
+
hidden_states = self.model(
|
93
|
+
input_ids,
|
94
|
+
positions,
|
95
|
+
forward_batch,
|
96
|
+
hidden_states,
|
97
|
+
)
|
98
|
+
|
99
|
+
return self.logits_processor(
|
100
|
+
input_ids, hidden_states, self.lm_head, forward_batch
|
101
|
+
)
|
102
|
+
|
103
|
+
def load_weights(
|
104
|
+
self, weights: Iterable[Tuple[str, torch.Tensor]], is_mtp: bool = False
|
105
|
+
):
|
106
|
+
super().load_weights(weights, is_mtp=True)
|
107
|
+
|
108
|
+
|
109
|
+
EntryClass = [Qwen3NextForCausalLMMTP]
|
@@ -22,7 +22,7 @@ Reference: https://pytorch.org/docs/stable/distributed.tensor.parallel.html
|
|
22
22
|
|
23
23
|
Here is a quick example to enable TP:
|
24
24
|
```python
|
25
|
-
from sglang.srt.model_parallel import tensor_parallel
|
25
|
+
from sglang.srt.layers.model_parallel import tensor_parallel
|
26
26
|
|
27
27
|
device_mesh = torch.distributed.init_device_mesh("cuda", (tp_size,))
|
28
28
|
tensor_parallel(model, device_mesh)
|
@@ -213,7 +213,7 @@ class TransformersForCausalLM(nn.Module):
|
|
213
213
|
"""
|
214
214
|
tp_plan = getattr(self.model.config, "base_model_tp_plan", None) or {}
|
215
215
|
|
216
|
-
if not tp_plan and
|
216
|
+
if not tp_plan and tp_size > 1:
|
217
217
|
raise ValueError(
|
218
218
|
f"{type(self.model)} does not support tensor parallel yet!"
|
219
219
|
)
|
@@ -13,7 +13,9 @@ from PIL import Image
|
|
13
13
|
from transformers import BaseImageProcessorFast
|
14
14
|
|
15
15
|
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
|
16
|
-
from sglang.srt.utils import load_audio, load_image, load_video, logger
|
16
|
+
from sglang.srt.utils import is_npu, load_audio, load_image, load_video, logger
|
17
|
+
|
18
|
+
_is_npu = is_npu()
|
17
19
|
|
18
20
|
|
19
21
|
@dataclasses.dataclass
|
@@ -232,7 +234,7 @@ class BaseMultimodalProcessor(ABC):
|
|
232
234
|
and isinstance(processor.image_processor, BaseImageProcessorFast)
|
233
235
|
and not self.server_args.disable_fast_image_processor
|
234
236
|
):
|
235
|
-
kwargs["device"] = "cuda"
|
237
|
+
kwargs["device"] = "cuda" if not _is_npu else "npu"
|
236
238
|
result = processor.__call__(
|
237
239
|
text=[input_text],
|
238
240
|
padding=True,
|
@@ -2,7 +2,6 @@ import re
|
|
2
2
|
from typing import List, Union
|
3
3
|
|
4
4
|
from decord import VideoReader
|
5
|
-
from transformers.video_utils import VideoMetadata
|
6
5
|
|
7
6
|
from sglang.srt.layers.rotary_embedding import MRotaryEmbedding
|
8
7
|
from sglang.srt.models.glm4v import Glm4vForConditionalGeneration
|
@@ -66,17 +65,18 @@ class Glm4vImageProcessor(SGLangBaseProcessor):
|
|
66
65
|
total_num_frames = len(vr)
|
67
66
|
duration = total_num_frames / video_fps if video_fps else 0
|
68
67
|
|
69
|
-
metadata = VideoMetadata(
|
70
|
-
total_num_frames=int(total_num_frames),
|
71
|
-
fps=float(video_fps),
|
72
|
-
duration=float(duration),
|
73
|
-
video_backend="decord",
|
74
|
-
)
|
75
|
-
|
76
68
|
# Extract all frames
|
77
69
|
indices = list(range(total_num_frames))
|
78
70
|
frames = vr.get_batch(indices).asnumpy()
|
79
|
-
|
71
|
+
|
72
|
+
# Return metadata as dict so transformers can properly create VideoMetadata objects
|
73
|
+
metadata = {
|
74
|
+
"total_num_frames": int(total_num_frames),
|
75
|
+
"fps": float(video_fps),
|
76
|
+
"duration": float(duration),
|
77
|
+
"video_backend": "decord",
|
78
|
+
"frames_indices": indices,
|
79
|
+
}
|
80
80
|
|
81
81
|
return frames, metadata
|
82
82
|
|
@@ -2,8 +2,10 @@
|
|
2
2
|
|
3
3
|
import numpy as np
|
4
4
|
import torch
|
5
|
-
|
5
|
+
import torchvision.transforms as T
|
6
|
+
from decord import VideoReader, cpu, gpu
|
6
7
|
from PIL import Image
|
8
|
+
from torchvision.transforms import InterpolationMode
|
7
9
|
|
8
10
|
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
|
9
11
|
from sglang.srt.models.interns1 import InternS1ForConditionalGeneration
|
@@ -48,99 +50,6 @@ class InternVLImageProcessor(BaseMultimodalProcessor):
|
|
48
50
|
image_token_id=tokenizer.convert_tokens_to_ids(self.IMG_CONTEXT_TOKEN),
|
49
51
|
).build(_image_processor)
|
50
52
|
|
51
|
-
@staticmethod
|
52
|
-
def build_transform(input_size):
|
53
|
-
IMAGENET_MEAN = (0.485, 0.456, 0.406)
|
54
|
-
IMAGENET_STD = (0.229, 0.224, 0.225)
|
55
|
-
|
56
|
-
def resize_image(img, size):
|
57
|
-
return img.resize((size, size), Image.Resampling.BICUBIC)
|
58
|
-
|
59
|
-
def to_tensor(img):
|
60
|
-
# Convert PIL Image to numpy array
|
61
|
-
img_array = np.array(img).astype(np.float32) / 255.0
|
62
|
-
# Convert HWC to CHW format
|
63
|
-
img_array = img_array.transpose(2, 0, 1)
|
64
|
-
return torch.from_numpy(img_array)
|
65
|
-
|
66
|
-
def normalize(tensor, mean, std):
|
67
|
-
mean = torch.tensor(mean).view(-1, 1, 1)
|
68
|
-
std = torch.tensor(std).view(-1, 1, 1)
|
69
|
-
return (tensor - mean) / std
|
70
|
-
|
71
|
-
def transform(img):
|
72
|
-
img = img.convert("RGB") if img.mode != "RGB" else img
|
73
|
-
img = resize_image(img, input_size)
|
74
|
-
tensor = to_tensor(img)
|
75
|
-
tensor = normalize(tensor, IMAGENET_MEAN, IMAGENET_STD)
|
76
|
-
return tensor
|
77
|
-
|
78
|
-
return transform
|
79
|
-
|
80
|
-
@staticmethod
|
81
|
-
def dynamic_preprocess(
|
82
|
-
image, min_num=1, max_num=12, image_size=448, use_thumbnail=False
|
83
|
-
):
|
84
|
-
|
85
|
-
def find_closest_aspect_ratio(
|
86
|
-
aspect_ratio, target_ratios, width, height, image_size
|
87
|
-
):
|
88
|
-
best_ratio_diff = float("inf")
|
89
|
-
best_ratio = (1, 1)
|
90
|
-
area = width * height
|
91
|
-
for ratio in target_ratios:
|
92
|
-
target_aspect_ratio = ratio[0] / ratio[1]
|
93
|
-
ratio_diff = abs(aspect_ratio - target_aspect_ratio)
|
94
|
-
if ratio_diff < best_ratio_diff:
|
95
|
-
best_ratio_diff = ratio_diff
|
96
|
-
best_ratio = ratio
|
97
|
-
elif ratio_diff == best_ratio_diff:
|
98
|
-
if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
|
99
|
-
best_ratio = ratio
|
100
|
-
return best_ratio
|
101
|
-
|
102
|
-
orig_width, orig_height = image.size
|
103
|
-
aspect_ratio = orig_width / orig_height
|
104
|
-
|
105
|
-
# calculate the existing image aspect ratio
|
106
|
-
target_ratios = set(
|
107
|
-
(i, j)
|
108
|
-
for n in range(min_num, max_num + 1)
|
109
|
-
for i in range(1, n + 1)
|
110
|
-
for j in range(1, n + 1)
|
111
|
-
if i * j <= max_num and i * j >= min_num
|
112
|
-
)
|
113
|
-
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
|
114
|
-
|
115
|
-
# find the closest aspect ratio to the target
|
116
|
-
target_aspect_ratio = find_closest_aspect_ratio(
|
117
|
-
aspect_ratio, target_ratios, orig_width, orig_height, image_size
|
118
|
-
)
|
119
|
-
|
120
|
-
# calculate the target width and height
|
121
|
-
target_width = image_size * target_aspect_ratio[0]
|
122
|
-
target_height = image_size * target_aspect_ratio[1]
|
123
|
-
blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
|
124
|
-
|
125
|
-
# resize the image
|
126
|
-
resized_img = image.resize((target_width, target_height))
|
127
|
-
processed_images = []
|
128
|
-
for i in range(blocks):
|
129
|
-
box = (
|
130
|
-
(i % (target_width // image_size)) * image_size,
|
131
|
-
(i // (target_width // image_size)) * image_size,
|
132
|
-
((i % (target_width // image_size)) + 1) * image_size,
|
133
|
-
((i // (target_width // image_size)) + 1) * image_size,
|
134
|
-
)
|
135
|
-
# split the image
|
136
|
-
split_img = resized_img.crop(box)
|
137
|
-
processed_images.append(split_img)
|
138
|
-
assert len(processed_images) == blocks
|
139
|
-
if use_thumbnail and len(processed_images) != 1:
|
140
|
-
thumbnail_img = image.resize((image_size, image_size))
|
141
|
-
processed_images.append(thumbnail_img)
|
142
|
-
return processed_images
|
143
|
-
|
144
53
|
@staticmethod
|
145
54
|
def get_index(bound, fps, max_frame, first_idx=0, num_segments=32):
|
146
55
|
if bound:
|
@@ -160,27 +69,112 @@ class InternVLImageProcessor(BaseMultimodalProcessor):
|
|
160
69
|
|
161
70
|
@staticmethod
|
162
71
|
def load_video(video_path, bound=None, input_size=448, max_num=1, num_segments=32):
|
163
|
-
|
72
|
+
try:
|
73
|
+
vr = VideoReader(video_path, ctx=gpu(0), num_threads=1)
|
74
|
+
use_gpu = True
|
75
|
+
except (RuntimeError, OSError) as e:
|
76
|
+
print(
|
77
|
+
f"[WARNING] Load video on gpu decoding failed: {e}. Falling back to CPU."
|
78
|
+
)
|
79
|
+
vr = VideoReader(video_path, ctx=cpu(0), num_threads=1)
|
80
|
+
use_gpu = False
|
81
|
+
|
164
82
|
max_frame = len(vr) - 1
|
165
83
|
fps = float(vr.get_avg_fps())
|
166
84
|
|
167
|
-
pixel_values_list
|
168
|
-
|
85
|
+
pixel_values_list = []
|
86
|
+
num_patches_list = []
|
169
87
|
frame_indices = InternVLImageProcessor.get_index(
|
170
88
|
bound, fps, max_frame, first_idx=0, num_segments=num_segments
|
171
89
|
)
|
90
|
+
|
172
91
|
for frame_index in frame_indices:
|
173
|
-
|
174
|
-
|
175
|
-
|
92
|
+
# Load frame
|
93
|
+
frame = vr[frame_index]
|
94
|
+
if use_gpu:
|
95
|
+
img = frame.cuda().permute(2, 0, 1).float() / 255.0
|
96
|
+
else:
|
97
|
+
img_np = frame.asnumpy()
|
98
|
+
img = torch.from_numpy(img_np).permute(2, 0, 1).cuda().float() / 255.0
|
99
|
+
|
100
|
+
# Using the mean and variance of the ImageNet dataset for all input images can lead to accuracy issues, while using the mean and variance of each input image is a more accurate choice.
|
101
|
+
mean = img.mean(dim=[1, 2], keepdim=True)
|
102
|
+
# Prevent division by zero; clamp to minimum value of 1e-6
|
103
|
+
std = img.std(dim=[1, 2], keepdim=True).clamp(min=1e-6)
|
104
|
+
img = (img - mean) / std
|
105
|
+
|
106
|
+
tiles = InternVLImageProcessor.dynamic_preprocess(
|
107
|
+
img, image_size=input_size, max_num=max_num, use_thumbnail=True
|
176
108
|
)
|
177
|
-
|
178
|
-
|
179
|
-
num_patches_list.append(
|
180
|
-
|
181
|
-
pixel_values = torch.cat(pixel_values_list)
|
109
|
+
|
110
|
+
pixel_values_list.append(tiles)
|
111
|
+
num_patches_list.append(tiles.shape[0])
|
112
|
+
|
113
|
+
pixel_values = torch.cat(pixel_values_list, dim=0)
|
182
114
|
return pixel_values, num_patches_list
|
183
115
|
|
116
|
+
@staticmethod
|
117
|
+
def dynamic_preprocess(tensor, image_size=448, max_num=12, use_thumbnail=False):
|
118
|
+
C, H, W = tensor.shape
|
119
|
+
aspect_ratio = W / H
|
120
|
+
|
121
|
+
# Generate all possible aspect ratios
|
122
|
+
target_ratios = set(
|
123
|
+
(i, j)
|
124
|
+
for n in range(1, max_num + 1)
|
125
|
+
for i in range(1, n + 1)
|
126
|
+
for j in range(1, n + 1)
|
127
|
+
if i * j <= max_num
|
128
|
+
)
|
129
|
+
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
|
130
|
+
|
131
|
+
# Find closest ratio
|
132
|
+
best_ratio_diff = float("inf")
|
133
|
+
best_ratio = (1, 1)
|
134
|
+
|
135
|
+
for x, y in target_ratios:
|
136
|
+
target_ar = x / y
|
137
|
+
diff = abs(aspect_ratio - target_ar)
|
138
|
+
blocks = x * y
|
139
|
+
best_blocks = best_ratio[0] * best_ratio[1]
|
140
|
+
|
141
|
+
if diff < best_ratio_diff:
|
142
|
+
best_ratio_diff = diff
|
143
|
+
best_ratio = (x, y)
|
144
|
+
elif diff == best_ratio_diff and blocks > best_blocks:
|
145
|
+
best_ratio = (x, y)
|
146
|
+
|
147
|
+
target_w, target_h = image_size * best_ratio[0], image_size * best_ratio[1]
|
148
|
+
blocks = best_ratio[0] * best_ratio[1]
|
149
|
+
|
150
|
+
# Resize on GPU
|
151
|
+
resized = torch.nn.functional.interpolate(
|
152
|
+
tensor.unsqueeze(0),
|
153
|
+
size=(target_h, target_w),
|
154
|
+
mode="bicubic",
|
155
|
+
align_corners=False,
|
156
|
+
).squeeze(0)
|
157
|
+
|
158
|
+
# Split into tiles
|
159
|
+
tiles = []
|
160
|
+
for i in range(blocks):
|
161
|
+
x = (i % best_ratio[0]) * image_size
|
162
|
+
y = (i // best_ratio[0]) * image_size
|
163
|
+
tile = resized[:, y : y + image_size, x : x + image_size]
|
164
|
+
tiles.append(tile)
|
165
|
+
|
166
|
+
# Add thumbnail if needed
|
167
|
+
if use_thumbnail and len(tiles) > 1:
|
168
|
+
thumb = torch.nn.functional.interpolate(
|
169
|
+
tensor.unsqueeze(0),
|
170
|
+
size=(image_size, image_size),
|
171
|
+
mode="bicubic",
|
172
|
+
align_corners=False,
|
173
|
+
).squeeze(0)
|
174
|
+
tiles.append(thumb)
|
175
|
+
|
176
|
+
return torch.stack(tiles).to(torch.bfloat16)
|
177
|
+
|
184
178
|
async def process_mm_data_async(
|
185
179
|
self, image_data, input_text, request_obj, **kwargs
|
186
180
|
):
|
@@ -191,53 +185,71 @@ class InternVLImageProcessor(BaseMultimodalProcessor):
|
|
191
185
|
discard_alpha_channel=True,
|
192
186
|
)
|
193
187
|
|
194
|
-
def process_image_internvl(image, input_size=448, max_num=12):
|
195
|
-
transform = InternVLImageProcessor.build_transform(input_size=input_size)
|
196
|
-
images = InternVLImageProcessor.dynamic_preprocess(
|
197
|
-
image, image_size=input_size, use_thumbnail=True, max_num=max_num
|
198
|
-
)
|
199
|
-
pixel_values = [transform(image) for image in images]
|
200
|
-
pixel_values = torch.stack(pixel_values)
|
201
|
-
return pixel_values
|
202
|
-
|
203
188
|
num_patches_list = []
|
204
189
|
pixel_values = []
|
190
|
+
|
205
191
|
# Process each input with allocated frames
|
206
|
-
for image_index,
|
192
|
+
for image_index, image in enumerate(base_output.images):
|
207
193
|
try:
|
208
194
|
# TODO: video input
|
209
|
-
|
210
|
-
|
211
|
-
|
212
|
-
|
213
|
-
|
214
|
-
|
215
|
-
|
216
|
-
|
195
|
+
# Convert PIL to GPU tensor
|
196
|
+
if isinstance(image, Image.Image):
|
197
|
+
img_np = np.array(image.convert("RGB"))
|
198
|
+
tensor = (
|
199
|
+
torch.from_numpy(img_np).permute(2, 0, 1).cuda().float() / 255.0
|
200
|
+
)
|
201
|
+
else:
|
202
|
+
tensor = image.cuda() # assume already tensor
|
203
|
+
|
204
|
+
# Using the mean and variance of the ImageNet dataset for all input images can lead to accuracy issues, while using the mean and variance of each input image is a more accurate choice.
|
205
|
+
mean = tensor.mean(dim=[1, 2], keepdim=True)
|
206
|
+
# Prevent division by zero; clamp to minimum value of 1e-6
|
207
|
+
std = tensor.std(dim=[1, 2], keepdim=True).clamp(min=1e-6)
|
208
|
+
tensor = (tensor - mean) / std
|
209
|
+
tiles = self.dynamic_preprocess(
|
210
|
+
tensor, image_size=448, max_num=12, use_thumbnail=True
|
211
|
+
)
|
212
|
+
|
213
|
+
pixel_values.append(tiles)
|
214
|
+
num_patches_list.append(tiles.shape[0])
|
215
|
+
|
216
|
+
except Exception as e:
|
217
|
+
print(f"[Error] Failed to process image {image_index}: {e}")
|
217
218
|
return None
|
218
219
|
|
220
|
+
# Concatenate all
|
219
221
|
pixel_values = torch.cat(pixel_values, dim=0)
|
220
222
|
|
221
223
|
original_placeholder = "<<<__IMG_CONTEXT_PLACEHOLDER__>>>"
|
222
224
|
input_text = input_text.replace(self.IMG_CONTEXT_TOKEN, original_placeholder)
|
223
225
|
|
224
|
-
|
226
|
+
input_text_updated = input_text
|
227
|
+
for num_patches in num_patches_list:
|
225
228
|
image_tokens = (
|
226
229
|
self.IMG_START_TOKEN
|
227
230
|
+ self.IMG_CONTEXT_TOKEN * self.num_image_token * num_patches
|
228
231
|
+ self.IMG_END_TOKEN
|
229
232
|
)
|
230
|
-
|
233
|
+
input_text_updated = input_text_updated.replace(
|
234
|
+
original_placeholder, image_tokens, 1
|
235
|
+
)
|
231
236
|
|
232
|
-
|
237
|
+
input_text_updated = input_text_updated.replace(
|
238
|
+
original_placeholder, self.IMG_CONTEXT_TOKEN
|
239
|
+
)
|
233
240
|
|
234
|
-
|
241
|
+
# Tokenize
|
242
|
+
input_ids_tensor = self.tokenizer(input_text_updated, return_tensors="pt")[
|
235
243
|
"input_ids"
|
236
244
|
].flatten()
|
245
|
+
input_ids = input_ids_tensor.tolist()
|
246
|
+
|
247
|
+
# Get image token offsets
|
237
248
|
image_offsets = self.get_mm_items_offset(
|
238
|
-
input_ids=
|
249
|
+
input_ids=input_ids_tensor.to("cuda"),
|
239
250
|
mm_token_id=self.mm_tokens.image_token_id,
|
240
251
|
)
|
252
|
+
|
241
253
|
items = [
|
242
254
|
MultimodalDataItem(
|
243
255
|
feature=pixel_values,
|
@@ -247,7 +259,7 @@ class InternVLImageProcessor(BaseMultimodalProcessor):
|
|
247
259
|
]
|
248
260
|
|
249
261
|
return {
|
250
|
-
"input_ids": input_ids
|
262
|
+
"input_ids": input_ids,
|
251
263
|
"mm_items": items,
|
252
264
|
"im_start_id": self.img_start_token_id,
|
253
265
|
"im_end_id": self.img_end_token_id,
|
@@ -1,7 +1,8 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
3
|
import abc
|
4
|
-
|
4
|
+
import weakref
|
5
|
+
from typing import TYPE_CHECKING, Optional, Set, Type
|
5
6
|
|
6
7
|
import torch
|
7
8
|
|
@@ -17,7 +18,7 @@ class BatchedPenalizerOrchestrator:
|
|
17
18
|
penalizers: Set[Type["_BatchedPenalizer"]],
|
18
19
|
):
|
19
20
|
self.vocab_size = vocab_size
|
20
|
-
self.
|
21
|
+
self._batch_ref = weakref.ref(batch)
|
21
22
|
self.device = batch.device
|
22
23
|
self.penalizers = {Penalizer: Penalizer(self) for Penalizer in penalizers}
|
23
24
|
|
@@ -27,6 +28,17 @@ class BatchedPenalizerOrchestrator:
|
|
27
28
|
is_required |= pen_is_required
|
28
29
|
self.is_required = is_required
|
29
30
|
|
31
|
+
@property
|
32
|
+
def batch(self) -> ScheduleBatch | None:
|
33
|
+
return self._batch_ref()
|
34
|
+
|
35
|
+
@batch.setter
|
36
|
+
def batch(self, value: Optional[ScheduleBatch]):
|
37
|
+
if value is None:
|
38
|
+
self._batch_ref = lambda: None
|
39
|
+
else:
|
40
|
+
self._batch_ref = weakref.ref(value)
|
41
|
+
|
30
42
|
def reqs(self):
|
31
43
|
return self.batch.reqs
|
32
44
|
|
@@ -67,28 +67,31 @@ class SamplingBatchInfo:
|
|
67
67
|
logit_bias: Optional[torch.Tensor] = None
|
68
68
|
|
69
69
|
@classmethod
|
70
|
-
def
|
70
|
+
def _get_global_server_args_dict(cls):
|
71
71
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
72
72
|
|
73
|
+
return global_server_args_dict
|
74
|
+
|
75
|
+
@classmethod
|
76
|
+
def from_schedule_batch(cls, batch: ScheduleBatch, vocab_size: int):
|
77
|
+
global_server_args_dict = cls._get_global_server_args_dict()
|
78
|
+
|
73
79
|
reqs = batch.reqs
|
74
80
|
device = batch.device
|
75
|
-
temperatures = (
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
.view(-1, 1)
|
81
|
-
.to(device, non_blocking=True)
|
82
|
-
)
|
81
|
+
temperatures = torch.tensor(
|
82
|
+
[r.sampling_params.temperature for r in reqs],
|
83
|
+
dtype=torch.float,
|
84
|
+
device=device,
|
85
|
+
).view(-1, 1)
|
83
86
|
top_ps = torch.tensor(
|
84
|
-
[r.sampling_params.top_p for r in reqs], dtype=torch.float
|
85
|
-
)
|
87
|
+
[r.sampling_params.top_p for r in reqs], dtype=torch.float, device=device
|
88
|
+
)
|
86
89
|
top_ks = torch.tensor(
|
87
|
-
[r.sampling_params.top_k for r in reqs], dtype=torch.int32
|
88
|
-
)
|
90
|
+
[r.sampling_params.top_k for r in reqs], dtype=torch.int32, device=device
|
91
|
+
)
|
89
92
|
min_ps = torch.tensor(
|
90
|
-
[r.sampling_params.min_p for r in reqs], dtype=torch.float
|
91
|
-
)
|
93
|
+
[r.sampling_params.min_p for r in reqs], dtype=torch.float, device=device
|
94
|
+
)
|
92
95
|
|
93
96
|
logit_bias = None
|
94
97
|
if any(r.sampling_params.logit_bias is not None for r in reqs):
|