pytme 0.2.2__cp311-cp311-macosx_14_0_arm64.whl → 0.2.3__cp311-cp311-macosx_14_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (33) hide show
  1. {pytme-0.2.2.data → pytme-0.2.3.data}/scripts/match_template.py +91 -142
  2. {pytme-0.2.2.data → pytme-0.2.3.data}/scripts/postprocess.py +20 -29
  3. pytme-0.2.3.data/scripts/preprocess.py +132 -0
  4. {pytme-0.2.2.data → pytme-0.2.3.data}/scripts/preprocessor_gui.py +6 -9
  5. {pytme-0.2.2.dist-info → pytme-0.2.3.dist-info}/METADATA +11 -10
  6. {pytme-0.2.2.dist-info → pytme-0.2.3.dist-info}/RECORD +33 -32
  7. pytme-0.2.2.data/scripts/preprocess.py → scripts/eval.py +1 -1
  8. scripts/match_template.py +91 -142
  9. scripts/postprocess.py +20 -29
  10. scripts/preprocess.py +95 -56
  11. scripts/preprocessor_gui.py +6 -9
  12. tme/__version__.py +1 -1
  13. tme/analyzer.py +9 -6
  14. tme/backends/__init__.py +1 -1
  15. tme/backends/_jax_utils.py +10 -8
  16. tme/backends/cupy_backend.py +2 -7
  17. tme/backends/jax_backend.py +34 -20
  18. tme/backends/npfftw_backend.py +3 -2
  19. tme/backends/pytorch_backend.py +10 -7
  20. tme/density.py +15 -8
  21. tme/extensions.cpython-311-darwin.so +0 -0
  22. tme/matching_data.py +24 -17
  23. tme/matching_exhaustive.py +36 -19
  24. tme/matching_scores.py +5 -2
  25. tme/matching_utils.py +7 -2
  26. tme/orientations.py +26 -9
  27. tme/preprocessing/composable_filter.py +7 -4
  28. tme/preprocessing/tilt_series.py +10 -32
  29. {pytme-0.2.2.data → pytme-0.2.3.data}/scripts/estimate_ram_usage.py +0 -0
  30. {pytme-0.2.2.dist-info → pytme-0.2.3.dist-info}/LICENSE +0 -0
  31. {pytme-0.2.2.dist-info → pytme-0.2.3.dist-info}/WHEEL +0 -0
  32. {pytme-0.2.2.dist-info → pytme-0.2.3.dist-info}/entry_points.txt +0 -0
  33. {pytme-0.2.2.dist-info → pytme-0.2.3.dist-info}/top_level.txt +0 -0
scripts/preprocess.py CHANGED
@@ -1,93 +1,132 @@
1
1
  #!python3
2
- """ Apply tme.preprocessor.Preprocessor methods to an input file based
3
- on a provided yaml configuration obtaiend from preprocessor_gui.py.
2
+ """ Preprocessing routines for template matching.
4
3
 
5
4
  Copyright (c) 2023 European Molecular Biology Laboratory
6
5
 
7
6
  Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
