sglang 0.2.12__py3-none-any.whl → 0.2.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/api.py +13 -1
- sglang/bench_latency.py +10 -5
- sglang/bench_serving.py +50 -26
- sglang/check_env.py +15 -0
- sglang/global_config.py +1 -1
- sglang/lang/backend/runtime_endpoint.py +60 -49
- sglang/lang/chat_template.py +10 -5
- sglang/lang/compiler.py +4 -0
- sglang/lang/interpreter.py +5 -2
- sglang/lang/ir.py +22 -4
- sglang/launch_server.py +8 -1
- sglang/srt/constrained/jump_forward.py +13 -2
- sglang/srt/conversation.py +50 -1
- sglang/srt/hf_transformers_utils.py +22 -23
- sglang/srt/layers/activation.py +24 -2
- sglang/srt/layers/decode_attention.py +338 -50
- sglang/srt/layers/extend_attention.py +3 -1
- sglang/srt/layers/fused_moe/__init__.py +1 -0
- sglang/srt/layers/{fused_moe.py → fused_moe/fused_moe.py} +165 -108
- sglang/srt/layers/fused_moe/layer.py +587 -0
- sglang/srt/layers/layernorm.py +3 -0
- sglang/srt/layers/logits_processor.py +64 -27
- sglang/srt/layers/radix_attention.py +41 -18
- sglang/srt/layers/sampler.py +154 -0
- sglang/srt/managers/controller_multi.py +2 -8
- sglang/srt/managers/controller_single.py +7 -10
- sglang/srt/managers/detokenizer_manager.py +20 -9
- sglang/srt/managers/io_struct.py +44 -11
- sglang/srt/managers/policy_scheduler.py +5 -2
- sglang/srt/managers/schedule_batch.py +59 -179
- sglang/srt/managers/tokenizer_manager.py +193 -84
- sglang/srt/managers/tp_worker.py +131 -50
- sglang/srt/mem_cache/memory_pool.py +82 -8
- sglang/srt/mm_utils.py +79 -7
- sglang/srt/model_executor/cuda_graph_runner.py +97 -28
- sglang/srt/model_executor/forward_batch_info.py +188 -82
- sglang/srt/model_executor/model_runner.py +269 -87
- sglang/srt/models/chatglm.py +6 -14
- sglang/srt/models/commandr.py +6 -2
- sglang/srt/models/dbrx.py +5 -1
- sglang/srt/models/deepseek.py +7 -3
- sglang/srt/models/deepseek_v2.py +12 -7
- sglang/srt/models/gemma.py +6 -2
- sglang/srt/models/gemma2.py +22 -8
- sglang/srt/models/gpt_bigcode.py +5 -1
- sglang/srt/models/grok.py +66 -398
- sglang/srt/models/internlm2.py +5 -1
- sglang/srt/models/llama2.py +7 -3
- sglang/srt/models/llama_classification.py +2 -2
- sglang/srt/models/llama_embedding.py +4 -0
- sglang/srt/models/llava.py +176 -59
- sglang/srt/models/minicpm.py +7 -3
- sglang/srt/models/mixtral.py +61 -255
- sglang/srt/models/mixtral_quant.py +6 -5
- sglang/srt/models/qwen.py +7 -4
- sglang/srt/models/qwen2.py +15 -5
- sglang/srt/models/qwen2_moe.py +7 -16
- sglang/srt/models/stablelm.py +6 -2
- sglang/srt/openai_api/adapter.py +149 -58
- sglang/srt/sampling/sampling_batch_info.py +209 -0
- sglang/srt/{sampling_params.py → sampling/sampling_params.py} +18 -4
- sglang/srt/server.py +107 -71
- sglang/srt/server_args.py +49 -15
- sglang/srt/utils.py +27 -18
- sglang/test/runners.py +38 -38
- sglang/test/simple_eval_common.py +9 -10
- sglang/test/simple_eval_gpqa.py +2 -1
- sglang/test/simple_eval_humaneval.py +2 -2
- sglang/test/simple_eval_math.py +2 -1
- sglang/test/simple_eval_mmlu.py +2 -1
- sglang/test/test_activation.py +55 -0
- sglang/test/test_programs.py +32 -5
- sglang/test/test_utils.py +37 -50
- sglang/version.py +1 -1
- {sglang-0.2.12.dist-info → sglang-0.2.14.dist-info}/METADATA +102 -27
- sglang-0.2.14.dist-info/RECORD +114 -0
- {sglang-0.2.12.dist-info → sglang-0.2.14.dist-info}/WHEEL +1 -1
- sglang/launch_server_llavavid.py +0 -29
- sglang/srt/model_loader/model_loader.py +0 -292
- sglang/srt/model_loader/utils.py +0 -275
- sglang-0.2.12.dist-info/RECORD +0 -112
- {sglang-0.2.12.dist-info → sglang-0.2.14.dist-info}/LICENSE +0 -0
- {sglang-0.2.12.dist-info → sglang-0.2.14.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,209 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import dataclasses
|
4
|
+
from typing import TYPE_CHECKING, List
|
5
|
+
|
6
|
+
import torch
|
7
|
+
|
8
|
+
import sglang.srt.sampling.penaltylib as penaltylib
|
9
|
+
|
10
|
+
if TYPE_CHECKING:
|
11
|
+
from sglang.srt.managers.schedule_batch import ScheduleBatch
|
12
|
+
|
13
|
+
|
14
|
+
@dataclasses.dataclass
|
15
|
+
class SamplingBatchInfo:
|
16
|
+
# Basic Info
|
17
|
+
vocab_size: int
|
18
|
+
|
19
|
+
# Batched sampling params
|
20
|
+
temperatures: torch.Tensor = None
|
21
|
+
top_ps: torch.Tensor = None
|
22
|
+
top_ks: torch.Tensor = None
|
23
|
+
min_ps: torch.Tensor = None
|
24
|
+
|
25
|
+
# Dispatch in CUDA graph
|
26
|
+
need_min_p_sampling: bool = False
|
27
|
+
|
28
|
+
# Bias Tensors
|
29
|
+
logit_bias: torch.Tensor = None
|
30
|
+
vocab_mask: torch.Tensor = None
|
31
|
+
|
32
|
+
# Penalizer
|
33
|
+
penalizer_orchestrator: penaltylib.BatchedPenalizerOrchestrator = None
|
34
|
+
linear_penalties: torch.Tensor = None
|
35
|
+
scaling_penalties: torch.Tensor = None
|
36
|
+
|
37
|
+
def has_bias(self):
|
38
|
+
return (
|
39
|
+
self.logit_bias is not None
|
40
|
+
or self.vocab_mask is not None
|
41
|
+
or self.linear_penalties is not None
|
42
|
+
or self.scaling_penalties is not None
|
43
|
+
)
|
44
|
+
|
45
|
+
@classmethod
|
46
|
+
def dummy_one(cls, max_bs: int, vocab_size: int):
|
47
|
+
ret = cls(vocab_size=vocab_size)
|
48
|
+
ret.temperatures = torch.ones((max_bs, 1), dtype=torch.float, device="cuda")
|
49
|
+
ret.top_ps = torch.ones((max_bs,), dtype=torch.float, device="cuda")
|
50
|
+
ret.top_ks = torch.ones((max_bs,), dtype=torch.int, device="cuda")
|
51
|
+
ret.min_ps = torch.zeros((max_bs,), dtype=torch.float, device="cuda")
|
52
|
+
return ret
|
53
|
+
|
54
|
+
def __getitem__(self, key):
|
55
|
+
if isinstance(key, slice):
|
56
|
+
# NOTE: We do not use cuda graph when there is bias tensors
|
57
|
+
assert not self.has_bias()
|
58
|
+
return SamplingBatchInfo(
|
59
|
+
vocab_size=self.vocab_size,
|
60
|
+
temperatures=self.temperatures[key],
|
61
|
+
top_ps=self.top_ps[key],
|
62
|
+
top_ks=self.top_ks[key],
|
63
|
+
min_ps=self.min_ps[key],
|
64
|
+
need_min_p_sampling=self.need_min_p_sampling,
|
65
|
+
)
|
66
|
+
else:
|
67
|
+
raise NotImplementedError
|
68
|
+
|
69
|
+
def inplace_assign(self, bs: int, other: SamplingBatchInfo):
|
70
|
+
# NOTE: We do not use cuda graph when there is bias tensors
|
71
|
+
assert not self.has_bias()
|
72
|
+
|
73
|
+
self.vocab_size = other.vocab_size
|
74
|
+
self.need_min_p_sampling = other.need_min_p_sampling
|
75
|
+
|
76
|
+
self.temperatures[:bs] = other.temperatures
|
77
|
+
self.top_ps[:bs] = other.top_ps
|
78
|
+
self.top_ks[:bs] = other.top_ks
|
79
|
+
self.min_ps[:bs] = other.min_ps
|
80
|
+
|
81
|
+
@classmethod
|
82
|
+
def from_schedule_batch(cls, batch: ScheduleBatch, vocab_size: int):
|
83
|
+
device = "cuda"
|
84
|
+
reqs = batch.reqs
|
85
|
+
ret = cls(vocab_size=vocab_size)
|
86
|
+
|
87
|
+
ret.temperatures = torch.tensor(
|
88
|
+
[r.sampling_params.temperature for r in reqs],
|
89
|
+
dtype=torch.float,
|
90
|
+
device=device,
|
91
|
+
).view(-1, 1)
|
92
|
+
ret.top_ps = torch.tensor(
|
93
|
+
[r.sampling_params.top_p for r in reqs], dtype=torch.float, device=device
|
94
|
+
)
|
95
|
+
ret.top_ks = torch.tensor(
|
96
|
+
[r.sampling_params.top_k for r in reqs], dtype=torch.int, device=device
|
97
|
+
)
|
98
|
+
ret.min_ps = torch.tensor(
|
99
|
+
[r.sampling_params.min_p for r in reqs], dtype=torch.float, device=device
|
100
|
+
)
|
101
|
+
ret.need_min_p_sampling = any(r.sampling_params.min_p > 0 for r in reqs)
|
102
|
+
|
103
|
+
# Each penalizers will do nothing if they evaluate themselves as not required by looking at
|
104
|
+
# the sampling_params of the requests (See {_is_required()} of each penalizers). So this
|
105
|
+
# should not add hefty computation overhead other than simple checks.
|
106
|
+
#
|
107
|
+
# While we choose not to even create the class instances if they are not required, this
|
108
|
+
# could add additional complexity to the {ScheduleBatch} class, especially we need to
|
109
|
+
# handle {filter_batch()} and {merge()} cases as well.
|
110
|
+
ret.penalizer_orchestrator = penaltylib.BatchedPenalizerOrchestrator(
|
111
|
+
vocab_size=vocab_size,
|
112
|
+
batch=batch,
|
113
|
+
device=device,
|
114
|
+
Penalizers={
|
115
|
+
penaltylib.BatchedFrequencyPenalizer,
|
116
|
+
penaltylib.BatchedMinNewTokensPenalizer,
|
117
|
+
penaltylib.BatchedPresencePenalizer,
|
118
|
+
penaltylib.BatchedRepetitionPenalizer,
|
119
|
+
},
|
120
|
+
)
|
121
|
+
|
122
|
+
# Handle logit bias but only allocate when needed
|
123
|
+
ret.logit_bias = None
|
124
|
+
|
125
|
+
ret.update_regex_vocab_mask(batch)
|
126
|
+
|
127
|
+
return ret
|
128
|
+
|
129
|
+
def prepare_penalties(self):
|
130
|
+
self.scaling_penalties = None
|
131
|
+
self.linear_penalties = None
|
132
|
+
|
133
|
+
for penalizer in self.penalizer_orchestrator.penalizers.values():
|
134
|
+
if isinstance(penalizer, penaltylib.BatchedRepetitionPenalizer):
|
135
|
+
if penalizer.is_prepared():
|
136
|
+
self.scaling_penalties = penalizer.cumulated_repetition_penalties
|
137
|
+
else:
|
138
|
+
if penalizer.is_prepared():
|
139
|
+
if self.linear_penalties is None:
|
140
|
+
bs = self.penalizer_orchestrator.batch.batch_size()
|
141
|
+
self.linear_penalties = torch.zeros(
|
142
|
+
(bs, self.vocab_size),
|
143
|
+
dtype=torch.float32,
|
144
|
+
device="cuda",
|
145
|
+
)
|
146
|
+
self.linear_penalties = penalizer.apply(self.linear_penalties)
|
147
|
+
|
148
|
+
def update_regex_vocab_mask(self, batch: ScheduleBatch):
|
149
|
+
bs, reqs = batch.batch_size(), batch.reqs
|
150
|
+
device = "cuda"
|
151
|
+
has_regex = any(req.regex_fsm is not None for req in reqs)
|
152
|
+
|
153
|
+
# Reset the vocab mask
|
154
|
+
self.vocab_mask = None
|
155
|
+
|
156
|
+
if has_regex:
|
157
|
+
for i, req in enumerate(reqs):
|
158
|
+
if req.regex_fsm is not None:
|
159
|
+
if self.vocab_mask is None:
|
160
|
+
self.vocab_mask = torch.zeros(
|
161
|
+
bs, self.vocab_size, dtype=torch.bool, device=device
|
162
|
+
)
|
163
|
+
self.vocab_mask[i][
|
164
|
+
req.regex_fsm.get_next_instruction(req.regex_fsm_state).tokens
|
165
|
+
] = 1
|
166
|
+
|
167
|
+
def filter(self, unfinished_indices: List[int], new_indices: torch.Tensor):
|
168
|
+
self.penalizer_orchestrator.filter(unfinished_indices, new_indices)
|
169
|
+
|
170
|
+
for item in [
|
171
|
+
"temperatures",
|
172
|
+
"top_ps",
|
173
|
+
"top_ks",
|
174
|
+
"min_ps",
|
175
|
+
"logit_bias",
|
176
|
+
]:
|
177
|
+
self_val = getattr(self, item, None)
|
178
|
+
if self_val is not None: # logit_bias can be None
|
179
|
+
setattr(self, item, self_val[new_indices])
|
180
|
+
|
181
|
+
def merge(self, other: "SamplingBatchInfo"):
|
182
|
+
self.penalizer_orchestrator.merge(other.penalizer_orchestrator)
|
183
|
+
|
184
|
+
for item in [
|
185
|
+
"temperatures",
|
186
|
+
"top_ps",
|
187
|
+
"top_ks",
|
188
|
+
"min_ps",
|
189
|
+
]:
|
190
|
+
self_val = getattr(self, item, None)
|
191
|
+
other_val = getattr(other, item, None)
|
192
|
+
setattr(self, item, torch.concat([self_val, other_val]))
|
193
|
+
|
194
|
+
# logit_bias can be None
|
195
|
+
if self.logit_bias is not None or other.logit_bias is not None:
|
196
|
+
vocab_size = (
|
197
|
+
self.logit_bias.shape[1]
|
198
|
+
if self.logit_bias is not None
|
199
|
+
else other.logit_bias.shape[1]
|
200
|
+
)
|
201
|
+
if self.logit_bias is None:
|
202
|
+
self.logit_bias = torch.zeros(
|
203
|
+
(len(self.reqs), vocab_size), dtype=torch.float32, device="cuda"
|
204
|
+
)
|
205
|
+
if other.logit_bias is None:
|
206
|
+
other.logit_bias = torch.zeros(
|
207
|
+
(len(other.reqs), vocab_size), dtype=torch.float32, device="cuda"
|
208
|
+
)
|
209
|
+
self.logit_bias = torch.concat([self.logit_bias, other.logit_bias])
|
@@ -30,19 +30,20 @@ class SamplingParams:
|
|
30
30
|
temperature: float = 1.0,
|
31
31
|
top_p: float = 1.0,
|
32
32
|
top_k: int = -1,
|
33
|
+
min_p: float = 0.0,
|
33
34
|
frequency_penalty: float = 0.0,
|
34
35
|
presence_penalty: float = 0.0,
|
35
36
|
repetition_penalty: float = 1.0,
|
36
37
|
ignore_eos: bool = False,
|
37
38
|
skip_special_tokens: bool = True,
|
38
39
|
spaces_between_special_tokens: bool = True,
|
39
|
-
dtype: Optional[str] = None,
|
40
40
|
regex: Optional[str] = None,
|
41
41
|
n: int = 1,
|
42
42
|
) -> None:
|
43
43
|
self.temperature = temperature
|
44
44
|
self.top_p = top_p
|
45
45
|
self.top_k = top_k
|
46
|
+
self.min_p = min_p
|
46
47
|
self.frequency_penalty = frequency_penalty
|
47
48
|
self.presence_penalty = presence_penalty
|
48
49
|
self.repetition_penalty = repetition_penalty
|
@@ -53,7 +54,6 @@ class SamplingParams:
|
|
53
54
|
self.ignore_eos = ignore_eos
|
54
55
|
self.skip_special_tokens = skip_special_tokens
|
55
56
|
self.spaces_between_special_tokens = spaces_between_special_tokens
|
56
|
-
self.dtype = dtype
|
57
57
|
self.regex = regex
|
58
58
|
self.n = n
|
59
59
|
|
@@ -63,8 +63,6 @@ class SamplingParams:
|
|
63
63
|
self.top_k = 1
|
64
64
|
if self.top_k == -1:
|
65
65
|
self.top_k = 1 << 30 # whole vocabulary
|
66
|
-
if self.dtype == "int":
|
67
|
-
self.stop_strs = [" ", "\n"]
|
68
66
|
|
69
67
|
def verify(self):
|
70
68
|
if self.temperature < 0.0:
|
@@ -73,6 +71,8 @@ class SamplingParams:
|
|
73
71
|
)
|
74
72
|
if not 0.0 < self.top_p <= 1.0:
|
75
73
|
raise ValueError(f"top_p must be in (0, 1], got {self.top_p}.")
|
74
|
+
if not 0.0 <= self.min_p <= 1.0:
|
75
|
+
raise ValueError(f"min_p must be in [0, 1], got {self.min_p}.")
|
76
76
|
if self.top_k < -1 or self.top_k == 0:
|
77
77
|
raise ValueError(
|
78
78
|
f"top_k must be -1 (disable), or at least 1, " f"got {self.top_k}."
|
@@ -127,3 +127,17 @@ class SamplingParams:
|
|
127
127
|
else:
|
128
128
|
stop_str_max_len = max(stop_str_max_len, len(stop_str))
|
129
129
|
self.stop_str_max_len = stop_str_max_len
|
130
|
+
|
131
|
+
def to_srt_kwargs(self):
|
132
|
+
return {
|
133
|
+
"max_new_tokens": self.max_new_tokens,
|
134
|
+
"stop": self.stop_strs,
|
135
|
+
"stop_token_ids": list(self.stop_token_ids),
|
136
|
+
"temperature": self.temperature,
|
137
|
+
"top_p": self.top_p,
|
138
|
+
"top_k": self.top_k,
|
139
|
+
"frequency_penalty": self.frequency_penalty,
|
140
|
+
"presence_penalty": self.presence_penalty,
|
141
|
+
"ignore_eos": self.ignore_eos,
|
142
|
+
"regex": self.regex,
|
143
|
+
}
|
sglang/srt/server.py
CHANGED
@@ -24,7 +24,6 @@ import json
|
|
24
24
|
import logging
|
25
25
|
import multiprocessing as mp
|
26
26
|
import os
|
27
|
-
import sys
|
28
27
|
import threading
|
29
28
|
import time
|
30
29
|
from http import HTTPStatus
|
@@ -34,7 +33,6 @@ from typing import Dict, List, Optional, Union
|
|
34
33
|
setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
|
35
34
|
|
36
35
|
import aiohttp
|
37
|
-
import psutil
|
38
36
|
import requests
|
39
37
|
import uvicorn
|
40
38
|
import uvloop
|
@@ -52,7 +50,11 @@ from sglang.srt.managers.controller_single import (
|
|
52
50
|
start_controller_process as start_controller_process_single,
|
53
51
|
)
|
54
52
|
from sglang.srt.managers.detokenizer_manager import start_detokenizer_process
|
55
|
-
from sglang.srt.managers.io_struct import
|
53
|
+
from sglang.srt.managers.io_struct import (
|
54
|
+
EmbeddingReqInput,
|
55
|
+
GenerateReqInput,
|
56
|
+
UpdateWeightReqInput,
|
57
|
+
)
|
56
58
|
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
57
59
|
from sglang.srt.openai_api.adapter import (
|
58
60
|
load_chat_template_for_openai_api,
|
@@ -72,6 +74,7 @@ from sglang.srt.utils import (
|
|
72
74
|
add_api_key_middleware,
|
73
75
|
allocate_init_ports,
|
74
76
|
assert_pkg_version,
|
77
|
+
configure_logger,
|
75
78
|
enable_show_time_cost,
|
76
79
|
kill_child_process,
|
77
80
|
maybe_set_triton_cache_manager,
|
@@ -92,10 +95,25 @@ tokenizer_manager = None
|
|
92
95
|
|
93
96
|
@app.get("/health")
|
94
97
|
async def health() -> Response:
|
95
|
-
"""
|
98
|
+
"""Check the health of the http server."""
|
96
99
|
return Response(status_code=200)
|
97
100
|
|
98
101
|
|
102
|
+
@app.get("/health_generate")
|
103
|
+
async def health_generate(request: Request) -> Response:
|
104
|
+
"""Check the health of the inference server by generating one token."""
|
105
|
+
gri = GenerateReqInput(
|
106
|
+
text="s", sampling_params={"max_new_tokens": 1, "temperature": 0.7}
|
107
|
+
)
|
108
|
+
try:
|
109
|
+
async for _ in tokenizer_manager.generate_request(gri, request):
|
110
|
+
break
|
111
|
+
return Response(status_code=200)
|
112
|
+
except Exception as e:
|
113
|
+
logger.exception(e)
|
114
|
+
return Response(status_code=503)
|
115
|
+
|
116
|
+
|
99
117
|
@app.get("/get_model_info")
|
100
118
|
async def get_model_info():
|
101
119
|
result = {
|
@@ -120,6 +138,23 @@ async def flush_cache():
|
|
120
138
|
)
|
121
139
|
|
122
140
|
|
141
|
+
@app.post("/update_weights")
|
142
|
+
async def update_weights(obj: UpdateWeightReqInput, request: Request):
|
143
|
+
|
144
|
+
success, message = await tokenizer_manager.update_weights(obj, request)
|
145
|
+
content = {"message": message, "success": str(success)}
|
146
|
+
if success:
|
147
|
+
return JSONResponse(
|
148
|
+
content,
|
149
|
+
status_code=HTTPStatus.OK,
|
150
|
+
)
|
151
|
+
else:
|
152
|
+
return JSONResponse(
|
153
|
+
content,
|
154
|
+
status_code=HTTPStatus.BAD_REQUEST,
|
155
|
+
)
|
156
|
+
|
157
|
+
|
123
158
|
async def generate_request(obj: GenerateReqInput, request: Request):
|
124
159
|
"""Handle a generate request."""
|
125
160
|
if obj.stream:
|
@@ -236,15 +271,12 @@ def launch_server(
|
|
236
271
|
"""Launch an HTTP server."""
|
237
272
|
global tokenizer_manager
|
238
273
|
|
239
|
-
|
240
|
-
level=getattr(logging, server_args.log_level.upper()),
|
241
|
-
format="%(message)s",
|
242
|
-
)
|
274
|
+
configure_logger(server_args)
|
243
275
|
|
244
276
|
server_args.check_server_args()
|
245
277
|
_set_envs_and_config(server_args)
|
246
278
|
|
247
|
-
# Allocate ports
|
279
|
+
# Allocate ports for inter-process communications
|
248
280
|
server_args.port, server_args.additional_ports = allocate_init_ports(
|
249
281
|
server_args.port,
|
250
282
|
server_args.additional_ports,
|
@@ -264,30 +296,34 @@ def launch_server(
|
|
264
296
|
server_args.tokenizer_path = prepare_tokenizer(server_args.tokenizer_path)
|
265
297
|
|
266
298
|
# Launch processes for multi-node tensor parallelism
|
267
|
-
if server_args.nnodes > 1:
|
268
|
-
|
269
|
-
|
270
|
-
|
271
|
-
|
272
|
-
|
273
|
-
|
274
|
-
range(
|
275
|
-
server_args.node_rank * tp_size_local,
|
276
|
-
(server_args.node_rank + 1) * tp_size_local,
|
277
|
-
)
|
278
|
-
)
|
279
|
-
procs = launch_tp_servers(
|
280
|
-
gpu_ids,
|
281
|
-
tp_rank_range,
|
282
|
-
server_args,
|
283
|
-
ports[3],
|
284
|
-
model_overide_args,
|
299
|
+
if server_args.nnodes > 1 and server_args.node_rank != 0:
|
300
|
+
tp_size_local = server_args.tp_size // server_args.nnodes
|
301
|
+
gpu_ids = [i for _ in range(server_args.nnodes) for i in range(tp_size_local)]
|
302
|
+
tp_rank_range = list(
|
303
|
+
range(
|
304
|
+
server_args.node_rank * tp_size_local,
|
305
|
+
(server_args.node_rank + 1) * tp_size_local,
|
285
306
|
)
|
286
|
-
|
287
|
-
|
307
|
+
)
|
308
|
+
procs = launch_tp_servers(
|
309
|
+
gpu_ids,
|
310
|
+
tp_rank_range,
|
311
|
+
server_args,
|
312
|
+
ports[3],
|
313
|
+
model_overide_args,
|
314
|
+
)
|
315
|
+
|
316
|
+
try:
|
317
|
+
for p in procs:
|
318
|
+
p.join()
|
319
|
+
finally:
|
320
|
+
kill_child_process(os.getpid(), including_parent=False)
|
321
|
+
return
|
288
322
|
|
289
323
|
# Launch processes
|
290
324
|
tokenizer_manager = TokenizerManager(server_args, port_args, model_overide_args)
|
325
|
+
if server_args.chat_template:
|
326
|
+
load_chat_template_for_openai_api(tokenizer_manager, server_args.chat_template)
|
291
327
|
pipe_controller_reader, pipe_controller_writer = mp.Pipe(duplex=False)
|
292
328
|
pipe_detoken_reader, pipe_detoken_writer = mp.Pipe(duplex=False)
|
293
329
|
|
@@ -295,11 +331,13 @@ def launch_server(
|
|
295
331
|
start_process = start_controller_process_single
|
296
332
|
else:
|
297
333
|
start_process = start_controller_process_multi
|
334
|
+
|
298
335
|
proc_controller = mp.Process(
|
299
336
|
target=start_process,
|
300
337
|
args=(server_args, port_args, pipe_controller_writer, model_overide_args),
|
301
338
|
)
|
302
339
|
proc_controller.start()
|
340
|
+
|
303
341
|
proc_detoken = mp.Process(
|
304
342
|
target=start_detokenizer_process,
|
305
343
|
args=(
|
@@ -317,15 +355,11 @@ def launch_server(
|
|
317
355
|
if controller_init_state != "init ok" or detoken_init_state != "init ok":
|
318
356
|
proc_controller.kill()
|
319
357
|
proc_detoken.kill()
|
320
|
-
|
321
|
-
|
322
|
-
|
323
|
-
|
324
|
-
print(
|
325
|
-
f"Initialization failed. detoken_init_state: {detoken_init_state}",
|
326
|
-
flush=True,
|
358
|
+
raise RuntimeError(
|
359
|
+
"Initialization failed. "
|
360
|
+
f"controller_init_state: {controller_init_state}, "
|
361
|
+
f"detoken_init_state: {detoken_init_state}"
|
327
362
|
)
|
328
|
-
sys.exit(1)
|
329
363
|
assert proc_controller.is_alive() and proc_detoken.is_alive()
|
330
364
|
|
331
365
|
# Add api key authorization
|
@@ -334,12 +368,12 @@ def launch_server(
|
|
334
368
|
|
335
369
|
# Send a warmup request
|
336
370
|
t = threading.Thread(
|
337
|
-
target=_wait_and_warmup, args=(server_args, pipe_finish_writer)
|
371
|
+
target=_wait_and_warmup, args=(server_args, pipe_finish_writer, os.getpid())
|
338
372
|
)
|
339
373
|
t.start()
|
340
374
|
|
341
|
-
# Listen for requests
|
342
375
|
try:
|
376
|
+
# Listen for requests
|
343
377
|
uvicorn.run(
|
344
378
|
app,
|
345
379
|
host=server_args.host,
|
@@ -358,6 +392,7 @@ def _set_envs_and_config(server_args: ServerArgs):
|
|
358
392
|
os.environ["NCCL_CUMEM_ENABLE"] = "0"
|
359
393
|
os.environ["NCCL_NVLS_ENABLE"] = "0"
|
360
394
|
os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
|
395
|
+
os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1"
|
361
396
|
|
362
397
|
# Set ulimit
|
363
398
|
set_ulimit()
|
@@ -375,23 +410,18 @@ def _set_envs_and_config(server_args: ServerArgs):
|
|
375
410
|
# FIXME: remove this after https://github.com/triton-lang/triton/pull/4295 is used as a dependency.
|
376
411
|
maybe_set_triton_cache_manager()
|
377
412
|
|
378
|
-
# Set global chat template
|
379
|
-
if server_args.chat_template:
|
380
|
-
# TODO: replace this with huggingface transformers template
|
381
|
-
load_chat_template_for_openai_api(server_args.chat_template)
|
382
|
-
|
383
413
|
# Check flashinfer version
|
384
414
|
if not server_args.disable_flashinfer:
|
385
415
|
assert_pkg_version(
|
386
416
|
"flashinfer",
|
387
|
-
"0.1.
|
417
|
+
"0.1.5",
|
388
418
|
"Please uninstall the old version and "
|
389
419
|
"reinstall the latest version by following the instructions "
|
390
420
|
"at https://docs.flashinfer.ai/installation.html.",
|
391
421
|
)
|
392
422
|
|
393
423
|
|
394
|
-
def _wait_and_warmup(server_args, pipe_finish_writer):
|
424
|
+
def _wait_and_warmup(server_args, pipe_finish_writer, pid):
|
395
425
|
headers = {}
|
396
426
|
url = server_args.url()
|
397
427
|
if server_args.api_key:
|
@@ -414,8 +444,9 @@ def _wait_and_warmup(server_args, pipe_finish_writer):
|
|
414
444
|
if not success:
|
415
445
|
if pipe_finish_writer is not None:
|
416
446
|
pipe_finish_writer.send(last_traceback)
|
417
|
-
|
418
|
-
|
447
|
+
logger.error(f"Initialization failed. warmup error: {last_traceback}")
|
448
|
+
kill_child_process(pid, including_parent=False)
|
449
|
+
return
|
419
450
|
|
420
451
|
# Send a warmup request
|
421
452
|
request_name = "/generate" if model_info["is_generation"] else "/encode"
|
@@ -440,21 +471,13 @@ def _wait_and_warmup(server_args, pipe_finish_writer):
|
|
440
471
|
timeout=600,
|
441
472
|
)
|
442
473
|
assert res.status_code == 200, f"{res}"
|
443
|
-
except Exception
|
474
|
+
except Exception:
|
444
475
|
last_traceback = get_exception_traceback()
|
445
476
|
if pipe_finish_writer is not None:
|
446
477
|
pipe_finish_writer.send(last_traceback)
|
447
|
-
|
448
|
-
|
449
|
-
|
450
|
-
# Print warnings here
|
451
|
-
if server_args.disable_radix_cache and server_args.chunked_prefill_size is not None:
|
452
|
-
logger.warning(
|
453
|
-
"You set both `--disable-radix-cache` and `--chunked-prefill-size`. "
|
454
|
-
"This combination is an experimental feature and we noticed it can lead to "
|
455
|
-
"wrong generation results. If you want to use chunked prefill, it is recommended "
|
456
|
-
"not using `--disable-radix-cache`."
|
457
|
-
)
|
478
|
+
logger.error(f"Initialization failed. warmup error: {last_traceback}")
|
479
|
+
kill_child_process(pid, including_parent=False)
|
480
|
+
return
|
458
481
|
|
459
482
|
logger.info("The server is fired up and ready to roll!")
|
460
483
|
if pipe_finish_writer is not None:
|
@@ -492,6 +515,7 @@ class Runtime:
|
|
492
515
|
|
493
516
|
self.pid = None
|
494
517
|
pipe_reader, pipe_writer = mp.Pipe(duplex=False)
|
518
|
+
|
495
519
|
proc = mp.Process(
|
496
520
|
target=launch_server,
|
497
521
|
args=(self.server_args, model_overide_args, pipe_writer),
|
@@ -533,11 +557,18 @@ class Runtime:
|
|
533
557
|
prompt: str,
|
534
558
|
sampling_params: Optional[Dict] = None,
|
535
559
|
):
|
536
|
-
|
537
|
-
|
538
|
-
|
539
|
-
|
540
|
-
|
560
|
+
if self.server_args.skip_tokenizer_init:
|
561
|
+
json_data = {
|
562
|
+
"input_ids": prompt,
|
563
|
+
"sampling_params": sampling_params,
|
564
|
+
"stream": True,
|
565
|
+
}
|
566
|
+
else:
|
567
|
+
json_data = {
|
568
|
+
"text": prompt,
|
569
|
+
"sampling_params": sampling_params,
|
570
|
+
"stream": True,
|
571
|
+
}
|
541
572
|
pos = 0
|
542
573
|
|
543
574
|
timeout = aiohttp.ClientTimeout(total=3 * 3600)
|
@@ -549,24 +580,29 @@ class Runtime:
|
|
549
580
|
if chunk == "data: [DONE]\n\n":
|
550
581
|
break
|
551
582
|
data = json.loads(chunk[5:].strip("\n"))
|
552
|
-
|
553
|
-
|
554
|
-
|
555
|
-
|
583
|
+
if hasattr(data, "text"):
|
584
|
+
cur = data["text"][pos:]
|
585
|
+
if cur:
|
586
|
+
yield cur
|
587
|
+
pos += len(cur)
|
588
|
+
else:
|
589
|
+
yield data
|
556
590
|
|
557
591
|
add_request = async_generate
|
558
592
|
|
559
593
|
def generate(
|
560
594
|
self,
|
561
|
-
prompt: str,
|
595
|
+
prompt: Union[str, List[str]],
|
562
596
|
sampling_params: Optional[Dict] = None,
|
563
597
|
return_logprob: Optional[Union[List[bool], bool]] = False,
|
598
|
+
logprob_start_len: Optional[Union[List[int], int]] = None,
|
564
599
|
top_logprobs_num: Optional[Union[List[int], int]] = None,
|
565
600
|
):
|
566
601
|
json_data = {
|
567
602
|
"text": prompt,
|
568
603
|
"sampling_params": sampling_params,
|
569
604
|
"return_logprob": return_logprob,
|
605
|
+
"logprob_start_len": logprob_start_len,
|
570
606
|
"top_logprobs_num": top_logprobs_num,
|
571
607
|
}
|
572
608
|
response = requests.post(
|
@@ -577,7 +613,7 @@ class Runtime:
|
|
577
613
|
|
578
614
|
def encode(
|
579
615
|
self,
|
580
|
-
prompt: str,
|
616
|
+
prompt: Union[str, List[str]],
|
581
617
|
):
|
582
618
|
json_data = {
|
583
619
|
"text": prompt,
|