pytme 0.1.6__cp311-cp311-macosx_14_0_arm64.whl → 0.1.8__cp311-cp311-macosx_14_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,269 @@
1
+ """ Backend Apple's MLX library for template matching.
2
+
3
+ Copyright (c) 2024 European Molecular Biology Laboratory
4
+
5
+ Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
6
+ """
7
+ from typing import Tuple, List, Callable
8
+
9
+ import numpy as np
10
+
11
+ from .npfftw_backend import NumpyFFTWBackend
12
+ from ..types import NDArray, MlxArray, Scalar
13
+
14
+
15
+ class MLXBackend(NumpyFFTWBackend):
16
+ """
17
+ A MLX based backend for template matching.
18
+ """
19
+
20
+ def __init__(
21
+ self,
22
+ device="cpu",
23
+ default_dtype=None,
24
+ complex_dtype=None,
25
+ default_dtype_int=None,
26
+ **kwargs,
27
+ ):
28
+ import mlx.core as mx
29
+
30
+ device = mx.cpu if device == "cpu" else mx.gpu
31
+ default_dtype = mx.float32 if default_dtype is None else default_dtype
32
+ complex_dtype = mx.complex64 if complex_dtype is None else complex_dtype
33
+ default_dtype_int = mx.int32 if default_dtype_int is None else default_dtype_int
34
+
35
+ super().__init__(
36
+ array_backend=mx,
37
+ default_dtype=default_dtype,
38
+ complex_dtype=complex_dtype,
39
+ default_dtype_int=default_dtype_int,
40
+ )
41
+
42
+ self.device = device
43
+
44
+ def to_backend_array(self, arr: NDArray) -> MlxArray:
45
+ return self._array_backend.array(arr)
46
+
47
+ def to_numpy_array(self, arr: MlxArray) -> NDArray:
48
+ return np.array(arr)
49
+
50
+ def to_cpu_array(self, arr: MlxArray) -> NDArray:
51
+ return arr
52
+
53
+ def free_cache(self):
54
+ pass
55
+
56
+ def mod(self, arr1: MlxArray, arr2: MlxArray, out: MlxArray = None) -> MlxArray:
57
+ if out is not None:
58
+ out[:] = arr1 % arr2
59
+ return None
60
+ return arr1 % arr2
61
+
62
+ def add(self, x1, x2, out: MlxArray = None, **kwargs) -> MlxArray:
63
+ x1 = self.to_backend_array(x1)
64
+ x2 = self.to_backend_array(x2)
65
+
66
+ if out is not None:
67
+ out[:] = self._array_backend.add(x1, x2, **kwargs)
68
+ return None
69
+ return self._array_backend.add(x1, x2, **kwargs)
70
+
71
+ def std(self, arr: MlxArray, axis) -> Scalar:
72
+ return self._array_backend.sqrt(arr.var(axis=axis))
73
+
74
+ def unique(self, *args, **kwargs):
75
+ ret = np.unique(*args, **kwargs)
76
+ if isinstance(ret, tuple):
77
+ ret = [self.to_backend_array(x) for x in ret]
78
+ return ret
79
+
80
+ def tobytes(self, arr):
81
+ return self.to_numpy_array(arr).tobytes()
82
+
83
+ def preallocate_array(self, shape: Tuple[int], dtype: type = None) -> NDArray:
84
+ """
85
+ Returns a byte-aligned array of zeros with specified shape and dtype.
86
+
87
+ Parameters
88
+ ----------
89
+ shape : Tuple[int]
90
+ Desired shape for the array.
91
+ dtype : type, optional
92
+ Desired data type for the array.
93
+
94
+ Returns
95
+ -------
96
+ NDArray
97
+ Byte-aligned array of zeros with specified shape and dtype.
98
+ """
99
+ arr = self._array_backend.zeros(shape, dtype=dtype)
100
+ return arr
101
+
102
+ def full(self, shape, fill_value, dtype=None):
103
+ return self._array_backend.full(shape=shape, dtype=dtype, vals=fill_value)
104
+
105
+ def fill(self, arr: MlxArray, value: Scalar) -> None:
106
+ arr[:] = value
107
+
108
+ def zeros(self, shape: Tuple[int], dtype: type = None) -> MlxArray:
109
+ return self._array_backend.zeros(shape=shape, dtype=dtype)
110
+
111
+ def roll(self, a: MlxArray, shift, axis, **kwargs):
112
+ a = self.to_numpy_array(a)
113
+ ret = NumpyFFTWBackend().roll(
114
+ a,
115
+ shift=shift,
116
+ axis=axis,
117
+ **kwargs,
118
+ )
119
+ return self.to_backend_array(ret)
120
+
121
+ def extract_center(self, arr: NDArray, newshape: Tuple[int]) -> NDArray:
122
+ """
123
+ Extract the centered portion of an array based on a new shape.
124
+
125
+ Parameters
126
+ ----------
127
+ arr : NDArray
128
+ Input array.
129
+ newshape : tuple
130
+ Desired shape for the central portion.
131
+
132
+ Returns
133
+ -------
134
+ NDArray
135
+ Central portion of the array with shape `newshape`.
136
+
137
+ References
138
+ ----------
139
+ .. [1] https://github.com/scipy/scipy/blob/v1.11.2/scipy/signal/_signaltools.py
140
+ """
141
+ new_shape = self.to_backend_array(newshape)
142
+ current_shape = self.to_backend_array(arr.shape)
143
+ starts = self.subtract(current_shape, new_shape)
144
+ starts = self.astype(self.divide(starts, 2), self._default_dtype_int)
145
+ stops = self.astype(self.add(starts, newshape), self._default_dtype_int)
146
+ starts, stops = starts.tolist(), stops.tolist()
147
+ box = tuple(slice(start, stop) for start, stop in zip(starts, stops))
148
+ return arr[box]
149
+
150
+ def build_fft(
151
+ self, fast_shape: Tuple[int], fast_ft_shape: Tuple[int], **kwargs
152
+ ) -> Tuple[Callable, Callable]:
153
+ """
154
+ Build fft builder functions.
155
+
156
+ Parameters
157
+ ----------
158
+ fast_shape : tuple
159
+ Tuple of integers corresponding to fast convolution shape
160
+ (see `compute_convolution_shapes`).
161
+ fast_ft_shape : tuple
162
+ Tuple of integers corresponding to the shape of the fourier
163
+ transform array (see `compute_convolution_shapes`).
164
+ **kwargs : dict, optional
165
+ Additional parameters that are not used for now.
166
+
167
+ Returns
168
+ -------
169
+ tuple
170
+ Tupple containing callable rfft and irfft object.
171
+ """
172
+
173
+ # Runs on mlx.core.cpu until Metal support is available
174
+ def rfftn(arr: MlxArray, out: MlxArray, shape: Tuple[int] = fast_shape) -> None:
175
+ out[:] = self._array_backend.fft.rfftn(
176
+ arr, s=shape, stream=self._array_backend.cpu
177
+ )
178
+
179
+ def irfftn(
180
+ arr: MlxArray, out: MlxArray, shape: Tuple[int] = fast_shape
181
+ ) -> None:
182
+ out[:] = self._array_backend.fft.irfftn(
183
+ arr, s=shape, stream=self._array_backend.cpu
184
+ )
185
+
186
+ return rfftn, irfftn
187
+
188
+ def sharedarr_to_arr(
189
+ self, shape: Tuple[int], dtype: str, shm: MlxArray
190
+ ) -> MlxArray:
191
+ return shm
192
+
193
+ @staticmethod
194
+ def arr_to_sharedarr(arr: MlxArray, shared_memory_handler: type = None) -> MlxArray:
195
+ return arr
196
+
197
+ def topk_indices(self, arr: NDArray, k: int):
198
+ arr = self.to_numpy_array(arr)
199
+ ret = NumpyFFTWBackend().topk_indices(arr=arr, k=k)
200
+ ret = [self.to_backend_array(x) for x in ret]
201
+ return ret
202
+
203
+ def rotate_array(
204
+ self,
205
+ arr: NDArray,
206
+ rotation_matrix: NDArray,
207
+ arr_mask: NDArray = None,
208
+ translation: NDArray = None,
209
+ use_geometric_center: bool = False,
210
+ out: NDArray = None,
211
+ out_mask: NDArray = None,
212
+ order: int = 3,
213
+ ) -> None:
214
+ rotate_mask = arr_mask is not None
215
+ return_type = (out is None) + 2 * rotate_mask * (out_mask is None)
216
+
217
+ arr = self.to_numpy_array(arr)
218
+ rotation_matrix = self.to_numpy_array(rotation_matrix)
219
+
220
+ if arr_mask is not None:
221
+ arr_mask = self.to_numpy_array(arr_mask)
222
+
223
+ if translation is not None:
224
+ translation = self.to_numpy_array(translation)
225
+
226
+ out_pass, out_mask_pass = None, None
227
+ if out is not None:
228
+ out_pass = self.to_numpy_array(out)
229
+ if out_mask is not None:
230
+ out_mask_pass = self.to_numpy_array(out_mask)
231
+
232
+ ret = NumpyFFTWBackend().rotate_array(
233
+ arr=arr,
234
+ rotation_matrix=rotation_matrix,
235
+ arr_mask=arr_mask,
236
+ translation=translation,
237
+ use_geometric_center=use_geometric_center,
238
+ out=out_pass,
239
+ out_mask=out_mask_pass,
240
+ order=order,
241
+ )
242
+
243
+ if ret is not None:
244
+ if len(ret) == 1 and out is None:
245
+ out_pass = ret
246
+ elif len(ret) == 1 and out_mask is None:
247
+ out_mask_pass = ret
248
+ else:
249
+ out_pass, out_mask_pass = ret
250
+
251
+ if out is not None:
252
+ out[:] = self.to_backend_array(out_pass)
253
+
254
+ if out_mask is not None:
255
+ out_mask[:] = self.to_backend_array(out_mask_pass)
256
+
257
+ match return_type:
258
+ case 0:
259
+ return None
260
+ case 1:
261
+ return out
262
+ case 2:
263
+ return out_mask
264
+ case 3:
265
+ return out, out_mask
266
+
267
+ def indices(self, arr: List) -> MlxArray:
268
+ ret = NumpyFFTWBackend().indices(arr)
269
+ return self.to_backend_array(ret)
@@ -38,7 +38,7 @@ class NumpyFFTWBackend(MatchingBackend):
38
38
  array_backend=array_backend,
