pytme 0.1.6__cp311-cp311-macosx_14_0_arm64.whl → 0.1.8__cp311-cp311-macosx_14_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
scripts/postprocess.py CHANGED
@@ -5,6 +5,8 @@
5
5
 
6
6
  Author: Valentin Maurer <valentin.maurer@embl-hamburg.de>
7
7
  """
8
+ from os import getcwd
9
+ from os.path import join
8
10
  import argparse
9
11
  from sys import exit
10
12
  from typing import List, Tuple
@@ -13,6 +15,7 @@ from dataclasses import dataclass
13
15
 
14
16
  import numpy as np
15
17
  from scipy.spatial.transform import Rotation
18
+ from numpy.typing import NDArray
16
19
 
17
20
  from tme import Density, Structure
18
21
  from tme.analyzer import (
@@ -26,6 +29,7 @@ from tme.matching_utils import (
26
29
  load_pickle,
27
30
  euler_to_rotationmatrix,
28
31
  euler_from_rotationmatrix,
32
+ centered_mask,
29
33
  )
30
34
 
31
35
  PEAK_CALLERS = {
@@ -52,16 +56,33 @@ def parse_args():
52
56
  help="Prefix for the output file name. Extension depends on output_format.",
53
57
  )
54
58
  parser.add_argument(
55
- "--number_of_peaks", type=int, default=1000, help="Number of peaks to consider."
59
+ "--number_of_peaks",
60
+ type=int,
61
+ default=1000,
62
+ help="Number of peaks to consider. Note, this is the number of called peaks "
63
+ ", subject to min_distance and min_boundary_distance filtering. Therefore, the "
64
+ "returned number of peaks will be at most equal to number_of_peaks. "
65
+ "Ignored when --orientations is provided.",
56
66
  )
57
67
  parser.add_argument(
58
- "--min_distance", type=int, default=5, help="Minimum distance between peaks."
68
+ "--min_distance",
69
+ type=int,
70
+ default=5,
71
+ help="Minimum distance between peaks. Ignored when --orientations is provided.",
59
72
  )
60
73
  parser.add_argument(
61
74
  "--min_boundary_distance",
62
75
  type=int,
63
76
  default=0,
64
- help="Minimum distance from target boundaries.",
77
+ help="Minimum distance from target boundaries. Ignored when --orientations "
78
+ "is provided.",
79
+ )
80
+ parser.add_argument(
81
+ "--mask_edges",
82
+ action="store_true",
83
+ default=False,
84
+ help="Whether to mask edges of the input score array according to the template shape."
85
+ "Uses twice the value of --min_boundary_distance if boht are provided.",
65
86
  )
66
87
  parser.add_argument(
67
88
  "--wedge_mask",
@@ -73,7 +94,8 @@ def parse_args():
73
94
  "--peak_caller",
74
95
  choices=list(PEAK_CALLERS.keys()),
75
96
  default="PeakCallerScipy",
76
- help="Peak caller to use for analysis. Ignored if input_file contains peaks.",
97
+ help="Peak caller to use for analysis. Ignored if input_file contains peaks or when "
98
+ "--orientations is provided.",
77
99
  )
78
100
  parser.add_argument(
79
101
  "--orientations",
@@ -205,8 +227,12 @@ class Orientations:
205
227
  return None
206
228
 
207
229
  def _to_relion_star(
208
- self, filename: str, name_prefix: str = None, ctf_image: str = None,
209
- sampling_rate : float = None, subtomogram_size : int = None
230
+ self,
231
+ filename: str,
232
+ name_prefix: str = None,
233
+ ctf_image: str = None,
234
+ sampling_rate: float = 1.0,
235
+ subtomogram_size: int = 0,
210
236
  ) -> None:
211
237
  """
212
238
  Save orientations in RELION's STAR file format.
@@ -249,12 +275,11 @@ class Orientations:
249
275
  "300.000000",
250
276
  str(int(subtomogram_size)),
251
277
  "3",
252
- str(float(sampling_rate))
278
+ str(float(sampling_rate)),
253
279
  ]
254
280
  optics_header = "\n".join(optics_header)
255
281
  optics_data = "\t".join(optics_data)
256
282
 
