evalscope 0.17.1__py3-none-any.whl → 1.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of evalscope might be problematic. Click here for more details.

Files changed (302) hide show
  1. evalscope/__init__.py +4 -1
  2. evalscope/api/benchmark/__init__.py +3 -0
  3. evalscope/api/benchmark/adapters/__init__.py +5 -0
  4. evalscope/api/benchmark/adapters/default_data_adapter.py +684 -0
  5. evalscope/api/benchmark/adapters/image_edit_adapter.py +82 -0
  6. evalscope/api/benchmark/adapters/multi_choice_adapter.py +83 -0
  7. evalscope/api/benchmark/adapters/text2image_adapter.py +156 -0
  8. evalscope/api/benchmark/adapters/vision_language_adapter.py +6 -0
  9. evalscope/api/benchmark/benchmark.py +356 -0
  10. evalscope/api/benchmark/meta.py +121 -0
  11. evalscope/api/dataset/__init__.py +2 -0
  12. evalscope/api/dataset/dataset.py +349 -0
  13. evalscope/api/dataset/loader.py +262 -0
  14. evalscope/api/dataset/utils.py +143 -0
  15. evalscope/api/evaluator/__init__.py +3 -0
  16. evalscope/api/evaluator/cache.py +378 -0
  17. evalscope/api/evaluator/evaluator.py +56 -0
  18. evalscope/api/evaluator/state.py +275 -0
  19. evalscope/api/filter/__init__.py +1 -0
  20. evalscope/api/filter/filter.py +72 -0
  21. evalscope/api/messages/__init__.py +12 -0
  22. evalscope/api/messages/chat_message.py +243 -0
  23. evalscope/api/messages/content.py +102 -0
  24. evalscope/api/messages/utils.py +35 -0
  25. evalscope/api/metric/__init__.py +2 -0
  26. evalscope/api/metric/metric.py +55 -0
  27. evalscope/api/metric/scorer.py +113 -0
  28. evalscope/api/mixin/__init__.py +1 -0
  29. evalscope/api/mixin/llm_judge_mixin.py +168 -0
  30. evalscope/api/model/__init__.py +12 -0
  31. evalscope/api/model/generate_config.py +155 -0
  32. evalscope/api/model/model.py +386 -0
  33. evalscope/api/model/model_output.py +285 -0
  34. evalscope/api/registry.py +182 -0
  35. evalscope/api/tool/__init__.py +3 -0
  36. evalscope/api/tool/tool_call.py +101 -0
  37. evalscope/api/tool/tool_info.py +173 -0
  38. evalscope/api/tool/utils.py +64 -0
  39. evalscope/app/app.py +3 -0
  40. evalscope/app/ui/app_ui.py +2 -1
  41. evalscope/app/ui/multi_model.py +50 -25
  42. evalscope/app/ui/single_model.py +26 -14
  43. evalscope/app/utils/data_utils.py +43 -27
  44. evalscope/app/utils/env_utils.py +12 -0
  45. evalscope/app/utils/text_utils.py +14 -14
  46. evalscope/app/utils/visualization.py +9 -4
  47. evalscope/arguments.py +7 -10
  48. evalscope/backend/opencompass/api_meta_template.py +2 -1
  49. evalscope/backend/opencompass/backend_manager.py +6 -5
  50. evalscope/backend/rag_eval/clip_benchmark/dataset_builder.py +10 -10
  51. evalscope/backend/rag_eval/clip_benchmark/task_template.py +8 -4
  52. evalscope/backend/rag_eval/ragas/task_template.py +2 -1
  53. evalscope/backend/rag_eval/ragas/tasks/build_distribution.py +2 -1
  54. evalscope/backend/rag_eval/ragas/tasks/build_transform.py +7 -4
  55. evalscope/backend/rag_eval/ragas/tasks/testset_generation.py +2 -1
  56. evalscope/backend/rag_eval/ragas/tasks/translate_prompt.py +2 -1
  57. evalscope/backend/rag_eval/utils/embedding.py +10 -1
  58. evalscope/backend/rag_eval/utils/llm.py +13 -12
  59. evalscope/benchmarks/__init__.py +0 -2
  60. evalscope/benchmarks/aime/aime24_adapter.py +38 -40
  61. evalscope/benchmarks/aime/aime25_adapter.py +34 -40
  62. evalscope/benchmarks/alpaca_eval/alpaca_eval_adapter.py +86 -60
  63. evalscope/benchmarks/arc/arc_adapter.py +34 -147
  64. evalscope/benchmarks/arena_hard/arena_hard_adapter.py +96 -70
  65. evalscope/benchmarks/arena_hard/utils.py +37 -1
  66. evalscope/benchmarks/bbh/bbh_adapter.py +72 -144
  67. evalscope/benchmarks/bfcl/bfcl_adapter.py +188 -171
  68. evalscope/benchmarks/bfcl/generation.py +222 -0
  69. evalscope/benchmarks/ceval/ceval_adapter.py +93 -162
  70. evalscope/benchmarks/chinese_simple_qa/csimple_qa_adapter.py +85 -82
  71. evalscope/benchmarks/cmmlu/cmmlu_adapter.py +34 -125
  72. evalscope/benchmarks/competition_math/competition_math_adapter.py +56 -108
  73. evalscope/benchmarks/data_collection/data_collection_adapter.py +187 -45
  74. evalscope/benchmarks/docmath/docmath_adapter.py +109 -51
  75. evalscope/benchmarks/docmath/utils.py +4 -5
  76. evalscope/benchmarks/drop/drop_adapter.py +88 -40
  77. evalscope/benchmarks/frames/frames_adapter.py +136 -52
  78. evalscope/benchmarks/general_arena/general_arena_adapter.py +140 -98
  79. evalscope/benchmarks/general_arena/utils.py +23 -27
  80. evalscope/benchmarks/general_mcq/general_mcq_adapter.py +40 -101
  81. evalscope/benchmarks/general_qa/general_qa_adapter.py +73 -134
  82. evalscope/benchmarks/gpqa/gpqa_adapter.py +61 -100
  83. evalscope/benchmarks/gpqa/{chain_of_thought.txt → prompt.py} +12 -5
  84. evalscope/benchmarks/gsm8k/gsm8k_adapter.py +62 -142
  85. evalscope/benchmarks/hellaswag/hellaswag_adapter.py +35 -124
  86. evalscope/benchmarks/hle/hle_adapter.py +127 -93
  87. evalscope/benchmarks/humaneval/humaneval_adapter.py +86 -55
  88. evalscope/benchmarks/ifeval/ifeval_adapter.py +69 -40
  89. evalscope/benchmarks/ifeval/instructions.py +109 -64
  90. evalscope/benchmarks/ifeval/instructions_registry.py +1 -1
  91. evalscope/benchmarks/ifeval/instructions_util.py +2 -3
  92. evalscope/benchmarks/ifeval/utils.py +6 -7
  93. evalscope/benchmarks/image_edit/gedit/__init__.py +0 -0
  94. evalscope/benchmarks/image_edit/gedit/gedit_adapter.py +138 -0
  95. evalscope/benchmarks/image_edit/gedit/utils.py +372 -0
  96. evalscope/benchmarks/image_edit/gedit/vie_prompts.py +406 -0
  97. evalscope/benchmarks/iquiz/iquiz_adapter.py +30 -65
  98. evalscope/benchmarks/live_code_bench/evaluate_utils.py +2 -2
  99. evalscope/benchmarks/live_code_bench/live_code_bench_adapter.py +121 -71
  100. evalscope/benchmarks/live_code_bench/load_utils.py +13 -21
  101. evalscope/benchmarks/live_code_bench/testing_util.py +6 -2
  102. evalscope/benchmarks/maritime_bench/maritime_bench_adapter.py +49 -75
  103. evalscope/benchmarks/math_500/math_500_adapter.py +41 -48
  104. evalscope/benchmarks/math_vista/__init__.py +0 -0
  105. evalscope/benchmarks/math_vista/math_vista_adapter.py +129 -0
  106. evalscope/benchmarks/mmlu/mmlu_adapter.py +32 -205
  107. evalscope/benchmarks/mmlu_pro/mmlu_pro_adapter.py +80 -99
  108. evalscope/benchmarks/mmlu_redux/mmlu_redux_adapter.py +64 -110
  109. evalscope/benchmarks/mmmu/__init__.py +0 -0
  110. evalscope/benchmarks/mmmu/mmmu_adapter.py +159 -0
  111. evalscope/benchmarks/mmmu_pro/__init__.py +0 -0
  112. evalscope/benchmarks/mmmu_pro/mmmu_pro_adapter.py +129 -0
  113. evalscope/benchmarks/musr/musr_adapter.py +33 -64
  114. evalscope/benchmarks/needle_haystack/needle_haystack_adapter.py +196 -152
  115. evalscope/benchmarks/process_bench/process_bench_adapter.py +144 -76
  116. evalscope/benchmarks/race/race_adapter.py +33 -119
  117. evalscope/benchmarks/simple_qa/simple_qa_adapter.py +72 -70
  118. evalscope/benchmarks/super_gpqa/{five_shot_prompt.txt → prompt.py} +14 -16
  119. evalscope/benchmarks/super_gpqa/super_gpqa_adapter.py +73 -117
  120. evalscope/benchmarks/super_gpqa/utils.py +2 -1
  121. evalscope/benchmarks/tau_bench/generation.py +147 -0
  122. evalscope/benchmarks/tau_bench/tau_bench_adapter.py +114 -60
  123. evalscope/benchmarks/text2image/__init__.py +0 -0
  124. evalscope/benchmarks/text2image/evalmuse_adapter.py +78 -0
  125. evalscope/benchmarks/text2image/genai_bench_adapter.py +53 -0
  126. evalscope/benchmarks/text2image/general_t2i_adapter.py +42 -0
  127. evalscope/benchmarks/text2image/hpdv2_adapter.py +52 -0
  128. evalscope/benchmarks/text2image/tifa_adapter.py +27 -0
  129. evalscope/benchmarks/tool_bench/tool_bench_adapter.py +91 -70
  130. evalscope/benchmarks/trivia_qa/trivia_qa_adapter.py +56 -124
  131. evalscope/benchmarks/truthful_qa/truthful_qa_adapter.py +70 -266
  132. evalscope/benchmarks/winogrande/winogrande_adapter.py +28 -54
  133. evalscope/cli/cli.py +2 -0
  134. evalscope/cli/start_app.py +7 -1
  135. evalscope/cli/start_perf.py +7 -1
  136. evalscope/cli/start_server.py +6 -3
  137. evalscope/collections/__init__.py +2 -10
  138. evalscope/collections/sampler.py +10 -10
  139. evalscope/collections/schema.py +13 -11
  140. evalscope/config.py +157 -57
  141. evalscope/constants.py +37 -61
  142. evalscope/evaluator/__init__.py +1 -1
  143. evalscope/evaluator/evaluator.py +275 -419
  144. evalscope/filters/__init__.py +2 -0
  145. evalscope/filters/extraction.py +126 -0
  146. evalscope/filters/selection.py +57 -0
  147. evalscope/metrics/__init__.py +13 -13
  148. evalscope/metrics/llm_judge.py +47 -33
  149. evalscope/metrics/math_parser.py +27 -22
  150. evalscope/metrics/metric.py +307 -0
  151. evalscope/metrics/metrics.py +22 -18
  152. evalscope/metrics/t2v_metrics/__init__.py +0 -52
  153. evalscope/metrics/t2v_metrics/models/clipscore_models/build_mps_model/clip_model.py +4 -2
  154. evalscope/metrics/t2v_metrics/models/clipscore_models/build_mps_model/cross_modeling.py +9 -13
  155. evalscope/metrics/t2v_metrics/models/clipscore_models/clip_model.py +2 -1
  156. evalscope/metrics/t2v_metrics/models/clipscore_models/hpsv2_model.py +3 -2
  157. evalscope/metrics/t2v_metrics/models/clipscore_models/mps_model.py +2 -1
  158. evalscope/metrics/t2v_metrics/models/clipscore_models/pickscore_model.py +2 -2
  159. evalscope/metrics/t2v_metrics/models/itmscore_models/blip2_itm_model.py +2 -1
  160. evalscope/metrics/t2v_metrics/models/itmscore_models/fga_blip2_model.py +4 -2
  161. evalscope/metrics/t2v_metrics/models/itmscore_models/image_reward/ImageReward.py +10 -5
  162. evalscope/metrics/t2v_metrics/models/itmscore_models/image_reward/blip_pretrain.py +4 -2
  163. evalscope/metrics/t2v_metrics/models/itmscore_models/image_reward_model.py +2 -1
  164. evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5/model/language_model/clip_t5.py +15 -9
  165. evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5/model/multimodal_encoder/clip_encoder.py +4 -2
  166. evalscope/metrics/t2v_metrics/models/vqascore_models/clip_t5_model.py +15 -10
  167. evalscope/metrics/t2v_metrics/models/vqascore_models/gpt4v_model.py +9 -6
  168. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/config.py +2 -2
  169. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/gradcam.py +4 -2
  170. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/logger.py +4 -2
  171. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/optims.py +3 -9
  172. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/registry.py +16 -10
  173. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/vqa_tools/vqa.py +3 -2
  174. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/common/vqa_tools/vqa_eval.py +4 -2
  175. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/__init__.py +8 -4
  176. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/Qformer.py +47 -25
  177. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/blip2_qformer.py +12 -7
  178. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/blip2_t5.py +23 -17
  179. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/blip2_t5_instruct.py +33 -23
  180. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/fga_blip2.py +2 -1
  181. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/modeling_llama.py +46 -30
  182. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip2_models/modeling_t5.py +69 -37
  183. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/__init__.py +7 -5
  184. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip.py +6 -4
  185. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_image_text_matching.py +7 -5
  186. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_nlvr.py +3 -2
  187. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_outputs.py +5 -2
  188. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/blip_vqa.py +17 -13
  189. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/blip_models/nlvr_encoder.py +35 -19
  190. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/clip_vit.py +14 -12
  191. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/eva_vit.py +63 -52
  192. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/med.py +63 -38
  193. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/models/vit.py +6 -3
  194. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/processors/__init__.py +6 -2
  195. evalscope/metrics/t2v_metrics/models/vqascore_models/lavis/processors/randaugment.py +3 -2
  196. evalscope/metrics/t2v_metrics/models/vqascore_models/mm_utils.py +15 -13
  197. evalscope/metrics/t2v_metrics/models/vqascore_models/vqa_model.py +3 -2
  198. evalscope/models/__init__.py +6 -29
  199. evalscope/models/image_edit_model.py +125 -0
  200. evalscope/models/mockllm.py +65 -0
  201. evalscope/models/model_apis.py +67 -0
  202. evalscope/models/modelscope.py +455 -0
  203. evalscope/models/openai_compatible.py +126 -0
  204. evalscope/models/text2image_model.py +124 -0
  205. evalscope/models/utils/openai.py +701 -0
  206. evalscope/perf/benchmark.py +4 -1
  207. evalscope/perf/http_client.py +4 -2
  208. evalscope/perf/plugin/api/custom_api.py +5 -4
  209. evalscope/perf/plugin/api/openai_api.py +11 -9
  210. evalscope/perf/plugin/datasets/custom.py +2 -1
  211. evalscope/perf/plugin/datasets/flickr8k.py +1 -1
  212. evalscope/perf/plugin/datasets/kontext_bench.py +1 -1
  213. evalscope/perf/plugin/datasets/line_by_line.py +2 -1
  214. evalscope/perf/plugin/datasets/longalpaca.py +2 -1
  215. evalscope/perf/plugin/datasets/openqa.py +4 -2
  216. evalscope/perf/utils/benchmark_util.py +15 -10
  217. evalscope/perf/utils/db_util.py +9 -6
  218. evalscope/perf/utils/local_server.py +11 -3
  219. evalscope/perf/utils/rich_display.py +16 -10
  220. evalscope/report/__init__.py +2 -3
  221. evalscope/report/combinator.py +18 -12
  222. evalscope/report/generator.py +51 -35
  223. evalscope/report/{utils.py → report.py} +8 -6
  224. evalscope/run.py +33 -47
  225. evalscope/summarizer.py +1 -1
  226. evalscope/third_party/toolbench_static/llm/swift_infer.py +0 -4
  227. evalscope/utils/__init__.py +21 -2
  228. evalscope/utils/chat_service.py +3 -2
  229. evalscope/utils/deprecation_utils.py +12 -1
  230. evalscope/utils/function_utils.py +29 -0
  231. evalscope/utils/import_utils.py +23 -1
  232. evalscope/utils/io_utils.py +142 -6
  233. evalscope/utils/json_schema.py +208 -0
  234. evalscope/utils/logger.py +51 -12
  235. evalscope/utils/model_utils.py +11 -7
  236. evalscope/utils/multi_choices.py +288 -0
  237. evalscope/utils/url_utils.py +65 -0
  238. evalscope/version.py +2 -2
  239. {evalscope-0.17.1.dist-info → evalscope-1.0.1.dist-info}/METADATA +108 -62
  240. {evalscope-0.17.1.dist-info → evalscope-1.0.1.dist-info}/RECORD +258 -226
  241. tests/benchmark/test_eval.py +385 -0
  242. tests/benchmark/test_image_edit.py +65 -0
  243. tests/{aigc → benchmark}/test_t2i.py +22 -4
  244. tests/benchmark/test_vlm.py +80 -0
  245. tests/cli/test_all.py +85 -47
  246. tests/cli/test_collection.py +20 -8
  247. tests/cli/test_custom.py +22 -15
  248. tests/cli/test_reasoning.py +81 -0
  249. tests/common.py +73 -0
  250. tests/perf/test_perf.py +4 -2
  251. tests/rag/test_clip_benchmark.py +0 -2
  252. evalscope/benchmarks/aigc/t2i/base.py +0 -56
  253. evalscope/benchmarks/aigc/t2i/evalmuse_adapter.py +0 -78
  254. evalscope/benchmarks/aigc/t2i/genai_bench_adapter.py +0 -58
  255. evalscope/benchmarks/aigc/t2i/general_t2i_adapter.py +0 -58
  256. evalscope/benchmarks/aigc/t2i/hpdv2_adapter.py +0 -57
  257. evalscope/benchmarks/aigc/t2i/tifa_adapter.py +0 -37
  258. evalscope/benchmarks/arc/ai2_arc.py +0 -151
  259. evalscope/benchmarks/benchmark.py +0 -81
  260. evalscope/benchmarks/ceval/ceval_exam.py +0 -146
  261. evalscope/benchmarks/cmmlu/cmmlu.py +0 -161
  262. evalscope/benchmarks/cmmlu/samples.jsonl +0 -5
  263. evalscope/benchmarks/competition_math/competition_math.py +0 -79
  264. evalscope/benchmarks/data_adapter.py +0 -528
  265. evalscope/benchmarks/filters.py +0 -59
  266. evalscope/benchmarks/gsm8k/gsm8k.py +0 -121
  267. evalscope/benchmarks/hellaswag/hellaswag.py +0 -112
  268. evalscope/benchmarks/humaneval/humaneval.py +0 -79
  269. evalscope/benchmarks/mmlu/mmlu.py +0 -160
  270. evalscope/benchmarks/mmlu/samples.jsonl +0 -5
  271. evalscope/benchmarks/process_bench/critique_template.txt +0 -13
  272. evalscope/benchmarks/race/race.py +0 -104
  273. evalscope/benchmarks/race/samples.jsonl +0 -5
  274. evalscope/benchmarks/super_gpqa/zero_shot_prompt.txt +0 -4
  275. evalscope/benchmarks/trivia_qa/trivia_qa.py +0 -89
  276. evalscope/benchmarks/truthful_qa/truthful_qa.py +0 -163
  277. evalscope/benchmarks/utils.py +0 -60
  278. evalscope/collections/evaluator.py +0 -375
  279. evalscope/metrics/completion_parsers.py +0 -227
  280. evalscope/metrics/named_metrics.py +0 -55
  281. evalscope/models/adapters/__init__.py +0 -14
  282. evalscope/models/adapters/base_adapter.py +0 -84
  283. evalscope/models/adapters/bfcl_adapter.py +0 -246
  284. evalscope/models/adapters/chat_adapter.py +0 -207
  285. evalscope/models/adapters/choice_adapter.py +0 -222
  286. evalscope/models/adapters/custom_adapter.py +0 -71
  287. evalscope/models/adapters/server_adapter.py +0 -236
  288. evalscope/models/adapters/t2i_adapter.py +0 -79
  289. evalscope/models/adapters/tau_bench_adapter.py +0 -189
  290. evalscope/models/custom/__init__.py +0 -4
  291. evalscope/models/custom/custom_model.py +0 -50
  292. evalscope/models/custom/dummy_model.py +0 -99
  293. evalscope/models/local_model.py +0 -128
  294. evalscope/models/register.py +0 -41
  295. tests/cli/test_run.py +0 -489
  296. /evalscope/{benchmarks/aigc → api}/__init__.py +0 -0
  297. /evalscope/benchmarks/{aigc/t2i → image_edit}/__init__.py +0 -0
  298. {evalscope-0.17.1.dist-info → evalscope-1.0.1.dist-info}/LICENSE +0 -0
  299. {evalscope-0.17.1.dist-info → evalscope-1.0.1.dist-info}/WHEEL +0 -0
  300. {evalscope-0.17.1.dist-info → evalscope-1.0.1.dist-info}/entry_points.txt +0 -0
  301. {evalscope-0.17.1.dist-info → evalscope-1.0.1.dist-info}/top_level.txt +0 -0
  302. /tests/{aigc → benchmark}/__init__.py +0 -0
