sglang 0.1.16__py3-none-any.whl → 0.1.18__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. sglang/__init__.py +3 -1
  2. sglang/api.py +7 -7
  3. sglang/backend/anthropic.py +1 -1
  4. sglang/backend/litellm.py +90 -0
  5. sglang/backend/openai.py +158 -11
  6. sglang/backend/runtime_endpoint.py +18 -10
  7. sglang/bench_latency.py +299 -0
  8. sglang/global_config.py +12 -2
  9. sglang/lang/compiler.py +2 -2
  10. sglang/lang/interpreter.py +114 -67
  11. sglang/lang/ir.py +28 -3
  12. sglang/launch_server.py +4 -1
  13. sglang/launch_server_llavavid.py +2 -1
  14. sglang/srt/constrained/__init__.py +13 -6
  15. sglang/srt/constrained/fsm_cache.py +8 -2
  16. sglang/srt/constrained/jump_forward.py +113 -25
  17. sglang/srt/conversation.py +2 -0
  18. sglang/srt/flush_cache.py +3 -1
  19. sglang/srt/hf_transformers_utils.py +130 -1
  20. sglang/srt/layers/extend_attention.py +17 -0
  21. sglang/srt/layers/fused_moe.py +582 -0
  22. sglang/srt/layers/logits_processor.py +65 -32
  23. sglang/srt/layers/radix_attention.py +41 -7
  24. sglang/srt/layers/token_attention.py +16 -1
  25. sglang/srt/managers/controller/dp_worker.py +113 -0
  26. sglang/srt/managers/{router → controller}/infer_batch.py +242 -100
  27. sglang/srt/managers/controller/manager_multi.py +191 -0
  28. sglang/srt/managers/{router/manager.py → controller/manager_single.py} +34 -14
  29. sglang/srt/managers/{router → controller}/model_runner.py +262 -158
  30. sglang/srt/managers/{router → controller}/radix_cache.py +11 -1
  31. sglang/srt/managers/{router/scheduler.py → controller/schedule_heuristic.py} +9 -7
  32. sglang/srt/managers/{router/model_rpc.py → controller/tp_worker.py} +298 -267
  33. sglang/srt/managers/detokenizer_manager.py +42 -46
  34. sglang/srt/managers/io_struct.py +22 -12
  35. sglang/srt/managers/tokenizer_manager.py +151 -87
  36. sglang/srt/model_config.py +83 -5
  37. sglang/srt/models/chatglm.py +399 -0
  38. sglang/srt/models/commandr.py +10 -13
  39. sglang/srt/models/dbrx.py +9 -15
  40. sglang/srt/models/gemma.py +12 -15
  41. sglang/srt/models/grok.py +738 -0
  42. sglang/srt/models/llama2.py +26 -15
  43. sglang/srt/models/llama_classification.py +104 -0
  44. sglang/srt/models/llava.py +86 -19
  45. sglang/srt/models/llavavid.py +11 -20
  46. sglang/srt/models/mixtral.py +282 -103
  47. sglang/srt/models/mixtral_quant.py +372 -0
  48. sglang/srt/models/qwen.py +9 -13
  49. sglang/srt/models/qwen2.py +11 -13
  50. sglang/srt/models/stablelm.py +9 -15
  51. sglang/srt/models/yivl.py +17 -22
  52. sglang/srt/openai_api_adapter.py +150 -95
  53. sglang/srt/openai_protocol.py +11 -2
  54. sglang/srt/server.py +124 -48
  55. sglang/srt/server_args.py +128 -48
  56. sglang/srt/utils.py +234 -67
  57. sglang/test/test_programs.py +65 -3
  58. sglang/test/test_utils.py +32 -1
  59. sglang/utils.py +23 -4
  60. {sglang-0.1.16.dist-info → sglang-0.1.18.dist-info}/METADATA +40 -27
  61. sglang-0.1.18.dist-info/RECORD +78 -0
  62. {sglang-0.1.16.dist-info → sglang-0.1.18.dist-info}/WHEEL +1 -1
  63. sglang/srt/backend_config.py +0 -13
  64. sglang/srt/models/dbrx_config.py +0 -281
  65. sglang/srt/weight_utils.py +0 -417
  66. sglang-0.1.16.dist-info/RECORD +0 -72
  67. {sglang-0.1.16.dist-info → sglang-0.1.18.dist-info}/LICENSE +0 -0
  68. {sglang-0.1.16.dist-info → sglang-0.1.18.dist-info}/top_level.txt +0 -0
