lattice-sub 1.0.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,177 @@
1
+ """
2
+ MRC file I/O utilities.
3
+
4
+ This module wraps the mrcfile library for reading and writing MRC format files
5
+ commonly used in cryo-EM.
6
+ """
7
+
8
+ from pathlib import Path
9
+ from typing import Optional
10
+ import numpy as np
11
+
12
+ try:
13
+ import mrcfile
14
+ except ImportError:
15
+ raise ImportError(
16
+ "mrcfile is required for MRC I/O. Install with: pip install mrcfile"
17
+ )
18
+
19
+
20
+ def read_mrc(
21
+ path: str | Path,
22
+ as_float32: bool = True,
23
+ ) -> np.ndarray:
24
+ """
25
+ Read a 2D micrograph from an MRC file.
26
+
27
+ Args:
28
+ path: Path to MRC file
29
+ as_float32: If True, convert to float32. Default: True
30
+
31
+ Returns:
32
+ 2D numpy array containing the image data
33
+
34
+ Raises:
35
+ FileNotFoundError: If file does not exist
36
+ ValueError: If file contains 3D data (use read_mrc_stack instead)
37
+ """
38
+ path = Path(path)
39
+ if not path.exists():
40
+ raise FileNotFoundError(f"MRC file not found: {path}")
41
+
42
+ with mrcfile.open(path, mode='r', permissive=True) as mrc:
43
+ data = mrc.data.copy()
44
+
45
+ # Handle 3D MRC files (single slice)
46
+ if data.ndim == 3:
47
+ if data.shape[0] == 1:
48
+ data = data[0]
49
+ else:
50
+ raise ValueError(
51
+ f"Expected 2D micrograph, got 3D stack with shape {data.shape}. "
52
+ "Use read_mrc_stack() for 3D data."
53
+ )
54
+
55
+ if as_float32:
56
+ data = data.astype(np.float32)
57
+
58
+ return data
59
+
60
+
61
+ def read_mrc_stack(
62
+ path: str | Path,
63
+ as_float32: bool = True,
64
+ ) -> np.ndarray:
65
+ """
66
+ Read a 3D stack from an MRC file.
67
+
68
+ Args:
69
+ path: Path to MRC file
70
+ as_float32: If True, convert to float32. Default: True
71
+
72
+ Returns:
73
+ 3D numpy array with shape (nz, ny, nx)
74
+ """
75
+ path = Path(path)
76
+ if not path.exists():
77
+ raise FileNotFoundError(f"MRC file not found: {path}")
78
+
79
+ with mrcfile.open(path, mode='r', permissive=True) as mrc:
80
+ data = mrc.data.copy()
81
+
82
+ if as_float32:
83
+ data = data.astype(np.float32)
84
+
85
+ return data
86
+
87
+
88
+ def read_mrc_header(path: str | Path) -> dict:
89
+ """
90
+ Read only the header information from an MRC file.
91
+
92
+ Args:
93
+ path: Path to MRC file
94
+
95
+ Returns:
96
+ Dictionary containing header information including:
97
+ - shape: (nx, ny, nz)
98
+ - pixel_size: voxel size in Angstroms
99
+ - mode: data type mode
100
+ - statistics: (min, max, mean, rms)
101
+ """
102
+ path = Path(path)
103
+
104
+ with mrcfile.open(path, mode='r', permissive=True) as mrc:
105
+ header = {
106
+ 'shape': (int(mrc.header.nx), int(mrc.header.ny), int(mrc.header.nz)),
107
+ 'pixel_size': (
108
+ float(mrc.voxel_size.x),
109
+ float(mrc.voxel_size.y),
110
+ float(mrc.voxel_size.z)
111
+ ),
112
+ 'mode': int(mrc.header.mode),
113
+ 'statistics': (
114
+ float(mrc.header.dmin),
115
+ float(mrc.header.dmax),
116
+ float(mrc.header.dmean),
117
+ float(mrc.header.rms),
118
+ ),
119
+ }
120
+
121
+ return header
122
+
123
+
124
+ def write_mrc(
125
+ data: np.ndarray,
126
+ path: str | Path,
127
+ pixel_size: float = 1.0,
128
+ overwrite: bool = True,
129
+ ) -> None:
130
+ """
131
+ Write a 2D or 3D array to an MRC file.
132
+
133
+ Args:
134
+ data: 2D or 3D numpy array to write
135
+ path: Output file path
136
+ pixel_size: Pixel/voxel size in Angstroms. Default: 1.0
137
+ overwrite: If True, overwrite existing file. Default: True
138
+
139
+ Raises:
140
+ FileExistsError: If file exists and overwrite=False
141
+ ValueError: If data has invalid shape
142
+ """
143
+ path = Path(path)
144
+
145
+ if path.exists() and not overwrite:
146
+ raise FileExistsError(f"File already exists: {path}")
147
+
148
+ # Ensure parent directory exists
149
+ path.parent.mkdir(parents=True, exist_ok=True)
150
+
151
+ # Convert to float32 for compatibility
152
+ if data.dtype not in (np.float32, np.int16, np.uint16, np.int8, np.uint8):
153
+ data = data.astype(np.float32)
154
+
155
+ # Ensure contiguous array
156
+ data = np.ascontiguousarray(data)
157
+
158
+ with mrcfile.new(path, overwrite=overwrite) as mrc:
159
+ mrc.set_data(data)
160
+ mrc.voxel_size = pixel_size
161
+
162
+ # Update statistics
163
+ mrc.update_header_stats()
164
+
165
+
166
+ def get_pixel_size_from_mrc(path: str | Path) -> float:
167
+ """
168
+ Extract pixel size from MRC file header.
169
+
170
+ Args:
171
+ path: Path to MRC file
172
+
173
+ Returns:
174
+ Pixel size in Angstroms (from X dimension)
175
+ """
176
+ header = read_mrc_header(path)
177
+ return header['pixel_size'][0]
@@ -0,0 +1,397 @@
1
+ """
2
+ Mask generation utilities for FFT processing.
3
+
4
+ This module provides functions for creating circular masks and
5
+ performing morphological operations on masks.
6
+
7
+ GPU-accelerated versions are available when PyTorch with CUDA is present.
8
+ """
9
+
10
+ import numpy as np
11
+ from typing import Tuple, Optional, Union
12
+
13
+ # Try to import torch for GPU operations
14
+ try:
15
+ import torch
16
+ TORCH_AVAILABLE = True
17
+ except ImportError:
18
+ TORCH_AVAILABLE = False
19
+
20
+
21
+ def create_circular_mask(
22
+ shape: Tuple[int, int],
23
+ radius: float,
24
+ center: Tuple[float, float] | None = None,
25
+ invert: bool = False,
26
+ ) -> np.ndarray:
27
+ """
28
+ Create a circular binary mask.
29
+
30
+ This is the Python equivalent of bg_drill_hole.m, but optimized using
31
+ vectorized NumPy operations instead of nested loops.
32
+
33
+ Args:
34
+ shape: Output mask shape (height, width)
35
+ radius: Radius of the circular region in pixels
36
+ center: Center coordinates (y, x). If None, uses image center.
37
+ invert: If True, mask is 0 inside circle, 1 outside. Default: False
38
+
39
+ Returns:
40
+ Boolean mask array where True indicates the circular region
41
+ (or its complement if invert=True)
42
+
43
+ Example:
44
+ >>> mask = create_circular_mask((100, 100), radius=30)
45
+ >>> mask.shape
46
+ (100, 100)
47
+ """
48
+ h, w = shape
49
+
50
+ if center is None:
51
+ center = (h // 2, w // 2)
52
+
53
+ cy, cx = center
54
+
55
+ # Create coordinate grids
56
+ y, x = np.ogrid[:h, :w]
57
+
58
+ # Calculate distance from center
59
+ dist_sq = (y - cy) ** 2 + (x - cx) ** 2
60
+
61
+ # Create mask
62
+ mask = dist_sq < radius ** 2
63
+
64
+ if invert:
65
+ mask = ~mask
66
+
67
+ return mask
68
+
69
+
70
+ def create_radial_band_mask(
71
+ shape: Tuple[int, int],
72
+ inner_radius: float,
73
+ outer_radius: float,
74
+ center: Tuple[float, float] | None = None,
75
+ ) -> np.ndarray:
76
+ """
77
+ Create an annular (ring) mask between two radii.
78
+
79
+ Args:
80
+ shape: Output mask shape (height, width)
81
+ inner_radius: Inner radius of the ring
82
+ outer_radius: Outer radius of the ring
83
+ center: Center coordinates. If None, uses image center.
84
+
85
+ Returns:
86
+ Boolean mask that is True in the annular region
87
+ """
88
+ inner = create_circular_mask(shape, inner_radius, center)
89
+ outer = create_circular_mask(shape, outer_radius, center)
90
+
91
+ return outer & ~inner
92
+
93
+
94
+ def resolution_to_pixels(
95
+ resolution_ang: float,
96
+ pixel_size_ang: float,
97
+ box_size: int,
98
+ ) -> float:
99
+ """
100
+ Convert resolution in Angstroms to radius in Fourier pixels.
101
+
102
+ The relationship is: radius_pixels = (pixel_size / resolution) * box_size
103
+
104
+ Args:
105
+ resolution_ang: Resolution in Angstroms
106
+ pixel_size_ang: Pixel size in Angstroms
107
+ box_size: Size of the FFT box
108
+
109
+ Returns:
110
+ Radius in Fourier pixels corresponding to the resolution
111
+ """
112
+ return (pixel_size_ang / resolution_ang) * box_size
113
+
114
+
115
+ def dilate_mask(
116
+ mask: np.ndarray,
117
+ radius: int,
118
+ ) -> np.ndarray:
119
+ """
120
+ Dilate a binary mask using a circular structuring element.
121
+
122
+ This replicates the MATLAB filter2(circle, mask) approach using
123
+ scipy's ndimage for better performance.
124
+
125
+ Args:
126
+ mask: Input binary mask
127
+ radius: Dilation radius in pixels
128
+
129
+ Returns:
130
+ Dilated mask
131
+ """
132
+ from scipy import ndimage
133
+
134
+ # Create circular structuring element
135
+ size = radius * 2 + 1
136
+ struct = create_circular_mask((size, size), radius - 1)
137
+
138
+ # Perform dilation
139
+ dilated = ndimage.binary_dilation(mask, structure=struct)
140
+
141
+ return dilated
142
+
143
+
144
+ def erode_mask(
145
+ mask: np.ndarray,
146
+ radius: int,
147
+ ) -> np.ndarray:
148
+ """
149
+ Erode a binary mask using a circular structuring element.
150
+
151
+ Args:
152
+ mask: Input binary mask
153
+ radius: Erosion radius in pixels
154
+
155
+ Returns:
156
+ Eroded mask
157
+ """
158
+ from scipy import ndimage
159
+
160
+ size = radius * 2 + 1
161
+ struct = create_circular_mask((size, size), radius - 1)
162
+
163
+ eroded = ndimage.binary_erosion(mask, structure=struct)
164
+
165
+ return eroded
166
+
167
+
168
+ def create_fft_mask(
169
+ box_size: int,
170
+ pixel_ang: float,
171
+ inside_radius_ang: float,
172
+ outside_radius_ang: float,
173
+ threshold_mask: np.ndarray,
174
+ expand_pixel: int = 10,
175
+ ) -> np.ndarray:
176
+ """
177
+ Create the composite FFT mask for lattice spot removal.
178
+
179
+ This combines:
180
+ 1. The threshold-based peak detection mask
181
+ 2. Central protection zone (low frequencies)
182
+ 3. Outer protection zone (near-Nyquist)
183
+ 4. Morphological expansion for smooth transitions
184
+
185
+ Args:
186
+ box_size: Size of the FFT (square)
187
+ pixel_ang: Pixel size in Angstroms
188
+ inside_radius_ang: Inner resolution limit (protect center)
189
+ outside_radius_ang: Outer resolution limit (protect edges)
190
+ threshold_mask: Boolean mask from peak thresholding (True = peak)
191
+ expand_pixel: Expansion radius for morphological dilation
192
+
193
+ Returns:
194
+ Final mask where True = keep, False = replace with inpainted values
195
+ """
196
+ # Convert resolution to Fourier pixels
197
+ inner_radius = resolution_to_pixels(inside_radius_ang, pixel_ang, box_size)
198
+ outer_radius = resolution_to_pixels(outside_radius_ang, pixel_ang, box_size)
199
+
200
+ # Clamp outer radius to valid range
201
+ outer_radius = min(outer_radius, box_size // 2 - 1)
202
+
203
+ # Create radial masks
204
+ shape = (box_size, box_size)
205
+ mask_center = create_circular_mask(shape, inner_radius) # Protect center
206
+ mask_outside = create_circular_mask(shape, outer_radius) # Within processing region
207
+
208
+ # Combine masks:
209
+ # - Keep center (low freq)
210
+ # - Keep outside Nyquist limit
211
+ # - In between: remove peaks (where threshold_mask is True)
212
+ # mask = ~threshold_mask OR ~mask_outside OR mask_center
213
+ combined = ~threshold_mask | ~mask_outside | mask_center
214
+
215
+ # Any non-zero value means "keep this pixel"
216
+ mask_final = combined
217
+
218
+ # Expand the removal regions (invert, dilate, invert back)
219
+ if expand_pixel > 0:
220
+ # Regions to remove (inverted mask)
221
+ removal_regions = ~mask_final
222
+
223
+ # Dilate the removal regions
224
+ rad_expand = expand_pixel // 2 - 1
225
+ if rad_expand > 0:
226
+ removal_regions = dilate_mask(removal_regions, rad_expand)
227
+
228
+ mask_final = ~removal_regions
229
+
230
+ return mask_final
231
+
232
+
233
+ # =============================================================================
234
+ # GPU-Accelerated Mask Functions
235
+ # =============================================================================
236
+
237
+ def create_circular_mask_gpu(
238
+ shape: Tuple[int, int],
239
+ radius: float,
240
+ center: Tuple[float, float] | None = None,
241
+ invert: bool = False,
242
+ device: Optional["torch.device"] = None,
243
+ ) -> "torch.Tensor":
244
+ """
245
+ Create a circular binary mask on GPU.
246
+
247
+ GPU-accelerated version of create_circular_mask using PyTorch.
248
+
249
+ Args:
250
+ shape: Output mask shape (height, width)
251
+ radius: Radius of the circular region in pixels
252
+ center: Center coordinates (y, x). If None, uses image center.
253
+ invert: If True, mask is 0 inside circle, 1 outside.
254
+ device: PyTorch device. If None, uses CUDA if available.
255
+
256
+ Returns:
257
+ Boolean tensor on specified device
258
+ """
259
+ if not TORCH_AVAILABLE:
260
+ raise ImportError("PyTorch required for GPU mask operations")
261
+
262
+ h, w = shape
263
+
264
+ if device is None:
265
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
266
+
267
+ if center is None:
268
+ center = (h // 2, w // 2)
269
+
270
+ cy, cx = center
271
+
272
+ # Create coordinate grids on device
273
+ y = torch.arange(h, device=device, dtype=torch.float32).unsqueeze(1)
274
+ x = torch.arange(w, device=device, dtype=torch.float32).unsqueeze(0)
275
+
276
+ # Calculate distance from center
277
+ dist_sq = (y - cy) ** 2 + (x - cx) ** 2
278
+
279
+ # Create mask
280
+ mask = dist_sq < radius ** 2
281
+
282
+ if invert:
283
+ mask = ~mask
284
+
285
+ return mask
286
+
287
+
288
+ def dilate_mask_gpu(
289
+ mask: "torch.Tensor",
290
+ radius: int,
291
+ device: Optional["torch.device"] = None,
292
+ ) -> "torch.Tensor":
293
+ """
294
+ Dilate a binary mask using max pooling on GPU.
295
+
296
+ This is much faster than scipy.ndimage.binary_dilation for large masks.
297
+ Uses max pooling with a circular kernel approximation.
298
+
299
+ Args:
300
+ mask: Input boolean tensor (H, W)
301
+ radius: Dilation radius in pixels
302
+ device: PyTorch device. If None, uses mask's device.
303
+
304
+ Returns:
305
+ Dilated mask tensor
306
+ """
307
+ if not TORCH_AVAILABLE:
308
+ raise ImportError("PyTorch required for GPU mask operations")
309
+
310
+ if device is None:
311
+ device = mask.device
312
+
313
+ # Kernel size must be odd
314
+ kernel_size = radius * 2 + 1
315
+ padding = radius
316
+
317
+ # Convert to float for max_pool2d (expects 4D: N, C, H, W)
318
+ mask_4d = mask.float().unsqueeze(0).unsqueeze(0)
319
+
320
+ # Max pooling acts as dilation for binary masks
321
+ dilated = torch.nn.functional.max_pool2d(
322
+ mask_4d,
323
+ kernel_size=kernel_size,
324
+ stride=1,
325
+ padding=padding,
326
+ )
327
+
328
+ # Convert back to boolean and remove batch/channel dims
329
+ return dilated.squeeze(0).squeeze(0) > 0.5
330
+
331
+
332
+ def create_fft_mask_gpu(
333
+ box_size: int,
334
+ pixel_ang: float,
335
+ inside_radius_ang: float,
336
+ outside_radius_ang: float,
337
+ threshold_mask: "torch.Tensor",
338
+ expand_pixel: int = 10,
339
+ device: Optional["torch.device"] = None,
340
+ ) -> "torch.Tensor":
341
+ """
342
+ Create the composite FFT mask for lattice spot removal on GPU.
343
+
344
+ GPU-accelerated version of create_fft_mask. All operations stay on GPU
345
+ to avoid CPU-GPU data transfers.
346
+
347
+ Args:
348
+ box_size: Size of the FFT (square)
349
+ pixel_ang: Pixel size in Angstroms
350
+ inside_radius_ang: Inner resolution limit (protect center)
351
+ outside_radius_ang: Outer resolution limit (protect edges)
352
+ threshold_mask: Boolean tensor from peak thresholding (True = peak)
353
+ expand_pixel: Expansion radius for morphological dilation
354
+ device: PyTorch device. If None, uses threshold_mask's device.
355
+
356
+ Returns:
357
+ Final mask tensor where True = keep, False = replace with inpainted values
358
+ """
359
+ if not TORCH_AVAILABLE:
360
+ raise ImportError("PyTorch required for GPU mask operations")
361
+
362
+ if device is None:
363
+ device = threshold_mask.device
364
+
365
+ # Convert resolution to Fourier pixels
366
+ inner_radius = resolution_to_pixels(inside_radius_ang, pixel_ang, box_size)
367
+ outer_radius = resolution_to_pixels(outside_radius_ang, pixel_ang, box_size)
368
+
369
+ # Clamp outer radius to valid range
370
+ outer_radius = min(outer_radius, box_size // 2 - 1)
371
+
372
+ # Create radial masks on GPU
373
+ shape = (box_size, box_size)
374
+ mask_center = create_circular_mask_gpu(shape, inner_radius, device=device)
375
+ mask_outside = create_circular_mask_gpu(shape, outer_radius, device=device)
376
+
377
+ # Ensure threshold_mask is boolean tensor on correct device
378
+ if not isinstance(threshold_mask, torch.Tensor):
379
+ threshold_mask = torch.from_numpy(threshold_mask).to(device)
380
+ else:
381
+ threshold_mask = threshold_mask.to(device)
382
+
383
+ # Combine masks (same logic as CPU version)
384
+ combined = ~threshold_mask | ~mask_outside | mask_center
385
+ mask_final = combined
386
+
387
+ # Expand the removal regions
388
+ if expand_pixel > 0:
389
+ removal_regions = ~mask_final
390
+
391
+ rad_expand = expand_pixel // 2 - 1
392
+ if rad_expand > 0:
393
+ removal_regions = dilate_mask_gpu(removal_regions, rad_expand, device=device)
394
+
395
+ mask_final = ~removal_regions
396
+
397
+ return mask_final