sglang 0.1.20__tar.gz → 0.1.21__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (88) hide show
  1. {sglang-0.1.20/sglang.egg-info → sglang-0.1.21}/PKG-INFO +9 -1
  2. {sglang-0.1.20 → sglang-0.1.21}/README.md +8 -0
  3. {sglang-0.1.20 → sglang-0.1.21}/pyproject.toml +1 -1
  4. {sglang-0.1.20 → sglang-0.1.21}/sglang/__init__.py +1 -1
  5. {sglang-0.1.20 → sglang-0.1.21}/sglang/backend/runtime_endpoint.py +14 -4
  6. {sglang-0.1.20 → sglang-0.1.21}/sglang/bench_latency.py +0 -1
  7. {sglang-0.1.20 → sglang-0.1.21}/sglang/global_config.py +3 -1
  8. {sglang-0.1.20 → sglang-0.1.21}/sglang/lang/chat_template.py +2 -2
  9. {sglang-0.1.20 → sglang-0.1.21}/sglang/lang/ir.py +3 -3
  10. {sglang-0.1.20 → sglang-0.1.21}/sglang/srt/managers/controller/cuda_graph_runner.py +36 -12
  11. {sglang-0.1.20 → sglang-0.1.21}/sglang/srt/managers/controller/infer_batch.py +32 -26
  12. {sglang-0.1.20 → sglang-0.1.21}/sglang/srt/managers/controller/manager_multi.py +6 -2
  13. sglang-0.1.21/sglang/srt/managers/controller/manager_single.py +177 -0
  14. {sglang-0.1.20 → sglang-0.1.21}/sglang/srt/managers/controller/model_runner.py +17 -5
  15. {sglang-0.1.20 → sglang-0.1.21}/sglang/srt/managers/controller/radix_cache.py +4 -3
  16. {sglang-0.1.20 → sglang-0.1.21}/sglang/srt/managers/controller/schedule_heuristic.py +4 -0
  17. {sglang-0.1.20 → sglang-0.1.21}/sglang/srt/managers/controller/tp_worker.py +38 -40
  18. {sglang-0.1.20 → sglang-0.1.21}/sglang/srt/memory_pool.py +29 -54
  19. {sglang-0.1.20 → sglang-0.1.21}/sglang/srt/models/minicpm.py +1 -8
  20. {sglang-0.1.20 → sglang-0.1.21}/sglang/srt/models/qwen2_moe.py +126 -107
  21. {sglang-0.1.20 → sglang-0.1.21}/sglang/srt/server.py +10 -15
  22. {sglang-0.1.20 → sglang-0.1.21}/sglang/srt/server_args.py +4 -2
  23. {sglang-0.1.20 → sglang-0.1.21/sglang.egg-info}/PKG-INFO +9 -1
  24. sglang-0.1.20/sglang/srt/managers/controller/manager_single.py +0 -102
  25. {sglang-0.1.20 → sglang-0.1.21}/LICENSE +0 -0
  26. {sglang-0.1.20 → sglang-0.1.21}/setup.cfg +0 -0
  27. {sglang-0.1.20 → sglang-0.1.21}/sglang/api.py +0 -0
  28. {sglang-0.1.20 → sglang-0.1.21}/sglang/backend/__init__.py +0 -0
  29. {sglang-0.1.20 → sglang-0.1.21}/sglang/backend/anthropic.py +0 -0
  30. {sglang-0.1.20 → sglang-0.1.21}/sglang/backend/base_backend.py +0 -0
  31. {sglang-0.1.20 → sglang-0.1.21}/sglang/backend/litellm.py +0 -0
  32. {sglang-0.1.20 → sglang-0.1.21}/sglang/backend/openai.py +0 -0
  33. {sglang-0.1.20 → sglang-0.1.21}/sglang/backend/vertexai.py +0 -0
  34. {sglang-0.1.20 → sglang-0.1.21}/sglang/lang/__init__.py +0 -0
  35. {sglang-0.1.20 → sglang-0.1.21}/sglang/lang/compiler.py +0 -0
  36. {sglang-0.1.20 → sglang-0.1.21}/sglang/lang/interpreter.py +0 -0
  37. {sglang-0.1.20 → sglang-0.1.21}/sglang/lang/tracer.py +0 -0
  38. {sglang-0.1.20 → sglang-0.1.21}/sglang/launch_server.py +0 -0
  39. {sglang-0.1.20 → sglang-0.1.21}/sglang/launch_server_llavavid.py +0 -0
  40. {sglang-0.1.20 → sglang-0.1.21}/sglang/srt/constrained/__init__.py +0 -0
  41. {sglang-0.1.20 → sglang-0.1.21}/sglang/srt/constrained/base_cache.py +0 -0
  42. {sglang-0.1.20 → sglang-0.1.21}/sglang/srt/constrained/fsm_cache.py +0 -0
  43. {sglang-0.1.20 → sglang-0.1.21}/sglang/srt/constrained/jump_forward.py +0 -0
  44. {sglang-0.1.20 → sglang-0.1.21}/sglang/srt/conversation.py +0 -0
  45. {sglang-0.1.20 → sglang-0.1.21}/sglang/srt/flush_cache.py +0 -0
  46. {sglang-0.1.20 → sglang-0.1.21}/sglang/srt/hf_transformers_utils.py +0 -0
  47. {sglang-0.1.20 → sglang-0.1.21}/sglang/srt/layers/context_flashattention_nopad.py +0 -0
  48. {sglang-0.1.20 → sglang-0.1.21}/sglang/srt/layers/extend_attention.py +0 -0
  49. {sglang-0.1.20 → sglang-0.1.21}/sglang/srt/layers/fused_moe.py +0 -0
  50. {sglang-0.1.20 → sglang-0.1.21}/sglang/srt/layers/logits_processor.py +0 -0
  51. {sglang-0.1.20 → sglang-0.1.21}/sglang/srt/layers/radix_attention.py +0 -0
  52. {sglang-0.1.20 → sglang-0.1.21}/sglang/srt/layers/token_attention.py +0 -0
  53. {sglang-0.1.20 → sglang-0.1.21}/sglang/srt/managers/controller/dp_worker.py +0 -0
  54. {sglang-0.1.20 → sglang-0.1.21}/sglang/srt/managers/detokenizer_manager.py +0 -0
  55. {sglang-0.1.20 → sglang-0.1.21}/sglang/srt/managers/io_struct.py +0 -0
  56. {sglang-0.1.20 → sglang-0.1.21}/sglang/srt/managers/tokenizer_manager.py +0 -0
  57. {sglang-0.1.20 → sglang-0.1.21}/sglang/srt/mm_utils.py +0 -0
  58. {sglang-0.1.20 → sglang-0.1.21}/sglang/srt/model_config.py +0 -0
  59. {sglang-0.1.20 → sglang-0.1.21}/sglang/srt/models/chatglm.py +0 -0
  60. {sglang-0.1.20 → sglang-0.1.21}/sglang/srt/models/commandr.py +0 -0
  61. {sglang-0.1.20 → sglang-0.1.21}/sglang/srt/models/dbrx.py +0 -0
  62. {sglang-0.1.20 → sglang-0.1.21}/sglang/srt/models/gemma.py +0 -0
  63. {sglang-0.1.20 → sglang-0.1.21}/sglang/srt/models/gemma2.py +0 -0
  64. {sglang-0.1.20 → sglang-0.1.21}/sglang/srt/models/grok.py +0 -0
  65. {sglang-0.1.20 → sglang-0.1.21}/sglang/srt/models/llama2.py +0 -0
  66. {sglang-0.1.20 → sglang-0.1.21}/sglang/srt/models/llama_classification.py +0 -0
  67. {sglang-0.1.20 → sglang-0.1.21}/sglang/srt/models/llava.py +0 -0
  68. {sglang-0.1.20 → sglang-0.1.21}/sglang/srt/models/llavavid.py +0 -0
  69. {sglang-0.1.20 → sglang-0.1.21}/sglang/srt/models/mistral.py +0 -0
  70. {sglang-0.1.20 → sglang-0.1.21}/sglang/srt/models/mixtral.py +0 -0
  71. {sglang-0.1.20 → sglang-0.1.21}/sglang/srt/models/mixtral_quant.py +0 -0
  72. {sglang-0.1.20 → sglang-0.1.21}/sglang/srt/models/qwen.py +0 -0
  73. {sglang-0.1.20 → sglang-0.1.21}/sglang/srt/models/qwen2.py +0 -0
  74. {sglang-0.1.20 → sglang-0.1.21}/sglang/srt/models/stablelm.py +0 -0
  75. {sglang-0.1.20 → sglang-0.1.21}/sglang/srt/models/yivl.py +0 -0
  76. {sglang-0.1.20 → sglang-0.1.21}/sglang/srt/openai_api_adapter.py +0 -0
  77. {sglang-0.1.20 → sglang-0.1.21}/sglang/srt/openai_protocol.py +0 -0
  78. {sglang-0.1.20 → sglang-0.1.21}/sglang/srt/sampling_params.py +0 -0
  79. {sglang-0.1.20 → sglang-0.1.21}/sglang/srt/utils.py +1 -1
  80. {sglang-0.1.20 → sglang-0.1.21}/sglang/test/test_conversation.py +0 -0
  81. {sglang-0.1.20 → sglang-0.1.21}/sglang/test/test_openai_protocol.py +0 -0
  82. {sglang-0.1.20 → sglang-0.1.21}/sglang/test/test_programs.py +0 -0
  83. {sglang-0.1.20 → sglang-0.1.21}/sglang/test/test_utils.py +0 -0
  84. {sglang-0.1.20 → sglang-0.1.21}/sglang/utils.py +0 -0
  85. {sglang-0.1.20 → sglang-0.1.21}/sglang.egg-info/SOURCES.txt +0 -0
  86. {sglang-0.1.20 → sglang-0.1.21}/sglang.egg-info/dependency_links.txt +0 -0
  87. {sglang-0.1.20 → sglang-0.1.21}/sglang.egg-info/requires.txt +0 -0
  88. {sglang-0.1.20 → sglang-0.1.21}/sglang.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: sglang
