sglang 0.4.4__py3-none-any.whl → 0.4.4.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -0
- sglang/api.py +6 -0
- sglang/bench_one_batch.py +1 -1
- sglang/bench_one_batch_server.py +1 -1
- sglang/bench_serving.py +3 -1
- sglang/check_env.py +3 -4
- sglang/lang/backend/openai.py +18 -5
- sglang/lang/chat_template.py +28 -7
- sglang/lang/interpreter.py +7 -3
- sglang/lang/ir.py +10 -0
- sglang/srt/_custom_ops.py +1 -1
- sglang/srt/code_completion_parser.py +174 -0
- sglang/srt/configs/__init__.py +2 -6
- sglang/srt/configs/deepseekvl2.py +667 -0
- sglang/srt/configs/janus_pro.py +3 -4
- sglang/srt/configs/load_config.py +1 -0
- sglang/srt/configs/model_config.py +63 -11
- sglang/srt/configs/utils.py +25 -0
- sglang/srt/connector/__init__.py +51 -0
- sglang/srt/connector/base_connector.py +112 -0
- sglang/srt/connector/redis.py +85 -0
- sglang/srt/connector/s3.py +122 -0
- sglang/srt/connector/serde/__init__.py +31 -0
- sglang/srt/connector/serde/safe_serde.py +29 -0
- sglang/srt/connector/serde/serde.py +43 -0
- sglang/srt/connector/utils.py +35 -0
- sglang/srt/conversation.py +88 -0
- sglang/srt/disaggregation/conn.py +81 -0
- sglang/srt/disaggregation/decode.py +495 -0
- sglang/srt/disaggregation/mini_lb.py +285 -0
- sglang/srt/disaggregation/prefill.py +249 -0
- sglang/srt/disaggregation/utils.py +44 -0
- sglang/srt/distributed/parallel_state.py +10 -3
- sglang/srt/entrypoints/engine.py +55 -5
- sglang/srt/entrypoints/http_server.py +71 -12
- sglang/srt/function_call_parser.py +164 -54
- sglang/srt/hf_transformers_utils.py +28 -3
- sglang/srt/layers/activation.py +4 -2
- sglang/srt/layers/attention/base_attn_backend.py +1 -1
- sglang/srt/layers/attention/flashattention_backend.py +295 -0
- sglang/srt/layers/attention/flashinfer_backend.py +1 -1
- sglang/srt/layers/attention/flashmla_backend.py +284 -0
- sglang/srt/layers/attention/triton_backend.py +171 -38
- sglang/srt/layers/attention/triton_ops/decode_attention.py +94 -31
- sglang/srt/layers/attention/triton_ops/extend_attention.py +14 -5
- sglang/srt/layers/attention/utils.py +53 -0
- sglang/srt/layers/attention/vision.py +9 -28
- sglang/srt/layers/dp_attention.py +62 -23
- sglang/srt/layers/elementwise.py +411 -0
- sglang/srt/layers/layernorm.py +24 -2
- sglang/srt/layers/linear.py +17 -5
- sglang/srt/layers/logits_processor.py +26 -7
- sglang/srt/layers/moe/ep_moe/kernels.py +110 -11
- sglang/srt/layers/moe/ep_moe/layer.py +273 -1
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +416 -0
- sglang/srt/layers/moe/fused_moe_native.py +2 -1
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L20,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L40S,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1024,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +23 -32
- sglang/srt/layers/moe/fused_moe_triton/layer.py +1 -2
- sglang/srt/layers/moe/router.py +342 -0
- sglang/srt/layers/moe/topk.py +31 -18
- sglang/srt/layers/parameter.py +1 -1
- sglang/srt/layers/quantization/__init__.py +184 -126
- sglang/srt/layers/quantization/base_config.py +5 -0
- sglang/srt/layers/quantization/blockwise_int8.py +1 -1
- sglang/srt/layers/quantization/compressed_tensors/__init__.py +0 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +652 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +658 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +9 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py +56 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +162 -0
- sglang/srt/layers/quantization/compressed_tensors/utils.py +218 -0
- sglang/srt/layers/quantization/fp8.py +76 -34
- sglang/srt/layers/quantization/fp8_kernel.py +24 -8
- sglang/srt/layers/quantization/fp8_utils.py +284 -28
- sglang/srt/layers/quantization/gptq.py +36 -9
- sglang/srt/layers/quantization/kv_cache.py +98 -0
- sglang/srt/layers/quantization/modelopt_quant.py +9 -7
- sglang/srt/layers/quantization/utils.py +153 -0
- sglang/srt/layers/quantization/w8a8_fp8.py +70 -19
- sglang/srt/layers/rotary_embedding.py +66 -87
- sglang/srt/layers/sampler.py +1 -1
- sglang/srt/lora/layers.py +68 -0
- sglang/srt/lora/lora.py +2 -22
- sglang/srt/lora/lora_manager.py +47 -23
- sglang/srt/lora/mem_pool.py +110 -51
- sglang/srt/lora/utils.py +12 -1
- sglang/srt/managers/cache_controller.py +4 -5
- sglang/srt/managers/data_parallel_controller.py +31 -9
- sglang/srt/managers/expert_distribution.py +81 -0
- sglang/srt/managers/io_struct.py +39 -3
- sglang/srt/managers/mm_utils.py +373 -0
- sglang/srt/managers/multimodal_processor.py +68 -0
- sglang/srt/managers/multimodal_processors/base_processor.py +275 -0
- sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +119 -0
- sglang/srt/managers/multimodal_processors/gemma3.py +83 -0
- sglang/srt/managers/{image_processors → multimodal_processors}/janus_pro.py +20 -15
- sglang/srt/managers/{image_processors → multimodal_processors}/llava.py +10 -15
- sglang/srt/managers/multimodal_processors/minicpm.py +167 -0
- sglang/srt/managers/{image_processors → multimodal_processors}/mlama.py +7 -8
- sglang/srt/managers/{image_processors → multimodal_processors}/qwen_vl.py +28 -22
- sglang/srt/managers/schedule_batch.py +134 -31
- sglang/srt/managers/scheduler.py +325 -38
- sglang/srt/managers/scheduler_output_processor_mixin.py +4 -1
- sglang/srt/managers/session_controller.py +1 -1
- sglang/srt/managers/tokenizer_manager.py +59 -23
- sglang/srt/managers/tp_worker.py +1 -1
- sglang/srt/managers/tp_worker_overlap_thread.py +3 -3
- sglang/srt/managers/utils.py +6 -1
- sglang/srt/mem_cache/hiradix_cache.py +27 -8
- sglang/srt/mem_cache/memory_pool.py +258 -98
- sglang/srt/mem_cache/paged_allocator.py +2 -2
- sglang/srt/mem_cache/radix_cache.py +4 -4
- sglang/srt/model_executor/cuda_graph_runner.py +85 -28
- sglang/srt/model_executor/forward_batch_info.py +81 -15
- sglang/srt/model_executor/model_runner.py +70 -6
- sglang/srt/model_loader/loader.py +160 -2
- sglang/srt/model_loader/weight_utils.py +45 -0
- sglang/srt/models/deepseek_janus_pro.py +29 -86
- sglang/srt/models/deepseek_nextn.py +22 -10
- sglang/srt/models/deepseek_v2.py +326 -192
- sglang/srt/models/deepseek_vl2.py +358 -0
- sglang/srt/models/gemma3_causal.py +684 -0
- sglang/srt/models/gemma3_mm.py +462 -0
- sglang/srt/models/grok.py +374 -119
- sglang/srt/models/llama.py +47 -7
- sglang/srt/models/llama_eagle.py +1 -0
- sglang/srt/models/llama_eagle3.py +196 -0
- sglang/srt/models/llava.py +3 -3
- sglang/srt/models/llavavid.py +3 -3
- sglang/srt/models/minicpmo.py +1995 -0
- sglang/srt/models/minicpmv.py +62 -137
- sglang/srt/models/mllama.py +4 -4
- sglang/srt/models/phi3_small.py +1 -1
- sglang/srt/models/qwen2.py +3 -0
- sglang/srt/models/qwen2_5_vl.py +68 -146
- sglang/srt/models/qwen2_classification.py +75 -0
- sglang/srt/models/qwen2_moe.py +9 -1
- sglang/srt/models/qwen2_vl.py +25 -63
- sglang/srt/openai_api/adapter.py +145 -47
- sglang/srt/openai_api/protocol.py +23 -2
- sglang/srt/sampling/sampling_batch_info.py +1 -1
- sglang/srt/sampling/sampling_params.py +6 -6
- sglang/srt/server_args.py +104 -14
- sglang/srt/speculative/build_eagle_tree.py +7 -347
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +41 -5
- sglang/srt/speculative/eagle_utils.py +208 -252
- sglang/srt/speculative/eagle_worker.py +139 -53
- sglang/srt/speculative/spec_info.py +6 -1
- sglang/srt/torch_memory_saver_adapter.py +22 -0
- sglang/srt/utils.py +182 -21
- sglang/test/__init__.py +0 -0
- sglang/test/attention/__init__.py +0 -0
- sglang/test/attention/test_flashattn_backend.py +312 -0
- sglang/test/runners.py +2 -0
- sglang/test/test_activation.py +2 -1
- sglang/test/test_block_fp8.py +5 -4
- sglang/test/test_block_fp8_ep.py +2 -1
- sglang/test/test_dynamic_grad_mode.py +58 -0
- sglang/test/test_layernorm.py +3 -2
- sglang/test/test_utils.py +55 -4
- sglang/utils.py +31 -0
- sglang/version.py +1 -1
- {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/METADATA +12 -8
- {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/RECORD +171 -125
- {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/WHEEL +1 -1
- sglang/srt/configs/qwen2_5_vl_config.py +0 -1006
- sglang/srt/managers/image_processor.py +0 -55
- sglang/srt/managers/image_processors/base_image_processor.py +0 -219
- sglang/srt/managers/image_processors/minicpmv.py +0 -86
- sglang/srt/managers/multi_modality_padding.py +0 -134
- {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info/licenses}/LICENSE +0 -0
- {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/top_level.txt +0 -0
@@ -9,6 +9,7 @@ import json
|
|
9
9
|
import logging
|
10
10
|
import math
|
11
11
|
import os
|
12
|
+
import time
|
12
13
|
from abc import ABC, abstractmethod
|
13
14
|
from contextlib import contextmanager
|
14
15
|
from typing import Any, Dict, Generator, Iterable, List, Optional, Tuple, cast
|
@@ -25,6 +26,12 @@ from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
|
|
25
26
|
from sglang.srt.configs.device_config import DeviceConfig
|
26
27
|
from sglang.srt.configs.load_config import LoadConfig, LoadFormat
|
27
28
|
from sglang.srt.configs.model_config import ModelConfig
|
29
|
+
from sglang.srt.connector import (
|
30
|
+
ConnectorType,
|
31
|
+
create_remote_connector,
|
32
|
+
get_connector_type,
|
33
|
+
)
|
34
|
+
from sglang.srt.connector.utils import parse_model_name
|
28
35
|
from sglang.srt.distributed import (
|
29
36
|
get_tensor_model_parallel_rank,
|
30
37
|
get_tensor_model_parallel_world_size,
|
@@ -46,6 +53,7 @@ from sglang.srt.model_loader.weight_utils import (
|
|
46
53
|
np_cache_weights_iterator,
|
47
54
|
pt_weights_iterator,
|
48
55
|
safetensors_weights_iterator,
|
56
|
+
set_runai_streamer_env,
|
49
57
|
)
|
50
58
|
from sglang.srt.utils import (
|
51
59
|
get_bool_env_var,
|
@@ -194,7 +202,7 @@ class DefaultModelLoader(BaseModelLoader):
|
|
194
202
|
def _maybe_download_from_modelscope(
|
195
203
|
self, model: str, revision: Optional[str]
|
196
204
|
) -> Optional[str]:
|
197
|
-
"""Download model from ModelScope hub if
|
205
|
+
"""Download model from ModelScope hub if SGLANG_USE_MODELSCOPE is True.
|
198
206
|
|
199
207
|
Returns the path to the downloaded model, or None if the model is not
|
200
208
|
downloaded from ModelScope."""
|
@@ -490,7 +498,7 @@ class ShardedStateLoader(BaseModelLoader):
|
|
490
498
|
Model loader that directly loads each worker's model state dict, which
|
491
499
|
enables a fast load path for large tensor-parallel models where each worker
|
492
500
|
only needs to read its own shard rather than the entire checkpoint. See
|
493
|
-
`examples/save_sharded_state.py` for creating a sharded checkpoint.
|
501
|
+
`examples/runtime/engine/save_sharded_state.py` for creating a sharded checkpoint.
|
494
502
|
"""
|
495
503
|
|
496
504
|
DEFAULT_PATTERN = "model-rank-{rank}-part-{part}.safetensors"
|
@@ -1204,6 +1212,153 @@ class GGUFModelLoader(BaseModelLoader):
|
|
1204
1212
|
return model
|
1205
1213
|
|
1206
1214
|
|
1215
|
+
class RemoteModelLoader(BaseModelLoader):
|
1216
|
+
"""Model loader that can load Tensors from remote database."""
|
1217
|
+
|
1218
|
+
def __init__(self, load_config: LoadConfig):
|
1219
|
+
super().__init__(load_config)
|
1220
|
+
# TODO @DellCurry: move to s3 connector only
|
1221
|
+
set_runai_streamer_env(load_config)
|
1222
|
+
|
1223
|
+
def _get_weights_iterator_kv(
|
1224
|
+
self,
|
1225
|
+
client,
|
1226
|
+
) -> Generator[Tuple[str, torch.Tensor], None, None]:
|
1227
|
+
"""Get an iterator for the model weights from remote storage."""
|
1228
|
+
assert get_connector_type(client) == ConnectorType.KV
|
1229
|
+
rank = get_tensor_model_parallel_rank()
|
1230
|
+
return client.weight_iterator(rank)
|
1231
|
+
|
1232
|
+
def _get_weights_iterator_fs(
|
1233
|
+
self,
|
1234
|
+
client,
|
1235
|
+
) -> Generator[Tuple[str, torch.Tensor], None, None]:
|
1236
|
+
"""Get an iterator for the model weights from remote storage."""
|
1237
|
+
assert get_connector_type(client) == ConnectorType.FS
|
1238
|
+
return client.weight_iterator()
|
1239
|
+
|
1240
|
+
def download_model(self, model_config: ModelConfig) -> None:
|
1241
|
+
pass
|
1242
|
+
|
1243
|
+
@staticmethod
|
1244
|
+
def save_model(
|
1245
|
+
model: torch.nn.Module,
|
1246
|
+
model_path: str,
|
1247
|
+
url: str,
|
1248
|
+
) -> None:
|
1249
|
+
with create_remote_connector(url) as client:
|
1250
|
+
assert get_connector_type(client) == ConnectorType.KV
|
1251
|
+
model_name = parse_model_name(url)
|
1252
|
+
rank = get_tensor_model_parallel_rank()
|
1253
|
+
state_dict = ShardedStateLoader._filter_subtensors(model.state_dict())
|
1254
|
+
for key, tensor in state_dict.items():
|
1255
|
+
r_key = f"{model_name}/keys/rank_{rank}/{key}"
|
1256
|
+
client.set(r_key, tensor)
|
1257
|
+
|
1258
|
+
for root, _, files in os.walk(model_path):
|
1259
|
+
for file_name in files:
|
1260
|
+
# ignore hidden files
|
1261
|
+
if file_name.startswith("."):
|
1262
|
+
continue
|
1263
|
+
if os.path.splitext(file_name)[1] not in (
|
1264
|
+
".bin",
|
1265
|
+
".pt",
|
1266
|
+
".safetensors",
|
1267
|
+
):
|
1268
|
+
file_path = os.path.join(root, file_name)
|
1269
|
+
with open(file_path, encoding="utf-8") as file:
|
1270
|
+
file_content = file.read()
|
1271
|
+
f_key = f"{model_name}/files/{file_name}"
|
1272
|
+
client.setstr(f_key, file_content)
|
1273
|
+
|
1274
|
+
def _load_model_from_remote_kv(self, model: nn.Module, client):
|
1275
|
+
for _, module in model.named_modules():
|
1276
|
+
quant_method = getattr(module, "quant_method", None)
|
1277
|
+
if quant_method is not None:
|
1278
|
+
quant_method.process_weights_after_loading(module)
|
1279
|
+
weights_iterator = self._get_weights_iterator_kv(client)
|
1280
|
+
state_dict = ShardedStateLoader._filter_subtensors(model.state_dict())
|
1281
|
+
for key, tensor in weights_iterator:
|
1282
|
+
# If loading with LoRA enabled, additional padding may
|
1283
|
+
# be added to certain parameters. We only load into a
|
1284
|
+
# narrowed view of the parameter data.
|
1285
|
+
param_data = state_dict[key].data
|
1286
|
+
param_shape = state_dict[key].shape
|
1287
|
+
for dim, size in enumerate(tensor.shape):
|
1288
|
+
if size < param_shape[dim]:
|
1289
|
+
param_data = param_data.narrow(dim, 0, size)
|
1290
|
+
if tensor.shape != param_shape:
|
1291
|
+
logger.warning(
|
1292
|
+
"loading tensor of shape %s into " "parameter '%s' of shape %s",
|
1293
|
+
tensor.shape,
|
1294
|
+
key,
|
1295
|
+
param_shape,
|
1296
|
+
)
|
1297
|
+
param_data.copy_(tensor)
|
1298
|
+
state_dict.pop(key)
|
1299
|
+
if state_dict:
|
1300
|
+
raise ValueError(f"Missing keys {tuple(state_dict)} in loaded state!")
|
1301
|
+
|
1302
|
+
def _load_model_from_remote_fs(
|
1303
|
+
self, model, client, model_config: ModelConfig, device_config: DeviceConfig
|
1304
|
+
) -> nn.Module:
|
1305
|
+
|
1306
|
+
target_device = torch.device(device_config.device)
|
1307
|
+
with set_default_torch_dtype(model_config.dtype):
|
1308
|
+
model.load_weights(self._get_weights_iterator_fs(client))
|
1309
|
+
|
1310
|
+
for _, module in model.named_modules():
|
1311
|
+
quant_method = getattr(module, "quant_method", None)
|
1312
|
+
if quant_method is not None:
|
1313
|
+
# When quant methods need to process weights after loading
|
1314
|
+
# (for repacking, quantizing, etc), they expect parameters
|
1315
|
+
# to be on the global target device. This scope is for the
|
1316
|
+
# case where cpu offloading is used, where we will move the
|
1317
|
+
# parameters onto device for processing and back off after.
|
1318
|
+
with device_loading_context(module, target_device):
|
1319
|
+
quant_method.process_weights_after_loading(module)
|
1320
|
+
|
1321
|
+
def load_model(
|
1322
|
+
self,
|
1323
|
+
*,
|
1324
|
+
model_config: ModelConfig,
|
1325
|
+
device_config: DeviceConfig,
|
1326
|
+
) -> nn.Module:
|
1327
|
+
logger.info("Loading weights from remote storage ...")
|
1328
|
+
start = time.perf_counter()
|
1329
|
+
load_config = self.load_config
|
1330
|
+
|
1331
|
+
assert load_config.load_format == LoadFormat.REMOTE, (
|
1332
|
+
f"Model loader {self.load_config.load_format} is not supported for "
|
1333
|
+
f"load format {load_config.load_format}"
|
1334
|
+
)
|
1335
|
+
|
1336
|
+
model_weights = model_config.model_path
|
1337
|
+
if hasattr(model_config, "model_weights"):
|
1338
|
+
model_weights = model_config.model_weights
|
1339
|
+
|
1340
|
+
with set_default_torch_dtype(model_config.dtype):
|
1341
|
+
with torch.device(device_config.device):
|
1342
|
+
model = _initialize_model(model_config, self.load_config)
|
1343
|
+
for _, module in model.named_modules():
|
1344
|
+
quant_method = getattr(module, "quant_method", None)
|
1345
|
+
if quant_method is not None:
|
1346
|
+
quant_method.process_weights_after_loading(module)
|
1347
|
+
|
1348
|
+
with create_remote_connector(model_weights, device_config.device) as client:
|
1349
|
+
connector_type = get_connector_type(client)
|
1350
|
+
if connector_type == ConnectorType.KV:
|
1351
|
+
self._load_model_from_remote_kv(model, client)
|
1352
|
+
elif connector_type == ConnectorType.FS:
|
1353
|
+
self._load_model_from_remote_fs(
|
1354
|
+
model, client, model_config, device_config
|
1355
|
+
)
|
1356
|
+
|
1357
|
+
end = time.perf_counter()
|
1358
|
+
logger.info("Loaded weights from remote storage in %.2f seconds.", end - start)
|
1359
|
+
return model.eval()
|
1360
|
+
|
1361
|
+
|
1207
1362
|
def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
|
1208
1363
|
"""Get a model loader based on the load format."""
|
1209
1364
|
|
@@ -1225,4 +1380,7 @@ def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
|
|
1225
1380
|
if load_config.load_format == LoadFormat.LAYERED:
|
1226
1381
|
return LayeredModelLoader(load_config)
|
1227
1382
|
|
1383
|
+
if load_config.load_format == LoadFormat.REMOTE:
|
1384
|
+
return RemoteModelLoader(load_config)
|
1385
|
+
|
1228
1386
|
return DefaultModelLoader(load_config)
|
@@ -585,6 +585,51 @@ def composed_weight_loader(
|
|
585
585
|
return composed_loader
|
586
586
|
|
587
587
|
|
588
|
+
def runai_safetensors_weights_iterator(
|
589
|
+
hf_weights_files: List[str],
|
590
|
+
) -> Generator[Tuple[str, torch.Tensor], None, None]:
|
591
|
+
"""Iterate over the weights in the model safetensor files."""
|
592
|
+
from runai_model_streamer import SafetensorsStreamer
|
593
|
+
|
594
|
+
enable_tqdm = (
|
595
|
+
not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0
|
596
|
+
)
|
597
|
+
|
598
|
+
with SafetensorsStreamer() as streamer:
|
599
|
+
for st_file in tqdm(
|
600
|
+
hf_weights_files,
|
601
|
+
desc="Loading safetensors using Runai Model Streamer",
|
602
|
+
disable=not enable_tqdm,
|
603
|
+
bar_format=_BAR_FORMAT,
|
604
|
+
):
|
605
|
+
streamer.stream_file(st_file)
|
606
|
+
yield from streamer.get_tensors()
|
607
|
+
|
608
|
+
|
609
|
+
def set_runai_streamer_env(load_config: LoadConfig):
|
610
|
+
if load_config.model_loader_extra_config:
|
611
|
+
extra_config = load_config.model_loader_extra_config
|
612
|
+
|
613
|
+
if "concurrency" in extra_config and isinstance(
|
614
|
+
extra_config.get("concurrency"), int
|
615
|
+
):
|
616
|
+
os.environ["RUNAI_STREAMER_CONCURRENCY"] = str(
|
617
|
+
extra_config.get("concurrency")
|
618
|
+
)
|
619
|
+
|
620
|
+
if "memory_limit" in extra_config and isinstance(
|
621
|
+
extra_config.get("memory_limit"), int
|
622
|
+
):
|
623
|
+
os.environ["RUNAI_STREAMER_MEMORY_LIMIT"] = str(
|
624
|
+
extra_config.get("memory_limit")
|
625
|
+
)
|
626
|
+
|
627
|
+
runai_streamer_s3_endpoint = os.getenv("RUNAI_STREAMER_S3_ENDPOINT")
|
628
|
+
aws_endpoint_url = os.getenv("AWS_ENDPOINT_URL")
|
629
|
+
if runai_streamer_s3_endpoint is None and aws_endpoint_url is not None:
|
630
|
+
os.environ["RUNAI_STREAMER_S3_ENDPOINT"] = aws_endpoint_url
|
631
|
+
|
632
|
+
|
588
633
|
def initialize_dummy_weights(
|
589
634
|
model: torch.nn.Module,
|
590
635
|
low: float = -1e-3,
|
@@ -47,10 +47,11 @@ from sglang.srt.configs.janus_pro import *
|
|
47
47
|
from sglang.srt.layers.attention.vision import VisionAttention
|
48
48
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
49
49
|
from sglang.srt.layers.quantization import QuantizationConfig
|
50
|
-
from sglang.srt.managers.
|
50
|
+
from sglang.srt.managers.mm_utils import (
|
51
51
|
MultiModalityDataPaddingPatternTokenPairs,
|
52
|
+
general_mm_embed_routine,
|
52
53
|
)
|
53
|
-
from sglang.srt.managers.schedule_batch import
|
54
|
+
from sglang.srt.managers.schedule_batch import MultimodalInputs, global_server_args_dict
|
54
55
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
55
56
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
56
57
|
from sglang.srt.models.llama import LlamaForCausalLM
|
@@ -1289,7 +1290,7 @@ class MlpProjector(nn.Module):
|
|
1289
1290
|
high_x, low_x = x_or_tuple
|
1290
1291
|
high_x = self.high_up_proj(high_x)
|
1291
1292
|
low_x = self.low_up_proj(low_x)
|
1292
|
-
x = torch.
|
1293
|
+
x = torch.cat([high_x, low_x], dim=-1)
|
1293
1294
|
else:
|
1294
1295
|
x = x_or_tuple
|
1295
1296
|
|
@@ -1958,17 +1959,24 @@ class MultiModalityCausalLM(MultiModalityPreTrainedModel):
|
|
1958
1959
|
)
|
1959
1960
|
self.logits_processor = LogitsProcessor(config)
|
1960
1961
|
|
1961
|
-
def
|
1962
|
-
|
1963
|
-
|
1964
|
-
|
1965
|
-
|
1962
|
+
def get_image_feature(self, image_input: MultimodalInputs) -> torch.Tensor:
|
1963
|
+
pixel_values = image_input.pixel_values
|
1964
|
+
bs, n = pixel_values.shape[0:2]
|
1965
|
+
pixel_values = pixel_values.to(
|
1966
|
+
device=self.vision_model.device, dtype=self.vision_model.dtype
|
1966
1967
|
)
|
1967
|
-
|
1968
|
-
|
1969
|
-
|
1970
|
-
|
1971
|
-
|
1968
|
+
images = rearrange(pixel_values, "b n c h w -> (b n) c h w")
|
1969
|
+
|
1970
|
+
# [b x n, T2, D]
|
1971
|
+
images_embeds = self.aligner(self.vision_model(images))
|
1972
|
+
|
1973
|
+
# [b x n, T2, D] -> [b, n x T2, D]
|
1974
|
+
images_embeds = rearrange(images_embeds, "(b n) t d -> b (n t) d", b=bs, n=n)
|
1975
|
+
|
1976
|
+
return images_embeds
|
1977
|
+
|
1978
|
+
def get_input_embeddings(self) -> nn.Embedding:
|
1979
|
+
return self.language_model.model.embed_tokens
|
1972
1980
|
|
1973
1981
|
@torch.no_grad()
|
1974
1982
|
def forward(
|
@@ -1978,90 +1986,25 @@ class MultiModalityCausalLM(MultiModalityPreTrainedModel):
|
|
1978
1986
|
forward_batch: ForwardBatch,
|
1979
1987
|
) -> torch.Tensor:
|
1980
1988
|
|
1981
|
-
inputs_embeds =
|
1982
|
-
|
1983
|
-
forward_batch
|
1984
|
-
|
1985
|
-
|
1986
|
-
)
|
1987
|
-
|
1988
|
-
image_inputs = forward_batch.image_inputs[0]
|
1989
|
-
|
1990
|
-
images_seq_mask = self.prepare_images_seq_mask(
|
1991
|
-
input_ids=input_ids, image_inputs=image_inputs
|
1992
|
-
)
|
1993
|
-
|
1994
|
-
if images_seq_mask is not None:
|
1995
|
-
input_ids.clamp_(min=0, max=self.config.vocab_size - 1)
|
1996
|
-
inputs_embeds = self.prepare_inputs_embeds(
|
1997
|
-
input_ids=input_ids,
|
1998
|
-
pixel_values=image_inputs.pixel_values,
|
1999
|
-
images_seq_mask=images_seq_mask,
|
2000
|
-
images_emb_mask=image_inputs.images_emb_mask,
|
2001
|
-
)
|
2002
|
-
input_ids = None
|
2003
|
-
|
2004
|
-
if input_ids is not None:
|
2005
|
-
input_ids.clamp_(min=0, max=self.config.vocab_size - 1)
|
1989
|
+
inputs_embeds = general_mm_embed_routine(
|
1990
|
+
input_ids=input_ids,
|
1991
|
+
forward_batch=forward_batch,
|
1992
|
+
embed_tokens=self.get_input_embeddings(),
|
1993
|
+
mm_data_embedding_func=self.get_image_feature,
|
1994
|
+
)
|
2006
1995
|
|
2007
1996
|
return self.language_model(
|
2008
|
-
input_ids=
|
1997
|
+
input_ids=None,
|
2009
1998
|
positions=positions,
|
2010
1999
|
forward_batch=forward_batch,
|
2011
2000
|
input_embeds=inputs_embeds,
|
2012
2001
|
get_embedding=False,
|
2013
2002
|
)
|
2014
2003
|
|
2015
|
-
def prepare_inputs_embeds(
|
2016
|
-
self,
|
2017
|
-
input_ids: torch.LongTensor,
|
2018
|
-
pixel_values: torch.FloatTensor,
|
2019
|
-
images_seq_mask: torch.LongTensor,
|
2020
|
-
images_emb_mask: torch.BoolTensor,
|
2021
|
-
**_kwargs,
|
2022
|
-
):
|
2023
|
-
"""
|
2024
|
-
|
2025
|
-
Args:
|
2026
|
-
input_ids (torch.LongTensor): [b, T]
|
2027
|
-
pixel_values (torch.FloatTensor): [b, n_images, 3, h, w]
|
2028
|
-
images_seq_mask (torch.BoolTensor): [b, T]
|
2029
|
-
images_emb_mask (torch.BoolTensor): [b, n_images, n_image_tokens]
|
2030
|
-
|
2031
|
-
assert torch.sum(images_seq_mask) == torch.sum(images_emb_mask)
|
2032
|
-
|
2033
|
-
Returns:
|
2034
|
-
input_embeds (torch.Tensor): [b, T, D]
|
2035
|
-
"""
|
2036
|
-
|
2037
|
-
bs, n = pixel_values.shape[0:2]
|
2038
|
-
pixel_values = pixel_values.to(
|
2039
|
-
device=self.vision_model.device, dtype=self.vision_model.dtype
|
2040
|
-
)
|
2041
|
-
images = rearrange(pixel_values, "b n c h w -> (b n) c h w")
|
2042
|
-
|
2043
|
-
# [b x n, T2, D]
|
2044
|
-
images_embeds = self.aligner(self.vision_model(images))
|
2045
|
-
|
2046
|
-
# [b x n, T2, D] -> [b, n x T2, D]
|
2047
|
-
images_embeds = rearrange(images_embeds, "(b n) t d -> b (n t) d", b=bs, n=n)
|
2048
|
-
# [b, n, T2] -> [b, n x T2]
|
2049
|
-
images_emb_mask = rearrange(images_emb_mask, "b n t -> b (n t)")
|
2050
|
-
|
2051
|
-
# [b, T, D]
|
2052
|
-
# ignore the image embeddings
|
2053
|
-
input_ids[input_ids < 0] = 0
|
2054
|
-
inputs_embeds = self.language_model.model.embed_tokens(input_ids)
|
2055
|
-
|
2056
|
-
# replace with the image embeddings
|
2057
|
-
inputs_embeds[images_seq_mask] = images_embeds[images_emb_mask]
|
2058
|
-
|
2059
|
-
return inputs_embeds
|
2060
|
-
|
2061
2004
|
def prepare_gen_img_embeds(self, image_ids: torch.LongTensor):
|
2062
2005
|
return self.gen_aligner(self.gen_embed(image_ids))
|
2063
2006
|
|
2064
|
-
def pad_input_ids(self, input_ids: List[int], image_inputs:
|
2007
|
+
def pad_input_ids(self, input_ids: List[int], image_inputs: MultimodalInputs):
|
2065
2008
|
im_start_id = image_inputs.im_start_id
|
2066
2009
|
im_end_id = image_inputs.im_end_id
|
2067
2010
|
media_token_pairs = [(im_start_id, im_end_id)]
|
@@ -18,7 +18,6 @@ from typing import Iterable, Optional, Tuple
|
|
18
18
|
import torch
|
19
19
|
from torch import nn
|
20
20
|
from transformers import PretrainedConfig
|
21
|
-
from vllm import _custom_ops as ops
|
22
21
|
|
23
22
|
from sglang.srt.layers.layernorm import RMSNorm
|
24
23
|
from sglang.srt.layers.linear import ReplicatedLinear
|
@@ -41,9 +40,15 @@ from sglang.srt.managers.schedule_batch import global_server_args_dict
|
|
41
40
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
42
41
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
43
42
|
from sglang.srt.models.deepseek_v2 import DeepseekV2DecoderLayer, DeepseekV3ForCausalLM
|
44
|
-
from sglang.srt.utils import add_prefix, is_hip
|
43
|
+
from sglang.srt.utils import add_prefix, is_cuda, is_hip
|
45
44
|
|
46
45
|
_is_hip = is_hip()
|
46
|
+
_is_cuda = is_cuda()
|
47
|
+
|
48
|
+
if _is_cuda:
|
49
|
+
from sgl_kernel import awq_dequantize
|
50
|
+
else:
|
51
|
+
from vllm import _custom_ops as ops
|
47
52
|
|
48
53
|
|
49
54
|
class DeepseekModelNextN(nn.Module):
|
@@ -261,14 +266,21 @@ class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM):
|
|
261
266
|
self_attn = self.model.decoder.self_attn
|
262
267
|
if hasattr(self_attn.kv_b_proj, "qweight"):
|
263
268
|
# AWQ compatible
|
264
|
-
|
265
|
-
|
266
|
-
|
267
|
-
|
268
|
-
|
269
|
-
|
270
|
-
|
271
|
-
|
269
|
+
if _is_cuda:
|
270
|
+
w = awq_dequantize(
|
271
|
+
self_attn.kv_b_proj.qweight,
|
272
|
+
self_attn.kv_b_proj.scales,
|
273
|
+
self_attn.kv_b_proj.qzeros,
|
274
|
+
).T
|
275
|
+
else:
|
276
|
+
w = ops.awq_dequantize(
|
277
|
+
self_attn.kv_b_proj.qweight,
|
278
|
+
self_attn.kv_b_proj.scales,
|
279
|
+
self_attn.kv_b_proj.qzeros,
|
280
|
+
0,
|
281
|
+
0,
|
282
|
+
0,
|
283
|
+
).T
|
272
284
|
else:
|
273
285
|
w = self_attn.kv_b_proj.weight
|
274
286
|
# NOTE(HandH1998): Since `bmm_fp8` only supports per-tensor scale, we have to requantize `self_attn.kv_b_proj`.
|