pytme 0.3b0__cp311-cp311-macosx_15_0_arm64.whl → 0.3.1__cp311-cp311-macosx_15_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (73) hide show
  1. {pytme-0.3b0.data → pytme-0.3.1.data}/scripts/estimate_memory_usage.py +1 -5
  2. {pytme-0.3b0.data → pytme-0.3.1.data}/scripts/match_template.py +177 -226
  3. {pytme-0.3b0.data → pytme-0.3.1.data}/scripts/postprocess.py +69 -47
  4. {pytme-0.3b0.data → pytme-0.3.1.data}/scripts/preprocess.py +10 -23
  5. {pytme-0.3b0.data → pytme-0.3.1.data}/scripts/preprocessor_gui.py +98 -28
  6. pytme-0.3.1.data/scripts/pytme_runner.py +1223 -0
  7. {pytme-0.3b0.dist-info → pytme-0.3.1.dist-info}/METADATA +15 -15
  8. pytme-0.3.1.dist-info/RECORD +133 -0
  9. {pytme-0.3b0.dist-info → pytme-0.3.1.dist-info}/entry_points.txt +1 -0
  10. pytme-0.3.1.dist-info/licenses/LICENSE +339 -0
  11. scripts/estimate_memory_usage.py +1 -5
  12. scripts/eval.py +93 -0
  13. scripts/extract_candidates.py +118 -99
  14. scripts/match_template.py +177 -226
  15. scripts/match_template_filters.py +1200 -0
  16. scripts/postprocess.py +69 -47
  17. scripts/preprocess.py +10 -23
  18. scripts/preprocessor_gui.py +98 -28
  19. scripts/pytme_runner.py +1223 -0
  20. scripts/refine_matches.py +156 -387
  21. tests/data/.DS_Store +0 -0
  22. tests/data/Blurring/.DS_Store +0 -0
  23. tests/data/Maps/.DS_Store +0 -0
  24. tests/data/Raw/.DS_Store +0 -0
  25. tests/data/Structures/.DS_Store +0 -0
  26. tests/preprocessing/test_frequency_filters.py +19 -10
  27. tests/preprocessing/test_utils.py +18 -0
  28. tests/test_analyzer.py +122 -122
  29. tests/test_backends.py +4 -9
  30. tests/test_density.py +0 -1
  31. tests/test_matching_cli.py +30 -30
  32. tests/test_matching_data.py +5 -5
  33. tests/test_matching_utils.py +11 -61
  34. tests/test_rotations.py +1 -1
  35. tme/__version__.py +1 -1
  36. tme/analyzer/__init__.py +1 -1
  37. tme/analyzer/_utils.py +5 -8
  38. tme/analyzer/aggregation.py +28 -9
  39. tme/analyzer/base.py +25 -36
  40. tme/analyzer/peaks.py +49 -122
  41. tme/analyzer/proxy.py +1 -0
  42. tme/backends/_jax_utils.py +31 -28
  43. tme/backends/_numpyfftw_utils.py +270 -0
  44. tme/backends/cupy_backend.py +11 -54
  45. tme/backends/jax_backend.py +72 -48
  46. tme/backends/matching_backend.py +6 -51
  47. tme/backends/mlx_backend.py +1 -27
  48. tme/backends/npfftw_backend.py +95 -90
  49. tme/backends/pytorch_backend.py +5 -26
  50. tme/density.py +7 -10
  51. tme/extensions.cpython-311-darwin.so +0 -0
  52. tme/filters/__init__.py +2 -2
  53. tme/filters/_utils.py +32 -7
  54. tme/filters/bandpass.py +225 -186
  55. tme/filters/ctf.py +138 -87
  56. tme/filters/reconstruction.py +38 -9
  57. tme/filters/wedge.py +98 -112
  58. tme/filters/whitening.py +1 -6
  59. tme/mask.py +341 -0
  60. tme/matching_data.py +20 -44
  61. tme/matching_exhaustive.py +46 -56
  62. tme/matching_optimization.py +2 -1
  63. tme/matching_scores.py +216 -412
  64. tme/matching_utils.py +82 -424
  65. tme/memory.py +1 -1
  66. tme/orientations.py +16 -8
  67. tme/parser.py +109 -29
  68. tme/preprocessor.py +2 -2
  69. tme/rotations.py +1 -1
  70. pytme-0.3b0.dist-info/RECORD +0 -122
  71. pytme-0.3b0.dist-info/licenses/LICENSE +0 -153
  72. {pytme-0.3b0.dist-info → pytme-0.3.1.dist-info}/WHEEL +0 -0
  73. {pytme-0.3b0.dist-info → pytme-0.3.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1223 @@
1
+ #!python
2
+ """
3
+ PyTME Batch Runner - Refactored Core Classes
4
+ """
5
+ import re
6
+ import argparse
7
+ import subprocess
8
+ from abc import ABC, abstractmethod
9
+
10
+ from pathlib import Path
11
+ from dataclasses import dataclass
12
+ from typing import Dict, List, Optional, Any
13
+
14
+ from tme.backends import backend as be
15
+ from tme.cli import print_entry, print_block, sanitize_name
16
+
17
+
18
+ @dataclass
19
+ class TomoFiles:
20
+ """Container for all files related to a single tomogram."""
21
+
22
+ #: Tomogram identifier.
23
+ tomo_id: str
24
+ #: Path to tomogram.
25
+ tomogram: Path
26
+ #: XML file with tilt angles, defocus, etc.
27
+ metadata: Path
28
+ #: Path to tomogram mask, optional.
29
+ mask: Optional[Path] = None
30
+
31
+ def __post_init__(self):
32
+ """Validate that required files exist."""
33
+ if not self.tomogram.exists():
34
+ raise FileNotFoundError(f"Tomogram not found: {self.tomogram}")
35
+ if not self.metadata.exists():
36
+ raise FileNotFoundError(f"Metadata not found: {self.metadata}")
37
+ if self.mask and not self.mask.exists():
38
+ raise FileNotFoundError(f"Mask not found: {self.mask}")
39
+
40
+
41
+ @dataclass
42
+ class AnalysisFiles:
43
+ """Container for files related to analysis of a single tomogram."""
44
+
45
+ #: Tomogram identifier.
46
+ tomo_id: str
47
+ #: List of TM pickle result files for this tomo_id.
48
+ input_files: List[Path]
49
+ #: Background pickle files for normalization (optional).
50
+ background_files: List[Path] = None
51
+ #: Target mask file (optional).
52
+ mask: Optional[Path] = None
53
+
54
+ def __post_init__(self):
55
+ """Validate that required files exist."""
56
+ for input_file in self.input_files:
57
+ if not input_file.exists():
58
+ raise FileNotFoundError(f"Input file not found: {input_file}")
59
+
60
+ if self.background_files:
61
+ for bg_file in self.background_files:
62
+ if not bg_file.exists():
63
+ raise FileNotFoundError(f"Background file not found: {bg_file}")
64
+
65
+ if self.mask and not self.mask.exists():
66
+ raise FileNotFoundError(f"Mask not found: {self.mask}")
67
+
68
+
69
+ class DatasetDiscovery(ABC):
70
+ """Base class for dataset discovery using glob patterns."""
71
+
72
+ @abstractmethod
73
+ def discover(self, tomo_list: Optional[List[str]] = None) -> List:
74
+ pass
75
+
76
+ @staticmethod
77
+ def parse_id_from_filename(filename: str) -> str:
78
+ """Extract the tomogram ID from filename by removing technical suffixes."""
79
+ base = Path(filename).stem
80
+ # Remove technical suffixes (pixel size, binning, filtering info)
81
+ # Examples: "_10.00Apx", "_4.00Apx", "_bin4", "_dose_filt"
82
+ base = re.sub(r"_\d+(\.\d+)?(Apx|bin\d*|dose_filt)$", "", base)
83
+
84
+ # Remove common organizational prefixes if they exist
85
+ for prefix in ["rec_Position_", "Position_", "rec_", "tomo_"]:
86
+ if base.startswith(prefix):
87
+ base = base[len(prefix) :]
88
+ break
89
+ return base
90
+
91
+ def create_mapping_table(self, pattern: str) -> Dict[str, List[Path]]:
92
+ """Create a mapping table between tomogram ids and file paths."""
93
+ if pattern is None:
94
+ return {}
95
+
96
+ ret = {}
97
+ path = Path(pattern).absolute()
98
+ for file in list(Path(path.parent).glob(path.name)):
99
+ file_id = self.parse_id_from_filename(file.name)
100
+ if file_id not in ret:
101
+ ret[file_id] = []
102
+ ret[file_id].append(file)
103
+
104
+ return ret
105
+
106
+
107
+ @dataclass
108
+ class TomoDatasetDiscovery(DatasetDiscovery):
109
+ """Find and match tomogram files using glob patterns."""
110
+
111
+ #: Glob pattern for tomogram files, e.g., "/data/tomograms/*.mrc"
112
+ mrc_pattern: str
113
+ #: Glob pattern for metadata files, e.g., "/data/metadata/*.xml"
114
+ metadata_pattern: str
115
+ #: Optional glob pattern for mask files, e.g., "/data/masks/*.mrc"
116
+ mask_pattern: Optional[str] = None
117
+
118
+ def discover(self, tomo_list: Optional[List[str]] = None) -> List[TomoFiles]:
119
+ """Find all matching tomogram files."""
120
+ mrc_files = self.create_mapping_table(self.mrc_pattern)
121
+ meta_files = self.create_mapping_table(self.metadata_pattern)
122
+ mask_files = self.create_mapping_table(self.mask_pattern)
123
+
124
+ if tomo_list:
125
+ mrc_files = {k: v for k, v in mrc_files.items() if k in tomo_list}
126
+ meta_files = {k: v for k, v in meta_files.items() if k in tomo_list}
127
+ mask_files = {k: v for k, v in mask_files.items() if k in tomo_list}
128
+
129
+ tomo_files = []
130
+ for key, value in mrc_files.items():
131
+ if key not in meta_files:
132
+ print(f"No metadata for {key}, skipping it for now.")
133
+ continue
134
+
135
+ tomo_files.append(
136
+ TomoFiles(
137
+ tomo_id=key,
138
+ tomogram=value[0].absolute(),
139
+ metadata=meta_files[key][0].absolute(),
140
+ mask=mask_files.get(key, [""])[0],
141
+ )
142
+ )
143
+ return tomo_files
144
+
145
+
146
+ @dataclass
147
+ class AnalysisDatasetDiscovery(DatasetDiscovery):
148
+ """Find and match analysis files using glob patterns."""
149
+
150
+ #: Glob pattern for TM pickle files, e.g., "/data/results/*.pickle"
151
+ input_patterns: List[str]
152
+ #: List of glob patterns for background files, e.g., ["/data/bg1/*.pickle", "/data/bg2/*.pickle"]
153
+ background_patterns: List[str] = None
154
+ #: Target masks, e.g., "/data/masks/*.mrc"
155
+ mask_patterns: Optional[str] = None
156
+
157
+ def __post_init__(self):
158
+ """Ensure patterns are lists."""
159
+ if isinstance(self.input_patterns, str):
160
+ self.input_patterns = [self.input_patterns]
161
+ if self.background_patterns and isinstance(self.background_patterns, str):
162
+ self.background_patterns = [self.background_patterns]
163
+
164
+ def discover(self, tomo_list: Optional[List[str]] = None) -> List[AnalysisFiles]:
165
+ """Find all matching analysis files."""
166
+
167
+ input_files_by_id = {}
168
+ for pattern in self.input_patterns:
169
+ files = self.create_mapping_table(pattern)
170
+ for tomo_id, file_list in files.items():
171
+ if tomo_id not in input_files_by_id:
172
+ input_files_by_id[tomo_id] = []
173
+ input_files_by_id[tomo_id].extend(file_list)
174
+
175
+ background_files_by_id = {}
176
+ if self.background_patterns:
177
+ for pattern in self.background_patterns:
178
+ bg_files = self.create_mapping_table(pattern)
179
+ for tomo_id, file_list in bg_files.items():
180
+ if tomo_id not in background_files_by_id:
181
+ background_files_by_id[tomo_id] = []
182
+ background_files_by_id[tomo_id].extend(file_list)
183
+
184
+ mask_files_by_id = {}
185
+ if self.mask_patterns:
186
+ mask_files_by_id = self.create_mapping_table(self.mask_patterns)
187
+
188
+ if tomo_list:
189
+ input_files_by_id = {
190
+ k: v for k, v in input_files_by_id.items() if k in tomo_list
191
+ }
192
+ background_files_by_id = {
193
+ k: v for k, v in background_files_by_id.items() if k in tomo_list
194
+ }
195
+ mask_files_by_id = {
196
+ k: v for k, v in mask_files_by_id.items() if k in tomo_list
197
+ }
198
+
199
+ analysis_files = []
200
+ for tomo_id, input_file_list in input_files_by_id.items():
201
+ background_files = background_files_by_id.get(tomo_id, [])
202
+ mask_file = mask_files_by_id.get(tomo_id, [None])[0]
203
+
204
+ analysis_file = AnalysisFiles(
205
+ tomo_id=tomo_id,
206
+ input_files=[f.absolute() for f in input_file_list],
207
+ background_files=(
208
+ [f.absolute() for f in background_files] if background_files else []
209
+ ),
210
+ mask=mask_file.absolute() if mask_file else None,
211
+ )
212
+ analysis_files.append(analysis_file)
213
+
214
+ return analysis_files
215
+
216
+
217
+ @dataclass
218
+ class TMParameters:
219
+ """Template matching parameters."""
220
+
221
+ template: Path
222
+ template_mask: Optional[Path] = None
223
+
224
+ # Angular sampling (auto-calculated or explicit)
225
+ angular_sampling: Optional[float] = None
226
+ particle_diameter: Optional[float] = None
227
+ cone_angle: Optional[float] = None
228
+ cone_sampling: Optional[float] = None
229
+ axis_angle: float = 360.0
230
+ axis_sampling: Optional[float] = None
231
+ axis_symmetry: int = 1
232
+ cone_axis: int = 2
233
+ invert_cone: bool = False
234
+ no_use_optimized_set: bool = False
235
+
236
+ # Microscope parameters
237
+ acceleration_voltage: float = 300.0 # kV
238
+ spherical_aberration: float = 2.7e7 # Å
239
+ amplitude_contrast: float = 0.07
240
+ defocus: Optional[float] = None # Å
241
+ phase_shift: float = 0.0 # Dg
242
+
243
+ # Processing options
244
+ lowpass: Optional[float] = None # Å
245
+ highpass: Optional[float] = None # Å
246
+ pass_format: str = "sampling_rate" # "sampling_rate", "voxel", "frequency"
247
+ no_pass_smooth: bool = True
248
+ interpolation_order: int = 3
249
+ score_threshold: float = 0.0
250
+ score: str = "FLCSphericalMask"
251
+
252
+ # Weighting and correction
253
+ tilt_weighting: Optional[str] = None # "angle", "relion", "grigorieff"
254
+ wedge_axes: str = "2,0"
255
+ whiten_spectrum: bool = False
256
+ scramble_phases: bool = False
257
+ invert_target_contrast: bool = False
258
+
259
+ # CTF parameters
260
+ ctf_file: Optional[Path] = None
261
+ no_flip_phase: bool = True
262
+ correct_defocus_gradient: bool = False
263
+
264
+ # Performance options
265
+ centering: bool = False
266
+ pad_edges: bool = False
267
+ pad_filter: bool = False
268
+ use_mixed_precision: bool = False
269
+ use_memmap: bool = False
270
+
271
+ # Analysis options
272
+ peak_calling: bool = False
273
+ num_peaks: int = 1000
274
+
275
+ # Backend selection
276
+ backend: str = "numpy"
277
+ gpu_indices: Optional[str] = None
278
+
279
+ # Reconstruction
280
+ reconstruction_filter: str = "ramp"
281
+ reconstruction_interpolation_order: int = 1
282
+ no_filter_target: bool = False
283
+
284
+ def __post_init__(self):
285
+ """Validate parameters and convert units."""
286
+ self.template = self.template.absolute()
287
+ if self.template_mask:
288
+ self.template_mask = self.template_mask.absolute()
289
+
290
+ if not self.template.exists():
291
+ raise FileNotFoundError(f"Template not found: {self.template}")
292
+ if self.template_mask and not self.template_mask.exists():
293
+ raise FileNotFoundError(f"Template mask not found: {self.template_mask}")
294
+ if self.ctf_file and not self.ctf_file.exists():
295
+ raise FileNotFoundError(f"CTF file not found: {self.ctf_file}")
296
+
297
+ if self.tilt_weighting and self.tilt_weighting not in [
298
+ "angle",
299
+ "relion",
300
+ "grigorieff",
301
+ ]:
302
+ raise ValueError(f"Invalid tilt weighting: {self.tilt_weighting}")
303
+
304
+ if self.pass_format not in ["sampling_rate", "voxel", "frequency"]:
305
+ raise ValueError(f"Invalid pass format: {self.pass_format}")
306
+
307
+ valid_backends = list(be._BACKEND_REGISTRY.keys())
308
+ if self.backend not in valid_backends:
309
+ raise ValueError(
310
+ f"Invalid backend: {self.backend}. Choose from {valid_backends}"
311
+ )
312
+
313
+ def to_command_args(self, files: TomoFiles, output_path: Path) -> Dict[str, Any]:
314
+ """Convert parameters to pyTME command arguments."""
315
+ args = {
316
+ "target": str(files.tomogram),
317
+ "template": str(self.template),
318
+ "output": str(output_path),
319
+ "acceleration-voltage": self.acceleration_voltage,
320
+ "spherical-aberration": self.spherical_aberration,
321
+ "amplitude-contrast": self.amplitude_contrast,
322
+ "interpolation-order": self.interpolation_order,
323
+ "wedge-axes": self.wedge_axes,
324
+ "score-threshold": self.score_threshold,
325
+ "score": self.score,
326
+ "pass-format": self.pass_format,
327
+ "reconstruction-filter": self.reconstruction_filter,
328
+ "reconstruction-interpolation-order": self.reconstruction_interpolation_order,
329
+ }
330
+
331
+ # Optional file arguments
332
+ if self.template_mask:
333
+ args["template-mask"] = str(self.template_mask)
334
+ if files.mask:
335
+ args["target-mask"] = str(files.mask)
336
+ if files.metadata:
337
+ args["ctf-file"] = str(files.metadata)
338
+ args["tilt-angles"] = str(files.metadata)
339
+
340
+ # Optional parameters
341
+ if self.lowpass:
342
+ args["lowpass"] = self.lowpass
343
+ if self.highpass:
344
+ args["highpass"] = self.highpass
345
+ if self.tilt_weighting:
346
+ args["tilt-weighting"] = self.tilt_weighting
347
+ if self.defocus:
348
+ args["defocus"] = self.defocus
349
+ if self.phase_shift != 0:
350
+ args["phase-shift"] = self.phase_shift
351
+ if self.gpu_indices:
352
+ args["gpu-indices"] = self.gpu_indices
353
+ if self.backend != "numpy":
354
+ args["backend"] = self.backend
355
+
356
+ # Angular sampling
357
+ if self.angular_sampling:
358
+ args["angular-sampling"] = self.angular_sampling
359
+ elif self.particle_diameter:
360
+ args["particle-diameter"] = self.particle_diameter
361
+ elif self.cone_angle:
362
+ args["cone-angle"] = self.cone_angle
363
+ if self.cone_sampling:
364
+ args["cone-sampling"] = self.cone_sampling
365
+ if self.axis_sampling:
366
+ args["axis-sampling"] = self.axis_sampling
367
+ if self.axis_angle != 360.0:
368
+ args["axis-angle"] = self.axis_angle
369
+ if self.axis_symmetry != 1:
370
+ args["axis-symmetry"] = self.axis_symmetry
371
+ if self.cone_axis != 2:
372
+ args["cone-axis"] = self.cone_axis
373
+ else:
374
+ # Default fallback
375
+ args["angular-sampling"] = 15.0
376
+
377
+ args["num-peaks"] = self.num_peaks
378
+ return {k: v for k, v in args.items() if v is not None}
379
+
380
+ def get_flags(self) -> List[str]:
381
+ """Get boolean flags for pyTME command."""
382
+ flags = []
383
+ if self.whiten_spectrum:
384
+ flags.append("whiten-spectrum")
385
+ if self.scramble_phases:
386
+ flags.append("scramble-phases")
387
+ if self.invert_target_contrast:
388
+ flags.append("invert-target-contrast")
389
+ if self.centering:
390
+ flags.append("centering")
391
+ if self.pad_edges:
392
+ flags.append("pad-edges")
393
+ if self.pad_filter:
394
+ flags.append("pad-filter")
395
+ if not self.no_pass_smooth:
396
+ flags.append("no-pass-smooth")
397
+ if self.use_mixed_precision:
398
+ flags.append("use-mixed-precision")
399
+ if self.use_memmap:
400
+ flags.append("use-memmap")
401
+ if self.peak_calling:
402
+ flags.append("peak-calling")
403
+ if not self.no_flip_phase:
404
+ flags.append("no-flip-phase")
405
+ if self.correct_defocus_gradient:
406
+ flags.append("correct-defocus-gradient")
407
+ if self.invert_cone:
408
+ flags.append("invert-cone")
409
+ if self.no_use_optimized_set:
410
+ flags.append("no-use-optimized-set")
411
+ if self.no_filter_target:
412
+ flags.append("no-filter-target")
413
+ return flags
414
+
415
+
416
+ @dataclass
417
+ class AnalysisParameters:
418
+ """Parameters for template matching analysis and peak calling."""
419
+
420
+ # Peak calling
421
+ peak_caller: str = "PeakCallerMaximumFilter"
422
+ num_peaks: int = 1000
423
+ min_score: float = 0.0
424
+ max_score: Optional[float] = None
425
+ min_distance: int = 5
426
+ min_boundary_distance: int = 0
427
+ mask_edges: bool = False
428
+ n_false_positives: Optional[int] = None
429
+
430
+ # Output format
431
+ output_format: str = "relion4"
432
+ output_directory: Optional[str] = None
433
+ angles_clockwise: bool = False
434
+
435
+ # Advanced options
436
+ extraction_box_size: Optional[int] = None
437
+
438
+ def to_command_args(
439
+ self, files: AnalysisFiles, output_path: Path
440
+ ) -> Dict[str, Any]:
441
+ """Convert parameters to analyze_template_matching command arguments."""
442
+ args = {
443
+ "input-files": " ".join([str(f) for f in files.input_files]),
444
+ "output-prefix": str(output_path.parent / output_path.stem),
445
+ "peak-caller": self.peak_caller,
446
+ "num-peaks": self.num_peaks,
447
+ "min-score": self.min_score,
448
+ "min-distance": self.min_distance,
449
+ "min-boundary-distance": self.min_boundary_distance,
450
+ "output-format": self.output_format,
451
+ }
452
+
453
+ # Optional parameters
454
+ if self.max_score is not None:
455
+ args["max-score"] = self.max_score
456
+ if self.n_false_positives is not None:
457
+ args["n-false-positives"] = self.n_false_positives
458
+ if self.extraction_box_size is not None:
459
+ args["extraction-box-size"] = self.extraction_box_size
460
+ if files.mask:
461
+ args["target-mask"] = str(files.mask)
462
+
463
+ # Background files
464
+ if files.background_files:
465
+ args["background-files"] = " ".join(
466
+ [str(f) for f in files.background_files]
467
+ )
468
+
469
+ return {k: v for k, v in args.items() if v is not None}
470
+
471
+ def get_flags(self) -> List[str]:
472
+ """Get boolean flags for analyze_template_matching command."""
473
+ flags = []
474
+ if self.mask_edges:
475
+ flags.append("mask-edges")
476
+ if self.angles_clockwise:
477
+ flags.append("angles-clockwise")
478
+ return flags
479
+
480
+
481
+ @dataclass
482
+ class ComputeResources:
483
+ """Compute resource requirements for a job."""
484
+
485
+ cpus: int = 4
486
+ memory_gb: int = 128
487
+ gpu_count: int = 0
488
+ gpu_type: Optional[str] = None # e.g., "3090", "A100"
489
+ time_limit: str = "05:00:00"
490
+ partition: str = "gpu-el8"
491
+ constraint: Optional[str] = None
492
+ qos: str = "normal"
493
+
494
+ def to_slurm_args(self) -> Dict[str, str]:
495
+ """Convert to SLURM sbatch arguments."""
496
+ args = {
497
+ "ntasks": "1",
498
+ "nodes": "1",
499
+ "ntasks-per-node": "1",
500
+ "cpus-per-task": str(self.cpus),
501
+ "mem": f"{self.memory_gb}G",
502
+ "time": self.time_limit,
503
+ "partition": self.partition,
504
+ "qos": self.qos,
505
+ "export": "none",
506
+ }
507
+
508
+ if self.gpu_count > 0:
509
+ args["gres"] = f"gpu:{self.gpu_count}"
510
+ if self.gpu_type:
511
+ args["constraint"] = f"gpu={self.gpu_type}"
512
+
513
+ if self.constraint and not self.gpu_type:
514
+ args["constraint"] = self.constraint
515
+
516
+ return args
517
+
518
+
519
+ @dataclass
520
+ class AbstractTask(ABC):
521
+ """Abstract task specification"""
522
+
523
+ files: object
524
+ parameters: object
525
+ resources: ComputeResources
526
+ output_dir: Path
527
+
528
+ @property
529
+ def tomo_id(self) -> str:
530
+ return self.files.tomo_id
531
+
532
+ @abstractmethod
533
+ def executable(self) -> str:
534
+ pass
535
+
536
+ @property
537
+ @abstractmethod
538
+ def output_file(self) -> Path:
539
+ pass
540
+
541
+ def to_command_args(self):
542
+ return self.parameters.to_command_args(self.files, self.output_file)
543
+
544
+ def create_output_dir(self) -> None:
545
+ """Ensure output directory exists."""
546
+ self.output_dir.mkdir(parents=True, exist_ok=True)
547
+
548
+
549
+ @dataclass
550
+ class TemplateMatchingTask(AbstractTask):
551
+ """Template matching task."""
552
+
553
+ @property
554
+ def output_file(self) -> Path:
555
+ original_stem = self.files.tomogram.stem
556
+ return self.output_dir / f"{original_stem}.pickle"
557
+
558
+ @property
559
+ def executable(self):
560
+ return "match_template"
561
+
562
+
563
+ class AnalysisTask(AbstractTask):
564
+ """Analysis task for processing TM results."""
565
+
566
+ @property
567
+ def output_file(self) -> Path:
568
+ """Generate output filename based on format."""
569
+ prefix = self.files.input_files[0].stem
570
+
571
+ format_extensions = {
572
+ "orientations": ".tsv",
573
+ "relion4": ".star",
574
+ "relion5": ".star",
575
+ "pickle": ".pickle",
576
+ "alignment": "",
577
+ "extraction": "",
578
+ "average": ".mrc",
579
+ }
580
+
581
+ extension = format_extensions.get(self.parameters.output_format, ".tsv")
582
+ return self.output_dir / f"{prefix}{extension}"
583
+
584
+ @property
585
+ def executable(self):
586
+ return "postprocess"
587
+
588
+
589
+ class ExecutionBackend(ABC):
590
+ """Abstract base class for execution backends."""
591
+
592
+ @abstractmethod
593
+ def submit_job(self, task) -> str:
594
+ """Submit a single job and return job ID or status."""
595
+ pass
596
+
597
+ @abstractmethod
598
+ def submit_jobs(self, tasks: List) -> List[str]:
599
+ """Submit multiple jobs and return list of job IDs."""
600
+ pass
601
+
602
+
603
+ class SlurmBackend(ExecutionBackend):
604
+ """SLURM execution backend for cluster job submission."""
605
+
606
+ def __init__(
607
+ self,
608
+ force: bool = True,
609
+ dry_run: bool = False,
610
+ script_dir: Path = Path("./slurm_scripts"),
611
+ environment_setup: str = "module load pyTME",
612
+ ):
613
+ """
614
+ Initialize SLURM backend.
615
+
616
+ Parameters
617
+ ----------
618
+ force : bool, optional
619
+ Rerun completed jobs, defaults to True.
620
+ dry_run : bool, optional
621
+ Generate scripts but do not submit, defaults to False.
622
+ script_dir: str, optional
623
+ Directory to save generated scripts, defaults to ./slurm_scripts,
624
+ environment_setup : str, optional
625
+ Command to set up pyTME environment, defaults to module load pyTME.
626
+ """
627
+ self.force = force
628
+ self.dry_run = dry_run
629
+ self.environment_setup = environment_setup
630
+ self.script_dir = Path(script_dir) if script_dir else Path("./slurm_scripts")
631
+ self.script_dir.mkdir(exist_ok=True, parents=True)
632
+
633
+ def create_sbatch_script(self, task) -> Path:
634
+ """Generate SLURM sbatch script for a template matching task."""
635
+ script_path = self.script_dir / f"pytme_{task.tomo_id}.sh"
636
+
637
+ # Ensure output directory exists
638
+ task.create_output_dir()
639
+
640
+ slurm_args = task.resources.to_slurm_args()
641
+ slurm_args.update(
642
+ {
643
+ "output": f"{task.output_dir}/{task.tomo_id}_%j.out",
644
+ "error": f"{task.output_dir}/{task.tomo_id}_%j.err",
645
+ "job-name": f"pytme_{task.executable}_{task.tomo_id}",
646
+ "chdir": str(task.output_dir),
647
+ }
648
+ )
649
+
650
+ script_lines = ["#!/bin/bash", "", "# SLURM directives"]
651
+ for param, value in slurm_args.items():
652
+ script_lines.append(f"#SBATCH --{param}={value}")
653
+
654
+ script_lines.extend(
655
+ [
656
+ "",
657
+ "# Environment setup",
658
+ "\n".join(self.environment_setup.split(";")),
659
+ "",
660
+ "# Run template matching",
661
+ ]
662
+ )
663
+
664
+ command_parts = [task.executable]
665
+ cmd_args = task.to_command_args()
666
+ for arg, value in cmd_args.items():
667
+ command_parts.append(f"--{arg} {value}")
668
+
669
+ for flag in task.parameters.get_flags():
670
+ command_parts.append(f"--{flag}")
671
+
672
+ command = " \\\n ".join(command_parts)
673
+ script_lines.append(command)
674
+
675
+ with open(script_path, "w") as f:
676
+ f.write("\n".join(script_lines) + "\n")
677
+ script_path.chmod(0o755)
678
+
679
+ print(f"Generated SLURM script: {Path(script_path).name}")
680
+ return script_path
681
+
682
+ def submit_job(self, task) -> str:
683
+ """Submit a single SLURM job."""
684
+ script_path = self.create_sbatch_script(task)
685
+
686
+ if self.dry_run:
687
+ return f"DRY_RUN:{script_path}"
688
+
689
+ try:
690
+ if Path(task.output_file).exists() and not self.force:
691
+ return f"ERROR: {str(task.output_file)} exists and force was not set."
692
+
693
+ result = subprocess.run(
694
+ ["sbatch", str(script_path)], capture_output=True, text=True, check=True
695
+ )
696
+
697
+ # Parse job ID from sbatch output
698
+ # Typical output: "Submitted batch job 123456"
699
+ job_id = result.stdout.strip().split()[-1]
700
+ print(f"Submitted job {job_id} for {task.tomo_id}")
701
+ return job_id
702
+
703
+ except subprocess.CalledProcessError as e:
704
+ error_msg = f"Failed to submit {script_path}: {e.stderr}"
705
+ return f"ERROR:{error_msg}"
706
+ except Exception as e:
707
+ error_msg = f"Submission error for {script_path}: {e}"
708
+ return f"ERROR:{error_msg}"
709
+
710
+ def submit_jobs(self, tasks: List) -> List[str]:
711
+ """Submit multiple SLURM jobs."""
712
+ job_ids = []
713
+ for task in tasks:
714
+ job_id = self.submit_job(task)
715
+ job_ids.append(job_id)
716
+ return job_ids
717
+
718
+
719
+ def add_compute_resources(
720
+ parser,
721
+ default_cpus=4,
722
+ default_memory=32,
723
+ default_time="02:00:00",
724
+ default_partition="cpu",
725
+ include_gpu=False,
726
+ ):
727
+ """Add compute resource arguments to a parser."""
728
+ compute_group = parser.add_argument_group("Compute Resources")
729
+ compute_group.add_argument(
730
+ "--cpus", type=int, default=default_cpus, help="Number of CPUs per job"
731
+ )
732
+ compute_group.add_argument(
733
+ "--memory", type=int, default=default_memory, help="Memory per job in GB"
734
+ )
735
+ compute_group.add_argument(
736
+ "--time-limit", default=default_time, help="Time limit (HH:MM:SS)"
737
+ )
738
+ compute_group.add_argument(
739
+ "--partition", default=default_partition, help="SLURM partition"
740
+ )
741
+ compute_group.add_argument(
742
+ "--qos", default="normal", help="SLURM quality of service"
743
+ )
744
+
745
+ if include_gpu:
746
+ compute_group.add_argument(
747
+ "--gpu-count", type=int, default=1, help="Number of GPUs per job"
748
+ )
749
+ compute_group.add_argument(
750
+ "--gpu-type",
751
+ default="3090",
752
+ help="GPU type constraint (e.g., '3090', 'A100')",
753
+ )
754
+
755
+ return compute_group
756
+
757
+
758
+ def add_job_submission(parser, default_output_dir="./results"):
759
+ """Add job submission arguments to a parser."""
760
+ job_group = parser.add_argument_group("Job Submission")
761
+ job_group.add_argument(
762
+ "--output-dir",
763
+ type=Path,
764
+ default=Path(default_output_dir),
765
+ help="Output directory for results",
766
+ )
767
+ job_group.add_argument(
768
+ "--script-dir",
769
+ type=Path,
770
+ default=Path("./scripts"),
771
+ help="Directory for generated SLURM scripts",
772
+ )
773
+ job_group.add_argument(
774
+ "--environment-setup",
775
+ default="module load pyTME",
776
+ help="Command(s) to set up pyTME environment",
777
+ )
778
+ job_group.add_argument(
779
+ "--dry-run", action="store_true", help="Generate scripts but do not submit jobs"
780
+ )
781
+ job_group.add_argument("--force", action="store_true", help="Rerun completed jobs")
782
+
783
+ return job_group
784
+
785
+
786
+ def parse_args():
787
+ parser = argparse.ArgumentParser(
788
+ description="Batch runner for PyTME.",
789
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
790
+ )
791
+
792
+ subparsers = parser.add_subparsers(
793
+ dest="command", help="Available commands", required=True
794
+ )
795
+
796
+ matching_parser = subparsers.add_parser(
797
+ "matching",
798
+ help="Run template matching",
799
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
800
+ )
801
+
802
+ # Input files for matching
803
+ tm_input_group = matching_parser.add_argument_group("Input Files")
804
+ tm_input_group.add_argument(
805
+ "--tomograms",
806
+ required=True,
807
+ help="Glob pattern for tomogram files (e.g., '/data/tomograms/*.mrc')",
808
+ )
809
+ tm_input_group.add_argument(
810
+ "--metadata",
811
+ required=True,
812
+ help="Glob pattern for metadata files (e.g., '/data/metadata/*.xml')",
813
+ )
814
+ tm_input_group.add_argument(
815
+ "--masks", help="Glob pattern for target mask files (e.g., '/data/masks/*.mrc')"
816
+ )
817
+ tm_input_group.add_argument(
818
+ "--template", required=True, type=Path, help="Template file (MRC, PDB, etc.)"
819
+ )
820
+ tm_input_group.add_argument("--template-mask", type=Path, help="Template mask file")
821
+ tm_input_group.add_argument(
822
+ "--tomo-list",
823
+ type=Path,
824
+ help="File with list of tomogram IDs to process (one per line)",
825
+ )
826
+
827
+ # Template matching parameters
828
+ tm_group = matching_parser.add_argument_group("Template Matching")
829
+ angular_group = tm_group.add_mutually_exclusive_group()
830
+ angular_group.add_argument(
831
+ "--angular-sampling", type=float, help="Angular sampling in degrees"
832
+ )
833
+ angular_group.add_argument(
834
+ "--particle-diameter",
835
+ type=float,
836
+ help="Particle diameter in units of sampling rate (typically Ångstrom)",
837
+ )
838
+
839
+ tm_group.add_argument(
840
+ "--score",
841
+ default="FLCSphericalMask",
842
+ help="Template matching scoring function. Use FLC if mask is not spherical.",
843
+ )
844
+ tm_group.add_argument(
845
+ "--score-threshold", type=float, default=0.0, help="Minimum score threshold"
846
+ )
847
+
848
+ # Microscope parameters
849
+ scope_group = matching_parser.add_argument_group("Microscope Parameters")
850
+ scope_group.add_argument(
851
+ "--voltage", type=float, default=300.0, help="Acceleration voltage in kV"
852
+ )
853
+ scope_group.add_argument(
854
+ "--spherical-aberration",
855
+ type=float,
856
+ default=2.7,
857
+ help="Spherical aberration in mm",
858
+ )
859
+ scope_group.add_argument(
860
+ "--amplitude-contrast", type=float, default=0.07, help="Amplitude contrast"
861
+ )
862
+
863
+ # Processing options
864
+ proc_group = matching_parser.add_argument_group("Processing Options")
865
+ proc_group.add_argument(
866
+ "--lowpass",
867
+ type=float,
868
+ help="Lowpass filter in units of sampling rate (typically Ångstrom).",
869
+ )
870
+ proc_group.add_argument(
871
+ "--highpass",
872
+ type=float,
873
+ help="Highpass filter in units of sampling rate (typically Ångstrom).",
874
+ )
875
+ proc_group.add_argument(
876
+ "--tilt-weighting",
877
+ choices=["angle", "relion", "grigorieff"],
878
+ help="Tilt weighting scheme",
879
+ )
880
+ proc_group.add_argument(
881
+ "--backend",
882
+ default="cupy",
883
+ choices=list(be._BACKEND_REGISTRY.keys()),
884
+ help="Computation backend",
885
+ )
886
+ proc_group.add_argument(
887
+ "--whiten-spectrum", action="store_true", help="Apply spectral whitening"
888
+ )
889
+ proc_group.add_argument(
890
+ "--scramble-phases",
891
+ action="store_true",
892
+ help="Scramble template phases for noise estimation",
893
+ )
894
+
895
+ _ = add_compute_resources(
896
+ matching_parser,
897
+ default_cpus=4,
898
+ default_memory=64,
899
+ include_gpu=True,
900
+ default_time="05:00:00",
901
+ default_partition="gpu-el8",
902
+ )
903
+ _ = add_job_submission(matching_parser, "./matching_results")
904
+
905
+ analysis_parser = subparsers.add_parser(
906
+ "analysis",
907
+ help="Analyze template matching results",
908
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
909
+ )
910
+
911
+ # Input files for analysis
912
+ analysis_input_group = analysis_parser.add_argument_group("Input Files")
913
+ analysis_input_group.add_argument(
914
+ "--input-file",
915
+ "--input-files",
916
+ required=True,
917
+ nargs="+",
918
+ help="Path to one or multiple runs of match_template.py.",
919
+ )
920
+ analysis_input_group.add_argument(
921
+ "--background-file",
922
+ "--background-files",
923
+ required=False,
924
+ nargs="+",
925
+ default=[],
926
+ help="Path to one or multiple runs of match_template.py for normalization. "
927
+ "For instance from --scramble_phases or a different template.",
928
+ )
929
+ analysis_input_group.add_argument(
930
+ "--masks", help="Glob pattern for target mask files (e.g., '/data/masks/*.mrc')"
931
+ )
932
+ analysis_input_group.add_argument(
933
+ "--tomo-list",
934
+ type=Path,
935
+ help="File with list of tomogram IDs to process (one per line)",
936
+ )
937
+
938
+ # Peak calling parameters
939
+ peak_group = analysis_parser.add_argument_group("Peak Calling")
940
+ peak_group.add_argument(
941
+ "--peak-caller",
942
+ choices=[
943
+ "PeakCallerSort",
944
+ "PeakCallerMaximumFilter",
945
+ "PeakCallerFast",
946
+ "PeakCallerRecursiveMasking",
947
+ "PeakCallerScipy",
948
+ ],
949
+ default="PeakCallerMaximumFilter",
950
+ help="Peak caller for local maxima identification",
951
+ )
952
+ peak_group.add_argument(
953
+ "--num-peaks",
954
+ type=int,
955
+ default=1000,
956
+ help="Maximum number of peaks to identify",
957
+ )
958
+ peak_group.add_argument(
959
+ "--min-score",
960
+ type=float,
961
+ default=None,
962
+ help="Minimum score from which peaks will be considered",
963
+ )
964
+ peak_group.add_argument(
965
+ "--max-score",
966
+ type=float,
967
+ default=None,
968
+ help="Maximum score until which peaks will be considered",
969
+ )
970
+ peak_group.add_argument(
971
+ "--min-distance", type=int, default=None, help="Minimum distance between peaks"
972
+ )
973
+ peak_group.add_argument(
974
+ "--min-boundary-distance",
975
+ type=int,
976
+ default=None,
977
+ help="Minimum distance of peaks to target edges",
978
+ )
979
+ peak_group.add_argument(
980
+ "--mask-edges",
981
+ action="store_true",
982
+ default=False,
983
+ help="Whether candidates should not be identified from scores that were "
984
+ "computed from padded densities. Superseded by min_boundary_distance.",
985
+ )
986
+ peak_group.add_argument(
987
+ "--n-false-positives",
988
+ type=int,
989
+ default=None,
990
+ help="Number of accepted false-positive picks to determine minimum score",
991
+ )
992
+
993
+ # Output options
994
+ output_group = analysis_parser.add_argument_group("Output Options")
995
+ output_group.add_argument(
996
+ "--output-format",
997
+ choices=[
998
+ "orientations",
999
+ "relion4",
1000
+ "relion5",
1001
+ "alignment",
1002
+ "extraction",
1003
+ "average",
1004
+ "pickle",
1005
+ ],
1006
+ default="relion4",
1007
+ help="Output format for analysis results",
1008
+ )
1009
+ output_group.add_argument(
1010
+ "--angles-clockwise",
1011
+ action="store_true",
1012
+ help="Report Euler angles in clockwise format expected by RELION",
1013
+ )
1014
+
1015
+ advanced_group = analysis_parser.add_argument_group("Advanced Options")
1016
+ advanced_group.add_argument(
1017
+ "--extraction-box-size",
1018
+ type=int,
1019
+ default=None,
1020
+ help="Box size for extracted subtomograms (for extraction output format)",
1021
+ )
1022
+
1023
+ _ = add_compute_resources(
1024
+ analysis_parser,
1025
+ default_cpus=2,
1026
+ default_memory=16,
1027
+ include_gpu=False,
1028
+ default_time="01:00:00",
1029
+ default_partition="htc-el8",
1030
+ )
1031
+ _ = add_job_submission(analysis_parser, "./analysis_results")
1032
+
1033
+ args = parser.parse_args()
1034
+ if args.tomo_list is not None:
1035
+ with open(args.tomo_list, mode="r") as f:
1036
+ args.tomo_list = [line.strip() for line in f if line.strip()]
1037
+
1038
+ args.output_dir = args.output_dir.absolute()
1039
+ args.script_dir = args.script_dir.absolute()
1040
+ return args
1041
+
1042
+
1043
+ def run_matching(args, resources):
1044
+ discovery = TomoDatasetDiscovery(
1045
+ mrc_pattern=args.tomograms,
1046
+ metadata_pattern=args.metadata,
1047
+ mask_pattern=args.masks,
1048
+ )
1049
+ files = discovery.discover(tomo_list=args.tomo_list)
1050
+ print_block(
1051
+ name="Discovering Dataset",
1052
+ data={
1053
+ "Tomogram Pattern": args.tomograms,
1054
+ "Metadata Pattern": args.metadata,
1055
+ "Mask Pattern": args.masks,
1056
+ "Valid Runs": len(files),
1057
+ },
1058
+ label_width=30,
1059
+ )
1060
+ if not files:
1061
+ print("No tomograms found! Check your patterns.")
1062
+ return
1063
+
1064
+ params = TMParameters(
1065
+ template=args.template,
1066
+ template_mask=args.template_mask,
1067
+ angular_sampling=args.angular_sampling,
1068
+ particle_diameter=args.particle_diameter,
1069
+ score=args.score,
1070
+ score_threshold=args.score_threshold,
1071
+ acceleration_voltage=args.voltage,
1072
+ spherical_aberration=args.spherical_aberration * 1e7, # mm to Ångstrom
1073
+ amplitude_contrast=args.amplitude_contrast,
1074
+ lowpass=args.lowpass,
1075
+ highpass=args.highpass,
1076
+ tilt_weighting=args.tilt_weighting,
1077
+ backend=args.backend,
1078
+ whiten_spectrum=args.whiten_spectrum,
1079
+ scramble_phases=args.scramble_phases,
1080
+ )
1081
+ print_params = params.to_command_args(files[0], "")
1082
+ _ = print_params.pop("target")
1083
+ _ = print_params.pop("output")
1084
+ print_params.update({k: True for k in params.get_flags()})
1085
+ print_params = {
1086
+ sanitize_name(k): print_params[k] for k in sorted(list(print_params.keys()))
1087
+ }
1088
+ print_block(name="Matching Parameters", data=print_params, label_width=30)
1089
+ print("\n" + "-" * 80)
1090
+
1091
+ tasks = []
1092
+ for tomo_file in files:
1093
+ task = TemplateMatchingTask(
1094
+ files=tomo_file,
1095
+ parameters=params,
1096
+ resources=resources,
1097
+ output_dir=args.output_dir,
1098
+ )
1099
+ tasks.append(task)
1100
+
1101
+ return tasks
1102
+
1103
+
1104
+ def run_analysis(args, resources):
1105
+ discovery = AnalysisDatasetDiscovery(
1106
+ input_patterns=args.input_file,
1107
+ background_patterns=args.background_file,
1108
+ mask_patterns=args.masks,
1109
+ )
1110
+ files = discovery.discover(tomo_list=args.tomo_list)
1111
+ print_block(
1112
+ name="Discovering Dataset",
1113
+ data={
1114
+ "Input Patterns": args.input_file,
1115
+ "Background Patterns": args.background_file,
1116
+ "Mask Pattern": args.masks,
1117
+ "Valid Runs": len(files),
1118
+ },
1119
+ label_width=30,
1120
+ )
1121
+ if not files:
1122
+ print("No TM results found! Check your patterns.")
1123
+ return
1124
+
1125
+ params = AnalysisParameters(
1126
+ peak_caller=args.peak_caller,
1127
+ num_peaks=args.num_peaks,
1128
+ min_score=args.min_score,
1129
+ max_score=args.max_score,
1130
+ min_distance=args.min_distance,
1131
+ min_boundary_distance=args.min_boundary_distance,
1132
+ mask_edges=args.mask_edges,
1133
+ n_false_positives=args.n_false_positives,
1134
+ output_format=args.output_format,
1135
+ angles_clockwise=args.angles_clockwise,
1136
+ extraction_box_size=args.extraction_box_size,
1137
+ )
1138
+ print_params = params.to_command_args(files[0], Path(""))
1139
+ _ = print_params.pop("input-files", None)
1140
+ _ = print_params.pop("background-files", None)
1141
+ _ = print_params.pop("output-prefix", None)
1142
+ print_params.update({k: True for k in params.get_flags()})
1143
+ print_params = {
1144
+ sanitize_name(k): print_params[k] for k in sorted(list(print_params.keys()))
1145
+ }
1146
+ print_block(name="Analysis Parameters", data=print_params, label_width=30)
1147
+ print("\n" + "-" * 80)
1148
+
1149
+ tasks = []
1150
+ for file in files:
1151
+ task = AnalysisTask(
1152
+ files=file,
1153
+ parameters=params,
1154
+ resources=resources,
1155
+ output_dir=args.output_dir,
1156
+ )
1157
+ tasks.append(task)
1158
+
1159
+ return tasks
1160
+
1161
+
1162
+ def main():
1163
+ print_entry()
1164
+
1165
+ args = parse_args()
1166
+
1167
+ resources = ComputeResources(
1168
+ cpus=args.cpus,
1169
+ memory_gb=args.memory,
1170
+ time_limit=args.time_limit,
1171
+ partition=args.partition,
1172
+ gpu_count=getattr(args, "gpu_count", 0),
1173
+ gpu_type=getattr(args, "gpu_type", None),
1174
+ )
1175
+
1176
+ func = run_matching
1177
+ if args.command == "analysis":
1178
+ func = run_analysis
1179
+
1180
+ try:
1181
+ tasks = func(args, resources)
1182
+ except Exception as e:
1183
+ exit(f"Error: {e}")
1184
+
1185
+ if tasks is None:
1186
+ exit(-1)
1187
+
1188
+ print_params = resources.to_slurm_args()
1189
+ print_params = {
1190
+ sanitize_name(k): print_params[k] for k in sorted(list(print_params.keys()))
1191
+ }
1192
+ print_block(name="Compute Resources", data=print_params, label_width=30)
1193
+ print("\n" + "-" * 80 + "\n")
1194
+
1195
+ backend = SlurmBackend(
1196
+ force=args.force,
1197
+ dry_run=args.dry_run,
1198
+ script_dir=args.script_dir,
1199
+ environment_setup=args.environment_setup,
1200
+ )
1201
+ job_ids = backend.submit_jobs(tasks)
1202
+ if args.dry_run:
1203
+ print(
1204
+ f"\nDry run complete. Generated {len(tasks)} scripts in {args.script_dir}"
1205
+ )
1206
+ return 0
1207
+
1208
+ successful_jobs = [j for j in job_ids if not j.startswith("ERROR")]
1209
+ print(f"\nSubmitted {len(successful_jobs)} jobs successfully.")
1210
+ if successful_jobs:
1211
+ print(f"Job IDs:\n{','.join(successful_jobs).strip()}")
1212
+
1213
+ if len(successful_jobs) == len(job_ids):
1214
+ return 0
1215
+
1216
+ print("\nThe following issues arose during submission:")
1217
+ for j in job_ids:
1218
+ if j.startswith("ERROR"):
1219
+ print(j)
1220
+
1221
+
1222
+ if __name__ == "__main__":
1223
+ main()