epub-translator 0.1.1__py3-none-any.whl → 0.1.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. epub_translator/__init__.py +2 -2
  2. epub_translator/data/fill.jinja +143 -38
  3. epub_translator/epub/__init__.py +1 -1
  4. epub_translator/epub/metadata.py +122 -0
  5. epub_translator/epub/spines.py +3 -2
  6. epub_translator/epub/zip.py +11 -9
  7. epub_translator/epub_transcode.py +108 -0
  8. epub_translator/llm/__init__.py +1 -0
  9. epub_translator/llm/context.py +109 -0
  10. epub_translator/llm/core.py +32 -113
  11. epub_translator/llm/executor.py +25 -31
  12. epub_translator/llm/increasable.py +1 -1
  13. epub_translator/llm/types.py +0 -3
  14. epub_translator/segment/__init__.py +26 -0
  15. epub_translator/segment/block_segment.py +124 -0
  16. epub_translator/segment/common.py +29 -0
  17. epub_translator/segment/inline_segment.py +356 -0
  18. epub_translator/{xml_translator → segment}/text_segment.py +8 -8
  19. epub_translator/segment/utils.py +43 -0
  20. epub_translator/translator.py +147 -183
  21. epub_translator/utils.py +33 -0
  22. epub_translator/xml/__init__.py +2 -0
  23. epub_translator/xml/const.py +1 -0
  24. epub_translator/xml/deduplication.py +3 -3
  25. epub_translator/xml/self_closing.py +182 -0
  26. epub_translator/xml/utils.py +42 -0
  27. epub_translator/xml/xml.py +7 -0
  28. epub_translator/xml/xml_like.py +8 -33
  29. epub_translator/xml_interrupter.py +165 -0
  30. epub_translator/xml_translator/__init__.py +1 -2
  31. epub_translator/xml_translator/callbacks.py +34 -0
  32. epub_translator/xml_translator/{const.py → common.py} +0 -1
  33. epub_translator/xml_translator/hill_climbing.py +104 -0
  34. epub_translator/xml_translator/stream_mapper.py +253 -0
  35. epub_translator/xml_translator/submitter.py +26 -72
  36. epub_translator/xml_translator/translator.py +162 -113
  37. epub_translator/xml_translator/validation.py +458 -0
  38. {epub_translator-0.1.1.dist-info → epub_translator-0.1.3.dist-info}/METADATA +72 -9
  39. epub_translator-0.1.3.dist-info/RECORD +66 -0
  40. epub_translator/epub/placeholder.py +0 -53
  41. epub_translator/iter_sync.py +0 -24
  42. epub_translator/xml_translator/fill.py +0 -128
  43. epub_translator/xml_translator/format.py +0 -282
  44. epub_translator/xml_translator/fragmented.py +0 -125
  45. epub_translator/xml_translator/group.py +0 -183
  46. epub_translator/xml_translator/progressive_locking.py +0 -256
  47. epub_translator/xml_translator/utils.py +0 -29
  48. epub_translator-0.1.1.dist-info/RECORD +0 -58
  49. {epub_translator-0.1.1.dist-info → epub_translator-0.1.3.dist-info}/LICENSE +0 -0
  50. {epub_translator-0.1.1.dist-info → epub_translator-0.1.3.dist-info}/WHEEL +0 -0
@@ -4,6 +4,7 @@ import warnings
4
4
  from typing import IO
5
5
  from xml.etree.ElementTree import Element, fromstring, tostring
6
6
 
7
+ from .self_closing import self_close_void_elements, unclose_void_elements
7
8
  from .xml import iter_with_stack
8
9
 
9
10
  _XML_NAMESPACE_URI = "http://www.w3.org/XML/1998/namespace"
