lattice-sub 1.0.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lattice_sub-1.0.10.dist-info/METADATA +324 -0
- lattice_sub-1.0.10.dist-info/RECORD +16 -0
- lattice_sub-1.0.10.dist-info/WHEEL +5 -0
- lattice_sub-1.0.10.dist-info/entry_points.txt +2 -0
- lattice_sub-1.0.10.dist-info/licenses/LICENSE +21 -0
- lattice_sub-1.0.10.dist-info/top_level.txt +1 -0
- lattice_subtraction/__init__.py +49 -0
- lattice_subtraction/batch.py +374 -0
- lattice_subtraction/cli.py +751 -0
- lattice_subtraction/config.py +216 -0
- lattice_subtraction/core.py +389 -0
- lattice_subtraction/io.py +177 -0
- lattice_subtraction/masks.py +397 -0
- lattice_subtraction/processing.py +221 -0
- lattice_subtraction/ui.py +256 -0
- lattice_subtraction/visualization.py +195 -0
|
@@ -0,0 +1,216 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Configuration management for lattice subtraction.
|
|
3
|
+
|
|
4
|
+
This module handles loading, validation, and storage of processing parameters
|
|
5
|
+
from YAML configuration files or Python dictionaries.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from dataclasses import dataclass, field
|
|
9
|
+
from pathlib import Path
|
|
10
|
+
from typing import Optional, Literal
|
|
11
|
+
import yaml
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@dataclass
|
|
15
|
+
class Config:
|
|
16
|
+
"""
|
|
17
|
+
Configuration parameters for lattice subtraction processing.
|
|
18
|
+
|
|
19
|
+
All resolution parameters are in Angstroms. The algorithm removes lattice
|
|
20
|
+
peaks in the resolution range between inside_radius_ang and outside_radius_ang.
|
|
21
|
+
|
|
22
|
+
Attributes:
|
|
23
|
+
pixel_ang: Pixel size in Angstroms (detector-dependent, e.g., 0.56 for K3)
|
|
24
|
+
inside_radius_ang: Inner resolution limit - FFT spots within this radius
|
|
25
|
+
are preserved (low-frequency structural info). Default: 90Å
|
|
26
|
+
outside_radius_ang: Outer resolution limit - spots beyond this are preserved.
|
|
27
|
+
If None, auto-calculated as pixel_ang * 2 + 0.2
|
|
28
|
+
threshold: Peak detection threshold on log-amplitude FFT. Spots above this
|
|
29
|
+
value are identified as lattice peaks. Default: 1.42
|
|
30
|
+
expand_pixel: Morphological expansion radius for mask dilation. Default: 10
|
|
31
|
+
pad_origin_x: X padding offset in pixels. Default: 200
|
|
32
|
+
pad_origin_y: Y padding offset in pixels. Default: 200 (use 1000 for K3)
|
|
33
|
+
pad_output: If False, crop output to original size. Default: False
|
|
34
|
+
unit_cell_ang: Crystal unit cell size in Angstroms for shift calculation.
|
|
35
|
+
Default: 116Å (nucleosome repeat)
|
|
36
|
+
backend: Computation backend - 'numpy' for CPU, 'pytorch' for GPU. Default: 'numpy'
|
|
37
|
+
"""
|
|
38
|
+
|
|
39
|
+
# Required parameter
|
|
40
|
+
pixel_ang: float
|
|
41
|
+
|
|
42
|
+
# Resolution limits
|
|
43
|
+
inside_radius_ang: float = 90.0
|
|
44
|
+
outside_radius_ang: Optional[float] = None # Auto-calculated if None
|
|
45
|
+
|
|
46
|
+
# Peak detection
|
|
47
|
+
threshold: float = 1.42
|
|
48
|
+
expand_pixel: int = 10
|
|
49
|
+
|
|
50
|
+
# Padding
|
|
51
|
+
pad_origin_x: int = 200
|
|
52
|
+
pad_origin_y: int = 200
|
|
53
|
+
pad_output: bool = False
|
|
54
|
+
|
|
55
|
+
# Crystal parameters
|
|
56
|
+
unit_cell_ang: float = 116.0 # Nucleosome repeat distance
|
|
57
|
+
|
|
58
|
+
# Computation backend: 'auto' tries GPU first, then falls back to CPU
|
|
59
|
+
backend: Literal["numpy", "pytorch", "auto"] = "auto"
|
|
60
|
+
|
|
61
|
+
def __post_init__(self):
|
|
62
|
+
"""Validate and set auto-calculated parameters."""
|
|
63
|
+
if self.pixel_ang <= 0:
|
|
64
|
+
raise ValueError(f"pixel_ang must be positive, got {self.pixel_ang}")
|
|
65
|
+
|
|
66
|
+
if self.inside_radius_ang <= 0:
|
|
67
|
+
raise ValueError(f"inside_radius_ang must be positive, got {self.inside_radius_ang}")
|
|
68
|
+
|
|
69
|
+
if self.threshold <= 0:
|
|
70
|
+
raise ValueError(f"threshold must be positive, got {self.threshold}")
|
|
71
|
+
|
|
72
|
+
# Auto-calculate outside radius if not provided
|
|
73
|
+
if self.outside_radius_ang is None:
|
|
74
|
+
self.outside_radius_ang = self.pixel_ang * 2 + 0.2
|
|
75
|
+
|
|
76
|
+
if self.outside_radius_ang >= self.inside_radius_ang:
|
|
77
|
+
raise ValueError(
|
|
78
|
+
f"outside_radius_ang ({self.outside_radius_ang}) must be smaller than "
|
|
79
|
+
f"inside_radius_ang ({self.inside_radius_ang})"
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
@classmethod
|
|
83
|
+
def from_yaml(cls, path: str | Path) -> "Config":
|
|
84
|
+
"""
|
|
85
|
+
Load configuration from a YAML file.
|
|
86
|
+
|
|
87
|
+
Args:
|
|
88
|
+
path: Path to YAML configuration file
|
|
89
|
+
|
|
90
|
+
Returns:
|
|
91
|
+
Config instance with loaded parameters
|
|
92
|
+
|
|
93
|
+
Example YAML format:
|
|
94
|
+
pixel_ang: 0.56
|
|
95
|
+
threshold: 1.56
|
|
96
|
+
inside_radius_ang: 90
|
|
97
|
+
# outside_radius_ang: auto # Optional, auto-calculated if omitted
|
|
98
|
+
"""
|
|
99
|
+
path = Path(path)
|
|
100
|
+
if not path.exists():
|
|
101
|
+
raise FileNotFoundError(f"Configuration file not found: {path}")
|
|
102
|
+
|
|
103
|
+
with open(path, 'r') as f:
|
|
104
|
+
data = yaml.safe_load(f)
|
|
105
|
+
|
|
106
|
+
# Handle 'auto' string for outside_radius_ang
|
|
107
|
+
if data.get('outside_radius_ang') == 'auto':
|
|
108
|
+
data['outside_radius_ang'] = None
|
|
109
|
+
|
|
110
|
+
return cls(**data)
|
|
111
|
+
|
|
112
|
+
@classmethod
|
|
113
|
+
def from_legacy_parameter_file(cls, path: str | Path) -> "Config":
|
|
114
|
+
"""
|
|
115
|
+
Load configuration from legacy MATLAB PARAMETER file format.
|
|
116
|
+
|
|
117
|
+
Args:
|
|
118
|
+
path: Path to legacy PARAMETER file
|
|
119
|
+
|
|
120
|
+
Returns:
|
|
121
|
+
Config instance with loaded parameters
|
|
122
|
+
"""
|
|
123
|
+
path = Path(path)
|
|
124
|
+
params = {}
|
|
125
|
+
|
|
126
|
+
# Mapping from legacy names to new names
|
|
127
|
+
name_map = {
|
|
128
|
+
'inside_radius_ang': 'inside_radius_ang',
|
|
129
|
+
'outside_radius_ang': 'outside_radius_ang',
|
|
130
|
+
'pixel_ang': 'pixel_ang',
|
|
131
|
+
'threshold': 'threshold',
|
|
132
|
+
}
|
|
133
|
+
|
|
134
|
+
with open(path, 'r') as f:
|
|
135
|
+
for line in f:
|
|
136
|
+
line = line.strip()
|
|
137
|
+
if not line or line.startswith('!'):
|
|
138
|
+
continue
|
|
139
|
+
|
|
140
|
+
# Split on whitespace, handle comments with !
|
|
141
|
+
parts = line.split('!')
|
|
142
|
+
main_part = parts[0].strip()
|
|
143
|
+
if not main_part:
|
|
144
|
+
continue
|
|
145
|
+
|
|
146
|
+
tokens = main_part.split()
|
|
147
|
+
if len(tokens) >= 2:
|
|
148
|
+
name = tokens[0].lower()
|
|
149
|
+
try:
|
|
150
|
+
value = float(tokens[1])
|
|
151
|
+
except ValueError:
|
|
152
|
+
continue
|
|
153
|
+
|
|
154
|
+
if name in name_map:
|
|
155
|
+
params[name_map[name]] = value
|
|
156
|
+
|
|
157
|
+
return cls(**params)
|
|
158
|
+
|
|
159
|
+
def to_yaml(self, path: str | Path) -> None:
|
|
160
|
+
"""
|
|
161
|
+
Save configuration to a YAML file.
|
|
162
|
+
|
|
163
|
+
Args:
|
|
164
|
+
path: Output path for YAML file
|
|
165
|
+
"""
|
|
166
|
+
path = Path(path)
|
|
167
|
+
|
|
168
|
+
data = {
|
|
169
|
+
'pixel_ang': self.pixel_ang,
|
|
170
|
+
'inside_radius_ang': self.inside_radius_ang,
|
|
171
|
+
'outside_radius_ang': self.outside_radius_ang,
|
|
172
|
+
'threshold': self.threshold,
|
|
173
|
+
'expand_pixel': self.expand_pixel,
|
|
174
|
+
'pad_origin_x': self.pad_origin_x,
|
|
175
|
+
'pad_origin_y': self.pad_origin_y,
|
|
176
|
+
'pad_output': self.pad_output,
|
|
177
|
+
'unit_cell_ang': self.unit_cell_ang,
|
|
178
|
+
'backend': self.backend,
|
|
179
|
+
}
|
|
180
|
+
|
|
181
|
+
with open(path, 'w') as f:
|
|
182
|
+
yaml.dump(data, f, default_flow_style=False, sort_keys=False)
|
|
183
|
+
|
|
184
|
+
def copy(self, **updates) -> "Config":
|
|
185
|
+
"""
|
|
186
|
+
Create a copy of this config with optional updates.
|
|
187
|
+
|
|
188
|
+
Args:
|
|
189
|
+
**updates: Parameters to override
|
|
190
|
+
|
|
191
|
+
Returns:
|
|
192
|
+
New Config instance with updates applied
|
|
193
|
+
"""
|
|
194
|
+
from dataclasses import asdict
|
|
195
|
+
current = asdict(self)
|
|
196
|
+
current.update(updates)
|
|
197
|
+
return Config(**current)
|
|
198
|
+
|
|
199
|
+
|
|
200
|
+
def create_default_config(pixel_ang: float = 0.56, detector: str = "K3") -> Config:
|
|
201
|
+
"""
|
|
202
|
+
Create a config with detector-specific defaults.
|
|
203
|
+
|
|
204
|
+
Args:
|
|
205
|
+
pixel_ang: Pixel size in Angstroms
|
|
206
|
+
detector: Detector type ('K3', 'Falcon', 'generic')
|
|
207
|
+
|
|
208
|
+
Returns:
|
|
209
|
+
Config with appropriate defaults for the detector
|
|
210
|
+
"""
|
|
211
|
+
pad_y = 1000 if detector.upper() == "K3" else 200
|
|
212
|
+
|
|
213
|
+
return Config(
|
|
214
|
+
pixel_ang=pixel_ang,
|
|
215
|
+
pad_origin_y=pad_y,
|
|
216
|
+
)
|
|
@@ -0,0 +1,389 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Core lattice subtraction algorithm.
|
|
3
|
+
|
|
4
|
+
This module contains the main LatticeSubtractor class that implements
|
|
5
|
+
the phase-preserving lattice removal algorithm.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from dataclasses import dataclass
|
|
9
|
+
from pathlib import Path
|
|
10
|
+
from typing import Optional
|
|
11
|
+
import numpy as np
|
|
12
|
+
|
|
13
|
+
from .config import Config
|
|
14
|
+
from .io import read_mrc, write_mrc
|
|
15
|
+
from .masks import create_fft_mask, create_fft_mask_gpu, resolution_to_pixels
|
|
16
|
+
from .processing import (
|
|
17
|
+
pad_image,
|
|
18
|
+
crop_to_original,
|
|
19
|
+
subtract_background,
|
|
20
|
+
compute_power_spectrum,
|
|
21
|
+
shift_and_average,
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
@dataclass
|
|
26
|
+
class SubtractionResult:
|
|
27
|
+
"""
|
|
28
|
+
Result of lattice subtraction processing.
|
|
29
|
+
|
|
30
|
+
Attributes:
|
|
31
|
+
image: Processed image with lattice removed
|
|
32
|
+
original_shape: Shape of input image before padding
|
|
33
|
+
fft_mask: The mask used for FFT filtering (optional)
|
|
34
|
+
power_spectrum: Background-subtracted power spectrum (optional)
|
|
35
|
+
"""
|
|
36
|
+
image: np.ndarray
|
|
37
|
+
original_shape: tuple
|
|
38
|
+
fft_mask: Optional[np.ndarray] = None
|
|
39
|
+
power_spectrum: Optional[np.ndarray] = None
|
|
40
|
+
|
|
41
|
+
def save(self, path: str | Path, pixel_size: float = 1.0) -> None:
|
|
42
|
+
"""Save the processed image to an MRC file."""
|
|
43
|
+
write_mrc(self.image, path, pixel_size=pixel_size)
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class LatticeSubtractor:
|
|
47
|
+
"""
|
|
48
|
+
Main class for lattice subtraction from cryo-EM micrographs.
|
|
49
|
+
|
|
50
|
+
This class implements the algorithm from bg_push_by_rot.m:
|
|
51
|
+
1. Pad image and compute 2D FFT
|
|
52
|
+
2. Detect lattice peaks via thresholding on log-power spectrum
|
|
53
|
+
3. Create composite mask (protect center and edges)
|
|
54
|
+
4. Inpaint masked regions with local average amplitude
|
|
55
|
+
5. Preserve original phase, replace amplitude
|
|
56
|
+
6. Inverse FFT and crop
|
|
57
|
+
|
|
58
|
+
The algorithm removes periodic crystal lattice signals while
|
|
59
|
+
preserving non-periodic features in the image.
|
|
60
|
+
|
|
61
|
+
Example:
|
|
62
|
+
>>> config = Config(pixel_ang=0.56, threshold=1.56)
|
|
63
|
+
>>> subtractor = LatticeSubtractor(config)
|
|
64
|
+
>>> result = subtractor.process("input.mrc")
|
|
65
|
+
>>> result.save("output.mrc")
|
|
66
|
+
"""
|
|
67
|
+
|
|
68
|
+
def __init__(self, config: Config):
|
|
69
|
+
"""
|
|
70
|
+
Initialize the subtractor with configuration.
|
|
71
|
+
|
|
72
|
+
Args:
|
|
73
|
+
config: Configuration parameters for processing
|
|
74
|
+
"""
|
|
75
|
+
self.config = config
|
|
76
|
+
self._setup_backend()
|
|
77
|
+
|
|
78
|
+
def _setup_backend(self) -> None:
|
|
79
|
+
"""Setup computation backend (numpy, pytorch, or auto).
|
|
80
|
+
|
|
81
|
+
Auto mode tries PyTorch+CUDA first, then PyTorch CPU, then NumPy.
|
|
82
|
+
Prints user-friendly status message about which backend is active.
|
|
83
|
+
"""
|
|
84
|
+
backend = self.config.backend
|
|
85
|
+
self._gpu_message_shown = getattr(self, '_gpu_message_shown', False)
|
|
86
|
+
|
|
87
|
+
# Auto mode: try GPU first, then CPU
|
|
88
|
+
if backend == "auto":
|
|
89
|
+
try:
|
|
90
|
+
import torch
|
|
91
|
+
if torch.cuda.is_available():
|
|
92
|
+
self.device = torch.device('cuda')
|
|
93
|
+
self.use_gpu = True
|
|
94
|
+
# Only print once per session (batch processing reuses subtractor)
|
|
95
|
+
if not self._gpu_message_shown:
|
|
96
|
+
gpu_name = torch.cuda.get_device_name(0)
|
|
97
|
+
print(f"✓ Using GPU: {gpu_name}")
|
|
98
|
+
self._gpu_message_shown = True
|
|
99
|
+
else:
|
|
100
|
+
self.device = torch.device('cpu')
|
|
101
|
+
self.use_gpu = False
|
|
102
|
+
if not self._gpu_message_shown:
|
|
103
|
+
print("ℹ Running on CPU (run 'lattice-sub setup-gpu' to enable GPU)")
|
|
104
|
+
self._gpu_message_shown = True
|
|
105
|
+
except ImportError:
|
|
106
|
+
self.device = None
|
|
107
|
+
self.use_gpu = False
|
|
108
|
+
if not self._gpu_message_shown:
|
|
109
|
+
print("ℹ Running on CPU with NumPy (PyTorch not installed)")
|
|
110
|
+
self._gpu_message_shown = True
|
|
111
|
+
|
|
112
|
+
elif backend == "pytorch":
|
|
113
|
+
try:
|
|
114
|
+
import torch
|
|
115
|
+
if torch.cuda.is_available():
|
|
116
|
+
self.device = torch.device('cuda')
|
|
117
|
+
self.use_gpu = True
|
|
118
|
+
else:
|
|
119
|
+
import warnings
|
|
120
|
+
warnings.warn(
|
|
121
|
+
"CUDA not available, falling back to CPU."
|
|
122
|
+
)
|
|
123
|
+
self.device = torch.device('cpu')
|
|
124
|
+
self.use_gpu = False
|
|
125
|
+
except ImportError:
|
|
126
|
+
import warnings
|
|
127
|
+
warnings.warn(
|
|
128
|
+
"PyTorch not available, falling back to NumPy. "
|
|
129
|
+
"Install with: pip install torch"
|
|
130
|
+
)
|
|
131
|
+
self.device = None
|
|
132
|
+
self.use_gpu = False
|
|
133
|
+
else:
|
|
134
|
+
# numpy backend
|
|
135
|
+
self.device = None
|
|
136
|
+
self.use_gpu = False
|
|
137
|
+
|
|
138
|
+
def _to_device(self, array: np.ndarray):
|
|
139
|
+
"""Move array to GPU if using PyTorch."""
|
|
140
|
+
if self.use_gpu and self.device is not None:
|
|
141
|
+
import torch
|
|
142
|
+
return torch.from_numpy(array).to(self.device)
|
|
143
|
+
return array
|
|
144
|
+
|
|
145
|
+
def _to_numpy(self, array) -> np.ndarray:
|
|
146
|
+
"""Move array from GPU to CPU if needed."""
|
|
147
|
+
if self.use_gpu and hasattr(array, 'cpu'):
|
|
148
|
+
return array.cpu().numpy()
|
|
149
|
+
return array
|
|
150
|
+
|
|
151
|
+
def process(
|
|
152
|
+
self,
|
|
153
|
+
input_path: str | Path | np.ndarray,
|
|
154
|
+
return_diagnostics: bool = False,
|
|
155
|
+
) -> SubtractionResult:
|
|
156
|
+
"""
|
|
157
|
+
Process a micrograph to remove lattice signal.
|
|
158
|
+
|
|
159
|
+
Args:
|
|
160
|
+
input_path: Path to input MRC file, or numpy array
|
|
161
|
+
return_diagnostics: If True, include mask and power spectrum in result
|
|
162
|
+
|
|
163
|
+
Returns:
|
|
164
|
+
SubtractionResult containing processed image and optional diagnostics
|
|
165
|
+
"""
|
|
166
|
+
# Load image
|
|
167
|
+
if isinstance(input_path, (str, Path)):
|
|
168
|
+
image = read_mrc(input_path)
|
|
169
|
+
else:
|
|
170
|
+
image = input_path.astype(np.float32)
|
|
171
|
+
|
|
172
|
+
original_shape = image.shape
|
|
173
|
+
|
|
174
|
+
# Pad image
|
|
175
|
+
padded, pad_meta = pad_image(
|
|
176
|
+
image,
|
|
177
|
+
pad_origin=(self.config.pad_origin_y, self.config.pad_origin_x),
|
|
178
|
+
)
|
|
179
|
+
|
|
180
|
+
# Process
|
|
181
|
+
result_padded, fft_mask, power_spec = self._process_padded(
|
|
182
|
+
padded,
|
|
183
|
+
return_diagnostics=return_diagnostics,
|
|
184
|
+
)
|
|
185
|
+
|
|
186
|
+
# Crop to original size if requested
|
|
187
|
+
if not self.config.pad_output:
|
|
188
|
+
result_image = crop_to_original(result_padded, pad_meta)
|
|
189
|
+
else:
|
|
190
|
+
result_image = result_padded
|
|
191
|
+
|
|
192
|
+
return SubtractionResult(
|
|
193
|
+
image=result_image,
|
|
194
|
+
original_shape=original_shape,
|
|
195
|
+
fft_mask=fft_mask if return_diagnostics else None,
|
|
196
|
+
power_spectrum=power_spec if return_diagnostics else None,
|
|
197
|
+
)
|
|
198
|
+
|
|
199
|
+
def _process_padded(
|
|
200
|
+
self,
|
|
201
|
+
image: np.ndarray,
|
|
202
|
+
return_diagnostics: bool = False,
|
|
203
|
+
) -> tuple:
|
|
204
|
+
"""
|
|
205
|
+
Core processing on a padded image.
|
|
206
|
+
|
|
207
|
+
This implements the algorithm from bg_push_by_rot.m.
|
|
208
|
+
"""
|
|
209
|
+
# Convert to float64 for processing precision
|
|
210
|
+
img = self._to_device(image.astype(np.float64))
|
|
211
|
+
box_size = image.shape[0]
|
|
212
|
+
|
|
213
|
+
# Step 1: Compute FFT and shift to center DC
|
|
214
|
+
if self.use_gpu:
|
|
215
|
+
import torch
|
|
216
|
+
fft_img = torch.fft.fft2(img)
|
|
217
|
+
fft_shifted = torch.fft.fftshift(fft_img)
|
|
218
|
+
# Step 2: Compute log-power spectrum
|
|
219
|
+
power_spectrum = torch.abs(torch.log(torch.abs(fft_shifted) + 1e-10))
|
|
220
|
+
else:
|
|
221
|
+
from scipy import fft
|
|
222
|
+
fft_img = fft.fft2(img)
|
|
223
|
+
fft_shifted = fft.fftshift(fft_img)
|
|
224
|
+
# Step 2: Compute log-power spectrum
|
|
225
|
+
power_spectrum = np.abs(np.log(np.abs(fft_shifted) + 1e-10))
|
|
226
|
+
|
|
227
|
+
# Step 3: Background subtraction for peak detection
|
|
228
|
+
# Move to numpy for scipy operations
|
|
229
|
+
power_np = self._to_numpy(power_spectrum)
|
|
230
|
+
subtracted = subtract_background(power_np)
|
|
231
|
+
|
|
232
|
+
# Step 4: Threshold to detect peaks
|
|
233
|
+
threshold_mask = subtracted > self.config.threshold
|
|
234
|
+
|
|
235
|
+
# Step 5: Create composite mask with radial limits
|
|
236
|
+
# Use GPU-accelerated mask creation when available
|
|
237
|
+
if self.use_gpu:
|
|
238
|
+
import torch
|
|
239
|
+
# Convert threshold mask to GPU tensor
|
|
240
|
+
threshold_tensor = torch.from_numpy(threshold_mask).to(self.device)
|
|
241
|
+
|
|
242
|
+
# Create mask entirely on GPU
|
|
243
|
+
mask_final_dev = create_fft_mask_gpu(
|
|
244
|
+
box_size=box_size,
|
|
245
|
+
pixel_ang=self.config.pixel_ang,
|
|
246
|
+
inside_radius_ang=self.config.inside_radius_ang,
|
|
247
|
+
outside_radius_ang=self.config.outside_radius_ang,
|
|
248
|
+
threshold_mask=threshold_tensor,
|
|
249
|
+
expand_pixel=self.config.expand_pixel,
|
|
250
|
+
device=self.device,
|
|
251
|
+
).float()
|
|
252
|
+
else:
|
|
253
|
+
# CPU path
|
|
254
|
+
mask_final = create_fft_mask(
|
|
255
|
+
box_size=box_size,
|
|
256
|
+
pixel_ang=self.config.pixel_ang,
|
|
257
|
+
inside_radius_ang=self.config.inside_radius_ang,
|
|
258
|
+
outside_radius_ang=self.config.outside_radius_ang,
|
|
259
|
+
threshold_mask=threshold_mask,
|
|
260
|
+
expand_pixel=self.config.expand_pixel,
|
|
261
|
+
)
|
|
262
|
+
mask_final_dev = self._to_device(mask_final.astype(np.float64))
|
|
263
|
+
|
|
264
|
+
# Step 6: Inpainting with local averaging
|
|
265
|
+
# Keep unmasked FFT values
|
|
266
|
+
fft_keep = mask_final_dev * fft_shifted
|
|
267
|
+
|
|
268
|
+
# Calculate shift distance (based on unit cell)
|
|
269
|
+
shift_pixels = int(
|
|
270
|
+
self.config.pixel_ang / self.config.unit_cell_ang * box_size
|
|
271
|
+
)
|
|
272
|
+
shift_pixels = max(1, shift_pixels) # Ensure at least 1 pixel shift
|
|
273
|
+
|
|
274
|
+
if self.use_gpu:
|
|
275
|
+
import torch
|
|
276
|
+
# Compute amplitude of kept FFT values (zeros where mask removes peaks)
|
|
277
|
+
# This matches MATLAB: abs_y2_A = abs(y2_A) where y2_A = mask_final .* y2
|
|
278
|
+
amplitude_keep = torch.abs(fft_keep)
|
|
279
|
+
|
|
280
|
+
# Shift and average (inpainting) - propagates good values into zero regions
|
|
281
|
+
# This matches MATLAB circshift averaging
|
|
282
|
+
shift_avg = (
|
|
283
|
+
torch.roll(amplitude_keep, shift_pixels, dims=0) +
|
|
284
|
+
torch.roll(amplitude_keep, -shift_pixels, dims=0) +
|
|
285
|
+
torch.roll(amplitude_keep, shift_pixels, dims=1) +
|
|
286
|
+
torch.roll(amplitude_keep, -shift_pixels, dims=1)
|
|
287
|
+
) / 4.0
|
|
288
|
+
|
|
289
|
+
# Step 7: Replace masked amplitudes, preserve ORIGINAL phase
|
|
290
|
+
# MATLAB: y2_B = ~mask_final .* shift_ave
|
|
291
|
+
# MATLAB: angle_y2_ori_B = angle(y .* ~mask_final) <- uses ORIGINAL FFT phase
|
|
292
|
+
mask_remove = ~mask_final_dev.bool()
|
|
293
|
+
inpaint_amplitude = mask_remove.float() * shift_avg
|
|
294
|
+
|
|
295
|
+
# Get original phase at masked positions FROM ORIGINAL FFT (not fft_keep)
|
|
296
|
+
# This is critical - MATLAB uses: angle(y .* ~mask_final)
|
|
297
|
+
original_phase = torch.angle(fft_shifted * mask_remove.float())
|
|
298
|
+
|
|
299
|
+
# Reconstruct: keep + inpainted with original phase
|
|
300
|
+
# MATLAB: y2 = y2_A + value_y2_B .* exp(i .* angle_y2_ori_B)
|
|
301
|
+
fft_result = fft_keep + inpaint_amplitude * torch.exp(1j * original_phase)
|
|
302
|
+
|
|
303
|
+
# Step 8: Inverse FFT
|
|
304
|
+
fft_result = torch.fft.ifftshift(fft_result)
|
|
305
|
+
result = torch.fft.ifft2(fft_result)
|
|
306
|
+
|
|
307
|
+
# Take real part
|
|
308
|
+
result = torch.real(result).float()
|
|
309
|
+
else:
|
|
310
|
+
# NumPy/SciPy path - same algorithm as GPU
|
|
311
|
+
# Compute amplitude of kept FFT values (zeros where mask removes peaks)
|
|
312
|
+
amplitude_keep = np.abs(fft_keep)
|
|
313
|
+
|
|
314
|
+
# Shift and average (inpainting) - propagates good values into zero regions
|
|
315
|
+
shift_avg = (
|
|
316
|
+
np.roll(amplitude_keep, shift_pixels, axis=0) +
|
|
317
|
+
np.roll(amplitude_keep, -shift_pixels, axis=0) +
|
|
318
|
+
np.roll(amplitude_keep, shift_pixels, axis=1) +
|
|
319
|
+
np.roll(amplitude_keep, -shift_pixels, axis=1)
|
|
320
|
+
) / 4.0
|
|
321
|
+
|
|
322
|
+
# Step 7: Replace masked amplitudes, preserve ORIGINAL phase
|
|
323
|
+
mask_remove = ~mask_final.astype(bool)
|
|
324
|
+
inpaint_amplitude = mask_remove.astype(np.float64) * shift_avg
|
|
325
|
+
|
|
326
|
+
# Get original phase at masked positions FROM ORIGINAL FFT
|
|
327
|
+
original_phase = np.angle(fft_shifted * mask_remove.astype(np.float64))
|
|
328
|
+
|
|
329
|
+
# Reconstruct: keep + inpainted with original phase
|
|
330
|
+
fft_result = fft_keep + inpaint_amplitude * np.exp(1j * original_phase)
|
|
331
|
+
|
|
332
|
+
# Step 8: Inverse FFT
|
|
333
|
+
from scipy import fft
|
|
334
|
+
fft_result = fft.ifftshift(fft_result)
|
|
335
|
+
result = fft.ifft2(fft_result)
|
|
336
|
+
|
|
337
|
+
# Take real part
|
|
338
|
+
result = np.real(result).astype(np.float32)
|
|
339
|
+
|
|
340
|
+
# Move results to numpy
|
|
341
|
+
result_np = self._to_numpy(result)
|
|
342
|
+
|
|
343
|
+
# For diagnostics, get mask as numpy array
|
|
344
|
+
if return_diagnostics:
|
|
345
|
+
if self.use_gpu:
|
|
346
|
+
mask_np = self._to_numpy(mask_final_dev).astype(bool)
|
|
347
|
+
else:
|
|
348
|
+
mask_np = mask_final.astype(bool)
|
|
349
|
+
power_np = subtracted
|
|
350
|
+
else:
|
|
351
|
+
mask_np = None
|
|
352
|
+
power_np = None
|
|
353
|
+
|
|
354
|
+
return result_np, mask_np, power_np
|
|
355
|
+
|
|
356
|
+
def process_array(
|
|
357
|
+
self,
|
|
358
|
+
image: np.ndarray,
|
|
359
|
+
return_diagnostics: bool = False,
|
|
360
|
+
) -> SubtractionResult:
|
|
361
|
+
"""
|
|
362
|
+
Process a numpy array directly.
|
|
363
|
+
|
|
364
|
+
Args:
|
|
365
|
+
image: Input 2D numpy array
|
|
366
|
+
return_diagnostics: If True, include mask and power spectrum
|
|
367
|
+
|
|
368
|
+
Returns:
|
|
369
|
+
SubtractionResult
|
|
370
|
+
"""
|
|
371
|
+
return self.process(image, return_diagnostics=return_diagnostics)
|
|
372
|
+
|
|
373
|
+
|
|
374
|
+
def process_micrograph(
|
|
375
|
+
input_path: str | Path,
|
|
376
|
+
output_path: str | Path,
|
|
377
|
+
config: Config,
|
|
378
|
+
) -> None:
|
|
379
|
+
"""
|
|
380
|
+
Convenience function to process a single micrograph.
|
|
381
|
+
|
|
382
|
+
Args:
|
|
383
|
+
input_path: Path to input MRC file
|
|
384
|
+
output_path: Path for output MRC file
|
|
385
|
+
config: Processing configuration
|
|
386
|
+
"""
|
|
387
|
+
subtractor = LatticeSubtractor(config)
|
|
388
|
+
result = subtractor.process(input_path)
|
|
389
|
+
result.save(output_path, pixel_size=config.pixel_ang)
|