evalscope 0.6.1__py3-none-any.whl → 0.7.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of evalscope might be problematic. Click here for more details.

Files changed (108) hide show
  1. evalscope/backend/opencompass/tasks/eval_api.py +2 -1
  2. evalscope/backend/opencompass/tasks/eval_datasets.py +1 -0
  3. evalscope/backend/rag_eval/clip_benchmark/utils/webdataset_convert.py +230 -0
  4. evalscope/backend/rag_eval/clip_benchmark/utils/webdatasets.txt +43 -0
  5. evalscope/backend/rag_eval/ragas/prompts/chinese/AnswerCorrectness/correctness_prompt_chinese.json +87 -0
  6. evalscope/backend/rag_eval/ragas/prompts/chinese/AnswerCorrectness/long_form_answer_prompt_chinese.json +36 -0
  7. evalscope/backend/rag_eval/ragas/prompts/chinese/AnswerRelevancy/question_generation_chinese.json +26 -0
  8. evalscope/backend/rag_eval/ragas/prompts/chinese/ContextPrecision/context_precision_prompt_chinese.json +41 -0
  9. evalscope/backend/rag_eval/ragas/prompts/chinese/Faithfulness/nli_statements_message_chinese.json +60 -0
  10. evalscope/backend/rag_eval/ragas/prompts/chinese/Faithfulness/statement_prompt_chinese.json +36 -0
  11. evalscope/backend/rag_eval/ragas/prompts/chinese/HeadlinesExtractor/prompt_chinese.json +22 -0
  12. evalscope/backend/rag_eval/ragas/prompts/chinese/MultiHopAbstractQuerySynthesizer/concept_combination_prompt_chinese.json +35 -0
  13. evalscope/backend/rag_eval/ragas/prompts/chinese/MultiHopAbstractQuerySynthesizer/generate_query_reference_prompt_chinese.json +7 -0
  14. evalscope/backend/rag_eval/ragas/prompts/chinese/MultiHopAbstractQuerySynthesizer/theme_persona_matching_prompt_chinese.json +39 -0
  15. evalscope/backend/rag_eval/ragas/prompts/chinese/MultiHopSpecificQuerySynthesizer/generate_query_reference_prompt_chinese.json +7 -0
  16. evalscope/backend/rag_eval/ragas/prompts/chinese/MultiHopSpecificQuerySynthesizer/theme_persona_matching_prompt_chinese.json +39 -0
  17. evalscope/backend/rag_eval/ragas/prompts/chinese/MultiModalFaithfulness/faithfulness_prompt_chinese.json +34 -0
  18. evalscope/backend/rag_eval/ragas/prompts/chinese/MultiModalRelevance/relevance_prompt_chinese.json +36 -0
  19. evalscope/backend/rag_eval/ragas/prompts/chinese/NERExtractor/prompt_chinese.json +25 -0
  20. evalscope/backend/rag_eval/ragas/prompts/chinese/SingleHopSpecificQuerySynthesizer/generate_query_reference_prompt_chinese.json +7 -0
  21. evalscope/backend/rag_eval/ragas/prompts/chinese/SingleHopSpecificQuerySynthesizer/theme_persona_matching_prompt_chinese.json +39 -0
  22. evalscope/backend/rag_eval/ragas/prompts/chinese/SummaryExtractor/prompt_chinese.json +16 -0
  23. evalscope/backend/rag_eval/ragas/prompts/chinese/ThemesExtractor/prompt_chinese.json +24 -0
  24. evalscope/backend/rag_eval/ragas/prompts/persona_prompt.py +18 -0
  25. evalscope/backend/vlm_eval_kit/backend_manager.py +23 -21
  26. evalscope/benchmarks/ceval/samples.jsonl +1 -0
  27. evalscope/benchmarks/cmmlu/samples.jsonl +5 -0
  28. evalscope/benchmarks/mmlu/samples.jsonl +5 -0
  29. evalscope/benchmarks/race/samples.jsonl +5 -0
  30. evalscope/benchmarks/trivia_qa/samples.jsonl +5 -0
  31. evalscope/cli/start_perf.py +8 -11
  32. evalscope/metrics/resources/gpt2-zhcn3-v4.bpe +58485 -0
  33. evalscope/metrics/resources/gpt2-zhcn3-v4.json +1 -0
  34. evalscope/metrics/rouge_metric.py +30 -15
  35. evalscope/perf/arguments.py +179 -0
  36. evalscope/perf/benchmark.py +245 -0
  37. evalscope/perf/http_client.py +127 -711
  38. evalscope/perf/main.py +35 -0
  39. evalscope/perf/plugin/__init__.py +2 -0
  40. evalscope/perf/plugin/api/__init__.py +3 -0
  41. evalscope/perf/{api_plugin_base.py → plugin/api/base.py} +17 -18
  42. evalscope/perf/{custom_api.py → plugin/api/custom_api.py} +25 -19
  43. evalscope/perf/{dashscope_api.py → plugin/api/dashscope_api.py} +28 -14
  44. evalscope/perf/{openai_api.py → plugin/api/openai_api.py} +51 -27
  45. evalscope/perf/plugin/datasets/__init__.py +6 -0
  46. evalscope/perf/{dataset_plugin_base.py → plugin/datasets/base.py} +13 -10
  47. evalscope/perf/plugin/datasets/custom.py +21 -0
  48. evalscope/perf/plugin/datasets/flickr8k.py +51 -0
  49. evalscope/perf/{datasets → plugin/datasets}/line_by_line.py +9 -5
  50. evalscope/perf/plugin/datasets/longalpaca.py +28 -0
  51. evalscope/perf/plugin/datasets/openqa.py +38 -0
  52. evalscope/perf/plugin/datasets/speed_benchmark.py +50 -0
  53. evalscope/perf/plugin/registry.py +54 -0
  54. evalscope/perf/{how_to_analysis_result.py → utils/analysis_result.py} +11 -5
  55. evalscope/perf/utils/benchmark_util.py +135 -0
  56. evalscope/perf/utils/chat_service.py +252 -0
  57. evalscope/perf/utils/db_util.py +200 -0
  58. evalscope/perf/utils/handler.py +46 -0
  59. evalscope/perf/utils/local_server.py +139 -0
  60. evalscope/registry/config/cfg_arena.yaml +77 -0
  61. evalscope/registry/config/cfg_arena_zhihu.yaml +63 -0
  62. evalscope/registry/config/cfg_pairwise_baseline.yaml +83 -0
  63. evalscope/registry/config/cfg_single.yaml +78 -0
  64. evalscope/registry/data/prompt_template/lmsys_v2.jsonl +8 -0
  65. evalscope/registry/data/prompt_template/prompt_templates.jsonl +8 -0
  66. evalscope/registry/data/qa_browser/battle.jsonl +634 -0
  67. evalscope/registry/data/qa_browser/category_mapping.yaml +10 -0
  68. evalscope/registry/data/question.jsonl +80 -0
  69. evalscope/third_party/longbench_write/README.md +118 -0
  70. evalscope/third_party/longbench_write/default_task.json +27 -0
  71. evalscope/third_party/longbench_write/default_task.yaml +24 -0
  72. evalscope/third_party/toolbench_static/README.md +118 -0
  73. evalscope/third_party/toolbench_static/config_default.json +15 -0
  74. evalscope/third_party/toolbench_static/config_default.yaml +12 -0
  75. evalscope/third_party/toolbench_static/requirements.txt +2 -0
  76. evalscope/utils/logger.py +18 -20
  77. evalscope/utils/utils.py +41 -42
  78. evalscope/version.py +2 -2
  79. evalscope-0.7.1.dist-info/LICENSE +203 -0
  80. {evalscope-0.6.1.dist-info → evalscope-0.7.1.dist-info}/METADATA +93 -35
  81. {evalscope-0.6.1.dist-info → evalscope-0.7.1.dist-info}/RECORD +101 -31
  82. {evalscope-0.6.1.dist-info → evalscope-0.7.1.dist-info}/WHEEL +1 -1
  83. {evalscope-0.6.1.dist-info → evalscope-0.7.1.dist-info}/top_level.txt +1 -0
  84. tests/cli/__init__.py +1 -0
  85. tests/cli/test_run.py +76 -0
  86. tests/perf/__init__.py +1 -0
  87. tests/perf/test_perf.py +96 -0
  88. tests/rag/test_clip_benchmark.py +85 -0
  89. tests/rag/test_mteb.py +136 -0
  90. tests/rag/test_ragas.py +120 -0
  91. tests/swift/__init__.py +1 -0
  92. tests/swift/test_run_swift_eval.py +146 -0
  93. tests/swift/test_run_swift_vlm_eval.py +128 -0
  94. tests/swift/test_run_swift_vlm_jugde_eval.py +157 -0
  95. tests/test_run_all.py +12 -0
  96. tests/vlm/__init__.py +1 -0
  97. tests/vlm/test_vlmeval.py +59 -0
  98. evalscope/perf/_logging.py +0 -32
  99. evalscope/perf/datasets/longalpaca_12k.py +0 -20
  100. evalscope/perf/datasets/openqa.py +0 -22
  101. evalscope/perf/plugin_registry.py +0 -35
  102. evalscope/perf/query_parameters.py +0 -42
  103. evalscope/perf/server_sent_event.py +0 -43
  104. evalscope/preprocess/tokenizers/gpt2_tokenizer.py +0 -221
  105. /evalscope/perf/{datasets → utils}/__init__.py +0 -0
  106. {evalscope-0.6.1.dist-info → evalscope-0.7.1.dist-info}/entry_points.txt +0 -0
  107. {evalscope/preprocess → tests}/__init__.py +0 -0
  108. {evalscope/preprocess/tokenizers → tests/rag}/__init__.py +0 -0
