sglang 0.2.14__py3-none-any.whl → 0.2.14.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/launch_server_llavavid.py +26 -0
- sglang/srt/constrained/fsm_cache.py +11 -2
- sglang/srt/constrained/jump_forward.py +1 -0
- sglang/srt/hf_transformers_utils.py +0 -149
- sglang/srt/layers/activation.py +93 -11
- sglang/srt/layers/layernorm.py +47 -4
- sglang/srt/layers/logits_processor.py +4 -4
- sglang/srt/layers/sampler.py +15 -68
- sglang/srt/managers/io_struct.py +5 -4
- sglang/srt/managers/schedule_batch.py +20 -25
- sglang/srt/managers/tokenizer_manager.py +74 -61
- sglang/srt/managers/tp_worker.py +49 -43
- sglang/srt/model_executor/cuda_graph_runner.py +17 -31
- sglang/srt/model_executor/forward_batch_info.py +9 -26
- sglang/srt/model_executor/model_runner.py +20 -17
- sglang/srt/models/chatglm.py +13 -5
- sglang/srt/models/commandr.py +1 -5
- sglang/srt/models/dbrx.py +1 -5
- sglang/srt/models/deepseek.py +1 -5
- sglang/srt/models/deepseek_v2.py +1 -5
- sglang/srt/models/gemma.py +3 -7
- sglang/srt/models/gemma2.py +2 -56
- sglang/srt/models/gpt_bigcode.py +2 -6
- sglang/srt/models/grok.py +10 -8
- sglang/srt/models/internlm2.py +1 -5
- sglang/srt/models/llama2.py +6 -11
- sglang/srt/models/llama_classification.py +2 -6
- sglang/srt/models/llama_embedding.py +3 -4
- sglang/srt/models/llava.py +69 -91
- sglang/srt/models/llavavid.py +40 -86
- sglang/srt/models/minicpm.py +1 -5
- sglang/srt/models/mixtral.py +1 -5
- sglang/srt/models/mixtral_quant.py +1 -5
- sglang/srt/models/qwen.py +2 -5
- sglang/srt/models/qwen2.py +5 -10
- sglang/srt/models/qwen2_moe.py +21 -24
- sglang/srt/models/stablelm.py +1 -5
- sglang/srt/models/yivl.py +2 -7
- sglang/srt/openai_api/adapter.py +85 -4
- sglang/srt/openai_api/protocol.py +2 -0
- sglang/srt/sampling/sampling_batch_info.py +1 -74
- sglang/srt/sampling/sampling_params.py +4 -0
- sglang/srt/server.py +11 -4
- sglang/srt/utils.py +18 -33
- sglang/test/runners.py +2 -2
- sglang/test/test_layernorm.py +53 -1
- sglang/version.py +1 -1
- {sglang-0.2.14.dist-info → sglang-0.2.14.post2.dist-info}/METADATA +11 -5
- {sglang-0.2.14.dist-info → sglang-0.2.14.post2.dist-info}/RECORD +52 -51
- {sglang-0.2.14.dist-info → sglang-0.2.14.post2.dist-info}/WHEEL +1 -1
- {sglang-0.2.14.dist-info → sglang-0.2.14.post2.dist-info}/LICENSE +0 -0
- {sglang-0.2.14.dist-info → sglang-0.2.14.post2.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,3 @@
|
|
1
|
-
from __future__ import annotations
|
2
|
-
|
3
1
|
"""
|
4
2
|
Copyright 2023-2024 SGLang Team
|
5
3
|
Licensed under the Apache License, Version 2.0 (the "License");
|
@@ -19,7 +17,7 @@ limitations under the License.
|
|
19
17
|
|
20
18
|
import logging
|
21
19
|
from dataclasses import dataclass
|
22
|
-
from typing import
|
20
|
+
from typing import List, Optional, Union
|
23
21
|
|
24
22
|
import torch
|
25
23
|
|
@@ -31,10 +29,6 @@ from sglang.srt.mem_cache.chunk_cache import ChunkCache
|
|
31
29
|
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
|
32
30
|
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
33
31
|
|
34
|
-
if TYPE_CHECKING:
|
35
|
-
from sglang.srt.layers.sampler import SampleOutput
|
36
|
-
|
37
|
-
|
38
32
|
INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
|
39
33
|
|
40
34
|
# Put some global args for easy access
|
@@ -127,8 +121,8 @@ class Req:
|
|
127
121
|
|
128
122
|
# For vision input
|
129
123
|
self.pixel_values = None
|
130
|
-
self.
|
131
|
-
self.
|
124
|
+
self.image_sizes = None
|
125
|
+
self.image_offsets = None
|
132
126
|
self.pad_value = None
|
133
127
|
|
134
128
|
# Prefix info
|
@@ -268,7 +262,14 @@ class Req:
|
|
268
262
|
|
269
263
|
all_text = self.origin_input_text + self.decoded_text + jump_forward_str
|
270
264
|
all_ids = self.tokenizer.encode(all_text)
|
265
|
+
if not all_ids:
|
266
|
+
logger.warning("Encoded all_text resulted in empty all_ids")
|
267
|
+
return False
|
268
|
+
|
271
269
|
prompt_tokens = len(self.origin_input_ids_unpadded)
|
270
|
+
if prompt_tokens > len(all_ids):
|
271
|
+
logger.warning("prompt_tokens is larger than encoded all_ids")
|
272
|
+
return False
|
272
273
|
|
273
274
|
if all_ids[prompt_tokens - 1] != self.origin_input_ids_unpadded[-1]:
|
274
275
|
# TODO(lsyin): fix token fusion
|
@@ -599,12 +600,12 @@ class ScheduleBatch:
|
|
599
600
|
if req.pixel_values is not None:
|
600
601
|
(
|
601
602
|
req.origin_input_ids,
|
602
|
-
req.
|
603
|
+
req.image_offsets,
|
603
604
|
) = model_runner.model.pad_input_ids(
|
604
605
|
req.origin_input_ids_unpadded,
|
605
606
|
req.pad_value,
|
606
|
-
req.pixel_values
|
607
|
-
req.
|
607
|
+
req.pixel_values,
|
608
|
+
req.image_sizes,
|
608
609
|
)
|
609
610
|
|
610
611
|
jump_forward_reqs.append(req)
|
@@ -677,17 +678,11 @@ class ScheduleBatch:
|
|
677
678
|
self.top_logprobs_nums.extend(other.top_logprobs_nums)
|
678
679
|
self.return_logprob = any(req.return_logprob for req in self.reqs)
|
679
680
|
|
680
|
-
def
|
681
|
-
|
682
|
-
|
683
|
-
|
684
|
-
|
685
|
-
|
686
|
-
argmax_ids = torch.argmax(probs, dim=-1)
|
687
|
-
batch_next_token_ids = torch.where(
|
688
|
-
sample_output.success, batch_next_token_ids, argmax_ids
|
689
|
-
)
|
690
|
-
sample_output.probs = probs
|
691
|
-
sample_output.batch_next_token_ids = batch_next_token_ids
|
681
|
+
def sample(self, logits: torch.Tensor):
|
682
|
+
from sglang.srt.layers.sampler import Sampler
|
683
|
+
|
684
|
+
sampler = Sampler()
|
685
|
+
|
686
|
+
batch_next_token_ids = sampler(logits, self.sampling_info)
|
692
687
|
|
693
|
-
return
|
688
|
+
return batch_next_token_ids
|
@@ -23,6 +23,7 @@ import multiprocessing as mp
|
|
23
23
|
import os
|
24
24
|
from typing import Dict, List, Optional, Tuple, Union
|
25
25
|
|
26
|
+
import fastapi
|
26
27
|
import numpy as np
|
27
28
|
import transformers
|
28
29
|
import uvloop
|
@@ -96,21 +97,18 @@ class TokenizerManager:
|
|
96
97
|
trust_remote_code=server_args.trust_remote_code,
|
97
98
|
model_overide_args=model_overide_args,
|
98
99
|
)
|
99
|
-
|
100
100
|
self.is_generation = is_generation_model(
|
101
101
|
self.hf_config.architectures, self.server_args.is_embedding
|
102
102
|
)
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
else:
|
107
|
-
self.context_len = get_context_length(self.hf_config)
|
103
|
+
self.context_len = server_args.context_length or get_context_length(
|
104
|
+
self.hf_config
|
105
|
+
)
|
108
106
|
|
109
107
|
# Create tokenizer
|
110
108
|
if server_args.skip_tokenizer_init:
|
111
109
|
self.tokenizer = self.processor = None
|
112
110
|
else:
|
113
|
-
if is_multimodal_model(self.
|
111
|
+
if is_multimodal_model(self.hf_config.architectures):
|
114
112
|
self.processor = get_processor(
|
115
113
|
server_args.tokenizer_path,
|
116
114
|
tokenizer_mode=server_args.tokenizer_mode,
|
@@ -118,6 +116,9 @@ class TokenizerManager:
|
|
118
116
|
)
|
119
117
|
self.tokenizer = self.processor.tokenizer
|
120
118
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
119
|
+
|
120
|
+
# We want to parallelize the image pre-processing so we
|
121
|
+
# create an executor for it
|
121
122
|
self.executor = concurrent.futures.ProcessPoolExecutor(
|
122
123
|
initializer=init_global_processor,
|
123
124
|
mp_context=mp.get_context("fork"),
|
@@ -134,12 +135,14 @@ class TokenizerManager:
|
|
134
135
|
self.to_create_loop = True
|
135
136
|
self.rid_to_state: Dict[str, ReqState] = {}
|
136
137
|
|
137
|
-
#
|
138
|
+
# For update model weights
|
138
139
|
self.model_update_lock = asyncio.Lock()
|
139
140
|
self.model_update_result = None
|
140
141
|
|
141
142
|
async def generate_request(
|
142
|
-
self,
|
143
|
+
self,
|
144
|
+
obj: Union[GenerateReqInput, EmbeddingReqInput],
|
145
|
+
request: Optional[fastapi.Request] = None,
|
143
146
|
):
|
144
147
|
if self.to_create_loop:
|
145
148
|
self.create_handle_loop()
|
@@ -160,7 +163,7 @@ class TokenizerManager:
|
|
160
163
|
async def _handle_single_request(
|
161
164
|
self,
|
162
165
|
obj: Union[GenerateReqInput, EmbeddingReqInput],
|
163
|
-
request,
|
166
|
+
request: Optional[fastapi.Request] = None,
|
164
167
|
index: Optional[int] = None,
|
165
168
|
is_cache_for_prefill: Optional[bool] = False,
|
166
169
|
):
|
@@ -182,8 +185,8 @@ class TokenizerManager:
|
|
182
185
|
)
|
183
186
|
|
184
187
|
if self.is_generation:
|
185
|
-
pixel_values,
|
186
|
-
obj.image_data
|
188
|
+
pixel_values, image_hashes, image_sizes = await self._get_pixel_values(
|
189
|
+
obj.image_data if not_use_index else obj.image_data[index]
|
187
190
|
)
|
188
191
|
return_logprob = (
|
189
192
|
obj.return_logprob if not_use_index else obj.return_logprob[index]
|
@@ -195,7 +198,6 @@ class TokenizerManager:
|
|
195
198
|
)
|
196
199
|
if return_logprob and logprob_start_len == -1:
|
197
200
|
logprob_start_len = len(input_ids) - 1
|
198
|
-
|
199
201
|
top_logprobs_num = (
|
200
202
|
obj.top_logprobs_num
|
201
203
|
if not_use_index
|
@@ -238,13 +240,14 @@ class TokenizerManager:
|
|
238
240
|
|
239
241
|
sampling_params = SamplingParams(**obj.sampling_params[0])
|
240
242
|
sampling_params.max_new_tokens = 0
|
241
|
-
pixel_values,
|
243
|
+
pixel_values, image_hashes, image_sizes = await self._get_pixel_values(
|
242
244
|
obj.image_data[0]
|
243
245
|
)
|
244
246
|
return_logprob = obj.return_logprob[0]
|
245
247
|
logprob_start_len = obj.logprob_start_len[0]
|
246
248
|
top_logprobs_num = obj.top_logprobs_num[0]
|
247
249
|
|
250
|
+
# Send to the controller
|
248
251
|
if self.is_generation:
|
249
252
|
if return_logprob and logprob_start_len == -1:
|
250
253
|
logprob_start_len = len(input_ids) - 1
|
@@ -253,8 +256,8 @@ class TokenizerManager:
|
|
253
256
|
input_text,
|
254
257
|
input_ids,
|
255
258
|
pixel_values,
|
256
|
-
|
257
|
-
|
259
|
+
image_hashes,
|
260
|
+
image_sizes,
|
258
261
|
sampling_params,
|
259
262
|
return_logprob,
|
260
263
|
logprob_start_len,
|
@@ -268,24 +271,24 @@ class TokenizerManager:
|
|
268
271
|
input_ids,
|
269
272
|
sampling_params,
|
270
273
|
)
|
271
|
-
|
272
274
|
self.send_to_router.send_pyobj(tokenized_obj)
|
273
275
|
|
276
|
+
# Recv results
|
274
277
|
event = asyncio.Event()
|
275
278
|
state = ReqState([], False, event)
|
276
279
|
self.rid_to_state[rid] = state
|
277
280
|
if not is_cache_for_prefill:
|
278
|
-
async for response in self._wait_for_response(
|
279
|
-
event, state, obj, rid, request
|
280
|
-
):
|
281
|
+
async for response in self._wait_for_response(state, obj, rid, request):
|
281
282
|
yield response
|
282
283
|
else:
|
283
284
|
assert self.is_generation
|
284
|
-
await self._wait_for_cache_prefill_response(
|
285
|
+
await self._wait_for_cache_prefill_response(state, obj, rid, request)
|
285
286
|
yield input_ids
|
286
287
|
|
287
288
|
async def _handle_batch_request(
|
288
|
-
self,
|
289
|
+
self,
|
290
|
+
obj: Union[GenerateReqInput, EmbeddingReqInput],
|
291
|
+
request: Optional[fastapi.Request] = None,
|
289
292
|
):
|
290
293
|
batch_size = obj.batch_size
|
291
294
|
if self.is_generation:
|
@@ -340,8 +343,8 @@ class TokenizerManager:
|
|
340
343
|
if self.is_generation:
|
341
344
|
if obj.return_logprob[index] and obj.logprob_start_len[index] == -1:
|
342
345
|
obj.logprob_start_len[index] = len(input_ids) - 1
|
343
|
-
pixel_values,
|
344
|
-
obj.image_data[index]
|
346
|
+
pixel_values, image_hashes, image_sizes = (
|
347
|
+
await self._get_pixel_values(obj.image_data[index])
|
345
348
|
)
|
346
349
|
|
347
350
|
tokenized_obj = TokenizedGenerateReqInput(
|
@@ -349,8 +352,8 @@ class TokenizerManager:
|
|
349
352
|
input_text,
|
350
353
|
input_ids,
|
351
354
|
pixel_values,
|
352
|
-
|
353
|
-
|
355
|
+
image_hashes,
|
356
|
+
image_sizes,
|
354
357
|
sampling_params,
|
355
358
|
obj.return_logprob[index],
|
356
359
|
obj.logprob_start_len[index],
|
@@ -372,7 +375,6 @@ class TokenizerManager:
|
|
372
375
|
|
373
376
|
generators.append(
|
374
377
|
self._wait_for_response(
|
375
|
-
event,
|
376
378
|
state,
|
377
379
|
obj,
|
378
380
|
rid,
|
@@ -388,6 +390,7 @@ class TokenizerManager:
|
|
388
390
|
tasks = [asyncio.create_task(gen.__anext__()) for gen in generators]
|
389
391
|
output_list = [None] * len(tasks)
|
390
392
|
|
393
|
+
# Recv results
|
391
394
|
while tasks:
|
392
395
|
done, _ = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
|
393
396
|
|
@@ -426,25 +429,18 @@ class TokenizerManager:
|
|
426
429
|
sampling_params.verify()
|
427
430
|
return sampling_params
|
428
431
|
|
429
|
-
async def _get_pixel_values(self, image_data):
|
430
|
-
if image_data is None:
|
431
|
-
return None, None, None
|
432
|
-
else:
|
433
|
-
return await self._get_pixel_values_internal(image_data)
|
434
|
-
|
435
432
|
async def _wait_for_response(
|
436
433
|
self,
|
437
|
-
event: asyncio.Event,
|
438
434
|
state: ReqState,
|
439
435
|
obj: Union[GenerateReqInput, EmbeddingReqInput],
|
440
436
|
rid: str,
|
441
|
-
request,
|
442
|
-
index: int = None,
|
437
|
+
request: Optional[fastapi.Request] = None,
|
438
|
+
index: Optional[int] = None,
|
443
439
|
response_index: int = 0,
|
444
440
|
):
|
445
441
|
while True:
|
446
442
|
try:
|
447
|
-
await asyncio.wait_for(event.wait(), timeout=4)
|
443
|
+
await asyncio.wait_for(state.event.wait(), timeout=4)
|
448
444
|
except asyncio.TimeoutError:
|
449
445
|
if request is not None and await request.is_disconnected():
|
450
446
|
for rid in [obj.rid] if obj.is_single else obj.rid:
|
@@ -478,16 +474,15 @@ class TokenizerManager:
|
|
478
474
|
yield out
|
479
475
|
break
|
480
476
|
|
481
|
-
event.clear()
|
477
|
+
state.event.clear()
|
482
478
|
yield out
|
483
479
|
|
484
480
|
async def _wait_for_cache_prefill_response(
|
485
481
|
self,
|
486
|
-
event: asyncio.Event,
|
487
482
|
state: ReqState,
|
488
483
|
obj: GenerateReqInput,
|
489
484
|
rid: str,
|
490
|
-
request,
|
485
|
+
request: Optional[fastapi.Request] = None,
|
491
486
|
):
|
492
487
|
while True:
|
493
488
|
try:
|
@@ -514,7 +509,9 @@ class TokenizerManager:
|
|
514
509
|
req = AbortReq(rid)
|
515
510
|
self.send_to_router.send_pyobj(req)
|
516
511
|
|
517
|
-
async def update_weights(
|
512
|
+
async def update_weights(
|
513
|
+
self, obj: UpdateWeightReqInput, request: Optional[fastapi.Request] = None
|
514
|
+
):
|
518
515
|
if self.to_create_loop:
|
519
516
|
self.create_handle_loop()
|
520
517
|
|
@@ -659,12 +656,11 @@ class TokenizerManager:
|
|
659
656
|
)
|
660
657
|
return top_logprobs
|
661
658
|
|
662
|
-
async def
|
663
|
-
|
664
|
-
|
665
|
-
|
666
|
-
|
667
|
-
)
|
659
|
+
async def _get_pixel_values(self, image_data: List[Union[str, bytes]]):
|
660
|
+
if not image_data:
|
661
|
+
return None, None, None
|
662
|
+
|
663
|
+
aspect_ratio = getattr(self.hf_config, "image_aspect_ratio", None)
|
668
664
|
grid_pinpoints = (
|
669
665
|
self.hf_config.image_grid_pinpoints
|
670
666
|
if hasattr(self.hf_config, "image_grid_pinpoints")
|
@@ -673,35 +669,42 @@ class TokenizerManager:
|
|
673
669
|
)
|
674
670
|
|
675
671
|
if isinstance(image_data, list) and len(image_data) > 0:
|
676
|
-
|
672
|
+
# Multiple images
|
677
673
|
if len(image_data) > 1:
|
678
674
|
aspect_ratio = "pad" # LLaVA OneVision Handling: more than one image --> interleaved image mode or video mode. We do not use anyres
|
675
|
+
pixel_values, image_hashes, image_sizes = [], [], []
|
679
676
|
for img_data in image_data:
|
680
677
|
pixel_v, image_h, image_s = await self._process_single_image(
|
681
678
|
img_data, aspect_ratio, grid_pinpoints
|
682
679
|
)
|
683
680
|
pixel_values.append(pixel_v)
|
684
|
-
|
685
|
-
|
686
|
-
|
681
|
+
image_hashes.append(image_h)
|
682
|
+
image_sizes.append(image_s)
|
683
|
+
|
684
|
+
if isinstance(pixel_values[0], np.ndarray):
|
685
|
+
pixel_values = np.stack(pixel_values, axis=0)
|
687
686
|
else:
|
687
|
+
# A single image
|
688
688
|
pixel_values, image_hash, image_size = await self._process_single_image(
|
689
689
|
image_data[0], aspect_ratio, grid_pinpoints
|
690
690
|
)
|
691
|
-
|
692
|
-
|
691
|
+
image_hashes = [image_hash]
|
692
|
+
image_sizes = [image_size]
|
693
693
|
elif isinstance(image_data, str):
|
694
|
+
# A single image
|
694
695
|
pixel_values, image_hash, image_size = await self._process_single_image(
|
695
696
|
image_data, aspect_ratio, grid_pinpoints
|
696
697
|
)
|
697
|
-
|
698
|
-
|
698
|
+
image_hashes = [image_hash]
|
699
|
+
image_sizes = [image_size]
|
699
700
|
else:
|
700
|
-
|
701
|
+
raise ValueError(f"Invalid image data: {image_data}")
|
701
702
|
|
702
|
-
return pixel_values,
|
703
|
+
return pixel_values, image_hashes, image_sizes
|
703
704
|
|
704
|
-
async def _process_single_image(
|
705
|
+
async def _process_single_image(
|
706
|
+
self, image_data: Union[bytes, str], aspect_ratio: str, grid_pinpoints: str
|
707
|
+
):
|
705
708
|
if self.executor is not None:
|
706
709
|
loop = asyncio.get_event_loop()
|
707
710
|
return await loop.run_in_executor(
|
@@ -732,12 +735,16 @@ def init_global_processor(server_args: ServerArgs):
|
|
732
735
|
|
733
736
|
|
734
737
|
def _process_single_image_task(
|
735
|
-
image_data
|
738
|
+
image_data: Union[str, bytes],
|
739
|
+
image_aspect_ratio: Optional[str] = None,
|
740
|
+
image_grid_pinpoints: Optional[str] = None,
|
741
|
+
processor=None,
|
736
742
|
):
|
737
743
|
try:
|
738
744
|
processor = processor or global_processor
|
739
745
|
image, image_size = load_image(image_data)
|
740
746
|
if image_size is not None:
|
747
|
+
# It is a video with multiple images
|
741
748
|
image_hash = hash(image_data)
|
742
749
|
pixel_values = processor.image_processor(image)["pixel_values"]
|
743
750
|
for _ in range(len(pixel_values)):
|
@@ -745,6 +752,7 @@ def _process_single_image_task(
|
|
745
752
|
pixel_values = np.stack(pixel_values, axis=0)
|
746
753
|
return pixel_values, image_hash, image_size
|
747
754
|
else:
|
755
|
+
# It is an image
|
748
756
|
image_hash = hash(image_data)
|
749
757
|
if image_aspect_ratio == "pad":
|
750
758
|
image = expand2square(
|
@@ -754,13 +762,18 @@ def _process_single_image_task(
|
|
754
762
|
pixel_values = processor.image_processor(image.convert("RGB"))[
|
755
763
|
"pixel_values"
|
756
764
|
][0]
|
757
|
-
elif image_aspect_ratio == "anyres" or
|
765
|
+
elif image_aspect_ratio == "anyres" or (
|
766
|
+
image_aspect_ratio is not None and "anyres_max" in image_aspect_ratio
|
767
|
+
):
|
758
768
|
pixel_values = process_anyres_image(
|
759
769
|
image, processor.image_processor, image_grid_pinpoints
|
760
770
|
)
|
761
771
|
else:
|
762
772
|
pixel_values = processor.image_processor(image)["pixel_values"][0]
|
763
|
-
|
773
|
+
|
774
|
+
if isinstance(pixel_values, np.ndarray):
|
775
|
+
pixel_values = pixel_values.astype(np.float16)
|
776
|
+
|
764
777
|
return pixel_values, image_hash, image.size
|
765
778
|
except Exception:
|
766
779
|
logger.error("Exception in TokenizerManager:\n" + get_exception_traceback())
|
sglang/srt/managers/tp_worker.py
CHANGED
@@ -31,7 +31,7 @@ from sglang.global_config import global_config
|
|
31
31
|
from sglang.srt.constrained.fsm_cache import FSMCache
|
32
32
|
from sglang.srt.constrained.jump_forward import JumpForwardCache
|
33
33
|
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
|
34
|
-
from sglang.srt.layers.logits_processor import
|
34
|
+
from sglang.srt.layers.logits_processor import LogitProcessorOutput
|
35
35
|
from sglang.srt.managers.io_struct import (
|
36
36
|
AbortReq,
|
37
37
|
BatchEmbeddingOut,
|
@@ -108,7 +108,7 @@ class ModelTpServer:
|
|
108
108
|
if server_args.skip_tokenizer_init:
|
109
109
|
self.tokenizer = self.processor = None
|
110
110
|
else:
|
111
|
-
if is_multimodal_model(
|
111
|
+
if is_multimodal_model(self.model_config.hf_config.architectures):
|
112
112
|
self.processor = get_processor(
|
113
113
|
server_args.tokenizer_path,
|
114
114
|
tokenizer_mode=server_args.tokenizer_mode,
|
@@ -197,6 +197,16 @@ class ModelTpServer:
|
|
197
197
|
"trust_remote_code": server_args.trust_remote_code,
|
198
198
|
},
|
199
199
|
skip_tokenizer_init=server_args.skip_tokenizer_init,
|
200
|
+
json_schema_mode=False,
|
201
|
+
)
|
202
|
+
self.json_fsm_cache = FSMCache(
|
203
|
+
server_args.tokenizer_path,
|
204
|
+
{
|
205
|
+
"tokenizer_mode": server_args.tokenizer_mode,
|
206
|
+
"trust_remote_code": server_args.trust_remote_code,
|
207
|
+
},
|
208
|
+
skip_tokenizer_init=server_args.skip_tokenizer_init,
|
209
|
+
json_schema_mode=True,
|
200
210
|
)
|
201
211
|
self.jump_forward_cache = JumpForwardCache()
|
202
212
|
|
@@ -323,34 +333,42 @@ class ModelTpServer:
|
|
323
333
|
if self.model_runner.is_generation:
|
324
334
|
req.pixel_values = recv_req.pixel_values
|
325
335
|
if req.pixel_values is not None:
|
326
|
-
|
327
|
-
|
328
|
-
|
329
|
-
else recv_req.image_hash
|
330
|
-
)
|
336
|
+
# Use image hash as fake token_ids, which is then used
|
337
|
+
# for prefix matching
|
338
|
+
image_hash = hash(tuple(recv_req.image_hashes))
|
331
339
|
req.pad_value = [
|
332
340
|
(image_hash) % self.model_config.vocab_size,
|
333
341
|
(image_hash >> 16) % self.model_config.vocab_size,
|
334
342
|
(image_hash >> 32) % self.model_config.vocab_size,
|
335
343
|
(image_hash >> 64) % self.model_config.vocab_size,
|
336
344
|
]
|
337
|
-
req.
|
345
|
+
req.image_sizes = recv_req.image_sizes
|
338
346
|
(
|
339
347
|
req.origin_input_ids,
|
340
|
-
req.
|
348
|
+
req.image_offsets,
|
341
349
|
) = self.model_runner.model.pad_input_ids(
|
342
350
|
req.origin_input_ids_unpadded,
|
343
351
|
req.pad_value,
|
344
|
-
req.pixel_values
|
345
|
-
req.
|
352
|
+
req.pixel_values,
|
353
|
+
req.image_sizes,
|
346
354
|
)
|
347
355
|
req.return_logprob = recv_req.return_logprob
|
348
356
|
req.logprob_start_len = recv_req.logprob_start_len
|
349
357
|
req.top_logprobs_num = recv_req.top_logprobs_num
|
350
358
|
req.stream = recv_req.stream
|
351
359
|
|
360
|
+
# Init regex fsm fron json
|
361
|
+
if req.sampling_params.json_schema is not None:
|
362
|
+
req.regex_fsm, computed_regex_string = self.json_fsm_cache.query(
|
363
|
+
req.sampling_params.json_schema
|
364
|
+
)
|
365
|
+
if not self.disable_regex_jump_forward:
|
366
|
+
req.jump_forward_map = self.jump_forward_cache.query(
|
367
|
+
computed_regex_string
|
368
|
+
)
|
369
|
+
|
352
370
|
# Init regex fsm
|
353
|
-
|
371
|
+
elif req.sampling_params.regex is not None:
|
354
372
|
req.regex_fsm = self.regex_fsm_cache.query(req.sampling_params.regex)
|
355
373
|
if not self.disable_regex_jump_forward:
|
356
374
|
req.jump_forward_map = self.jump_forward_cache.query(
|
@@ -486,29 +504,21 @@ class ModelTpServer:
|
|
486
504
|
if self.model_runner.is_generation:
|
487
505
|
# Forward and sample the next tokens
|
488
506
|
if batch.extend_num_tokens != 0:
|
489
|
-
|
490
|
-
|
491
|
-
)
|
492
|
-
next_token_ids = batch.check_sample_results(sample_output)
|
507
|
+
output = self.model_runner.forward(batch, ForwardMode.EXTEND)
|
508
|
+
next_token_ids = batch.sample(output.next_token_logits)
|
493
509
|
batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
|
494
510
|
next_token_ids
|
495
511
|
)
|
496
512
|
|
497
513
|
# Move logprobs to cpu
|
498
|
-
if
|
499
|
-
|
500
|
-
|
501
|
-
|
502
|
-
|
503
|
-
|
504
|
-
|
505
|
-
|
506
|
-
)
|
507
|
-
logits_output.input_token_logprobs = (
|
508
|
-
logits_output.input_token_logprobs.tolist()
|
509
|
-
)
|
510
|
-
logits_output.normalized_prompt_logprobs = (
|
511
|
-
logits_output.normalized_prompt_logprobs.tolist()
|
514
|
+
if output.next_token_logprobs is not None:
|
515
|
+
output.next_token_logprobs = output.next_token_logprobs[
|
516
|
+
torch.arange(len(next_token_ids), device=next_token_ids.device),
|
517
|
+
next_token_ids,
|
518
|
+
].tolist()
|
519
|
+
output.input_token_logprobs = output.input_token_logprobs.tolist()
|
520
|
+
output.normalized_prompt_logprobs = (
|
521
|
+
output.normalized_prompt_logprobs.tolist()
|
512
522
|
)
|
513
523
|
|
514
524
|
next_token_ids = next_token_ids.tolist()
|
@@ -547,14 +557,12 @@ class ModelTpServer:
|
|
547
557
|
self.req_to_token_pool.free(req.req_pool_idx)
|
548
558
|
|
549
559
|
if req.return_logprob:
|
550
|
-
self.add_logprob_return_values(
|
551
|
-
i, req, pt, next_token_ids, logits_output
|
552
|
-
)
|
560
|
+
self.add_logprob_return_values(i, req, pt, next_token_ids, output)
|
553
561
|
pt += req.extend_input_len
|
554
562
|
else:
|
555
563
|
assert batch.extend_num_tokens != 0
|
556
|
-
|
557
|
-
embeddings =
|
564
|
+
output = self.model_runner.forward(batch, ForwardMode.EXTEND)
|
565
|
+
embeddings = output.embeddings.tolist()
|
558
566
|
|
559
567
|
# Check finish conditions
|
560
568
|
for i, req in enumerate(batch.reqs):
|
@@ -582,7 +590,7 @@ class ModelTpServer:
|
|
582
590
|
req: Req,
|
583
591
|
pt: int,
|
584
592
|
next_token_ids: List[int],
|
585
|
-
output:
|
593
|
+
output: LogitProcessorOutput,
|
586
594
|
):
|
587
595
|
if req.normalized_prompt_logprob is None:
|
588
596
|
req.normalized_prompt_logprob = output.normalized_prompt_logprobs[i]
|
@@ -664,17 +672,15 @@ class ModelTpServer:
|
|
664
672
|
batch.prepare_for_decode()
|
665
673
|
|
666
674
|
# Forward and sample the next tokens
|
667
|
-
|
668
|
-
|
669
|
-
)
|
670
|
-
next_token_ids = batch.check_sample_results(sample_output)
|
675
|
+
output = self.model_runner.forward(batch, ForwardMode.DECODE)
|
676
|
+
next_token_ids = batch.sample(output.next_token_logits)
|
671
677
|
batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
|
672
678
|
next_token_ids
|
673
679
|
)
|
674
680
|
|
675
681
|
# Move logprobs to cpu
|
676
|
-
if
|
677
|
-
next_token_logprobs =
|
682
|
+
if output.next_token_logprobs is not None:
|
683
|
+
next_token_logprobs = output.next_token_logprobs[
|
678
684
|
torch.arange(len(next_token_ids), device=next_token_ids.device),
|
679
685
|
next_token_ids,
|
680
686
|
].tolist()
|
@@ -700,7 +706,7 @@ class ModelTpServer:
|
|
700
706
|
(next_token_logprobs[i], next_token_id)
|
701
707
|
)
|
702
708
|
if req.top_logprobs_num > 0:
|
703
|
-
req.output_top_logprobs.append(
|
709
|
+
req.output_top_logprobs.append(output.output_top_logprobs[i])
|
704
710
|
|
705
711
|
self.handle_finished_requests(batch)
|
706
712
|
|