sglang 0.1.17__py3-none-any.whl → 0.1.19__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (73) hide show
  1. sglang/__init__.py +2 -2
  2. sglang/api.py +30 -4
  3. sglang/backend/litellm.py +2 -2
  4. sglang/backend/openai.py +26 -15
  5. sglang/backend/runtime_endpoint.py +18 -14
  6. sglang/bench_latency.py +317 -0
  7. sglang/global_config.py +5 -1
  8. sglang/lang/chat_template.py +41 -6
  9. sglang/lang/compiler.py +2 -2
  10. sglang/lang/interpreter.py +6 -2
  11. sglang/lang/ir.py +74 -28
  12. sglang/launch_server.py +4 -1
  13. sglang/launch_server_llavavid.py +2 -1
  14. sglang/srt/constrained/__init__.py +14 -6
  15. sglang/srt/constrained/fsm_cache.py +6 -3
  16. sglang/srt/constrained/jump_forward.py +113 -25
  17. sglang/srt/conversation.py +2 -0
  18. sglang/srt/flush_cache.py +2 -0
  19. sglang/srt/hf_transformers_utils.py +68 -9
  20. sglang/srt/layers/extend_attention.py +2 -1
  21. sglang/srt/layers/fused_moe.py +280 -169
  22. sglang/srt/layers/logits_processor.py +106 -42
  23. sglang/srt/layers/radix_attention.py +53 -29
  24. sglang/srt/layers/token_attention.py +4 -1
  25. sglang/srt/managers/controller/dp_worker.py +6 -3
  26. sglang/srt/managers/controller/infer_batch.py +144 -69
  27. sglang/srt/managers/controller/manager_multi.py +5 -5
  28. sglang/srt/managers/controller/manager_single.py +9 -4
  29. sglang/srt/managers/controller/model_runner.py +167 -55
  30. sglang/srt/managers/controller/radix_cache.py +4 -0
  31. sglang/srt/managers/controller/schedule_heuristic.py +2 -0
  32. sglang/srt/managers/controller/tp_worker.py +156 -134
  33. sglang/srt/managers/detokenizer_manager.py +19 -21
  34. sglang/srt/managers/io_struct.py +11 -5
  35. sglang/srt/managers/tokenizer_manager.py +16 -14
  36. sglang/srt/model_config.py +89 -4
  37. sglang/srt/models/chatglm.py +399 -0
  38. sglang/srt/models/commandr.py +2 -2
  39. sglang/srt/models/dbrx.py +1 -1
  40. sglang/srt/models/gemma.py +5 -1
  41. sglang/srt/models/gemma2.py +436 -0
  42. sglang/srt/models/grok.py +204 -137
  43. sglang/srt/models/llama2.py +12 -5
  44. sglang/srt/models/llama_classification.py +107 -0
  45. sglang/srt/models/llava.py +11 -8
  46. sglang/srt/models/llavavid.py +1 -1
  47. sglang/srt/models/minicpm.py +373 -0
  48. sglang/srt/models/mixtral.py +164 -115
  49. sglang/srt/models/mixtral_quant.py +0 -1
  50. sglang/srt/models/qwen.py +1 -1
  51. sglang/srt/models/qwen2.py +1 -1
  52. sglang/srt/models/qwen2_moe.py +454 -0
  53. sglang/srt/models/stablelm.py +1 -1
  54. sglang/srt/models/yivl.py +2 -2
  55. sglang/srt/openai_api_adapter.py +35 -25
  56. sglang/srt/openai_protocol.py +2 -2
  57. sglang/srt/server.py +69 -19
  58. sglang/srt/server_args.py +76 -43
  59. sglang/srt/utils.py +177 -35
  60. sglang/test/test_programs.py +28 -10
  61. sglang/utils.py +4 -3
  62. {sglang-0.1.17.dist-info → sglang-0.1.19.dist-info}/METADATA +44 -31
  63. sglang-0.1.19.dist-info/RECORD +81 -0
  64. {sglang-0.1.17.dist-info → sglang-0.1.19.dist-info}/WHEEL +1 -1
  65. sglang/srt/managers/router/infer_batch.py +0 -596
  66. sglang/srt/managers/router/manager.py +0 -82
  67. sglang/srt/managers/router/model_rpc.py +0 -818
  68. sglang/srt/managers/router/model_runner.py +0 -445
  69. sglang/srt/managers/router/radix_cache.py +0 -267
  70. sglang/srt/managers/router/scheduler.py +0 -59
  71. sglang-0.1.17.dist-info/RECORD +0 -81
  72. {sglang-0.1.17.dist-info → sglang-0.1.19.dist-info}/LICENSE +0 -0
  73. {sglang-0.1.17.dist-info → sglang-0.1.19.dist-info}/top_level.txt +0 -0
