sglang 0.3.6.post3__py3-none-any.whl → 0.4.0.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (116) hide show
  1. sglang/__init__.py +1 -1
  2. sglang/bench_one_batch.py +4 -0
  3. sglang/bench_serving.py +13 -0
  4. sglang/check_env.py +1 -1
  5. sglang/srt/_custom_ops.py +118 -0
  6. sglang/srt/configs/device_config.py +17 -0
  7. sglang/srt/configs/load_config.py +84 -0
  8. sglang/srt/configs/model_config.py +161 -4
  9. sglang/srt/configs/qwen2vl.py +5 -8
  10. sglang/srt/constrained/outlines_backend.py +11 -1
  11. sglang/srt/constrained/outlines_jump_forward.py +8 -1
  12. sglang/srt/constrained/xgrammar_backend.py +5 -5
  13. sglang/srt/distributed/__init__.py +3 -0
  14. sglang/srt/distributed/communication_op.py +34 -0
  15. sglang/srt/distributed/device_communicators/__init__.py +0 -0
  16. sglang/srt/distributed/device_communicators/cuda_wrapper.py +182 -0
  17. sglang/srt/distributed/device_communicators/custom_all_reduce.py +352 -0
  18. sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +291 -0
  19. sglang/srt/distributed/device_communicators/hpu_communicator.py +48 -0
  20. sglang/srt/distributed/device_communicators/pynccl.py +204 -0
  21. sglang/srt/distributed/device_communicators/pynccl_wrapper.py +362 -0
  22. sglang/srt/distributed/device_communicators/shm_broadcast.py +568 -0
  23. sglang/srt/distributed/device_communicators/xpu_communicator.py +47 -0
  24. sglang/srt/distributed/parallel_state.py +1275 -0
  25. sglang/srt/distributed/utils.py +223 -0
  26. sglang/srt/hf_transformers_utils.py +37 -1
  27. sglang/srt/layers/attention/__init__.py +5 -2
  28. sglang/srt/layers/attention/double_sparsity_backend.py +22 -8
  29. sglang/srt/layers/attention/flashinfer_backend.py +33 -20
  30. sglang/srt/layers/attention/torch_native_backend.py +299 -0
  31. sglang/srt/layers/attention/triton_backend.py +22 -8
  32. sglang/srt/layers/attention/triton_ops/extend_attention.py +3 -0
  33. sglang/srt/layers/ep_moe/__init__.py +0 -0
  34. sglang/srt/layers/ep_moe/kernels.py +349 -0
  35. sglang/srt/layers/ep_moe/layer.py +661 -0
  36. sglang/srt/layers/fused_moe_patch.py +20 -11
  37. sglang/srt/layers/linear.py +1 -0
  38. sglang/srt/layers/logits_processor.py +17 -3
  39. sglang/srt/layers/quantization/__init__.py +36 -2
  40. sglang/srt/layers/quantization/fp8.py +559 -0
  41. sglang/srt/layers/quantization/fp8_utils.py +27 -0
  42. sglang/srt/layers/radix_attention.py +4 -2
  43. sglang/srt/layers/sampler.py +2 -0
  44. sglang/srt/layers/torchao_utils.py +23 -45
  45. sglang/srt/layers/vocab_parallel_embedding.py +1 -0
  46. sglang/srt/lora/lora.py +1 -1
  47. sglang/srt/managers/io_struct.py +48 -2
  48. sglang/srt/managers/schedule_batch.py +19 -14
  49. sglang/srt/managers/schedule_policy.py +7 -4
  50. sglang/srt/managers/scheduler.py +145 -85
  51. sglang/srt/managers/tokenizer_manager.py +166 -68
  52. sglang/srt/managers/tp_worker.py +36 -3
  53. sglang/srt/managers/tp_worker_overlap_thread.py +28 -8
  54. sglang/srt/mem_cache/memory_pool.py +5 -1
  55. sglang/srt/model_executor/cuda_graph_runner.py +30 -7
  56. sglang/srt/model_executor/forward_batch_info.py +9 -4
  57. sglang/srt/model_executor/model_runner.py +146 -153
  58. sglang/srt/model_loader/__init__.py +34 -0
  59. sglang/srt/model_loader/loader.py +1139 -0
  60. sglang/srt/model_loader/utils.py +41 -0
  61. sglang/srt/model_loader/weight_utils.py +640 -0
  62. sglang/srt/model_parallel.py +1 -5
  63. sglang/srt/models/baichuan.py +9 -10
  64. sglang/srt/models/chatglm.py +6 -15
  65. sglang/srt/models/commandr.py +4 -5
  66. sglang/srt/models/dbrx.py +2 -3
  67. sglang/srt/models/deepseek.py +4 -11
  68. sglang/srt/models/deepseek_v2.py +90 -18
  69. sglang/srt/models/exaone.py +2 -3
  70. sglang/srt/models/gemma.py +2 -6
  71. sglang/srt/models/gemma2.py +3 -14
  72. sglang/srt/models/gemma2_reward.py +0 -1
  73. sglang/srt/models/gpt2.py +5 -12
  74. sglang/srt/models/gpt_bigcode.py +6 -22
  75. sglang/srt/models/grok.py +3 -8
  76. sglang/srt/models/internlm2.py +2 -3
  77. sglang/srt/models/internlm2_reward.py +0 -1
  78. sglang/srt/models/llama.py +96 -31
  79. sglang/srt/models/llama_classification.py +1 -2
  80. sglang/srt/models/llama_embedding.py +1 -2
  81. sglang/srt/models/llama_reward.py +2 -3
  82. sglang/srt/models/llava.py +1 -4
  83. sglang/srt/models/llavavid.py +1 -2
  84. sglang/srt/models/minicpm.py +4 -7
  85. sglang/srt/models/minicpm3.py +6 -19
  86. sglang/srt/models/mixtral.py +24 -14
  87. sglang/srt/models/mixtral_quant.py +2 -3
  88. sglang/srt/models/mllama.py +3 -7
  89. sglang/srt/models/olmo.py +2 -8
  90. sglang/srt/models/olmo2.py +0 -1
  91. sglang/srt/models/olmoe.py +3 -5
  92. sglang/srt/models/phi3_small.py +8 -13
  93. sglang/srt/models/qwen.py +2 -3
  94. sglang/srt/models/qwen2.py +10 -9
  95. sglang/srt/models/qwen2_moe.py +4 -16
  96. sglang/srt/models/qwen2_vl.py +2 -6
  97. sglang/srt/models/registry.py +99 -0
  98. sglang/srt/models/stablelm.py +2 -3
  99. sglang/srt/models/torch_native_llama.py +6 -17
  100. sglang/srt/models/xverse.py +2 -4
  101. sglang/srt/models/xverse_moe.py +4 -11
  102. sglang/srt/models/yivl.py +2 -3
  103. sglang/srt/openai_api/adapter.py +9 -5
  104. sglang/srt/openai_api/protocol.py +1 -0
  105. sglang/srt/sampling/sampling_batch_info.py +9 -8
  106. sglang/srt/server.py +270 -173
  107. sglang/srt/server_args.py +102 -29
  108. sglang/srt/utils.py +295 -28
  109. sglang/test/test_utils.py +7 -0
  110. sglang/version.py +1 -1
  111. {sglang-0.3.6.post3.dist-info → sglang-0.4.0.post1.dist-info}/METADATA +5 -4
  112. sglang-0.4.0.post1.dist-info/RECORD +189 -0
  113. sglang-0.3.6.post3.dist-info/RECORD +0 -162
  114. {sglang-0.3.6.post3.dist-info → sglang-0.4.0.post1.dist-info}/LICENSE +0 -0
  115. {sglang-0.3.6.post3.dist-info → sglang-0.4.0.post1.dist-info}/WHEEL +0 -0
  116. {sglang-0.3.6.post3.dist-info → sglang-0.4.0.post1.dist-info}/top_level.txt +0 -0
