sglang 0.1.15__py3-none-any.whl → 0.1.17__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +5 -1
- sglang/api.py +8 -3
- sglang/backend/anthropic.py +1 -1
- sglang/backend/litellm.py +90 -0
- sglang/backend/openai.py +148 -12
- sglang/backend/runtime_endpoint.py +18 -10
- sglang/global_config.py +11 -1
- sglang/lang/chat_template.py +9 -2
- sglang/lang/interpreter.py +161 -81
- sglang/lang/ir.py +29 -11
- sglang/lang/tracer.py +1 -1
- sglang/launch_server.py +1 -2
- sglang/launch_server_llavavid.py +31 -0
- sglang/srt/constrained/fsm_cache.py +3 -0
- sglang/srt/flush_cache.py +16 -0
- sglang/srt/hf_transformers_utils.py +83 -2
- sglang/srt/layers/extend_attention.py +17 -0
- sglang/srt/layers/fused_moe.py +485 -0
- sglang/srt/layers/logits_processor.py +12 -7
- sglang/srt/layers/radix_attention.py +10 -3
- sglang/srt/layers/token_attention.py +16 -1
- sglang/srt/managers/controller/dp_worker.py +110 -0
- sglang/srt/managers/controller/infer_batch.py +619 -0
- sglang/srt/managers/controller/manager_multi.py +191 -0
- sglang/srt/managers/controller/manager_single.py +97 -0
- sglang/srt/managers/controller/model_runner.py +462 -0
- sglang/srt/managers/controller/radix_cache.py +267 -0
- sglang/srt/managers/controller/schedule_heuristic.py +59 -0
- sglang/srt/managers/controller/tp_worker.py +791 -0
- sglang/srt/managers/detokenizer_manager.py +45 -45
- sglang/srt/managers/io_struct.py +26 -10
- sglang/srt/managers/router/infer_batch.py +130 -74
- sglang/srt/managers/router/manager.py +7 -9
- sglang/srt/managers/router/model_rpc.py +224 -135
- sglang/srt/managers/router/model_runner.py +94 -107
- sglang/srt/managers/router/radix_cache.py +54 -18
- sglang/srt/managers/router/scheduler.py +23 -34
- sglang/srt/managers/tokenizer_manager.py +183 -88
- sglang/srt/model_config.py +5 -2
- sglang/srt/models/commandr.py +15 -22
- sglang/srt/models/dbrx.py +22 -29
- sglang/srt/models/gemma.py +14 -24
- sglang/srt/models/grok.py +671 -0
- sglang/srt/models/llama2.py +24 -23
- sglang/srt/models/llava.py +85 -25
- sglang/srt/models/llavavid.py +298 -0
- sglang/srt/models/mixtral.py +254 -130
- sglang/srt/models/mixtral_quant.py +373 -0
- sglang/srt/models/qwen.py +28 -25
- sglang/srt/models/qwen2.py +17 -22
- sglang/srt/models/stablelm.py +21 -26
- sglang/srt/models/yivl.py +17 -25
- sglang/srt/openai_api_adapter.py +140 -95
- sglang/srt/openai_protocol.py +10 -1
- sglang/srt/server.py +101 -52
- sglang/srt/server_args.py +59 -11
- sglang/srt/utils.py +242 -75
- sglang/test/test_programs.py +44 -0
- sglang/test/test_utils.py +32 -1
- sglang/utils.py +95 -26
- {sglang-0.1.15.dist-info → sglang-0.1.17.dist-info}/METADATA +23 -13
- sglang-0.1.17.dist-info/RECORD +81 -0
- sglang/srt/backend_config.py +0 -13
- sglang/srt/models/dbrx_config.py +0 -281
- sglang/srt/weight_utils.py +0 -402
- sglang-0.1.15.dist-info/RECORD +0 -69
- {sglang-0.1.15.dist-info → sglang-0.1.17.dist-info}/LICENSE +0 -0
- {sglang-0.1.15.dist-info → sglang-0.1.17.dist-info}/WHEEL +0 -0
- {sglang-0.1.15.dist-info → sglang-0.1.17.dist-info}/top_level.txt +0 -0
@@ -1,30 +1,25 @@
|
|
1
1
|
import importlib
|
2
2
|
import importlib.resources
|
3
|
-
import inspect
|
4
3
|
import logging
|
5
4
|
import pkgutil
|
6
5
|
from dataclasses import dataclass
|
7
6
|
from functools import lru_cache
|
8
|
-
from typing import List
|
7
|
+
from typing import List, Optional, Type
|
9
8
|
|
10
9
|
import numpy as np
|
11
10
|
import torch
|
12
|
-
|
13
|
-
from vllm.
|
14
|
-
from vllm.
|
15
|
-
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
|
11
|
+
import torch.nn as nn
|
12
|
+
from vllm.config import DeviceConfig, LoadConfig
|
13
|
+
from vllm.config import ModelConfig as VllmModelConfig
|
16
14
|
from vllm.distributed import initialize_model_parallel
|
15
|
+
from vllm.model_executor.model_loader import get_model
|
16
|
+
from vllm.model_executor.models import ModelRegistry
|
17
17
|
|
18
18
|
from sglang.srt.managers.router.infer_batch import Batch, ForwardMode
|
19
19
|
from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
|
20
|
-
from sglang.srt.
|
21
|
-
from sglang.utils import get_available_gpu_memory
|
20
|
+
from sglang.srt.server_args import ServerArgs
|
21
|
+
from sglang.srt.utils import get_available_gpu_memory, is_multimodal_model
|
22
22
|
|
23
|
-
QUANTIZATION_CONFIG_MAPPING = {
|
24
|
-
"awq": AWQConfig,
|
25
|
-
"gptq": GPTQConfig,
|
26
|
-
"marlin": MarlinConfig,
|
27
|
-
}
|
28
23
|
|
29
24
|
logger = logging.getLogger("model_runner")
|
30
25
|
|
@@ -32,35 +27,6 @@ logger = logging.getLogger("model_runner")
|
|
32
27
|
global_server_args_dict = {}
|
33
28
|
|
34
29
|
|
35
|
-
@lru_cache()
|
36
|
-
def import_model_classes():
|
37
|
-
model_arch_name_to_cls = {}
|
38
|
-
package_name = "sglang.srt.models"
|
39
|
-
package = importlib.import_module(package_name)
|
40
|
-
for _, name, ispkg in pkgutil.iter_modules(package.__path__, package_name + "."):
|
41
|
-
if not ispkg:
|
42
|
-
module = importlib.import_module(name)
|
43
|
-
if hasattr(module, "EntryClass"):
|
44
|
-
model_arch_name_to_cls[module.EntryClass.__name__] = module.EntryClass
|
45
|
-
return model_arch_name_to_cls
|
46
|
-
|
47
|
-
|
48
|
-
def get_model_cls_by_arch_name(model_arch_names):
|
49
|
-
model_arch_name_to_cls = import_model_classes()
|
50
|
-
|
51
|
-
model_class = None
|
52
|
-
for arch in model_arch_names:
|
53
|
-
if arch in model_arch_name_to_cls:
|
54
|
-
model_class = model_arch_name_to_cls[arch]
|
55
|
-
break
|
56
|
-
else:
|
57
|
-
raise ValueError(
|
58
|
-
f"Unsupported architectures: {arch}. "
|
59
|
-
f"Supported list: {list(model_arch_name_to_cls.keys())}"
|
60
|
-
)
|
61
|
-
return model_class
|
62
|
-
|
63
|
-
|
64
30
|
@dataclass
|
65
31
|
class InputMetadata:
|
66
32
|
model_runner: "ModelRunner"
|
@@ -110,8 +76,8 @@ class InputMetadata:
|
|
110
76
|
self.kv_last_page_len = torch.ones(
|
111
77
|
(self.batch_size,), dtype=torch.int32, device="cuda"
|
112
78
|
)
|
113
|
-
req_pool_indices_cpu = self.req_pool_indices.cpu().
|
114
|
-
seq_lens_cpu = self.seq_lens.
|
79
|
+
req_pool_indices_cpu = self.req_pool_indices.cpu().numpy()
|
80
|
+
seq_lens_cpu = self.seq_lens.cpu().numpy()
|
115
81
|
self.kv_indices = torch.cat(
|
116
82
|
[
|
117
83
|
self.req_to_token_pool.req_to_token[
|
@@ -143,7 +109,7 @@ class InputMetadata:
|
|
143
109
|
self.kv_last_page_len,
|
144
110
|
self.model_runner.model_config.num_attention_heads // tp_size,
|
145
111
|
self.model_runner.model_config.num_key_value_heads // tp_size,
|
146
|
-
self.model_runner.model_config.head_dim
|
112
|
+
self.model_runner.model_config.head_dim,
|
147
113
|
]
|
148
114
|
|
149
115
|
self.prefill_wrapper.begin_forward(*args)
|
@@ -253,113 +219,102 @@ class ModelRunner:
|
|
253
219
|
tp_rank,
|
254
220
|
tp_size,
|
255
221
|
nccl_port,
|
256
|
-
|
257
|
-
trust_remote_code=True,
|
258
|
-
server_args_dict: dict = {},
|
222
|
+
server_args: ServerArgs,
|
259
223
|
):
|
260
224
|
self.model_config = model_config
|
261
225
|
self.mem_fraction_static = mem_fraction_static
|
262
226
|
self.tp_rank = tp_rank
|
263
227
|
self.tp_size = tp_size
|
264
228
|
self.nccl_port = nccl_port
|
265
|
-
self.
|
266
|
-
self.trust_remote_code = trust_remote_code
|
229
|
+
self.server_args = server_args
|
267
230
|
|
268
231
|
global global_server_args_dict
|
269
|
-
global_server_args_dict =
|
232
|
+
global_server_args_dict = {
|
233
|
+
"enable_flashinfer": server_args.enable_flashinfer,
|
234
|
+
"attention_reduce_in_fp32": server_args.attention_reduce_in_fp32,
|
235
|
+
}
|
270
236
|
|
271
237
|
# Init torch distributed
|
238
|
+
logger.info(f"[rank={self.tp_rank}] Set cuda device.")
|
272
239
|
torch.cuda.set_device(self.tp_rank)
|
240
|
+
logger.info(f"[rank={self.tp_rank}] Init torch begin. Avail mem={get_available_gpu_memory(self.tp_rank):.2f} GB")
|
273
241
|
torch.distributed.init_process_group(
|
274
242
|
backend="nccl",
|
275
243
|
world_size=self.tp_size,
|
276
244
|
rank=self.tp_rank,
|
277
245
|
init_method=f"tcp://127.0.0.1:{self.nccl_port}",
|
278
246
|
)
|
279
|
-
|
280
247
|
initialize_model_parallel(tensor_model_parallel_size=self.tp_size)
|
248
|
+
logger.info(f"[rank={self.tp_rank}] Init torch end.")
|
249
|
+
|
250
|
+
total_gpu_memory = get_available_gpu_memory(self.tp_rank, distributed=self.tp_size > 1)
|
251
|
+
|
252
|
+
if self.tp_size > 1:
|
253
|
+
total_local_gpu_memory = get_available_gpu_memory(self.tp_rank)
|
254
|
+
if total_local_gpu_memory < total_gpu_memory * 0.9:
|
255
|
+
raise ValueError("The memory capacity is unbalanced. Some GPUs may be occupied by other processes.")
|
281
256
|
|
282
|
-
total_gpu_memory = get_available_gpu_memory(
|
283
|
-
self.tp_rank, distributed=self.tp_size > 1
|
284
|
-
) * (1 << 30)
|
285
257
|
self.load_model()
|
286
258
|
self.init_memory_pool(total_gpu_memory)
|
287
259
|
|
288
260
|
self.is_multimodal_model = is_multimodal_model(self.model_config)
|
289
261
|
|
290
262
|
def load_model(self):
|
291
|
-
"
|
292
|
-
|
293
|
-
|
294
|
-
|
295
|
-
|
296
|
-
|
297
|
-
|
298
|
-
|
299
|
-
|
300
|
-
|
301
|
-
|
302
|
-
|
303
|
-
|
304
|
-
|
305
|
-
|
306
|
-
|
307
|
-
|
308
|
-
|
309
|
-
|
310
|
-
|
311
|
-
|
312
|
-
|
313
|
-
|
314
|
-
|
315
|
-
|
316
|
-
|
317
|
-
|
318
|
-
|
319
|
-
|
320
|
-
|
321
|
-
with set_default_torch_dtype(torch.float16):
|
322
|
-
with torch.device("cuda"):
|
323
|
-
model = model_class(
|
324
|
-
config=self.model_config.hf_config, quant_config=quant_config
|
325
|
-
)
|
326
|
-
model.load_weights(
|
327
|
-
self.model_config.path,
|
328
|
-
cache_dir=None,
|
329
|
-
load_format=self.load_format,
|
330
|
-
revision=None,
|
331
|
-
)
|
332
|
-
self.model = model.eval()
|
333
|
-
|
334
|
-
logger.info(f"Rank {self.tp_rank}: load weight end.")
|
263
|
+
logger.info(f"[rank={self.tp_rank}] Load weight begin.")
|
264
|
+
|
265
|
+
device_config = DeviceConfig()
|
266
|
+
load_config = LoadConfig(load_format=self.server_args.load_format)
|
267
|
+
vllm_model_config = VllmModelConfig(
|
268
|
+
model=self.server_args.model_path,
|
269
|
+
quantization=self.server_args.quantization,
|
270
|
+
tokenizer=None,
|
271
|
+
tokenizer_mode=None,
|
272
|
+
trust_remote_code=self.server_args.trust_remote_code,
|
273
|
+
dtype=torch.float16,
|
274
|
+
seed=42,
|
275
|
+
skip_tokenizer_init=True,
|
276
|
+
)
|
277
|
+
if self.model_config.model_overide_args is not None:
|
278
|
+
vllm_model_config.hf_config.update(self.model_config.model_overide_args)
|
279
|
+
|
280
|
+
self.model = get_model(
|
281
|
+
model_config=vllm_model_config,
|
282
|
+
device_config=device_config,
|
283
|
+
load_config=load_config,
|
284
|
+
lora_config=None,
|
285
|
+
vision_language_config=None,
|
286
|
+
parallel_config=None,
|
287
|
+
scheduler_config=None,
|
288
|
+
)
|
289
|
+
logger.info(f"[rank={self.tp_rank}] Load weight end. "
|
290
|
+
f"Type={type(self.model).__name__}. "
|
291
|
+
f"Avail mem={get_available_gpu_memory(self.tp_rank):.2f} GB")
|
335
292
|
|
336
293
|
def profile_max_num_token(self, total_gpu_memory):
|
337
|
-
available_gpu_memory = get_available_gpu_memory(
|
338
|
-
self.tp_rank, distributed=self.tp_size > 1
|
339
|
-
) * (1 << 30)
|
294
|
+
available_gpu_memory = get_available_gpu_memory(self.tp_rank, distributed=self.tp_size > 1)
|
340
295
|
head_dim = self.model_config.head_dim
|
341
296
|
head_num = self.model_config.num_key_value_heads // self.tp_size
|
342
297
|
cell_size = head_num * head_dim * self.model_config.num_hidden_layers * 2 * 2
|
343
298
|
rest_memory = available_gpu_memory - total_gpu_memory * (
|
344
299
|
1 - self.mem_fraction_static
|
345
300
|
)
|
346
|
-
max_num_token = int(rest_memory // cell_size)
|
301
|
+
max_num_token = int(rest_memory * (1 << 30) // cell_size)
|
347
302
|
return max_num_token
|
348
303
|
|
349
304
|
def init_memory_pool(self, total_gpu_memory):
|
350
|
-
self.
|
305
|
+
self.max_total_num_tokens = self.profile_max_num_token(total_gpu_memory)
|
351
306
|
|
352
|
-
if self.
|
307
|
+
if self.max_total_num_tokens <= 0:
|
353
308
|
raise RuntimeError(
|
354
309
|
"Not enought memory. " "Please try to increase --mem-fraction-static."
|
355
310
|
)
|
356
311
|
|
357
312
|
self.req_to_token_pool = ReqToTokenPool(
|
358
|
-
int(self.
|
313
|
+
int(self.max_total_num_tokens / self.model_config.context_len * 256),
|
359
314
|
self.model_config.context_len + 8,
|
360
315
|
)
|
361
316
|
self.token_to_kv_pool = TokenToKVPool(
|
362
|
-
self.
|
317
|
+
self.max_total_num_tokens,
|
363
318
|
dtype=torch.float16,
|
364
319
|
head_num=self.model_config.num_key_value_heads // self.tp_size,
|
365
320
|
head_dim=self.model_config.head_dim,
|
@@ -456,3 +411,35 @@ class ModelRunner:
|
|
456
411
|
return self.forward_prefill(batch)
|
457
412
|
else:
|
458
413
|
raise ValueError(f"Invaid forward mode: {forward_mode}")
|
414
|
+
|
415
|
+
|
416
|
+
@lru_cache()
|
417
|
+
def import_model_classes():
|
418
|
+
model_arch_name_to_cls = {}
|
419
|
+
package_name = "sglang.srt.models"
|
420
|
+
package = importlib.import_module(package_name)
|
421
|
+
for _, name, ispkg in pkgutil.iter_modules(package.__path__, package_name + "."):
|
422
|
+
if not ispkg:
|
423
|
+
module = importlib.import_module(name)
|
424
|
+
if hasattr(module, "EntryClass"):
|
425
|
+
entry = module.EntryClass
|
426
|
+
if isinstance(entry, list): # To support multiple model classes in one module
|
427
|
+
for cls in entry:
|
428
|
+
model_arch_name_to_cls[cls.__name__] = cls
|
429
|
+
else:
|
430
|
+
model_arch_name_to_cls[entry.__name__] = entry
|
431
|
+
return model_arch_name_to_cls
|
432
|
+
|
433
|
+
|
434
|
+
def load_model_cls_srt(model_arch: str) -> Optional[Type[nn.Module]]:
|
435
|
+
model_arch_name_to_cls = import_model_classes()
|
436
|
+
if model_arch not in model_arch_name_to_cls:
|
437
|
+
raise ValueError(
|
438
|
+
f"Unsupported architectures: {model_arch}. "
|
439
|
+
f"Supported list: {list(model_arch_name_to_cls.keys())}"
|
440
|
+
)
|
441
|
+
return model_arch_name_to_cls[model_arch]
|
442
|
+
|
443
|
+
|
444
|
+
# Monkey patch model loader
|
445
|
+
setattr(ModelRegistry, "load_model_cls", load_model_cls_srt)
|
@@ -11,7 +11,7 @@ class TreeNode:
|
|
11
11
|
self.parent = None
|
12
12
|
self.key = None
|
13
13
|
self.value = None
|
14
|
-
self.
|
14
|
+
self.lock_ref = 0
|
15
15
|
self.last_access_time = time.time()
|
16
16
|
|
17
17
|
def __lt__(self, other: "TreeNode"):
|
@@ -28,7 +28,9 @@ def _key_match(key0, key1):
|
|
28
28
|
|
29
29
|
|
30
30
|
class RadixCache:
|
31
|
-
def __init__(self, disable: bool = False):
|
31
|
+
def __init__(self, req_to_token_pool, token_to_kv_pool, disable: bool = False):
|
32
|
+
self.req_to_token_pool = req_to_token_pool
|
33
|
+
self.token_to_kv_pool = token_to_kv_pool
|
32
34
|
self.disable = disable
|
33
35
|
self.reset()
|
34
36
|
|
@@ -38,7 +40,7 @@ class RadixCache:
|
|
38
40
|
self.root_node = TreeNode()
|
39
41
|
self.root_node.key = []
|
40
42
|
self.root_node.value = []
|
41
|
-
self.root_node.
|
43
|
+
self.root_node.lock_ref = 1
|
42
44
|
self.evictable_size_ = 0
|
43
45
|
|
44
46
|
def match_prefix(self, key):
|
@@ -50,16 +52,52 @@ class RadixCache:
|
|
50
52
|
self._match_prefix_helper(self.root_node, key, value, last_node)
|
51
53
|
if value:
|
52
54
|
value = torch.concat(value)
|
55
|
+
else:
|
56
|
+
value = torch.tensor([], dtype=torch.int64)
|
53
57
|
return value, last_node[0]
|
54
58
|
|
55
59
|
def insert(self, key, value=None):
|
56
60
|
if self.disable:
|
57
|
-
return
|
61
|
+
return 0
|
58
62
|
|
59
63
|
if value is None:
|
60
64
|
value = [x for x in key]
|
61
65
|
return self._insert_helper(self.root_node, key, value)
|
62
66
|
|
67
|
+
def cache_req(
|
68
|
+
self,
|
69
|
+
token_ids,
|
70
|
+
last_uncached_pos,
|
71
|
+
req_pool_idx,
|
72
|
+
del_in_memory_pool=True,
|
73
|
+
old_last_node=None,
|
74
|
+
):
|
75
|
+
# Insert the request into radix cache
|
76
|
+
indices = self.req_to_token_pool.req_to_token[req_pool_idx, : len(token_ids)]
|
77
|
+
new_prefix_len = self.insert(token_ids, indices.clone())
|
78
|
+
|
79
|
+
if self.disable:
|
80
|
+
if del_in_memory_pool:
|
81
|
+
self.token_to_kv_pool.dec_refs(indices)
|
82
|
+
else:
|
83
|
+
return torch.tensor([], dtype=torch.int64), self.root_node
|
84
|
+
|
85
|
+
# Radix Cache takes one ref in memory pool
|
86
|
+
self.token_to_kv_pool.dec_refs(indices[last_uncached_pos:new_prefix_len])
|
87
|
+
|
88
|
+
if del_in_memory_pool:
|
89
|
+
self.req_to_token_pool.free(req_pool_idx)
|
90
|
+
else:
|
91
|
+
cached_indices, new_last_node = self.match_prefix(token_ids)
|
92
|
+
assert len(cached_indices) == len(token_ids)
|
93
|
+
|
94
|
+
self.req_to_token_pool.req_to_token[
|
95
|
+
req_pool_idx, last_uncached_pos : len(cached_indices)
|
96
|
+
] = cached_indices[last_uncached_pos:]
|
97
|
+
self.dec_lock_ref(old_last_node)
|
98
|
+
self.inc_lock_ref(new_last_node)
|
99
|
+
return cached_indices, new_last_node
|
100
|
+
|
63
101
|
def pretty_print(self):
|
64
102
|
self._print_helper(self.root_node, 0)
|
65
103
|
print(f"#tokens: {self.total_size()}")
|
@@ -80,7 +118,7 @@ class RadixCache:
|
|
80
118
|
|
81
119
|
if x == self.root_node:
|
82
120
|
break
|
83
|
-
if x.
|
121
|
+
if x.lock_ref > 0:
|
84
122
|
continue
|
85
123
|
|
86
124
|
num_evicted += evict_callback(x.value)
|
@@ -89,23 +127,23 @@ class RadixCache:
|
|
89
127
|
if len(x.parent.children) == 0:
|
90
128
|
heapq.heappush(leaves, x.parent)
|
91
129
|
|
92
|
-
def
|
130
|
+
def inc_lock_ref(self, node: TreeNode):
|
93
131
|
delta = 0
|
94
132
|
while node != self.root_node:
|
95
|
-
if node.
|
133
|
+
if node.lock_ref == 0:
|
96
134
|
self.evictable_size_ -= len(node.value)
|
97
135
|
delta -= len(node.value)
|
98
|
-
node.
|
136
|
+
node.lock_ref += 1
|
99
137
|
node = node.parent
|
100
138
|
return delta
|
101
139
|
|
102
|
-
def
|
140
|
+
def dec_lock_ref(self, node: TreeNode):
|
103
141
|
delta = 0
|
104
142
|
while node != self.root_node:
|
105
|
-
if node.
|
143
|
+
if node.lock_ref == 1:
|
106
144
|
self.evictable_size_ += len(node.value)
|
107
145
|
delta += len(node.value)
|
108
|
-
node.
|
146
|
+
node.lock_ref -= 1
|
109
147
|
node = node.parent
|
110
148
|
return delta
|
111
149
|
|
@@ -131,12 +169,12 @@ class RadixCache:
|
|
131
169
|
last_node[0] = child
|
132
170
|
self._match_prefix_helper(child, key[prefix_len:], value, last_node)
|
133
171
|
|
134
|
-
def _split_node(self, key, child, split_len):
|
172
|
+
def _split_node(self, key, child: TreeNode, split_len):
|
135
173
|
# new_node -> child
|
136
174
|
new_node = TreeNode()
|
137
175
|
new_node.children = {key[split_len:][0]: child}
|
138
176
|
new_node.parent = child.parent
|
139
|
-
new_node.
|
177
|
+
new_node.lock_ref = child.lock_ref
|
140
178
|
new_node.key = child.key[:split_len]
|
141
179
|
new_node.value = child.value[:split_len]
|
142
180
|
child.parent = new_node
|
@@ -176,11 +214,9 @@ class RadixCache:
|
|
176
214
|
self.evictable_size_ += len(value)
|
177
215
|
return 0
|
178
216
|
|
179
|
-
def _print_helper(self, node, indent):
|
217
|
+
def _print_helper(self, node: TreeNode, indent):
|
180
218
|
for _, child in node.children.items():
|
181
|
-
print(
|
182
|
-
" " * indent, len(child.key), child.key[:10], f"r={child.ref_counter}"
|
183
|
-
)
|
219
|
+
print(" " * indent, len(child.key), child.key[:10], f"r={child.lock_ref}")
|
184
220
|
self._print_helper(child, indent=indent + 2)
|
185
221
|
|
186
222
|
def _delete_leaf(self, node):
|
@@ -211,7 +247,7 @@ class RadixCache:
|
|
211
247
|
|
212
248
|
|
213
249
|
if __name__ == "__main__":
|
214
|
-
tree = RadixCache()
|
250
|
+
tree = RadixCache(None, None, False)
|
215
251
|
|
216
252
|
tree.insert("Hello")
|
217
253
|
tree.insert("Hello")
|
@@ -6,15 +6,15 @@ class Scheduler:
|
|
6
6
|
def __init__(
|
7
7
|
self,
|
8
8
|
schedule_heuristic,
|
9
|
-
|
10
|
-
|
11
|
-
|
9
|
+
max_running_seqs,
|
10
|
+
max_prefill_num_tokens,
|
11
|
+
max_total_num_tokens,
|
12
12
|
tree_cache,
|
13
13
|
):
|
14
14
|
self.schedule_heuristic = schedule_heuristic
|
15
|
-
self.
|
16
|
-
self.
|
17
|
-
self.
|
15
|
+
self.max_running_seqs = max_running_seqs
|
16
|
+
self.max_prefill_num_tokens = max_prefill_num_tokens
|
17
|
+
self.max_total_num_tokens = max_total_num_tokens
|
18
18
|
self.tree_cache = tree_cache
|
19
19
|
|
20
20
|
def get_priority_queue(self, forward_queue):
|
@@ -27,44 +27,33 @@ class Scheduler:
|
|
27
27
|
return forward_queue
|
28
28
|
elif self.schedule_heuristic == "fcfs":
|
29
29
|
return forward_queue
|
30
|
-
elif self.schedule_heuristic == "weight":
|
30
|
+
elif self.schedule_heuristic == "dfs-weight":
|
31
31
|
last_node_to_reqs = defaultdict(list)
|
32
32
|
for req in forward_queue:
|
33
33
|
last_node_to_reqs[req.last_node].append(req)
|
34
|
-
for node in last_node_to_reqs:
|
35
|
-
last_node_to_reqs[node].sort(key=lambda x: -len(x.prefix_indices))
|
36
34
|
|
37
35
|
node_to_weight = defaultdict(int)
|
38
|
-
|
39
|
-
|
40
|
-
)
|
36
|
+
for node in last_node_to_reqs:
|
37
|
+
node_to_weight[node] = len(last_node_to_reqs[node])
|
38
|
+
self.calc_weight(self.tree_cache.root_node, node_to_weight)
|
41
39
|
|
42
|
-
|
43
|
-
self.
|
44
|
-
self.tree_cache.root_node, node_to_weight, last_node_to_reqs,
|
40
|
+
q = []
|
41
|
+
self.get_dfs_priority(
|
42
|
+
self.tree_cache.root_node, node_to_weight, last_node_to_reqs, q
|
45
43
|
)
|
46
|
-
assert len(
|
47
|
-
return
|
44
|
+
assert len(q) == len(forward_queue)
|
45
|
+
return q
|
48
46
|
else:
|
49
47
|
raise ValueError(f"Unknown schedule_heuristic: {self.schedule_heuristic}")
|
50
48
|
|
51
|
-
def
|
52
|
-
node_to_weight[cur_node] = 1
|
53
|
-
if cur_node in last_node_to_reqs:
|
54
|
-
node_to_weight[cur_node] += len(last_node_to_reqs[cur_node])
|
49
|
+
def calc_weight(self, cur_node, node_to_weight):
|
55
50
|
for child in cur_node.children.values():
|
56
|
-
self.
|
51
|
+
self.calc_weight(child, node_to_weight)
|
57
52
|
node_to_weight[cur_node] += node_to_weight[child]
|
58
53
|
|
59
|
-
def
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
# print(f"{node_to_wight[node]} {len(node.value) if node.value is not None else 0}")
|
66
|
-
for child in visit_list:
|
67
|
-
self._get_weight_priority_recursive(
|
68
|
-
child, node_to_wight, last_node_to_reqs, tmp_queue
|
69
|
-
)
|
70
|
-
tmp_queue.extend(last_node_to_reqs[cur_node])
|
54
|
+
def get_dfs_priority(self, cur_node, node_to_priority, last_node_to_reqs, q):
|
55
|
+
childs = [child for child in cur_node.children.values()]
|
56
|
+
childs.sort(key=lambda x: -node_to_priority[x])
|
57
|
+
for child in childs:
|
58
|
+
self.get_dfs_priority(child, node_to_priority, last_node_to_reqs, q)
|
59
|
+
q.extend(last_node_to_reqs[cur_node])
|