evalscope 0.13.2__py3-none-any.whl → 0.15.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of evalscope might be problematic. Click here for more details.

Files changed (214) hide show
  1. evalscope/arguments.py +2 -1
  2. evalscope/backend/rag_eval/__init__.py +1 -1
  3. evalscope/backend/rag_eval/backend_manager.py +21 -5
  4. evalscope/backend/rag_eval/cmteb/arguments.py +10 -0
  5. evalscope/backend/rag_eval/ragas/arguments.py +0 -1
  6. evalscope/backend/rag_eval/ragas/tasks/testset_generation.py +7 -2
  7. evalscope/backend/rag_eval/ragas/tasks/translate_prompt.py +0 -5
  8. evalscope/backend/rag_eval/utils/embedding.py +49 -3
  9. evalscope/backend/rag_eval/utils/llm.py +4 -4
  10. evalscope/backend/vlm_eval_kit/backend_manager.py +4 -2
  11. evalscope/benchmarks/__init__.py +2 -2
  12. evalscope/benchmarks/aigc/__init__.py +0 -0
  13. evalscope/benchmarks/aigc/t2i/__init__.py +0 -0
  14. evalscope/benchmarks/aigc/t2i/base.py +56 -0
  15. evalscope/benchmarks/aigc/t2i/evalmuse_adapter.py +77 -0
  16. evalscope/benchmarks/aigc/t2i/genai_bench_adapter.py +58 -0
  17. evalscope/benchmarks/aigc/t2i/general_t2i_adapter.py +58 -0
  18. evalscope/benchmarks/aigc/t2i/hpdv2_adapter.py +57 -0
  19. evalscope/benchmarks/aigc/t2i/tifa_adapter.py +37 -0
  20. evalscope/benchmarks/aime/aime24_adapter.py +1 -1
  21. evalscope/benchmarks/aime/aime25_adapter.py +4 -4
  22. evalscope/benchmarks/alpaca_eval/alpaca_eval_adapter.py +1 -2
  23. evalscope/benchmarks/arc/arc_adapter.py +2 -2
  24. evalscope/benchmarks/arena_hard/arena_hard_adapter.py +1 -3
  25. evalscope/benchmarks/ceval/ceval_adapter.py +2 -2
  26. evalscope/benchmarks/chinese_simple_qa/csimple_qa_adapter.py +1 -3
  27. evalscope/benchmarks/cmmlu/cmmlu_adapter.py +1 -1
  28. evalscope/benchmarks/competition_math/competition_math_adapter.py +1 -2
  29. evalscope/benchmarks/data_adapter.py +21 -10
  30. evalscope/benchmarks/data_collection/data_collection_adapter.py +6 -4
  31. evalscope/benchmarks/general_mcq/general_mcq_adapter.py +2 -2
  32. evalscope/benchmarks/general_qa/general_qa_adapter.py +1 -1
  33. evalscope/benchmarks/hellaswag/hellaswag_adapter.py +1 -1
  34. evalscope/benchmarks/live_code_bench/evaluate_utils.py +16 -21
  35. evalscope/benchmarks/live_code_bench/live_code_bench_adapter.py +5 -4
  36. evalscope/benchmarks/live_code_bench/testing_util.py +369 -550
  37. evalscope/benchmarks/maritime_bench/__init__.py +0 -0
  38. evalscope/benchmarks/maritime_bench/maritime_bench_adapter.py +79 -0
  39. evalscope/benchmarks/math_500/math_500_adapter.py +1 -1
  40. evalscope/benchmarks/mmlu/mmlu_adapter.py +8 -8
  41. evalscope/benchmarks/mmlu_pro/mmlu_pro_adapter.py +1 -1
  42. evalscope/benchmarks/mmlu_redux/mmlu_redux_adapter.py +1 -1
  43. evalscope/benchmarks/musr/musr_adapter.py +1 -1
  44. evalscope/benchmarks/simple_qa/simple_qa_adapter.py +1 -2
  45. evalscope/benchmarks/utils.py +7 -16
  46. evalscope/cli/start_app.py +1 -1
  47. evalscope/collections/evaluator.py +20 -6
  48. evalscope/config.py +8 -4
  49. evalscope/constants.py +11 -0
  50. evalscope/evaluator/evaluator.py +2 -2
  51. evalscope/evaluator/reviewer/auto_reviewer.py +1 -1
  52. evalscope/metrics/__init__.py +49 -4
  53. evalscope/metrics/llm_judge.py +1 -1
  54. evalscope/metrics/named_metrics.py +13 -0
  55. evalscope/metrics/t2v_metrics/__init__.py +66 -0
  56. evalscope/metrics/t2v_metrics/clipscore.py +14 -0
  57. evalscope/metrics/t2v_metrics/constants.py +12 -0
  58. evalscope/metrics/t2v_metrics/itmscore.py +14 -0
  59. evalscope/metrics/t2v_metrics/models/__init__.py +0 -0
  60. evalscope/metrics/t2v_metrics/models/clipscore_models/__init__.py +30 -0
  61. evalscope/metrics/t2v_metrics/models/clipscore_models/build_mps_model/__init__.py +0 -0
  62. evalscope/metrics/t2v_metrics/models/clipscore_models/build_mps_model/base_model.py +6 -0
  63. evalscope/metrics/t2v_metrics/models/clipscore_models/build_mps_model/clip_model.py +132 -0
  64. evalscope/metrics/t2v_metrics/models/clipscore_models/build_mps_model/cross_modeling.py +286 -0
  65. evalscope/metrics/t2v_metrics/models/clipscore_models/clip_model.py +114 -0
  66. evalscope/metrics/t2v_metrics/models/clipscore_models/hpsv2_model.py +86 -0
  67. evalscope/metrics/t2v_metrics/models/clipscore_models/mps_model.py +85 -0
  68. evalscope/metrics/t2v_metrics/models/clipscore_models/pickscore_model.py +62 -0
  69. evalscope/metrics/t2v_metrics/models/itmscore_models/__init__.py +26 -0
  70. evalscope/metrics/t2v_metrics/models/itmscore_models/blip2_itm_model.py +84 -0
  71. evalscope/metrics/t2v_metrics/models/itmscore_models/fga_blip2_model.py +97 -0
  72. evalscope/metrics/t2v_metrics/models/itmscore_models/image_reward/ImageReward.py +171 -0
  73. evalscope/metrics/t2v_metrics/models/itmscore_models/image_reward/__init__.py +0 -0
  74. evalscope/metrics/t2v_metrics/models/itmscore_models/image_reward/blip_pretrain.py +80 -0
  75. evalscope/metrics/t2v_metrics/models/itmscore_models/image_reward_model.py +73 -0
  76. evalscope/metrics/t2v_metrics/models/model.py +45 -0
  77. evalscope/metrics/t2v_metrics/models/utils.py +25 -0
  78. evalscope/metrics/t2v_metrics/models/vqascore_models/__init__.py +22 -0
  79. evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5/__init__.py +0 -0
  80. evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5/model/__init__.py +1 -0
  81. evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5/model/language_model/clip_t5.py +300 -0
  82. evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5/model/multimodal_encoder/builder.py +12 -0
  83. evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5/model/multimodal_encoder/clip_encoder.py +82 -0
  84. evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5/model/multimodal_projector/builder.py +50 -0
  85. evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5_model.py +218 -0
  86. evalscope/metrics/t2v_metrics/models/vqascore_models/gpt4v_model.py +150 -0
  87. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/__init__.py +26 -0
  88. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/config.py +465 -0
  89. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/dist_utils.py +141 -0
  90. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/gradcam.py +22 -0
  91. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/logger.py +188 -0
  92. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/optims.py +106 -0
  93. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/registry.py +307 -0
  94. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/utils.py +416 -0
  95. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/vqa_tools/__init__.py +8 -0
  96. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/vqa_tools/vqa.py +191 -0
  97. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/vqa_tools/vqa_eval.py +318 -0
  98. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/default.yaml +10 -0
  99. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_caption_flant5xl.yaml +42 -0
  100. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_caption_opt2.7b.yaml +42 -0
  101. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_caption_opt6.7b.yaml +42 -0
  102. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_coco.yaml +36 -0
  103. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_instruct_flant5xl.yaml +43 -0
  104. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_instruct_flant5xxl.yaml +43 -0
  105. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_instruct_vicuna13b.yaml +43 -0
  106. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_instruct_vicuna7b.yaml +43 -0
  107. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain.yaml +36 -0
  108. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_flant5xl.yaml +42 -0
  109. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_flant5xl_iter_80k_total_100k_no_prefix.yaml +42 -0
  110. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_flant5xl_iter_80k_total_100k_prefix.yaml +42 -0
  111. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_flant5xl_vitL.yaml +43 -0
  112. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_flant5xxl.yaml +42 -0
  113. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_opt2.7b.yaml +42 -0
  114. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_opt6.7b.yaml +42 -0
  115. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_pretrain_vitL.yaml +37 -0
  116. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_vicuna13b.yaml +43 -0
  117. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/blip2/blip2_vicuna7b.yaml +43 -0
  118. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/med_config.json +21 -0
  119. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/med_config_albef.json +22 -0
  120. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/configs/models/med_large_config.json +21 -0
  121. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/__init__.py +208 -0
  122. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/base_model.py +231 -0
  123. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/Qformer.py +1093 -0
  124. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/__init__.py +0 -0
  125. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/blip2.py +211 -0
  126. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/blip2_image_text_matching.py +109 -0
  127. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/blip2_qformer.py +452 -0
  128. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/blip2_t5.py +364 -0
  129. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/blip2_t5_instruct.py +755 -0
  130. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/fga_blip2.py +273 -0
  131. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/modeling_llama.py +880 -0
  132. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/modeling_t5.py +1844 -0
  133. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/__init__.py +81 -0
  134. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip.py +56 -0
  135. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_caption.py +212 -0
  136. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_classification.py +164 -0
  137. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_feature_extractor.py +202 -0
  138. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_image_text_matching.py +185 -0
  139. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_nlvr.py +178 -0
  140. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_outputs.py +112 -0
  141. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_pretrain.py +371 -0
  142. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_vqa.py +344 -0
  143. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/nlvr_encoder.py +858 -0
  144. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/clip_vit.py +271 -0
  145. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/eva_vit.py +503 -0
  146. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/med.py +1270 -0
  147. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/vit.py +473 -0
  148. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/processors/__init__.py +31 -0
  149. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/processors/base_processor.py +27 -0
  150. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/processors/blip_processors.py +233 -0
  151. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/processors/randaugment.py +392 -0
  152. evalscope/metrics/t2v_metrics/models/vqascore_models/mm_utils.py +127 -0
  153. evalscope/metrics/t2v_metrics/models/vqascore_models/vqa_model.py +17 -0
  154. evalscope/metrics/t2v_metrics/score.py +78 -0
  155. evalscope/metrics/t2v_metrics/vqascore.py +14 -0
  156. evalscope/models/__init__.py +50 -14
  157. evalscope/models/adapters/__init__.py +17 -0
  158. evalscope/models/{base_adapter.py → adapters/base_adapter.py} +17 -17
  159. evalscope/models/{chat_adapter.py → adapters/chat_adapter.py} +10 -7
  160. evalscope/models/{choice_adapter.py → adapters/choice_adapter.py} +2 -6
  161. evalscope/models/{custom_adapter.py → adapters/custom_adapter.py} +2 -4
  162. evalscope/models/{server_adapter.py → adapters/server_adapter.py} +1 -3
  163. evalscope/models/adapters/t2i_adapter.py +76 -0
  164. evalscope/models/custom/__init__.py +2 -1
  165. evalscope/models/custom/dummy_model.py +11 -13
  166. evalscope/models/local_model.py +82 -33
  167. evalscope/models/model.py +2 -42
  168. evalscope/models/register.py +26 -0
  169. evalscope/perf/arguments.py +24 -5
  170. evalscope/perf/benchmark.py +28 -42
  171. evalscope/perf/http_client.py +2 -3
  172. evalscope/perf/plugin/api/custom_api.py +1 -1
  173. evalscope/perf/plugin/api/openai_api.py +2 -2
  174. evalscope/perf/plugin/datasets/custom.py +4 -1
  175. evalscope/perf/plugin/datasets/flickr8k.py +2 -1
  176. evalscope/perf/plugin/datasets/line_by_line.py +4 -1
  177. evalscope/perf/plugin/datasets/longalpaca.py +4 -1
  178. evalscope/perf/plugin/datasets/openqa.py +4 -1
  179. evalscope/perf/plugin/datasets/random_dataset.py +13 -6
  180. evalscope/perf/utils/benchmark_util.py +14 -8
  181. evalscope/perf/utils/db_util.py +9 -3
  182. evalscope/perf/utils/log_utils.py +41 -0
  183. evalscope/report/__init__.py +1 -0
  184. evalscope/report/app.py +128 -78
  185. evalscope/report/app_arguments.py +11 -0
  186. evalscope/report/generator.py +1 -1
  187. evalscope/run.py +10 -3
  188. evalscope/summarizer.py +2 -1
  189. evalscope/third_party/thinkbench/eval.py +19 -7
  190. evalscope/utils/chat_service.py +2 -2
  191. evalscope/utils/import_utils.py +66 -0
  192. evalscope/utils/utils.py +48 -29
  193. evalscope/version.py +2 -2
  194. {evalscope-0.13.2.dist-info → evalscope-0.15.0.dist-info}/METADATA +37 -15
  195. {evalscope-0.13.2.dist-info → evalscope-0.15.0.dist-info}/RECORD +209 -96
  196. tests/aigc/__init__.py +1 -0
  197. tests/aigc/test_t2i.py +87 -0
  198. tests/cli/test_all.py +4 -4
  199. tests/cli/test_collection.py +2 -1
  200. tests/cli/test_run.py +19 -12
  201. tests/perf/test_perf.py +3 -3
  202. tests/rag/test_clip_benchmark.py +0 -1
  203. tests/rag/test_mteb.py +37 -8
  204. tests/rag/test_ragas.py +29 -26
  205. tests/vlm/test_vlmeval.py +37 -1
  206. evalscope/backend/vlm_eval_kit/custom_dataset.py +0 -46
  207. evalscope/benchmarks/live_code_bench/execute_utils.py +0 -267
  208. evalscope/metrics/code_metric.py +0 -98
  209. evalscope/metrics/resources/gpt2-zhcn3-v4.bpe +0 -58485
  210. evalscope/metrics/resources/gpt2-zhcn3-v4.json +0 -1
  211. {evalscope-0.13.2.dist-info → evalscope-0.15.0.dist-info}/LICENSE +0 -0
  212. {evalscope-0.13.2.dist-info → evalscope-0.15.0.dist-info}/WHEEL +0 -0
  213. {evalscope-0.13.2.dist-info → evalscope-0.15.0.dist-info}/entry_points.txt +0 -0
  214. {evalscope-0.13.2.dist-info → evalscope-0.15.0.dist-info}/top_level.txt +0 -0
