lattice-sub 1.0.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- lattice_sub-1.0.10.dist-info/METADATA +324 -0
- lattice_sub-1.0.10.dist-info/RECORD +16 -0
- lattice_sub-1.0.10.dist-info/WHEEL +5 -0
- lattice_sub-1.0.10.dist-info/entry_points.txt +2 -0
- lattice_sub-1.0.10.dist-info/licenses/LICENSE +21 -0
- lattice_sub-1.0.10.dist-info/top_level.txt +1 -0
- lattice_subtraction/__init__.py +49 -0
- lattice_subtraction/batch.py +374 -0
- lattice_subtraction/cli.py +751 -0
- lattice_subtraction/config.py +216 -0
- lattice_subtraction/core.py +389 -0
- lattice_subtraction/io.py +177 -0
- lattice_subtraction/masks.py +397 -0
- lattice_subtraction/processing.py +221 -0
- lattice_subtraction/ui.py +256 -0
- lattice_subtraction/visualization.py +195 -0
|
@@ -0,0 +1,751 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Command-line interface for lattice subtraction.
|
|
3
|
+
|
|
4
|
+
This module provides a Click-based CLI for processing single files
|
|
5
|
+
or batch directories from the command line.
|
|
6
|
+
|
|
7
|
+
Terminal output is styled when running interactively. When piped
|
|
8
|
+
or used in a pipeline, decorative output is automatically suppressed.
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
import logging
|
|
12
|
+
import re
|
|
13
|
+
import subprocess
|
|
14
|
+
import sys
|
|
15
|
+
from os import cpu_count as import_cpu_count
|
|
16
|
+
from pathlib import Path
|
|
17
|
+
from typing import Optional
|
|
18
|
+
|
|
19
|
+
import click
|
|
20
|
+
|
|
21
|
+
from .config import Config, create_default_config
|
|
22
|
+
from .core import LatticeSubtractor
|
|
23
|
+
from .batch import BatchProcessor
|
|
24
|
+
from .visualization import generate_visualizations, save_comparison_visualization
|
|
25
|
+
from .ui import get_ui, get_gpu_name
|
|
26
|
+
from .io import read_mrc
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
# CUDA version to PyTorch index URL mapping
|
|
30
|
+
# Note: PyTorch 2.5+ bundles CUDA 12.x by default, so explicit CUDA wheels
|
|
31
|
+
# are often not needed. This mapping is for cases where reinstallation helps.
|
|
32
|
+
CUDA_INDEX_URLS = {
|
|
33
|
+
"11.8": "https://download.pytorch.org/whl/cu118",
|
|
34
|
+
"12.1": "https://download.pytorch.org/whl/cu121",
|
|
35
|
+
"12.4": "https://download.pytorch.org/whl/cu124", # Tested
|
|
36
|
+
"12.6": "https://download.pytorch.org/whl/cu126",
|
|
37
|
+
"12.8": "https://download.pytorch.org/whl/cu128",
|
|
38
|
+
}
|
|
39
|
+
|
|
40
|
+
# CUDA versions that can use newer wheel versions (backward compatible)
|
|
41
|
+
CUDA_FALLBACK = {
|
|
42
|
+
"13.0": "12.8",
|
|
43
|
+
"13.1": "12.8",
|
|
44
|
+
"13.2": "12.8",
|
|
45
|
+
}
|
|
46
|
+
|
|
47
|
+
RECOMMENDED_CUDA = "12.8"
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def detect_cuda_version() -> Optional[str]:
|
|
51
|
+
"""Detect CUDA version from nvidia-smi output.
|
|
52
|
+
|
|
53
|
+
Returns:
|
|
54
|
+
CUDA version string (e.g., "12.4") or None if not available.
|
|
55
|
+
"""
|
|
56
|
+
try:
|
|
57
|
+
result = subprocess.run(
|
|
58
|
+
["nvidia-smi", "--query-gpu=driver_version", "--format=csv,noheader"],
|
|
59
|
+
capture_output=True,
|
|
60
|
+
text=True,
|
|
61
|
+
timeout=10,
|
|
62
|
+
)
|
|
63
|
+
if result.returncode != 0:
|
|
64
|
+
return None
|
|
65
|
+
|
|
66
|
+
# Get CUDA version from nvidia-smi
|
|
67
|
+
result = subprocess.run(
|
|
68
|
+
["nvidia-smi"],
|
|
69
|
+
capture_output=True,
|
|
70
|
+
text=True,
|
|
71
|
+
timeout=10,
|
|
72
|
+
)
|
|
73
|
+
# Parse "CUDA Version: 12.4" from output
|
|
74
|
+
match = re.search(r"CUDA Version:\s*(\d+\.\d+)", result.stdout)
|
|
75
|
+
if match:
|
|
76
|
+
return match.group(1)
|
|
77
|
+
except (subprocess.TimeoutExpired, FileNotFoundError, Exception):
|
|
78
|
+
pass
|
|
79
|
+
return None
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
def get_pytorch_index_url(cuda_version: str) -> Optional[str]:
|
|
83
|
+
"""Get PyTorch index URL for a CUDA version.
|
|
84
|
+
|
|
85
|
+
Args:
|
|
86
|
+
cuda_version: CUDA version string (e.g., "12.4")
|
|
87
|
+
|
|
88
|
+
Returns:
|
|
89
|
+
PyTorch index URL or None if version not supported.
|
|
90
|
+
"""
|
|
91
|
+
# Try exact match first
|
|
92
|
+
if cuda_version in CUDA_INDEX_URLS:
|
|
93
|
+
return CUDA_INDEX_URLS[cuda_version]
|
|
94
|
+
|
|
95
|
+
# Check fallback for newer CUDA versions (backward compatible)
|
|
96
|
+
if cuda_version in CUDA_FALLBACK:
|
|
97
|
+
fallback = CUDA_FALLBACK[cuda_version]
|
|
98
|
+
return CUDA_INDEX_URLS.get(fallback)
|
|
99
|
+
|
|
100
|
+
# Try major.minor prefix match for minor version differences
|
|
101
|
+
major_minor = ".".join(cuda_version.split(".")[:2])
|
|
102
|
+
for version, url in CUDA_INDEX_URLS.items():
|
|
103
|
+
if major_minor == ".".join(version.split(".")[:2]):
|
|
104
|
+
return url
|
|
105
|
+
|
|
106
|
+
return None
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
# Setup logging - minimal format when interactive UI is active
|
|
110
|
+
def setup_logging(verbose: bool, interactive: bool = False) -> None:
|
|
111
|
+
"""Configure logging based on verbosity and interactivity."""
|
|
112
|
+
if interactive:
|
|
113
|
+
# Suppress logging when using interactive UI
|
|
114
|
+
level = logging.DEBUG if verbose else logging.WARNING
|
|
115
|
+
else:
|
|
116
|
+
level = logging.DEBUG if verbose else logging.INFO
|
|
117
|
+
|
|
118
|
+
logging.basicConfig(
|
|
119
|
+
level=level,
|
|
120
|
+
format="%(asctime)s | %(levelname)-8s | %(message)s",
|
|
121
|
+
datefmt="%H:%M:%S",
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
@click.group()
|
|
126
|
+
@click.version_option(version="1.0.10", prog_name="lattice-sub")
|
|
127
|
+
def main():
|
|
128
|
+
"""
|
|
129
|
+
Lattice Subtraction for Cryo-EM Micrographs.
|
|
130
|
+
|
|
131
|
+
Remove periodic crystal lattice signals from micrographs to reveal
|
|
132
|
+
non-periodic features like defects, particles, or molecular tags.
|
|
133
|
+
|
|
134
|
+
GPU acceleration works automatically with PyTorch 2.5+. No setup needed!
|
|
135
|
+
Use 'lattice-sub setup-gpu' to verify GPU status or troubleshoot.
|
|
136
|
+
|
|
137
|
+
\b
|
|
138
|
+
Quick Start:
|
|
139
|
+
pip install lattice-sub
|
|
140
|
+
lattice-sub process input.mrc -o output.mrc -p 0.56
|
|
141
|
+
|
|
142
|
+
\b
|
|
143
|
+
Examples:
|
|
144
|
+
# Process single file (auto GPU detection)
|
|
145
|
+
lattice-sub process input.mrc -o output.mrc --pixel-size 0.56
|
|
146
|
+
|
|
147
|
+
# Force CPU processing
|
|
148
|
+
lattice-sub process input.mrc -o output.mrc -p 0.56 --cpu
|
|
149
|
+
|
|
150
|
+
# Batch process directory (GPU handles parallelism)
|
|
151
|
+
lattice-sub batch input_dir/ output_dir/ --pixel-size 0.56
|
|
152
|
+
|
|
153
|
+
# Batch with visualizations
|
|
154
|
+
lattice-sub batch input_dir/ output_dir/ -p 0.56 --vis viz_dir/
|
|
155
|
+
|
|
156
|
+
# CPU batch with parallel workers (use -j only with --cpu)
|
|
157
|
+
lattice-sub batch input_dir/ output_dir/ -p 0.56 --cpu -j 8
|
|
158
|
+
|
|
159
|
+
# Generate visualizations for existing files
|
|
160
|
+
lattice-sub visualize input_dir/ output_dir/ viz_dir/
|
|
161
|
+
|
|
162
|
+
# Create config file
|
|
163
|
+
lattice-sub init-config params.yaml --pixel-size 0.56
|
|
164
|
+
|
|
165
|
+
# Quiet mode (no banner/colors)
|
|
166
|
+
lattice-sub process input.mrc -o output.mrc -p 0.56 --quiet
|
|
167
|
+
"""
|
|
168
|
+
pass
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
@main.command("setup-gpu")
|
|
172
|
+
@click.option(
|
|
173
|
+
"-y", "--yes",
|
|
174
|
+
is_flag=True,
|
|
175
|
+
help="Skip confirmation prompt",
|
|
176
|
+
)
|
|
177
|
+
@click.option(
|
|
178
|
+
"--force",
|
|
179
|
+
is_flag=True,
|
|
180
|
+
help="Force reinstall even if GPU is already working",
|
|
181
|
+
)
|
|
182
|
+
def setup_gpu(yes: bool, force: bool):
|
|
183
|
+
"""
|
|
184
|
+
One-time GPU setup - installs PyTorch with CUDA support.
|
|
185
|
+
|
|
186
|
+
Detects your CUDA version and installs the appropriate PyTorch
|
|
187
|
+
wheels for GPU acceleration. You only need to run this once.
|
|
188
|
+
|
|
189
|
+
Note: PyTorch 2.5+ often bundles CUDA support by default. This
|
|
190
|
+
command will first check if your GPU is already working.
|
|
191
|
+
|
|
192
|
+
\b
|
|
193
|
+
Example:
|
|
194
|
+
lattice-sub setup-gpu # Interactive
|
|
195
|
+
lattice-sub setup-gpu -y # Skip confirmation
|
|
196
|
+
lattice-sub setup-gpu --force # Reinstall even if working
|
|
197
|
+
"""
|
|
198
|
+
# First, check if GPU is already working
|
|
199
|
+
click.echo("\nChecking current GPU status...")
|
|
200
|
+
try:
|
|
201
|
+
import torch
|
|
202
|
+
if torch.cuda.is_available() and not force:
|
|
203
|
+
gpu_name = torch.cuda.get_device_name(0)
|
|
204
|
+
click.echo(f"\n✓ GPU already enabled: {gpu_name}")
|
|
205
|
+
click.echo(f" PyTorch version: {torch.__version__}")
|
|
206
|
+
click.echo("\n No setup needed! Your GPU is ready to use.")
|
|
207
|
+
click.echo(" Use --force to reinstall anyway.")
|
|
208
|
+
sys.exit(0)
|
|
209
|
+
elif torch.cuda.is_available() and force:
|
|
210
|
+
gpu_name = torch.cuda.get_device_name(0)
|
|
211
|
+
click.echo(f"\n GPU currently working: {gpu_name}")
|
|
212
|
+
click.echo(" Proceeding with reinstall due to --force...")
|
|
213
|
+
except ImportError:
|
|
214
|
+
click.echo(" PyTorch not installed, proceeding with setup...")
|
|
215
|
+
except Exception as e:
|
|
216
|
+
click.echo(f" Could not check GPU: {e}")
|
|
217
|
+
|
|
218
|
+
click.echo("\nDetecting CUDA version...", nl=False)
|
|
219
|
+
|
|
220
|
+
cuda_version = detect_cuda_version()
|
|
221
|
+
|
|
222
|
+
if cuda_version is None:
|
|
223
|
+
click.echo(" not found")
|
|
224
|
+
click.echo("\n✗ No NVIDIA GPU detected.")
|
|
225
|
+
click.echo(" Make sure nvidia-smi works and NVIDIA drivers are installed.")
|
|
226
|
+
click.echo(" The package will run on CPU without GPU setup.")
|
|
227
|
+
sys.exit(1)
|
|
228
|
+
|
|
229
|
+
click.echo(f" found CUDA {cuda_version}")
|
|
230
|
+
|
|
231
|
+
# Check for fallback version
|
|
232
|
+
effective_version = CUDA_FALLBACK.get(cuda_version, cuda_version)
|
|
233
|
+
if effective_version != cuda_version:
|
|
234
|
+
click.echo(f" (Using CUDA {effective_version} wheels - backward compatible)")
|
|
235
|
+
|
|
236
|
+
index_url = get_pytorch_index_url(cuda_version)
|
|
237
|
+
if index_url is None:
|
|
238
|
+
supported = ", ".join(sorted(CUDA_INDEX_URLS.keys()))
|
|
239
|
+
click.echo(f"\n✗ CUDA {cuda_version} is not supported.")
|
|
240
|
+
click.echo(f" Supported versions: {supported}")
|
|
241
|
+
click.echo("\n However, your GPU may already work with the bundled CUDA.")
|
|
242
|
+
click.echo(" Try running: python -c \"import torch; print(torch.cuda.is_available())\"")
|
|
243
|
+
sys.exit(1)
|
|
244
|
+
|
|
245
|
+
# Build pip command
|
|
246
|
+
pip_cmd = f"pip install torch --index-url {index_url} --force-reinstall"
|
|
247
|
+
|
|
248
|
+
# Show what will happen
|
|
249
|
+
click.echo(f"\nThis will install PyTorch with GPU support:")
|
|
250
|
+
click.echo(f" {pip_cmd}")
|
|
251
|
+
|
|
252
|
+
if effective_version in ["12.4", "12.6", "12.8"]:
|
|
253
|
+
click.echo(f"\n✓ CUDA {cuda_version} is well supported")
|
|
254
|
+
else:
|
|
255
|
+
click.echo(f"\nNote: CUDA 12.x is recommended, but {cuda_version} should work")
|
|
256
|
+
|
|
257
|
+
click.echo("\nThis is a one-time setup. You won't need to run this again.")
|
|
258
|
+
|
|
259
|
+
# Confirm unless --yes
|
|
260
|
+
if not yes:
|
|
261
|
+
if not click.confirm("\nProceed?", default=True):
|
|
262
|
+
click.echo("Cancelled.")
|
|
263
|
+
sys.exit(0)
|
|
264
|
+
|
|
265
|
+
# Run pip install
|
|
266
|
+
click.echo("\nInstalling PyTorch with CUDA support...")
|
|
267
|
+
|
|
268
|
+
try:
|
|
269
|
+
result = subprocess.run(
|
|
270
|
+
pip_cmd.split(),
|
|
271
|
+
check=True,
|
|
272
|
+
capture_output=False,
|
|
273
|
+
)
|
|
274
|
+
except subprocess.CalledProcessError as e:
|
|
275
|
+
click.echo(f"\n✗ Installation failed with exit code {e.returncode}")
|
|
276
|
+
click.echo(" Try running the pip command manually:")
|
|
277
|
+
click.echo(f" {pip_cmd}")
|
|
278
|
+
sys.exit(1)
|
|
279
|
+
|
|
280
|
+
# Verify installation
|
|
281
|
+
click.echo("\nVerifying GPU access...")
|
|
282
|
+
try:
|
|
283
|
+
import importlib
|
|
284
|
+
import torch
|
|
285
|
+
importlib.reload(torch) # Reload in case it was imported before
|
|
286
|
+
|
|
287
|
+
if torch.cuda.is_available():
|
|
288
|
+
gpu_name = torch.cuda.get_device_name(0)
|
|
289
|
+
click.echo(f"\n✓ GPU enabled: {gpu_name}")
|
|
290
|
+
click.echo("\nYou can now use lattice-sub with automatic GPU acceleration!")
|
|
291
|
+
else:
|
|
292
|
+
click.echo("\n⚠ PyTorch installed but CUDA not available.")
|
|
293
|
+
click.echo(" This may require a Python restart. Try:")
|
|
294
|
+
click.echo(" python -c \"import torch; print(torch.cuda.is_available())\"")
|
|
295
|
+
except ImportError:
|
|
296
|
+
click.echo("\n⚠ Could not verify - restart Python and check manually:")
|
|
297
|
+
click.echo(" python -c \"import torch; print(torch.cuda.is_available())\"")
|
|
298
|
+
|
|
299
|
+
|
|
300
|
+
@main.command()
|
|
301
|
+
@click.argument("input_file", type=click.Path(exists=True, dir_okay=False))
|
|
302
|
+
@click.option(
|
|
303
|
+
"-o", "--output",
|
|
304
|
+
type=click.Path(dir_okay=False),
|
|
305
|
+
help="Output file path. Default: sub_<input_name>",
|
|
306
|
+
)
|
|
307
|
+
@click.option(
|
|
308
|
+
"--pixel-size", "-p",
|
|
309
|
+
type=float,
|
|
310
|
+
required=True,
|
|
311
|
+
help="Pixel size in Angstroms",
|
|
312
|
+
)
|
|
313
|
+
@click.option(
|
|
314
|
+
"--threshold", "-t",
|
|
315
|
+
type=float,
|
|
316
|
+
default=1.42,
|
|
317
|
+
help="Peak detection threshold. Default: 1.42",
|
|
318
|
+
)
|
|
319
|
+
@click.option(
|
|
320
|
+
"--inside-radius",
|
|
321
|
+
type=float,
|
|
322
|
+
default=90.0,
|
|
323
|
+
help="Inner resolution limit in Angstroms. Default: 90",
|
|
324
|
+
)
|
|
325
|
+
@click.option(
|
|
326
|
+
"--outside-radius",
|
|
327
|
+
type=float,
|
|
328
|
+
default=None,
|
|
329
|
+
help="Outer resolution limit in Angstroms. Default: auto",
|
|
330
|
+
)
|
|
331
|
+
@click.option(
|
|
332
|
+
"--config", "-c",
|
|
333
|
+
type=click.Path(exists=True, dir_okay=False),
|
|
334
|
+
help="Path to YAML config file (overrides other options)",
|
|
335
|
+
)
|
|
336
|
+
@click.option(
|
|
337
|
+
"--cpu",
|
|
338
|
+
is_flag=True,
|
|
339
|
+
help="Force CPU processing (disable GPU auto-detection)",
|
|
340
|
+
)
|
|
341
|
+
@click.option(
|
|
342
|
+
"--diagnostics/--no-diagnostics",
|
|
343
|
+
default=False,
|
|
344
|
+
help="Save diagnostic images (mask, power spectrum)",
|
|
345
|
+
)
|
|
346
|
+
@click.option(
|
|
347
|
+
"-v", "--verbose",
|
|
348
|
+
is_flag=True,
|
|
349
|
+
help="Enable verbose output",
|
|
350
|
+
)
|
|
351
|
+
@click.option(
|
|
352
|
+
"-q", "--quiet",
|
|
353
|
+
is_flag=True,
|
|
354
|
+
help="Suppress decorative output (banner, colors)",
|
|
355
|
+
)
|
|
356
|
+
def process(
|
|
357
|
+
input_file: str,
|
|
358
|
+
output: Optional[str],
|
|
359
|
+
pixel_size: float,
|
|
360
|
+
threshold: float,
|
|
361
|
+
inside_radius: float,
|
|
362
|
+
outside_radius: Optional[float],
|
|
363
|
+
config: Optional[str],
|
|
364
|
+
cpu: bool,
|
|
365
|
+
diagnostics: bool,
|
|
366
|
+
verbose: bool,
|
|
367
|
+
quiet: bool,
|
|
368
|
+
):
|
|
369
|
+
"""
|
|
370
|
+
Process a single micrograph.
|
|
371
|
+
|
|
372
|
+
INPUT_FILE: Path to input MRC file
|
|
373
|
+
"""
|
|
374
|
+
# Initialize UI
|
|
375
|
+
ui = get_ui(quiet=quiet)
|
|
376
|
+
ui.print_banner()
|
|
377
|
+
|
|
378
|
+
setup_logging(verbose, interactive=ui.interactive)
|
|
379
|
+
logger = logging.getLogger(__name__)
|
|
380
|
+
|
|
381
|
+
input_path = Path(input_file)
|
|
382
|
+
|
|
383
|
+
# Determine output path
|
|
384
|
+
if output is None:
|
|
385
|
+
output_path = input_path.parent / f"sub_{input_path.name}"
|
|
386
|
+
else:
|
|
387
|
+
output_path = Path(output)
|
|
388
|
+
|
|
389
|
+
# Load or create config
|
|
390
|
+
if config:
|
|
391
|
+
logger.info(f"Loading config from {config}")
|
|
392
|
+
cfg = Config.from_yaml(config)
|
|
393
|
+
else:
|
|
394
|
+
cfg = Config(
|
|
395
|
+
pixel_ang=pixel_size,
|
|
396
|
+
threshold=threshold,
|
|
397
|
+
inside_radius_ang=inside_radius,
|
|
398
|
+
outside_radius_ang=outside_radius,
|
|
399
|
+
backend="numpy" if cpu else "auto",
|
|
400
|
+
)
|
|
401
|
+
|
|
402
|
+
# Print configuration
|
|
403
|
+
gpu_name = get_gpu_name() if not cpu else None
|
|
404
|
+
ui.print_config(cfg.pixel_ang, cfg.threshold, cfg.backend, gpu_name)
|
|
405
|
+
ui.start_timer()
|
|
406
|
+
|
|
407
|
+
logger.info(f"Processing: {input_path}")
|
|
408
|
+
logger.info(f"Parameters: pixel={cfg.pixel_ang}Å, threshold={cfg.threshold}")
|
|
409
|
+
|
|
410
|
+
# Process
|
|
411
|
+
try:
|
|
412
|
+
# Get image shape for display
|
|
413
|
+
image = read_mrc(input_path)
|
|
414
|
+
ui.start_processing(input_path.name, shape=image.shape)
|
|
415
|
+
|
|
416
|
+
subtractor = LatticeSubtractor(cfg)
|
|
417
|
+
result = subtractor.process(image, return_diagnostics=diagnostics)
|
|
418
|
+
|
|
419
|
+
# Save result
|
|
420
|
+
output_path.parent.mkdir(parents=True, exist_ok=True)
|
|
421
|
+
result.save(output_path, pixel_size=cfg.pixel_ang)
|
|
422
|
+
ui.print_saved(str(output_path))
|
|
423
|
+
logger.info(f"Saved: {output_path}")
|
|
424
|
+
|
|
425
|
+
# Save diagnostics if requested
|
|
426
|
+
if diagnostics and result.fft_mask is not None:
|
|
427
|
+
from .io import write_mrc
|
|
428
|
+
import numpy as np
|
|
429
|
+
|
|
430
|
+
mask_path = output_path.with_suffix(".mask.mrc")
|
|
431
|
+
write_mrc(result.fft_mask.astype(np.float32), mask_path)
|
|
432
|
+
ui.print_saved(str(mask_path))
|
|
433
|
+
logger.info(f"Saved mask: {mask_path}")
|
|
434
|
+
|
|
435
|
+
if result.power_spectrum is not None:
|
|
436
|
+
ps_path = output_path.with_suffix(".power.mrc")
|
|
437
|
+
write_mrc(result.power_spectrum, ps_path)
|
|
438
|
+
ui.print_saved(str(ps_path))
|
|
439
|
+
logger.info(f"Saved power spectrum: {ps_path}")
|
|
440
|
+
|
|
441
|
+
ui.end_processing(str(output_path), success=True)
|
|
442
|
+
ui.print_summary(processed=1)
|
|
443
|
+
|
|
444
|
+
except Exception as e:
|
|
445
|
+
ui.end_processing(str(output_path), success=False)
|
|
446
|
+
ui.print_error(str(e))
|
|
447
|
+
logger.error(f"Processing failed: {e}")
|
|
448
|
+
sys.exit(1)
|
|
449
|
+
|
|
450
|
+
|
|
451
|
+
@main.command()
|
|
452
|
+
@click.argument("input_dir", type=click.Path(exists=True, file_okay=False))
|
|
453
|
+
@click.argument("output_dir", type=click.Path(file_okay=False))
|
|
454
|
+
@click.option(
|
|
455
|
+
"--pixel-size", "-p",
|
|
456
|
+
type=float,
|
|
457
|
+
required=True,
|
|
458
|
+
help="Pixel size in Angstroms",
|
|
459
|
+
)
|
|
460
|
+
@click.option(
|
|
461
|
+
"--threshold", "-t",
|
|
462
|
+
type=float,
|
|
463
|
+
default=1.42,
|
|
464
|
+
help="Peak detection threshold. Default: 1.42",
|
|
465
|
+
)
|
|
466
|
+
@click.option(
|
|
467
|
+
"--pattern",
|
|
468
|
+
type=str,
|
|
469
|
+
default="*.mrc",
|
|
470
|
+
help="Glob pattern for input files. Default: *.mrc",
|
|
471
|
+
)
|
|
472
|
+
@click.option(
|
|
473
|
+
"--prefix",
|
|
474
|
+
type=str,
|
|
475
|
+
default="sub_",
|
|
476
|
+
help="Output filename prefix. Default: sub_",
|
|
477
|
+
)
|
|
478
|
+
@click.option(
|
|
479
|
+
"-j", "--jobs",
|
|
480
|
+
type=int,
|
|
481
|
+
default=None,
|
|
482
|
+
help="Number of parallel workers. Default: 1 for GPU, CPU count - 1 for --cpu mode",
|
|
483
|
+
)
|
|
484
|
+
@click.option(
|
|
485
|
+
"--config", "-c",
|
|
486
|
+
type=click.Path(exists=True, dir_okay=False),
|
|
487
|
+
help="Path to YAML config file",
|
|
488
|
+
)
|
|
489
|
+
@click.option(
|
|
490
|
+
"-r", "--recursive",
|
|
491
|
+
is_flag=True,
|
|
492
|
+
help="Search subdirectories recursively",
|
|
493
|
+
)
|
|
494
|
+
@click.option(
|
|
495
|
+
"--vis",
|
|
496
|
+
type=click.Path(file_okay=False),
|
|
497
|
+
default=None,
|
|
498
|
+
help="Generate comparison visualizations in this directory",
|
|
499
|
+
)
|
|
500
|
+
@click.option(
|
|
501
|
+
"-v", "--verbose",
|
|
502
|
+
is_flag=True,
|
|
503
|
+
help="Enable verbose output",
|
|
504
|
+
)
|
|
505
|
+
@click.option(
|
|
506
|
+
"-q", "--quiet",
|
|
507
|
+
is_flag=True,
|
|
508
|
+
help="Suppress decorative output (banner, colors)",
|
|
509
|
+
)
|
|
510
|
+
@click.option(
|
|
511
|
+
"--cpu",
|
|
512
|
+
is_flag=True,
|
|
513
|
+
help="Force CPU processing (disable GPU auto-detection)",
|
|
514
|
+
)
|
|
515
|
+
def batch(
|
|
516
|
+
input_dir: str,
|
|
517
|
+
output_dir: str,
|
|
518
|
+
pixel_size: float,
|
|
519
|
+
threshold: float,
|
|
520
|
+
pattern: str,
|
|
521
|
+
prefix: str,
|
|
522
|
+
jobs: Optional[int],
|
|
523
|
+
config: Optional[str],
|
|
524
|
+
recursive: bool,
|
|
525
|
+
vis: Optional[str],
|
|
526
|
+
verbose: bool,
|
|
527
|
+
quiet: bool,
|
|
528
|
+
cpu: bool,
|
|
529
|
+
):
|
|
530
|
+
"""
|
|
531
|
+
Batch process a directory of micrographs.
|
|
532
|
+
|
|
533
|
+
\b
|
|
534
|
+
INPUT_DIR: Directory containing input MRC files
|
|
535
|
+
OUTPUT_DIR: Directory for processed output files
|
|
536
|
+
"""
|
|
537
|
+
# Initialize UI
|
|
538
|
+
ui = get_ui(quiet=quiet)
|
|
539
|
+
ui.print_banner()
|
|
540
|
+
|
|
541
|
+
setup_logging(verbose, interactive=ui.interactive)
|
|
542
|
+
logger = logging.getLogger(__name__)
|
|
543
|
+
|
|
544
|
+
# Load or create config
|
|
545
|
+
if config:
|
|
546
|
+
cfg = Config.from_yaml(config)
|
|
547
|
+
else:
|
|
548
|
+
cfg = Config(
|
|
549
|
+
pixel_ang=pixel_size,
|
|
550
|
+
threshold=threshold,
|
|
551
|
+
backend="numpy" if cpu else "auto",
|
|
552
|
+
)
|
|
553
|
+
|
|
554
|
+
# Count files first
|
|
555
|
+
input_path = Path(input_dir)
|
|
556
|
+
if recursive:
|
|
557
|
+
files = list(input_path.rglob(pattern))
|
|
558
|
+
else:
|
|
559
|
+
files = list(input_path.glob(pattern))
|
|
560
|
+
|
|
561
|
+
num_files = len(files)
|
|
562
|
+
# For GPU: single worker is optimal (GPU handles parallelism)
|
|
563
|
+
# For CPU: use multiple workers to parallelize across cores
|
|
564
|
+
if jobs is not None:
|
|
565
|
+
num_workers = jobs
|
|
566
|
+
elif cpu:
|
|
567
|
+
num_workers = max(1, (import_cpu_count() or 4) - 1)
|
|
568
|
+
else:
|
|
569
|
+
num_workers = 1 # GPU mode: single worker is optimal
|
|
570
|
+
|
|
571
|
+
# Print configuration
|
|
572
|
+
gpu_name = get_gpu_name() if not cpu else None
|
|
573
|
+
ui.print_config(cfg.pixel_ang, cfg.threshold, cfg.backend, gpu_name)
|
|
574
|
+
ui.print_batch_header(num_files, output_dir, num_workers)
|
|
575
|
+
ui.start_timer()
|
|
576
|
+
|
|
577
|
+
logger.info(f"Batch processing: {input_dir} -> {output_dir}")
|
|
578
|
+
logger.info(f"Pattern: {pattern}, Workers: {jobs or 'auto'}")
|
|
579
|
+
|
|
580
|
+
# Process
|
|
581
|
+
processor = BatchProcessor(cfg, num_workers=jobs, output_prefix=prefix)
|
|
582
|
+
result = processor.process_directory(
|
|
583
|
+
input_dir=input_dir,
|
|
584
|
+
output_dir=output_dir,
|
|
585
|
+
pattern=pattern,
|
|
586
|
+
recursive=recursive,
|
|
587
|
+
show_progress=True, # Always show progress bar
|
|
588
|
+
)
|
|
589
|
+
|
|
590
|
+
# Report results with UI
|
|
591
|
+
ui.print_batch_complete()
|
|
592
|
+
ui.print_summary(processed=result.successful, failed=result.failed)
|
|
593
|
+
|
|
594
|
+
logger.info(f"Completed: {result.successful}/{result.total} files ({result.success_rate:.1f}%)")
|
|
595
|
+
|
|
596
|
+
if result.failed > 0:
|
|
597
|
+
for path, error in result.failed_files[:5]:
|
|
598
|
+
ui.print_error(f"{path}: {error}")
|
|
599
|
+
if len(result.failed_files) > 5:
|
|
600
|
+
ui.print_error(f"... and {len(result.failed_files) - 5} more failures")
|
|
601
|
+
|
|
602
|
+
logger.warning(f"Failed files: {result.failed}")
|
|
603
|
+
for path, error in result.failed_files[:10]:
|
|
604
|
+
logger.warning(f" {path}: {error}")
|
|
605
|
+
|
|
606
|
+
if len(result.failed_files) > 10:
|
|
607
|
+
logger.warning(f" ... and {len(result.failed_files) - 10} more")
|
|
608
|
+
|
|
609
|
+
sys.exit(1)
|
|
610
|
+
|
|
611
|
+
# Generate visualizations if requested
|
|
612
|
+
if vis:
|
|
613
|
+
logger.info(f"Generating visualizations in: {vis}")
|
|
614
|
+
viz_success, viz_total = generate_visualizations(
|
|
615
|
+
input_dir=input_dir,
|
|
616
|
+
output_dir=output_dir,
|
|
617
|
+
viz_dir=vis,
|
|
618
|
+
prefix=prefix,
|
|
619
|
+
pattern=pattern,
|
|
620
|
+
show_progress=True,
|
|
621
|
+
)
|
|
622
|
+
logger.info(f"Visualizations: {viz_success}/{viz_total} created")
|
|
623
|
+
|
|
624
|
+
|
|
625
|
+
@main.command("init-config")
|
|
626
|
+
@click.argument("output_file", type=click.Path(dir_okay=False))
|
|
627
|
+
@click.option(
|
|
628
|
+
"--pixel-size", "-p",
|
|
629
|
+
type=float,
|
|
630
|
+
default=0.56,
|
|
631
|
+
help="Pixel size in Angstroms. Default: 0.56",
|
|
632
|
+
)
|
|
633
|
+
@click.option(
|
|
634
|
+
"--detector",
|
|
635
|
+
type=click.Choice(["K3", "Falcon", "generic"]),
|
|
636
|
+
default="generic",
|
|
637
|
+
help="Detector type for default settings",
|
|
638
|
+
)
|
|
639
|
+
def init_config(output_file: str, pixel_size: float, detector: str):
|
|
640
|
+
"""
|
|
641
|
+
Create a default configuration file.
|
|
642
|
+
|
|
643
|
+
OUTPUT_FILE: Path for the YAML config file
|
|
644
|
+
"""
|
|
645
|
+
cfg = create_default_config(pixel_ang=pixel_size, detector=detector)
|
|
646
|
+
cfg.to_yaml(output_file)
|
|
647
|
+
|
|
648
|
+
click.echo(f"Created config file: {output_file}")
|
|
649
|
+
click.echo(f" Pixel size: {cfg.pixel_ang} Å")
|
|
650
|
+
click.echo(f" Threshold: {cfg.threshold}")
|
|
651
|
+
click.echo(f" Inside radius: {cfg.inside_radius_ang} Å")
|
|
652
|
+
|
|
653
|
+
|
|
654
|
+
@main.command("convert-config")
|
|
655
|
+
@click.argument("input_file", type=click.Path(exists=True, dir_okay=False))
|
|
656
|
+
@click.argument("output_file", type=click.Path(dir_okay=False))
|
|
657
|
+
def convert_config(input_file: str, output_file: str):
|
|
658
|
+
"""
|
|
659
|
+
Convert legacy PARAMETER file to YAML format.
|
|
660
|
+
|
|
661
|
+
\b
|
|
662
|
+
INPUT_FILE: Path to legacy PARAMETER file
|
|
663
|
+
OUTPUT_FILE: Path for new YAML config file
|
|
664
|
+
"""
|
|
665
|
+
try:
|
|
666
|
+
cfg = Config.from_legacy_parameter_file(input_file)
|
|
667
|
+
cfg.to_yaml(output_file)
|
|
668
|
+
click.echo(f"Converted {input_file} -> {output_file}")
|
|
669
|
+
except Exception as e:
|
|
670
|
+
click.echo(f"Error: {e}", err=True)
|
|
671
|
+
sys.exit(1)
|
|
672
|
+
|
|
673
|
+
|
|
674
|
+
@main.command("visualize")
|
|
675
|
+
@click.argument("input_dir", type=click.Path(exists=True, file_okay=False))
|
|
676
|
+
@click.argument("output_dir", type=click.Path(exists=True, file_okay=False))
|
|
677
|
+
@click.argument("viz_dir", type=click.Path(file_okay=False))
|
|
678
|
+
@click.option(
|
|
679
|
+
"--prefix",
|
|
680
|
+
type=str,
|
|
681
|
+
default="sub_",
|
|
682
|
+
help="Prefix used for processed files. Default: sub_",
|
|
683
|
+
)
|
|
684
|
+
@click.option(
|
|
685
|
+
"--pattern",
|
|
686
|
+
type=str,
|
|
687
|
+
default="*.mrc",
|
|
688
|
+
help="Glob pattern for MRC files. Default: *.mrc",
|
|
689
|
+
)
|
|
690
|
+
@click.option(
|
|
691
|
+
"--dpi",
|
|
692
|
+
type=int,
|
|
693
|
+
default=150,
|
|
694
|
+
help="Resolution for output images. Default: 150",
|
|
695
|
+
)
|
|
696
|
+
@click.option(
|
|
697
|
+
"-v", "--verbose",
|
|
698
|
+
is_flag=True,
|
|
699
|
+
help="Enable verbose output",
|
|
700
|
+
)
|
|
701
|
+
def visualize(
|
|
702
|
+
input_dir: str,
|
|
703
|
+
output_dir: str,
|
|
704
|
+
viz_dir: str,
|
|
705
|
+
prefix: str,
|
|
706
|
+
pattern: str,
|
|
707
|
+
dpi: int,
|
|
708
|
+
verbose: bool,
|
|
709
|
+
):
|
|
710
|
+
"""
|
|
711
|
+
Generate comparison visualizations for processed micrographs.
|
|
712
|
+
|
|
713
|
+
Creates side-by-side PNG images showing original, lattice-subtracted,
|
|
714
|
+
and difference images for each processed micrograph.
|
|
715
|
+
|
|
716
|
+
\b
|
|
717
|
+
INPUT_DIR: Directory containing original MRC files
|
|
718
|
+
OUTPUT_DIR: Directory containing processed (sub_*) MRC files
|
|
719
|
+
VIZ_DIR: Directory for output visualization PNG files
|
|
720
|
+
|
|
721
|
+
\b
|
|
722
|
+
Example:
|
|
723
|
+
lattice-sub visualize raw_images/ processed/ visualizations/
|
|
724
|
+
"""
|
|
725
|
+
setup_logging(verbose)
|
|
726
|
+
logger = logging.getLogger(__name__)
|
|
727
|
+
|
|
728
|
+
logger.info(f"Generating visualizations...")
|
|
729
|
+
logger.info(f" Original images: {input_dir}")
|
|
730
|
+
logger.info(f" Processed images: {output_dir}")
|
|
731
|
+
logger.info(f" Output visualizations: {viz_dir}")
|
|
732
|
+
|
|
733
|
+
successful, total = generate_visualizations(
|
|
734
|
+
input_dir=Path(input_dir),
|
|
735
|
+
output_dir=Path(output_dir),
|
|
736
|
+
viz_dir=Path(viz_dir),
|
|
737
|
+
prefix=prefix,
|
|
738
|
+
pattern=pattern,
|
|
739
|
+
dpi=dpi,
|
|
740
|
+
show_progress=True,
|
|
741
|
+
)
|
|
742
|
+
|
|
743
|
+
logger.info(f"Completed: {successful}/{total} visualizations created")
|
|
744
|
+
|
|
745
|
+
if successful < total:
|
|
746
|
+
logger.warning(f"Some visualizations failed: {total - successful}")
|
|
747
|
+
sys.exit(1)
|
|
748
|
+
|
|
749
|
+
|
|
750
|
+
if __name__ == "__main__":
|
|
751
|
+
main()
|