pytme 0.2.0b0__cp311-cp311-macosx_14_0_arm64.whl → 0.2.2__cp311-cp311-macosx_14_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. pytme-0.2.2.data/scripts/match_template.py +1187 -0
  2. {pytme-0.2.0b0.data → pytme-0.2.2.data}/scripts/postprocess.py +170 -71
  3. {pytme-0.2.0b0.data → pytme-0.2.2.data}/scripts/preprocessor_gui.py +179 -86
  4. pytme-0.2.2.dist-info/METADATA +91 -0
  5. pytme-0.2.2.dist-info/RECORD +74 -0
  6. {pytme-0.2.0b0.dist-info → pytme-0.2.2.dist-info}/WHEEL +1 -1
  7. scripts/extract_candidates.py +126 -87
  8. scripts/match_template.py +596 -209
  9. scripts/match_template_filters.py +571 -223
  10. scripts/postprocess.py +170 -71
  11. scripts/preprocessor_gui.py +179 -86
  12. scripts/refine_matches.py +567 -159
  13. tme/__init__.py +0 -1
  14. tme/__version__.py +1 -1
  15. tme/analyzer.py +627 -855
  16. tme/backends/__init__.py +41 -11
  17. tme/backends/_jax_utils.py +185 -0
  18. tme/backends/cupy_backend.py +120 -225
  19. tme/backends/jax_backend.py +282 -0
  20. tme/backends/matching_backend.py +464 -388
  21. tme/backends/mlx_backend.py +45 -68
  22. tme/backends/npfftw_backend.py +256 -514
  23. tme/backends/pytorch_backend.py +41 -154
  24. tme/density.py +312 -421
  25. tme/extensions.cpython-311-darwin.so +0 -0
  26. tme/matching_data.py +366 -303
  27. tme/matching_exhaustive.py +279 -1521
  28. tme/matching_optimization.py +234 -129
  29. tme/matching_scores.py +884 -0
  30. tme/matching_utils.py +281 -387
  31. tme/memory.py +377 -0
  32. tme/orientations.py +226 -66
  33. tme/parser.py +3 -4
  34. tme/preprocessing/__init__.py +2 -0
  35. tme/preprocessing/_utils.py +217 -0
  36. tme/preprocessing/composable_filter.py +31 -0
  37. tme/preprocessing/compose.py +55 -0
  38. tme/preprocessing/frequency_filters.py +388 -0
  39. tme/preprocessing/tilt_series.py +1011 -0
  40. tme/preprocessor.py +574 -530
  41. tme/structure.py +495 -189
  42. tme/types.py +5 -3
  43. pytme-0.2.0b0.data/scripts/match_template.py +0 -800
  44. pytme-0.2.0b0.dist-info/METADATA +0 -73
  45. pytme-0.2.0b0.dist-info/RECORD +0 -66
  46. tme/helpers.py +0 -881
  47. tme/matching_constrained.py +0 -195
  48. {pytme-0.2.0b0.data → pytme-0.2.2.data}/scripts/estimate_ram_usage.py +0 -0
  49. {pytme-0.2.0b0.data → pytme-0.2.2.data}/scripts/preprocess.py +0 -0
  50. {pytme-0.2.0b0.dist-info → pytme-0.2.2.dist-info}/LICENSE +0 -0
  51. {pytme-0.2.0b0.dist-info → pytme-0.2.2.dist-info}/entry_points.txt +0 -0
  52. {pytme-0.2.0b0.dist-info → pytme-0.2.2.dist-info}/top_level.txt +0 -0
tme/backends/__init__.py CHANGED
@@ -4,14 +4,15 @@
4
4
 
5
5
  Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
6
6
  """
7
-
8
7
  from typing import Dict, List
8
+ from importlib.util import find_spec
9
9
 
10
10
  from .matching_backend import MatchingBackend
11
11
  from .npfftw_backend import NumpyFFTWBackend
12
12
  from .pytorch_backend import PytorchBackend
13
13
  from .cupy_backend import CupyBackend
14
14
  from .mlx_backend import MLXBackend
15
+ from .jax_backend import JaxBackend
15
16
 
16
17
 
17
18
  class BackendManager:
@@ -43,23 +44,26 @@ class BackendManager:
43
44
 
44
45
  >>> backend.change_backend("pytorch")
45
46
  >>> backend.multiply(arr1, arr2)
46
- # This will use the GPUBackend's multiply method
47
+ # This will use the pytorchs multiply method
48
+
49
+ >>> backend.available_backends()
50
+ # Backends available on your system
47
51
 
48
52
  Notes
49
53
  -----
50
- To add custom backends, use the `add_backend` method. To switch
51
- between backends, use the `change_backend` method. Note that the backend
52
- has to be reinitialzed when using fork-based parallelism.
54
+ The backend has to be reinitialzed when using fork-based parallelism.
53
55
  """
54
56
 
55
57
  def __init__(self):
