sglang 0.1.16__py3-none-any.whl → 0.1.17__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (65) hide show
  1. sglang/__init__.py +3 -1
  2. sglang/api.py +3 -3
  3. sglang/backend/anthropic.py +1 -1
  4. sglang/backend/litellm.py +90 -0
  5. sglang/backend/openai.py +148 -12
  6. sglang/backend/runtime_endpoint.py +18 -10
  7. sglang/global_config.py +8 -1
  8. sglang/lang/interpreter.py +114 -67
  9. sglang/lang/ir.py +17 -2
  10. sglang/srt/constrained/fsm_cache.py +3 -0
  11. sglang/srt/flush_cache.py +1 -1
  12. sglang/srt/hf_transformers_utils.py +75 -1
  13. sglang/srt/layers/extend_attention.py +17 -0
  14. sglang/srt/layers/fused_moe.py +485 -0
  15. sglang/srt/layers/logits_processor.py +12 -7
  16. sglang/srt/layers/radix_attention.py +10 -3
  17. sglang/srt/layers/token_attention.py +16 -1
  18. sglang/srt/managers/controller/dp_worker.py +110 -0
  19. sglang/srt/managers/controller/infer_batch.py +619 -0
  20. sglang/srt/managers/controller/manager_multi.py +191 -0
  21. sglang/srt/managers/controller/manager_single.py +97 -0
  22. sglang/srt/managers/controller/model_runner.py +462 -0
  23. sglang/srt/managers/controller/radix_cache.py +267 -0
  24. sglang/srt/managers/controller/schedule_heuristic.py +59 -0
  25. sglang/srt/managers/controller/tp_worker.py +791 -0
  26. sglang/srt/managers/detokenizer_manager.py +45 -45
  27. sglang/srt/managers/io_struct.py +15 -11
  28. sglang/srt/managers/router/infer_batch.py +103 -59
  29. sglang/srt/managers/router/manager.py +1 -1
  30. sglang/srt/managers/router/model_rpc.py +175 -122
  31. sglang/srt/managers/router/model_runner.py +91 -104
  32. sglang/srt/managers/router/radix_cache.py +7 -1
  33. sglang/srt/managers/router/scheduler.py +6 -6
  34. sglang/srt/managers/tokenizer_manager.py +152 -89
  35. sglang/srt/model_config.py +4 -5
  36. sglang/srt/models/commandr.py +10 -13
  37. sglang/srt/models/dbrx.py +9 -15
  38. sglang/srt/models/gemma.py +8 -15
  39. sglang/srt/models/grok.py +671 -0
  40. sglang/srt/models/llama2.py +19 -15
  41. sglang/srt/models/llava.py +84 -20
  42. sglang/srt/models/llavavid.py +11 -20
  43. sglang/srt/models/mixtral.py +248 -118
  44. sglang/srt/models/mixtral_quant.py +373 -0
  45. sglang/srt/models/qwen.py +9 -13
  46. sglang/srt/models/qwen2.py +11 -13
  47. sglang/srt/models/stablelm.py +9 -15
  48. sglang/srt/models/yivl.py +17 -22
  49. sglang/srt/openai_api_adapter.py +140 -95
  50. sglang/srt/openai_protocol.py +10 -1
  51. sglang/srt/server.py +77 -42
  52. sglang/srt/server_args.py +51 -6
  53. sglang/srt/utils.py +124 -66
  54. sglang/test/test_programs.py +44 -0
  55. sglang/test/test_utils.py +32 -1
  56. sglang/utils.py +22 -4
  57. {sglang-0.1.16.dist-info → sglang-0.1.17.dist-info}/METADATA +15 -9
  58. sglang-0.1.17.dist-info/RECORD +81 -0
  59. sglang/srt/backend_config.py +0 -13
  60. sglang/srt/models/dbrx_config.py +0 -281
  61. sglang/srt/weight_utils.py +0 -417
  62. sglang-0.1.16.dist-info/RECORD +0 -72
  63. {sglang-0.1.16.dist-info → sglang-0.1.17.dist-info}/LICENSE +0 -0
  64. {sglang-0.1.16.dist-info → sglang-0.1.17.dist-info}/WHEEL +0 -0
  65. {sglang-0.1.16.dist-info → sglang-0.1.17.dist-info}/top_level.txt +0 -0
@@ -15,10 +15,9 @@ class ModelConfig:
15
15
  self.path = path
16
16
  self.trust_remote_code = trust_remote_code
17
17
  self.revision = revision
18
- self.hf_config = get_config(self.path, trust_remote_code, revision)
19
-
20
- if model_overide_args is not None:
21
- self.hf_config.update(model_overide_args)
18
+ self.model_overide_args = model_overide_args
19
+ self.hf_config = get_config(self.path, trust_remote_code, revision,
20
+ model_overide_args=model_overide_args)
22
21
 
23
22
  if context_length is not None:
