sglang 0.4.4.post1__py3-none-any.whl → 0.4.4.post3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -0
- sglang/api.py +6 -0
- sglang/bench_one_batch.py +1 -1
- sglang/bench_one_batch_server.py +1 -1
- sglang/bench_serving.py +26 -4
- sglang/check_env.py +3 -4
- sglang/lang/backend/openai.py +18 -5
- sglang/lang/chat_template.py +28 -7
- sglang/lang/interpreter.py +7 -3
- sglang/lang/ir.py +10 -0
- sglang/srt/_custom_ops.py +1 -1
- sglang/srt/code_completion_parser.py +174 -0
- sglang/srt/configs/__init__.py +2 -6
- sglang/srt/configs/deepseekvl2.py +676 -0
- sglang/srt/configs/janus_pro.py +3 -4
- sglang/srt/configs/load_config.py +1 -0
- sglang/srt/configs/model_config.py +49 -8
- sglang/srt/configs/utils.py +25 -0
- sglang/srt/connector/__init__.py +51 -0
- sglang/srt/connector/base_connector.py +112 -0
- sglang/srt/connector/redis.py +85 -0
- sglang/srt/connector/s3.py +122 -0
- sglang/srt/connector/serde/__init__.py +31 -0
- sglang/srt/connector/serde/safe_serde.py +29 -0
- sglang/srt/connector/serde/serde.py +43 -0
- sglang/srt/connector/utils.py +35 -0
- sglang/srt/conversation.py +88 -0
- sglang/srt/disaggregation/conn.py +81 -0
- sglang/srt/disaggregation/decode.py +495 -0
- sglang/srt/disaggregation/mini_lb.py +285 -0
- sglang/srt/disaggregation/prefill.py +249 -0
- sglang/srt/disaggregation/utils.py +44 -0
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -1
- sglang/srt/distributed/parallel_state.py +42 -8
- sglang/srt/entrypoints/engine.py +55 -5
- sglang/srt/entrypoints/http_server.py +78 -13
- sglang/srt/entrypoints/verl_engine.py +2 -0
- sglang/srt/function_call_parser.py +133 -55
- sglang/srt/hf_transformers_utils.py +28 -3
- sglang/srt/layers/activation.py +4 -2
- sglang/srt/layers/attention/base_attn_backend.py +1 -1
- sglang/srt/layers/attention/flashattention_backend.py +434 -0
- sglang/srt/layers/attention/flashinfer_backend.py +1 -1
- sglang/srt/layers/attention/flashmla_backend.py +284 -0
- sglang/srt/layers/attention/triton_backend.py +171 -38
- sglang/srt/layers/attention/triton_ops/decode_attention.py +94 -31
- sglang/srt/layers/attention/triton_ops/extend_attention.py +14 -5
- sglang/srt/layers/attention/utils.py +53 -0
- sglang/srt/layers/attention/vision.py +9 -28
- sglang/srt/layers/dp_attention.py +41 -19
- sglang/srt/layers/layernorm.py +24 -2
- sglang/srt/layers/linear.py +17 -5
- sglang/srt/layers/logits_processor.py +25 -7
- sglang/srt/layers/moe/ep_moe/kernels.py +110 -11
- sglang/srt/layers/moe/ep_moe/layer.py +273 -1
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +416 -0
- sglang/srt/layers/moe/fused_moe_native.py +2 -1
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L20,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L40S,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1024,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +23 -32
- sglang/srt/layers/moe/fused_moe_triton/layer.py +1 -2
- sglang/srt/layers/moe/topk.py +60 -20
- sglang/srt/layers/parameter.py +1 -1
- sglang/srt/layers/quantization/__init__.py +80 -53
- sglang/srt/layers/quantization/awq.py +200 -0
- sglang/srt/layers/quantization/base_config.py +5 -0
- sglang/srt/layers/quantization/blockwise_int8.py +1 -1
- sglang/srt/layers/quantization/compressed_tensors/__init__.py +0 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +652 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +658 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +9 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py +56 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +162 -0
- sglang/srt/layers/quantization/compressed_tensors/utils.py +218 -0
- sglang/srt/layers/quantization/fp8.py +76 -34
- sglang/srt/layers/quantization/fp8_kernel.py +25 -8
- sglang/srt/layers/quantization/fp8_utils.py +284 -28
- sglang/srt/layers/quantization/gptq.py +36 -19
- sglang/srt/layers/quantization/kv_cache.py +98 -0
- sglang/srt/layers/quantization/modelopt_quant.py +9 -7
- sglang/srt/layers/quantization/utils.py +153 -0
- sglang/srt/layers/quantization/w8a8_fp8.py +70 -19
- sglang/srt/layers/rotary_embedding.py +78 -87
- sglang/srt/layers/sampler.py +1 -1
- sglang/srt/lora/backend/base_backend.py +4 -4
- sglang/srt/lora/backend/flashinfer_backend.py +12 -9
- sglang/srt/lora/backend/triton_backend.py +5 -8
- sglang/srt/lora/layers.py +87 -33
- sglang/srt/lora/lora.py +2 -22
- sglang/srt/lora/lora_manager.py +67 -30
- sglang/srt/lora/mem_pool.py +117 -52
- sglang/srt/lora/triton_ops/gate_up_lora_b.py +10 -4
- sglang/srt/lora/triton_ops/qkv_lora_b.py +8 -3
- sglang/srt/lora/triton_ops/sgemm_lora_a.py +16 -5
- sglang/srt/lora/triton_ops/sgemm_lora_b.py +11 -6
- sglang/srt/lora/utils.py +18 -1
- sglang/srt/managers/cache_controller.py +2 -5
- sglang/srt/managers/data_parallel_controller.py +30 -8
- sglang/srt/managers/expert_distribution.py +81 -0
- sglang/srt/managers/io_struct.py +43 -5
- sglang/srt/managers/mm_utils.py +373 -0
- sglang/srt/managers/multimodal_processor.py +68 -0
- sglang/srt/managers/multimodal_processors/base_processor.py +275 -0
- sglang/srt/managers/multimodal_processors/clip.py +63 -0
- sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +119 -0
- sglang/srt/managers/multimodal_processors/gemma3.py +83 -0
- sglang/srt/managers/{image_processors → multimodal_processors}/janus_pro.py +20 -15
- sglang/srt/managers/{image_processors → multimodal_processors}/llava.py +10 -15
- sglang/srt/managers/multimodal_processors/minicpm.py +167 -0
- sglang/srt/managers/{image_processors → multimodal_processors}/mlama.py +7 -8
- sglang/srt/managers/{image_processors → multimodal_processors}/qwen_vl.py +28 -22
- sglang/srt/managers/schedule_batch.py +134 -30
- sglang/srt/managers/scheduler.py +290 -31
- sglang/srt/managers/session_controller.py +1 -1
- sglang/srt/managers/tokenizer_manager.py +59 -24
- sglang/srt/managers/tp_worker.py +4 -1
- sglang/srt/managers/tp_worker_overlap_thread.py +3 -3
- sglang/srt/managers/utils.py +6 -1
- sglang/srt/mem_cache/hiradix_cache.py +18 -7
- sglang/srt/mem_cache/memory_pool.py +255 -98
- sglang/srt/mem_cache/paged_allocator.py +2 -2
- sglang/srt/mem_cache/radix_cache.py +4 -4
- sglang/srt/model_executor/cuda_graph_runner.py +36 -21
- sglang/srt/model_executor/forward_batch_info.py +68 -11
- sglang/srt/model_executor/model_runner.py +75 -8
- sglang/srt/model_loader/loader.py +171 -3
- sglang/srt/model_loader/weight_utils.py +51 -3
- sglang/srt/models/clip.py +563 -0
- sglang/srt/models/deepseek_janus_pro.py +31 -88
- sglang/srt/models/deepseek_nextn.py +22 -10
- sglang/srt/models/deepseek_v2.py +329 -73
- sglang/srt/models/deepseek_vl2.py +358 -0
- sglang/srt/models/gemma3_causal.py +694 -0
- sglang/srt/models/gemma3_mm.py +468 -0
- sglang/srt/models/llama.py +47 -7
- sglang/srt/models/llama_eagle.py +1 -0
- sglang/srt/models/llama_eagle3.py +196 -0
- sglang/srt/models/llava.py +3 -3
- sglang/srt/models/llavavid.py +3 -3
- sglang/srt/models/minicpmo.py +1995 -0
- sglang/srt/models/minicpmv.py +62 -137
- sglang/srt/models/mllama.py +4 -4
- sglang/srt/models/phi3_small.py +1 -1
- sglang/srt/models/qwen2.py +3 -0
- sglang/srt/models/qwen2_5_vl.py +68 -146
- sglang/srt/models/qwen2_classification.py +75 -0
- sglang/srt/models/qwen2_moe.py +9 -1
- sglang/srt/models/qwen2_vl.py +25 -63
- sglang/srt/openai_api/adapter.py +201 -104
- sglang/srt/openai_api/protocol.py +33 -7
- sglang/srt/patch_torch.py +71 -0
- sglang/srt/sampling/sampling_batch_info.py +1 -1
- sglang/srt/sampling/sampling_params.py +6 -6
- sglang/srt/server_args.py +114 -14
- sglang/srt/speculative/build_eagle_tree.py +7 -347
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +41 -5
- sglang/srt/speculative/eagle_utils.py +208 -252
- sglang/srt/speculative/eagle_worker.py +140 -54
- sglang/srt/speculative/spec_info.py +6 -1
- sglang/srt/torch_memory_saver_adapter.py +22 -0
- sglang/srt/utils.py +215 -21
- sglang/test/__init__.py +0 -0
- sglang/test/attention/__init__.py +0 -0
- sglang/test/attention/test_flashattn_backend.py +312 -0
- sglang/test/runners.py +29 -2
- sglang/test/test_activation.py +2 -1
- sglang/test/test_block_fp8.py +5 -4
- sglang/test/test_block_fp8_ep.py +2 -1
- sglang/test/test_dynamic_grad_mode.py +58 -0
- sglang/test/test_layernorm.py +3 -2
- sglang/test/test_utils.py +56 -5
- sglang/utils.py +31 -0
- sglang/version.py +1 -1
- {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info}/METADATA +16 -8
- {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info}/RECORD +180 -132
- {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info}/WHEEL +1 -1
- sglang/srt/configs/qwen2_5_vl_config.py +0 -1006
- sglang/srt/managers/image_processor.py +0 -55
- sglang/srt/managers/image_processors/base_image_processor.py +0 -219
- sglang/srt/managers/image_processors/minicpmv.py +0 -86
- sglang/srt/managers/multi_modality_padding.py +0 -134
- {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info/licenses}/LICENSE +0 -0
- {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info}/top_level.txt +0 -0
@@ -22,7 +22,6 @@ from typing import (
|
|
22
22
|
)
|
23
23
|
|
24
24
|
import filelock
|
25
|
-
import gguf
|
26
25
|
import huggingface_hub.constants
|
27
26
|
import numpy as np
|
28
27
|
import safetensors.torch
|
@@ -93,7 +92,7 @@ def convert_bin_to_safetensor_file(
|
|
93
92
|
pt_filename: str,
|
94
93
|
sf_filename: str,
|
95
94
|
) -> None:
|
96
|
-
loaded = torch.load(pt_filename, map_location="cpu")
|
95
|
+
loaded = torch.load(pt_filename, map_location="cpu", weights_only=True)
|
97
96
|
if "state_dict" in loaded:
|
98
97
|
loaded = loaded["state_dict"]
|
99
98
|
shared = _shared_pointers(loaded)
|
@@ -381,7 +380,7 @@ def np_cache_weights_iterator(
|
|
381
380
|
disable=not enable_tqdm,
|
382
381
|
bar_format=_BAR_FORMAT,
|
383
382
|
):
|
384
|
-
state = torch.load(bin_file, map_location="cpu")
|
383
|
+
state = torch.load(bin_file, map_location="cpu", weights_only=True)
|
385
384
|
for name, param in state.items():
|
386
385
|
param_path = os.path.join(np_folder, name)
|
387
386
|
with open(param_path, "wb") as f:
|
@@ -464,6 +463,8 @@ def pt_weights_iterator(
|
|
464
463
|
def get_gguf_extra_tensor_names(
|
465
464
|
gguf_file: str, gguf_to_hf_name_map: Dict[str, str]
|
466
465
|
) -> List[str]:
|
466
|
+
import gguf
|
467
|
+
|
467
468
|
reader = gguf.GGUFReader(gguf_file)
|
468
469
|
expected_gguf_keys = set(gguf_to_hf_name_map.keys())
|
469
470
|
exact_gguf_keys = set([tensor.name for tensor in reader.tensors])
|
@@ -479,6 +480,8 @@ def gguf_quant_weights_iterator(
|
|
479
480
|
them to torch tensors
|
480
481
|
"""
|
481
482
|
|
483
|
+
import gguf
|
484
|
+
|
482
485
|
reader = gguf.GGUFReader(gguf_file)
|
483
486
|
|
484
487
|
for tensor in reader.tensors:
|
@@ -585,6 +588,51 @@ def composed_weight_loader(
|
|
585
588
|
return composed_loader
|
586
589
|
|
587
590
|
|
591
|
+
def runai_safetensors_weights_iterator(
|
592
|
+
hf_weights_files: List[str],
|
593
|
+
) -> Generator[Tuple[str, torch.Tensor], None, None]:
|
594
|
+
"""Iterate over the weights in the model safetensor files."""
|
595
|
+
from runai_model_streamer import SafetensorsStreamer
|
596
|
+
|
597
|
+
enable_tqdm = (
|
598
|
+
not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0
|
599
|
+
)
|
600
|
+
|
601
|
+
with SafetensorsStreamer() as streamer:
|
602
|
+
for st_file in tqdm(
|
603
|
+
hf_weights_files,
|
604
|
+
desc="Loading safetensors using Runai Model Streamer",
|
605
|
+
disable=not enable_tqdm,
|
606
|
+
bar_format=_BAR_FORMAT,
|
607
|
+
):
|
608
|
+
streamer.stream_file(st_file)
|
609
|
+
yield from streamer.get_tensors()
|
610
|
+
|
611
|
+
|
612
|
+
def set_runai_streamer_env(load_config: LoadConfig):
|
613
|
+
if load_config.model_loader_extra_config:
|
614
|
+
extra_config = load_config.model_loader_extra_config
|
615
|
+
|
616
|
+
if "concurrency" in extra_config and isinstance(
|
617
|
+
extra_config.get("concurrency"), int
|
618
|
+
):
|
619
|
+
os.environ["RUNAI_STREAMER_CONCURRENCY"] = str(
|
620
|
+
extra_config.get("concurrency")
|
621
|
+
)
|
622
|
+
|
623
|
+
if "memory_limit" in extra_config and isinstance(
|
624
|
+
extra_config.get("memory_limit"), int
|
625
|
+
):
|
626
|
+
os.environ["RUNAI_STREAMER_MEMORY_LIMIT"] = str(
|
627
|
+
extra_config.get("memory_limit")
|
628
|
+
)
|
629
|
+
|
630
|
+
runai_streamer_s3_endpoint = os.getenv("RUNAI_STREAMER_S3_ENDPOINT")
|
631
|
+
aws_endpoint_url = os.getenv("AWS_ENDPOINT_URL")
|
632
|
+
if runai_streamer_s3_endpoint is None and aws_endpoint_url is not None:
|
633
|
+
os.environ["RUNAI_STREAMER_S3_ENDPOINT"] = aws_endpoint_url
|
634
|
+
|
635
|
+
|
588
636
|
def initialize_dummy_weights(
|
589
637
|
model: torch.nn.Module,
|
590
638
|
low: float = -1e-3,
|
@@ -0,0 +1,563 @@
|
|
1
|
+
# Adapted from
|
2
|
+
# https://github.com/huggingface/transformers/blob/af9b2eaa54c150741f298d6db939af6328e1dc38/src/transformers/models/clip/modeling_clip.py
|
3
|
+
|
4
|
+
from functools import partial
|
5
|
+
from typing import Iterable, List, Optional, Tuple, Type, Union
|
6
|
+
|
7
|
+
import torch
|
8
|
+
import torch.nn as nn
|
9
|
+
from transformers import CLIPConfig, CLIPTextConfig, CLIPVisionConfig
|
10
|
+
from transformers.modeling_attn_mask_utils import _create_4d_causal_attention_mask
|
11
|
+
|
12
|
+
from sglang.srt.layers.activation import QuickGELU
|
13
|
+
from sglang.srt.layers.attention.vision import VisionAttention
|
14
|
+
from sglang.srt.layers.linear import ColumnParallelLinear, RowParallelLinear
|
15
|
+
from sglang.srt.layers.pooler import EmbeddingPoolerOutput, Pooler, PoolingType
|
16
|
+
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
17
|
+
from sglang.srt.managers.schedule_batch import MultimodalInputs
|
18
|
+
from sglang.srt.model_executor.model_runner import ForwardBatch
|
19
|
+
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
20
|
+
from sglang.srt.utils import add_prefix
|
21
|
+
|
22
|
+
|
23
|
+
class CLIPVisionEmbeddings(nn.Module):
|
24
|
+
|
25
|
+
def __init__(self, config: CLIPVisionConfig):
|
26
|
+
super().__init__()
|
27
|
+
self.config = config
|
28
|
+
self.embed_dim = config.hidden_size
|
29
|
+
self.image_size = config.image_size
|
30
|
+
self.patch_size = config.patch_size
|
31
|
+
assert self.image_size % self.patch_size == 0
|
32
|
+
|
33
|
+
self.class_embedding = nn.Parameter(torch.randn(self.embed_dim))
|
34
|
+
|
35
|
+
self.patch_embedding = nn.Conv2d(
|
36
|
+
in_channels=config.num_channels,
|
37
|
+
out_channels=self.embed_dim,
|
38
|
+
kernel_size=self.patch_size,
|
39
|
+
stride=self.patch_size,
|
40
|
+
bias=False,
|
41
|
+
)
|
42
|
+
|
43
|
+
self.num_patches = (self.image_size // self.patch_size) ** 2
|
44
|
+
self.num_positions = self.num_patches + 1
|
45
|
+
self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
|
46
|
+
self.register_buffer(
|
47
|
+
"position_ids",
|
48
|
+
torch.arange(self.num_positions).expand((1, -1)),
|
49
|
+
persistent=False,
|
50
|
+
)
|
51
|
+
|
52
|
+
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
|
53
|
+
batch_size = pixel_values.shape[0]
|
54
|
+
target_dtype = self.patch_embedding.weight.dtype
|
55
|
+
patch_embeds = self.patch_embedding(
|
56
|
+
pixel_values.to(dtype=target_dtype)
|
57
|
+
) # shape = [*, width, grid, grid]
|
58
|
+
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
|
59
|
+
|
60
|
+
class_embeds = self.class_embedding.expand(batch_size, 1, -1)
|
61
|
+
embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
|
62
|
+
embeddings = embeddings + self.position_embedding(self.position_ids)
|
63
|
+
|
64
|
+
return embeddings
|
65
|
+
|
66
|
+
|
67
|
+
class CLIPTextEmbeddings(nn.Module):
|
68
|
+
def __init__(self, config: CLIPTextConfig):
|
69
|
+
super().__init__()
|
70
|
+
embed_dim = config.hidden_size
|
71
|
+
|
72
|
+
self.token_embedding = nn.Embedding(config.vocab_size, embed_dim)
|
73
|
+
self.position_embedding = nn.Embedding(
|
74
|
+
config.max_position_embeddings, embed_dim
|
75
|
+
)
|
76
|
+
|
77
|
+
# position_ids (1, len position emb) is contiguous in memory and exported when serialized
|
78
|
+
self.register_buffer(
|
79
|
+
"position_ids",
|
80
|
+
torch.arange(config.max_position_embeddings).expand((1, -1)),
|
81
|
+
persistent=False,
|
82
|
+
)
|
83
|
+
|
84
|
+
def forward(
|
85
|
+
self,
|
86
|
+
input_ids: Optional[torch.LongTensor] = None,
|
87
|
+
position_ids: Optional[torch.LongTensor] = None,
|
88
|
+
inputs_embeds: Optional[torch.FloatTensor] = None,
|
89
|
+
) -> torch.Tensor:
|
90
|
+
seq_length = (
|
91
|
+
input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]
|
92
|
+
)
|
93
|
+
|
94
|
+
if position_ids is None:
|
95
|
+
position_ids = self.position_ids[:, :seq_length]
|
96
|
+
|
97
|
+
if inputs_embeds is None:
|
98
|
+
inputs_embeds = self.token_embedding(input_ids)
|
99
|
+
|
100
|
+
position_embeddings = self.position_embedding(position_ids)
|
101
|
+
embeddings = inputs_embeds + position_embeddings
|
102
|
+
|
103
|
+
return embeddings
|
104
|
+
|
105
|
+
|
106
|
+
class CLIPMLP(nn.Module):
|
107
|
+
|
108
|
+
def __init__(
|
109
|
+
self,
|
110
|
+
config,
|
111
|
+
act_layer: Type[nn.Module] = QuickGELU,
|
112
|
+
quant_config: Optional[QuantizationConfig] = None,
|
113
|
+
prefix: str = "",
|
114
|
+
):
|
115
|
+
super().__init__()
|
116
|
+
self.fc1 = ColumnParallelLinear(
|
117
|
+
config.hidden_size,
|
118
|
+
config.intermediate_size,
|
119
|
+
quant_config=quant_config,
|
120
|
+
prefix=add_prefix("fc1", prefix),
|
121
|
+
)
|
122
|
+
self.act = act_layer()
|
123
|
+
self.fc2 = RowParallelLinear(
|
124
|
+
config.intermediate_size,
|
125
|
+
config.hidden_size,
|
126
|
+
quant_config=quant_config,
|
127
|
+
prefix=add_prefix("fc2", prefix),
|
128
|
+
)
|
129
|
+
|
130
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
131
|
+
x_parallel, _ = self.fc1(x)
|
132
|
+
x_parallel = self.act(x_parallel)
|
133
|
+
x, _ = self.fc2(x_parallel)
|
134
|
+
return x
|
135
|
+
|
136
|
+
|
137
|
+
class CLIPEncoderLayer(nn.Module):
|
138
|
+
|
139
|
+
def __init__(
|
140
|
+
self,
|
141
|
+
config: CLIPVisionConfig,
|
142
|
+
act_layer: Type[nn.Module] = QuickGELU,
|
143
|
+
norm_layer: Type[nn.Module] = None,
|
144
|
+
attn_implementation: Optional[str] = "sdpa",
|
145
|
+
quant_config: Optional[QuantizationConfig] = None,
|
146
|
+
prefix: str = "",
|
147
|
+
) -> None:
|
148
|
+
super().__init__()
|
149
|
+
if norm_layer is None:
|
150
|
+
norm_layer = partial(nn.LayerNorm, eps=config.layer_norm_eps)
|
151
|
+
self.layer_norm1 = norm_layer(config.hidden_size)
|
152
|
+
self.layer_norm2 = norm_layer(config.hidden_size)
|
153
|
+
if attn_implementation == "sdpa":
|
154
|
+
use_context_forward = False
|
155
|
+
softmax_in_single_precision = False
|
156
|
+
elif attn_implementation == "flash_attention_2":
|
157
|
+
softmax_in_single_precision = False
|
158
|
+
use_context_forward = True
|
159
|
+
elif attn_implementation == "eager":
|
160
|
+
softmax_in_single_precision = True
|
161
|
+
use_context_forward = False
|
162
|
+
self.self_attn = VisionAttention(
|
163
|
+
embed_dim=config.hidden_size,
|
164
|
+
num_heads=config.num_attention_heads,
|
165
|
+
projection_size=config.hidden_size,
|
166
|
+
use_qkv_parallel=True,
|
167
|
+
use_context_forward=use_context_forward,
|
168
|
+
softmax_in_single_precision=softmax_in_single_precision,
|
169
|
+
flatten_batch=True,
|
170
|
+
quant_config=quant_config,
|
171
|
+
prefix=add_prefix("attn", prefix),
|
172
|
+
)
|
173
|
+
self.mlp = CLIPMLP(
|
174
|
+
config,
|
175
|
+
act_layer=act_layer,
|
176
|
+
quant_config=quant_config,
|
177
|
+
prefix=add_prefix("mlp", prefix),
|
178
|
+
)
|
179
|
+
|
180
|
+
def forward(
|
181
|
+
self,
|
182
|
+
hidden_states: torch.Tensor,
|
183
|
+
attention_mask: torch.Tensor,
|
184
|
+
causal_attention_mask: torch.Tensor,
|
185
|
+
) -> torch.Tensor:
|
186
|
+
|
187
|
+
residual = hidden_states
|
188
|
+
hidden_states = self.layer_norm1(hidden_states)
|
189
|
+
# CLIP text model uses both `causal_attention_mask` and `attention_mask`
|
190
|
+
if attention_mask is not None and causal_attention_mask is not None:
|
191
|
+
attn_mask = attention_mask + causal_attention_mask
|
192
|
+
elif causal_attention_mask is not None:
|
193
|
+
attn_mask = causal_attention_mask
|
194
|
+
else:
|
195
|
+
attn_mask = attention_mask
|
196
|
+
hidden_states = self.self_attn(
|
197
|
+
hidden_states,
|
198
|
+
attention_mask=attn_mask,
|
199
|
+
# causal_attention_mask=causal_attention_mask,
|
200
|
+
)
|
201
|
+
|
202
|
+
hidden_states = residual + hidden_states
|
203
|
+
residual = hidden_states
|
204
|
+
hidden_states = self.layer_norm2(hidden_states)
|
205
|
+
hidden_states = self.mlp(hidden_states)
|
206
|
+
hidden_states = residual + hidden_states
|
207
|
+
return hidden_states
|
208
|
+
|
209
|
+
|
210
|
+
class CLIPEncoder(nn.Module):
|
211
|
+
"""
|
212
|
+
Transformer encoder consisting of `config.num_hidden_layers` self
|
213
|
+
attention layers. Each layer is a [`CLIPEncoderLayer`].
|
214
|
+
|
215
|
+
Args:
|
216
|
+
config: CLIPConfig
|
217
|
+
"""
|
218
|
+
|
219
|
+
def __init__(
|
220
|
+
self,
|
221
|
+
config: CLIPVisionConfig,
|
222
|
+
quant_config: Optional[QuantizationConfig] = None,
|
223
|
+
prefix: str = "",
|
224
|
+
) -> None:
|
225
|
+
super().__init__()
|
226
|
+
|
227
|
+
self.config = config
|
228
|
+
|
229
|
+
num_hidden_layers = config.num_hidden_layers
|
230
|
+
norm_layer = partial(nn.LayerNorm, eps=config.layer_norm_eps)
|
231
|
+
self.layers = nn.ModuleList(
|
232
|
+
[
|
233
|
+
CLIPEncoderLayer(
|
234
|
+
config=config,
|
235
|
+
norm_layer=norm_layer,
|
236
|
+
attn_implementation="sdpa",
|
237
|
+
quant_config=quant_config,
|
238
|
+
prefix=add_prefix(f"layers.{layer_idx}", prefix),
|
239
|
+
)
|
240
|
+
for layer_idx in range(num_hidden_layers)
|
241
|
+
]
|
242
|
+
)
|
243
|
+
|
244
|
+
def forward(
|
245
|
+
self,
|
246
|
+
inputs_embeds: torch.Tensor,
|
247
|
+
attention_mask: torch.Tensor = None,
|
248
|
+
causal_attention_mask: torch.Tensor = None,
|
249
|
+
return_all_hidden_states: bool = False,
|
250
|
+
) -> Union[torch.Tensor, list[torch.Tensor]]:
|
251
|
+
hidden_states_pool = [inputs_embeds]
|
252
|
+
hidden_states = inputs_embeds
|
253
|
+
|
254
|
+
for encoder_layer in self.layers:
|
255
|
+
hidden_states = encoder_layer(
|
256
|
+
hidden_states, attention_mask, causal_attention_mask
|
257
|
+
)
|
258
|
+
if return_all_hidden_states:
|
259
|
+
hidden_states_pool.append(hidden_states)
|
260
|
+
if return_all_hidden_states:
|
261
|
+
return hidden_states_pool
|
262
|
+
return hidden_states
|
263
|
+
|
264
|
+
|
265
|
+
class CLIPTextTransformer(nn.Module):
|
266
|
+
def __init__(
|
267
|
+
self,
|
268
|
+
config: CLIPTextConfig,
|
269
|
+
quant_config: Optional[QuantizationConfig] = None,
|
270
|
+
prefix: str = "",
|
271
|
+
) -> None:
|
272
|
+
super().__init__()
|
273
|
+
self.config = config
|
274
|
+
embed_dim = config.hidden_size
|
275
|
+
self.embeddings = CLIPTextEmbeddings(config)
|
276
|
+
self.encoder = CLIPEncoder(
|
277
|
+
config=config,
|
278
|
+
quant_config=quant_config,
|
279
|
+
prefix=add_prefix("encoder", prefix),
|
280
|
+
)
|
281
|
+
self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
|
282
|
+
|
283
|
+
@property
|
284
|
+
def device(self) -> torch.device:
|
285
|
+
return self.encoder.layers[0].layer_norm1.weight.device
|
286
|
+
|
287
|
+
def forward(
|
288
|
+
self,
|
289
|
+
input_ids: torch.Tensor,
|
290
|
+
attention_mask: Optional[torch.Tensor] = None,
|
291
|
+
position_ids: Optional[torch.Tensor] = None,
|
292
|
+
):
|
293
|
+
input_shape = input_ids.size()
|
294
|
+
input_ids = input_ids.view(-1, input_shape[-1])
|
295
|
+
hidden_states = self.embeddings(input_ids, position_ids)
|
296
|
+
causal_attention_mask = _create_4d_causal_attention_mask(
|
297
|
+
input_ids.shape, hidden_states.dtype, device=hidden_states.device
|
298
|
+
)
|
299
|
+
encoder_outputs = self.encoder(
|
300
|
+
hidden_states, attention_mask, causal_attention_mask
|
301
|
+
)
|
302
|
+
last_hidden_state = self.final_layer_norm(encoder_outputs)
|
303
|
+
return last_hidden_state
|
304
|
+
|
305
|
+
|
306
|
+
class CLIPTextModel(nn.Module):
|
307
|
+
def __init__(
|
308
|
+
self,
|
309
|
+
config: CLIPTextConfig,
|
310
|
+
quant_config: Optional[QuantizationConfig] = None,
|
311
|
+
prefix: str = "",
|
312
|
+
) -> None:
|
313
|
+
super().__init__()
|
314
|
+
self.config = config
|
315
|
+
self.text_model = CLIPTextTransformer(
|
316
|
+
config=config,
|
317
|
+
quant_config=quant_config,
|
318
|
+
prefix=add_prefix("text_model", prefix),
|
319
|
+
)
|
320
|
+
|
321
|
+
def forward(
|
322
|
+
self,
|
323
|
+
input_ids: torch.Tensor,
|
324
|
+
position_ids: torch.Tensor,
|
325
|
+
):
|
326
|
+
return self.text_model(input_ids, position_ids)
|
327
|
+
|
328
|
+
|
329
|
+
class CLIPVisionTransformer(nn.Module):
|
330
|
+
|
331
|
+
def __init__(
|
332
|
+
self,
|
333
|
+
config: CLIPVisionConfig,
|
334
|
+
quant_config: Optional[QuantizationConfig] = None,
|
335
|
+
prefix: str = "",
|
336
|
+
) -> None:
|
337
|
+
super().__init__()
|
338
|
+
|
339
|
+
self.config = config
|
340
|
+
embed_dim = config.hidden_size
|
341
|
+
|
342
|
+
self.embeddings = CLIPVisionEmbeddings(config)
|
343
|
+
|
344
|
+
# NOTE: This typo of "layrnorm" is not fixed on purpose to match
|
345
|
+
# the original transformers code and name of the model weights.
|
346
|
+
self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
|
347
|
+
|
348
|
+
self.encoder = CLIPEncoder(
|
349
|
+
config=config,
|
350
|
+
quant_config=quant_config,
|
351
|
+
prefix=add_prefix("encoder", prefix),
|
352
|
+
)
|
353
|
+
|
354
|
+
num_hidden_layers = config.num_hidden_layers
|
355
|
+
if len(self.encoder.layers) > config.num_hidden_layers:
|
356
|
+
raise ValueError(
|
357
|
+
f"The original encoder only has {num_hidden_layers} "
|
358
|
+
f"layers, but you requested {len(self.encoder.layers)} layers."
|
359
|
+
)
|
360
|
+
|
361
|
+
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
|
362
|
+
|
363
|
+
@property
|
364
|
+
def device(self) -> torch.device:
|
365
|
+
return self.encoder.layers[0].layer_norm1.weight.device
|
366
|
+
|
367
|
+
def forward(
|
368
|
+
self,
|
369
|
+
pixel_values: torch.Tensor,
|
370
|
+
) -> torch.Tensor:
|
371
|
+
|
372
|
+
hidden_states = self.embeddings(pixel_values.to(self.device))
|
373
|
+
hidden_states = self.pre_layrnorm(hidden_states)
|
374
|
+
|
375
|
+
return_all_hidden_states = False
|
376
|
+
|
377
|
+
last_hidden_state = self.encoder(
|
378
|
+
inputs_embeds=hidden_states,
|
379
|
+
return_all_hidden_states=return_all_hidden_states,
|
380
|
+
)
|
381
|
+
|
382
|
+
last_hidden_state = self.post_layernorm(last_hidden_state)
|
383
|
+
|
384
|
+
return last_hidden_state
|
385
|
+
|
386
|
+
|
387
|
+
class CLIPVisionModel(nn.Module):
|
388
|
+
def __init__(
|
389
|
+
self,
|
390
|
+
config: CLIPVisionConfig,
|
391
|
+
quant_config: Optional[QuantizationConfig] = None,
|
392
|
+
prefix: str = "",
|
393
|
+
):
|
394
|
+
super().__init__()
|
395
|
+
self.vision_model = CLIPVisionTransformer(
|
396
|
+
config, quant_config, prefix=add_prefix("vision_model", prefix)
|
397
|
+
)
|
398
|
+
|
399
|
+
def forward(self, pixel_values: torch.Tensor):
|
400
|
+
return self.vision_model(pixel_values)
|
401
|
+
|
402
|
+
|
403
|
+
class CLIPModel(nn.Module):
|
404
|
+
def __init__(
|
405
|
+
self,
|
406
|
+
config: CLIPConfig,
|
407
|
+
quant_config: Optional[QuantizationConfig] = None,
|
408
|
+
prefix: str = "",
|
409
|
+
) -> None:
|
410
|
+
super().__init__()
|
411
|
+
self.config = config
|
412
|
+
if not isinstance(config.text_config, CLIPTextConfig):
|
413
|
+
raise TypeError(
|
414
|
+
"config.text_config is expected to be of type CLIPTextConfig but is of type"
|
415
|
+
f" {type(config.text_config)}."
|
416
|
+
)
|
417
|
+
|
418
|
+
if not isinstance(config.vision_config, CLIPVisionConfig):
|
419
|
+
raise TypeError(
|
420
|
+
"config.vision_config is expected to be of type CLIPVisionConfig but is of type"
|
421
|
+
f" {type(config.vision_config)}."
|
422
|
+
)
|
423
|
+
|
424
|
+
text_config = config.text_config
|
425
|
+
vision_config = config.vision_config
|
426
|
+
|
427
|
+
self.projection_dim = config.projection_dim
|
428
|
+
self.text_embed_dim = text_config.hidden_size
|
429
|
+
self.vision_embed_dim = vision_config.hidden_size
|
430
|
+
self.visual_projection = nn.Linear(
|
431
|
+
self.vision_embed_dim, self.projection_dim, bias=False
|
432
|
+
)
|
433
|
+
self.text_projection = nn.Linear(
|
434
|
+
self.text_embed_dim, self.projection_dim, bias=False
|
435
|
+
)
|
436
|
+
self.logit_scale = nn.Parameter(
|
437
|
+
torch.tensor(self.config.logit_scale_init_value)
|
438
|
+
)
|
439
|
+
|
440
|
+
text_model = CLIPTextModel(
|
441
|
+
text_config, quant_config, prefix=add_prefix("text_model", prefix)
|
442
|
+
)
|
443
|
+
vision_model = CLIPVisionModel(
|
444
|
+
vision_config, quant_config, prefix=add_prefix("vision_model", prefix)
|
445
|
+
)
|
446
|
+
self.text_model = text_model.text_model
|
447
|
+
self.vision_model = vision_model.vision_model
|
448
|
+
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
|
449
|
+
monkey_patch_weight_loader()
|
450
|
+
|
451
|
+
def forward(
|
452
|
+
self,
|
453
|
+
input_ids: torch.Tensor,
|
454
|
+
positions: torch.Tensor,
|
455
|
+
forward_batch: ForwardBatch,
|
456
|
+
get_embedding: bool = True,
|
457
|
+
):
|
458
|
+
assert get_embedding, "CLIPEmbeddingModel is only used for embedding"
|
459
|
+
image_inputs = None
|
460
|
+
if forward_batch.mm_inputs is not None:
|
461
|
+
image_inputs = forward_batch.mm_inputs
|
462
|
+
|
463
|
+
if image_inputs is not None and image_inputs[0] is not None:
|
464
|
+
vision_outputs = self.vision_model(image_inputs[0].pixel_values)
|
465
|
+
pooled_output = vision_outputs[:, 0, :]
|
466
|
+
image_embeds = self.visual_projection(pooled_output)
|
467
|
+
image_embeds = nn.functional.normalize(image_embeds, p=2, dim=1)
|
468
|
+
return EmbeddingPoolerOutput(embeddings=image_embeds)
|
469
|
+
|
470
|
+
else:
|
471
|
+
text_outputs = self.text_model(input_ids, position_ids=positions)
|
472
|
+
pooled_output = self.pooler(text_outputs[0], forward_batch)
|
473
|
+
return EmbeddingPoolerOutput(
|
474
|
+
embeddings=self.text_projection(pooled_output.embeddings)
|
475
|
+
)
|
476
|
+
|
477
|
+
def pad_input_ids(self, input_ids: List[int], image_inputs: MultimodalInputs):
|
478
|
+
# Clip embeddings models handle text/image separately, so we don't need to pad input ids
|
479
|
+
return input_ids
|
480
|
+
|
481
|
+
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
482
|
+
stacked_params_mapping = [
|
483
|
+
# (param_name, shard_name, shard_id)
|
484
|
+
("qkv_proj", "q_proj", "q"),
|
485
|
+
("qkv_proj", "k_proj", "k"),
|
486
|
+
("qkv_proj", "v_proj", "v"),
|
487
|
+
]
|
488
|
+
params_dict = dict(self.named_parameters())
|
489
|
+
for name, loaded_weight in weights:
|
490
|
+
if "position_ids" in name:
|
491
|
+
continue
|
492
|
+
if "out_proj" in name:
|
493
|
+
name = name.replace("out_proj", "proj")
|
494
|
+
for param_name, shard_name, shard_id in stacked_params_mapping:
|
495
|
+
if shard_name not in name:
|
496
|
+
continue
|
497
|
+
name = name.replace(shard_name, param_name)
|
498
|
+
param = params_dict[name]
|
499
|
+
weight_loader = param.weight_loader
|
500
|
+
weight_loader(param, loaded_weight, shard_id)
|
501
|
+
break
|
502
|
+
else:
|
503
|
+
param = params_dict[name]
|
504
|
+
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
505
|
+
weight_loader(param, loaded_weight)
|
506
|
+
|
507
|
+
|
508
|
+
# monkey patch weight loader to remove open_clip file
|
509
|
+
def monkey_patch_weight_loader():
|
510
|
+
import glob
|
511
|
+
import os
|
512
|
+
|
513
|
+
from sglang.srt.model_loader.loader import DefaultModelLoader
|
514
|
+
from sglang.srt.model_loader.weight_utils import (
|
515
|
+
download_weights_from_hf,
|
516
|
+
filter_files_not_needed_for_inference,
|
517
|
+
)
|
518
|
+
|
519
|
+
def prepare_weights(
|
520
|
+
self, model_name_or_path: str, revision: Optional[str], fall_back_to_pt: bool
|
521
|
+
) -> Tuple[str, List[str], bool]:
|
522
|
+
model_name_or_path = (
|
523
|
+
self._maybe_download_from_modelscope(model_name_or_path, revision)
|
524
|
+
or model_name_or_path
|
525
|
+
)
|
526
|
+
|
527
|
+
is_local = os.path.isdir(model_name_or_path)
|
528
|
+
use_safetensors = False
|
529
|
+
allow_patterns = ["*.bin"]
|
530
|
+
|
531
|
+
if not is_local:
|
532
|
+
hf_folder = download_weights_from_hf(
|
533
|
+
model_name_or_path,
|
534
|
+
self.load_config.download_dir,
|
535
|
+
allow_patterns,
|
536
|
+
revision,
|
537
|
+
ignore_patterns=self.load_config.ignore_patterns,
|
538
|
+
)
|
539
|
+
else:
|
540
|
+
hf_folder = model_name_or_path
|
541
|
+
|
542
|
+
hf_weights_files: List[str] = []
|
543
|
+
for pattern in allow_patterns:
|
544
|
+
hf_weights_files += glob.glob(os.path.join(hf_folder, pattern))
|
545
|
+
|
546
|
+
hf_weights_files = filter_files_not_needed_for_inference(hf_weights_files)
|
547
|
+
|
548
|
+
# remove open_clip file
|
549
|
+
hf_weights_files = [
|
550
|
+
file for file in hf_weights_files if "open_clip" not in file
|
551
|
+
]
|
552
|
+
|
553
|
+
if len(hf_weights_files) == 0:
|
554
|
+
raise RuntimeError(
|
555
|
+
f"Cannot find any model weights with `{model_name_or_path}`"
|
556
|
+
)
|
557
|
+
|
558
|
+
return hf_folder, hf_weights_files, use_safetensors
|
559
|
+
|
560
|
+
setattr(DefaultModelLoader, "_prepare_weights", prepare_weights)
|
561
|
+
|
562
|
+
|
563
|
+
EntryClass = CLIPModel
|