8
7
  """
9
- import yaml
8
+ import warnings
10
9
  import argparse
11
- import textwrap
12
- from tme import Preprocessor, Density
10
+ import numpy as np
11
+
12
+ from tme import Density, Structure
13
+ from tme.backends import backend as be
14
+ from tme.preprocessing.frequency_filters import BandPassFilter
13
15
 
14
16
 
15
17
  def parse_args():
16
18
  parser = argparse.ArgumentParser(
17
- description=textwrap.dedent(
18
- """
19
- Apply preprocessing to an input file based on a provided YAML configuration.
20
-
21
- Expected YAML file format:
22
- ```yaml
23
- <method_name>:
24
- <parameter1>: <value1>
25
- <parameter2>: <value2>
26
- ...
27
- ```
28
- """
29
- ),
30
- formatter_class=argparse.RawDescriptionHelpFormatter,
19
+ description="Perform template matching preprocessing.",
20
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
31
21
  )
32
- parser.add_argument(
33
- "-i",
34
- "--input_file",
22
+
23
+ io_group = parser.add_argument_group("Input / Output")
24
+ io_group.add_argument(
25
+ "-m",
26
+ "--data",
27
+ dest="data",
35
28
  type=str,
36
29
  required=True,
37
- help="Path to the input data file in CCP4/MRC format.",
30
+ help="Path to a file in PDB/MMCIF, CCP4/MRC, EM, H5 or a format supported by "
31
+ "tme.density.Density.from_file "
32
+ "https://kosinskilab.github.io/pyTME/reference/api/tme.density.Density.from_file.html",
38
33
  )
39
- parser.add_argument(
40
- "-y",
41
- "--yaml_file",
34
+ io_group.add_argument(
35
+ "-o",
36
+ "--output",
37
+ dest="output",
42
38
  type=str,
43
39
  required=True,
44
- help="Path to the YAML configuration file.",
40
+ help="Path the output should be written to.",
45
41
  )
46
- parser.add_argument(
47
- "-o",
48
- "--output_file",
49
- type=str,
42
+
43
+ box_group = parser.add_argument_group("Box")
44
+ box_group.add_argument(
45
+ "--box_size",
46
+ dest="box_size",
47
+ type=int,
50
48
  required=True,
51
- help="Path to output file in CPP4/MRC format..",
49
+ help="Box size of the output",
52
50
  )
53
- parser.add_argument(
54
- "--compress", action="store_true", help="Compress the output file using gzip."
51
+ box_group.add_argument(
52
+ "--sampling_rate",
53
+ dest="sampling_rate",
54
+ type=float,
55
+ required=True,
56
+ help="Sampling rate of the output file.",
55
57
  )
56
58
 
59
+ modulation_group = parser.add_argument_group("Modulation")
60
+ modulation_group.add_argument(
61
+ "--invert_contrast",
62
+ dest="invert_contrast",
63
+ action="store_true",
64
+ required=False,
65
+ help="Inverts the template contrast.",
66
+ )
67
+ modulation_group.add_argument(
68
+ "--lowpass",
69
+ dest="lowpass",
70
+ type=float,
71
+ required=False,
72
+ default=None,
73
+ help="Lowpass filter the template to the given resolution. Nyquist by default. "
74
+ "A value of 0 disables the filter.",
75
+ )
76
+ modulation_group.add_argument(
77
+ "--no_centering",
78
+ dest="no_centering",
79
+ action="store_true",
80
+ help="Assumes the template is already centered and omits centering.",
81
+ )
57
82
  args = parser.parse_args()
58
-
59
83
  return args
60
84
 
61
85
 
62
86
  def main():
63
87
  args = parse_args()
64
- with open(args.yaml_file, "r") as f:
65
- preprocess_settings = yaml.safe_load(f)
66
88
 
67
- if len(preprocess_settings) > 1:
68
- raise NotImplementedError(
69
- "Multiple preprocessing methods specified. "
70
- "The script currently supports one method at a time."
89
+ try:
90
+ data = Structure.from_file(args.data)
91
+ data = Density.from_structure(data, sampling_rate=args.sampling_rate)
92
+ except NotImplementedError:
93
+ data = Density.from_file(args.data)
94
+
95
+ if not args.no_centering:
96
+ data, _ = data.centered(0)
97
+
98
+ recommended_box = be.compute_convolution_shapes([args.box_size], [1])[1][0]
99
+ if recommended_box != args.box_size:
100
+ warnings.warn(
101
+ f"Consider using --box_size {recommended_box} instead of {args.box_size}."
71
102
  )
72
103
 
73
- method_name = list(preprocess_settings.keys())[0]
74
- if not hasattr(Preprocessor, method_name):
75
- raise ValueError(f"Method {method_name} does not exist in Preprocessor.")
104
+ data.pad(
105
+ np.multiply(args.box_size, np.divide(args.sampling_rate, data.sampling_rate)),
106
+ center=True,
107
+ )
108
+
109
+ bpf_mask = 1
110
+ lowpass = 2 * args.sampling_rate if args.lowpass is None else args.lowpass
111
+ if args.lowpass != 0:
112
+ bpf_mask = BandPassFilter(
113
+ lowpass=lowpass,
114
+ highpass=None,
115
+ use_gaussian=True,
116
+ return_real_fourier=True,
117
+ shape_is_real_fourier=False,
118
+ )(shape=data.shape)["data"]
76
119
 
77
- density = Density.from_file(args.input_file)
78
- output = density.empty
120
+ data_ft = np.fft.rfftn(data.data, s=data.shape)
121
+ data_ft = np.multiply(data_ft, bpf_mask, out=data_ft)
122
+ data.data = np.fft.irfftn(data_ft, s=data.shape).real
79
123
 
80
- method_params = preprocess_settings[method_name]
81
- preprocessor = Preprocessor()
82
- method = getattr(preprocessor, method_name, None)
83
- if not method:
84
- raise ValueError(
85
- f"{method} does not exist in dge.preprocessor.Preprocessor class."
86
- )
124
+ data = data.resample(args.sampling_rate, method="spline", order=3)
87
125
 
88
- output.data = method(template=density.data, **method_params)
89
- output.to_file(args.output_file, gzip=args.compress)
126
+ if args.invert_contrast:
127
+ data.data = data.data * -1
90
128
 
129
+ data.to_file(args.output)
91
130
 
92
131
  if __name__ == "__main__":
93
- main()
132
+ main()
@@ -132,14 +132,6 @@ def local_gaussian_filter(
132
132
  )
133
133
 
134
134
 
135
- def ntree(
136
- template: NDArray,
137
- sigma_range: Tuple[float, float],
138
- **kwargs: dict,
139
- ) -> NDArray:
140
- return preprocessor.ntree_filter(template=template, sigma_range=sigma_range)
141
-
142
-
143
135
  def mean(
144
136
  template: NDArray,
145
137
  width: int,
@@ -197,6 +189,10 @@ def compute_power_spectrum(template: NDArray) -> NDArray:
197
189
  return np.fft.fftshift(np.log(np.abs(np.fft.fftn(template))))
198
190
 
199
191
 
192
+ def invert_contrast(template: NDArray) -> NDArray:
193
+ return template * -1
194
+
195
+
200
196
  def widgets_from_function(function: Callable, exclude_params: List = ["self"]):
201
197
  """
202
198
  Creates list of magicui widgets by inspecting function typing ann
@@ -252,13 +248,13 @@ WRAPPED_FUNCTIONS = {
252
248
  "gaussian_filter": gaussian_filter,
253
249
  "bandpass_filter": bandpass_filter,
254
250
  "edge_gaussian_filter": edge_gaussian_filter,
255
- "ntree_filter": ntree,
256
251
  "local_gaussian_filter": local_gaussian_filter,
257
252
  "difference_of_gaussian_filter": difference_of_gaussian_filter,
258
253
  "mean_filter": mean,
259
254
  "wedge_filter": wedge,
260
255
  "power_spectrum": compute_power_spectrum,
261
256
  "ctf": ctf_filter,
257
+ "invert_contrast": invert_contrast,
262
258
  }
263
259
 
264
260
  EXCLUDED_FUNCTIONS = [
@@ -634,6 +630,7 @@ class MaskWidget(widgets.Container):
634
630
 
635
631
  data = active_layer.data.copy()
636
632
  cutoff = np.quantile(data, self.percentile_range_edit.value / 100)
633
+ cutoff = max(cutoff, np.finfo(np.float32).resolution)
637
634
  data[data < cutoff] = 0
638
635
 
639
636
  center_of_mass = Density.center_of_mass(np.abs(data), 0)
tme/__version__.py CHANGED
@@ -1 +1 @@
1
- __version__ = "0.2.2"
1
+ __version__ = "0.2.3"
tme/analyzer.py CHANGED
@@ -459,14 +459,14 @@ class PeakCaller(ABC):
459
459
 
460
460
  final_order = top_scores[
461
461
  filter_points_indices(
462
- coordinates=peak_positions[top_scores],
462
+ coordinates=peak_positions[top_scores, :],
463
463
  min_distance=self.min_distance,
464
464
  batch_dims=self.batch_dims,
465
465
  )
466
466
  ]
467
467
 
468
- self.peak_list[0] = peak_positions[final_order,]
469
- self.peak_list[1] = rotations[final_order,]
468
+ self.peak_list[0] = peak_positions[final_order, :]
469
+ self.peak_list[1] = rotations[final_order, :]
470
470
  self.peak_list[2] = peak_scores[final_order]
471
471
  self.peak_list[3] = peak_details[final_order]
472
472
 
@@ -475,6 +475,7 @@ class PeakCaller(ABC):
475
475
  fast_shape: Tuple[int],
476
476
  targetshape: Tuple[int],
477
477
  templateshape: Tuple[int],
478
+ convolution_shape: Tuple[int] = None,
478
479
  fourier_shift: Tuple[int] = None,
479
480
  convolution_mode: str = None,
480
481
  shared_memory_handler=None,
@@ -488,6 +489,7 @@ class PeakCaller(ABC):
488
489
  return self
489
490
 
490
491
  # Wrap peaks around score space
492
+ convolution_shape = be.to_backend_array(convolution_shape)
491
493
  fast_shape = be.to_backend_array(fast_shape)
492
494
  if fourier_shift is not None:
493
495
  fourier_shift = be.to_backend_array(fourier_shift)
@@ -501,10 +503,9 @@ class PeakCaller(ABC):
501
503
  )
