epub-translator 0.0.7__py3-none-any.whl → 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (82) hide show
  1. epub_translator/__init__.py +4 -2
  2. epub_translator/data/fill.jinja +66 -0
  3. epub_translator/data/mmltex/README.md +67 -0
  4. epub_translator/data/mmltex/cmarkup.xsl +1106 -0
  5. epub_translator/data/mmltex/entities.xsl +459 -0
  6. epub_translator/data/mmltex/glayout.xsl +222 -0
  7. epub_translator/data/mmltex/mmltex.xsl +36 -0
  8. epub_translator/data/mmltex/scripts.xsl +375 -0
  9. epub_translator/data/mmltex/tables.xsl +130 -0
  10. epub_translator/data/mmltex/tokens.xsl +328 -0
  11. epub_translator/data/translate.jinja +15 -12
  12. epub_translator/epub/__init__.py +4 -2
  13. epub_translator/epub/common.py +43 -0
  14. epub_translator/epub/math.py +193 -0
  15. epub_translator/epub/placeholder.py +53 -0
  16. epub_translator/epub/spines.py +42 -0
  17. epub_translator/epub/toc.py +505 -0
  18. epub_translator/epub/zip.py +67 -0
  19. epub_translator/iter_sync.py +24 -0
  20. epub_translator/language.py +23 -0
  21. epub_translator/llm/__init__.py +2 -1
  22. epub_translator/llm/core.py +233 -0
  23. epub_translator/llm/error.py +38 -35
  24. epub_translator/llm/executor.py +159 -136
  25. epub_translator/llm/increasable.py +28 -28
  26. epub_translator/llm/types.py +17 -0
  27. epub_translator/serial/__init__.py +2 -0
  28. epub_translator/serial/chunk.py +52 -0
  29. epub_translator/serial/segment.py +17 -0
  30. epub_translator/serial/splitter.py +50 -0
  31. epub_translator/template.py +35 -33
  32. epub_translator/translator.py +208 -178
  33. epub_translator/utils.py +7 -0
  34. epub_translator/xml/__init__.py +4 -3
  35. epub_translator/xml/deduplication.py +38 -0
  36. epub_translator/xml/firendly/__init__.py +2 -0
  37. epub_translator/xml/firendly/decoder.py +75 -0
  38. epub_translator/xml/firendly/encoder.py +84 -0
  39. epub_translator/xml/firendly/parser.py +177 -0
  40. epub_translator/xml/firendly/tag.py +118 -0
  41. epub_translator/xml/firendly/transform.py +36 -0
  42. epub_translator/xml/xml.py +52 -0
  43. epub_translator/xml/xml_like.py +231 -0
  44. epub_translator/xml_translator/__init__.py +3 -0
  45. epub_translator/xml_translator/const.py +2 -0
  46. epub_translator/xml_translator/fill.py +128 -0
  47. epub_translator/xml_translator/format.py +282 -0
  48. epub_translator/xml_translator/fragmented.py +125 -0
  49. epub_translator/xml_translator/group.py +183 -0
  50. epub_translator/xml_translator/progressive_locking.py +256 -0
  51. epub_translator/xml_translator/submitter.py +102 -0
  52. epub_translator/xml_translator/text_segment.py +263 -0
  53. epub_translator/xml_translator/translator.py +179 -0
  54. epub_translator/xml_translator/utils.py +29 -0
  55. epub_translator-0.1.1.dist-info/METADATA +283 -0
  56. epub_translator-0.1.1.dist-info/RECORD +58 -0
  57. epub_translator/data/format.jinja +0 -33
  58. epub_translator/epub/content_parser.py +0 -162
  59. epub_translator/epub/html/__init__.py +0 -1
  60. epub_translator/epub/html/dom_operator.py +0 -68
  61. epub_translator/epub/html/empty_tags.py +0 -23
  62. epub_translator/epub/html/file.py +0 -80
  63. epub_translator/epub/html/texts_searcher.py +0 -46
  64. epub_translator/llm/node.py +0 -201
  65. epub_translator/translation/__init__.py +0 -2
  66. epub_translator/translation/chunk.py +0 -118
  67. epub_translator/translation/splitter.py +0 -78
  68. epub_translator/translation/store.py +0 -36
  69. epub_translator/translation/translation.py +0 -231
  70. epub_translator/translation/types.py +0 -45
  71. epub_translator/translation/utils.py +0 -11
  72. epub_translator/xml/decoder.py +0 -71
  73. epub_translator/xml/encoder.py +0 -95
  74. epub_translator/xml/parser.py +0 -172
  75. epub_translator/xml/tag.py +0 -93
  76. epub_translator/xml/transform.py +0 -34
  77. epub_translator/xml/utils.py +0 -12
  78. epub_translator/zip_context.py +0 -74
  79. epub_translator-0.0.7.dist-info/METADATA +0 -170
  80. epub_translator-0.0.7.dist-info/RECORD +0 -36
  81. {epub_translator-0.0.7.dist-info → epub_translator-0.1.1.dist-info}/LICENSE +0 -0
  82. {epub_translator-0.0.7.dist-info → epub_translator-0.1.1.dist-info}/WHEEL +0 -0
