epub-translator 0.1.1__py3-none-any.whl → 0.1.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. epub_translator/__init__.py +9 -2
  2. epub_translator/data/fill.jinja +143 -38
  3. epub_translator/epub/__init__.py +1 -1
  4. epub_translator/epub/metadata.py +122 -0
  5. epub_translator/epub/spines.py +3 -2
  6. epub_translator/epub/zip.py +11 -9
  7. epub_translator/epub_transcode.py +108 -0
  8. epub_translator/llm/__init__.py +1 -0
  9. epub_translator/llm/context.py +109 -0
  10. epub_translator/llm/core.py +32 -113
  11. epub_translator/llm/executor.py +25 -31
  12. epub_translator/llm/increasable.py +1 -1
  13. epub_translator/llm/types.py +0 -3
  14. epub_translator/punctuation.py +34 -0
  15. epub_translator/segment/__init__.py +26 -0
  16. epub_translator/segment/block_segment.py +124 -0
  17. epub_translator/segment/common.py +29 -0
  18. epub_translator/segment/inline_segment.py +356 -0
  19. epub_translator/{xml_translator → segment}/text_segment.py +7 -72
  20. epub_translator/segment/utils.py +43 -0
  21. epub_translator/translator.py +152 -184
  22. epub_translator/utils.py +33 -0
  23. epub_translator/xml/__init__.py +3 -0
  24. epub_translator/xml/const.py +1 -0
  25. epub_translator/xml/deduplication.py +3 -3
  26. epub_translator/xml/inline.py +67 -0
  27. epub_translator/xml/self_closing.py +182 -0
  28. epub_translator/xml/utils.py +42 -0
  29. epub_translator/xml/xml.py +7 -0
  30. epub_translator/xml/xml_like.py +8 -33
  31. epub_translator/xml_interrupter.py +165 -0
  32. epub_translator/xml_translator/__init__.py +3 -3
  33. epub_translator/xml_translator/callbacks.py +34 -0
  34. epub_translator/xml_translator/{const.py → common.py} +0 -1
  35. epub_translator/xml_translator/hill_climbing.py +104 -0
  36. epub_translator/xml_translator/stream_mapper.py +253 -0
  37. epub_translator/xml_translator/submitter.py +352 -91
  38. epub_translator/xml_translator/translator.py +182 -114
  39. epub_translator/xml_translator/validation.py +458 -0
  40. {epub_translator-0.1.1.dist-info → epub_translator-0.1.4.dist-info}/METADATA +134 -21
  41. epub_translator-0.1.4.dist-info/RECORD +68 -0
  42. epub_translator/epub/placeholder.py +0 -53
  43. epub_translator/iter_sync.py +0 -24
  44. epub_translator/xml_translator/fill.py +0 -128
  45. epub_translator/xml_translator/format.py +0 -282
  46. epub_translator/xml_translator/fragmented.py +0 -125
  47. epub_translator/xml_translator/group.py +0 -183
  48. epub_translator/xml_translator/progressive_locking.py +0 -256
  49. epub_translator/xml_translator/utils.py +0 -29
  50. epub_translator-0.1.1.dist-info/RECORD +0 -58
  51. {epub_translator-0.1.1.dist-info → epub_translator-0.1.4.dist-info}/LICENSE +0 -0
  52. {epub_translator-0.1.1.dist-info → epub_translator-0.1.4.dist-info}/WHEEL +0 -0
@@ -0,0 +1,182 @@
1
+ import re
2
+
3
+ # Some non-standard EPUB generators use HTML-style tags without self-closing syntax
4
+ # We need to convert them to XML-compatible format before parsing
5
+ # These are HTML5 void elements that must be self-closing in XHTML
6
+ _VOID_TAGS = (
7
+ "area",
8
+ "base",
9
+ "br",
10
+ "col",
11
+ "embed",
12
+ "hr",
13
+ "img",
14
+ "input",
15
+ "link",
16
+ "meta",
17
+ "param",
18
+ "source",
19
+ "track",
20
+ "wbr",
21
+ )
22
+
23
+
24
+ def self_close_void_elements(xml_content: str) -> str:
25
+ """
26
+ Convert void HTML elements to self-closing format for XML parsing.
27
+
28
+ This function handles non-standard HTML where void elements are not self-closed.
29
+ For illegal cases like <meta>content</meta>, the content is removed.
30
+
31
+ Args:
32
+ xml_content: HTML/XHTML content string
33
+
34
+ Returns:
35
+ Content with void elements in self-closing format
36
+
37
+ Example:
38
+ <meta charset="utf-8"> → <meta charset="utf-8" />
39
+ <br> → <br />
40
+ <meta>illegal</meta> → <meta />
41
+ """
42
+ for tag in _VOID_TAGS:
43
+ xml_content = _fix_void_element(xml_content, tag)
44
+ return xml_content
45
+
46
+
47
+ def _fix_void_element(content: str, tag_name: str) -> str:
48
+ """
49
+ Fix a specific void element in the content.
50
+
51
+ Strategy:
52
+ 1. Find <tag ...> (not already self-closed)
53
+ 2. Check if there's a matching </tag>
54
+ 3. If yes, remove everything between them and make it self-closing
55
+ 4. If no, just make the opening tag self-closing
56
+ """
57
+ result = []
58
+ pos = 0
59
+
60
+ while pos < len(content):
61
+ tag_start = content.find(f"<{tag_name}", pos)
62
+ if tag_start == -1:
63
+ result.append(content[pos:])
64
+ break
65
+
66
+ # Verify it's a complete tag match (not a prefix like <br matching <brain>)
67
+ # The character after tag_name must be >, /, or whitespace
68
+ check_pos = tag_start + len(f"<{tag_name}")
69
+ if check_pos < len(content):
70
+ next_char = content[check_pos]
71
+ if next_char not in (">", "/", " ", "\t", "\n", "\r"):
72
+ result.append(content[pos:check_pos])
73
+ pos = check_pos
74
+ continue
75
+
76
+ result.append(content[pos:tag_start])
77
+ tag_end = _find_tag_end(content, tag_start)
78
+ if tag_end == -1:
79
+ result.append(content[tag_start:])
80
+ break
81
+
82
+ opening_tag = content[tag_start : tag_end + 1]
83
+
84
+ if opening_tag.rstrip().endswith("/>"):
85
+ result.append(opening_tag)
86
+ pos = tag_end + 1
87
+ continue
88
+
89
+ if not opening_tag.endswith(">"):
90
+ result.append(opening_tag)
91
+ pos = tag_end + 1
92
+ continue
93
+
94
+ closing_tag = f"</{tag_name}>"
95
+ closing_pos = content.find(closing_tag, tag_end + 1)
96
+
97
+ if closing_pos != -1:
98
+ attrs_part = opening_tag[len(f"<{tag_name}") : -1].rstrip()
99
+ if attrs_part:
100
+ result.append(f"<{tag_name}{attrs_part} />")
101
+ else:
102
+ result.append(f"<{tag_name} />")
103
+ pos = closing_pos + len(closing_tag)
104
+ else:
105
+ attrs_part = opening_tag[len(f"<{tag_name}") : -1].rstrip()
106
+ if attrs_part:
107
+ result.append(f"<{tag_name}{attrs_part} />")
108
+ else:
109
+ result.append(f"<{tag_name} />")
110
+ pos = tag_end + 1
111
+
112
+ return "".join(result)
113
+
114
+
115
+ def _find_tag_end(content: str, start_pos: int) -> int:
116
+ """
117
+ Find the end of an HTML tag (the position of >).
118
+
119
+ Handles quotes: ignores > inside quoted attribute values.
120
+ """
121
+ pos = start_pos
122
+ in_quote = None # None, '"', or "'"
123
+
124
+ while pos < len(content):
125
+ char = content[pos]
126
+
127
+ if in_quote:
128
+ if char == in_quote:
129
+ if pos > 0 and content[pos - 1] == "\\":
130
+ pos += 1
131
+ continue
132
+ else:
133
+ in_quote = None
134
+ else:
135
+ if char in ('"', "'"):
136
+ in_quote = char
137
+ elif char == ">":
138
+ return pos
139
+
140
+ pos += 1
141
+
142
+ return -1 # Not found
143
+
144
+
145
+ # For saving: match self-closing tags like <br /> or <br/>
146
+ # Capture tag name and everything between tag name and />
147
+ _VOID_TAG_CLOSE_PATTERN = re.compile(r"<(" + "|".join(_VOID_TAGS) + r")([^>]*?)\s*/>")
148
+
149
+
150
+ def unclose_void_elements(xml_content: str) -> str:
151
+ """
152
+ Convert void elements from self-closing to unclosed format for HTML compatibility.
153
+
154
+ Transforms self-closed void elements like <br /> back to <br> for
155
+ compatibility with HTML parsers that don't support XHTML syntax.
156
+ Used only for text/html media type files.
157
+
158
+ Args:
159
+ xml_content: HTML/XHTML content string
160
+
161
+ Returns:
162
+ Content with void elements in unclosed format
163
+
164
+ Example:
165
+ <meta charset="utf-8" /> → <meta charset="utf-8">
166
+ <br /> → <br>
167
+ <img src="test.png" /> → <img src="test.png">
168
+ """
169
+
170
+ def replacer(m: re.Match):
171
+ tag_name = m.group(1)
172
+ attrs = m.group(2).rstrip() # Remove trailing whitespace
173
+ if attrs:
174
+ return f"<{tag_name}{attrs}>"
175
+ else:
176
+ return f"<{tag_name}>"
177
+
178
+ return re.sub(
179
+ pattern=_VOID_TAG_CLOSE_PATTERN,
180
+ repl=replacer,
181
+ string=xml_content,
182
+ )
@@ -0,0 +1,42 @@
1
+ from collections.abc import Generator
2
+ from xml.etree.ElementTree import Element
3
+
4
+ from ..utils import normalize_whitespace
5
+ from .const import ID_KEY
6
+
7
+
8
+ def normalize_text_in_element(text: str | None) -> str | None:
9
+ if text is None:
10
+ return None
11
+ text = normalize_whitespace(text)
12
+ if not text.strip():
13
+ return None
14
+ return text
15
+
16
+
17
+ def append_text_in_element(origin_text: str | None, append_text: str) -> str:
18
+ if origin_text is None:
19
+ return append_text
20
+ else:
21
+ return origin_text + append_text
22
+
23
+
24
+ def index_of_parent(parent: Element, checked_element: Element) -> int:
25
+ for i, child in enumerate(parent):
26
+ if child == checked_element:
27
+ return i
28
+ raise ValueError("Element not found in parent.")
29
+
30
+
31
+ def expand_left_element_texts(element: Element) -> Generator[str, None, None]:
32
+ yield "<"
33
+ yield element.tag
34
+ yield " "
35
+ yield ID_KEY
36
+ yield '="99">'
37
+
38
+
39
+ def expand_right_element_texts(element: Element) -> Generator[str, None, None]:
40
+ yield "</"
41
+ yield element.tag
42
+ yield ">"
@@ -12,6 +12,13 @@ def find_first(element: Element, tag: str) -> Element | None:
12
12
  return None
13
13
 
14
14
 
15
+ def index_in_parent(parent: Element, element: Element) -> int | None:
16
+ for i, child in enumerate(parent):
17
+ if child is element:
18
+ return i
19
+ return None
20
+
21
+
15
22
  def iter_with_stack(element: Element) -> Generator[tuple[list[Element], Element], None, None]:
16
23
  """先序遍历:yield parent_path, element"""
17
24
  stack: list[list[Element]] = [[element]]
@@ -4,6 +4,7 @@ import warnings
4
4
  from typing import IO
5
5
  from xml.etree.ElementTree import Element, fromstring, tostring
6
6
 
7
+ from .self_closing import self_close_void_elements, unclose_void_elements
7
8
  from .xml import iter_with_stack
8
9
 
9
10
  _XML_NAMESPACE_URI = "http://www.w3.org/XML/1998/namespace"
@@ -31,28 +32,11 @@ _ENCODING_PATTERN = re.compile(r'encoding\s*=\s*["\']([^"\']+)["\']', re.IGNOREC
31
32
  _FIRST_ELEMENT_PATTERN = re.compile(r"<(?![?!])[a-zA-Z]")
32
33
  _NAMESPACE_IN_TAG = re.compile(r"\{([^}]+)\}")
33
34
 
34
- # Some non-standard EPUB generators use HTML-style tags without self-closing syntax
35
- # We need to convert them to XML-compatible format before parsing
36
- _EMPTY_TAGS = (
37
- "br",
38
- "hr",
39
- "input",
40
- "col",
41
- "base",
42
- "meta",
43
- "area",
44
- )
45
-
46
- # For reading: match tags like <br> or <br class="x"> (but not <br/> or <body>)
47
- _EMPTY_TAG_OPEN_PATTERN = re.compile(r"<(" + "|".join(_EMPTY_TAGS) + r")(\s[^/>]*)>")
48
-
49
- # For saving: match self-closing tags like <br />
50
- _EMPTY_TAG_CLOSE_PATTERN = re.compile(r"<(" + "|".join(_EMPTY_TAGS) + r")(\s[^>]*?)\s*/>")
51
-
52
35
 
53
36
  class XMLLikeNode:
54
37
  def __init__(self, file: IO[bytes], is_html_like: bool = False) -> None:
55
38
  raw_content = file.read()
39
+ self._is_html_like = is_html_like
56
40
  self._encoding: str = self._detect_encoding(raw_content)
57
41
  content = raw_content.decode(self._encoding)
58
42
  self._header, xml_content = self._extract_header(content)
@@ -60,16 +44,9 @@ class XMLLikeNode:
60
44
  self._tag_to_namespace: dict[str, str] = {}
61
45
  self._attr_to_namespace: dict[str, str] = {}
62
46
 
63
- # For non-standard HTML files, convert <br> to <br/> before parsing
64
- self._is_html_like = is_html_like
65
- if is_html_like:
66
- xml_content = re.sub(
67
- pattern=_EMPTY_TAG_OPEN_PATTERN,
68
- repl=lambda m: f"<{m.group(1)}{m.group(2)} />",
69
- string=xml_content,
70
- )
71
-
72
47
  try:
48
+ # 不必判断类型,这是一个防御性极强的函数,可做到 shit >> XML
49
+ xml_content = self_close_void_elements(xml_content)
73
50
  self.element = self._extract_and_clean_namespaces(
74
51
  element=fromstring(xml_content),
75
52
  )
@@ -92,13 +69,11 @@ class XMLLikeNode:
92
69
 
93
70
  content = self._serialize_with_namespaces(self.element)
94
71
 
95
- # For non-standard HTML files, convert back from <br/> to <br>
72
+ # For non-standard HTML files (text/html), convert back from <br/> to <br>
73
+ # to maintain compatibility with HTML parsers that don't support XHTML
74
+ # For XHTML files (application/xhtml+xml), keep self-closing format
96
75
  if self._is_html_like:
97
- content = re.sub(
98
- pattern=_EMPTY_TAG_CLOSE_PATTERN,
99
- repl=lambda m: f"<{m.group(1)}{m.group(2)}>",
100
- string=content,
101
- )
76
+ content = unclose_void_elements(content)
102
77
 
103
78
  writer.write(content)
104
79
 
@@ -0,0 +1,165 @@
1
+ from collections.abc import Generator, Iterable
2
+ from typing import cast
3
+ from xml.etree.ElementTree import Element
4
+
5
+ from .segment import TextSegment
6
+ from .utils import ensure_list, normalize_whitespace
7
+
8
+ _ID_KEY = "__XML_INTERRUPTER_ID"
9
+ _MATH_TAG = "math"
10
+ _EXPRESSION_TAG = "expression"
11
+
12
+
13
+ class XMLInterrupter:
14
+ def __init__(self) -> None:
15
+ self._next_id: int = 1
16
+ self._last_interrupted_id: str | None = None
17
+ self._placeholder2interrupted: dict[int, Element] = {}
18
+ self._raw_text_segments: dict[str, list[TextSegment]] = {}
19
+
20
+ def interrupt_source_text_segments(
21
+ self, text_segments: Iterable[TextSegment]
22
+ ) -> Generator[TextSegment, None, None]:
23
+ for text_segment in text_segments:
24
+ yield from self._expand_source_text_segment(text_segment)
25
+
26
+ if self._last_interrupted_id is not None:
27
+ merged_text_segment = self._pop_and_merge_from_buffered(self._last_interrupted_id)
28
+ if merged_text_segment:
29
+ yield merged_text_segment
30
+
31
+ def interrupt_translated_text_segments(
32
+ self, text_segments: Iterable[TextSegment]
33
+ ) -> Generator[TextSegment, None, None]:
34
+ for text_segment in text_segments:
35
+ yield from self._expand_translated_text_segment(text_segment)
36
+
37
+ def interrupt_block_element(self, element: Element) -> Element:
38
+ interrupted_element = self._placeholder2interrupted.pop(id(element), None)
39
+ if interrupted_element is None:
40
+ return element
41
+ else:
42
+ return interrupted_element
43
+
44
+ def _expand_source_text_segment(self, text_segment: TextSegment):
45
+ interrupted_index = self._interrupted_index(text_segment)
46
+ interrupted_id: str | None = None
47
+
48
+ if interrupted_index is not None:
49
+ interrupted_element = text_segment.parent_stack[interrupted_index]
50
+ interrupted_id = interrupted_element.get(_ID_KEY)
51
+ if interrupted_id is None:
52
+ interrupted_id = str(self._next_id)
53
+ interrupted_element.set(_ID_KEY, interrupted_id)
54
+ self._next_id += 1
55
+ text_segments = ensure_list(
56
+ target=self._raw_text_segments,
57
+ key=interrupted_id,
58
+ )
59
+ text_segments.append(text_segment)
60
+
61
+ if self._last_interrupted_id is not None and interrupted_id != self._last_interrupted_id:
62
+ merged_text_segment = self._pop_and_merge_from_buffered(self._last_interrupted_id)
63
+ if merged_text_segment:
64
+ yield merged_text_segment
65
+
66
+ self._last_interrupted_id = interrupted_id
67
+
68
+ if interrupted_index is None:
69
+ yield text_segment
70
+
71
+ def _pop_and_merge_from_buffered(self, interrupted_id: str) -> TextSegment | None:
72
+ merged_text_segment: TextSegment | None = None
73
+ text_segments = self._raw_text_segments.get(interrupted_id, None)
74
+ if text_segments:
75
+ text_segment = text_segments[0]
76
+ interrupted_index = cast(int, self._interrupted_index(text_segment))
77
+ interrupted_element = text_segment.parent_stack[cast(int, interrupted_index)]
78
+ placeholder_element = Element(
79
+ _EXPRESSION_TAG,
80
+ {
81
+ _ID_KEY: cast(str, interrupted_element.get(_ID_KEY)),
82
+ },
83
+ )
84
+ raw_parent_stack = text_segment.parent_stack[:interrupted_index]
85
+ parent_stack = raw_parent_stack + [placeholder_element]
86
+ merged_text_segment = TextSegment(
87
+ text="".join(t.text for t in text_segments),
88
+ parent_stack=parent_stack,
89
+ left_common_depth=text_segments[0].left_common_depth,
90
+ right_common_depth=text_segments[-1].right_common_depth,
91
+ block_depth=len(parent_stack),
92
+ position=text_segments[0].position,
93
+ )
94
+ self._placeholder2interrupted[id(placeholder_element)] = interrupted_element
95
+
96
+ # TODO: 比较难搞,先关了再说
97
+ # parent_element: Element | None = None
98
+ # if interrupted_index > 0:
99
+ # parent_element = text_segment.parent_stack[interrupted_index - 1]
100
+
101
+ # if (
102
+ # not self._is_inline_math(interrupted_element)
103
+ # or parent_element is None
104
+ # or self._has_no_math_texts(parent_element)
105
+ # ):
106
+ # # 区块级公式不必重复出现,出现时突兀。但行内公式穿插在译文中更有利于读者阅读顺畅。
107
+ # self._raw_text_segments.pop(interrupted_id, None)
108
+ # else:
109
+ # for text_segment in text_segments:
110
+ # # 原始栈退光,仅留下相对 interrupted 元素的栈,这种格式与 translated 要求一致
111
+ # text_segment.left_common_depth = max(0, text_segment.left_common_depth - interrupted_index)
112
+ # text_segment.right_common_depth = max(0, text_segment.right_common_depth - interrupted_index)
113
+ # text_segment.block_depth = 1
114
+ # text_segment.parent_stack = text_segment.parent_stack[interrupted_index:]
115
+ for text_segment in text_segments:
116
+ # 原始栈退光,仅留下相对 interrupted 元素的栈,这种格式与 translated 要求一致
117
+ text_segment.left_common_depth = max(0, text_segment.left_common_depth - interrupted_index)
118
+ text_segment.right_common_depth = max(0, text_segment.right_common_depth - interrupted_index)
119
+ text_segment.block_depth = 1
120
+ text_segment.parent_stack = text_segment.parent_stack[interrupted_index:]
121
+
122
+ return merged_text_segment
123
+
124
+ def _interrupted_index(self, text_segment: TextSegment) -> int | None:
125
+ interrupted_index: int | None = None
126
+ for i, parent_element in enumerate(text_segment.parent_stack):
127
+ if parent_element.tag == _MATH_TAG:
128
+ interrupted_index = i
129
+ break
130
+ return interrupted_index
131
+
132
+ def _expand_translated_text_segment(self, text_segment: TextSegment):
133
+ interrupted_id = text_segment.block_parent.attrib.pop(_ID_KEY, None)
134
+ if interrupted_id is None:
135
+ yield text_segment
136
+ return
137
+
138
+ raw_text_segments = self._raw_text_segments.pop(interrupted_id, None)
139
+ if not raw_text_segments:
140
+ return
141
+
142
+ raw_block = raw_text_segments[0].parent_stack[0]
143
+ if not self._is_inline_math(raw_block):
144
+ return
145
+
146
+ for raw_text_segment in raw_text_segments:
147
+ raw_text_segment.block_parent.attrib.pop(_ID_KEY, None)
148
+ yield raw_text_segment
149
+
150
+ def _has_no_math_texts(self, element: Element):
151
+ if element.tag == _MATH_TAG:
152
+ return True
153
+ if element.text and normalize_whitespace(element.text).strip():
154
+ return False
155
+ for child_element in element:
156
+ if not self._has_no_math_texts(child_element):
157
+ return False
158
+ if child_element.tail and normalize_whitespace(child_element.tail).strip():
159
+ return False
160
+ return True
161
+
162
+ def _is_inline_math(self, element: Element) -> bool:
163
+ if element.tag != _MATH_TAG:
164
+ return False
165
+ return element.get("display", "").lower() != "block"
@@ -1,3 +1,3 @@
1
- from .group import XMLGroupContext
2
- from .submitter import submit_text_segments
3
- from .translator import XMLTranslator
1
+ from .callbacks import FillFailedEvent
2
+ from .submitter import SubmitKind
3
+ from .translator import TranslationTask, XMLTranslator
@@ -0,0 +1,34 @@
1
+ from collections.abc import Callable, Iterable
2
+ from dataclasses import dataclass
3
+ from xml.etree.ElementTree import Element
4
+
5
+ from ..segment import TextSegment
6
+
7
+
8
+ @dataclass
9
+ class FillFailedEvent:
10
+ error_message: str
11
+ retried_count: int
12
+ over_maximum_retries: bool
13
+
14
+
15
+ @dataclass
16
+ class Callbacks:
17
+ interrupt_source_text_segments: Callable[[Iterable[TextSegment]], Iterable[TextSegment]]
18
+ interrupt_translated_text_segments: Callable[[Iterable[TextSegment]], Iterable[TextSegment]]
19
+ interrupt_block_element: Callable[[Element], Element]
20
+ on_fill_failed: Callable[[FillFailedEvent], None]
21
+
22
+
23
+ def warp_callbacks(
24
+ interrupt_source_text_segments: Callable[[Iterable[TextSegment]], Iterable[TextSegment]] | None,
25
+ interrupt_translated_text_segments: Callable[[Iterable[TextSegment]], Iterable[TextSegment]] | None,
26
+ interrupt_block_element: Callable[[Element], Element] | None,
27
+ on_fill_failed: Callable[[FillFailedEvent], None] | None,
28
+ ) -> Callbacks:
29
+ return Callbacks(
30
+ interrupt_source_text_segments=interrupt_source_text_segments or (lambda x: x),
31
+ interrupt_translated_text_segments=interrupt_translated_text_segments or (lambda x: x),
32
+ interrupt_block_element=interrupt_block_element or (lambda x: x),
33
+ on_fill_failed=on_fill_failed or (lambda event: None),
34
+ )
@@ -1,2 +1 @@
1
- ID_KEY: str = "id"
2
1
  DATA_ORIGIN_LEN_KEY = "data-orig-len"
@@ -0,0 +1,104 @@
1
+ from collections.abc import Generator
2
+ from dataclasses import dataclass
3
+ from xml.etree.ElementTree import Element
4
+
5
+ from tiktoken import Encoding
6
+
7
+ from ..segment import BlockSegment, BlockSubmitter, TextSegment, search_text_segments
8
+ from ..xml import plain_text
9
+ from .common import DATA_ORIGIN_LEN_KEY
10
+ from .stream_mapper import InlineSegmentMapping
11
+ from .validation import LEVEL_DEPTH, generate_error_message, nest_as_errors_group, truncate_errors_group
12
+
13
+
14
+ @dataclass
15
+ class _BlockStatus:
16
+ weight: int
17
+ submitter: BlockSubmitter
18
+
19
+
20
+ # 以爬山算法,将 LLM 中提交的内容中挑选出完成度更高的部分。
21
+ # 它通过拒绝每个子部分的相对低完成度提交,锁定每个子部分只能往更高完成度的方向移动
22
+ class HillClimbing:
23
+ def __init__(
24
+ self,
25
+ encoding: Encoding,
26
+ max_fill_displaying_errors: int,
27
+ block_segment: BlockSegment,
28
+ ) -> None:
29
+ self._encoding: Encoding = encoding
30
+ self._max_fill_displaying_errors: int = max_fill_displaying_errors
31
+ self._block_statuses: dict[int, _BlockStatus] = {}
32
+ self._block_segment: BlockSegment = block_segment
33
+
34
+ def request_element(self) -> Element:
35
+ element = self._block_segment.create_element()
36
+ for child_element in element:
37
+ text = plain_text(child_element)
38
+ tokens = self._encoding.encode(text)
39
+ child_element.set(DATA_ORIGIN_LEN_KEY, str(len(tokens)))
40
+ return element
41
+
42
+ def gen_mappings(self) -> Generator[InlineSegmentMapping | None, None, None]:
43
+ for inline_segment in self._block_segment:
44
+ id = inline_segment.id
45
+ assert id is not None
46
+ status = self._block_statuses.get(id, None)
47
+ text_segments: list[TextSegment] | None = None
48
+ if status is None:
49
+ yield None
50
+ else:
51
+ submitted_element = status.submitter.submitted_element
52
+ text_segments = list(search_text_segments(submitted_element))
53
+ yield inline_segment.parent, text_segments
54
+
55
+ def submit(self, element: Element) -> str | None:
56
+ error_message, block_weights = self._validate_block_weights_and_error_message(element)
57
+
58
+ for submitter in self._block_segment.submit(element):
59
+ weight: int = 0 # 未出现在 block_weights 说明没有错误,已完成
60
+ if block_weights:
61
+ weight = block_weights.get(submitter.id, 0)
62
+ status = self._block_statuses.get(submitter.id, None)
63
+ if status is None:
64
+ self._block_statuses[submitter.id] = _BlockStatus(
65
+ weight=weight,
66
+ submitter=submitter,
67
+ )
68
+ elif weight < status.weight:
69
+ status.weight = weight
70
+ status.submitter = submitter
71
+
72
+ return error_message
73
+
74
+ def _validate_block_weights_and_error_message(self, element: Element) -> tuple[str | None, dict[int, int] | None]:
75
+ errors_group = nest_as_errors_group(
76
+ errors=self._block_segment.validate(element),
77
+ )
78
+ if errors_group is None:
79
+ return None, None
80
+
81
+ block_weights: dict[int, int] = {}
82
+ for block_group in errors_group.block_groups:
83
+ block_id = block_group.block_id
84
+ status = self._block_statuses.get(block_id, None)
85
+ block_weights[block_id] = block_group.weight
86
+ if status is not None and status.weight > block_group.weight:
87
+ # 本轮完成度得到改善(weight 下降)应该排后,让出注意力给完成度尚未改善的部分
88
+ for child_error in block_group.errors:
89
+ child_error.level -= LEVEL_DEPTH
90
+
91
+ origin_errors_count = errors_group.errors_count
92
+ errors_group = truncate_errors_group(
93
+ errors_group=errors_group,
94
+ max_errors=self._max_fill_displaying_errors,
95
+ )
96
+ if errors_group is None:
97
+ return None, block_weights
98
+
99
+ message = generate_error_message(
100
+ encoding=self._encoding,
101
+ errors_group=errors_group,
102
+ omitted_count=origin_errors_count - errors_group.errors_count,
103
+ )
104
+ return message, block_weights