pytme 0.2.0__cp311-cp311-macosx_14_0_arm64.whl → 0.2.1__cp311-cp311-macosx_14_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. {pytme-0.2.0.data → pytme-0.2.1.data}/scripts/match_template.py +183 -69
  2. {pytme-0.2.0.data → pytme-0.2.1.data}/scripts/postprocess.py +107 -49
  3. {pytme-0.2.0.data → pytme-0.2.1.data}/scripts/preprocessor_gui.py +4 -1
  4. {pytme-0.2.0.dist-info → pytme-0.2.1.dist-info}/METADATA +1 -1
  5. pytme-0.2.1.dist-info/RECORD +73 -0
  6. scripts/extract_candidates.py +117 -85
  7. scripts/match_template.py +183 -69
  8. scripts/match_template_filters.py +193 -71
  9. scripts/postprocess.py +107 -49
  10. scripts/preprocessor_gui.py +4 -1
  11. scripts/refine_matches.py +364 -160
  12. tme/__version__.py +1 -1
  13. tme/analyzer.py +259 -117
  14. tme/backends/__init__.py +1 -0
  15. tme/backends/cupy_backend.py +20 -13
  16. tme/backends/jax_backend.py +218 -0
  17. tme/backends/matching_backend.py +25 -10
  18. tme/backends/mlx_backend.py +13 -9
  19. tme/backends/npfftw_backend.py +20 -8
  20. tme/backends/pytorch_backend.py +20 -9
  21. tme/density.py +79 -60
  22. tme/extensions.cpython-311-darwin.so +0 -0
  23. tme/matching_data.py +85 -61
  24. tme/matching_exhaustive.py +222 -129
  25. tme/matching_optimization.py +117 -76
  26. tme/orientations.py +175 -55
  27. tme/preprocessing/_utils.py +17 -5
  28. tme/preprocessing/composable_filter.py +2 -1
  29. tme/preprocessing/compose.py +1 -2
  30. tme/preprocessing/frequency_filters.py +97 -41
  31. tme/preprocessing/tilt_series.py +137 -87
  32. tme/preprocessor.py +3 -0
  33. tme/structure.py +4 -1
  34. pytme-0.2.0.dist-info/RECORD +0 -72
  35. {pytme-0.2.0.data → pytme-0.2.1.data}/scripts/estimate_ram_usage.py +0 -0
  36. {pytme-0.2.0.data → pytme-0.2.1.data}/scripts/preprocess.py +0 -0
  37. {pytme-0.2.0.dist-info → pytme-0.2.1.dist-info}/LICENSE +0 -0
  38. {pytme-0.2.0.dist-info → pytme-0.2.1.dist-info}/WHEEL +0 -0
  39. {pytme-0.2.0.dist-info → pytme-0.2.1.dist-info}/entry_points.txt +0 -0
  40. {pytme-0.2.0.dist-info → pytme-0.2.1.dist-info}/top_level.txt +0 -0
tme/orientations.py CHANGED
@@ -1,5 +1,5 @@
1
1
  #!python3
2
- """ Handle template matching peaks and convert between formats.
2
+ """ Handle template matching orientations and conversion between formats.
3
3
 
4
4
  Copyright (c) 2024 European Molecular Biology Laboratory
5
5
 
@@ -8,6 +8,7 @@
8
8
  import re
9
9
  from collections import deque
10
10
  from dataclasses import dataclass
11
+ from string import ascii_lowercase
11
12
  from typing import List, Tuple, Dict
12
13
 
13
14
  import numpy as np
@@ -17,7 +18,48 @@ from scipy.spatial.transform import Rotation
17
18
  @dataclass
18
19
  class Orientations:
19
20
  """
20
- Handle template matching peaks and convert between formats.
21
+ Handle template matching orientations and conversion between formats.
22
+
23
+ Examples
24
+ --------
25
+ The following achieves the minimal definition of an :py:class:`Orientations` instance
26
+
27
+ >>> import numpy as np
28
+ >>> from tme import Orientations
29
+ >>> translations = np.random.randint(low = 0, high = 100, size = (100,3))
30
+ >>> rotations = np.random.rand(100, 3)
31
+ >>> scores = np.random.rand(100)
32
+ >>> details = np.full((100,), fill_value = -1)
33
+ >>> orientations = Orientations(
34
+ >>> translations = translations,
35
+ >>> rotations = rotations,
36
+ >>> scores = scores,
37
+ >>> details = details,
38
+ >>> )
39
+
40
+ The created ``orientations`` object can be written to disk in a range of formats.
41
+ See :py:meth:`Orientations.to_file` for available formats. The following creates
42
+ a STAR file
43
+
44
+ >>> orientations.to_file("test.star")
45
+
46
+ :py:meth:`Orientations.from_file` can create :py:class:`Orientations` instances
47
+ from a range of formats, to enable conversion between formats
48
+
49
+ >>> orientations_star = Orientations.from_file("test.star")
50
+ >>> np.all(orientations.translations == orientations_star.translations)
51
+ True
52
+
53
+ Parameters
54
+ ----------
55
+ translations: np.ndarray
56
+ Array with translations of each orientations (n, d).
57
+ rotations: np.ndarray
58
+ Array with euler angles of each orientation in zxy convention (n, d).
59
+ scores: np.ndarray
60
+ Array with the score of each orientation (n, ).
61
+ details: np.ndarray
62
+ Array with additional orientation details (n, ).
21
63
  """
22
64
 
23
65
  #: Return a numpy array with translations of each orientation (n x d).
@@ -32,6 +74,29 @@ class Orientations:
32
74
  #: Return a numpy array with additional orientation details (n, ).
33
75
  details: np.ndarray
34
76
 
77
+ def __post_init__(self):
78
+ self.translations = np.array(self.translations).astype(np.float32)
79
+ self.rotations = np.array(self.rotations).astype(np.float32)
80
+ self.scores = np.array(self.scores).astype(np.float32)
81
+ self.details = np.array(self.details).astype(np.float32)
82
+ n_orientations = set(
83
+ [
84
+ self.translations.shape[0],
85
+ self.rotations.shape[0],
86
+ self.scores.shape[0],
87
+ self.details.shape[0],
88
+ ]
89
+ )
90
+ if len(n_orientations) != 1:
91
+ raise ValueError(
92
+ "The first dimension of all parameters needs to be of equal length."
93
+ )
94
+ if self.translations.ndim != 2:
95
+ raise ValueError("Expected two dimensional translations parameter.")
96
+
97
+ if self.rotations.ndim != 2:
98
+ raise ValueError("Expected two dimensional rotations parameter.")
99
+
35
100
  def __iter__(self) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
36
101
  """
37
102
  Iterate over the current class instance. Each iteration returns a orientation
@@ -77,7 +142,17 @@ class Orientations:
77
142
  filename : str
78
143
  The name of the file where the orientations will be saved.
79
144
  file_format : type, optional
80
- The format in which to save the orientations. Supported formats are 'text' and 'relion'.
145
+ The format in which to save the orientations. Defaults to None and infers
146
+ the file_format from the typical extension. Supported formats are
147
+
148
+ +---------------+----------------------------------------------------+
149
+ | text | pyTME's standard tab-separated orientations file |
150
+ +---------------+----------------------------------------------------+
151
+ | relion | Creates a STAR file of orientations |
152
+ +---------------+----------------------------------------------------+
153
+ | dynamo | Creates a dynamo table |
154
+ +---------------+----------------------------------------------------+
155
+
81
156
  **kwargs : dict
82
157
  Additional keyword arguments specific to the file format.
83
158
 
@@ -120,8 +195,14 @@ class Orientations:
120
195
  The file is saved with a header specifying each column: z, y, x, euler_z,
121
196
  euler_y, euler_x, score, detail. Each row in the file corresponds to an orientation.
122
197
  """
198
+ naming = ascii_lowercase[::-1]
123
199
  header = "\t".join(
124
- ["z", "y", "x", "euler_z", "euler_y", "euler_x", "score", "detail"]
200
+ [
201
+ *list(naming[: self.translations.shape[1]]),
202
+ *[f"euler_{x}" for x in naming[: self.rotations.shape[1]]],
203
+ "score",
204
+ "detail",
205
+ ]
125
206
  )
126
207
  with open(filename, mode="w", encoding="utf-8") as ofile:
127
208
  _ = ofile.write(f"{header}\n")
@@ -158,9 +239,6 @@ class Orientations:
158
239
  References
159
240
  ----------
160
241
  .. [1] https://wiki.dynamo.biozentrum.unibas.ch/w/index.php/Table
