pytme 0.2.4__cp311-cp311-macosx_14_0_arm64.whl → 0.2.5__cp311-cp311-macosx_14_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: pytme
3
- Version: 0.2.4
3
+ Version: 0.2.5
4
4
  Summary: Python Template Matching Engine
5
5
  Author: Valentin Maurer
6
6
  Author-email: Valentin Maurer <valentin.maurer@embl-hamburg.de>
@@ -1,8 +1,8 @@
1
- pytme-0.2.4.data/scripts/estimate_ram_usage.py,sha256=R1NDpFajcF-MonJ4a43SfDlA-nxBYwK7D2quzCdsVFM,2767
2
- pytme-0.2.4.data/scripts/match_template.py,sha256=fDxH0yYudh4bWimmum1hCtjasG1EVJp6mKZ8a6zDt0Q,39852
3
- pytme-0.2.4.data/scripts/postprocess.py,sha256=50PwDfOWe2Fdws4J5K-k2SgM55fARlAWCnIsv-l0i-4,24414
4
- pytme-0.2.4.data/scripts/preprocess.py,sha256=A2nQlNr2fvrZ6C89jGsscgWk85KuDQIPKloQGBhExeE,4380
5
- pytme-0.2.4.data/scripts/preprocessor_gui.py,sha256=AHgL8j7nVCH3srsyGYWU7i3mCxeu00H-mR2qObR90GA,39071
1
+ pytme-0.2.5.data/scripts/estimate_ram_usage.py,sha256=R1NDpFajcF-MonJ4a43SfDlA-nxBYwK7D2quzCdsVFM,2767
2
+ pytme-0.2.5.data/scripts/match_template.py,sha256=fDxH0yYudh4bWimmum1hCtjasG1EVJp6mKZ8a6zDt0Q,39852
3
+ pytme-0.2.5.data/scripts/postprocess.py,sha256=50PwDfOWe2Fdws4J5K-k2SgM55fARlAWCnIsv-l0i-4,24414
4
+ pytme-0.2.5.data/scripts/preprocess.py,sha256=A2nQlNr2fvrZ6C89jGsscgWk85KuDQIPKloQGBhExeE,4380
5
+ pytme-0.2.5.data/scripts/preprocessor_gui.py,sha256=AHgL8j7nVCH3srsyGYWU7i3mCxeu00H-mR2qObR90GA,39071
6
6
  scripts/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
7
7
  scripts/estimate_ram_usage.py,sha256=rN7haobnHg3YcgGJIp81FNiCzy8-saJGeEurQlmQmNQ,2768
8
8
  scripts/eval.py,sha256=ebJVLxbRlB6TI5YHNr0VavZ4lmaRdf8QVafyiDhh_oU,2528
@@ -22,7 +22,7 @@ tests/test_matching_cli.py,sha256=9qLrUM3nuIkY_LaKuzxtTjOqtgC9jUCMZXTWhUxYBGw,93
22
22
  tests/test_matching_data.py,sha256=TyvnSJPzdLPiXYWdz9coQ-m4H1tUS_cbD0WaBdvrevg,6062
23
23
  tests/test_matching_exhaustive.py,sha256=-Xm8jro1YJ3uPcPUCNazWX9Y7CxsTwPeKoy2Vr8TVH8,5531
24
24
  tests/test_matching_memory.py,sha256=jypztjDwTxvixHKteZzcvTzEVyOyJzVNK5JxlzInLcE,1145
25
- tests/test_matching_optimization.py,sha256=nG2R03eWGsZmdLPI0O0yBVKeD2k0C65OMJ7YKV5HLsc,9823
25
+ tests/test_matching_optimization.py,sha256=BOPeI37zK8DWxprXgH5MIDqPWYyirIGoNZaFLvFm2BE,7946
26
26
  tests/test_matching_utils.py,sha256=jdQc8L5RhFFo5T_EeiIrDs6z8G43yAI-iBe5FKr5KZc,11401
27
27
  tests/test_orientations.py,sha256=4ngFOUskcsumaaVCYIxapjdjX0A7KUM6VGkjLV-y9_Y,6776
28
28
  tests/test_packaging.py,sha256=fZhTXqa_1ONSZSVW581QRZaFPw13l0Wpc-p8AgQ4b3E,2783
@@ -57,25 +57,25 @@ tests/preprocessing/test_frequency_filters.py,sha256=zXBRBvuFjiKWiT8yTXWYfYF3kKG
57
57
  tests/preprocessing/test_preprocessor.py,sha256=XyR4xc4YM76PUKuTIiive76Q85DdcyDAvbXNGcoKL8w,4820
58
58
  tests/preprocessing/test_utils.py,sha256=M6rmFl7a3JaBdONvPHhkbkzoDjjOAwrztPTHXqsbN6o,3037
59
59
  tme/__init__.py,sha256=R0cxXFmTvL3p7y6D0zX_rfjChbXNU_-tYI4FTWZ16Ns,177
60
- tme/__version__.py,sha256=SBl2EPFW-ltPvQ7vbVWItyAsz3aKYIpjO7vcfr84GkU,22
60
+ tme/__version__.py,sha256=Xsa3ayOMVkhUWm4t06YeyHE0apjpZefxLH4ylp0CDtU,22
61
61
  tme/analyzer.py,sha256=kYxtzBCp02OJrlMME8Njh-9wIZAvsuhTHNAveNplssY,50466
62
- tme/density.py,sha256=APQgpZ8ILUsGCZTQDfAgcR_d-tm0T5m_AsPWIPr5pcg,84357
62
+ tme/density.py,sha256=SUlLrEeKY7wl-ePhXb6b5yA0HVSEZic-5ICRbaPFPck,84311
63
63
  tme/extensions.cpython-311-darwin.so,sha256=KM0UcTYdq0Gib8rC6yVt7KnXrVvPNcWS-cvuYW2aPso,392496
64
- tme/matching_data.py,sha256=UJnDk5CTDyoytD1CamUSmQR82WvtUUoFqUEadTHTIQ8,25394
65
- tme/matching_exhaustive.py,sha256=g6znkqHcrgPMvKw2wYaHT_N9B_pR6bj1xbtTByJ-QW0,19593
66
- tme/matching_optimization.py,sha256=Y8HfecXiOvAHXM1viBaQ_aXljqqTnGwlOlFe0MJpDRQ,45082
64
+ tme/matching_data.py,sha256=SWcW6qCfO3vImOOlIxxWfyzXkNSdyJox7JKVKHa_PT8,25450
65
+ tme/matching_exhaustive.py,sha256=E_XPl7Z26loDfL9hP8T5t8ID4DxpLUWFN9FhxNbG3fs,19601
66
+ tme/matching_optimization.py,sha256=S7Mmzln-VUbEClcbETxYCVahUWLRhq97hrnVVtH9T_I,45673
67
67
  tme/matching_scores.py,sha256=CECxl2Lh0TMLfZYnoCJXy3euGf8i9J0eHsAD7sqvWGU,30962