@@ -0,0 +1,233 @@
1
+ import datetime
2
+ import hashlib
3
+ import json
4
+ import uuid
5
+ from collections.abc import Callable, Generator
6
+ from importlib.resources import files
7
+ from logging import DEBUG, FileHandler, Formatter, Logger, getLogger
8
+ from os import PathLike
9
+ from pathlib import Path
10
+ from typing import Self
11
+
12
+ from jinja2 import Environment, Template
13
+ from tiktoken import Encoding, get_encoding
14
+
15
+ from ..template import create_env
16
+ from .executor import LLMExecutor
17
+ from .increasable import Increasable
18
+ from .types import Message, MessageRole, R
19
+
20
+
21
+ class LLMContext:
22
+ """Context manager for LLM requests with transactional caching."""
23
+
24
+ def __init__(
25
+ self,
26
+ executor: LLMExecutor,
27
+ cache_path: Path | None,
28
+ ) -> None:
29
+ self._executor = executor
30
+ self._cache_path = cache_path
31
+ self._context_id = uuid.uuid4().hex[:12]
32
+ self._temp_files: list[Path] = []
33
+
34
+ def __enter__(self) -> Self:
35
+ return self
36
+
37
+ def __exit__(self, exc_type, exc_val, exc_tb) -> None:
38
+ if exc_type is None:
39
+ # Success: commit all temporary cache files
40
+ self._commit()
41
+ else:
42
+ # Failure: rollback (delete) all temporary cache files
43
+ self._rollback()
44
+
45
+ def request(
46
+ self,
47
+ input: str | list[Message],
48
+ parser: Callable[[str], R] = lambda x: x,
49
+ max_tokens: int | None = None,
50
+ ) -> R:
51
+ messages: list[Message]
52
+ if isinstance(input, str):
53
+ messages = [Message(role=MessageRole.USER, message=input)]
54
+ else:
55
+ messages = input
56
+
57
+ cache_key: str | None = None
58
+ if self._cache_path is not None:
59
+ cache_key = self._compute_messages_hash(messages)
60
+ permanent_cache_file = self._cache_path / f"{cache_key}.txt"
61
+ if permanent_cache_file.exists():
62
+ cached_content = permanent_cache_file.read_text(encoding="utf-8")
63
+ return parser(cached_content)
64
+
65
+ temp_cache_file = self._cache_path / f"{cache_key}.{self._context_id}.txt"
66
+ if temp_cache_file.exists():
67
+ cached_content = temp_cache_file.read_text(encoding="utf-8")
68
+ return parser(cached_content)
69
+
70
+ # Make the actual request
71
+ response = self._executor.request(
72
+ messages=messages,
73
+ parser=lambda x: x,
74
+ max_tokens=max_tokens,
75
+ )
76
+
77
+ # Save to temporary cache if cache_path is set
78
+ if self._cache_path is not None and cache_key is not None:
79
+ temp_cache_file = self._cache_path / f"{cache_key}.{self._context_id}.txt"
80
+ temp_cache_file.write_text(response, encoding="utf-8")
81
+ self._temp_files.append(temp_cache_file)
82
+
83
+ return parser(response)
84
+
85
+ def _compute_messages_hash(self, messages: list[Message]) -> str:
86
+ messages_dict = [{"role": msg.role.value, "message": msg.message} for msg in messages]
87
+ messages_json = json.dumps(messages_dict, ensure_ascii=False, sort_keys=True)
88
+ return hashlib.sha512(messages_json.encode("utf-8")).hexdigest()
89
+
90
+ def _commit(self) -> None:
91
+ for temp_file in self._temp_files:
92
+ if temp_file.exists():
93
+ # Remove the .[context-id].txt suffix to get permanent name
94
+ permanent_name = temp_file.name.rsplit(".", 2)[0] + ".txt"
95
+ permanent_file = temp_file.parent / permanent_name
96
+ temp_file.rename(permanent_file)
97
+
98
+ def _rollback(self) -> None:
99
+ for temp_file in self._temp_files:
100
+ if temp_file.exists():
101
+ temp_file.unlink()
102
+
103
+
104
+ class LLM:
105
+ def __init__(
106
+ self,
107
+ key: str,
108
+ url: str,
109
+ model: str,
110
+ token_encoding: str,
111
+ cache_path: PathLike | None = None,
112
+ timeout: float | None = None,
113
+ top_p: float | tuple[float, float] | None = None,
114
+ temperature: float | tuple[float, float] | None = None,
115
+ retry_times: int = 5,
116
+ retry_interval_seconds: float = 6.0,
117
+ log_dir_path: PathLike | None = None,
118
+ ) -> None:
119
+ prompts_path = Path(str(files("epub_translator"))) / "data"
120
+ self._templates: dict[str, Template] = {}
121
+ self._encoding: Encoding = get_encoding(token_encoding)
122
+ self._env: Environment = create_env(prompts_path)
123
+ self._logger_save_path: Path | None = None
124
+ self._cache_path: Path | None = None
125
+
126
+ if cache_path is not None:
127
+ self._cache_path = Path(cache_path)
128
+ if not self._cache_path.exists():
129
+ self._cache_path.mkdir(parents=True, exist_ok=True)
130
+ elif not self._cache_path.is_dir():
131
+ self._cache_path = None
132
+
133
+ if log_dir_path is not None:
134
+ self._logger_save_path = Path(log_dir_path)
135
+ if not self._logger_save_path.exists():
136
+ self._logger_save_path.mkdir(parents=True, exist_ok=True)
137
+ elif not self._logger_save_path.is_dir():
138
+ self._logger_save_path = None
139
+
140
+ self._executor = LLMExecutor(
141
+ url=url,
142
+ model=model,
143
+ api_key=key,
144
+ timeout=timeout,
145
+ top_p=Increasable(top_p),
146
+ temperature=Increasable(temperature),
147
+ retry_times=retry_times,
148
+ retry_interval_seconds=retry_interval_seconds,
149
+ create_logger=self._create_logger,
150
+ )
151
+
152
+ @property
153
+ def encoding(self) -> Encoding:
154
+ return self._encoding
155
+
156
+ def context(self) -> LLMContext:
157
+ return LLMContext(
158
+ executor=self._executor,
159
+ cache_path=self._cache_path,
160
+ )
161
+
162
+ def request(
163
+ self,
164
+ input: str | list[Message],
165
+ parser: Callable[[str], R] = lambda x: x,
166
+ max_tokens: int | None = None,
167
+ ) -> R:
168
+ with self.context() as ctx:
169
+ return ctx.request(input=input, parser=parser, max_tokens=max_tokens)
170
+
171
+ def template(self, template_name: str) -> Template:
172
+ template = self._templates.get(template_name, None)
173
+ if template is None:
174
+ template = self._env.get_template(template_name)
175
+ self._templates[template_name] = template
176
+ return template
177
+
178
+ def _create_logger(self) -> Logger | None:
179
+ if self._logger_save_path is None:
180
+ return None
181
+
182
+ now = datetime.datetime.now(datetime.UTC)
183
+ timestamp = now.strftime("%Y-%m-%d %H-%M-%S %f")
184
+ file_path = self._logger_save_path / f"request {timestamp}.log"
185
+ logger = getLogger(f"LLM Request {timestamp}")
186
+ logger.setLevel(DEBUG)
187
+ handler = FileHandler(file_path, encoding="utf-8")
188
+ handler.setLevel(DEBUG)
189
+ handler.setFormatter(Formatter("%(asctime)s %(message)s", "%H:%M:%S"))
190
+ logger.addHandler(handler)
191
+
192
+ return logger
193
+
194
+ def _search_quotes(self, kind: str, response: str) -> Generator[str, None, None]:
195
+ start_marker = f"```{kind}"
196
+ end_marker = "```"
197
+ start_index = 0
198
+
199
+ while True:
200
+ start_index = self._find_ignore_case(
201
+ raw=response,
202
+ sub=start_marker,
203
+ start=start_index,
204
+ )
205
+ if start_index == -1:
206
+ break
207
+
208
+ end_index = self._find_ignore_case(
209
+ raw=response,
210
+ sub=end_marker,
211
+ start=start_index + len(start_marker),
212
+ )
213
+ if end_index == -1:
214
+ break
215
+
216
+ extracted_text = response[start_index + len(start_marker) : end_index].strip()
217
+ yield extracted_text
218
+ start_index = end_index + len(end_marker)
219
+
220
+ def _find_ignore_case(self, raw: str, sub: str, start: int = 0):
221
+ if not sub:
222
+ return 0 if 0 >= start else -1
223
+
224
+ raw_len, sub_len = len(raw), len(sub)
225
+ for i in range(start, raw_len - sub_len + 1):
226
+ match = True
227
+ for j in range(sub_len):
228
+ if raw[i + j].lower() != sub[j].lower():
229
+ match = False
230
+ break
231
+ if match:
232
+ return i
233
+ return -1
@@ -1,49 +1,52 @@
1
- import openai
2
1
  import httpx