3
- Version: 0.1.20
3
+ Version: 0.1.21
4
4
  Summary: A structured generation langauge for LLMs.
5
5
  License: Apache License
6
6
  Version 2.0, January 2004
@@ -623,6 +623,14 @@ python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port
623
623
  python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000 --mem-fraction-static 0.7
624
624
  ```
625
625
  - See [hyperparameter_tuning.md](docs/hyperparameter_tuning.md) on tuning hyperparameters for better performance.
626
+ - Add `--nnodes 2` to run tensor parallelism on multiple nodes. If you have two nodes with two GPUs on each node and want to run TP=4, let `sgl-dev-1` be the hostname of the first node and `50000` be an available port.
627
+ ```
628
+ # Node 0
629
+ python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --tp 4 --nccl-init sgl-dev-1:50000 --nnodes 2 --node-rank 0
630
+
631
+ # Node 1
632
+ python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --tp 4 --nccl-init sgl-dev-1:50000 --nnodes 2 --node-rank 1
633
+ ```
626
634
 
627
635
  ### Supported Models
628
636
  - Llama
@@ -377,6 +377,14 @@ python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port
377
377
  python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --port 30000 --mem-fraction-static 0.7
378
378
  ```
379
379
  - See [hyperparameter_tuning.md](docs/hyperparameter_tuning.md) on tuning hyperparameters for better performance.
