sglang 0.4.1.post5__py3-none-any.whl → 0.4.1.post7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +21 -23
- sglang/api.py +2 -7
- sglang/bench_offline_throughput.py +24 -16
- sglang/bench_one_batch.py +51 -3
- sglang/bench_one_batch_server.py +1 -1
- sglang/bench_serving.py +37 -28
- sglang/lang/backend/runtime_endpoint.py +183 -4
- sglang/lang/chat_template.py +15 -4
- sglang/launch_server.py +1 -1
- sglang/srt/_custom_ops.py +80 -42
- sglang/srt/configs/device_config.py +1 -1
- sglang/srt/configs/model_config.py +16 -6
- sglang/srt/constrained/base_grammar_backend.py +21 -0
- sglang/srt/constrained/xgrammar_backend.py +8 -4
- sglang/srt/conversation.py +14 -1
- sglang/srt/distributed/__init__.py +3 -3
- sglang/srt/distributed/communication_op.py +2 -1
- sglang/srt/distributed/device_communicators/cuda_wrapper.py +2 -1
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +107 -40
- sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +2 -2
- sglang/srt/distributed/device_communicators/hpu_communicator.py +2 -1
- sglang/srt/distributed/device_communicators/pynccl.py +80 -1
- sglang/srt/distributed/device_communicators/pynccl_wrapper.py +112 -2
- sglang/srt/distributed/device_communicators/shm_broadcast.py +5 -72
- sglang/srt/distributed/device_communicators/xpu_communicator.py +2 -1
- sglang/srt/distributed/parallel_state.py +1 -1
- sglang/srt/distributed/utils.py +2 -1
- sglang/srt/entrypoints/engine.py +449 -0
- sglang/srt/entrypoints/http_server.py +579 -0
- sglang/srt/layers/activation.py +3 -3
- sglang/srt/layers/attention/flashinfer_backend.py +27 -12
- sglang/srt/layers/attention/triton_backend.py +4 -6
- sglang/srt/layers/attention/vision.py +204 -0
- sglang/srt/layers/dp_attention.py +69 -0
- sglang/srt/layers/linear.py +76 -102
- sglang/srt/layers/logits_processor.py +48 -63
- sglang/srt/layers/moe/ep_moe/layer.py +4 -4
- sglang/srt/layers/moe/fused_moe_native.py +69 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +9 -6
- sglang/srt/layers/moe/fused_moe_triton/layer.py +66 -14
- sglang/srt/layers/moe/topk.py +4 -2
- sglang/srt/layers/parameter.py +26 -17
- sglang/srt/layers/quantization/__init__.py +22 -23
- sglang/srt/layers/quantization/fp8.py +112 -55
- sglang/srt/layers/quantization/fp8_utils.py +1 -1
- sglang/srt/layers/quantization/int8_kernel.py +54 -0
- sglang/srt/layers/quantization/modelopt_quant.py +2 -3
- sglang/srt/layers/quantization/w8a8_int8.py +117 -0
- sglang/srt/layers/radix_attention.py +2 -0
- sglang/srt/layers/rotary_embedding.py +1179 -31
- sglang/srt/layers/sampler.py +39 -1
- sglang/srt/layers/vocab_parallel_embedding.py +17 -4
- sglang/srt/lora/lora.py +1 -9
- sglang/srt/managers/configure_logging.py +46 -0
- sglang/srt/managers/data_parallel_controller.py +79 -72
- sglang/srt/managers/detokenizer_manager.py +23 -8
- sglang/srt/managers/image_processor.py +158 -2
- sglang/srt/managers/io_struct.py +54 -15
- sglang/srt/managers/schedule_batch.py +49 -22
- sglang/srt/managers/schedule_policy.py +26 -12
- sglang/srt/managers/scheduler.py +319 -181
- sglang/srt/managers/session_controller.py +1 -0
- sglang/srt/managers/tokenizer_manager.py +303 -158
- sglang/srt/managers/tp_worker.py +6 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +5 -8
- sglang/srt/managers/utils.py +44 -0
- sglang/srt/mem_cache/memory_pool.py +110 -77
- sglang/srt/metrics/collector.py +25 -11
- sglang/srt/model_executor/cuda_graph_runner.py +4 -6
- sglang/srt/model_executor/model_runner.py +80 -21
- sglang/srt/model_loader/loader.py +8 -6
- sglang/srt/model_loader/weight_utils.py +55 -2
- sglang/srt/models/baichuan.py +6 -6
- sglang/srt/models/chatglm.py +2 -2
- sglang/srt/models/commandr.py +3 -3
- sglang/srt/models/dbrx.py +4 -4
- sglang/srt/models/deepseek.py +3 -3
- sglang/srt/models/deepseek_v2.py +8 -8
- sglang/srt/models/exaone.py +2 -2
- sglang/srt/models/gemma.py +2 -2
- sglang/srt/models/gemma2.py +6 -24
- sglang/srt/models/gpt2.py +3 -5
- sglang/srt/models/gpt_bigcode.py +1 -1
- sglang/srt/models/granite.py +2 -2
- sglang/srt/models/grok.py +3 -3
- sglang/srt/models/internlm2.py +2 -2
- sglang/srt/models/llama.py +41 -4
- sglang/srt/models/minicpm.py +2 -2
- sglang/srt/models/minicpm3.py +6 -6
- sglang/srt/models/minicpmv.py +1238 -0
- sglang/srt/models/mixtral.py +3 -3
- sglang/srt/models/mixtral_quant.py +3 -3
- sglang/srt/models/mllama.py +2 -2
- sglang/srt/models/olmo.py +3 -3
- sglang/srt/models/olmo2.py +4 -4
- sglang/srt/models/olmoe.py +7 -13
- sglang/srt/models/phi3_small.py +2 -2
- sglang/srt/models/qwen.py +2 -2
- sglang/srt/models/qwen2.py +52 -4
- sglang/srt/models/qwen2_eagle.py +131 -0
- sglang/srt/models/qwen2_moe.py +3 -3
- sglang/srt/models/qwen2_vl.py +22 -122
- sglang/srt/models/stablelm.py +2 -2
- sglang/srt/models/torch_native_llama.py +3 -3
- sglang/srt/models/xverse.py +6 -6
- sglang/srt/models/xverse_moe.py +6 -6
- sglang/srt/openai_api/protocol.py +2 -0
- sglang/srt/sampling/custom_logit_processor.py +38 -0
- sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +15 -5
- sglang/srt/sampling/sampling_batch_info.py +153 -9
- sglang/srt/sampling/sampling_params.py +4 -2
- sglang/srt/server.py +4 -1037
- sglang/srt/server_args.py +84 -32
- sglang/srt/speculative/eagle_worker.py +1 -0
- sglang/srt/torch_memory_saver_adapter.py +59 -0
- sglang/srt/utils.py +130 -63
- sglang/test/runners.py +8 -13
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +3 -1
- sglang/utils.py +12 -2
- sglang/version.py +1 -1
- {sglang-0.4.1.post5.dist-info → sglang-0.4.1.post7.dist-info}/METADATA +26 -13
- {sglang-0.4.1.post5.dist-info → sglang-0.4.1.post7.dist-info}/RECORD +126 -117
- sglang/launch_server_llavavid.py +0 -25
- sglang/srt/constrained/__init__.py +0 -16
- sglang/srt/distributed/device_communicators/__init__.py +0 -0
- {sglang-0.4.1.post5.dist-info → sglang-0.4.1.post7.dist-info}/LICENSE +0 -0
- {sglang-0.4.1.post5.dist-info → sglang-0.4.1.post7.dist-info}/WHEEL +0 -0
- {sglang-0.4.1.post5.dist-info → sglang-0.4.1.post7.dist-info}/top_level.txt +0 -0
@@ -131,6 +131,7 @@ class Session:
|
|
131
131
|
sampling_params=req.sampling_params,
|
132
132
|
lora_path=req.lora_path,
|
133
133
|
session_id=self.session_id,
|
134
|
+
custom_logit_processor=req.custom_logit_processor,
|
134
135
|
)
|
135
136
|
if last_req is not None:
|
136
137
|
new_req.image_inputs = last_req.image_inputs
|
@@ -18,10 +18,14 @@ import copy
|
|
18
18
|
import dataclasses
|
19
19
|
import logging
|
20
20
|
import os
|
21
|
+
import pickle
|
21
22
|
import signal
|
22
23
|
import sys
|
24
|
+
import threading
|
23
25
|
import time
|
24
26
|
import uuid
|
27
|
+
from datetime import datetime
|
28
|
+
from http import HTTPStatus
|
25
29
|
from typing import Any, Awaitable, Dict, Generic, List, Optional, Tuple, TypeVar, Union
|
26
30
|
|
27
31
|
import fastapi
|
@@ -43,6 +47,7 @@ from sglang.srt.managers.io_struct import (
|
|
43
47
|
BatchStrOut,
|
44
48
|
BatchTokenIDOut,
|
45
49
|
CloseSessionReqInput,
|
50
|
+
ConfigureLoggingReq,
|
46
51
|
EmbeddingReqInput,
|
47
52
|
FlushCacheReq,
|
48
53
|
GenerateReqInput,
|
@@ -53,6 +58,10 @@ from sglang.srt.managers.io_struct import (
|
|
53
58
|
OpenSessionReqInput,
|
54
59
|
OpenSessionReqOutput,
|
55
60
|
ProfileReq,
|
61
|
+
ReleaseMemoryOccupationReqInput,
|
62
|
+
ReleaseMemoryOccupationReqOutput,
|
63
|
+
ResumeMemoryOccupationReqInput,
|
64
|
+
ResumeMemoryOccupationReqOutput,
|
56
65
|
SessionParams,
|
57
66
|
TokenizedEmbeddingReqInput,
|
58
67
|
TokenizedGenerateReqInput,
|
@@ -71,6 +80,7 @@ from sglang.srt.utils import (
|
|
71
80
|
get_zmq_socket,
|
72
81
|
kill_process_tree,
|
73
82
|
)
|
83
|
+
from sglang.utils import TypeBasedDispatcher, get_exception_traceback
|
74
84
|
|
75
85
|
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
76
86
|
|
@@ -103,16 +113,19 @@ class TokenizerManager:
|
|
103
113
|
port_args: PortArgs,
|
104
114
|
):
|
105
115
|
# Parse args
|
116
|
+
|
106
117
|
self.server_args = server_args
|
107
118
|
self.enable_metrics = server_args.enable_metrics
|
119
|
+
self.log_requests = server_args.log_requests
|
120
|
+
self.log_requests_level = 0
|
108
121
|
|
109
122
|
# Init inter-process communication
|
110
123
|
context = zmq.asyncio.Context(2)
|
111
124
|
self.recv_from_detokenizer = get_zmq_socket(
|
112
|
-
context, zmq.PULL, port_args.tokenizer_ipc_name
|
125
|
+
context, zmq.PULL, port_args.tokenizer_ipc_name, True
|
113
126
|
)
|
114
127
|
self.send_to_scheduler = get_zmq_socket(
|
115
|
-
context, zmq.PUSH, port_args.scheduler_input_ipc_name
|
128
|
+
context, zmq.PUSH, port_args.scheduler_input_ipc_name, True
|
116
129
|
)
|
117
130
|
|
118
131
|
# Read model args
|
@@ -145,6 +158,7 @@ class TokenizerManager:
|
|
145
158
|
server_args.tokenizer_path,
|
146
159
|
tokenizer_mode=server_args.tokenizer_mode,
|
147
160
|
trust_remote_code=server_args.trust_remote_code,
|
161
|
+
revision=server_args.revision,
|
148
162
|
)
|
149
163
|
self.tokenizer = self.processor.tokenizer
|
150
164
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
@@ -158,11 +172,15 @@ class TokenizerManager:
|
|
158
172
|
server_args.tokenizer_path,
|
159
173
|
tokenizer_mode=server_args.tokenizer_mode,
|
160
174
|
trust_remote_code=server_args.trust_remote_code,
|
175
|
+
revision=server_args.revision,
|
161
176
|
)
|
162
177
|
|
163
178
|
# Store states
|
164
|
-
self.
|
179
|
+
self.no_create_loop = False
|
165
180
|
self.rid_to_state: Dict[str, ReqState] = {}
|
181
|
+
self.dump_requests_folder = "" # By default do not dump
|
182
|
+
self.dump_requests_threshold = 1000
|
183
|
+
self.dump_request_list: List[Tuple] = []
|
166
184
|
|
167
185
|
# The event to notify the weight sync is finished.
|
168
186
|
self.model_update_lock = RWLock()
|
@@ -188,6 +206,14 @@ class TokenizerManager:
|
|
188
206
|
self.get_weights_by_name_communicator = _Communicator(
|
189
207
|
self.send_to_scheduler, server_args.dp_size
|
190
208
|
)
|
209
|
+
self.release_memory_occupation_communicator = _Communicator(
|
210
|
+
self.send_to_scheduler, server_args.dp_size
|
211
|
+
)
|
212
|
+
self.resume_memory_occupation_communicator = _Communicator(
|
213
|
+
self.send_to_scheduler, server_args.dp_size
|
214
|
+
)
|
215
|
+
# Set after scheduler is initialized
|
216
|
+
self.max_req_input_len = None
|
191
217
|
|
192
218
|
# Metrics
|
193
219
|
if self.enable_metrics:
|
@@ -198,6 +224,44 @@ class TokenizerManager:
|
|
198
224
|
},
|
199
225
|
)
|
200
226
|
|
227
|
+
self._result_dispatcher = TypeBasedDispatcher(
|
228
|
+
[
|
229
|
+
(
|
230
|
+
(BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut),
|
231
|
+
self._handle_batch_output,
|
232
|
+
),
|
233
|
+
(OpenSessionReqOutput, self._handle_open_session_req_output),
|
234
|
+
(
|
235
|
+
UpdateWeightFromDiskReqOutput,
|
236
|
+
self._handle_update_weights_from_disk_req_output,
|
237
|
+
),
|
238
|
+
(
|
239
|
+
InitWeightsUpdateGroupReqOutput,
|
240
|
+
self.init_weights_update_group_communicator.handle_recv,
|
241
|
+
),
|
242
|
+
(
|
243
|
+
UpdateWeightsFromDistributedReqOutput,
|
244
|
+
self.update_weights_from_distributed_communicator.handle_recv,
|
245
|
+
),
|
246
|
+
(
|
247
|
+
UpdateWeightsFromTensorReqOutput,
|
248
|
+
self.update_weights_from_tensor_communicator.handle_recv,
|
249
|
+
),
|
250
|
+
(
|
251
|
+
GetWeightsByNameReqOutput,
|
252
|
+
self.get_weights_by_name_communicator.handle_recv,
|
253
|
+
),
|
254
|
+
(
|
255
|
+
ReleaseMemoryOccupationReqOutput,
|
256
|
+
self.release_memory_occupation_communicator.handle_recv,
|
257
|
+
),
|
258
|
+
(
|
259
|
+
ResumeMemoryOccupationReqOutput,
|
260
|
+
self.resume_memory_occupation_communicator.handle_recv,
|
261
|
+
),
|
262
|
+
]
|
263
|
+
)
|
264
|
+
|
201
265
|
async def generate_request(
|
202
266
|
self,
|
203
267
|
obj: Union[GenerateReqInput, EmbeddingReqInput],
|
@@ -215,8 +279,11 @@ class TokenizerManager:
|
|
215
279
|
|
216
280
|
obj.normalize_batch_and_arguments()
|
217
281
|
|
218
|
-
if self.
|
219
|
-
|
282
|
+
if self.log_requests:
|
283
|
+
max_length = 2048 if self.log_requests_level == 0 else 1 << 30
|
284
|
+
logger.info(
|
285
|
+
f"Receive: obj={dataclass_to_string_truncated(obj, max_length)}"
|
286
|
+
)
|
220
287
|
|
221
288
|
async with self.model_update_lock.reader_lock:
|
222
289
|
is_single = obj.is_single
|
@@ -248,15 +315,21 @@ class TokenizerManager:
|
|
248
315
|
)
|
249
316
|
input_embeds = obj.input_embeds
|
250
317
|
input_ids = obj.input_ids
|
251
|
-
elif obj.input_ids is None:
|
252
|
-
input_ids = self.tokenizer.encode(input_text)
|
253
|
-
else:
|
318
|
+
elif obj.input_ids is not None:
|
254
319
|
input_ids = obj.input_ids
|
320
|
+
else:
|
321
|
+
if self.tokenizer is None:
|
322
|
+
raise ValueError(
|
323
|
+
"The engine initialized with skip_tokenizer_init=True cannot "
|
324
|
+
"accept text prompts. Please provide input_ids or re-initialize "
|
325
|
+
"the engine with skip_tokenizer_init=False."
|
326
|
+
)
|
327
|
+
input_ids = self.tokenizer.encode(input_text)
|
255
328
|
|
256
329
|
if self.is_generation:
|
257
330
|
# TODO: also support getting embeddings for multimodal models
|
258
331
|
image_inputs: Dict = await self.image_processor.process_images_async(
|
259
|
-
obj.image_data, input_text or input_ids, obj
|
332
|
+
obj.image_data, input_text or input_ids, obj, self.max_req_input_len
|
260
333
|
)
|
261
334
|
if image_inputs and "input_ids" in image_inputs:
|
262
335
|
input_ids = image_inputs["input_ids"]
|
@@ -267,12 +340,28 @@ class TokenizerManager:
|
|
267
340
|
SessionParams(**obj.session_params) if obj.session_params else None
|
268
341
|
)
|
269
342
|
|
270
|
-
if
|
343
|
+
input_token_num = len(input_ids) if input_ids is not None else 0
|
344
|
+
if input_token_num >= self.context_len:
|
271
345
|
raise ValueError(
|
272
|
-
f"The input ({
|
346
|
+
f"The input ({input_token_num} tokens) is longer than the "
|
273
347
|
f"model's context length ({self.context_len} tokens)."
|
274
348
|
)
|
275
349
|
|
350
|
+
if (
|
351
|
+
obj.sampling_params.get("max_new_tokens") is not None
|
352
|
+
and obj.sampling_params.get("max_new_tokens") + input_token_num
|
353
|
+
>= self.context_len
|
354
|
+
):
|
355
|
+
raise ValueError(
|
356
|
+
f"Requested token count exceeds the model's maximum context length "
|
357
|
+
f"of {self.context_len} tokens. You requested a total of "
|
358
|
+
f"{obj.sampling_params.get('max_new_tokens') + input_token_num} "
|
359
|
+
f"tokens: {input_token_num} tokens from the input messages and "
|
360
|
+
f"{obj.sampling_params.get('max_new_tokens')} tokens for the "
|
361
|
+
f"completion. Please reduce the number of tokens in the input "
|
362
|
+
f"messages or the completion to fit within the limit."
|
363
|
+
)
|
364
|
+
|
276
365
|
# Parse sampling parameters
|
277
366
|
sampling_params = SamplingParams(**obj.sampling_params)
|
278
367
|
sampling_params.normalize(self.tokenizer)
|
@@ -293,6 +382,7 @@ class TokenizerManager:
|
|
293
382
|
lora_path=obj.lora_path,
|
294
383
|
input_embeds=input_embeds,
|
295
384
|
session_params=session_params,
|
385
|
+
custom_logit_processor=obj.custom_logit_processor,
|
296
386
|
)
|
297
387
|
elif isinstance(obj, EmbeddingReqInput):
|
298
388
|
tokenized_obj = TokenizedEmbeddingReqInput(
|
@@ -336,10 +426,21 @@ class TokenizerManager:
|
|
336
426
|
|
337
427
|
state.out_list = []
|
338
428
|
if state.finished:
|
339
|
-
if self.
|
340
|
-
|
429
|
+
if self.log_requests:
|
430
|
+
max_length = 2048 if self.log_requests_level == 0 else 1 << 30
|
431
|
+
msg = f"Finish: obj={dataclass_to_string_truncated(obj, max_length)}, out={dataclass_to_string_truncated(out, max_length)}"
|
341
432
|
logger.info(msg)
|
342
433
|
del self.rid_to_state[obj.rid]
|
434
|
+
|
435
|
+
# Check if this was an abort/error created by scheduler
|
436
|
+
if isinstance(out["meta_info"].get("finish_reason"), dict):
|
437
|
+
finish_reason = out["meta_info"]["finish_reason"]
|
438
|
+
if (
|
439
|
+
finish_reason.get("type") == "abort"
|
440
|
+
and finish_reason.get("status_code") == HTTPStatus.BAD_REQUEST
|
441
|
+
):
|
442
|
+
raise ValueError(finish_reason["message"])
|
443
|
+
|
343
444
|
yield out
|
344
445
|
break
|
345
446
|
|
@@ -548,6 +649,22 @@ class TokenizerManager:
|
|
548
649
|
else:
|
549
650
|
return all_parameters
|
550
651
|
|
652
|
+
async def release_memory_occupation(
|
653
|
+
self,
|
654
|
+
obj: ReleaseMemoryOccupationReqInput,
|
655
|
+
request: Optional[fastapi.Request] = None,
|
656
|
+
):
|
657
|
+
self.auto_create_handle_loop()
|
658
|
+
await self.release_memory_occupation_communicator(obj)
|
659
|
+
|
660
|
+
async def resume_memory_occupation(
|
661
|
+
self,
|
662
|
+
obj: ResumeMemoryOccupationReqInput,
|
663
|
+
request: Optional[fastapi.Request] = None,
|
664
|
+
):
|
665
|
+
self.auto_create_handle_loop()
|
666
|
+
await self.resume_memory_occupation_communicator(obj)
|
667
|
+
|
551
668
|
async def open_session(
|
552
669
|
self, obj: OpenSessionReqInput, request: Optional[fastapi.Request] = None
|
553
670
|
):
|
@@ -568,9 +685,19 @@ class TokenizerManager:
|
|
568
685
|
async def close_session(
|
569
686
|
self, obj: CloseSessionReqInput, request: Optional[fastapi.Request] = None
|
570
687
|
):
|
571
|
-
assert not self.to_create_loop, "close session should not be the first request"
|
572
688
|
await self.send_to_scheduler.send_pyobj(obj)
|
573
689
|
|
690
|
+
def configure_logging(self, obj: ConfigureLoggingReq):
|
691
|
+
if obj.log_requests is not None:
|
692
|
+
self.log_requests = obj.log_requests
|
693
|
+
if obj.log_requests_level is not None:
|
694
|
+
self.log_requests_level = obj.log_requests_level
|
695
|
+
if obj.dump_requests_folder is not None:
|
696
|
+
self.dump_requests_folder = obj.dump_requests_folder
|
697
|
+
if obj.dump_requests_threshold is not None:
|
698
|
+
self.dump_requests_threshold = obj.dump_requests_threshold
|
699
|
+
logging.info(f"Config logging: {obj=}")
|
700
|
+
|
574
701
|
def create_abort_task(self, obj: GenerateReqInput):
|
575
702
|
# Abort the request if the client is disconnected.
|
576
703
|
async def abort_request():
|
@@ -586,22 +713,35 @@ class TokenizerManager:
|
|
586
713
|
return background_tasks
|
587
714
|
|
588
715
|
def auto_create_handle_loop(self):
|
589
|
-
if
|
716
|
+
if self.no_create_loop:
|
590
717
|
return
|
591
718
|
|
592
|
-
self.
|
719
|
+
self.no_create_loop = True
|
593
720
|
loop = asyncio.get_event_loop()
|
594
|
-
self.asyncio_tasks.add(
|
721
|
+
self.asyncio_tasks.add(
|
722
|
+
loop.create_task(print_exception_wrapper(self.handle_loop))
|
723
|
+
)
|
595
724
|
|
596
|
-
|
597
|
-
|
598
|
-
|
725
|
+
# We cannot add signal handler when the tokenizer manager is not in
|
726
|
+
# the main thread due to the CPython limitation.
|
727
|
+
if threading.current_thread() is threading.main_thread():
|
728
|
+
signal_handler = SignalHandler(self)
|
729
|
+
loop.add_signal_handler(signal.SIGTERM, signal_handler.signal_handler)
|
730
|
+
else:
|
731
|
+
logger.warning(
|
732
|
+
"Signal handler is not added because the tokenizer manager is "
|
733
|
+
"not in the main thread. This disables graceful shutdown of the "
|
734
|
+
"tokenizer manager when SIGTERM is received."
|
735
|
+
)
|
736
|
+
self.asyncio_tasks.add(
|
737
|
+
loop.create_task(print_exception_wrapper(self.sigterm_watchdog))
|
738
|
+
)
|
599
739
|
|
600
740
|
async def sigterm_watchdog(self):
|
601
741
|
while not self.gracefully_exit:
|
602
742
|
await asyncio.sleep(5)
|
603
743
|
|
604
|
-
#
|
744
|
+
# Drain requests
|
605
745
|
while True:
|
606
746
|
remain_num_req = len(self.rid_to_state)
|
607
747
|
logger.info(
|
@@ -619,143 +759,64 @@ class TokenizerManager:
|
|
619
759
|
"""The event loop that handles requests"""
|
620
760
|
|
621
761
|
while True:
|
622
|
-
recv_obj
|
623
|
-
|
624
|
-
BatchEmbeddingOut,
|
625
|
-
BatchTokenIDOut,
|
626
|
-
UpdateWeightFromDiskReqOutput,
|
627
|
-
UpdateWeightsFromDistributedReqOutput,
|
628
|
-
GetWeightsByNameReqOutput,
|
629
|
-
InitWeightsUpdateGroupReqOutput,
|
630
|
-
] = await self.recv_from_detokenizer.recv_pyobj()
|
631
|
-
|
632
|
-
if isinstance(recv_obj, (BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut)):
|
633
|
-
for i, rid in enumerate(recv_obj.rids):
|
634
|
-
state = self.rid_to_state.get(rid, None)
|
635
|
-
if state is None:
|
636
|
-
continue
|
637
|
-
|
638
|
-
meta_info = {
|
639
|
-
"id": rid,
|
640
|
-
"finish_reason": recv_obj.finished_reasons[i],
|
641
|
-
"prompt_tokens": recv_obj.prompt_tokens[i],
|
642
|
-
}
|
762
|
+
recv_obj = await self.recv_from_detokenizer.recv_pyobj()
|
763
|
+
self._result_dispatcher(recv_obj)
|
643
764
|
|
644
|
-
|
645
|
-
|
646
|
-
|
647
|
-
|
648
|
-
|
649
|
-
|
650
|
-
|
651
|
-
|
652
|
-
|
653
|
-
|
654
|
-
|
655
|
-
|
656
|
-
|
657
|
-
|
658
|
-
|
659
|
-
|
660
|
-
|
661
|
-
|
662
|
-
|
663
|
-
|
664
|
-
|
665
|
-
|
666
|
-
|
667
|
-
|
668
|
-
|
669
|
-
|
670
|
-
|
671
|
-
|
672
|
-
|
673
|
-
elif isinstance(recv_obj, BatchTokenIDOut):
|
674
|
-
out_dict = {
|
675
|
-
"token_ids": recv_obj.output_ids[i],
|
676
|
-
"meta_info": meta_info,
|
677
|
-
}
|
678
|
-
else:
|
679
|
-
assert isinstance(recv_obj, BatchEmbeddingOut)
|
680
|
-
out_dict = {
|
681
|
-
"embedding": recv_obj.embeddings[i],
|
682
|
-
"meta_info": meta_info,
|
683
|
-
}
|
684
|
-
state.out_list.append(out_dict)
|
685
|
-
state.finished = recv_obj.finished_reasons[i] is not None
|
686
|
-
state.event.set()
|
687
|
-
|
688
|
-
if self.enable_metrics:
|
689
|
-
completion_tokens = (
|
690
|
-
recv_obj.completion_tokens[i]
|
691
|
-
if getattr(recv_obj, "completion_tokens", None)
|
692
|
-
else 0
|
693
|
-
)
|
694
|
-
|
695
|
-
if state.first_token_time is None:
|
696
|
-
state.first_token_time = time.time()
|
697
|
-
self.metrics_collector.observe_time_to_first_token(
|
698
|
-
state.first_token_time - state.created_time
|
699
|
-
)
|
700
|
-
else:
|
701
|
-
if completion_tokens >= 2:
|
702
|
-
# Compute time_per_output_token for the streaming case
|
703
|
-
self.metrics_collector.observe_time_per_output_token(
|
704
|
-
(time.time() - state.first_token_time)
|
705
|
-
/ (completion_tokens - 1)
|
706
|
-
)
|
707
|
-
|
708
|
-
if state.finished:
|
709
|
-
self.metrics_collector.inc_prompt_tokens(
|
710
|
-
recv_obj.prompt_tokens[i]
|
711
|
-
)
|
712
|
-
self.metrics_collector.inc_generation_tokens(
|
713
|
-
completion_tokens
|
714
|
-
)
|
715
|
-
self.metrics_collector.observe_e2e_request_latency(
|
716
|
-
time.time() - state.created_time
|
717
|
-
)
|
718
|
-
# Compute time_per_output_token for the non-streaming case
|
719
|
-
if (
|
720
|
-
hasattr(state.obj, "stream")
|
721
|
-
and not state.obj.stream
|
722
|
-
and completion_tokens >= 1
|
723
|
-
):
|
724
|
-
self.metrics_collector.observe_time_per_output_token(
|
725
|
-
(time.time() - state.created_time)
|
726
|
-
/ completion_tokens
|
727
|
-
)
|
728
|
-
elif isinstance(recv_obj, OpenSessionReqOutput):
|
729
|
-
self.session_futures[recv_obj.session_id].set_result(
|
730
|
-
recv_obj.session_id if recv_obj.success else None
|
765
|
+
def _handle_batch_output(
|
766
|
+
self, recv_obj: Union[BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut]
|
767
|
+
):
|
768
|
+
for i, rid in enumerate(recv_obj.rids):
|
769
|
+
state = self.rid_to_state.get(rid, None)
|
770
|
+
if state is None:
|
771
|
+
continue
|
772
|
+
|
773
|
+
meta_info = {
|
774
|
+
"id": rid,
|
775
|
+
"finish_reason": recv_obj.finished_reasons[i],
|
776
|
+
"prompt_tokens": recv_obj.prompt_tokens[i],
|
777
|
+
}
|
778
|
+
|
779
|
+
if getattr(state.obj, "return_logprob", False):
|
780
|
+
self.convert_logprob_style(
|
781
|
+
meta_info,
|
782
|
+
state.obj.top_logprobs_num,
|
783
|
+
state.obj.return_text_in_logprobs,
|
784
|
+
recv_obj,
|
785
|
+
i,
|
786
|
+
)
|
787
|
+
|
788
|
+
if not isinstance(recv_obj, BatchEmbeddingOut):
|
789
|
+
meta_info.update(
|
790
|
+
{
|
791
|
+
"completion_tokens": recv_obj.completion_tokens[i],
|
792
|
+
"cached_tokens": recv_obj.cached_tokens[i],
|
793
|
+
}
|
731
794
|
)
|
732
|
-
|
733
|
-
|
734
|
-
|
735
|
-
|
736
|
-
|
737
|
-
|
738
|
-
|
739
|
-
|
740
|
-
|
741
|
-
|
742
|
-
|
743
|
-
), "dp_size must be 1 for init parameter update group"
|
744
|
-
self.init_weights_update_group_communicator.handle_recv(recv_obj)
|
745
|
-
elif isinstance(recv_obj, UpdateWeightsFromDistributedReqOutput):
|
746
|
-
assert (
|
747
|
-
self.server_args.dp_size == 1
|
748
|
-
), "dp_size must be 1 for update weights from distributed"
|
749
|
-
self.update_weights_from_distributed_communicator.handle_recv(recv_obj)
|
750
|
-
elif isinstance(recv_obj, UpdateWeightsFromTensorReqOutput):
|
751
|
-
assert (
|
752
|
-
self.server_args.dp_size == 1
|
753
|
-
), "dp_size must be 1 for update weights from distributed"
|
754
|
-
self.update_weights_from_tensor_communicator.handle_recv(recv_obj)
|
755
|
-
elif isinstance(recv_obj, GetWeightsByNameReqOutput):
|
756
|
-
self.get_weights_by_name_communicator.handle_recv(recv_obj)
|
795
|
+
|
796
|
+
if isinstance(recv_obj, BatchStrOut):
|
797
|
+
out_dict = {
|
798
|
+
"text": recv_obj.output_strs[i],
|
799
|
+
"meta_info": meta_info,
|
800
|
+
}
|
801
|
+
elif isinstance(recv_obj, BatchTokenIDOut):
|
802
|
+
out_dict = {
|
803
|
+
"token_ids": recv_obj.output_ids[i],
|
804
|
+
"meta_info": meta_info,
|
805
|
+
}
|
757
806
|
else:
|
758
|
-
|
807
|
+
assert isinstance(recv_obj, BatchEmbeddingOut)
|
808
|
+
out_dict = {
|
809
|
+
"embedding": recv_obj.embeddings[i],
|
810
|
+
"meta_info": meta_info,
|
811
|
+
}
|
812
|
+
state.out_list.append(out_dict)
|
813
|
+
state.finished = recv_obj.finished_reasons[i] is not None
|
814
|
+
state.event.set()
|
815
|
+
|
816
|
+
if self.enable_metrics and state.obj.log_metrics:
|
817
|
+
self.collect_metrics(state, recv_obj, i)
|
818
|
+
if self.dump_requests_folder and state.finished and state.obj.log_metrics:
|
819
|
+
self.dump_requests(state, out_dict)
|
759
820
|
|
760
821
|
def convert_logprob_style(
|
761
822
|
self,
|
@@ -775,9 +836,6 @@ class TokenizerManager:
|
|
775
836
|
recv_obj.output_token_logprobs_idx[recv_obj_index],
|
776
837
|
return_text_in_logprobs,
|
777
838
|
)
|
778
|
-
meta_info["normalized_prompt_logprob"] = recv_obj.normalized_prompt_logprob[
|
779
|
-
recv_obj_index
|
780
|
-
]
|
781
839
|
|
782
840
|
if top_logprobs_num > 0:
|
783
841
|
meta_info["input_top_logprobs"] = self.detokenize_top_logprobs_tokens(
|
@@ -827,6 +885,93 @@ class TokenizerManager:
|
|
827
885
|
ret.append(None)
|
828
886
|
return ret
|
829
887
|
|
888
|
+
def collect_metrics(self, state: ReqState, recv_obj: BatchStrOut, i: int):
|
889
|
+
completion_tokens = (
|
890
|
+
recv_obj.completion_tokens[i]
|
891
|
+
if getattr(recv_obj, "completion_tokens", None)
|
892
|
+
else 0
|
893
|
+
)
|
894
|
+
|
895
|
+
if state.first_token_time is None:
|
896
|
+
state.first_token_time = time.time()
|
897
|
+
self.metrics_collector.observe_time_to_first_token(
|
898
|
+
state.first_token_time - state.created_time
|
899
|
+
)
|
900
|
+
else:
|
901
|
+
if completion_tokens >= 2:
|
902
|
+
# Compute time_per_output_token for the streaming case
|
903
|
+
self.metrics_collector.observe_time_per_output_token(
|
904
|
+
(time.time() - state.first_token_time) / (completion_tokens - 1)
|
905
|
+
)
|
906
|
+
|
907
|
+
if state.finished:
|
908
|
+
self.metrics_collector.observe_one_finished_request(
|
909
|
+
recv_obj.prompt_tokens[i], completion_tokens
|
910
|
+
)
|
911
|
+
self.metrics_collector.observe_e2e_request_latency(
|
912
|
+
time.time() - state.created_time
|
913
|
+
)
|
914
|
+
# Compute time_per_output_token for the non-streaming case
|
915
|
+
if (
|
916
|
+
hasattr(state.obj, "stream")
|
917
|
+
and not state.obj.stream
|
918
|
+
and completion_tokens >= 1
|
919
|
+
):
|
920
|
+
self.metrics_collector.observe_time_per_output_token(
|
921
|
+
(time.time() - state.created_time) / completion_tokens
|
922
|
+
)
|
923
|
+
|
924
|
+
def dump_requests(self, state: ReqState, out_dict: dict):
|
925
|
+
self.dump_request_list.append(
|
926
|
+
(state.obj, out_dict, state.created_time, time.time())
|
927
|
+
)
|
928
|
+
|
929
|
+
if len(self.dump_request_list) >= self.dump_requests_threshold:
|
930
|
+
filename = os.path.join(
|
931
|
+
self.dump_requests_folder,
|
932
|
+
datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + ".pkl",
|
933
|
+
)
|
934
|
+
logger.info(f"Dump {len(self.dump_request_list)} requests to {filename}")
|
935
|
+
|
936
|
+
to_dump = self.dump_request_list
|
937
|
+
self.dump_request_list = []
|
938
|
+
|
939
|
+
def background_task():
|
940
|
+
os.makedirs(self.dump_requests_folder, exist_ok=True)
|
941
|
+
with open(filename, "wb") as f:
|
942
|
+
pickle.dump(to_dump, f)
|
943
|
+
|
944
|
+
# Schedule the task to run in the background without awaiting it
|
945
|
+
asyncio.create_task(asyncio.to_thread(background_task))
|
946
|
+
|
947
|
+
def _handle_open_session_req_output(self, recv_obj):
|
948
|
+
self.session_futures[recv_obj.session_id].set_result(
|
949
|
+
recv_obj.session_id if recv_obj.success else None
|
950
|
+
)
|
951
|
+
|
952
|
+
def _handle_update_weights_from_disk_req_output(self, recv_obj):
|
953
|
+
if self.server_args.dp_size == 1:
|
954
|
+
self.model_update_result.set_result(recv_obj)
|
955
|
+
else: # self.server_args.dp_size > 1
|
956
|
+
self.model_update_tmp.append(recv_obj)
|
957
|
+
# set future if the all results are recevied
|
958
|
+
if len(self.model_update_tmp) == self.server_args.dp_size:
|
959
|
+
self.model_update_result.set_result(self.model_update_tmp)
|
960
|
+
|
961
|
+
|
962
|
+
async def print_exception_wrapper(func):
|
963
|
+
"""
|
964
|
+
Sometimes an asyncio function does not print exception.
|
965
|
+
We do another wrapper to handle the exception.
|
966
|
+
"""
|
967
|
+
try:
|
968
|
+
await func()
|
969
|
+
except Exception:
|
970
|
+
traceback = get_exception_traceback()
|
971
|
+
logger.error(f"TokenizerManager hit an exception: {traceback}")
|
972
|
+
kill_process_tree(os.getpid(), include_parent=True)
|
973
|
+
sys.exit(1)
|
974
|
+
|
830
975
|
|
831
976
|
class SignalHandler:
|
832
977
|
def __init__(self, tokenizer_manager):
|
sglang/srt/managers/tp_worker.py
CHANGED
@@ -83,6 +83,7 @@ class TpModelWorker:
|
|
83
83
|
server_args.tokenizer_path,
|
84
84
|
tokenizer_mode=server_args.tokenizer_mode,
|
85
85
|
trust_remote_code=server_args.trust_remote_code,
|
86
|
+
revision=server_args.revision,
|
86
87
|
)
|
87
88
|
self.tokenizer = self.processor.tokenizer
|
88
89
|
else:
|
@@ -90,6 +91,7 @@ class TpModelWorker:
|
|
90
91
|
server_args.tokenizer_path,
|
91
92
|
tokenizer_mode=server_args.tokenizer_mode,
|
92
93
|
trust_remote_code=server_args.trust_remote_code,
|
94
|
+
revision=server_args.revision,
|
93
95
|
)
|
94
96
|
self.device = self.model_runner.device
|
95
97
|
|
@@ -101,6 +103,7 @@ class TpModelWorker:
|
|
101
103
|
self.max_total_num_tokens // 2
|
102
104
|
if server_args.max_running_requests is None
|
103
105
|
else server_args.max_running_requests
|
106
|
+
// (server_args.dp_size if server_args.enable_dp_attention else 1)
|
104
107
|
),
|
105
108
|
self.model_runner.req_to_token_pool.size,
|
106
109
|
)
|
@@ -142,16 +145,15 @@ class TpModelWorker:
|
|
142
145
|
def get_tp_cpu_group(self):
|
143
146
|
return self.model_runner.tp_group.cpu_group
|
144
147
|
|
148
|
+
def get_attention_tp_cpu_group(self):
|
149
|
+
return self.model_runner.attention_tp_group.cpu_group
|
150
|
+
|
145
151
|
def get_memory_pool(self):
|
146
152
|
return (
|
147
153
|
self.model_runner.req_to_token_pool,
|
148
154
|
self.model_runner.token_to_kv_pool,
|
149
155
|
)
|
150
156
|
|
151
|
-
def forward_batch_idle(self, model_worker_batch: ModelWorkerBatch):
|
152
|
-
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
|
153
|
-
self.model_runner.forward(forward_batch)
|
154
|
-
|
155
157
|
def forward_batch_generation(
|
156
158
|
self,
|
157
159
|
model_worker_batch: ModelWorkerBatch,
|