sglang 0.4.4.post1__py3-none-any.whl → 0.4.4.post3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (185) hide show
  1. sglang/__init__.py +2 -0
  2. sglang/api.py +6 -0
  3. sglang/bench_one_batch.py +1 -1
  4. sglang/bench_one_batch_server.py +1 -1
  5. sglang/bench_serving.py +26 -4
  6. sglang/check_env.py +3 -4
  7. sglang/lang/backend/openai.py +18 -5
  8. sglang/lang/chat_template.py +28 -7
  9. sglang/lang/interpreter.py +7 -3
  10. sglang/lang/ir.py +10 -0
  11. sglang/srt/_custom_ops.py +1 -1
  12. sglang/srt/code_completion_parser.py +174 -0
  13. sglang/srt/configs/__init__.py +2 -6
  14. sglang/srt/configs/deepseekvl2.py +676 -0
  15. sglang/srt/configs/janus_pro.py +3 -4
  16. sglang/srt/configs/load_config.py +1 -0
  17. sglang/srt/configs/model_config.py +49 -8
  18. sglang/srt/configs/utils.py +25 -0
  19. sglang/srt/connector/__init__.py +51 -0
  20. sglang/srt/connector/base_connector.py +112 -0
  21. sglang/srt/connector/redis.py +85 -0
  22. sglang/srt/connector/s3.py +122 -0
  23. sglang/srt/connector/serde/__init__.py +31 -0
  24. sglang/srt/connector/serde/safe_serde.py +29 -0
  25. sglang/srt/connector/serde/serde.py +43 -0
  26. sglang/srt/connector/utils.py +35 -0
  27. sglang/srt/conversation.py +88 -0
  28. sglang/srt/disaggregation/conn.py +81 -0
  29. sglang/srt/disaggregation/decode.py +495 -0
  30. sglang/srt/disaggregation/mini_lb.py +285 -0
  31. sglang/srt/disaggregation/prefill.py +249 -0
  32. sglang/srt/disaggregation/utils.py +44 -0
  33. sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -1
  34. sglang/srt/distributed/parallel_state.py +42 -8
  35. sglang/srt/entrypoints/engine.py +55 -5
  36. sglang/srt/entrypoints/http_server.py +78 -13
  37. sglang/srt/entrypoints/verl_engine.py +2 -0
  38. sglang/srt/function_call_parser.py +133 -55
  39. sglang/srt/hf_transformers_utils.py +28 -3
  40. sglang/srt/layers/activation.py +4 -2
  41. sglang/srt/layers/attention/base_attn_backend.py +1 -1
  42. sglang/srt/layers/attention/flashattention_backend.py +434 -0
  43. sglang/srt/layers/attention/flashinfer_backend.py +1 -1
  44. sglang/srt/layers/attention/flashmla_backend.py +284 -0
  45. sglang/srt/layers/attention/triton_backend.py +171 -38
  46. sglang/srt/layers/attention/triton_ops/decode_attention.py +94 -31
  47. sglang/srt/layers/attention/triton_ops/extend_attention.py +14 -5
  48. sglang/srt/layers/attention/utils.py +53 -0
  49. sglang/srt/layers/attention/vision.py +9 -28
  50. sglang/srt/layers/dp_attention.py +41 -19
  51. sglang/srt/layers/layernorm.py +24 -2
  52. sglang/srt/layers/linear.py +17 -5
  53. sglang/srt/layers/logits_processor.py +25 -7
  54. sglang/srt/layers/moe/ep_moe/kernels.py +110 -11
  55. sglang/srt/layers/moe/ep_moe/layer.py +273 -1
  56. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +416 -0
  57. sglang/srt/layers/moe/fused_moe_native.py +2 -1
  58. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L20,dtype=int8_w8a8.json +146 -0
  59. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L40S,dtype=int8_w8a8.json +146 -0
  60. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1024,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  61. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  62. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +23 -32
  63. sglang/srt/layers/moe/fused_moe_triton/layer.py +1 -2
  64. sglang/srt/layers/moe/topk.py +60 -20
  65. sglang/srt/layers/parameter.py +1 -1
  66. sglang/srt/layers/quantization/__init__.py +80 -53
  67. sglang/srt/layers/quantization/awq.py +200 -0
  68. sglang/srt/layers/quantization/base_config.py +5 -0
  69. sglang/srt/layers/quantization/blockwise_int8.py +1 -1
  70. sglang/srt/layers/quantization/compressed_tensors/__init__.py +0 -0
  71. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +652 -0
  72. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +658 -0
  73. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +9 -0
  74. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py +56 -0
  75. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +162 -0
  76. sglang/srt/layers/quantization/compressed_tensors/utils.py +218 -0
  77. sglang/srt/layers/quantization/fp8.py +76 -34
  78. sglang/srt/layers/quantization/fp8_kernel.py +25 -8
  79. sglang/srt/layers/quantization/fp8_utils.py +284 -28
  80. sglang/srt/layers/quantization/gptq.py +36 -19
  81. sglang/srt/layers/quantization/kv_cache.py +98 -0
  82. sglang/srt/layers/quantization/modelopt_quant.py +9 -7
  83. sglang/srt/layers/quantization/utils.py +153 -0
  84. sglang/srt/layers/quantization/w8a8_fp8.py +70 -19
  85. sglang/srt/layers/rotary_embedding.py +78 -87
  86. sglang/srt/layers/sampler.py +1 -1
  87. sglang/srt/lora/backend/base_backend.py +4 -4
  88. sglang/srt/lora/backend/flashinfer_backend.py +12 -9
  89. sglang/srt/lora/backend/triton_backend.py +5 -8
  90. sglang/srt/lora/layers.py +87 -33
  91. sglang/srt/lora/lora.py +2 -22
  92. sglang/srt/lora/lora_manager.py +67 -30
  93. sglang/srt/lora/mem_pool.py +117 -52
  94. sglang/srt/lora/triton_ops/gate_up_lora_b.py +10 -4
  95. sglang/srt/lora/triton_ops/qkv_lora_b.py +8 -3
  96. sglang/srt/lora/triton_ops/sgemm_lora_a.py +16 -5
  97. sglang/srt/lora/triton_ops/sgemm_lora_b.py +11 -6
  98. sglang/srt/lora/utils.py +18 -1
  99. sglang/srt/managers/cache_controller.py +2 -5
  100. sglang/srt/managers/data_parallel_controller.py +30 -8
  101. sglang/srt/managers/expert_distribution.py +81 -0
  102. sglang/srt/managers/io_struct.py +43 -5
  103. sglang/srt/managers/mm_utils.py +373 -0
  104. sglang/srt/managers/multimodal_processor.py +68 -0
  105. sglang/srt/managers/multimodal_processors/base_processor.py +275 -0
  106. sglang/srt/managers/multimodal_processors/clip.py +63 -0
  107. sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +119 -0
  108. sglang/srt/managers/multimodal_processors/gemma3.py +83 -0
  109. sglang/srt/managers/{image_processors → multimodal_processors}/janus_pro.py +20 -15
  110. sglang/srt/managers/{image_processors → multimodal_processors}/llava.py +10 -15
  111. sglang/srt/managers/multimodal_processors/minicpm.py +167 -0
  112. sglang/srt/managers/{image_processors → multimodal_processors}/mlama.py +7 -8
  113. sglang/srt/managers/{image_processors → multimodal_processors}/qwen_vl.py +28 -22
  114. sglang/srt/managers/schedule_batch.py +134 -30
  115. sglang/srt/managers/scheduler.py +290 -31
  116. sglang/srt/managers/session_controller.py +1 -1
  117. sglang/srt/managers/tokenizer_manager.py +59 -24
  118. sglang/srt/managers/tp_worker.py +4 -1
  119. sglang/srt/managers/tp_worker_overlap_thread.py +3 -3
  120. sglang/srt/managers/utils.py +6 -1
  121. sglang/srt/mem_cache/hiradix_cache.py +18 -7
  122. sglang/srt/mem_cache/memory_pool.py +255 -98
  123. sglang/srt/mem_cache/paged_allocator.py +2 -2
  124. sglang/srt/mem_cache/radix_cache.py +4 -4
  125. sglang/srt/model_executor/cuda_graph_runner.py +36 -21
  126. sglang/srt/model_executor/forward_batch_info.py +68 -11
  127. sglang/srt/model_executor/model_runner.py +75 -8
  128. sglang/srt/model_loader/loader.py +171 -3
  129. sglang/srt/model_loader/weight_utils.py +51 -3
  130. sglang/srt/models/clip.py +563 -0
  131. sglang/srt/models/deepseek_janus_pro.py +31 -88
  132. sglang/srt/models/deepseek_nextn.py +22 -10
  133. sglang/srt/models/deepseek_v2.py +329 -73
  134. sglang/srt/models/deepseek_vl2.py +358 -0
  135. sglang/srt/models/gemma3_causal.py +694 -0
  136. sglang/srt/models/gemma3_mm.py +468 -0
  137. sglang/srt/models/llama.py +47 -7
  138. sglang/srt/models/llama_eagle.py +1 -0
  139. sglang/srt/models/llama_eagle3.py +196 -0
  140. sglang/srt/models/llava.py +3 -3
  141. sglang/srt/models/llavavid.py +3 -3
  142. sglang/srt/models/minicpmo.py +1995 -0
  143. sglang/srt/models/minicpmv.py +62 -137
  144. sglang/srt/models/mllama.py +4 -4
  145. sglang/srt/models/phi3_small.py +1 -1
  146. sglang/srt/models/qwen2.py +3 -0
  147. sglang/srt/models/qwen2_5_vl.py +68 -146
  148. sglang/srt/models/qwen2_classification.py +75 -0
  149. sglang/srt/models/qwen2_moe.py +9 -1
  150. sglang/srt/models/qwen2_vl.py +25 -63
  151. sglang/srt/openai_api/adapter.py +201 -104
  152. sglang/srt/openai_api/protocol.py +33 -7
  153. sglang/srt/patch_torch.py +71 -0
  154. sglang/srt/sampling/sampling_batch_info.py +1 -1
  155. sglang/srt/sampling/sampling_params.py +6 -6
  156. sglang/srt/server_args.py +114 -14
  157. sglang/srt/speculative/build_eagle_tree.py +7 -347
  158. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +41 -5
  159. sglang/srt/speculative/eagle_utils.py +208 -252
  160. sglang/srt/speculative/eagle_worker.py +140 -54
  161. sglang/srt/speculative/spec_info.py +6 -1
  162. sglang/srt/torch_memory_saver_adapter.py +22 -0
  163. sglang/srt/utils.py +215 -21
  164. sglang/test/__init__.py +0 -0
  165. sglang/test/attention/__init__.py +0 -0
  166. sglang/test/attention/test_flashattn_backend.py +312 -0
  167. sglang/test/runners.py +29 -2
  168. sglang/test/test_activation.py +2 -1
  169. sglang/test/test_block_fp8.py +5 -4
  170. sglang/test/test_block_fp8_ep.py +2 -1
  171. sglang/test/test_dynamic_grad_mode.py +58 -0
  172. sglang/test/test_layernorm.py +3 -2
  173. sglang/test/test_utils.py +56 -5
  174. sglang/utils.py +31 -0
  175. sglang/version.py +1 -1
  176. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info}/METADATA +16 -8
  177. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info}/RECORD +180 -132
  178. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info}/WHEEL +1 -1
  179. sglang/srt/configs/qwen2_5_vl_config.py +0 -1006
  180. sglang/srt/managers/image_processor.py +0 -55
  181. sglang/srt/managers/image_processors/base_image_processor.py +0 -219
  182. sglang/srt/managers/image_processors/minicpmv.py +0 -86
  183. sglang/srt/managers/multi_modality_padding.py +0 -134
  184. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info/licenses}/LICENSE +0 -0
  185. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info}/top_level.txt +0 -0
