sglang 0.1.14__py3-none-any.whl → 0.1.21__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +59 -2
- sglang/api.py +40 -11
- sglang/backend/anthropic.py +17 -3
- sglang/backend/litellm.py +90 -0
- sglang/backend/openai.py +160 -12
- sglang/backend/runtime_endpoint.py +62 -27
- sglang/backend/vertexai.py +1 -0
- sglang/bench_latency.py +320 -0
- sglang/global_config.py +24 -3
- sglang/lang/chat_template.py +122 -6
- sglang/lang/compiler.py +2 -2
- sglang/lang/interpreter.py +206 -98
- sglang/lang/ir.py +98 -34
- sglang/lang/tracer.py +6 -4
- sglang/launch_server.py +4 -1
- sglang/launch_server_llavavid.py +32 -0
- sglang/srt/constrained/__init__.py +14 -6
- sglang/srt/constrained/fsm_cache.py +9 -2
- sglang/srt/constrained/jump_forward.py +113 -24
- sglang/srt/conversation.py +4 -2
- sglang/srt/flush_cache.py +18 -0
- sglang/srt/hf_transformers_utils.py +144 -3
- sglang/srt/layers/context_flashattention_nopad.py +1 -0
- sglang/srt/layers/extend_attention.py +20 -1
- sglang/srt/layers/fused_moe.py +596 -0
- sglang/srt/layers/logits_processor.py +190 -61
- sglang/srt/layers/radix_attention.py +62 -53
- sglang/srt/layers/token_attention.py +21 -9
- sglang/srt/managers/controller/cuda_graph_runner.py +196 -0
- sglang/srt/managers/controller/dp_worker.py +113 -0
- sglang/srt/managers/controller/infer_batch.py +908 -0
- sglang/srt/managers/controller/manager_multi.py +195 -0
- sglang/srt/managers/controller/manager_single.py +177 -0
- sglang/srt/managers/controller/model_runner.py +359 -0
- sglang/srt/managers/{router → controller}/radix_cache.py +102 -53
- sglang/srt/managers/controller/schedule_heuristic.py +65 -0
- sglang/srt/managers/controller/tp_worker.py +813 -0
- sglang/srt/managers/detokenizer_manager.py +42 -40
- sglang/srt/managers/io_struct.py +44 -10
- sglang/srt/managers/tokenizer_manager.py +224 -82
- sglang/srt/memory_pool.py +52 -59
- sglang/srt/model_config.py +97 -2
- sglang/srt/models/chatglm.py +399 -0
- sglang/srt/models/commandr.py +369 -0
- sglang/srt/models/dbrx.py +406 -0
- sglang/srt/models/gemma.py +34 -38
- sglang/srt/models/gemma2.py +436 -0
- sglang/srt/models/grok.py +738 -0
- sglang/srt/models/llama2.py +47 -37
- sglang/srt/models/llama_classification.py +107 -0
- sglang/srt/models/llava.py +92 -27
- sglang/srt/models/llavavid.py +298 -0
- sglang/srt/models/minicpm.py +366 -0
- sglang/srt/models/mixtral.py +302 -127
- sglang/srt/models/mixtral_quant.py +372 -0
- sglang/srt/models/qwen.py +40 -35
- sglang/srt/models/qwen2.py +33 -36
- sglang/srt/models/qwen2_moe.py +473 -0
- sglang/srt/models/stablelm.py +33 -39
- sglang/srt/models/yivl.py +19 -26
- sglang/srt/openai_api_adapter.py +411 -0
- sglang/srt/{managers/openai_protocol.py → openai_protocol.py} +44 -19
- sglang/srt/sampling_params.py +2 -0
- sglang/srt/server.py +197 -481
- sglang/srt/server_args.py +190 -74
- sglang/srt/utils.py +460 -95
- sglang/test/test_programs.py +73 -10
- sglang/test/test_utils.py +226 -7
- sglang/utils.py +97 -27
- {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/METADATA +74 -45
- sglang-0.1.21.dist-info/RECORD +82 -0
- {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/WHEEL +1 -1
- sglang/srt/backend_config.py +0 -13
- sglang/srt/managers/router/infer_batch.py +0 -503
- sglang/srt/managers/router/manager.py +0 -79
- sglang/srt/managers/router/model_rpc.py +0 -686
- sglang/srt/managers/router/model_runner.py +0 -514
- sglang/srt/managers/router/scheduler.py +0 -70
- sglang-0.1.14.dist-info/RECORD +0 -64
- {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/LICENSE +0 -0
- {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,298 @@
|
|
1
|
+
"""Inference-only LLaVa video model compatible with HuggingFace weights."""
|
2
|
+
|
3
|
+
from typing import Iterable, List, Optional, Tuple
|
4
|
+
|
5
|
+
import numpy as np
|
6
|
+
import torch
|
7
|
+
from torch import nn
|
8
|
+
from transformers import CLIPVisionModel, LlavaConfig
|
9
|
+
from transformers.models.llava.modeling_llava import LlavaMultiModalProjector
|
10
|
+
from vllm.config import CacheConfig
|
11
|
+
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
12
|
+
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
13
|
+
|
14
|
+
from sglang.srt.managers.controller.infer_batch import ForwardMode
|
15
|
+
from sglang.srt.managers.controller.model_runner import InputMetadata
|
16
|
+
from sglang.srt.mm_utils import (
|
17
|
+
get_anyres_image_grid_shape,
|
18
|
+
unpad_image,
|
19
|
+
unpad_image_shape,
|
20
|
+
)
|
21
|
+
from sglang.srt.models.llama2 import LlamaForCausalLM
|
22
|
+
|
23
|
+
|
24
|
+
class LlavaVidForCausalLM(nn.Module):
|
25
|
+
def __init__(
|
26
|
+
self,
|
27
|
+
config: LlavaConfig,
|
28
|
+
quant_config: Optional[QuantizationConfig] = None,
|
29
|
+
cache_config: Optional[CacheConfig] = None,
|
30
|
+
) -> None:
|
31
|
+
super().__init__()
|
32
|
+
self.config = config
|
33
|
+
self.vision_tower = None
|
34
|
+
self.config.vision_config.hidden_size = config.mm_hidden_size
|
35
|
+
self.config.text_config.hidden_size = config.hidden_size
|
36
|
+
self.multi_modal_projector = LlavaMultiModalProjector(config)
|
37
|
+
self.mm_spatial_pool_stride = getattr(self.config, "mm_spatial_pool_stride", 2)
|
38
|
+
self.resampler = nn.AvgPool2d(
|
39
|
+
kernel_size=self.mm_spatial_pool_stride, stride=self.mm_spatial_pool_stride
|
40
|
+
)
|
41
|
+
self.language_model = LlamaForCausalLM(config, quant_config=quant_config)
|
42
|
+
self.num_frames = getattr(self.config, "num_frames", 16)
|
43
|
+
if "unpad" in getattr(config, "mm_patch_merge_type", ""):
|
44
|
+
self.language_model.model.image_newline = nn.Parameter(
|
45
|
+
torch.empty(config.text_config.hidden_size, dtype=torch.float16)
|
46
|
+
)
|
47
|
+
|
48
|
+
def pad_input_ids(self, input_ids, pad_value, pt_shape=None, image_size=None):
|
49
|
+
new_image_feature_len = self.image_feature_len
|
50
|
+
# now only support spatial_unpad + anyres
|
51
|
+
# if self.mm_patch_merge_type.startswith("spatial"):
|
52
|
+
# height = width = self.num_patches_per_side
|
53
|
+
# if pt_shape[0] > 1:
|
54
|
+
# if self.image_aspect_ratio == "anyres":
|
55
|
+
# num_patch_width, num_patch_height = get_anyres_image_grid_shape(
|
56
|
+
# image_size,
|
57
|
+
# self.image_grid_pinpoints,
|
58
|
+
# self.vision_tower.config.image_size,
|
59
|
+
# )
|
60
|
+
# if "unpad" in self.mm_patch_merge_type:
|
61
|
+
# h = num_patch_height * height
|
62
|
+
# w = num_patch_width * width
|
63
|
+
# new_h, new_w = unpad_image_shape(h, w, image_size)
|
64
|
+
# new_image_feature_len += new_h * (new_w + 1)
|
65
|
+
|
66
|
+
pad_ids = pad_value * (
|
67
|
+
(new_image_feature_len + len(pad_value)) // len(pad_value)
|
68
|
+
)
|
69
|
+
offset = input_ids.index(self.config.image_token_index)
|
70
|
+
# old_len + pad_len - 1, because we need to remove image_token_id
|
71
|
+
new_input_ids = (
|
72
|
+
input_ids[:offset]
|
73
|
+
+ pad_ids[:new_image_feature_len]
|
74
|
+
+ input_ids[offset + 1 :]
|
75
|
+
)
|
76
|
+
return new_input_ids, offset
|
77
|
+
|
78
|
+
def encode_images(self, pixel_values: torch.Tensor) -> torch.Tensor:
|
79
|
+
image_outputs = self.vision_tower(pixel_values, output_hidden_states=True)
|
80
|
+
# NOTE: This is not memory efficient. (output_hidden_states=True) will save all the hidden stated.
|
81
|
+
|
82
|
+
selected_image_feature = image_outputs.hidden_states[self.vision_feature_layer]
|
83
|
+
if self.vision_feature_select_strategy in ["default", "patch"]:
|
84
|
+
selected_image_feature = selected_image_feature[:, 1:]
|
85
|
+
elif self.vision_feature_select_strategy == "full":
|
86
|
+
selected_image_feature = selected_image_feature
|
87
|
+
else:
|
88
|
+
raise ValueError(
|
89
|
+
f"Unexpected select feature strategy: {self.config.vision_feature_select_strategy}"
|
90
|
+
)
|
91
|
+
|
92
|
+
height = width = self.num_patches_per_side
|
93
|
+
num_of_frames = selected_image_feature.shape[0]
|
94
|
+
selected_image_feature = selected_image_feature.view(
|
95
|
+
num_of_frames, height, width, -1
|
96
|
+
)
|
97
|
+
selected_image_feature = selected_image_feature.permute(0, 3, 1, 2).contiguous()
|
98
|
+
selected_image_feature = (
|
99
|
+
self.resampler(selected_image_feature)
|
100
|
+
.flatten(2)
|
101
|
+
.transpose(1, 2)
|
102
|
+
.contiguous()
|
103
|
+
)
|
104
|
+
|
105
|
+
image_features = self.multi_modal_projector(selected_image_feature)
|
106
|
+
|
107
|
+
return image_features
|
108
|
+
|
109
|
+
def forward(
|
110
|
+
self,
|
111
|
+
input_ids: torch.LongTensor,
|
112
|
+
positions: torch.Tensor,
|
113
|
+
input_metadata: InputMetadata,
|
114
|
+
pixel_values: Optional[List[Optional[np.array]]] = None,
|
115
|
+
image_sizes: Optional[List[List[int]]] = None,
|
116
|
+
image_offsets: Optional[List[int]] = None,
|
117
|
+
) -> torch.Tensor:
|
118
|
+
if input_metadata.forward_mode == ForwardMode.EXTEND:
|
119
|
+
bs = input_metadata.batch_size
|
120
|
+
|
121
|
+
# Embed text input
|
122
|
+
input_embeds = self.language_model.model.embed_tokens(input_ids)
|
123
|
+
|
124
|
+
# Embed vision input
|
125
|
+
need_vision = (
|
126
|
+
(positions[input_metadata.extend_start_loc] < self.image_feature_len)
|
127
|
+
.cpu()
|
128
|
+
.numpy()
|
129
|
+
)
|
130
|
+
# FIXME: We need to substract the length of the system prompt
|
131
|
+
has_pixel = np.array([pixel_values[i] is not None for i in range(bs)])
|
132
|
+
need_vision = need_vision & has_pixel
|
133
|
+
|
134
|
+
if need_vision.any():
|
135
|
+
pixel_values = [pixel_values[i] for i in range(bs) if need_vision[i]]
|
136
|
+
image_sizes = [image_sizes[i] for i in range(bs) if need_vision[i]]
|
137
|
+
|
138
|
+
########## Encode Image ########
|
139
|
+
|
140
|
+
if pixel_values[0].ndim == 4:
|
141
|
+
# llava-hd: BS, num_patch, C=3, H=336, W=336, num_patch obtained from process_images
|
142
|
+
np.concatenate(pixel_values, axis=0)
|
143
|
+
# ndim=4
|
144
|
+
concat_images = torch.tensor(
|
145
|
+
np.concatenate(pixel_values, axis=0),
|
146
|
+
device=self.vision_tower.device,
|
147
|
+
)
|
148
|
+
# image_features = self.encode_images(concat_images)
|
149
|
+
# split_sizes = [image.shape[0] for image in pixel_values]
|
150
|
+
# image_features = torch.split(image_features, split_sizes, dim=0)
|
151
|
+
image_features = self.encode_images(
|
152
|
+
concat_images
|
153
|
+
) # , prompts)#, image_counts, long_video=long_video)
|
154
|
+
split_sizes = [image.shape[0] for image in pixel_values]
|
155
|
+
image_features = torch.split(image_features, split_sizes, dim=0)
|
156
|
+
|
157
|
+
# hd image_features: BS, num_patch, 576, 4096
|
158
|
+
else:
|
159
|
+
# normal pixel: BS, C=3, H=336, W=336
|
160
|
+
pixel_values = torch.tensor(
|
161
|
+
np.array(pixel_values), device=self.vision_tower.device
|
162
|
+
)
|
163
|
+
image_features = self.encode_images(pixel_values)
|
164
|
+
# image_features: BS, 576, 4096
|
165
|
+
|
166
|
+
new_image_features = []
|
167
|
+
for image_idx, image_feature in enumerate(image_features):
|
168
|
+
new_image_features.append(image_feature.flatten(0, 1))
|
169
|
+
image_features = new_image_features
|
170
|
+
|
171
|
+
extend_start_loc_cpu = input_metadata.extend_start_loc.cpu().numpy()
|
172
|
+
pt = 0
|
173
|
+
for i in range(bs):
|
174
|
+
if not need_vision[i]:
|
175
|
+
continue
|
176
|
+
|
177
|
+
start_idx = extend_start_loc_cpu[i]
|
178
|
+
pad_len, pad_dim = image_features[pt].shape # 576, 4096
|
179
|
+
dim = input_embeds.shape[1]
|
180
|
+
assert (
|
181
|
+
pad_dim == dim
|
182
|
+
), "invalid pad_dim={}, input_embed_dim={}!".format(pad_dim, dim)
|
183
|
+
# Fill in the placeholder for the image
|
184
|
+
try:
|
185
|
+
input_embeds[
|
186
|
+
start_idx
|
187
|
+
+ image_offsets[i] : start_idx
|
188
|
+
+ image_offsets[i]
|
189
|
+
+ pad_len
|
190
|
+
] = image_features[pt]
|
191
|
+
except RuntimeError as e:
|
192
|
+
print(f"RuntimeError in llava image encoding: {e}")
|
193
|
+
print(input_embeds.shape)
|
194
|
+
print(start_idx, image_offsets[i])
|
195
|
+
pt += 1
|
196
|
+
|
197
|
+
return self.language_model(
|
198
|
+
input_ids, positions, input_metadata, input_embeds=input_embeds
|
199
|
+
)
|
200
|
+
elif input_metadata.forward_mode == ForwardMode.DECODE:
|
201
|
+
return self.language_model(input_ids, positions, input_metadata)
|
202
|
+
|
203
|
+
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
204
|
+
# load clip vision model by cfg['mm_vision_tower']:
|
205
|
+
# huggingface_name or path_of_clip_relative_to_llava_model_dir
|
206
|
+
vision_path = self.config.mm_vision_tower
|
207
|
+
self.vision_tower = CLIPVisionModel.from_pretrained(
|
208
|
+
vision_path, torch_dtype=torch.float16
|
209
|
+
).cuda()
|
210
|
+
self.vision_tower.eval()
|
211
|
+
|
212
|
+
self.vision_feature_layer = self.config.mm_vision_select_layer
|
213
|
+
self.vision_feature_select_strategy = self.config.mm_vision_select_feature
|
214
|
+
self.image_size = self.vision_tower.config.image_size
|
215
|
+
self.patch_size = self.vision_tower.config.patch_size
|
216
|
+
|
217
|
+
self.mm_patch_merge_type = getattr(self.config, "mm_patch_merge_type", "flat")
|
218
|
+
self.image_aspect_ratio = getattr(self.config, "image_aspect_ratio", "square")
|
219
|
+
self.image_grid_pinpoints = getattr(self.config, "image_grid_pinpoints", None)
|
220
|
+
|
221
|
+
print(f"target_frames: {self.num_frames}")
|
222
|
+
self.image_feature_len = self.num_frames * int(
|
223
|
+
(self.image_size / self.patch_size / self.mm_spatial_pool_stride) ** 2
|
224
|
+
)
|
225
|
+
if self.vision_feature_select_strategy == "patch":
|
226
|
+
pass
|
227
|
+
elif self.vision_feature_select_strategy == "cls_patch":
|
228
|
+
self.image_feature_len += 1
|
229
|
+
else:
|
230
|
+
raise ValueError(f"Unexpected select feature: {self.select_feature}")
|
231
|
+
|
232
|
+
# load mm_projector
|
233
|
+
projector_weights = {
|
234
|
+
"model.mm_projector.0": "multi_modal_projector.linear_1",
|
235
|
+
"model.mm_projector.2": "multi_modal_projector.linear_2",
|
236
|
+
"model.vision_resampler.mm_projector.0": "multi_modal_projector.linear_1",
|
237
|
+
"model.vision_resampler.mm_projector.2": "multi_modal_projector.linear_2",
|
238
|
+
"model.vision_tower.vision_tower": "vision_tower", # Update the vision tower weights if we find them in the checkpoint (it may be finetuned).
|
239
|
+
}
|
240
|
+
params_dict = dict(self.named_parameters())
|
241
|
+
weights = list(weights)
|
242
|
+
for name, loaded_weight in weights:
|
243
|
+
# FIXME: why projector weights read two times?
|
244
|
+
if "projector" in name or "vision_tower" in name:
|
245
|
+
for weight_name, param_name in projector_weights.items():
|
246
|
+
if weight_name in name:
|
247
|
+
name = name.replace(weight_name, param_name)
|
248
|
+
if name in params_dict:
|
249
|
+
param = params_dict[name]
|
250
|
+
else:
|
251
|
+
print(f"Warning: {name} not found in the model")
|
252
|
+
continue
|
253
|
+
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
254
|
+
weight_loader(param, loaded_weight)
|
255
|
+
|
256
|
+
# load language model
|
257
|
+
self.language_model.load_weights(weights)
|
258
|
+
|
259
|
+
monkey_path_clip_vision_embed_forward()
|
260
|
+
|
261
|
+
@property
|
262
|
+
def num_patches_per_side(self):
|
263
|
+
return self.image_size // self.patch_size
|
264
|
+
|
265
|
+
|
266
|
+
first_call = True
|
267
|
+
|
268
|
+
|
269
|
+
def clip_vision_embed_forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
|
270
|
+
batch_size = pixel_values.shape[0]
|
271
|
+
|
272
|
+
# Move this conv layer to CPU to avoid a bug in torch >= 2.1 on A10G.
|
273
|
+
global first_call
|
274
|
+
if first_call:
|
275
|
+
self.patch_embedding.cpu().float()
|
276
|
+
first_call = False
|
277
|
+
pixel_values = pixel_values.to(dtype=torch.float32, device="cpu")
|
278
|
+
patch_embeds = self.patch_embedding(pixel_values).cuda().half()
|
279
|
+
|
280
|
+
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
|
281
|
+
|
282
|
+
class_embeds = self.class_embedding.expand(batch_size, 1, -1)
|
283
|
+
embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
|
284
|
+
embeddings = embeddings + self.position_embedding(self.position_ids)
|
285
|
+
return embeddings
|
286
|
+
|
287
|
+
|
288
|
+
def monkey_path_clip_vision_embed_forward():
|
289
|
+
import transformers
|
290
|
+
|
291
|
+
setattr(
|
292
|
+
transformers.models.clip.modeling_clip.CLIPVisionEmbeddings,
|
293
|
+
"forward",
|
294
|
+
clip_vision_embed_forward,
|
295
|
+
)
|
296
|
+
|
297
|
+
|
298
|
+
EntryClass = LlavaVidForCausalLM
|
@@ -0,0 +1,366 @@
|
|
1
|
+
"""Inference-only MiniCPM model compatible with HuggingFace weights."""
|
2
|
+
|
3
|
+
import math
|
4
|
+
from typing import Any, Dict, Iterable, Optional, Tuple
|
5
|
+
|
6
|
+
import torch
|
7
|
+
from torch import nn
|
8
|
+
from vllm.config import CacheConfig
|
9
|
+
from vllm.distributed import get_tensor_model_parallel_world_size
|
10
|
+
from vllm.model_executor.layers.activation import SiluAndMul
|
11
|
+
from vllm.model_executor.layers.layernorm import RMSNorm
|
12
|
+
from vllm.model_executor.layers.linear import (
|
13
|
+
MergedColumnParallelLinear,
|
14
|
+
QKVParallelLinear,
|
15
|
+
RowParallelLinear,
|
16
|
+
)
|
17
|
+
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
18
|
+
from vllm.model_executor.layers.rotary_embedding import get_rope
|
19
|
+
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
20
|
+
ParallelLMHead,
|
21
|
+
VocabParallelEmbedding,
|
22
|
+
)
|
23
|
+
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
24
|
+
|
25
|
+
from sglang.srt.layers.logits_processor import LogitsProcessor
|
26
|
+
from sglang.srt.layers.radix_attention import RadixAttention
|
27
|
+
from sglang.srt.managers.controller.model_runner import InputMetadata
|
28
|
+
|
29
|
+
|
30
|
+
class MiniCPMMLP(nn.Module):
|
31
|
+
def __init__(
|
32
|
+
self,
|
33
|
+
hidden_size: int,
|
34
|
+
intermediate_size: int,
|
35
|
+
hidden_act: str,
|
36
|
+
quant_config: Optional[QuantizationConfig] = None,
|
37
|
+
) -> None:
|
38
|
+
super().__init__()
|
39
|
+
self.gate_up_proj = MergedColumnParallelLinear(
|
40
|
+
hidden_size,
|
41
|
+
[intermediate_size] * 2,
|
42
|
+
bias=False,
|
43
|
+
quant_config=quant_config,
|
44
|
+
)
|
45
|
+
self.down_proj = RowParallelLinear(
|
46
|
+
intermediate_size,
|
47
|
+
hidden_size,
|
48
|
+
bias=False,
|
49
|
+
quant_config=quant_config,
|
50
|
+
)
|
51
|
+
if hidden_act != "silu":
|
52
|
+
raise ValueError(
|
53
|
+
f"Unsupported activation: {hidden_act}. "
|
54
|
+
"Only silu is supported for now."
|
55
|
+
)
|
56
|
+
self.act_fn = SiluAndMul()
|
57
|
+
|
58
|
+
def forward(self, x):
|
59
|
+
gate_up, _ = self.gate_up_proj(x)
|
60
|
+
x = self.act_fn(gate_up)
|
61
|
+
x, _ = self.down_proj(x)
|
62
|
+
return x
|
63
|
+
|
64
|
+
|
65
|
+
class MiniCPMAttention(nn.Module):
|
66
|
+
def __init__(
|
67
|
+
self,
|
68
|
+
hidden_size: int,
|
69
|
+
num_heads: int,
|
70
|
+
num_kv_heads: int,
|
71
|
+
layer_id: int = 0,
|
72
|
+
rope_theta: float = 10000,
|
73
|
+
rope_scaling: Optional[Dict[str, Any]] = None,
|
74
|
+
max_position_embeddings: int = 8192,
|
75
|
+
quant_config: Optional[QuantizationConfig] = None,
|
76
|
+
) -> None:
|
77
|
+
super().__init__()
|
78
|
+
self.hidden_size = hidden_size
|
79
|
+
tp_size = get_tensor_model_parallel_world_size()
|
80
|
+
self.total_num_heads = num_heads
|
81
|
+
assert self.total_num_heads % tp_size == 0
|
82
|
+
self.num_heads = self.total_num_heads // tp_size
|
83
|
+
self.total_num_kv_heads = num_kv_heads
|
84
|
+
if self.total_num_kv_heads >= tp_size:
|
85
|
+
# Number of KV heads is greater than TP size, so we partition
|
86
|
+
# the KV heads across multiple tensor parallel GPUs.
|
87
|
+
assert self.total_num_kv_heads % tp_size == 0
|
88
|
+
else:
|
89
|
+
# Number of KV heads is less than TP size, so we replicate
|
90
|
+
# the KV heads across multiple tensor parallel GPUs.
|
91
|
+
assert tp_size % self.total_num_kv_heads == 0
|
92
|
+
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
|
93
|
+
self.head_dim = hidden_size // self.total_num_heads
|
94
|
+
self.q_size = self.num_heads * self.head_dim
|
95
|
+
self.kv_size = self.num_kv_heads * self.head_dim
|
96
|
+
self.scaling = self.head_dim**-0.5
|
97
|
+
self.rope_theta = rope_theta
|
98
|
+
self.max_position_embeddings = max_position_embeddings
|
99
|
+
|
100
|
+
self.qkv_proj = QKVParallelLinear(
|
101
|
+
hidden_size,
|
102
|
+
self.head_dim,
|
103
|
+
self.total_num_heads,
|
104
|
+
self.total_num_kv_heads,
|
105
|
+
bias=False,
|
106
|
+
quant_config=quant_config,
|
107
|
+
)
|
108
|
+
self.o_proj = RowParallelLinear(
|
109
|
+
self.total_num_heads * self.head_dim,
|
110
|
+
hidden_size,
|
111
|
+
bias=False,
|
112
|
+
quant_config=quant_config,
|
113
|
+
)
|
114
|
+
|
115
|
+
self.rotary_emb = get_rope(
|
116
|
+
self.head_dim,
|
117
|
+
rotary_dim=self.head_dim,
|
118
|
+
max_position=max_position_embeddings,
|
119
|
+
base=rope_theta,
|
120
|
+
rope_scaling=rope_scaling,
|
121
|
+
)
|
122
|
+
# set rope as fp32 instead of bf16
|
123
|
+
self.rotary_emb.cos_sin_cache = self.rotary_emb._compute_cos_sin_cache()
|
124
|
+
self.attn = RadixAttention(
|
125
|
+
self.num_heads,
|
126
|
+
self.head_dim,
|
127
|
+
self.scaling,
|
128
|
+
num_kv_heads=self.num_kv_heads,
|
129
|
+
layer_id=layer_id,
|
130
|
+
)
|
131
|
+
|
132
|
+
def forward(
|
133
|
+
self,
|
134
|
+
positions: torch.Tensor,
|
135
|
+
hidden_states: torch.Tensor,
|
136
|
+
input_metadata: InputMetadata,
|
137
|
+
) -> torch.Tensor:
|
138
|
+
qkv, _ = self.qkv_proj(hidden_states)
|
139
|
+
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
140
|
+
orig_dtype = q.dtype
|
141
|
+
q, k = q.float(), k.float()
|
142
|
+
q, k = self.rotary_emb(positions, q, k)
|
143
|
+
q, k = q.to(orig_dtype), k.to(orig_dtype)
|
144
|
+
attn_output = self.attn(q, k, v, input_metadata)
|
145
|
+
output, _ = self.o_proj(attn_output)
|
146
|
+
return output
|
147
|
+
|
148
|
+
|
149
|
+
class MiniCPMDecoderLayer(nn.Module):
|
150
|
+
def __init__(
|
151
|
+
self,
|
152
|
+
config,
|
153
|
+
layer_id: int = 0,
|
154
|
+
quant_config: Optional[QuantizationConfig] = None,
|
155
|
+
) -> None:
|
156
|
+
super().__init__()
|
157
|
+
self.config = config
|
158
|
+
self.hidden_size = config.hidden_size
|
159
|
+
rope_theta = getattr(config, "rope_theta", 10000)
|
160
|
+
rope_scaling = getattr(config, "rope_scaling", None)
|
161
|
+
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
|
162
|
+
self.self_attn = MiniCPMAttention(
|
163
|
+
hidden_size=self.hidden_size,
|
164
|
+
num_heads=config.num_attention_heads,
|
165
|
+
num_kv_heads=config.num_key_value_heads,
|
166
|
+
layer_id=layer_id,
|
167
|
+
rope_theta=rope_theta,
|
168
|
+
rope_scaling=rope_scaling,
|
169
|
+
max_position_embeddings=max_position_embeddings,
|
170
|
+
quant_config=quant_config,
|
171
|
+
)
|
172
|
+
self.mlp = MiniCPMMLP(
|
173
|
+
hidden_size=self.hidden_size,
|
174
|
+
intermediate_size=config.intermediate_size,
|
175
|
+
hidden_act=config.hidden_act,
|
176
|
+
quant_config=quant_config,
|
177
|
+
)
|
178
|
+
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
179
|
+
self.post_attention_layernorm = RMSNorm(
|
180
|
+
config.hidden_size, eps=config.rms_norm_eps
|
181
|
+
)
|
182
|
+
|
183
|
+
def forward(
|
184
|
+
self,
|
185
|
+
positions: torch.Tensor,
|
186
|
+
hidden_states: torch.Tensor,
|
187
|
+
input_metadata: InputMetadata,
|
188
|
+
residual: Optional[torch.Tensor],
|
189
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
190
|
+
# Self Attention
|
191
|
+
residual = hidden_states
|
192
|
+
hidden_states = self.input_layernorm(hidden_states)
|
193
|
+
hidden_states = self.self_attn(
|
194
|
+
positions=positions,
|
195
|
+
hidden_states=hidden_states,
|
196
|
+
input_metadata=input_metadata,
|
197
|
+
)
|
198
|
+
hidden_states = residual + hidden_states * (
|
199
|
+
self.config.scale_depth / math.sqrt(self.config.num_hidden_layers)
|
200
|
+
)
|
201
|
+
|
202
|
+
# Fully Connected
|
203
|
+
residual = hidden_states
|
204
|
+
hidden_states = self.post_attention_layernorm(hidden_states)
|
205
|
+
hidden_states = self.mlp(hidden_states)
|
206
|
+
hidden_states = residual + hidden_states * (
|
207
|
+
self.config.scale_depth / math.sqrt(self.config.num_hidden_layers)
|
208
|
+
)
|
209
|
+
|
210
|
+
return hidden_states, None
|
211
|
+
|
212
|
+
|
213
|
+
class MiniCPMModel(nn.Module):
|
214
|
+
def __init__(
|
215
|
+
self,
|
216
|
+
config,
|
217
|
+
quant_config: Optional[QuantizationConfig] = None,
|
218
|
+
) -> None:
|
219
|
+
super().__init__()
|
220
|
+
self.config = config
|
221
|
+
self.padding_idx = config.pad_token_id
|
222
|
+
self.vocab_size = config.vocab_size
|
223
|
+
self.embed_tokens = VocabParallelEmbedding(
|
224
|
+
self.vocab_size,
|
225
|
+
config.hidden_size,
|
226
|
+
org_num_embeddings=config.vocab_size,
|
227
|
+
)
|
228
|
+
self.layers = nn.ModuleList(
|
229
|
+
[
|
230
|
+
MiniCPMDecoderLayer(config, i, quant_config=quant_config)
|
231
|
+
for i in range(config.num_hidden_layers)
|
232
|
+
]
|
233
|
+
)
|
234
|
+
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
235
|
+
|
236
|
+
def forward(
|
237
|
+
self,
|
238
|
+
input_ids: torch.Tensor,
|
239
|
+
positions: torch.Tensor,
|
240
|
+
input_metadata: InputMetadata,
|
241
|
+
input_embeds: torch.Tensor = None,
|
242
|
+
) -> torch.Tensor:
|
243
|
+
if input_embeds is None:
|
244
|
+
hidden_states = self.embed_tokens(input_ids) * self.config.scale_emb
|
245
|
+
else:
|
246
|
+
hidden_states = input_embeds
|
247
|
+
residual = None
|
248
|
+
|
249
|
+
for i in range(len(self.layers)):
|
250
|
+
layer = self.layers[i]
|
251
|
+
hidden_states, residual = layer(
|
252
|
+
positions,
|
253
|
+
hidden_states,
|
254
|
+
input_metadata,
|
255
|
+
residual,
|
256
|
+
)
|
257
|
+
hidden_states = self.norm(hidden_states)
|
258
|
+
return hidden_states
|
259
|
+
|
260
|
+
|
261
|
+
class MiniCPMForCausalLM(nn.Module):
|
262
|
+
def __init__(
|
263
|
+
self,
|
264
|
+
config,
|
265
|
+
quant_config: Optional[QuantizationConfig] = None,
|
266
|
+
cache_config: Optional[CacheConfig] = None,
|
267
|
+
) -> None:
|
268
|
+
super().__init__()
|
269
|
+
self.config = config
|
270
|
+
|
271
|
+
self.num_experts = getattr(self.config, "num_experts", 0)
|
272
|
+
self.quant_config = quant_config
|
273
|
+
self.model = MiniCPMModel(config, quant_config=quant_config)
|
274
|
+
# self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
275
|
+
if not self.config.tie_word_embeddings:
|
276
|
+
self.lm_head = ParallelLMHead(
|
277
|
+
config.vocab_size,
|
278
|
+
config.hidden_size,
|
279
|
+
org_num_embeddings=config.vocab_size,
|
280
|
+
)
|
281
|
+
|
282
|
+
self.scale_width = self.config.hidden_size / self.config.dim_model_base
|
283
|
+
|
284
|
+
self.logits_processor = LogitsProcessor(config)
|
285
|
+
|
286
|
+
def forward(
|
287
|
+
self,
|
288
|
+
input_ids: torch.Tensor,
|
289
|
+
positions: torch.Tensor,
|
290
|
+
input_metadata: InputMetadata,
|
291
|
+
input_embeds: torch.Tensor = None,
|
292
|
+
) -> torch.Tensor:
|
293
|
+
if input_embeds is not None:
|
294
|
+
input_embeds = input_embeds * self.config.scale_emb
|
295
|
+
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
|
296
|
+
hidden_states = hidden_states / self.scale_width
|
297
|
+
if self.config.tie_word_embeddings:
|
298
|
+
lm_head_weight = self.model.embed_tokens.weight
|
299
|
+
else:
|
300
|
+
lm_head_weight = self.lm_head.weight
|
301
|
+
return self.logits_processor(
|
302
|
+
input_ids, hidden_states, lm_head_weight, input_metadata
|
303
|
+
)
|
304
|
+
|
305
|
+
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
306
|
+
stacked_params_mapping = [
|
307
|
+
# (param_name, shard_name, shard_id)
|
308
|
+
("qkv_proj", "q_proj", "q"),
|
309
|
+
("qkv_proj", "k_proj", "k"),
|
310
|
+
("qkv_proj", "v_proj", "v"),
|
311
|
+
("gate_up_proj", "gate_proj", 0),
|
312
|
+
("gate_up_proj", "up_proj", 1),
|
313
|
+
]
|
314
|
+
expert_params_mapping = [
|
315
|
+
# (param_name, weight_name, expert_id)
|
316
|
+
(
|
317
|
+
"ws" if weight_name in ["w1", "w3"] else "w2s",
|
318
|
+
f"experts.{expert_id}.{weight_name}.weight",
|
319
|
+
expert_id,
|
320
|
+
)
|
321
|
+
for expert_id in range(self.num_experts)
|
322
|
+
for weight_name in ["w1", "w2", "w3"]
|
323
|
+
]
|
324
|
+
params_dict = dict(self.named_parameters())
|
325
|
+
for name, loaded_weight in weights:
|
326
|
+
if "rotary_emb.inv_freq" in name:
|
327
|
+
continue
|
328
|
+
if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
|
329
|
+
# Models trained using ColossalAI may include these tensors in
|
330
|
+
# the checkpoint. Skip them.
|
331
|
+
continue
|
332
|
+
|
333
|
+
for param_name, weight_name, shard_id in stacked_params_mapping:
|
334
|
+
if weight_name not in name:
|
335
|
+
continue
|
336
|
+
name = name.replace(weight_name, param_name)
|
337
|
+
# Skip loading extra bias for GPTQ models.
|
338
|
+
if name.endswith(".bias") and name not in params_dict:
|
339
|
+
continue
|
340
|
+
param = params_dict[name]
|
341
|
+
weight_loader = param.weight_loader
|
342
|
+
weight_loader(param, loaded_weight, shard_id)
|
343
|
+
break
|
344
|
+
else:
|
345
|
+
for param_name, weight_name, expert_id in expert_params_mapping:
|
346
|
+
if weight_name not in name:
|
347
|
+
continue
|
348
|
+
name = name.replace(weight_name, param_name)
|
349
|
+
param = params_dict[name]
|
350
|
+
weight_loader = param.weight_loader
|
351
|
+
weight_loader(
|
352
|
+
param, loaded_weight, weight_name, expert_id=expert_id
|
353
|
+
)
|
354
|
+
break
|
355
|
+
else:
|
356
|
+
# Skip loading extra bias for GPTQ models.
|
357
|
+
if name.endswith(".bias") and name not in params_dict:
|
358
|
+
continue
|
359
|
+
param = params_dict[name]
|
360
|
+
weight_loader = getattr(
|
361
|
+
param, "weight_loader", default_weight_loader
|
362
|
+
)
|
363
|
+
weight_loader(param, loaded_weight)
|
364
|
+
|
365
|
+
|
366
|
+
EntryClass = MiniCPMForCausalLM
|