epub-translator 0.1.0__py3-none-any.whl → 0.1.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. epub_translator/__init__.py +2 -2
  2. epub_translator/data/fill.jinja +143 -38
  3. epub_translator/epub/__init__.py +1 -1
  4. epub_translator/epub/metadata.py +122 -0
  5. epub_translator/epub/spines.py +3 -2
  6. epub_translator/epub/zip.py +11 -9
  7. epub_translator/epub_transcode.py +108 -0
  8. epub_translator/llm/__init__.py +1 -0
  9. epub_translator/llm/context.py +109 -0
  10. epub_translator/llm/core.py +39 -62
  11. epub_translator/llm/executor.py +25 -31
  12. epub_translator/llm/increasable.py +1 -1
  13. epub_translator/llm/types.py +0 -3
  14. epub_translator/segment/__init__.py +26 -0
  15. epub_translator/segment/block_segment.py +124 -0
  16. epub_translator/segment/common.py +29 -0
  17. epub_translator/segment/inline_segment.py +356 -0
  18. epub_translator/{xml_translator → segment}/text_segment.py +8 -8
  19. epub_translator/segment/utils.py +43 -0
  20. epub_translator/translator.py +150 -183
  21. epub_translator/utils.py +33 -0
  22. epub_translator/xml/__init__.py +2 -0
  23. epub_translator/xml/const.py +1 -0
  24. epub_translator/xml/deduplication.py +3 -3
  25. epub_translator/xml/self_closing.py +182 -0
  26. epub_translator/xml/utils.py +42 -0
  27. epub_translator/xml/xml.py +7 -0
  28. epub_translator/xml/xml_like.py +145 -115
  29. epub_translator/xml_interrupter.py +165 -0
  30. epub_translator/xml_translator/__init__.py +1 -2
  31. epub_translator/xml_translator/callbacks.py +34 -0
  32. epub_translator/xml_translator/{const.py → common.py} +0 -1
  33. epub_translator/xml_translator/hill_climbing.py +104 -0
  34. epub_translator/xml_translator/stream_mapper.py +253 -0
  35. epub_translator/xml_translator/submitter.py +26 -72
  36. epub_translator/xml_translator/translator.py +157 -107
  37. epub_translator/xml_translator/validation.py +458 -0
  38. {epub_translator-0.1.0.dist-info → epub_translator-0.1.3.dist-info}/METADATA +72 -9
  39. epub_translator-0.1.3.dist-info/RECORD +66 -0
  40. epub_translator/epub/placeholder.py +0 -53
  41. epub_translator/iter_sync.py +0 -24
  42. epub_translator/xml_translator/fill.py +0 -128
  43. epub_translator/xml_translator/format.py +0 -282
  44. epub_translator/xml_translator/fragmented.py +0 -125
  45. epub_translator/xml_translator/group.py +0 -183
  46. epub_translator/xml_translator/progressive_locking.py +0 -256
  47. epub_translator/xml_translator/utils.py +0 -29
  48. epub_translator-0.1.0.dist-info/RECORD +0 -58
  49. {epub_translator-0.1.0.dist-info → epub_translator-0.1.3.dist-info}/LICENSE +0 -0
  50. {epub_translator-0.1.0.dist-info → epub_translator-0.1.3.dist-info}/WHEEL +0 -0
@@ -1,10 +1,14 @@
1
1
  import io
2
2
  import re
3
+ import warnings
3
4
  from typing import IO
4
5
  from xml.etree.ElementTree import Element, fromstring, tostring
5
6
 
7
+ from .self_closing import self_close_void_elements, unclose_void_elements
6
8
  from .xml import iter_with_stack
7
9
 