@@ -22,7 +22,6 @@ from typing import (
22
22
  )
23
23
 
24
24
  import filelock
25
- import gguf
26
25
  import huggingface_hub.constants
27
26
  import numpy as np
28
27
  import safetensors.torch
@@ -93,7 +92,7 @@ def convert_bin_to_safetensor_file(
93
92
  pt_filename: str,
94
93
  sf_filename: str,
95
94
  ) -> None:
96
- loaded = torch.load(pt_filename, map_location="cpu")
95
+ loaded = torch.load(pt_filename, map_location="cpu", weights_only=True)
97
96
  if "state_dict" in loaded:
98
97
  loaded = loaded["state_dict"]
99
98
  shared = _shared_pointers(loaded)
@@ -381,7 +380,7 @@ def np_cache_weights_iterator(
381
380
  disable=not enable_tqdm,
382
381
  bar_format=_BAR_FORMAT,
383
382
  ):
384
- state = torch.load(bin_file, map_location="cpu")
383
+ state = torch.load(bin_file, map_location="cpu", weights_only=True)
385
384
  for name, param in state.items():
386
385
  param_path = os.path.join(np_folder, name)
387
386
  with open(param_path, "wb") as f:
@@ -464,6 +463,8 @@ def pt_weights_iterator(
464
463
  def get_gguf_extra_tensor_names(
465
464
  gguf_file: str, gguf_to_hf_name_map: Dict[str, str]
466
465
  ) -> List[str]:
466
+ import gguf
467
+
467
468
  reader = gguf.GGUFReader(gguf_file)
468
469
  expected_gguf_keys = set(gguf_to_hf_name_map.keys())
469
470
  exact_gguf_keys = set([tensor.name for tensor in reader.tensors])
@@ -479,6 +480,8 @@ def gguf_quant_weights_iterator(
479
480
  them to torch tensors
480
481
  """
481
482
 
483
+ import gguf
484
+
482
485
  reader = gguf.GGUFReader(gguf_file)
483
486
 
484
487
  for tensor in reader.tensors:
@@ -585,6 +588,51 @@ def composed_weight_loader(
585
588
  return composed_loader
586
589
 
587
590
 
591
+ def runai_safetensors_weights_iterator(
592
+ hf_weights_files: List[str],
593
+ ) -> Generator[Tuple[str, torch.Tensor], None, None]:
594
+ """Iterate over the weights in the model safetensor files."""
595
+ from runai_model_streamer import SafetensorsStreamer
596
+
597
+ enable_tqdm = (
598
+ not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0
599
+ )
600
+
601
+ with SafetensorsStreamer() as streamer:
602
+ for st_file in tqdm(
603
+ hf_weights_files,
604
+ desc="Loading safetensors using Runai Model Streamer",
605
+ disable=not enable_tqdm,
606
+ bar_format=_BAR_FORMAT,
607
+ ):
608
+ streamer.stream_file(st_file)
609
+ yield from streamer.get_tensors()
610
+
611
+
612
+ def set_runai_streamer_env(load_config: LoadConfig):
613
+ if load_config.model_loader_extra_config:
614
+ extra_config = load_config.model_loader_extra_config
615
+
616
+ if "concurrency" in extra_config and isinstance(
617
+ extra_config.get("concurrency"), int
618
+ ):
619
+ os.environ["RUNAI_STREAMER_CONCURRENCY"] = str(
620
+ extra_config.get("concurrency")
621
+ )
622
+
623
+ if "memory_limit" in extra_config and isinstance(
624
+ extra_config.get("memory_limit"), int
625
+ ):
626
+ os.environ["RUNAI_STREAMER_MEMORY_LIMIT"] = str(
627
+ extra_config.get("memory_limit")
628
+ )
629
+
630
+ runai_streamer_s3_endpoint = os.getenv("RUNAI_STREAMER_S3_ENDPOINT")
631
+ aws_endpoint_url = os.getenv("AWS_ENDPOINT_URL")
632
+ if runai_streamer_s3_endpoint is None and aws_endpoint_url is not None:
633
+ os.environ["RUNAI_STREAMER_S3_ENDPOINT"] = aws_endpoint_url
634
+
635
+
588
636
  def initialize_dummy_weights(
589
637
  model: torch.nn.Module,
590
638
  low: float = -1e-3,
@@ -0,0 +1,563 @@
1
+ # Adapted from
2
+ # https://github.com/huggingface/transformers/blob/af9b2eaa54c150741f298d6db939af6328e1dc38/src/transformers/models/clip/modeling_clip.py
3
+
4
+ from functools import partial
5
+ from typing import Iterable, List, Optional, Tuple, Type, Union
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ from transformers import CLIPConfig, CLIPTextConfig, CLIPVisionConfig
10
+ from transformers.modeling_attn_mask_utils import _create_4d_causal_attention_mask
11
+
12
+ from sglang.srt.layers.activation import QuickGELU
13
+ from sglang.srt.layers.attention.vision import VisionAttention
14
+ from sglang.srt.layers.linear import ColumnParallelLinear, RowParallelLinear
15
+ from sglang.srt.layers.pooler import EmbeddingPoolerOutput, Pooler, PoolingType
16
+ from sglang.srt.layers.quantization.base_config import QuantizationConfig
17
+ from sglang.srt.managers.schedule_batch import MultimodalInputs
18
+ from sglang.srt.model_executor.model_runner import ForwardBatch
19
+ from sglang.srt.model_loader.weight_utils import default_weight_loader
20
+ from sglang.srt.utils import add_prefix
21
+
22
+
23
+ class CLIPVisionEmbeddings(nn.Module):
24
+
25
+ def __init__(self, config: CLIPVisionConfig):
26
+ super().__init__()
27
+ self.config = config
28
+ self.embed_dim = config.hidden_size
29
+ self.image_size = config.image_size
30
+ self.patch_size = config.patch_size
31
+ assert self.image_size % self.patch_size == 0
32
+
33
+ self.class_embedding = nn.Parameter(torch.randn(self.embed_dim))
34
+
35
+ self.patch_embedding = nn.Conv2d(
36
+ in_channels=config.num_channels,
37
+ out_channels=self.embed_dim,
38
+ kernel_size=self.patch_size,
39
+ stride=self.patch_size,
40
+ bias=False,
41
+ )
42
+
43
+ self.num_patches = (self.image_size // self.patch_size) ** 2
44
+ self.num_positions = self.num_patches + 1
45
+ self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
46
+ self.register_buffer(
47
+ "position_ids",
48
+ torch.arange(self.num_positions).expand((1, -1)),
49
+ persistent=False,
50
+ )
51
+
52
+ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
53
+ batch_size = pixel_values.shape[0]
54
+ target_dtype = self.patch_embedding.weight.dtype
55
+ patch_embeds = self.patch_embedding(
56
+ pixel_values.to(dtype=target_dtype)
57
+ ) # shape = [*, width, grid, grid]
58
+ patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
59
+
60
+ class_embeds = self.class_embedding.expand(batch_size, 1, -1)
61
+ embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
62
+ embeddings = embeddings + self.position_embedding(self.position_ids)
63
+
64
+ return embeddings
65
+
66
+
67
+ class CLIPTextEmbeddings(nn.Module):
68
+ def __init__(self, config: CLIPTextConfig):
69
+ super().__init__()
70
+ embed_dim = config.hidden_size
71
+
72
+ self.token_embedding = nn.Embedding(config.vocab_size, embed_dim)
73
+ self.position_embedding = nn.Embedding(
74
+ config.max_position_embeddings, embed_dim
75
+ )
76
+
77
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
78
+ self.register_buffer(
79
+ "position_ids",
80
+ torch.arange(config.max_position_embeddings).expand((1, -1)),
81
+ persistent=False,
82
+ )
83
+
84
+ def forward(
85
+ self,
86
+ input_ids: Optional[torch.LongTensor] = None,
87
+ position_ids: Optional[torch.LongTensor] = None,
88
+ inputs_embeds: Optional[torch.FloatTensor] = None,
89
+ ) -> torch.Tensor:
90
+ seq_length = (
91
+ input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]
92
+ )
93
+
94
+ if position_ids is None:
95
+ position_ids = self.position_ids[:, :seq_length]
96
+
97
+ if inputs_embeds is None:
98
+ inputs_embeds = self.token_embedding(input_ids)
99
+
100
+ position_embeddings = self.position_embedding(position_ids)
101
+ embeddings = inputs_embeds + position_embeddings
102
+
103
+ return embeddings
104
+
105
+
106
+ class CLIPMLP(nn.Module):
107
+
108
+ def __init__(
109
+ self,
110
+ config,
111
+ act_layer: Type[nn.Module] = QuickGELU,
112
+ quant_config: Optional[QuantizationConfig] = None,
113
+ prefix: str = "",
114
+ ):
115
+ super().__init__()
116
+ self.fc1 = ColumnParallelLinear(
117
+ config.hidden_size,
118
+ config.intermediate_size,
119
+ quant_config=quant_config,
120
+ prefix=add_prefix("fc1", prefix),
121
+ )
122
+ self.act = act_layer()
123
+ self.fc2 = RowParallelLinear(
124
+ config.intermediate_size,
125
+ config.hidden_size,
126
+ quant_config=quant_config,
127
+ prefix=add_prefix("fc2", prefix),
128
+ )
129
+
130
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
131
+ x_parallel, _ = self.fc1(x)
132
+ x_parallel = self.act(x_parallel)
133
+ x, _ = self.fc2(x_parallel)
134
+ return x
135
+
136
+
137
+ class CLIPEncoderLayer(nn.Module):
138
+
139
+ def __init__(
140
+ self,
141
+ config: CLIPVisionConfig,
142
+ act_layer: Type[nn.Module] = QuickGELU,
143
+ norm_layer: Type[nn.Module] = None,
144
+ attn_implementation: Optional[str] = "sdpa",
145
+ quant_config: Optional[QuantizationConfig] = None,
146
+ prefix: str = "",
147
+ ) -> None:
148
+ super().__init__()
149
+ if norm_layer is None:
150
+ norm_layer = partial(nn.LayerNorm, eps=config.layer_norm_eps)
151
+ self.layer_norm1 = norm_layer(config.hidden_size)
152
+ self.layer_norm2 = norm_layer(config.hidden_size)
153
+ if attn_implementation == "sdpa":
154
+ use_context_forward = False
155
+ softmax_in_single_precision = False
156
+ elif attn_implementation == "flash_attention_2":
157
+ softmax_in_single_precision = False
158
+ use_context_forward = True
159
+ elif attn_implementation == "eager":
160
+ softmax_in_single_precision = True
161
+ use_context_forward = False
162
+ self.self_attn = VisionAttention(
163
+ embed_dim=config.hidden_size,
164
+ num_heads=config.num_attention_heads,
165
+ projection_size=config.hidden_size,
166
+ use_qkv_parallel=True,
167
+ use_context_forward=use_context_forward,
168
+ softmax_in_single_precision=softmax_in_single_precision,
169
+ flatten_batch=True,
170
+ quant_config=quant_config,
171
+ prefix=add_prefix("attn", prefix),
172
+ )
173
+ self.mlp = CLIPMLP(
174
+ config,
175
+ act_layer=act_layer,
176
+ quant_config=quant_config,
177
+ prefix=add_prefix("mlp", prefix),
178
+ )
179
+
180
+ def forward(
181
+ self,
182
+ hidden_states: torch.Tensor,
183
+ attention_mask: torch.Tensor,
184
+ causal_attention_mask: torch.Tensor,
185
+ ) -> torch.Tensor:
186
+
187
+ residual = hidden_states
188
+ hidden_states = self.layer_norm1(hidden_states)
189
+ # CLIP text model uses both `causal_attention_mask` and `attention_mask`
190
+ if attention_mask is not None and causal_attention_mask is not None:
191
+ attn_mask = attention_mask + causal_attention_mask
192
+ elif causal_attention_mask is not None:
193
+ attn_mask = causal_attention_mask
194
+ else:
195
+ attn_mask = attention_mask
196
+ hidden_states = self.self_attn(
197
+ hidden_states,
198
+ attention_mask=attn_mask,
199
+ # causal_attention_mask=causal_attention_mask,
200
+ )
201
+
202
+ hidden_states = residual + hidden_states
203
+ residual = hidden_states
204
+ hidden_states = self.layer_norm2(hidden_states)
205
+ hidden_states = self.mlp(hidden_states)
206
+ hidden_states = residual + hidden_states
207
+ return hidden_states
208
+
209
+
210
+ class CLIPEncoder(nn.Module):
211
+ """
212
+ Transformer encoder consisting of `config.num_hidden_layers` self
213
+ attention layers. Each layer is a [`CLIPEncoderLayer`].
214
+
215
+ Args:
216
+ config: CLIPConfig
217
+ """
218
+
219
+ def __init__(
220
+ self,
221
+ config: CLIPVisionConfig,
222
+ quant_config: Optional[QuantizationConfig] = None,
223
+ prefix: str = "",
224
+ ) -> None:
225
+ super().__init__()
226
+
227
+ self.config = config
228
+
229
+ num_hidden_layers = config.num_hidden_layers
230
+ norm_layer = partial(nn.LayerNorm, eps=config.layer_norm_eps)
231
+ self.layers = nn.ModuleList(
232
+ [
233
+ CLIPEncoderLayer(
234
+ config=config,
235
+ norm_layer=norm_layer,
236
+ attn_implementation="sdpa",
237
+ quant_config=quant_config,
238
+ prefix=add_prefix(f"layers.{layer_idx}", prefix),
239
+ )
240
+ for layer_idx in range(num_hidden_layers)
241
+ ]
242
+ )
243
+
244
+ def forward(
245
+ self,
246
+ inputs_embeds: torch.Tensor,
247
+ attention_mask: torch.Tensor = None,
248
+ causal_attention_mask: torch.Tensor = None,
249
+ return_all_hidden_states: bool = False,
250
+ ) -> Union[torch.Tensor, list[torch.Tensor]]:
251
+ hidden_states_pool = [inputs_embeds]
252
+ hidden_states = inputs_embeds
253
+
254
+ for encoder_layer in self.layers:
255
+ hidden_states = encoder_layer(
256
+ hidden_states, attention_mask, causal_attention_mask
257
+ )
258
+ if return_all_hidden_states:
259
+ hidden_states_pool.append(hidden_states)
260
+ if return_all_hidden_states:
261
+ return hidden_states_pool
262
+ return hidden_states
263
+
264
+
265
+ class CLIPTextTransformer(nn.Module):
266
+ def __init__(
267
+ self,
268
+ config: CLIPTextConfig,
269
+ quant_config: Optional[QuantizationConfig] = None,
270
+ prefix: str = "",
271
+ ) -> None:
272
+ super().__init__()
273
+ self.config = config
274
+ embed_dim = config.hidden_size
275
+ self.embeddings = CLIPTextEmbeddings(config)
276
+ self.encoder = CLIPEncoder(
277
+ config=config,
278
+ quant_config=quant_config,
279
+ prefix=add_prefix("encoder", prefix),
280
+ )
281
+ self.final_layer_norm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
282
+
283
+ @property
284
+ def device(self) -> torch.device:
285
+ return self.encoder.layers[0].layer_norm1.weight.device
286
+
287
+ def forward(
288
+ self,
289
+ input_ids: torch.Tensor,
290
+ attention_mask: Optional[torch.Tensor] = None,
291
+ position_ids: Optional[torch.Tensor] = None,
292
+ ):
293
+ input_shape = input_ids.size()
294
+ input_ids = input_ids.view(-1, input_shape[-1])
295
+ hidden_states = self.embeddings(input_ids, position_ids)
296
+ causal_attention_mask = _create_4d_causal_attention_mask(
297
+ input_ids.shape, hidden_states.dtype, device=hidden_states.device
298
+ )
299
+ encoder_outputs = self.encoder(
300
+ hidden_states, attention_mask, causal_attention_mask
301
+ )
302
+ last_hidden_state = self.final_layer_norm(encoder_outputs)
303
+ return last_hidden_state
304
+
305
+
306
+ class CLIPTextModel(nn.Module):
307
+ def __init__(
308
+ self,
309
+ config: CLIPTextConfig,
310
+ quant_config: Optional[QuantizationConfig] = None,
311
+ prefix: str = "",
312
+ ) -> None:
313
+ super().__init__()
314
+ self.config = config
315
+ self.text_model = CLIPTextTransformer(
316
+ config=config,
317
+ quant_config=quant_config,
318
+ prefix=add_prefix("text_model", prefix),
319
+ )
320
+
321
+ def forward(
322
+ self,
323
+ input_ids: torch.Tensor,
324
+ position_ids: torch.Tensor,
325
+ ):
326
+ return self.text_model(input_ids, position_ids)
327
+
328
+
329
+ class CLIPVisionTransformer(nn.Module):
330
+
331
+ def __init__(
332
+ self,
333
+ config: CLIPVisionConfig,
334
+ quant_config: Optional[QuantizationConfig] = None,
335
+ prefix: str = "",
336
+ ) -> None:
337
+ super().__init__()
338
+
339
+ self.config = config
340
+ embed_dim = config.hidden_size
341
+
342
+ self.embeddings = CLIPVisionEmbeddings(config)
343
+
344
+ # NOTE: This typo of "layrnorm" is not fixed on purpose to match
345
+ # the original transformers code and name of the model weights.
346
+ self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
347
+
348
+ self.encoder = CLIPEncoder(
349
+ config=config,
350
+ quant_config=quant_config,
351
+ prefix=add_prefix("encoder", prefix),
352
+ )
353
+
354
+ num_hidden_layers = config.num_hidden_layers
355
+ if len(self.encoder.layers) > config.num_hidden_layers:
356
+ raise ValueError(
357
+ f"The original encoder only has {num_hidden_layers} "
358
+ f"layers, but you requested {len(self.encoder.layers)} layers."
359
+ )
360
+
361
+ self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
362
+
363
+ @property
364
+ def device(self) -> torch.device:
365
+ return self.encoder.layers[0].layer_norm1.weight.device
366
+
367
+ def forward(
368
+ self,
369
+ pixel_values: torch.Tensor,
370
+ ) -> torch.Tensor:
371
+
372
+ hidden_states = self.embeddings(pixel_values.to(self.device))
373
+ hidden_states = self.pre_layrnorm(hidden_states)
374
+
375
+ return_all_hidden_states = False
376
+
377
+ last_hidden_state = self.encoder(
378
+ inputs_embeds=hidden_states,
379
+ return_all_hidden_states=return_all_hidden_states,
380
+ )
381
+
382
+ last_hidden_state = self.post_layernorm(last_hidden_state)
383
+
384
+ return last_hidden_state
385
+
386
+
387
+ class CLIPVisionModel(nn.Module):
388
+ def __init__(
389
+ self,
390
+ config: CLIPVisionConfig,
391
+ quant_config: Optional[QuantizationConfig] = None,
392
+ prefix: str = "",
393
+ ):
394
+ super().__init__()
395
+ self.vision_model = CLIPVisionTransformer(
396
+ config, quant_config, prefix=add_prefix("vision_model", prefix)
397
+ )
398
+
399
+ def forward(self, pixel_values: torch.Tensor):
400
+ return self.vision_model(pixel_values)
401
+
402
+
403
+ class CLIPModel(nn.Module):
404
+ def __init__(
405
+ self,
406
+ config: CLIPConfig,
407
+ quant_config: Optional[QuantizationConfig] = None,
408
+ prefix: str = "",
409
+ ) -> None:
410
+ super().__init__()
411
+ self.config = config
412
+ if not isinstance(config.text_config, CLIPTextConfig):
413
+ raise TypeError(
414
+ "config.text_config is expected to be of type CLIPTextConfig but is of type"
415
+ f" {type(config.text_config)}."
416
+ )
417
+
418
+ if not isinstance(config.vision_config, CLIPVisionConfig):
419
+ raise TypeError(
420
+ "config.vision_config is expected to be of type CLIPVisionConfig but is of type"
421
+ f" {type(config.vision_config)}."
422
+ )
423
+
424
+ text_config = config.text_config
425
+ vision_config = config.vision_config
426
+
427
+ self.projection_dim = config.projection_dim
428
+ self.text_embed_dim = text_config.hidden_size
429
+ self.vision_embed_dim = vision_config.hidden_size
430
+ self.visual_projection = nn.Linear(
431
+ self.vision_embed_dim, self.projection_dim, bias=False
432
+ )
433
+ self.text_projection = nn.Linear(
434
+ self.text_embed_dim, self.projection_dim, bias=False
435
+ )
436
+ self.logit_scale = nn.Parameter(
437
+ torch.tensor(self.config.logit_scale_init_value)
438
+ )
439
+
440
+ text_model = CLIPTextModel(
441
+ text_config, quant_config, prefix=add_prefix("text_model", prefix)
442
+ )
443
+ vision_model = CLIPVisionModel(
444
+ vision_config, quant_config, prefix=add_prefix("vision_model", prefix)
445
+ )
446
+ self.text_model = text_model.text_model
447
+ self.vision_model = vision_model.vision_model
448
+ self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
449
+ monkey_patch_weight_loader()
450
+
451
+ def forward(
452
+ self,
453
+ input_ids: torch.Tensor,
454
+ positions: torch.Tensor,
455
+ forward_batch: ForwardBatch,
456
+ get_embedding: bool = True,
457
+ ):
458
+ assert get_embedding, "CLIPEmbeddingModel is only used for embedding"
459
+ image_inputs = None
460
+ if forward_batch.mm_inputs is not None:
461
+ image_inputs = forward_batch.mm_inputs
462
+
463
+ if image_inputs is not None and image_inputs[0] is not None:
464
+ vision_outputs = self.vision_model(image_inputs[0].pixel_values)
465
+ pooled_output = vision_outputs[:, 0, :]
466
+ image_embeds = self.visual_projection(pooled_output)
467
+ image_embeds = nn.functional.normalize(image_embeds, p=2, dim=1)
468
+ return EmbeddingPoolerOutput(embeddings=image_embeds)
469
+
470
+ else:
471
+ text_outputs = self.text_model(input_ids, position_ids=positions)
472
+ pooled_output = self.pooler(text_outputs[0], forward_batch)
473
+ return EmbeddingPoolerOutput(
474
+ embeddings=self.text_projection(pooled_output.embeddings)
475
+ )
476
+
477
+ def pad_input_ids(self, input_ids: List[int], image_inputs: MultimodalInputs):
478
+ # Clip embeddings models handle text/image separately, so we don't need to pad input ids
479
+ return input_ids
480
+
481
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
482
+ stacked_params_mapping = [
483
+ # (param_name, shard_name, shard_id)
484
+ ("qkv_proj", "q_proj", "q"),
485
+ ("qkv_proj", "k_proj", "k"),
486
+ ("qkv_proj", "v_proj", "v"),
487
+ ]
488
+ params_dict = dict(self.named_parameters())
489
+ for name, loaded_weight in weights:
490
+ if "position_ids" in name:
491
+ continue
492
+ if "out_proj" in name:
493
+ name = name.replace("out_proj", "proj")
494
+ for param_name, shard_name, shard_id in stacked_params_mapping:
495
+ if shard_name not in name:
496
+ continue
497
+ name = name.replace(shard_name, param_name)
498
+ param = params_dict[name]
499
+ weight_loader = param.weight_loader
500
+ weight_loader(param, loaded_weight, shard_id)
501
+ break
502
+ else:
503
+ param = params_dict[name]
504
+ weight_loader = getattr(param, "weight_loader", default_weight_loader)
505
+ weight_loader(param, loaded_weight)
506
+
507
+
508
+ # monkey patch weight loader to remove open_clip file
509
+ def monkey_patch_weight_loader():
510
+ import glob
511
+ import os
512
+
513
+ from sglang.srt.model_loader.loader import DefaultModelLoader
514
+ from sglang.srt.model_loader.weight_utils import (
515
+ download_weights_from_hf,
516
+ filter_files_not_needed_for_inference,
517
+ )
518
+
519
+ def prepare_weights(
520
+ self, model_name_or_path: str, revision: Optional[str], fall_back_to_pt: bool
521
+ ) -> Tuple[str, List[str], bool]:
522
+ model_name_or_path = (
523
+ self._maybe_download_from_modelscope(model_name_or_path, revision)
524
+ or model_name_or_path
525
+ )
526
+
527
+ is_local = os.path.isdir(model_name_or_path)
528
+ use_safetensors = False
529
+ allow_patterns = ["*.bin"]
530
+
531
+ if not is_local:
532
+ hf_folder = download_weights_from_hf(
533
+ model_name_or_path,
534
+ self.load_config.download_dir,
535
+ allow_patterns,
536
+ revision,
537
+ ignore_patterns=self.load_config.ignore_patterns,
538
+ )
539
+ else:
540
+ hf_folder = model_name_or_path
541
+
542
+ hf_weights_files: List[str] = []
543
+ for pattern in allow_patterns:
544
+ hf_weights_files += glob.glob(os.path.join(hf_folder, pattern))
545
+
546
+ hf_weights_files = filter_files_not_needed_for_inference(hf_weights_files)
547
+
548
+ # remove open_clip file
549
+ hf_weights_files = [
550
+ file for file in hf_weights_files if "open_clip" not in file
551
+ ]
552
+
553
+ if len(hf_weights_files) == 0:
554
+ raise RuntimeError(
555
+ f"Cannot find any model weights with `{model_name_or_path}`"
556
+ )
557
+
558
+ return hf_folder, hf_weights_files, use_safetensors
559
+
560
+ setattr(DefaultModelLoader, "_prepare_weights", prepare_weights)
561
+
562
+
563
+ EntryClass = CLIPModel