@@ -1,3 +1,5 @@
1
+ """ModelRunner runs the forward passes of the models."""
2
+
1
3
  import importlib
2
4
  import importlib.resources
3
5
  import logging
@@ -11,15 +13,19 @@ import torch
11
13
  import torch.nn as nn
12
14
  from vllm.config import DeviceConfig, LoadConfig
13
15
  from vllm.config import ModelConfig as VllmModelConfig
14
- from vllm.distributed import initialize_model_parallel, init_distributed_environment
16
+ from vllm.distributed import init_distributed_environment, initialize_model_parallel
15
17
  from vllm.model_executor.model_loader import get_model
16
18
  from vllm.model_executor.models import ModelRegistry
17
19
 
18
20
  from sglang.srt.managers.controller.infer_batch import Batch, ForwardMode
19
21
  from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
20
22
  from sglang.srt.server_args import ServerArgs
21
- from sglang.srt.utils import get_available_gpu_memory, is_multimodal_model, monkey_patch_vllm_p2p_access_check
22
-
23
+ from sglang.srt.utils import (
24
+ get_available_gpu_memory,
25
+ is_multimodal_model,
26
+ monkey_patch_vllm_dummy_weight_loader,
27
+ monkey_patch_vllm_p2p_access_check,
28
+ )
23
29
 
24
30
  logger = logging.getLogger("srt.model_runner")
25
31
 
@@ -29,7 +35,6 @@ global_server_args_dict = {}
29
35
 
30
36
  @dataclass
31
37
  class InputMetadata:
32
- model_runner: "ModelRunner"
33
38
  forward_mode: ForwardMode
34
39
  batch_size: int
35
40
  total_num_tokens: int
@@ -60,73 +65,82 @@ class InputMetadata:
60
65
  kv_indptr: torch.Tensor = None
61
66
  kv_indices: torch.Tensor = None
62
67
  kv_last_page_len: torch.Tensor = None
63
- prefill_wrapper = None
64
- decode_wrapper = None
68
+ flashinfer_prefill_wrapper_ragged: "BatchPrefillWithRaggedKVCacheWrapper" = None
69
+ flashinfer_prefill_wrapper_paged: "BatchPrefillWithPagedKVCacheWrapper" = None
70
+ flashinfer_decode_wrapper: "BatchDecodeWithPagedKVCacheWrapper" = None
65
71
 
66
- def init_flashinfer_args(self, tp_size):
67
- from flashinfer import (
68
- BatchDecodeWithPagedKVCacheWrapper,
69
- BatchPrefillWithPagedKVCacheWrapper,
70
- )
72
+ def init_flashinfer_args(self, num_qo_heads, num_kv_heads, head_dim):
73
+ if (
74
+ self.forward_mode == ForwardMode.PREFILL
75
+ or self.forward_mode == ForwardMode.EXTEND
76
+ ):
77
+ paged_kernel_lens = self.prefix_lens
78
+ self.no_prefix = torch.all(self.prefix_lens == 0)
79
+ else:
80
+ paged_kernel_lens = self.seq_lens
71
81
 
72
82
  self.kv_indptr = torch.zeros(
73
83
  (self.batch_size + 1,), dtype=torch.int32, device="cuda"
74
84
  )
75
- self.kv_indptr[1:] = torch.cumsum(self.seq_lens, dim=0)
85
+ self.kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
76
86
  self.kv_last_page_len = torch.ones(
77
87
  (self.batch_size,), dtype=torch.int32, device="cuda"
78
88
  )
79
89
  req_pool_indices_cpu = self.req_pool_indices.cpu().numpy()
80
- seq_lens_cpu = self.seq_lens.cpu().numpy()
90
+ paged_kernel_lens_cpu = paged_kernel_lens.cpu().numpy()
81
91
  self.kv_indices = torch.cat(
82
92
  [
83
93
  self.req_to_token_pool.req_to_token[
84
- req_pool_indices_cpu[i], : seq_lens_cpu[i]
94
+ req_pool_indices_cpu[i], : paged_kernel_lens_cpu[i]
85
95
  ]
86
96
  for i in range(self.batch_size)
87
97
  ],
88
98
  dim=0,
89
99
  ).contiguous()
