pytme 0.2.3__cp311-cp311-macosx_14_0_arm64.whl → 0.2.4__cp311-cp311-macosx_14_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (75) hide show
  1. {pytme-0.2.3.data → pytme-0.2.4.data}/scripts/match_template.py +8 -8
  2. {pytme-0.2.3.data → pytme-0.2.4.data}/scripts/preprocess.py +22 -6
  3. {pytme-0.2.3.data → pytme-0.2.4.data}/scripts/preprocessor_gui.py +9 -14
  4. {pytme-0.2.3.dist-info → pytme-0.2.4.dist-info}/METADATA +1 -1
  5. pytme-0.2.4.dist-info/RECORD +119 -0
  6. {pytme-0.2.3.dist-info → pytme-0.2.4.dist-info}/WHEEL +1 -1
  7. {pytme-0.2.3.dist-info → pytme-0.2.4.dist-info}/top_level.txt +1 -0
  8. scripts/match_template.py +8 -8
  9. scripts/preprocess.py +22 -6
  10. scripts/preprocessor_gui.py +9 -14
  11. tests/__init__.py +0 -0
  12. tests/data/.DS_Store +0 -0
  13. tests/data/Blurring/.DS_Store +0 -0
  14. tests/data/Blurring/blob_width18.npy +0 -0
  15. tests/data/Blurring/edgegaussian_sigma3.npy +0 -0
  16. tests/data/Blurring/gaussian_sigma2.npy +0 -0
  17. tests/data/Blurring/hamming_width6.npy +0 -0
  18. tests/data/Blurring/kaiserb_width18.npy +0 -0
  19. tests/data/Blurring/localgaussian_sigma0510.npy +0 -0
  20. tests/data/Blurring/mean_size5.npy +0 -0
  21. tests/data/Blurring/ntree_sigma0510.npy +0 -0
  22. tests/data/Blurring/rank_rank3.npy +0 -0
  23. tests/data/Maps/.DS_Store +0 -0
  24. tests/data/Maps/emd_8621.mrc.gz +0 -0
  25. tests/data/README.md +2 -0
  26. tests/data/Raw/.DS_Store +0 -0
  27. tests/data/Raw/em_map.map +0 -0
  28. tests/data/Structures/.DS_Store +0 -0
  29. tests/data/Structures/1pdj.cif +3339 -0
  30. tests/data/Structures/1pdj.pdb +1429 -0
  31. tests/data/Structures/5khe.cif +3685 -0
  32. tests/data/Structures/5khe.ent +2210 -0
  33. tests/data/Structures/5khe.pdb +2210 -0
  34. tests/data/Structures/5uz4.cif +70548 -0
  35. tests/preprocessing/__init__.py +0 -0
  36. tests/preprocessing/test_compose.py +76 -0
  37. tests/preprocessing/test_frequency_filters.py +178 -0
  38. tests/preprocessing/test_preprocessor.py +136 -0
  39. tests/preprocessing/test_utils.py +79 -0
  40. tests/test_analyzer.py +310 -0
  41. tests/test_backends.py +375 -0
  42. tests/test_density.py +508 -0
  43. tests/test_extensions.py +130 -0
  44. tests/test_matching_cli.py +283 -0
  45. tests/test_matching_data.py +162 -0
  46. tests/test_matching_exhaustive.py +162 -0
  47. tests/test_matching_memory.py +30 -0
  48. tests/test_matching_optimization.py +276 -0
  49. tests/test_matching_utils.py +326 -0
  50. tests/test_orientations.py +173 -0
  51. tests/test_packaging.py +95 -0
  52. tests/test_parser.py +33 -0
  53. tests/test_structure.py +243 -0
  54. tme/__init__.py +0 -1
  55. tme/__version__.py +1 -1
  56. tme/backends/jax_backend.py +8 -7
  57. tme/data/scattering_factors.pickle +0 -0
  58. tme/density.py +11 -4
  59. tme/external/bindings.cpp +332 -0
  60. tme/matching_data.py +11 -9
  61. tme/matching_exhaustive.py +10 -8
  62. tme/matching_utils.py +1 -0
  63. tme/preprocessing/_utils.py +14 -14
  64. tme/preprocessing/composable_filter.py +0 -2
  65. tme/preprocessing/compose.py +4 -4
  66. tme/preprocessing/frequency_filters.py +32 -35
  67. tme/preprocessing/tilt_series.py +202 -118
  68. tme/preprocessor.py +24 -246
  69. tme/structure.py +14 -14
  70. pytme-0.2.3.dist-info/RECORD +0 -75
  71. tme/matching_memory.py +0 -383
  72. {pytme-0.2.3.data → pytme-0.2.4.data}/scripts/estimate_ram_usage.py +0 -0
  73. {pytme-0.2.3.data → pytme-0.2.4.data}/scripts/postprocess.py +0 -0
  74. {pytme-0.2.3.dist-info → pytme-0.2.4.dist-info}/LICENSE +0 -0
  75. {pytme-0.2.3.dist-info → pytme-0.2.4.dist-info}/entry_points.txt +0 -0
