pytme 0.1.6__cp311-cp311-macosx_14_0_arm64.whl → 0.1.8__cp311-cp311-macosx_14_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
tme/matching_utils.py CHANGED
@@ -763,10 +763,14 @@ def euler_to_rotationmatrix(angles: Tuple[float]) -> NDArray:
763
763
  NDArray
764
764
  The generated rotation matrix.
765
765
  """
766
- if len(angles) == 1:
766
+ n_angles = len(angles)
767
+ angle_convention = "zyx"[:n_angles]
768
+ if n_angles == 1:
767
769
  angles = (angles, 0, 0)
768
770
  rotation_matrix = (
769
- Rotation.from_euler("zyx", angles, degrees=True).as_matrix().astype(np.float32)
771
+ Rotation.from_euler(angle_convention, angles, degrees=True)
772
+ .as_matrix()
773
+ .astype(np.float32)
770
774
  )
771
775
  return rotation_matrix
772
776
 
@@ -1052,7 +1056,7 @@ def tube_mask(
1052
1056
  symmetry_axis : int
1053
1057
  The axis of symmetry for the tube.
1054
1058
  base_center : tuple
1055
- Center of the base circle of the tube.
1059
+ Center of the tube.
1056
1060
  inner_radius : float
1057
1061
  Inner radius of the tube.
1058
1062
  outer_radius : float
@@ -1068,8 +1072,9 @@ def tube_mask(
1068
1072
  Raises
1069
1073
  ------
1070
1074
  ValueError
1071
- If the inner radius is larger than the outer radius. Or height is larger
1072
- than the symmetry axis shape.
1075
+ If the inner radius is larger than the outer radius, height is larger
1076
+ than the symmetry axis shape, or if base_center and shape do not have the
1077
+ same length.
1073
1078
  """
1074
1079
  if inner_radius > outer_radius:
1075
1080
  raise ValueError("inner_radius should be smaller than outer_radius.")
