pytme 0.2.9__cp311-cp311-macosx_15_0_arm64.whl → 0.3.0__cp311-cp311-macosx_15_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (75) hide show
  1. pytme-0.3.0.data/scripts/estimate_memory_usage.py +76 -0
  2. pytme-0.3.0.data/scripts/match_template.py +1106 -0
  3. {pytme-0.2.9.data → pytme-0.3.0.data}/scripts/postprocess.py +320 -190
  4. {pytme-0.2.9.data → pytme-0.3.0.data}/scripts/preprocess.py +21 -31
  5. {pytme-0.2.9.data → pytme-0.3.0.data}/scripts/preprocessor_gui.py +85 -19
  6. pytme-0.3.0.data/scripts/pytme_runner.py +771 -0
  7. {pytme-0.2.9.dist-info → pytme-0.3.0.dist-info}/METADATA +22 -20
  8. pytme-0.3.0.dist-info/RECORD +126 -0
  9. {pytme-0.2.9.dist-info → pytme-0.3.0.dist-info}/entry_points.txt +2 -1
  10. pytme-0.3.0.dist-info/licenses/LICENSE +339 -0
  11. scripts/estimate_memory_usage.py +76 -0
  12. scripts/eval.py +93 -0
  13. scripts/extract_candidates.py +224 -0
  14. scripts/match_template.py +349 -378
  15. pytme-0.2.9.data/scripts/match_template.py → scripts/match_template_filters.py +213 -148
  16. scripts/postprocess.py +320 -190
  17. scripts/preprocess.py +21 -31
  18. scripts/preprocessor_gui.py +85 -19
  19. scripts/pytme_runner.py +771 -0
  20. scripts/refine_matches.py +625 -0
  21. tests/preprocessing/test_frequency_filters.py +28 -14
  22. tests/test_analyzer.py +41 -36
  23. tests/test_backends.py +1 -0
  24. tests/test_matching_cli.py +109 -53
  25. tests/test_matching_data.py +5 -5
  26. tests/test_matching_exhaustive.py +1 -2
  27. tests/test_matching_optimization.py +4 -9
  28. tests/test_matching_utils.py +1 -1
  29. tests/test_orientations.py +0 -1
  30. tme/__version__.py +1 -1
  31. tme/analyzer/__init__.py +2 -0
  32. tme/analyzer/_utils.py +26 -21
  33. tme/analyzer/aggregation.py +396 -222
  34. tme/analyzer/base.py +127 -0
  35. tme/analyzer/peaks.py +189 -201
  36. tme/analyzer/proxy.py +123 -0
  37. tme/backends/__init__.py +4 -3
  38. tme/backends/_cupy_utils.py +25 -24
  39. tme/backends/_jax_utils.py +20 -18
  40. tme/backends/cupy_backend.py +13 -26
  41. tme/backends/jax_backend.py +24 -23
  42. tme/backends/matching_backend.py +4 -3
  43. tme/backends/mlx_backend.py +4 -3
  44. tme/backends/npfftw_backend.py +34 -30
  45. tme/backends/pytorch_backend.py +18 -4
  46. tme/cli.py +126 -0
  47. tme/density.py +9 -7
  48. tme/extensions.cpython-311-darwin.so +0 -0
  49. tme/filters/__init__.py +3 -3
  50. tme/filters/_utils.py +36 -10
  51. tme/filters/bandpass.py +229 -188
  52. tme/filters/compose.py +5 -4
  53. tme/filters/ctf.py +516 -254
  54. tme/filters/reconstruction.py +91 -32
  55. tme/filters/wedge.py +196 -135
  56. tme/filters/whitening.py +37 -42
  57. tme/matching_data.py +28 -39
  58. tme/matching_exhaustive.py +31 -27
  59. tme/matching_optimization.py +5 -4
  60. tme/matching_scores.py +25 -15
  61. tme/matching_utils.py +158 -28
  62. tme/memory.py +4 -3
  63. tme/orientations.py +22 -9
  64. tme/parser.py +114 -33
  65. tme/preprocessor.py +6 -5
  66. tme/rotations.py +10 -7
  67. tme/structure.py +4 -3
  68. pytme-0.2.9.data/scripts/estimate_ram_usage.py +0 -97
  69. pytme-0.2.9.dist-info/RECORD +0 -119
  70. pytme-0.2.9.dist-info/licenses/LICENSE +0 -153
  71. scripts/estimate_ram_usage.py +0 -97
  72. tests/data/Maps/.DS_Store +0 -0
  73. tests/data/Structures/.DS_Store +0 -0
  74. {pytme-0.2.9.dist-info → pytme-0.3.0.dist-info}/WHEEL +0 -0
  75. {pytme-0.2.9.dist-info → pytme-0.3.0.dist-info}/top_level.txt +0 -0
