sglang 0.2.4__py3-none-any.whl → 0.2.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/lang/backend/runtime_endpoint.py +4 -4
- sglang/lang/interpreter.py +4 -4
- sglang/srt/constrained/fsm_cache.py +21 -1
- sglang/srt/hf_transformers_utils.py +3 -1
- sglang/srt/layers/logits_processor.py +70 -61
- sglang/srt/layers/radix_attention.py +5 -2
- sglang/srt/layers/token_attention.py +1 -1
- sglang/srt/managers/controller/cuda_graph_runner.py +26 -17
- sglang/srt/managers/controller/infer_batch.py +54 -13
- sglang/srt/managers/controller/model_runner.py +22 -7
- sglang/srt/managers/controller/tp_worker.py +47 -41
- sglang/srt/managers/io_struct.py +2 -2
- sglang/srt/managers/tokenizer_manager.py +62 -43
- sglang/srt/model_config.py +5 -0
- sglang/srt/models/deepseek_v2.py +517 -0
- sglang/srt/models/llama_classification.py +3 -3
- sglang/srt/openai_api/adapter.py +33 -33
- sglang/srt/openai_api/protocol.py +1 -1
- sglang/srt/sampling_params.py +5 -4
- sglang/srt/server.py +2 -15
- sglang/srt/server_args.py +28 -7
- sglang/test/test_programs.py +5 -1
- sglang/version.py +1 -1
- {sglang-0.2.4.dist-info → sglang-0.2.6.dist-info}/METADATA +9 -7
- {sglang-0.2.4.dist-info → sglang-0.2.6.dist-info}/RECORD +28 -27
- {sglang-0.2.4.dist-info → sglang-0.2.6.dist-info}/LICENSE +0 -0
- {sglang-0.2.4.dist-info → sglang-0.2.6.dist-info}/WHEEL +0 -0
- {sglang-0.2.4.dist-info → sglang-0.2.6.dist-info}/top_level.txt +0 -0
@@ -98,17 +98,21 @@ class ModelTpServer:
|
|
98
98
|
if server_args.max_prefill_tokens is None
|
99
99
|
else server_args.max_prefill_tokens
|
100
100
|
)
|
101
|
-
self.max_running_requests = (
|
102
|
-
self.max_total_num_tokens // 2
|
103
|
-
if server_args.max_running_requests is None
|
104
|
-
else server_args.max_running_requests
|
105
|
-
)
|
106
101
|
self.max_running_requests = min(
|
107
|
-
|
102
|
+
(
|
103
|
+
self.max_total_num_tokens // 2
|
104
|
+
if server_args.max_running_requests is None
|
105
|
+
else server_args.max_running_requests
|
106
|
+
),
|
107
|
+
self.model_runner.req_to_token_pool.size - 1,
|
108
108
|
)
|
109
109
|
self.int_token_logit_bias = torch.tensor(
|
110
110
|
get_int_token_logit_bias(self.tokenizer, self.model_config.vocab_size)
|
111
111
|
)
|
112
|
+
self.max_req_input_len = min(
|
113
|
+
self.model_config.context_len - 1,
|
114
|
+
self.max_total_num_tokens - 1,
|
115
|
+
)
|
112
116
|
set_random_seed(server_args.random_seed)
|
113
117
|
|
114
118
|
# Print info
|
@@ -295,18 +299,20 @@ class ModelTpServer:
|
|
295
299
|
)
|
296
300
|
|
297
301
|
# Truncate prompts that are too long
|
298
|
-
req.origin_input_ids
|
302
|
+
if len(req.origin_input_ids) >= self.max_req_input_len:
|
303
|
+
logger.warn(
|
304
|
+
"Request length is longer than the KV cache pool size or "
|
305
|
+
"the max context length. Truncated!!!"
|
306
|
+
)
|
307
|
+
req.origin_input_ids = req.origin_input_ids[: self.max_req_input_len]
|
299
308
|
req.sampling_params.max_new_tokens = min(
|
300
|
-
|
301
|
-
|
302
|
-
|
309
|
+
(
|
310
|
+
req.sampling_params.max_new_tokens
|
311
|
+
if req.sampling_params.max_new_tokens is not None
|
312
|
+
else 1 << 30
|
313
|
+
),
|
314
|
+
self.max_req_input_len - 1 - len(req.origin_input_ids),
|
303
315
|
)
|
304
|
-
if req.sampling_params.max_new_tokens < 0:
|
305
|
-
req.origin_input_ids = req.origin_input_ids[
|
306
|
-
: self.max_total_num_tokens - 128
|
307
|
-
]
|
308
|
-
logger.error("Request longer than memory pool size, truncated!!!")
|
309
|
-
|
310
316
|
self.forward_queue.append(req)
|
311
317
|
|
312
318
|
def get_new_prefill_batch(self) -> Optional[Batch]:
|
@@ -449,7 +455,7 @@ class ModelTpServer:
|
|
449
455
|
torch.arange(len(next_token_ids), device=next_token_ids.device),
|
450
456
|
next_token_ids,
|
451
457
|
].tolist()
|
452
|
-
output.
|
458
|
+
output.input_token_logprobs = output.input_token_logprobs.tolist()
|
453
459
|
output.normalized_prompt_logprobs = (
|
454
460
|
output.normalized_prompt_logprobs.tolist()
|
455
461
|
)
|
@@ -475,24 +481,24 @@ class ModelTpServer:
|
|
475
481
|
if req.normalized_prompt_logprob is None:
|
476
482
|
req.normalized_prompt_logprob = output.normalized_prompt_logprobs[i]
|
477
483
|
|
478
|
-
if req.
|
484
|
+
if req.input_token_logprobs is None:
|
479
485
|
# If logprob_start_len > 0, then first logprob_start_len prompt tokens will be ignored.
|
480
|
-
req.
|
486
|
+
req.input_token_logprobs = list(
|
481
487
|
zip(
|
482
|
-
output.
|
488
|
+
output.input_token_logprobs[pt : pt + req.extend_input_len - 1],
|
483
489
|
req.input_ids[-req.extend_input_len + 1 :],
|
484
490
|
)
|
485
491
|
)
|
486
492
|
if req.logprob_start_len == 0:
|
487
|
-
req.
|
493
|
+
req.input_token_logprobs = [
|
488
494
|
(None, req.input_ids[0])
|
489
|
-
] + req.
|
495
|
+
] + req.input_token_logprobs
|
490
496
|
|
491
497
|
if req.last_update_decode_tokens != 0:
|
492
|
-
req.
|
498
|
+
req.output_token_logprobs.extend(
|
493
499
|
list(
|
494
500
|
zip(
|
495
|
-
output.
|
501
|
+
output.input_token_logprobs[
|
496
502
|
pt
|
497
503
|
+ req.extend_input_len
|
498
504
|
- req.last_update_decode_tokens : pt
|
@@ -504,21 +510,21 @@ class ModelTpServer:
|
|
504
510
|
)
|
505
511
|
)
|
506
512
|
|
507
|
-
req.
|
513
|
+
req.output_token_logprobs.append(
|
508
514
|
(output.next_token_logprobs[i], next_token_ids[i])
|
509
515
|
)
|
510
516
|
|
511
517
|
if req.top_logprobs_num > 0:
|
512
|
-
if req.
|
513
|
-
req.
|
518
|
+
if req.input_top_logprobs is None:
|
519
|
+
req.input_top_logprobs = output.input_top_logprobs[i]
|
514
520
|
if req.logprob_start_len == 0:
|
515
|
-
req.
|
521
|
+
req.input_top_logprobs = [None] + req.input_top_logprobs
|
516
522
|
|
517
523
|
if req.last_update_decode_tokens != 0:
|
518
|
-
req.
|
519
|
-
output.
|
524
|
+
req.output_top_logprobs.extend(
|
525
|
+
output.input_top_logprobs[i][-req.last_update_decode_tokens + 1 :]
|
520
526
|
)
|
521
|
-
req.
|
527
|
+
req.output_top_logprobs.append(output.output_top_logprobs[i])
|
522
528
|
|
523
529
|
def cache_filled_batch(self, batch: Batch):
|
524
530
|
req_pool_indices_cpu = batch.req_pool_indices.cpu().numpy()
|
@@ -583,11 +589,11 @@ class ModelTpServer:
|
|
583
589
|
req.check_finished()
|
584
590
|
|
585
591
|
if req.return_logprob:
|
586
|
-
req.
|
592
|
+
req.output_token_logprobs.append(
|
587
593
|
(next_token_logprobs[i], next_token_id)
|
588
594
|
)
|
589
595
|
if req.top_logprobs_num > 0:
|
590
|
-
req.
|
596
|
+
req.output_top_logprobs.append(output.output_top_logprobs[i])
|
591
597
|
|
592
598
|
self.handle_finished_requests(batch)
|
593
599
|
|
@@ -639,16 +645,16 @@ class ModelTpServer:
|
|
639
645
|
}
|
640
646
|
if req.return_logprob:
|
641
647
|
(
|
642
|
-
meta_info["
|
643
|
-
meta_info["
|
644
|
-
meta_info["
|
645
|
-
meta_info["
|
648
|
+
meta_info["input_token_logprobs"],
|
649
|
+
meta_info["output_token_logprobs"],
|
650
|
+
meta_info["input_top_logprobs"],
|
651
|
+
meta_info["output_top_logprobs"],
|
646
652
|
meta_info["normalized_prompt_logprob"],
|
647
653
|
) = (
|
648
|
-
req.
|
649
|
-
req.
|
650
|
-
req.
|
651
|
-
req.
|
654
|
+
req.input_token_logprobs,
|
655
|
+
req.output_token_logprobs,
|
656
|
+
req.input_top_logprobs,
|
657
|
+
req.output_top_logprobs,
|
652
658
|
req.normalized_prompt_logprob,
|
653
659
|
)
|
654
660
|
output_meta_info.append(meta_info)
|
sglang/srt/managers/io_struct.py
CHANGED
@@ -20,7 +20,7 @@ class GenerateReqInput:
|
|
20
20
|
# The image input. It can be a file name, a url, or base64 encoded string.
|
21
21
|
# See also python/sglang/srt/utils.py:load_image.
|
22
22
|
image_data: Optional[Union[List[str], str]] = None
|
23
|
-
# The sampling_params.
|
23
|
+
# The sampling_params. See descriptions below.
|
24
24
|
sampling_params: Union[List[Dict], Dict] = None
|
25
25
|
# The request id.
|
26
26
|
rid: Optional[Union[List[str], str]] = None
|
@@ -30,7 +30,7 @@ class GenerateReqInput:
|
|
30
30
|
logprob_start_len: Optional[Union[List[int], int]] = None
|
31
31
|
# The number of top logprobs to return.
|
32
32
|
top_logprobs_num: Optional[Union[List[int], int]] = None
|
33
|
-
# Whether to detokenize tokens in logprobs.
|
33
|
+
# Whether to detokenize tokens in text in the returned logprobs.
|
34
34
|
return_text_in_logprobs: bool = False
|
35
35
|
# Whether to stream output.
|
36
36
|
stream: bool = False
|
@@ -133,24 +133,10 @@ class TokenizerManager:
|
|
133
133
|
async for response in self._handle_batch_request(obj, request):
|
134
134
|
yield response
|
135
135
|
|
136
|
-
async def _handle_single_request(
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
rid = obj.rid[index]
|
141
|
-
else:
|
142
|
-
input_text = obj.text
|
143
|
-
rid = obj.rid[0]
|
144
|
-
input_ids = self.tokenizer.encode(input_text)
|
145
|
-
sampling_params = SamplingParams(**obj.sampling_params[0])
|
146
|
-
sampling_params.max_new_tokens = 0
|
147
|
-
pixel_values, image_hash, image_size = await self._get_pixel_values(
|
148
|
-
obj.image_data[0]
|
149
|
-
)
|
150
|
-
return_logprob = obj.return_logprob[0]
|
151
|
-
logprob_start_len = obj.logprob_start_len[0]
|
152
|
-
top_logprobs_num = obj.top_logprobs_num[0]
|
153
|
-
else:
|
136
|
+
async def _handle_single_request(
|
137
|
+
self, obj, request, index=None, is_cache_for_prefill=False
|
138
|
+
):
|
139
|
+
if not is_cache_for_prefill:
|
154
140
|
rid = obj.rid if index is None else obj.rid[index]
|
155
141
|
input_text = obj.text if index is None else obj.text[index]
|
156
142
|
input_ids = (
|
@@ -177,6 +163,22 @@ class TokenizerManager:
|
|
177
163
|
top_logprobs_num = (
|
178
164
|
obj.top_logprobs_num if index is None else obj.top_logprobs_num[index]
|
179
165
|
)
|
166
|
+
else:
|
167
|
+
if isinstance(obj.text, list):
|
168
|
+
input_text = obj.text[index]
|
169
|
+
rid = obj.rid[index]
|
170
|
+
else:
|
171
|
+
input_text = obj.text
|
172
|
+
rid = obj.rid[0]
|
173
|
+
input_ids = self.tokenizer.encode(input_text)
|
174
|
+
sampling_params = SamplingParams(**obj.sampling_params[0])
|
175
|
+
sampling_params.max_new_tokens = 0
|
176
|
+
pixel_values, image_hash, image_size = await self._get_pixel_values(
|
177
|
+
obj.image_data[0]
|
178
|
+
)
|
179
|
+
return_logprob = obj.return_logprob[0]
|
180
|
+
logprob_start_len = obj.logprob_start_len[0]
|
181
|
+
top_logprobs_num = obj.top_logprobs_num[0]
|
180
182
|
|
181
183
|
tokenized_obj = TokenizedGenerateReqInput(
|
182
184
|
rid,
|
@@ -196,26 +198,26 @@ class TokenizerManager:
|
|
196
198
|
event = asyncio.Event()
|
197
199
|
state = ReqState([], False, event)
|
198
200
|
self.rid_to_state[rid] = state
|
199
|
-
if
|
200
|
-
await self._wait_for_prefill_response(event, state, obj, request, rid)
|
201
|
-
yield input_ids
|
202
|
-
else:
|
201
|
+
if not is_cache_for_prefill:
|
203
202
|
async for response in self._wait_for_response(
|
204
203
|
event, state, obj, rid, request
|
205
204
|
):
|
206
205
|
yield response
|
206
|
+
else:
|
207
|
+
await self._wait_for_cache_prefill_response(event, state, obj, rid, request)
|
208
|
+
yield input_ids
|
207
209
|
|
208
|
-
async def _handle_batch_request(self, obj, request):
|
210
|
+
async def _handle_batch_request(self, obj: GenerateReqInput, request):
|
209
211
|
batch_size = obj.batch_size
|
210
212
|
parallel_sample_num = obj.sampling_params[0].get("n", 1)
|
211
213
|
|
212
214
|
if parallel_sample_num != 1:
|
213
|
-
|
215
|
+
# Send prefill requests to cache the common input
|
214
216
|
parallel_sample_num += 1
|
215
217
|
input_id_result = [] if obj.input_ids is None else None
|
216
218
|
for i in range(batch_size):
|
217
219
|
async for input_id in self._handle_single_request(
|
218
|
-
obj, request, index=i,
|
220
|
+
obj, request, index=i, is_cache_for_prefill=True
|
219
221
|
):
|
220
222
|
if input_id_result is not None:
|
221
223
|
input_id_result.append(input_id)
|
@@ -224,6 +226,7 @@ class TokenizerManager:
|
|
224
226
|
obj.input_ids = input_id_result
|
225
227
|
elif input_id_result is not None:
|
226
228
|
obj.input_ids = input_id_result[0]
|
229
|
+
|
227
230
|
# First send out all requests
|
228
231
|
for i in range(batch_size):
|
229
232
|
for j in range(parallel_sample_num):
|
@@ -308,17 +311,15 @@ class TokenizerManager:
|
|
308
311
|
|
309
312
|
yield output_list
|
310
313
|
|
311
|
-
def _validate_input_length(self, input_ids):
|
314
|
+
def _validate_input_length(self, input_ids: List[int]):
|
312
315
|
if len(input_ids) >= self.context_len:
|
313
316
|
raise ValueError(
|
314
317
|
f"The input ({len(input_ids)} tokens) is longer than the "
|
315
318
|
f"model's context length ({self.context_len} tokens)."
|
316
319
|
)
|
317
320
|
|
318
|
-
def _get_sampling_params(self, sampling_params_data
|
321
|
+
def _get_sampling_params(self, sampling_params_data: dict):
|
319
322
|
sampling_params = SamplingParams(**sampling_params_data)
|
320
|
-
if max_new_tokens is not None:
|
321
|
-
sampling_params.max_new_tokens = max_new_tokens
|
322
323
|
if sampling_params.max_new_tokens != 0:
|
323
324
|
sampling_params.normalize(self.tokenizer)
|
324
325
|
sampling_params.verify()
|
@@ -332,7 +333,14 @@ class TokenizerManager:
|
|
332
333
|
else:
|
333
334
|
return None, None, None
|
334
335
|
|
335
|
-
async def _wait_for_response(
|
336
|
+
async def _wait_for_response(
|
337
|
+
self,
|
338
|
+
event: asyncio.Event,
|
339
|
+
state: ReqState,
|
340
|
+
obj: GenerateReqInput,
|
341
|
+
rid: str,
|
342
|
+
request,
|
343
|
+
):
|
336
344
|
while True:
|
337
345
|
try:
|
338
346
|
await asyncio.wait_for(event.wait(), timeout=4)
|
@@ -361,7 +369,14 @@ class TokenizerManager:
|
|
361
369
|
event.clear()
|
362
370
|
yield out
|
363
371
|
|
364
|
-
async def
|
372
|
+
async def _wait_for_cache_prefill_response(
|
373
|
+
self,
|
374
|
+
event: asyncio.Event,
|
375
|
+
state: ReqState,
|
376
|
+
obj: GenerateReqInput,
|
377
|
+
rid: str,
|
378
|
+
request,
|
379
|
+
):
|
365
380
|
while True:
|
366
381
|
try:
|
367
382
|
await asyncio.wait_for(state.event.wait(), timeout=4)
|
@@ -380,7 +395,7 @@ class TokenizerManager:
|
|
380
395
|
req = FlushCacheReq()
|
381
396
|
self.send_to_router.send_pyobj(req)
|
382
397
|
|
383
|
-
def abort_request(self, rid):
|
398
|
+
def abort_request(self, rid: str):
|
384
399
|
if rid not in self.rid_to_state:
|
385
400
|
return
|
386
401
|
del self.rid_to_state[rid]
|
@@ -426,31 +441,35 @@ class TokenizerManager:
|
|
426
441
|
state.event.set()
|
427
442
|
|
428
443
|
def convert_logprob_style(
|
429
|
-
self,
|
444
|
+
self,
|
445
|
+
ret: dict,
|
446
|
+
return_logprob: bool,
|
447
|
+
top_logprobs_num: int,
|
448
|
+
return_text_in_logprobs: bool,
|
430
449
|
):
|
431
450
|
if return_logprob:
|
432
|
-
ret["meta_info"]["
|
433
|
-
ret["meta_info"]["
|
451
|
+
ret["meta_info"]["input_token_logprobs"] = self.detokenize_logprob_tokens(
|
452
|
+
ret["meta_info"]["input_token_logprobs"], return_text_in_logprobs
|
434
453
|
)
|
435
|
-
ret["meta_info"]["
|
436
|
-
ret["meta_info"]["
|
454
|
+
ret["meta_info"]["output_token_logprobs"] = self.detokenize_logprob_tokens(
|
455
|
+
ret["meta_info"]["output_token_logprobs"], return_text_in_logprobs
|
437
456
|
)
|
438
457
|
|
439
458
|
if top_logprobs_num > 0:
|
440
|
-
ret["meta_info"]["
|
459
|
+
ret["meta_info"]["input_top_logprobs"] = (
|
441
460
|
self.detokenize_top_logprobs_tokens(
|
442
|
-
ret["meta_info"]["
|
461
|
+
ret["meta_info"]["input_top_logprobs"],
|
443
462
|
return_text_in_logprobs,
|
444
463
|
)
|
445
464
|
)
|
446
|
-
ret["meta_info"]["
|
465
|
+
ret["meta_info"]["output_top_logprobs"] = (
|
447
466
|
self.detokenize_top_logprobs_tokens(
|
448
|
-
ret["meta_info"]["
|
467
|
+
ret["meta_info"]["output_top_logprobs"], return_text_in_logprobs
|
449
468
|
)
|
450
469
|
)
|
451
470
|
return ret
|
452
471
|
|
453
|
-
def detokenize_logprob_tokens(self, token_logprobs, decode_to_text):
|
472
|
+
def detokenize_logprob_tokens(self, token_logprobs, decode_to_text: bool):
|
454
473
|
if not decode_to_text:
|
455
474
|
return [(logprob, token_id, None) for logprob, token_id in token_logprobs]
|
456
475
|
|
@@ -461,7 +480,7 @@ class TokenizerManager:
|
|
461
480
|
for (logprob, token_id), token_text, in zip(token_logprobs, token_texts)
|
462
481
|
]
|
463
482
|
|
464
|
-
def detokenize_top_logprobs_tokens(self, top_logprobs, decode_to_text):
|
483
|
+
def detokenize_top_logprobs_tokens(self, top_logprobs, decode_to_text: bool):
|
465
484
|
for i, t in enumerate(top_logprobs):
|
466
485
|
if t:
|
467
486
|
top_logprobs[i] = self.detokenize_logprob_tokens(t, decode_to_text)
|
sglang/srt/model_config.py
CHANGED
@@ -36,6 +36,11 @@ class ModelConfig:
|
|
36
36
|
"head_dim",
|
37
37
|
self.hf_config.hidden_size // self.hf_config.num_attention_heads,
|
38
38
|
)
|
39
|
+
|
40
|
+
# FIXME: temporary special judge for deepseek v2 MLA architecture
|
41
|
+
if "DeepseekV2ForCausalLM" in self.hf_config.architectures:
|
42
|
+
self.head_dim = 256
|
43
|
+
|
39
44
|
self.num_attention_heads = self.hf_config.num_attention_heads
|
40
45
|
self.num_key_value_heads = getattr(self.hf_config, "num_key_value_heads", None)
|
41
46
|
|