epub-translator 0.1.5__py3-none-any.whl → 0.1.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. epub_translator/__init__.py +1 -2
  2. epub_translator/data/translate.jinja +3 -0
  3. epub_translator/epub/__init__.py +1 -1
  4. epub_translator/llm/context.py +10 -1
  5. epub_translator/llm/core.py +30 -3
  6. epub_translator/segment/__init__.py +1 -0
  7. epub_translator/segment/inline_segment.py +11 -1
  8. epub_translator/segment/text_segment.py +5 -10
  9. epub_translator/segment/utils.py +0 -16
  10. epub_translator/translation/__init__.py +2 -0
  11. epub_translator/{epub_transcode.py → translation/epub_transcode.py} +2 -2
  12. epub_translator/{punctuation.py → translation/punctuation.py} +1 -1
  13. epub_translator/{translator.py → translation/translator.py} +8 -6
  14. epub_translator/{xml_interrupter.py → translation/xml_interrupter.py} +52 -28
  15. epub_translator/xml/__init__.py +1 -1
  16. epub_translator/xml/inline.py +48 -2
  17. epub_translator/xml_translator/concurrency.py +52 -0
  18. epub_translator/xml_translator/score.py +164 -0
  19. epub_translator/xml_translator/stream_mapper.py +145 -114
  20. epub_translator/xml_translator/submitter.py +5 -5
  21. epub_translator/xml_translator/translator.py +12 -18
  22. {epub_translator-0.1.5.dist-info → epub_translator-0.1.7.dist-info}/METADATA +37 -9
  23. epub_translator-0.1.7.dist-info/RECORD +63 -0
  24. epub_translator/data/mmltex/README.md +0 -67
  25. epub_translator/data/mmltex/cmarkup.xsl +0 -1106
  26. epub_translator/data/mmltex/entities.xsl +0 -459
  27. epub_translator/data/mmltex/glayout.xsl +0 -222
  28. epub_translator/data/mmltex/mmltex.xsl +0 -36
  29. epub_translator/data/mmltex/scripts.xsl +0 -375
  30. epub_translator/data/mmltex/tables.xsl +0 -130
  31. epub_translator/data/mmltex/tokens.xsl +0 -328
  32. epub_translator-0.1.5.dist-info/RECORD +0 -68
  33. /epub_translator/{language.py → translation/language.py} +0 -0
  34. /epub_translator/xml/{firendly → friendly}/__init__.py +0 -0
  35. /epub_translator/xml/{firendly → friendly}/decoder.py +0 -0
  36. /epub_translator/xml/{firendly → friendly}/encoder.py +0 -0
  37. /epub_translator/xml/{firendly → friendly}/parser.py +0 -0
  38. /epub_translator/xml/{firendly → friendly}/tag.py +0 -0
  39. /epub_translator/xml/{firendly → friendly}/transform.py +0 -0
  40. {epub_translator-0.1.5.dist-info → epub_translator-0.1.7.dist-info}/LICENSE +0 -0
  41. {epub_translator-0.1.5.dist-info → epub_translator-0.1.7.dist-info}/WHEEL +0 -0
@@ -1,6 +1,5 @@
1
- from . import language
2
1
  from .llm import LLM
3
- from .translator import FillFailedEvent, translate
2
+ from .translation import FillFailedEvent, language, translate
4
3
  from .xml_translator import SubmitKind
5
4
 
