pytme 0.3b0.post1__cp311-cp311-macosx_15_0_arm64.whl → 0.3.1__cp311-cp311-macosx_15_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (54) hide show
  1. {pytme-0.3b0.post1.data → pytme-0.3.1.data}/scripts/match_template.py +28 -39
  2. {pytme-0.3b0.post1.data → pytme-0.3.1.data}/scripts/postprocess.py +23 -10
  3. {pytme-0.3b0.post1.data → pytme-0.3.1.data}/scripts/preprocessor_gui.py +95 -24
  4. pytme-0.3.1.data/scripts/pytme_runner.py +1223 -0
  5. {pytme-0.3b0.post1.dist-info → pytme-0.3.1.dist-info}/METADATA +5 -5
  6. {pytme-0.3b0.post1.dist-info → pytme-0.3.1.dist-info}/RECORD +53 -46
  7. scripts/extract_candidates.py +118 -99
  8. scripts/match_template.py +28 -39
  9. scripts/postprocess.py +23 -10
  10. scripts/preprocessor_gui.py +95 -24
  11. scripts/pytme_runner.py +644 -190
  12. scripts/refine_matches.py +156 -386
  13. tests/data/.DS_Store +0 -0
  14. tests/data/Blurring/.DS_Store +0 -0
  15. tests/data/Maps/.DS_Store +0 -0
  16. tests/data/Raw/.DS_Store +0 -0
  17. tests/data/Structures/.DS_Store +0 -0
  18. tests/preprocessing/test_utils.py +18 -0
  19. tests/test_backends.py +3 -9
  20. tests/test_density.py +0 -1
  21. tests/test_matching_utils.py +10 -60
  22. tests/test_rotations.py +1 -1
  23. tme/__version__.py +1 -1
  24. tme/analyzer/_utils.py +4 -4
  25. tme/analyzer/aggregation.py +13 -3
  26. tme/analyzer/peaks.py +11 -10
  27. tme/backends/_jax_utils.py +15 -13
  28. tme/backends/_numpyfftw_utils.py +270 -0
  29. tme/backends/cupy_backend.py +5 -44
  30. tme/backends/jax_backend.py +58 -37
  31. tme/backends/matching_backend.py +6 -51
  32. tme/backends/mlx_backend.py +1 -27
  33. tme/backends/npfftw_backend.py +68 -65
  34. tme/backends/pytorch_backend.py +1 -26
  35. tme/density.py +2 -6
  36. tme/extensions.cpython-311-darwin.so +0 -0
  37. tme/filters/ctf.py +22 -21
  38. tme/filters/wedge.py +10 -7
  39. tme/mask.py +341 -0
  40. tme/matching_data.py +7 -19
  41. tme/matching_exhaustive.py +34 -47
  42. tme/matching_optimization.py +2 -1
  43. tme/matching_scores.py +206 -411
  44. tme/matching_utils.py +73 -422
  45. tme/memory.py +1 -1
  46. tme/orientations.py +4 -6
  47. tme/rotations.py +1 -1
  48. pytme-0.3b0.post1.data/scripts/pytme_runner.py +0 -769
  49. {pytme-0.3b0.post1.data → pytme-0.3.1.data}/scripts/estimate_memory_usage.py +0 -0
  50. {pytme-0.3b0.post1.data → pytme-0.3.1.data}/scripts/preprocess.py +0 -0
  51. {pytme-0.3b0.post1.dist-info → pytme-0.3.1.dist-info}/WHEEL +0 -0
  52. {pytme-0.3b0.post1.dist-info → pytme-0.3.1.dist-info}/entry_points.txt +0 -0
  53. {pytme-0.3b0.post1.dist-info → pytme-0.3.1.dist-info}/licenses/LICENSE +0 -0
  54. {pytme-0.3b0.post1.dist-info → pytme-0.3.1.dist-info}/top_level.txt +0 -0
scripts/pytme_runner.py CHANGED
@@ -38,30 +38,40 @@ class TomoFiles:
38
38
  raise FileNotFoundError(f"Mask not found: {self.mask}")
39
39
 
40
40
 
41
- class TomoDatasetDiscovery:
42
- """Find and match tomogram files using glob patterns."""
41
+ @dataclass
42
+ class AnalysisFiles:
43
+ """Container for files related to analysis of a single tomogram."""
43
44
 
44
- def __init__(
45
- self,
46
- mrc_pattern: str,
47
- metadata_pattern: str,
48
- mask_pattern: Optional[str] = None,
49
- ):
50
- """
51
- Initialize with glob patterns for file discovery.
45
+ #: Tomogram identifier.
46
+ tomo_id: str
47
+ #: List of TM pickle result files for this tomo_id.
48
+ input_files: List[Path]
49
+ #: Background pickle files for normalization (optional).
50
+ background_files: List[Path] = None
51
+ #: Target mask file (optional).
52
+ mask: Optional[Path] = None
52
53
 
53
- Parameters
54
- ----------
55
- mrc_pattern: str
56
- Glob pattern for tomogram files, e.g., "/data/tomograms/*.mrc"
57
- metadata_pattern: str
58
- Glob pattern for metadata files, e.g., "/data/metadata/*.xml"
59
- mask_pattern: str
60
- Optional glob pattern for mask files, e.g., "/data/masks/*.mrc"
61
- """
62
- self.mrc_pattern = mrc_pattern
63
- self.metadata_pattern = metadata_pattern
64
- self.mask_pattern = mask_pattern
54
+ def __post_init__(self):
55
+ """Validate that required files exist."""
56
+ for input_file in self.input_files:
57
+ if not input_file.exists():
58
+ raise FileNotFoundError(f"Input file not found: {input_file}")
59
+
60
+ if self.background_files:
61
+ for bg_file in self.background_files:
62
+ if not bg_file.exists():
63
+ raise FileNotFoundError(f"Background file not found: {bg_file}")
64
+
65
+ if self.mask and not self.mask.exists():
66
+ raise FileNotFoundError(f"Mask not found: {self.mask}")
67
+
68
+
69
+ class DatasetDiscovery(ABC):
70
+ """Base class for dataset discovery using glob patterns."""
71
+
72
+ @abstractmethod
73
+ def discover(self, tomo_list: Optional[List[str]] = None) -> List:
74
+ pass
65
75
 