502
504
 
503
505
  # Remove padding to fast Fourier (and potential full convolution) shape
506
+ output_shape = convolution_shape
504
507
  targetshape = be.to_backend_array(targetshape)
505
508
  templateshape = be.to_backend_array(templateshape)
506
- fast_shape = be.minimum(be.add(targetshape, templateshape) - 1, fast_shape)
507
- output_shape = fast_shape
508
509
  if convolution_mode == "same":
509
510
  output_shape = targetshape
510
511
  elif convolution_mode == "valid":
@@ -515,7 +516,7 @@ class PeakCaller(ABC):
515
516
 
516
517
  output_shape = be.to_backend_array(output_shape)
517
518
  starts = be.astype(
518
- be.divide(be.subtract(fast_shape, output_shape), 2),
519
+ be.divide(be.subtract(convolution_shape, output_shape), 2),
519
520
  be._int_dtype,
520
521
  )
521
522
  stops = be.add(starts, output_shape)
@@ -1019,6 +1020,7 @@ class MaxScoreOverRotations:
1019
1020
  self,
1020
1021
  targetshape: Tuple[int],
1021
1022
  templateshape: Tuple[int],
1023
+ convolution_shape: Tuple[int],
1022
1024
  fourier_shift: Tuple[int] = None,
1023
1025
  convolution_mode: str = None,
1024
1026
  shared_memory_handler=None,
@@ -1039,6 +1041,7 @@ class MaxScoreOverRotations:
1039
1041
  "s1": targetshape,
1040
1042
  "s2": templateshape,
1041
1043
  "convolution_mode": convolution_mode,
1044
+ "convolution_shape": convolution_shape,
1042
1045
  }
1043
1046
  if convolution_mode is not None:
1044
1047
  scores = apply_convolution_mode(scores, **convargs)
tme/backends/__init__.py CHANGED
@@ -147,7 +147,7 @@ class BackendManager:
147
147
  _dependencies = {
148
148
  "numpyfftw": "numpy",
149
149
  "cupy": "cupy",
150
- "pytorch": "pytorch",
150
+ "pytorch": "torch",
151
151
  "mlx": "mlx",
152
152
  "jax": "jax",
153
153
  }
@@ -19,9 +19,9 @@ def _correlate(template: BackendArray, ft_target: BackendArray) -> BackendArray:
19
19
  """
20
20
  Computes :py:meth:`tme.matching_exhaustive.cc_setup`.
21
21
  """
22
- template_ft = jnp.fft.rfftn(template)
22
+ template_ft = jnp.fft.rfftn(template, s=template.shape)
23
23
  template_ft = template_ft.at[:].multiply(ft_target)
24
- correlation = jnp.fft.irfftn(template_ft)
24
+ correlation = jnp.fft.irfftn(template_ft, s=template.shape)
25
25
  return correlation
26
26
 
27
27
 
