pytme 0.2.0b0__cp311-cp311-macosx_14_0_arm64.whl → 0.2.2__cp311-cp311-macosx_14_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. pytme-0.2.2.data/scripts/match_template.py +1187 -0
  2. {pytme-0.2.0b0.data → pytme-0.2.2.data}/scripts/postprocess.py +170 -71
  3. {pytme-0.2.0b0.data → pytme-0.2.2.data}/scripts/preprocessor_gui.py +179 -86
  4. pytme-0.2.2.dist-info/METADATA +91 -0
  5. pytme-0.2.2.dist-info/RECORD +74 -0
  6. {pytme-0.2.0b0.dist-info → pytme-0.2.2.dist-info}/WHEEL +1 -1
  7. scripts/extract_candidates.py +126 -87
  8. scripts/match_template.py +596 -209
  9. scripts/match_template_filters.py +571 -223
  10. scripts/postprocess.py +170 -71
  11. scripts/preprocessor_gui.py +179 -86
  12. scripts/refine_matches.py +567 -159
  13. tme/__init__.py +0 -1
  14. tme/__version__.py +1 -1
  15. tme/analyzer.py +627 -855
  16. tme/backends/__init__.py +41 -11
  17. tme/backends/_jax_utils.py +185 -0
  18. tme/backends/cupy_backend.py +120 -225
  19. tme/backends/jax_backend.py +282 -0
  20. tme/backends/matching_backend.py +464 -388
  21. tme/backends/mlx_backend.py +45 -68
  22. tme/backends/npfftw_backend.py +256 -514
  23. tme/backends/pytorch_backend.py +41 -154
  24. tme/density.py +312 -421
  25. tme/extensions.cpython-311-darwin.so +0 -0
  26. tme/matching_data.py +366 -303
  27. tme/matching_exhaustive.py +279 -1521
  28. tme/matching_optimization.py +234 -129
  29. tme/matching_scores.py +884 -0
  30. tme/matching_utils.py +281 -387
  31. tme/memory.py +377 -0
  32. tme/orientations.py +226 -66
  33. tme/parser.py +3 -4
  34. tme/preprocessing/__init__.py +2 -0
  35. tme/preprocessing/_utils.py +217 -0
  36. tme/preprocessing/composable_filter.py +31 -0
  37. tme/preprocessing/compose.py +55 -0
  38. tme/preprocessing/frequency_filters.py +388 -0
  39. tme/preprocessing/tilt_series.py +1011 -0
  40. tme/preprocessor.py +574 -530
  41. tme/structure.py +495 -189
  42. tme/types.py +5 -3
  43. pytme-0.2.0b0.data/scripts/match_template.py +0 -800
  44. pytme-0.2.0b0.dist-info/METADATA +0 -73
  45. pytme-0.2.0b0.dist-info/RECORD +0 -66
  46. tme/helpers.py +0 -881
  47. tme/matching_constrained.py +0 -195
  48. {pytme-0.2.0b0.data → pytme-0.2.2.data}/scripts/estimate_ram_usage.py +0 -0
  49. {pytme-0.2.0b0.data → pytme-0.2.2.data}/scripts/preprocess.py +0 -0
  50. {pytme-0.2.0b0.dist-info → pytme-0.2.2.dist-info}/LICENSE +0 -0
  51. {pytme-0.2.0b0.dist-info → pytme-0.2.2.dist-info}/entry_points.txt +0 -0
  52. {pytme-0.2.0b0.dist-info → pytme-0.2.2.dist-info}/top_level.txt +0 -0
scripts/postprocess.py CHANGED
@@ -8,9 +8,8 @@
8
8
  import argparse
9
9
  from sys import exit
10
10
  from os import getcwd
11
- from os.path import join, abspath
12
- from typing import List
13
- from os.path import splitext
11
+ from typing import List, Tuple
12
+ from os.path import join, abspath, splitext
14
13
 
15
14
  import numpy as np
16
15
  from numpy.typing import NDArray
@@ -26,9 +25,11 @@ from tme.analyzer import (
26
25
  )
27
26
  from tme.matching_utils import (
28
27
  load_pickle,
28
+ centered_mask,
29
29
  euler_to_rotationmatrix,
30
30
  euler_from_rotationmatrix,
31
31
  )
32
+ from tme.matching_optimization import create_score_object, optimize_match
32
33
 