@@ -22,7 +22,6 @@ import torch
22
22
  from torch import nn
23
23
  from vllm.distributed import get_tensor_model_parallel_world_size
24
24
  from vllm.model_executor.layers.rotary_embedding import get_rope
25
- from vllm.model_executor.model_loader.weight_utils import default_weight_loader
26
25
 
27
26
  from sglang.srt.layers.activation import SiluAndMul
28
27
  from sglang.srt.layers.layernorm import RMSNorm
@@ -40,6 +39,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
40
39
  VocabParallelEmbedding,
41
40
  )
42
41
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
42
+ from sglang.srt.model_loader.weight_utils import default_weight_loader
43
43
  from sglang.srt.utils import make_layers
44
44
 
45
45
  Qwen2Config = None
@@ -230,6 +230,7 @@ class Qwen2Model(nn.Module):
230
230
  self.embed_tokens = VocabParallelEmbedding(
231
231
  config.vocab_size,
232
232
  config.hidden_size,
233
+ quant_config=quant_config,
233
234
  )
234
235
  self.layers = make_layers(
235
236
  config.num_hidden_layers,
@@ -270,13 +271,17 @@ class Qwen2ForCausalLM(nn.Module):
270
271
  self,
271
272
  config: Qwen2Config,
272
273
  quant_config: Optional[QuantizationConfig] = None,
273
- cache_config=None,
274
274
  ) -> None:
275
275
  super().__init__()
276
276
  self.config = config
277
277
  self.quant_config = quant_config
278
278
  self.model = Qwen2Model(config, quant_config=quant_config)