2
+ import openai
3
3
  import requests
4
4
 
5
5
 
6
6
  def is_retry_error(err: Exception) -> bool:
7
- if _is_openai_retry_error(err):
8
- return True
9
- if _is_httpx_retry_error(err):
10
- return True
11
- if _is_request_retry_error(err):
12
- return True
13
- return False
7
+ if _is_openai_retry_error(err):
8
+ return True
9
+ if _is_httpx_retry_error(err):
10
+ return True
11
+ if _is_request_retry_error(err):
12
+ return True
13
+ return False
14
+
14
15
 
15
16
  # https://help.openai.com/en/articles/6897213-openai-library-error-types-guidance
16
17
  def _is_openai_retry_error(err: Exception) -> bool:
17
- if isinstance(err, openai.Timeout):
18
- return True
19
- if isinstance(err, openai.APIConnectionError):
20
- return True
21
- if isinstance(err, openai.InternalServerError):
22
- return err.status_code in (502, 503, 504)
23
- return False
18
+ if isinstance(err, openai.Timeout):
19
+ return True
20
+ if isinstance(err, openai.APIConnectionError):
21
+ return True
22
+ if isinstance(err, openai.InternalServerError):
23
+ return err.status_code in (502, 503, 504)
24
+ return False
25
+
24
26
 
