pytme 0.3.1__cp311-cp311-macosx_15_0_arm64.whl → 0.3.1.dev20250731__cp311-cp311-macosx_15_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (33) hide show
  1. pytme-0.3.1.dev20250731.data/scripts/estimate_ram_usage.py +97 -0
  2. {pytme-0.3.1.data → pytme-0.3.1.dev20250731.data}/scripts/match_template.py +2 -2
  3. {pytme-0.3.1.data → pytme-0.3.1.dev20250731.data}/scripts/postprocess.py +16 -15
  4. {pytme-0.3.1.data → pytme-0.3.1.dev20250731.data}/scripts/preprocessor_gui.py +1 -0
  5. {pytme-0.3.1.dist-info → pytme-0.3.1.dev20250731.dist-info}/METADATA +2 -4
  6. {pytme-0.3.1.dist-info → pytme-0.3.1.dev20250731.dist-info}/RECORD +33 -30
  7. scripts/estimate_ram_usage.py +97 -0
  8. scripts/match_template.py +2 -2
  9. scripts/match_template_devel.py +1339 -0
  10. scripts/postprocess.py +16 -15
  11. scripts/preprocessor_gui.py +1 -0
  12. scripts/refine_matches.py +5 -7
  13. tests/test_analyzer.py +2 -3
  14. tests/test_extensions.py +0 -1
  15. tests/test_orientations.py +0 -12
  16. tme/analyzer/aggregation.py +22 -12
  17. tme/backends/_jax_utils.py +60 -16
  18. tme/backends/cupy_backend.py +11 -11
  19. tme/backends/jax_backend.py +27 -9
  20. tme/backends/matching_backend.py +11 -0
  21. tme/backends/npfftw_backend.py +3 -0
  22. tme/density.py +58 -1
  23. tme/matching_data.py +24 -0
  24. tme/matching_exhaustive.py +5 -2
  25. tme/matching_scores.py +23 -0
  26. tme/orientations.py +20 -7
  27. {pytme-0.3.1.data → pytme-0.3.1.dev20250731.data}/scripts/estimate_memory_usage.py +0 -0
  28. {pytme-0.3.1.data → pytme-0.3.1.dev20250731.data}/scripts/preprocess.py +0 -0
  29. {pytme-0.3.1.data → pytme-0.3.1.dev20250731.data}/scripts/pytme_runner.py +0 -0
  30. {pytme-0.3.1.dist-info → pytme-0.3.1.dev20250731.dist-info}/WHEEL +0 -0
  31. {pytme-0.3.1.dist-info → pytme-0.3.1.dev20250731.dist-info}/entry_points.txt +0 -0
  32. {pytme-0.3.1.dist-info → pytme-0.3.1.dev20250731.dist-info}/licenses/LICENSE +0 -0
  33. {pytme-0.3.1.dist-info → pytme-0.3.1.dev20250731.dist-info}/top_level.txt +0 -0
tme/density.py CHANGED
@@ -2196,7 +2196,7 @@ class Density:
2196
2196
 
2197
2197
  Parameters
2198
2198
  ----------
2199
- target : Density
2199
+ target : :py:class:`Density`
2200
2200
  The target map for template matching.
2201
2201
  template : Structure
2202
2202
  The template that should be aligned to the target.
@@ -2259,3 +2259,60 @@ class Density:
2259
2259
  coordinates = np.array(np.where(data > 0))
2260
2260
  weights = self.data[tuple(coordinates)]
2261
2261
  return align_to_axis(coordinates.T, weights=weights, axis=axis, flip=flip)