279
- self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
279
+ if config.tie_word_embeddings:
280
+ self.lm_head = self.model.embed_tokens
281
+ else:
282
+ self.lm_head = ParallelLMHead(
283
+ config.vocab_size, config.hidden_size, quant_config=quant_config
284
+ )
280
285
  self.logits_processor = LogitsProcessor(config)
281
286
  self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
282
287
 
@@ -292,7 +297,7 @@ class Qwen2ForCausalLM(nn.Module):
292
297
  hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
293
298
  if not get_embedding:
294
299
  return self.logits_processor(
295
- input_ids, hidden_states, self.lm_head.weight, forward_batch
300
+ input_ids, hidden_states, self.lm_head, forward_batch
296
301
  )
297
302
  else:
298
303
  return self.pooler(hidden_states, forward_batch)
@@ -306,6 +311,7 @@ class Qwen2ForCausalLM(nn.Module):
306
311
  ("gate_up_proj", "gate_proj", 0),
307
312
  ("gate_up_proj", "up_proj", 1),
308
313
  ]
314
+
309
315
  params_dict = dict(self.named_parameters())
310
316
  for name, loaded_weight in weights:
311
317
  if "rotary_emb.inv_freq" in name or "projector" in name:
@@ -335,11 +341,6 @@ class Qwen2ForCausalLM(nn.Module):
335
341
  param = params_dict[name]
336
342
  weight_loader = getattr(param, "weight_loader", default_weight_loader)
337
343
  weight_loader(param, loaded_weight)
338
- if (
339
- self.config.tie_word_embeddings
340
- and name == "model.embed_tokens.weight"
341
- ):
342
- weight_loader(params_dict["lm_head.weight"], loaded_weight)
343
344
 
344
345
 
345
346
  EntryClass = Qwen2ForCausalLM
@@ -27,7 +27,6 @@ from vllm.distributed import (
27
27
  tensor_model_parallel_all_reduce,
28
28
  )
29
29
  from vllm.model_executor.layers.rotary_embedding import get_rope
30
- from vllm.model_executor.model_loader.weight_utils import default_weight_loader
31
30
 
32
31
  from sglang.srt.layers.activation import SiluAndMul
33
32
  from sglang.srt.layers.fused_moe_triton import FusedMoE
@@ -41,13 +40,12 @@ from sglang.srt.layers.linear import (
41
40
  from sglang.srt.layers.logits_processor import LogitsProcessor
42
41
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
43
42
  from sglang.srt.layers.radix_attention import RadixAttention
44
- from sglang.srt.layers.torchao_utils import apply_torchao_config_
45
43
  from sglang.srt.layers.vocab_parallel_embedding import (
46
44
  ParallelLMHead,
47
45
  VocabParallelEmbedding,
48
46
  )
49
- from sglang.srt.managers.schedule_batch import global_server_args_dict
50
47
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
48
+ from sglang.srt.model_loader.weight_utils import default_weight_loader
51
49
 
52
50
 
53
51
  class Qwen2MoeMLP(nn.Module):
@@ -158,7 +156,6 @@ class Qwen2MoeAttention(nn.Module):
158
156
  rope_theta: float = 10000,
159
157
  rope_scaling: Optional[Dict[str, Any]] = None,
160
158
  max_position_embeddings: int = 8192,
161
- cache_config=None,
162
159
  quant_config: Optional[QuantizationConfig] = None,
163
160
  ) -> None:
164
161
  super().__init__()
@@ -234,7 +231,6 @@ class Qwen2MoeDecoderLayer(nn.Module):
234
231
  self,
235
232
  config: PretrainedConfig,
236
233
  layer_id: int,
237
- cache_config=None,
238
234
  quant_config: Optional[QuantizationConfig] = None,
239
235
  ) -> None:
240
236
  super().__init__()
@@ -250,7 +246,6 @@ class Qwen2MoeDecoderLayer(nn.Module):
250
246
  rope_theta=rope_theta,
251
247
  rope_scaling=rope_scaling,
252
248
  max_position_embeddings=max_position_embeddings,
253
- cache_config=cache_config,
254
249
  quant_config=quant_config,
255
250
  )
256
251
 
@@ -304,7 +299,6 @@ class Qwen2MoeModel(nn.Module):
304
299
  def __init__(
305
300
  self,
306
301
  config: PretrainedConfig,
307
- cache_config=None,
308
302
  quant_config: Optional[QuantizationConfig] = None,
309
303
  ) -> None:
310
304
  super().__init__()
@@ -317,9 +311,7 @@ class Qwen2MoeModel(nn.Module):
317
311
  )
318
312
  self.layers = nn.ModuleList(
319
313
  [
320
- Qwen2MoeDecoderLayer(
321
- config, layer_id, cache_config, quant_config=quant_config
322
- )
314
+ Qwen2MoeDecoderLayer(config, layer_id, quant_config=quant_config)
323
315
  for layer_id in range(config.num_hidden_layers)
324
316
  ]
325
317
  )