@@ -1,7 +1,8 @@
1
- import torch
1
+ import importlib
2
+ from abc import ABC, abstractmethod
2
3
  from typing import TYPE_CHECKING, Optional
3
4
 
4
- from evalscope.constants import DEFAULT_MODEL_CACHE_DIR, DEFAULT_MODEL_REVISION, EvalType
5
+ from evalscope.constants import DEFAULT_MODEL_CACHE_DIR, DEFAULT_MODEL_REVISION, EvalType, ModelTask
5
6
  from evalscope.utils.logger import get_logger
6
7
  from evalscope.utils.model_utils import get_device
7
8
 
@@ -11,31 +12,55 @@ if TYPE_CHECKING:
11
12
  logger = get_logger()
12
13
 
13
14
 
14
- class LocalModel:
15
+ class LocalModel(ABC):
15
16
 
16
17
  def __init__(self,
17
18
  model_id: str,
18
- model_revision: str = DEFAULT_MODEL_REVISION,
19
- device_map: str = 'auto',
19
+ model_revision: str = None,
20
+ device_map: str = None,
20
21
  torch_dtype: str = 'auto',
21
22
  cache_dir: str = None,
22
23
  **kwargs):
23
- from modelscope import AutoModelForCausalLM, AutoTokenizer
24
24
 
