epub-translator 0.1.1__py3-none-any.whl → 0.1.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. epub_translator/__init__.py +2 -2
  2. epub_translator/data/fill.jinja +143 -38
  3. epub_translator/epub/__init__.py +1 -1
  4. epub_translator/epub/metadata.py +122 -0
  5. epub_translator/epub/spines.py +3 -2
  6. epub_translator/epub/zip.py +11 -9
  7. epub_translator/epub_transcode.py +108 -0
  8. epub_translator/llm/__init__.py +1 -0
  9. epub_translator/llm/context.py +109 -0
  10. epub_translator/llm/core.py +32 -113
  11. epub_translator/llm/executor.py +25 -31
  12. epub_translator/llm/increasable.py +1 -1
  13. epub_translator/llm/types.py +0 -3
  14. epub_translator/segment/__init__.py +26 -0
  15. epub_translator/segment/block_segment.py +124 -0
  16. epub_translator/segment/common.py +29 -0
  17. epub_translator/segment/inline_segment.py +356 -0
  18. epub_translator/{xml_translator → segment}/text_segment.py +8 -8
  19. epub_translator/segment/utils.py +43 -0
  20. epub_translator/translator.py +147 -183
  21. epub_translator/utils.py +33 -0
  22. epub_translator/xml/__init__.py +2 -0
  23. epub_translator/xml/const.py +1 -0
  24. epub_translator/xml/deduplication.py +3 -3
  25. epub_translator/xml/self_closing.py +182 -0
  26. epub_translator/xml/utils.py +42 -0
  27. epub_translator/xml/xml.py +7 -0
  28. epub_translator/xml/xml_like.py +8 -33
  29. epub_translator/xml_interrupter.py +165 -0
  30. epub_translator/xml_translator/__init__.py +1 -2
  31. epub_translator/xml_translator/callbacks.py +34 -0
  32. epub_translator/xml_translator/{const.py → common.py} +0 -1
  33. epub_translator/xml_translator/hill_climbing.py +104 -0
  34. epub_translator/xml_translator/stream_mapper.py +253 -0
  35. epub_translator/xml_translator/submitter.py +26 -72
  36. epub_translator/xml_translator/translator.py +162 -113
  37. epub_translator/xml_translator/validation.py +458 -0
  38. {epub_translator-0.1.1.dist-info → epub_translator-0.1.3.dist-info}/METADATA +72 -9
  39. epub_translator-0.1.3.dist-info/RECORD +66 -0
  40. epub_translator/epub/placeholder.py +0 -53
  41. epub_translator/iter_sync.py +0 -24
  42. epub_translator/xml_translator/fill.py +0 -128
  43. epub_translator/xml_translator/format.py +0 -282
  44. epub_translator/xml_translator/fragmented.py +0 -125
  45. epub_translator/xml_translator/group.py +0 -183
  46. epub_translator/xml_translator/progressive_locking.py +0 -256
  47. epub_translator/xml_translator/utils.py +0 -29
  48. epub_translator-0.1.1.dist-info/RECORD +0 -58
  49. {epub_translator-0.1.1.dist-info → epub_translator-0.1.3.dist-info}/LICENSE +0 -0
  50. {epub_translator-0.1.1.dist-info → epub_translator-0.1.3.dist-info}/WHEEL +0 -0
