pytme 0.3.1__cp311-cp311-macosx_15_0_arm64.whl → 0.3.1.post2__cp311-cp311-macosx_15_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (30) hide show
  1. {pytme-0.3.1.data → pytme-0.3.1.post2.data}/scripts/match_template.py +2 -2
  2. {pytme-0.3.1.data → pytme-0.3.1.post2.data}/scripts/postprocess.py +16 -15
  3. {pytme-0.3.1.data → pytme-0.3.1.post2.data}/scripts/preprocessor_gui.py +1 -0
  4. {pytme-0.3.1.dist-info → pytme-0.3.1.post2.dist-info}/METADATA +2 -4
  5. {pytme-0.3.1.dist-info → pytme-0.3.1.post2.dist-info}/RECORD +30 -30
  6. scripts/match_template.py +2 -2
  7. scripts/postprocess.py +16 -15
  8. scripts/preprocessor_gui.py +1 -0
  9. scripts/refine_matches.py +5 -7
  10. tests/test_analyzer.py +2 -3
  11. tests/test_extensions.py +0 -1
  12. tests/test_orientations.py +0 -12
  13. tme/analyzer/aggregation.py +22 -12
  14. tme/backends/_jax_utils.py +56 -15
  15. tme/backends/cupy_backend.py +11 -11
  16. tme/backends/jax_backend.py +27 -9
  17. tme/backends/matching_backend.py +11 -0
  18. tme/backends/npfftw_backend.py +3 -0
  19. tme/density.py +58 -1
  20. tme/matching_data.py +24 -0
  21. tme/matching_exhaustive.py +5 -2
  22. tme/matching_scores.py +23 -0
  23. tme/orientations.py +20 -7
  24. {pytme-0.3.1.data → pytme-0.3.1.post2.data}/scripts/estimate_memory_usage.py +0 -0
  25. {pytme-0.3.1.data → pytme-0.3.1.post2.data}/scripts/preprocess.py +0 -0
  26. {pytme-0.3.1.data → pytme-0.3.1.post2.data}/scripts/pytme_runner.py +0 -0
  27. {pytme-0.3.1.dist-info → pytme-0.3.1.post2.dist-info}/WHEEL +0 -0
  28. {pytme-0.3.1.dist-info → pytme-0.3.1.post2.dist-info}/entry_points.txt +0 -0
  29. {pytme-0.3.1.dist-info → pytme-0.3.1.post2.dist-info}/licenses/LICENSE +0 -0
  30. {pytme-0.3.1.dist-info → pytme-0.3.1.post2.dist-info}/top_level.txt +0 -0
@@ -359,8 +359,8 @@ def parse_args():
359
359
  "--invert-target-contrast",
360
360
  action="store_true",
361
361
  default=False,
362
- help="Invert the target's contrast for cases where templates to-be-matched have "
363
- "negative values, e.g. tomograms.",
362
+ help="Invert the target contrast. Useful for matching on tomograms if the "
363
+ "template has not been inverted.",
364
364
  )
365
365
  io_group.add_argument(
366
366
  "--scramble-phases",
@@ -188,7 +188,7 @@ def parse_args():
188
188
  )