25
27
  # https://www.python-httpx.org/exceptions/
26
28
  def _is_httpx_retry_error(err: Exception) -> bool:
27
- if isinstance(err, httpx.RemoteProtocolError):
28
- return True
29
- if isinstance(err, httpx.StreamError):
30
- return True
31
- if isinstance(err, httpx.TimeoutException):
32
- return True
33
- if isinstance(err, httpx.NetworkError):
34
- return True
35
- if isinstance(err, httpx.ProtocolError):
36
- return True
37
- return False
29
+ if isinstance(err, httpx.RemoteProtocolError):
30
+ return True
31
+ if isinstance(err, httpx.StreamError):
32
+ return True
33
+ if isinstance(err, httpx.TimeoutException):
34
+ return True
35
+ if isinstance(err, httpx.NetworkError):
36
+ return True
37
+ if isinstance(err, httpx.ProtocolError):
38
+ return True
39
+ return False
40
+
38
41
 
39
42
  # https://requests.readthedocs.io/en/latest/api/#exceptions
40
43
  def _is_request_retry_error(err: Exception) -> bool:
41
- if isinstance(err, requests.ConnectionError):
42
- return True
43
- if isinstance(err, requests.ConnectTimeout):
44
- return True
45
- if isinstance(err, requests.ReadTimeout):
46
- return True
47
- if isinstance(err, requests.Timeout):
48
- return True
49
- return False
44
+ if isinstance(err, requests.ConnectionError):
45
+ return True
46
+ if isinstance(err, requests.ConnectTimeout):
47
+ return True
48
+ if isinstance(err, requests.ReadTimeout):
49
+ return True
50
+ if isinstance(err, requests.Timeout):
51
+ return True
52
+ return False
@@ -1,150 +1,173 @@
1
- from typing import cast, Any, Callable
1
+ from collections.abc import Callable
2
2
  from io import StringIO
