pytme 0.2.9__cp311-cp311-macosx_15_0_arm64.whl → 0.3b0__cp311-cp311-macosx_15_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. pytme-0.2.9.data/scripts/estimate_ram_usage.py → pytme-0.3b0.data/scripts/estimate_memory_usage.py +16 -33
  2. {pytme-0.2.9.data → pytme-0.3b0.data}/scripts/match_template.py +224 -223
  3. {pytme-0.2.9.data → pytme-0.3b0.data}/scripts/postprocess.py +283 -163
  4. {pytme-0.2.9.data → pytme-0.3b0.data}/scripts/preprocess.py +11 -8
  5. {pytme-0.2.9.data → pytme-0.3b0.data}/scripts/preprocessor_gui.py +10 -9
  6. {pytme-0.2.9.dist-info → pytme-0.3b0.dist-info}/METADATA +11 -9
  7. {pytme-0.2.9.dist-info → pytme-0.3b0.dist-info}/RECORD +61 -58
  8. {pytme-0.2.9.dist-info → pytme-0.3b0.dist-info}/entry_points.txt +1 -1
  9. scripts/{estimate_ram_usage.py → estimate_memory_usage.py} +16 -33
  10. scripts/extract_candidates.py +224 -0
  11. scripts/match_template.py +224 -223
  12. scripts/postprocess.py +283 -163
  13. scripts/preprocess.py +11 -8
  14. scripts/preprocessor_gui.py +10 -9
  15. scripts/refine_matches.py +626 -0
  16. tests/preprocessing/test_frequency_filters.py +9 -4
  17. tests/test_analyzer.py +143 -138
  18. tests/test_matching_cli.py +85 -29
  19. tests/test_matching_exhaustive.py +1 -2
  20. tests/test_matching_optimization.py +4 -9
  21. tests/test_orientations.py +0 -1
  22. tme/__version__.py +1 -1
  23. tme/analyzer/__init__.py +2 -0
  24. tme/analyzer/_utils.py +25 -17
  25. tme/analyzer/aggregation.py +385 -220
  26. tme/analyzer/base.py +138 -0
  27. tme/analyzer/peaks.py +150 -88
  28. tme/analyzer/proxy.py +122 -0
  29. tme/backends/__init__.py +4 -3
  30. tme/backends/_cupy_utils.py +25 -24
  31. tme/backends/_jax_utils.py +4 -3
  32. tme/backends/cupy_backend.py +4 -13
  33. tme/backends/jax_backend.py +6 -8
  34. tme/backends/matching_backend.py +4 -3
  35. tme/backends/mlx_backend.py +4 -3
  36. tme/backends/npfftw_backend.py +7 -5
  37. tme/backends/pytorch_backend.py +14 -4
  38. tme/cli.py +126 -0
  39. tme/density.py +4 -3
  40. tme/filters/__init__.py +1 -1
  41. tme/filters/_utils.py +4 -3
  42. tme/filters/bandpass.py +6 -4
  43. tme/filters/compose.py +5 -4
  44. tme/filters/ctf.py +426 -214
  45. tme/filters/reconstruction.py +58 -28
  46. tme/filters/wedge.py +139 -61
  47. tme/filters/whitening.py +36 -36
  48. tme/matching_data.py +4 -3
  49. tme/matching_exhaustive.py +17 -16
  50. tme/matching_optimization.py +5 -4
  51. tme/matching_scores.py +4 -3
  52. tme/matching_utils.py +6 -4
  53. tme/memory.py +4 -3
  54. tme/orientations.py +9 -6
  55. tme/parser.py +5 -4
  56. tme/preprocessor.py +4 -3
  57. tme/rotations.py +10 -7
  58. tme/structure.py +4 -3
  59. tests/data/Maps/.DS_Store +0 -0
  60. tests/data/Structures/.DS_Store +0 -0
  61. {pytme-0.2.9.dist-info → pytme-0.3b0.dist-info}/WHEEL +0 -0
  62. {pytme-0.2.9.dist-info → pytme-0.3b0.dist-info}/licenses/LICENSE +0 -0
  63. {pytme-0.2.9.dist-info → pytme-0.3b0.dist-info}/top_level.txt +0 -0