@@ -353,14 +345,12 @@ class Qwen2MoeForCausalLM(nn.Module):
353
345
  def __init__(
354
346
  self,
355
347
  config: PretrainedConfig,
356
- cache_config=None,
357
348
  quant_config: Optional[QuantizationConfig] = None,
358
349
  ) -> None:
359
350
  super().__init__()
360
351
  self.config = config
361
352
  self.quant_config = quant_config
362
- self.torchao_config = global_server_args_dict["torchao_config"]
363
- self.model = Qwen2MoeModel(config, cache_config, quant_config)
353
+ self.model = Qwen2MoeModel(config, quant_config)
364
354
  self.lm_head = ParallelLMHead(
365
355
  config.vocab_size, config.hidden_size, quant_config=quant_config
366
356
  )
@@ -376,7 +366,7 @@ class Qwen2MoeForCausalLM(nn.Module):
376
366
  ) -> torch.Tensor:
377
367
  hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
378
368
  return self.logits_processor(
379
- input_ids, hidden_states, self.lm_head.weight, forward_batch
369
+ input_ids, hidden_states, self.lm_head, forward_batch
380
370
  )
381
371
 
382
372
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
@@ -452,7 +442,5 @@ class Qwen2MoeForCausalLM(nn.Module):
452
442
  )
453
443
  weight_loader(param, loaded_weight)
454
444
 
455
- apply_torchao_config_(self, params_dict, set(["proj.weight"]))
456
-
457
445
 
458
446
  EntryClass = Qwen2MoeForCausalLM
@@ -30,12 +30,10 @@ import torch
30
30
  import torch.nn as nn
31
31
  import torch.nn.functional as F
32
32
  from einops import rearrange, repeat
33
- from vllm.config import CacheConfig, MultiModalConfig
34
33
  from vllm.distributed import parallel_state
35
34
  from vllm.distributed import utils as dist_utils
36
35
  from vllm.logger import init_logger
37
36
  from vllm.model_executor.layers.activation import QuickGELU
38
- from vllm.model_executor.model_loader.weight_utils import default_weight_loader
39
37
 
40
38
  from sglang.srt.configs import Qwen2VLConfig, Qwen2VLVisionConfig
41
39
  from sglang.srt.hf_transformers_utils import get_processor
@@ -49,6 +47,7 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig
49
47
  from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
50
48
  from sglang.srt.managers.schedule_batch import ImageInputs
51
49
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
50
+ from sglang.srt.model_loader.weight_utils import default_weight_loader
52
51
  from sglang.srt.models.qwen2 import Qwen2Model
53
52
 
54
53
  logger = init_logger(__name__)
@@ -536,7 +535,6 @@ class Qwen2VLForConditionalGeneration(nn.Module):
536
535
  def __init__(
537
536
  self,
538
537
  config: Qwen2VLConfig,
539
- cache_config: Optional[CacheConfig] = None,
540
538
  quant_config: Optional[QuantizationConfig] = None,
541
539
  ) -> None:
542
540
  super().__init__()
@@ -668,7 +666,7 @@ class Qwen2VLForConditionalGeneration(nn.Module):
668
666
 
669
667
  if not get_embedding:
670
668
  return self.logits_processor(
671
- input_ids, hidden_states, self.lm_head.weight, forward_batch
669
+ input_ids, hidden_states, self.lm_head, forward_batch
672
670
  )
673
671
  else:
674
672
  return self.pooler(hidden_states, forward_batch)
@@ -686,8 +684,6 @@ class Qwen2VLForConditionalGeneration(nn.Module):
686
684
  for name, loaded_weight in weights:
687
685
  if "rotary_emb.inv_freq" in name:
688
686
  continue
689
- if self.config.tie_word_embeddings and "lm_head.weight" in name:
690
- continue
691
687
  for param_name, weight_name, shard_id in stacked_params_mapping:
692
688
  if weight_name not in name:
693
689
  continue