3
- from time import sleep
4
- from pydantic import SecretStr
5
3
  from logging import Logger
6
- from langchain_core.messages import SystemMessage, HumanMessage, AIMessage
7
- from langchain_core.language_models import LanguageModelInput
8
- from langchain_openai import ChatOpenAI
4
+ from time import sleep
5
+ from typing import cast
6
+
7
+ from openai import OpenAI
8
+ from openai.types.chat import ChatCompletionMessageParam
9
9
 
10
- from .increasable import Increasable, Increaser
11
10
  from .error import is_retry_error
11
+ from .increasable import Increasable, Increaser
12
+ from .types import Message, MessageRole, R
12
13
 
13
14
 
14
15
  class LLMExecutor:
15
- def __init__(
16
- self,
17
- api_key: SecretStr,
18
- url: str,
19
- model: str,
20
- timeout: float | None,
21
- top_p: Increasable,
22
- temperature: Increasable,
23
- retry_times: int,
24
- retry_interval_seconds: float,
25
- create_logger: Callable[[], Logger | None],
26
- ) -> None:
27
-
28
- self._timeout: float | None = timeout
29
- self._top_p: Increasable = top_p
30
- self._temperature: Increasable = temperature
31
- self._retry_times: int = retry_times
32
- self._retry_interval_seconds: float = retry_interval_seconds
33
- self._create_logger: Callable[[], Logger | None] = create_logger
34
- self._model = ChatOpenAI(
35
- api_key=cast(SecretStr, api_key),
36
- base_url=url,
37
- model=model,
38
- timeout=timeout,
39
- )
40
-
41
- def request(self, input: LanguageModelInput, parser: Callable[[str], Any], max_tokens: int | None) -> Any:
42
- result: Any | None = None
43
- last_error: Exception | None = None
44
- did_success = False
45
- top_p: Increaser = self._top_p.context()
46
- temperature: Increaser = self._temperature.context()
47
- logger = self._create_logger()
48
-
49
- if logger is not None:
50
- logger.debug(f"[[Request]]:\n{self._input2str(input)}\n")
51
-
52
- try:
53
- for i in range(self._retry_times + 1):
54
- try:
55
- response = self._invoke_model(
56
- input=input,
57
- top_p=top_p.current,
58
- temperature=temperature.current,
59
- max_tokens=max_tokens,
60
- )
61
- if logger is not None:
62
- logger.debug(f"[[Response]]:\n{response}\n")
16
+ def __init__(
17
+ self,
18
+ api_key: str,
19
+ url: str,
20
+ model: str,
21
+ timeout: float | None,
22
+ top_p: Increasable,
23
+ temperature: Increasable,
24
+ retry_times: int,
25
+ retry_interval_seconds: float,
26
+ create_logger: Callable[[], Logger | None],
27
+ ) -> None:
28
+ self._model_name: str = model
29
+ self._timeout: float | None = timeout
30
+ self._top_p: Increasable = top_p
31
+ self._temperature: Increasable = temperature
32
+ self._retry_times: int = retry_times
33
+ self._retry_interval_seconds: float = retry_interval_seconds
34
+ self._create_logger: Callable[[], Logger | None] = create_logger
35
+ self._client = OpenAI(
36
+ api_key=api_key,
37
+ base_url=url,
38
+ timeout=timeout,
39
+ )
40
+
41
+ def request(self, messages: list[Message], parser: Callable[[str], R], max_tokens: int | None) -> R:
42
+ result: R | None = None
43
+ last_error: Exception | None = None
44
+ did_success = False
45
+ top_p: Increaser = self._top_p.context()
46
+ temperature: Increaser = self._temperature.context()
47
+ logger = self._create_logger()
48
+
49
+ if logger is not None:
50
+ logger.debug(f"[[Request]]:\n{self._input2str(messages)}\n")
63
51
 
