pytme 0.2.9.post1__cp311-cp311-macosx_15_0_arm64.whl → 0.3b0.post1__cp311-cp311-macosx_15_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (74) hide show
  1. pytme-0.3b0.post1.data/scripts/estimate_memory_usage.py +76 -0
  2. pytme-0.3b0.post1.data/scripts/match_template.py +1098 -0
  3. {pytme-0.2.9.post1.data → pytme-0.3b0.post1.data}/scripts/postprocess.py +318 -189
  4. {pytme-0.2.9.post1.data → pytme-0.3b0.post1.data}/scripts/preprocess.py +21 -31
  5. {pytme-0.2.9.post1.data → pytme-0.3b0.post1.data}/scripts/preprocessor_gui.py +12 -12
  6. pytme-0.3b0.post1.data/scripts/pytme_runner.py +769 -0
  7. {pytme-0.2.9.post1.dist-info → pytme-0.3b0.post1.dist-info}/METADATA +21 -20
  8. pytme-0.3b0.post1.dist-info/RECORD +126 -0
  9. {pytme-0.2.9.post1.dist-info → pytme-0.3b0.post1.dist-info}/entry_points.txt +2 -1
  10. pytme-0.3b0.post1.dist-info/licenses/LICENSE +339 -0
  11. scripts/estimate_memory_usage.py +76 -0
  12. scripts/eval.py +93 -0
  13. scripts/extract_candidates.py +224 -0
  14. scripts/match_template.py +341 -378
  15. pytme-0.2.9.post1.data/scripts/match_template.py → scripts/match_template_filters.py +213 -148
  16. scripts/postprocess.py +318 -189
  17. scripts/preprocess.py +21 -31
  18. scripts/preprocessor_gui.py +12 -12
  19. scripts/pytme_runner.py +769 -0
  20. scripts/refine_matches.py +625 -0
  21. tests/preprocessing/test_frequency_filters.py +28 -14
  22. tests/test_analyzer.py +41 -36
  23. tests/test_backends.py +1 -0
  24. tests/test_matching_cli.py +109 -54
  25. tests/test_matching_data.py +5 -5
  26. tests/test_matching_exhaustive.py +1 -2
  27. tests/test_matching_optimization.py +4 -9
  28. tests/test_matching_utils.py +1 -1
  29. tests/test_orientations.py +0 -1
  30. tme/__version__.py +1 -1
  31. tme/analyzer/__init__.py +2 -0
  32. tme/analyzer/_utils.py +26 -21
  33. tme/analyzer/aggregation.py +395 -222
  34. tme/analyzer/base.py +127 -0
  35. tme/analyzer/peaks.py +189 -204
  36. tme/analyzer/proxy.py +123 -0
  37. tme/backends/__init__.py +4 -3
  38. tme/backends/_cupy_utils.py +25 -24
  39. tme/backends/_jax_utils.py +20 -18
  40. tme/backends/cupy_backend.py +13 -26
  41. tme/backends/jax_backend.py +24 -23
  42. tme/backends/matching_backend.py +4 -3
  43. tme/backends/mlx_backend.py +4 -3
  44. tme/backends/npfftw_backend.py +34 -30
  45. tme/backends/pytorch_backend.py +18 -4
  46. tme/cli.py +126 -0
  47. tme/density.py +9 -7
  48. tme/filters/__init__.py +3 -3
  49. tme/filters/_utils.py +36 -10
  50. tme/filters/bandpass.py +229 -188
  51. tme/filters/compose.py +5 -4
  52. tme/filters/ctf.py +516 -254
  53. tme/filters/reconstruction.py +91 -32
  54. tme/filters/wedge.py +196 -135
  55. tme/filters/whitening.py +37 -42
  56. tme/matching_data.py +28 -39
  57. tme/matching_exhaustive.py +31 -27
  58. tme/matching_optimization.py +5 -4
  59. tme/matching_scores.py +25 -15
  60. tme/matching_utils.py +54 -9
  61. tme/memory.py +4 -3
  62. tme/orientations.py +22 -9
  63. tme/parser.py +114 -33
  64. tme/preprocessor.py +6 -5
  65. tme/rotations.py +10 -7
  66. tme/structure.py +4 -3
  67. pytme-0.2.9.post1.data/scripts/estimate_ram_usage.py +0 -97
  68. pytme-0.2.9.post1.dist-info/RECORD +0 -119
  69. pytme-0.2.9.post1.dist-info/licenses/LICENSE +0 -153
  70. scripts/estimate_ram_usage.py +0 -97
  71. tests/data/Maps/.DS_Store +0 -0
  72. tests/data/Structures/.DS_Store +0 -0
  73. {pytme-0.2.9.post1.dist-info → pytme-0.3b0.post1.dist-info}/WHEEL +0 -0
  74. {pytme-0.2.9.post1.dist-info → pytme-0.3b0.post1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,769 @@
1
+ #!python3
2
+ """
3
+ PyTME Batch Runner - Refactored Core Classes
4
+ """
5
+ import re
6
+ import argparse
7
+ import subprocess
8
+ from abc import ABC, abstractmethod
9
+
10
+ from pathlib import Path
11
+ from dataclasses import dataclass
12
+ from typing import Dict, List, Optional, Any
13
+
14
+ from tme.backends import backend as be
15
+ from tme.cli import print_entry, print_block, sanitize_name
16
+
17
+
18
+ @dataclass
19
+ class TomoFiles:
20
+ """Container for all files related to a single tomogram."""
21
+
22
+ #: Tomogram identifier.
23
+ tomo_id: str
24
+ #: Path to tomogram.
25
+ tomogram: Path
26
+ #: XML file with tilt angles, defocus, etc.
27
+ metadata: Path
28
+ #: Path to tomogram mask, optional.
29
+ mask: Optional[Path] = None
30
+
31
+ def __post_init__(self):
32
+ """Validate that required files exist."""
33
+ if not self.tomogram.exists():
34
+ raise FileNotFoundError(f"Tomogram not found: {self.tomogram}")
35
+ if not self.metadata.exists():
36
+ raise FileNotFoundError(f"Metadata not found: {self.metadata}")
37
+ if self.mask and not self.mask.exists():
38
+ raise FileNotFoundError(f"Mask not found: {self.mask}")
39
+
40
+
41
+ class TomoDatasetDiscovery:
42
+ """Find and match tomogram files using glob patterns."""
43
+
44
+ def __init__(
45
+ self,
46
+ mrc_pattern: str,
47
+ metadata_pattern: str,
48
+ mask_pattern: Optional[str] = None,
49
+ ):
50
+ """
51
+ Initialize with glob patterns for file discovery.
52
+
53
+ Parameters
54
+ ----------
55
+ mrc_pattern: str
56
+ Glob pattern for tomogram files, e.g., "/data/tomograms/*.mrc"
57
+ metadata_pattern: str
58
+ Glob pattern for metadata files, e.g., "/data/metadata/*.xml"
59
+ mask_pattern: str
60
+ Optional glob pattern for mask files, e.g., "/data/masks/*.mrc"
61
+ """
62
+ self.mrc_pattern = mrc_pattern
63
+ self.metadata_pattern = metadata_pattern
64
+ self.mask_pattern = mask_pattern
65
+
66
+ @staticmethod
67
+ def parse_id_from_filename(filename: str) -> str:
68
+ """Extract the tomogram ID from filename by removing technical suffixes."""
69
+ base = Path(filename).stem
70
+ # Remove technical suffixes (pixel size, binning, filtering info)
71
+ # Examples: "_10.00Apx", "_4.00Apx", "_bin4", "_dose_filt"
72
+ base = re.sub(r"_\d+(\.\d+)?(Apx|bin\d*|dose_filt)$", "", base)
73
+
74
+ # Remove common organizational prefixes if they exist
75
+ for prefix in ["rec_Position_", "Position_", "rec_", "tomo_"]:
76
+ if base.startswith(prefix):
77
+ base = base[len(prefix) :]
78
+ break
79
+ return base
80
+
81
+ def _create_mapping_table(self, pattern: str) -> Dict:
82
+ """Create a mapping table between tomogram ids and file paths."""
83
+ if pattern is None:
84
+ return {}
85
+
86
+ ret = {}
87
+ path = Path(pattern).absolute()
88
+ for file in list(Path(path.parent).glob(path.name)):
89
+ file_id = self.parse_id_from_filename(file.name)
90
+ if file_id not in ret:
91
+ ret[file_id] = []
92
+ ret[file_id].append(file)
93
+
94
+ # This could all be done in one line but we want the messages.
95
+ for key in ret.keys():
96
+ value = ret[key]
97
+ if len(value) > 1:
98
+ print(f"Found id {key} multiple times at {value}. Using {value[0]}.")
99
+ ret[key] = value[0]
100
+ return ret
101
+
102
+ def discover_tomograms(
103
+ self, tomo_list: Optional[List[str]] = None, require_mask: bool = False
104
+ ) -> List[TomoFiles]:
105
+ """Find all matching tomogram files."""
106
+ mrc_files = self._create_mapping_table(self.mrc_pattern)
107
+ meta_files = self._create_mapping_table(self.metadata_pattern)
108
+ mask_files = self._create_mapping_table(self.mask_pattern)
109
+
110
+ if tomo_list:
111
+ mrc_files = {k: v for k, v in mrc_files.items() if k in tomo_list}
112
+ meta_files = {k: v for k, v in meta_files.items() if k in tomo_list}
113
+ mask_files = {k: v for k, v in mask_files.items() if k in tomo_list}
114
+
115
+ tomo_files = []
116
+ for key, value in mrc_files.items():
117
+ if key not in meta_files:
118
+ print(f"No metadata for {key}, skipping it for now.")
119
+ continue
120
+
121
+ tomo_files.append(
122
+ TomoFiles(
123
+ tomo_id=key,
124
+ tomogram=value.absolute(),
125
+ metadata=meta_files[key].absolute(),
126
+ mask=mask_files.get(key),
127
+ )
128
+ )
129
+ return tomo_files
130
+
131
+
132
+ @dataclass
133
+ class TMParameters:
134
+ """Template matching parameters."""
135
+
136
+ template: Path
137
+ template_mask: Optional[Path] = None
138
+
139
+ # Angular sampling (auto-calculated or explicit)
140
+ angular_sampling: Optional[float] = None
141
+ particle_diameter: Optional[float] = None
142
+ cone_angle: Optional[float] = None
143
+ cone_sampling: Optional[float] = None
144
+ axis_angle: float = 360.0
145
+ axis_sampling: Optional[float] = None
146
+ axis_symmetry: int = 1
147
+ cone_axis: int = 2
148
+ invert_cone: bool = False
149
+ no_use_optimized_set: bool = False
150
+
151
+ # Microscope parameters
152
+ acceleration_voltage: float = 300.0 # kV
153
+ spherical_aberration: float = 2.7e7 # Å
154
+ amplitude_contrast: float = 0.07
155
+ defocus: Optional[float] = None # Å
156
+ phase_shift: float = 0.0 # Dg
157
+
158
+ # Processing options
159
+ lowpass: Optional[float] = None # Å
160
+ highpass: Optional[float] = None # Å
161
+ pass_format: str = "sampling_rate" # "sampling_rate", "voxel", "frequency"
162
+ no_pass_smooth: bool = True
163
+ interpolation_order: int = 3
164
+ score_threshold: float = 0.0
165
+ score: str = "FLCSphericalMask"
166
+
167
+ # Weighting and correction
168
+ tilt_weighting: Optional[str] = None # "angle", "relion", "grigorieff"
169
+ wedge_axes: str = "2,0"
170
+ whiten_spectrum: bool = False
171
+ scramble_phases: bool = False
172
+ invert_target_contrast: bool = False
173
+
174
+ # CTF parameters
175
+ ctf_file: Optional[Path] = None
176
+ no_flip_phase: bool = True
177
+ correct_defocus_gradient: bool = False
178
+
179
+ # Performance options
180
+ centering: bool = False
181
+ pad_edges: bool = False
182
+ pad_filter: bool = False
183
+ use_mixed_precision: bool = False
184
+ use_memmap: bool = False
185
+
186
+ # Analysis options
187
+ peak_calling: bool = False
188
+ num_peaks: int = 1000
189
+
190
+ # Backend selection
191
+ backend: str = "numpy"
192
+ gpu_indices: Optional[str] = None
193
+
194
+ # Reconstruction
195
+ reconstruction_filter: str = "ramp"
196
+ reconstruction_interpolation_order: int = 1
197
+ no_filter_target: bool = False
198
+
199
+ def __post_init__(self):
200
+ """Validate parameters and convert units."""
201
+ self.template = self.template.absolute()
202
+ if self.template_mask:
203
+ self.template_mask = self.template_mask.absolute()
204
+
205
+ if not self.template.exists():
206
+ raise FileNotFoundError(f"Template not found: {self.template}")
207
+ if self.template_mask and not self.template_mask.exists():
208
+ raise FileNotFoundError(f"Template mask not found: {self.template_mask}")
209
+ if self.ctf_file and not self.ctf_file.exists():
210
+ raise FileNotFoundError(f"CTF file not found: {self.ctf_file}")
211
+
212
+ if self.tilt_weighting and self.tilt_weighting not in [
213
+ "angle",
214
+ "relion",
215
+ "grigorieff",
216
+ ]:
217
+ raise ValueError(f"Invalid tilt weighting: {self.tilt_weighting}")
218
+
219
+ if self.pass_format not in ["sampling_rate", "voxel", "frequency"]:
220
+ raise ValueError(f"Invalid pass format: {self.pass_format}")
221
+
222
+ valid_backends = list(be._BACKEND_REGISTRY.keys())
223
+ if self.backend not in valid_backends:
224
+ raise ValueError(
225
+ f"Invalid backend: {self.backend}. Choose from {valid_backends}"
226
+ )
227
+
228
+ def to_command_args(
229
+ self, tomo_files: TomoFiles, output_path: Path
230
+ ) -> Dict[str, Any]:
231
+ """Convert parameters to pyTME command arguments."""
232
+ args = {
233
+ "target": str(tomo_files.tomogram),
234
+ "template": str(self.template),
235
+ "output": str(output_path),
236
+ "acceleration-voltage": self.acceleration_voltage,
237
+ "spherical-aberration": self.spherical_aberration,
238
+ "amplitude-contrast": self.amplitude_contrast,
239
+ "interpolation-order": self.interpolation_order,
240
+ "wedge-axes": self.wedge_axes,
241
+ "score-threshold": self.score_threshold,
242
+ "score": self.score,
243
+ "pass-format": self.pass_format,
244
+ "reconstruction-filter": self.reconstruction_filter,
245
+ "reconstruction-interpolation-order": self.reconstruction_interpolation_order,
246
+ }
247
+
248
+ # Optional file arguments
249
+ if self.template_mask:
250
+ args["template-mask"] = str(self.template_mask)
251
+ if tomo_files.mask:
252
+ args["target-mask"] = str(tomo_files.mask)
253
+ if tomo_files.metadata:
254
+ args["ctf-file"] = str(tomo_files.metadata)
255
+ args["tilt-angles"] = str(tomo_files.metadata)
256
+
257
+ # Optional parameters
258
+ if self.lowpass:
259
+ args["lowpass"] = self.lowpass
260
+ if self.highpass:
261
+ args["highpass"] = self.highpass
262
+ if self.tilt_weighting:
263
+ args["tilt-weighting"] = self.tilt_weighting
264
+ if self.defocus:
265
+ args["defocus"] = self.defocus
266
+ if self.phase_shift != 0:
267
+ args["phase-shift"] = self.phase_shift
268
+ if self.gpu_indices:
269
+ args["gpu-indices"] = self.gpu_indices
270
+ if self.backend != "numpy":
271
+ args["backend"] = self.backend
272
+
273
+ # Angular sampling
274
+ if self.angular_sampling:
275
+ args["angular-sampling"] = self.angular_sampling
276
+ elif self.particle_diameter:
277
+ args["particle-diameter"] = self.particle_diameter
278
+ elif self.cone_angle:
279
+ args["cone-angle"] = self.cone_angle
280
+ if self.cone_sampling:
281
+ args["cone-sampling"] = self.cone_sampling
282
+ if self.axis_sampling:
283
+ args["axis-sampling"] = self.axis_sampling
284
+ if self.axis_angle != 360.0:
285
+ args["axis-angle"] = self.axis_angle
286
+ if self.axis_symmetry != 1:
287
+ args["axis-symmetry"] = self.axis_symmetry
288
+ if self.cone_axis != 2:
289
+ args["cone-axis"] = self.cone_axis
290
+ else:
291
+ # Default fallback
292
+ args["angular-sampling"] = 15.0
293
+
294
+ args["num-peaks"] = self.num_peaks
295
+ return args
296
+
297
+ def get_flags(self) -> List[str]:
298
+ """Get boolean flags for pyTME command."""
299
+ flags = []
300
+ if self.whiten_spectrum:
301
+ flags.append("whiten-spectrum")
302
+ if self.scramble_phases:
303
+ flags.append("scramble-phases")
304
+ if self.invert_target_contrast:
305
+ flags.append("invert-target-contrast")
306
+ if self.centering:
307
+ flags.append("centering")
308
+ if self.pad_edges:
309
+ flags.append("pad-edges")
310
+ if self.pad_filter:
311
+ flags.append("pad-filter")
312
+ if not self.no_pass_smooth:
313
+ flags.append("no-pass-smooth")
314
+ if self.use_mixed_precision:
315
+ flags.append("use-mixed-precision")
316
+ if self.use_memmap:
317
+ flags.append("use-memmap")
318
+ if self.peak_calling:
319
+ flags.append("peak-calling")
320
+ if not self.no_flip_phase:
321
+ flags.append("no-flip-phase")
322
+ if self.correct_defocus_gradient:
323
+ flags.append("correct-defocus-gradient")
324
+ if self.invert_cone:
325
+ flags.append("invert-cone")
326
+ if self.no_use_optimized_set:
327
+ flags.append("no-use-optimized-set")
328
+ if self.no_filter_target:
329
+ flags.append("no-filter-target")
330
+ return flags
331
+
332
+
333
+ @dataclass
334
+ class ComputeResources:
335
+ """Compute resource requirements for a job."""
336
+
337
+ cpus: int = 4
338
+ memory_gb: int = 128
339
+ gpu_count: int = 0
340
+ gpu_type: Optional[str] = None # e.g., "3090", "A100"
341
+ time_limit: str = "05:00:00"
342
+ partition: str = "gpu-el8"
343
+ constraint: Optional[str] = None
344
+ qos: str = "normal"
345
+
346
+ def to_slurm_args(self) -> Dict[str, str]:
347
+ """Convert to SLURM sbatch arguments."""
348
+ args = {
349
+ "ntasks": "1",
350
+ "nodes": "1",
351
+ "ntasks-per-node": "1",
352
+ "cpus-per-task": str(self.cpus),
353
+ "mem": f"{self.memory_gb}G",
354
+ "time": self.time_limit,
355
+ "partition": self.partition,
356
+ "qos": self.qos,
357
+ "export": "none",
358
+ }
359
+
360
+ if self.gpu_count > 0:
361
+ args["gres"] = f"gpu:{self.gpu_count}"
362
+ if self.gpu_type:
363
+ args["constraint"] = f"gpu={self.gpu_type}"
364
+
365
+ if self.constraint and not self.gpu_type:
366
+ args["constraint"] = self.constraint
367
+
368
+ return args
369
+
370
+
371
+ @dataclass
372
+ class TemplateMatchingTask:
373
+ """A complete template matching task."""
374
+
375
+ tomo_files: TomoFiles
376
+ parameters: TMParameters
377
+ resources: ComputeResources
378
+ output_dir: Path
379
+
380
+ @property
381
+ def tomo_id(self) -> str:
382
+ return self.tomo_files.tomo_id
383
+
384
+ @property
385
+ def output_file(self) -> Path:
386
+ return self.output_dir / f"{self.tomo_id}.pickle"
387
+
388
+ def create_output_dir(self) -> None:
389
+ """Ensure output directory exists."""
390
+ self.output_dir.mkdir(parents=True, exist_ok=True)
391
+
392
+
393
+ class ExecutionBackend(ABC):
394
+ """Abstract base class for execution backends."""
395
+
396
+ @abstractmethod
397
+ def submit_job(self, task) -> str:
398
+ """Submit a single job and return job ID or status."""
399
+ pass
400
+
401
+ @abstractmethod
402
+ def submit_jobs(self, tasks: List) -> List[str]:
403
+ """Submit multiple jobs and return list of job IDs."""
404
+ pass
405
+
406
+
407
+ class SlurmBackend(ExecutionBackend):
408
+ """SLURM execution backend for cluster job submission."""
409
+
410
+ def __init__(
411
+ self,
412
+ force: bool = True,
413
+ dry_run: bool = False,
414
+ script_dir: Path = Path("./slurm_scripts"),
415
+ environment_setup: str = "module load pyTME",
416
+ ):
417
+ """
418
+ Initialize SLURM backend.
419
+
420
+ Parameters
421
+ ----------
422
+ force : bool, optional
423
+ Rerun completed jobs, defaults to True.
424
+ dry_run : bool, optional
425
+ Generate scripts but do not submit, defaults to False.
426
+ script_dir: str, optional
427
+ Directory to save generated scripts, defaults to ./slurm_scripts,
428
+ environment_setup : str, optional
429
+ Command to set up pyTME environment, defaults to module load pyTME.
430
+ """
431
+ self.force = force
432
+ self.dry_run = dry_run
433
+ self.environment_setup = environment_setup
434
+ self.script_dir = Path(script_dir) if script_dir else Path("./slurm_scripts")
435
+ self.script_dir.mkdir(exist_ok=True, parents=True)
436
+
437
+ def create_sbatch_script(self, task) -> Path:
438
+ """Generate SLURM sbatch script for a template matching task."""
439
+ script_path = self.script_dir / f"pytme_{task.tomo_id}.sh"
440
+
441
+ # Ensure output directory exists
442
+ task.create_output_dir()
443
+
444
+ slurm_args = task.resources.to_slurm_args()
445
+ slurm_args.update(
446
+ {
447
+ "output": f"{task.output_dir}/{task.tomo_id}_%j.out",
448
+ "error": f"{task.output_dir}/{task.tomo_id}_%j.err",
449
+ "job-name": f"pytme_{task.tomo_id}",
450
+ "chdir": str(task.output_dir),
451
+ }
452
+ )
453
+
454
+ script_lines = ["#!/bin/bash", "", "# SLURM directives"]
455
+ for param, value in slurm_args.items():
456
+ script_lines.append(f"#SBATCH --{param}={value}")
457
+
458
+ script_lines.extend(
459
+ [
460
+ "",
461
+ "# Environment setup",
462
+ "\n".join(self.environment_setup.split(";")),
463
+ "",
464
+ "# Run template matching",
465
+ ]
466
+ )
467
+
468
+ command_parts = ["match_template"]
469
+ cmd_args = task.parameters.to_command_args(task.tomo_files, task.output_file)
470
+ for arg, value in cmd_args.items():
471
+ command_parts.append(f"--{arg} {value}")
472
+
473
+ for flag in task.parameters.get_flags():
474
+ command_parts.append(f"--{flag}")
475
+
476
+ command = " \\\n ".join(command_parts)
477
+ script_lines.append(command)
478
+
479
+ with open(script_path, "w") as f:
480
+ f.write("\n".join(script_lines) + "\n")
481
+ script_path.chmod(0o755)
482
+
483
+ print(f"Generated SLURM script: {Path(script_path).name}")
484
+ return script_path
485
+
486
+ def submit_job(self, task) -> str:
487
+ """Submit a single SLURM job."""
488
+ script_path = self.create_sbatch_script(task)
489
+
490
+ if self.dry_run:
491
+ return f"DRY_RUN:{script_path}"
492
+
493
+ try:
494
+ if Path(task.output_file).exists() and not self.force:
495
+ return "ERROR: File exists and force was not requested."
496
+
497
+ result = subprocess.run(
498
+ ["sbatch", str(script_path)], capture_output=True, text=True, check=True
499
+ )
500
+
501
+ # Parse job ID from sbatch output
502
+ # Typical output: "Submitted batch job 123456"
503
+ job_id = result.stdout.strip().split()[-1]
504
+ print(f"Submitted job {job_id} for {task.tomo_id}")
505
+ return job_id
506
+
507
+ except subprocess.CalledProcessError as e:
508
+ error_msg = f"Failed to submit {script_path}: {e.stderr}"
509
+ return f"ERROR:{error_msg}"
510
+ except Exception as e:
511
+ error_msg = f"Submission error for {script_path}: {e}"
512
+ return f"ERROR:{error_msg}"
513
+
514
+ def submit_jobs(self, tasks: List) -> List[str]:
515
+ """Submit multiple SLURM jobs."""
516
+ job_ids = []
517
+ for task in tasks:
518
+ job_id = self.submit_job(task)
519
+ job_ids.append(job_id)
520
+ return job_ids
521
+
522
+
523
+ def parse_args():
524
+ parser = argparse.ArgumentParser(
525
+ description="Batch runner for match_template.py",
526
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
527
+ )
528
+
529
+ input_group = parser.add_argument_group("Input Files")
530
+ input_group.add_argument(
531
+ "--tomograms",
532
+ required=True,
533
+ help="Glob pattern for tomogram files (e.g., '/data/tomograms/*.mrc')",
534
+ )
535
+ input_group.add_argument(
536
+ "--metadata",
537
+ required=True,
538
+ help="Glob pattern for metadata files (e.g., '/data/metadata/*.xml')",
539
+ )
540
+ input_group.add_argument(
541
+ "--masks", help="Glob pattern for mask files (e.g., '/data/masks/*.mrc')"
542
+ )
543
+ input_group.add_argument(
544
+ "--template", required=True, type=Path, help="Template file (MRC, PDB, etc.)"
545
+ )
546
+ input_group.add_argument("--template-mask", type=Path, help="Template mask file")
547
+ input_group.add_argument(
548
+ "--tomo-list",
549
+ type=Path,
550
+ help="File with list of tomogram IDs to process (one per line)",
551
+ )
552
+
553
+ tm_group = parser.add_argument_group("Template Matching")
554
+ angular_group = tm_group.add_mutually_exclusive_group()
555
+ angular_group.add_argument(
556
+ "--angular-sampling", type=float, help="Angular sampling in degrees"
557
+ )
558
+ angular_group.add_argument(
559
+ "--particle-diameter",
560
+ type=float,
561
+ help="Particle diameter in units of sampling rate (typically Ångstrom)",
562
+ )
563
+
564
+ tm_group.add_argument(
565
+ "--score",
566
+ default="FLCSphericalMask",
567
+ help="Template matching scoring function. Use FLC if mask is not spherical.",
568
+ )
569
+ tm_group.add_argument(
570
+ "--score-threshold", type=float, default=0.0, help="Minimum score threshold"
571
+ )
572
+
573
+ scope_group = parser.add_argument_group("Microscope Parameters")
574
+ scope_group.add_argument(
575
+ "--voltage", type=float, default=300.0, help="Acceleration voltage in kV"
576
+ )
577
+ scope_group.add_argument(
578
+ "--spherical-aberration",
579
+ type=float,
580
+ default=2.7,
581
+ help="Spherical aberration in mm",
582
+ )
583
+ scope_group.add_argument(
584
+ "--amplitude-contrast", type=float, default=0.07, help="Amplitude contrast"
585
+ )
586
+
587
+ proc_group = parser.add_argument_group("Processing Options")
588
+ proc_group.add_argument(
589
+ "--lowpass",
590
+ type=float,
591
+ help="Lowpass filter in units of sampling rate (typically Ångstrom).",
592
+ )
593
+ proc_group.add_argument(
594
+ "--highpass",
595
+ type=float,
596
+ help="Highpass filter in units of sampling rate (typically Ångstrom).",
597
+ )
598
+ proc_group.add_argument(
599
+ "--tilt-weighting",
600
+ choices=["angle", "relion", "grigorieff"],
601
+ help="Tilt weighting scheme",
602
+ )
603
+ proc_group.add_argument(
604
+ "--backend",
605
+ default="cupy",
606
+ choices=list(be._BACKEND_REGISTRY.keys()),
607
+ help="Computation backend",
608
+ )
609
+ proc_group.add_argument(
610
+ "--whiten-spectrum", action="store_true", help="Apply spectral whitening"
611
+ )
612
+ proc_group.add_argument(
613
+ "--scramble-phases",
614
+ action="store_true",
615
+ help="Scramble template phases for noise estimation",
616
+ )
617
+
618
+ compute_group = parser.add_argument_group("Compute Resources")
619
+ compute_group.add_argument(
620
+ "--cpus", type=int, default=4, help="Number of CPUs per job"
621
+ )
622
+ compute_group.add_argument(
623
+ "--memory", type=int, default=64, help="Memory per job in GB"
624
+ )
625
+ compute_group.add_argument(
626
+ "--gpu-count", type=int, default=1, help="Number of GPUs per job"
627
+ )
628
+ compute_group.add_argument(
629
+ "--gpu-type", default="3090", help="GPU type constraint (e.g., '3090', 'A100')"
630
+ )
631
+ compute_group.add_argument(
632
+ "--time-limit", default="05:00:00", help="Time limit (HH:MM:SS)"
633
+ )
634
+ compute_group.add_argument("--partition", default="gpu-el8", help="SLURM partition")
635
+
636
+ job_group = parser.add_argument_group("Job Submission")
637
+ job_group.add_argument(
638
+ "--output-dir",
639
+ type=Path,
640
+ default=Path("./batch_results"),
641
+ help="Output directory for results",
642
+ )
643
+ job_group.add_argument(
644
+ "--script-dir",
645
+ type=Path,
646
+ default=Path("./slurm_scripts"),
647
+ help="Directory for generated SLURM scripts",
648
+ )
649
+ job_group.add_argument(
650
+ "--environment-setup",
651
+ default="module load pyTME",
652
+ help="Command(s) to set up pyTME environment",
653
+ )
654
+ job_group.add_argument(
655
+ "--dry-run", action="store_true", help="Generate scripts but do not submit jobs"
656
+ )
657
+ job_group.add_argument("--force", action="store_true", help="Rerun completed jobs")
658
+ args = parser.parse_args()
659
+
660
+ if args.tomo_list is not None:
661
+ with open(args.tomo_list, mode="r") as f:
662
+ args.tomo_list = [line.strip() for line in f if line.strip()]
663
+
664
+ args.output_dir = args.output_dir.absolute()
665
+ args.script_dir = args.script_dir.absolute()
666
+
667
+ return args
668
+
669
+
670
+ def main():
671
+ print_entry()
672
+
673
+ args = parse_args()
674
+ try:
675
+ discovery = TomoDatasetDiscovery(
676
+ mrc_pattern=args.tomograms,
677
+ metadata_pattern=args.metadata,
678
+ mask_pattern=args.masks,
679
+ )
680
+ tomo_files = discovery.discover_tomograms(tomo_list=args.tomo_list)
681
+ print_block(
682
+ name="Discovering Dataset",
683
+ data={
684
+ "Tomogram Pattern": args.tomograms,
685
+ "Metadata Pattern": args.metadata,
686
+ "Mask Pattern": args.masks,
687
+ "Valid Runs": len(tomo_files),
688
+ },
689
+ label_width=30,
690
+ )
691
+ if not tomo_files:
692
+ print("No tomograms found! Check your patterns.")
693
+ return
694
+
695
+ params = TMParameters(
696
+ template=args.template,
697
+ template_mask=args.template_mask,
698
+ angular_sampling=args.angular_sampling,
699
+ particle_diameter=args.particle_diameter,
700
+ score=args.score,
701
+ score_threshold=args.score_threshold,
702
+ acceleration_voltage=args.voltage * 1e3, # keV to eV
703
+ spherical_aberration=args.spherical_aberration * 1e7, # Convert mm to Å
704
+ amplitude_contrast=args.amplitude_contrast,
705
+ lowpass=args.lowpass,
706
+ highpass=args.highpass,
707
+ tilt_weighting=args.tilt_weighting,
708
+ backend=args.backend,
709
+ whiten_spectrum=args.whiten_spectrum,
710
+ scramble_phases=args.scramble_phases,
711
+ )
712
+ print_params = params.to_command_args(tomo_files[0], "")
713
+ _ = print_params.pop("target")
714
+ _ = print_params.pop("output")
715
+ print_params.update({k: True for k in params.get_flags()})
716
+ print_params = {
717
+ sanitize_name(k): print_params[k] for k in sorted(list(print_params.keys()))
718
+ }
719
+ print_block(name="Matching Parameters", data=print_params, label_width=30)
720
+ print("\n" + "-" * 80)
721
+
722
+ resources = ComputeResources(
723
+ cpus=args.cpus,
724
+ memory_gb=args.memory,
725
+ gpu_count=args.gpu_count,
726
+ gpu_type=args.gpu_type,
727
+ time_limit=args.time_limit,
728
+ partition=args.partition,
729
+ )
730
+ print_params = resources.to_slurm_args()
731
+ print_params = {
732
+ sanitize_name(k): print_params[k] for k in sorted(list(print_params.keys()))
733
+ }
734
+ print_block(name="Compute Resources", data=print_params, label_width=30)
735
+ print("\n" + "-" * 80 + "\n")
736
+
737
+ tasks = []
738
+ for tomo_file in tomo_files:
739
+ task = TemplateMatchingTask(
740
+ tomo_files=tomo_file,
741
+ parameters=params,
742
+ resources=resources,
743
+ output_dir=args.output_dir,
744
+ )
745
+ tasks.append(task)
746
+
747
+ backend = SlurmBackend(
748
+ force=args.force,
749
+ dry_run=args.dry_run,
750
+ script_dir=args.script_dir,
751
+ environment_setup=args.environment_setup,
752
+ )
753
+ job_ids = backend.submit_jobs(tasks)
754
+ if args.dry_run:
755
+ print(
756
+ f"\nDry run complete. Generated {len(tasks)} scripts in {args.script_dir}"
757
+ )
758
+ else:
759
+ successful_jobs = [j for j in job_ids if not j.startswith("ERROR")]
760
+ print(f"\nSubmitted {len(successful_jobs)} jobs successfully.")
761
+ if successful_jobs:
762
+ print(f"Job IDs:\n{','.join(successful_jobs).strip()}")
763
+
764
+ except Exception as e:
765
+ print(f"Error: {e}")
766
+
767
+
768
+ if __name__ == "__main__":
769
+ main()