pytme 0.2.0__tar.gz → 0.2.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (73) hide show
  1. {pytme-0.2.0 → pytme-0.2.1}/PKG-INFO +1 -1
  2. {pytme-0.2.0 → pytme-0.2.1}/pyproject.toml +1 -1
  3. {pytme-0.2.0 → pytme-0.2.1}/pytme.egg-info/SOURCES.txt +1 -0
  4. pytme-0.2.1/scripts/extract_candidates.py +250 -0
  5. {pytme-0.2.0 → pytme-0.2.1}/scripts/match_template.py +183 -69
  6. {pytme-0.2.0 → pytme-0.2.1}/scripts/match_template_filters.py +193 -71
  7. {pytme-0.2.0 → pytme-0.2.1}/scripts/postprocess.py +107 -49
  8. {pytme-0.2.0 → pytme-0.2.1}/scripts/preprocessor_gui.py +4 -1
  9. pytme-0.2.1/scripts/refine_matches.py +422 -0
  10. pytme-0.2.1/tme/__version__.py +1 -0
  11. {pytme-0.2.0 → pytme-0.2.1}/tme/analyzer.py +259 -117
  12. {pytme-0.2.0 → pytme-0.2.1}/tme/backends/__init__.py +1 -0
  13. {pytme-0.2.0 → pytme-0.2.1}/tme/backends/cupy_backend.py +20 -13
  14. pytme-0.2.1/tme/backends/jax_backend.py +218 -0
  15. {pytme-0.2.0 → pytme-0.2.1}/tme/backends/matching_backend.py +25 -10
  16. {pytme-0.2.0 → pytme-0.2.1}/tme/backends/mlx_backend.py +13 -9
  17. {pytme-0.2.0 → pytme-0.2.1}/tme/backends/npfftw_backend.py +20 -8
  18. {pytme-0.2.0 → pytme-0.2.1}/tme/backends/pytorch_backend.py +20 -9
  19. {pytme-0.2.0 → pytme-0.2.1}/tme/density.py +79 -60
  20. {pytme-0.2.0 → pytme-0.2.1}/tme/matching_data.py +85 -61
  21. {pytme-0.2.0 → pytme-0.2.1}/tme/matching_exhaustive.py +222 -129
  22. {pytme-0.2.0 → pytme-0.2.1}/tme/matching_optimization.py +117 -76
  23. {pytme-0.2.0 → pytme-0.2.1}/tme/orientations.py +175 -55
  24. {pytme-0.2.0 → pytme-0.2.1}/tme/preprocessing/_utils.py +17 -5
  25. {pytme-0.2.0 → pytme-0.2.1}/tme/preprocessing/composable_filter.py +2 -1
  26. {pytme-0.2.0 → pytme-0.2.1}/tme/preprocessing/compose.py +1 -2
  27. {pytme-0.2.0 → pytme-0.2.1}/tme/preprocessing/frequency_filters.py +97 -41
  28. {pytme-0.2.0 → pytme-0.2.1}/tme/preprocessing/tilt_series.py +137 -87
  29. {pytme-0.2.0 → pytme-0.2.1}/tme/preprocessor.py +3 -0
  30. {pytme-0.2.0 → pytme-0.2.1}/tme/structure.py +4 -1
  31. pytme-0.2.0/scripts/extract_candidates.py +0 -218
  32. pytme-0.2.0/scripts/refine_matches.py +0 -218
  33. pytme-0.2.0/tme/__version__.py +0 -1
  34. {pytme-0.2.0 → pytme-0.2.1}/LICENSE +0 -0
  35. {pytme-0.2.0 → pytme-0.2.1}/MANIFEST.in +0 -0
  36. {pytme-0.2.0 → pytme-0.2.1}/README.md +0 -0
  37. {pytme-0.2.0 → pytme-0.2.1}/scripts/__init__.py +0 -0
  38. {pytme-0.2.0 → pytme-0.2.1}/scripts/estimate_ram_usage.py +0 -0
  39. {pytme-0.2.0 → pytme-0.2.1}/scripts/preprocess.py +0 -0
  40. {pytme-0.2.0 → pytme-0.2.1}/setup.cfg +0 -0
  41. {pytme-0.2.0 → pytme-0.2.1}/setup.py +0 -0
  42. {pytme-0.2.0 → pytme-0.2.1}/src/extensions.cpp +0 -0
  43. {pytme-0.2.0 → pytme-0.2.1}/tme/__init__.py +0 -0
  44. {pytme-0.2.0 → pytme-0.2.1}/tme/data/__init__.py +0 -0
  45. {pytme-0.2.0 → pytme-0.2.1}/tme/data/c48n309.npy +0 -0
  46. {pytme-0.2.0 → pytme-0.2.1}/tme/data/c48n527.npy +0 -0
  47. {pytme-0.2.0 → pytme-0.2.1}/tme/data/c48n9.npy +0 -0
  48. {pytme-0.2.0 → pytme-0.2.1}/tme/data/c48u1.npy +0 -0
  49. {pytme-0.2.0 → pytme-0.2.1}/tme/data/c48u1153.npy +0 -0
  50. {pytme-0.2.0 → pytme-0.2.1}/tme/data/c48u1201.npy +0 -0
  51. {pytme-0.2.0 → pytme-0.2.1}/tme/data/c48u1641.npy +0 -0
  52. {pytme-0.2.0 → pytme-0.2.1}/tme/data/c48u181.npy +0 -0
  53. {pytme-0.2.0 → pytme-0.2.1}/tme/data/c48u2219.npy +0 -0
  54. {pytme-0.2.0 → pytme-0.2.1}/tme/data/c48u27.npy +0 -0
  55. {pytme-0.2.0 → pytme-0.2.1}/tme/data/c48u2947.npy +0 -0
  56. {pytme-0.2.0 → pytme-0.2.1}/tme/data/c48u3733.npy +0 -0
  57. {pytme-0.2.0 → pytme-0.2.1}/tme/data/c48u4749.npy +0 -0
  58. {pytme-0.2.0 → pytme-0.2.1}/tme/data/c48u5879.npy +0 -0
  59. {pytme-0.2.0 → pytme-0.2.1}/tme/data/c48u7111.npy +0 -0
  60. {pytme-0.2.0 → pytme-0.2.1}/tme/data/c48u815.npy +0 -0
  61. {pytme-0.2.0 → pytme-0.2.1}/tme/data/c48u83.npy +0 -0
  62. {pytme-0.2.0 → pytme-0.2.1}/tme/data/c48u8649.npy +0 -0
  63. {pytme-0.2.0 → pytme-0.2.1}/tme/data/c600v.npy +0 -0
  64. {pytme-0.2.0 → pytme-0.2.1}/tme/data/c600vc.npy +0 -0
  65. {pytme-0.2.0 → pytme-0.2.1}/tme/data/metadata.yaml +0 -0
  66. {pytme-0.2.0 → pytme-0.2.1}/tme/data/quat_to_numpy.py +0 -0
  67. {pytme-0.2.0 → pytme-0.2.1}/tme/helpers.py +0 -0
  68. {pytme-0.2.0 → pytme-0.2.1}/tme/matching_constrained.py +0 -0
  69. {pytme-0.2.0 → pytme-0.2.1}/tme/matching_memory.py +0 -0
  70. {pytme-0.2.0 → pytme-0.2.1}/tme/matching_utils.py +0 -0
  71. {pytme-0.2.0 → pytme-0.2.1}/tme/parser.py +0 -0
  72. {pytme-0.2.0 → pytme-0.2.1}/tme/preprocessing/__init__.py +0 -0
  73. {pytme-0.2.0 → pytme-0.2.1}/tme/types.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: pytme