2262
+
2263
+ @staticmethod
2264
+ def fourier_shell_correlation(density1: "Density", density2: "Density") -> NDArray:
2265
+ """
2266
+ Computes the Fourier Shell Correlation (FSC) between two instances of `Density`.
2267
+
2268
+ The Fourier transforms of the input maps are divided into shells
2269
+ based on their spatial frequency. The correlation between corresponding shells
2270
+ in the two maps is computed to give the FSC.
2271
+
2272
+ Parameters
2273
+ ----------
2274
+ density1 : :py:class:`Density`
2275
+ Reference for comparison.
2276
+ density2 : :py:class:`Density`
2277
+ Target for comparison.
2278
+
2279
+ Returns
2280
+ -------
2281
+ NDArray
2282
+ An array of shape (N, 2), where N is the number of shells.
2283
+ The first column represents the spatial frequency for each shell
2284
+ and the second column represents the corresponding FSC.
2285
+
2286
+ References
2287
+ ----------
2288
+ .. [1] https://github.com/tdgrant1/denss/blob/master/saxstats/saxstats.py
2289
+ """
2290
+ side = density1.data.shape[0]
2291
+ df = 1.0 / side
2292
+
2293
+ qx_ = np.fft.fftfreq(side) * side * df
2294
+ qx, qy, qz = np.meshgrid(qx_, qx_, qx_, indexing="ij")
2295
+ qr = np.sqrt(qx**2 + qy**2 + qz**2)
2296
+
2297
+ qmax = np.max(qr)
2298
+ qstep = np.min(qr[qr > 0])
2299
+ nbins = int(qmax / qstep)
2300
+ qbins = np.linspace(0, nbins * qstep, nbins + 1)
2301
+ qbin_labels = np.searchsorted(qbins, qr, "right") - 1
2302
+
2303
+ F1 = np.fft.fftn(density1.data)
2304
+ F2 = np.fft.fftn(density2.data)
2305
+
2306
+ qbin_labels = qbin_labels.reshape(-1)
2307
+ numerator = np.bincount(
2308
+ qbin_labels, weights=np.real(F1 * np.conj(F2)).reshape(-1)
2309
+ )
2310
+ term1 = np.bincount(qbin_labels, weights=np.abs(F1).reshape(-1) ** 2)
2311
+ term2 = np.bincount(qbin_labels, weights=np.abs(F2).reshape(-1) ** 2)
2312
+ np.multiply(term1, term2, out=term1)
2313
+ denominator = np.sqrt(term1)
2314
+ FSC = np.divide(numerator, denominator)
2315
+
2316
+ qidx = np.where(qbins < qx.max())
2317
+
2318
+ return np.vstack((qbins[qidx], FSC[qidx])).T
tme/matching_data.py CHANGED
@@ -544,6 +544,30 @@ class MatchingData:
544
544
  batch_mask=be.to_numpy_array(self._batch_mask),
545
545
  )
546
546
 