@@ -31,28 +32,11 @@ _ENCODING_PATTERN = re.compile(r'encoding\s*=\s*["\']([^"\']+)["\']', re.IGNOREC
31
32
  _FIRST_ELEMENT_PATTERN = re.compile(r"<(?![?!])[a-zA-Z]")
32
33
  _NAMESPACE_IN_TAG = re.compile(r"\{([^}]+)\}")
33
34
 
34
- # Some non-standard EPUB generators use HTML-style tags without self-closing syntax
35
- # We need to convert them to XML-compatible format before parsing
36
- _EMPTY_TAGS = (
37
- "br",
38
- "hr",
39
- "input",
40
- "col",
41
- "base",
42
- "meta",
43
- "area",
44
- )
45
-
46
- # For reading: match tags like <br> or <br class="x"> (but not <br/> or <body>)
47
- _EMPTY_TAG_OPEN_PATTERN = re.compile(r"<(" + "|".join(_EMPTY_TAGS) + r")(\s[^/>]*)>")
48
-
49
- # For saving: match self-closing tags like <br />
50
- _EMPTY_TAG_CLOSE_PATTERN = re.compile(r"<(" + "|".join(_EMPTY_TAGS) + r")(\s[^>]*?)\s*/>")
51
-
52
35
 
53
36
  class XMLLikeNode:
54
37
  def __init__(self, file: IO[bytes], is_html_like: bool = False) -> None:
55
38
  raw_content = file.read()
39
+ self._is_html_like = is_html_like
56
40
  self._encoding: str = self._detect_encoding(raw_content)
57
41
  content = raw_content.decode(self._encoding)
58
42
  self._header, xml_content = self._extract_header(content)
@@ -60,16 +44,9 @@ class XMLLikeNode:
60
44
  self._tag_to_namespace: dict[str, str] = {}
61
45
  self._attr_to_namespace: dict[str, str] = {}
62
46
 
63
- # For non-standard HTML files, convert <br> to <br/> before parsing
64
- self._is_html_like = is_html_like
65
- if is_html_like:
66
- xml_content = re.sub(
67
- pattern=_EMPTY_TAG_OPEN_PATTERN,
68
- repl=lambda m: f"<{m.group(1)}{m.group(2)} />",
69
- string=xml_content,
70
- )
71
-
72
47
  try:
48
+ # 不必判断类型,这是一个防御性极强的函数,可做到 shit >> XML
49
+ xml_content = self_close_void_elements(xml_content)
73
50
  self.element = self._extract_and_clean_namespaces(
74
51
  element=fromstring(xml_content),
75
52
  )
@@ -92,13 +69,11 @@ class XMLLikeNode:
92
69
 
93
70
  content = self._serialize_with_namespaces(self.element)
94
71
 
95
- # For non-standard HTML files, convert back from <br/> to <br>
72
+ # For non-standard HTML files (text/html), convert back from <br/> to <br>
73
+ # to maintain compatibility with HTML parsers that don't support XHTML
74
+ # For XHTML files (application/xhtml+xml), keep self-closing format
96
75
  if self._is_html_like:
97
- content = re.sub(
98
- pattern=_EMPTY_TAG_CLOSE_PATTERN,
99
- repl=lambda m: f"<{m.group(1)}{m.group(2)}>",
100
- string=content,
101
- )
76
+ content = unclose_void_elements(content)
102
77
 
103
78
  writer.write(content)
104
79
 
@@ -0,0 +1,165 @@
1
+ from collections.abc import Generator, Iterable
2
+ from typing import cast
3
+ from xml.etree.ElementTree import Element
4
+
5
+ from .segment import TextSegment
6
+ from .utils import ensure_list, normalize_whitespace
7
+
8
+ _ID_KEY = "__XML_INTERRUPTER_ID"
9
+ _MATH_TAG = "math"
10
+ _EXPRESSION_TAG = "expression"
11
+
12
+
13
+ class XMLInterrupter:
14
+ def __init__(self) -> None:
15
+ self._next_id: int = 1
16
+ self._last_interrupted_id: str | None = None
17
+ self._placeholder2interrupted: dict[int, Element] = {}
18
+ self._raw_text_segments: dict[str, list[TextSegment]] = {}
19
+
20
+ def interrupt_source_text_segments(
21
+ self, text_segments: Iterable[TextSegment]
22
+ ) -> Generator[TextSegment, None, None]:
23
+ for text_segment in text_segments:
24
+ yield from self._expand_source_text_segment(text_segment)
25
+
26
+ if self._last_interrupted_id is not None:
27
+ merged_text_segment = self._pop_and_merge_from_buffered(self._last_interrupted_id)
28
+ if merged_text_segment:
29
+ yield merged_text_segment
30
+
31
+ def interrupt_translated_text_segments(
32
+ self, text_segments: Iterable[TextSegment]
33
+ ) -> Generator[TextSegment, None, None]:
34
+ for text_segment in text_segments:
35
+ yield from self._expand_translated_text_segment(text_segment)
36
+
37
+ def interrupt_block_element(self, element: Element) -> Element:
38
+ interrupted_element = self._placeholder2interrupted.pop(id(element), None)
39
+ if interrupted_element is None:
40
+ return element
41
+ else:
42
+ return interrupted_element
43
+
44
+ def _expand_source_text_segment(self, text_segment: TextSegment):
45
+ interrupted_index = self._interrupted_index(text_segment)
46
+ interrupted_id: str | None = None
47
+
48
+ if interrupted_index is not None:
49
+ interrupted_element = text_segment.parent_stack[interrupted_index]
50
+ interrupted_id = interrupted_element.get(_ID_KEY)
51
+ if interrupted_id is None:
52
+ interrupted_id = str(self._next_id)
53
+ interrupted_element.set(_ID_KEY, interrupted_id)
54
+ self._next_id += 1
55
+ text_segments = ensure_list(
56
+ target=self._raw_text_segments,
57
+ key=interrupted_id,
58
+ )
59
+ text_segments.append(text_segment)
60
+
61
+ if self._last_interrupted_id is not None and interrupted_id != self._last_interrupted_id:
62
+ merged_text_segment = self._pop_and_merge_from_buffered(self._last_interrupted_id)
63
+ if merged_text_segment:
64
+ yield merged_text_segment
65
+
66
+ self._last_interrupted_id = interrupted_id
67
+
68
+ if interrupted_index is None:
69
+ yield text_segment
70
+
71
+ def _pop_and_merge_from_buffered(self, interrupted_id: str) -> TextSegment | None:
72
+ merged_text_segment: TextSegment | None = None
73
+ text_segments = self._raw_text_segments.get(interrupted_id, None)
74
+ if text_segments:
75
+ text_segment = text_segments[0]
76
+ interrupted_index = cast(int, self._interrupted_index(text_segment))
77
+ interrupted_element = text_segment.parent_stack[cast(int, interrupted_index)]
78
+ placeholder_element = Element(
79
+ _EXPRESSION_TAG,
80
+ {
81
+ _ID_KEY: cast(str, interrupted_element.get(_ID_KEY)),
82
+ },
83
+ )
84
+ raw_parent_stack = text_segment.parent_stack[:interrupted_index]
85
+ parent_stack = raw_parent_stack + [placeholder_element]
86
+ merged_text_segment = TextSegment(
87
+ text="".join(t.text for t in text_segments),
88
+ parent_stack=parent_stack,
89
+ left_common_depth=text_segments[0].left_common_depth,
90
+ right_common_depth=text_segments[-1].right_common_depth,
91
+ block_depth=len(parent_stack),
92
+ position=text_segments[0].position,
93
+ )
94
+ self._placeholder2interrupted[id(placeholder_element)] = interrupted_element
95
+
96
+ # TODO: 比较难搞,先关了再说
97
+ # parent_element: Element | None = None
98
+ # if interrupted_index > 0:
99
+ # parent_element = text_segment.parent_stack[interrupted_index - 1]
100
+
101
+ # if (
102
+ # not self._is_inline_math(interrupted_element)
103
+ # or parent_element is None
104
+ # or self._has_no_math_texts(parent_element)
105
+ # ):
106
+ # # 区块级公式不必重复出现,出现时突兀。但行内公式穿插在译文中更有利于读者阅读顺畅。
107
+ # self._raw_text_segments.pop(interrupted_id, None)
108
+ # else:
109
+ # for text_segment in text_segments:
110
+ # # 原始栈退光,仅留下相对 interrupted 元素的栈,这种格式与 translated 要求一致
111
+ # text_segment.left_common_depth = max(0, text_segment.left_common_depth - interrupted_index)
112
+ # text_segment.right_common_depth = max(0, text_segment.right_common_depth - interrupted_index)
113
+ # text_segment.block_depth = 1
114
+ # text_segment.parent_stack = text_segment.parent_stack[interrupted_index:]
115
+ for text_segment in text_segments:
116
+ # 原始栈退光,仅留下相对 interrupted 元素的栈,这种格式与 translated 要求一致
117
+ text_segment.left_common_depth = max(0, text_segment.left_common_depth - interrupted_index)
118
+ text_segment.right_common_depth = max(0, text_segment.right_common_depth - interrupted_index)
119
+ text_segment.block_depth = 1
120
+ text_segment.parent_stack = text_segment.parent_stack[interrupted_index:]
121
+
122
+ return merged_text_segment
123
+
124
+ def _interrupted_index(self, text_segment: TextSegment) -> int | None:
125
+ interrupted_index: int | None = None
126
+ for i, parent_element in enumerate(text_segment.parent_stack):
127
+ if parent_element.tag == _MATH_TAG:
128
+ interrupted_index = i
129
+ break
130
+ return interrupted_index
131
+
132
+ def _expand_translated_text_segment(self, text_segment: TextSegment):
133
+ interrupted_id = text_segment.block_parent.attrib.pop(_ID_KEY, None)
134
+ if interrupted_id is None:
135
+ yield text_segment
136
+ return
137
+
138
+ raw_text_segments = self._raw_text_segments.pop(interrupted_id, None)
139
+ if not raw_text_segments:
140
+ return
141
+
142
+ raw_block = raw_text_segments[0].parent_stack[0]
143
+ if not self._is_inline_math(raw_block):
144
+ return
145
+
146
+ for raw_text_segment in raw_text_segments:
147
+ raw_text_segment.block_parent.attrib.pop(_ID_KEY, None)
148
+ yield raw_text_segment
149
+
150
+ def _has_no_math_texts(self, element: Element):
151
+ if element.tag == _MATH_TAG:
152
+ return True
153
+ if element.text and normalize_whitespace(element.text).strip():
154
+ return False
155
+ for child_element in element:
156
+ if not self._has_no_math_texts(child_element):
157
+ return False
158
+ if child_element.tail and normalize_whitespace(child_element.tail).strip():
159
+ return False
160
+ return True
161
+
162
+ def _is_inline_math(self, element: Element) -> bool:
163
+ if element.tag != _MATH_TAG:
164
+ return False
165
+ return element.get("display", "").lower() != "block"
@@ -1,3 +1,2 @@
1
- from .group import XMLGroupContext
2
- from .submitter import submit_text_segments
1
+ from .callbacks import FillFailedEvent
3
2
  from .translator import XMLTranslator
@@ -0,0 +1,34 @@
1
+ from collections.abc import Callable, Iterable
2
+ from dataclasses import dataclass
3
+ from xml.etree.ElementTree import Element
4
+
5
+ from ..segment import TextSegment
6
+
7
+
8
+ @dataclass
9
+ class FillFailedEvent:
10
+ error_message: str
11
+ retried_count: int
12
+ over_maximum_retries: bool
13
+
14
+
15
+ @dataclass
16
+ class Callbacks:
17
+ interrupt_source_text_segments: Callable[[Iterable[TextSegment]], Iterable[TextSegment]]
18
+ interrupt_translated_text_segments: Callable[[Iterable[TextSegment]], Iterable[TextSegment]]
19
+ interrupt_block_element: Callable[[Element], Element]
20
+ on_fill_failed: Callable[[FillFailedEvent], None]
21
+
22
+
23
+ def warp_callbacks(
24
+ interrupt_source_text_segments: Callable[[Iterable[TextSegment]], Iterable[TextSegment]] | None,
25
+ interrupt_translated_text_segments: Callable[[Iterable[TextSegment]], Iterable[TextSegment]] | None,
26
+ interrupt_block_element: Callable[[Element], Element] | None,
27
+ on_fill_failed: Callable[[FillFailedEvent], None] | None,
28
+ ) -> Callbacks:
29
+ return Callbacks(
30
+ interrupt_source_text_segments=interrupt_source_text_segments or (lambda x: x),
31
+ interrupt_translated_text_segments=interrupt_translated_text_segments or (lambda x: x),
32
+ interrupt_block_element=interrupt_block_element or (lambda x: x),
33
+ on_fill_failed=on_fill_failed or (lambda event: None),
34
+ )
@@ -1,2 +1 @@
1
- ID_KEY: str = "id"
2
1
  DATA_ORIGIN_LEN_KEY = "data-orig-len"
@@ -0,0 +1,104 @@
1
+ from collections.abc import Generator
2
+ from dataclasses import dataclass
3
+ from xml.etree.ElementTree import Element
4
+
5
+ from tiktoken import Encoding
6
+
7
+ from ..segment import BlockSegment, BlockSubmitter, TextSegment, search_text_segments
8
+ from ..xml import plain_text
9
+ from .common import DATA_ORIGIN_LEN_KEY
10
+ from .stream_mapper import InlineSegmentMapping
11
+ from .validation import LEVEL_DEPTH, generate_error_message, nest_as_errors_group, truncate_errors_group
12
+
13
+
14
+ @dataclass
15
+ class _BlockStatus:
16
+ weight: int
17
+ submitter: BlockSubmitter
18
+
19
+
20
+ # 以爬山算法,将 LLM 中提交的内容中挑选出完成度更高的部分。
21
+ # 它通过拒绝每个子部分的相对低完成度提交,锁定每个子部分只能往更高完成度的方向移动
22
+ class HillClimbing:
23
+ def __init__(
24
+ self,
25
+ encoding: Encoding,
26
+ max_fill_displaying_errors: int,
27
+ block_segment: BlockSegment,
28
+ ) -> None:
29
+ self._encoding: Encoding = encoding
30
+ self._max_fill_displaying_errors: int = max_fill_displaying_errors
31
+ self._block_statuses: dict[int, _BlockStatus] = {}
32
+ self._block_segment: BlockSegment = block_segment
33
+
34
+ def request_element(self) -> Element:
35
+ element = self._block_segment.create_element()
36
+ for child_element in element:
37
+ text = plain_text(child_element)
38
+ tokens = self._encoding.encode(text)
39
+ child_element.set(DATA_ORIGIN_LEN_KEY, str(len(tokens)))
40
+ return element
41
+
42
+ def gen_mappings(self) -> Generator[InlineSegmentMapping | None, None, None]:
43
+ for inline_segment in self._block_segment:
44
+ id = inline_segment.id
45
+ assert id is not None
46
+ status = self._block_statuses.get(id, None)
47
+ text_segments: list[TextSegment] | None = None
48
+ if status is None:
49
+ yield None
50
+ else:
51
+ submitted_element = status.submitter.submitted_element
52
+ text_segments = list(search_text_segments(submitted_element))
53
+ yield inline_segment.parent, text_segments
54
+
55
+ def submit(self, element: Element) -> str | None:
56
+ error_message, block_weights = self._validate_block_weights_and_error_message(element)
57
+
58
+ for submitter in self._block_segment.submit(element):
59
+ weight: int = 0 # 未出现在 block_weights 说明没有错误,已完成
60
+ if block_weights:
61
+ weight = block_weights.get(submitter.id, 0)
62
+ status = self._block_statuses.get(submitter.id, None)
63
+ if status is None:
64
+ self._block_statuses[submitter.id] = _BlockStatus(
65
+ weight=weight,
66
+ submitter=submitter,
67
+ )
68
+ elif weight < status.weight:
69
+ status.weight = weight
70
+ status.submitter = submitter
71
+
72
+ return error_message
73
+
74
+ def _validate_block_weights_and_error_message(self, element: Element) -> tuple[str | None, dict[int, int] | None]:
75
+ errors_group = nest_as_errors_group(
76
+ errors=self._block_segment.validate(element),
77
+ )
78
+ if errors_group is None:
79
+ return None, None
80
+
81
+ block_weights: dict[int, int] = {}
82
+ for block_group in errors_group.block_groups:
83
+ block_id = block_group.block_id
84
+ status = self._block_statuses.get(block_id, None)
85
+ block_weights[block_id] = block_group.weight
86
+ if status is not None and status.weight > block_group.weight:
87
+ # 本轮完成度得到改善(weight 下降)应该排后,让出注意力给完成度尚未改善的部分
88
+ for child_error in block_group.errors:
89
+ child_error.level -= LEVEL_DEPTH
90
+
91
+ origin_errors_count = errors_group.errors_count
92
+ errors_group = truncate_errors_group(
93
+ errors_group=errors_group,
94
+ max_errors=self._max_fill_displaying_errors,
95
+ )
96
+ if errors_group is None:
97
+ return None, block_weights
98
+
99
+ message = generate_error_message(
100
+ encoding=self._encoding,
101
+ errors_group=errors_group,
102
+ omitted_count=origin_errors_count - errors_group.errors_count,
103
+ )
104
+ return message, block_weights
@@ -0,0 +1,253 @@
1
+ from collections.abc import Callable, Generator, Iterable, Iterator
2
+ from xml.etree.ElementTree import Element
3
+
4
+ from resource_segmentation import Group, Resource, Segment, split
5
+ from tiktoken import Encoding
6
+
7
+ from ..segment import InlineSegment, TextSegment, search_inline_segments, search_text_segments
8
+ from .callbacks import Callbacks
9
+
10
+ _PAGE_INCISION = 0
11
+ _BLOCK_INCISION = 1
12
+
13
+ _ELLIPSIS = "..."
14
+
15
+
16
+ InlineSegmentMapping = tuple[Element, list[TextSegment]]
17
+ InlineSegmentGroupMap = Callable[[list[InlineSegment]], list[InlineSegmentMapping | None]]
18
+
19
+
20
+ class XMLStreamMapper:
21
+ def __init__(self, encoding: Encoding, max_group_tokens: int) -> None:
22
+ self._encoding: Encoding = encoding
23
+ self._max_group_tokens: int = max_group_tokens
24
+
25
+ def map_stream(
26
+ self,
27
+ elements: Iterator[Element],
28
+ callbacks: Callbacks,
29
+ map: InlineSegmentGroupMap,
30
+ ) -> Generator[tuple[Element, list[InlineSegmentMapping]], None, None]:
31
+ current_element: Element | None = None
32
+ mapping_buffer: list[InlineSegmentMapping] = []
33
+
34
+ for group in self._split_into_serial_groups(elements, callbacks):
35
+ head, body, tail = self._truncate_and_transform_group(group)
36
+ target_body = map(head + body + tail)[len(head) : len(head) + len(body)]
37
+ for origin, target in zip(body, target_body, strict=False):
38
+ origin_element = origin.head.root
39
+ if current_element is None:
40
+ current_element = origin_element
41
+
42
+ if id(current_element) != id(origin_element):
43
+ yield current_element, mapping_buffer
44
+ current_element = origin_element
45
+ mapping_buffer = []
46
+
47
+ if target:
48
+ block_element, text_segments = target
49
+ block_element = callbacks.interrupt_block_element(block_element)
50
+ text_segments = list(callbacks.interrupt_translated_text_segments(text_segments))
51
+ if text_segments:
52
+ mapping_buffer.append((block_element, text_segments))
53
+
54
+ if current_element is not None:
55
+ yield current_element, mapping_buffer
56
+
57
+ def _split_into_serial_groups(self, elements: Iterable[Element], callbacks: Callbacks):
58
+ def generate():
59
+ for element in elements:
60
+ yield from split(
61
+ max_segment_count=self._max_group_tokens,
62
+ border_incision=_PAGE_INCISION,
63
+ resources=self._expand_to_resources(element, callbacks),
64
+ )
65
+
66
+ generator = generate()
67
+ group = next(generator, None)
68
+ if group is None:
69
+ return
70
+
71
+ # head + body * N (without tail)
72
+ sum_count = group.head_remain_count + sum(x.count for x in self._expand_resource_segments(group.body))
73
+
74
+ while True:
75
+ next_group = next(generator, None)
76
+ if next_group is None:
77
+ break
78
+
79
+ next_sum_body_count = sum(x.count for x in self._expand_resource_segments(next_group.body))
80
+ next_sum_count = sum_count + next_sum_body_count
81
+
82
+ if next_sum_count + next_group.tail_remain_count > self._max_group_tokens:
83
+ yield group
84
+ group = next_group
85
+ sum_count = group.head_remain_count + next_sum_body_count
86
+ else:
87
+ group.body.extend(next_group.body)
88
+ group.tail = next_group.tail
89
+ group.tail_remain_count = next_group.tail_remain_count
90
+ sum_count = next_sum_count
91
+
92
+ yield group
93
+
94
+ def _truncate_and_transform_group(self, group: Group[InlineSegment]):
95
+ head = list(
96
+ self._truncate_inline_segments(
97
+ inline_segments=self._expand_inline_segments(group.head),
98
+ remain_head=False,
99
+ remain_count=group.head_remain_count,
100
+ )
101
+ )
102
+ body = list(self._expand_inline_segments(group.body))
103
+ tail = list(
104
+ self._truncate_inline_segments(
105
+ inline_segments=self._expand_inline_segments(group.tail),
106
+ remain_head=True,
107
+ remain_count=group.tail_remain_count,
108
+ )
109
+ )
110
+ return head, body, tail
111
+
112
+ def _expand_to_resources(self, element: Element, callbacks: Callbacks):
113
+ def expand(element: Element):
114
+ text_segments = search_text_segments(element)
115
+ text_segments = callbacks.interrupt_source_text_segments(text_segments)
116
+ yield from search_inline_segments(text_segments)
117
+
118
+ inline_segment_generator = expand(element)
119
+ start_incision = _PAGE_INCISION
120
+ inline_segment = next(inline_segment_generator, None)
121
+ if inline_segment is None:
122
+ return
123
+
124
+ while True:
125
+ next_inline_segment = next(inline_segment_generator, None)
126
+ if next_inline_segment is None:
127
+ break
128
+
129
+ if next_inline_segment.head.root is inline_segment.tail.root:
130
+ end_incision = _BLOCK_INCISION
131
+ else:
132
+ end_incision = _PAGE_INCISION
133
+
134
+ yield Resource(
135
+ count=sum(len(self._encoding.encode(t.xml_text)) for t in inline_segment),
136
+ start_incision=start_incision,
137
+ end_incision=end_incision,
138
+ payload=inline_segment,
139
+ )
140
+ inline_segment = next_inline_segment
141
+ start_incision = end_incision
142
+
143
+ yield Resource(
144
+ count=sum(len(self._encoding.encode(t.xml_text)) for t in inline_segment),
145
+ start_incision=start_incision,
146
+ end_incision=_PAGE_INCISION,
147
+ payload=inline_segment,
148
+ )
149
+
150
+ def _truncate_inline_segments(self, inline_segments: Iterable[InlineSegment], remain_head: bool, remain_count: int):
151
+ def clone_and_expand(segments: Iterable[InlineSegment]):
152
+ for segment in segments:
153
+ for child_segment in segment:
154
+ yield child_segment.clone() # 切割对应的 head 和 tail 会与其他 group 重叠,复制避免互相影响
155
+
156
+ truncated_text_segments = self._truncate_text_segments(
157
+ text_segments=clone_and_expand(inline_segments),
158
+ remain_head=remain_head,
159
+ remain_count=remain_count,
160
+ )
161
+ yield from search_inline_segments(truncated_text_segments)
162
+
163
+ def _expand_inline_segments(self, items: list[Resource[InlineSegment] | Segment[InlineSegment]]):
164
+ for resource in self._expand_resource_segments(items):
165
+ yield resource.payload
166
+
167
+ def _expand_resource_segments(self, items: list[Resource[InlineSegment] | Segment[InlineSegment]]):
168
+ for item in items:
169
+ if isinstance(item, Resource):
170
+ yield item
171
+ elif isinstance(item, Segment):
172
+ yield from item.resources
173
+
174
+ def _truncate_text_segments(self, text_segments: Iterable[TextSegment], remain_head: bool, remain_count: int):
175
+ if remain_head:
176
+ yield from self._filter_and_remain_segments(
177
+ segments=text_segments,
178
+ remain_head=remain_head,
179
+ remain_count=remain_count,
180
+ )
181
+ else:
182
+ yield from reversed(
183
+ list(
184
+ self._filter_and_remain_segments(
185
+ segments=reversed(list(text_segments)),
186
+ remain_head=remain_head,
187
+ remain_count=remain_count,
188
+ )
189
+ )
190
+ )
191
+
192
+ def _filter_and_remain_segments(self, segments: Iterable[TextSegment], remain_head: bool, remain_count: int):
193
+ for segment in segments:
194
+ if remain_count <= 0:
195
+ break
196
+ raw_xml_text = segment.xml_text
197
+ tokens = self._encoding.encode(raw_xml_text)
198
+ tokens_count = len(tokens)
199
+
200
+ if tokens_count > remain_count:
201
+ truncated_segment = self._truncate_text_segment(
202
+ segment=segment,
203
+ tokens=tokens,
204
+ raw_xml_text=raw_xml_text,
205
+ remain_head=remain_head,
206
+ remain_count=remain_count,
207
+ )
208
+ if truncated_segment is not None:
209
+ yield truncated_segment
210
+ break
211
+
212
+ yield segment
213
+ remain_count -= tokens_count
214
+
215
+ def _truncate_text_segment(
216
+ self,
217
+ segment: TextSegment,
218
+ tokens: list[int],
219
+ raw_xml_text: str,
220
+ remain_head: bool,
221
+ remain_count: int,
222
+ ) -> TextSegment | None:
223
+ # 典型的 xml_text: <tag id="99" data-origin-len="999">Some text</tag>
224
+ # 如果切割点在前缀 XML 区,则整体舍弃
225
+ # 如果切割点在后缀 XML 区,则整体保留
226
+ # 只有刚好切割在正文区,才执行文本截断操作
227
+ remain_text: str
228
+ xml_text_head_length = raw_xml_text.find(segment.text)
229
+
230
+ if remain_head:
231
+ remain_xml_text = self._encoding.decode(tokens[:remain_count]) # remain_count cannot be 0 here
232
+ if len(remain_xml_text) <= xml_text_head_length:
233
+ return None
234
+ if len(remain_xml_text) >= xml_text_head_length + len(segment.text):
235
+ return segment
236
+ remain_text = remain_xml_text[xml_text_head_length:]
237
+ else:
238
+ xml_text_tail_length = len(raw_xml_text) - (xml_text_head_length + len(segment.text))
239
+ remain_xml_text = self._encoding.decode(tokens[-remain_count:])
240
+ if len(remain_xml_text) <= xml_text_tail_length:
241
+ return None
242
+ if len(remain_xml_text) >= xml_text_tail_length + len(segment.text):
243
+ return segment
244
+ remain_text = remain_xml_text[: len(remain_xml_text) - xml_text_tail_length]
245
+
246
+ if not remain_text.strip():
247
+ return None
248
+
249
+ if remain_head:
250
+ segment.text = f"{remain_text} {_ELLIPSIS}"
251
+ else:
252
+ segment.text = f"{_ELLIPSIS} {remain_text}"
253
+ return segment