@@ -1,20 +1,26 @@
1
+ """A controller that manages a group of tensor parallel workers."""
2
+
1
3
  import asyncio
2
4
  import logging
5
+ from concurrent.futures import ThreadPoolExecutor
3
6
 
4
7
  import uvloop
5
8
  import zmq
6
9
  import zmq.asyncio
7
10
 
8
11
  from sglang.global_config import global_config
9
- from sglang.srt.managers.router.model_rpc import ModelRpcClient
12
+ from sglang.srt.managers.controller.tp_worker import ModelTpClient
10
13
  from sglang.srt.server_args import PortArgs, ServerArgs
11
- from sglang.srt.utils import get_exception_traceback
14
+ from sglang.srt.utils import kill_parent_process
15
+ from sglang.utils import get_exception_traceback
12
16
 
13
17
  asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
14
18
 
19
+ logger = logging.getLogger("srt.controller")
20
+
15
21
 
16
- class RouterManager:
17
- def __init__(self, model_client: ModelRpcClient, port_args: PortArgs):
22
+ class ControllerSingle:
23
+ def __init__(self, model_client: ModelTpClient, port_args: PortArgs):
18
24
  # Init communication
19
25
  context = zmq.asyncio.Context(2)
20
26
  self.recv_from_tokenizer = context.socket(zmq.PULL)
@@ -30,7 +36,7 @@ class RouterManager:
30
36
  self.recv_reqs = []
31
37
 
32
38
  # Init some configs
33
- self.request_dependency_time = global_config.request_dependency_time
39
+ self.request_dependency_delay = global_config.request_dependency_delay
34
40
 
35
41
  async def loop_for_forward(self):
36
42
  while True:
@@ -44,14 +50,16 @@ class RouterManager:
44
50
  # async sleep for receiving the subsequent request and avoiding cache miss
45
51
  slept = False
46
52
  if len(out_pyobjs) != 0:
47
- has_finished = any([obj.finished for obj in out_pyobjs])
53
+ has_finished = any(
54
+ [obj.finished_reason is not None for obj in out_pyobjs]
55
+ )
48
56
  if has_finished:
49
- if self.request_dependency_time > 0:
57
+ if self.request_dependency_delay > 0:
50
58
  slept = True
51
- await asyncio.sleep(self.request_dependency_time)
59
+ await asyncio.sleep(self.request_dependency_delay)
52
60
 
53
61
  if not slept:
54
- await asyncio.sleep(0.0006)
62
+ await asyncio.sleep(global_config.wait_for_new_request_delay)
55
63
 
56
64
  async def loop_for_recv_requests(self):
57
65
  while True:
@@ -59,7 +67,7 @@ class RouterManager:
59
67
  self.recv_reqs.append(recv_req)
60
68
 
61
69
 
