pytme 0.2.9__cp311-cp311-macosx_15_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (119) hide show
  1. pytme-0.2.9.data/scripts/estimate_ram_usage.py +97 -0
  2. pytme-0.2.9.data/scripts/match_template.py +1135 -0
  3. pytme-0.2.9.data/scripts/postprocess.py +622 -0
  4. pytme-0.2.9.data/scripts/preprocess.py +209 -0
  5. pytme-0.2.9.data/scripts/preprocessor_gui.py +1227 -0
  6. pytme-0.2.9.dist-info/METADATA +95 -0
  7. pytme-0.2.9.dist-info/RECORD +119 -0
  8. pytme-0.2.9.dist-info/WHEEL +5 -0
  9. pytme-0.2.9.dist-info/entry_points.txt +6 -0
  10. pytme-0.2.9.dist-info/licenses/LICENSE +153 -0
  11. pytme-0.2.9.dist-info/top_level.txt +3 -0
  12. scripts/__init__.py +0 -0
  13. scripts/estimate_ram_usage.py +97 -0
  14. scripts/match_template.py +1135 -0
  15. scripts/postprocess.py +622 -0
  16. scripts/preprocess.py +209 -0
  17. scripts/preprocessor_gui.py +1227 -0
  18. tests/__init__.py +0 -0
  19. tests/data/Blurring/blob_width18.npy +0 -0
  20. tests/data/Blurring/edgegaussian_sigma3.npy +0 -0
  21. tests/data/Blurring/gaussian_sigma2.npy +0 -0
  22. tests/data/Blurring/hamming_width6.npy +0 -0
  23. tests/data/Blurring/kaiserb_width18.npy +0 -0
  24. tests/data/Blurring/localgaussian_sigma0510.npy +0 -0
  25. tests/data/Blurring/mean_size5.npy +0 -0
  26. tests/data/Blurring/ntree_sigma0510.npy +0 -0
  27. tests/data/Blurring/rank_rank3.npy +0 -0
  28. tests/data/Maps/.DS_Store +0 -0
  29. tests/data/Maps/emd_8621.mrc.gz +0 -0
  30. tests/data/README.md +2 -0
  31. tests/data/Raw/em_map.map +0 -0
  32. tests/data/Structures/.DS_Store +0 -0
  33. tests/data/Structures/1pdj.cif +3339 -0
  34. tests/data/Structures/1pdj.pdb +1429 -0
  35. tests/data/Structures/5khe.cif +3685 -0
  36. tests/data/Structures/5khe.ent +2210 -0
  37. tests/data/Structures/5khe.pdb +2210 -0
  38. tests/data/Structures/5uz4.cif +70548 -0
  39. tests/preprocessing/__init__.py +0 -0
  40. tests/preprocessing/test_compose.py +76 -0
  41. tests/preprocessing/test_frequency_filters.py +178 -0
  42. tests/preprocessing/test_preprocessor.py +136 -0
  43. tests/preprocessing/test_utils.py +79 -0
  44. tests/test_analyzer.py +216 -0
  45. tests/test_backends.py +446 -0
  46. tests/test_density.py +503 -0
  47. tests/test_extensions.py +130 -0
  48. tests/test_matching_cli.py +283 -0
  49. tests/test_matching_data.py +162 -0
  50. tests/test_matching_exhaustive.py +124 -0
  51. tests/test_matching_memory.py +30 -0
  52. tests/test_matching_optimization.py +226 -0
  53. tests/test_matching_utils.py +189 -0
  54. tests/test_orientations.py +175 -0
  55. tests/test_parser.py +33 -0
  56. tests/test_rotations.py +153 -0
  57. tests/test_structure.py +247 -0
  58. tme/__init__.py +6 -0
  59. tme/__version__.py +1 -0
  60. tme/analyzer/__init__.py +2 -0
  61. tme/analyzer/_utils.py +186 -0
  62. tme/analyzer/aggregation.py +577 -0
  63. tme/analyzer/peaks.py +953 -0
  64. tme/backends/__init__.py +171 -0
  65. tme/backends/_cupy_utils.py +734 -0
  66. tme/backends/_jax_utils.py +188 -0
  67. tme/backends/cupy_backend.py +294 -0
  68. tme/backends/jax_backend.py +314 -0
  69. tme/backends/matching_backend.py +1270 -0
  70. tme/backends/mlx_backend.py +241 -0
  71. tme/backends/npfftw_backend.py +583 -0
  72. tme/backends/pytorch_backend.py +430 -0
  73. tme/data/__init__.py +0 -0
  74. tme/data/c48n309.npy +0 -0
  75. tme/data/c48n527.npy +0 -0
  76. tme/data/c48n9.npy +0 -0
  77. tme/data/c48u1.npy +0 -0
  78. tme/data/c48u1153.npy +0 -0
  79. tme/data/c48u1201.npy +0 -0
  80. tme/data/c48u1641.npy +0 -0
  81. tme/data/c48u181.npy +0 -0
  82. tme/data/c48u2219.npy +0 -0
  83. tme/data/c48u27.npy +0 -0
  84. tme/data/c48u2947.npy +0 -0
  85. tme/data/c48u3733.npy +0 -0
  86. tme/data/c48u4749.npy +0 -0
  87. tme/data/c48u5879.npy +0 -0
  88. tme/data/c48u7111.npy +0 -0
  89. tme/data/c48u815.npy +0 -0
  90. tme/data/c48u83.npy +0 -0
  91. tme/data/c48u8649.npy +0 -0
  92. tme/data/c600v.npy +0 -0
  93. tme/data/c600vc.npy +0 -0
  94. tme/data/metadata.yaml +80 -0
  95. tme/data/quat_to_numpy.py +42 -0
  96. tme/data/scattering_factors.pickle +0 -0
  97. tme/density.py +2263 -0
  98. tme/extensions.cpython-311-darwin.so +0 -0
  99. tme/external/bindings.cpp +332 -0
  100. tme/filters/__init__.py +6 -0
  101. tme/filters/_utils.py +311 -0
  102. tme/filters/bandpass.py +230 -0
  103. tme/filters/compose.py +81 -0
  104. tme/filters/ctf.py +393 -0
  105. tme/filters/reconstruction.py +160 -0
  106. tme/filters/wedge.py +542 -0
  107. tme/filters/whitening.py +191 -0
  108. tme/matching_data.py +863 -0
  109. tme/matching_exhaustive.py +497 -0
  110. tme/matching_optimization.py +1311 -0
  111. tme/matching_scores.py +1183 -0
  112. tme/matching_utils.py +1188 -0
  113. tme/memory.py +337 -0
  114. tme/orientations.py +598 -0
  115. tme/parser.py +685 -0
  116. tme/preprocessor.py +1329 -0
  117. tme/rotations.py +350 -0
  118. tme/structure.py +1864 -0
  119. tme/types.py +13 -0