@@ -0,0 +1,38 @@
1
+ import subprocess
2
+ from typing import Any, Dict, Iterator, List
3
+
4
+ import json
5
+
6
+ from evalscope.perf.arguments import Arguments
7
+ from evalscope.perf.plugin.datasets.base import DatasetPluginBase
8
+ from evalscope.perf.plugin.registry import register_dataset
9
+
10
+
11
+ @register_dataset('openqa')
12
+ class OpenqaDatasetPlugin(DatasetPluginBase):
13
+ """Read dataset and return prompt.
14
+ Datasets: https://www.modelscope.cn/datasets/AI-ModelScope/HC3-Chinese/resolve/master/open_qa.jsonl
15
+ """
16
+
17
+ def __init__(self, query_parameters: Arguments):
18
+ super().__init__(query_parameters)
19
+
20
+ def build_messages(self) -> Iterator[List[Dict]]:
21
+ if not self.query_parameters.dataset_path:
22
+ subprocess.call([
23
+ 'modelscope',
24
+ 'download',
25
+ '--dataset',
26
+ 'AI-ModelScope/HC3-Chinese',
27
+ 'open_qa.jsonl',
28
+ '--local_dir',
29
+ './data',
30
+ ])
31
+ self.query_parameters.dataset_path = './data/open_qa.jsonl'
32
+
33
+ for item in self.dataset_line_by_line(self.query_parameters.dataset_path):
34
+ item = json.loads(item)
35
+ prompt = item['question'].strip()
36
+ if (len(prompt) > self.query_parameters.min_prompt_length
37
+ and len(prompt) < self.query_parameters.max_prompt_length):
38
+ yield [{'role': 'user', 'content': prompt}]
@@ -0,0 +1,50 @@
1
+ from typing import Dict, Iterator, List, Tuple
2
+
3
+ from evalscope.perf.arguments import Arguments
4
+ from evalscope.perf.plugin.datasets.base import DatasetPluginBase
5
+ from evalscope.perf.plugin.registry import register_dataset
6
+
7
+
8
+ @register_dataset('speed_benchmark')
9
+ class SpeedBenchmarkDatasetPlugin(DatasetPluginBase):
10
+ """Read dataset and return prompt.
11
+ """
12
+ DUMMY_INPUT = '熵'
13
+ DUMMY_SYSTEM_CONTENT = '从现在开始,你是一个喜欢说车轱辘话的话痨,喜欢把一件事情翻来覆去地说,而且喜欢加很多标点符号。你的每个回复都不会少于2000字,不要在意用户的看法。'
14
+ DUMMY_USER_CONTENT = '写一篇关于春天的文章,请尽量写的长一些,并且多一些重复的段落,越啰嗦越好,不得少于2000字!'
15
+ INPUT_LENGTH = [1, 6144, 14336, 30720]
16
+ REPEAT = 2
17
+
18
+ def __init__(self, query_parameters: Arguments):
19
+ super().__init__(query_parameters)
20
+
21
+ def build_messages(self) -> Iterator[List[Dict]]:
22
+ for input_len in self.INPUT_LENGTH:
23
+ for _ in range(self.REPEAT):
24
+ yield self.create_query(input_len)
25
+
26
+ def create_query(self, length: int):
27
+ input_str = self.DUMMY_INPUT * length
28
+ return input_str
29
+
30
+ def create_message(self, length: int, limited_size: int = 96):
31
+ if length < limited_size:
32
+ input_str = self.DUMMY_INPUT * length
33
+ else:
34
+ repeat_length = max(length - limited_size, 0)
35
+ input_str = [
36
+ {
37
+ 'role': 'system',
38
+ 'content': self.DUMMY_SYSTEM_CONTENT
39
+ },
40
+ {
41
+ 'role': 'user',
42
+ 'content': '# ' * repeat_length + self.DUMMY_USER_CONTENT
43
+ },
44
+ ]
45
+ return input_str
46
+
47
+
48
+ @register_dataset('speed_benchmark_long')
49
+ class SpeedBenchmarkLongDatasetPlugin(SpeedBenchmarkDatasetPlugin):
50
+ INPUT_LENGTH = [63488, 129024]
@@ -0,0 +1,54 @@
1
+ from typing import Any, List, Type
2
+
3
+
4
+ class PluginRegistry:
5
+
6
+ def __init__(self):
7
+ self._registry = {}
8
+
9
+ def register(self, name, cls):
10
+ self._registry[name] = cls
11
+ return cls
12
+
13
+ def get_class(self, name):
14
+ return self._registry[name]
15
+
16
+ def all_classes(self):
17
+ return list(self._registry.keys())
18
+
19
+ def __call__(self, name: str) -> Any:
20
+ return self.get_class(name)
21
+
22
+
23
+ def register_dataset(name: str | List[str]):
24
+
25
+ def class_decorator(cls: Type):
26
+ if isinstance(name, str):
27
+ DatasetRegistry.register(name, cls)
28
+ elif isinstance(name, list):
29
+ for n in name:
30
+ DatasetRegistry.register(n, cls)
31
+ else:
32
+ raise TypeError('name must be a string or a list of strings')
33
+ return cls
34
+
35
+ return class_decorator
36
+
37
+
38
+ def register_api(name: str | List[str]):
39
+
40
+ def class_decorator(cls: Type):
41
+ if isinstance(name, str):
42
+ ApiRegistry.register(name, cls)
43
+ elif isinstance(name, list):
44
+ for n in name:
45
+ ApiRegistry.register(n, cls)
46
+ else:
47
+ raise TypeError('name must be a string or a list of strings')
48
+ return cls
49
+
50
+ return class_decorator
51
+
52
+
53
+ DatasetRegistry = PluginRegistry()
54
+ ApiRegistry = PluginRegistry()
@@ -1,12 +1,15 @@
1
- import sqlite3
2
1
  import base64
