epub-translator 0.1.1__py3-none-any.whl → 0.1.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. epub_translator/__init__.py +9 -2
  2. epub_translator/data/fill.jinja +143 -38
  3. epub_translator/epub/__init__.py +1 -1
  4. epub_translator/epub/metadata.py +122 -0
  5. epub_translator/epub/spines.py +3 -2
  6. epub_translator/epub/zip.py +11 -9
  7. epub_translator/epub_transcode.py +108 -0
  8. epub_translator/llm/__init__.py +1 -0
  9. epub_translator/llm/context.py +109 -0
  10. epub_translator/llm/core.py +32 -113
  11. epub_translator/llm/executor.py +25 -31
  12. epub_translator/llm/increasable.py +1 -1
  13. epub_translator/llm/types.py +0 -3
  14. epub_translator/punctuation.py +34 -0
  15. epub_translator/segment/__init__.py +26 -0
  16. epub_translator/segment/block_segment.py +124 -0
  17. epub_translator/segment/common.py +29 -0
  18. epub_translator/segment/inline_segment.py +356 -0
  19. epub_translator/{xml_translator → segment}/text_segment.py +7 -72
  20. epub_translator/segment/utils.py +43 -0
  21. epub_translator/translator.py +152 -184
  22. epub_translator/utils.py +33 -0
  23. epub_translator/xml/__init__.py +3 -0
  24. epub_translator/xml/const.py +1 -0
  25. epub_translator/xml/deduplication.py +3 -3
  26. epub_translator/xml/inline.py +67 -0
  27. epub_translator/xml/self_closing.py +182 -0
  28. epub_translator/xml/utils.py +42 -0
  29. epub_translator/xml/xml.py +7 -0
  30. epub_translator/xml/xml_like.py +8 -33
  31. epub_translator/xml_interrupter.py +165 -0
  32. epub_translator/xml_translator/__init__.py +3 -3
  33. epub_translator/xml_translator/callbacks.py +34 -0
  34. epub_translator/xml_translator/{const.py → common.py} +0 -1
  35. epub_translator/xml_translator/hill_climbing.py +104 -0
  36. epub_translator/xml_translator/stream_mapper.py +253 -0
  37. epub_translator/xml_translator/submitter.py +352 -91
  38. epub_translator/xml_translator/translator.py +182 -114
  39. epub_translator/xml_translator/validation.py +458 -0
  40. {epub_translator-0.1.1.dist-info → epub_translator-0.1.4.dist-info}/METADATA +134 -21
  41. epub_translator-0.1.4.dist-info/RECORD +68 -0
  42. epub_translator/epub/placeholder.py +0 -53
  43. epub_translator/iter_sync.py +0 -24
  44. epub_translator/xml_translator/fill.py +0 -128
  45. epub_translator/xml_translator/format.py +0 -282
  46. epub_translator/xml_translator/fragmented.py +0 -125
  47. epub_translator/xml_translator/group.py +0 -183
  48. epub_translator/xml_translator/progressive_locking.py +0 -256
  49. epub_translator/xml_translator/utils.py +0 -29
  50. epub_translator-0.1.1.dist-info/RECORD +0 -58
  51. {epub_translator-0.1.1.dist-info → epub_translator-0.1.4.dist-info}/LICENSE +0 -0
  52. {epub_translator-0.1.1.dist-info → epub_translator-0.1.4.dist-info}/WHEEL +0 -0
@@ -1,104 +1,18 @@
1
1
  import datetime
2
- import hashlib
3
- import json
4
- import uuid
5
- from collections.abc import Callable, Generator
2
+ from collections.abc import Generator
6
3
  from importlib.resources import files
7
4
  from logging import DEBUG, FileHandler, Formatter, Logger, getLogger
8
5
  from os import PathLike
9
6
  from pathlib import Path
10
- from typing import Self
11
7
 
12
8
  from jinja2 import Environment, Template
13
9
  from tiktoken import Encoding, get_encoding
14
10
 
15
11
  from ..template import create_env
12
+ from .context import LLMContext
16
13
  from .executor import LLMExecutor
17
14
  from .increasable import Increasable