25
- model_cache_dir = cache_dir or DEFAULT_MODEL_CACHE_DIR
25
+ self.model_id = model_id
26
+ self.model_revision = model_revision or DEFAULT_MODEL_REVISION
27
+ self.device = device_map or get_device()
28
+ self.cache_dir = cache_dir or DEFAULT_MODEL_CACHE_DIR
29
+ self.kwargs = kwargs
30
+ self.model = None
31
+ self.tokenizer = None
26
32
 
27
33
  if isinstance(torch_dtype, str) and torch_dtype != 'auto':
34
+ import torch
28
35
  torch_dtype = eval(torch_dtype)
36
+ self.torch_dtype = torch_dtype
37
+
38
+ self.model_cfg = {
39
+ 'model_id': self.model_id,
40
+ 'device_map': self.device,
41
+ 'torch_dtype': str(self.torch_dtype),
42
+ }
43
+
44
+ @abstractmethod
45
+ def load_model(self):
46
+ pass
29
47
 
30
- self.model_id = model_id
31
- self.model_revision = model_revision
32
- self.device = device_map
48
+
49
+ class LocalChatModel(LocalModel):
50
+
51
+ def __init__(self, **kwargs):
52
+ super().__init__(**kwargs)
53
+
54
+ def load_model(self):
55
+ from modelscope import AutoModelForCausalLM, AutoTokenizer
56
+
57
+ logger.info(f'Loading model {self.model_id} ...')
33
58
 
