sglang 0.3.6.post2__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (110) hide show
  1. sglang/bench_offline_throughput.py +55 -2
  2. sglang/bench_one_batch.py +7 -6
  3. sglang/bench_one_batch_server.py +4 -3
  4. sglang/bench_serving.py +13 -0
  5. sglang/check_env.py +1 -1
  6. sglang/launch_server.py +3 -2
  7. sglang/srt/_custom_ops.py +118 -0
  8. sglang/srt/configs/device_config.py +17 -0
  9. sglang/srt/configs/load_config.py +84 -0
  10. sglang/srt/configs/model_config.py +161 -4
  11. sglang/srt/configs/qwen2vl.py +5 -8
  12. sglang/srt/constrained/outlines_backend.py +6 -1
  13. sglang/srt/constrained/outlines_jump_forward.py +8 -1
  14. sglang/srt/distributed/__init__.py +3 -0
  15. sglang/srt/distributed/communication_op.py +34 -0
  16. sglang/srt/distributed/device_communicators/__init__.py +0 -0
  17. sglang/srt/distributed/device_communicators/cuda_wrapper.py +182 -0
  18. sglang/srt/distributed/device_communicators/custom_all_reduce.py +352 -0
  19. sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +291 -0
  20. sglang/srt/distributed/device_communicators/hpu_communicator.py +48 -0
  21. sglang/srt/distributed/device_communicators/pynccl.py +204 -0
  22. sglang/srt/distributed/device_communicators/pynccl_wrapper.py +362 -0
  23. sglang/srt/distributed/device_communicators/shm_broadcast.py +568 -0
  24. sglang/srt/distributed/device_communicators/xpu_communicator.py +47 -0
  25. sglang/srt/distributed/parallel_state.py +1275 -0
  26. sglang/srt/distributed/utils.py +223 -0
  27. sglang/srt/hf_transformers_utils.py +37 -1
  28. sglang/srt/layers/attention/flashinfer_backend.py +13 -15
  29. sglang/srt/layers/attention/torch_native_backend.py +285 -0
  30. sglang/srt/layers/fused_moe_patch.py +20 -11
  31. sglang/srt/layers/linear.py +1 -0
  32. sglang/srt/layers/logits_processor.py +17 -3
  33. sglang/srt/layers/quantization/__init__.py +34 -0
  34. sglang/srt/layers/vocab_parallel_embedding.py +1 -0
  35. sglang/srt/lora/lora.py +1 -1
  36. sglang/srt/managers/data_parallel_controller.py +7 -11
  37. sglang/srt/managers/detokenizer_manager.py +7 -4
  38. sglang/srt/managers/image_processor.py +1 -1
  39. sglang/srt/managers/io_struct.py +48 -12
  40. sglang/srt/managers/schedule_batch.py +42 -36
  41. sglang/srt/managers/schedule_policy.py +7 -4
  42. sglang/srt/managers/scheduler.py +111 -46
  43. sglang/srt/managers/session_controller.py +0 -3
  44. sglang/srt/managers/tokenizer_manager.py +169 -100
  45. sglang/srt/managers/tp_worker.py +36 -3
  46. sglang/srt/managers/tp_worker_overlap_thread.py +32 -5
  47. sglang/srt/model_executor/cuda_graph_runner.py +16 -7
  48. sglang/srt/model_executor/forward_batch_info.py +9 -4
  49. sglang/srt/model_executor/model_runner.py +136 -150
  50. sglang/srt/model_loader/__init__.py +34 -0
  51. sglang/srt/model_loader/loader.py +1139 -0
  52. sglang/srt/model_loader/utils.py +41 -0
  53. sglang/srt/model_loader/weight_utils.py +640 -0
  54. sglang/srt/models/baichuan.py +9 -10
  55. sglang/srt/models/chatglm.py +6 -15
  56. sglang/srt/models/commandr.py +2 -3
  57. sglang/srt/models/dbrx.py +2 -3
  58. sglang/srt/models/deepseek.py +4 -11
  59. sglang/srt/models/deepseek_v2.py +3 -11
  60. sglang/srt/models/exaone.py +2 -3
  61. sglang/srt/models/gemma.py +2 -6
  62. sglang/srt/models/gemma2.py +3 -14
  63. sglang/srt/models/gemma2_reward.py +0 -1
  64. sglang/srt/models/gpt2.py +5 -12
  65. sglang/srt/models/gpt_bigcode.py +6 -22
  66. sglang/srt/models/grok.py +14 -51
  67. sglang/srt/models/internlm2.py +2 -3
  68. sglang/srt/models/internlm2_reward.py +0 -1
  69. sglang/srt/models/llama.py +97 -27
  70. sglang/srt/models/llama_classification.py +1 -2
  71. sglang/srt/models/llama_embedding.py +1 -2
  72. sglang/srt/models/llama_reward.py +2 -3
  73. sglang/srt/models/llava.py +10 -12
  74. sglang/srt/models/llavavid.py +1 -2
  75. sglang/srt/models/minicpm.py +4 -7
  76. sglang/srt/models/minicpm3.py +6 -19
  77. sglang/srt/models/mixtral.py +12 -5
  78. sglang/srt/models/mixtral_quant.py +2 -3
  79. sglang/srt/models/mllama.py +3 -7
  80. sglang/srt/models/olmo.py +2 -8
  81. sglang/srt/models/olmo2.py +391 -0
  82. sglang/srt/models/olmoe.py +3 -5
  83. sglang/srt/models/phi3_small.py +8 -8
  84. sglang/srt/models/qwen.py +2 -3
  85. sglang/srt/models/qwen2.py +10 -9
  86. sglang/srt/models/qwen2_moe.py +4 -11
  87. sglang/srt/models/qwen2_vl.py +12 -9
  88. sglang/srt/models/registry.py +99 -0
  89. sglang/srt/models/stablelm.py +2 -3
  90. sglang/srt/models/torch_native_llama.py +6 -12
  91. sglang/srt/models/xverse.py +2 -4
  92. sglang/srt/models/xverse_moe.py +4 -11
  93. sglang/srt/models/yivl.py +2 -3
  94. sglang/srt/openai_api/adapter.py +10 -6
  95. sglang/srt/openai_api/protocol.py +1 -0
  96. sglang/srt/server.py +303 -204
  97. sglang/srt/server_args.py +65 -31
  98. sglang/srt/utils.py +253 -48
  99. sglang/test/test_utils.py +27 -7
  100. sglang/utils.py +2 -2
  101. sglang/version.py +1 -1
  102. {sglang-0.3.6.post2.dist-info → sglang-0.4.0.dist-info}/METADATA +2 -1
  103. sglang-0.4.0.dist-info/RECORD +184 -0
  104. sglang/srt/layers/fused_moe_grok/__init__.py +0 -1
  105. sglang/srt/layers/fused_moe_grok/fused_moe.py +0 -692
  106. sglang/srt/layers/fused_moe_grok/layer.py +0 -630
  107. sglang-0.3.6.post2.dist-info/RECORD +0 -164
  108. {sglang-0.3.6.post2.dist-info → sglang-0.4.0.dist-info}/LICENSE +0 -0
  109. {sglang-0.3.6.post2.dist-info → sglang-0.4.0.dist-info}/WHEEL +0 -0
  110. {sglang-0.3.6.post2.dist-info → sglang-0.4.0.dist-info}/top_level.txt +0 -0