3
2
  import pickle
3
+ import sqlite3
4
+
4
5
  import json
5
- result_db_path = 'db_name.db'
6
+
7
+ result_db_path = '/mnt/data/data/user/maoyunlin.myl/eval-scope/outputs/qwen2.5_benchmark_20241111_160543.db'
6
8
  con = sqlite3.connect(result_db_path)
7
9
  query_sql = "SELECT request, response_messages, prompt_tokens, completion_tokens \
8
- FROM result WHERE success='True'"
9
- # how to save base64.b64encode(pickle.dumps(benchmark_data["request"])).decode("ascii"),
10
+ FROM result WHERE success='1'"
11
+
12
+ # how to save base64.b64encode(pickle.dumps(benchmark_data["request"])).decode("ascii"),
10
13
  with con:
11
14
  rows = con.execute(query_sql).fetchall()
12
15
  if len(rows) > 0:
@@ -20,5 +23,8 @@ with con:
20
23
  response_content = ''
21
24
  for response in responses:
22
25
  response = json.loads(response)
26
+ if not response['choices']:
27
+ continue
23
28
  response_content += response['choices'][0]['delta']['content']
24
- print('prompt: %s, tokens: %s, completion: %s, tokens: %s' % (request['messages'][0]['content'], row[2], response_content, row[3]))
29
+ print('prompt: %s, tokens: %s, completion: %s, tokens: %s' %
30
+ (request['messages'][0]['content'], row[2], response_content, row[3]))
@@ -0,0 +1,135 @@
1
+ import time
2
+ from dataclasses import dataclass, field
3
+ from typing import Any, List, Optional, Tuple
4
+
5
+ import torch
6
+
7
+ from evalscope.utils.logger import get_logger
8
+
9
+ logger = get_logger()
10
+
11
+
12
+ @dataclass
13
+ class BenchmarkData:
14
+ request: Any = None
15
+ start_time: float = field(default_factory=time.perf_counter)
16
+ completed_time: float = 0.0
17
+ chunk_times: List[float] = field(default_factory=list)
18
+ success: bool = False
19
+ response_messages: List[Any] = field(default_factory=list)
20
+
21
+ # late init
22
+ query_latency: float = 0.0
23
+ first_chunk_latency: float = 0.0
24
+ n_chunks: int = 0
25
+ n_chunks_time: float = 0.0
26
+ max_gpu_memory_cost = 0
27
+
28
+ prompt_tokens = None
29
+ completion_tokens = None
30
+
31
+ def _calculate_query_stream_metric(self) -> Tuple[float, int, float]:
32
+ self.query_latency = self.completed_time - self.start_time
33
+ if len(self.chunk_times) > 1:
34
+ self.first_chunk_latency = self.chunk_times[0] - self.start_time
35
+ self.n_chunks = len(self.chunk_times) - 2
36
+ self.n_chunks_time = self.chunk_times[-2] - self.chunk_times[0]
37
+ else:
38
+ self.first_chunk_latency = self.query_latency
39
+ self.n_chunks = 1
40
+ self.n_chunks_time = self.query_latency
41
+
42
+ def _calculate_tokens(self, api_plugin):
43
+ self.prompt_tokens, self.completion_tokens = \
44
+ api_plugin.parse_responses(self.response_messages, request=self.request)
45
+
46
+ def update_gpu_usage(self):
47
+ total_memory = 0
48
+ for i in range(torch.cuda.device_count()):
49
+ total_memory += (torch.cuda.max_memory_allocated(i) / 2**30) # GB
50
+ self.max_gpu_memory_cost = max(self.max_gpu_memory_cost, total_memory)
51
+
52
+
53
+ @dataclass
54
+ class BenchmarkMetrics:
55
+ concurrency: int = 0
56
+ n_succeed_queries: int = 0
57
+ n_failed_queries: int = 0
58
+ total_first_chunk_latency: float = 0.0
59
+ total_latency: float = 0.0
60
+ n_total_chunks: int = 0
61
+ n_total_prompt_tokens: int = 0
62
+ n_total_completion_tokens: int = 0
63
+ total_chunks_time: float = 0.0
64
+ start_time: Optional[float] = None
65
+ total_time: float = 1.0
66
+ n_total_queries: int = 0
67
+
68
+ avg_first_chunk_latency: float = -1
69
+ avg_latency: float = -1
70
+ n_avg_chunks: float = -1
71
+ avg_chunk_time: float = -1
72
+ avg_prompt_tokens: float = -1
73
+ avg_completion_tokens: float = -1
74
+ avg_token_per_seconds: float = -1
75
+ avg_time_per_token: float = -1
76
+ qps: float = -1
77
+
78
+ def update_metrics(self, benchmark_data: BenchmarkData, api_plugin):
79
+ self.n_total_queries += 1
80
+ if self.start_time is None:
81
+ self.start_time = benchmark_data.start_time
82
+ self.total_time = time.perf_counter() - self.start_time
83
+
84
+ if benchmark_data.success:
85
+ self.n_succeed_queries += 1
86
+
87
+ benchmark_data._calculate_tokens(api_plugin)
88
+ self.n_total_prompt_tokens += benchmark_data.prompt_tokens
89
+ self.n_total_completion_tokens += benchmark_data.completion_tokens
90
+
91
+ benchmark_data._calculate_query_stream_metric()
92
+ self.total_latency += benchmark_data.query_latency
93
+ self.total_first_chunk_latency += benchmark_data.first_chunk_latency
94
+ self.n_total_chunks += benchmark_data.n_chunks
95
+ self.total_chunks_time += benchmark_data.n_chunks_time
96
+ else:
97
+ self.n_failed_queries += 1
98
+
99
+ self.calculate_averages()
100
+
101
+ def calculate_averages(self):
102
+ if self.n_succeed_queries == 0:
103
+ return
104
+ try:
105
+ self.avg_first_chunk_latency = self.total_first_chunk_latency / self.n_succeed_queries
106
+ self.avg_latency = self.total_latency / self.n_succeed_queries
107
+ self.n_avg_chunks = self.n_total_chunks / self.n_succeed_queries
108
+ self.avg_chunk_time = self.total_chunks_time / self.n_total_chunks
109
+ self.avg_prompt_tokens = self.n_total_prompt_tokens / self.n_succeed_queries
110
+ self.avg_completion_tokens = self.n_total_completion_tokens / self.n_succeed_queries
111
+ self.avg_token_per_seconds = self.n_total_completion_tokens / self.total_time
112
+ self.avg_time_per_token = self.total_time / self.n_total_completion_tokens
113
+ self.qps = self.n_succeed_queries / self.total_time
114
+ except ZeroDivisionError as e:
115
+ logger.exception(e)
116
+ return
117
+
118
+ def create_message(self, default_ndigits=3):
119
+ message = {
120
+ 'Time taken for tests (senconds)': round(self.total_time, default_ndigits),
121
+ 'Number of concurrency': self.concurrency,
122
+ 'Total requests': int(self.n_total_queries),
123
+ 'Succeed requests': self.n_succeed_queries,
124
+ 'Failed requests': self.n_failed_queries,
125
+ 'Average QPS': round(self.qps, default_ndigits),
126
+ 'Average latency (s)': round(self.avg_latency, default_ndigits),
127
+ 'Average time to first token (s)': round(self.avg_first_chunk_latency, default_ndigits),
128
+ 'Average time per output token (s)': round(self.avg_time_per_token, 5),
129
+ 'Average package latency (s)': round(self.avg_chunk_time, default_ndigits),
130
+ 'Average package per request': round(self.n_avg_chunks, default_ndigits),
131
+ 'Throughput(average output tokens per second)': round(self.avg_token_per_seconds, default_ndigits),
132
+ 'Average input tokens per request': round(self.avg_prompt_tokens, default_ndigits),
133
+ 'Average output tokens per request': round(self.avg_completion_tokens, default_ndigits),
134
+ }
135
+ return message
@@ -0,0 +1,252 @@
1
+ import os
2
+ import time
3
+ from contextlib import contextmanager
4
+ from functools import partial
5
+ from threading import Thread
6
+ from typing import List, Literal, Optional, Union
7
+
8
+ import torch
9
+ from modelscope import AutoModelForCausalLM, AutoTokenizer
10
+ from pydantic import BaseModel, Field
11
+ from transformers import TextIteratorStreamer
12
+
13
+
14
+ class Usage(BaseModel):
15
+ prompt_tokens: int = 0
16
+ completion_tokens: int = 0
17
+ total_tokens: int = 0
18
+
19
+
20
+ class ModelCard(BaseModel):
21
+ id: str
22
+ object: str = 'model'
23
+ created: int = Field(default_factory=lambda: int(time.time()))
24
+ owned_by: str = 'owner'
25
+ root: Optional[str] = None
26
+ parent: Optional[str] = None
27
+ permission: Optional[list] = None
28
+
29
+
30
+ class ModelList(BaseModel):
31
+ object: str = 'list'
32
+ data: List[ModelCard] = []
33
+
34
+
35
+ class ChatMessage(BaseModel):
36
+ role: Literal['user', 'assistant', 'system']
37
+ content: str
38
+
39
+
40
+ class DeltaMessage(BaseModel):
41
+ role: Optional[Literal['user', 'assistant', 'system']] = None
42
+ content: Optional[str] = None
43
+
44
+
45
+ class ChatCompletionRequest(BaseModel):
46
+ model: str
47
+ messages: List[ChatMessage] | str
48
+ temperature: Optional[float] = None
49
+ top_p: Optional[float] = None
50
+ max_tokens: Optional[int] = 2048
51
+ min_tokens: Optional[int] = None
52
+ stream: Optional[bool] = False
53
+
54
+
55
+ class ChatCompletionResponseChoice(BaseModel):
56
+ index: int
57
+ message: ChatMessage
58
+ finish_reason: Literal['stop', 'length']
59
+
60
+
61
+ class ChatCompletionResponseStreamChoice(BaseModel):
62
+ index: int
63
+ delta: DeltaMessage
64
+ finish_reason: Optional[Literal['stop', 'length']]
65
+
66
+
67
+ class ChatCompletionResponse(BaseModel):
68
+ model: str
69
+ object: Literal['chat.completion', 'chat.completion.chunk']
70
+ choices: List[Union[ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice]]
71
+ created: Optional[int] = Field(default_factory=lambda: int(time.time()))
72
+ usage: Optional[Usage]
73
+
74
+
75
+ class TextCompletionRequest(BaseModel):
76
+ model: str
77
+ prompt: str
78
+ temperature: Optional[float] = None
79
+ max_tokens: Optional[int] = 2048
80
+ min_tokens: Optional[int] = None
81
+
82
+
83
+ class TextCompletionResponseChoice(BaseModel):
84
+ index: int
85
+ text: str
86
+ finish_reason: Literal['stop', 'length']
87
+
88
+
89
+ class TextCompletionResponse(BaseModel):
90
+ model: str
91
+ object: Literal['text_completion']
92
+ created: Optional[int] = Field(default_factory=lambda: int(time.time()))
93
+ choices: List[TextCompletionResponseChoice]
94
+ usage: Optional[Usage]
95
+
96
+
97
+ class ChatService:
98
+
99
+ def __init__(self, model_path, attn_implementation):
100
+ self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
101
+ self.model = AutoModelForCausalLM.from_pretrained(
102
+ model_path,
103
+ trust_remote_code=True,
104
+ device_map='auto',
105
+ torch_dtype='auto',
106
+ attn_implementation=attn_implementation,
107
+ ).eval()
108
+ self.streamer = TextIteratorStreamer(self.tokenizer, skip_prompt=True)
109
+ self.model_id = os.path.basename(model_path)
110
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
111
+
112
+ def count_tokens(self, text: str) -> int:
113
+ # Use the tokenizer to count the number of tokens
114
+ return len(self.tokenizer.encode(text, add_special_tokens=False))
115
+
116
+ async def list_models(self):
117
+ model_card = ModelCard(id=self.model_id)
118
+ return ModelList(data=[model_card])
119
+
120
+ async def _chat(self, request: ChatCompletionRequest):
121
+ formatted_prompt, inputs, prompt_tokens = self._prepare_chat_inputs(request)
122
+ outputs = self.model.generate(
123
+ **inputs,
124
+ max_new_tokens=request.max_tokens,
125
+ min_new_tokens=request.min_tokens,
126
+ temperature=request.temperature,
127
+ )
128
+ outputs = outputs[0][prompt_tokens:] # remove prompt
129
+ completion_tokens = len(outputs)
130
+ response = self.tokenizer.decode(outputs, skip_special_tokens=True)
131
+
132
+ choice_data = ChatCompletionResponseChoice(
133
+ index=0,
134
+ message=ChatMessage(role='assistant', content=response),
135
+ finish_reason='stop',
136
+ )
137
+ return ChatCompletionResponse(
138
+ model=self.model_id,
139
+ choices=[choice_data],
140
+ object='chat.completion',
141
+ usage=Usage(
142
+ prompt_tokens=prompt_tokens,
143
+ completion_tokens=completion_tokens,
144
+ total_tokens=prompt_tokens + completion_tokens,
145
+ ),
146
+ )
147
+
148
+ async def _text_completion(self, request: TextCompletionRequest):
149
+ inputs, prompt_tokens = self._prepare_text_inputs(request)
150
+ outputs = self.model.generate(
151
+ **inputs,
152
+ max_new_tokens=request.max_tokens,
153
+ min_new_tokens=request.min_tokens,
154
+ temperature=request.temperature,
155
+ )
156
+ outputs = outputs[0][prompt_tokens:] # remove prompt
157
+ completion_tokens = len(outputs)
158
+ response = self.tokenizer.decode(outputs, skip_special_tokens=True)
159
+
160
+ choice_data = TextCompletionResponseChoice(
161
+ index=0,
162
+ text=response,
163
+ finish_reason='stop',
164
+ )
165
+ return TextCompletionResponse(
166
+ model=self.model_id,
167
+ choices=[choice_data],
168
+ object='text_completion',
169
+ usage=Usage(
170
+ prompt_tokens=prompt_tokens,
171
+ completion_tokens=completion_tokens,
172
+ total_tokens=prompt_tokens + completion_tokens,
173
+ ),
174
+ )
175
+
176
+ def _prepare_text_inputs(self, request: TextCompletionRequest):
177
+ inputs = self.tokenizer(request.prompt, return_tensors='pt', padding=True).to(self.device)
178
+ prompt_tokens = len(inputs['input_ids'][0])
179
+ return inputs, prompt_tokens
180
+
181
+ def _stream_chat(self, request: ChatCompletionRequest):
182
+ formatted_prompt, inputs, prompt_tokens = self._prepare_chat_inputs(request)
183
+ completion_tokens = 0
184
+
185
+ yield self._create_initial_chunk()
186
+
187
+ generation_kwargs = dict(
188
+ **inputs,
189
+ streamer=self.streamer,
190
+ max_new_tokens=request.max_tokens,
191
+ min_new_tokens=request.min_tokens,
192
+ temperature=request.temperature,
193
+ )
194
+ generate_partial = partial(self.model.generate, **generation_kwargs)
195
+
196
+ with self._start_generation_thread(generate_partial):
197
+ for new_text in self.streamer:
198
+ yield self._create_chunk(new_text)
199
+ completion_tokens += self.count_tokens(new_text)
200
+
201
+ yield self._create_final_chunk(prompt_tokens, completion_tokens)
202
+ yield '[DONE]'
203
+
204
+ def _prepare_chat_inputs(self, request: ChatCompletionRequest):
205
+ formatted_prompt = self.tokenizer.apply_chat_template(
206
+ request.messages, tokenize=False, add_generation_prompt=True)
207
+ inputs = self.tokenizer(formatted_prompt, return_tensors='pt', padding=True).to(self.device)
208
+ prompt_tokens = len(inputs['input_ids'][0])
209
+ return formatted_prompt, inputs, prompt_tokens
210
+
211
+ @contextmanager
212
+ def _start_generation_thread(self, generate_partial):
213
+ thread = Thread(target=generate_partial)
214
+ thread.start()
215
+ try:
216
+ yield
217
+ finally:
218
+ thread.join()
219
+
220
+ def _create_initial_chunk(self):
221
+ choice_data = ChatCompletionResponseStreamChoice(index=0, delta={'role': 'assistant'}, finish_reason=None)
222
+ chunk = ChatCompletionResponse(
223
+ model=self.model_id,
224
+ choices=[choice_data],
225
+ object='chat.completion.chunk',
226
+ usage=None,
227
+ )
228
+ return chunk.model_dump_json(exclude_unset=True)
229
+
230
+ def _create_chunk(self, new_text):
231
+ choice_data = ChatCompletionResponseStreamChoice(index=0, delta={'content': new_text}, finish_reason=None)
232
+ chunk = ChatCompletionResponse(
233
+ model=self.model_id,
234
+ choices=[choice_data],
235
+ object='chat.completion.chunk',
236
+ usage=None,
237
+ )
238
+ return chunk.model_dump_json(exclude_unset=True)
239
+
240
+ def _create_final_chunk(self, prompt_tokens, completion_tokens):
241
+ choice_data = ChatCompletionResponseStreamChoice(index=0, delta={}, finish_reason='stop')
242
+ chunk = ChatCompletionResponse(
243
+ model=self.model_id,
244
+ choices=[choice_data],
245
+ object='chat.completion.chunk',
246
+ usage=Usage(
247
+ prompt_tokens=prompt_tokens,
248
+ completion_tokens=completion_tokens,
249
+ total_tokens=prompt_tokens + completion_tokens,
250
+ ),
251
+ )
252
+ return chunk.model_dump_json(exclude_unset=True)