sglang 0.1.16__py3-none-any.whl → 0.1.17__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (65) hide show
  1. sglang/__init__.py +3 -1
  2. sglang/api.py +3 -3
  3. sglang/backend/anthropic.py +1 -1
  4. sglang/backend/litellm.py +90 -0
  5. sglang/backend/openai.py +148 -12
  6. sglang/backend/runtime_endpoint.py +18 -10
  7. sglang/global_config.py +8 -1
  8. sglang/lang/interpreter.py +114 -67
  9. sglang/lang/ir.py +17 -2
  10. sglang/srt/constrained/fsm_cache.py +3 -0
  11. sglang/srt/flush_cache.py +1 -1
  12. sglang/srt/hf_transformers_utils.py +75 -1
  13. sglang/srt/layers/extend_attention.py +17 -0
  14. sglang/srt/layers/fused_moe.py +485 -0
  15. sglang/srt/layers/logits_processor.py +12 -7
  16. sglang/srt/layers/radix_attention.py +10 -3
  17. sglang/srt/layers/token_attention.py +16 -1
  18. sglang/srt/managers/controller/dp_worker.py +110 -0
  19. sglang/srt/managers/controller/infer_batch.py +619 -0
  20. sglang/srt/managers/controller/manager_multi.py +191 -0
  21. sglang/srt/managers/controller/manager_single.py +97 -0
  22. sglang/srt/managers/controller/model_runner.py +462 -0
  23. sglang/srt/managers/controller/radix_cache.py +267 -0
  24. sglang/srt/managers/controller/schedule_heuristic.py +59 -0
  25. sglang/srt/managers/controller/tp_worker.py +791 -0
  26. sglang/srt/managers/detokenizer_manager.py +45 -45
  27. sglang/srt/managers/io_struct.py +15 -11
  28. sglang/srt/managers/router/infer_batch.py +103 -59
  29. sglang/srt/managers/router/manager.py +1 -1
  30. sglang/srt/managers/router/model_rpc.py +175 -122
  31. sglang/srt/managers/router/model_runner.py +91 -104
  32. sglang/srt/managers/router/radix_cache.py +7 -1
  33. sglang/srt/managers/router/scheduler.py +6 -6
  34. sglang/srt/managers/tokenizer_manager.py +152 -89
  35. sglang/srt/model_config.py +4 -5
  36. sglang/srt/models/commandr.py +10 -13
  37. sglang/srt/models/dbrx.py +9 -15
  38. sglang/srt/models/gemma.py +8 -15
  39. sglang/srt/models/grok.py +671 -0
  40. sglang/srt/models/llama2.py +19 -15
  41. sglang/srt/models/llava.py +84 -20
  42. sglang/srt/models/llavavid.py +11 -20
  43. sglang/srt/models/mixtral.py +248 -118
  44. sglang/srt/models/mixtral_quant.py +373 -0
  45. sglang/srt/models/qwen.py +9 -13
  46. sglang/srt/models/qwen2.py +11 -13
  47. sglang/srt/models/stablelm.py +9 -15
  48. sglang/srt/models/yivl.py +17 -22
  49. sglang/srt/openai_api_adapter.py +140 -95
  50. sglang/srt/openai_protocol.py +10 -1
  51. sglang/srt/server.py +77 -42
  52. sglang/srt/server_args.py +51 -6
  53. sglang/srt/utils.py +124 -66
  54. sglang/test/test_programs.py +44 -0
  55. sglang/test/test_utils.py +32 -1
  56. sglang/utils.py +22 -4
  57. {sglang-0.1.16.dist-info → sglang-0.1.17.dist-info}/METADATA +15 -9
  58. sglang-0.1.17.dist-info/RECORD +81 -0
  59. sglang/srt/backend_config.py +0 -13
  60. sglang/srt/models/dbrx_config.py +0 -281
  61. sglang/srt/weight_utils.py +0 -417
  62. sglang-0.1.16.dist-info/RECORD +0 -72
  63. {sglang-0.1.16.dist-info → sglang-0.1.17.dist-info}/LICENSE +0 -0
  64. {sglang-0.1.16.dist-info → sglang-0.1.17.dist-info}/WHEEL +0 -0
  65. {sglang-0.1.16.dist-info → sglang-0.1.17.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,462 @@
1
+ import importlib
2
+ import importlib.resources
3
+ import logging
4
+ import pkgutil
5
+ from dataclasses import dataclass
6
+ from functools import lru_cache
7
+ from typing import List, Optional, Type
8
+
9
+ import numpy as np
10
+ import torch
11
+ import torch.nn as nn
12
+ from vllm.config import DeviceConfig, LoadConfig
13
+ from vllm.config import ModelConfig as VllmModelConfig
14
+ from vllm.distributed import initialize_model_parallel, init_distributed_environment
15
+ from vllm.model_executor.model_loader import get_model
16
+ from vllm.model_executor.models import ModelRegistry
17
+
18
+ from sglang.srt.managers.controller.infer_batch import Batch, ForwardMode
19
+ from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
20
+ from sglang.srt.server_args import ServerArgs
21
+ from sglang.srt.utils import get_available_gpu_memory, is_multimodal_model, monkey_patch_vllm_p2p_access_check
22
+
23
+
24
+ logger = logging.getLogger("srt.model_runner")
25
+
26
+ # for server args in model endpoints
27
+ global_server_args_dict = {}
28
+
29
+
30
+ @dataclass
31
+ class InputMetadata:
32
+ model_runner: "ModelRunner"
33
+ forward_mode: ForwardMode
34
+ batch_size: int
35
+ total_num_tokens: int
36
+ max_seq_len: int
37
+ req_pool_indices: torch.Tensor
38
+ start_loc: torch.Tensor
39
+ seq_lens: torch.Tensor
40
+ prefix_lens: torch.Tensor
41
+ positions: torch.Tensor
42
+ req_to_token_pool: ReqToTokenPool
43
+ token_to_kv_pool: TokenToKVPool
44
+
45
+ # for extend
46
+ extend_seq_lens: torch.Tensor = None
47
+ extend_start_loc: torch.Tensor = None
48
+ max_extend_len: int = 0
49
+
50
+ out_cache_loc: torch.Tensor = None
51
+ out_cache_cont_start: torch.Tensor = None
52
+ out_cache_cont_end: torch.Tensor = None
53
+
54
+ other_kv_index: torch.Tensor = None
55
+ return_logprob: bool = False
56
+ top_logprobs_nums: List[int] = None
57
+
58
+ # for flashinfer
59
+ qo_indptr: torch.Tensor = None
60
+ kv_indptr: torch.Tensor = None
61
+ kv_indices: torch.Tensor = None
62
+ kv_last_page_len: torch.Tensor = None
63
+ prefill_wrapper = None
64
+ decode_wrapper = None
65
+
66
+ def init_flashinfer_args(self, tp_size):
67
+ from flashinfer import (
68
+ BatchDecodeWithPagedKVCacheWrapper,
69
+ BatchPrefillWithPagedKVCacheWrapper,
70
+ )
71
+
72
+ self.kv_indptr = torch.zeros(
73
+ (self.batch_size + 1,), dtype=torch.int32, device="cuda"
74
+ )
75
+ self.kv_indptr[1:] = torch.cumsum(self.seq_lens, dim=0)
76
+ self.kv_last_page_len = torch.ones(
77
+ (self.batch_size,), dtype=torch.int32, device="cuda"
78
+ )
79
+ req_pool_indices_cpu = self.req_pool_indices.cpu().numpy()
80
+ seq_lens_cpu = self.seq_lens.cpu().numpy()
81
+ self.kv_indices = torch.cat(
82
+ [
83
+ self.req_to_token_pool.req_to_token[
84
+ req_pool_indices_cpu[i], : seq_lens_cpu[i]
85
+ ]
86
+ for i in range(self.batch_size)
87
+ ],
88
+ dim=0,
89
+ ).contiguous()
90
+
91
+ workspace_buffer = torch.empty(
92
+ 32 * 1024 * 1024, dtype=torch.int8, device="cuda"
93
+ )
94
+ if (
95
+ self.forward_mode == ForwardMode.PREFILL
96
+ or self.forward_mode == ForwardMode.EXTEND
97
+ ):
98
+ self.qo_indptr = torch.zeros(
99
+ (self.batch_size + 1,), dtype=torch.int32, device="cuda"
100
+ )
101
+ self.qo_indptr[1:] = torch.cumsum(self.extend_seq_lens, dim=0)
102
+ self.prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper(
103
+ workspace_buffer, "NHD"
104
+ )
105
+ args = [
106
+ self.qo_indptr,
107
+ self.kv_indptr,
108
+ self.kv_indices,
109
+ self.kv_last_page_len,
110
+ self.model_runner.model_config.num_attention_heads // tp_size,
111
+ self.model_runner.model_config.num_key_value_heads // tp_size,
112
+ self.model_runner.model_config.head_dim,
113
+ ]
114
+
115
+ self.prefill_wrapper.begin_forward(*args)
116
+ else:
117
+ self.decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
118
+ workspace_buffer, "NHD"
119
+ )
120
+ self.decode_wrapper.begin_forward(
121
+ self.kv_indptr,
122
+ self.kv_indices,
123
+ self.kv_last_page_len,
124
+ self.model_runner.model_config.num_attention_heads // tp_size,
125
+ self.model_runner.model_config.num_key_value_heads // tp_size,
126
+ self.model_runner.model_config.head_dim,
127
+ 1,
128
+ "NONE",
129
+ "float16",
130
+ )
131
+
132
+ def init_extend_args(self):
133
+ self.extend_seq_lens = self.seq_lens - self.prefix_lens
134
+ self.extend_start_loc = torch.zeros_like(self.seq_lens)
135
+ self.extend_start_loc[1:] = torch.cumsum(self.extend_seq_lens[:-1], dim=0)
136
+ self.max_extend_len = int(torch.max(self.extend_seq_lens))
137
+
138
+ @classmethod
139
+ def create(
140
+ cls,
141
+ model_runner,
142
+ tp_size,
143
+ forward_mode,
144
+ req_pool_indices,
145
+ seq_lens,
146
+ prefix_lens,
147
+ position_ids_offsets,
148
+ out_cache_loc,
149
+ out_cache_cont_start=None,
150
+ out_cache_cont_end=None,
151
+ top_logprobs_nums=None,
152
+ return_logprob=False,
153
+ ):
154
+ batch_size = len(req_pool_indices)
155
+ start_loc = torch.zeros((batch_size,), dtype=torch.int32, device="cuda")
156
+ start_loc[1:] = torch.cumsum(seq_lens[:-1], dim=0)
157
+ total_num_tokens = int(torch.sum(seq_lens))
158
+ max_seq_len = int(torch.max(seq_lens))
159
+
160
+ if forward_mode == ForwardMode.DECODE:
161
+ positions = ((seq_lens - 1) + position_ids_offsets).to(torch.int64)
162
+ other_kv_index = model_runner.req_to_token_pool.req_to_token[
163
+ req_pool_indices[0], seq_lens[0] - 1
164
+ ].item()
165
+ else:
166
+ seq_lens_cpu = seq_lens.cpu().numpy()
167
+ prefix_lens_cpu = prefix_lens.cpu().numpy()
168
+ position_ids_offsets_cpu = position_ids_offsets.cpu().numpy()
169
+ positions = torch.tensor(
170
+ np.concatenate(
171
+ [
172
+ np.arange(
173
+ prefix_lens_cpu[i] + position_ids_offsets_cpu[i],
174
+ seq_lens_cpu[i] + position_ids_offsets_cpu[i],
175
+ )
176
+ for i in range(batch_size)
177
+ ],
178
+ axis=0,
179
+ ),
180
+ device="cuda",
181
+ )
182
+ other_kv_index = None
183
+
184
+ ret = cls(
185
+ model_runner=model_runner,
186
+ forward_mode=forward_mode,
187
+ batch_size=batch_size,
188
+ total_num_tokens=total_num_tokens,
189
+ max_seq_len=max_seq_len,
190
+ req_pool_indices=req_pool_indices,
191
+ start_loc=start_loc,
192
+ seq_lens=seq_lens,
193
+ prefix_lens=prefix_lens,
194
+ positions=positions,
195
+ req_to_token_pool=model_runner.req_to_token_pool,
196
+ token_to_kv_pool=model_runner.token_to_kv_pool,
197
+ out_cache_loc=out_cache_loc,
198
+ out_cache_cont_start=out_cache_cont_start,
199
+ out_cache_cont_end=out_cache_cont_end,
200
+ other_kv_index=other_kv_index,
201
+ return_logprob=return_logprob,
202
+ top_logprobs_nums=top_logprobs_nums,
203
+ )
204
+
205
+ if forward_mode == ForwardMode.EXTEND:
206
+ ret.init_extend_args()
207
+
208
+ if global_server_args_dict.get("enable_flashinfer", False):
209
+ ret.init_flashinfer_args(tp_size)
210
+
211
+ return ret
212
+
213
+
214
+ class ModelRunner:
215
+ def __init__(
216
+ self,
217
+ model_config,
218
+ mem_fraction_static: float,
219
+ gpu_id: int,
220
+ tp_rank: int,
221
+ tp_size: int,
222
+ nccl_port: int,
223
+ server_args: ServerArgs,
224
+ ):
225
+ self.model_config = model_config
226
+ self.mem_fraction_static = mem_fraction_static
227
+ self.gpu_id = gpu_id
228
+ self.tp_rank = tp_rank
229
+ self.tp_size = tp_size
230
+ self.nccl_port = nccl_port
231
+ self.server_args = server_args
232
+
233
+ global global_server_args_dict
234
+ global_server_args_dict = {
235
+ "enable_flashinfer": server_args.enable_flashinfer,
236
+ "attention_reduce_in_fp32": server_args.attention_reduce_in_fp32,
237
+ }
238
+
239
+ # Init torch distributed
240
+ logger.info(f"[gpu_id={self.gpu_id}] Set cuda device.")
241
+ torch.cuda.set_device(self.gpu_id)
242
+ logger.info(f"[gpu_id={self.gpu_id}] Init nccl begin.")
243
+ monkey_patch_vllm_p2p_access_check()
244
+ init_distributed_environment(
245
+ backend="nccl",
246
+ world_size=self.tp_size,
247
+ rank=self.tp_rank,
248
+ local_rank=self.gpu_id,
249
+ distributed_init_method=f"tcp://127.0.0.1:{self.nccl_port}",
250
+ )
251
+ initialize_model_parallel(tensor_model_parallel_size=self.tp_size)
252
+ total_gpu_memory = get_available_gpu_memory(
253
+ self.gpu_id, distributed=self.tp_size > 1
254
+ )
255
+
256
+ if self.tp_size > 1:
257
+ total_local_gpu_memory = get_available_gpu_memory(self.gpu_id)
258
+ if total_local_gpu_memory < total_gpu_memory * 0.9:
259
+ raise ValueError(
260
+ "The memory capacity is unbalanced. Some GPUs may be occupied by other processes."
261
+ )
262
+
263
+ self.load_model()
264
+ self.init_memory_pool(total_gpu_memory)
265
+ self.is_multimodal_model = is_multimodal_model(self.model_config)
266
+
267
+ def load_model(self):
268
+ logger.info(
269
+ f"[gpu_id={self.gpu_id}] Load weight begin. "
270
+ f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
271
+ )
272
+
273
+ device_config = DeviceConfig()
274
+ load_config = LoadConfig(load_format=self.server_args.load_format)
275
+ vllm_model_config = VllmModelConfig(
276
+ model=self.server_args.model_path,
277
+ quantization=self.server_args.quantization,
278
+ tokenizer=None,
279
+ tokenizer_mode=None,
280
+ trust_remote_code=self.server_args.trust_remote_code,
281
+ dtype=torch.float16,
282
+ seed=42,
283
+ skip_tokenizer_init=True,
284
+ )
285
+ if self.model_config.model_overide_args is not None:
286
+ vllm_model_config.hf_config.update(self.model_config.model_overide_args)
287
+
288
+ self.model = get_model(
289
+ model_config=vllm_model_config,
290
+ device_config=device_config,
291
+ load_config=load_config,
292
+ lora_config=None,
293
+ vision_language_config=None,
294
+ parallel_config=None,
295
+ scheduler_config=None,
296
+ cache_config=None,
297
+ )
298
+ logger.info(
299
+ f"[gpu_id={self.gpu_id}] Load weight end. "
300
+ f"type={type(self.model).__name__}, "
301
+ f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
302
+ )
303
+
304
+ def profile_max_num_token(self, total_gpu_memory):
305
+ available_gpu_memory = get_available_gpu_memory(
306
+ self.gpu_id, distributed=self.tp_size > 1
307
+ )
308
+ head_dim = self.model_config.head_dim
309
+ head_num = self.model_config.num_key_value_heads // self.tp_size
310
+ cell_size = head_num * head_dim * self.model_config.num_hidden_layers * 2 * 2
311
+ rest_memory = available_gpu_memory - total_gpu_memory * (
312
+ 1 - self.mem_fraction_static
313
+ )
314
+ max_num_token = int(rest_memory * (1 << 30) // cell_size)
315
+ return max_num_token
316
+
317
+ def init_memory_pool(self, total_gpu_memory):
318
+ self.max_total_num_tokens = self.profile_max_num_token(total_gpu_memory)
319
+
320
+ if self.max_total_num_tokens <= 0:
321
+ raise RuntimeError(
322
+ "Not enought memory. Please try to increase --mem-fraction-static."
323
+ )
324
+
325
+ self.req_to_token_pool = ReqToTokenPool(
326
+ int(self.max_total_num_tokens / self.model_config.context_len * 256),
327
+ self.model_config.context_len + 8,
328
+ )
329
+ self.token_to_kv_pool = TokenToKVPool(
330
+ self.max_total_num_tokens,
331
+ dtype=torch.float16,
332
+ head_num=self.model_config.num_key_value_heads // self.tp_size,
333
+ head_dim=self.model_config.head_dim,
334
+ layer_num=self.model_config.num_hidden_layers,
335
+ )
336
+ logger.info(
337
+ f"[gpu_id={self.gpu_id}] Memory pool end. "
338
+ f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
339
+ )
340
+
341
+ @torch.inference_mode()
342
+ def forward_prefill(self, batch: Batch):
343
+ input_metadata = InputMetadata.create(
344
+ self,
345
+ forward_mode=ForwardMode.PREFILL,
346
+ tp_size=self.tp_size,
347
+ req_pool_indices=batch.req_pool_indices,
348
+ seq_lens=batch.seq_lens,
349
+ prefix_lens=batch.prefix_lens,
350
+ position_ids_offsets=batch.position_ids_offsets,
351
+ out_cache_loc=batch.out_cache_loc,
352
+ top_logprobs_nums=batch.top_logprobs_nums,
353
+ return_logprob=batch.return_logprob,
354
+ )
355
+ return self.model.forward(
356
+ batch.input_ids, input_metadata.positions, input_metadata
357
+ )
358
+
359
+ @torch.inference_mode()
360
+ def forward_extend(self, batch: Batch):
361
+ input_metadata = InputMetadata.create(
362
+ self,
363
+ forward_mode=ForwardMode.EXTEND,
364
+ tp_size=self.tp_size,
365
+ req_pool_indices=batch.req_pool_indices,
366
+ seq_lens=batch.seq_lens,
367
+ prefix_lens=batch.prefix_lens,
368
+ position_ids_offsets=batch.position_ids_offsets,
369
+ out_cache_loc=batch.out_cache_loc,
370
+ top_logprobs_nums=batch.top_logprobs_nums,
371
+ return_logprob=batch.return_logprob,
372
+ )
373
+ return self.model.forward(
374
+ batch.input_ids, input_metadata.positions, input_metadata
375
+ )
376
+
377
+ @torch.inference_mode()
378
+ def forward_decode(self, batch: Batch):
379
+ input_metadata = InputMetadata.create(
380
+ self,
381
+ forward_mode=ForwardMode.DECODE,
382
+ tp_size=self.tp_size,
383
+ req_pool_indices=batch.req_pool_indices,
384
+ seq_lens=batch.seq_lens,
385
+ prefix_lens=batch.prefix_lens,
386
+ position_ids_offsets=batch.position_ids_offsets,
387
+ out_cache_loc=batch.out_cache_loc,
388
+ out_cache_cont_start=batch.out_cache_cont_start,
389
+ out_cache_cont_end=batch.out_cache_cont_end,
390
+ top_logprobs_nums=batch.top_logprobs_nums,
391
+ return_logprob=batch.return_logprob,
392
+ )
393
+ return self.model.forward(
394
+ batch.input_ids, input_metadata.positions, input_metadata
395
+ )
396
+
397
+ @torch.inference_mode()
398
+ def forward_extend_multi_modal(self, batch: Batch):
399
+ input_metadata = InputMetadata.create(
400
+ self,
401
+ forward_mode=ForwardMode.EXTEND,
402
+ tp_size=self.tp_size,
403
+ req_pool_indices=batch.req_pool_indices,
404
+ seq_lens=batch.seq_lens,
405
+ prefix_lens=batch.prefix_lens,
406
+ position_ids_offsets=batch.position_ids_offsets,
407
+ out_cache_loc=batch.out_cache_loc,
408
+ top_logprobs_nums=batch.top_logprobs_nums,
409
+ return_logprob=batch.return_logprob,
410
+ )
411
+ return self.model.forward(
412
+ batch.input_ids,
413
+ input_metadata.positions,
414
+ input_metadata,
415
+ batch.pixel_values,
416
+ batch.image_sizes,
417
+ batch.image_offsets,
418
+ )
419
+
420
+ def forward(self, batch: Batch, forward_mode: ForwardMode):
421
+ if self.is_multimodal_model and forward_mode == ForwardMode.EXTEND:
422
+ return self.forward_extend_multi_modal(batch)
423
+ elif forward_mode == ForwardMode.DECODE:
424
+ return self.forward_decode(batch)
425
+ elif forward_mode == ForwardMode.EXTEND:
426
+ return self.forward_extend(batch)
427
+ elif forward_mode == ForwardMode.PREFILL:
428
+ return self.forward_prefill(batch)
429
+ else:
430
+ raise ValueError(f"Invaid forward mode: {forward_mode}")
431
+
432
+
433
+ @lru_cache()
434
+ def import_model_classes():
435
+ model_arch_name_to_cls = {}
436
+ package_name = "sglang.srt.models"
437
+ package = importlib.import_module(package_name)
438
+ for _, name, ispkg in pkgutil.iter_modules(package.__path__, package_name + "."):
439
+ if not ispkg:
440
+ module = importlib.import_module(name)
441
+ if hasattr(module, "EntryClass"):
442
+ entry = module.EntryClass
443
+ if isinstance(entry, list): # To support multiple model classes in one module
444
+ for tmp in entry:
445
+ model_arch_name_to_cls[tmp.__name__] = tmp
446
+ else:
447
+ model_arch_name_to_cls[entry.__name__] = entry
448
+ return model_arch_name_to_cls
449
+
450
+
451
+ def load_model_cls_srt(model_arch: str) -> Optional[Type[nn.Module]]:
452
+ model_arch_name_to_cls = import_model_classes()
453
+ if model_arch not in model_arch_name_to_cls:
454
+ raise ValueError(
455
+ f"Unsupported architectures: {model_arch}. "
456
+ f"Supported list: {list(model_arch_name_to_cls.keys())}"
457
+ )
458
+ return model_arch_name_to_cls[model_arch]
459
+
460
+
461
+ # Monkey patch model loader
462
+ setattr(ModelRegistry, "load_model_cls", load_model_cls_srt)