pytme 0.2.9__cp311-cp311-macosx_15_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (119) hide show
  1. pytme-0.2.9.data/scripts/estimate_ram_usage.py +97 -0
  2. pytme-0.2.9.data/scripts/match_template.py +1135 -0
  3. pytme-0.2.9.data/scripts/postprocess.py +622 -0
  4. pytme-0.2.9.data/scripts/preprocess.py +209 -0
  5. pytme-0.2.9.data/scripts/preprocessor_gui.py +1227 -0
  6. pytme-0.2.9.dist-info/METADATA +95 -0
  7. pytme-0.2.9.dist-info/RECORD +119 -0
  8. pytme-0.2.9.dist-info/WHEEL +5 -0
  9. pytme-0.2.9.dist-info/entry_points.txt +6 -0
  10. pytme-0.2.9.dist-info/licenses/LICENSE +153 -0
  11. pytme-0.2.9.dist-info/top_level.txt +3 -0
  12. scripts/__init__.py +0 -0
  13. scripts/estimate_ram_usage.py +97 -0
  14. scripts/match_template.py +1135 -0
  15. scripts/postprocess.py +622 -0
  16. scripts/preprocess.py +209 -0
  17. scripts/preprocessor_gui.py +1227 -0
  18. tests/__init__.py +0 -0
  19. tests/data/Blurring/blob_width18.npy +0 -0
  20. tests/data/Blurring/edgegaussian_sigma3.npy +0 -0
  21. tests/data/Blurring/gaussian_sigma2.npy +0 -0
  22. tests/data/Blurring/hamming_width6.npy +0 -0
  23. tests/data/Blurring/kaiserb_width18.npy +0 -0
  24. tests/data/Blurring/localgaussian_sigma0510.npy +0 -0
  25. tests/data/Blurring/mean_size5.npy +0 -0
  26. tests/data/Blurring/ntree_sigma0510.npy +0 -0
  27. tests/data/Blurring/rank_rank3.npy +0 -0
  28. tests/data/Maps/.DS_Store +0 -0
  29. tests/data/Maps/emd_8621.mrc.gz +0 -0
  30. tests/data/README.md +2 -0
  31. tests/data/Raw/em_map.map +0 -0
  32. tests/data/Structures/.DS_Store +0 -0
  33. tests/data/Structures/1pdj.cif +3339 -0
  34. tests/data/Structures/1pdj.pdb +1429 -0
  35. tests/data/Structures/5khe.cif +3685 -0
  36. tests/data/Structures/5khe.ent +2210 -0
  37. tests/data/Structures/5khe.pdb +2210 -0
  38. tests/data/Structures/5uz4.cif +70548 -0
  39. tests/preprocessing/__init__.py +0 -0
  40. tests/preprocessing/test_compose.py +76 -0
  41. tests/preprocessing/test_frequency_filters.py +178 -0
  42. tests/preprocessing/test_preprocessor.py +136 -0
  43. tests/preprocessing/test_utils.py +79 -0
  44. tests/test_analyzer.py +216 -0
  45. tests/test_backends.py +446 -0
  46. tests/test_density.py +503 -0
  47. tests/test_extensions.py +130 -0
  48. tests/test_matching_cli.py +283 -0
  49. tests/test_matching_data.py +162 -0
  50. tests/test_matching_exhaustive.py +124 -0
  51. tests/test_matching_memory.py +30 -0
  52. tests/test_matching_optimization.py +226 -0
  53. tests/test_matching_utils.py +189 -0
  54. tests/test_orientations.py +175 -0
  55. tests/test_parser.py +33 -0
  56. tests/test_rotations.py +153 -0
  57. tests/test_structure.py +247 -0
  58. tme/__init__.py +6 -0
  59. tme/__version__.py +1 -0
  60. tme/analyzer/__init__.py +2 -0
  61. tme/analyzer/_utils.py +186 -0
  62. tme/analyzer/aggregation.py +577 -0
  63. tme/analyzer/peaks.py +953 -0
  64. tme/backends/__init__.py +171 -0
  65. tme/backends/_cupy_utils.py +734 -0
  66. tme/backends/_jax_utils.py +188 -0
  67. tme/backends/cupy_backend.py +294 -0
  68. tme/backends/jax_backend.py +314 -0
  69. tme/backends/matching_backend.py +1270 -0
  70. tme/backends/mlx_backend.py +241 -0
  71. tme/backends/npfftw_backend.py +583 -0
  72. tme/backends/pytorch_backend.py +430 -0
  73. tme/data/__init__.py +0 -0
  74. tme/data/c48n309.npy +0 -0
  75. tme/data/c48n527.npy +0 -0
  76. tme/data/c48n9.npy +0 -0
  77. tme/data/c48u1.npy +0 -0
  78. tme/data/c48u1153.npy +0 -0
  79. tme/data/c48u1201.npy +0 -0
  80. tme/data/c48u1641.npy +0 -0
  81. tme/data/c48u181.npy +0 -0
  82. tme/data/c48u2219.npy +0 -0
  83. tme/data/c48u27.npy +0 -0
  84. tme/data/c48u2947.npy +0 -0
  85. tme/data/c48u3733.npy +0 -0
  86. tme/data/c48u4749.npy +0 -0
  87. tme/data/c48u5879.npy +0 -0
  88. tme/data/c48u7111.npy +0 -0
  89. tme/data/c48u815.npy +0 -0
  90. tme/data/c48u83.npy +0 -0
  91. tme/data/c48u8649.npy +0 -0
  92. tme/data/c600v.npy +0 -0
  93. tme/data/c600vc.npy +0 -0
  94. tme/data/metadata.yaml +80 -0
  95. tme/data/quat_to_numpy.py +42 -0
  96. tme/data/scattering_factors.pickle +0 -0
  97. tme/density.py +2263 -0
  98. tme/extensions.cpython-311-darwin.so +0 -0
  99. tme/external/bindings.cpp +332 -0
  100. tme/filters/__init__.py +6 -0
  101. tme/filters/_utils.py +311 -0
  102. tme/filters/bandpass.py +230 -0
  103. tme/filters/compose.py +81 -0
  104. tme/filters/ctf.py +393 -0
  105. tme/filters/reconstruction.py +160 -0
  106. tme/filters/wedge.py +542 -0
  107. tme/filters/whitening.py +191 -0
  108. tme/matching_data.py +863 -0
  109. tme/matching_exhaustive.py +497 -0
  110. tme/matching_optimization.py +1311 -0
  111. tme/matching_scores.py +1183 -0
  112. tme/matching_utils.py +1188 -0
  113. tme/memory.py +337 -0
  114. tme/orientations.py +598 -0
  115. tme/parser.py +685 -0
  116. tme/preprocessor.py +1329 -0
  117. tme/rotations.py +350 -0
  118. tme/structure.py +1864 -0
  119. tme/types.py +13 -0
