pytme 0.2.0b0__cp311-cp311-macosx_14_0_arm64.whl → 0.2.2__cp311-cp311-macosx_14_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. pytme-0.2.2.data/scripts/match_template.py +1187 -0
  2. {pytme-0.2.0b0.data → pytme-0.2.2.data}/scripts/postprocess.py +170 -71
  3. {pytme-0.2.0b0.data → pytme-0.2.2.data}/scripts/preprocessor_gui.py +179 -86
  4. pytme-0.2.2.dist-info/METADATA +91 -0
  5. pytme-0.2.2.dist-info/RECORD +74 -0
  6. {pytme-0.2.0b0.dist-info → pytme-0.2.2.dist-info}/WHEEL +1 -1
  7. scripts/extract_candidates.py +126 -87
  8. scripts/match_template.py +596 -209
  9. scripts/match_template_filters.py +571 -223
  10. scripts/postprocess.py +170 -71
  11. scripts/preprocessor_gui.py +179 -86
  12. scripts/refine_matches.py +567 -159
  13. tme/__init__.py +0 -1
  14. tme/__version__.py +1 -1
  15. tme/analyzer.py +627 -855
  16. tme/backends/__init__.py +41 -11
  17. tme/backends/_jax_utils.py +185 -0
  18. tme/backends/cupy_backend.py +120 -225
  19. tme/backends/jax_backend.py +282 -0
  20. tme/backends/matching_backend.py +464 -388
  21. tme/backends/mlx_backend.py +45 -68
  22. tme/backends/npfftw_backend.py +256 -514
  23. tme/backends/pytorch_backend.py +41 -154
  24. tme/density.py +312 -421
  25. tme/extensions.cpython-311-darwin.so +0 -0
  26. tme/matching_data.py +366 -303
  27. tme/matching_exhaustive.py +279 -1521
  28. tme/matching_optimization.py +234 -129
  29. tme/matching_scores.py +884 -0
  30. tme/matching_utils.py +281 -387
  31. tme/memory.py +377 -0
  32. tme/orientations.py +226 -66
  33. tme/parser.py +3 -4
  34. tme/preprocessing/__init__.py +2 -0
  35. tme/preprocessing/_utils.py +217 -0
  36. tme/preprocessing/composable_filter.py +31 -0
  37. tme/preprocessing/compose.py +55 -0
  38. tme/preprocessing/frequency_filters.py +388 -0
  39. tme/preprocessing/tilt_series.py +1011 -0
  40. tme/preprocessor.py +574 -530
  41. tme/structure.py +495 -189
  42. tme/types.py +5 -3
  43. pytme-0.2.0b0.data/scripts/match_template.py +0 -800
  44. pytme-0.2.0b0.dist-info/METADATA +0 -73
  45. pytme-0.2.0b0.dist-info/RECORD +0 -66
  46. tme/helpers.py +0 -881
  47. tme/matching_constrained.py +0 -195
  48. {pytme-0.2.0b0.data → pytme-0.2.2.data}/scripts/estimate_ram_usage.py +0 -0
  49. {pytme-0.2.0b0.data → pytme-0.2.2.data}/scripts/preprocess.py +0 -0
  50. {pytme-0.2.0b0.dist-info → pytme-0.2.2.dist-info}/LICENSE +0 -0
  51. {pytme-0.2.0b0.dist-info → pytme-0.2.2.dist-info}/entry_points.txt +0 -0
  52. {pytme-0.2.0b0.dist-info → pytme-0.2.2.dist-info}/top_level.txt +0 -0
scripts/refine_matches.py CHANGED
@@ -1,52 +1,123 @@
1
1
  #!python3
2
- """ CLI to refine template matching candidates.
2
+ """ Iterative template matching parameter tuning.
3
3
 
4
- Copyright (c) 2023-2024 European Molecular Biology Laboratory
4
+ Copyright (c) 2024 European Molecular Biology Laboratory
5
5
 
6
6
  Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