64
- except Exception as err:
65
- last_error = err
66
- if not is_retry_error(err):
52
+ try:
53
+ for i in range(self._retry_times + 1):
54
+ try:
55
+ response = self._invoke_model(
56
+ input_messages=messages,
57
+ top_p=top_p.current,
58
+ temperature=temperature.current,
59
+ max_tokens=max_tokens,
60
+ )
61
+ if logger is not None:
62
+ logger.debug(f"[[Response]]:\n{response}\n")
63
+
64
+ except Exception as err:
65
+ last_error = err
66
+ if not is_retry_error(err):
67
+ raise err
68
+ if logger is not None:
69
+ logger.warning(f"request failed with connection error, retrying... ({i + 1} times)")
70
+ if self._retry_interval_seconds > 0.0 and i < self._retry_times:
71
+ sleep(self._retry_interval_seconds)
72
+ continue
73
+
74
+ try:
75
+ result = parser(response)
76
+ did_success = True
77
+ break
78
+
79
+ except Exception as err:
80
+ last_error = err
81
+ warn_message = f"request failed with parsing error, retrying... ({i + 1} times)"
82
+ if logger is not None:
83
+ logger.warning(warn_message)
84
+ print(warn_message)
85
+ top_p.increase()
86
+ temperature.increase()
87
+ if self._retry_interval_seconds > 0.0 and i < self._retry_times:
88
+ sleep(self._retry_interval_seconds)
89
+ continue
90
+
91
+ except KeyboardInterrupt as err:
92
+ if last_error is not None and logger is not None:
93
+ logger.debug(f"[[Error]]:\n{last_error}\n")
67
94
  raise err
68
- if logger is not None:
69
- logger.warning(f"request failed with connection error, retrying... ({i + 1} times)")
70
- if self._retry_interval_seconds > 0.0 and \
71
- i < self._retry_times:
72
- sleep(self._retry_interval_seconds)
73
- continue
74
95
 