66
76
  @staticmethod
67
77
  def parse_id_from_filename(filename: str) -> str:
@@ -78,7 +88,7 @@ class TomoDatasetDiscovery:
78
88
  break
79
89
  return base
80
90
 
81
- def _create_mapping_table(self, pattern: str) -> Dict:
91
+ def create_mapping_table(self, pattern: str) -> Dict[str, List[Path]]:
82
92
  """Create a mapping table between tomogram ids and file paths."""
83
93
  if pattern is None:
84
94
  return {}
@@ -91,21 +101,25 @@ class TomoDatasetDiscovery:
91
101
  ret[file_id] = []
92
102
  ret[file_id].append(file)
93
103
 
94
- # This could all be done in one line but we want the messages.
95
- for key in ret.keys():
96
- value = ret[key]
97
- if len(value) > 1:
98
- print(f"Found id {key} multiple times at {value}. Using {value[0]}.")
99
- ret[key] = value[0]
100
104
  return ret
101
105
 
102
- def discover_tomograms(
103
- self, tomo_list: Optional[List[str]] = None, require_mask: bool = False
104
- ) -> List[TomoFiles]:
106
+
107
+ @dataclass
108
+ class TomoDatasetDiscovery(DatasetDiscovery):
109
+ """Find and match tomogram files using glob patterns."""
110
+
111
+ #: Glob pattern for tomogram files, e.g., "/data/tomograms/*.mrc"
112
+ mrc_pattern: str
113
+ #: Glob pattern for metadata files, e.g., "/data/metadata/*.xml"
114
+ metadata_pattern: str
115
+ #: Optional glob pattern for mask files, e.g., "/data/masks/*.mrc"
116
+ mask_pattern: Optional[str] = None
117
+
118
+ def discover(self, tomo_list: Optional[List[str]] = None) -> List[TomoFiles]:
105
119
  """Find all matching tomogram files."""
106
- mrc_files = self._create_mapping_table(self.mrc_pattern)
107
- meta_files = self._create_mapping_table(self.metadata_pattern)
108
- mask_files = self._create_mapping_table(self.mask_pattern)
120
+ mrc_files = self.create_mapping_table(self.mrc_pattern)
121
+ meta_files = self.create_mapping_table(self.metadata_pattern)
122
+ mask_files = self.create_mapping_table(self.mask_pattern)
109
123
 
110
124
  if tomo_list:
111
125
  mrc_files = {k: v for k, v in mrc_files.items() if k in tomo_list}
@@ -121,14 +135,85 @@ class TomoDatasetDiscovery:
121
135
  tomo_files.append(
122
136
  TomoFiles(
123
137
  tomo_id=key,
124
- tomogram=value.absolute(),
125
- metadata=meta_files[key].absolute(),
126
- mask=mask_files.get(key),
138
+ tomogram=value[0].absolute(),
139
+ metadata=meta_files[key][0].absolute(),
140
+ mask=mask_files.get(key, [""])[0],
127
141
  )
128
142
  )
129
143
  return tomo_files
130
144
 
131
145
 
146
+ @dataclass
147
+ class AnalysisDatasetDiscovery(DatasetDiscovery):
148
+ """Find and match analysis files using glob patterns."""
149
+
150
+ #: Glob pattern for TM pickle files, e.g., "/data/results/*.pickle"
151
+ input_patterns: List[str]
152
+ #: List of glob patterns for background files, e.g., ["/data/bg1/*.pickle", "/data/bg2/*.pickle"]
153
+ background_patterns: List[str] = None
154
+ #: Target masks, e.g., "/data/masks/*.mrc"
155
+ mask_patterns: Optional[str] = None
156
+
157
+ def __post_init__(self):
158
+ """Ensure patterns are lists."""
159
+ if isinstance(self.input_patterns, str):
160
+ self.input_patterns = [self.input_patterns]
161
+ if self.background_patterns and isinstance(self.background_patterns, str):
162
+ self.background_patterns = [self.background_patterns]
163
+
164
+ def discover(self, tomo_list: Optional[List[str]] = None) -> List[AnalysisFiles]:
165
+ """Find all matching analysis files."""
166
+
167
+ input_files_by_id = {}
168
+ for pattern in self.input_patterns:
169
+ files = self.create_mapping_table(pattern)
170
+ for tomo_id, file_list in files.items():
171
+ if tomo_id not in input_files_by_id:
172
+ input_files_by_id[tomo_id] = []
173
+ input_files_by_id[tomo_id].extend(file_list)
174
+
175
+ background_files_by_id = {}
176
+ if self.background_patterns:
177
+ for pattern in self.background_patterns:
178
+ bg_files = self.create_mapping_table(pattern)
179
+ for tomo_id, file_list in bg_files.items():
180
+ if tomo_id not in background_files_by_id:
181
+ background_files_by_id[tomo_id] = []
182
+ background_files_by_id[tomo_id].extend(file_list)
183
+
184
+ mask_files_by_id = {}
185
+ if self.mask_patterns:
186
+ mask_files_by_id = self.create_mapping_table(self.mask_patterns)
187
+
188
+ if tomo_list:
189
+ input_files_by_id = {
190
+ k: v for k, v in input_files_by_id.items() if k in tomo_list
191
+ }
192
+ background_files_by_id = {
193
+ k: v for k, v in background_files_by_id.items() if k in tomo_list
194
+ }
195
+ mask_files_by_id = {
196
+ k: v for k, v in mask_files_by_id.items() if k in tomo_list
197
+ }
198
+
199
+ analysis_files = []
200
+ for tomo_id, input_file_list in input_files_by_id.items():
201
+ background_files = background_files_by_id.get(tomo_id, [])
202
+ mask_file = mask_files_by_id.get(tomo_id, [None])[0]
203
+
204
+ analysis_file = AnalysisFiles(
205
+ tomo_id=tomo_id,
206
+ input_files=[f.absolute() for f in input_file_list],
207
+ background_files=(
208
+ [f.absolute() for f in background_files] if background_files else []
209
+ ),
210
+ mask=mask_file.absolute() if mask_file else None,
211
+ )
212
+ analysis_files.append(analysis_file)
213
+
214
+ return analysis_files
215
+
216
+
132
217
  @dataclass