@@ -0,0 +1,622 @@
1
+ #!python
2
+ """ CLI to simplify analysing the output of match_template.py.
3
+
4
+ Copyright (c) 2023 European Molecular Biology Laboratory
5
+
6
+ Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
7
+ """
8
+ import argparse
9
+ from sys import exit
10
+ from os import getcwd
11
+ from typing import List, Tuple
12
+ from os.path import join, abspath, splitext
13
+
14
+ import numpy as np
15
+ from numpy.typing import NDArray
16
+ from scipy.special import erfcinv
17
+
18
+ from tme import Density, Structure, Orientations
19
+ from tme.matching_utils import load_pickle, centered_mask
20
+ from tme.matching_optimization import create_score_object, optimize_match
21
+ from tme.rotations import euler_to_rotationmatrix, euler_from_rotationmatrix
22
+ from tme.analyzer import (
23
+ PeakCallerSort,
24
+ PeakCallerMaximumFilter,
25
+ PeakCallerFast,
26
+ PeakCallerRecursiveMasking,
27
+ PeakCallerScipy,
28
+ )
29
+
30
+
31
+ PEAK_CALLERS = {
32
+ "PeakCallerSort": PeakCallerSort,
33
+ "PeakCallerMaximumFilter": PeakCallerMaximumFilter,
34
+ "PeakCallerFast": PeakCallerFast,
35
+ "PeakCallerRecursiveMasking": PeakCallerRecursiveMasking,
36
+ "PeakCallerScipy": PeakCallerScipy,
37
+ }
38
+
39
+
40
+ def parse_args():
41
+ parser = argparse.ArgumentParser(description="Analyze Template Matching Outputs")
42
+
43
+ input_group = parser.add_argument_group("Input")
44
+ output_group = parser.add_argument_group("Output")
45
+ peak_group = parser.add_argument_group("Peak Calling")
46
+ additional_group = parser.add_argument_group("Additional Parameters")
47
+
48
+ input_group.add_argument(
49
+ "--input_file",
50
+ required=True,
51
+ nargs="+",
52
+ help="Path to the output of match_template.py.",
53
+ )
54
+ input_group.add_argument(
55
+ "--background_file",
56
+ required=False,
57
+ nargs="+",
58
+ help="Path to an output of match_template.py used for normalization. "
59
+ "For instance from --scramble_phases or a different template.",
60
+ )
61
+ input_group.add_argument(
62
+ "--target_mask",
63
+ required=False,
64
+ type=str,
65
+ help="Path to an optional mask applied to template matching scores.",
66
+ )
67
+ input_group.add_argument(
68
+ "--orientations",
69
+ required=False,
70
+ type=str,
71
+ help="Path to file generated using output_format orientations. Can be filtered "
72
+ "to exclude false-positive peaks. If this file is provided, peak calling "
73
+ "is skipped and corresponding parameters ignored.",
74
+ )
75
+
76
+ output_group.add_argument(
77
+ "--output_prefix",
78
+ required=True,
79
+ help="Output filename, extension will be added based on output_format.",
80
+ )
81
+ output_group.add_argument(
82
+ "--output_format",
83
+ choices=[
84
+ "orientations",
85
+ "relion4",
86
+ "relion5",
87
+ "alignment",
88
+ "extraction",
89
+ "average",
90
+ ],
91
+ default="orientations",
92
+ help="Available output formats: "
93
+ "orientations (translation, rotation, and score), "
94
+ "relion4 (RELION 4 star format), "
95
+ "relion5 (RELION 5 star format), "
96
+ "alignment (aligned template to target based on orientations), "
97
+ "extraction (extract regions around peaks from targets, i.e. subtomograms), "
98
+ "average (extract matched regions from target and average them).",
99
+ )
100
+
101
+ peak_group.add_argument(
102
+ "--peak_caller",
103
+ choices=list(PEAK_CALLERS.keys()),
104
+ default="PeakCallerScipy",
105
+ help="Peak caller for local maxima identification.",
106
+ )
107
+ peak_group.add_argument(
108
+ "--minimum_score",
109
+ type=float,
110
+ default=None,
111
+ help="Minimum score from which peaks will be considered.",
112
+ )
113
+ peak_group.add_argument(
114
+ "--maximum_score",
115
+ type=float,
116
+ default=None,
117
+ help="Maximum score until which peaks will be considered.",
118
+ )
119
+ peak_group.add_argument(
120
+ "--min_distance",
121
+ type=int,
122
+ default=5,
123
+ help="Minimum distance between peaks.",
124
+ )
125
+ peak_group.add_argument(
126
+ "--min_boundary_distance",
127
+ type=int,
128
+ default=0,
129
+ help="Minimum distance of peaks to target edges.",
130
+ )
131
+ peak_group.add_argument(
132
+ "--mask_edges",
133
+ action="store_true",
134
+ default=False,
135
+ help="Whether candidates should not be identified from scores that were "
136
+ "computed from padded densities. Superseded by min_boundary_distance.",
137
+ )
138
+ peak_group.add_argument(
139
+ "--num_peaks",
140
+ type=int,
141
+ default=1000,
142
+ required=False,
143
+ help="Upper limit of peaks to call, subject to filtering parameters. Default 1000. "
144
+ "If minimum_score is provided all peaks scoring higher will be reported.",
145
+ )
146
+ peak_group.add_argument(
147
+ "--peak_oversampling",
148
+ type=int,
149
+ default=1,
150
+ help="1 / factor equals voxel precision, e.g. 2 detects half voxel "
151
+ "translations. Useful for matching structures to electron density maps.",
152
+ )
153
+
154
+ additional_group.add_argument(
155
+ "--subtomogram_box_size",
156
+ type=int,
157
+ default=None,
158
+ help="Subtomogram box size, by default equal to the centered template. Will be "
159
+ "padded to even values if output_format is relion.",
160
+ )
161
+ additional_group.add_argument(
162
+ "--mask_subtomograms",
163
+ action="store_true",
164
+ default=False,
165
+ help="Whether to mask subtomograms using the template mask. The mask will be "
166
+ "rotated according to determined angles.",
167
+ )
168
+ additional_group.add_argument(
169
+ "--invert_target_contrast",
170
+ action="store_true",
171
+ default=False,
172
+ help="Whether to invert the target contrast.",
173
+ )
174
+ additional_group.add_argument(
175
+ "--n_false_positives",
176
+ type=int,
177
+ default=None,
178
+ required=False,
179
+ help="Number of accepted false-positives picks to determine minimum score.",
180
+ )
181
+ additional_group.add_argument(
182
+ "--local_optimization",
183
+ action="store_true",
184
+ required=False,
185
+ help="[Experimental] Perform local optimization of candidates. Useful when the "
186
+ "number of identified candidats is small (< 10).",
187
+ )
188
+
189
+ args = parser.parse_args()
190
+
191
+ if args.output_format == "relion" and args.subtomogram_box_size is not None:
192
+ args.subtomogram_box_size += args.subtomogram_box_size % 2
193
+
194
+ if args.orientations is not None:
195
+ args.orientations = Orientations.from_file(filename=args.orientations)
196
+
197
+ if args.background_file is None:
198
+ args.background_file = [None]
199
+ if len(args.background_file) == 1:
200
+ args.background_file = args.background_file * len(args.input_file)
201
+ elif len(args.background_file) not in (0, len(args.input_file)):
202
+ raise ValueError(
203
+ "--background_file needs to be specified once or for each --input_file."
204
+ )
205
+
206
+ return args
207
+
208
+
209
+ def load_template(
210
+ filepath: str,
211
+ sampling_rate: NDArray,
212
+ centering: bool = True,
213
+ target_shape: Tuple[int] = None,
214
+ ):
215
+ try:
216
+ template = Density.from_file(filepath)
217
+ center = np.divide(np.subtract(template.shape, 1), 2)
218
+ template_is_density = True
219
+ except Exception:
220
+ template = Structure.from_file(filepath)
221
+ center = template.center_of_mass()
222
+ template = Density.from_structure(template, sampling_rate=sampling_rate)
223
+ template_is_density = False
224
+
225
+ translation = np.zeros_like(center)
226
+ if centering and template_is_density:
227
+ template, translation = template.centered(0)
228
+ center = np.divide(np.subtract(template.shape, 1), 2)
229
+
230
+ return template, center, translation, template_is_density
231
+
232
+
233
+ def merge_outputs(data, foreground_paths: List[str], background_paths: List[str], args):
234
+ if len(foreground_paths) == 0:
235
+ return data, 1
236
+
237
+ if data[0].ndim != data[2].ndim:
238
+ return data, 1
239
+
240
+ from tme.matching_exhaustive import normalize_under_mask
241
+
242
+ def _norm_scores(data, args):
243
+ target_origin, _, sampling_rate, cli_args = data[-1]
244
+
245
+ _, template_extension = splitext(cli_args.template)
246
+ ret = load_template(
247
+ filepath=cli_args.template,
248
+ sampling_rate=sampling_rate,
249
+ centering=not cli_args.no_centering,
250
+ )
251
+ template, center_of_mass, translation, template_is_density = ret
252
+
253
+ if args.mask_edges and args.min_boundary_distance == 0:
254
+ max_shape = np.max(template.shape)
255
+ args.min_boundary_distance = np.ceil(np.divide(max_shape, 2))
256
+
257
+ target_mask = 1
258
+ if args.target_mask is not None:
259
+ target_mask = Density.from_file(args.target_mask).data
260
+ elif cli_args.target_mask is not None:
261
+ target_mask = Density.from_file(args.target_mask).data
262
+
263
+ mask = np.ones_like(data[0])
264
+ np.multiply(mask, target_mask, out=mask)
265
+
266
+ cropped_shape = np.subtract(
267
+ mask.shape, np.multiply(args.min_boundary_distance, 2)
268
+ ).astype(int)
269
+ mask[cropped_shape] = 0
270
+ normalize_under_mask(template=data[0], mask=mask, mask_intensity=mask.sum())
271
+ return data[0]
272
+
273
+ entities = np.zeros_like(data[0])
274
+ data[0] = _norm_scores(data=data, args=args)
275
+ for index, filepath in enumerate(foreground_paths):
276
+ new_scores = _norm_scores(
277
+ data=load_match_template_output(filepath, background_paths[index]),
278
+ args=args,
279
+ )
280
+ indices = new_scores > data[0]
281
+ entities[indices] = index + 1
282
+ data[0][indices] = new_scores[indices]
283
+
284
+ return data, entities
285
+
286
+
287
+ def load_match_template_output(foreground_path, background_path):
288
+ data = load_pickle(foreground_path)
289
+ if background_path is not None:
290
+ data_background = load_pickle(background_path)
291
+ data[0] = (data[0] - data_background[0]) / (1 - data_background[0])
292
+ np.fmax(data[0], 0, out=data[0])
293
+ return data
294
+
295
+
296
+ def main():
297
+ args = parse_args()
298
+ data = load_match_template_output(args.input_file[0], args.background_file[0])
299
+
300
+ target_origin, _, sampling_rate, cli_args = data[-1]
301
+
302
+ _, template_extension = splitext(cli_args.template)
303
+ ret = load_template(
304
+ filepath=cli_args.template,
305
+ sampling_rate=sampling_rate,
306
+ centering=not cli_args.no_centering,
307
+ )
308
+ template, center_of_mass, translation, template_is_density = ret
309
+
310
+ template_mask = template.empty
311
+ template_mask.data[:] = 1
312
+ if cli_args.template_mask is not None:
313
+ template_mask = Density.from_file(cli_args.template_mask)
314
+ template_mask.pad(template.shape, center=False)
315
+ origin_translation = np.divide(
316
+ np.subtract(template.origin, template_mask.origin), template.sampling_rate
317
+ )
318
+ translation = np.add(translation, origin_translation)
319
+
320
+ template_mask = template_mask.rigid_transform(
321
+ rotation_matrix=np.eye(template_mask.data.ndim),
322
+ translation=-translation,
323
+ order=1,
324
+ )
325
+
326
+ if args.mask_edges and args.min_boundary_distance == 0:
327
+ max_shape = np.max(template.shape)
328
+ args.min_boundary_distance = np.ceil(np.divide(max_shape, 2))
329
+
330
+ entities = None
331
+ if len(args.input_file) > 1:
332
+ data, entities = merge_outputs(
333
+ data=data,
334
+ foreground_paths=args.input_file,
335
+ background_paths=args.background_file,
336
+ args=args,
337
+ )
338
+
339
+ orientations = args.orientations
340
+ if orientations is None:
341
+ translations, rotations, scores, details = [], [], [], []
342
+ # Output is MaxScoreOverRotations
343
+ if data[0].ndim == data[2].ndim:
344
+ scores, offset, rotation_array, rotation_mapping, meta = data
345
+
346
+ if args.target_mask is not None:
347
+ target_mask = Density.from_file(args.target_mask)
348
+ scores = scores * target_mask.data
349
+
350
+ cropped_shape = np.subtract(
351
+ scores.shape, np.multiply(args.min_boundary_distance, 2)
352
+ ).astype(int)
353
+
354
+ if args.min_boundary_distance > 0:
355
+ scores = centered_mask(scores, new_shape=cropped_shape)
356
+
357
+ if args.n_false_positives is not None:
358
+ # Rickgauer et al. 2017
359
+ cropped_slice = tuple(
360
+ slice(
361
+ int(args.min_boundary_distance),
362
+ int(x - args.min_boundary_distance),
363
+ )
364
+ for x in scores.shape
365
+ )
366
+ args.n_false_positives = max(args.n_false_positives, 1)
367
+ n_correlations = np.size(scores[cropped_slice]) * len(rotation_mapping)
368
+ minimum_score = np.multiply(
369
+ erfcinv(2 * args.n_false_positives / n_correlations),
370
+ np.sqrt(2) * np.std(scores[cropped_slice]),
371
+ )
372
+ print(f"Determined minimum score cutoff: {minimum_score}.")
373
+ minimum_score = max(minimum_score, 0)
374
+ args.minimum_score = minimum_score
375
+
376
+ args.batch_dims = None
377
+ if hasattr(cli_args, "target_batch"):
378
+ args.batch_dims = cli_args.target_batch
379
+
380
+ peak_caller_kwargs = {
381
+ "shape": scores.shape,
382
+ "num_peaks": args.num_peaks,
383
+ "min_distance": args.min_distance,
384
+ "min_boundary_distance": args.min_boundary_distance,
385
+ "batch_dims": args.batch_dims,
386
+ "minimum_score": args.minimum_score,
387
+ "maximum_score": args.maximum_score,
388
+ }
389
+
390
+ peak_caller = PEAK_CALLERS[args.peak_caller](**peak_caller_kwargs)
391
+ peak_caller(
392
+ scores,
393
+ rotation_matrix=np.eye(template.data.ndim),
394
+ mask=template.data,
395
+ rotation_mapping=rotation_mapping,
396
+ rotations=rotation_array,
397
+ )
398
+ candidates = peak_caller.merge(
399
+ candidates=[tuple(peak_caller)], **peak_caller_kwargs
400
+ )
401
+ if len(candidates) == 0:
402
+ candidates = [[], [], [], []]
403
+ print("Found no peaks, consider changing peak calling parameters.")
404
+ exit(-1)
405
+
406
+ for translation, _, score, detail in zip(*candidates):
407
+ rotation_index = rotation_array[tuple(translation)]
408
+ rotation = rotation_mapping.get(
409
+ rotation_index, np.zeros(template.data.ndim, int)
410
+ )
411
+ if rotation.ndim == 2:
412
+ rotation = euler_from_rotationmatrix(rotation)
413
+ rotations.append(rotation)
414
+
415
+ else:
416
+ candidates = data
417
+ translation, rotation, *_ = data
418
+ for i in range(translation.shape[0]):
419
+ rotations.append(euler_from_rotationmatrix(rotation[i]))
420
+
421
+ if len(rotations):
422
+ rotations = np.vstack(rotations).astype(float)
423
+ translations, scores, details = candidates[0], candidates[2], candidates[3]
424
+
425
+ if entities is not None:
426
+ details = entities[tuple(translations.T)]
427
+
428
+ orientations = Orientations(
429
+ translations=translations,
430
+ rotations=rotations,
431
+ scores=scores,
432
+ details=details,
433
+ )
434
+
435
+ if args.minimum_score is not None and len(orientations.scores):
436
+ keep = orientations.scores >= args.minimum_score
437
+ orientations = orientations[keep]
438
+
439
+ if args.maximum_score is not None and len(orientations.scores):
440
+ keep = orientations.scores <= args.maximum_score
441
+ orientations = orientations[keep]
442
+
443
+ if args.peak_oversampling > 1:
444
+ if data[0].ndim != data[2].ndim:
445
+ print(
446
+ "Input pickle does not contain template matching scores."
447
+ " Cannot oversample peaks."
448
+ )
449
+ exit(-1)
450
+ peak_caller = peak_caller = PEAK_CALLERS[args.peak_caller](shape=scores.shape)
451
+ orientations.translations = peak_caller.oversample_peaks(
452
+ scores=data[0],
453
+ peak_positions=orientations.translations,
454
+ oversampling_factor=args.peak_oversampling,
455
+ )
456
+
457
+ if args.local_optimization:
458
+ target = Density.from_file(cli_args.target, use_memmap=True)
459
+ orientations.translations = orientations.translations.astype(np.float32)
460
+ orientations.rotations = orientations.rotations.astype(np.float32)
461
+ for index, (translation, angles, *_) in enumerate(orientations):
462
+ score_object = create_score_object(
463
+ score="FLC",
464
+ target=target.data.copy(),
465
+ template=template.data.copy(),
466
+ template_mask=template_mask.data.copy(),
467
+ )
468
+
469
+ center = np.divide(template.shape, 2)
470
+ init_translation = np.subtract(translation, center)
471
+ bounds_translation = tuple((x - 5, x + 5) for x in init_translation)
472
+
473
+ translation, rotation_matrix, score = optimize_match(
474
+ score_object=score_object,
475
+ optimization_method="basinhopping",
476
+ bounds_translation=bounds_translation,
477
+ maxiter=3,
478
+ x0=[*init_translation, *angles],
479
+ )
480
+ orientations.translations[index] = np.add(translation, center)
481
+ orientations.rotations[index] = angles
482
+ orientations.scores[index] = score * -1
483
+
484
+ if args.output_format in ("orientations", "relion4", "relion5"):
485
+ file_format, extension = "text", "tsv"
486
+
487
+ version = None
488
+ if args.output_format in ("relion4", "relion5"):
489
+ version = "# version 40001"
490
+ file_format, extension = "star", "star"
491
+
492
+ if args.output_format == "relion5":
493
+ version = "# version 50001"
494
+ target = Density.from_file(cli_args.target, use_memmap=True)
495
+ orientations.translations = np.subtract(
496
+ orientations.translations, np.divide(target.shape, 2).astype(int)
497
+ )
498
+ orientations.translations = np.multiply(
499
+ orientations.translations, target.sampling_rate
500
+ )
501
+
502
+ orientations.to_file(
503
+ filename=f"{args.output_prefix}.{extension}",
504
+ file_format=file_format,
505
+ source_path=cli_args.target,
506
+ version=version,
507
+ )
508
+ exit(0)
509
+
510
+ target = Density.from_file(cli_args.target)
511
+ if args.invert_target_contrast:
512
+ target.data = target.data * -1
513
+
514
+ if args.output_format in ("extraction"):
515
+ if not np.all(np.divide(target.shape, template.shape) > 2):
516
+ print(
517
+ "Target might be too small relative to template to extract"
518
+ " meaningful particles."
519
+ f" Target : {target.shape}, Template : {template.shape}."
520
+ )
521
+
522
+ extraction_shape = template.shape
523
+ if args.subtomogram_box_size is not None:
524
+ extraction_shape = np.repeat(
525
+ args.subtomogram_box_size, len(extraction_shape)
526
+ )
527
+
528
+ orientations, cand_slices, obs_slices = orientations.get_extraction_slices(
529
+ target_shape=target.shape,
530
+ extraction_shape=extraction_shape,
531
+ drop_out_of_box=True,
532
+ return_orientations=True,
533
+ )
534
+
535
+ working_directory = getcwd()
536
+
537
+ observations = np.zeros((len(cand_slices), *extraction_shape))
538
+ slices = zip(cand_slices, obs_slices)
539
+ for idx, (cand_slice, obs_slice) in enumerate(slices):
540
+ observations[idx][:] = np.mean(target.data[obs_slice])
541
+ observations[idx][cand_slice] = target.data[obs_slice]
542
+
543
+ for index in range(observations.shape[0]):
544
+ cand_start = [x.start for x in cand_slices[index]]
545
+ out_density = Density(
546
+ data=observations[index],
547
+ sampling_rate=sampling_rate,
548
+ origin=np.multiply(cand_start, sampling_rate),
549
+ )
550
+ if args.mask_subtomograms:
551
+ rotation_matrix = euler_to_rotationmatrix(orientations.rotations[index])
552
+ mask_transfomed = template_mask.rigid_transform(
553
+ rotation_matrix=rotation_matrix, order=1
554
+ )
555
+ out_density.data = out_density.data * mask_transfomed.data
556
+ out_density.to_file(
557
+ join(working_directory, f"{args.output_prefix}_{index}.mrc")
558
+ )
559
+
560
+ exit(0)
561
+
562
+ if args.output_format == "average":
563
+ orientations, cand_slices, obs_slices = orientations.get_extraction_slices(
564
+ target_shape=target.shape,
565
+ extraction_shape=template.shape,
566
+ drop_out_of_box=True,
567
+ return_orientations=True,
568
+ )
569
+ out = np.zeros_like(template.data)
570
+ for index in range(len(cand_slices)):
571
+ subset = Density(target.data[obs_slices[index]])
572
+ rotation_matrix = euler_to_rotationmatrix(orientations.rotations[index])
573
+
574
+ subset = subset.rigid_transform(
575
+ rotation_matrix=np.linalg.inv(rotation_matrix),
576
+ order=1,
577
+ use_geometric_center=True,
578
+ )
579
+ np.add(out, subset.data, out=out)
580
+
581
+ out /= len(cand_slices)
582
+ ret = Density(out, sampling_rate=template.sampling_rate, origin=0)
583
+ ret.pad(template.shape, center=True)
584
+ ret.to_file(f"{args.output_prefix}.mrc")
585
+ exit(0)
586
+
587
+ template, center, *_ = load_template(
588
+ filepath=cli_args.template,
589
+ sampling_rate=sampling_rate,
590
+ centering=not cli_args.no_centering,
591
+ target_shape=target.shape,
592
+ )
593
+
594
+ for index, (translation, angles, *_) in enumerate(orientations):
595
+ rotation_matrix = euler_to_rotationmatrix(angles)
596
+ if template_is_density:
597
+ transformed_template = template.rigid_transform(
598
+ rotation_matrix=rotation_matrix, use_geometric_center=True
599
+ )
600
+ # Just adapting the coordinate system not the in-box position
601
+ shift = np.multiply(np.subtract(translation, center), sampling_rate)
602
+ transformed_template.origin = np.add(target_origin, shift)
603
+
604
+ else:
605
+ template = Structure.from_file(cli_args.template)
606
+ new_center_of_mass = np.add(
607
+ np.multiply(translation, sampling_rate), target_origin
608
+ )
609
+ translation = np.subtract(new_center_of_mass, center)
610
+ transformed_template = template.rigid_transform(
611
+ translation=translation,
612
+ rotation_matrix=rotation_matrix,
613
+ )
614
+ # template_extension should contain '.'
615
+ transformed_template.to_file(
616
+ f"{args.output_prefix}_{index}{template_extension}"
617
+ )
618
+ index += 1
619
+
620
+
621
+ if __name__ == "__main__":
622
+ main()