62
- def start_router_process(
70
+ def start_controller_process(
63
71
  server_args: ServerArgs, port_args: PortArgs, pipe_writer, model_overide_args
64
72
  ):
65
73
  logging.basicConfig(
@@ -68,8 +76,14 @@ def start_router_process(
68
76
  )
69
77
 
70
78
  try:
71
- model_client = ModelRpcClient(server_args, port_args, model_overide_args)
72
- router = RouterManager(model_client, port_args)
79
+ tp_size_local = server_args.tp_size // server_args.nnodes
80
+ model_client = ModelTpClient(
81
+ [i for _ in range(server_args.nnodes) for i in range(tp_size_local)],
82
+ server_args,
83
+ port_args.model_port_args[0],
84
+ model_overide_args,
85
+ )
86
+ controller = ControllerSingle(model_client, port_args)
73
87
  except Exception:
74
88
  pipe_writer.send(get_exception_traceback())
75
89
  raise
@@ -77,6 +91,12 @@ def start_router_process(
77
91
  pipe_writer.send("init ok")
78
92
 
79
93
  loop = asyncio.new_event_loop()
94
+ loop.set_default_executor(ThreadPoolExecutor(max_workers=256))
80
95
  asyncio.set_event_loop(loop)
81
- loop.create_task(router.loop_for_recv_requests())
82
- loop.run_until_complete(router.loop_for_forward())
96
+ loop.create_task(controller.loop_for_recv_requests())
97
+ try:
98
+ loop.run_until_complete(controller.loop_for_forward())
99
+ except Exception:
100
+ logger.error("Exception in ControllerSingle:\n" + get_exception_traceback())
101
+ finally:
102
+ kill_parent_process()
@@ -1,69 +1,40 @@
1
+ """ModelRunner runs the forward passes of the models."""
2
+
1
3
  import importlib
2
4
  import importlib.resources
3
- import inspect
4
5
  import logging
5
6
  import pkgutil
6
7
  from dataclasses import dataclass
7
8
  from functools import lru_cache
8
- from typing import List
9
+ from typing import List, Optional, Type
9
10
 
10
11
  import numpy as np
11
12
  import torch
12
- from vllm.distributed import initialize_model_parallel
13
- from vllm.model_executor.layers.quantization.awq import AWQConfig
14
- from vllm.model_executor.layers.quantization.gptq import GPTQConfig
15
- from vllm.model_executor.layers.quantization.marlin import MarlinConfig
16
- from vllm.model_executor.model_loader.utils import set_default_torch_dtype
17
-
18
- from sglang.srt.managers.router.infer_batch import Batch, ForwardMode
13
+ import torch.nn as nn
14
+ from vllm.config import DeviceConfig, LoadConfig
15
+ from vllm.config import ModelConfig as VllmModelConfig
16
+ from vllm.distributed import init_distributed_environment, initialize_model_parallel
17
+ from vllm.model_executor.model_loader import get_model
18
+ from vllm.model_executor.models import ModelRegistry
19
+
20
+ from sglang.srt.managers.controller.infer_batch import Batch, ForwardMode
19
21
  from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
20
- from sglang.srt.utils import is_multimodal_model, get_available_gpu_memory
21
-
22
+ from sglang.srt.server_args import ServerArgs
23
+ from sglang.srt.utils import (
24
+ get_available_gpu_memory,
25
+ is_multimodal_model,
26
+ monkey_patch_vllm_dummy_weight_loader,
27
+ monkey_patch_vllm_p2p_access_check,
28
+ )
22
29
 
23
- QUANTIZATION_CONFIG_MAPPING = {
24
- "awq": AWQConfig,
25
- "gptq": GPTQConfig,
26
- "marlin": MarlinConfig,
27
- }
28
-
29
- logger = logging.getLogger("model_runner")
30
+ logger = logging.getLogger("srt.model_runner")
30
31
 
31
32
  # for server args in model endpoints
32
33
  global_server_args_dict = {}
33
34
 
34
35
 
35
- @lru_cache()
36
- def import_model_classes():
37
- model_arch_name_to_cls = {}
38
- package_name = "sglang.srt.models"
39
- package = importlib.import_module(package_name)
40
- for _, name, ispkg in pkgutil.iter_modules(package.__path__, package_name + "."):
41
- if not ispkg:
42
- module = importlib.import_module(name)
43
- if hasattr(module, "EntryClass"):
44
- model_arch_name_to_cls[module.EntryClass.__name__] = module.EntryClass
45
- return model_arch_name_to_cls
46
-
47
-
48
- def get_model_cls_by_arch_name(model_arch_names):
49
- model_arch_name_to_cls = import_model_classes()
50
-
51
- model_class = None
52
- for arch in model_arch_names:
53
- if arch in model_arch_name_to_cls:
54
- model_class = model_arch_name_to_cls[arch]
55
- break
56
- else:
57
- raise ValueError(
58
- f"Unsupported architectures: {arch}. "
59
- f"Supported list: {list(model_arch_name_to_cls.keys())}"
60
- )
61
- return model_class
62
-
63
-
64
36
  @dataclass
65
37
  class InputMetadata:
66
- model_runner: "ModelRunner"
67
38
  forward_mode: ForwardMode
68
39
  batch_size: int
69
40
  total_num_tokens: int
@@ -94,73 +65,82 @@ class InputMetadata:
94
65
  kv_indptr: torch.Tensor = None
95
66
  kv_indices: torch.Tensor = None
96
67
  kv_last_page_len: torch.Tensor = None
97
- prefill_wrapper = None
98
- decode_wrapper = None
68
+ flashinfer_prefill_wrapper_ragged: "BatchPrefillWithRaggedKVCacheWrapper" = None
69
+ flashinfer_prefill_wrapper_paged: "BatchPrefillWithPagedKVCacheWrapper" = None
70
+ flashinfer_decode_wrapper: "BatchDecodeWithPagedKVCacheWrapper" = None
99
71
 
100
- def init_flashinfer_args(self, tp_size):
101
- from flashinfer import (
102
- BatchDecodeWithPagedKVCacheWrapper,
103
- BatchPrefillWithPagedKVCacheWrapper,
104
- )
72
+ def init_flashinfer_args(self, num_qo_heads, num_kv_heads, head_dim):
73
+ if (
74
+ self.forward_mode == ForwardMode.PREFILL
75
+ or self.forward_mode == ForwardMode.EXTEND
76
+ ):
77
+ paged_kernel_lens = self.prefix_lens
78
+ self.no_prefix = torch.all(self.prefix_lens == 0)
79
+ else:
80
+ paged_kernel_lens = self.seq_lens
105
81
 
106
82
  self.kv_indptr = torch.zeros(
107
83
  (self.batch_size + 1,), dtype=torch.int32, device="cuda"
108
84
  )
109
- self.kv_indptr[1:] = torch.cumsum(self.seq_lens, dim=0)
85
+ self.kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
110
86
  self.kv_last_page_len = torch.ones(
111
87
  (self.batch_size,), dtype=torch.int32, device="cuda"
112
88
  )
113
89
  req_pool_indices_cpu = self.req_pool_indices.cpu().numpy()
114
- seq_lens_cpu = self.seq_lens.cpu().numpy()
90
+ paged_kernel_lens_cpu = paged_kernel_lens.cpu().numpy()
115
91
  self.kv_indices = torch.cat(
116
92
  [
117
93
  self.req_to_token_pool.req_to_token[
118
- req_pool_indices_cpu[i], : seq_lens_cpu[i]
94
+ req_pool_indices_cpu[i], : paged_kernel_lens_cpu[i]
119
95
  ]
120
96
  for i in range(self.batch_size)
121
97
  ],
122
98
  dim=0,
123
99
  ).contiguous()
124
100
 
125
- workspace_buffer = torch.empty(
126
- 32 * 1024 * 1024, dtype=torch.int8, device="cuda"
127
- )
128
101
  if (
129
102
  self.forward_mode == ForwardMode.PREFILL
130
103
  or self.forward_mode == ForwardMode.EXTEND
131
104
  ):
105
+ # extend part
132
106
  self.qo_indptr = torch.zeros(
133
107
  (self.batch_size + 1,), dtype=torch.int32, device="cuda"
134
108
  )
135
109
  self.qo_indptr[1:] = torch.cumsum(self.extend_seq_lens, dim=0)
136
- self.prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper(
137
- workspace_buffer, "NHD"
110
+
111
+ self.flashinfer_prefill_wrapper_ragged.end_forward()
112
+ self.flashinfer_prefill_wrapper_ragged.begin_forward(
113
+ self.qo_indptr,
114
+ self.qo_indptr.clone(),
115
+ num_qo_heads,
116
+ num_kv_heads,
117
+ head_dim,
138
118
  )
139
- args = [
119
+
120
+ # cached part
121
+ self.flashinfer_prefill_wrapper_paged.end_forward()
122
+ self.flashinfer_prefill_wrapper_paged.begin_forward(
140
123
  self.qo_indptr,
141
124
  self.kv_indptr,
142
125
  self.kv_indices,
143
126
  self.kv_last_page_len,
144
- self.model_runner.model_config.num_attention_heads // tp_size,
145
- self.model_runner.model_config.num_key_value_heads // tp_size,
146
- self.model_runner.model_config.head_dim,
147
- ]
148
-
149
- self.prefill_wrapper.begin_forward(*args)
150
- else:
151
- self.decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
152
- workspace_buffer, "NHD"
127
+ num_qo_heads,
128
+ num_kv_heads,
129
+ head_dim,
130
+ 1
153
131
  )
154
- self.decode_wrapper.begin_forward(
132
+ else:
133
+ self.flashinfer_decode_wrapper.end_forward()
134
+ self.flashinfer_decode_wrapper.begin_forward(
155
135
  self.kv_indptr,
156
136
  self.kv_indices,
157
137
  self.kv_last_page_len,
158
- self.model_runner.model_config.num_attention_heads // tp_size,
159
- self.model_runner.model_config.num_key_value_heads // tp_size,
160
- self.model_runner.model_config.head_dim,
138
+ num_qo_heads,
139
+ num_kv_heads,
140
+ head_dim,
161
141
  1,
162
- "NONE",
163
- "float16",
142
+ pos_encoding_mode="NONE",
143
+ data_type=self.token_to_kv_pool.kv_data[0].dtype
164
144
  )
165
145
 
166
146
  def init_extend_args(self):
@@ -184,6 +164,9 @@ class InputMetadata:
184
164
  out_cache_cont_end=None,
185
165
  top_logprobs_nums=None,
186
166
  return_logprob=False,
167
+ flashinfer_prefill_wrapper_ragged=None,
168
+ flashinfer_prefill_wrapper_paged=None,
169
+ flashinfer_decode_wrapper=None,
187
170
  ):
188
171
  batch_size = len(req_pool_indices)
189
172
  start_loc = torch.zeros((batch_size,), dtype=torch.int32, device="cuda")
@@ -216,7 +199,6 @@ class InputMetadata:
216
199
  other_kv_index = None
217
200
 
218
201
  ret = cls(
219
- model_runner=model_runner,
220
202
  forward_mode=forward_mode,
221
203
  batch_size=batch_size,
222
204
  total_num_tokens=total_num_tokens,
@@ -234,13 +216,20 @@ class InputMetadata:
234
216
  other_kv_index=other_kv_index,
235
217
  return_logprob=return_logprob,
236
218
  top_logprobs_nums=top_logprobs_nums,
219
+ flashinfer_prefill_wrapper_ragged=flashinfer_prefill_wrapper_ragged,
220
+ flashinfer_prefill_wrapper_paged=flashinfer_prefill_wrapper_paged,
221
+ flashinfer_decode_wrapper=flashinfer_decode_wrapper,
237
222
  )
238
223
 
239
224
  if forward_mode == ForwardMode.EXTEND:
240
225
  ret.init_extend_args()
241
226
 
242
- if global_server_args_dict.get("enable_flashinfer", False):
243
- ret.init_flashinfer_args(tp_size)
227
+ if not global_server_args_dict.get("disable_flashinfer", False):
228
+ ret.init_flashinfer_args(
229
+ model_runner.model_config.num_attention_heads // tp_size,
230
+ model_runner.model_config.get_num_kv_heads(tp_size),
231
+ model_runner.model_config.head_dim
232
+ )
244
233
 
245
234
  return ret
246
235
 
@@ -249,122 +238,180 @@ class ModelRunner:
249
238
  def __init__(
250
239
  self,
251
240
  model_config,
252
- mem_fraction_static,
253
- tp_rank,
254
- tp_size,
255
- nccl_port,
256
- load_format="auto",
257
- trust_remote_code=True,
258
- server_args_dict: dict = {},
241
+ mem_fraction_static: float,
242
+ gpu_id: int,
243
+ tp_rank: int,
244
+ tp_size: int,
245
+ nccl_port: int,
246
+ server_args: ServerArgs,
259
247
  ):
260
248
  self.model_config = model_config
261
249
  self.mem_fraction_static = mem_fraction_static
250
+ self.gpu_id = gpu_id
262
251
  self.tp_rank = tp_rank
263
252
  self.tp_size = tp_size
264
253
  self.nccl_port = nccl_port
265
- self.load_format = load_format
266
- self.trust_remote_code = trust_remote_code
267
-
268
- global global_server_args_dict
269
- global_server_args_dict = server_args_dict
254
+ self.server_args = server_args
255
+ self.is_multimodal_model = is_multimodal_model(self.model_config)
256
+ monkey_patch_vllm_dummy_weight_loader()
270
257
 
271
258
  # Init torch distributed
272
- torch.cuda.set_device(self.tp_rank)
273
- torch.distributed.init_process_group(
259
+ logger.info(f"[gpu_id={self.gpu_id}] Set cuda device.")
260
+ torch.cuda.set_device(self.gpu_id)
261
+ logger.info(f"[gpu_id={self.gpu_id}] Init nccl begin.")
262
+ monkey_patch_vllm_p2p_access_check(self.gpu_id)
263
+ if server_args.nccl_init_addr:
264
+ nccl_init_method = f"tcp://{server_args.nccl_init_addr}"
265
+ else:
266
+ nccl_init_method = f"tcp://127.0.0.1:{self.nccl_port}"
267
+ init_distributed_environment(
274
268
  backend="nccl",
275
269
  world_size=self.tp_size,
276
270
  rank=self.tp_rank,
277
- init_method=f"tcp://127.0.0.1:{self.nccl_port}",
271
+ local_rank=self.gpu_id,
272
+ distributed_init_method=nccl_init_method
278
273
  )
279
-
280
274
  initialize_model_parallel(tensor_model_parallel_size=self.tp_size)
281
-
282
275
  total_gpu_memory = get_available_gpu_memory(
283
- self.tp_rank, distributed=self.tp_size > 1
284
- ) * (1 << 30)
276
+ self.gpu_id, distributed=self.tp_size > 1
277
+ )
278
+
279
+ if self.tp_size > 1:
280
+ total_local_gpu_memory = get_available_gpu_memory(self.gpu_id)
281
+ if total_local_gpu_memory < total_gpu_memory * 0.9:
282
+ raise ValueError(
283
+ "The memory capacity is unbalanced. Some GPUs may be occupied by other processes."
284
+ )
285
+
286
+ # Set some global args
287
+ global global_server_args_dict
288
+ global_server_args_dict = {
289
+ "disable_flashinfer": server_args.disable_flashinfer,
290
+ "attention_reduce_in_fp32": server_args.attention_reduce_in_fp32,
291
+ }
292
+
293
+ # Load the model and create memory pool
285
294
  self.load_model()
286
295
  self.init_memory_pool(total_gpu_memory)
287
-
288
- self.is_multimodal_model = is_multimodal_model(self.model_config)
296
+ self.init_cublas()
297
+ self.init_flash_infer()
289
298
 
290
299
  def load_model(self):
291
- """See also vllm/model_executor/model_loader.py::get_model"""
292
- # Select model class
293
- architectures = getattr(self.model_config.hf_config, "architectures", [])
294
- model_class = get_model_cls_by_arch_name(architectures)
295
- logger.info(f"Rank {self.tp_rank}: load weight begin.")
296
-
297
- # Load weights
298
- quant_config = None
299
-
300
- quant_cfg = getattr(self.model_config.hf_config, "quantization_config", None)
301
- if quant_cfg is not None:
302
- quant_method = quant_cfg.get("quant_method", "").lower()
303
- # compat: autogptq >=0.8.0 use checkpoint_format: str
304
- # compat: autogptq <=0.7.1 is_marlin_format: bool
305
- is_format_marlin = quant_cfg.get(
306
- "checkpoint_format"
307
- ) == "marlin" or quant_cfg.get("is_marlin_format", False)
308
-
309
- # Use marlin if the GPTQ model is serialized in marlin format.
310
- if quant_method == "gptq" and is_format_marlin:
311
- quant_method = "marlin"
312
-
313
- quant_config_class = QUANTIZATION_CONFIG_MAPPING.get(quant_method)
314
-
315
- if quant_config_class is None:
316
- raise ValueError(f"Unsupported quantization method: {quant_method}")
317
-
318
- quant_config = quant_config_class.from_config(quant_cfg)
319
- logger.info(f"quant_config: {quant_config}")
320
-
321
- with set_default_torch_dtype(torch.float16):
322
- with torch.device("cuda"):
323
- model = model_class(
324
- config=self.model_config.hf_config, quant_config=quant_config
325
- )
326
- model.load_weights(
327
- self.model_config.path,
328
- cache_dir=None,
329
- load_format=self.load_format,
330
- revision=None,
331
- )
332
- self.model = model.eval()
300
+ logger.info(
301
+ f"[gpu_id={self.gpu_id}] Load weight begin. "
302
+ f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
303
+ )
333
304
 
334
- logger.info(f"Rank {self.tp_rank}: load weight end.")
305
+ device_config = DeviceConfig()
306
+ load_config = LoadConfig(load_format=self.server_args.load_format)
307
+ vllm_model_config = VllmModelConfig(
308
+ model=self.server_args.model_path,
309
+ quantization=self.server_args.quantization,
310
+ tokenizer=None,
311
+ tokenizer_mode=None,
312
+ trust_remote_code=self.server_args.trust_remote_code,
313
+ dtype=self.server_args.dtype,
314
+ seed=42,
315
+ skip_tokenizer_init=True,
316
+ )
317
+ self.dtype = vllm_model_config.dtype
318
+ if self.model_config.model_overide_args is not None:
319
+ vllm_model_config.hf_config.update(self.model_config.model_overide_args)
320
+
321
+ self.model = get_model(
322
+ model_config=vllm_model_config,
323
+ device_config=device_config,
324
+ load_config=load_config,
325
+ lora_config=None,
326
+ vision_language_config=None,
327
+ parallel_config=None,
328
+ scheduler_config=None,
329
+ cache_config=None,
330
+ )
331
+ logger.info(
332
+ f"[gpu_id={self.gpu_id}] Load weight end. "
333
+ f"type={type(self.model).__name__}, "
334
+ f"dtype={self.dtype}, "
335
+ f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
336
+ )
335
337
 
336
338
  def profile_max_num_token(self, total_gpu_memory):
337
339
  available_gpu_memory = get_available_gpu_memory(
338
- self.tp_rank, distributed=self.tp_size > 1
339
- ) * (1 << 30)
340
+ self.gpu_id, distributed=self.tp_size > 1
341
+ )
340
342
  head_dim = self.model_config.head_dim
341
- head_num = self.model_config.num_key_value_heads // self.tp_size
342
- cell_size = head_num * head_dim * self.model_config.num_hidden_layers * 2 * 2
343
+ head_num = self.model_config.get_num_kv_heads(self.tp_size)
344
+ cell_size = head_num * head_dim * self.model_config.num_hidden_layers * 2 * torch._utils._element_size(self.dtype)
343
345
  rest_memory = available_gpu_memory - total_gpu_memory * (
344
346
  1 - self.mem_fraction_static
345
347
  )
346
- max_num_token = int(rest_memory // cell_size)
348
+ max_num_token = int(rest_memory * (1 << 30) // cell_size)
347
349
  return max_num_token
348
350
 
349
351
  def init_memory_pool(self, total_gpu_memory):
350
- self.max_total_num_token = self.profile_max_num_token(total_gpu_memory)
352
+ self.max_total_num_tokens = self.profile_max_num_token(total_gpu_memory)
351
353
 
352
- if self.max_total_num_token <= 0:
354
+ if self.max_total_num_tokens <= 0:
353
355
  raise RuntimeError(
354
- "Not enought memory. " "Please try to increase --mem-fraction-static."
356
+ "Not enough memory. Please try to increase --mem-fraction-static."
355
357
  )
356
358
 
357
359
  self.req_to_token_pool = ReqToTokenPool(
358
- int(self.max_total_num_token / self.model_config.context_len * 256),
360
+ int(self.max_total_num_tokens / self.model_config.context_len * 256),
359
361
  self.model_config.context_len + 8,
360
362
  )
361
363
  self.token_to_kv_pool = TokenToKVPool(
362
- self.max_total_num_token,
363
- dtype=torch.float16,
364
- head_num=self.model_config.num_key_value_heads // self.tp_size,
364
+ self.max_total_num_tokens,
365
+ dtype=self.dtype,
366
+ head_num=self.model_config.get_num_kv_heads(self.tp_size),
365
367
  head_dim=self.model_config.head_dim,
366
368
  layer_num=self.model_config.num_hidden_layers,
367
369
  )
370
+ logger.info(
371
+ f"[gpu_id={self.gpu_id}] Memory pool end. "
372
+ f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
373
+ )
374
+
375
+ def init_cublas(self):
376
+ """We need to run a small matmul to init cublas. Otherwise, it will raise some errors later."""
377
+ dtype = torch.float16
378
+ device = "cuda"
379
+ a = torch.ones((16, 16), dtype=dtype, device=device)
380
+ b = torch.ones((16, 16), dtype=dtype, device=device)
381
+ c = a @ b
382
+ return c
383
+
384
+ def init_flash_infer(self):
385
+ if not global_server_args_dict.get("disable_flashinfer", False):
386
+ from flashinfer import (
387
+ BatchPrefillWithRaggedKVCacheWrapper,
388
+ BatchPrefillWithPagedKVCacheWrapper,
389
+ BatchDecodeWithPagedKVCacheWrapper,
390
+ )
391
+ from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
392
+
393
+ if not _grouped_size_compiled_for_decode_kernels(
394
+ self.model_config.num_attention_heads // self.tp_size,
395
+ self.model_config.get_num_kv_heads(self.tp_size)):
396
+ use_tensor_cores = True
397
+ else:
398
+ use_tensor_cores = False
399
+
400
+ workspace_buffers = torch.empty(
401
+ 3, 96 * 1024 * 1024, dtype=torch.uint8, device="cuda"
402
+ )
403
+ self.flashinfer_prefill_wrapper_ragged = BatchPrefillWithRaggedKVCacheWrapper(
404
+ workspace_buffers[0], "NHD"
405
+ )
406
+ self.flashinfer_prefill_wrapper_paged = BatchPrefillWithPagedKVCacheWrapper(
407
+ workspace_buffers[1], "NHD"
408
+ )
409
+ self.flashinfer_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
410
+ workspace_buffers[2], "NHD", use_tensor_cores=use_tensor_cores
411
+ )
412
+ else:
413
+ self.flashinfer_prefill_wrapper_ragged = self.flashinfer_prefill_wrapper_paged = None
414
+ self.flashinfer_decode_wrapper = None
368
415
 
369
416
  @torch.inference_mode()
370
417
  def forward_prefill(self, batch: Batch):
@@ -379,6 +426,9 @@ class ModelRunner:
379
426
  out_cache_loc=batch.out_cache_loc,
380
427
  top_logprobs_nums=batch.top_logprobs_nums,
381
428
  return_logprob=batch.return_logprob,
429
+ flashinfer_prefill_wrapper_ragged=self.flashinfer_prefill_wrapper_ragged,
430
+ flashinfer_prefill_wrapper_paged=self.flashinfer_prefill_wrapper_paged,
431
+ flashinfer_decode_wrapper=self.flashinfer_decode_wrapper,
382
432
  )
383
433
  return self.model.forward(
384
434
  batch.input_ids, input_metadata.positions, input_metadata
@@ -397,6 +447,9 @@ class ModelRunner:
397
447
  out_cache_loc=batch.out_cache_loc,
398
448
  top_logprobs_nums=batch.top_logprobs_nums,
399
449
  return_logprob=batch.return_logprob,
450
+ flashinfer_prefill_wrapper_ragged=self.flashinfer_prefill_wrapper_ragged,
451
+ flashinfer_prefill_wrapper_paged=self.flashinfer_prefill_wrapper_paged,
452
+ flashinfer_decode_wrapper=self.flashinfer_decode_wrapper,
400
453
  )
401
454
  return self.model.forward(
402
455
  batch.input_ids, input_metadata.positions, input_metadata
@@ -417,6 +470,9 @@ class ModelRunner:
417
470
  out_cache_cont_end=batch.out_cache_cont_end,
418
471
  top_logprobs_nums=batch.top_logprobs_nums,
419
472
  return_logprob=batch.return_logprob,
473
+ flashinfer_prefill_wrapper_ragged=self.flashinfer_prefill_wrapper_ragged,
474
+ flashinfer_prefill_wrapper_paged=self.flashinfer_prefill_wrapper_paged,
475
+ flashinfer_decode_wrapper=self.flashinfer_decode_wrapper,
420
476
  )
421
477
  return self.model.forward(
422
478
  batch.input_ids, input_metadata.positions, input_metadata
@@ -435,6 +491,9 @@ class ModelRunner:
435
491
  out_cache_loc=batch.out_cache_loc,
436
492
  top_logprobs_nums=batch.top_logprobs_nums,
437
493
  return_logprob=batch.return_logprob,
494
+ flashinfer_prefill_wrapper_ragged=self.flashinfer_prefill_wrapper_ragged,
495
+ flashinfer_prefill_wrapper_paged=self.flashinfer_prefill_wrapper_paged,
496
+ flashinfer_decode_wrapper=self.flashinfer_decode_wrapper,
438
497
  )
439
498
  return self.model.forward(
440
499
  batch.input_ids,
@@ -456,3 +515,48 @@ class ModelRunner:
456
515
  return self.forward_prefill(batch)
457
516
  else:
458
517
  raise ValueError(f"Invaid forward mode: {forward_mode}")
518
+
519
+
520
+ @lru_cache()
521
+ def import_model_classes():
522
+ model_arch_name_to_cls = {}
523
+ package_name = "sglang.srt.models"
524
+ package = importlib.import_module(package_name)
525
+ for _, name, ispkg in pkgutil.iter_modules(package.__path__, package_name + "."):
526
+ if not ispkg:
527
+ module = importlib.import_module(name)
528
+ if hasattr(module, "EntryClass"):
529
+ entry = module.EntryClass
530
+ if isinstance(
531
+ entry, list
532
+ ): # To support multiple model classes in one module
533
+ for tmp in entry:
534
+ model_arch_name_to_cls[tmp.__name__] = tmp
535
+ else:
536
+ model_arch_name_to_cls[entry.__name__] = entry
537
+
538
+ # compat: some models such as chatglm has incorrect class set in config.json
539
+ # usage: [ tuple("From_Entry_Class_Name": EntryClass), ]
540
+ if hasattr(module, "EntryClassRemapping") and isinstance(
541
+ module.EntryClassRemapping, list
542
+ ):
543
+ for remap in module.EntryClassRemapping:
544
+ if isinstance(remap, tuple) and len(remap) == 2:
545
+ model_arch_name_to_cls[remap[0]] = remap[1]
546
+
547
+ return model_arch_name_to_cls
548
+
549
+
550
+ def load_model_cls_srt(model_arch: str) -> Optional[Type[nn.Module]]:
551
+ model_arch_name_to_cls = import_model_classes()
552
+
553
+ if model_arch not in model_arch_name_to_cls:
554
+ raise ValueError(
555
+ f"Unsupported architectures: {model_arch}. "
556
+ f"Supported list: {list(model_arch_name_to_cls.keys())}"
557
+ )
558
+ return model_arch_name_to_cls[model_arch]
559
+
560
+
561
+ # Monkey patch model loader
562
+ setattr(ModelRegistry, "load_model_cls", load_model_cls_srt)