sglang 0.3.3__py3-none-any.whl → 0.3.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_latency.py +31 -13
- sglang/bench_server_latency.py +21 -10
- sglang/bench_serving.py +101 -7
- sglang/global_config.py +0 -1
- sglang/srt/conversation.py +11 -2
- sglang/srt/layers/attention/__init__.py +27 -5
- sglang/srt/layers/attention/double_sparsity_backend.py +281 -0
- sglang/srt/layers/attention/flashinfer_backend.py +352 -83
- sglang/srt/layers/attention/triton_backend.py +6 -4
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +772 -0
- sglang/srt/layers/attention/triton_ops/extend_attention.py +5 -3
- sglang/srt/layers/attention/triton_ops/prefill_attention.py +4 -2
- sglang/srt/layers/sampler.py +6 -2
- sglang/srt/managers/data_parallel_controller.py +177 -0
- sglang/srt/managers/detokenizer_manager.py +31 -10
- sglang/srt/managers/io_struct.py +11 -2
- sglang/srt/managers/schedule_batch.py +126 -43
- sglang/srt/managers/schedule_policy.py +2 -1
- sglang/srt/managers/scheduler.py +245 -142
- sglang/srt/managers/tokenizer_manager.py +14 -1
- sglang/srt/managers/tp_worker.py +111 -1
- sglang/srt/mem_cache/chunk_cache.py +8 -4
- sglang/srt/mem_cache/memory_pool.py +77 -4
- sglang/srt/mem_cache/radix_cache.py +15 -7
- sglang/srt/model_executor/cuda_graph_runner.py +4 -4
- sglang/srt/model_executor/forward_batch_info.py +16 -21
- sglang/srt/model_executor/model_runner.py +100 -36
- sglang/srt/models/baichuan.py +2 -3
- sglang/srt/models/chatglm.py +5 -6
- sglang/srt/models/commandr.py +1 -2
- sglang/srt/models/dbrx.py +1 -2
- sglang/srt/models/deepseek.py +4 -5
- sglang/srt/models/deepseek_v2.py +5 -6
- sglang/srt/models/exaone.py +1 -2
- sglang/srt/models/gemma.py +2 -2
- sglang/srt/models/gemma2.py +5 -5
- sglang/srt/models/gpt_bigcode.py +5 -5
- sglang/srt/models/grok.py +1 -2
- sglang/srt/models/internlm2.py +1 -2
- sglang/srt/models/llama.py +1 -2
- sglang/srt/models/llama_classification.py +1 -2
- sglang/srt/models/llama_reward.py +2 -3
- sglang/srt/models/llava.py +4 -8
- sglang/srt/models/llavavid.py +1 -2
- sglang/srt/models/minicpm.py +1 -2
- sglang/srt/models/minicpm3.py +5 -6
- sglang/srt/models/mixtral.py +1 -2
- sglang/srt/models/mixtral_quant.py +1 -2
- sglang/srt/models/olmo.py +352 -0
- sglang/srt/models/olmoe.py +1 -2
- sglang/srt/models/qwen.py +1 -2
- sglang/srt/models/qwen2.py +1 -2
- sglang/srt/models/qwen2_moe.py +4 -5
- sglang/srt/models/stablelm.py +1 -2
- sglang/srt/models/torch_native_llama.py +1 -2
- sglang/srt/models/xverse.py +1 -2
- sglang/srt/models/xverse_moe.py +4 -5
- sglang/srt/models/yivl.py +1 -2
- sglang/srt/openai_api/adapter.py +97 -52
- sglang/srt/openai_api/protocol.py +10 -2
- sglang/srt/sampling/penaltylib/orchestrator.py +28 -9
- sglang/srt/sampling/sampling_batch_info.py +105 -59
- sglang/srt/sampling/sampling_params.py +2 -0
- sglang/srt/server.py +171 -37
- sglang/srt/server_args.py +127 -48
- sglang/srt/utils.py +37 -14
- sglang/test/few_shot_gsm8k.py +4 -1
- sglang/test/few_shot_gsm8k_engine.py +144 -0
- sglang/test/srt/sampling/penaltylib/utils.py +16 -12
- sglang/version.py +1 -1
- {sglang-0.3.3.dist-info → sglang-0.3.4.dist-info}/METADATA +82 -32
- sglang-0.3.4.dist-info/RECORD +143 -0
- {sglang-0.3.3.dist-info → sglang-0.3.4.dist-info}/WHEEL +1 -1
- sglang/srt/layers/attention/flashinfer_utils.py +0 -237
- sglang-0.3.3.dist-info/RECORD +0 -139
- {sglang-0.3.3.dist-info → sglang-0.3.4.dist-info}/LICENSE +0 -0
- {sglang-0.3.3.dist-info → sglang-0.3.4.dist-info}/top_level.txt +0 -0
@@ -1,7 +1,7 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
3
|
import dataclasses
|
4
|
-
from typing import TYPE_CHECKING, List
|
4
|
+
from typing import TYPE_CHECKING, List, Optional
|
5
5
|
|
6
6
|
import torch
|
7
7
|
|
@@ -20,6 +20,9 @@ class SamplingBatchInfo:
|
|
20
20
|
top_ks: torch.Tensor
|
21
21
|
min_ps: torch.Tensor
|
22
22
|
|
23
|
+
# All requests use greedy sampling
|
24
|
+
is_all_greedy: bool
|
25
|
+
|
23
26
|
# Dispatch in CUDA graph
|
24
27
|
need_min_p_sampling: bool
|
25
28
|
|
@@ -33,27 +36,39 @@ class SamplingBatchInfo:
|
|
33
36
|
regex_fsm_states: List[int] = None
|
34
37
|
|
35
38
|
# Penalizer
|
36
|
-
penalizer_orchestrator: penaltylib.BatchedPenalizerOrchestrator = None
|
37
|
-
linear_penalties: torch.Tensor = None
|
38
|
-
scaling_penalties: torch.Tensor = None
|
39
|
+
penalizer_orchestrator: Optional[penaltylib.BatchedPenalizerOrchestrator] = None
|
40
|
+
linear_penalties: Optional[torch.Tensor] = None
|
41
|
+
scaling_penalties: Optional[torch.Tensor] = None
|
42
|
+
|
43
|
+
# Device
|
44
|
+
device: str = "cuda"
|
39
45
|
|
40
46
|
@classmethod
|
41
|
-
def from_schedule_batch(
|
47
|
+
def from_schedule_batch(
|
48
|
+
cls,
|
49
|
+
batch: ScheduleBatch,
|
50
|
+
vocab_size: int,
|
51
|
+
disable_penalizer: bool,
|
52
|
+
):
|
42
53
|
reqs = batch.reqs
|
43
|
-
|
44
|
-
|
54
|
+
device = batch.input_ids.device
|
55
|
+
temperatures = (
|
56
|
+
torch.tensor(
|
45
57
|
[r.sampling_params.temperature for r in reqs],
|
46
58
|
dtype=torch.float,
|
47
|
-
).view(-1, 1)
|
48
|
-
top_ps = torch.tensor(
|
49
|
-
[r.sampling_params.top_p for r in reqs], dtype=torch.float
|
50
|
-
)
|
51
|
-
top_ks = torch.tensor(
|
52
|
-
[r.sampling_params.top_k for r in reqs], dtype=torch.int
|
53
|
-
)
|
54
|
-
min_ps = torch.tensor(
|
55
|
-
[r.sampling_params.min_p for r in reqs], dtype=torch.float
|
56
59
|
)
|
60
|
+
.view(-1, 1)
|
61
|
+
.to(device, non_blocking=True)
|
62
|
+
)
|
63
|
+
top_ps = torch.tensor(
|
64
|
+
[r.sampling_params.top_p for r in reqs], dtype=torch.float
|
65
|
+
).to(device, non_blocking=True)
|
66
|
+
top_ks = torch.tensor(
|
67
|
+
[r.sampling_params.top_k for r in reqs], dtype=torch.int32
|
68
|
+
).to(device, non_blocking=True)
|
69
|
+
min_ps = torch.tensor(
|
70
|
+
[r.sampling_params.min_p for r in reqs], dtype=torch.float
|
71
|
+
).to(device, non_blocking=True)
|
57
72
|
|
58
73
|
ret = cls(
|
59
74
|
temperatures=temperatures,
|
@@ -61,7 +76,9 @@ class SamplingBatchInfo:
|
|
61
76
|
top_ks=top_ks,
|
62
77
|
min_ps=min_ps,
|
63
78
|
need_min_p_sampling=any(r.sampling_params.min_p > 0 for r in reqs),
|
79
|
+
is_all_greedy=top_ks.max().item() <= 1,
|
64
80
|
vocab_size=vocab_size,
|
81
|
+
device=batch.input_ids.device,
|
65
82
|
)
|
66
83
|
# TODO (lianmin): `need_min_p_sampling` needs to be updated in filter and merge.
|
67
84
|
|
@@ -71,18 +88,21 @@ class SamplingBatchInfo:
|
|
71
88
|
#
|
72
89
|
# While we choose not to even create the class instances if they are not required, this
|
73
90
|
# could add additional complexity to the {ScheduleBatch} class, especially we need to
|
74
|
-
# handle {filter_batch()} and {
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
91
|
+
# handle {filter_batch()} and {merge_batch()} cases as well.
|
92
|
+
if disable_penalizer:
|
93
|
+
ret.penalizer_orchestrator = None
|
94
|
+
else:
|
95
|
+
ret.penalizer_orchestrator = penaltylib.BatchedPenalizerOrchestrator(
|
96
|
+
vocab_size=vocab_size,
|
97
|
+
batch=batch,
|
98
|
+
device=batch.input_ids.device,
|
99
|
+
Penalizers={
|
100
|
+
penaltylib.BatchedFrequencyPenalizer,
|
101
|
+
penaltylib.BatchedMinNewTokensPenalizer,
|
102
|
+
penaltylib.BatchedPresencePenalizer,
|
103
|
+
penaltylib.BatchedRepetitionPenalizer,
|
104
|
+
},
|
105
|
+
)
|
86
106
|
|
87
107
|
# Handle logit bias but only allocate when needed
|
88
108
|
ret.logit_bias = None
|
@@ -93,43 +113,50 @@ class SamplingBatchInfo:
|
|
93
113
|
return len(self.temperatures)
|
94
114
|
|
95
115
|
def update_penalties(self):
|
116
|
+
if not self.penalizer_orchestrator:
|
117
|
+
return
|
118
|
+
|
96
119
|
self.scaling_penalties = None
|
97
120
|
self.linear_penalties = None
|
98
121
|
|
99
122
|
for penalizer in self.penalizer_orchestrator.penalizers.values():
|
123
|
+
if not penalizer.is_prepared():
|
124
|
+
continue
|
125
|
+
|
100
126
|
if isinstance(penalizer, penaltylib.BatchedRepetitionPenalizer):
|
101
|
-
|
102
|
-
self.scaling_penalties = penalizer.cumulated_repetition_penalties
|
127
|
+
self.scaling_penalties = penalizer.cumulated_repetition_penalties
|
103
128
|
else:
|
104
|
-
if
|
105
|
-
|
106
|
-
|
107
|
-
self.
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
self.linear_penalties = penalizer.apply(self.linear_penalties)
|
129
|
+
if self.linear_penalties is None:
|
130
|
+
bs = self.penalizer_orchestrator.batch.batch_size()
|
131
|
+
self.linear_penalties = torch.zeros(
|
132
|
+
(bs, self.vocab_size),
|
133
|
+
dtype=torch.float32,
|
134
|
+
device=self.device,
|
135
|
+
)
|
136
|
+
self.linear_penalties = penalizer.apply(self.linear_penalties)
|
113
137
|
|
114
138
|
def update_regex_vocab_mask(self):
|
115
139
|
has_regex = self.regex_fsms and any(regex_fsm for regex_fsm in self.regex_fsms)
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
|
120
|
-
|
121
|
-
self.
|
122
|
-
|
123
|
-
|
124
|
-
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
|
140
|
+
if not has_regex:
|
141
|
+
self.vocab_mask = None
|
142
|
+
return
|
143
|
+
|
144
|
+
self.vocab_mask = torch.zeros(
|
145
|
+
len(self.temperatures),
|
146
|
+
self.vocab_size,
|
147
|
+
dtype=torch.bool,
|
148
|
+
device=self.device,
|
149
|
+
)
|
150
|
+
for i, regex_fsm in enumerate(self.regex_fsms):
|
151
|
+
if regex_fsm is not None:
|
152
|
+
self.vocab_mask[i].fill_(1)
|
153
|
+
self.vocab_mask[i][
|
154
|
+
regex_fsm.get_next_instruction(self.regex_fsm_states[i]).tokens
|
155
|
+
] = 0
|
130
156
|
|
131
157
|
def filter_batch(self, unfinished_indices: List[int], new_indices: torch.Tensor):
|
132
|
-
self.penalizer_orchestrator
|
158
|
+
if self.penalizer_orchestrator:
|
159
|
+
self.penalizer_orchestrator.filter(unfinished_indices, new_indices)
|
133
160
|
|
134
161
|
for item in [
|
135
162
|
"temperatures",
|
@@ -144,7 +171,12 @@ class SamplingBatchInfo:
|
|
144
171
|
|
145
172
|
@staticmethod
|
146
173
|
def merge_bias_tensor(
|
147
|
-
lhs: torch.Tensor,
|
174
|
+
lhs: torch.Tensor,
|
175
|
+
rhs: torch.Tensor,
|
176
|
+
bs1: int,
|
177
|
+
bs2: int,
|
178
|
+
device: str,
|
179
|
+
default: int = 0,
|
148
180
|
):
|
149
181
|
# bias tensor can be None
|
150
182
|
if lhs is not None or rhs is not None:
|
@@ -155,15 +187,16 @@ class SamplingBatchInfo:
|
|
155
187
|
shape, dtype = rhs.shape[1:], rhs.dtype
|
156
188
|
with torch.dtype(dtype):
|
157
189
|
if lhs is None:
|
158
|
-
lhs = torch.empty((bs1, *shape), device=
|
190
|
+
lhs = torch.empty((bs1, *shape), device=device).fill_(default)
|
159
191
|
if rhs is None:
|
160
|
-
rhs = torch.empty((bs2, *shape), device=
|
192
|
+
rhs = torch.empty((bs2, *shape), device=device).fill_(default)
|
161
193
|
return torch.cat([lhs, rhs])
|
162
194
|
|
163
195
|
return None
|
164
196
|
|
165
197
|
def merge_batch(self, other: "SamplingBatchInfo"):
|
166
|
-
self.penalizer_orchestrator
|
198
|
+
if self.penalizer_orchestrator:
|
199
|
+
self.penalizer_orchestrator.merge(other.penalizer_orchestrator)
|
167
200
|
|
168
201
|
for item in [
|
169
202
|
"temperatures",
|
@@ -175,6 +208,19 @@ class SamplingBatchInfo:
|
|
175
208
|
other_val = getattr(other, item, None)
|
176
209
|
setattr(self, item, torch.concat([self_val, other_val]))
|
177
210
|
|
211
|
+
self.is_all_greedy = self.is_all_greedy and other.is_all_greedy
|
178
212
|
self.logit_bias = SamplingBatchInfo.merge_bias_tensor(
|
179
|
-
self.logit_bias, other.logit_bias, len(self), len(other)
|
213
|
+
self.logit_bias, other.logit_bias, len(self), len(other), self.device
|
214
|
+
)
|
215
|
+
|
216
|
+
def copy(self):
|
217
|
+
return SamplingBatchInfo(
|
218
|
+
temperatures=self.temperatures,
|
219
|
+
top_ps=self.top_ps,
|
220
|
+
top_ks=self.top_ks,
|
221
|
+
min_ps=self.min_ps,
|
222
|
+
is_all_greedy=self.is_all_greedy,
|
223
|
+
need_min_p_sampling=self.need_min_p_sampling,
|
224
|
+
vocab_size=self.vocab_size,
|
225
|
+
device=self.device,
|
180
226
|
)
|
@@ -40,6 +40,7 @@ class SamplingParams:
|
|
40
40
|
regex: Optional[str] = None,
|
41
41
|
n: int = 1,
|
42
42
|
json_schema: Optional[str] = None,
|
43
|
+
no_stop_trim: bool = False,
|
43
44
|
) -> None:
|
44
45
|
self.temperature = temperature
|
45
46
|
self.top_p = top_p
|
@@ -60,6 +61,7 @@ class SamplingParams:
|
|
60
61
|
self.regex = regex
|
61
62
|
self.n = n
|
62
63
|
self.json_schema = json_schema
|
64
|
+
self.no_stop_trim = no_stop_trim
|
63
65
|
|
64
66
|
# Process some special cases
|
65
67
|
if self.temperature < _SAMPLING_EPS:
|
sglang/srt/server.py
CHANGED
@@ -25,11 +25,12 @@ import json
|
|
25
25
|
import logging
|
26
26
|
import multiprocessing as mp
|
27
27
|
import os
|
28
|
-
import random
|
29
28
|
import threading
|
30
29
|
import time
|
31
30
|
from http import HTTPStatus
|
32
|
-
from typing import Dict, List, Optional, Union
|
31
|
+
from typing import AsyncIterator, Dict, List, Optional, Union
|
32
|
+
|
33
|
+
import orjson
|
33
34
|
|
34
35
|
# Fix a bug of Python threading
|
35
36
|
setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
|
@@ -40,10 +41,14 @@ import uvicorn
|
|
40
41
|
import uvloop
|
41
42
|
from fastapi import FastAPI, File, Form, Request, UploadFile
|
42
43
|
from fastapi.middleware.cors import CORSMiddleware
|
43
|
-
from fastapi.responses import
|
44
|
+
from fastapi.responses import ORJSONResponse, Response, StreamingResponse
|
45
|
+
from uvicorn.config import LOGGING_CONFIG
|
44
46
|
|
45
47
|
from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint
|
46
48
|
from sglang.srt.hf_transformers_utils import get_tokenizer
|
49
|
+
from sglang.srt.managers.data_parallel_controller import (
|
50
|
+
run_data_parallel_controller_process,
|
51
|
+
)
|
47
52
|
from sglang.srt.managers.detokenizer_manager import run_detokenizer_process
|
48
53
|
from sglang.srt.managers.io_struct import (
|
49
54
|
EmbeddingReqInput,
|
@@ -145,18 +150,40 @@ async def flush_cache():
|
|
145
150
|
)
|
146
151
|
|
147
152
|
|
153
|
+
@app.get("/start_profile")
|
154
|
+
@app.post("/start_profile")
|
155
|
+
async def start_profile():
|
156
|
+
"""Start profiling."""
|
157
|
+
tokenizer_manager.start_profile()
|
158
|
+
return Response(
|
159
|
+
content="Start profiling.\n",
|
160
|
+
status_code=200,
|
161
|
+
)
|
162
|
+
|
163
|
+
|
164
|
+
@app.get("/stop_profile")
|
165
|
+
@app.post("/stop_profile")
|
166
|
+
async def stop_profile():
|
167
|
+
"""Stop profiling."""
|
168
|
+
tokenizer_manager.stop_profile()
|
169
|
+
return Response(
|
170
|
+
content="Stop profiling. This will take some time.\n",
|
171
|
+
status_code=200,
|
172
|
+
)
|
173
|
+
|
174
|
+
|
148
175
|
@app.post("/update_weights")
|
149
176
|
async def update_weights(obj: UpdateWeightReqInput, request: Request):
|
150
177
|
"""Update the weights inplace without re-launching the server."""
|
151
178
|
success, message = await tokenizer_manager.update_weights(obj, request)
|
152
179
|
content = {"success": success, "message": message}
|
153
180
|
if success:
|
154
|
-
return
|
181
|
+
return ORJSONResponse(
|
155
182
|
content,
|
156
183
|
status_code=HTTPStatus.OK,
|
157
184
|
)
|
158
185
|
else:
|
159
|
-
return
|
186
|
+
return ORJSONResponse(
|
160
187
|
content,
|
161
188
|
status_code=HTTPStatus.BAD_REQUEST,
|
162
189
|
)
|
@@ -167,14 +194,18 @@ async def generate_request(obj: GenerateReqInput, request: Request):
|
|
167
194
|
"""Handle a generate request."""
|
168
195
|
if obj.stream:
|
169
196
|
|
170
|
-
async def stream_results():
|
197
|
+
async def stream_results() -> AsyncIterator[bytes]:
|
171
198
|
try:
|
172
199
|
async for out in tokenizer_manager.generate_request(obj, request):
|
173
|
-
yield
|
200
|
+
yield b"data: " + orjson.dumps(
|
201
|
+
out, option=orjson.OPT_NON_STR_KEYS
|
202
|
+
) + b"\n\n"
|
174
203
|
except ValueError as e:
|
175
204
|
out = {"error": {"message": str(e)}}
|
176
|
-
yield
|
177
|
-
|
205
|
+
yield b"data: " + orjson.dumps(
|
206
|
+
out, option=orjson.OPT_NON_STR_KEYS
|
207
|
+
) + b"\n\n"
|
208
|
+
yield b"data: [DONE]\n\n"
|
178
209
|
|
179
210
|
return StreamingResponse(
|
180
211
|
stream_results(),
|
@@ -186,7 +217,7 @@ async def generate_request(obj: GenerateReqInput, request: Request):
|
|
186
217
|
ret = await tokenizer_manager.generate_request(obj, request).__anext__()
|
187
218
|
return ret
|
188
219
|
except ValueError as e:
|
189
|
-
return
|
220
|
+
return ORJSONResponse(
|
190
221
|
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
|
191
222
|
)
|
192
223
|
|
@@ -201,7 +232,7 @@ async def encode_request(obj: EmbeddingReqInput, request: Request):
|
|
201
232
|
ret = await tokenizer_manager.generate_request(obj, request).__anext__()
|
202
233
|
return ret
|
203
234
|
except ValueError as e:
|
204
|
-
return
|
235
|
+
return ORJSONResponse(
|
205
236
|
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
|
206
237
|
)
|
207
238
|
|
@@ -216,7 +247,7 @@ async def judge_request(obj: RewardReqInput, request: Request):
|
|
216
247
|
ret = await tokenizer_manager.generate_request(obj, request).__anext__()
|
217
248
|
return ret
|
218
249
|
except ValueError as e:
|
219
|
-
return
|
250
|
+
return ORJSONResponse(
|
220
251
|
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
|
221
252
|
)
|
222
253
|
|
@@ -235,13 +266,13 @@ async def openai_v1_chat_completions(raw_request: Request):
|
|
235
266
|
return await v1_chat_completions(tokenizer_manager, raw_request)
|
236
267
|
|
237
268
|
|
238
|
-
@app.post("/v1/embeddings")
|
269
|
+
@app.post("/v1/embeddings", response_class=ORJSONResponse)
|
239
270
|
async def openai_v1_embeddings(raw_request: Request):
|
240
271
|
response = await v1_embeddings(tokenizer_manager, raw_request)
|
241
272
|
return response
|
242
273
|
|
243
274
|
|
244
|
-
@app.get("/v1/models")
|
275
|
+
@app.get("/v1/models", response_class=ORJSONResponse)
|
245
276
|
def available_models():
|
246
277
|
"""Show available models."""
|
247
278
|
served_model_names = [tokenizer_manager.served_model_name]
|
@@ -315,30 +346,40 @@ def launch_engine(
|
|
315
346
|
server_args.model_path, server_args.tokenizer_path
|
316
347
|
)
|
317
348
|
|
318
|
-
|
319
|
-
|
320
|
-
|
321
|
-
|
322
|
-
|
323
|
-
|
324
|
-
|
325
|
-
|
326
|
-
|
349
|
+
if server_args.dp_size == 1:
|
350
|
+
# Launch tensor parallel scheduler processes
|
351
|
+
scheduler_procs = []
|
352
|
+
scheduler_pipe_readers = []
|
353
|
+
tp_size_per_node = server_args.tp_size // server_args.nnodes
|
354
|
+
tp_rank_range = range(
|
355
|
+
tp_size_per_node * server_args.node_rank,
|
356
|
+
tp_size_per_node * (server_args.node_rank + 1),
|
357
|
+
)
|
358
|
+
for tp_rank in tp_rank_range:
|
359
|
+
reader, writer = mp.Pipe(duplex=False)
|
360
|
+
gpu_id = tp_rank % tp_size_per_node
|
361
|
+
proc = mp.Process(
|
362
|
+
target=run_scheduler_process,
|
363
|
+
args=(server_args, port_args, gpu_id, tp_rank, None, writer),
|
364
|
+
)
|
365
|
+
proc.start()
|
366
|
+
scheduler_procs.append(proc)
|
367
|
+
scheduler_pipe_readers.append(reader)
|
368
|
+
|
369
|
+
if server_args.node_rank >= 1:
|
370
|
+
# For other nodes, they do not need to run tokenizer or detokenizer,
|
371
|
+
# so they can just wait here.
|
372
|
+
while True:
|
373
|
+
pass
|
374
|
+
else:
|
375
|
+
# Launch the data parallel controller
|
327
376
|
reader, writer = mp.Pipe(duplex=False)
|
328
|
-
|
377
|
+
scheduler_pipe_readers = [reader]
|
329
378
|
proc = mp.Process(
|
330
|
-
target=
|
331
|
-
args=(server_args, port_args,
|
379
|
+
target=run_data_parallel_controller_process,
|
380
|
+
args=(server_args, port_args, writer),
|
332
381
|
)
|
333
382
|
proc.start()
|
334
|
-
scheduler_procs.append(proc)
|
335
|
-
scheduler_pipe_readers.append(reader)
|
336
|
-
|
337
|
-
if server_args.node_rank >= 1:
|
338
|
-
# For other nodes, they do not need to run tokenizer or detokenizer,
|
339
|
-
# so they can just wait here.
|
340
|
-
while True:
|
341
|
-
pass
|
342
383
|
|
343
384
|
# Launch detokenizer process
|
344
385
|
detoken_proc = mp.Process(
|
@@ -394,6 +435,14 @@ def launch_server(
|
|
394
435
|
|
395
436
|
try:
|
396
437
|
# Listen for HTTP requests
|
438
|
+
LOGGING_CONFIG["formatters"]["default"][
|
439
|
+
"fmt"
|
440
|
+
] = "[%(asctime)s] %(levelprefix)s %(message)s"
|
441
|
+
LOGGING_CONFIG["formatters"]["default"]["datefmt"] = "%Y-%m-%d %H:%M:%S"
|
442
|
+
LOGGING_CONFIG["formatters"]["access"][
|
443
|
+
"fmt"
|
444
|
+
] = '[%(asctime)s] %(levelprefix)s %(client_addr)s - "%(request_line)s" %(status_code)s'
|
445
|
+
LOGGING_CONFIG["formatters"]["access"]["datefmt"] = "%Y-%m-%d %H:%M:%S"
|
397
446
|
uvicorn.run(
|
398
447
|
app,
|
399
448
|
host=server_args.host,
|
@@ -412,7 +461,7 @@ def _set_envs_and_config(server_args: ServerArgs):
|
|
412
461
|
os.environ["NCCL_CUMEM_ENABLE"] = "0"
|
413
462
|
os.environ["NCCL_NVLS_ENABLE"] = "0"
|
414
463
|
os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
|
415
|
-
os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "
|
464
|
+
os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "4"
|
416
465
|
|
417
466
|
# Set ulimit
|
418
467
|
set_ulimit()
|
@@ -493,6 +542,8 @@ def _wait_and_warmup(server_args, pipe_finish_writer, pid):
|
|
493
542
|
kill_child_process(pid, including_parent=False)
|
494
543
|
return
|
495
544
|
|
545
|
+
# logger.info(f"{res.json()=}")
|
546
|
+
|
496
547
|
logger.info("The server is fired up and ready to roll!")
|
497
548
|
if pipe_finish_writer is not None:
|
498
549
|
pipe_finish_writer.send("ready")
|
@@ -657,6 +708,10 @@ class Runtime:
|
|
657
708
|
self.shutdown()
|
658
709
|
|
659
710
|
|
711
|
+
STREAM_END_SYMBOL = b"data: [DONE]"
|
712
|
+
STREAM_CHUNK_START_SYMBOL = b"data:"
|
713
|
+
|
714
|
+
|
660
715
|
class Engine:
|
661
716
|
"""
|
662
717
|
SRT Engine without an HTTP server layer.
|
@@ -681,7 +736,10 @@ class Engine:
|
|
681
736
|
logprob_start_len: Optional[Union[List[int], int]] = None,
|
682
737
|
top_logprobs_num: Optional[Union[List[int], int]] = None,
|
683
738
|
lora_path: Optional[List[Optional[str]]] = None,
|
739
|
+
stream: bool = False,
|
684
740
|
):
|
741
|
+
# TODO (ByronHsu): refactor to reduce the duplicated code
|
742
|
+
|
685
743
|
obj = GenerateReqInput(
|
686
744
|
text=prompt,
|
687
745
|
sampling_params=sampling_params,
|
@@ -689,13 +747,89 @@ class Engine:
|
|
689
747
|
logprob_start_len=logprob_start_len,
|
690
748
|
top_logprobs_num=top_logprobs_num,
|
691
749
|
lora_path=lora_path,
|
750
|
+
stream=stream,
|
692
751
|
)
|
693
752
|
|
694
753
|
# get the current event loop
|
695
754
|
loop = asyncio.get_event_loop()
|
696
|
-
|
755
|
+
ret = loop.run_until_complete(generate_request(obj, None))
|
756
|
+
|
757
|
+
if stream is True:
|
758
|
+
|
759
|
+
def generator_wrapper():
|
760
|
+
offset = 0
|
761
|
+
loop = asyncio.get_event_loop()
|
762
|
+
generator = ret.body_iterator
|
763
|
+
while True:
|
764
|
+
chunk = loop.run_until_complete(generator.__anext__())
|
765
|
+
|
766
|
+
if chunk.startswith(STREAM_END_SYMBOL):
|
767
|
+
break
|
768
|
+
else:
|
769
|
+
data = json.loads(chunk[len(STREAM_CHUNK_START_SYMBOL) :])
|
770
|
+
data["text"] = data["text"][offset:]
|
771
|
+
offset += len(data["text"])
|
772
|
+
yield data
|
773
|
+
|
774
|
+
# we cannot yield in the scope of generate() because python does not allow yield + return in the same function
|
775
|
+
# however, it allows to wrap the generator as a subfunction and return
|
776
|
+
return generator_wrapper()
|
777
|
+
else:
|
778
|
+
return ret
|
779
|
+
|
780
|
+
async def async_generate(
|
781
|
+
self,
|
782
|
+
prompt: Union[str, List[str]],
|
783
|
+
sampling_params: Optional[Dict] = None,
|
784
|
+
return_logprob: Optional[Union[List[bool], bool]] = False,
|
785
|
+
logprob_start_len: Optional[Union[List[int], int]] = None,
|
786
|
+
top_logprobs_num: Optional[Union[List[int], int]] = None,
|
787
|
+
lora_path: Optional[List[Optional[str]]] = None,
|
788
|
+
stream: bool = False,
|
789
|
+
):
|
790
|
+
obj = GenerateReqInput(
|
791
|
+
text=prompt,
|
792
|
+
sampling_params=sampling_params,
|
793
|
+
return_logprob=return_logprob,
|
794
|
+
logprob_start_len=logprob_start_len,
|
795
|
+
top_logprobs_num=top_logprobs_num,
|
796
|
+
lora_path=lora_path,
|
797
|
+
stream=stream,
|
798
|
+
)
|
799
|
+
|
800
|
+
ret = await generate_request(obj, None)
|
801
|
+
|
802
|
+
if stream is True:
|
803
|
+
generator = ret.body_iterator
|
804
|
+
|
805
|
+
async def generator_wrapper():
|
806
|
+
|
807
|
+
offset = 0
|
808
|
+
|
809
|
+
while True:
|
810
|
+
chunk = await generator.__anext__()
|
811
|
+
|
812
|
+
if chunk.startswith(STREAM_END_SYMBOL):
|
813
|
+
break
|
814
|
+
else:
|
815
|
+
data = json.loads(chunk[len(STREAM_CHUNK_START_SYMBOL) :])
|
816
|
+
data["text"] = data["text"][offset:]
|
817
|
+
offset += len(data["text"])
|
818
|
+
yield data
|
819
|
+
|
820
|
+
return generator_wrapper()
|
821
|
+
else:
|
822
|
+
return ret
|
697
823
|
|
698
824
|
def shutdown(self):
|
699
825
|
kill_child_process(os.getpid(), including_parent=False)
|
700
826
|
|
701
|
-
|
827
|
+
def get_tokenizer(self):
|
828
|
+
global tokenizer_manager
|
829
|
+
|
830
|
+
if tokenizer_manager is None:
|
831
|
+
raise ReferenceError("Tokenizer Manager is not initialized.")
|
832
|
+
else:
|
833
|
+
return tokenizer_manager.tokenizer
|
834
|
+
|
835
|
+
# TODO (ByronHsu): encode
|