39
39
  default_dtype=default_dtype,
40
40
  complex_dtype=complex_dtype,
41
- default_dtype_int=np.int32,
41
+ default_dtype_int=default_dtype_int,
42
42
  )
43
43
  self.affine_transform = affine_transform
44
44
 
@@ -424,8 +424,8 @@ class NumpyFFTWBackend(MatchingBackend):
424
424
  new_shape = self.to_backend_array(newshape)
425
425
  current_shape = self.to_backend_array(arr.shape)
426
426
  starts = self.subtract(current_shape, new_shape)
427
- starts = self.astype(self.divide(starts, 2), int)
428
- stops = self.add(starts, newshape)
427
+ starts = self.astype(self.divide(starts, 2), self._default_dtype_int)
428
+ stops = self.astype(self.add(starts, newshape), self._default_dtype_int)
429
429
  box = tuple(slice(start, stop) for start, stop in zip(starts, stops))
430
430
  return arr[box]
431
431
 
@@ -493,7 +493,9 @@ class NumpyFFTWBackend(MatchingBackend):
493
493
  out_mask : NDArray, optional
494
494
  The output array to write the rotation of `arr_mask` to.
495
495
  order : int, optional
496
- Spline interpolation order. Has to be in the range 0-5.
496
+ Spline interpolation order. Has to be in the range 0-5. Non-zero
497
+ elements will be converted into a point-cloud and rotated according
498
+ to ``rotation_matrix`` if order is None.
497
499
  """
498
500
 
499
501
  if order is None:
@@ -611,13 +613,6 @@ class NumpyFFTWBackend(MatchingBackend):
611
613
  rotate_mask = arr_mask is not None and mask_coordinates is not None
612
614
  return_type = (out is None) + 2 * rotate_mask * (out_mask is None)
613
615
 
614
- # Otherwise array might be slightly shifted by centering
615
- if np.allclose(
616
- rotation_matrix,
617
- np.eye(rotation_matrix.shape[0], dtype=rotation_matrix.dtype),
618
- ):
619
- center_rotation = False
620
-
621
616
  coordinates_rotated = np.empty(coordinates.shape, dtype=rotation_matrix.dtype)
622
617
  mask_rotated = (
623
618
  np.empty(mask_coordinates.shape, dtype=rotation_matrix.dtype)
@@ -6,13 +6,12 @@
6
6
  Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
7
7
  """