33
34
  PEAK_CALLERS = {
34
35
  "PeakCallerSort": PeakCallerSort,
@@ -40,9 +41,7 @@ PEAK_CALLERS = {
40
41
 
41
42
 
42
43
  def parse_args():
43
- parser = argparse.ArgumentParser(
44
- description="Peak Calling for Template Matching Outputs"
45
- )
44
+ parser = argparse.ArgumentParser(description="Analyze Template Matching Outputs")
46
45
 
47
46
  input_group = parser.add_argument_group("Input")
48
47
  output_group = parser.add_argument_group("Output")
@@ -55,6 +54,13 @@ def parse_args():
55
54
  nargs="+",
56
55
  help="Path to the output of match_template.py.",
57
56
  )
57
+ input_group.add_argument(
58
+ "--background_file",
59
+ required=False,
60
+ nargs="+",
61
+ help="Path to an output of match_template.py used for normalization. "
62
+ "For instance from --scramble_phases or a different template.",
63
+ )
58
64
  input_group.add_argument(
59
65
  "--target_mask",
60
66
  required=False,
@@ -86,7 +92,7 @@ def parse_args():
86
92
  "average",
87
93
  ],
88
94
  default="orientations",
89
- help="Available output formats:"
95
+ help="Available output formats: "
90
96
  "orientations (translation, rotation, and score), "
91
97
  "alignment (aligned template to target based on orientations), "
92
98
  "extraction (extract regions around peaks from targets, i.e. subtomograms), "
@@ -181,6 +187,13 @@ def parse_args():
181
187
  required=False,
182
188
  help="Number of accepted false-positives picks to determine minimum score.",
183
189
  )
190
+ additional_group.add_argument(
191
+ "--local_optimization",
192
+ action="store_true",
193
+ required=False,
194
+ help="[Experimental] Perform local optimization of candidates. Useful when the "
195
+ "number of identified candidats is small (< 10).",
196
+ )
184
197
 
185
198
  args = parser.parse_args()
186
199
 
@@ -195,38 +208,53 @@ def parse_args():
195
208
 
196
209
  if args.minimum_score is not None or args.n_false_positives is not None:
197
210
  args.number_of_peaks = np.iinfo(np.int64).max
198
- else:
211
+ elif args.number_of_peaks is None:
199
212
  args.number_of_peaks = 1000
200
213
 
214
+ if args.background_file is None:
215
+ args.background_file = [None]
216
+ if len(args.background_file) == 1:
217
+ args.background_file = args.background_file * len(args.input_file)
218
+ elif len(args.background_file) not in (0, len(args.input_file)):
219
+ raise ValueError(
220
+ "--background_file needs to be specified once or for each --input_file."
221
+ )
222
+
201
223
  return args
202
224
 
203
225
 
204
- def load_template(filepath: str, sampling_rate: NDArray, center: bool = True):
226
+ def load_template(
227
+ filepath: str,
228
+ sampling_rate: NDArray,
229
+ centering: bool = True,
230
+ target_shape: Tuple[int] = None,
231
+ ):
205
232
  try:
206
233
  template = Density.from_file(filepath)
207
- center_of_mass = template.center_of_mass(template.data)
234
+ center = np.divide(np.subtract(template.shape, 1), 2)
208
235
  template_is_density = True
209
- except ValueError:
236
+ except Exception:
210
237
  template = Structure.from_file(filepath)
211
- center_of_mass = template.center_of_mass()[::-1]
238
+ center = template.center_of_mass()[::-1]
212
239
  template = Density.from_structure(template, sampling_rate=sampling_rate)
213
240
  template_is_density = False
214
241
 
215
- translation = np.zeros_like(center_of_mass)
216
- if center:
242
+ translation = np.zeros_like(center)
243
+ if centering and template_is_density:
217
244
  template, translation = template.centered(0)
245
+ center = np.divide(np.subtract(template.shape, 1), 2)
218
246
 
219
- return template, center_of_mass, translation, template_is_density
247
+ return template, center, translation, template_is_density
220
248
 
221
249
 
222
- def merge_outputs(data, filepaths: List[str], args):
223
- if len(filepaths) == 0:
250
+ def merge_outputs(data, foreground_paths: List[str], background_paths: List[str], args):
251
+ if len(foreground_paths) == 0:
224
252
  return data, 1
225
253
 
226
254
  if data[0].ndim != data[2].ndim:
227
255
  return data, 1
228
256
 
229
- from tme.matching_exhaustive import _normalize_under_mask
257
+ from tme.matching_exhaustive import normalize_under_mask
230
258
 
231
259
  def _norm_scores(data, args):
232
260
  target_origin, _, sampling_rate, cli_args = data[-1]