189
189
  additional_group.add_argument(
190
190
  "--n-false-positives",
191
- type=int,
191
+ type=float,
192
192
  default=None,
193
193
  required=False,
194
194
  help="Number of accepted false-positives picks to determine minimum score.",
@@ -318,11 +318,7 @@ def normalize_input(foregrounds: Tuple[str], backgrounds: Tuple[str]) -> Tuple:
318
318
  data = load_matching_output(foreground)
319
319
  scores, _, rotations, rotation_mapping, *_ = data
320
320
 
321
- # We could normalize to unit sdev, but that might lead to unexpected
322
- # results for flat background distributions
323
- # scores -= scores.mean()
324
321
  indices = tuple(slice(0, x) for x in scores.shape)
325
-
326
322
  indices_update = scores > scores_out[indices]
327
323
  scores_out[indices][indices_update] = scores[indices_update]
328
324
 
@@ -369,9 +365,7 @@ def normalize_input(foregrounds: Tuple[str], backgrounds: Tuple[str]) -> Tuple:
369
365
  scores_norm = np.full(out_shape_norm, fill_value=0, dtype=np.float32)
370
366
  for background in backgrounds:
371
367
  data_norm = load_matching_output(background)
372
-
373
- scores = data_norm[0]
374
- # scores -= scores.mean()
368
+ scores, _, rotations, rotation_mapping, *_ = data_norm
375
369
 
376
370
  indices = tuple(slice(0, x) for x in scores.shape)
377
371
  indices_update = scores > scores_norm[indices]
@@ -381,9 +375,10 @@ def normalize_input(foregrounds: Tuple[str], backgrounds: Tuple[str]) -> Tuple:
381
375
  update = tuple(slice(0, int(x)) for x in np.minimum(out_shape, scores.shape))
382
376
  scores_out = np.full(out_shape, fill_value=0, dtype=np.float32)
383
377
  scores_out[update] = data[0][update] - scores_norm[update]
378
+ scores_out = np.fmax(scores_out, 0, out=scores_out)
379
+ scores_out[update] += scores_norm[update].mean()
384
380
 
385
381
  # scores_out[update] = np.divide(scores_out[update], 1 - scores_norm[update])
386
- scores_out = np.fmax(scores_out, 0, out=scores_out)
387
382
  data[0] = scores_out
388
383
 
389
384
  fg, bg = simple_stats(data[0]), simple_stats(scores_norm)
@@ -485,8 +480,11 @@ def main():
485
480
  if orientations is None:
486
481
  translations, rotations, scores, details = [], [], [], []
487
482
 
488
- # Data processed by normalize_input is guaranteed to have this shape
489
- scores, offset, rotation_array, rotation_mapping, meta = data
483
+ var = None
484
+ # Data processed by normalize_input is guaranteed to have this shape)
485
+ scores, _, rotation_array, rotation_mapping, *_ = data
486
+ if len(data) == 6:
487
+ scores, _, rotation_array, rotation_mapping, var, *_ = data
490
488
 
491
489
  cropped_shape = np.subtract(
492
490
  scores.shape, np.multiply(args.min_boundary_distance, 2)
@@ -509,13 +507,16 @@ def main():
509
507
  )
510
508
  args.n_false_positives = max(args.n_false_positives, 1)
511
509
  n_correlations = np.size(scores[cropped_slice]) * len(rotation_mapping)
510
+ std = np.std(scores[cropped_slice])
511
+ if var is not None:
512
+ std = np.asarray(np.sqrt(var)).reshape(())
513
+
512
514
  minimum_score = np.multiply(
513
515
  erfcinv(2 * args.n_false_positives / n_correlations),
514
- np.sqrt(2) * np.std(scores[cropped_slice]),
516
+ np.sqrt(2) * std,
515
517
  )
516
- print(f"Determined minimum score cutoff: {minimum_score}.")
517
- minimum_score = max(minimum_score, 0)
518
- args.min_score = minimum_score
518
+ print(f"Determined cutoff --min-score {minimum_score}.")
519
+ args.min_score = max(minimum_score, 0)
519
520
 
520
521
  args.batch_dims = None
521
522
  if hasattr(cli_args, "batch_dims"):
@@ -474,6 +474,7 @@ def membrane_mask(
474
474
  **kwargs,
475
475
  ) -> NDArray:
476
476
  return create_mask(
477
+ center=(center_x, center_y, center_z),
477
478
  mask_type="membrane",
478
479
  shape=template.shape,
479
480
  radius=radius,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: pytme
3
- Version: 0.3.1
3
+ Version: 0.3.1.post2
4
4
  Summary: Python Template Matching Engine
5
5
  Author: Valentin Maurer
6
6
  Author-email: Valentin Maurer <valentin.maurer@embl-hamburg.de>
@@ -26,9 +26,7 @@ Requires-Dist: importlib_resources
26
26
  Requires-Dist: joblib
27
27
  Provides-Extra: cupy
28
28
  Requires-Dist: cupy-cuda12x>13.0.0; extra == "cupy"
29
- Provides-Extra: cupy-voltools
30
- Requires-Dist: cupy-cuda12x>13.0.0; extra == "cupy-voltools"
31
- Requires-Dist: voltools; extra == "cupy-voltools"
29
+ Requires-Dist: voltools; extra == "cupy"
32
30
  Provides-Extra: pytorch
33
31
  Requires-Dist: torch; extra == "pytorch"
34
32
  Requires-Dist: torchvision; extra == "pytorch"
@@ -1,33 +1,33 @@
1
- pytme-0.3.1.data/scripts/estimate_memory_usage.py,sha256=Ry46LXUv3SZ0g41g3RDUg9UH6hiSnnG3mHTyaGletXE,2114
2
- pytme-0.3.1.data/scripts/match_template.py,sha256=X0Sa4dBuE3fvKer--JgY7kPSyZ7H6fa5PKi2lxRoyzM,37563
3
- pytme-0.3.1.data/scripts/postprocess.py,sha256=0QOUzqa76Wrg7JK51qEtjm7XH78UFdMuDia50HmlXwU,27638
4
- pytme-0.3.1.data/scripts/preprocess.py,sha256=eq-67cuj3WFx5jAdS56yQSVv_sCt7SRW-dqzQRbLnVE,6328
5
- pytme-0.3.1.data/scripts/preprocessor_gui.py,sha256=R73N-UTSiAplTqyRw-SLGcGusnJVv2dlAFLuqmUrl-8,44153
6
- pytme-0.3.1.data/scripts/pytme_runner.py,sha256=dqCd60puAOOOSvuCxrJC1MbfdsRS-ctMso5YIwM-JkI,40356
7
- pytme-0.3.1.dist-info/licenses/LICENSE,sha256=gXf5dRMhNSbfLPYYTY_5hsZ1r7UU1OaKQEAQUhuIBkM,18092
1
+ pytme-0.3.1.post2.data/scripts/estimate_memory_usage.py,sha256=Ry46LXUv3SZ0g41g3RDUg9UH6hiSnnG3mHTyaGletXE,2114
2
+ pytme-0.3.1.post2.data/scripts/match_template.py,sha256=Px443VNaxto_GPuL16d2TcFp7zPrU-camVdGeW4dHNA,37556
3
+ pytme-0.3.1.post2.data/scripts/postprocess.py,sha256=-n7WFjw-AoVe1BLFCgYtfnjdSpZE65o2qeYVmxO_Jh4,27708
4
+ pytme-0.3.1.post2.data/scripts/preprocess.py,sha256=eq-67cuj3WFx5jAdS56yQSVv_sCt7SRW-dqzQRbLnVE,6328
5
+ pytme-0.3.1.post2.data/scripts/preprocessor_gui.py,sha256=fw1Q0G11Eit7TRt9BNqqn5vYthOWt4GdaFS-WfeznbE,44200
6
+ pytme-0.3.1.post2.data/scripts/pytme_runner.py,sha256=dqCd60puAOOOSvuCxrJC1MbfdsRS-ctMso5YIwM-JkI,40356
7
+ pytme-0.3.1.post2.dist-info/licenses/LICENSE,sha256=gXf5dRMhNSbfLPYYTY_5hsZ1r7UU1OaKQEAQUhuIBkM,18092
8
8
  scripts/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
9
9
  scripts/estimate_memory_usage.py,sha256=UiaX30o_H59vzTQZaXISxgzWj5jwgFbMoH1E5ydVHcw,2115
10
10
  scripts/eval.py,sha256=ebJVLxbRlB6TI5YHNr0VavZ4lmaRdf8QVafyiDhh_oU,2528
11
11
  scripts/extract_candidates.py,sha256=B2O4Xm0eVJzBZOblfkH5za8fTxLIdBRDl89Qwkq4Kjk,8097
12
- scripts/match_template.py,sha256=LsoS7gPh6tx3c5uyF3v5XbATjb8jJ_jh91zmO2IruZ0,37564
12
+ scripts/match_template.py,sha256=akKSFalnhsk2OdqELCKFvrBhJ49U_UoWX4xs-42NKAY,37557
13
13
  scripts/match_template_filters.py,sha256=Gj4a1b_S5NWp_dfFEPFn0D7jGf-qYgBbnTvZZ4bwqOQ,42036
14
- scripts/postprocess.py,sha256=YNSySz76axtNvz7R5eCvS90eJjtKQQbWMLSexFTkEbg,27639
14
+ scripts/postprocess.py,sha256=_9mFpFQww0TbMQeFaDRxR_AJh2wdGfXrNbqXVUZhtZY,27709
15
15
  scripts/preprocess.py,sha256=PrtO0aWGbTvSMqCdNYuzW02FoGvSsvQzJab52DTssQ4,6329
16
- scripts/preprocessor_gui.py,sha256=eozUE6D_vJOO11Prardwm7zWrjvTXLHFaAMmY7yfNIc,44154
16
+ scripts/preprocessor_gui.py,sha256=GUtq8_jU0NGqTEq5ZyjMaeBskJIxao4rznuraNv_02s,44201
17
17
  scripts/pytme_runner.py,sha256=BD5u45vtzJo_d6ZAM2N_CKnGpfjtNg4v7C-31vbTGnU,40357
18
- scripts/refine_matches.py,sha256=1AG-E2PuTauflR_qVC1CaZKg7ev9xxGu3CIYeyJ2-WQ,12777
18
+ scripts/refine_matches.py,sha256=dLu3aW0-iAn0Qn-GoQBaZfgEGUcTWMQ2J2tA8aNm5Yo,12725
19
19
  tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
20
- tests/test_analyzer.py,sha256=HNZf12dZbuMTJ1oYhdDd1nn9NSGXV12ReEqRjNQ96UI,8393
20
+ tests/test_analyzer.py,sha256=9TVJacQF44AnMYOWfLhqQgz1V2TeKjijVT1rJb50NMw,8360
21
21
  tests/test_backends.py,sha256=65jCvPfB2YjUHkhmjkC49Rpd1SyPRQ7_lnLXHI5QC-U,17361
22
22
  tests/test_density.py,sha256=JVQunV445qij5WdcpKn-5GKqT3endzjXvBPkIaXPADo,18914
23
- tests/test_extensions.py,sha256=1Zv9dG_dmmC2mlbX91YIPyGLSToPC0202-ffLAfVcr4,5203
23
+ tests/test_extensions.py,sha256=K71MvMhLdZz66ifJYqRl0IEr7-PCXSMxKokg-4NCc8Q,5170
24
24
  tests/test_matching_cli.py,sha256=JF7LQixpFQMpbijoPwtNSe9FoYXZmBR6J3T5BvsLbR8,11202
25
25
  tests/test_matching_data.py,sha256=U6ISe4lBLDj-OzgA6QAaoO_aegCJjPtXqHbPiPZ2tkA,6091
26
26
  tests/test_matching_exhaustive.py,sha256=bRPCN0pyZk3DmXMrWRcEGAksYoch8P7fRiwE3g0yIf4,4039
27
27
  tests/test_matching_memory.py,sha256=XrBGRi4cIp3-4nN6s7tj0pBbFJrvJaG3vQdtK-6uzY0,1151
28
28
  tests/test_matching_optimization.py,sha256=bYt3d_nHewTRxKCXgGaMLhObkqBufGwrYL4gfinr_jE,7732
29
29
  tests/test_matching_utils.py,sha256=m6mkjo-rgA0q67MYQwjeyy_y6GBErh68A-lqxvQ8NT4,4948
30
- tests/test_orientations.py,sha256=ydqYZc3mOWSweFEGWMrMaygxX7-7in7forUtiNWfQKc,6767
30
+ tests/test_orientations.py,sha256=BnTWHXplucO9_f1TognyZNVQdzGgr1ckt-YGDrlHDcU,6250
31
31
  tests/test_parser.py,sha256=57oaksWrKNB4Z_22IxfW0nXMyQWLJFVsuvnJQPhMn10,993
32
32
  tests/test_rotations.py,sha256=Kll51PZm_5MWP6TcW_II6oJLgORM4piZtWbzhn2fiA8,5374
33
33
  tests/test_structure.py,sha256=Qgrex3wYl9TpFLcqMeUOkq-7w93bzEsMNStKdJdsbnw,8881
@@ -62,16 +62,16 @@ tests/preprocessing/test_utils.py,sha256=RnqlA3vHcVguWp0gmBOwBgUZK85ylNLqr_oFkf8
62
62
  tme/__init__.py,sha256=GXfIuzU4fh2KZfgUOIXnNeKlJ8kSp3EqoBr6bSik848,217
63
63
  tme/__version__.py,sha256=r4xAFihOf72W9TD-lpMi6ntWSTKTP2SlzKP1ytkjRbI,22
64
64
  tme/cli.py,sha256=48Q7QGIuCwNSms5wOnGmWjMSlZrXjAQUz_9zsIFT9tY,3652
65
- tme/density.py,sha256=iupy3gL7gwE2Q9FHcZfHY_AcbOSUAo3i9tSbC-KF-DA,82405
65
+ tme/density.py,sha256=qYwLi6zuGlSYXP3mQwfahaAVrSolq718sujeK6ZSE1k,84484
66
66
  tme/extensions.cpython-311-darwin.so,sha256=xUCjh1eMK3iKWMC04oFG3LzxBRAqHpl7idIIV4ezx3Q,415840
67
67
  tme/mask.py,sha256=4xHqWtOabgdYDDCXHpXflaZmLlmE2_9-bxjs4AaCIYM,10472
68
- tme/matching_data.py,sha256=l9eIJOkANOn-sHWeCVU-1RskMloqH0KmUFQcOQTt6qQ,30361
69
- tme/matching_exhaustive.py,sha256=YswFCaRTaPaGQVRf7SgrmsqEHlnC83u8G4y-PIlR-8w,18270
68
+ tme/matching_data.py,sha256=y6UBUykYpZHsl5NEDyKthKztuwcKFIaOIlB0eEh-Sww,31255
69
+ tme/matching_exhaustive.py,sha256=DnWl7zBfnK83MFZzLDQ9WsBvI2mZkTikxpiXIWrdmxQ,18385
70
70
  tme/matching_optimization.py,sha256=r2zMyOKyvoP8OX9xValJE0DLWbLpNdyY1MLWcOi-H1U,45715
71
- tme/matching_scores.py,sha256=ZYw7ff0y0Zs6fzUwTCQ34Gth4Ckx2ZV7qNk1bH4NXdw,34849
71
+ tme/matching_scores.py,sha256=sX4hMki9WjwgenNWJ1wNu04J808wOgGuVqbRYSA2TB8,35860
72
72
  tme/matching_utils.py,sha256=Qn65B2F8v1yHLetqpiirvUbj-A9-J1giq3lE9cSb-1E,27426
73
73
  tme/memory.py,sha256=lOD_XJtAWo3IFG5fJIdYgLSpn8o0ORO6Vua_tlm5YEk,9683
74
- tme/orientations.py,sha256=93nmt4GPiuBsMWfwSuQVHUBrgY-5k53PlaCaDXZCZjc,21345
74
+ tme/orientations.py,sha256=23HQ2jmW74cAY7RBHe-fB-T1sb8lJ7OfbnP4W5uYZ_M,21857
75
75
  tme/parser.py,sha256=d59A7meIUJ8OzfB6eaVqd7CwLj9oYnOfubqlvygbu1U,24210
76
76
  tme/preprocessor.py,sha256=7DAGRfCPc9ge-hEuYsNA_KCeajVnpWl-w4NzGQS53GM,40377
77
77
  tme/rotations.py,sha256=wVnvZrRCHb1wd5E6yDsLxiP2G0Oin66A4W14vsz4i50,10641
@@ -79,19 +79,19 @@ tme/structure.py,sha256=ojk-IQnIe1JrzxUVHlNlyLxtwzNkNykvxTaXvyBMmbs,72135
79
79
  tme/types.py,sha256=NAY7C4qxE6yz-DXVtClMvFfoOV-spWGLNfpLATZ1LcU,442
80
80
  tme/analyzer/__init__.py,sha256=aJyVGlQwO6Ij-_NZ5rBlgBQQSIj0dqpUwlmvIvnQqBM,89
81
81
  tme/analyzer/_utils.py,sha256=48Xq2Hi_P3VYl1SozYz3WwEj3BwL1QdtFYx7dTxv0bc,6344
82
- tme/analyzer/aggregation.py,sha256=tl6Vw0VhdhPYwb-Y82eAK83odL2JD0d5C3N39qJRVZM,27875
82
+ tme/analyzer/aggregation.py,sha256=3ieZlNcDp0dhnQ85DdVLsCFrRfqU4jim3JRcMwORF_U,28518
83
83
  tme/analyzer/base.py,sha256=2uEYKa0xLwggJUenAA6-l8DIO5sYtklHMg9dR3DzAOI,3887
84
84
  tme/analyzer/peaks.py,sha256=y40YVa2zJFLWD3aoP0xtuhK8yOrQV0AVH6ElWyVjXkM,32587
85
85
  tme/analyzer/proxy.py,sha256=NH0J93ESl5NAVd_22rLZGDWjCx1ov8I995Fr3mYqUSM,4162
86
86
  tme/backends/__init__.py,sha256=VUvEnpfjmZQVn_9QHkLPBoa6y8M3gDyNNL9kq4OQNGE,5249
87
87
  tme/backends/_cupy_utils.py,sha256=scxCSK9BcsDaIbHR2mFxAJRIxdc0r5JTC1gCJmVJa64,24558
88
- tme/backends/_jax_utils.py,sha256=0rybsyG2QbGDdQg-1ruDIt1F15DZcz_TguZYjbepCA4,5838
88
+ tme/backends/_jax_utils.py,sha256=pbKtaNDZEf7mAxKoQwcUXTYIkVzpnzjhF_wcbxO4AC4,6799
89
89
  tme/backends/_numpyfftw_utils.py,sha256=RnYgbSs_sL-NlwYGRqU9kziq0U5qzl9GecVx9uUeSJs,7904
90
- tme/backends/cupy_backend.py,sha256=G3rsdmxWTes8pM82TTnWvoq1bwGrNsBTbfJMSc5gfTA,8269
91
- tme/backends/jax_backend.py,sha256=eHN0PvWVpjoKaYUJ0oRcdXwF7muKrgVAUzb96QRDXWI,12069
92
- tme/backends/matching_backend.py,sha256=oWARMu_Mt6apnvpYGCbQ1_x7tl_PZoxVtVb8HJgtQ-o,32953
90
+ tme/backends/cupy_backend.py,sha256=Ms2sObxr0xc_tdHctcL659G8YWOTqPG59F3665yZH4c,8163
91
+ tme/backends/jax_backend.py,sha256=Al1J60NfZTawmDWuCxxXLoEVu7HuP1sFdQlyMXbF1DE,12676
92
+ tme/backends/matching_backend.py,sha256=B9LXxKpfNYdJBIksyGm7g7V7SC9G3mrluscz3aqISsU,33199
93
93
  tme/backends/mlx_backend.py,sha256=3aAKv7VkK_jfxwDMPQZsLp_lRU8MBJ6zNAzrs6VZd3s,6718
94
- tme/backends/npfftw_backend.py,sha256=P151VbhML5myCMX1o4ArHYRNz_Dz6XoGVyQIesEK2zI,18698
94
+ tme/backends/npfftw_backend.py,sha256=8-Hs3slqnWvsHpqICxKxb-fbmg3IbSok3sThXew-P7M,18800
95
95
  tme/backends/pytorch_backend.py,sha256=yn1hhowUMmzMLHdh2jqw2B0CvBIVcSAWOvRMMDkSHmw,14229
96
96
  tme/data/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
97
97
  tme/data/c48n309.npy,sha256=NwH64mOEbm3tStq5c98o81fY1vMOoq4nvXDAh7Z7iZg,296768
@@ -126,8 +126,8 @@ tme/filters/ctf.py,sha256=HNB03Pw3gn2Y8pIM2391iRHQzp2NcHZWFDC61v3K0oM,24520
126
126
  tme/filters/reconstruction.py,sha256=4gCGpuWuM-yCDQvuzyex_Q8bmsFqUI_8IoZnv3B_3UA,7564
127
127
  tme/filters/wedge.py,sha256=JWiN9IwRXnc7A7syq11kSuin6QjhlVmi38nwGHu4yxE,19541
128
128
  tme/filters/whitening.py,sha256=Zwk-0jMVqy_F8TdQ0ht5unMz2JBOePOC3okUhQpU0bo,6348
129
- pytme-0.3.1.dist-info/METADATA,sha256=uGvw-jEiKmPKGI-cQurxjdCgvwrRhcCLZzbIQmF2NqM,5079
130
- pytme-0.3.1.dist-info/WHEEL,sha256=SPbiHAOPLnBtml4sk5MIwvWF6YnGfOfy9C__W6Bpeg4,109
131
- pytme-0.3.1.dist-info/entry_points.txt,sha256=pbUSmB0J4Ghlg0w7jHfaFSvPMuvRWzeSuvdjdRPisAU,288
132
- pytme-0.3.1.dist-info/top_level.txt,sha256=ovCUR7UXXouH3zYt_fJLoqr_vtjp1wudFgjVAnztQLE,18
133
- pytme-0.3.1.dist-info/RECORD,,
129
+ pytme-0.3.1.post2.dist-info/METADATA,sha256=2l8XEWvLoIRz-UdlbcOMZHTrRiRAmQIPBODnOGcYZUI,4985
130
+ pytme-0.3.1.post2.dist-info/WHEEL,sha256=SPbiHAOPLnBtml4sk5MIwvWF6YnGfOfy9C__W6Bpeg4,109
131
+ pytme-0.3.1.post2.dist-info/entry_points.txt,sha256=pbUSmB0J4Ghlg0w7jHfaFSvPMuvRWzeSuvdjdRPisAU,288
132
+ pytme-0.3.1.post2.dist-info/top_level.txt,sha256=ovCUR7UXXouH3zYt_fJLoqr_vtjp1wudFgjVAnztQLE,18
133
+ pytme-0.3.1.post2.dist-info/RECORD,,
scripts/match_template.py CHANGED
@@ -359,8 +359,8 @@ def parse_args():
359
359
  "--invert-target-contrast",
360
360
  action="store_true",
361
361
  default=False,
362
- help="Invert the target's contrast for cases where templates to-be-matched have "
363
- "negative values, e.g. tomograms.",
362
+ help="Invert the target contrast. Useful for matching on tomograms if the "
363
+ "template has not been inverted.",
364
364
  )
365
365
  io_group.add_argument(
366
366
  "--scramble-phases",
scripts/postprocess.py CHANGED
@@ -188,7 +188,7 @@ def parse_args():
188
188
  )
189
189
  additional_group.add_argument(
190
190
  "--n-false-positives",
191
- type=int,
191
+ type=float,
192
192
  default=None,
193
193
  required=False,
194
194
  help="Number of accepted false-positives picks to determine minimum score.",
@@ -318,11 +318,7 @@ def normalize_input(foregrounds: Tuple[str], backgrounds: Tuple[str]) -> Tuple:
318
318
  data = load_matching_output(foreground)
319
319
  scores, _, rotations, rotation_mapping, *_ = data
320
320
 
321
- # We could normalize to unit sdev, but that might lead to unexpected
322
- # results for flat background distributions
323
- # scores -= scores.mean()
324
321
  indices = tuple(slice(0, x) for x in scores.shape)
325
-
326
322
  indices_update = scores > scores_out[indices]
327
323
  scores_out[indices][indices_update] = scores[indices_update]
328
324
 
@@ -369,9 +365,7 @@ def normalize_input(foregrounds: Tuple[str], backgrounds: Tuple[str]) -> Tuple:
369
365
  scores_norm = np.full(out_shape_norm, fill_value=0, dtype=np.float32)
370
366
  for background in backgrounds:
371
367
  data_norm = load_matching_output(background)
372
-
373
- scores = data_norm[0]
374
- # scores -= scores.mean()
368
+ scores, _, rotations, rotation_mapping, *_ = data_norm
375
369
 
376
370
  indices = tuple(slice(0, x) for x in scores.shape)
377
371
  indices_update = scores > scores_norm[indices]
@@ -381,9 +375,10 @@ def normalize_input(foregrounds: Tuple[str], backgrounds: Tuple[str]) -> Tuple:
381
375
  update = tuple(slice(0, int(x)) for x in np.minimum(out_shape, scores.shape))
382
376
  scores_out = np.full(out_shape, fill_value=0, dtype=np.float32)
383
377
  scores_out[update] = data[0][update] - scores_norm[update]
378
+ scores_out = np.fmax(scores_out, 0, out=scores_out)
379
+ scores_out[update] += scores_norm[update].mean()
384
380
 
385
381
  # scores_out[update] = np.divide(scores_out[update], 1 - scores_norm[update])
386
- scores_out = np.fmax(scores_out, 0, out=scores_out)
387
382
  data[0] = scores_out
388
383
 
389
384
  fg, bg = simple_stats(data[0]), simple_stats(scores_norm)
@@ -485,8 +480,11 @@ def main():
485
480
  if orientations is None:
486
481
  translations, rotations, scores, details = [], [], [], []
487
482
 
488
- # Data processed by normalize_input is guaranteed to have this shape
489
- scores, offset, rotation_array, rotation_mapping, meta = data
483
+ var = None
484
+ # Data processed by normalize_input is guaranteed to have this shape)
485
+ scores, _, rotation_array, rotation_mapping, *_ = data
486
+ if len(data) == 6:
487
+ scores, _, rotation_array, rotation_mapping, var, *_ = data
490
488
 
491
489
  cropped_shape = np.subtract(
492
490
  scores.shape, np.multiply(args.min_boundary_distance, 2)
@@ -509,13 +507,16 @@ def main():
509
507
  )
510
508
  args.n_false_positives = max(args.n_false_positives, 1)
511
509
  n_correlations = np.size(scores[cropped_slice]) * len(rotation_mapping)
510
+ std = np.std(scores[cropped_slice])
511
+ if var is not None:
512
+ std = np.asarray(np.sqrt(var)).reshape(())
513
+
512
514
  minimum_score = np.multiply(
513
515
  erfcinv(2 * args.n_false_positives / n_correlations),
514
- np.sqrt(2) * np.std(scores[cropped_slice]),
516
+ np.sqrt(2) * std,
515
517
  )
516
- print(f"Determined minimum score cutoff: {minimum_score}.")
517
- minimum_score = max(minimum_score, 0)
518
- args.min_score = minimum_score
518
+ print(f"Determined cutoff --min-score {minimum_score}.")
519
+ args.min_score = max(minimum_score, 0)
519
520
 
520
521
  args.batch_dims = None
521
522
  if hasattr(cli_args, "batch_dims"):
@@ -474,6 +474,7 @@ def membrane_mask(
474
474
  **kwargs,
475
475
  ) -> NDArray:
476
476
  return create_mask(
477
+ center=(center_x, center_y, center_z),
477
478
  mask_type="membrane",
478
479
  shape=template.shape,
479
480
  radius=radius,
scripts/refine_matches.py CHANGED
@@ -10,11 +10,9 @@ import subprocess
10
10
  from sys import exit
11
11
  from os import unlink
12
12
  from time import time
13
- from os.path import join
14
13
  from typing import Tuple, List, Dict
15
14
 
16
15
  import numpy as np
17
- from scipy import optimize
18
16
  from sklearn.metrics import roc_auc_score
19
17
 
20
18
  from tme import Orientations, Density
@@ -66,7 +64,6 @@ def parse_args():
66
64
  matching_group.add_argument(
67
65
  "-i",
68
66
  "--template",
69
- dest="template",
70
67
  type=str,
71
68
  required=True,
72
69
  help="Path to a template in PDB/MMCIF or other supported formats (see target).",
@@ -102,7 +99,7 @@ def parse_args():
102
99
  )
103
100
  matching_group.add_argument(
104
101
  "-s",
105
- dest="score",
102
+ "--score",
106
103
  type=str,
107
104
  default="batchFLCSphericalMask",
108
105
  choices=list(MATCHING_EXHAUSTIVE_REGISTER.keys()),
@@ -197,6 +194,7 @@ def create_matching_argdict(args) -> Dict:
197
194
  "-n": args.cores,
198
195
  "--ctf-file": args.ctf_file,
199
196
  "--invert-target-contrast": args.invert_target_contrast,
197
+ "--backend" : args.backend,
200
198
  }
201
199
  return arg_dict
202
200
 
@@ -252,7 +250,7 @@ class DeepMatcher:
252
250
  if args.lowpass_range:
253
251
  self.filter_parameters["--lowpass"] = 0
254
252
  if args.highpass_range:
255
- self.filter_parameters["--highpass"] = 200
253
+ self.filter_parameters["--highpass"] = 0
256
254
 
257
255
  self.postprocess_args = create_postprocessing_argdict(args)
258
256
  self.log_file = f"{args.output_prefix}_optimization_log.txt"
@@ -309,14 +307,14 @@ class DeepMatcher:
309
307
 
310
308
  match_template = argdict_to_command(
311
309
  self.match_template_args,
312
- executable="match_template.py",
310
+ executable="match_template",
313
311
  )
314
312
  run_command(match_template)
315
313
 
316
314
  # Assume we get a new peak for each input in the same order
317
315
  postprocess = argdict_to_command(
318
316
  self.postprocess_args,
319
- executable="postprocess.py",
317
+ executable="postprocess",
320
318
  )
321
319
  run_command(postprocess)
322
320
 
tests/test_analyzer.py CHANGED
@@ -165,7 +165,6 @@ class TestMaxScoreOverRotations:
165
165
  assert res[0].dtype == be._float_dtype
166
166
  assert res[1].size == self.data.ndim
167
167
  assert np.allclose(res[2].shape, self.data.shape)
168
- assert len(res) == 4
169
168
 
170
169
  @pytest.mark.parametrize("use_memmap", [False, True])
171
170
  @pytest.mark.parametrize("score_threshold", [0, 1e10, -1e10])
@@ -181,7 +180,7 @@ class TestMaxScoreOverRotations:
181
180
 
182
181
  data2 = self.data * 2
183
182
  score_analyzer(state, data2, rotation_matrix=self.rotation_matrix)
184
- scores, translation_offset, rotations, mapping = score_analyzer.result(state)
183
+ scores, offset, rotations, mapping, *_ = score_analyzer.result(state)
185
184
 
186
185
  assert np.all(scores >= score_threshold)
187
186
  max_scores = np.maximum(self.data, data2)
@@ -214,7 +213,7 @@ class TestMaxScoreOverRotations:
214
213
  ret = MaxScoreOverRotations.merge(
215
214
  results=states, use_memmap=use_memmap, score_threshold=score_threshold
216
215
  )
217
- scores, translation, rotations, mapping = ret
216
+ scores, translation, rotations, mapping, *_ = ret
218
217
  assert np.all(scores >= score_threshold)
219
218
  max_scores = np.maximum(self.data, data2)
220
219
  max_scores = np.maximum(max_scores, score_threshold)
tests/test_extensions.py CHANGED
@@ -53,7 +53,6 @@ class TestExtensions:
53
53
  @pytest.mark.parametrize("min_distance", [0, 5, 10])
54
54
  def test_find_candidate_indices(self, dimension, dtype, min_distance):
55
55
  coordinates = COORDINATES[dimension].astype(dtype)
56
- print(coordinates.shape)
57
56
 
58
57
  min_distance = np.array([min_distance]).astype(dtype)[0]
59
58
 
@@ -95,18 +95,6 @@ class TestDensity:
95
95
  self.orientations.rotations, orientations_new.rotations, atol=1e-3
96
96
  )
97
97
 
98
- @pytest.mark.parametrize("input_format", ("text", "star", "tbl"))
99
- @pytest.mark.parametrize("output_format", ("text", "star", "tbl"))
100
- def test_file_format_io(self, input_format: str, output_format: str):
101
- _, output_file = mkstemp(suffix=f".{input_format}")
102
- _, output_file2 = mkstemp(suffix=f".{output_format}")
103
-
104
- self.orientations.to_file(output_file)
105
- orientations_new = Orientations.from_file(output_file)
106
- orientations_new.to_file(output_file2)
107
-
108
- assert True
109
-
110
98
  @pytest.mark.parametrize("drop_oob", (True, False))
111
99
  @pytest.mark.parametrize("shape", (10, 40, 80))
112
100
  @pytest.mark.parametrize("odd", (True, False))
@@ -132,12 +132,14 @@ class MaxScoreOverRotations(AbstractAnalyzer):
132
132
  - scores : BackendArray of shape `self._shape` filled with `score_threshold`.
133
133
  - rotations : BackendArray of shape `self._shape` filled with -1.
134
134
  - rotation_mapping : dict, empty mapping from rotation bytes to indices.
135
+ - ssum : BackendArray, accumulator for sum of squared scores.
135
136
  """
136
137
  scores = be.full(
137
138
  shape=self._shape, dtype=be._float_dtype, fill_value=self._score_threshold
138
139
  )
139
140
  rotations = be.full(self._shape, dtype=be._int_dtype, fill_value=-1)
140
- return scores, rotations, {}
141
+ ssum = be.full((1), dtype=be._float_dtype, fill_value=0)
142
+ return scores, rotations, {}, ssum
141
143
 
142
144
  def __call__(
143
145
  self,
@@ -156,6 +158,7 @@ class MaxScoreOverRotations(AbstractAnalyzer):
156
158
  - scores : BackendArray, current maximum scores.
157
159
  - rotations : BackendArray, current rotation indices.
158
160
  - rotation_mapping : dict, mapping from rotation bytes to indices.
161
+ - ssum : BackendArray, accumulator for sum of squared scores.
159
162
  scores : BackendArray
160
163
  Array of new scores to update analyzer with.
161
164
  rotation_matrix : BackendArray
@@ -168,7 +171,7 @@ class MaxScoreOverRotations(AbstractAnalyzer):
168
171
  # be.tobytes behaviour caused overhead for certain GPU/CUDA combinations
169
172
  # If the analyzer is not shared and each rotation is unique, we can
170
173
  # use index to rotation mapping and invert prior to merging.
171
- prev_scores, rotations, rotation_mapping = state
174
+ prev_scores, rotations, rotation_mapping, ssum = state
172
175
 
173
176
  rotation_index = len(rotation_mapping)
174
177
  rotation_matrix = be.astype(rotation_matrix, be._float_dtype)
@@ -180,13 +183,14 @@ class MaxScoreOverRotations(AbstractAnalyzer):
180
183
  rotation = be.tobytes(rotation_matrix)
181
184
  rotation_index = rotation_mapping.setdefault(rotation, rotation_index)
182
185
 
186
+ ssum = be.add(ssum, be.ssum(scores), out=ssum)
183
187
  scores, rotations = be.max_score_over_rotations(
184
188
  scores=scores,
185
189
  max_scores=prev_scores,
186
190
  rotations=rotations,
187
191
  rotation_index=rotation_index,
188
192
  )
189
- return scores, rotations, rotation_mapping
193
+ return scores, rotations, rotation_mapping, ssum
190
194
 
191
195
  @staticmethod
192
196
  def _invert_rmap(rotation_mapping: dict) -> dict:
@@ -224,6 +228,7 @@ class MaxScoreOverRotations(AbstractAnalyzer):
224
228
  - scores : BackendArray, current maximum scores.
225
229
  - rotations : BackendArray, current rotation indices.
226
230
  - rotation_mapping : dict, mapping from rotation indices to matrices.
231
+ - ssum : BackendArray, accumulator for sum of squared scores.
227
232
  targetshape : Tuple[int], optional
228
233
  Shape of the target for convolution mode correction.
229
234
  templateshape : Tuple[int], optional
@@ -240,9 +245,9 @@ class MaxScoreOverRotations(AbstractAnalyzer):
240
245
  Returns
241
246
  -------
242
247
  tuple
243
- Final result tuple (scores, offset, rotations, rotation_mapping).
248
+ Final result tuple (scores, offset, rotations, rotation_mapping, ssum).
244
249
  """
245
- scores, rotations, rotation_mapping = state
250
+ scores, rotations, rotation_mapping, ssum = state
246
251
 
247
252
  # Apply postprocessing if parameters are provided
248
253
  if fourier_shift is not None:
@@ -269,11 +274,13 @@ class MaxScoreOverRotations(AbstractAnalyzer):
269
274
  if self._inversion_mapping:
270
275
  rotation_mapping = {be.tobytes(v): k for k, v in rotation_mapping.items()}
271
276
 
277
+ n_rotations = max(len(rotation_mapping), 1)
272
278
  return (
273
279
  scores,
274
280
  be.to_numpy_array(self._offset),
275
281
  rotations,
276
282
  self._invert_rmap(rotation_mapping),
283
+ be.to_numpy_array(ssum) / (scores.size * n_rotations),
277
284
  )
278
285
 
279
286
  def _harmonize_states(states: List[Tuple]):
@@ -287,18 +294,18 @@ class MaxScoreOverRotations(AbstractAnalyzer):
287
294
  if states[i] is None:
288
295
  continue
289
296
 
290
- scores, offset, rotations, rotation_mapping = states[i]
297
+ scores, offset, rotations, rotation_mapping, ssum = states[i]
291
298
  if out_shape is None:
292
299
  out_shape = np.zeros(scores.ndim, int)
293
300
  out_shape = np.maximum(out_shape, np.add(offset, scores.shape))
294
301
 
295
302
  new_param = {}
296
303
  for key, value in rotation_mapping.items():
297
- rotation_bytes = be.tobytes(value)
304
+ rotation_bytes = np.asarray(value).tobytes()
298
305
  new_param[rotation_bytes] = key
299
306
  if rotation_bytes not in new_rotation_mapping:
300
307
  new_rotation_mapping[rotation_bytes] = len(new_rotation_mapping)
301
- states[i] = (scores, offset, rotations, new_param)
308
+ states[i] = (scores, offset, rotations, new_param, ssum)
302
309
  out_shape = tuple(int(x) for x in out_shape)
303
310
  return new_rotation_mapping, out_shape, states
304
311
 
@@ -329,11 +336,10 @@ class MaxScoreOverRotations(AbstractAnalyzer):
329
336
  if len(results) == 1:
330
337
  ret = results[0]
331
338
  if use_memmap:
332
- scores, offset, rotations, rotation_mapping = ret
339
+ scores, offset, rotations, rotation_mapping, ssum = ret
333
340
  scores = array_to_memmap(scores)
334
341
  rotations = array_to_memmap(rotations)
335
- ret = (scores, offset, rotations, rotation_mapping)
336
-
342
+ ret = (scores, offset, rotations, rotation_mapping, ssum)
337
343
  return ret
338
344
 
339
345
  # Determine output array shape and create consistent rotation map
@@ -368,6 +374,7 @@ class MaxScoreOverRotations(AbstractAnalyzer):
368
374
  )
369
375
  rotations_out = np.full(out_shape, fill_value=-1, dtype=rotations_dtype)
370
376
 
377
+ total_ssum = 0
371
378
  for i in range(len(results)):
372
379
  if results[i] is None:
373
380
  continue
@@ -385,7 +392,9 @@ class MaxScoreOverRotations(AbstractAnalyzer):
385
392
  shape=out_shape,
386
393
  dtype=rotations_dtype,
387
394
  )
