sglang 0.1.15__py3-none-any.whl → 0.1.17__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (69) hide show
  1. sglang/__init__.py +5 -1
  2. sglang/api.py +8 -3
  3. sglang/backend/anthropic.py +1 -1
  4. sglang/backend/litellm.py +90 -0
  5. sglang/backend/openai.py +148 -12
  6. sglang/backend/runtime_endpoint.py +18 -10
  7. sglang/global_config.py +11 -1
  8. sglang/lang/chat_template.py +9 -2
  9. sglang/lang/interpreter.py +161 -81
  10. sglang/lang/ir.py +29 -11
  11. sglang/lang/tracer.py +1 -1
  12. sglang/launch_server.py +1 -2
  13. sglang/launch_server_llavavid.py +31 -0
  14. sglang/srt/constrained/fsm_cache.py +3 -0
  15. sglang/srt/flush_cache.py +16 -0
  16. sglang/srt/hf_transformers_utils.py +83 -2
  17. sglang/srt/layers/extend_attention.py +17 -0
  18. sglang/srt/layers/fused_moe.py +485 -0
  19. sglang/srt/layers/logits_processor.py +12 -7
  20. sglang/srt/layers/radix_attention.py +10 -3
  21. sglang/srt/layers/token_attention.py +16 -1
  22. sglang/srt/managers/controller/dp_worker.py +110 -0
  23. sglang/srt/managers/controller/infer_batch.py +619 -0
  24. sglang/srt/managers/controller/manager_multi.py +191 -0
  25. sglang/srt/managers/controller/manager_single.py +97 -0
  26. sglang/srt/managers/controller/model_runner.py +462 -0
  27. sglang/srt/managers/controller/radix_cache.py +267 -0
  28. sglang/srt/managers/controller/schedule_heuristic.py +59 -0
  29. sglang/srt/managers/controller/tp_worker.py +791 -0
  30. sglang/srt/managers/detokenizer_manager.py +45 -45
  31. sglang/srt/managers/io_struct.py +26 -10
  32. sglang/srt/managers/router/infer_batch.py +130 -74
  33. sglang/srt/managers/router/manager.py +7 -9
  34. sglang/srt/managers/router/model_rpc.py +224 -135
  35. sglang/srt/managers/router/model_runner.py +94 -107
  36. sglang/srt/managers/router/radix_cache.py +54 -18
  37. sglang/srt/managers/router/scheduler.py +23 -34
  38. sglang/srt/managers/tokenizer_manager.py +183 -88
  39. sglang/srt/model_config.py +5 -2
  40. sglang/srt/models/commandr.py +15 -22
  41. sglang/srt/models/dbrx.py +22 -29
  42. sglang/srt/models/gemma.py +14 -24
  43. sglang/srt/models/grok.py +671 -0
  44. sglang/srt/models/llama2.py +24 -23
  45. sglang/srt/models/llava.py +85 -25
  46. sglang/srt/models/llavavid.py +298 -0
  47. sglang/srt/models/mixtral.py +254 -130
  48. sglang/srt/models/mixtral_quant.py +373 -0
  49. sglang/srt/models/qwen.py +28 -25
  50. sglang/srt/models/qwen2.py +17 -22
  51. sglang/srt/models/stablelm.py +21 -26
  52. sglang/srt/models/yivl.py +17 -25
  53. sglang/srt/openai_api_adapter.py +140 -95
  54. sglang/srt/openai_protocol.py +10 -1
  55. sglang/srt/server.py +101 -52
  56. sglang/srt/server_args.py +59 -11
  57. sglang/srt/utils.py +242 -75
  58. sglang/test/test_programs.py +44 -0
  59. sglang/test/test_utils.py +32 -1
  60. sglang/utils.py +95 -26
  61. {sglang-0.1.15.dist-info → sglang-0.1.17.dist-info}/METADATA +23 -13
  62. sglang-0.1.17.dist-info/RECORD +81 -0
  63. sglang/srt/backend_config.py +0 -13
  64. sglang/srt/models/dbrx_config.py +0 -281
  65. sglang/srt/weight_utils.py +0 -402
  66. sglang-0.1.15.dist-info/RECORD +0 -69
  67. {sglang-0.1.15.dist-info → sglang-0.1.17.dist-info}/LICENSE +0 -0
  68. {sglang-0.1.15.dist-info → sglang-0.1.17.dist-info}/WHEEL +0 -0
  69. {sglang-0.1.15.dist-info → sglang-0.1.17.dist-info}/top_level.txt +0 -0
