sglang 0.1.17__py3-none-any.whl → 0.1.18__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (66) hide show
  1. sglang/__init__.py +2 -2
  2. sglang/api.py +4 -4
  3. sglang/backend/litellm.py +2 -2
  4. sglang/backend/openai.py +26 -15
  5. sglang/bench_latency.py +299 -0
  6. sglang/global_config.py +4 -1
  7. sglang/lang/compiler.py +2 -2
  8. sglang/lang/interpreter.py +1 -1
  9. sglang/lang/ir.py +15 -5
  10. sglang/launch_server.py +4 -1
  11. sglang/launch_server_llavavid.py +2 -1
  12. sglang/srt/constrained/__init__.py +13 -6
  13. sglang/srt/constrained/fsm_cache.py +6 -3
  14. sglang/srt/constrained/jump_forward.py +113 -25
  15. sglang/srt/conversation.py +2 -0
  16. sglang/srt/flush_cache.py +2 -0
  17. sglang/srt/hf_transformers_utils.py +64 -9
  18. sglang/srt/layers/fused_moe.py +186 -89
  19. sglang/srt/layers/logits_processor.py +53 -25
  20. sglang/srt/layers/radix_attention.py +34 -7
  21. sglang/srt/managers/controller/dp_worker.py +6 -3
  22. sglang/srt/managers/controller/infer_batch.py +142 -67
  23. sglang/srt/managers/controller/manager_multi.py +5 -5
  24. sglang/srt/managers/controller/manager_single.py +8 -3
  25. sglang/srt/managers/controller/model_runner.py +154 -54
  26. sglang/srt/managers/controller/radix_cache.py +4 -0
  27. sglang/srt/managers/controller/schedule_heuristic.py +2 -0
  28. sglang/srt/managers/controller/tp_worker.py +140 -135
  29. sglang/srt/managers/detokenizer_manager.py +15 -19
  30. sglang/srt/managers/io_struct.py +10 -4
  31. sglang/srt/managers/tokenizer_manager.py +14 -13
  32. sglang/srt/model_config.py +83 -4
  33. sglang/srt/models/chatglm.py +399 -0
  34. sglang/srt/models/commandr.py +2 -2
  35. sglang/srt/models/dbrx.py +1 -1
  36. sglang/srt/models/gemma.py +5 -1
  37. sglang/srt/models/grok.py +204 -137
  38. sglang/srt/models/llama2.py +11 -4
  39. sglang/srt/models/llama_classification.py +104 -0
  40. sglang/srt/models/llava.py +11 -8
  41. sglang/srt/models/llavavid.py +1 -1
  42. sglang/srt/models/mixtral.py +164 -115
  43. sglang/srt/models/mixtral_quant.py +0 -1
  44. sglang/srt/models/qwen.py +1 -1
  45. sglang/srt/models/qwen2.py +1 -1
  46. sglang/srt/models/stablelm.py +1 -1
  47. sglang/srt/models/yivl.py +2 -2
  48. sglang/srt/openai_api_adapter.py +33 -23
  49. sglang/srt/openai_protocol.py +1 -1
  50. sglang/srt/server.py +60 -19
  51. sglang/srt/server_args.py +79 -44
  52. sglang/srt/utils.py +146 -37
  53. sglang/test/test_programs.py +28 -10
  54. sglang/utils.py +4 -3
  55. {sglang-0.1.17.dist-info → sglang-0.1.18.dist-info}/METADATA +29 -22
  56. sglang-0.1.18.dist-info/RECORD +78 -0
  57. {sglang-0.1.17.dist-info → sglang-0.1.18.dist-info}/WHEEL +1 -1
  58. sglang/srt/managers/router/infer_batch.py +0 -596
  59. sglang/srt/managers/router/manager.py +0 -82
  60. sglang/srt/managers/router/model_rpc.py +0 -818
  61. sglang/srt/managers/router/model_runner.py +0 -445
  62. sglang/srt/managers/router/radix_cache.py +0 -267
  63. sglang/srt/managers/router/scheduler.py +0 -59
  64. sglang-0.1.17.dist-info/RECORD +0 -81
  65. {sglang-0.1.17.dist-info → sglang-0.1.18.dist-info}/LICENSE +0 -0
  66. {sglang-0.1.17.dist-info → sglang-0.1.18.dist-info}/top_level.txt +0 -0