@@ -0,0 +1,99 @@
1
+ # Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/models/registry.py
2
+
3
+ import importlib
4
+ import logging
5
+ import pkgutil
6
+ from dataclasses import dataclass, field
7
+ from functools import lru_cache
8
+ from typing import AbstractSet, Dict, List, Optional, Tuple, Type, Union
9
+
10
+ import torch.nn as nn
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ @dataclass
16
+ class _ModelRegistry:
17
+ # Keyed by model_arch
18
+ models: Dict[str, Union[Type[nn.Module], str]] = field(default_factory=dict)
19
+
20
+ def get_supported_archs(self) -> AbstractSet[str]:
21
+ return self.models.keys()
22
+
23
+ def _raise_for_unsupported(self, architectures: List[str]):
24
+ all_supported_archs = self.get_supported_archs()
25
+
26
+ if any(arch in all_supported_archs for arch in architectures):
27
+ raise ValueError(
28
+ f"Model architectures {architectures} failed "
29
+ "to be inspected. Please check the logs for more details."
30
+ )
31
+
32
+ raise ValueError(
33
+ f"Model architectures {architectures} are not supported for now. "
34
+ f"Supported architectures: {all_supported_archs}"
35
+ )
36
+
37
+ def _try_load_model_cls(self, model_arch: str) -> Optional[Type[nn.Module]]:
38
+ if model_arch not in self.models:
39
+ return None
40
+
41
+ return self.models[model_arch]
42
+
43
+ def _normalize_archs(
44
+ self,
45
+ architectures: Union[str, List[str]],
46
+ ) -> List[str]:
47
+ if isinstance(architectures, str):
48
+ architectures = [architectures]
49
+ if not architectures:
50
+ logger.warning("No model architectures are specified")
51
+
52
+ return architectures
53
+
54
+ def resolve_model_cls(
55
+ self,
56
+ architectures: Union[str, List[str]],
57
+ ) -> Tuple[Type[nn.Module], str]:
58
+ architectures = self._normalize_archs(architectures)
59
+
60
+ for arch in architectures:
61
+ model_cls = self._try_load_model_cls(arch)
62
+ if model_cls is not None:
63
+ return (model_cls, arch)
64
+
65
+ return self._raise_for_unsupported(architectures)
66
+
67
+
68
+ @lru_cache()
69
+ def import_model_classes():
70
+ model_arch_name_to_cls = {}
71
+ package_name = "sglang.srt.models"
72
+ package = importlib.import_module(package_name)
73
+ for _, name, ispkg in pkgutil.iter_modules(package.__path__, package_name + "."):
74
+ if not ispkg:
75
+ try:
76
+ module = importlib.import_module(name)
77
+ except Exception as e:
78
+ logger.warning(f"Ignore import error when loading {name}. " f"{e}")
79
+ continue
80
+ if hasattr(module, "EntryClass"):
81
+ entry = module.EntryClass
82
+ if isinstance(
83
+ entry, list
84
+ ): # To support multiple model classes in one module
85
+ for tmp in entry:
86
+ assert (
87
+ tmp.__name__ not in model_arch_name_to_cls
88
+ ), f"Duplicated model implementation for {tmp.__name__}"
89
+ model_arch_name_to_cls[tmp.__name__] = tmp
90
+ else:
91
+ assert (
92
+ entry.__name__ not in model_arch_name_to_cls
93
+ ), f"Duplicated model implementation for {entry.__name__}"
94
+ model_arch_name_to_cls[entry.__name__] = entry
95
+
96
+ return model_arch_name_to_cls
97
+
98
+
99
+ ModelRegistry = _ModelRegistry(import_model_classes())
@@ -26,7 +26,6 @@ from torch import nn
26
26
  from transformers import PretrainedConfig
27
27
  from vllm.distributed import get_tensor_model_parallel_world_size
28
28
  from vllm.model_executor.layers.rotary_embedding import get_rope
29
- from vllm.model_executor.model_loader.weight_utils import default_weight_loader
30
29
 
31
30
  from sglang.srt.layers.activation import SiluAndMul
32
31
  from sglang.srt.layers.linear import (
@@ -42,6 +41,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
42
41
  VocabParallelEmbedding,
43
42
  )
44
43
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
44
+ from sglang.srt.model_loader.weight_utils import default_weight_loader
45
45
 
46
46
 
47
47
  class StablelmMLP(nn.Module):
@@ -242,7 +242,6 @@ class StableLmForCausalLM(nn.Module):
242
242
  self,
243
243
  config: PretrainedConfig,
244
244
  quant_config: Optional[QuantizationConfig] = None,
245
- cache_config=None,
246
245
  ) -> None:
247
246
  super().__init__()
248
247
  self.config = config
@@ -261,7 +260,7 @@ class StableLmForCausalLM(nn.Module):
261
260
  ) -> torch.Tensor:
262
261
  hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
263
262
  return self.logits_processor(
264
- input_ids, hidden_states, self.lm_head.weight, forward_batch
263
+ input_ids, hidden_states, self.lm_head, forward_batch
265
264
  )
266
265
 
267
266
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
@@ -52,20 +52,18 @@ from vllm.distributed import (
52
52
  get_tensor_model_parallel_world_size,
53
53
  )
54
54
  from vllm.model_executor.layers.rotary_embedding import get_rope
55
- from vllm.model_executor.model_loader.weight_utils import default_weight_loader
56
55
 
57
56
  from sglang.srt.layers.activation import SiluAndMul
58
57
  from sglang.srt.layers.layernorm import RMSNorm
59
58
  from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
60
59
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
61
60
  from sglang.srt.layers.radix_attention import RadixAttention