8
8
 
9
- from typing import Tuple, Dict, Callable
9
+ from typing import Tuple, Callable
10
10
  from contextlib import contextmanager
11
11
  from multiprocessing import shared_memory
12
12
  from multiprocessing.managers import SharedMemoryManager
13
13
 
14
- from numpy.typing import NDArray
15
-
14
+ import numpy as np
16
15
  from .npfftw_backend import NumpyFFTWBackend
17
16
  from ..types import NDArray, TorchTensor
18
17
 
@@ -47,7 +46,6 @@ class PytorchBackend(NumpyFFTWBackend):
47
46
  )
48
47
  self.device = device
49
48
  self.F = F
50
- self._default_dtype_int = torch.int32
51
49
 
52
50
  def to_backend_array(self, arr: NDArray) -> TorchTensor:
53
51
  if isinstance(arr, self._array_backend.Tensor):
@@ -57,6 +55,8 @@ class PytorchBackend(NumpyFFTWBackend):
57
55
  return self.tensor(arr, device=self.device)
58
56
 
59
57
  def to_numpy_array(self, arr: TorchTensor) -> NDArray:
58
+ if isinstance(arr, np.ndarray):
59
+ return arr
60
60
  return arr.cpu().numpy()
61
61
 
62
62
  def to_cpu_array(self, arr: TorchTensor) -> NDArray:
@@ -131,7 +131,7 @@ class PytorchBackend(NumpyFFTWBackend):
131
131
 
132
132
  def full(self, shape, fill_value, dtype=None):
133
133
  return self._array_backend.full(
134
- size = shape, dtype=dtype, fill_value=fill_value, device=self.device
134
+ size=shape, dtype=dtype, fill_value=fill_value, device=self.device
135
135
  )
136
136
 
137
137
  def datatype_bytes(self, dtype: type) -> int:
tme/density.py CHANGED
@@ -186,8 +186,8 @@ class Density:
186
186
  Notes
187
187
  -----
188
188
  If ``filename`` ends with ".em" or ".em.gz" the method will parse it as EM file.
189
- Otherwise it defaults to the CCP4/MRC format and on failure, defaults to
190
- skimage.io.imread regardless of the extension. Currently, the later does not
189
+ Otherwise it defaults to the CCP4/MRC format and on failure, switches to
190
+ :obj:`skimage.io.imread` regardless of the extension. Currently, the later does not
191
191
  extract origin or sampling_rate information from the file.
192
192
 
193
193
  See Also
@@ -199,16 +199,16 @@ class Density:
199
199
  func = cls._load_mrc
200
200
  if filename.endswith(".em") or filename.endswith(".em.gz"):
201
201
  func = cls._load_em
202
- data, origin, sampling_rate = func(
202
+ data, origin, sampling_rate, meta = func(
203
203
  filename=filename, subset=subset, use_memmap=use_memmap
204
204
  )
205
205
  except ValueError:
206
- data, origin, sampling_rate = cls._load_skio(filename=filename)
206
+ data, origin, sampling_rate, meta = cls._load_skio(filename=filename)
207
207
  if subset is not None:
208
208
  cls._validate_slices(slices=subset, shape=data.shape)
209
209
  data = data[subset].copy()
210
210
 
211
- return cls(data=data, origin=origin, sampling_rate=sampling_rate)
211
+ return cls(data=data, origin=origin, sampling_rate=sampling_rate, metadata=meta)
212
212
 
213
213
  @classmethod
214
214
  def _load_mrc(
@@ -305,6 +305,15 @@ class Density:
305
305
  ) and not np.all(start == 0):
306
306
  origin = np.multiply(start, sampling_rate)
307
307
 
308
+ extended_header = mrc.header.nsymbt
309
+
310
+ metadata = {
311
+ "min": float(mrc.header.dmin),
312
+ "max": float(mrc.header.dmax),
313
+ "mean": float(mrc.header.dmean),
314
+ "std": float(mrc.header.rms),
315
+ }
316
+
308
317
  if is_gzipped(filename):
309
318
  if use_memmap:
310
319
  warnings.warn(
@@ -325,9 +334,9 @@ class Density:
325
334
  slices=subset,
326
335
  data_shape=data_shape,
327
336
  dtype=data_type,
328
- header_size=1024,
337
+ header_size=1024 + extended_header,
329
338
  )
330
- return data, origin, sampling_rate
339
+ return data, origin, sampling_rate, metadata
331
340
 
332
341
  if not use_memmap:
333
342
  with mrcfile.open(filename, header_only=False) as mrc:
@@ -341,7 +350,7 @@ class Density:
341
350
  data = np.transpose(data, crs_index)
342
351
  start = np.take(start, crs_index)
343
352
 
344
- return data, origin, sampling_rate
353
+ return data, origin, sampling_rate, metadata
345
354
 
346
355
  @classmethod
347
356
  def _load_em(
@@ -446,7 +455,7 @@ class Density:
446
455
  pixel_size = 1
447
456
  sampling_rate = np.repeat(pixel_size, data.ndim).astype(data.dtype)
448
457
 
449
- return data, origin, sampling_rate
458
+ return data, origin, sampling_rate, {}
450
459
 
451
460
  @staticmethod
452
461
  def _validate_slices(slices: Tuple[slice], shape: Tuple[int]):
@@ -553,12 +562,12 @@ class Density:
553
562
  @staticmethod
554
563
  def _load_skio(filename: str) -> Tuple[NDArray]:
555
564
  """
556
- Uses skimage.io.imread to extract data from filename.
565
+ Uses :obj:`skimage.io.imread` to extract data from filename [1]_.
557
566
 
558
567
  Parameters
559
568
  ----------
560
569
  filename : str
561
- Path to a file whose format is supported by skimage.io.imread.
570
+ Path to a file whose format is supported by :obj:`skimage.io.imread`.
562
571
 
563
572
  Returns
564
573
  -------
@@ -590,7 +599,7 @@ class Density:
590
599
  warnings.warn(
591
600
  "origin and sampling_rate are not yet extracted from non CCP4/MRC files."
592
601
  )
593
- return data, np.zeros(data.ndim), np.ones(data.ndim)
602
+ return data, np.zeros(data.ndim), np.ones(data.ndim), {}
594
603
 
595
604
  @classmethod
596
605
  def from_structure(
@@ -720,7 +729,7 @@ class Density:
720
729
  :py:meth:`tme.structure.Structure.to_volume`
721
730
  """
722
731
  structure = filename_or_structure
723
- if type(filename_or_structure) == str:
732
+ if isinstance(filename_or_structure, str):
724
733
  structure = Structure.from_file(
725
734
  filename=filename_or_structure,
726
735
  filter_by_elements=filter_by_elements,
@@ -790,8 +799,8 @@ class Density:
790
799
  Notes
791
800
  -----
792
801
  If ``filename`` ends with ".em" or ".em.gz", the method will create an EM file.
793
- Otherwise, it defaults to the CCP4/MRC format, and on failure, it falls back
794
- to `skimage.io.imsave`.
802
+ Otherwise, it defaults to the CCP4/MRC format, and on failure, falls back
803
+ to :obj:`skimage.io.imsave`.
795
804
 
796
805
  See Also
797
806
  --------
@@ -879,12 +888,12 @@ class Density:
879
888
 
880
889
  def _save_skio(self, filename: str, gzip: bool) -> None:
881
890
  """
882
- Uses skimage.io.imsave to write data to filename.
891
+ Uses :obj:`skimage.io.imsave` to write data to filename [1]_.
883
892
 
884
893
  Parameters
885
894
  ----------
886
895
  filename : str
887
- Path to write to with a format supported by skimage.io.imsave.
896
+ Path to write to with a format supported by :obj:`skimage.io.imsave`.
888
897
  gzip : bool, optional
889
898
  If True, the output will be gzip compressed.
890
899
 
@@ -907,7 +916,8 @@ class Density:
907
916
  Returns a copy of the current class instance with all elements in
908
917
  :py:attr:`Density.data` set to zero. :py:attr:`Density.origin` and
909
918
  :py:attr:`Density.sampling_rate` will be copied, while
910
- :py:attr:`Density.metadata` will be initialized to an empty dictionary.
919
+ :py:attr:`Density.metadata` will be initialized to contain min, max,
920
+ mean and standard deviation of :py:attr:`Density.data`.
911
921
 
912
922
  Examples
913
923
  --------
@@ -922,6 +932,7 @@ class Density:
922
932
  data=np.zeros_like(self.data),
923
933
  origin=deepcopy(self.origin),
924
934
  sampling_rate=deepcopy(self.sampling_rate),
935
+ metadata={"min": 0, "max": 0, "mean": 0, "std": 0},
925
936
  )
926
937
 
927
938
  def copy(self) -> "Density":
@@ -1407,9 +1418,9 @@ class Density:
1407
1418
 
1408
1419
  def centered(self, cutoff: float = 0) -> Tuple["Density", NDArray]:
1409
1420
  """
1410
- Shifts the data center of mass to the center of the data array. The box size
1411
- of the return Density object is at least equal to the box size of the class
1412
- instance.
1421
+ Shifts the data center of mass to the center of the data array using linear
1422
+ interpolation. The box size of the returned :py:class:`Density` object is at
1423
+ least equal to the box size of the class instance.
1413
1424
 
1414
1425
  Parameters
1415
1426
  ----------
@@ -1442,8 +1453,8 @@ class Density:
1442
1453
  --------
1443
1454
  :py:meth:`Density.centered` returns a tuple containing a centered version
1444
1455
  of the current :py:class:`Density` instance, as well as an array with
1445
- translations. The translation corresponds to the shift that was used to
1446
- center the current :py:class:`Density` instance.
1456
+ translations. The translation corresponds to the shift that between the
1457
+ center of mass and the center of the internal :py:attr:`Density.data` attribute.
1447
1458
 
1448
1459
  >>> import numpy as np
1449
1460
  >>> from tme import Density
@@ -1463,12 +1474,13 @@ class Density:
1463
1474
  internal :py:attr:`Density.data` attribute:
1464
1475
 
1465
1476
  >>> centered_dens.data
1466
- array([[0., 0., 0., 0., 0., 0.],
1467
- [0., 1., 1., 1., 1., 1.],
1468
- [0., 1., 1., 1., 1., 1.],
1469
- [0., 1., 1., 1., 1., 1.],
1470
- [0., 1., 1., 1., 1., 1.],
1471
- [0., 1., 1., 1., 1., 1.]])
1477
+ array([[0., 0., 0., 0., 0., 0., 0.],
1478
+ [0., 1., 1., 1., 1., 1., 0.],
1479
+ [0., 1., 1., 1., 1., 1., 0.],
1480
+ [0., 1., 1., 1., 1., 1., 0.],
1481
+ [0., 1., 1., 1., 1., 1., 0.],
1482
+ [0., 1., 1., 1., 1., 1., 0.],
1483
+ [0., 0., 0., 0., 0., 0., 0.]])
1472
1484
 
1473
1485
  `centered_dens` is sufficiently large to represent all rotations that
1474
1486
  could be applied to the :py:attr:`Density.data` attribute. Lets look
@@ -1494,14 +1506,15 @@ class Density:
1494
1506
  ret.pad(new_shape)
1495
1507
 
1496
1508
  center = self.center_of_mass(ret.data, cutoff)
1497
- shift = np.subtract(np.divide(ret.shape, 2), center).astype(int)
1509
+ shift = np.subtract(np.divide(ret.shape, 2), center)
1498
1510
 
1499
1511
  ret = ret.rigid_transform(
1500
1512
  translation=shift,
1501
1513
  rotation_matrix=np.eye(ret.data.ndim),
1502
1514
  use_geometric_center=False,
1515
+ order=1,
1503
1516
  )
1504
- offset = np.subtract(center, self.center_of_mass(ret.data))
1517
+ offset = np.subtract(center, self.center_of_mass(ret.data, cutoff))
1505
1518
 
1506
1519
  return ret, offset
1507
1520