24
23
  self.context_len = context_length
@@ -44,4 +43,4 @@ class ModelConfig:
44
43
  self.num_key_value_heads = self.num_attention_heads
45
44
  self.hidden_size = self.hf_config.hidden_size
46
45
  self.num_hidden_layers = self.hf_config.num_hidden_layers
47
- self.vocab_size = self.hf_config.vocab_size
46
+ self.vocab_size = self.hf_config.vocab_size
@@ -18,15 +18,19 @@
18
18
  # See the License for the specific language governing permissions and
19
19
  # limitations under the License.
20
20
 
21
+ # Adapted from
22
+ # https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/commandr.py#L1
23
+
21
24
  # This file is based on the LLama model definition file in transformers
22
25
  """PyTorch Cohere model."""
23
- from typing import Optional, Tuple
26
+ from typing import Optional, Tuple, Iterable
24
27
 
25
28
  import torch
26
29
  import torch.utils.checkpoint
27
30
  from torch import nn
28
31
  from torch.nn.parameter import Parameter
29
32
  from transformers import PretrainedConfig
33
+ from vllm.config import CacheConfig
30
34
  from vllm.distributed import (
31
35
  get_tensor_model_parallel_rank,
32
36
  get_tensor_model_parallel_world_size,
@@ -41,11 +45,11 @@ from vllm.model_executor.layers.quantization.base_config import QuantizationConf
41
45
  from vllm.model_executor.layers.rotary_embedding import get_rope
42
46
  from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
43
47
  from vllm.model_executor.utils import set_weight_attrs
48
+ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
44
49
 
45
50
  from sglang.srt.layers.logits_processor import LogitsProcessor
46
51
  from sglang.srt.layers.radix_attention import RadixAttention
47
- from sglang.srt.managers.router.model_runner import InputMetadata
48
- from sglang.srt.weight_utils import default_weight_loader, hf_model_weights_iterator
52
+ from sglang.srt.managers.controller.model_runner import InputMetadata
49
53
 
50
54
 
51
55
  @torch.compile
@@ -301,6 +305,7 @@ class CohereForCausalLM(nn.Module):
301
305
  self,
302
306
  config: PretrainedConfig,
303
307
  quant_config: Optional[QuantizationConfig] = None,
308
+ cache_config: Optional[CacheConfig] = None,
304
309
  ) -> None:
305
310
  super().__init__()
306
311
  self.config = config
@@ -324,13 +329,7 @@ class CohereForCausalLM(nn.Module):
324
329
  input_ids, hidden_states, self.model.embed_tokens.weight, input_metadata
325
330
  )
326
331
 
327
- def load_weights(
328
- self,
329
- model_name_or_path: str,
330
- cache_dir: Optional[str] = None,
331
- load_format: str = "auto",
332
- revision: Optional[str] = None,
333
- ):
332
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
334
333
  stacked_params_mapping = [
335
334
  # (param_name, shard_name, shard_id)
336
335
  ("qkv_proj", "q_proj", "q"),
@@ -341,9 +340,7 @@ class CohereForCausalLM(nn.Module):
341
340
  ]
342
341
  params_dict = dict(self.named_parameters())
343
342
  loaded_params = set()
344
- for name, loaded_weight in hf_model_weights_iterator(
345
- model_name_or_path, cache_dir, load_format, revision
346
- ):
343
+ for name, loaded_weight in weights:
347
344
  for param_name, shard_name, shard_id in stacked_params_mapping:
348
345
  if shard_name not in name:
349
346
  continue
sglang/srt/models/dbrx.py CHANGED
@@ -1,10 +1,11 @@
1
1
  # Adapted from:
2
- # https://github.com/vllm-project/vllm/blob/14ccd94c89d0ffd9da283545d93ab1dfea5da340/vllm/model_executor/models/dbrx.py
2
+ # https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/dbrx.py#L1
3
3
  # coding=utf-8
4
- from typing import Optional
4
+ from typing import Iterable, Optional, Tuple
5
5
 
6
6
  import torch
7
7
  import torch.nn as nn
8
+ from vllm.config import CacheConfig
8
9
  from vllm.distributed import (
9
10
  get_tensor_model_parallel_rank,
10
11
  get_tensor_model_parallel_world_size,
@@ -24,12 +25,12 @@ from vllm.model_executor.layers.vocab_parallel_embedding import (
24
25
  VocabParallelEmbedding,
25
26
  )
26
27
  from vllm.model_executor.utils import set_weight_attrs
28
+ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
29
+ from vllm.transformers_utils.configs.dbrx import DbrxConfig
27
30
 
28
31
  from sglang.srt.layers.logits_processor import LogitsProcessor
29
32
  from sglang.srt.layers.radix_attention import RadixAttention
30
- from sglang.srt.managers.router.model_runner import InputMetadata
31
- from sglang.srt.models.dbrx_config import DbrxConfig
32
- from sglang.srt.weight_utils import default_weight_loader, hf_model_weights_iterator
33
+ from sglang.srt.managers.controller.model_runner import InputMetadata
33
34
 
34
35
 
35
36
  class DbrxRouter(nn.Module):
@@ -352,6 +353,7 @@ class DbrxForCausalLM(nn.Module):
352
353
  self,
353
354
  config: DbrxConfig,
354
355
  quant_config: Optional[QuantizationConfig] = None,
356
+ cache_config: Optional[CacheConfig] = None,
355
357
  ):
