pytme 0.2.1__cp311-cp311-macosx_14_0_arm64.whl → 0.2.3__cp311-cp311-macosx_14_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. {pytme-0.2.1.data → pytme-0.2.3.data}/scripts/match_template.py +219 -216
  2. {pytme-0.2.1.data → pytme-0.2.3.data}/scripts/postprocess.py +86 -54
  3. pytme-0.2.3.data/scripts/preprocess.py +132 -0
  4. {pytme-0.2.1.data → pytme-0.2.3.data}/scripts/preprocessor_gui.py +181 -94
  5. pytme-0.2.3.dist-info/METADATA +92 -0
  6. pytme-0.2.3.dist-info/RECORD +75 -0
  7. {pytme-0.2.1.dist-info → pytme-0.2.3.dist-info}/WHEEL +1 -1
  8. pytme-0.2.1.data/scripts/preprocess.py → scripts/eval.py +1 -1
  9. scripts/extract_candidates.py +20 -13
  10. scripts/match_template.py +219 -216
  11. scripts/match_template_filters.py +154 -95
  12. scripts/postprocess.py +86 -54
  13. scripts/preprocess.py +95 -56
  14. scripts/preprocessor_gui.py +181 -94
  15. scripts/refine_matches.py +265 -61
  16. tme/__init__.py +0 -1
  17. tme/__version__.py +1 -1
  18. tme/analyzer.py +458 -813
  19. tme/backends/__init__.py +40 -11
  20. tme/backends/_jax_utils.py +187 -0
  21. tme/backends/cupy_backend.py +109 -226
  22. tme/backends/jax_backend.py +230 -152
  23. tme/backends/matching_backend.py +445 -384
  24. tme/backends/mlx_backend.py +32 -59
  25. tme/backends/npfftw_backend.py +240 -507
  26. tme/backends/pytorch_backend.py +30 -151
  27. tme/density.py +248 -371
  28. tme/extensions.cpython-311-darwin.so +0 -0
  29. tme/matching_data.py +328 -284
  30. tme/matching_exhaustive.py +195 -1499
  31. tme/matching_optimization.py +143 -106
  32. tme/matching_scores.py +887 -0
  33. tme/matching_utils.py +287 -388
  34. tme/memory.py +377 -0
  35. tme/orientations.py +78 -21
  36. tme/parser.py +3 -4
  37. tme/preprocessing/_utils.py +61 -32
  38. tme/preprocessing/composable_filter.py +7 -4
  39. tme/preprocessing/compose.py +7 -3
  40. tme/preprocessing/frequency_filters.py +49 -39
  41. tme/preprocessing/tilt_series.py +44 -72
  42. tme/preprocessor.py +560 -526
  43. tme/structure.py +491 -188
  44. tme/types.py +5 -3
  45. pytme-0.2.1.dist-info/METADATA +0 -73
  46. pytme-0.2.1.dist-info/RECORD +0 -73
  47. tme/helpers.py +0 -881
  48. tme/matching_constrained.py +0 -195
  49. {pytme-0.2.1.data → pytme-0.2.3.data}/scripts/estimate_ram_usage.py +0 -0
  50. {pytme-0.2.1.dist-info → pytme-0.2.3.dist-info}/LICENSE +0 -0
  51. {pytme-0.2.1.dist-info → pytme-0.2.3.dist-info}/entry_points.txt +0 -0
  52. {pytme-0.2.1.dist-info → pytme-0.2.3.dist-info}/top_level.txt +0 -0
scripts/postprocess.py CHANGED
@@ -8,9 +8,8 @@
8
8
  import argparse
9
9
  from sys import exit
10
10
  from os import getcwd
11
- from os.path import join, abspath
12
11
  from typing import List, Tuple
13
- from os.path import splitext
12
+ from os.path import join, abspath, splitext
14
13
 
15
14
  import numpy as np
16
15
  from numpy.typing import NDArray