34
59
  self.tokenizer = AutoTokenizer.from_pretrained(
35
60
  self.model_id,
36
- revision=model_revision,
61
+ revision=self.model_revision,
37
62
  trust_remote_code=True,
38
- cache_dir=model_cache_dir,
63
+ cache_dir=self.cache_dir,
39
64
  )
40
65
 
41
66
  # Fix no padding
@@ -44,18 +69,45 @@ class LocalModel:
44
69
 
45
70
  self.model = AutoModelForCausalLM.from_pretrained(
46
71
  self.model_id,
47
- revision=model_revision,
48
- device_map=device_map,
72
+ revision=self.model_revision,
73
+ device_map=self.device,
49
74
  trust_remote_code=True,
50
- torch_dtype=torch_dtype,
51
- cache_dir=model_cache_dir,
75
+ torch_dtype=self.torch_dtype,
76
+ cache_dir=self.cache_dir,
52
77
  )
53
78
 
54
- self.model_cfg = {
55
- 'model_id': model_id,
56
- 'device_map': device_map,
57
- 'torch_dtype': str(torch_dtype),
58
- }
79
+
80
+ class LocalImageModel(LocalModel):
81
+
82
+ def __init__(self, **kwargs):
83
+ super().__init__(**kwargs)
84
+
85
+ self.pipeline_cls = kwargs.pop('pipeline_cls', None)
86
+ # default to DiffusionPipeline if not specified
87
+ if self.pipeline_cls is None:
88
+ if 'flux' in self.model_id.lower():
89
+ self.pipeline_cls = 'FluxPipeline'
90
+ else:
91
+ self.pipeline_cls = 'DiffusionPipeline'
92
+
93
+ def load_model(self):
94
+ # from modelscope import pipeline_cls
95
+ module = getattr(importlib.import_module('modelscope'), self.pipeline_cls)
96
+
97
+ logger.info(f'Loading model {self.model_id} with {self.pipeline_cls} ...')
98
+
99
+ self.model = module.from_pretrained(
100
+ self.model_id,
101
+ revision=self.model_revision,
102
+ torch_dtype=self.torch_dtype,
103
+ cache_dir=self.cache_dir,
104
+ **self.kwargs,
105
+ )
106
+
107
+ self.model.to(self.device)
108
+
109
+ def __call__(self, *args, **kwargs):
110
+ return self.model(*args, **kwargs)
59
111
 
60
112
 
61
113
  def get_local_model(task_cfg: 'TaskConfig') -> Optional[LocalModel]:
@@ -64,16 +116,13 @@ def get_local_model(task_cfg: 'TaskConfig') -> Optional[LocalModel]:
64
116
  """
65
117
  if task_cfg.eval_type != EvalType.CHECKPOINT:
66
118
  return None
67
- else:
68
- device_map = task_cfg.model_args.get('device_map', get_device())
69
- cache_dir = task_cfg.model_args.get('cache_dir', None)
70
- model_precision = task_cfg.model_args.get('precision', 'torch.float16')
71
- model_revision = task_cfg.model_args.get('revision', DEFAULT_MODEL_REVISION)
72
-
73
- base_model = LocalModel(
74
- model_id=task_cfg.model,
75
- model_revision=model_revision,
76
- device_map=device_map,
77
- torch_dtype=model_precision,
78
- cache_dir=cache_dir)
119
+ elif task_cfg.model_task == ModelTask.TEXT_GENERATION:
120
+ base_model = LocalChatModel(model_id=task_cfg.model, **task_cfg.model_args)
121
+ base_model.load_model()
122
+ return base_model
123
+ elif task_cfg.model_task == ModelTask.IMAGE_GENERATION:
124
+ base_model = LocalImageModel(model_id=task_cfg.model, **task_cfg.model_args)
125
+ base_model.load_model()
79
126
  return base_model
127
+ else:
128
+ raise ValueError(f'Unsupported model task: {task_cfg.model_task} for model checkpoint.')
evalscope/models/model.py CHANGED
@@ -1,9 +1,8 @@
1
1
  # Copyright (c) Alibaba, Inc. and its affiliates.
2
2
  import os
3
- import random
4
3
  import time
5
4
  from abc import ABC, abstractmethod
6
- from typing import Any
5
+ from typing import Any, List
7
6
 
8
7
  from evalscope.utils.logger import get_logger
9
8
 
@@ -95,6 +94,7 @@ class ChatBaseModel(BaseModel):
95
94
  raise NotImplementedError
96
95
 
97
96
 
97
+ # TODO: Remove this class after refactoring all models
98
98
  class OpenAIModel(ChatBaseModel):
99
99
  """
100
100
  APIs of OpenAI models.
@@ -187,43 +187,3 @@ class OpenAIModel(ChatBaseModel):
187
187
  time.sleep(3)
188
188
  logger.error(f'OpenAI API call failed after {self.MAX_RETRIES} retries')
189
189
  return res