75
- try:
76
- result = parser(response)
77
- did_success = True
78
- break
79
-
80
- except Exception as err:
81
- last_error = err
82
- warn_message = f"request failed with parsing error, retrying... ({i + 1} times)"
83
- if logger is not None:
84
- logger.warning(warn_message)
85
- print(warn_message)
86
- top_p.increase()
87
- temperature.increase()
88
- if self._retry_interval_seconds > 0.0 and \
89
- i < self._retry_times:
90
- sleep(self._retry_interval_seconds)
91
- continue
92
-
93
- except KeyboardInterrupt as err:
94
- if last_error is not None and logger is not None:
95
- logger.debug(f"[[Error]]:\n{last_error}\n")
96
- raise err
97
-
98
- if not did_success:
99
- if last_error is None:
100
- raise RuntimeError("Request failed with unknown error")
101
- else:
102
- raise last_error
103
-
104
- return result
105
-
106
- def _input2str(self, input: LanguageModelInput) -> str:
107
- if isinstance(input, str):
108
- return input
109
- if not isinstance(input, list):
110
- raise ValueError(f"Unsupported input type: {type(input)}")
111
-
112
- buffer = StringIO()
113
- is_first = True
114
- for message in input:
115
- if not is_first:
116
- buffer.write("\n\n")
117
- if isinstance(message, SystemMessage):
118
- buffer.write("System:\n")
119
- buffer.write(message.content)
120
- elif isinstance(message, HumanMessage):
121
- buffer.write("User:\n")
122
- buffer.write(message.content)
123
- elif isinstance(message, AIMessage):
124
- buffer.write("Assistant:\n")
125
- buffer.write(message.content)
126
- else:
127
- buffer.write(str(message))
128
- is_first = False
129
-
130
- return buffer.getvalue()
131
-
132
- def _invoke_model(
96
+ if not did_success:
97
+ if last_error is None:
98
+ raise RuntimeError("Request failed with unknown error")
99
+ else:
100
+ raise last_error
101
+
102
+ return cast(R, result)
103
+
104
+ def _input2str(self, input: str | list[Message]) -> str:
105
+ if isinstance(input, str):
106
+ return input
107
+ if not isinstance(input, list):
108
+ raise ValueError(f"Unsupported input type: {type(input)}")
109
+
110
+ buffer = StringIO()
111
+ is_first = True
112
+ for message in input:
113
+ if not is_first:
114
+ buffer.write("\n\n")
115
+ if message.role == MessageRole.SYSTEM:
116
+ buffer.write("System:\n")
117
+ buffer.write(message.message)
118
+ elif message.role == MessageRole.USER:
119
+ buffer.write("User:\n")
120
+ buffer.write(message.message)
121
+ elif message.role == MessageRole.ASSISTANT:
122
+ buffer.write("Assistant:\n")
123
+ buffer.write(message.message)
124
+ else:
125
+ buffer.write(str(message))
126
+ is_first = False
127
+
128
+ return buffer.getvalue()
129
+
130
+ def _invoke_model(
133
131
  self,
134
- input: LanguageModelInput,
132
+ input_messages: list[Message],
135
133
  top_p: float | None,
136
134
  temperature: float | None,
137
135
  max_tokens: int | None,
138
- ):
139
- stream = self._model.stream(
140
- input=input,
141
- timeout=self._timeout,
142
- top_p=top_p,
143
- temperature=temperature,
144
- max_tokens=max_tokens,
145
- )
146
- buffer = StringIO()
147
- for chunk in stream:
148
- data = str(chunk.content)
149
- buffer.write(data)
150
- return buffer.getvalue()
136
+ ):
137
+ messages: list[ChatCompletionMessageParam] = []
138
+ for item in input_messages:
139
+ if item.role == MessageRole.SYSTEM:
140
+ messages.append(
141
+ {
142
+ "role": "system",
143
+ "content": item.message,
144
+ }
145
+ )
146
+ elif item.role == MessageRole.USER:
147
+ messages.append(
148
+ {
149
+ "role": "user",
150
+ "content": item.message,
151
+ }
152
+ )
153
+ elif item.role == MessageRole.ASSISTANT:
154
+ messages.append(
155
+ {
156
+ "role": "assistant",
157
+ "content": item.message,
158
+ }
159
+ )
160
+
161
+ stream = self._client.chat.completions.create(
162
+ model=self._model_name,
163
+ messages=messages,
164
+ stream=True,
165
+ top_p=top_p,
166
+ temperature=temperature,
167
+ max_tokens=max_tokens,
168
+ )
169
+ buffer = StringIO()
170
+ for chunk in stream:
171
+ if chunk.choices and chunk.choices[0].delta.content:
172
+ buffer.write(chunk.choices[0].delta.content)
173
+ return buffer.getvalue()