62
- from sglang.srt.layers.torchao_utils import apply_torchao_config_
63
61
  from sglang.srt.layers.vocab_parallel_embedding import (
64
62
  ParallelLMHead,
65
63
  VocabParallelEmbedding,
66
64
  )
67
- from sglang.srt.managers.schedule_batch import global_server_args_dict
68
65
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
66
+ from sglang.srt.model_loader.weight_utils import default_weight_loader
69
67
 
70
68
  tp_size = get_tensor_model_parallel_world_size()
71
69
  tp_rank = get_tensor_model_parallel_rank()
@@ -388,15 +386,16 @@ class TorchNativeLlamaForCausalLM(nn.Module):
388
386
  self,
389
387
  config: LlamaConfig,
390
388
  quant_config: Optional[QuantizationConfig] = None,
391
- cache_config=None,
392
389
  ) -> None:
393
390
  super().__init__()
394
391
  self.config = config
395
392
  self.quant_config = quant_config
396
- self.torchao_config = global_server_args_dict["torchao_config"]
397
393
  self.supports_torch_tp = True
398
394
  self.model = LlamaModel(config, quant_config=quant_config)
399
- self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
395
+ if self.config.tie_word_embeddings:
396
+ self.lm_head = self.model.embed_tokens
397
+ else:
398
+ self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
400
399
  self.logits_processor = LogitsProcessor(config)
401
400
 
402
401
  # turning off autotune for fp8dq since it doesn't give speedup and
@@ -413,7 +412,7 @@ class TorchNativeLlamaForCausalLM(nn.Module):
413
412
  ) -> LogitsProcessorOutput:
414
413
  hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
415
414
  return self.logits_processor(
416
- input_ids, hidden_states, self.lm_head.weight, forward_batch
415
+ input_ids, hidden_states, self.lm_head, forward_batch
417
416
  )
418
417
 
419
418
  def get_hidden_dim(self, module_name):
@@ -501,16 +500,6 @@ class TorchNativeLlamaForCausalLM(nn.Module):
501
500
  weight_loader = getattr(param, "weight_loader", default_weight_loader)
502
501
  weight_loader(param, loaded_weight)
503
502
 
504
- if (
505
- hasattr(self.config, "tie_word_embeddings")
506
- and self.config.tie_word_embeddings
507
- ):
508
- # Tie output embedding layer to input embedding layer, to solve issues where lm_head.weight is missing
509
- param = self.lm_head.weight
510
- weight_loader = getattr(param, "weight_loader", default_weight_loader)
511
- weight_loader(param, self.model.embed_tokens.weight)
512
- apply_torchao_config_(self, params_dict, set(["proj.weight"]))
513
-
514
503
 
515
504
  class TorchNativePhi3ForCausalLM(TorchNativeLlamaForCausalLM):
516
505
  pass
@@ -30,7 +30,6 @@ from vllm.model_executor.layers.linear import (
30
30
  RowParallelLinear,
31
31
  )
32
32
  from vllm.model_executor.layers.rotary_embedding import get_rope
33
- from vllm.model_executor.model_loader.weight_utils import default_weight_loader
34
33
 
35
34
  from sglang.srt.layers.logits_processor import LogitsProcessor
36
35
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
@@ -40,6 +39,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
40
39
  VocabParallelEmbedding,
41
40
  )
42
41
  from sglang.srt.model_executor.model_runner import ForwardBatch
42
+ from sglang.srt.model_loader.weight_utils import default_weight_loader
43
43
 
44
44
 
45
45
  class XverseMLP(nn.Module):
@@ -295,8 +295,6 @@ class XverseForCausalLM(nn.Module):
295
295
  self,
296
296
  config: LlamaConfig,
297
297
  quant_config: Optional[QuantizationConfig] = None,
298
- cache_config=None,
299
- efficient_weight_load=False,
300
298
  ) -> None:
301
299
  super().__init__()
302
300
  self.config = config
@@ -315,7 +313,7 @@ class XverseForCausalLM(nn.Module):
315
313
  ) -> torch.Tensor:
316
314
  hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
317
315
  return self.logits_processor(
318
- input_ids, hidden_states, self.lm_head.weight, forward_batch
316
+ input_ids, hidden_states, self.lm_head, forward_batch
319
317
  )
320
318
 
321
319
  def load_weights(
@@ -32,7 +32,6 @@ from vllm.model_executor.layers.linear import (
32
32
  RowParallelLinear,
33
33
  )
34
34
  from vllm.model_executor.layers.rotary_embedding import get_rope
35
- from vllm.model_executor.model_loader.weight_utils import default_weight_loader
36
35
 
37
36
  from sglang.srt.layers.fused_moe_triton import fused_moe
38
37
  from sglang.srt.layers.logits_processor import LogitsProcessor
@@ -43,6 +42,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
43
42
  VocabParallelEmbedding,
44
43
  )