56
58
  self._BACKEND_REGISTRY = {
57
- "cpu_backend": NumpyFFTWBackend,
59
+ "numpyfftw": NumpyFFTWBackend,
58
60
  "pytorch": PytorchBackend,
59
61
  "cupy": CupyBackend,
62
+ "mlx": MLXBackend,
63
+ "jax": JaxBackend,
60
64
  }
61
65
  self._backend = NumpyFFTWBackend()
62
- self._backend_name = "cpu_backend"
66
+ self._backend_name = "numpyfftw"
63
67
  self._backend_args = {}
64
68
 
65
69
  def __repr__(self):
@@ -98,7 +102,8 @@ class BackendManager:
98
102
  Raises
99
103
  ------
100
104
  ValueError
101
- If the provided backend_instance does not inherit from MatchingBackend.
105
+ If the provided backend_instance does not inherit from
106
+ :py:class:`MatchingBackend`.
102
107
  """
103
108
  if not issubclass(backend_class, MatchingBackend):
104
109
  raise ValueError("backend_class needs to inherit from MatchingBackend.")
@@ -121,9 +126,7 @@ class BackendManager:
121
126
  If no backend is found with the provided name.
122
127
  """
123
128
  if backend_name not in self._BACKEND_REGISTRY:
124
- available_backends = ", ".join(
125
- [str(x) for x in self._BACKEND_REGISTRY.keys()]
126
- )
129
+ available_backends = ", ".join(self.available_backends())
127
130
  raise NotImplementedError(
128
131
  f"Available backends are {available_backends} - not {backend_name}."
129
132
  )
@@ -131,5 +134,32 @@ class BackendManager:
131
134
  self._backend_name = backend_name
132
135
  self._backend_args = backend_kwargs
133
136
 
137
+ def available_backends(self) -> List[str]:
138
+ """
139
+ Determines importable backends.
140
+
141
+ Returns
142
+ -------
143
+ list of str
144
+ Backends that are available for template matching.
145
+ """
146
+ # This is an approximation but avoids runtime polution
147
+ _dependencies = {
148
+ "numpyfftw": "numpy",
149
+ "cupy": "cupy",
150
+ "pytorch": "pytorch",
151
+ "mlx": "mlx",
152
+ "jax": "jax",
153
+ }
154
+ available_backends = []
155
+ for name, backend in self._BACKEND_REGISTRY.items():
156
+ if name not in _dependencies:
157
+ continue
158
+
159
+ if find_spec(_dependencies[name]) is not None:
160
+ available_backends.append(name)
161
+
162
+ return available_backends
163
+
134
164
 
135
165
  backend = BackendManager()