90
100
 
91
- workspace_buffer = torch.empty(
92
- 32 * 1024 * 1024, dtype=torch.int8, device="cuda"
93
- )
94
101
  if (
95
102
  self.forward_mode == ForwardMode.PREFILL
96
103
  or self.forward_mode == ForwardMode.EXTEND
97
104
  ):
105
+ # extend part
98
106
  self.qo_indptr = torch.zeros(
99
107
  (self.batch_size + 1,), dtype=torch.int32, device="cuda"
100
108
  )
101
109
  self.qo_indptr[1:] = torch.cumsum(self.extend_seq_lens, dim=0)
102
- self.prefill_wrapper = BatchPrefillWithPagedKVCacheWrapper(
103
- workspace_buffer, "NHD"
110
+
111
+ self.flashinfer_prefill_wrapper_ragged.end_forward()
112
+ self.flashinfer_prefill_wrapper_ragged.begin_forward(
113
+ self.qo_indptr,
114
+ self.qo_indptr.clone(),
115
+ num_qo_heads,
116
+ num_kv_heads,
117
+ head_dim,
104
118
  )
105
- args = [
119
+
120
+ # cached part
121
+ self.flashinfer_prefill_wrapper_paged.end_forward()
122
+ self.flashinfer_prefill_wrapper_paged.begin_forward(
106
123
  self.qo_indptr,
107
124
  self.kv_indptr,
108
125
  self.kv_indices,
109
126
  self.kv_last_page_len,
110
- self.model_runner.model_config.num_attention_heads // tp_size,
111
- self.model_runner.model_config.num_key_value_heads // tp_size,
112
- self.model_runner.model_config.head_dim,
113
- ]
114
-
115
- self.prefill_wrapper.begin_forward(*args)
116
- else:
117
- self.decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
118
- workspace_buffer, "NHD"
127
+ num_qo_heads,
128
+ num_kv_heads,
129
+ head_dim,
130
+ 1,
119
131
  )
120
- self.decode_wrapper.begin_forward(
132
+ else:
133
+ self.flashinfer_decode_wrapper.end_forward()
134
+ self.flashinfer_decode_wrapper.begin_forward(
121
135
  self.kv_indptr,
122
136
  self.kv_indices,
123
137
  self.kv_last_page_len,
124
- self.model_runner.model_config.num_attention_heads // tp_size,
125
- self.model_runner.model_config.num_key_value_heads // tp_size,
126
- self.model_runner.model_config.head_dim,
138
+ num_qo_heads,
139
+ num_kv_heads,
140
+ head_dim,
127
141
  1,
128
- "NONE",
129
- "float16",
142
+ pos_encoding_mode="NONE",
143
+ data_type=self.token_to_kv_pool.kv_data[0].dtype,
130
144
  )
131
145
 
132
146
  def init_extend_args(self):
@@ -150,6 +164,9 @@ class InputMetadata:
150
164
  out_cache_cont_end=None,
151
165
  top_logprobs_nums=None,
152
166
  return_logprob=False,
167
+ flashinfer_prefill_wrapper_ragged=None,
168
+ flashinfer_prefill_wrapper_paged=None,
169
+ flashinfer_decode_wrapper=None,
153
170
  ):
154
171
  batch_size = len(req_pool_indices)
155
172
  start_loc = torch.zeros((batch_size,), dtype=torch.int32, device="cuda")
@@ -182,7 +199,6 @@ class InputMetadata:
182
199
  other_kv_index = None
183
200
 
184
201
  ret = cls(
185
- model_runner=model_runner,
186
202
  forward_mode=forward_mode,
187
203
  batch_size=batch_size,
188
204
  total_num_tokens=total_num_tokens,
@@ -200,13 +216,20 @@ class InputMetadata:
200
216
  other_kv_index=other_kv_index,
201
217
  return_logprob=return_logprob,
202
218
  top_logprobs_nums=top_logprobs_nums,
219
+ flashinfer_prefill_wrapper_ragged=flashinfer_prefill_wrapper_ragged,
220
+ flashinfer_prefill_wrapper_paged=flashinfer_prefill_wrapper_paged,
221
+ flashinfer_decode_wrapper=flashinfer_decode_wrapper,
203
222
  )
204
223
 
205
224
  if forward_mode == ForwardMode.EXTEND:
206
225
  ret.init_extend_args()
207
226
 
208
- if global_server_args_dict.get("enable_flashinfer", False):
209
- ret.init_flashinfer_args(tp_size)
227
+ if not global_server_args_dict.get("disable_flashinfer", False):
228
+ ret.init_flashinfer_args(
229
+ model_runner.model_config.num_attention_heads // tp_size,
230
+ model_runner.model_config.get_num_kv_heads(tp_size),
231
+ model_runner.model_config.head_dim,
232
+ )
210
233
 
211
234
  return ret
212
235
 
@@ -229,24 +252,27 @@ class ModelRunner:
229
252
  self.tp_size = tp_size
230
253
  self.nccl_port = nccl_port
231
254
  self.server_args = server_args
232
-
233
- global global_server_args_dict
234
- global_server_args_dict = {
235
- "enable_flashinfer": server_args.enable_flashinfer,
236
- "attention_reduce_in_fp32": server_args.attention_reduce_in_fp32,
237
- }
255
+ self.is_multimodal_model = is_multimodal_model(self.model_config)
256
+ monkey_patch_vllm_dummy_weight_loader()
238
257
 
239
258
  # Init torch distributed
240
259
  logger.info(f"[gpu_id={self.gpu_id}] Set cuda device.")
241
260
  torch.cuda.set_device(self.gpu_id)
242
261
  logger.info(f"[gpu_id={self.gpu_id}] Init nccl begin.")
243
- monkey_patch_vllm_p2p_access_check()
262
+
263
+ if not server_args.enable_p2p_check:
264
+ monkey_patch_vllm_p2p_access_check(self.gpu_id)
265
+
266
+ if server_args.nccl_init_addr:
267
+ nccl_init_method = f"tcp://{server_args.nccl_init_addr}"
268
+ else:
269
+ nccl_init_method = f"tcp://127.0.0.1:{self.nccl_port}"
244
270
  init_distributed_environment(
245
271
  backend="nccl",
246
272
  world_size=self.tp_size,
247
273
  rank=self.tp_rank,
248
274
  local_rank=self.gpu_id,
249
- distributed_init_method=f"tcp://127.0.0.1:{self.nccl_port}",
275
+ distributed_init_method=nccl_init_method,
250
276
  )
251
277
  initialize_model_parallel(tensor_model_parallel_size=self.tp_size)
252
278
  total_gpu_memory = get_available_gpu_memory(
@@ -260,9 +286,18 @@ class ModelRunner:
260
286
  "The memory capacity is unbalanced. Some GPUs may be occupied by other processes."
261
287
  )
262
288
 
289
+ # Set some global args
290
+ global global_server_args_dict
291
+ global_server_args_dict = {
292
+ "disable_flashinfer": server_args.disable_flashinfer,
293
+ "attention_reduce_in_fp32": server_args.attention_reduce_in_fp32,
294
+ }
295
+
296
+ # Load the model and create memory pool
263
297
  self.load_model()
264
298
  self.init_memory_pool(total_gpu_memory)
265
- self.is_multimodal_model = is_multimodal_model(self.model_config)
299
+ self.init_cublas()
300
+ self.init_flash_infer()
266
301
 
267
302
  def load_model(self):
268
303
  logger.info(
@@ -278,10 +313,11 @@ class ModelRunner:
278
313
  tokenizer=None,
279
314
  tokenizer_mode=None,
280
315
  trust_remote_code=self.server_args.trust_remote_code,
281
- dtype=torch.float16,
316
+ dtype=self.server_args.dtype,
282
317
  seed=42,
283
318
  skip_tokenizer_init=True,
284
319
  )
320
+ self.dtype = vllm_model_config.dtype
285
321
  if self.model_config.model_overide_args is not None:
286
322
  vllm_model_config.hf_config.update(self.model_config.model_overide_args)
287
323
 
@@ -290,7 +326,7 @@ class ModelRunner:
290
326
  device_config=device_config,
291
327
  load_config=load_config,
292
328
  lora_config=None,
293
- vision_language_config=None,
329
+ multimodal_config=None,
294
330
  parallel_config=None,
295
331
  scheduler_config=None,
296
332
  cache_config=None,
@@ -298,6 +334,7 @@ class ModelRunner:
298
334
  logger.info(
299
335
  f"[gpu_id={self.gpu_id}] Load weight end. "
300
336
  f"type={type(self.model).__name__}, "
337
+ f"dtype={self.dtype}, "
301
338
  f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
302
339
  )
303
340
 
@@ -306,8 +343,14 @@ class ModelRunner:
306
343
  self.gpu_id, distributed=self.tp_size > 1
307
344
  )