@@ -1,3 +1,5 @@
1
+ """ModelRunner runs the forward passes of the models."""
2
+
1
3
  import importlib
2
4
  import importlib.resources
3
5
  import logging
@@ -11,15 +13,19 @@ import torch
11
13
  import torch.nn as nn
12
14
  from vllm.config import DeviceConfig, LoadConfig
13
15
  from vllm.config import ModelConfig as VllmModelConfig
14
- from vllm.distributed import initialize_model_parallel, init_distributed_environment
16
+ from vllm.distributed import init_distributed_environment, initialize_model_parallel
15
17
  from vllm.model_executor.model_loader import get_model
16
18
  from vllm.model_executor.models import ModelRegistry
17
19
 
18
20
  from sglang.srt.managers.controller.infer_batch import Batch, ForwardMode
19
21
  from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
20
22
  from sglang.srt.server_args import ServerArgs
21
- from sglang.srt.utils import get_available_gpu_memory, is_multimodal_model, monkey_patch_vllm_p2p_access_check
22
-
23
+ from sglang.srt.utils import (
24
+ get_available_gpu_memory,
25
+ is_multimodal_model,
26
+ monkey_patch_vllm_dummy_weight_loader,
27
+ monkey_patch_vllm_p2p_access_check,
28
+ )
23
29
 
24
30
  logger = logging.getLogger("srt.model_runner")
25
31
 
@@ -29,7 +35,6 @@ global_server_args_dict = {}
29
35
 
30
36
  @dataclass
31
37
  class InputMetadata:
32
- model_runner: "ModelRunner"
33
38
  forward_mode: ForwardMode
34
39
  batch_size: int
35
40
  total_num_tokens: int
@@ -60,73 +65,82 @@ class InputMetadata:
60
65
  kv_indptr: torch.Tensor = None
61
66
  kv_indices: torch.Tensor = None
62
67
  kv_last_page_len: torch.Tensor = None
63
- prefill_wrapper = None
64
- decode_wrapper = None
68
+ flashinfer_prefill_wrapper_ragged: "BatchPrefillWithRaggedKVCacheWrapper" = None
69
+ flashinfer_prefill_wrapper_paged: "BatchPrefillWithPagedKVCacheWrapper" = None
70
+ flashinfer_decode_wrapper: "BatchDecodeWithPagedKVCacheWrapper" = None
65
71
 
66
- def init_flashinfer_args(self, tp_size):
67
- from flashinfer import (
68
- BatchDecodeWithPagedKVCacheWrapper,
69
- BatchPrefillWithPagedKVCacheWrapper,
70
- )
72
+ def init_flashinfer_args(self, num_qo_heads, num_kv_heads, head_dim):
73
+ if (
74
+ self.forward_mode == ForwardMode.PREFILL
75
+ or self.forward_mode == ForwardMode.EXTEND
76
+ ):
77
+ paged_kernel_lens = self.prefix_lens
78
+ self.no_prefix = torch.all(self.prefix_lens == 0)
79
+ else:
80
+ paged_kernel_lens = self.seq_lens
71
81
 
72
82
  self.kv_indptr = torch.zeros(
73
83
  (self.batch_size + 1,), dtype=torch.int32, device="cuda"
74
84
  )
75
- self.kv_indptr[1:] = torch.cumsum(self.seq_lens, dim=0)
85
+ self.kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
76
86
  self.kv_last_page_len = torch.ones(
77
87
  (self.batch_size,), dtype=torch.int32, device="cuda"
78
88
  )
79
89
  req_pool_indices_cpu = self.req_pool_indices.cpu().numpy()
80
- seq_lens_cpu = self.seq_lens.cpu().numpy()
90
+ paged_kernel_lens_cpu = paged_kernel_lens.cpu().numpy()
81
91
  self.kv_indices = torch.cat(
82
92
  [
83
93
  self.req_to_token_pool.req_to_token[
84
- req_pool_indices_cpu[i], : seq_lens_cpu[i]
94
+ req_pool_indices_cpu[i], : paged_kernel_lens_cpu[i]
85
95
  ]
86
96
  for i in range(self.batch_size)
87
97
  ],
88
98
  dim=0,
89
99
  ).contiguous()