380
+ - Add `--nnodes 2` to run tensor parallelism on multiple nodes. If you have two nodes with two GPUs on each node and want to run TP=4, let `sgl-dev-1` be the hostname of the first node and `50000` be an available port.
381
+ ```
382
+ # Node 0
383
+ python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --tp 4 --nccl-init sgl-dev-1:50000 --nnodes 2 --node-rank 0
384
+
385
+ # Node 1
386
+ python -m sglang.launch_server --model-path meta-llama/Llama-2-7b-chat-hf --tp 4 --nccl-init sgl-dev-1:50000 --nnodes 2 --node-rank 1
387
+ ```
380
388
 
381
389
  ### Supported Models
382
390
  - Llama
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "sglang"
7
- version = "0.1.20"
7
+ version = "0.1.21"
8
8
  description = "A structured generation langauge for LLMs."
9
9
  readme = "README.md"
10
10
  requires-python = ">=3.8"
@@ -1,4 +1,4 @@
1
- __version__ = "0.1.20"
1
+ __version__ = "0.1.21"
2
2
 
3
3
  # SGL API Components
4
4
  from sglang.api import (
@@ -12,7 +12,6 @@ from sglang.utils import http_request
12
12
 
13
13
 
14
14
  class RuntimeEndpoint(BaseBackend):
15
-
16
15
  def __init__(
17
16
  self,
18
17
  base_url: str,
@@ -38,7 +37,8 @@ class RuntimeEndpoint(BaseBackend):
38
37
  self.model_info = res.json()
39
38
 
40
39
  self.chat_template = get_chat_template_by_model_path(
41
- self.model_info["model_path"])
40
+ self.model_info["model_path"]
41
+ )
42
42
 
43
43
  def get_model_name(self):
44
44
  return self.model_info["model_path"]
@@ -124,7 +124,12 @@ class RuntimeEndpoint(BaseBackend):
124
124
  else:
125
125
  raise RuntimeError(f"Invalid dtype: {sampling_params.dtype}")
126
126
 
127
- for item in ["return_logprob", "logprob_start_len", "top_logprobs_num", "return_text_in_logprobs"]:
127
+ for item in [
128
+ "return_logprob",
129
+ "logprob_start_len",
130
+ "top_logprobs_num",
131
+ "return_text_in_logprobs",
132
+ ]:
128
133
  value = getattr(sampling_params, item, None)
129
134
  if value is not None:
130
135
  data[item] = value
@@ -171,7 +176,12 @@ class RuntimeEndpoint(BaseBackend):
171
176
  else:
172
177
  raise RuntimeError(f"Invalid dtype: {sampling_params.dtype}")
173
178
 
174
- for item in ["return_logprob", "logprob_start_len", "top_logprobs_num", "return_text_in_logprobs"]:
179
+ for item in [
180
+ "return_logprob",
181
+ "logprob_start_len",
182
+ "top_logprobs_num",
183
+ "return_text_in_logprobs",
184
+ ]:
175
185
  value = getattr(sampling_params, item, None)
176
186
  if value is not None:
177
187
  data[item] = value
@@ -32,7 +32,6 @@ import logging
32
32
  import multiprocessing
33
33
  import time
34
34
 
35
-
36
35
  import numpy as np
37
36
  import torch
38
37
  import torch.distributed as dist
@@ -25,7 +25,8 @@ class GlobalConfig:
25
25
  # This can improve the speed for large batch sizes during prefill.
26
26
  self.layer_sync_threshold = 8192
27
27
 
28
- # Runtime constants: Flashinfer
28
+ # Runtime constants: others
29
+ self.num_continue_decode_steps = 10
29
30
  self.flashinfer_workspace_size = 192 * 1024 * 1024
30
31
 
31
32
  # Output tokenization configs
@@ -44,4 +45,5 @@ class GlobalConfig:
44
45
  # adjust_cache: Adjust the position embedding of KV cache.
45
46
  self.concate_and_append_mode = "no_adjust"
46
47
 
48
+
47
49
  global_config = GlobalConfig()
@@ -84,7 +84,7 @@ register_chat_template(
84
84
  "system": ("SYSTEM:", "\n"),
85
85
  "user": ("USER:", "\n"),
86
86
  "assistant": ("ASSISTANT:", "\n"),
87
- }
87
+ },
88
88
  )