190
-
191
-
192
- class DummyChatModel(ChatBaseModel):
193
-
194
- MODEL_ID = 'dummy_chat_model_0801'
195
- REVISION = 'v1.0.0'
196
-
197
- def __init__(self, model_cfg: dict, **kwargs):
198
- model_cfg['model_id'] = self.MODEL_ID
199
- model_cfg['revision'] = self.REVISION
200
- super(DummyChatModel, self).__init__(model_cfg=model_cfg)
201
-
202
- def predict(self, inputs: dict, **kwargs) -> dict:
203
-
204
- debug: bool = False
205
- if debug:
206
- messages = inputs['messages']
207
- history = inputs['history']
208
-
209
- logger.info(f'** messages: {messages}')
210
- logger.info(f'** history: {history}')
211
-
212
- choice = random.choice(['A', 'B', 'C', 'D'])
213
-
214
- # Build response
215
- res = {
216
- 'choices': [{
217
- 'index': 0,
218
- 'message': {
219
- 'content': choice,
220
- 'role': 'assistant'
221
- }
222
- }],
223
- 'created': time.time(),
224
- 'model': self.MODEL_ID + '-' + self.REVISION,
225
- 'object': 'chat.completion',
226
- 'usage': {}
227
- }
228
-
229
- return res
@@ -1,3 +1,6 @@
1
+ from evalscope.constants import OutputType
2
+ from .adapters import *
3
+
1
4
  MODEL_ADAPTERS = {}
2
5
 
3
6
 
@@ -26,3 +29,26 @@ def get_model_adapter(name):
26
29
  raise ValueError(
27
30
  f"Model adapter '{name}' is not registered. Available model adapters: {list(MODEL_ADAPTERS.keys())}")
28
31
  return MODEL_ADAPTERS[name]
32
+
33
+
34
+ def register_model_adapter_class(cls, name=None):
35
+ """
36
+ Register a model adapter class.
37
+ :param cls: The model adapter class to register
38
+ :param name: Optional name for the model adapter. If not provided, the class name will be used.
39
+ """
40
+ if name is None:
41
+ name = cls.__name__
42
+ if name in MODEL_ADAPTERS:
43
+ raise ValueError(f"Model adapter class '{name}' is already registered.")
44
+ MODEL_ADAPTERS[name] = cls
45
+
46
+
47
+ # register all model adapters
48
+ register_model_adapter_class(BaseModelAdapter, name='base')
49
+ register_model_adapter_class(ChatGenerationModelAdapter, name=OutputType.GENERATION)
50
+ register_model_adapter_class(ContinuationLogitsModelAdapter, name=OutputType.LOGITS)
51
+ register_model_adapter_class(MultiChoiceModelAdapter, name=OutputType.MULTIPLE_CHOICE)
52
+ register_model_adapter_class(CustomModelAdapter, name='custom')
53
+ register_model_adapter_class(ServerModelAdapter, name='server')
54
+ register_model_adapter_class(T2IModelAdapter, name=OutputType.IMAGE_GENERATION)
@@ -35,6 +35,7 @@ class Arguments:
35
35
  log_every_n_query: int = 10 # Log every N queries
36
36
  debug: bool = False # Debug mode
37
37
  wandb_api_key: Optional[str] = None # WandB API key for logging
38
+ swanlab_api_key: Optional[str] = None # SwanLab API key for logging
38
39
  name: Optional[str] = None # Name for the run
39
40
 
40
41
  # Output settings
@@ -46,6 +47,7 @@ class Arguments:
46
47
  prefix_length: int = 0 # Length of the prefix, only for random dataset
47
48
  prompt: Optional[str] = None # The prompt text
48
49
  query_template: Optional[str] = None # Template for the query
50
+ apply_chat_template: Optional[bool] = None # Whether to apply chat template
49
51
 
50
52
  # Dataset settings
51
53
  dataset: str = 'openqa' # Dataset type (default: 'line_by_line')
@@ -57,10 +59,10 @@ class Arguments:
57
59
  max_tokens: Optional[int] = 2048 # Maximum number of tokens in the response
58
60
  min_tokens: Optional[int] = None # Minimum number of tokens in the response
59
61
  n_choices: Optional[int] = None # Number of response choices
60
- seed: Optional[int] = 42 # Random seed for reproducibility
62
+ seed: Optional[int] = 0 # Random seed for reproducibility
61
63
  stop: Optional[List[str]] = field(default_factory=list) # Stop sequences for the response
62
64
  stop_token_ids: Optional[List[str]] = field(default_factory=list) # Stop token IDs for the response
63
- stream: Optional[bool] = False # Whether to stream the response
65
+ stream: Optional[bool] = True # Whether to stream the response
64
66
  temperature: float = 0.0 # Temperature setting for the response
65
67
  top_p: Optional[float] = None # Top-p (nucleus) sampling setting for the response
66
68
  top_k: Optional[int] = None # Top-k sampling setting for the response
@@ -76,12 +78,26 @@ class Arguments:
76
78
  return Arguments(**args_dict)
77
79
 
78
80
  def __post_init__(self):
81
+ # Set the default headers
79
82
  self.headers = self.headers or {} # Default to empty dictionary
80
83
  if self.api_key:
81
84
  # Assuming the API key is used as a Bearer token
82
85
  self.headers['Authorization'] = f'Bearer {self.api_key}'
86
+
87
+ # Set the model ID based on the model name
83
88
  self.model_id = os.path.basename(self.model)
84
89
 
90
+ # Set the URL based on the dataset type
91
+ if self.api.startswith('local'):
92
+ if self.dataset.startswith('speed_benchmark'):
93
+ self.url = f'http://127.0.0.1:{self.port}/v1/completions'
94
+ else:
95
+ self.url = f'http://127.0.0.1:{self.port}/v1/chat/completions'
96
+
97
+ # Set the apply_chat_template flag based on the URL
98
+ if self.apply_chat_template is None:
99
+ self.apply_chat_template = self.url.strip('/').endswith('chat/completions')
100
+
85
101
  def __str__(self):
