epub-translator 0.0.6__py3-none-any.whl → 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- epub_translator/__init__.py +3 -1
- epub_translator/data/fill.jinja +66 -0
- epub_translator/data/mmltex/README.md +67 -0
- epub_translator/data/mmltex/cmarkup.xsl +1106 -0
- epub_translator/data/mmltex/entities.xsl +459 -0
- epub_translator/data/mmltex/glayout.xsl +222 -0
- epub_translator/data/mmltex/mmltex.xsl +36 -0
- epub_translator/data/mmltex/scripts.xsl +375 -0
- epub_translator/data/mmltex/tables.xsl +130 -0
- epub_translator/data/mmltex/tokens.xsl +328 -0
- epub_translator/data/translate.jinja +15 -12
- epub_translator/epub/__init__.py +4 -2
- epub_translator/epub/common.py +43 -0
- epub_translator/epub/math.py +193 -0
- epub_translator/epub/placeholder.py +53 -0
- epub_translator/epub/spines.py +42 -0
- epub_translator/epub/toc.py +505 -0
- epub_translator/epub/zip.py +67 -0
- epub_translator/iter_sync.py +24 -0
- epub_translator/language.py +23 -0
- epub_translator/llm/__init__.py +2 -1
- epub_translator/llm/core.py +175 -0
- epub_translator/llm/error.py +38 -35
- epub_translator/llm/executor.py +159 -136
- epub_translator/llm/increasable.py +28 -28
- epub_translator/llm/types.py +17 -0
- epub_translator/serial/__init__.py +2 -0
- epub_translator/serial/chunk.py +52 -0
- epub_translator/serial/segment.py +17 -0
- epub_translator/serial/splitter.py +50 -0
- epub_translator/template.py +35 -33
- epub_translator/translator.py +205 -168
- epub_translator/utils.py +7 -0
- epub_translator/xml/__init__.py +4 -3
- epub_translator/xml/deduplication.py +38 -0
- epub_translator/xml/firendly/__init__.py +2 -0
- epub_translator/xml/firendly/decoder.py +75 -0
- epub_translator/xml/firendly/encoder.py +84 -0
- epub_translator/xml/firendly/parser.py +177 -0
- epub_translator/xml/firendly/tag.py +118 -0
- epub_translator/xml/firendly/transform.py +36 -0
- epub_translator/xml/xml.py +52 -0
- epub_translator/xml/xml_like.py +176 -0
- epub_translator/xml_translator/__init__.py +3 -0
- epub_translator/xml_translator/const.py +2 -0
- epub_translator/xml_translator/fill.py +128 -0
- epub_translator/xml_translator/format.py +282 -0
- epub_translator/xml_translator/fragmented.py +125 -0
- epub_translator/xml_translator/group.py +183 -0
- epub_translator/xml_translator/progressive_locking.py +256 -0
- epub_translator/xml_translator/submitter.py +102 -0
- epub_translator/xml_translator/text_segment.py +263 -0
- epub_translator/xml_translator/translator.py +178 -0
- epub_translator/xml_translator/utils.py +29 -0
- epub_translator-0.1.0.dist-info/METADATA +283 -0
- epub_translator-0.1.0.dist-info/RECORD +58 -0
- epub_translator/data/format.jinja +0 -33
- epub_translator/epub/content_parser.py +0 -162
- epub_translator/epub/html/__init__.py +0 -1
- epub_translator/epub/html/dom_operator.py +0 -62
- epub_translator/epub/html/empty_tags.py +0 -23
- epub_translator/epub/html/file.py +0 -80
- epub_translator/epub/html/texts_searcher.py +0 -46
- epub_translator/llm/node.py +0 -201
- epub_translator/translation/__init__.py +0 -2
- epub_translator/translation/chunk.py +0 -118
- epub_translator/translation/splitter.py +0 -78
- epub_translator/translation/store.py +0 -36
- epub_translator/translation/translation.py +0 -231
- epub_translator/translation/types.py +0 -45
- epub_translator/translation/utils.py +0 -11
- epub_translator/xml/decoder.py +0 -71
- epub_translator/xml/encoder.py +0 -95
- epub_translator/xml/parser.py +0 -172
- epub_translator/xml/tag.py +0 -93
- epub_translator/xml/transform.py +0 -34
- epub_translator/xml/utils.py +0 -12
- epub_translator/zip_context.py +0 -74
- epub_translator-0.0.6.dist-info/METADATA +0 -170
- epub_translator-0.0.6.dist-info/RECORD +0 -36
- {epub_translator-0.0.6.dist-info → epub_translator-0.1.0.dist-info}/LICENSE +0 -0
- {epub_translator-0.0.6.dist-info → epub_translator-0.1.0.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,175 @@
|
|
|
1
|
+
import datetime
|
|
2
|
+
import hashlib
|
|
3
|
+
import json
|
|
4
|
+
from collections.abc import Callable, Generator
|
|
5
|
+
from importlib.resources import files
|
|
6
|
+
from logging import DEBUG, FileHandler, Formatter, Logger, getLogger
|
|
7
|
+
from os import PathLike
|
|
8
|
+
from pathlib import Path
|
|
9
|
+
|
|
10
|
+
from jinja2 import Environment, Template
|
|
11
|
+
from tiktoken import Encoding, get_encoding
|
|
12
|
+
|
|
13
|
+
from ..template import create_env
|
|
14
|
+
from .executor import LLMExecutor
|
|
15
|
+
from .increasable import Increasable
|
|
16
|
+
from .types import Message, MessageRole, R
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class LLM:
|
|
20
|
+
def __init__(
|
|
21
|
+
self,
|
|
22
|
+
key: str,
|
|
23
|
+
url: str,
|
|
24
|
+
model: str,
|
|
25
|
+
token_encoding: str,
|
|
26
|
+
cache_path: PathLike | None = None,
|
|
27
|
+
timeout: float | None = None,
|
|
28
|
+
top_p: float | tuple[float, float] | None = None,
|
|
29
|
+
temperature: float | tuple[float, float] | None = None,
|
|
30
|
+
retry_times: int = 5,
|
|
31
|
+
retry_interval_seconds: float = 6.0,
|
|
32
|
+
log_dir_path: PathLike | None = None,
|
|
33
|
+
):
|
|
34
|
+
prompts_path = Path(str(files("epub_translator"))) / "data"
|
|
35
|
+
self._templates: dict[str, Template] = {}
|
|
36
|
+
self._encoding: Encoding = get_encoding(token_encoding)
|
|
37
|
+
self._env: Environment = create_env(prompts_path)
|
|
38
|
+
self._logger_save_path: Path | None = None
|
|
39
|
+
self._cache_path: Path | None = None
|
|
40
|
+
|
|
41
|
+
if cache_path is not None:
|
|
42
|
+
self._cache_path = Path(cache_path)
|
|
43
|
+
if not self._cache_path.exists():
|
|
44
|
+
self._cache_path.mkdir(parents=True, exist_ok=True)
|
|
45
|
+
elif not self._cache_path.is_dir():
|
|
46
|
+
self._cache_path = None
|
|
47
|
+
|
|
48
|
+
if log_dir_path is not None:
|
|
49
|
+
self._logger_save_path = Path(log_dir_path)
|
|
50
|
+
if not self._logger_save_path.exists():
|
|
51
|
+
self._logger_save_path.mkdir(parents=True, exist_ok=True)
|
|
52
|
+
elif not self._logger_save_path.is_dir():
|
|
53
|
+
self._logger_save_path = None
|
|
54
|
+
|
|
55
|
+
self._executor = LLMExecutor(
|
|
56
|
+
url=url,
|
|
57
|
+
model=model,
|
|
58
|
+
api_key=key,
|
|
59
|
+
timeout=timeout,
|
|
60
|
+
top_p=Increasable(top_p),
|
|
61
|
+
temperature=Increasable(temperature),
|
|
62
|
+
retry_times=retry_times,
|
|
63
|
+
retry_interval_seconds=retry_interval_seconds,
|
|
64
|
+
create_logger=self._create_logger,
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
@property
|
|
68
|
+
def encoding(self) -> Encoding:
|
|
69
|
+
return self._encoding
|
|
70
|
+
|
|
71
|
+
def request(
|
|
72
|
+
self,
|
|
73
|
+
input: str | list[Message],
|
|
74
|
+
parser: Callable[[str], R] = lambda x: x,
|
|
75
|
+
max_tokens: int | None = None,
|
|
76
|
+
) -> R:
|
|
77
|
+
messages: list[Message]
|
|
78
|
+
if isinstance(input, str):
|
|
79
|
+
messages = [Message(role=MessageRole.USER, message=input)]
|
|
80
|
+
else:
|
|
81
|
+
messages = input
|
|
82
|
+
|
|
83
|
+
# Check cache if cache_path is set
|
|
84
|
+
if self._cache_path is not None:
|
|
85
|
+
cache_key = self._compute_messages_hash(messages)
|
|
86
|
+
cache_file = self._cache_path / f"{cache_key}.txt"
|
|
87
|
+
|
|
88
|
+
if cache_file.exists():
|
|
89
|
+
cached_content = cache_file.read_text(encoding="utf-8")
|
|
90
|
+
return parser(cached_content)
|
|
91
|
+
|
|
92
|
+
# Make the actual request
|
|
93
|
+
response = self._executor.request(
|
|
94
|
+
messages=messages,
|
|
95
|
+
parser=lambda x: x,
|
|
96
|
+
max_tokens=max_tokens,
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
# Save to cache if cache_path is set
|
|
100
|
+
if self._cache_path is not None:
|
|
101
|
+
cache_key = self._compute_messages_hash(messages)
|
|
102
|
+
cache_file = self._cache_path / f"{cache_key}.txt"
|
|
103
|
+
cache_file.write_text(response, encoding="utf-8")
|
|
104
|
+
|
|
105
|
+
return parser(response)
|
|
106
|
+
|
|
107
|
+
def template(self, template_name: str) -> Template:
|
|
108
|
+
template = self._templates.get(template_name, None)
|
|
109
|
+
if template is None:
|
|
110
|
+
template = self._env.get_template(template_name)
|
|
111
|
+
self._templates[template_name] = template
|
|
112
|
+
return template
|
|
113
|
+
|
|
114
|
+
def _compute_messages_hash(self, messages: list[Message]) -> str:
|
|
115
|
+
"""Compute SHA-512 hash of m·essages for cache key."""
|
|
116
|
+
messages_dict = [{"role": msg.role.value, "message": msg.message} for msg in messages]
|
|
117
|
+
messages_json = json.dumps(messages_dict, ensure_ascii=False, sort_keys=True)
|
|
118
|
+
return hashlib.sha512(messages_json.encode("utf-8")).hexdigest()
|
|
119
|
+
|
|
120
|
+
def _create_logger(self) -> Logger | None:
|
|
121
|
+
if self._logger_save_path is None:
|
|
122
|
+
return None
|
|
123
|
+
|
|
124
|
+
now = datetime.datetime.now(datetime.timezone.utc)
|
|
125
|
+
timestamp = now.strftime("%Y-%m-%d %H-%M-%S %f")
|
|
126
|
+
file_path = self._logger_save_path / f"request {timestamp}.log"
|
|
127
|
+
logger = getLogger(f"LLM Request {timestamp}")
|
|
128
|
+
logger.setLevel(DEBUG)
|
|
129
|
+
handler = FileHandler(file_path, encoding="utf-8")
|
|
130
|
+
handler.setLevel(DEBUG)
|
|
131
|
+
handler.setFormatter(Formatter("%(asctime)s %(message)s", "%H:%M:%S"))
|
|
132
|
+
logger.addHandler(handler)
|
|
133
|
+
|
|
134
|
+
return logger
|
|
135
|
+
|
|
136
|
+
def _search_quotes(self, kind: str, response: str) -> Generator[str, None, None]:
|
|
137
|
+
start_marker = f"```{kind}"
|
|
138
|
+
end_marker = "```"
|
|
139
|
+
start_index = 0
|
|
140
|
+
|
|
141
|
+
while True:
|
|
142
|
+
start_index = self._find_ignore_case(
|
|
143
|
+
raw=response,
|
|
144
|
+
sub=start_marker,
|
|
145
|
+
start=start_index,
|
|
146
|
+
)
|
|
147
|
+
if start_index == -1:
|
|
148
|
+
break
|
|
149
|
+
|
|
150
|
+
end_index = self._find_ignore_case(
|
|
151
|
+
raw=response,
|
|
152
|
+
sub=end_marker,
|
|
153
|
+
start=start_index + len(start_marker),
|
|
154
|
+
)
|
|
155
|
+
if end_index == -1:
|
|
156
|
+
break
|
|
157
|
+
|
|
158
|
+
extracted_text = response[start_index + len(start_marker) : end_index].strip()
|
|
159
|
+
yield extracted_text
|
|
160
|
+
start_index = end_index + len(end_marker)
|
|
161
|
+
|
|
162
|
+
def _find_ignore_case(self, raw: str, sub: str, start: int = 0):
|
|
163
|
+
if not sub:
|
|
164
|
+
return 0 if 0 >= start else -1
|
|
165
|
+
|
|
166
|
+
raw_len, sub_len = len(raw), len(sub)
|
|
167
|
+
for i in range(start, raw_len - sub_len + 1):
|
|
168
|
+
match = True
|
|
169
|
+
for j in range(sub_len):
|
|
170
|
+
if raw[i + j].lower() != sub[j].lower():
|
|
171
|
+
match = False
|
|
172
|
+
break
|
|
173
|
+
if match:
|
|
174
|
+
return i
|
|
175
|
+
return -1
|
epub_translator/llm/error.py
CHANGED
|
@@ -1,49 +1,52 @@
|
|
|
1
|
-
import openai
|
|
2
1
|
import httpx
|
|
2
|
+
import openai
|
|
3
3
|
import requests
|
|
4
4
|
|
|
5
5
|
|
|
6
6
|
def is_retry_error(err: Exception) -> bool:
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
|
|
13
|
-
|
|
7
|
+
if _is_openai_retry_error(err):
|
|
8
|
+
return True
|
|
9
|
+
if _is_httpx_retry_error(err):
|
|
10
|
+
return True
|
|
11
|
+
if _is_request_retry_error(err):
|
|
12
|
+
return True
|
|
13
|
+
return False
|
|
14
|
+
|
|
14
15
|
|
|
15
16
|
# https://help.openai.com/en/articles/6897213-openai-library-error-types-guidance
|
|
16
17
|
def _is_openai_retry_error(err: Exception) -> bool:
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
18
|
+
if isinstance(err, openai.Timeout):
|
|
19
|
+
return True
|
|
20
|
+
if isinstance(err, openai.APIConnectionError):
|
|
21
|
+
return True
|
|
22
|
+
if isinstance(err, openai.InternalServerError):
|
|
23
|
+
return err.status_code in (502, 503, 504)
|
|
24
|
+
return False
|
|
25
|
+
|
|
24
26
|
|
|
25
27
|
# https://www.python-httpx.org/exceptions/
|
|
26
28
|
def _is_httpx_retry_error(err: Exception) -> bool:
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
29
|
+
if isinstance(err, httpx.RemoteProtocolError):
|
|
30
|
+
return True
|
|
31
|
+
if isinstance(err, httpx.StreamError):
|
|
32
|
+
return True
|
|
33
|
+
if isinstance(err, httpx.TimeoutException):
|
|
34
|
+
return True
|
|
35
|
+
if isinstance(err, httpx.NetworkError):
|
|
36
|
+
return True
|
|
37
|
+
if isinstance(err, httpx.ProtocolError):
|
|
38
|
+
return True
|
|
39
|
+
return False
|
|
40
|
+
|
|
38
41
|
|
|
39
42
|
# https://requests.readthedocs.io/en/latest/api/#exceptions
|
|
40
43
|
def _is_request_retry_error(err: Exception) -> bool:
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
44
|
+
if isinstance(err, requests.ConnectionError):
|
|
45
|
+
return True
|
|
46
|
+
if isinstance(err, requests.ConnectTimeout):
|
|
47
|
+
return True
|
|
48
|
+
if isinstance(err, requests.ReadTimeout):
|
|
49
|
+
return True
|
|
50
|
+
if isinstance(err, requests.Timeout):
|
|
51
|
+
return True
|
|
52
|
+
return False
|
epub_translator/llm/executor.py
CHANGED
|
@@ -1,150 +1,173 @@
|
|
|
1
|
-
from
|
|
1
|
+
from collections.abc import Callable
|
|
2
2
|
from io import StringIO
|
|
3
|
-
from time import sleep
|
|
4
|
-
from pydantic import SecretStr
|
|
5
3
|
from logging import Logger
|
|
6
|
-
from
|
|
7
|
-
from
|
|
8
|
-
|
|
4
|
+
from time import sleep
|
|
5
|
+
from typing import cast
|
|
6
|
+
|
|
7
|
+
from openai import OpenAI
|
|
8
|
+
from openai.types.chat import ChatCompletionMessageParam
|
|
9
9
|
|
|
10
|
-
from .increasable import Increasable, Increaser
|
|
11
10
|
from .error import is_retry_error
|
|
11
|
+
from .increasable import Increasable, Increaser
|
|
12
|
+
from .types import Message, MessageRole, R
|
|
12
13
|
|
|
13
14
|
|
|
14
15
|
class LLMExecutor:
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
logger.debug(f"[[Request]]:\n{self._input2str(input)}\n")
|
|
51
|
-
|
|
52
|
-
try:
|
|
53
|
-
for i in range(self._retry_times + 1):
|
|
54
|
-
try:
|
|
55
|
-
response = self._invoke_model(
|
|
56
|
-
input=input,
|
|
57
|
-
top_p=top_p.current,
|
|
58
|
-
temperature=temperature.current,
|
|
59
|
-
max_tokens=max_tokens,
|
|
60
|
-
)
|
|
61
|
-
if logger is not None:
|
|
62
|
-
logger.debug(f"[[Response]]:\n{response}\n")
|
|
16
|
+
def __init__(
|
|
17
|
+
self,
|
|
18
|
+
api_key: str,
|
|
19
|
+
url: str,
|
|
20
|
+
model: str,
|
|
21
|
+
timeout: float | None,
|
|
22
|
+
top_p: Increasable,
|
|
23
|
+
temperature: Increasable,
|
|
24
|
+
retry_times: int,
|
|
25
|
+
retry_interval_seconds: float,
|
|
26
|
+
create_logger: Callable[[], Logger | None],
|
|
27
|
+
) -> None:
|
|
28
|
+
self._model_name: str = model
|
|
29
|
+
self._timeout: float | None = timeout
|
|
30
|
+
self._top_p: Increasable = top_p
|
|
31
|
+
self._temperature: Increasable = temperature
|
|
32
|
+
self._retry_times: int = retry_times
|
|
33
|
+
self._retry_interval_seconds: float = retry_interval_seconds
|
|
34
|
+
self._create_logger: Callable[[], Logger | None] = create_logger
|
|
35
|
+
self._client = OpenAI(
|
|
36
|
+
api_key=api_key,
|
|
37
|
+
base_url=url,
|
|
38
|
+
timeout=timeout,
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
def request(self, messages: list[Message], parser: Callable[[str], R], max_tokens: int | None) -> R:
|
|
42
|
+
result: R | None = None
|
|
43
|
+
last_error: Exception | None = None
|
|
44
|
+
did_success = False
|
|
45
|
+
top_p: Increaser = self._top_p.context()
|
|
46
|
+
temperature: Increaser = self._temperature.context()
|
|
47
|
+
logger = self._create_logger()
|
|
48
|
+
|
|
49
|
+
if logger is not None:
|
|
50
|
+
logger.debug(f"[[Request]]:\n{self._input2str(messages)}\n")
|
|
63
51
|
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
52
|
+
try:
|
|
53
|
+
for i in range(self._retry_times + 1):
|
|
54
|
+
try:
|
|
55
|
+
response = self._invoke_model(
|
|
56
|
+
input_messages=messages,
|
|
57
|
+
top_p=top_p.current,
|
|
58
|
+
temperature=temperature.current,
|
|
59
|
+
max_tokens=max_tokens,
|
|
60
|
+
)
|
|
61
|
+
if logger is not None:
|
|
62
|
+
logger.debug(f"[[Response]]:\n{response}\n")
|
|
63
|
+
|
|
64
|
+
except Exception as err:
|
|
65
|
+
last_error = err
|
|
66
|
+
if not is_retry_error(err):
|
|
67
|
+
raise err
|
|
68
|
+
if logger is not None:
|
|
69
|
+
logger.warning(f"request failed with connection error, retrying... ({i + 1} times)")
|
|
70
|
+
if self._retry_interval_seconds > 0.0 and i < self._retry_times:
|
|
71
|
+
sleep(self._retry_interval_seconds)
|
|
72
|
+
continue
|
|
73
|
+
|
|
74
|
+
try:
|
|
75
|
+
result = parser(response)
|
|
76
|
+
did_success = True
|
|
77
|
+
break
|
|
78
|
+
|
|
79
|
+
except Exception as err:
|
|
80
|
+
last_error = err
|
|
81
|
+
warn_message = f"request failed with parsing error, retrying... ({i + 1} times)"
|
|
82
|
+
if logger is not None:
|
|
83
|
+
logger.warning(warn_message)
|
|
84
|
+
print(warn_message)
|
|
85
|
+
top_p.increase()
|
|
86
|
+
temperature.increase()
|
|
87
|
+
if self._retry_interval_seconds > 0.0 and i < self._retry_times:
|
|
88
|
+
sleep(self._retry_interval_seconds)
|
|
89
|
+
continue
|
|
90
|
+
|
|
91
|
+
except KeyboardInterrupt as err:
|
|
92
|
+
if last_error is not None and logger is not None:
|
|
93
|
+
logger.debug(f"[[Error]]:\n{last_error}\n")
|
|
67
94
|
raise err
|
|
68
|
-
if logger is not None:
|
|
69
|
-
logger.warning(f"request failed with connection error, retrying... ({i + 1} times)")
|
|
70
|
-
if self._retry_interval_seconds > 0.0 and \
|
|
71
|
-
i < self._retry_times:
|
|
72
|
-
sleep(self._retry_interval_seconds)
|
|
73
|
-
continue
|
|
74
95
|
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
raise ValueError(f"Unsupported input type: {type(input)}")
|
|
111
|
-
|
|
112
|
-
buffer = StringIO()
|
|
113
|
-
is_first = True
|
|
114
|
-
for message in input:
|
|
115
|
-
if not is_first:
|
|
116
|
-
buffer.write("\n\n")
|
|
117
|
-
if isinstance(message, SystemMessage):
|
|
118
|
-
buffer.write("System:\n")
|
|
119
|
-
buffer.write(message.content)
|
|
120
|
-
elif isinstance(message, HumanMessage):
|
|
121
|
-
buffer.write("User:\n")
|
|
122
|
-
buffer.write(message.content)
|
|
123
|
-
elif isinstance(message, AIMessage):
|
|
124
|
-
buffer.write("Assistant:\n")
|
|
125
|
-
buffer.write(message.content)
|
|
126
|
-
else:
|
|
127
|
-
buffer.write(str(message))
|
|
128
|
-
is_first = False
|
|
129
|
-
|
|
130
|
-
return buffer.getvalue()
|
|
131
|
-
|
|
132
|
-
def _invoke_model(
|
|
96
|
+
if not did_success:
|
|
97
|
+
if last_error is None:
|
|
98
|
+
raise RuntimeError("Request failed with unknown error")
|
|
99
|
+
else:
|
|
100
|
+
raise last_error
|
|
101
|
+
|
|
102
|
+
return cast(R, result)
|
|
103
|
+
|
|
104
|
+
def _input2str(self, input: str | list[Message]) -> str:
|
|
105
|
+
if isinstance(input, str):
|
|
106
|
+
return input
|
|
107
|
+
if not isinstance(input, list):
|
|
108
|
+
raise ValueError(f"Unsupported input type: {type(input)}")
|
|
109
|
+
|
|
110
|
+
buffer = StringIO()
|
|
111
|
+
is_first = True
|
|
112
|
+
for message in input:
|
|
113
|
+
if not is_first:
|
|
114
|
+
buffer.write("\n\n")
|
|
115
|
+
if message.role == MessageRole.SYSTEM:
|
|
116
|
+
buffer.write("System:\n")
|
|
117
|
+
buffer.write(message.message)
|
|
118
|
+
elif message.role == MessageRole.USER:
|
|
119
|
+
buffer.write("User:\n")
|
|
120
|
+
buffer.write(message.message)
|
|
121
|
+
elif message.role == MessageRole.ASSISTANT:
|
|
122
|
+
buffer.write("Assistant:\n")
|
|
123
|
+
buffer.write(message.message)
|
|
124
|
+
else:
|
|
125
|
+
buffer.write(str(message))
|
|
126
|
+
is_first = False
|
|
127
|
+
|
|
128
|
+
return buffer.getvalue()
|
|
129
|
+
|
|
130
|
+
def _invoke_model(
|
|
133
131
|
self,
|
|
134
|
-
|
|
132
|
+
input_messages: list[Message],
|
|
135
133
|
top_p: float | None,
|
|
136
134
|
temperature: float | None,
|
|
137
135
|
max_tokens: int | None,
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
136
|
+
):
|
|
137
|
+
messages: list[ChatCompletionMessageParam] = []
|
|
138
|
+
for item in input_messages:
|
|
139
|
+
if item.role == MessageRole.SYSTEM:
|
|
140
|
+
messages.append(
|
|
141
|
+
{
|
|
142
|
+
"role": "system",
|
|
143
|
+
"content": item.message,
|
|
144
|
+
}
|
|
145
|
+
)
|
|
146
|
+
elif item.role == MessageRole.USER:
|
|
147
|
+
messages.append(
|
|
148
|
+
{
|
|
149
|
+
"role": "user",
|
|
150
|
+
"content": item.message,
|
|
151
|
+
}
|
|
152
|
+
)
|
|
153
|
+
elif item.role == MessageRole.ASSISTANT:
|
|
154
|
+
messages.append(
|
|
155
|
+
{
|
|
156
|
+
"role": "assistant",
|
|
157
|
+
"content": item.message,
|
|
158
|
+
}
|
|
159
|
+
)
|
|
160
|
+
|
|
161
|
+
stream = self._client.chat.completions.create(
|
|
162
|
+
model=self._model_name,
|
|
163
|
+
messages=messages,
|
|
164
|
+
stream=True,
|
|
165
|
+
top_p=top_p,
|
|
166
|
+
temperature=temperature,
|
|
167
|
+
max_tokens=max_tokens,
|
|
168
|
+
)
|
|
169
|
+
buffer = StringIO()
|
|
170
|
+
for chunk in stream:
|
|
171
|
+
if chunk.choices and chunk.choices[0].delta.content:
|
|
172
|
+
buffer.write(chunk.choices[0].delta.content)
|
|
173
|
+
return buffer.getvalue()
|
|
@@ -1,35 +1,35 @@
|
|
|
1
1
|
class Increaser:
|
|
2
|
-
|
|
3
|
-
|
|
4
|
-
|
|
2
|
+
def __init__(self, value_range: tuple[float, float] | None):
|
|
3
|
+
self._value_range: tuple[float, float] | None = value_range
|
|
4
|
+
self._current: float | None = value_range[0] if value_range is not None else None
|
|
5
5
|
|
|
6
|
-
|
|
7
|
-
|
|
8
|
-
|
|
6
|
+
@property
|
|
7
|
+
def current(self) -> float | None:
|
|
8
|
+
return self._current
|
|
9
|
+
|
|
10
|
+
def increase(self):
|
|
11
|
+
if self._value_range is not None and self._current is not None:
|
|
12
|
+
_, end_value = self._value_range
|
|
13
|
+
self._current = self._current + 0.5 * (end_value - self._current)
|
|
9
14
|
|
|
10
|
-
def increase(self):
|
|
11
|
-
if self._value_range is None:
|
|
12
|
-
return
|
|
13
|
-
_, end_value = self._value_range
|
|
14
|
-
self._current = self._current + 0.5 * (end_value - self._current)
|
|
15
15
|
|
|
16
16
|
class Increasable:
|
|
17
|
-
|
|
18
|
-
|
|
17
|
+
def __init__(self, param: float | tuple[float, float] | None):
|
|
18
|
+
self._value_range: tuple[float, float] | None = None
|
|
19
19
|
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
20
|
+
if isinstance(param, int):
|
|
21
|
+
param = float(param)
|
|
22
|
+
if isinstance(param, float):
|
|
23
|
+
param = (param, param)
|
|
24
|
+
if isinstance(param, tuple):
|
|
25
|
+
if len(param) != 2:
|
|
26
|
+
raise ValueError(f"Expected a tuple of length 2, got {len(param)}")
|
|
27
|
+
begin, end = param
|
|
28
|
+
if isinstance(begin, int):
|
|
29
|
+
begin = float(begin)
|
|
30
|
+
if isinstance(end, int):
|
|
31
|
+
end = float(end)
|
|
32
|
+
self._value_range = (begin, end)
|
|
33
33
|
|
|
34
|
-
|
|
35
|
-
|
|
34
|
+
def context(self) -> Increaser:
|
|
35
|
+
return Increaser(self._value_range)
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
from enum import Enum, auto
|
|
3
|
+
from typing import TypeVar
|
|
4
|
+
|
|
5
|
+
R = TypeVar("R")
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
@dataclass
|
|
9
|
+
class Message:
|
|
10
|
+
role: "MessageRole"
|
|
11
|
+
message: str
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class MessageRole(Enum):
|
|
15
|
+
SYSTEM = auto()
|
|
16
|
+
USER = auto()
|
|
17
|
+
ASSISTANT = auto()
|