pytme 0.3b0.post1__cp311-cp311-macosx_15_0_arm64.whl → 0.3.1.dev20250731__cp311-cp311-macosx_15_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pytme-0.3.1.dev20250731.data/scripts/estimate_ram_usage.py +97 -0
- {pytme-0.3b0.post1.data → pytme-0.3.1.dev20250731.data}/scripts/match_template.py +30 -41
- {pytme-0.3b0.post1.data → pytme-0.3.1.dev20250731.data}/scripts/postprocess.py +35 -21
- {pytme-0.3b0.post1.data → pytme-0.3.1.dev20250731.data}/scripts/preprocessor_gui.py +96 -24
- pytme-0.3.1.dev20250731.data/scripts/pytme_runner.py +1223 -0
- {pytme-0.3b0.post1.dist-info → pytme-0.3.1.dev20250731.dist-info}/METADATA +5 -7
- {pytme-0.3b0.post1.dist-info → pytme-0.3.1.dev20250731.dist-info}/RECORD +59 -49
- scripts/estimate_ram_usage.py +97 -0
- scripts/extract_candidates.py +118 -99
- scripts/match_template.py +30 -41
- scripts/match_template_devel.py +1339 -0
- scripts/postprocess.py +35 -21
- scripts/preprocessor_gui.py +96 -24
- scripts/pytme_runner.py +644 -190
- scripts/refine_matches.py +158 -390
- tests/data/.DS_Store +0 -0
- tests/data/Blurring/.DS_Store +0 -0
- tests/data/Maps/.DS_Store +0 -0
- tests/data/Raw/.DS_Store +0 -0
- tests/data/Structures/.DS_Store +0 -0
- tests/preprocessing/test_utils.py +18 -0
- tests/test_analyzer.py +2 -3
- tests/test_backends.py +3 -9
- tests/test_density.py +0 -1
- tests/test_extensions.py +0 -1
- tests/test_matching_utils.py +10 -60
- tests/test_orientations.py +0 -12
- tests/test_rotations.py +1 -1
- tme/__version__.py +1 -1
- tme/analyzer/_utils.py +4 -4
- tme/analyzer/aggregation.py +35 -15
- tme/analyzer/peaks.py +11 -10
- tme/backends/_jax_utils.py +64 -18
- tme/backends/_numpyfftw_utils.py +270 -0
- tme/backends/cupy_backend.py +16 -55
- tme/backends/jax_backend.py +79 -40
- tme/backends/matching_backend.py +17 -51
- tme/backends/mlx_backend.py +1 -27
- tme/backends/npfftw_backend.py +71 -65
- tme/backends/pytorch_backend.py +1 -26
- tme/density.py +58 -5
- tme/extensions.cpython-311-darwin.so +0 -0
- tme/filters/ctf.py +22 -21
- tme/filters/wedge.py +10 -7
- tme/mask.py +341 -0
- tme/matching_data.py +31 -19
- tme/matching_exhaustive.py +37 -47
- tme/matching_optimization.py +2 -1
- tme/matching_scores.py +229 -411
- tme/matching_utils.py +73 -422
- tme/memory.py +1 -1
- tme/orientations.py +24 -13
- tme/rotations.py +1 -1
- pytme-0.3b0.post1.data/scripts/pytme_runner.py +0 -769
- {pytme-0.3b0.post1.data → pytme-0.3.1.dev20250731.data}/scripts/estimate_memory_usage.py +0 -0
- {pytme-0.3b0.post1.data → pytme-0.3.1.dev20250731.data}/scripts/preprocess.py +0 -0
- {pytme-0.3b0.post1.dist-info → pytme-0.3.1.dev20250731.dist-info}/WHEEL +0 -0
- {pytme-0.3b0.post1.dist-info → pytme-0.3.1.dev20250731.dist-info}/entry_points.txt +0 -0
- {pytme-0.3b0.post1.dist-info → pytme-0.3.1.dev20250731.dist-info}/licenses/LICENSE +0 -0
- {pytme-0.3b0.post1.dist-info → pytme-0.3.1.dev20250731.dist-info}/top_level.txt +0 -0
@@ -1,769 +0,0 @@
|
|
1
|
-
#!python
|
2
|
-
"""
|
3
|
-
PyTME Batch Runner - Refactored Core Classes
|
4
|
-
"""
|
5
|
-
import re
|
6
|
-
import argparse
|
7
|
-
import subprocess
|
8
|
-
from abc import ABC, abstractmethod
|
9
|
-
|
10
|
-
from pathlib import Path
|
11
|
-
from dataclasses import dataclass
|
12
|
-
from typing import Dict, List, Optional, Any
|
13
|
-
|
14
|
-
from tme.backends import backend as be
|
15
|
-
from tme.cli import print_entry, print_block, sanitize_name
|
16
|
-
|
17
|
-
|
18
|
-
@dataclass
|
19
|
-
class TomoFiles:
|
20
|
-
"""Container for all files related to a single tomogram."""
|
21
|
-
|
22
|
-
#: Tomogram identifier.
|
23
|
-
tomo_id: str
|
24
|
-
#: Path to tomogram.
|
25
|
-
tomogram: Path
|
26
|
-
#: XML file with tilt angles, defocus, etc.
|
27
|
-
metadata: Path
|
28
|
-
#: Path to tomogram mask, optional.
|
29
|
-
mask: Optional[Path] = None
|
30
|
-
|
31
|
-
def __post_init__(self):
|
32
|
-
"""Validate that required files exist."""
|
33
|
-
if not self.tomogram.exists():
|
34
|
-
raise FileNotFoundError(f"Tomogram not found: {self.tomogram}")
|
35
|
-
if not self.metadata.exists():
|
36
|
-
raise FileNotFoundError(f"Metadata not found: {self.metadata}")
|
37
|
-
if self.mask and not self.mask.exists():
|
38
|
-
raise FileNotFoundError(f"Mask not found: {self.mask}")
|
39
|
-
|
40
|
-
|
41
|
-
class TomoDatasetDiscovery:
|
42
|
-
"""Find and match tomogram files using glob patterns."""
|
43
|
-
|
44
|
-
def __init__(
|
45
|
-
self,
|
46
|
-
mrc_pattern: str,
|
47
|
-
metadata_pattern: str,
|
48
|
-
mask_pattern: Optional[str] = None,
|
49
|
-
):
|
50
|
-
"""
|
51
|
-
Initialize with glob patterns for file discovery.
|
52
|
-
|
53
|
-
Parameters
|
54
|
-
----------
|
55
|
-
mrc_pattern: str
|
56
|
-
Glob pattern for tomogram files, e.g., "/data/tomograms/*.mrc"
|
57
|
-
metadata_pattern: str
|
58
|
-
Glob pattern for metadata files, e.g., "/data/metadata/*.xml"
|
59
|
-
mask_pattern: str
|
60
|
-
Optional glob pattern for mask files, e.g., "/data/masks/*.mrc"
|
61
|
-
"""
|
62
|
-
self.mrc_pattern = mrc_pattern
|
63
|
-
self.metadata_pattern = metadata_pattern
|
64
|
-
self.mask_pattern = mask_pattern
|
65
|
-
|
66
|
-
@staticmethod
|
67
|
-
def parse_id_from_filename(filename: str) -> str:
|
68
|
-
"""Extract the tomogram ID from filename by removing technical suffixes."""
|
69
|
-
base = Path(filename).stem
|
70
|
-
# Remove technical suffixes (pixel size, binning, filtering info)
|
71
|
-
# Examples: "_10.00Apx", "_4.00Apx", "_bin4", "_dose_filt"
|
72
|
-
base = re.sub(r"_\d+(\.\d+)?(Apx|bin\d*|dose_filt)$", "", base)
|
73
|
-
|
74
|
-
# Remove common organizational prefixes if they exist
|
75
|
-
for prefix in ["rec_Position_", "Position_", "rec_", "tomo_"]:
|
76
|
-
if base.startswith(prefix):
|
77
|
-
base = base[len(prefix) :]
|
78
|
-
break
|
79
|
-
return base
|
80
|
-
|
81
|
-
def _create_mapping_table(self, pattern: str) -> Dict:
|
82
|
-
"""Create a mapping table between tomogram ids and file paths."""
|
83
|
-
if pattern is None:
|
84
|
-
return {}
|
85
|
-
|
86
|
-
ret = {}
|
87
|
-
path = Path(pattern).absolute()
|
88
|
-
for file in list(Path(path.parent).glob(path.name)):
|
89
|
-
file_id = self.parse_id_from_filename(file.name)
|
90
|
-
if file_id not in ret:
|
91
|
-
ret[file_id] = []
|
92
|
-
ret[file_id].append(file)
|
93
|
-
|
94
|
-
# This could all be done in one line but we want the messages.
|
95
|
-
for key in ret.keys():
|
96
|
-
value = ret[key]
|
97
|
-
if len(value) > 1:
|
98
|
-
print(f"Found id {key} multiple times at {value}. Using {value[0]}.")
|
99
|
-
ret[key] = value[0]
|
100
|
-
return ret
|
101
|
-
|
102
|
-
def discover_tomograms(
|
103
|
-
self, tomo_list: Optional[List[str]] = None, require_mask: bool = False
|
104
|
-
) -> List[TomoFiles]:
|
105
|
-
"""Find all matching tomogram files."""
|
106
|
-
mrc_files = self._create_mapping_table(self.mrc_pattern)
|
107
|
-
meta_files = self._create_mapping_table(self.metadata_pattern)
|
108
|
-
mask_files = self._create_mapping_table(self.mask_pattern)
|
109
|
-
|
110
|
-
if tomo_list:
|
111
|
-
mrc_files = {k: v for k, v in mrc_files.items() if k in tomo_list}
|
112
|
-
meta_files = {k: v for k, v in meta_files.items() if k in tomo_list}
|
113
|
-
mask_files = {k: v for k, v in mask_files.items() if k in tomo_list}
|
114
|
-
|
115
|
-
tomo_files = []
|
116
|
-
for key, value in mrc_files.items():
|
117
|
-
if key not in meta_files:
|
118
|
-
print(f"No metadata for {key}, skipping it for now.")
|
119
|
-
continue
|
120
|
-
|
121
|
-
tomo_files.append(
|
122
|
-
TomoFiles(
|
123
|
-
tomo_id=key,
|
124
|
-
tomogram=value.absolute(),
|
125
|
-
metadata=meta_files[key].absolute(),
|
126
|
-
mask=mask_files.get(key),
|
127
|
-
)
|
128
|
-
)
|
129
|
-
return tomo_files
|
130
|
-
|
131
|
-
|
132
|
-
@dataclass
|
133
|
-
class TMParameters:
|
134
|
-
"""Template matching parameters."""
|
135
|
-
|
136
|
-
template: Path
|
137
|
-
template_mask: Optional[Path] = None
|
138
|
-
|
139
|
-
# Angular sampling (auto-calculated or explicit)
|
140
|
-
angular_sampling: Optional[float] = None
|
141
|
-
particle_diameter: Optional[float] = None
|
142
|
-
cone_angle: Optional[float] = None
|
143
|
-
cone_sampling: Optional[float] = None
|
144
|
-
axis_angle: float = 360.0
|
145
|
-
axis_sampling: Optional[float] = None
|
146
|
-
axis_symmetry: int = 1
|
147
|
-
cone_axis: int = 2
|
148
|
-
invert_cone: bool = False
|
149
|
-
no_use_optimized_set: bool = False
|
150
|
-
|
151
|
-
# Microscope parameters
|
152
|
-
acceleration_voltage: float = 300.0 # kV
|
153
|
-
spherical_aberration: float = 2.7e7 # Å
|
154
|
-
amplitude_contrast: float = 0.07
|
155
|
-
defocus: Optional[float] = None # Å
|
156
|
-
phase_shift: float = 0.0 # Dg
|
157
|
-
|
158
|
-
# Processing options
|
159
|
-
lowpass: Optional[float] = None # Å
|
160
|
-
highpass: Optional[float] = None # Å
|
161
|
-
pass_format: str = "sampling_rate" # "sampling_rate", "voxel", "frequency"
|
162
|
-
no_pass_smooth: bool = True
|
163
|
-
interpolation_order: int = 3
|
164
|
-
score_threshold: float = 0.0
|
165
|
-
score: str = "FLCSphericalMask"
|
166
|
-
|
167
|
-
# Weighting and correction
|
168
|
-
tilt_weighting: Optional[str] = None # "angle", "relion", "grigorieff"
|
169
|
-
wedge_axes: str = "2,0"
|
170
|
-
whiten_spectrum: bool = False
|
171
|
-
scramble_phases: bool = False
|
172
|
-
invert_target_contrast: bool = False
|
173
|
-
|
174
|
-
# CTF parameters
|
175
|
-
ctf_file: Optional[Path] = None
|
176
|
-
no_flip_phase: bool = True
|
177
|
-
correct_defocus_gradient: bool = False
|
178
|
-
|
179
|
-
# Performance options
|
180
|
-
centering: bool = False
|
181
|
-
pad_edges: bool = False
|
182
|
-
pad_filter: bool = False
|
183
|
-
use_mixed_precision: bool = False
|
184
|
-
use_memmap: bool = False
|
185
|
-
|
186
|
-
# Analysis options
|
187
|
-
peak_calling: bool = False
|
188
|
-
num_peaks: int = 1000
|
189
|
-
|
190
|
-
# Backend selection
|
191
|
-
backend: str = "numpy"
|
192
|
-
gpu_indices: Optional[str] = None
|
193
|
-
|
194
|
-
# Reconstruction
|
195
|
-
reconstruction_filter: str = "ramp"
|
196
|
-
reconstruction_interpolation_order: int = 1
|
197
|
-
no_filter_target: bool = False
|
198
|
-
|
199
|
-
def __post_init__(self):
|
200
|
-
"""Validate parameters and convert units."""
|
201
|
-
self.template = self.template.absolute()
|
202
|
-
if self.template_mask:
|
203
|
-
self.template_mask = self.template_mask.absolute()
|
204
|
-
|
205
|
-
if not self.template.exists():
|
206
|
-
raise FileNotFoundError(f"Template not found: {self.template}")
|
207
|
-
if self.template_mask and not self.template_mask.exists():
|
208
|
-
raise FileNotFoundError(f"Template mask not found: {self.template_mask}")
|
209
|
-
if self.ctf_file and not self.ctf_file.exists():
|
210
|
-
raise FileNotFoundError(f"CTF file not found: {self.ctf_file}")
|
211
|
-
|
212
|
-
if self.tilt_weighting and self.tilt_weighting not in [
|
213
|
-
"angle",
|
214
|
-
"relion",
|
215
|
-
"grigorieff",
|
216
|
-
]:
|
217
|
-
raise ValueError(f"Invalid tilt weighting: {self.tilt_weighting}")
|
218
|
-
|
219
|
-
if self.pass_format not in ["sampling_rate", "voxel", "frequency"]:
|
220
|
-
raise ValueError(f"Invalid pass format: {self.pass_format}")
|
221
|
-
|
222
|
-
valid_backends = list(be._BACKEND_REGISTRY.keys())
|
223
|
-
if self.backend not in valid_backends:
|
224
|
-
raise ValueError(
|
225
|
-
f"Invalid backend: {self.backend}. Choose from {valid_backends}"
|
226
|
-
)
|
227
|
-
|
228
|
-
def to_command_args(
|
229
|
-
self, tomo_files: TomoFiles, output_path: Path
|
230
|
-
) -> Dict[str, Any]:
|
231
|
-
"""Convert parameters to pyTME command arguments."""
|
232
|
-
args = {
|
233
|
-
"target": str(tomo_files.tomogram),
|
234
|
-
"template": str(self.template),
|
235
|
-
"output": str(output_path),
|
236
|
-
"acceleration-voltage": self.acceleration_voltage,
|
237
|
-
"spherical-aberration": self.spherical_aberration,
|
238
|
-
"amplitude-contrast": self.amplitude_contrast,
|
239
|
-
"interpolation-order": self.interpolation_order,
|
240
|
-
"wedge-axes": self.wedge_axes,
|
241
|
-
"score-threshold": self.score_threshold,
|
242
|
-
"score": self.score,
|
243
|
-
"pass-format": self.pass_format,
|
244
|
-
"reconstruction-filter": self.reconstruction_filter,
|
245
|
-
"reconstruction-interpolation-order": self.reconstruction_interpolation_order,
|
246
|
-
}
|
247
|
-
|
248
|
-
# Optional file arguments
|
249
|
-
if self.template_mask:
|
250
|
-
args["template-mask"] = str(self.template_mask)
|
251
|
-
if tomo_files.mask:
|
252
|
-
args["target-mask"] = str(tomo_files.mask)
|
253
|
-
if tomo_files.metadata:
|
254
|
-
args["ctf-file"] = str(tomo_files.metadata)
|
255
|
-
args["tilt-angles"] = str(tomo_files.metadata)
|
256
|
-
|
257
|
-
# Optional parameters
|
258
|
-
if self.lowpass:
|
259
|
-
args["lowpass"] = self.lowpass
|
260
|
-
if self.highpass:
|
261
|
-
args["highpass"] = self.highpass
|
262
|
-
if self.tilt_weighting:
|
263
|
-
args["tilt-weighting"] = self.tilt_weighting
|
264
|
-
if self.defocus:
|
265
|
-
args["defocus"] = self.defocus
|
266
|
-
if self.phase_shift != 0:
|
267
|
-
args["phase-shift"] = self.phase_shift
|
268
|
-
if self.gpu_indices:
|
269
|
-
args["gpu-indices"] = self.gpu_indices
|
270
|
-
if self.backend != "numpy":
|
271
|
-
args["backend"] = self.backend
|
272
|
-
|
273
|
-
# Angular sampling
|
274
|
-
if self.angular_sampling:
|
275
|
-
args["angular-sampling"] = self.angular_sampling
|
276
|
-
elif self.particle_diameter:
|
277
|
-
args["particle-diameter"] = self.particle_diameter
|
278
|
-
elif self.cone_angle:
|
279
|
-
args["cone-angle"] = self.cone_angle
|
280
|
-
if self.cone_sampling:
|
281
|
-
args["cone-sampling"] = self.cone_sampling
|
282
|
-
if self.axis_sampling:
|
283
|
-
args["axis-sampling"] = self.axis_sampling
|
284
|
-
if self.axis_angle != 360.0:
|
285
|
-
args["axis-angle"] = self.axis_angle
|
286
|
-
if self.axis_symmetry != 1:
|
287
|
-
args["axis-symmetry"] = self.axis_symmetry
|
288
|
-
if self.cone_axis != 2:
|
289
|
-
args["cone-axis"] = self.cone_axis
|
290
|
-
else:
|
291
|
-
# Default fallback
|
292
|
-
args["angular-sampling"] = 15.0
|
293
|
-
|
294
|
-
args["num-peaks"] = self.num_peaks
|
295
|
-
return args
|
296
|
-
|
297
|
-
def get_flags(self) -> List[str]:
|
298
|
-
"""Get boolean flags for pyTME command."""
|
299
|
-
flags = []
|
300
|
-
if self.whiten_spectrum:
|
301
|
-
flags.append("whiten-spectrum")
|
302
|
-
if self.scramble_phases:
|
303
|
-
flags.append("scramble-phases")
|
304
|
-
if self.invert_target_contrast:
|
305
|
-
flags.append("invert-target-contrast")
|
306
|
-
if self.centering:
|
307
|
-
flags.append("centering")
|
308
|
-
if self.pad_edges:
|
309
|
-
flags.append("pad-edges")
|
310
|
-
if self.pad_filter:
|
311
|
-
flags.append("pad-filter")
|
312
|
-
if not self.no_pass_smooth:
|
313
|
-
flags.append("no-pass-smooth")
|
314
|
-
if self.use_mixed_precision:
|
315
|
-
flags.append("use-mixed-precision")
|
316
|
-
if self.use_memmap:
|
317
|
-
flags.append("use-memmap")
|
318
|
-
if self.peak_calling:
|
319
|
-
flags.append("peak-calling")
|
320
|
-
if not self.no_flip_phase:
|
321
|
-
flags.append("no-flip-phase")
|
322
|
-
if self.correct_defocus_gradient:
|
323
|
-
flags.append("correct-defocus-gradient")
|
324
|
-
if self.invert_cone:
|
325
|
-
flags.append("invert-cone")
|
326
|
-
if self.no_use_optimized_set:
|
327
|
-
flags.append("no-use-optimized-set")
|
328
|
-
if self.no_filter_target:
|
329
|
-
flags.append("no-filter-target")
|
330
|
-
return flags
|
331
|
-
|
332
|
-
|
333
|
-
@dataclass
|
334
|
-
class ComputeResources:
|
335
|
-
"""Compute resource requirements for a job."""
|
336
|
-
|
337
|
-
cpus: int = 4
|
338
|
-
memory_gb: int = 128
|
339
|
-
gpu_count: int = 0
|
340
|
-
gpu_type: Optional[str] = None # e.g., "3090", "A100"
|
341
|
-
time_limit: str = "05:00:00"
|
342
|
-
partition: str = "gpu-el8"
|
343
|
-
constraint: Optional[str] = None
|
344
|
-
qos: str = "normal"
|
345
|
-
|
346
|
-
def to_slurm_args(self) -> Dict[str, str]:
|
347
|
-
"""Convert to SLURM sbatch arguments."""
|
348
|
-
args = {
|
349
|
-
"ntasks": "1",
|
350
|
-
"nodes": "1",
|
351
|
-
"ntasks-per-node": "1",
|
352
|
-
"cpus-per-task": str(self.cpus),
|
353
|
-
"mem": f"{self.memory_gb}G",
|
354
|
-
"time": self.time_limit,
|
355
|
-
"partition": self.partition,
|
356
|
-
"qos": self.qos,
|
357
|
-
"export": "none",
|
358
|
-
}
|
359
|
-
|
360
|
-
if self.gpu_count > 0:
|
361
|
-
args["gres"] = f"gpu:{self.gpu_count}"
|
362
|
-
if self.gpu_type:
|
363
|
-
args["constraint"] = f"gpu={self.gpu_type}"
|
364
|
-
|
365
|
-
if self.constraint and not self.gpu_type:
|
366
|
-
args["constraint"] = self.constraint
|
367
|
-
|
368
|
-
return args
|
369
|
-
|
370
|
-
|
371
|
-
@dataclass
|
372
|
-
class TemplateMatchingTask:
|
373
|
-
"""A complete template matching task."""
|
374
|
-
|
375
|
-
tomo_files: TomoFiles
|
376
|
-
parameters: TMParameters
|
377
|
-
resources: ComputeResources
|
378
|
-
output_dir: Path
|
379
|
-
|
380
|
-
@property
|
381
|
-
def tomo_id(self) -> str:
|
382
|
-
return self.tomo_files.tomo_id
|
383
|
-
|
384
|
-
@property
|
385
|
-
def output_file(self) -> Path:
|
386
|
-
return self.output_dir / f"{self.tomo_id}.pickle"
|
387
|
-
|
388
|
-
def create_output_dir(self) -> None:
|
389
|
-
"""Ensure output directory exists."""
|
390
|
-
self.output_dir.mkdir(parents=True, exist_ok=True)
|
391
|
-
|
392
|
-
|
393
|
-
class ExecutionBackend(ABC):
|
394
|
-
"""Abstract base class for execution backends."""
|
395
|
-
|
396
|
-
@abstractmethod
|
397
|
-
def submit_job(self, task) -> str:
|
398
|
-
"""Submit a single job and return job ID or status."""
|
399
|
-
pass
|
400
|
-
|
401
|
-
@abstractmethod
|
402
|
-
def submit_jobs(self, tasks: List) -> List[str]:
|
403
|
-
"""Submit multiple jobs and return list of job IDs."""
|
404
|
-
pass
|
405
|
-
|
406
|
-
|
407
|
-
class SlurmBackend(ExecutionBackend):
|
408
|
-
"""SLURM execution backend for cluster job submission."""
|
409
|
-
|
410
|
-
def __init__(
|
411
|
-
self,
|
412
|
-
force: bool = True,
|
413
|
-
dry_run: bool = False,
|
414
|
-
script_dir: Path = Path("./slurm_scripts"),
|
415
|
-
environment_setup: str = "module load pyTME",
|
416
|
-
):
|
417
|
-
"""
|
418
|
-
Initialize SLURM backend.
|
419
|
-
|
420
|
-
Parameters
|
421
|
-
----------
|
422
|
-
force : bool, optional
|
423
|
-
Rerun completed jobs, defaults to True.
|
424
|
-
dry_run : bool, optional
|
425
|
-
Generate scripts but do not submit, defaults to False.
|
426
|
-
script_dir: str, optional
|
427
|
-
Directory to save generated scripts, defaults to ./slurm_scripts,
|
428
|
-
environment_setup : str, optional
|
429
|
-
Command to set up pyTME environment, defaults to module load pyTME.
|
430
|
-
"""
|
431
|
-
self.force = force
|
432
|
-
self.dry_run = dry_run
|
433
|
-
self.environment_setup = environment_setup
|
434
|
-
self.script_dir = Path(script_dir) if script_dir else Path("./slurm_scripts")
|
435
|
-
self.script_dir.mkdir(exist_ok=True, parents=True)
|
436
|
-
|
437
|
-
def create_sbatch_script(self, task) -> Path:
|
438
|
-
"""Generate SLURM sbatch script for a template matching task."""
|
439
|
-
script_path = self.script_dir / f"pytme_{task.tomo_id}.sh"
|
440
|
-
|
441
|
-
# Ensure output directory exists
|
442
|
-
task.create_output_dir()
|
443
|
-
|
444
|
-
slurm_args = task.resources.to_slurm_args()
|
445
|
-
slurm_args.update(
|
446
|
-
{
|
447
|
-
"output": f"{task.output_dir}/{task.tomo_id}_%j.out",
|
448
|
-
"error": f"{task.output_dir}/{task.tomo_id}_%j.err",
|
449
|
-
"job-name": f"pytme_{task.tomo_id}",
|
450
|
-
"chdir": str(task.output_dir),
|
451
|
-
}
|
452
|
-
)
|
453
|
-
|
454
|
-
script_lines = ["#!/bin/bash", "", "# SLURM directives"]
|
455
|
-
for param, value in slurm_args.items():
|
456
|
-
script_lines.append(f"#SBATCH --{param}={value}")
|
457
|
-
|
458
|
-
script_lines.extend(
|
459
|
-
[
|
460
|
-
"",
|
461
|
-
"# Environment setup",
|
462
|
-
"\n".join(self.environment_setup.split(";")),
|
463
|
-
"",
|
464
|
-
"# Run template matching",
|
465
|
-
]
|
466
|
-
)
|
467
|
-
|
468
|
-
command_parts = ["match_template"]
|
469
|
-
cmd_args = task.parameters.to_command_args(task.tomo_files, task.output_file)
|
470
|
-
for arg, value in cmd_args.items():
|
471
|
-
command_parts.append(f"--{arg} {value}")
|
472
|
-
|
473
|
-
for flag in task.parameters.get_flags():
|
474
|
-
command_parts.append(f"--{flag}")
|
475
|
-
|
476
|
-
command = " \\\n ".join(command_parts)
|
477
|
-
script_lines.append(command)
|
478
|
-
|
479
|
-
with open(script_path, "w") as f:
|
480
|
-
f.write("\n".join(script_lines) + "\n")
|
481
|
-
script_path.chmod(0o755)
|
482
|
-
|
483
|
-
print(f"Generated SLURM script: {Path(script_path).name}")
|
484
|
-
return script_path
|
485
|
-
|
486
|
-
def submit_job(self, task) -> str:
|
487
|
-
"""Submit a single SLURM job."""
|
488
|
-
script_path = self.create_sbatch_script(task)
|
489
|
-
|
490
|
-
if self.dry_run:
|
491
|
-
return f"DRY_RUN:{script_path}"
|
492
|
-
|
493
|
-
try:
|
494
|
-
if Path(task.output_file).exists() and not self.force:
|
495
|
-
return "ERROR: File exists and force was not requested."
|
496
|
-
|
497
|
-
result = subprocess.run(
|
498
|
-
["sbatch", str(script_path)], capture_output=True, text=True, check=True
|
499
|
-
)
|
500
|
-
|
501
|
-
# Parse job ID from sbatch output
|
502
|
-
# Typical output: "Submitted batch job 123456"
|
503
|
-
job_id = result.stdout.strip().split()[-1]
|
504
|
-
print(f"Submitted job {job_id} for {task.tomo_id}")
|
505
|
-
return job_id
|
506
|
-
|
507
|
-
except subprocess.CalledProcessError as e:
|
508
|
-
error_msg = f"Failed to submit {script_path}: {e.stderr}"
|
509
|
-
return f"ERROR:{error_msg}"
|
510
|
-
except Exception as e:
|
511
|
-
error_msg = f"Submission error for {script_path}: {e}"
|
512
|
-
return f"ERROR:{error_msg}"
|
513
|
-
|
514
|
-
def submit_jobs(self, tasks: List) -> List[str]:
|
515
|
-
"""Submit multiple SLURM jobs."""
|
516
|
-
job_ids = []
|
517
|
-
for task in tasks:
|
518
|
-
job_id = self.submit_job(task)
|
519
|
-
job_ids.append(job_id)
|
520
|
-
return job_ids
|
521
|
-
|
522
|
-
|
523
|
-
def parse_args():
|
524
|
-
parser = argparse.ArgumentParser(
|
525
|
-
description="Batch runner for match_template.py",
|
526
|
-
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
527
|
-
)
|
528
|
-
|
529
|
-
input_group = parser.add_argument_group("Input Files")
|
530
|
-
input_group.add_argument(
|
531
|
-
"--tomograms",
|
532
|
-
required=True,
|
533
|
-
help="Glob pattern for tomogram files (e.g., '/data/tomograms/*.mrc')",
|
534
|
-
)
|
535
|
-
input_group.add_argument(
|
536
|
-
"--metadata",
|
537
|
-
required=True,
|
538
|
-
help="Glob pattern for metadata files (e.g., '/data/metadata/*.xml')",
|
539
|
-
)
|
540
|
-
input_group.add_argument(
|
541
|
-
"--masks", help="Glob pattern for mask files (e.g., '/data/masks/*.mrc')"
|
542
|
-
)
|
543
|
-
input_group.add_argument(
|
544
|
-
"--template", required=True, type=Path, help="Template file (MRC, PDB, etc.)"
|
545
|
-
)
|
546
|
-
input_group.add_argument("--template-mask", type=Path, help="Template mask file")
|
547
|
-
input_group.add_argument(
|
548
|
-
"--tomo-list",
|
549
|
-
type=Path,
|
550
|
-
help="File with list of tomogram IDs to process (one per line)",
|
551
|
-
)
|
552
|
-
|
553
|
-
tm_group = parser.add_argument_group("Template Matching")
|
554
|
-
angular_group = tm_group.add_mutually_exclusive_group()
|
555
|
-
angular_group.add_argument(
|
556
|
-
"--angular-sampling", type=float, help="Angular sampling in degrees"
|
557
|
-
)
|
558
|
-
angular_group.add_argument(
|
559
|
-
"--particle-diameter",
|
560
|
-
type=float,
|
561
|
-
help="Particle diameter in units of sampling rate (typically Ångstrom)",
|
562
|
-
)
|
563
|
-
|
564
|
-
tm_group.add_argument(
|
565
|
-
"--score",
|
566
|
-
default="FLCSphericalMask",
|
567
|
-
help="Template matching scoring function. Use FLC if mask is not spherical.",
|
568
|
-
)
|
569
|
-
tm_group.add_argument(
|
570
|
-
"--score-threshold", type=float, default=0.0, help="Minimum score threshold"
|
571
|
-
)
|
572
|
-
|
573
|
-
scope_group = parser.add_argument_group("Microscope Parameters")
|
574
|
-
scope_group.add_argument(
|
575
|
-
"--voltage", type=float, default=300.0, help="Acceleration voltage in kV"
|
576
|
-
)
|
577
|
-
scope_group.add_argument(
|
578
|
-
"--spherical-aberration",
|
579
|
-
type=float,
|
580
|
-
default=2.7,
|
581
|
-
help="Spherical aberration in mm",
|
582
|
-
)
|
583
|
-
scope_group.add_argument(
|
584
|
-
"--amplitude-contrast", type=float, default=0.07, help="Amplitude contrast"
|
585
|
-
)
|
586
|
-
|
587
|
-
proc_group = parser.add_argument_group("Processing Options")
|
588
|
-
proc_group.add_argument(
|
589
|
-
"--lowpass",
|
590
|
-
type=float,
|
591
|
-
help="Lowpass filter in units of sampling rate (typically Ångstrom).",
|
592
|
-
)
|
593
|
-
proc_group.add_argument(
|
594
|
-
"--highpass",
|
595
|
-
type=float,
|
596
|
-
help="Highpass filter in units of sampling rate (typically Ångstrom).",
|
597
|
-
)
|
598
|
-
proc_group.add_argument(
|
599
|
-
"--tilt-weighting",
|
600
|
-
choices=["angle", "relion", "grigorieff"],
|
601
|
-
help="Tilt weighting scheme",
|
602
|
-
)
|
603
|
-
proc_group.add_argument(
|
604
|
-
"--backend",
|
605
|
-
default="cupy",
|
606
|
-
choices=list(be._BACKEND_REGISTRY.keys()),
|
607
|
-
help="Computation backend",
|
608
|
-
)
|
609
|
-
proc_group.add_argument(
|
610
|
-
"--whiten-spectrum", action="store_true", help="Apply spectral whitening"
|
611
|
-
)
|
612
|
-
proc_group.add_argument(
|
613
|
-
"--scramble-phases",
|
614
|
-
action="store_true",
|
615
|
-
help="Scramble template phases for noise estimation",
|
616
|
-
)
|
617
|
-
|
618
|
-
compute_group = parser.add_argument_group("Compute Resources")
|
619
|
-
compute_group.add_argument(
|
620
|
-
"--cpus", type=int, default=4, help="Number of CPUs per job"
|
621
|
-
)
|
622
|
-
compute_group.add_argument(
|
623
|
-
"--memory", type=int, default=64, help="Memory per job in GB"
|
624
|
-
)
|
625
|
-
compute_group.add_argument(
|
626
|
-
"--gpu-count", type=int, default=1, help="Number of GPUs per job"
|
627
|
-
)
|
628
|
-
compute_group.add_argument(
|
629
|
-
"--gpu-type", default="3090", help="GPU type constraint (e.g., '3090', 'A100')"
|
630
|
-
)
|
631
|
-
compute_group.add_argument(
|
632
|
-
"--time-limit", default="05:00:00", help="Time limit (HH:MM:SS)"
|
633
|
-
)
|
634
|
-
compute_group.add_argument("--partition", default="gpu-el8", help="SLURM partition")
|
635
|
-
|
636
|
-
job_group = parser.add_argument_group("Job Submission")
|
637
|
-
job_group.add_argument(
|
638
|
-
"--output-dir",
|
639
|
-
type=Path,
|
640
|
-
default=Path("./batch_results"),
|
641
|
-
help="Output directory for results",
|
642
|
-
)
|
643
|
-
job_group.add_argument(
|
644
|
-
"--script-dir",
|
645
|
-
type=Path,
|
646
|
-
default=Path("./slurm_scripts"),
|
647
|
-
help="Directory for generated SLURM scripts",
|
648
|
-
)
|
649
|
-
job_group.add_argument(
|
650
|
-
"--environment-setup",
|
651
|
-
default="module load pyTME",
|
652
|
-
help="Command(s) to set up pyTME environment",
|
653
|
-
)
|
654
|
-
job_group.add_argument(
|
655
|
-
"--dry-run", action="store_true", help="Generate scripts but do not submit jobs"
|
656
|
-
)
|
657
|
-
job_group.add_argument("--force", action="store_true", help="Rerun completed jobs")
|
658
|
-
args = parser.parse_args()
|
659
|
-
|
660
|
-
if args.tomo_list is not None:
|
661
|
-
with open(args.tomo_list, mode="r") as f:
|
662
|
-
args.tomo_list = [line.strip() for line in f if line.strip()]
|
663
|
-
|
664
|
-
args.output_dir = args.output_dir.absolute()
|
665
|
-
args.script_dir = args.script_dir.absolute()
|
666
|
-
|
667
|
-
return args
|
668
|
-
|
669
|
-
|
670
|
-
def main():
|
671
|
-
print_entry()
|
672
|
-
|
673
|
-
args = parse_args()
|
674
|
-
try:
|
675
|
-
discovery = TomoDatasetDiscovery(
|
676
|
-
mrc_pattern=args.tomograms,
|
677
|
-
metadata_pattern=args.metadata,
|
678
|
-
mask_pattern=args.masks,
|
679
|
-
)
|
680
|
-
tomo_files = discovery.discover_tomograms(tomo_list=args.tomo_list)
|
681
|
-
print_block(
|
682
|
-
name="Discovering Dataset",
|
683
|
-
data={
|
684
|
-
"Tomogram Pattern": args.tomograms,
|
685
|
-
"Metadata Pattern": args.metadata,
|
686
|
-
"Mask Pattern": args.masks,
|
687
|
-
"Valid Runs": len(tomo_files),
|
688
|
-
},
|
689
|
-
label_width=30,
|
690
|
-
)
|
691
|
-
if not tomo_files:
|
692
|
-
print("No tomograms found! Check your patterns.")
|
693
|
-
return
|
694
|
-
|
695
|
-
params = TMParameters(
|
696
|
-
template=args.template,
|
697
|
-
template_mask=args.template_mask,
|
698
|
-
angular_sampling=args.angular_sampling,
|
699
|
-
particle_diameter=args.particle_diameter,
|
700
|
-
score=args.score,
|
701
|
-
score_threshold=args.score_threshold,
|
702
|
-
acceleration_voltage=args.voltage * 1e3, # keV to eV
|
703
|
-
spherical_aberration=args.spherical_aberration * 1e7, # Convert mm to Å
|
704
|
-
amplitude_contrast=args.amplitude_contrast,
|
705
|
-
lowpass=args.lowpass,
|
706
|
-
highpass=args.highpass,
|
707
|
-
tilt_weighting=args.tilt_weighting,
|
708
|
-
backend=args.backend,
|
709
|
-
whiten_spectrum=args.whiten_spectrum,
|
710
|
-
scramble_phases=args.scramble_phases,
|
711
|
-
)
|
712
|
-
print_params = params.to_command_args(tomo_files[0], "")
|
713
|
-
_ = print_params.pop("target")
|
714
|
-
_ = print_params.pop("output")
|
715
|
-
print_params.update({k: True for k in params.get_flags()})
|
716
|
-
print_params = {
|
717
|
-
sanitize_name(k): print_params[k] for k in sorted(list(print_params.keys()))
|
718
|
-
}
|
719
|
-
print_block(name="Matching Parameters", data=print_params, label_width=30)
|
720
|
-
print("\n" + "-" * 80)
|
721
|
-
|
722
|
-
resources = ComputeResources(
|
723
|
-
cpus=args.cpus,
|
724
|
-
memory_gb=args.memory,
|
725
|
-
gpu_count=args.gpu_count,
|
726
|
-
gpu_type=args.gpu_type,
|
727
|
-
time_limit=args.time_limit,
|
728
|
-
partition=args.partition,
|
729
|
-
)
|
730
|
-
print_params = resources.to_slurm_args()
|
731
|
-
print_params = {
|
732
|
-
sanitize_name(k): print_params[k] for k in sorted(list(print_params.keys()))
|
733
|
-
}
|
734
|
-
print_block(name="Compute Resources", data=print_params, label_width=30)
|
735
|
-
print("\n" + "-" * 80 + "\n")
|
736
|
-
|
737
|
-
tasks = []
|
738
|
-
for tomo_file in tomo_files:
|
739
|
-
task = TemplateMatchingTask(
|
740
|
-
tomo_files=tomo_file,
|
741
|
-
parameters=params,
|
742
|
-
resources=resources,
|
743
|
-
output_dir=args.output_dir,
|
744
|
-
)
|
745
|
-
tasks.append(task)
|
746
|
-
|
747
|
-
backend = SlurmBackend(
|
748
|
-
force=args.force,
|
749
|
-
dry_run=args.dry_run,
|
750
|
-
script_dir=args.script_dir,
|
751
|
-
environment_setup=args.environment_setup,
|
752
|
-
)
|
753
|
-
job_ids = backend.submit_jobs(tasks)
|
754
|
-
if args.dry_run:
|
755
|
-
print(
|
756
|
-
f"\nDry run complete. Generated {len(tasks)} scripts in {args.script_dir}"
|
757
|
-
)
|
758
|
-
else:
|
759
|
-
successful_jobs = [j for j in job_ids if not j.startswith("ERROR")]
|
760
|
-
print(f"\nSubmitted {len(successful_jobs)} jobs successfully.")
|
761
|
-
if successful_jobs:
|
762
|
-
print(f"Job IDs:\n{','.join(successful_jobs).strip()}")
|
763
|
-
|
764
|
-
except Exception as e:
|
765
|
-
print(f"Error: {e}")
|
766
|
-
|
767
|
-
|
768
|
-
if __name__ == "__main__":
|
769
|
-
main()
|