89
89
  )
90
90
 
@@ -177,7 +177,7 @@ register_chat_template(
177
177
  "assistant": ("", "<|im_end|>\n"),
178
178
  },
179
179
  style=ChatTemplateStyle.PLAIN,
180
- stop_str=("<|im_end|>",)
180
+ stop_str=("<|im_end|>",),
181
181
  )
182
182
  )
183
183
 
@@ -24,9 +24,9 @@ class SglSamplingParams:
24
24
  presence_penalty: float = 0.0
25
25
  ignore_eos: bool = False
26
26
  return_logprob: Optional[bool] = None
27
- logprob_start_len: Optional[int] = None,
28
- top_logprobs_num: Optional[int] = None,
29
- return_text_in_logprobs: Optional[bool] = None,
27
+ logprob_start_len: Optional[int] = (None,)
28
+ top_logprobs_num: Optional[int] = (None,)
29
+ return_text_in_logprobs: Optional[bool] = (None,)
30
30
 
31
31
  # for constrained generation, not included in to_xxx_kwargs
32
32
  dtype: Optional[str] = None
@@ -8,7 +8,10 @@ from vllm.distributed.parallel_state import graph_capture
8
8
  from sglang.global_config import global_config
9
9
  from sglang.srt.layers.logits_processor import LogitProcessorOutput
10
10
  from sglang.srt.managers.controller.infer_batch import (
11
- Batch, ForwardMode, InputMetadata, init_flashinfer_args
11
+ Batch,
12
+ ForwardMode,
13
+ InputMetadata,
14
+ init_flashinfer_args,
12
15
  )
13
16
 
14
17
 
@@ -24,18 +27,28 @@ class CudaGraphRunner:
24
27
  # Common inputs
25
28
  self.max_bs = max_batch_size_to_capture
26
29
  self.input_ids = torch.zeros((self.max_bs,), dtype=torch.int32, device="cuda")
27
- self.req_pool_indices = torch.zeros((self.max_bs,), dtype=torch.int32, device="cuda")
30
+ self.req_pool_indices = torch.zeros(
31
+ (self.max_bs,), dtype=torch.int32, device="cuda"
32
+ )
28
33
  self.seq_lens = torch.ones((self.max_bs,), dtype=torch.int32, device="cuda")
29
- self.position_ids_offsets = torch.zeros((self.max_bs,), dtype=torch.int32, device="cuda")
30
- self.out_cache_loc = torch.zeros((self.max_bs,), dtype=torch.int32, device="cuda")
34
+ self.position_ids_offsets = torch.zeros(
35
+ (self.max_bs,), dtype=torch.int32, device="cuda"
36
+ )
37
+ self.out_cache_loc = torch.zeros(
38
+ (self.max_bs,), dtype=torch.int32, device="cuda"
39
+ )
31
40
 
32
41
  # FlashInfer inputs
33
- self.flashinfer_workspace_buffer = self.model_runner.flashinfer_workspace_buffers[0]
42
+ self.flashinfer_workspace_buffer = (
43
+ self.model_runner.flashinfer_workspace_buffers[0]
44
+ )
34
45
  self.flashinfer_kv_indptr = torch.zeros(
35
46
  (self.max_bs + 1,), dtype=torch.int32, device="cuda"
36
47
  )
37
48
  self.flashinfer_kv_indices = torch.zeros(
38
- (self.max_bs * model_runner.model_config.context_len,), dtype=torch.int32, device="cuda"
49
+ (self.max_bs * model_runner.model_config.context_len,),
50
+ dtype=torch.int32,
51
+ device="cuda",
39
52
  )