68
- tme/matching_utils.py,sha256=C4x4lxJq0_e1R-c0-JkGYM2MoqECAgJY5D1w4qaac5k,40046
68
+ tme/matching_utils.py,sha256=dGAfwA3VMkaOwTG8_36oPxSNg2ihZJyQ42LORKJbUnQ,40044
69
69
  tme/memory.py,sha256=6xeIMAncQkgYDi6w-PIYgFEWRTUPu0_OTCeRO0p9r9Q,11029
70
70
  tme/orientations.py,sha256=KsYXJuLRLYXRHsDjP9_Tn1jXxIVPSaYkw1wRrWH3nUQ,26027
71
71
  tme/parser.py,sha256=fNiCAdsWI4ql92F1Ob4suiVzpjUOBlh2lad1iNY_FP8,13772
72
72
  tme/preprocessor.py,sha256=8UgPuNb0GwZ7JQoBZQisgp0r-wFKwvo0Jxb0u9kb2fg,40412
73
- tme/structure.py,sha256=9cG9I4muinstujpj79ZJgVQEABly8OEt9Uha26FXJLM,65800
73
+ tme/structure.py,sha256=9kuOwLgOXkPLo4JCBFPB5fUAsR5JqOb7uFXGyURuyiM,65849
74
74
  tme/types.py,sha256=NAY7C4qxE6yz-DXVtClMvFfoOV-spWGLNfpLATZ1LcU,442
75
75
  tme/backends/__init__.py,sha256=4S68W2WJNZ9t33QSrRs6aL3OIyEVFo_zVsqXjS1iWYA,5185
76
76
  tme/backends/_jax_utils.py,sha256=YuNJHCYnSqOESMV-9LPr-ZxBg6Zvax2euBjsZM-j-64,5906
77
77
  tme/backends/cupy_backend.py,sha256=1nnCJ4nT7tJsXu1mrJGCy7x0Yg1wWVRg4SdzsQ2qiiw,9284
78
- tme/backends/jax_backend.py,sha256=femPIcppQVPKMyVIowgoOFFWOArvAX16aLvGho8qBNQ,10414
78
+ tme/backends/jax_backend.py,sha256=xGVsdEjvuNYACuAxncJF07yo4-dZOEmeomGEeHek8b4,10002
79
79
  tme/backends/matching_backend.py,sha256=KfCOKD_rA9el3Y7BeH17KJ1apCUIIhhvn-vmbkb3CB0,33750
80
80
  tme/backends/mlx_backend.py,sha256=FJhqmCzgjXAjWGX1HhHFrCy_We4YwQQBkKFNG05ctzM,7788
81
81
  tme/backends/npfftw_backend.py,sha256=JDTc_1QcVi9jU3yLQF7jkgwQz_Po60OhkKuV2V3g5v8,16997
@@ -110,10 +110,10 @@ tme/preprocessing/_utils.py,sha256=1K8xPquM0v1MASwsMpIc3ZWxxpUFt69LezVZY5QcJnY,6
110
110
  tme/preprocessing/composable_filter.py,sha256=zmXN_NcuvvtstFdU6yYQ09z-XJFE4Y-kkMCL4vHy-jc,778
111
111
  tme/preprocessing/compose.py,sha256=NFB6oUQOwn8foy82i3Lm5DeZUd_5dmcKdhuwX8E6wpo,1454
112
112
  tme/preprocessing/frequency_filters.py,sha256=XPG6zRF_VSPH4CWFj1BLICm3_jNrzmiHaln0JZR7CrU,12755
113
- tme/preprocessing/tilt_series.py,sha256=6OptAfqISxzZOtHIx5MdSaJf7VGFeDntz2jWekpZMus,37307
114
- pytme-0.2.4.dist-info/LICENSE,sha256=K1IUNSVAz8BXbpH5EA8y5FpaHdvFXnAF2zeK95Lr2bY,18467
115
- pytme-0.2.4.dist-info/METADATA,sha256=fq2oJEwpG-N_8XYERPQwaBD-IbXySwuJ75BG8NP0WtQ,5278
116
- pytme-0.2.4.dist-info/WHEEL,sha256=uY16WuvBs6SVLr1w0jr9fTUdSkt0n_9cWxlDSGwcm3o,109
117
- pytme-0.2.4.dist-info/entry_points.txt,sha256=ff3LQL3FCWfCYOwFiP9zatm7laUbnwCkuPELkQVyUO4,241
118
- pytme-0.2.4.dist-info/top_level.txt,sha256=ovCUR7UXXouH3zYt_fJLoqr_vtjp1wudFgjVAnztQLE,18
119
- pytme-0.2.4.dist-info/RECORD,,
113
+ tme/preprocessing/tilt_series.py,sha256=cCbUHufhPs_fniqXPm_JkqqeADAhgw3hxB0AX1Vhwz0,37273
114
+ pytme-0.2.5.dist-info/LICENSE,sha256=K1IUNSVAz8BXbpH5EA8y5FpaHdvFXnAF2zeK95Lr2bY,18467
115
+ pytme-0.2.5.dist-info/METADATA,sha256=h5QwF5Hkrt1YakNH9pPvzy3eB_wjbLZQApqykRJPboQ,5278
116
+ pytme-0.2.5.dist-info/WHEEL,sha256=iyHOmzQZtL_L_KJL9-PDLQa8YEVkAYN5-On75avjx-8,109
117
+ pytme-0.2.5.dist-info/entry_points.txt,sha256=ff3LQL3FCWfCYOwFiP9zatm7laUbnwCkuPELkQVyUO4,241
118
+ pytme-0.2.5.dist-info/top_level.txt,sha256=ovCUR7UXXouH3zYt_fJLoqr_vtjp1wudFgjVAnztQLE,18
119
+ pytme-0.2.5.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (73.0.1)
2
+ Generator: setuptools (74.1.2)
3
3
  Root-Is-Purelib: false
4
4
  Tag: cp311-cp311-macosx_14_0_arm64
5
5
 
@@ -101,56 +101,6 @@ class TestMatchDensityToCoordinates:
101
101
  score = instance()
102
102
  assert isinstance(score, float)
103
103
 
