sglang 0.3.5.post2__py3-none-any.whl → 0.3.6.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -2
- sglang/api.py +2 -2
- sglang/bench_latency.py +1 -553
- sglang/bench_offline_throughput.py +48 -20
- sglang/bench_one_batch.py +472 -0
- sglang/{bench_server_latency.py → bench_one_batch_server.py} +3 -3
- sglang/bench_serving.py +125 -6
- sglang/check_env.py +3 -6
- sglang/lang/backend/base_backend.py +1 -1
- sglang/lang/backend/runtime_endpoint.py +2 -2
- sglang/srt/configs/model_config.py +13 -14
- sglang/srt/constrained/__init__.py +13 -14
- sglang/srt/constrained/base_grammar_backend.py +13 -15
- sglang/srt/constrained/outlines_backend.py +28 -17
- sglang/srt/constrained/outlines_jump_forward.py +13 -15
- sglang/srt/constrained/xgrammar_backend.py +47 -58
- sglang/srt/conversation.py +13 -15
- sglang/srt/hf_transformers_utils.py +13 -15
- sglang/srt/layers/activation.py +16 -13
- sglang/srt/layers/attention/flashinfer_backend.py +106 -54
- sglang/srt/layers/attention/triton_backend.py +9 -7
- sglang/srt/layers/attention/triton_ops/decode_attention.py +51 -55
- sglang/srt/layers/attention/triton_ops/extend_attention.py +16 -16
- sglang/srt/layers/attention/triton_ops/prefill_attention.py +13 -15
- sglang/srt/layers/custom_op_util.py +25 -0
- sglang/srt/layers/fused_moe_grok/__init__.py +1 -0
- sglang/srt/layers/{fused_moe → fused_moe_grok}/fused_moe.py +11 -4
- sglang/srt/layers/{fused_moe → fused_moe_grok}/layer.py +4 -9
- sglang/srt/layers/{fused_moe/patch.py → fused_moe_patch.py} +5 -0
- sglang/srt/layers/fused_moe_triton/__init__.py +44 -0
- sglang/srt/layers/fused_moe_triton/fused_moe.py +861 -0
- sglang/srt/layers/fused_moe_triton/layer.py +633 -0
- sglang/srt/layers/layernorm.py +17 -15
- sglang/srt/layers/logits_processor.py +23 -25
- sglang/srt/layers/quantization/__init__.py +77 -17
- sglang/srt/layers/radix_attention.py +13 -15
- sglang/srt/layers/rotary_embedding.py +13 -13
- sglang/srt/layers/sampler.py +4 -8
- sglang/srt/layers/torchao_utils.py +2 -0
- sglang/srt/lora/lora.py +13 -14
- sglang/srt/lora/lora_config.py +13 -14
- sglang/srt/lora/lora_manager.py +22 -24
- sglang/srt/managers/data_parallel_controller.py +98 -27
- sglang/srt/managers/detokenizer_manager.py +13 -15
- sglang/srt/managers/io_struct.py +63 -21
- sglang/srt/managers/schedule_batch.py +154 -59
- sglang/srt/managers/schedule_policy.py +18 -16
- sglang/srt/managers/scheduler.py +278 -109
- sglang/srt/managers/session_controller.py +61 -0
- sglang/srt/managers/tokenizer_manager.py +63 -18
- sglang/srt/managers/tp_worker.py +25 -16
- sglang/srt/managers/tp_worker_overlap_thread.py +62 -67
- sglang/srt/metrics/collector.py +13 -15
- sglang/srt/metrics/func_timer.py +13 -15
- sglang/srt/mm_utils.py +13 -14
- sglang/srt/model_executor/cuda_graph_runner.py +63 -25
- sglang/srt/model_executor/forward_batch_info.py +128 -32
- sglang/srt/model_executor/model_runner.py +132 -64
- sglang/srt/model_parallel.py +98 -0
- sglang/srt/models/chatglm.py +15 -16
- sglang/srt/models/commandr.py +15 -16
- sglang/srt/models/dbrx.py +15 -16
- sglang/srt/models/deepseek.py +15 -15
- sglang/srt/models/deepseek_v2.py +162 -59
- sglang/srt/models/exaone.py +14 -15
- sglang/srt/models/gemma.py +14 -14
- sglang/srt/models/gemma2.py +31 -25
- sglang/srt/models/gemma2_reward.py +13 -14
- sglang/srt/models/gpt_bigcode.py +14 -14
- sglang/srt/models/grok.py +15 -15
- sglang/srt/models/internlm2.py +13 -15
- sglang/srt/models/internlm2_reward.py +13 -14
- sglang/srt/models/llama.py +21 -21
- sglang/srt/models/llama_classification.py +13 -14
- sglang/srt/models/llama_reward.py +13 -14
- sglang/srt/models/llava.py +14 -16
- sglang/srt/models/llavavid.py +14 -16
- sglang/srt/models/minicpm.py +13 -15
- sglang/srt/models/minicpm3.py +13 -15
- sglang/srt/models/mistral.py +13 -15
- sglang/srt/models/mixtral.py +15 -15
- sglang/srt/models/mixtral_quant.py +14 -14
- sglang/srt/models/olmo.py +22 -20
- sglang/srt/models/olmoe.py +23 -20
- sglang/srt/models/phi3_small.py +447 -0
- sglang/srt/models/qwen.py +14 -14
- sglang/srt/models/qwen2.py +22 -19
- sglang/srt/models/qwen2_moe.py +17 -18
- sglang/srt/models/qwen2_vl.py +13 -6
- sglang/srt/models/stablelm.py +18 -16
- sglang/srt/models/torch_native_llama.py +107 -93
- sglang/srt/models/xverse.py +13 -14
- sglang/srt/models/xverse_moe.py +15 -16
- sglang/srt/models/yivl.py +13 -15
- sglang/srt/openai_api/adapter.py +19 -17
- sglang/srt/openai_api/protocol.py +14 -16
- sglang/srt/sampling/penaltylib/orchestrator.py +49 -79
- sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +3 -8
- sglang/srt/sampling/penaltylib/penalizers/min_new_tokens.py +3 -9
- sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +3 -8
- sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +3 -8
- sglang/srt/sampling/sampling_batch_info.py +61 -57
- sglang/srt/sampling/sampling_params.py +14 -16
- sglang/srt/server.py +86 -35
- sglang/srt/server_args.py +96 -80
- sglang/srt/utils.py +266 -68
- sglang/test/few_shot_gsm8k.py +8 -4
- sglang/test/runners.py +38 -20
- sglang/test/srt/sampling/penaltylib/utils.py +23 -21
- sglang/test/test_utils.py +31 -20
- sglang/version.py +1 -1
- {sglang-0.3.5.post2.dist-info → sglang-0.3.6.post1.dist-info}/LICENSE +1 -1
- {sglang-0.3.5.post2.dist-info → sglang-0.3.6.post1.dist-info}/METADATA +66 -57
- sglang-0.3.6.post1.dist-info/RECORD +164 -0
- {sglang-0.3.5.post2.dist-info → sglang-0.3.6.post1.dist-info}/WHEEL +1 -1
- sglang/srt/layers/fused_moe/__init__.py +0 -1
- sglang-0.3.5.post2.dist-info/RECORD +0 -156
- {sglang-0.3.5.post2.dist-info → sglang-0.3.6.post1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,61 @@
|
|
1
|
+
# Copyright 2023-2024 SGLang Team
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
6
|
+
# Unless required by applicable law or agreed to in writing, software
|
7
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
8
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
9
|
+
# See the License for the specific language governing permissions and
|
10
|
+
# limitations under the License.
|
11
|
+
# ==============================================================================
|
12
|
+
|
13
|
+
import copy
|
14
|
+
import uuid
|
15
|
+
from dataclasses import dataclass
|
16
|
+
from typing import Optional
|
17
|
+
|
18
|
+
from sglang.srt.managers.io_struct import TokenizedGenerateReqInput
|
19
|
+
from sglang.srt.managers.schedule_batch import FINISH_ABORT, List, Req
|
20
|
+
|
21
|
+
|
22
|
+
class Session:
|
23
|
+
def __init__(self, capacity_of_str_len: int, session_id: str = None):
|
24
|
+
self.session_id = session_id if session_id is not None else uuid.uuid4().hex
|
25
|
+
self.capacity_of_str_len = capacity_of_str_len
|
26
|
+
self.reqs: List[Req] = []
|
27
|
+
|
28
|
+
def create_req(self, req: TokenizedGenerateReqInput, tokenizer):
|
29
|
+
if req.session_rid is not None:
|
30
|
+
while len(self.reqs) > 0:
|
31
|
+
if self.reqs[-1].rid == req.session_rid:
|
32
|
+
break
|
33
|
+
self.reqs = self.reqs[:-1]
|
34
|
+
else:
|
35
|
+
self.reqs = []
|
36
|
+
if len(self.reqs) > 0:
|
37
|
+
input_ids = (
|
38
|
+
self.reqs[-1].origin_input_ids
|
39
|
+
+ self.reqs[-1].output_ids[
|
40
|
+
: self.reqs[-1].sampling_params.max_new_tokens
|
41
|
+
]
|
42
|
+
+ req.input_ids
|
43
|
+
)
|
44
|
+
else:
|
45
|
+
input_ids = req.input_ids
|
46
|
+
new_req = Req(
|
47
|
+
req.rid,
|
48
|
+
None,
|
49
|
+
input_ids,
|
50
|
+
req.sampling_params,
|
51
|
+
lora_path=req.lora_path,
|
52
|
+
session_id=self.session_id,
|
53
|
+
)
|
54
|
+
new_req.tokenizer = tokenizer
|
55
|
+
if req.session_rid is not None and len(self.reqs) == 0:
|
56
|
+
new_req.finished_reason = FINISH_ABORT(
|
57
|
+
f"Invalid request: requested session rid {req.session_rid} does not exist in the session history"
|
58
|
+
)
|
59
|
+
else:
|
60
|
+
self.reqs.append(new_req)
|
61
|
+
return new_req
|
@@ -1,18 +1,16 @@
|
|
1
|
-
|
2
|
-
|
3
|
-
|
4
|
-
|
5
|
-
|
6
|
-
|
7
|
-
|
8
|
-
|
9
|
-
|
10
|
-
|
11
|
-
|
12
|
-
|
13
|
-
|
14
|
-
"""
|
15
|
-
|
1
|
+
# Copyright 2023-2024 SGLang Team
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
#
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
7
|
+
#
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
11
|
+
# See the License for the specific language governing permissions and
|
12
|
+
# limitations under the License.
|
13
|
+
# ==============================================================================
|
16
14
|
"""TokenizerManager is a process that tokenizes the text."""
|
17
15
|
|
18
16
|
import asyncio
|
@@ -23,6 +21,7 @@ import os
|
|
23
21
|
import signal
|
24
22
|
import sys
|
25
23
|
import time
|
24
|
+
import uuid
|
26
25
|
from typing import Dict, List, Optional, Tuple, Union
|
27
26
|
|
28
27
|
import fastapi
|
@@ -42,11 +41,14 @@ from sglang.srt.managers.io_struct import (
|
|
42
41
|
BatchEmbeddingOut,
|
43
42
|
BatchStrOut,
|
44
43
|
BatchTokenIDOut,
|
44
|
+
CloseSessionReqInput,
|
45
45
|
EmbeddingReqInput,
|
46
46
|
FlushCacheReq,
|
47
47
|
GenerateReqInput,
|
48
48
|
GetMemPoolSizeReq,
|
49
49
|
GetMemPoolSizeReqOutput,
|
50
|
+
OpenSessionReqInput,
|
51
|
+
OpenSessionReqOutput,
|
50
52
|
ProfileReq,
|
51
53
|
TokenizedEmbeddingReqInput,
|
52
54
|
TokenizedGenerateReqInput,
|
@@ -146,6 +148,9 @@ class TokenizerManager:
|
|
146
148
|
self.model_update_lock = asyncio.Lock()
|
147
149
|
self.model_update_result = None
|
148
150
|
|
151
|
+
# For session info
|
152
|
+
self.session_futures = {} # session_id -> asyncio event
|
153
|
+
|
149
154
|
# Others
|
150
155
|
self.gracefully_exit = False
|
151
156
|
|
@@ -196,8 +201,18 @@ class TokenizerManager:
|
|
196
201
|
):
|
197
202
|
"""Tokenize one request."""
|
198
203
|
# Tokenize
|
204
|
+
input_embeds = None
|
199
205
|
input_text = obj.text
|
200
|
-
if obj.
|
206
|
+
if obj.input_embeds is not None:
|
207
|
+
if not self.server_args.disable_radix_cache:
|
208
|
+
raise ValueError(
|
209
|
+
"input_embeds is provided while disable_radix_cache is False. "
|
210
|
+
"Please add `--disable-radix-cach` when you launch the server "
|
211
|
+
"if you want to use input_embeds as inputs."
|
212
|
+
)
|
213
|
+
input_embeds = obj.input_embeds
|
214
|
+
input_ids = obj.input_ids
|
215
|
+
elif obj.input_ids is None:
|
201
216
|
input_ids = self.tokenizer.encode(input_text)
|
202
217
|
else:
|
203
218
|
input_ids = obj.input_ids
|
@@ -211,8 +226,10 @@ class TokenizerManager:
|
|
211
226
|
return_logprob = obj.return_logprob
|
212
227
|
logprob_start_len = obj.logprob_start_len
|
213
228
|
top_logprobs_num = obj.top_logprobs_num
|
229
|
+
session_id = obj.session[0] if obj.session else None
|
230
|
+
session_rid = obj.session[1] if obj.session else None
|
214
231
|
|
215
|
-
if len(input_ids) >= self.context_len:
|
232
|
+
if obj.input_ids is not None and len(input_ids) >= self.context_len:
|
216
233
|
raise ValueError(
|
217
234
|
f"The input ({len(input_ids)} tokens) is longer than the "
|
218
235
|
f"model's context length ({self.context_len} tokens)."
|
@@ -235,7 +252,10 @@ class TokenizerManager:
|
|
235
252
|
logprob_start_len,
|
236
253
|
top_logprobs_num,
|
237
254
|
obj.stream,
|
238
|
-
obj.lora_path,
|
255
|
+
lora_path=obj.lora_path,
|
256
|
+
input_embeds=input_embeds,
|
257
|
+
session_id=session_id,
|
258
|
+
session_rid=session_rid,
|
239
259
|
)
|
240
260
|
elif isinstance(obj, EmbeddingReqInput):
|
241
261
|
tokenized_obj = TokenizedEmbeddingReqInput(
|
@@ -451,6 +471,26 @@ class TokenizerManager:
|
|
451
471
|
else:
|
452
472
|
return False, "Another update is in progress. Please try again later."
|
453
473
|
|
474
|
+
async def open_session(
|
475
|
+
self, obj: OpenSessionReqInput, request: Optional[fastapi.Request] = None
|
476
|
+
):
|
477
|
+
if self.to_create_loop:
|
478
|
+
self.create_handle_loop()
|
479
|
+
|
480
|
+
session_id = uuid.uuid4().hex
|
481
|
+
obj.session_id = session_id
|
482
|
+
self.send_to_scheduler.send_pyobj(obj)
|
483
|
+
self.session_futures[session_id] = asyncio.Future()
|
484
|
+
session_id = await self.session_futures[session_id]
|
485
|
+
del self.session_futures[session_id]
|
486
|
+
return session_id
|
487
|
+
|
488
|
+
async def close_session(
|
489
|
+
self, obj: CloseSessionReqInput, request: Optional[fastapi.Request] = None
|
490
|
+
):
|
491
|
+
assert not self.to_create_loop, "close session should not be the first request"
|
492
|
+
await self.send_to_scheduler.send_pyobj(obj)
|
493
|
+
|
454
494
|
def create_abort_task(self, obj: GenerateReqInput):
|
455
495
|
# Abort the request if the client is disconnected.
|
456
496
|
async def abort_request():
|
@@ -521,6 +561,11 @@ class TokenizerManager:
|
|
521
561
|
if len(self.mem_pool_size_tmp) == self.server_args.dp_size:
|
522
562
|
self.mem_pool_size.set_result(self.mem_pool_size_tmp)
|
523
563
|
continue
|
564
|
+
elif isinstance(recv_obj, OpenSessionReqOutput):
|
565
|
+
self.session_futures[recv_obj.session_id].set_result(
|
566
|
+
recv_obj.session_id
|
567
|
+
)
|
568
|
+
continue
|
524
569
|
|
525
570
|
assert isinstance(
|
526
571
|
recv_obj, (BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut)
|
sglang/srt/managers/tp_worker.py
CHANGED
@@ -1,21 +1,20 @@
|
|
1
|
-
|
2
|
-
|
3
|
-
|
4
|
-
|
5
|
-
|
6
|
-
|
7
|
-
|
8
|
-
|
9
|
-
|
10
|
-
|
11
|
-
|
12
|
-
|
13
|
-
|
14
|
-
"""
|
15
|
-
|
1
|
+
# Copyright 2023-2024 SGLang Team
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
#
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
7
|
+
#
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
11
|
+
# See the License for the specific language governing permissions and
|
12
|
+
# limitations under the License.
|
13
|
+
# ==============================================================================
|
16
14
|
"""A tensor parallel worker."""
|
17
15
|
|
18
16
|
import logging
|
17
|
+
import threading
|
19
18
|
from typing import Optional
|
20
19
|
|
21
20
|
from sglang.srt.configs.model_config import ModelConfig
|
@@ -134,9 +133,19 @@ class TpModelWorker:
|
|
134
133
|
self.model_runner.token_to_kv_pool,
|
135
134
|
)
|
136
135
|
|
137
|
-
def
|
136
|
+
def forward_batch_idle(self, model_worker_batch: ModelWorkerBatch):
|
137
|
+
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
|
138
|
+
self.model_runner.forward(forward_batch)
|
139
|
+
|
140
|
+
def forward_batch_generation(
|
141
|
+
self,
|
142
|
+
model_worker_batch: ModelWorkerBatch,
|
143
|
+
launch_done: Optional[threading.Event] = None,
|
144
|
+
):
|
138
145
|
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
|
139
146
|
logits_output = self.model_runner.forward(forward_batch)
|
147
|
+
if launch_done:
|
148
|
+
launch_done.set()
|
140
149
|
next_token_ids = self.model_runner.sample(logits_output, model_worker_batch)
|
141
150
|
return logits_output, next_token_ids
|
142
151
|
|
@@ -1,23 +1,21 @@
|
|
1
|
-
|
2
|
-
|
3
|
-
|
4
|
-
|
5
|
-
|
6
|
-
|
7
|
-
|
8
|
-
|
9
|
-
|
10
|
-
|
11
|
-
|
12
|
-
|
13
|
-
|
14
|
-
"""
|
15
|
-
|
1
|
+
# Copyright 2023-2024 SGLang Team
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
#
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
7
|
+
#
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
11
|
+
# See the License for the specific language governing permissions and
|
12
|
+
# limitations under the License.
|
13
|
+
# ==============================================================================
|
16
14
|
"""A tensor parallel worker."""
|
17
15
|
|
16
|
+
import dataclasses
|
18
17
|
import logging
|
19
18
|
import threading
|
20
|
-
import time
|
21
19
|
from queue import Queue
|
22
20
|
from typing import Optional
|
23
21
|
|
@@ -26,7 +24,6 @@ import torch
|
|
26
24
|
from sglang.srt.managers.io_struct import UpdateWeightReqInput
|
27
25
|
from sglang.srt.managers.schedule_batch import ModelWorkerBatch
|
28
26
|
from sglang.srt.managers.tp_worker import TpModelWorker
|
29
|
-
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
30
27
|
from sglang.srt.server_args import ServerArgs
|
31
28
|
|
32
29
|
logger = logging.getLogger(__name__)
|
@@ -56,6 +53,7 @@ class TpModelWorkerClient:
|
|
56
53
|
self.worker = TpModelWorker(server_args, gpu_id, tp_rank, dp_rank, nccl_port)
|
57
54
|
self.max_running_requests = self.worker.max_running_requests
|
58
55
|
self.device = self.worker.device
|
56
|
+
self.gpu_id = gpu_id
|
59
57
|
|
60
58
|
# Init future mappings
|
61
59
|
self.future_token_ids_ct = 0
|
@@ -73,12 +71,6 @@ class TpModelWorkerClient:
|
|
73
71
|
)
|
74
72
|
self.forward_thread.start()
|
75
73
|
|
76
|
-
self.copy_queue = Queue()
|
77
|
-
self.copy_thread = threading.Thread(
|
78
|
-
target=self.copy_thread_func,
|
79
|
-
)
|
80
|
-
self.copy_thread.start()
|
81
|
-
|
82
74
|
def get_worker_info(self):
|
83
75
|
return self.worker.get_worker_info()
|
84
76
|
|
@@ -98,15 +90,25 @@ class TpModelWorkerClient:
|
|
98
90
|
with torch.cuda.stream(self.forward_stream):
|
99
91
|
self.forward_thread_func_()
|
100
92
|
|
101
|
-
@torch.
|
93
|
+
@torch.no_grad()
|
102
94
|
def forward_thread_func_(self):
|
95
|
+
batch_pt = 0
|
96
|
+
batch_lists = [None] * 2
|
97
|
+
|
103
98
|
while True:
|
104
|
-
self.has_inflight_batch = False
|
105
99
|
model_worker_batch, future_token_ids_ct = self.input_queue.get()
|
106
100
|
if not model_worker_batch:
|
107
101
|
break
|
108
|
-
|
109
|
-
|
102
|
+
|
103
|
+
# Keep a reference of model_worker_batch by storing it into a list.
|
104
|
+
# Otherwise, the tensor members of model_worker_batch will be released
|
105
|
+
# by pytorch and cause CUDA illegal memory access errors.
|
106
|
+
batch_lists[batch_pt % 2] = model_worker_batch
|
107
|
+
batch_pt += 1
|
108
|
+
|
109
|
+
# Create event
|
110
|
+
self.launch_done = threading.Event()
|
111
|
+
copy_done = torch.cuda.Event()
|
110
112
|
|
111
113
|
# Resolve future tokens in the input
|
112
114
|
input_ids = model_worker_batch.input_ids
|
@@ -114,7 +116,7 @@ class TpModelWorkerClient:
|
|
114
116
|
|
115
117
|
# Run forward
|
116
118
|
logits_output, next_token_ids = self.worker.forward_batch_generation(
|
117
|
-
model_worker_batch
|
119
|
+
model_worker_batch, self.launch_done
|
118
120
|
)
|
119
121
|
|
120
122
|
# Update the future token ids map
|
@@ -139,44 +141,45 @@ class TpModelWorkerClient:
|
|
139
141
|
)
|
140
142
|
)
|
141
143
|
next_token_ids = next_token_ids.to("cpu", non_blocking=True)
|
142
|
-
|
143
|
-
copy_event.record()
|
144
|
+
copy_done.record()
|
144
145
|
|
145
|
-
self.
|
146
|
-
self.copy_queue.put((copy_event, logits_output, next_token_ids))
|
146
|
+
self.output_queue.put((copy_done, logits_output, next_token_ids))
|
147
147
|
|
148
|
-
def
|
149
|
-
|
150
|
-
|
151
|
-
|
152
|
-
break
|
153
|
-
while not copy_event.query():
|
154
|
-
time.sleep(1e-5)
|
148
|
+
def resolve_batch_result(self, bid: int):
|
149
|
+
copy_done, logits_output, next_token_ids = self.output_queue.get()
|
150
|
+
copy_done.synchronize()
|
151
|
+
self.launch_done.wait()
|
155
152
|
|
156
|
-
|
157
|
-
|
158
|
-
|
153
|
+
if logits_output.next_token_logprobs is not None:
|
154
|
+
logits_output.next_token_logprobs = (
|
155
|
+
logits_output.next_token_logprobs.tolist()
|
156
|
+
)
|
157
|
+
if logits_output.input_token_logprobs is not None:
|
158
|
+
logits_output.input_token_logprobs = (
|
159
|
+
logits_output.input_token_logprobs.tolist()
|
159
160
|
)
|
160
|
-
|
161
|
-
logits_output.
|
162
|
-
|
163
|
-
|
164
|
-
logits_output.normalized_prompt_logprobs = (
|
165
|
-
logits_output.normalized_prompt_logprobs.tolist()
|
166
|
-
)
|
167
|
-
|
168
|
-
self.output_queue.put((logits_output, next_token_ids.tolist()))
|
169
|
-
|
170
|
-
def resulve_batch_result(self, bid: int):
|
171
|
-
logits_output, next_token_ids = self.output_queue.get()
|
172
|
-
if self.has_inflight_batch:
|
173
|
-
# Wait until the batch is launched
|
174
|
-
self.launch_event.wait()
|
161
|
+
logits_output.normalized_prompt_logprobs = (
|
162
|
+
logits_output.normalized_prompt_logprobs.tolist()
|
163
|
+
)
|
164
|
+
next_token_ids = next_token_ids.tolist()
|
175
165
|
return logits_output, next_token_ids
|
176
166
|
|
177
167
|
def forward_batch_generation(self, model_worker_batch: ModelWorkerBatch):
|
168
|
+
# Create a new copy of sampling_info because it will be updated in-place by the scheduler for the next batch.
|
169
|
+
sampling_info = model_worker_batch.sampling_info
|
170
|
+
sampling_info.update_penalties()
|
171
|
+
model_worker_batch.sampling_info = self.cur_sampling_info = dataclasses.replace(
|
172
|
+
sampling_info,
|
173
|
+
sampling_info_done=threading.Event(),
|
174
|
+
scaling_penalties=sampling_info.scaling_penalties,
|
175
|
+
linear_penalties=sampling_info.linear_penalties,
|
176
|
+
)
|
177
|
+
|
178
|
+
# A cuda stream sync here to avoid the cuda illegal memory access error.
|
179
|
+
torch.cuda.current_stream().synchronize()
|
180
|
+
|
178
181
|
# Push a new batch to the queue
|
179
|
-
self.input_queue.put((model_worker_batch
|
182
|
+
self.input_queue.put((model_worker_batch, self.future_token_ids_ct))
|
180
183
|
|
181
184
|
# Allocate output future objects
|
182
185
|
bs = len(model_worker_batch.seq_lens)
|
@@ -192,16 +195,8 @@ class TpModelWorkerClient:
|
|
192
195
|
) % self.future_token_ids_limit
|
193
196
|
return None, future_next_token_ids
|
194
197
|
|
195
|
-
def forward_batch_embedding(self, model_worker_batch: ModelWorkerBatch):
|
196
|
-
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
|
197
|
-
logits_output = self.model_runner.forward(forward_batch)
|
198
|
-
embeddings = logits_output.embeddings
|
199
|
-
return embeddings
|
200
|
-
|
201
198
|
def update_weights(self, recv_req: UpdateWeightReqInput):
|
202
|
-
success, message = self.
|
203
|
-
recv_req.model_path, recv_req.load_format
|
204
|
-
)
|
199
|
+
success, message = self.worker.update_weights(recv_req)
|
205
200
|
return success, message
|
206
201
|
|
207
202
|
def __delete__(self):
|
sglang/srt/metrics/collector.py
CHANGED
@@ -1,18 +1,16 @@
|
|
1
|
-
|
2
|
-
|
3
|
-
|
4
|
-
|
5
|
-
|
6
|
-
|
7
|
-
|
8
|
-
|
9
|
-
|
10
|
-
|
11
|
-
|
12
|
-
|
13
|
-
|
14
|
-
"""
|
15
|
-
|
1
|
+
# Copyright 2023-2024 SGLang Team
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
#
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
7
|
+
#
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
11
|
+
# See the License for the specific language governing permissions and
|
12
|
+
# limitations under the License.
|
13
|
+
# ==============================================================================
|
16
14
|
"""Utilities for Prometheus Metrics Collection."""
|
17
15
|
|
18
16
|
from dataclasses import dataclass
|
sglang/srt/metrics/func_timer.py
CHANGED
@@ -1,18 +1,16 @@
|
|
1
|
-
|
2
|
-
|
3
|
-
|
4
|
-
|
5
|
-
|
6
|
-
|
7
|
-
|
8
|
-
|
9
|
-
|
10
|
-
|
11
|
-
|
12
|
-
|
13
|
-
|
14
|
-
"""
|
15
|
-
|
1
|
+
# Copyright 2023-2024 SGLang Team
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
#
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
7
|
+
#
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
11
|
+
# See the License for the specific language governing permissions and
|
12
|
+
# limitations under the License.
|
13
|
+
# ==============================================================================
|
16
14
|
"""
|
17
15
|
Records the latency of some functions
|
18
16
|
"""
|
sglang/srt/mm_utils.py
CHANGED
@@ -1,17 +1,16 @@
|
|
1
|
-
|
2
|
-
|
3
|
-
|
4
|
-
|
5
|
-
|
6
|
-
|
7
|
-
|
8
|
-
|
9
|
-
|
10
|
-
|
11
|
-
|
12
|
-
|
13
|
-
|
14
|
-
"""
|
1
|
+
# Copyright 2023-2024 SGLang Team
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
#
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
7
|
+
#
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
11
|
+
# See the License for the specific language governing permissions and
|
12
|
+
# limitations under the License.
|
13
|
+
# ==============================================================================
|
15
14
|
|
16
15
|
# Source: https://github.com/LLaVA-VL/LLaVA-NeXT/blob/main/llava/mm_utils.py
|
17
16
|
"""
|