40
53
  self.flashinfer_kv_last_page_len = torch.ones(
41
54
  (self.max_bs,), dtype=torch.int32, device="cuda"
@@ -49,7 +62,12 @@ class CudaGraphRunner:
49
62
  with graph_capture() as graph_capture_context:
50
63
  self.stream = graph_capture_context.stream
51
64
  for bs in batch_size_list:
52
- graph, input_buffers, output_buffers, flashinfer_handler = self.capture_one_batch_size(bs)
65
+ (
66
+ graph,
67
+ input_buffers,
68
+ output_buffers,
69
+ flashinfer_handler,
70
+ ) = self.capture_one_batch_size(bs)
53
71
  self.graphs[bs] = graph
54
72
  self.input_buffers[bs] = input_buffers
55
73
  self.output_buffers[bs] = output_buffers
@@ -71,17 +89,19 @@ class CudaGraphRunner:
71
89
 
72
90
  # FlashInfer inputs
73
91
  if not _grouped_size_compiled_for_decode_kernels(
74
- self.model_runner.model_config.num_attention_heads // self.model_runner.tp_size,
92
+ self.model_runner.model_config.num_attention_heads
93
+ // self.model_runner.tp_size,
75
94
  self.model_runner.model_config.get_num_kv_heads(self.model_runner.tp_size),
76
95
  ):
77
96
  use_tensor_cores = True
78
97
  else:
79
98
  use_tensor_cores = False
80
99
  flashinfer_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
81
- self.flashinfer_workspace_buffer, "NHD",
100
+ self.flashinfer_workspace_buffer,
101
+ "NHD",
82
102
  use_cuda_graph=True,
83
103
  use_tensor_cores=use_tensor_cores,
84
- paged_kv_indptr_buffer=self.flashinfer_kv_indptr[:bs+1],
104
+ paged_kv_indptr_buffer=self.flashinfer_kv_indptr[: bs + 1],
85
105
  paged_kv_indices_buffer=self.flashinfer_kv_indices,
86
106
  paged_kv_last_page_len_buffer=self.flashinfer_kv_last_page_len[:bs],
87
107
  )
@@ -163,10 +183,14 @@ class CudaGraphRunner:
163
183
  else:
164
184
  output = LogitProcessorOutput(
165
185
  next_token_logits=output.next_token_logits[:raw_bs],
166
- next_token_logprobs=output.next_token_logprobs[:raw_bs] if output.next_token_logprobs is not None else None,
186
+ next_token_logprobs=output.next_token_logprobs[:raw_bs]
187
+ if output.next_token_logprobs is not None
188
+ else None,
167
189
  normalized_prompt_logprobs=None,
168
190
  prefill_token_logprobs=None,
169
191
  prefill_top_logprobs=None,
170
- decode_top_logprobs=output.decode_top_logprobs[:raw_bs] if output.decode_top_logprobs is not None else None,
192
+ decode_top_logprobs=output.decode_top_logprobs[:raw_bs]
193
+ if output.decode_top_logprobs is not None
194
+ else None,
171
195
  )
172
196
  return output
@@ -174,9 +174,6 @@ class Req:
174
174
 
175
175
  return False, ""
176
176
 
177
- def max_new_tokens(self):
178
- return self.sampling_params.max_new_tokens
179
-
180
177
  def check_finished(self):
181
178
  if self.finished():
182
179
  return
@@ -352,7 +349,7 @@ class Batch:
352
349
  extend_num_tokens = seq_lens.sum() - prefix_lens.sum()
353
350
  out_cache_loc = self.token_to_kv_pool.alloc(extend_num_tokens)
354
351
  if out_cache_loc is None:
355
- self.tree_cache.evict(extend_num_tokens, self.token_to_kv_pool.dec_refs)
352
+ self.tree_cache.evict(extend_num_tokens, self.token_to_kv_pool.free)
356
353
  out_cache_loc = self.token_to_kv_pool.alloc(extend_num_tokens)
357
354
 
358
355
  if out_cache_loc is None:
@@ -422,7 +419,7 @@ class Batch:
422
419
  if self.token_to_kv_pool.available_size() >= bs:
423
420
  return True
424
421
 
425
- self.tree_cache.evict(bs, self.token_to_kv_pool.dec_refs)
422
+ self.tree_cache.evict(bs, self.token_to_kv_pool.free)
426
423
 
427
424
  if self.token_to_kv_pool.available_size() >= bs:
428
425
  return True
@@ -453,7 +450,7 @@ class Batch:
453
450
  token_indices = self.req_to_token_pool.req_to_token[
454
451
  req_pool_indices_cpu[idx]
455
452
  ][last_uncached_pos : seq_lens_cpu[idx]]
456
- self.token_to_kv_pool.dec_refs(token_indices)
453
+ self.token_to_kv_pool.free(token_indices)
457
454
 
458
455
  # release the last node
459
456
  self.tree_cache.dec_lock_ref(req.last_node)
@@ -596,8 +593,7 @@ class Batch:
596
593
  "logit_bias",
597
594
  ]:
598
595
  self_val = getattr(self, item, None)
599
- # logit_bias can be None
600
- if self_val is not None:
596
+ if self_val is not None: # logit_bias can be None
601
597
  setattr(self, item, self_val[new_indices])
602
598
 
603
599
  def merge(self, other: "Batch"):
@@ -668,7 +664,9 @@ class Batch:
668
664
  sampled_index = torch.multinomial(probs_sort, num_samples=1)
669
665
  except RuntimeError as e:
670
666
  warnings.warn(f"Ignore errors in sampling: {e}")
