pytme 0.2.1__cp311-cp311-macosx_14_0_arm64.whl → 0.2.2__cp311-cp311-macosx_14_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. {pytme-0.2.1.data → pytme-0.2.2.data}/scripts/match_template.py +147 -93
  2. {pytme-0.2.1.data → pytme-0.2.2.data}/scripts/postprocess.py +67 -26
  3. {pytme-0.2.1.data → pytme-0.2.2.data}/scripts/preprocessor_gui.py +175 -85
  4. pytme-0.2.2.dist-info/METADATA +91 -0
  5. pytme-0.2.2.dist-info/RECORD +74 -0
  6. {pytme-0.2.1.dist-info → pytme-0.2.2.dist-info}/WHEEL +1 -1
  7. scripts/extract_candidates.py +20 -13
  8. scripts/match_template.py +147 -93
  9. scripts/match_template_filters.py +154 -95
  10. scripts/postprocess.py +67 -26
  11. scripts/preprocessor_gui.py +175 -85
  12. scripts/refine_matches.py +265 -61
  13. tme/__init__.py +0 -1
  14. tme/__version__.py +1 -1
  15. tme/analyzer.py +451 -809
  16. tme/backends/__init__.py +40 -11
  17. tme/backends/_jax_utils.py +185 -0
  18. tme/backends/cupy_backend.py +111 -223
  19. tme/backends/jax_backend.py +214 -150
  20. tme/backends/matching_backend.py +445 -384
  21. tme/backends/mlx_backend.py +32 -59
  22. tme/backends/npfftw_backend.py +239 -507
  23. tme/backends/pytorch_backend.py +21 -145
  24. tme/density.py +233 -363
  25. tme/extensions.cpython-311-darwin.so +0 -0
  26. tme/matching_data.py +322 -285
  27. tme/matching_exhaustive.py +172 -1493
  28. tme/matching_optimization.py +143 -106
  29. tme/matching_scores.py +884 -0
  30. tme/matching_utils.py +280 -386
  31. tme/memory.py +377 -0
  32. tme/orientations.py +52 -12
  33. tme/parser.py +3 -4
  34. tme/preprocessing/_utils.py +61 -32
  35. tme/preprocessing/compose.py +7 -3
  36. tme/preprocessing/frequency_filters.py +49 -39
  37. tme/preprocessing/tilt_series.py +34 -40
  38. tme/preprocessor.py +560 -526
  39. tme/structure.py +491 -188
  40. tme/types.py +5 -3
  41. pytme-0.2.1.dist-info/METADATA +0 -73
  42. pytme-0.2.1.dist-info/RECORD +0 -73
  43. tme/helpers.py +0 -881
  44. tme/matching_constrained.py +0 -195
  45. {pytme-0.2.1.data → pytme-0.2.2.data}/scripts/estimate_ram_usage.py +0 -0
  46. {pytme-0.2.1.data → pytme-0.2.2.data}/scripts/preprocess.py +0 -0
  47. {pytme-0.2.1.dist-info → pytme-0.2.2.dist-info}/LICENSE +0 -0
  48. {pytme-0.2.1.dist-info → pytme-0.2.2.dist-info}/entry_points.txt +0 -0
  49. {pytme-0.2.1.dist-info → pytme-0.2.2.dist-info}/top_level.txt +0 -0
tme/backends/__init__.py CHANGED
@@ -4,14 +4,15 @@
4
4
 
5
5
  Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
6
6
  """
7
-
8
7
  from typing import Dict, List
8
+ from importlib.util import find_spec
9
9
 
10
10
  from .matching_backend import MatchingBackend
11
11
  from .npfftw_backend import NumpyFFTWBackend
12
12
  from .pytorch_backend import PytorchBackend
13
13
  from .cupy_backend import CupyBackend
14
14
  from .mlx_backend import MLXBackend
15
+ from .jax_backend import JaxBackend
15
16
 
16
17
 
17
18
  class BackendManager:
@@ -43,24 +44,26 @@ class BackendManager:
43
44
 
44
45
  >>> backend.change_backend("pytorch")
45
46
  >>> backend.multiply(arr1, arr2)
46
- # This will use the GPUBackend's multiply method
47
+ # This will use the pytorchs multiply method
48
+
49
+ >>> backend.available_backends()
50
+ # Backends available on your system
47
51
 
48
52
  Notes
49
53
  -----
50
- To add custom backends, use the `add_backend` method. To switch
51
- between backends, use the `change_backend` method. Note that the backend
52
- has to be reinitialzed when using fork-based parallelism.
54
+ The backend has to be reinitialzed when using fork-based parallelism.
53
55
  """
54
56
 
55
57
  def __init__(self):
56
58
  self._BACKEND_REGISTRY = {
57
- "cpu_backend": NumpyFFTWBackend,
59
+ "numpyfftw": NumpyFFTWBackend,
58
60
  "pytorch": PytorchBackend,
59
61
  "cupy": CupyBackend,
60
62
  "mlx": MLXBackend,
63
+ "jax": JaxBackend,
61
64
  }