86
102
  return json.dumps(self.to_dict(), indent=4, default=str, ensure_ascii=False)
87
103
 
@@ -135,7 +151,8 @@ def add_argument(parser: argparse.ArgumentParser):
135
151
  parser.add_argument('--log-every-n-query', type=int, default=10, help='Logging every n query')
136
152
  parser.add_argument('--debug', action='store_true', default=False, help='Debug request send')
137
153
  parser.add_argument('--wandb-api-key', type=str, default=None, help='The wandb API key')
138
- parser.add_argument('--name', type=str, help='The wandb db result name and result db name')
154
+ parser.add_argument('--swanlab-api-key', type=str, default=None, help='The swanlab API key')
155
+ parser.add_argument('--name', type=str, help='The wandb/swanlab db result name and result db name')
139
156
 
140
157
  # Prompt settings
141
158
  parser.add_argument('--max-prompt-length', type=int, default=sys.maxsize, help='Maximum input prompt length')
@@ -143,6 +160,8 @@ def add_argument(parser: argparse.ArgumentParser):
143
160
  parser.add_argument('--prefix-length', type=int, default=0, help='The prefix length')
144
161
  parser.add_argument('--prompt', type=str, required=False, default=None, help='Specified the request prompt')
145
162
  parser.add_argument('--query-template', type=str, default=None, help='Specify the query template')
163
+ parser.add_argument(
164
+ '--apply-chat-template', type=argparse.BooleanOptionalAction, default=None, help='Apply chat template to the prompt') # noqa: E501
146
165
 
147
166
  # Output settings
148
167
  parser.add_argument('--outputs-dir', help='Outputs dir.', default='outputs')
@@ -159,10 +178,10 @@ def add_argument(parser: argparse.ArgumentParser):
159
178
  parser.add_argument(
160
179
  '--min-tokens', type=int, help='The minimum number of tokens that can be generated', default=None)
161
180
  parser.add_argument('--n-choices', type=int, help='How many completion choices to generate', default=None)
162
- parser.add_argument('--seed', type=int, help='The random seed', default=42)
181
+ parser.add_argument('--seed', type=int, help='The random seed', default=0)
163
182
  parser.add_argument('--stop', nargs='*', help='The stop tokens', default=None)
164
183
  parser.add_argument('--stop-token-ids', nargs='*', help='Set the stop token IDs', default=None)
165
- parser.add_argument('--stream', action='store_true', help='Stream output with SSE', default=False)
184
+ parser.add_argument('--stream', action=argparse.BooleanOptionalAction, help='Stream output with SSE', default=True)
166
185
  parser.add_argument('--temperature', type=float, help='The sample temperature', default=0.0)
167
186
  parser.add_argument('--top-p', type=float, help='Sampling top p', default=None)
168
187
  parser.add_argument('--top-k', type=int, help='Sampling top k', default=None)
@@ -18,6 +18,7 @@ from evalscope.perf.utils.benchmark_util import BenchmarkData, BenchmarkMetrics
18
18
  from evalscope.perf.utils.db_util import create_result_table, get_result_db_path, insert_benchmark_data, summary_result
19
19
  from evalscope.perf.utils.handler import add_signal_handlers, exception_handler
20
20
  from evalscope.perf.utils.local_server import start_app
21
+ from evalscope.perf.utils.log_utils import init_swanlab, init_wandb
21
22
  from evalscope.utils.logger import get_logger
22
23
 
23
24
  logger = get_logger()
@@ -56,7 +57,7 @@ async def get_requests(args: Arguments) -> AsyncGenerator[dict, None]:
56
57
 
57
58
  if args.prompt:
58
59
  prompt = load_prompt(args.prompt)
59
- messages = [{'role': 'user', 'content': prompt}]
60
+ messages = [{'role': 'user', 'content': prompt}] if args.apply_chat_template else prompt
60
61
  generator = generate_requests_from_prompt(messages)
61
62
  elif args.dataset:
62
63
  generator = generate_requests_from_dataset()