10
+ _XML_NAMESPACE_URI = "http://www.w3.org/XML/1998/namespace"
11
+
8
12
  _COMMON_NAMESPACES = {
9
13
  "http://www.w3.org/1999/xhtml": "xhtml",
10
14
  "http://www.idpf.org/2007/ops": "epub",
@@ -14,6 +18,7 @@ _COMMON_NAMESPACES = {
14
18
  "http://www.idpf.org/2007/opf": "opf",
15
19
  "http://www.w3.org/2000/svg": "svg",
16
20
  "urn:oasis:names:tc:opendocument:xmlns:container": "container",
21
+ "http://www.w3.org/XML/1998/namespace": "xml", # Reserved XML namespace
17
22
  }
18
23
 
19
24
  _ROOT_NAMESPACES = {
@@ -27,32 +32,26 @@ _ENCODING_PATTERN = re.compile(r'encoding\s*=\s*["\']([^"\']+)["\']', re.IGNOREC
27
32
  _FIRST_ELEMENT_PATTERN = re.compile(r"<(?![?!])[a-zA-Z]")
28
33
  _NAMESPACE_IN_TAG = re.compile(r"\{([^}]+)\}")
29
34
 
30
- # HTML 规定了一系列自闭标签,这些标签需要改成非自闭的,因为 EPub 格式不支持
31
- # https://www.tutorialspoint.com/which-html-tags-are-self-closing
32
- _EMPTY_TAGS = (
33
- "br",
34
- "hr",
35
- "input",
36
- "col",
37
- "base",
38
- "meta",
39
- "area",
40
- )
41
-
42
- _EMPTY_TAG_PATTERN = re.compile(r"<(" + "|".join(_EMPTY_TAGS) + r")(\s[^>]*?)\s*/?>")
43
-
44
35
 
45
36
  class XMLLikeNode:
46
- def __init__(self, file: IO[bytes]) -> None:
37
+ def __init__(self, file: IO[bytes], is_html_like: bool = False) -> None:
47
38
  raw_content = file.read()
48
- self._encoding: str = _detect_encoding(raw_content)
39
+ self._is_html_like = is_html_like
40
+ self._encoding: str = self._detect_encoding(raw_content)
49
41
  content = raw_content.decode(self._encoding)
50
- self._header, xml_content = _extract_header(content)
42
+ self._header, xml_content = self._extract_header(content)
43
+ self._namespaces: dict[str, str] = {}
44
+ self._tag_to_namespace: dict[str, str] = {}
45
+ self._attr_to_namespace: dict[str, str] = {}
46
+
51
47
  try:
52
- self.element = fromstring(xml_content)
48
+ # 不必判断类型,这是一个防御性极强的函数,可做到 shit >> XML
49
+ xml_content = self_close_void_elements(xml_content)
50
+ self.element = self._extract_and_clean_namespaces(
51
+ element=fromstring(xml_content),
52
+ )
53
53
  except Exception as error:
54
54
  raise ValueError("Failed to parse XML-like content") from error
55
- self._namespaces: dict[str, str] = _extract_and_clean_namespaces(self.element)
56
55
 
57
56
  @property
58
57
  def encoding(self) -> str:
@@ -62,115 +61,146 @@ class XMLLikeNode:
62
61
  def namespaces(self) -> list[str]:
63
62
  return list(self._namespaces.keys())
64
63
 
65
- def save(self, file: IO[bytes], is_html_like: bool = False) -> None:
64
+ def save(self, file: IO[bytes]) -> None:
66
65
  writer = io.TextIOWrapper(file, encoding=self._encoding, write_through=True)
67
66
  try:
68
67
  if self._header:
69
68
  writer.write(self._header)
70
69
 
71
- content = _serialize_with_namespaces(element=self.element, namespaces=self._namespaces)
72
- if is_html_like:
73
- content = re.sub(
74
- pattern=_EMPTY_TAG_PATTERN,
75
- repl=lambda m: f"<{m.group(1)}{m.group(2)}>",
76
- string=content,
77
- )
78
- else:
79
- content = re.sub(
80
- pattern=_EMPTY_TAG_PATTERN,
81
- repl=lambda m: f"<{m.group(1)}{m.group(2)} />",
82
- string=content,
83
- )
70
+ content = self._serialize_with_namespaces(self.element)
71
+
72
+ # For non-standard HTML files (text/html), convert back from <br/> to <br>
73
+ # to maintain compatibility with HTML parsers that don't support XHTML
74
+ # For XHTML files (application/xhtml+xml), keep self-closing format
75
+ if self._is_html_like:
76
+ content = unclose_void_elements(content)
77
+
84
78
  writer.write(content)
85
79
 
86
80
  finally:
87
81
  writer.detach()
88
82
 
83
+ def _detect_encoding(self, raw_content: bytes) -> str:
84
+ if raw_content.startswith(b"\xef\xbb\xbf"):
85
+ return "utf-8-sig"
86
+ elif raw_content.startswith(b"\xff\xfe"):
87
+ return "utf-16-le"
88
+ elif raw_content.startswith(b"\xfe\xff"):
89
+ return "utf-16-be"
90
+
91
+ # 尝试从 XML 声明中提取编码:只读取前 1024 字节来查找 XML 声明
92
+ header_bytes = raw_content[:1024]
93
+ for try_encoding in ("utf-8", "utf-16-le", "utf-16-be", "iso-8859-1"):
94
+ try:
95
+ header_str = header_bytes.decode(try_encoding)
96
+ match = _ENCODING_PATTERN.search(header_str)
97
+ if match:
98
+ declared_encoding = match.group(1).lower()
99
+ try:
100
+ raw_content.decode(declared_encoding)
101
+ return declared_encoding
102
+ except (LookupError, UnicodeDecodeError):
103
+ pass
104
+ except UnicodeDecodeError:
105
+ continue
89
106
 
90
- def _detect_encoding(raw_content: bytes) -> str:
91
- if raw_content.startswith(b"\xef\xbb\xbf"):
92
- return "utf-8-sig"
93
- elif raw_content.startswith(b"\xff\xfe"):
94
- return "utf-16-le"
95
- elif raw_content.startswith(b"\xfe\xff"):
96
- return "utf-16-be"
97
-
98
- # 尝试从 XML 声明中提取编码:只读取前 1024 字节来查找 XML 声明
99
- header_bytes = raw_content[:1024]
100
- for try_encoding in ("utf-8", "utf-16-le", "utf-16-be", "iso-8859-1"):
101
107
  try:
102
- header_str = header_bytes.decode(try_encoding)
103
- match = _ENCODING_PATTERN.search(header_str)
104
- if match:
105
- declared_encoding = match.group(1).lower()
106
- try:
107
- raw_content.decode(declared_encoding)
108
- return declared_encoding
109
- except (LookupError, UnicodeDecodeError):
110
- pass
108
+ raw_content.decode("utf-8")
109
+ return "utf-8"
111
110
  except UnicodeDecodeError:
112
- continue
113
-
114
- try:
115
- raw_content.decode("utf-8")
116
- return "utf-8"
117
- except UnicodeDecodeError:
118
- pass
119
- return "iso-8859-1"
120
-
121
-
122
- def _extract_header(content: str) -> tuple[str, str]:
123
- match = _FIRST_ELEMENT_PATTERN.search(content)
124
- if match:
125
- split_pos = match.start()
126
- header = content[:split_pos]
127
- xml_content = content[split_pos:]
128
- return header, xml_content
129
- return "", content
130
-
131
-
132
- def _extract_and_clean_namespaces(element: Element):
133
- namespaces: dict[str, str] = {}
134
- for _, elem in iter_with_stack(element):
135
- match = _NAMESPACE_IN_TAG.match(elem.tag)
136
- if match:
137
- namespace_uri = match.group(1)
138
- if namespace_uri not in namespaces:
139
- prefix = _COMMON_NAMESPACES.get(namespace_uri, f"ns{len(namespaces)}")
140
- namespaces[namespace_uri] = prefix
111
+ pass
112
+ return "iso-8859-1"
141
113
 
142
- tag_name = elem.tag[len(match.group(0)) :]
143
- elem.tag = tag_name
144
-
145
- for attr_key in list(elem.attrib.keys()):
146
- match = _NAMESPACE_IN_TAG.match(attr_key)
114
+ def _extract_header(self, content: str) -> tuple[str, str]:
115
+ match = _FIRST_ELEMENT_PATTERN.search(content)
116
+ if match:
117
+ split_pos = match.start()
118
+ header = content[:split_pos]
119
+ xml_content = content[split_pos:]
120
+ return header, xml_content
121
+ return "", content
122
+
123
+ def _extract_and_clean_namespaces(self, element: Element) -> Element:
124
+ for _, elem in iter_with_stack(element):
125
+ match = _NAMESPACE_IN_TAG.match(elem.tag)
147
126
  if match:
148
127
  namespace_uri = match.group(1)
149
- if namespace_uri not in namespaces:
150
- prefix = _COMMON_NAMESPACES.get(namespace_uri, f"ns{len(namespaces)}")
151
- namespaces[namespace_uri] = prefix
152
-
153
- attr_name = attr_key[len(match.group(0)) :]
154
- attr_value = elem.attrib.pop(attr_key)
155
- elem.attrib[attr_name] = attr_value
156
- return namespaces
157
-
158
-
159
- def _serialize_with_namespaces(
160
- element: Element,
161
- namespaces: dict[str, str],
162
- ) -> str:
163
- for namespace_uri, prefix in namespaces.items():
164
- if namespace_uri in _ROOT_NAMESPACES:
165
- element.attrib["xmlns"] = namespace_uri
166
- else:
167
- element.attrib[f"xmlns:{prefix}"] = namespace_uri
168
- xml_string = tostring(element, encoding="unicode")
169
- for namespace_uri, prefix in namespaces.items():
170
- if namespace_uri in _ROOT_NAMESPACES:
171
- xml_string = xml_string.replace(f"{{{namespace_uri}}}", "")
172
- else:
173
- xml_string = xml_string.replace(f"{{{namespace_uri}}}", f"{prefix}:")
174
- pattern = r'\s+xmlns:(ns\d+)="' + re.escape(namespace_uri) + r'"'
175
- xml_string = re.sub(pattern, "", xml_string)
176
- return xml_string
128
+ if namespace_uri not in self._namespaces:
129
+ prefix = _COMMON_NAMESPACES.get(namespace_uri, f"ns{len(self._namespaces)}")
130
+ self._namespaces[namespace_uri] = prefix
131
+
132
+ tag_name = elem.tag[len(match.group(0)) :]
133
+
134
+ # Record tag -> namespace mapping (warn if conflict)
135
+ if tag_name in self._tag_to_namespace and self._tag_to_namespace[tag_name] != namespace_uri:
136
+ warnings.warn(
137
+ f"Tag '{tag_name}' has multiple namespaces: "
138
+ f"{self._tag_to_namespace[tag_name]} and {namespace_uri}. "
139
+ f"Using the first one.",
140
+ stacklevel=2,
141
+ )
142
+ else:
143
+ self._tag_to_namespace[tag_name] = namespace_uri
144
+
145
+ # Clean: remove namespace URI completely
146
+ elem.tag = tag_name
147
+
148
+ for attr_key in list(elem.attrib.keys()):
149
+ match = _NAMESPACE_IN_TAG.match(attr_key)
150
+ if match:
151
+ namespace_uri = match.group(1)
152
+ if namespace_uri not in self._namespaces:
153
+ prefix = _COMMON_NAMESPACES.get(namespace_uri, f"ns{len(self._namespaces)}")
154
+ self._namespaces[namespace_uri] = prefix
155
+
156
+ attr_name = attr_key[len(match.group(0)) :]
157
+ attr_value = elem.attrib.pop(attr_key)
158
+
159
+ # Record attr -> namespace mapping (warn if conflict)
160
+ if attr_name in self._attr_to_namespace and self._attr_to_namespace[attr_name] != namespace_uri:
161
+ warnings.warn(
162
+ f"Attribute '{attr_name}' has multiple namespaces: "
163
+ f"{self._attr_to_namespace[attr_name]} and {namespace_uri}. "
164
+ f"Using the first one.",
165
+ stacklevel=2,
166
+ )
167
+ else:
168
+ self._attr_to_namespace[attr_name] = namespace_uri
169
+
170
+ # Clean: remove namespace URI completely
171
+ elem.attrib[attr_name] = attr_value
172
+ return element
173
+
174
+ def _serialize_with_namespaces(self, element: Element) -> str:
175
+ # First, add namespace declarations to root element (before serialization)
176
+ for namespace_uri, prefix in self._namespaces.items():
177
+ # Skip the reserved xml namespace - it's implicit
178
+ if namespace_uri == _XML_NAMESPACE_URI:
179
+ continue
180
+ if namespace_uri in _ROOT_NAMESPACES:
181
+ element.attrib["xmlns"] = namespace_uri
182
+ else:
183
+ element.attrib[f"xmlns:{prefix}"] = namespace_uri
184
+
185
+ # Serialize the element tree as-is (tags are simple names without prefixes)
186
+ xml_string = tostring(element, encoding="unicode")
187
+
188
+ # Now restore namespace prefixes in the serialized string
189
+ # For each tag that should have a namespace prefix, wrap it with the prefix
190
+ for tag_name, namespace_uri in self._tag_to_namespace.items():
191
+ if namespace_uri not in _ROOT_NAMESPACES:
192
+ # Get the prefix for this namespace
193
+ prefix = self._namespaces[namespace_uri]
194
+ # Replace opening and closing tags
195
+ xml_string = xml_string.replace(f"<{tag_name} ", f"<{prefix}:{tag_name} ")
196
+ xml_string = xml_string.replace(f"<{tag_name}>", f"<{prefix}:{tag_name}>")
197
+ xml_string = xml_string.replace(f"</{tag_name}>", f"</{prefix}:{tag_name}>")
198
+ xml_string = xml_string.replace(f"<{tag_name}/>", f"<{prefix}:{tag_name}/>")
199
+
200
+ # Similarly for attributes (though less common in EPUB)
201
+ for attr_name, namespace_uri in self._attr_to_namespace.items():
202
+ if namespace_uri not in _ROOT_NAMESPACES:
203
+ prefix = self._namespaces[namespace_uri]
204
+ xml_string = xml_string.replace(f' {attr_name}="', f' {prefix}:{attr_name}="')
205
+
206
+ return xml_string
@@ -0,0 +1,165 @@
1
+ from collections.abc import Generator, Iterable
2
+ from typing import cast
3
+ from xml.etree.ElementTree import Element
4
+
5
+ from .segment import TextSegment
6
+ from .utils import ensure_list, normalize_whitespace
7
+
8
+ _ID_KEY = "__XML_INTERRUPTER_ID"
9
+ _MATH_TAG = "math"
10
+ _EXPRESSION_TAG = "expression"
11
+
12
+
13
+ class XMLInterrupter:
14
+ def __init__(self) -> None:
15
+ self._next_id: int = 1
16
+ self._last_interrupted_id: str | None = None
17
+ self._placeholder2interrupted: dict[int, Element] = {}
18
+ self._raw_text_segments: dict[str, list[TextSegment]] = {}
19
+
20
+ def interrupt_source_text_segments(
21
+ self, text_segments: Iterable[TextSegment]
22
+ ) -> Generator[TextSegment, None, None]:
23
+ for text_segment in text_segments:
24
+ yield from self._expand_source_text_segment(text_segment)
25
+
26
+ if self._last_interrupted_id is not None:
27
+ merged_text_segment = self._pop_and_merge_from_buffered(self._last_interrupted_id)
28
+ if merged_text_segment:
29
+ yield merged_text_segment
30
+
31
+ def interrupt_translated_text_segments(
32
+ self, text_segments: Iterable[TextSegment]
33
+ ) -> Generator[TextSegment, None, None]:
34
+ for text_segment in text_segments:
35
+ yield from self._expand_translated_text_segment(text_segment)
36
+
37
+ def interrupt_block_element(self, element: Element) -> Element:
38
+ interrupted_element = self._placeholder2interrupted.pop(id(element), None)
39
+ if interrupted_element is None:
40
+ return element
41
+ else:
42
+ return interrupted_element
43
+
44
+ def _expand_source_text_segment(self, text_segment: TextSegment):
45
+ interrupted_index = self._interrupted_index(text_segment)
46
+ interrupted_id: str | None = None
47
+
48
+ if interrupted_index is not None:
49
+ interrupted_element = text_segment.parent_stack[interrupted_index]
50
+ interrupted_id = interrupted_element.get(_ID_KEY)
51
+ if interrupted_id is None:
52
+ interrupted_id = str(self._next_id)
53
+ interrupted_element.set(_ID_KEY, interrupted_id)
54
+ self._next_id += 1
55
+ text_segments = ensure_list(
56
+ target=self._raw_text_segments,
57
+ key=interrupted_id,
58
+ )
59
+ text_segments.append(text_segment)
60
+
61
+ if self._last_interrupted_id is not None and interrupted_id != self._last_interrupted_id:
62
+ merged_text_segment = self._pop_and_merge_from_buffered(self._last_interrupted_id)
63
+ if merged_text_segment:
64
+ yield merged_text_segment
65
+
66
+ self._last_interrupted_id = interrupted_id
67
+
68
+ if interrupted_index is None:
69
+ yield text_segment
70
+
71
+ def _pop_and_merge_from_buffered(self, interrupted_id: str) -> TextSegment | None:
72
+ merged_text_segment: TextSegment | None = None
73
+ text_segments = self._raw_text_segments.get(interrupted_id, None)
74
+ if text_segments:
75
+ text_segment = text_segments[0]
76
+ interrupted_index = cast(int, self._interrupted_index(text_segment))
77
+ interrupted_element = text_segment.parent_stack[cast(int, interrupted_index)]
78
+ placeholder_element = Element(
79
+ _EXPRESSION_TAG,
80
+ {
81
+ _ID_KEY: cast(str, interrupted_element.get(_ID_KEY)),
82
+ },
83
+ )
84
+ raw_parent_stack = text_segment.parent_stack[:interrupted_index]
85
+ parent_stack = raw_parent_stack + [placeholder_element]
86
+ merged_text_segment = TextSegment(
87
+ text="".join(t.text for t in text_segments),
88
+ parent_stack=parent_stack,
89
+ left_common_depth=text_segments[0].left_common_depth,
90
+ right_common_depth=text_segments[-1].right_common_depth,
91
+ block_depth=len(parent_stack),
92
+ position=text_segments[0].position,
93
+ )
94
+ self._placeholder2interrupted[id(placeholder_element)] = interrupted_element
95
+
96
+ # TODO: 比较难搞,先关了再说
97
+ # parent_element: Element | None = None
98
+ # if interrupted_index > 0:
99
+ # parent_element = text_segment.parent_stack[interrupted_index - 1]
100
+
101
+ # if (
102
+ # not self._is_inline_math(interrupted_element)
103
+ # or parent_element is None
104
+ # or self._has_no_math_texts(parent_element)
105
+ # ):
106
+ # # 区块级公式不必重复出现,出现时突兀。但行内公式穿插在译文中更有利于读者阅读顺畅。
107
+ # self._raw_text_segments.pop(interrupted_id, None)
108
+ # else:
109
+ # for text_segment in text_segments:
110
+ # # 原始栈退光,仅留下相对 interrupted 元素的栈,这种格式与 translated 要求一致
111
+ # text_segment.left_common_depth = max(0, text_segment.left_common_depth - interrupted_index)
112
+ # text_segment.right_common_depth = max(0, text_segment.right_common_depth - interrupted_index)
113
+ # text_segment.block_depth = 1
114
+ # text_segment.parent_stack = text_segment.parent_stack[interrupted_index:]
115
+ for text_segment in text_segments:
116
+ # 原始栈退光,仅留下相对 interrupted 元素的栈,这种格式与 translated 要求一致
117
+ text_segment.left_common_depth = max(0, text_segment.left_common_depth - interrupted_index)
118
+ text_segment.right_common_depth = max(0, text_segment.right_common_depth - interrupted_index)
119
+ text_segment.block_depth = 1
120
+ text_segment.parent_stack = text_segment.parent_stack[interrupted_index:]
121
+
122
+ return merged_text_segment
123
+
124
+ def _interrupted_index(self, text_segment: TextSegment) -> int | None:
125
+ interrupted_index: int | None = None
126
+ for i, parent_element in enumerate(text_segment.parent_stack):
127
+ if parent_element.tag == _MATH_TAG:
128
+ interrupted_index = i
129
+ break
130
+ return interrupted_index
131
+
132
+ def _expand_translated_text_segment(self, text_segment: TextSegment):
133
+ interrupted_id = text_segment.block_parent.attrib.pop(_ID_KEY, None)
134
+ if interrupted_id is None:
135
+ yield text_segment
136
+ return
137
+
138
+ raw_text_segments = self._raw_text_segments.pop(interrupted_id, None)
139
+ if not raw_text_segments:
140
+ return
141
+
142
+ raw_block = raw_text_segments[0].parent_stack[0]
143
+ if not self._is_inline_math(raw_block):
144
+ return
145
+
146
+ for raw_text_segment in raw_text_segments:
147
+ raw_text_segment.block_parent.attrib.pop(_ID_KEY, None)
148
+ yield raw_text_segment
149
+
150
+ def _has_no_math_texts(self, element: Element):
151
+ if element.tag == _MATH_TAG:
152
+ return True
153
+ if element.text and normalize_whitespace(element.text).strip():
154
+ return False
155
+ for child_element in element:
156
+ if not self._has_no_math_texts(child_element):
157
+ return False
158
+ if child_element.tail and normalize_whitespace(child_element.tail).strip():
159
+ return False
160
+ return True
161
+
162
+ def _is_inline_math(self, element: Element) -> bool:
163
+ if element.tag != _MATH_TAG:
164
+ return False
165
+ return element.get("display", "").lower() != "block"
@@ -1,3 +1,2 @@
1
- from .group import XMLGroupContext
2
- from .submitter import submit_text_segments
1
+ from .callbacks import FillFailedEvent
3
2
  from .translator import XMLTranslator
@@ -0,0 +1,34 @@
1
+ from collections.abc import Callable, Iterable
2
+ from dataclasses import dataclass
3
+ from xml.etree.ElementTree import Element
4
+
5
+ from ..segment import TextSegment
6
+
7
+
8
+ @dataclass
9
+ class FillFailedEvent:
10
+ error_message: str
11
+ retried_count: int
12
+ over_maximum_retries: bool
13
+
14
+
15
+ @dataclass
16
+ class Callbacks:
17
+ interrupt_source_text_segments: Callable[[Iterable[TextSegment]], Iterable[TextSegment]]
18
+ interrupt_translated_text_segments: Callable[[Iterable[TextSegment]], Iterable[TextSegment]]
19
+ interrupt_block_element: Callable[[Element], Element]
20
+ on_fill_failed: Callable[[FillFailedEvent], None]
21
+
22
+
23
+ def warp_callbacks(
24
+ interrupt_source_text_segments: Callable[[Iterable[TextSegment]], Iterable[TextSegment]] | None,
25
+ interrupt_translated_text_segments: Callable[[Iterable[TextSegment]], Iterable[TextSegment]] | None,
26
+ interrupt_block_element: Callable[[Element], Element] | None,
27
+ on_fill_failed: Callable[[FillFailedEvent], None] | None,
28
+ ) -> Callbacks:
29
+ return Callbacks(
30
+ interrupt_source_text_segments=interrupt_source_text_segments or (lambda x: x),
31
+ interrupt_translated_text_segments=interrupt_translated_text_segments or (lambda x: x),
32
+ interrupt_block_element=interrupt_block_element or (lambda x: x),
33
+ on_fill_failed=on_fill_failed or (lambda event: None),
34
+ )
@@ -1,2 +1 @@
1
- ID_KEY: str = "id"
2
1
  DATA_ORIGIN_LEN_KEY = "data-orig-len"
@@ -0,0 +1,104 @@
1
+ from collections.abc import Generator
2
+ from dataclasses import dataclass
3
+ from xml.etree.ElementTree import Element
4
+
5
+ from tiktoken import Encoding
6
+
7
+ from ..segment import BlockSegment, BlockSubmitter, TextSegment, search_text_segments
8
+ from ..xml import plain_text
9
+ from .common import DATA_ORIGIN_LEN_KEY
10
+ from .stream_mapper import InlineSegmentMapping
11
+ from .validation import LEVEL_DEPTH, generate_error_message, nest_as_errors_group, truncate_errors_group
12
+
13
+
14
+ @dataclass
15
+ class _BlockStatus:
16
+ weight: int
17
+ submitter: BlockSubmitter
18
+
19
+
20
+ # 以爬山算法,将 LLM 中提交的内容中挑选出完成度更高的部分。
21
+ # 它通过拒绝每个子部分的相对低完成度提交,锁定每个子部分只能往更高完成度的方向移动
22
+ class HillClimbing:
23
+ def __init__(
24
+ self,
25
+ encoding: Encoding,
26
+ max_fill_displaying_errors: int,
27
+ block_segment: BlockSegment,
28
+ ) -> None:
29
+ self._encoding: Encoding = encoding
30
+ self._max_fill_displaying_errors: int = max_fill_displaying_errors
31
+ self._block_statuses: dict[int, _BlockStatus] = {}
32
+ self._block_segment: BlockSegment = block_segment
33
+
34
+ def request_element(self) -> Element:
35
+ element = self._block_segment.create_element()
36
+ for child_element in element:
37
+ text = plain_text(child_element)
38
+ tokens = self._encoding.encode(text)
39
+ child_element.set(DATA_ORIGIN_LEN_KEY, str(len(tokens)))
40
+ return element
41
+
42
+ def gen_mappings(self) -> Generator[InlineSegmentMapping | None, None, None]:
43
+ for inline_segment in self._block_segment:
44
+ id = inline_segment.id
45
+ assert id is not None
46
+ status = self._block_statuses.get(id, None)
47
+ text_segments: list[TextSegment] | None = None
48
+ if status is None:
49
+ yield None
50
+ else:
51
+ submitted_element = status.submitter.submitted_element
52
+ text_segments = list(search_text_segments(submitted_element))
53
+ yield inline_segment.parent, text_segments
54
+
55
+ def submit(self, element: Element) -> str | None:
56
+ error_message, block_weights = self._validate_block_weights_and_error_message(element)
57
+
58
+ for submitter in self._block_segment.submit(element):
59
+ weight: int = 0 # 未出现在 block_weights 说明没有错误,已完成
60
+ if block_weights:
61
+ weight = block_weights.get(submitter.id, 0)
62
+ status = self._block_statuses.get(submitter.id, None)
63
+ if status is None:
64
+ self._block_statuses[submitter.id] = _BlockStatus(
65
+ weight=weight,
66
+ submitter=submitter,
67
+ )
68
+ elif weight < status.weight:
69
+ status.weight = weight
70
+ status.submitter = submitter
71
+
72
+ return error_message
73
+
74
+ def _validate_block_weights_and_error_message(self, element: Element) -> tuple[str | None, dict[int, int] | None]:
75
+ errors_group = nest_as_errors_group(
76
+ errors=self._block_segment.validate(element),
77
+ )
78
+ if errors_group is None:
79
+ return None, None
80
+
81
+ block_weights: dict[int, int] = {}
82
+ for block_group in errors_group.block_groups:
83
+ block_id = block_group.block_id
84
+ status = self._block_statuses.get(block_id, None)
85
+ block_weights[block_id] = block_group.weight
86
+ if status is not None and status.weight > block_group.weight:
87
+ # 本轮完成度得到改善(weight 下降)应该排后,让出注意力给完成度尚未改善的部分
88
+ for child_error in block_group.errors:
89
+ child_error.level -= LEVEL_DEPTH
90
+
91
+ origin_errors_count = errors_group.errors_count
92
+ errors_group = truncate_errors_group(
93
+ errors_group=errors_group,
94
+ max_errors=self._max_fill_displaying_errors,
95
+ )
96
+ if errors_group is None:
97
+ return None, block_weights
98
+
99
+ message = generate_error_message(
100
+ encoding=self._encoding,
101
+ errors_group=errors_group,
102
+ omitted_count=origin_errors_count - errors_group.errors_count,
103
+ )
104
+ return message, block_weights