671
- sampled_index = torch.ones(probs_sort.shape[:-1] + (1,), dtype=torch.int64, device=probs.device)
667
+ sampled_index = torch.ones(
668
+ probs_sort.shape[:-1] + (1,), dtype=torch.int64, device=probs.device
669
+ )
672
670
  batch_next_token_ids = torch.gather(probs_idx, dim=1, index=sampled_index).view(
673
671
  -1
674
672
  )
@@ -749,8 +747,14 @@ class InputMetadata:
749
747
  skip_flashinfer_init=False,
750
748
  ):
751
749
  if not skip_flashinfer_init and not model_runner.server_args.disable_flashinfer:
752
- init_flashinfer_args(forward_mode, model_runner, req_pool_indices, seq_lens, prefix_lens,
753
- model_runner.flashinfer_decode_wrapper)
750
+ init_flashinfer_args(
751
+ forward_mode,
752
+ model_runner,
753
+ req_pool_indices,
754
+ seq_lens,
755
+ prefix_lens,
756
+ model_runner.flashinfer_decode_wrapper,
757
+ )
754
758
 
755
759
  batch_size = len(req_pool_indices)
756
760
 
@@ -807,16 +811,24 @@ class InputMetadata:
807
811
  )
808
812
 
809
813
  if model_runner.server_args.disable_flashinfer:
810
- (ret.triton_max_seq_len,
811
- ret.triton_max_extend_len,
812
- ret.triton_start_loc,
813
- ret.triton_prefix_lens) = init_triton_args(forward_mode, seq_lens, prefix_lens)
814
+ (
815
+ ret.triton_max_seq_len,
816
+ ret.triton_max_extend_len,
817
+ ret.triton_start_loc,
818
+ ret.triton_prefix_lens,
819
+ ) = init_triton_args(forward_mode, seq_lens, prefix_lens)
814
820
 
815
821
  return ret
816
822
 
817
823
 
818
- def init_flashinfer_args(forward_mode, model_runner, req_pool_indices, seq_lens, prefix_lens,
819
- flashinfer_decode_wrapper):
824
+ def init_flashinfer_args(
825
+ forward_mode,
826
+ model_runner,
827
+ req_pool_indices,
828
+ seq_lens,
829
+ prefix_lens,
830
+ flashinfer_decode_wrapper,
831
+ ):
820
832
  num_qo_heads = model_runner.model_config.num_attention_heads // model_runner.tp_size
821
833
  num_kv_heads = model_runner.model_config.get_num_kv_heads(model_runner.tp_size)
822
834
  head_dim = model_runner.model_config.head_dim
@@ -827,9 +839,7 @@ def init_flashinfer_args(forward_mode, model_runner, req_pool_indices, seq_lens,
827
839
  else:
828
840
  paged_kernel_lens = prefix_lens
829
841
 
830
- kv_indptr = torch.zeros(
831
- (batch_size + 1,), dtype=torch.int32, device="cuda"
832
- )
842
+ kv_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda")
833
843
  kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
834
844
  req_pool_indices_cpu = req_pool_indices.cpu().numpy()
835
845
  paged_kernel_lens_cpu = paged_kernel_lens.cpu().numpy()
@@ -842,9 +852,7 @@ def init_flashinfer_args(forward_mode, model_runner, req_pool_indices, seq_lens,
842
852
  ],
843
853
  dim=0,
844
854
  ).contiguous()
845
- kv_last_page_len = torch.ones(
846
- (batch_size,), dtype=torch.int32, device="cuda"
847
- )
855
+ kv_last_page_len = torch.ones((batch_size,), dtype=torch.int32, device="cuda")
848
856
 
849
857
  if forward_mode == ForwardMode.DECODE:
850
858
  flashinfer_decode_wrapper.end_forward()
@@ -859,9 +867,7 @@ def init_flashinfer_args(forward_mode, model_runner, req_pool_indices, seq_lens,
859
867
  )
860
868
  else:
861
869
  # extend part
862
- qo_indptr = torch.zeros(
863
- (batch_size + 1,), dtype=torch.int32, device="cuda"
864
- )
870
+ qo_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda")
865
871
  qo_indptr[1:] = torch.cumsum(seq_lens - prefix_lens, dim=0)
866
872
 
867
873
  model_runner.flashinfer_prefill_wrapper_ragged.end_forward()
@@ -42,6 +42,8 @@ class LoadBalanceMethod(Enum):
42
42
 
43
43
 
44
44
  class Controller:
45
+ """A controller that manages multiple data parallel workers."""
46
+
45
47
  def __init__(
46
48
  self,
47
49
  load_balance_method: str,
@@ -183,9 +185,11 @@ def start_controller_process(
183
185
  except Exception:
184
186
  pipe_writer.send(get_exception_traceback())
185
187
  raise
186
-
187
188
  pipe_writer.send("init ok")
188
- loop = asyncio.get_event_loop()
189
+
190
+ loop = asyncio.new_event_loop()
191
+ loop.set_default_executor(ThreadPoolExecutor(max_workers=256))
192
+
189
193
  asyncio.set_event_loop(loop)
190
194
  loop.create_task(controller.loop_for_recv_requests())
191
195
  loop.run_until_complete(controller.loop_for_forward())
@@ -0,0 +1,177 @@
1
+ """A controller that manages a group of tensor parallel workers."""
2
+
3
+ import multiprocessing
4
+ import logging
5
+ import os
6
+ import pickle
7
+
8
+ import torch
9
+ import torch.distributed as dist
10
+ import zmq
11
+ import zmq.asyncio
12
+
13
+ from sglang.srt.managers.controller.tp_worker import ModelTpServer
14
+ from sglang.srt.server_args import PortArgs, ServerArgs, ModelPortArgs
15
+ from sglang.srt.utils import kill_parent_process
16
+ from sglang.utils import get_exception_traceback
17
+
18
+
19
+ logger = logging.getLogger("srt.controller")
20
+
21
+
22
+ def run_tp_server(
23
+ gpu_id: int,
24
+ tp_rank: int,
25
+ server_args: ServerArgs,
26
+ model_port_args: ModelPortArgs,
27
+ model_overide_args: dict,
28
+ ):
29
+ """Run a tp server."""
30
+ try:
31
+ model_server = ModelTpServer(
32
+ gpu_id,
33
+ tp_rank,
34
+ server_args,
35
+ model_port_args,
36
+ model_overide_args,
37
+ )
38
+ tp_cpu_group = model_server.model_runner.tp_group.cpu_group
39
+
40
+ while True:
41
+ recv_reqs = broadcast_recv_input(None, tp_rank, tp_cpu_group)
42
+ model_server.exposed_step(recv_reqs)
43
+ except Exception:
44
+ logger.error("Exception in run_tp_server:\n" + get_exception_traceback())
45
+ raise
46
+
47
+
48
+ def launch_tp_servers(gpu_ids, tp_rank_range, server_args,
49
+ model_port_args, model_overide_args):
50
+ """Launch multiple tp servers."""
51
+ procs = []
52
+ for i in tp_rank_range:
53
+ proc = multiprocessing.Process(target=run_tp_server, args=(
54
+ gpu_ids[i], i, server_args, model_port_args, model_overide_args
55
+ ))
56
+ proc.start()
57
+ procs.append(proc)
58
+
59
+ return procs
60
+
61
+
62
+ def broadcast_recv_input(data, rank, dist_group):
63
+ """Broadcast inputs from rank=0 to all other ranks with torch.dist backend."""
64
+
65
+ if rank == 0:
66
+ if len(data) == 0:
67
+ tensor_size = torch.tensor([0], dtype=torch.long)
68
+ dist.broadcast(tensor_size, src=0, group=dist_group)
69
+ else:
70
+ serialized_data = pickle.dumps(data)
71
+ size = len(serialized_data)
72
+ tensor_data = torch.ByteTensor(list(serialized_data))
73
+ tensor_size = torch.tensor([size], dtype=torch.long)
74
+
75
+ dist.broadcast(tensor_size, src=0, group=dist_group)
76
+ dist.broadcast(tensor_data, src=0, group=dist_group)
77
+ else:
78
+ tensor_size = torch.tensor([0], dtype=torch.long)
79
+ dist.broadcast(tensor_size, src=0, group=dist_group)
80
+ size = tensor_size.item()
81
+
82
+ if size == 0:
83
+ return []
84
+
85
+ tensor_data = torch.empty(size, dtype=torch.uint8)
86
+ dist.broadcast(tensor_data, src=0, group=dist_group)
87
+
88
+ serialized_data = bytes(tensor_data.tolist())
89
+ data = pickle.loads(serialized_data)
90
+ return data
91
+
92
+
93
+ class ControllerSingle:
94
+ """A controller that manages a group of tensor parallel workers."""
95
+
96
+ def __init__(self, server_args: ServerArgs, port_args: PortArgs, model_overide_args: dict):
97
+ # Parse args
98
+ self.server_args = server_args
99
+
100
+ # Init communication
101
+ context = zmq.Context(2)
102
+ self.recv_from_tokenizer = context.socket(zmq.PULL)
103
+ self.recv_from_tokenizer.bind(f"tcp://127.0.0.1:{port_args.router_port}")
104
+
105
+ self.send_to_detokenizer = context.socket(zmq.PUSH)
106
+ self.send_to_detokenizer.connect(
107
+ f"tcp://127.0.0.1:{port_args.detokenizer_port}"
108
+ )
109
+
110
+ # Init model server
111
+ tp_size_local = server_args.tp_size // server_args.nnodes
112
+ gpu_ids = [i for _ in range(server_args.nnodes) for i in range(tp_size_local)]
113
+
114
+ # Launch other tp ranks
115
+ if tp_size_local > 1:
116
+ tp_rank_range = range(1, tp_size_local)
117
+ self.tp_procs = launch_tp_servers(
118
+ gpu_ids, tp_rank_range, server_args,
119
+ port_args.model_port_args[0], model_overide_args)
120
+
121
+ # Launch tp rank 0
122
+ self.tp_server = ModelTpServer(
123
+ gpu_ids[0],
124
+ 0,
125
+ server_args,
126
+ port_args.model_port_args[0],
127
+ model_overide_args,
128
+ )
129
+ self.tp_cpu_group = self.tp_server.model_runner.tp_group.cpu_group
130
+
131
+ def loop_for_forward(self):
132
+ while True:
133
+ recv_reqs = self.recv_requests()
134
+
135
+ if self.server_args.tp_size > 1:
136
+ broadcast_recv_input(recv_reqs, 0, self.tp_cpu_group)
137
+
138
+ out_pyobjs = self.tp_server.exposed_step(recv_reqs)
139
+
140
+ for obj in out_pyobjs:
141
+ self.send_to_detokenizer.send_pyobj(obj)
142
+
143
+ def recv_requests(self):
144
+ recv_reqs = []
145
+ while True:
146
+ try:
147
+ recv_req = self.recv_from_tokenizer.recv_pyobj(zmq.NOBLOCK)
148
+ recv_reqs.append(recv_req)
149
+ except zmq.ZMQError:
150
+ break
151
+ return recv_reqs
152
+
153
+
154
+ def start_controller_process(
155
+ server_args: ServerArgs, port_args: PortArgs, pipe_writer, model_overide_args: dict
156
+ ):
157
+ logging.basicConfig(
158
+ level=getattr(logging, server_args.log_level.upper()),
159
+ format="%(message)s",
160
+ )
161
+
162
+ try:
163
+ controller = ControllerSingle(server_args, port_args, model_overide_args)
164
+ except Exception:
165
+ pipe_writer.send(get_exception_traceback())
166
+ raise
167
+
168
+ pipe_writer.send("init ok")
169
+
170
+ try:
171
+ controller.loop_for_forward()
172
+ except Exception:
173
+ logger.error("Exception in ControllerSingle:\n" + get_exception_traceback())
174
+ finally:
175
+ for t in controller.tp_procs:
176
+ os.kill(t.pid, 9)
177
+ kill_parent_process()
@@ -11,12 +11,17 @@ import torch
11
11
  import torch.nn as nn
12
12
  from vllm.config import DeviceConfig, LoadConfig
13
13
  from vllm.config import ModelConfig as VllmModelConfig
14
- from vllm.distributed import init_distributed_environment, initialize_model_parallel
14
+ from vllm.distributed import init_distributed_environment, initialize_model_parallel, get_tp_group
15
15
  from vllm.model_executor.model_loader import get_model
16
16
  from vllm.model_executor.models import ModelRegistry
17
17
 
18
18
  from sglang.global_config import global_config
19
- from sglang.srt.managers.controller.infer_batch import Batch, ForwardMode, InputMetadata, global_server_args_dict
19
+ from sglang.srt.managers.controller.infer_batch import (
20
+ Batch,
21
+ ForwardMode,
22
+ InputMetadata,
23
+ global_server_args_dict,
24
+ )
20
25
  from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
21
26
  from sglang.srt.server_args import ServerArgs
22
27
  from sglang.srt.utils import (
@@ -70,6 +75,7 @@ class ModelRunner:
70
75
  distributed_init_method=nccl_init_method,
71
76
  )
72
77
  initialize_model_parallel(tensor_model_parallel_size=self.tp_size)
78
+ self.tp_group = get_tp_group()
73
79
  total_gpu_memory = get_available_gpu_memory(
74
80
  self.gpu_id, distributed=self.tp_size > 1
75
81
  )
@@ -83,7 +89,9 @@ class ModelRunner:
83
89
 
84
90
  # Set some global args
85
91
  global_server_args_dict["disable_flashinfer"] = server_args.disable_flashinfer
86
- global_server_args_dict["attention_reduce_in_fp32"] = server_args.attention_reduce_in_fp32
92
+ global_server_args_dict[
93
+ "attention_reduce_in_fp32"
94
+ ] = server_args.attention_reduce_in_fp32
87
95
 
88
96
  # Load the model and create memory pool
89
97
  self.load_model()
@@ -217,7 +225,9 @@ class ModelRunner:
217
225
  self.flashinfer_workspace_buffers[1], "NHD"
218
226
  )
219
227
  self.flashinfer_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
220
- self.flashinfer_workspace_buffers[0], "NHD", use_tensor_cores=use_tensor_cores
228
+ self.flashinfer_workspace_buffers[0],
229
+ "NHD",
230
+ use_tensor_cores=use_tensor_cores,
221
231
  )
222
232
 
223
233
  def init_cuda_graphs(self):
@@ -229,7 +239,9 @@ class ModelRunner:
229
239
 
230
240
  logger.info(f"[gpu_id={self.gpu_id}] Capture cuda graph begin.")
231
241
  batch_size_list = [1, 2, 4] + [i * 8 for i in range(1, 16)]
232
- self.cuda_graph_runner = CudaGraphRunner(self, max_batch_size_to_capture=max(batch_size_list))
242
+ self.cuda_graph_runner = CudaGraphRunner(
243
+ self, max_batch_size_to_capture=max(batch_size_list)
244
+ )
233
245
  self.cuda_graph_runner.capture(batch_size_list)
234
246
 
235
247
  @torch.inference_mode()