@@ -1,22 +1,23 @@
1
1
  #!python
2
- """ CLI to simplify analysing the output of match_template.py.
2
+ """CLI to simplify analysing the output of match_template.py.
3
3
 
4
- Copyright (c) 2023 European Molecular Biology Laboratory
4
+ Copyright (c) 2023 European Molecular Biology Laboratory
5
5
 
6
- Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
6
+ Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
7
7
  """
8
8
  import argparse
9
9
  from sys import exit
10
10
  from os import getcwd
11
- from typing import List, Tuple
12
- from os.path import join, abspath, splitext
11
+ from typing import Tuple, List
12
+ from os.path import join, splitext
13
13
 
14
14
  import numpy as np
15
15
  from numpy.typing import NDArray
16
16
  from scipy.special import erfcinv
17
17
 
18
18
  from tme import Density, Structure, Orientations
19
- from tme.matching_utils import load_pickle, centered_mask
19
+ from tme.cli import sanitize_name, print_block, print_entry
20
+ from tme.matching_utils import load_pickle, centered_mask, write_pickle
20
21
  from tme.matching_optimization import create_score_object, optimize_match
21
22
  from tme.rotations import euler_to_rotationmatrix, euler_from_rotationmatrix
22
23
  from tme.analyzer import (
@@ -38,7 +39,10 @@ PEAK_CALLERS = {
38
39
 
39
40
 
40
41
  def parse_args():
41
- parser = argparse.ArgumentParser(description="Analyze Template Matching Outputs")
42
+ parser = argparse.ArgumentParser(
43
+ description="Analyze template matching outputs",
44
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
45
+ )
42
46
 
43
47
  input_group = parser.add_argument_group("Input")
44
48
  output_group = parser.add_argument_group("Output")
@@ -47,15 +51,18 @@ def parse_args():
47
51
 
48
52
  input_group.add_argument(
49
53
  "--input_file",
54
+ "--input_files",
50
55
  required=True,
51
56
  nargs="+",
52
- help="Path to the output of match_template.py.",
57
+ help="Path to one or multiple runs of match_template.py.",
53
58
  )
54
59
  input_group.add_argument(
55
60
  "--background_file",
61
+ "--background_files",
56
62
  required=False,
57
63
  nargs="+",
58
- help="Path to an output of match_template.py used for normalization. "
64
+ default=[],
65
+ help="Path to one or multiple runs of match_template.py for normalization. "
59
66
  "For instance from --scramble_phases or a different template.",
60
67
  )
61
68
  input_group.add_argument(
@@ -87,6 +94,7 @@ def parse_args():
87
94
  "alignment",
88
95
  "extraction",
89
96
  "average",
97
+ "pickle",
90
98
  ],
91
99
  default="orientations",
92
100
  help="Available output formats: "
@@ -95,7 +103,8 @@ def parse_args():
95
103
  "relion5 (RELION 5 star format), "
96
104
  "alignment (aligned template to target based on orientations), "
97
105
  "extraction (extract regions around peaks from targets, i.e. subtomograms), "
98
- "average (extract matched regions from target and average them).",
106
+ "average (extract matched regions from target and average them)."
107
+ "pickle (results of applying mask and background correction for inspection).",
99
108
  )
100
109
 
101
110
  peak_group.add_argument(
@@ -107,7 +116,7 @@ def parse_args():
107
116
  peak_group.add_argument(
108
117
  "--minimum_score",
109
118
  type=float,
110
- default=None,
119
+ default=0.0,
111
120
  help="Minimum score from which peaks will be considered.",
112
121
  )
113
122
  peak_group.add_argument(
@@ -152,11 +161,10 @@ def parse_args():
152
161
  )
153
162
 
154
163
  additional_group.add_argument(
155
- "--subtomogram_box_size",
164
+ "--extraction_box_size",
156
165
  type=int,
157
166
  default=None,
158
- help="Subtomogram box size, by default equal to the centered template. Will be "
159
- "padded to even values if output_format is relion.",
167
+ help="Box size of extracted subtomograms, defaults to the centered template.",
160
168
  )
161
169
  additional_group.add_argument(
162
170
  "--mask_subtomograms",
@@ -188,21 +196,9 @@ def parse_args():
188
196
 
189
197
  args = parser.parse_args()
190
198
 
191
- if args.output_format == "relion" and args.subtomogram_box_size is not None:
192
- args.subtomogram_box_size += args.subtomogram_box_size % 2
193
-
194
199
  if args.orientations is not None:
195
200
  args.orientations = Orientations.from_file(filename=args.orientations)
196
201
 
197
- if args.background_file is None:
198
- args.background_file = [None]
199
- if len(args.background_file) == 1:
200
- args.background_file = args.background_file * len(args.input_file)
201
- elif len(args.background_file) not in (0, len(args.input_file)):
202
- raise ValueError(
203
- "--background_file needs to be specified once or for each --input_file."
204
- )
205
-
206
202
  return args
207
203
 
208
204
 
@@ -210,7 +206,6 @@ def load_template(
210
206
  filepath: str,
211
207
  sampling_rate: NDArray,
212
208
  centering: bool = True,
213
- target_shape: Tuple[int] = None,
214
209
  ):
215
210
  try:
216
211
  template = Density.from_file(filepath)
@@ -230,72 +225,211 @@ def load_template(
230
225
  return template, center, translation, template_is_density
231
226
 
232
227
 
233
- def merge_outputs(data, foreground_paths: List[str], background_paths: List[str], args):
234
- if len(foreground_paths) == 0:
235
- return data, 1
236
-
228
+ def load_matching_output(path: str) -> List:
229
+ data = load_pickle(path)
237
230
  if data[0].ndim != data[2].ndim:
238
- return data, 1
231
+ data = _peaks_to_volume(data)
232
+ return list(data)
239
233
 
240
- from tme.matching_exhaustive import normalize_under_mask
241
234
 
242
- def _norm_scores(data, args):
243
- target_origin, _, sampling_rate, cli_args = data[-1]
235
+ def _peaks_to_volume(data):
236
+ # Emulate the output of analyzer aggregators
237
+ translations = data[0].astype(int)
238
+ keep = (translations < 0).sum(axis=1) == 0
244
239
 
245
- _, template_extension = splitext(cli_args.template)
246
- ret = load_template(
247
- filepath=cli_args.template,
248
- sampling_rate=sampling_rate,
249
- centering=not cli_args.no_centering,
250
- )
251
- template, center_of_mass, translation, template_is_density = ret
240
+ translations = translations[keep]
241
+ rotations = data[1][keep]
252
242
 
253
- if args.mask_edges and args.min_boundary_distance == 0:
254
- max_shape = np.max(template.shape)
255
- args.min_boundary_distance = np.ceil(np.divide(max_shape, 2))
243
+ unique_rotations, rotation_map = np.unique(rotations, axis=0, return_inverse=True)
244
+ rotation_mapping = {
245
+ i: unique_rotations[i] for i in range(unique_rotations.shape[0])
246
+ }
256
247
 
257
- target_mask = 1
258
- if args.target_mask is not None:
259
- target_mask = Density.from_file(args.target_mask).data
260
- elif cli_args.target_mask is not None:
261
- target_mask = Density.from_file(args.target_mask).data
248
+ out_shape = np.max(translations, axis=0) + 1
262
249
 
263
- mask = np.ones_like(data[0])
264
- np.multiply(mask, target_mask, out=mask)
250
+ scores_out = np.full(out_shape, fill_value=0, dtype=np.float32)
251
+ scores_out[tuple(translations.T)] = data[2][keep]
265
252
 
266
- cropped_shape = np.subtract(
267
- mask.shape, np.multiply(args.min_boundary_distance, 2)
268
- ).astype(int)
269
- mask[cropped_shape] = 0
270
- normalize_under_mask(template=data[0], mask=mask, mask_intensity=mask.sum())
271
- return data[0]
272
-
273
- entities = np.zeros_like(data[0])
274
- data[0] = _norm_scores(data=data, args=args)
275
- for index, filepath in enumerate(foreground_paths):
276
- new_scores = _norm_scores(
277
- data=load_match_template_output(filepath, background_paths[index]),
278
- args=args,
279
- )
280
- indices = new_scores > data[0]
281
- entities[indices] = index + 1
282
- data[0][indices] = new_scores[indices]
253
+ rotations_out = np.full(out_shape, fill_value=-1, dtype=np.int32)
254
+ rotations_out[tuple(translations.T)] = rotation_map
283
255
 
284
- return data, entities
256
+ offset = np.zeros((scores_out.ndim), dtype=int)
257
+ return (scores_out, offset, rotations_out, rotation_mapping, data[-1])
258
+
259
+
260
+ def prepare_pickle_merge(paths):
261
+ new_rotation_mapping, out_shape = {}, None
262
+ for path in paths:
263
+ data = load_matching_output(path)
264
+ scores, _, rotations, rotation_mapping, *_ = data
265
+ if np.allclose(scores.shape, 0):
266
+ continue
267
+
268
+ if out_shape is None:
269
+ out_shape = scores.shape
270
+
271
+ if out_shape is not None and not np.allclose(out_shape, scores.shape):
272
+ print(
273
+ f"\nScore spaces have different sizes {out_shape} and {scores.shape}. "
274
+ "Assuming that both boxes are aligned at the origin, but please "
275
+ "make sure this is intentional."
276
+ )
277
+ out_shape = np.maximum(out_shape, scores.shape)
278
+
279
+ for key, value in rotation_mapping.items():
280
+ if key not in new_rotation_mapping:
281
+ new_rotation_mapping[key] = len(new_rotation_mapping)
282
+
283
+ return new_rotation_mapping, out_shape
284
+
285
+
286
+ def simple_stats(arr, decimals=3):
287
+ return {
288
+ "mean": round(float(arr.mean()), decimals),
289
+ "std": round(float(arr.std()), decimals),
290
+ "max": round(float(arr.max()), decimals),
291
+ }
292
+
293
+
294
+ def normalize_input(foregrounds: Tuple[str], backgrounds: Tuple[str]) -> Tuple:
295
+ # Determine output array shape and create consistent rotation map
296
+ new_rotation_mapping, out_shape = prepare_pickle_merge(foregrounds)
297
+
298
+ if out_shape is None:
299
+ exit("No valid score spaces found. Check messages aboves.")
300
+
301
+ print("\nFinished conversion - Now aggregating over entities.")
302
+ entities = np.full(out_shape, fill_value=-1, dtype=np.int32)
303
+ scores_out = np.full(out_shape, fill_value=0, dtype=np.float32)
304
+ rotations_out = np.full(out_shape, fill_value=-1, dtype=np.int32)
305
+
306
+ # We reload to avoid potential memory bottlenecks
307
+ for entity_index, foreground in enumerate(foregrounds):
308
+ data = load_matching_output(foreground)
309
+ scores, _, rotations, rotation_mapping, *_ = data
310
+
311
+ # We could normalize to unit sdev, but that might lead to unexpected
312
+ # results for flat background distributions
313
+ scores -= scores.mean()
314
+ indices = tuple(slice(0, x) for x in scores.shape)
315
+
316
+ indices_update = scores > scores_out[indices]
317
+ scores_out[indices][indices_update] = scores[indices_update]
318
+
319
+ lookup_table = np.arange(len(rotation_mapping) + 1, dtype=rotations_out.dtype)
320
+
321
+ # Maps rotation matrix to rotation index in rotations array
322
+ for key, _ in rotation_mapping.items():
323
+ lookup_table[key] = new_rotation_mapping[key]
324
+
325
+ updated_rotations = rotations[indices_update].astype(int)
326
+ if len(updated_rotations):
327
+ rotations_out[indices][indices_update] = lookup_table[updated_rotations]
328
+
329
+ entities[indices][indices_update] = entity_index
285
330
 
331
+ data = list(data)
332
+ data[0] = scores_out
333
+ data[2] = rotations_out
286
334
 
287
- def load_match_template_output(foreground_path, background_path):
288
- data = load_pickle(foreground_path)
289
- if background_path is not None:
290
- data_background = load_pickle(background_path)
291
- data[0] = (data[0] - data_background[0]) / (1 - data_background[0])
292
- np.fmax(data[0], 0, out=data[0])
293
- return data
335
+ fg = simple_stats(data[0])
336
+ print(f"> Foreground {', '.join(str(k) + ' ' + str(v) for k, v in fg.items())}.")
337
+
338
+ if not len(backgrounds):
339
+ print("\nScore statistics per entity")
340
+ for i in range(len(foregrounds)):
341
+ mask = entities == i
342
+ avg = "No occurences"
343
+ if mask.sum() != 0:
344
+ fg = simple_stats(data[0][mask])
345
+ avg = ", ".join(str(k) + " " + str(v) for k, v in fg.items())
346
+ print(f"> Entity {i}: {avg}.")
347
+ return data, entities
348
+
349
+ print("\nComputing and applying background correction.")
350
+ _, out_shape_norm = prepare_pickle_merge(backgrounds)
351
+
352
+ if not np.allclose(out_shape, out_shape_norm):
353
+ print(
354
+ f"Foreground and background have different sizes {out_shape} and "
355
+ f"{out_shape_norm}. Assuming that boxes are aligned at the origin and "
356
+ "dropping scores beyond, but make sure this is intentional."
357
+ )
358
+
359
+ scores_norm = np.full(out_shape_norm, fill_value=0, dtype=np.float32)
360
+ for background in backgrounds:
361
+ data_norm = load_matching_output(background)
362
+
363
+ scores = data_norm[0]
364
+ scores -= scores.mean()
365
+ indices = tuple(slice(0, x) for x in scores.shape)
366
+ indices_update = scores > scores_norm[indices]
367
+ scores_norm[indices][indices_update] = scores[indices_update]
368
+
369
+ # Set translations to zero that do not have background distribution
370
+ update = tuple(slice(0, int(x)) for x in np.minimum(out_shape, scores.shape))
371
+ scores_out = np.full(out_shape, fill_value=0, dtype=np.float32)
372
+ scores_out[update] = data[0][update] - scores_norm[update]
373
+ scores_out[update] = np.divide(scores_out[update], 1 - scores_norm[update])
374
+ np.fmax(scores_out, 0, out=scores_out)
375
+ data[0] = scores_out
376
+
377
+ fg, bg = simple_stats(data[0]), simple_stats(scores_norm)
378
+ print(f"> Background {', '.join(str(k) + ' ' + str(v) for k, v in bg.items())}.")
379
+ print(f"> Normalized {', '.join(str(k) + ' ' + str(v) for k, v in fg.items())}.")
380
+
381
+ print("\nScore statistics per entity")
382
+ for i in range(len(foregrounds)):
383
+ mask = entities == i
384
+ avg = "No occurences"
385
+ if mask.sum() != 0:
386
+ fg = simple_stats(data[0][mask])
387
+ avg = ", ".join(str(k) + " " + str(v) for k, v in fg.items())
388
+ print(f"> Entity {i}: {avg}.")
389
+
390
+ return data, entities
294
391
 
295
392
 
296
393
  def main():
297
394
  args = parse_args()
298
- data = load_match_template_output(args.input_file[0], args.background_file[0])
395
+ print_entry()
396
+
397
+ cli_kwargs = {
398
+ key: value
399
+ for key, value in sorted(vars(args).items())
400
+ if value is not None and key not in ("input_file", "background_file")
401
+ }
402
+ print_block(
403
+ name="Parameters",
404
+ data={sanitize_name(k): v for k, v in cli_kwargs.items()},
405
+ label_width=25,
406
+ )
407
+ print("\n" + "-" * 80)
408
+
409
+ print_block(
410
+ name=sanitize_name("Foreground entities"),
411
+ data={i: k for i, k in enumerate(args.input_file)},
412
+ label_width=25,
413
+ )
414
+
415
+ if len(args.background_file):
416
+ print_block(
417
+ name=sanitize_name("Background entities"),
418
+ data={i: k for i, k in enumerate(args.background_file)},
419
+ label_width=25,
420
+ )
421
+
422
+ data, entities = normalize_input(args.input_file, args.background_file)
423
+
424
+ if args.target_mask:
425
+ target_mask = Density.from_file(args.target_mask, use_memmap=True).data
426
+ if target_mask.shape != data[0].shape:
427
+ print(
428
+ f"Shape of target mask and scores do not match {target_mask} "
429
+ f"{data[0].shape}. Skipping mask application"
430
+ )
431
+ else:
432
+ np.multiply(data[0], target_mask, out=data[0])
299
433
 
300
434
  target_origin, _, sampling_rate, cli_args = data[-1]
301
435
 
@@ -327,96 +461,83 @@ def main():
327
461
  max_shape = np.max(template.shape)
328
462
  args.min_boundary_distance = np.ceil(np.divide(max_shape, 2))
329
463
 
330
- entities = None
331
- if len(args.input_file) > 1:
332
- data, entities = merge_outputs(
333
- data=data,
334
- foreground_paths=args.input_file,
335
- background_paths=args.background_file,
336
- args=args,
337
- )
464
+ if args.output_format == "pickle":
465
+ write_pickle(data, f"{args.output_prefix}.pickle")
466
+ exit(0)
338
467
 
339
468
  orientations = args.orientations
340
469
  if orientations is None:
341
470
  translations, rotations, scores, details = [], [], [], []
342
- # Output is MaxScoreOverRotations
343
- if data[0].ndim == data[2].ndim:
344
- scores, offset, rotation_array, rotation_mapping, meta = data
345
-
346
- if args.target_mask is not None:
347
- target_mask = Density.from_file(args.target_mask)
348
- scores = scores * target_mask.data
349
-
350
- cropped_shape = np.subtract(
351
- scores.shape, np.multiply(args.min_boundary_distance, 2)
352
- ).astype(int)
353
-
354
- if args.min_boundary_distance > 0:
355
- scores = centered_mask(scores, new_shape=cropped_shape)
356
-
357
- if args.n_false_positives is not None:
358
- # Rickgauer et al. 2017
359
- cropped_slice = tuple(
360
- slice(
361
- int(args.min_boundary_distance),
362
- int(x - args.min_boundary_distance),
363
- )
364
- for x in scores.shape
365
- )
366
- args.n_false_positives = max(args.n_false_positives, 1)
367
- n_correlations = np.size(scores[cropped_slice]) * len(rotation_mapping)
368
- minimum_score = np.multiply(
369
- erfcinv(2 * args.n_false_positives / n_correlations),
370
- np.sqrt(2) * np.std(scores[cropped_slice]),
471
+
472
+ # Data processed by normalize_input is guaranteed to have this shape
473
+ scores, offset, rotation_array, rotation_mapping, meta = data
474
+
475
+ cropped_shape = np.subtract(
476
+ scores.shape, np.multiply(args.min_boundary_distance, 2)
477
+ ).astype(int)
478
+
479
+ if args.min_boundary_distance > 0:
480
+ scores = centered_mask(scores, new_shape=cropped_shape)
481
+
482
+ if args.n_false_positives is not None:
483
+ # Rickgauer et al. 2017
484
+ cropped_slice = tuple(
485
+ slice(
486
+ int(args.min_boundary_distance),
487
+ int(x - args.min_boundary_distance),
371
488
  )
372
- print(f"Determined minimum score cutoff: {minimum_score}.")
373
- minimum_score = max(minimum_score, 0)
374
- args.minimum_score = minimum_score
375
-
376
- args.batch_dims = None
377
- if hasattr(cli_args, "target_batch"):
378
- args.batch_dims = cli_args.target_batch
379
-
380
- peak_caller_kwargs = {
381
- "shape": scores.shape,
382
- "num_peaks": args.num_peaks,
383
- "min_distance": args.min_distance,
384
- "min_boundary_distance": args.min_boundary_distance,
385
- "batch_dims": args.batch_dims,
386
- "minimum_score": args.minimum_score,
387
- "maximum_score": args.maximum_score,
388
- }
389
-
390
- peak_caller = PEAK_CALLERS[args.peak_caller](**peak_caller_kwargs)
391
- peak_caller(
392
- scores,
393
- rotation_matrix=np.eye(template.data.ndim),
394
- mask=template.data,
395
- rotation_mapping=rotation_mapping,
396
- rotations=rotation_array,
489
+ for x in scores.shape
397
490
  )
398
- candidates = peak_caller.merge(
399
- candidates=[tuple(peak_caller)], **peak_caller_kwargs
491
+ args.n_false_positives = max(args.n_false_positives, 1)
492
+ n_correlations = np.size(scores[cropped_slice]) * len(rotation_mapping)
493
+ minimum_score = np.multiply(
494
+ erfcinv(2 * args.n_false_positives / n_correlations),
495
+ np.sqrt(2) * np.std(scores[cropped_slice]),
400
496
  )
401
- if len(candidates) == 0:
402
- candidates = [[], [], [], []]
403
- print("Found no peaks, consider changing peak calling parameters.")
404
- exit(-1)
405
-
406
- for translation, _, score, detail in zip(*candidates):
407
- rotation_index = rotation_array[tuple(translation)]
408
- rotation = rotation_mapping.get(
409
- rotation_index, np.zeros(template.data.ndim, int)
410
- )
411
- if rotation.ndim == 2:
412
- rotation = euler_from_rotationmatrix(rotation)
413
- rotations.append(rotation)
497
+ print(f"Determined minimum score cutoff: {minimum_score}.")
498
+ minimum_score = max(minimum_score, 0)
499
+ args.minimum_score = minimum_score
500
+
501
+ args.batch_dims = None
502
+ if hasattr(cli_args, "target_batch"):
503
+ args.batch_dims = cli_args.target_batch
504
+
505
+ peak_caller_kwargs = {
506
+ "shape": scores.shape,
507
+ "num_peaks": args.num_peaks,
508
+ "min_distance": args.min_distance,
509
+ "min_boundary_distance": args.min_boundary_distance,
510
+ "batch_dims": args.batch_dims,
511
+ "minimum_score": args.minimum_score,
512
+ "maximum_score": args.maximum_score,
513
+ }
514
+
515
+ peak_caller = PEAK_CALLERS[args.peak_caller](**peak_caller_kwargs)
516
+ state = peak_caller.init_state()
517
+ state = peak_caller(
518
+ state,
519
+ scores,
520
+ rotation_matrix=np.eye(template.data.ndim),
521
+ mask=template.data,
522
+ rotation_mapping=rotation_mapping,
523
+ rotations=rotation_array,
524
+ )
525
+ candidates = peak_caller.merge(
526
+ results=[peak_caller.result(state)], **peak_caller_kwargs
527
+ )
528
+ if len(candidates) == 0:
529
+ candidates = [[], [], [], []]
530
+ print("Found no peaks, consider changing peak calling parameters.")
531
+ exit(-1)
414
532
 
415
- else:
416
- candidates = data
417
- translation, rotation, *_ = data
418
- for i in range(translation.shape[0]):
419
- rotations.append(euler_from_rotationmatrix(rotation[i]))
533
+ for translation, _, score, detail in zip(*candidates):
534
+ rotation_index = rotation_array[tuple(translation)]
535
+ rotation = rotation_mapping.get(
536
+ rotation_index, np.zeros(template.data.ndim, int)
537
+ )
538
+ if rotation.ndim == 2:
539
+ rotation = euler_from_rotationmatrix(rotation)
540
+ rotations.append(rotation)
420
541
 
421
542
  if len(rotations):
422
543
  rotations = np.vstack(rotations).astype(float)
@@ -520,9 +641,9 @@ def main():
520
641
  )
521
642
 
522
643
  extraction_shape = template.shape
523
- if args.subtomogram_box_size is not None:
644
+ if args.extraction_box_size is not None:
524
645
  extraction_shape = np.repeat(
525
- args.subtomogram_box_size, len(extraction_shape)
646
+ args.extraction_box_size, len(extraction_shape)
526
647
  )
527
648
 
528
649
  orientations, cand_slices, obs_slices = orientations.get_extraction_slices(
@@ -588,7 +709,6 @@ def main():
588
709
  filepath=cli_args.template,
589
710
  sampling_rate=sampling_rate,
590
711
  centering=not cli_args.no_centering,
591
- target_shape=target.shape,
592
712
  )
593
713
 
594
714
  for index, (translation, angles, *_) in enumerate(orientations):
@@ -1,14 +1,15 @@
1
1
  #!python
2
- """ Preprocessing routines for template matching.
2
+ """Preprocessing routines for template matching.
3
3
 
4
- Copyright (c) 2023 European Molecular Biology Laboratory
4
+ Copyright (c) 2023 European Molecular Biology Laboratory
5
5
 
6
- Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
6
+ Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
7
7
  """