133
218
  class TMParameters:
134
219
  """Template matching parameters."""
@@ -225,12 +310,10 @@ class TMParameters:
225
310
  f"Invalid backend: {self.backend}. Choose from {valid_backends}"
226
311
  )
227
312
 
228
- def to_command_args(
229
- self, tomo_files: TomoFiles, output_path: Path
230
- ) -> Dict[str, Any]:
313
+ def to_command_args(self, files: TomoFiles, output_path: Path) -> Dict[str, Any]:
231
314
  """Convert parameters to pyTME command arguments."""
232
315
  args = {
233
- "target": str(tomo_files.tomogram),
316
+ "target": str(files.tomogram),
234
317
  "template": str(self.template),
235
318
  "output": str(output_path),
236
319
  "acceleration-voltage": self.acceleration_voltage,
@@ -248,11 +331,11 @@ class TMParameters:
248
331
  # Optional file arguments
249
332
  if self.template_mask:
250
333
  args["template-mask"] = str(self.template_mask)
251
- if tomo_files.mask:
252
- args["target-mask"] = str(tomo_files.mask)
253
- if tomo_files.metadata:
254
- args["ctf-file"] = str(tomo_files.metadata)
255
- args["tilt-angles"] = str(tomo_files.metadata)
334
+ if files.mask:
335
+ args["target-mask"] = str(files.mask)
336
+ if files.metadata:
337
+ args["ctf-file"] = str(files.metadata)
338
+ args["tilt-angles"] = str(files.metadata)
256
339
 
257
340
  # Optional parameters
258
341
  if self.lowpass:
@@ -292,7 +375,7 @@ class TMParameters:
292
375
  args["angular-sampling"] = 15.0
293
376
 
294
377
  args["num-peaks"] = self.num_peaks
295
- return args
378
+ return {k: v for k, v in args.items() if v is not None}
296
379
 
297
380
  def get_flags(self) -> List[str]:
298
381
  """Get boolean flags for pyTME command."""
@@ -330,6 +413,71 @@ class TMParameters:
330
413
  return flags
331
414
 
332
415
 
416
+ @dataclass
417
+ class AnalysisParameters:
418
+ """Parameters for template matching analysis and peak calling."""
419
+
420
+ # Peak calling
421
+ peak_caller: str = "PeakCallerMaximumFilter"
422
+ num_peaks: int = 1000
423
+ min_score: float = 0.0
424
+ max_score: Optional[float] = None
425
+ min_distance: int = 5
426
+ min_boundary_distance: int = 0
427
+ mask_edges: bool = False
428
+ n_false_positives: Optional[int] = None
429
+
430
+ # Output format
431
+ output_format: str = "relion4"
432
+ output_directory: Optional[str] = None
433
+ angles_clockwise: bool = False
434
+
435
+ # Advanced options
436
+ extraction_box_size: Optional[int] = None
437
+
438
+ def to_command_args(
439
+ self, files: AnalysisFiles, output_path: Path
440
+ ) -> Dict[str, Any]:
441
+ """Convert parameters to analyze_template_matching command arguments."""
442
+ args = {
443
+ "input-files": " ".join([str(f) for f in files.input_files]),
444
+ "output-prefix": str(output_path.parent / output_path.stem),
445
+ "peak-caller": self.peak_caller,
446
+ "num-peaks": self.num_peaks,
447
+ "min-score": self.min_score,
448
+ "min-distance": self.min_distance,
449
+ "min-boundary-distance": self.min_boundary_distance,
450
+ "output-format": self.output_format,
451
+ }
452
+
453
+ # Optional parameters
454
+ if self.max_score is not None:
455
+ args["max-score"] = self.max_score
456
+ if self.n_false_positives is not None:
457
+ args["n-false-positives"] = self.n_false_positives
458
+ if self.extraction_box_size is not None:
459
+ args["extraction-box-size"] = self.extraction_box_size
460
+ if files.mask:
461
+ args["target-mask"] = str(files.mask)
462
+
463
+ # Background files
464
+ if files.background_files:
465
+ args["background-files"] = " ".join(
466
+ [str(f) for f in files.background_files]
467
+ )
468
+
469
+ return {k: v for k, v in args.items() if v is not None}
470
+
471
+ def get_flags(self) -> List[str]:
472
+ """Get boolean flags for analyze_template_matching command."""
473
+ flags = []
474
+ if self.mask_edges:
475
+ flags.append("mask-edges")
476
+ if self.angles_clockwise:
477
+ flags.append("angles-clockwise")
478
+ return flags
479
+
480
+
333
481
  @dataclass
334
482
  class ComputeResources:
335
483
  """Compute resource requirements for a job."""
@@ -369,27 +517,75 @@ class ComputeResources:
369
517
 
370
518
 
371
519
  @dataclass
372
- class TemplateMatchingTask:
373
- """A complete template matching task."""
520
+ class AbstractTask(ABC):
521
+ """Abstract task specification"""
374
522
 
375
- tomo_files: TomoFiles
376
- parameters: TMParameters
523
+ files: object
524
+ parameters: object
377
525
  resources: ComputeResources
378
526
  output_dir: Path
379
527
 
380
528
  @property
381
529
  def tomo_id(self) -> str:
382
- return self.tomo_files.tomo_id
530
+ return self.files.tomo_id
531
+
532
+ @abstractmethod
533
+ def executable(self) -> str:
534
+ pass
383
535
 
384
536
  @property
537
+ @abstractmethod
385
538
  def output_file(self) -> Path:
386
- return self.output_dir / f"{self.tomo_id}.pickle"
539
+ pass
540
+
541
+ def to_command_args(self):
542
+ return self.parameters.to_command_args(self.files, self.output_file)
387
543
 
388
544
  def create_output_dir(self) -> None:
389
545
  """Ensure output directory exists."""
390
546
  self.output_dir.mkdir(parents=True, exist_ok=True)
391
547
 
392
548
 
549
+ @dataclass
550
+ class TemplateMatchingTask(AbstractTask):
551
+ """Template matching task."""
552
+
553
+ @property
554
+ def output_file(self) -> Path:
555
+ original_stem = self.files.tomogram.stem
556
+ return self.output_dir / f"{original_stem}.pickle"
557
+
558
+ @property
559
+ def executable(self):
560
+ return "match_template"
561
+
562
+
563
+ class AnalysisTask(AbstractTask):
564
+ """Analysis task for processing TM results."""
565
+
566
+ @property
567
+ def output_file(self) -> Path:
568
+ """Generate output filename based on format."""
569
+ prefix = self.files.input_files[0].stem
570
+
571
+ format_extensions = {
572
+ "orientations": ".tsv",
573
+ "relion4": ".star",
574
+ "relion5": ".star",
575
+ "pickle": ".pickle",
576
+ "alignment": "",
577
+ "extraction": "",
578
+ "average": ".mrc",
579
+ }
580
+
581
+ extension = format_extensions.get(self.parameters.output_format, ".tsv")
582
+ return self.output_dir / f"{prefix}{extension}"
583
+
584
+ @property
585
+ def executable(self):
586
+ return "postprocess"
587
+
588
+
393
589
  class ExecutionBackend(ABC):
394
590
  """Abstract base class for execution backends."""
395
591
 
@@ -446,7 +642,7 @@ class SlurmBackend(ExecutionBackend):
446
642
  {
447
643
  "output": f"{task.output_dir}/{task.tomo_id}_%j.out",
448
644
  "error": f"{task.output_dir}/{task.tomo_id}_%j.err",
449
- "job-name": f"pytme_{task.tomo_id}",
645
+ "job-name": f"pytme_{task.executable}_{task.tomo_id}",
450
646
  "chdir": str(task.output_dir),
451
647
  }
452
648
  )