@@ -26,6 +25,7 @@ from tme.analyzer import (
26
25
  )
27
26
  from tme.matching_utils import (
28
27
  load_pickle,
28
+ centered_mask,
29
29
  euler_to_rotationmatrix,
30
30
  euler_from_rotationmatrix,
31
31
  )
@@ -41,9 +41,7 @@ PEAK_CALLERS = {
41
41
 
42
42
 
43
43
  def parse_args():
44
- parser = argparse.ArgumentParser(
45
- description="Peak Calling for Template Matching Outputs"
46
- )
44
+ parser = argparse.ArgumentParser(description="Analyze Template Matching Outputs")
47
45
 
48
46
  input_group = parser.add_argument_group("Input")
49
47
  output_group = parser.add_argument_group("Output")
@@ -56,6 +54,13 @@ def parse_args():
56
54
  nargs="+",
57
55
  help="Path to the output of match_template.py.",
58
56
  )
57
+ input_group.add_argument(
58
+ "--background_file",
59
+ required=False,
60
+ nargs="+",
61
+ help="Path to an output of match_template.py used for normalization. "
62
+ "For instance from --scramble_phases or a different template.",
63
+ )
59
64
  input_group.add_argument(
60
65
  "--target_mask",
61
66
  required=False,
@@ -87,7 +92,7 @@ def parse_args():
87
92
  "average",
88
93
  ],
89
94
  default="orientations",
90
- help="Available output formats:"
95
+ help="Available output formats: "
91
96
  "orientations (translation, rotation, and score), "
92
97
  "alignment (aligned template to target based on orientations), "
93
98
  "extraction (extract regions around peaks from targets, i.e. subtomograms), "
@@ -206,6 +211,15 @@ def parse_args():
206
211
  elif args.number_of_peaks is None:
207
212
  args.number_of_peaks = 1000
208
213
 
214
+ if args.background_file is None:
215
+ args.background_file = [None]
216
+ if len(args.background_file) == 1:
217
+ args.background_file = args.background_file * len(args.input_file)
218
+ elif len(args.background_file) not in (0, len(args.input_file)):
219
+ raise ValueError(
220
+ "--background_file needs to be specified once or for each --input_file."
221
+ )
222
+
209
223
  return args
210
224
 
211
225
 
