lattice-sub 1.0.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,221 @@
1
+ """
2
+ Image processing utilities.
3
+
4
+ This module contains functions for image padding, background subtraction,
5
+ and other preprocessing operations.
6
+ """
7
+
8
+ import numpy as np
9
+ from typing import Tuple
10
+
11
+
12
+ def pad_image(
13
+ image: np.ndarray,
14
+ pad_origin: Tuple[int, int],
15
+ target_size: int | None = None,
16
+ pad_value: float | None = None,
17
+ ) -> Tuple[np.ndarray, dict]:
18
+ """
19
+ Pad an image with mean border for FFT processing.
20
+
21
+ This replicates the MATLAB padarray functionality with 'pre' and 'post'
22
+ padding using the image mean value.
23
+
24
+ Args:
25
+ image: Input 2D image
26
+ pad_origin: Padding offsets (pad_y, pad_x) - pixels to add at start
27
+ target_size: Target square size. If None, auto-calculated.
28
+ pad_value: Value to use for padding. If None, uses image mean.
29
+
30
+ Returns:
31
+ Tuple of:
32
+ - Padded image
33
+ - Metadata dict with original shape and padding info for cropping
34
+ """
35
+ orig_h, orig_w = image.shape
36
+ pad_y, pad_x = pad_origin
37
+
38
+ # Auto-calculate target size if not provided
39
+ if target_size is None:
40
+ max_dim = max(orig_h, orig_w)
41
+ target_size = max_dim + pad_x * 2
42
+ # Round to nearest 10 for FFT efficiency
43
+ target_size = int(np.round(target_size / 10) * 10)
44
+
45
+ # Calculate padding amounts
46
+ pad_top = pad_y - 1 if pad_y > 0 else 0
47
+ pad_left = pad_x - 1 if pad_x > 0 else 0
48
+ pad_bottom = target_size - orig_h - pad_top
49
+ pad_right = target_size - orig_w - pad_left
50
+
51
+ # Ensure non-negative padding
52
+ pad_bottom = max(0, pad_bottom)
53
+ pad_right = max(0, pad_right)
54
+
55
+ # Use image mean for padding value if not specified
56
+ if pad_value is None:
57
+ pad_value = float(np.mean(image))
58
+
59
+ # Perform padding
60
+ padded = np.pad(
61
+ image,
62
+ ((pad_top, pad_bottom), (pad_left, pad_right)),
63
+ mode='constant',
64
+ constant_values=pad_value,
65
+ )
66
+
67
+ # Store metadata for later cropping
68
+ metadata = {
69
+ 'original_shape': (orig_h, orig_w),
70
+ 'pad_top': pad_top,
71
+ 'pad_left': pad_left,
72
+ 'pad_bottom': pad_bottom,
73
+ 'pad_right': pad_right,
74
+ 'target_size': target_size,
75
+ }
76
+
77
+ return padded, metadata
78
+
79
+
80
+ def crop_to_original(
81
+ image: np.ndarray,
82
+ metadata: dict,
83
+ ) -> np.ndarray:
84
+ """
85
+ Crop a padded image back to its original size.
86
+
87
+ Args:
88
+ image: Padded image
89
+ metadata: Metadata dict from pad_image()
90
+
91
+ Returns:
92
+ Cropped image with original dimensions
93
+ """
94
+ orig_h, orig_w = metadata['original_shape']
95
+ pad_top = metadata['pad_top']
96
+ pad_left = metadata['pad_left']
97
+
98
+ return image[pad_top:pad_top + orig_h, pad_left:pad_left + orig_w]
99
+
100
+
101
+ def subtract_background(
102
+ image: np.ndarray,
103
+ median_filter_size: int = 10,
104
+ ) -> np.ndarray:
105
+ """
106
+ Subtract smooth background from an image to reveal sharp features.
107
+
108
+ This is the Python equivalent of bg_FastSubtract_standard.m.
109
+ It creates a smoothed version of the image using median filtering
110
+ and subtracts it from the original.
111
+
112
+ Args:
113
+ image: Input 2D image (typically log-power spectrum)
114
+ median_filter_size: Size of median filter kernel. Default: 10
115
+
116
+ Returns:
117
+ Background-subtracted image with edge regions replaced by mean
118
+ """
119
+ from scipy import ndimage
120
+ from skimage.transform import resize
121
+
122
+ h, w = image.shape
123
+
124
+ if max(h, w) < 500:
125
+ # For small images, apply median filter directly
126
+ smoothed = ndimage.median_filter(image, size=median_filter_size)
127
+ edge = median_filter_size
128
+ else:
129
+ # For large images, downsample -> filter -> upsample
130
+ shrink_factor = 500 / max(h, w)
131
+
132
+ # Downsample
133
+ small = resize(
134
+ image,
135
+ (int(h * shrink_factor), int(w * shrink_factor)),
136
+ order=1, # Bilinear
137
+ preserve_range=True,
138
+ anti_aliasing=True,
139
+ )
140
+
141
+ # Apply median filter
142
+ small = ndimage.median_filter(small, size=median_filter_size)
143
+
144
+ # Upsample back to original size
145
+ smoothed = resize(
146
+ small,
147
+ (h, w),
148
+ order=1,
149
+ preserve_range=True,
150
+ anti_aliasing=True,
151
+ )
152
+
153
+ # Scale edge hiding region
154
+ edge = int(median_filter_size / shrink_factor)
155
+
156
+ # Subtract background
157
+ subtracted = image - smoothed
158
+
159
+ # Hide edges with artifacts
160
+ mean_value = np.mean(subtracted)
161
+ edge = max(1, edge)
162
+
163
+ # Replace edge regions with mean
164
+ subtracted[:edge, :] = mean_value
165
+ subtracted[-edge:, :] = mean_value
166
+ subtracted[:, :edge] = mean_value
167
+ subtracted[:, -edge:] = mean_value
168
+
169
+ return subtracted.astype(np.float32)
170
+
171
+
172
+ def compute_power_spectrum(
173
+ fft_shifted: np.ndarray,
174
+ log_scale: bool = True,
175
+ epsilon: float = 1e-10,
176
+ ) -> np.ndarray:
177
+ """
178
+ Compute power spectrum from shifted FFT.
179
+
180
+ Args:
181
+ fft_shifted: Centered FFT (after fftshift)
182
+ log_scale: If True, return log of amplitude. Default: True
183
+ epsilon: Small value to avoid log(0). Default: 1e-10
184
+
185
+ Returns:
186
+ Power spectrum (log-amplitude if log_scale=True)
187
+ """
188
+ amplitude = np.abs(fft_shifted)
189
+
190
+ if log_scale:
191
+ return np.log(amplitude + epsilon)
192
+
193
+ return amplitude
194
+
195
+
196
+ def shift_and_average(
197
+ array: np.ndarray,
198
+ shift_pixels: int,
199
+ ) -> np.ndarray:
200
+ """
201
+ Create a local average by averaging 4 shifted copies.
202
+
203
+ This is the inpainting technique from bg_push_by_rot.m that
204
+ averages amplitude values from neighboring regions.
205
+
206
+ Args:
207
+ array: Input 2D array (typically FFT amplitude)
208
+ shift_pixels: Number of pixels to shift in each direction
209
+
210
+ Returns:
211
+ Averaged array (local background estimate)
212
+ """
213
+ # Shift in 4 cardinal directions and average
214
+ shifted_sum = (
215
+ np.roll(array, shift_pixels, axis=0) +
216
+ np.roll(array, -shift_pixels, axis=0) +
217
+ np.roll(array, shift_pixels, axis=1) +
218
+ np.roll(array, -shift_pixels, axis=1)
219
+ )
220
+
221
+ return shifted_sum / 4.0
@@ -0,0 +1,256 @@
1
+ """
2
+ Terminal UI utilities for lattice subtraction.
3
+
4
+ This module provides styled terminal output with ASCII art banner
5
+ and formatted progress messages. Output is only shown when running
6
+ interactively (TTY detected) and not suppressed by --quiet flag.
7
+
8
+ When piped or used in a pipeline, decorative output is automatically
9
+ suppressed to avoid polluting downstream processing.
10
+ """
11
+
12
+ import sys
13
+ import time
14
+ from typing import Optional
15
+
16
+
17
+ class Colors:
18
+ """ANSI color codes for terminal styling."""
19
+ HEADER = '\033[95m'
20
+ BLUE = '\033[94m'
21
+ CYAN = '\033[96m'
22
+ GREEN = '\033[92m'
23
+ YELLOW = '\033[93m'
24
+ RED = '\033[91m'
25
+ BOLD = '\033[1m'
26
+ DIM = '\033[2m'
27
+ RESET = '\033[0m'
28
+
29
+
30
+ # ASCII Art Banner
31
+ BANNER = r"""
32
+ .__ __ __ .__ ___.
33
+ | | _____ _/ |__/ |_|__| ____ ____ ________ _\_ |__
34
+ | | \__ \\ __\ __\ |/ ___\/ __ \ ______ / ___/ | \ __ \
35
+ | |__/ __ \| | | | | \ \__\ ___/ /_____/ \___ \| | / \_\ \
36
+ |____(____ /__| |__| |__|\___ >___ > /____ >____/|___ /
37
+ \/ \/ \/ \/ \/
38
+ """
39
+
40
+ # Import version from package to keep it in sync
41
+ from . import __version__ as VERSION
42
+
43
+
44
+ def is_interactive() -> bool:
45
+ """
46
+ Check if running in an interactive terminal.
47
+
48
+ Returns False if stdout is piped or redirected, which means
49
+ we're likely part of a pipeline and should suppress decorative output.
50
+ """
51
+ return sys.stdout.isatty()
52
+
53
+
54
+ class TerminalUI:
55
+ """
56
+ Manages styled terminal output for the CLI.
57
+
58
+ Decorative output (banner, colors, progress indicators) is only shown when:
59
+ - Running in an interactive terminal (TTY detected)
60
+ - Not suppressed by quiet mode
61
+
62
+ When piped or in a script, output is automatically minimal.
63
+ """
64
+
65
+ def __init__(self, quiet: bool = False):
66
+ """
67
+ Initialize the terminal UI.
68
+
69
+ Args:
70
+ quiet: If True, suppress all decorative output even in interactive mode
71
+ """
72
+ self.quiet = quiet
73
+ self.interactive = is_interactive() and not quiet
74
+ self.use_colors = self.interactive
75
+ self._start_time: Optional[float] = None
76
+ self._file_start_time: Optional[float] = None
77
+
78
+ def _colorize(self, text: str, color: str) -> str:
79
+ """Apply color if colors are enabled."""
80
+ if self.use_colors:
81
+ return f"{color}{text}{Colors.RESET}"
82
+ return text
83
+
84
+ def print_banner(self) -> None:
85
+ """Print the ASCII art banner."""
86
+ if not self.interactive:
87
+ return
88
+
89
+ print()
90
+ print(self._colorize(BANNER, Colors.CYAN))
91
+ tagline = f" Phase-preserving FFT inpainting for cryo-EM | v{VERSION}"
92
+ print(self._colorize(tagline, Colors.DIM))
93
+ print()
94
+
95
+ def print_config(self, pixel_size: float, threshold: float,
96
+ backend: str, gpu_name: Optional[str] = None) -> None:
97
+ """Print configuration summary."""
98
+ if not self.interactive:
99
+ return
100
+
101
+ print(self._colorize(" Configuration", Colors.BOLD))
102
+ print(self._colorize(" -------------", Colors.DIM))
103
+ print(f" Pixel size: {self._colorize(f'{pixel_size} A', Colors.YELLOW)}")
104
+ print(f" Threshold: {self._colorize(str(threshold), Colors.YELLOW)}")
105
+
106
+ # Determine backend display string
107
+ if backend == "auto":
108
+ # Check if GPU is actually available for auto mode
109
+ if gpu_name:
110
+ backend_str = f"Auto → GPU ({gpu_name})"
111
+ print(f" Backend: {self._colorize(backend_str, Colors.GREEN)}")
112
+ else:
113
+ backend_str = "Auto → CPU"
114
+ print(f" Backend: {self._colorize(backend_str, Colors.BLUE)}")
115
+ elif backend == "pytorch" and gpu_name:
116
+ backend_str = f"PyTorch CUDA ({gpu_name})"
117
+ print(f" Backend: {self._colorize(backend_str, Colors.GREEN)}")
118
+ elif backend == "pytorch":
119
+ print(f" Backend: {self._colorize('PyTorch CUDA', Colors.GREEN)}")
120
+ else:
121
+ print(f" Backend: {self._colorize('NumPy (CPU)', Colors.BLUE)}")
122
+ print()
123
+
124
+ def start_timer(self) -> None:
125
+ """Start the overall timer."""
126
+ self._start_time = time.time()
127
+
128
+ def start_processing(self, filename: str, shape: Optional[tuple] = None) -> None:
129
+ """Indicate start of processing a file."""
130
+ self._file_start_time = time.time()
131
+
132
+ if not self.interactive:
133
+ return
134
+
135
+ print(f" {self._colorize('>', Colors.CYAN)} Processing: {self._colorize(filename, Colors.BOLD)}")
136
+ if shape:
137
+ print(f" {self._colorize('|-', Colors.DIM)} Size: {shape[0]} x {shape[1]}")
138
+
139
+ def end_processing(self, output_path: str, success: bool = True) -> None:
140
+ """Indicate end of processing."""
141
+ elapsed = time.time() - self._file_start_time if self._file_start_time else 0
142
+
143
+ if not self.interactive:
144
+ return
145
+
146
+ if success:
147
+ status = self._colorize(f"[OK] Complete ({elapsed:.2f}s)", Colors.GREEN)
148
+ else:
149
+ status = self._colorize("[FAIL]", Colors.RED)
150
+
151
+ print(f" {self._colorize('`-', Colors.DIM)} {status}")
152
+ print()
153
+
154
+ def print_batch_header(self, num_files: int, output_dir: str, num_workers: int = 1) -> None:
155
+ """Print batch processing header."""
156
+ if not self.interactive:
157
+ return
158
+
159
+ print(self._colorize(" Batch Processing", Colors.BOLD))
160
+ print(self._colorize(" ----------------", Colors.DIM))
161
+ print(f" Files: {self._colorize(str(num_files), Colors.YELLOW)}")
162
+ print(f" Output: {output_dir}")
163
+ print(f" Workers: {num_workers}")
164
+ print()
165
+
166
+ def print_batch_progress(self, current: int, total: int, filename: str,
167
+ elapsed: Optional[float] = None) -> None:
168
+ """Print batch progress update."""
169
+ if not self.interactive:
170
+ return
171
+
172
+ progress = f"[{current}/{total}]"
173
+ time_str = f" ({elapsed:.1f}s)" if elapsed else ""
174
+ print(f" {self._colorize(progress, Colors.CYAN)} {filename}{time_str}")
175
+
176
+ def print_batch_complete(self) -> None:
177
+ """Print batch completion message."""
178
+ if not self.interactive:
179
+ return
180
+
181
+ elapsed = time.time() - self._start_time if self._start_time else 0
182
+ print()
183
+ print(f" {self._colorize('[OK]', Colors.GREEN)} {self._colorize('Batch complete', Colors.BOLD)} ({elapsed:.1f}s)")
184
+ print()
185
+
186
+ def print_summary(self, processed: int, failed: int = 0) -> None:
187
+ """Print final summary."""
188
+ if not self.interactive:
189
+ return
190
+
191
+ elapsed = time.time() - self._start_time if self._start_time else 0
192
+
193
+ print(self._colorize(" Summary", Colors.BOLD))
194
+ print(self._colorize(" -------", Colors.DIM))
195
+ print(f" Processed: {self._colorize(str(processed), Colors.GREEN)}")
196
+ if failed > 0:
197
+ print(f" Failed: {self._colorize(str(failed), Colors.RED)}")
198
+ print(f" Time: {elapsed:.1f}s")
199
+ print()
200
+
201
+ def print_error(self, message: str) -> None:
202
+ """Print error message (always shown unless quiet)."""
203
+ if self.quiet:
204
+ return
205
+ prefix = self._colorize("Error:", Colors.RED) if self.use_colors else "Error:"
206
+ print(f" {prefix} {message}", file=sys.stderr)
207
+
208
+ def print_warning(self, message: str) -> None:
209
+ """Print warning message."""
210
+ if not self.interactive:
211
+ return
212
+ prefix = self._colorize("Warning:", Colors.YELLOW)
213
+ print(f" {prefix} {message}", file=sys.stderr)
214
+
215
+ def print_success(self, message: str) -> None:
216
+ """Print success message."""
217
+ if not self.interactive:
218
+ return
219
+ check = self._colorize("[OK]", Colors.GREEN)
220
+ print(f" {check} {message}")
221
+
222
+ def print_info(self, message: str) -> None:
223
+ """Print info message."""
224
+ if not self.interactive:
225
+ return
226
+ print(f" {self._colorize('>', Colors.CYAN)} {message}")
227
+
228
+ def print_saved(self, path: str) -> None:
229
+ """Print saved file notification."""
230
+ if not self.interactive:
231
+ return
232
+ print(f" {self._colorize('|-', Colors.DIM)} Saved: {path}")
233
+
234
+
235
+ def get_ui(quiet: bool = False) -> TerminalUI:
236
+ """
237
+ Get a TerminalUI instance.
238
+
239
+ Args:
240
+ quiet: If True, suppress decorative output even in interactive mode
241
+
242
+ Returns:
243
+ TerminalUI instance configured for current environment
244
+ """
245
+ return TerminalUI(quiet=quiet)
246
+
247
+
248
+ def get_gpu_name() -> Optional[str]:
249
+ """Get the name of the CUDA GPU if available."""
250
+ try:
251
+ import torch
252
+ if torch.cuda.is_available():
253
+ return torch.cuda.get_device_name(0)
254
+ except ImportError:
255
+ pass
256
+ return None
@@ -0,0 +1,195 @@
1
+ """
2
+ Visualization utilities for lattice subtraction results.
3
+
4
+ This module provides functions to create comparison visualizations
5
+ showing original, processed, and difference images.
6
+ """
7
+
8
+ import logging
9
+ from pathlib import Path
10
+ from typing import Optional, Tuple
11
+
12
+ import numpy as np
13
+
14
+ # Silence matplotlib's verbose debug logging
15
+ logging.getLogger('matplotlib').setLevel(logging.WARNING)
16
+ logging.getLogger('matplotlib.font_manager').setLevel(logging.WARNING)
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+
21
+ def create_comparison_figure(
22
+ original: np.ndarray,
23
+ processed: np.ndarray,
24
+ title: str = "Lattice Subtraction Comparison",
25
+ figsize: Tuple[int, int] = (18, 6),
26
+ dpi: int = 150,
27
+ ):
28
+ """
29
+ Create a comparison figure showing original, processed, and difference images.
30
+
31
+ Args:
32
+ original: Original image array
33
+ processed: Processed (lattice-subtracted) image array
34
+ title: Figure title
35
+ figsize: Figure size in inches (width, height)
36
+ dpi: Resolution for saving
37
+
38
+ Returns:
39
+ matplotlib Figure object
40
+ """
41
+ import matplotlib.pyplot as plt
42
+
43
+ # Compute difference
44
+ difference = original - processed
45
+
46
+ # Create figure
47
+ fig, axes = plt.subplots(1, 3, figsize=figsize)
48
+
49
+ # Contrast limits from original
50
+ vmin, vmax = np.percentile(original, [1, 99])
51
+
52
+ # Original
53
+ axes[0].imshow(original, cmap='gray', vmin=vmin, vmax=vmax)
54
+ axes[0].set_title(f'Original\n{original.shape}')
55
+ axes[0].axis('off')
56
+
57
+ # Lattice Subtracted
58
+ axes[1].imshow(processed, cmap='gray', vmin=vmin, vmax=vmax)
59
+ axes[1].set_title(f'Lattice Subtracted\n{processed.shape}')
60
+ axes[1].axis('off')
61
+
62
+ # Difference (removed lattice)
63
+ diff_std = np.std(difference)
64
+ axes[2].imshow(
65
+ difference,
66
+ cmap='RdBu_r',
67
+ vmin=-diff_std * 3,
68
+ vmax=diff_std * 3
69
+ )
70
+ axes[2].set_title('Difference (Removed Lattice)')
71
+ axes[2].axis('off')
72
+
73
+ # Title
74
+ plt.suptitle(title, fontsize=14)
75
+ plt.tight_layout()
76
+
77
+ return fig
78
+
79
+
80
+ def save_comparison_visualization(
81
+ original_path: Path,
82
+ processed_path: Path,
83
+ output_path: Path,
84
+ dpi: int = 150,
85
+ ) -> None:
86
+ """
87
+ Create and save a comparison visualization for a single image pair.
88
+
89
+ Args:
90
+ original_path: Path to original MRC file
91
+ processed_path: Path to processed MRC file
92
+ output_path: Path for output PNG file
93
+ """
94
+ import matplotlib.pyplot as plt
95
+ import mrcfile
96
+
97
+ # Load images
98
+ with mrcfile.open(original_path, 'r') as f:
99
+ original = f.data.copy()
100
+ with mrcfile.open(processed_path, 'r') as f:
101
+ processed = f.data.copy()
102
+
103
+ # Create title
104
+ name = original_path.name
105
+ short_name = name[:60] + "..." if len(name) > 60 else name
106
+ title = f"Lattice Subtraction Comparison: {short_name}"
107
+
108
+ # Create and save figure
109
+ fig = create_comparison_figure(original, processed, title=title, dpi=dpi)
110
+ output_path.parent.mkdir(parents=True, exist_ok=True)
111
+ fig.savefig(output_path, dpi=dpi, bbox_inches='tight')
112
+ plt.close(fig)
113
+
114
+
115
+ def generate_visualizations(
116
+ input_dir: Path,
117
+ output_dir: Path,
118
+ viz_dir: Path,
119
+ prefix: str = "sub_",
120
+ pattern: str = "*.mrc",
121
+ dpi: int = 150,
122
+ show_progress: bool = True,
123
+ ) -> Tuple[int, int]:
124
+ """
125
+ Generate comparison visualizations for all processed images in a directory.
126
+
127
+ Args:
128
+ input_dir: Directory containing original MRC files
129
+ output_dir: Directory containing processed MRC files
130
+ viz_dir: Directory for output visualization PNG files
131
+ prefix: Prefix used for processed files (default: "sub_")
132
+ pattern: Glob pattern for finding processed files
133
+ dpi: Resolution for output images
134
+ show_progress: Show progress bar
135
+
136
+ Returns:
137
+ Tuple of (successful_count, total_count)
138
+ """
139
+ import matplotlib.pyplot as plt
140
+
141
+ viz_dir = Path(viz_dir)
142
+ viz_dir.mkdir(parents=True, exist_ok=True)
143
+
144
+ # Find all processed files
145
+ output_files = sorted(Path(output_dir).glob(f"{prefix}{pattern}"))
146
+
147
+ if not output_files:
148
+ logger.warning(f"No processed files found matching '{prefix}{pattern}' in {output_dir}")
149
+ return 0, 0
150
+
151
+ successful = 0
152
+ total = len(output_files)
153
+
154
+ # Setup iterator with optional progress bar
155
+ if show_progress:
156
+ try:
157
+ from tqdm import tqdm
158
+ iterator = tqdm(output_files, desc="Generating visualizations", unit="file")
159
+ except ImportError:
160
+ iterator = output_files
161
+ else:
162
+ iterator = output_files
163
+
164
+ for processed_path in iterator:
165
+ try:
166
+ # Get corresponding input file
167
+ input_name = processed_path.name.replace(prefix, "", 1)
168
+ input_path = Path(input_dir) / input_name
169
+
170
+ if not input_path.exists():
171
+ logger.debug(f"Original not found: {input_path}")
172
+ continue
173
+
174
+ # Output path
175
+ viz_name = input_name.replace(".mrc", ".png")
176
+ viz_path = viz_dir / viz_name
177
+
178
+ # Skip if already exists
179
+ if viz_path.exists():
180
+ successful += 1
181
+ continue
182
+
183
+ # Generate visualization
184
+ save_comparison_visualization(
185
+ original_path=input_path,
186
+ processed_path=processed_path,
187
+ output_path=viz_path,
188
+ dpi=dpi,
189
+ )
190
+ successful += 1
191
+
192
+ except Exception as e:
193
+ logger.error(f"Failed to create visualization for {processed_path.name}: {e}")
194
+
195
+ return successful, total