308
345
  head_dim = self.model_config.head_dim
309
- head_num = self.model_config.num_key_value_heads // self.tp_size
310
- cell_size = head_num * head_dim * self.model_config.num_hidden_layers * 2 * 2
346
+ head_num = self.model_config.get_num_kv_heads(self.tp_size)
347
+ cell_size = (
348
+ head_num
349
+ * head_dim
350
+ * self.model_config.num_hidden_layers
351
+ * 2
352
+ * torch._utils._element_size(self.dtype)
353
+ )
311
354
  rest_memory = available_gpu_memory - total_gpu_memory * (
312
355
  1 - self.mem_fraction_static
313
356
  )
@@ -319,7 +362,7 @@ class ModelRunner:
319
362
 
320
363
  if self.max_total_num_tokens <= 0:
321
364
  raise RuntimeError(
322
- "Not enought memory. Please try to increase --mem-fraction-static."
365
+ "Not enough memory. Please try to increase --mem-fraction-static."
323
366
  )
324
367
 
325
368
  self.req_to_token_pool = ReqToTokenPool(
@@ -328,8 +371,8 @@ class ModelRunner:
328
371
  )
329
372
  self.token_to_kv_pool = TokenToKVPool(
330
373
  self.max_total_num_tokens,
331
- dtype=torch.float16,
332
- head_num=self.model_config.num_key_value_heads // self.tp_size,
374
+ dtype=self.dtype,
375
+ head_num=self.model_config.get_num_kv_heads(self.tp_size),
333
376
  head_dim=self.model_config.head_dim,
334
377
  layer_num=self.model_config.num_hidden_layers,
335
378
  )
