sglang 0.4.4__py3-none-any.whl → 0.4.4.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (176) hide show
  1. sglang/__init__.py +2 -0
  2. sglang/api.py +6 -0
  3. sglang/bench_one_batch.py +1 -1
  4. sglang/bench_one_batch_server.py +1 -1
  5. sglang/bench_serving.py +3 -1
  6. sglang/check_env.py +3 -4
  7. sglang/lang/backend/openai.py +18 -5
  8. sglang/lang/chat_template.py +28 -7
  9. sglang/lang/interpreter.py +7 -3
  10. sglang/lang/ir.py +10 -0
  11. sglang/srt/_custom_ops.py +1 -1
  12. sglang/srt/code_completion_parser.py +174 -0
  13. sglang/srt/configs/__init__.py +2 -6
  14. sglang/srt/configs/deepseekvl2.py +667 -0
  15. sglang/srt/configs/janus_pro.py +3 -4
  16. sglang/srt/configs/load_config.py +1 -0
  17. sglang/srt/configs/model_config.py +63 -11
  18. sglang/srt/configs/utils.py +25 -0
  19. sglang/srt/connector/__init__.py +51 -0
  20. sglang/srt/connector/base_connector.py +112 -0
  21. sglang/srt/connector/redis.py +85 -0
  22. sglang/srt/connector/s3.py +122 -0
  23. sglang/srt/connector/serde/__init__.py +31 -0
  24. sglang/srt/connector/serde/safe_serde.py +29 -0
  25. sglang/srt/connector/serde/serde.py +43 -0
  26. sglang/srt/connector/utils.py +35 -0
  27. sglang/srt/conversation.py +88 -0
  28. sglang/srt/disaggregation/conn.py +81 -0
  29. sglang/srt/disaggregation/decode.py +495 -0
  30. sglang/srt/disaggregation/mini_lb.py +285 -0
  31. sglang/srt/disaggregation/prefill.py +249 -0
  32. sglang/srt/disaggregation/utils.py +44 -0
  33. sglang/srt/distributed/parallel_state.py +10 -3
  34. sglang/srt/entrypoints/engine.py +55 -5
  35. sglang/srt/entrypoints/http_server.py +71 -12
  36. sglang/srt/function_call_parser.py +164 -54
  37. sglang/srt/hf_transformers_utils.py +28 -3
  38. sglang/srt/layers/activation.py +4 -2
  39. sglang/srt/layers/attention/base_attn_backend.py +1 -1
  40. sglang/srt/layers/attention/flashattention_backend.py +295 -0
  41. sglang/srt/layers/attention/flashinfer_backend.py +1 -1
  42. sglang/srt/layers/attention/flashmla_backend.py +284 -0
  43. sglang/srt/layers/attention/triton_backend.py +171 -38
  44. sglang/srt/layers/attention/triton_ops/decode_attention.py +94 -31
  45. sglang/srt/layers/attention/triton_ops/extend_attention.py +14 -5
  46. sglang/srt/layers/attention/utils.py +53 -0
  47. sglang/srt/layers/attention/vision.py +9 -28
  48. sglang/srt/layers/dp_attention.py +62 -23
  49. sglang/srt/layers/elementwise.py +411 -0
  50. sglang/srt/layers/layernorm.py +24 -2
  51. sglang/srt/layers/linear.py +17 -5
  52. sglang/srt/layers/logits_processor.py +26 -7
  53. sglang/srt/layers/moe/ep_moe/kernels.py +110 -11
  54. sglang/srt/layers/moe/ep_moe/layer.py +273 -1
  55. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +416 -0
  56. sglang/srt/layers/moe/fused_moe_native.py +2 -1
  57. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L20,dtype=int8_w8a8.json +146 -0
  58. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L40S,dtype=int8_w8a8.json +146 -0
  59. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1024,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  60. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  61. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +23 -32
  62. sglang/srt/layers/moe/fused_moe_triton/layer.py +1 -2
  63. sglang/srt/layers/moe/router.py +342 -0
  64. sglang/srt/layers/moe/topk.py +31 -18
  65. sglang/srt/layers/parameter.py +1 -1
  66. sglang/srt/layers/quantization/__init__.py +184 -126
  67. sglang/srt/layers/quantization/base_config.py +5 -0
  68. sglang/srt/layers/quantization/blockwise_int8.py +1 -1
  69. sglang/srt/layers/quantization/compressed_tensors/__init__.py +0 -0
  70. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +652 -0
  71. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +658 -0
  72. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +9 -0
  73. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py +56 -0
  74. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +162 -0
  75. sglang/srt/layers/quantization/compressed_tensors/utils.py +218 -0
  76. sglang/srt/layers/quantization/fp8.py +76 -34
  77. sglang/srt/layers/quantization/fp8_kernel.py +24 -8
  78. sglang/srt/layers/quantization/fp8_utils.py +284 -28
  79. sglang/srt/layers/quantization/gptq.py +36 -9
  80. sglang/srt/layers/quantization/kv_cache.py +98 -0
  81. sglang/srt/layers/quantization/modelopt_quant.py +9 -7
  82. sglang/srt/layers/quantization/utils.py +153 -0
  83. sglang/srt/layers/quantization/w8a8_fp8.py +70 -19
  84. sglang/srt/layers/rotary_embedding.py +66 -87
  85. sglang/srt/layers/sampler.py +1 -1
  86. sglang/srt/lora/layers.py +68 -0
  87. sglang/srt/lora/lora.py +2 -22
  88. sglang/srt/lora/lora_manager.py +47 -23
  89. sglang/srt/lora/mem_pool.py +110 -51
  90. sglang/srt/lora/utils.py +12 -1
  91. sglang/srt/managers/cache_controller.py +4 -5
  92. sglang/srt/managers/data_parallel_controller.py +31 -9
  93. sglang/srt/managers/expert_distribution.py +81 -0
  94. sglang/srt/managers/io_struct.py +39 -3
  95. sglang/srt/managers/mm_utils.py +373 -0
  96. sglang/srt/managers/multimodal_processor.py +68 -0
  97. sglang/srt/managers/multimodal_processors/base_processor.py +275 -0
  98. sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +119 -0
  99. sglang/srt/managers/multimodal_processors/gemma3.py +83 -0
  100. sglang/srt/managers/{image_processors → multimodal_processors}/janus_pro.py +20 -15
  101. sglang/srt/managers/{image_processors → multimodal_processors}/llava.py +10 -15
  102. sglang/srt/managers/multimodal_processors/minicpm.py +167 -0
  103. sglang/srt/managers/{image_processors → multimodal_processors}/mlama.py +7 -8
  104. sglang/srt/managers/{image_processors → multimodal_processors}/qwen_vl.py +28 -22
  105. sglang/srt/managers/schedule_batch.py +134 -31
  106. sglang/srt/managers/scheduler.py +325 -38
  107. sglang/srt/managers/scheduler_output_processor_mixin.py +4 -1
  108. sglang/srt/managers/session_controller.py +1 -1
  109. sglang/srt/managers/tokenizer_manager.py +59 -23
  110. sglang/srt/managers/tp_worker.py +1 -1
  111. sglang/srt/managers/tp_worker_overlap_thread.py +3 -3
  112. sglang/srt/managers/utils.py +6 -1
  113. sglang/srt/mem_cache/hiradix_cache.py +27 -8
  114. sglang/srt/mem_cache/memory_pool.py +258 -98
  115. sglang/srt/mem_cache/paged_allocator.py +2 -2
  116. sglang/srt/mem_cache/radix_cache.py +4 -4
  117. sglang/srt/model_executor/cuda_graph_runner.py +85 -28
  118. sglang/srt/model_executor/forward_batch_info.py +81 -15
  119. sglang/srt/model_executor/model_runner.py +70 -6
  120. sglang/srt/model_loader/loader.py +160 -2
  121. sglang/srt/model_loader/weight_utils.py +45 -0
  122. sglang/srt/models/deepseek_janus_pro.py +29 -86
  123. sglang/srt/models/deepseek_nextn.py +22 -10
  124. sglang/srt/models/deepseek_v2.py +326 -192
  125. sglang/srt/models/deepseek_vl2.py +358 -0
  126. sglang/srt/models/gemma3_causal.py +684 -0
  127. sglang/srt/models/gemma3_mm.py +462 -0
  128. sglang/srt/models/grok.py +374 -119
  129. sglang/srt/models/llama.py +47 -7
  130. sglang/srt/models/llama_eagle.py +1 -0
  131. sglang/srt/models/llama_eagle3.py +196 -0
  132. sglang/srt/models/llava.py +3 -3
  133. sglang/srt/models/llavavid.py +3 -3
  134. sglang/srt/models/minicpmo.py +1995 -0
  135. sglang/srt/models/minicpmv.py +62 -137
  136. sglang/srt/models/mllama.py +4 -4
  137. sglang/srt/models/phi3_small.py +1 -1
  138. sglang/srt/models/qwen2.py +3 -0
  139. sglang/srt/models/qwen2_5_vl.py +68 -146
  140. sglang/srt/models/qwen2_classification.py +75 -0
  141. sglang/srt/models/qwen2_moe.py +9 -1
  142. sglang/srt/models/qwen2_vl.py +25 -63
  143. sglang/srt/openai_api/adapter.py +145 -47
  144. sglang/srt/openai_api/protocol.py +23 -2
  145. sglang/srt/sampling/sampling_batch_info.py +1 -1
  146. sglang/srt/sampling/sampling_params.py +6 -6
  147. sglang/srt/server_args.py +104 -14
  148. sglang/srt/speculative/build_eagle_tree.py +7 -347
  149. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +41 -5
  150. sglang/srt/speculative/eagle_utils.py +208 -252
  151. sglang/srt/speculative/eagle_worker.py +139 -53
  152. sglang/srt/speculative/spec_info.py +6 -1
  153. sglang/srt/torch_memory_saver_adapter.py +22 -0
  154. sglang/srt/utils.py +182 -21
  155. sglang/test/__init__.py +0 -0
  156. sglang/test/attention/__init__.py +0 -0
  157. sglang/test/attention/test_flashattn_backend.py +312 -0
  158. sglang/test/runners.py +2 -0
  159. sglang/test/test_activation.py +2 -1
  160. sglang/test/test_block_fp8.py +5 -4
  161. sglang/test/test_block_fp8_ep.py +2 -1
  162. sglang/test/test_dynamic_grad_mode.py +58 -0
  163. sglang/test/test_layernorm.py +3 -2
  164. sglang/test/test_utils.py +55 -4
  165. sglang/utils.py +31 -0
  166. sglang/version.py +1 -1
  167. {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/METADATA +12 -8
  168. {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/RECORD +171 -125
  169. {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/WHEEL +1 -1
  170. sglang/srt/configs/qwen2_5_vl_config.py +0 -1006
  171. sglang/srt/managers/image_processor.py +0 -55
  172. sglang/srt/managers/image_processors/base_image_processor.py +0 -219
  173. sglang/srt/managers/image_processors/minicpmv.py +0 -86
  174. sglang/srt/managers/multi_modality_padding.py +0 -134
  175. {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info/licenses}/LICENSE +0 -0
  176. {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/top_level.txt +0 -0
@@ -9,6 +9,7 @@ import json
9
9
  import logging
10
10
  import math
11
11
  import os
12
+ import time
12
13
  from abc import ABC, abstractmethod
13
14
  from contextlib import contextmanager
14
15
  from typing import Any, Dict, Generator, Iterable, List, Optional, Tuple, cast
@@ -25,6 +26,12 @@ from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
25
26
  from sglang.srt.configs.device_config import DeviceConfig
26
27
  from sglang.srt.configs.load_config import LoadConfig, LoadFormat
27
28
  from sglang.srt.configs.model_config import ModelConfig
29
+ from sglang.srt.connector import (
30
+ ConnectorType,
31
+ create_remote_connector,
32
+ get_connector_type,
33
+ )
34
+ from sglang.srt.connector.utils import parse_model_name
28
35
  from sglang.srt.distributed import (
29
36
  get_tensor_model_parallel_rank,
30
37
  get_tensor_model_parallel_world_size,
@@ -46,6 +53,7 @@ from sglang.srt.model_loader.weight_utils import (
46
53
  np_cache_weights_iterator,
47
54
  pt_weights_iterator,
48
55
  safetensors_weights_iterator,
56
+ set_runai_streamer_env,
49
57
  )
50
58
  from sglang.srt.utils import (
51
59
  get_bool_env_var,
@@ -194,7 +202,7 @@ class DefaultModelLoader(BaseModelLoader):
194
202
  def _maybe_download_from_modelscope(
195
203
  self, model: str, revision: Optional[str]
196
204
  ) -> Optional[str]:
197
- """Download model from ModelScope hub if VLLM_USE_MODELSCOPE is True.
205
+ """Download model from ModelScope hub if SGLANG_USE_MODELSCOPE is True.
198
206
 
199
207
  Returns the path to the downloaded model, or None if the model is not
200
208
  downloaded from ModelScope."""
@@ -490,7 +498,7 @@ class ShardedStateLoader(BaseModelLoader):
490
498
  Model loader that directly loads each worker's model state dict, which
491
499
  enables a fast load path for large tensor-parallel models where each worker
492
500
  only needs to read its own shard rather than the entire checkpoint. See
493
- `examples/save_sharded_state.py` for creating a sharded checkpoint.
501
+ `examples/runtime/engine/save_sharded_state.py` for creating a sharded checkpoint.
494
502
  """
495
503
 
496
504
  DEFAULT_PATTERN = "model-rank-{rank}-part-{part}.safetensors"
@@ -1204,6 +1212,153 @@ class GGUFModelLoader(BaseModelLoader):
1204
1212
  return model
1205
1213
 
1206
1214
 
1215
+ class RemoteModelLoader(BaseModelLoader):
1216
+ """Model loader that can load Tensors from remote database."""
1217
+
1218
+ def __init__(self, load_config: LoadConfig):
1219
+ super().__init__(load_config)
1220
+ # TODO @DellCurry: move to s3 connector only
1221
+ set_runai_streamer_env(load_config)
1222
+
1223
+ def _get_weights_iterator_kv(
1224
+ self,
1225
+ client,
1226
+ ) -> Generator[Tuple[str, torch.Tensor], None, None]:
1227
+ """Get an iterator for the model weights from remote storage."""
1228
+ assert get_connector_type(client) == ConnectorType.KV
1229
+ rank = get_tensor_model_parallel_rank()
1230
+ return client.weight_iterator(rank)
1231
+
1232
+ def _get_weights_iterator_fs(
1233
+ self,
1234
+ client,
1235
+ ) -> Generator[Tuple[str, torch.Tensor], None, None]:
1236
+ """Get an iterator for the model weights from remote storage."""
1237
+ assert get_connector_type(client) == ConnectorType.FS
1238
+ return client.weight_iterator()
1239
+
1240
+ def download_model(self, model_config: ModelConfig) -> None:
1241
+ pass
1242
+
1243
+ @staticmethod
1244
+ def save_model(
1245
+ model: torch.nn.Module,
1246
+ model_path: str,
1247
+ url: str,
1248
+ ) -> None:
1249
+ with create_remote_connector(url) as client:
1250
+ assert get_connector_type(client) == ConnectorType.KV
1251
+ model_name = parse_model_name(url)
1252
+ rank = get_tensor_model_parallel_rank()
1253
+ state_dict = ShardedStateLoader._filter_subtensors(model.state_dict())
1254
+ for key, tensor in state_dict.items():
1255
+ r_key = f"{model_name}/keys/rank_{rank}/{key}"
1256
+ client.set(r_key, tensor)
1257
+
1258
+ for root, _, files in os.walk(model_path):
1259
+ for file_name in files:
1260
+ # ignore hidden files
1261
+ if file_name.startswith("."):
1262
+ continue
1263
+ if os.path.splitext(file_name)[1] not in (
1264
+ ".bin",
1265
+ ".pt",
1266
+ ".safetensors",
1267
+ ):
1268
+ file_path = os.path.join(root, file_name)
1269
+ with open(file_path, encoding="utf-8") as file:
1270
+ file_content = file.read()
1271
+ f_key = f"{model_name}/files/{file_name}"
1272
+ client.setstr(f_key, file_content)
1273
+
1274
+ def _load_model_from_remote_kv(self, model: nn.Module, client):
1275
+ for _, module in model.named_modules():
1276
+ quant_method = getattr(module, "quant_method", None)
1277
+ if quant_method is not None:
1278
+ quant_method.process_weights_after_loading(module)
1279
+ weights_iterator = self._get_weights_iterator_kv(client)
1280
+ state_dict = ShardedStateLoader._filter_subtensors(model.state_dict())
1281
+ for key, tensor in weights_iterator:
1282
+ # If loading with LoRA enabled, additional padding may
1283
+ # be added to certain parameters. We only load into a
1284
+ # narrowed view of the parameter data.
1285
+ param_data = state_dict[key].data
1286
+ param_shape = state_dict[key].shape
1287
+ for dim, size in enumerate(tensor.shape):
1288
+ if size < param_shape[dim]:
1289
+ param_data = param_data.narrow(dim, 0, size)
1290
+ if tensor.shape != param_shape:
1291
+ logger.warning(
1292
+ "loading tensor of shape %s into " "parameter '%s' of shape %s",
1293
+ tensor.shape,
1294
+ key,
1295
+ param_shape,
1296
+ )
1297
+ param_data.copy_(tensor)
1298
+ state_dict.pop(key)
1299
+ if state_dict:
1300
+ raise ValueError(f"Missing keys {tuple(state_dict)} in loaded state!")
1301
+
1302
+ def _load_model_from_remote_fs(
1303
+ self, model, client, model_config: ModelConfig, device_config: DeviceConfig
1304
+ ) -> nn.Module:
1305
+
1306
+ target_device = torch.device(device_config.device)
1307
+ with set_default_torch_dtype(model_config.dtype):
1308
+ model.load_weights(self._get_weights_iterator_fs(client))
1309
+
1310
+ for _, module in model.named_modules():
1311
+ quant_method = getattr(module, "quant_method", None)
1312
+ if quant_method is not None:
1313
+ # When quant methods need to process weights after loading
1314
+ # (for repacking, quantizing, etc), they expect parameters
1315
+ # to be on the global target device. This scope is for the
1316
+ # case where cpu offloading is used, where we will move the
1317
+ # parameters onto device for processing and back off after.
1318
+ with device_loading_context(module, target_device):
1319
+ quant_method.process_weights_after_loading(module)
1320
+
1321
+ def load_model(
1322
+ self,
1323
+ *,
1324
+ model_config: ModelConfig,
1325
+ device_config: DeviceConfig,
1326
+ ) -> nn.Module:
1327
+ logger.info("Loading weights from remote storage ...")
1328
+ start = time.perf_counter()
1329
+ load_config = self.load_config
1330
+
1331
+ assert load_config.load_format == LoadFormat.REMOTE, (
1332
+ f"Model loader {self.load_config.load_format} is not supported for "
1333
+ f"load format {load_config.load_format}"
1334
+ )
1335
+
1336
+ model_weights = model_config.model_path
1337
+ if hasattr(model_config, "model_weights"):
1338
+ model_weights = model_config.model_weights
1339
+
1340
+ with set_default_torch_dtype(model_config.dtype):
1341
+ with torch.device(device_config.device):
1342
+ model = _initialize_model(model_config, self.load_config)
1343
+ for _, module in model.named_modules():
1344
+ quant_method = getattr(module, "quant_method", None)
1345
+ if quant_method is not None:
1346
+ quant_method.process_weights_after_loading(module)
1347
+
1348
+ with create_remote_connector(model_weights, device_config.device) as client:
1349
+ connector_type = get_connector_type(client)
1350
+ if connector_type == ConnectorType.KV:
1351
+ self._load_model_from_remote_kv(model, client)
1352
+ elif connector_type == ConnectorType.FS:
1353
+ self._load_model_from_remote_fs(
1354
+ model, client, model_config, device_config
1355
+ )
1356
+
1357
+ end = time.perf_counter()
1358
+ logger.info("Loaded weights from remote storage in %.2f seconds.", end - start)
1359
+ return model.eval()
1360
+
1361
+
1207
1362
  def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
1208
1363
  """Get a model loader based on the load format."""
1209
1364
 
@@ -1225,4 +1380,7 @@ def get_model_loader(load_config: LoadConfig) -> BaseModelLoader:
1225
1380
  if load_config.load_format == LoadFormat.LAYERED:
1226
1381
  return LayeredModelLoader(load_config)
1227
1382
 
1383
+ if load_config.load_format == LoadFormat.REMOTE:
1384
+ return RemoteModelLoader(load_config)
1385
+
1228
1386
  return DefaultModelLoader(load_config)
@@ -585,6 +585,51 @@ def composed_weight_loader(
585
585
  return composed_loader
586
586
 
587
587
 
588
+ def runai_safetensors_weights_iterator(
589
+ hf_weights_files: List[str],
590
+ ) -> Generator[Tuple[str, torch.Tensor], None, None]:
591
+ """Iterate over the weights in the model safetensor files."""
592
+ from runai_model_streamer import SafetensorsStreamer
593
+
594
+ enable_tqdm = (
595
+ not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0
596
+ )
597
+
598
+ with SafetensorsStreamer() as streamer:
599
+ for st_file in tqdm(
600
+ hf_weights_files,
601
+ desc="Loading safetensors using Runai Model Streamer",
602
+ disable=not enable_tqdm,
603
+ bar_format=_BAR_FORMAT,
604
+ ):
605
+ streamer.stream_file(st_file)
606
+ yield from streamer.get_tensors()
607
+
608
+
609
+ def set_runai_streamer_env(load_config: LoadConfig):
610
+ if load_config.model_loader_extra_config:
611
+ extra_config = load_config.model_loader_extra_config
612
+
613
+ if "concurrency" in extra_config and isinstance(
614
+ extra_config.get("concurrency"), int
615
+ ):
616
+ os.environ["RUNAI_STREAMER_CONCURRENCY"] = str(
617
+ extra_config.get("concurrency")
618
+ )
619
+
620
+ if "memory_limit" in extra_config and isinstance(
621
+ extra_config.get("memory_limit"), int
622
+ ):
623
+ os.environ["RUNAI_STREAMER_MEMORY_LIMIT"] = str(
624
+ extra_config.get("memory_limit")
625
+ )
626
+
627
+ runai_streamer_s3_endpoint = os.getenv("RUNAI_STREAMER_S3_ENDPOINT")
628
+ aws_endpoint_url = os.getenv("AWS_ENDPOINT_URL")
629
+ if runai_streamer_s3_endpoint is None and aws_endpoint_url is not None:
630
+ os.environ["RUNAI_STREAMER_S3_ENDPOINT"] = aws_endpoint_url
631
+
632
+
588
633
  def initialize_dummy_weights(
589
634
  model: torch.nn.Module,
590
635
  low: float = -1e-3,
@@ -47,10 +47,11 @@ from sglang.srt.configs.janus_pro import *
47
47
  from sglang.srt.layers.attention.vision import VisionAttention
48
48
  from sglang.srt.layers.logits_processor import LogitsProcessor
49
49
  from sglang.srt.layers.quantization import QuantizationConfig
50
- from sglang.srt.managers.multi_modality_padding import (
50
+ from sglang.srt.managers.mm_utils import (
51
51
  MultiModalityDataPaddingPatternTokenPairs,
52
+ general_mm_embed_routine,
52
53
  )
53
- from sglang.srt.managers.schedule_batch import ImageInputs
54
+ from sglang.srt.managers.schedule_batch import MultimodalInputs, global_server_args_dict
54
55
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
55
56
  from sglang.srt.model_loader.weight_utils import default_weight_loader
56
57
  from sglang.srt.models.llama import LlamaForCausalLM
@@ -1289,7 +1290,7 @@ class MlpProjector(nn.Module):
1289
1290
  high_x, low_x = x_or_tuple
1290
1291
  high_x = self.high_up_proj(high_x)
1291
1292
  low_x = self.low_up_proj(low_x)
1292
- x = torch.concat([high_x, low_x], dim=-1)
1293
+ x = torch.cat([high_x, low_x], dim=-1)
1293
1294
  else:
1294
1295
  x = x_or_tuple
1295
1296
 
@@ -1958,17 +1959,24 @@ class MultiModalityCausalLM(MultiModalityPreTrainedModel):
1958
1959
  )
1959
1960
  self.logits_processor = LogitsProcessor(config)
1960
1961
 
1961
- def prepare_images_seq_mask(
1962
- self, input_ids: torch.Tensor, image_inputs: ImageInputs
1963
- ) -> Optional[torch.LongTensor]:
1964
- images_seq_mask = torch.isin(
1965
- input_ids, torch.tensor(image_inputs.pad_values, device=input_ids.device)
1962
+ def get_image_feature(self, image_input: MultimodalInputs) -> torch.Tensor:
1963
+ pixel_values = image_input.pixel_values
1964
+ bs, n = pixel_values.shape[0:2]
1965
+ pixel_values = pixel_values.to(
1966
+ device=self.vision_model.device, dtype=self.vision_model.dtype
1966
1967
  )
1967
- if images_seq_mask.sum() == 0:
1968
- # sometimes image_inputs is not empty, but input_ids contain no image token because of prefix-cache
1969
- return None
1970
- else:
1971
- return images_seq_mask
1968
+ images = rearrange(pixel_values, "b n c h w -> (b n) c h w")
1969
+
1970
+ # [b x n, T2, D]
1971
+ images_embeds = self.aligner(self.vision_model(images))
1972
+
1973
+ # [b x n, T2, D] -> [b, n x T2, D]
1974
+ images_embeds = rearrange(images_embeds, "(b n) t d -> b (n t) d", b=bs, n=n)
1975
+
1976
+ return images_embeds
1977
+
1978
+ def get_input_embeddings(self) -> nn.Embedding:
1979
+ return self.language_model.model.embed_tokens
1972
1980
 
1973
1981
  @torch.no_grad()
1974
1982
  def forward(
@@ -1978,90 +1986,25 @@ class MultiModalityCausalLM(MultiModalityPreTrainedModel):
1978
1986
  forward_batch: ForwardBatch,
1979
1987
  ) -> torch.Tensor:
1980
1988
 
1981
- inputs_embeds = None
1982
- if (
1983
- forward_batch.image_inputs is not None
1984
- and len(forward_batch.image_inputs) != 0
1985
- and forward_batch.image_inputs[0] is not None
1986
- ):
1987
-
1988
- image_inputs = forward_batch.image_inputs[0]
1989
-
1990
- images_seq_mask = self.prepare_images_seq_mask(
1991
- input_ids=input_ids, image_inputs=image_inputs
1992
- )
1993
-
1994
- if images_seq_mask is not None:
1995
- input_ids.clamp_(min=0, max=self.config.vocab_size - 1)
1996
- inputs_embeds = self.prepare_inputs_embeds(
1997
- input_ids=input_ids,
1998
- pixel_values=image_inputs.pixel_values,
1999
- images_seq_mask=images_seq_mask,
2000
- images_emb_mask=image_inputs.images_emb_mask,
2001
- )
2002
- input_ids = None
2003
-
2004
- if input_ids is not None:
2005
- input_ids.clamp_(min=0, max=self.config.vocab_size - 1)
1989
+ inputs_embeds = general_mm_embed_routine(
1990
+ input_ids=input_ids,
1991
+ forward_batch=forward_batch,
1992
+ embed_tokens=self.get_input_embeddings(),
1993
+ mm_data_embedding_func=self.get_image_feature,
1994
+ )
2006
1995
 
2007
1996
  return self.language_model(
2008
- input_ids=input_ids,
1997
+ input_ids=None,
2009
1998
  positions=positions,
2010
1999
  forward_batch=forward_batch,
2011
2000
  input_embeds=inputs_embeds,
2012
2001
  get_embedding=False,
2013
2002
  )
2014
2003
 
2015
- def prepare_inputs_embeds(
2016
- self,
2017
- input_ids: torch.LongTensor,
2018
- pixel_values: torch.FloatTensor,
2019
- images_seq_mask: torch.LongTensor,
2020
- images_emb_mask: torch.BoolTensor,
2021
- **_kwargs,
2022
- ):
2023
- """
2024
-
2025
- Args:
2026
- input_ids (torch.LongTensor): [b, T]
2027
- pixel_values (torch.FloatTensor): [b, n_images, 3, h, w]
2028
- images_seq_mask (torch.BoolTensor): [b, T]
2029
- images_emb_mask (torch.BoolTensor): [b, n_images, n_image_tokens]
2030
-
2031
- assert torch.sum(images_seq_mask) == torch.sum(images_emb_mask)
2032
-
2033
- Returns:
2034
- input_embeds (torch.Tensor): [b, T, D]
2035
- """
2036
-
2037
- bs, n = pixel_values.shape[0:2]
2038
- pixel_values = pixel_values.to(
2039
- device=self.vision_model.device, dtype=self.vision_model.dtype
2040
- )
2041
- images = rearrange(pixel_values, "b n c h w -> (b n) c h w")
2042
-
2043
- # [b x n, T2, D]
2044
- images_embeds = self.aligner(self.vision_model(images))
2045
-
2046
- # [b x n, T2, D] -> [b, n x T2, D]
2047
- images_embeds = rearrange(images_embeds, "(b n) t d -> b (n t) d", b=bs, n=n)
2048
- # [b, n, T2] -> [b, n x T2]
2049
- images_emb_mask = rearrange(images_emb_mask, "b n t -> b (n t)")
2050
-
2051
- # [b, T, D]
2052
- # ignore the image embeddings
2053
- input_ids[input_ids < 0] = 0
2054
- inputs_embeds = self.language_model.model.embed_tokens(input_ids)
2055
-
2056
- # replace with the image embeddings
2057
- inputs_embeds[images_seq_mask] = images_embeds[images_emb_mask]
2058
-
2059
- return inputs_embeds
2060
-
2061
2004
  def prepare_gen_img_embeds(self, image_ids: torch.LongTensor):
2062
2005
  return self.gen_aligner(self.gen_embed(image_ids))
2063
2006
 
2064
- def pad_input_ids(self, input_ids: List[int], image_inputs: ImageInputs):
2007
+ def pad_input_ids(self, input_ids: List[int], image_inputs: MultimodalInputs):
2065
2008
  im_start_id = image_inputs.im_start_id
2066
2009
  im_end_id = image_inputs.im_end_id
2067
2010
  media_token_pairs = [(im_start_id, im_end_id)]
@@ -18,7 +18,6 @@ from typing import Iterable, Optional, Tuple
18
18
  import torch
19
19
  from torch import nn
20
20
  from transformers import PretrainedConfig
21
- from vllm import _custom_ops as ops
22
21
 
23
22
  from sglang.srt.layers.layernorm import RMSNorm
24
23
  from sglang.srt.layers.linear import ReplicatedLinear
@@ -41,9 +40,15 @@ from sglang.srt.managers.schedule_batch import global_server_args_dict
41
40
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
42
41
  from sglang.srt.model_loader.weight_utils import default_weight_loader
43
42
  from sglang.srt.models.deepseek_v2 import DeepseekV2DecoderLayer, DeepseekV3ForCausalLM
44
- from sglang.srt.utils import add_prefix, is_hip
43
+ from sglang.srt.utils import add_prefix, is_cuda, is_hip
45
44
 
46
45
  _is_hip = is_hip()
46
+ _is_cuda = is_cuda()
47
+
48
+ if _is_cuda:
49
+ from sgl_kernel import awq_dequantize
50
+ else:
51
+ from vllm import _custom_ops as ops
47
52
 
48
53
 
49
54
  class DeepseekModelNextN(nn.Module):
@@ -261,14 +266,21 @@ class DeepseekV3ForCausalLMNextN(DeepseekV3ForCausalLM):
261
266
  self_attn = self.model.decoder.self_attn
262
267
  if hasattr(self_attn.kv_b_proj, "qweight"):
263
268
  # AWQ compatible
264
- w = ops.awq_dequantize(
265
- self_attn.kv_b_proj.qweight,
266
- self_attn.kv_b_proj.scales,
267
- self_attn.kv_b_proj.qzeros,
268
- 0,
269
- 0,
270
- 0,
271
- ).T
269
+ if _is_cuda:
270
+ w = awq_dequantize(
271
+ self_attn.kv_b_proj.qweight,
272
+ self_attn.kv_b_proj.scales,
273
+ self_attn.kv_b_proj.qzeros,
274
+ ).T
275
+ else:
276
+ w = ops.awq_dequantize(
277
+ self_attn.kv_b_proj.qweight,
278
+ self_attn.kv_b_proj.scales,
279
+ self_attn.kv_b_proj.qzeros,
280
+ 0,
281
+ 0,
282
+ 0,
283
+ ).T
272
284
  else:
273
285
  w = self_attn.kv_b_proj.weight
274
286
  # NOTE(HandH1998): Since `bmm_fp8` only supports per-tensor scale, we have to requantize `self_attn.kv_b_proj`.