161
-
162
- The file is saved with a standard header used in Dynamo STAR files.
163
- Each row in the file corresponds to an orientation.
164
242
  """
165
243
  with open(filename, mode="w", encoding="utf-8") as ofile:
166
244
  for index, (translation, rotation, score, detail) in enumerate(self):
@@ -316,8 +394,18 @@ class Orientations:
316
394
  filename : str
317
395
  The name of the file from which to read the orientations.
318
396
  file_format : type, optional
319
- The format of the file. Currently, only 'text' format is supported.
320
- **kwargs : dict
397
+ The format of the file. Defaults to None and infers
398
+ the file_format from the typical extension. Supported formats are
399
+
400
+ +---------------+----------------------------------------------------+
401
+ | text | pyTME's standard tab-separated orientations file |
402
+ +---------------+----------------------------------------------------+
403
+ | relion | Creates a STAR file of orientations |
404
+ +---------------+----------------------------------------------------+
405
+ | dynamo | Creates a dynamo table |
406
+ +---------------+----------------------------------------------------+
407
+
408
+ **kwargs
321
409
  Additional keyword arguments specific to the file format.
322
410
 
323
411
  Returns
@@ -330,11 +418,18 @@ class Orientations:
330
418
  ValueError
331
419
  If an unsupported file format is specified.
332
420
  """
333
- mapping = {"text": cls._from_text, "relion": cls._from_relion_star}
421
+ mapping = {
422
+ "text": cls._from_text,
423
+ "relion": cls._from_relion_star,
424
+ "tbl": cls._from_tbl,
425
+ }
334
426
  if file_format is None:
335
427
  file_format = "text"
428
+
336
429
  if filename.lower().endswith(".star"):
337
430
  file_format = "relion"
431
+ elif filename.lower().endswith(".tbl"):
432
+ file_format = "tbl"
338
433
 
339
434
  func = mapping.get(file_format, None)
340
435
  if func is None:
@@ -375,25 +470,28 @@ class Orientations:
375
470
  """
376
471
  with open(filename, mode="r", encoding="utf-8") as infile:
377
472
  data = [x.strip().split("\t") for x in infile.read().split("\n")]
378
- _ = data.pop(0)
379
473
 
474
+ header = data.pop(0)
380
475
  translation, rotation, score, detail = [], [], [], []
381
476
  for candidate in data:
382
477
  if len(candidate) <= 1:
383
478
  continue
384
- if len(candidate) != 8:
385
- candidate.append(-1)
386
479
 
387
- candidate = [float(x) for x in candidate]
388
- translation.append((candidate[0], candidate[1], candidate[2]))
389
- rotation.append((candidate[3], candidate[4], candidate[5]))
390
- score.append(candidate[6])
391
- detail.append(candidate[7])
480
+ translation.append(
481
+ tuple(
482
+ candidate[i] for i, x in enumerate(header) if x in ascii_lowercase
483
+ )
484
+ )
485
+ rotation.append(
486
+ tuple(candidate[i] for i, x in enumerate(header) if "euler" in x)
487
+ )
488
+ score.append(candidate[-2])
489
+ detail.append(candidate[-1])
392
490
 
393
- translation = np.vstack(translation).astype(int)
394
- rotation = np.vstack(rotation).astype(float)
395
- score = np.array(score).astype(float)
396
- detail = np.array(detail).astype(float)
491
+ translation = np.vstack(translation)
492
+ rotation = np.vstack(rotation)
493
+ score = np.array(score)
494
+ detail = np.array(detail)
397
495
 
398
496
  return translation, rotation, score, detail
399
497
 
@@ -448,20 +546,15 @@ class Orientations:
448
546
  ret = cls._parse_star(filename=filename, delimiter=delimiter)
449
547
  ret = ret["data_particles"]
450
548
 
451
- translation = (
452
- np.vstack(
453
- (ret["_rlnCoordinateZ"], ret["_rlnCoordinateY"], ret["_rlnCoordinateX"])
454
- )
455
- .astype(np.float32)
456
- .astype(int)
457
- .T
549
+ translation = np.vstack(
550
+ (ret["_rlnCoordinateZ"], ret["_rlnCoordinateY"], ret["_rlnCoordinateX"])
458
551
  )
552
+ translation = translation.astype(np.float32).T
459
553
 
460
- rotation = (
461
- np.vstack((ret["_rlnAngleRot"], ret["_rlnAngleTilt"], ret["_rlnAnglePsi"]))
462
- .astype(np.float32)
463
- .T
554
+ rotation = np.vstack(
555
+ (ret["_rlnAngleRot"], ret["_rlnAngleTilt"], ret["_rlnAnglePsi"])
464
556
  )
557
+ rotation = rotation.astype(np.float32).T
465
558
 
466
559
  rotation = Rotation.from_euler("xyx", rotation, degrees=True)
467
560
  rotation = rotation.as_euler(seq="zyx", degrees=True)
@@ -470,6 +563,33 @@ class Orientations:
470
563
 
471
564
  return translation, rotation, score, detail
472
565
 
566
+ @staticmethod
567
+ def _from_tbl(
568
+ filename: str, **kwargs
569
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
570
+ with open(filename, mode="r", encoding="utf-8") as infile:
571
+ data = infile.read().split("\n")
572
+ data = [x.strip().split(" ") for x in data if len(x.strip())]
573
+
574
+ if len(data[0]) != 38:
575
+ raise ValueError(
576
+ "Expected tbl file to have 38 columns generated by _to_tbl."
577
+ )
578
+
579
+ translations, rotations, scores, details = [], [], [], []
580
+ for peak in data:
581
+ rotation = Rotation.from_euler(
582
+ "xyx", (peak[6], peak[7], peak[8]), degrees=True
583
+ )
584
+ rotations.append(rotation.as_euler(seq="zyx", degrees=True))
585
+ scores.append(peak[9])
586
+ details.append(-1)
587
+ translations.append((peak[25], peak[24], peak[23]))
588
+
589
+ translations, rotations = np.array(translations), np.array(rotations)
590
+ scores, details = np.array(scores), np.array(details)
591
+ return translations, rotations, scores, details
592
+
473
593
  def get_extraction_slices(
474
594
  self,
475
595
  target_shape: Tuple[int],
@@ -504,55 +624,55 @@ class Orientations:
504
624
  SystemExit
505
625
  If no peak remains after filtering, indicating an error.
506
626
  """
