batch-store 0.0.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- batch_store-0.0.1/PKG-INFO +17 -0
- batch_store-0.0.1/batch_store.egg-info/PKG-INFO +17 -0
- batch_store-0.0.1/batch_store.egg-info/SOURCES.txt +7 -0
- batch_store-0.0.1/batch_store.egg-info/dependency_links.txt +1 -0
- batch_store-0.0.1/batch_store.egg-info/requires.txt +9 -0
- batch_store-0.0.1/batch_store.egg-info/top_level.txt +2 -0
- batch_store-0.0.1/batch_store.py +526 -0
- batch_store-0.0.1/pyproject.toml +24 -0
- batch_store-0.0.1/setup.cfg +4 -0
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: batch_store
|
|
3
|
+
Version: 0.0.1
|
|
4
|
+
Summary: Batch Store
|
|
5
|
+
Author-email: Sheldon Lee <sheldonlee@outlook.com>
|
|
6
|
+
Classifier: Programming Language :: Python :: 3
|
|
7
|
+
Classifier: Operating System :: OS Independent
|
|
8
|
+
Requires-Python: >=3.12
|
|
9
|
+
Description-Content-Type: text/markdown
|
|
10
|
+
Requires-Dist: rich
|
|
11
|
+
Requires-Dist: scikit-learn
|
|
12
|
+
Provides-Extra: dev
|
|
13
|
+
Requires-Dist: build; extra == "dev"
|
|
14
|
+
Requires-Dist: pytest; extra == "dev"
|
|
15
|
+
Requires-Dist: pytest-env; extra == "dev"
|
|
16
|
+
Requires-Dist: mypy; extra == "dev"
|
|
17
|
+
Requires-Dist: pre-commit; extra == "dev"
|
|
@@ -0,0 +1,17 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: batch_store
|
|
3
|
+
Version: 0.0.1
|
|
4
|
+
Summary: Batch Store
|
|
5
|
+
Author-email: Sheldon Lee <sheldonlee@outlook.com>
|
|
6
|
+
Classifier: Programming Language :: Python :: 3
|
|
7
|
+
Classifier: Operating System :: OS Independent
|
|
8
|
+
Requires-Python: >=3.12
|
|
9
|
+
Description-Content-Type: text/markdown
|
|
10
|
+
Requires-Dist: rich
|
|
11
|
+
Requires-Dist: scikit-learn
|
|
12
|
+
Provides-Extra: dev
|
|
13
|
+
Requires-Dist: build; extra == "dev"
|
|
14
|
+
Requires-Dist: pytest; extra == "dev"
|
|
15
|
+
Requires-Dist: pytest-env; extra == "dev"
|
|
16
|
+
Requires-Dist: mypy; extra == "dev"
|
|
17
|
+
Requires-Dist: pre-commit; extra == "dev"
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
|
|
@@ -0,0 +1,526 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import json
|
|
3
|
+
import logging
|
|
4
|
+
import shutil
|
|
5
|
+
import tempfile
|
|
6
|
+
import time
|
|
7
|
+
import traceback
|
|
8
|
+
from abc import ABC
|
|
9
|
+
from pathlib import Path
|
|
10
|
+
from typing import Any, Dict
|
|
11
|
+
|
|
12
|
+
import httpx
|
|
13
|
+
import pandas as pd
|
|
14
|
+
from pydantic import BaseModel, Field
|
|
15
|
+
from pydantic_settings import BaseSettings, SettingsConfigDict
|
|
16
|
+
from util_common.io_util import b64str2bytes, bytes2b64str, json2bytes
|
|
17
|
+
from util_common.list_util import in_batches
|
|
18
|
+
from util_common.logger import setup_logger
|
|
19
|
+
from util_common.path import (
|
|
20
|
+
duplicate,
|
|
21
|
+
ensure_folder,
|
|
22
|
+
remove_file,
|
|
23
|
+
remove_folder,
|
|
24
|
+
sort_paths,
|
|
25
|
+
split_basename,
|
|
26
|
+
)
|
|
27
|
+
from util_common.pydantic_util import show_settings_as_env
|
|
28
|
+
from util_common.singleton import singleton
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class BatchStoreSettings(BaseSettings):
|
|
32
|
+
"""Workflow settings class that combines all settings."""
|
|
33
|
+
|
|
34
|
+
model_config = SettingsConfigDict(
|
|
35
|
+
case_sensitive=False,
|
|
36
|
+
extra="allow",
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
data_root: Path = Field(default=Path("/home/sheldon/repos/docparser_trainer/data"))
|
|
40
|
+
decompress_base_url: str = "http://192.168.8.251:28001"
|
|
41
|
+
unify_base_url: str = "http://192.168.8.251:28002"
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
batch_store_settings = BatchStoreSettings()
|
|
45
|
+
show_settings_as_env(batch_store_settings)
|
|
46
|
+
# Configure logging
|
|
47
|
+
setup_logger()
|
|
48
|
+
logger = logging.getLogger(__file__)
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def log_file_size(file_type: str, file: bytes):
|
|
52
|
+
size_mb = len(file) / 1024
|
|
53
|
+
logger.info(f'Writing {file_type}, file size: {size_mb:.2f} KB')
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
class TagRecord(BaseModel):
|
|
57
|
+
batch_name: str
|
|
58
|
+
sample_name: str
|
|
59
|
+
page_name: str
|
|
60
|
+
tag: str
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
class Page(BaseModel):
|
|
64
|
+
name: str
|
|
65
|
+
page_dir: Path
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
class Sample(BaseModel):
|
|
69
|
+
name: str
|
|
70
|
+
sample_dir: Path
|
|
71
|
+
pages: list[Page]
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
async def _save_unified_file(
|
|
75
|
+
save_page_dir: Path,
|
|
76
|
+
unified_page: Dict[str, Any],
|
|
77
|
+
):
|
|
78
|
+
"""Helper method to save unified file and its associated data."""
|
|
79
|
+
|
|
80
|
+
ensure_folder(save_page_dir)
|
|
81
|
+
logger.info(f'Writing {save_page_dir.name} ===')
|
|
82
|
+
|
|
83
|
+
if unified_page.get('text'):
|
|
84
|
+
save_page_dir.joinpath('pure.txt').write_text(unified_page['text'])
|
|
85
|
+
else:
|
|
86
|
+
save_page_dir.joinpath('pure.txt').write_text('')
|
|
87
|
+
if unified_page.get('html'):
|
|
88
|
+
save_page_dir.joinpath('raw.html').write_text(unified_page['html'])
|
|
89
|
+
else:
|
|
90
|
+
save_page_dir.joinpath('raw.html').write_text(unified_page['text'])
|
|
91
|
+
|
|
92
|
+
for file_type, file_key, file_name, ext in [
|
|
93
|
+
('Excel', 'xlsx', 'raw', 'xlsx'),
|
|
94
|
+
('PDF', 'pdf', 'raw', 'pdf'),
|
|
95
|
+
('norm_image', 'norm_image', 'norm', 'png'),
|
|
96
|
+
('char_block', 'char_block', 'char_block', 'json'),
|
|
97
|
+
('text_block', 'text_block', 'text_block', 'json'),
|
|
98
|
+
('table_block', 'table_block', 'table_block', 'json'),
|
|
99
|
+
]:
|
|
100
|
+
if unified_page.get(file_key):
|
|
101
|
+
if file_key in ['char_block', 'text_block', 'table_block']:
|
|
102
|
+
content = json2bytes(unified_page[file_key])
|
|
103
|
+
else:
|
|
104
|
+
content = b64str2bytes(unified_page[file_key]['file_b64str'])
|
|
105
|
+
log_file_size(file_type, content)
|
|
106
|
+
save_page_dir.joinpath(f'{file_name}.{ext}').write_bytes(content)
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
async def request_unify_pages(file_name: str, content: bytes):
|
|
110
|
+
res = await make_request(
|
|
111
|
+
UnifyClient().get_client(),
|
|
112
|
+
"/unify-pages",
|
|
113
|
+
{
|
|
114
|
+
"file_name": file_name,
|
|
115
|
+
"file_b64str": bytes2b64str(content),
|
|
116
|
+
"task_settings": {
|
|
117
|
+
"target_results": [
|
|
118
|
+
"text",
|
|
119
|
+
"xlsx",
|
|
120
|
+
"html",
|
|
121
|
+
"pdf",
|
|
122
|
+
"norm_image",
|
|
123
|
+
],
|
|
124
|
+
},
|
|
125
|
+
"step_settings": [
|
|
126
|
+
{
|
|
127
|
+
"step_name": "excel",
|
|
128
|
+
"excel_rows_limit": 500,
|
|
129
|
+
"excel_rows_limit_exceed_schema": "truncate",
|
|
130
|
+
"delete_invalid_rows": False,
|
|
131
|
+
},
|
|
132
|
+
],
|
|
133
|
+
},
|
|
134
|
+
)
|
|
135
|
+
success_pages = res.json()['success_pages']
|
|
136
|
+
failed_files = res.json()['failures']
|
|
137
|
+
return success_pages, failed_files
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
async def unify_pages_and_save(file_path: Path, sample_dir: Path, failed_dir: Path) -> None:
|
|
141
|
+
try:
|
|
142
|
+
success_pages, failed_files = await request_unify_pages(
|
|
143
|
+
file_path.name, file_path.read_bytes()
|
|
144
|
+
)
|
|
145
|
+
|
|
146
|
+
for unified_page in success_pages:
|
|
147
|
+
stem, ext = split_basename(file_path.name)
|
|
148
|
+
pid = unified_page["page_id"]
|
|
149
|
+
sheet_name = unified_page["sheet_name"]
|
|
150
|
+
if sheet_name:
|
|
151
|
+
save_page_dir = sample_dir / f'{stem}-{pid}-{sheet_name}.{ext}'
|
|
152
|
+
else:
|
|
153
|
+
save_page_dir = sample_dir / f'{stem}-{pid}.{ext}'
|
|
154
|
+
await _save_unified_file(save_page_dir, unified_page)
|
|
155
|
+
|
|
156
|
+
failed_sample_path = failed_dir / file_path.parent.name
|
|
157
|
+
if len(failed_files) == 0:
|
|
158
|
+
if failed_sample_path.is_dir():
|
|
159
|
+
remove_folder(failed_sample_path)
|
|
160
|
+
elif failed_sample_path.is_file():
|
|
161
|
+
remove_file(failed_sample_path)
|
|
162
|
+
|
|
163
|
+
for failure in failed_files:
|
|
164
|
+
logger.error(
|
|
165
|
+
'File normalization failed: '
|
|
166
|
+
f'{failure["file_name"]}.{failure["page_id"]}: {failure["error_msg"]}'
|
|
167
|
+
)
|
|
168
|
+
duplicate(file_path.parent, failed_sample_path)
|
|
169
|
+
|
|
170
|
+
except Exception:
|
|
171
|
+
logger.error(f"Error in unify_pages: {traceback.format_exc()}")
|
|
172
|
+
duplicate(file_path.parent, failed_sample_path)
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
async def decompress_sample_and_save(
|
|
176
|
+
sample_path: Path,
|
|
177
|
+
decompressed_dir: Path,
|
|
178
|
+
failed_dir: Path | None = None,
|
|
179
|
+
):
|
|
180
|
+
save_dir = decompressed_dir / sample_path.name
|
|
181
|
+
if sample_path.is_dir():
|
|
182
|
+
try:
|
|
183
|
+
with tempfile.TemporaryDirectory() as tmpdir:
|
|
184
|
+
zip_path = Path(tmpdir) / f"{sample_path.name}.zip"
|
|
185
|
+
shutil.make_archive(str(zip_path.with_suffix('')), 'zip', root_dir=sample_path)
|
|
186
|
+
content = zip_path.read_bytes()
|
|
187
|
+
except Exception:
|
|
188
|
+
logger.error(f"Error zipping sample directory {sample_path}: {traceback.format_exc()}")
|
|
189
|
+
if failed_dir:
|
|
190
|
+
duplicate(sample_path, failed_dir / sample_path.name)
|
|
191
|
+
return []
|
|
192
|
+
else:
|
|
193
|
+
content = sample_path.read_bytes()
|
|
194
|
+
|
|
195
|
+
try:
|
|
196
|
+
payload = {
|
|
197
|
+
"file_name": sample_path.name,
|
|
198
|
+
"file_b64str": bytes2b64str(content),
|
|
199
|
+
}
|
|
200
|
+
res = await make_request(DecompressClient().get_client(), "/decompress", payload)
|
|
201
|
+
response_data = res.json()
|
|
202
|
+
ensure_folder(save_dir)
|
|
203
|
+
for file in response_data['success_files']:
|
|
204
|
+
file_name = file['file_name']
|
|
205
|
+
(save_dir / file_name).write_bytes(b64str2bytes(file['file_b64str']))
|
|
206
|
+
for file in response_data['failed_files']:
|
|
207
|
+
logger.error('Decompression failed: ' f'{file["file_name"]}. {file["error_msg"]}')
|
|
208
|
+
if failed_dir:
|
|
209
|
+
duplicate(sample_path, failed_dir / sample_path.name)
|
|
210
|
+
except Exception:
|
|
211
|
+
logger.error(f"Error in decompress_sample: {traceback.format_exc()}")
|
|
212
|
+
if failed_dir:
|
|
213
|
+
duplicate(sample_path, failed_dir / sample_path.name)
|
|
214
|
+
return []
|
|
215
|
+
return sort_paths(save_dir.iterdir())
|
|
216
|
+
|
|
217
|
+
|
|
218
|
+
class HTTPXClient(ABC):
|
|
219
|
+
def __init__(self, base_url: str, timeout: int = 600, max_retries: int = 3):
|
|
220
|
+
self.base_url = base_url
|
|
221
|
+
self.timeout = timeout
|
|
222
|
+
self.max_retries = max_retries
|
|
223
|
+
self._init_client()
|
|
224
|
+
self.client: httpx.AsyncClient | None = None
|
|
225
|
+
|
|
226
|
+
def _init_client(self):
|
|
227
|
+
"""Initialize the HTTP client."""
|
|
228
|
+
self.client = httpx.AsyncClient(
|
|
229
|
+
base_url=self.base_url,
|
|
230
|
+
timeout=self.timeout,
|
|
231
|
+
limits=httpx.Limits(max_keepalive_connections=5, max_connections=10),
|
|
232
|
+
)
|
|
233
|
+
|
|
234
|
+
async def close(self):
|
|
235
|
+
"""Close the HTTP client."""
|
|
236
|
+
if self.client is not None:
|
|
237
|
+
await self.client.aclose()
|
|
238
|
+
self.client = None
|
|
239
|
+
|
|
240
|
+
def __del__(self):
|
|
241
|
+
"""Ensure client is closed when object is garbage collected."""
|
|
242
|
+
if self.client is not None:
|
|
243
|
+
try:
|
|
244
|
+
# Check if there's already a running event loop
|
|
245
|
+
try:
|
|
246
|
+
loop = asyncio.get_running_loop()
|
|
247
|
+
# If there's a running loop, we can't use asyncio.run()
|
|
248
|
+
# Create a task to close the client
|
|
249
|
+
if not loop.is_closed():
|
|
250
|
+
loop.create_task(self.client.aclose())
|
|
251
|
+
except RuntimeError:
|
|
252
|
+
# No running event loop, safe to use asyncio.run()
|
|
253
|
+
asyncio.run(self.client.aclose())
|
|
254
|
+
except Exception:
|
|
255
|
+
# Ignore any exceptions during cleanup to prevent the warning
|
|
256
|
+
pass
|
|
257
|
+
|
|
258
|
+
def get_client(self):
|
|
259
|
+
"""Get the client, reinitializing it if necessary."""
|
|
260
|
+
if self.client is None:
|
|
261
|
+
self._init_client()
|
|
262
|
+
return self.client
|
|
263
|
+
|
|
264
|
+
|
|
265
|
+
@singleton
|
|
266
|
+
class DecompressClient(HTTPXClient):
|
|
267
|
+
def __init__(self):
|
|
268
|
+
super().__init__(batch_store_settings.decompress_base_url)
|
|
269
|
+
|
|
270
|
+
|
|
271
|
+
@singleton
|
|
272
|
+
class UnifyClient(HTTPXClient):
|
|
273
|
+
def __init__(self):
|
|
274
|
+
super().__init__(batch_store_settings.unify_base_url)
|
|
275
|
+
|
|
276
|
+
|
|
277
|
+
async def close_all_clients():
|
|
278
|
+
"""Close all singleton HTTP clients."""
|
|
279
|
+
await DecompressClient().close()
|
|
280
|
+
await UnifyClient().close()
|
|
281
|
+
|
|
282
|
+
|
|
283
|
+
async def make_request(
|
|
284
|
+
client: httpx.AsyncClient | None, url: str, payload: Dict[str, Any], max_retries: int = 3
|
|
285
|
+
) -> httpx.Response:
|
|
286
|
+
"""Make an HTTP request with retries and proper error handling."""
|
|
287
|
+
if client is None:
|
|
288
|
+
raise ValueError("Client cannot be None")
|
|
289
|
+
|
|
290
|
+
for attempt in range(max_retries):
|
|
291
|
+
try:
|
|
292
|
+
response = await client.post(url=url, json=payload)
|
|
293
|
+
response.raise_for_status()
|
|
294
|
+
logger.info(f'Request {url} success')
|
|
295
|
+
return response
|
|
296
|
+
except httpx.HTTPStatusError as e:
|
|
297
|
+
logger.warning(
|
|
298
|
+
f"HTTP error occurred: {e.response.status_code} - {e.response.text}, retrying..."
|
|
299
|
+
)
|
|
300
|
+
if attempt == max_retries - 1:
|
|
301
|
+
raise e
|
|
302
|
+
except httpx.RequestError as e:
|
|
303
|
+
logger.warning(f"Request error occurred: {str(e)}, retrying...")
|
|
304
|
+
if attempt == max_retries - 1:
|
|
305
|
+
raise e
|
|
306
|
+
except Exception as e:
|
|
307
|
+
logger.warning(f"Unexpected error occurred: {traceback.format_exc()}, retrying...")
|
|
308
|
+
if attempt == max_retries - 1:
|
|
309
|
+
raise e
|
|
310
|
+
await asyncio.sleep(1 * (attempt + 1)) # Exponential backoff
|
|
311
|
+
raise Exception('Failed to make request')
|
|
312
|
+
|
|
313
|
+
|
|
314
|
+
class BatchStore:
|
|
315
|
+
"""
|
|
316
|
+
功能:
|
|
317
|
+
对一个批次内的样本进行统一的数据预处理, 以及存储操作, 并提供一些辅助功能
|
|
318
|
+
|
|
319
|
+
目录结构:
|
|
320
|
+
不指定 batch_dir 时, 数据根目录结构:
|
|
321
|
+
|-workflow_settings.data_root
|
|
322
|
+
| |-batches
|
|
323
|
+
| | |-batch-1
|
|
324
|
+
| | | |-raw:
|
|
325
|
+
| | | | |-sample_1.zip
|
|
326
|
+
| | | | |-sample_2
|
|
327
|
+
| | | | |-sample_3.pdf
|
|
328
|
+
| | | | |-sample_4.xls
|
|
329
|
+
| | | | |-...
|
|
330
|
+
|
|
331
|
+
指定 batch_dir 时, 数据根目录结构:
|
|
332
|
+
|-batch_dir
|
|
333
|
+
| |-raw:
|
|
334
|
+
| | |-sample_1.zip
|
|
335
|
+
| | |-sample_2
|
|
336
|
+
| | |-sample_3.pdf
|
|
337
|
+
| | |-sample_4.xls
|
|
338
|
+
| | |-...
|
|
339
|
+
|
|
340
|
+
依赖:
|
|
341
|
+
解压缩服务
|
|
342
|
+
文件归一化服务
|
|
343
|
+
|
|
344
|
+
使用方法:
|
|
345
|
+
配置环境变量 DATA_ROOT 为数据根目录
|
|
346
|
+
配置环境变量 DECOMPRESS_BASE_URL 为解压缩服务地址
|
|
347
|
+
配置环境变量 UNIFY_BASE_URL 为文件归一化服务地址
|
|
348
|
+
|
|
349
|
+
使用示例:
|
|
350
|
+
batch_store = BatchStore('batch-1')
|
|
351
|
+
asyncio.run(batch_store.preprocess_batch())
|
|
352
|
+
"""
|
|
353
|
+
|
|
354
|
+
def __init__(self, batch_name: str | None = None, batch_dir: Path | None = None) -> None:
|
|
355
|
+
if batch_name is None and batch_dir is None:
|
|
356
|
+
raise ValueError('batch_name or batch_dir must be provided')
|
|
357
|
+
if batch_name is not None and batch_dir is not None:
|
|
358
|
+
if batch_dir.name != batch_name:
|
|
359
|
+
raise ValueError('batch_name and batch_dir must have the same name')
|
|
360
|
+
self.batch_name = batch_name or batch_dir.name # type: ignore
|
|
361
|
+
self.batch_dir = batch_dir or batch_store_settings.data_root / 'batches' / self.batch_name
|
|
362
|
+
self.raw_dir = self.batch_dir / 'raw'
|
|
363
|
+
self.decompressed_dir = self.batch_dir / 'decompressed'
|
|
364
|
+
self.unified_dir = self.batch_dir / 'unified'
|
|
365
|
+
self.failed_dir = self.batch_dir / 'failed'
|
|
366
|
+
self.results_dir = self.batch_dir / 'results'
|
|
367
|
+
self.tag_dir = self.batch_dir / 'tag'
|
|
368
|
+
self.classified_dir = self.batch_dir / 'classified'
|
|
369
|
+
self.compare_dir = self.batch_dir / 'compare'
|
|
370
|
+
self.sessions_dir = self.batch_dir / 'sessions'
|
|
371
|
+
|
|
372
|
+
async def _preprocess_sample(
|
|
373
|
+
self, sample_path: Path, fix_broken: bool = False, check_empty: bool = False
|
|
374
|
+
) -> None:
|
|
375
|
+
save_dir = self.unified_dir / sample_path.name
|
|
376
|
+
if fix_broken and save_dir.is_dir():
|
|
377
|
+
return
|
|
378
|
+
try:
|
|
379
|
+
if check_empty is True:
|
|
380
|
+
for page_dir in sort_paths(save_dir.iterdir()):
|
|
381
|
+
if page_dir.joinpath('pure.txt').read_text().strip():
|
|
382
|
+
continue
|
|
383
|
+
if not page_dir.joinpath('raw.pdf').exists():
|
|
384
|
+
continue
|
|
385
|
+
logger.info(f'Checking empty page: {page_dir.name} ...')
|
|
386
|
+
success_pages, failed_files = await request_unify_pages(
|
|
387
|
+
page_dir.joinpath('raw.pdf').name, page_dir.joinpath('raw.pdf').read_bytes()
|
|
388
|
+
)
|
|
389
|
+
if len(success_pages) == 1:
|
|
390
|
+
await _save_unified_file(page_dir, success_pages[0])
|
|
391
|
+
if page_dir.joinpath('pure.txt').read_text().strip():
|
|
392
|
+
logger.info(f'Extract page success: {page_dir.name}')
|
|
393
|
+
else:
|
|
394
|
+
logger.info(f'Empty page checked! {page_dir.name}')
|
|
395
|
+
else:
|
|
396
|
+
logger.error(f'Check empty page failed! {page_dir.name}')
|
|
397
|
+
else:
|
|
398
|
+
file_paths = await decompress_sample_and_save(
|
|
399
|
+
sample_path, self.decompressed_dir, self.failed_dir
|
|
400
|
+
)
|
|
401
|
+
await asyncio.gather(
|
|
402
|
+
*[
|
|
403
|
+
unify_pages_and_save(file_path, save_dir, self.failed_dir)
|
|
404
|
+
for file_path in file_paths
|
|
405
|
+
]
|
|
406
|
+
)
|
|
407
|
+
except Exception:
|
|
408
|
+
logger.error(f"Error processing sample {sample_path}: {traceback.format_exc()}")
|
|
409
|
+
duplicate(sample_path, self.failed_dir / sample_path.name)
|
|
410
|
+
|
|
411
|
+
async def preprocess_batch(
|
|
412
|
+
self,
|
|
413
|
+
concurrency: int = 1,
|
|
414
|
+
failed_only: bool = False, # 只需要运行 failed 下的样本
|
|
415
|
+
fix_broken: bool = False, # 跳过已经处理过的样本
|
|
416
|
+
check_empty: bool = False, # 检查样本 pure.txt 是否为空, 如果为空再试一次
|
|
417
|
+
) -> None:
|
|
418
|
+
"""Process all samples in the batch with proper error handling."""
|
|
419
|
+
# Ensure all required directories exist
|
|
420
|
+
for dir_path in [self.raw_dir, self.decompressed_dir, self.unified_dir, self.failed_dir]:
|
|
421
|
+
ensure_folder(dir_path)
|
|
422
|
+
|
|
423
|
+
try:
|
|
424
|
+
if failed_only:
|
|
425
|
+
sample_paths = sort_paths(self.failed_dir.iterdir())
|
|
426
|
+
else:
|
|
427
|
+
sample_paths = sort_paths(self.raw_dir.iterdir())
|
|
428
|
+
|
|
429
|
+
for i, batch in enumerate(in_batches(sample_paths, concurrency)):
|
|
430
|
+
start_time = time.time()
|
|
431
|
+
await asyncio.gather(
|
|
432
|
+
*[
|
|
433
|
+
self._preprocess_sample(sample_path, fix_broken, check_empty)
|
|
434
|
+
for sample_path in batch
|
|
435
|
+
]
|
|
436
|
+
)
|
|
437
|
+
end_time = time.time()
|
|
438
|
+
logger.info(
|
|
439
|
+
f'### Preprocess batch {i + 1} finished in {end_time - start_time:.2f} seconds'
|
|
440
|
+
)
|
|
441
|
+
except Exception as e:
|
|
442
|
+
logger.error(f"Error in preprocess_batch: {traceback.format_exc()}")
|
|
443
|
+
raise e
|
|
444
|
+
finally:
|
|
445
|
+
await close_all_clients()
|
|
446
|
+
|
|
447
|
+
def load_unified_samples(self) -> list[Sample]:
|
|
448
|
+
samples = []
|
|
449
|
+
for sample_dir in sort_paths(self.unified_dir.iterdir()):
|
|
450
|
+
pages = []
|
|
451
|
+
for page_dir in sort_paths(sample_dir.iterdir()):
|
|
452
|
+
pages.append(Page(name=page_dir.name, page_dir=page_dir))
|
|
453
|
+
samples.append(Sample(name=sample_dir.name, sample_dir=sample_dir, pages=pages))
|
|
454
|
+
return samples
|
|
455
|
+
|
|
456
|
+
def save_sample(self, content: bytes, file_name: str):
|
|
457
|
+
save_path = self.raw_dir / file_name
|
|
458
|
+
ensure_folder(save_path.parent)
|
|
459
|
+
save_path.write_bytes(content)
|
|
460
|
+
|
|
461
|
+
def save_sample_result(
|
|
462
|
+
self,
|
|
463
|
+
sample_name: str,
|
|
464
|
+
result: list[dict] | dict,
|
|
465
|
+
test_id: str,
|
|
466
|
+
session_id: int,
|
|
467
|
+
):
|
|
468
|
+
save_dir = self.results_dir / test_id / f'{test_id}-{session_id}'
|
|
469
|
+
save_dir.mkdir(parents=True, exist_ok=True)
|
|
470
|
+
save_path = save_dir / f'{sample_name}.json'
|
|
471
|
+
save_path.write_bytes(json2bytes(result))
|
|
472
|
+
|
|
473
|
+
def get_sample_result(self, sample_name, test_id: str, session_id: int = 0):
|
|
474
|
+
result_path = self.results_dir / test_id / f'{test_id}-{session_id}' / f'{sample_name}.json'
|
|
475
|
+
return json.loads(result_path.read_text())
|
|
476
|
+
|
|
477
|
+
def get_compare_dir(self, test_id: str, session_id: int):
|
|
478
|
+
compare_dir = self.compare_dir / test_id / f'{test_id}-{session_id}'
|
|
479
|
+
compare_dir.mkdir(parents=True, exist_ok=True)
|
|
480
|
+
return compare_dir
|
|
481
|
+
|
|
482
|
+
def backup_tag(self, tag_name: str):
|
|
483
|
+
records = []
|
|
484
|
+
for sample_dir in sort_paths(self.unified_dir.iterdir()):
|
|
485
|
+
for page_dir in sort_paths(sample_dir.iterdir()):
|
|
486
|
+
tag_path = page_dir / f"tag-{tag_name}.json"
|
|
487
|
+
if tag_path.exists():
|
|
488
|
+
record = TagRecord(
|
|
489
|
+
batch_name=self.batch_name,
|
|
490
|
+
sample_name=sample_dir.name,
|
|
491
|
+
page_name=page_dir.name,
|
|
492
|
+
tag=tag_path.read_text(),
|
|
493
|
+
)
|
|
494
|
+
records.append(record.model_dump())
|
|
495
|
+
|
|
496
|
+
df = pd.DataFrame.from_records(records)
|
|
497
|
+
logger.info(f'Saving {tag_name}.csv...')
|
|
498
|
+
logger.info(f'total {df.shape[0]} records')
|
|
499
|
+
logger.info(f'first 5 records: {df.head()}')
|
|
500
|
+
logger.info(f'last 5 records: {df.tail()}')
|
|
501
|
+
ensure_folder(self.tag_dir)
|
|
502
|
+
df.to_csv(f'{self.tag_dir}/{tag_name}.csv', index=False)
|
|
503
|
+
|
|
504
|
+
def restore_tag(self, tag_name: str):
|
|
505
|
+
df = pd.read_csv(f'{self.tag_dir}/{tag_name}.csv')
|
|
506
|
+
for _, row in df.iterrows():
|
|
507
|
+
sample_name = row['sample_name']
|
|
508
|
+
page_name = row['page_name']
|
|
509
|
+
tag = row['tag']
|
|
510
|
+
save_path = self.unified_dir / sample_name / page_name / f'tag-{tag_name}.json'
|
|
511
|
+
try:
|
|
512
|
+
save_path.write_text(tag) # type: ignore
|
|
513
|
+
except Exception:
|
|
514
|
+
logger.error(f'{sample_name} {page_name} {tag_name} restore failed, skip')
|
|
515
|
+
|
|
516
|
+
|
|
517
|
+
if __name__ == '__main__':
|
|
518
|
+
batch_store = BatchStore('batch-temp')
|
|
519
|
+
asyncio.run(
|
|
520
|
+
batch_store.preprocess_batch(
|
|
521
|
+
concurrency=1,
|
|
522
|
+
failed_only=False,
|
|
523
|
+
fix_broken=False,
|
|
524
|
+
check_empty=False,
|
|
525
|
+
)
|
|
526
|
+
)
|
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
[build-system]
|
|
2
|
+
requires = ["setuptools"]
|
|
3
|
+
build-backend = "setuptools.build_meta"
|
|
4
|
+
|
|
5
|
+
[project]
|
|
6
|
+
name = "batch_store"
|
|
7
|
+
version = "0.0.1"
|
|
8
|
+
description = "Batch Store"
|
|
9
|
+
readme = "README.md"
|
|
10
|
+
requires-python = ">=3.12"
|
|
11
|
+
classifiers = ["Programming Language :: Python :: 3", "Operating System :: OS Independent"]
|
|
12
|
+
dependencies = ["rich", "scikit-learn"]
|
|
13
|
+
[[project.authors]]
|
|
14
|
+
name = "Sheldon Lee"
|
|
15
|
+
email = "sheldonlee@outlook.com"
|
|
16
|
+
|
|
17
|
+
[project.optional-dependencies]
|
|
18
|
+
dev = ["build", "pytest", "pytest-env", "mypy", "pre-commit"]
|
|
19
|
+
|
|
20
|
+
[tool.sys-dependencies]
|
|
21
|
+
apt = []
|
|
22
|
+
|
|
23
|
+
[tool.setuptools]
|
|
24
|
+
py-modules = ["batch_store", "settings"]
|