openocr-python 0.0.9__py3-none-any.whl → 0.1.0.dev0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (90) hide show
  1. openocr/__init__.py +35 -1
  2. openocr/configs/dataset/rec/evaluation.yaml +41 -0
  3. openocr/configs/dataset/rec/ltb.yaml +9 -0
  4. openocr/configs/dataset/rec/mjsynth.yaml +11 -0
  5. openocr/configs/dataset/rec/openvino.yaml +25 -0
  6. openocr/configs/dataset/rec/ost.yaml +17 -0
  7. openocr/configs/dataset/rec/synthtext.yaml +7 -0
  8. openocr/configs/dataset/rec/test.yaml +77 -0
  9. openocr/configs/dataset/rec/textocr.yaml +13 -0
  10. openocr/configs/dataset/rec/textocr_horizontal.yaml +13 -0
  11. openocr/configs/dataset/rec/union14m_b.yaml +47 -0
  12. openocr/configs/dataset/rec/union14m_l_filtered.yaml +35 -0
  13. openocr/configs/rec/cmer/cmer.yml +127 -0
  14. openocr/configs/rec/mdiff4str/svtrv2_mdiffdecoder_base.yml +152 -0
  15. openocr/configs/rec/mdiff4str/svtrv2_mdiffdecoder_small.yml +152 -0
  16. openocr/configs/rec/unirec/focalsvtr_ardecoder_unirec.yml +114 -0
  17. openocr/configs/rec/unirec/opendoc_pipeline.yml +105 -0
  18. openocr/demo_gradio.py +28 -8
  19. openocr/demo_opendoc.py +572 -0
  20. openocr/demo_unirec.py +392 -0
  21. openocr/opendet/losses/__init__.py +5 -7
  22. openocr/opendet/preprocess/crop_resize.py +2 -1
  23. openocr/openocr.py +685 -0
  24. openocr/openrec/losses/__init__.py +8 -3
  25. openocr/openrec/losses/cmer_loss.py +12 -0
  26. openocr/openrec/losses/mdiff_loss.py +11 -0
  27. openocr/openrec/losses/unirec_loss.py +12 -0
  28. openocr/openrec/metrics/__init__.py +4 -1
  29. openocr/openrec/metrics/rec_metric_cmer.py +328 -0
  30. openocr/openrec/modeling/cmer_modeling/modeling_cmer.py +643 -0
  31. openocr/openrec/modeling/decoders/__init__.py +1 -0
  32. openocr/openrec/modeling/decoders/ctc_decoder.py +1 -1
  33. openocr/openrec/modeling/decoders/dan_decoder.py +4 -4
  34. openocr/openrec/modeling/decoders/dptr_parseq_clip_b_decoder.py +1563 -1398
  35. openocr/openrec/modeling/decoders/mdiff_decoder.py +587 -0
  36. openocr/openrec/modeling/decoders/smtr_decoder.py +99 -48
  37. openocr/openrec/modeling/unirec_modeling/configuration_unirec.py +166 -0
  38. openocr/openrec/modeling/unirec_modeling/modeling_unirec.py +433 -0
  39. openocr/openrec/optimizer/__init__.py +4 -3
  40. openocr/openrec/optimizer/lr.py +49 -0
  41. openocr/openrec/postprocess/__init__.py +2 -0
  42. openocr/openrec/postprocess/abinet_postprocess.py +1 -1
  43. openocr/openrec/postprocess/ar_postprocess.py +1 -1
  44. openocr/openrec/postprocess/cmer_postprocess.py +86 -0
  45. openocr/openrec/postprocess/cppd_postprocess.py +1 -1
  46. openocr/openrec/postprocess/igtr_postprocess.py +1 -1
  47. openocr/openrec/postprocess/lister_postprocess.py +1 -1
  48. openocr/openrec/postprocess/mgp_postprocess.py +1 -1
  49. openocr/openrec/postprocess/nrtr_postprocess.py +2 -2
  50. openocr/openrec/postprocess/smtr_postprocess.py +1 -1
  51. openocr/openrec/postprocess/srn_postprocess.py +1 -1
  52. openocr/openrec/postprocess/unirec_postprocess.py +58 -0
  53. openocr/openrec/postprocess/visionlan_postprocess.py +1 -1
  54. openocr/openrec/preprocess/__init__.py +5 -0
  55. openocr/openrec/preprocess/ce_label_encode.py +1 -1
  56. openocr/openrec/preprocess/cmer_label_encode.py +1025 -0
  57. openocr/openrec/preprocess/ctc_label_encode.py +1 -1
  58. openocr/openrec/preprocess/dptr_label_encode.py +177 -157
  59. openocr/openrec/preprocess/igtr_label_encode.py +4 -2
  60. openocr/openrec/preprocess/mdiff_label_encode.py +312 -0
  61. openocr/openrec/preprocess/rec_aug.py +128 -2
  62. openocr/openrec/preprocess/resize.py +57 -0
  63. openocr/openrec/preprocess/unirec_label_encode.py +62 -0
  64. openocr/tools/data/__init__.py +78 -55
  65. openocr/tools/data/cmer_web_dataset.py +310 -0
  66. openocr/tools/data/native_size_dataset.py +753 -0
  67. openocr/tools/data/native_size_sampler.py +158 -0
  68. openocr/tools/data/ratio_dataset_tvresize.py +2 -0
  69. openocr/tools/data/ratio_sampler.py +2 -1
  70. openocr/tools/download/download_dataset.py +38 -0
  71. openocr/tools/download/utils.py +28 -0
  72. openocr/tools/download_example_images.py +236 -0
  73. openocr/tools/engine/trainer.py +155 -39
  74. openocr/tools/eval_rec_all_ch.py +2 -2
  75. openocr/tools/infer_det.py +20 -2
  76. openocr/tools/infer_doc.py +898 -0
  77. openocr/tools/infer_doc_onnx.py +1172 -0
  78. openocr/tools/infer_e2e.py +27 -10
  79. openocr/tools/infer_rec.py +64 -15
  80. openocr/tools/infer_unirec_onnx.py +730 -0
  81. openocr/tools/to_markdown.py +468 -0
  82. openocr/tools/utils/ckpt.py +17 -5
  83. openocr/tools/utils/opendoc_onnx_utils/utils.py +1052 -0
  84. openocr_python-0.1.0.dev0.dist-info/METADATA +324 -0
  85. {openocr_python-0.0.9.dist-info → openocr_python-0.1.0.dev0.dist-info}/RECORD +89 -45
  86. {openocr_python-0.0.9.dist-info → openocr_python-0.1.0.dev0.dist-info}/WHEEL +1 -1
  87. openocr_python-0.1.0.dev0.dist-info/entry_points.txt +2 -0
  88. openocr_python-0.0.9.dist-info/METADATA +0 -149
  89. /openocr_python-0.0.9.dist-info/LICENCE → /openocr_python-0.1.0.dev0.dist-info/licenses/LICENSE +0 -0
  90. {openocr_python-0.0.9.dist-info → openocr_python-0.1.0.dev0.dist-info}/top_level.txt +0 -0