388
- scores, offset, rotations, rotation_mapping = results[i]
395
+ scores, offset, rotations, rotation_mapping, ssum = results[i]
396
+
397
+ total_ssum = np.add(total_ssum, ssum)
389
398
  stops = np.add(offset, scores.shape).astype(int)
390
399
  indices = tuple(slice(*pos) for pos in zip(offset, stops))
391
400
 
@@ -428,6 +437,7 @@ class MaxScoreOverRotations(AbstractAnalyzer):
428
437
  np.zeros(scores_out.ndim, dtype=int),
429
438
  rotations_out,
430
439
  cls._invert_rmap(master_rotation_mapping),
440
+ total_ssum / len(results),
431
441
  )
432
442
 
433
443
 
@@ -17,7 +17,7 @@ from ..backends import backend as be
17
17
  from ..matching_utils import normalize_template as _normalize_template
18
18
 
19
19
 
20
- __all__ = ["scan"]
20
+ __all__ = ["scan", "setup_scan"]
21
21
 
22
22
 
23
23
  def _correlate(template: BackendArray, ft_target: BackendArray) -> BackendArray:
@@ -112,12 +112,53 @@ def _identity(arr: BackendArray, arr_filter: BackendArray) -> BackendArray:
112
112
  return arr
113
113
 
