pytme 0.2.1__cp311-cp311-macosx_14_0_arm64.whl → 0.2.3__cp311-cp311-macosx_14_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (52) hide show
  1. {pytme-0.2.1.data → pytme-0.2.3.data}/scripts/match_template.py +219 -216
  2. {pytme-0.2.1.data → pytme-0.2.3.data}/scripts/postprocess.py +86 -54
  3. pytme-0.2.3.data/scripts/preprocess.py +132 -0
  4. {pytme-0.2.1.data → pytme-0.2.3.data}/scripts/preprocessor_gui.py +181 -94
  5. pytme-0.2.3.dist-info/METADATA +92 -0
  6. pytme-0.2.3.dist-info/RECORD +75 -0
  7. {pytme-0.2.1.dist-info → pytme-0.2.3.dist-info}/WHEEL +1 -1
  8. pytme-0.2.1.data/scripts/preprocess.py → scripts/eval.py +1 -1
  9. scripts/extract_candidates.py +20 -13
  10. scripts/match_template.py +219 -216
  11. scripts/match_template_filters.py +154 -95
  12. scripts/postprocess.py +86 -54
  13. scripts/preprocess.py +95 -56
  14. scripts/preprocessor_gui.py +181 -94
  15. scripts/refine_matches.py +265 -61
  16. tme/__init__.py +0 -1
  17. tme/__version__.py +1 -1
  18. tme/analyzer.py +458 -813
  19. tme/backends/__init__.py +40 -11
  20. tme/backends/_jax_utils.py +187 -0
  21. tme/backends/cupy_backend.py +109 -226
  22. tme/backends/jax_backend.py +230 -152
  23. tme/backends/matching_backend.py +445 -384
  24. tme/backends/mlx_backend.py +32 -59
  25. tme/backends/npfftw_backend.py +240 -507
  26. tme/backends/pytorch_backend.py +30 -151
  27. tme/density.py +248 -371
  28. tme/extensions.cpython-311-darwin.so +0 -0
  29. tme/matching_data.py +328 -284
  30. tme/matching_exhaustive.py +195 -1499
  31. tme/matching_optimization.py +143 -106
  32. tme/matching_scores.py +887 -0
  33. tme/matching_utils.py +287 -388
  34. tme/memory.py +377 -0
  35. tme/orientations.py +78 -21
  36. tme/parser.py +3 -4
  37. tme/preprocessing/_utils.py +61 -32
  38. tme/preprocessing/composable_filter.py +7 -4
  39. tme/preprocessing/compose.py +7 -3
  40. tme/preprocessing/frequency_filters.py +49 -39
  41. tme/preprocessing/tilt_series.py +44 -72
  42. tme/preprocessor.py +560 -526
  43. tme/structure.py +491 -188
  44. tme/types.py +5 -3
  45. pytme-0.2.1.dist-info/METADATA +0 -73
  46. pytme-0.2.1.dist-info/RECORD +0 -73
  47. tme/helpers.py +0 -881
  48. tme/matching_constrained.py +0 -195
  49. {pytme-0.2.1.data → pytme-0.2.3.data}/scripts/estimate_ram_usage.py +0 -0
  50. {pytme-0.2.1.dist-info → pytme-0.2.3.dist-info}/LICENSE +0 -0
  51. {pytme-0.2.1.dist-info → pytme-0.2.3.dist-info}/entry_points.txt +0 -0
  52. {pytme-0.2.1.dist-info → pytme-0.2.3.dist-info}/top_level.txt +0 -0
tme/backends/__init__.py CHANGED
@@ -4,14 +4,15 @@
4
4
 
5
5
  Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