90
100
 
91
- workspace_buffer = torch.empty(
92
- 32 * 1024 * 1024, dtype=torch.int8, device="cuda"
93
- )
94
101
  if (
95
102
  self.forward_mode == ForwardMode.PREFILL
96
103
  or self.forward_mode == ForwardMode.EXTEND
97
104
  ):
105
+ # extend part
98
106
  self.qo_indptr = torch.zeros(
99
107
  (self.batch_size + 1,), dtype=torch.int32, device="cuda"
100
108
  )
101
109
  self.qo_indptr[1:] = torch.cumsum(self.extend_seq_lens, dim=0)
102
- self.prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper(
103
- workspace_buffer, "NHD"
110
+
111
+ self.flashinfer_prefill_wrapper_ragged.end_forward()
112
+ self.flashinfer_prefill_wrapper_ragged.begin_forward(
113
+ self.qo_indptr,
114
+ self.qo_indptr.clone(),
115
+ num_qo_heads,
116
+ num_kv_heads,
117
+ head_dim,
104
118
  )
105
- args = [
119
+
120
+ # cached part
121
+ self.flashinfer_prefill_wrapper_paged.end_forward()
122
+ self.flashinfer_prefill_wrapper_paged.begin_forward(
106
123
  self.qo_indptr,
107
124
  self.kv_indptr,
108
125
  self.kv_indices,
109
126
  self.kv_last_page_len,
110
- self.model_runner.model_config.num_attention_heads // tp_size,
111
- self.model_runner.model_config.num_key_value_heads // tp_size,
112
- self.model_runner.model_config.head_dim,
113
- ]
114
-
115
- self.prefill_wrapper.begin_forward(*args)
116
- else:
117
- self.decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
118
- workspace_buffer, "NHD"
127
+ num_qo_heads,
128
+ num_kv_heads,
129
+ head_dim,
130
+ 1
119
131
  )
120
- self.decode_wrapper.begin_forward(
132
+ else:
133
+ self.flashinfer_decode_wrapper.end_forward()
134
+ self.flashinfer_decode_wrapper.begin_forward(
121
135
  self.kv_indptr,
122
136
  self.kv_indices,
123
137
  self.kv_last_page_len,
124
- self.model_runner.model_config.num_attention_heads // tp_size,
125
- self.model_runner.model_config.num_key_value_heads // tp_size,
126
- self.model_runner.model_config.head_dim,
138
+ num_qo_heads,
139
+ num_kv_heads,
140
+ head_dim,
127
141
  1,
128
- "NONE",
129
- "float16",
142
+ pos_encoding_mode="NONE",
143
+ data_type=self.token_to_kv_pool.kv_data[0].dtype
130
144
  )
131
145
 
132
146
  def init_extend_args(self):
@@ -150,6 +164,9 @@ class InputMetadata:
150
164
  out_cache_cont_end=None,
151
165
  top_logprobs_nums=None,
152
166
  return_logprob=False,
167
+ flashinfer_prefill_wrapper_ragged=None,
168
+ flashinfer_prefill_wrapper_paged=None,
169
+ flashinfer_decode_wrapper=None,
153
170
  ):
154
171
  batch_size = len(req_pool_indices)
155
172
  start_loc = torch.zeros((batch_size,), dtype=torch.int32, device="cuda")
@@ -182,7 +199,6 @@ class InputMetadata:
182
199
  other_kv_index = None
183
200
 
184
201
  ret = cls(
185
- model_runner=model_runner,
186
202
  forward_mode=forward_mode,
187
203
  batch_size=batch_size,
188
204
  total_num_tokens=total_num_tokens,
@@ -200,13 +216,20 @@ class InputMetadata:
200
216
  other_kv_index=other_kv_index,
201
217
  return_logprob=return_logprob,
202
218
  top_logprobs_nums=top_logprobs_nums,
219
+ flashinfer_prefill_wrapper_ragged=flashinfer_prefill_wrapper_ragged,
220
+ flashinfer_prefill_wrapper_paged=flashinfer_prefill_wrapper_paged,
221
+ flashinfer_decode_wrapper=flashinfer_decode_wrapper,
203
222
  )
204
223
 
205
224
  if forward_mode == ForwardMode.EXTEND:
206
225
  ret.init_extend_args()
207
226
 
208
- if global_server_args_dict.get("enable_flashinfer", False):
209
- ret.init_flashinfer_args(tp_size)
227
+ if not global_server_args_dict.get("disable_flashinfer", False):
228
+ ret.init_flashinfer_args(
229
+ model_runner.model_config.num_attention_heads // tp_size,
230
+ model_runner.model_config.get_num_kv_heads(tp_size),
231
+ model_runner.model_config.head_dim
232
+ )
210
233
 
211
234
  return ret
212
235
 
@@ -229,24 +252,24 @@ class ModelRunner:
229
252
  self.tp_size = tp_size
230
253
  self.nccl_port = nccl_port
231
254
  self.server_args = server_args
232
-
233
- global global_server_args_dict
234
- global_server_args_dict = {
235
- "enable_flashinfer": server_args.enable_flashinfer,
236
- "attention_reduce_in_fp32": server_args.attention_reduce_in_fp32,
237
- }
255
+ self.is_multimodal_model = is_multimodal_model(self.model_config)
256
+ monkey_patch_vllm_dummy_weight_loader()
238
257
 
239
258
  # Init torch distributed
240
259
  logger.info(f"[gpu_id={self.gpu_id}] Set cuda device.")
241
260
  torch.cuda.set_device(self.gpu_id)
242
261
  logger.info(f"[gpu_id={self.gpu_id}] Init nccl begin.")
243
- monkey_patch_vllm_p2p_access_check()
262
+ monkey_patch_vllm_p2p_access_check(self.gpu_id)
263
+ if server_args.nccl_init_addr:
264
+ nccl_init_method = f"tcp://{server_args.nccl_init_addr}"
265
+ else:
266
+ nccl_init_method = f"tcp://127.0.0.1:{self.nccl_port}"
244
267
  init_distributed_environment(
245
268
  backend="nccl",
246
269
  world_size=self.tp_size,
247
270
  rank=self.tp_rank,
248
271
  local_rank=self.gpu_id,
249
- distributed_init_method=f"tcp://127.0.0.1:{self.nccl_port}",
272
+ distributed_init_method=nccl_init_method
250
273
  )
251
274
  initialize_model_parallel(tensor_model_parallel_size=self.tp_size)
252
275
  total_gpu_memory = get_available_gpu_memory(
@@ -260,9 +283,18 @@ class ModelRunner:
260
283
  "The memory capacity is unbalanced. Some GPUs may be occupied by other processes."
261
284
  )
262
285
 
286
+ # Set some global args
287
+ global global_server_args_dict
288
+ global_server_args_dict = {
289
+ "disable_flashinfer": server_args.disable_flashinfer,
290
+ "attention_reduce_in_fp32": server_args.attention_reduce_in_fp32,
291
+ }
292
+
293
+ # Load the model and create memory pool
263
294
  self.load_model()
264
295
  self.init_memory_pool(total_gpu_memory)
265
- self.is_multimodal_model = is_multimodal_model(self.model_config)
296
+ self.init_cublas()
297
+ self.init_flash_infer()
266
298
 
267
299
  def load_model(self):
268
300
  logger.info(
@@ -278,10 +310,11 @@ class ModelRunner:
278
310
  tokenizer=None,
279
311
  tokenizer_mode=None,
280
312
  trust_remote_code=self.server_args.trust_remote_code,
281
- dtype=torch.float16,
313
+ dtype=self.server_args.dtype,
282
314
  seed=42,
283
315
  skip_tokenizer_init=True,
284
316
  )
317
+ self.dtype = vllm_model_config.dtype
285
318
  if self.model_config.model_overide_args is not None:
286
319
  vllm_model_config.hf_config.update(self.model_config.model_overide_args)
287
320
 
@@ -298,6 +331,7 @@ class ModelRunner:
298
331
  logger.info(
299
332
  f"[gpu_id={self.gpu_id}] Load weight end. "
300
333
  f"type={type(self.model).__name__}, "
334
+ f"dtype={self.dtype}, "
301
335
  f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
302
336
  )
303
337
 
@@ -306,8 +340,8 @@ class ModelRunner:
306
340
  self.gpu_id, distributed=self.tp_size > 1
307
341
  )