62
65
  self._backend = NumpyFFTWBackend()
63
- self._backend_name = "cpu_backend"
66
+ self._backend_name = "numpyfftw"
64
67
  self._backend_args = {}
65
68
 
66
69
  def __repr__(self):
@@ -99,7 +102,8 @@ class BackendManager:
99
102
  Raises
100
103
  ------
101
104
  ValueError
102
- If the provided backend_instance does not inherit from MatchingBackend.
105
+ If the provided backend_instance does not inherit from
106
+ :py:class:`MatchingBackend`.
103
107
  """
104
108
  if not issubclass(backend_class, MatchingBackend):
105
109
  raise ValueError("backend_class needs to inherit from MatchingBackend.")
@@ -122,9 +126,7 @@ class BackendManager:
122
126
  If no backend is found with the provided name.
123
127
  """
124
128
  if backend_name not in self._BACKEND_REGISTRY:
125
- available_backends = ", ".join(
126
- [str(x) for x in self._BACKEND_REGISTRY.keys()]
127
- )
129
+ available_backends = ", ".join(self.available_backends())
128
130
  raise NotImplementedError(
129
131
  f"Available backends are {available_backends} - not {backend_name}."
130
132
  )
@@ -132,5 +134,32 @@ class BackendManager:
132
134
  self._backend_name = backend_name
133
135
  self._backend_args = backend_kwargs
134
136
 
137
+ def available_backends(self) -> List[str]:
138
+ """
139
+ Determines importable backends.
140
+
141
+ Returns
142
+ -------
143
+ list of str
144
+ Backends that are available for template matching.
145
+ """
146
+ # This is an approximation but avoids runtime polution
147
+ _dependencies = {
148
+ "numpyfftw": "numpy",
149
+ "cupy": "cupy",
150
+ "pytorch": "pytorch",
151
+ "mlx": "mlx",
152
+ "jax": "jax",
153
+ }
154
+ available_backends = []
155
+ for name, backend in self._BACKEND_REGISTRY.items():
156
+ if name not in _dependencies:
157
+ continue
158
+
159
+ if find_spec(_dependencies[name]) is not None:
160
+ available_backends.append(name)
161
+
162
+ return available_backends
163
+
135
164
 
136
165
  backend = BackendManager()
