epub-translator 0.0.7__py3-none-any.whl → 0.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- epub_translator/__init__.py +4 -2
- epub_translator/data/fill.jinja +66 -0
- epub_translator/data/mmltex/README.md +67 -0
- epub_translator/data/mmltex/cmarkup.xsl +1106 -0
- epub_translator/data/mmltex/entities.xsl +459 -0
- epub_translator/data/mmltex/glayout.xsl +222 -0
- epub_translator/data/mmltex/mmltex.xsl +36 -0
- epub_translator/data/mmltex/scripts.xsl +375 -0
- epub_translator/data/mmltex/tables.xsl +130 -0
- epub_translator/data/mmltex/tokens.xsl +328 -0
- epub_translator/data/translate.jinja +15 -12
- epub_translator/epub/__init__.py +4 -2
- epub_translator/epub/common.py +43 -0
- epub_translator/epub/math.py +193 -0
- epub_translator/epub/placeholder.py +53 -0
- epub_translator/epub/spines.py +42 -0
- epub_translator/epub/toc.py +505 -0
- epub_translator/epub/zip.py +67 -0
- epub_translator/iter_sync.py +24 -0
- epub_translator/language.py +23 -0
- epub_translator/llm/__init__.py +2 -1
- epub_translator/llm/core.py +233 -0
- epub_translator/llm/error.py +38 -35
- epub_translator/llm/executor.py +159 -136
- epub_translator/llm/increasable.py +28 -28
- epub_translator/llm/types.py +17 -0
- epub_translator/serial/__init__.py +2 -0
- epub_translator/serial/chunk.py +52 -0
- epub_translator/serial/segment.py +17 -0
- epub_translator/serial/splitter.py +50 -0
- epub_translator/template.py +35 -33
- epub_translator/translator.py +208 -178
- epub_translator/utils.py +7 -0
- epub_translator/xml/__init__.py +4 -3
- epub_translator/xml/deduplication.py +38 -0
- epub_translator/xml/firendly/__init__.py +2 -0
- epub_translator/xml/firendly/decoder.py +75 -0
- epub_translator/xml/firendly/encoder.py +84 -0
- epub_translator/xml/firendly/parser.py +177 -0
- epub_translator/xml/firendly/tag.py +118 -0
- epub_translator/xml/firendly/transform.py +36 -0
- epub_translator/xml/xml.py +52 -0
- epub_translator/xml/xml_like.py +231 -0
- epub_translator/xml_translator/__init__.py +3 -0
- epub_translator/xml_translator/const.py +2 -0
- epub_translator/xml_translator/fill.py +128 -0
- epub_translator/xml_translator/format.py +282 -0
- epub_translator/xml_translator/fragmented.py +125 -0
- epub_translator/xml_translator/group.py +183 -0
- epub_translator/xml_translator/progressive_locking.py +256 -0
- epub_translator/xml_translator/submitter.py +102 -0
- epub_translator/xml_translator/text_segment.py +263 -0
- epub_translator/xml_translator/translator.py +179 -0
- epub_translator/xml_translator/utils.py +29 -0
- epub_translator-0.1.1.dist-info/METADATA +283 -0
- epub_translator-0.1.1.dist-info/RECORD +58 -0
- epub_translator/data/format.jinja +0 -33
- epub_translator/epub/content_parser.py +0 -162
- epub_translator/epub/html/__init__.py +0 -1
- epub_translator/epub/html/dom_operator.py +0 -68
- epub_translator/epub/html/empty_tags.py +0 -23
- epub_translator/epub/html/file.py +0 -80
- epub_translator/epub/html/texts_searcher.py +0 -46
- epub_translator/llm/node.py +0 -201
- epub_translator/translation/__init__.py +0 -2
- epub_translator/translation/chunk.py +0 -118
- epub_translator/translation/splitter.py +0 -78
- epub_translator/translation/store.py +0 -36
- epub_translator/translation/translation.py +0 -231
- epub_translator/translation/types.py +0 -45
- epub_translator/translation/utils.py +0 -11
- epub_translator/xml/decoder.py +0 -71
- epub_translator/xml/encoder.py +0 -95
- epub_translator/xml/parser.py +0 -172
- epub_translator/xml/tag.py +0 -93
- epub_translator/xml/transform.py +0 -34
- epub_translator/xml/utils.py +0 -12
- epub_translator/zip_context.py +0 -74
- epub_translator-0.0.7.dist-info/METADATA +0 -170
- epub_translator-0.0.7.dist-info/RECORD +0 -36
- {epub_translator-0.0.7.dist-info → epub_translator-0.1.1.dist-info}/LICENSE +0 -0
- {epub_translator-0.0.7.dist-info → epub_translator-0.1.1.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,233 @@
|
|
|
1
|
+
import datetime
|
|
2
|
+
import hashlib
|
|
3
|
+
import json
|
|
4
|
+
import uuid
|
|
5
|
+
from collections.abc import Callable, Generator
|
|
6
|
+
from importlib.resources import files
|
|
7
|
+
from logging import DEBUG, FileHandler, Formatter, Logger, getLogger
|
|
8
|
+
from os import PathLike
|
|
9
|
+
from pathlib import Path
|
|
10
|
+
from typing import Self
|
|
11
|
+
|
|
12
|
+
from jinja2 import Environment, Template
|
|
13
|
+
from tiktoken import Encoding, get_encoding
|
|
14
|
+
|
|
15
|
+
from ..template import create_env
|
|
16
|
+
from .executor import LLMExecutor
|
|
17
|
+
from .increasable import Increasable
|
|
18
|
+
from .types import Message, MessageRole, R
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class LLMContext:
|
|
22
|
+
"""Context manager for LLM requests with transactional caching."""
|
|
23
|
+
|
|
24
|
+
def __init__(
|
|
25
|
+
self,
|
|
26
|
+
executor: LLMExecutor,
|
|
27
|
+
cache_path: Path | None,
|
|
28
|
+
) -> None:
|
|
29
|
+
self._executor = executor
|
|
30
|
+
self._cache_path = cache_path
|
|
31
|
+
self._context_id = uuid.uuid4().hex[:12]
|
|
32
|
+
self._temp_files: list[Path] = []
|
|
33
|
+
|
|
34
|
+
def __enter__(self) -> Self:
|
|
35
|
+
return self
|
|
36
|
+
|
|
37
|
+
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
|
|
38
|
+
if exc_type is None:
|
|
39
|
+
# Success: commit all temporary cache files
|
|
40
|
+
self._commit()
|
|
41
|
+
else:
|
|
42
|
+
# Failure: rollback (delete) all temporary cache files
|
|
43
|
+
self._rollback()
|
|
44
|
+
|
|
45
|
+
def request(
|
|
46
|
+
self,
|
|
47
|
+
input: str | list[Message],
|
|
48
|
+
parser: Callable[[str], R] = lambda x: x,
|
|
49
|
+
max_tokens: int | None = None,
|
|
50
|
+
) -> R:
|
|
51
|
+
messages: list[Message]
|
|
52
|
+
if isinstance(input, str):
|
|
53
|
+
messages = [Message(role=MessageRole.USER, message=input)]
|
|
54
|
+
else:
|
|
55
|
+
messages = input
|
|
56
|
+
|
|
57
|
+
cache_key: str | None = None
|
|
58
|
+
if self._cache_path is not None:
|
|
59
|
+
cache_key = self._compute_messages_hash(messages)
|
|
60
|
+
permanent_cache_file = self._cache_path / f"{cache_key}.txt"
|
|
61
|
+
if permanent_cache_file.exists():
|
|
62
|
+
cached_content = permanent_cache_file.read_text(encoding="utf-8")
|
|
63
|
+
return parser(cached_content)
|
|
64
|
+
|
|
65
|
+
temp_cache_file = self._cache_path / f"{cache_key}.{self._context_id}.txt"
|
|
66
|
+
if temp_cache_file.exists():
|
|
67
|
+
cached_content = temp_cache_file.read_text(encoding="utf-8")
|
|
68
|
+
return parser(cached_content)
|
|
69
|
+
|
|
70
|
+
# Make the actual request
|
|
71
|
+
response = self._executor.request(
|
|
72
|
+
messages=messages,
|
|
73
|
+
parser=lambda x: x,
|
|
74
|
+
max_tokens=max_tokens,
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
# Save to temporary cache if cache_path is set
|
|
78
|
+
if self._cache_path is not None and cache_key is not None:
|
|
79
|
+
temp_cache_file = self._cache_path / f"{cache_key}.{self._context_id}.txt"
|
|
80
|
+
temp_cache_file.write_text(response, encoding="utf-8")
|
|
81
|
+
self._temp_files.append(temp_cache_file)
|
|
82
|
+
|
|
83
|
+
return parser(response)
|
|
84
|
+
|
|
85
|
+
def _compute_messages_hash(self, messages: list[Message]) -> str:
|
|
86
|
+
messages_dict = [{"role": msg.role.value, "message": msg.message} for msg in messages]
|
|
87
|
+
messages_json = json.dumps(messages_dict, ensure_ascii=False, sort_keys=True)
|
|
88
|
+
return hashlib.sha512(messages_json.encode("utf-8")).hexdigest()
|
|
89
|
+
|
|
90
|
+
def _commit(self) -> None:
|
|
91
|
+
for temp_file in self._temp_files:
|
|
92
|
+
if temp_file.exists():
|
|
93
|
+
# Remove the .[context-id].txt suffix to get permanent name
|
|
94
|
+
permanent_name = temp_file.name.rsplit(".", 2)[0] + ".txt"
|
|
95
|
+
permanent_file = temp_file.parent / permanent_name
|
|
96
|
+
temp_file.rename(permanent_file)
|
|
97
|
+
|
|
98
|
+
def _rollback(self) -> None:
|
|
99
|
+
for temp_file in self._temp_files:
|
|
100
|
+
if temp_file.exists():
|
|
101
|
+
temp_file.unlink()
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
class LLM:
|
|
105
|
+
def __init__(
|
|
106
|
+
self,
|
|
107
|
+
key: str,
|
|
108
|
+
url: str,
|
|
109
|
+
model: str,
|
|
110
|
+
token_encoding: str,
|
|
111
|
+
cache_path: PathLike | None = None,
|
|
112
|
+
timeout: float | None = None,
|
|
113
|
+
top_p: float | tuple[float, float] | None = None,
|
|
114
|
+
temperature: float | tuple[float, float] | None = None,
|
|
115
|
+
retry_times: int = 5,
|
|
116
|
+
retry_interval_seconds: float = 6.0,
|
|
117
|
+
log_dir_path: PathLike | None = None,
|
|
118
|
+
) -> None:
|
|
119
|
+
prompts_path = Path(str(files("epub_translator"))) / "data"
|
|
120
|
+
self._templates: dict[str, Template] = {}
|
|
121
|
+
self._encoding: Encoding = get_encoding(token_encoding)
|
|
122
|
+
self._env: Environment = create_env(prompts_path)
|
|
123
|
+
self._logger_save_path: Path | None = None
|
|
124
|
+
self._cache_path: Path | None = None
|
|
125
|
+
|
|
126
|
+
if cache_path is not None:
|
|
127
|
+
self._cache_path = Path(cache_path)
|
|
128
|
+
if not self._cache_path.exists():
|
|
129
|
+
self._cache_path.mkdir(parents=True, exist_ok=True)
|
|
130
|
+
elif not self._cache_path.is_dir():
|
|
131
|
+
self._cache_path = None
|
|
132
|
+
|
|
133
|
+
if log_dir_path is not None:
|
|
134
|
+
self._logger_save_path = Path(log_dir_path)
|
|
135
|
+
if not self._logger_save_path.exists():
|
|
136
|
+
self._logger_save_path.mkdir(parents=True, exist_ok=True)
|
|
137
|
+
elif not self._logger_save_path.is_dir():
|
|
138
|
+
self._logger_save_path = None
|
|
139
|
+
|
|
140
|
+
self._executor = LLMExecutor(
|
|
141
|
+
url=url,
|
|
142
|
+
model=model,
|
|
143
|
+
api_key=key,
|
|
144
|
+
timeout=timeout,
|
|
145
|
+
top_p=Increasable(top_p),
|
|
146
|
+
temperature=Increasable(temperature),
|
|
147
|
+
retry_times=retry_times,
|
|
148
|
+
retry_interval_seconds=retry_interval_seconds,
|
|
149
|
+
create_logger=self._create_logger,
|
|
150
|
+
)
|
|
151
|
+
|
|
152
|
+
@property
|
|
153
|
+
def encoding(self) -> Encoding:
|
|
154
|
+
return self._encoding
|
|
155
|
+
|
|
156
|
+
def context(self) -> LLMContext:
|
|
157
|
+
return LLMContext(
|
|
158
|
+
executor=self._executor,
|
|
159
|
+
cache_path=self._cache_path,
|
|
160
|
+
)
|
|
161
|
+
|
|
162
|
+
def request(
|
|
163
|
+
self,
|
|
164
|
+
input: str | list[Message],
|
|
165
|
+
parser: Callable[[str], R] = lambda x: x,
|
|
166
|
+
max_tokens: int | None = None,
|
|
167
|
+
) -> R:
|
|
168
|
+
with self.context() as ctx:
|
|
169
|
+
return ctx.request(input=input, parser=parser, max_tokens=max_tokens)
|
|
170
|
+
|
|
171
|
+
def template(self, template_name: str) -> Template:
|
|
172
|
+
template = self._templates.get(template_name, None)
|
|
173
|
+
if template is None:
|
|
174
|
+
template = self._env.get_template(template_name)
|
|
175
|
+
self._templates[template_name] = template
|
|
176
|
+
return template
|
|
177
|
+
|
|
178
|
+
def _create_logger(self) -> Logger | None:
|
|
179
|
+
if self._logger_save_path is None:
|
|
180
|
+
return None
|
|
181
|
+
|
|
182
|
+
now = datetime.datetime.now(datetime.UTC)
|
|
183
|
+
timestamp = now.strftime("%Y-%m-%d %H-%M-%S %f")
|
|
184
|
+
file_path = self._logger_save_path / f"request {timestamp}.log"
|
|
185
|
+
logger = getLogger(f"LLM Request {timestamp}")
|
|
186
|
+
logger.setLevel(DEBUG)
|
|
187
|
+
handler = FileHandler(file_path, encoding="utf-8")
|
|
188
|
+
handler.setLevel(DEBUG)
|
|
189
|
+
handler.setFormatter(Formatter("%(asctime)s %(message)s", "%H:%M:%S"))
|
|
190
|
+
logger.addHandler(handler)
|
|
191
|
+
|
|
192
|
+
return logger
|
|
193
|
+
|
|
194
|
+
def _search_quotes(self, kind: str, response: str) -> Generator[str, None, None]:
|
|
195
|
+
start_marker = f"```{kind}"
|
|
196
|
+
end_marker = "```"
|
|
197
|
+
start_index = 0
|
|
198
|
+
|
|
199
|
+
while True:
|
|
200
|
+
start_index = self._find_ignore_case(
|
|
201
|
+
raw=response,
|
|
202
|
+
sub=start_marker,
|
|
203
|
+
start=start_index,
|
|
204
|
+
)
|
|
205
|
+
if start_index == -1:
|
|
206
|
+
break
|
|
207
|
+
|
|
208
|
+
end_index = self._find_ignore_case(
|
|
209
|
+
raw=response,
|
|
210
|
+
sub=end_marker,
|
|
211
|
+
start=start_index + len(start_marker),
|
|
212
|
+
)
|
|
213
|
+
if end_index == -1:
|
|
214
|
+
break
|
|
215
|
+
|
|
216
|
+
extracted_text = response[start_index + len(start_marker) : end_index].strip()
|
|
217
|
+
yield extracted_text
|
|
218
|
+
start_index = end_index + len(end_marker)
|
|
219
|
+
|
|
220
|
+
def _find_ignore_case(self, raw: str, sub: str, start: int = 0):
|
|
221
|
+
if not sub:
|
|
222
|
+
return 0 if 0 >= start else -1
|
|
223
|
+
|
|
224
|
+
raw_len, sub_len = len(raw), len(sub)
|
|
225
|
+
for i in range(start, raw_len - sub_len + 1):
|
|
226
|
+
match = True
|
|
227
|
+
for j in range(sub_len):
|
|
228
|
+
if raw[i + j].lower() != sub[j].lower():
|
|
229
|
+
match = False
|
|
230
|
+
break
|
|
231
|
+
if match:
|
|
232
|
+
return i
|
|
233
|
+
return -1
|
epub_translator/llm/error.py
CHANGED
|
@@ -1,49 +1,52 @@
|
|
|
1
|
-
import openai
|
|
2
1
|
import httpx
|
|
2
|
+
import openai
|
|
3
3
|
import requests
|
|
4
4
|
|
|
5
5
|
|
|
6
6
|
def is_retry_error(err: Exception) -> bool:
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
7
|
+
if _is_openai_retry_error(err):
|
|
8
|
+
return True
|
|
9
|
+
if _is_httpx_retry_error(err):
|
|
10
|
+
return True
|
|
11
|
+
if _is_request_retry_error(err):
|
|
12
|
+
return True
|
|
13
|
+
return False
|
|
14
|
+
|
|
14
15
|
|
|
15
16
|
# https://help.openai.com/en/articles/6897213-openai-library-error-types-guidance
|
|
16
17
|
def _is_openai_retry_error(err: Exception) -> bool:
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
18
|
+
if isinstance(err, openai.Timeout):
|
|
19
|
+
return True
|
|
20
|
+
if isinstance(err, openai.APIConnectionError):
|
|
21
|
+
return True
|
|
22
|
+
if isinstance(err, openai.InternalServerError):
|
|
23
|
+
return err.status_code in (502, 503, 504)
|
|
24
|
+
return False
|
|
25
|
+
|
|
24
26
|
|
|
25
27
|
# https://www.python-httpx.org/exceptions/
|
|
26
28
|
def _is_httpx_retry_error(err: Exception) -> bool:
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
29
|
+
if isinstance(err, httpx.RemoteProtocolError):
|
|
30
|
+
return True
|
|
31
|
+
if isinstance(err, httpx.StreamError):
|
|
32
|
+
return True
|
|
33
|
+
if isinstance(err, httpx.TimeoutException):
|
|
34
|
+
return True
|
|
35
|
+
if isinstance(err, httpx.NetworkError):
|
|
36
|
+
return True
|
|
37
|
+
if isinstance(err, httpx.ProtocolError):
|
|
38
|
+
return True
|
|
39
|
+
return False
|
|
40
|
+
|
|
38
41
|
|
|
39
42
|
# https://requests.readthedocs.io/en/latest/api/#exceptions
|
|
40
43
|
def _is_request_retry_error(err: Exception) -> bool:
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
44
|
+
if isinstance(err, requests.ConnectionError):
|
|
45
|
+
return True
|
|
46
|
+
if isinstance(err, requests.ConnectTimeout):
|
|
47
|
+
return True
|
|
48
|
+
if isinstance(err, requests.ReadTimeout):
|
|
49
|
+
return True
|
|
50
|
+
if isinstance(err, requests.Timeout):
|
|
51
|
+
return True
|
|
52
|
+
return False
|
epub_translator/llm/executor.py
CHANGED
|
@@ -1,150 +1,173 @@
|
|
|
1
|
-
from
|
|
1
|
+
from collections.abc import Callable
|
|
2
2
|
from io import StringIO
|
|
3
|
-
from time import sleep
|
|
4
|
-
from pydantic import SecretStr
|
|
5
3
|
from logging import Logger
|
|
6
|
-
from
|
|
7
|
-
from
|
|
8
|
-
|
|
4
|
+
from time import sleep
|
|
5
|
+
from typing import cast
|
|
6
|
+
|
|
7
|
+
from openai import OpenAI
|
|
8
|
+
from openai.types.chat import ChatCompletionMessageParam
|
|
9
9
|
|
|
10
|
-
from .increasable import Increasable, Increaser
|
|
11
10
|
from .error import is_retry_error
|
|
11
|
+
from .increasable import Increasable, Increaser
|
|
12
|
+
from .types import Message, MessageRole, R
|
|
12
13
|
|
|
13
14
|
|
|
14
15
|
class LLMExecutor:
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
logger.debug(f"[[Request]]:\n{self._input2str(input)}\n")
|
|
51
|
-
|
|
52
|
-
try:
|
|
53
|
-
for i in range(self._retry_times + 1):
|
|
54
|
-
try:
|
|
55
|
-
response = self._invoke_model(
|
|
56
|
-
input=input,
|
|
57
|
-
top_p=top_p.current,
|
|
58
|
-
temperature=temperature.current,
|
|
59
|
-
max_tokens=max_tokens,
|
|
60
|
-
)
|
|
61
|
-
if logger is not None:
|
|
62
|
-
logger.debug(f"[[Response]]:\n{response}\n")
|
|
16
|
+
def __init__(
|
|
17
|
+
self,
|
|
18
|
+
api_key: str,
|
|
19
|
+
url: str,
|
|
20
|
+
model: str,
|
|
21
|
+
timeout: float | None,
|
|
22
|
+
top_p: Increasable,
|
|
23
|
+
temperature: Increasable,
|
|
24
|
+
retry_times: int,
|
|
25
|
+
retry_interval_seconds: float,
|
|
26
|
+
create_logger: Callable[[], Logger | None],
|
|
27
|
+
) -> None:
|
|
28
|
+
self._model_name: str = model
|
|
29
|
+
self._timeout: float | None = timeout
|
|
30
|
+
self._top_p: Increasable = top_p
|
|
31
|
+
self._temperature: Increasable = temperature
|
|
32
|
+
self._retry_times: int = retry_times
|
|
33
|
+
self._retry_interval_seconds: float = retry_interval_seconds
|
|
34
|
+
self._create_logger: Callable[[], Logger | None] = create_logger
|
|
35
|
+
self._client = OpenAI(
|
|
36
|
+
api_key=api_key,
|
|
37
|
+
base_url=url,
|
|
38
|
+
timeout=timeout,
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
def request(self, messages: list[Message], parser: Callable[[str], R], max_tokens: int | None) -> R:
|
|
42
|
+
result: R | None = None
|
|
43
|
+
last_error: Exception | None = None
|
|
44
|
+
did_success = False
|
|
45
|
+
top_p: Increaser = self._top_p.context()
|
|
46
|
+
temperature: Increaser = self._temperature.context()
|
|
47
|
+
logger = self._create_logger()
|
|
48
|
+
|
|
49
|
+
if logger is not None:
|
|
50
|
+
logger.debug(f"[[Request]]:\n{self._input2str(messages)}\n")
|
|
63
51
|
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
52
|
+
try:
|
|
53
|
+
for i in range(self._retry_times + 1):
|
|
54
|
+
try:
|
|
55
|
+
response = self._invoke_model(
|
|
56
|
+
input_messages=messages,
|
|
57
|
+
top_p=top_p.current,
|
|
58
|
+
temperature=temperature.current,
|
|
59
|
+
max_tokens=max_tokens,
|
|
60
|
+
)
|
|
61
|
+
if logger is not None:
|
|
62
|
+
logger.debug(f"[[Response]]:\n{response}\n")
|
|
63
|
+
|
|
64
|
+
except Exception as err:
|
|
65
|
+
last_error = err
|
|
66
|
+
if not is_retry_error(err):
|
|
67
|
+
raise err
|
|
68
|
+
if logger is not None:
|
|
69
|
+
logger.warning(f"request failed with connection error, retrying... ({i + 1} times)")
|
|
70
|
+
if self._retry_interval_seconds > 0.0 and i < self._retry_times:
|
|
71
|
+
sleep(self._retry_interval_seconds)
|
|
72
|
+
continue
|
|
73
|
+
|
|
74
|
+
try:
|
|
75
|
+
result = parser(response)
|
|
76
|
+
did_success = True
|
|
77
|
+
break
|
|
78
|
+
|
|
79
|
+
except Exception as err:
|
|
80
|
+
last_error = err
|
|
81
|
+
warn_message = f"request failed with parsing error, retrying... ({i + 1} times)"
|
|
82
|
+
if logger is not None:
|
|
83
|
+
logger.warning(warn_message)
|
|
84
|
+
print(warn_message)
|
|
85
|
+
top_p.increase()
|
|
86
|
+
temperature.increase()
|
|
87
|
+
if self._retry_interval_seconds > 0.0 and i < self._retry_times:
|
|
88
|
+
sleep(self._retry_interval_seconds)
|
|
89
|
+
continue
|
|
90
|
+
|
|
91
|
+
except KeyboardInterrupt as err:
|
|
92
|
+
if last_error is not None and logger is not None:
|
|
93
|
+
logger.debug(f"[[Error]]:\n{last_error}\n")
|
|
67
94
|
raise err
|
|
68
|
-
if logger is not None:
|
|
69
|
-
logger.warning(f"request failed with connection error, retrying... ({i + 1} times)")
|
|
70
|
-
if self._retry_interval_seconds > 0.0 and \
|
|
71
|
-
i < self._retry_times:
|
|
72
|
-
sleep(self._retry_interval_seconds)
|
|
73
|
-
continue
|
|
74
95
|
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
raise ValueError(f"Unsupported input type: {type(input)}")
|
|
111
|
-
|
|
112
|
-
buffer = StringIO()
|
|
113
|
-
is_first = True
|
|
114
|
-
for message in input:
|
|
115
|
-
if not is_first:
|
|
116
|
-
buffer.write("\n\n")
|
|
117
|
-
if isinstance(message, SystemMessage):
|
|
118
|
-
buffer.write("System:\n")
|
|
119
|
-
buffer.write(message.content)
|
|
120
|
-
elif isinstance(message, HumanMessage):
|
|
121
|
-
buffer.write("User:\n")
|
|
122
|
-
buffer.write(message.content)
|
|
123
|
-
elif isinstance(message, AIMessage):
|
|
124
|
-
buffer.write("Assistant:\n")
|
|
125
|
-
buffer.write(message.content)
|
|
126
|
-
else:
|
|
127
|
-
buffer.write(str(message))
|
|
128
|
-
is_first = False
|
|
129
|
-
|
|
130
|
-
return buffer.getvalue()
|
|
131
|
-
|
|
132
|
-
def _invoke_model(
|
|
96
|
+
if not did_success:
|
|
97
|
+
if last_error is None:
|
|
98
|
+
raise RuntimeError("Request failed with unknown error")
|
|
99
|
+
else:
|
|
100
|
+
raise last_error
|
|
101
|
+
|
|
102
|
+
return cast(R, result)
|
|
103
|
+
|
|
104
|
+
def _input2str(self, input: str | list[Message]) -> str:
|
|
105
|
+
if isinstance(input, str):
|
|
106
|
+
return input
|
|
107
|
+
if not isinstance(input, list):
|
|
108
|
+
raise ValueError(f"Unsupported input type: {type(input)}")
|
|
109
|
+
|
|
110
|
+
buffer = StringIO()
|
|
111
|
+
is_first = True
|
|
112
|
+
for message in input:
|
|
113
|
+
if not is_first:
|
|
114
|
+
buffer.write("\n\n")
|
|
115
|
+
if message.role == MessageRole.SYSTEM:
|
|
116
|
+
buffer.write("System:\n")
|
|
117
|
+
buffer.write(message.message)
|
|
118
|
+
elif message.role == MessageRole.USER:
|
|
119
|
+
buffer.write("User:\n")
|
|
120
|
+
buffer.write(message.message)
|
|
121
|
+
elif message.role == MessageRole.ASSISTANT:
|
|
122
|
+
buffer.write("Assistant:\n")
|
|
123
|
+
buffer.write(message.message)
|
|
124
|
+
else:
|
|
125
|
+
buffer.write(str(message))
|
|
126
|
+
is_first = False
|
|
127
|
+
|
|
128
|
+
return buffer.getvalue()
|
|
129
|
+
|
|
130
|
+
def _invoke_model(
|
|
133
131
|
self,
|
|
134
|
-
|
|
132
|
+
input_messages: list[Message],
|
|
135
133
|
top_p: float | None,
|
|
136
134
|
temperature: float | None,
|
|
137
135
|
max_tokens: int | None,
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
136
|
+
):
|
|
137
|
+
messages: list[ChatCompletionMessageParam] = []
|
|
138
|
+
for item in input_messages:
|
|
139
|
+
if item.role == MessageRole.SYSTEM:
|
|
140
|
+
messages.append(
|
|
141
|
+
{
|
|
142
|
+
"role": "system",
|
|
143
|
+
"content": item.message,
|
|
144
|
+
}
|
|
145
|
+
)
|
|
146
|
+
elif item.role == MessageRole.USER:
|
|
147
|
+
messages.append(
|
|
148
|
+
{
|
|
149
|
+
"role": "user",
|
|
150
|
+
"content": item.message,
|
|
151
|
+
}
|
|
152
|
+
)
|
|
153
|
+
elif item.role == MessageRole.ASSISTANT:
|
|
154
|
+
messages.append(
|
|
155
|
+
{
|
|
156
|
+
"role": "assistant",
|
|
157
|
+
"content": item.message,
|
|
158
|
+
}
|
|
159
|
+
)
|
|
160
|
+
|
|
161
|
+
stream = self._client.chat.completions.create(
|
|
162
|
+
model=self._model_name,
|
|
163
|
+
messages=messages,
|
|
164
|
+
stream=True,
|
|
165
|
+
top_p=top_p,
|
|
166
|
+
temperature=temperature,
|
|
167
|
+
max_tokens=max_tokens,
|
|
168
|
+
)
|
|
169
|
+
buffer = StringIO()
|
|
170
|
+
for chunk in stream:
|
|
171
|
+
if chunk.choices and chunk.choices[0].delta.content:
|
|
172
|
+
buffer.write(chunk.choices[0].delta.content)
|
|
173
|
+
return buffer.getvalue()
|