308
342
  head_dim = self.model_config.head_dim
309
- head_num = self.model_config.num_key_value_heads // self.tp_size
310
- cell_size = head_num * head_dim * self.model_config.num_hidden_layers * 2 * 2
343
+ head_num = self.model_config.get_num_kv_heads(self.tp_size)
344
+ cell_size = head_num * head_dim * self.model_config.num_hidden_layers * 2 * torch._utils._element_size(self.dtype)
311
345
  rest_memory = available_gpu_memory - total_gpu_memory * (
312
346
  1 - self.mem_fraction_static
313
347
  )
@@ -319,7 +353,7 @@ class ModelRunner:
319
353
 
320
354
  if self.max_total_num_tokens <= 0:
321
355
  raise RuntimeError(
322
- "Not enought memory. Please try to increase --mem-fraction-static."
356
+ "Not enough memory. Please try to increase --mem-fraction-static."
323
357
  )
324
358
 
325
359
  self.req_to_token_pool = ReqToTokenPool(
@@ -328,8 +362,8 @@ class ModelRunner:
328
362
  )
329
363
  self.token_to_kv_pool = TokenToKVPool(
330
364
  self.max_total_num_tokens,
331
- dtype=torch.float16,
332
- head_num=self.model_config.num_key_value_heads // self.tp_size,
365
+ dtype=self.dtype,
366
+ head_num=self.model_config.get_num_kv_heads(self.tp_size),
333
367
  head_dim=self.model_config.head_dim,
334
368
  layer_num=self.model_config.num_hidden_layers,
335
369
  )
