sglang 0.1.14__py3-none-any.whl → 0.1.21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (81) hide show
  1. sglang/__init__.py +59 -2
  2. sglang/api.py +40 -11
  3. sglang/backend/anthropic.py +17 -3
  4. sglang/backend/litellm.py +90 -0
  5. sglang/backend/openai.py +160 -12
  6. sglang/backend/runtime_endpoint.py +62 -27
  7. sglang/backend/vertexai.py +1 -0
  8. sglang/bench_latency.py +320 -0
  9. sglang/global_config.py +24 -3
  10. sglang/lang/chat_template.py +122 -6
  11. sglang/lang/compiler.py +2 -2
  12. sglang/lang/interpreter.py +206 -98
  13. sglang/lang/ir.py +98 -34
  14. sglang/lang/tracer.py +6 -4
  15. sglang/launch_server.py +4 -1
  16. sglang/launch_server_llavavid.py +32 -0
  17. sglang/srt/constrained/__init__.py +14 -6
  18. sglang/srt/constrained/fsm_cache.py +9 -2
  19. sglang/srt/constrained/jump_forward.py +113 -24
  20. sglang/srt/conversation.py +4 -2
  21. sglang/srt/flush_cache.py +18 -0
  22. sglang/srt/hf_transformers_utils.py +144 -3
  23. sglang/srt/layers/context_flashattention_nopad.py +1 -0
  24. sglang/srt/layers/extend_attention.py +20 -1
  25. sglang/srt/layers/fused_moe.py +596 -0
  26. sglang/srt/layers/logits_processor.py +190 -61
  27. sglang/srt/layers/radix_attention.py +62 -53
  28. sglang/srt/layers/token_attention.py +21 -9
  29. sglang/srt/managers/controller/cuda_graph_runner.py +196 -0
  30. sglang/srt/managers/controller/dp_worker.py +113 -0
  31. sglang/srt/managers/controller/infer_batch.py +908 -0
  32. sglang/srt/managers/controller/manager_multi.py +195 -0
  33. sglang/srt/managers/controller/manager_single.py +177 -0
  34. sglang/srt/managers/controller/model_runner.py +359 -0
  35. sglang/srt/managers/{router → controller}/radix_cache.py +102 -53
  36. sglang/srt/managers/controller/schedule_heuristic.py +65 -0
  37. sglang/srt/managers/controller/tp_worker.py +813 -0
  38. sglang/srt/managers/detokenizer_manager.py +42 -40
  39. sglang/srt/managers/io_struct.py +44 -10
  40. sglang/srt/managers/tokenizer_manager.py +224 -82
  41. sglang/srt/memory_pool.py +52 -59
  42. sglang/srt/model_config.py +97 -2
  43. sglang/srt/models/chatglm.py +399 -0
  44. sglang/srt/models/commandr.py +369 -0
  45. sglang/srt/models/dbrx.py +406 -0
  46. sglang/srt/models/gemma.py +34 -38
  47. sglang/srt/models/gemma2.py +436 -0
  48. sglang/srt/models/grok.py +738 -0
  49. sglang/srt/models/llama2.py +47 -37
  50. sglang/srt/models/llama_classification.py +107 -0
  51. sglang/srt/models/llava.py +92 -27
  52. sglang/srt/models/llavavid.py +298 -0
  53. sglang/srt/models/minicpm.py +366 -0
  54. sglang/srt/models/mixtral.py +302 -127
  55. sglang/srt/models/mixtral_quant.py +372 -0
  56. sglang/srt/models/qwen.py +40 -35
  57. sglang/srt/models/qwen2.py +33 -36
  58. sglang/srt/models/qwen2_moe.py +473 -0
  59. sglang/srt/models/stablelm.py +33 -39
  60. sglang/srt/models/yivl.py +19 -26
  61. sglang/srt/openai_api_adapter.py +411 -0
  62. sglang/srt/{managers/openai_protocol.py → openai_protocol.py} +44 -19
  63. sglang/srt/sampling_params.py +2 -0
  64. sglang/srt/server.py +197 -481
  65. sglang/srt/server_args.py +190 -74
  66. sglang/srt/utils.py +460 -95
  67. sglang/test/test_programs.py +73 -10
  68. sglang/test/test_utils.py +226 -7
  69. sglang/utils.py +97 -27
  70. {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/METADATA +74 -45
  71. sglang-0.1.21.dist-info/RECORD +82 -0
  72. {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/WHEEL +1 -1
  73. sglang/srt/backend_config.py +0 -13
  74. sglang/srt/managers/router/infer_batch.py +0 -503
  75. sglang/srt/managers/router/manager.py +0 -79
  76. sglang/srt/managers/router/model_rpc.py +0 -686
  77. sglang/srt/managers/router/model_runner.py +0 -514
  78. sglang/srt/managers/router/scheduler.py +0 -70
  79. sglang-0.1.14.dist-info/RECORD +0 -64
  80. {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/LICENSE +0 -0
  81. {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/top_level.txt +0 -0
@@ -1,514 +0,0 @@
1
- import importlib
2
- import logging
3
- import inspect
4
- from dataclasses import dataclass
5
- from functools import lru_cache
6
- from pathlib import Path
7
- import importlib.resources
8
-
9
- import numpy as np
10
- import torch
11
- from sglang.srt.managers.router.infer_batch import Batch, ForwardMode
12
- from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
13
- from sglang.srt.utils import is_multimodal_model
14
- from sglang.utils import get_available_gpu_memory
15
- from vllm.model_executor.layers.quantization.awq import AWQConfig
16
- from vllm.model_executor.layers.quantization.gptq import GPTQConfig
17
- from vllm.model_executor.layers.quantization.marlin import MarlinConfig
18
- from vllm.model_executor.model_loader import _set_default_torch_dtype
19
- from vllm.model_executor.parallel_utils.parallel_state import initialize_model_parallel
20
-
21
- import importlib
22
- import pkgutil
23
-
24
- import sglang
25
-
26
- QUANTIONCONFIG_MAPPING = {"awq": AWQConfig, "gptq": GPTQConfig, "marlin": MarlinConfig}
27
-
28
- logger = logging.getLogger("model_runner")
29
-
30
-
31
- # for server args in model endpoints
32
- global_server_args_dict: dict = None
33
-
34
-
35
- @lru_cache()
36
- def import_model_classes():
37
- model_arch_name_to_cls = {}
38
- package_name = "sglang.srt.models"
39
- package = importlib.import_module(package_name)
40
- for finder, name, ispkg in pkgutil.iter_modules(package.__path__, package_name + '.'):
41
- if not ispkg:
42
- module = importlib.import_module(name)
43
- if hasattr(module, "EntryClass"):
44
- model_arch_name_to_cls[module.EntryClass.__name__] = module.EntryClass
45
- return model_arch_name_to_cls
46
-
47
-
48
- def get_model_cls_by_arch_name(model_arch_names):
49
- model_arch_name_to_cls = import_model_classes()
50
-
51
- model_class = None
52
- for arch in model_arch_names:
53
- if arch in model_arch_name_to_cls:
54
- model_class = model_arch_name_to_cls[arch]
55
- break
56
- else:
57
- raise ValueError(
58
- f"Unsupported architectures: {arch}. "
59
- f"Supported list: {list(model_arch_name_to_cls.keys())}"
60
- )
61
- return model_class
62
-
63
-
64
- @dataclass
65
- class InputMetadata:
66
- model_runner: "ModelRunner"
67
- forward_mode: ForwardMode
68
- batch_size: int
69
- total_num_tokens: int
70
- max_seq_len: int
71
- req_pool_indices: torch.Tensor
72
- start_loc: torch.Tensor
73
- seq_lens: torch.Tensor
74
- prefix_lens: torch.Tensor
75
- positions: torch.Tensor
76
- req_to_token_pool: ReqToTokenPool
77
- token_to_kv_pool: TokenToKVPool
78
-
79
- # for extend
80
- extend_seq_lens: torch.Tensor = None
81
- extend_start_loc: torch.Tensor = None
82
- max_extend_len: int = 0
83
-
84
- out_cache_loc: torch.Tensor = None
85
- out_cache_cont_start: torch.Tensor = None
86
- out_cache_cont_end: torch.Tensor = None
87
-
88
- other_kv_index: torch.Tensor = None
89
- return_logprob: bool = False
90
-
91
- # for flashinfer
92
- qo_indptr: torch.Tensor = None
93
- kv_indptr: torch.Tensor = None
94
- kv_indices: torch.Tensor = None
95
- kv_last_page_len: torch.Tensor = None
96
- prefill_wrapper = None
97
- decode_wrapper = None
98
-
99
- def init_flashinfer_args(self, tp_size):
100
- from flashinfer import (
101
- BatchDecodeWithPagedKVCacheWrapper,
102
- BatchPrefillWithPagedKVCacheWrapper,
103
- )
104
-
105
- self.kv_indptr = torch.zeros(
106
- (self.batch_size + 1,), dtype=torch.int32, device="cuda"
107
- )
108
- self.kv_indptr[1:] = torch.cumsum(self.seq_lens, dim=0)
109
- self.kv_indices = torch.cat(
110
- [
111
- self.req_to_token_pool.req_to_token[
112
- self.req_pool_indices[i].item(), : self.seq_lens[i].item()
113
- ]
114
- for i in range(self.batch_size)
115
- ],
116
- dim=0,
117
- ).contiguous()
118
- self.kv_last_page_len = torch.ones(
119
- (self.batch_size,), dtype=torch.int32, device="cuda"
120
- )
121
-
122
- workspace_buffer = torch.empty(
123
- 32 * 1024 * 1024, dtype=torch.int8, device="cuda"
124
- )
125
- if (
126
- self.forward_mode == ForwardMode.PREFILL
127
- or self.forward_mode == ForwardMode.EXTEND
128
- ):
129
- self.qo_indptr = torch.zeros(
130
- (self.batch_size + 1,), dtype=torch.int32, device="cuda"
131
- )
132
- self.qo_indptr[1:] = torch.cumsum(self.extend_seq_lens, dim=0)
133
- self.prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper(
134
- workspace_buffer, "NHD"
135
- )
136
- args = [
137
- self.qo_indptr,
138
- self.kv_indptr,
139
- self.kv_indices,
140
- self.kv_last_page_len,
141
- self.model_runner.model_config.num_attention_heads // tp_size,
142
- self.model_runner.model_config.num_key_value_heads // tp_size,
143
- ]
144
-
145
- # flashinfer >= 0.0.3
146
- # FIXME: Drop this when flashinfer updates to 0.0.4
147
- if len(inspect.signature(self.prefill_wrapper.begin_forward).parameters) == 7:
148
- args.append(self.model_runner.model_config.head_dim)
149
-
150
- self.prefill_wrapper.begin_forward(*args)
151
- else:
152
- self.decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
153
- workspace_buffer, "NHD"
154
- )
155
- self.decode_wrapper.begin_forward(
156
- self.kv_indptr,
157
- self.kv_indices,
158
- self.kv_last_page_len,
159
- self.model_runner.model_config.num_attention_heads // tp_size,
160
- self.model_runner.model_config.num_key_value_heads // tp_size,
161
- self.model_runner.model_config.head_dim,
162
- 1,
163
- "NONE",
164
- "float16",
165
- )
166
-
167
- def init_extend_args(self):
168
- self.extend_seq_lens = self.seq_lens - self.prefix_lens
169
- self.extend_start_loc = torch.zeros_like(self.seq_lens)
170
- self.extend_start_loc[1:] = torch.cumsum(self.extend_seq_lens[:-1], dim=0)
171
- self.max_extend_len = int(torch.max(self.extend_seq_lens))
172
-
173
- @classmethod
174
- def create(
175
- cls,
176
- model_runner,
177
- tp_size,
178
- forward_mode,
179
- req_pool_indices,
180
- seq_lens,
181
- prefix_lens,
182
- position_ids_offsets,
183
- out_cache_loc,
184
- out_cache_cont_start=None,
185
- out_cache_cont_end=None,
186
- return_logprob=False,
187
- ):
188
- batch_size = len(req_pool_indices)
189
- start_loc = torch.zeros((batch_size,), dtype=torch.int32, device="cuda")
190
- start_loc[1:] = torch.cumsum(seq_lens[:-1], dim=0)
191
- total_num_tokens = int(torch.sum(seq_lens))
192
- max_seq_len = int(torch.max(seq_lens))
193
-
194
- if forward_mode == ForwardMode.DECODE:
195
- positions = ((seq_lens - 1) + position_ids_offsets).to(torch.int64)
196
- other_kv_index = model_runner.req_to_token_pool.req_to_token[
197
- req_pool_indices[0], seq_lens[0] - 1
198
- ].item()
199
- else:
200
- seq_lens_np = seq_lens.cpu().numpy()
201
- prefix_lens_np = prefix_lens.cpu().numpy()
202
- position_ids_offsets_np = position_ids_offsets.cpu().numpy()
203
- positions = torch.tensor(
204
- np.concatenate(
205
- [
206
- np.arange(
207
- prefix_lens_np[i] + position_ids_offsets_np[i],
208
- seq_lens_np[i] + position_ids_offsets_np[i],
209
- )
210
- for i in range(batch_size)
211
- ],
212
- axis=0,
213
- ),
214
- device="cuda",
215
- )
216
- other_kv_index = None
217
-
218
- ret = cls(
219
- model_runner=model_runner,
220
- forward_mode=forward_mode,
221
- batch_size=batch_size,
222
- total_num_tokens=total_num_tokens,
223
- max_seq_len=max_seq_len,
224
- req_pool_indices=req_pool_indices,
225
- start_loc=start_loc,
226
- seq_lens=seq_lens,
227
- prefix_lens=prefix_lens,
228
- positions=positions,
229
- req_to_token_pool=model_runner.req_to_token_pool,
230
- token_to_kv_pool=model_runner.token_to_kv_pool,
231
- out_cache_loc=out_cache_loc,
232
- out_cache_cont_start=out_cache_cont_start,
233
- out_cache_cont_end=out_cache_cont_end,
234
- return_logprob=return_logprob,
235
- other_kv_index=other_kv_index,
236
- )
237
-
238
- if forward_mode == ForwardMode.EXTEND:
239
- ret.init_extend_args()
240
-
241
- if global_server_args_dict.get("enable_flashinfer", False):
242
- ret.init_flashinfer_args(tp_size)
243
-
244
- return ret
245
-
246
-
247
- class ModelRunner:
248
- def __init__(
249
- self,
250
- model_config,
251
- mem_fraction_static,
252
- tp_rank,
253
- tp_size,
254
- nccl_port,
255
- load_format="auto",
256
- trust_remote_code=True,
257
- server_args_dict: dict = {},
258
- ):
259
- self.model_config = model_config
260
- self.mem_fraction_static = mem_fraction_static
261
- self.tp_rank = tp_rank
262
- self.tp_size = tp_size
263
- self.nccl_port = nccl_port
264
- self.load_format = load_format
265
- self.trust_remote_code = trust_remote_code
266
-
267
- global global_server_args_dict
268
- global_server_args_dict = server_args_dict
269
-
270
- # Init torch distributed
271
- torch.cuda.set_device(self.tp_rank)
272
- torch.distributed.init_process_group(
273
- backend="nccl",
274
- world_size=self.tp_size,
275
- rank=self.tp_rank,
276
- init_method=f"tcp://127.0.0.1:{self.nccl_port}",
277
- )
278
-
279
- # A small all_reduce for warmup.
280
- if self.tp_size > 1:
281
- torch.distributed.all_reduce(torch.zeros(1).cuda())
282
- initialize_model_parallel(tensor_model_parallel_size=self.tp_size)
283
-
284
- total_gpu_memory = get_available_gpu_memory(
285
- self.tp_rank, distributed=self.tp_size > 1
286
- ) * (1 << 30)
287
- self.load_model()
288
- self.init_memory_pool(total_gpu_memory)
289
-
290
- self.is_multimodal_model = is_multimodal_model(self.model_config)
291
-
292
- def load_model(self):
293
- """See also vllm/model_executor/model_loader.py::get_model"""
294
- # Select model class
295
- architectures = getattr(self.model_config.hf_config, "architectures", [])
296
- model_class = get_model_cls_by_arch_name(architectures)
297
- logger.info(f"Rank {self.tp_rank}: load weight begin.")
298
-
299
- # Load weights
300
- linear_method = None
301
- with _set_default_torch_dtype(torch.float16):
302
- with torch.device("cuda"):
303
- hf_quant_config = getattr(
304
- self.model_config.hf_config, "quantization_config", None
305
- )
306
- if hf_quant_config is not None:
307
- hf_quant_method = hf_quant_config["quant_method"]
308
-
309
- # compat: autogptq uses is_marlin_format within quant config
310
- if (hf_quant_method == "gptq"
311
- and "is_marlin_format" in hf_quant_config
312
- and hf_quant_config["is_marlin_format"]):
313
- hf_quant_method = "marlin"
314
- quant_config_class = QUANTIONCONFIG_MAPPING.get(hf_quant_method)
315
-
316
- if quant_config_class is None:
317
- raise ValueError(
318
- f"Unsupported quantization method: {hf_quant_config['quant_method']}"
319
- )
320
- quant_config = quant_config_class.from_config(hf_quant_config)
321
- logger.info(f"quant_config: {quant_config}")
322
- linear_method = quant_config.get_linear_method()
323
- model = model_class(
324
- config=self.model_config.hf_config, linear_method=linear_method
325
- )
326
- model.load_weights(
327
- self.model_config.path,
328
- cache_dir=None,
329
- load_format=self.load_format,
330
- revision=None,
331
- )
332
- self.model = model.eval()
333
-
334
- logger.info(f"Rank {self.tp_rank}: load weight end.")
335
-
336
- def profile_max_num_token(self, total_gpu_memory):
337
- available_gpu_memory = get_available_gpu_memory(
338
- self.tp_rank, distributed=self.tp_size > 1
339
- ) * (1 << 30)
340
- head_dim = self.model_config.head_dim
341
- head_num = self.model_config.num_key_value_heads // self.tp_size
342
- cell_size = head_num * head_dim * self.model_config.num_hidden_layers * 2 * 2
343
- rest_memory = available_gpu_memory - total_gpu_memory * (
344
- 1 - self.mem_fraction_static
345
- )
346
- max_num_token = int(rest_memory // cell_size)
347
- return max_num_token
348
-
349
- def init_memory_pool(self, total_gpu_memory):
350
- self.max_total_num_token = self.profile_max_num_token(total_gpu_memory)
351
-
352
- if self.max_total_num_token <= 0:
353
- raise RuntimeError(
354
- "Not enought memory. " "Please try to increase --mem-fraction-static."
355
- )
356
-
357
- self.req_to_token_pool = ReqToTokenPool(
358
- int(self.max_total_num_token / self.model_config.context_len * 256),
359
- self.model_config.context_len + 8,
360
- )
361
- self.token_to_kv_pool = TokenToKVPool(
362
- self.max_total_num_token,
363
- dtype=torch.float16,
364
- head_num=self.model_config.num_key_value_heads // self.tp_size,
365
- head_dim=self.model_config.head_dim,
366
- layer_num=self.model_config.num_hidden_layers,
367
- )
368
-
369
- @torch.inference_mode()
370
- def forward_prefill(
371
- self,
372
- input_ids,
373
- req_pool_indices,
374
- seq_lens,
375
- prefix_lens,
376
- position_ids_offsets,
377
- out_cache_loc,
378
- return_logprob,
379
- ):
380
- input_metadata = InputMetadata.create(
381
- self,
382
- forward_mode=ForwardMode.PREFILL,
383
- tp_size=self.tp_size,
384
- req_pool_indices=req_pool_indices,
385
- seq_lens=seq_lens,
386
- prefix_lens=prefix_lens,
387
- position_ids_offsets=position_ids_offsets,
388
- out_cache_loc=out_cache_loc,
389
- return_logprob=return_logprob,
390
- )
391
- return self.model.forward(input_ids, input_metadata.positions, input_metadata)
392
-
393
- @torch.inference_mode()
394
- def forward_extend(
395
- self,
396
- input_ids,
397
- req_pool_indices,
398
- seq_lens,
399
- prefix_lens,
400
- position_ids_offsets,
401
- out_cache_loc,
402
- return_logprob,
403
- ):
404
- input_metadata = InputMetadata.create(
405
- self,
406
- forward_mode=ForwardMode.EXTEND,
407
- tp_size=self.tp_size,
408
- req_pool_indices=req_pool_indices,
409
- seq_lens=seq_lens,
410
- prefix_lens=prefix_lens,
411
- position_ids_offsets=position_ids_offsets,
412
- out_cache_loc=out_cache_loc,
413
- return_logprob=return_logprob,
414
- )
415
- return self.model.forward(input_ids, input_metadata.positions, input_metadata)
416
-
417
- @torch.inference_mode()
418
- def forward_decode(
419
- self,
420
- input_ids,
421
- req_pool_indices,
422
- seq_lens,
423
- prefix_lens,
424
- position_ids_offsets,
425
- out_cache_loc,
426
- out_cache_cont_start,
427
- out_cache_cont_end,
428
- return_logprob,
429
- ):
430
- input_metadata = InputMetadata.create(
431
- self,
432
- forward_mode=ForwardMode.DECODE,
433
- tp_size=self.tp_size,
434
- req_pool_indices=req_pool_indices,
435
- seq_lens=seq_lens,
436
- prefix_lens=prefix_lens,
437
- position_ids_offsets=position_ids_offsets,
438
- out_cache_loc=out_cache_loc,
439
- out_cache_cont_start=out_cache_cont_start,
440
- out_cache_cont_end=out_cache_cont_end,
441
- return_logprob=return_logprob,
442
- )
443
- return self.model.forward(input_ids, input_metadata.positions, input_metadata)
444
-
445
- @torch.inference_mode()
446
- def forward_extend_multi_modal(
447
- self,
448
- input_ids,
449
- pixel_values,
450
- image_sizes,
451
- image_offsets,
452
- req_pool_indices,
453
- seq_lens,
454
- prefix_lens,
455
- position_ids_offsets,
456
- out_cache_loc,
457
- return_logprob,
458
- ):
459
- input_metadata = InputMetadata.create(
460
- self,
461
- forward_mode=ForwardMode.EXTEND,
462
- tp_size=self.tp_size,
463
- req_pool_indices=req_pool_indices,
464
- seq_lens=seq_lens,
465
- prefix_lens=prefix_lens,
466
- position_ids_offsets=position_ids_offsets,
467
- out_cache_loc=out_cache_loc,
468
- return_logprob=return_logprob,
469
- )
470
- return self.model.forward(
471
- input_ids,
472
- input_metadata.positions,
473
- input_metadata,
474
- pixel_values,
475
- image_sizes,
476
- image_offsets,
477
- )
478
-
479
- def forward(self, batch: Batch, forward_mode: ForwardMode, return_logprob=False):
480
- if self.is_multimodal_model and forward_mode == ForwardMode.EXTEND:
481
- kwargs = {
482
- "input_ids": batch.input_ids,
483
- "pixel_values": batch.pixel_values,
484
- "image_sizes": batch.image_sizes,
485
- "image_offsets": batch.image_offsets,
486
- "req_pool_indices": batch.req_pool_indices,
487
- "seq_lens": batch.seq_lens,
488
- "prefix_lens": batch.prefix_lens,
489
- "position_ids_offsets": batch.position_ids_offsets,
490
- "out_cache_loc": batch.out_cache_loc,
491
- "return_logprob": return_logprob,
492
- }
493
- return self.forward_extend_multi_modal(**kwargs)
494
- else:
495
- kwargs = {
496
- "input_ids": batch.input_ids,
497
- "req_pool_indices": batch.req_pool_indices,
498
- "seq_lens": batch.seq_lens,
499
- "prefix_lens": batch.prefix_lens,
500
- "position_ids_offsets": batch.position_ids_offsets,
501
- "out_cache_loc": batch.out_cache_loc,
502
- "return_logprob": return_logprob,
503
- }
504
-
505
- if forward_mode == ForwardMode.DECODE:
506
- kwargs["out_cache_cont_start"] = batch.out_cache_cont_start
507
- kwargs["out_cache_cont_end"] = batch.out_cache_cont_end
508
- return self.forward_decode(**kwargs)
509
- elif forward_mode == ForwardMode.EXTEND:
510
- return self.forward_extend(**kwargs)
511
- elif forward_mode == ForwardMode.PREFILL:
512
- return self.forward_prefill(**kwargs)
513
- else:
514
- raise ValueError(f"Invaid forward mode: {forward_mode}")
@@ -1,70 +0,0 @@
1
- import random
2
- from collections import defaultdict
3
-
4
-
5
- class Scheduler:
6
- def __init__(
7
- self,
8
- schedule_heuristic,
9
- max_running_seq,
10
- max_prefill_num_token,
11
- max_total_num_token,
12
- tree_cache,
13
- ):
14
- self.schedule_heuristic = schedule_heuristic
15
- self.max_running_seq = max_running_seq
16
- self.max_prefill_num_token = max_prefill_num_token
17
- self.max_total_num_token = max_total_num_token
18
- self.tree_cache = tree_cache
19
-
20
- def get_priority_queue(self, forward_queue):
21
- if self.schedule_heuristic == "lpm":
22
- # longest prefix match
23
- forward_queue.sort(key=lambda x: -len(x.prefix_indices))
24
- return forward_queue
25
- elif self.schedule_heuristic == "random":
26
- random.shuffle(forward_queue)
27
- return forward_queue
28
- elif self.schedule_heuristic == "fcfs":
29
- return forward_queue
30
- elif self.schedule_heuristic == "weight":
31
- last_node_to_reqs = defaultdict(list)
32
- for req in forward_queue:
33
- last_node_to_reqs[req.last_node].append(req)
34
- for node in last_node_to_reqs:
35
- last_node_to_reqs[node].sort(key=lambda x: -len(x.prefix_indices))
36
-
37
- node_to_weight = defaultdict(int)
38
- self._calc_weight_recursive(
39
- self.tree_cache.root_node, last_node_to_reqs, node_to_weight
40
- )
41
-
42
- tmp_queue = []
43
- self._get_weight_priority_recursive(
44
- self.tree_cache.root_node, node_to_weight, last_node_to_reqs, tmp_queue
45
- )
46
- assert len(tmp_queue) == len(forward_queue)
47
- return tmp_queue
48
- else:
49
- raise ValueError(f"Unknown schedule_heuristic: {self.schedule_heuristic}")
50
-
51
- def _calc_weight_recursive(self, cur_node, last_node_to_reqs, node_to_weight):
52
- node_to_weight[cur_node] = 1
53
- if cur_node in last_node_to_reqs:
54
- node_to_weight[cur_node] += len(last_node_to_reqs[cur_node])
55
- for child in cur_node.children.values():
56
- self._calc_weight_recursive(child, last_node_to_reqs, node_to_weight)
57
- node_to_weight[cur_node] += node_to_weight[child]
58
-
59
- def _get_weight_priority_recursive(
60
- self, cur_node, node_to_wight, last_node_to_reqs, tmp_queue
61
- ):
62
- visit_list = [child for child in cur_node.children.values()]
63
- visit_list.sort(key=lambda x: -node_to_wight[x])
64
- # for node in visit_list:
65
- # print(f"{node_to_wight[node]} {len(node.value) if node.value is not None else 0}")
66
- for child in visit_list:
67
- self._get_weight_priority_recursive(
68
- child, node_to_wight, last_node_to_reqs, tmp_queue
69
- )
70
- tmp_queue.extend(last_node_to_reqs[cur_node])
@@ -1,64 +0,0 @@
1
- sglang/__init__.py,sha256=Nxa2M7XCh2-e6I7VrCg7OSBL6BvEW3gyRD14ZdykpRM,96
2
- sglang/api.py,sha256=0-Eh7c41hWKjPXrzzvLFdLAUVkvmPGJGLAsrG9evDTE,4576
3
- sglang/global_config.py,sha256=PAX7TWeFcq0HBzNUWyCONAOjqIokWqw8vT7I6sBSKTc,797
4
- sglang/launch_server.py,sha256=jKPZRDN5bUe8Wgz5eoDkqeePhmKa8DLD4DpXQLT5auo,294
5
- sglang/utils.py,sha256=2dUXLMPz9VhhzbIRQABmfZnVW5yz61F3UVtb6yKyevM,6237
6
- sglang/backend/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
7
- sglang/backend/anthropic.py,sha256=GJ_T1Jg0VOtajgkgczPKt5sjuVYdbAiWd2jXlJRNRmg,1677
8
- sglang/backend/base_backend.py,sha256=APiMht4WYECLCOGRPCEUF6lX-an1vjVe2dWoMSgymWY,1831
9
- sglang/backend/openai.py,sha256=nPdA88A5GISJTH88svJdww3qHWIHZcGG2NEn0XjMkLU,9578
10
- sglang/backend/runtime_endpoint.py,sha256=r7dTazselaudlFx8hqk-PQLYDHZhpbAKjyFF1zLuM_E,8022
11
- sglang/backend/vertexai.py,sha256=BLfWf_tEgoHY9srCufJM5PLe3tql2j0G6ia7cPykxCM,4713
12
- sglang/lang/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
13
- sglang/lang/chat_template.py,sha256=MaCF0fvNky0nJC9OvmAeApeHYgM6Lr03mtRhF0lS31U,8000
14
- sglang/lang/compiler.py,sha256=wNn_UqV6Sxl22mv-PpzFUtRgiFFV-Y4OYpO4LshEoRM,7527
15
- sglang/lang/interpreter.py,sha256=ahRxuEJZ7b1Tts2Lr7wViWIqL-Z12T3anvgj0XdvMN8,26666
16
- sglang/lang/ir.py,sha256=8Ap-uEUz6K9eNQTOKtMixePuLwRFHFKcN0Z5Yn44nKk,13320
17
- sglang/lang/tracer.py,sha256=pFiSNzPSg0l7ZZIlGqJDLCmQALR-wyo2dFgJP73J4_Y,8260
18
- sglang/srt/backend_config.py,sha256=UIV6kIU2j-Xh0eoezn1aXcYIy0miftHsWFeAZwqpbGE,227
19
- sglang/srt/conversation.py,sha256=mTstD-SsXG5p_YhWQUPEWU-vzzDMF4RgQ7KmLkOOC7U,15496
20
- sglang/srt/hf_transformers_utils.py,sha256=soRyYLoCn7GxgxvonufGFkdFBA3eH5i3Izk_wi7p1l0,5285
21
- sglang/srt/memory_pool.py,sha256=BMoX2wvicj214mV-xvcr_Iv_Je0qs3zTuzXfQVpV8u4,3609
22
- sglang/srt/mm_utils.py,sha256=OptgAHDX-73Bk4jAdr2BOAJtiEXJNzPrMhaM-dy275c,8889
23
- sglang/srt/model_config.py,sha256=ned-odjmKBKBhVPo04FEpus9gJsUWxrFLrLxahLwSaw,1328
24
- sglang/srt/sampling_params.py,sha256=83Fp-4HWThC20TEh139XcIb_erBqfI7KZg5txdRBq7c,2896
25
- sglang/srt/server.py,sha256=WLXissKuXQI7JFb2V8D47QSF-PPHnW-JZCiQm4YW0xE,24070
26
- sglang/srt/server_args.py,sha256=bvbi-Rb_JudqztFFfRsuXBYtUsG9hq4zMFt7X97uDhA,8954
27
- sglang/srt/utils.py,sha256=IEqpmWx_hl4eXn_KoHM0EPXmxeN2wKkgK7H01_t0x5Q,7355
28
- sglang/srt/constrained/__init__.py,sha256=BPRNDJnWtzYJ13X4urRS5aE6wFuwAVNBA9qeWIHF8rE,1236
29
- sglang/srt/constrained/base_cache.py,sha256=QQjmFEiT8jlOskJoZobhrDl2TKB-B4b1LPQo9JQCP_w,1405
30
- sglang/srt/constrained/fsm_cache.py,sha256=20mEgtDXU1Zeoicl5KBQC3arkg-RhRWiYnchJc00m1g,901
31
- sglang/srt/constrained/jump_forward.py,sha256=Z-pz2Jnvk1CxSEZA65OVq0GryqdiKuOkhhc13v5T6Lo,2482
32
- sglang/srt/layers/context_flashattention_nopad.py,sha256=TVYQ6IjftWVXORmKpEROMqQxDOnF6n2g0G1Ci4LquYM,5209
33
- sglang/srt/layers/extend_attention.py,sha256=KGqQOA5mel9qScXMAQP_3Qyhp3BNbiQ7Y_6wi38Lxcs,12622
34
- sglang/srt/layers/logits_processor.py,sha256=MW2bpqSXyghODMojqeMSYWZhUHuAFPk_gUkyyLw9HkM,4827
35
- sglang/srt/layers/radix_attention.py,sha256=bqrb8H8K8RbKTr1PzVmpnUxRzMj0H-OWCi1JYZKuRDw,5597
36
- sglang/srt/layers/token_attention.py,sha256=waOjGsWZlvf6epFhYerRJlAaMwvDTy_Z3uzPaXsVQUU,8516
37
- sglang/srt/managers/detokenizer_manager.py,sha256=1lPNh_Pe6Pr0v-TzlCBBREbvz4uFWxyw31SmnEZh0s8,3292
38
- sglang/srt/managers/io_struct.py,sha256=nXJh3CrOvv9MdAfIFoo6SCXuNQTG3KswmRKkwF61Tek,3141
39
- sglang/srt/managers/openai_protocol.py,sha256=cttqg9iv3de8fhtCqDI4cYoPPZ_gULedMXstV1ok6WA,4563
40
- sglang/srt/managers/tokenizer_manager.py,sha256=hgsR9AMj6ic9S3-2WiELh7Hnp8Xnb_bzp7kpbjHwHtM,9733
41
- sglang/srt/managers/router/infer_batch.py,sha256=U-Ckt9ad1WaOQF_dW6Eo9AMIRQoOJQ-Pm-MMXnEmPP8,18399
42
- sglang/srt/managers/router/manager.py,sha256=TNYs0IrkZGkPvZJViwL7BMUg0VlvzeyTjDMjuvRoMDI,2529
43
- sglang/srt/managers/router/model_rpc.py,sha256=VlwLNpHZ92bnteQl4PhVKoAXM0C8Y4_2LBBVaffeu3g,26766
44
- sglang/srt/managers/router/model_runner.py,sha256=-wWv00EbB_UkkLpio6VKGBTagfzxLHfY-eKDDQ0rZQc,18292
45
- sglang/srt/managers/router/radix_cache.py,sha256=XGUF5mxQTSCzD7GW_ltNP2p5aelEKrMXzdezufJ7NCQ,6484
46
- sglang/srt/managers/router/scheduler.py,sha256=V-LAnVSzgD2ddy2eXW3jWURCeq9Lv7YxCGk4kHyytfM,2818
47
- sglang/srt/models/gemma.py,sha256=8XlfHPtVixPYYjz5F9T4DOAuoordWFStmyFFWGfny1k,11582
48
- sglang/srt/models/llama2.py,sha256=VL4iN8R3wyTNr0bDxxKdLNnVGEvdXF6iGvA768YeakA,11611
49
- sglang/srt/models/llava.py,sha256=42sn-AgI-6dMaTEU4aEbi4Js5epy0J3JVQoMooUOKt8,14922
50
- sglang/srt/models/mistral.py,sha256=XSn7fiZqspyWVTYrpVAacAnWdwAybBtyn9-Sh9AvMTM,254
51
- sglang/srt/models/mixtral.py,sha256=wqIwKfR90ih0gDiTZkFZcQD4PIYpZFD3CmzxRcuKIqw,13915
52
- sglang/srt/models/qwen.py,sha256=CvdbcF90aI1tJPSQ-3OMUaQGMuaxCGe0y29m5nU_Yj0,9225
53
- sglang/srt/models/qwen2.py,sha256=myPc0wvgf5ZzJyGhUGN49YjY-tMf4t8Jn_Imjg8D7Mk,11307
54
- sglang/srt/models/stablelm.py,sha256=vMZUNgwXKPGYr5FcdYHw5g3QifVu9owKqq51_-EBOY0,10817
55
- sglang/srt/models/yivl.py,sha256=Qvp-zQ93cOZGg3zVyaiQLhRsfXiLrQhxu9TyQP2FMm4,4414
56
- sglang/test/test_conversation.py,sha256=1zIrXcXiwEliPHgDAsqsQUA7JKzZ5fnQEU-U6L887FU,1592
57
- sglang/test/test_openai_protocol.py,sha256=eePzoskYR3PqfWczSVZvg8ja63qbT8TFUNEMyzDZpa8,1657
58
- sglang/test/test_programs.py,sha256=mrLhGuprwvx8ZJ-0Qe28E-iCw5Qv-9T0SAv1Jgo1AJw,11421
59
- sglang/test/test_utils.py,sha256=6PhTRi8UnR-BRNjit6aGu0M5lO0RebNQwEcDt712hE4,4830
60
- sglang-0.1.14.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
61
- sglang-0.1.14.dist-info/METADATA,sha256=C5N0VOYRHixdJcsf4dExIvP-Q099kYBMKs_dA4LBXSM,28809
62
- sglang-0.1.14.dist-info/WHEEL,sha256=GJ7t_kWBFywbagK5eo9IoUwLW6oyOeTKmQ-9iHFVNxQ,92
63
- sglang-0.1.14.dist-info/top_level.txt,sha256=yxhh3pYQkcnA7v3Bg889C2jZhvtJdEincysO7PEB09M,7
64
- sglang-0.1.14.dist-info/RECORD,,