@@ -1,30 +1,25 @@
1
1
  import importlib
2
2
  import importlib.resources
3
- import inspect
4
3
  import logging
5
4
  import pkgutil
6
5
  from dataclasses import dataclass
7
6
  from functools import lru_cache
8
- from typing import List
7
+ from typing import List, Optional, Type
9
8
 
10
9
  import numpy as np
11
10
  import torch
12
- from vllm.model_executor.layers.quantization.awq import AWQConfig
13
- from vllm.model_executor.layers.quantization.gptq import GPTQConfig
14
- from vllm.model_executor.layers.quantization.marlin import MarlinConfig
15
- from vllm.model_executor.model_loader.utils import set_default_torch_dtype
11
+ import torch.nn as nn
12
+ from vllm.config import DeviceConfig, LoadConfig
13
+ from vllm.config import ModelConfig as VllmModelConfig
16
14
  from vllm.distributed import initialize_model_parallel
15
+ from vllm.model_executor.model_loader import get_model
16
+ from vllm.model_executor.models import ModelRegistry
17
17
 
18
18
  from sglang.srt.managers.router.infer_batch import Batch, ForwardMode
19
19
  from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
20
- from sglang.srt.utils import is_multimodal_model
21
- from sglang.utils import get_available_gpu_memory
20
+ from sglang.srt.server_args import ServerArgs
21
+ from sglang.srt.utils import get_available_gpu_memory, is_multimodal_model
22
22
 
23
- QUANTIZATION_CONFIG_MAPPING = {
24
- "awq": AWQConfig,
25
- "gptq": GPTQConfig,
26
- "marlin": MarlinConfig,
27
- }
28
23
 
29
24
  logger = logging.getLogger("model_runner")
30
25
 
@@ -32,35 +27,6 @@ logger = logging.getLogger("model_runner")
32
27
  global_server_args_dict = {}
33
28
 
34
29
 
35
- @lru_cache()
36
- def import_model_classes():
37
- model_arch_name_to_cls = {}
38
- package_name = "sglang.srt.models"
39
- package = importlib.import_module(package_name)
40
- for _, name, ispkg in pkgutil.iter_modules(package.__path__, package_name + "."):
41
- if not ispkg:
42
- module = importlib.import_module(name)
43
- if hasattr(module, "EntryClass"):
44
- model_arch_name_to_cls[module.EntryClass.__name__] = module.EntryClass
45
- return model_arch_name_to_cls
46
-
47
-
48
- def get_model_cls_by_arch_name(model_arch_names):
49
- model_arch_name_to_cls = import_model_classes()
50
-
51
- model_class = None
52
- for arch in model_arch_names:
53
- if arch in model_arch_name_to_cls:
54
- model_class = model_arch_name_to_cls[arch]
55
- break
56
- else:
57
- raise ValueError(
58
- f"Unsupported architectures: {arch}. "
59
- f"Supported list: {list(model_arch_name_to_cls.keys())}"
60
- )
61
- return model_class
62
-
63
-
64
30
  @dataclass
65
31
  class InputMetadata:
66
32
  model_runner: "ModelRunner"
@@ -110,8 +76,8 @@ class InputMetadata:
110
76
  self.kv_last_page_len = torch.ones(
111
77
  (self.batch_size,), dtype=torch.int32, device="cuda"
112
78
  )