507
- left_pad = np.divide(extraction_shape, 2).astype(int)
508
- right_pad = np.add(left_pad, np.mod(extraction_shape, 2)).astype(int)
627
+ right_pad = np.divide(extraction_shape, 2).astype(int)
628
+ left_pad = np.add(right_pad, np.mod(extraction_shape, 2)).astype(int)
629
+
630
+ peaks = self.translations.astype(int)
631
+ obs_beg = np.subtract(peaks, left_pad)
632
+ obs_end = np.add(peaks, right_pad)
509
633
 
510
- obs_start = np.subtract(self.translations, left_pad)
511
- obs_stop = np.add(self.translations, right_pad)
634
+ obs_beg = np.maximum(obs_beg, 0)
635
+ obs_end = np.minimum(obs_end, target_shape)
512
636
 
513
- cand_start = np.subtract(np.maximum(obs_start, 0), obs_start)
514
- cand_stop = np.subtract(obs_stop, np.minimum(obs_stop, target_shape))
515
- cand_stop = np.subtract(extraction_shape, cand_stop)
516
- obs_start = np.maximum(obs_start, 0)
517
- obs_stop = np.minimum(obs_stop, target_shape)
637
+ cand_beg = left_pad - np.subtract(peaks, obs_beg)
638
+ cand_end = left_pad + np.subtract(obs_end, peaks)
518
639
 
519
640
  subset = self
520
641
  if drop_out_of_box:
521
- stops = np.subtract(cand_stop, extraction_shape)
642
+ stops = np.subtract(cand_end, extraction_shape)
522
643
  keep_peaks = (
523
644
  np.sum(
524
- np.multiply(cand_start == 0, stops == 0),
645
+ np.multiply(cand_beg == 0, stops == 0),
525
646
  axis=1,
526
647
  )
527
- == self.translations.shape[1]
648
+ == peaks.shape[1]
528
649
  )
529
650
  n_remaining = keep_peaks.sum()
530
651
  if n_remaining == 0:
531
652
  print(
532
653
  "No peak remaining after filtering. Started with"
533
- f" {self.translations.shape[0]} filtered to {n_remaining}."
654
+ f" {peaks.shape[0]} filtered to {n_remaining}."
534
655
  " Consider reducing min_distance, increase num_peaks or use"
535
656
  " a different peak caller."
536
657
  )
537
- exit(-1)
538
658
 