45
44
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
45
+ from sglang.srt.model_loader.weight_utils import default_weight_loader
46
46
 
47
47
 
48
48
  class XverseMLP(nn.Module):
@@ -181,7 +181,6 @@ class XverseAttention(nn.Module):
181
181
  rope_theta: float = 10000,
182
182
  rope_scaling: Optional[Dict[str, Any]] = None,
183
183
  max_position_embeddings: int = 8192,
184
- cache_config=None,
185
184
  quant_config: Optional[QuantizationConfig] = None,
186
185
  ) -> None:
187
186
  super().__init__()
@@ -258,7 +257,6 @@ class XverseDecoderLayer(nn.Module):
258
257
  self,
259
258
  config: PretrainedConfig,
260
259
  layer_id: int,
261
- cache_config=None,
262
260
  quant_config: Optional[QuantizationConfig] = None,
263
261
  ) -> None:
264
262
  super().__init__()
@@ -277,7 +275,6 @@ class XverseDecoderLayer(nn.Module):
277
275
  rope_theta=rope_theta,
278
276
  rope_scaling=rope_scaling,
279
277
  max_position_embeddings=max_position_embeddings,
280
- cache_config=cache_config,
281
278
  quant_config=quant_config,
282
279
  )
283
280
  if config.num_experts is not None:
@@ -326,7 +323,6 @@ class XverseModel(nn.Module):
326
323
  def __init__(
327
324
  self,
328
325
  config: PretrainedConfig,
329
- cache_config=None,
330
326
  quant_config: Optional[QuantizationConfig] = None,
331
327
  ) -> None:
332
328
  super().__init__()
@@ -339,9 +335,7 @@ class XverseModel(nn.Module):
339
335
  )
340
336
  self.layers = nn.ModuleList(
341
337
  [
342
- XverseDecoderLayer(
343
- config, layer_id, cache_config, quant_config=quant_config
344
- )
338
+ XverseDecoderLayer(config, layer_id, quant_config=quant_config)
345
339
  for layer_id in range(config.num_hidden_layers)
346
340
  ]
347
341
  )
@@ -369,13 +363,12 @@ class XverseMoeForCausalLM(nn.Module):
369
363
  def __init__(
370
364
  self,
371
365
  config: PretrainedConfig,
372
- cache_config=None,
373
366
  quant_config: Optional[QuantizationConfig] = None,
374
367
  ) -> None:
375
368
  super().__init__()
376
369
  self.config = config
377
370
  self.quant_config = quant_config
378
- self.model = XverseModel(config, cache_config, quant_config)
371
+ self.model = XverseModel(config, quant_config)
379
372
  self.lm_head = ParallelLMHead(
380
373
  config.vocab_size, config.hidden_size, quant_config=quant_config
381
374
  )
@@ -390,7 +383,7 @@ class XverseMoeForCausalLM(nn.Module):
390
383
  ) -> torch.Tensor:
391
384
  hidden_states = self.model(input_ids, positions, forward_batch)
392
385
  return self.logits_processor(
393
- input_ids, hidden_states, self.lm_head.weight, forward_batch
386
+ input_ids, hidden_states, self.lm_head, forward_batch
394
387
  )
395
388
 
396
389
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
sglang/srt/models/yivl.py CHANGED
@@ -18,9 +18,9 @@ from typing import Iterable, Optional, Tuple
18
18
  import torch
19
19
  import torch.nn as nn
20
20
  from transformers import CLIPVisionModel, LlavaConfig
21
- from vllm.model_executor.model_loader.weight_utils import default_weight_loader
22
21
 
23
22
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
23
+ from sglang.srt.model_loader.weight_utils import default_weight_loader
24
24
  from sglang.srt.models.llava import LlavaLlamaForCausalLM
25
25
 
26
26
 
@@ -29,9 +29,8 @@ class YiVLForCausalLM(LlavaLlamaForCausalLM):
29
29
  self,
30
30
  config: LlavaConfig,
31
31
  quant_config: Optional[QuantizationConfig] = None,
32
- cache_config=None,
33
32
  ) -> None:
34
- super().__init__(config, quant_config, cache_config)
33
+ super().__init__(config, quant_config)
35
34
 
36
35
  self.multi_modal_projector = YiVLMultiModalProjector(self.config)