18
- from .types import Message, MessageRole, R
19
-
20
-
21
- class LLMContext:
22
- """Context manager for LLM requests with transactional caching."""
23
-
24
- def __init__(
25
- self,
26
- executor: LLMExecutor,
27
- cache_path: Path | None,
28
- ) -> None:
29
- self._executor = executor
30
- self._cache_path = cache_path
31
- self._context_id = uuid.uuid4().hex[:12]
32
- self._temp_files: list[Path] = []
33
-
34
- def __enter__(self) -> Self:
35
- return self
36
-
37
- def __exit__(self, exc_type, exc_val, exc_tb) -> None:
38
- if exc_type is None:
39
- # Success: commit all temporary cache files
40
- self._commit()
41
- else:
42
- # Failure: rollback (delete) all temporary cache files
43
- self._rollback()
44
-
45
- def request(
46
- self,
47
- input: str | list[Message],
48
- parser: Callable[[str], R] = lambda x: x,
49
- max_tokens: int | None = None,
50
- ) -> R:
51
- messages: list[Message]
52
- if isinstance(input, str):
53
- messages = [Message(role=MessageRole.USER, message=input)]
54
- else:
55
- messages = input
56
-
57
- cache_key: str | None = None
58
- if self._cache_path is not None:
59
- cache_key = self._compute_messages_hash(messages)
60
- permanent_cache_file = self._cache_path / f"{cache_key}.txt"
61
- if permanent_cache_file.exists():
62
- cached_content = permanent_cache_file.read_text(encoding="utf-8")
63
- return parser(cached_content)
64
-
65
- temp_cache_file = self._cache_path / f"{cache_key}.{self._context_id}.txt"
66
- if temp_cache_file.exists():
67
- cached_content = temp_cache_file.read_text(encoding="utf-8")
68
- return parser(cached_content)
69
-
70
- # Make the actual request
71
- response = self._executor.request(
72
- messages=messages,
73
- parser=lambda x: x,
74
- max_tokens=max_tokens,
75
- )
76
-
77
- # Save to temporary cache if cache_path is set
78
- if self._cache_path is not None and cache_key is not None:
79
- temp_cache_file = self._cache_path / f"{cache_key}.{self._context_id}.txt"
80
- temp_cache_file.write_text(response, encoding="utf-8")
81
- self._temp_files.append(temp_cache_file)
82
-
83
- return parser(response)
84
-
85
- def _compute_messages_hash(self, messages: list[Message]) -> str:
86
- messages_dict = [{"role": msg.role.value, "message": msg.message} for msg in messages]
87
- messages_json = json.dumps(messages_dict, ensure_ascii=False, sort_keys=True)
88
- return hashlib.sha512(messages_json.encode("utf-8")).hexdigest()
89
-
90
- def _commit(self) -> None:
91
- for temp_file in self._temp_files:
92
- if temp_file.exists():
93
- # Remove the .[context-id].txt suffix to get permanent name
94
- permanent_name = temp_file.name.rsplit(".", 2)[0] + ".txt"
95
- permanent_file = temp_file.parent / permanent_name
96
- temp_file.rename(permanent_file)
97
-
98
- def _rollback(self) -> None:
99
- for temp_file in self._temp_files:
100
- if temp_file.exists():
101
- temp_file.unlink()
15
+ from .types import Message
102
16
 
103
17
 
104
18
  class LLM:
@@ -108,42 +22,28 @@ class LLM:
108
22
  url: str,
109
23
  model: str,
110
24
  token_encoding: str,
111
- cache_path: PathLike | None = None,
112
25
  timeout: float | None = None,
113
26
  top_p: float | tuple[float, float] | None = None,
114
27
  temperature: float | tuple[float, float] | None = None,
115
28
  retry_times: int = 5,
116
29
  retry_interval_seconds: float = 6.0,
117
- log_dir_path: PathLike | None = None,
30
+ cache_path: PathLike | str | None = None,
31
+ log_dir_path: PathLike | str | None = None,
118
32
  ) -> None:
119
33
  prompts_path = Path(str(files("epub_translator"))) / "data"
120
34
  self._templates: dict[str, Template] = {}
121
35
  self._encoding: Encoding = get_encoding(token_encoding)
122
36
  self._env: Environment = create_env(prompts_path)
123
- self._logger_save_path: Path | None = None
124
- self._cache_path: Path | None = None
125
-
126
- if cache_path is not None:
127
- self._cache_path = Path(cache_path)
128
- if not self._cache_path.exists():
129
- self._cache_path.mkdir(parents=True, exist_ok=True)
130
- elif not self._cache_path.is_dir():
131
- self._cache_path = None
132
-
133
- if log_dir_path is not None:
134
- self._logger_save_path = Path(log_dir_path)
135
- if not self._logger_save_path.exists():
136
- self._logger_save_path.mkdir(parents=True, exist_ok=True)
137
- elif not self._logger_save_path.is_dir():
138
- self._logger_save_path = None
37
+ self._top_p: Increasable = Increasable(top_p)
38
+ self._temperature: Increasable = Increasable(temperature)
39
+ self._cache_path: Path | None = self._ensure_dir_path(cache_path)
40
+ self._logger_save_path: Path | None = self._ensure_dir_path(log_dir_path)
139
41
 
140
42
  self._executor = LLMExecutor(
141
43
  url=url,
142
44
  model=model,
143
45
  api_key=key,
144
46
  timeout=timeout,
145
- top_p=Increasable(top_p),
146
- temperature=Increasable(temperature),
147
47
  retry_times=retry_times,
148
48
  retry_interval_seconds=retry_interval_seconds,
149
49
  create_logger=self._create_logger,
@@ -153,20 +53,29 @@ class LLM:
153
53
  def encoding(self) -> Encoding:
154
54
  return self._encoding
155
55
 
156
- def context(self) -> LLMContext:
56
+ def context(self, cache_seed_content: str | None = None) -> LLMContext:
157
57
  return LLMContext(
158
58
  executor=self._executor,
159
59
  cache_path=self._cache_path,
60
+ cache_seed_content=cache_seed_content,
61
+ top_p=self._top_p,
62
+ temperature=self._temperature,
160
63
  )
161
64
 
162
65
  def request(
163
66
  self,
164
67
  input: str | list[Message],
165
- parser: Callable[[str], R] = lambda x: x,
166
68
  max_tokens: int | None = None,
167
- ) -> R:
69
+ temperature: float | None = None,
70
+ top_p: float | None = None,
71
+ ) -> str:
168
72
  with self.context() as ctx:
169
- return ctx.request(input=input, parser=parser, max_tokens=max_tokens)
73
+ return ctx.request(
74
+ input=input,
75
+ max_tokens=max_tokens,
76
+ temperature=temperature,
77
+ top_p=top_p,
78
+ )
170
79
 
171
80
  def template(self, template_name: str) -> Template:
172
81
  template = self._templates.get(template_name, None)
@@ -175,6 +84,16 @@ class LLM:
175
84
  self._templates[template_name] = template
176
85
  return template
177
86
 
87
+ def _ensure_dir_path(self, path: PathLike | str | None) -> Path | None:
88
+ if path is None:
89
+ return None
90
+ dir_path = Path(path)
91
+ if not dir_path.exists():
92
+ dir_path.mkdir(parents=True, exist_ok=True)
93
+ elif not dir_path.is_dir():
94
+ return None
95
+ return dir_path.resolve()
96
+
178
97
  def _create_logger(self) -> Logger | None:
179
98
  if self._logger_save_path is None:
180
99
  return None
@@ -2,14 +2,12 @@ from collections.abc import Callable
2
2
  from io import StringIO
3
3
  from logging import Logger
4
4
  from time import sleep
5
- from typing import cast
6
5
 
7
6
  from openai import OpenAI
8
7
  from openai.types.chat import ChatCompletionMessageParam
9
8
 
10
9
  from .error import is_retry_error
11
- from .increasable import Increasable, Increaser
12
- from .types import Message, MessageRole, R
10
+ from .types import Message, MessageRole
13
11
 
14
12
 
15
13
  class LLMExecutor:
@@ -19,16 +17,12 @@ class LLMExecutor:
19
17
  url: str,
20
18
  model: str,
21
19
  timeout: float | None,
22
- top_p: Increasable,
23
- temperature: Increasable,
24
20
  retry_times: int,
25
21
  retry_interval_seconds: float,
26
22
  create_logger: Callable[[], Logger | None],
27
23
  ) -> None:
28
24
  self._model_name: str = model
29
25
  self._timeout: float | None = timeout
30
- self._top_p: Increasable = top_p
31
- self._temperature: Increasable = temperature
32
26
  self._retry_times: int = retry_times
33
27
  self._retry_interval_seconds: float = retry_interval_seconds
34
28
  self._create_logger: Callable[[], Logger | None] = create_logger
@@ -38,15 +32,29 @@ class LLMExecutor:
38
32
  timeout=timeout,
39
33
  )
40
34
 
41
- def request(self, messages: list[Message], parser: Callable[[str], R], max_tokens: int | None) -> R:
42
- result: R | None = None
35
+ def request(
36
+ self,
37
+ messages: list[Message],
38
+ max_tokens: int | None,
39
+ temperature: float | None,
40
+ top_p: float | None,
41
+ cache_key: str | None,
42
+ ) -> str:
43
+ response: str = ""
43
44
  last_error: Exception | None = None
44
45
  did_success = False
45
- top_p: Increaser = self._top_p.context()
46
- temperature: Increaser = self._temperature.context()
47
46
  logger = self._create_logger()
48
47
 
49
48
  if logger is not None:
49
+ parameters: list[str] = [
50
+ f"\t\ntemperature={temperature}",
51
+ f"\t\ntop_p={top_p}",
52
+ f"\t\nmax_tokens={max_tokens}",
53
+ ]
54
+ if cache_key is not None:
55
+ parameters.append(f"\t\ncache_key={cache_key}")
56
+
57
+ logger.debug(f"[[Parameters]]:{''.join(parameters)}\n")
50
58
  logger.debug(f"[[Request]]:\n{self._input2str(messages)}\n")
51
59
 
52
60
  try:
@@ -54,8 +62,8 @@ class LLMExecutor:
54
62
  try:
55
63
  response = self._invoke_model(
56
64
  input_messages=messages,
57
- top_p=top_p.current,
58
- temperature=temperature.current,
65
+ temperature=temperature,
66
+ top_p=top_p,
59
67
  max_tokens=max_tokens,
60
68
  )
61
69
  if logger is not None:
@@ -71,22 +79,8 @@ class LLMExecutor:
71
79
  sleep(self._retry_interval_seconds)
72
80
  continue
73
81
 
74
- try:
75
- result = parser(response)
76
- did_success = True
77
- break
78
-
79
- except Exception as err:
80
- last_error = err
81
- warn_message = f"request failed with parsing error, retrying... ({i + 1} times)"
82
- if logger is not None:
83
- logger.warning(warn_message)
84
- print(warn_message)
85
- top_p.increase()
86
- temperature.increase()
87
- if self._retry_interval_seconds > 0.0 and i < self._retry_times:
88
- sleep(self._retry_interval_seconds)
89
- continue
82
+ did_success = True
83
+ break
90
84
 
91
85
  except KeyboardInterrupt as err:
92
86
  if last_error is not None and logger is not None:
@@ -99,7 +93,7 @@ class LLMExecutor:
99
93
  else:
100
94
  raise last_error
101
95
 
102
- return cast(R, result)
96
+ return response
103
97
 
104
98
  def _input2str(self, input: str | list[Message]) -> str:
105
99
  if isinstance(input, str):
@@ -133,7 +127,7 @@ class LLMExecutor:
133
127
  top_p: float | None,
134
128
  temperature: float | None,
135
129
  max_tokens: int | None,
136
- ):
130
+ ) -> str:
137
131
  messages: list[ChatCompletionMessageParam] = []
138
132
  for item in input_messages:
139
133
  if item.role == MessageRole.SYSTEM:
@@ -21,7 +21,7 @@ class Increasable:
21
21
  param = float(param)
22
22
  if isinstance(param, float):
23
23
  param = (param, param)
