pytme 0.3.1.post1__cp311-cp311-macosx_15_0_arm64.whl → 0.3.2.dev0__cp311-cp311-macosx_15_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (69) hide show
  1. pytme-0.3.2.dev0.data/scripts/estimate_ram_usage.py +97 -0
  2. {pytme-0.3.1.post1.data → pytme-0.3.2.dev0.data}/scripts/match_template.py +213 -196
  3. {pytme-0.3.1.post1.data → pytme-0.3.2.dev0.data}/scripts/postprocess.py +40 -78
  4. {pytme-0.3.1.post1.data → pytme-0.3.2.dev0.data}/scripts/preprocess.py +4 -5
  5. {pytme-0.3.1.post1.data → pytme-0.3.2.dev0.data}/scripts/preprocessor_gui.py +50 -103
  6. {pytme-0.3.1.post1.data → pytme-0.3.2.dev0.data}/scripts/pytme_runner.py +46 -69
  7. {pytme-0.3.1.post1.dist-info → pytme-0.3.2.dev0.dist-info}/METADATA +2 -1
  8. pytme-0.3.2.dev0.dist-info/RECORD +136 -0
  9. scripts/estimate_ram_usage.py +97 -0
  10. scripts/match_template.py +213 -196
  11. scripts/match_template_devel.py +1339 -0
  12. scripts/postprocess.py +40 -78
  13. scripts/preprocess.py +4 -5
  14. scripts/preprocessor_gui.py +50 -103
  15. scripts/pytme_runner.py +46 -69
  16. scripts/refine_matches.py +5 -7
  17. tests/preprocessing/test_compose.py +31 -30
  18. tests/preprocessing/test_frequency_filters.py +17 -32
  19. tests/preprocessing/test_preprocessor.py +0 -19
  20. tests/preprocessing/test_utils.py +13 -1
  21. tests/test_analyzer.py +2 -10
  22. tests/test_backends.py +47 -18
  23. tests/test_density.py +72 -13
  24. tests/test_extensions.py +1 -0
  25. tests/test_matching_cli.py +23 -9
  26. tests/test_matching_exhaustive.py +5 -5
  27. tests/test_matching_utils.py +3 -3
  28. tests/test_rotations.py +13 -23
  29. tests/test_structure.py +1 -7
  30. tme/__version__.py +1 -1
  31. tme/analyzer/aggregation.py +47 -16
  32. tme/analyzer/base.py +34 -0
  33. tme/analyzer/peaks.py +26 -13
  34. tme/analyzer/proxy.py +14 -0
  35. tme/backends/_jax_utils.py +124 -71
  36. tme/backends/cupy_backend.py +6 -19
  37. tme/backends/jax_backend.py +110 -105
  38. tme/backends/matching_backend.py +0 -17
  39. tme/backends/mlx_backend.py +0 -29
  40. tme/backends/npfftw_backend.py +100 -97
  41. tme/backends/pytorch_backend.py +65 -78
  42. tme/cli.py +2 -2
  43. tme/density.py +102 -58
  44. tme/extensions.cpython-311-darwin.so +0 -0
  45. tme/filters/_utils.py +52 -24
  46. tme/filters/bandpass.py +99 -105
  47. tme/filters/compose.py +133 -39
  48. tme/filters/ctf.py +51 -102
  49. tme/filters/reconstruction.py +67 -122
  50. tme/filters/wedge.py +296 -325
  51. tme/filters/whitening.py +39 -75
  52. tme/mask.py +2 -2
  53. tme/matching_data.py +87 -15
  54. tme/matching_exhaustive.py +70 -120
  55. tme/matching_optimization.py +9 -63
  56. tme/matching_scores.py +261 -100
  57. tme/matching_utils.py +150 -91
  58. tme/memory.py +1 -0
  59. tme/orientations.py +28 -8
  60. tme/preprocessor.py +0 -239
  61. tme/rotations.py +102 -70
  62. tme/structure.py +601 -631
  63. tme/types.py +1 -0
  64. pytme-0.3.1.post1.dist-info/RECORD +0 -133
  65. {pytme-0.3.1.post1.data → pytme-0.3.2.dev0.data}/scripts/estimate_memory_usage.py +0 -0
  66. {pytme-0.3.1.post1.dist-info → pytme-0.3.2.dev0.dist-info}/WHEEL +0 -0
  67. {pytme-0.3.1.post1.dist-info → pytme-0.3.2.dev0.dist-info}/entry_points.txt +0 -0
  68. {pytme-0.3.1.post1.dist-info → pytme-0.3.2.dev0.dist-info}/licenses/LICENSE +0 -0
  69. {pytme-0.3.1.post1.dist-info → pytme-0.3.2.dev0.dist-info}/top_level.txt +0 -0
tests/test_backends.py CHANGED
@@ -331,16 +331,6 @@ class TestBackends:
331
331
  real_arr = backend.irfftn(complex_arr)
332
332
  assert np.allclose(arr, backend.to_numpy_array(real_arr), rtol=0.3)
333
333
 
334
- @pytest.mark.parametrize("backend", BACKENDS_TO_TEST)
335
- def test_extract_center(self, backend):
336
- new_shape = np.divide(self.x1.shape, 2).astype(int)
337
- base = self.backend.extract_center(arr=self.x1, newshape=new_shape)
338
- other = backend.extract_center(
339
- arr=backend.to_backend_array(self.x1), newshape=new_shape
340
- )
341
-
342
- assert np.allclose(base, backend.to_numpy_array(other), rtol=0.01)
343
-
344
334
  @pytest.mark.parametrize("backend", BACKENDS_TO_TEST)
345
335
  def test_compute_convolution_shapes(self, backend):
346
336
  base = self.backend.compute_convolution_shapes(self.x1.shape, self.x2.shape)
@@ -359,32 +349,69 @@ class TestBackends:
359
349
  elif dim == 3:
360
350
  arr[20:25, 21:26, 26:31] = 1
361
351
 
362
- rotation_matrix = np.eye(dim)
363
- rotation_matrix[0, 0] = -1
352
+ from tme.rotations import get_rotation_matrices
353
+
354
+ np.random.seed(42)
355
+ rotation_matrix = get_rotation_matrices(
356
+ dim=dim, angular_sampling=10, use_optimized_set=False
357
+ )[-1]
364
358
 
365
359
  out = np.zeros_like(arr)
360
+ out.setflags(write=True)
366
361
 
367
362
  arr_mask, out_mask = None, None
368
363
  if create_mask:
369
364
  arr_mask = np.multiply(np.random.rand(*arr.shape) > 0.5, 1.0)
370
365
  out_mask = np.zeros_like(arr_mask)
366
+ out_mask.setflags(write=True)
367
+
368
+ out, _ = NumpyFFTWBackend().rigid_transform(
369
+ arr=arr,
370
+ arr_mask=arr_mask,
371
+ rotation_matrix=rotation_matrix,
372
+ out=out,
373
+ out_mask=out_mask,
374
+ order=1,
375
+ use_geometric_center=True,
376
+ )
377
+
378
+ arr = backend.to_backend_array(arr.copy())
379
+ out_be = backend.to_backend_array(out.copy())
380
+ if create_mask:
371
381
  arr_mask = backend.to_backend_array(arr_mask)
372
382
  out_mask = backend.to_backend_array(out_mask)
373
383
 
374
- arr = backend.to_backend_array(arr)
375
- out = backend.to_backend_array(arr)
376
-
377
384
  rotation_matrix = backend.to_backend_array(rotation_matrix)
378
385
 
379
- backend.rigid_transform(
386
+ out_be, _ = backend.rigid_transform(
380
387
  arr=arr,
381
388
  arr_mask=arr_mask,
382
389
  rotation_matrix=rotation_matrix,
383
- out=out,
390
+ out=out_be,
384
391
  out_mask=out_mask,
392
+ order=1,
393
+ use_geometric_center=True,
385
394
  )
395
+ out_be = backend.to_numpy_array(out_be)
396
+ assert np.allclose(out, out_be, atol=0.3)
386
397
 
387
- assert np.round(arr.sum(), 3) == np.round(out.sum(), 3)
398
+ @pytest.mark.parametrize("backend", BACKENDS_TO_TEST)
399
+ @pytest.mark.parametrize("create_mask", (False, True))
400
+ def test_rigid_transform_identity(self, backend, create_mask):
401
+ dim = 3
402
+ shape = tuple(50 for _ in range(dim))
403
+
404
+ arr = np.zeros(shape)
405
+ arr[20:25, 21:26, 26:31] = 1
406
+
407
+ rotation_matrix = backend.to_backend_array(np.eye(dim))
408
+ out, _ = backend.rigid_transform(
409
+ arr=backend.to_backend_array(arr),
410
+ rotation_matrix=backend.to_backend_array(rotation_matrix),
411
+ order=1,
412
+ use_geometric_center=True,
413
+ )
414
+ assert np.allclose(out, arr, atol=0.01)
388
415
 
389
416
  @pytest.mark.parametrize("dim", (2, 3))
390
417
  @pytest.mark.parametrize("backend", BACKENDS_TO_TEST)
@@ -418,6 +445,7 @@ class TestBackends:
418
445
  out=out,
419
446
  out_mask=out_mask,
420
447
  batched=True,
448
+ order=1,
421
449
  )
422
450
 
423
451
  arr_b = backend.to_backend_array(arr_b)
@@ -430,6 +458,7 @@ class TestBackends:
430
458
  rotation_matrix=rotation_matrix,
431
459
  out=out_b[i],
432
460
  out_mask=out_mask if out_mask is None else out_mask[i],
461
+ order=1,
433
462
  )
434
463
 
435
464
  assert np.allclose(arr, arr_b)
tests/test_density.py CHANGED
@@ -101,9 +101,10 @@ class TestDensity:
101
101
 
102
102
  def test_from_file_baseline(self):
103
103
  self.test_to_file(gzip=False)
104
- density = Density.from_file(str(BASEPATH.joinpath("Maps/emd_8621.mrc.gz")))
105
- assert np.allclose(density.origin, (4.35, 2.90, -1.45), rtol=0.1)
106
- assert np.allclose(density.sampling_rate, (1.45), rtol=0.3)
104
+ density = Density.from_file(str(BASEPATH.joinpath("Raw/em_map.map")))
105
+ assert density.shape == (19, 14, 20)
106
+ assert np.allclose(density.origin, (-52.8, -10.56, -52.8), rtol=0.1)
107
+ assert np.allclose(density.sampling_rate, (5.28), rtol=0.1)
107
108
 
108
109
  @pytest.mark.parametrize("extension", ("mrc", "em", "tiff", "h5"))
109
110
  @pytest.mark.parametrize("gzip", (True, False))
@@ -315,7 +316,7 @@ class TestDensity:
315
316
  "cutoff", [DEFAULT_DATA.min() - 1, 0, DEFAULT_DATA.max() - 0.1]
316
317
  )
317
318
  def test_centered(self, cutoff):
318
- centered_density, translation = self.density.centered(cutoff=cutoff)
319
+ centered_density = self.density.centered(cutoff=cutoff)
319
320
  com = centered_density.center_of_mass(centered_density.data, 0)
320
321
 