356
358
  super().__init__()
357
359
  self.config = config
@@ -377,13 +379,7 @@ class DbrxForCausalLM(nn.Module):
377
379
  input_ids, hidden_states, self.lm_head.weight, input_metadata
378
380
  )
379
381
 
380
- def load_weights(
381
- self,
382
- model_name_or_path: str,
383
- cache_dir: Optional[str] = None,
384
- load_format: str = "auto",
385
- revision: Optional[str] = None,
386
- ):
382
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
387
383
  expert_params_mapping = [
388
384
  (
389
385
  "ws" if weight_name in ["w1", "v1"] else "w2s",
@@ -392,9 +388,7 @@ class DbrxForCausalLM(nn.Module):
392
388
  for weight_name in ["w1", "v1", "w2"]
393
389
  ]
394
390
  params_dict = dict(self.named_parameters(remove_duplicate=False))
395
- for name, loaded_weight in hf_model_weights_iterator(
396
- model_name_or_path, cache_dir, load_format, revision
397
- ):
391
+ for name, loaded_weight in weights:
398
392
  for param_name, weight_name in expert_params_mapping:
399
393
  if weight_name not in name:
400
394
  continue
@@ -1,12 +1,12 @@
1
1
  # Adapted from:
2
- # https://github.com/vllm-project/vllm/blob/d65fac2738f0287a41955b45df76a2d5a919bff6/vllm/model_executor/models/gemma.py
2
+ # https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/gemma.py#L1
3
3
  """Inference-only Gemma model compatible with HuggingFace weights."""
4
- from typing import Optional, Tuple
4
+ from typing import Iterable, Optional, Tuple
5
5
 
6
6
  import torch
7
7
  from torch import nn
8
8
  from transformers import PretrainedConfig
9
- from vllm.config import LoRAConfig
9
+ from vllm.config import LoRAConfig, CacheConfig
10
10
  from vllm.distributed import get_tensor_model_parallel_world_size
11
11
  from vllm.model_executor.layers.activation import GeluAndMul
12
12
  from vllm.model_executor.layers.layernorm import RMSNorm
@@ -18,11 +18,11 @@ from vllm.model_executor.layers.linear import (
18
18
  from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
19
19
  from vllm.model_executor.layers.rotary_embedding import get_rope
20
20
  from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
21
+ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
21
22
 
22
23
  from sglang.srt.layers.logits_processor import LogitsProcessor
23
24
  from sglang.srt.layers.radix_attention import RadixAttention
24
- from sglang.srt.managers.router.model_runner import InputMetadata
25
- from sglang.srt.weight_utils import default_weight_loader, hf_model_weights_iterator
25
+ from sglang.srt.managers.controller.model_runner import InputMetadata
26
26
 
27
27
 
28
28
  class GemmaMLP(nn.Module):
@@ -264,6 +264,7 @@ class GemmaForCausalLM(nn.Module):
264
264
  config: PretrainedConfig,
265
265
  quant_config: Optional[QuantizationConfig] = None,
266
266
  lora_config: Optional[LoRAConfig] = None,
267
+ cache_config: Optional[CacheConfig] = None,
267
268
  ) -> None:
268
269
  del lora_config # Unused.
269
270
  super().__init__()
@@ -285,13 +286,7 @@ class GemmaForCausalLM(nn.Module):
285
286
  input_ids, hidden_states, self.model.embed_tokens.weight, input_metadata
286
287
  )
287
288
 
288
- def load_weights(
289
- self,
290
- model_name_or_path: str,
291
- cache_dir: Optional[str] = None,
292
- load_format: str = "auto",
293
- revision: Optional[str] = None,
294
- ):
289
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
295
290
  stacked_params_mapping = [
296
291
  # (param_name, shard_name, shard_id)
297
292
  ("qkv_proj", "q_proj", "q"),
@@ -302,9 +297,7 @@ class GemmaForCausalLM(nn.Module):
302
297
  ]
303
298
  params_dict = dict(self.named_parameters())
304
299
  loaded_params = set()
305
- for name, loaded_weight in hf_model_weights_iterator(
306
- model_name_or_path, cache_dir, load_format, revision
307
- ):
300
+ for name, loaded_weight in weights:
308
301
  for param_name, shard_name, shard_id in stacked_params_mapping:
309
302
  if shard_name not in name:
310
303
  continue