114
114
 
115
- @partial(
116
- pmap,
117
- in_axes=(0,) + (None,) * 6,
118
- static_broadcasted_argnums=[6, 7, 8, 9],
119
- axis_name="batch",
120
- )
115
+ def _mask_scores(arr, mask):
116
+ return arr.at[:].multiply(mask)
117
+
118
+
119
+ def _select_config(analyzer_kwargs, device_idx):
120
+ return analyzer_kwargs[device_idx]
121
+
122
+
123
+ def setup_scan(analyzer_kwargs, callback_class, fast_shape, rotate_mask):
124
+ """Create separate scan function with initialized analyzer for each device"""
125
+ device_scans = [
126
+ partial(
127
+ scan,
128
+ fast_shape=fast_shape,
129
+ rotate_mask=rotate_mask,
130
+ analyzer=callback_class(**device_config),
131
+ ) for device_config in analyzer_kwargs
132
+ ]
133
+
134
+ @partial(
135
+ pmap,
136
+ in_axes=(0,) + (None,) * 6,
137
+ axis_name="batch",
138
+ )
139
+ def scan_combined(
140
+ target,
141
+ template,
142
+ template_mask,
143
+ rotations,
144
+ template_filter,
145
+ target_filter,
146
+ score_mask,
147
+ ):
148
+ return lax.switch(
149
+ lax.axis_index("batch"),
150
+ device_scans,
151
+ target,
152
+ template,
153
+ template_mask,
154
+ rotations,
155
+ template_filter,
156
+ target_filter,
157
+ score_mask,
158
+ )
159
+ return scan_combined
160
+
161
+
121
162
  def scan(
122
163
  target: BackendArray,
123
164
  template: BackendArray,
@@ -125,19 +166,13 @@ def scan(
125
166
  rotations: BackendArray,
126
167
  template_filter: BackendArray,
127
168
  target_filter: BackendArray,
169
+ score_mask: BackendArray,
128
170
  fast_shape: Tuple[int],
129
171
  rotate_mask: bool,
130
- analyzer_class: object,
131
- analyzer_kwargs: Tuple[Tuple],
172
+ analyzer: object,
132
173
  ) -> Tuple[BackendArray, BackendArray]:
133
174
  eps = jnp.finfo(template.dtype).resolution
134
175
 
135
- kwargs = lax.switch(
136
- lax.axis_index("batch"),
137
- [lambda: analyzer_kwargs[i] for i in range(len(analyzer_kwargs))],
138
- )
139
- analyzer = analyzer_class(**be._tuple_to_dict(kwargs))
140
-
141
176
  if hasattr(target_filter, "shape"):
142
177
  target = _apply_fourier_filter(target, target_filter)
143
178
 
@@ -159,6 +194,10 @@ def scan(
159
194
  if template_filter.shape != ():
160
195
  _template_filter_func = _apply_fourier_filter
161
196
 
197
+ _score_mask_func = _identity
198
+ if score_mask.shape != ():
199
+ _score_mask_func = _mask_scores
200
+
162
201
  def _sample_transform(ret, rotation_matrix):
163
202
  state, index = ret
164
203
  template_rot, template_mask_rot = be.rigid_transform(
@@ -185,6 +224,8 @@ def scan(
185
224
  n_observations=n_observations,
186
225
  eps=eps,
187
226
  )
227
+ scores = _score_mask_func(scores, score_mask)
228
+
188
229
  state = analyzer(state, scores, rotation_matrix, rotation_index=index)
189
230
  return (state, index + 1), None
190
231
 
@@ -81,6 +81,17 @@ class CupyBackend(NumpyFFTWBackend):
81
81
  """,
82
82
  "norm_scores",
83
83
  )
84
+
85
+ # Sum of square computation similar to the demeaned variance in pytom
86
+ self.ssum = cp.ReductionKernel(
87
+ f"{ftype} arr",
88
+ f"{ftype} ret",
89
+ "arr * arr",
90
+ "a + b",
91
+ "ret = a",
92
+ "0",
93
+ f"ssum_{ftype}",
94
+ )
84
95
  self.texture_available = find_spec("voltools") is not None
85
96
 
86
97
  def to_backend_array(self, arr: NDArray) -> CupyArray:
@@ -139,17 +150,6 @@ class CupyBackend(NumpyFFTWBackend):
139
150
  peaks = self._array_backend.array(self._array_backend.nonzero(max_filter)).T
140
151
  return peaks
141
152
 
142
- # The default methods in Cupy were oddly slow
143
- def var(self, a, *args, **kwargs):
144
- out = a - self._array_backend.mean(a, *args, **kwargs)
145
- self._array_backend.square(out, out)
146
- out = self._array_backend.mean(out, *args, **kwargs)
147
- return out
148
-
149
- def std(self, a, *args, **kwargs):
150
- out = self.var(a, *args, **kwargs)
151
- return self._array_backend.sqrt(out)
152
-
153
153
  def _get_texture(self, arr: CupyArray, order: int = 3, prefilter: bool = False):
154
154
  key = id(arr)
155
155
  if key in TEXTURE_CACHE:
@@ -197,6 +197,13 @@ class JaxBackend(NumpyFFTWBackend):
197
197
  def _tuple_to_dict(self, data: Tuple) -> Dict:
198
198
  return {x[0]: self._from_hashable(*x[1]) for x in data}
199
199
 
200
+ def _unbatch(self, data, target_ndim, index):
201
+ if not isinstance(data, type(self.zeros(1))):
202
+ return data
203
+ elif data.ndim <= target_ndim:
204
+ return data
205
+ return data[index]
206
+
200
207
  def scan(
201
208
  self,
202
209
  matching_data: type,
@@ -211,12 +218,14 @@ class JaxBackend(NumpyFFTWBackend):
211
218
  Emulates output of :py:meth:`tme.matching_exhaustive.scan` using
212
219
  :py:class:`tme.analyzer.MaxScoreOverRotations`.
213
220
  """
214
- from ._jax_utils import scan as scan_inner
221
+ from ._jax_utils import setup_scan
222
+ from ..analyzer import MaxScoreOverRotations
215
223
 
216
224
  pad_target = True if len(splits) > 1 else False
217
225
  convolution_mode = "valid" if pad_target else "same"
218
226
  target_pad = matching_data.target_padding(pad_target=pad_target)
219
227
 
228
+ score_mask = self.full((1,), fill_value=1, dtype=bool)
220
229
  target_shape = tuple(
221
230
  (x.stop - x.start + p) for x, p in zip(splits[0][0], target_pad)
222
231
  )
@@ -270,8 +279,10 @@ class JaxBackend(NumpyFFTWBackend):
270
279
  cur_args = analyzer_args.copy()
271
280
  cur_args["offset"] = translation_offset
272
281
  cur_args.update(callback_class_args)
282
+ analyzer_kwargs.append(cur_args)
273
283
 
274
- analyzer_kwargs.append(self._dict_to_tuple(cur_args))
284
+ if pad_target:
285
+ score_mask = base._score_mask(fast_shape, shift)
275
286
 
276
287
  _target = self.astype(base._target, self._float_dtype)
277
288
  translation_offsets.append(translation_offset)
@@ -298,7 +309,13 @@ class JaxBackend(NumpyFFTWBackend):
298
309
  create_filter, create_template_filter, create_target_filter = (False,) * 3
299
310
  base, targets = None, self._array_backend.stack(targets)
300
311
 
301
- analyzer_kwargs = tuple(analyzer_kwargs)
312
+ scan_inner = setup_scan(
313
+ analyzer_kwargs=analyzer_kwargs,
314
+ callback_class=callback_class,
315
+ fast_shape=fast_shape,
316
+ rotate_mask=rotate_mask
317
+ )
318
+
302
319
  states = scan_inner(
303
320
  self.astype(targets, self._float_dtype),
304
321
  self.astype(matching_data.template, self._float_dtype),
@@ -306,17 +323,18 @@ class JaxBackend(NumpyFFTWBackend):
306
323
  matching_data.rotations,
307
324
  template_filter,
308
325
  target_filter,
309
- fast_shape,
310
- rotate_mask,
311
- callback_class,
312
- analyzer_kwargs,
326
+ score_mask,
313
327
  )
314
328
 
329
+ ndim = targets.ndim - 1
315
330
  for index in range(targets.shape[0]):
316
- kwargs = self._tuple_to_dict(analyzer_kwargs[index])
331
+ kwargs = analyzer_kwargs[index]
317
332
  analyzer = callback_class(**kwargs)
333
+ state = [self._unbatch(x, ndim, index) for x in states]
334
+
335
+ if isinstance(analyzer, MaxScoreOverRotations):
336
+ state[2] = rotation_mapping
318
337
 
319
- state = (states[0][index], states[1][index], rotation_mapping)
320
338
  ret.append(analyzer.result(state, **kwargs))
321
339
  return ret
322
340
 
@@ -863,6 +863,17 @@ class MatchingBackend(ABC):
863
863
  Indices of ``k`` largest elements in ``arr``.
864
864
  """
865
865
 
866
+ @abstractmethod
867
+ def ssum(self, arr, *args, **kwargs) -> BackendArray:
868
+ """
869
+ Compute the sum of squares of ``arr``.
870
+
871
+ Returns
872
+ -------
873
+ BackendArray
874
+ Sum of squares with shape ().
875
+ """
876
+
866
877
  def indices(self, *args, **kwargs) -> BackendArray:
867
878
  """
868
879
  Creates an array representing the index grid of an input.
@@ -201,6 +201,9 @@ class NumpyFFTWBackend(_NumpyWrapper, MatchingBackend):
201
201
  sorted_indices = self.unravel_index(indices=sorted_indices, shape=arr.shape)
202
202
  return sorted_indices
203
203
 
204
+ def ssum(self, arr, *args, **kwargs):
205
+ return self.sum(self.square(arr), *args, **kwargs)
206
+
204
207
  def indices(self, *args, **kwargs) -> NDArray:
205
208
  return self._array_backend.indices(*args, **kwargs)
206
209
 
tme/density.py CHANGED
@@ -2196,7 +2196,7 @@ class Density:
2196
2196
 
2197
2197
  Parameters
2198
2198
  ----------
2199
- target : Density
2199
+ target : :py:class:`Density`
2200
2200
  The target map for template matching.
2201
2201
  template : Structure
2202
2202
  The template that should be aligned to the target.
@@ -2259,3 +2259,60 @@ class Density:
2259
2259
  coordinates = np.array(np.where(data > 0))
2260
2260
  weights = self.data[tuple(coordinates)]
2261
2261
  return align_to_axis(coordinates.T, weights=weights, axis=axis, flip=flip)
2262
+
2263
+ @staticmethod
2264
+ def fourier_shell_correlation(density1: "Density", density2: "Density") -> NDArray:
2265
+ """
2266
+ Computes the Fourier Shell Correlation (FSC) between two instances of `Density`.
2267
+
2268
+ The Fourier transforms of the input maps are divided into shells
2269
+ based on their spatial frequency. The correlation between corresponding shells
2270
+ in the two maps is computed to give the FSC.
2271
+
2272
+ Parameters
2273
+ ----------
2274
+ density1 : :py:class:`Density`
2275
+ Reference for comparison.
2276
+ density2 : :py:class:`Density`
2277
+ Target for comparison.
2278
+
2279
+ Returns
2280
+ -------
2281
+ NDArray
2282
+ An array of shape (N, 2), where N is the number of shells.
2283
+ The first column represents the spatial frequency for each shell
2284
+ and the second column represents the corresponding FSC.
2285
+
2286
+ References
2287
+ ----------
2288
+ .. [1] https://github.com/tdgrant1/denss/blob/master/saxstats/saxstats.py
2289
+ """
2290
+ side = density1.data.shape[0]
2291
+ df = 1.0 / side
2292
+
2293
+ qx_ = np.fft.fftfreq(side) * side * df
2294
+ qx, qy, qz = np.meshgrid(qx_, qx_, qx_, indexing="ij")
2295
+ qr = np.sqrt(qx**2 + qy**2 + qz**2)
2296
+
2297
+ qmax = np.max(qr)
2298
+ qstep = np.min(qr[qr > 0])
2299
+ nbins = int(qmax / qstep)
2300
+ qbins = np.linspace(0, nbins * qstep, nbins + 1)
2301
+ qbin_labels = np.searchsorted(qbins, qr, "right") - 1
2302
+
2303
+ F1 = np.fft.fftn(density1.data)
2304
+ F2 = np.fft.fftn(density2.data)
2305
+
2306
+ qbin_labels = qbin_labels.reshape(-1)
2307
+ numerator = np.bincount(
2308
+ qbin_labels, weights=np.real(F1 * np.conj(F2)).reshape(-1)
2309
+ )
2310
+ term1 = np.bincount(qbin_labels, weights=np.abs(F1).reshape(-1) ** 2)
2311
+ term2 = np.bincount(qbin_labels, weights=np.abs(F2).reshape(-1) ** 2)
2312
+ np.multiply(term1, term2, out=term1)
2313
+ denominator = np.sqrt(term1)
2314
+ FSC = np.divide(numerator, denominator)
2315
+
2316
+ qidx = np.where(qbins < qx.max())
2317
+
2318
+ return np.vstack((qbins[qidx], FSC[qidx])).T
tme/matching_data.py CHANGED
@@ -544,6 +544,30 @@ class MatchingData:
544
544
  batch_mask=be.to_numpy_array(self._batch_mask),
545
545
  )
546
546
 
547
+ def _score_mask(self, fast_shape: Tuple[int], shift: Tuple[int]) -> BackendArray:
548
+ """
549
+ Create a boolean mask to exclude scores derived from padding in template matching.
550
+ """
551
+ padding = self.target_padding(True)
552
+ offset = tuple(x // 2 for x in padding)
553
+ shape = tuple(y - x for x, y in zip(padding, self.target.shape))
554
+
555
+ subset = []
556
+ for i in range(len(offset)):
557
+ if self._batch_mask[i]:
558
+ subset.append(slice(None))
559
+ else:
560
+ subset.append(slice(offset[i], offset[i] + shape[i]))
561
+
562
+ score_mask = np.zeros(fast_shape, dtype=bool)
563
+ score_mask[tuple(subset)] = 1
564
+ score_mask = np.roll(
565
+ score_mask,
566
+ shift=tuple(-x for x in shift),
567
+ axis=tuple(i for i in range(len(shift))),
568
+ )
569
+ return be.to_backend_array(score_mask)
570
+
547
571
  def computation_schedule(
548
572
  self,
549
573
  matching_method: str = "FLCSphericalMask",
@@ -223,6 +223,10 @@ def scan(
223
223
  )
224
224
  conv, fwd, inv, shift = matching_data.fourier_padding()
225
225
 
226
+ score_mask = be.full(shape=(1,), fill_value=1, dtype=bool)
227
+ if pad_target:
228
+ score_mask = matching_data._score_mask(fwd, shift)
229
+
226
230
  template_filter = _setup_template_filter_apply_target_filter(
227
231
  matching_data=matching_data,
228
232
  fast_shape=fwd,
@@ -275,6 +279,7 @@ def scan(
275
279
  callback=callback_classes[index % n_callback_classes],
276
280
  interpolation_order=interpolation_order,
277
281
  template_filter=be.to_sharedarr(template_filter, shm_handler),
282
+ score_mask=be.to_sharedarr(score_mask, shm_handler),
278
283
  **setup,
279
284
  )
280
285
  for index, rotation in enumerate(matching_data._split_rotations_on_jobs(n_jobs))
@@ -420,8 +425,6 @@ def scan_subsets(
420
425
  outer_jobs, inner_jobs = job_schedule
421
426
  if be._backend_name == "jax":
422
427
  func = be.scan
423
- if kwargs.get("projection_matching", False):
424
- func = be.scan_projections
425
428
 
426
429
  corr_scoring = MATCHING_EXHAUSTIVE_REGISTER.get("CORR", (None, None))[1]
427
430
  results = func(
tme/matching_scores.py CHANGED
@@ -356,6 +356,7 @@ def corr_scoring(
356
356
  callback: CallbackClass,
357
357
  interpolation_order: int,
358
358
  template_mask: shm_type = None,
359
+ score_mask: shm_type = None,
359
360
  ) -> CallbackClass:
360
361
  """
361
362
  Calculates a normalized cross-correlation between a target f and a template g.
@@ -394,6 +395,8 @@ def corr_scoring(
394
395
  Spline order for template rotations.
395
396
  template_mask : Union[Tuple[type, tuple of ints, type], BackendArray], optional
396
397
  Template mask data buffer, its shape and datatype, None by default.
398
+ score_mask : Union[Tuple[type, tuple of ints, type], BackendArray], optional
399
+ Score mask data buffer, its shape and datatype, None by default.
397
400
 
398
401
  Returns
399
402
  -------
@@ -404,6 +407,7 @@ def corr_scoring(
404
407
  inv_denominator = be.from_sharedarr(inv_denominator)
405
408
  numerator = be.from_sharedarr(numerator)
406
409
  template_filter = be.from_sharedarr(template_filter)
410
+ score_mask = be.from_sharedarr(score_mask)
407
411
 
408
412
  n_obs = None
409
413
  if template_mask is not None:
@@ -413,6 +417,7 @@ def corr_scoring(
413
417
  norm_template = conditional_execute(normalize_template, n_obs is not None)
414
418
  norm_sub = conditional_execute(be.subtract, numerator.shape != (1,))
415
419
  norm_mul = conditional_execute(be.multiply, inv_denominator.shape != (1))
420
+ norm_mask = conditional_execute(be.multiply, score_mask.shape != (1,))
416
421
 
417
422
  arr = be.zeros(fast_shape, be._float_dtype)
418
423
  ft_temp = be.zeros(fast_ft_shape, be._complex_dtype)
@@ -447,6 +452,8 @@ def corr_scoring(
447
452
 
448
453
  arr = norm_sub(arr, numerator, out=arr)
449
454
  arr = norm_mul(arr, inv_denominator, out=arr)
455
+ arr = norm_mask(arr, score_mask, out=arr)
456
+
450
457
  callback(arr, rotation_matrix=rotation)
451
458
 
452
459
  return callback
@@ -463,6 +470,7 @@ def flc_scoring(
463
470
  rotations: BackendArray,
464
471
  callback: CallbackClass,
465
472
  interpolation_order: int,
473
+ score_mask: shm_type = None,
466
474
  ) -> CallbackClass:
467
475
  """
468
476
  Computes a normalized cross-correlation between ``target`` (f),
@@ -522,6 +530,7 @@ def flc_scoring(
522
530
  ft_target = be.from_sharedarr(ft_target)
523
531
  ft_target2 = be.from_sharedarr(ft_target2)
524
532
  template_filter = be.from_sharedarr(template_filter)
533
+ score_mask = be.from_sharedarr(score_mask)
525
534
 
526
535
  arr = be.zeros(fast_shape, float_dtype)
527
536
  temp = be.zeros(fast_shape, float_dtype)
@@ -532,6 +541,7 @@ def flc_scoring(
532
541
  template_mask_rot = be.zeros(template.shape, be._float_dtype)
533
542
 
534
543
  tmpl_filter_func = _create_filter_func(template.shape, template_filter.shape)
544
+ norm_mask = conditional_execute(be.multiply, score_mask.shape != (1,))
535
545
 
536
546
  eps = be.eps(float_dtype)
537
547
  center = be.divide(be.to_backend_array(template.shape) - 1, 2)
@@ -567,6 +577,8 @@ def flc_scoring(
567
577
  arr = _correlate_fts(ft_target, ft_temp, ft_temp, arr, fast_shape)
568
578
 
569
579
  arr = be.norm_scores(arr, temp2, temp, n_obs, eps, arr)
580
+ arr = norm_mask(arr, score_mask, out=arr)
581
+
570
582
  callback(arr, rotation_matrix=rotation)
571
583
 
572
584
  return callback
@@ -585,6 +597,7 @@ def mcc_scoring(
585
597
  callback: CallbackClass,
586
598
  interpolation_order: int,
587
599
  overlap_ratio: float = 0.3,
600
+ score_mask: shm_type = None,
588
601
  ) -> CallbackClass:
589
602
  """
590
603
  Computes a normalized cross-correlation score between ``target`` (f),
@@ -755,12 +768,14 @@ def flc_scoring2(
755
768
  rotations: BackendArray,
756
769
  callback: CallbackClass,
757
770
  interpolation_order: int,
771
+ score_mask: shm_type = None,
758
772
  ) -> CallbackClass:
759
773
  template = be.from_sharedarr(template)
760
774
  template_mask = be.from_sharedarr(template_mask)
761
775
  ft_target = be.from_sharedarr(ft_target)
762
776
  ft_target2 = be.from_sharedarr(ft_target2)
763
777
  template_filter = be.from_sharedarr(template_filter)
778
+ score_mask = be.from_sharedarr(score_mask)
764
779
 
765
780
  tar_batch, tmpl_batch = _get_batch_dim(ft_target, template)
766
781
 
@@ -785,6 +800,7 @@ def flc_scoring2(
785
800
  filter_shape=template_filter.shape,
786
801
  arr_padded=True,
787
802
  )
803
+ norm_mask = conditional_execute(be.multiply, score_mask.shape != (1,))
788
804
 
789
805
  eps = be.eps(be._float_dtype)
790
806
  for index in range(rotations.shape[0]):
@@ -816,6 +832,8 @@ def flc_scoring2(
816
832
  arr = _correlate_fts(ft_target, ft_temp, ft_denom, arr, shape, axes)
817
833
 
818
834
  arr = be.norm_scores(arr, temp2, temp, n_obs, eps, arr)
835
+ arr = norm_mask(arr, score_mask, out=arr)
836
+
819
837
  callback(arr, rotation_matrix=rotation)
820
838
 
821
839
  return callback
@@ -834,12 +852,14 @@ def corr_scoring2(
834
852
  interpolation_order: int,
835
853
  target_filter: shm_type = None,
836
854
  template_mask: shm_type = None,
855
+ score_mask: shm_type = None,
837
856
  ) -> CallbackClass:
838
857
  template = be.from_sharedarr(template)
839
858
  ft_target = be.from_sharedarr(ft_target)
840
859
  inv_denominator = be.from_sharedarr(inv_denominator)
841
860
  numerator = be.from_sharedarr(numerator)
842
861
  template_filter = be.from_sharedarr(template_filter)
862
+ score_mask = be.from_sharedarr(score_mask)
843
863
 
844
864
  tar_batch, tmpl_batch = _get_batch_dim(ft_target, template)
845
865
 
@@ -869,6 +889,7 @@ def corr_scoring2(
869
889
  norm_template = conditional_execute(normalize_template, n_obs is not None)
870
890
  norm_sub = conditional_execute(be.subtract, numerator.shape != (1,))
871
891
  norm_mul = conditional_execute(be.multiply, inv_denominator.shape != (1,))
892
+ norm_mask = conditional_execute(be.multiply, score_mask.shape != (1,))
872
893
 
873
894
  template_filter_func = _create_filter_func(
874
895
  arr_shape=template.shape,
@@ -896,6 +917,8 @@ def corr_scoring2(
896
917
 
897
918
  arr = norm_sub(arr, numerator, out=arr)
898
919
  arr = norm_mul(arr, inv_denominator, out=arr)
920
+ arr = norm_mask(arr, score_mask, out=arr)
921
+
899
922
  callback(arr, rotation_matrix=rotation)
900
923
 
901
924
  return callback
tme/orientations.py CHANGED
@@ -327,11 +327,18 @@ class Orientations:
327
327
  "_rlnAnglePsi",
328
328
  "_rlnClassNumber",
329
329
  ]
330
+
331
+ target_identifer = "_rlnMicrographName"
332
+ if version == "# version 50001":
333
+ header[3] = "_rlnCenteredCoordinateXAngst"
334
+ header[4] = "_rlnCenteredCoordinateYAngst"
335
+ header[5] = "_rlnCenteredCoordinateZAngst"
336
+ target_identifer = "_rlnTomoName"
337
+
330
338
  if source_path is not None:
331
- header.append("_rlnMicrographName")
339
+ header.append(target_identifer)
332
340
 
333
341
  header.append("_pytmeScore")
334
-
335
342
  header = "\n".join(header)
336
343
  with open(filename, mode="w", encoding="utf-8") as ofile:
337
344
  if version is not None:
@@ -487,16 +494,22 @@ class Orientations:
487
494
 
488
495
  @classmethod
489
496
  def _from_star(
490
- cls, filename: str, delimiter: str = "\t"
497
+ cls, filename: str, delimiter: str = None
491
498
  ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
492
499
  parser = StarParser(filename, delimiter=delimiter)
493
500
 
494
- ret = parser.get("data_particles", None)
495
- if ret is None:
496
- ret = parser.get("data_", None)
501
+ keyword_order = ("data_particles", "particles", "data")
502
+ for keyword in keyword_order:
503
+ ret = parser.get(keyword, None)
504
+ if ret is None:
505
+ ret = parser.get(f"{keyword}_", None)
506
+ if ret is not None:
507
+ break
497
508
 
498
509
  if ret is None:
499
- raise ValueError(f"No data_particles section found in {filename}.")
510
+ raise ValueError(
511
+ f"Could not find either {keyword_order} section found in {filename}."
512
+ )
500
513
 
501
514
  translation = np.vstack(
502
515
  (ret["_rlnCoordinateX"], ret["_rlnCoordinateY"], ret["_rlnCoordinateZ"])