@@ -30,12 +30,10 @@ import torch
30
30
  import torch.nn as nn
31
31
  import torch.nn.functional as F
32
32
  from einops import rearrange, repeat
33
- from vllm.config import CacheConfig, MultiModalConfig
34
33
  from vllm.distributed import parallel_state
35
34
  from vllm.distributed import utils as dist_utils
36
35
  from vllm.logger import init_logger
37
36
  from vllm.model_executor.layers.activation import QuickGELU
38
- from vllm.model_executor.model_loader.weight_utils import default_weight_loader
39
37
 
40
38
  from sglang.srt.configs import Qwen2VLConfig, Qwen2VLVisionConfig
41
39
  from sglang.srt.hf_transformers_utils import get_processor
@@ -49,6 +47,7 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig
49
47
  from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
50
48
  from sglang.srt.managers.schedule_batch import ImageInputs
51
49
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
50
+ from sglang.srt.model_loader.weight_utils import default_weight_loader
52
51
  from sglang.srt.models.qwen2 import Qwen2Model
53
52
 
54
53
  logger = init_logger(__name__)
@@ -500,7 +499,7 @@ class Qwen2VLForConditionalGeneration(nn.Module):
500
499
  return num_image_tokens
501
500
 
502
501
  # Use grid_t * grid_w * grid_h to pad tokens for each image