3
- Version: 0.2.0
3
+ Version: 0.2.1
4
4
  Summary: Python Template Matching Engine
5
5
  Author: Valentin Maurer
6
6
  Author-email: Valentin Maurer <valentin.maurer@embl-hamburg.de>
@@ -7,7 +7,7 @@ name="pytme"
7
7
  authors = [
8
8
  { name = "Valentin Maurer", email = "valentin.maurer@embl-hamburg.de" },
9
9
  ]
10
- version="0.2.0"
10
+ version="0.2.1"
11
11
  description="Python Template Matching Engine"
12
12
  readme="README.md"
13
13
  requires-python = ">=3.11"
@@ -31,6 +31,7 @@ tme/structure.py
31
31
  tme/types.py
32
32
  tme/backends/__init__.py
33
33
  tme/backends/cupy_backend.py
34
+ tme/backends/jax_backend.py
34
35
  tme/backends/matching_backend.py
35
36
  tme/backends/mlx_backend.py
36
37
  tme/backends/npfftw_backend.py
@@ -0,0 +1,250 @@
1
+ #!python3
2
+ """ Prepare orientations stack for refinement.
3
+
4
+ Copyright (c) 2023 European Molecular Biology Laboratory
5
+
6
+ Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
7
+ """
8
+ import argparse
9
+ from os.path import splitext
10
+
11
+ import numpy as np
12
+
13
+ from tme import Density, Orientations
14
+ from tme.matching_utils import (
15
+ load_pickle,
16
+ generate_tempfile_name,
17
+ rotation_aligning_vectors,
18
+ euler_from_rotationmatrix,
19
+ euler_to_rotationmatrix,
20
+ )
21
+
22
+
23
+ class ProgressBar:
24
+ """
25
+ ASCII progress bar.
26
+ """
27
+
28
+ def __init__(self, message : str, nchars : int, total : int):
29
+ self._size = nchars - len(message)
30
+ self._message = message
31
+ self._total = total
32
+
33
+ def update(self, cur):
34
+ x = int(cur * self._size / self._total)
35
+ print(
36
+ "%s[%s%s] %i/%i\r"
37
+ % (self._message, "#" * x, "." * (self._size - x), cur, self._total),
38
+ end=''
39
+ )
40
+
41
+
42
+ def parse_args():
43
+ parser = argparse.ArgumentParser(
44
+ description="Extract matching candidates for further refinement."
45
+ )
46
+
47
+ io_group = parser.add_argument_group("Input / Output")
48
+ io_group.add_argument(
49
+ "--input_file",
50
+ required=False,
51
+ type=str,
52
+ help="Path to the output of match_template.py.",
53
+ )
54
+ io_group.add_argument(
55
+ "--orientations",
56
+ required=True,
57
+ type=str,
58
+ help="Path to file generated by postprocess.py using output_format orientations.",
59
+ )
60
+ io_group.add_argument(
61
+ "--target",
62
+ required=False,
63
+ type=str,
64
+ help="Extract candidates from this target, can be at different sampling rate.",
65
+ )
66
+ io_group.add_argument(
67
+ "--template",
68
+ required=False,
69
+ type=str,
70
+ help="Extract candidates from this target, can be at different sampling rate.",
71
+ )
72
+ io_group.add_argument(
73
+ "-o",
74
+ "--output_file",
75
+ required=True,
76
+ type=str,
77
+ help="Path to output HDF5 file.",
78
+ )
79
+
80
+ alignment_group = parser.add_argument_group("Alignment")
81
+ alignment_group.add_argument(
82
+ "--align_orientations",
83
+ action="store_true",
84
+ required=False,
85
+ help="Whether to align extracted orientations based on their angles. Allows "
86
+ "for efficient subsequent sampling of cone angles.",
87
+ )
88
+ alignment_group.add_argument(
89
+ "--angles_are_vector",
90
+ action="store_true",
91
+ required=False,
92
+ help="Considers euler_z euler_y, euler_x as vector that will be rotated to align "
93
+ "with the z-axis (1,0,0). Only considered when --align_orientations is set."
94
+ )
95
+ alignment_group.add_argument(
96
+ "--interpolation_order",
97
+ dest="interpolation_order",
98
+ required=False,
99
+ type=int,
100
+ default=1,
101
+ help="Interpolation order for alignment, less than zero is no interpolation."
102
+ )
103
+
104
+ extraction_group = parser.add_argument_group("Extraction")
105
+ extraction_group.add_argument(
106
+ "--box_size",
107
+ required=False,
108
+ type=int,
109
+ help="Box size for extraction, defaults to two times the template.",
110
+ )
111
+ extraction_group.add_argument(
112
+ "--keep_out_of_box",
113
+ action="store_true",
114
+ required=False,
115
+ help="Whether to keep orientations that fall outside the box. If the "
116
+ "orientations are sensible, it is safe to pass this flag.",
117
+ )
118
+
119
+ args = parser.parse_args()
120
+
121
+ data_present = args.target is not None and args.template is not None
122
+ if args.input_file is None and not data_present:
123
+ raise ValueError(
124
+ "Either --input_file or --target and --template need to be specified."
125
+ )
126
+ elif args.input_file is not None and data_present:
127
+ raise ValueError(
128
+ "Please specific either --input_file or --target and --template."
129
+ )
130
+
131
+ return args
132
+
133
+
134
+ def main():
135
+ args = parse_args()
136
+ orientations = Orientations.from_file(args.orientations)
137
+
138
+ if args.input_file is not None:
139
+ data = load_pickle(args.input_file)
140
+ target_origin, _, sampling_rate, cli_args = data[-1]
141
+ args.target, args.template = cli_args.target, cli_args.template
142
+
143
+ target = Density.from_file(args.target, use_memmap=True)
144
+
145
+ try:
146
+ template = Density.from_file(args.template)
147
+ except Exception:
148
+ template = Density.from_structure(args.template, sampling_rate = target.sampling_rate)
149
+
150
+ box_size = args.box_size
151
+ if box_size is None:
152
+ box_size = np.multiply(template.shape, 2)
153
+ box_size = np.array(box_size)
154
+ box_size = np.repeat(box_size, template.data.ndim // box_size.size).astype(int)
155
+
156
+ extraction_shape = np.copy(box_size)
157
+ if args.align_orientations:
158
+ extraction_shape[:] = int(np.linalg.norm(box_size) + 1)
159
+
160
+ orientations, cand_slices, obs_slices = orientations.get_extraction_slices(
161
+ target_shape=target.shape,
162
+ extraction_shape=extraction_shape,
163
+ drop_out_of_box=not args.keep_out_of_box,
164
+ return_orientations=True,
165
+ )
166
+
167
+ if args.align_orientations:
168
+ orientations.rotations = orientations.rotations.astype(np.float32)
169
+ for index in range(orientations.rotations.shape[0]):
170
+ rotation_matrix = euler_to_rotationmatrix(orientations.rotations[index])
171
+ rotation_matrix = np.linalg.inv(rotation_matrix)
172
+ if args.angles_are_vector:
173
+ rotation_matrix = rotation_aligning_vectors(
174
+ orientations.rotations[index], target_vector=(1,0,0)
175
+ )
176
+ orientations.rotations[index] = euler_from_rotationmatrix(rotation_matrix)
177
+
178
+
179
+ filename = generate_tempfile_name()
180
+ output_dtype = target.data.dtype
181
+ if args.align_orientations is not None:
182
+ output_dtype = np.float32
183
+
184
+ target.data = target.data.astype(output_dtype)
185
+
186
+ dens = Density(
187
+ np.memmap(
188
+ filename,
189
+ mode="w+",
190
+ shape=(len(obs_slices), *box_size),
191
+ dtype=output_dtype,
192
+ ),
193
+ sampling_rate=(1, *target.sampling_rate),
194
+ origin=(0, *target.origin),
195
+ )
196
+ dens.data[:] = target.metadata["mean"]
197
+
198
+ print(target.data.shape)
199
+ # There appears to be an isseu with the stack creation. Trace this further
200
+ data_subset = np.zeros(extraction_shape, dtype = target.data.dtype)
201
+ pbar = ProgressBar(message = "Orientation ", nchars = 80, total = len(obs_slices))
202
+ for index, (obs_slice, cand_slice) in enumerate(zip(obs_slices, cand_slices)):
203
+ pbar.update(index + 1)
204
+
205
+ data_subset.fill(0)
206
+ data_subset[cand_slice] = target.data[obs_slice]
207
+ target_subset = Density(
208
+ data_subset,
209
+ sampling_rate=target.sampling_rate,
210
+ origin=target.origin,
211
+ )
212
+
213
+ if args.align_orientations:
214
+ rotation_matrix = euler_to_rotationmatrix(orientations.rotations[index])
215
+ target_subset = target_subset.rigid_transform(
216
+ rotation_matrix=rotation_matrix,
217
+ use_geometric_center=True,
218
+ order=args.interpolation_order,
219
+ )
220
+ target_subset.pad(box_size, center=True)
221
+
222
+ target_value = target.data[tuple(orientations.translations[index].astype(int))]
223
+ center = np.divide(target_subset.data.shape, 2).astype(int ) + np.mod(target_subset.shape, 2)
224
+ print(np.where(target_subset.data == target_value), center)
225
+ print(target_subset.data[tuple(center.astype(int))],
226
+ target_value,
227
+ target_subset.data[tuple(center.astype(int))] == target_value
228
+ )
229
+
230
+ dens.data[index] = target_subset.data
231
+ print("")
232
+
233
+ target_meta = {
234
+ k: v for k, v in target.metadata.items() if k in ("mean", "max", "min", "std")
235
+ }
236
+ dens.metadata.update(target_meta)
237
+ dens.metadata["batch_dimension"] = (0, )
238
+
239
+ dens.to_file(args.output_file)
240
+ orientations.to_file(
241
+ f"{splitext(args.output_file)[0]}_aligned.tsv",
242
+ file_format="text"
243
+ )
244
+ orientations.to_file(
245
+ f"{splitext(args.output_file)[0]}_aligned.star",
246
+ file_format="relion"
247
+ )
248
+
249
+ if __name__ == "__main__":
250
+ main()
@@ -217,7 +217,7 @@ def setup_filter(args, template: Density, target: Density) -> Tuple[Compose, Com
217
217
  if args.tilt_weighting not in ("angle", None):
218
218
  raise ValueError(
219
219
  "Tilt weighting schemes other than 'angle' or 'None' require "
220
- "a specification of electron doses."
220
+ "a specification of electron doses via --tilt_angles."
221
221
  )
222
222
 
223
223
  wedge = Wedge(
@@ -240,31 +240,58 @@ def setup_filter(args, template: Density, target: Density) -> Tuple[Compose, Com
240
240
  wedge.sampling_rate = template.sampling_rate
241
241
  template_filter.append(wedge)
242
242
  if not isinstance(wedge, WedgeReconstructed):
243
- template_filter.append(ReconstructFromTilt(
244
- reconstruction_filter = args.reconstruction_filter
245
- ))
243
+ template_filter.append(
244
+ ReconstructFromTilt(
245
+ reconstruction_filter=args.reconstruction_filter,
246
+ interpolation_order=args.reconstruction_interpolation_order,
247
+ )
248
+ )
246
249
 
247
- if args.ctf_file is not None:
250
+ if args.ctf_file is not None or args.defocus is not None:
248
251
  from tme.preprocessing.tilt_series import CTF
249
252
 
250
- ctf = CTF.from_file(args.ctf_file)
251
- n_tilts_ctfs, n_tils_angles = len(ctf.defocus_x), len(wedge.angles)
252
- if n_tilts_ctfs != n_tils_angles:
253
- raise ValueError(
254
- f"CTF file contains {n_tilts_ctfs} micrographs, but match_template "
255
- f"recieved {n_tils_angles} tilt angles. Expected one angle "
256
- "per micrograph."
253
+ needs_reconstruction = True
254
+ if args.ctf_file is not None:
255
+ ctf = CTF.from_file(args.ctf_file)
256
+ n_tilts_ctfs, n_tils_angles = len(ctf.defocus_x), len(wedge.angles)
257
+ if n_tilts_ctfs != n_tils_angles:
258
+ raise ValueError(
259
+ f"CTF file contains {n_tilts_ctfs} micrographs, but match_template "
260
+ f"recieved {n_tils_angles} tilt angles. Expected one angle "
261
+ "per micrograph."
262
+ )
263
+ ctf.angles = wedge.angles
264
+ ctf.opening_axis, ctf.tilt_axis = args.wedge_axes
265
+ else:
266
+ needs_reconstruction = False
267
+ ctf = CTF(
268
+ defocus_x=[args.defocus],
269
+ phase_shift=[args.phase_shift],
270
+ defocus_y=None,
271
+ angles=[0],
272
+ shape=None,
273
+ return_real_fourier=True,
257
274
  )
258
- ctf.angles = wedge.angles
259
- ctf.opening_axis, ctf.tilt_axis = args.wedge_axes
260
-
261
- if isinstance(template_filter[-1], ReconstructFromTilt):
275
+ ctf.sampling_rate = template.sampling_rate
276
+ ctf.flip_phase = not args.no_flip_phase
277
+ ctf.amplitude_contrast = args.amplitude_contrast
278
+ ctf.spherical_aberration = args.spherical_aberration
279
+ ctf.acceleration_voltage = args.acceleration_voltage * 1e3
280
+ ctf.correct_defocus_gradient = args.correct_defocus_gradient
281
+
282
+ if not needs_reconstruction:
283
+ template_filter.append(ctf)
284
+ elif isinstance(template_filter[-1], ReconstructFromTilt):
262
285
  template_filter.insert(-1, ctf)
263
286
  else:
264
287
  template_filter.insert(0, ctf)
265
- template_filter.insert(1, ReconstructFromTilt(
266
- reconstruction_filter = args.reconstruction_filter
267
- ))
288
+ template_filter.insert(
289
+ 1,
290
+ ReconstructFromTilt(
291
+ reconstruction_filter=args.reconstruction_filter,
292
+ interpolation_order=args.reconstruction_interpolation_order,
293
+ ),
294
+ )
268
295
 
269
296
  if args.lowpass or args.highpass is not None:
270
297
  lowpass, highpass = args.lowpass, args.highpass
@@ -293,6 +320,14 @@ def setup_filter(args, template: Density, target: Density) -> Tuple[Compose, Com
293
320
  template_filter.append(whitening_filter)
294
321
  target_filter.append(whitening_filter)
295
322
 
323
+ needs_reconstruction = any(
324
+ [isinstance(t, ReconstructFromTilt) for t in template_filter]
325
+ )
326
+ if needs_reconstruction and args.reconstruction_filter is None:
327
+ warnings.warn(
328
+ "Consider using a --reconstruction_filter such as 'ramp' to avoid artifacts."
329
+ )
330
+
296
331
  template_filter = Compose(template_filter) if len(template_filter) else None
297
332
  target_filter = Compose(target_filter) if len(target_filter) else None
298
333
 
@@ -510,7 +545,7 @@ def parse_args():
510
545
  dest="no_pass_smooth",
511
546
  action="store_false",
512
547
  default=True,
513
- help="Whether a hard edge filter should be used for --lowpass and --highpass."
548
+ help="Whether a hard edge filter should be used for --lowpass and --highpass.",
514
549
  )
515
550
  filter_group.add_argument(
516
551
  "--pass_format",
@@ -519,7 +554,7 @@ def parse_args():
519
554
  required=False,
520
555
  choices=["sampling_rate", "voxel", "frequency"],
521
556
  help="How values passed to --lowpass and --highpass should be interpreted. "
522
- "By default, they are assumed to be in units of sampling rate, e.g. Ångstrom."
557
+ "By default, they are assumed to be in units of sampling rate, e.g. Ångstrom.",
523
558
  )
524
559
  filter_group.add_argument(
525
560
  "--whiten_spectrum",
@@ -561,23 +596,90 @@ def parse_args():
561
596
  "grigorieff (exposure filter as defined in Grant and Grigorieff 2015)."
562
597
  "relion and grigorieff require electron doses in --tilt_angles weights column.",
563
598
  )
564
- # filter_group.add_argument(
565
- # "--ctf_file",
566
- # dest="ctf_file",
567
- # type=str,
568
- # required=False,
569
- # default=None,
570
- # help="Path to a file with CTF parameters from CTFFIND4.",
571
- # )
572
599
  filter_group.add_argument(
573
600
  "--reconstruction_filter",
574
601
  dest="reconstruction_filter",
575
602
  type=str,
576
603
  required=False,
577
- choices = ["ram-lak", "ramp", "shepp-logan", "cosine", "hamming"],
604
+ choices=["ram-lak", "ramp", "ramp-cont", "shepp-logan", "cosine", "hamming"],
578
605
  default=None,
579
606
  help="Filter applied when reconstructing (N+1)-D from N-D filters.",
580
607
  )
608
+ filter_group.add_argument(
609
+ "--reconstruction_interpolation_order",
610
+ dest="reconstruction_interpolation_order",
611
+ type=int,
612
+ default=1,
613
+ required=False,
614
+ help="Analogous to --interpolation_order but for reconstruction.",
615
+ )
616
+
617
+ ctf_group = parser.add_argument_group("Contrast Transfer Function")
618
+ ctf_group.add_argument(
619
+ "--ctf_file",
620
+ dest="ctf_file",
621
+ type=str,
622
+ required=False,
623
+ default=None,
624
+ help="Path to a file with CTF parameters from CTFFIND4. Each line will be "
625
+ "interpreted as tilt obtained at the angle specified in --tilt_angles. ",
626
+ )
627
+ ctf_group.add_argument(
628
+ "--defocus",
629
+ dest="defocus",
630
+ type=float,
631
+ required=False,
632
+ default=None,
633
+ help="Defocus in units of sampling rate (typically Ångstrom). "
634
+ "Superseded by --ctf_file.",
635
+ )
636
+ ctf_group.add_argument(
637
+ "--phase_shift",
638
+ dest="phase_shift",
639
+ type=float,
640
+ required=False,
641
+ default=0,
642
+ help="Phase shift in degrees. Superseded by --ctf_file.",
643
+ )
644
+ ctf_group.add_argument(
645
+ "--acceleration_voltage",
646
+ dest="acceleration_voltage",
647
+ type=float,
648
+ required=False,
649
+ default=300,
650
+ help="Acceleration voltage in kV, defaults to 300.",
651
+ )
652
+ ctf_group.add_argument(
653
+ "--spherical_aberration",
654
+ dest="spherical_aberration",
655
+ type=float,
656
+ required=False,
657
+ default=2.7e7,
658
+ help="Spherical aberration in units of sampling rate (typically Ångstrom).",
659
+ )
660
+ ctf_group.add_argument(
661
+ "--amplitude_contrast",
662
+ dest="amplitude_contrast",
663
+ type=float,
664
+ required=False,
665
+ default=0.07,
666
+ help="Amplitude contrast, defaults to 0.07.",
667
+ )
668
+ ctf_group.add_argument(
669
+ "--no_flip_phase",
670
+ dest="no_flip_phase",
671
+ action="store_false",
672
+ required=False,
673
+ help="Whether the phase of the computed CTF should not be flipped.",
674
+ )
675
+ ctf_group.add_argument(
676
+ "--correct_defocus_gradient",
677
+ dest="correct_defocus_gradient",
678
+ action="store_true",
679
+ required=False,
680
+ help="[Experimental] Whether to compute a more accurate 3D CTF incorporating "
681
+ "defocus gradients.",
682
+ )
581
683
 
582
684
  performance_group = parser.add_argument_group("Performance")
583
685
  performance_group.add_argument(
@@ -655,12 +757,11 @@ def parse_args():
655
757
  )
656
758
 
657
759
  args = parser.parse_args()
760
+ args.version = __version__
658
761
 
659
762
  if args.interpolation_order < 0:
660
763
  args.interpolation_order = None
661
764
 
662
- args.ctf_file = None
663
-
664
765
  if args.temp_directory is None:
665
766
  default = abspath(".")
666
767
  if os.environ.get("TMPDIR", None) is not None:
@@ -725,12 +826,13 @@ def main():
725
826
  sampling_rate=target.sampling_rate,
726
827
  )
727
828
 
728
- if not np.allclose(target.sampling_rate, template.sampling_rate):
729
- print(
730
- f"Resampling template to {target.sampling_rate}. "
731
- "Consider providing a template with the same sampling rate as the target."
732
- )
733
- template = template.resample(target.sampling_rate, order=3)
829
+ if target.sampling_rate.size == template.sampling_rate.size:
830
+ if not np.allclose(target.sampling_rate, template.sampling_rate):
831
+ print(
832
+ f"Resampling template to {target.sampling_rate}. "
833
+ "Consider providing a template with the same sampling rate as the target."
834
+ )
835
+ template = template.resample(target.sampling_rate, order=3)
734
836
 
735
837
  template_mask = load_and_validate_mask(
736
838
  mask_target=template, mask_path=args.template_mask
@@ -863,31 +965,46 @@ def main():
863
965
  if args.memory is None:
864
966
  args.memory = int(args.memory_scaling * available_memory)
865
967
 
866
- target_padding = np.zeros_like(template.shape)
867
- if args.pad_target_edges:
868
- target_padding = template.shape
968
+ callback_class = MaxScoreOverRotations
969
+ if args.peak_calling:
970
+ callback_class = PeakCallerMaximumFilter
971
+
972
+ matching_data = MatchingData(
973
+ target=target,
974
+ template=template.data,
975
+ target_mask=target_mask,
976
+ template_mask=template_mask,
977
+ invert_target=args.invert_target_contrast,
978
+ rotations=parse_rotation_logic(args=args, ndim=template.data.ndim),
979
+ )
980
+
981
+ template_filter, target_filter = setup_filter(args, template, target)
982
+ matching_data.template_filter = template_filter
983
+ matching_data.target_filter = target_filter
869
984
 
870
- template_box = template.shape
985
+ template_box = matching_data._output_template_shape
871
986
  if not args.pad_fourier:
872
987
  template_box = np.ones(len(template_box), dtype=int)
873
988
 
874
- callback_class = MaxScoreOverRotations
875
- if args.peak_calling:
876
- callback_class = PeakCallerMaximumFilter
989
+ target_padding = np.zeros(
990
+ (backend.size(matching_data._output_template_shape)), dtype=int
991
+ )
992
+ if args.pad_target_edges:
993
+ target_padding = matching_data._output_template_shape
877
994
 
878
995
  splits, schedule = compute_parallelization_schedule(
879
996
  shape1=target.shape,
880
- shape2=template_box,
881
- shape1_padding=target_padding,
997
+ shape2=tuple(int(x) for x in template_box),
998
+ shape1_padding=tuple(int(x) for x in target_padding),
882
999
  max_cores=args.cores,
883
1000
  max_ram=args.memory,
884
1001
  split_only_outer=args.use_gpu,
885
1002
  matching_method=args.score,
886
1003
  analyzer_method=callback_class.__name__,
887
1004
  backend=backend._backend_name,
888
- float_nbytes=backend.datatype_bytes(backend._default_dtype),
1005
+ float_nbytes=backend.datatype_bytes(backend._float_dtype),
889
1006
  complex_nbytes=backend.datatype_bytes(backend._complex_dtype),
890
- integer_nbytes=backend.datatype_bytes(backend._default_dtype_int),
1007
+ integer_nbytes=backend.datatype_bytes(backend._int_dtype),
891
1008
  )
892
1009
 
893
1010
  if splits is None:
@@ -898,20 +1015,6 @@ def main():
898
1015
  exit(-1)
899
1016
 
900
1017
  matching_setup, matching_score = MATCHING_EXHAUSTIVE_REGISTER[args.score]
901
- matching_data = MatchingData(target=target, template=template.data)
902
- matching_data.rotations = parse_rotation_logic(args=args, ndim=target.data.ndim)
903
-
904
- template_filter, target_filter = setup_filter(args, template, target)
905
- matching_data.template_filter = template_filter
906
- matching_data.target_filter = target_filter
907
-
908
- matching_data.template_filter = template_filter
909
- matching_data._invert_target = args.invert_target_contrast
910
- if target_mask is not None:
911
- matching_data.target_mask = target_mask
912
- if template_mask is not None:
913
- matching_data.template_mask = template_mask.data
914
-
915
1018
  n_splits = np.prod(list(splits.values()))
916
1019
  target_split = ", ".join(
917
1020
  [":".join([str(x) for x in axis]) for axis in splits.items()]
@@ -945,13 +1048,23 @@ def main():
945
1048
  "Lowpass": args.lowpass,
946
1049
  "Highpass": args.highpass,
947
1050
  "Smooth Pass": args.no_pass_smooth,
948
- "Pass Format" : args.pass_format,
1051
+ "Pass Format": args.pass_format,
949
1052
  "Spectral Whitening": args.whiten_spectrum,
950
1053
  "Wedge Axes": args.wedge_axes,
951
1054
  "Tilt Angles": args.tilt_angles,
952
1055
  "Tilt Weighting": args.tilt_weighting,
953
- "CTF": args.ctf_file,
1056
+ "Reconstruction Filter": args.reconstruction_filter,
954
1057
  }
1058
+ if args.ctf_file is not None or args.defocus is not None:
1059
+ filter_args["CTF File"] = args.ctf_file
1060
+ filter_args["Defocus"] = args.defocus
1061
+ filter_args["Phase Shift"] = args.phase_shift
1062
+ filter_args["No Flip Phase"] = args.no_flip_phase
1063
+ filter_args["Acceleration Voltage"] = args.acceleration_voltage
1064
+ filter_args["Spherical Aberration"] = args.spherical_aberration
1065
+ filter_args["Amplitude Contrast"] = args.amplitude_contrast
1066
+ filter_args["Correct Defocus"] = args.correct_defocus_gradient
1067
+
955
1068
  filter_args = {k: v for k, v in filter_args.items() if v is not None}
956
1069
  if len(filter_args):
957
1070
  print_block(
@@ -1000,15 +1113,16 @@ def main():
1000
1113
  candidates[0] *= target_mask.data
1001
1114
  with warnings.catch_warnings():
1002
1115
  warnings.simplefilter("ignore", category=UserWarning)
1116
+ nbytes = backend.datatype_bytes(backend._float_dtype)
1117
+ dtype = np.float32 if nbytes == 4 else np.float16
1118
+ rot_dim = matching_data.rotations.shape[1]
1003
1119
  candidates[3] = {
1004
1120
  x: euler_from_rotationmatrix(
1005
- np.frombuffer(i, dtype=matching_data.rotations.dtype).reshape(
1006
- candidates[0].ndim, candidates[0].ndim
1007
- )
1121
+ np.frombuffer(i, dtype=dtype).reshape(rot_dim, rot_dim)
1008
1122
  )
1009
1123
  for i, x in candidates[3].items()
1010
1124
  }
1011
- candidates.append((target.origin, template.origin, target.sampling_rate, args))
1125
+ candidates.append((target.origin, template.origin, template.sampling_rate, args))
1012
1126
  write_pickle(data=candidates, filename=args.output)
1013
1127
 
1014
1128
  runtime = time() - start