@@ -338,6 +372,47 @@ class ModelRunner:
338
372
  f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
339
373
  )
340
374
 
375
+ def init_cublas(self):
376
+ """We need to run a small matmul to init cublas. Otherwise, it will raise some errors later."""
377
+ dtype = torch.float16
378
+ device = "cuda"
379
+ a = torch.ones((16, 16), dtype=dtype, device=device)
380
+ b = torch.ones((16, 16), dtype=dtype, device=device)
381
+ c = a @ b
382
+ return c
383
+
384
+ def init_flash_infer(self):
385
+ if not global_server_args_dict.get("disable_flashinfer", False):
386
+ from flashinfer import (
387
+ BatchPrefillWithRaggedKVCacheWrapper,
388
+ BatchPrefillWithPagedKVCacheWrapper,
389
+ BatchDecodeWithPagedKVCacheWrapper,
390
+ )
391
+ from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
392
+
393
+ if not _grouped_size_compiled_for_decode_kernels(
394
+ self.model_config.num_attention_heads // self.tp_size,
395
+ self.model_config.get_num_kv_heads(self.tp_size)):
396
+ use_tensor_cores = True
397
+ else:
398
+ use_tensor_cores = False
399
+
400
+ workspace_buffers = torch.empty(
401
+ 3, 96 * 1024 * 1024, dtype=torch.uint8, device="cuda"
402
+ )
403
+ self.flashinfer_prefill_wrapper_ragged = BatchPrefillWithRaggedKVCacheWrapper(
404
+ workspace_buffers[0], "NHD"
405
+ )
406
+ self.flashinfer_prefill_wrapper_paged = BatchPrefillWithPagedKVCacheWrapper(
407
+ workspace_buffers[1], "NHD"
408
+ )
409
+ self.flashinfer_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
410
+ workspace_buffers[2], "NHD", use_tensor_cores=use_tensor_cores
411
+ )
412
+ else:
413
+ self.flashinfer_prefill_wrapper_ragged = self.flashinfer_prefill_wrapper_paged = None
414
+ self.flashinfer_decode_wrapper = None
415
+
341
416
  @torch.inference_mode()
342
417
  def forward_prefill(self, batch: Batch):
343
418
  input_metadata = InputMetadata.create(
@@ -351,6 +426,9 @@ class ModelRunner:
351
426
  out_cache_loc=batch.out_cache_loc,
352
427
  top_logprobs_nums=batch.top_logprobs_nums,
353
428
  return_logprob=batch.return_logprob,
429
+ flashinfer_prefill_wrapper_ragged=self.flashinfer_prefill_wrapper_ragged,
430
+ flashinfer_prefill_wrapper_paged=self.flashinfer_prefill_wrapper_paged,
431
+ flashinfer_decode_wrapper=self.flashinfer_decode_wrapper,
354
432
  )
