sglang 0.2.15__py3-none-any.whl → 0.3.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_latency.py +10 -6
- sglang/bench_serving.py +33 -38
- sglang/global_config.py +0 -4
- sglang/lang/backend/runtime_endpoint.py +13 -6
- sglang/lang/interpreter.py +1 -1
- sglang/launch_server.py +3 -6
- sglang/launch_server_llavavid.py +7 -8
- sglang/srt/{model_config.py → configs/model_config.py} +5 -0
- sglang/srt/constrained/__init__.py +2 -0
- sglang/srt/constrained/fsm_cache.py +29 -38
- sglang/srt/constrained/jump_forward.py +0 -1
- sglang/srt/conversation.py +4 -1
- sglang/srt/hf_transformers_utils.py +2 -4
- sglang/srt/layers/attention_backend.py +480 -0
- sglang/srt/layers/flashinfer_utils.py +235 -0
- sglang/srt/layers/logits_processor.py +64 -77
- sglang/srt/layers/radix_attention.py +11 -161
- sglang/srt/layers/sampler.py +40 -35
- sglang/srt/layers/torchao_utils.py +75 -0
- sglang/srt/layers/{decode_attention.py → triton_attention/decode_attention.py} +67 -63
- sglang/srt/layers/{extend_attention.py → triton_attention/extend_attention.py} +40 -132
- sglang/srt/layers/{prefill_attention.py → triton_attention/prefill_attention.py} +13 -7
- sglang/srt/lora/lora.py +403 -0
- sglang/srt/lora/lora_config.py +43 -0
- sglang/srt/lora/lora_manager.py +256 -0
- sglang/srt/managers/controller_multi.py +1 -5
- sglang/srt/managers/controller_single.py +0 -5
- sglang/srt/managers/io_struct.py +16 -1
- sglang/srt/managers/policy_scheduler.py +122 -5
- sglang/srt/managers/schedule_batch.py +110 -74
- sglang/srt/managers/tokenizer_manager.py +24 -15
- sglang/srt/managers/tp_worker.py +181 -115
- sglang/srt/model_executor/cuda_graph_runner.py +60 -133
- sglang/srt/model_executor/forward_batch_info.py +35 -312
- sglang/srt/model_executor/model_runner.py +118 -141
- sglang/srt/models/baichuan.py +416 -0
- sglang/srt/models/chatglm.py +6 -8
- sglang/srt/models/commandr.py +1 -5
- sglang/srt/models/dbrx.py +1 -5
- sglang/srt/models/deepseek.py +1 -5
- sglang/srt/models/deepseek_v2.py +1 -5
- sglang/srt/models/exaone.py +8 -43
- sglang/srt/models/gemma.py +1 -5
- sglang/srt/models/gemma2.py +1 -5
- sglang/srt/models/gpt_bigcode.py +1 -5
- sglang/srt/models/grok.py +1 -5
- sglang/srt/models/internlm2.py +1 -5
- sglang/srt/models/{llama2.py → llama.py} +48 -26
- sglang/srt/models/llama_classification.py +14 -40
- sglang/srt/models/llama_embedding.py +7 -6
- sglang/srt/models/llava.py +38 -16
- sglang/srt/models/llavavid.py +7 -8
- sglang/srt/models/minicpm.py +1 -5
- sglang/srt/models/minicpm3.py +665 -0
- sglang/srt/models/mistral.py +2 -3
- sglang/srt/models/mixtral.py +6 -5
- sglang/srt/models/mixtral_quant.py +1 -5
- sglang/srt/models/qwen.py +1 -5
- sglang/srt/models/qwen2.py +1 -5
- sglang/srt/models/qwen2_moe.py +6 -5
- sglang/srt/models/stablelm.py +1 -5
- sglang/srt/models/xverse.py +375 -0
- sglang/srt/models/xverse_moe.py +445 -0
- sglang/srt/openai_api/adapter.py +65 -46
- sglang/srt/openai_api/protocol.py +11 -3
- sglang/srt/sampling/sampling_batch_info.py +67 -58
- sglang/srt/server.py +24 -14
- sglang/srt/server_args.py +130 -28
- sglang/srt/utils.py +12 -0
- sglang/test/few_shot_gsm8k.py +132 -0
- sglang/test/runners.py +114 -22
- sglang/test/test_programs.py +70 -0
- sglang/test/test_utils.py +89 -1
- sglang/utils.py +38 -4
- sglang/version.py +1 -1
- {sglang-0.2.15.dist-info → sglang-0.3.1.dist-info}/METADATA +31 -18
- sglang-0.3.1.dist-info/RECORD +129 -0
- {sglang-0.2.15.dist-info → sglang-0.3.1.dist-info}/WHEEL +1 -1
- sglang-0.2.15.dist-info/RECORD +0 -118
- {sglang-0.2.15.dist-info → sglang-0.3.1.dist-info}/LICENSE +0 -0
- {sglang-0.2.15.dist-info → sglang-0.3.1.dist-info}/top_level.txt +0 -0
@@ -25,12 +25,6 @@ from typing import Optional, Tuple, Type
|
|
25
25
|
|
26
26
|
import torch
|
27
27
|
import torch.nn as nn
|
28
|
-
from flashinfer import (
|
29
|
-
BatchDecodeWithPagedKVCacheWrapper,
|
30
|
-
BatchPrefillWithPagedKVCacheWrapper,
|
31
|
-
BatchPrefillWithRaggedKVCacheWrapper,
|
32
|
-
)
|
33
|
-
from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
|
34
28
|
from vllm.config import DeviceConfig, LoadConfig
|
35
29
|
from vllm.config import ModelConfig as VllmModelConfig
|
36
30
|
from vllm.distributed import (
|
@@ -43,17 +37,19 @@ from vllm.distributed.parallel_state import in_the_same_node_as
|
|
43
37
|
from vllm.model_executor.model_loader import get_model
|
44
38
|
from vllm.model_executor.models import ModelRegistry
|
45
39
|
|
46
|
-
from sglang.
|
40
|
+
from sglang.srt.configs.model_config import AttentionArch, ModelConfig
|
41
|
+
from sglang.srt.layers.attention_backend import FlashInferAttnBackend, TritonAttnBackend
|
47
42
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
48
|
-
from sglang.srt.layers.sampler import SampleOutput
|
43
|
+
from sglang.srt.layers.sampler import SampleOutput, Sampler
|
44
|
+
from sglang.srt.lora.lora_manager import LoRAManager
|
49
45
|
from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict
|
50
46
|
from sglang.srt.mem_cache.memory_pool import (
|
51
47
|
MHATokenToKVPool,
|
52
48
|
MLATokenToKVPool,
|
53
49
|
ReqToTokenPool,
|
54
50
|
)
|
55
|
-
from sglang.srt.
|
56
|
-
from sglang.srt.
|
51
|
+
from sglang.srt.model_executor.forward_batch_info import InputMetadata
|
52
|
+
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
57
53
|
from sglang.srt.server_args import ServerArgs
|
58
54
|
from sglang.srt.utils import (
|
59
55
|
get_available_gpu_memory,
|
@@ -69,6 +65,8 @@ logger = logging.getLogger(__name__)
|
|
69
65
|
|
70
66
|
|
71
67
|
class ModelRunner:
|
68
|
+
"""ModelRunner runs the forward passes of the models."""
|
69
|
+
|
72
70
|
def __init__(
|
73
71
|
self,
|
74
72
|
model_config: ModelConfig,
|
@@ -92,13 +90,15 @@ class ModelRunner:
|
|
92
90
|
)
|
93
91
|
global_server_args_dict.update(
|
94
92
|
{
|
95
|
-
"
|
96
|
-
"
|
93
|
+
"attention_backend": server_args.attention_backend,
|
94
|
+
"sampling_backend": server_args.sampling_backend,
|
97
95
|
"triton_attention_reduce_in_fp32": server_args.triton_attention_reduce_in_fp32,
|
98
96
|
"enable_mla": server_args.enable_mla,
|
97
|
+
"torchao_config": server_args.torchao_config,
|
99
98
|
}
|
100
99
|
)
|
101
100
|
|
101
|
+
# Model-specific adjustment
|
102
102
|
if self.is_multimodal_model:
|
103
103
|
logger.info(
|
104
104
|
"Automatically turn off --chunked-prefill-size and adjust --mem-fraction-static for multimodal models."
|
@@ -106,15 +106,19 @@ class ModelRunner:
|
|
106
106
|
server_args.chunked_prefill_size = None
|
107
107
|
server_args.mem_fraction_static *= 0.95
|
108
108
|
|
109
|
+
# Init componnets
|
109
110
|
min_per_gpu_memory = self.init_torch_distributed()
|
111
|
+
self.sampler = Sampler()
|
110
112
|
self.load_model()
|
113
|
+
if server_args.lora_paths is not None:
|
114
|
+
self.init_lora_manager()
|
111
115
|
self.init_memory_pool(
|
112
116
|
min_per_gpu_memory,
|
113
|
-
server_args.
|
117
|
+
server_args.max_running_requests,
|
114
118
|
server_args.max_total_tokens,
|
115
119
|
)
|
116
120
|
self.init_cublas()
|
117
|
-
self.
|
121
|
+
self.init_attention_backend()
|
118
122
|
self.init_cuda_graphs()
|
119
123
|
|
120
124
|
def init_torch_distributed(self):
|
@@ -162,6 +166,7 @@ class ModelRunner:
|
|
162
166
|
return min_per_gpu_memory
|
163
167
|
|
164
168
|
def load_model(self):
|
169
|
+
torch.set_num_threads(1)
|
165
170
|
logger.info(
|
166
171
|
f"Load weight begin. avail mem={get_available_gpu_memory(self.gpu_id):.2f} GB"
|
167
172
|
)
|
@@ -312,6 +317,17 @@ class ModelRunner:
|
|
312
317
|
logger.info("Update weights end.")
|
313
318
|
return True, "Succeeded to update model weights"
|
314
319
|
|
320
|
+
def init_lora_manager(self):
|
321
|
+
self.lora_manager = LoRAManager(
|
322
|
+
base_model=self.model,
|
323
|
+
lora_paths=self.server_args.lora_paths,
|
324
|
+
base_hf_config=self.model_config.hf_config,
|
325
|
+
max_loras_per_batch=self.server_args.max_loras_per_batch,
|
326
|
+
load_config=self.load_config,
|
327
|
+
dtype=self.dtype,
|
328
|
+
)
|
329
|
+
logger.info("LoRA manager ready.")
|
330
|
+
|
315
331
|
def profile_max_num_token(self, total_gpu_memory: int):
|
316
332
|
available_gpu_memory = get_available_gpu_memory(
|
317
333
|
self.gpu_id, distributed=self.tp_size > 1
|
@@ -342,8 +358,8 @@ class ModelRunner:
|
|
342
358
|
def init_memory_pool(
|
343
359
|
self,
|
344
360
|
total_gpu_memory: int,
|
345
|
-
max_num_reqs: int = None,
|
346
|
-
max_total_tokens: int = None,
|
361
|
+
max_num_reqs: Optional[int] = None,
|
362
|
+
max_total_tokens: Optional[int] = None,
|
347
363
|
):
|
348
364
|
if self.server_args.kv_cache_dtype == "auto":
|
349
365
|
self.kv_cache_dtype = self.dtype
|
@@ -377,7 +393,7 @@ class ModelRunner:
|
|
377
393
|
),
|
378
394
|
2048,
|
379
395
|
),
|
380
|
-
|
396
|
+
4096,
|
381
397
|
)
|
382
398
|
|
383
399
|
self.req_to_token_pool = ReqToTokenPool(
|
@@ -395,9 +411,6 @@ class ModelRunner:
|
|
395
411
|
qk_rope_head_dim=self.model_config.qk_rope_head_dim,
|
396
412
|
layer_num=self.model_config.num_hidden_layers,
|
397
413
|
)
|
398
|
-
logger.info("using MLA Triton implementaion, flashinfer is disabled")
|
399
|
-
# FIXME: temporarily only Triton MLA is supported
|
400
|
-
self.server_args.disable_flashinfer = True
|
401
414
|
else:
|
402
415
|
self.token_to_kv_pool = MHATokenToKVPool(
|
403
416
|
self.max_total_num_tokens,
|
@@ -420,118 +433,46 @@ class ModelRunner:
|
|
420
433
|
c = a @ b
|
421
434
|
return c
|
422
435
|
|
423
|
-
def
|
424
|
-
"""Init
|
425
|
-
if self.server_args.
|
426
|
-
|
427
|
-
|
428
|
-
|
429
|
-
|
430
|
-
|
431
|
-
self.flashinfer_decode_wrapper = None
|
432
|
-
return
|
433
|
-
|
434
|
-
if not _grouped_size_compiled_for_decode_kernels(
|
435
|
-
self.model_config.num_attention_heads // self.tp_size,
|
436
|
-
self.model_config.get_num_kv_heads(self.tp_size),
|
437
|
-
):
|
438
|
-
use_tensor_cores = True
|
439
|
-
else:
|
440
|
-
use_tensor_cores = False
|
441
|
-
|
442
|
-
if self.sliding_window_size is None:
|
443
|
-
self.flashinfer_workspace_buffer = torch.empty(
|
444
|
-
global_config.flashinfer_workspace_size,
|
445
|
-
dtype=torch.uint8,
|
446
|
-
device="cuda",
|
447
|
-
)
|
448
|
-
self.flashinfer_prefill_wrapper_ragged = (
|
449
|
-
BatchPrefillWithRaggedKVCacheWrapper(
|
450
|
-
self.flashinfer_workspace_buffer, "NHD"
|
451
|
-
)
|
452
|
-
)
|
453
|
-
self.flashinfer_prefill_wrapper_paged = BatchPrefillWithPagedKVCacheWrapper(
|
454
|
-
self.flashinfer_workspace_buffer, "NHD"
|
455
|
-
)
|
456
|
-
self.flashinfer_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
|
457
|
-
self.flashinfer_workspace_buffer,
|
458
|
-
"NHD",
|
459
|
-
use_tensor_cores=use_tensor_cores,
|
436
|
+
def init_attention_backend(self):
|
437
|
+
"""Init attention kernel backend."""
|
438
|
+
if self.server_args.attention_backend == "flashinfer":
|
439
|
+
self.attn_backend = FlashInferAttnBackend(self)
|
440
|
+
elif self.server_args.attention_backend == "triton":
|
441
|
+
assert self.sliding_window_size is None, (
|
442
|
+
"Window attention is not supported in the triton attention backend. "
|
443
|
+
"Please use `--attention-backend flashinfer`."
|
460
444
|
)
|
445
|
+
self.attn_backend = TritonAttnBackend(self)
|
461
446
|
else:
|
462
|
-
|
463
|
-
|
464
|
-
dtype=torch.uint8,
|
465
|
-
device="cuda",
|
447
|
+
raise ValueError(
|
448
|
+
f"Invalid attention backend: {self.server_args.attention_backend}"
|
466
449
|
)
|
467
|
-
self.flashinfer_prefill_wrapper_ragged = None
|
468
|
-
self.flashinfer_prefill_wrapper_paged = []
|
469
|
-
self.flashinfer_decode_wrapper = []
|
470
|
-
for i in range(2):
|
471
|
-
self.flashinfer_prefill_wrapper_paged.append(
|
472
|
-
BatchPrefillWithPagedKVCacheWrapper(
|
473
|
-
self.flashinfer_workspace_buffer, "NHD"
|
474
|
-
)
|
475
|
-
)
|
476
|
-
self.flashinfer_decode_wrapper.append(
|
477
|
-
BatchDecodeWithPagedKVCacheWrapper(
|
478
|
-
self.flashinfer_workspace_buffer,
|
479
|
-
"NHD",
|
480
|
-
use_tensor_cores=use_tensor_cores,
|
481
|
-
)
|
482
|
-
)
|
483
450
|
|
484
451
|
def init_cuda_graphs(self):
|
485
452
|
"""Capture cuda graphs."""
|
453
|
+
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
|
454
|
+
|
455
|
+
self.cuda_graph_runner = None
|
456
|
+
|
486
457
|
if not self.is_generation:
|
487
458
|
# TODO: Currently, cuda graph only captures decode steps, which only exists for generation models
|
488
459
|
return
|
489
460
|
|
490
|
-
|
491
|
-
|
492
|
-
if self.server_args.disable_cuda_graph or self.server_args.disable_flashinfer:
|
493
|
-
self.cuda_graph_runner = None
|
461
|
+
if self.server_args.disable_cuda_graph:
|
494
462
|
return
|
495
463
|
|
496
464
|
logger.info("Capture cuda graph begin. This can take up to several minutes.")
|
497
|
-
|
498
|
-
if self.server_args.disable_cuda_graph_padding:
|
499
|
-
batch_size_list = list(range(1, 32)) + [64, 128]
|
500
|
-
else:
|
501
|
-
batch_size_list = [1, 2, 4] + [i * 8 for i in range(1, 21)]
|
502
|
-
|
503
|
-
self.cuda_graph_runner = CudaGraphRunner(
|
504
|
-
self,
|
505
|
-
max_batch_size_to_capture=max(batch_size_list),
|
506
|
-
use_torch_compile=self.server_args.enable_torch_compile,
|
507
|
-
disable_padding=self.server_args.disable_cuda_graph_padding,
|
508
|
-
)
|
509
|
-
try:
|
510
|
-
self.cuda_graph_runner.capture(batch_size_list)
|
511
|
-
except RuntimeError as e:
|
512
|
-
raise Exception(
|
513
|
-
f"Capture cuda graph failed: {e}\n"
|
514
|
-
"Possible solutions:\n"
|
515
|
-
"1. disable cuda graph by --disable-cuda-graph\n"
|
516
|
-
"2. set --mem-fraction-static to a smaller value\n"
|
517
|
-
"3. disable torch compile by not using --enable-torch-compile\n"
|
518
|
-
"Open an issue on GitHub https://github.com/sgl-project/sglang/issues/new/choose \n"
|
519
|
-
)
|
465
|
+
self.cuda_graph_runner = CudaGraphRunner(self)
|
520
466
|
|
521
467
|
@torch.inference_mode()
|
522
468
|
def forward_decode(self, batch: ScheduleBatch):
|
523
|
-
if
|
524
|
-
self.
|
525
|
-
|
526
|
-
|
527
|
-
):
|
469
|
+
if self.server_args.lora_paths is not None:
|
470
|
+
self.lora_manager.prepare_lora_batch(batch)
|
471
|
+
|
472
|
+
if self.cuda_graph_runner and self.cuda_graph_runner.can_run(len(batch.reqs)):
|
528
473
|
return self.cuda_graph_runner.replay(batch)
|
529
474
|
|
530
|
-
input_metadata = InputMetadata.from_schedule_batch(
|
531
|
-
self,
|
532
|
-
batch,
|
533
|
-
ForwardMode.DECODE,
|
534
|
-
)
|
475
|
+
input_metadata = InputMetadata.from_schedule_batch(self, batch)
|
535
476
|
|
536
477
|
return self.model.forward(
|
537
478
|
batch.input_ids, input_metadata.positions, input_metadata
|
@@ -539,11 +480,10 @@ class ModelRunner:
|
|
539
480
|
|
540
481
|
@torch.inference_mode()
|
541
482
|
def forward_extend(self, batch: ScheduleBatch):
|
542
|
-
input_metadata = InputMetadata.from_schedule_batch(
|
543
|
-
|
544
|
-
batch,
|
545
|
-
|
546
|
-
)
|
483
|
+
input_metadata = InputMetadata.from_schedule_batch(self, batch)
|
484
|
+
if self.server_args.lora_paths is not None:
|
485
|
+
self.lora_manager.prepare_lora_batch(batch, input_metadata.extend_seq_lens)
|
486
|
+
|
547
487
|
if self.is_generation:
|
548
488
|
return self.model.forward(
|
549
489
|
batch.input_ids, input_metadata.positions, input_metadata
|
@@ -559,11 +499,7 @@ class ModelRunner:
|
|
559
499
|
|
560
500
|
@torch.inference_mode()
|
561
501
|
def forward_extend_multi_modal(self, batch: ScheduleBatch):
|
562
|
-
input_metadata = InputMetadata.from_schedule_batch(
|
563
|
-
self,
|
564
|
-
batch,
|
565
|
-
forward_mode=ForwardMode.EXTEND,
|
566
|
-
)
|
502
|
+
input_metadata = InputMetadata.from_schedule_batch(self, batch)
|
567
503
|
return self.model.forward(
|
568
504
|
batch.input_ids,
|
569
505
|
input_metadata.positions,
|
@@ -573,17 +509,68 @@ class ModelRunner:
|
|
573
509
|
input_metadata.image_offsets,
|
574
510
|
)
|
575
511
|
|
576
|
-
def forward(
|
577
|
-
|
578
|
-
|
579
|
-
if self.is_multimodal_model and forward_mode
|
512
|
+
def forward(self, batch: ScheduleBatch) -> Tuple[LogitsProcessorOutput]:
|
513
|
+
assert batch.forward_mode is not None
|
514
|
+
|
515
|
+
if self.is_multimodal_model and batch.forward_mode.is_extend():
|
580
516
|
return self.forward_extend_multi_modal(batch)
|
581
|
-
elif forward_mode
|
517
|
+
elif batch.forward_mode.is_decode():
|
582
518
|
return self.forward_decode(batch)
|
583
|
-
elif forward_mode
|
519
|
+
elif batch.forward_mode.is_extend():
|
584
520
|
return self.forward_extend(batch)
|
585
521
|
else:
|
586
|
-
raise ValueError(f"Invaid forward mode: {forward_mode}")
|
522
|
+
raise ValueError(f"Invaid forward mode: {batch.forward_mode}")
|
523
|
+
|
524
|
+
def _check_sample_results(self, sample_output: SampleOutput):
|
525
|
+
if not torch.all(sample_output.success):
|
526
|
+
probs = sample_output.probs
|
527
|
+
batch_next_token_ids = sample_output.batch_next_token_ids
|
528
|
+
logging.warning("Sampling failed, fallback to top_k=1 strategy")
|
529
|
+
probs = probs.masked_fill(torch.isnan(probs), 0.0)
|
530
|
+
argmax_ids = torch.argmax(probs, dim=-1)
|
531
|
+
batch_next_token_ids = torch.where(
|
532
|
+
sample_output.success, batch_next_token_ids, argmax_ids
|
533
|
+
)
|
534
|
+
sample_output.probs = probs
|
535
|
+
sample_output.batch_next_token_ids = batch_next_token_ids
|
536
|
+
|
537
|
+
return sample_output.batch_next_token_ids
|
538
|
+
|
539
|
+
def _apply_logits_bias(
|
540
|
+
self, logits: torch.Tensor, sampling_info: SamplingBatchInfo
|
541
|
+
):
|
542
|
+
# Apply logit_bias
|
543
|
+
if sampling_info.logit_bias is not None:
|
544
|
+
logits.add_(sampling_info.logit_bias)
|
545
|
+
|
546
|
+
# min-token, presence, frequency
|
547
|
+
if sampling_info.linear_penalties is not None:
|
548
|
+
logits += sampling_info.linear_penalties
|
549
|
+
|
550
|
+
# repetition
|
551
|
+
if sampling_info.scaling_penalties is not None:
|
552
|
+
logits = torch.where(
|
553
|
+
logits > 0,
|
554
|
+
logits / sampling_info.scaling_penalties,
|
555
|
+
logits * sampling_info.scaling_penalties,
|
556
|
+
)
|
557
|
+
|
558
|
+
# Apply regex vocab_mask
|
559
|
+
if sampling_info.vocab_mask is not None:
|
560
|
+
logits = logits.masked_fill(sampling_info.vocab_mask, float("-inf"))
|
561
|
+
|
562
|
+
return logits
|
563
|
+
|
564
|
+
def sample(
|
565
|
+
self, logits_output: LogitsProcessorOutput, batch: ScheduleBatch
|
566
|
+
) -> torch.Tensor:
|
567
|
+
batch.sampling_info.update_regex_vocab_mask(batch)
|
568
|
+
batch.sampling_info.update_penalties()
|
569
|
+
logits = self._apply_logits_bias(
|
570
|
+
logits_output.next_token_logits, batch.sampling_info
|
571
|
+
)
|
572
|
+
sample_output = self.sampler(logits, batch.sampling_info)
|
573
|
+
return self._check_sample_results(sample_output)
|
587
574
|
|
588
575
|
|
589
576
|
@lru_cache()
|
@@ -606,16 +593,6 @@ def import_model_classes():
|
|
606
593
|
assert entry.__name__ not in model_arch_name_to_cls
|
607
594
|
model_arch_name_to_cls[entry.__name__] = entry
|
608
595
|
|
609
|
-
# compat: some models such as chatglm has incorrect class set in config.json
|
610
|
-
# usage: [ tuple("From_Entry_Class_Name": EntryClass), ]
|
611
|
-
if hasattr(module, "EntryClassRemapping") and isinstance(
|
612
|
-
module.EntryClassRemapping, list
|
613
|
-
):
|
614
|
-
for remap in module.EntryClassRemapping:
|
615
|
-
if isinstance(remap, tuple) and len(remap) == 2:
|
616
|
-
assert remap[0] not in model_arch_name_to_cls
|
617
|
-
model_arch_name_to_cls[remap[0]] = remap[1]
|
618
|
-
|
619
596
|
return model_arch_name_to_cls
|
620
597
|
|
621
598
|
|