sglang 0.2.14__py3-none-any.whl → 0.2.14.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/srt/constrained/fsm_cache.py +11 -2
- sglang/srt/constrained/jump_forward.py +1 -0
- sglang/srt/layers/activation.py +83 -7
- sglang/srt/layers/layernorm.py +0 -3
- sglang/srt/layers/logits_processor.py +4 -4
- sglang/srt/layers/sampler.py +15 -68
- sglang/srt/managers/schedule_batch.py +15 -20
- sglang/srt/managers/tp_worker.py +40 -33
- sglang/srt/model_executor/cuda_graph_runner.py +17 -31
- sglang/srt/model_executor/forward_batch_info.py +1 -8
- sglang/srt/model_executor/model_runner.py +5 -11
- sglang/srt/models/chatglm.py +12 -4
- sglang/srt/models/commandr.py +1 -5
- sglang/srt/models/dbrx.py +1 -5
- sglang/srt/models/deepseek.py +1 -5
- sglang/srt/models/deepseek_v2.py +1 -5
- sglang/srt/models/gemma.py +1 -5
- sglang/srt/models/gemma2.py +1 -5
- sglang/srt/models/gpt_bigcode.py +2 -6
- sglang/srt/models/grok.py +1 -5
- sglang/srt/models/internlm2.py +1 -5
- sglang/srt/models/llama2.py +3 -7
- sglang/srt/models/llama_classification.py +2 -2
- sglang/srt/models/minicpm.py +1 -5
- sglang/srt/models/mixtral.py +1 -5
- sglang/srt/models/mixtral_quant.py +1 -5
- sglang/srt/models/qwen.py +2 -5
- sglang/srt/models/qwen2.py +2 -6
- sglang/srt/models/qwen2_moe.py +14 -5
- sglang/srt/models/stablelm.py +1 -5
- sglang/srt/openai_api/adapter.py +85 -4
- sglang/srt/openai_api/protocol.py +2 -0
- sglang/srt/sampling/sampling_batch_info.py +1 -74
- sglang/srt/sampling/sampling_params.py +4 -0
- sglang/srt/server.py +8 -1
- sglang/test/runners.py +1 -1
- sglang/version.py +1 -1
- {sglang-0.2.14.dist-info → sglang-0.2.14.post1.dist-info}/METADATA +10 -4
- {sglang-0.2.14.dist-info → sglang-0.2.14.post1.dist-info}/RECORD +42 -42
- {sglang-0.2.14.dist-info → sglang-0.2.14.post1.dist-info}/WHEEL +1 -1
- {sglang-0.2.14.dist-info → sglang-0.2.14.post1.dist-info}/LICENSE +0 -0
- {sglang-0.2.14.dist-info → sglang-0.2.14.post1.dist-info}/top_level.txt +0 -0
@@ -17,6 +17,7 @@ limitations under the License.
|
|
17
17
|
|
18
18
|
import bisect
|
19
19
|
from contextlib import contextmanager
|
20
|
+
from typing import Callable, List
|
20
21
|
|
21
22
|
import torch
|
22
23
|
from flashinfer import BatchDecodeWithPagedKVCacheWrapper
|
@@ -25,18 +26,16 @@ from vllm.distributed.parallel_state import graph_capture
|
|
25
26
|
from vllm.model_executor.custom_op import CustomOp
|
26
27
|
|
27
28
|
from sglang.srt.layers.logits_processor import (
|
29
|
+
LogitProcessorOutput,
|
28
30
|
LogitsMetadata,
|
29
31
|
LogitsProcessor,
|
30
|
-
LogitsProcessorOutput,
|
31
32
|
)
|
32
|
-
from sglang.srt.layers.sampler import SampleOutput
|
33
33
|
from sglang.srt.managers.schedule_batch import ScheduleBatch
|
34
34
|
from sglang.srt.model_executor.forward_batch_info import (
|
35
35
|
ForwardMode,
|
36
36
|
InputMetadata,
|
37
37
|
update_flashinfer_indices,
|
38
38
|
)
|
39
|
-
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
40
39
|
from sglang.srt.utils import monkey_patch_vllm_all_gather
|
41
40
|
|
42
41
|
|
@@ -53,12 +52,12 @@ def _to_torch(model: torch.nn.Module, reverse: bool = False):
|
|
53
52
|
|
54
53
|
@contextmanager
|
55
54
|
def patch_model(
|
56
|
-
model: torch.nn.Module,
|
55
|
+
model: torch.nn.Module, enable_compile: bool, tp_group: "GroupCoordinator"
|
57
56
|
):
|
58
57
|
backup_ca_comm = None
|
59
58
|
|
60
59
|
try:
|
61
|
-
if
|
60
|
+
if enable_compile:
|
62
61
|
_to_torch(model)
|
63
62
|
monkey_patch_vllm_all_gather()
|
64
63
|
backup_ca_comm = tp_group.ca_comm
|
@@ -67,7 +66,7 @@ def patch_model(
|
|
67
66
|
else:
|
68
67
|
yield model.forward
|
69
68
|
finally:
|
70
|
-
if
|
69
|
+
if enable_compile:
|
71
70
|
_to_torch(model, reverse=True)
|
72
71
|
monkey_patch_vllm_all_gather(reverse=True)
|
73
72
|
tp_group.ca_comm = backup_ca_comm
|
@@ -88,7 +87,7 @@ def set_torch_compile_config():
|
|
88
87
|
class CudaGraphRunner:
|
89
88
|
def __init__(
|
90
89
|
self,
|
91
|
-
model_runner,
|
90
|
+
model_runner: "ModelRunner",
|
92
91
|
max_batch_size_to_capture: int,
|
93
92
|
use_torch_compile: bool,
|
94
93
|
disable_padding: bool,
|
@@ -145,22 +144,18 @@ class CudaGraphRunner:
|
|
145
144
|
self.flashinfer_kv_indices.clone(),
|
146
145
|
]
|
147
146
|
|
148
|
-
# Sampling inputs
|
149
|
-
vocab_size = model_runner.model_config.vocab_size
|
150
|
-
self.sampling_info = SamplingBatchInfo.dummy_one(self.max_bs, vocab_size)
|
151
|
-
|
152
147
|
self.compile_bs = [1, 2, 4, 8, 16, 24, 32] if use_torch_compile else []
|
153
148
|
|
154
149
|
if use_torch_compile:
|
155
150
|
set_torch_compile_config()
|
156
151
|
|
157
|
-
def can_run(self, batch_size):
|
152
|
+
def can_run(self, batch_size: int):
|
158
153
|
if self.disable_padding:
|
159
154
|
return batch_size in self.graphs
|
160
155
|
else:
|
161
156
|
return batch_size <= self.max_bs
|
162
157
|
|
163
|
-
def capture(self, batch_size_list):
|
158
|
+
def capture(self, batch_size_list: List[int]):
|
164
159
|
self.batch_size_list = batch_size_list
|
165
160
|
with graph_capture() as graph_capture_context:
|
166
161
|
self.stream = graph_capture_context.stream
|
@@ -181,7 +176,7 @@ class CudaGraphRunner:
|
|
181
176
|
self.output_buffers[bs] = output_buffers
|
182
177
|
self.flashinfer_handlers[bs] = flashinfer_handler
|
183
178
|
|
184
|
-
def capture_one_batch_size(self, bs, forward):
|
179
|
+
def capture_one_batch_size(self, bs: int, forward: Callable):
|
185
180
|
graph = torch.cuda.CUDAGraph()
|
186
181
|
stream = self.stream
|
187
182
|
|
@@ -240,7 +235,6 @@ class CudaGraphRunner:
|
|
240
235
|
def run_once():
|
241
236
|
input_metadata = InputMetadata(
|
242
237
|
forward_mode=ForwardMode.DECODE,
|
243
|
-
sampling_info=self.sampling_info[:bs],
|
244
238
|
batch_size=bs,
|
245
239
|
req_pool_indices=req_pool_indices,
|
246
240
|
seq_lens=seq_lens,
|
@@ -305,35 +299,27 @@ class CudaGraphRunner:
|
|
305
299
|
self.flashinfer_handlers[bs],
|
306
300
|
)
|
307
301
|
|
308
|
-
# Sampling inputs
|
309
|
-
self.sampling_info.inplace_assign(raw_bs, batch.sampling_info)
|
310
|
-
|
311
302
|
# Replay
|
312
303
|
torch.cuda.synchronize()
|
313
304
|
self.graphs[bs].replay()
|
314
305
|
torch.cuda.synchronize()
|
315
|
-
|
306
|
+
output = self.output_buffers[bs]
|
316
307
|
|
317
308
|
# Unpad
|
318
309
|
if bs != raw_bs:
|
319
|
-
|
320
|
-
next_token_logits=
|
310
|
+
output = LogitProcessorOutput(
|
311
|
+
next_token_logits=output.next_token_logits[:raw_bs],
|
321
312
|
next_token_logprobs=None,
|
322
313
|
normalized_prompt_logprobs=None,
|
323
314
|
input_token_logprobs=None,
|
324
315
|
input_top_logprobs=None,
|
325
316
|
output_top_logprobs=None,
|
326
317
|
)
|
327
|
-
sample_output = SampleOutput(
|
328
|
-
sample_output.success[:raw_bs],
|
329
|
-
sample_output.probs[:raw_bs],
|
330
|
-
sample_output.batch_next_token_ids[:raw_bs],
|
331
|
-
)
|
332
318
|
|
333
319
|
# Extract logprobs
|
334
320
|
if batch.return_logprob:
|
335
|
-
|
336
|
-
|
321
|
+
output.next_token_logprobs = torch.nn.functional.log_softmax(
|
322
|
+
output.next_token_logits, dim=-1
|
337
323
|
)
|
338
324
|
return_top_logprob = any(x > 0 for x in batch.top_logprobs_nums)
|
339
325
|
if return_top_logprob:
|
@@ -341,8 +327,8 @@ class CudaGraphRunner:
|
|
341
327
|
forward_mode=ForwardMode.DECODE,
|
342
328
|
top_logprobs_nums=batch.top_logprobs_nums,
|
343
329
|
)
|
344
|
-
|
345
|
-
|
330
|
+
output.output_top_logprobs = LogitsProcessor.get_top_logprobs(
|
331
|
+
output.next_token_logprobs, logits_metadata
|
346
332
|
)[1]
|
347
333
|
|
348
|
-
return
|
334
|
+
return output
|
@@ -1,5 +1,3 @@
|
|
1
|
-
from __future__ import annotations
|
2
|
-
|
3
1
|
"""
|
4
2
|
Copyright 2023-2024 SGLang Team
|
5
3
|
Licensed under the Apache License, Version 2.0 (the "License");
|
@@ -18,7 +16,7 @@ limitations under the License.
|
|
18
16
|
"""ModelRunner runs the forward passes of the models."""
|
19
17
|
from dataclasses import dataclass
|
20
18
|
from enum import IntEnum, auto
|
21
|
-
from typing import TYPE_CHECKING, List
|
19
|
+
from typing import TYPE_CHECKING, List, Optional
|
22
20
|
|
23
21
|
import numpy as np
|
24
22
|
import torch
|
@@ -28,7 +26,6 @@ from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
|
|
28
26
|
|
29
27
|
if TYPE_CHECKING:
|
30
28
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
31
|
-
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
32
29
|
|
33
30
|
|
34
31
|
class ForwardMode(IntEnum):
|
@@ -45,7 +42,6 @@ class InputMetadata:
|
|
45
42
|
"""Store all inforamtion of a forward pass."""
|
46
43
|
|
47
44
|
forward_mode: ForwardMode
|
48
|
-
sampling_info: SamplingBatchInfo
|
49
45
|
batch_size: int
|
50
46
|
req_pool_indices: torch.Tensor
|
51
47
|
seq_lens: torch.Tensor
|
@@ -183,7 +179,6 @@ class InputMetadata:
|
|
183
179
|
):
|
184
180
|
ret = cls(
|
185
181
|
forward_mode=forward_mode,
|
186
|
-
sampling_info=batch.sampling_info,
|
187
182
|
batch_size=batch.batch_size(),
|
188
183
|
req_pool_indices=batch.req_pool_indices,
|
189
184
|
seq_lens=batch.seq_lens,
|
@@ -194,8 +189,6 @@ class InputMetadata:
|
|
194
189
|
top_logprobs_nums=batch.top_logprobs_nums,
|
195
190
|
)
|
196
191
|
|
197
|
-
ret.sampling_info.prepare_penalties()
|
198
|
-
|
199
192
|
ret.compute_positions(batch)
|
200
193
|
|
201
194
|
ret.compute_extend_infos(batch)
|
@@ -21,7 +21,7 @@ import importlib.resources
|
|
21
21
|
import logging
|
22
22
|
import pkgutil
|
23
23
|
from functools import lru_cache
|
24
|
-
from typing import Optional,
|
24
|
+
from typing import Optional, Type
|
25
25
|
|
26
26
|
import torch
|
27
27
|
import torch.nn as nn
|
@@ -44,8 +44,6 @@ from vllm.model_executor.model_loader import get_model
|
|
44
44
|
from vllm.model_executor.models import ModelRegistry
|
45
45
|
|
46
46
|
from sglang.global_config import global_config
|
47
|
-
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
48
|
-
from sglang.srt.layers.sampler import SampleOutput
|
49
47
|
from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict
|
50
48
|
from sglang.srt.mem_cache.memory_pool import (
|
51
49
|
MHATokenToKVPool,
|
@@ -161,6 +159,8 @@ class ModelRunner:
|
|
161
159
|
"Compute capability below sm80. Use float16 due to lack of bfloat16 support."
|
162
160
|
)
|
163
161
|
self.server_args.dtype = "float16"
|
162
|
+
if torch.cuda.get_device_capability()[1] < 5:
|
163
|
+
raise RuntimeError("SGLang only supports sm75 and above.")
|
164
164
|
|
165
165
|
monkey_patch_vllm_dummy_weight_loader()
|
166
166
|
self.device_config = DeviceConfig()
|
@@ -515,11 +515,7 @@ class ModelRunner:
|
|
515
515
|
|
516
516
|
@torch.inference_mode()
|
517
517
|
def forward_decode(self, batch: ScheduleBatch):
|
518
|
-
if (
|
519
|
-
self.cuda_graph_runner
|
520
|
-
and self.cuda_graph_runner.can_run(len(batch.reqs))
|
521
|
-
and not batch.sampling_info.has_bias()
|
522
|
-
):
|
518
|
+
if self.cuda_graph_runner and self.cuda_graph_runner.can_run(len(batch.reqs)):
|
523
519
|
return self.cuda_graph_runner.replay(batch)
|
524
520
|
|
525
521
|
input_metadata = InputMetadata.from_schedule_batch(
|
@@ -568,9 +564,7 @@ class ModelRunner:
|
|
568
564
|
input_metadata.image_offsets,
|
569
565
|
)
|
570
566
|
|
571
|
-
def forward(
|
572
|
-
self, batch: ScheduleBatch, forward_mode: ForwardMode
|
573
|
-
) -> Tuple[SampleOutput, LogitsProcessorOutput]:
|
567
|
+
def forward(self, batch: ScheduleBatch, forward_mode: ForwardMode):
|
574
568
|
if self.is_multimodal_model and forward_mode == ForwardMode.EXTEND:
|
575
569
|
return self.forward_extend_multi_modal(batch)
|
576
570
|
elif forward_mode == ForwardMode.DECODE:
|
sglang/srt/models/chatglm.py
CHANGED
@@ -31,18 +31,20 @@ from vllm.model_executor.layers.linear import (
|
|
31
31
|
)
|
32
32
|
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
33
33
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
34
|
+
from vllm.model_executor.layers.sampler import Sampler
|
34
35
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
35
36
|
ParallelLMHead,
|
36
37
|
VocabParallelEmbedding,
|
37
38
|
)
|
38
39
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
40
|
+
from vllm.model_executor.sampling_metadata import SamplingMetadata
|
41
|
+
from vllm.sequence import SamplerOutput
|
39
42
|
from vllm.transformers_utils.configs import ChatGLMConfig
|
40
43
|
|
41
44
|
from sglang.srt.layers.activation import SiluAndMul
|
42
45
|
from sglang.srt.layers.layernorm import RMSNorm
|
43
46
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
44
47
|
from sglang.srt.layers.radix_attention import RadixAttention
|
45
|
-
from sglang.srt.layers.sampler import Sampler
|
46
48
|
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
47
49
|
|
48
50
|
LoraConfig = None
|
@@ -381,11 +383,17 @@ class ChatGLMForCausalLM(nn.Module):
|
|
381
383
|
input_metadata: InputMetadata,
|
382
384
|
) -> torch.Tensor:
|
383
385
|
hidden_states = self.transformer(input_ids, positions, input_metadata)
|
384
|
-
|
386
|
+
return self.logits_processor(
|
385
387
|
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
386
388
|
)
|
387
|
-
|
388
|
-
|
389
|
+
|
390
|
+
def sample(
|
391
|
+
self,
|
392
|
+
logits: torch.Tensor,
|
393
|
+
sampling_metadata: SamplingMetadata,
|
394
|
+
) -> Optional[SamplerOutput]:
|
395
|
+
next_tokens = self.sampler(logits, sampling_metadata)
|
396
|
+
return next_tokens
|
389
397
|
|
390
398
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
391
399
|
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
sglang/srt/models/commandr.py
CHANGED
@@ -64,7 +64,6 @@ from vllm.model_executor.utils import set_weight_attrs
|
|
64
64
|
from sglang.srt.layers.activation import SiluAndMul
|
65
65
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
66
66
|
from sglang.srt.layers.radix_attention import RadixAttention
|
67
|
-
from sglang.srt.layers.sampler import Sampler
|
68
67
|
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
69
68
|
|
70
69
|
|
@@ -327,7 +326,6 @@ class CohereForCausalLM(nn.Module):
|
|
327
326
|
self.config = config
|
328
327
|
self.quant_config = quant_config
|
329
328
|
self.logits_processor = LogitsProcessor(config)
|
330
|
-
self.sampler = Sampler()
|
331
329
|
self.model = CohereModel(config, quant_config)
|
332
330
|
|
333
331
|
@torch.no_grad()
|
@@ -342,11 +340,9 @@ class CohereForCausalLM(nn.Module):
|
|
342
340
|
positions,
|
343
341
|
input_metadata,
|
344
342
|
)
|
345
|
-
|
343
|
+
return self.logits_processor(
|
346
344
|
input_ids, hidden_states, self.model.embed_tokens.weight, input_metadata
|
347
345
|
)
|
348
|
-
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
|
349
|
-
return sample_output, logits_output
|
350
346
|
|
351
347
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
352
348
|
stacked_params_mapping = [
|
sglang/srt/models/dbrx.py
CHANGED
@@ -45,7 +45,6 @@ from vllm.transformers_utils.configs.dbrx import DbrxConfig
|
|
45
45
|
|
46
46
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
47
47
|
from sglang.srt.layers.radix_attention import RadixAttention
|
48
|
-
from sglang.srt.layers.sampler import Sampler
|
49
48
|
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
50
49
|
|
51
50
|
|
@@ -383,7 +382,6 @@ class DbrxForCausalLM(nn.Module):
|
|
383
382
|
padding_size=DEFAULT_VOCAB_PADDING_SIZE,
|
384
383
|
)
|
385
384
|
self.logits_processor = LogitsProcessor(config)
|
386
|
-
self.sampler = Sampler()
|
387
385
|
|
388
386
|
@torch.no_grad()
|
389
387
|
def forward(
|
@@ -393,11 +391,9 @@ class DbrxForCausalLM(nn.Module):
|
|
393
391
|
input_metadata: InputMetadata,
|
394
392
|
) -> torch.Tensor:
|
395
393
|
hidden_states = self.transformer(input_ids, positions, input_metadata)
|
396
|
-
|
394
|
+
return self.logits_processor(
|
397
395
|
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
398
396
|
)
|
399
|
-
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
|
400
|
-
return sample_output, logits_output
|
401
397
|
|
402
398
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
403
399
|
expert_params_mapping = [
|
sglang/srt/models/deepseek.py
CHANGED
@@ -46,7 +46,6 @@ from sglang.srt.layers.activation import SiluAndMul
|
|
46
46
|
from sglang.srt.layers.layernorm import RMSNorm
|
47
47
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
48
48
|
from sglang.srt.layers.radix_attention import RadixAttention
|
49
|
-
from sglang.srt.layers.sampler import Sampler
|
50
49
|
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
51
50
|
|
52
51
|
|
@@ -386,7 +385,6 @@ class DeepseekForCausalLM(nn.Module):
|
|
386
385
|
config.vocab_size, config.hidden_size, quant_config=quant_config
|
387
386
|
)
|
388
387
|
self.logits_processor = LogitsProcessor(config)
|
389
|
-
self.sampler = Sampler()
|
390
388
|
|
391
389
|
@torch.no_grad()
|
392
390
|
def forward(
|
@@ -396,11 +394,9 @@ class DeepseekForCausalLM(nn.Module):
|
|
396
394
|
input_metadata: InputMetadata,
|
397
395
|
) -> torch.Tensor:
|
398
396
|
hidden_states = self.model(input_ids, positions, input_metadata)
|
399
|
-
|
397
|
+
return self.logits_processor(
|
400
398
|
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
401
399
|
)
|
402
|
-
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
|
403
|
-
return sample_output, logits_output
|
404
400
|
|
405
401
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
406
402
|
stacked_params_mapping = [
|
sglang/srt/models/deepseek_v2.py
CHANGED
@@ -45,7 +45,6 @@ from sglang.srt.layers.activation import SiluAndMul
|
|
45
45
|
from sglang.srt.layers.layernorm import RMSNorm
|
46
46
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
47
47
|
from sglang.srt.layers.radix_attention import RadixAttention
|
48
|
-
from sglang.srt.layers.sampler import Sampler
|
49
48
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
50
49
|
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
51
50
|
|
@@ -633,7 +632,6 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
633
632
|
config.vocab_size, config.hidden_size, quant_config=quant_config
|
634
633
|
)
|
635
634
|
self.logits_processor = LogitsProcessor(config)
|
636
|
-
self.sampler = Sampler()
|
637
635
|
|
638
636
|
def forward(
|
639
637
|
self,
|
@@ -642,11 +640,9 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
642
640
|
input_metadata: InputMetadata,
|
643
641
|
) -> torch.Tensor:
|
644
642
|
hidden_states = self.model(input_ids, positions, input_metadata)
|
645
|
-
|
643
|
+
return self.logits_processor(
|
646
644
|
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
647
645
|
)
|
648
|
-
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
|
649
|
-
return sample_output, logits_output
|
650
646
|
|
651
647
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
652
648
|
stacked_params_mapping = [
|
sglang/srt/models/gemma.py
CHANGED
@@ -37,7 +37,6 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
|
37
37
|
from sglang.srt.layers.layernorm import RMSNorm
|
38
38
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
39
39
|
from sglang.srt.layers.radix_attention import RadixAttention
|
40
|
-
from sglang.srt.layers.sampler import Sampler
|
41
40
|
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
42
41
|
|
43
42
|
|
@@ -288,7 +287,6 @@ class GemmaForCausalLM(nn.Module):
|
|
288
287
|
self.quant_config = quant_config
|
289
288
|
self.model = GemmaModel(config, quant_config=quant_config)
|
290
289
|
self.logits_processor = LogitsProcessor(config)
|
291
|
-
self.sampler = Sampler()
|
292
290
|
|
293
291
|
@torch.no_grad()
|
294
292
|
def forward(
|
@@ -299,11 +297,9 @@ class GemmaForCausalLM(nn.Module):
|
|
299
297
|
input_embeds: torch.Tensor = None,
|
300
298
|
) -> torch.Tensor:
|
301
299
|
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
|
302
|
-
|
300
|
+
return self.logits_processor(
|
303
301
|
input_ids, hidden_states, self.model.embed_tokens.weight, input_metadata
|
304
302
|
)
|
305
|
-
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
|
306
|
-
return (sample_output, logits_output)
|
307
303
|
|
308
304
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
309
305
|
stacked_params_mapping = [
|
sglang/srt/models/gemma2.py
CHANGED
@@ -41,7 +41,6 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
|
41
41
|
from sglang.srt.layers.activation import GeluAndMul
|
42
42
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
43
43
|
from sglang.srt.layers.radix_attention import RadixAttention
|
44
|
-
from sglang.srt.layers.sampler import Sampler
|
45
44
|
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
46
45
|
|
47
46
|
|
@@ -397,7 +396,6 @@ class Gemma2ForCausalLM(nn.Module):
|
|
397
396
|
self.quant_config = quant_config
|
398
397
|
self.model = Gemma2Model(config, cache_config, quant_config)
|
399
398
|
self.logits_processor = LogitsProcessor(config)
|
400
|
-
self.sampler = Sampler()
|
401
399
|
|
402
400
|
@torch.no_grad()
|
403
401
|
def forward(
|
@@ -408,11 +406,9 @@ class Gemma2ForCausalLM(nn.Module):
|
|
408
406
|
input_embeds: torch.Tensor = None,
|
409
407
|
) -> torch.Tensor:
|
410
408
|
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
|
411
|
-
|
409
|
+
return self.logits_processor(
|
412
410
|
input_ids, hidden_states, self.model.embed_tokens.weight, input_metadata
|
413
411
|
)
|
414
|
-
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
|
415
|
-
return sample_output, logits_output
|
416
412
|
|
417
413
|
def get_attention_sliding_window_size(self):
|
418
414
|
return get_attention_sliding_window_size(self.config)
|
sglang/srt/models/gpt_bigcode.py
CHANGED
@@ -23,7 +23,6 @@ from torch import nn
|
|
23
23
|
from transformers import GPTBigCodeConfig
|
24
24
|
from vllm.config import CacheConfig, LoRAConfig
|
25
25
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
26
|
-
from vllm.model_executor.layers.activation import get_act_fn
|
27
26
|
from vllm.model_executor.layers.linear import (
|
28
27
|
ColumnParallelLinear,
|
29
28
|
QKVParallelLinear,
|
@@ -33,9 +32,9 @@ from vllm.model_executor.layers.quantization.base_config import QuantizationConf
|
|
33
32
|
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
34
33
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
35
34
|
|
35
|
+
from sglang.srt.layers.activation import get_act_fn
|
36
36
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
37
37
|
from sglang.srt.layers.radix_attention import RadixAttention
|
38
|
-
from sglang.srt.layers.sampler import Sampler
|
39
38
|
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
40
39
|
|
41
40
|
|
@@ -262,7 +261,6 @@ class GPTBigCodeForCausalLM(nn.Module):
|
|
262
261
|
if lora_config:
|
263
262
|
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
|
264
263
|
self.logits_processor = LogitsProcessor(config)
|
265
|
-
self.sampler = Sampler()
|
266
264
|
|
267
265
|
@torch.no_grad()
|
268
266
|
def forward(
|
@@ -272,11 +270,9 @@ class GPTBigCodeForCausalLM(nn.Module):
|
|
272
270
|
input_metadata: InputMetadata,
|
273
271
|
) -> torch.Tensor:
|
274
272
|
hidden_states = self.transformer(input_ids, positions, input_metadata)
|
275
|
-
|
273
|
+
return self.logits_processor(
|
276
274
|
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
277
275
|
)
|
278
|
-
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
|
279
|
-
return sample_output, logits_output
|
280
276
|
|
281
277
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
282
278
|
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
sglang/srt/models/grok.py
CHANGED
@@ -46,7 +46,6 @@ from sglang.srt.layers.fused_moe import FusedMoE
|
|
46
46
|
from sglang.srt.layers.layernorm import RMSNorm
|
47
47
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
48
48
|
from sglang.srt.layers.radix_attention import RadixAttention
|
49
|
-
from sglang.srt.layers.sampler import Sampler
|
50
49
|
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
51
50
|
|
52
51
|
|
@@ -298,7 +297,6 @@ class Grok1ModelForCausalLM(nn.Module):
|
|
298
297
|
self.model = Grok1Model(config, quant_config=quant_config)
|
299
298
|
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
300
299
|
self.logits_processor = LogitsProcessor(config)
|
301
|
-
self.sampler = Sampler()
|
302
300
|
|
303
301
|
# Monkey patch _prepare_weights to load pre-sharded weights
|
304
302
|
setattr(DefaultModelLoader, "_prepare_weights", _prepare_presharded_weights)
|
@@ -315,11 +313,9 @@ class Grok1ModelForCausalLM(nn.Module):
|
|
315
313
|
input_embeds: torch.Tensor = None,
|
316
314
|
) -> torch.Tensor:
|
317
315
|
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
|
318
|
-
|
316
|
+
return self.logits_processor(
|
319
317
|
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
320
318
|
)
|
321
|
-
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
|
322
|
-
return sample_output, logits_output
|
323
319
|
|
324
320
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
325
321
|
stacked_params_mapping = [
|
sglang/srt/models/internlm2.py
CHANGED
@@ -40,7 +40,6 @@ from sglang.srt.layers.activation import SiluAndMul
|
|
40
40
|
from sglang.srt.layers.layernorm import RMSNorm
|
41
41
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
42
42
|
from sglang.srt.layers.radix_attention import RadixAttention
|
43
|
-
from sglang.srt.layers.sampler import Sampler
|
44
43
|
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
45
44
|
|
46
45
|
|
@@ -263,7 +262,6 @@ class InternLM2ForCausalLM(nn.Module):
|
|
263
262
|
self.model = InternLM2Model(config, quant_config)
|
264
263
|
self.output = ParallelLMHead(config.vocab_size, config.hidden_size)
|
265
264
|
self.logits_processor = LogitsProcessor(config)
|
266
|
-
self.sampler = Sampler()
|
267
265
|
|
268
266
|
@torch.no_grad()
|
269
267
|
def forward(
|
@@ -274,11 +272,9 @@ class InternLM2ForCausalLM(nn.Module):
|
|
274
272
|
input_embeds: torch.Tensor = None,
|
275
273
|
) -> torch.Tensor:
|
276
274
|
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
|
277
|
-
|
275
|
+
return self.logits_processor(
|
278
276
|
input_ids, hidden_states, self.output.weight, input_metadata
|
279
277
|
)
|
280
|
-
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
|
281
|
-
return sample_output, logits_output
|
282
278
|
|
283
279
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
284
280
|
stacked_params_mapping = [
|
sglang/srt/models/llama2.py
CHANGED
@@ -39,9 +39,8 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
|
39
39
|
|
40
40
|
from sglang.srt.layers.activation import SiluAndMul
|
41
41
|
from sglang.srt.layers.layernorm import RMSNorm
|
42
|
-
from sglang.srt.layers.logits_processor import
|
42
|
+
from sglang.srt.layers.logits_processor import LogitProcessorOutput, LogitsProcessor
|
43
43
|
from sglang.srt.layers.radix_attention import RadixAttention
|
44
|
-
from sglang.srt.layers.sampler import Sampler
|
45
44
|
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
46
45
|
|
47
46
|
|
@@ -303,7 +302,6 @@ class LlamaForCausalLM(nn.Module):
|
|
303
302
|
self.model = LlamaModel(config, quant_config=quant_config)
|
304
303
|
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
305
304
|
self.logits_processor = LogitsProcessor(config)
|
306
|
-
self.sampler = Sampler()
|
307
305
|
|
308
306
|
@torch.no_grad()
|
309
307
|
def forward(
|
@@ -312,13 +310,11 @@ class LlamaForCausalLM(nn.Module):
|
|
312
310
|
positions: torch.Tensor,
|
313
311
|
input_metadata: InputMetadata,
|
314
312
|
input_embeds: torch.Tensor = None,
|
315
|
-
) ->
|
313
|
+
) -> LogitProcessorOutput:
|
316
314
|
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
|
317
|
-
|
315
|
+
return self.logits_processor(
|
318
316
|
input_ids, hidden_states, self.lm_head.weight, input_metadata
|
319
317
|
)
|
320
|
-
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
|
321
|
-
return sample_output, logits_output
|
322
318
|
|
323
319
|
def get_module_name(self, name):
|
324
320
|
stacked_params_mapping = [
|
@@ -24,7 +24,7 @@ from vllm.distributed import get_tensor_model_parallel_rank
|
|
24
24
|
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
|
25
25
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
26
26
|
|
27
|
-
from sglang.srt.layers.logits_processor import
|
27
|
+
from sglang.srt.layers.logits_processor import LogitProcessorOutput
|
28
28
|
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
29
29
|
from sglang.srt.models.llama2 import LlamaModel
|
30
30
|
|
@@ -65,7 +65,7 @@ class LlamaForClassification(nn.Module):
|
|
65
65
|
(input_metadata.batch_size, self.config.classification_out_size)
|
66
66
|
).to(input_ids.device)
|
67
67
|
|
68
|
-
return
|
68
|
+
return LogitProcessorOutput(
|
69
69
|
next_token_logits=scores,
|
70
70
|
next_token_logprobs=scores,
|
71
71
|
normalized_prompt_logprobs=scores,
|
sglang/srt/models/minicpm.py
CHANGED
@@ -39,7 +39,6 @@ from sglang.srt.layers.activation import SiluAndMul
|
|
39
39
|
from sglang.srt.layers.layernorm import RMSNorm
|
40
40
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
41
41
|
from sglang.srt.layers.radix_attention import RadixAttention
|
42
|
-
from sglang.srt.layers.sampler import Sampler
|
43
42
|
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
44
43
|
|
45
44
|
|
@@ -298,7 +297,6 @@ class MiniCPMForCausalLM(nn.Module):
|
|
298
297
|
self.scale_width = self.config.hidden_size / self.config.dim_model_base
|
299
298
|
|
300
299
|
self.logits_processor = LogitsProcessor(config)
|
301
|
-
self.sampler = Sampler()
|
302
300
|
|
303
301
|
@torch.no_grad()
|
304
302
|
def forward(
|
@@ -316,11 +314,9 @@ class MiniCPMForCausalLM(nn.Module):
|
|
316
314
|
lm_head_weight = self.model.embed_tokens.weight
|
317
315
|
else:
|
318
316
|
lm_head_weight = self.lm_head.weight
|
319
|
-
|
317
|
+
return self.logits_processor(
|
320
318
|
input_ids, hidden_states, lm_head_weight, input_metadata
|
321
319
|
)
|
322
|
-
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
|
323
|
-
return sample_output, logits_output
|
324
320
|
|
325
321
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
326
322
|
stacked_params_mapping = [
|