@@ -19,13 +19,16 @@ DATASET_MODULES = {
19
19
  'RatioDataSet': 'tools.data.ratio_dataset',
20
20
  'RatioDataSetTest': 'tools.data.ratio_dataset_test',
21
21
  'RatioDataSetTVResize': 'tools.data.ratio_dataset_tvresize',
22
- 'RatioDataSetTVResizeTest': 'tools.data.ratio_dataset_tvresize_test'
22
+ 'RatioDataSetTVResizeTest': 'tools.data.ratio_dataset_tvresize_test',
23
+ 'NaSizeDataSet': 'tools.data.native_size_dataset',
24
+ 'CMERWebDataSet': 'tools.data.cmer_web_dataset',
23
25
  }
24
26
 
25
27
  # 定义支持的 Sampler 类及其对应的模块路径
26
28
  SAMPLER_MODULES = {
27
29
  'MultiScaleSampler': 'tools.data.multi_scale_sampler',
28
- 'RatioSampler': 'tools.data.ratio_sampler'
30
+ 'RatioSampler': 'tools.data.ratio_sampler',
31
+ 'NaSizeSampler': 'tools.data.native_size_sampler',
29
32
  }
30
33
 
31
34
  __all__ = [
@@ -33,9 +36,9 @@ __all__ = [
33
36
  ]
34
37
 
35
38
 
36
- def build_dataloader(config, mode, logger, seed=None, epoch=3, task='rec'):
39
+ def build_dataloader(config, mode, logger, seed=None, epoch=1, task='rec'):
37
40
  config = copy.deepcopy(config)
38
- mode = mode.capitalize() # 确保 mode 是首字母大写形式(Train/Eval/Test)
41
+ mode = mode.capitalize()
39
42
 
40
43
  # 获取 dataset 配置
41
44
  dataset_config = config[mode]['dataset']
@@ -53,61 +56,81 @@ def build_dataloader(config, mode, logger, seed=None, epoch=3, task='rec'):
53
56
 
54
57
  # DataLoader 配置
55
58
  loader_config = config[mode]['loader']
56
- batch_size = loader_config['batch_size_per_card']
57
- drop_last = loader_config['drop_last']
58
- shuffle = loader_config['shuffle']
59
59
  num_workers = loader_config['num_workers']
60
60
  pin_memory = loader_config.get('pin_memory', False)
61
-
62
- sampler = None
63
- batch_sampler = None
64
- if 'sampler' in config[mode]:
65
- sampler_config = config[mode]['sampler']
66
- sampler_name = sampler_config.pop('name')
67
-
68
- if sampler_name not in SAMPLER_MODULES:
69
- raise ValueError(
70
- f'Unsupported sampler: {sampler_name}. Supported samplers: {list(SAMPLER_MODULES.keys())}'
71
- )
72
-
73
- sampler_module = importlib.import_module(SAMPLER_MODULES[sampler_name])
74
- sampler_class = getattr(sampler_module, sampler_name)
75
- batch_sampler = sampler_class(dataset, **sampler_config)
76
- elif config['Global']['distributed'] and mode == 'Train':
77
- sampler = DistributedSampler(dataset=dataset, shuffle=shuffle)
78
-
79
- if 'collate_fn' in loader_config:
80
- from . import collate_fn
81
- collate_fn = getattr(collate_fn, loader_config['collate_fn'])()
82
- else:
83
- collate_fn = None
84
-
85
- if batch_sampler is None:
86
- data_loader = DataLoader(
87
- dataset=dataset,
88
- sampler=sampler,
61
+ if module_name == 'CMERWebDataSet':
62
+ logger.info(f"Building WebLoader for {module_name} (IterableDataset mode)...")
63
+ import webdataset as wds
64
+ persistent = num_workers > 0
65
+ data_loader = wds.WebLoader(
66
+ dataset,
67
+ batch_size=None, # 必须为 None,因为 dataset yield 的已经是 batch
68
+ shuffle=False, # 外部不打乱,内部处理
89
69
  num_workers=num_workers,
90
- pin_memory=pin_memory,
91
- collate_fn=collate_fn,
92
- batch_size=batch_size,
93
- drop_last=drop_last,
70
+ pin_memory=True,
71
+ prefetch_factor=4,
72
+ persistent_workers=persistent,
94
73
  )
74
+ total_iter_steps = config['Global'].get('total_iter_steps', 1000000)
75
+ data_loader = data_loader.with_length(total_iter_steps)
76
+ return data_loader
95
77
  else:
96
- data_loader = DataLoader(
97
- dataset=dataset,
98
- batch_sampler=batch_sampler,
99
- num_workers=num_workers,
100
- pin_memory=pin_memory,
101
- collate_fn=collate_fn,
102
- )
78
+ batch_size = loader_config['batch_size_per_card']
79
+ drop_last = loader_config['drop_last']
80
+ shuffle = loader_config['shuffle']
81
+ sampler = None
82
+ batch_sampler = None
83
+ if 'sampler' in config[mode]:
84
+ sampler_config = config[mode]['sampler']
85
+ sampler_name = sampler_config.pop('name')
86
+
87
+ if sampler_name not in SAMPLER_MODULES:
88
+ raise ValueError(
89
+ f'Unsupported sampler: {sampler_name}. Supported samplers: {list(SAMPLER_MODULES.keys())}'
90
+ )
91
+
92
+ sampler_module = importlib.import_module(SAMPLER_MODULES[sampler_name])
93
+ sampler_class = getattr(sampler_module, sampler_name)
94
+ batch_sampler = sampler_class(dataset, **sampler_config)
95
+ elif config['Global']['distributed'] and mode == 'Train':
96
+ sampler = DistributedSampler(dataset=dataset, shuffle=shuffle)
97
+
98
+ if hasattr(dataset, 'collate_fn'):
99
+ collate_fn = dataset.collate_fn
100
+ logger.info(f'Using collate_fn defined in {mode} dataset.')
101
+ else:
102
+ if 'collate_fn' in loader_config:
103
+ from . import collate_fn
104
+ collate_fn = getattr(collate_fn, loader_config['collate_fn'])()
105
+ else:
106
+ collate_fn = None
107
+
108
+ if batch_sampler is None:
109
+ data_loader = DataLoader(
110
+ dataset=dataset,
111
+ sampler=sampler,
112
+ num_workers=num_workers,
113
+ pin_memory=pin_memory,
114
+ collate_fn=collate_fn,
115
+ batch_size=batch_size,
116
+ drop_last=drop_last,
117
+ )
118
+ else:
119
+ data_loader = DataLoader(
120
+ dataset=dataset,
121
+ batch_sampler=batch_sampler,
122
+ num_workers=num_workers,
123
+ pin_memory=pin_memory,
124
+ collate_fn=collate_fn,
125
+ )
103
126
 
104
- # 检查数据加载器是否为空
105
- if len(data_loader) == 0:
106
- logger.error(
107
- f'No Images in {mode.lower()} dataloader. Please check:\n'
108
- '\t1. The images num in the train label_file_list should be >= batch size.\n'
109
- '\t2. The annotation file and path in the configuration are correct.\n'
110
- '\t3. The BatchSize is not larger than the number of images.')
111
- sys.exit()
127
+ # 检查数据加载器是否为空
128
+ if len(data_loader) == 0:
129
+ logger.error(
130
+ f'No Images in {mode.lower()} dataloader. Please check:\n'
131
+ '\t1. The images num in the train label_file_list should be >= batch size.\n'
132
+ '\t2. The annotation file and path in the configuration are correct.\n'
133
+ '\t3. The BatchSize is not larger than the number of images.')
134
+ sys.exit()
112
135
 
113
- return data_loader
136
+ return data_loader
@@ -0,0 +1,310 @@
1
+ import os
2
+ import glob
3
+ import json
4
+ import math
5
+ import random
6
+ import numpy as np
7
+ import pandas as pd
8
+ import torch
9
+ import webdataset as wds
10
+ from torch.utils.data import IterableDataset
11
+ from io import BytesIO
12
+ from PIL import Image
13
+ from functools import partial
14
+ from collections import Counter
15
+ from webdataset import handlers
16
+ import importlib
17
+ # Global counter for drop statistics, as used in the original code
18
+ _DROP_STATS = Counter()
19
+ def sanitize_keys(sample):
20
+ new_sample = sample.copy()
21
+ for key in list(sample.keys()):
22
+ if key.startswith("__"):
23
+ continue
24
+
25
+
26
+ if "." in key:
27
+ ext = key.split(".")[-1]
28
+
29
+ if ext not in new_sample:
30
+ new_sample[ext] = sample[key]
31
+
32
+ return new_sample
33
+ # --- Helper Functions (Must be defined at module level for pickling in multiprocessing) ---
34
+
35
+ def keep_by_meta(sample, longside_max=12000, area_max=80_000_000, ar_max=20.0,
36
+ shortside_min=16, require_positive_wh=True, max_tokens=1536, require_tokens=True):
37
+ try:
38
+ w = int(sample.get("width", 0) or 0)
39
+ h = int(sample.get("height", 0) or 0)
40
+
41
+ if require_positive_wh and (w <= 0 or h <= 0):
42
+ _DROP_STATS["nonpos_wh"] += 1
43
+ return False
44
+
45
+ L, S = max(w, h), min(w, h)
46
+ A = w * h
47
+ ar = (w / h) if (h > 0) else math.inf
48
+
49
+ if L > longside_max: _DROP_STATS["longside"] += 1; return False
50
+ if A > area_max: _DROP_STATS["area"] += 1; return False
51
+ if S < shortside_min: _DROP_STATS["shortside"] += 1; return False
52
+ if (ar > ar_max) or (ar < 1.0 / ar_max): _DROP_STATS["ar"] += 1; return False
53
+
54
+ tok = sample.get("tokens", None)
55
+ if tok is None:
56
+ if require_tokens: _DROP_STATS["no_tokens"] += 1; return False
57
+ else:
58
+ try:
59
+ if int(tok) > max_tokens: _DROP_STATS["tokens"] += 1; return False
60
+ except Exception:
61
+ if require_tokens: _DROP_STATS["bad_tokens_val"] += 1; return False
62
+ return True
63
+ except Exception:
64
+ _DROP_STATS["exception"] += 1
65
+ return False
66
+
67
+ def parse_json_tuple_meta_only(sample):
68
+ js, img_bytes = sample
69
+ if isinstance(js, (bytes, bytearray)): js = json.loads(js.decode("utf-8"))
70
+ elif isinstance(js, str): js = json.loads(js)
71
+
72
+ return {
73
+ "id": js.get('id', ''),
74
+ "img_bytes": img_bytes,
75
+ "tex": js["tex"],
76
+ "tokens": int(js["tokens"]),
77
+ "width": int(js.get("width", 0) or 0),
78
+ "height": int(js.get("height", 0) or 0),
79
+ "category": js.get('category', '')
80
+ }
81
+
82
+ def add_ar_bin(sample, k=5, clamp=2.0):
83
+ w, h = sample.get("width", 0), sample.get("height", 0)
84
+ r = float(w) / float(h) if h else 1.0
85
+ logar = math.log2(max(r, 1e-6))
86
+ b = int(round(k * logar))
87
+ sample["ar_bin"] = max(min(b, int(k * clamp)), -int(k * clamp))
88
+ return sample
89
+
90
+ def add_short_edge_bin(sample, se_bin_size=96, se_min=96, se_max=1536):
91
+ w, h = sample.get("width", 0), sample.get("height", 0)
92
+ se = int(min(w, h)) if (w and h) else 0
93
+ if se <= 0: se = se_min
94
+ se = max(se_min, min(se, se_max))
95
+ sample["se_bin"] = int(se // se_bin_size)
96
+ return sample
97
+
98
+ def batch_to_inputs_decode_late_safe(batch, processor, max_length):
99
+ images, texts, ids, categorys = [], [], [], []
100
+ for s in batch:
101
+ img_bytes = s.get("img_bytes", None)
102
+ if not img_bytes:
103
+ _DROP_STATS["empty_img_bytes"] += 1
104
+ continue
105
+ try:
106
+ img = Image.open(BytesIO(img_bytes))
107
+ img.load()
108
+ img = img.convert("RGB")
109
+ except Exception:
110
+ _DROP_STATS["bad_img_bytes"] += 1
111
+ continue
112
+ ids.append(s.get('id'))
113
+ images.append(img)
114
+ texts.append(s["tex"])
115
+ categorys.append(s.get('category'))
116
+
117
+ if not images:
118
+ raise handlers.SkipItem("all images in batch are broken/truncated")
119
+
120
+ return processor(
121
+ images=images,
122
+ text=texts,
123
+ ids=ids,
124
+ categorys=categorys,
125
+ return_tensors="pt",
126
+ padding=True,
127
+ truncation=True,
128
+ max_length=max_length,
129
+ )
130
+
131
+ @wds.pipelinefilter
132
+ def bucket_len_shortedge_ratio(data_iter, batch_size, pool_size=4096, len_key="tokens",
133
+ se_key="se_bin", ar_key="ar_bin", len_bin_size=64, drop_last=True):
134
+ pool = []
135
+ def _drain_pool_pandas(pool):
136
+ if not pool: return
137
+ df = pd.DataFrame.from_records([{
138
+ "idx": i,
139
+ "len_bin": pool[i][len_key] // len_bin_size,
140
+ "se_bin": pool[i][se_key],
141
+ "ar_bin": pool[i][ar_key],
142
+ } for i in range(len(pool))])
143
+
144
+ all_leftovers = []
145
+ for _, g in df.groupby(['len_bin', 'se_bin', 'ar_bin']):
146
+ idxs = g["idx"].to_list()
147
+ n_full = (len(idxs) // batch_size) * batch_size
148
+ for i in range(0, n_full, batch_size):
149
+ yield [pool[j] for j in idxs[i:i+batch_size]]
150
+ if len(idxs) > n_full:
151
+ all_leftovers.extend(idxs[n_full:])
152
+
153
+ while len(all_leftovers) >= batch_size:
154
+ yield [pool[j] for j in all_leftovers[:batch_size]]
155
+ all_leftovers = all_leftovers[batch_size:]
156
+
157
+ if (not drop_last) and all_leftovers:
158
+ yield [pool[j] for j in all_leftovers]
159
+
160
+ for sample in data_iter:
161
+ pool.append(sample)
162
+ if len(pool) >= pool_size:
163
+ for batch in _drain_pool_pandas(pool): yield batch
164
+ pool = []
165
+ if pool:
166
+ for batch in _drain_pool_pandas(pool): yield batch
167
+
168
+ # --- Main Dataset Class ---
169
+
170
+ class CMERWebDataSet(IterableDataset):
171
+ def __init__(self, config, mode, logger, seed=None, epoch=None, task='rec'):
172
+ super(CMERWebDataSet, self).__init__()
173
+
174
+ # Config parsing
175
+ global_config = config.get('Global', {})
176
+ dataset_config = config[mode]['dataset']
177
+ loader_config = config[mode]['loader']
178
+
179
+ self.mode = mode
180
+ self.logger = logger
181
+ self.data_dir = dataset_config['data_dir']
182
+ self.batch_size = loader_config['batch_size_per_card']
183
+ self.shuffle = loader_config.get('shuffle', False)
184
+
185
+ # Specific CMER params
186
+ self.max_length = dataset_config.get('max_length', 256)
187
+ self.shuffle_buffer = dataset_config.get('shuffle_buffer', 10000)
188
+ self.pool_size = dataset_config.get('pool_size', 8192)
189
+ self.drop_last = dataset_config.get('drop_last', True if mode == 'train' else False)
190
+ self.epochs = dataset_config.get('epochs', 1)
191
+
192
+ processor_name = dataset_config.get('processor', 'CMERProcessor')
193
+ module_path = dataset_config.get('processor_source', 'openrec.preprocess.cmer_label_encode')
194
+ proc_module = importlib.import_module(module_path)
195
+ ProcessorClass = getattr(proc_module, processor_name)
196
+ proc_args = loader_config.get('processor_args', {})
197
+
198
+ self.logger.info(f"Initializing {processor_name} with args: {proc_args}")
199
+ self.processor = ProcessorClass(**proc_args)
200
+
201
+ self.logger.info(f'Initialize CMER WebDataset: {self.data_dir} | Mode: {mode}')
202
+
203
+ def _build_wds_pipeline(self, shards, epoch_idx=0):
204
+ """Constructs the WebDataset pipeline."""
205
+ if isinstance(shards, (list, tuple)):
206
+ shard_list = shards
207
+ else:
208
+ shard_list = [shards]
209
+
210
+ pipeline_stages = [
211
+ wds.SimpleShardList(shard_list),
212
+ wds.split_by_node,
213
+ wds.split_by_worker,
214
+ wds.tarfile_to_samples(),
215
+ wds.shuffle(self.shuffle_buffer if self.shuffle else 0),
216
+ wds.map(sanitize_keys),
217
+ wds.to_tuple("json", "jpg;png"),
218
+ wds.map(parse_json_tuple_meta_only),
219
+ ]
220
+
221
+ # Filtering (Train only usually, but logic allows both)
222
+ if self.shuffle_buffer != 0:
223
+ pipeline_stages.append(
224
+ wds.select(partial(
225
+ keep_by_meta,
226
+ longside_max=3840,
227
+ area_max=1536*1536,
228
+ ar_max=20.0,
229
+ shortside_min=0,
230
+ max_tokens=1536,
231
+ require_positive_wh=True,
232
+ ))
233
+ )
234
+
235
+ # Bucketing and Batching
236
+ pipeline_stages.extend([
237
+ wds.map(add_ar_bin),
238
+ wds.map(add_short_edge_bin),
239
+ bucket_len_shortedge_ratio(
240
+ batch_size=self.batch_size,
241
+ pool_size=self.pool_size,
242
+ len_key="tokens",
243
+ se_key="se_bin",
244
+ ar_key="ar_bin",
245
+ len_bin_size=64,
246
+ drop_last=self.drop_last,
247
+ ),
248
+ wds.map(partial(batch_to_inputs_decode_late_safe,
249
+ processor=self.processor,
250
+ max_length=self.max_length)),
251
+ ])
252
+
253
+ ds = wds.DataPipeline(*pipeline_stages)
254
+ return ds.with_epoch(epoch_idx)
255
+
256
+ def __iter__(self):
257
+ """
258
+ Iterates through the dataset.
259
+ Logic adapted from get_train_dataset and get_dataset.
260
+ """
261
+ all_datasets = []
262
+
263
+ if self.mode == 'train':
264
+ # Training logic: iterates through epoch folders
265
+ epochs_to_use = list(range(0, self.epochs))
266
+ for epoch_idx in epochs_to_use:
267
+ # Assuming structure: root/epoch_0/*.tar
268
+ epoch_path = os.path.join(self.data_dir, f"epoch_{epoch_idx}")
269
+ train_shards = sorted(glob.glob(f"{epoch_path}/*.tar"))
270
+
271
+ if not train_shards:
272
+ self.logger.warning(f"No .tar files found in {epoch_path}, skipping.")
273
+ continue
274
+
275
+ ds = self._build_wds_pipeline(train_shards, epoch_idx)
276
+ all_datasets.append(ds)
277
+ else:
278
+ # Eval/Test logic: iterates through package folders or flat structure
279
+ # Logic adapted from get_dataset
280
+ if os.path.exists(self.data_dir):
281
+ # Check if it's a directory of packages or direct shards
282
+ subdirs = [os.path.join(self.data_dir, d) for d in os.listdir(self.data_dir)
283
+ if os.path.isdir(os.path.join(self.data_dir, d))]
284
+
285
+ if subdirs:
286
+ # Package structure
287
+ for package_path in subdirs:
288
+ shards = sorted(glob.glob(f"{package_path}/*.tar"))
289
+ if shards:
290
+ ds = self._build_wds_pipeline(shards)
291
+ all_datasets.append(ds)
292
+ else:
293
+ # Flat structure
294
+ shards = sorted(glob.glob(f"{self.data_dir}/*.tar"))
295
+ if shards:
296
+ ds = self._build_wds_pipeline(shards)
297
+ all_datasets.append(ds)
298
+
299
+ if not all_datasets:
300
+ raise RuntimeError(f"No data found in {self.data_dir}")
301
+
302
+ # Chain the datasets (ChainIterDataset logic)
303
+ for ds in all_datasets:
304
+ for sample in ds:
305
+ yield sample
306
+
307
+ def __len__(self):
308
+ # WebDataset length is often approximate or unknown until iteration
309
+ # Returning a placeholder or calculating based on num_samples if available in metadata
310
+ return 0