503
- # and replaced padding by unique image hash
502
+ # add replaced padding by unique image hash
504
503
  def pad_input_ids(self, input_ids: List[int], image_inputs: ImageInputs):
505
504
  image_grid_thws = image_inputs.image_grid_thws
506
505
  pad_values = image_inputs.pad_values
@@ -536,7 +535,6 @@ class Qwen2VLForConditionalGeneration(nn.Module):
536
535
  def __init__(
537
536
  self,
538
537
  config: Qwen2VLConfig,
539
- cache_config: Optional[CacheConfig] = None,
540
538
  quant_config: Optional[QuantizationConfig] = None,
541
539
  ) -> None:
542
540
  super().__init__()
@@ -597,13 +595,15 @@ class Qwen2VLForConditionalGeneration(nn.Module):
597
595
  image_grid_thw: Tensor `(n_images, 3)` of image 3D grid in LLM.
598
596
  `None` if no images are passed.
599
597
  """
598
+ if getattr(self.config, "rope_scaling", {}).get("type", None) == "mrope":
599
+ positions = forward_batch.mrope_positions
600
+
600
601
  image_inputs = None
601
602
  if forward_batch.image_inputs is not None:
602
603
  image_inputs = [
603
604
  img for img in forward_batch.image_inputs if img is not None
604
605
  ]
605
- if getattr(self.config, "rope_scaling", {}).get("type", None) == "mrope":
606
- positions = forward_batch.mrope_positions
606
+
607
607
  if (
608
608
  forward_batch.forward_mode.is_decode()
609
609
  or image_inputs is None
@@ -617,6 +617,11 @@ class Qwen2VLForConditionalGeneration(nn.Module):
617
617
  f"(3, seq_len) positions, but got {positions.size()}"
618
618
  )
619
619
 
620
+ # Clamp input ids. This is because the input_ids for the image tokens are
621
+ # filled with the hash values of the image for the prefix matching in the radix attention.
622
+ # There values are useless because their embeddings will be replaced by vision embeddings anyway.
623
+ input_ids.clamp_(min=0, max=self.config.vocab_size - 1)
624
+
620
625
  inputs_embeds = self.model.embed_tokens(input_ids)
621
626
  extend_start_loc_cpu = forward_batch.extend_start_loc.cpu().numpy()
622
627
  prefix_lens_cpu = forward_batch.extend_prefix_lens_cpu
@@ -661,7 +666,7 @@ class Qwen2VLForConditionalGeneration(nn.Module):
661
666
 
662
667
  if not get_embedding:
663
668
  return self.logits_processor(
664
- input_ids, hidden_states, self.lm_head.weight, forward_batch
669
+ input_ids, hidden_states, self.lm_head, forward_batch
665
670
  )
666
671
  else:
667
672
  return self.pooler(hidden_states, forward_batch)
@@ -679,8 +684,6 @@ class Qwen2VLForConditionalGeneration(nn.Module):
679
684
  for name, loaded_weight in weights:
680
685
  if "rotary_emb.inv_freq" in name:
681
686
  continue
682
- if self.config.tie_word_embeddings and "lm_head.weight" in name:
683
- continue
684
687
  for param_name, weight_name, shard_id in stacked_params_mapping:
685
688
  if weight_name not in name:
686
689
  continue
@@ -0,0 +1,99 @@
1
+ # Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/models/registry.py
2
+
3
+ import importlib
4
+ import logging
5
+ import pkgutil
6
+ from dataclasses import dataclass, field
7
+ from functools import lru_cache
8
+ from typing import AbstractSet, Dict, List, Optional, Tuple, Type, Union
9
+
10
+ import torch.nn as nn
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ @dataclass
16
+ class _ModelRegistry:
17
+ # Keyed by model_arch
18
+ models: Dict[str, Union[Type[nn.Module], str]] = field(default_factory=dict)
19
+
20
+ def get_supported_archs(self) -> AbstractSet[str]:
21
+ return self.models.keys()
22
+
23
+ def _raise_for_unsupported(self, architectures: List[str]):
24
+ all_supported_archs = self.get_supported_archs()
25
+
26
+ if any(arch in all_supported_archs for arch in architectures):
27
+ raise ValueError(
28
+ f"Model architectures {architectures} failed "
29
+ "to be inspected. Please check the logs for more details."
30
+ )
31
+
32
+ raise ValueError(
33
+ f"Model architectures {architectures} are not supported for now. "
34
+ f"Supported architectures: {all_supported_archs}"
35
+ )
36
+
37
+ def _try_load_model_cls(self, model_arch: str) -> Optional[Type[nn.Module]]:
38
+ if model_arch not in self.models:
39
+ return None
40
+
41
+ return self.models[model_arch]
42
+
43
+ def _normalize_archs(
44
+ self,
45
+ architectures: Union[str, List[str]],
46
+ ) -> List[str]:
47
+ if isinstance(architectures, str):
48
+ architectures = [architectures]
49
+ if not architectures:
50
+ logger.warning("No model architectures are specified")
51
+
52
+ return architectures
53
+
54
+ def resolve_model_cls(
55
+ self,
56
+ architectures: Union[str, List[str]],
57
+ ) -> Tuple[Type[nn.Module], str]:
58
+ architectures = self._normalize_archs(architectures)
59
+
60
+ for arch in architectures:
61
+ model_cls = self._try_load_model_cls(arch)
62
+ if model_cls is not None:
63
+ return (model_cls, arch)
64
+
65
+ return self._raise_for_unsupported(architectures)
66
+
67
+
68
+ @lru_cache()
69
+ def import_model_classes():
70
+ model_arch_name_to_cls = {}
71
+ package_name = "sglang.srt.models"
72
+ package = importlib.import_module(package_name)
73
+ for _, name, ispkg in pkgutil.iter_modules(package.__path__, package_name + "."):
74
+ if not ispkg:
75
+ try:
76
+ module = importlib.import_module(name)
77
+ except Exception as e:
78
+ logger.warning(f"Ignore import error when loading {name}. " f"{e}")
79
+ continue
80
+ if hasattr(module, "EntryClass"):
81
+ entry = module.EntryClass
82
+ if isinstance(
83
+ entry, list
84
+ ): # To support multiple model classes in one module
85
+ for tmp in entry:
86
+ assert (
87
+ tmp.__name__ not in model_arch_name_to_cls
88
+ ), f"Duplicated model implementation for {tmp.__name__}"
89
+ model_arch_name_to_cls[tmp.__name__] = tmp
90
+ else:
91
+ assert (
92
+ entry.__name__ not in model_arch_name_to_cls
93
+ ), f"Duplicated model implementation for {entry.__name__}"
94
+ model_arch_name_to_cls[entry.__name__] = entry
95
+
96
+ return model_arch_name_to_cls
97
+
98
+
99
+ ModelRegistry = _ModelRegistry(import_model_classes())
@@ -26,7 +26,6 @@ from torch import nn
26
26
  from transformers import PretrainedConfig
27
27
  from vllm.distributed import get_tensor_model_parallel_world_size
28
28
  from vllm.model_executor.layers.rotary_embedding import get_rope
29
- from vllm.model_executor.model_loader.weight_utils import default_weight_loader
30
29
 
31
30
  from sglang.srt.layers.activation import SiluAndMul
32
31
  from sglang.srt.layers.linear import (
@@ -42,6 +41,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
42
41
  VocabParallelEmbedding,
43
42
  )
44
43
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
44
+ from sglang.srt.model_loader.weight_utils import default_weight_loader
45
45
 
46
46
 
47
47
  class StablelmMLP(nn.Module):
@@ -242,7 +242,6 @@ class StableLmForCausalLM(nn.Module):
242
242
  self,
243
243
  config: PretrainedConfig,
244
244
  quant_config: Optional[QuantizationConfig] = None,
245
- cache_config=None,
246
245
  ) -> None:
247
246
  super().__init__()
248
247
  self.config = config
@@ -261,7 +260,7 @@ class StableLmForCausalLM(nn.Module):
261
260
  ) -> torch.Tensor:
262
261
  hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
263
262
  return self.logits_processor(
264
- input_ids, hidden_states, self.lm_head.weight, forward_batch
263
+ input_ids, hidden_states, self.lm_head, forward_batch
265
264
  )
266
265
 
267
266
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
@@ -52,7 +52,6 @@ from vllm.distributed import (
52
52
  get_tensor_model_parallel_world_size,
53
53
  )
54
54
  from vllm.model_executor.layers.rotary_embedding import get_rope
55
- from vllm.model_executor.model_loader.weight_utils import default_weight_loader
56
55
 
57
56
  from sglang.srt.layers.activation import SiluAndMul
58
57
  from sglang.srt.layers.layernorm import RMSNorm
@@ -66,6 +65,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
66
65
  )
67
66
  from sglang.srt.managers.schedule_batch import global_server_args_dict
68
67
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
68
+ from sglang.srt.model_loader.weight_utils import default_weight_loader
69
69
 
70
70
  tp_size = get_tensor_model_parallel_world_size()
71
71
  tp_rank = get_tensor_model_parallel_rank()
@@ -388,7 +388,6 @@ class TorchNativeLlamaForCausalLM(nn.Module):
388
388
  self,
389
389
  config: LlamaConfig,
390
390
  quant_config: Optional[QuantizationConfig] = None,
391
- cache_config=None,
392
391
  ) -> None:
393
392
  super().__init__()
394
393
  self.config = config
@@ -396,7 +395,10 @@ class TorchNativeLlamaForCausalLM(nn.Module):
396
395
  self.torchao_config = global_server_args_dict["torchao_config"]
397
396
  self.supports_torch_tp = True
398
397
  self.model = LlamaModel(config, quant_config=quant_config)
399
- self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
398
+ if self.config.tie_word_embeddings:
399
+ self.lm_head = self.model.embed_tokens
400
+ else:
401
+ self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
400
402
  self.logits_processor = LogitsProcessor(config)
401
403
 
402
404
  # turning off autotune for fp8dq since it doesn't give speedup and
@@ -413,7 +415,7 @@ class TorchNativeLlamaForCausalLM(nn.Module):
413
415
  ) -> LogitsProcessorOutput:
414
416
  hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
415
417
  return self.logits_processor(
416
- input_ids, hidden_states, self.lm_head.weight, forward_batch
418
+ input_ids, hidden_states, self.lm_head, forward_batch
417
419
  )
418
420
 
419
421
  def get_hidden_dim(self, module_name):
@@ -501,14 +503,6 @@ class TorchNativeLlamaForCausalLM(nn.Module):
501
503
  weight_loader = getattr(param, "weight_loader", default_weight_loader)
502
504
  weight_loader(param, loaded_weight)
503
505
 
504
- if (
505
- hasattr(self.config, "tie_word_embeddings")
506
- and self.config.tie_word_embeddings
507
- ):
508
- # Tie output embedding layer to input embedding layer, to solve issues where lm_head.weight is missing
509
- param = self.lm_head.weight
510
- weight_loader = getattr(param, "weight_loader", default_weight_loader)
511
- weight_loader(param, self.model.embed_tokens.weight)
512
506
  apply_torchao_config_(self, params_dict, set(["proj.weight"]))
513
507
 
514
508
 
@@ -30,7 +30,6 @@ from vllm.model_executor.layers.linear import (
30
30
  RowParallelLinear,
31
31
  )
32
32
  from vllm.model_executor.layers.rotary_embedding import get_rope
33
- from vllm.model_executor.model_loader.weight_utils import default_weight_loader
34
33
 
35
34
  from sglang.srt.layers.logits_processor import LogitsProcessor
36
35
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
@@ -40,6 +39,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
40
39
  VocabParallelEmbedding,
41
40
  )
42
41
  from sglang.srt.model_executor.model_runner import ForwardBatch
42
+ from sglang.srt.model_loader.weight_utils import default_weight_loader
43
43
 
44
44
 
45
45
  class XverseMLP(nn.Module):
@@ -295,8 +295,6 @@ class XverseForCausalLM(nn.Module):
295
295
  self,
296
296
  config: LlamaConfig,
297
297
  quant_config: Optional[QuantizationConfig] = None,
298
- cache_config=None,
299
- efficient_weight_load=False,
300
298
  ) -> None:
301
299
  super().__init__()
302
300
  self.config = config
@@ -315,7 +313,7 @@ class XverseForCausalLM(nn.Module):
315
313
  ) -> torch.Tensor:
316
314
  hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
317
315
  return self.logits_processor(
318
- input_ids, hidden_states, self.lm_head.weight, forward_batch
316
+ input_ids, hidden_states, self.lm_head, forward_batch
319
317
  )
320
318
 
321
319
  def load_weights(
@@ -32,7 +32,6 @@ from vllm.model_executor.layers.linear import (
32
32
  RowParallelLinear,
33
33
  )
34
34
  from vllm.model_executor.layers.rotary_embedding import get_rope
35
- from vllm.model_executor.model_loader.weight_utils import default_weight_loader
36
35
 
37
36
  from sglang.srt.layers.fused_moe_triton import fused_moe
38
37
  from sglang.srt.layers.logits_processor import LogitsProcessor
@@ -43,6 +42,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
43
42
  VocabParallelEmbedding,
44
43
  )
45
44
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
45
+ from sglang.srt.model_loader.weight_utils import default_weight_loader
46
46
 
47
47
 
48
48
  class XverseMLP(nn.Module):
@@ -181,7 +181,6 @@ class XverseAttention(nn.Module):
181
181
  rope_theta: float = 10000,
182
182
  rope_scaling: Optional[Dict[str, Any]] = None,
183
183
  max_position_embeddings: int = 8192,
184
- cache_config=None,
185
184
  quant_config: Optional[QuantizationConfig] = None,
186
185
  ) -> None:
187
186
  super().__init__()
@@ -258,7 +257,6 @@ class XverseDecoderLayer(nn.Module):
258
257
  self,
259
258
  config: PretrainedConfig,
260
259
  layer_id: int,
261
- cache_config=None,
262
260
  quant_config: Optional[QuantizationConfig] = None,
263
261
  ) -> None:
264
262
  super().__init__()
@@ -277,7 +275,6 @@ class XverseDecoderLayer(nn.Module):
277
275
  rope_theta=rope_theta,
278
276
  rope_scaling=rope_scaling,
279
277
  max_position_embeddings=max_position_embeddings,
280
- cache_config=cache_config,
281
278
  quant_config=quant_config,
282
279
  )
283
280
  if config.num_experts is not None:
@@ -326,7 +323,6 @@ class XverseModel(nn.Module):
326
323
  def __init__(
327
324
  self,
328
325
  config: PretrainedConfig,
329
- cache_config=None,
330
326
  quant_config: Optional[QuantizationConfig] = None,
331
327
  ) -> None:
332
328
  super().__init__()
@@ -339,9 +335,7 @@ class XverseModel(nn.Module):
339
335
  )
340
336
  self.layers = nn.ModuleList(
341
337
  [
342
- XverseDecoderLayer(
343
- config, layer_id, cache_config, quant_config=quant_config
344
- )
338
+ XverseDecoderLayer(config, layer_id, quant_config=quant_config)
345
339
  for layer_id in range(config.num_hidden_layers)
346
340
  ]
347
341
  )
@@ -369,13 +363,12 @@ class XverseMoeForCausalLM(nn.Module):
369
363
  def __init__(
370
364
  self,
371
365
  config: PretrainedConfig,
372
- cache_config=None,
373
366
  quant_config: Optional[QuantizationConfig] = None,
374
367
  ) -> None:
375
368
  super().__init__()
376
369
  self.config = config
377
370
  self.quant_config = quant_config
378
- self.model = XverseModel(config, cache_config, quant_config)
371
+ self.model = XverseModel(config, quant_config)
379
372
  self.lm_head = ParallelLMHead(
380
373
  config.vocab_size, config.hidden_size, quant_config=quant_config
381
374
  )
@@ -390,7 +383,7 @@ class XverseMoeForCausalLM(nn.Module):
390
383
  ) -> torch.Tensor:
391
384
  hidden_states = self.model(input_ids, positions, forward_batch)
392
385
  return self.logits_processor(
393
- input_ids, hidden_states, self.lm_head.weight, forward_batch
386
+ input_ids, hidden_states, self.lm_head, forward_batch
394
387
  )
395
388
 
396
389
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
sglang/srt/models/yivl.py CHANGED
@@ -18,9 +18,9 @@ from typing import Iterable, Optional, Tuple
18
18
  import torch
19
19
  import torch.nn as nn
20
20
  from transformers import CLIPVisionModel, LlavaConfig
21
- from vllm.model_executor.model_loader.weight_utils import default_weight_loader
22
21
 
23
22
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
23
+ from sglang.srt.model_loader.weight_utils import default_weight_loader
24
24
  from sglang.srt.models.llava import LlavaLlamaForCausalLM
25
25
 
26
26
 
@@ -29,9 +29,8 @@ class YiVLForCausalLM(LlavaLlamaForCausalLM):
29
29
  self,
30
30
  config: LlavaConfig,
31
31
  quant_config: Optional[QuantizationConfig] = None,
32
- cache_config=None,
33
32
  ) -> None:
34
- super().__init__(config, quant_config, cache_config)
33
+ super().__init__(config, quant_config)
35
34
 
36
35
  self.multi_modal_projector = YiVLMultiModalProjector(self.config)
37
36
  self.vision_tower_subfolder = self.config.mm_vision_tower.replace(
@@ -486,6 +486,7 @@ def v1_generate_request(
486
486
  return_logprobs = []
487
487
  logprob_start_lens = []
488
488
  top_logprobs_nums = []
489
+ lora_paths = []
489
490
 
490
491
  for request in all_requests:
491
492
  # NOTE: with openai API, the prompt's logprobs are always not computed
@@ -496,6 +497,7 @@ def v1_generate_request(
496
497
  )
497
498
 
498
499
  prompts.append(request.prompt)
500
+ lora_paths.append(request.lora_path)
499
501
  if request.echo and request.logprobs:
500
502
  current_logprob_start_len = 0
501
503
  else:
@@ -519,7 +521,7 @@ def v1_generate_request(
519
521
  "skip_special_tokens": request.skip_special_tokens,
520
522
  }
521
523
  )
522
- return_logprobs.append(request.logprobs is not None and request.logprobs > 0)
524
+ return_logprobs.append(request.logprobs is not None)
523
525
  logprob_start_lens.append(current_logprob_start_len)
524
526
  top_logprobs_nums.append(
525
527
  request.logprobs if request.logprobs is not None else 0
@@ -534,6 +536,7 @@ def v1_generate_request(
534
536
  return_logprobs = return_logprobs[0]
535
537
  logprob_start_lens = logprob_start_lens[0]
536
538
  top_logprobs_nums = top_logprobs_nums[0]
539
+ lora_paths = lora_paths[0]
537
540
  else:
538
541
  if isinstance(prompts[0], str) or isinstance(prompts[0][0], str):
539
542
  prompt_kwargs = {"text": prompts}
@@ -549,6 +552,7 @@ def v1_generate_request(
549
552
  return_text_in_logprobs=True,
550
553
  stream=all_requests[0].stream,
551
554
  rid=request_ids,
555
+ lora_path=lora_paths,
552
556
  )
553
557
 
554
558
  return adapted_request, all_requests if len(all_requests) > 1 else all_requests[0]
@@ -591,9 +595,9 @@ def v1_generate_response(request, ret, tokenizer_manager, to_file=False):
591
595
  text = prompts[prompt_index] + text
592
596
 
593
597
  logprobs = False
594
- if isinstance(request, list) and request[idx].logprobs:
598
+ if isinstance(request, list) and request[idx].logprobs is not None:
595
599
  logprobs = True
596
- elif (not isinstance(request, list)) and request.logprobs:
600
+ elif (not isinstance(request, list)) and request.logprobs is not None:
597
601
  logprobs = True
598
602
  if logprobs:
599
603
  if echo:
@@ -735,7 +739,7 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
735
739
  # Prepend prompt in response text.
736
740
  text = prompts + text
737
741
 
738
- if request.logprobs:
742
+ if request.logprobs is not None:
739
743
  # The first chunk and echo is enabled.
740
744
  if not stream_buffer and request.echo:
741
745
  input_token_logprobs = content["meta_info"][
@@ -1275,7 +1279,7 @@ def v1_embedding_request(all_requests, tokenizer_manager):
1275
1279
  for request in all_requests:
1276
1280
  prompt = request.input
1277
1281
  assert (
1278
- type(prompt) == first_prompt_type
1282
+ type(prompt) is first_prompt_type
1279
1283
  ), "All prompts must be of the same type in file input settings"
1280
1284
  prompts.append(prompt)
1281
1285
 
@@ -1286,7 +1290,7 @@ def v1_embedding_request(all_requests, tokenizer_manager):
1286
1290
  else:
1287
1291
  prompt_kwargs = {"input_ids": prompt}
1288
1292
  else:
1289
- if isinstance(prompts[0], str) or isinstance(propmts[0][0], str):
1293
+ if isinstance(prompts[0], str) or isinstance(prompts[0][0], str):
1290
1294
  prompt_kwargs = {"text": prompts}
1291
1295
  else:
1292
1296
  prompt_kwargs = {"input_ids": prompts}
@@ -166,6 +166,7 @@ class CompletionRequest(BaseModel):
166
166
  temperature: float = 1.0
167
167
  top_p: float = 1.0
168
168
  user: Optional[str] = None
169
+ lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
169
170
 
170
171
  # Extra parameters for SRT backend only and will be ignored by OpenAI models.
171
172
  json_schema: Optional[str] = None