openocr-python 0.0.9__py3-none-any.whl → 0.1.0.dev0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- openocr/__init__.py +35 -1
- openocr/configs/dataset/rec/evaluation.yaml +41 -0
- openocr/configs/dataset/rec/ltb.yaml +9 -0
- openocr/configs/dataset/rec/mjsynth.yaml +11 -0
- openocr/configs/dataset/rec/openvino.yaml +25 -0
- openocr/configs/dataset/rec/ost.yaml +17 -0
- openocr/configs/dataset/rec/synthtext.yaml +7 -0
- openocr/configs/dataset/rec/test.yaml +77 -0
- openocr/configs/dataset/rec/textocr.yaml +13 -0
- openocr/configs/dataset/rec/textocr_horizontal.yaml +13 -0
- openocr/configs/dataset/rec/union14m_b.yaml +47 -0
- openocr/configs/dataset/rec/union14m_l_filtered.yaml +35 -0
- openocr/configs/rec/cmer/cmer.yml +127 -0
- openocr/configs/rec/mdiff4str/svtrv2_mdiffdecoder_base.yml +152 -0
- openocr/configs/rec/mdiff4str/svtrv2_mdiffdecoder_small.yml +152 -0
- openocr/configs/rec/unirec/focalsvtr_ardecoder_unirec.yml +114 -0
- openocr/configs/rec/unirec/opendoc_pipeline.yml +105 -0
- openocr/demo_gradio.py +28 -8
- openocr/demo_opendoc.py +572 -0
- openocr/demo_unirec.py +392 -0
- openocr/opendet/losses/__init__.py +5 -7
- openocr/opendet/preprocess/crop_resize.py +2 -1
- openocr/openocr.py +685 -0
- openocr/openrec/losses/__init__.py +8 -3
- openocr/openrec/losses/cmer_loss.py +12 -0
- openocr/openrec/losses/mdiff_loss.py +11 -0
- openocr/openrec/losses/unirec_loss.py +12 -0
- openocr/openrec/metrics/__init__.py +4 -1
- openocr/openrec/metrics/rec_metric_cmer.py +328 -0
- openocr/openrec/modeling/cmer_modeling/modeling_cmer.py +643 -0
- openocr/openrec/modeling/decoders/__init__.py +1 -0
- openocr/openrec/modeling/decoders/ctc_decoder.py +1 -1
- openocr/openrec/modeling/decoders/dan_decoder.py +4 -4
- openocr/openrec/modeling/decoders/dptr_parseq_clip_b_decoder.py +1563 -1398
- openocr/openrec/modeling/decoders/mdiff_decoder.py +587 -0
- openocr/openrec/modeling/decoders/smtr_decoder.py +99 -48
- openocr/openrec/modeling/unirec_modeling/configuration_unirec.py +166 -0
- openocr/openrec/modeling/unirec_modeling/modeling_unirec.py +433 -0
- openocr/openrec/optimizer/__init__.py +4 -3
- openocr/openrec/optimizer/lr.py +49 -0
- openocr/openrec/postprocess/__init__.py +2 -0
- openocr/openrec/postprocess/abinet_postprocess.py +1 -1
- openocr/openrec/postprocess/ar_postprocess.py +1 -1
- openocr/openrec/postprocess/cmer_postprocess.py +86 -0
- openocr/openrec/postprocess/cppd_postprocess.py +1 -1
- openocr/openrec/postprocess/igtr_postprocess.py +1 -1
- openocr/openrec/postprocess/lister_postprocess.py +1 -1
- openocr/openrec/postprocess/mgp_postprocess.py +1 -1
- openocr/openrec/postprocess/nrtr_postprocess.py +2 -2
- openocr/openrec/postprocess/smtr_postprocess.py +1 -1
- openocr/openrec/postprocess/srn_postprocess.py +1 -1
- openocr/openrec/postprocess/unirec_postprocess.py +58 -0
- openocr/openrec/postprocess/visionlan_postprocess.py +1 -1
- openocr/openrec/preprocess/__init__.py +5 -0
- openocr/openrec/preprocess/ce_label_encode.py +1 -1
- openocr/openrec/preprocess/cmer_label_encode.py +1025 -0
- openocr/openrec/preprocess/ctc_label_encode.py +1 -1
- openocr/openrec/preprocess/dptr_label_encode.py +177 -157
- openocr/openrec/preprocess/igtr_label_encode.py +4 -2
- openocr/openrec/preprocess/mdiff_label_encode.py +312 -0
- openocr/openrec/preprocess/rec_aug.py +128 -2
- openocr/openrec/preprocess/resize.py +57 -0
- openocr/openrec/preprocess/unirec_label_encode.py +62 -0
- openocr/tools/data/__init__.py +78 -55
- openocr/tools/data/cmer_web_dataset.py +310 -0
- openocr/tools/data/native_size_dataset.py +753 -0
- openocr/tools/data/native_size_sampler.py +158 -0
- openocr/tools/data/ratio_dataset_tvresize.py +2 -0
- openocr/tools/data/ratio_sampler.py +2 -1
- openocr/tools/download/download_dataset.py +38 -0
- openocr/tools/download/utils.py +28 -0
- openocr/tools/download_example_images.py +236 -0
- openocr/tools/engine/trainer.py +155 -39
- openocr/tools/eval_rec_all_ch.py +2 -2
- openocr/tools/infer_det.py +20 -2
- openocr/tools/infer_doc.py +898 -0
- openocr/tools/infer_doc_onnx.py +1172 -0
- openocr/tools/infer_e2e.py +27 -10
- openocr/tools/infer_rec.py +64 -15
- openocr/tools/infer_unirec_onnx.py +730 -0
- openocr/tools/to_markdown.py +468 -0
- openocr/tools/utils/ckpt.py +17 -5
- openocr/tools/utils/opendoc_onnx_utils/utils.py +1052 -0
- openocr_python-0.1.0.dev0.dist-info/METADATA +324 -0
- {openocr_python-0.0.9.dist-info → openocr_python-0.1.0.dev0.dist-info}/RECORD +89 -45
- {openocr_python-0.0.9.dist-info → openocr_python-0.1.0.dev0.dist-info}/WHEEL +1 -1
- openocr_python-0.1.0.dev0.dist-info/entry_points.txt +2 -0
- openocr_python-0.0.9.dist-info/METADATA +0 -149
- /openocr_python-0.0.9.dist-info/LICENCE → /openocr_python-0.1.0.dev0.dist-info/licenses/LICENSE +0 -0
- {openocr_python-0.0.9.dist-info → openocr_python-0.1.0.dev0.dist-info}/top_level.txt +0 -0
openocr/tools/data/__init__.py
CHANGED
|
@@ -19,13 +19,16 @@ DATASET_MODULES = {
|
|
|
19
19
|
'RatioDataSet': 'tools.data.ratio_dataset',
|
|
20
20
|
'RatioDataSetTest': 'tools.data.ratio_dataset_test',
|
|
21
21
|
'RatioDataSetTVResize': 'tools.data.ratio_dataset_tvresize',
|
|
22
|
-
'RatioDataSetTVResizeTest': 'tools.data.ratio_dataset_tvresize_test'
|
|
22
|
+
'RatioDataSetTVResizeTest': 'tools.data.ratio_dataset_tvresize_test',
|
|
23
|
+
'NaSizeDataSet': 'tools.data.native_size_dataset',
|
|
24
|
+
'CMERWebDataSet': 'tools.data.cmer_web_dataset',
|
|
23
25
|
}
|
|
24
26
|
|
|
25
27
|
# 定义支持的 Sampler 类及其对应的模块路径
|
|
26
28
|
SAMPLER_MODULES = {
|
|
27
29
|
'MultiScaleSampler': 'tools.data.multi_scale_sampler',
|
|
28
|
-
'RatioSampler': 'tools.data.ratio_sampler'
|
|
30
|
+
'RatioSampler': 'tools.data.ratio_sampler',
|
|
31
|
+
'NaSizeSampler': 'tools.data.native_size_sampler',
|
|
29
32
|
}
|
|
30
33
|
|
|
31
34
|
__all__ = [
|
|
@@ -33,9 +36,9 @@ __all__ = [
|
|
|
33
36
|
]
|
|
34
37
|
|
|
35
38
|
|
|
36
|
-
def build_dataloader(config, mode, logger, seed=None, epoch=
|
|
39
|
+
def build_dataloader(config, mode, logger, seed=None, epoch=1, task='rec'):
|
|
37
40
|
config = copy.deepcopy(config)
|
|
38
|
-
mode = mode.capitalize()
|
|
41
|
+
mode = mode.capitalize()
|
|
39
42
|
|
|
40
43
|
# 获取 dataset 配置
|
|
41
44
|
dataset_config = config[mode]['dataset']
|
|
@@ -53,61 +56,81 @@ def build_dataloader(config, mode, logger, seed=None, epoch=3, task='rec'):
|
|
|
53
56
|
|
|
54
57
|
# DataLoader 配置
|
|
55
58
|
loader_config = config[mode]['loader']
|
|
56
|
-
batch_size = loader_config['batch_size_per_card']
|
|
57
|
-
drop_last = loader_config['drop_last']
|
|
58
|
-
shuffle = loader_config['shuffle']
|
|
59
59
|
num_workers = loader_config['num_workers']
|
|
60
60
|
pin_memory = loader_config.get('pin_memory', False)
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
69
|
-
raise ValueError(
|
|
70
|
-
f'Unsupported sampler: {sampler_name}. Supported samplers: {list(SAMPLER_MODULES.keys())}'
|
|
71
|
-
)
|
|
72
|
-
|
|
73
|
-
sampler_module = importlib.import_module(SAMPLER_MODULES[sampler_name])
|
|
74
|
-
sampler_class = getattr(sampler_module, sampler_name)
|
|
75
|
-
batch_sampler = sampler_class(dataset, **sampler_config)
|
|
76
|
-
elif config['Global']['distributed'] and mode == 'Train':
|
|
77
|
-
sampler = DistributedSampler(dataset=dataset, shuffle=shuffle)
|
|
78
|
-
|
|
79
|
-
if 'collate_fn' in loader_config:
|
|
80
|
-
from . import collate_fn
|
|
81
|
-
collate_fn = getattr(collate_fn, loader_config['collate_fn'])()
|
|
82
|
-
else:
|
|
83
|
-
collate_fn = None
|
|
84
|
-
|
|
85
|
-
if batch_sampler is None:
|
|
86
|
-
data_loader = DataLoader(
|
|
87
|
-
dataset=dataset,
|
|
88
|
-
sampler=sampler,
|
|
61
|
+
if module_name == 'CMERWebDataSet':
|
|
62
|
+
logger.info(f"Building WebLoader for {module_name} (IterableDataset mode)...")
|
|
63
|
+
import webdataset as wds
|
|
64
|
+
persistent = num_workers > 0
|
|
65
|
+
data_loader = wds.WebLoader(
|
|
66
|
+
dataset,
|
|
67
|
+
batch_size=None, # 必须为 None,因为 dataset yield 的已经是 batch
|
|
68
|
+
shuffle=False, # 外部不打乱,内部处理
|
|
89
69
|
num_workers=num_workers,
|
|
90
|
-
pin_memory=
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
drop_last=drop_last,
|
|
70
|
+
pin_memory=True,
|
|
71
|
+
prefetch_factor=4,
|
|
72
|
+
persistent_workers=persistent,
|
|
94
73
|
)
|
|
74
|
+
total_iter_steps = config['Global'].get('total_iter_steps', 1000000)
|
|
75
|
+
data_loader = data_loader.with_length(total_iter_steps)
|
|
76
|
+
return data_loader
|
|
95
77
|
else:
|
|
96
|
-
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
78
|
+
batch_size = loader_config['batch_size_per_card']
|
|
79
|
+
drop_last = loader_config['drop_last']
|
|
80
|
+
shuffle = loader_config['shuffle']
|
|
81
|
+
sampler = None
|
|
82
|
+
batch_sampler = None
|
|
83
|
+
if 'sampler' in config[mode]:
|
|
84
|
+
sampler_config = config[mode]['sampler']
|
|
85
|
+
sampler_name = sampler_config.pop('name')
|
|
86
|
+
|
|
87
|
+
if sampler_name not in SAMPLER_MODULES:
|
|
88
|
+
raise ValueError(
|
|
89
|
+
f'Unsupported sampler: {sampler_name}. Supported samplers: {list(SAMPLER_MODULES.keys())}'
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
sampler_module = importlib.import_module(SAMPLER_MODULES[sampler_name])
|
|
93
|
+
sampler_class = getattr(sampler_module, sampler_name)
|
|
94
|
+
batch_sampler = sampler_class(dataset, **sampler_config)
|
|
95
|
+
elif config['Global']['distributed'] and mode == 'Train':
|
|
96
|
+
sampler = DistributedSampler(dataset=dataset, shuffle=shuffle)
|
|
97
|
+
|
|
98
|
+
if hasattr(dataset, 'collate_fn'):
|
|
99
|
+
collate_fn = dataset.collate_fn
|
|
100
|
+
logger.info(f'Using collate_fn defined in {mode} dataset.')
|
|
101
|
+
else:
|
|
102
|
+
if 'collate_fn' in loader_config:
|
|
103
|
+
from . import collate_fn
|
|
104
|
+
collate_fn = getattr(collate_fn, loader_config['collate_fn'])()
|
|
105
|
+
else:
|
|
106
|
+
collate_fn = None
|
|
107
|
+
|
|
108
|
+
if batch_sampler is None:
|
|
109
|
+
data_loader = DataLoader(
|
|
110
|
+
dataset=dataset,
|
|
111
|
+
sampler=sampler,
|
|
112
|
+
num_workers=num_workers,
|
|
113
|
+
pin_memory=pin_memory,
|
|
114
|
+
collate_fn=collate_fn,
|
|
115
|
+
batch_size=batch_size,
|
|
116
|
+
drop_last=drop_last,
|
|
117
|
+
)
|
|
118
|
+
else:
|
|
119
|
+
data_loader = DataLoader(
|
|
120
|
+
dataset=dataset,
|
|
121
|
+
batch_sampler=batch_sampler,
|
|
122
|
+
num_workers=num_workers,
|
|
123
|
+
pin_memory=pin_memory,
|
|
124
|
+
collate_fn=collate_fn,
|
|
125
|
+
)
|
|
103
126
|
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
|
|
110
|
-
|
|
111
|
-
|
|
127
|
+
# 检查数据加载器是否为空
|
|
128
|
+
if len(data_loader) == 0:
|
|
129
|
+
logger.error(
|
|
130
|
+
f'No Images in {mode.lower()} dataloader. Please check:\n'
|
|
131
|
+
'\t1. The images num in the train label_file_list should be >= batch size.\n'
|
|
132
|
+
'\t2. The annotation file and path in the configuration are correct.\n'
|
|
133
|
+
'\t3. The BatchSize is not larger than the number of images.')
|
|
134
|
+
sys.exit()
|
|
112
135
|
|
|
113
|
-
|
|
136
|
+
return data_loader
|
|
@@ -0,0 +1,310 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import glob
|
|
3
|
+
import json
|
|
4
|
+
import math
|
|
5
|
+
import random
|
|
6
|
+
import numpy as np
|
|
7
|
+
import pandas as pd
|
|
8
|
+
import torch
|
|
9
|
+
import webdataset as wds
|
|
10
|
+
from torch.utils.data import IterableDataset
|
|
11
|
+
from io import BytesIO
|
|
12
|
+
from PIL import Image
|
|
13
|
+
from functools import partial
|
|
14
|
+
from collections import Counter
|
|
15
|
+
from webdataset import handlers
|
|
16
|
+
import importlib
|
|
17
|
+
# Global counter for drop statistics, as used in the original code
|
|
18
|
+
_DROP_STATS = Counter()
|
|
19
|
+
def sanitize_keys(sample):
|
|
20
|
+
new_sample = sample.copy()
|
|
21
|
+
for key in list(sample.keys()):
|
|
22
|
+
if key.startswith("__"):
|
|
23
|
+
continue
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
if "." in key:
|
|
27
|
+
ext = key.split(".")[-1]
|
|
28
|
+
|
|
29
|
+
if ext not in new_sample:
|
|
30
|
+
new_sample[ext] = sample[key]
|
|
31
|
+
|
|
32
|
+
return new_sample
|
|
33
|
+
# --- Helper Functions (Must be defined at module level for pickling in multiprocessing) ---
|
|
34
|
+
|
|
35
|
+
def keep_by_meta(sample, longside_max=12000, area_max=80_000_000, ar_max=20.0,
|
|
36
|
+
shortside_min=16, require_positive_wh=True, max_tokens=1536, require_tokens=True):
|
|
37
|
+
try:
|
|
38
|
+
w = int(sample.get("width", 0) or 0)
|
|
39
|
+
h = int(sample.get("height", 0) or 0)
|
|
40
|
+
|
|
41
|
+
if require_positive_wh and (w <= 0 or h <= 0):
|
|
42
|
+
_DROP_STATS["nonpos_wh"] += 1
|
|
43
|
+
return False
|
|
44
|
+
|
|
45
|
+
L, S = max(w, h), min(w, h)
|
|
46
|
+
A = w * h
|
|
47
|
+
ar = (w / h) if (h > 0) else math.inf
|
|
48
|
+
|
|
49
|
+
if L > longside_max: _DROP_STATS["longside"] += 1; return False
|
|
50
|
+
if A > area_max: _DROP_STATS["area"] += 1; return False
|
|
51
|
+
if S < shortside_min: _DROP_STATS["shortside"] += 1; return False
|
|
52
|
+
if (ar > ar_max) or (ar < 1.0 / ar_max): _DROP_STATS["ar"] += 1; return False
|
|
53
|
+
|
|
54
|
+
tok = sample.get("tokens", None)
|
|
55
|
+
if tok is None:
|
|
56
|
+
if require_tokens: _DROP_STATS["no_tokens"] += 1; return False
|
|
57
|
+
else:
|
|
58
|
+
try:
|
|
59
|
+
if int(tok) > max_tokens: _DROP_STATS["tokens"] += 1; return False
|
|
60
|
+
except Exception:
|
|
61
|
+
if require_tokens: _DROP_STATS["bad_tokens_val"] += 1; return False
|
|
62
|
+
return True
|
|
63
|
+
except Exception:
|
|
64
|
+
_DROP_STATS["exception"] += 1
|
|
65
|
+
return False
|
|
66
|
+
|
|
67
|
+
def parse_json_tuple_meta_only(sample):
|
|
68
|
+
js, img_bytes = sample
|
|
69
|
+
if isinstance(js, (bytes, bytearray)): js = json.loads(js.decode("utf-8"))
|
|
70
|
+
elif isinstance(js, str): js = json.loads(js)
|
|
71
|
+
|
|
72
|
+
return {
|
|
73
|
+
"id": js.get('id', ''),
|
|
74
|
+
"img_bytes": img_bytes,
|
|
75
|
+
"tex": js["tex"],
|
|
76
|
+
"tokens": int(js["tokens"]),
|
|
77
|
+
"width": int(js.get("width", 0) or 0),
|
|
78
|
+
"height": int(js.get("height", 0) or 0),
|
|
79
|
+
"category": js.get('category', '')
|
|
80
|
+
}
|
|
81
|
+
|
|
82
|
+
def add_ar_bin(sample, k=5, clamp=2.0):
|
|
83
|
+
w, h = sample.get("width", 0), sample.get("height", 0)
|
|
84
|
+
r = float(w) / float(h) if h else 1.0
|
|
85
|
+
logar = math.log2(max(r, 1e-6))
|
|
86
|
+
b = int(round(k * logar))
|
|
87
|
+
sample["ar_bin"] = max(min(b, int(k * clamp)), -int(k * clamp))
|
|
88
|
+
return sample
|
|
89
|
+
|
|
90
|
+
def add_short_edge_bin(sample, se_bin_size=96, se_min=96, se_max=1536):
|
|
91
|
+
w, h = sample.get("width", 0), sample.get("height", 0)
|
|
92
|
+
se = int(min(w, h)) if (w and h) else 0
|
|
93
|
+
if se <= 0: se = se_min
|
|
94
|
+
se = max(se_min, min(se, se_max))
|
|
95
|
+
sample["se_bin"] = int(se // se_bin_size)
|
|
96
|
+
return sample
|
|
97
|
+
|
|
98
|
+
def batch_to_inputs_decode_late_safe(batch, processor, max_length):
|
|
99
|
+
images, texts, ids, categorys = [], [], [], []
|
|
100
|
+
for s in batch:
|
|
101
|
+
img_bytes = s.get("img_bytes", None)
|
|
102
|
+
if not img_bytes:
|
|
103
|
+
_DROP_STATS["empty_img_bytes"] += 1
|
|
104
|
+
continue
|
|
105
|
+
try:
|
|
106
|
+
img = Image.open(BytesIO(img_bytes))
|
|
107
|
+
img.load()
|
|
108
|
+
img = img.convert("RGB")
|
|
109
|
+
except Exception:
|
|
110
|
+
_DROP_STATS["bad_img_bytes"] += 1
|
|
111
|
+
continue
|
|
112
|
+
ids.append(s.get('id'))
|
|
113
|
+
images.append(img)
|
|
114
|
+
texts.append(s["tex"])
|
|
115
|
+
categorys.append(s.get('category'))
|
|
116
|
+
|
|
117
|
+
if not images:
|
|
118
|
+
raise handlers.SkipItem("all images in batch are broken/truncated")
|
|
119
|
+
|
|
120
|
+
return processor(
|
|
121
|
+
images=images,
|
|
122
|
+
text=texts,
|
|
123
|
+
ids=ids,
|
|
124
|
+
categorys=categorys,
|
|
125
|
+
return_tensors="pt",
|
|
126
|
+
padding=True,
|
|
127
|
+
truncation=True,
|
|
128
|
+
max_length=max_length,
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
@wds.pipelinefilter
|
|
132
|
+
def bucket_len_shortedge_ratio(data_iter, batch_size, pool_size=4096, len_key="tokens",
|
|
133
|
+
se_key="se_bin", ar_key="ar_bin", len_bin_size=64, drop_last=True):
|
|
134
|
+
pool = []
|
|
135
|
+
def _drain_pool_pandas(pool):
|
|
136
|
+
if not pool: return
|
|
137
|
+
df = pd.DataFrame.from_records([{
|
|
138
|
+
"idx": i,
|
|
139
|
+
"len_bin": pool[i][len_key] // len_bin_size,
|
|
140
|
+
"se_bin": pool[i][se_key],
|
|
141
|
+
"ar_bin": pool[i][ar_key],
|
|
142
|
+
} for i in range(len(pool))])
|
|
143
|
+
|
|
144
|
+
all_leftovers = []
|
|
145
|
+
for _, g in df.groupby(['len_bin', 'se_bin', 'ar_bin']):
|
|
146
|
+
idxs = g["idx"].to_list()
|
|
147
|
+
n_full = (len(idxs) // batch_size) * batch_size
|
|
148
|
+
for i in range(0, n_full, batch_size):
|
|
149
|
+
yield [pool[j] for j in idxs[i:i+batch_size]]
|
|
150
|
+
if len(idxs) > n_full:
|
|
151
|
+
all_leftovers.extend(idxs[n_full:])
|
|
152
|
+
|
|
153
|
+
while len(all_leftovers) >= batch_size:
|
|
154
|
+
yield [pool[j] for j in all_leftovers[:batch_size]]
|
|
155
|
+
all_leftovers = all_leftovers[batch_size:]
|
|
156
|
+
|
|
157
|
+
if (not drop_last) and all_leftovers:
|
|
158
|
+
yield [pool[j] for j in all_leftovers]
|
|
159
|
+
|
|
160
|
+
for sample in data_iter:
|
|
161
|
+
pool.append(sample)
|
|
162
|
+
if len(pool) >= pool_size:
|
|
163
|
+
for batch in _drain_pool_pandas(pool): yield batch
|
|
164
|
+
pool = []
|
|
165
|
+
if pool:
|
|
166
|
+
for batch in _drain_pool_pandas(pool): yield batch
|
|
167
|
+
|
|
168
|
+
# --- Main Dataset Class ---
|
|
169
|
+
|
|
170
|
+
class CMERWebDataSet(IterableDataset):
|
|
171
|
+
def __init__(self, config, mode, logger, seed=None, epoch=None, task='rec'):
|
|
172
|
+
super(CMERWebDataSet, self).__init__()
|
|
173
|
+
|
|
174
|
+
# Config parsing
|
|
175
|
+
global_config = config.get('Global', {})
|
|
176
|
+
dataset_config = config[mode]['dataset']
|
|
177
|
+
loader_config = config[mode]['loader']
|
|
178
|
+
|
|
179
|
+
self.mode = mode
|
|
180
|
+
self.logger = logger
|
|
181
|
+
self.data_dir = dataset_config['data_dir']
|
|
182
|
+
self.batch_size = loader_config['batch_size_per_card']
|
|
183
|
+
self.shuffle = loader_config.get('shuffle', False)
|
|
184
|
+
|
|
185
|
+
# Specific CMER params
|
|
186
|
+
self.max_length = dataset_config.get('max_length', 256)
|
|
187
|
+
self.shuffle_buffer = dataset_config.get('shuffle_buffer', 10000)
|
|
188
|
+
self.pool_size = dataset_config.get('pool_size', 8192)
|
|
189
|
+
self.drop_last = dataset_config.get('drop_last', True if mode == 'train' else False)
|
|
190
|
+
self.epochs = dataset_config.get('epochs', 1)
|
|
191
|
+
|
|
192
|
+
processor_name = dataset_config.get('processor', 'CMERProcessor')
|
|
193
|
+
module_path = dataset_config.get('processor_source', 'openrec.preprocess.cmer_label_encode')
|
|
194
|
+
proc_module = importlib.import_module(module_path)
|
|
195
|
+
ProcessorClass = getattr(proc_module, processor_name)
|
|
196
|
+
proc_args = loader_config.get('processor_args', {})
|
|
197
|
+
|
|
198
|
+
self.logger.info(f"Initializing {processor_name} with args: {proc_args}")
|
|
199
|
+
self.processor = ProcessorClass(**proc_args)
|
|
200
|
+
|
|
201
|
+
self.logger.info(f'Initialize CMER WebDataset: {self.data_dir} | Mode: {mode}')
|
|
202
|
+
|
|
203
|
+
def _build_wds_pipeline(self, shards, epoch_idx=0):
|
|
204
|
+
"""Constructs the WebDataset pipeline."""
|
|
205
|
+
if isinstance(shards, (list, tuple)):
|
|
206
|
+
shard_list = shards
|
|
207
|
+
else:
|
|
208
|
+
shard_list = [shards]
|
|
209
|
+
|
|
210
|
+
pipeline_stages = [
|
|
211
|
+
wds.SimpleShardList(shard_list),
|
|
212
|
+
wds.split_by_node,
|
|
213
|
+
wds.split_by_worker,
|
|
214
|
+
wds.tarfile_to_samples(),
|
|
215
|
+
wds.shuffle(self.shuffle_buffer if self.shuffle else 0),
|
|
216
|
+
wds.map(sanitize_keys),
|
|
217
|
+
wds.to_tuple("json", "jpg;png"),
|
|
218
|
+
wds.map(parse_json_tuple_meta_only),
|
|
219
|
+
]
|
|
220
|
+
|
|
221
|
+
# Filtering (Train only usually, but logic allows both)
|
|
222
|
+
if self.shuffle_buffer != 0:
|
|
223
|
+
pipeline_stages.append(
|
|
224
|
+
wds.select(partial(
|
|
225
|
+
keep_by_meta,
|
|
226
|
+
longside_max=3840,
|
|
227
|
+
area_max=1536*1536,
|
|
228
|
+
ar_max=20.0,
|
|
229
|
+
shortside_min=0,
|
|
230
|
+
max_tokens=1536,
|
|
231
|
+
require_positive_wh=True,
|
|
232
|
+
))
|
|
233
|
+
)
|
|
234
|
+
|
|
235
|
+
# Bucketing and Batching
|
|
236
|
+
pipeline_stages.extend([
|
|
237
|
+
wds.map(add_ar_bin),
|
|
238
|
+
wds.map(add_short_edge_bin),
|
|
239
|
+
bucket_len_shortedge_ratio(
|
|
240
|
+
batch_size=self.batch_size,
|
|
241
|
+
pool_size=self.pool_size,
|
|
242
|
+
len_key="tokens",
|
|
243
|
+
se_key="se_bin",
|
|
244
|
+
ar_key="ar_bin",
|
|
245
|
+
len_bin_size=64,
|
|
246
|
+
drop_last=self.drop_last,
|
|
247
|
+
),
|
|
248
|
+
wds.map(partial(batch_to_inputs_decode_late_safe,
|
|
249
|
+
processor=self.processor,
|
|
250
|
+
max_length=self.max_length)),
|
|
251
|
+
])
|
|
252
|
+
|
|
253
|
+
ds = wds.DataPipeline(*pipeline_stages)
|
|
254
|
+
return ds.with_epoch(epoch_idx)
|
|
255
|
+
|
|
256
|
+
def __iter__(self):
|
|
257
|
+
"""
|
|
258
|
+
Iterates through the dataset.
|
|
259
|
+
Logic adapted from get_train_dataset and get_dataset.
|
|
260
|
+
"""
|
|
261
|
+
all_datasets = []
|
|
262
|
+
|
|
263
|
+
if self.mode == 'train':
|
|
264
|
+
# Training logic: iterates through epoch folders
|
|
265
|
+
epochs_to_use = list(range(0, self.epochs))
|
|
266
|
+
for epoch_idx in epochs_to_use:
|
|
267
|
+
# Assuming structure: root/epoch_0/*.tar
|
|
268
|
+
epoch_path = os.path.join(self.data_dir, f"epoch_{epoch_idx}")
|
|
269
|
+
train_shards = sorted(glob.glob(f"{epoch_path}/*.tar"))
|
|
270
|
+
|
|
271
|
+
if not train_shards:
|
|
272
|
+
self.logger.warning(f"No .tar files found in {epoch_path}, skipping.")
|
|
273
|
+
continue
|
|
274
|
+
|
|
275
|
+
ds = self._build_wds_pipeline(train_shards, epoch_idx)
|
|
276
|
+
all_datasets.append(ds)
|
|
277
|
+
else:
|
|
278
|
+
# Eval/Test logic: iterates through package folders or flat structure
|
|
279
|
+
# Logic adapted from get_dataset
|
|
280
|
+
if os.path.exists(self.data_dir):
|
|
281
|
+
# Check if it's a directory of packages or direct shards
|
|
282
|
+
subdirs = [os.path.join(self.data_dir, d) for d in os.listdir(self.data_dir)
|
|
283
|
+
if os.path.isdir(os.path.join(self.data_dir, d))]
|
|
284
|
+
|
|
285
|
+
if subdirs:
|
|
286
|
+
# Package structure
|
|
287
|
+
for package_path in subdirs:
|
|
288
|
+
shards = sorted(glob.glob(f"{package_path}/*.tar"))
|
|
289
|
+
if shards:
|
|
290
|
+
ds = self._build_wds_pipeline(shards)
|
|
291
|
+
all_datasets.append(ds)
|
|
292
|
+
else:
|
|
293
|
+
# Flat structure
|
|
294
|
+
shards = sorted(glob.glob(f"{self.data_dir}/*.tar"))
|
|
295
|
+
if shards:
|
|
296
|
+
ds = self._build_wds_pipeline(shards)
|
|
297
|
+
all_datasets.append(ds)
|
|
298
|
+
|
|
299
|
+
if not all_datasets:
|
|
300
|
+
raise RuntimeError(f"No data found in {self.data_dir}")
|
|
301
|
+
|
|
302
|
+
# Chain the datasets (ChainIterDataset logic)
|
|
303
|
+
for ds in all_datasets:
|
|
304
|
+
for sample in ds:
|
|
305
|
+
yield sample
|
|
306
|
+
|
|
307
|
+
def __len__(self):
|
|
308
|
+
# WebDataset length is often approximate or unknown until iteration
|
|
309
|
+
# Returning a placeholder or calculating based on num_samples if available in metadata
|
|
310
|
+
return 0
|