8
8
  import argparse
9
9
  import numpy as np
10
10
 
11
11
  from tme import Density, Structure
12
+ from tme.cli import print_entry
12
13
  from tme.backends import backend as be
13
14
  from tme.filters import BandPassFilter
14
15
 
@@ -59,7 +60,8 @@ def parse_args():
59
60
  dest="input_sampling_rate",
60
61
  type=float,
61
62
  required=False,
62
- help="Sampling rate of the input file.",
63
+ help="Sampling rate of the input file. Defaults to header for volume "
64
+ "and to --sampling_rate for atomic structures.",
63
65
  )
64
66
 
65
67
  modulation_group = parser.add_argument_group("Modulation")
@@ -125,6 +127,7 @@ def parse_args():
125
127
 
126
128
  def main():
127
129
  args = parse_args()
130
+ print_entry()
128
131
 
129
132
  try:
130
133
  data = Structure.from_file(args.data)
@@ -190,11 +193,11 @@ def main():
190
193
  shape_is_real_fourier=False,
191
194
  sampling_rate=data.sampling_rate,
192
195
  )(shape=data.shape)["data"]
193
- bpf_mask = be.to_numpy_array(bpf_mask)
196
+ bpf_mask = be.to_backend_array(bpf_mask)
194
197
 
195
- data_ft = np.fft.rfftn(data.data, s=data.shape)
196
- data_ft = np.multiply(data_ft, bpf_mask, out=data_ft)
197
- data.data = np.fft.irfftn(data_ft, s=data.shape).real
198
+ data_ft = be.rfftn(be.to_backend_array(data.data), s=data.shape)
199
+ data_ft = be.multiply(data_ft, bpf_mask, out=data_ft)
200
+ data.data = be.to_numpy_array(be.irfftn(data_ft, s=data.shape).real)
198
201
 
199
202
  data = data.resample(args.sampling_rate, method="spline", order=3)
200
203