@@ -0,0 +1,185 @@
1
+ """ Utility functions for jax backend.
2
+
3
+ Copyright (c) 2023-2024 European Molecular Biology Laboratory
4
+
5
+ Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
6
+ """
7
+ from typing import Tuple
8
+ from functools import partial
9
+
10
+ import jax.numpy as jnp
11
+ from jax import pmap, lax
12
+
13
+ from ..types import BackendArray
14
+ from ..backends import backend as be
15
+ from ..matching_utils import normalize_template as _normalize_template
16
+
17
+
18
+ def _correlate(template: BackendArray, ft_target: BackendArray) -> BackendArray:
19
+ """
20
+ Computes :py:meth:`tme.matching_exhaustive.cc_setup`.
21
+ """
22
+ template_ft = jnp.fft.rfftn(template)
23
+ template_ft = template_ft.at[:].multiply(ft_target)
24
+ correlation = jnp.fft.irfftn(template_ft)
25
+ return correlation
26
+
27
+
28
+ def _flc_scoring(
29
+ template: BackendArray,
30
+ template_mask: BackendArray,
31
+ ft_target: BackendArray,
32
+ ft_target2: BackendArray,
33
+ n_observations: BackendArray,
34
+ eps: float,
35
+ **kwargs,
36
+ ) -> BackendArray:
37
+ """
38
+ Computes :py:meth:`tme.matching_exhaustive.flc_scoring`.
39
+ """
40
+ correlation = _correlate(template=template, ft_target=ft_target)
41
+ inv_denominator = _reciprocal_target_std(
42
+ ft_target=ft_target,
43
+ ft_target2=ft_target2,
44
+ template_mask=template_mask,
45
+ eps=eps,
46
+ n_observations=n_observations,
47
+ )
48
+ correlation = correlation.at[:].multiply(inv_denominator)
49
+ return correlation
50
+
51
+
52
+ def _flcSphere_scoring(
53
+ template: BackendArray,
54
+ ft_target: BackendArray,
55
+ inv_denominator: BackendArray,
56
+ **kwargs,
57
+ ) -> BackendArray:
58
+ """
59
+ Computes :py:meth:`tme.matching_exhaustive.flc_scoring`.
60
+ """
61
+ correlation = _correlate(template=template, ft_target=ft_target)
62
+ correlation = correlation.at[:].multiply(inv_denominator)
63
+ return correlation
64
+
65
+
66
+ def _reciprocal_target_std(
67
+ ft_target: BackendArray,
68
+ ft_target2: BackendArray,
69
+ template_mask: BackendArray,
70
+ n_observations: float,
71
+ eps: float,
72
+ ) -> BackendArray:
73
+ """
74
+ Computes reciprocal standard deviation of a target given a mask.
75
+
76
+ See Also
77
+ --------
78
+ :py:meth:`tme.matching_exhaustive.flc_scoring`.
79
+ """
80
+ ft_template_mask = jnp.fft.rfftn(template_mask)
81
+
82
+ # E(X^2)- E(X)^2
83
+ exp_sq = jnp.fft.irfftn(ft_target2 * ft_template_mask)
84
+ exp_sq = exp_sq.at[:].divide(n_observations)
85
+
86
+ ft_template_mask = ft_template_mask.at[:].multiply(ft_target)
87
+ sq_exp = jnp.fft.irfftn(ft_template_mask)
88
+ sq_exp = sq_exp.at[:].divide(n_observations)
89
+ sq_exp = sq_exp.at[:].power(2)
90
+
91
+ exp_sq = exp_sq.at[:].add(-sq_exp)
92
+ exp_sq = exp_sq.at[:].max(0)
93
+ exp_sq = exp_sq.at[:].power(0.5)
94
+
95
+ exp_sq = exp_sq.at[:].set(
96
+ jnp.where(exp_sq <= eps, 0, jnp.reciprocal(exp_sq * n_observations))
97
+ )
98
+ return exp_sq
99
+
100
+
101
+ def _apply_fourier_filter(arr: BackendArray, arr_filter: BackendArray) -> BackendArray:
102
+ arr_ft = jnp.fft.rfftn(arr)
103
+ arr_ft = arr_ft.at[:].multiply(arr_filter)
104
+ return arr.at[:].set(jnp.fft.irfftn(arr_ft, s=arr.shape))
105
+
106
+
107
+ def _identity(arr: BackendArray, arr_filter: BackendArray) -> BackendArray:
108
+ return arr
109
+
110
+ @partial(
111
+ pmap,
112
+ in_axes=(0,) + (None,) * 6,
113
+ static_broadcasted_argnums=[6, 7],
114
+ )
115
+ def scan(
116
+ target: BackendArray,
117
+ template: BackendArray,
118
+ template_mask: BackendArray,
119
+ rotations: BackendArray,
120
+ template_filter: BackendArray,
121
+ target_filter: BackendArray,
122
+ fast_shape: Tuple[int],
123
+ rotate_mask: bool,
124
+ ) -> Tuple[BackendArray, BackendArray]:
125
+ eps = jnp.finfo(template.dtype).resolution
126
+
127
+ if hasattr(target_filter, "shape"):
128
+ target = _apply_fourier_filter(target, target_filter)
129
+
130
+ ft_target = jnp.fft.rfftn(target)
131
+ ft_target2 = jnp.fft.rfftn(jnp.square(target))
132
+ inv_denominator, target, scoring_func = None, None, _flc_scoring
133
+ if not rotate_mask:
134
+ n_observations = jnp.sum(template_mask)
135
+ inv_denominator = _reciprocal_target_std(
136
+ ft_target=ft_target,
137
+ ft_target2=ft_target2,
138
+ template_mask=be.topleft_pad(template_mask, fast_shape),
139
+ eps=eps,
140
+ n_observations=n_observations,
141
+ )
142
+ ft_target2, scoring_func = None, _flcSphere_scoring
143
+
144
+ _template_filter_func = _identity
145
+ if template_filter.shape != ():
146
+ _template_filter_func = _apply_fourier_filter
147
+
148
+ def _sample_transform(ret, rotation_matrix):
149
+ max_scores, rotations, index = ret
150
+ template_rot, template_mask_rot = be.rigid_transform(
151
+ arr=template,
152
+ arr_mask=template_mask,
153
+ rotation_matrix=rotation_matrix,
154
+ order=1, # thats all we get for now
155
+ )
156
+
157
+ n_observations = jnp.sum(template_mask_rot)
158
+ template_rot = _template_filter_func(template_rot, template_filter)
159
+ template_rot = _normalize_template(
160
+ template_rot, template_mask_rot, n_observations
161
+ )
162
+ template_rot = be.topleft_pad(template_rot, fast_shape)
163
+ template_mask_rot = be.topleft_pad(template_mask_rot, fast_shape)
164
+
165
+ scores = scoring_func(
166
+ template=template_rot,
167
+ template_mask=template_mask_rot,
168
+ ft_target=ft_target,
169
+ ft_target2=ft_target2,
170
+ inv_denominator=inv_denominator,
171
+ n_observations=n_observations,
172
+ eps=eps,
173
+ )
174
+ max_scores, rotations = be.max_score_over_rotations(
175
+ scores, max_scores, rotations, index
176
+ )
177
+ return (max_scores, rotations, index + 1), None
178
+
179
+ score_space = jnp.zeros(fast_shape)
180
+ rotation_space = jnp.full(shape=fast_shape, dtype=jnp.int32, fill_value=-1)
181
+ (score_space, rotation_space, _), _ = lax.scan(
182
+ _sample_transform, (score_space, rotation_space, 0), rotations
183
+ )
184
+
185
+ return score_space, rotation_space