tests/test_density.py ADDED
@@ -0,0 +1,503 @@
1
+ from os import remove
2
+ from tempfile import mkstemp
3
+ from itertools import permutations
4
+ from importlib_resources import files
5
+
6
+ import pytest
7
+ import numpy as np
8
+
9
+ from tme import Density, Structure, Preprocessor
10
+ from tme.rotations import euler_to_rotationmatrix
11
+ from tme.matching_utils import create_mask
12
+
13
+ DEFAULT_DATA = create_mask(
14
+ mask_type="ellipse",
15
+ center=(20, 20, 20),
16
+ radius=(10, 5, 10),
17
+ shape=(50, 50, 50),
18
+ )
19
+ DEFAULT_DATA = Preprocessor().gaussian_filter(DEFAULT_DATA * 10, sigma=2)
20
+ DEFAULT_DATA = DEFAULT_DATA.astype(np.float32)
21
+ DEFAULT_ORIGIN = np.array([0, 0, 0])
22
+ DEFAULT_SAMPLING_RATE = np.array([1, 1, 1])
23
+
24
+ BASEPATH = files("tests.data")
25
+
26
+
27
+ class TestDensity:
28
+ def setup_method(self):
29
+ self.density = Density(
30
+ data=DEFAULT_DATA,
31
+ origin=DEFAULT_ORIGIN,
32
+ sampling_rate=DEFAULT_SAMPLING_RATE,
33
+ metadata={
34
+ "min": DEFAULT_DATA.min(),
35
+ "max": DEFAULT_DATA.max(),
36
+ "mean": DEFAULT_DATA.mean(),
37
+ "std": DEFAULT_DATA.std(),
38
+ },
39
+ )
40
+ _, self.path = mkstemp()
41
+ self.structure_path = str(BASEPATH.joinpath("Structures/5khe.cif"))
42
+
43
+ def teardown_method(self):
44
+ del self.density
45
+ remove(self.path)
46
+
47
+ def test_initialization(self):
48
+ data = DEFAULT_DATA
49
+ origin = DEFAULT_ORIGIN
50
+ sampling_rate = DEFAULT_SAMPLING_RATE
51
+ metadata = {"test_key": "test_value"}
52
+
53
+ density = Density(data, origin, sampling_rate, metadata)
54
+
55
+ assert np.array_equal(density.data, data)
56
+ assert np.array_equal(density.origin, origin)
57
+ assert np.array_equal(density.sampling_rate, sampling_rate)
58
+ assert density.metadata == metadata
59
+
60
+ @pytest.mark.parametrize(
61
+ "data,origin,sampling_rate,metadata",
62
+ [
63
+ (np.random.rand(50, 50, 50), (0, 0, 0), (1, 2), {}),
64
+ (np.random.rand(50, 50, 50), (0, 0, 0), (1, 2, 3, 4), {}),
65
+ (np.random.rand(50, 50, 50), (0, 0, 0), (1, 2, 3), "not_a_dict"),
66
+ (np.random.rand(50, 50, 50), (0, 0), (1, 2, 3), "not_a_dict"),
67
+ ],
68
+ )
69
+ def test_initialization_errors(self, data, origin, sampling_rate, metadata):
70
+ with pytest.raises(ValueError):
71
+ Density(data, origin, sampling_rate, metadata)
72
+
73
+ def test_repr(self):
74
+ data = DEFAULT_DATA
75
+ origin = DEFAULT_ORIGIN
76
+ sampling_rate = DEFAULT_SAMPLING_RATE
77
+
78
+ density = Density(data, origin, sampling_rate)
79
+ repr_str = density.__repr__()
80
+
81
+ response = "Density object at {}\nOrigin: {}, Sampling Rate: {}, Shape: {}"
82
+ response = response.format(
83
+ hex(id(density)),
84
+ tuple(round(float(x), 3) for x in density.origin),
85
+ tuple(round(float(x), 3) for x in density.sampling_rate),
86
+ density.shape,
87
+ )
88
+ assert response == repr_str
89
+
90
+ @pytest.mark.parametrize("gzip", [(False), (True)])
91
+ def test_to_file(self, gzip: bool):
92
+ self.density.to_file(self.path, gzip=gzip)
93
+ assert True
94
+
95
+ def test_from_file(self):
96
+ self.test_to_file(gzip=False)
97
+ density = Density.from_file(self.path)
98
+ assert np.allclose(density.data, self.density.data)
99
+ assert np.allclose(density.sampling_rate, self.density.sampling_rate)
100
+ assert np.allclose(density.origin, self.density.origin)
101
+ assert density.metadata == self.density.metadata
102
+
103
+ def test_from_file_baseline(self):
104
+ self.test_to_file(gzip=False)
105
+ density = Density.from_file(str(BASEPATH.joinpath("Maps/emd_8621.mrc.gz")))
106
+ assert np.allclose(density.origin, (4.35, 2.90, -1.45), rtol=0.1)
107
+ assert np.allclose(density.sampling_rate, (1.45), rtol=0.3)
108
+
109
+ @pytest.mark.parametrize("extension", ("mrc", "em", "tiff", "h5"))
110
+ @pytest.mark.parametrize("gzip", (True, False))
111
+ @pytest.mark.parametrize("use_memmap", (True, False))
112
+ @pytest.mark.parametrize("subset", (True, False))
113
+ def test_file_format_io(self, extension, gzip, subset, use_memmap):
114
+ base = Density(
115
+ data=np.random.rand(50, 50, 50), origin=(0, 0, 0), sampling_rate=(1, 1, 1)
116
+ )
117
+ data_subset = (slice(0, 22), slice(31, 46), slice(12, 25))
118
+ if extension not in ("mrc", "em"):
119
+ base = Density(
120
+ data=np.random.rand(50, 50), origin=(0, 0), sampling_rate=(1, 1)
121
+ )
122
+ data_subset = (slice(0, 22), slice(31, 46))
123
+ if gzip:
124
+ use_memmap = False
125
+
126
+ if not subset:
127
+ data_subset = tuple(slice(0, x) for x in base.shape)
128
+
129
+ suffix = f".{extension}.gz" if gzip else f".{extension}"
130
+ _, output_file = mkstemp(suffix=suffix)
131
+ base.to_file(output_file, gzip=gzip)
132
+ temp = Density.from_file(output_file, use_memmap=use_memmap, subset=data_subset)
133
+ assert np.allclose(base.data[data_subset], temp.data)
134
+ if extension.upper() == "MRC":
135
+ assert np.allclose(base.origin, temp.origin)
136
+ assert np.allclose(base.sampling_rate, temp.sampling_rate)
137
+
138
+ def test__read_binary_subset_error(self):
139
+ base = Density(
140
+ data=np.random.rand(50, 50, 50), origin=(0, 0, 0), sampling_rate=(1, 1, 1)
141
+ )
142
+ _, output_file = mkstemp()
143
+ base.to_file(output_file)
144
+ with pytest.raises((ValueError, OSError)):
145
+ Density.from_file(output_file, subset=(slice(0, 10),))
146
+ Density.from_file(
147
+ output_file, subset=(slice(-1, 10), slice(5, 10), slice(5, 10))
148
+ )
149
+ Density.from_file(
150
+ output_file, subset=(slice(20, 100), slice(5, 10), slice(5, 10))
151
+ )
152
+
153
+ def test_from_structure(self):
154
+ _ = Density.from_structure(
155
+ self.structure_path,
156
+ origin=self.density.origin,
157
+ shape=self.density.shape,
158
+ sampling_rate=self.density.sampling_rate,
159
+ )
160
+ _ = Density.from_structure(self.structure_path, shape=(30, 30, 30))
161
+ _ = Density.from_structure(
162
+ self.structure_path,
163
+ shape=(30, 30, 30),
164
+ origin=self.density.origin,
165
+ )
166
+ _ = Density.from_structure(self.structure_path, origin=self.density.origin)
167
+ _ = Density.from_structure(
168
+ self.structure_path,
169
+ shape=(30, 30, 30),
170
+ sampling_rate=6,
171
+ origin=self.density.origin,
172
+ )
173
+ _ = Density.from_structure(
174
+ self.structure_path,
175
+ shape=(30, 30, 30),
176
+ sampling_rate=None,
177
+ origin=self.density.origin,
178
+ )
179
+ _ = Density.from_structure(
180
+ self.structure_path, weight_type="atomic_weight", chain="A"
181
+ )
182
+
183
+ @pytest.mark.parametrize(
184
+ "weight_type",
185
+ [
186
+ "atomic_weight",
187
+ "atomic_number",
188
+ "van_der_waals_radius",
189
+ "scattering_factors",
190
+ "lowpass_scattering_factors",
191
+ "gaussian",
192
+ ],
193
+ )
194
+ def test_from_structure_weight_types(self, weight_type):
195
+ _ = Density.from_structure(
196
+ self.structure_path,
197
+ weight_type=weight_type,
198
+ )
199
+
200
+ def test_from_structure_weight_types_error(self):
201
+ with pytest.raises(NotImplementedError):
202
+ _ = Density.from_structure(
203
+ self.structure_path,
204
+ weight_type=None,
205
+ )
206
+
207
+ @pytest.mark.parametrize(
208
+ "weight_type", ["scattering_factors", "lowpass_scattering_factors"]
209
+ )
210
+ @pytest.mark.parametrize(
211
+ "scattering_factors", ["dt1969", "wk1995", "peng1995", "peng1999"]
212
+ )
213
+ def test_from_structure_scattering(self, scattering_factors, weight_type):
214
+ _ = Density.from_structure(
215
+ self.structure_path,
216
+ weight_type=weight_type,
217
+ weight_type_args={"source": scattering_factors},
218
+ )
219
+
220
+ def test_from_structure_error(self):
221
+ with pytest.raises(NotImplementedError):
222
+ _ = Density.from_structure(self.structure_path, weight_type="RAISERROR")
223
+ with pytest.raises(ValueError):
224
+ _ = Density.from_structure(self.structure_path, sampling_rate=(1, 5))
225
+
226
+ def test_empty(self):
227
+ empty_density = self.density.empty
228
+ assert np.allclose(empty_density.data, np.zeros_like(empty_density.data))
229
+ assert np.allclose(empty_density.sampling_rate, self.density.sampling_rate)
230
+ assert np.allclose(empty_density.origin, self.density.origin)
231
+ assert empty_density.metadata == {"min": 0, "max": 0, "mean": 0, "std": 0}
232
+
233
+ def test_copy(self):
234
+ copied_density = self.density.copy()
235
+ assert np.allclose(copied_density.data, self.density.data)
236
+ assert np.allclose(copied_density.sampling_rate, self.density.sampling_rate)
237
+ assert np.allclose(copied_density.origin, self.density.origin)
238
+ assert copied_density.metadata == self.density.metadata
239
+
240
+ def test_to_memmap(self):
241
+ filename = self.path
242
+
243
+ temp = self.density.copy()
244
+ shape, dtype = temp.data.shape, temp.data.dtype
245
+
246
+ arr_memmap = np.memmap(filename, mode="w+", dtype=dtype, shape=shape)
247
+ arr_memmap[:] = temp.data[:]
248
+ arr_memmap.flush()
249
+ arr_memmap = np.memmap(filename, mode="r", dtype=dtype, shape=shape)
250
+
251
+ temp.to_memmap()
252
+
253
+ assert np.allclose(temp.data, arr_memmap)
254
+
255
+ def test_to_numpy(self):
256
+ temp = self.density.copy()
257
+ temp.to_memmap()
258
+ temp.to_numpy()
259
+
260
+ assert np.allclose(temp.data, self.density.data)
261
+
262
+ @pytest.mark.parametrize("threshold", [(0), (0.5), (1), (1.5)])
263
+ def test_to_pointcloud(self, threshold):
264
+ indices = self.density.to_pointcloud(threshold=threshold)
265
+ assert indices.shape[0] == self.density.data.ndim
266
+ assert np.all(self.density.data[tuple(indices)] > threshold)
267
+
268
+ def test__pad_slice(self):
269
+ x, y, z = self.density.shape
270
+ box = (slice(-5, z + 5), slice(0, y - 5), slice(2, z + 2))
271
+ padded_data = self.density._pad_slice(box)
272
+ assert padded_data.shape == (60, y, z)
273
+
274
+ def test_adjust_box(self):
275
+ box = (slice(10, 40), slice(10, 40), slice(10, 40))
276
+ self.density.adjust_box(box)
277
+ assert self.density.data.shape == (30, 30, 30)
278
+ np.testing.assert_array_equal(
279
+ self.density.origin, DEFAULT_ORIGIN + 10 * DEFAULT_SAMPLING_RATE
280
+ )
281
+
282
+ def test_trim_box(self):
283
+ toy_data = np.zeros((20, 20, 20))
284
+ signal = np.ones((5, 5, 5))
285
+ toy_data[0:5, 5:10, 10:15] = signal
286
+ temp = self.density.empty
287
+ temp.data = toy_data
288
+ trim_box = temp.trim_box(cutoff=0.5, margin=0)
289
+ trim_box_margin = temp.trim_box(cutoff=0.5, margin=2)
290
+ temp.adjust_box(trim_box)
291
+ assert np.allclose(temp.data, signal)
292
+ assert trim_box_margin == tuple((slice(0, 7), slice(3, 12), slice(8, 17)))
293
+
294
+ def test_pad(self):
295
+ new_shape = (70, 70, 70)
296
+ self.density.pad(new_shape)
297
+ assert self.density.data.shape == new_shape
298
+
299
+ @pytest.mark.parametrize("new_shape", [(70,), (70, 70), (70, 70, 70, 70)])
300
+ def test_pad_error(self, new_shape):
301
+ with pytest.raises(ValueError):
302
+ self.density.pad(new_shape)
303
+
304
+ def test_minimum_enclosing_box(self):
305
+ # The exact shape may vary, so we will mainly ensure that
306
+ # the data is correctly adapted.
307
+ # Further, more precise tests could be added.
308
+ temp = self.density.copy()
309
+ box = temp.minimum_enclosing_box(cutoff=0.5)
310
+ assert len(box) == temp.data.ndim
311
+ temp = self.density.copy()
312
+ box = temp.minimum_enclosing_box(cutoff=0.5, use_geometric_center=True)
313
+ assert len(box) == temp.data.ndim
314
+
315
+ @pytest.mark.parametrize(
316
+ "cutoff", [DEFAULT_DATA.min() - 1, 0, DEFAULT_DATA.max() - 0.1]
317
+ )
318
+ def test_centered(self, cutoff):
319
+ centered_density, translation = self.density.centered(cutoff=cutoff)
320
+ com = centered_density.center_of_mass(centered_density.data, 0)
321
+
322
+ difference = np.abs(
323
+ np.subtract(
324
+ np.rint(np.array(com)).astype(int),
325
+ np.array(centered_density.shape) // 2,
326
+ )
327
+ )
328
+ assert np.all(difference <= self.density.sampling_rate)
329
+
330
+ @pytest.mark.parametrize("use_geometric_center", (True, False))
331
+ @pytest.mark.parametrize("order", (1, 3))
332
+ def test_rigid_transform(self, use_geometric_center: bool, order: int):
333
+ temp = self.density.copy()
334
+ if use_geometric_center:
335
+ box = temp.minimum_enclosing_box(cutoff=0, use_geometric_center=True)
336
+ temp.adjust_box(box)
337
+ else:
338
+ temp, translation = temp.centered()
339
+
340
+ swaps = set(permutations([0, 1, 2]))
341
+ temp_matrix = np.eye(temp.data.ndim).astype(np.float32)
342
+ rotation_matrix = np.zeros_like(temp_matrix)
343
+
344
+ initial_weight = np.sum(np.abs(temp.data))
345
+ for z, y, x in swaps:
346
+ rotation_matrix[:, 0] = temp_matrix[:, z]
347
+ rotation_matrix[:, 1] = temp_matrix[:, y]
348
+ rotation_matrix[:, 2] = temp_matrix[:, x]
349
+
350
+ transformed = temp.rigid_transform(
351
+ rotation_matrix=rotation_matrix,
352
+ translation=np.zeros(temp.data.ndim),
353
+ use_geometric_center=use_geometric_center,
354
+ order=order,
355
+ )
356
+ transformed_weight = np.sum(np.abs(transformed.data))
357
+ assert np.abs(1 - initial_weight / transformed_weight) < 0.01
358
+
359
+ @pytest.mark.parametrize(
360
+ "new_sampling_rate,order",
361
+ [(2, 1), (4, 3)],
362
+ )
363
+ @pytest.mark.parametrize("method", ("spline", "fourier"))
364
+ def test_resample(self, new_sampling_rate, order, method):
365
+ resampled = self.density.resample(
366
+ new_sampling_rate=new_sampling_rate, order=order, method=method
367
+ )
368
+ assert np.allclose(
369
+ resampled.shape,
370
+ np.divide(self.density.shape, new_sampling_rate).astype(int),
371
+ )
372
+
373
+ @pytest.mark.parametrize(
374
+ "fraction_surface,volume_factor",
375
+ [(0.5, 2), (1, 3)],
376
+ )
377
+ def test_density_boundary(self, fraction_surface, volume_factor):
378
+ # TODO: Pre compute volume boundary on real data
379
+ boundary = self.density.density_boundary(
380
+ weight=1000, fraction_surface=fraction_surface, volume_factor=volume_factor
381
+ )
382
+ assert boundary[0] < boundary[1]
383
+
384
+ @pytest.mark.parametrize(
385
+ "fraction_surface,volume_factor",
386
+ [(-0.5, 0), (1, -3)],
387
+ )
388
+ def test_density_boundary_error(self, fraction_surface, volume_factor):
389
+ with pytest.raises(ValueError):
390
+ _ = self.density.density_boundary(
391
+ weight=1000,
392
+ fraction_surface=fraction_surface,
393
+ volume_factor=volume_factor,
394
+ )
395
+
396
+ @pytest.mark.parametrize(
397
+ "method",
398
+ [("ConvexHull"), ("Weight"), ("Sobel"), ("Laplace"), ("Minimum")],
399
+ )
400
+ def test_surface_coordinates(self, method):
401
+ density_boundaries = self.density.density_boundary(weight=1000)
402
+ self.density.surface_coordinates(
403
+ density_boundaries=density_boundaries, method=method
404
+ )
405
+
406
+ def test_surface_coordinates_error(self):
407
+ density_boundaries = self.density.density_boundary(weight=1000)
408
+ with pytest.raises(ValueError):
409
+ self.density.surface_coordinates(
410
+ density_boundaries=density_boundaries, method=None
411
+ )
412
+
413
+ def test_normal_vectors(self):
414
+ density_boundaries = self.density.density_boundary(weight=1000)
415
+ coordinates = self.density.surface_coordinates(
416
+ density_boundaries=density_boundaries, method="ConvexHull"
417
+ )
418
+ self.density.normal_vectors(coordinates=coordinates)
419
+
420
+ def test_normal_vectors_error(self):
421
+ coordinates = np.random.rand(10, 10, 10)
422
+ with pytest.raises(ValueError):
423
+ self.density.normal_vectors(coordinates=coordinates)
424
+
425
+ coordinates = np.random.rand(10, 4)
426
+ with pytest.raises(ValueError):
427
+ self.density.normal_vectors(coordinates=coordinates)
428
+
429
+ def test_core_mask(self):
430
+ mask = self.density.core_mask()
431
+ assert mask.sum() > 0
432
+
433
+ def test_center_of_mass(self):
434
+ center, shape, radius = (10, 10), (20, 20), 5
435
+ n = len(shape)
436
+ position = np.array(center).reshape((-1,) + (1,) * n)
437
+ arr = np.linalg.norm(np.indices(shape) - position, axis=0)
438
+ arr = (arr <= radius).astype(np.float32)
439
+
440
+ center_of_mass = Density.center_of_mass(arr)
441
+ assert np.allclose(center, center_of_mass)
442
+
443
+ @pytest.mark.parametrize(
444
+ "method",
445
+ [
446
+ ("CrossCorrelation"),
447
+ ("NormalizedCrossCorrelation"),
448
+ ],
449
+ )
450
+ def test_match_densities(self, method: str):
451
+ target = np.zeros((30, 30, 30))
452
+ target[5:10, 15:22, 10:13] = 1
453
+
454
+ target = Density(target, sampling_rate=(1, 1, 1), origin=(0, 0, 0))
455
+ target, translation = target.centered(cutoff=0)
456
+
457
+ template = target.copy()
458
+
459
+ initial_translation = np.array([-1, 3, 0])
460
+ initial_rotation = euler_to_rotationmatrix((-10, 2, 5))
461
+
462
+ template = template.rigid_transform(
463
+ rotation_matrix=initial_rotation,
464
+ translation=initial_translation,
465
+ use_geometric_center=False,
466
+ )
467
+
468
+ target.sampling_rate = np.array(target.sampling_rate[0])
469
+ template.sampling_rate = np.array(template.sampling_rate[0])
470
+
471
+ aligned, translation, rotation = Density.match_densities(
472
+ target=target, template=template, scoring_method=method, maxiter=5
473
+ )
474
+ assert np.allclose(-translation, initial_translation, atol=2)
475
+ assert np.allclose(np.linalg.inv(rotation), initial_rotation, atol=0.2)
476
+
477
+ def test_match_structure_to_density(self):
478
+ density = Density.from_file("tests/data/Maps/emd_8621.mrc.gz")
479
+ density = density.resample(density.sampling_rate * 4)
480
+ structure = Structure.from_file(
481
+ "tests/data/Structures/5uz4.cif", filter_by_residues=None
482
+ )
483
+
484
+ initial_translation = np.array([-1, 0, 5])
485
+ initial_rotation = euler_to_rotationmatrix((-10, 2, 5))
486
+ structure.rigid_transform(
487
+ translation=initial_translation, rotation_matrix=initial_rotation
488
+ )
489
+ np.random.seed(12)
490
+ ret = Density.match_structure_to_density(
491
+ target=density,
492
+ template=structure,
493
+ cutoff_target=0,
494
+ scoring_method="CrossCorrelation",
495
+ maxiter=10,
496
+ )
497
+ structure_aligned, translation, rotation_matrix = ret
498
+
499
+ assert np.allclose(
500
+ structure_aligned.atom_coordinate.shape, structure.atom_coordinate.shape
501
+ )
502
+ assert np.allclose(-translation, initial_translation, atol=2)
503
+ assert np.allclose(np.linalg.inv(rotation_matrix), initial_rotation, atol=0.2)
@@ -0,0 +1,130 @@
1
+ import pytest
2
+ import numpy as np
3
+
4
+ from scipy.spatial import distance
5
+
6
+ from tme.extensions import (
7
+ absolute_minimum_deviation,
8
+ max_euclidean_distance,
9
+ find_candidate_indices,
10
+ find_candidate_coordinates,
11
+ max_index_by_label,
12
+ online_statistics,
13
+ )
14
+
15
+
16
+ COORDINATES, N_COORDINATES = {}, 50
17
+ for i in range(1, 4):
18
+ COORDINATES[i] = np.random.choice(np.arange(100), size=50 * i).reshape(
19
+ N_COORDINATES, i
20
+ )
21
+
22
+ np.random.seed(42)
23
+ TEST_DATA = 10 * np.random.rand(5000)
24
+
25
+
26
+ class TestExtensions:
27
+ @pytest.mark.parametrize("dimension", list(COORDINATES.keys()))
28
+ @pytest.mark.parametrize("dtype", [np.int32, int, np.float32, np.float64])
29
+ def test_absolute_minimum_deviation(self, dimension, dtype):
30
+ coordinates = COORDINATES[dimension].astype(dtype)
31
+ output = np.zeros(
32
+ (coordinates.shape[0], coordinates.shape[0]), dtype=coordinates.dtype
33
+ )
34
+ absolute_minimum_deviation(coordinates=coordinates, output=output)
35
+ expected_output = distance.cdist(
36
+ coordinates, coordinates, lambda u, v: np.min(np.abs(u - v))
37
+ )
38
+ assert np.allclose(output, expected_output)
39
+
40
+ @pytest.mark.parametrize("dimension", list(COORDINATES.keys()))
41
+ @pytest.mark.parametrize("dtype", [np.int32, int, np.float32, np.float64])
42
+ def test_max_euclidean_distance(self, dimension, dtype):
43
+ coordinates = COORDINATES[dimension].astype(dtype)
44
+ print(coordinates.shape)
45
+
46
+ max_distance, pair = max_euclidean_distance(coordinates=coordinates)
47
+ distances = distance.cdist(coordinates, coordinates, "euclidean")
48
+ distances_max = distances.max()
49
+ assert np.allclose(max_distance, distances_max)
50
+
51
+ @pytest.mark.parametrize("dimension", list(COORDINATES.keys()))
52
+ @pytest.mark.parametrize("dtype", [np.int32, int, np.float32, np.float64])
53
+ @pytest.mark.parametrize("min_distance", [0, 5, 10])
54
+ def test_find_candidate_indices(self, dimension, dtype, min_distance):
55
+ coordinates = COORDINATES[dimension].astype(dtype)
56
+ print(coordinates.shape)
57
+
58
+ min_distance = np.array([min_distance]).astype(dtype)[0]
59
+
60
+ candidates = find_candidate_indices(
61
+ coordinates=coordinates, min_distance=min_distance
62
+ )
63
+
64
+ distances = distance.cdist(coordinates[candidates], coordinates[candidates])
65
+ np.fill_diagonal(distances, np.inf)
66
+ assert np.all(distances >= min_distance)
67
+
68
+ @pytest.mark.parametrize("dimension", list(COORDINATES.keys()))
69
+ @pytest.mark.parametrize("dtype", [np.int32, int, np.float32, np.float64])
70
+ @pytest.mark.parametrize("min_distance", [0, 5, 10])
71
+ def test_find_candidate_coordinates(self, dimension, dtype, min_distance):
72
+ coordinates = COORDINATES[dimension].astype(dtype)
73
+ print(coordinates.shape)
74
+
75
+ min_distance = np.array([min_distance]).astype(dtype)[0]
76
+
77
+ filtered_coordinates = find_candidate_coordinates(
78
+ coordinates=coordinates, min_distance=min_distance
79
+ )
80
+
81
+ distances = distance.cdist(filtered_coordinates, filtered_coordinates)
82
+ np.fill_diagonal(distances, np.inf)
83
+ assert np.all(distances >= min_distance)
84
+
85
+ @pytest.mark.parametrize("dtype_labels", [np.int32, int, np.float32, np.float64])
86
+ @pytest.mark.parametrize("dtype_scores", [np.int32, int, np.float32, np.float64])
87
+ def test_max_index_by_label(self, dtype_labels, dtype_scores):
88
+ labels = np.array([1, 1, 2, 3, 3, 3, 2], dtype=dtype_labels)
89
+ scores = np.array([0.5, 0.8, 0.7, 0.2, 0.9, 0.6, 0.5])
90
+ scores = (10 * scores).astype(dtype_scores)
91
+ expected_result = {}
92
+ for label in np.unique(labels):
93
+ mask = labels == label
94
+ expected_result[label] = np.argmax(scores * mask)
95
+
96
+ ret = max_index_by_label(labels=labels, scores=scores)
97
+ print(ret)
98
+ for k in expected_result:
99
+ assert np.allclose(expected_result[k], ret[k])
100
+
101
+ @pytest.mark.parametrize("dtype", [np.int32, int, np.float32, np.float64])
102
+ @pytest.mark.parametrize("splits", [1, 5, 10])
103
+ def test_online_statistics(self, dtype, splits):
104
+ parts = np.array_split(TEST_DATA, splits)
105
+ n, rmean, ssqd, reference = 0, 0, 0, 0.0
106
+ better_or_equal = 0
107
+ end_idx = 0
108
+ for part in parts:
109
+ start_idx = end_idx
110
+ end_idx = start_idx + len(part)
111
+
112
+ n, rmean, ssqd, nbetter_or_equal, max_value, min_value = online_statistics(
113
+ arr=part, n=n, rmean=rmean, ssqd=ssqd
114
+ )
115
+ better_or_equal += nbetter_or_equal
116
+
117
+ data = TEST_DATA[:end_idx]
118
+ dn = len(data)
119
+ drmean = np.mean(data)
120
+ dssqd = np.sum((data - rmean) ** 2)
121
+ dnbetter_or_equal = np.sum(data >= reference)
122
+ dmax_value = np.max(data)
123
+ dmin_value = np.min(data)
124
+
125
+ assert n == dn
126
+ assert np.isclose(rmean, drmean, atol=1e-8)
127
+ assert np.isclose(ssqd, dssqd, atol=1e-8)
128
+ assert better_or_equal == dnbetter_or_equal
129
+ assert np.isclose(max_value, dmax_value, atol=1e-1)
130
+ assert np.isclose(min_value, dmin_value, atol=1e-1)