@@ -1080,40 +1085,52 @@ def tube_mask(
1080
1085
  if symmetry_axis > len(shape):
1081
1086
  raise ValueError(f"symmetry_axis can be not larger than {len(shape)}.")
1082
1087
 
1083
- circle_shape = tuple(b for ix, b in enumerate(shape) if ix != symmetry_axis)
1084
- base_center = tuple(b for ix, b in enumerate(base_center) if ix != symmetry_axis)
1088
+ if len(base_center) != len(shape):
1089
+ raise ValueError("shape and base_center need to have the same length.")
1085
1090
 
1086
- inner_circle = create_mask(
1087
- mask_type="ellipse",
1088
- shape=circle_shape,
1089
- radius=inner_radius,
1090
- center=base_center,
1091
- )
1092
- outer_circle = create_mask(
1093
- mask_type="ellipse",
1094
- shape=circle_shape,
1095
- radius=outer_radius,
1096
- center=base_center,
1097
- )
1091
+ circle_shape = tuple(b for ix, b in enumerate(shape) if ix != symmetry_axis)
1092
+ circle_center = tuple(b for ix, b in enumerate(base_center) if ix != symmetry_axis)
1093
+
1094
+ inner_circle = np.zeros(circle_shape)
1095
+ outer_circle = np.zeros_like(inner_circle)
1096
+ if inner_radius > 0:
1097
+ inner_circle = create_mask(
1098
+ mask_type="ellipse",
1099
+ shape=circle_shape,
1100
+ radius=inner_radius,
1101
+ center=circle_center,
1102
+ )
1103
+ if outer_radius > 0:
1104
+ outer_circle = create_mask(
1105
+ mask_type="ellipse",
1106
+ shape=circle_shape,
1107
+ radius=outer_radius,
1108
+ center=circle_center,
1109
+ )
1098
1110
  circle = outer_circle - inner_circle
1099
1111
  circle = np.expand_dims(circle, axis=symmetry_axis)
1100
1112
 
1101
- center = shape[symmetry_axis] // 2
1102
- start_idx = center - height // 2
1103
- stop_idx = center + height // 2 + height % 2
1113
+ center = base_center[symmetry_axis]
1114
+ start_idx = int(center - height // 2)
1115
+ stop_idx = int(center + height // 2 + height % 2)
1116
+
1117
+ start_idx, stop_idx = max(start_idx, 0), min(stop_idx, shape[symmetry_axis])
1104
1118
 
1105
1119
  slice_indices = tuple(
1106
1120
  slice(None) if i != symmetry_axis else slice(start_idx, stop_idx)
1107
1121
  for i in range(len(shape))
1108
1122
  )
1109
1123
  tube = np.zeros(shape)
1110
- tube[slice_indices] = np.repeat(circle, height, axis=symmetry_axis)
1124
+ tube[slice_indices] = circle
1111
1125
 
1112
1126
  return tube
1113
1127
 
1114
1128
 
1115
1129
  def scramble_phases(
1116
- arr: NDArray, noise_proportion: float = 0.5, seed: int = 42
1130
+ arr: NDArray,
1131
+ noise_proportion: float = 0.5,
1132
+ seed: int = 42,
1133
+ normalize_power: bool = True,
1117
1134
  ) -> NDArray:
1118
1135
  """
1119
1136
  Applies random phase scrambling to a given array.
@@ -1131,6 +1148,8 @@ def scramble_phases(
1131
1148
  The proportion of noise in the phase scrambling, by default 0.5.
1132
1149
  seed : int, optional
1133
1150
  The seed for the random phase scrambling, by default 42.
1151
+ normalize_power : bool, optional
1152
+ Whether the returned template should have the same sum of squares as arr.
1134
1153
 
1135
1154
  Returns
1136
1155
  -------
@@ -1154,6 +1173,17 @@ def scramble_phases(
1154
1173
  ph_noise = np.random.permutation(ph)
1155
1174
  ph_new = ph * (1 - noise_proportion) + ph_noise * noise_proportion
1156
1175
  ret = np.real(np.fft.ifftn(amp * np.exp(1j * ph_new)))
1176
+
1177
+ if normalize_power:
1178
+ np.divide(
1179
+ np.subtract(ret, ret.min()), np.subtract(ret.max(), ret.min()), out=ret
1180
+ )
1181
+ np.multiply(ret, np.subtract(arr.max(), arr.min()), out=ret)
1182
+ np.add(ret, arr.min(), out=ret)
1183
+
1184
+ scaling = np.divide(np.abs(arr).sum(), np.abs(ret).sum())
1185
+ np.multiply(ret, scaling, out=ret)
1186
+
1157
1187
  return ret
1158
1188
 
1159
1189
 
tme/parser.py CHANGED
@@ -229,7 +229,7 @@ class PDBParser(Parser):
229
229
  class MMCIFParser(Parser):
230
230
  """
231
231
  A Parser subclass for converting MMCIF file data into a dictionary representation.
232
- This implementation heavily relies on the atomium library:
232
+ This implementation heavily relies on the atomium library [1]_.
233
233
 
234
234
  References
235
235
  ----------
@@ -308,7 +308,7 @@ class MMCIFParser(Parser):
308
308
  -------
309
309
  list of dict
310
310
  A list of dictionaries where each dictionary represents a block
311
- of data from the MMCIF file.
311
+ of data from the MMCIF file.
312
312
  """
313
313
  category = None
314
314
  block, blocks = [], []
tme/preprocessor.py CHANGED
@@ -1055,8 +1055,8 @@ class Preprocessor:
1055
1055
  be equivalent to the following
1056
1056
 
1057
1057
  >>> wedge = Preprocessor().continuous_wedge_mask(
1058
- >>> shape = (50,50,50),
1059
- >>> start_tilt = 50,
1058
+ >>> shape=(50,50,50),
1059
+ >>> start_tilt=50,
1060
1060
  >>> stop_tilt=55,
1061
1061
  >>> tilt_axis=1,
1062
1062
  >>> omit_negative_frequencies=False,
@@ -1074,6 +1074,7 @@ class Preprocessor:
1074
1074
 
1075
1075
  See Also
1076
1076
  --------
1077
+ :py:meth:`Preprocessor.step_wedge_mask`
1077
1078
  :py:meth:`Preprocessor.continuous_wedge_mask`
1078
1079
  """
1079
1080
  opening_axes = np.asarray(opening_axes)
@@ -1095,7 +1096,7 @@ class Preprocessor:
1095
1096
 
1096
1097
  opening_axis = opening_axes[index]
1097
1098
  rotation_matrix = euler_to_rotationmatrix(
1098
- np.roll(tilt_angles[:, index], 1 - opening_axes)
1099
+ np.roll(tilt_angles[:, index], opening_axis - 1)
1099
1100
  )
1100
1101
 
1101
1102
  subset = tuple(
@@ -1124,6 +1125,112 @@ class Preprocessor:
1124
1125
 
1125
1126
  return wedge_volume
1126
1127
 
1128
+ def step_wedge_mask(
1129
+ self,
1130
+ start_tilt: float,
1131
+ stop_tilt: float,
1132
+ tilt_step: float,
1133
+ shape: Tuple[int],
1134
+ opening_axis: int = 0,
1135
+ tilt_axis: int = 2,
1136
+ sigma: float = 0,
1137
+ omit_negative_frequencies: bool = True,
1138
+ ) -> NDArray:
1139
+ """
1140
+ Create a wedge mask with the same shape as template by rotating a
1141
+ plane according to tilt angles. The DC component of the filter is at the origin.
1142
+
1143
+ Parameters
1144
+ ----------
1145
+ start_tilt : float
1146
+ Starting tilt angle in degrees, e.g. a stage tilt of 70 degrees
1147
+ would yield a start_tilt value of 70.
1148
+ stop_tilt : float
1149
+ Ending tilt angle in degrees, , e.g. a stage tilt of -70 degrees
1150
+ would yield a stop_tilt value of 70.
1151
+ tilt_step : float
1152
+ Angle between the different tilt planes.
1153
+ shape : Tuple of ints
1154
+ Shape of the output wedge array.
1155
+ tilt_axis : int, optional
1156
+ Axis that the plane is tilted over.
1157
+ - 0 for Z-axis
1158
+ - 1 for Y-axis
1159
+ - 2 for X-axis
1160
+ opening_axis : int, optional
1161
+ Axis running through the void defined by the wedge.
1162
+ - 0 for Z-axis
1163
+ - 1 for Y-axis
1164
+ - 2 for X-axis
1165
+ sigma : float, optional
1166
+ Standard deviation for Gaussian kernel used for smoothing the wedge.
1167
+ omit_negative_frequencies : bool, optional
1168
+ Whether the wedge mask should omit negative frequencies, i.e. be
1169
+ applicable to symmetric Fourier transforms (see :obj:`numpy.fft.fftn`)
1170
+
1171
+ Returns
1172
+ -------
1173
+ NDArray
1174
+ A numpy array containing the wedge mask.
1175
+
1176
+ Notes
1177
+ -----
1178
+ This function is equivalent to :py:meth:`Preprocessor.wedge_mask`, but much faster
1179
+ for large shapes because it only considers a single tilt angle rather than the rotation
1180
+ of an N-1 dimensional hyperplane in N dimensions.
1181
+
1182
+ See Also
1183
+ --------
1184
+ :py:meth:`Preprocessor.wedge_mask`
1185
+ :py:meth:`Preprocessor.continuous_wedge_mask`
1186
+ """
1187
+ tilt_angles = np.arange(-start_tilt, stop_tilt + tilt_step, tilt_step)
1188
+ plane = np.zeros((shape[opening_axis], shape[tilt_axis]), dtype=np.float32)
1189
+ subset = tuple(
1190
+ slice(None) if i != 0 else slice(x // 2, x // 2 + 1)
1191
+ for i, x in enumerate(plane.shape)
1192
+ )
1193
+ plane[subset] = 1
1194
+ plane_rotated, wedge_volume = np.zeros_like(plane), np.zeros_like(plane)
1195
+ for index in range(tilt_angles.shape[0]):
1196
+ plane_rotated.fill(0)
1197
+
1198
+ rotation_matrix = euler_to_rotationmatrix((tilt_angles[index], 0))
1199
+ rotation_matrix = rotation_matrix[np.ix_((0, 1), (0, 1))]
1200
+
1201
+ Density.rotate_array(
1202
+ arr=plane,
1203
+ rotation_matrix=rotation_matrix,
1204
+ out=plane_rotated,
1205
+ use_geometric_center=True,
1206
+ order=1,
1207
+ )
1208
+ wedge_volume += plane_rotated
1209
+
1210
+ wedge_volume = self.gaussian_filter(
1211
+ template=wedge_volume, sigma=sigma, fourier=False
1212
+ )
1213
+ wedge_volume = np.where(wedge_volume > np.exp(-2), 1, 0)
1214
+
1215
+ if opening_axis > tilt_axis:
1216
+ wedge_volume = np.moveaxis(wedge_volume, 1, 0)
1217
+
1218
+ reshape_dimensions = tuple(
1219
+ x if i in (opening_axis, tilt_axis) else 1 for i, x in enumerate(shape)
1220
+ )
1221
+
1222
+ wedge_volume = wedge_volume.reshape(reshape_dimensions)
1223
+ tile_dimensions = np.divide(shape, reshape_dimensions).astype(int)
1224
+ wedge_volume = np.tile(wedge_volume, tile_dimensions)
1225
+
1226
+ wedge_volume = np.fft.ifftshift(wedge_volume)
1227
+
1228
+ if omit_negative_frequencies:
1229
+ stop = 1 + (wedge_volume.shape[-1] // 2)
1230
+ wedge_volume = wedge_volume[..., :stop]
1231
+
1232
+ return wedge_volume
1233
+
1127
1234
  def continuous_wedge_mask(
1128
1235
  self,
1129
1236
  start_tilt: float,
@@ -1192,6 +1299,7 @@ class Preprocessor:
1192
1299
  See Also
1193
1300
  --------
1194
1301
  :py:meth:`Preprocessor.wedge_mask`
1302
+ :py:meth:`Preprocessor.step_wedge_mask`
1195
1303
  """
1196
1304
  shape_center = np.divide(shape, 2).astype(int)
1197
1305
 
@@ -1218,7 +1326,7 @@ class Preprocessor:
1218
1326
  distances = np.linalg.norm(grid, axis=0)
1219
1327
 
1220
1328
  if not infinite_plane:
1221
- np.multiply(wedge, distances <= shape[opening_axis] // 2, out=wedge)
1329
+ np.multiply(wedge, distances <= shape[tilt_axis] // 2, out=wedge)
1222
1330
 
1223
1331
  wedge = self.gaussian_filter(template=wedge, sigma=sigma, fourier=False)
1224
1332
  wedge = np.fft.ifftshift(wedge > np.exp(-2))
@@ -1348,7 +1456,7 @@ class LinearWhiteningFilter:
1348
1456
  def filter(
1349
1457
  self, template: NDArray, n_bins: int = None
1350
1458
  ) -> Tuple[NDArray, NDArray, NDArray]:
1351
- max_bins = int(np.linalg.norm(template.shape) // 2 + 1)
1459
+ max_bins = np.max(template.shape) // 2 + 1
1352
1460
  n_bins = max_bins if n_bins is None else n_bins
1353
1461
  n_bins = int(min(n_bins, max_bins))
1354
1462
 
@@ -1360,31 +1468,36 @@ class LinearWhiteningFilter:
1360
1468
  _, bin_edges = np.histogram(frequency_grid, bins=n_bins - 1)
1361
1469
  bins = np.digitize(frequency_grid, bins=bin_edges, right=True)
1362
1470
 
1363
- fourier_transform = np.fft.fftshift(np.fft.rfftn(template))
1471
+ fft_shift_axes = tuple(range(template.ndim - 1))
1472
+ fourier_transform = np.fft.fftshift(np.fft.rfftn(template), axes=fft_shift_axes)
1364
1473
  fourier_spectrum = np.abs(fourier_transform)
1365
1474
 
1366
1475
  radial_averages = ndimean(fourier_spectrum, labels=bins, index=np.unique(bins))
1367
1476
  np.reciprocal(radial_averages, out=radial_averages)
1368
1477
  np.divide(radial_averages, radial_averages.max(), out=radial_averages)
1369
1478
 
1370
- bins = radial_averages[bins]
1371
- np.multiply(fourier_transform, bins, out=fourier_transform)
1372
- center_indices = tuple([dim_size // 2 for dim_size in fourier_transform.shape])
1373
- fourier_transform[center_indices] = 0
1374
- ret = np.fft.irfftn(np.fft.ifftshift(fourier_transform)).real
1479
+ np.multiply(fourier_transform, radial_averages[bins], out=fourier_transform)
1375
1480
 
1481
+ ret = np.fft.irfftn(
1482
+ np.fft.ifftshift(fourier_transform, axes=fft_shift_axes), s=template.shape
1483
+ ).real
1376
1484
  return ret, bin_edges, radial_averages
1377
1485
 
1378
1486
  def apply(
1379
1487
  self, template: NDArray, bin_edges: NDArray, radial_averages: NDArray
1380
1488
  ) -> NDArray:
1381
- fourier_transform = np.fft.fftshift(np.fft.fftn(template))
1382
-
1383
- grid = self._fftfreqn(shape=fourier_transform.shape, sampling_rate=1)
1489
+ grid = self._fftfreqn(
1490
+ shape=template.shape, sampling_rate=1, omit_negative_frequencies=True
1491
+ )
1384
1492
  frequency_grid = np.linalg.norm(grid, axis=0)
1385
1493
 
1494
+ fft_shift_axes = tuple(range(template.ndim - 1))
1495
+ fourier_transform = np.fft.fftshift(np.fft.rfftn(template), axes=fft_shift_axes)
1496
+
1386
1497
  bins = np.digitize(frequency_grid, bins=bin_edges, right=True)
1387
1498
  np.multiply(fourier_transform, radial_averages[bins], out=fourier_transform)
1388
- ret = np.fft.ifftn(np.fft.ifftshift(fourier_transform)).real
1499
+ ret = np.fft.irfftn(
1500
+ np.fft.ifftshift(fourier_transform, axes=fft_shift_axes), s=template.shape
1501
+ ).real
1389
1502
 
1390
1503
  return ret
tme/types.py CHANGED
@@ -3,6 +3,7 @@ from typing import Union, TypeVar
3
3
  NDArray = TypeVar("numpy.ndarray")
4
4
  CupyArray = TypeVar("cupy.ndarray")
5
5
  TorchTensor = TypeVar("torch.Tensor")
6
+ MlxArray = TypeVar("mlx.core.array")
6
7
 
7
8
  Scalar = Union[int, float, complex]
8
9
 
File without changes
File without changes