lattice-sub 1.0.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,374 @@
1
+ """
2
+ Batch processing for multiple micrographs.
3
+
4
+ This module provides parallel processing capabilities for large datasets.
5
+ """
6
+
7
+ import os
8
+ from concurrent.futures import ProcessPoolExecutor, as_completed
9
+ from dataclasses import dataclass
10
+ from pathlib import Path
11
+ from typing import List, Tuple, Optional, Callable
12
+ import logging
13
+
14
+ from tqdm import tqdm
15
+
16
+ from .config import Config
17
+ from .core import LatticeSubtractor
18
+ from .io import read_mrc, write_mrc
19
+
20
+
21
+ logger = logging.getLogger(__name__)
22
+
23
+
24
+ @dataclass
25
+ class BatchResult:
26
+ """Results from batch processing."""
27
+
28
+ total: int
29
+ successful: int
30
+ failed: int
31
+ failed_files: List[Tuple[Path, str]] # (path, error_message)
32
+
33
+ @property
34
+ def success_rate(self) -> float:
35
+ """Return success rate as percentage."""
36
+ return (self.successful / self.total) * 100 if self.total > 0 else 0.0
37
+
38
+
39
+ def _process_single_file(args: tuple) -> Tuple[Path, Optional[str]]:
40
+ """
41
+ Process a single file (for parallel execution).
42
+
43
+ Args:
44
+ args: Tuple of (input_path, output_path, config_dict)
45
+
46
+ Returns:
47
+ Tuple of (input_path, error_message or None)
48
+ """
49
+ input_path, output_path, config_dict = args
50
+
51
+ try:
52
+ # Reconstruct config from dict (can't pickle dataclass with defaults easily)
53
+ config = Config(**config_dict)
54
+
55
+ # Process
56
+ subtractor = LatticeSubtractor(config)
57
+ result = subtractor.process(input_path)
58
+ result.save(output_path, pixel_size=config.pixel_ang)
59
+
60
+ return (Path(input_path), None)
61
+
62
+ except Exception as e:
63
+ return (Path(input_path), str(e))
64
+
65
+
66
+ class BatchProcessor:
67
+ """
68
+ Parallel batch processor for micrograph datasets.
69
+
70
+ This class handles processing of multiple MRC files in parallel,
71
+ with progress tracking, error handling, and optional file pattern matching.
72
+
73
+ Example:
74
+ >>> config = Config(pixel_ang=0.56)
75
+ >>> processor = BatchProcessor(config, num_workers=8)
76
+ >>> result = processor.process_directory(
77
+ ... input_dir="raw_micrographs/",
78
+ ... output_dir="subtracted/",
79
+ ... pattern="*.mrc"
80
+ ... )
81
+ >>> print(f"Processed {result.successful}/{result.total} files")
82
+ """
83
+
84
+ def __init__(
85
+ self,
86
+ config: Config,
87
+ num_workers: Optional[int] = None,
88
+ output_prefix: str = "sub_",
89
+ ):
90
+ """
91
+ Initialize batch processor.
92
+
93
+ Args:
94
+ config: Processing configuration
95
+ num_workers: Number of parallel workers. Default: CPU count - 1
96
+ output_prefix: Prefix for output filenames. Default: "sub_"
97
+ """
98
+ self.config = config
99
+ self.num_workers = num_workers or max(1, os.cpu_count() - 1)
100
+ self.output_prefix = output_prefix
101
+
102
+ # Convert config to dict for pickling
103
+ from dataclasses import asdict
104
+ self._config_dict = asdict(config)
105
+
106
+ def process_directory(
107
+ self,
108
+ input_dir: str | Path,
109
+ output_dir: str | Path,
110
+ pattern: str = "*.mrc",
111
+ recursive: bool = False,
112
+ show_progress: bool = True,
113
+ ) -> BatchResult:
114
+ """
115
+ Process all matching files in a directory.
116
+
117
+ Args:
118
+ input_dir: Input directory containing MRC files
119
+ output_dir: Output directory for processed files
120
+ pattern: Glob pattern for matching files. Default: "*.mrc"
121
+ recursive: If True, search subdirectories. Default: False
122
+ show_progress: If True, show progress bar. Default: True
123
+
124
+ Returns:
125
+ BatchResult with processing statistics
126
+ """
127
+ input_dir = Path(input_dir)
128
+ output_dir = Path(output_dir)
129
+
130
+ # Find input files
131
+ if recursive:
132
+ input_files = list(input_dir.rglob(pattern))
133
+ else:
134
+ input_files = list(input_dir.glob(pattern))
135
+
136
+ if not input_files:
137
+ logger.warning(f"No files matching '{pattern}' found in {input_dir}")
138
+ return BatchResult(total=0, successful=0, failed=0, failed_files=[])
139
+
140
+ # Create output directory
141
+ output_dir.mkdir(parents=True, exist_ok=True)
142
+
143
+ # Build file list with output paths
144
+ file_pairs = []
145
+ for input_path in input_files:
146
+ output_name = f"{self.output_prefix}{input_path.name}"
147
+ output_path = output_dir / output_name
148
+ file_pairs.append((input_path, output_path))
149
+
150
+ return self.process_file_list(file_pairs, show_progress=show_progress)
151
+
152
+ def process_file_list(
153
+ self,
154
+ file_pairs: List[Tuple[Path, Path]],
155
+ show_progress: bool = True,
156
+ ) -> BatchResult:
157
+ """
158
+ Process a list of input/output file pairs.
159
+
160
+ Args:
161
+ file_pairs: List of (input_path, output_path) tuples
162
+ show_progress: If True, show progress bar
163
+
164
+ Returns:
165
+ BatchResult with processing statistics
166
+ """
167
+ total = len(file_pairs)
168
+ successful = 0
169
+ failed_files = []
170
+
171
+ # Check if using GPU - if so, process sequentially to avoid CUDA fork issues
172
+ # With "auto" backend, check if PyTorch + CUDA is actually available
173
+ use_gpu = self.config.backend == "pytorch"
174
+ if self.config.backend == "auto":
175
+ try:
176
+ import torch
177
+ use_gpu = torch.cuda.is_available()
178
+ except ImportError:
179
+ use_gpu = False
180
+
181
+ if use_gpu:
182
+ # Sequential processing for GPU (CUDA doesn't support fork multiprocessing)
183
+ successful, failed_files = self._process_sequential(
184
+ file_pairs, show_progress
185
+ )
186
+ else:
187
+ # Parallel processing for CPU
188
+ successful, failed_files = self._process_parallel(
189
+ file_pairs, show_progress
190
+ )
191
+
192
+ return BatchResult(
193
+ total=total,
194
+ successful=successful,
195
+ failed=total - successful,
196
+ failed_files=failed_files,
197
+ )
198
+
199
+ def _process_sequential(
200
+ self,
201
+ file_pairs: List[Tuple[Path, Path]],
202
+ show_progress: bool = True,
203
+ ) -> Tuple[int, List[Tuple[Path, str]]]:
204
+ """Process files sequentially (for GPU mode)."""
205
+ import sys
206
+ successful = 0
207
+ failed_files = []
208
+
209
+ # Create progress bar FIRST before any CUDA initialization
210
+ # Use sys.stdout and force flush for immediate display
211
+ if show_progress:
212
+ print("", flush=True) # Ensure clean line
213
+ pbar = tqdm(
214
+ total=len(file_pairs),
215
+ desc=" Processing",
216
+ unit="file",
217
+ ncols=80,
218
+ leave=True,
219
+ )
220
+ else:
221
+ pbar = None
222
+
223
+ # Now initialize subtractor (this triggers CUDA init)
224
+ subtractor = LatticeSubtractor(self.config)
225
+
226
+ for input_path, output_path in file_pairs:
227
+ try:
228
+ result = subtractor.process(input_path)
229
+ result.save(output_path, pixel_size=self.config.pixel_ang)
230
+ successful += 1
231
+ except Exception as e:
232
+ failed_files.append((input_path, str(e)))
233
+ logger.error(f"Failed to process {input_path}: {e}")
234
+
235
+ if pbar:
236
+ pbar.update(1)
237
+
238
+ if pbar:
239
+ pbar.close()
240
+
241
+ return successful, failed_files
242
+
243
+ def _process_parallel(
244
+ self,
245
+ file_pairs: List[Tuple[Path, Path]],
246
+ show_progress: bool = True,
247
+ ) -> Tuple[int, List[Tuple[Path, str]]]:
248
+ """Process files in parallel (for CPU mode)."""
249
+ successful = 0
250
+ failed_files = []
251
+ total = len(file_pairs)
252
+
253
+ # Prepare arguments for parallel execution
254
+ args_list = [
255
+ (str(inp), str(out), self._config_dict)
256
+ for inp, out in file_pairs
257
+ ]
258
+
259
+ # Process in parallel
260
+ with ProcessPoolExecutor(max_workers=self.num_workers) as executor:
261
+ futures = {
262
+ executor.submit(_process_single_file, args): args[0]
263
+ for args in args_list
264
+ }
265
+
266
+ # Track progress
267
+ iterator = as_completed(futures)
268
+ if show_progress:
269
+ iterator = tqdm(
270
+ iterator,
271
+ total=total,
272
+ desc="Processing micrographs",
273
+ unit="file",
274
+ )
275
+
276
+ for future in iterator:
277
+ input_path, error = future.result()
278
+
279
+ if error is None:
280
+ successful += 1
281
+ else:
282
+ failed_files.append((input_path, error))
283
+ logger.error(f"Failed to process {input_path}: {error}")
284
+
285
+ return successful, failed_files
286
+
287
+ def process_numbered_sequence(
288
+ self,
289
+ input_pattern: str,
290
+ output_dir: str | Path,
291
+ start: int,
292
+ end: int,
293
+ zero_pad: int = 4,
294
+ show_progress: bool = True,
295
+ ) -> BatchResult:
296
+ """
297
+ Process a numbered sequence of files.
298
+
299
+ This is designed to match the legacy HYPER_loop behavior for
300
+ processing numbered file sequences.
301
+
302
+ Args:
303
+ input_pattern: Pattern with {num} placeholder, e.g.,
304
+ "raw/image_{num}.mrc"
305
+ output_dir: Output directory
306
+ start: Starting number (inclusive)
307
+ end: Ending number (inclusive)
308
+ zero_pad: Number of digits for zero-padding. Default: 4
309
+ show_progress: If True, show progress bar
310
+
311
+ Returns:
312
+ BatchResult
313
+
314
+ Example:
315
+ >>> processor.process_numbered_sequence(
316
+ ... input_pattern="data/mic_{num}.mrc",
317
+ ... output_dir="processed/",
318
+ ... start=1,
319
+ ... end=100,
320
+ ... zero_pad=4
321
+ ... )
322
+ """
323
+ output_dir = Path(output_dir)
324
+ output_dir.mkdir(parents=True, exist_ok=True)
325
+
326
+ file_pairs = []
327
+
328
+ for num in range(start, end + 1):
329
+ num_str = str(num).zfill(zero_pad)
330
+ input_path = Path(input_pattern.format(num=num_str))
331
+
332
+ if input_path.exists():
333
+ output_name = f"{self.output_prefix}{input_path.name}"
334
+ output_path = output_dir / output_name
335
+ file_pairs.append((input_path, output_path))
336
+ else:
337
+ logger.debug(f"Skipping non-existent file: {input_path}")
338
+
339
+ if not file_pairs:
340
+ logger.warning("No files found matching the numbered pattern")
341
+ return BatchResult(total=0, successful=0, failed=0, failed_files=[])
342
+
343
+ return self.process_file_list(file_pairs, show_progress=show_progress)
344
+
345
+
346
+ def process_directory(
347
+ input_dir: str | Path,
348
+ output_dir: str | Path,
349
+ config: Config,
350
+ pattern: str = "*.mrc",
351
+ num_workers: Optional[int] = None,
352
+ show_progress: bool = True,
353
+ ) -> BatchResult:
354
+ """
355
+ Convenience function for batch processing a directory.
356
+
357
+ Args:
358
+ input_dir: Input directory
359
+ output_dir: Output directory
360
+ config: Processing configuration
361
+ pattern: Glob pattern for files
362
+ num_workers: Number of parallel workers
363
+ show_progress: Show progress bar
364
+
365
+ Returns:
366
+ BatchResult
367
+ """
368
+ processor = BatchProcessor(config, num_workers=num_workers)
369
+ return processor.process_directory(
370
+ input_dir=input_dir,
371
+ output_dir=output_dir,
372
+ pattern=pattern,
373
+ show_progress=show_progress,
374
+ )