6
6
  """
7
-
8
7
  from typing import Dict, List
8
+ from importlib.util import find_spec
9
9
 
10
10
  from .matching_backend import MatchingBackend
11
11
  from .npfftw_backend import NumpyFFTWBackend
12
12
  from .pytorch_backend import PytorchBackend
13
13
  from .cupy_backend import CupyBackend
14
14
  from .mlx_backend import MLXBackend
15
+ from .jax_backend import JaxBackend
15
16
 
16
17
 
17
18
  class BackendManager:
@@ -43,24 +44,26 @@ class BackendManager:
43
44
 
44
45
  >>> backend.change_backend("pytorch")
45
46
  >>> backend.multiply(arr1, arr2)
46
- # This will use the GPUBackend's multiply method
47
+ # This will use the pytorchs multiply method
48
+
49
+ >>> backend.available_backends()
50
+ # Backends available on your system
47
51
 
48
52
  Notes
49
53
  -----
50
- To add custom backends, use the `add_backend` method. To switch
51
- between backends, use the `change_backend` method. Note that the backend
52
- has to be reinitialzed when using fork-based parallelism.
54
+ The backend has to be reinitialzed when using fork-based parallelism.
53
55
  """
54
56
 
55
57
  def __init__(self):
56
58
  self._BACKEND_REGISTRY = {
57
- "cpu_backend": NumpyFFTWBackend,
59
+ "numpyfftw": NumpyFFTWBackend,
58
60
  "pytorch": PytorchBackend,
59
61
  "cupy": CupyBackend,
60
62
  "mlx": MLXBackend,
63
+ "jax": JaxBackend,
61
64
  }
62
65
  self._backend = NumpyFFTWBackend()
63
- self._backend_name = "cpu_backend"
66
+ self._backend_name = "numpyfftw"
64
67
  self._backend_args = {}
65
68
 
66
69
  def __repr__(self):
@@ -99,7 +102,8 @@ class BackendManager:
99
102
  Raises
100
103
  ------
101
104
  ValueError
102
- If the provided backend_instance does not inherit from MatchingBackend.
105
+ If the provided backend_instance does not inherit from
106
+ :py:class:`MatchingBackend`.
103
107
  """
104
108
  if not issubclass(backend_class, MatchingBackend):
105
109
  raise ValueError("backend_class needs to inherit from MatchingBackend.")
@@ -122,9 +126,7 @@ class BackendManager:
122
126
  If no backend is found with the provided name.
123
127
  """
124
128
  if backend_name not in self._BACKEND_REGISTRY:
125
- available_backends = ", ".join(
126
- [str(x) for x in self._BACKEND_REGISTRY.keys()]
127
- )
129
+ available_backends = ", ".join(self.available_backends())
128
130
  raise NotImplementedError(
129
131
  f"Available backends are {available_backends} - not {backend_name}."
130
132
  )
@@ -132,5 +134,32 @@ class BackendManager:
132
134
  self._backend_name = backend_name
133
135
  self._backend_args = backend_kwargs
134
136
 
137
+ def available_backends(self) -> List[str]:
138
+ """
139
+ Determines importable backends.
140
+
141
+ Returns
142
+ -------
143
+ list of str
144
+ Backends that are available for template matching.
145
+ """
146
+ # This is an approximation but avoids runtime polution
147
+ _dependencies = {
148
+ "numpyfftw": "numpy",
149
+ "cupy": "cupy",
150
+ "pytorch": "torch",
151
+ "mlx": "mlx",
152
+ "jax": "jax",
153
+ }
154
+ available_backends = []
155
+ for name, backend in self._BACKEND_REGISTRY.items():
156
+ if name not in _dependencies:
157
+ continue
158
+
159
+ if find_spec(_dependencies[name]) is not None:
160
+ available_backends.append(name)
161
+
162
+ return available_backends
163
+
135
164
 
136
165
  backend = BackendManager()
