lattice-sub 1.0.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lattice_sub-1.0.10.dist-info/METADATA +324 -0
- lattice_sub-1.0.10.dist-info/RECORD +16 -0
- lattice_sub-1.0.10.dist-info/WHEEL +5 -0
- lattice_sub-1.0.10.dist-info/entry_points.txt +2 -0
- lattice_sub-1.0.10.dist-info/licenses/LICENSE +21 -0
- lattice_sub-1.0.10.dist-info/top_level.txt +1 -0
- lattice_subtraction/__init__.py +49 -0
- lattice_subtraction/batch.py +374 -0
- lattice_subtraction/cli.py +751 -0
- lattice_subtraction/config.py +216 -0
- lattice_subtraction/core.py +389 -0
- lattice_subtraction/io.py +177 -0
- lattice_subtraction/masks.py +397 -0
- lattice_subtraction/processing.py +221 -0
- lattice_subtraction/ui.py +256 -0
- lattice_subtraction/visualization.py +195 -0
|
@@ -0,0 +1,221 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Image processing utilities.
|
|
3
|
+
|
|
4
|
+
This module contains functions for image padding, background subtraction,
|
|
5
|
+
and other preprocessing operations.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
9
|
+
from typing import Tuple
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def pad_image(
|
|
13
|
+
image: np.ndarray,
|
|
14
|
+
pad_origin: Tuple[int, int],
|
|
15
|
+
target_size: int | None = None,
|
|
16
|
+
pad_value: float | None = None,
|
|
17
|
+
) -> Tuple[np.ndarray, dict]:
|
|
18
|
+
"""
|
|
19
|
+
Pad an image with mean border for FFT processing.
|
|
20
|
+
|
|
21
|
+
This replicates the MATLAB padarray functionality with 'pre' and 'post'
|
|
22
|
+
padding using the image mean value.
|
|
23
|
+
|
|
24
|
+
Args:
|
|
25
|
+
image: Input 2D image
|
|
26
|
+
pad_origin: Padding offsets (pad_y, pad_x) - pixels to add at start
|
|
27
|
+
target_size: Target square size. If None, auto-calculated.
|
|
28
|
+
pad_value: Value to use for padding. If None, uses image mean.
|
|
29
|
+
|
|
30
|
+
Returns:
|
|
31
|
+
Tuple of:
|
|
32
|
+
- Padded image
|
|
33
|
+
- Metadata dict with original shape and padding info for cropping
|
|
34
|
+
"""
|
|
35
|
+
orig_h, orig_w = image.shape
|
|
36
|
+
pad_y, pad_x = pad_origin
|
|
37
|
+
|
|
38
|
+
# Auto-calculate target size if not provided
|
|
39
|
+
if target_size is None:
|
|
40
|
+
max_dim = max(orig_h, orig_w)
|
|
41
|
+
target_size = max_dim + pad_x * 2
|
|
42
|
+
# Round to nearest 10 for FFT efficiency
|
|
43
|
+
target_size = int(np.round(target_size / 10) * 10)
|
|
44
|
+
|
|
45
|
+
# Calculate padding amounts
|
|
46
|
+
pad_top = pad_y - 1 if pad_y > 0 else 0
|
|
47
|
+
pad_left = pad_x - 1 if pad_x > 0 else 0
|
|
48
|
+
pad_bottom = target_size - orig_h - pad_top
|
|
49
|
+
pad_right = target_size - orig_w - pad_left
|
|
50
|
+
|
|
51
|
+
# Ensure non-negative padding
|
|
52
|
+
pad_bottom = max(0, pad_bottom)
|
|
53
|
+
pad_right = max(0, pad_right)
|
|
54
|
+
|
|
55
|
+
# Use image mean for padding value if not specified
|
|
56
|
+
if pad_value is None:
|
|
57
|
+
pad_value = float(np.mean(image))
|
|
58
|
+
|
|
59
|
+
# Perform padding
|
|
60
|
+
padded = np.pad(
|
|
61
|
+
image,
|
|
62
|
+
((pad_top, pad_bottom), (pad_left, pad_right)),
|
|
63
|
+
mode='constant',
|
|
64
|
+
constant_values=pad_value,
|
|
65
|
+
)
|
|
66
|
+
|
|
67
|
+
# Store metadata for later cropping
|
|
68
|
+
metadata = {
|
|
69
|
+
'original_shape': (orig_h, orig_w),
|
|
70
|
+
'pad_top': pad_top,
|
|
71
|
+
'pad_left': pad_left,
|
|
72
|
+
'pad_bottom': pad_bottom,
|
|
73
|
+
'pad_right': pad_right,
|
|
74
|
+
'target_size': target_size,
|
|
75
|
+
}
|
|
76
|
+
|
|
77
|
+
return padded, metadata
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def crop_to_original(
|
|
81
|
+
image: np.ndarray,
|
|
82
|
+
metadata: dict,
|
|
83
|
+
) -> np.ndarray:
|
|
84
|
+
"""
|
|
85
|
+
Crop a padded image back to its original size.
|
|
86
|
+
|
|
87
|
+
Args:
|
|
88
|
+
image: Padded image
|
|
89
|
+
metadata: Metadata dict from pad_image()
|
|
90
|
+
|
|
91
|
+
Returns:
|
|
92
|
+
Cropped image with original dimensions
|
|
93
|
+
"""
|
|
94
|
+
orig_h, orig_w = metadata['original_shape']
|
|
95
|
+
pad_top = metadata['pad_top']
|
|
96
|
+
pad_left = metadata['pad_left']
|
|
97
|
+
|
|
98
|
+
return image[pad_top:pad_top + orig_h, pad_left:pad_left + orig_w]
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
def subtract_background(
|
|
102
|
+
image: np.ndarray,
|
|
103
|
+
median_filter_size: int = 10,
|
|
104
|
+
) -> np.ndarray:
|
|
105
|
+
"""
|
|
106
|
+
Subtract smooth background from an image to reveal sharp features.
|
|
107
|
+
|
|
108
|
+
This is the Python equivalent of bg_FastSubtract_standard.m.
|
|
109
|
+
It creates a smoothed version of the image using median filtering
|
|
110
|
+
and subtracts it from the original.
|
|
111
|
+
|
|
112
|
+
Args:
|
|
113
|
+
image: Input 2D image (typically log-power spectrum)
|
|
114
|
+
median_filter_size: Size of median filter kernel. Default: 10
|
|
115
|
+
|
|
116
|
+
Returns:
|
|
117
|
+
Background-subtracted image with edge regions replaced by mean
|
|
118
|
+
"""
|
|
119
|
+
from scipy import ndimage
|
|
120
|
+
from skimage.transform import resize
|
|
121
|
+
|
|
122
|
+
h, w = image.shape
|
|
123
|
+
|
|
124
|
+
if max(h, w) < 500:
|
|
125
|
+
# For small images, apply median filter directly
|
|
126
|
+
smoothed = ndimage.median_filter(image, size=median_filter_size)
|
|
127
|
+
edge = median_filter_size
|
|
128
|
+
else:
|
|
129
|
+
# For large images, downsample -> filter -> upsample
|
|
130
|
+
shrink_factor = 500 / max(h, w)
|
|
131
|
+
|
|
132
|
+
# Downsample
|
|
133
|
+
small = resize(
|
|
134
|
+
image,
|
|
135
|
+
(int(h * shrink_factor), int(w * shrink_factor)),
|
|
136
|
+
order=1, # Bilinear
|
|
137
|
+
preserve_range=True,
|
|
138
|
+
anti_aliasing=True,
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
# Apply median filter
|
|
142
|
+
small = ndimage.median_filter(small, size=median_filter_size)
|
|
143
|
+
|
|
144
|
+
# Upsample back to original size
|
|
145
|
+
smoothed = resize(
|
|
146
|
+
small,
|
|
147
|
+
(h, w),
|
|
148
|
+
order=1,
|
|
149
|
+
preserve_range=True,
|
|
150
|
+
anti_aliasing=True,
|
|
151
|
+
)
|
|
152
|
+
|
|
153
|
+
# Scale edge hiding region
|
|
154
|
+
edge = int(median_filter_size / shrink_factor)
|
|
155
|
+
|
|
156
|
+
# Subtract background
|
|
157
|
+
subtracted = image - smoothed
|
|
158
|
+
|
|
159
|
+
# Hide edges with artifacts
|
|
160
|
+
mean_value = np.mean(subtracted)
|
|
161
|
+
edge = max(1, edge)
|
|
162
|
+
|
|
163
|
+
# Replace edge regions with mean
|
|
164
|
+
subtracted[:edge, :] = mean_value
|
|
165
|
+
subtracted[-edge:, :] = mean_value
|
|
166
|
+
subtracted[:, :edge] = mean_value
|
|
167
|
+
subtracted[:, -edge:] = mean_value
|
|
168
|
+
|
|
169
|
+
return subtracted.astype(np.float32)
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
def compute_power_spectrum(
|
|
173
|
+
fft_shifted: np.ndarray,
|
|
174
|
+
log_scale: bool = True,
|
|
175
|
+
epsilon: float = 1e-10,
|
|
176
|
+
) -> np.ndarray:
|
|
177
|
+
"""
|
|
178
|
+
Compute power spectrum from shifted FFT.
|
|
179
|
+
|
|
180
|
+
Args:
|
|
181
|
+
fft_shifted: Centered FFT (after fftshift)
|
|
182
|
+
log_scale: If True, return log of amplitude. Default: True
|
|
183
|
+
epsilon: Small value to avoid log(0). Default: 1e-10
|
|
184
|
+
|
|
185
|
+
Returns:
|
|
186
|
+
Power spectrum (log-amplitude if log_scale=True)
|
|
187
|
+
"""
|
|
188
|
+
amplitude = np.abs(fft_shifted)
|
|
189
|
+
|
|
190
|
+
if log_scale:
|
|
191
|
+
return np.log(amplitude + epsilon)
|
|
192
|
+
|
|
193
|
+
return amplitude
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
def shift_and_average(
|
|
197
|
+
array: np.ndarray,
|
|
198
|
+
shift_pixels: int,
|
|
199
|
+
) -> np.ndarray:
|
|
200
|
+
"""
|
|
201
|
+
Create a local average by averaging 4 shifted copies.
|
|
202
|
+
|
|
203
|
+
This is the inpainting technique from bg_push_by_rot.m that
|
|
204
|
+
averages amplitude values from neighboring regions.
|
|
205
|
+
|
|
206
|
+
Args:
|
|
207
|
+
array: Input 2D array (typically FFT amplitude)
|
|
208
|
+
shift_pixels: Number of pixels to shift in each direction
|
|
209
|
+
|
|
210
|
+
Returns:
|
|
211
|
+
Averaged array (local background estimate)
|
|
212
|
+
"""
|
|
213
|
+
# Shift in 4 cardinal directions and average
|
|
214
|
+
shifted_sum = (
|
|
215
|
+
np.roll(array, shift_pixels, axis=0) +
|
|
216
|
+
np.roll(array, -shift_pixels, axis=0) +
|
|
217
|
+
np.roll(array, shift_pixels, axis=1) +
|
|
218
|
+
np.roll(array, -shift_pixels, axis=1)
|
|
219
|
+
)
|
|
220
|
+
|
|
221
|
+
return shifted_sum / 4.0
|
|
@@ -0,0 +1,256 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Terminal UI utilities for lattice subtraction.
|
|
3
|
+
|
|
4
|
+
This module provides styled terminal output with ASCII art banner
|
|
5
|
+
and formatted progress messages. Output is only shown when running
|
|
6
|
+
interactively (TTY detected) and not suppressed by --quiet flag.
|
|
7
|
+
|
|
8
|
+
When piped or used in a pipeline, decorative output is automatically
|
|
9
|
+
suppressed to avoid polluting downstream processing.
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
import sys
|
|
13
|
+
import time
|
|
14
|
+
from typing import Optional
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class Colors:
|
|
18
|
+
"""ANSI color codes for terminal styling."""
|
|
19
|
+
HEADER = '\033[95m'
|
|
20
|
+
BLUE = '\033[94m'
|
|
21
|
+
CYAN = '\033[96m'
|
|
22
|
+
GREEN = '\033[92m'
|
|
23
|
+
YELLOW = '\033[93m'
|
|
24
|
+
RED = '\033[91m'
|
|
25
|
+
BOLD = '\033[1m'
|
|
26
|
+
DIM = '\033[2m'
|
|
27
|
+
RESET = '\033[0m'
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
# ASCII Art Banner
|
|
31
|
+
BANNER = r"""
|
|
32
|
+
.__ __ __ .__ ___.
|
|
33
|
+
| | _____ _/ |__/ |_|__| ____ ____ ________ _\_ |__
|
|
34
|
+
| | \__ \\ __\ __\ |/ ___\/ __ \ ______ / ___/ | \ __ \
|
|
35
|
+
| |__/ __ \| | | | | \ \__\ ___/ /_____/ \___ \| | / \_\ \
|
|
36
|
+
|____(____ /__| |__| |__|\___ >___ > /____ >____/|___ /
|
|
37
|
+
\/ \/ \/ \/ \/
|
|
38
|
+
"""
|
|
39
|
+
|
|
40
|
+
# Import version from package to keep it in sync
|
|
41
|
+
from . import __version__ as VERSION
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def is_interactive() -> bool:
|
|
45
|
+
"""
|
|
46
|
+
Check if running in an interactive terminal.
|
|
47
|
+
|
|
48
|
+
Returns False if stdout is piped or redirected, which means
|
|
49
|
+
we're likely part of a pipeline and should suppress decorative output.
|
|
50
|
+
"""
|
|
51
|
+
return sys.stdout.isatty()
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
class TerminalUI:
|
|
55
|
+
"""
|
|
56
|
+
Manages styled terminal output for the CLI.
|
|
57
|
+
|
|
58
|
+
Decorative output (banner, colors, progress indicators) is only shown when:
|
|
59
|
+
- Running in an interactive terminal (TTY detected)
|
|
60
|
+
- Not suppressed by quiet mode
|
|
61
|
+
|
|
62
|
+
When piped or in a script, output is automatically minimal.
|
|
63
|
+
"""
|
|
64
|
+
|
|
65
|
+
def __init__(self, quiet: bool = False):
|
|
66
|
+
"""
|
|
67
|
+
Initialize the terminal UI.
|
|
68
|
+
|
|
69
|
+
Args:
|
|
70
|
+
quiet: If True, suppress all decorative output even in interactive mode
|
|
71
|
+
"""
|
|
72
|
+
self.quiet = quiet
|
|
73
|
+
self.interactive = is_interactive() and not quiet
|
|
74
|
+
self.use_colors = self.interactive
|
|
75
|
+
self._start_time: Optional[float] = None
|
|
76
|
+
self._file_start_time: Optional[float] = None
|
|
77
|
+
|
|
78
|
+
def _colorize(self, text: str, color: str) -> str:
|
|
79
|
+
"""Apply color if colors are enabled."""
|
|
80
|
+
if self.use_colors:
|
|
81
|
+
return f"{color}{text}{Colors.RESET}"
|
|
82
|
+
return text
|
|
83
|
+
|
|
84
|
+
def print_banner(self) -> None:
|
|
85
|
+
"""Print the ASCII art banner."""
|
|
86
|
+
if not self.interactive:
|
|
87
|
+
return
|
|
88
|
+
|
|
89
|
+
print()
|
|
90
|
+
print(self._colorize(BANNER, Colors.CYAN))
|
|
91
|
+
tagline = f" Phase-preserving FFT inpainting for cryo-EM | v{VERSION}"
|
|
92
|
+
print(self._colorize(tagline, Colors.DIM))
|
|
93
|
+
print()
|
|
94
|
+
|
|
95
|
+
def print_config(self, pixel_size: float, threshold: float,
|
|
96
|
+
backend: str, gpu_name: Optional[str] = None) -> None:
|
|
97
|
+
"""Print configuration summary."""
|
|
98
|
+
if not self.interactive:
|
|
99
|
+
return
|
|
100
|
+
|
|
101
|
+
print(self._colorize(" Configuration", Colors.BOLD))
|
|
102
|
+
print(self._colorize(" -------------", Colors.DIM))
|
|
103
|
+
print(f" Pixel size: {self._colorize(f'{pixel_size} A', Colors.YELLOW)}")
|
|
104
|
+
print(f" Threshold: {self._colorize(str(threshold), Colors.YELLOW)}")
|
|
105
|
+
|
|
106
|
+
# Determine backend display string
|
|
107
|
+
if backend == "auto":
|
|
108
|
+
# Check if GPU is actually available for auto mode
|
|
109
|
+
if gpu_name:
|
|
110
|
+
backend_str = f"Auto → GPU ({gpu_name})"
|
|
111
|
+
print(f" Backend: {self._colorize(backend_str, Colors.GREEN)}")
|
|
112
|
+
else:
|
|
113
|
+
backend_str = "Auto → CPU"
|
|
114
|
+
print(f" Backend: {self._colorize(backend_str, Colors.BLUE)}")
|
|
115
|
+
elif backend == "pytorch" and gpu_name:
|
|
116
|
+
backend_str = f"PyTorch CUDA ({gpu_name})"
|
|
117
|
+
print(f" Backend: {self._colorize(backend_str, Colors.GREEN)}")
|
|
118
|
+
elif backend == "pytorch":
|
|
119
|
+
print(f" Backend: {self._colorize('PyTorch CUDA', Colors.GREEN)}")
|
|
120
|
+
else:
|
|
121
|
+
print(f" Backend: {self._colorize('NumPy (CPU)', Colors.BLUE)}")
|
|
122
|
+
print()
|
|
123
|
+
|
|
124
|
+
def start_timer(self) -> None:
|
|
125
|
+
"""Start the overall timer."""
|
|
126
|
+
self._start_time = time.time()
|
|
127
|
+
|
|
128
|
+
def start_processing(self, filename: str, shape: Optional[tuple] = None) -> None:
|
|
129
|
+
"""Indicate start of processing a file."""
|
|
130
|
+
self._file_start_time = time.time()
|
|
131
|
+
|
|
132
|
+
if not self.interactive:
|
|
133
|
+
return
|
|
134
|
+
|
|
135
|
+
print(f" {self._colorize('>', Colors.CYAN)} Processing: {self._colorize(filename, Colors.BOLD)}")
|
|
136
|
+
if shape:
|
|
137
|
+
print(f" {self._colorize('|-', Colors.DIM)} Size: {shape[0]} x {shape[1]}")
|
|
138
|
+
|
|
139
|
+
def end_processing(self, output_path: str, success: bool = True) -> None:
|
|
140
|
+
"""Indicate end of processing."""
|
|
141
|
+
elapsed = time.time() - self._file_start_time if self._file_start_time else 0
|
|
142
|
+
|
|
143
|
+
if not self.interactive:
|
|
144
|
+
return
|
|
145
|
+
|
|
146
|
+
if success:
|
|
147
|
+
status = self._colorize(f"[OK] Complete ({elapsed:.2f}s)", Colors.GREEN)
|
|
148
|
+
else:
|
|
149
|
+
status = self._colorize("[FAIL]", Colors.RED)
|
|
150
|
+
|
|
151
|
+
print(f" {self._colorize('`-', Colors.DIM)} {status}")
|
|
152
|
+
print()
|
|
153
|
+
|
|
154
|
+
def print_batch_header(self, num_files: int, output_dir: str, num_workers: int = 1) -> None:
|
|
155
|
+
"""Print batch processing header."""
|
|
156
|
+
if not self.interactive:
|
|
157
|
+
return
|
|
158
|
+
|
|
159
|
+
print(self._colorize(" Batch Processing", Colors.BOLD))
|
|
160
|
+
print(self._colorize(" ----------------", Colors.DIM))
|
|
161
|
+
print(f" Files: {self._colorize(str(num_files), Colors.YELLOW)}")
|
|
162
|
+
print(f" Output: {output_dir}")
|
|
163
|
+
print(f" Workers: {num_workers}")
|
|
164
|
+
print()
|
|
165
|
+
|
|
166
|
+
def print_batch_progress(self, current: int, total: int, filename: str,
|
|
167
|
+
elapsed: Optional[float] = None) -> None:
|
|
168
|
+
"""Print batch progress update."""
|
|
169
|
+
if not self.interactive:
|
|
170
|
+
return
|
|
171
|
+
|
|
172
|
+
progress = f"[{current}/{total}]"
|
|
173
|
+
time_str = f" ({elapsed:.1f}s)" if elapsed else ""
|
|
174
|
+
print(f" {self._colorize(progress, Colors.CYAN)} {filename}{time_str}")
|
|
175
|
+
|
|
176
|
+
def print_batch_complete(self) -> None:
|
|
177
|
+
"""Print batch completion message."""
|
|
178
|
+
if not self.interactive:
|
|
179
|
+
return
|
|
180
|
+
|
|
181
|
+
elapsed = time.time() - self._start_time if self._start_time else 0
|
|
182
|
+
print()
|
|
183
|
+
print(f" {self._colorize('[OK]', Colors.GREEN)} {self._colorize('Batch complete', Colors.BOLD)} ({elapsed:.1f}s)")
|
|
184
|
+
print()
|
|
185
|
+
|
|
186
|
+
def print_summary(self, processed: int, failed: int = 0) -> None:
|
|
187
|
+
"""Print final summary."""
|
|
188
|
+
if not self.interactive:
|
|
189
|
+
return
|
|
190
|
+
|
|
191
|
+
elapsed = time.time() - self._start_time if self._start_time else 0
|
|
192
|
+
|
|
193
|
+
print(self._colorize(" Summary", Colors.BOLD))
|
|
194
|
+
print(self._colorize(" -------", Colors.DIM))
|
|
195
|
+
print(f" Processed: {self._colorize(str(processed), Colors.GREEN)}")
|
|
196
|
+
if failed > 0:
|
|
197
|
+
print(f" Failed: {self._colorize(str(failed), Colors.RED)}")
|
|
198
|
+
print(f" Time: {elapsed:.1f}s")
|
|
199
|
+
print()
|
|
200
|
+
|
|
201
|
+
def print_error(self, message: str) -> None:
|
|
202
|
+
"""Print error message (always shown unless quiet)."""
|
|
203
|
+
if self.quiet:
|
|
204
|
+
return
|
|
205
|
+
prefix = self._colorize("Error:", Colors.RED) if self.use_colors else "Error:"
|
|
206
|
+
print(f" {prefix} {message}", file=sys.stderr)
|
|
207
|
+
|
|
208
|
+
def print_warning(self, message: str) -> None:
|
|
209
|
+
"""Print warning message."""
|
|
210
|
+
if not self.interactive:
|
|
211
|
+
return
|
|
212
|
+
prefix = self._colorize("Warning:", Colors.YELLOW)
|
|
213
|
+
print(f" {prefix} {message}", file=sys.stderr)
|
|
214
|
+
|
|
215
|
+
def print_success(self, message: str) -> None:
|
|
216
|
+
"""Print success message."""
|
|
217
|
+
if not self.interactive:
|
|
218
|
+
return
|
|
219
|
+
check = self._colorize("[OK]", Colors.GREEN)
|
|
220
|
+
print(f" {check} {message}")
|
|
221
|
+
|
|
222
|
+
def print_info(self, message: str) -> None:
|
|
223
|
+
"""Print info message."""
|
|
224
|
+
if not self.interactive:
|
|
225
|
+
return
|
|
226
|
+
print(f" {self._colorize('>', Colors.CYAN)} {message}")
|
|
227
|
+
|
|
228
|
+
def print_saved(self, path: str) -> None:
|
|
229
|
+
"""Print saved file notification."""
|
|
230
|
+
if not self.interactive:
|
|
231
|
+
return
|
|
232
|
+
print(f" {self._colorize('|-', Colors.DIM)} Saved: {path}")
|
|
233
|
+
|
|
234
|
+
|
|
235
|
+
def get_ui(quiet: bool = False) -> TerminalUI:
|
|
236
|
+
"""
|
|
237
|
+
Get a TerminalUI instance.
|
|
238
|
+
|
|
239
|
+
Args:
|
|
240
|
+
quiet: If True, suppress decorative output even in interactive mode
|
|
241
|
+
|
|
242
|
+
Returns:
|
|
243
|
+
TerminalUI instance configured for current environment
|
|
244
|
+
"""
|
|
245
|
+
return TerminalUI(quiet=quiet)
|
|
246
|
+
|
|
247
|
+
|
|
248
|
+
def get_gpu_name() -> Optional[str]:
|
|
249
|
+
"""Get the name of the CUDA GPU if available."""
|
|
250
|
+
try:
|
|
251
|
+
import torch
|
|
252
|
+
if torch.cuda.is_available():
|
|
253
|
+
return torch.cuda.get_device_name(0)
|
|
254
|
+
except ImportError:
|
|
255
|
+
pass
|
|
256
|
+
return None
|
|
@@ -0,0 +1,195 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Visualization utilities for lattice subtraction results.
|
|
3
|
+
|
|
4
|
+
This module provides functions to create comparison visualizations
|
|
5
|
+
showing original, processed, and difference images.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import logging
|
|
9
|
+
from pathlib import Path
|
|
10
|
+
from typing import Optional, Tuple
|
|
11
|
+
|
|
12
|
+
import numpy as np
|
|
13
|
+
|
|
14
|
+
# Silence matplotlib's verbose debug logging
|
|
15
|
+
logging.getLogger('matplotlib').setLevel(logging.WARNING)
|
|
16
|
+
logging.getLogger('matplotlib.font_manager').setLevel(logging.WARNING)
|
|
17
|
+
|
|
18
|
+
logger = logging.getLogger(__name__)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def create_comparison_figure(
|
|
22
|
+
original: np.ndarray,
|
|
23
|
+
processed: np.ndarray,
|
|
24
|
+
title: str = "Lattice Subtraction Comparison",
|
|
25
|
+
figsize: Tuple[int, int] = (18, 6),
|
|
26
|
+
dpi: int = 150,
|
|
27
|
+
):
|
|
28
|
+
"""
|
|
29
|
+
Create a comparison figure showing original, processed, and difference images.
|
|
30
|
+
|
|
31
|
+
Args:
|
|
32
|
+
original: Original image array
|
|
33
|
+
processed: Processed (lattice-subtracted) image array
|
|
34
|
+
title: Figure title
|
|
35
|
+
figsize: Figure size in inches (width, height)
|
|
36
|
+
dpi: Resolution for saving
|
|
37
|
+
|
|
38
|
+
Returns:
|
|
39
|
+
matplotlib Figure object
|
|
40
|
+
"""
|
|
41
|
+
import matplotlib.pyplot as plt
|
|
42
|
+
|
|
43
|
+
# Compute difference
|
|
44
|
+
difference = original - processed
|
|
45
|
+
|
|
46
|
+
# Create figure
|
|
47
|
+
fig, axes = plt.subplots(1, 3, figsize=figsize)
|
|
48
|
+
|
|
49
|
+
# Contrast limits from original
|
|
50
|
+
vmin, vmax = np.percentile(original, [1, 99])
|
|
51
|
+
|
|
52
|
+
# Original
|
|
53
|
+
axes[0].imshow(original, cmap='gray', vmin=vmin, vmax=vmax)
|
|
54
|
+
axes[0].set_title(f'Original\n{original.shape}')
|
|
55
|
+
axes[0].axis('off')
|
|
56
|
+
|
|
57
|
+
# Lattice Subtracted
|
|
58
|
+
axes[1].imshow(processed, cmap='gray', vmin=vmin, vmax=vmax)
|
|
59
|
+
axes[1].set_title(f'Lattice Subtracted\n{processed.shape}')
|
|
60
|
+
axes[1].axis('off')
|
|
61
|
+
|
|
62
|
+
# Difference (removed lattice)
|
|
63
|
+
diff_std = np.std(difference)
|
|
64
|
+
axes[2].imshow(
|
|
65
|
+
difference,
|
|
66
|
+
cmap='RdBu_r',
|
|
67
|
+
vmin=-diff_std * 3,
|
|
68
|
+
vmax=diff_std * 3
|
|
69
|
+
)
|
|
70
|
+
axes[2].set_title('Difference (Removed Lattice)')
|
|
71
|
+
axes[2].axis('off')
|
|
72
|
+
|
|
73
|
+
# Title
|
|
74
|
+
plt.suptitle(title, fontsize=14)
|
|
75
|
+
plt.tight_layout()
|
|
76
|
+
|
|
77
|
+
return fig
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def save_comparison_visualization(
|
|
81
|
+
original_path: Path,
|
|
82
|
+
processed_path: Path,
|
|
83
|
+
output_path: Path,
|
|
84
|
+
dpi: int = 150,
|
|
85
|
+
) -> None:
|
|
86
|
+
"""
|
|
87
|
+
Create and save a comparison visualization for a single image pair.
|
|
88
|
+
|
|
89
|
+
Args:
|
|
90
|
+
original_path: Path to original MRC file
|
|
91
|
+
processed_path: Path to processed MRC file
|
|
92
|
+
output_path: Path for output PNG file
|
|
93
|
+
"""
|
|
94
|
+
import matplotlib.pyplot as plt
|
|
95
|
+
import mrcfile
|
|
96
|
+
|
|
97
|
+
# Load images
|
|
98
|
+
with mrcfile.open(original_path, 'r') as f:
|
|
99
|
+
original = f.data.copy()
|
|
100
|
+
with mrcfile.open(processed_path, 'r') as f:
|
|
101
|
+
processed = f.data.copy()
|
|
102
|
+
|
|
103
|
+
# Create title
|
|
104
|
+
name = original_path.name
|
|
105
|
+
short_name = name[:60] + "..." if len(name) > 60 else name
|
|
106
|
+
title = f"Lattice Subtraction Comparison: {short_name}"
|
|
107
|
+
|
|
108
|
+
# Create and save figure
|
|
109
|
+
fig = create_comparison_figure(original, processed, title=title, dpi=dpi)
|
|
110
|
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
111
|
+
fig.savefig(output_path, dpi=dpi, bbox_inches='tight')
|
|
112
|
+
plt.close(fig)
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
def generate_visualizations(
|
|
116
|
+
input_dir: Path,
|
|
117
|
+
output_dir: Path,
|
|
118
|
+
viz_dir: Path,
|
|
119
|
+
prefix: str = "sub_",
|
|
120
|
+
pattern: str = "*.mrc",
|
|
121
|
+
dpi: int = 150,
|
|
122
|
+
show_progress: bool = True,
|
|
123
|
+
) -> Tuple[int, int]:
|
|
124
|
+
"""
|
|
125
|
+
Generate comparison visualizations for all processed images in a directory.
|
|
126
|
+
|
|
127
|
+
Args:
|
|
128
|
+
input_dir: Directory containing original MRC files
|
|
129
|
+
output_dir: Directory containing processed MRC files
|
|
130
|
+
viz_dir: Directory for output visualization PNG files
|
|
131
|
+
prefix: Prefix used for processed files (default: "sub_")
|
|
132
|
+
pattern: Glob pattern for finding processed files
|
|
133
|
+
dpi: Resolution for output images
|
|
134
|
+
show_progress: Show progress bar
|
|
135
|
+
|
|
136
|
+
Returns:
|
|
137
|
+
Tuple of (successful_count, total_count)
|
|
138
|
+
"""
|
|
139
|
+
import matplotlib.pyplot as plt
|
|
140
|
+
|
|
141
|
+
viz_dir = Path(viz_dir)
|
|
142
|
+
viz_dir.mkdir(parents=True, exist_ok=True)
|
|
143
|
+
|
|
144
|
+
# Find all processed files
|
|
145
|
+
output_files = sorted(Path(output_dir).glob(f"{prefix}{pattern}"))
|
|
146
|
+
|
|
147
|
+
if not output_files:
|
|
148
|
+
logger.warning(f"No processed files found matching '{prefix}{pattern}' in {output_dir}")
|
|
149
|
+
return 0, 0
|
|
150
|
+
|
|
151
|
+
successful = 0
|
|
152
|
+
total = len(output_files)
|
|
153
|
+
|
|
154
|
+
# Setup iterator with optional progress bar
|
|
155
|
+
if show_progress:
|
|
156
|
+
try:
|
|
157
|
+
from tqdm import tqdm
|
|
158
|
+
iterator = tqdm(output_files, desc="Generating visualizations", unit="file")
|
|
159
|
+
except ImportError:
|
|
160
|
+
iterator = output_files
|
|
161
|
+
else:
|
|
162
|
+
iterator = output_files
|
|
163
|
+
|
|
164
|
+
for processed_path in iterator:
|
|
165
|
+
try:
|
|
166
|
+
# Get corresponding input file
|
|
167
|
+
input_name = processed_path.name.replace(prefix, "", 1)
|
|
168
|
+
input_path = Path(input_dir) / input_name
|
|
169
|
+
|
|
170
|
+
if not input_path.exists():
|
|
171
|
+
logger.debug(f"Original not found: {input_path}")
|
|
172
|
+
continue
|
|
173
|
+
|
|
174
|
+
# Output path
|
|
175
|
+
viz_name = input_name.replace(".mrc", ".png")
|
|
176
|
+
viz_path = viz_dir / viz_name
|
|
177
|
+
|
|
178
|
+
# Skip if already exists
|
|
179
|
+
if viz_path.exists():
|
|
180
|
+
successful += 1
|
|
181
|
+
continue
|
|
182
|
+
|
|
183
|
+
# Generate visualization
|
|
184
|
+
save_comparison_visualization(
|
|
185
|
+
original_path=input_path,
|
|
186
|
+
processed_path=processed_path,
|
|
187
|
+
output_path=viz_path,
|
|
188
|
+
dpi=dpi,
|
|
189
|
+
)
|
|
190
|
+
successful += 1
|
|
191
|
+
|
|
192
|
+
except Exception as e:
|
|
193
|
+
logger.error(f"Failed to create visualization for {processed_path.name}: {e}")
|
|
194
|
+
|
|
195
|
+
return successful, total
|