sglang 0.2.14__py3-none-any.whl → 0.2.14.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/launch_server_llavavid.py +26 -0
- sglang/srt/constrained/fsm_cache.py +11 -2
- sglang/srt/constrained/jump_forward.py +1 -0
- sglang/srt/hf_transformers_utils.py +0 -149
- sglang/srt/layers/activation.py +93 -11
- sglang/srt/layers/layernorm.py +47 -4
- sglang/srt/layers/logits_processor.py +4 -4
- sglang/srt/layers/sampler.py +15 -68
- sglang/srt/managers/io_struct.py +5 -4
- sglang/srt/managers/schedule_batch.py +20 -25
- sglang/srt/managers/tokenizer_manager.py +74 -61
- sglang/srt/managers/tp_worker.py +49 -43
- sglang/srt/model_executor/cuda_graph_runner.py +17 -31
- sglang/srt/model_executor/forward_batch_info.py +9 -26
- sglang/srt/model_executor/model_runner.py +20 -17
- sglang/srt/models/chatglm.py +13 -5
- sglang/srt/models/commandr.py +1 -5
- sglang/srt/models/dbrx.py +1 -5
- sglang/srt/models/deepseek.py +1 -5
- sglang/srt/models/deepseek_v2.py +1 -5
- sglang/srt/models/gemma.py +3 -7
- sglang/srt/models/gemma2.py +2 -56
- sglang/srt/models/gpt_bigcode.py +2 -6
- sglang/srt/models/grok.py +10 -8
- sglang/srt/models/internlm2.py +1 -5
- sglang/srt/models/llama2.py +6 -11
- sglang/srt/models/llama_classification.py +2 -6
- sglang/srt/models/llama_embedding.py +3 -4
- sglang/srt/models/llava.py +69 -91
- sglang/srt/models/llavavid.py +40 -86
- sglang/srt/models/minicpm.py +1 -5
- sglang/srt/models/mixtral.py +1 -5
- sglang/srt/models/mixtral_quant.py +1 -5
- sglang/srt/models/qwen.py +2 -5
- sglang/srt/models/qwen2.py +5 -10
- sglang/srt/models/qwen2_moe.py +21 -24
- sglang/srt/models/stablelm.py +1 -5
- sglang/srt/models/yivl.py +2 -7
- sglang/srt/openai_api/adapter.py +85 -4
- sglang/srt/openai_api/protocol.py +2 -0
- sglang/srt/sampling/sampling_batch_info.py +1 -74
- sglang/srt/sampling/sampling_params.py +4 -0
- sglang/srt/server.py +11 -4
- sglang/srt/utils.py +18 -33
- sglang/test/runners.py +2 -2
- sglang/test/test_layernorm.py +53 -1
- sglang/version.py +1 -1
- {sglang-0.2.14.dist-info → sglang-0.2.14.post2.dist-info}/METADATA +11 -5
- {sglang-0.2.14.dist-info → sglang-0.2.14.post2.dist-info}/RECORD +52 -51
- {sglang-0.2.14.dist-info → sglang-0.2.14.post2.dist-info}/WHEEL +1 -1
- {sglang-0.2.14.dist-info → sglang-0.2.14.post2.dist-info}/LICENSE +0 -0
- {sglang-0.2.14.dist-info → sglang-0.2.14.post2.dist-info}/top_level.txt +0 -0
sglang/srt/models/llavavid.py
CHANGED
@@ -26,11 +26,6 @@ from vllm.config import CacheConfig
|
|
26
26
|
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
27
27
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
28
28
|
|
29
|
-
from sglang.srt.mm_utils import (
|
30
|
-
get_anyres_image_grid_shape,
|
31
|
-
unpad_image,
|
32
|
-
unpad_image_shape,
|
33
|
-
)
|
34
29
|
from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
|
35
30
|
from sglang.srt.models.llama2 import LlamaForCausalLM
|
36
31
|
|
@@ -59,23 +54,14 @@ class LlavaVidForCausalLM(nn.Module):
|
|
59
54
|
torch.empty(config.text_config.hidden_size, dtype=torch.float16)
|
60
55
|
)
|
61
56
|
|
62
|
-
def pad_input_ids(
|
57
|
+
def pad_input_ids(
|
58
|
+
self,
|
59
|
+
input_ids: List[int],
|
60
|
+
pad_value: List[int],
|
61
|
+
pixel_values: List,
|
62
|
+
image_sizes: List[List[int]],
|
63
|
+
):
|
63
64
|
new_image_feature_len = self.image_feature_len
|
64
|
-
# now only support spatial_unpad + anyres
|
65
|
-
# if self.mm_patch_merge_type.startswith("spatial"):
|
66
|
-
# height = width = self.num_patches_per_side
|
67
|
-
# if pt_shape[0] > 1:
|
68
|
-
# if self.image_aspect_ratio == "anyres":
|
69
|
-
# num_patch_width, num_patch_height = get_anyres_image_grid_shape(
|
70
|
-
# image_size,
|
71
|
-
# self.image_grid_pinpoints,
|
72
|
-
# self.vision_tower.config.image_size,
|
73
|
-
# )
|
74
|
-
# if "unpad" in self.mm_patch_merge_type:
|
75
|
-
# h = num_patch_height * height
|
76
|
-
# w = num_patch_width * width
|
77
|
-
# new_h, new_w = unpad_image_shape(h, w, image_size)
|
78
|
-
# new_image_feature_len += new_h * (new_w + 1)
|
79
65
|
|
80
66
|
pad_ids = pad_value * (
|
81
67
|
(new_image_feature_len + len(pad_value)) // len(pad_value)
|
@@ -87,7 +73,7 @@ class LlavaVidForCausalLM(nn.Module):
|
|
87
73
|
+ pad_ids[:new_image_feature_len]
|
88
74
|
+ input_ids[offset + 1 :]
|
89
75
|
)
|
90
|
-
return new_input_ids, offset
|
76
|
+
return new_input_ids, [offset]
|
91
77
|
|
92
78
|
def encode_images(self, pixel_values: torch.Tensor) -> torch.Tensor:
|
93
79
|
image_outputs = self.vision_tower(pixel_values, output_hidden_states=True)
|
@@ -133,22 +119,18 @@ class LlavaVidForCausalLM(nn.Module):
|
|
133
119
|
if input_metadata.forward_mode == ForwardMode.EXTEND:
|
134
120
|
bs = input_metadata.batch_size
|
135
121
|
|
136
|
-
# Embed text
|
122
|
+
# Embed text inputs
|
137
123
|
input_embeds = self.language_model.model.embed_tokens(input_ids)
|
138
124
|
|
139
|
-
#
|
140
|
-
|
141
|
-
(
|
142
|
-
.cpu()
|
143
|
-
.numpy()
|
125
|
+
# Whether the requests need vision inputs
|
126
|
+
max_image_offset = np.array(
|
127
|
+
[max(image_offsets[i]) if image_offsets[i] else -1 for i in range(bs)]
|
144
128
|
)
|
145
|
-
|
146
|
-
|
147
|
-
need_vision = need_vision & has_pixel
|
129
|
+
start_positions = positions[input_metadata.extend_start_loc].cpu().numpy()
|
130
|
+
need_vision = start_positions <= max_image_offset
|
148
131
|
|
149
132
|
if need_vision.any():
|
150
133
|
pixel_values = [pixel_values[i] for i in range(bs) if need_vision[i]]
|
151
|
-
image_sizes = [image_sizes[i] for i in range(bs) if need_vision[i]]
|
152
134
|
|
153
135
|
########## Encode Image ########
|
154
136
|
|
@@ -183,31 +165,36 @@ class LlavaVidForCausalLM(nn.Module):
|
|
183
165
|
new_image_features.append(image_feature.flatten(0, 1))
|
184
166
|
image_features = new_image_features
|
185
167
|
|
168
|
+
# Fill in the placeholder for the image
|
186
169
|
extend_start_loc_cpu = input_metadata.extend_start_loc.cpu().numpy()
|
170
|
+
prefix_lens_cpu = input_metadata.extend_prefix_lens.cpu().numpy()
|
187
171
|
pt = 0
|
188
172
|
for i in range(bs):
|
189
173
|
if not need_vision[i]:
|
190
174
|
continue
|
191
175
|
|
192
176
|
start_idx = extend_start_loc_cpu[i]
|
193
|
-
|
194
|
-
|
195
|
-
|
196
|
-
|
197
|
-
|
198
|
-
|
199
|
-
|
200
|
-
|
201
|
-
|
202
|
-
|
203
|
-
|
204
|
-
|
205
|
-
|
206
|
-
|
207
|
-
|
208
|
-
|
209
|
-
|
210
|
-
|
177
|
+
prefix_len = prefix_lens_cpu[i]
|
178
|
+
|
179
|
+
# Multiple images
|
180
|
+
for image_offset in image_offsets[i]:
|
181
|
+
if image_offset < prefix_len:
|
182
|
+
continue
|
183
|
+
|
184
|
+
tmp_image_feature = image_features[pt]
|
185
|
+
pad_len = tmp_image_feature.shape[0]
|
186
|
+
|
187
|
+
left_idx = start_idx + (image_offset - prefix_len)
|
188
|
+
right_idx = start_idx + (image_offset - prefix_len) + pad_len
|
189
|
+
try:
|
190
|
+
input_embeds[left_idx:right_idx] = tmp_image_feature
|
191
|
+
except RuntimeError as e:
|
192
|
+
print(f"RuntimeError in image encoding: {e}")
|
193
|
+
print(f"{input_embeds.shape=}, {tmp_image_feature.shape=}")
|
194
|
+
print(
|
195
|
+
f"{start_idx=}, {image_offset=}, {prefix_len=}, {pad_len=}"
|
196
|
+
)
|
197
|
+
pt += 1
|
211
198
|
|
212
199
|
return self.language_model(
|
213
200
|
input_ids, positions, input_metadata, input_embeds=input_embeds
|
@@ -216,8 +203,9 @@ class LlavaVidForCausalLM(nn.Module):
|
|
216
203
|
return self.language_model(input_ids, positions, input_metadata)
|
217
204
|
|
218
205
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
219
|
-
#
|
220
|
-
#
|
206
|
+
# Load clip vision model by cfg['mm_vision_tower']:
|
207
|
+
# huggingface_name or path_of_clip_relative_to_llava_model_dir
|
208
|
+
# We put the initialization here instead of __init__ to allow it being reused by other subclasses.
|
221
209
|
vision_path = self.config.mm_vision_tower
|
222
210
|
self.vision_tower = CLIPVisionModel.from_pretrained(
|
223
211
|
vision_path, torch_dtype=torch.float16
|
@@ -271,43 +259,9 @@ class LlavaVidForCausalLM(nn.Module):
|
|
271
259
|
# load language model
|
272
260
|
self.language_model.load_weights(weights)
|
273
261
|
|
274
|
-
monkey_path_clip_vision_embed_forward()
|
275
|
-
|
276
262
|
@property
|
277
263
|
def num_patches_per_side(self):
|
278
264
|
return self.image_size // self.patch_size
|
279
265
|
|
280
266
|
|
281
|
-
first_call = True
|
282
|
-
|
283
|
-
|
284
|
-
def clip_vision_embed_forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
|
285
|
-
batch_size = pixel_values.shape[0]
|
286
|
-
|
287
|
-
# Move this conv layer to CPU to avoid a bug in torch >= 2.1 on A10G.
|
288
|
-
global first_call
|
289
|
-
if first_call:
|
290
|
-
self.patch_embedding.cpu().float()
|
291
|
-
first_call = False
|
292
|
-
pixel_values = pixel_values.to(dtype=torch.float32, device="cpu")
|
293
|
-
patch_embeds = self.patch_embedding(pixel_values).cuda().half()
|
294
|
-
|
295
|
-
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
|
296
|
-
|
297
|
-
class_embeds = self.class_embedding.expand(batch_size, 1, -1)
|
298
|
-
embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
|
299
|
-
embeddings = embeddings + self.position_embedding(self.position_ids)
|
300
|
-
return embeddings
|
301
|
-
|
302
|
-
|
303
|
-
def monkey_path_clip_vision_embed_forward():
|
304
|
-
import transformers
|
305
|
-
|
306
|
-
setattr(
|
307
|
-
transformers.models.clip.modeling_clip.CLIPVisionEmbeddings,
|
308
|
-
"forward",
|
309
|
-
clip_vision_embed_forward,
|
310
|
-
)
|
311
|
-
|
312
|
-
|
313
267
|
EntryClass = LlavaVidForCausalLM
|
sglang/srt/models/minicpm.py
CHANGED
@@ -39,7 +39,6 @@ from sglang.srt.layers.activation import SiluAndMul
|
|
39
39
|
from sglang.srt.layers.layernorm import RMSNorm
|
40
40
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
41
41
|
from sglang.srt.layers.radix_attention import RadixAttention
|
42
|
-
from sglang.srt.layers.sampler import Sampler
|
43
42
|
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
44
43
|
|
45
44
|
|
@@ -298,7 +297,6 @@ class MiniCPMForCausalLM(nn.Module):
|
|
298
297
|
self.scale_width = self.config.hidden_size / self.config.dim_model_base
|
299
298
|
|
300
299
|
self.logits_processor = LogitsProcessor(config)
|
301
|
-
self.sampler = Sampler()
|
302
300
|
|
303
301
|
@torch.no_grad()
|
304
302
|
def forward(
|
@@ -316,11 +314,9 @@ class MiniCPMForCausalLM(nn.Module):
|
|
316
314
|
lm_head_weight = self.model.embed_tokens.weight
|
317
315
|
else:
|
318
316
|
lm_head_weight = self.lm_head.weight
|
319
|
-
|
317
|
+
return self.logits_processor(
|
320
318
|
input_ids, hidden_states, lm_head_weight, input_metadata
|
321
319
|
)
|
322
|
-
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
|
323
|
-
return sample_output, logits_output
|
324
320
|
|
325
321
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
326
322
|
stacked_params_mapping = [
|
sglang/srt/models/mixtral.py
CHANGED
@@ -41,7 +41,6 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
|
41
41
|
from sglang.srt.layers.layernorm import RMSNorm
|
42
42
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
43
43
|
from sglang.srt.layers.radix_attention import RadixAttention
|
44
|
-
from sglang.srt.layers.sampler import Sampler
|
45
44
|
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
46
45
|
|
47
46
|
|
@@ -300,7 +299,6 @@ class MixtralForCausalLM(nn.Module):
|
|
300
299
|
self.model = MixtralModel(config, quant_config=quant_config, prefix="model")
|
301
300
|
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
302
301
|
self.logits_processor = LogitsProcessor(config)
|
303
|
-
self.sampler = Sampler()
|
304
302
|
|
305
303
|
def forward(
|
306
304
|
self,
|
@@ -310,11 +308,9 @@ class MixtralForCausalLM(nn.Module):
|
|
310
308
|
input_embeds: torch.Tensor = None,
|
311
309
|
) -> torch.Tensor:
|
312
310
|
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
|
313
|
-
|
311
|
+
return self.logits_processor(
|
314
312
|
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
315
313
|
)
|
316
|
-
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
|
317
|
-
return sample_output, logits_output
|
318
314
|
|
319
315
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
320
316
|
stacked_params_mapping = [
|
@@ -45,7 +45,6 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
|
45
45
|
from sglang.srt.layers.layernorm import RMSNorm
|
46
46
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
47
47
|
from sglang.srt.layers.radix_attention import RadixAttention
|
48
|
-
from sglang.srt.layers.sampler import Sampler
|
49
48
|
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
50
49
|
|
51
50
|
|
@@ -334,7 +333,6 @@ class QuantMixtralForCausalLM(nn.Module):
|
|
334
333
|
self.model = MixtralModel(config, quant_config=quant_config)
|
335
334
|
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
336
335
|
self.logits_processor = LogitsProcessor(config)
|
337
|
-
self.sampler = Sampler()
|
338
336
|
|
339
337
|
@torch.no_grad()
|
340
338
|
def forward(
|
@@ -345,11 +343,9 @@ class QuantMixtralForCausalLM(nn.Module):
|
|
345
343
|
input_embeds: torch.Tensor = None,
|
346
344
|
) -> torch.Tensor:
|
347
345
|
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
|
348
|
-
|
346
|
+
return self.logits_processor(
|
349
347
|
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
350
348
|
)
|
351
|
-
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
|
352
|
-
return sample_output, logits_output
|
353
349
|
|
354
350
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
355
351
|
stacked_params_mapping = [
|
sglang/srt/models/qwen.py
CHANGED
@@ -39,7 +39,6 @@ from sglang.srt.layers.activation import SiluAndMul
|
|
39
39
|
from sglang.srt.layers.layernorm import RMSNorm
|
40
40
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
41
41
|
from sglang.srt.layers.radix_attention import RadixAttention
|
42
|
-
from sglang.srt.layers.sampler import Sampler
|
43
42
|
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
44
43
|
|
45
44
|
|
@@ -252,7 +251,6 @@ class QWenLMHeadModel(nn.Module):
|
|
252
251
|
vocab_size = ((config.vocab_size + 63) // 64) * 64
|
253
252
|
self.lm_head = ParallelLMHead(vocab_size, config.hidden_size)
|
254
253
|
self.logits_processor = LogitsProcessor(config)
|
255
|
-
self.sampler = Sampler()
|
256
254
|
|
257
255
|
@torch.no_grad()
|
258
256
|
def forward(
|
@@ -262,11 +260,10 @@ class QWenLMHeadModel(nn.Module):
|
|
262
260
|
input_metadata: InputMetadata,
|
263
261
|
):
|
264
262
|
hidden_states = self.transformer(input_ids, positions, input_metadata)
|
265
|
-
|
263
|
+
next_tokens = self.logits_processor(
|
266
264
|
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
267
265
|
)
|
268
|
-
|
269
|
-
return sample_output, logits_output
|
266
|
+
return next_tokens
|
270
267
|
|
271
268
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
272
269
|
stacked_params_mapping = [
|
sglang/srt/models/qwen2.py
CHANGED
@@ -38,9 +38,8 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
|
38
38
|
from sglang.srt.layers.activation import SiluAndMul
|
39
39
|
from sglang.srt.layers.layernorm import RMSNorm
|
40
40
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
41
|
-
from sglang.srt.layers.pooler import Pooler, PoolingType
|
41
|
+
from sglang.srt.layers.pooler import EmbeddingPoolerOutput, Pooler, PoolingType
|
42
42
|
from sglang.srt.layers.radix_attention import RadixAttention
|
43
|
-
from sglang.srt.layers.sampler import Sampler
|
44
43
|
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
45
44
|
|
46
45
|
Qwen2Config = None
|
@@ -277,7 +276,6 @@ class Qwen2ForCausalLM(nn.Module):
|
|
277
276
|
self.model = Qwen2Model(config, quant_config=quant_config)
|
278
277
|
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
279
278
|
self.logits_processor = LogitsProcessor(config)
|
280
|
-
self.sampler = Sampler()
|
281
279
|
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
|
282
280
|
|
283
281
|
@torch.no_grad()
|
@@ -291,11 +289,9 @@ class Qwen2ForCausalLM(nn.Module):
|
|
291
289
|
) -> torch.Tensor:
|
292
290
|
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
|
293
291
|
if not get_embedding:
|
294
|
-
|
292
|
+
return self.logits_processor(
|
295
293
|
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
296
294
|
)
|
297
|
-
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
|
298
|
-
return sample_output, logits_output
|
299
295
|
else:
|
300
296
|
return self.pooler(hidden_states, input_metadata)
|
301
297
|
|
@@ -316,6 +312,9 @@ class Qwen2ForCausalLM(nn.Module):
|
|
316
312
|
# Models trained using ColossalAI may include these tensors in
|
317
313
|
# the checkpoint. Skip them.
|
318
314
|
continue
|
315
|
+
if name.startswith("model.vision_tower") and name not in params_dict:
|
316
|
+
continue
|
317
|
+
|
319
318
|
for param_name, weight_name, shard_id in stacked_params_mapping:
|
320
319
|
if weight_name not in name:
|
321
320
|
continue
|
@@ -323,8 +322,6 @@ class Qwen2ForCausalLM(nn.Module):
|
|
323
322
|
# Skip loading extra bias for GPTQ models.
|
324
323
|
if name.endswith(".bias") and name not in params_dict:
|
325
324
|
continue
|
326
|
-
if name.startswith("model.vision_tower") and name not in params_dict:
|
327
|
-
continue
|
328
325
|
param = params_dict[name]
|
329
326
|
weight_loader = param.weight_loader
|
330
327
|
weight_loader(param, loaded_weight, shard_id)
|
@@ -333,8 +330,6 @@ class Qwen2ForCausalLM(nn.Module):
|
|
333
330
|
# Skip loading extra bias for GPTQ models.
|
334
331
|
if name.endswith(".bias") and name not in params_dict:
|
335
332
|
continue
|
336
|
-
if name.startswith("model.vision_tower") and name not in params_dict:
|
337
|
-
continue
|
338
333
|
param = params_dict[name]
|
339
334
|
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
340
335
|
weight_loader(param, loaded_weight)
|
sglang/srt/models/qwen2_moe.py
CHANGED
@@ -35,8 +35,10 @@ from vllm.model_executor.layers.linear import (
|
|
35
35
|
ReplicatedLinear,
|
36
36
|
RowParallelLinear,
|
37
37
|
)
|
38
|
+
from vllm.model_executor.layers.logits_processor import LogitsProcessor
|
38
39
|
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
39
40
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
41
|
+
from vllm.model_executor.layers.sampler import Sampler
|
40
42
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
41
43
|
ParallelLMHead,
|
42
44
|
VocabParallelEmbedding,
|
@@ -47,7 +49,6 @@ from sglang.srt.layers.activation import SiluAndMul
|
|
47
49
|
from sglang.srt.layers.layernorm import RMSNorm
|
48
50
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
49
51
|
from sglang.srt.layers.radix_attention import RadixAttention
|
50
|
-
from sglang.srt.layers.sampler import Sampler
|
51
52
|
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
52
53
|
|
53
54
|
|
@@ -365,7 +366,6 @@ class Qwen2MoeForCausalLM(nn.Module):
|
|
365
366
|
config.vocab_size, config.hidden_size, quant_config=quant_config
|
366
367
|
)
|
367
368
|
self.logits_processor = LogitsProcessor(config)
|
368
|
-
self.sampler = Sampler()
|
369
369
|
|
370
370
|
@torch.no_grad()
|
371
371
|
def forward(
|
@@ -376,11 +376,20 @@ class Qwen2MoeForCausalLM(nn.Module):
|
|
376
376
|
input_embeds: torch.Tensor = None,
|
377
377
|
) -> torch.Tensor:
|
378
378
|
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
|
379
|
-
|
379
|
+
return self.logits_processor(
|
380
380
|
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
381
381
|
)
|
382
|
-
|
383
|
-
|
382
|
+
|
383
|
+
def compute_logits(
|
384
|
+
self,
|
385
|
+
input_ids: torch.Tensor,
|
386
|
+
hidden_states: torch.Tensor,
|
387
|
+
input_metadata: InputMetadata,
|
388
|
+
) -> torch.Tensor:
|
389
|
+
logits = self.logits_processor(
|
390
|
+
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
391
|
+
)
|
392
|
+
return logits
|
384
393
|
|
385
394
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
386
395
|
stacked_params_mapping = [
|
@@ -392,24 +401,12 @@ class Qwen2MoeForCausalLM(nn.Module):
|
|
392
401
|
("gate_up_proj", "up_proj", 1),
|
393
402
|
]
|
394
403
|
|
395
|
-
expert_params_mapping =
|
396
|
-
|
397
|
-
|
398
|
-
|
399
|
-
|
400
|
-
|
401
|
-
if weight_name in ["gate_proj", "up_proj"]
|
402
|
-
else "experts.w2_weight"
|
403
|
-
),
|
404
|
-
f"experts.{expert_id}.{weight_name}.weight",
|
405
|
-
expert_id,
|
406
|
-
shard_id,
|
407
|
-
)
|
408
|
-
for expert_id in range(self.config.num_experts)
|
409
|
-
for shard_id, weight_name in enumerate(
|
410
|
-
["gate_proj", "down_proj", "up_proj"]
|
411
|
-
)
|
412
|
-
]
|
404
|
+
expert_params_mapping = FusedMoE.make_expert_params_mapping(
|
405
|
+
ckpt_gate_proj_name="gate_proj",
|
406
|
+
ckpt_down_proj_name="down_proj",
|
407
|
+
ckpt_up_proj_name="up_proj",
|
408
|
+
num_experts=self.config.num_experts,
|
409
|
+
)
|
413
410
|
|
414
411
|
params_dict = dict(self.named_parameters())
|
415
412
|
for name, loaded_weight in weights:
|
@@ -449,7 +446,7 @@ class Qwen2MoeForCausalLM(nn.Module):
|
|
449
446
|
weight_loader(
|
450
447
|
param,
|
451
448
|
loaded_weight,
|
452
|
-
|
449
|
+
name,
|
453
450
|
shard_id=shard_id,
|
454
451
|
expert_id=expert_id,
|
455
452
|
)
|
sglang/srt/models/stablelm.py
CHANGED
@@ -40,7 +40,6 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
|
40
40
|
from sglang.srt.layers.activation import SiluAndMul
|
41
41
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
42
42
|
from sglang.srt.layers.radix_attention import RadixAttention
|
43
|
-
from sglang.srt.layers.sampler import Sampler
|
44
43
|
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
45
44
|
|
46
45
|
|
@@ -250,7 +249,6 @@ class StableLmForCausalLM(nn.Module):
|
|
250
249
|
self.model = StableLMEpochModel(config, quant_config=quant_config)
|
251
250
|
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
252
251
|
self.logits_processor = LogitsProcessor(config)
|
253
|
-
self.sampler = Sampler()
|
254
252
|
|
255
253
|
@torch.no_grad()
|
256
254
|
def forward(
|
@@ -261,11 +259,9 @@ class StableLmForCausalLM(nn.Module):
|
|
261
259
|
input_embeds: torch.Tensor = None,
|
262
260
|
) -> torch.Tensor:
|
263
261
|
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
|
264
|
-
|
262
|
+
return self.logits_processor(
|
265
263
|
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
266
264
|
)
|
267
|
-
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
|
268
|
-
return sample_output, logits_output
|
269
265
|
|
270
266
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
271
267
|
stacked_params_mapping = [
|
sglang/srt/models/yivl.py
CHANGED
@@ -24,10 +24,7 @@ from vllm.config import CacheConfig
|
|
24
24
|
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
25
25
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
26
26
|
|
27
|
-
from sglang.srt.models.llava import
|
28
|
-
LlavaLlamaForCausalLM,
|
29
|
-
monkey_path_clip_vision_embed_forward,
|
30
|
-
)
|
27
|
+
from sglang.srt.models.llava import LlavaLlamaForCausalLM
|
31
28
|
|
32
29
|
|
33
30
|
class YiVLForCausalLM(LlavaLlamaForCausalLM):
|
@@ -50,7 +47,7 @@ class YiVLForCausalLM(LlavaLlamaForCausalLM):
|
|
50
47
|
self.config._name_or_path,
|
51
48
|
torch_dtype=torch.float16,
|
52
49
|
subfolder=self.vision_tower_subfolder,
|
53
|
-
).cuda
|
50
|
+
).to("cuda")
|
54
51
|
|
55
52
|
self.vision_tower.eval()
|
56
53
|
|
@@ -94,8 +91,6 @@ class YiVLForCausalLM(LlavaLlamaForCausalLM):
|
|
94
91
|
# load language model
|
95
92
|
self.language_model.load_weights(weights)
|
96
93
|
|
97
|
-
monkey_path_clip_vision_embed_forward()
|
98
|
-
|
99
94
|
|
100
95
|
class YiVLMultiModalProjector(nn.Module):
|
101
96
|
def __init__(self, config: LlavaConfig):
|