@@ -0,0 +1,187 @@
1
+ """ Utility functions for jax backend.
2
+
3
+ Copyright (c) 2023-2024 European Molecular Biology Laboratory
4
+
5
+ Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
6
+ """
7
+ from typing import Tuple
8
+ from functools import partial
9
+
10
+ import jax.numpy as jnp
11
+ from jax import pmap, lax
12
+
13
+ from ..types import BackendArray
14
+ from ..backends import backend as be
15
+ from ..matching_utils import normalize_template as _normalize_template
16
+
17
+
18
+ def _correlate(template: BackendArray, ft_target: BackendArray) -> BackendArray:
19
+ """
20
+ Computes :py:meth:`tme.matching_exhaustive.cc_setup`.
21
+ """
22
+ template_ft = jnp.fft.rfftn(template, s=template.shape)
23
+ template_ft = template_ft.at[:].multiply(ft_target)
24
+ correlation = jnp.fft.irfftn(template_ft, s=template.shape)
25
+ return correlation
26
+
27
+
28
+ def _flc_scoring(
29
+ template: BackendArray,
30
+ template_mask: BackendArray,
31
+ ft_target: BackendArray,
32
+ ft_target2: BackendArray,
33
+ n_observations: BackendArray,
34
+ eps: float,
35
+ **kwargs,
36
+ ) -> BackendArray:
37
+ """
38
+ Computes :py:meth:`tme.matching_exhaustive.flc_scoring`.
39
+ """
40
+ correlation = _correlate(template=template, ft_target=ft_target)
41
+ inv_denominator = _reciprocal_target_std(
42
+ ft_target=ft_target,
43
+ ft_target2=ft_target2,
44
+ template_mask=template_mask,
45
+ eps=eps,
46
+ n_observations=n_observations,
47
+ )
48
+ correlation = correlation.at[:].multiply(inv_denominator)
49
+ return correlation
50
+
51
+
52
+ def _flcSphere_scoring(
53
+ template: BackendArray,
54
+ ft_target: BackendArray,
55
+ inv_denominator: BackendArray,
56
+ **kwargs,
57
+ ) -> BackendArray:
58
+ """
59
+ Computes :py:meth:`tme.matching_exhaustive.flc_scoring`.
60
+ """
61
+ correlation = _correlate(template=template, ft_target=ft_target)
62
+ correlation = correlation.at[:].multiply(inv_denominator)
63
+ return correlation
64
+
65
+
66
+ def _reciprocal_target_std(
67
+ ft_target: BackendArray,
68
+ ft_target2: BackendArray,
69
+ template_mask: BackendArray,
70
+ n_observations: float,
71
+ eps: float,
72
+ ) -> BackendArray:
73
+ """
74
+ Computes reciprocal standard deviation of a target given a mask.
75
+
76
+ See Also
77
+ --------
78
+ :py:meth:`tme.matching_exhaustive.flc_scoring`.
79
+ """
80
+ ft_shape = template_mask.shape
81
+ ft_template_mask = jnp.fft.rfftn(template_mask, s=ft_shape)
82
+
83
+ # E(X^2)- E(X)^2
84
+ exp_sq = jnp.fft.irfftn(ft_target2 * ft_template_mask, s=ft_shape)
85
+ exp_sq = exp_sq.at[:].divide(n_observations)
86
+
87
+ ft_template_mask = ft_template_mask.at[:].multiply(ft_target)
88
+ sq_exp = jnp.fft.irfftn(ft_template_mask, s=ft_shape)
89
+ sq_exp = sq_exp.at[:].divide(n_observations)
90
+ sq_exp = sq_exp.at[:].power(2)
91
+
92
+ exp_sq = exp_sq.at[:].add(-sq_exp)
93
+ exp_sq = exp_sq.at[:].max(0)
94
+ exp_sq = exp_sq.at[:].power(0.5)
95
+
96
+ exp_sq = exp_sq.at[:].set(
97
+ jnp.where(exp_sq <= eps, 0, jnp.reciprocal(exp_sq * n_observations))
98
+ )
99
+ return exp_sq
100
+
101
+
102
+ def _apply_fourier_filter(arr: BackendArray, arr_filter: BackendArray) -> BackendArray:
103
+ arr_ft = jnp.fft.rfftn(arr, s=arr.shape)
104
+ arr_ft = arr_ft.at[:].multiply(arr_filter)
105
+ return arr.at[:].set(jnp.fft.irfftn(arr_ft, s=arr.shape))
106
+
107
+
108
+ def _identity(arr: BackendArray, arr_filter: BackendArray) -> BackendArray:
109
+ return arr
110
+
111
+
112
+ @partial(
113
+ pmap,
114
+ in_axes=(0,) + (None,) * 6,
115
+ static_broadcasted_argnums=[6, 7],
116
+ )
117
+ def scan(
118
+ target: BackendArray,
119
+ template: BackendArray,
120
+ template_mask: BackendArray,
121
+ rotations: BackendArray,
122
+ template_filter: BackendArray,
123
+ target_filter: BackendArray,
124
+ fast_shape: Tuple[int],
125
+ rotate_mask: bool,
126
+ ) -> Tuple[BackendArray, BackendArray]:
127
+ eps = jnp.finfo(template.dtype).resolution
128
+
129
+ if hasattr(target_filter, "shape"):
130
+ target = _apply_fourier_filter(target, target_filter)
131
+
132
+ ft_target = jnp.fft.rfftn(target, s=fast_shape)
133
+ ft_target2 = jnp.fft.rfftn(jnp.square(target), s=fast_shape)
134
+ inv_denominator, target, scoring_func = None, None, _flc_scoring
135
+ if not rotate_mask:
136
+ n_observations = jnp.sum(template_mask)
137
+ inv_denominator = _reciprocal_target_std(
138
+ ft_target=ft_target,
139
+ ft_target2=ft_target2,
140
+ template_mask=be.topleft_pad(template_mask, fast_shape),
141
+ eps=eps,
142
+ n_observations=n_observations,
143
+ )
144
+ ft_target2, scoring_func = None, _flcSphere_scoring
145
+
146
+ _template_filter_func = _identity
147
+ if template_filter.shape != ():
148
+ _template_filter_func = _apply_fourier_filter
149
+
150
+ def _sample_transform(ret, rotation_matrix):
151
+ max_scores, rotations, index = ret
152
+ template_rot, template_mask_rot = be.rigid_transform(
153
+ arr=template,
154
+ arr_mask=template_mask,
155
+ rotation_matrix=rotation_matrix,
156
+ order=1, # thats all we get for now
157
+ )
158
+
159
+ n_observations = jnp.sum(template_mask_rot)
160
+ template_rot = _template_filter_func(template_rot, template_filter)
161
+ template_rot = _normalize_template(
162
+ template_rot, template_mask_rot, n_observations
163
+ )
164
+ template_rot = be.topleft_pad(template_rot, fast_shape)
165
+ template_mask_rot = be.topleft_pad(template_mask_rot, fast_shape)
166
+
167
+ scores = scoring_func(
168
+ template=template_rot,
169
+ template_mask=template_mask_rot,
170
+ ft_target=ft_target,
171
+ ft_target2=ft_target2,
172
+ inv_denominator=inv_denominator,
173
+ n_observations=n_observations,
174
+ eps=eps,
175
+ )
176
+ max_scores, rotations = be.max_score_over_rotations(
177
+ scores, max_scores, rotations, index
178
+ )
179
+ return (max_scores, rotations, index + 1), None
180
+
181
+ score_space = jnp.zeros(fast_shape)
182
+ rotation_space = jnp.full(shape=fast_shape, dtype=jnp.int32, fill_value=-1)
183
+ (score_space, rotation_space, _), _ = lax.scan(
184
+ _sample_transform, (score_space, rotation_space, 0), rotations
185
+ )
186
+
187
+ return score_space, rotation_space