104
- def test_map_coordinates_to_array(self):
105
- ret = _MatchCoordinatesToDensity.map_coordinates_to_array(
106
- coordinates=self.coordinates.astype(np.float32),
107
- array_shape=self.target.shape,
108
- array_origin=np.zeros(self.target.ndim),
109
- sampling_rate=np.ones(self.target.ndim),
110
- )
111
- assert len(ret) == 2
112
-
113
- in_vol, in_vol_mask = ret
114
-
115
- assert in_vol_mask is None
116
- assert np.allclose(in_vol.shape, self.coordinates.shape[1])
117
-
118
- def test_map_coordinates_to_array_mask(self):
119
- ret = _MatchCoordinatesToDensity.map_coordinates_to_array(
120
- coordinates=self.coordinates.astype(np.float32),
121
- array_shape=self.target.shape,
122
- array_origin=self.origin,
123
- sampling_rate=self.sampling_rate,
124
- coordinates_mask=self.coordinates.astype(np.float32),
125
- )
126
- assert len(ret) == 2
127
-
128
- in_vol, in_vol_mask = ret
129
- assert np.allclose(in_vol, in_vol_mask)
130
-
131
- def test_array_from_coordinates(self):
132
- ret = _MatchCoordinatesToDensity.array_from_coordinates(
133
- coordinates=self.coordinates,
134
- weights=self.coordinates_weights,
135
- sampling_rate=self.sampling_rate,
136
- )
137
- assert len(ret) == 3
138
- arr, positions, origin = ret
139
- assert arr.ndim == self.coordinates.shape[0]
140
- assert positions.shape == self.coordinates.shape
141
- assert origin.shape == (self.coordinates.shape[0],)
142
-
143
- assert np.allclose(origin, self.coordinates.min(axis=1))
144
-
145
- ret = _MatchCoordinatesToDensity.array_from_coordinates(
146
- coordinates=self.coordinates,
147
- weights=self.coordinates_weights,
148
- sampling_rate=self.sampling_rate,
149
- origin=self.origin,
150
- )
151
- arr, positions, origin = ret
152
- assert np.allclose(origin, self.origin)
153
-
154
104
 
155
105
  class TestMatchCoordinateToCoordinates:
156
106
  def setup_method(self):
tme/__version__.py CHANGED
@@ -1 +1 @@
1
- __version__ = "0.2.4"
1
+ __version__ = "0.2.5"
@@ -208,13 +208,6 @@ class JaxBackend(NumpyFFTWBackend):
208
208
  # Applying the filter leads to more FFTs
209
209
  fastt_shape = matching_data._template.shape
210
210
  if create_template_filter:
211
- # _, fastt_shape, _, tshift = matching_data._fourier_padding(
212
- # target_shape=self.to_numpy_array(matching_data._template.shape),
213
- # template_shape=self.to_numpy_array(
214
- # [1 for _ in matching_data._template.shape]
215
- # ),
216
- # pad_fourier=False,
217
- # )
218
211
  fastt_shape = matching_data._template.shape
219
212
 
220
213
  ret, template_filter, target_filter = [], 1, 1
@@ -260,8 +253,8 @@ class JaxBackend(NumpyFFTWBackend):
260
253
  base, targets = None, self._array_backend.stack(targets)
261
254
  scores, rotations = scan_inner(
262
255
  targets,
263
- self.topleft_pad(matching_data.template, fastt_shape),
264
- self.topleft_pad(matching_data.template_mask, fastt_shape),
256
+ matching_data.template,
257
+ matching_data.template_mask,
265
258
  matching_data.rotations,
266
259
  template_filter,
267
260
  target_filter,
tme/density.py CHANGED
@@ -697,8 +697,7 @@ class Density:
697
697
  References
698
698
  ----------
699
699
  .. [1] Sorzano, Carlos et al (Mar. 2015). Fast and accurate conversion
700
- of atomic models into electron density maps. AIMS Biophysics
701
- 2, 8–20.
700
+ of atomic models into electron density maps. AIMS Biophysics 2, 8–20.
702
701
 
703
702
  Examples
704
703
  --------
@@ -743,17 +742,16 @@ class Density:
743
742
  >>> filter_by_residues = {"SER", "THR", "CYS", "ASN", "GLN", "TYR"}
744
743
  >>> )
745
744
 
746
- :py:meth:`Density.from_structure` supports a variety of methods to convert
747
- atoms into densities
745
+ In addtion, :py:meth:`Density.from_structure` supports a variety of methods
746
+ to convert atoms into densities, such as Gaussians
748
747
 
749
748
  >>> density = Density.from_structure(
750
749
  >>> filename_or_structure = path_to_structure,
751
750
  >>> weight_type = "gaussian",
752
- >>> weight_type_args={"resolution": "20"}
751
+ >>> weight_type_args={"resolution": 20}
753
752
  >>> )
754
753
 
755
- In addition its possible to use experimentally determined scattering factors
756
- from various sources:
754
+ experimentally determined scattering factors
757
755
 
758
756
  >>> density = Density.from_structure(
759
757
  >>> filename_or_structure = path_to_structure,
@@ -1688,11 +1686,10 @@ class Density:
1688
1686
  Resampling method to use, defaults to `spline`. Availabe options are:
1689
1687
 
1690
1688
  +---------+----------------------------------------------------------+
1691
- | spline | Smooth spline interpolation via :obj:`scipy.ndimage.zoom`|
1689
+ | spline | Spline interpolation using :obj:`scipy.ndimage.zoom` |
1692
1690
  +---------+----------------------------------------------------------+
1693
- | fourier | Frequency preserving Fourier cropping |
1691
+ | fourier | Fourier cropping |
1694
1692
  +---------+----------------------------------------------------------+
1695
-
1696
1693
  order : int, optional
1697
1694
  Order of spline used for interpolation, by default 1. Ignored when
1698
1695
  ``method`` is `fourier`.
tme/matching_data.py CHANGED
@@ -478,7 +478,8 @@ class MatchingData:
478
478
  shape_diff = np.multiply(
479
479
  np.subtract(target_shape, template_shape), 1 - batch_mask
480
480
  )
481
- if np.sum(shape_diff < 0):
481
+ shape_mask = shape_diff < 0
482
+ if np.sum(shape_mask):
482
483
  shape_shift = np.divide(shape_diff, 2)
483
484
  offset = np.mod(shape_diff, 2)
484
485
  if pad_fourier:
@@ -491,8 +492,7 @@ class MatchingData:
491
492
  "Template is larger than target and padding is turned off. Consider "
492
493
  "swapping them or activate padding. Correcting the shift for now."
493
494
  )
494
-
495
- shape_shift = np.add(shape_shift, offset)
495
+ shape_shift = np.multiply(np.add(shape_shift, offset), shape_mask)
496
496
  fourier_shift = np.subtract(fourier_shift, shape_shift).astype(int)
497
497
 
498
498
  fourier_shift = tuple(fourier_shift.astype(int))
