pytme 0.3.1.post1__cp311-cp311-macosx_15_0_arm64.whl → 0.3.1.post2__cp311-cp311-macosx_15_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -359,8 +359,8 @@ def parse_args():
359
359
  "--invert-target-contrast",
360
360
  action="store_true",
361
361
  default=False,
362
- help="Invert the target's contrast for cases where templates to-be-matched have "
363
- "negative values, e.g. tomograms.",
362
+ help="Invert the target contrast. Useful for matching on tomograms if the "
363
+ "template has not been inverted.",
364
364
  )
365
365
  io_group.add_argument(
366
366
  "--scramble-phases",
@@ -474,6 +474,7 @@ def membrane_mask(
474
474
  **kwargs,
475
475
  ) -> NDArray:
476
476
  return create_mask(
477
+ center=(center_x, center_y, center_z),
477
478
  mask_type="membrane",
478
479
  shape=template.shape,
479
480
  radius=radius,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: pytme
3
- Version: 0.3.1.post1
3
+ Version: 0.3.1.post2
4
4
  Summary: Python Template Matching Engine
5
5
  Author: Valentin Maurer
6
6
  Author-email: Valentin Maurer <valentin.maurer@embl-hamburg.de>
@@ -1,21 +1,21 @@
1
- pytme-0.3.1.post1.data/scripts/estimate_memory_usage.py,sha256=Ry46LXUv3SZ0g41g3RDUg9UH6hiSnnG3mHTyaGletXE,2114
2
- pytme-0.3.1.post1.data/scripts/match_template.py,sha256=X0Sa4dBuE3fvKer--JgY7kPSyZ7H6fa5PKi2lxRoyzM,37563
3
- pytme-0.3.1.post1.data/scripts/postprocess.py,sha256=GVlavaLWPdbTmMeEczsrG0cGhHXW4wZOZ5zLIV0sx4o,27708
4
- pytme-0.3.1.post1.data/scripts/preprocess.py,sha256=eq-67cuj3WFx5jAdS56yQSVv_sCt7SRW-dqzQRbLnVE,6328
5
- pytme-0.3.1.post1.data/scripts/preprocessor_gui.py,sha256=R73N-UTSiAplTqyRw-SLGcGusnJVv2dlAFLuqmUrl-8,44153
6
- pytme-0.3.1.post1.data/scripts/pytme_runner.py,sha256=dqCd60puAOOOSvuCxrJC1MbfdsRS-ctMso5YIwM-JkI,40356
7
- pytme-0.3.1.post1.dist-info/licenses/LICENSE,sha256=gXf5dRMhNSbfLPYYTY_5hsZ1r7UU1OaKQEAQUhuIBkM,18092
1
+ pytme-0.3.1.post2.data/scripts/estimate_memory_usage.py,sha256=Ry46LXUv3SZ0g41g3RDUg9UH6hiSnnG3mHTyaGletXE,2114
2
+ pytme-0.3.1.post2.data/scripts/match_template.py,sha256=Px443VNaxto_GPuL16d2TcFp7zPrU-camVdGeW4dHNA,37556
3
+ pytme-0.3.1.post2.data/scripts/postprocess.py,sha256=-n7WFjw-AoVe1BLFCgYtfnjdSpZE65o2qeYVmxO_Jh4,27708
4
+ pytme-0.3.1.post2.data/scripts/preprocess.py,sha256=eq-67cuj3WFx5jAdS56yQSVv_sCt7SRW-dqzQRbLnVE,6328
5
+ pytme-0.3.1.post2.data/scripts/preprocessor_gui.py,sha256=fw1Q0G11Eit7TRt9BNqqn5vYthOWt4GdaFS-WfeznbE,44200
6
+ pytme-0.3.1.post2.data/scripts/pytme_runner.py,sha256=dqCd60puAOOOSvuCxrJC1MbfdsRS-ctMso5YIwM-JkI,40356
7
+ pytme-0.3.1.post2.dist-info/licenses/LICENSE,sha256=gXf5dRMhNSbfLPYYTY_5hsZ1r7UU1OaKQEAQUhuIBkM,18092
8
8
  scripts/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
9
9
  scripts/estimate_memory_usage.py,sha256=UiaX30o_H59vzTQZaXISxgzWj5jwgFbMoH1E5ydVHcw,2115
10
10
  scripts/eval.py,sha256=ebJVLxbRlB6TI5YHNr0VavZ4lmaRdf8QVafyiDhh_oU,2528
11
11
  scripts/extract_candidates.py,sha256=B2O4Xm0eVJzBZOblfkH5za8fTxLIdBRDl89Qwkq4Kjk,8097
12
- scripts/match_template.py,sha256=LsoS7gPh6tx3c5uyF3v5XbATjb8jJ_jh91zmO2IruZ0,37564
12
+ scripts/match_template.py,sha256=akKSFalnhsk2OdqELCKFvrBhJ49U_UoWX4xs-42NKAY,37557
13
13
  scripts/match_template_filters.py,sha256=Gj4a1b_S5NWp_dfFEPFn0D7jGf-qYgBbnTvZZ4bwqOQ,42036
14
- scripts/postprocess.py,sha256=ysI1hBFqRly8noeOgL2ZONm27hSkPoVsbeG1kwI53G4,27709
14
+ scripts/postprocess.py,sha256=_9mFpFQww0TbMQeFaDRxR_AJh2wdGfXrNbqXVUZhtZY,27709
15
15
  scripts/preprocess.py,sha256=PrtO0aWGbTvSMqCdNYuzW02FoGvSsvQzJab52DTssQ4,6329
16
- scripts/preprocessor_gui.py,sha256=eozUE6D_vJOO11Prardwm7zWrjvTXLHFaAMmY7yfNIc,44154
16
+ scripts/preprocessor_gui.py,sha256=GUtq8_jU0NGqTEq5ZyjMaeBskJIxao4rznuraNv_02s,44201
17
17
  scripts/pytme_runner.py,sha256=BD5u45vtzJo_d6ZAM2N_CKnGpfjtNg4v7C-31vbTGnU,40357
18
- scripts/refine_matches.py,sha256=1AG-E2PuTauflR_qVC1CaZKg7ev9xxGu3CIYeyJ2-WQ,12777
18
+ scripts/refine_matches.py,sha256=dLu3aW0-iAn0Qn-GoQBaZfgEGUcTWMQ2J2tA8aNm5Yo,12725
19
19
  tests/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
20
20
  tests/test_analyzer.py,sha256=9TVJacQF44AnMYOWfLhqQgz1V2TeKjijVT1rJb50NMw,8360
21
21
  tests/test_backends.py,sha256=65jCvPfB2YjUHkhmjkC49Rpd1SyPRQ7_lnLXHI5QC-U,17361
@@ -27,7 +27,7 @@ tests/test_matching_exhaustive.py,sha256=bRPCN0pyZk3DmXMrWRcEGAksYoch8P7fRiwE3g0
27
27
  tests/test_matching_memory.py,sha256=XrBGRi4cIp3-4nN6s7tj0pBbFJrvJaG3vQdtK-6uzY0,1151
28
28
  tests/test_matching_optimization.py,sha256=bYt3d_nHewTRxKCXgGaMLhObkqBufGwrYL4gfinr_jE,7732
29
29
  tests/test_matching_utils.py,sha256=m6mkjo-rgA0q67MYQwjeyy_y6GBErh68A-lqxvQ8NT4,4948
30
- tests/test_orientations.py,sha256=ydqYZc3mOWSweFEGWMrMaygxX7-7in7forUtiNWfQKc,6767
30
+ tests/test_orientations.py,sha256=BnTWHXplucO9_f1TognyZNVQdzGgr1ckt-YGDrlHDcU,6250
31
31
  tests/test_parser.py,sha256=57oaksWrKNB4Z_22IxfW0nXMyQWLJFVsuvnJQPhMn10,993
32
32
  tests/test_rotations.py,sha256=Kll51PZm_5MWP6TcW_II6oJLgORM4piZtWbzhn2fiA8,5374
33
33
  tests/test_structure.py,sha256=Qgrex3wYl9TpFLcqMeUOkq-7w93bzEsMNStKdJdsbnw,8881
@@ -62,7 +62,7 @@ tests/preprocessing/test_utils.py,sha256=RnqlA3vHcVguWp0gmBOwBgUZK85ylNLqr_oFkf8
62
62
  tme/__init__.py,sha256=GXfIuzU4fh2KZfgUOIXnNeKlJ8kSp3EqoBr6bSik848,217
63
63
  tme/__version__.py,sha256=r4xAFihOf72W9TD-lpMi6ntWSTKTP2SlzKP1ytkjRbI,22
64
64
  tme/cli.py,sha256=48Q7QGIuCwNSms5wOnGmWjMSlZrXjAQUz_9zsIFT9tY,3652
65
- tme/density.py,sha256=iupy3gL7gwE2Q9FHcZfHY_AcbOSUAo3i9tSbC-KF-DA,82405
65
+ tme/density.py,sha256=qYwLi6zuGlSYXP3mQwfahaAVrSolq718sujeK6ZSE1k,84484
66
66
  tme/extensions.cpython-311-darwin.so,sha256=xUCjh1eMK3iKWMC04oFG3LzxBRAqHpl7idIIV4ezx3Q,415840
67
67
  tme/mask.py,sha256=4xHqWtOabgdYDDCXHpXflaZmLlmE2_9-bxjs4AaCIYM,10472
68
68
  tme/matching_data.py,sha256=y6UBUykYpZHsl5NEDyKthKztuwcKFIaOIlB0eEh-Sww,31255
@@ -71,7 +71,7 @@ tme/matching_optimization.py,sha256=r2zMyOKyvoP8OX9xValJE0DLWbLpNdyY1MLWcOi-H1U,
71
71
  tme/matching_scores.py,sha256=sX4hMki9WjwgenNWJ1wNu04J808wOgGuVqbRYSA2TB8,35860
72
72
  tme/matching_utils.py,sha256=Qn65B2F8v1yHLetqpiirvUbj-A9-J1giq3lE9cSb-1E,27426
73
73
  tme/memory.py,sha256=lOD_XJtAWo3IFG5fJIdYgLSpn8o0ORO6Vua_tlm5YEk,9683
74
- tme/orientations.py,sha256=xMVC5FWDNYSoiIpRXqTi3K1QiSsH2ilUoDEUqEJLYzk,21642
74
+ tme/orientations.py,sha256=23HQ2jmW74cAY7RBHe-fB-T1sb8lJ7OfbnP4W5uYZ_M,21857
75
75
  tme/parser.py,sha256=d59A7meIUJ8OzfB6eaVqd7CwLj9oYnOfubqlvygbu1U,24210
76
76
  tme/preprocessor.py,sha256=7DAGRfCPc9ge-hEuYsNA_KCeajVnpWl-w4NzGQS53GM,40377
77
77
  tme/rotations.py,sha256=wVnvZrRCHb1wd5E6yDsLxiP2G0Oin66A4W14vsz4i50,10641
@@ -85,10 +85,10 @@ tme/analyzer/peaks.py,sha256=y40YVa2zJFLWD3aoP0xtuhK8yOrQV0AVH6ElWyVjXkM,32587
85
85
  tme/analyzer/proxy.py,sha256=NH0J93ESl5NAVd_22rLZGDWjCx1ov8I995Fr3mYqUSM,4162
86
86
  tme/backends/__init__.py,sha256=VUvEnpfjmZQVn_9QHkLPBoa6y8M3gDyNNL9kq4OQNGE,5249
87
87
  tme/backends/_cupy_utils.py,sha256=scxCSK9BcsDaIbHR2mFxAJRIxdc0r5JTC1gCJmVJa64,24558
88
- tme/backends/_jax_utils.py,sha256=UrC0RwwxwnNdz6QusPOgC8eISTX_C1KbZmt7Karl_Nk,6096
88
+ tme/backends/_jax_utils.py,sha256=pbKtaNDZEf7mAxKoQwcUXTYIkVzpnzjhF_wcbxO4AC4,6799
89
89
  tme/backends/_numpyfftw_utils.py,sha256=RnYgbSs_sL-NlwYGRqU9kziq0U5qzl9GecVx9uUeSJs,7904
90
90
  tme/backends/cupy_backend.py,sha256=Ms2sObxr0xc_tdHctcL659G8YWOTqPG59F3665yZH4c,8163
91
- tme/backends/jax_backend.py,sha256=d14u__H9leb5aXCz383T1xjTrGOzxgfWxXjgRrRWZMI,12676
91
+ tme/backends/jax_backend.py,sha256=Al1J60NfZTawmDWuCxxXLoEVu7HuP1sFdQlyMXbF1DE,12676
92
92
  tme/backends/matching_backend.py,sha256=B9LXxKpfNYdJBIksyGm7g7V7SC9G3mrluscz3aqISsU,33199
93
93
  tme/backends/mlx_backend.py,sha256=3aAKv7VkK_jfxwDMPQZsLp_lRU8MBJ6zNAzrs6VZd3s,6718
94
94
  tme/backends/npfftw_backend.py,sha256=8-Hs3slqnWvsHpqICxKxb-fbmg3IbSok3sThXew-P7M,18800
@@ -126,8 +126,8 @@ tme/filters/ctf.py,sha256=HNB03Pw3gn2Y8pIM2391iRHQzp2NcHZWFDC61v3K0oM,24520
126
126
  tme/filters/reconstruction.py,sha256=4gCGpuWuM-yCDQvuzyex_Q8bmsFqUI_8IoZnv3B_3UA,7564
127
127
  tme/filters/wedge.py,sha256=JWiN9IwRXnc7A7syq11kSuin6QjhlVmi38nwGHu4yxE,19541
128
128
  tme/filters/whitening.py,sha256=Zwk-0jMVqy_F8TdQ0ht5unMz2JBOePOC3okUhQpU0bo,6348
129
- pytme-0.3.1.post1.dist-info/METADATA,sha256=2EXMjS9Dx7TEo0gDEgGjvWD7rKnhJ0eTTb4AalGW9kE,4985
130
- pytme-0.3.1.post1.dist-info/WHEEL,sha256=SPbiHAOPLnBtml4sk5MIwvWF6YnGfOfy9C__W6Bpeg4,109
131
- pytme-0.3.1.post1.dist-info/entry_points.txt,sha256=pbUSmB0J4Ghlg0w7jHfaFSvPMuvRWzeSuvdjdRPisAU,288
132
- pytme-0.3.1.post1.dist-info/top_level.txt,sha256=ovCUR7UXXouH3zYt_fJLoqr_vtjp1wudFgjVAnztQLE,18
133
- pytme-0.3.1.post1.dist-info/RECORD,,
129
+ pytme-0.3.1.post2.dist-info/METADATA,sha256=2l8XEWvLoIRz-UdlbcOMZHTrRiRAmQIPBODnOGcYZUI,4985
130
+ pytme-0.3.1.post2.dist-info/WHEEL,sha256=SPbiHAOPLnBtml4sk5MIwvWF6YnGfOfy9C__W6Bpeg4,109
131
+ pytme-0.3.1.post2.dist-info/entry_points.txt,sha256=pbUSmB0J4Ghlg0w7jHfaFSvPMuvRWzeSuvdjdRPisAU,288
132
+ pytme-0.3.1.post2.dist-info/top_level.txt,sha256=ovCUR7UXXouH3zYt_fJLoqr_vtjp1wudFgjVAnztQLE,18
133
+ pytme-0.3.1.post2.dist-info/RECORD,,
scripts/match_template.py CHANGED
@@ -359,8 +359,8 @@ def parse_args():
359
359
  "--invert-target-contrast",
360
360
  action="store_true",
361
361
  default=False,
362
- help="Invert the target's contrast for cases where templates to-be-matched have "
363
- "negative values, e.g. tomograms.",
362
+ help="Invert the target contrast. Useful for matching on tomograms if the "
363
+ "template has not been inverted.",
364
364
  )
365
365
  io_group.add_argument(
366
366
  "--scramble-phases",
scripts/postprocess.py CHANGED
@@ -375,10 +375,10 @@ def normalize_input(foregrounds: Tuple[str], backgrounds: Tuple[str]) -> Tuple:
375
375
  update = tuple(slice(0, int(x)) for x in np.minimum(out_shape, scores.shape))
376
376
  scores_out = np.full(out_shape, fill_value=0, dtype=np.float32)
377
377
  scores_out[update] = data[0][update] - scores_norm[update]
378
+ scores_out = np.fmax(scores_out, 0, out=scores_out)
378
379
  scores_out[update] += scores_norm[update].mean()
379
380
 
380
381
  # scores_out[update] = np.divide(scores_out[update], 1 - scores_norm[update])
381
- scores_out = np.fmax(scores_out, 0, out=scores_out)
382
382
  data[0] = scores_out
383
383
 
384
384
  fg, bg = simple_stats(data[0]), simple_stats(scores_norm)
@@ -474,6 +474,7 @@ def membrane_mask(
474
474
  **kwargs,
475
475
  ) -> NDArray:
476
476
  return create_mask(
477
+ center=(center_x, center_y, center_z),
477
478
  mask_type="membrane",
478
479
  shape=template.shape,
479
480
  radius=radius,
scripts/refine_matches.py CHANGED
@@ -10,11 +10,9 @@ import subprocess
10
10
  from sys import exit
11
11
  from os import unlink
12
12
  from time import time
13
- from os.path import join
14
13
  from typing import Tuple, List, Dict
15
14
 
16
15
  import numpy as np
17
- from scipy import optimize
18
16
  from sklearn.metrics import roc_auc_score
19
17
 
20
18
  from tme import Orientations, Density
@@ -66,7 +64,6 @@ def parse_args():
66
64
  matching_group.add_argument(
67
65
  "-i",
68
66
  "--template",
69
- dest="template",
70
67
  type=str,
71
68
  required=True,
72
69
  help="Path to a template in PDB/MMCIF or other supported formats (see target).",
@@ -102,7 +99,7 @@ def parse_args():
102
99
  )
103
100
  matching_group.add_argument(
104
101
  "-s",
105
- dest="score",
102
+ "--score",
106
103
  type=str,
107
104
  default="batchFLCSphericalMask",
108
105
  choices=list(MATCHING_EXHAUSTIVE_REGISTER.keys()),
@@ -197,6 +194,7 @@ def create_matching_argdict(args) -> Dict:
197
194
  "-n": args.cores,
198
195
  "--ctf-file": args.ctf_file,
199
196
  "--invert-target-contrast": args.invert_target_contrast,
197
+ "--backend" : args.backend,
200
198
  }
201
199
  return arg_dict
202
200
 
@@ -252,7 +250,7 @@ class DeepMatcher:
252
250
  if args.lowpass_range:
253
251
  self.filter_parameters["--lowpass"] = 0
254
252
  if args.highpass_range:
255
- self.filter_parameters["--highpass"] = 200
253
+ self.filter_parameters["--highpass"] = 0
256
254
 
257
255
  self.postprocess_args = create_postprocessing_argdict(args)
258
256
  self.log_file = f"{args.output_prefix}_optimization_log.txt"
@@ -309,14 +307,14 @@ class DeepMatcher:
309
307
 
310
308
  match_template = argdict_to_command(
311
309
  self.match_template_args,
312
- executable="match_template.py",
310
+ executable="match_template",
313
311
  )
314
312
  run_command(match_template)
315
313
 
316
314
  # Assume we get a new peak for each input in the same order
317
315
  postprocess = argdict_to_command(
318
316
  self.postprocess_args,
319
- executable="postprocess.py",
317
+ executable="postprocess",
320
318
  )
321
319
  run_command(postprocess)
322
320
 
@@ -95,18 +95,6 @@ class TestDensity:
95
95
  self.orientations.rotations, orientations_new.rotations, atol=1e-3
96
96
  )
97
97
 
98
- @pytest.mark.parametrize("input_format", ("text", "star", "tbl"))
99
- @pytest.mark.parametrize("output_format", ("text", "star", "tbl"))
100
- def test_file_format_io(self, input_format: str, output_format: str):
101
- _, output_file = mkstemp(suffix=f".{input_format}")
102
- _, output_file2 = mkstemp(suffix=f".{output_format}")
103
-
104
- self.orientations.to_file(output_file)
105
- orientations_new = Orientations.from_file(output_file)
106
- orientations_new.to_file(output_file2)
107
-
108
- assert True
109
-
110
98
  @pytest.mark.parametrize("drop_oob", (True, False))
111
99
  @pytest.mark.parametrize("shape", (10, 40, 80))
112
100
  @pytest.mark.parametrize("odd", (True, False))
@@ -17,7 +17,7 @@ from ..backends import backend as be
17
17
  from ..matching_utils import normalize_template as _normalize_template
18
18
 
19
19
 
20
- __all__ = ["scan"]
20
+ __all__ = ["scan", "setup_scan"]
21
21
 
22
22
 
23
23
  def _correlate(template: BackendArray, ft_target: BackendArray) -> BackendArray:
@@ -116,12 +116,49 @@ def _mask_scores(arr, mask):
116
116
  return arr.at[:].multiply(mask)
117
117
 
118
118
 
119
- @partial(
120
- pmap,
121
- in_axes=(0,) + (None,) * 7,
122
- static_broadcasted_argnums=[7, 8, 9, 10],
123
- axis_name="batch",
124
- )
119
+ def _select_config(analyzer_kwargs, device_idx):
120
+ return analyzer_kwargs[device_idx]
121
+
122
+
123
+ def setup_scan(analyzer_kwargs, callback_class, fast_shape, rotate_mask):
124
+ """Create separate scan function with initialized analyzer for each device"""
125
+ device_scans = [
126
+ partial(
127
+ scan,
128
+ fast_shape=fast_shape,
129
+ rotate_mask=rotate_mask,
130
+ analyzer=callback_class(**device_config),
131
+ ) for device_config in analyzer_kwargs
132
+ ]
133
+
134
+ @partial(
135
+ pmap,
136
+ in_axes=(0,) + (None,) * 6,
137
+ axis_name="batch",
138
+ )
139
+ def scan_combined(
140
+ target,
141
+ template,
142
+ template_mask,
143
+ rotations,
144
+ template_filter,
145
+ target_filter,
146
+ score_mask,
147
+ ):
148
+ return lax.switch(
149
+ lax.axis_index("batch"),
150
+ device_scans,
151
+ target,
152
+ template,
153
+ template_mask,
154
+ rotations,
155
+ template_filter,
156
+ target_filter,
157
+ score_mask,
158
+ )
159
+ return scan_combined
160
+
161
+
125
162
  def scan(
126
163
  target: BackendArray,
127
164
  template: BackendArray,
@@ -132,17 +169,10 @@ def scan(
132
169
  score_mask: BackendArray,
133
170
  fast_shape: Tuple[int],
134
171
  rotate_mask: bool,
135
- analyzer_class: object,
136
- analyzer_kwargs: Tuple[Tuple],
172
+ analyzer: object,
137
173
  ) -> Tuple[BackendArray, BackendArray]:
138
174
  eps = jnp.finfo(template.dtype).resolution
139
175
 
140
- kwargs = lax.switch(
141
- lax.axis_index("batch"),
142
- [lambda: analyzer_kwargs[i] for i in range(len(analyzer_kwargs))],
143
- )
144
- analyzer = analyzer_class(**be._tuple_to_dict(kwargs))
145
-
146
176
  if hasattr(target_filter, "shape"):
147
177
  target = _apply_fourier_filter(target, target_filter)
148
178
 
@@ -218,7 +218,7 @@ class JaxBackend(NumpyFFTWBackend):
218
218
  Emulates output of :py:meth:`tme.matching_exhaustive.scan` using
219
219
  :py:class:`tme.analyzer.MaxScoreOverRotations`.
220
220
  """
221
- from ._jax_utils import scan as scan_inner
221
+ from ._jax_utils import setup_scan
222
222
  from ..analyzer import MaxScoreOverRotations
223
223
 
224
224
  pad_target = True if len(splits) > 1 else False
@@ -279,8 +279,7 @@ class JaxBackend(NumpyFFTWBackend):
279
279
  cur_args = analyzer_args.copy()
280
280
  cur_args["offset"] = translation_offset
281
281
  cur_args.update(callback_class_args)
282
-
283
- analyzer_kwargs.append(self._dict_to_tuple(cur_args))
282
+ analyzer_kwargs.append(cur_args)
284
283
 
285
284
  if pad_target:
286
285
  score_mask = base._score_mask(fast_shape, shift)
@@ -310,7 +309,13 @@ class JaxBackend(NumpyFFTWBackend):
310
309
  create_filter, create_template_filter, create_target_filter = (False,) * 3
311
310
  base, targets = None, self._array_backend.stack(targets)
312
311
 
313
- analyzer_kwargs = tuple(analyzer_kwargs)
312
+ scan_inner = setup_scan(
313
+ analyzer_kwargs=analyzer_kwargs,
314
+ callback_class=callback_class,
315
+ fast_shape=fast_shape,
316
+ rotate_mask=rotate_mask
317
+ )
318
+
314
319
  states = scan_inner(
315
320
  self.astype(targets, self._float_dtype),
316
321
  self.astype(matching_data.template, self._float_dtype),
@@ -319,17 +324,12 @@ class JaxBackend(NumpyFFTWBackend):
319
324
  template_filter,
320
325
  target_filter,
321
326
  score_mask,
322
- fast_shape,
323
- rotate_mask,
324
- callback_class,
325
- analyzer_kwargs,
326
327
  )
327
328
 
328
329
  ndim = targets.ndim - 1
329
330
  for index in range(targets.shape[0]):
330
- kwargs = self._tuple_to_dict(analyzer_kwargs[index])
331
+ kwargs = analyzer_kwargs[index]
331
332
  analyzer = callback_class(**kwargs)
332
-
333
333
  state = [self._unbatch(x, ndim, index) for x in states]
334
334
 
335
335
  if isinstance(analyzer, MaxScoreOverRotations):
tme/density.py CHANGED
@@ -2196,7 +2196,7 @@ class Density:
2196
2196
 
2197
2197
  Parameters
2198
2198
  ----------
2199
- target : Density
2199
+ target : :py:class:`Density`
2200
2200
  The target map for template matching.
2201
2201
  template : Structure
2202
2202
  The template that should be aligned to the target.
@@ -2259,3 +2259,60 @@ class Density:
2259
2259
  coordinates = np.array(np.where(data > 0))
2260
2260
  weights = self.data[tuple(coordinates)]
2261
2261
  return align_to_axis(coordinates.T, weights=weights, axis=axis, flip=flip)
2262
+
2263
+ @staticmethod
2264
+ def fourier_shell_correlation(density1: "Density", density2: "Density") -> NDArray:
2265
+ """
2266
+ Computes the Fourier Shell Correlation (FSC) between two instances of `Density`.
2267
+
2268
+ The Fourier transforms of the input maps are divided into shells
2269
+ based on their spatial frequency. The correlation between corresponding shells
2270
+ in the two maps is computed to give the FSC.
2271
+
2272
+ Parameters
2273
+ ----------
2274
+ density1 : :py:class:`Density`
2275
+ Reference for comparison.
2276
+ density2 : :py:class:`Density`
2277
+ Target for comparison.
2278
+
2279
+ Returns
2280
+ -------
2281
+ NDArray
2282
+ An array of shape (N, 2), where N is the number of shells.
2283
+ The first column represents the spatial frequency for each shell
2284
+ and the second column represents the corresponding FSC.
2285
+
2286
+ References
2287
+ ----------
2288
+ .. [1] https://github.com/tdgrant1/denss/blob/master/saxstats/saxstats.py
2289
+ """
2290
+ side = density1.data.shape[0]
2291
+ df = 1.0 / side
2292
+
2293
+ qx_ = np.fft.fftfreq(side) * side * df
2294
+ qx, qy, qz = np.meshgrid(qx_, qx_, qx_, indexing="ij")
2295
+ qr = np.sqrt(qx**2 + qy**2 + qz**2)
2296
+
2297
+ qmax = np.max(qr)
2298
+ qstep = np.min(qr[qr > 0])
2299
+ nbins = int(qmax / qstep)
2300
+ qbins = np.linspace(0, nbins * qstep, nbins + 1)
2301
+ qbin_labels = np.searchsorted(qbins, qr, "right") - 1
2302
+
2303
+ F1 = np.fft.fftn(density1.data)
2304
+ F2 = np.fft.fftn(density2.data)
2305
+
2306
+ qbin_labels = qbin_labels.reshape(-1)
2307
+ numerator = np.bincount(
2308
+ qbin_labels, weights=np.real(F1 * np.conj(F2)).reshape(-1)
2309
+ )
2310
+ term1 = np.bincount(qbin_labels, weights=np.abs(F1).reshape(-1) ** 2)
2311
+ term2 = np.bincount(qbin_labels, weights=np.abs(F2).reshape(-1) ** 2)
2312
+ np.multiply(term1, term2, out=term1)
2313
+ denominator = np.sqrt(term1)
2314
+ FSC = np.divide(numerator, denominator)
2315
+
2316
+ qidx = np.where(qbins < qx.max())
2317
+
2318
+ return np.vstack((qbins[qidx], FSC[qidx])).T
tme/orientations.py CHANGED
@@ -494,16 +494,22 @@ class Orientations:
494
494
 
495
495
  @classmethod
496
496
  def _from_star(
497
- cls, filename: str, delimiter: str = "\t"
497
+ cls, filename: str, delimiter: str = None
498
498
  ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
499
499
  parser = StarParser(filename, delimiter=delimiter)
500
500
 
501
- ret = parser.get("data_particles", None)
502
- if ret is None:
503
- ret = parser.get("data_", None)
501
+ keyword_order = ("data_particles", "particles", "data")
502
+ for keyword in keyword_order:
503
+ ret = parser.get(keyword, None)
504
+ if ret is None:
505
+ ret = parser.get(f"{keyword}_", None)
506
+ if ret is not None:
507
+ break
504
508
 
505
509
  if ret is None:
506
- raise ValueError(f"No data_particles section found in {filename}.")
510
+ raise ValueError(
511
+ f"Could not find either {keyword_order} section found in {filename}."
512
+ )
507
513
 
508
514
  translation = np.vstack(
509
515
  (ret["_rlnCoordinateX"], ret["_rlnCoordinateY"], ret["_rlnCoordinateZ"])
@@ -375,10 +375,10 @@ def normalize_input(foregrounds: Tuple[str], backgrounds: Tuple[str]) -> Tuple:
375
375
  update = tuple(slice(0, int(x)) for x in np.minimum(out_shape, scores.shape))
376
376
  scores_out = np.full(out_shape, fill_value=0, dtype=np.float32)
377
377
  scores_out[update] = data[0][update] - scores_norm[update]
378
+ scores_out = np.fmax(scores_out, 0, out=scores_out)
378
379
  scores_out[update] += scores_norm[update].mean()
379
380
 
380
381
  # scores_out[update] = np.divide(scores_out[update], 1 - scores_norm[update])
381
- scores_out = np.fmax(scores_out, 0, out=scores_out)
382
382
  data[0] = scores_out
383
383
 
384
384
  fg, bg = simple_stats(data[0]), simple_stats(scores_norm)