@@ -233,8 +247,8 @@ def load_template(
233
247
  return template, center, translation, template_is_density
234
248
 
235
249
 
236
- def merge_outputs(data, filepaths: List[str], args):
237
- if len(filepaths) == 0:
250
+ def merge_outputs(data, foreground_paths: List[str], background_paths: List[str], args):
251
+ if len(foreground_paths) == 0:
238
252
  return data, 1
239
253
 
240
254
  if data[0].ndim != data[2].ndim:
@@ -275,8 +289,11 @@ def merge_outputs(data, filepaths: List[str], args):
275
289
 
276
290
  entities = np.zeros_like(data[0])
277
291
  data[0] = _norm_scores(data=data, args=args)
278
- for index, filepath in enumerate(filepaths):
279
- new_scores = _norm_scores(data=load_pickle(filepath), args=args)
292
+ for index, filepath in enumerate(foreground_paths):
293
+ new_scores = _norm_scores(
294
+ data=load_match_template_output(filepath, background_paths[index]),
295
+ args=args,
296
+ )
280
297
  indices = new_scores > data[0]
281
298
  entities[indices] = index + 1
282
299
  data[0][indices] = new_scores[indices]
@@ -284,9 +301,18 @@ def merge_outputs(data, filepaths: List[str], args):
284
301
  return data, entities
285
302
 
286
303
 
304
+ def load_match_template_output(foreground_path, background_path):
305
+ data = load_pickle(foreground_path)
306
+ if background_path is not None:
307
+ data_background = load_pickle(background_path)
308
+ data[0] = (data[0] - data_background[0]) / (1 - data_background[0])
309
+ np.fmax(data[0], 0, out=data[0])
310
+ return data
311
+
312
+
287
313
  def main():
288
314
  args = parse_args()
289
- data = load_pickle(args.input_file[0])
315
+ data = load_match_template_output(args.input_file[0], args.background_file[0])
290
316
 
291
317
  target_origin, _, sampling_rate, cli_args = data[-1]
292
318
 
@@ -326,7 +352,12 @@ def main():
326
352
 
327
353
  entities = None
328
354
  if len(args.input_file) > 1:
329
- data, entities = merge_outputs(data=data, filepaths=args.input_file, args=args)
355
+ data, entities = merge_outputs(
356
+ data=data,
357
+ foreground_paths=args.input_file,
358
+ background_paths=args.background_file,
359
+ args=args,
360
+ )
330
361
 
331
362
  orientations = args.orientations
332
363
  if orientations is None:
@@ -339,24 +370,27 @@ def main():
339
370
  target_mask = Density.from_file(args.target_mask)
340
371
  scores = scores * target_mask.data
341
372
 
342
- if args.n_false_positives is not None:
343
- args.n_false_positives = max(args.n_false_positives, 1)
344
- cropped_shape = np.subtract(
345
- scores.shape, np.multiply(args.min_boundary_distance, 2)
346
- ).astype(int)
373
+ cropped_shape = np.subtract(
374
+ scores.shape, np.multiply(args.min_boundary_distance, 2)
375
+ ).astype(int)
376
+
377
+ if args.min_boundary_distance > 0:
378
+ scores = centered_mask(scores, new_shape=cropped_shape)
347
379
 
348
- cropped_shape = tuple(
380
+ if args.n_false_positives is not None:
381
+ # Rickgauer et al. 2017
382
+ cropped_slice = tuple(
349
383
  slice(
350
384
  int(args.min_boundary_distance),
351
385
  int(x - args.min_boundary_distance),
352
386
  )
353
387
  for x in scores.shape
354
388
  )
355
- # Rickgauer et al. 2017
356
- n_correlations = np.size(scores[cropped_shape]) * len(rotation_mapping)
389
+ args.n_false_positives = max(args.n_false_positives, 1)
390
+ n_correlations = np.size(scores[cropped_slice]) * len(rotation_mapping)
357
391
  minimum_score = np.multiply(
358
392
  erfcinv(2 * args.n_false_positives / n_correlations),
359
- np.sqrt(2) * np.std(scores[cropped_shape]),
393
+ np.sqrt(2) * np.std(scores[cropped_slice]),
360
394
  )
361
395
  print(f"Determined minimum score cutoff: {minimum_score}.")
362
396
  minimum_score = max(minimum_score, 0)
@@ -371,6 +405,8 @@ def main():
371
405
  "min_distance": args.min_distance,
372
406
  "min_boundary_distance": args.min_boundary_distance,
373
407
  "batch_dims": args.batch_dims,
408
+ "minimum_score": args.minimum_score,
409
+ "maximum_score": args.maximum_score,
374
410
  }
375
411
 
376
412
  peak_caller = PEAK_CALLERS[args.peak_caller](**peak_caller_kwargs)
@@ -380,7 +416,6 @@ def main():
380
416
  mask=template.data,
381
417
  rotation_mapping=rotation_mapping,
382
418
  rotation_array=rotation_array,
383
- minimum_score=args.minimum_score,
384
419
  )
385
420
  candidates = peak_caller.merge(
386
421
  candidates=[tuple(peak_caller)], **peak_caller_kwargs
@@ -388,10 +423,16 @@ def main():
388
423
  if len(candidates) == 0:
389
424
  candidates = [[], [], [], []]
390
425
  print("Found no peaks, consider changing peak calling parameters.")
391
- exit(0)
426
+ exit(-1)
392
427
 
393
428
  for translation, _, score, detail in zip(*candidates):
394
- rotations.append(rotation_mapping[rotation_array[tuple(translation)]])
429
+ rotation_index = rotation_array[tuple(translation)]
430
+ rotation = rotation_mapping.get(
431
+ rotation_index, np.zeros(template.data.ndim, int)
432
+ )
433
+ if rotation.ndim == 2:
434
+ rotation = euler_from_rotationmatrix(rotation)
435
+ rotations.append(rotation)
395
436
 
396
437
  else:
397
438
  candidates = data
@@ -430,7 +471,7 @@ def main():
430
471
  )