@@ -81,6 +82,7 @@ async def send_request(
81
82
  client = AioHttpClient(args)
82
83
  async with client:
83
84
  benchmark_data = BenchmarkData(request=request)
85
+ benchmark_data.start_time = time.perf_counter()
84
86
  collected_messages = []
85
87
  try:
86
88
  async for is_error, state_code, response_data in client.post(request):
@@ -106,24 +108,18 @@ async def send_request(
106
108
 
107
109
 
108
110
  @exception_handler
109
- async def statistic_benchmark_metric_worker(benchmark_data_queue: asyncio.Queue, args: Arguments):
111
+ async def statistic_benchmark_metric(benchmark_data_queue: asyncio.Queue, args: Arguments):
110
112
  metrics = BenchmarkMetrics(concurrency=args.parallel)
111
113
 
112
114
  api_plugin_class = ApiRegistry(args.api)
113
115
  api_plugin = api_plugin_class(args.tokenizer_path)
114
116
 
115
117
  result_db_path = get_result_db_path(args)
116
- # Initialize wandb
117
- if args.wandb_api_key:
118
- import datetime
119
- import wandb
120
- os.environ['WANDB_SILENT'] = 'true'
121
- os.environ['WANDB_DIR'] = args.outputs_dir
122
118
 
123
- wandb.login(key=args.wandb_api_key)
124
- current_time = datetime.datetime.now().strftime('%Y%m%d_%H%M%S')
125
- name = args.name if args.name else f'{args.model_id}_{current_time}'
126
- wandb.init(project='perf_benchmark', name=name, config=args.to_dict())
119
+ if args.wandb_api_key:
120
+ init_wandb(args)
121
+ if args.swanlab_api_key:
122
+ init_swanlab(args)
127
123
 
128
124
  collected_benchmark_data = []
129
125
 
@@ -146,9 +142,13 @@ async def statistic_benchmark_metric_worker(benchmark_data_queue: asyncio.Queue,
146
142
  # Create a message with the updated metrics
147
143
  message = metrics.create_message()
148
144
 
149
- # Log the message to wandb if the api key is provided
145
+ # Log the message to wandb\swanlab if the api key is provided
150
146
  if args.wandb_api_key:
147
+ import wandb
151
148
  wandb.log(message)
149
+ if args.swanlab_api_key:
150
+ import swanlab
151
+ swanlab.log(message)
152
152
 
153
153
  # Log the message to the logger every n queries
154
154
  if int(metrics.n_total_queries) % args.log_every_n_query == 0:
@@ -169,17 +169,12 @@ async def statistic_benchmark_metric_worker(benchmark_data_queue: asyncio.Queue,
169
169
 
170
170
 
171
171
  @exception_handler
172
- async def start_server(args: Arguments) -> bool:
172
+ async def connect_test(args: Arguments) -> bool:
173
173
  if args.api.startswith('local'):
174
174
  # start local server
175
175
  server = threading.Thread(target=start_app, args=(copy.deepcopy(args), ), daemon=True)
176
176
  server.start()
177
177
 
178
- if args.dataset.startswith('speed_benchmark'):
179
- args.url = f'http://127.0.0.1:{args.port}/v1/completions'
180
- else:
181
- args.url = f'http://127.0.0.1:{args.port}/v1/chat/completions'
182
-
183
178
  if (not args.no_test_connection) and (not await test_connection(args)):
184
179
  raise TimeoutError('Test connection failed')
185
180
 
@@ -192,31 +187,22 @@ async def benchmark(args: Arguments) -> None:
192
187
 
193
188
  # init queue
194
189
  benchmark_data_queue = asyncio.Queue()
195
-
196
190
  # reset event
197
191
  data_process_completed_event.clear()
198
-
192
+ # test connection
193
+ await connect_test(args)
194
+ # start statistic benchmark metric
195
+ statistic_benchmark_metric_task = asyncio.create_task(statistic_benchmark_metric(benchmark_data_queue, args))
196
+ # start send request
199
197
  semaphore = asyncio.Semaphore(args.parallel)
198
+ send_request_tasks: List[asyncio.Task] = []
199
+ async for request in get_requests(args):
200
+ task = asyncio.create_task(send_request(semaphore, request, benchmark_data_queue, args))
201
+ send_request_tasks.append(task)
200
202
 
201
- async def create_send_request_tasks():
202
- tasks: List[asyncio.Task] = []
203
- async for request in get_requests(args):
204
- task = asyncio.create_task(send_request(semaphore, request, benchmark_data_queue, args))
205
- tasks.append(task)
206
- return tasks
207
-
208
- async def run_tasks():
209
- await start_server(args)
210
-
211
- statistic_benchmark_metric_task = asyncio.create_task(
212
- statistic_benchmark_metric_worker(benchmark_data_queue, args))
213
- send_request_tasks = await create_send_request_tasks()
214
-
215
- await asyncio.gather(*send_request_tasks, return_exceptions=True)
216
- await benchmark_data_queue.join()
217
- data_process_completed_event.set()
218
-
219
- metrics, result_db_path = await statistic_benchmark_metric_task
220
- summary_result(args, metrics, result_db_path)
203
+ await asyncio.gather(*send_request_tasks, return_exceptions=True)
204
+ await benchmark_data_queue.join()
205
+ data_process_completed_event.set()
221
206
 
222
- await run_tasks()
207
+ metrics, result_db_path = await statistic_benchmark_metric_task
208
+ summary_result(args, metrics, result_db_path)
@@ -24,7 +24,6 @@ class AioHttpClient:
24
24
  self.connect_timeout = args.connect_timeout
25
25
  self.client = aiohttp.ClientSession(
26
26
  timeout=aiohttp.ClientTimeout(connect=self.connect_timeout, sock_read=self.read_timeout),
27
- connector=aiohttp.TCPConnector(limit=1),
28
27
  trace_configs=[self._create_trace_config()] if args.debug else [])
29
28
 
30
29
  def _create_trace_config(self):
@@ -144,7 +143,7 @@ async def test_connection(args: Arguments) -> bool:
144
143
  async def attempt_connection():
145
144
  client = AioHttpClient(args)
146
145
  async with client:
147
- if 'chat/completions' in args.url:
146
+ if args.apply_chat_template:
148
147
  request = {
149
148
  'messages': [{
150
149
  'role': 'user',
@@ -164,7 +163,7 @@ async def test_connection(args: Arguments) -> bool:
164
163
  is_error, state_code, response_data = await asyncio.wait_for(
165
164
  attempt_connection(), timeout=args.connect_timeout)
166
165
  if not is_error:
167
- logger.info('Connection successful.')
166
+ logger.info('Test connection successful.')
168
167
  return True
169
168
  logger.warning(f'Retrying... <{state_code}> {response_data}')
170
169
  except Exception as e:
@@ -24,7 +24,7 @@ class CustomPlugin(ApiPluginBase):
24
24
  """
25
25
  super().__init__(model_path=mode_path)
26
26
  if mode_path is not None:
27
- from transformers import AutoTokenizer
27
+ from modelscope import AutoTokenizer
28
28
  self.tokenizer = AutoTokenizer.from_pretrained(mode_path)
29
29
  else:
30
30
  self.tokenizer = None
@@ -24,7 +24,7 @@ class OpenaiPlugin(ApiPluginBase):
24
24
  """
25
25
  super().__init__(model_path=mode_path)
26
26
  if mode_path is not None:
27
- from transformers import AutoTokenizer
27
+ from modelscope import AutoTokenizer
28
28
  self.tokenizer = AutoTokenizer.from_pretrained(mode_path)
29
29
  else:
30
30
  self.tokenizer = None
@@ -70,7 +70,7 @@ class OpenaiPlugin(ApiPluginBase):
70
70
  def __compose_query_from_parameter(self, payload: Dict, param: Arguments):
71
71
  payload['model'] = param.model
72
72
  if param.max_tokens is not None:
73
- payload['max_completion_tokens'] = param.max_tokens
73
+ payload['max_tokens'] = param.max_tokens
74
74
  if param.min_tokens is not None:
75
75
  payload['min_tokens'] = param.min_tokens
76
76
  if param.frequency_penalty is not None:
@@ -18,4 +18,7 @@ class CustomDatasetPlugin(DatasetPluginBase):
18
18
  prompt = item.strip()
19
19
  if len(prompt) > self.query_parameters.min_prompt_length and len(
20
20
  prompt) < self.query_parameters.max_prompt_length:
21
- yield [{'role': 'user', 'content': prompt}]
21
+ if self.query_parameters.apply_chat_template:
22
+ yield [{'role': 'user', 'content': prompt}]
23
+ else:
24
+ yield prompt
@@ -30,6 +30,7 @@ class FlickrDatasetPlugin(DatasetPluginBase):
30
30
 
31
31
  for item in dataset:
32
32
  pil_image = item['jpg']
33
+ text = item['txt']
33
34
  base64_iamge = PIL_to_base64(pil_image)
34
35
 
35
36
  yield [{
@@ -38,7 +39,7 @@ class FlickrDatasetPlugin(DatasetPluginBase):
38
39
  'content': [
39
40
  {
40
41
  'type': 'text',
41
- 'text': 'Describe the image'
42
+ 'text': text,
42
43
  },
43
44
  {
44
45
  'type': 'image_url',
@@ -19,4 +19,7 @@ class LineByLineDatasetPlugin(DatasetPluginBase):
19
19
  prompt = item.strip()
20
20
  if len(prompt) > self.query_parameters.min_prompt_length and len(
21
21
  prompt) < self.query_parameters.max_prompt_length:
22
- yield [{'role': 'user', 'content': prompt}]
22
+ if self.query_parameters.apply_chat_template:
23
+ yield [{'role': 'user', 'content': prompt}]
24
+ else:
25
+ yield prompt
@@ -24,4 +24,7 @@ class LongAlpacaDatasetPlugin(DatasetPluginBase):
24
24
  prompt = item['instruction'].strip()
25
25
  if len(prompt) > self.query_parameters.min_prompt_length and len(
26
26
  prompt) < self.query_parameters.max_prompt_length:
27
- yield [{'role': 'user', 'content': prompt}]
27
+ if self.query_parameters.apply_chat_template:
28
+ yield [{'role': 'user', 'content': prompt}]
29
+ else:
30
+ yield prompt
@@ -29,4 +29,7 @@ class OpenqaDatasetPlugin(DatasetPluginBase):
29
29
  prompt = item['question'].strip()
30
30
  if (len(prompt) > self.query_parameters.min_prompt_length
31
31
  and len(prompt) < self.query_parameters.max_prompt_length):
32
- yield [{'role': 'user', 'content': prompt}]
32
+ if self.query_parameters.apply_chat_template:
33
+ yield [{'role': 'user', 'content': prompt}]
34
+ else:
35
+ yield prompt
@@ -23,8 +23,12 @@ class RandomDatasetPlugin(DatasetPluginBase):
23
23
  self.number = self.query_parameters.number or 1
24
24
 
25
25
  def build_messages(self) -> Iterator[List[Dict]]:
26
- min_prompt_length = self.query_parameters.min_prompt_length - self.template_len
27
- max_prompt_length = self.query_parameters.max_prompt_length - self.template_len + 1
26
+ if self.query_parameters.apply_chat_template:
27
+ min_prompt_length = self.query_parameters.min_prompt_length - self.template_len
28
+ max_prompt_length = self.query_parameters.max_prompt_length - self.template_len + 1
29
+ else:
30
+ min_prompt_length = self.query_parameters.min_prompt_length
31
+ max_prompt_length = self.query_parameters.max_prompt_length + 1
28
32
 
29
33
  assert min_prompt_length >= 0, f'min_prompt_length should be greater than or equal to the template length {self.template_len}.' # noqa: E501
30
34
  assert max_prompt_length >= min_prompt_length, 'max_prompt_length should be greater than or equal to min_prompt_length.' # noqa: E501
@@ -34,10 +38,13 @@ class RandomDatasetPlugin(DatasetPluginBase):
34
38
  offsets = np.random.randint(0, self.tokenizer.vocab_size, size=self.number)
35
39
 
36
40
  for i in range(self.number):
37
- prompt_ids = (offsets[i] + i + np.arange(input_lens[i])) % self.tokenizer.vocab_size
38
- prompt = self.tokenizer.decode(
39
- self.prefix_ids + prompt_ids.tolist(), skip_special_tokens=False, clean_up_tokenization_spaces=False)
40
- yield [{'role': 'user', 'content': prompt}]
41
+ prompt_ids = ((offsets[i] + i + np.arange(input_lens[i])) % self.tokenizer.vocab_size).tolist()
42
+ prompt = self.tokenizer.decode(self.prefix_ids + prompt_ids)
43
+
44
+ if self.query_parameters.apply_chat_template:
45
+ yield [{'role': 'user', 'content': prompt}]
46
+ else:
47
+ yield prompt
41
48
 
42
49
  def get_random_inputs(self, length: int) -> List[int]:
43
50
  if length <= 0: