pytme 0.3b0.post1__cp311-cp311-macosx_15_0_arm64.whl → 0.3.1__cp311-cp311-macosx_15_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (54) hide show
  1. {pytme-0.3b0.post1.data → pytme-0.3.1.data}/scripts/match_template.py +28 -39
  2. {pytme-0.3b0.post1.data → pytme-0.3.1.data}/scripts/postprocess.py +23 -10
  3. {pytme-0.3b0.post1.data → pytme-0.3.1.data}/scripts/preprocessor_gui.py +95 -24
  4. pytme-0.3.1.data/scripts/pytme_runner.py +1223 -0
  5. {pytme-0.3b0.post1.dist-info → pytme-0.3.1.dist-info}/METADATA +5 -5
  6. {pytme-0.3b0.post1.dist-info → pytme-0.3.1.dist-info}/RECORD +53 -46
  7. scripts/extract_candidates.py +118 -99
  8. scripts/match_template.py +28 -39
  9. scripts/postprocess.py +23 -10
  10. scripts/preprocessor_gui.py +95 -24
  11. scripts/pytme_runner.py +644 -190
  12. scripts/refine_matches.py +156 -386
  13. tests/data/.DS_Store +0 -0
  14. tests/data/Blurring/.DS_Store +0 -0
  15. tests/data/Maps/.DS_Store +0 -0
  16. tests/data/Raw/.DS_Store +0 -0
  17. tests/data/Structures/.DS_Store +0 -0
  18. tests/preprocessing/test_utils.py +18 -0
  19. tests/test_backends.py +3 -9
  20. tests/test_density.py +0 -1
  21. tests/test_matching_utils.py +10 -60
  22. tests/test_rotations.py +1 -1
  23. tme/__version__.py +1 -1
  24. tme/analyzer/_utils.py +4 -4
  25. tme/analyzer/aggregation.py +13 -3
  26. tme/analyzer/peaks.py +11 -10
  27. tme/backends/_jax_utils.py +15 -13
  28. tme/backends/_numpyfftw_utils.py +270 -0
  29. tme/backends/cupy_backend.py +5 -44
  30. tme/backends/jax_backend.py +58 -37
  31. tme/backends/matching_backend.py +6 -51
  32. tme/backends/mlx_backend.py +1 -27
  33. tme/backends/npfftw_backend.py +68 -65
  34. tme/backends/pytorch_backend.py +1 -26
  35. tme/density.py +2 -6
  36. tme/extensions.cpython-311-darwin.so +0 -0
  37. tme/filters/ctf.py +22 -21
  38. tme/filters/wedge.py +10 -7
  39. tme/mask.py +341 -0
  40. tme/matching_data.py +7 -19
  41. tme/matching_exhaustive.py +34 -47
  42. tme/matching_optimization.py +2 -1
  43. tme/matching_scores.py +206 -411
  44. tme/matching_utils.py +73 -422
  45. tme/memory.py +1 -1
  46. tme/orientations.py +4 -6
  47. tme/rotations.py +1 -1
  48. pytme-0.3b0.post1.data/scripts/pytme_runner.py +0 -769
  49. {pytme-0.3b0.post1.data → pytme-0.3.1.data}/scripts/estimate_memory_usage.py +0 -0
  50. {pytme-0.3b0.post1.data → pytme-0.3.1.data}/scripts/preprocess.py +0 -0
  51. {pytme-0.3b0.post1.dist-info → pytme-0.3.1.dist-info}/WHEEL +0 -0
  52. {pytme-0.3b0.post1.dist-info → pytme-0.3.1.dist-info}/entry_points.txt +0 -0
  53. {pytme-0.3b0.post1.dist-info → pytme-0.3.1.dist-info}/licenses/LICENSE +0 -0
  54. {pytme-0.3b0.post1.dist-info → pytme-0.3.1.dist-info}/top_level.txt +0 -0
@@ -40,7 +40,7 @@ def _setup_template_filter_apply_target_filter(
40
40
  matching_data: MatchingData,
41
41
  fast_shape: Tuple[int],
42
42
  fast_ft_shape: Tuple[int],
43
- pad_template_filter: bool = True,
43
+ pad_template_filter: bool = False,
44
44
  ):
45
45
  target_filter = None
46
46
  backend_arr = type(be.zeros((1), dtype=be._float_dtype))