scripts/postprocess.py CHANGED
@@ -1,22 +1,23 @@
1
1
  #!python3
2
- """ CLI to simplify analysing the output of match_template.py.
2
+ """CLI to simplify analysing the output of match_template.py.
3
3
 
4
- Copyright (c) 2023 European Molecular Biology Laboratory
4
+ Copyright (c) 2023 European Molecular Biology Laboratory
5
5
 
6
- Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
6
+ Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
7
7
  """
8
8
  import argparse
9
9
  from sys import exit
10
10
  from os import getcwd
11
- from typing import List, Tuple
12
- from os.path import join, abspath, splitext
11
+ from typing import Tuple, List
12
+ from os.path import join, splitext, basename
13
13
 
14
14
  import numpy as np
15
15
  from numpy.typing import NDArray
16
16
  from scipy.special import erfcinv
17
17
 
18
18
  from tme import Density, Structure, Orientations
19
- from tme.matching_utils import load_pickle, centered_mask
19
+ from tme.cli import sanitize_name, print_block, print_entry
20
+ from tme.matching_utils import load_pickle, centered_mask, write_pickle
20
21
  from tme.matching_optimization import create_score_object, optimize_match
21
22
  from tme.rotations import euler_to_rotationmatrix, euler_from_rotationmatrix
22
23
  from tme.analyzer import (
@@ -38,7 +39,10 @@ PEAK_CALLERS = {
38
39
 
39
40
 
40
41
  def parse_args():
41
- parser = argparse.ArgumentParser(description="Analyze Template Matching Outputs")
42
+ parser = argparse.ArgumentParser(
43
+ description="Analyze template matching outputs",
44
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
45
+ )
42
46
 
43
47
  input_group = parser.add_argument_group("Input")
44
48
  output_group = parser.add_argument_group("Output")
@@ -46,20 +50,23 @@ def parse_args():
46
50
  additional_group = parser.add_argument_group("Additional Parameters")
47
51
 
48
52
  input_group.add_argument(
49
- "--input_file",
53
+ "--input-file",
54
+ "--input-files",
50
55
  required=True,
51
56
  nargs="+",
52
- help="Path to the output of match_template.py.",
57
+ help="Path to one or multiple runs of match_template.py.",
53
58
  )
54
59
  input_group.add_argument(
55
- "--background_file",
60
+ "--background-file",
61
+ "--background-files",
56
62
  required=False,
57
63
  nargs="+",
58
- help="Path to an output of match_template.py used for normalization. "
64
+ default=[],
65
+ help="Path to one or multiple runs of match_template.py for normalization. "
59
66
  "For instance from --scramble_phases or a different template.",
60
67
  )
61
68
  input_group.add_argument(
62
- "--target_mask",
69
+ "--target-mask",
63
70
  required=False,
64
71
  type=str,
65
72
  help="Path to an optional mask applied to template matching scores.",
@@ -74,12 +81,14 @@ def parse_args():
74
81
  )
75
82
 
76
83
  output_group.add_argument(
77
- "--output_prefix",
78
- required=True,
79
- help="Output filename, extension will be added based on output_format.",
84
+ "--output-prefix",
85
+ required=False,
86
+ default=None,
87
+ help="Output prefix. Defaults to basename of first input. Extension is "
88
+ "added with respect to chosen output format.",
80
89
  )
81
90
  output_group.add_argument(
82
- "--output_format",
91
+ "--output-format",
83
92
  choices=[
84
93
  "orientations",
85
94
  "relion4",
@@ -87,6 +96,7 @@ def parse_args():
87
96
  "alignment",
88
97
  "extraction",
89
98
  "average",
99
+ "pickle",
90
100
  ],
91
101
  default="orientations",
92
102
  help="Available output formats: "
@@ -95,48 +105,49 @@ def parse_args():
95
105
  "relion5 (RELION 5 star format), "
96
106
  "alignment (aligned template to target based on orientations), "
97
107
  "extraction (extract regions around peaks from targets, i.e. subtomograms), "
98
- "average (extract matched regions from target and average them).",
108
+ "average (extract matched regions from target and average them)."
109
+ "pickle (results of applying mask and background correction for inspection).",
99
110
  )
100
111
 
101
112
  peak_group.add_argument(
102
- "--peak_caller",
113
+ "--peak-caller",
103
114
  choices=list(PEAK_CALLERS.keys()),
104
- default="PeakCallerScipy",
115
+ default="PeakCallerMaximumFilter",
105
116
  help="Peak caller for local maxima identification.",
106
117
  )
107
118
  peak_group.add_argument(
108
- "--minimum_score",
119
+ "--min-score",
109
120
  type=float,
110
- default=None,
121
+ default=0.0,
111
122
  help="Minimum score from which peaks will be considered.",
112
123
  )
113
124
  peak_group.add_argument(
114
- "--maximum_score",
125
+ "--max-score",
115
126
  type=float,
116
127
  default=None,
117
128
  help="Maximum score until which peaks will be considered.",
118
129
  )
119
130
  peak_group.add_argument(
120
- "--min_distance",
131
+ "--min-distance",
121
132
  type=int,
122
133
  default=5,
123
134
  help="Minimum distance between peaks.",
124
135
  )
125
136
  peak_group.add_argument(
126
- "--min_boundary_distance",
137
+ "--min-boundary-distance",
127
138
  type=int,
128
139
  default=0,
129
140
  help="Minimum distance of peaks to target edges.",
130
141
  )
131
142
  peak_group.add_argument(
132
- "--mask_edges",
143
+ "--mask-edges",
133
144
  action="store_true",
134
145
  default=False,
135
146
  help="Whether candidates should not be identified from scores that were "
136
147
  "computed from padded densities. Superseded by min_boundary_distance.",
137
148
  )
138
149
  peak_group.add_argument(
139
- "--num_peaks",
150
+ "--num-peaks",
140
151
  type=int,
141
152
  default=1000,
142
153
  required=False,
@@ -144,7 +155,7 @@ def parse_args():
144
155
  "If minimum_score is provided all peaks scoring higher will be reported.",
145
156
  )
146
157
  peak_group.add_argument(
147
- "--peak_oversampling",
158
+ "--peak-oversampling",
148
159
  type=int,
149
160
  default=1,
150
161
  help="1 / factor equals voxel precision, e.g. 2 detects half voxel "
@@ -152,34 +163,33 @@ def parse_args():
152
163
  )
153
164
 
154
165
  additional_group.add_argument(
155
- "--subtomogram_box_size",
166
+ "--extraction-box-size",
156
167
  type=int,
157
168
  default=None,
158
- help="Subtomogram box size, by default equal to the centered template. Will be "
159
- "padded to even values if output_format is relion.",
169
+ help="Box size of extracted subtomograms, defaults to the centered template.",
160
170
  )
161
171
  additional_group.add_argument(
162
- "--mask_subtomograms",
172
+ "--mask-subtomograms",
163
173
  action="store_true",
164
174
  default=False,
165
175
  help="Whether to mask subtomograms using the template mask. The mask will be "
166
176
  "rotated according to determined angles.",
167
177
  )
168
178
  additional_group.add_argument(
169
- "--invert_target_contrast",
179
+ "--invert-target-contrast",
170
180
  action="store_true",
171
181
  default=False,
172
182
  help="Whether to invert the target contrast.",
173
183
  )
174
184
  additional_group.add_argument(
175
- "--n_false_positives",
185
+ "--n-false-positives",
176
186
  type=int,
177
187
  default=None,
178
188
  required=False,
179
189
  help="Number of accepted false-positives picks to determine minimum score.",
180
190
  )
181
191
  additional_group.add_argument(
182
- "--local_optimization",
192
+ "--local-optimization",
183
193
  action="store_true",
184
194
  required=False,
185
195
  help="[Experimental] Perform local optimization of candidates. Useful when the "
@@ -188,21 +198,12 @@ def parse_args():
188
198
 
189
199
  args = parser.parse_args()
190
200
 
191
- if args.output_format == "relion" and args.subtomogram_box_size is not None:
192
- args.subtomogram_box_size += args.subtomogram_box_size % 2
201
+ if args.output_prefix is None:
202
+ args.output_prefix = splitext(basename(args.input_file[0]))[0]
193
203
 
194
204
  if args.orientations is not None:
195
205
  args.orientations = Orientations.from_file(filename=args.orientations)
196
206
 
197
- if args.background_file is None:
198
- args.background_file = [None]
199
- if len(args.background_file) == 1:
200
- args.background_file = args.background_file * len(args.input_file)
201
- elif len(args.background_file) not in (0, len(args.input_file)):
202
- raise ValueError(
203
- "--background_file needs to be specified once or for each --input_file."
204
- )
205
-
206
207
  return args
207
208
 
208
209
 
@@ -210,7 +211,6 @@ def load_template(
210
211
  filepath: str,
211
212
  sampling_rate: NDArray,
212
213
  centering: bool = True,
213
- target_shape: Tuple[int] = None,
214
214
  ):
215
215
  try:
216
216
  template = Density.from_file(filepath)
@@ -230,80 +230,228 @@ def load_template(
230
230
  return template, center, translation, template_is_density
231
231
 
232
232
 
233
- def merge_outputs(data, foreground_paths: List[str], background_paths: List[str], args):
234
- if len(foreground_paths) == 0:
235
- return data, 1
236
-
233
+ def load_matching_output(path: str) -> List:
234
+ data = load_pickle(path)
237
235
  if data[0].ndim != data[2].ndim:
238
- return data, 1
236
+ data = _peaks_to_volume(data)
237
+ return list(data)
239
238
 
240
- from tme.matching_exhaustive import normalize_under_mask
241
239
 
242
- def _norm_scores(data, args):
243
- target_origin, _, sampling_rate, cli_args = data[-1]
240
+ def _peaks_to_volume(data):
241
+ # Emulate the output of analyzer aggregators
242
+ translations = data[0].astype(int)
243
+ keep = (translations < 0).sum(axis=1) == 0
244
244
 
245
- _, template_extension = splitext(cli_args.template)
246
- ret = load_template(
247
- filepath=cli_args.template,
248
- sampling_rate=sampling_rate,
249
- centering=not cli_args.no_centering,
250
- )
251
- template, center_of_mass, translation, template_is_density = ret
245
+ translations = translations[keep]
246
+ rotations = data[1][keep]
252
247
 
253
- if args.mask_edges and args.min_boundary_distance == 0:
254
- max_shape = np.max(template.shape)
255
- args.min_boundary_distance = np.ceil(np.divide(max_shape, 2))
248
+ unique_rotations, rotation_map = np.unique(rotations, axis=0, return_inverse=True)
249
+ rotation_mapping = {
250
+ i: unique_rotations[i] for i in range(unique_rotations.shape[0])
251
+ }
256
252
 
257
- target_mask = 1
258
- if args.target_mask is not None:
259
- target_mask = Density.from_file(args.target_mask).data
260
- elif cli_args.target_mask is not None:
261
- target_mask = Density.from_file(args.target_mask).data
253
+ out_shape = np.max(translations, axis=0) + 1
262
254
 
263
- mask = np.ones_like(data[0])
264
- np.multiply(mask, target_mask, out=mask)
255
+ scores_out = np.full(out_shape, fill_value=0, dtype=np.float32)
256
+ scores_out[tuple(translations.T)] = data[2][keep]
265
257
 
266
- cropped_shape = np.subtract(
267
- mask.shape, np.multiply(args.min_boundary_distance, 2)
268
- ).astype(int)
269
- mask[cropped_shape] = 0
270
- normalize_under_mask(template=data[0], mask=mask, mask_intensity=mask.sum())
271
- return data[0]
272
-
273
- entities = np.zeros_like(data[0])
274
- data[0] = _norm_scores(data=data, args=args)
275
- for index, filepath in enumerate(foreground_paths):
276
- new_scores = _norm_scores(
277
- data=load_match_template_output(filepath, background_paths[index]),
278
- args=args,
279
- )
280
- indices = new_scores > data[0]
281
- entities[indices] = index + 1
282
- data[0][indices] = new_scores[indices]
258
+ rotations_out = np.full(out_shape, fill_value=-1, dtype=np.int32)
259
+ rotations_out[tuple(translations.T)] = rotation_map
283
260
 
284
- return data, entities
261
+ offset = np.zeros((scores_out.ndim), dtype=int)
262
+ return (scores_out, offset, rotations_out, rotation_mapping, data[-1])
263
+
264
+
265
+ def prepare_pickle_merge(paths):
266
+ new_rotation_mapping, out_shape = {}, None
267
+ for path in paths:
268
+ data = load_matching_output(path)
269
+ scores, _, rotations, rotation_mapping, *_ = data
270
+ if np.allclose(scores.shape, 0):
271
+ continue
272
+
273
+ if out_shape is None:
274
+ out_shape = scores.shape
275
+
276
+ if out_shape is not None and not np.allclose(out_shape, scores.shape):
277
+ print(
278
+ f"\nScore spaces have different sizes {out_shape} and {scores.shape}. "
279
+ "Assuming that both boxes are aligned at the origin, but please "
280
+ "make sure this is intentional."
281
+ )
282
+ out_shape = np.maximum(out_shape, scores.shape)
283
+
284
+ for key, value in rotation_mapping.items():
285
+ if key not in new_rotation_mapping:
286
+ new_rotation_mapping[key] = len(new_rotation_mapping)
287
+
288
+ return new_rotation_mapping, out_shape
289
+
290
+
291
+ def simple_stats(arr, decimals=3):
292
+ return {
293
+ "mean": round(float(arr.mean()), decimals),
294
+ "std": round(float(arr.std()), decimals),
295
+ "max": round(float(arr.max()), decimals),
296
+ }
297
+
298
+
299
+ def normalize_input(foregrounds: Tuple[str], backgrounds: Tuple[str]) -> Tuple:
300
+ # Determine output array shape and create consistent rotation map
301
+ new_rotation_mapping, out_shape = prepare_pickle_merge(foregrounds)
302
+
303
+ if out_shape is None:
304
+ exit("No valid score spaces found. Check messages aboves.")
305
+
306
+ print("\nFinished conversion - Now aggregating over entities.")
307
+ entities = np.full(out_shape, fill_value=-1, dtype=np.int32)
308
+ scores_out = np.full(out_shape, fill_value=0, dtype=np.float32)
309
+ rotations_out = np.full(out_shape, fill_value=-1, dtype=np.int32)
310
+
311
+ # We reload to avoid potential memory bottlenecks
312
+ for entity_index, foreground in enumerate(foregrounds):
313
+ data = load_matching_output(foreground)
314
+ scores, _, rotations, rotation_mapping, *_ = data
315
+
316
+ # We could normalize to unit sdev, but that might lead to unexpected
317
+ # results for flat background distributions
318
+ scores -= scores.mean()
319
+ indices = tuple(slice(0, x) for x in scores.shape)
320
+
321
+ indices_update = scores > scores_out[indices]
322
+ scores_out[indices][indices_update] = scores[indices_update]
323
+
324
+ lookup_table = np.arange(len(rotation_mapping) + 1, dtype=rotations_out.dtype)
325
+
326
+ # Maps rotation matrix to rotation index in rotations array
327
+ for key, _ in rotation_mapping.items():
328
+ lookup_table[key] = new_rotation_mapping[key]
329
+
330
+ updated_rotations = rotations[indices_update].astype(int)
331
+ if len(updated_rotations):
332
+ rotations_out[indices][indices_update] = lookup_table[updated_rotations]
333
+
334
+ entities[indices][indices_update] = entity_index
335
+
336
+ data = list(data)
337
+ data[0] = scores_out
338
+ data[2] = rotations_out
285
339
 
340
+ fg = simple_stats(data[0])
341
+ print(f"> Foreground {', '.join(str(k) + ' ' + str(v) for k, v in fg.items())}.")
286
342
 
287
- def load_match_template_output(foreground_path, background_path):
288
- data = load_pickle(foreground_path)
289
- if background_path is not None:
290
- data_background = load_pickle(background_path)
291
- data[0] = (data[0] - data_background[0]) / (1 - data_background[0])
292
- np.fmax(data[0], 0, out=data[0])
293
- return data
343
+ if not len(backgrounds):
344
+ print("\nScore statistics per entity")
345
+ for i in range(len(foregrounds)):
346
+ mask = entities == i
347
+ avg = "No occurences"
348
+ if mask.sum() != 0:
349
+ fg = simple_stats(data[0][mask])
350
+ avg = ", ".join(str(k) + " " + str(v) for k, v in fg.items())
351
+ print(f"> Entity {i}: {avg}.")
352
+ return data, entities
353
+
354
+ print("\nComputing and applying background correction.")
355
+ _, out_shape_norm = prepare_pickle_merge(backgrounds)
356
+
357
+ if not np.allclose(out_shape, out_shape_norm):
358
+ print(
359
+ f"Foreground and background have different sizes {out_shape} and "
360
+ f"{out_shape_norm}. Assuming that boxes are aligned at the origin and "
361
+ "dropping scores beyond, but make sure this is intentional."
362
+ )
363
+
364
+ scores_norm = np.full(out_shape_norm, fill_value=0, dtype=np.float32)
365
+ for background in backgrounds:
366
+ data_norm = load_matching_output(background)
367
+
368
+ scores = data_norm[0]
369
+ scores -= scores.mean()
370
+
371
+ indices = tuple(slice(0, x) for x in scores.shape)
372
+ indices_update = scores > scores_norm[indices]
373
+ scores_norm[indices][indices_update] = scores[indices_update]
374
+
375
+ # Set translations to zero that do not have background distribution
376
+ update = tuple(slice(0, int(x)) for x in np.minimum(out_shape, scores.shape))
377
+ scores_out = np.full(out_shape, fill_value=0, dtype=np.float32)
378
+ scores_out[update] = data[0][update] - scores_norm[update]
379
+ # scores_out[update] = np.divide(scores_out[update], 1 - scores_norm[update])
380
+ scores_out = np.fmax(scores_out, 0, out=scores_out)
381
+ data[0] = scores_out
382
+
383
+ fg, bg = simple_stats(data[0]), simple_stats(scores_norm)
384
+ print(f"> Background {', '.join(str(k) + ' ' + str(v) for k, v in bg.items())}.")
385
+ print(f"> Normalized {', '.join(str(k) + ' ' + str(v) for k, v in fg.items())}.")
386
+
387
+ print("\nScore statistics per entity")
388
+ for i in range(len(foregrounds)):
389
+ mask = entities == i
390
+ avg = "No occurences"
391
+ if mask.sum() != 0:
392
+ fg = simple_stats(data[0][mask])
393
+ avg = ", ".join(str(k) + " " + str(v) for k, v in fg.items())
394
+ print(f"> Entity {i}: {avg}.")
395
+
396
+ return data, entities
294
397
 
295
398
 
296
399
  def main():
297
400
  args = parse_args()
298
- data = load_match_template_output(args.input_file[0], args.background_file[0])
401
+ print_entry()
402
+
403
+ cli_kwargs = {
404
+ key: value
405
+ for key, value in sorted(vars(args).items())
406
+ if value is not None and key not in ("input_file", "background_file")
407
+ }
408
+ print_block(
409
+ name="Parameters",
410
+ data={sanitize_name(k): v for k, v in cli_kwargs.items()},
411
+ label_width=25,
412
+ )
413
+ print("\n" + "-" * 80)
414
+
415
+ print_block(
416
+ name=sanitize_name("Foreground entities"),
417
+ data={i: k for i, k in enumerate(args.input_file)},
418
+ label_width=25,
419
+ )
420
+
421
+ if len(args.background_file):
422
+ print_block(
423
+ name=sanitize_name("Background entities"),
424
+ data={i: k for i, k in enumerate(args.background_file)},
425
+ label_width=25,
426
+ )
427
+
428
+ data, entities = normalize_input(args.input_file, args.background_file)
429
+
430
+ if args.output_format == "pickle":
431
+ write_pickle(data, f"{args.output_prefix}.pickle")
432
+ exit(0)
433
+
434
+ if args.target_mask:
435
+ target_mask = Density.from_file(args.target_mask, use_memmap=True).data
436
+ if target_mask.shape != data[0].shape:
437
+ print(
438
+ f"Shape of target mask and scores do not match {target_mask} "
439
+ f"{data[0].shape}. Skipping mask application"
440
+ )
441
+ else:
442
+ np.multiply(data[0], target_mask, out=data[0])
299
443
 
300
444
  target_origin, _, sampling_rate, cli_args = data[-1]
301
445
 
446
+ # Backwards compatibility with pre v0.3.0b
447
+ if hasattr(cli_args, "no_centering"):
448
+ cli_args.centering = not cli_args.no_centering
449
+
302
450
  _, template_extension = splitext(cli_args.template)
303
451
  ret = load_template(
304
452
  filepath=cli_args.template,
305
453
  sampling_rate=sampling_rate,
306
- centering=not cli_args.no_centering,
454
+ centering=cli_args.centering,
307
455
  )
308
456
  template, center_of_mass, translation, template_is_density = ret
309
457
 
@@ -327,96 +475,79 @@ def main():
327
475
  max_shape = np.max(template.shape)
328
476
  args.min_boundary_distance = np.ceil(np.divide(max_shape, 2))
329
477
 
330
- entities = None
331
- if len(args.input_file) > 1:
332
- data, entities = merge_outputs(
333
- data=data,
334
- foreground_paths=args.input_file,
335
- background_paths=args.background_file,
336
- args=args,
337
- )
338
-
339
478
  orientations = args.orientations
340
479
  if orientations is None:
341
480
  translations, rotations, scores, details = [], [], [], []
342
- # Output is MaxScoreOverRotations
343
- if data[0].ndim == data[2].ndim:
344
- scores, offset, rotation_array, rotation_mapping, meta = data
345
-
346
- if args.target_mask is not None:
347
- target_mask = Density.from_file(args.target_mask)
348
- scores = scores * target_mask.data
349
-
350
- cropped_shape = np.subtract(
351
- scores.shape, np.multiply(args.min_boundary_distance, 2)
352
- ).astype(int)
353
-
354
- if args.min_boundary_distance > 0:
355
- scores = centered_mask(scores, new_shape=cropped_shape)
356
-
357
- if args.n_false_positives is not None:
358
- # Rickgauer et al. 2017
359
- cropped_slice = tuple(
360
- slice(
361
- int(args.min_boundary_distance),
362
- int(x - args.min_boundary_distance),
363
- )
364
- for x in scores.shape
365
- )
366
- args.n_false_positives = max(args.n_false_positives, 1)
367
- n_correlations = np.size(scores[cropped_slice]) * len(rotation_mapping)
368
- minimum_score = np.multiply(
369
- erfcinv(2 * args.n_false_positives / n_correlations),
370
- np.sqrt(2) * np.std(scores[cropped_slice]),
481
+
482
+ # Data processed by normalize_input is guaranteed to have this shape
483
+ scores, offset, rotation_array, rotation_mapping, meta = data
484
+
485
+ cropped_shape = np.subtract(
486
+ scores.shape, np.multiply(args.min_boundary_distance, 2)
487
+ ).astype(int)
488
+
489
+ if args.min_boundary_distance > 0:
490
+ scores = centered_mask(scores, new_shape=cropped_shape)
491
+
492
+ if args.n_false_positives is not None:
493
+ # Rickgauer et al. 2017
494
+ cropped_slice = tuple(
495
+ slice(
496
+ int(args.min_boundary_distance),
497
+ int(x - args.min_boundary_distance),
371
498
  )
372
- print(f"Determined minimum score cutoff: {minimum_score}.")
373
- minimum_score = max(minimum_score, 0)
374
- args.minimum_score = minimum_score
375
-
376
- args.batch_dims = None
377
- if hasattr(cli_args, "target_batch"):
378
- args.batch_dims = cli_args.target_batch
379
-
380
- peak_caller_kwargs = {
381
- "shape": scores.shape,
382
- "num_peaks": args.num_peaks,
383
- "min_distance": args.min_distance,
384
- "min_boundary_distance": args.min_boundary_distance,
385
- "batch_dims": args.batch_dims,
386
- "minimum_score": args.minimum_score,
387
- "maximum_score": args.maximum_score,
388
- }
389
-
390
- peak_caller = PEAK_CALLERS[args.peak_caller](**peak_caller_kwargs)
391
- peak_caller(
392
- scores,
393
- rotation_matrix=np.eye(template.data.ndim),
394
- mask=template.data,
395
- rotation_mapping=rotation_mapping,
396
- rotations=rotation_array,
499
+ for x in scores.shape
397
500
  )
398
- candidates = peak_caller.merge(
399
- candidates=[tuple(peak_caller)], **peak_caller_kwargs
501
+ args.n_false_positives = max(args.n_false_positives, 1)
502
+ n_correlations = np.size(scores[cropped_slice]) * len(rotation_mapping)
503
+ minimum_score = np.multiply(
504
+ erfcinv(2 * args.n_false_positives / n_correlations),
505
+ np.sqrt(2) * np.std(scores[cropped_slice]),
400
506
  )
401
- if len(candidates) == 0:
402
- candidates = [[], [], [], []]
403
- print("Found no peaks, consider changing peak calling parameters.")
404
- exit(-1)
405
-
406
- for translation, _, score, detail in zip(*candidates):
407
- rotation_index = rotation_array[tuple(translation)]
408
- rotation = rotation_mapping.get(
409
- rotation_index, np.zeros(template.data.ndim, int)
410
- )
411
- if rotation.ndim == 2:
412
- rotation = euler_from_rotationmatrix(rotation)
413
- rotations.append(rotation)
507
+ print(f"Determined minimum score cutoff: {minimum_score}.")
508
+ minimum_score = max(minimum_score, 0)
509
+ args.min_score = minimum_score
510
+
511
+ args.batch_dims = None
512
+ if hasattr(cli_args, "target_batch"):
513
+ args.batch_dims = cli_args.target_batch
514
+
515
+ peak_caller_kwargs = {
516
+ "shape": scores.shape,
517
+ "num_peaks": args.num_peaks,
518
+ "min_distance": args.min_distance,
519
+ "min_boundary_distance": args.min_boundary_distance,
520
+ "batch_dims": args.batch_dims,
521
+ "minimum_score": args.min_score,
522
+ "maximum_score": args.max_score,
523
+ }
524
+
525
+ peak_caller = PEAK_CALLERS[args.peak_caller](**peak_caller_kwargs)
526
+ state = peak_caller.init_state()
527
+ state = peak_caller(
528
+ state,
529
+ scores,
530
+ rotation_matrix=np.eye(template.data.ndim),
531
+ mask=template_mask.data,
532
+ rotation_mapping=rotation_mapping,
533
+ rotations=rotation_array,
534
+ )
535
+ candidates = peak_caller.merge(
536
+ results=[peak_caller.result(state)], **peak_caller_kwargs
537
+ )
538
+ if len(candidates) == 0:
539
+ candidates = [[], [], [], []]
540
+ print("Found no peaks, consider changing peak calling parameters.")
541
+ exit(-1)
414
542
 
415
- else:
416
- candidates = data
417
- translation, rotation, *_ = data
418
- for i in range(translation.shape[0]):
419
- rotations.append(euler_from_rotationmatrix(rotation[i]))
543
+ for translation, _, score, detail in zip(*candidates):
544
+ rotation_index = rotation_array[tuple(translation)]
545
+ rotation = rotation_mapping.get(
546
+ rotation_index, np.zeros(template.data.ndim, int)
547
+ )
548
+ if rotation.ndim == 2:
549
+ rotation = euler_from_rotationmatrix(rotation)
550
+ rotations.append(rotation)
420
551
 
421
552
  if len(rotations):
422
553
  rotations = np.vstack(rotations).astype(float)
@@ -432,12 +563,12 @@ def main():
432
563
  details=details,
433
564
  )
434
565
 
435
- if args.minimum_score is not None and len(orientations.scores):
436
- keep = orientations.scores >= args.minimum_score
566
+ if args.min_score is not None and len(orientations.scores):
567
+ keep = orientations.scores >= args.min_score
437
568
  orientations = orientations[keep]
438
569
 
439
- if args.maximum_score is not None and len(orientations.scores):
440
- keep = orientations.scores <= args.maximum_score
570
+ if args.max_score is not None and len(orientations.scores):
571
+ keep = orientations.scores <= args.max_score
441
572
  orientations = orientations[keep]
442
573
 
443
574
  if args.peak_oversampling > 1:
@@ -502,7 +633,7 @@ def main():
502
633
  orientations.to_file(
503
634
  filename=f"{args.output_prefix}.{extension}",
504
635
  file_format=file_format,
505
- source_path=cli_args.target,
636
+ source_path=basename(cli_args.target),
506
637
  version=version,
507
638
  )
508
639
  exit(0)
@@ -520,9 +651,9 @@ def main():
520
651
  )
521
652
 
522
653
  extraction_shape = template.shape
523
- if args.subtomogram_box_size is not None:
654
+ if args.extraction_box_size is not None:
524
655
  extraction_shape = np.repeat(
525
- args.subtomogram_box_size, len(extraction_shape)
656
+ args.extraction_box_size, len(extraction_shape)
526
657
  )
527
658
 
528
659
  orientations, cand_slices, obs_slices = orientations.get_extraction_slices(
@@ -587,8 +718,7 @@ def main():
587
718
  template, center, *_ = load_template(
588
719
  filepath=cli_args.template,
589
720
  sampling_rate=sampling_rate,
590
- centering=not cli_args.no_centering,
591
- target_shape=target.shape,
721
+ centering=cli_args.centering,
592
722
  )
593
723
 
594
724
  for index, (translation, angles, *_) in enumerate(orientations):