sglang 0.1.26__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -15,7 +15,6 @@ from flashinfer import (
15
15
  BatchPrefillWithRaggedKVCacheWrapper,
16
16
  )
17
17
  from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
18
- from torch.nn.parameter import Parameter
19
18
  from vllm.config import DeviceConfig, LoadConfig
20
19
  from vllm.config import ModelConfig as VllmModelConfig
21
20
  from vllm.distributed import (
@@ -23,7 +22,6 @@ from vllm.distributed import (
23
22
  init_distributed_environment,
24
23
  initialize_model_parallel,
25
24
  )
26
- from vllm.model_executor.layers.linear import QKVParallelLinear
27
25
  from vllm.model_executor.models import ModelRegistry
28
26
 
29
27
  from sglang.global_config import global_config
@@ -32,26 +30,16 @@ from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
32
30
  from sglang.srt.server_args import ServerArgs
33
31
  from sglang.srt.utils import (
34
32
  get_available_gpu_memory,
33
+ is_llama3_405b_fp8,
35
34
  is_multimodal_model,
36
35
  monkey_patch_vllm_dummy_weight_loader,
37
36
  monkey_patch_vllm_p2p_access_check,
37
+ monkey_patch_vllm_qvk_linear_loader,
38
38
  )
39
39
 
40
40
  logger = logging.getLogger("srt.model_runner")
41
41
 
42
42
 
43
- def is_llama3_405b_fp8(model_config):
44
- if (
45
- model_config.hf_config.architectures[0] == "LlamaForCausalLM"
46
- and model_config.hf_config.hidden_size == 16384
47
- and model_config.hf_config.intermediate_size == 53248
48
- and model_config.hf_config.num_hidden_layers == 126
49
- and model_config.hf_config.quantization_config["quant_method"] == "fbgemm_fp8"
50
- ):
51
- return True
52
- return False
53
-
54
-
55
43
  class ModelRunner:
56
44
  def __init__(
57
45
  self,
@@ -132,9 +120,13 @@ class ModelRunner:
132
120
  seed=42,
133
121
  skip_tokenizer_init=True,
134
122
  )
123
+
135
124
  if is_llama3_405b_fp8(self.model_config):
125
+ # A temporary hack to fix the num_heads for meta-llama/Meta-Llama-3.1-405B-FP8 checkpoints
136
126
  self.model_config.hf_config.num_key_value_heads = 8
137
127
  vllm_model_config.hf_config.num_key_value_heads = 8
128
+ monkey_patch_vllm_qvk_linear_loader()
129
+
138
130
  self.dtype = vllm_model_config.dtype
139
131
  if self.model_config.model_overide_args is not None:
140
132
  vllm_model_config.hf_config.update(self.model_config.model_overide_args)
@@ -387,39 +379,5 @@ def load_model_cls_srt(model_arch: str) -> Optional[Type[nn.Module]]:
387
379
  return model_arch_name_to_cls[model_arch]
388
380
 
389
381
 
390
- def get_original_weight(loaded_weight, head_dim):
391
- n_kv_head = loaded_weight.shape[0] // (2 * head_dim)
392
- dim = loaded_weight.shape[1]
393
- for i in range(n_kv_head):
394
- loaded_weight[i * head_dim : (i + 1) * head_dim, :] = loaded_weight[
395
- 2 * i * head_dim : (2 * i + 1) * head_dim, :
396
- ]
397
- original_kv_weight = loaded_weight[: n_kv_head * head_dim, :]
398
- assert original_kv_weight.shape == (n_kv_head * head_dim, dim)
399
- return original_kv_weight
400
-
401
-
402
- def get_weight_loader_srt(weight_loader):
403
- def weight_loader_srt(
404
- self,
405
- param: Parameter,
406
- loaded_weight: torch.Tensor,
407
- loaded_shard_id: Optional[str] = None,
408
- ):
409
- if (
410
- loaded_shard_id in ["k", "v"]
411
- and loaded_weight.shape[0] == self.head_size * self.total_num_kv_heads * 2
412
- ):
413
- loaded_weight = get_original_weight(loaded_weight, self.head_size)
414
-
415
- weight_loader(self, param, loaded_weight, loaded_shard_id)
416
-
417
- return weight_loader_srt
418
-
419
-
420
382
  # Monkey patch model loader
421
383
  setattr(ModelRegistry, "load_model_cls", load_model_cls_srt)
422
- original_weight_loader = QKVParallelLinear.weight_loader
423
- setattr(
424
- QKVParallelLinear, "weight_loader", get_weight_loader_srt(original_weight_loader)
425
- )
sglang/srt/server.py CHANGED
@@ -202,15 +202,12 @@ def launch_server(
202
202
  "reinstall the latest version by following the instructions "
203
203
  "at https://docs.flashinfer.ai/installation.html.",
204
204
  )
205
-
206
- if server_args.tp_size // server_args.dp_size > 1:
205
+ if server_args.tp_size * server_args.dp_size > 1:
207
206
  # FIXME: remove this after https://github.com/triton-lang/triton/pull/4295 is used as a dependency.
208
207
  maybe_set_triton_cache_manager()
209
-
210
208
  if server_args.chat_template:
211
209
  # TODO: replace this with huggingface transformers template
212
210
  load_chat_template_for_openai_api(server_args.chat_template)
213
-
214
211
  if server_args.enable_torch_compile:
215
212
  _set_torch_compile_config()
216
213
 
sglang/srt/utils.py CHANGED
@@ -21,6 +21,7 @@ import torch.distributed as dist
21
21
  from fastapi.responses import JSONResponse
22
22
  from packaging import version as pkg_version
23
23
  from starlette.middleware.base import BaseHTTPMiddleware
24
+ from torch.nn.parameter import Parameter
24
25
  from triton.runtime.cache import (
25
26
  FileCacheManager,
26
27
  default_cache_dir,
@@ -471,7 +472,7 @@ def maybe_set_triton_cache_manager() -> None:
471
472
  cache_manger = os.environ.get("TRITON_CACHE_MANAGER", None)
472
473
  if cache_manger is None:
473
474
  manager = "sglang.srt.utils:CustomCacheManager"
474
- logger.info("Setting Triton cache manager to: %s", manager)
475
+ logger.debug("Setting Triton cache manager to: %s", manager)
475
476
  os.environ["TRITON_CACHE_MANAGER"] = manager
476
477
 
477
478
 
@@ -615,3 +616,51 @@ def set_ulimit(target_soft_limit=65535):
615
616
  resource.setrlimit(resource_type, (target_soft_limit, current_hard))
616
617
  except ValueError as e:
617
618
  logger.warn(f"Fail to set RLIMIT_NOFILE: {e}")
619
+
620
+
621
+ def is_llama3_405b_fp8(model_config):
622
+ """Return whether the model is meta-llama/Meta-Llama-3.1-405B-FP8 with 16 kv heads."""
623
+ if (
624
+ model_config.hf_config.architectures[0] == "LlamaForCausalLM"
625
+ and model_config.hf_config.hidden_size == 16384
626
+ and model_config.hf_config.intermediate_size == 53248
627
+ and model_config.hf_config.num_hidden_layers == 126
628
+ and model_config.hf_config.num_key_value_heads == 16
629
+ and model_config.hf_config.quantization_config["quant_method"] == "fbgemm_fp8"
630
+ ):
631
+ return True
632
+ return False
633
+
634
+
635
+ def monkey_patch_vllm_qvk_linear_loader():
636
+ """A temporary hack to fix the num_heads for meta-llama/Meta-Llama-3.1-405B-FP8 checkpoints."""
637
+ from vllm.model_executor.layers.linear import QKVParallelLinear
638
+
639
+ origin_weight_loader = QKVParallelLinear.weight_loader
640
+
641
+ def get_original_weight(loaded_weight, head_dim):
642
+ n_kv_head = loaded_weight.shape[0] // (2 * head_dim)
643
+ dim = loaded_weight.shape[1]
644
+ for i in range(n_kv_head):
645
+ loaded_weight[i * head_dim : (i + 1) * head_dim, :] = loaded_weight[
646
+ 2 * i * head_dim : (2 * i + 1) * head_dim, :
647
+ ]
648
+ original_kv_weight = loaded_weight[: n_kv_head * head_dim, :]
649
+ assert original_kv_weight.shape == (n_kv_head * head_dim, dim)
650
+ return original_kv_weight
651
+
652
+ def weight_loader_srt(
653
+ self,
654
+ param: Parameter,
655
+ loaded_weight: torch.Tensor,
656
+ loaded_shard_id: Optional[str] = None,
657
+ ):
658
+ if (
659
+ loaded_shard_id in ["k", "v"]
660
+ and loaded_weight.shape[0] == self.head_size * self.total_num_kv_heads * 2
661
+ ):
662
+ loaded_weight = get_original_weight(loaded_weight, self.head_size)
663
+
664
+ origin_weight_loader(self, param, loaded_weight, loaded_shard_id)
665
+
666
+ setattr(QKVParallelLinear, "weight_loader", weight_loader_srt)
sglang/version.py CHANGED
@@ -1 +1 @@
1
- __version__ = "0.1.26"
1
+ __version__ = "0.2.0"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: sglang
3
- Version: 0.1.26
3
+ Version: 0.2.0
4
4
  Summary: SGLang is yet another fast serving framework for large language models and vision language models.
5
5
  License: Apache License
6
6
  Version 2.0, January 2004
@@ -315,11 +315,6 @@ docker run --gpus all \
315
315
  ```
316
316
 
317
317
  ### Common Notes
318
- - If you see errors from the Triton compiler, please install the [Triton Nightly](https://triton-lang.org/main/getting-started/installation.html) by
319
- ```
320
- pip uninstall -y triton triton-nightly
321
- pip install -U --index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/Triton-Nightly/pypi/simple/ triton-nightly
322
- ```
323
318
  - If you cannot install FlashInfer, check out its [installation](https://docs.flashinfer.ai/installation.html#) page. If you still cannot install it, you can use the slower Triton kernels by adding `--disable-flashinfer` when launching the server.
324
319
  - If you only need to use the OpenAI backend, you can avoid installing other dependencies by using `pip install "sglang[openai]"`.
325
320
 
@@ -402,6 +397,7 @@ python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3-8B-Instruct
402
397
  ```
403
398
  - If the model does not have a template in the Hugging Face tokenizer, you can specify a [custom chat template](docs/custom_chat_template.md).
404
399
  - To enable fp8 quantization, you can add `--quantization fp8` on a fp16 checkpoint or directly load a fp8 checkpoint without specifying any arguments.
400
+ - To enable experimental torch.compile support, you can add `--enable-torch-compile`. It accelerates small models on small batch sizes.
405
401
 
406
402
  ### Supported Models
407
403
 
@@ -7,7 +7,7 @@ sglang/global_config.py,sha256=QG-ABVJksKK_llvUx7fSZcmK4GGCs-hBUVcM4LCr7Nw,1749
7
7
  sglang/launch_server.py,sha256=Gg8CwNlTCCfg1dF65ZT9ePLxOT9LKtY79GhIPG6PCrU,358
8
8
  sglang/launch_server_llavavid.py,sha256=40uaazMsavKuk6YXFa5v37kdUpFGuealgJJeph1g8gU,1025
9
9
  sglang/utils.py,sha256=arJuwOAEX445M2NL9SAOi6jBNu0-cfU04PLAr-hIH3U,8168
10
- sglang/version.py,sha256=3_QdGLpuk_SDY7k9PpNcHpSTjlPdhadPiEgF82wzkqk,23
10
+ sglang/version.py,sha256=Zn1KFblwuFHiDRdRAiRnDBRkbPttWh44jKa5zG2ov0E,22
11
11
  sglang/lang/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
12
12
  sglang/lang/chat_template.py,sha256=psIlhaDo70twgLrx5Lgln03metLEA3-FZuixeI0Y7Ao,13309
13
13
  sglang/lang/compiler.py,sha256=UiXUmPR9wBAPtnORrLcyQX8Uh0ZL0nKeV8ZgBozAJPw,7531
@@ -28,9 +28,9 @@ sglang/srt/memory_pool.py,sha256=FhJk5GtYortO3MJIsMMQ-o49agwDHVX1aEQH2LITq6c,394
28
28
  sglang/srt/mm_utils.py,sha256=OptgAHDX-73Bk4jAdr2BOAJtiEXJNzPrMhaM-dy275c,8889
29
29
  sglang/srt/model_config.py,sha256=lZu1D-XLVMETHS6FBMoPn8Uowa9QFGe95d3SuWrr2q8,5282
30
30
  sglang/srt/sampling_params.py,sha256=OI11asr1Bd_E5soDjih614v4flgWxdMZU9HAF0aBafQ,3062
31
- sglang/srt/server.py,sha256=DXhcJt0V24a7yhydP1abPrK1qqV3qt7r8cyOMVOAI4M,14611
31
+ sglang/srt/server.py,sha256=IUed6vnXCx7-xbrpEMAaJZ_aa4UubPAQ5pXvcv-xNoY,14607
32
32
  sglang/srt/server_args.py,sha256=aF6L35mEB-FU3BL_ooKuCIcOXLhYLxA9-MjpaOTQRCo,13189
33
- sglang/srt/utils.py,sha256=bUp3SLzbDms0dvuETaccDPAGRHOIGW5A61pqH62XiT0,20370
33
+ sglang/srt/utils.py,sha256=DZtYSTvtSf_HWZjKZyo8TFiXahz-JfeujJcKBuBkhpQ,22318
34
34
  sglang/srt/constrained/__init__.py,sha256=5LB3_mDTMW6wcRkFA5J2Rd5HPHHEKRyiELhe4gtlBYM,1472
35
35
  sglang/srt/constrained/base_cache.py,sha256=QQjmFEiT8jlOskJoZobhrDl2TKB-B4b1LPQo9JQCP_w,1405
36
36
  sglang/srt/constrained/fsm_cache.py,sha256=P4qNDHHxpKpTnYL_8V1R6OFXlUwbM6ZcBdzddpcBgb4,1135
@@ -51,7 +51,7 @@ sglang/srt/managers/controller/cuda_graph_runner.py,sha256=0aRqA1_34oJ557Zn8PjpJ
51
51
  sglang/srt/managers/controller/infer_batch.py,sha256=SKwCwhnZ_CNlG0mVCEc4X0e4HNjJFke-c8zdWP3TzjQ,34186
52
52
  sglang/srt/managers/controller/manager_multi.py,sha256=DT8Y9RF5OyTxlrLEZYz4claNWir3UrVztdOZaVPiA6g,6077
53
53
  sglang/srt/managers/controller/manager_single.py,sha256=2xO_iWK6tWvc0B31nKbe2N3klxwQBJmPTnFhNjzhVSI,4566
54
- sglang/srt/managers/controller/model_runner.py,sha256=FwZ7FU7nhJsYhtoTNxYFc4e6oMEwSqOh8ohXOKtFPKc,15828
54
+ sglang/srt/managers/controller/model_runner.py,sha256=WzbyGkMnULuDkZ_SUe-UfOH2OZEQ-IE8aYYdQacy7fM,14349
55
55
  sglang/srt/managers/controller/radix_cache.py,sha256=tx8LEQpqLxipw9UUVj4D1YQLMMDmWnjDYv8oDlOl-co,8210
56
56
  sglang/srt/managers/controller/schedule_heuristic.py,sha256=SQAGzPS3aB_TPj7rnPBhewwyR6W1sVwW4D3zG3JUY00,2714
57
57
  sglang/srt/managers/controller/tp_worker.py,sha256=yjz-Xzl0zEy4QSU-EYneZH5vi3oHtBuXTtYe4VuDp2g,30517
@@ -85,8 +85,8 @@ sglang/test/test_conversation.py,sha256=gF_AyOxQgpPQBPnA57-kq-M0p_zFu-rBDMFgAq65
85
85
  sglang/test/test_openai_protocol.py,sha256=DVx3r6hrb8oRqbo5AYIleldxbqMBTtb-gtORM6t_Y1c,1661
86
86
  sglang/test/test_programs.py,sha256=uefeHUFKT2NJESOujj-CsnPXdw1aQQN2TzUbPCHJjGs,13654
87
87
  sglang/test/test_utils.py,sha256=kD_fQe3WroZ9Kc3NBRKPiZOFJ_JD2uEE9XIvPp6AD9Y,11048
88
- sglang-0.1.26.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
89
- sglang-0.1.26.dist-info/METADATA,sha256=QnzTK6blFTHKTDw9ULRpaJVvXyg0MuzkdqwYkk0zPb0,30986
90
- sglang-0.1.26.dist-info/WHEEL,sha256=Wyh-_nZ0DJYolHNn1_hMa4lM7uDedD_RGVwbmTjyItk,91
91
- sglang-0.1.26.dist-info/top_level.txt,sha256=yxhh3pYQkcnA7v3Bg889C2jZhvtJdEincysO7PEB09M,7
92
- sglang-0.1.26.dist-info/RECORD,,
88
+ sglang-0.2.0.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
89
+ sglang-0.2.0.dist-info/METADATA,sha256=mk2lWkWZKtTJFXM7e_z2dMdke8WiV67X9aL48lGLRaw,30791
90
+ sglang-0.2.0.dist-info/WHEEL,sha256=Wyh-_nZ0DJYolHNn1_hMa4lM7uDedD_RGVwbmTjyItk,91
91
+ sglang-0.2.0.dist-info/top_level.txt,sha256=yxhh3pYQkcnA7v3Bg889C2jZhvtJdEincysO7PEB09M,7
92
+ sglang-0.2.0.dist-info/RECORD,,