epub-translator 0.1.1__py3-none-any.whl → 0.1.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- epub_translator/__init__.py +9 -2
- epub_translator/data/fill.jinja +143 -38
- epub_translator/epub/__init__.py +1 -1
- epub_translator/epub/metadata.py +122 -0
- epub_translator/epub/spines.py +3 -2
- epub_translator/epub/zip.py +11 -9
- epub_translator/epub_transcode.py +108 -0
- epub_translator/llm/__init__.py +1 -0
- epub_translator/llm/context.py +109 -0
- epub_translator/llm/core.py +32 -113
- epub_translator/llm/executor.py +25 -31
- epub_translator/llm/increasable.py +1 -1
- epub_translator/llm/types.py +0 -3
- epub_translator/punctuation.py +34 -0
- epub_translator/segment/__init__.py +26 -0
- epub_translator/segment/block_segment.py +124 -0
- epub_translator/segment/common.py +29 -0
- epub_translator/segment/inline_segment.py +356 -0
- epub_translator/{xml_translator → segment}/text_segment.py +7 -72
- epub_translator/segment/utils.py +43 -0
- epub_translator/translator.py +152 -184
- epub_translator/utils.py +33 -0
- epub_translator/xml/__init__.py +3 -0
- epub_translator/xml/const.py +1 -0
- epub_translator/xml/deduplication.py +3 -3
- epub_translator/xml/inline.py +67 -0
- epub_translator/xml/self_closing.py +182 -0
- epub_translator/xml/utils.py +42 -0
- epub_translator/xml/xml.py +7 -0
- epub_translator/xml/xml_like.py +8 -33
- epub_translator/xml_interrupter.py +165 -0
- epub_translator/xml_translator/__init__.py +3 -3
- epub_translator/xml_translator/callbacks.py +34 -0
- epub_translator/xml_translator/{const.py → common.py} +0 -1
- epub_translator/xml_translator/hill_climbing.py +104 -0
- epub_translator/xml_translator/stream_mapper.py +253 -0
- epub_translator/xml_translator/submitter.py +352 -91
- epub_translator/xml_translator/translator.py +182 -114
- epub_translator/xml_translator/validation.py +458 -0
- {epub_translator-0.1.1.dist-info → epub_translator-0.1.4.dist-info}/METADATA +134 -21
- epub_translator-0.1.4.dist-info/RECORD +68 -0
- epub_translator/epub/placeholder.py +0 -53
- epub_translator/iter_sync.py +0 -24
- epub_translator/xml_translator/fill.py +0 -128
- epub_translator/xml_translator/format.py +0 -282
- epub_translator/xml_translator/fragmented.py +0 -125
- epub_translator/xml_translator/group.py +0 -183
- epub_translator/xml_translator/progressive_locking.py +0 -256
- epub_translator/xml_translator/utils.py +0 -29
- epub_translator-0.1.1.dist-info/RECORD +0 -58
- {epub_translator-0.1.1.dist-info → epub_translator-0.1.4.dist-info}/LICENSE +0 -0
- {epub_translator-0.1.1.dist-info → epub_translator-0.1.4.dist-info}/WHEEL +0 -0
epub_translator/llm/core.py
CHANGED
|
@@ -1,104 +1,18 @@
|
|
|
1
1
|
import datetime
|
|
2
|
-
import
|
|
3
|
-
import json
|
|
4
|
-
import uuid
|
|
5
|
-
from collections.abc import Callable, Generator
|
|
2
|
+
from collections.abc import Generator
|
|
6
3
|
from importlib.resources import files
|
|
7
4
|
from logging import DEBUG, FileHandler, Formatter, Logger, getLogger
|
|
8
5
|
from os import PathLike
|
|
9
6
|
from pathlib import Path
|
|
10
|
-
from typing import Self
|
|
11
7
|
|
|
12
8
|
from jinja2 import Environment, Template
|
|
13
9
|
from tiktoken import Encoding, get_encoding
|
|
14
10
|
|
|
15
11
|
from ..template import create_env
|
|
12
|
+
from .context import LLMContext
|
|
16
13
|
from .executor import LLMExecutor
|
|
17
14
|
from .increasable import Increasable
|
|
18
|
-
from .types import Message
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
class LLMContext:
|
|
22
|
-
"""Context manager for LLM requests with transactional caching."""
|
|
23
|
-
|
|
24
|
-
def __init__(
|
|
25
|
-
self,
|
|
26
|
-
executor: LLMExecutor,
|
|
27
|
-
cache_path: Path | None,
|
|
28
|
-
) -> None:
|
|
29
|
-
self._executor = executor
|
|
30
|
-
self._cache_path = cache_path
|
|
31
|
-
self._context_id = uuid.uuid4().hex[:12]
|
|
32
|
-
self._temp_files: list[Path] = []
|
|
33
|
-
|
|
34
|
-
def __enter__(self) -> Self:
|
|
35
|
-
return self
|
|
36
|
-
|
|
37
|
-
def __exit__(self, exc_type, exc_val, exc_tb) -> None:
|
|
38
|
-
if exc_type is None:
|
|
39
|
-
# Success: commit all temporary cache files
|
|
40
|
-
self._commit()
|
|
41
|
-
else:
|
|
42
|
-
# Failure: rollback (delete) all temporary cache files
|
|
43
|
-
self._rollback()
|
|
44
|
-
|
|
45
|
-
def request(
|
|
46
|
-
self,
|
|
47
|
-
input: str | list[Message],
|
|
48
|
-
parser: Callable[[str], R] = lambda x: x,
|
|
49
|
-
max_tokens: int | None = None,
|
|
50
|
-
) -> R:
|
|
51
|
-
messages: list[Message]
|
|
52
|
-
if isinstance(input, str):
|
|
53
|
-
messages = [Message(role=MessageRole.USER, message=input)]
|
|
54
|
-
else:
|
|
55
|
-
messages = input
|
|
56
|
-
|
|
57
|
-
cache_key: str | None = None
|
|
58
|
-
if self._cache_path is not None:
|
|
59
|
-
cache_key = self._compute_messages_hash(messages)
|
|
60
|
-
permanent_cache_file = self._cache_path / f"{cache_key}.txt"
|
|
61
|
-
if permanent_cache_file.exists():
|
|
62
|
-
cached_content = permanent_cache_file.read_text(encoding="utf-8")
|
|
63
|
-
return parser(cached_content)
|
|
64
|
-
|
|
65
|
-
temp_cache_file = self._cache_path / f"{cache_key}.{self._context_id}.txt"
|
|
66
|
-
if temp_cache_file.exists():
|
|
67
|
-
cached_content = temp_cache_file.read_text(encoding="utf-8")
|
|
68
|
-
return parser(cached_content)
|
|
69
|
-
|
|
70
|
-
# Make the actual request
|
|
71
|
-
response = self._executor.request(
|
|
72
|
-
messages=messages,
|
|
73
|
-
parser=lambda x: x,
|
|
74
|
-
max_tokens=max_tokens,
|
|
75
|
-
)
|
|
76
|
-
|
|
77
|
-
# Save to temporary cache if cache_path is set
|
|
78
|
-
if self._cache_path is not None and cache_key is not None:
|
|
79
|
-
temp_cache_file = self._cache_path / f"{cache_key}.{self._context_id}.txt"
|
|
80
|
-
temp_cache_file.write_text(response, encoding="utf-8")
|
|
81
|
-
self._temp_files.append(temp_cache_file)
|
|
82
|
-
|
|
83
|
-
return parser(response)
|
|
84
|
-
|
|
85
|
-
def _compute_messages_hash(self, messages: list[Message]) -> str:
|
|
86
|
-
messages_dict = [{"role": msg.role.value, "message": msg.message} for msg in messages]
|
|
87
|
-
messages_json = json.dumps(messages_dict, ensure_ascii=False, sort_keys=True)
|
|
88
|
-
return hashlib.sha512(messages_json.encode("utf-8")).hexdigest()
|
|
89
|
-
|
|
90
|
-
def _commit(self) -> None:
|
|
91
|
-
for temp_file in self._temp_files:
|
|
92
|
-
if temp_file.exists():
|
|
93
|
-
# Remove the .[context-id].txt suffix to get permanent name
|
|
94
|
-
permanent_name = temp_file.name.rsplit(".", 2)[0] + ".txt"
|
|
95
|
-
permanent_file = temp_file.parent / permanent_name
|
|
96
|
-
temp_file.rename(permanent_file)
|
|
97
|
-
|
|
98
|
-
def _rollback(self) -> None:
|
|
99
|
-
for temp_file in self._temp_files:
|
|
100
|
-
if temp_file.exists():
|
|
101
|
-
temp_file.unlink()
|
|
15
|
+
from .types import Message
|
|
102
16
|
|
|
103
17
|
|
|
104
18
|
class LLM:
|
|
@@ -108,42 +22,28 @@ class LLM:
|
|
|
108
22
|
url: str,
|
|
109
23
|
model: str,
|
|
110
24
|
token_encoding: str,
|
|
111
|
-
cache_path: PathLike | None = None,
|
|
112
25
|
timeout: float | None = None,
|
|
113
26
|
top_p: float | tuple[float, float] | None = None,
|
|
114
27
|
temperature: float | tuple[float, float] | None = None,
|
|
115
28
|
retry_times: int = 5,
|
|
116
29
|
retry_interval_seconds: float = 6.0,
|
|
117
|
-
|
|
30
|
+
cache_path: PathLike | str | None = None,
|
|
31
|
+
log_dir_path: PathLike | str | None = None,
|
|
118
32
|
) -> None:
|
|
119
33
|
prompts_path = Path(str(files("epub_translator"))) / "data"
|
|
120
34
|
self._templates: dict[str, Template] = {}
|
|
121
35
|
self._encoding: Encoding = get_encoding(token_encoding)
|
|
122
36
|
self._env: Environment = create_env(prompts_path)
|
|
123
|
-
self.
|
|
124
|
-
self.
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
self._cache_path = Path(cache_path)
|
|
128
|
-
if not self._cache_path.exists():
|
|
129
|
-
self._cache_path.mkdir(parents=True, exist_ok=True)
|
|
130
|
-
elif not self._cache_path.is_dir():
|
|
131
|
-
self._cache_path = None
|
|
132
|
-
|
|
133
|
-
if log_dir_path is not None:
|
|
134
|
-
self._logger_save_path = Path(log_dir_path)
|
|
135
|
-
if not self._logger_save_path.exists():
|
|
136
|
-
self._logger_save_path.mkdir(parents=True, exist_ok=True)
|
|
137
|
-
elif not self._logger_save_path.is_dir():
|
|
138
|
-
self._logger_save_path = None
|
|
37
|
+
self._top_p: Increasable = Increasable(top_p)
|
|
38
|
+
self._temperature: Increasable = Increasable(temperature)
|
|
39
|
+
self._cache_path: Path | None = self._ensure_dir_path(cache_path)
|
|
40
|
+
self._logger_save_path: Path | None = self._ensure_dir_path(log_dir_path)
|
|
139
41
|
|
|
140
42
|
self._executor = LLMExecutor(
|
|
141
43
|
url=url,
|
|
142
44
|
model=model,
|
|
143
45
|
api_key=key,
|
|
144
46
|
timeout=timeout,
|
|
145
|
-
top_p=Increasable(top_p),
|
|
146
|
-
temperature=Increasable(temperature),
|
|
147
47
|
retry_times=retry_times,
|
|
148
48
|
retry_interval_seconds=retry_interval_seconds,
|
|
149
49
|
create_logger=self._create_logger,
|
|
@@ -153,20 +53,29 @@ class LLM:
|
|
|
153
53
|
def encoding(self) -> Encoding:
|
|
154
54
|
return self._encoding
|
|
155
55
|
|
|
156
|
-
def context(self) -> LLMContext:
|
|
56
|
+
def context(self, cache_seed_content: str | None = None) -> LLMContext:
|
|
157
57
|
return LLMContext(
|
|
158
58
|
executor=self._executor,
|
|
159
59
|
cache_path=self._cache_path,
|
|
60
|
+
cache_seed_content=cache_seed_content,
|
|
61
|
+
top_p=self._top_p,
|
|
62
|
+
temperature=self._temperature,
|
|
160
63
|
)
|
|
161
64
|
|
|
162
65
|
def request(
|
|
163
66
|
self,
|
|
164
67
|
input: str | list[Message],
|
|
165
|
-
parser: Callable[[str], R] = lambda x: x,
|
|
166
68
|
max_tokens: int | None = None,
|
|
167
|
-
|
|
69
|
+
temperature: float | None = None,
|
|
70
|
+
top_p: float | None = None,
|
|
71
|
+
) -> str:
|
|
168
72
|
with self.context() as ctx:
|
|
169
|
-
return ctx.request(
|
|
73
|
+
return ctx.request(
|
|
74
|
+
input=input,
|
|
75
|
+
max_tokens=max_tokens,
|
|
76
|
+
temperature=temperature,
|
|
77
|
+
top_p=top_p,
|
|
78
|
+
)
|
|
170
79
|
|
|
171
80
|
def template(self, template_name: str) -> Template:
|
|
172
81
|
template = self._templates.get(template_name, None)
|
|
@@ -175,6 +84,16 @@ class LLM:
|
|
|
175
84
|
self._templates[template_name] = template
|
|
176
85
|
return template
|
|
177
86
|
|
|
87
|
+
def _ensure_dir_path(self, path: PathLike | str | None) -> Path | None:
|
|
88
|
+
if path is None:
|
|
89
|
+
return None
|
|
90
|
+
dir_path = Path(path)
|
|
91
|
+
if not dir_path.exists():
|
|
92
|
+
dir_path.mkdir(parents=True, exist_ok=True)
|
|
93
|
+
elif not dir_path.is_dir():
|
|
94
|
+
return None
|
|
95
|
+
return dir_path.resolve()
|
|
96
|
+
|
|
178
97
|
def _create_logger(self) -> Logger | None:
|
|
179
98
|
if self._logger_save_path is None:
|
|
180
99
|
return None
|
epub_translator/llm/executor.py
CHANGED
|
@@ -2,14 +2,12 @@ from collections.abc import Callable
|
|
|
2
2
|
from io import StringIO
|
|
3
3
|
from logging import Logger
|
|
4
4
|
from time import sleep
|
|
5
|
-
from typing import cast
|
|
6
5
|
|
|
7
6
|
from openai import OpenAI
|
|
8
7
|
from openai.types.chat import ChatCompletionMessageParam
|
|
9
8
|
|
|
10
9
|
from .error import is_retry_error
|
|
11
|
-
from .
|
|
12
|
-
from .types import Message, MessageRole, R
|
|
10
|
+
from .types import Message, MessageRole
|
|
13
11
|
|
|
14
12
|
|
|
15
13
|
class LLMExecutor:
|
|
@@ -19,16 +17,12 @@ class LLMExecutor:
|
|
|
19
17
|
url: str,
|
|
20
18
|
model: str,
|
|
21
19
|
timeout: float | None,
|
|
22
|
-
top_p: Increasable,
|
|
23
|
-
temperature: Increasable,
|
|
24
20
|
retry_times: int,
|
|
25
21
|
retry_interval_seconds: float,
|
|
26
22
|
create_logger: Callable[[], Logger | None],
|
|
27
23
|
) -> None:
|
|
28
24
|
self._model_name: str = model
|
|
29
25
|
self._timeout: float | None = timeout
|
|
30
|
-
self._top_p: Increasable = top_p
|
|
31
|
-
self._temperature: Increasable = temperature
|
|
32
26
|
self._retry_times: int = retry_times
|
|
33
27
|
self._retry_interval_seconds: float = retry_interval_seconds
|
|
34
28
|
self._create_logger: Callable[[], Logger | None] = create_logger
|
|
@@ -38,15 +32,29 @@ class LLMExecutor:
|
|
|
38
32
|
timeout=timeout,
|
|
39
33
|
)
|
|
40
34
|
|
|
41
|
-
def request(
|
|
42
|
-
|
|
35
|
+
def request(
|
|
36
|
+
self,
|
|
37
|
+
messages: list[Message],
|
|
38
|
+
max_tokens: int | None,
|
|
39
|
+
temperature: float | None,
|
|
40
|
+
top_p: float | None,
|
|
41
|
+
cache_key: str | None,
|
|
42
|
+
) -> str:
|
|
43
|
+
response: str = ""
|
|
43
44
|
last_error: Exception | None = None
|
|
44
45
|
did_success = False
|
|
45
|
-
top_p: Increaser = self._top_p.context()
|
|
46
|
-
temperature: Increaser = self._temperature.context()
|
|
47
46
|
logger = self._create_logger()
|
|
48
47
|
|
|
49
48
|
if logger is not None:
|
|
49
|
+
parameters: list[str] = [
|
|
50
|
+
f"\t\ntemperature={temperature}",
|
|
51
|
+
f"\t\ntop_p={top_p}",
|
|
52
|
+
f"\t\nmax_tokens={max_tokens}",
|
|
53
|
+
]
|
|
54
|
+
if cache_key is not None:
|
|
55
|
+
parameters.append(f"\t\ncache_key={cache_key}")
|
|
56
|
+
|
|
57
|
+
logger.debug(f"[[Parameters]]:{''.join(parameters)}\n")
|
|
50
58
|
logger.debug(f"[[Request]]:\n{self._input2str(messages)}\n")
|
|
51
59
|
|
|
52
60
|
try:
|
|
@@ -54,8 +62,8 @@ class LLMExecutor:
|
|
|
54
62
|
try:
|
|
55
63
|
response = self._invoke_model(
|
|
56
64
|
input_messages=messages,
|
|
57
|
-
|
|
58
|
-
|
|
65
|
+
temperature=temperature,
|
|
66
|
+
top_p=top_p,
|
|
59
67
|
max_tokens=max_tokens,
|
|
60
68
|
)
|
|
61
69
|
if logger is not None:
|
|
@@ -71,22 +79,8 @@ class LLMExecutor:
|
|
|
71
79
|
sleep(self._retry_interval_seconds)
|
|
72
80
|
continue
|
|
73
81
|
|
|
74
|
-
|
|
75
|
-
|
|
76
|
-
did_success = True
|
|
77
|
-
break
|
|
78
|
-
|
|
79
|
-
except Exception as err:
|
|
80
|
-
last_error = err
|
|
81
|
-
warn_message = f"request failed with parsing error, retrying... ({i + 1} times)"
|
|
82
|
-
if logger is not None:
|
|
83
|
-
logger.warning(warn_message)
|
|
84
|
-
print(warn_message)
|
|
85
|
-
top_p.increase()
|
|
86
|
-
temperature.increase()
|
|
87
|
-
if self._retry_interval_seconds > 0.0 and i < self._retry_times:
|
|
88
|
-
sleep(self._retry_interval_seconds)
|
|
89
|
-
continue
|
|
82
|
+
did_success = True
|
|
83
|
+
break
|
|
90
84
|
|
|
91
85
|
except KeyboardInterrupt as err:
|
|
92
86
|
if last_error is not None and logger is not None:
|
|
@@ -99,7 +93,7 @@ class LLMExecutor:
|
|
|
99
93
|
else:
|
|
100
94
|
raise last_error
|
|
101
95
|
|
|
102
|
-
return
|
|
96
|
+
return response
|
|
103
97
|
|
|
104
98
|
def _input2str(self, input: str | list[Message]) -> str:
|
|
105
99
|
if isinstance(input, str):
|
|
@@ -133,7 +127,7 @@ class LLMExecutor:
|
|
|
133
127
|
top_p: float | None,
|
|
134
128
|
temperature: float | None,
|
|
135
129
|
max_tokens: int | None,
|
|
136
|
-
):
|
|
130
|
+
) -> str:
|
|
137
131
|
messages: list[ChatCompletionMessageParam] = []
|
|
138
132
|
for item in input_messages:
|
|
139
133
|
if item.role == MessageRole.SYSTEM:
|
|
@@ -21,7 +21,7 @@ class Increasable:
|
|
|
21
21
|
param = float(param)
|
|
22
22
|
if isinstance(param, float):
|
|
23
23
|
param = (param, param)
|
|
24
|
-
if isinstance(param, tuple):
|
|
24
|
+
if isinstance(param, (tuple, list)):
|
|
25
25
|
if len(param) != 2:
|
|
26
26
|
raise ValueError(f"Expected a tuple of length 2, got {len(param)}")
|
|
27
27
|
begin, end = param
|
epub_translator/llm/types.py
CHANGED
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
from xml.etree.ElementTree import Element
|
|
2
|
+
|
|
3
|
+
from .xml import iter_with_stack
|
|
4
|
+
|
|
5
|
+
_QUOTE_MAPPING = {
|
|
6
|
+
# 法语引号
|
|
7
|
+
"«": "",
|
|
8
|
+
"»": "",
|
|
9
|
+
"‹": "«",
|
|
10
|
+
"›": "»",
|
|
11
|
+
# 中文书书名号
|
|
12
|
+
"《": "",
|
|
13
|
+
"》": "",
|
|
14
|
+
"〈": "《",
|
|
15
|
+
"〉": "》",
|
|
16
|
+
}
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def _strip_quotes(text: str):
|
|
20
|
+
for char in text:
|
|
21
|
+
mapped = _QUOTE_MAPPING.get(char, None)
|
|
22
|
+
if mapped is None:
|
|
23
|
+
yield char
|
|
24
|
+
elif mapped:
|
|
25
|
+
yield mapped
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def unwrap_french_quotes(element: Element) -> Element:
|
|
29
|
+
for _, child_element in iter_with_stack(element):
|
|
30
|
+
if child_element.text:
|
|
31
|
+
child_element.text = "".join(_strip_quotes(child_element.text))
|
|
32
|
+
if child_element.tail:
|
|
33
|
+
child_element.tail = "".join(_strip_quotes(child_element.tail))
|
|
34
|
+
return element
|
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
from .block_segment import (
|
|
2
|
+
BlockContentError,
|
|
3
|
+
BlockError,
|
|
4
|
+
BlockExpectedIDsError,
|
|
5
|
+
BlockSegment,
|
|
6
|
+
BlockSubmitter,
|
|
7
|
+
BlockUnexpectedIDError,
|
|
8
|
+
BlockWrongTagError,
|
|
9
|
+
)
|
|
10
|
+
from .common import FoundInvalidIDError
|
|
11
|
+
from .inline_segment import (
|
|
12
|
+
InlineError,
|
|
13
|
+
InlineExpectedIDsError,
|
|
14
|
+
InlineLostIDError,
|
|
15
|
+
InlineSegment,
|
|
16
|
+
InlineUnexpectedIDError,
|
|
17
|
+
InlineWrongTagCountError,
|
|
18
|
+
search_inline_segments,
|
|
19
|
+
)
|
|
20
|
+
from .text_segment import (
|
|
21
|
+
TextPosition,
|
|
22
|
+
TextSegment,
|
|
23
|
+
combine_text_segments,
|
|
24
|
+
incision_between,
|
|
25
|
+
search_text_segments,
|
|
26
|
+
)
|
|
@@ -0,0 +1,124 @@
|
|
|
1
|
+
from collections.abc import Generator
|
|
2
|
+
from dataclasses import dataclass
|
|
3
|
+
from typing import cast
|
|
4
|
+
from xml.etree.ElementTree import Element
|
|
5
|
+
|
|
6
|
+
from .common import FoundInvalidIDError, validate_id_in_element
|
|
7
|
+
from .inline_segment import InlineError, InlineSegment
|
|
8
|
+
from .text_segment import TextSegment
|
|
9
|
+
from .utils import IDGenerator, id_in_element
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
@dataclass
|
|
13
|
+
class BlockSubmitter:
|
|
14
|
+
id: int
|
|
15
|
+
origin_text_segments: list[TextSegment]
|
|
16
|
+
submitted_element: Element
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
@dataclass
|
|
20
|
+
class BlockWrongTagError:
|
|
21
|
+
block: tuple[int, Element] | None # (block_id, block_element) | None 表示根元素
|
|
22
|
+
expected_tag: str
|
|
23
|
+
instead_tag: str
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
@dataclass
|
|
27
|
+
class BlockUnexpectedIDError:
|
|
28
|
+
id: int
|
|
29
|
+
element: Element
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
@dataclass
|
|
33
|
+
class BlockExpectedIDsError:
|
|
34
|
+
id2element: dict[int, Element]
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
@dataclass
|
|
38
|
+
class BlockContentError:
|
|
39
|
+
id: int
|
|
40
|
+
element: Element
|
|
41
|
+
errors: list[InlineError | FoundInvalidIDError]
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
BlockError = BlockWrongTagError | BlockUnexpectedIDError | BlockExpectedIDsError | BlockContentError
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
class BlockSegment:
|
|
48
|
+
def __init__(self, root_tag: str, inline_segments: list[InlineSegment]) -> None:
|
|
49
|
+
id_generator = IDGenerator()
|
|
50
|
+
for inline_segment in inline_segments:
|
|
51
|
+
inline_segment.id = id_generator.next_id()
|
|
52
|
+
inline_segment.recreate_ids(id_generator)
|
|
53
|
+
|
|
54
|
+
self._root_tag: str = root_tag
|
|
55
|
+
self._inline_segments: list[InlineSegment] = inline_segments
|
|
56
|
+
self._id2inline_segment: dict[int, InlineSegment] = dict((cast(int, s.id), s) for s in self._inline_segments)
|
|
57
|
+
|
|
58
|
+
def __iter__(self) -> Generator[InlineSegment, None, None]:
|
|
59
|
+
yield from self._inline_segments
|
|
60
|
+
|
|
61
|
+
def create_element(self) -> Element:
|
|
62
|
+
root_element = Element(self._root_tag)
|
|
63
|
+
for inline_segment in self._inline_segments:
|
|
64
|
+
root_element.append(inline_segment.create_element())
|
|
65
|
+
return root_element
|
|
66
|
+
|
|
67
|
+
def validate(self, validated_element: Element) -> Generator[BlockError | FoundInvalidIDError, None, None]:
|
|
68
|
+
if validated_element.tag != self._root_tag:
|
|
69
|
+
yield BlockWrongTagError(
|
|
70
|
+
block=None,
|
|
71
|
+
expected_tag=self._root_tag,
|
|
72
|
+
instead_tag=validated_element.tag,
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
remain_expected_elements: dict[int, Element] = dict(
|
|
76
|
+
(id, inline_segment.parent) for id, inline_segment in self._id2inline_segment.items()
|
|
77
|
+
)
|
|
78
|
+
for child_validated_element in validated_element:
|
|
79
|
+
element_id = validate_id_in_element(child_validated_element)
|
|
80
|
+
if isinstance(element_id, FoundInvalidIDError):
|
|
81
|
+
yield element_id
|
|
82
|
+
else:
|
|
83
|
+
inline_segment = self._id2inline_segment.get(element_id, None)
|
|
84
|
+
if inline_segment is None:
|
|
85
|
+
yield BlockUnexpectedIDError(
|
|
86
|
+
id=element_id,
|
|
87
|
+
element=child_validated_element,
|
|
88
|
+
)
|
|
89
|
+
else:
|
|
90
|
+
if inline_segment.parent.tag != child_validated_element.tag:
|
|
91
|
+
yield BlockWrongTagError(
|
|
92
|
+
block=(cast(int, inline_segment.id), inline_segment.parent),
|
|
93
|
+
expected_tag=inline_segment.parent.tag,
|
|
94
|
+
instead_tag=child_validated_element.tag,
|
|
95
|
+
)
|
|
96
|
+
|
|
97
|
+
remain_expected_elements.pop(element_id, None)
|
|
98
|
+
inline_errors = list(inline_segment.validate(child_validated_element))
|
|
99
|
+
|
|
100
|
+
if inline_errors:
|
|
101
|
+
yield BlockContentError(
|
|
102
|
+
id=element_id,
|
|
103
|
+
element=child_validated_element,
|
|
104
|
+
errors=inline_errors,
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
if remain_expected_elements:
|
|
108
|
+
yield BlockExpectedIDsError(id2element=remain_expected_elements)
|
|
109
|
+
|
|
110
|
+
def submit(self, target: Element) -> Generator[BlockSubmitter, None, None]:
|
|
111
|
+
for child_element in target:
|
|
112
|
+
element_id = id_in_element(child_element)
|
|
113
|
+
if element_id is None:
|
|
114
|
+
continue
|
|
115
|
+
inline_segment = self._id2inline_segment.get(element_id, None)
|
|
116
|
+
if inline_segment is None:
|
|
117
|
+
continue
|
|
118
|
+
inline_segment_id = inline_segment.id
|
|
119
|
+
assert inline_segment_id is not None
|
|
120
|
+
yield BlockSubmitter(
|
|
121
|
+
id=inline_segment_id,
|
|
122
|
+
origin_text_segments=list(inline_segment),
|
|
123
|
+
submitted_element=inline_segment.assign_attributes(child_element),
|
|
124
|
+
)
|
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
from xml.etree.ElementTree import Element
|
|
3
|
+
|
|
4
|
+
from ..xml import ID_KEY
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
@dataclass
|
|
8
|
+
class FoundInvalidIDError(Exception):
|
|
9
|
+
invalid_id: str | None
|
|
10
|
+
element: Element
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def validate_id_in_element(element: Element, enable_no_id: bool = False) -> int | FoundInvalidIDError:
|
|
14
|
+
id_str = element.get(ID_KEY, None)
|
|
15
|
+
if id_str is None:
|
|
16
|
+
if enable_no_id:
|
|
17
|
+
return -1
|
|
18
|
+
else:
|
|
19
|
+
return FoundInvalidIDError(
|
|
20
|
+
invalid_id=None,
|
|
21
|
+
element=element,
|
|
22
|
+
)
|
|
23
|
+
try:
|
|
24
|
+
return int(id_str)
|
|
25
|
+
except ValueError:
|
|
26
|
+
return FoundInvalidIDError(
|
|
27
|
+
invalid_id=id_str,
|
|
28
|
+
element=element,
|
|
29
|
+
)
|