431
472
  exit(-1)
432
473
  orientations.translations = peak_caller.oversample_peaks(
433
- score_space=data[0],
474
+ scores=data[0],
434
475
  peak_positions=orientations.translations,
435
476
  oversampling_factor=args.peak_oversampling,
436
477
  )
@@ -468,19 +509,7 @@ def main():
468
509
 
469
510
  target = Density.from_file(cli_args.target)
470
511
  if args.invert_target_contrast:
471
- if args.output_format == "relion":
472
- target.data = target.data * -1
473
- target.data = np.divide(
474
- np.subtract(target.data, target.data.mean()), target.data.std()
475
- )
476
- else:
477
- target.data = (
478
- -np.divide(
479
- np.subtract(target.data, target.data.min()),
480
- np.subtract(target.data.max(), target.data.min()),
481
- )
482
- + 1
483
- )
512
+ target.data = target.data * -1
484
513
 
485
514
  if args.output_format in ("extraction", "relion"):
486
515
  if not np.all(np.divide(target.shape, template.shape) > 2):
@@ -505,10 +534,14 @@ def main():
505
534
 
506
535
  working_directory = getcwd()
507
536
  if args.output_format == "relion":
537
+ name = [
538
+ join(working_directory, f"{args.output_prefix}_{index}.mrc")
539
+ for index in range(len(cand_slices))
540
+ ]
508
541
  orientations.to_file(
509
542
  filename=f"{args.output_prefix}.star",
510
543
  file_format="relion",
511
- name_prefix=join(working_directory, args.output_prefix),
544
+ name=name,
512
545
  ctf_image=args.wedge_mask,
513
546
  sampling_rate=target.sampling_rate.max(),
514
547
  subtomogram_size=extraction_shape[0],
@@ -565,24 +598,22 @@ def main():
565
598
  if args.output_format == "average":
566
599
  orientations, cand_slices, obs_slices = orientations.get_extraction_slices(
567
600
  target_shape=target.shape,
568
- extraction_shape=np.multiply(template.shape, 2),
601
+ extraction_shape=template.shape,
569
602
  drop_out_of_box=True,
570
603
  return_orientations=True,
571
604
  )
572
605
  out = np.zeros_like(template.data)
573
- out = np.zeros(np.multiply(template.shape, 2).astype(int))
574
606
  for index in range(len(cand_slices)):
575
- from scipy.spatial.transform import Rotation
576
-
577
- rotation = Rotation.from_euler(
578
- angles=orientations.rotations[index], seq="zyx", degrees=True
579
- )
580
- rotation_matrix = rotation.inv().as_matrix()
581
-
582
607
  subset = Density(target.data[obs_slices[index]])
583
- subset = subset.rigid_transform(rotation_matrix=rotation_matrix, order=1)
608
+ rotation_matrix = euler_to_rotationmatrix(orientations.rotations[index])
584
609
 
610
+ subset = subset.rigid_transform(
611
+ rotation_matrix=np.linalg.inv(rotation_matrix),
612
+ order=1,
613
+ use_geometric_center=True,
614
+ )
585
615
  np.add(out, subset.data, out=out)
616
+
586
617
  out /= len(cand_slices)
587
618
  ret = Density(out, sampling_rate=template.sampling_rate, origin=0)
588
619
  ret.pad(template.shape, center=True)
@@ -596,17 +627,18 @@ def main():
596
627
  target_shape=target.shape,
597
628
  )
598
629
 
