sglang 0.2.14.post1__py3-none-any.whl → 0.2.14.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/launch_server_llavavid.py +26 -0
- sglang/srt/hf_transformers_utils.py +0 -149
- sglang/srt/layers/activation.py +10 -4
- sglang/srt/layers/layernorm.py +47 -1
- sglang/srt/managers/io_struct.py +5 -4
- sglang/srt/managers/schedule_batch.py +5 -5
- sglang/srt/managers/tokenizer_manager.py +74 -61
- sglang/srt/managers/tp_worker.py +9 -10
- sglang/srt/model_executor/forward_batch_info.py +10 -20
- sglang/srt/model_executor/model_runner.py +15 -6
- sglang/srt/models/chatglm.py +1 -1
- sglang/srt/models/gemma.py +2 -2
- sglang/srt/models/gemma2.py +1 -51
- sglang/srt/models/grok.py +9 -3
- sglang/srt/models/llama2.py +3 -4
- sglang/srt/models/llama_classification.py +0 -4
- sglang/srt/models/llama_embedding.py +3 -4
- sglang/srt/models/llava.py +69 -91
- sglang/srt/models/llavavid.py +40 -86
- sglang/srt/models/qwen2.py +3 -4
- sglang/srt/models/qwen2_moe.py +7 -19
- sglang/srt/models/yivl.py +2 -7
- sglang/srt/server.py +3 -3
- sglang/srt/utils.py +18 -33
- sglang/test/runners.py +1 -1
- sglang/test/test_layernorm.py +53 -1
- sglang/version.py +1 -1
- {sglang-0.2.14.post1.dist-info → sglang-0.2.14.post2.dist-info}/METADATA +3 -3
- {sglang-0.2.14.post1.dist-info → sglang-0.2.14.post2.dist-info}/RECORD +32 -31
- {sglang-0.2.14.post1.dist-info → sglang-0.2.14.post2.dist-info}/LICENSE +0 -0
- {sglang-0.2.14.post1.dist-info → sglang-0.2.14.post2.dist-info}/WHEEL +0 -0
- {sglang-0.2.14.post1.dist-info → sglang-0.2.14.post2.dist-info}/top_level.txt +0 -0
@@ -23,6 +23,7 @@ import multiprocessing as mp
|
|
23
23
|
import os
|
24
24
|
from typing import Dict, List, Optional, Tuple, Union
|
25
25
|
|
26
|
+
import fastapi
|
26
27
|
import numpy as np
|
27
28
|
import transformers
|
28
29
|
import uvloop
|
@@ -96,21 +97,18 @@ class TokenizerManager:
|
|
96
97
|
trust_remote_code=server_args.trust_remote_code,
|
97
98
|
model_overide_args=model_overide_args,
|
98
99
|
)
|
99
|
-
|
100
100
|
self.is_generation = is_generation_model(
|
101
101
|
self.hf_config.architectures, self.server_args.is_embedding
|
102
102
|
)
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
else:
|
107
|
-
self.context_len = get_context_length(self.hf_config)
|
103
|
+
self.context_len = server_args.context_length or get_context_length(
|
104
|
+
self.hf_config
|
105
|
+
)
|
108
106
|
|
109
107
|
# Create tokenizer
|
110
108
|
if server_args.skip_tokenizer_init:
|
111
109
|
self.tokenizer = self.processor = None
|
112
110
|
else:
|
113
|
-
if is_multimodal_model(self.
|
111
|
+
if is_multimodal_model(self.hf_config.architectures):
|
114
112
|
self.processor = get_processor(
|
115
113
|
server_args.tokenizer_path,
|
116
114
|
tokenizer_mode=server_args.tokenizer_mode,
|
@@ -118,6 +116,9 @@ class TokenizerManager:
|
|
118
116
|
)
|
119
117
|
self.tokenizer = self.processor.tokenizer
|
120
118
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
119
|
+
|
120
|
+
# We want to parallelize the image pre-processing so we
|
121
|
+
# create an executor for it
|
121
122
|
self.executor = concurrent.futures.ProcessPoolExecutor(
|
122
123
|
initializer=init_global_processor,
|
123
124
|
mp_context=mp.get_context("fork"),
|
@@ -134,12 +135,14 @@ class TokenizerManager:
|
|
134
135
|
self.to_create_loop = True
|
135
136
|
self.rid_to_state: Dict[str, ReqState] = {}
|
136
137
|
|
137
|
-
#
|
138
|
+
# For update model weights
|
138
139
|
self.model_update_lock = asyncio.Lock()
|
139
140
|
self.model_update_result = None
|
140
141
|
|
141
142
|
async def generate_request(
|
142
|
-
self,
|
143
|
+
self,
|
144
|
+
obj: Union[GenerateReqInput, EmbeddingReqInput],
|
145
|
+
request: Optional[fastapi.Request] = None,
|
143
146
|
):
|
144
147
|
if self.to_create_loop:
|
145
148
|
self.create_handle_loop()
|
@@ -160,7 +163,7 @@ class TokenizerManager:
|
|
160
163
|
async def _handle_single_request(
|
161
164
|
self,
|
162
165
|
obj: Union[GenerateReqInput, EmbeddingReqInput],
|
163
|
-
request,
|
166
|
+
request: Optional[fastapi.Request] = None,
|
164
167
|
index: Optional[int] = None,
|
165
168
|
is_cache_for_prefill: Optional[bool] = False,
|
166
169
|
):
|
@@ -182,8 +185,8 @@ class TokenizerManager:
|
|
182
185
|
)
|
183
186
|
|
184
187
|
if self.is_generation:
|
185
|
-
pixel_values,
|
186
|
-
obj.image_data
|
188
|
+
pixel_values, image_hashes, image_sizes = await self._get_pixel_values(
|
189
|
+
obj.image_data if not_use_index else obj.image_data[index]
|
187
190
|
)
|
188
191
|
return_logprob = (
|
189
192
|
obj.return_logprob if not_use_index else obj.return_logprob[index]
|
@@ -195,7 +198,6 @@ class TokenizerManager:
|
|
195
198
|
)
|
196
199
|
if return_logprob and logprob_start_len == -1:
|
197
200
|
logprob_start_len = len(input_ids) - 1
|
198
|
-
|
199
201
|
top_logprobs_num = (
|
200
202
|
obj.top_logprobs_num
|
201
203
|
if not_use_index
|
@@ -238,13 +240,14 @@ class TokenizerManager:
|
|
238
240
|
|
239
241
|
sampling_params = SamplingParams(**obj.sampling_params[0])
|
240
242
|
sampling_params.max_new_tokens = 0
|
241
|
-
pixel_values,
|
243
|
+
pixel_values, image_hashes, image_sizes = await self._get_pixel_values(
|
242
244
|
obj.image_data[0]
|
243
245
|
)
|
244
246
|
return_logprob = obj.return_logprob[0]
|
245
247
|
logprob_start_len = obj.logprob_start_len[0]
|
246
248
|
top_logprobs_num = obj.top_logprobs_num[0]
|
247
249
|
|
250
|
+
# Send to the controller
|
248
251
|
if self.is_generation:
|
249
252
|
if return_logprob and logprob_start_len == -1:
|
250
253
|
logprob_start_len = len(input_ids) - 1
|
@@ -253,8 +256,8 @@ class TokenizerManager:
|
|
253
256
|
input_text,
|
254
257
|
input_ids,
|
255
258
|
pixel_values,
|
256
|
-
|
257
|
-
|
259
|
+
image_hashes,
|
260
|
+
image_sizes,
|
258
261
|
sampling_params,
|
259
262
|
return_logprob,
|
260
263
|
logprob_start_len,
|
@@ -268,24 +271,24 @@ class TokenizerManager:
|
|
268
271
|
input_ids,
|
269
272
|
sampling_params,
|
270
273
|
)
|
271
|
-
|
272
274
|
self.send_to_router.send_pyobj(tokenized_obj)
|
273
275
|
|
276
|
+
# Recv results
|
274
277
|
event = asyncio.Event()
|
275
278
|
state = ReqState([], False, event)
|
276
279
|
self.rid_to_state[rid] = state
|
277
280
|
if not is_cache_for_prefill:
|
278
|
-
async for response in self._wait_for_response(
|
279
|
-
event, state, obj, rid, request
|
280
|
-
):
|
281
|
+
async for response in self._wait_for_response(state, obj, rid, request):
|
281
282
|
yield response
|
282
283
|
else:
|
283
284
|
assert self.is_generation
|
284
|
-
await self._wait_for_cache_prefill_response(
|
285
|
+
await self._wait_for_cache_prefill_response(state, obj, rid, request)
|
285
286
|
yield input_ids
|
286
287
|
|
287
288
|
async def _handle_batch_request(
|
288
|
-
self,
|
289
|
+
self,
|
290
|
+
obj: Union[GenerateReqInput, EmbeddingReqInput],
|
291
|
+
request: Optional[fastapi.Request] = None,
|
289
292
|
):
|
290
293
|
batch_size = obj.batch_size
|
291
294
|
if self.is_generation:
|
@@ -340,8 +343,8 @@ class TokenizerManager:
|
|
340
343
|
if self.is_generation:
|
341
344
|
if obj.return_logprob[index] and obj.logprob_start_len[index] == -1:
|
342
345
|
obj.logprob_start_len[index] = len(input_ids) - 1
|
343
|
-
pixel_values,
|
344
|
-
obj.image_data[index]
|
346
|
+
pixel_values, image_hashes, image_sizes = (
|
347
|
+
await self._get_pixel_values(obj.image_data[index])
|
345
348
|
)
|
346
349
|
|
347
350
|
tokenized_obj = TokenizedGenerateReqInput(
|
@@ -349,8 +352,8 @@ class TokenizerManager:
|
|
349
352
|
input_text,
|
350
353
|
input_ids,
|
351
354
|
pixel_values,
|
352
|
-
|
353
|
-
|
355
|
+
image_hashes,
|
356
|
+
image_sizes,
|
354
357
|
sampling_params,
|
355
358
|
obj.return_logprob[index],
|
356
359
|
obj.logprob_start_len[index],
|
@@ -372,7 +375,6 @@ class TokenizerManager:
|
|
372
375
|
|
373
376
|
generators.append(
|
374
377
|
self._wait_for_response(
|
375
|
-
event,
|
376
378
|
state,
|
377
379
|
obj,
|
378
380
|
rid,
|
@@ -388,6 +390,7 @@ class TokenizerManager:
|
|
388
390
|
tasks = [asyncio.create_task(gen.__anext__()) for gen in generators]
|
389
391
|
output_list = [None] * len(tasks)
|
390
392
|
|
393
|
+
# Recv results
|
391
394
|
while tasks:
|
392
395
|
done, _ = await asyncio.wait(tasks, return_when=asyncio.FIRST_COMPLETED)
|
393
396
|
|
@@ -426,25 +429,18 @@ class TokenizerManager:
|
|
426
429
|
sampling_params.verify()
|
427
430
|
return sampling_params
|
428
431
|
|
429
|
-
async def _get_pixel_values(self, image_data):
|
430
|
-
if image_data is None:
|
431
|
-
return None, None, None
|
432
|
-
else:
|
433
|
-
return await self._get_pixel_values_internal(image_data)
|
434
|
-
|
435
432
|
async def _wait_for_response(
|
436
433
|
self,
|
437
|
-
event: asyncio.Event,
|
438
434
|
state: ReqState,
|
439
435
|
obj: Union[GenerateReqInput, EmbeddingReqInput],
|
440
436
|
rid: str,
|
441
|
-
request,
|
442
|
-
index: int = None,
|
437
|
+
request: Optional[fastapi.Request] = None,
|
438
|
+
index: Optional[int] = None,
|
443
439
|
response_index: int = 0,
|
444
440
|
):
|
445
441
|
while True:
|
446
442
|
try:
|
447
|
-
await asyncio.wait_for(event.wait(), timeout=4)
|
443
|
+
await asyncio.wait_for(state.event.wait(), timeout=4)
|
448
444
|
except asyncio.TimeoutError:
|
449
445
|
if request is not None and await request.is_disconnected():
|
450
446
|
for rid in [obj.rid] if obj.is_single else obj.rid:
|
@@ -478,16 +474,15 @@ class TokenizerManager:
|
|
478
474
|
yield out
|
479
475
|
break
|
480
476
|
|
481
|
-
event.clear()
|
477
|
+
state.event.clear()
|
482
478
|
yield out
|
483
479
|
|
484
480
|
async def _wait_for_cache_prefill_response(
|
485
481
|
self,
|
486
|
-
event: asyncio.Event,
|
487
482
|
state: ReqState,
|
488
483
|
obj: GenerateReqInput,
|
489
484
|
rid: str,
|
490
|
-
request,
|
485
|
+
request: Optional[fastapi.Request] = None,
|
491
486
|
):
|
492
487
|
while True:
|
493
488
|
try:
|
@@ -514,7 +509,9 @@ class TokenizerManager:
|
|
514
509
|
req = AbortReq(rid)
|
515
510
|
self.send_to_router.send_pyobj(req)
|
516
511
|
|
517
|
-
async def update_weights(
|
512
|
+
async def update_weights(
|
513
|
+
self, obj: UpdateWeightReqInput, request: Optional[fastapi.Request] = None
|
514
|
+
):
|
518
515
|
if self.to_create_loop:
|
519
516
|
self.create_handle_loop()
|
520
517
|
|
@@ -659,12 +656,11 @@ class TokenizerManager:
|
|
659
656
|
)
|
660
657
|
return top_logprobs
|
661
658
|
|
662
|
-
async def
|
663
|
-
|
664
|
-
|
665
|
-
|
666
|
-
|
667
|
-
)
|
659
|
+
async def _get_pixel_values(self, image_data: List[Union[str, bytes]]):
|
660
|
+
if not image_data:
|
661
|
+
return None, None, None
|
662
|
+
|
663
|
+
aspect_ratio = getattr(self.hf_config, "image_aspect_ratio", None)
|
668
664
|
grid_pinpoints = (
|
669
665
|
self.hf_config.image_grid_pinpoints
|
670
666
|
if hasattr(self.hf_config, "image_grid_pinpoints")
|
@@ -673,35 +669,42 @@ class TokenizerManager:
|
|
673
669
|
)
|
674
670
|
|
675
671
|
if isinstance(image_data, list) and len(image_data) > 0:
|
676
|
-
|
672
|
+
# Multiple images
|
677
673
|
if len(image_data) > 1:
|
678
674
|
aspect_ratio = "pad" # LLaVA OneVision Handling: more than one image --> interleaved image mode or video mode. We do not use anyres
|
675
|
+
pixel_values, image_hashes, image_sizes = [], [], []
|
679
676
|
for img_data in image_data:
|
680
677
|
pixel_v, image_h, image_s = await self._process_single_image(
|
681
678
|
img_data, aspect_ratio, grid_pinpoints
|
682
679
|
)
|
683
680
|
pixel_values.append(pixel_v)
|
684
|
-
|
685
|
-
|
686
|
-
|
681
|
+
image_hashes.append(image_h)
|
682
|
+
image_sizes.append(image_s)
|
683
|
+
|
684
|
+
if isinstance(pixel_values[0], np.ndarray):
|
685
|
+
pixel_values = np.stack(pixel_values, axis=0)
|
687
686
|
else:
|
687
|
+
# A single image
|
688
688
|
pixel_values, image_hash, image_size = await self._process_single_image(
|
689
689
|
image_data[0], aspect_ratio, grid_pinpoints
|
690
690
|
)
|
691
|
-
|
692
|
-
|
691
|
+
image_hashes = [image_hash]
|
692
|
+
image_sizes = [image_size]
|
693
693
|
elif isinstance(image_data, str):
|
694
|
+
# A single image
|
694
695
|
pixel_values, image_hash, image_size = await self._process_single_image(
|
695
696
|
image_data, aspect_ratio, grid_pinpoints
|
696
697
|
)
|
697
|
-
|
698
|
-
|
698
|
+
image_hashes = [image_hash]
|
699
|
+
image_sizes = [image_size]
|
699
700
|
else:
|
700
|
-
|
701
|
+
raise ValueError(f"Invalid image data: {image_data}")
|
701
702
|
|
702
|
-
return pixel_values,
|
703
|
+
return pixel_values, image_hashes, image_sizes
|
703
704
|
|
704
|
-
async def _process_single_image(
|
705
|
+
async def _process_single_image(
|
706
|
+
self, image_data: Union[bytes, str], aspect_ratio: str, grid_pinpoints: str
|
707
|
+
):
|
705
708
|
if self.executor is not None:
|
706
709
|
loop = asyncio.get_event_loop()
|
707
710
|
return await loop.run_in_executor(
|
@@ -732,12 +735,16 @@ def init_global_processor(server_args: ServerArgs):
|
|
732
735
|
|
733
736
|
|
734
737
|
def _process_single_image_task(
|
735
|
-
image_data
|
738
|
+
image_data: Union[str, bytes],
|
739
|
+
image_aspect_ratio: Optional[str] = None,
|
740
|
+
image_grid_pinpoints: Optional[str] = None,
|
741
|
+
processor=None,
|
736
742
|
):
|
737
743
|
try:
|
738
744
|
processor = processor or global_processor
|
739
745
|
image, image_size = load_image(image_data)
|
740
746
|
if image_size is not None:
|
747
|
+
# It is a video with multiple images
|
741
748
|
image_hash = hash(image_data)
|
742
749
|
pixel_values = processor.image_processor(image)["pixel_values"]
|
743
750
|
for _ in range(len(pixel_values)):
|
@@ -745,6 +752,7 @@ def _process_single_image_task(
|
|
745
752
|
pixel_values = np.stack(pixel_values, axis=0)
|
746
753
|
return pixel_values, image_hash, image_size
|
747
754
|
else:
|
755
|
+
# It is an image
|
748
756
|
image_hash = hash(image_data)
|
749
757
|
if image_aspect_ratio == "pad":
|
750
758
|
image = expand2square(
|
@@ -754,13 +762,18 @@ def _process_single_image_task(
|
|
754
762
|
pixel_values = processor.image_processor(image.convert("RGB"))[
|
755
763
|
"pixel_values"
|
756
764
|
][0]
|
757
|
-
elif image_aspect_ratio == "anyres" or
|
765
|
+
elif image_aspect_ratio == "anyres" or (
|
766
|
+
image_aspect_ratio is not None and "anyres_max" in image_aspect_ratio
|
767
|
+
):
|
758
768
|
pixel_values = process_anyres_image(
|
759
769
|
image, processor.image_processor, image_grid_pinpoints
|
760
770
|
)
|
761
771
|
else:
|
762
772
|
pixel_values = processor.image_processor(image)["pixel_values"][0]
|
763
|
-
|
773
|
+
|
774
|
+
if isinstance(pixel_values, np.ndarray):
|
775
|
+
pixel_values = pixel_values.astype(np.float16)
|
776
|
+
|
764
777
|
return pixel_values, image_hash, image.size
|
765
778
|
except Exception:
|
766
779
|
logger.error("Exception in TokenizerManager:\n" + get_exception_traceback())
|
sglang/srt/managers/tp_worker.py
CHANGED
@@ -108,7 +108,7 @@ class ModelTpServer:
|
|
108
108
|
if server_args.skip_tokenizer_init:
|
109
109
|
self.tokenizer = self.processor = None
|
110
110
|
else:
|
111
|
-
if is_multimodal_model(
|
111
|
+
if is_multimodal_model(self.model_config.hf_config.architectures):
|
112
112
|
self.processor = get_processor(
|
113
113
|
server_args.tokenizer_path,
|
114
114
|
tokenizer_mode=server_args.tokenizer_mode,
|
@@ -333,26 +333,24 @@ class ModelTpServer:
|
|
333
333
|
if self.model_runner.is_generation:
|
334
334
|
req.pixel_values = recv_req.pixel_values
|
335
335
|
if req.pixel_values is not None:
|
336
|
-
|
337
|
-
|
338
|
-
|
339
|
-
else recv_req.image_hash
|
340
|
-
)
|
336
|
+
# Use image hash as fake token_ids, which is then used
|
337
|
+
# for prefix matching
|
338
|
+
image_hash = hash(tuple(recv_req.image_hashes))
|
341
339
|
req.pad_value = [
|
342
340
|
(image_hash) % self.model_config.vocab_size,
|
343
341
|
(image_hash >> 16) % self.model_config.vocab_size,
|
344
342
|
(image_hash >> 32) % self.model_config.vocab_size,
|
345
343
|
(image_hash >> 64) % self.model_config.vocab_size,
|
346
344
|
]
|
347
|
-
req.
|
345
|
+
req.image_sizes = recv_req.image_sizes
|
348
346
|
(
|
349
347
|
req.origin_input_ids,
|
350
|
-
req.
|
348
|
+
req.image_offsets,
|
351
349
|
) = self.model_runner.model.pad_input_ids(
|
352
350
|
req.origin_input_ids_unpadded,
|
353
351
|
req.pad_value,
|
354
|
-
req.pixel_values
|
355
|
-
req.
|
352
|
+
req.pixel_values,
|
353
|
+
req.image_sizes,
|
356
354
|
)
|
357
355
|
req.return_logprob = recv_req.return_logprob
|
358
356
|
req.logprob_start_len = recv_req.logprob_start_len
|
@@ -368,6 +366,7 @@ class ModelTpServer:
|
|
368
366
|
req.jump_forward_map = self.jump_forward_cache.query(
|
369
367
|
computed_regex_string
|
370
368
|
)
|
369
|
+
|
371
370
|
# Init regex fsm
|
372
371
|
elif req.sampling_params.regex is not None:
|
373
372
|
req.regex_fsm = self.regex_fsm_cache.query(req.sampling_params.regex)
|
@@ -16,7 +16,7 @@ limitations under the License.
|
|
16
16
|
"""ModelRunner runs the forward passes of the models."""
|
17
17
|
from dataclasses import dataclass
|
18
18
|
from enum import IntEnum, auto
|
19
|
-
from typing import TYPE_CHECKING, List
|
19
|
+
from typing import TYPE_CHECKING, List
|
20
20
|
|
21
21
|
import numpy as np
|
22
22
|
import torch
|
@@ -58,6 +58,7 @@ class InputMetadata:
|
|
58
58
|
|
59
59
|
# For extend
|
60
60
|
extend_seq_lens: torch.Tensor = None
|
61
|
+
extend_prefix_lens: torch.Tensor = None
|
61
62
|
extend_start_loc: torch.Tensor = None
|
62
63
|
extend_no_prefix: bool = None
|
63
64
|
|
@@ -69,8 +70,8 @@ class InputMetadata:
|
|
69
70
|
|
70
71
|
# For multimodal
|
71
72
|
pixel_values: List[torch.Tensor] = None
|
72
|
-
image_sizes: List[List[int]] = None
|
73
|
-
image_offsets: List[int] = None
|
73
|
+
image_sizes: List[List[List[int]]] = None
|
74
|
+
image_offsets: List[List[int]] = None
|
74
75
|
|
75
76
|
# Trition attention backend
|
76
77
|
triton_max_seq_len: int = 0
|
@@ -87,20 +88,8 @@ class InputMetadata:
|
|
87
88
|
def init_multimuldal_info(self, batch: ScheduleBatch):
|
88
89
|
reqs = batch.reqs
|
89
90
|
self.pixel_values = [r.pixel_values for r in reqs]
|
90
|
-
self.image_sizes = [r.
|
91
|
-
self.image_offsets = []
|
92
|
-
for r in reqs:
|
93
|
-
if isinstance(r.image_offset, list):
|
94
|
-
self.image_offsets.append(
|
95
|
-
[
|
96
|
-
(image_offset - len(r.prefix_indices))
|
97
|
-
for image_offset in r.image_offset
|
98
|
-
]
|
99
|
-
)
|
100
|
-
elif isinstance(r.image_offset, int):
|
101
|
-
self.image_offsets.append(r.image_offset - len(r.prefix_indices))
|
102
|
-
elif r.image_offset is None:
|
103
|
-
self.image_offsets.append(0)
|
91
|
+
self.image_sizes = [r.image_sizes for r in reqs]
|
92
|
+
self.image_offsets = [r.image_offsets for r in reqs]
|
104
93
|
|
105
94
|
def compute_positions(self, batch: ScheduleBatch):
|
106
95
|
position_ids_offsets = batch.position_ids_offsets
|
@@ -153,6 +142,7 @@ class InputMetadata:
|
|
153
142
|
for i, r in enumerate(batch.reqs)
|
154
143
|
]
|
155
144
|
self.extend_seq_lens = torch.tensor(extend_lens_cpu, device="cuda")
|
145
|
+
self.extend_prefix_lens = torch.tensor(batch.prefix_lens_cpu, device="cuda")
|
156
146
|
self.extend_start_loc = torch.zeros_like(self.seq_lens)
|
157
147
|
self.extend_start_loc[1:] = torch.cumsum(self.extend_seq_lens[:-1], dim=0)
|
158
148
|
self.extend_no_prefix = all(l == 0 for l in batch.prefix_lens_cpu)
|
@@ -238,10 +228,10 @@ class InputMetadata:
|
|
238
228
|
prefix_lens_cpu,
|
239
229
|
flashinfer_use_ragged,
|
240
230
|
):
|
241
|
-
if self.forward_mode
|
242
|
-
prefix_lens = torch.tensor(prefix_lens_cpu, device="cuda")
|
243
|
-
else:
|
231
|
+
if self.forward_mode == ForwardMode.DECODE:
|
244
232
|
prefix_lens = None
|
233
|
+
else:
|
234
|
+
prefix_lens = self.extend_prefix_lens
|
245
235
|
|
246
236
|
update_flashinfer_indices(
|
247
237
|
self.forward_mode,
|
@@ -50,7 +50,7 @@ from sglang.srt.mem_cache.memory_pool import (
|
|
50
50
|
MLATokenToKVPool,
|
51
51
|
ReqToTokenPool,
|
52
52
|
)
|
53
|
-
from sglang.srt.model_config import AttentionArch
|
53
|
+
from sglang.srt.model_config import AttentionArch, ModelConfig
|
54
54
|
from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
|
55
55
|
from sglang.srt.server_args import ServerArgs
|
56
56
|
from sglang.srt.utils import (
|
@@ -69,7 +69,7 @@ logger = logging.getLogger(__name__)
|
|
69
69
|
class ModelRunner:
|
70
70
|
def __init__(
|
71
71
|
self,
|
72
|
-
model_config,
|
72
|
+
model_config: ModelConfig,
|
73
73
|
mem_fraction_static: float,
|
74
74
|
gpu_id: int,
|
75
75
|
tp_rank: int,
|
@@ -85,7 +85,9 @@ class ModelRunner:
|
|
85
85
|
self.tp_size = tp_size
|
86
86
|
self.nccl_port = nccl_port
|
87
87
|
self.server_args = server_args
|
88
|
-
self.is_multimodal_model = is_multimodal_model(
|
88
|
+
self.is_multimodal_model = is_multimodal_model(
|
89
|
+
self.model_config.hf_config.architectures
|
90
|
+
)
|
89
91
|
global_server_args_dict.update(
|
90
92
|
{
|
91
93
|
"disable_flashinfer": server_args.disable_flashinfer,
|
@@ -95,6 +97,13 @@ class ModelRunner:
|
|
95
97
|
}
|
96
98
|
)
|
97
99
|
|
100
|
+
if self.is_multimodal_model:
|
101
|
+
logger.info(
|
102
|
+
"Automatically turn off --chunked-prefill-size and adjust --mem-fraction-static for multimodal models."
|
103
|
+
)
|
104
|
+
server_args.chunked_prefill_size = None
|
105
|
+
server_args.mem_fraction_static *= 0.95
|
106
|
+
|
98
107
|
min_per_gpu_memory = self.init_torch_distributed()
|
99
108
|
self.load_model()
|
100
109
|
self.init_memory_pool(
|
@@ -507,9 +516,9 @@ class ModelRunner:
|
|
507
516
|
raise Exception(
|
508
517
|
f"Capture cuda graph failed: {e}\n"
|
509
518
|
"Possible solutions:\n"
|
510
|
-
"1. disable
|
511
|
-
"2.
|
512
|
-
"3.
|
519
|
+
"1. disable cuda graph by --disable-cuda-graph\n"
|
520
|
+
"2. set --mem-fraction-static to a smaller value\n"
|
521
|
+
"3. disable torch compile by not using --enable-torch-compile\n"
|
513
522
|
"Open an issue on GitHub https://github.com/sgl-project/sglang/issues/new/choose \n"
|
514
523
|
)
|
515
524
|
|
sglang/srt/models/chatglm.py
CHANGED
@@ -17,7 +17,7 @@ limitations under the License.
|
|
17
17
|
# Adapted from
|
18
18
|
# https://github.com/THUDM/ChatGLM2-6B
|
19
19
|
"""Inference-only ChatGLM model compatible with THUDM weights."""
|
20
|
-
from typing import Iterable,
|
20
|
+
from typing import Iterable, Optional, Tuple
|
21
21
|
|
22
22
|
import torch
|
23
23
|
from torch import nn
|
sglang/srt/models/gemma.py
CHANGED
@@ -23,7 +23,6 @@ from torch import nn
|
|
23
23
|
from transformers import PretrainedConfig
|
24
24
|
from vllm.config import CacheConfig, LoRAConfig
|
25
25
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
26
|
-
from vllm.model_executor.layers.activation import GeluAndMul
|
27
26
|
from vllm.model_executor.layers.linear import (
|
28
27
|
MergedColumnParallelLinear,
|
29
28
|
QKVParallelLinear,
|
@@ -34,6 +33,7 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
|
|
34
33
|
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
35
34
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
36
35
|
|
36
|
+
from sglang.srt.layers.activation import GeluAndMul
|
37
37
|
from sglang.srt.layers.layernorm import RMSNorm
|
38
38
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
39
39
|
from sglang.srt.layers.radix_attention import RadixAttention
|
@@ -60,7 +60,7 @@ class GemmaMLP(nn.Module):
|
|
60
60
|
bias=False,
|
61
61
|
quant_config=quant_config,
|
62
62
|
)
|
63
|
-
self.act_fn = GeluAndMul()
|
63
|
+
self.act_fn = GeluAndMul("none")
|
64
64
|
|
65
65
|
def forward(self, x):
|
66
66
|
gate_up, _ = self.gate_up_proj(x)
|
sglang/srt/models/gemma2.py
CHANGED
@@ -22,11 +22,6 @@ from torch import nn
|
|
22
22
|
from transformers import PretrainedConfig
|
23
23
|
from vllm.config import CacheConfig, LoRAConfig
|
24
24
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
25
|
-
|
26
|
-
# FIXME: temporary solution, remove after next vllm release
|
27
|
-
from vllm.model_executor.custom_op import CustomOp
|
28
|
-
|
29
|
-
# from vllm.model_executor.layers.layernorm import GemmaRMSNorm
|
30
25
|
from vllm.model_executor.layers.linear import (
|
31
26
|
MergedColumnParallelLinear,
|
32
27
|
QKVParallelLinear,
|
@@ -39,6 +34,7 @@ from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmb
|
|
39
34
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
40
35
|
|
41
36
|
from sglang.srt.layers.activation import GeluAndMul
|
37
|
+
from sglang.srt.layers.layernorm import GemmaRMSNorm
|
42
38
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
43
39
|
from sglang.srt.layers.radix_attention import RadixAttention
|
44
40
|
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
@@ -50,52 +46,6 @@ def get_attention_sliding_window_size(config):
|
|
50
46
|
return config.sliding_window - 1
|
51
47
|
|
52
48
|
|
53
|
-
class GemmaRMSNorm(CustomOp):
|
54
|
-
"""RMS normalization for Gemma.
|
55
|
-
|
56
|
-
Two differences from the above RMSNorm:
|
57
|
-
1. x * (1 + w) instead of x * w.
|
58
|
-
2. (x * w).to(orig_dtype) instead of x.to(orig_dtype) * w.
|
59
|
-
"""
|
60
|
-
|
61
|
-
def __init__(
|
62
|
-
self,
|
63
|
-
hidden_size: int,
|
64
|
-
eps: float = 1e-6,
|
65
|
-
) -> None:
|
66
|
-
super().__init__()
|
67
|
-
self.weight = nn.Parameter(torch.zeros(hidden_size))
|
68
|
-
self.variance_epsilon = eps
|
69
|
-
|
70
|
-
def forward_native(
|
71
|
-
self,
|
72
|
-
x: torch.Tensor,
|
73
|
-
residual: Optional[torch.Tensor] = None,
|
74
|
-
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
75
|
-
"""PyTorch-native implementation equivalent to forward()."""
|
76
|
-
orig_dtype = x.dtype
|
77
|
-
if residual is not None:
|
78
|
-
x = x + residual
|
79
|
-
residual = x
|
80
|
-
|
81
|
-
x = x.float()
|
82
|
-
variance = x.pow(2).mean(dim=-1, keepdim=True)
|
83
|
-
x = x * torch.rsqrt(variance + self.variance_epsilon)
|
84
|
-
# Llama does x.to(float16) * w whilst Gemma is (x * w).to(float16)
|
85
|
-
# See https://github.com/huggingface/transformers/pull/29402
|
86
|
-
x = x * (1.0 + self.weight.float())
|
87
|
-
x = x.to(orig_dtype)
|
88
|
-
return x if residual is None else (x, residual)
|
89
|
-
|
90
|
-
def forward_cuda(
|
91
|
-
self,
|
92
|
-
x: torch.Tensor,
|
93
|
-
residual: Optional[torch.Tensor] = None,
|
94
|
-
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
95
|
-
# from vLLM: TODO(woosuk): Implement an optimized kernel for GemmaRMSNorm.
|
96
|
-
return self.forward_native(x, residual)
|
97
|
-
|
98
|
-
|
99
49
|
# FIXME: temporary solution, remove after next vllm release
|
100
50
|
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
|
101
51
|
|
sglang/srt/models/grok.py
CHANGED
@@ -273,9 +273,9 @@ class Grok1Model(nn.Module):
|
|
273
273
|
) -> torch.Tensor:
|
274
274
|
if input_embeds is None:
|
275
275
|
hidden_states = self.embed_tokens(input_ids)
|
276
|
+
hidden_states.mul_(self.config.embedding_multiplier_scale)
|
276
277
|
else:
|
277
278
|
hidden_states = input_embeds
|
278
|
-
hidden_states.mul_(self.config.embedding_multiplier_scale)
|
279
279
|
|
280
280
|
for i in range(len(self.layers)):
|
281
281
|
hidden_states = self.layers[i](positions, hidden_states, input_metadata)
|
@@ -284,7 +284,7 @@ class Grok1Model(nn.Module):
|
|
284
284
|
return hidden_states
|
285
285
|
|
286
286
|
|
287
|
-
class
|
287
|
+
class Grok1ForCausalLM(nn.Module):
|
288
288
|
def __init__(
|
289
289
|
self,
|
290
290
|
config: PretrainedConfig,
|
@@ -415,4 +415,10 @@ def _prepare_presharded_weights(
|
|
415
415
|
return hf_folder, hf_weights_files, use_safetensors
|
416
416
|
|
417
417
|
|
418
|
-
|
418
|
+
class Grok1ModelForCausalLM(Grok1ForCausalLM):
|
419
|
+
"""An alias for backward-compatbility."""
|
420
|
+
|
421
|
+
pass
|
422
|
+
|
423
|
+
|
424
|
+
EntryClass = [Grok1ForCausalLM, Grok1ModelForCausalLM]
|