355
433
  return self.model.forward(
356
434
  batch.input_ids, input_metadata.positions, input_metadata
@@ -369,6 +447,9 @@ class ModelRunner:
369
447
  out_cache_loc=batch.out_cache_loc,
370
448
  top_logprobs_nums=batch.top_logprobs_nums,
371
449
  return_logprob=batch.return_logprob,
450
+ flashinfer_prefill_wrapper_ragged=self.flashinfer_prefill_wrapper_ragged,
451
+ flashinfer_prefill_wrapper_paged=self.flashinfer_prefill_wrapper_paged,
452
+ flashinfer_decode_wrapper=self.flashinfer_decode_wrapper,
372
453
  )
373
454
  return self.model.forward(
374
455
  batch.input_ids, input_metadata.positions, input_metadata
@@ -389,6 +470,9 @@ class ModelRunner:
389
470
  out_cache_cont_end=batch.out_cache_cont_end,
390
471
  top_logprobs_nums=batch.top_logprobs_nums,
391
472
  return_logprob=batch.return_logprob,
473
+ flashinfer_prefill_wrapper_ragged=self.flashinfer_prefill_wrapper_ragged,
474
+ flashinfer_prefill_wrapper_paged=self.flashinfer_prefill_wrapper_paged,
475
+ flashinfer_decode_wrapper=self.flashinfer_decode_wrapper,
392
476
  )
393
477
  return self.model.forward(
394
478
  batch.input_ids, input_metadata.positions, input_metadata
@@ -407,6 +491,9 @@ class ModelRunner:
407
491
  out_cache_loc=batch.out_cache_loc,
408
492
  top_logprobs_nums=batch.top_logprobs_nums,
409
493
  return_logprob=batch.return_logprob,
494
+ flashinfer_prefill_wrapper_ragged=self.flashinfer_prefill_wrapper_ragged,
495
+ flashinfer_prefill_wrapper_paged=self.flashinfer_prefill_wrapper_paged,
496
+ flashinfer_decode_wrapper=self.flashinfer_decode_wrapper,
410
497
  )
411
498
  return self.model.forward(
412
499
  batch.input_ids,
@@ -440,16 +527,29 @@ def import_model_classes():
440
527
  module = importlib.import_module(name)
441
528
  if hasattr(module, "EntryClass"):
442
529
  entry = module.EntryClass
443
- if isinstance(entry, list): # To support multiple model classes in one module
530
+ if isinstance(
531
+ entry, list
532
+ ): # To support multiple model classes in one module
444
533
  for tmp in entry:
445
534
  model_arch_name_to_cls[tmp.__name__] = tmp
446
535
  else:
447
536
  model_arch_name_to_cls[entry.__name__] = entry
537
+
538
+ # compat: some models such as chatglm has incorrect class set in config.json
539
+ # usage: [ tuple("From_Entry_Class_Name": EntryClass), ]
540
+ if hasattr(module, "EntryClassRemapping") and isinstance(
541
+ module.EntryClassRemapping, list
542
+ ):
543
+ for remap in module.EntryClassRemapping:
544
+ if isinstance(remap, tuple) and len(remap) == 2:
545
+ model_arch_name_to_cls[remap[0]] = remap[1]
546
+
448
547
  return model_arch_name_to_cls
449
548
 
450
549
 
451
550
  def load_model_cls_srt(model_arch: str) -> Optional[Type[nn.Module]]:
452
551
  model_arch_name_to_cls = import_model_classes()
552
+
453
553
  if model_arch not in model_arch_name_to_cls:
454
554
  raise ValueError(
455
555
  f"Unsupported architectures: {model_arch}. "
@@ -1,3 +1,7 @@
1
+ """
2
+ The radix tree data structure for managing the KV cache.
3
+ """
4
+
1
5
  import heapq
2
6
  import time
3
7
  from collections import defaultdict
@@ -1,3 +1,5 @@
1
+ """Request scheduler heuristic."""
2
+
1
3
  import random
2
4
  from collections import defaultdict
3
5