@@ -235,7 +263,7 @@ def merge_outputs(data, filepaths: List[str], args):
235
263
  ret = load_template(
236
264
  filepath=cli_args.template,
237
265
  sampling_rate=sampling_rate,
238
- center=not cli_args.no_centering,
266
+ centering=not cli_args.no_centering,
239
267
  )
240
268
  template, center_of_mass, translation, template_is_density = ret
241
269
 
@@ -256,13 +284,16 @@ def merge_outputs(data, filepaths: List[str], args):
256
284
  mask.shape, np.multiply(args.min_boundary_distance, 2)
257
285
  ).astype(int)
258
286
  mask[cropped_shape] = 0
259
- _normalize_under_mask(template=data[0], mask=mask, mask_intensity=mask.sum())
287
+ normalize_under_mask(template=data[0], mask=mask, mask_intensity=mask.sum())
260
288
  return data[0]
261
289
 
262
290
  entities = np.zeros_like(data[0])
263
291
  data[0] = _norm_scores(data=data, args=args)
264
- for index, filepath in enumerate(filepaths):
265
- new_scores = _norm_scores(data=load_pickle(filepath), args=args)
292
+ for index, filepath in enumerate(foreground_paths):
293
+ new_scores = _norm_scores(
294
+ data=load_match_template_output(filepath, background_paths[index]),
295
+ args=args,
296
+ )
266
297
  indices = new_scores > data[0]
267
298
  entities[indices] = index + 1
268
299
  data[0][indices] = new_scores[indices]
@@ -270,9 +301,18 @@ def merge_outputs(data, filepaths: List[str], args):
270
301
  return data, entities
271
302
 
272
303
 
304
+ def load_match_template_output(foreground_path, background_path):
305
+ data = load_pickle(foreground_path)
306
+ if background_path is not None:
307
+ data_background = load_pickle(background_path)
308
+ data[0] = (data[0] - data_background[0]) / (1 - data_background[0])
309
+ np.fmax(data[0], 0, out=data[0])
310
+ return data
311
+
312
+
273
313
  def main():
274
314
  args = parse_args()
275
- data = load_pickle(args.input_file[0])
315
+ data = load_match_template_output(args.input_file[0], args.background_file[0])
276
316
 
277
317
  target_origin, _, sampling_rate, cli_args = data[-1]
278
318
 
@@ -280,7 +320,7 @@ def main():
280
320
  ret = load_template(
281
321
  filepath=cli_args.template,
282
322
  sampling_rate=sampling_rate,
283
- center=not cli_args.no_centering,
323
+ centering=not cli_args.no_centering,
284
324
  )
285
325
  template, center_of_mass, translation, template_is_density = ret
286
326
 
@@ -310,7 +350,14 @@ def main():
310
350
  max_shape = np.max(template.shape)
311
351
  args.min_boundary_distance = np.ceil(np.divide(max_shape, 2))
312
352
 
313
- # data, entities = merge_outputs(data=data, filepaths=args.input_file[1:], args=args)
353
+ entities = None
354
+ if len(args.input_file) > 1:
355
+ data, entities = merge_outputs(
356
+ data=data,
357
+ foreground_paths=args.input_file,
358
+ background_paths=args.background_file,
359
+ args=args,
360
+ )
314
361
 
315
362
  orientations = args.orientations
316
363
  if orientations is None:
@@ -323,57 +370,69 @@ def main():
323
370
  target_mask = Density.from_file(args.target_mask)
324
371
  scores = scores * target_mask.data
325
372
 
326
- if args.n_false_positives is not None:
327
- args.n_false_positives = max(args.n_false_positives, 1)
328
- cropped_shape = np.subtract(
329
- scores.shape, np.multiply(args.min_boundary_distance, 2)
330
- ).astype(int)
373
+ cropped_shape = np.subtract(
374
+ scores.shape, np.multiply(args.min_boundary_distance, 2)
375
+ ).astype(int)
331
376
 