24
- if isinstance(param, tuple):
24
+ if isinstance(param, (tuple, list)):
25
25
  if len(param) != 2:
26
26
  raise ValueError(f"Expected a tuple of length 2, got {len(param)}")
27
27
  begin, end = param
@@ -1,8 +1,5 @@
1
1
  from dataclasses import dataclass
2
2
  from enum import Enum, auto
3
- from typing import TypeVar
4
-
5
- R = TypeVar("R")
6
3
 
7
4
 
8
5
  @dataclass
@@ -0,0 +1,34 @@
1
+ from xml.etree.ElementTree import Element
2
+
3
+ from .xml import iter_with_stack
4
+
5
+ _QUOTE_MAPPING = {
6
+ # 法语引号
7
+ "«": "",
8
+ "»": "",
9
+ "‹": "«",
10
+ "›": "»",
11
+ # 中文书书名号
12
+ "《": "",
13
+ "》": "",
14
+ "〈": "《",
15
+ "〉": "》",
16
+ }
17
+
18
+
19
+ def _strip_quotes(text: str):
20
+ for char in text:
21
+ mapped = _QUOTE_MAPPING.get(char, None)
22
+ if mapped is None:
23
+ yield char
24
+ elif mapped:
25
+ yield mapped
26
+
27
+
28
+ def unwrap_french_quotes(element: Element) -> Element:
29
+ for _, child_element in iter_with_stack(element):
30
+ if child_element.text:
31
+ child_element.text = "".join(_strip_quotes(child_element.text))
32
+ if child_element.tail:
33
+ child_element.tail = "".join(_strip_quotes(child_element.tail))
34
+ return element
@@ -0,0 +1,26 @@
1
+ from .block_segment import (
2
+ BlockContentError,
3
+ BlockError,
4
+ BlockExpectedIDsError,
5
+ BlockSegment,
6
+ BlockSubmitter,
7
+ BlockUnexpectedIDError,
8
+ BlockWrongTagError,
9
+ )
10
+ from .common import FoundInvalidIDError
11
+ from .inline_segment import (
12
+ InlineError,
13
+ InlineExpectedIDsError,
14
+ InlineLostIDError,
15
+ InlineSegment,
16
+ InlineUnexpectedIDError,
17
+ InlineWrongTagCountError,
18
+ search_inline_segments,
19
+ )
20
+ from .text_segment import (
21
+ TextPosition,
22
+ TextSegment,
23
+ combine_text_segments,
24
+ incision_between,
25
+ search_text_segments,
26
+ )
@@ -0,0 +1,124 @@
1
+ from collections.abc import Generator
2
+ from dataclasses import dataclass
3
+ from typing import cast
4
+ from xml.etree.ElementTree import Element
5
+
6
+ from .common import FoundInvalidIDError, validate_id_in_element
7
+ from .inline_segment import InlineError, InlineSegment
8
+ from .text_segment import TextSegment
9
+ from .utils import IDGenerator, id_in_element
10
+
11
+
12
+ @dataclass
13
+ class BlockSubmitter:
14
+ id: int
15
+ origin_text_segments: list[TextSegment]
16
+ submitted_element: Element
17
+
18
+
19
+ @dataclass
20
+ class BlockWrongTagError:
21
+ block: tuple[int, Element] | None # (block_id, block_element) | None 表示根元素
22
+ expected_tag: str
23
+ instead_tag: str
24
+
25
+
26
+ @dataclass
27
+ class BlockUnexpectedIDError:
28
+ id: int
29
+ element: Element
30
+
31
+
32
+ @dataclass
33
+ class BlockExpectedIDsError:
34
+ id2element: dict[int, Element]
35
+
36
+
37
+ @dataclass
38
+ class BlockContentError:
39
+ id: int
40
+ element: Element
41
+ errors: list[InlineError | FoundInvalidIDError]
42
+
43
+
44
+ BlockError = BlockWrongTagError | BlockUnexpectedIDError | BlockExpectedIDsError | BlockContentError
45
+
46
+
47
+ class BlockSegment:
48
+ def __init__(self, root_tag: str, inline_segments: list[InlineSegment]) -> None:
49
+ id_generator = IDGenerator()
50
+ for inline_segment in inline_segments:
51
+ inline_segment.id = id_generator.next_id()
52
+ inline_segment.recreate_ids(id_generator)
53
+
54
+ self._root_tag: str = root_tag
55
+ self._inline_segments: list[InlineSegment] = inline_segments
56
+ self._id2inline_segment: dict[int, InlineSegment] = dict((cast(int, s.id), s) for s in self._inline_segments)
57
+
58
+ def __iter__(self) -> Generator[InlineSegment, None, None]:
59
+ yield from self._inline_segments
60
+
61
+ def create_element(self) -> Element:
62
+ root_element = Element(self._root_tag)
63
+ for inline_segment in self._inline_segments:
64
+ root_element.append(inline_segment.create_element())
65
+ return root_element
66
+
67
+ def validate(self, validated_element: Element) -> Generator[BlockError | FoundInvalidIDError, None, None]:
68
+ if validated_element.tag != self._root_tag:
69
+ yield BlockWrongTagError(
70
+ block=None,
71
+ expected_tag=self._root_tag,
72
+ instead_tag=validated_element.tag,
73
+ )
74
+
75
+ remain_expected_elements: dict[int, Element] = dict(
76
+ (id, inline_segment.parent) for id, inline_segment in self._id2inline_segment.items()
77
+ )
78
+ for child_validated_element in validated_element:
79
+ element_id = validate_id_in_element(child_validated_element)
80
+ if isinstance(element_id, FoundInvalidIDError):
81
+ yield element_id
82
+ else:
83
+ inline_segment = self._id2inline_segment.get(element_id, None)
84
+ if inline_segment is None:
85
+ yield BlockUnexpectedIDError(
86
+ id=element_id,
87
+ element=child_validated_element,
88
+ )
89
+ else:
90
+ if inline_segment.parent.tag != child_validated_element.tag:
91
+ yield BlockWrongTagError(
92
+ block=(cast(int, inline_segment.id), inline_segment.parent),
93
+ expected_tag=inline_segment.parent.tag,
94
+ instead_tag=child_validated_element.tag,
95
+ )
96
+
97
+ remain_expected_elements.pop(element_id, None)
98
+ inline_errors = list(inline_segment.validate(child_validated_element))
99
+
100
+ if inline_errors:
101
+ yield BlockContentError(
102
+ id=element_id,
103
+ element=child_validated_element,
104
+ errors=inline_errors,
105
+ )
106
+
107
+ if remain_expected_elements:
108
+ yield BlockExpectedIDsError(id2element=remain_expected_elements)
109
+
110
+ def submit(self, target: Element) -> Generator[BlockSubmitter, None, None]:
111
+ for child_element in target:
112
+ element_id = id_in_element(child_element)
113
+ if element_id is None:
114
+ continue
115
+ inline_segment = self._id2inline_segment.get(element_id, None)
116
+ if inline_segment is None:
117
+ continue
118
+ inline_segment_id = inline_segment.id
119
+ assert inline_segment_id is not None
120
+ yield BlockSubmitter(
121
+ id=inline_segment_id,
122
+ origin_text_segments=list(inline_segment),
123
+ submitted_element=inline_segment.assign_attributes(child_element),
124
+ )
@@ -0,0 +1,29 @@
1
+ from dataclasses import dataclass
2
+ from xml.etree.ElementTree import Element
3
+
4
+ from ..xml import ID_KEY
5
+
6
+
7
+ @dataclass
8
+ class FoundInvalidIDError(Exception):
9
+ invalid_id: str | None
10
+ element: Element
11
+
12
+
13
+ def validate_id_in_element(element: Element, enable_no_id: bool = False) -> int | FoundInvalidIDError:
14
+ id_str = element.get(ID_KEY, None)
15
+ if id_str is None:
16
+ if enable_no_id:
17
+ return -1
18
+ else:
19
+ return FoundInvalidIDError(
20
+ invalid_id=None,
21
+ element=element,
22
+ )
23
+ try:
24
+ return int(id_str)
25
+ except ValueError:
26
+ return FoundInvalidIDError(
27
+ invalid_id=id_str,
28
+ element=element,
29
+ )