@@ -338,6 +381,50 @@ class ModelRunner:
338
381
  f"avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
339
382
  )
340
383
 
384
+ def init_cublas(self):
385
+ """We need to run a small matmul to init cublas. Otherwise, it will raise some errors later."""
386
+ dtype = torch.float16
387
+ device = "cuda"
388
+ a = torch.ones((16, 16), dtype=dtype, device=device)
389
+ b = torch.ones((16, 16), dtype=dtype, device=device)
390
+ c = a @ b
391
+ return c
392
+
393
+ def init_flash_infer(self):
394
+ if not global_server_args_dict.get("disable_flashinfer", False):
395
+ from flashinfer import (
396
+ BatchDecodeWithPagedKVCacheWrapper,
397
+ BatchPrefillWithPagedKVCacheWrapper,
398
+ BatchPrefillWithRaggedKVCacheWrapper,
399
+ )
400
+ from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
401
+
402
+ if not _grouped_size_compiled_for_decode_kernels(
403
+ self.model_config.num_attention_heads // self.tp_size,
404
+ self.model_config.get_num_kv_heads(self.tp_size),
405
+ ):
406
+ use_tensor_cores = True
407
+ else:
408
+ use_tensor_cores = False
409
+
410
+ workspace_buffers = torch.empty(
411
+ 2, 96 * 1024 * 1024, dtype=torch.uint8, device="cuda"
412
+ )
413
+ self.flashinfer_prefill_wrapper_ragged = (
414
+ BatchPrefillWithRaggedKVCacheWrapper(workspace_buffers[0], "NHD")
415
+ )
416
+ self.flashinfer_prefill_wrapper_paged = BatchPrefillWithPagedKVCacheWrapper(
417
+ workspace_buffers[1], "NHD"
418
+ )
419
+ self.flashinfer_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
420
+ workspace_buffers[0], "NHD", use_tensor_cores=use_tensor_cores
421
+ )
422
+ else:
423
+ self.flashinfer_prefill_wrapper_ragged = (
424
+ self.flashinfer_prefill_wrapper_paged
425
+ ) = None
426
+ self.flashinfer_decode_wrapper = None
427
+
341
428
  @torch.inference_mode()
342
429
  def forward_prefill(self, batch: Batch):