7
7
  """
8
8
  import argparse
9
+ import subprocess
10
+ from sys import exit
9
11
  from time import time
12
+ from shutil import copyfile
13
+ from typing import Tuple, List, Dict
10
14
 
11
15
  import numpy as np
12
- from numpy.typing import NDArray
13
-
14
- from tme.backends import backend
15
- from tme import Density, Structure
16
- from tme.matching_data import MatchingData
17
- from tme.analyzer import MaxScoreOverRotations, MaxScoreOverTranslations
18
- from tme.matching_utils import (
19
- load_pickle,
20
- get_rotation_matrices,
21
- compute_parallelization_schedule,
22
- )
23
- from tme.matching_exhaustive import scan_subsets, MATCHING_EXHAUSTIVE_REGISTER
24
-
25
- from postprocess import Orientations
26
- from match_template import load_and_validate_mask
16
+ from scipy import optimize
17
+ from sklearn.metrics import roc_auc_score
18
+
19
+ from tme import Orientations, Density
20
+ from tme.matching_utils import generate_tempfile_name, load_pickle, write_pickle, create_mask
21
+ from tme.matching_exhaustive import MATCHING_EXHAUSTIVE_REGISTER
22
+
23
+ def parse_range(x : str):
24
+ start, stop,step = x.split(":")
25
+ return range(int(start), int(stop), int(step))
27
26
 
28
27
  def parse_args():
29
28
  parser = argparse.ArgumentParser(
30
- description="Refine Template Matching Orientations."
29
+ description="Refine template matching candidates using deep matching.",
31
30
  )
32
- parser.add_argument(
33
- "--input_file",
31
+ io_group = parser.add_argument_group("Input / Output")
32
+ io_group.add_argument(
33
+ "--orientations",
34
34
  required=True,
35
35
  type=str,
36
+ help="Path to an orientations file in a supported format. See "
37
+ "https://kosinskilab.github.io/pyTME/reference/api/tme.orientations.Orientations.from_file.html"
38
+ " for available options.",
39
+ )
40
+ io_group.add_argument(
41
+ "--output_prefix", required=True, type=str, help="Path to write output to."
42
+ )
43
+ io_group.add_argument(
44
+ "--iterations",
45
+ required=False,
46
+ default=0,
47
+ type=int,
48
+ help="Number of refinement iterations to perform.",
49
+ )
50
+ io_group.add_argument(
51
+ "--verbose",
52
+ action="store_true",
53
+ default=False,
54
+ help="More verbose and more files written to disk.",
55
+ )
56
+ matching_group = parser.add_argument_group("Template Matching")
57
+ matching_group.add_argument(
58
+ "--input_file",
59
+ required=False,
60
+ type=str,
36
61
  help="Path to the output of match_template.py.",
37
62
  )
38
- parser.add_argument(
39
- "--orientations",
40
- required=True,
63
+ matching_group.add_argument(
64
+ "-m",
65
+ "--target",
66
+ dest="target",
67
+ type=str,
68
+ required=False,
69
+ help="Path to a target in CCP4/MRC, EM, H5 or another format supported by "
70
+ "tme.density.Density.from_file "
71
+ "https://kosinskilab.github.io/pyTME/reference/api/tme.density.Density.from_file.html",
72
+ )
73
+ matching_group.add_argument(
74
+ "--target_mask",
75
+ dest="target_mask",
76
+ type=str,
77
+ required=False,
78
+ help="Path to a mask for the target in a supported format (see target).",
79
+ )
80
+ matching_group.add_argument(
81
+ "-i",
82
+ "--template",
83
+ dest="template",
84
+ type=str,
85
+ required=False,
86
+ help="Path to a template in PDB/MMCIF or other supported formats (see target).",
87
+ )
88
+ matching_group.add_argument(
89
+ "--template_mask",
90
+ dest="template_mask",
41
91
  type=str,
42
- help="Path to orientations from postprocess.py.",
92
+ required=False,
93
+ help="Path to a mask for the template in a supported format (see target).",
94
+ )
95
+ matching_group.add_argument(
96
+ "--invert_target_contrast",
97
+ dest="invert_target_contrast",
98
+ action="store_true",
99
+ default=False,
100
+ help="Invert the target's contrast and rescale linearly between zero and one. "
101
+ "This option is intended for targets where templates to-be-matched have "
102
+ "negative values, e.g. tomograms.",
43
103
  )
44
- parser.add_argument(
45
- "--output_file",
104
+ matching_group.add_argument(
105
+ "--angular_sampling",
106
+ dest="angular_sampling",
46
107
  required=True,
47
- help="Path to the refined output orientations.",
108
+ default=None,
109
+ help="Angular sampling rate using optimized rotational sets."
110
+ "A lower number yields more rotations. Values >= 180 sample only the identity.",
48
111
  )
49
- parser.add_argument(
112
+ matching_group.add_argument(
113
+ "-s",
114
+ dest="score",
115
+ type=str,
116
+ default="FLCSphericalMask",
117
+ choices=list(MATCHING_EXHAUSTIVE_REGISTER.keys()),
118
+ help="Template matching scoring function.",
119
+ )
120
+ matching_group.add_argument(
50
121
  "-n",
51
122
  dest="cores",
52
123
  required=False,
@@ -54,164 +125,501 @@ def parse_args():
54
125
  default=4,
55
126
  help="Number of cores used for template matching.",
56
127
  )
128
+ matching_group.add_argument(
129
+ "--use_gpu",
130
+ dest="use_gpu",
131
+ action="store_true",
132
+ default=False,
133
+ help="Whether to perform computations on the GPU.",
134
+ )
135
+ matching_group.add_argument(
136
+ "--no_centering",
137
+ dest="no_centering",
138
+ action="store_true",
139
+ help="Assumes the template is already centered and omits centering.",
140
+ )
141
+ matching_group.add_argument(
142
+ "--no_edge_padding",
143
+ dest="no_edge_padding",
144
+ action="store_true",
145
+ default=False,
146
+ help="Whether to not pad the edges of the target. Can be set if the target"
147
+ " has a well defined bounding box, e.g. a masked reconstruction.",
148
+ )
149
+ matching_group.add_argument(
150
+ "--no_fourier_padding",
151
+ dest="no_fourier_padding",
152
+ action="store_true",
153
+ default=False,
154
+ help="Whether input arrays should not be zero-padded to full convolution shape "
155
+ "for numerical stability. When working with very large targets, e.g. tomograms, "
156
+ "it is safe to use this flag and benefit from the performance gain.",
157
+ )
158
+
159
+ peak_group = parser.add_argument_group("Peak Calling")
160
+ peak_group.add_argument(
161
+ "--number_of_peaks",
162
+ type=int,
163
+ default=100,
164
+ required=False,
165
+ help="Upper limit of peaks to call, subject to filtering parameters. Default 1000. "
166
+ "If minimum_score is provided all peaks scoring higher will be reported.",
167
+ )
168
+ extraction_group = parser.add_argument_group("Extraction")
169
+ extraction_group.add_argument(
170
+ "--keep_out_of_box",
171
+ action="store_true",
172
+ required=False,
173
+ help="Whether to keep orientations that fall outside the box. If the "
174
+ "orientations are sensible, it is safe to pass this flag.",
175
+ )
176
+
177
+ optimization_group = parser.add_argument_group("Optimization")
178
+ optimization_group.add_argument(
179
+ "--lowpass",
180
+ dest="lowpass",
181
+ action="store_true",
182
+ default=False,
183
+ help="Optimize template matching lowpass filter cutoff.",
184
+ )
185
+ optimization_group.add_argument(
186
+ "--highpass",
187
+ dest="highpass",
188
+ action="store_true",
189
+ default=False,
190
+ help="Optimize template matching highpass filter cutoff.",
191
+ )
192
+ optimization_group.add_argument(
193
+ "--lowpass-range",
194
+ dest="lowpass_range",
195
+ type=str,
196
+ default="0:50:5",
197
+ help="Optimize template matching lowpass filter cutoff.",
198
+ )
199
+ optimization_group.add_argument(
200
+ "--highpass-range",
201
+ dest="highpass_range",
202
+ type=str,
203
+ default="0:50:5",
204
+ help="Optimize template matching highpass filter cutoff.",
205
+ )
206
+ optimization_group.add_argument(
207
+ "--translation-uncertainty",
208
+ dest="translation_uncertainty",
209
+ type=int,
210
+ default=None,
211
+ help="Optimize template matching highpass filter cutoff.",
212
+ )
213
+
57
214
 
58
215
  args = parser.parse_args()
59
216
 
217
+ data_present = args.target is not None and args.template is not None
218
+ if args.input_file is None and not data_present:
219
+ raise ValueError(
220
+ "Either --input_file or --target and --template need to be specified."
221
+ )
222
+ elif args.input_file is not None and data_present:
223
+ raise ValueError(
224
+ "Please specific either --input_file or --target and --template."
225
+ )
226
+
227
+ if args.lowpass_range != "None":
228
+ args.lowpass_range = parse_range(args.lowpass_range)
229
+ else:
230
+ args.lowpass_range = (None, )
231
+ if args.highpass_range != "None":
232
+ args.highpass_range = parse_range(args.highpass_range)
233
+ else:
234
+ args.highpass_range = (None, )
60
235
  return args
61
236
 
62
237
 
63
- def load_template(filepath: str, sampling_rate: NDArray) -> "Density":
64
- try:
65
- template = Density.from_file(filepath)
66
- except ValueError:
67
- template = Structure.from_file(filepath)
68
- template = Density.from_structure(template, sampling_rate=sampling_rate)
238
+ def argdict_to_command(input_args: Dict, executable: str) -> List:
239
+ ret = []
240
+ for key, value in input_args.items():
241
+ if value is None:
242
+ continue
243
+ elif isinstance(value, bool):
244
+ if value:
245
+ ret.append(key)
246
+ else:
247
+ ret.extend([key, value])
248
+
249
+ ret = [str(x) for x in ret]
250
+ ret.insert(0, executable)
251
+ return " ".join(ret)
252
+
253
+ def run_command(command):
254
+ ret = subprocess.run(command, capture_output=True, shell=True)
255
+ if ret.returncode != 0:
256
+ print(f"Error when executing: {command}.")
257
+ print(f"Stdout: {ret.stdout.decode('utf-8')}")
258
+ print(f"Stderr: {ret.stderr.decode('utf-8')}")
259
+ exit(-1)
260
+
261
+ return None
262
+
263
+ def create_stacking_argdict(args) -> Dict:
264
+ arg_dict = {
265
+ "--target": args.target,
266
+ "--template": args.template,
267
+ "--orientations": args.orientations,
268
+ "--output_file": args.candidate_stack_path,
269
+ "--keep_out_of_box": args.keep_out_of_box,
270
+ }
271
+ return arg_dict
272
+
273
+
274
+ def create_matching_argdict(args) -> Dict:
275
+ arg_dict = {
276
+ "--target": args.target,
277
+ "--template": args.template,
278
+ "--template_mask": args.template_mask,
279
+ "-o": args.match_template_path,
280
+ "-a": args.angular_sampling,
281
+ "-s": args.score,
282
+ "--no_fourier_padding": True,
283
+ "--no_edge_padding": True,
284
+ "--no_centering": args.no_centering,
285
+ "-n": args.cores,
286
+ "--invert_target_contrast": args.invert_target_contrast,
287
+ "--use_gpu": args.use_gpu,
288
+ }
289
+ return arg_dict
290
+
291
+
292
+ def create_postprocessing_argdict(args) -> Dict:
293
+ arg_dict = {
294
+ "--input_file": args.match_template_path,
295
+ "--target_mask": args.target_mask,
296
+ "--output_prefix": args.new_orientations_path,
297
+ "--peak_caller": "PeakCallerMaximumFilter",
298
+ "--number_of_peaks": args.number_of_peaks,
299
+ "--output_format": "orientations",
300
+ "--mask_edges": True,
301
+ }
302
+ if args.target_mask is not None:
303
+ arg_dict["--mask_edges"] = False
304
+ return arg_dict
305
+
306
+
307
+ def update_orientations(old : Orientations, new : Orientations, args, **kwargs) -> Orientations:
308
+ stack_shape = Density.from_file(args.candidate_stack_path, use_memmap=True).shape
309
+ stack_center = np.add(np.divide(stack_shape, 2).astype(int), np.mod(stack_shape, 2))
310
+
311
+ peak_number = new.translations[:, 0].astype(int)
312
+ new_translations = np.add(
313
+ old.translations[peak_number],
314
+ np.subtract(new.translations, stack_center)[:, 1:],
315
+ )
316
+ ret = old.copy()
317
+ ret.scores[:] = 0
318
+ ret.scores[peak_number] = new.scores
319
+ ret.translations[peak_number] = new_translations
69
320
 
70
- return template
321
+ # The effect of --align_orientations should be handled herer
322
+ return ret
71
323
 
72
324
 
73
- def main():
74
- args = parse_args()
75
- meta = load_pickle(args.input_file)[-1]
76
- target_origin, _, sampling_rate, cli_args = meta
325
+ class DeepMatcher:
326
+ def __init__(self, args, margin : float = 0.5):
327
+ self.args = args
328
+ self.margin = margin
329
+ self.orientations = Orientations.from_file(args.orientations)
77
330
 
78
- orientations = Orientations.from_file(
79
- filename=args.orientations, file_format="text"
80
- )
331
+ match_template_args = create_matching_argdict(args)
332
+ match_template_args["--target"] = args.candidate_stack_path
333
+ self.match_template_args = match_template_args
81
334
 
82
- template = load_template(cli_args.template, sampling_rate)
83
- template_mask = load_and_validate_mask(
84
- mask_target=template, mask_path=cli_args.template_mask
85
- )
335
+ self.filter_parameters = {}
336
+ if args.lowpass:
337
+ self.filter_parameters["--lowpass"] = 0
338
+ if args.highpass:
339
+ self.filter_parameters["--highpass"] = 200
340
+ # self.filter_parameters["--whiten"] = False
341
+ self.filter_parameters["--no_filter_target"] = False
86
342
 
87
- if not cli_args.no_centering:
88
- template, translation = template.centered(0)
89
343
 
90
- if template_mask is None:
91
- template_mask = template.empty
92
- if not cli_args.no_centering:
93
- enclosing_box = template.minimum_enclosing_box(
94
- 0, use_geometric_center=False
95
- )
96
- template_mask.adjust_box(enclosing_box)
344
+ self.postprocess_args = create_postprocessing_argdict(args)
345
+ self.postprocess_args["--number_of_peaks"] = 1
97
346
 
98
- template_mask.data[:] = 1
99
- translation = np.zeros_like(translation)
347
+ def get_initial_values(self) -> Tuple[float]:
348
+ ret = tuple(float(x) for x in self.filter_parameters.values())
349
+ return ret
100
350
 
101
- template_mask.pad(template.shape, center=False)
102
- origin_translation = np.divide(
103
- np.subtract(template.origin, template_mask.origin), template.sampling_rate
104
- )
105
- translation = np.add(translation, origin_translation)
351
+ def format_parameters(self, parameter_values: Tuple[float]) -> Dict:
352
+ ret = {}
353
+ for value, key in zip(parameter_values, self.filter_parameters.keys()):
354
+ ret[key] = value
355
+ if isinstance(self.filter_parameters[key], bool):
356
+ ret[key] = value > 0.5
357
+ return ret
106
358
 
107
- template_mask = template_mask.rigid_transform(
108
- rotation_matrix=np.eye(template_mask.data.ndim),
109
- translation=-translation,
110
- order=1,
111
- )
112
- template_mask.origin = template.origin.copy()
359
+ def forward(self, x : Tuple[float]):
113
360
 
114
- target = Density.from_file(cli_args.target)
115
- peaks = orientations.translations.astype(int)
116
- half_shape = np.divide(template.shape, 2).astype(int)
117
- observation_starts = np.subtract(peaks, half_shape)
118
- observation_stops = np.add(peaks, half_shape) + np.mod(template.shape, 2).astype(
119
- int
120
- )
121
361
 
122
- pruned_starts = np.maximum(observation_starts, 0)
123
- pruned_stops = np.minimum(observation_stops, target.shape)
362
+ # Label 1 -> True positive, label 0 -> false positive
363
+ orientations_new = self(x)
364
+ label, score = orientations_new.details, orientations_new.scores
365
+ # loss = np.add(
366
+ # (1 - label) * np.square(score),
367
+ # label * np.square(np.fmax(self.margin - score, 0.0))
368
+ # )
369
+ # loss = loss.mean()
370
+
371
+
372
+
373
+ loss = roc_auc_score(label, score)
374
+ # print(
375
+ # np.mean(score[label == 1]), np.mean(score[label == 0]),
376
+ # *x, loss, time()
377
+ # )
378
+
379
+ return loss
124
380
 
125
- keep_peaks = (
126
- np.sum(
127
- np.multiply(
128
- observation_starts == pruned_starts, observation_stops == pruned_stops
129
- ),
130
- axis=1,
381
+ def __call__(self, x: Tuple[float]):
382
+ filter_parameters = self.format_parameters(x)
383
+ self.match_template_args.update(filter_parameters)
384
+ match_template = argdict_to_command(
385
+ self.match_template_args,
386
+ executable="python3 $HOME/src/pytme/scripts/match_template_filters.py",
131
387
  )
132
- == observation_starts.shape[1]
133
- )
134
- observation_starts = observation_starts[keep_peaks]
135
- observation_stops = observation_stops[keep_peaks]
388
+ run_command(match_template)
389
+
390
+ # Assume we get a new peak for each input in the same order
391
+ postprocess = argdict_to_command(
392
+ self.postprocess_args,
393
+ executable="python3 $HOME/src/pytme/scripts/postprocess.py",
394
+ )
395
+ run_command(postprocess)
396
+
397
+ orientations_new = Orientations.from_file(
398
+ f"{self.postprocess_args['--output_prefix']}.tsv"
399
+ )
400
+ orientations_new = update_orientations(
401
+ new=orientations_new,
402
+ old=self.orientations,
403
+ args=self.args
404
+ )
405
+
406
+ label, score = orientations_new.details, orientations_new.scores
407
+ loss = roc_auc_score(label, score)
408
+ print(
409
+ np.mean(score[label == 1]), np.mean(score[label == 0]),
410
+ *x, 0, loss, time()
411
+ )
412
+
413
+
414
+ # Rerun with noise correction
415
+ temp_args = self.match_template_args.copy()
416
+ background_file = generate_tempfile_name(".pickle")
417
+ temp_args["--scramble_phases"] = True
418
+ temp_args["-o"] = background_file
419
+ match_template = argdict_to_command(
420
+ temp_args,
421
+ executable="python3 $HOME/src/pytme/scripts/match_template_filters.py",
422
+ )
423
+ run_command(match_template)
424
+ temp_args = self.match_template_args.copy()
425
+ temp_args["--background_file"] = background_file
426
+ postprocess = argdict_to_command(
427
+ self.postprocess_args,
428
+ executable="python3 $HOME/src/pytme/scripts/postprocess.py",
429
+ )
430
+ run_command(postprocess)
136
431
 
137
- observation_slices = [
138
- tuple(slice(s, e) for s, e in zip(start_row, stop_row))
139
- for start_row, stop_row in zip(observation_starts, observation_stops)
140
- ]
432
+ orientations_new = Orientations.from_file(
433
+ f"{self.postprocess_args['--output_prefix']}.tsv"
434
+ )
435
+ orientations_new = update_orientations(
436
+ new=orientations_new,
437
+ old=self.orientations,
438
+ args=self.args
439
+ )
141
440
 
142
- matching_data = MatchingData(target=target, template=template)
143
- matching_data.rotations = np.eye(template.data.ndim).reshape(1, 3, 3)
441
+ label, score = orientations_new.details, orientations_new.scores
442
+ loss = roc_auc_score(label, score)
443
+ print(
444
+ np.mean(score[label == 1]), np.mean(score[label == 0]),
445
+ *x, 1, loss, time()
446
+ )
144
447
 
145
- target_pad = matching_data.target_padding(pad_target=True)
146
- out_shape = np.add(target_pad, template.shape).astype(int)
448
+ return orientations_new
449
+
450
+ # def __call__(self, x: Tuple[float]):
451
+ # filter_parameters = self.format_parameters(x)
452
+ # # print(filter_parameters)
453
+ # self.match_template_args.update(filter_parameters)
454
+ # match_template = argdict_to_command(
455
+ # self.match_template_args,
456
+ # executable="python3 $HOME/src/pytme/scripts/match_template_filters.py",
457
+ # )
458
+ # run_command(match_template)
459
+
460
+ # data = load_pickle(self.args.match_template_path)
461
+ # temp_args = self.match_template_args.copy()
462
+ # temp_args["--scramble_phases"] = True
463
+ # # write_pickle(data, "/home/vmaurer/deep_matching/t.pickle")
464
+
465
+ # match_template = argdict_to_command(
466
+ # temp_args,
467
+ # executable="python3 $HOME/src/pytme/scripts/match_template_filters.py",
468
+ # )
469
+ # run_command(match_template)
470
+ # data_norm = load_pickle(self.args.match_template_path)
471
+ # # write_pickle(data_norm, "/home/vmaurer/deep_matching/noise.pickle")
472
+
473
+ # data[0] = (data[0] - data_norm[0]) / (1 - data_norm[0])
474
+ # data[0] = np.fmax(data[0], 0)
475
+ # write_pickle(data, self.args.match_template_path)
476
+
477
+ # # Assume we get a new peak for each input in the same order
478
+ # postprocess = argdict_to_command(
479
+ # self.postprocess_args,
480
+ # executable="python3 $HOME/src/pytme/scripts/postprocess.py",
481
+ # )
482
+ # run_command(postprocess)
483
+
484
+ # orientations_new = Orientations.from_file(
485
+ # f"{self.postprocess_args['--output_prefix']}.tsv"
486
+ # )
487
+ # orientations_new = update_orientations(
488
+ # new=orientations_new,
489
+ # old=self.orientations,
490
+ # args=self.args
491
+ # )
492
+
493
+ # return orientations_new
147
494
 
148
- observations = np.zeros((len(observation_slices), *out_shape))
149
495
 
496
+ def main():
497
+ print("Entered")
498
+ args = parse_args()
150
499
 
151
- for idx, obs_slice in enumerate(observation_slices):
152
- subset = matching_data.subset_by_slice(
153
- target_slice=obs_slice,
154
- target_pad=target_pad,
155
- invert_target=cli_args.invert_target_contrast,
500
+ if args.input_file is not None:
501
+ data = load_pickle(args.input_file)
502
+ target_origin, _, sampling_rate, cli_args = data[-1]
503
+ args.target, args.template = cli_args.target, cli_args.template
504
+
505
+ args.candidate_stack_path = generate_tempfile_name(suffix=".h5")
506
+ args.new_orientations_path = generate_tempfile_name()
507
+ args.match_template_path = generate_tempfile_name()
508
+
509
+ match_deep = DeepMatcher(args)
510
+ initial_values = match_deep.get_initial_values()
511
+
512
+ # Do a single pass over the data
513
+ if len(initial_values) == 0:
514
+ create_image_stack = create_stacking_argdict(args)
515
+ create_image_stack = argdict_to_command(
516
+ create_image_stack,
517
+ executable="python3 $HOME/src/pytme/scripts/extract_candidates.py",
518
+ )
519
+ run_command(create_image_stack)
520
+
521
+ print("Created image stack")
522
+ if args.verbose:
523
+ copyfile(args.candidate_stack_path, f"{args.output_prefix}_stack.h5")
524
+
525
+ print("Starting matching")
526
+ orientations = match_deep(x=())
527
+
528
+ if args.verbose:
529
+ copyfile(args.match_template_path, f"{args.output_prefix}_stack.pickle")
530
+ print("Completed matching")
531
+ orientations.to_file(f"{args.output_prefix}.tsv")
532
+ exit(0)
533
+
534
+ if args.translation_uncertainty is not None:
535
+ args.target_mask = generate_tempfile_name(suffix=".h5")
536
+
537
+ for current_iteration in range(args.iterations):
538
+ create_image_stack = create_stacking_argdict(args)
539
+ create_image_stack = argdict_to_command(
540
+ create_image_stack,
541
+ executable="python3 $HOME/src/pytme/scripts/extract_candidates.py",
542
+ )
543
+ run_command(create_image_stack)
544
+
545
+ if args.translation_uncertainty is not None:
546
+ dens = Density.from_file(args.candidate_stack_path)
547
+ stack_center = np.add(
548
+ np.divide(dens.data.shape, 2).astype(int), np.mod(dens.data.shape, 2)
549
+ ).astype(int)[1:]
550
+
551
+ out = dens.empty
552
+ out.data[:,...] = create_mask(
553
+ mask_type = "ellipse",
554
+ center = stack_center,
555
+ radius = args.translation_uncertainty,
556
+ shape = dens.data.shape[1:]
557
+ )
558
+ out.to_file(args.target_mask)
559
+
560
+
561
+
562
+ # Perhaps we need a different optimizer here to use sensible steps for each parameter
563
+ parameters, min_loss = (), None
564
+ match_deep = DeepMatcher(args)
565
+ # for lowpass in (0, 10, 20, 50):
566
+ # for highpass in (50, 100, 150, 200):
567
+ # for whiten in (False, True):
568
+ # loss = match_deep.forward((lowpass, highpass, whiten))
569
+ # # print((lowpass, highpass), loss)
570
+ # if min_loss is None:
571
+ # min_loss = loss
572
+ # if loss < min_loss:
573
+ # min_loss = loss
574
+ # parameters = (lowpass, highpass, whiten),
575
+
576
+ # for lowpass in (10, 50, 100, 200):
577
+ # for highpass in (10, 50, 100, 200):
578
+ for lowpass in args.lowpass_range:
579
+ for highpass in args.highpass_range:
580
+ if lowpass is not None and highpass is not None:
581
+ if lowpass >= highpass:
582
+ continue
583
+ for no_filter_target in (True, False):
584
+ loss = match_deep.forward((lowpass, highpass, no_filter_target))
585
+ if min_loss is None:
586
+ min_loss = loss
587
+ if loss < min_loss:
588
+ min_loss = loss
589
+ parameters = (lowpass, highpass, no_filter_target)
590
+
591
+ # print("Final output", min_loss, parameters)
592
+ import sys
593
+ sys.exit(0)
594
+
595
+ # parameters = optimize.minimize(
596
+ # x0=match_deep.get_initial_values(),
597
+ # fun=match_deep.forward,
598
+ # method="L-BFGS-B",
599
+ # options={"maxiter": 100}
600
+ # )
601
+ parameter_dict = match_deep.format_parameters(parameters)
602
+ print("Converged with parameters", parameters)
603
+
604
+ match_template = create_matching_argdict(args)
605
+ match_template.update(parameter_dict)
606
+ match_template = argdict_to_command(
607
+ match_template,
608
+ executable="python3 $HOME/src/pytme/scripts/match_template_filters.py",
156
609
  )
157
- xd = template.copy()
158
- xd.pad(subset.target.shape, center = True)
159
- # observations[idx] = subset.target
160
- observations[idx] = xd.data
161
-
162
-
163
- matching_data = MatchingData(target=observations, template=template)
164
- matching_data._set_batch_dimension(target_dims=0, template_dims=None)
165
- matching_data.rotations = get_rotation_matrices(15)
166
- if template_mask is not None:
167
- matching_data.template_mask = template_mask.data
168
-
169
- template_box = np.ones(len(matching_data._output_template_shape), dtype=int)
170
- target_padding = np.zeros_like(matching_data._output_target_shape)
171
-
172
- scoring_method = "FLC"
173
- callback_class = MaxScoreOverRotations
174
- splits, schedule = compute_parallelization_schedule(
175
- shape1=matching_data._output_target_shape,
176
- shape2=template_box,
177
- shape1_padding=target_padding,
178
- max_cores=args.cores,
179
- max_ram=backend.get_available_memory(),
180
- split_only_outer=False,
181
- matching_method=scoring_method,
182
- analyzer_method=callback_class.__name__,
183
- backend=backend._backend_name,
184
- float_nbytes=backend.datatype_bytes(backend._default_dtype),
185
- complex_nbytes=backend.datatype_bytes(backend._complex_dtype),
186
- integer_nbytes=backend.datatype_bytes(backend._default_dtype_int),
187
- )
188
-
189
- matching_setup, matching_score = MATCHING_EXHAUSTIVE_REGISTER[scoring_method]
190
-
191
- start = time()
192
- candidates = scan_subsets(
193
- matching_data=matching_data,
194
- matching_score=matching_score,
195
- matching_setup=matching_setup,
196
- callback_class=callback_class,
197
- callback_class_args={
198
- # "score_space_shape" : (
199
- # matching_data.rotations.shape[0], observations.shape[0]
200
- # ),
201
- # "score_space_dtype" : backend._default_dtype,
202
- # "template_shape" : (matching_data.rotations.shape[0], *matching_data._template.shape)
203
- },
204
- target_splits=splits,
205
- job_schedule=schedule,
206
- pad_target_edges=False,
207
- pad_fourier=False,
208
- interpolation_order=cli_args.interpolation_order,
209
- )
210
- print(candidates[0].max())
211
- Density(candidates[0][0]).to_file("scores.mrc")
212
-
213
- runtime = time() - start
214
- print(f"\nRuntime real: {runtime:.3f}s user: {(runtime * args.cores):.3f}s.")
610
+ _ = subprocess.run(match_template, capture_output=True, shell=True)
611
+
612
+ # Some form of labelling is necessary for these matches
613
+ # 1. All of them are true positives
614
+ # 2. All of them are true positives up to a certain threshold
615
+ # 3. Kernel fitting
616
+ # 4. Perhaps also sensible to include a certain percentage of low scores as true negatives
617
+ postprocess = create_postprocessing_argdict(args)
618
+ postprocess = argdict_to_command(postprocess, executable="postprocess.py")
619
+ _ = subprocess.run(postprocess, capture_output=True, shell=True)
620
+ args.orientations = f"{args.new_orientations_path}.tsv"
621
+ orientations = Orientations.from_file(args.orientations)
622
+ orientations.to_file(f"{args.output_prefix}_{current_iteration}.tsv")
215
623
 
216
624
 
217
625
  if __name__ == "__main__":