@@ -465,8 +661,8 @@ class SlurmBackend(ExecutionBackend):
465
661
  ]
466
662
  )
467
663
 
468
- command_parts = ["match_template"]
469
- cmd_args = task.parameters.to_command_args(task.tomo_files, task.output_file)
664
+ command_parts = [task.executable]
665
+ cmd_args = task.to_command_args()
470
666
  for arg, value in cmd_args.items():
471
667
  command_parts.append(f"--{arg} {value}")
472
668
 
@@ -492,7 +688,7 @@ class SlurmBackend(ExecutionBackend):
492
688
 
493
689
  try:
494
690
  if Path(task.output_file).exists() and not self.force:
495
- return "ERROR: File exists and force was not requested."
691
+ return f"ERROR: {str(task.output_file)} exists and force was not set."
496
692
 
497
693
  result = subprocess.run(
498
694
  ["sbatch", str(script_path)], capture_output=True, text=True, check=True
@@ -520,37 +716,116 @@ class SlurmBackend(ExecutionBackend):
520
716
  return job_ids
521
717
 
522
718
 
719
+ def add_compute_resources(
720
+ parser,
721
+ default_cpus=4,
722
+ default_memory=32,
723
+ default_time="02:00:00",
724
+ default_partition="cpu",
725
+ include_gpu=False,
726
+ ):
727
+ """Add compute resource arguments to a parser."""
728
+ compute_group = parser.add_argument_group("Compute Resources")
729
+ compute_group.add_argument(
730
+ "--cpus", type=int, default=default_cpus, help="Number of CPUs per job"
731
+ )
732
+ compute_group.add_argument(
733
+ "--memory", type=int, default=default_memory, help="Memory per job in GB"
734
+ )
735
+ compute_group.add_argument(
736
+ "--time-limit", default=default_time, help="Time limit (HH:MM:SS)"
737
+ )
738
+ compute_group.add_argument(
739
+ "--partition", default=default_partition, help="SLURM partition"
740
+ )
741
+ compute_group.add_argument(
742
+ "--qos", default="normal", help="SLURM quality of service"
743
+ )
744
+
745
+ if include_gpu:
746
+ compute_group.add_argument(
747
+ "--gpu-count", type=int, default=1, help="Number of GPUs per job"
748
+ )
749
+ compute_group.add_argument(
750
+ "--gpu-type",
751
+ default="3090",
752
+ help="GPU type constraint (e.g., '3090', 'A100')",
753
+ )
754
+
755
+ return compute_group
756
+
757
+
758
+ def add_job_submission(parser, default_output_dir="./results"):
759
+ """Add job submission arguments to a parser."""
760
+ job_group = parser.add_argument_group("Job Submission")
761
+ job_group.add_argument(
762
+ "--output-dir",
763
+ type=Path,
764
+ default=Path(default_output_dir),
765
+ help="Output directory for results",
766
+ )
767
+ job_group.add_argument(
768
+ "--script-dir",
769
+ type=Path,
770
+ default=Path("./scripts"),
771
+ help="Directory for generated SLURM scripts",
772
+ )
773
+ job_group.add_argument(
774
+ "--environment-setup",
775
+ default="module load pyTME",
776
+ help="Command(s) to set up pyTME environment",
777
+ )
778
+ job_group.add_argument(
779
+ "--dry-run", action="store_true", help="Generate scripts but do not submit jobs"
780
+ )
781
+ job_group.add_argument("--force", action="store_true", help="Rerun completed jobs")
782
+
783
+ return job_group
784
+
785
+
523
786
  def parse_args():
524
787
  parser = argparse.ArgumentParser(
525
- description="Batch runner for match_template.py",
788
+ description="Batch runner for PyTME.",
526
789
  formatter_class=argparse.ArgumentDefaultsHelpFormatter,
527
790
  )
528
791
 
529
- input_group = parser.add_argument_group("Input Files")
530
- input_group.add_argument(
792
+ subparsers = parser.add_subparsers(
793
+ dest="command", help="Available commands", required=True
794
+ )
795
+
796
+ matching_parser = subparsers.add_parser(
797
+ "matching",
798
+ help="Run template matching",
799
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
800
+ )
801
+
802
+ # Input files for matching
803
+ tm_input_group = matching_parser.add_argument_group("Input Files")
804
+ tm_input_group.add_argument(
531
805
  "--tomograms",
532
806
  required=True,
533
807
  help="Glob pattern for tomogram files (e.g., '/data/tomograms/*.mrc')",
534
808
  )
535
- input_group.add_argument(
809
+ tm_input_group.add_argument(
536
810
  "--metadata",
537
811
  required=True,
538
812
  help="Glob pattern for metadata files (e.g., '/data/metadata/*.xml')",
539
813
  )
540
- input_group.add_argument(
541
- "--masks", help="Glob pattern for mask files (e.g., '/data/masks/*.mrc')"
814
+ tm_input_group.add_argument(
815
+ "--masks", help="Glob pattern for target mask files (e.g., '/data/masks/*.mrc')"
542
816
  )
543
- input_group.add_argument(
817
+ tm_input_group.add_argument(
544
818
  "--template", required=True, type=Path, help="Template file (MRC, PDB, etc.)"
545
819
  )
546
- input_group.add_argument("--template-mask", type=Path, help="Template mask file")
547
- input_group.add_argument(
820
+ tm_input_group.add_argument("--template-mask", type=Path, help="Template mask file")
821
+ tm_input_group.add_argument(
548
822
  "--tomo-list",
549
823
  type=Path,
550
824
  help="File with list of tomogram IDs to process (one per line)",
551
825
  )
552
826
 
553
- tm_group = parser.add_argument_group("Template Matching")
827
+ # Template matching parameters
828
+ tm_group = matching_parser.add_argument_group("Template Matching")
554
829
  angular_group = tm_group.add_mutually_exclusive_group()
555
830
  angular_group.add_argument(
556
831
  "--angular-sampling", type=float, help="Angular sampling in degrees"
@@ -570,7 +845,8 @@ def parse_args():
570
845
  "--score-threshold", type=float, default=0.0, help="Minimum score threshold"
571
846
  )
572
847
 
573
- scope_group = parser.add_argument_group("Microscope Parameters")
848
+ # Microscope parameters
849
+ scope_group = matching_parser.add_argument_group("Microscope Parameters")
574
850
  scope_group.add_argument(
575
851
  "--voltage", type=float, default=300.0, help="Acceleration voltage in kV"
576
852
  )
@@ -584,7 +860,8 @@ def parse_args():
584
860
  "--amplitude-contrast", type=float, default=0.07, help="Amplitude contrast"
585
861
  )
586
862
 
587
- proc_group = parser.add_argument_group("Processing Options")
863
+ # Processing options
864
+ proc_group = matching_parser.add_argument_group("Processing Options")
588
865
  proc_group.add_argument(
589
866
  "--lowpass",
590
867
  type=float,
@@ -615,154 +892,331 @@ def parse_args():
615
892
  help="Scramble template phases for noise estimation",
616
893
  )
617
894
 
618
- compute_group = parser.add_argument_group("Compute Resources")
619
- compute_group.add_argument(
620
- "--cpus", type=int, default=4, help="Number of CPUs per job"
895
+ _ = add_compute_resources(
896
+ matching_parser,
897
+ default_cpus=4,
898
+ default_memory=64,
899
+ include_gpu=True,
900
+ default_time="05:00:00",
901
+ default_partition="gpu-el8",
621
902
  )
622
- compute_group.add_argument(
623
- "--memory", type=int, default=64, help="Memory per job in GB"
903
+ _ = add_job_submission(matching_parser, "./matching_results")
904
+
905
+ analysis_parser = subparsers.add_parser(
906
+ "analysis",
907
+ help="Analyze template matching results",
908
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
624
909
  )
625
- compute_group.add_argument(
626
- "--gpu-count", type=int, default=1, help="Number of GPUs per job"
910
+
911
+ # Input files for analysis
912
+ analysis_input_group = analysis_parser.add_argument_group("Input Files")
913
+ analysis_input_group.add_argument(
914
+ "--input-file",
915
+ "--input-files",
916
+ required=True,
917
+ nargs="+",
918
+ help="Path to one or multiple runs of match_template.py.",
627
919
  )
628
- compute_group.add_argument(
629
- "--gpu-type", default="3090", help="GPU type constraint (e.g., '3090', 'A100')"
920
+ analysis_input_group.add_argument(
921
+ "--background-file",
922
+ "--background-files",
923
+ required=False,
924
+ nargs="+",
925
+ default=[],
926
+ help="Path to one or multiple runs of match_template.py for normalization. "
927
+ "For instance from --scramble_phases or a different template.",
630
928
  )
631
- compute_group.add_argument(
632
- "--time-limit", default="05:00:00", help="Time limit (HH:MM:SS)"
929
+ analysis_input_group.add_argument(
930
+ "--masks", help="Glob pattern for target mask files (e.g., '/data/masks/*.mrc')"
633
931
  )
634
- compute_group.add_argument("--partition", default="gpu-el8", help="SLURM partition")
635
-
636
- job_group = parser.add_argument_group("Job Submission")
637
- job_group.add_argument(
638
- "--output-dir",
932
+ analysis_input_group.add_argument(
933
+ "--tomo-list",
639
934
  type=Path,
640
- default=Path("./batch_results"),
641
- help="Output directory for results",
935
+ help="File with list of tomogram IDs to process (one per line)",
642
936
  )
643
- job_group.add_argument(
644
- "--script-dir",
645
- type=Path,
646
- default=Path("./slurm_scripts"),
647
- help="Directory for generated SLURM scripts",
937
+
938
+ # Peak calling parameters
939
+ peak_group = analysis_parser.add_argument_group("Peak Calling")
940
+ peak_group.add_argument(
941
+ "--peak-caller",
942
+ choices=[
943
+ "PeakCallerSort",
944
+ "PeakCallerMaximumFilter",
945
+ "PeakCallerFast",
946
+ "PeakCallerRecursiveMasking",
947
+ "PeakCallerScipy",
948
+ ],
949
+ default="PeakCallerMaximumFilter",
950
+ help="Peak caller for local maxima identification",
648
951
  )
649
- job_group.add_argument(
650
- "--environment-setup",
651
- default="module load pyTME",
652
- help="Command(s) to set up pyTME environment",
952
+ peak_group.add_argument(
953
+ "--num-peaks",
954
+ type=int,
955
+ default=1000,
956
+ help="Maximum number of peaks to identify",
653
957
  )
654
- job_group.add_argument(
655
- "--dry-run", action="store_true", help="Generate scripts but do not submit jobs"
958
+ peak_group.add_argument(
959
+ "--min-score",
960
+ type=float,
961
+ default=None,
962
+ help="Minimum score from which peaks will be considered",
963
+ )
964
+ peak_group.add_argument(
965
+ "--max-score",
966
+ type=float,
967
+ default=None,
968
+ help="Maximum score until which peaks will be considered",
969
+ )
970
+ peak_group.add_argument(
971
+ "--min-distance", type=int, default=None, help="Minimum distance between peaks"
972
+ )
973
+ peak_group.add_argument(
974
+ "--min-boundary-distance",
975
+ type=int,
976
+ default=None,
977
+ help="Minimum distance of peaks to target edges",
978
+ )
979
+ peak_group.add_argument(
980
+ "--mask-edges",
981
+ action="store_true",
982
+ default=False,
983
+ help="Whether candidates should not be identified from scores that were "
984
+ "computed from padded densities. Superseded by min_boundary_distance.",
985
+ )
986
+ peak_group.add_argument(
987
+ "--n-false-positives",
988
+ type=int,
989
+ default=None,
990
+ help="Number of accepted false-positive picks to determine minimum score",
656
991
  )
657
- job_group.add_argument("--force", action="store_true", help="Rerun completed jobs")
658
- args = parser.parse_args()
659
992
 
993
+ # Output options
994
+ output_group = analysis_parser.add_argument_group("Output Options")
995
+ output_group.add_argument(
996
+ "--output-format",
997
+ choices=[
998
+ "orientations",
999
+ "relion4",
1000
+ "relion5",
1001
+ "alignment",
1002
+ "extraction",
1003
+ "average",
1004
+ "pickle",
1005
+ ],
1006
+ default="relion4",
1007
+ help="Output format for analysis results",
1008
+ )
1009
+ output_group.add_argument(
1010
+ "--angles-clockwise",
1011
+ action="store_true",
1012
+ help="Report Euler angles in clockwise format expected by RELION",
1013
+ )
1014
+
1015
+ advanced_group = analysis_parser.add_argument_group("Advanced Options")
1016
+ advanced_group.add_argument(
1017
+ "--extraction-box-size",
1018
+ type=int,
1019
+ default=None,
1020
+ help="Box size for extracted subtomograms (for extraction output format)",
1021
+ )
1022
+
1023
+ _ = add_compute_resources(
1024
+ analysis_parser,
1025
+ default_cpus=2,
1026
+ default_memory=16,
1027
+ include_gpu=False,
1028
+ default_time="01:00:00",
1029
+ default_partition="htc-el8",
1030
+ )
1031
+ _ = add_job_submission(analysis_parser, "./analysis_results")
1032
+
1033
+ args = parser.parse_args()
660
1034
  if args.tomo_list is not None:
661
1035
  with open(args.tomo_list, mode="r") as f:
662
1036
  args.tomo_list = [line.strip() for line in f if line.strip()]
663
1037
 
664
1038
  args.output_dir = args.output_dir.absolute()
665
1039
  args.script_dir = args.script_dir.absolute()
666
-
667
1040
  return args
668
1041
 
669
1042
 
1043
+ def run_matching(args, resources):
1044
+ discovery = TomoDatasetDiscovery(
1045
+ mrc_pattern=args.tomograms,
1046
+ metadata_pattern=args.metadata,
1047
+ mask_pattern=args.masks,
1048
+ )
1049
+ files = discovery.discover(tomo_list=args.tomo_list)
1050
+ print_block(
1051
+ name="Discovering Dataset",
1052
+ data={
1053
+ "Tomogram Pattern": args.tomograms,
1054
+ "Metadata Pattern": args.metadata,
1055
+ "Mask Pattern": args.masks,
1056
+ "Valid Runs": len(files),
1057
+ },
1058
+ label_width=30,
1059
+ )
1060
+ if not files:
1061
+ print("No tomograms found! Check your patterns.")
1062
+ return
1063
+
1064
+ params = TMParameters(
1065
+ template=args.template,
1066
+ template_mask=args.template_mask,
1067
+ angular_sampling=args.angular_sampling,
1068
+ particle_diameter=args.particle_diameter,
1069
+ score=args.score,
1070
+ score_threshold=args.score_threshold,
1071
+ acceleration_voltage=args.voltage,
1072
+ spherical_aberration=args.spherical_aberration * 1e7, # mm to Ångstrom
1073
+ amplitude_contrast=args.amplitude_contrast,
1074
+ lowpass=args.lowpass,
1075
+ highpass=args.highpass,
1076
+ tilt_weighting=args.tilt_weighting,
1077
+ backend=args.backend,
1078
+ whiten_spectrum=args.whiten_spectrum,
1079
+ scramble_phases=args.scramble_phases,
1080
+ )
1081
+ print_params = params.to_command_args(files[0], "")
1082
+ _ = print_params.pop("target")
1083
+ _ = print_params.pop("output")
1084
+ print_params.update({k: True for k in params.get_flags()})
1085
+ print_params = {
1086
+ sanitize_name(k): print_params[k] for k in sorted(list(print_params.keys()))
1087
+ }
1088
+ print_block(name="Matching Parameters", data=print_params, label_width=30)
1089
+ print("\n" + "-" * 80)
1090
+
1091
+ tasks = []
1092
+ for tomo_file in files:
1093
+ task = TemplateMatchingTask(
1094
+ files=tomo_file,
1095
+ parameters=params,
1096
+ resources=resources,
1097
+ output_dir=args.output_dir,
1098
+ )
1099
+ tasks.append(task)
1100
+
1101
+ return tasks
1102
+
1103
+
1104
+ def run_analysis(args, resources):
1105
+ discovery = AnalysisDatasetDiscovery(
1106
+ input_patterns=args.input_file,
1107
+ background_patterns=args.background_file,
1108
+ mask_patterns=args.masks,
1109
+ )
1110
+ files = discovery.discover(tomo_list=args.tomo_list)
1111
+ print_block(
1112
+ name="Discovering Dataset",
1113
+ data={
1114
+ "Input Patterns": args.input_file,
1115
+ "Background Patterns": args.background_file,
1116
+ "Mask Pattern": args.masks,
1117
+ "Valid Runs": len(files),
1118
+ },
1119
+ label_width=30,
1120
+ )
1121
+ if not files:
1122
+ print("No TM results found! Check your patterns.")
1123
+ return
1124
+
1125
+ params = AnalysisParameters(
1126
+ peak_caller=args.peak_caller,
1127
+ num_peaks=args.num_peaks,
1128
+ min_score=args.min_score,
1129
+ max_score=args.max_score,
1130
+ min_distance=args.min_distance,
1131
+ min_boundary_distance=args.min_boundary_distance,
1132
+ mask_edges=args.mask_edges,
1133
+ n_false_positives=args.n_false_positives,
1134
+ output_format=args.output_format,
1135
+ angles_clockwise=args.angles_clockwise,
1136
+ extraction_box_size=args.extraction_box_size,
1137
+ )
1138
+ print_params = params.to_command_args(files[0], Path(""))
1139
+ _ = print_params.pop("input-files", None)
1140
+ _ = print_params.pop("background-files", None)
1141
+ _ = print_params.pop("output-prefix", None)
1142
+ print_params.update({k: True for k in params.get_flags()})
1143
+ print_params = {
1144
+ sanitize_name(k): print_params[k] for k in sorted(list(print_params.keys()))
1145
+ }
1146
+ print_block(name="Analysis Parameters", data=print_params, label_width=30)
1147
+ print("\n" + "-" * 80)
1148
+
1149
+ tasks = []
1150
+ for file in files:
1151
+ task = AnalysisTask(
1152
+ files=file,
1153
+ parameters=params,
1154
+ resources=resources,
1155
+ output_dir=args.output_dir,
1156
+ )
1157
+ tasks.append(task)
1158
+
1159
+ return tasks
1160
+
1161
+
670
1162
  def main():
671
1163
  print_entry()
672
1164
 
673
1165
  args = parse_args()
1166
+
1167
+ resources = ComputeResources(
1168
+ cpus=args.cpus,
1169
+ memory_gb=args.memory,
1170
+ time_limit=args.time_limit,
1171
+ partition=args.partition,
1172
+ gpu_count=getattr(args, "gpu_count", 0),
1173
+ gpu_type=getattr(args, "gpu_type", None),
1174
+ )
1175
+
1176
+ func = run_matching
1177
+ if args.command == "analysis":
1178
+ func = run_analysis
1179
+
674
1180
  try:
675
- discovery = TomoDatasetDiscovery(
676
- mrc_pattern=args.tomograms,
677
- metadata_pattern=args.metadata,
678
- mask_pattern=args.masks,
679
- )
680
- tomo_files = discovery.discover_tomograms(tomo_list=args.tomo_list)
681
- print_block(
682
- name="Discovering Dataset",
683
- data={
684
- "Tomogram Pattern": args.tomograms,
685
- "Metadata Pattern": args.metadata,
686
- "Mask Pattern": args.masks,
687
- "Valid Runs": len(tomo_files),
688
- },
689
- label_width=30,
690
- )
691
- if not tomo_files:
692
- print("No tomograms found! Check your patterns.")
693
- return
694
-
695
- params = TMParameters(
696
- template=args.template,
697
- template_mask=args.template_mask,
698
- angular_sampling=args.angular_sampling,
699
- particle_diameter=args.particle_diameter,
700
- score=args.score,
701
- score_threshold=args.score_threshold,
702
- acceleration_voltage=args.voltage * 1e3, # keV to eV
703
- spherical_aberration=args.spherical_aberration * 1e7, # Convert mm to Å
704
- amplitude_contrast=args.amplitude_contrast,
705
- lowpass=args.lowpass,
706
- highpass=args.highpass,
707
- tilt_weighting=args.tilt_weighting,
708
- backend=args.backend,
709
- whiten_spectrum=args.whiten_spectrum,
710
- scramble_phases=args.scramble_phases,
711
- )
712
- print_params = params.to_command_args(tomo_files[0], "")
713
- _ = print_params.pop("target")
714
- _ = print_params.pop("output")
715
- print_params.update({k: True for k in params.get_flags()})
716
- print_params = {
717
- sanitize_name(k): print_params[k] for k in sorted(list(print_params.keys()))
718
- }
719
- print_block(name="Matching Parameters", data=print_params, label_width=30)
720
- print("\n" + "-" * 80)
721
-
722
- resources = ComputeResources(
723
- cpus=args.cpus,
724
- memory_gb=args.memory,
725
- gpu_count=args.gpu_count,
726
- gpu_type=args.gpu_type,
727
- time_limit=args.time_limit,
728
- partition=args.partition,
1181
+ tasks = func(args, resources)
1182
+ except Exception as e:
1183
+ exit(f"Error: {e}")
1184
+
1185
+ if tasks is None:
1186
+ exit(-1)
1187
+
1188
+ print_params = resources.to_slurm_args()
1189
+ print_params = {
1190
+ sanitize_name(k): print_params[k] for k in sorted(list(print_params.keys()))
1191
+ }
1192
+ print_block(name="Compute Resources", data=print_params, label_width=30)
1193
+ print("\n" + "-" * 80 + "\n")
1194
+
1195
+ backend = SlurmBackend(
1196
+ force=args.force,
1197
+ dry_run=args.dry_run,
1198
+ script_dir=args.script_dir,
1199
+ environment_setup=args.environment_setup,
1200
+ )
1201
+ job_ids = backend.submit_jobs(tasks)
1202
+ if args.dry_run:
1203
+ print(
1204
+ f"\nDry run complete. Generated {len(tasks)} scripts in {args.script_dir}"
729
1205
  )
730
- print_params = resources.to_slurm_args()
731
- print_params = {
732
- sanitize_name(k): print_params[k] for k in sorted(list(print_params.keys()))
733
- }
734
- print_block(name="Compute Resources", data=print_params, label_width=30)
735
- print("\n" + "-" * 80 + "\n")
736
-
737
- tasks = []
738
- for tomo_file in tomo_files:
739
- task = TemplateMatchingTask(
740
- tomo_files=tomo_file,
741
- parameters=params,
742
- resources=resources,
743
- output_dir=args.output_dir,
744
- )
745
- tasks.append(task)
1206
+ return 0
746
1207
 
747
- backend = SlurmBackend(
748
- force=args.force,
749
- dry_run=args.dry_run,
750
- script_dir=args.script_dir,
751
- environment_setup=args.environment_setup,
752
- )
753
- job_ids = backend.submit_jobs(tasks)
754
- if args.dry_run:
755
- print(
756
- f"\nDry run complete. Generated {len(tasks)} scripts in {args.script_dir}"
757
- )
758
- else:
759
- successful_jobs = [j for j in job_ids if not j.startswith("ERROR")]
760
- print(f"\nSubmitted {len(successful_jobs)} jobs successfully.")
761
- if successful_jobs:
762
- print(f"Job IDs:\n{','.join(successful_jobs).strip()}")
1208
+ successful_jobs = [j for j in job_ids if not j.startswith("ERROR")]
1209
+ print(f"\nSubmitted {len(successful_jobs)} jobs successfully.")
1210
+ if successful_jobs:
1211
+ print(f"Job IDs:\n{','.join(successful_jobs).strip()}")
763
1212
 
764
- except Exception as e:
765
- print(f"Error: {e}")
1213
+ if len(successful_jobs) == len(job_ids):
1214
+ return 0
1215
+
1216
+ print("\nThe following issues arose during submission:")
1217
+ for j in job_ids:
1218
+ if j.startswith("ERROR"):
1219
+ print(j)
766
1220
 
767
1221
 
768
1222
  if __name__ == "__main__":