37
36
  self.vision_tower_subfolder = self.config.mm_vision_tower.replace(
@@ -486,6 +486,7 @@ def v1_generate_request(
486
486
  return_logprobs = []
487
487
  logprob_start_lens = []
488
488
  top_logprobs_nums = []
489
+ lora_paths = []
489
490
 
490
491
  for request in all_requests:
491
492
  # NOTE: with openai API, the prompt's logprobs are always not computed
@@ -496,6 +497,7 @@ def v1_generate_request(
496
497
  )
497
498
 
498
499
  prompts.append(request.prompt)
500
+ lora_paths.append(request.lora_path)
499
501
  if request.echo and request.logprobs:
500
502
  current_logprob_start_len = 0
501
503
  else:
@@ -519,7 +521,7 @@ def v1_generate_request(
519
521
  "skip_special_tokens": request.skip_special_tokens,
520
522
  }
521
523
  )
522
- return_logprobs.append(request.logprobs is not None and request.logprobs > 0)
524
+ return_logprobs.append(request.logprobs is not None)
523
525
  logprob_start_lens.append(current_logprob_start_len)
524
526
  top_logprobs_nums.append(
525
527
  request.logprobs if request.logprobs is not None else 0
@@ -534,6 +536,7 @@ def v1_generate_request(
534
536
  return_logprobs = return_logprobs[0]
535
537
  logprob_start_lens = logprob_start_lens[0]
536
538
  top_logprobs_nums = top_logprobs_nums[0]
539
+ lora_paths = lora_paths[0]
537
540
  else:
538
541
  if isinstance(prompts[0], str) or isinstance(prompts[0][0], str):
539
542
  prompt_kwargs = {"text": prompts}
@@ -549,6 +552,7 @@ def v1_generate_request(
549
552
  return_text_in_logprobs=True,
550
553
  stream=all_requests[0].stream,
551
554
  rid=request_ids,
555
+ lora_path=lora_paths,
552
556
  )
553
557
 
554
558
  return adapted_request, all_requests if len(all_requests) > 1 else all_requests[0]
@@ -591,9 +595,9 @@ def v1_generate_response(request, ret, tokenizer_manager, to_file=False):
591
595
  text = prompts[prompt_index] + text
592
596
 
593
597
  logprobs = False
594
- if isinstance(request, list) and request[idx].logprobs:
598
+ if isinstance(request, list) and request[idx].logprobs is not None:
595
599
  logprobs = True
596
- elif (not isinstance(request, list)) and request.logprobs:
600
+ elif (not isinstance(request, list)) and request.logprobs is not None:
597
601
  logprobs = True
598
602
  if logprobs:
599
603
  if echo:
@@ -735,7 +739,7 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
735
739
  # Prepend prompt in response text.
736
740
  text = prompts + text
737
741
 
738
- if request.logprobs:
742
+ if request.logprobs is not None:
739
743
  # The first chunk and echo is enabled.
740
744
  if not stream_buffer and request.echo:
741
745
  input_token_logprobs = content["meta_info"][
@@ -1275,7 +1279,7 @@ def v1_embedding_request(all_requests, tokenizer_manager):
1275
1279
  for request in all_requests:
1276
1280
  prompt = request.input
1277
1281
  assert (
1278
- type(prompt) == first_prompt_type
1282
+ type(prompt) is first_prompt_type
1279
1283
  ), "All prompts must be of the same type in file input settings"
1280
1284
  prompts.append(prompt)
1281
1285
 
@@ -166,6 +166,7 @@ class CompletionRequest(BaseModel):
166
166
  temperature: float = 1.0
167
167
  top_p: float = 1.0
168
168
  user: Optional[str] = None
169
+ lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
169
170
 
170
171
  # Extra parameters for SRT backend only and will be ignored by OpenAI models.
171
172
  json_schema: Optional[str] = None
@@ -158,22 +158,23 @@ class SamplingBatchInfo:
158
158
  return
159
159
 
160
160
  # find a grammar from the list
161
- grammar = next(grammar for grammar in self.grammars if grammar)
161
+ first_grammar = next(grammar for grammar in self.grammars if grammar)
162
162
 
163
163
  # maybe we can reuse the existing mask?
164
- self.vocab_mask = grammar.allocate_vocab_mask(
164
+ self.vocab_mask = first_grammar.allocate_vocab_mask(
165
165
  vocab_size=self.vocab_size,
166
166
  batch_size=len(self.temperatures),
167
167
  device=self.device,
168
168
  )
169
- self.apply_mask = type(grammar).apply_vocab_mask # force to use static method
169
+ self.apply_mask = first_grammar.apply_vocab_mask # force to use static method
170
170
 
171
+ # Apply the mask
171
172
  for i, grammar in enumerate(self.grammars):
172
- if grammar is not None:
173
- try:
174
- grammar.fill_vocab_mask(self.vocab_mask, i)
175
- except RuntimeError:
176
- continue
173
+ if grammar and not grammar.finished:
174
+ grammar.fill_vocab_mask(self.vocab_mask, i)
175
+
176
+ # Move the mask to the device if needed
177
+ self.vocab_mask = first_grammar.move_vocab_mask(self.vocab_mask, self.device)
177
178
 
178
179
  def filter_batch(self, unfinished_indices: List[int], new_indices: torch.Tensor):
179
180
  self.penalizer_orchestrator.filter(unfinished_indices, new_indices)