@@ -77,14 +77,15 @@ def _reciprocal_target_std(
77
77
  --------
78
78
  :py:meth:`tme.matching_exhaustive.flc_scoring`.
79
79
  """
80
- ft_template_mask = jnp.fft.rfftn(template_mask)
80
+ ft_shape = template_mask.shape
81
+ ft_template_mask = jnp.fft.rfftn(template_mask, s=ft_shape)
81
82
 
82
83
  # E(X^2)- E(X)^2
83
- exp_sq = jnp.fft.irfftn(ft_target2 * ft_template_mask)
84
+ exp_sq = jnp.fft.irfftn(ft_target2 * ft_template_mask, s=ft_shape)
84
85
  exp_sq = exp_sq.at[:].divide(n_observations)
85
86
 
86
87
  ft_template_mask = ft_template_mask.at[:].multiply(ft_target)
87
- sq_exp = jnp.fft.irfftn(ft_template_mask)
88
+ sq_exp = jnp.fft.irfftn(ft_template_mask, s=ft_shape)
88
89
  sq_exp = sq_exp.at[:].divide(n_observations)
89
90
  sq_exp = sq_exp.at[:].power(2)
90
91
 
@@ -99,7 +100,7 @@ def _reciprocal_target_std(
99
100
 
100
101
 
101
102
  def _apply_fourier_filter(arr: BackendArray, arr_filter: BackendArray) -> BackendArray:
102
- arr_ft = jnp.fft.rfftn(arr)
103
+ arr_ft = jnp.fft.rfftn(arr, s=arr.shape)
103
104
  arr_ft = arr_ft.at[:].multiply(arr_filter)
104
105
  return arr.at[:].set(jnp.fft.irfftn(arr_ft, s=arr.shape))
105
106
 
@@ -107,6 +108,7 @@ def _apply_fourier_filter(arr: BackendArray, arr_filter: BackendArray) -> Backen
107
108
  def _identity(arr: BackendArray, arr_filter: BackendArray) -> BackendArray:
108
109
  return arr
109
110
 
111
+
110
112
  @partial(
111
113
  pmap,
112
114
  in_axes=(0,) + (None,) * 6,
@@ -127,8 +129,8 @@ def scan(
127
129
  if hasattr(target_filter, "shape"):
128
130
  target = _apply_fourier_filter(target, target_filter)
129
131
 
130
- ft_target = jnp.fft.rfftn(target)
131
- ft_target2 = jnp.fft.rfftn(jnp.square(target))
132
+ ft_target = jnp.fft.rfftn(target, s=fast_shape)
133
+ ft_target2 = jnp.fft.rfftn(jnp.square(target), s=fast_shape)
132
134
  inv_denominator, target, scoring_func = None, None, _flc_scoring
133
135
  if not rotate_mask:
134
136
  n_observations = jnp.sum(template_mask)
@@ -149,10 +149,10 @@ class CupyBackend(NumpyFFTWBackend):
149
149
  cache.clear()
150
150
 
151
151
  def rfftn(arr: CupyArray, out: CupyArray) -> CupyArray:
152
- return cufft.rfftn(arr)
152
+ return cufft.rfftn(arr, s=fast_shape)
153
153
 
154
154
  def irfftn(arr: CupyArray, out: CupyArray) -> CupyArray:
155
- return cufft.irfftn(arr)
155
+ return cufft.irfftn(arr, s=fast_shape)
156
156
 
157
157
  PLAN_CACHE[current_device] = [fast_shape, fast_ft_shape]
158
158
 
@@ -167,11 +167,6 @@ class CupyBackend(NumpyFFTWBackend):
167
167
  fast_shape = [next_fast_len(x, real=True) for x in convolution_shape]
168
168
  fast_ft_shape = list(fast_shape[:-1]) + [fast_shape[-1] // 2 + 1]
169
169
 
170
- # This almost never happens but avoid cuFFT casting errors
171
- is_odd = fast_shape[-1] % 2
172
- fast_shape[-1] += is_odd
173
- fast_ft_shape[-1] += is_odd
174
-
175
170
  return convolution_shape, fast_shape, fast_ft_shape
176
171
 
177
172
  def max_filter_coordinates(self, score_space, min_distance: Tuple[int]):
@@ -119,19 +119,6 @@ class JaxBackend(NumpyFFTWBackend):
119
119
 
120
120
  return rfftn, irfftn
121
121
 
122
- def compute_convolution_shapes(
123
- self, arr1_shape: Tuple[int], arr2_shape: Tuple[int]
124
- ) -> Tuple[List[int], List[int], List[int]]:
125
- conv_shape, fast_shape, fast_ft_shape = super().compute_convolution_shapes(
126
- arr1_shape, arr2_shape
127
- )
128
-
129
- is_odd = fast_shape[-1] % 2
130
- fast_shape[-1] += is_odd
131
- fast_ft_shape[-1] += is_odd
132
-
133
- return conv_shape, fast_shape, fast_ft_shape
134
-
135
122
  def rigid_transform(
136
123
  self,
137
124
  arr: BackendArray,
@@ -144,8 +131,8 @@ class JaxBackend(NumpyFFTWBackend):
144
131
  **kwargs,
145
132
  ) -> Tuple[BackendArray, BackendArray]:
146
133
  rotate_mask = arr_mask is not None
147
- center = self.divide(self.to_backend_array(arr.shape), 2)[:, None]
148
134
 
135
+ center = self.divide(self.to_backend_array(arr.shape) - 1, 2)[:, None]
149
136
  indices = self._array_backend.indices(arr.shape, dtype=self._float_dtype)
150
137
  indices = indices.reshape((arr.ndim, -1))
151
138
  indices = indices.at[:].add(-center)
@@ -200,7 +187,7 @@ class JaxBackend(NumpyFFTWBackend):
200
187
  target_shape = tuple(
201
188
  (x.stop - x.start + p) for x, p in zip(splits[0][0], target_pad)
202
189
  )
203
- fast_shape, fast_ft_shape, shift = matching_data._fourier_padding(
190
+ conv_shape, fast_shape, fast_ft_shape, shift = matching_data._fourier_padding(
204
191
  target_shape=self.to_numpy_array(target_shape),
205
192
  template_shape=self.to_numpy_array(matching_data._template.shape),
206
193
  pad_fourier=False,
@@ -210,13 +197,25 @@ class JaxBackend(NumpyFFTWBackend):
210
197
  "convolution_mode": convolution_mode,
211
198
  "fourier_shift": shift,
212
199
  "targetshape": target_shape,
213
- "templateshape": matching_data._template.shape,
200
+ "templateshape": matching_data.template.shape,
201
+ "convolution_shape": conv_shape,
214
202
  }
215
203
 
216
204
  create_target_filter = matching_data.target_filter is not None
217
205
  create_template_filter = matching_data.template_filter is not None
218
206
  create_filter = create_target_filter or create_template_filter
219
207
 
208
+ # Applying the filter leads to more FFTs
209
+ fastt_shape = matching_data._template.shape
210
+ if create_template_filter:
211
+ _, fastt_shape, _, tshift = matching_data._fourier_padding(
212
+ target_shape=self.to_numpy_array(matching_data._template.shape),
213
+ template_shape=self.to_numpy_array(
214
+ [1 for _ in matching_data._template.shape]
215
+ ),
216
+ pad_fourier=False,
217
+ )
218
+
220
219
  ret, template_filter, target_filter = [], 1, 1
221
220
  rotation_mapping = {
222
221
  self.tobytes(matching_data.rotations[i]): i
@@ -246,12 +245,12 @@ class JaxBackend(NumpyFFTWBackend):
246
245
 
247
246
  if create_template_filter:
248
247
  template_filter = matching_data.template_filter(
249
- shape=matching_data._template.shape, **filter_args
248
+ shape=fastt_shape, **filter_args
250
249
  )["data"]
251
250
  template_filter = template_filter.at[(0,) * template_filter.ndim].set(0)
252
251
 
253
252
  if create_target_filter:
254
- target_filter = matching_data.template_filter(
253
+ target_filter = matching_data.target_filter(
255
254
  shape=fast_shape, **filter_args
256
255
  )["data"]
257
256
  target_filter = target_filter.at[(0,) * target_filter.ndim].set(0)
@@ -260,8 +259,8 @@ class JaxBackend(NumpyFFTWBackend):
260
259
  base, targets = None, self._array_backend.stack(targets)
261
260
  scores, rotations = scan_inner(
262
261
  targets,
263
- matching_data.template,
264
- matching_data.template_mask,
262
+ self.topleft_pad(matching_data.template, fastt_shape),
263
+ self.topleft_pad(matching_data.template_mask, fastt_shape),
265
264
  matching_data.rotations,
266
265
  template_filter,
267
266
  target_filter,
@@ -280,3 +279,18 @@ class JaxBackend(NumpyFFTWBackend):
280
279
  ret.append(tuple(temp._postprocess(**analyzer_args)))
281
280
 
282
281
  return ret
282
+
283
+ def get_available_memory(self) -> int:
284
+ import jax
285
+
286
+ _memory = {"cpu": 0, "gpu": 0}
287
+ for device in jax.devices():
288
+ if device.platform == "cpu":
289
+ _memory["cpu"] = super().get_available_memory()
290
+ else:
291
+ mem_stats = device.memory_stats()
292
+ _memory["gpu"] += mem_stats.get("bytes_limit", 0)
293
+
294
+ if _memory["gpu"] > 0:
295
+ return _memory["gpu"]
296
+ return _memory["cpu"]
@@ -186,7 +186,7 @@ class NumpyFFTWBackend(_NumpyWrapper, MatchingBackend):
186
186
  def to_sharedarr(
187
187
  self, arr: NDArray, shared_memory_handler: type = None
188
188
  ) -> shm_type:
189
- if type(shared_memory_handler) == SharedMemoryManager:
189
+ if isinstance(shared_memory_handler, SharedMemoryManager):
190
190
  shm = shared_memory_handler.SharedMemory(size=arr.nbytes)
191
191
  else:
192
192
  shm = shared_memory.SharedMemory(create=True, size=arr.nbytes)
@@ -347,7 +347,8 @@ class NumpyFFTWBackend(_NumpyWrapper, MatchingBackend):
347
347
  cache: bool = False,
348
348
  ) -> Tuple[NDArray, NDArray]:
349
349
  translation = self.zeros(arr.ndim) if translation is None else translation
350
- center = self.divide(self.to_backend_array(arr.shape), 2)
350
+
351
+ center = self.divide(self.to_backend_array(arr.shape) - 1, 2)
351
352
  if not use_geometric_center:
352
353
  center = self.center_of_mass(arr, cutoff=0)
353
354
 
@@ -81,13 +81,13 @@ class PytorchBackend(NumpyFFTWBackend):
81
81
 
82
82
  def max(self, *args, **kwargs) -> NDArray:
83
83
  ret = self._array_backend.amax(*args, **kwargs)
84
- if type(ret) == self._array_backend.Tensor:
84
+ if isinstance(ret, self._array_backend.Tensor):
85
85
  return ret
86
86
  return ret[0]
87
87
 
88
88
  def min(self, *args, **kwargs) -> NDArray:
89
89
  ret = self._array_backend.amin(*args, **kwargs)
90
- if type(ret) == self._array_backend.Tensor:
90
+ if isinstance(ret, self._array_backend.Tensor):
91
91
  return ret
92
92
  return ret[0]
93
93
 
@@ -154,7 +154,7 @@ class PytorchBackend(NumpyFFTWBackend):
154
154
  1, -1
155
155
  )
156
156
  if unraveled_coords.size(0) == 1:
157
- return tuple(unraveled_coords[0, :].tolist())
157
+ return (unraveled_coords[0, :],)
158
158
 
159
159
  else:
160
160
  return tuple(unraveled_coords.T)
@@ -206,7 +206,9 @@ class PytorchBackend(NumpyFFTWBackend):
206
206
  else:
207
207
  raise NotImplementedError("Operation only implemented for 2 and 3D inputs.")
208
208
 
209
- pool = func(kernel_size=min_distance, return_indices=True)
209
+ pool = func(
210
+ kernel_size=min_distance, padding=min_distance // 2, return_indices=True
211
+ )
210
212
  _, indices = pool(score_space.reshape(1, 1, *score_space.shape))
211
213
  coordinates = self.unravel_index(indices.reshape(-1), score_space.shape)
212
214
  coordinates = self.transpose(self.stack(coordinates))
@@ -217,7 +219,7 @@ class PytorchBackend(NumpyFFTWBackend):
217
219
 
218
220
  def from_sharedarr(self, args) -> TorchTensor:
219
221
  if self.device == "cuda":
220
- return args[0]
222
+ return args
221
223
 
222
224
  shm, shape, dtype = args
223
225
  required_size = int(self._array_backend.prod(self.to_backend_array(shape)))
@@ -235,13 +237,12 @@ class PytorchBackend(NumpyFFTWBackend):
235
237
 
236
238
  nbytes = arr.numel() * arr.element_size()
237
239
 
238
- if type(shared_memory_handler) == SharedMemoryManager:
240
+ if isinstance(shared_memory_handler, SharedMemoryManager):
239
241
  shm = shared_memory_handler.SharedMemory(size=nbytes)
240
242
  else:
241
243
  shm = shared_memory.SharedMemory(create=True, size=nbytes)
242
244
 
243
245
  shm.buf[:nbytes] = arr.numpy().tobytes()
244
-
245
246
  return shm, arr.shape, arr.dtype
246
247
 
247
248
  def transpose(self, arr):
@@ -415,6 +416,8 @@ class PytorchBackend(NumpyFFTWBackend):
415
416
  yield None
416
417
 
417
418
  def device_count(self) -> int:
419
+ if self.device == "cpu":
420
+ return 1
418
421
  return self._array_backend.cuda.device_count()
419
422
 
420
423
  def reverse(self, arr: TorchTensor) -> TorchTensor:
tme/density.py CHANGED
@@ -116,8 +116,8 @@ class Density:
116
116
  response = "Density object at {}\nOrigin: {}, Sampling Rate: {}, Shape: {}"
117
117
  return response.format(
118
118
  hex(id(self)),
119
- tuple(np.round(self.origin, 3)),
120
- tuple(np.round(self.sampling_rate, 3)),
119
+ tuple(round(float(x), 3) for x in self.origin),
120
+ tuple(round(float(x), 3) for x in self.sampling_rate),
121
121
  self.shape,
122
122
  )
123
123
 
@@ -306,6 +306,10 @@ class Density:
306
306
  "std": float(mrc.header.rms),
307
307
  }
308
308
 
309
+ non_standard_crs = not np.all(crs_index == (0, 1, 2))
310
+ if non_standard_crs:
311
+ warnings.warn("Non standard MAPC, MAPR, MAPS, adapting data and origin.")
312
+
309
313
  if is_gzipped(filename):
310
314
  if use_memmap:
311
315
  warnings.warn(
@@ -315,6 +319,10 @@ class Density:
315
319
  use_memmap = False
316
320
 
317
321
  if subset is not None:
322
+ subset = tuple(
323
+ subset[i] if i < len(subset) else slice(0, data_shape[i])
324
+ for i in crs_index
325
+ )
318
326
  subset_shape = tuple(x.stop - x.start for x in subset)
319
327
  if np.allclose(subset_shape, data_shape):
320
328
  return cls._load_mrc(
@@ -328,18 +336,16 @@ class Density:
328
336
  dtype=data_type,
329
337
  header_size=1024 + extended_header,
330
338
  )
331
- return data, origin, sampling_rate, metadata
332
-
333
- if not use_memmap:
339
+ elif subset is None and not use_memmap:
334
340
  with mrcfile.open(filename, header_only=False) as mrc:
335
341
  data = mrc.data.astype(np.float32, copy=False)
336
342
  else:
337
343
  with mrcfile.mrcmemmap.MrcMemmap(filename, header_only=False) as mrc:
338
344
  data = mrc.data
339
345
 
340
- if not np.all(crs_index == (0, 1, 2)):
346
+ if non_standard_crs:
341
347
  data = np.transpose(data, crs_index)
342
- start = np.take(start, crs_index)
348
+ origin = np.take(origin, crs_index)
343
349
 
344
350
  return data, origin, sampling_rate, metadata
345
351
 
@@ -873,6 +879,7 @@ class Density:
873
879
  mrc.header.nzstart, mrc.header.nystart, mrc.header.nxstart = np.rint(
874
880
  np.divide(self.origin, self.sampling_rate)
875
881
  )
882
+ mrc.header.origin = tuple(x for x in self.origin)
876
883
  # mrcfile library expects origin to be in xyz format
877
884
  mrc.header.mapc, mrc.header.mapr, mrc.header.maps = (1, 2, 3)
878
885
  mrc.header["origin"] = tuple(self.origin[::-1])
@@ -1594,7 +1601,7 @@ class Density:
1594
1601
  rotation_matrix: NDArray,
1595
1602
  translation: NDArray = None,
1596
1603
  order: int = 3,
1597
- use_geometric_center: bool = False,
1604
+ use_geometric_center: bool = True,
1598
1605
  ) -> "Density":
1599
1606
  """
1600
1607
  Performs a rigid transform of the class instance.
Binary file