evalscope 0.13.2__py3-none-any.whl → 0.15.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of evalscope might be problematic. Click here for more details.
- evalscope/arguments.py +2 -1
- evalscope/backend/rag_eval/__init__.py +1 -1
- evalscope/backend/rag_eval/backend_manager.py +21 -5
- evalscope/backend/rag_eval/cmteb/arguments.py +10 -0
- evalscope/backend/rag_eval/ragas/arguments.py +0 -1
- evalscope/backend/rag_eval/ragas/tasks/testset_generation.py +7 -2
- evalscope/backend/rag_eval/ragas/tasks/translate_prompt.py +0 -5
- evalscope/backend/rag_eval/utils/embedding.py +49 -3
- evalscope/backend/rag_eval/utils/llm.py +4 -4
- evalscope/backend/vlm_eval_kit/backend_manager.py +4 -2
- evalscope/benchmarks/__init__.py +2 -2
- evalscope/benchmarks/aigc/__init__.py +0 -0
- evalscope/benchmarks/aigc/t2i/__init__.py +0 -0
- evalscope/benchmarks/aigc/t2i/base.py +56 -0
- evalscope/benchmarks/aigc/t2i/evalmuse_adapter.py +77 -0
- evalscope/benchmarks/aigc/t2i/genai_bench_adapter.py +58 -0
- evalscope/benchmarks/aigc/t2i/general_t2i_adapter.py +58 -0
- evalscope/benchmarks/aigc/t2i/hpdv2_adapter.py +57 -0
- evalscope/benchmarks/aigc/t2i/tifa_adapter.py +37 -0
- evalscope/benchmarks/aime/aime24_adapter.py +1 -1
- evalscope/benchmarks/aime/aime25_adapter.py +4 -4
- evalscope/benchmarks/alpaca_eval/alpaca_eval_adapter.py +1 -2
- evalscope/benchmarks/arc/arc_adapter.py +2 -2
- evalscope/benchmarks/arena_hard/arena_hard_adapter.py +1 -3
- evalscope/benchmarks/ceval/ceval_adapter.py +2 -2
- evalscope/benchmarks/chinese_simple_qa/csimple_qa_adapter.py +1 -3
- evalscope/benchmarks/cmmlu/cmmlu_adapter.py +1 -1
- evalscope/benchmarks/competition_math/competition_math_adapter.py +1 -2
- evalscope/benchmarks/data_adapter.py +21 -10
- evalscope/benchmarks/data_collection/data_collection_adapter.py +6 -4
- evalscope/benchmarks/general_mcq/general_mcq_adapter.py +2 -2
- evalscope/benchmarks/general_qa/general_qa_adapter.py +1 -1
- evalscope/benchmarks/hellaswag/hellaswag_adapter.py +1 -1
- evalscope/benchmarks/live_code_bench/evaluate_utils.py +16 -21
- evalscope/benchmarks/live_code_bench/live_code_bench_adapter.py +5 -4
- evalscope/benchmarks/live_code_bench/testing_util.py +369 -550
- evalscope/benchmarks/maritime_bench/__init__.py +0 -0
- evalscope/benchmarks/maritime_bench/maritime_bench_adapter.py +79 -0
- evalscope/benchmarks/math_500/math_500_adapter.py +1 -1
- evalscope/benchmarks/mmlu/mmlu_adapter.py +8 -8
- evalscope/benchmarks/mmlu_pro/mmlu_pro_adapter.py +1 -1
- evalscope/benchmarks/mmlu_redux/mmlu_redux_adapter.py +1 -1
- evalscope/benchmarks/musr/musr_adapter.py +1 -1
- evalscope/benchmarks/simple_qa/simple_qa_adapter.py +1 -2
- evalscope/benchmarks/utils.py +7 -16
- evalscope/cli/start_app.py +1 -1
- evalscope/collections/evaluator.py +20 -6
- evalscope/config.py +8 -4
- evalscope/constants.py +11 -0
- evalscope/evaluator/evaluator.py +2 -2
- evalscope/evaluator/reviewer/auto_reviewer.py +1 -1
- evalscope/metrics/__init__.py +49 -4
- evalscope/metrics/llm_judge.py +1 -1
- evalscope/metrics/named_metrics.py +13 -0
- evalscope/metrics/t2v_metrics/__init__.py +66 -0
- evalscope/metrics/t2v_metrics/clipscore.py +14 -0
- evalscope/metrics/t2v_metrics/constants.py +12 -0
- evalscope/metrics/t2v_metrics/itmscore.py +14 -0
- evalscope/metrics/t2v_metrics/models/__init__.py +0 -0
- evalscope/metrics/t2v_metrics/models/clipscore_models/__init__.py +30 -0
- evalscope/metrics/t2v_metrics/models/clipscore_models/build_mps_model/__init__.py +0 -0
- evalscope/metrics/t2v_metrics/models/clipscore_models/build_mps_model/base_model.py +6 -0
- evalscope/metrics/t2v_metrics/models/clipscore_models/build_mps_model/clip_model.py +132 -0
- evalscope/metrics/t2v_metrics/models/clipscore_models/build_mps_model/cross_modeling.py +286 -0
- evalscope/metrics/t2v_metrics/models/clipscore_models/clip_model.py +114 -0
- evalscope/metrics/t2v_metrics/models/clipscore_models/hpsv2_model.py +86 -0
- evalscope/metrics/t2v_metrics/models/clipscore_models/mps_model.py +85 -0
- evalscope/metrics/t2v_metrics/models/clipscore_models/pickscore_model.py +62 -0
- evalscope/metrics/t2v_metrics/models/itmscore_models/__init__.py +26 -0
- evalscope/metrics/t2v_metrics/models/itmscore_models/blip2_itm_model.py +84 -0
- evalscope/metrics/t2v_metrics/models/itmscore_models/fga_blip2_model.py +97 -0
- evalscope/metrics/t2v_metrics/models/itmscore_models/image_reward/ImageReward.py +171 -0
- evalscope/metrics/t2v_metrics/models/itmscore_models/image_reward/__init__.py +0 -0
- evalscope/metrics/t2v_metrics/models/itmscore_models/image_reward/blip_pretrain.py +80 -0
- evalscope/metrics/t2v_metrics/models/itmscore_models/image_reward_model.py +73 -0
- evalscope/metrics/t2v_metrics/models/model.py +45 -0
- evalscope/metrics/t2v_metrics/models/utils.py +25 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/__init__.py +22 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5/__init__.py +0 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5/model/__init__.py +1 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5/model/language_model/clip_t5.py +300 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5/model/multimodal_encoder/builder.py +12 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5/model/multimodal_encoder/clip_encoder.py +82 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5/model/multimodal_projector/builder.py +50 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5_model.py +218 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/gpt4v_model.py +150 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/__init__.py +26 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/config.py +465 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/dist_utils.py +141 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/gradcam.py +22 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/logger.py +188 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/optims.py +106 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/registry.py +307 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/utils.py +416 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/vqa_tools/__init__.py +8 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/vqa_tools/vqa.py +191 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/vqa_tools/vqa_eval.py +318 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/default.yaml +10 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_caption_flant5xl.yaml +42 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_caption_opt2.7b.yaml +42 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_caption_opt6.7b.yaml +42 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_coco.yaml +36 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_instruct_flant5xl.yaml +43 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_instruct_flant5xxl.yaml +43 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_instruct_vicuna13b.yaml +43 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_instruct_vicuna7b.yaml +43 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain.yaml +36 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_flant5xl.yaml +42 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_flant5xl_iter_80k_total_100k_no_prefix.yaml +42 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_flant5xl_iter_80k_total_100k_prefix.yaml +42 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_flant5xl_vitL.yaml +43 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_flant5xxl.yaml +42 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_opt2.7b.yaml +42 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_opt6.7b.yaml +42 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_vitL.yaml +37 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_vicuna13b.yaml +43 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_vicuna7b.yaml +43 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/med_config.json +21 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/med_config_albef.json +22 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/med_large_config.json +21 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/__init__.py +208 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/base_model.py +231 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/Qformer.py +1093 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/__init__.py +0 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/blip2.py +211 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/blip2_image_text_matching.py +109 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/blip2_qformer.py +452 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/blip2_t5.py +364 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/blip2_t5_instruct.py +755 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/fga_blip2.py +273 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/modeling_llama.py +880 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/modeling_t5.py +1844 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/__init__.py +81 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip.py +56 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_caption.py +212 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_classification.py +164 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_feature_extractor.py +202 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_image_text_matching.py +185 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_nlvr.py +178 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_outputs.py +112 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_pretrain.py +371 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_vqa.py +344 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/nlvr_encoder.py +858 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/clip_vit.py +271 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/eva_vit.py +503 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/med.py +1270 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/vit.py +473 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/processors/__init__.py +31 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/processors/base_processor.py +27 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/processors/blip_processors.py +233 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/processors/randaugment.py +392 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/mm_utils.py +127 -0
- evalscope/metrics/t2v_metrics/models/vqascore_models/vqa_model.py +17 -0
- evalscope/metrics/t2v_metrics/score.py +78 -0
- evalscope/metrics/t2v_metrics/vqascore.py +14 -0
- evalscope/models/__init__.py +50 -14
- evalscope/models/adapters/__init__.py +17 -0
- evalscope/models/{base_adapter.py → adapters/base_adapter.py} +17 -17
- evalscope/models/{chat_adapter.py → adapters/chat_adapter.py} +10 -7
- evalscope/models/{choice_adapter.py → adapters/choice_adapter.py} +2 -6
- evalscope/models/{custom_adapter.py → adapters/custom_adapter.py} +2 -4
- evalscope/models/{server_adapter.py → adapters/server_adapter.py} +1 -3
- evalscope/models/adapters/t2i_adapter.py +76 -0
- evalscope/models/custom/__init__.py +2 -1
- evalscope/models/custom/dummy_model.py +11 -13
- evalscope/models/local_model.py +82 -33
- evalscope/models/model.py +2 -42
- evalscope/models/register.py +26 -0
- evalscope/perf/arguments.py +24 -5
- evalscope/perf/benchmark.py +28 -42
- evalscope/perf/http_client.py +2 -3
- evalscope/perf/plugin/api/custom_api.py +1 -1
- evalscope/perf/plugin/api/openai_api.py +2 -2
- evalscope/perf/plugin/datasets/custom.py +4 -1
- evalscope/perf/plugin/datasets/flickr8k.py +2 -1
- evalscope/perf/plugin/datasets/line_by_line.py +4 -1
- evalscope/perf/plugin/datasets/longalpaca.py +4 -1
- evalscope/perf/plugin/datasets/openqa.py +4 -1
- evalscope/perf/plugin/datasets/random_dataset.py +13 -6
- evalscope/perf/utils/benchmark_util.py +14 -8
- evalscope/perf/utils/db_util.py +9 -3
- evalscope/perf/utils/log_utils.py +41 -0
- evalscope/report/__init__.py +1 -0
- evalscope/report/app.py +128 -78
- evalscope/report/app_arguments.py +11 -0
- evalscope/report/generator.py +1 -1
- evalscope/run.py +10 -3
- evalscope/summarizer.py +2 -1
- evalscope/third_party/thinkbench/eval.py +19 -7
- evalscope/utils/chat_service.py +2 -2
- evalscope/utils/import_utils.py +66 -0
- evalscope/utils/utils.py +48 -29
- evalscope/version.py +2 -2
- {evalscope-0.13.2.dist-info → evalscope-0.15.0.dist-info}/METADATA +37 -15
- {evalscope-0.13.2.dist-info → evalscope-0.15.0.dist-info}/RECORD +209 -96
- tests/aigc/__init__.py +1 -0
- tests/aigc/test_t2i.py +87 -0
- tests/cli/test_all.py +4 -4
- tests/cli/test_collection.py +2 -1
- tests/cli/test_run.py +19 -12
- tests/perf/test_perf.py +3 -3
- tests/rag/test_clip_benchmark.py +0 -1
- tests/rag/test_mteb.py +37 -8
- tests/rag/test_ragas.py +29 -26
- tests/vlm/test_vlmeval.py +37 -1
- evalscope/backend/vlm_eval_kit/custom_dataset.py +0 -46
- evalscope/benchmarks/live_code_bench/execute_utils.py +0 -267
- evalscope/metrics/code_metric.py +0 -98
- evalscope/metrics/resources/gpt2-zhcn3-v4.bpe +0 -58485
- evalscope/metrics/resources/gpt2-zhcn3-v4.json +0 -1
- {evalscope-0.13.2.dist-info → evalscope-0.15.0.dist-info}/LICENSE +0 -0
- {evalscope-0.13.2.dist-info → evalscope-0.15.0.dist-info}/WHEEL +0 -0
- {evalscope-0.13.2.dist-info → evalscope-0.15.0.dist-info}/entry_points.txt +0 -0
- {evalscope-0.13.2.dist-info → evalscope-0.15.0.dist-info}/top_level.txt +0 -0
evalscope/models/local_model.py
CHANGED
|
@@ -1,7 +1,8 @@
|
|
|
1
|
-
import
|
|
1
|
+
import importlib
|
|
2
|
+
from abc import ABC, abstractmethod
|
|
2
3
|
from typing import TYPE_CHECKING, Optional
|
|
3
4
|
|
|
4
|
-
from evalscope.constants import DEFAULT_MODEL_CACHE_DIR, DEFAULT_MODEL_REVISION, EvalType
|
|
5
|
+
from evalscope.constants import DEFAULT_MODEL_CACHE_DIR, DEFAULT_MODEL_REVISION, EvalType, ModelTask
|
|
5
6
|
from evalscope.utils.logger import get_logger
|
|
6
7
|
from evalscope.utils.model_utils import get_device
|
|
7
8
|
|
|
@@ -11,31 +12,55 @@ if TYPE_CHECKING:
|
|
|
11
12
|
logger = get_logger()
|
|
12
13
|
|
|
13
14
|
|
|
14
|
-
class LocalModel:
|
|
15
|
+
class LocalModel(ABC):
|
|
15
16
|
|
|
16
17
|
def __init__(self,
|
|
17
18
|
model_id: str,
|
|
18
|
-
model_revision: str =
|
|
19
|
-
device_map: str =
|
|
19
|
+
model_revision: str = None,
|
|
20
|
+
device_map: str = None,
|
|
20
21
|
torch_dtype: str = 'auto',
|
|
21
22
|
cache_dir: str = None,
|
|
22
23
|
**kwargs):
|
|
23
|
-
from modelscope import AutoModelForCausalLM, AutoTokenizer
|
|
24
24
|
|
|
25
|
-
|
|
25
|
+
self.model_id = model_id
|
|
26
|
+
self.model_revision = model_revision or DEFAULT_MODEL_REVISION
|
|
27
|
+
self.device = device_map or get_device()
|
|
28
|
+
self.cache_dir = cache_dir or DEFAULT_MODEL_CACHE_DIR
|
|
29
|
+
self.kwargs = kwargs
|
|
30
|
+
self.model = None
|
|
31
|
+
self.tokenizer = None
|
|
26
32
|
|
|
27
33
|
if isinstance(torch_dtype, str) and torch_dtype != 'auto':
|
|
34
|
+
import torch
|
|
28
35
|
torch_dtype = eval(torch_dtype)
|
|
36
|
+
self.torch_dtype = torch_dtype
|
|
37
|
+
|
|
38
|
+
self.model_cfg = {
|
|
39
|
+
'model_id': self.model_id,
|
|
40
|
+
'device_map': self.device,
|
|
41
|
+
'torch_dtype': str(self.torch_dtype),
|
|
42
|
+
}
|
|
43
|
+
|
|
44
|
+
@abstractmethod
|
|
45
|
+
def load_model(self):
|
|
46
|
+
pass
|
|
29
47
|
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
48
|
+
|
|
49
|
+
class LocalChatModel(LocalModel):
|
|
50
|
+
|
|
51
|
+
def __init__(self, **kwargs):
|
|
52
|
+
super().__init__(**kwargs)
|
|
53
|
+
|
|
54
|
+
def load_model(self):
|
|
55
|
+
from modelscope import AutoModelForCausalLM, AutoTokenizer
|
|
56
|
+
|
|
57
|
+
logger.info(f'Loading model {self.model_id} ...')
|
|
33
58
|
|
|
34
59
|
self.tokenizer = AutoTokenizer.from_pretrained(
|
|
35
60
|
self.model_id,
|
|
36
|
-
revision=model_revision,
|
|
61
|
+
revision=self.model_revision,
|
|
37
62
|
trust_remote_code=True,
|
|
38
|
-
cache_dir=
|
|
63
|
+
cache_dir=self.cache_dir,
|
|
39
64
|
)
|
|
40
65
|
|
|
41
66
|
# Fix no padding
|
|
@@ -44,18 +69,45 @@ class LocalModel:
|
|
|
44
69
|
|
|
45
70
|
self.model = AutoModelForCausalLM.from_pretrained(
|
|
46
71
|
self.model_id,
|
|
47
|
-
revision=model_revision,
|
|
48
|
-
device_map=
|
|
72
|
+
revision=self.model_revision,
|
|
73
|
+
device_map=self.device,
|
|
49
74
|
trust_remote_code=True,
|
|
50
|
-
torch_dtype=torch_dtype,
|
|
51
|
-
cache_dir=
|
|
75
|
+
torch_dtype=self.torch_dtype,
|
|
76
|
+
cache_dir=self.cache_dir,
|
|
52
77
|
)
|
|
53
78
|
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
79
|
+
|
|
80
|
+
class LocalImageModel(LocalModel):
|
|
81
|
+
|
|
82
|
+
def __init__(self, **kwargs):
|
|
83
|
+
super().__init__(**kwargs)
|
|
84
|
+
|
|
85
|
+
self.pipeline_cls = kwargs.pop('pipeline_cls', None)
|
|
86
|
+
# default to DiffusionPipeline if not specified
|
|
87
|
+
if self.pipeline_cls is None:
|
|
88
|
+
if 'flux' in self.model_id.lower():
|
|
89
|
+
self.pipeline_cls = 'FluxPipeline'
|
|
90
|
+
else:
|
|
91
|
+
self.pipeline_cls = 'DiffusionPipeline'
|
|
92
|
+
|
|
93
|
+
def load_model(self):
|
|
94
|
+
# from modelscope import pipeline_cls
|
|
95
|
+
module = getattr(importlib.import_module('modelscope'), self.pipeline_cls)
|
|
96
|
+
|
|
97
|
+
logger.info(f'Loading model {self.model_id} with {self.pipeline_cls} ...')
|
|
98
|
+
|
|
99
|
+
self.model = module.from_pretrained(
|
|
100
|
+
self.model_id,
|
|
101
|
+
revision=self.model_revision,
|
|
102
|
+
torch_dtype=self.torch_dtype,
|
|
103
|
+
cache_dir=self.cache_dir,
|
|
104
|
+
**self.kwargs,
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
self.model.to(self.device)
|
|
108
|
+
|
|
109
|
+
def __call__(self, *args, **kwargs):
|
|
110
|
+
return self.model(*args, **kwargs)
|
|
59
111
|
|
|
60
112
|
|
|
61
113
|
def get_local_model(task_cfg: 'TaskConfig') -> Optional[LocalModel]:
|
|
@@ -64,16 +116,13 @@ def get_local_model(task_cfg: 'TaskConfig') -> Optional[LocalModel]:
|
|
|
64
116
|
"""
|
|
65
117
|
if task_cfg.eval_type != EvalType.CHECKPOINT:
|
|
66
118
|
return None
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
base_model
|
|
74
|
-
model_id=task_cfg.model,
|
|
75
|
-
model_revision=model_revision,
|
|
76
|
-
device_map=device_map,
|
|
77
|
-
torch_dtype=model_precision,
|
|
78
|
-
cache_dir=cache_dir)
|
|
119
|
+
elif task_cfg.model_task == ModelTask.TEXT_GENERATION:
|
|
120
|
+
base_model = LocalChatModel(model_id=task_cfg.model, **task_cfg.model_args)
|
|
121
|
+
base_model.load_model()
|
|
122
|
+
return base_model
|
|
123
|
+
elif task_cfg.model_task == ModelTask.IMAGE_GENERATION:
|
|
124
|
+
base_model = LocalImageModel(model_id=task_cfg.model, **task_cfg.model_args)
|
|
125
|
+
base_model.load_model()
|
|
79
126
|
return base_model
|
|
127
|
+
else:
|
|
128
|
+
raise ValueError(f'Unsupported model task: {task_cfg.model_task} for model checkpoint.')
|
evalscope/models/model.py
CHANGED
|
@@ -1,9 +1,8 @@
|
|
|
1
1
|
# Copyright (c) Alibaba, Inc. and its affiliates.
|
|
2
2
|
import os
|
|
3
|
-
import random
|
|
4
3
|
import time
|
|
5
4
|
from abc import ABC, abstractmethod
|
|
6
|
-
from typing import Any
|
|
5
|
+
from typing import Any, List
|
|
7
6
|
|
|
8
7
|
from evalscope.utils.logger import get_logger
|
|
9
8
|
|
|
@@ -95,6 +94,7 @@ class ChatBaseModel(BaseModel):
|
|
|
95
94
|
raise NotImplementedError
|
|
96
95
|
|
|
97
96
|
|
|
97
|
+
# TODO: Remove this class after refactoring all models
|
|
98
98
|
class OpenAIModel(ChatBaseModel):
|
|
99
99
|
"""
|
|
100
100
|
APIs of OpenAI models.
|
|
@@ -187,43 +187,3 @@ class OpenAIModel(ChatBaseModel):
|
|
|
187
187
|
time.sleep(3)
|
|
188
188
|
logger.error(f'OpenAI API call failed after {self.MAX_RETRIES} retries')
|
|
189
189
|
return res
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
class DummyChatModel(ChatBaseModel):
|
|
193
|
-
|
|
194
|
-
MODEL_ID = 'dummy_chat_model_0801'
|
|
195
|
-
REVISION = 'v1.0.0'
|
|
196
|
-
|
|
197
|
-
def __init__(self, model_cfg: dict, **kwargs):
|
|
198
|
-
model_cfg['model_id'] = self.MODEL_ID
|
|
199
|
-
model_cfg['revision'] = self.REVISION
|
|
200
|
-
super(DummyChatModel, self).__init__(model_cfg=model_cfg)
|
|
201
|
-
|
|
202
|
-
def predict(self, inputs: dict, **kwargs) -> dict:
|
|
203
|
-
|
|
204
|
-
debug: bool = False
|
|
205
|
-
if debug:
|
|
206
|
-
messages = inputs['messages']
|
|
207
|
-
history = inputs['history']
|
|
208
|
-
|
|
209
|
-
logger.info(f'** messages: {messages}')
|
|
210
|
-
logger.info(f'** history: {history}')
|
|
211
|
-
|
|
212
|
-
choice = random.choice(['A', 'B', 'C', 'D'])
|
|
213
|
-
|
|
214
|
-
# Build response
|
|
215
|
-
res = {
|
|
216
|
-
'choices': [{
|
|
217
|
-
'index': 0,
|
|
218
|
-
'message': {
|
|
219
|
-
'content': choice,
|
|
220
|
-
'role': 'assistant'
|
|
221
|
-
}
|
|
222
|
-
}],
|
|
223
|
-
'created': time.time(),
|
|
224
|
-
'model': self.MODEL_ID + '-' + self.REVISION,
|
|
225
|
-
'object': 'chat.completion',
|
|
226
|
-
'usage': {}
|
|
227
|
-
}
|
|
228
|
-
|
|
229
|
-
return res
|
evalscope/models/register.py
CHANGED
|
@@ -1,3 +1,6 @@
|
|
|
1
|
+
from evalscope.constants import OutputType
|
|
2
|
+
from .adapters import *
|
|
3
|
+
|
|
1
4
|
MODEL_ADAPTERS = {}
|
|
2
5
|
|
|
3
6
|
|
|
@@ -26,3 +29,26 @@ def get_model_adapter(name):
|
|
|
26
29
|
raise ValueError(
|
|
27
30
|
f"Model adapter '{name}' is not registered. Available model adapters: {list(MODEL_ADAPTERS.keys())}")
|
|
28
31
|
return MODEL_ADAPTERS[name]
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def register_model_adapter_class(cls, name=None):
|
|
35
|
+
"""
|
|
36
|
+
Register a model adapter class.
|
|
37
|
+
:param cls: The model adapter class to register
|
|
38
|
+
:param name: Optional name for the model adapter. If not provided, the class name will be used.
|
|
39
|
+
"""
|
|
40
|
+
if name is None:
|
|
41
|
+
name = cls.__name__
|
|
42
|
+
if name in MODEL_ADAPTERS:
|
|
43
|
+
raise ValueError(f"Model adapter class '{name}' is already registered.")
|
|
44
|
+
MODEL_ADAPTERS[name] = cls
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
# register all model adapters
|
|
48
|
+
register_model_adapter_class(BaseModelAdapter, name='base')
|
|
49
|
+
register_model_adapter_class(ChatGenerationModelAdapter, name=OutputType.GENERATION)
|
|
50
|
+
register_model_adapter_class(ContinuationLogitsModelAdapter, name=OutputType.LOGITS)
|
|
51
|
+
register_model_adapter_class(MultiChoiceModelAdapter, name=OutputType.MULTIPLE_CHOICE)
|
|
52
|
+
register_model_adapter_class(CustomModelAdapter, name='custom')
|
|
53
|
+
register_model_adapter_class(ServerModelAdapter, name='server')
|
|
54
|
+
register_model_adapter_class(T2IModelAdapter, name=OutputType.IMAGE_GENERATION)
|
evalscope/perf/arguments.py
CHANGED
|
@@ -35,6 +35,7 @@ class Arguments:
|
|
|
35
35
|
log_every_n_query: int = 10 # Log every N queries
|
|
36
36
|
debug: bool = False # Debug mode
|
|
37
37
|
wandb_api_key: Optional[str] = None # WandB API key for logging
|
|
38
|
+
swanlab_api_key: Optional[str] = None # SwanLab API key for logging
|
|
38
39
|
name: Optional[str] = None # Name for the run
|
|
39
40
|
|
|
40
41
|
# Output settings
|
|
@@ -46,6 +47,7 @@ class Arguments:
|
|
|
46
47
|
prefix_length: int = 0 # Length of the prefix, only for random dataset
|
|
47
48
|
prompt: Optional[str] = None # The prompt text
|
|
48
49
|
query_template: Optional[str] = None # Template for the query
|
|
50
|
+
apply_chat_template: Optional[bool] = None # Whether to apply chat template
|
|
49
51
|
|
|
50
52
|
# Dataset settings
|
|
51
53
|
dataset: str = 'openqa' # Dataset type (default: 'line_by_line')
|
|
@@ -57,10 +59,10 @@ class Arguments:
|
|
|
57
59
|
max_tokens: Optional[int] = 2048 # Maximum number of tokens in the response
|
|
58
60
|
min_tokens: Optional[int] = None # Minimum number of tokens in the response
|
|
59
61
|
n_choices: Optional[int] = None # Number of response choices
|
|
60
|
-
seed: Optional[int] =
|
|
62
|
+
seed: Optional[int] = 0 # Random seed for reproducibility
|
|
61
63
|
stop: Optional[List[str]] = field(default_factory=list) # Stop sequences for the response
|
|
62
64
|
stop_token_ids: Optional[List[str]] = field(default_factory=list) # Stop token IDs for the response
|
|
63
|
-
stream: Optional[bool] =
|
|
65
|
+
stream: Optional[bool] = True # Whether to stream the response
|
|
64
66
|
temperature: float = 0.0 # Temperature setting for the response
|
|
65
67
|
top_p: Optional[float] = None # Top-p (nucleus) sampling setting for the response
|
|
66
68
|
top_k: Optional[int] = None # Top-k sampling setting for the response
|
|
@@ -76,12 +78,26 @@ class Arguments:
|
|
|
76
78
|
return Arguments(**args_dict)
|
|
77
79
|
|
|
78
80
|
def __post_init__(self):
|
|
81
|
+
# Set the default headers
|
|
79
82
|
self.headers = self.headers or {} # Default to empty dictionary
|
|
80
83
|
if self.api_key:
|
|
81
84
|
# Assuming the API key is used as a Bearer token
|
|
82
85
|
self.headers['Authorization'] = f'Bearer {self.api_key}'
|
|
86
|
+
|
|
87
|
+
# Set the model ID based on the model name
|
|
83
88
|
self.model_id = os.path.basename(self.model)
|
|
84
89
|
|
|
90
|
+
# Set the URL based on the dataset type
|
|
91
|
+
if self.api.startswith('local'):
|
|
92
|
+
if self.dataset.startswith('speed_benchmark'):
|
|
93
|
+
self.url = f'http://127.0.0.1:{self.port}/v1/completions'
|
|
94
|
+
else:
|
|
95
|
+
self.url = f'http://127.0.0.1:{self.port}/v1/chat/completions'
|
|
96
|
+
|
|
97
|
+
# Set the apply_chat_template flag based on the URL
|
|
98
|
+
if self.apply_chat_template is None:
|
|
99
|
+
self.apply_chat_template = self.url.strip('/').endswith('chat/completions')
|
|
100
|
+
|
|
85
101
|
def __str__(self):
|
|
86
102
|
return json.dumps(self.to_dict(), indent=4, default=str, ensure_ascii=False)
|
|
87
103
|
|
|
@@ -135,7 +151,8 @@ def add_argument(parser: argparse.ArgumentParser):
|
|
|
135
151
|
parser.add_argument('--log-every-n-query', type=int, default=10, help='Logging every n query')
|
|
136
152
|
parser.add_argument('--debug', action='store_true', default=False, help='Debug request send')
|
|
137
153
|
parser.add_argument('--wandb-api-key', type=str, default=None, help='The wandb API key')
|
|
138
|
-
parser.add_argument('--
|
|
154
|
+
parser.add_argument('--swanlab-api-key', type=str, default=None, help='The swanlab API key')
|
|
155
|
+
parser.add_argument('--name', type=str, help='The wandb/swanlab db result name and result db name')
|
|
139
156
|
|
|
140
157
|
# Prompt settings
|
|
141
158
|
parser.add_argument('--max-prompt-length', type=int, default=sys.maxsize, help='Maximum input prompt length')
|
|
@@ -143,6 +160,8 @@ def add_argument(parser: argparse.ArgumentParser):
|
|
|
143
160
|
parser.add_argument('--prefix-length', type=int, default=0, help='The prefix length')
|
|
144
161
|
parser.add_argument('--prompt', type=str, required=False, default=None, help='Specified the request prompt')
|
|
145
162
|
parser.add_argument('--query-template', type=str, default=None, help='Specify the query template')
|
|
163
|
+
parser.add_argument(
|
|
164
|
+
'--apply-chat-template', type=argparse.BooleanOptionalAction, default=None, help='Apply chat template to the prompt') # noqa: E501
|
|
146
165
|
|
|
147
166
|
# Output settings
|
|
148
167
|
parser.add_argument('--outputs-dir', help='Outputs dir.', default='outputs')
|
|
@@ -159,10 +178,10 @@ def add_argument(parser: argparse.ArgumentParser):
|
|
|
159
178
|
parser.add_argument(
|
|
160
179
|
'--min-tokens', type=int, help='The minimum number of tokens that can be generated', default=None)
|
|
161
180
|
parser.add_argument('--n-choices', type=int, help='How many completion choices to generate', default=None)
|
|
162
|
-
parser.add_argument('--seed', type=int, help='The random seed', default=
|
|
181
|
+
parser.add_argument('--seed', type=int, help='The random seed', default=0)
|
|
163
182
|
parser.add_argument('--stop', nargs='*', help='The stop tokens', default=None)
|
|
164
183
|
parser.add_argument('--stop-token-ids', nargs='*', help='Set the stop token IDs', default=None)
|
|
165
|
-
parser.add_argument('--stream', action=
|
|
184
|
+
parser.add_argument('--stream', action=argparse.BooleanOptionalAction, help='Stream output with SSE', default=True)
|
|
166
185
|
parser.add_argument('--temperature', type=float, help='The sample temperature', default=0.0)
|
|
167
186
|
parser.add_argument('--top-p', type=float, help='Sampling top p', default=None)
|
|
168
187
|
parser.add_argument('--top-k', type=int, help='Sampling top k', default=None)
|
evalscope/perf/benchmark.py
CHANGED
|
@@ -18,6 +18,7 @@ from evalscope.perf.utils.benchmark_util import BenchmarkData, BenchmarkMetrics
|
|
|
18
18
|
from evalscope.perf.utils.db_util import create_result_table, get_result_db_path, insert_benchmark_data, summary_result
|
|
19
19
|
from evalscope.perf.utils.handler import add_signal_handlers, exception_handler
|
|
20
20
|
from evalscope.perf.utils.local_server import start_app
|
|
21
|
+
from evalscope.perf.utils.log_utils import init_swanlab, init_wandb
|
|
21
22
|
from evalscope.utils.logger import get_logger
|
|
22
23
|
|
|
23
24
|
logger = get_logger()
|
|
@@ -56,7 +57,7 @@ async def get_requests(args: Arguments) -> AsyncGenerator[dict, None]:
|
|
|
56
57
|
|
|
57
58
|
if args.prompt:
|
|
58
59
|
prompt = load_prompt(args.prompt)
|
|
59
|
-
messages = [{'role': 'user', 'content': prompt}]
|
|
60
|
+
messages = [{'role': 'user', 'content': prompt}] if args.apply_chat_template else prompt
|
|
60
61
|
generator = generate_requests_from_prompt(messages)
|
|
61
62
|
elif args.dataset:
|
|
62
63
|
generator = generate_requests_from_dataset()
|
|
@@ -81,6 +82,7 @@ async def send_request(
|
|
|
81
82
|
client = AioHttpClient(args)
|
|
82
83
|
async with client:
|
|
83
84
|
benchmark_data = BenchmarkData(request=request)
|
|
85
|
+
benchmark_data.start_time = time.perf_counter()
|
|
84
86
|
collected_messages = []
|
|
85
87
|
try:
|
|
86
88
|
async for is_error, state_code, response_data in client.post(request):
|
|
@@ -106,24 +108,18 @@ async def send_request(
|
|
|
106
108
|
|
|
107
109
|
|
|
108
110
|
@exception_handler
|
|
109
|
-
async def
|
|
111
|
+
async def statistic_benchmark_metric(benchmark_data_queue: asyncio.Queue, args: Arguments):
|
|
110
112
|
metrics = BenchmarkMetrics(concurrency=args.parallel)
|
|
111
113
|
|
|
112
114
|
api_plugin_class = ApiRegistry(args.api)
|
|
113
115
|
api_plugin = api_plugin_class(args.tokenizer_path)
|
|
114
116
|
|
|
115
117
|
result_db_path = get_result_db_path(args)
|
|
116
|
-
# Initialize wandb
|
|
117
|
-
if args.wandb_api_key:
|
|
118
|
-
import datetime
|
|
119
|
-
import wandb
|
|
120
|
-
os.environ['WANDB_SILENT'] = 'true'
|
|
121
|
-
os.environ['WANDB_DIR'] = args.outputs_dir
|
|
122
118
|
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
119
|
+
if args.wandb_api_key:
|
|
120
|
+
init_wandb(args)
|
|
121
|
+
if args.swanlab_api_key:
|
|
122
|
+
init_swanlab(args)
|
|
127
123
|
|
|
128
124
|
collected_benchmark_data = []
|
|
129
125
|
|
|
@@ -146,9 +142,13 @@ async def statistic_benchmark_metric_worker(benchmark_data_queue: asyncio.Queue,
|
|
|
146
142
|
# Create a message with the updated metrics
|
|
147
143
|
message = metrics.create_message()
|
|
148
144
|
|
|
149
|
-
# Log the message to wandb if the api key is provided
|
|
145
|
+
# Log the message to wandb\swanlab if the api key is provided
|
|
150
146
|
if args.wandb_api_key:
|
|
147
|
+
import wandb
|
|
151
148
|
wandb.log(message)
|
|
149
|
+
if args.swanlab_api_key:
|
|
150
|
+
import swanlab
|
|
151
|
+
swanlab.log(message)
|
|
152
152
|
|
|
153
153
|
# Log the message to the logger every n queries
|
|
154
154
|
if int(metrics.n_total_queries) % args.log_every_n_query == 0:
|
|
@@ -169,17 +169,12 @@ async def statistic_benchmark_metric_worker(benchmark_data_queue: asyncio.Queue,
|
|
|
169
169
|
|
|
170
170
|
|
|
171
171
|
@exception_handler
|
|
172
|
-
async def
|
|
172
|
+
async def connect_test(args: Arguments) -> bool:
|
|
173
173
|
if args.api.startswith('local'):
|
|
174
174
|
# start local server
|
|
175
175
|
server = threading.Thread(target=start_app, args=(copy.deepcopy(args), ), daemon=True)
|
|
176
176
|
server.start()
|
|
177
177
|
|
|
178
|
-
if args.dataset.startswith('speed_benchmark'):
|
|
179
|
-
args.url = f'http://127.0.0.1:{args.port}/v1/completions'
|
|
180
|
-
else:
|
|
181
|
-
args.url = f'http://127.0.0.1:{args.port}/v1/chat/completions'
|
|
182
|
-
|
|
183
178
|
if (not args.no_test_connection) and (not await test_connection(args)):
|
|
184
179
|
raise TimeoutError('Test connection failed')
|
|
185
180
|
|
|
@@ -192,31 +187,22 @@ async def benchmark(args: Arguments) -> None:
|
|
|
192
187
|
|
|
193
188
|
# init queue
|
|
194
189
|
benchmark_data_queue = asyncio.Queue()
|
|
195
|
-
|
|
196
190
|
# reset event
|
|
197
191
|
data_process_completed_event.clear()
|
|
198
|
-
|
|
192
|
+
# test connection
|
|
193
|
+
await connect_test(args)
|
|
194
|
+
# start statistic benchmark metric
|
|
195
|
+
statistic_benchmark_metric_task = asyncio.create_task(statistic_benchmark_metric(benchmark_data_queue, args))
|
|
196
|
+
# start send request
|
|
199
197
|
semaphore = asyncio.Semaphore(args.parallel)
|
|
198
|
+
send_request_tasks: List[asyncio.Task] = []
|
|
199
|
+
async for request in get_requests(args):
|
|
200
|
+
task = asyncio.create_task(send_request(semaphore, request, benchmark_data_queue, args))
|
|
201
|
+
send_request_tasks.append(task)
|
|
200
202
|
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
task = asyncio.create_task(send_request(semaphore, request, benchmark_data_queue, args))
|
|
205
|
-
tasks.append(task)
|
|
206
|
-
return tasks
|
|
207
|
-
|
|
208
|
-
async def run_tasks():
|
|
209
|
-
await start_server(args)
|
|
210
|
-
|
|
211
|
-
statistic_benchmark_metric_task = asyncio.create_task(
|
|
212
|
-
statistic_benchmark_metric_worker(benchmark_data_queue, args))
|
|
213
|
-
send_request_tasks = await create_send_request_tasks()
|
|
214
|
-
|
|
215
|
-
await asyncio.gather(*send_request_tasks, return_exceptions=True)
|
|
216
|
-
await benchmark_data_queue.join()
|
|
217
|
-
data_process_completed_event.set()
|
|
218
|
-
|
|
219
|
-
metrics, result_db_path = await statistic_benchmark_metric_task
|
|
220
|
-
summary_result(args, metrics, result_db_path)
|
|
203
|
+
await asyncio.gather(*send_request_tasks, return_exceptions=True)
|
|
204
|
+
await benchmark_data_queue.join()
|
|
205
|
+
data_process_completed_event.set()
|
|
221
206
|
|
|
222
|
-
await
|
|
207
|
+
metrics, result_db_path = await statistic_benchmark_metric_task
|
|
208
|
+
summary_result(args, metrics, result_db_path)
|
evalscope/perf/http_client.py
CHANGED
|
@@ -24,7 +24,6 @@ class AioHttpClient:
|
|
|
24
24
|
self.connect_timeout = args.connect_timeout
|
|
25
25
|
self.client = aiohttp.ClientSession(
|
|
26
26
|
timeout=aiohttp.ClientTimeout(connect=self.connect_timeout, sock_read=self.read_timeout),
|
|
27
|
-
connector=aiohttp.TCPConnector(limit=1),
|
|
28
27
|
trace_configs=[self._create_trace_config()] if args.debug else [])
|
|
29
28
|
|
|
30
29
|
def _create_trace_config(self):
|
|
@@ -144,7 +143,7 @@ async def test_connection(args: Arguments) -> bool:
|
|
|
144
143
|
async def attempt_connection():
|
|
145
144
|
client = AioHttpClient(args)
|
|
146
145
|
async with client:
|
|
147
|
-
if
|
|
146
|
+
if args.apply_chat_template:
|
|
148
147
|
request = {
|
|
149
148
|
'messages': [{
|
|
150
149
|
'role': 'user',
|
|
@@ -164,7 +163,7 @@ async def test_connection(args: Arguments) -> bool:
|
|
|
164
163
|
is_error, state_code, response_data = await asyncio.wait_for(
|
|
165
164
|
attempt_connection(), timeout=args.connect_timeout)
|
|
166
165
|
if not is_error:
|
|
167
|
-
logger.info('
|
|
166
|
+
logger.info('Test connection successful.')
|
|
168
167
|
return True
|
|
169
168
|
logger.warning(f'Retrying... <{state_code}> {response_data}')
|
|
170
169
|
except Exception as e:
|
|
@@ -24,7 +24,7 @@ class CustomPlugin(ApiPluginBase):
|
|
|
24
24
|
"""
|
|
25
25
|
super().__init__(model_path=mode_path)
|
|
26
26
|
if mode_path is not None:
|
|
27
|
-
from
|
|
27
|
+
from modelscope import AutoTokenizer
|
|
28
28
|
self.tokenizer = AutoTokenizer.from_pretrained(mode_path)
|
|
29
29
|
else:
|
|
30
30
|
self.tokenizer = None
|
|
@@ -24,7 +24,7 @@ class OpenaiPlugin(ApiPluginBase):
|
|
|
24
24
|
"""
|
|
25
25
|
super().__init__(model_path=mode_path)
|
|
26
26
|
if mode_path is not None:
|
|
27
|
-
from
|
|
27
|
+
from modelscope import AutoTokenizer
|
|
28
28
|
self.tokenizer = AutoTokenizer.from_pretrained(mode_path)
|
|
29
29
|
else:
|
|
30
30
|
self.tokenizer = None
|
|
@@ -70,7 +70,7 @@ class OpenaiPlugin(ApiPluginBase):
|
|
|
70
70
|
def __compose_query_from_parameter(self, payload: Dict, param: Arguments):
|
|
71
71
|
payload['model'] = param.model
|
|
72
72
|
if param.max_tokens is not None:
|
|
73
|
-
payload['
|
|
73
|
+
payload['max_tokens'] = param.max_tokens
|
|
74
74
|
if param.min_tokens is not None:
|
|
75
75
|
payload['min_tokens'] = param.min_tokens
|
|
76
76
|
if param.frequency_penalty is not None:
|
|
@@ -18,4 +18,7 @@ class CustomDatasetPlugin(DatasetPluginBase):
|
|
|
18
18
|
prompt = item.strip()
|
|
19
19
|
if len(prompt) > self.query_parameters.min_prompt_length and len(
|
|
20
20
|
prompt) < self.query_parameters.max_prompt_length:
|
|
21
|
-
|
|
21
|
+
if self.query_parameters.apply_chat_template:
|
|
22
|
+
yield [{'role': 'user', 'content': prompt}]
|
|
23
|
+
else:
|
|
24
|
+
yield prompt
|
|
@@ -30,6 +30,7 @@ class FlickrDatasetPlugin(DatasetPluginBase):
|
|
|
30
30
|
|
|
31
31
|
for item in dataset:
|
|
32
32
|
pil_image = item['jpg']
|
|
33
|
+
text = item['txt']
|
|
33
34
|
base64_iamge = PIL_to_base64(pil_image)
|
|
34
35
|
|
|
35
36
|
yield [{
|
|
@@ -38,7 +39,7 @@ class FlickrDatasetPlugin(DatasetPluginBase):
|
|
|
38
39
|
'content': [
|
|
39
40
|
{
|
|
40
41
|
'type': 'text',
|
|
41
|
-
'text':
|
|
42
|
+
'text': text,
|
|
42
43
|
},
|
|
43
44
|
{
|
|
44
45
|
'type': 'image_url',
|
|
@@ -19,4 +19,7 @@ class LineByLineDatasetPlugin(DatasetPluginBase):
|
|
|
19
19
|
prompt = item.strip()
|
|
20
20
|
if len(prompt) > self.query_parameters.min_prompt_length and len(
|
|
21
21
|
prompt) < self.query_parameters.max_prompt_length:
|
|
22
|
-
|
|
22
|
+
if self.query_parameters.apply_chat_template:
|
|
23
|
+
yield [{'role': 'user', 'content': prompt}]
|
|
24
|
+
else:
|
|
25
|
+
yield prompt
|
|
@@ -24,4 +24,7 @@ class LongAlpacaDatasetPlugin(DatasetPluginBase):
|
|
|
24
24
|
prompt = item['instruction'].strip()
|
|
25
25
|
if len(prompt) > self.query_parameters.min_prompt_length and len(
|
|
26
26
|
prompt) < self.query_parameters.max_prompt_length:
|
|
27
|
-
|
|
27
|
+
if self.query_parameters.apply_chat_template:
|
|
28
|
+
yield [{'role': 'user', 'content': prompt}]
|
|
29
|
+
else:
|
|
30
|
+
yield prompt
|
|
@@ -29,4 +29,7 @@ class OpenqaDatasetPlugin(DatasetPluginBase):
|
|
|
29
29
|
prompt = item['question'].strip()
|
|
30
30
|
if (len(prompt) > self.query_parameters.min_prompt_length
|
|
31
31
|
and len(prompt) < self.query_parameters.max_prompt_length):
|
|
32
|
-
|
|
32
|
+
if self.query_parameters.apply_chat_template:
|
|
33
|
+
yield [{'role': 'user', 'content': prompt}]
|
|
34
|
+
else:
|
|
35
|
+
yield prompt
|
|
@@ -23,8 +23,12 @@ class RandomDatasetPlugin(DatasetPluginBase):
|
|
|
23
23
|
self.number = self.query_parameters.number or 1
|
|
24
24
|
|
|
25
25
|
def build_messages(self) -> Iterator[List[Dict]]:
|
|
26
|
-
|
|
27
|
-
|
|
26
|
+
if self.query_parameters.apply_chat_template:
|
|
27
|
+
min_prompt_length = self.query_parameters.min_prompt_length - self.template_len
|
|
28
|
+
max_prompt_length = self.query_parameters.max_prompt_length - self.template_len + 1
|
|
29
|
+
else:
|
|
30
|
+
min_prompt_length = self.query_parameters.min_prompt_length
|
|
31
|
+
max_prompt_length = self.query_parameters.max_prompt_length + 1
|
|
28
32
|
|
|
29
33
|
assert min_prompt_length >= 0, f'min_prompt_length should be greater than or equal to the template length {self.template_len}.' # noqa: E501
|
|
30
34
|
assert max_prompt_length >= min_prompt_length, 'max_prompt_length should be greater than or equal to min_prompt_length.' # noqa: E501
|
|
@@ -34,10 +38,13 @@ class RandomDatasetPlugin(DatasetPluginBase):
|
|
|
34
38
|
offsets = np.random.randint(0, self.tokenizer.vocab_size, size=self.number)
|
|
35
39
|
|
|
36
40
|
for i in range(self.number):
|
|
37
|
-
prompt_ids = (offsets[i] + i + np.arange(input_lens[i])) % self.tokenizer.vocab_size
|
|
38
|
-
prompt = self.tokenizer.decode(
|
|
39
|
-
|
|
40
|
-
|
|
41
|
+
prompt_ids = ((offsets[i] + i + np.arange(input_lens[i])) % self.tokenizer.vocab_size).tolist()
|
|
42
|
+
prompt = self.tokenizer.decode(self.prefix_ids + prompt_ids)
|
|
43
|
+
|
|
44
|
+
if self.query_parameters.apply_chat_template:
|
|
45
|
+
yield [{'role': 'user', 'content': prompt}]
|
|
46
|
+
else:
|
|
47
|
+
yield prompt
|
|
41
48
|
|
|
42
49
|
def get_random_inputs(self, length: int) -> List[int]:
|
|
43
50
|
if length <= 0:
|