pytme 0.3b0.post1__cp311-cp311-macosx_15_0_arm64.whl → 0.3.1__cp311-cp311-macosx_15_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {pytme-0.3b0.post1.data → pytme-0.3.1.data}/scripts/match_template.py +28 -39
- {pytme-0.3b0.post1.data → pytme-0.3.1.data}/scripts/postprocess.py +23 -10
- {pytme-0.3b0.post1.data → pytme-0.3.1.data}/scripts/preprocessor_gui.py +95 -24
- pytme-0.3.1.data/scripts/pytme_runner.py +1223 -0
- {pytme-0.3b0.post1.dist-info → pytme-0.3.1.dist-info}/METADATA +5 -5
- {pytme-0.3b0.post1.dist-info → pytme-0.3.1.dist-info}/RECORD +53 -46
- scripts/extract_candidates.py +118 -99
- scripts/match_template.py +28 -39
- scripts/postprocess.py +23 -10
- scripts/preprocessor_gui.py +95 -24
- scripts/pytme_runner.py +644 -190
- scripts/refine_matches.py +156 -386
- tests/data/.DS_Store +0 -0
- tests/data/Blurring/.DS_Store +0 -0
- tests/data/Maps/.DS_Store +0 -0
- tests/data/Raw/.DS_Store +0 -0
- tests/data/Structures/.DS_Store +0 -0
- tests/preprocessing/test_utils.py +18 -0
- tests/test_backends.py +3 -9
- tests/test_density.py +0 -1
- tests/test_matching_utils.py +10 -60
- tests/test_rotations.py +1 -1
- tme/__version__.py +1 -1
- tme/analyzer/_utils.py +4 -4
- tme/analyzer/aggregation.py +13 -3
- tme/analyzer/peaks.py +11 -10
- tme/backends/_jax_utils.py +15 -13
- tme/backends/_numpyfftw_utils.py +270 -0
- tme/backends/cupy_backend.py +5 -44
- tme/backends/jax_backend.py +58 -37
- tme/backends/matching_backend.py +6 -51
- tme/backends/mlx_backend.py +1 -27
- tme/backends/npfftw_backend.py +68 -65
- tme/backends/pytorch_backend.py +1 -26
- tme/density.py +2 -6
- tme/extensions.cpython-311-darwin.so +0 -0
- tme/filters/ctf.py +22 -21
- tme/filters/wedge.py +10 -7
- tme/mask.py +341 -0
- tme/matching_data.py +7 -19
- tme/matching_exhaustive.py +34 -47
- tme/matching_optimization.py +2 -1
- tme/matching_scores.py +206 -411
- tme/matching_utils.py +73 -422
- tme/memory.py +1 -1
- tme/orientations.py +4 -6
- tme/rotations.py +1 -1
- pytme-0.3b0.post1.data/scripts/pytme_runner.py +0 -769
- {pytme-0.3b0.post1.data → pytme-0.3.1.data}/scripts/estimate_memory_usage.py +0 -0
- {pytme-0.3b0.post1.data → pytme-0.3.1.data}/scripts/preprocess.py +0 -0
- {pytme-0.3b0.post1.dist-info → pytme-0.3.1.dist-info}/WHEEL +0 -0
- {pytme-0.3b0.post1.dist-info → pytme-0.3.1.dist-info}/entry_points.txt +0 -0
- {pytme-0.3b0.post1.dist-info → pytme-0.3.1.dist-info}/licenses/LICENSE +0 -0
- {pytme-0.3b0.post1.dist-info → pytme-0.3.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1223 @@
|
|
1
|
+
#!python
|
2
|
+
"""
|
3
|
+
PyTME Batch Runner - Refactored Core Classes
|
4
|
+
"""
|
5
|
+
import re
|
6
|
+
import argparse
|
7
|
+
import subprocess
|
8
|
+
from abc import ABC, abstractmethod
|
9
|
+
|
10
|
+
from pathlib import Path
|
11
|
+
from dataclasses import dataclass
|
12
|
+
from typing import Dict, List, Optional, Any
|
13
|
+
|
14
|
+
from tme.backends import backend as be
|
15
|
+
from tme.cli import print_entry, print_block, sanitize_name
|
16
|
+
|
17
|
+
|
18
|
+
@dataclass
|
19
|
+
class TomoFiles:
|
20
|
+
"""Container for all files related to a single tomogram."""
|
21
|
+
|
22
|
+
#: Tomogram identifier.
|
23
|
+
tomo_id: str
|
24
|
+
#: Path to tomogram.
|
25
|
+
tomogram: Path
|
26
|
+
#: XML file with tilt angles, defocus, etc.
|
27
|
+
metadata: Path
|
28
|
+
#: Path to tomogram mask, optional.
|
29
|
+
mask: Optional[Path] = None
|
30
|
+
|
31
|
+
def __post_init__(self):
|
32
|
+
"""Validate that required files exist."""
|
33
|
+
if not self.tomogram.exists():
|
34
|
+
raise FileNotFoundError(f"Tomogram not found: {self.tomogram}")
|
35
|
+
if not self.metadata.exists():
|
36
|
+
raise FileNotFoundError(f"Metadata not found: {self.metadata}")
|
37
|
+
if self.mask and not self.mask.exists():
|
38
|
+
raise FileNotFoundError(f"Mask not found: {self.mask}")
|
39
|
+
|
40
|
+
|
41
|
+
@dataclass
|
42
|
+
class AnalysisFiles:
|
43
|
+
"""Container for files related to analysis of a single tomogram."""
|
44
|
+
|
45
|
+
#: Tomogram identifier.
|
46
|
+
tomo_id: str
|
47
|
+
#: List of TM pickle result files for this tomo_id.
|
48
|
+
input_files: List[Path]
|
49
|
+
#: Background pickle files for normalization (optional).
|
50
|
+
background_files: List[Path] = None
|
51
|
+
#: Target mask file (optional).
|
52
|
+
mask: Optional[Path] = None
|
53
|
+
|
54
|
+
def __post_init__(self):
|
55
|
+
"""Validate that required files exist."""
|
56
|
+
for input_file in self.input_files:
|
57
|
+
if not input_file.exists():
|
58
|
+
raise FileNotFoundError(f"Input file not found: {input_file}")
|
59
|
+
|
60
|
+
if self.background_files:
|
61
|
+
for bg_file in self.background_files:
|
62
|
+
if not bg_file.exists():
|
63
|
+
raise FileNotFoundError(f"Background file not found: {bg_file}")
|
64
|
+
|
65
|
+
if self.mask and not self.mask.exists():
|
66
|
+
raise FileNotFoundError(f"Mask not found: {self.mask}")
|
67
|
+
|
68
|
+
|
69
|
+
class DatasetDiscovery(ABC):
|
70
|
+
"""Base class for dataset discovery using glob patterns."""
|
71
|
+
|
72
|
+
@abstractmethod
|
73
|
+
def discover(self, tomo_list: Optional[List[str]] = None) -> List:
|
74
|
+
pass
|
75
|
+
|
76
|
+
@staticmethod
|
77
|
+
def parse_id_from_filename(filename: str) -> str:
|
78
|
+
"""Extract the tomogram ID from filename by removing technical suffixes."""
|
79
|
+
base = Path(filename).stem
|
80
|
+
# Remove technical suffixes (pixel size, binning, filtering info)
|
81
|
+
# Examples: "_10.00Apx", "_4.00Apx", "_bin4", "_dose_filt"
|
82
|
+
base = re.sub(r"_\d+(\.\d+)?(Apx|bin\d*|dose_filt)$", "", base)
|
83
|
+
|
84
|
+
# Remove common organizational prefixes if they exist
|
85
|
+
for prefix in ["rec_Position_", "Position_", "rec_", "tomo_"]:
|
86
|
+
if base.startswith(prefix):
|
87
|
+
base = base[len(prefix) :]
|
88
|
+
break
|
89
|
+
return base
|
90
|
+
|
91
|
+
def create_mapping_table(self, pattern: str) -> Dict[str, List[Path]]:
|
92
|
+
"""Create a mapping table between tomogram ids and file paths."""
|
93
|
+
if pattern is None:
|
94
|
+
return {}
|
95
|
+
|
96
|
+
ret = {}
|
97
|
+
path = Path(pattern).absolute()
|
98
|
+
for file in list(Path(path.parent).glob(path.name)):
|
99
|
+
file_id = self.parse_id_from_filename(file.name)
|
100
|
+
if file_id not in ret:
|
101
|
+
ret[file_id] = []
|
102
|
+
ret[file_id].append(file)
|
103
|
+
|
104
|
+
return ret
|
105
|
+
|
106
|
+
|
107
|
+
@dataclass
|
108
|
+
class TomoDatasetDiscovery(DatasetDiscovery):
|
109
|
+
"""Find and match tomogram files using glob patterns."""
|
110
|
+
|
111
|
+
#: Glob pattern for tomogram files, e.g., "/data/tomograms/*.mrc"
|
112
|
+
mrc_pattern: str
|
113
|
+
#: Glob pattern for metadata files, e.g., "/data/metadata/*.xml"
|
114
|
+
metadata_pattern: str
|
115
|
+
#: Optional glob pattern for mask files, e.g., "/data/masks/*.mrc"
|
116
|
+
mask_pattern: Optional[str] = None
|
117
|
+
|
118
|
+
def discover(self, tomo_list: Optional[List[str]] = None) -> List[TomoFiles]:
|
119
|
+
"""Find all matching tomogram files."""
|
120
|
+
mrc_files = self.create_mapping_table(self.mrc_pattern)
|
121
|
+
meta_files = self.create_mapping_table(self.metadata_pattern)
|
122
|
+
mask_files = self.create_mapping_table(self.mask_pattern)
|
123
|
+
|
124
|
+
if tomo_list:
|
125
|
+
mrc_files = {k: v for k, v in mrc_files.items() if k in tomo_list}
|
126
|
+
meta_files = {k: v for k, v in meta_files.items() if k in tomo_list}
|
127
|
+
mask_files = {k: v for k, v in mask_files.items() if k in tomo_list}
|
128
|
+
|
129
|
+
tomo_files = []
|
130
|
+
for key, value in mrc_files.items():
|
131
|
+
if key not in meta_files:
|
132
|
+
print(f"No metadata for {key}, skipping it for now.")
|
133
|
+
continue
|
134
|
+
|
135
|
+
tomo_files.append(
|
136
|
+
TomoFiles(
|
137
|
+
tomo_id=key,
|
138
|
+
tomogram=value[0].absolute(),
|
139
|
+
metadata=meta_files[key][0].absolute(),
|
140
|
+
mask=mask_files.get(key, [""])[0],
|
141
|
+
)
|
142
|
+
)
|
143
|
+
return tomo_files
|
144
|
+
|
145
|
+
|
146
|
+
@dataclass
|
147
|
+
class AnalysisDatasetDiscovery(DatasetDiscovery):
|
148
|
+
"""Find and match analysis files using glob patterns."""
|
149
|
+
|
150
|
+
#: Glob pattern for TM pickle files, e.g., "/data/results/*.pickle"
|
151
|
+
input_patterns: List[str]
|
152
|
+
#: List of glob patterns for background files, e.g., ["/data/bg1/*.pickle", "/data/bg2/*.pickle"]
|
153
|
+
background_patterns: List[str] = None
|
154
|
+
#: Target masks, e.g., "/data/masks/*.mrc"
|
155
|
+
mask_patterns: Optional[str] = None
|
156
|
+
|
157
|
+
def __post_init__(self):
|
158
|
+
"""Ensure patterns are lists."""
|
159
|
+
if isinstance(self.input_patterns, str):
|
160
|
+
self.input_patterns = [self.input_patterns]
|
161
|
+
if self.background_patterns and isinstance(self.background_patterns, str):
|
162
|
+
self.background_patterns = [self.background_patterns]
|
163
|
+
|
164
|
+
def discover(self, tomo_list: Optional[List[str]] = None) -> List[AnalysisFiles]:
|
165
|
+
"""Find all matching analysis files."""
|
166
|
+
|
167
|
+
input_files_by_id = {}
|
168
|
+
for pattern in self.input_patterns:
|
169
|
+
files = self.create_mapping_table(pattern)
|
170
|
+
for tomo_id, file_list in files.items():
|
171
|
+
if tomo_id not in input_files_by_id:
|
172
|
+
input_files_by_id[tomo_id] = []
|
173
|
+
input_files_by_id[tomo_id].extend(file_list)
|
174
|
+
|
175
|
+
background_files_by_id = {}
|
176
|
+
if self.background_patterns:
|
177
|
+
for pattern in self.background_patterns:
|
178
|
+
bg_files = self.create_mapping_table(pattern)
|
179
|
+
for tomo_id, file_list in bg_files.items():
|
180
|
+
if tomo_id not in background_files_by_id:
|
181
|
+
background_files_by_id[tomo_id] = []
|
182
|
+
background_files_by_id[tomo_id].extend(file_list)
|
183
|
+
|
184
|
+
mask_files_by_id = {}
|
185
|
+
if self.mask_patterns:
|
186
|
+
mask_files_by_id = self.create_mapping_table(self.mask_patterns)
|
187
|
+
|
188
|
+
if tomo_list:
|
189
|
+
input_files_by_id = {
|
190
|
+
k: v for k, v in input_files_by_id.items() if k in tomo_list
|
191
|
+
}
|
192
|
+
background_files_by_id = {
|
193
|
+
k: v for k, v in background_files_by_id.items() if k in tomo_list
|
194
|
+
}
|
195
|
+
mask_files_by_id = {
|
196
|
+
k: v for k, v in mask_files_by_id.items() if k in tomo_list
|
197
|
+
}
|
198
|
+
|
199
|
+
analysis_files = []
|
200
|
+
for tomo_id, input_file_list in input_files_by_id.items():
|
201
|
+
background_files = background_files_by_id.get(tomo_id, [])
|
202
|
+
mask_file = mask_files_by_id.get(tomo_id, [None])[0]
|
203
|
+
|
204
|
+
analysis_file = AnalysisFiles(
|
205
|
+
tomo_id=tomo_id,
|
206
|
+
input_files=[f.absolute() for f in input_file_list],
|
207
|
+
background_files=(
|
208
|
+
[f.absolute() for f in background_files] if background_files else []
|
209
|
+
),
|
210
|
+
mask=mask_file.absolute() if mask_file else None,
|
211
|
+
)
|
212
|
+
analysis_files.append(analysis_file)
|
213
|
+
|
214
|
+
return analysis_files
|
215
|
+
|
216
|
+
|
217
|
+
@dataclass
|
218
|
+
class TMParameters:
|
219
|
+
"""Template matching parameters."""
|
220
|
+
|
221
|
+
template: Path
|
222
|
+
template_mask: Optional[Path] = None
|
223
|
+
|
224
|
+
# Angular sampling (auto-calculated or explicit)
|
225
|
+
angular_sampling: Optional[float] = None
|
226
|
+
particle_diameter: Optional[float] = None
|
227
|
+
cone_angle: Optional[float] = None
|
228
|
+
cone_sampling: Optional[float] = None
|
229
|
+
axis_angle: float = 360.0
|
230
|
+
axis_sampling: Optional[float] = None
|
231
|
+
axis_symmetry: int = 1
|
232
|
+
cone_axis: int = 2
|
233
|
+
invert_cone: bool = False
|
234
|
+
no_use_optimized_set: bool = False
|
235
|
+
|
236
|
+
# Microscope parameters
|
237
|
+
acceleration_voltage: float = 300.0 # kV
|
238
|
+
spherical_aberration: float = 2.7e7 # Å
|
239
|
+
amplitude_contrast: float = 0.07
|
240
|
+
defocus: Optional[float] = None # Å
|
241
|
+
phase_shift: float = 0.0 # Dg
|
242
|
+
|
243
|
+
# Processing options
|
244
|
+
lowpass: Optional[float] = None # Å
|
245
|
+
highpass: Optional[float] = None # Å
|
246
|
+
pass_format: str = "sampling_rate" # "sampling_rate", "voxel", "frequency"
|
247
|
+
no_pass_smooth: bool = True
|
248
|
+
interpolation_order: int = 3
|
249
|
+
score_threshold: float = 0.0
|
250
|
+
score: str = "FLCSphericalMask"
|
251
|
+
|
252
|
+
# Weighting and correction
|
253
|
+
tilt_weighting: Optional[str] = None # "angle", "relion", "grigorieff"
|
254
|
+
wedge_axes: str = "2,0"
|
255
|
+
whiten_spectrum: bool = False
|
256
|
+
scramble_phases: bool = False
|
257
|
+
invert_target_contrast: bool = False
|
258
|
+
|
259
|
+
# CTF parameters
|
260
|
+
ctf_file: Optional[Path] = None
|
261
|
+
no_flip_phase: bool = True
|
262
|
+
correct_defocus_gradient: bool = False
|
263
|
+
|
264
|
+
# Performance options
|
265
|
+
centering: bool = False
|
266
|
+
pad_edges: bool = False
|
267
|
+
pad_filter: bool = False
|
268
|
+
use_mixed_precision: bool = False
|
269
|
+
use_memmap: bool = False
|
270
|
+
|
271
|
+
# Analysis options
|
272
|
+
peak_calling: bool = False
|
273
|
+
num_peaks: int = 1000
|
274
|
+
|
275
|
+
# Backend selection
|
276
|
+
backend: str = "numpy"
|
277
|
+
gpu_indices: Optional[str] = None
|
278
|
+
|
279
|
+
# Reconstruction
|
280
|
+
reconstruction_filter: str = "ramp"
|
281
|
+
reconstruction_interpolation_order: int = 1
|
282
|
+
no_filter_target: bool = False
|
283
|
+
|
284
|
+
def __post_init__(self):
|
285
|
+
"""Validate parameters and convert units."""
|
286
|
+
self.template = self.template.absolute()
|
287
|
+
if self.template_mask:
|
288
|
+
self.template_mask = self.template_mask.absolute()
|
289
|
+
|
290
|
+
if not self.template.exists():
|
291
|
+
raise FileNotFoundError(f"Template not found: {self.template}")
|
292
|
+
if self.template_mask and not self.template_mask.exists():
|
293
|
+
raise FileNotFoundError(f"Template mask not found: {self.template_mask}")
|
294
|
+
if self.ctf_file and not self.ctf_file.exists():
|
295
|
+
raise FileNotFoundError(f"CTF file not found: {self.ctf_file}")
|
296
|
+
|
297
|
+
if self.tilt_weighting and self.tilt_weighting not in [
|
298
|
+
"angle",
|
299
|
+
"relion",
|
300
|
+
"grigorieff",
|
301
|
+
]:
|
302
|
+
raise ValueError(f"Invalid tilt weighting: {self.tilt_weighting}")
|
303
|
+
|
304
|
+
if self.pass_format not in ["sampling_rate", "voxel", "frequency"]:
|
305
|
+
raise ValueError(f"Invalid pass format: {self.pass_format}")
|
306
|
+
|
307
|
+
valid_backends = list(be._BACKEND_REGISTRY.keys())
|
308
|
+
if self.backend not in valid_backends:
|
309
|
+
raise ValueError(
|
310
|
+
f"Invalid backend: {self.backend}. Choose from {valid_backends}"
|
311
|
+
)
|
312
|
+
|
313
|
+
def to_command_args(self, files: TomoFiles, output_path: Path) -> Dict[str, Any]:
|
314
|
+
"""Convert parameters to pyTME command arguments."""
|
315
|
+
args = {
|
316
|
+
"target": str(files.tomogram),
|
317
|
+
"template": str(self.template),
|
318
|
+
"output": str(output_path),
|
319
|
+
"acceleration-voltage": self.acceleration_voltage,
|
320
|
+
"spherical-aberration": self.spherical_aberration,
|
321
|
+
"amplitude-contrast": self.amplitude_contrast,
|
322
|
+
"interpolation-order": self.interpolation_order,
|
323
|
+
"wedge-axes": self.wedge_axes,
|
324
|
+
"score-threshold": self.score_threshold,
|
325
|
+
"score": self.score,
|
326
|
+
"pass-format": self.pass_format,
|
327
|
+
"reconstruction-filter": self.reconstruction_filter,
|
328
|
+
"reconstruction-interpolation-order": self.reconstruction_interpolation_order,
|
329
|
+
}
|
330
|
+
|
331
|
+
# Optional file arguments
|
332
|
+
if self.template_mask:
|
333
|
+
args["template-mask"] = str(self.template_mask)
|
334
|
+
if files.mask:
|
335
|
+
args["target-mask"] = str(files.mask)
|
336
|
+
if files.metadata:
|
337
|
+
args["ctf-file"] = str(files.metadata)
|
338
|
+
args["tilt-angles"] = str(files.metadata)
|
339
|
+
|
340
|
+
# Optional parameters
|
341
|
+
if self.lowpass:
|
342
|
+
args["lowpass"] = self.lowpass
|
343
|
+
if self.highpass:
|
344
|
+
args["highpass"] = self.highpass
|
345
|
+
if self.tilt_weighting:
|
346
|
+
args["tilt-weighting"] = self.tilt_weighting
|
347
|
+
if self.defocus:
|
348
|
+
args["defocus"] = self.defocus
|
349
|
+
if self.phase_shift != 0:
|
350
|
+
args["phase-shift"] = self.phase_shift
|
351
|
+
if self.gpu_indices:
|
352
|
+
args["gpu-indices"] = self.gpu_indices
|
353
|
+
if self.backend != "numpy":
|
354
|
+
args["backend"] = self.backend
|
355
|
+
|
356
|
+
# Angular sampling
|
357
|
+
if self.angular_sampling:
|
358
|
+
args["angular-sampling"] = self.angular_sampling
|
359
|
+
elif self.particle_diameter:
|
360
|
+
args["particle-diameter"] = self.particle_diameter
|
361
|
+
elif self.cone_angle:
|
362
|
+
args["cone-angle"] = self.cone_angle
|
363
|
+
if self.cone_sampling:
|
364
|
+
args["cone-sampling"] = self.cone_sampling
|
365
|
+
if self.axis_sampling:
|
366
|
+
args["axis-sampling"] = self.axis_sampling
|
367
|
+
if self.axis_angle != 360.0:
|
368
|
+
args["axis-angle"] = self.axis_angle
|
369
|
+
if self.axis_symmetry != 1:
|
370
|
+
args["axis-symmetry"] = self.axis_symmetry
|
371
|
+
if self.cone_axis != 2:
|
372
|
+
args["cone-axis"] = self.cone_axis
|
373
|
+
else:
|
374
|
+
# Default fallback
|
375
|
+
args["angular-sampling"] = 15.0
|
376
|
+
|
377
|
+
args["num-peaks"] = self.num_peaks
|
378
|
+
return {k: v for k, v in args.items() if v is not None}
|
379
|
+
|
380
|
+
def get_flags(self) -> List[str]:
|
381
|
+
"""Get boolean flags for pyTME command."""
|
382
|
+
flags = []
|
383
|
+
if self.whiten_spectrum:
|
384
|
+
flags.append("whiten-spectrum")
|
385
|
+
if self.scramble_phases:
|
386
|
+
flags.append("scramble-phases")
|
387
|
+
if self.invert_target_contrast:
|
388
|
+
flags.append("invert-target-contrast")
|
389
|
+
if self.centering:
|
390
|
+
flags.append("centering")
|
391
|
+
if self.pad_edges:
|
392
|
+
flags.append("pad-edges")
|
393
|
+
if self.pad_filter:
|
394
|
+
flags.append("pad-filter")
|
395
|
+
if not self.no_pass_smooth:
|
396
|
+
flags.append("no-pass-smooth")
|
397
|
+
if self.use_mixed_precision:
|
398
|
+
flags.append("use-mixed-precision")
|
399
|
+
if self.use_memmap:
|
400
|
+
flags.append("use-memmap")
|
401
|
+
if self.peak_calling:
|
402
|
+
flags.append("peak-calling")
|
403
|
+
if not self.no_flip_phase:
|
404
|
+
flags.append("no-flip-phase")
|
405
|
+
if self.correct_defocus_gradient:
|
406
|
+
flags.append("correct-defocus-gradient")
|
407
|
+
if self.invert_cone:
|
408
|
+
flags.append("invert-cone")
|
409
|
+
if self.no_use_optimized_set:
|
410
|
+
flags.append("no-use-optimized-set")
|
411
|
+
if self.no_filter_target:
|
412
|
+
flags.append("no-filter-target")
|
413
|
+
return flags
|
414
|
+
|
415
|
+
|
416
|
+
@dataclass
|
417
|
+
class AnalysisParameters:
|
418
|
+
"""Parameters for template matching analysis and peak calling."""
|
419
|
+
|
420
|
+
# Peak calling
|
421
|
+
peak_caller: str = "PeakCallerMaximumFilter"
|
422
|
+
num_peaks: int = 1000
|
423
|
+
min_score: float = 0.0
|
424
|
+
max_score: Optional[float] = None
|
425
|
+
min_distance: int = 5
|
426
|
+
min_boundary_distance: int = 0
|
427
|
+
mask_edges: bool = False
|
428
|
+
n_false_positives: Optional[int] = None
|
429
|
+
|
430
|
+
# Output format
|
431
|
+
output_format: str = "relion4"
|
432
|
+
output_directory: Optional[str] = None
|
433
|
+
angles_clockwise: bool = False
|
434
|
+
|
435
|
+
# Advanced options
|
436
|
+
extraction_box_size: Optional[int] = None
|
437
|
+
|
438
|
+
def to_command_args(
|
439
|
+
self, files: AnalysisFiles, output_path: Path
|
440
|
+
) -> Dict[str, Any]:
|
441
|
+
"""Convert parameters to analyze_template_matching command arguments."""
|
442
|
+
args = {
|
443
|
+
"input-files": " ".join([str(f) for f in files.input_files]),
|
444
|
+
"output-prefix": str(output_path.parent / output_path.stem),
|
445
|
+
"peak-caller": self.peak_caller,
|
446
|
+
"num-peaks": self.num_peaks,
|
447
|
+
"min-score": self.min_score,
|
448
|
+
"min-distance": self.min_distance,
|
449
|
+
"min-boundary-distance": self.min_boundary_distance,
|
450
|
+
"output-format": self.output_format,
|
451
|
+
}
|
452
|
+
|
453
|
+
# Optional parameters
|
454
|
+
if self.max_score is not None:
|
455
|
+
args["max-score"] = self.max_score
|
456
|
+
if self.n_false_positives is not None:
|
457
|
+
args["n-false-positives"] = self.n_false_positives
|
458
|
+
if self.extraction_box_size is not None:
|
459
|
+
args["extraction-box-size"] = self.extraction_box_size
|
460
|
+
if files.mask:
|
461
|
+
args["target-mask"] = str(files.mask)
|
462
|
+
|
463
|
+
# Background files
|
464
|
+
if files.background_files:
|
465
|
+
args["background-files"] = " ".join(
|
466
|
+
[str(f) for f in files.background_files]
|
467
|
+
)
|
468
|
+
|
469
|
+
return {k: v for k, v in args.items() if v is not None}
|
470
|
+
|
471
|
+
def get_flags(self) -> List[str]:
|
472
|
+
"""Get boolean flags for analyze_template_matching command."""
|
473
|
+
flags = []
|
474
|
+
if self.mask_edges:
|
475
|
+
flags.append("mask-edges")
|
476
|
+
if self.angles_clockwise:
|
477
|
+
flags.append("angles-clockwise")
|
478
|
+
return flags
|
479
|
+
|
480
|
+
|
481
|
+
@dataclass
|
482
|
+
class ComputeResources:
|
483
|
+
"""Compute resource requirements for a job."""
|
484
|
+
|
485
|
+
cpus: int = 4
|
486
|
+
memory_gb: int = 128
|
487
|
+
gpu_count: int = 0
|
488
|
+
gpu_type: Optional[str] = None # e.g., "3090", "A100"
|
489
|
+
time_limit: str = "05:00:00"
|
490
|
+
partition: str = "gpu-el8"
|
491
|
+
constraint: Optional[str] = None
|
492
|
+
qos: str = "normal"
|
493
|
+
|
494
|
+
def to_slurm_args(self) -> Dict[str, str]:
|
495
|
+
"""Convert to SLURM sbatch arguments."""
|
496
|
+
args = {
|
497
|
+
"ntasks": "1",
|
498
|
+
"nodes": "1",
|
499
|
+
"ntasks-per-node": "1",
|
500
|
+
"cpus-per-task": str(self.cpus),
|
501
|
+
"mem": f"{self.memory_gb}G",
|
502
|
+
"time": self.time_limit,
|
503
|
+
"partition": self.partition,
|
504
|
+
"qos": self.qos,
|
505
|
+
"export": "none",
|
506
|
+
}
|
507
|
+
|
508
|
+
if self.gpu_count > 0:
|
509
|
+
args["gres"] = f"gpu:{self.gpu_count}"
|
510
|
+
if self.gpu_type:
|
511
|
+
args["constraint"] = f"gpu={self.gpu_type}"
|
512
|
+
|
513
|
+
if self.constraint and not self.gpu_type:
|
514
|
+
args["constraint"] = self.constraint
|
515
|
+
|
516
|
+
return args
|
517
|
+
|
518
|
+
|
519
|
+
@dataclass
|
520
|
+
class AbstractTask(ABC):
|
521
|
+
"""Abstract task specification"""
|
522
|
+
|
523
|
+
files: object
|
524
|
+
parameters: object
|
525
|
+
resources: ComputeResources
|
526
|
+
output_dir: Path
|
527
|
+
|
528
|
+
@property
|
529
|
+
def tomo_id(self) -> str:
|
530
|
+
return self.files.tomo_id
|
531
|
+
|
532
|
+
@abstractmethod
|
533
|
+
def executable(self) -> str:
|
534
|
+
pass
|
535
|
+
|
536
|
+
@property
|
537
|
+
@abstractmethod
|
538
|
+
def output_file(self) -> Path:
|
539
|
+
pass
|
540
|
+
|
541
|
+
def to_command_args(self):
|
542
|
+
return self.parameters.to_command_args(self.files, self.output_file)
|
543
|
+
|
544
|
+
def create_output_dir(self) -> None:
|
545
|
+
"""Ensure output directory exists."""
|
546
|
+
self.output_dir.mkdir(parents=True, exist_ok=True)
|
547
|
+
|
548
|
+
|
549
|
+
@dataclass
|
550
|
+
class TemplateMatchingTask(AbstractTask):
|
551
|
+
"""Template matching task."""
|
552
|
+
|
553
|
+
@property
|
554
|
+
def output_file(self) -> Path:
|
555
|
+
original_stem = self.files.tomogram.stem
|
556
|
+
return self.output_dir / f"{original_stem}.pickle"
|
557
|
+
|
558
|
+
@property
|
559
|
+
def executable(self):
|
560
|
+
return "match_template"
|
561
|
+
|
562
|
+
|
563
|
+
class AnalysisTask(AbstractTask):
|
564
|
+
"""Analysis task for processing TM results."""
|
565
|
+
|
566
|
+
@property
|
567
|
+
def output_file(self) -> Path:
|
568
|
+
"""Generate output filename based on format."""
|
569
|
+
prefix = self.files.input_files[0].stem
|
570
|
+
|
571
|
+
format_extensions = {
|
572
|
+
"orientations": ".tsv",
|
573
|
+
"relion4": ".star",
|
574
|
+
"relion5": ".star",
|
575
|
+
"pickle": ".pickle",
|
576
|
+
"alignment": "",
|
577
|
+
"extraction": "",
|
578
|
+
"average": ".mrc",
|
579
|
+
}
|
580
|
+
|
581
|
+
extension = format_extensions.get(self.parameters.output_format, ".tsv")
|
582
|
+
return self.output_dir / f"{prefix}{extension}"
|
583
|
+
|
584
|
+
@property
|
585
|
+
def executable(self):
|
586
|
+
return "postprocess"
|
587
|
+
|
588
|
+
|
589
|
+
class ExecutionBackend(ABC):
|
590
|
+
"""Abstract base class for execution backends."""
|
591
|
+
|
592
|
+
@abstractmethod
|
593
|
+
def submit_job(self, task) -> str:
|
594
|
+
"""Submit a single job and return job ID or status."""
|
595
|
+
pass
|
596
|
+
|
597
|
+
@abstractmethod
|
598
|
+
def submit_jobs(self, tasks: List) -> List[str]:
|
599
|
+
"""Submit multiple jobs and return list of job IDs."""
|
600
|
+
pass
|
601
|
+
|
602
|
+
|
603
|
+
class SlurmBackend(ExecutionBackend):
|
604
|
+
"""SLURM execution backend for cluster job submission."""
|
605
|
+
|
606
|
+
def __init__(
|
607
|
+
self,
|
608
|
+
force: bool = True,
|
609
|
+
dry_run: bool = False,
|
610
|
+
script_dir: Path = Path("./slurm_scripts"),
|
611
|
+
environment_setup: str = "module load pyTME",
|
612
|
+
):
|
613
|
+
"""
|
614
|
+
Initialize SLURM backend.
|
615
|
+
|
616
|
+
Parameters
|
617
|
+
----------
|
618
|
+
force : bool, optional
|
619
|
+
Rerun completed jobs, defaults to True.
|
620
|
+
dry_run : bool, optional
|
621
|
+
Generate scripts but do not submit, defaults to False.
|
622
|
+
script_dir: str, optional
|
623
|
+
Directory to save generated scripts, defaults to ./slurm_scripts,
|
624
|
+
environment_setup : str, optional
|
625
|
+
Command to set up pyTME environment, defaults to module load pyTME.
|
626
|
+
"""
|
627
|
+
self.force = force
|
628
|
+
self.dry_run = dry_run
|
629
|
+
self.environment_setup = environment_setup
|
630
|
+
self.script_dir = Path(script_dir) if script_dir else Path("./slurm_scripts")
|
631
|
+
self.script_dir.mkdir(exist_ok=True, parents=True)
|
632
|
+
|
633
|
+
def create_sbatch_script(self, task) -> Path:
|
634
|
+
"""Generate SLURM sbatch script for a template matching task."""
|
635
|
+
script_path = self.script_dir / f"pytme_{task.tomo_id}.sh"
|
636
|
+
|
637
|
+
# Ensure output directory exists
|
638
|
+
task.create_output_dir()
|
639
|
+
|
640
|
+
slurm_args = task.resources.to_slurm_args()
|
641
|
+
slurm_args.update(
|
642
|
+
{
|
643
|
+
"output": f"{task.output_dir}/{task.tomo_id}_%j.out",
|
644
|
+
"error": f"{task.output_dir}/{task.tomo_id}_%j.err",
|
645
|
+
"job-name": f"pytme_{task.executable}_{task.tomo_id}",
|
646
|
+
"chdir": str(task.output_dir),
|
647
|
+
}
|
648
|
+
)
|
649
|
+
|
650
|
+
script_lines = ["#!/bin/bash", "", "# SLURM directives"]
|
651
|
+
for param, value in slurm_args.items():
|
652
|
+
script_lines.append(f"#SBATCH --{param}={value}")
|
653
|
+
|
654
|
+
script_lines.extend(
|
655
|
+
[
|
656
|
+
"",
|
657
|
+
"# Environment setup",
|
658
|
+
"\n".join(self.environment_setup.split(";")),
|
659
|
+
"",
|
660
|
+
"# Run template matching",
|
661
|
+
]
|
662
|
+
)
|
663
|
+
|
664
|
+
command_parts = [task.executable]
|
665
|
+
cmd_args = task.to_command_args()
|
666
|
+
for arg, value in cmd_args.items():
|
667
|
+
command_parts.append(f"--{arg} {value}")
|
668
|
+
|
669
|
+
for flag in task.parameters.get_flags():
|
670
|
+
command_parts.append(f"--{flag}")
|
671
|
+
|
672
|
+
command = " \\\n ".join(command_parts)
|
673
|
+
script_lines.append(command)
|
674
|
+
|
675
|
+
with open(script_path, "w") as f:
|
676
|
+
f.write("\n".join(script_lines) + "\n")
|
677
|
+
script_path.chmod(0o755)
|
678
|
+
|
679
|
+
print(f"Generated SLURM script: {Path(script_path).name}")
|
680
|
+
return script_path
|
681
|
+
|
682
|
+
def submit_job(self, task) -> str:
|
683
|
+
"""Submit a single SLURM job."""
|
684
|
+
script_path = self.create_sbatch_script(task)
|
685
|
+
|
686
|
+
if self.dry_run:
|
687
|
+
return f"DRY_RUN:{script_path}"
|
688
|
+
|
689
|
+
try:
|
690
|
+
if Path(task.output_file).exists() and not self.force:
|
691
|
+
return f"ERROR: {str(task.output_file)} exists and force was not set."
|
692
|
+
|
693
|
+
result = subprocess.run(
|
694
|
+
["sbatch", str(script_path)], capture_output=True, text=True, check=True
|
695
|
+
)
|
696
|
+
|
697
|
+
# Parse job ID from sbatch output
|
698
|
+
# Typical output: "Submitted batch job 123456"
|
699
|
+
job_id = result.stdout.strip().split()[-1]
|
700
|
+
print(f"Submitted job {job_id} for {task.tomo_id}")
|
701
|
+
return job_id
|
702
|
+
|
703
|
+
except subprocess.CalledProcessError as e:
|
704
|
+
error_msg = f"Failed to submit {script_path}: {e.stderr}"
|
705
|
+
return f"ERROR:{error_msg}"
|
706
|
+
except Exception as e:
|
707
|
+
error_msg = f"Submission error for {script_path}: {e}"
|
708
|
+
return f"ERROR:{error_msg}"
|
709
|
+
|
710
|
+
def submit_jobs(self, tasks: List) -> List[str]:
|
711
|
+
"""Submit multiple SLURM jobs."""
|
712
|
+
job_ids = []
|
713
|
+
for task in tasks:
|
714
|
+
job_id = self.submit_job(task)
|
715
|
+
job_ids.append(job_id)
|
716
|
+
return job_ids
|
717
|
+
|
718
|
+
|
719
|
+
def add_compute_resources(
|
720
|
+
parser,
|
721
|
+
default_cpus=4,
|
722
|
+
default_memory=32,
|
723
|
+
default_time="02:00:00",
|
724
|
+
default_partition="cpu",
|
725
|
+
include_gpu=False,
|
726
|
+
):
|
727
|
+
"""Add compute resource arguments to a parser."""
|
728
|
+
compute_group = parser.add_argument_group("Compute Resources")
|
729
|
+
compute_group.add_argument(
|
730
|
+
"--cpus", type=int, default=default_cpus, help="Number of CPUs per job"
|
731
|
+
)
|
732
|
+
compute_group.add_argument(
|
733
|
+
"--memory", type=int, default=default_memory, help="Memory per job in GB"
|
734
|
+
)
|
735
|
+
compute_group.add_argument(
|
736
|
+
"--time-limit", default=default_time, help="Time limit (HH:MM:SS)"
|
737
|
+
)
|
738
|
+
compute_group.add_argument(
|
739
|
+
"--partition", default=default_partition, help="SLURM partition"
|
740
|
+
)
|
741
|
+
compute_group.add_argument(
|
742
|
+
"--qos", default="normal", help="SLURM quality of service"
|
743
|
+
)
|
744
|
+
|
745
|
+
if include_gpu:
|
746
|
+
compute_group.add_argument(
|
747
|
+
"--gpu-count", type=int, default=1, help="Number of GPUs per job"
|
748
|
+
)
|
749
|
+
compute_group.add_argument(
|
750
|
+
"--gpu-type",
|
751
|
+
default="3090",
|
752
|
+
help="GPU type constraint (e.g., '3090', 'A100')",
|
753
|
+
)
|
754
|
+
|
755
|
+
return compute_group
|
756
|
+
|
757
|
+
|
758
|
+
def add_job_submission(parser, default_output_dir="./results"):
|
759
|
+
"""Add job submission arguments to a parser."""
|
760
|
+
job_group = parser.add_argument_group("Job Submission")
|
761
|
+
job_group.add_argument(
|
762
|
+
"--output-dir",
|
763
|
+
type=Path,
|
764
|
+
default=Path(default_output_dir),
|
765
|
+
help="Output directory for results",
|
766
|
+
)
|
767
|
+
job_group.add_argument(
|
768
|
+
"--script-dir",
|
769
|
+
type=Path,
|
770
|
+
default=Path("./scripts"),
|
771
|
+
help="Directory for generated SLURM scripts",
|
772
|
+
)
|
773
|
+
job_group.add_argument(
|
774
|
+
"--environment-setup",
|
775
|
+
default="module load pyTME",
|
776
|
+
help="Command(s) to set up pyTME environment",
|
777
|
+
)
|
778
|
+
job_group.add_argument(
|
779
|
+
"--dry-run", action="store_true", help="Generate scripts but do not submit jobs"
|
780
|
+
)
|
781
|
+
job_group.add_argument("--force", action="store_true", help="Rerun completed jobs")
|
782
|
+
|
783
|
+
return job_group
|
784
|
+
|
785
|
+
|
786
|
+
def parse_args():
|
787
|
+
parser = argparse.ArgumentParser(
|
788
|
+
description="Batch runner for PyTME.",
|
789
|
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
790
|
+
)
|
791
|
+
|
792
|
+
subparsers = parser.add_subparsers(
|
793
|
+
dest="command", help="Available commands", required=True
|
794
|
+
)
|
795
|
+
|
796
|
+
matching_parser = subparsers.add_parser(
|
797
|
+
"matching",
|
798
|
+
help="Run template matching",
|
799
|
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
800
|
+
)
|
801
|
+
|
802
|
+
# Input files for matching
|
803
|
+
tm_input_group = matching_parser.add_argument_group("Input Files")
|
804
|
+
tm_input_group.add_argument(
|
805
|
+
"--tomograms",
|
806
|
+
required=True,
|
807
|
+
help="Glob pattern for tomogram files (e.g., '/data/tomograms/*.mrc')",
|
808
|
+
)
|
809
|
+
tm_input_group.add_argument(
|
810
|
+
"--metadata",
|
811
|
+
required=True,
|
812
|
+
help="Glob pattern for metadata files (e.g., '/data/metadata/*.xml')",
|
813
|
+
)
|
814
|
+
tm_input_group.add_argument(
|
815
|
+
"--masks", help="Glob pattern for target mask files (e.g., '/data/masks/*.mrc')"
|
816
|
+
)
|
817
|
+
tm_input_group.add_argument(
|
818
|
+
"--template", required=True, type=Path, help="Template file (MRC, PDB, etc.)"
|
819
|
+
)
|
820
|
+
tm_input_group.add_argument("--template-mask", type=Path, help="Template mask file")
|
821
|
+
tm_input_group.add_argument(
|
822
|
+
"--tomo-list",
|
823
|
+
type=Path,
|
824
|
+
help="File with list of tomogram IDs to process (one per line)",
|
825
|
+
)
|
826
|
+
|
827
|
+
# Template matching parameters
|
828
|
+
tm_group = matching_parser.add_argument_group("Template Matching")
|
829
|
+
angular_group = tm_group.add_mutually_exclusive_group()
|
830
|
+
angular_group.add_argument(
|
831
|
+
"--angular-sampling", type=float, help="Angular sampling in degrees"
|
832
|
+
)
|
833
|
+
angular_group.add_argument(
|
834
|
+
"--particle-diameter",
|
835
|
+
type=float,
|
836
|
+
help="Particle diameter in units of sampling rate (typically Ångstrom)",
|
837
|
+
)
|
838
|
+
|
839
|
+
tm_group.add_argument(
|
840
|
+
"--score",
|
841
|
+
default="FLCSphericalMask",
|
842
|
+
help="Template matching scoring function. Use FLC if mask is not spherical.",
|
843
|
+
)
|
844
|
+
tm_group.add_argument(
|
845
|
+
"--score-threshold", type=float, default=0.0, help="Minimum score threshold"
|
846
|
+
)
|
847
|
+
|
848
|
+
# Microscope parameters
|
849
|
+
scope_group = matching_parser.add_argument_group("Microscope Parameters")
|
850
|
+
scope_group.add_argument(
|
851
|
+
"--voltage", type=float, default=300.0, help="Acceleration voltage in kV"
|
852
|
+
)
|
853
|
+
scope_group.add_argument(
|
854
|
+
"--spherical-aberration",
|
855
|
+
type=float,
|
856
|
+
default=2.7,
|
857
|
+
help="Spherical aberration in mm",
|
858
|
+
)
|
859
|
+
scope_group.add_argument(
|
860
|
+
"--amplitude-contrast", type=float, default=0.07, help="Amplitude contrast"
|
861
|
+
)
|
862
|
+
|
863
|
+
# Processing options
|
864
|
+
proc_group = matching_parser.add_argument_group("Processing Options")
|
865
|
+
proc_group.add_argument(
|
866
|
+
"--lowpass",
|
867
|
+
type=float,
|
868
|
+
help="Lowpass filter in units of sampling rate (typically Ångstrom).",
|
869
|
+
)
|
870
|
+
proc_group.add_argument(
|
871
|
+
"--highpass",
|
872
|
+
type=float,
|
873
|
+
help="Highpass filter in units of sampling rate (typically Ångstrom).",
|
874
|
+
)
|
875
|
+
proc_group.add_argument(
|
876
|
+
"--tilt-weighting",
|
877
|
+
choices=["angle", "relion", "grigorieff"],
|
878
|
+
help="Tilt weighting scheme",
|
879
|
+
)
|
880
|
+
proc_group.add_argument(
|
881
|
+
"--backend",
|
882
|
+
default="cupy",
|
883
|
+
choices=list(be._BACKEND_REGISTRY.keys()),
|
884
|
+
help="Computation backend",
|
885
|
+
)
|
886
|
+
proc_group.add_argument(
|
887
|
+
"--whiten-spectrum", action="store_true", help="Apply spectral whitening"
|
888
|
+
)
|
889
|
+
proc_group.add_argument(
|
890
|
+
"--scramble-phases",
|
891
|
+
action="store_true",
|
892
|
+
help="Scramble template phases for noise estimation",
|
893
|
+
)
|
894
|
+
|
895
|
+
_ = add_compute_resources(
|
896
|
+
matching_parser,
|
897
|
+
default_cpus=4,
|
898
|
+
default_memory=64,
|
899
|
+
include_gpu=True,
|
900
|
+
default_time="05:00:00",
|
901
|
+
default_partition="gpu-el8",
|
902
|
+
)
|
903
|
+
_ = add_job_submission(matching_parser, "./matching_results")
|
904
|
+
|
905
|
+
analysis_parser = subparsers.add_parser(
|
906
|
+
"analysis",
|
907
|
+
help="Analyze template matching results",
|
908
|
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
909
|
+
)
|
910
|
+
|
911
|
+
# Input files for analysis
|
912
|
+
analysis_input_group = analysis_parser.add_argument_group("Input Files")
|
913
|
+
analysis_input_group.add_argument(
|
914
|
+
"--input-file",
|
915
|
+
"--input-files",
|
916
|
+
required=True,
|
917
|
+
nargs="+",
|
918
|
+
help="Path to one or multiple runs of match_template.py.",
|
919
|
+
)
|
920
|
+
analysis_input_group.add_argument(
|
921
|
+
"--background-file",
|
922
|
+
"--background-files",
|
923
|
+
required=False,
|
924
|
+
nargs="+",
|
925
|
+
default=[],
|
926
|
+
help="Path to one or multiple runs of match_template.py for normalization. "
|
927
|
+
"For instance from --scramble_phases or a different template.",
|
928
|
+
)
|
929
|
+
analysis_input_group.add_argument(
|
930
|
+
"--masks", help="Glob pattern for target mask files (e.g., '/data/masks/*.mrc')"
|
931
|
+
)
|
932
|
+
analysis_input_group.add_argument(
|
933
|
+
"--tomo-list",
|
934
|
+
type=Path,
|
935
|
+
help="File with list of tomogram IDs to process (one per line)",
|
936
|
+
)
|
937
|
+
|
938
|
+
# Peak calling parameters
|
939
|
+
peak_group = analysis_parser.add_argument_group("Peak Calling")
|
940
|
+
peak_group.add_argument(
|
941
|
+
"--peak-caller",
|
942
|
+
choices=[
|
943
|
+
"PeakCallerSort",
|
944
|
+
"PeakCallerMaximumFilter",
|
945
|
+
"PeakCallerFast",
|
946
|
+
"PeakCallerRecursiveMasking",
|
947
|
+
"PeakCallerScipy",
|
948
|
+
],
|
949
|
+
default="PeakCallerMaximumFilter",
|
950
|
+
help="Peak caller for local maxima identification",
|
951
|
+
)
|
952
|
+
peak_group.add_argument(
|
953
|
+
"--num-peaks",
|
954
|
+
type=int,
|
955
|
+
default=1000,
|
956
|
+
help="Maximum number of peaks to identify",
|
957
|
+
)
|
958
|
+
peak_group.add_argument(
|
959
|
+
"--min-score",
|
960
|
+
type=float,
|
961
|
+
default=None,
|
962
|
+
help="Minimum score from which peaks will be considered",
|
963
|
+
)
|
964
|
+
peak_group.add_argument(
|
965
|
+
"--max-score",
|
966
|
+
type=float,
|
967
|
+
default=None,
|
968
|
+
help="Maximum score until which peaks will be considered",
|
969
|
+
)
|
970
|
+
peak_group.add_argument(
|
971
|
+
"--min-distance", type=int, default=None, help="Minimum distance between peaks"
|
972
|
+
)
|
973
|
+
peak_group.add_argument(
|
974
|
+
"--min-boundary-distance",
|
975
|
+
type=int,
|
976
|
+
default=None,
|
977
|
+
help="Minimum distance of peaks to target edges",
|
978
|
+
)
|
979
|
+
peak_group.add_argument(
|
980
|
+
"--mask-edges",
|
981
|
+
action="store_true",
|
982
|
+
default=False,
|
983
|
+
help="Whether candidates should not be identified from scores that were "
|
984
|
+
"computed from padded densities. Superseded by min_boundary_distance.",
|
985
|
+
)
|
986
|
+
peak_group.add_argument(
|
987
|
+
"--n-false-positives",
|
988
|
+
type=int,
|
989
|
+
default=None,
|
990
|
+
help="Number of accepted false-positive picks to determine minimum score",
|
991
|
+
)
|
992
|
+
|
993
|
+
# Output options
|
994
|
+
output_group = analysis_parser.add_argument_group("Output Options")
|
995
|
+
output_group.add_argument(
|
996
|
+
"--output-format",
|
997
|
+
choices=[
|
998
|
+
"orientations",
|
999
|
+
"relion4",
|
1000
|
+
"relion5",
|
1001
|
+
"alignment",
|
1002
|
+
"extraction",
|
1003
|
+
"average",
|
1004
|
+
"pickle",
|
1005
|
+
],
|
1006
|
+
default="relion4",
|
1007
|
+
help="Output format for analysis results",
|
1008
|
+
)
|
1009
|
+
output_group.add_argument(
|
1010
|
+
"--angles-clockwise",
|
1011
|
+
action="store_true",
|
1012
|
+
help="Report Euler angles in clockwise format expected by RELION",
|
1013
|
+
)
|
1014
|
+
|
1015
|
+
advanced_group = analysis_parser.add_argument_group("Advanced Options")
|
1016
|
+
advanced_group.add_argument(
|
1017
|
+
"--extraction-box-size",
|
1018
|
+
type=int,
|
1019
|
+
default=None,
|
1020
|
+
help="Box size for extracted subtomograms (for extraction output format)",
|
1021
|
+
)
|
1022
|
+
|
1023
|
+
_ = add_compute_resources(
|
1024
|
+
analysis_parser,
|
1025
|
+
default_cpus=2,
|
1026
|
+
default_memory=16,
|
1027
|
+
include_gpu=False,
|
1028
|
+
default_time="01:00:00",
|
1029
|
+
default_partition="htc-el8",
|
1030
|
+
)
|
1031
|
+
_ = add_job_submission(analysis_parser, "./analysis_results")
|
1032
|
+
|
1033
|
+
args = parser.parse_args()
|
1034
|
+
if args.tomo_list is not None:
|
1035
|
+
with open(args.tomo_list, mode="r") as f:
|
1036
|
+
args.tomo_list = [line.strip() for line in f if line.strip()]
|
1037
|
+
|
1038
|
+
args.output_dir = args.output_dir.absolute()
|
1039
|
+
args.script_dir = args.script_dir.absolute()
|
1040
|
+
return args
|
1041
|
+
|
1042
|
+
|
1043
|
+
def run_matching(args, resources):
|
1044
|
+
discovery = TomoDatasetDiscovery(
|
1045
|
+
mrc_pattern=args.tomograms,
|
1046
|
+
metadata_pattern=args.metadata,
|
1047
|
+
mask_pattern=args.masks,
|
1048
|
+
)
|
1049
|
+
files = discovery.discover(tomo_list=args.tomo_list)
|
1050
|
+
print_block(
|
1051
|
+
name="Discovering Dataset",
|
1052
|
+
data={
|
1053
|
+
"Tomogram Pattern": args.tomograms,
|
1054
|
+
"Metadata Pattern": args.metadata,
|
1055
|
+
"Mask Pattern": args.masks,
|
1056
|
+
"Valid Runs": len(files),
|
1057
|
+
},
|
1058
|
+
label_width=30,
|
1059
|
+
)
|
1060
|
+
if not files:
|
1061
|
+
print("No tomograms found! Check your patterns.")
|
1062
|
+
return
|
1063
|
+
|
1064
|
+
params = TMParameters(
|
1065
|
+
template=args.template,
|
1066
|
+
template_mask=args.template_mask,
|
1067
|
+
angular_sampling=args.angular_sampling,
|
1068
|
+
particle_diameter=args.particle_diameter,
|
1069
|
+
score=args.score,
|
1070
|
+
score_threshold=args.score_threshold,
|
1071
|
+
acceleration_voltage=args.voltage,
|
1072
|
+
spherical_aberration=args.spherical_aberration * 1e7, # mm to Ångstrom
|
1073
|
+
amplitude_contrast=args.amplitude_contrast,
|
1074
|
+
lowpass=args.lowpass,
|
1075
|
+
highpass=args.highpass,
|
1076
|
+
tilt_weighting=args.tilt_weighting,
|
1077
|
+
backend=args.backend,
|
1078
|
+
whiten_spectrum=args.whiten_spectrum,
|
1079
|
+
scramble_phases=args.scramble_phases,
|
1080
|
+
)
|
1081
|
+
print_params = params.to_command_args(files[0], "")
|
1082
|
+
_ = print_params.pop("target")
|
1083
|
+
_ = print_params.pop("output")
|
1084
|
+
print_params.update({k: True for k in params.get_flags()})
|
1085
|
+
print_params = {
|
1086
|
+
sanitize_name(k): print_params[k] for k in sorted(list(print_params.keys()))
|
1087
|
+
}
|
1088
|
+
print_block(name="Matching Parameters", data=print_params, label_width=30)
|
1089
|
+
print("\n" + "-" * 80)
|
1090
|
+
|
1091
|
+
tasks = []
|
1092
|
+
for tomo_file in files:
|
1093
|
+
task = TemplateMatchingTask(
|
1094
|
+
files=tomo_file,
|
1095
|
+
parameters=params,
|
1096
|
+
resources=resources,
|
1097
|
+
output_dir=args.output_dir,
|
1098
|
+
)
|
1099
|
+
tasks.append(task)
|
1100
|
+
|
1101
|
+
return tasks
|
1102
|
+
|
1103
|
+
|
1104
|
+
def run_analysis(args, resources):
|
1105
|
+
discovery = AnalysisDatasetDiscovery(
|
1106
|
+
input_patterns=args.input_file,
|
1107
|
+
background_patterns=args.background_file,
|
1108
|
+
mask_patterns=args.masks,
|
1109
|
+
)
|
1110
|
+
files = discovery.discover(tomo_list=args.tomo_list)
|
1111
|
+
print_block(
|
1112
|
+
name="Discovering Dataset",
|
1113
|
+
data={
|
1114
|
+
"Input Patterns": args.input_file,
|
1115
|
+
"Background Patterns": args.background_file,
|
1116
|
+
"Mask Pattern": args.masks,
|
1117
|
+
"Valid Runs": len(files),
|
1118
|
+
},
|
1119
|
+
label_width=30,
|
1120
|
+
)
|
1121
|
+
if not files:
|
1122
|
+
print("No TM results found! Check your patterns.")
|
1123
|
+
return
|
1124
|
+
|
1125
|
+
params = AnalysisParameters(
|
1126
|
+
peak_caller=args.peak_caller,
|
1127
|
+
num_peaks=args.num_peaks,
|
1128
|
+
min_score=args.min_score,
|
1129
|
+
max_score=args.max_score,
|
1130
|
+
min_distance=args.min_distance,
|
1131
|
+
min_boundary_distance=args.min_boundary_distance,
|
1132
|
+
mask_edges=args.mask_edges,
|
1133
|
+
n_false_positives=args.n_false_positives,
|
1134
|
+
output_format=args.output_format,
|
1135
|
+
angles_clockwise=args.angles_clockwise,
|
1136
|
+
extraction_box_size=args.extraction_box_size,
|
1137
|
+
)
|
1138
|
+
print_params = params.to_command_args(files[0], Path(""))
|
1139
|
+
_ = print_params.pop("input-files", None)
|
1140
|
+
_ = print_params.pop("background-files", None)
|
1141
|
+
_ = print_params.pop("output-prefix", None)
|
1142
|
+
print_params.update({k: True for k in params.get_flags()})
|
1143
|
+
print_params = {
|
1144
|
+
sanitize_name(k): print_params[k] for k in sorted(list(print_params.keys()))
|
1145
|
+
}
|
1146
|
+
print_block(name="Analysis Parameters", data=print_params, label_width=30)
|
1147
|
+
print("\n" + "-" * 80)
|
1148
|
+
|
1149
|
+
tasks = []
|
1150
|
+
for file in files:
|
1151
|
+
task = AnalysisTask(
|
1152
|
+
files=file,
|
1153
|
+
parameters=params,
|
1154
|
+
resources=resources,
|
1155
|
+
output_dir=args.output_dir,
|
1156
|
+
)
|
1157
|
+
tasks.append(task)
|
1158
|
+
|
1159
|
+
return tasks
|
1160
|
+
|
1161
|
+
|
1162
|
+
def main():
|
1163
|
+
print_entry()
|
1164
|
+
|
1165
|
+
args = parse_args()
|
1166
|
+
|
1167
|
+
resources = ComputeResources(
|
1168
|
+
cpus=args.cpus,
|
1169
|
+
memory_gb=args.memory,
|
1170
|
+
time_limit=args.time_limit,
|
1171
|
+
partition=args.partition,
|
1172
|
+
gpu_count=getattr(args, "gpu_count", 0),
|
1173
|
+
gpu_type=getattr(args, "gpu_type", None),
|
1174
|
+
)
|
1175
|
+
|
1176
|
+
func = run_matching
|
1177
|
+
if args.command == "analysis":
|
1178
|
+
func = run_analysis
|
1179
|
+
|
1180
|
+
try:
|
1181
|
+
tasks = func(args, resources)
|
1182
|
+
except Exception as e:
|
1183
|
+
exit(f"Error: {e}")
|
1184
|
+
|
1185
|
+
if tasks is None:
|
1186
|
+
exit(-1)
|
1187
|
+
|
1188
|
+
print_params = resources.to_slurm_args()
|
1189
|
+
print_params = {
|
1190
|
+
sanitize_name(k): print_params[k] for k in sorted(list(print_params.keys()))
|
1191
|
+
}
|
1192
|
+
print_block(name="Compute Resources", data=print_params, label_width=30)
|
1193
|
+
print("\n" + "-" * 80 + "\n")
|
1194
|
+
|
1195
|
+
backend = SlurmBackend(
|
1196
|
+
force=args.force,
|
1197
|
+
dry_run=args.dry_run,
|
1198
|
+
script_dir=args.script_dir,
|
1199
|
+
environment_setup=args.environment_setup,
|
1200
|
+
)
|
1201
|
+
job_ids = backend.submit_jobs(tasks)
|
1202
|
+
if args.dry_run:
|
1203
|
+
print(
|
1204
|
+
f"\nDry run complete. Generated {len(tasks)} scripts in {args.script_dir}"
|
1205
|
+
)
|
1206
|
+
return 0
|
1207
|
+
|
1208
|
+
successful_jobs = [j for j in job_ids if not j.startswith("ERROR")]
|
1209
|
+
print(f"\nSubmitted {len(successful_jobs)} jobs successfully.")
|
1210
|
+
if successful_jobs:
|
1211
|
+
print(f"Job IDs:\n{','.join(successful_jobs).strip()}")
|
1212
|
+
|
1213
|
+
if len(successful_jobs) == len(job_ids):
|
1214
|
+
return 0
|
1215
|
+
|
1216
|
+
print("\nThe following issues arose during submission:")
|
1217
|
+
for j in job_ids:
|
1218
|
+
if j.startswith("ERROR"):
|
1219
|
+
print(j)
|
1220
|
+
|
1221
|
+
|
1222
|
+
if __name__ == "__main__":
|
1223
|
+
main()
|