@@ -146,11 +146,10 @@ def scan(
146
146
  matching_data: MatchingData,
147
147
  matching_setup: Callable,
148
148
  matching_score: Callable,
149
- n_jobs: int = 4,
150
- callback_class: CallbackClass = None,
149
+ callback_class: CallbackClass,
151
150
  callback_class_args: Dict = {},
151
+ n_jobs: int = 4,
152
152
  pad_target: bool = True,
153
- pad_template_filter: bool = True,
154
153
  interpolation_order: int = 3,
155
154
  jobs_per_callback_class: int = 8,
156
155
  shm_handler=None,
@@ -172,20 +171,22 @@ def scan(
172
171
  Function pointer to scoring function.
173
172
  n_jobs : int, optional
174
173
  Number of parallel jobs. Default is 4.
175
- callback_class : type, optional
174
+ callback_class : type
176
175
  Analyzer class pointer to operate on computed scores.
177
176
  callback_class_args : dict, optional
178
177
  Arguments passed to the callback_class. Default is an empty dictionary.
179
178
  pad_target: bool, optional
180
179
  Whether to pad target to the full convolution shape.
181
- pad_template_filter: bool, optional
182
- Whether to pad potential template filters to the full convolution shape.
183
180
  interpolation_order : int, optional
184
181
  Order of spline interpolation for rotations.
185
182
  jobs_per_callback_class : int, optional
186
183
  Number of jobs a callback_class instance is shared between, 8 by default.
187
184
  shm_handler : type, optional
188
185
  Manager for shared memory objects, None by default.
186
+ target_slice : tuple of slice, optional
187
+ Target subset to process.
188
+ template_slice : tuple of slice, optional
189
+ Template subset to process.
189
190
 
190
191
  Returns
191
192
  -------
@@ -220,13 +221,13 @@ def scan(
220
221
  template_shape = matching_data._batch_shape(
221
222
  matching_data.template.shape, matching_data._template_batch
222
223
  )
223
- conv, fwd, inv, shift = matching_data.fourier_padding(pad_target=pad_target)
224
+ conv, fwd, inv, shift = matching_data.fourier_padding()
224
225
 
225
226
  template_filter = _setup_template_filter_apply_target_filter(
226
227
  matching_data=matching_data,
227
228
  fast_shape=fwd,
228
229
  fast_ft_shape=inv,
229
- pad_template_filter=pad_template_filter,
230
+ pad_template_filter=False,
230
231
  )
231
232
 
232
233
  default_callback_args = {
@@ -240,11 +241,10 @@ def scan(
240
241
  "thread_safe": n_jobs > 1,
241
242
  "convolution_mode": "valid" if pad_target else "same",
242
243
  "shm_handler": shm_handler,
243
- "only_unique_rotations": True,
244
244
  "aggregate_axis": matching_data._batch_axis(matching_data._batch_mask),
245
245
  "n_rotations": matching_data.rotations.shape[0],
246
+ "inversion_mapping": n_jobs == 1,
246
247
  }
247
- callback_class_args["inversion_mapping"] = n_jobs == 1
248
248
  default_callback_args.update(callback_class_args)
249
249
 
250
250
  setup = matching_setup(
@@ -254,22 +254,14 @@ def scan(
254
254
  fast_ft_shape=inv,
255
255
  shm_handler=shm_handler,
256
256
  )
257
- setup["interpolation_order"] = interpolation_order
258
- setup["template_filter"] = be.to_sharedarr(template_filter, shm_handler)
259
257
 
260
258
  matching_data._free_data()
261
- be.free_cache()
262
-
263
259
  n_callback_classes = max(n_jobs // jobs_per_callback_class, 1)
264
260
  callback_classes = [
265
- (
266
- SharedAnalyzerProxy(
267
- callback_class,
268
- default_callback_args,
269
- shm_handler=shm_handler if n_jobs > 1 else None,
270
- )
271
- if callback_class
272
- else None
261
+ SharedAnalyzerProxy(
262
+ callback_class,
263
+ default_callback_args,
264
+ shm_handler=shm_handler if n_jobs > 1 else None,
273
265
  )
274
266
  for _ in range(n_callback_classes)
275
267
  ]
@@ -277,35 +269,32 @@ def scan(
277
269
  delayed(_wrap_backend(matching_score))(
278
270
  backend_name=be._backend_name,
279
271
  backend_args=be._backend_args,
272
+ fast_shape=fwd,
273
+ fast_ft_shape=inv,
280
274
  rotations=rotation,
281
275
  callback=callback_classes[index % n_callback_classes],
276
+ interpolation_order=interpolation_order,
277
+ template_filter=be.to_sharedarr(template_filter, shm_handler),
282
278
  **setup,
283
279
  )
284
280
  for index, rotation in enumerate(matching_data._split_rotations_on_jobs(n_jobs))
285
281
  )
286
- callbacks = [
287
- callback.result(**default_callback_args)
288
- for callback in ret[:n_callback_classes]
289
- if callback
290
- ]
291
282
  be.free_cache()
292
283
 
293
- if callback_class:
294
- ret = callback_class.merge(callbacks, **default_callback_args)
295
- return ret
284
+ callbacks = [x.result(**default_callback_args) for x in ret[:n_callback_classes]]
285
+ return callback_class.merge(callbacks, **default_callback_args)
296
286
 
297
287
 
298
288
  def scan_subsets(
299
289
  matching_data: MatchingData,
300
290
  matching_score: Callable,
301
291
  matching_setup: Callable,
302
- callback_class: CallbackClass = None,
292
+ callback_class: CallbackClass,
303
293
  callback_class_args: Dict = {},
304
294
  job_schedule: Tuple[int] = (1, 1),
305
295
  target_splits: Dict = {},
306
296
  template_splits: Dict = {},
307
297
  pad_target_edges: bool = False,
308
- pad_template_filter: bool = True,
309
298
  interpolation_order: int = 3,
310
299
  jobs_per_callback_class: int = 8,
311
300
  backend_name: str = None,
@@ -325,7 +314,7 @@ def scan_subsets(
325
314
  Function pointer to setup function.
326
315
  matching_score : type
327
316
  Function pointer to scoring function.
328
- callback_class : type, optional
317
+ callback_class : type
329
318
  Analyzer class pointer to operate on computed scores.
330
319
  callback_class_args : dict, optional
331
320
  Arguments passed to the callback_class. Default is an empty dictionary.
@@ -341,8 +330,6 @@ def scan_subsets(
341
330
  See :py:meth:`tme.matching_utils.compute_parallelization_schedule`.
342
331
  pad_target_edges : bool, optional
343
332
  Pad the target boundaries to avoid edge effects.
344
- pad_template_filter: bool, optional
345
- Whether to pad potential template filters to the full convolution shape.
346
333
  interpolation_order : int, optional
347
334
  Order of spline interpolation for rotations.
348
335
  jobs_per_callback_class : int, optional
@@ -424,18 +411,24 @@ def scan_subsets(
424
411
  )
425
412
  splits = tuple(product(target_splits, template_splits))
426
413
 
414
+ kwargs = {
415
+ "matching_data": matching_data,
416
+ "callback_class": callback_class,
417
+ "callback_class_args": callback_class_args,
418
+ }
419
+
427
420
  outer_jobs, inner_jobs = job_schedule
428
421
  if be._backend_name == "jax":
429
422
  func = be.scan
423
+ if kwargs.get("projection_matching", False):
424
+ func = be.scan_projections
430
425
 
431
426
  corr_scoring = MATCHING_EXHAUSTIVE_REGISTER.get("CORR", (None, None))[1]
432
427
  results = func(
433
- matching_data=matching_data,
434
428
  splits=splits,
435
429
  n_jobs=outer_jobs,
436
430
  rotate_mask=matching_score != corr_scoring,
437
- callback_class=callback_class,
438
- callback_class_args=callback_class_args,
431
+ **kwargs,
439
432
  )
440
433
  else:
441
434
  results = Parallel(n_jobs=outer_jobs, verbose=verbose)(
@@ -443,26 +436,20 @@ def scan_subsets(
443
436
  delayed(_wrap_backend(scan))(
444
437
  backend_name=be._backend_name,
445
438
  backend_args=be._backend_args,
446
- matching_data=matching_data,
447
439
  matching_score=matching_score,
448
440
  matching_setup=matching_setup,
449
441
  n_jobs=inner_jobs,
450
- callback_class=callback_class,
451
- callback_class_args=callback_class_args,
452
442
  interpolation_order=interpolation_order,
453
443
  pad_target=pad_target_edges,
454
444
  gpu_index=index % outer_jobs,
455
- pad_template_filter=pad_template_filter,
456
445
  target_slice=target_split,
457
446
  template_slice=template_split,
447
+ **kwargs,
458
448
  )
459
449
  for index, (target_split, template_split) in enumerate(splits)
460
450
  ]
461
451
  )
462
- matching_data._free_data()
463
- if callback_class is not None:
464
- return callback_class.merge(results, **callback_class_args)
465
- return None
452
+ return callback_class.merge(results, **callback_class_args)
466
453
 
467
454
 
468
455
  def register_matching_exhaustive(
@@ -1104,7 +1104,8 @@ def create_score_object(score: str, **kwargs) -> object:
1104
1104
  Examples
1105
1105
  --------
1106
1106
  >>> from tme import Density
1107
- >>> from tme.matching_utils import create_mask, euler_to_rotationmatrix
1107
+ >>> from tme.mask import create_mask
1108
+ >>> from tme.matching_utils import euler_to_rotationmatrix
1108
1109
  >>> from tme.matching_optimization import CrossCorrelation, optimize_match
1109
1110
  >>> translation, rotation = (5, -2, 7), (5, -10, 2)
1110
1111
  >>> target = create_mask(