630
+ # Template is larger than target
599
631
  for index, (translation, angles, *_) in enumerate(orientations):
600
632
  rotation_matrix = euler_to_rotationmatrix(angles)
601
633
  if template_is_density:
602
- translation = np.subtract(translation, center)
603
634
  transformed_template = template.rigid_transform(
604
- rotation_matrix=rotation_matrix
605
- )
606
- transformed_template.origin = np.add(
607
- target_origin, np.multiply(translation, sampling_rate)
635
+ rotation_matrix=rotation_matrix, use_geometric_center=True
608
636
  )
609
637
 
638
+ # Just adapting the coordinate system not the in-box position
639
+ shift = np.multiply(np.subtract(translation, center), sampling_rate)
640
+ transformed_template.origin = np.add(target_origin, shift)
641
+
610
642
  else:
611
643
  template = Structure.from_file(cli_args.template)
612
644
  new_center_of_mass = np.add(
scripts/preprocess.py CHANGED
@@ -1,93 +1,132 @@
1
1
  #!python3
2
- """ Apply tme.preprocessor.Preprocessor methods to an input file based
3
- on a provided yaml configuration obtaiend from preprocessor_gui.py.
2
+ """ Preprocessing routines for template matching.
4
3
 
5
4
  Copyright (c) 2023 European Molecular Biology Laboratory
6
5
 
7
6
  Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
8
7
  """
9
- import yaml
8
+ import warnings
10
9
  import argparse
11
- import textwrap
12
- from tme import Preprocessor, Density
10
+ import numpy as np
11
+
12
+ from tme import Density, Structure
13
+ from tme.backends import backend as be
14
+ from tme.preprocessing.frequency_filters import BandPassFilter
13
15
 
14
16
 
15
17
  def parse_args():
16
18
  parser = argparse.ArgumentParser(
17
- description=textwrap.dedent(
18
- """
19
- Apply preprocessing to an input file based on a provided YAML configuration.
20
-
21
- Expected YAML file format:
22
- ```yaml
23
- <method_name>:
24
- <parameter1>: <value1>
25
- <parameter2>: <value2>
26
- ...
27
- ```
28
- """
29
- ),
30
- formatter_class=argparse.RawDescriptionHelpFormatter,
19
+ description="Perform template matching preprocessing.",
20
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
31
21
  )
32
- parser.add_argument(
33
- "-i",
34
- "--input_file",
22
+
23
+ io_group = parser.add_argument_group("Input / Output")
24
+ io_group.add_argument(
25
+ "-m",
26
+ "--data",
27
+ dest="data",
35
28
  type=str,
36
29
  required=True,
37
- help="Path to the input data file in CCP4/MRC format.",
30
+ help="Path to a file in PDB/MMCIF, CCP4/MRC, EM, H5 or a format supported by "
31
+ "tme.density.Density.from_file "
32
+ "https://kosinskilab.github.io/pyTME/reference/api/tme.density.Density.from_file.html",
38
33
  )
39
- parser.add_argument(
40
- "-y",
41
- "--yaml_file",
34
+ io_group.add_argument(
35
+ "-o",
36
+ "--output",
37
+ dest="output",
42
38
  type=str,
43
39
  required=True,
44
- help="Path to the YAML configuration file.",
40
+ help="Path the output should be written to.",
45
41
  )
46
- parser.add_argument(
47
- "-o",
48
- "--output_file",
49
- type=str,
42
+
43
+ box_group = parser.add_argument_group("Box")
44
+ box_group.add_argument(
45
+ "--box_size",
46
+ dest="box_size",
47
+ type=int,
50
48
  required=True,
51
- help="Path to output file in CPP4/MRC format..",
49
+ help="Box size of the output",
52
50
  )
53
- parser.add_argument(
54
- "--compress", action="store_true", help="Compress the output file using gzip."
51
+ box_group.add_argument(
52
+ "--sampling_rate",
53
+ dest="sampling_rate",
54
+ type=float,
55
+ required=True,
56
+ help="Sampling rate of the output file.",
55
57
  )
56
58
 
59
+ modulation_group = parser.add_argument_group("Modulation")
60
+ modulation_group.add_argument(
61
+ "--invert_contrast",
62
+ dest="invert_contrast",
63
+ action="store_true",
64
+ required=False,
65
+ help="Inverts the template contrast.",
66
+ )
67
+ modulation_group.add_argument(
68
+ "--lowpass",
69
+ dest="lowpass",
70
+ type=float,
71
+ required=False,
72
+ default=None,
73
+ help="Lowpass filter the template to the given resolution. Nyquist by default. "
74
+ "A value of 0 disables the filter.",
75
+ )
76
+ modulation_group.add_argument(
77
+ "--no_centering",
78
+ dest="no_centering",
79
+ action="store_true",
80
+ help="Assumes the template is already centered and omits centering.",
81
+ )
57
82
  args = parser.parse_args()
58
-
59
83
  return args
60
84
 
61
85
 
62
86
  def main():
63
87
  args = parse_args()
64
- with open(args.yaml_file, "r") as f:
65
- preprocess_settings = yaml.safe_load(f)
66
88
 
67
- if len(preprocess_settings) > 1:
68
- raise NotImplementedError(
69
- "Multiple preprocessing methods specified. "
70
- "The script currently supports one method at a time."
89
+ try:
90
+ data = Structure.from_file(args.data)
91
+ data = Density.from_structure(data, sampling_rate=args.sampling_rate)
92
+ except NotImplementedError:
93
+ data = Density.from_file(args.data)
94
+
95
+ if not args.no_centering:
96
+ data, _ = data.centered(0)
97
+
98
+ recommended_box = be.compute_convolution_shapes([args.box_size], [1])[1][0]
99
+ if recommended_box != args.box_size:
100
+ warnings.warn(
101
+ f"Consider using --box_size {recommended_box} instead of {args.box_size}."
71
102
  )
72
103
 
73
- method_name = list(preprocess_settings.keys())[0]
74
- if not hasattr(Preprocessor, method_name):
75
- raise ValueError(f"Method {method_name} does not exist in Preprocessor.")
104
+ data.pad(
105
+ np.multiply(args.box_size, np.divide(args.sampling_rate, data.sampling_rate)),
106
+ center=True,
107
+ )
108
+
109
+ bpf_mask = 1
110
+ lowpass = 2 * args.sampling_rate if args.lowpass is None else args.lowpass
111
+ if args.lowpass != 0:
112
+ bpf_mask = BandPassFilter(
113
+ lowpass=lowpass,
114
+ highpass=None,
115
+ use_gaussian=True,
116
+ return_real_fourier=True,
117
+ shape_is_real_fourier=False,
118
+ )(shape=data.shape)["data"]
76
119
 
77
- density = Density.from_file(args.input_file)
78
- output = density.empty
120
+ data_ft = np.fft.rfftn(data.data, s=data.shape)
121
+ data_ft = np.multiply(data_ft, bpf_mask, out=data_ft)
122
+ data.data = np.fft.irfftn(data_ft, s=data.shape).real
79
123
 
80
- method_params = preprocess_settings[method_name]
81
- preprocessor = Preprocessor()
82
- method = getattr(preprocessor, method_name, None)
83
- if not method:
84
- raise ValueError(
85
- f"{method} does not exist in dge.preprocessor.Preprocessor class."
86
- )
124
+ data = data.resample(args.sampling_rate, method="spline", order=3)
87
125
 
88
- output.data = method(template=density.data, **method_params)
89
- output.to_file(args.output_file, gzip=args.compress)
126
+ if args.invert_contrast:
127
+ data.data = data.data * -1
90
128
 
129
+ data.to_file(args.output)
91
130
 
92
131
  if __name__ == "__main__":
93
- main()
132
+ main()