6
5
  __all__ = [
@@ -13,6 +13,9 @@ Translation rules:
13
13
  {% if user_prompt -%}
14
14
  User may provide additional requirements in <rules> tags before the source text. Follow them, but prioritize the rules above if conflicts arise.
15
15
 
16
+ <rules>
17
+ {{ user_prompt }}
18
+ </rules>
16
19
  {% endif -%}
17
20
 
18
21
  Output only the translated text, nothing else.
@@ -1,4 +1,4 @@
1
1
  from .metadata import read_metadata, write_metadata
2
2
  from .spines import search_spine_paths
3
- from .toc import read_toc, write_toc
3
+ from .toc import Toc, read_toc, write_toc
4
4
  from .zip import Zip
@@ -1,5 +1,6 @@
1
1
  import hashlib
2
2
  import json
3
+ import threading
3
4
  import uuid
4
5
  from pathlib import Path
5
6
  from typing import Self
@@ -8,6 +9,9 @@ from .executor import LLMExecutor
8
9
  from .increasable import Increasable, Increaser
9
10
  from .types import Message, MessageRole
10
11
 
12
+ # Global lock for cache file commit operations
13
+ _CACHE_COMMIT_LOCK = threading.Lock()
14
+
11
15
 
12
16
  class LLMContext:
13
17
  def __init__(
@@ -101,7 +105,12 @@ class LLMContext:
101
105
  # Remove the .[context-id].txt suffix to get permanent name
102
106
  permanent_name = temp_file.name.rsplit(".", 2)[0] + ".txt"
103
107
  permanent_file = temp_file.parent / permanent_name
104
- temp_file.rename(permanent_file)
108
+
109
+ with _CACHE_COMMIT_LOCK: # 多线程下的线程安全
110
+ if permanent_file.exists():
111
+ temp_file.unlink()
112
+ else:
113
+ temp_file.rename(permanent_file)
105
114
 
106
115
  def _rollback(self) -> None:
107
116
  for temp_file in self._temp_files:
@@ -1,4 +1,5 @@
1
1
  import datetime
2
+ import threading
2
3
  from collections.abc import Generator
3
4
  from importlib.resources import files
4
5
  from logging import DEBUG, FileHandler, Formatter, Logger, getLogger
@@ -14,6 +15,11 @@ from .executor import LLMExecutor
14
15
  from .increasable import Increasable
15
16
  from .types import Message
16
17
 
18
+ # Global state for logger filename generation
19
+ _LOGGER_LOCK = threading.Lock()
20
+ _LAST_TIMESTAMP: str | None = None
21
+ _LOGGER_SUFFIX_ID: int = 1
22
+
17
23
 
18
24
  class LLM:
19
25
  def __init__(
@@ -95,13 +101,34 @@ class LLM:
95
101
  return dir_path.resolve()
96
102
 
97
103
  def _create_logger(self) -> Logger | None:
104
+ # pylint: disable=global-statement
105
+ global _LAST_TIMESTAMP, _LOGGER_SUFFIX_ID
106
+
98
107
  if self._logger_save_path is None:
99
108
  return None
100
109
 
101
110
  now = datetime.datetime.now(datetime.UTC)
102
- timestamp = now.strftime("%Y-%m-%d %H-%M-%S %f")
103
- file_path = self._logger_save_path / f"request {timestamp}.log"
104
- logger = getLogger(f"LLM Request {timestamp}")
111
+ # Use second-level precision for collision detection
112
+ timestamp_key = now.strftime("%Y-%m-%d %H-%M-%S")
113
+
114
+ with _LOGGER_LOCK:
115
+ if _LAST_TIMESTAMP == timestamp_key:
116
+ _LOGGER_SUFFIX_ID += 1
117
+ suffix_id = _LOGGER_SUFFIX_ID
118
+ else:
119
+ _LAST_TIMESTAMP = timestamp_key
120
+ _LOGGER_SUFFIX_ID = 1
121
+ suffix_id = 1
122
+
123
+ if suffix_id == 1:
124
+ file_name = f"request {timestamp_key}.log"
125
+ logger_name = f"LLM Request {timestamp_key}"
126
+ else:
127
+ file_name = f"request {timestamp_key}_{suffix_id}.log"
128
+ logger_name = f"LLM Request {timestamp_key}_{suffix_id}"
129
+
130
+ file_path = self._logger_save_path / file_name
131
+ logger = getLogger(logger_name)
105
132
  logger.setLevel(DEBUG)
106
133
  handler = FileHandler(file_path, encoding="utf-8")
107
134
  handler.setLevel(DEBUG)
@@ -21,6 +21,7 @@ from .text_segment import (
21
21
  TextPosition,
22
22
  TextSegment,
23
23
  combine_text_segments,
24
+ find_block_depth,
24
25
  incision_between,
25
26
  search_text_segments,
26
27
  )
@@ -47,6 +47,7 @@ def search_inline_segments(text_segments: Iterable[TextSegment]) -> Generator["I
47
47
  inline_segment = _pop_stack_data(stack_data)
48
48
  stack_data = None
49
49
  if inline_segment:
50
+ inline_segment.id = 0
50
51
  yield inline_segment
51
52
 
52
53
  if stack_data is None:
@@ -73,6 +74,7 @@ def search_inline_segments(text_segments: Iterable[TextSegment]) -> Generator["I
73
74
  if stack_data is not None:
74
75
  inline_segment = _pop_stack_data(stack_data)
75
76
  if inline_segment:
77
+ inline_segment.id = 0
76
78
  yield inline_segment
77
79
 
78
80
 
@@ -115,7 +117,7 @@ class InlineSegment:
115
117
  self._child_tag2ids: dict[str, list[int]] = {}
116
118
  self._child_tag2count: dict[str, int] = {}
117
119
 
118
- next_temp_id: int = 0
120
+ next_temp_id: int = 1
119
121
  terms = nest((child.parent.tag, child) for child in children if isinstance(child, InlineSegment))
120
122
 
121
123
  for tag, child_terms in terms.items():
@@ -162,6 +164,14 @@ class InlineSegment:
162
164
  elif isinstance(child, InlineSegment):
163
165
  yield from child
164
166
 
167
+ def clone(self) -> "InlineSegment":
168
+ cloned_segment = InlineSegment(
169
+ depth=len(self._parent_stack),
170
+ children=[child.clone() for child in self._children],
171
+ )
172
+ cloned_segment.id = self.id
173
+ return cloned_segment
174
+
165
175
  def recreate_ids(self, id_generator: IDGenerator) -> None:
166
176
  self._child_tag2count.clear()
167
177
  self._child_tag2ids.clear()
@@ -4,7 +4,7 @@ from enum import Enum, auto
4
4
  from typing import Self
5
5
  from xml.etree.ElementTree import Element
6
6
 
7
- from ..xml import expand_left_element_texts, expand_right_element_texts, is_inline_tag, normalize_text_in_element
7
+ from ..xml import expand_left_element_texts, expand_right_element_texts, is_inline_element, normalize_text_in_element
8
8
 
9
9
 
10
10
  class TextPosition(Enum):
@@ -33,10 +33,6 @@ class TextSegment:
33
33
  def block_parent(self) -> Element:
34
34
  return self.parent_stack[self.block_depth - 1]
35
35
 
36
- @property
37
- def xml_text(self) -> str:
38
- return "".join(_expand_xml_texts(self))
39
-
40
36
  def strip_block_parents(self) -> Self:
41
37
  self.parent_stack = self.parent_stack[self.block_depth - 1 :]
42
38
  self.block_depth = 1
@@ -104,7 +100,7 @@ def search_text_segments(root: Element) -> Generator[TextSegment, None, None]:
104
100
  def _search_text_segments(stack: list[Element], element: Element) -> Generator[TextSegment, None, None]:
105
101
  text = normalize_text_in_element(element.text)
106
102
  next_stack = stack + [element]
107
- next_block_depth = _find_block_depth(next_stack)
103
+ next_block_depth = find_block_depth(next_stack)
108
104
 
109
105
  if text is not None:
110
106
  yield TextSegment(
@@ -129,12 +125,11 @@ def _search_text_segments(stack: list[Element], element: Element) -> Generator[T
129
125
  )
130
126
 
131
127
 
132
- def _find_block_depth(parent_stack: list[Element]) -> int:
128
+ def find_block_depth(parent_stack: list[Element]) -> int:
133
129
  index: int = 0
134
- for i in range(len(parent_stack) - 1, -1, -1):
135
- if not is_inline_tag(parent_stack[i].tag):
130
+ for i in range(len(parent_stack)):
131
+ if not is_inline_element(parent_stack[i]):
136
132
  index = i
137
- break
138
133
  return index + 1 # depth is a count not index
139
134
 
140
135
 
@@ -8,22 +8,6 @@ def element_fingerprint(element: Element) -> str:
8
8
  return f"<{element.tag} {' '.join(attrs)}/>"
9
9
 
10
10
 
11
- def unwrap_parents(element: Element) -> tuple[Element, list[Element]]:
12
- parents: list[Element] = []
13
- while True:
14
- if len(element) != 1:
15
- break
16
- child = element[0]
17
- if not element.text:
18
- break
19
- if not child.tail:
20
- break
21
- parents.append(element)
22
- element = child
23
- element.tail = None
24
- return element, parents
25
-
26
-
27
11
  def id_in_element(element: Element) -> int | None:
28
12
  id_str = element.get(ID_KEY, None)
29
13
  if id_str is None:
@@ -0,0 +1,2 @@
1
+ from . import language
2
+ from .translator import FillFailedEvent, translate
@@ -6,8 +6,8 @@ EPUB 数据结构与 XML 的编码/解码转换
6
6
 
7
7
  from xml.etree.ElementTree import Element
8
8
 
9
- from .epub.metadata import MetadataField
10
- from .epub.toc import Toc
9
+ from ..epub import Toc
10
+ from ..epub.metadata import MetadataField
11
11
 
12
12
 
13
13
  def encode_toc(toc: Toc) -> Element:
@@ -1,6 +1,6 @@
1
1
  from xml.etree.ElementTree import Element
2
2
 
3
- from .xml import iter_with_stack
3
+ from ..xml import iter_with_stack
4
4
 
5
5
  _QUOTE_MAPPING = {
6
6
  # 法语引号
@@ -5,7 +5,7 @@ from importlib.metadata import version as get_package_version
5
5
  from os import PathLike
6
6
  from pathlib import Path
7
7
 
8
- from .epub import (
8
+ from ..epub import (
9
9
  Zip,
10
10
  read_metadata,
11
11
  read_toc,
@@ -13,12 +13,12 @@ from .epub import (
13
13
  write_metadata,
14
14
  write_toc,
15
15
  )
16
+ from ..llm import LLM
17
+ from ..xml import XMLLikeNode, deduplicate_ids_in_element, find_first
18
+ from ..xml_translator import FillFailedEvent, SubmitKind, TranslationTask, XMLTranslator
16
19
  from .epub_transcode import decode_metadata, decode_toc_list, encode_metadata, encode_toc_list
17
- from .llm import LLM
18
20
  from .punctuation import unwrap_french_quotes
19
- from .xml import XMLLikeNode, deduplicate_ids_in_element, find_first
20
21
  from .xml_interrupter import XMLInterrupter
21
- from .xml_translator import FillFailedEvent, SubmitKind, TranslationTask, XMLTranslator
22
22
 
23
23
 
24
24
  class _ElementType(Enum):
@@ -40,7 +40,8 @@ def translate(
40
40
  submit: SubmitKind,
41
41
  user_prompt: str | None = None,
42
42
  max_retries: int = 5,
43
- max_group_tokens: int = 1200,
43
+ max_group_tokens: int = 2600,
44
+ concurrency: int = 1,
44
45
  llm: LLM | None = None,
45
46
  translation_llm: LLM | None = None,
46
47
  fill_llm: LLM | None = None,
@@ -62,7 +63,7 @@ def translate(
62
63
  ignore_translated_error=False,
63
64
  max_retries=max_retries,
64
65
  max_fill_displaying_errors=10,
65
- max_group_tokens=max_group_tokens,
66
+ max_group_score=max_group_tokens,
66
67
  cache_seed_content=f"{_get_version()}:{target_language}",
67
68
  )
68
69
  with Zip(
@@ -92,6 +93,7 @@ def translate(
92
93
  current_progress = 0.0
93
94
 
94
95
  for translated_elem, context in translator.translate_elements(
96
+ concurrency=concurrency,
95
97
  interrupt_source_text_segments=interrupter.interrupt_source_text_segments,
96
98
  interrupt_translated_text_segments=interrupter.interrupt_translated_text_segments,
97
99
  interrupt_block_element=interrupter.interrupt_block_element,
@@ -1,9 +1,13 @@
1
1
  from collections.abc import Generator, Iterable
2
2
  from typing import cast
3
- from xml.etree.ElementTree import Element
3
+ from xml.etree.ElementTree import Element, tostring
4
4
 
5
- from .segment import TextSegment
6
- from .utils import ensure_list, normalize_whitespace
5
+ from bs4 import BeautifulSoup
6
+ from mathml2latex.mathml import process_mathml
7
+
8
+ from ..segment import TextSegment, combine_text_segments, find_block_depth
9
+ from ..utils import ensure_list
10
+ from ..xml import clone_element
7
11
 
8
12
  _ID_KEY = "__XML_INTERRUPTER_ID"
9
13
  _MATH_TAG = "math"
@@ -37,8 +41,10 @@ class XMLInterrupter:
37
41
  def interrupt_block_element(self, element: Element) -> Element:
38
42
  interrupted_element = self._placeholder2interrupted.pop(id(element), None)
39
43
  if interrupted_element is None:
44
+ element.attrib.pop(_ID_KEY, None)
40
45
  return element
41
46
  else:
47
+ interrupted_element.attrib.pop(_ID_KEY, None)
42
48
  return interrupted_element
43
49
 
44
50
  def _expand_source_text_segment(self, text_segment: TextSegment):
@@ -81,14 +87,18 @@ class XMLInterrupter:
81
87
  _ID_KEY: cast(str, interrupted_element.get(_ID_KEY)),
82
88
  },
83
89
  )
90
+ interrupted_display = interrupted_element.get("display", None)
91
+ if interrupted_display is not None:
92
+ placeholder_element.set("display", interrupted_display)
93
+
84
94
  raw_parent_stack = text_segment.parent_stack[:interrupted_index]
85
95
  parent_stack = raw_parent_stack + [placeholder_element]
86
96
  merged_text_segment = TextSegment(
87
- text="".join(t.text for t in text_segments),
97
+ text=self._render_latex(text_segments),
88
98
  parent_stack=parent_stack,
89
99
  left_common_depth=text_segments[0].left_common_depth,
90
100
  right_common_depth=text_segments[-1].right_common_depth,
91
- block_depth=len(parent_stack),
101
+ block_depth=find_block_depth(parent_stack),
92
102
  position=text_segments[0].position,
93
103
  )
94
104
  self._placeholder2interrupted[id(placeholder_element)] = interrupted_element
@@ -116,8 +126,8 @@ class XMLInterrupter:
116
126
  # 原始栈退光,仅留下相对 interrupted 元素的栈,这种格式与 translated 要求一致
117
127
  text_segment.left_common_depth = max(0, text_segment.left_common_depth - interrupted_index)
118
128
  text_segment.right_common_depth = max(0, text_segment.right_common_depth - interrupted_index)
119
- text_segment.block_depth = 1
120
129
  text_segment.parent_stack = text_segment.parent_stack[interrupted_index:]
130
+ text_segment.block_depth = find_block_depth(text_segment.parent_stack)
121
131
 
122
132
  return merged_text_segment
123
133
 
@@ -129,37 +139,51 @@ class XMLInterrupter:
129
139
  break
130
140
  return interrupted_index
131
141
 
142
+ def _render_latex(self, text_segments: list[TextSegment]) -> str:
143
+ math_element, _ = next(combine_text_segments(text_segments))
144
+ while math_element.tag != _MATH_TAG:
145
+ if len(math_element) == 0:
146
+ return ""
147
+ math_element = math_element[0]
148
+
149
+ math_element = clone_element(math_element)
150
+ math_element.attrib.pop(_ID_KEY, None)
151
+ math_element.tail = None
152
+ latex: str | None = None
153
+ try:
154
+ mathml_str = tostring(math_element, encoding="unicode")
155
+ soup = BeautifulSoup(mathml_str, "html.parser")
156
+ latex = process_mathml(soup)
157
+ except Exception:
158
+ pass
159
+
160
+ if latex is None:
161
+ latex = "".join(t.text for t in text_segments)
162
+ elif math_element.get("display", None) == "inline":
163
+ latex = f"${latex}$"
164
+ else:
165
+ latex = f"$${latex}$$"
166
+
167
+ return f" {latex} "
168
+
132
169
  def _expand_translated_text_segment(self, text_segment: TextSegment):
133
- interrupted_id = text_segment.block_parent.attrib.pop(_ID_KEY, None)
170
+ parent_element = text_segment.parent_stack[-1]
171
+ interrupted_id = parent_element.attrib.pop(_ID_KEY, None)
134
172
  if interrupted_id is None:
135
173
  yield text_segment
136
174
  return
137
175
 
138
- raw_text_segments = self._raw_text_segments.pop(interrupted_id, None)
139
- if not raw_text_segments:
176
+ if parent_element is text_segment.block_parent:
177
+ # Block-level math, need to be hidden
140
178
  return
141
179
 
142
- raw_block = raw_text_segments[0].parent_stack[0]
143
- if not self._is_inline_math(raw_block):
180
+ raw_text_segments = self._raw_text_segments.pop(interrupted_id, None)
181
+ if not raw_text_segments:
182
+ yield text_segment
144
183
  return
145
184
 
146
185
  for raw_text_segment in raw_text_segments:
186
+ text_basic_parent_stack = text_segment.parent_stack[:-1]
147
187
  raw_text_segment.block_parent.attrib.pop(_ID_KEY, None)
188
+ raw_text_segment.parent_stack = text_basic_parent_stack + raw_text_segment.parent_stack
148
189
  yield raw_text_segment
149
-
150
- def _has_no_math_texts(self, element: Element):
151
- if element.tag == _MATH_TAG:
152
- return True
153
- if element.text and normalize_whitespace(element.text).strip():
154
- return False
155
- for child_element in element:
156
- if not self._has_no_math_texts(child_element):
157
- return False
158
- if child_element.tail and normalize_whitespace(child_element.tail).strip():
159
- return False
160
- return True
161
-
162
- def _is_inline_math(self, element: Element) -> bool:
163
- if element.tag != _MATH_TAG:
164
- return False
165
- return element.get("display", "").lower() != "block"
@@ -1,6 +1,6 @@
1
1
  from .const import *
2
2
  from .deduplication import *
3
- from .firendly import *
3
+ from .friendly import *
4
4
  from .inline import *
5
5
  from .utils import *
6
6
  from .xml import *
@@ -1,6 +1,9 @@
1
+ from xml.etree.ElementTree import Element
2
+
1
3
  # HTML inline-level elements
2
4
  # Reference: https://developer.mozilla.org/en-US/docs/Web/HTML/Inline_elements
3
5
  # Reference: https://developer.mozilla.org/en-US/docs/Glossary/Inline-level_content
6
+ # Reference: https://developer.mozilla.org/en-US/docs/MathML/Element
4
7
  _HTML_INLINE_TAGS = frozenset(
5
8
  (
6
9
  # Inline text semantics
@@ -59,9 +62,52 @@ _HTML_INLINE_TAGS = frozenset(
59
62
  "del",
60
63
  "ins",
61
64
  "slot",
65
+ # MathML elements
66
+ # Token elements
67
+ "mi", # identifier
68
+ "mn", # number
69
+ "mo", # operator
70
+ "ms", # string literal
71
+ "mspace", # space
72
+ "mtext", # text
73
+ # General layout
74
+ "menclose", # enclosed content
75
+ "merror", # syntax error message
76
+ "mfenced", # parentheses (deprecated)
77
+ "mfrac", # fraction
78
+ "mpadded", # space around content
79
+ "mphantom", # invisible content
80
+ "mroot", # radical with index
81
+ "mrow", # grouped sub-expressions
82
+ "msqrt", # square root
83
+ "mstyle", # style change
84
+ # Scripts and limits
85
+ "mmultiscripts", # prescripts and tensor indices
86
+ "mover", # overscript
87
+ "mprescripts", # prescripts separator
88
+ "msub", # subscript
89
+ "msubsup", # subscript-superscript pair
90
+ "msup", # superscript
91
+ "munder", # underscript
92
+ "munderover", # underscript-overscript pair
93
+ # Table math
94
+ "mtable", # table or matrix
95
+ "mtr", # row in table or matrix
96
+ "mtd", # cell in table or matrix
97
+ # Semantic annotations
98
+ "annotation", # data annotation
99
+ "annotation-xml", # XML annotation
100
+ "semantics", # semantic annotation container
101
+ # Other
102
+ "maction", # bind actions to sub-expressions (deprecated)
62
103
  )
63
104
  )
64
105
 
65
106
 
66
- def is_inline_tag(tag: str) -> bool:
67
- return tag.lower() in _HTML_INLINE_TAGS
107
+ def is_inline_element(element: Element) -> bool:
108
+ if element.tag.lower() in _HTML_INLINE_TAGS:
109
+ return True
110
+ display = element.get("display", None)
111
+ if display is not None and display.lower() == "inline":
112
+ return True
113
+ return False
@@ -0,0 +1,52 @@
1
+ from collections import deque
2
+ from collections.abc import Callable, Iterable
3
+ from concurrent.futures import Future, ThreadPoolExecutor
4
+ from typing import TypeVar
5
+
6
+ P = TypeVar("P")
7
+ R = TypeVar("R")
8
+
9
+
10
+ def run_concurrency(
11
+ parameters: Iterable[P],
12
+ execute: Callable[[P], R],
13
+ concurrency: int,
14
+ ) -> Iterable[R]:
15
+ assert concurrency >= 1, "the concurrency must be at least 1"
16
+ # Fast path: concurrency == 1, no thread overhead
17
+ if concurrency == 1:
18
+ for param in parameters:
19
+ yield execute(param)
20
+ return
21
+
22
+ executor = ThreadPoolExecutor(max_workers=concurrency)
23
+ did_shutdown = False
24
+ try:
25
+ futures: deque[Future[R]] = deque()
26
+ params_iter = iter(parameters)
27
+ for _ in range(concurrency):
28
+ try:
29
+ param = next(params_iter)
30
+ future = executor.submit(execute, param)
31
+ futures.append(future)
32
+ except StopIteration:
33
+ break
34
+
35
+ while futures:
36
+ future = futures.popleft()
37
+ yield future.result()
38
+ try:
39
+ param = next(params_iter)
40
+ new_future = executor.submit(execute, param)
41
+ futures.append(new_future)
42
+ except StopIteration:
43
+ pass
44
+
45
+ except KeyboardInterrupt:
46
+ executor.shutdown(wait=False, cancel_futures=True)
47
+ did_shutdown = True
48
+ raise
49
+
50
+ finally:
51
+ if not did_shutdown:
52
+ executor.shutdown(wait=True)