@@ -0,0 +1,455 @@
1
+ from __future__ import annotations
2
+
3
+ import copy
4
+ import functools
5
+ import json
6
+ import time
7
+ import torch # type: ignore
8
+ from concurrent.futures import Future
9
+ from dataclasses import dataclass
10
+ from logging import getLogger
11
+ from modelscope import AutoModelForCausalLM, AutoTokenizer
12
+ from queue import Empty, Queue
13
+ from threading import Thread
14
+ from torch import Tensor # type: ignore
15
+ from typing import Any, Dict, List, Literal, Optional, Protocol, Tuple, Union, cast
16
+ from typing_extensions import override
17
+
18
+ from evalscope.api.messages import (
19
+ ChatMessage,
20
+ ChatMessageAssistant,
21
+ ContentAudio,
22
+ ContentImage,
23
+ ContentText,
24
+ ContentVideo,
25
+ )
26
+ from evalscope.api.model import (
27
+ ChatCompletionChoice,
28
+ GenerateConfig,
29
+ Logprob,
30
+ Logprobs,
31
+ ModelAPI,
32
+ ModelOutput,
33
+ ModelUsage,
34
+ TopLogprob,
35
+ )
36
+ from evalscope.api.tool import ToolChoice, ToolInfo
37
+ from evalscope.utils.model_utils import get_device
38
+
39
+ logger = getLogger()
40
+
41
+
42
+ class ModelScopeAPI(ModelAPI):
43
+
44
+ def __init__(
45
+ self,
46
+ model_name: str,
47
+ base_url: Optional[str] = None,
48
+ api_key: Optional[str] = None,
49
+ config: GenerateConfig = GenerateConfig(),
50
+ **model_args: Any,
51
+ ):
52
+ super().__init__(
53
+ model_name=model_name,
54
+ base_url=base_url,
55
+ api_key=api_key,
56
+ config=config,
57
+ )
58
+
59
+ # collect known model_args (then delete them so we can pass the rest on)
60
+ def collect_model_arg(name: str) -> Optional[Any]:
61
+ nonlocal model_args
62
+ value = model_args.get(name, None)
63
+ if value is not None:
64
+ model_args.pop(name)
65
+ return value
66
+
67
+ model_path = collect_model_arg('model_path')
68
+ device_map = collect_model_arg('device_map')
69
+ torch_dtype = collect_model_arg('precision')
70
+ tokenizer_path = collect_model_arg('tokenizer_path')
71
+ self.chat_template = collect_model_arg('chat_template')
72
+ self.tokenizer_call_args = collect_model_arg('tokenizer_call_args')
73
+ self.enable_thinking = collect_model_arg('enable_thinking')
74
+ if self.tokenizer_call_args is None:
75
+ self.tokenizer_call_args = {}
76
+
77
+ # device
78
+ self.device = device_map or get_device()
79
+
80
+ # torch dtype
81
+ DTYPE_MAP = {'float16': torch.float16, 'float32': torch.float32, 'bfloat16': torch.bfloat16, 'auto': 'auto'}
82
+
83
+ if isinstance(torch_dtype, str) and torch_dtype != 'auto':
84
+ torch_dtype = DTYPE_MAP.get(torch_dtype, torch.float32)
85
+ self.torch_dtype = torch_dtype
86
+
87
+ # model
88
+ model_name_or_path = model_path or model_name
89
+ self.model = AutoModelForCausalLM.from_pretrained(
90
+ model_name_or_path,
91
+ device_map=self.device,
92
+ token=self.api_key,
93
+ torch_dtype=self.torch_dtype,
94
+ trust_remote_code=True,
95
+ **model_args
96
+ )
97
+
98
+ # tokenizer
99
+ tokenizer_name_or_path = tokenizer_path or model_name_or_path
100
+ self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path, trust_remote_code=True)
101
+ # LLMs generally don't have a pad token and we need one for batching
102
+ if self.tokenizer.pad_token is None:
103
+ if self.tokenizer.eos_token is not None:
104
+ self.tokenizer.pad_token = self.tokenizer.eos_token
105
+ else:
106
+ # add a pad token
107
+ self.tokenizer.add_special_tokens({'pad_token': '[PAD]'})
108
+ # set padding side to left for LLMs
109
+ self.tokenizer.padding_side = 'left'
110
+ # set chat template if provided
111
+ if self.chat_template:
112
+ self.tokenizer.chat_template = self.chat_template
113
+ logger.info(f'Using custom chat template: {self.chat_template}')
114
+
115
+ def generate(
116
+ self,
117
+ input: List[ChatMessage],
118
+ tools: List[ToolInfo],
119
+ tool_choice: ToolChoice,
120
+ config: GenerateConfig,
121
+ ) -> ModelOutput:
122
+
123
+ # create chat
124
+ chat = self.ms_chat(input, tools)
125
+
126
+ assert isinstance(self.tokenizer_call_args, dict)
127
+ # prepare tokenizer
128
+ tokenizer = functools.partial(
129
+ self.tokenizer,
130
+ return_tensors='pt',
131
+ padding=True,
132
+ **self.tokenizer_call_args,
133
+ )
134
+
135
+ # prepare generator
136
+ kwargs: Dict[str, Any] = {}
137
+ if config.do_sample is not None:
138
+ kwargs['do_sample'] = config.do_sample
139
+ if config.n is not None:
140
+ if config.n > 1:
141
+ assert config.do_sample, 'n > 1 requires do_sample=True in GenerateConfig'
142
+ kwargs['num_return_sequences'] = config.n
143
+ if config.max_tokens is not None:
144
+ kwargs['max_new_tokens'] = config.max_tokens
145
+ if config.temperature is not None:
146
+ kwargs['temperature'] = config.temperature
147
+ if config.top_p is not None:
148
+ kwargs['top_p'] = config.top_p
149
+ if config.top_k is not None:
150
+ kwargs['top_k'] = config.top_k
151
+ if config.logprobs is not None:
152
+ kwargs['output_logits'] = config.logprobs
153
+ if 'return_dict_in_generate' in kwargs:
154
+ assert kwargs['return_dict_in_generate']
155
+ if config.stop_seqs is not None:
156
+ from transformers.generation import StopStringCriteria # type: ignore
157
+
158
+ stopping_criteria = [StopStringCriteria(self.tokenizer, config.stop_seqs)]
159
+ kwargs['stopping_criteria'] = stopping_criteria
160
+
161
+ kwargs['return_dict_in_generate'] = True
162
+ generator = functools.partial(self.model.generate, **kwargs)
163
+
164
+ # prepare decoder
165
+ decoder = functools.partial(
166
+ self.tokenizer.batch_decode,
167
+ skip_special_tokens=True,
168
+ clean_up_tokenization_spaces=False,
169
+ )
170
+
171
+ # generate
172
+ responses = batched_generate(
173
+ GenerateInput(
174
+ input=chat,
175
+ device=self.model.device,
176
+ tokenizer=tokenizer,
177
+ generator=generator,
178
+ decoder=decoder,
179
+ batch_size=config.batch_size or self.max_connections(),
180
+ )
181
+ )
182
+
183
+ choices: List[ChatCompletionChoice] = []
184
+ for response in responses:
185
+ # gather logprobs
186
+ final_logprobs = None
187
+ if config.logprobs is not None:
188
+ final_logprobs = extract_logprobs(
189
+ response=response,
190
+ top=config.top_logprobs,
191
+ tokenizer=self.tokenizer,
192
+ )
193
+
194
+ # construct choice
195
+ # TODO: Handle tool calls
196
+ choice = ChatCompletionChoice(
197
+ message=ChatMessageAssistant(content=response.output, model=self.model_name, source='generate'),
198
+ logprobs=(Logprobs(content=final_logprobs) if final_logprobs is not None else None),
199
+ )
200
+ choices.append(choice)
201
+
202
+ # return output
203
+ return ModelOutput(
204
+ model=self.model_name,
205
+ choices=choices,
206
+ usage=ModelUsage(
207
+ input_tokens=response.input_tokens,
208
+ output_tokens=response.output_tokens,
209
+ total_tokens=response.total_tokens,
210
+ ),
211
+ time=response.time,
212
+ )
213
+
214
+ @override
215
+ def max_tokens(self) -> Optional[int]:
216
+ """Default is 2048, bump it up to a value suitable for evals."""
217
+ return 2048
218
+
219
+ @override
220
+ def max_connections(self) -> int:
221
+ """Effectively the batch size."""
222
+ return 8
223
+
224
+ def ms_chat(self, messages: List[ChatMessage], tools: List[ToolInfo]) -> str:
225
+ # convert to ms format
226
+ tools_list = []
227
+ ms_messages = copy.deepcopy(messages)
228
+ if len(tools) > 0:
229
+ tools_list = [json.loads(tool.model_dump_json(exclude_none=True, indent=2)) for tool in tools]
230
+
231
+ ms_messages = message_content_to_string(ms_messages)
232
+ # apply chat template
233
+ if self.tokenizer.chat_template is not None:
234
+ chat = self.tokenizer.apply_chat_template(
235
+ ms_messages,
236
+ add_generation_prompt=True,
237
+ tokenize=False,
238
+ tools=tools_list if len(tools_list) > 0 else None,
239
+ enable_thinking=self.enable_thinking, # not all models use this, check if it is supported
240
+ )
241
+ else:
242
+ chat = ''
243
+ for message in ms_messages:
244
+ chat += f'{message.role}: {message.content}\n'
245
+ # return
246
+ return cast(str, chat)
247
+
248
+
249
+ def message_content_to_string(messages: List[ChatMessage]) -> List[ChatMessage]:
250
+ """Convert list of content in `ChatMessageAssistant`, `ChatMessageUser` or `ChatMessageSystem` to a string."""
251
+ for message in messages:
252
+ if isinstance(message.content, list):
253
+ is_multimodal = any(
254
+ isinstance(item, (ContentAudio, ContentImage, ContentVideo)) for item in message.content
255
+ )
256
+ if is_multimodal:
257
+ raise NotImplementedError(
258
+ 'Transformer model does not support multimodal content, please provide text inputs only.'
259
+ )
260
+ message.content = message.text
261
+ return messages
262
+
263
+
264
+ # return value from generate as a result of specifying return_dict_in_generate
265
+ class ModelGenerateOutput:
266
+ sequences: Tensor
267
+ logits: tuple[Tensor]
268
+
269
+
270
+ class Tokenizer(Protocol):
271
+
272
+ def __call__(self, input: List[str]) -> Dict[Literal['input_ids', 'attention_mask'], Tensor]:
273
+ ...
274
+
275
+
276
+ class Generator(Protocol):
277
+
278
+ def __call__(self, input_ids: Tensor, attention_mask: Tensor) -> Tensor:
279
+ ...
280
+
281
+
282
+ class Decoder(Protocol):
283
+
284
+ def __call__(self, sequences: Tensor) -> list[str]:
285
+ ...
286
+
287
+
288
+ @dataclass
289
+ class GenerateInput:
290
+ input: str
291
+ device: str
292
+ tokenizer: Tokenizer
293
+ generator: Generator
294
+ decoder: Decoder
295
+ batch_size: int
296
+
297
+
298
+ @dataclass
299
+ class GenerateOutput:
300
+ output: str
301
+ input_tokens: int
302
+ output_tokens: int
303
+ total_tokens: int
304
+ logprobs: Optional[torch.Tensor]
305
+ time: float
306
+
307
+
308
+ @dataclass
309
+ class _QueueItem:
310
+ input: GenerateInput
311
+ future: Future[GenerateOutput]
312
+
313
+
314
+ batch_thread: Optional[Thread] = None
315
+
316
+ batch_queue: 'Queue[_QueueItem]' = Queue()
317
+
318
+
319
+ def batched_generate(input: GenerateInput) -> List[GenerateOutput]:
320
+ # start the background thread if necessary
321
+ global batch_thread
322
+ if batch_thread is None:
323
+ batch_thread = Thread(target=process_batches, daemon=True)
324
+ batch_thread.start()
325
+
326
+ # enqueue the job
327
+ future = Future[GenerateOutput]()
328
+ batch_queue.put(_QueueItem(input=input, future=future))
329
+
330
+ return future.result()
331
+
332
+
333
+ def process_batches() -> None:
334
+ while True:
335
+ # drain the queue (wait until no new messages have shown up for 2 seconds)
336
+ inputs: List[Tuple[GenerateInput, Future[GenerateOutput]]] = []
337
+ while True:
338
+ try:
339
+ input = batch_queue.get(timeout=2)
340
+ inputs.append((input.input, input.future))
341
+ if len(inputs) == input.input.batch_size:
342
+ # max batch size reached
343
+ break
344
+ except Empty:
345
+ # we have exhausted the queue
346
+ break
347
+
348
+ # see if we have any work to do
349
+ if len(inputs) == 0:
350
+ continue
351
+
352
+ try:
353
+ # capture the generator and decoder functions
354
+ start_time = time.monotonic()
355
+ first_input = inputs[0][0]
356
+ device = first_input.device
357
+ tokenizer = first_input.tokenizer
358
+ generator = first_input.generator
359
+ decoder = first_input.decoder
360
+ num_return_sequences = generator.keywords.get('num_return_sequences', 1)
361
+
362
+ # tokenize and move to device
363
+ tokenized_inputs = tokenizer([item[0].input for item in inputs])
364
+ input_ids = tokenized_inputs['input_ids']
365
+ attention_mask = tokenized_inputs['attention_mask']
366
+ input_ids = input_ids.to(device)
367
+ attention_mask = attention_mask.to(device)
368
+
369
+ # generate
370
+ with torch.inference_mode():
371
+ generation_outputs = cast(
372
+ ModelGenerateOutput,
373
+ generator(input_ids=input_ids, attention_mask=attention_mask),
374
+ )
375
+ generate_ids = generation_outputs.sequences
376
+ logits = generation_outputs.logits
377
+
378
+ # get logprobs from logits
379
+ logprobs = None
380
+ if logits is not None:
381
+ stacked_logits = torch.stack(logits).transpose(0, 1)
382
+ logprobs = torch.nn.functional.log_softmax(stacked_logits, dim=-1)
383
+
384
+ # decode
385
+ generated_tokens = generate_ids[:, input_ids.size(dim=1):]
386
+ if logprobs is not None:
387
+ assert logprobs.shape[1] == generated_tokens.shape[1]
388
+ outputs = decoder(sequences=generated_tokens)
389
+
390
+ # call back futures
391
+ total_time = time.monotonic() - start_time
392
+ for input_index in range(len(inputs)):
393
+ choices: List[GenerateOutput] = []
394
+ # handle input
395
+ future = inputs[input_index][1]
396
+ input_tokens = input_ids[input_index].shape[-1]
397
+ # handle choices
398
+ for choice_index in range(num_return_sequences):
399
+ output_index = input_index * num_return_sequences + choice_index
400
+ # handle out of
401
+ output = outputs[output_index]
402
+ output_tokens = generate_ids[output_index].shape[-1] - input_tokens
403
+ logprobs_tensor = logprobs[output_index] if logprobs is not None else None
404
+ # create the output
405
+ choices.append(
406
+ GenerateOutput(
407
+ output=output,
408
+ input_tokens=input_tokens,
409
+ output_tokens=output_tokens,
410
+ total_tokens=input_tokens + output_tokens,
411
+ logprobs=logprobs_tensor,
412
+ time=total_time,
413
+ )
414
+ )
415
+
416
+ # asyncio futures are not thread safe, so we need to pass the event loop
417
+ # down to this point, so we can mark the future as done in a thread safe manner.
418
+ # see: https://docs.python.org/3/library/asyncio-dev.html#concurrency-and-multithreading
419
+ future.set_result(choices)
420
+
421
+ except Exception as ex:
422
+ for inp in inputs:
423
+ future = inp[1]
424
+ future.set_exception(ex)
425
+
426
+
427
+ def extract_logprobs(
428
+ response: GenerateOutput,
429
+ top: Optional[int],
430
+ tokenizer,
431
+ ) -> List[Logprob]:
432
+ assert response.logprobs is not None
433
+ k = top or 1
434
+ topk_values, topk_inds = response.logprobs.topk(k=k, dim=-1)
435
+ final_logprobs = []
436
+ for toks, vals in zip(topk_inds, topk_values):
437
+ top_logprobs: List[TopLogprob] = []
438
+ for tok, val in zip(toks, vals):
439
+ # TODO: you get byte artifacts converting single ids to tokens like this...
440
+ # but `tokenizer.decode` strips spaces. There must be a better way to do this.
441
+ token_str = tokenizer.convert_ids_to_tokens(tok.item())
442
+ top_logprobs.append(TopLogprob(
443
+ token=token_str,
444
+ logprob=val,
445
+ bytes=list(map(ord, token_str)),
446
+ ))
447
+ final_logprobs.append(
448
+ Logprob(
449
+ token=top_logprobs[0].token,
450
+ logprob=top_logprobs[0].logprob,
451
+ bytes=top_logprobs[0].bytes,
452
+ top_logprobs=top_logprobs,
453
+ )
454
+ )
455
+ return final_logprobs
@@ -0,0 +1,126 @@
1
+ import os
2
+ from openai import APIStatusError, BadRequestError, OpenAI, PermissionDeniedError, UnprocessableEntityError
3
+ from openai._types import NOT_GIVEN
4
+ from openai.types.chat import ChatCompletion
5
+ from typing import Any, Dict, List, Optional, Tuple, Union
6
+
7
+ from evalscope.api.messages import ChatMessage
8
+ from evalscope.api.model import ChatCompletionChoice, GenerateConfig, ModelAPI, ModelOutput
9
+ from evalscope.api.tool import ToolChoice, ToolInfo
10
+ from evalscope.utils import get_logger
11
+ from .utils.openai import (
12
+ chat_choices_from_openai,
13
+ collect_stream_response,
14
+ model_output_from_openai,
15
+ openai_chat_messages,
16
+ openai_chat_tool_choice,
17
+ openai_chat_tools,
18
+ openai_completion_params,
19
+ openai_handle_bad_request,
20
+ )
21
+
22
+ logger = get_logger()
23
+
24
+
25
+ class OpenAICompatibleAPI(ModelAPI):
26
+
27
+ def __init__(
28
+ self,
29
+ model_name: str,
30
+ base_url: Optional[str] = None,
31
+ api_key: Optional[str] = None,
32
+ config: GenerateConfig = GenerateConfig(),
33
+ **model_args: Any,
34
+ ) -> None:
35
+
36
+ super().__init__(
37
+ model_name=model_name,
38
+ base_url=base_url,
39
+ api_key=api_key,
40
+ config=config,
41
+ )
42
+
43
+ # use service prefix to lookup api_key
44
+ self.api_key = api_key or os.environ.get('EVALSCOPE_API_KEY', None)
45
+ assert self.api_key, f'API key for {model_name} not found'
46
+
47
+ # use service prefix to lookup base_url
48
+ self.base_url = base_url or os.environ.get('EVALSCOPE_BASE_URL', None)
49
+ assert self.base_url, f'Base URL for {model_name} not found'
50
+
51
+ # remove trailing slash from base_url
52
+ self.base_url = self.base_url.rstrip('/').removesuffix('/chat/completions')
53
+
54
+ # create http client
55
+ self.client = OpenAI(
56
+ api_key=self.api_key,
57
+ base_url=self.base_url,
58
+ **model_args,
59
+ )
60
+
61
+ def generate(
62
+ self,
63
+ input: List[ChatMessage],
64
+ tools: List[ToolInfo],
65
+ tool_choice: ToolChoice,
66
+ config: GenerateConfig,
67
+ ) -> ModelOutput:
68
+ # setup request and response for ModelCall
69
+ request: Dict[str, Any] = {}
70
+ response: Dict[str, Any] = {}
71
+
72
+ tools, tool_choice, config = self.resolve_tools(tools, tool_choice, config)
73
+
74
+ # get completion params (slice off service from model name)
75
+ completion_params = self.completion_params(
76
+ config=config,
77
+ tools=len(tools) > 0,
78
+ )
79
+
80
+ request = dict(
81
+ messages=openai_chat_messages(input),
82
+ tools=openai_chat_tools(tools) if len(tools) > 0 else NOT_GIVEN,
83
+ tool_choice=openai_chat_tool_choice(tool_choice) if len(tools) > 0 else NOT_GIVEN,
84
+ **completion_params,
85
+ )
86
+
87
+ try:
88
+ # generate completion and save response for model call
89
+ completion = self.client.chat.completions.create(**request)
90
+ # handle streaming response
91
+ if not isinstance(completion, ChatCompletion):
92
+ completion = collect_stream_response(completion)
93
+ response = completion.model_dump()
94
+ self.on_response(response)
95
+
96
+ # return output and call
97
+ choices = self.chat_choices_from_completion(completion, tools)
98
+ return model_output_from_openai(completion, choices)
99
+
100
+ except (BadRequestError, UnprocessableEntityError, PermissionDeniedError) as ex:
101
+ return self.handle_bad_request(ex)
102
+
103
+ def resolve_tools(self, tools: List[ToolInfo], tool_choice: ToolChoice,
104
+ config: GenerateConfig) -> Tuple[List[ToolInfo], ToolChoice, GenerateConfig]:
105
+ """Provides an opportunity for concrete classes to customize tool resolution."""
106
+ return tools, tool_choice, config
107
+
108
+ def completion_params(self, config: GenerateConfig, tools: bool) -> Dict[str, Any]:
109
+ return openai_completion_params(
110
+ model=self.model_name,
111
+ config=config,
112
+ tools=tools,
113
+ )
114
+
115
+ def on_response(self, response: Dict[str, Any]) -> None:
116
+ """Hook for subclasses to do custom response handling."""
117
+ pass
118
+
119
+ def chat_choices_from_completion(self, completion: ChatCompletion,
120
+ tools: List[ToolInfo]) -> List[ChatCompletionChoice]:
121
+ """Hook for subclasses to do custom chat choice processing."""
122
+ return chat_choices_from_openai(completion, tools)
123
+
124
+ def handle_bad_request(self, ex: APIStatusError) -> Union[ModelOutput, Exception]:
125
+ """Hook for subclasses to do bad request handling"""
126
+ return openai_handle_bad_request(self.model_name, ex)
@@ -0,0 +1,124 @@
1
+ from __future__ import annotations
2
+
3
+ import importlib
4
+ import time
5
+ import torch
6
+ from logging import getLogger
7
+ from typing import Any, Dict, List, Literal, Optional, Protocol, Tuple, Union, cast
8
+
9
+ from evalscope.api.messages import (
10
+ ChatMessage,
11
+ ChatMessageAssistant,
12
+ ContentAudio,
13
+ ContentImage,
14
+ ContentText,
15
+ ContentVideo,
16
+ )
17
+ from evalscope.api.model import (
18
+ ChatCompletionChoice,
19
+ GenerateConfig,
20
+ Logprob,
21
+ Logprobs,
22
+ ModelAPI,
23
+ ModelOutput,
24
+ ModelUsage,
25
+ TopLogprob,
26
+ )
27
+ from evalscope.api.tool import ToolChoice, ToolInfo
28
+ from evalscope.utils.io_utils import PIL_to_base64
29
+ from evalscope.utils.model_utils import get_device
30
+
31
+ logger = getLogger()
32
+
33
+
34
+ class Text2ImageAPI(ModelAPI):
35
+
36
+ def __init__(
37
+ self,
38
+ model_name: str,
39
+ base_url: Optional[str] = None,
40
+ api_key: Optional[str] = None,
41
+ config: GenerateConfig = GenerateConfig(),
42
+ **model_args: Any,
43
+ ):
44
+ super().__init__(
45
+ model_name=model_name,
46
+ base_url=base_url,
47
+ api_key=api_key,
48
+ config=config,
49
+ )
50
+
51
+ # collect known model_args (then delete them so we can pass the rest on)
52
+ def collect_model_arg(name: str) -> Optional[Any]:
53
+ nonlocal model_args
54
+ value = model_args.get(name, None)
55
+ if value is not None:
56
+ model_args.pop(name)
57
+ return value
58
+
59
+ model_path = collect_model_arg('model_path')
60
+ torch_dtype = collect_model_arg('precision') or collect_model_arg('torch_dtype')
61
+ device_map = collect_model_arg('device_map')
62
+ # torch dtype
63
+ DTYPE_MAP = {'float16': torch.float16, 'float32': torch.float32, 'bfloat16': torch.bfloat16, 'auto': 'auto'}
64
+
65
+ if isinstance(torch_dtype, str) and torch_dtype != 'auto':
66
+ torch_dtype = DTYPE_MAP.get(torch_dtype, torch.float32)
67
+ self.torch_dtype = torch_dtype
68
+ self.device = device_map or get_device()
69
+
70
+ self.pipeline_cls = collect_model_arg('pipeline_cls')
71
+ # default to DiffusionPipeline if not specified
72
+ if self.pipeline_cls is None:
73
+ if 'flux' in model_name.lower():
74
+ self.pipeline_cls = 'FluxPipeline'
75
+ else:
76
+ self.pipeline_cls = 'DiffusionPipeline'
77
+
78
+ model_name_or_path = model_path or model_name
79
+
80
+ # from modelscope import pipeline_cls
81
+ module = getattr(importlib.import_module('modelscope'), self.pipeline_cls)
82
+ logger.info(f'Loading model {model_name_or_path} with {self.pipeline_cls} ...')
83
+
84
+ self.model = module.from_pretrained(
85
+ model_name_or_path,
86
+ torch_dtype=self.torch_dtype,
87
+ **model_args,
88
+ )
89
+
90
+ self.model.to(self.device)
91
+
92
+ def generate(
93
+ self,
94
+ input: List[ChatMessage],
95
+ tools: List[ToolInfo],
96
+ tool_choice: ToolChoice,
97
+ config: GenerateConfig,
98
+ ) -> ModelOutput:
99
+
100
+ # prepare generator
101
+ kwargs: Dict[str, Any] = {}
102
+ if config.height is not None:
103
+ kwargs['height'] = config.height
104
+ if config.width is not None:
105
+ kwargs['width'] = config.width
106
+ if config.num_inference_steps is not None:
107
+ kwargs['num_inference_steps'] = config.num_inference_steps
108
+ if config.guidance_scale is not None:
109
+ kwargs['guidance_scale'] = config.guidance_scale
110
+ # update with extra model parameters
111
+ kwargs.update(config.model_extra)
112
+
113
+ # assume the first text as prompt
114
+ prompt = input[0].text
115
+ # get the first image as output
116
+ image = self.model(prompt=prompt, **kwargs).images[0]
117
+
118
+ image_base64 = PIL_to_base64(image)
119
+
120
+ return ModelOutput(
121
+ model=self.model_name,
122
+ choices=[ChatCompletionChoice.from_content(content=[ContentImage(image=image_base64)])],
123
+ time=time.time(),
124
+ )