539
- cand_start = cand_start[keep_peaks,]
540
- cand_stop = cand_stop[keep_peaks,]
541
- obs_start = obs_start[keep_peaks,]
542
- obs_stop = obs_stop[keep_peaks,]
659
+ cand_beg = cand_beg[keep_peaks,]
660
+ cand_end = cand_end[keep_peaks,]
661
+ obs_beg = obs_beg[keep_peaks,]
662
+ obs_end = obs_end[keep_peaks,]
543
663
  subset = self[keep_peaks]
544
664
 
545
- cand_start, cand_stop = cand_start.astype(int), cand_stop.astype(int)
546
- obs_start, obs_stop = obs_start.astype(int), obs_stop.astype(int)
665
+ cand_beg, cand_end = cand_beg.astype(int), cand_end.astype(int)
666
+ obs_beg, obs_end = obs_beg.astype(int), obs_end.astype(int)
547
667
 
548
668
  candidate_slices = [
549
669
  tuple(slice(s, e) for s, e in zip(start_row, stop_row))
550
- for start_row, stop_row in zip(cand_start, cand_stop)
670
+ for start_row, stop_row in zip(cand_beg, cand_end)
551
671
  ]
552
672
 
553
673
  observation_slices = [
554
674
  tuple(slice(s, e) for s, e in zip(start_row, stop_row))
555
- for start_row, stop_row in zip(obs_start, obs_stop)
675
+ for start_row, stop_row in zip(obs_beg, obs_end)
556
676
  ]
557
677
 
558
678
  if return_orientations:
@@ -133,11 +133,11 @@ def fftfreqn(
133
133
  NDArray
134
134
  The sample frequencies.
135
135
  """
136
- center = backend.astype(backend.divide(shape, 2), backend._default_dtype_int)
136
+ center = backend.astype(backend.divide(shape, 2), backend._int_dtype)
137
137
 
138
- norm = np.ones(3)
138
+ norm = np.ones(len(shape))
139
139
  if sampling_rate is not None:
140
- norm = backend.multiply(shape, sampling_rate).astype(int)
140
+ norm = backend.astype(backend.multiply(shape, sampling_rate), int)
141
141
 
142
142
  if shape_is_real_fourier:
143
143
  center[-1] = 0
@@ -151,9 +151,9 @@ def fftfreqn(
151
151
  indices = backend.transpose(indices)
152
152
 
153
153
  if compute_euclidean_norm:
154
- backend.square(indices, indices)
154
+ indices = backend.square(indices)
155
155
  indices = backend.sum(indices, axis=0)
156
- indices = backend.sqrt(indices)
156
+ backend.sqrt(indices, out=indices)
157
157
 
158
158
  return indices
159
159
 
@@ -174,3 +174,15 @@ def crop_real_fourier(data: NDArray) -> NDArray:
174
174
  """
175
175
  stop = 1 + (data.shape[-1] // 2)
176
176
  return data[..., :stop]
177
+
178
+
179
+ def shift_fourier(data: NDArray, shape_is_real_fourier: bool = False):
180
+ shift = backend.add(
181
+ backend.astype(backend.divide(data.shape, 2), int),
182
+ backend.mod(data.shape, 2),
183
+ )
184
+ if shape_is_real_fourier:
185
+ shift[-1] = 0
186
+
187
+ data = backend.roll(data, shift, tuple(i for i in range(len(shift))))
188
+ return data
@@ -8,6 +8,7 @@
8
8
  from typing import Dict
9
9
  from abc import ABC, abstractmethod
10
10
 
11
+
11
12
  class ComposableFilter(ABC):
12
13
  """
13
14
  Strategy class for composable filters.
@@ -27,4 +28,4 @@ class ComposableFilter(ABC):
27
28
  --------
28
29
  Dict
29
30
  A dictionary representing the result of the filtering operation.
30
- """
31
+ """
@@ -39,12 +39,11 @@ class Compose:
39
39
 
40
40
  meta = self.transforms[0](**kwargs)
41
41
  for transform in self.transforms[1:]:
42
-
43
42
  kwargs.update(meta)
44
43
  ret = transform(**kwargs)
45
44
 
46
45
  if ret.get("is_multiplicative_filter", False):
47
- backend.multiply(ret["data"], meta["data"], ret["data"])
46
+ backend.multiply(ret["data"], meta["data"], out=ret["data"])
48
47
  ret["merge"] = None
49
48
 
50
49
  meta = ret