@@ -1,53 +0,0 @@
1
- from collections.abc import Callable
2
- from xml.etree.ElementTree import Element
3
-
4
- from .math import xml_to_latex
5
-
6
- _MATH_TAG = "math"
7
- _EXPRESSION_TAG = "expression"
8
-
9
- _PLACEHOLDER_TAGS = frozenset((_EXPRESSION_TAG,))
10
-
11
-
12
- def is_placeholder_tag(tag: str) -> bool:
13
- return tag in _PLACEHOLDER_TAGS
14
-
15
-
16
- class Placeholder:
17
- def __init__(self, root: Element):
18
- self._raw_elements: dict[int, Element] = {}
19
- self._root: Element = self._replace(
20
- element=root,
21
- replace=self._replace_raw,
22
- )
23
- assert id(self._root) == id(root)
24
-
25
- def recover(self) -> None:
26
- self._replace(
27
- element=self._root,
28
- replace=self._recover_to_raw,
29
- )
30
-
31
- def _replace(self, element: Element, replace: Callable[[Element], Element | None]) -> Element:
32
- replaced = replace(element)
33
- if replaced is not None:
34
- return replaced
35
- if len(element):
36
- element[:] = [self._replace(child, replace) for child in element]
37
- return element
38
-
39
- def _replace_raw(self, element: Element) -> Element | None:
40
- if element.tag == _MATH_TAG:
41
- replaced = Element(_EXPRESSION_TAG)
42
- replaced.text = xml_to_latex(element)
43
- replaced.tail = element.tail
44
- self._raw_elements[id(replaced)] = element
45
- return replaced
46
- return None
47
-
48
- def _recover_to_raw(self, replaced: Element) -> Element | None:
49
- raw = self._raw_elements.get(id(replaced))
50
- if raw is not None:
51
- del self._raw_elements[id(replaced)]
52
- return raw
53
- return None
@@ -1,24 +0,0 @@
1
- from collections.abc import Generator, Iterable
2
- from typing import Generic, TypeVar
3
-
4
- T = TypeVar("T")
5
-
6
-
7
- class IterSync(Generic[T]):
8
- def __init__(self) -> None:
9
- super().__init__()
10
- self._queue: list[T] = []
11
-
12
- @property
13
- def tail(self) -> T | None:
14
- if not self._queue:
15
- return None
16
- return self._queue[-1]
17
-
18
- def take(self) -> T:
19
- return self._queue.pop()
20
-
21
- def iter(self, elements: Iterable[T]) -> Generator[T, None, None]:
22
- for element in elements:
23
- self._queue.insert(0, element)
24
- yield element
@@ -1,128 +0,0 @@
1
- from xml.etree.ElementTree import Element
2
-
3
- from ..utils import normalize_whitespace
4
- from ..xml import plain_text
5
- from .const import DATA_ORIGIN_LEN_KEY, ID_KEY
6
- from .format import format
7
- from .text_segment import TextSegment, combine_text_segments
8
-
9
-
10
- class XMLFill:
11
- def __init__(self, text_segments: list[TextSegment]) -> None:
12
- self._request_element = Element("xml")
13
- self._text_segments: dict[tuple[int, ...], list[TextSegment]] = {} # generated id stack -> text segments
14
-
15
- raw2generated: dict[int, Element] = {}
16
- raw2generated_ids: dict[int, int] = {}
17
-
18
- for combined_element, sub_raw2generated in combine_text_segments(text_segments):
19
- unwrapped_parent_ids: set[int] = set()
20
- sub_element, parents = self._unwrap_parents(combined_element)
21
- self._request_element.append(sub_element)
22
- for parent in parents:
23
- unwrapped_parent_ids.add(id(parent))
24
-
25
- for raw_id, generated_element in sub_raw2generated.items():
26
- if raw_id in unwrapped_parent_ids:
27
- continue
28
- if id(generated_element) in unwrapped_parent_ids:
29
- continue
30
- generated_id = len(raw2generated)
31
- raw2generated[raw_id] = generated_element
32
- raw2generated_ids[raw_id] = generated_id
33
-
34
- generated_text = normalize_whitespace(
35
- text=plain_text(generated_element),
36
- )
37
- generated_element.attrib = {
38
- ID_KEY: str(generated_id),
39
- DATA_ORIGIN_LEN_KEY: str(len(generated_text)),
40
- }
41
-
42
- for text_segment in text_segments:
43
- generated_id_stack: list[int] = []
44
- for parent in text_segment.parent_stack:
45
- generated_id = raw2generated_ids.get(id(parent), None)
46
- if generated_id is not None:
47
- generated_id_stack.append(generated_id)
48
- generated_key = tuple(generated_id_stack)
49
- text_segments_stack = self._text_segments.get(generated_key, None)
50
- if text_segments_stack is None:
51
- text_segments_stack = []
52
- self._text_segments[generated_key] = text_segments_stack
53
- text_segments_stack.append(text_segment)
54
-
55
- for text_segments_stack in self._text_segments.values():
56
- text_segments_stack.reverse() # for use call .pop()
57
-
58
- def _unwrap_parents(self, element: Element):
59
- parents: list[Element] = []
60
- while True:
61
- if len(element) != 1:
62
- break
63
- child = element[0]
64
- if not element.text:
65
- break
66
- if not child.tail:
67
- break
68
- parents.append(element)
69
- element = child
70
- element.tail = None
71
- return element, parents
72
-
73
- @property
74
- def request_element(self) -> Element:
75
- return self._request_element
76
-
77
- def submit_response_text(self, text: str, errors_limit: int) -> Element:
78
- submitted_element = format(
79
- template_ele=self._request_element,
80
- validated_text=text,
81
- errors_limit=errors_limit,
82
- )
83
- self._fill_submitted_texts(
84
- generated_ids_stack=[],
85
- element=submitted_element,
86
- )
87
- return submitted_element
88
-
89
- def _fill_submitted_texts(self, generated_ids_stack: list[int], element: Element):
90
- current_stack = generated_ids_stack
91
- generated_id = self._generated_id(element)
92
- if generated_id >= 0:
93
- current_stack = generated_ids_stack + [generated_id]
94
-
95
- generated_key = tuple(current_stack)
96
- text_segments_stack = self._text_segments.get(generated_key, None)
97
- text = self._normalize_text(element.text)
98
-
99
- if text_segments_stack and text is not None:
100
- text_segment = text_segments_stack.pop()
101
- text_segment.text = text
102
-
103
- for child_element in element:
104
- self._fill_submitted_texts(
105
- generated_ids_stack=current_stack,
106
- element=child_element,
107
- )
108
- tail = self._normalize_text(child_element.tail)
109
- if text_segments_stack and tail is not None:
110
- text_segment = text_segments_stack.pop()
111
- text_segment.text = tail
112
-
113
- def _generated_id(self, element: Element) -> int:
114
- str_id = element.get(ID_KEY, None)
115
- if str_id is None:
116
- return -1
117
- try:
118
- return int(str_id)
119
- except ValueError:
120
- return -1
121
-
122
- def _normalize_text(self, text: str | None) -> str | None:
123
- if text is None:
124
- return None
125
- text = normalize_whitespace(text)
126
- if not text.strip():
127
- return None
128
- return text
@@ -1,282 +0,0 @@
1
- from xml.etree.ElementTree import Element
2
-
3
- from ..utils import normalize_whitespace
4
- from ..xml import decode_friendly
5
- from .const import ID_KEY
6
-
7
-
8
- def format(template_ele: Element, validated_text: str, errors_limit: int) -> Element:
9
- context = _ValidationContext()
10
- validated_ele = _extract_xml_element(validated_text)
11
- context.validate(raw_ele=template_ele, validated_ele=validated_ele)
12
- error_message = context.errors(limit=errors_limit)
13
- if error_message:
14
- raise ValidationError(message=error_message, validated_ele=validated_ele)
15
- return validated_ele
16
-
17
-
18
- class ValidationError(Exception):
19
- def __init__(self, message: str, validated_ele: Element | None = None) -> None:
20
- super().__init__(message)
21
- self.validated_ele = validated_ele
22
-
23
-
24
- def _extract_xml_element(text: str) -> Element:
25
- first_xml_element: Element | None = None
26
- all_xml_elements: int = 0
27
-
28
- for xml_element in decode_friendly(text, tags="xml"):
29
- if first_xml_element is None:
30
- first_xml_element = xml_element
31
- all_xml_elements += 1
32
-
33
- if first_xml_element is None:
34
- raise ValidationError(
35
- "No complete <xml>...</xml> block found. Please ensure you have properly closed the XML with </xml> tag."
36
- )
37
- if all_xml_elements > 1:
38
- raise ValidationError(
39
- f"Found {all_xml_elements} <xml>...</xml> blocks. "
40
- "Please return only one XML block without any examples or explanations."
41
- )
42
- return first_xml_element
43
-
44
-
45
- class _ValidationContext:
46
- def __init__(self) -> None:
47
- self._tag_text_dict: dict[int, str] = {}
48
- self._errors: dict[tuple[int, ...], list[str]] = {}
49
-
50
- def validate(self, raw_ele: Element, validated_ele: Element):
51
- self._validate_ele(ids_path=[], raw_ele=raw_ele, validated_ele=validated_ele)
52
-
53
- def errors(self, limit: int) -> str | None:
54
- if not self._errors:
55
- return
56
-
57
- keys = list(self._errors.keys())
58
- keys.sort(key=lambda k: (len(k), k)) # AI 矫正应该先浅后深
59
- keys = keys[:limit]
60
- max_len_key = max((len(key) for key in keys), default=0)
61
-
62
- for i in range(len(keys)):
63
- key = keys[i]
64
- if len(key) < max_len_key:
65
- key_list = list(key)
66
- while len(key_list) < max_len_key:
67
- key_list.append(-1)
68
- keys[i] = tuple(key_list)
69
-
70
- content: list[str] = []
71
- total_errors = sum(len(messages) for messages in self._errors.values())
72
- remain_errors = total_errors
73
-
74
- for key in sorted(keys): # 改成深度优先排序,看起来关联度更好
75
- raw_key = tuple(k for k in key if k >= 0)
76
- indent: str = f"{' ' * len(raw_key)}"
77
- errors_list = self._errors[raw_key]
78
- parent_text: str
79
-
80
- if len(raw_key) > 0:
81
- parent_text = self._tag_text_dict[raw_key[-1]]
82
- else:
83
- parent_text = "the root tag"
84
-
85
- if len(errors_list) == 1:
86
- error = errors_list[0]
87
- content.append(f"{indent}- errors in {parent_text}: {error}.")
88
- else:
89
- content.append(f"{indent}- errors in {parent_text}:")
90
- for error in errors_list:
91
- content.append(f"{indent} - {error}.")
92
- remain_errors -= len(errors_list)
93
-
94
- content.insert(0, f"Found {total_errors} error(s) in your response XML structure.")
95
- if remain_errors > 0:
96
- content.append(f"\n... and {remain_errors} more error(s).")
97
-
98
- return "\n".join(content)
99
-
100
- def _validate_ele(self, ids_path: list[int], raw_ele: Element, validated_ele: Element):
101
- raw_id_map = self._build_id_map(raw_ele)
102
- validated_id_map = self._build_id_map(validated_ele)
103
- lost_ids: list[int] = []
104
- extra_ids: list[int] = []
105
-
106
- for id, sub_raw in raw_id_map.items():
107
- sub_validated = validated_id_map.get(id, None)
108
- if sub_validated is None:
109
- lost_ids.append(id)
110
- else:
111
- self._validate_id_ele(
112
- id=id,
113
- ids_path=ids_path,
114
- raw_ele=sub_raw,
115
- validated_ele=sub_validated,
116
- )
117
-
118
- for id in validated_id_map.keys():
119
- if id not in raw_id_map:
120
- extra_ids.append(id)
121
-
122
- if lost_ids or extra_ids:
123
- messages: list[str] = []
124
- lost_ids.sort()
125
- extra_ids.sort()
126
-
127
- if lost_ids:
128
- tags = [self._str_tag(raw_id_map[id]) for id in lost_ids]
129
- # Provide context from source XML
130
- context_info = self._get_source_context(raw_ele, lost_ids)
131
- messages.append(f"lost sub-tags {' '.join(tags)}")
132
- if context_info:
133
- messages.append(f"Source structure was: {context_info}")
134
-
135
- if extra_ids:
136
- tags = [self._str_tag(validated_id_map[id]) for id in extra_ids]
137
- messages.append(f"extra sub-tags {' '.join(tags)}")
138
-
139
- if messages:
140
- self._add_error(
141
- ids_path=ids_path,
142
- message="find " + " and ".join(messages),
143
- )
144
- else:
145
- raw_element_empty = not self._has_text_content(raw_ele)
146
- validated_ele_empty = not self._has_text_content(validated_ele)
147
-
148
- if raw_element_empty and not validated_ele_empty:
149
- self._add_error(
150
- ids_path=ids_path,
151
- message="shouldn't have text content",
152
- )
153
- elif not raw_element_empty and validated_ele_empty:
154
- self._add_error(
155
- ids_path=ids_path,
156
- message="text content is missing",
157
- )
158
-
159
- def _validate_id_ele(self, ids_path: list[int], id: int, raw_ele: Element, validated_ele: Element):
160
- if raw_ele.tag == validated_ele.tag:
161
- self._tag_text_dict[id] = self._str_tag(raw_ele)
162
- raw_has_text = self._has_direct_text(raw_ele.text)
163
- validated_has_text = self._has_direct_text(validated_ele.text)
164
-
165
- if raw_has_text and not validated_has_text:
166
- self._add_error(
167
- ids_path=ids_path + [id],
168
- message="missing text content before child elements",
169
- )
170
- elif not raw_has_text and validated_has_text:
171
- self._add_error(
172
- ids_path=ids_path + [id],
173
- message="shouldn't have text content before child elements",
174
- )
175
- raw_has_tail = self._has_direct_text(raw_ele.tail)
176
- validated_has_tail = self._has_direct_text(validated_ele.tail)
177
-
178
- if raw_has_tail and not validated_has_tail:
179
- self._add_error(
180
- ids_path=ids_path + [id],
181
- message="missing text content after the element",
182
- )
183
- elif not raw_has_tail and validated_has_tail:
184
- self._add_error(
185
- ids_path=ids_path + [id],
186
- message="shouldn't have text content after the element",
187
- )
188
-
189
- self._validate_ele(
190
- ids_path=ids_path + [id],
191
- raw_ele=raw_ele,
192
- validated_ele=validated_ele,
193
- )
194
- else:
195
- self._add_error(
196
- ids_path=ids_path,
197
- message=f'got <{validated_ele.tag} id="{id}">',
198
- )
199
-
200
- def _add_error(self, ids_path: list[int], message: str):
201
- key = tuple(ids_path)
202
- if key not in self._errors:
203
- self._errors[key] = []
204
- self._errors[key].append(message)
205
-
206
- def _build_id_map(self, ele: Element):
207
- id_map: dict[int, Element] = {}
208
- for child_ele in ele:
209
- id_text = child_ele.get(ID_KEY, None)
210
- if id_text is not None:
211
- id = int(id_text)
212
- if id < 0:
213
- raise ValueError(f"Invalid id {id} found. IDs must be non-negative integers.")
214
- if id_text is not None:
215
- id_map[id] = child_ele
216
- return id_map
217
-
218
- def _has_text_content(self, ele: Element) -> bool:
219
- text = "".join(self._plain_text(ele))
220
- text = normalize_whitespace(text)
221
- text = text.strip()
222
- return len(text) > 0
223
-
224
- def _has_direct_text(self, text: str | None) -> bool:
225
- if text is None:
226
- return False
227
- normalized = normalize_whitespace(text).strip()
228
- return len(normalized) > 0
229
-
230
- def _plain_text(self, ele: Element):
231
- if ele.text:
232
- yield ele.text
233
- for child in ele:
234
- if child.get(ID_KEY, None) is not None:
235
- yield from self._plain_text(child)
236
- if child.tail:
237
- yield child.tail
238
-
239
- def _str_tag(self, ele: Element) -> str:
240
- ele_id = ele.get(ID_KEY)
241
- content: str
242
- if ele_id is not None:
243
- content = f'<{ele.tag} id="{ele_id}"'
244
- else:
245
- content = f"<{ele.tag}"
246
- if len(ele) > 0:
247
- content += f"> ... </{ele.tag}>"
248
- else:
249
- content += " />"
250
- return content
251
-
252
- def _get_source_context(self, parent: Element, lost_ids: list[int]) -> str:
253
- """Generate context showing where lost tags appeared in source XML."""
254
- if not lost_ids:
255
- return ""
256
-
257
- # Build a simple representation of the source structure
258
- children_with_ids = []
259
- for child in parent:
260
- child_id_str = child.get(ID_KEY)
261
- if child_id_str is not None:
262
- child_id = int(child_id_str)
263
- is_lost = child_id in lost_ids
264
- tag_str = f'<{child.tag} id="{child_id}">'
265
-
266
- # Show text before/inside/after
267
- parts = []
268
- if child.text and child.text.strip():
269
- preview = child.text.strip()[:20]
270
- if is_lost:
271
- parts.append(f'[{preview}...]')
272
- else:
273
- parts.append(f'{preview}...')
274
-
275
- if is_lost:
276
- children_with_ids.append(f'{tag_str}*MISSING*')
277
- else:
278
- children_with_ids.append(tag_str)
279
-
280
- if children_with_ids:
281
- return f"[{' '.join(children_with_ids)}]"
282
- return ""
@@ -1,125 +0,0 @@
1
- from collections.abc import Generator, Iterable, Iterator
2
- from enum import Enum, auto
3
- from xml.etree.ElementTree import Element
4
-
5
- from tiktoken import Encoding
6
-
7
- from .utils import expand_left_element_texts, expand_right_element_texts, normalize_text_in_element
8
-
9
-
10
- def group_fragmented_elements(
11
- encoding: Encoding,
12
- elements: Iterable[Element],
13
- group_max_tokens: int,
14
- ) -> Generator[list[Element], None, None]:
15
- remain_tokens_count: int = group_max_tokens
16
- elements_buffer: list[Element] = []
17
-
18
- for element in elements:
19
- if remain_tokens_count <= 0:
20
- remain_tokens_count = group_max_tokens
21
- if elements_buffer:
22
- yield elements_buffer
23
- elements_buffer = []
24
-
25
- counter = _XMLCounter(encoding, element)
26
- cost_tokens_count = counter.advance_tokens(remain_tokens_count)
27
- remain_tokens_count -= cost_tokens_count
28
- if not counter.can_advance():
29
- elements_buffer.append(element)
30
- continue
31
-
32
- if elements_buffer:
33
- yield elements_buffer
34
- elements_buffer = []
35
-
36
- remain_tokens_count = group_max_tokens - cost_tokens_count
37
- cost_tokens_count = counter.advance_tokens(remain_tokens_count)
38
- if not counter.can_advance():
39
- elements_buffer.append(element)
40
- remain_tokens_count -= cost_tokens_count
41
- continue
42
-
43
- remain_tokens_count = group_max_tokens
44
- yield [element]
45
-
46
- if elements_buffer:
47
- yield elements_buffer
48
-
49
-
50
- class _TextItemKind(Enum):
51
- TEXT = auto()
52
- XML_TAG = auto()
53
-
54
-
55
- class _XMLCounter:
56
- def __init__(self, encoding: Encoding, root: Element) -> None:
57
- self._encoding: Encoding = encoding
58
- self._text_iter: Iterator[str] = iter(self._expand_texts(root))
59
- self._remain_tokens_count: int = 0
60
- self._next_text_buffer: str | None = None
61
-
62
- def can_advance(self) -> bool:
63
- if self._remain_tokens_count > 0:
64
- return True
65
- if self._next_text_buffer is None:
66
- self._next_text_buffer = next(self._text_iter, None)
67
- return self._next_text_buffer is not None
68
-
69
- def _expand_texts(self, element: Element) -> Generator[str, None, None]:
70
- xml_tags_buffer: list[str] = [] # 这类过于碎片化,需拼凑避免 encoding 失效
71
- for kind, text in self._expand_text_items(element):
72
- if kind == _TextItemKind.XML_TAG:
73
- xml_tags_buffer.append(text)
74
- elif kind == _TextItemKind.TEXT:
75
- if xml_tags_buffer:
76
- yield "".join(xml_tags_buffer)
77
- xml_tags_buffer = []
78
- yield text
79
- if xml_tags_buffer:
80
- yield "".join(xml_tags_buffer)
81
-
82
- def _expand_text_items(self, element: Element) -> Generator[tuple[_TextItemKind, str], None, None]:
83
- for text in expand_left_element_texts(element):
84
- yield _TextItemKind.XML_TAG, text
85
-
86
- text = normalize_text_in_element(element.text)
87
- if text is not None:
88
- yield _TextItemKind.TEXT, text
89
- for child in element:
90
- yield from self._expand_text_items(child)
91
- tail = normalize_text_in_element(child.tail)
92
- if tail is not None:
93
- yield _TextItemKind.TEXT, tail
94
-
95
- for text in expand_right_element_texts(element):
96
- yield _TextItemKind.XML_TAG, text
97
-
98
- def advance_tokens(self, max_tokens_count: int) -> int:
99
- tokens_count: int = 0
100
- while tokens_count < max_tokens_count:
101
- if self._remain_tokens_count > 0:
102
- will_count_tokens = max_tokens_count - tokens_count
103
- if will_count_tokens > self._remain_tokens_count:
104
- tokens_count += self._remain_tokens_count
105
- self._remain_tokens_count = 0
106
- else:
107
- tokens_count += will_count_tokens
108
- self._remain_tokens_count -= will_count_tokens
109
- if tokens_count >= max_tokens_count:
110
- break
111
- next_text = self._next_text()
112
- if next_text is None:
113
- break
114
- self._remain_tokens_count += len(self._encoding.encode(next_text))
115
-
116
- return tokens_count
117
-
118
- def _next_text(self) -> str | None:
119
- next_text: str | None = None
120
- if self._next_text_buffer is None:
121
- next_text = next(self._text_iter, None)
122
- else:
123
- next_text = self._next_text_buffer
124
- self._next_text_buffer = None
125
- return next_text