sglang 0.3.6.post1__py3-none-any.whl → 0.3.6.post3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_offline_throughput.py +55 -2
- sglang/bench_one_batch.py +4 -8
- sglang/bench_one_batch_server.py +6 -5
- sglang/check_env.py +7 -1
- sglang/lang/tracer.py +1 -1
- sglang/launch_server.py +2 -4
- sglang/srt/configs/model_config.py +2 -6
- sglang/srt/layers/attention/flashinfer_backend.py +3 -3
- sglang/srt/layers/sampler.py +1 -1
- sglang/srt/managers/data_parallel_controller.py +7 -11
- sglang/srt/managers/detokenizer_manager.py +7 -6
- sglang/srt/managers/image_processor.py +7 -10
- sglang/srt/managers/io_struct.py +0 -10
- sglang/srt/managers/schedule_batch.py +51 -13
- sglang/srt/managers/scheduler.py +41 -29
- sglang/srt/managers/session_controller.py +15 -7
- sglang/srt/managers/tokenizer_manager.py +4 -33
- sglang/srt/managers/tp_worker_overlap_thread.py +11 -2
- sglang/srt/models/grok.py +11 -48
- sglang/srt/models/llava.py +16 -9
- sglang/srt/models/olmo2.py +392 -0
- sglang/srt/models/qwen2_vl.py +10 -3
- sglang/srt/openai_api/adapter.py +1 -1
- sglang/srt/server.py +48 -45
- sglang/srt/server_args.py +1 -1
- sglang/srt/utils.py +22 -24
- sglang/test/test_utils.py +21 -8
- sglang/utils.py +2 -2
- sglang/version.py +1 -1
- {sglang-0.3.6.post1.dist-info → sglang-0.3.6.post3.dist-info}/METADATA +4 -2
- {sglang-0.3.6.post1.dist-info → sglang-0.3.6.post3.dist-info}/RECORD +34 -36
- sglang/srt/layers/fused_moe_grok/__init__.py +0 -1
- sglang/srt/layers/fused_moe_grok/fused_moe.py +0 -692
- sglang/srt/layers/fused_moe_grok/layer.py +0 -630
- {sglang-0.3.6.post1.dist-info → sglang-0.3.6.post3.dist-info}/LICENSE +0 -0
- {sglang-0.3.6.post1.dist-info → sglang-0.3.6.post3.dist-info}/WHEEL +0 -0
- {sglang-0.3.6.post1.dist-info → sglang-0.3.6.post3.dist-info}/top_level.txt +0 -0
sglang/srt/managers/scheduler.py
CHANGED
@@ -15,6 +15,7 @@
|
|
15
15
|
|
16
16
|
import logging
|
17
17
|
import os
|
18
|
+
import signal
|
18
19
|
import threading
|
19
20
|
import time
|
20
21
|
import warnings
|
@@ -23,6 +24,7 @@ from concurrent import futures
|
|
23
24
|
from types import SimpleNamespace
|
24
25
|
from typing import List, Optional
|
25
26
|
|
27
|
+
import psutil
|
26
28
|
import torch
|
27
29
|
import zmq
|
28
30
|
|
@@ -36,8 +38,6 @@ from sglang.srt.managers.io_struct import (
|
|
36
38
|
BatchTokenIDOut,
|
37
39
|
CloseSessionReqInput,
|
38
40
|
FlushCacheReq,
|
39
|
-
GetMemPoolSizeReq,
|
40
|
-
GetMemPoolSizeReqOutput,
|
41
41
|
OpenSessionReqInput,
|
42
42
|
OpenSessionReqOutput,
|
43
43
|
ProfileReq,
|
@@ -71,9 +71,9 @@ from sglang.srt.utils import (
|
|
71
71
|
broadcast_pyobj,
|
72
72
|
configure_logger,
|
73
73
|
crash_on_warnings,
|
74
|
+
get_bool_env_var,
|
74
75
|
get_zmq_socket,
|
75
|
-
|
76
|
-
kill_parent_process,
|
76
|
+
set_gpu_proc_affinity,
|
77
77
|
set_random_seed,
|
78
78
|
suppress_other_loggers,
|
79
79
|
)
|
@@ -82,7 +82,7 @@ from sglang.utils import get_exception_traceback
|
|
82
82
|
logger = logging.getLogger(__name__)
|
83
83
|
|
84
84
|
# Test retract decode
|
85
|
-
test_retract =
|
85
|
+
test_retract = get_bool_env_var("SGLANG_TEST_RETRACT")
|
86
86
|
|
87
87
|
|
88
88
|
class Scheduler:
|
@@ -169,6 +169,10 @@ class Scheduler:
|
|
169
169
|
self.enable_overlap = False
|
170
170
|
logger.info("Overlap scheduler is disabled for embedding models.")
|
171
171
|
|
172
|
+
if self.model_config.is_multimodal:
|
173
|
+
self.enable_overlap = False
|
174
|
+
logger.info("Overlap scheduler is disabled for multimodal models.")
|
175
|
+
|
172
176
|
if self.enable_overlap:
|
173
177
|
self.disable_jump_forward = True
|
174
178
|
|
@@ -311,6 +315,7 @@ class Scheduler:
|
|
311
315
|
self.watchdog_timeout = server_args.watchdog_timeout
|
312
316
|
t = threading.Thread(target=self.watchdog_thread, daemon=True)
|
313
317
|
t.start()
|
318
|
+
self.parent_process = psutil.Process().parent()
|
314
319
|
|
315
320
|
# Init profiler
|
316
321
|
if os.getenv("SGLANG_TORCH_PROFILER_DIR", "") == "":
|
@@ -354,7 +359,7 @@ class Scheduler:
|
|
354
359
|
self.watchdog_last_time = time.time()
|
355
360
|
time.sleep(self.watchdog_timeout / 2)
|
356
361
|
|
357
|
-
|
362
|
+
self.parent_process.send_signal(signal.SIGQUIT)
|
358
363
|
|
359
364
|
@torch.no_grad()
|
360
365
|
def event_loop_normal(self):
|
@@ -514,10 +519,6 @@ class Scheduler:
|
|
514
519
|
self.send_to_tokenizer.send_pyobj(OpenSessionReqOutput(session_id))
|
515
520
|
elif isinstance(recv_req, CloseSessionReqInput):
|
516
521
|
self.close_session(recv_req)
|
517
|
-
elif isinstance(recv_req, GetMemPoolSizeReq):
|
518
|
-
self.send_to_tokenizer.send_pyobj(
|
519
|
-
GetMemPoolSizeReqOutput(self.max_total_num_tokens)
|
520
|
-
)
|
521
522
|
else:
|
522
523
|
raise ValueError(f"Invalid request: {recv_req}")
|
523
524
|
|
@@ -525,8 +526,9 @@ class Scheduler:
|
|
525
526
|
self,
|
526
527
|
recv_req: TokenizedGenerateReqInput,
|
527
528
|
):
|
529
|
+
# Create a new request
|
528
530
|
if recv_req.session_id is None or recv_req.session_id not in self.sessions:
|
529
|
-
|
531
|
+
|
530
532
|
if recv_req.input_embeds is not None:
|
531
533
|
# Generate fake input_ids based on the length of input_embeds
|
532
534
|
seq_length = len(recv_req.input_embeds)
|
@@ -557,24 +559,30 @@ class Scheduler:
|
|
557
559
|
self.waiting_queue.append(req)
|
558
560
|
return
|
559
561
|
|
560
|
-
#
|
562
|
+
# Handle image inputs
|
561
563
|
if recv_req.image_inputs is not None:
|
562
|
-
|
563
|
-
|
564
|
-
)
|
564
|
+
image_inputs = ImageInputs.from_dict(recv_req.image_inputs)
|
565
|
+
# Expand a single image token into multiple dummy tokens for receiving image embeddings
|
565
566
|
req.origin_input_ids = self.pad_input_ids_func(
|
566
|
-
req.
|
567
|
+
req.origin_input_ids, image_inputs
|
567
568
|
)
|
569
|
+
req.extend_image_inputs(image_inputs)
|
568
570
|
|
569
|
-
if len(req.origin_input_ids)
|
570
|
-
|
571
|
-
"
|
572
|
-
"
|
571
|
+
if len(req.origin_input_ids) >= self.max_req_input_len:
|
572
|
+
logger.error(
|
573
|
+
"Multimodal prompt is too long after expanding multimodal tokens. "
|
574
|
+
f"After expanding {len(req.origin_input_ids_unpadded)=} => {len(req.origin_input_ids)} >= {self.max_req_input_len}. "
|
573
575
|
)
|
576
|
+
req.origin_input_ids = [0]
|
577
|
+
req.image_inputs = None
|
574
578
|
req.sampling_params.max_new_tokens = 0
|
579
|
+
req.finished_reason = FINISH_ABORT(
|
580
|
+
"Multimodal prompt is too long. Check server logs for details."
|
581
|
+
)
|
575
582
|
self.waiting_queue.append(req)
|
576
583
|
return
|
577
584
|
|
585
|
+
# Copy more attributes
|
578
586
|
req.return_logprob = recv_req.return_logprob
|
579
587
|
req.top_logprobs_num = recv_req.top_logprobs_num
|
580
588
|
req.stream = recv_req.stream
|
@@ -1342,13 +1350,15 @@ class Scheduler:
|
|
1342
1350
|
|
1343
1351
|
if to_del is not None:
|
1344
1352
|
del self.waiting_queue[to_del]
|
1353
|
+
logger.debug(f"Abort queued request. {req.rid=}")
|
1354
|
+
return
|
1345
1355
|
|
1346
1356
|
# Delete requests in the running batch
|
1347
1357
|
if self.running_batch:
|
1348
1358
|
for req in self.running_batch.reqs:
|
1349
1359
|
if req.rid == recv_req.rid and not req.finished():
|
1350
|
-
req.
|
1351
|
-
|
1360
|
+
logger.debug(f"Abort running request. {req.rid=}")
|
1361
|
+
req.to_abort = True
|
1352
1362
|
break
|
1353
1363
|
|
1354
1364
|
def update_weights(self, recv_req: UpdateWeightReqInput):
|
@@ -1404,11 +1414,12 @@ def run_scheduler_process(
|
|
1404
1414
|
pipe_writer,
|
1405
1415
|
):
|
1406
1416
|
# set cpu affinity to this gpu process
|
1407
|
-
|
1417
|
+
if get_bool_env_var("SGLANG_SET_CPU_AFFINITY"):
|
1418
|
+
set_gpu_proc_affinity(server_args.tp_size, server_args.nnodes, gpu_id)
|
1408
1419
|
|
1409
|
-
# [For Router] if env var "
|
1410
|
-
if dp_rank is None and "
|
1411
|
-
dp_rank = int(os.environ["
|
1420
|
+
# [For Router] if env var "SGLANG_DP_RANK" exist, set dp_rank to the value of the env var
|
1421
|
+
if dp_rank is None and "SGLANG_DP_RANK" in os.environ:
|
1422
|
+
dp_rank = int(os.environ["SGLANG_DP_RANK"])
|
1412
1423
|
|
1413
1424
|
if dp_rank is None:
|
1414
1425
|
configure_logger(server_args, prefix=f" TP{tp_rank}")
|
@@ -1416,6 +1427,7 @@ def run_scheduler_process(
|
|
1416
1427
|
configure_logger(server_args, prefix=f" DP{dp_rank} TP{tp_rank}")
|
1417
1428
|
|
1418
1429
|
suppress_other_loggers()
|
1430
|
+
parent_process = psutil.Process().parent()
|
1419
1431
|
|
1420
1432
|
try:
|
1421
1433
|
scheduler = Scheduler(server_args, port_args, gpu_id, tp_rank, dp_rank)
|
@@ -1427,6 +1439,6 @@ def run_scheduler_process(
|
|
1427
1439
|
else:
|
1428
1440
|
scheduler.event_loop_normal()
|
1429
1441
|
except Exception:
|
1430
|
-
|
1431
|
-
logger.error(
|
1432
|
-
|
1442
|
+
traceback = get_exception_traceback()
|
1443
|
+
logger.error(f"Scheduler hit an exception: {traceback}")
|
1444
|
+
parent_process.send_signal(signal.SIGQUIT)
|
@@ -10,10 +10,7 @@
|
|
10
10
|
# limitations under the License.
|
11
11
|
# ==============================================================================
|
12
12
|
|
13
|
-
import copy
|
14
13
|
import uuid
|
15
|
-
from dataclasses import dataclass
|
16
|
-
from typing import Optional
|
17
14
|
|
18
15
|
from sglang.srt.managers.io_struct import TokenizedGenerateReqInput
|
19
16
|
from sglang.srt.managers.schedule_batch import FINISH_ABORT, List, Req
|
@@ -41,16 +38,27 @@ class Session:
|
|
41
38
|
]
|
42
39
|
+ req.input_ids
|
43
40
|
)
|
41
|
+
input_ids_unpadded = (
|
42
|
+
self.reqs[-1].origin_input_ids_unpadded
|
43
|
+
+ self.reqs[-1].output_ids[
|
44
|
+
: self.reqs[-1].sampling_params.max_new_tokens
|
45
|
+
]
|
46
|
+
+ req.input_ids
|
47
|
+
)
|
44
48
|
else:
|
45
49
|
input_ids = req.input_ids
|
50
|
+
input_ids_unpadded = req.input_ids
|
46
51
|
new_req = Req(
|
47
|
-
req.rid,
|
48
|
-
None,
|
49
|
-
input_ids,
|
50
|
-
|
52
|
+
rid=req.rid,
|
53
|
+
origin_input_text=None,
|
54
|
+
origin_input_ids=input_ids,
|
55
|
+
origin_input_ids_unpadded=input_ids_unpadded,
|
56
|
+
sampling_params=req.sampling_params,
|
51
57
|
lora_path=req.lora_path,
|
52
58
|
session_id=self.session_id,
|
53
59
|
)
|
60
|
+
if len(self.reqs) > 0:
|
61
|
+
new_req.image_inputs = self.reqs[-1].image_inputs
|
54
62
|
new_req.tokenizer = tokenizer
|
55
63
|
if req.session_rid is not None and len(self.reqs) == 0:
|
56
64
|
new_req.finished_reason = FINISH_ABORT(
|
@@ -45,8 +45,6 @@ from sglang.srt.managers.io_struct import (
|
|
45
45
|
EmbeddingReqInput,
|
46
46
|
FlushCacheReq,
|
47
47
|
GenerateReqInput,
|
48
|
-
GetMemPoolSizeReq,
|
49
|
-
GetMemPoolSizeReqOutput,
|
50
48
|
OpenSessionReqInput,
|
51
49
|
OpenSessionReqOutput,
|
52
50
|
ProfileReq,
|
@@ -58,7 +56,7 @@ from sglang.srt.managers.io_struct import (
|
|
58
56
|
from sglang.srt.metrics.collector import TokenizerMetricsCollector
|
59
57
|
from sglang.srt.sampling.sampling_params import SamplingParams
|
60
58
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
61
|
-
from sglang.srt.utils import get_zmq_socket,
|
59
|
+
from sglang.srt.utils import get_zmq_socket, kill_process_tree
|
62
60
|
|
63
61
|
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
|
64
62
|
|
@@ -218,7 +216,8 @@ class TokenizerManager:
|
|
218
216
|
input_ids = obj.input_ids
|
219
217
|
|
220
218
|
if self.is_generation:
|
221
|
-
|
219
|
+
# TODO: also support getting embeddings for multimodal models
|
220
|
+
image_inputs: Dict = await self.image_processor.process_images_async(
|
222
221
|
obj.image_data, input_text or input_ids, obj
|
223
222
|
)
|
224
223
|
if image_inputs and "input_ids" in image_inputs:
|
@@ -406,25 +405,6 @@ class TokenizerManager:
|
|
406
405
|
req = ProfileReq.STOP_PROFILE
|
407
406
|
self.send_to_scheduler.send_pyobj(req)
|
408
407
|
|
409
|
-
async def get_memory_pool_size(self):
|
410
|
-
if self.to_create_loop:
|
411
|
-
self.create_handle_loop()
|
412
|
-
|
413
|
-
req = GetMemPoolSizeReq()
|
414
|
-
|
415
|
-
self.send_to_scheduler.send_pyobj(req)
|
416
|
-
self.mem_pool_size = asyncio.Future()
|
417
|
-
|
418
|
-
# FIXME: Each request should have its own future instead of using `self.mem_pool_size`.
|
419
|
-
if self.server_args.dp_size == 1:
|
420
|
-
res = await self.mem_pool_size
|
421
|
-
return res.size
|
422
|
-
else: # self.server_args.dp_size > 1
|
423
|
-
self.mem_pool_size_tmp = []
|
424
|
-
res = await self.mem_pool_size
|
425
|
-
ret = [r.size for r in res]
|
426
|
-
return ret
|
427
|
-
|
428
408
|
async def update_weights(
|
429
409
|
self, obj: UpdateWeightReqInput, request: Optional[fastapi.Request] = None
|
430
410
|
):
|
@@ -532,7 +512,7 @@ class TokenizerManager:
|
|
532
512
|
else:
|
533
513
|
break
|
534
514
|
|
535
|
-
|
515
|
+
kill_process_tree(os.getpid(), include_parent=True)
|
536
516
|
sys.exit(0)
|
537
517
|
|
538
518
|
async def handle_loop(self):
|
@@ -552,15 +532,6 @@ class TokenizerManager:
|
|
552
532
|
if len(self.model_update_tmp) == self.server_args.dp_size:
|
553
533
|
self.model_update_result.set_result(self.model_update_tmp)
|
554
534
|
continue
|
555
|
-
elif isinstance(recv_obj, GetMemPoolSizeReqOutput):
|
556
|
-
if self.server_args.dp_size == 1:
|
557
|
-
self.mem_pool_size.set_result(recv_obj)
|
558
|
-
else: # self.sever_args.dp_size > 1
|
559
|
-
self.mem_pool_size_tmp.append(recv_obj)
|
560
|
-
# set future if the all results are received
|
561
|
-
if len(self.mem_pool_size_tmp) == self.server_args.dp_size:
|
562
|
-
self.mem_pool_size.set_result(self.mem_pool_size_tmp)
|
563
|
-
continue
|
564
535
|
elif isinstance(recv_obj, OpenSessionReqOutput):
|
565
536
|
self.session_futures[recv_obj.session_id].set_result(
|
566
537
|
recv_obj.session_id
|
@@ -15,16 +15,19 @@
|
|
15
15
|
|
16
16
|
import dataclasses
|
17
17
|
import logging
|
18
|
+
import signal
|
18
19
|
import threading
|
19
20
|
from queue import Queue
|
20
21
|
from typing import Optional
|
21
22
|
|
23
|
+
import psutil
|
22
24
|
import torch
|
23
25
|
|
24
26
|
from sglang.srt.managers.io_struct import UpdateWeightReqInput
|
25
27
|
from sglang.srt.managers.schedule_batch import ModelWorkerBatch
|
26
28
|
from sglang.srt.managers.tp_worker import TpModelWorker
|
27
29
|
from sglang.srt.server_args import ServerArgs
|
30
|
+
from sglang.utils import get_exception_traceback
|
28
31
|
|
29
32
|
logger = logging.getLogger(__name__)
|
30
33
|
|
@@ -70,6 +73,7 @@ class TpModelWorkerClient:
|
|
70
73
|
target=self.forward_thread_func,
|
71
74
|
)
|
72
75
|
self.forward_thread.start()
|
76
|
+
self.parent_process = psutil.Process().parent()
|
73
77
|
|
74
78
|
def get_worker_info(self):
|
75
79
|
return self.worker.get_worker_info()
|
@@ -87,8 +91,13 @@ class TpModelWorkerClient:
|
|
87
91
|
)
|
88
92
|
|
89
93
|
def forward_thread_func(self):
|
90
|
-
|
91
|
-
self.
|
94
|
+
try:
|
95
|
+
with torch.cuda.stream(self.forward_stream):
|
96
|
+
self.forward_thread_func_()
|
97
|
+
except Exception:
|
98
|
+
traceback = get_exception_traceback()
|
99
|
+
logger.error(f"TpModelWorkerClient hit an exception: {traceback}")
|
100
|
+
self.parent_process.send_signal(signal.SIGQUIT)
|
92
101
|
|
93
102
|
@torch.no_grad()
|
94
103
|
def forward_thread_func_(self):
|
sglang/srt/models/grok.py
CHANGED
@@ -16,22 +16,17 @@
|
|
16
16
|
# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/mixtral.py#L1
|
17
17
|
"""Inference-only Grok1 model."""
|
18
18
|
|
19
|
-
import
|
20
|
-
from typing import Iterable, List, Optional, Tuple
|
19
|
+
from typing import Iterable, Optional, Tuple
|
21
20
|
|
22
21
|
import torch
|
23
22
|
import torch.nn.functional as F
|
24
23
|
from torch import nn
|
25
24
|
from transformers import PretrainedConfig
|
26
|
-
from vllm.distributed import
|
27
|
-
get_tensor_model_parallel_rank,
|
28
|
-
get_tensor_model_parallel_world_size,
|
29
|
-
)
|
25
|
+
from vllm.distributed import get_tensor_model_parallel_world_size
|
30
26
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
31
|
-
from vllm.model_executor.model_loader.loader import DefaultModelLoader
|
32
27
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
33
28
|
|
34
|
-
from sglang.srt.layers.
|
29
|
+
from sglang.srt.layers.fused_moe_triton import FusedMoE
|
35
30
|
from sglang.srt.layers.layernorm import RMSNorm
|
36
31
|
from sglang.srt.layers.linear import (
|
37
32
|
QKVParallelLinear,
|
@@ -41,10 +36,12 @@ from sglang.srt.layers.linear import (
|
|
41
36
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
42
37
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
43
38
|
from sglang.srt.layers.radix_attention import RadixAttention
|
39
|
+
from sglang.srt.layers.torchao_utils import apply_torchao_config_
|
44
40
|
from sglang.srt.layers.vocab_parallel_embedding import (
|
45
41
|
ParallelLMHead,
|
46
42
|
VocabParallelEmbedding,
|
47
43
|
)
|
44
|
+
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
48
45
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
49
46
|
|
50
47
|
|
@@ -293,17 +290,11 @@ class Grok1ForCausalLM(nn.Module):
|
|
293
290
|
super().__init__()
|
294
291
|
self.config = config
|
295
292
|
self.quant_config = quant_config
|
293
|
+
self.torchao_config = global_server_args_dict["torchao_config"]
|
296
294
|
self.model = Grok1Model(config, quant_config=quant_config)
|
297
295
|
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
298
296
|
self.logits_processor = LogitsProcessor(config)
|
299
297
|
|
300
|
-
# Monkey patch _prepare_weights to load pre-sharded weights
|
301
|
-
setattr(DefaultModelLoader, "_prepare_weights", _prepare_presharded_weights)
|
302
|
-
|
303
|
-
self.use_presharded_weights = True
|
304
|
-
|
305
|
-
warnings.filterwarnings("ignore", category=FutureWarning)
|
306
|
-
|
307
298
|
def forward(
|
308
299
|
self,
|
309
300
|
input_ids: torch.Tensor,
|
@@ -357,28 +348,23 @@ class Grok1ForCausalLM(nn.Module):
|
|
357
348
|
continue
|
358
349
|
name = name.replace(weight_name, param_name)
|
359
350
|
|
360
|
-
if self.use_presharded_weights:
|
361
|
-
extra_kwargs = {
|
362
|
-
"use_presharded_weights": self.use_presharded_weights
|
363
|
-
}
|
364
|
-
else:
|
365
|
-
extra_kwargs = {}
|
366
|
-
|
367
351
|
param = params_dict[name]
|
368
352
|
weight_loader = param.weight_loader
|
369
353
|
weight_loader(
|
370
354
|
param,
|
371
355
|
loaded_weight,
|
372
|
-
|
356
|
+
name,
|
373
357
|
shard_id=shard_id,
|
374
358
|
expert_id=expert_id,
|
375
|
-
**extra_kwargs,
|
376
359
|
)
|
377
360
|
break
|
378
361
|
else:
|
379
362
|
# Skip loading extra bias for GPTQ models.
|
380
363
|
if name.endswith(".bias") and name not in params_dict:
|
381
364
|
continue
|
365
|
+
# Skip loading kv_scale from ckpts towards new design.
|
366
|
+
if name.endswith(".kv_scale") and name not in params_dict:
|
367
|
+
continue
|
382
368
|
if name is None:
|
383
369
|
continue
|
384
370
|
|
@@ -388,30 +374,7 @@ class Grok1ForCausalLM(nn.Module):
|
|
388
374
|
)
|
389
375
|
weight_loader(param, loaded_weight)
|
390
376
|
|
391
|
-
|
392
|
-
old_prepare_weights = getattr(DefaultModelLoader, "_prepare_weights")
|
393
|
-
|
394
|
-
|
395
|
-
def _prepare_presharded_weights(
|
396
|
-
self, model_name_or_path: str, revision: Optional[str], fall_back_to_pt: bool
|
397
|
-
) -> Tuple[str, List[str], bool]:
|
398
|
-
import glob
|
399
|
-
import os
|
400
|
-
|
401
|
-
if get_tensor_model_parallel_world_size() == 1:
|
402
|
-
return old_prepare_weights(self, model_name_or_path, revision, fall_back_to_pt)
|
403
|
-
|
404
|
-
tp_rank = get_tensor_model_parallel_rank()
|
405
|
-
allow_patterns = [f"*-{tp_rank:03d}.bin"]
|
406
|
-
|
407
|
-
hf_folder = model_name_or_path
|
408
|
-
|
409
|
-
hf_weights_files: List[str] = []
|
410
|
-
for pattern in allow_patterns:
|
411
|
-
hf_weights_files += glob.glob(os.path.join(hf_folder, pattern))
|
412
|
-
use_safetensors = False
|
413
|
-
|
414
|
-
return hf_folder, hf_weights_files, use_safetensors
|
377
|
+
apply_torchao_config_(self, params_dict, set(["proj.weight"]))
|
415
378
|
|
416
379
|
|
417
380
|
class Grok1ModelForCausalLM(Grok1ForCausalLM):
|
sglang/srt/models/llava.py
CHANGED
@@ -49,9 +49,15 @@ class LlavaBaseForCausalLM(nn.Module):
|
|
49
49
|
image_sizes, pad_values = image_inputs.image_sizes, image_inputs.pad_values
|
50
50
|
|
51
51
|
# hardcode for spatial_unpad + anyres
|
52
|
-
|
52
|
+
if image_inputs.modalities is not None and (
|
53
|
+
"multi-images" in image_inputs.modalities
|
54
|
+
or "video" in image_inputs.modalities
|
55
|
+
):
|
56
|
+
image_aspect_ratio = "pad"
|
57
|
+
else:
|
58
|
+
image_aspect_ratio = "anyres"
|
53
59
|
offset_list = []
|
54
|
-
for image_s in image_sizes:
|
60
|
+
for image_idx, image_s in enumerate(image_sizes):
|
55
61
|
if len(image_sizes) > 16:
|
56
62
|
# 2x2 pooling with stride 2
|
57
63
|
new_image_feature_len = (
|
@@ -86,10 +92,6 @@ class LlavaBaseForCausalLM(nn.Module):
|
|
86
92
|
new_w = int(new_w // times)
|
87
93
|
new_image_feature_len += new_h * (new_w + 1)
|
88
94
|
|
89
|
-
pad_ids = pad_values * (
|
90
|
-
(new_image_feature_len + len(pad_values)) // len(pad_values)
|
91
|
-
)
|
92
|
-
# print("calculated new_image_feature_len: ", new_image_feature_len)
|
93
95
|
try:
|
94
96
|
offset = input_ids.index(self.config.image_token_index)
|
95
97
|
except ValueError:
|
@@ -97,7 +99,7 @@ class LlavaBaseForCausalLM(nn.Module):
|
|
97
99
|
# old_len + pad_len - 1, because we need to remove image_token_id
|
98
100
|
input_ids = (
|
99
101
|
input_ids[:offset]
|
100
|
-
+
|
102
|
+
+ [pad_values[image_idx]] * new_image_feature_len
|
101
103
|
+ input_ids[offset + 1 :]
|
102
104
|
)
|
103
105
|
offset_list.append(offset)
|
@@ -132,7 +134,6 @@ class LlavaBaseForCausalLM(nn.Module):
|
|
132
134
|
image_inputs = forward_batch.image_inputs
|
133
135
|
|
134
136
|
if forward_batch.forward_mode.is_extend():
|
135
|
-
bs = forward_batch.batch_size
|
136
137
|
# Got List[List[str]] extend it to List[str]
|
137
138
|
# The length of the List should be equal to batch size
|
138
139
|
modalities_list = []
|
@@ -140,11 +141,16 @@ class LlavaBaseForCausalLM(nn.Module):
|
|
140
141
|
for im in image_inputs:
|
141
142
|
if im and im.modalities is not None:
|
142
143
|
modalities_list.extend(im.modalities)
|
143
|
-
if im and im.image_offsets
|
144
|
+
if im and im.image_offsets:
|
144
145
|
max_image_offset.append(max(im.image_offsets))
|
145
146
|
else:
|
146
147
|
max_image_offset.append(-1)
|
147
148
|
|
149
|
+
# Clamp input ids. This is because the input_ids for the image tokens are
|
150
|
+
# filled with the hash values of the image for the prefix matching in the radix attention.
|
151
|
+
# There values are useless because their embeddings will be replaced by vision embeddings anyway.
|
152
|
+
input_ids.clamp_(min=0, max=self.config.vocab_size - 1)
|
153
|
+
|
148
154
|
# Embed text inputs
|
149
155
|
input_embeds = self.language_model.model.embed_tokens(input_ids)
|
150
156
|
|
@@ -152,6 +158,7 @@ class LlavaBaseForCausalLM(nn.Module):
|
|
152
158
|
need_vision = start_positions <= np.array(max_image_offset)
|
153
159
|
|
154
160
|
if need_vision.any():
|
161
|
+
bs = forward_batch.batch_size
|
155
162
|
pixel_values = [
|
156
163
|
image_inputs[i].pixel_values for i in range(bs) if need_vision[i]
|
157
164
|
]
|