113
- req_pool_indices_cpu = self.req_pool_indices.cpu().tolist()
114
- seq_lens_cpu = self.seq_lens.tolist()
79
+ req_pool_indices_cpu = self.req_pool_indices.cpu().numpy()
80
+ seq_lens_cpu = self.seq_lens.cpu().numpy()
115
81
  self.kv_indices = torch.cat(
116
82
  [
117
83
  self.req_to_token_pool.req_to_token[
@@ -143,7 +109,7 @@ class InputMetadata:
143
109
  self.kv_last_page_len,
144
110
  self.model_runner.model_config.num_attention_heads // tp_size,
145
111
  self.model_runner.model_config.num_key_value_heads // tp_size,
146
- self.model_runner.model_config.head_dim
112
+ self.model_runner.model_config.head_dim,
147
113
  ]
148
114
 
149
115
  self.prefill_wrapper.begin_forward(*args)
@@ -253,113 +219,102 @@ class ModelRunner:
253
219
  tp_rank,
254
220
  tp_size,
255
221
  nccl_port,
256
- load_format="auto",
257
- trust_remote_code=True,
258
- server_args_dict: dict = {},
222
+ server_args: ServerArgs,
259
223
  ):
260
224
  self.model_config = model_config
261
225
  self.mem_fraction_static = mem_fraction_static
262
226
  self.tp_rank = tp_rank
263
227
  self.tp_size = tp_size
264
228
  self.nccl_port = nccl_port
265
- self.load_format = load_format
266
- self.trust_remote_code = trust_remote_code
229
+ self.server_args = server_args
267
230
 
268
231
  global global_server_args_dict
269
- global_server_args_dict = server_args_dict
232
+ global_server_args_dict = {
233
+ "enable_flashinfer": server_args.enable_flashinfer,
234
+ "attention_reduce_in_fp32": server_args.attention_reduce_in_fp32,
235
+ }
270
236
 
271
237
  # Init torch distributed
238
+ logger.info(f"[rank={self.tp_rank}] Set cuda device.")
272
239
  torch.cuda.set_device(self.tp_rank)
240
+ logger.info(f"[rank={self.tp_rank}] Init torch begin. Avail mem={get_available_gpu_memory(self.tp_rank):.2f} GB")
273
241
  torch.distributed.init_process_group(
274
242
  backend="nccl",
275
243
  world_size=self.tp_size,
276
244
  rank=self.tp_rank,
277
245
  init_method=f"tcp://127.0.0.1:{self.nccl_port}",
278
246
  )
279
-
280
247
  initialize_model_parallel(tensor_model_parallel_size=self.tp_size)
248
+ logger.info(f"[rank={self.tp_rank}] Init torch end.")
249
+
250
+ total_gpu_memory = get_available_gpu_memory(self.tp_rank, distributed=self.tp_size > 1)
251
+
252
+ if self.tp_size > 1:
253
+ total_local_gpu_memory = get_available_gpu_memory(self.tp_rank)
254
+ if total_local_gpu_memory < total_gpu_memory * 0.9:
255
+ raise ValueError("The memory capacity is unbalanced. Some GPUs may be occupied by other processes.")
281
256
 
282
- total_gpu_memory = get_available_gpu_memory(
283
- self.tp_rank, distributed=self.tp_size > 1
284
- ) * (1 << 30)
285
257
  self.load_model()
286
258
  self.init_memory_pool(total_gpu_memory)
287
259
 
288
260
  self.is_multimodal_model = is_multimodal_model(self.model_config)
289
261
 
290
262
  def load_model(self):
291
- """See also vllm/model_executor/model_loader.py::get_model"""
292
- # Select model class
293
- architectures = getattr(self.model_config.hf_config, "architectures", [])
294
- model_class = get_model_cls_by_arch_name(architectures)
295
- logger.info(f"Rank {self.tp_rank}: load weight begin.")
296
-
297
- # Load weights
298
- quant_config = None
299
-
300
- quant_cfg = getattr(self.model_config.hf_config, "quantization_config", None)
301
- if quant_cfg is not None:
302
- quant_method = quant_cfg.get("quant_method", "").lower()
303
- # compat: autogptq >=0.8.0 use checkpoint_format: str
304
- # compat: autogptq <=0.7.1 is_marlin_format: bool
305
- is_format_marlin = quant_cfg.get(
306
- "checkpoint_format"
307
- ) == "marlin" or quant_cfg.get("is_marlin_format", False)
308
-
309
- # Use marlin if the GPTQ model is serialized in marlin format.
310
- if quant_method == "gptq" and is_format_marlin:
311
- quant_method = "marlin"
312
-
313
- quant_config_class = QUANTIZATION_CONFIG_MAPPING.get(quant_method)
314
-
315
- if quant_config_class is None:
316
- raise ValueError(f"Unsupported quantization method: {quant_method}")
317
-
318
- quant_config = quant_config_class.from_config(quant_cfg)
319
- logger.info(f"quant_config: {quant_config}")
320
-
321
- with set_default_torch_dtype(torch.float16):
322
- with torch.device("cuda"):
323
- model = model_class(
324
- config=self.model_config.hf_config, quant_config=quant_config
325
- )
326
- model.load_weights(
327
- self.model_config.path,
328
- cache_dir=None,
329
- load_format=self.load_format,
330
- revision=None,
331
- )
332
- self.model = model.eval()
333
-
334
- logger.info(f"Rank {self.tp_rank}: load weight end.")
263
+ logger.info(f"[rank={self.tp_rank}] Load weight begin.")
264
+
265
+ device_config = DeviceConfig()
266
+ load_config = LoadConfig(load_format=self.server_args.load_format)
267
+ vllm_model_config = VllmModelConfig(
268
+ model=self.server_args.model_path,
269
+ quantization=self.server_args.quantization,
270
+ tokenizer=None,
271
+ tokenizer_mode=None,
272
+ trust_remote_code=self.server_args.trust_remote_code,
273
+ dtype=torch.float16,
274
+ seed=42,
275
+ skip_tokenizer_init=True,
276
+ )
277
+ if self.model_config.model_overide_args is not None:
278
+ vllm_model_config.hf_config.update(self.model_config.model_overide_args)
279
+
280
+ self.model = get_model(
281
+ model_config=vllm_model_config,
282
+ device_config=device_config,
283
+ load_config=load_config,
284
+ lora_config=None,
285
+ vision_language_config=None,
286
+ parallel_config=None,
287
+ scheduler_config=None,
288
+ )
289
+ logger.info(f"[rank={self.tp_rank}] Load weight end. "
290
+ f"Type={type(self.model).__name__}. "
291
+ f"Avail mem={get_available_gpu_memory(self.tp_rank):.2f} GB")
335
292
 
336
293
  def profile_max_num_token(self, total_gpu_memory):
337
- available_gpu_memory = get_available_gpu_memory(
338
- self.tp_rank, distributed=self.tp_size > 1
339
- ) * (1 << 30)
294
+ available_gpu_memory = get_available_gpu_memory(self.tp_rank, distributed=self.tp_size > 1)
340
295
  head_dim = self.model_config.head_dim
341
296
  head_num = self.model_config.num_key_value_heads // self.tp_size
342
297
  cell_size = head_num * head_dim * self.model_config.num_hidden_layers * 2 * 2
343
298
  rest_memory = available_gpu_memory - total_gpu_memory * (
344
299
  1 - self.mem_fraction_static
345
300
  )
346
- max_num_token = int(rest_memory // cell_size)
301
+ max_num_token = int(rest_memory * (1 << 30) // cell_size)
347
302
  return max_num_token
348
303
 
349
304
  def init_memory_pool(self, total_gpu_memory):
350
- self.max_total_num_token = self.profile_max_num_token(total_gpu_memory)
305
+ self.max_total_num_tokens = self.profile_max_num_token(total_gpu_memory)
351
306
 
352
- if self.max_total_num_token <= 0:
307
+ if self.max_total_num_tokens <= 0:
353
308
  raise RuntimeError(
354
309
  "Not enought memory. " "Please try to increase --mem-fraction-static."
355
310
  )
356
311
 
357
312
  self.req_to_token_pool = ReqToTokenPool(
358
- int(self.max_total_num_token / self.model_config.context_len * 256),
313
+ int(self.max_total_num_tokens / self.model_config.context_len * 256),
359
314
  self.model_config.context_len + 8,
360
315
  )
361
316
  self.token_to_kv_pool = TokenToKVPool(
362
- self.max_total_num_token,
317
+ self.max_total_num_tokens,
363
318
  dtype=torch.float16,
364
319
  head_num=self.model_config.num_key_value_heads // self.tp_size,
365
320
  head_dim=self.model_config.head_dim,
@@ -456,3 +411,35 @@ class ModelRunner:
456
411
  return self.forward_prefill(batch)
457
412
  else:
458
413
  raise ValueError(f"Invaid forward mode: {forward_mode}")
414
+
415
+
416
+ @lru_cache()
417
+ def import_model_classes():
418
+ model_arch_name_to_cls = {}
419
+ package_name = "sglang.srt.models"
420
+ package = importlib.import_module(package_name)
421
+ for _, name, ispkg in pkgutil.iter_modules(package.__path__, package_name + "."):
422
+ if not ispkg:
423
+ module = importlib.import_module(name)
424
+ if hasattr(module, "EntryClass"):
425
+ entry = module.EntryClass
426
+ if isinstance(entry, list): # To support multiple model classes in one module
427
+ for cls in entry:
428
+ model_arch_name_to_cls[cls.__name__] = cls
429
+ else:
430
+ model_arch_name_to_cls[entry.__name__] = entry
431
+ return model_arch_name_to_cls
432
+
433
+
434
+ def load_model_cls_srt(model_arch: str) -> Optional[Type[nn.Module]]:
435
+ model_arch_name_to_cls = import_model_classes()
436
+ if model_arch not in model_arch_name_to_cls:
437
+ raise ValueError(
438
+ f"Unsupported architectures: {model_arch}. "
439
+ f"Supported list: {list(model_arch_name_to_cls.keys())}"
440
+ )
441
+ return model_arch_name_to_cls[model_arch]
442
+
443
+
444
+ # Monkey patch model loader
445
+ setattr(ModelRegistry, "load_model_cls", load_model_cls_srt)
@@ -11,7 +11,7 @@ class TreeNode:
11
11
  self.parent = None
12
12
  self.key = None
13
13
  self.value = None
14
- self.ref_counter = 0
14
+ self.lock_ref = 0
15
15
  self.last_access_time = time.time()
16
16
 
17
17
  def __lt__(self, other: "TreeNode"):
@@ -28,7 +28,9 @@ def _key_match(key0, key1):
28
28
 
29
29
 
30
30
  class RadixCache:
31
- def __init__(self, disable: bool = False):
31
+ def __init__(self, req_to_token_pool, token_to_kv_pool, disable: bool = False):
32
+ self.req_to_token_pool = req_to_token_pool
33
+ self.token_to_kv_pool = token_to_kv_pool
32
34
  self.disable = disable
33
35
  self.reset()
34
36
 
@@ -38,7 +40,7 @@ class RadixCache:
38
40
  self.root_node = TreeNode()
39
41
  self.root_node.key = []
40
42
  self.root_node.value = []
41
- self.root_node.ref_counter = 1
43
+ self.root_node.lock_ref = 1
42
44
  self.evictable_size_ = 0
43
45
 
44
46
  def match_prefix(self, key):
@@ -50,16 +52,52 @@ class RadixCache:
50
52
  self._match_prefix_helper(self.root_node, key, value, last_node)
51
53
  if value:
52
54
  value = torch.concat(value)
55
+ else:
56
+ value = torch.tensor([], dtype=torch.int64)
53
57
  return value, last_node[0]
54
58
 
55
59
  def insert(self, key, value=None):
56
60
  if self.disable:
57
- return len(key)
61
+ return 0
58
62
 
59
63
  if value is None:
60
64
  value = [x for x in key]
61
65
  return self._insert_helper(self.root_node, key, value)
62
66
 
67
+ def cache_req(
68
+ self,
69
+ token_ids,
70
+ last_uncached_pos,
71
+ req_pool_idx,
72
+ del_in_memory_pool=True,
73
+ old_last_node=None,
74
+ ):
75
+ # Insert the request into radix cache
76
+ indices = self.req_to_token_pool.req_to_token[req_pool_idx, : len(token_ids)]
77
+ new_prefix_len = self.insert(token_ids, indices.clone())
78
+
79
+ if self.disable:
80
+ if del_in_memory_pool:
81
+ self.token_to_kv_pool.dec_refs(indices)
82
+ else:
83
+ return torch.tensor([], dtype=torch.int64), self.root_node
84
+
85
+ # Radix Cache takes one ref in memory pool
86
+ self.token_to_kv_pool.dec_refs(indices[last_uncached_pos:new_prefix_len])
87
+
88
+ if del_in_memory_pool:
89
+ self.req_to_token_pool.free(req_pool_idx)
90
+ else:
91
+ cached_indices, new_last_node = self.match_prefix(token_ids)
92
+ assert len(cached_indices) == len(token_ids)
93
+
94
+ self.req_to_token_pool.req_to_token[
95
+ req_pool_idx, last_uncached_pos : len(cached_indices)
96
+ ] = cached_indices[last_uncached_pos:]
97
+ self.dec_lock_ref(old_last_node)
98
+ self.inc_lock_ref(new_last_node)
99
+ return cached_indices, new_last_node
100
+
63
101
  def pretty_print(self):
64
102
  self._print_helper(self.root_node, 0)
65
103
  print(f"#tokens: {self.total_size()}")
@@ -80,7 +118,7 @@ class RadixCache:
80
118
 
81
119
  if x == self.root_node:
82
120
  break
83
- if x.ref_counter > 0:
121
+ if x.lock_ref > 0:
84
122
  continue
85
123
 
86
124
  num_evicted += evict_callback(x.value)
@@ -89,23 +127,23 @@ class RadixCache:
89
127
  if len(x.parent.children) == 0:
90
128
  heapq.heappush(leaves, x.parent)
91
129
 
92
- def inc_ref_counter(self, node):
130
+ def inc_lock_ref(self, node: TreeNode):
93
131
  delta = 0
94
132
  while node != self.root_node:
95
- if node.ref_counter == 0:
133
+ if node.lock_ref == 0:
96
134
  self.evictable_size_ -= len(node.value)
97
135
  delta -= len(node.value)
98
- node.ref_counter += 1
136
+ node.lock_ref += 1
99
137
  node = node.parent
100
138
  return delta
101
139
 
102
- def dec_ref_counter(self, node):
140
+ def dec_lock_ref(self, node: TreeNode):
103
141
  delta = 0
104
142
  while node != self.root_node:
105
- if node.ref_counter == 1:
143
+ if node.lock_ref == 1:
106
144
  self.evictable_size_ += len(node.value)
107
145
  delta += len(node.value)
108
- node.ref_counter -= 1
146
+ node.lock_ref -= 1
109
147
  node = node.parent
110
148
  return delta
111
149
 
@@ -131,12 +169,12 @@ class RadixCache:
131
169
  last_node[0] = child
132
170
  self._match_prefix_helper(child, key[prefix_len:], value, last_node)
133
171
 
134
- def _split_node(self, key, child, split_len):
172
+ def _split_node(self, key, child: TreeNode, split_len):
135
173
  # new_node -> child
136
174
  new_node = TreeNode()
137
175
  new_node.children = {key[split_len:][0]: child}
138
176
  new_node.parent = child.parent
139
- new_node.ref_counter = child.ref_counter
177
+ new_node.lock_ref = child.lock_ref
140
178
  new_node.key = child.key[:split_len]
141
179
  new_node.value = child.value[:split_len]
142
180
  child.parent = new_node
@@ -176,11 +214,9 @@ class RadixCache:
176
214
  self.evictable_size_ += len(value)
177
215
  return 0
178
216
 
179
- def _print_helper(self, node, indent):
217
+ def _print_helper(self, node: TreeNode, indent):
180
218
  for _, child in node.children.items():
181
- print(
182
- " " * indent, len(child.key), child.key[:10], f"r={child.ref_counter}"
183
- )
219
+ print(" " * indent, len(child.key), child.key[:10], f"r={child.lock_ref}")
184
220
  self._print_helper(child, indent=indent + 2)
185
221
 
186
222
  def _delete_leaf(self, node):
@@ -211,7 +247,7 @@ class RadixCache:
211
247
 
212
248
 
213
249
  if __name__ == "__main__":
214
- tree = RadixCache()
250
+ tree = RadixCache(None, None, False)
215
251
 
216
252
  tree.insert("Hello")
217
253
  tree.insert("Hello")
@@ -6,15 +6,15 @@ class Scheduler:
6
6
  def __init__(
7
7
  self,
8
8
  schedule_heuristic,
9
- max_running_seq,
10
- max_prefill_num_token,
11
- max_total_num_token,
9
+ max_running_seqs,
10
+ max_prefill_num_tokens,
11
+ max_total_num_tokens,
12
12
  tree_cache,
13
13
  ):
14
14
  self.schedule_heuristic = schedule_heuristic
15
- self.max_running_seq = max_running_seq
16
- self.max_prefill_num_token = max_prefill_num_token
17
- self.max_total_num_token = max_total_num_token
15
+ self.max_running_seqs = max_running_seqs
16
+ self.max_prefill_num_tokens = max_prefill_num_tokens
17
+ self.max_total_num_tokens = max_total_num_tokens
18
18
  self.tree_cache = tree_cache
19
19
 
20
20
  def get_priority_queue(self, forward_queue):
@@ -27,44 +27,33 @@ class Scheduler:
27
27
  return forward_queue
28
28
  elif self.schedule_heuristic == "fcfs":
29
29
  return forward_queue
30
- elif self.schedule_heuristic == "weight":
30
+ elif self.schedule_heuristic == "dfs-weight":
31
31
  last_node_to_reqs = defaultdict(list)
32
32
  for req in forward_queue:
33
33
  last_node_to_reqs[req.last_node].append(req)
34
- for node in last_node_to_reqs:
35
- last_node_to_reqs[node].sort(key=lambda x: -len(x.prefix_indices))
36
34
 
37
35
  node_to_weight = defaultdict(int)
38
- self._calc_weight_recursive(
39
- self.tree_cache.root_node, last_node_to_reqs, node_to_weight
40
- )
36
+ for node in last_node_to_reqs:
37
+ node_to_weight[node] = len(last_node_to_reqs[node])
38
+ self.calc_weight(self.tree_cache.root_node, node_to_weight)
41
39
 
42
- tmp_queue = []
43
- self._get_weight_priority_recursive(
44
- self.tree_cache.root_node, node_to_weight, last_node_to_reqs, tmp_queue
40
+ q = []
41
+ self.get_dfs_priority(
42
+ self.tree_cache.root_node, node_to_weight, last_node_to_reqs, q
45
43
  )
46
- assert len(tmp_queue) == len(forward_queue)
47
- return tmp_queue
44
+ assert len(q) == len(forward_queue)
45
+ return q
48
46
  else:
49
47
  raise ValueError(f"Unknown schedule_heuristic: {self.schedule_heuristic}")
50
48
 
51
- def _calc_weight_recursive(self, cur_node, last_node_to_reqs, node_to_weight):
52
- node_to_weight[cur_node] = 1
53
- if cur_node in last_node_to_reqs:
54
- node_to_weight[cur_node] += len(last_node_to_reqs[cur_node])
49
+ def calc_weight(self, cur_node, node_to_weight):
55
50
  for child in cur_node.children.values():
56
- self._calc_weight_recursive(child, last_node_to_reqs, node_to_weight)
51
+ self.calc_weight(child, node_to_weight)
57
52
  node_to_weight[cur_node] += node_to_weight[child]
58
53
 
59
- def _get_weight_priority_recursive(
60
- self, cur_node, node_to_wight, last_node_to_reqs, tmp_queue
61
- ):
62
- visit_list = [child for child in cur_node.children.values()]
63
- visit_list.sort(key=lambda x: -node_to_wight[x])
64
- # for node in visit_list:
65
- # print(f"{node_to_wight[node]} {len(node.value) if node.value is not None else 0}")
66
- for child in visit_list:
67
- self._get_weight_priority_recursive(
68
- child, node_to_wight, last_node_to_reqs, tmp_queue
69
- )
70
- tmp_queue.extend(last_node_to_reqs[cur_node])
54
+ def get_dfs_priority(self, cur_node, node_to_priority, last_node_to_reqs, q):
55
+ childs = [child for child in cur_node.children.values()]
56
+ childs.sort(key=lambda x: -node_to_priority[x])
57
+ for child in childs:
58
+ self.get_dfs_priority(child, node_to_priority, last_node_to_reqs, q)
59
+ q.extend(last_node_to_reqs[cur_node])