257
-
258
283
  header = [
259
284
  "data_particles",
260
285
  "",
@@ -283,13 +308,14 @@ class Orientations:
283
308
  _ = ofile.write("\n# version 30001\n")
284
309
  _ = ofile.write(f"{header}\n")
285
310
 
311
+ # pyTME uses a zyx data layout
286
312
  for index, (translation, rotation, score, detail) in enumerate(self):
287
313
  rotation = Rotation.from_euler("zyx", rotation, degrees=True)
288
- rotation = rotation.as_euler(seq="zyz", degrees=True)
314
+ rotation = rotation.as_euler(seq="xyx", degrees=True)
289
315
 
290
316
  translation_string = "\t".join([str(x) for x in translation][::-1])
291
- angle_string = "\t".join([str(x) for x in rotation[::-1]])
292
- name = f"{name_prefix}{index}.mrc"
317
+ angle_string = "\t".join([str(x) for x in rotation])
318
+ name = f"{name_prefix}_{index}.mrc"
293
319
  _ = ofile.write(
294
320
  f"{translation_string}\t{name}\t{angle_string}\t1{ctf_image}\n"
295
321
  )
@@ -386,11 +412,25 @@ class Orientations:
386
412
  return translation, rotation, score, detail
387
413
 
388
414
 
415
+ def load_template(filepath: str, sampling_rate: NDArray) -> "Density":
416
+ try:
417
+ template = Density.from_file(filepath)
418
+ template, _ = template.centered(0)
419
+ center_of_mass = template.center_of_mass(template.data)
420
+ except ValueError:
421
+ template = Structure.from_file(filepath)
422
+ center_of_mass = template.center_of_mass()[::-1]
423
+ template = Density.from_structure(template, sampling_rate=sampling_rate)
424
+
425
+ return template, center_of_mass
426
+
427
+
389
428
  def main():
390
429
  args = parse_args()
391
430
  data = load_pickle(args.input_file)
392
431
 
393
432
  meta = data[-1]
433
+ target_origin, _, sampling_rate, cli_args = meta
394
434
 
395
435
  if args.orientations is not None:
396
436
  orientations = Orientations.from_file(
@@ -402,13 +442,29 @@ def main():
402
442
  # Output is MaxScoreOverRotations
403
443
  if data[0].ndim == data[2].ndim:
404
444
  scores, offset, rotation_array, rotation_mapping, meta = data
445
+ if args.mask_edges:
446
+ template, center_of_mass = load_template(
447
+ cli_args.template, sampling_rate=sampling_rate
448
+ )
449
+ if not cli_args.no_centering:
450
+ template, *_ = template.centered(0)
451
+ mask_size = template.shape
452
+ if args.min_boundary_distance > 0:
453
+ mask_size = 2 * args.min_boundary_distance
454
+ scores = centered_mask(scores, np.subtract(scores.shape, mask_size) + 1)
455
+
405
456
  peak_caller = PEAK_CALLERS[args.peak_caller](
406
457
  number_of_peaks=args.number_of_peaks,
407
458
  min_distance=args.min_distance,
408
459
  min_boundary_distance=args.min_boundary_distance,
409
460
  )
410
461
  peak_caller(scores, rotation_matrix=np.eye(3))
411
- candidates = peak_caller.merge([tuple(peak_caller)])
462
+ candidates = peak_caller.merge(
463
+ candidates=[tuple(peak_caller)],
464
+ number_of_peaks=args.number_of_peaks,
465
+ min_distance=args.min_distance,
466
+ min_boundary_distance=args.min_boundary_distance,
467
+ )
412
468
  if len(candidates) == 0:
413
469
  exit(
414
470
  "Found no peaks. Try reducing min_distance or min_boundary_distance."
@@ -436,18 +492,11 @@ def main():
436
492
  orientations.to_file(filename=f"{args.output_prefix}.tsv", file_format="text")
437
493
  exit(0)
438
494
 
439
- target_origin, _, sampling_rate, cli_args = meta
440
-
441
- template_is_density, index = True, 0
442
495
  _, template_extension = splitext(cli_args.template)
443
- try:
444
- template = Density.from_file(cli_args.template)
445
- template, _ = template.centered(0)
446
- center_of_mass = template.center_of_mass(template.data)
447
- except ValueError:
448
- template_is_density = False
449
- template = Structure.from_file(cli_args.template)
450
- center_of_mass = template.center_of_mass()[::-1]
496
+ template, center_of_mass = load_template(
497
+ filepath=cli_args.template, sampling_rate=sampling_rate
498
+ )
499
+ template_is_density, index = isinstance(template, Density), 0
451
500
 
452
501
  if args.output_format == "relion":
453
502
  new_shape = np.add(template.shape, np.mod(template.shape, 2))
@@ -457,10 +506,6 @@ def main():
457
506
 
458
507
  if args.output_format in ("extraction", "relion"):
459
508
  target = Density.from_file(cli_args.target)
460
- if isinstance(template, Structure):
461
- template = Density.from_structure(
462
- template, sampling_rate=target.sampling_rate
463
- )
464
509
 
465
510
  if not np.all(np.divide(target.shape, template.shape) > 2):
466
511
  print(
@@ -489,14 +534,15 @@ def main():
489
534
  )
490
535
 
491
536
  orientations = orientations[keep_peaks]
537
+ working_directory = getcwd()
492
538
  if args.output_format == "relion":
493
539
  orientations.to_file(
494
540
  filename=f"{args.output_prefix}.star",
495
541
  file_format="relion",
496
- name_prefix=args.output_prefix,
542
+ name_prefix=join(working_directory, args.output_prefix),
497
543
  ctf_image=args.wedge_mask,
498
- sampling_rate = target.sampling_rate.max(),
499
- subtomogram_size = template.shape[0]
544
+ sampling_rate=target.sampling_rate.max(),
545
+ subtomogram_size=template.shape[0],
500
546
  )
501
547
 
502
548
  peaks = peaks[keep_peaks,]
@@ -543,7 +589,9 @@ def main():
543
589
  origin=candidate_starts[index] * sampling_rate,
544
590
  )
545
591
  # out_density.data = out_density.data * template_mask.data
546
- out_density.to_file(f"{args.output_prefix}{index}.mrc")
592
+ out_density.to_file(
593
+ join(working_directory, f"{args.output_prefix}_{index}.mrc")
594
+ )
547
595
 
548
596
  exit(0)
549
597
 
@@ -566,7 +614,10 @@ def main():
566
614
  translation=translation[::-1],
567
615
  rotation_matrix=rotation_matrix[::-1, ::-1],
568
616
  )
569
- transformed_template.to_file(f"{args.output_prefix}{index}{template_extension}")
617
+ # template_extension should contain the extension '.'
618
+ transformed_template.to_file(
619
+ f"{args.output_prefix}_{index}{template_extension}"
620
+ )
570
621
  index += 1
571
622
 
572
623
 
@@ -29,17 +29,17 @@ preprocessor = Preprocessor()
29
29
  SLIDER_MIN, SLIDER_MAX = 0, 25
30
30
 
31
31
 
32
- def gaussian_filter(template, sigma: float, **kwargs: dict):
32
+ def gaussian_filter(template: NDArray, sigma: float, **kwargs: dict) -> NDArray:
33
33
  return preprocessor.gaussian_filter(template=template, sigma=sigma, **kwargs)
34
34
 
35
35
 
36
36
  def bandpass_filter(
37
- template,
37
+ template: NDArray,
38
38
  minimum_frequency: float,
39
39
  maximum_frequency: float,
40
40
  gaussian_sigma: float,
41
41
  **kwargs: dict,
42
- ):
42
+ ) -> NDArray:
43
43
  return preprocessor.bandpass_filter(
44
44
  template=template,
45
45
  minimum_frequency=minimum_frequency,
@@ -51,8 +51,8 @@ def bandpass_filter(
51
51
 
52
52
 
53
53
  def difference_of_gaussian_filter(
54
- template, sigmas: Tuple[float, float], **kwargs: dict
55
- ):
54
+ template: NDArray, sigmas: Tuple[float, float], **kwargs: dict
55
+ ) -> NDArray:
56
56
  low_sigma, high_sigma = sigmas
57
57
  return preprocessor.difference_of_gaussian_filter(
58
58
  template=template, low_sigma=low_sigma, high_sigma=high_sigma, **kwargs
@@ -60,7 +60,7 @@ def difference_of_gaussian_filter(
60
60
 
61
61
 
62
62
  def edge_gaussian_filter(
63
- template,
63
+ template: NDArray,
64
64
  sigma: float,
65
65
  edge_algorithm: Annotated[
66
66
  str,
@@ -68,7 +68,7 @@ def edge_gaussian_filter(
68
68
  ],
69
69
  reverse: bool = False,
70
70
  **kwargs: dict,
71
- ):
71
+ ) -> NDArray:
72
72
  return preprocessor.edge_gaussian_filter(
73
73
  template=template,
74
74
  sigma=sigma,
@@ -78,13 +78,13 @@ def edge_gaussian_filter(
78
78
 
79
79
 
80
80
  def local_gaussian_filter(
81
- template,
81
+ template: NDArray,
82
82
  lbd: float,
83
83
  sigma_range: Tuple[float, float],
84
84
  gaussian_sigma: float,
85
85
  reverse: bool = False,
86
86
  **kwargs: dict,
87
- ):
87
+ ) -> NDArray:
88
88
  return preprocessor.local_gaussian_filter(
89
89
  template=template,
90
90
  lbd=lbd,
@@ -94,21 +94,76 @@ def local_gaussian_filter(
94
94
 
95
95
 
96
96
  def ntree(
97
- template,
97
+ template: NDArray,
98
98
  sigma_range: Tuple[float, float],
99
99
  **kwargs: dict,
100
- ):
100
+ ) -> NDArray:
101
101
  return preprocessor.ntree_filter(template=template, sigma_range=sigma_range)
102
102
 
103
103
 
104
104
  def mean(
105
- template,
105
+ template: NDArray,
106
106
  width: int,
107
107
  **kwargs: dict,
108
- ):
108
+ ) -> NDArray:
109
109
  return preprocessor.mean_filter(template=template, width=width)
110
110
 
111
111
 
112
+ def resolution_sphere(
113
+ template: NDArray,
114
+ cutoff_angstrom: float,
115
+ highpass: bool = False,
116
+ sampling_rate=None,
117
+ ) -> NDArray:
118
+ if cutoff_angstrom == 0:
119
+ return template
120
+
121
+ cutoff_frequency = np.max(2 * sampling_rate / cutoff_angstrom)
122
+
123
+ min_freq, max_freq = 0, cutoff_frequency
124
+ if highpass:
125
+ min_freq, max_freq = cutoff_frequency, 1e10
126
+
127
+ mask = preprocessor.bandpass_mask(
128
+ shape=template.shape,
129
+ minimum_frequency=min_freq,
130
+ maximum_frequency=max_freq,
131
+ omit_negative_frequencies=False,
132
+ )
133
+
134
+ template_ft = np.fft.fftn(template)
135
+ np.multiply(template_ft, mask, out=template_ft)
136
+ return np.fft.ifftn(template_ft).real
137
+
138
+
139
+ def resolution_gaussian(
140
+ template: NDArray,
141
+ cutoff_angstrom: float,
142
+ highpass: bool = False,
143
+ sampling_rate=None,
144
+ ) -> NDArray:
145
+ if cutoff_angstrom == 0:
146
+ return template
147
+
148
+ grid = preprocessor.fftfreqn(
149
+ shape=template.shape, sampling_rate=sampling_rate / sampling_rate.max()
150
+ )
151
+
152
+ sigma_fourier = np.divide(
153
+ np.max(2 * sampling_rate / cutoff_angstrom), np.sqrt(2 * np.log(2))
154
+ )
155
+
156
+ mask = np.exp(-(grid**2) / (2 * sigma_fourier**2))
157
+ if highpass:
158
+ mask = 1 - mask
159
+
160
+ mask = np.fft.ifftshift(mask)
161
+
162
+ template_ft = np.fft.fftn(template)
163
+ np.multiply(template_ft, mask, out=template_ft)
164
+ return np.fft.ifftn(template_ft).real
165
+
166
+
112
167
  def wedge(
113
168
  template: NDArray,
114
169
  tilt_start: float,
@@ -120,7 +175,7 @@ def wedge(
120
175
  omit_negative_frequencies: bool = True,
121
176
  extrude_plane: bool = True,
122
177
  infinite_plane: bool = True,
123
- ):
178
+ ) -> NDArray:
124
179
  template_ft = np.fft.rfftn(template)
125
180
 
126
181
  if tilt_step <= 0:
@@ -139,15 +194,14 @@ def wedge(
139
194
  template = np.real(np.fft.irfftn(template_ft))
140
195
  return template
141
196
 
142
- tilt_angles = np.arange(-tilt_start, tilt_stop, tilt_step)
143
- angles = np.zeros((template.ndim, tilt_angles.size))
144
- angles[tilt_axis, :] = tilt_angles
145
-
146
- wedge_mask = preprocessor.wedge_mask(
147
- tilt_angles=angles,
197
+ wedge_mask = preprocessor.step_wedge_mask(
198
+ start_tilt=tilt_start,
199
+ stop_tilt=tilt_stop,
200
+ tilt_axis=tilt_axis,
201
+ tilt_step=tilt_step,
202
+ opening_axis=opening_axis,
148
203
  shape=template.shape,
149
204
  sigma=gaussian_sigma,
150
- opening_axes=opening_axis,
151
205
  omit_negative_frequencies=omit_negative_frequencies,
152
206
  )
153
207
  np.multiply(template_ft, wedge_mask, out=template_ft)
@@ -155,7 +209,7 @@ def wedge(
155
209
  return template
156
210
 
157
211
 
158
- def compute_power_spectrum(template: NDArray):
212
+ def compute_power_spectrum(template: NDArray) -> NDArray:
159
213
  return np.fft.fftshift(np.log(np.abs(np.fft.fftn(template))))
160
214
 
161
215
 
@@ -220,6 +274,8 @@ WRAPPED_FUNCTIONS = {
220
274
  "mean_filter": mean,
221
275
  "wedge_filter": wedge,
222
276
  "power_spectrum": compute_power_spectrum,
277
+ "resolution_gaussian": resolution_gaussian,
278
+ "resolution_sphere": resolution_sphere,
223
279
  }
224
280
 
225
281
  EXCLUDED_FUNCTIONS = [
@@ -332,6 +388,9 @@ class FilterWidget(widgets.Container):
332
388
  function_name = self.name_mapping.get(self.method_dropdown.value)
333
389
  function = self._get_function(function_name)
334
390
 
391
+ if "sampling_rate" in inspect.getfullargspec(function).args:
392
+ kwargs["sampling_rate"] = selected_layer_metadata["sampling_rate"]
393
+
335
394
  processed_data = function(selected_layer.data, **kwargs)
336
395
 
337
396
  new_layer_name = f"{selected_layer.name} ({self.method_dropdown.value})"
@@ -358,7 +417,7 @@ class FilterWidget(widgets.Container):
358
417
 
359
418
  def sphere_mask(
360
419
  template: NDArray, center_x: float, center_y: float, center_z: float, radius: float
361
- ):
420
+ ) -> NDArray:
362
421
  return create_mask(
363
422
  mask_type="ellipse",
364
423
  shape=template.shape,
@@ -375,7 +434,7 @@ def ellipsod_mask(
375
434
  radius_x: float,
376
435
  radius_y: float,
377
436
  radius_z: float,
378
- ):
437
+ ) -> NDArray:
379
438
  return create_mask(
380
439
  mask_type="ellipse",
381
440
  shape=template.shape,
@@ -392,7 +451,7 @@ def box_mask(
392
451
  height_x: int,
393
452
  height_y: int,
394
453
  height_z: int,
395
- ):
454
+ ) -> NDArray:
396
455
  return create_mask(
397
456
  mask_type="box",
398
457
  shape=template.shape,
@@ -410,7 +469,7 @@ def tube_mask(
410
469
  inner_radius: float,
411
470
  outer_radius: float,
412
471
  height: int,
413
- ):
472
+ ) -> NDArray:
414
473
  return create_mask(
415
474
  mask_type="tube",
416
475
  shape=template.shape,
@@ -433,7 +492,7 @@ def wedge_mask(
433
492
  omit_negative_frequencies: bool = False,
434
493
  extrude_plane: bool = True,
435
494
  infinite_plane: bool = True,
436
- ):
495
+ ) -> NDArray:
437
496
  if tilt_step <= 0:
438
497
  wedge_mask = preprocessor.continuous_wedge_mask(
439
498
  start_tilt=tilt_start,
@@ -449,25 +508,24 @@ def wedge_mask(
449
508
  wedge_mask = np.fft.fftshift(wedge_mask)
450
509
  return wedge_mask
451
510
 
452
- tilt_angles = np.arange(-tilt_start, tilt_stop + tilt_step, tilt_step)
453
- angles = np.zeros((template.ndim, tilt_angles.size))
454
-
455
- angles[tilt_axis, :] = tilt_angles
456
-
457
- wedge_mask = preprocessor.wedge_mask(
458
- tilt_angles=angles,
511
+ wedge_mask = preprocessor.step_wedge_mask(
512
+ start_tilt=tilt_start,
513
+ stop_tilt=tilt_stop,
514
+ tilt_axis=tilt_axis,
515
+ tilt_step=tilt_step,
516
+ opening_axis=opening_axis,
459
517
  shape=template.shape,
460
518
  sigma=gaussian_sigma,
461
- opening_axes=opening_axis,
462
519
  omit_negative_frequencies=omit_negative_frequencies,
463
520
  )
521
+
464
522
  wedge_mask = np.fft.fftshift(wedge_mask)
465
523
  return wedge_mask
466
524
 
467
525
 
468
526
  def threshold_mask(
469
527
  template: NDArray, standard_deviation: float = 5.0, invert: bool = False
470
- ):
528
+ ) -> NDArray:
471
529
  template_mean = template.mean()
472
530
  template_deviation = standard_deviation * template.std()
473
531
  upper = template_mean + template_deviation
@@ -479,6 +537,15 @@ def threshold_mask(
479
537
  return mask
480
538
 
481
539
 
540
+ def lowpass_mask(template: NDArray, sigma: float = 1.0):
541
+ template = template / template.max()
542
+ template = (template > np.exp(-2)) * 128.0
543
+ template = preprocessor.gaussian_filter(template=template, sigma=sigma)
544
+ mask = template > np.exp(-2)
545
+
546
+ return mask
547
+
548
+
482
549
  class MaskWidget(widgets.Container):
483
550
  def __init__(self, viewer):
484
551
  super().__init__(layout="vertical")
@@ -496,6 +563,7 @@ class MaskWidget(widgets.Container):
496
563
  "Box": box_mask,
497
564
  "Wedge": wedge_mask,
498
565
  "Threshold": threshold_mask,
566
+ "Lowpass": lowpass_mask,
499
567
  }
500
568
 
501
569
  self.method_dropdown = widgets.ComboBox(
@@ -509,7 +577,6 @@ class MaskWidget(widgets.Container):
509
577
 
510
578
  self.adapt_button = widgets.PushButton(text="Adapt to layer", enabled=False)
511
579
  self.adapt_button.changed.connect(self._update_initial_values)
512
-
513
580
  self.viewer.layers.selection.events.active.connect(
514
581
  self._update_action_button_state
515
582
  )
@@ -520,8 +587,9 @@ class MaskWidget(widgets.Container):
520
587
  # self.density_field.value = f"Positive Density in Mask: {0:.2f}%"
521
588
 
522
589
  self.append(self.method_dropdown)
523
- self.append(self.percentile_range_edit)
524
590
  self.append(self.adapt_button)
591
+ self.append(self.percentile_range_edit)
592
+
525
593
  self.append(self.align_button)
526
594
  self.append(self.action_button)
527
595
  self.append(self.density_field)
@@ -559,6 +627,7 @@ class MaskWidget(widgets.Container):
559
627
  arr=active_layer.data,
560
628
  rotation_matrix=rotation_matrix,
561
629
  use_geometric_center=False,
630
+ order=1,
562
631
  )
563
632
  eps = np.finfo(rotated_data.dtype).eps
564
633
  rotated_data[rotated_data < eps] = 0
@@ -589,10 +658,10 @@ class MaskWidget(widgets.Container):
589
658
  dict(zip(["height_x", "height_y", "height_z"], coordinates_heights))
590
659
  )
591
660
 
592
- defaults["radius"] = np.min(coordinate_radius)
661
+ defaults["radius"] = np.max(coordinate_radius)
593
662
  defaults["inner_radius"] = np.min(coordinate_radius)
594
663
  defaults["outer_radius"] = np.max(coordinate_radius)
595
- defaults["height"] = defaults["radius"]
664
+ defaults["height"] = np.max(coordinates_heights)
596
665
 
597
666
  for widget in self.action_widgets:
598
667
  if widget.name in defaults:
@@ -804,6 +873,7 @@ def parse_args():
804
873
  args = parser.parse_args()
805
874
  return args
806
875
 
876
+
807
877
  if __name__ == "__main__":
808
878
  parse_args()
809
879
  main()
tme/__version__.py CHANGED
@@ -1 +1 @@
1
- __version__ = "0.1.6"
1
+ __version__ = "0.1.8"
tme/analyzer.py CHANGED
@@ -136,13 +136,16 @@ class PeakCaller(ABC):
136
136
  upper_limit = backend.subtract(
137
137
  score_space.shape, self.min_boundary_distance
138
138
  )
139
- valid_peaks = backend.sum(
140
- backend.multiply(
141
- peak_positions < upper_limit,
142
- peak_positions >= self.min_boundary_distance,
143
- ),
144
- axis=1,
145
- ) == peak_positions.shape[1]
139
+ valid_peaks = (
140
+ backend.sum(
141
+ backend.multiply(
142
+ peak_positions < upper_limit,
143
+ peak_positions >= self.min_boundary_distance,
144
+ ),
145
+ axis=1,
146
+ )
147
+ == peak_positions.shape[1]
148
+ )
146
149
  if backend.sum(valid_peaks) == 0:
147
150
  return None
148
151
 
@@ -261,7 +264,9 @@ class PeakCaller(ABC):
261
264
  backend.add(peak_positions, translation_offset, out=peak_positions)
262
265
  if not len(self.peak_list):
263
266
  self.peak_list = [peak_positions, rotations, peak_scores, peak_details]
264
- peak_scores, peak_details, dim = (), (), peak_positions.shape[1]
267
+ dim = peak_positions.shape[1]
268
+ peak_scores = backend.zeros((0,), peak_scores.dtype)
269
+ peak_details = backend.zeros((0,), peak_details.dtype)
265
270
  rotations = backend.zeros((0, dim, dim), rotations.dtype)
266
271
  peak_positions = backend.zeros((0, dim), peak_positions.dtype)
267
272
 
tme/backends/__init__.py CHANGED
@@ -11,6 +11,7 @@ from .matching_backend import MatchingBackend
11
11
  from .npfftw_backend import NumpyFFTWBackend
12
12
  from .pytorch_backend import PytorchBackend
13
13
  from .cupy_backend import CupyBackend
14
+ from .mlx_backend import MLXBackend
14
15
 
15
16
 
16
17
  class BackendManager:
@@ -45,7 +45,11 @@ class CupyBackend(NumpyFFTWBackend):
45
45
  self.maximum_filter = maximum_filter
46
46
 
47
47
  def to_backend_array(self, arr: NDArray) -> CupyArray:
48
- if isinstance(arr, self._array_backend.ndarray):
48
+ current_device = self._array_backend.cuda.device.get_device_id()
49
+ if (
50
+ isinstance(arr, self._array_backend.ndarray)
51
+ and arr.device.id == current_device
52
+ ):
49
53
  return arr
50
54
  return self._array_backend.asarray(arr)
51
55
 
@@ -373,7 +373,7 @@ class MatchingBackend(ABC):
373
373
  """
374
374
 
375
375
  @abstractmethod
376
- def std(self, arr: ArrayLike) -> Scalar:
376
+ def std(self, arr: ArrayLike, axis: Scalar) -> Scalar:
377
377
  """
378
378
  Compute the standad deviation of array elements.
379
379
 
@@ -381,6 +381,8 @@ class MatchingBackend(ABC):
381
381
  ----------
382
382
  arr : Scalar
383
383
  The array whose standard deviation should be computed.
384
+ axis : Scalar
385
+ Axis to perform the operation on.
384
386
 
385
387
  Returns
386
388
  -------