321
322
  difference = np.abs(
@@ -334,7 +335,7 @@ class TestDensity:
334
335
  box = temp.minimum_enclosing_box(cutoff=0, use_geometric_center=True)
335
336
  temp.adjust_box(box)
336
337
  else:
337
- temp, translation = temp.centered()
338
+ temp = temp.centered()
338
339
 
339
340
  swaps = set(permutations([0, 1, 2]))
340
341
  temp_matrix = np.eye(temp.data.ndim).astype(np.float32)
@@ -436,7 +437,7 @@ class TestDensity:
436
437
  arr = np.linalg.norm(np.indices(shape) - position, axis=0)
437
438
  arr = (arr <= radius).astype(np.float32)
438
439
 
439
- center_of_mass = Density.center_of_mass(arr)
440
+ center_of_mass = Density(arr).center_of_mass()
440
441
  assert np.allclose(center, center_of_mass)
441
442
 
442
443
  @pytest.mark.parametrize(
@@ -451,7 +452,7 @@ class TestDensity:
451
452
  target[5:10, 15:22, 10:13] = 1
452
453
 
453
454
  target = Density(target, sampling_rate=(1, 1, 1), origin=(0, 0, 0))
454
- target, translation = target.centered(cutoff=0)
455
+ target = target.centered(cutoff=0)
455
456
 
456
457
  template = target.copy()
457
458
 
@@ -474,21 +475,79 @@ class TestDensity:
474
475
  assert np.allclose(np.linalg.inv(rotation), initial_rotation, atol=0.2)
475
476
 
476
477
  def test_match_structure_to_density(self):
477
- density = Density.from_file("tests/data/Maps/emd_8621.mrc.gz")
478
- density = density.resample(density.sampling_rate * 4)
479
- structure = Structure.from_file(
480
- "tests/data/Structures/5uz4.cif", filter_by_residues=None
478
+ mask = create_mask(
479
+ mask_type="ellipse",
480
+ center=(20, 20, 20),
481
+ radius=(10, 5, 10),
482
+ shape=(50, 50, 50),
481
483
  )
482
484
 
485
+ coordinates = np.array(np.where(mask > 0), dtype=np.float32).T
486
+ structure = Structure(
487
+ record_type=[
488
+ "ATOM",
489
+ ]
490
+ * coordinates.shape[0],
491
+ atom_serial_number=list(range(coordinates.shape[0])),
492
+ atom_name=[
493
+ "C",
494
+ ]
495
+ * coordinates.shape[0],
496
+ atom_coordinate=coordinates,
497
+ alternate_location_indicator=[
498
+ ".",
499
+ ]
500
+ * coordinates.shape[0],
501
+ residue_name=[
502
+ "GLY",
503
+ ]
504
+ * coordinates.shape[0],
505
+ chain_identifier=[
506
+ "A",
507
+ ]
508
+ * coordinates.shape[0],
509
+ residue_sequence_number=[
510
+ 0,
511
+ ]
512
+ * coordinates.shape[0],
513
+ code_for_residue_insertion=[
514
+ "?",
515
+ ]
516
+ * coordinates.shape[0],
517
+ occupancy=[
518
+ 0,
519
+ ]
520
+ * coordinates.shape[0],
521
+ temperature_factor=[
522
+ 0,
523
+ ]
524
+ * coordinates.shape[0],
525
+ segment_identifier=[
526
+ "1",
527
+ ]
528
+ * coordinates.shape[0],
529
+ element_symbol=[
530
+ "C",
531
+ ]
532
+ * coordinates.shape[0],
533
+ charge=[
534
+ "?",
535
+ ]
536
+ * coordinates.shape[0],
537
+ metadata={},
538
+ )
539
+ density = Density(mask, sampling_rate=1.0)
540
+ density = density.resample(density.sampling_rate * 2)
541
+
483
542
  initial_translation = np.array([-1, 0, 5])
484
543
  initial_rotation = euler_to_rotationmatrix((-10, 2, 5))
485
- structure.rigid_transform(
544
+ structure_mod = structure.rigid_transform(
486
545
  translation=initial_translation, rotation_matrix=initial_rotation
487
546
  )
488
547
  np.random.seed(12)
489
548
  ret = Density.match_structure_to_density(
490
549
  target=density,
491
- template=structure,
550
+ template=structure_mod,
492
551
  cutoff_target=0,
493
552
  scoring_method="CrossCorrelation",
494
553
  maxiter=10,
tests/test_extensions.py CHANGED
@@ -53,6 +53,7 @@ class TestExtensions:
53
53
  @pytest.mark.parametrize("min_distance", [0, 5, 10])
54
54
  def test_find_candidate_indices(self, dimension, dtype, min_distance):
55
55
  coordinates = COORDINATES[dimension].astype(dtype)
56
+ print(coordinates.shape)
56
57
 
57
58
  min_distance = np.array([min_distance]).astype(dtype)[0]
58
59
 
@@ -10,7 +10,7 @@ from tme import Density, Orientations
10
10
  from tme.backends import backend as be
11
11
 
12
12
  np.random.seed(42)
13
- available_backends = (x for x in be.available_backends() if x != "mlx")
13
+ available_backends = tuple(x for x in be.available_backends() if x != "mlx")
14
14
 
15
15
 
16
16
  def argdict_to_command(input_args, executable: str):
@@ -29,7 +29,7 @@ def argdict_to_command(input_args, executable: str):
29
29
  return " ".join(ret)
30
30
 
31
31
 
32
- class TestMatchTemplate:
32
+ class TestSetup:
33
33
  @classmethod
34
34
  def setup_class(cls):
35
35
  target = np.random.rand(20, 20, 20)
@@ -98,6 +98,7 @@ class TestMatchTemplate:
98
98
  use_target_mask: bool = False,
99
99
  backend: str = "numpyfftw",
100
100
  test_rejection_sampling: bool = False,
101
+ background_correction: str = None,
101
102
  ):
102
103
  output_path = tempfile.NamedTemporaryFile(delete=False, suffix="pickle").name
103
104
 
@@ -120,6 +121,9 @@ class TestMatchTemplate:
120
121
  if test_rejection_sampling:
121
122
  argdict["--orientations"] = self.orientations_path
122
123
 
124
+ if background_correction is not None:
125
+ argdict["--background-correction"] = background_correction
126
+
123
127
  if test_filter:
124
128
  argdict["--lowpass"] = 30
125
129
  argdict["--defocus"] = 3000
@@ -136,11 +140,15 @@ class TestMatchTemplate:
136
140
  assert ret.returncode == 0
137
141
  return output_path
138
142
 
143
+
144
+ class TestMatchTemplate(TestSetup):
145
+
139
146
  @pytest.mark.parametrize("backend", available_backends)
140
147
  @pytest.mark.parametrize("call_peaks", (False, True))
141
148
  @pytest.mark.parametrize("use_template_mask", (False, True))
142
149
  @pytest.mark.parametrize("test_filter", (False, True))
143
150
  @pytest.mark.parametrize("test_rejection_sampling", (False, True))
151
+ @pytest.mark.parametrize("background_correction", (None, "phase-scrambling"))
144
152
  def test_match_template(
145
153
  self,
146
154
  backend: bool,
@@ -148,8 +156,14 @@ class TestMatchTemplate:
148
156
  use_template_mask: bool,
149
157
  test_filter: bool,
150
158
  test_rejection_sampling: bool,
159
+ background_correction: str,
151
160
  ):
152
- if backend == "jax" and (call_peaks or test_rejection_sampling):
161
+ # Jax does not support peak calling yet
162
+ if backend == "jax" and call_peaks:
163
+ return None
164
+
165
+ # These use different analyzers raising an error in the interface
166
+ if call_peaks and test_rejection_sampling:
153
167
  return None
154
168
 
155
169
  self.run_matching(
@@ -163,10 +177,11 @@ class TestMatchTemplate:
163
177
  template_mask_path=self.template_mask_path,
164
178
  target_mask_path=self.target_mask_path,
165
179
  test_rejection_sampling=test_rejection_sampling,
180
+ background_correction=background_correction,
166
181
  )
167
182
 
168
183
 
169
- class TestPostprocessing(TestMatchTemplate):
184
+ class TestPostprocessing(TestSetup):
170
185
  @classmethod
171
186
  def setup_class(cls):
172
187
  super().setup_class()
@@ -256,7 +271,6 @@ class TestPostprocessing(TestMatchTemplate):
256
271
  }
257
272
  cmd = argdict_to_command(argdict, executable="postprocess.py")
258
273
  ret = subprocess.run(cmd, capture_output=True, shell=True)
259
- print(ret)
260
274
 
261
275
  match output_format:
262
276
  case "orientations":
@@ -289,14 +303,14 @@ class TestPostprocessing(TestMatchTemplate):
289
303
  assert ret.returncode == 0
290
304
 
291
305
 
292
- class TestEstimateMemoryUsage(TestMatchTemplate):
306
+ class TestEstimateMemoryUsage(TestSetup):
293
307
  @classmethod
294
308
  def setup_class(cls):
295
309
  super().setup_class()
296
310
 
297
311
  @pytest.mark.parametrize("ncores", (1, 4, 8))
298
312
  @pytest.mark.parametrize("pad_edges", (False, True))
299
- def test_estimation(self, ncores, pad_edges):
313
+ def test_estimation_cli(self, ncores, pad_edges):
300
314
 
301
315
  argdict = {
302
316
  "-m": self.target_path,
@@ -311,7 +325,7 @@ class TestEstimateMemoryUsage(TestMatchTemplate):
311
325
  assert ret.returncode == 0
312
326
 
313
327
 
314
- class TestPreprocess(TestMatchTemplate):
328
+ class TestPreprocess(TestSetup):
315
329
  @classmethod
316
330
  def setup_class(cls):
317
331
  super().setup_class()
@@ -319,7 +333,7 @@ class TestPreprocess(TestMatchTemplate):
319
333
  @pytest.mark.parametrize("backend", available_backends)
320
334
  @pytest.mark.parametrize("align_axis", (False, True))
321
335
  @pytest.mark.parametrize("invert_contrast", (False, True))
322
- def test_estimation(self, backend, align_axis, invert_contrast):
336
+ def test_preprocess_cli(self, backend, align_axis, invert_contrast):
323
337
 
324
338
  argdict = {
325
339
  "-m": self.target_path,
@@ -1,5 +1,5 @@
1
- import numpy as np
2
1
  import pytest
2
+ import numpy as np
3
3
 
4
4
  from scipy.ndimage import laplace
5
5
 
@@ -7,7 +7,7 @@ from tme.matching_data import MatchingData
7
7
  from tme.memory import MATCHING_MEMORY_REGISTRY
8
8
  from tme.analyzer import MaxScoreOverRotations, PeakCallerSort
9
9
  from tme.matching_exhaustive import (
10
- scan_subsets,
10
+ match_exhaustive,
11
11
  MATCHING_EXHAUSTIVE_REGISTER,
12
12
  register_matching_exhaustive,
13
13
  )
@@ -35,11 +35,11 @@ class TestMatchExhaustive:
35
35
  self.coordinates_weights = None
36
36
  self.rotations = None
37
37
 
38
- @pytest.mark.parametrize("evaluate_peak", (True,))
38
+ @pytest.mark.parametrize("evaluate_peak", (True, False))
39
39
  @pytest.mark.parametrize("score", tuple(MATCHING_EXHAUSTIVE_REGISTER.keys()))
40
40
  @pytest.mark.parametrize("job_schedule", ((2, 1),))
41
41
  @pytest.mark.parametrize("pad_edge", (False, True))
42
- def test_scan_subset(
42
+ def test_match_exhaustive(
43
43
  self,
44
44
  score: str,
45
45
  job_schedule: int,
@@ -64,7 +64,7 @@ class TestMatchExhaustive:
64
64
  if evaluate_peak:
65
65
  callback_class = PeakCallerSort
66
66
 
67
- ret = scan_subsets(
67
+ ret = match_exhaustive(
68
68
  matching_data=matching_data,
69
69
  matching_setup=setup,
70
70
  matching_score=process,
@@ -15,7 +15,7 @@ from tme.matching_utils import (
15
15
  apply_convolution_mode,
16
16
  write_pickle,
17
17
  load_pickle,
18
- _normalize_template_overflow_safe,
18
+ _standardize_safe,
19
19
  )
20
20
 
21
21
  BASEPATH = files("tests.data")
@@ -127,12 +127,12 @@ class TestMatchingUtils:
127
127
  loaded_data = load_pickle(filename)
128
128
  assert np.array_equal(loaded_data, data)
129
129
 
130
- def test_normalize_template_overflow_safe(self):
130
+ def test_standardize_safe(self):
131
131
  template = be.random.random((10, 10)).astype(be.float32)
132
132
  mask = be.ones_like(template)
133
133
  n_observations = 100.0
134
134
 
135
- result = _normalize_template_overflow_safe(template, mask, n_observations)
135
+ result = _standardize_safe(template, mask, n_observations)
136
136
  assert result.shape == template.shape
137
137
  assert result.dtype == template.dtype
138
138
  assert np.allclose(result.mean(), 0, atol=0.1)
tests/test_rotations.py CHANGED
@@ -35,24 +35,19 @@ class TestRotations:
35
35
  )
36
36
 
37
37
  @pytest.mark.parametrize(
38
- "initial_vector, target_vector, convention",
38
+ "initial_vector, target_vector",
39
39
  [
40
- ([1, 0, 0], [0, 1, 0], None),
41
- ([0, 1, 0], [0, 0, 1], "zyx"),
42
- ([1, 1, 1], [1, 0, 0], "xyz"),
40
+ ([1, 0, 0], [0, 1, 0]),
41
+ ([0, 1, 0], [0, 0, 1]),
42
+ ([1, 1, 1], [1, 0, 0]),
43
43
  ],
44
44
  )
45
- def test_align_vectors(self, initial_vector, target_vector, convention):
46
- result = align_vectors(initial_vector, target_vector, convention)
45
+ def test_align_vectors(self, initial_vector, target_vector):
46
+ result = align_vectors(initial_vector, target_vector)
47
47
 
48
48
  assert isinstance(result, np.ndarray)
49
- if convention is None:
50
- assert result.shape == (3, 3)
51
- assert np.allclose(np.dot(result, result.T), np.eye(3), atol=1e-6)
52
- else:
53
- assert len(result) == 3
54
- result = Rotation.from_euler(convention, result, degrees=True).as_matrix()
55
- assert np.allclose(np.dot(result, result.T), np.eye(3), atol=1e-6)
49
+ assert result.shape == (3, 3)
50
+ assert np.allclose(np.dot(result, result.T), np.eye(3), atol=1e-6)
56
51
 
57
52
  rotated = np.dot(Rotation.from_matrix(result).as_matrix(), initial_vector)
58
53
  assert np.allclose(
@@ -62,11 +57,11 @@ class TestRotations:
62
57
  )
63
58
 
64
59
  @pytest.mark.parametrize(
65
- "cone_angle, cone_sampling, axis_angle, axis_sampling, vector, n_symmetry, convention",
60
+ "cone_angle, cone_sampling, axis_angle, axis_sampling, vector, n_symmetry",
66
61
  [
67
- (30, 5, 360, None, (1, 0, 0), 1, None),
68
- (45, 10, 180, 15, (0, 1, 0), 2, "zyx"),
69
- (60, 15, 90, 30, (0, 0, 1), 4, "xyz"),
62
+ (30, 5, 360, None, (1, 0, 0), 1),
63
+ (45, 10, 180, 15, (0, 1, 0), 2),
64
+ (60, 15, 90, 30, (0, 0, 1), 4),
70
65
  ],
71
66
  )
72
67
  def test_get_cone_rotations(
@@ -77,7 +72,6 @@ class TestRotations:
77
72
  axis_sampling,
78
73
  vector,
79
74
  n_symmetry,
80
- convention,
81
75
  ):
82
76
  result = get_cone_rotations(
83
77
  cone_angle=cone_angle,
@@ -86,14 +80,10 @@ class TestRotations:
86
80
  axis_sampling=axis_sampling,
87
81
  reference=vector,
88
82
  n_symmetry=n_symmetry,
89
- seq=convention,
90
83
  )
91
84
 
92
85
  assert isinstance(result, np.ndarray)
93
- if convention is None:
94
- assert result.shape[1:] == (3, 3)
95
- else:
96
- assert result.shape[1] == 3
86
+ assert result.shape[1:] == (3, 3)
97
87
 
98
88
  def test_euler_conversion(self):
99
89
  rotation_matrix_initial = np.array(
tests/test_structure.py CHANGED
@@ -6,7 +6,6 @@ import pytest
6
6
  import numpy as np
7
7
 
8
8
  from tme import Structure
9
- from tme.matching_utils import minimum_enclosing_box
10
9
  from tme.rotations import euler_to_rotationmatrix
11
10
 
12
11
 
@@ -196,11 +195,6 @@ class TestStructure:
196
195
  assert center_of_mass.shape[0] == self.structure.atom_coordinate.shape[1]
197
196
  assert np.allclose(center_of_mass, [-0.89391639, 29.94908928, -2.64736741])
198
197
 
199
- def test_centered(self):
200
- ret, translation = self.structure.centered()
201
- box = minimum_enclosing_box(coordinates=self.structure.atom_coordinate.T)
202
- assert np.allclose(ret.center_of_mass(), np.divide(box, 2), atol=1)
203
-
204
198
  def test__get_atom_weights_error(self):
205
199
  with pytest.raises(NotImplementedError):
206
200
  self.structure._get_atom_weights(
@@ -242,6 +236,6 @@ class TestStructure:
242
236
  assert final_rmsd <= 0.1
243
237
 
244
238
  aligned, final_rmsd = Structure.align_structures(
245
- self.structure, structure_transform, sampling_rate=1
239
+ self.structure, structure_transform
246
240
  )
247
241
  assert final_rmsd <= 1
tme/__version__.py CHANGED
@@ -1 +1 @@
1
- __version__ = "0.3.1"
1
+ __version__ = "0.3.2"