epub-translator 0.1.0__py3-none-any.whl → 0.1.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- epub_translator/__init__.py +2 -2
- epub_translator/data/fill.jinja +143 -38
- epub_translator/epub/__init__.py +1 -1
- epub_translator/epub/metadata.py +122 -0
- epub_translator/epub/spines.py +3 -2
- epub_translator/epub/zip.py +11 -9
- epub_translator/epub_transcode.py +108 -0
- epub_translator/llm/__init__.py +1 -0
- epub_translator/llm/context.py +109 -0
- epub_translator/llm/core.py +39 -62
- epub_translator/llm/executor.py +25 -31
- epub_translator/llm/increasable.py +1 -1
- epub_translator/llm/types.py +0 -3
- epub_translator/segment/__init__.py +26 -0
- epub_translator/segment/block_segment.py +124 -0
- epub_translator/segment/common.py +29 -0
- epub_translator/segment/inline_segment.py +356 -0
- epub_translator/{xml_translator → segment}/text_segment.py +8 -8
- epub_translator/segment/utils.py +43 -0
- epub_translator/translator.py +150 -183
- epub_translator/utils.py +33 -0
- epub_translator/xml/__init__.py +2 -0
- epub_translator/xml/const.py +1 -0
- epub_translator/xml/deduplication.py +3 -3
- epub_translator/xml/self_closing.py +182 -0
- epub_translator/xml/utils.py +42 -0
- epub_translator/xml/xml.py +7 -0
- epub_translator/xml/xml_like.py +145 -115
- epub_translator/xml_interrupter.py +165 -0
- epub_translator/xml_translator/__init__.py +1 -2
- epub_translator/xml_translator/callbacks.py +34 -0
- epub_translator/xml_translator/{const.py → common.py} +0 -1
- epub_translator/xml_translator/hill_climbing.py +104 -0
- epub_translator/xml_translator/stream_mapper.py +253 -0
- epub_translator/xml_translator/submitter.py +26 -72
- epub_translator/xml_translator/translator.py +157 -107
- epub_translator/xml_translator/validation.py +458 -0
- {epub_translator-0.1.0.dist-info → epub_translator-0.1.3.dist-info}/METADATA +72 -9
- epub_translator-0.1.3.dist-info/RECORD +66 -0
- epub_translator/epub/placeholder.py +0 -53
- epub_translator/iter_sync.py +0 -24
- epub_translator/xml_translator/fill.py +0 -128
- epub_translator/xml_translator/format.py +0 -282
- epub_translator/xml_translator/fragmented.py +0 -125
- epub_translator/xml_translator/group.py +0 -183
- epub_translator/xml_translator/progressive_locking.py +0 -256
- epub_translator/xml_translator/utils.py +0 -29
- epub_translator-0.1.0.dist-info/RECORD +0 -58
- {epub_translator-0.1.0.dist-info → epub_translator-0.1.3.dist-info}/LICENSE +0 -0
- {epub_translator-0.1.0.dist-info → epub_translator-0.1.3.dist-info}/WHEEL +0 -0
epub_translator/xml/xml_like.py
CHANGED
|
@@ -1,10 +1,14 @@
|
|
|
1
1
|
import io
|
|
2
2
|
import re
|
|
3
|
+
import warnings
|
|
3
4
|
from typing import IO
|
|
4
5
|
from xml.etree.ElementTree import Element, fromstring, tostring
|
|
5
6
|
|
|
7
|
+
from .self_closing import self_close_void_elements, unclose_void_elements
|
|
6
8
|
from .xml import iter_with_stack
|
|
7
9
|
|
|
10
|
+
_XML_NAMESPACE_URI = "http://www.w3.org/XML/1998/namespace"
|
|
11
|
+
|
|
8
12
|
_COMMON_NAMESPACES = {
|
|
9
13
|
"http://www.w3.org/1999/xhtml": "xhtml",
|
|
10
14
|
"http://www.idpf.org/2007/ops": "epub",
|
|
@@ -14,6 +18,7 @@ _COMMON_NAMESPACES = {
|
|
|
14
18
|
"http://www.idpf.org/2007/opf": "opf",
|
|
15
19
|
"http://www.w3.org/2000/svg": "svg",
|
|
16
20
|
"urn:oasis:names:tc:opendocument:xmlns:container": "container",
|
|
21
|
+
"http://www.w3.org/XML/1998/namespace": "xml", # Reserved XML namespace
|
|
17
22
|
}
|
|
18
23
|
|
|
19
24
|
_ROOT_NAMESPACES = {
|
|
@@ -27,32 +32,26 @@ _ENCODING_PATTERN = re.compile(r'encoding\s*=\s*["\']([^"\']+)["\']', re.IGNOREC
|
|
|
27
32
|
_FIRST_ELEMENT_PATTERN = re.compile(r"<(?![?!])[a-zA-Z]")
|
|
28
33
|
_NAMESPACE_IN_TAG = re.compile(r"\{([^}]+)\}")
|
|
29
34
|
|
|
30
|
-
# HTML 规定了一系列自闭标签,这些标签需要改成非自闭的,因为 EPub 格式不支持
|
|
31
|
-
# https://www.tutorialspoint.com/which-html-tags-are-self-closing
|
|
32
|
-
_EMPTY_TAGS = (
|
|
33
|
-
"br",
|
|
34
|
-
"hr",
|
|
35
|
-
"input",
|
|
36
|
-
"col",
|
|
37
|
-
"base",
|
|
38
|
-
"meta",
|
|
39
|
-
"area",
|
|
40
|
-
)
|
|
41
|
-
|
|
42
|
-
_EMPTY_TAG_PATTERN = re.compile(r"<(" + "|".join(_EMPTY_TAGS) + r")(\s[^>]*?)\s*/?>")
|
|
43
|
-
|
|
44
35
|
|
|
45
36
|
class XMLLikeNode:
|
|
46
|
-
def __init__(self, file: IO[bytes]) -> None:
|
|
37
|
+
def __init__(self, file: IO[bytes], is_html_like: bool = False) -> None:
|
|
47
38
|
raw_content = file.read()
|
|
48
|
-
self.
|
|
39
|
+
self._is_html_like = is_html_like
|
|
40
|
+
self._encoding: str = self._detect_encoding(raw_content)
|
|
49
41
|
content = raw_content.decode(self._encoding)
|
|
50
|
-
self._header, xml_content = _extract_header(content)
|
|
42
|
+
self._header, xml_content = self._extract_header(content)
|
|
43
|
+
self._namespaces: dict[str, str] = {}
|
|
44
|
+
self._tag_to_namespace: dict[str, str] = {}
|
|
45
|
+
self._attr_to_namespace: dict[str, str] = {}
|
|
46
|
+
|
|
51
47
|
try:
|
|
52
|
-
|
|
48
|
+
# 不必判断类型,这是一个防御性极强的函数,可做到 shit >> XML
|
|
49
|
+
xml_content = self_close_void_elements(xml_content)
|
|
50
|
+
self.element = self._extract_and_clean_namespaces(
|
|
51
|
+
element=fromstring(xml_content),
|
|
52
|
+
)
|
|
53
53
|
except Exception as error:
|
|
54
54
|
raise ValueError("Failed to parse XML-like content") from error
|
|
55
|
-
self._namespaces: dict[str, str] = _extract_and_clean_namespaces(self.element)
|
|
56
55
|
|
|
57
56
|
@property
|
|
58
57
|
def encoding(self) -> str:
|
|
@@ -62,115 +61,146 @@ class XMLLikeNode:
|
|
|
62
61
|
def namespaces(self) -> list[str]:
|
|
63
62
|
return list(self._namespaces.keys())
|
|
64
63
|
|
|
65
|
-
def save(self, file: IO[bytes]
|
|
64
|
+
def save(self, file: IO[bytes]) -> None:
|
|
66
65
|
writer = io.TextIOWrapper(file, encoding=self._encoding, write_through=True)
|
|
67
66
|
try:
|
|
68
67
|
if self._header:
|
|
69
68
|
writer.write(self._header)
|
|
70
69
|
|
|
71
|
-
content = _serialize_with_namespaces(
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
|
|
77
|
-
)
|
|
78
|
-
|
|
79
|
-
content = re.sub(
|
|
80
|
-
pattern=_EMPTY_TAG_PATTERN,
|
|
81
|
-
repl=lambda m: f"<{m.group(1)}{m.group(2)} />",
|
|
82
|
-
string=content,
|
|
83
|
-
)
|
|
70
|
+
content = self._serialize_with_namespaces(self.element)
|
|
71
|
+
|
|
72
|
+
# For non-standard HTML files (text/html), convert back from <br/> to <br>
|
|
73
|
+
# to maintain compatibility with HTML parsers that don't support XHTML
|
|
74
|
+
# For XHTML files (application/xhtml+xml), keep self-closing format
|
|
75
|
+
if self._is_html_like:
|
|
76
|
+
content = unclose_void_elements(content)
|
|
77
|
+
|
|
84
78
|
writer.write(content)
|
|
85
79
|
|
|
86
80
|
finally:
|
|
87
81
|
writer.detach()
|
|
88
82
|
|
|
83
|
+
def _detect_encoding(self, raw_content: bytes) -> str:
|
|
84
|
+
if raw_content.startswith(b"\xef\xbb\xbf"):
|
|
85
|
+
return "utf-8-sig"
|
|
86
|
+
elif raw_content.startswith(b"\xff\xfe"):
|
|
87
|
+
return "utf-16-le"
|
|
88
|
+
elif raw_content.startswith(b"\xfe\xff"):
|
|
89
|
+
return "utf-16-be"
|
|
90
|
+
|
|
91
|
+
# 尝试从 XML 声明中提取编码:只读取前 1024 字节来查找 XML 声明
|
|
92
|
+
header_bytes = raw_content[:1024]
|
|
93
|
+
for try_encoding in ("utf-8", "utf-16-le", "utf-16-be", "iso-8859-1"):
|
|
94
|
+
try:
|
|
95
|
+
header_str = header_bytes.decode(try_encoding)
|
|
96
|
+
match = _ENCODING_PATTERN.search(header_str)
|
|
97
|
+
if match:
|
|
98
|
+
declared_encoding = match.group(1).lower()
|
|
99
|
+
try:
|
|
100
|
+
raw_content.decode(declared_encoding)
|
|
101
|
+
return declared_encoding
|
|
102
|
+
except (LookupError, UnicodeDecodeError):
|
|
103
|
+
pass
|
|
104
|
+
except UnicodeDecodeError:
|
|
105
|
+
continue
|
|
89
106
|
|
|
90
|
-
def _detect_encoding(raw_content: bytes) -> str:
|
|
91
|
-
if raw_content.startswith(b"\xef\xbb\xbf"):
|
|
92
|
-
return "utf-8-sig"
|
|
93
|
-
elif raw_content.startswith(b"\xff\xfe"):
|
|
94
|
-
return "utf-16-le"
|
|
95
|
-
elif raw_content.startswith(b"\xfe\xff"):
|
|
96
|
-
return "utf-16-be"
|
|
97
|
-
|
|
98
|
-
# 尝试从 XML 声明中提取编码:只读取前 1024 字节来查找 XML 声明
|
|
99
|
-
header_bytes = raw_content[:1024]
|
|
100
|
-
for try_encoding in ("utf-8", "utf-16-le", "utf-16-be", "iso-8859-1"):
|
|
101
107
|
try:
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
if match:
|
|
105
|
-
declared_encoding = match.group(1).lower()
|
|
106
|
-
try:
|
|
107
|
-
raw_content.decode(declared_encoding)
|
|
108
|
-
return declared_encoding
|
|
109
|
-
except (LookupError, UnicodeDecodeError):
|
|
110
|
-
pass
|
|
108
|
+
raw_content.decode("utf-8")
|
|
109
|
+
return "utf-8"
|
|
111
110
|
except UnicodeDecodeError:
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
try:
|
|
115
|
-
raw_content.decode("utf-8")
|
|
116
|
-
return "utf-8"
|
|
117
|
-
except UnicodeDecodeError:
|
|
118
|
-
pass
|
|
119
|
-
return "iso-8859-1"
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
def _extract_header(content: str) -> tuple[str, str]:
|
|
123
|
-
match = _FIRST_ELEMENT_PATTERN.search(content)
|
|
124
|
-
if match:
|
|
125
|
-
split_pos = match.start()
|
|
126
|
-
header = content[:split_pos]
|
|
127
|
-
xml_content = content[split_pos:]
|
|
128
|
-
return header, xml_content
|
|
129
|
-
return "", content
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
def _extract_and_clean_namespaces(element: Element):
|
|
133
|
-
namespaces: dict[str, str] = {}
|
|
134
|
-
for _, elem in iter_with_stack(element):
|
|
135
|
-
match = _NAMESPACE_IN_TAG.match(elem.tag)
|
|
136
|
-
if match:
|
|
137
|
-
namespace_uri = match.group(1)
|
|
138
|
-
if namespace_uri not in namespaces:
|
|
139
|
-
prefix = _COMMON_NAMESPACES.get(namespace_uri, f"ns{len(namespaces)}")
|
|
140
|
-
namespaces[namespace_uri] = prefix
|
|
111
|
+
pass
|
|
112
|
+
return "iso-8859-1"
|
|
141
113
|
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
114
|
+
def _extract_header(self, content: str) -> tuple[str, str]:
|
|
115
|
+
match = _FIRST_ELEMENT_PATTERN.search(content)
|
|
116
|
+
if match:
|
|
117
|
+
split_pos = match.start()
|
|
118
|
+
header = content[:split_pos]
|
|
119
|
+
xml_content = content[split_pos:]
|
|
120
|
+
return header, xml_content
|
|
121
|
+
return "", content
|
|
122
|
+
|
|
123
|
+
def _extract_and_clean_namespaces(self, element: Element) -> Element:
|
|
124
|
+
for _, elem in iter_with_stack(element):
|
|
125
|
+
match = _NAMESPACE_IN_TAG.match(elem.tag)
|
|
147
126
|
if match:
|
|
148
127
|
namespace_uri = match.group(1)
|
|
149
|
-
if namespace_uri not in
|
|
150
|
-
prefix = _COMMON_NAMESPACES.get(namespace_uri, f"ns{len(
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
)
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
128
|
+
if namespace_uri not in self._namespaces:
|
|
129
|
+
prefix = _COMMON_NAMESPACES.get(namespace_uri, f"ns{len(self._namespaces)}")
|
|
130
|
+
self._namespaces[namespace_uri] = prefix
|
|
131
|
+
|
|
132
|
+
tag_name = elem.tag[len(match.group(0)) :]
|
|
133
|
+
|
|
134
|
+
# Record tag -> namespace mapping (warn if conflict)
|
|
135
|
+
if tag_name in self._tag_to_namespace and self._tag_to_namespace[tag_name] != namespace_uri:
|
|
136
|
+
warnings.warn(
|
|
137
|
+
f"Tag '{tag_name}' has multiple namespaces: "
|
|
138
|
+
f"{self._tag_to_namespace[tag_name]} and {namespace_uri}. "
|
|
139
|
+
f"Using the first one.",
|
|
140
|
+
stacklevel=2,
|
|
141
|
+
)
|
|
142
|
+
else:
|
|
143
|
+
self._tag_to_namespace[tag_name] = namespace_uri
|
|
144
|
+
|
|
145
|
+
# Clean: remove namespace URI completely
|
|
146
|
+
elem.tag = tag_name
|
|
147
|
+
|
|
148
|
+
for attr_key in list(elem.attrib.keys()):
|
|
149
|
+
match = _NAMESPACE_IN_TAG.match(attr_key)
|
|
150
|
+
if match:
|
|
151
|
+
namespace_uri = match.group(1)
|
|
152
|
+
if namespace_uri not in self._namespaces:
|
|
153
|
+
prefix = _COMMON_NAMESPACES.get(namespace_uri, f"ns{len(self._namespaces)}")
|
|
154
|
+
self._namespaces[namespace_uri] = prefix
|
|
155
|
+
|
|
156
|
+
attr_name = attr_key[len(match.group(0)) :]
|
|
157
|
+
attr_value = elem.attrib.pop(attr_key)
|
|
158
|
+
|
|
159
|
+
# Record attr -> namespace mapping (warn if conflict)
|
|
160
|
+
if attr_name in self._attr_to_namespace and self._attr_to_namespace[attr_name] != namespace_uri:
|
|
161
|
+
warnings.warn(
|
|
162
|
+
f"Attribute '{attr_name}' has multiple namespaces: "
|
|
163
|
+
f"{self._attr_to_namespace[attr_name]} and {namespace_uri}. "
|
|
164
|
+
f"Using the first one.",
|
|
165
|
+
stacklevel=2,
|
|
166
|
+
)
|
|
167
|
+
else:
|
|
168
|
+
self._attr_to_namespace[attr_name] = namespace_uri
|
|
169
|
+
|
|
170
|
+
# Clean: remove namespace URI completely
|
|
171
|
+
elem.attrib[attr_name] = attr_value
|
|
172
|
+
return element
|
|
173
|
+
|
|
174
|
+
def _serialize_with_namespaces(self, element: Element) -> str:
|
|
175
|
+
# First, add namespace declarations to root element (before serialization)
|
|
176
|
+
for namespace_uri, prefix in self._namespaces.items():
|
|
177
|
+
# Skip the reserved xml namespace - it's implicit
|
|
178
|
+
if namespace_uri == _XML_NAMESPACE_URI:
|
|
179
|
+
continue
|
|
180
|
+
if namespace_uri in _ROOT_NAMESPACES:
|
|
181
|
+
element.attrib["xmlns"] = namespace_uri
|
|
182
|
+
else:
|
|
183
|
+
element.attrib[f"xmlns:{prefix}"] = namespace_uri
|
|
184
|
+
|
|
185
|
+
# Serialize the element tree as-is (tags are simple names without prefixes)
|
|
186
|
+
xml_string = tostring(element, encoding="unicode")
|
|
187
|
+
|
|
188
|
+
# Now restore namespace prefixes in the serialized string
|
|
189
|
+
# For each tag that should have a namespace prefix, wrap it with the prefix
|
|
190
|
+
for tag_name, namespace_uri in self._tag_to_namespace.items():
|
|
191
|
+
if namespace_uri not in _ROOT_NAMESPACES:
|
|
192
|
+
# Get the prefix for this namespace
|
|
193
|
+
prefix = self._namespaces[namespace_uri]
|
|
194
|
+
# Replace opening and closing tags
|
|
195
|
+
xml_string = xml_string.replace(f"<{tag_name} ", f"<{prefix}:{tag_name} ")
|
|
196
|
+
xml_string = xml_string.replace(f"<{tag_name}>", f"<{prefix}:{tag_name}>")
|
|
197
|
+
xml_string = xml_string.replace(f"</{tag_name}>", f"</{prefix}:{tag_name}>")
|
|
198
|
+
xml_string = xml_string.replace(f"<{tag_name}/>", f"<{prefix}:{tag_name}/>")
|
|
199
|
+
|
|
200
|
+
# Similarly for attributes (though less common in EPUB)
|
|
201
|
+
for attr_name, namespace_uri in self._attr_to_namespace.items():
|
|
202
|
+
if namespace_uri not in _ROOT_NAMESPACES:
|
|
203
|
+
prefix = self._namespaces[namespace_uri]
|
|
204
|
+
xml_string = xml_string.replace(f' {attr_name}="', f' {prefix}:{attr_name}="')
|
|
205
|
+
|
|
206
|
+
return xml_string
|
|
@@ -0,0 +1,165 @@
|
|
|
1
|
+
from collections.abc import Generator, Iterable
|
|
2
|
+
from typing import cast
|
|
3
|
+
from xml.etree.ElementTree import Element
|
|
4
|
+
|
|
5
|
+
from .segment import TextSegment
|
|
6
|
+
from .utils import ensure_list, normalize_whitespace
|
|
7
|
+
|
|
8
|
+
_ID_KEY = "__XML_INTERRUPTER_ID"
|
|
9
|
+
_MATH_TAG = "math"
|
|
10
|
+
_EXPRESSION_TAG = "expression"
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class XMLInterrupter:
|
|
14
|
+
def __init__(self) -> None:
|
|
15
|
+
self._next_id: int = 1
|
|
16
|
+
self._last_interrupted_id: str | None = None
|
|
17
|
+
self._placeholder2interrupted: dict[int, Element] = {}
|
|
18
|
+
self._raw_text_segments: dict[str, list[TextSegment]] = {}
|
|
19
|
+
|
|
20
|
+
def interrupt_source_text_segments(
|
|
21
|
+
self, text_segments: Iterable[TextSegment]
|
|
22
|
+
) -> Generator[TextSegment, None, None]:
|
|
23
|
+
for text_segment in text_segments:
|
|
24
|
+
yield from self._expand_source_text_segment(text_segment)
|
|
25
|
+
|
|
26
|
+
if self._last_interrupted_id is not None:
|
|
27
|
+
merged_text_segment = self._pop_and_merge_from_buffered(self._last_interrupted_id)
|
|
28
|
+
if merged_text_segment:
|
|
29
|
+
yield merged_text_segment
|
|
30
|
+
|
|
31
|
+
def interrupt_translated_text_segments(
|
|
32
|
+
self, text_segments: Iterable[TextSegment]
|
|
33
|
+
) -> Generator[TextSegment, None, None]:
|
|
34
|
+
for text_segment in text_segments:
|
|
35
|
+
yield from self._expand_translated_text_segment(text_segment)
|
|
36
|
+
|
|
37
|
+
def interrupt_block_element(self, element: Element) -> Element:
|
|
38
|
+
interrupted_element = self._placeholder2interrupted.pop(id(element), None)
|
|
39
|
+
if interrupted_element is None:
|
|
40
|
+
return element
|
|
41
|
+
else:
|
|
42
|
+
return interrupted_element
|
|
43
|
+
|
|
44
|
+
def _expand_source_text_segment(self, text_segment: TextSegment):
|
|
45
|
+
interrupted_index = self._interrupted_index(text_segment)
|
|
46
|
+
interrupted_id: str | None = None
|
|
47
|
+
|
|
48
|
+
if interrupted_index is not None:
|
|
49
|
+
interrupted_element = text_segment.parent_stack[interrupted_index]
|
|
50
|
+
interrupted_id = interrupted_element.get(_ID_KEY)
|
|
51
|
+
if interrupted_id is None:
|
|
52
|
+
interrupted_id = str(self._next_id)
|
|
53
|
+
interrupted_element.set(_ID_KEY, interrupted_id)
|
|
54
|
+
self._next_id += 1
|
|
55
|
+
text_segments = ensure_list(
|
|
56
|
+
target=self._raw_text_segments,
|
|
57
|
+
key=interrupted_id,
|
|
58
|
+
)
|
|
59
|
+
text_segments.append(text_segment)
|
|
60
|
+
|
|
61
|
+
if self._last_interrupted_id is not None and interrupted_id != self._last_interrupted_id:
|
|
62
|
+
merged_text_segment = self._pop_and_merge_from_buffered(self._last_interrupted_id)
|
|
63
|
+
if merged_text_segment:
|
|
64
|
+
yield merged_text_segment
|
|
65
|
+
|
|
66
|
+
self._last_interrupted_id = interrupted_id
|
|
67
|
+
|
|
68
|
+
if interrupted_index is None:
|
|
69
|
+
yield text_segment
|
|
70
|
+
|
|
71
|
+
def _pop_and_merge_from_buffered(self, interrupted_id: str) -> TextSegment | None:
|
|
72
|
+
merged_text_segment: TextSegment | None = None
|
|
73
|
+
text_segments = self._raw_text_segments.get(interrupted_id, None)
|
|
74
|
+
if text_segments:
|
|
75
|
+
text_segment = text_segments[0]
|
|
76
|
+
interrupted_index = cast(int, self._interrupted_index(text_segment))
|
|
77
|
+
interrupted_element = text_segment.parent_stack[cast(int, interrupted_index)]
|
|
78
|
+
placeholder_element = Element(
|
|
79
|
+
_EXPRESSION_TAG,
|
|
80
|
+
{
|
|
81
|
+
_ID_KEY: cast(str, interrupted_element.get(_ID_KEY)),
|
|
82
|
+
},
|
|
83
|
+
)
|
|
84
|
+
raw_parent_stack = text_segment.parent_stack[:interrupted_index]
|
|
85
|
+
parent_stack = raw_parent_stack + [placeholder_element]
|
|
86
|
+
merged_text_segment = TextSegment(
|
|
87
|
+
text="".join(t.text for t in text_segments),
|
|
88
|
+
parent_stack=parent_stack,
|
|
89
|
+
left_common_depth=text_segments[0].left_common_depth,
|
|
90
|
+
right_common_depth=text_segments[-1].right_common_depth,
|
|
91
|
+
block_depth=len(parent_stack),
|
|
92
|
+
position=text_segments[0].position,
|
|
93
|
+
)
|
|
94
|
+
self._placeholder2interrupted[id(placeholder_element)] = interrupted_element
|
|
95
|
+
|
|
96
|
+
# TODO: 比较难搞,先关了再说
|
|
97
|
+
# parent_element: Element | None = None
|
|
98
|
+
# if interrupted_index > 0:
|
|
99
|
+
# parent_element = text_segment.parent_stack[interrupted_index - 1]
|
|
100
|
+
|
|
101
|
+
# if (
|
|
102
|
+
# not self._is_inline_math(interrupted_element)
|
|
103
|
+
# or parent_element is None
|
|
104
|
+
# or self._has_no_math_texts(parent_element)
|
|
105
|
+
# ):
|
|
106
|
+
# # 区块级公式不必重复出现,出现时突兀。但行内公式穿插在译文中更有利于读者阅读顺畅。
|
|
107
|
+
# self._raw_text_segments.pop(interrupted_id, None)
|
|
108
|
+
# else:
|
|
109
|
+
# for text_segment in text_segments:
|
|
110
|
+
# # 原始栈退光,仅留下相对 interrupted 元素的栈,这种格式与 translated 要求一致
|
|
111
|
+
# text_segment.left_common_depth = max(0, text_segment.left_common_depth - interrupted_index)
|
|
112
|
+
# text_segment.right_common_depth = max(0, text_segment.right_common_depth - interrupted_index)
|
|
113
|
+
# text_segment.block_depth = 1
|
|
114
|
+
# text_segment.parent_stack = text_segment.parent_stack[interrupted_index:]
|
|
115
|
+
for text_segment in text_segments:
|
|
116
|
+
# 原始栈退光,仅留下相对 interrupted 元素的栈,这种格式与 translated 要求一致
|
|
117
|
+
text_segment.left_common_depth = max(0, text_segment.left_common_depth - interrupted_index)
|
|
118
|
+
text_segment.right_common_depth = max(0, text_segment.right_common_depth - interrupted_index)
|
|
119
|
+
text_segment.block_depth = 1
|
|
120
|
+
text_segment.parent_stack = text_segment.parent_stack[interrupted_index:]
|
|
121
|
+
|
|
122
|
+
return merged_text_segment
|
|
123
|
+
|
|
124
|
+
def _interrupted_index(self, text_segment: TextSegment) -> int | None:
|
|
125
|
+
interrupted_index: int | None = None
|
|
126
|
+
for i, parent_element in enumerate(text_segment.parent_stack):
|
|
127
|
+
if parent_element.tag == _MATH_TAG:
|
|
128
|
+
interrupted_index = i
|
|
129
|
+
break
|
|
130
|
+
return interrupted_index
|
|
131
|
+
|
|
132
|
+
def _expand_translated_text_segment(self, text_segment: TextSegment):
|
|
133
|
+
interrupted_id = text_segment.block_parent.attrib.pop(_ID_KEY, None)
|
|
134
|
+
if interrupted_id is None:
|
|
135
|
+
yield text_segment
|
|
136
|
+
return
|
|
137
|
+
|
|
138
|
+
raw_text_segments = self._raw_text_segments.pop(interrupted_id, None)
|
|
139
|
+
if not raw_text_segments:
|
|
140
|
+
return
|
|
141
|
+
|
|
142
|
+
raw_block = raw_text_segments[0].parent_stack[0]
|
|
143
|
+
if not self._is_inline_math(raw_block):
|
|
144
|
+
return
|
|
145
|
+
|
|
146
|
+
for raw_text_segment in raw_text_segments:
|
|
147
|
+
raw_text_segment.block_parent.attrib.pop(_ID_KEY, None)
|
|
148
|
+
yield raw_text_segment
|
|
149
|
+
|
|
150
|
+
def _has_no_math_texts(self, element: Element):
|
|
151
|
+
if element.tag == _MATH_TAG:
|
|
152
|
+
return True
|
|
153
|
+
if element.text and normalize_whitespace(element.text).strip():
|
|
154
|
+
return False
|
|
155
|
+
for child_element in element:
|
|
156
|
+
if not self._has_no_math_texts(child_element):
|
|
157
|
+
return False
|
|
158
|
+
if child_element.tail and normalize_whitespace(child_element.tail).strip():
|
|
159
|
+
return False
|
|
160
|
+
return True
|
|
161
|
+
|
|
162
|
+
def _is_inline_math(self, element: Element) -> bool:
|
|
163
|
+
if element.tag != _MATH_TAG:
|
|
164
|
+
return False
|
|
165
|
+
return element.get("display", "").lower() != "block"
|
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
from collections.abc import Callable, Iterable
|
|
2
|
+
from dataclasses import dataclass
|
|
3
|
+
from xml.etree.ElementTree import Element
|
|
4
|
+
|
|
5
|
+
from ..segment import TextSegment
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
@dataclass
|
|
9
|
+
class FillFailedEvent:
|
|
10
|
+
error_message: str
|
|
11
|
+
retried_count: int
|
|
12
|
+
over_maximum_retries: bool
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@dataclass
|
|
16
|
+
class Callbacks:
|
|
17
|
+
interrupt_source_text_segments: Callable[[Iterable[TextSegment]], Iterable[TextSegment]]
|
|
18
|
+
interrupt_translated_text_segments: Callable[[Iterable[TextSegment]], Iterable[TextSegment]]
|
|
19
|
+
interrupt_block_element: Callable[[Element], Element]
|
|
20
|
+
on_fill_failed: Callable[[FillFailedEvent], None]
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def warp_callbacks(
|
|
24
|
+
interrupt_source_text_segments: Callable[[Iterable[TextSegment]], Iterable[TextSegment]] | None,
|
|
25
|
+
interrupt_translated_text_segments: Callable[[Iterable[TextSegment]], Iterable[TextSegment]] | None,
|
|
26
|
+
interrupt_block_element: Callable[[Element], Element] | None,
|
|
27
|
+
on_fill_failed: Callable[[FillFailedEvent], None] | None,
|
|
28
|
+
) -> Callbacks:
|
|
29
|
+
return Callbacks(
|
|
30
|
+
interrupt_source_text_segments=interrupt_source_text_segments or (lambda x: x),
|
|
31
|
+
interrupt_translated_text_segments=interrupt_translated_text_segments or (lambda x: x),
|
|
32
|
+
interrupt_block_element=interrupt_block_element or (lambda x: x),
|
|
33
|
+
on_fill_failed=on_fill_failed or (lambda event: None),
|
|
34
|
+
)
|
|
@@ -0,0 +1,104 @@
|
|
|
1
|
+
from collections.abc import Generator
|
|
2
|
+
from dataclasses import dataclass
|
|
3
|
+
from xml.etree.ElementTree import Element
|
|
4
|
+
|
|
5
|
+
from tiktoken import Encoding
|
|
6
|
+
|
|
7
|
+
from ..segment import BlockSegment, BlockSubmitter, TextSegment, search_text_segments
|
|
8
|
+
from ..xml import plain_text
|
|
9
|
+
from .common import DATA_ORIGIN_LEN_KEY
|
|
10
|
+
from .stream_mapper import InlineSegmentMapping
|
|
11
|
+
from .validation import LEVEL_DEPTH, generate_error_message, nest_as_errors_group, truncate_errors_group
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@dataclass
|
|
15
|
+
class _BlockStatus:
|
|
16
|
+
weight: int
|
|
17
|
+
submitter: BlockSubmitter
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
# 以爬山算法,将 LLM 中提交的内容中挑选出完成度更高的部分。
|
|
21
|
+
# 它通过拒绝每个子部分的相对低完成度提交,锁定每个子部分只能往更高完成度的方向移动
|
|
22
|
+
class HillClimbing:
|
|
23
|
+
def __init__(
|
|
24
|
+
self,
|
|
25
|
+
encoding: Encoding,
|
|
26
|
+
max_fill_displaying_errors: int,
|
|
27
|
+
block_segment: BlockSegment,
|
|
28
|
+
) -> None:
|
|
29
|
+
self._encoding: Encoding = encoding
|
|
30
|
+
self._max_fill_displaying_errors: int = max_fill_displaying_errors
|
|
31
|
+
self._block_statuses: dict[int, _BlockStatus] = {}
|
|
32
|
+
self._block_segment: BlockSegment = block_segment
|
|
33
|
+
|
|
34
|
+
def request_element(self) -> Element:
|
|
35
|
+
element = self._block_segment.create_element()
|
|
36
|
+
for child_element in element:
|
|
37
|
+
text = plain_text(child_element)
|
|
38
|
+
tokens = self._encoding.encode(text)
|
|
39
|
+
child_element.set(DATA_ORIGIN_LEN_KEY, str(len(tokens)))
|
|
40
|
+
return element
|
|
41
|
+
|
|
42
|
+
def gen_mappings(self) -> Generator[InlineSegmentMapping | None, None, None]:
|
|
43
|
+
for inline_segment in self._block_segment:
|
|
44
|
+
id = inline_segment.id
|
|
45
|
+
assert id is not None
|
|
46
|
+
status = self._block_statuses.get(id, None)
|
|
47
|
+
text_segments: list[TextSegment] | None = None
|
|
48
|
+
if status is None:
|
|
49
|
+
yield None
|
|
50
|
+
else:
|
|
51
|
+
submitted_element = status.submitter.submitted_element
|
|
52
|
+
text_segments = list(search_text_segments(submitted_element))
|
|
53
|
+
yield inline_segment.parent, text_segments
|
|
54
|
+
|
|
55
|
+
def submit(self, element: Element) -> str | None:
|
|
56
|
+
error_message, block_weights = self._validate_block_weights_and_error_message(element)
|
|
57
|
+
|
|
58
|
+
for submitter in self._block_segment.submit(element):
|
|
59
|
+
weight: int = 0 # 未出现在 block_weights 说明没有错误,已完成
|
|
60
|
+
if block_weights:
|
|
61
|
+
weight = block_weights.get(submitter.id, 0)
|
|
62
|
+
status = self._block_statuses.get(submitter.id, None)
|
|
63
|
+
if status is None:
|
|
64
|
+
self._block_statuses[submitter.id] = _BlockStatus(
|
|
65
|
+
weight=weight,
|
|
66
|
+
submitter=submitter,
|
|
67
|
+
)
|
|
68
|
+
elif weight < status.weight:
|
|
69
|
+
status.weight = weight
|
|
70
|
+
status.submitter = submitter
|
|
71
|
+
|
|
72
|
+
return error_message
|
|
73
|
+
|
|
74
|
+
def _validate_block_weights_and_error_message(self, element: Element) -> tuple[str | None, dict[int, int] | None]:
|
|
75
|
+
errors_group = nest_as_errors_group(
|
|
76
|
+
errors=self._block_segment.validate(element),
|
|
77
|
+
)
|
|
78
|
+
if errors_group is None:
|
|
79
|
+
return None, None
|
|
80
|
+
|
|
81
|
+
block_weights: dict[int, int] = {}
|
|
82
|
+
for block_group in errors_group.block_groups:
|
|
83
|
+
block_id = block_group.block_id
|
|
84
|
+
status = self._block_statuses.get(block_id, None)
|
|
85
|
+
block_weights[block_id] = block_group.weight
|
|
86
|
+
if status is not None and status.weight > block_group.weight:
|
|
87
|
+
# 本轮完成度得到改善(weight 下降)应该排后,让出注意力给完成度尚未改善的部分
|
|
88
|
+
for child_error in block_group.errors:
|
|
89
|
+
child_error.level -= LEVEL_DEPTH
|
|
90
|
+
|
|
91
|
+
origin_errors_count = errors_group.errors_count
|
|
92
|
+
errors_group = truncate_errors_group(
|
|
93
|
+
errors_group=errors_group,
|
|
94
|
+
max_errors=self._max_fill_displaying_errors,
|
|
95
|
+
)
|
|
96
|
+
if errors_group is None:
|
|
97
|
+
return None, block_weights
|
|
98
|
+
|
|
99
|
+
message = generate_error_message(
|
|
100
|
+
encoding=self._encoding,
|
|
101
|
+
errors_group=errors_group,
|
|
102
|
+
omitted_count=origin_errors_count - errors_group.errors_count,
|
|
103
|
+
)
|
|
104
|
+
return message, block_weights
|