332
- cropped_shape = tuple(
377
+ if args.min_boundary_distance > 0:
378
+ scores = centered_mask(scores, new_shape=cropped_shape)
379
+
380
+ if args.n_false_positives is not None:
381
+ # Rickgauer et al. 2017
382
+ cropped_slice = tuple(
333
383
  slice(
334
384
  int(args.min_boundary_distance),
335
385
  int(x - args.min_boundary_distance),
336
386
  )
337
387
  for x in scores.shape
338
388
  )
339
- # Rickgauer et al. 2017
340
- n_correlations = np.size(scores[cropped_shape]) * len(rotation_mapping)
389
+ args.n_false_positives = max(args.n_false_positives, 1)
390
+ n_correlations = np.size(scores[cropped_slice]) * len(rotation_mapping)
341
391
  minimum_score = np.multiply(
342
392
  erfcinv(2 * args.n_false_positives / n_correlations),
343
- np.sqrt(2) * np.std(scores[cropped_shape]),
393
+ np.sqrt(2) * np.std(scores[cropped_slice]),
344
394
  )
345
395
  print(f"Determined minimum score cutoff: {minimum_score}.")
346
396
  minimum_score = max(minimum_score, 0)
347
397
  args.minimum_score = minimum_score
348
398
 
349
- peak_caller = PEAK_CALLERS[args.peak_caller](
350
- number_of_peaks=args.number_of_peaks,
351
- min_distance=args.min_distance,
352
- min_boundary_distance=args.min_boundary_distance,
353
- )
354
- if args.minimum_score is not None:
355
- args.number_of_peaks = np.inf
399
+ args.batch_dims = None
400
+ if hasattr(cli_args, "target_batch"):
401
+ args.batch_dims = cli_args.target_batch
402
+
403
+ peak_caller_kwargs = {
404
+ "number_of_peaks": args.number_of_peaks,
405
+ "min_distance": args.min_distance,
406
+ "min_boundary_distance": args.min_boundary_distance,
407
+ "batch_dims": args.batch_dims,
408
+ "minimum_score": args.minimum_score,
409
+ "maximum_score": args.maximum_score,
410
+ }
356
411
 
412
+ peak_caller = PEAK_CALLERS[args.peak_caller](**peak_caller_kwargs)
357
413
  peak_caller(
358
414
  scores,
359
- rotation_matrix=np.eye(3),
415
+ rotation_matrix=np.eye(template.data.ndim),
360
416
  mask=template.data,
361
417
  rotation_mapping=rotation_mapping,
362
418
  rotation_array=rotation_array,
363
- minimum_score=args.minimum_score,
364
419
  )
365
420
  candidates = peak_caller.merge(
366
- candidates=[tuple(peak_caller)],
367
- number_of_peaks=args.number_of_peaks,
368
- min_distance=args.min_distance,
369
- min_boundary_distance=args.min_boundary_distance,
421
+ candidates=[tuple(peak_caller)], **peak_caller_kwargs
370
422
  )
371
423
  if len(candidates) == 0:
372
- print("Found no peaks. Consider changing peak calling parameters.")
424
+ candidates = [[], [], [], []]
425
+ print("Found no peaks, consider changing peak calling parameters.")
373
426
  exit(-1)
374
427
 
375
428
  for translation, _, score, detail in zip(*candidates):
376
- rotations.append(rotation_mapping[rotation_array[tuple(translation)]])
429
+ rotation_index = rotation_array[tuple(translation)]
430
+ rotation = rotation_mapping.get(
431
+ rotation_index, np.zeros(template.data.ndim, int)
432
+ )
433
+ if rotation.ndim == 2:
434
+ rotation = euler_from_rotationmatrix(rotation)
435
+ rotations.append(rotation)
377
436
 
378
437
  else:
379
438
  candidates = data
@@ -381,8 +440,13 @@ def main():
381
440
  for i in range(translation.shape[0]):
382
441
  rotations.append(euler_from_rotationmatrix(rotation[i]))
383
442
 
384
- rotations = np.vstack(rotations).astype(float)
443
+ if len(rotations):
444
+ rotations = np.vstack(rotations).astype(float)
385
445
  translations, scores, details = candidates[0], candidates[2], candidates[3]
446
+
447
+ if entities is not None:
448
+ details = entities[tuple(translations.T)]
449
+
386
450
  orientations = Orientations(
387
451
  translations=translations,
388
452
  rotations=rotations,
@@ -390,14 +454,55 @@ def main():
390
454
  details=details,
391
455
  )
392
456
 
393
- if args.minimum_score is not None:
457
+ if args.minimum_score is not None and len(orientations.scores):
394
458
  keep = orientations.scores >= args.minimum_score
395
459
  orientations = orientations[keep]
396
460
 
397
- if args.maximum_score is not None:
461
+ if args.maximum_score is not None and len(orientations.scores):
398
462
  keep = orientations.scores <= args.maximum_score
399
463
  orientations = orientations[keep]
400
464
 
465
+ if args.peak_oversampling > 1:
466
+ peak_caller = peak_caller = PEAK_CALLERS[args.peak_caller]()
467
+ if data[0].ndim != data[2].ndim:
468
+ print(
469
+ "Input pickle does not contain template matching scores."
470
+ " Cannot oversample peaks."
471
+ )
472
+ exit(-1)
473
+ orientations.translations = peak_caller.oversample_peaks(
474
+ scores=data[0],
475
+ peak_positions=orientations.translations,
476
+ oversampling_factor=args.peak_oversampling,
477
+ )
478
+
479
+ if args.local_optimization:
480
+ target = Density.from_file(cli_args.target)
481
+ orientations.translations = orientations.translations.astype(np.float32)
482
+ orientations.rotations = orientations.rotations.astype(np.float32)
483
+ for index, (translation, angles, *_) in enumerate(orientations):
484
+ score_object = create_score_object(
485
+ score="FLC",
486
+ target=target.data.copy(),
487
+ template=template.data.copy(),
488
+ template_mask=template_mask.data.copy(),
489
+ )
490
+
491
+ center = np.divide(template.shape, 2)
492
+ init_translation = np.subtract(translation, center)
493
+ bounds_translation = tuple((x - 5, x + 5) for x in init_translation)
494
+
495
+ translation, rotation_matrix, score = optimize_match(
496
+ score_object=score_object,
497
+ optimization_method="basinhopping",
498
+ bounds_translation=bounds_translation,
499
+ maxiter=3,
500
+ x0=[*init_translation, *angles],
501
+ )
502
+ orientations.translations[index] = np.add(translation, center)
503
+ orientations.rotations[index] = angles
504
+ orientations.scores[index] = score * -1
505
+
401
506
  if args.output_format == "orientations":
402
507
  orientations.to_file(filename=f"{args.output_prefix}.tsv", file_format="text")
403
508
  exit(0)
@@ -506,7 +611,7 @@ def main():
506
611
  return_orientations=True,
507
612
  )
