lattice-sub 1.0.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lattice_sub-1.0.10.dist-info/METADATA +324 -0
- lattice_sub-1.0.10.dist-info/RECORD +16 -0
- lattice_sub-1.0.10.dist-info/WHEEL +5 -0
- lattice_sub-1.0.10.dist-info/entry_points.txt +2 -0
- lattice_sub-1.0.10.dist-info/licenses/LICENSE +21 -0
- lattice_sub-1.0.10.dist-info/top_level.txt +1 -0
- lattice_subtraction/__init__.py +49 -0
- lattice_subtraction/batch.py +374 -0
- lattice_subtraction/cli.py +751 -0
- lattice_subtraction/config.py +216 -0
- lattice_subtraction/core.py +389 -0
- lattice_subtraction/io.py +177 -0
- lattice_subtraction/masks.py +397 -0
- lattice_subtraction/processing.py +221 -0
- lattice_subtraction/ui.py +256 -0
- lattice_subtraction/visualization.py +195 -0
|
@@ -0,0 +1,177 @@
|
|
|
1
|
+
"""
|
|
2
|
+
MRC file I/O utilities.
|
|
3
|
+
|
|
4
|
+
This module wraps the mrcfile library for reading and writing MRC format files
|
|
5
|
+
commonly used in cryo-EM.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from pathlib import Path
|
|
9
|
+
from typing import Optional
|
|
10
|
+
import numpy as np
|
|
11
|
+
|
|
12
|
+
try:
|
|
13
|
+
import mrcfile
|
|
14
|
+
except ImportError:
|
|
15
|
+
raise ImportError(
|
|
16
|
+
"mrcfile is required for MRC I/O. Install with: pip install mrcfile"
|
|
17
|
+
)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def read_mrc(
|
|
21
|
+
path: str | Path,
|
|
22
|
+
as_float32: bool = True,
|
|
23
|
+
) -> np.ndarray:
|
|
24
|
+
"""
|
|
25
|
+
Read a 2D micrograph from an MRC file.
|
|
26
|
+
|
|
27
|
+
Args:
|
|
28
|
+
path: Path to MRC file
|
|
29
|
+
as_float32: If True, convert to float32. Default: True
|
|
30
|
+
|
|
31
|
+
Returns:
|
|
32
|
+
2D numpy array containing the image data
|
|
33
|
+
|
|
34
|
+
Raises:
|
|
35
|
+
FileNotFoundError: If file does not exist
|
|
36
|
+
ValueError: If file contains 3D data (use read_mrc_stack instead)
|
|
37
|
+
"""
|
|
38
|
+
path = Path(path)
|
|
39
|
+
if not path.exists():
|
|
40
|
+
raise FileNotFoundError(f"MRC file not found: {path}")
|
|
41
|
+
|
|
42
|
+
with mrcfile.open(path, mode='r', permissive=True) as mrc:
|
|
43
|
+
data = mrc.data.copy()
|
|
44
|
+
|
|
45
|
+
# Handle 3D MRC files (single slice)
|
|
46
|
+
if data.ndim == 3:
|
|
47
|
+
if data.shape[0] == 1:
|
|
48
|
+
data = data[0]
|
|
49
|
+
else:
|
|
50
|
+
raise ValueError(
|
|
51
|
+
f"Expected 2D micrograph, got 3D stack with shape {data.shape}. "
|
|
52
|
+
"Use read_mrc_stack() for 3D data."
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
if as_float32:
|
|
56
|
+
data = data.astype(np.float32)
|
|
57
|
+
|
|
58
|
+
return data
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def read_mrc_stack(
|
|
62
|
+
path: str | Path,
|
|
63
|
+
as_float32: bool = True,
|
|
64
|
+
) -> np.ndarray:
|
|
65
|
+
"""
|
|
66
|
+
Read a 3D stack from an MRC file.
|
|
67
|
+
|
|
68
|
+
Args:
|
|
69
|
+
path: Path to MRC file
|
|
70
|
+
as_float32: If True, convert to float32. Default: True
|
|
71
|
+
|
|
72
|
+
Returns:
|
|
73
|
+
3D numpy array with shape (nz, ny, nx)
|
|
74
|
+
"""
|
|
75
|
+
path = Path(path)
|
|
76
|
+
if not path.exists():
|
|
77
|
+
raise FileNotFoundError(f"MRC file not found: {path}")
|
|
78
|
+
|
|
79
|
+
with mrcfile.open(path, mode='r', permissive=True) as mrc:
|
|
80
|
+
data = mrc.data.copy()
|
|
81
|
+
|
|
82
|
+
if as_float32:
|
|
83
|
+
data = data.astype(np.float32)
|
|
84
|
+
|
|
85
|
+
return data
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def read_mrc_header(path: str | Path) -> dict:
|
|
89
|
+
"""
|
|
90
|
+
Read only the header information from an MRC file.
|
|
91
|
+
|
|
92
|
+
Args:
|
|
93
|
+
path: Path to MRC file
|
|
94
|
+
|
|
95
|
+
Returns:
|
|
96
|
+
Dictionary containing header information including:
|
|
97
|
+
- shape: (nx, ny, nz)
|
|
98
|
+
- pixel_size: voxel size in Angstroms
|
|
99
|
+
- mode: data type mode
|
|
100
|
+
- statistics: (min, max, mean, rms)
|
|
101
|
+
"""
|
|
102
|
+
path = Path(path)
|
|
103
|
+
|
|
104
|
+
with mrcfile.open(path, mode='r', permissive=True) as mrc:
|
|
105
|
+
header = {
|
|
106
|
+
'shape': (int(mrc.header.nx), int(mrc.header.ny), int(mrc.header.nz)),
|
|
107
|
+
'pixel_size': (
|
|
108
|
+
float(mrc.voxel_size.x),
|
|
109
|
+
float(mrc.voxel_size.y),
|
|
110
|
+
float(mrc.voxel_size.z)
|
|
111
|
+
),
|
|
112
|
+
'mode': int(mrc.header.mode),
|
|
113
|
+
'statistics': (
|
|
114
|
+
float(mrc.header.dmin),
|
|
115
|
+
float(mrc.header.dmax),
|
|
116
|
+
float(mrc.header.dmean),
|
|
117
|
+
float(mrc.header.rms),
|
|
118
|
+
),
|
|
119
|
+
}
|
|
120
|
+
|
|
121
|
+
return header
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
def write_mrc(
|
|
125
|
+
data: np.ndarray,
|
|
126
|
+
path: str | Path,
|
|
127
|
+
pixel_size: float = 1.0,
|
|
128
|
+
overwrite: bool = True,
|
|
129
|
+
) -> None:
|
|
130
|
+
"""
|
|
131
|
+
Write a 2D or 3D array to an MRC file.
|
|
132
|
+
|
|
133
|
+
Args:
|
|
134
|
+
data: 2D or 3D numpy array to write
|
|
135
|
+
path: Output file path
|
|
136
|
+
pixel_size: Pixel/voxel size in Angstroms. Default: 1.0
|
|
137
|
+
overwrite: If True, overwrite existing file. Default: True
|
|
138
|
+
|
|
139
|
+
Raises:
|
|
140
|
+
FileExistsError: If file exists and overwrite=False
|
|
141
|
+
ValueError: If data has invalid shape
|
|
142
|
+
"""
|
|
143
|
+
path = Path(path)
|
|
144
|
+
|
|
145
|
+
if path.exists() and not overwrite:
|
|
146
|
+
raise FileExistsError(f"File already exists: {path}")
|
|
147
|
+
|
|
148
|
+
# Ensure parent directory exists
|
|
149
|
+
path.parent.mkdir(parents=True, exist_ok=True)
|
|
150
|
+
|
|
151
|
+
# Convert to float32 for compatibility
|
|
152
|
+
if data.dtype not in (np.float32, np.int16, np.uint16, np.int8, np.uint8):
|
|
153
|
+
data = data.astype(np.float32)
|
|
154
|
+
|
|
155
|
+
# Ensure contiguous array
|
|
156
|
+
data = np.ascontiguousarray(data)
|
|
157
|
+
|
|
158
|
+
with mrcfile.new(path, overwrite=overwrite) as mrc:
|
|
159
|
+
mrc.set_data(data)
|
|
160
|
+
mrc.voxel_size = pixel_size
|
|
161
|
+
|
|
162
|
+
# Update statistics
|
|
163
|
+
mrc.update_header_stats()
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
def get_pixel_size_from_mrc(path: str | Path) -> float:
|
|
167
|
+
"""
|
|
168
|
+
Extract pixel size from MRC file header.
|
|
169
|
+
|
|
170
|
+
Args:
|
|
171
|
+
path: Path to MRC file
|
|
172
|
+
|
|
173
|
+
Returns:
|
|
174
|
+
Pixel size in Angstroms (from X dimension)
|
|
175
|
+
"""
|
|
176
|
+
header = read_mrc_header(path)
|
|
177
|
+
return header['pixel_size'][0]
|
|
@@ -0,0 +1,397 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Mask generation utilities for FFT processing.
|
|
3
|
+
|
|
4
|
+
This module provides functions for creating circular masks and
|
|
5
|
+
performing morphological operations on masks.
|
|
6
|
+
|
|
7
|
+
GPU-accelerated versions are available when PyTorch with CUDA is present.
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
import numpy as np
|
|
11
|
+
from typing import Tuple, Optional, Union
|
|
12
|
+
|
|
13
|
+
# Try to import torch for GPU operations
|
|
14
|
+
try:
|
|
15
|
+
import torch
|
|
16
|
+
TORCH_AVAILABLE = True
|
|
17
|
+
except ImportError:
|
|
18
|
+
TORCH_AVAILABLE = False
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def create_circular_mask(
|
|
22
|
+
shape: Tuple[int, int],
|
|
23
|
+
radius: float,
|
|
24
|
+
center: Tuple[float, float] | None = None,
|
|
25
|
+
invert: bool = False,
|
|
26
|
+
) -> np.ndarray:
|
|
27
|
+
"""
|
|
28
|
+
Create a circular binary mask.
|
|
29
|
+
|
|
30
|
+
This is the Python equivalent of bg_drill_hole.m, but optimized using
|
|
31
|
+
vectorized NumPy operations instead of nested loops.
|
|
32
|
+
|
|
33
|
+
Args:
|
|
34
|
+
shape: Output mask shape (height, width)
|
|
35
|
+
radius: Radius of the circular region in pixels
|
|
36
|
+
center: Center coordinates (y, x). If None, uses image center.
|
|
37
|
+
invert: If True, mask is 0 inside circle, 1 outside. Default: False
|
|
38
|
+
|
|
39
|
+
Returns:
|
|
40
|
+
Boolean mask array where True indicates the circular region
|
|
41
|
+
(or its complement if invert=True)
|
|
42
|
+
|
|
43
|
+
Example:
|
|
44
|
+
>>> mask = create_circular_mask((100, 100), radius=30)
|
|
45
|
+
>>> mask.shape
|
|
46
|
+
(100, 100)
|
|
47
|
+
"""
|
|
48
|
+
h, w = shape
|
|
49
|
+
|
|
50
|
+
if center is None:
|
|
51
|
+
center = (h // 2, w // 2)
|
|
52
|
+
|
|
53
|
+
cy, cx = center
|
|
54
|
+
|
|
55
|
+
# Create coordinate grids
|
|
56
|
+
y, x = np.ogrid[:h, :w]
|
|
57
|
+
|
|
58
|
+
# Calculate distance from center
|
|
59
|
+
dist_sq = (y - cy) ** 2 + (x - cx) ** 2
|
|
60
|
+
|
|
61
|
+
# Create mask
|
|
62
|
+
mask = dist_sq < radius ** 2
|
|
63
|
+
|
|
64
|
+
if invert:
|
|
65
|
+
mask = ~mask
|
|
66
|
+
|
|
67
|
+
return mask
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def create_radial_band_mask(
|
|
71
|
+
shape: Tuple[int, int],
|
|
72
|
+
inner_radius: float,
|
|
73
|
+
outer_radius: float,
|
|
74
|
+
center: Tuple[float, float] | None = None,
|
|
75
|
+
) -> np.ndarray:
|
|
76
|
+
"""
|
|
77
|
+
Create an annular (ring) mask between two radii.
|
|
78
|
+
|
|
79
|
+
Args:
|
|
80
|
+
shape: Output mask shape (height, width)
|
|
81
|
+
inner_radius: Inner radius of the ring
|
|
82
|
+
outer_radius: Outer radius of the ring
|
|
83
|
+
center: Center coordinates. If None, uses image center.
|
|
84
|
+
|
|
85
|
+
Returns:
|
|
86
|
+
Boolean mask that is True in the annular region
|
|
87
|
+
"""
|
|
88
|
+
inner = create_circular_mask(shape, inner_radius, center)
|
|
89
|
+
outer = create_circular_mask(shape, outer_radius, center)
|
|
90
|
+
|
|
91
|
+
return outer & ~inner
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
def resolution_to_pixels(
|
|
95
|
+
resolution_ang: float,
|
|
96
|
+
pixel_size_ang: float,
|
|
97
|
+
box_size: int,
|
|
98
|
+
) -> float:
|
|
99
|
+
"""
|
|
100
|
+
Convert resolution in Angstroms to radius in Fourier pixels.
|
|
101
|
+
|
|
102
|
+
The relationship is: radius_pixels = (pixel_size / resolution) * box_size
|
|
103
|
+
|
|
104
|
+
Args:
|
|
105
|
+
resolution_ang: Resolution in Angstroms
|
|
106
|
+
pixel_size_ang: Pixel size in Angstroms
|
|
107
|
+
box_size: Size of the FFT box
|
|
108
|
+
|
|
109
|
+
Returns:
|
|
110
|
+
Radius in Fourier pixels corresponding to the resolution
|
|
111
|
+
"""
|
|
112
|
+
return (pixel_size_ang / resolution_ang) * box_size
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
def dilate_mask(
|
|
116
|
+
mask: np.ndarray,
|
|
117
|
+
radius: int,
|
|
118
|
+
) -> np.ndarray:
|
|
119
|
+
"""
|
|
120
|
+
Dilate a binary mask using a circular structuring element.
|
|
121
|
+
|
|
122
|
+
This replicates the MATLAB filter2(circle, mask) approach using
|
|
123
|
+
scipy's ndimage for better performance.
|
|
124
|
+
|
|
125
|
+
Args:
|
|
126
|
+
mask: Input binary mask
|
|
127
|
+
radius: Dilation radius in pixels
|
|
128
|
+
|
|
129
|
+
Returns:
|
|
130
|
+
Dilated mask
|
|
131
|
+
"""
|
|
132
|
+
from scipy import ndimage
|
|
133
|
+
|
|
134
|
+
# Create circular structuring element
|
|
135
|
+
size = radius * 2 + 1
|
|
136
|
+
struct = create_circular_mask((size, size), radius - 1)
|
|
137
|
+
|
|
138
|
+
# Perform dilation
|
|
139
|
+
dilated = ndimage.binary_dilation(mask, structure=struct)
|
|
140
|
+
|
|
141
|
+
return dilated
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
def erode_mask(
|
|
145
|
+
mask: np.ndarray,
|
|
146
|
+
radius: int,
|
|
147
|
+
) -> np.ndarray:
|
|
148
|
+
"""
|
|
149
|
+
Erode a binary mask using a circular structuring element.
|
|
150
|
+
|
|
151
|
+
Args:
|
|
152
|
+
mask: Input binary mask
|
|
153
|
+
radius: Erosion radius in pixels
|
|
154
|
+
|
|
155
|
+
Returns:
|
|
156
|
+
Eroded mask
|
|
157
|
+
"""
|
|
158
|
+
from scipy import ndimage
|
|
159
|
+
|
|
160
|
+
size = radius * 2 + 1
|
|
161
|
+
struct = create_circular_mask((size, size), radius - 1)
|
|
162
|
+
|
|
163
|
+
eroded = ndimage.binary_erosion(mask, structure=struct)
|
|
164
|
+
|
|
165
|
+
return eroded
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
def create_fft_mask(
|
|
169
|
+
box_size: int,
|
|
170
|
+
pixel_ang: float,
|
|
171
|
+
inside_radius_ang: float,
|
|
172
|
+
outside_radius_ang: float,
|
|
173
|
+
threshold_mask: np.ndarray,
|
|
174
|
+
expand_pixel: int = 10,
|
|
175
|
+
) -> np.ndarray:
|
|
176
|
+
"""
|
|
177
|
+
Create the composite FFT mask for lattice spot removal.
|
|
178
|
+
|
|
179
|
+
This combines:
|
|
180
|
+
1. The threshold-based peak detection mask
|
|
181
|
+
2. Central protection zone (low frequencies)
|
|
182
|
+
3. Outer protection zone (near-Nyquist)
|
|
183
|
+
4. Morphological expansion for smooth transitions
|
|
184
|
+
|
|
185
|
+
Args:
|
|
186
|
+
box_size: Size of the FFT (square)
|
|
187
|
+
pixel_ang: Pixel size in Angstroms
|
|
188
|
+
inside_radius_ang: Inner resolution limit (protect center)
|
|
189
|
+
outside_radius_ang: Outer resolution limit (protect edges)
|
|
190
|
+
threshold_mask: Boolean mask from peak thresholding (True = peak)
|
|
191
|
+
expand_pixel: Expansion radius for morphological dilation
|
|
192
|
+
|
|
193
|
+
Returns:
|
|
194
|
+
Final mask where True = keep, False = replace with inpainted values
|
|
195
|
+
"""
|
|
196
|
+
# Convert resolution to Fourier pixels
|
|
197
|
+
inner_radius = resolution_to_pixels(inside_radius_ang, pixel_ang, box_size)
|
|
198
|
+
outer_radius = resolution_to_pixels(outside_radius_ang, pixel_ang, box_size)
|
|
199
|
+
|
|
200
|
+
# Clamp outer radius to valid range
|
|
201
|
+
outer_radius = min(outer_radius, box_size // 2 - 1)
|
|
202
|
+
|
|
203
|
+
# Create radial masks
|
|
204
|
+
shape = (box_size, box_size)
|
|
205
|
+
mask_center = create_circular_mask(shape, inner_radius) # Protect center
|
|
206
|
+
mask_outside = create_circular_mask(shape, outer_radius) # Within processing region
|
|
207
|
+
|
|
208
|
+
# Combine masks:
|
|
209
|
+
# - Keep center (low freq)
|
|
210
|
+
# - Keep outside Nyquist limit
|
|
211
|
+
# - In between: remove peaks (where threshold_mask is True)
|
|
212
|
+
# mask = ~threshold_mask OR ~mask_outside OR mask_center
|
|
213
|
+
combined = ~threshold_mask | ~mask_outside | mask_center
|
|
214
|
+
|
|
215
|
+
# Any non-zero value means "keep this pixel"
|
|
216
|
+
mask_final = combined
|
|
217
|
+
|
|
218
|
+
# Expand the removal regions (invert, dilate, invert back)
|
|
219
|
+
if expand_pixel > 0:
|
|
220
|
+
# Regions to remove (inverted mask)
|
|
221
|
+
removal_regions = ~mask_final
|
|
222
|
+
|
|
223
|
+
# Dilate the removal regions
|
|
224
|
+
rad_expand = expand_pixel // 2 - 1
|
|
225
|
+
if rad_expand > 0:
|
|
226
|
+
removal_regions = dilate_mask(removal_regions, rad_expand)
|
|
227
|
+
|
|
228
|
+
mask_final = ~removal_regions
|
|
229
|
+
|
|
230
|
+
return mask_final
|
|
231
|
+
|
|
232
|
+
|
|
233
|
+
# =============================================================================
|
|
234
|
+
# GPU-Accelerated Mask Functions
|
|
235
|
+
# =============================================================================
|
|
236
|
+
|
|
237
|
+
def create_circular_mask_gpu(
|
|
238
|
+
shape: Tuple[int, int],
|
|
239
|
+
radius: float,
|
|
240
|
+
center: Tuple[float, float] | None = None,
|
|
241
|
+
invert: bool = False,
|
|
242
|
+
device: Optional["torch.device"] = None,
|
|
243
|
+
) -> "torch.Tensor":
|
|
244
|
+
"""
|
|
245
|
+
Create a circular binary mask on GPU.
|
|
246
|
+
|
|
247
|
+
GPU-accelerated version of create_circular_mask using PyTorch.
|
|
248
|
+
|
|
249
|
+
Args:
|
|
250
|
+
shape: Output mask shape (height, width)
|
|
251
|
+
radius: Radius of the circular region in pixels
|
|
252
|
+
center: Center coordinates (y, x). If None, uses image center.
|
|
253
|
+
invert: If True, mask is 0 inside circle, 1 outside.
|
|
254
|
+
device: PyTorch device. If None, uses CUDA if available.
|
|
255
|
+
|
|
256
|
+
Returns:
|
|
257
|
+
Boolean tensor on specified device
|
|
258
|
+
"""
|
|
259
|
+
if not TORCH_AVAILABLE:
|
|
260
|
+
raise ImportError("PyTorch required for GPU mask operations")
|
|
261
|
+
|
|
262
|
+
h, w = shape
|
|
263
|
+
|
|
264
|
+
if device is None:
|
|
265
|
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
|
266
|
+
|
|
267
|
+
if center is None:
|
|
268
|
+
center = (h // 2, w // 2)
|
|
269
|
+
|
|
270
|
+
cy, cx = center
|
|
271
|
+
|
|
272
|
+
# Create coordinate grids on device
|
|
273
|
+
y = torch.arange(h, device=device, dtype=torch.float32).unsqueeze(1)
|
|
274
|
+
x = torch.arange(w, device=device, dtype=torch.float32).unsqueeze(0)
|
|
275
|
+
|
|
276
|
+
# Calculate distance from center
|
|
277
|
+
dist_sq = (y - cy) ** 2 + (x - cx) ** 2
|
|
278
|
+
|
|
279
|
+
# Create mask
|
|
280
|
+
mask = dist_sq < radius ** 2
|
|
281
|
+
|
|
282
|
+
if invert:
|
|
283
|
+
mask = ~mask
|
|
284
|
+
|
|
285
|
+
return mask
|
|
286
|
+
|
|
287
|
+
|
|
288
|
+
def dilate_mask_gpu(
|
|
289
|
+
mask: "torch.Tensor",
|
|
290
|
+
radius: int,
|
|
291
|
+
device: Optional["torch.device"] = None,
|
|
292
|
+
) -> "torch.Tensor":
|
|
293
|
+
"""
|
|
294
|
+
Dilate a binary mask using max pooling on GPU.
|
|
295
|
+
|
|
296
|
+
This is much faster than scipy.ndimage.binary_dilation for large masks.
|
|
297
|
+
Uses max pooling with a circular kernel approximation.
|
|
298
|
+
|
|
299
|
+
Args:
|
|
300
|
+
mask: Input boolean tensor (H, W)
|
|
301
|
+
radius: Dilation radius in pixels
|
|
302
|
+
device: PyTorch device. If None, uses mask's device.
|
|
303
|
+
|
|
304
|
+
Returns:
|
|
305
|
+
Dilated mask tensor
|
|
306
|
+
"""
|
|
307
|
+
if not TORCH_AVAILABLE:
|
|
308
|
+
raise ImportError("PyTorch required for GPU mask operations")
|
|
309
|
+
|
|
310
|
+
if device is None:
|
|
311
|
+
device = mask.device
|
|
312
|
+
|
|
313
|
+
# Kernel size must be odd
|
|
314
|
+
kernel_size = radius * 2 + 1
|
|
315
|
+
padding = radius
|
|
316
|
+
|
|
317
|
+
# Convert to float for max_pool2d (expects 4D: N, C, H, W)
|
|
318
|
+
mask_4d = mask.float().unsqueeze(0).unsqueeze(0)
|
|
319
|
+
|
|
320
|
+
# Max pooling acts as dilation for binary masks
|
|
321
|
+
dilated = torch.nn.functional.max_pool2d(
|
|
322
|
+
mask_4d,
|
|
323
|
+
kernel_size=kernel_size,
|
|
324
|
+
stride=1,
|
|
325
|
+
padding=padding,
|
|
326
|
+
)
|
|
327
|
+
|
|
328
|
+
# Convert back to boolean and remove batch/channel dims
|
|
329
|
+
return dilated.squeeze(0).squeeze(0) > 0.5
|
|
330
|
+
|
|
331
|
+
|
|
332
|
+
def create_fft_mask_gpu(
|
|
333
|
+
box_size: int,
|
|
334
|
+
pixel_ang: float,
|
|
335
|
+
inside_radius_ang: float,
|
|
336
|
+
outside_radius_ang: float,
|
|
337
|
+
threshold_mask: "torch.Tensor",
|
|
338
|
+
expand_pixel: int = 10,
|
|
339
|
+
device: Optional["torch.device"] = None,
|
|
340
|
+
) -> "torch.Tensor":
|
|
341
|
+
"""
|
|
342
|
+
Create the composite FFT mask for lattice spot removal on GPU.
|
|
343
|
+
|
|
344
|
+
GPU-accelerated version of create_fft_mask. All operations stay on GPU
|
|
345
|
+
to avoid CPU-GPU data transfers.
|
|
346
|
+
|
|
347
|
+
Args:
|
|
348
|
+
box_size: Size of the FFT (square)
|
|
349
|
+
pixel_ang: Pixel size in Angstroms
|
|
350
|
+
inside_radius_ang: Inner resolution limit (protect center)
|
|
351
|
+
outside_radius_ang: Outer resolution limit (protect edges)
|
|
352
|
+
threshold_mask: Boolean tensor from peak thresholding (True = peak)
|
|
353
|
+
expand_pixel: Expansion radius for morphological dilation
|
|
354
|
+
device: PyTorch device. If None, uses threshold_mask's device.
|
|
355
|
+
|
|
356
|
+
Returns:
|
|
357
|
+
Final mask tensor where True = keep, False = replace with inpainted values
|
|
358
|
+
"""
|
|
359
|
+
if not TORCH_AVAILABLE:
|
|
360
|
+
raise ImportError("PyTorch required for GPU mask operations")
|
|
361
|
+
|
|
362
|
+
if device is None:
|
|
363
|
+
device = threshold_mask.device
|
|
364
|
+
|
|
365
|
+
# Convert resolution to Fourier pixels
|
|
366
|
+
inner_radius = resolution_to_pixels(inside_radius_ang, pixel_ang, box_size)
|
|
367
|
+
outer_radius = resolution_to_pixels(outside_radius_ang, pixel_ang, box_size)
|
|
368
|
+
|
|
369
|
+
# Clamp outer radius to valid range
|
|
370
|
+
outer_radius = min(outer_radius, box_size // 2 - 1)
|
|
371
|
+
|
|
372
|
+
# Create radial masks on GPU
|
|
373
|
+
shape = (box_size, box_size)
|
|
374
|
+
mask_center = create_circular_mask_gpu(shape, inner_radius, device=device)
|
|
375
|
+
mask_outside = create_circular_mask_gpu(shape, outer_radius, device=device)
|
|
376
|
+
|
|
377
|
+
# Ensure threshold_mask is boolean tensor on correct device
|
|
378
|
+
if not isinstance(threshold_mask, torch.Tensor):
|
|
379
|
+
threshold_mask = torch.from_numpy(threshold_mask).to(device)
|
|
380
|
+
else:
|
|
381
|
+
threshold_mask = threshold_mask.to(device)
|
|
382
|
+
|
|
383
|
+
# Combine masks (same logic as CPU version)
|
|
384
|
+
combined = ~threshold_mask | ~mask_outside | mask_center
|
|
385
|
+
mask_final = combined
|
|
386
|
+
|
|
387
|
+
# Expand the removal regions
|
|
388
|
+
if expand_pixel > 0:
|
|
389
|
+
removal_regions = ~mask_final
|
|
390
|
+
|
|
391
|
+
rad_expand = expand_pixel // 2 - 1
|
|
392
|
+
if rad_expand > 0:
|
|
393
|
+
removal_regions = dilate_mask_gpu(removal_regions, rad_expand, device=device)
|
|
394
|
+
|
|
395
|
+
mask_final = ~removal_regions
|
|
396
|
+
|
|
397
|
+
return mask_final
|