sglang 0.4.1.post6__py3-none-any.whl → 0.4.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +21 -23
- sglang/api.py +2 -7
- sglang/bench_offline_throughput.py +41 -27
- sglang/bench_one_batch.py +60 -4
- sglang/bench_one_batch_server.py +1 -1
- sglang/bench_serving.py +83 -71
- sglang/lang/backend/runtime_endpoint.py +183 -4
- sglang/lang/chat_template.py +46 -4
- sglang/launch_server.py +1 -1
- sglang/srt/_custom_ops.py +80 -42
- sglang/srt/configs/device_config.py +1 -1
- sglang/srt/configs/load_config.py +1 -0
- sglang/srt/configs/model_config.py +1 -0
- sglang/srt/constrained/base_grammar_backend.py +21 -0
- sglang/srt/constrained/xgrammar_backend.py +8 -4
- sglang/srt/conversation.py +14 -1
- sglang/srt/distributed/__init__.py +3 -3
- sglang/srt/distributed/communication_op.py +2 -1
- sglang/srt/distributed/device_communicators/cuda_wrapper.py +2 -1
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +112 -42
- sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +2 -2
- sglang/srt/distributed/device_communicators/hpu_communicator.py +2 -1
- sglang/srt/distributed/device_communicators/pynccl.py +80 -1
- sglang/srt/distributed/device_communicators/pynccl_wrapper.py +112 -2
- sglang/srt/distributed/device_communicators/shm_broadcast.py +5 -72
- sglang/srt/distributed/device_communicators/xpu_communicator.py +2 -1
- sglang/srt/distributed/parallel_state.py +1 -1
- sglang/srt/distributed/utils.py +2 -1
- sglang/srt/entrypoints/engine.py +452 -0
- sglang/srt/entrypoints/http_server.py +603 -0
- sglang/srt/function_call_parser.py +494 -0
- sglang/srt/layers/activation.py +8 -8
- sglang/srt/layers/attention/flashinfer_backend.py +10 -9
- sglang/srt/layers/attention/triton_backend.py +4 -6
- sglang/srt/layers/attention/vision.py +204 -0
- sglang/srt/layers/dp_attention.py +71 -0
- sglang/srt/layers/layernorm.py +5 -5
- sglang/srt/layers/linear.py +65 -14
- sglang/srt/layers/logits_processor.py +49 -64
- sglang/srt/layers/moe/ep_moe/layer.py +24 -16
- sglang/srt/layers/moe/fused_moe_native.py +84 -1
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +27 -7
- sglang/srt/layers/moe/fused_moe_triton/layer.py +38 -5
- sglang/srt/layers/parameter.py +18 -8
- sglang/srt/layers/quantization/__init__.py +20 -23
- sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/fp8.py +10 -4
- sglang/srt/layers/quantization/modelopt_quant.py +1 -2
- sglang/srt/layers/quantization/w8a8_int8.py +1 -1
- sglang/srt/layers/radix_attention.py +2 -2
- sglang/srt/layers/rotary_embedding.py +1184 -31
- sglang/srt/layers/sampler.py +64 -6
- sglang/srt/layers/torchao_utils.py +12 -6
- sglang/srt/layers/vocab_parallel_embedding.py +2 -2
- sglang/srt/lora/lora.py +1 -9
- sglang/srt/managers/configure_logging.py +3 -0
- sglang/srt/managers/data_parallel_controller.py +79 -72
- sglang/srt/managers/detokenizer_manager.py +24 -6
- sglang/srt/managers/image_processor.py +158 -2
- sglang/srt/managers/io_struct.py +57 -3
- sglang/srt/managers/schedule_batch.py +78 -45
- sglang/srt/managers/schedule_policy.py +26 -12
- sglang/srt/managers/scheduler.py +326 -201
- sglang/srt/managers/session_controller.py +1 -0
- sglang/srt/managers/tokenizer_manager.py +210 -121
- sglang/srt/managers/tp_worker.py +6 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +5 -8
- sglang/srt/managers/utils.py +44 -0
- sglang/srt/mem_cache/memory_pool.py +10 -32
- sglang/srt/metrics/collector.py +15 -6
- sglang/srt/model_executor/cuda_graph_runner.py +26 -30
- sglang/srt/model_executor/forward_batch_info.py +5 -7
- sglang/srt/model_executor/model_runner.py +44 -19
- sglang/srt/model_loader/loader.py +83 -6
- sglang/srt/model_loader/weight_utils.py +145 -6
- sglang/srt/models/baichuan.py +6 -6
- sglang/srt/models/chatglm.py +2 -2
- sglang/srt/models/commandr.py +17 -5
- sglang/srt/models/dbrx.py +13 -5
- sglang/srt/models/deepseek.py +3 -3
- sglang/srt/models/deepseek_v2.py +11 -11
- sglang/srt/models/exaone.py +2 -2
- sglang/srt/models/gemma.py +2 -2
- sglang/srt/models/gemma2.py +15 -25
- sglang/srt/models/gpt2.py +3 -5
- sglang/srt/models/gpt_bigcode.py +1 -1
- sglang/srt/models/granite.py +2 -2
- sglang/srt/models/grok.py +4 -3
- sglang/srt/models/internlm2.py +2 -2
- sglang/srt/models/llama.py +7 -5
- sglang/srt/models/minicpm.py +2 -2
- sglang/srt/models/minicpm3.py +9 -9
- sglang/srt/models/minicpmv.py +1238 -0
- sglang/srt/models/mixtral.py +3 -3
- sglang/srt/models/mixtral_quant.py +3 -3
- sglang/srt/models/mllama.py +2 -2
- sglang/srt/models/olmo.py +3 -3
- sglang/srt/models/olmo2.py +4 -4
- sglang/srt/models/olmoe.py +7 -13
- sglang/srt/models/phi3_small.py +2 -2
- sglang/srt/models/qwen.py +2 -2
- sglang/srt/models/qwen2.py +41 -4
- sglang/srt/models/qwen2_moe.py +3 -3
- sglang/srt/models/qwen2_vl.py +22 -122
- sglang/srt/models/stablelm.py +2 -2
- sglang/srt/models/torch_native_llama.py +20 -7
- sglang/srt/models/xverse.py +6 -6
- sglang/srt/models/xverse_moe.py +6 -6
- sglang/srt/openai_api/adapter.py +139 -37
- sglang/srt/openai_api/protocol.py +7 -4
- sglang/srt/sampling/custom_logit_processor.py +38 -0
- sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +11 -14
- sglang/srt/sampling/sampling_batch_info.py +143 -18
- sglang/srt/sampling/sampling_params.py +3 -1
- sglang/srt/server.py +4 -1090
- sglang/srt/server_args.py +77 -15
- sglang/srt/speculative/eagle_utils.py +37 -15
- sglang/srt/speculative/eagle_worker.py +11 -13
- sglang/srt/utils.py +164 -129
- sglang/test/runners.py +8 -13
- sglang/test/test_programs.py +2 -1
- sglang/test/test_utils.py +83 -22
- sglang/utils.py +12 -2
- sglang/version.py +1 -1
- {sglang-0.4.1.post6.dist-info → sglang-0.4.2.dist-info}/METADATA +21 -10
- {sglang-0.4.1.post6.dist-info → sglang-0.4.2.dist-info}/RECORD +138 -123
- sglang/launch_server_llavavid.py +0 -25
- sglang/srt/constrained/__init__.py +0 -16
- sglang/srt/distributed/device_communicators/__init__.py +0 -0
- {sglang-0.4.1.post6.dist-info → sglang-0.4.2.dist-info}/LICENSE +0 -0
- {sglang-0.4.1.post6.dist-info → sglang-0.4.2.dist-info}/WHEEL +0 -0
- {sglang-0.4.1.post6.dist-info → sglang-0.4.2.dist-info}/top_level.txt +0 -0
@@ -131,6 +131,7 @@ class Session:
|
|
131
131
|
sampling_params=req.sampling_params,
|
132
132
|
lora_path=req.lora_path,
|
133
133
|
session_id=self.session_id,
|
134
|
+
custom_logit_processor=req.custom_logit_processor,
|
134
135
|
)
|
135
136
|
if last_req is not None:
|
136
137
|
new_req.image_inputs = last_req.image_inputs
|
@@ -21,9 +21,11 @@ import os
|
|
21
21
|
import pickle
|
22
22
|
import signal
|
23
23
|
import sys
|
24
|
+
import threading
|
24
25
|
import time
|
25
26
|
import uuid
|
26
27
|
from datetime import datetime
|
28
|
+
from http import HTTPStatus
|
27
29
|
from typing import Any, Awaitable, Dict, Generic, List, Optional, Tuple, TypeVar, Union
|
28
30
|
|
29
31
|
import fastapi
|
@@ -78,6 +80,7 @@ from sglang.srt.utils import (
|
|
78
80
|
get_zmq_socket,
|
79
81
|
kill_process_tree,
|
80
82
|
)
|
83
|
+
from sglang.utils import TypeBasedDispatcher, get_exception_traceback
|
81
84
|
|
82
85
|
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
83
86
|
|
@@ -110,17 +113,19 @@ class TokenizerManager:
|
|
110
113
|
port_args: PortArgs,
|
111
114
|
):
|
112
115
|
# Parse args
|
116
|
+
|
113
117
|
self.server_args = server_args
|
114
118
|
self.enable_metrics = server_args.enable_metrics
|
115
119
|
self.log_requests = server_args.log_requests
|
120
|
+
self.log_requests_level = 0
|
116
121
|
|
117
122
|
# Init inter-process communication
|
118
123
|
context = zmq.asyncio.Context(2)
|
119
124
|
self.recv_from_detokenizer = get_zmq_socket(
|
120
|
-
context, zmq.PULL, port_args.tokenizer_ipc_name
|
125
|
+
context, zmq.PULL, port_args.tokenizer_ipc_name, True
|
121
126
|
)
|
122
127
|
self.send_to_scheduler = get_zmq_socket(
|
123
|
-
context, zmq.PUSH, port_args.scheduler_input_ipc_name
|
128
|
+
context, zmq.PUSH, port_args.scheduler_input_ipc_name, True
|
124
129
|
)
|
125
130
|
|
126
131
|
# Read model args
|
@@ -153,6 +158,7 @@ class TokenizerManager:
|
|
153
158
|
server_args.tokenizer_path,
|
154
159
|
tokenizer_mode=server_args.tokenizer_mode,
|
155
160
|
trust_remote_code=server_args.trust_remote_code,
|
161
|
+
revision=server_args.revision,
|
156
162
|
)
|
157
163
|
self.tokenizer = self.processor.tokenizer
|
158
164
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
@@ -166,10 +172,11 @@ class TokenizerManager:
|
|
166
172
|
server_args.tokenizer_path,
|
167
173
|
tokenizer_mode=server_args.tokenizer_mode,
|
168
174
|
trust_remote_code=server_args.trust_remote_code,
|
175
|
+
revision=server_args.revision,
|
169
176
|
)
|
170
177
|
|
171
178
|
# Store states
|
172
|
-
self.
|
179
|
+
self.no_create_loop = False
|
173
180
|
self.rid_to_state: Dict[str, ReqState] = {}
|
174
181
|
self.dump_requests_folder = "" # By default do not dump
|
175
182
|
self.dump_requests_threshold = 1000
|
@@ -205,6 +212,8 @@ class TokenizerManager:
|
|
205
212
|
self.resume_memory_occupation_communicator = _Communicator(
|
206
213
|
self.send_to_scheduler, server_args.dp_size
|
207
214
|
)
|
215
|
+
# Set after scheduler is initialized
|
216
|
+
self.max_req_input_len = None
|
208
217
|
|
209
218
|
# Metrics
|
210
219
|
if self.enable_metrics:
|
@@ -215,6 +224,44 @@ class TokenizerManager:
|
|
215
224
|
},
|
216
225
|
)
|
217
226
|
|
227
|
+
self._result_dispatcher = TypeBasedDispatcher(
|
228
|
+
[
|
229
|
+
(
|
230
|
+
(BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut),
|
231
|
+
self._handle_batch_output,
|
232
|
+
),
|
233
|
+
(OpenSessionReqOutput, self._handle_open_session_req_output),
|
234
|
+
(
|
235
|
+
UpdateWeightFromDiskReqOutput,
|
236
|
+
self._handle_update_weights_from_disk_req_output,
|
237
|
+
),
|
238
|
+
(
|
239
|
+
InitWeightsUpdateGroupReqOutput,
|
240
|
+
self.init_weights_update_group_communicator.handle_recv,
|
241
|
+
),
|
242
|
+
(
|
243
|
+
UpdateWeightsFromDistributedReqOutput,
|
244
|
+
self.update_weights_from_distributed_communicator.handle_recv,
|
245
|
+
),
|
246
|
+
(
|
247
|
+
UpdateWeightsFromTensorReqOutput,
|
248
|
+
self.update_weights_from_tensor_communicator.handle_recv,
|
249
|
+
),
|
250
|
+
(
|
251
|
+
GetWeightsByNameReqOutput,
|
252
|
+
self.get_weights_by_name_communicator.handle_recv,
|
253
|
+
),
|
254
|
+
(
|
255
|
+
ReleaseMemoryOccupationReqOutput,
|
256
|
+
self.release_memory_occupation_communicator.handle_recv,
|
257
|
+
),
|
258
|
+
(
|
259
|
+
ResumeMemoryOccupationReqOutput,
|
260
|
+
self.resume_memory_occupation_communicator.handle_recv,
|
261
|
+
),
|
262
|
+
]
|
263
|
+
)
|
264
|
+
|
218
265
|
async def generate_request(
|
219
266
|
self,
|
220
267
|
obj: Union[GenerateReqInput, EmbeddingReqInput],
|
@@ -233,7 +280,10 @@ class TokenizerManager:
|
|
233
280
|
obj.normalize_batch_and_arguments()
|
234
281
|
|
235
282
|
if self.log_requests:
|
236
|
-
|
283
|
+
max_length = 2048 if self.log_requests_level == 0 else 1 << 30
|
284
|
+
logger.info(
|
285
|
+
f"Receive: obj={dataclass_to_string_truncated(obj, max_length)}"
|
286
|
+
)
|
237
287
|
|
238
288
|
async with self.model_update_lock.reader_lock:
|
239
289
|
is_single = obj.is_single
|
@@ -265,15 +315,21 @@ class TokenizerManager:
|
|
265
315
|
)
|
266
316
|
input_embeds = obj.input_embeds
|
267
317
|
input_ids = obj.input_ids
|
268
|
-
elif obj.input_ids is None:
|
269
|
-
input_ids = self.tokenizer.encode(input_text)
|
270
|
-
else:
|
318
|
+
elif obj.input_ids is not None:
|
271
319
|
input_ids = obj.input_ids
|
320
|
+
else:
|
321
|
+
if self.tokenizer is None:
|
322
|
+
raise ValueError(
|
323
|
+
"The engine initialized with skip_tokenizer_init=True cannot "
|
324
|
+
"accept text prompts. Please provide input_ids or re-initialize "
|
325
|
+
"the engine with skip_tokenizer_init=False."
|
326
|
+
)
|
327
|
+
input_ids = self.tokenizer.encode(input_text)
|
272
328
|
|
273
329
|
if self.is_generation:
|
274
330
|
# TODO: also support getting embeddings for multimodal models
|
275
331
|
image_inputs: Dict = await self.image_processor.process_images_async(
|
276
|
-
obj.image_data, input_text or input_ids, obj
|
332
|
+
obj.image_data, input_text or input_ids, obj, self.max_req_input_len
|
277
333
|
)
|
278
334
|
if image_inputs and "input_ids" in image_inputs:
|
279
335
|
input_ids = image_inputs["input_ids"]
|
@@ -284,12 +340,28 @@ class TokenizerManager:
|
|
284
340
|
SessionParams(**obj.session_params) if obj.session_params else None
|
285
341
|
)
|
286
342
|
|
287
|
-
if
|
343
|
+
input_token_num = len(input_ids) if input_ids is not None else 0
|
344
|
+
if input_token_num >= self.context_len:
|
288
345
|
raise ValueError(
|
289
|
-
f"The input ({
|
346
|
+
f"The input ({input_token_num} tokens) is longer than the "
|
290
347
|
f"model's context length ({self.context_len} tokens)."
|
291
348
|
)
|
292
349
|
|
350
|
+
if (
|
351
|
+
obj.sampling_params.get("max_new_tokens") is not None
|
352
|
+
and obj.sampling_params.get("max_new_tokens") + input_token_num
|
353
|
+
>= self.context_len
|
354
|
+
):
|
355
|
+
raise ValueError(
|
356
|
+
f"Requested token count exceeds the model's maximum context length "
|
357
|
+
f"of {self.context_len} tokens. You requested a total of "
|
358
|
+
f"{obj.sampling_params.get('max_new_tokens') + input_token_num} "
|
359
|
+
f"tokens: {input_token_num} tokens from the input messages and "
|
360
|
+
f"{obj.sampling_params.get('max_new_tokens')} tokens for the "
|
361
|
+
f"completion. Please reduce the number of tokens in the input "
|
362
|
+
f"messages or the completion to fit within the limit."
|
363
|
+
)
|
364
|
+
|
293
365
|
# Parse sampling parameters
|
294
366
|
sampling_params = SamplingParams(**obj.sampling_params)
|
295
367
|
sampling_params.normalize(self.tokenizer)
|
@@ -310,6 +382,7 @@ class TokenizerManager:
|
|
310
382
|
lora_path=obj.lora_path,
|
311
383
|
input_embeds=input_embeds,
|
312
384
|
session_params=session_params,
|
385
|
+
custom_logit_processor=obj.custom_logit_processor,
|
313
386
|
)
|
314
387
|
elif isinstance(obj, EmbeddingReqInput):
|
315
388
|
tokenized_obj = TokenizedEmbeddingReqInput(
|
@@ -354,9 +427,20 @@ class TokenizerManager:
|
|
354
427
|
state.out_list = []
|
355
428
|
if state.finished:
|
356
429
|
if self.log_requests:
|
357
|
-
|
430
|
+
max_length = 2048 if self.log_requests_level == 0 else 1 << 30
|
431
|
+
msg = f"Finish: obj={dataclass_to_string_truncated(obj, max_length)}, out={dataclass_to_string_truncated(out, max_length)}"
|
358
432
|
logger.info(msg)
|
359
433
|
del self.rid_to_state[obj.rid]
|
434
|
+
|
435
|
+
# Check if this was an abort/error created by scheduler
|
436
|
+
if isinstance(out["meta_info"].get("finish_reason"), dict):
|
437
|
+
finish_reason = out["meta_info"]["finish_reason"]
|
438
|
+
if (
|
439
|
+
finish_reason.get("type") == "abort"
|
440
|
+
and finish_reason.get("status_code") == HTTPStatus.BAD_REQUEST
|
441
|
+
):
|
442
|
+
raise ValueError(finish_reason["message"])
|
443
|
+
|
360
444
|
yield out
|
361
445
|
break
|
362
446
|
|
@@ -601,12 +685,13 @@ class TokenizerManager:
|
|
601
685
|
async def close_session(
|
602
686
|
self, obj: CloseSessionReqInput, request: Optional[fastapi.Request] = None
|
603
687
|
):
|
604
|
-
assert not self.to_create_loop, "close session should not be the first request"
|
605
688
|
await self.send_to_scheduler.send_pyobj(obj)
|
606
689
|
|
607
690
|
def configure_logging(self, obj: ConfigureLoggingReq):
|
608
691
|
if obj.log_requests is not None:
|
609
692
|
self.log_requests = obj.log_requests
|
693
|
+
if obj.log_requests_level is not None:
|
694
|
+
self.log_requests_level = obj.log_requests_level
|
610
695
|
if obj.dump_requests_folder is not None:
|
611
696
|
self.dump_requests_folder = obj.dump_requests_folder
|
612
697
|
if obj.dump_requests_threshold is not None:
|
@@ -628,16 +713,29 @@ class TokenizerManager:
|
|
628
713
|
return background_tasks
|
629
714
|
|
630
715
|
def auto_create_handle_loop(self):
|
631
|
-
if
|
716
|
+
if self.no_create_loop:
|
632
717
|
return
|
633
718
|
|
634
|
-
self.
|
719
|
+
self.no_create_loop = True
|
635
720
|
loop = asyncio.get_event_loop()
|
636
|
-
self.asyncio_tasks.add(
|
721
|
+
self.asyncio_tasks.add(
|
722
|
+
loop.create_task(print_exception_wrapper(self.handle_loop))
|
723
|
+
)
|
637
724
|
|
638
|
-
|
639
|
-
|
640
|
-
|
725
|
+
# We cannot add signal handler when the tokenizer manager is not in
|
726
|
+
# the main thread due to the CPython limitation.
|
727
|
+
if threading.current_thread() is threading.main_thread():
|
728
|
+
signal_handler = SignalHandler(self)
|
729
|
+
loop.add_signal_handler(signal.SIGTERM, signal_handler.signal_handler)
|
730
|
+
else:
|
731
|
+
logger.warning(
|
732
|
+
"Signal handler is not added because the tokenizer manager is "
|
733
|
+
"not in the main thread. This disables graceful shutdown of the "
|
734
|
+
"tokenizer manager when SIGTERM is received."
|
735
|
+
)
|
736
|
+
self.asyncio_tasks.add(
|
737
|
+
loop.create_task(print_exception_wrapper(self.sigterm_watchdog))
|
738
|
+
)
|
641
739
|
|
642
740
|
async def sigterm_watchdog(self):
|
643
741
|
while not self.gracefully_exit:
|
@@ -661,106 +759,68 @@ class TokenizerManager:
|
|
661
759
|
"""The event loop that handles requests"""
|
662
760
|
|
663
761
|
while True:
|
664
|
-
recv_obj
|
665
|
-
|
666
|
-
|
667
|
-
|
668
|
-
|
669
|
-
|
670
|
-
|
671
|
-
|
672
|
-
|
673
|
-
|
674
|
-
] = await self.recv_from_detokenizer.recv_pyobj()
|
675
|
-
|
676
|
-
if isinstance(recv_obj, (BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut)):
|
677
|
-
for i, rid in enumerate(recv_obj.rids):
|
678
|
-
state = self.rid_to_state.get(rid, None)
|
679
|
-
if state is None:
|
680
|
-
continue
|
681
|
-
|
682
|
-
meta_info = {
|
683
|
-
"id": rid,
|
684
|
-
"finish_reason": recv_obj.finished_reasons[i],
|
685
|
-
"prompt_tokens": recv_obj.prompt_tokens[i],
|
686
|
-
}
|
762
|
+
recv_obj = await self.recv_from_detokenizer.recv_pyobj()
|
763
|
+
self._result_dispatcher(recv_obj)
|
764
|
+
|
765
|
+
def _handle_batch_output(
|
766
|
+
self, recv_obj: Union[BatchStrOut, BatchEmbeddingOut, BatchTokenIDOut]
|
767
|
+
):
|
768
|
+
for i, rid in enumerate(recv_obj.rids):
|
769
|
+
state = self.rid_to_state.get(rid, None)
|
770
|
+
if state is None:
|
771
|
+
continue
|
687
772
|
|
688
|
-
|
689
|
-
|
690
|
-
|
691
|
-
|
692
|
-
|
693
|
-
|
694
|
-
|
695
|
-
|
696
|
-
|
697
|
-
|
698
|
-
|
699
|
-
|
700
|
-
|
701
|
-
"cached_tokens": recv_obj.cached_tokens[i],
|
702
|
-
}
|
703
|
-
)
|
704
|
-
|
705
|
-
if isinstance(recv_obj, BatchStrOut):
|
706
|
-
out_dict = {
|
707
|
-
"text": recv_obj.output_strs[i],
|
708
|
-
"meta_info": meta_info,
|
709
|
-
}
|
710
|
-
elif isinstance(recv_obj, BatchTokenIDOut):
|
711
|
-
out_dict = {
|
712
|
-
"token_ids": recv_obj.output_ids[i],
|
713
|
-
"meta_info": meta_info,
|
714
|
-
}
|
715
|
-
else:
|
716
|
-
assert isinstance(recv_obj, BatchEmbeddingOut)
|
717
|
-
out_dict = {
|
718
|
-
"embedding": recv_obj.embeddings[i],
|
719
|
-
"meta_info": meta_info,
|
720
|
-
}
|
721
|
-
state.out_list.append(out_dict)
|
722
|
-
state.finished = recv_obj.finished_reasons[i] is not None
|
723
|
-
state.event.set()
|
724
|
-
|
725
|
-
if self.enable_metrics:
|
726
|
-
self.collect_metrics(state, recv_obj, i)
|
727
|
-
if self.dump_requests_folder and state.finished:
|
728
|
-
self.dump_requests(state, out_dict)
|
729
|
-
elif isinstance(recv_obj, OpenSessionReqOutput):
|
730
|
-
self.session_futures[recv_obj.session_id].set_result(
|
731
|
-
recv_obj.session_id if recv_obj.success else None
|
773
|
+
meta_info = {
|
774
|
+
"id": rid,
|
775
|
+
"finish_reason": recv_obj.finished_reasons[i],
|
776
|
+
"prompt_tokens": recv_obj.prompt_tokens[i],
|
777
|
+
}
|
778
|
+
|
779
|
+
if getattr(state.obj, "return_logprob", False):
|
780
|
+
self.convert_logprob_style(
|
781
|
+
meta_info,
|
782
|
+
state.obj.top_logprobs_num,
|
783
|
+
state.obj.return_text_in_logprobs,
|
784
|
+
recv_obj,
|
785
|
+
i,
|
732
786
|
)
|
733
|
-
|
734
|
-
|
735
|
-
|
736
|
-
|
737
|
-
|
738
|
-
|
739
|
-
|
740
|
-
|
741
|
-
|
742
|
-
|
743
|
-
|
744
|
-
|
745
|
-
|
746
|
-
|
747
|
-
|
748
|
-
|
749
|
-
|
750
|
-
|
751
|
-
|
752
|
-
|
753
|
-
|
754
|
-
|
755
|
-
self.update_weights_from_tensor_communicator.handle_recv(recv_obj)
|
756
|
-
elif isinstance(recv_obj, GetWeightsByNameReqOutput):
|
757
|
-
self.get_weights_by_name_communicator.handle_recv(recv_obj)
|
758
|
-
elif isinstance(recv_obj, ReleaseMemoryOccupationReqOutput):
|
759
|
-
self.release_memory_occupation_communicator.handle_recv(recv_obj)
|
760
|
-
elif isinstance(recv_obj, ResumeMemoryOccupationReqOutput):
|
761
|
-
self.resume_memory_occupation_communicator.handle_recv(recv_obj)
|
787
|
+
|
788
|
+
if self.server_args.speculative_algorithm:
|
789
|
+
meta_info["spec_verify_ct"] = recv_obj.spec_verify_ct[i]
|
790
|
+
|
791
|
+
if not isinstance(recv_obj, BatchEmbeddingOut):
|
792
|
+
meta_info.update(
|
793
|
+
{
|
794
|
+
"completion_tokens": recv_obj.completion_tokens[i],
|
795
|
+
"cached_tokens": recv_obj.cached_tokens[i],
|
796
|
+
}
|
797
|
+
)
|
798
|
+
|
799
|
+
if isinstance(recv_obj, BatchStrOut):
|
800
|
+
out_dict = {
|
801
|
+
"text": recv_obj.output_strs[i],
|
802
|
+
"meta_info": meta_info,
|
803
|
+
}
|
804
|
+
elif isinstance(recv_obj, BatchTokenIDOut):
|
805
|
+
out_dict = {
|
806
|
+
"token_ids": recv_obj.output_ids[i],
|
807
|
+
"meta_info": meta_info,
|
808
|
+
}
|
762
809
|
else:
|
763
|
-
|
810
|
+
assert isinstance(recv_obj, BatchEmbeddingOut)
|
811
|
+
out_dict = {
|
812
|
+
"embedding": recv_obj.embeddings[i],
|
813
|
+
"meta_info": meta_info,
|
814
|
+
}
|
815
|
+
|
816
|
+
state.out_list.append(out_dict)
|
817
|
+
state.finished = recv_obj.finished_reasons[i] is not None
|
818
|
+
state.event.set()
|
819
|
+
|
820
|
+
if self.enable_metrics and state.obj.log_metrics:
|
821
|
+
self.collect_metrics(state, recv_obj, i)
|
822
|
+
if self.dump_requests_folder and state.finished and state.obj.log_metrics:
|
823
|
+
self.dump_requests(state, out_dict)
|
764
824
|
|
765
825
|
def convert_logprob_style(
|
766
826
|
self,
|
@@ -780,9 +840,6 @@ class TokenizerManager:
|
|
780
840
|
recv_obj.output_token_logprobs_idx[recv_obj_index],
|
781
841
|
return_text_in_logprobs,
|
782
842
|
)
|
783
|
-
meta_info["normalized_prompt_logprob"] = recv_obj.normalized_prompt_logprob[
|
784
|
-
recv_obj_index
|
785
|
-
]
|
786
843
|
|
787
844
|
if top_logprobs_num > 0:
|
788
845
|
meta_info["input_top_logprobs"] = self.detokenize_top_logprobs_tokens(
|
@@ -874,19 +931,51 @@ class TokenizerManager:
|
|
874
931
|
)
|
875
932
|
|
876
933
|
if len(self.dump_request_list) >= self.dump_requests_threshold:
|
934
|
+
filename = os.path.join(
|
935
|
+
self.dump_requests_folder,
|
936
|
+
datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + ".pkl",
|
937
|
+
)
|
938
|
+
logger.info(f"Dump {len(self.dump_request_list)} requests to {filename}")
|
939
|
+
|
877
940
|
to_dump = self.dump_request_list
|
878
941
|
self.dump_request_list = []
|
879
942
|
|
880
943
|
def background_task():
|
881
944
|
os.makedirs(self.dump_requests_folder, exist_ok=True)
|
882
|
-
|
883
|
-
filename = current_time.strftime("%Y-%m-%d_%H-%M-%S") + ".pkl"
|
884
|
-
with open(os.path.join(self.dump_requests_folder, filename), "wb") as f:
|
945
|
+
with open(filename, "wb") as f:
|
885
946
|
pickle.dump(to_dump, f)
|
886
947
|
|
887
948
|
# Schedule the task to run in the background without awaiting it
|
888
949
|
asyncio.create_task(asyncio.to_thread(background_task))
|
889
950
|
|
951
|
+
def _handle_open_session_req_output(self, recv_obj):
|
952
|
+
self.session_futures[recv_obj.session_id].set_result(
|
953
|
+
recv_obj.session_id if recv_obj.success else None
|
954
|
+
)
|
955
|
+
|
956
|
+
def _handle_update_weights_from_disk_req_output(self, recv_obj):
|
957
|
+
if self.server_args.dp_size == 1:
|
958
|
+
self.model_update_result.set_result(recv_obj)
|
959
|
+
else: # self.server_args.dp_size > 1
|
960
|
+
self.model_update_tmp.append(recv_obj)
|
961
|
+
# set future if the all results are recevied
|
962
|
+
if len(self.model_update_tmp) == self.server_args.dp_size:
|
963
|
+
self.model_update_result.set_result(self.model_update_tmp)
|
964
|
+
|
965
|
+
|
966
|
+
async def print_exception_wrapper(func):
|
967
|
+
"""
|
968
|
+
Sometimes an asyncio function does not print exception.
|
969
|
+
We do another wrapper to handle the exception.
|
970
|
+
"""
|
971
|
+
try:
|
972
|
+
await func()
|
973
|
+
except Exception:
|
974
|
+
traceback = get_exception_traceback()
|
975
|
+
logger.error(f"TokenizerManager hit an exception: {traceback}")
|
976
|
+
kill_process_tree(os.getpid(), include_parent=True)
|
977
|
+
sys.exit(1)
|
978
|
+
|
890
979
|
|
891
980
|
class SignalHandler:
|
892
981
|
def __init__(self, tokenizer_manager):
|
sglang/srt/managers/tp_worker.py
CHANGED
@@ -83,6 +83,7 @@ class TpModelWorker:
|
|
83
83
|
server_args.tokenizer_path,
|
84
84
|
tokenizer_mode=server_args.tokenizer_mode,
|
85
85
|
trust_remote_code=server_args.trust_remote_code,
|
86
|
+
revision=server_args.revision,
|
86
87
|
)
|
87
88
|
self.tokenizer = self.processor.tokenizer
|
88
89
|
else:
|
@@ -90,6 +91,7 @@ class TpModelWorker:
|
|
90
91
|
server_args.tokenizer_path,
|
91
92
|
tokenizer_mode=server_args.tokenizer_mode,
|
92
93
|
trust_remote_code=server_args.trust_remote_code,
|
94
|
+
revision=server_args.revision,
|
93
95
|
)
|
94
96
|
self.device = self.model_runner.device
|
95
97
|
|
@@ -101,6 +103,7 @@ class TpModelWorker:
|
|
101
103
|
self.max_total_num_tokens // 2
|
102
104
|
if server_args.max_running_requests is None
|
103
105
|
else server_args.max_running_requests
|
106
|
+
// (server_args.dp_size if server_args.enable_dp_attention else 1)
|
104
107
|
),
|
105
108
|
self.model_runner.req_to_token_pool.size,
|
106
109
|
)
|
@@ -142,16 +145,15 @@ class TpModelWorker:
|
|
142
145
|
def get_tp_cpu_group(self):
|
143
146
|
return self.model_runner.tp_group.cpu_group
|
144
147
|
|
148
|
+
def get_attention_tp_cpu_group(self):
|
149
|
+
return self.model_runner.attention_tp_group.cpu_group
|
150
|
+
|
145
151
|
def get_memory_pool(self):
|
146
152
|
return (
|
147
153
|
self.model_runner.req_to_token_pool,
|
148
154
|
self.model_runner.token_to_kv_pool,
|
149
155
|
)
|
150
156
|
|
151
|
-
def forward_batch_idle(self, model_worker_batch: ModelWorkerBatch):
|
152
|
-
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
|
153
|
-
self.model_runner.forward(forward_batch)
|
154
|
-
|
155
157
|
def forward_batch_generation(
|
156
158
|
self,
|
157
159
|
model_worker_batch: ModelWorkerBatch,
|
@@ -82,6 +82,8 @@ class TpModelWorkerClient:
|
|
82
82
|
self.forward_thread.start()
|
83
83
|
self.parent_process = psutil.Process().parent()
|
84
84
|
self.scheduler_stream = torch.get_device_module(self.device).current_stream()
|
85
|
+
if self.device == "cpu":
|
86
|
+
self.scheduler_stream.synchronize = lambda: None # No-op for CPU
|
85
87
|
|
86
88
|
def get_worker_info(self):
|
87
89
|
return self.worker.get_worker_info()
|
@@ -92,6 +94,9 @@ class TpModelWorkerClient:
|
|
92
94
|
def get_tp_cpu_group(self):
|
93
95
|
return self.worker.get_tp_cpu_group()
|
94
96
|
|
97
|
+
def get_attention_tp_cpu_group(self):
|
98
|
+
return self.worker.get_attention_tp_cpu_group()
|
99
|
+
|
95
100
|
def get_memory_pool(self):
|
96
101
|
return (
|
97
102
|
self.worker.model_runner.req_to_token_pool,
|
@@ -151,11 +156,6 @@ class TpModelWorkerClient:
|
|
151
156
|
logits_output.input_token_logprobs = (
|
152
157
|
logits_output.input_token_logprobs.to("cpu", non_blocking=True)
|
153
158
|
)
|
154
|
-
logits_output.normalized_prompt_logprobs = (
|
155
|
-
logits_output.normalized_prompt_logprobs.to(
|
156
|
-
"cpu", non_blocking=True
|
157
|
-
)
|
158
|
-
)
|
159
159
|
next_token_ids = next_token_ids.to("cpu", non_blocking=True)
|
160
160
|
copy_done.record()
|
161
161
|
|
@@ -174,9 +174,6 @@ class TpModelWorkerClient:
|
|
174
174
|
logits_output.input_token_logprobs = (
|
175
175
|
logits_output.input_token_logprobs.tolist()
|
176
176
|
)
|
177
|
-
logits_output.normalized_prompt_logprobs = (
|
178
|
-
logits_output.normalized_prompt_logprobs.tolist()
|
179
|
-
)
|
180
177
|
next_token_ids = next_token_ids.tolist()
|
181
178
|
return logits_output, next_token_ids
|
182
179
|
|
@@ -0,0 +1,44 @@
|
|
1
|
+
import logging
|
2
|
+
from http import HTTPStatus
|
3
|
+
from typing import Optional
|
4
|
+
|
5
|
+
from sglang.srt.managers.schedule_batch import FINISH_ABORT, Req
|
6
|
+
|
7
|
+
logger = logging.getLogger(__name__)
|
8
|
+
|
9
|
+
|
10
|
+
def validate_input_length(
|
11
|
+
req: Req, max_req_input_len: int, allow_auto_truncate: bool
|
12
|
+
) -> Optional[str]:
|
13
|
+
"""Validate and potentially truncate input length.
|
14
|
+
|
15
|
+
Args:
|
16
|
+
req: The request containing input_ids to validate
|
17
|
+
max_req_input_len: Maximum allowed input length
|
18
|
+
allow_auto_truncate: Whether to truncate long inputs
|
19
|
+
|
20
|
+
Returns:
|
21
|
+
Error message if validation fails, None if successful
|
22
|
+
"""
|
23
|
+
if len(req.origin_input_ids) >= max_req_input_len:
|
24
|
+
if allow_auto_truncate:
|
25
|
+
logger.warning(
|
26
|
+
"Request length is longer than the KV cache pool size or "
|
27
|
+
"the max context length. Truncated. "
|
28
|
+
f"{len(req.origin_input_ids)=}, {max_req_input_len=}."
|
29
|
+
)
|
30
|
+
req.origin_input_ids = req.origin_input_ids[:max_req_input_len]
|
31
|
+
return None
|
32
|
+
else:
|
33
|
+
error_msg = (
|
34
|
+
f"Input length ({len(req.origin_input_ids)} tokens) exceeds "
|
35
|
+
f"the maximum allowed length ({max_req_input_len} tokens). "
|
36
|
+
f"Use a shorter input or enable --allow-auto-truncate."
|
37
|
+
)
|
38
|
+
logger.error(error_msg)
|
39
|
+
req.finished_reason = FINISH_ABORT(
|
40
|
+
error_msg, HTTPStatus.BAD_REQUEST, "BadRequestError"
|
41
|
+
)
|
42
|
+
return error_msg
|
43
|
+
|
44
|
+
return None
|