@@ -90,12 +90,12 @@ def _setup_template_filter_apply_target_filter(
90
90
  # pad_fourier=False,
91
91
  # )
92
92
  fastt_shape = matching_data._template.shape
93
- matching_data.template = be.reverse(
94
- be.topleft_pad(matching_data.template, fastt_shape)
95
- )
96
- matching_data.template_mask = be.reverse(
97
- be.topleft_pad(matching_data.template_mask, fastt_shape)
98
- )
93
+ # matching_data.template = be.reverse(
94
+ # be.topleft_pad(matching_data.template, fastt_shape)
95
+ # )
96
+ # matching_data.template_mask = be.reverse(
97
+ # be.topleft_pad(matching_data.template_mask, fastt_shape)
98
+ # )
99
99
  matching_data._set_matching_dimension(
100
100
  target_dims=matching_data._target_dims,
101
101
  template_dims=matching_data._template_dims,
@@ -208,7 +208,7 @@ def scan(
208
208
 
209
209
  Examples
210
210
  --------
211
- Schematically, using :py:meth:`scan` is similar to :py:meth:`scan_subsets`,
211
+ Schematically, :py:meth:`scan` is identical to :py:meth:`scan_subsets`,
212
212
  with the distinction that the objects contained in ``matching_data`` are not
213
213
  split and the search is only parallelized over angles.
214
214
  Assuming you have followed the example in :py:meth:`scan_subsets`, :py:meth:`scan`
@@ -1,17 +1,16 @@
1
- """ Implements various methods for non-exhaustive template matching
2
- based on numerical optimization.
1
+ """ Implements methods for non-exhaustive template matching.
3
2
 
4
3
  Copyright (c) 2023 European Molecular Biology Laboratory
5
4
 
6
5
  Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
7
6
  """
8
-
9
- from typing import Tuple, Dict
7
+ import warnings
8
+ from typing import Tuple, List, Dict
10
9
  from abc import ABC, abstractmethod
11
10
 
12
11
  import numpy as np
13
12
  from scipy.spatial import KDTree
14
- from scipy.ndimage import laplace, map_coordinates
13
+ from scipy.ndimage import laplace, map_coordinates, sobel
15
14
  from scipy.optimize import (
16
15
  minimize,
17
16
  basinhopping,
@@ -157,15 +156,13 @@ class _MatchDensityToDensity(ABC):
157
156
  if out is None:
158
157
  out = np.zeros_like(arr)
159
158
 
160
- map_coordinates(arr, self.grid_out, order=order, output=out.ravel())
159
+ self._interpolate(arr, self.grid_out, order=order, out=out.ravel())
161
160
 
162
161
  if out_mask is None and arr_mask is not None:
163
162
  out_mask = np.zeros_like(arr_mask)
164
163
 
165
164
  if arr_mask is not None:
166
- map_coordinates(
167
- arr_mask, self.grid_out, order=order, output=out_mask.ravel()
168
- )
165
+ self._interpolate(arr_mask, self.grid_out, order=order, out=out.ravel())
169
166
 
170
167
  match return_type:
171
168
  case 0:
@@ -177,6 +174,12 @@ class _MatchDensityToDensity(ABC):
177
174
  case 3:
178
175
  return out, out_mask
179
176
 
177
+ @staticmethod
178
+ def _interpolate(data, positions, order: int = 1, out=None):
179
+ return map_coordinates(
180
+ data, positions, order=order, mode="constant", output=out
181
+ )
182
+
180
183
  def score_translation(self, x: Tuple[float]) -> float:
181
184
  """
182
185
  Computes the score after a given translation.
@@ -291,6 +294,8 @@ class _MatchCoordinatesToDensity(_MatchDensityToDensity):
291
294
  A d-dimensional mask to be applied to the target.
292
295
  negate_score : bool, optional
293
296
  Whether the final score should be multiplied by negative one. Default is True.
297
+ return_gradient : bool, optional
298
+ Invoking __call_ returns a tuple of score and parameter gradient. Default is False.
294
299
  **kwargs : Dict, optional
295
300
  Keyword arguments propagated to downstream functions.
296
301
  """
@@ -303,40 +308,43 @@ class _MatchCoordinatesToDensity(_MatchDensityToDensity):
303
308
  template_mask_coordinates: NDArray = None,
304
309
  target_mask: NDArray = None,
305
310
  negate_score: bool = True,
311
+ return_gradient: bool = False,
312
+ interpolation_order: int = 1,
306
313
  **kwargs: Dict,
307
314
  ):
308
- self.eps = be.eps(target.dtype)
309
- self.target_density = target
310
- self.target_mask_density = target_mask
315
+ self.target = target.astype(np.float32)
316
+ self.target_mask = None
317
+ if target_mask is not None:
318
+ self.target_mask = target_mask.astype(np.float32)
311
319
 
312
- self.template_weights = template_weights
313
- self.template_coordinates = template_coordinates
314
- self.template_coordinates_rotated = np.copy(self.template_coordinates).astype(
315
- np.float32
320
+ self.eps = be.eps(self.target.dtype)
321
+
322
+ self.target_grad = np.stack(
323
+ [sobel(self.target, axis=i) for i in range(self.target.ndim)]
316
324
  )
317
- if template_mask_coordinates is None:
318
- template_mask_coordinates = template_coordinates.copy()
319
325
 
320
- self.template_mask_coordinates = template_mask_coordinates
321
- self.template_mask_coordinates_rotated = template_mask_coordinates
326
+ self.n_points = template_coordinates.shape[1]
327
+ self.template = template_coordinates.astype(np.float32)
328
+ self.template_rotated = np.zeros_like(self.template)
329
+ self.template_weights = template_weights.astype(np.float32)
330
+ self.template_center = np.mean(self.template, axis=1)[:, None]
331
+
332
+ self.template_mask, self.template_mask_rotated = None, None
322
333
  if template_mask_coordinates is not None:
323
- self.template_mask_coordinates_rotated = np.copy(
324
- self.template_mask_coordinates
325
- ).astype(np.float32)
334
+ self.template_mask = template_mask_coordinates.astype(np.float32)
335
+ self.template_mask_rotated = np.empty_like(self.template_mask)
326
336
 
327
337
  self.denominator = 1
328
338
  self.score_sign = -1 if negate_score else 1
339
+ self.interpolation_order = interpolation_order
329
340
 
330
- self.in_volume, self.in_volume_mask = self.map_coordinates_to_array(
331
- coordinates=self.template_coordinates_rotated,
332
- coordinates_mask=self.template_mask_coordinates_rotated,
333
- array_origin=be.zeros(target.ndim),
334
- array_shape=self.target_density.shape,
335
- sampling_rate=be.full(target.ndim, fill_value=1),
341
+ self._target_values = self._interpolate(
342
+ self.target, self.template, order=self.interpolation_order
336
343
  )
337
344
 
338
- if hasattr(self, "_post_init"):
339
- self._post_init(**kwargs)
345
+ if return_gradient and not hasattr(self, "grad"):
346
+ raise NotImplementedError(f"{type(self)} does not have grad method.")
347
+ self.return_gradient = return_gradient
340
348
 
341
349
  def score(self, x: Tuple[float]):
342
350
  """
@@ -356,119 +364,39 @@ class _MatchCoordinatesToDensity(_MatchDensityToDensity):
356
364
  translation, rotation_matrix = _format_rigid_transform(x)
357
365
 
358
366
  rigid_transform(
359
- coordinates=self.template_coordinates,
360
- coordinates_mask=self.template_mask_coordinates,
367
+ coordinates=self.template,
368
+ coordinates_mask=self.template_mask,
361
369
  rotation_matrix=rotation_matrix,
362
370
  translation=translation,
363
- out=self.template_coordinates_rotated,
364
- out_mask=self.template_mask_coordinates_rotated,
371
+ out=self.template_rotated,
372
+ out_mask=self.template_mask_rotated,
365
373
  use_geometric_center=False,
366
374
  )
367
375
 
368
- self.in_volume, self.in_volume_mask = self.map_coordinates_to_array(
369
- coordinates=self.template_coordinates_rotated,
370
- coordinates_mask=self.template_mask_coordinates_rotated,
371
- array_origin=be.zeros(rotation_matrix.shape[0]),
372
- array_shape=self.target_density.shape,
373
- sampling_rate=be.full(rotation_matrix.shape[0], fill_value=1),
376
+ self._target_values = self._interpolate(
377
+ self.target, self.template_rotated, order=self.interpolation_order
374
378
  )
375
379
 
376
- return self()
380
+ score = self()
381
+ if not self.return_gradient:
382
+ return score
377
383
 
378
- @staticmethod
379
- def array_from_coordinates(
380
- coordinates: NDArray,
381
- weights: NDArray,
382
- sampling_rate: NDArray,
383
- origin: NDArray = None,
384
- shape: NDArray = None,
385
- ) -> Tuple[NDArray, NDArray, NDArray]:
386
- """
387
- Create a volume from coordinates, using given weights and voxel size.
384
+ return score, self.grad()
388
385
 
389
- Parameters
390
- ----------
391
- coordinates : NDArray
392
- An array representing the coordinates [d x N].
393
- weights : NDArray
394
- An array representing the weights for each coordinate [N].
395
- sampling_rate : NDArray
396
- The size of a voxel in the volume.
397
- origin : NDArray, optional
398
- The origin of the volume.
399
- shape : NDArray, optional
400
- The shape of the volume.
386
+ def _interpolate_gradient(self, positions):
387
+ ret = be.zeros(positions.shape, dtype=positions.dtype)
401
388
 
402
- Returns
403
- -------
404
- tuple
405
- Returns the generated volume, positions of coordinates, and origin.
406
- """
407
- if origin is None:
408
- origin = coordinates.min(axis=1)
409
-
410
- positions = np.divide(coordinates - origin[:, None], sampling_rate[:, None])
411
- positions = positions.astype(int)
412
-
413
- if shape is None:
414
- shape = positions.max(axis=1) + 1
389
+ for k in range(self.target_grad.shape[0]):
390
+ ret[k, :] = self._interpolate(
391
+ self.target_grad[k], positions, order=self.interpolation_order
392
+ )
415
393
 
416
- arr = np.zeros(shape, dtype=np.float32)
417
- np.add.at(arr, tuple(positions), weights)
418
- return arr, positions, origin
394
+ return ret
419
395
 
420
396
  @staticmethod
421
- def map_coordinates_to_array(
422
- coordinates: NDArray,
423
- array_shape: NDArray,
424
- array_origin: NDArray,
425
- sampling_rate: NDArray,
426
- coordinates_mask: NDArray = None,
427
- ) -> Tuple[NDArray, NDArray]:
428
- """
429
- Map coordinates to a volume based on given voxel size and origin.
430
-
431
- Parameters
432
- ----------
433
- coordinates : NDArray
434
- An array representing the coordinates to be mapped [d x N].
435
- array_shape : NDArray
436
- The shape of the array to which the coordinates are mapped.
437
- array_origin : NDArray
438
- The origin of the array to which the coordinates are mapped.
439
- sampling_rate : NDArray
440
- The size of a voxel in the array.
441
- coordinates_mask : NDArray, optional
442
- An array representing the mask for the coordinates [d x T].
443
-
444
- Returns
445
- -------
446
- tuple
447
- Returns transformed coordinates, transformed coordinates mask,
448
- mask for in_volume points, and mask for in_volume points in mask.
449
- """
450
- np.divide(
451
- coordinates - array_origin[:, None], sampling_rate[:, None], out=coordinates
452
- )
453
-
454
- in_volume = np.logical_and(
455
- coordinates < np.array(array_shape)[:, None],
456
- coordinates >= 0,
457
- ).min(axis=0)
458
-
459
- in_volume_mask = None
460
- if coordinates_mask is not None:
461
- np.divide(
462
- coordinates_mask - array_origin[:, None],
463
- sampling_rate[:, None],
464
- out=coordinates_mask,
465
- )
466
- in_volume_mask = np.logical_and(
467
- coordinates_mask < np.array(array_shape)[:, None],
468
- coordinates_mask >= 0,
469
- ).min(axis=0)
470
-
471
- return in_volume, in_volume_mask
397
+ def _torques(positions, center, gradients):
398
+ positions_center = (positions - center).T
399
+ return be.cross(positions_center, gradients.T).T
472
400
 
473
401
 
474
402
  class _MatchCoordinatesToCoordinates(_MatchDensityToDensity):
@@ -635,14 +563,43 @@ class CrossCorrelation(_MatchCoordinatesToDensity):
635
563
 
636
564
  def __call__(self) -> float:
637
565
  """Returns the score of the current configuration."""
638
- score = np.dot(
639
- self.target_density[
640
- tuple(self.template_coordinates_rotated[:, self.in_volume].astype(int))
641
- ],
642
- self.template_weights[self.in_volume],
566
+ score = be.dot(self._target_values, self.template_weights)
567
+ score /= self.denominator * self.score_sign
568
+ return score
569
+
570
+ def grad(self):
571
+ """
572
+ Calculate the gradient of the cost function w.r.t. translation and rotation.
573
+
574
+ .. math::
575
+
576
+ \\nabla f = -\\frac{1}{N} \\begin{bmatrix}
577
+ \\sum_i w_i \\nabla v(x_i) \\\\
578
+ \\sum_i w_i (r_i \\times \\nabla v(x_i))
579
+ \\end{bmatrix}
580
+
581
+ where :math:`N` is the number of points, :math:`w_i` are weights,
582
+ :math:`x_i` are rotated template positions, and :math:`r_i` are
583
+ positions relative to the template center.
584
+
585
+ Returns
586
+ -------
587
+ np.ndarray
588
+ Negative gradient of the cost function: [dx, dy, dz, dRx, dRy, dRz].
589
+
590
+ """
591
+ grad = self._interpolate_gradient(positions=self.template_rotated)
592
+ torque = self._torques(
593
+ positions=self.template_rotated, gradients=grad, center=self.template_center
643
594
  )
644
- score /= self.denominator
645
- return score * self.score_sign
595
+
596
+ translation_grad = be.sum(grad * self.template_weights, axis=1)
597
+ torque_grad = be.sum(torque * self.template_weights, axis=1)
598
+
599
+ # <u, dv/dx> / <u, r x dv/dx>
600
+ total_grad = be.concatenate([translation_grad, torque_grad])
601
+ total_grad = be.divide(total_grad, self.n_points, out=total_grad)
602
+ return -total_grad
646
603
 
647
604
 
648
605
  class LaplaceCrossCorrelation(CrossCorrelation):
@@ -658,15 +615,18 @@ class LaplaceCrossCorrelation(CrossCorrelation):
658
615
 
659
616
  __doc__ += _MatchCoordinatesToDensity.__doc__
660
617
 
661
- def _post_init(self, **kwargs):
662
- self.target_density = laplace(self.target_density)
618
+ def __init__(self, **kwargs):
619
+ kwargs["target"] = laplace(kwargs["target"])
663
620
 
664
- arr, positions, _ = self.array_from_coordinates(
665
- self.template_coordinates,
666
- self.template_weights,
667
- np.ones(self.template_coordinates.shape[0]),
668
- )
669
- self.template_weights = laplace(arr)[tuple(positions)]
621
+ coordinates = kwargs["template_coordinates"]
622
+ origin = coordinates.min(axis=1)
623
+ positions = (coordinates - origin[:, None]).astype(int)
624
+ shape = positions.max(axis=1) + 1
625
+ arr = np.zeros(shape, dtype=np.float32)
626
+ np.add.at(arr, tuple(positions), kwargs["template_weights"])
627
+
628
+ kwargs["template_weights"] = laplace(arr)[tuple(positions)]
629
+ super().__init__(**kwargs)
670
630
 
671
631
 
672
632
  class NormalizedCrossCorrelation(CrossCorrelation):
@@ -696,24 +656,76 @@ class NormalizedCrossCorrelation(CrossCorrelation):
696
656
  __doc__ += _MatchCoordinatesToDensity.__doc__
697
657
 
698
658
  def __call__(self) -> float:
699
- n_observations = be.sum(self.in_volume_mask)
700
- target_coordinates = be.astype(
701
- self.template_mask_coordinates_rotated[:, self.in_volume_mask], int
659
+ denominator = be.multiply(
660
+ np.linalg.norm(self.template_weights), np.linalg.norm(self._target_values)
702
661
  )
703
- target_weight = self.target_density[tuple(target_coordinates)]
704
- ex2 = be.divide(be.sum(be.square(target_weight)), n_observations)
705
- e2x = be.square(be.divide(be.sum(target_weight), n_observations))
706
-
707
- denominator = be.maximum(be.subtract(ex2, e2x), 0.0)
708
- denominator = be.sqrt(denominator)
709
- denominator = be.multiply(denominator, n_observations)
710
662
 
711
- if denominator <= self.eps:
663
+ if denominator <= 0:
712
664
  return 0.0
713
665
 
714
666
  self.denominator = denominator
715
667
  return super().__call__()
716
668
 
669
+ def grad(self):
670
+ """
671
+ Calculate the normalized gradient of the cost function w.r.t. translation and rotation.
672
+
673
+ .. math::
674
+
675
+ \\nabla f = -\\frac{1}{N|w||v|^3} \\begin{bmatrix}
676
+ (\\sum_i w_i \\nabla v(x_i))|v|^2 - (\\sum_i v(x_i)
677
+ \\nabla v(x_i))(w \\cdot v) \\\\
678
+ (\\sum_i w_i (r_i \\times \\nabla v(x_i)))|v|^2 - (\\sum_i v(x_i)
679
+ (r_i \\times \\nabla v(x_i)))(w \\cdot v)
680
+ \\end{bmatrix}
681
+
682
+ where :math:`N` is the number of points, :math:`w` are weights,
683
+ :math:`v` are target values, :math:`x_i` are rotated template positions,
684
+ and :math:`r_i` are positions relative to the template center.
685
+
686
+ Returns
687
+ -------
688
+ np.ndarray
689
+ Negative normalized gradient: [dx, dy, dz, dRx, dRy, dRz].
690
+
691
+ """
692
+ grad = self._interpolate_gradient(positions=self.template_rotated)
693
+ torque = self._torques(
694
+ positions=self.template_rotated, gradients=grad, center=self.template_center
695
+ )
696
+
697
+ norm = be.multiply(
698
+ be.power(be.sqrt(be.sum(be.square(self._target_values))), 3),
699
+ be.sqrt(be.sum(be.square(self.template_weights))),
700
+ )
701
+
702
+ # (<u,dv/dx> * |v|**2 - <u,v> * <v,dv/dx>)/(|w|*|v|**3)
703
+ translation_grad = be.multiply(
704
+ be.sum(be.multiply(grad, self.template_weights), axis=1),
705
+ be.sum(be.square(self._target_values)),
706
+ )
707
+ translation_grad -= be.multiply(
708
+ be.sum(be.multiply(grad, self._target_values), axis=1),
709
+ be.sum(be.multiply(self._target_values, self.template_weights)),
710
+ )
711
+
712
+ # (<u,r x dv/dx> * |v|**2 - <u,v> * <v,r x dv/dx>)/(|w|*|v|**3)
713
+ torque_grad = be.multiply(
714
+ be.sum(be.multiply(torque, self.template_weights), axis=1),
715
+ be.sum(be.square(self._target_values)),
716
+ )
717
+ torque_grad -= be.multiply(
718
+ be.sum(be.multiply(torque, self._target_values), axis=1),
719
+ be.sum(be.multiply(self._target_values, self.template_weights)),
720
+ )
721
+
722
+ total_grad = be.concatenate([translation_grad, torque_grad])
723
+ if norm > 0:
724
+ total_grad = be.divide(total_grad, norm, out=total_grad)
725
+
726
+ total_grad = be.divide(total_grad, self.n_points, out=total_grad)
727
+ return -total_grad
728
+
717
729
 
718
730
  class NormalizedCrossCorrelationMean(NormalizedCrossCorrelation):
719
731
  """
@@ -802,33 +814,33 @@ class MaskedCrossCorrelation(_MatchCoordinatesToDensity):
802
814
 
803
815
  def __call__(self) -> float:
804
816
  """Returns the score of the current configuration."""
817
+
818
+ in_volume = np.logical_and(
819
+ self.template_rotated < np.array(self.target.shape)[:, None],
820
+ self.template_rotated >= 0,
821
+ ).min(axis=0)
822
+ in_volume_mask = np.logical_and(
823
+ self.template_mask_rotated < np.array(self.target.shape)[:, None],
824
+ self.template_mask_rotated >= 0,
825
+ ).min(axis=0)
826
+
805
827
  mask_overlap = np.sum(
806
- self.target_mask_density[
807
- tuple(
808
- self.template_mask_coordinates_rotated[
809
- :, self.in_volume_mask
810
- ].astype(int)
811
- )
828
+ self.target_mask[
829
+ tuple(self.template_mask_rotated[:, in_volume_mask].astype(int))
812
830
  ],
813
831
  )
814
832
  mask_overlap = np.fmax(mask_overlap, np.finfo(float).eps)
815
833
 
816
- mask_target = self.target_density[
817
- tuple(
818
- self.template_mask_coordinates_rotated[:, self.in_volume_mask].astype(
819
- int
820
- )
821
- )
834
+ mask_target = self.target[
835
+ tuple(self.template_mask_rotated[:, in_volume_mask].astype(int))
822
836
  ]
823
837
  denominator1 = np.subtract(
824
838
  np.sum(mask_target**2),
825
839
  np.divide(np.square(np.sum(mask_target)), mask_overlap),
826
840
  )
827
841
  mask_template = np.multiply(
828
- self.template_weights[self.in_volume],
829
- self.target_mask_density[
830
- tuple(self.template_coordinates_rotated[:, self.in_volume].astype(int))
831
- ],
842
+ self.template_weights[in_volume],
843
+ self.target_mask[tuple(self.template_rotated[:, in_volume].astype(int))],
832
844
  )
833
845
  denominator2 = np.subtract(
834
846
  np.sum(mask_template**2),
@@ -840,10 +852,8 @@ class MaskedCrossCorrelation(_MatchCoordinatesToDensity):
840
852
  denominator = np.sqrt(np.multiply(denominator1, denominator2))
841
853
 
842
854
  numerator = np.dot(
843
- self.target_density[
844
- tuple(self.template_coordinates_rotated[:, self.in_volume].astype(int))
845
- ],
846
- self.template_weights[self.in_volume],
855
+ self.target[tuple(self.template_rotated[:, in_volume].astype(int))],
856
+ self.template_weights[in_volume],
847
857
  )
848
858
 
849
859
  numerator -= np.divide(
@@ -877,21 +887,9 @@ class PartialLeastSquareDifference(_MatchCoordinatesToDensity):
877
887
 
878
888
  def __call__(self) -> float:
879
889
  """Returns the score of the current configuration."""
880
- score = np.sum(
881
- np.square(
882
- np.subtract(
883
- self.target_density[
884
- tuple(
885
- self.template_coordinates_rotated[:, self.in_volume].astype(
886
- int
887
- )
888
- )
889
- ],
890
- self.template_weights[self.in_volume],
891
- )
892
- )
890
+ score = be.sum(
891
+ be.square(be.subtract(self._target_values, self.template_weights))
893
892
  )
894
- score += np.sum(np.square(self.template_weights[np.invert(self.in_volume)]))
895
893
  return score * self.score_sign
896
894
 
897
895
 
@@ -917,10 +915,7 @@ class MutualInformation(_MatchCoordinatesToDensity):
917
915
  def __call__(self) -> float:
918
916
  """Returns the score of the current configuration."""
919
917
  p_xy, target, template = np.histogram2d(
920
- self.target_density[
921
- tuple(self.template_coordinates_rotated[:, self.in_volume].astype(int))
922
- ],
923
- self.template_weights[self.in_volume],
918
+ self._target_values, self.template_weights
924
919
  )
925
920
  p_x, p_y = np.sum(p_xy, axis=1), np.sum(p_xy, axis=0)
926
921
 
@@ -947,7 +942,7 @@ class Envelope(_MatchCoordinatesToDensity):
947
942
  References
948
943
  ----------
949
944
  .. [1] Daven Vasishtan and Maya Topf, "Scoring functions for cryoEM density
950
- fitting", Journal of Structural Biology, vol. 174, no. 2,
945
+ fitting", Journal of Structural Biology, vol. 1174, no. 2,
951
946
  pp. 333--343, 2011. DOI: https://doi.org/10.1016/j.jsb.2011.01.012
952
947
  """
953
948
 
@@ -956,22 +951,22 @@ class Envelope(_MatchCoordinatesToDensity):
956
951
  def __init__(self, target_threshold: float = None, **kwargs):
957
952
  super().__init__(**kwargs)
958
953
  if target_threshold is None:
959
- target_threshold = np.mean(self.target_density)
960
- self.target_density = np.where(self.target_density > target_threshold, -1, 1)
961
- self.target_density_present = np.sum(self.target_density == -1)
962
- self.target_density_absent = np.sum(self.target_density == 1)
954
+ target_threshold = np.mean(self.target)
955
+ self.target = np.where(self.target > target_threshold, -1, 1)
956
+ self.target_present = np.sum(self.target == -1)
957
+ self.target_absent = np.sum(self.target == 1)
963
958
  self.template_weights = np.ones_like(self.template_weights)
964
959
 
965
960
  def __call__(self) -> float:
966
961
  """Returns the score of the current configuration."""
967
- score = self.target_density[
968
- tuple(self.template_coordinates_rotated[:, self.in_volume].astype(int))
969
- ]
970
- unassigned_density = self.target_density_present - (score == -1).sum()
962
+ score = self._target_values
963
+ unassigned_density = self.target_present - (score == -1).sum()
971
964
 
972
- score = score.sum() - unassigned_density - 2 * np.sum(np.invert(self.in_volume))
973
- min_score = -self.target_density_present - 2 * self.target_density_absent
974
- score = (score - 2 * min_score) / (2 * self.target_density_present - min_score)
965
+ # Out of volume values will be set to 0
966
+ score = score.sum() - unassigned_density
967
+ score -= 2 * np.sum(np.invert(np.abs(self._target_values) > 0))
968
+ min_score = -self.target_present - 2 * self.target_absent
969
+ score = (score - 2 * min_score) / (2 * self.target_present - min_score)
975
970
 
976
971
  return score * self.score_sign
977
972
 
@@ -1271,7 +1266,15 @@ def optimize_match(
1271
1266
 
1272
1267
  x0 = np.zeros(2 * ndim) if x0 is None else x0
1273
1268
 
1269
+ return_gradient = getattr(score_object, "return_gradient", False)
1270
+ if optimization_method != "minimize" and return_gradient:
1271
+ warnings.warn("Gradient only considered for optimization_method='minimize'.")
1272
+ score_object.return_gradient = False
1273
+
1274
1274
  initial_score = score_object.score(x=x0)
1275
+ if isinstance(initial_score, (List, Tuple)):
1276
+ initial_score = initial_score[0]
1277
+
1275
1278
  if optimization_method == "basinhopping":
1276
1279
  result = basinhopping(
1277
1280
  x0=x0,
@@ -1287,10 +1290,14 @@ def optimize_match(
1287
1290
  maxiter=maxiter,
1288
1291
  )
1289
1292
  elif optimization_method == "minimize":
1290
- print(maxiter)
1293
+ if hasattr(score_object, "grad") and not return_gradient:
1294
+ warnings.warn(
1295
+ "Consider initializing score object with return_gradient=True."
1296
+ )
1291
1297
  result = minimize(
1292
1298
  x0=x0,
1293
1299
  fun=score_object.score,
1300
+ jac=return_gradient,
1294
1301
  bounds=bounds,
1295
1302
  constraints=linear_constraint,
1296
1303
  options={"maxiter": maxiter},
tme/matching_utils.py CHANGED
@@ -645,7 +645,7 @@ def get_rotation_matrices(
645
645
  dets = np.linalg.det(ret)
646
646
  neg_dets = dets < 0
647
647
  ret[neg_dets, :, -1] *= -1
648
- ret[0] = np.eye(dim, dtype = ret.dtype)
648
+ ret[0] = np.eye(dim, dtype=ret.dtype)
649
649
  return ret
650
650
 
651
651
 
@@ -91,9 +91,9 @@ def create_reconstruction_filter(
91
91
  if tilt_angles is False:
92
92
  raise ValueError("'ramp' filter requires specifying tilt angles.")
93
93
  size = filter_shape[0]
94
- ret = fftfreqn((size,), sampling_rate = 1, compute_euclidean_norm = True)
94
+ ret = fftfreqn((size,), sampling_rate=1, compute_euclidean_norm=True)
95
95
  min_increment = np.radians(np.min(np.abs(np.diff(np.sort(tilt_angles)))))
96
- ret *= (min_increment * size)
96
+ ret *= min_increment * size
97
97
  np.fmin(ret, 1, out=ret)
98
98
 
99
99
  ret = np.tile(ret[:, np.newaxis], (1, filter_shape[1]))
@@ -193,7 +193,7 @@ class ReconstructFromTilt:
193
193
  volume_temp_rotated = be.zeros(shape, dtype=be._float_dtype)
194
194
  volume = be.zeros(shape, dtype=be._float_dtype)
195
195
 
196
- slices = tuple(slice(a//2, (a//2) + 1) for a in shape)
196
+ slices = tuple(slice(a // 2, (a // 2) + 1) for a in shape)
197
197
  subset = tuple(
198
198
  slice(None) if i != opening_axis else slices[opening_axis]
199
199
  for i in range(len(shape))
@@ -423,12 +423,9 @@ class Wedge:
423
423
 
424
424
  return wedges
425
425
 
426
- def weight_relion(self,
427
- shape: Tuple[int],
428
- opening_axis: int,
429
- tilt_axis: int,
430
- **kwargs
431
- ) -> NDArray:
426
+ def weight_relion(
427
+ self, shape: Tuple[int], opening_axis: int, tilt_axis: int, **kwargs
428
+ ) -> NDArray:
432
429
  """
433
430
  Generate weighted wedges based on the RELION 1.4 formalism, weighting each
434
431
  angle using the cosine of the angle and a Gaussian lowpass filter computed
@@ -545,7 +542,7 @@ class WedgeReconstructed:
545
542
  angles: Tuple[float] = None,
546
543
  opening_axis: int = 0,
547
544
  tilt_axis: int = 2,
548
- weights : Tuple[float] = None,
545
+ weights: Tuple[float] = None,
549
546
  weight_wedge: bool = False,
550
547
  create_continuous_wedge: bool = False,
551
548
  frequency_cutoff: float = 0.5,
@@ -675,7 +672,7 @@ class WedgeReconstructed:
675
672
  opening_axis: int,
676
673
  tilt_axis: int,
677
674
  weights: Tuple[float] = None,
678
- reconstruction_filter : str = None,
675
+ reconstruction_filter: str = None,
679
676
  **kwargs: Dict,
680
677
  ) -> NDArray:
681
678
  """
@@ -715,7 +712,7 @@ class WedgeReconstructed:
715
712
  weights = np.repeat(weights, angles.size // weights.size)
716
713
  plane = np.zeros(
717
714
  (shape[opening_axis], shape[tilt_axis] + (1 - shape[tilt_axis] % 2)),
718
- dtype=np.float32
715
+ dtype=np.float32,
719
716
  )
720
717
 
721
718
  # plane = np.zeros((shape[opening_axis], int(2 * np.max(shape)) + 1), dtype=np.float32)
@@ -723,7 +720,7 @@ class WedgeReconstructed:
723
720
  rec_filter = 1
724
721
  if reconstruction_filter is not None:
725
722
  rec_filter = create_reconstruction_filter(
726
- plane.shape[::-1], filter_type = reconstruction_filter, tilt_angles = angles
723
+ plane.shape[::-1], filter_type=reconstruction_filter, tilt_angles=angles
727
724
  ).T
728
725
 
729
726
  subset = tuple(
tme/structure.py CHANGED
@@ -1117,16 +1117,16 @@ class Structure:
1117
1117
  ----------
1118
1118
  positions : NDArray
1119
1119
  Array containing atomic positions in z,y,x format (n,d).
1120
- weights : [float]
1120
+ weights : tuple of float
1121
1121
  The weights to use for the entries in positions.
1122
- resolution : float
1122
+ resolution : float, optional
1123
1123
  The product of resolution and sigma_factor gives the sigma used to
1124
1124
  compute the discretized Gaussian.
1125
- sigma_factor : float
1125
+ sigma_factor : float, optional
1126
1126
  The factor used with resolution to compute sigma. Default is 1 / (π√2).
1127
- cutoff_value : float
1127
+ cutoff_value : float, optional
1128
1128
  The cutoff value for the Gaussian kernel. Default is 4.0.
1129
- sampling_rate : float
1129
+ sampling_rate : float, optional
1130
1130
  Sampling rate along each dimension. One third of resolution by default.
1131
1131
 
1132
1132
  References
@@ -1160,8 +1160,8 @@ class Structure:
1160
1160
  positions = positions[:, ::-1]
1161
1161
  origin = positions.min(axis=0) - pad * sampling_rate
1162
1162
  positions = np.rint(np.divide((positions - origin), sampling_rate)).astype(int)
1163
- shape = positions.max(axis=0).astype(int) + pad + 1
1164
1163
 
1164
+ shape = positions.max(axis=0).astype(int) + pad + 1
1165
1165
  out = np.zeros(shape, dtype=np.float32)
1166
1166
  np.add.at(out, tuple(positions.T), weights)
1167
1167
 
@@ -1299,10 +1299,10 @@ class Structure:
1299
1299
  )
1300
1300
 
1301
1301
  temp = self.subset_by_chain(chain=chain)
1302
- positions, atoms, shape, sampling_rate, origin = temp._coordinate_to_position(
1302
+ positions, atoms, _shape, sampling_rate, origin = temp._coordinate_to_position(
1303
1303
  shape=shape, sampling_rate=sampling_rate, origin=origin
1304
1304
  )
1305
- volume = np.zeros(shape, dtype=np.float32)
1305
+ volume = np.zeros(_shape, dtype=np.float32)
1306
1306
  if weight_type in ("atomic_weight", "atomic_number"):
1307
1307
  weights = temp._get_atom_weights(atoms=atoms, weight_type=weight_type)
1308
1308
  np.add.at(volume, tuple(positions.T), weights)
File without changes