343
430
  input_metadata = InputMetadata.create(
@@ -351,6 +438,9 @@ class ModelRunner:
351
438
  out_cache_loc=batch.out_cache_loc,
352
439
  top_logprobs_nums=batch.top_logprobs_nums,
353
440
  return_logprob=batch.return_logprob,
441
+ flashinfer_prefill_wrapper_ragged=self.flashinfer_prefill_wrapper_ragged,
442
+ flashinfer_prefill_wrapper_paged=self.flashinfer_prefill_wrapper_paged,
443
+ flashinfer_decode_wrapper=self.flashinfer_decode_wrapper,
354
444
  )
355
445
  return self.model.forward(
356
446
  batch.input_ids, input_metadata.positions, input_metadata
@@ -369,6 +459,9 @@ class ModelRunner:
369
459
  out_cache_loc=batch.out_cache_loc,
370
460
  top_logprobs_nums=batch.top_logprobs_nums,
371
461
  return_logprob=batch.return_logprob,
462
+ flashinfer_prefill_wrapper_ragged=self.flashinfer_prefill_wrapper_ragged,
463
+ flashinfer_prefill_wrapper_paged=self.flashinfer_prefill_wrapper_paged,
464
+ flashinfer_decode_wrapper=self.flashinfer_decode_wrapper,
372
465
  )
373
466
  return self.model.forward(
374
467
  batch.input_ids, input_metadata.positions, input_metadata
@@ -389,6 +482,9 @@ class ModelRunner:
389
482
  out_cache_cont_end=batch.out_cache_cont_end,
390
483
  top_logprobs_nums=batch.top_logprobs_nums,
391
484
  return_logprob=batch.return_logprob,
485
+ flashinfer_prefill_wrapper_ragged=self.flashinfer_prefill_wrapper_ragged,
486
+ flashinfer_prefill_wrapper_paged=self.flashinfer_prefill_wrapper_paged,
487
+ flashinfer_decode_wrapper=self.flashinfer_decode_wrapper,
392
488
  )
393
489
  return self.model.forward(
394
490
  batch.input_ids, input_metadata.positions, input_metadata
@@ -407,6 +503,9 @@ class ModelRunner:
407
503
  out_cache_loc=batch.out_cache_loc,
408
504
  top_logprobs_nums=batch.top_logprobs_nums,
409
505
  return_logprob=batch.return_logprob,
506
+ flashinfer_prefill_wrapper_ragged=self.flashinfer_prefill_wrapper_ragged,
507
+ flashinfer_prefill_wrapper_paged=self.flashinfer_prefill_wrapper_paged,
508
+ flashinfer_decode_wrapper=self.flashinfer_decode_wrapper,
410
509
  )
411
510
  return self.model.forward(
412
511
  batch.input_ids,
@@ -440,16 +539,29 @@ def import_model_classes():
440
539
  module = importlib.import_module(name)
441
540
  if hasattr(module, "EntryClass"):
442
541
  entry = module.EntryClass
443
- if isinstance(entry, list): # To support multiple model classes in one module
542
+ if isinstance(
543
+ entry, list
544
+ ): # To support multiple model classes in one module
444
545
  for tmp in entry:
445
546
  model_arch_name_to_cls[tmp.__name__] = tmp
446
547
  else:
447
548
  model_arch_name_to_cls[entry.__name__] = entry
549
+
550
+ # compat: some models such as chatglm has incorrect class set in config.json
551
+ # usage: [ tuple("From_Entry_Class_Name": EntryClass), ]
552
+ if hasattr(module, "EntryClassRemapping") and isinstance(
553
+ module.EntryClassRemapping, list
554
+ ):
555
+ for remap in module.EntryClassRemapping:
556
+ if isinstance(remap, tuple) and len(remap) == 2:
557
+ model_arch_name_to_cls[remap[0]] = remap[1]
558
+
448
559
  return model_arch_name_to_cls
449
560
 
450
561
 
451
562
  def load_model_cls_srt(model_arch: str) -> Optional[Type[nn.Module]]:
452
563
  model_arch_name_to_cls = import_model_classes()
564
+
453
565
  if model_arch not in model_arch_name_to_cls:
454
566
  raise ValueError(
455
567
  f"Unsupported architectures: {model_arch}. "
@@ -1,3 +1,7 @@
1
+ """
2
+ The radix tree data structure for managing the KV cache.
3
+ """
4
+
1
5
  import heapq
2
6
  import time
3
7
  from collections import defaultdict
@@ -1,3 +1,5 @@
1
+ """Request scheduler heuristic."""
2
+
1
3
  import random
2
4
  from collections import defaultdict
3
5