508
613
  out = np.zeros_like(template.data)
509
- out = np.zeros(np.multiply(template.shape, 2).astype(int))
614
+ # out = np.zeros(np.multiply(template.shape, 2).astype(int))
510
615
  for index in range(len(cand_slices)):
511
616
  from scipy.spatial.transform import Rotation
512
617
 
@@ -515,7 +620,6 @@ def main():
515
620
  )
516
621
  rotation_matrix = rotation.inv().as_matrix()
517
622
 
518
- # rotation_matrix = euler_to_rotationmatrix(orientations.rotations[index])
519
623
  subset = Density(target.data[obs_slices[index]])
520
624
  subset = subset.rigid_transform(rotation_matrix=rotation_matrix, order=1)
521
625
 
@@ -526,35 +630,30 @@ def main():
526
630
  ret.to_file(f"{args.output_prefix}_average.mrc")
527
631
  exit(0)
528
632
 
529
- if args.peak_oversampling > 1:
530
- peak_caller = peak_caller = PEAK_CALLERS[args.peak_caller]()
531
- if data[0].ndim != data[2].ndim:
532
- print(
533
- "Input pickle does not contain template matching scores."
534
- " Cannot oversample peaks."
535
- )
536
- exit(-1)
537
- orientations.translations = peak_caller.oversample_peaks(
538
- score_space=data[0],
539
- translations=orientations.translations,
540
- oversampling_factor=args.oversampling_factor,
541
- )
633
+ template, center, *_ = load_template(
634
+ filepath=cli_args.template,
635
+ sampling_rate=sampling_rate,
636
+ centering=not cli_args.no_centering,
637
+ target_shape=target.shape,
638
+ )
542
639
 
543
640
  for index, (translation, angles, *_) in enumerate(orientations):
544
641
  rotation_matrix = euler_to_rotationmatrix(angles)
545
642
  if template_is_density:
546
- translation = np.subtract(translation, center_of_mass)
643
+ translation = np.subtract(translation, center)
547
644
  transformed_template = template.rigid_transform(
548
645
  rotation_matrix=rotation_matrix
549
646
  )
550
- new_origin = np.add(target_origin / sampling_rate, translation)
551
- transformed_template.origin = np.multiply(new_origin, sampling_rate)
647
+ transformed_template.origin = np.add(
648
+ target_origin, np.multiply(translation, sampling_rate)
649
+ )
650
+
552
651
  else:
553
652
  template = Structure.from_file(cli_args.template)
554
653
  new_center_of_mass = np.add(
555
654
  np.multiply(translation, sampling_rate), target_origin
556
655
  )
557
- translation = np.subtract(new_center_of_mass, center_of_mass)
656
+ translation = np.subtract(new_center_of_mass, center)
558
657
  transformed_template = template.rigid_transform(
559
658
  translation=translation[::-1],
560
659
  rotation_matrix=rotation_matrix[::-1, ::-1],