@@ -0,0 +1,185 @@
1
+ """ Utility functions for jax backend.
2
+
3
+ Copyright (c) 2023-2024 European Molecular Biology Laboratory
4
+
5
+ Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
6
+ """
7
+ from typing import Tuple
8
+ from functools import partial
9
+
10
+ import jax.numpy as jnp
11
+ from jax import pmap, lax
12
+
13
+ from ..types import BackendArray
14
+ from ..backends import backend as be
15
+ from ..matching_utils import normalize_template as _normalize_template
16
+
17
+
18
+ def _correlate(template: BackendArray, ft_target: BackendArray) -> BackendArray:
19
+ """
20
+ Computes :py:meth:`tme.matching_exhaustive.cc_setup`.
21
+ """
22
+ template_ft = jnp.fft.rfftn(template)
23
+ template_ft = template_ft.at[:].multiply(ft_target)
24
+ correlation = jnp.fft.irfftn(template_ft)
25
+ return correlation
26
+
27
+
28
+ def _flc_scoring(
29
+ template: BackendArray,
30
+ template_mask: BackendArray,
31
+ ft_target: BackendArray,
32
+ ft_target2: BackendArray,
33
+ n_observations: BackendArray,
34
+ eps: float,
35
+ **kwargs,
36
+ ) -> BackendArray:
37
+ """
38
+ Computes :py:meth:`tme.matching_exhaustive.flc_scoring`.
39
+ """
40
+ correlation = _correlate(template=template, ft_target=ft_target)
41
+ inv_denominator = _reciprocal_target_std(
42
+ ft_target=ft_target,
43
+ ft_target2=ft_target2,
44
+ template_mask=template_mask,
45
+ eps=eps,
46
+ n_observations=n_observations,
47
+ )
48
+ correlation = correlation.at[:].multiply(inv_denominator)
49
+ return correlation
50
+
51
+
52
+ def _flcSphere_scoring(
53
+ template: BackendArray,
54
+ ft_target: BackendArray,
55
+ inv_denominator: BackendArray,
56
+ **kwargs,
57
+ ) -> BackendArray:
58
+ """
59
+ Computes :py:meth:`tme.matching_exhaustive.flc_scoring`.
60
+ """
61
+ correlation = _correlate(template=template, ft_target=ft_target)
62
+ correlation = correlation.at[:].multiply(inv_denominator)
63
+ return correlation
64
+
65
+
66
+ def _reciprocal_target_std(
67
+ ft_target: BackendArray,
68
+ ft_target2: BackendArray,
69
+ template_mask: BackendArray,
70
+ n_observations: float,
71
+ eps: float,
72
+ ) -> BackendArray:
73
+ """
74
+ Computes reciprocal standard deviation of a target given a mask.
75
+
76
+ See Also
77
+ --------
78
+ :py:meth:`tme.matching_exhaustive.flc_scoring`.
79
+ """
80
+ ft_template_mask = jnp.fft.rfftn(template_mask)
81
+
82
+ # E(X^2)- E(X)^2
83
+ exp_sq = jnp.fft.irfftn(ft_target2 * ft_template_mask)
84
+ exp_sq = exp_sq.at[:].divide(n_observations)
85
+
86
+ ft_template_mask = ft_template_mask.at[:].multiply(ft_target)
87
+ sq_exp = jnp.fft.irfftn(ft_template_mask)
88
+ sq_exp = sq_exp.at[:].divide(n_observations)
89
+ sq_exp = sq_exp.at[:].power(2)
90
+
91
+ exp_sq = exp_sq.at[:].add(-sq_exp)
92
+ exp_sq = exp_sq.at[:].max(0)
93
+ exp_sq = exp_sq.at[:].power(0.5)
94
+
95
+ exp_sq = exp_sq.at[:].set(
96
+ jnp.where(exp_sq <= eps, 0, jnp.reciprocal(exp_sq * n_observations))
97
+ )
98
+ return exp_sq
99
+
100
+
101
+ def _apply_fourier_filter(arr: BackendArray, arr_filter: BackendArray) -> BackendArray:
102
+ arr_ft = jnp.fft.rfftn(arr)
103
+ arr_ft = arr_ft.at[:].multiply(arr_filter)
104
+ return arr.at[:].set(jnp.fft.irfftn(arr_ft, s=arr.shape))
105
+
106
+
107
+ def _identity(arr: BackendArray, arr_filter: BackendArray) -> BackendArray:
108
+ return arr
109
+
110
+ @partial(
111
+ pmap,
112
+ in_axes=(0,) + (None,) * 6,
113
+ static_broadcasted_argnums=[6, 7],
114
+ )
115
+ def scan(
116
+ target: BackendArray,
117
+ template: BackendArray,
118
+ template_mask: BackendArray,
119
+ rotations: BackendArray,
120
+ template_filter: BackendArray,
121
+ target_filter: BackendArray,
122
+ fast_shape: Tuple[int],
123
+ rotate_mask: bool,
124
+ ) -> Tuple[BackendArray, BackendArray]:
125
+ eps = jnp.finfo(template.dtype).resolution
126
+
127
+ if hasattr(target_filter, "shape"):
128
+ target = _apply_fourier_filter(target, target_filter)
129
+
130
+ ft_target = jnp.fft.rfftn(target)
131
+ ft_target2 = jnp.fft.rfftn(jnp.square(target))
132
+ inv_denominator, target, scoring_func = None, None, _flc_scoring
133
+ if not rotate_mask:
134
+ n_observations = jnp.sum(template_mask)
135
+ inv_denominator = _reciprocal_target_std(
136
+ ft_target=ft_target,
137
+ ft_target2=ft_target2,
138
+ template_mask=be.topleft_pad(template_mask, fast_shape),
139
+ eps=eps,
140
+ n_observations=n_observations,
141
+ )
142
+ ft_target2, scoring_func = None, _flcSphere_scoring
143
+
144
+ _template_filter_func = _identity
145
+ if template_filter.shape != ():
146
+ _template_filter_func = _apply_fourier_filter
147
+
148
+ def _sample_transform(ret, rotation_matrix):
149
+ max_scores, rotations, index = ret
150
+ template_rot, template_mask_rot = be.rigid_transform(
151
+ arr=template,
152
+ arr_mask=template_mask,
153
+ rotation_matrix=rotation_matrix,
154
+ order=1, # thats all we get for now
155
+ )
156
+
157
+ n_observations = jnp.sum(template_mask_rot)
158
+ template_rot = _template_filter_func(template_rot, template_filter)
159
+ template_rot = _normalize_template(
160
+ template_rot, template_mask_rot, n_observations
161
+ )
162
+ template_rot = be.topleft_pad(template_rot, fast_shape)
163
+ template_mask_rot = be.topleft_pad(template_mask_rot, fast_shape)
164
+
165
+ scores = scoring_func(
166
+ template=template_rot,
167
+ template_mask=template_mask_rot,
168
+ ft_target=ft_target,
169
+ ft_target2=ft_target2,
170
+ inv_denominator=inv_denominator,
171
+ n_observations=n_observations,
172
+ eps=eps,
173
+ )
174
+ max_scores, rotations = be.max_score_over_rotations(
175
+ scores, max_scores, rotations, index
176
+ )
177
+ return (max_scores, rotations, index + 1), None
178
+
179
+ score_space = jnp.zeros(fast_shape)
180
+ rotation_space = jnp.full(shape=fast_shape, dtype=jnp.int32, fill_value=-1)
181
+ (score_space, rotation_space, _), _ = lax.scan(
182
+ _sample_transform, (score_space, rotation_space, 0), rotations
183
+ )
184
+
185
+ return score_space, rotation_space