547
+ def _score_mask(self, fast_shape: Tuple[int], shift: Tuple[int]) -> BackendArray:
548
+ """
549
+ Create a boolean mask to exclude scores derived from padding in template matching.
550
+ """
551
+ padding = self.target_padding(True)
552
+ offset = tuple(x // 2 for x in padding)
553
+ shape = tuple(y - x for x, y in zip(padding, self.target.shape))
554
+
555
+ subset = []
556
+ for i in range(len(offset)):
557
+ if self._batch_mask[i]:
558
+ subset.append(slice(None))
559
+ else:
560
+ subset.append(slice(offset[i], offset[i] + shape[i]))
561
+
562
+ score_mask = np.zeros(fast_shape, dtype=bool)
563
+ score_mask[tuple(subset)] = 1
564
+ score_mask = np.roll(
565
+ score_mask,
566
+ shift=tuple(-x for x in shift),
567
+ axis=tuple(i for i in range(len(shift))),
568
+ )
569
+ return be.to_backend_array(score_mask)
570
+
547
571
  def computation_schedule(
548
572
  self,
549
573
  matching_method: str = "FLCSphericalMask",
@@ -223,6 +223,10 @@ def scan(
223
223
  )
224
224
  conv, fwd, inv, shift = matching_data.fourier_padding()
225
225
 
226
+ score_mask = be.full(shape=(1,), fill_value=1, dtype=bool)
227
+ if pad_target:
228
+ score_mask = matching_data._score_mask(fwd, shift)
229
+
226
230
  template_filter = _setup_template_filter_apply_target_filter(
227
231
  matching_data=matching_data,
228
232
  fast_shape=fwd,
@@ -275,6 +279,7 @@ def scan(
275
279
  callback=callback_classes[index % n_callback_classes],
276
280
  interpolation_order=interpolation_order,
277
281
  template_filter=be.to_sharedarr(template_filter, shm_handler),
282
+ score_mask=be.to_sharedarr(score_mask, shm_handler),
278
283
  **setup,
279
284
  )
280
285
  for index, rotation in enumerate(matching_data._split_rotations_on_jobs(n_jobs))
@@ -420,8 +425,6 @@ def scan_subsets(
420
425
  outer_jobs, inner_jobs = job_schedule
421
426
  if be._backend_name == "jax":
422
427
  func = be.scan
423
- if kwargs.get("projection_matching", False):
424
- func = be.scan_projections
425
428
 
426
429
  corr_scoring = MATCHING_EXHAUSTIVE_REGISTER.get("CORR", (None, None))[1]
427
430
  results = func(
tme/matching_scores.py CHANGED
@@ -356,6 +356,7 @@ def corr_scoring(
356
356
  callback: CallbackClass,
357
357
  interpolation_order: int,
358
358
  template_mask: shm_type = None,
359
+ score_mask: shm_type = None,
359
360
  ) -> CallbackClass:
360
361
  """
361
362
  Calculates a normalized cross-correlation between a target f and a template g.
@@ -394,6 +395,8 @@ def corr_scoring(
394
395
  Spline order for template rotations.
395
396
  template_mask : Union[Tuple[type, tuple of ints, type], BackendArray], optional
396
397
  Template mask data buffer, its shape and datatype, None by default.
398
+ score_mask : Union[Tuple[type, tuple of ints, type], BackendArray], optional
399
+ Score mask data buffer, its shape and datatype, None by default.
397
400
 
398
401
  Returns
399
402
  -------
@@ -404,6 +407,7 @@ def corr_scoring(
404
407
  inv_denominator = be.from_sharedarr(inv_denominator)
405
408
  numerator = be.from_sharedarr(numerator)
406
409
  template_filter = be.from_sharedarr(template_filter)
410
+ score_mask = be.from_sharedarr(score_mask)
407
411
 
408
412
  n_obs = None
409
413
  if template_mask is not None:
@@ -413,6 +417,7 @@ def corr_scoring(
413
417
  norm_template = conditional_execute(normalize_template, n_obs is not None)
414
418
  norm_sub = conditional_execute(be.subtract, numerator.shape != (1,))
415
419
  norm_mul = conditional_execute(be.multiply, inv_denominator.shape != (1))
420
+ norm_mask = conditional_execute(be.multiply, score_mask.shape != (1,))
416
421
 
417
422
  arr = be.zeros(fast_shape, be._float_dtype)
418
423
  ft_temp = be.zeros(fast_ft_shape, be._complex_dtype)
@@ -447,6 +452,8 @@ def corr_scoring(
447
452
 
448
453
  arr = norm_sub(arr, numerator, out=arr)
449
454
  arr = norm_mul(arr, inv_denominator, out=arr)
455
+ arr = norm_mask(arr, score_mask, out=arr)
456
+
450
457
  callback(arr, rotation_matrix=rotation)
451
458
 
452
459
  return callback
@@ -463,6 +470,7 @@ def flc_scoring(
463
470
  rotations: BackendArray,
464
471
  callback: CallbackClass,
465
472
  interpolation_order: int,
473
+ score_mask: shm_type = None,
466
474
  ) -> CallbackClass:
467
475
  """
468
476
  Computes a normalized cross-correlation between ``target`` (f),
@@ -522,6 +530,7 @@ def flc_scoring(
522
530
  ft_target = be.from_sharedarr(ft_target)
523
531
  ft_target2 = be.from_sharedarr(ft_target2)
524
532
  template_filter = be.from_sharedarr(template_filter)
533
+ score_mask = be.from_sharedarr(score_mask)
525
534
 
526
535
  arr = be.zeros(fast_shape, float_dtype)
527
536
  temp = be.zeros(fast_shape, float_dtype)
@@ -532,6 +541,7 @@ def flc_scoring(
532
541
  template_mask_rot = be.zeros(template.shape, be._float_dtype)
533
542
 
534
543
  tmpl_filter_func = _create_filter_func(template.shape, template_filter.shape)
544
+ norm_mask = conditional_execute(be.multiply, score_mask.shape != (1,))
535
545
 
536
546
  eps = be.eps(float_dtype)
537
547
  center = be.divide(be.to_backend_array(template.shape) - 1, 2)
@@ -567,6 +577,8 @@ def flc_scoring(
567
577
  arr = _correlate_fts(ft_target, ft_temp, ft_temp, arr, fast_shape)
568
578
 
569
579
  arr = be.norm_scores(arr, temp2, temp, n_obs, eps, arr)
580
+ arr = norm_mask(arr, score_mask, out=arr)
581
+
570
582
  callback(arr, rotation_matrix=rotation)
571
583
 
572
584
  return callback
@@ -585,6 +597,7 @@ def mcc_scoring(
585
597
  callback: CallbackClass,
586
598
  interpolation_order: int,
587
599
  overlap_ratio: float = 0.3,
600
+ score_mask: shm_type = None,
588
601
  ) -> CallbackClass:
589
602
  """
590
603
  Computes a normalized cross-correlation score between ``target`` (f),
@@ -755,12 +768,14 @@ def flc_scoring2(
755
768
  rotations: BackendArray,
756
769
  callback: CallbackClass,
757
770
  interpolation_order: int,
771
+ score_mask: shm_type = None,
758
772
  ) -> CallbackClass:
759
773
  template = be.from_sharedarr(template)
760
774
  template_mask = be.from_sharedarr(template_mask)
761
775
  ft_target = be.from_sharedarr(ft_target)
762
776
  ft_target2 = be.from_sharedarr(ft_target2)
763
777
  template_filter = be.from_sharedarr(template_filter)
778
+ score_mask = be.from_sharedarr(score_mask)
764
779
 
765
780
  tar_batch, tmpl_batch = _get_batch_dim(ft_target, template)
766
781
 
@@ -785,6 +800,7 @@ def flc_scoring2(
785
800
  filter_shape=template_filter.shape,
786
801
  arr_padded=True,
787
802
  )
803
+ norm_mask = conditional_execute(be.multiply, score_mask.shape != (1,))
788
804
 
789
805
  eps = be.eps(be._float_dtype)
790
806
  for index in range(rotations.shape[0]):
@@ -816,6 +832,8 @@ def flc_scoring2(
816
832
  arr = _correlate_fts(ft_target, ft_temp, ft_denom, arr, shape, axes)
817
833
 
818
834
  arr = be.norm_scores(arr, temp2, temp, n_obs, eps, arr)
835
+ arr = norm_mask(arr, score_mask, out=arr)
836
+
819
837
  callback(arr, rotation_matrix=rotation)
820
838
 
821
839
  return callback
@@ -834,12 +852,14 @@ def corr_scoring2(
834
852
  interpolation_order: int,
835
853
  target_filter: shm_type = None,
836
854
  template_mask: shm_type = None,
855
+ score_mask: shm_type = None,
837
856
  ) -> CallbackClass:
838
857
  template = be.from_sharedarr(template)
839
858
  ft_target = be.from_sharedarr(ft_target)
840
859
  inv_denominator = be.from_sharedarr(inv_denominator)
841
860
  numerator = be.from_sharedarr(numerator)
842
861
  template_filter = be.from_sharedarr(template_filter)
862
+ score_mask = be.from_sharedarr(score_mask)
843
863
 
844
864
  tar_batch, tmpl_batch = _get_batch_dim(ft_target, template)
845
865
 
@@ -869,6 +889,7 @@ def corr_scoring2(
869
889
  norm_template = conditional_execute(normalize_template, n_obs is not None)
870
890
  norm_sub = conditional_execute(be.subtract, numerator.shape != (1,))
871
891
  norm_mul = conditional_execute(be.multiply, inv_denominator.shape != (1,))
892
+ norm_mask = conditional_execute(be.multiply, score_mask.shape != (1,))
872
893
 
873
894
  template_filter_func = _create_filter_func(
874
895
  arr_shape=template.shape,
@@ -896,6 +917,8 @@ def corr_scoring2(
896
917
 
897
918
  arr = norm_sub(arr, numerator, out=arr)
898
919
  arr = norm_mul(arr, inv_denominator, out=arr)
920
+ arr = norm_mask(arr, score_mask, out=arr)
921
+
899
922
  callback(arr, rotation_matrix=rotation)
900
923
 
901
924
  return callback
tme/orientations.py CHANGED
@@ -327,11 +327,18 @@ class Orientations:
327
327
  "_rlnAnglePsi",
328
328
  "_rlnClassNumber",
329
329
  ]
330
+
331
+ target_identifer = "_rlnMicrographName"
332
+ if version == "# version 50001":
333
+ header[3] = "_rlnCenteredCoordinateXAngst"
334
+ header[4] = "_rlnCenteredCoordinateYAngst"
335
+ header[5] = "_rlnCenteredCoordinateZAngst"
336
+ target_identifer = "_rlnTomoName"
337
+
330
338
  if source_path is not None:
331
- header.append("_rlnMicrographName")
339
+ header.append(target_identifer)
332
340
 
333
341
  header.append("_pytmeScore")
334
-
335
342
  header = "\n".join(header)
336
343
  with open(filename, mode="w", encoding="utf-8") as ofile:
337
344
  if version is not None:
@@ -487,16 +494,22 @@ class Orientations:
487
494
 
488
495
  @classmethod
489
496
  def _from_star(
490
- cls, filename: str, delimiter: str = "\t"
497
+ cls, filename: str, delimiter: str = None
491
498
  ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
492
499
  parser = StarParser(filename, delimiter=delimiter)
493
500
 
494
- ret = parser.get("data_particles", None)
495
- if ret is None:
496
- ret = parser.get("data_", None)
501
+ keyword_order = ("data_particles", "particles", "data")
502
+ for keyword in keyword_order:
503
+ ret = parser.get(keyword, None)
504
+ if ret is None:
505
+ ret = parser.get(f"{keyword}_", None)
506
+ if ret is not None:
507
+ break
497
508
 
498
509
  if ret is None:
499
- raise ValueError(f"No data_particles section found in {filename}.")
510
+ raise ValueError(
511
+ f"Could not find either {keyword_order} section found in {filename}."
512
+ )
500
513
 
501
514
  translation = np.vstack(
502
515
  (ret["_rlnCoordinateX"], ret["_rlnCoordinateY"], ret["_rlnCoordinateZ"])