epub-translator 0.1.1__py3-none-any.whl → 0.1.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- epub_translator/__init__.py +2 -2
- epub_translator/data/fill.jinja +143 -38
- epub_translator/epub/__init__.py +1 -1
- epub_translator/epub/metadata.py +122 -0
- epub_translator/epub/spines.py +3 -2
- epub_translator/epub/zip.py +11 -9
- epub_translator/epub_transcode.py +108 -0
- epub_translator/llm/__init__.py +1 -0
- epub_translator/llm/context.py +109 -0
- epub_translator/llm/core.py +32 -113
- epub_translator/llm/executor.py +25 -31
- epub_translator/llm/increasable.py +1 -1
- epub_translator/llm/types.py +0 -3
- epub_translator/segment/__init__.py +26 -0
- epub_translator/segment/block_segment.py +124 -0
- epub_translator/segment/common.py +29 -0
- epub_translator/segment/inline_segment.py +356 -0
- epub_translator/{xml_translator → segment}/text_segment.py +8 -8
- epub_translator/segment/utils.py +43 -0
- epub_translator/translator.py +147 -183
- epub_translator/utils.py +33 -0
- epub_translator/xml/__init__.py +2 -0
- epub_translator/xml/const.py +1 -0
- epub_translator/xml/deduplication.py +3 -3
- epub_translator/xml/self_closing.py +182 -0
- epub_translator/xml/utils.py +42 -0
- epub_translator/xml/xml.py +7 -0
- epub_translator/xml/xml_like.py +8 -33
- epub_translator/xml_interrupter.py +165 -0
- epub_translator/xml_translator/__init__.py +1 -2
- epub_translator/xml_translator/callbacks.py +34 -0
- epub_translator/xml_translator/{const.py → common.py} +0 -1
- epub_translator/xml_translator/hill_climbing.py +104 -0
- epub_translator/xml_translator/stream_mapper.py +253 -0
- epub_translator/xml_translator/submitter.py +26 -72
- epub_translator/xml_translator/translator.py +162 -113
- epub_translator/xml_translator/validation.py +458 -0
- {epub_translator-0.1.1.dist-info → epub_translator-0.1.3.dist-info}/METADATA +72 -9
- epub_translator-0.1.3.dist-info/RECORD +66 -0
- epub_translator/epub/placeholder.py +0 -53
- epub_translator/iter_sync.py +0 -24
- epub_translator/xml_translator/fill.py +0 -128
- epub_translator/xml_translator/format.py +0 -282
- epub_translator/xml_translator/fragmented.py +0 -125
- epub_translator/xml_translator/group.py +0 -183
- epub_translator/xml_translator/progressive_locking.py +0 -256
- epub_translator/xml_translator/utils.py +0 -29
- epub_translator-0.1.1.dist-info/RECORD +0 -58
- {epub_translator-0.1.1.dist-info → epub_translator-0.1.3.dist-info}/LICENSE +0 -0
- {epub_translator-0.1.1.dist-info → epub_translator-0.1.3.dist-info}/WHEEL +0 -0
epub_translator/xml/xml_like.py
CHANGED
|
@@ -4,6 +4,7 @@ import warnings
|
|
|
4
4
|
from typing import IO
|
|
5
5
|
from xml.etree.ElementTree import Element, fromstring, tostring
|
|
6
6
|
|
|
7
|
+
from .self_closing import self_close_void_elements, unclose_void_elements
|
|
7
8
|
from .xml import iter_with_stack
|
|
8
9
|
|
|
9
10
|
_XML_NAMESPACE_URI = "http://www.w3.org/XML/1998/namespace"
|
|
@@ -31,28 +32,11 @@ _ENCODING_PATTERN = re.compile(r'encoding\s*=\s*["\']([^"\']+)["\']', re.IGNOREC
|
|
|
31
32
|
_FIRST_ELEMENT_PATTERN = re.compile(r"<(?![?!])[a-zA-Z]")
|
|
32
33
|
_NAMESPACE_IN_TAG = re.compile(r"\{([^}]+)\}")
|
|
33
34
|
|
|
34
|
-
# Some non-standard EPUB generators use HTML-style tags without self-closing syntax
|
|
35
|
-
# We need to convert them to XML-compatible format before parsing
|
|
36
|
-
_EMPTY_TAGS = (
|
|
37
|
-
"br",
|
|
38
|
-
"hr",
|
|
39
|
-
"input",
|
|
40
|
-
"col",
|
|
41
|
-
"base",
|
|
42
|
-
"meta",
|
|
43
|
-
"area",
|
|
44
|
-
)
|
|
45
|
-
|
|
46
|
-
# For reading: match tags like <br> or <br class="x"> (but not <br/> or <body>)
|
|
47
|
-
_EMPTY_TAG_OPEN_PATTERN = re.compile(r"<(" + "|".join(_EMPTY_TAGS) + r")(\s[^/>]*)>")
|
|
48
|
-
|
|
49
|
-
# For saving: match self-closing tags like <br />
|
|
50
|
-
_EMPTY_TAG_CLOSE_PATTERN = re.compile(r"<(" + "|".join(_EMPTY_TAGS) + r")(\s[^>]*?)\s*/>")
|
|
51
|
-
|
|
52
35
|
|
|
53
36
|
class XMLLikeNode:
|
|
54
37
|
def __init__(self, file: IO[bytes], is_html_like: bool = False) -> None:
|
|
55
38
|
raw_content = file.read()
|
|
39
|
+
self._is_html_like = is_html_like
|
|
56
40
|
self._encoding: str = self._detect_encoding(raw_content)
|
|
57
41
|
content = raw_content.decode(self._encoding)
|
|
58
42
|
self._header, xml_content = self._extract_header(content)
|
|
@@ -60,16 +44,9 @@ class XMLLikeNode:
|
|
|
60
44
|
self._tag_to_namespace: dict[str, str] = {}
|
|
61
45
|
self._attr_to_namespace: dict[str, str] = {}
|
|
62
46
|
|
|
63
|
-
# For non-standard HTML files, convert <br> to <br/> before parsing
|
|
64
|
-
self._is_html_like = is_html_like
|
|
65
|
-
if is_html_like:
|
|
66
|
-
xml_content = re.sub(
|
|
67
|
-
pattern=_EMPTY_TAG_OPEN_PATTERN,
|
|
68
|
-
repl=lambda m: f"<{m.group(1)}{m.group(2)} />",
|
|
69
|
-
string=xml_content,
|
|
70
|
-
)
|
|
71
|
-
|
|
72
47
|
try:
|
|
48
|
+
# 不必判断类型,这是一个防御性极强的函数,可做到 shit >> XML
|
|
49
|
+
xml_content = self_close_void_elements(xml_content)
|
|
73
50
|
self.element = self._extract_and_clean_namespaces(
|
|
74
51
|
element=fromstring(xml_content),
|
|
75
52
|
)
|
|
@@ -92,13 +69,11 @@ class XMLLikeNode:
|
|
|
92
69
|
|
|
93
70
|
content = self._serialize_with_namespaces(self.element)
|
|
94
71
|
|
|
95
|
-
# For non-standard HTML files, convert back from <br/> to <br>
|
|
72
|
+
# For non-standard HTML files (text/html), convert back from <br/> to <br>
|
|
73
|
+
# to maintain compatibility with HTML parsers that don't support XHTML
|
|
74
|
+
# For XHTML files (application/xhtml+xml), keep self-closing format
|
|
96
75
|
if self._is_html_like:
|
|
97
|
-
content =
|
|
98
|
-
pattern=_EMPTY_TAG_CLOSE_PATTERN,
|
|
99
|
-
repl=lambda m: f"<{m.group(1)}{m.group(2)}>",
|
|
100
|
-
string=content,
|
|
101
|
-
)
|
|
76
|
+
content = unclose_void_elements(content)
|
|
102
77
|
|
|
103
78
|
writer.write(content)
|
|
104
79
|
|
|
@@ -0,0 +1,165 @@
|
|
|
1
|
+
from collections.abc import Generator, Iterable
|
|
2
|
+
from typing import cast
|
|
3
|
+
from xml.etree.ElementTree import Element
|
|
4
|
+
|
|
5
|
+
from .segment import TextSegment
|
|
6
|
+
from .utils import ensure_list, normalize_whitespace
|
|
7
|
+
|
|
8
|
+
_ID_KEY = "__XML_INTERRUPTER_ID"
|
|
9
|
+
_MATH_TAG = "math"
|
|
10
|
+
_EXPRESSION_TAG = "expression"
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class XMLInterrupter:
|
|
14
|
+
def __init__(self) -> None:
|
|
15
|
+
self._next_id: int = 1
|
|
16
|
+
self._last_interrupted_id: str | None = None
|
|
17
|
+
self._placeholder2interrupted: dict[int, Element] = {}
|
|
18
|
+
self._raw_text_segments: dict[str, list[TextSegment]] = {}
|
|
19
|
+
|
|
20
|
+
def interrupt_source_text_segments(
|
|
21
|
+
self, text_segments: Iterable[TextSegment]
|
|
22
|
+
) -> Generator[TextSegment, None, None]:
|
|
23
|
+
for text_segment in text_segments:
|
|
24
|
+
yield from self._expand_source_text_segment(text_segment)
|
|
25
|
+
|
|
26
|
+
if self._last_interrupted_id is not None:
|
|
27
|
+
merged_text_segment = self._pop_and_merge_from_buffered(self._last_interrupted_id)
|
|
28
|
+
if merged_text_segment:
|
|
29
|
+
yield merged_text_segment
|
|
30
|
+
|
|
31
|
+
def interrupt_translated_text_segments(
|
|
32
|
+
self, text_segments: Iterable[TextSegment]
|
|
33
|
+
) -> Generator[TextSegment, None, None]:
|
|
34
|
+
for text_segment in text_segments:
|
|
35
|
+
yield from self._expand_translated_text_segment(text_segment)
|
|
36
|
+
|
|
37
|
+
def interrupt_block_element(self, element: Element) -> Element:
|
|
38
|
+
interrupted_element = self._placeholder2interrupted.pop(id(element), None)
|
|
39
|
+
if interrupted_element is None:
|
|
40
|
+
return element
|
|
41
|
+
else:
|
|
42
|
+
return interrupted_element
|
|
43
|
+
|
|
44
|
+
def _expand_source_text_segment(self, text_segment: TextSegment):
|
|
45
|
+
interrupted_index = self._interrupted_index(text_segment)
|
|
46
|
+
interrupted_id: str | None = None
|
|
47
|
+
|
|
48
|
+
if interrupted_index is not None:
|
|
49
|
+
interrupted_element = text_segment.parent_stack[interrupted_index]
|
|
50
|
+
interrupted_id = interrupted_element.get(_ID_KEY)
|
|
51
|
+
if interrupted_id is None:
|
|
52
|
+
interrupted_id = str(self._next_id)
|
|
53
|
+
interrupted_element.set(_ID_KEY, interrupted_id)
|
|
54
|
+
self._next_id += 1
|
|
55
|
+
text_segments = ensure_list(
|
|
56
|
+
target=self._raw_text_segments,
|
|
57
|
+
key=interrupted_id,
|
|
58
|
+
)
|
|
59
|
+
text_segments.append(text_segment)
|
|
60
|
+
|
|
61
|
+
if self._last_interrupted_id is not None and interrupted_id != self._last_interrupted_id:
|
|
62
|
+
merged_text_segment = self._pop_and_merge_from_buffered(self._last_interrupted_id)
|
|
63
|
+
if merged_text_segment:
|
|
64
|
+
yield merged_text_segment
|
|
65
|
+
|
|
66
|
+
self._last_interrupted_id = interrupted_id
|
|
67
|
+
|
|
68
|
+
if interrupted_index is None:
|
|
69
|
+
yield text_segment
|
|
70
|
+
|
|
71
|
+
def _pop_and_merge_from_buffered(self, interrupted_id: str) -> TextSegment | None:
|
|
72
|
+
merged_text_segment: TextSegment | None = None
|
|
73
|
+
text_segments = self._raw_text_segments.get(interrupted_id, None)
|
|
74
|
+
if text_segments:
|
|
75
|
+
text_segment = text_segments[0]
|
|
76
|
+
interrupted_index = cast(int, self._interrupted_index(text_segment))
|
|
77
|
+
interrupted_element = text_segment.parent_stack[cast(int, interrupted_index)]
|
|
78
|
+
placeholder_element = Element(
|
|
79
|
+
_EXPRESSION_TAG,
|
|
80
|
+
{
|
|
81
|
+
_ID_KEY: cast(str, interrupted_element.get(_ID_KEY)),
|
|
82
|
+
},
|
|
83
|
+
)
|
|
84
|
+
raw_parent_stack = text_segment.parent_stack[:interrupted_index]
|
|
85
|
+
parent_stack = raw_parent_stack + [placeholder_element]
|
|
86
|
+
merged_text_segment = TextSegment(
|
|
87
|
+
text="".join(t.text for t in text_segments),
|
|
88
|
+
parent_stack=parent_stack,
|
|
89
|
+
left_common_depth=text_segments[0].left_common_depth,
|
|
90
|
+
right_common_depth=text_segments[-1].right_common_depth,
|
|
91
|
+
block_depth=len(parent_stack),
|
|
92
|
+
position=text_segments[0].position,
|
|
93
|
+
)
|
|
94
|
+
self._placeholder2interrupted[id(placeholder_element)] = interrupted_element
|
|
95
|
+
|
|
96
|
+
# TODO: 比较难搞,先关了再说
|
|
97
|
+
# parent_element: Element | None = None
|
|
98
|
+
# if interrupted_index > 0:
|
|
99
|
+
# parent_element = text_segment.parent_stack[interrupted_index - 1]
|
|
100
|
+
|
|
101
|
+
# if (
|
|
102
|
+
# not self._is_inline_math(interrupted_element)
|
|
103
|
+
# or parent_element is None
|
|
104
|
+
# or self._has_no_math_texts(parent_element)
|
|
105
|
+
# ):
|
|
106
|
+
# # 区块级公式不必重复出现,出现时突兀。但行内公式穿插在译文中更有利于读者阅读顺畅。
|
|
107
|
+
# self._raw_text_segments.pop(interrupted_id, None)
|
|
108
|
+
# else:
|
|
109
|
+
# for text_segment in text_segments:
|
|
110
|
+
# # 原始栈退光,仅留下相对 interrupted 元素的栈,这种格式与 translated 要求一致
|
|
111
|
+
# text_segment.left_common_depth = max(0, text_segment.left_common_depth - interrupted_index)
|
|
112
|
+
# text_segment.right_common_depth = max(0, text_segment.right_common_depth - interrupted_index)
|
|
113
|
+
# text_segment.block_depth = 1
|
|
114
|
+
# text_segment.parent_stack = text_segment.parent_stack[interrupted_index:]
|
|
115
|
+
for text_segment in text_segments:
|
|
116
|
+
# 原始栈退光,仅留下相对 interrupted 元素的栈,这种格式与 translated 要求一致
|
|
117
|
+
text_segment.left_common_depth = max(0, text_segment.left_common_depth - interrupted_index)
|
|
118
|
+
text_segment.right_common_depth = max(0, text_segment.right_common_depth - interrupted_index)
|
|
119
|
+
text_segment.block_depth = 1
|
|
120
|
+
text_segment.parent_stack = text_segment.parent_stack[interrupted_index:]
|
|
121
|
+
|
|
122
|
+
return merged_text_segment
|
|
123
|
+
|
|
124
|
+
def _interrupted_index(self, text_segment: TextSegment) -> int | None:
|
|
125
|
+
interrupted_index: int | None = None
|
|
126
|
+
for i, parent_element in enumerate(text_segment.parent_stack):
|
|
127
|
+
if parent_element.tag == _MATH_TAG:
|
|
128
|
+
interrupted_index = i
|
|
129
|
+
break
|
|
130
|
+
return interrupted_index
|
|
131
|
+
|
|
132
|
+
def _expand_translated_text_segment(self, text_segment: TextSegment):
|
|
133
|
+
interrupted_id = text_segment.block_parent.attrib.pop(_ID_KEY, None)
|
|
134
|
+
if interrupted_id is None:
|
|
135
|
+
yield text_segment
|
|
136
|
+
return
|
|
137
|
+
|
|
138
|
+
raw_text_segments = self._raw_text_segments.pop(interrupted_id, None)
|
|
139
|
+
if not raw_text_segments:
|
|
140
|
+
return
|
|
141
|
+
|
|
142
|
+
raw_block = raw_text_segments[0].parent_stack[0]
|
|
143
|
+
if not self._is_inline_math(raw_block):
|
|
144
|
+
return
|
|
145
|
+
|
|
146
|
+
for raw_text_segment in raw_text_segments:
|
|
147
|
+
raw_text_segment.block_parent.attrib.pop(_ID_KEY, None)
|
|
148
|
+
yield raw_text_segment
|
|
149
|
+
|
|
150
|
+
def _has_no_math_texts(self, element: Element):
|
|
151
|
+
if element.tag == _MATH_TAG:
|
|
152
|
+
return True
|
|
153
|
+
if element.text and normalize_whitespace(element.text).strip():
|
|
154
|
+
return False
|
|
155
|
+
for child_element in element:
|
|
156
|
+
if not self._has_no_math_texts(child_element):
|
|
157
|
+
return False
|
|
158
|
+
if child_element.tail and normalize_whitespace(child_element.tail).strip():
|
|
159
|
+
return False
|
|
160
|
+
return True
|
|
161
|
+
|
|
162
|
+
def _is_inline_math(self, element: Element) -> bool:
|
|
163
|
+
if element.tag != _MATH_TAG:
|
|
164
|
+
return False
|
|
165
|
+
return element.get("display", "").lower() != "block"
|
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
from collections.abc import Callable, Iterable
|
|
2
|
+
from dataclasses import dataclass
|
|
3
|
+
from xml.etree.ElementTree import Element
|
|
4
|
+
|
|
5
|
+
from ..segment import TextSegment
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
@dataclass
|
|
9
|
+
class FillFailedEvent:
|
|
10
|
+
error_message: str
|
|
11
|
+
retried_count: int
|
|
12
|
+
over_maximum_retries: bool
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@dataclass
|
|
16
|
+
class Callbacks:
|
|
17
|
+
interrupt_source_text_segments: Callable[[Iterable[TextSegment]], Iterable[TextSegment]]
|
|
18
|
+
interrupt_translated_text_segments: Callable[[Iterable[TextSegment]], Iterable[TextSegment]]
|
|
19
|
+
interrupt_block_element: Callable[[Element], Element]
|
|
20
|
+
on_fill_failed: Callable[[FillFailedEvent], None]
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def warp_callbacks(
|
|
24
|
+
interrupt_source_text_segments: Callable[[Iterable[TextSegment]], Iterable[TextSegment]] | None,
|
|
25
|
+
interrupt_translated_text_segments: Callable[[Iterable[TextSegment]], Iterable[TextSegment]] | None,
|
|
26
|
+
interrupt_block_element: Callable[[Element], Element] | None,
|
|
27
|
+
on_fill_failed: Callable[[FillFailedEvent], None] | None,
|
|
28
|
+
) -> Callbacks:
|
|
29
|
+
return Callbacks(
|
|
30
|
+
interrupt_source_text_segments=interrupt_source_text_segments or (lambda x: x),
|
|
31
|
+
interrupt_translated_text_segments=interrupt_translated_text_segments or (lambda x: x),
|
|
32
|
+
interrupt_block_element=interrupt_block_element or (lambda x: x),
|
|
33
|
+
on_fill_failed=on_fill_failed or (lambda event: None),
|
|
34
|
+
)
|
|
@@ -0,0 +1,104 @@
|
|
|
1
|
+
from collections.abc import Generator
|
|
2
|
+
from dataclasses import dataclass
|
|
3
|
+
from xml.etree.ElementTree import Element
|
|
4
|
+
|
|
5
|
+
from tiktoken import Encoding
|
|
6
|
+
|
|
7
|
+
from ..segment import BlockSegment, BlockSubmitter, TextSegment, search_text_segments
|
|
8
|
+
from ..xml import plain_text
|
|
9
|
+
from .common import DATA_ORIGIN_LEN_KEY
|
|
10
|
+
from .stream_mapper import InlineSegmentMapping
|
|
11
|
+
from .validation import LEVEL_DEPTH, generate_error_message, nest_as_errors_group, truncate_errors_group
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@dataclass
|
|
15
|
+
class _BlockStatus:
|
|
16
|
+
weight: int
|
|
17
|
+
submitter: BlockSubmitter
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
# 以爬山算法,将 LLM 中提交的内容中挑选出完成度更高的部分。
|
|
21
|
+
# 它通过拒绝每个子部分的相对低完成度提交,锁定每个子部分只能往更高完成度的方向移动
|
|
22
|
+
class HillClimbing:
|
|
23
|
+
def __init__(
|
|
24
|
+
self,
|
|
25
|
+
encoding: Encoding,
|
|
26
|
+
max_fill_displaying_errors: int,
|
|
27
|
+
block_segment: BlockSegment,
|
|
28
|
+
) -> None:
|
|
29
|
+
self._encoding: Encoding = encoding
|
|
30
|
+
self._max_fill_displaying_errors: int = max_fill_displaying_errors
|
|
31
|
+
self._block_statuses: dict[int, _BlockStatus] = {}
|
|
32
|
+
self._block_segment: BlockSegment = block_segment
|
|
33
|
+
|
|
34
|
+
def request_element(self) -> Element:
|
|
35
|
+
element = self._block_segment.create_element()
|
|
36
|
+
for child_element in element:
|
|
37
|
+
text = plain_text(child_element)
|
|
38
|
+
tokens = self._encoding.encode(text)
|
|
39
|
+
child_element.set(DATA_ORIGIN_LEN_KEY, str(len(tokens)))
|
|
40
|
+
return element
|
|
41
|
+
|
|
42
|
+
def gen_mappings(self) -> Generator[InlineSegmentMapping | None, None, None]:
|
|
43
|
+
for inline_segment in self._block_segment:
|
|
44
|
+
id = inline_segment.id
|
|
45
|
+
assert id is not None
|
|
46
|
+
status = self._block_statuses.get(id, None)
|
|
47
|
+
text_segments: list[TextSegment] | None = None
|
|
48
|
+
if status is None:
|
|
49
|
+
yield None
|
|
50
|
+
else:
|
|
51
|
+
submitted_element = status.submitter.submitted_element
|
|
52
|
+
text_segments = list(search_text_segments(submitted_element))
|
|
53
|
+
yield inline_segment.parent, text_segments
|
|
54
|
+
|
|
55
|
+
def submit(self, element: Element) -> str | None:
|
|
56
|
+
error_message, block_weights = self._validate_block_weights_and_error_message(element)
|
|
57
|
+
|
|
58
|
+
for submitter in self._block_segment.submit(element):
|
|
59
|
+
weight: int = 0 # 未出现在 block_weights 说明没有错误,已完成
|
|
60
|
+
if block_weights:
|
|
61
|
+
weight = block_weights.get(submitter.id, 0)
|
|
62
|
+
status = self._block_statuses.get(submitter.id, None)
|
|
63
|
+
if status is None:
|
|
64
|
+
self._block_statuses[submitter.id] = _BlockStatus(
|
|
65
|
+
weight=weight,
|
|
66
|
+
submitter=submitter,
|
|
67
|
+
)
|
|
68
|
+
elif weight < status.weight:
|
|
69
|
+
status.weight = weight
|
|
70
|
+
status.submitter = submitter
|
|
71
|
+
|
|
72
|
+
return error_message
|
|
73
|
+
|
|
74
|
+
def _validate_block_weights_and_error_message(self, element: Element) -> tuple[str | None, dict[int, int] | None]:
|
|
75
|
+
errors_group = nest_as_errors_group(
|
|
76
|
+
errors=self._block_segment.validate(element),
|
|
77
|
+
)
|
|
78
|
+
if errors_group is None:
|
|
79
|
+
return None, None
|
|
80
|
+
|
|
81
|
+
block_weights: dict[int, int] = {}
|
|
82
|
+
for block_group in errors_group.block_groups:
|
|
83
|
+
block_id = block_group.block_id
|
|
84
|
+
status = self._block_statuses.get(block_id, None)
|
|
85
|
+
block_weights[block_id] = block_group.weight
|
|
86
|
+
if status is not None and status.weight > block_group.weight:
|
|
87
|
+
# 本轮完成度得到改善(weight 下降)应该排后,让出注意力给完成度尚未改善的部分
|
|
88
|
+
for child_error in block_group.errors:
|
|
89
|
+
child_error.level -= LEVEL_DEPTH
|
|
90
|
+
|
|
91
|
+
origin_errors_count = errors_group.errors_count
|
|
92
|
+
errors_group = truncate_errors_group(
|
|
93
|
+
errors_group=errors_group,
|
|
94
|
+
max_errors=self._max_fill_displaying_errors,
|
|
95
|
+
)
|
|
96
|
+
if errors_group is None:
|
|
97
|
+
return None, block_weights
|
|
98
|
+
|
|
99
|
+
message = generate_error_message(
|
|
100
|
+
encoding=self._encoding,
|
|
101
|
+
errors_group=errors_group,
|
|
102
|
+
omitted_count=origin_errors_count - errors_group.errors_count,
|
|
103
|
+
)
|
|
104
|
+
return message, block_weights
|
|
@@ -0,0 +1,253 @@
|
|
|
1
|
+
from collections.abc import Callable, Generator, Iterable, Iterator
|
|
2
|
+
from xml.etree.ElementTree import Element
|
|
3
|
+
|
|
4
|
+
from resource_segmentation import Group, Resource, Segment, split
|
|
5
|
+
from tiktoken import Encoding
|
|
6
|
+
|
|
7
|
+
from ..segment import InlineSegment, TextSegment, search_inline_segments, search_text_segments
|
|
8
|
+
from .callbacks import Callbacks
|
|
9
|
+
|
|
10
|
+
_PAGE_INCISION = 0
|
|
11
|
+
_BLOCK_INCISION = 1
|
|
12
|
+
|
|
13
|
+
_ELLIPSIS = "..."
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
InlineSegmentMapping = tuple[Element, list[TextSegment]]
|
|
17
|
+
InlineSegmentGroupMap = Callable[[list[InlineSegment]], list[InlineSegmentMapping | None]]
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class XMLStreamMapper:
|
|
21
|
+
def __init__(self, encoding: Encoding, max_group_tokens: int) -> None:
|
|
22
|
+
self._encoding: Encoding = encoding
|
|
23
|
+
self._max_group_tokens: int = max_group_tokens
|
|
24
|
+
|
|
25
|
+
def map_stream(
|
|
26
|
+
self,
|
|
27
|
+
elements: Iterator[Element],
|
|
28
|
+
callbacks: Callbacks,
|
|
29
|
+
map: InlineSegmentGroupMap,
|
|
30
|
+
) -> Generator[tuple[Element, list[InlineSegmentMapping]], None, None]:
|
|
31
|
+
current_element: Element | None = None
|
|
32
|
+
mapping_buffer: list[InlineSegmentMapping] = []
|
|
33
|
+
|
|
34
|
+
for group in self._split_into_serial_groups(elements, callbacks):
|
|
35
|
+
head, body, tail = self._truncate_and_transform_group(group)
|
|
36
|
+
target_body = map(head + body + tail)[len(head) : len(head) + len(body)]
|
|
37
|
+
for origin, target in zip(body, target_body, strict=False):
|
|
38
|
+
origin_element = origin.head.root
|
|
39
|
+
if current_element is None:
|
|
40
|
+
current_element = origin_element
|
|
41
|
+
|
|
42
|
+
if id(current_element) != id(origin_element):
|
|
43
|
+
yield current_element, mapping_buffer
|
|
44
|
+
current_element = origin_element
|
|
45
|
+
mapping_buffer = []
|
|
46
|
+
|
|
47
|
+
if target:
|
|
48
|
+
block_element, text_segments = target
|
|
49
|
+
block_element = callbacks.interrupt_block_element(block_element)
|
|
50
|
+
text_segments = list(callbacks.interrupt_translated_text_segments(text_segments))
|
|
51
|
+
if text_segments:
|
|
52
|
+
mapping_buffer.append((block_element, text_segments))
|
|
53
|
+
|
|
54
|
+
if current_element is not None:
|
|
55
|
+
yield current_element, mapping_buffer
|
|
56
|
+
|
|
57
|
+
def _split_into_serial_groups(self, elements: Iterable[Element], callbacks: Callbacks):
|
|
58
|
+
def generate():
|
|
59
|
+
for element in elements:
|
|
60
|
+
yield from split(
|
|
61
|
+
max_segment_count=self._max_group_tokens,
|
|
62
|
+
border_incision=_PAGE_INCISION,
|
|
63
|
+
resources=self._expand_to_resources(element, callbacks),
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
generator = generate()
|
|
67
|
+
group = next(generator, None)
|
|
68
|
+
if group is None:
|
|
69
|
+
return
|
|
70
|
+
|
|
71
|
+
# head + body * N (without tail)
|
|
72
|
+
sum_count = group.head_remain_count + sum(x.count for x in self._expand_resource_segments(group.body))
|
|
73
|
+
|
|
74
|
+
while True:
|
|
75
|
+
next_group = next(generator, None)
|
|
76
|
+
if next_group is None:
|
|
77
|
+
break
|
|
78
|
+
|
|
79
|
+
next_sum_body_count = sum(x.count for x in self._expand_resource_segments(next_group.body))
|
|
80
|
+
next_sum_count = sum_count + next_sum_body_count
|
|
81
|
+
|
|
82
|
+
if next_sum_count + next_group.tail_remain_count > self._max_group_tokens:
|
|
83
|
+
yield group
|
|
84
|
+
group = next_group
|
|
85
|
+
sum_count = group.head_remain_count + next_sum_body_count
|
|
86
|
+
else:
|
|
87
|
+
group.body.extend(next_group.body)
|
|
88
|
+
group.tail = next_group.tail
|
|
89
|
+
group.tail_remain_count = next_group.tail_remain_count
|
|
90
|
+
sum_count = next_sum_count
|
|
91
|
+
|
|
92
|
+
yield group
|
|
93
|
+
|
|
94
|
+
def _truncate_and_transform_group(self, group: Group[InlineSegment]):
|
|
95
|
+
head = list(
|
|
96
|
+
self._truncate_inline_segments(
|
|
97
|
+
inline_segments=self._expand_inline_segments(group.head),
|
|
98
|
+
remain_head=False,
|
|
99
|
+
remain_count=group.head_remain_count,
|
|
100
|
+
)
|
|
101
|
+
)
|
|
102
|
+
body = list(self._expand_inline_segments(group.body))
|
|
103
|
+
tail = list(
|
|
104
|
+
self._truncate_inline_segments(
|
|
105
|
+
inline_segments=self._expand_inline_segments(group.tail),
|
|
106
|
+
remain_head=True,
|
|
107
|
+
remain_count=group.tail_remain_count,
|
|
108
|
+
)
|
|
109
|
+
)
|
|
110
|
+
return head, body, tail
|
|
111
|
+
|
|
112
|
+
def _expand_to_resources(self, element: Element, callbacks: Callbacks):
|
|
113
|
+
def expand(element: Element):
|
|
114
|
+
text_segments = search_text_segments(element)
|
|
115
|
+
text_segments = callbacks.interrupt_source_text_segments(text_segments)
|
|
116
|
+
yield from search_inline_segments(text_segments)
|
|
117
|
+
|
|
118
|
+
inline_segment_generator = expand(element)
|
|
119
|
+
start_incision = _PAGE_INCISION
|
|
120
|
+
inline_segment = next(inline_segment_generator, None)
|
|
121
|
+
if inline_segment is None:
|
|
122
|
+
return
|
|
123
|
+
|
|
124
|
+
while True:
|
|
125
|
+
next_inline_segment = next(inline_segment_generator, None)
|
|
126
|
+
if next_inline_segment is None:
|
|
127
|
+
break
|
|
128
|
+
|
|
129
|
+
if next_inline_segment.head.root is inline_segment.tail.root:
|
|
130
|
+
end_incision = _BLOCK_INCISION
|
|
131
|
+
else:
|
|
132
|
+
end_incision = _PAGE_INCISION
|
|
133
|
+
|
|
134
|
+
yield Resource(
|
|
135
|
+
count=sum(len(self._encoding.encode(t.xml_text)) for t in inline_segment),
|
|
136
|
+
start_incision=start_incision,
|
|
137
|
+
end_incision=end_incision,
|
|
138
|
+
payload=inline_segment,
|
|
139
|
+
)
|
|
140
|
+
inline_segment = next_inline_segment
|
|
141
|
+
start_incision = end_incision
|
|
142
|
+
|
|
143
|
+
yield Resource(
|
|
144
|
+
count=sum(len(self._encoding.encode(t.xml_text)) for t in inline_segment),
|
|
145
|
+
start_incision=start_incision,
|
|
146
|
+
end_incision=_PAGE_INCISION,
|
|
147
|
+
payload=inline_segment,
|
|
148
|
+
)
|
|
149
|
+
|
|
150
|
+
def _truncate_inline_segments(self, inline_segments: Iterable[InlineSegment], remain_head: bool, remain_count: int):
|
|
151
|
+
def clone_and_expand(segments: Iterable[InlineSegment]):
|
|
152
|
+
for segment in segments:
|
|
153
|
+
for child_segment in segment:
|
|
154
|
+
yield child_segment.clone() # 切割对应的 head 和 tail 会与其他 group 重叠,复制避免互相影响
|
|
155
|
+
|
|
156
|
+
truncated_text_segments = self._truncate_text_segments(
|
|
157
|
+
text_segments=clone_and_expand(inline_segments),
|
|
158
|
+
remain_head=remain_head,
|
|
159
|
+
remain_count=remain_count,
|
|
160
|
+
)
|
|
161
|
+
yield from search_inline_segments(truncated_text_segments)
|
|
162
|
+
|
|
163
|
+
def _expand_inline_segments(self, items: list[Resource[InlineSegment] | Segment[InlineSegment]]):
|
|
164
|
+
for resource in self._expand_resource_segments(items):
|
|
165
|
+
yield resource.payload
|
|
166
|
+
|
|
167
|
+
def _expand_resource_segments(self, items: list[Resource[InlineSegment] | Segment[InlineSegment]]):
|
|
168
|
+
for item in items:
|
|
169
|
+
if isinstance(item, Resource):
|
|
170
|
+
yield item
|
|
171
|
+
elif isinstance(item, Segment):
|
|
172
|
+
yield from item.resources
|
|
173
|
+
|
|
174
|
+
def _truncate_text_segments(self, text_segments: Iterable[TextSegment], remain_head: bool, remain_count: int):
|
|
175
|
+
if remain_head:
|
|
176
|
+
yield from self._filter_and_remain_segments(
|
|
177
|
+
segments=text_segments,
|
|
178
|
+
remain_head=remain_head,
|
|
179
|
+
remain_count=remain_count,
|
|
180
|
+
)
|
|
181
|
+
else:
|
|
182
|
+
yield from reversed(
|
|
183
|
+
list(
|
|
184
|
+
self._filter_and_remain_segments(
|
|
185
|
+
segments=reversed(list(text_segments)),
|
|
186
|
+
remain_head=remain_head,
|
|
187
|
+
remain_count=remain_count,
|
|
188
|
+
)
|
|
189
|
+
)
|
|
190
|
+
)
|
|
191
|
+
|
|
192
|
+
def _filter_and_remain_segments(self, segments: Iterable[TextSegment], remain_head: bool, remain_count: int):
|
|
193
|
+
for segment in segments:
|
|
194
|
+
if remain_count <= 0:
|
|
195
|
+
break
|
|
196
|
+
raw_xml_text = segment.xml_text
|
|
197
|
+
tokens = self._encoding.encode(raw_xml_text)
|
|
198
|
+
tokens_count = len(tokens)
|
|
199
|
+
|
|
200
|
+
if tokens_count > remain_count:
|
|
201
|
+
truncated_segment = self._truncate_text_segment(
|
|
202
|
+
segment=segment,
|
|
203
|
+
tokens=tokens,
|
|
204
|
+
raw_xml_text=raw_xml_text,
|
|
205
|
+
remain_head=remain_head,
|
|
206
|
+
remain_count=remain_count,
|
|
207
|
+
)
|
|
208
|
+
if truncated_segment is not None:
|
|
209
|
+
yield truncated_segment
|
|
210
|
+
break
|
|
211
|
+
|
|
212
|
+
yield segment
|
|
213
|
+
remain_count -= tokens_count
|
|
214
|
+
|
|
215
|
+
def _truncate_text_segment(
|
|
216
|
+
self,
|
|
217
|
+
segment: TextSegment,
|
|
218
|
+
tokens: list[int],
|
|
219
|
+
raw_xml_text: str,
|
|
220
|
+
remain_head: bool,
|
|
221
|
+
remain_count: int,
|
|
222
|
+
) -> TextSegment | None:
|
|
223
|
+
# 典型的 xml_text: <tag id="99" data-origin-len="999">Some text</tag>
|
|
224
|
+
# 如果切割点在前缀 XML 区,则整体舍弃
|
|
225
|
+
# 如果切割点在后缀 XML 区,则整体保留
|
|
226
|
+
# 只有刚好切割在正文区,才执行文本截断操作
|
|
227
|
+
remain_text: str
|
|
228
|
+
xml_text_head_length = raw_xml_text.find(segment.text)
|
|
229
|
+
|
|
230
|
+
if remain_head:
|
|
231
|
+
remain_xml_text = self._encoding.decode(tokens[:remain_count]) # remain_count cannot be 0 here
|
|
232
|
+
if len(remain_xml_text) <= xml_text_head_length:
|
|
233
|
+
return None
|
|
234
|
+
if len(remain_xml_text) >= xml_text_head_length + len(segment.text):
|
|
235
|
+
return segment
|
|
236
|
+
remain_text = remain_xml_text[xml_text_head_length:]
|
|
237
|
+
else:
|
|
238
|
+
xml_text_tail_length = len(raw_xml_text) - (xml_text_head_length + len(segment.text))
|
|
239
|
+
remain_xml_text = self._encoding.decode(tokens[-remain_count:])
|
|
240
|
+
if len(remain_xml_text) <= xml_text_tail_length:
|
|
241
|
+
return None
|
|
242
|
+
if len(remain_xml_text) >= xml_text_tail_length + len(segment.text):
|
|
243
|
+
return segment
|
|
244
|
+
remain_text = remain_xml_text[: len(remain_xml_text) - xml_text_tail_length]
|
|
245
|
+
|
|
246
|
+
if not remain_text.strip():
|
|
247
|
+
return None
|
|
248
|
+
|
|
249
|
+
if remain_head:
|
|
250
|
+
segment.text = f"{remain_text} {_ELLIPSIS}"
|
|
251
|
+
else:
|
|
252
|
+
segment.text = f"{_ELLIPSIS} {remain_text}"
|
|
253
|
+
return segment
|