tests/test_density.py ADDED
@@ -0,0 +1,508 @@
1
+ from os import remove
2
+ from tempfile import mkstemp
3
+ from itertools import permutations
4
+ from importlib_resources import files
5
+
6
+ import pytest
7
+ import numpy as np
8
+
9
+ from tme import Density, Structure, Preprocessor
10
+ from tme.matching_utils import create_mask, euler_to_rotationmatrix
11
+
12
+ DEFAULT_DATA = create_mask(
13
+ mask_type="ellipse",
14
+ center=(20, 20, 20),
15
+ radius=(10, 5, 10),
16
+ shape=(50, 50, 50),
17
+ )
18
+ DEFAULT_DATA = Preprocessor().gaussian_filter(DEFAULT_DATA * 10, sigma=2)
19
+ DEFAULT_DATA = DEFAULT_DATA.astype(np.float32)
20
+ DEFAULT_ORIGIN = np.array([0, 0, 0])
21
+ DEFAULT_SAMPLING_RATE = np.array([1, 1, 1])
22
+
23
+ BASEPATH = files("tests.data")
24
+
25
+
26
+ class TestDensity:
27
+ def setup_method(self):
28
+ self.density = Density(
29
+ data=DEFAULT_DATA,
30
+ origin=DEFAULT_ORIGIN,
31
+ sampling_rate=DEFAULT_SAMPLING_RATE,
32
+ metadata={
33
+ "min": DEFAULT_DATA.min(),
34
+ "max": DEFAULT_DATA.max(),
35
+ "mean": DEFAULT_DATA.mean(),
36
+ "std": DEFAULT_DATA.std(),
37
+ },
38
+ )
39
+ _, self.path = mkstemp()
40
+ self.structure_path = str(BASEPATH.joinpath("Structures/5khe.cif"))
41
+
42
+ def teardown_method(self):
43
+ del self.density
44
+ remove(self.path)
45
+
46
+ def test_initialization(self):
47
+ data = DEFAULT_DATA
48
+ origin = DEFAULT_ORIGIN
49
+ sampling_rate = DEFAULT_SAMPLING_RATE
50
+ metadata = {"test_key": "test_value"}
51
+
52
+ density = Density(data, origin, sampling_rate, metadata)
53
+
54
+ assert np.array_equal(density.data, data)
55
+ assert np.array_equal(density.origin, origin)
56
+ assert np.array_equal(density.sampling_rate, sampling_rate)
57
+ assert density.metadata == metadata
58
+
59
+ @pytest.mark.parametrize(
60
+ "data,origin,sampling_rate,metadata",
61
+ [
62
+ (np.random.rand(50, 50, 50), (0, 0, 0), (1, 2), {}),
63
+ (np.random.rand(50, 50, 50), (0, 0, 0), (1, 2, 3, 4), {}),
64
+ (np.random.rand(50, 50, 50), (0, 0, 0), (1, 2, 3), "not_a_dict"),
65
+ (np.random.rand(50, 50, 50), (0, 0), (1, 2, 3), "not_a_dict"),
66
+ ],
67
+ )
68
+ def test_initialization_errors(self, data, origin, sampling_rate, metadata):
69
+ with pytest.raises(ValueError):
70
+ Density(data, origin, sampling_rate, metadata)
71
+
72
+ def test_repr(self):
73
+ data = DEFAULT_DATA
74
+ origin = DEFAULT_ORIGIN
75
+ sampling_rate = DEFAULT_SAMPLING_RATE
76
+
77
+ density = Density(data, origin, sampling_rate)
78
+ repr_str = density.__repr__()
79
+
80
+ response = "Density object at {}\nOrigin: {}, Sampling Rate: {}, Shape: {}"
81
+ response = response.format(
82
+ hex(id(density)),
83
+ tuple(round(float(x), 3) for x in density.origin),
84
+ tuple(round(float(x), 3) for x in density.sampling_rate),
85
+ density.shape,
86
+ )
87
+ assert response == repr_str
88
+
89
+ @pytest.mark.parametrize("gzip", [(False), (True)])
90
+ def test_to_file(self, gzip: bool):
91
+ self.density.to_file(self.path, gzip=gzip)
92
+ assert True
93
+
94
+ def test_from_file(self):
95
+ self.test_to_file(gzip=False)
96
+ density = Density.from_file(self.path)
97
+ assert np.allclose(density.data, self.density.data)
98
+ assert np.allclose(density.sampling_rate, self.density.sampling_rate)
99
+ assert np.allclose(density.origin, self.density.origin)
100
+ assert density.metadata == self.density.metadata
101
+
102
+ def test_from_file_baseline(self):
103
+ self.test_to_file(gzip=False)
104
+ density = Density.from_file(str(BASEPATH.joinpath("Maps/emd_8621.mrc.gz")))
105
+ assert np.allclose(density.origin, (-1.45, 2.90, 4.35), rtol=0.1)
106
+ assert np.allclose(density.sampling_rate, (1.45), rtol=0.3)
107
+
108
+ @pytest.mark.parametrize("extension", ("mrc", "em", "tiff", "h5"))
109
+ @pytest.mark.parametrize("gzip", (True, False))
110
+ @pytest.mark.parametrize("use_memmap", (True, False))
111
+ @pytest.mark.parametrize("subset", (True, False))
112
+ def test_file_format_io(self, extension, gzip, subset, use_memmap):
113
+ base = Density(
114
+ data=np.random.rand(50, 50, 50), origin=(0, 0, 0), sampling_rate=(1, 1, 1)
115
+ )
116
+ data_subset = (slice(0, 22), slice(31, 46), slice(12, 25))
117
+ if extension not in ("mrc", "em"):
118
+ base = Density(
119
+ data=np.random.rand(50, 50), origin=(0, 0), sampling_rate=(1, 1)
120
+ )
121
+ data_subset = (slice(0, 22), slice(31, 46))
122
+ if gzip:
123
+ use_memmap = False
124
+
125
+ if not subset:
126
+ data_subset = tuple(slice(0, x) for x in base.shape)
127
+
128
+ suffix = f".{extension}.gz" if gzip else f".{extension}"
129
+ _, output_file = mkstemp(suffix=suffix)
130
+ base.to_file(output_file, gzip=gzip)
131
+ temp = Density.from_file(output_file, use_memmap=use_memmap, subset=data_subset)
132
+ assert np.allclose(base.data[data_subset], temp.data)
133
+ if extension.upper() == "MRC":
134
+ assert np.allclose(base.origin, temp.origin)
135
+ assert np.allclose(base.sampling_rate, temp.sampling_rate)
136
+
137
+ def test__read_binary_subset_error(self):
138
+ base = Density(
139
+ data=np.random.rand(50, 50, 50), origin=(0, 0, 0), sampling_rate=(1, 1, 1)
140
+ )
141
+ _, output_file = mkstemp()
142
+ base.to_file(output_file)
143
+ with pytest.raises((ValueError, OSError)):
144
+ Density.from_file(output_file, subset=(slice(0, 10),))
145
+ Density.from_file(
146
+ output_file, subset=(slice(-1, 10), slice(5, 10), slice(5, 10))
147
+ )
148
+ Density.from_file(
149
+ output_file, subset=(slice(20, 100), slice(5, 10), slice(5, 10))
150
+ )
151
+
152
+ def test_from_structure(self):
153
+ _ = Density.from_structure(
154
+ self.structure_path,
155
+ origin=self.density.origin,
156
+ shape=self.density.shape,
157
+ sampling_rate=self.density.sampling_rate,
158
+ )
159
+ _ = Density.from_structure(self.structure_path, shape=(30, 30, 30))
160
+ _ = Density.from_structure(
161
+ self.structure_path,
162
+ shape=(30, 30, 30),
163
+ origin=self.density.origin,
164
+ )
165
+ _ = Density.from_structure(self.structure_path, origin=self.density.origin)
166
+ _ = Density.from_structure(
167
+ self.structure_path,
168
+ shape=(30, 30, 30),
169
+ sampling_rate=6,
170
+ origin=self.density.origin,
171
+ )
172
+ _ = Density.from_structure(
173
+ self.structure_path,
174
+ shape=(30, 30, 30),
175
+ sampling_rate=None,
176
+ origin=self.density.origin,
177
+ )
178
+ _ = Density.from_structure(
179
+ self.structure_path, weight_type="atomic_weight", chain="A"
180
+ )
181
+
182
+ @pytest.mark.parametrize(
183
+ "weight_type",
184
+ [
185
+ "atomic_weight",
186
+ "atomic_number",
187
+ "van_der_waals_radius",
188
+ "scattering_factors",
189
+ "lowpass_scattering_factors",
190
+ "gaussian",
191
+ ],
192
+ )
193
+ def test_from_structure_weight_types(self, weight_type):
194
+ _ = Density.from_structure(
195
+ self.structure_path,
196
+ weight_type=weight_type,
197
+ )
198
+
199
+ def test_from_structure_weight_types_error(self):
200
+ with pytest.raises(NotImplementedError):
201
+ _ = Density.from_structure(
202
+ self.structure_path,
203
+ weight_type=None,
204
+ )
205
+
206
+ @pytest.mark.parametrize(
207
+ "weight_type", ["scattering_factors", "lowpass_scattering_factors"]
208
+ )
209
+ @pytest.mark.parametrize(
210
+ "scattering_factors", ["dt1969", "wk1995", "peng1995", "peng1999"]
211
+ )
212
+ def test_from_structure_scattering(self, scattering_factors, weight_type):
213
+ _ = Density.from_structure(
214
+ self.structure_path,
215
+ weight_type=weight_type,
216
+ weight_type_args={"source": scattering_factors},
217
+ )
218
+
219
+ def test_from_structure_error(self):
220
+ with pytest.raises(NotImplementedError):
221
+ _ = Density.from_structure(self.structure_path, weight_type="RAISERROR")
222
+ with pytest.raises(ValueError):
223
+ _ = Density.from_structure(self.structure_path, sampling_rate=(1, 5))
224
+
225
+ def test_empty(self):
226
+ empty_density = self.density.empty
227
+ assert np.allclose(empty_density.data, np.zeros_like(empty_density.data))
228
+ assert np.allclose(empty_density.sampling_rate, self.density.sampling_rate)
229
+ assert np.allclose(empty_density.origin, self.density.origin)
230
+ assert empty_density.metadata == {"min": 0, "max": 0, "mean": 0, "std": 0}
231
+
232
+ def test_copy(self):
233
+ copied_density = self.density.copy()
234
+ assert np.allclose(copied_density.data, self.density.data)
235
+ assert np.allclose(copied_density.sampling_rate, self.density.sampling_rate)
236
+ assert np.allclose(copied_density.origin, self.density.origin)
237
+ assert copied_density.metadata == self.density.metadata
238
+
239
+ def test_to_memmap(self):
240
+ filename = self.path
241
+
242
+ temp = self.density.copy()
243
+ shape, dtype = temp.data.shape, temp.data.dtype
244
+
245
+ arr_memmap = np.memmap(filename, mode="w+", dtype=dtype, shape=shape)
246
+ arr_memmap[:] = temp.data[:]
247
+ arr_memmap.flush()
248
+ arr_memmap = np.memmap(filename, mode="r", dtype=dtype, shape=shape)
249
+
250
+ temp.to_memmap()
251
+
252
+ assert np.allclose(temp.data, arr_memmap)
253
+
254
+ def test_to_numpy(self):
255
+ temp = self.density.copy()
256
+ temp.to_memmap()
257
+ temp.to_numpy()
258
+
259
+ assert np.allclose(temp.data, self.density.data)
260
+
261
+ @pytest.mark.parametrize("threshold", [(0), (0.5), (1), (1.5)])
262
+ def test_to_pointcloud(self, threshold):
263
+ indices = self.density.to_pointcloud(threshold=threshold)
264
+ assert indices.shape[0] == self.density.data.ndim
265
+ assert np.all(self.density.data[tuple(indices)] > threshold)
266
+
267
+ def test__pad_slice(self):
268
+ x, y, z = self.density.shape
269
+ box = (slice(-5, z + 5), slice(0, y - 5), slice(2, z + 2))
270
+ padded_data = self.density._pad_slice(box)
271
+ assert padded_data.shape == (60, y, z)
272
+
273
+ def test_adjust_box(self):
274
+ box = (slice(10, 40), slice(10, 40), slice(10, 40))
275
+ self.density.adjust_box(box)
276
+ assert self.density.data.shape == (30, 30, 30)
277
+ np.testing.assert_array_equal(
278
+ self.density.origin, DEFAULT_ORIGIN + 10 * DEFAULT_SAMPLING_RATE
279
+ )
280
+
281
+ def test_trim_box(self):
282
+ toy_data = np.zeros((20, 20, 20))
283
+ signal = np.ones((5, 5, 5))
284
+ toy_data[0:5, 5:10, 10:15] = signal
285
+ temp = self.density.empty
286
+ temp.data = toy_data
287
+ trim_box = temp.trim_box(cutoff=0.5, margin=0)
288
+ trim_box_margin = temp.trim_box(cutoff=0.5, margin=2)
289
+ temp.adjust_box(trim_box)
290
+ assert np.allclose(temp.data, signal)
291
+ assert trim_box_margin == tuple((slice(0, 7), slice(3, 12), slice(8, 17)))
292
+
293
+ def test_pad(self):
294
+ new_shape = (70, 70, 70)
295
+ self.density.pad(new_shape)
296
+ assert self.density.data.shape == new_shape
297
+
298
+ @pytest.mark.parametrize("new_shape", [(70,), (70, 70), (70, 70, 70, 70)])
299
+ def test_pad_error(self, new_shape):
300
+ with pytest.raises(ValueError):
301
+ self.density.pad(new_shape)
302
+
303
+ def test_minimum_enclosing_box(self):
304
+ # The exact shape may vary, so we will mainly ensure that
305
+ # the data is correctly adapted.
306
+ # Further, more precise tests could be added.
307
+ temp = self.density.copy()
308
+ box = temp.minimum_enclosing_box(cutoff=0.5)
309
+ assert len(box) == temp.data.ndim
310
+ temp = self.density.copy()
311
+ box = temp.minimum_enclosing_box(cutoff=0.5, use_geometric_center=True)
312
+ assert len(box) == temp.data.ndim
313
+
314
+ @pytest.mark.parametrize(
315
+ "cutoff", [DEFAULT_DATA.min() - 1, 0, DEFAULT_DATA.max() - 0.1]
316
+ )
317
+ def test_centered(self, cutoff):
318
+ centered_density, translation = self.density.centered(cutoff=cutoff)
319
+ com = centered_density.center_of_mass(centered_density.data, 0)
320
+
321
+ difference = np.abs(
322
+ np.subtract(
323
+ np.rint(np.array(com)).astype(int),
324
+ np.array(centered_density.shape) // 2,
325
+ )
326
+ )
327
+ assert np.all(difference <= self.density.sampling_rate)
328
+
329
+ @pytest.mark.parametrize("use_geometric_center", (True, False))
330
+ @pytest.mark.parametrize("order", (1, 3))
331
+ def test_rigid_transform(self, use_geometric_center: bool, order: int):
332
+ temp = self.density.copy()
333
+ if use_geometric_center:
334
+ box = temp.minimum_enclosing_box(cutoff=0, use_geometric_center=True)
335
+ temp.adjust_box(box)
336
+ else:
337
+ temp, translation = temp.centered()
338
+
339
+ swaps = set(permutations([0, 1, 2]))
340
+ temp_matrix = np.eye(temp.data.ndim).astype(np.float32)
341
+ rotation_matrix = np.zeros_like(temp_matrix)
342
+
343
+ initial_weight = np.sum(np.abs(temp.data))
344
+ for z, y, x in swaps:
345
+ rotation_matrix[:, 0] = temp_matrix[:, z]
346
+ rotation_matrix[:, 1] = temp_matrix[:, y]
347
+ rotation_matrix[:, 2] = temp_matrix[:, x]
348
+
349
+ transformed = temp.rigid_transform(
350
+ rotation_matrix=rotation_matrix,
351
+ translation=np.zeros(temp.data.ndim),
352
+ use_geometric_center=use_geometric_center,
353
+ order=order,
354
+ )
355
+ transformed_weight = np.sum(np.abs(transformed.data))
356
+ assert np.abs(1 - initial_weight / transformed_weight) < 0.01
357
+
358
+ @pytest.mark.parametrize(
359
+ "new_sampling_rate,order",
360
+ [(2, 1), (4, 3)],
361
+ )
362
+ @pytest.mark.parametrize("method", ("spline", "fourier"))
363
+ def test_resample(self, new_sampling_rate, order, method):
364
+ resampled = self.density.resample(
365
+ new_sampling_rate=new_sampling_rate, order=order, method=method
366
+ )
367
+ assert np.allclose(
368
+ resampled.shape,
369
+ np.divide(self.density.shape, new_sampling_rate).astype(int),
370
+ )
371
+
372
+ @pytest.mark.parametrize(
373
+ "fraction_surface,volume_factor",
374
+ [(0.5, 2), (1, 3)],
375
+ )
376
+ def test_density_boundary(self, fraction_surface, volume_factor):
377
+ # TODO: Pre compute volume boundary on real data
378
+ boundary = self.density.density_boundary(
379
+ weight=1000, fraction_surface=fraction_surface, volume_factor=volume_factor
380
+ )
381
+ assert boundary[0] < boundary[1]
382
+
383
+ @pytest.mark.parametrize(
384
+ "fraction_surface,volume_factor",
385
+ [(-0.5, 0), (1, -3)],
386
+ )
387
+ def test_density_boundary_error(self, fraction_surface, volume_factor):
388
+ with pytest.raises(ValueError):
389
+ _ = self.density.density_boundary(
390
+ weight=1000,
391
+ fraction_surface=fraction_surface,
392
+ volume_factor=volume_factor,
393
+ )
394
+
395
+ @pytest.mark.parametrize(
396
+ "method",
397
+ [("ConvexHull"), ("Weight"), ("Sobel"), ("Laplace"), ("Minimum")],
398
+ )
399
+ def test_surface_coordinates(self, method):
400
+ density_boundaries = self.density.density_boundary(weight=1000)
401
+ self.density.surface_coordinates(
402
+ density_boundaries=density_boundaries, method=method
403
+ )
404
+
405
+ def test_surface_coordinates_error(self):
406
+ density_boundaries = self.density.density_boundary(weight=1000)
407
+ with pytest.raises(ValueError):
408
+ self.density.surface_coordinates(
409
+ density_boundaries=density_boundaries, method=None
410
+ )
411
+
412
+ def test_normal_vectors(self):
413
+ density_boundaries = self.density.density_boundary(weight=1000)
414
+ coordinates = self.density.surface_coordinates(
415
+ density_boundaries=density_boundaries, method="ConvexHull"
416
+ )
417
+ self.density.normal_vectors(coordinates=coordinates)
418
+
419
+ def test_normal_vectors_error(self):
420
+ coordinates = np.random.rand(10, 10, 10)
421
+ with pytest.raises(ValueError):
422
+ self.density.normal_vectors(coordinates=coordinates)
423
+
424
+ coordinates = np.random.rand(10, 4)
425
+ with pytest.raises(ValueError):
426
+ self.density.normal_vectors(coordinates=coordinates)
427
+
428
+ def test_core_mask(self):
429
+ mask = self.density.core_mask()
430
+ assert mask.sum() > 0
431
+
432
+ def test_center_of_mass(self):
433
+ center, shape, radius = (10, 10), (20, 20), 5
434
+ n = len(shape)
435
+ position = np.array(center).reshape((-1,) + (1,) * n)
436
+ arr = np.linalg.norm(np.indices(shape) - position, axis=0)
437
+ arr = (arr <= radius).astype(np.float32)
438
+
439
+ center_of_mass = Density.center_of_mass(arr)
440
+ assert np.allclose(center, center_of_mass)
441
+
442
+ @pytest.mark.parametrize(
443
+ "method",
444
+ [
445
+ ("CrossCorrelation"),
446
+ ("NormalizedCrossCorrelation"),
447
+ ],
448
+ )
449
+ def test_match_densities(self, method: str):
450
+ target = np.zeros((30, 30, 30))
451
+ target[5:10, 15:22, 10:13] = 1
452
+
453
+ target = Density(target, sampling_rate=(1, 1, 1), origin=(0, 0, 0))
454
+ target, translation = target.centered(cutoff=0)
455
+
456
+ template = target.copy()
457
+
458
+ initial_translation = np.array([-1, 3, 0])
459
+ initial_rotation = euler_to_rotationmatrix((-10, 2, 5))
460
+
461
+ template = template.rigid_transform(
462
+ rotation_matrix=initial_rotation,
463
+ translation=initial_translation,
464
+ use_geometric_center=False,
465
+ )
466
+
467
+ target.sampling_rate = np.array(target.sampling_rate[0])
468
+ template.sampling_rate = np.array(template.sampling_rate[0])
469
+
470
+ aligned, translation, rotation = Density.match_densities(
471
+ target=target, template=template, scoring_method=method, maxiter=5
472
+ )
473
+ assert np.allclose(-translation, initial_translation, atol=2)
474
+ assert np.allclose(np.linalg.inv(rotation), initial_rotation, atol=0.2)
475
+
476
+ def test_match_structure_to_density(self):
477
+ density = Density.from_file("tests/data/Maps/emd_8621.mrc.gz")
478
+ density = density.resample(density.sampling_rate * 4)
479
+ structure = Structure.from_file(
480
+ "tests/data/Structures/5uz4.cif", filter_by_residues=None
481
+ )
482
+
483
+ initial_translation = np.array([-1, 0, 5])
484
+ initial_rotation = euler_to_rotationmatrix((-10, 2, 5))
485
+ structure.rigid_transform(
486
+ translation=initial_translation, rotation_matrix=initial_rotation
487
+ )
488
+ np.random.seed(12)
489
+ ret = Density.match_structure_to_density(
490
+ target=density,
491
+ template=structure,
492
+ cutoff_target=0,
493
+ scoring_method="CrossCorrelation",
494
+ maxiter=10,
495
+ )
496
+ structure_aligned, translation, rotation_matrix = ret
497
+
498
+ assert np.allclose(
499
+ structure_aligned.atom_coordinate.shape, structure.atom_coordinate.shape
500
+ )
501
+ assert np.allclose(-translation, initial_translation, atol=2)
502
+ assert np.allclose(np.linalg.inv(rotation_matrix), initial_rotation, atol=0.2)
503
+
504
+ def test_fourier_shell_correlation(self):
505
+ fsc = Density.fourier_shell_correlation(
506
+ self.density.copy(), self.density.copy()
507
+ )
508
+ assert fsc.shape[1] == 2
@@ -0,0 +1,130 @@
1
+ import pytest
2
+ import numpy as np
3
+
4
+ from scipy.spatial import distance
5
+
6
+ from tme.extensions import (
7
+ absolute_minimum_deviation,
8
+ max_euclidean_distance,
9
+ find_candidate_indices,
10
+ find_candidate_coordinates,
11
+ max_index_by_label,
12
+ online_statistics,
13
+ )
14
+
15
+
16
+ COORDINATES, N_COORDINATES = {}, 50
17
+ for i in range(1, 4):
18
+ COORDINATES[i] = np.random.choice(np.arange(100), size=50 * i).reshape(
19
+ N_COORDINATES, i
20
+ )
21
+
22
+ np.random.seed(42)
23
+ TEST_DATA = 10 * np.random.rand(5000)
24
+
25
+
26
+ class TestExtensions:
27
+ @pytest.mark.parametrize("dimension", list(COORDINATES.keys()))
28
+ @pytest.mark.parametrize("dtype", [np.int32, int, np.float32, np.float64])
29
+ def test_absolute_minimum_deviation(self, dimension, dtype):
30
+ coordinates = COORDINATES[dimension].astype(dtype)
31
+ output = np.zeros(
32
+ (coordinates.shape[0], coordinates.shape[0]), dtype=coordinates.dtype
33
+ )
34
+ absolute_minimum_deviation(coordinates=coordinates, output=output)
35
+ expected_output = distance.cdist(
36
+ coordinates, coordinates, lambda u, v: np.min(np.abs(u - v))
37
+ )
38
+ assert np.allclose(output, expected_output)
39
+
40
+ @pytest.mark.parametrize("dimension", list(COORDINATES.keys()))
41
+ @pytest.mark.parametrize("dtype", [np.int32, int, np.float32, np.float64])
42
+ def test_max_euclidean_distance(self, dimension, dtype):
43
+ coordinates = COORDINATES[dimension].astype(dtype)
44
+ print(coordinates.shape)
45
+
46
+ max_distance, pair = max_euclidean_distance(coordinates=coordinates)
47
+ distances = distance.cdist(coordinates, coordinates, "euclidean")
48
+ distances_max = distances.max()
49
+ assert np.allclose(max_distance, distances_max)
50
+
51
+ @pytest.mark.parametrize("dimension", list(COORDINATES.keys()))
52
+ @pytest.mark.parametrize("dtype", [np.int32, int, np.float32, np.float64])
53
+ @pytest.mark.parametrize("min_distance", [0, 5, 10])
54
+ def test_find_candidate_indices(self, dimension, dtype, min_distance):
55
+ coordinates = COORDINATES[dimension].astype(dtype)
56
+ print(coordinates.shape)
57
+
58
+ min_distance = np.array([min_distance]).astype(dtype)[0]
59
+
60
+ candidates = find_candidate_indices(
61
+ coordinates=coordinates, min_distance=min_distance
62
+ )
63
+
64
+ distances = distance.cdist(coordinates[candidates], coordinates[candidates])
65
+ np.fill_diagonal(distances, np.inf)
66
+ assert np.all(distances >= min_distance)
67
+
68
+ @pytest.mark.parametrize("dimension", list(COORDINATES.keys()))
69
+ @pytest.mark.parametrize("dtype", [np.int32, int, np.float32, np.float64])
70
+ @pytest.mark.parametrize("min_distance", [0, 5, 10])
71
+ def test_find_candidate_coordinates(self, dimension, dtype, min_distance):
72
+ coordinates = COORDINATES[dimension].astype(dtype)
73
+ print(coordinates.shape)
74
+
75
+ min_distance = np.array([min_distance]).astype(dtype)[0]
76
+
77
+ filtered_coordinates = find_candidate_coordinates(
78
+ coordinates=coordinates, min_distance=min_distance
79
+ )
80
+
81
+ distances = distance.cdist(filtered_coordinates, filtered_coordinates)
82
+ np.fill_diagonal(distances, np.inf)
83
+ assert np.all(distances >= min_distance)
84
+
85
+ @pytest.mark.parametrize("dtype_labels", [np.int32, int, np.float32, np.float64])
86
+ @pytest.mark.parametrize("dtype_scores", [np.int32, int, np.float32, np.float64])
87
+ def test_max_index_by_label(self, dtype_labels, dtype_scores):
88
+ labels = np.array([1, 1, 2, 3, 3, 3, 2], dtype=dtype_labels)
89
+ scores = np.array([0.5, 0.8, 0.7, 0.2, 0.9, 0.6, 0.5])
90
+ scores = (10 * scores).astype(dtype_scores)
91
+ expected_result = {}
92
+ for label in np.unique(labels):
93
+ mask = labels == label
94
+ expected_result[label] = np.argmax(scores * mask)
95
+
96
+ ret = max_index_by_label(labels=labels, scores=scores)
97
+ print(ret)
98
+ for k in expected_result:
99
+ assert np.allclose(expected_result[k], ret[k])
100
+
101
+ @pytest.mark.parametrize("dtype", [np.int32, int, np.float32, np.float64])
102
+ @pytest.mark.parametrize("splits", [1, 5, 10])
103
+ def test_online_statistics(self, dtype, splits):
104
+ parts = np.array_split(TEST_DATA, splits)
105
+ n, rmean, ssqd, reference = 0, 0, 0, 0.0
106
+ better_or_equal = 0
107
+ end_idx = 0
108
+ for part in parts:
109
+ start_idx = end_idx
110
+ end_idx = start_idx + len(part)
111
+
112
+ n, rmean, ssqd, nbetter_or_equal, max_value, min_value = online_statistics(
113
+ arr=part, n=n, rmean=rmean, ssqd=ssqd
114
+ )
115
+ better_or_equal += nbetter_or_equal
116
+
117
+ data = TEST_DATA[:end_idx]
118
+ dn = len(data)
119
+ drmean = np.mean(data)
120
+ dssqd = np.sum((data - rmean) ** 2)
121
+ dnbetter_or_equal = np.sum(data >= reference)
122
+ dmax_value = np.max(data)
123
+ dmin_value = np.min(data)
124
+
125
+ assert n == dn
126
+ assert np.isclose(rmean, drmean, atol=1e-8)
127
+ assert np.isclose(ssqd, dssqd, atol=1e-8)
128
+ assert better_or_equal == dnbetter_or_equal
129
+ assert np.isclose(max_value, dmax_value, atol=1e-1)
130
+ assert np.isclose(min_value, dmin_value, atol=1e-1)