lattice-sub 1.0.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,216 @@
1
+ """
2
+ Configuration management for lattice subtraction.
3
+
4
+ This module handles loading, validation, and storage of processing parameters
5
+ from YAML configuration files or Python dictionaries.
6
+ """
7
+
8
+ from dataclasses import dataclass, field
9
+ from pathlib import Path
10
+ from typing import Optional, Literal
11
+ import yaml
12
+
13
+
14
+ @dataclass
15
+ class Config:
16
+ """
17
+ Configuration parameters for lattice subtraction processing.
18
+
19
+ All resolution parameters are in Angstroms. The algorithm removes lattice
20
+ peaks in the resolution range between inside_radius_ang and outside_radius_ang.
21
+
22
+ Attributes:
23
+ pixel_ang: Pixel size in Angstroms (detector-dependent, e.g., 0.56 for K3)
24
+ inside_radius_ang: Inner resolution limit - FFT spots within this radius
25
+ are preserved (low-frequency structural info). Default: 90Å
26
+ outside_radius_ang: Outer resolution limit - spots beyond this are preserved.
27
+ If None, auto-calculated as pixel_ang * 2 + 0.2
28
+ threshold: Peak detection threshold on log-amplitude FFT. Spots above this
29
+ value are identified as lattice peaks. Default: 1.42
30
+ expand_pixel: Morphological expansion radius for mask dilation. Default: 10
31
+ pad_origin_x: X padding offset in pixels. Default: 200
32
+ pad_origin_y: Y padding offset in pixels. Default: 200 (use 1000 for K3)
33
+ pad_output: If False, crop output to original size. Default: False
34
+ unit_cell_ang: Crystal unit cell size in Angstroms for shift calculation.
35
+ Default: 116Å (nucleosome repeat)
36
+ backend: Computation backend - 'numpy' for CPU, 'pytorch' for GPU. Default: 'numpy'
37
+ """
38
+
39
+ # Required parameter
40
+ pixel_ang: float
41
+
42
+ # Resolution limits
43
+ inside_radius_ang: float = 90.0
44
+ outside_radius_ang: Optional[float] = None # Auto-calculated if None
45
+
46
+ # Peak detection
47
+ threshold: float = 1.42
48
+ expand_pixel: int = 10
49
+
50
+ # Padding
51
+ pad_origin_x: int = 200
52
+ pad_origin_y: int = 200
53
+ pad_output: bool = False
54
+
55
+ # Crystal parameters
56
+ unit_cell_ang: float = 116.0 # Nucleosome repeat distance
57
+
58
+ # Computation backend: 'auto' tries GPU first, then falls back to CPU
59
+ backend: Literal["numpy", "pytorch", "auto"] = "auto"
60
+
61
+ def __post_init__(self):
62
+ """Validate and set auto-calculated parameters."""
63
+ if self.pixel_ang <= 0:
64
+ raise ValueError(f"pixel_ang must be positive, got {self.pixel_ang}")
65
+
66
+ if self.inside_radius_ang <= 0:
67
+ raise ValueError(f"inside_radius_ang must be positive, got {self.inside_radius_ang}")
68
+
69
+ if self.threshold <= 0:
70
+ raise ValueError(f"threshold must be positive, got {self.threshold}")
71
+
72
+ # Auto-calculate outside radius if not provided
73
+ if self.outside_radius_ang is None:
74
+ self.outside_radius_ang = self.pixel_ang * 2 + 0.2
75
+
76
+ if self.outside_radius_ang >= self.inside_radius_ang:
77
+ raise ValueError(
78
+ f"outside_radius_ang ({self.outside_radius_ang}) must be smaller than "
79
+ f"inside_radius_ang ({self.inside_radius_ang})"
80
+ )
81
+
82
+ @classmethod
83
+ def from_yaml(cls, path: str | Path) -> "Config":
84
+ """
85
+ Load configuration from a YAML file.
86
+
87
+ Args:
88
+ path: Path to YAML configuration file
89
+
90
+ Returns:
91
+ Config instance with loaded parameters
92
+
93
+ Example YAML format:
94
+ pixel_ang: 0.56
95
+ threshold: 1.56
96
+ inside_radius_ang: 90
97
+ # outside_radius_ang: auto # Optional, auto-calculated if omitted
98
+ """
99
+ path = Path(path)
100
+ if not path.exists():
101
+ raise FileNotFoundError(f"Configuration file not found: {path}")
102
+
103
+ with open(path, 'r') as f:
104
+ data = yaml.safe_load(f)
105
+
106
+ # Handle 'auto' string for outside_radius_ang
107
+ if data.get('outside_radius_ang') == 'auto':
108
+ data['outside_radius_ang'] = None
109
+
110
+ return cls(**data)
111
+
112
+ @classmethod
113
+ def from_legacy_parameter_file(cls, path: str | Path) -> "Config":
114
+ """
115
+ Load configuration from legacy MATLAB PARAMETER file format.
116
+
117
+ Args:
118
+ path: Path to legacy PARAMETER file
119
+
120
+ Returns:
121
+ Config instance with loaded parameters
122
+ """
123
+ path = Path(path)
124
+ params = {}
125
+
126
+ # Mapping from legacy names to new names
127
+ name_map = {
128
+ 'inside_radius_ang': 'inside_radius_ang',
129
+ 'outside_radius_ang': 'outside_radius_ang',
130
+ 'pixel_ang': 'pixel_ang',
131
+ 'threshold': 'threshold',
132
+ }
133
+
134
+ with open(path, 'r') as f:
135
+ for line in f:
136
+ line = line.strip()
137
+ if not line or line.startswith('!'):
138
+ continue
139
+
140
+ # Split on whitespace, handle comments with !
141
+ parts = line.split('!')
142
+ main_part = parts[0].strip()
143
+ if not main_part:
144
+ continue
145
+
146
+ tokens = main_part.split()
147
+ if len(tokens) >= 2:
148
+ name = tokens[0].lower()
149
+ try:
150
+ value = float(tokens[1])
151
+ except ValueError:
152
+ continue
153
+
154
+ if name in name_map:
155
+ params[name_map[name]] = value
156
+
157
+ return cls(**params)
158
+
159
+ def to_yaml(self, path: str | Path) -> None:
160
+ """
161
+ Save configuration to a YAML file.
162
+
163
+ Args:
164
+ path: Output path for YAML file
165
+ """
166
+ path = Path(path)
167
+
168
+ data = {
169
+ 'pixel_ang': self.pixel_ang,
170
+ 'inside_radius_ang': self.inside_radius_ang,
171
+ 'outside_radius_ang': self.outside_radius_ang,
172
+ 'threshold': self.threshold,
173
+ 'expand_pixel': self.expand_pixel,
174
+ 'pad_origin_x': self.pad_origin_x,
175
+ 'pad_origin_y': self.pad_origin_y,
176
+ 'pad_output': self.pad_output,
177
+ 'unit_cell_ang': self.unit_cell_ang,
178
+ 'backend': self.backend,
179
+ }
180
+
181
+ with open(path, 'w') as f:
182
+ yaml.dump(data, f, default_flow_style=False, sort_keys=False)
183
+
184
+ def copy(self, **updates) -> "Config":
185
+ """
186
+ Create a copy of this config with optional updates.
187
+
188
+ Args:
189
+ **updates: Parameters to override
190
+
191
+ Returns:
192
+ New Config instance with updates applied
193
+ """
194
+ from dataclasses import asdict
195
+ current = asdict(self)
196
+ current.update(updates)
197
+ return Config(**current)
198
+
199
+
200
+ def create_default_config(pixel_ang: float = 0.56, detector: str = "K3") -> Config:
201
+ """
202
+ Create a config with detector-specific defaults.
203
+
204
+ Args:
205
+ pixel_ang: Pixel size in Angstroms
206
+ detector: Detector type ('K3', 'Falcon', 'generic')
207
+
208
+ Returns:
209
+ Config with appropriate defaults for the detector
210
+ """
211
+ pad_y = 1000 if detector.upper() == "K3" else 200
212
+
213
+ return Config(
214
+ pixel_ang=pixel_ang,
215
+ pad_origin_y=pad_y,
216
+ )
@@ -0,0 +1,389 @@
1
+ """
2
+ Core lattice subtraction algorithm.
3
+
4
+ This module contains the main LatticeSubtractor class that implements
5
+ the phase-preserving lattice removal algorithm.
6
+ """
7
+
8
+ from dataclasses import dataclass
9
+ from pathlib import Path
10
+ from typing import Optional
11
+ import numpy as np
12
+
13
+ from .config import Config
14
+ from .io import read_mrc, write_mrc
15
+ from .masks import create_fft_mask, create_fft_mask_gpu, resolution_to_pixels
16
+ from .processing import (
17
+ pad_image,
18
+ crop_to_original,
19
+ subtract_background,
20
+ compute_power_spectrum,
21
+ shift_and_average,
22
+ )
23
+
24
+
25
+ @dataclass
26
+ class SubtractionResult:
27
+ """
28
+ Result of lattice subtraction processing.
29
+
30
+ Attributes:
31
+ image: Processed image with lattice removed
32
+ original_shape: Shape of input image before padding
33
+ fft_mask: The mask used for FFT filtering (optional)
34
+ power_spectrum: Background-subtracted power spectrum (optional)
35
+ """
36
+ image: np.ndarray
37
+ original_shape: tuple
38
+ fft_mask: Optional[np.ndarray] = None
39
+ power_spectrum: Optional[np.ndarray] = None
40
+
41
+ def save(self, path: str | Path, pixel_size: float = 1.0) -> None:
42
+ """Save the processed image to an MRC file."""
43
+ write_mrc(self.image, path, pixel_size=pixel_size)
44
+
45
+
46
+ class LatticeSubtractor:
47
+ """
48
+ Main class for lattice subtraction from cryo-EM micrographs.
49
+
50
+ This class implements the algorithm from bg_push_by_rot.m:
51
+ 1. Pad image and compute 2D FFT
52
+ 2. Detect lattice peaks via thresholding on log-power spectrum
53
+ 3. Create composite mask (protect center and edges)
54
+ 4. Inpaint masked regions with local average amplitude
55
+ 5. Preserve original phase, replace amplitude
56
+ 6. Inverse FFT and crop
57
+
58
+ The algorithm removes periodic crystal lattice signals while
59
+ preserving non-periodic features in the image.
60
+
61
+ Example:
62
+ >>> config = Config(pixel_ang=0.56, threshold=1.56)
63
+ >>> subtractor = LatticeSubtractor(config)
64
+ >>> result = subtractor.process("input.mrc")
65
+ >>> result.save("output.mrc")
66
+ """
67
+
68
+ def __init__(self, config: Config):
69
+ """
70
+ Initialize the subtractor with configuration.
71
+
72
+ Args:
73
+ config: Configuration parameters for processing
74
+ """
75
+ self.config = config
76
+ self._setup_backend()
77
+
78
+ def _setup_backend(self) -> None:
79
+ """Setup computation backend (numpy, pytorch, or auto).
80
+
81
+ Auto mode tries PyTorch+CUDA first, then PyTorch CPU, then NumPy.
82
+ Prints user-friendly status message about which backend is active.
83
+ """
84
+ backend = self.config.backend
85
+ self._gpu_message_shown = getattr(self, '_gpu_message_shown', False)
86
+
87
+ # Auto mode: try GPU first, then CPU
88
+ if backend == "auto":
89
+ try:
90
+ import torch
91
+ if torch.cuda.is_available():
92
+ self.device = torch.device('cuda')
93
+ self.use_gpu = True
94
+ # Only print once per session (batch processing reuses subtractor)
95
+ if not self._gpu_message_shown:
96
+ gpu_name = torch.cuda.get_device_name(0)
97
+ print(f"✓ Using GPU: {gpu_name}")
98
+ self._gpu_message_shown = True
99
+ else:
100
+ self.device = torch.device('cpu')
101
+ self.use_gpu = False
102
+ if not self._gpu_message_shown:
103
+ print("ℹ Running on CPU (run 'lattice-sub setup-gpu' to enable GPU)")
104
+ self._gpu_message_shown = True
105
+ except ImportError:
106
+ self.device = None
107
+ self.use_gpu = False
108
+ if not self._gpu_message_shown:
109
+ print("ℹ Running on CPU with NumPy (PyTorch not installed)")
110
+ self._gpu_message_shown = True
111
+
112
+ elif backend == "pytorch":
113
+ try:
114
+ import torch
115
+ if torch.cuda.is_available():
116
+ self.device = torch.device('cuda')
117
+ self.use_gpu = True
118
+ else:
119
+ import warnings
120
+ warnings.warn(
121
+ "CUDA not available, falling back to CPU."
122
+ )
123
+ self.device = torch.device('cpu')
124
+ self.use_gpu = False
125
+ except ImportError:
126
+ import warnings
127
+ warnings.warn(
128
+ "PyTorch not available, falling back to NumPy. "
129
+ "Install with: pip install torch"
130
+ )
131
+ self.device = None
132
+ self.use_gpu = False
133
+ else:
134
+ # numpy backend
135
+ self.device = None
136
+ self.use_gpu = False
137
+
138
+ def _to_device(self, array: np.ndarray):
139
+ """Move array to GPU if using PyTorch."""
140
+ if self.use_gpu and self.device is not None:
141
+ import torch
142
+ return torch.from_numpy(array).to(self.device)
143
+ return array
144
+
145
+ def _to_numpy(self, array) -> np.ndarray:
146
+ """Move array from GPU to CPU if needed."""
147
+ if self.use_gpu and hasattr(array, 'cpu'):
148
+ return array.cpu().numpy()
149
+ return array
150
+
151
+ def process(
152
+ self,
153
+ input_path: str | Path | np.ndarray,
154
+ return_diagnostics: bool = False,
155
+ ) -> SubtractionResult:
156
+ """
157
+ Process a micrograph to remove lattice signal.
158
+
159
+ Args:
160
+ input_path: Path to input MRC file, or numpy array
161
+ return_diagnostics: If True, include mask and power spectrum in result
162
+
163
+ Returns:
164
+ SubtractionResult containing processed image and optional diagnostics
165
+ """
166
+ # Load image
167
+ if isinstance(input_path, (str, Path)):
168
+ image = read_mrc(input_path)
169
+ else:
170
+ image = input_path.astype(np.float32)
171
+
172
+ original_shape = image.shape
173
+
174
+ # Pad image
175
+ padded, pad_meta = pad_image(
176
+ image,
177
+ pad_origin=(self.config.pad_origin_y, self.config.pad_origin_x),
178
+ )
179
+
180
+ # Process
181
+ result_padded, fft_mask, power_spec = self._process_padded(
182
+ padded,
183
+ return_diagnostics=return_diagnostics,
184
+ )
185
+
186
+ # Crop to original size if requested
187
+ if not self.config.pad_output:
188
+ result_image = crop_to_original(result_padded, pad_meta)
189
+ else:
190
+ result_image = result_padded
191
+
192
+ return SubtractionResult(
193
+ image=result_image,
194
+ original_shape=original_shape,
195
+ fft_mask=fft_mask if return_diagnostics else None,
196
+ power_spectrum=power_spec if return_diagnostics else None,
197
+ )
198
+
199
+ def _process_padded(
200
+ self,
201
+ image: np.ndarray,
202
+ return_diagnostics: bool = False,
203
+ ) -> tuple:
204
+ """
205
+ Core processing on a padded image.
206
+
207
+ This implements the algorithm from bg_push_by_rot.m.
208
+ """
209
+ # Convert to float64 for processing precision
210
+ img = self._to_device(image.astype(np.float64))
211
+ box_size = image.shape[0]
212
+
213
+ # Step 1: Compute FFT and shift to center DC
214
+ if self.use_gpu:
215
+ import torch
216
+ fft_img = torch.fft.fft2(img)
217
+ fft_shifted = torch.fft.fftshift(fft_img)
218
+ # Step 2: Compute log-power spectrum
219
+ power_spectrum = torch.abs(torch.log(torch.abs(fft_shifted) + 1e-10))
220
+ else:
221
+ from scipy import fft
222
+ fft_img = fft.fft2(img)
223
+ fft_shifted = fft.fftshift(fft_img)
224
+ # Step 2: Compute log-power spectrum
225
+ power_spectrum = np.abs(np.log(np.abs(fft_shifted) + 1e-10))
226
+
227
+ # Step 3: Background subtraction for peak detection
228
+ # Move to numpy for scipy operations
229
+ power_np = self._to_numpy(power_spectrum)
230
+ subtracted = subtract_background(power_np)
231
+
232
+ # Step 4: Threshold to detect peaks
233
+ threshold_mask = subtracted > self.config.threshold
234
+
235
+ # Step 5: Create composite mask with radial limits
236
+ # Use GPU-accelerated mask creation when available
237
+ if self.use_gpu:
238
+ import torch
239
+ # Convert threshold mask to GPU tensor
240
+ threshold_tensor = torch.from_numpy(threshold_mask).to(self.device)
241
+
242
+ # Create mask entirely on GPU
243
+ mask_final_dev = create_fft_mask_gpu(
244
+ box_size=box_size,
245
+ pixel_ang=self.config.pixel_ang,
246
+ inside_radius_ang=self.config.inside_radius_ang,
247
+ outside_radius_ang=self.config.outside_radius_ang,
248
+ threshold_mask=threshold_tensor,
249
+ expand_pixel=self.config.expand_pixel,
250
+ device=self.device,
251
+ ).float()
252
+ else:
253
+ # CPU path
254
+ mask_final = create_fft_mask(
255
+ box_size=box_size,
256
+ pixel_ang=self.config.pixel_ang,
257
+ inside_radius_ang=self.config.inside_radius_ang,
258
+ outside_radius_ang=self.config.outside_radius_ang,
259
+ threshold_mask=threshold_mask,
260
+ expand_pixel=self.config.expand_pixel,
261
+ )
262
+ mask_final_dev = self._to_device(mask_final.astype(np.float64))
263
+
264
+ # Step 6: Inpainting with local averaging
265
+ # Keep unmasked FFT values
266
+ fft_keep = mask_final_dev * fft_shifted
267
+
268
+ # Calculate shift distance (based on unit cell)
269
+ shift_pixels = int(
270
+ self.config.pixel_ang / self.config.unit_cell_ang * box_size
271
+ )
272
+ shift_pixels = max(1, shift_pixels) # Ensure at least 1 pixel shift
273
+
274
+ if self.use_gpu:
275
+ import torch
276
+ # Compute amplitude of kept FFT values (zeros where mask removes peaks)
277
+ # This matches MATLAB: abs_y2_A = abs(y2_A) where y2_A = mask_final .* y2
278
+ amplitude_keep = torch.abs(fft_keep)
279
+
280
+ # Shift and average (inpainting) - propagates good values into zero regions
281
+ # This matches MATLAB circshift averaging
282
+ shift_avg = (
283
+ torch.roll(amplitude_keep, shift_pixels, dims=0) +
284
+ torch.roll(amplitude_keep, -shift_pixels, dims=0) +
285
+ torch.roll(amplitude_keep, shift_pixels, dims=1) +
286
+ torch.roll(amplitude_keep, -shift_pixels, dims=1)
287
+ ) / 4.0
288
+
289
+ # Step 7: Replace masked amplitudes, preserve ORIGINAL phase
290
+ # MATLAB: y2_B = ~mask_final .* shift_ave
291
+ # MATLAB: angle_y2_ori_B = angle(y .* ~mask_final) <- uses ORIGINAL FFT phase
292
+ mask_remove = ~mask_final_dev.bool()
293
+ inpaint_amplitude = mask_remove.float() * shift_avg
294
+
295
+ # Get original phase at masked positions FROM ORIGINAL FFT (not fft_keep)
296
+ # This is critical - MATLAB uses: angle(y .* ~mask_final)
297
+ original_phase = torch.angle(fft_shifted * mask_remove.float())
298
+
299
+ # Reconstruct: keep + inpainted with original phase
300
+ # MATLAB: y2 = y2_A + value_y2_B .* exp(i .* angle_y2_ori_B)
301
+ fft_result = fft_keep + inpaint_amplitude * torch.exp(1j * original_phase)
302
+
303
+ # Step 8: Inverse FFT
304
+ fft_result = torch.fft.ifftshift(fft_result)
305
+ result = torch.fft.ifft2(fft_result)
306
+
307
+ # Take real part
308
+ result = torch.real(result).float()
309
+ else:
310
+ # NumPy/SciPy path - same algorithm as GPU
311
+ # Compute amplitude of kept FFT values (zeros where mask removes peaks)
312
+ amplitude_keep = np.abs(fft_keep)
313
+
314
+ # Shift and average (inpainting) - propagates good values into zero regions
315
+ shift_avg = (
316
+ np.roll(amplitude_keep, shift_pixels, axis=0) +
317
+ np.roll(amplitude_keep, -shift_pixels, axis=0) +
318
+ np.roll(amplitude_keep, shift_pixels, axis=1) +
319
+ np.roll(amplitude_keep, -shift_pixels, axis=1)
320
+ ) / 4.0
321
+
322
+ # Step 7: Replace masked amplitudes, preserve ORIGINAL phase
323
+ mask_remove = ~mask_final.astype(bool)
324
+ inpaint_amplitude = mask_remove.astype(np.float64) * shift_avg
325
+
326
+ # Get original phase at masked positions FROM ORIGINAL FFT
327
+ original_phase = np.angle(fft_shifted * mask_remove.astype(np.float64))
328
+
329
+ # Reconstruct: keep + inpainted with original phase
330
+ fft_result = fft_keep + inpaint_amplitude * np.exp(1j * original_phase)
331
+
332
+ # Step 8: Inverse FFT
333
+ from scipy import fft
334
+ fft_result = fft.ifftshift(fft_result)
335
+ result = fft.ifft2(fft_result)
336
+
337
+ # Take real part
338
+ result = np.real(result).astype(np.float32)
339
+
340
+ # Move results to numpy
341
+ result_np = self._to_numpy(result)
342
+
343
+ # For diagnostics, get mask as numpy array
344
+ if return_diagnostics:
345
+ if self.use_gpu:
346
+ mask_np = self._to_numpy(mask_final_dev).astype(bool)
347
+ else:
348
+ mask_np = mask_final.astype(bool)
349
+ power_np = subtracted
350
+ else:
351
+ mask_np = None
352
+ power_np = None
353
+
354
+ return result_np, mask_np, power_np
355
+
356
+ def process_array(
357
+ self,
358
+ image: np.ndarray,
359
+ return_diagnostics: bool = False,
360
+ ) -> SubtractionResult:
361
+ """
362
+ Process a numpy array directly.
363
+
364
+ Args:
365
+ image: Input 2D numpy array
366
+ return_diagnostics: If True, include mask and power spectrum
367
+
368
+ Returns:
369
+ SubtractionResult
370
+ """
371
+ return self.process(image, return_diagnostics=return_diagnostics)
372
+
373
+
374
+ def process_micrograph(
375
+ input_path: str | Path,
376
+ output_path: str | Path,
377
+ config: Config,
378
+ ) -> None:
379
+ """
380
+ Convenience function to process a single micrograph.
381
+
382
+ Args:
383
+ input_path: Path to input MRC file
384
+ output_path: Path for output MRC file
385
+ config: Processing configuration
386
+ """
387
+ subtractor = LatticeSubtractor(config)
388
+ result = subtractor.process(input_path)
389
+ result.save(output_path, pixel_size=config.pixel_ang)