spacr 0.4.60__py3-none-any.whl → 0.9.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (72) hide show
  1. spacr/__init__.py +2 -4
  2. spacr/__main__.py +3 -3
  3. spacr/core.py +13 -107
  4. spacr/gui.py +0 -1
  5. spacr/gui_core.py +2 -2
  6. spacr/gui_utils.py +5 -14
  7. spacr/io.py +189 -200
  8. spacr/mediar.py +12 -8
  9. spacr/plot.py +50 -13
  10. spacr/settings.py +71 -14
  11. spacr/submodules.py +21 -14
  12. spacr/timelapse.py +192 -6
  13. spacr/utils.py +180 -56
  14. {spacr-0.4.60.dist-info → spacr-0.9.0.dist-info}/METADATA +64 -62
  15. {spacr-0.4.60.dist-info → spacr-0.9.0.dist-info}/RECORD +20 -72
  16. {spacr-0.4.60.dist-info → spacr-0.9.0.dist-info}/WHEEL +1 -1
  17. spacr/resources/MEDIAR/.gitignore +0 -18
  18. spacr/resources/MEDIAR/LICENSE +0 -21
  19. spacr/resources/MEDIAR/README.md +0 -189
  20. spacr/resources/MEDIAR/SetupDict.py +0 -39
  21. spacr/resources/MEDIAR/config/baseline.json +0 -60
  22. spacr/resources/MEDIAR/config/mediar_example.json +0 -72
  23. spacr/resources/MEDIAR/config/pred/pred_mediar.json +0 -17
  24. spacr/resources/MEDIAR/config/step1_pretraining/phase1.json +0 -55
  25. spacr/resources/MEDIAR/config/step1_pretraining/phase2.json +0 -58
  26. spacr/resources/MEDIAR/config/step2_finetuning/finetuning1.json +0 -66
  27. spacr/resources/MEDIAR/config/step2_finetuning/finetuning2.json +0 -66
  28. spacr/resources/MEDIAR/config/step3_prediction/base_prediction.json +0 -16
  29. spacr/resources/MEDIAR/config/step3_prediction/ensemble_tta.json +0 -23
  30. spacr/resources/MEDIAR/core/BasePredictor.py +0 -120
  31. spacr/resources/MEDIAR/core/BaseTrainer.py +0 -240
  32. spacr/resources/MEDIAR/core/Baseline/Predictor.py +0 -59
  33. spacr/resources/MEDIAR/core/Baseline/Trainer.py +0 -113
  34. spacr/resources/MEDIAR/core/Baseline/__init__.py +0 -2
  35. spacr/resources/MEDIAR/core/Baseline/utils.py +0 -80
  36. spacr/resources/MEDIAR/core/MEDIAR/EnsemblePredictor.py +0 -105
  37. spacr/resources/MEDIAR/core/MEDIAR/Predictor.py +0 -234
  38. spacr/resources/MEDIAR/core/MEDIAR/Trainer.py +0 -172
  39. spacr/resources/MEDIAR/core/MEDIAR/__init__.py +0 -3
  40. spacr/resources/MEDIAR/core/MEDIAR/utils.py +0 -429
  41. spacr/resources/MEDIAR/core/__init__.py +0 -2
  42. spacr/resources/MEDIAR/core/utils.py +0 -40
  43. spacr/resources/MEDIAR/evaluate.py +0 -71
  44. spacr/resources/MEDIAR/generate_mapping.py +0 -121
  45. spacr/resources/MEDIAR/image/examples/img1.tiff +0 -0
  46. spacr/resources/MEDIAR/image/examples/img2.tif +0 -0
  47. spacr/resources/MEDIAR/image/failure_cases.png +0 -0
  48. spacr/resources/MEDIAR/image/mediar_framework.png +0 -0
  49. spacr/resources/MEDIAR/image/mediar_model.PNG +0 -0
  50. spacr/resources/MEDIAR/image/mediar_results.png +0 -0
  51. spacr/resources/MEDIAR/main.py +0 -125
  52. spacr/resources/MEDIAR/predict.py +0 -70
  53. spacr/resources/MEDIAR/requirements.txt +0 -14
  54. spacr/resources/MEDIAR/train_tools/__init__.py +0 -3
  55. spacr/resources/MEDIAR/train_tools/data_utils/__init__.py +0 -1
  56. spacr/resources/MEDIAR/train_tools/data_utils/custom/CellAware.py +0 -88
  57. spacr/resources/MEDIAR/train_tools/data_utils/custom/LoadImage.py +0 -161
  58. spacr/resources/MEDIAR/train_tools/data_utils/custom/NormalizeImage.py +0 -77
  59. spacr/resources/MEDIAR/train_tools/data_utils/custom/__init__.py +0 -3
  60. spacr/resources/MEDIAR/train_tools/data_utils/custom/modalities.pkl +0 -0
  61. spacr/resources/MEDIAR/train_tools/data_utils/datasetter.py +0 -208
  62. spacr/resources/MEDIAR/train_tools/data_utils/transforms.py +0 -148
  63. spacr/resources/MEDIAR/train_tools/data_utils/utils.py +0 -84
  64. spacr/resources/MEDIAR/train_tools/measures.py +0 -200
  65. spacr/resources/MEDIAR/train_tools/models/MEDIARFormer.py +0 -102
  66. spacr/resources/MEDIAR/train_tools/models/__init__.py +0 -1
  67. spacr/resources/MEDIAR/train_tools/utils.py +0 -70
  68. spacr/stats.py +0 -221
  69. /spacr/{cellpose.py → spacr_cellpose.py} +0 -0
  70. {spacr-0.4.60.dist-info → spacr-0.9.0.dist-info}/LICENSE +0 -0
  71. {spacr-0.4.60.dist-info → spacr-0.9.0.dist-info}/entry_points.txt +0 -0
  72. {spacr-0.4.60.dist-info → spacr-0.9.0.dist-info}/top_level.txt +0 -0
spacr/timelapse.py CHANGED
@@ -7,9 +7,9 @@ from IPython.display import display
7
7
  from IPython.display import Image as ipyimage
8
8
  import trackpy as tp
9
9
  from btrack import datasets as btrack_datasets
10
- from skimage.measure import regionprops
10
+ from skimage.measure import regionprops, regionprops_table
11
11
  from scipy.signal import find_peaks
12
- from scipy.optimize import curve_fit
12
+ from scipy.optimize import curve_fit, linear_sum_assignment
13
13
  from scipy.integrate import trapz
14
14
  import matplotlib.pyplot as plt
15
15
 
@@ -255,7 +255,7 @@ def _relabel_masks_based_on_tracks(masks, tracks, mode='btrack'):
255
255
 
256
256
  return relabeled_masks
257
257
 
258
- def _prepare_for_tracking(mask_array):
258
+ def _prepare_for_tracking_v1(mask_array):
259
259
  """
260
260
  Prepare the mask array for object tracking.
261
261
 
@@ -286,6 +286,105 @@ def _prepare_for_tracking(mask_array):
286
286
  })
287
287
  return pd.DataFrame(frames)
288
288
 
289
+ def _prepare_for_tracking_v1(mask_array):
290
+ frames = []
291
+ for t, frame in enumerate(mask_array):
292
+ props = regionprops_table(
293
+ frame,
294
+ properties=('label', 'centroid-0', 'centroid-1', 'area',
295
+ 'bbox-0', 'bbox-1', 'bbox-2', 'bbox-3',
296
+ 'eccentricity')
297
+ )
298
+ df = pd.DataFrame(props)
299
+ df = df.rename(columns={
300
+ 'centroid-0': 'y', 'centroid-1': 'x', 'area': 'mass',
301
+ 'label': 'original_label'
302
+ })
303
+ df['frame'] = t
304
+ frames.append(df[['frame','y','x','mass','original_label',
305
+ 'bbox-0','bbox-1','bbox-2','bbox-3','eccentricity']])
306
+ return pd.concat(frames, ignore_index=True)
307
+
308
+ def _prepare_for_tracking(mask_array):
309
+ frames = []
310
+ for t, frame in enumerate(mask_array):
311
+ props = regionprops_table(
312
+ frame,
313
+ properties=('label', 'centroid', 'area', 'bbox', 'eccentricity')
314
+ )
315
+ df = pd.DataFrame(props)
316
+ df = df.rename(columns={
317
+ 'centroid-0': 'y',
318
+ 'centroid-1': 'x',
319
+ 'area': 'mass',
320
+ 'label': 'original_label'
321
+ })
322
+ df['frame'] = t
323
+ frames.append(df[['frame','y','x','mass','original_label',
324
+ 'bbox-0','bbox-1','bbox-2','bbox-3','eccentricity']])
325
+ return pd.concat(frames, ignore_index=True)
326
+
327
+ def _track_by_iou(masks, iou_threshold=0.1):
328
+ """
329
+ Build a track table by linking masks frame→frame via IoU.
330
+ Returns a DataFrame with columns [frame, original_label, track_id].
331
+ """
332
+ n_frames = masks.shape[0]
333
+ # 1) initialize: every label in frame 0 starts its own track
334
+ labels0 = np.unique(masks[0])[1:]
335
+ next_track = 1
336
+ track_map = {} # (frame,label) -> track_id
337
+ for L in labels0:
338
+ track_map[(0, L)] = next_track
339
+ next_track += 1
340
+
341
+ # 2) iterate through frames
342
+ for t in range(1, n_frames):
343
+ prev, curr = masks[t-1], masks[t]
344
+ matches = link_by_iou(prev, curr, iou_threshold=iou_threshold)
345
+ used_curr = set()
346
+ # a) assign matched labels to existing tracks
347
+ for L_prev, L_curr in matches:
348
+ tid = track_map[(t-1, L_prev)]
349
+ track_map[(t, L_curr)] = tid
350
+ used_curr.add(L_curr)
351
+ # b) any label in curr not matched → new track
352
+ for L in np.unique(curr)[1:]:
353
+ if L not in used_curr:
354
+ track_map[(t, L)] = next_track
355
+ next_track += 1
356
+
357
+ # 3) flatten into DataFrame
358
+ records = []
359
+ for (frame, label), tid in track_map.items():
360
+ records.append({'frame': frame, 'original_label': label, 'track_id': tid})
361
+ return pd.DataFrame(records)
362
+
363
+ def link_by_iou(mask_prev, mask_next, iou_threshold=0.1):
364
+ # Get labels
365
+ labels_prev = np.unique(mask_prev)[1:]
366
+ labels_next = np.unique(mask_next)[1:]
367
+ # Precompute masks as boolean
368
+ bool_prev = {L: mask_prev==L for L in labels_prev}
369
+ bool_next = {L: mask_next==L for L in labels_next}
370
+ # Cost matrix = 1 - IoU
371
+ cost = np.ones((len(labels_prev), len(labels_next)), dtype=float)
372
+ for i, L1 in enumerate(labels_prev):
373
+ m1 = bool_prev[L1]
374
+ for j, L2 in enumerate(labels_next):
375
+ m2 = bool_next[L2]
376
+ inter = np.logical_and(m1, m2).sum()
377
+ union = np.logical_or(m1, m2).sum()
378
+ if union > 0:
379
+ cost[i, j] = 1 - inter/union
380
+ # Solve assignment
381
+ row_ind, col_ind = linear_sum_assignment(cost)
382
+ matches = []
383
+ for i, j in zip(row_ind, col_ind):
384
+ if cost[i,j] <= 1 - iou_threshold:
385
+ matches.append((labels_prev[i], labels_next[j]))
386
+ return matches
387
+
289
388
  def _find_optimal_search_range(features, initial_search_range=500, increment=10, max_attempts=49, memory=3):
290
389
  """
291
390
  Find the optimal search range for linking features.
@@ -336,7 +435,94 @@ def _remove_objects_from_first_frame(masks, percentage=10):
336
435
  masks[0][first_frame == label] = 0
337
436
  return masks
338
437
 
339
- def _facilitate_trackin_with_adaptive_removal(masks, search_range=500, max_attempts=100, memory=3):
438
+ def _track_by_iou(masks, iou_threshold=0.1):
439
+ """
440
+ Build a track table by linking masks frame→frame via IoU.
441
+ Returns a DataFrame with columns [frame, original_label, track_id].
442
+ """
443
+ n_frames = masks.shape[0]
444
+ # 1) initialize: every label in frame 0 starts its own track
445
+ labels0 = np.unique(masks[0])[1:]
446
+ next_track = 1
447
+ track_map = {} # (frame,label) -> track_id
448
+ for L in labels0:
449
+ track_map[(0, L)] = next_track
450
+ next_track += 1
451
+
452
+ # 2) iterate through frames
453
+ for t in range(1, n_frames):
454
+ prev, curr = masks[t-1], masks[t]
455
+ matches = link_by_iou(prev, curr, iou_threshold=iou_threshold)
456
+ used_curr = set()
457
+ # a) assign matched labels to existing tracks
458
+ for L_prev, L_curr in matches:
459
+ tid = track_map[(t-1, L_prev)]
460
+ track_map[(t, L_curr)] = tid
461
+ used_curr.add(L_curr)
462
+ # b) any label in curr not matched → new track
463
+ for L in np.unique(curr)[1:]:
464
+ if L not in used_curr:
465
+ track_map[(t, L)] = next_track
466
+ next_track += 1
467
+
468
+ # 3) flatten into DataFrame
469
+ records = []
470
+ for (frame, label), tid in track_map.items():
471
+ records.append({'frame': frame, 'original_label': label, 'track_id': tid})
472
+ return pd.DataFrame(records)
473
+
474
+
475
+ def _facilitate_trackin_with_adaptive_removal(masks, search_range=None, max_attempts=5, memory=3, min_mass=50, track_by_iou=False):
476
+ """
477
+ Facilitates object tracking with deterministic initial filtering and
478
+ trackpy’s constant-velocity prediction.
479
+
480
+ Args:
481
+ masks (np.ndarray): integer‐labeled masks (frames × H × W).
482
+ search_range (int|None): max displacement; if None, auto‐computed.
483
+ max_attempts (int): how many times to retry with smaller search_range.
484
+ memory (int): trackpy memory parameter.
485
+ min_mass (float): drop any object in frame 0 with area < min_mass.
486
+
487
+ Returns:
488
+ masks, features_df, tracks_df
489
+
490
+ Raises:
491
+ RuntimeError if linking fails after max_attempts.
492
+ """
493
+ # 1) initial features & filter frame 0 by area
494
+ features = _prepare_for_tracking(masks)
495
+ f0 = features[features['frame'] == 0]
496
+ valid = f0.loc[f0['mass'] >= min_mass, 'original_label'].unique()
497
+ masks[0] = np.where(np.isin(masks[0], valid), masks[0], 0)
498
+
499
+ # 2) recompute features on filtered masks
500
+ features = _prepare_for_tracking(masks)
501
+
502
+ # 3) default search_range = 2×sqrt(99th‑pct area)
503
+ if search_range is None:
504
+ a99 = f0['mass'].quantile(0.99)
505
+ search_range = max(1, int(2 * np.sqrt(a99)))
506
+
507
+ # 4) attempt linking, shrinking search_range on failure
508
+ for attempt in range(1, max_attempts + 1):
509
+ try:
510
+ if track_by_iou:
511
+ tracks_df = _track_by_iou(masks, iou_threshold=0.1)
512
+ else:
513
+ tracks_df = tp.link_df(features,search_range=search_range, memory=memory, predict=True)
514
+ print(f"Linked on attempt {attempt} with search_range={search_range}")
515
+ return masks, features, tracks_df
516
+
517
+ except Exception as e:
518
+ search_range = max(1, int(search_range * 0.8))
519
+ print(f"Attempt {attempt} failed ({e}); reducing search_range to {search_range}")
520
+
521
+ raise RuntimeError(
522
+ f"Failed to track after {max_attempts} attempts; last search_range={search_range}"
523
+ )
524
+
525
+ def _facilitate_trackin_with_adaptive_removal_v1(masks, search_range=500, max_attempts=100, memory=3):
340
526
  """
341
527
  Facilitates object tracking with adaptive removal.
342
528
 
@@ -376,7 +562,7 @@ def _facilitate_trackin_with_adaptive_removal(masks, search_range=500, max_attem
376
562
  print(f"Failed to track objects after {max_attempts} attempts. Consider adjusting parameters.")
377
563
  return None, None, None
378
564
 
379
- def _trackpy_track_cells(src, name, batch_filenames, object_type, masks, timelapse_displacement, timelapse_memory, timelapse_remove_transient, plot, save, mode):
565
+ def _trackpy_track_cells(src, name, batch_filenames, object_type, masks, timelapse_displacement, timelapse_memory, timelapse_remove_transient, plot, save, mode, track_by_iou):
380
566
  """
381
567
  Track cells using the Trackpy library.
382
568
 
@@ -409,7 +595,7 @@ def _trackpy_track_cells(src, name, batch_filenames, object_type, masks, timelap
409
595
  if timelapse_displacement is None:
410
596
  timelapse_displacement = 50
411
597
 
412
- masks, features, tracks_df = _facilitate_trackin_with_adaptive_removal(masks, search_range=timelapse_displacement, max_attempts=100, memory=timelapse_memory)
598
+ masks, features, tracks_df = _facilitate_trackin_with_adaptive_removal(masks, search_range=timelapse_displacement, max_attempts=100, memory=timelapse_memory, track_by_iou=track_by_iou)
413
599
 
414
600
  tracks_df['particle'] += 1
415
601
 
spacr/utils.py CHANGED
@@ -1,5 +1,4 @@
1
1
  import os, re, sqlite3, torch, torchvision, random, string, shutil, cv2, tarfile, glob, psutil, platform, gzip, subprocess, time, requests, ast, traceback
2
-
3
2
  import numpy as np
4
3
  import pandas as pd
5
4
  from cellpose import models as cp_models
@@ -11,7 +10,7 @@ import skimage.measure as measure
11
10
  from skimage.transform import resize as resizescikit
12
11
  from skimage.morphology import dilation, square
13
12
  from skimage.measure import find_contours
14
- from skimage.segmentation import clear_border
13
+ from skimage.segmentation import clear_border, find_boundaries
15
14
  from scipy.stats import pearsonr
16
15
 
17
16
  from collections import defaultdict, OrderedDict
@@ -432,7 +431,7 @@ def close_multiprocessing_processes():
432
431
 
433
432
  def check_mask_folder(src,mask_fldr):
434
433
 
435
- mask_folder = os.path.join(src,'norm_channel_stack',mask_fldr)
434
+ mask_folder = os.path.join(src,'masks',mask_fldr)
436
435
  stack_folder = os.path.join(src,'stack')
437
436
 
438
437
  if not os.path.exists(mask_folder):
@@ -554,7 +553,7 @@ def _get_cellpose_batch_size():
554
553
  except Exception as e:
555
554
  return 8
556
555
 
557
- def _extract_filename_metadata(filenames, src, regular_expression, metadata_type='cellvoyager', pick_slice=False, skip_mode='01'):
556
+ def _extract_filename_metadata(filenames, src, regular_expression, metadata_type='cellvoyager'):
558
557
 
559
558
  images_by_key = defaultdict(list)
560
559
 
@@ -568,33 +567,38 @@ def _extract_filename_metadata(filenames, src, regular_expression, metadata_type
568
567
  plate = os.path.basename(src)
569
568
 
570
569
  well = match.group('wellID')
571
- field = match.group('fieldID')
572
- channel = match.group('chanID')
573
- mode = None
574
-
575
570
  if well[0].isdigit():
576
571
  well = str(_safe_int_convert(well))
572
+
573
+ field = match.group('fieldID')
577
574
  if field[0].isdigit():
578
575
  field = str(_safe_int_convert(field))
576
+
577
+ channel = match.group('chanID')
579
578
  if channel[0].isdigit():
580
579
  channel = str(_safe_int_convert(channel))
581
-
580
+
581
+ if 'timeID' in match.groupdict():
582
+ timeID = match.group('timeID')
583
+ if timeID[0].isdigit():
584
+ timeID = str(_safe_int_convert(timeID))
585
+ else:
586
+ timeID = None
587
+
588
+ if 'sliceID' in match.groupdict():
589
+ sliceID = match.group('sliceID')
590
+ if sliceID[0].isdigit():
591
+ sliceID = str(_safe_int_convert(sliceID))
592
+ else:
593
+ sliceID = None
594
+
582
595
  if metadata_type =='cq1':
583
596
  orig_wellID = wellID
584
597
  wellID = _convert_cq1_well_id(wellID)
585
598
  print(f'Converted Well ID: {orig_wellID} to {wellID}', end='\r', flush=True)
586
599
 
587
- if pick_slice:
588
- try:
589
- mode = match.group('AID')
590
- except IndexError:
591
- sliceid = '00'
592
-
593
- if mode == skip_mode:
594
- continue
595
-
596
- key = (plate, well, field, channel, mode)
597
- file_path = os.path.join(src, filename) # Store the full path
600
+ key = (plate, well, field, channel, timeID, sliceID)
601
+ file_path = os.path.join(src, filename)
598
602
  images_by_key[key].append(file_path)
599
603
 
600
604
  except IndexError:
@@ -1262,9 +1266,9 @@ def _pivot_counts_table(db_path):
1262
1266
 
1263
1267
  def _get_cellpose_channels(src, nucleus_channel, pathogen_channel, cell_channel):
1264
1268
 
1265
- cell_mask_path = os.path.join(src, 'norm_channel_stack', 'cell_mask_stack')
1266
- nucleus_mask_path = os.path.join(src, 'norm_channel_stack', 'nucleus_mask_stack')
1267
- pathogen_mask_path = os.path.join(src, 'norm_channel_stack', 'pathogen_mask_stack')
1269
+ cell_mask_path = os.path.join(src, 'masks', 'cell_mask_stack')
1270
+ nucleus_mask_path = os.path.join(src, 'masks', 'nucleus_mask_stack')
1271
+ pathogen_mask_path = os.path.join(src, 'masks', 'pathogen_mask_stack')
1268
1272
 
1269
1273
 
1270
1274
  if os.path.exists(cell_mask_path) or os.path.exists(nucleus_mask_path) or os.path.exists(pathogen_mask_path):
@@ -3084,17 +3088,19 @@ def _object_filter(df, object_type, size_range, intensity_range, mask_chans, mas
3084
3088
  return df
3085
3089
 
3086
3090
  def _get_regex(metadata_type, img_format, custom_regex=None):
3091
+
3092
+ print(f"Image_format: {img_format}")
3087
3093
 
3088
3094
  if img_format == None:
3089
- img_format == '.tif'
3095
+ img_format == 'tif'
3090
3096
  if metadata_type == 'cellvoyager':
3091
- regex = f'(?P<plateID>.*)_(?P<wellID>.*)_T(?P<timeID>.*)F(?P<fieldID>.*)L(?P<laserID>..)A(?P<AID>..)Z(?P<sliceID>.*)C(?P<chanID>.*){img_format}'
3097
+ regex = f"(?P<plateID>.*)_(?P<wellID>.*)_T(?P<timeID>.*)F(?P<fieldID>.*)L(?P<laserID>..)A(?P<AID>..)Z(?P<sliceID>.*)C(?P<chanID>.*).{img_format}"
3092
3098
  elif metadata_type == 'cq1':
3093
- regex = f'W(?P<wellID>.*)F(?P<fieldID>.*)T(?P<timeID>.*)Z(?P<sliceID>.*)C(?P<chanID>.*){img_format}'
3099
+ regex = f"W(?P<wellID>.*)F(?P<fieldID>.*)T(?P<timeID>.*)Z(?P<sliceID>.*)C(?P<chanID>.*).{img_format}"
3094
3100
  elif metadata_type == 'auto':
3095
- regex = f'(?P<plateID>.*)_(?P<wellID>.*)_T(?P<timeID>.*)F(?P<fieldID>.*)L(?P<laserID>.*)C(?P<chanID>.*).tif'
3101
+ regex = f"(?P<plateID>.*)_(?P<wellID>.*)_T(?P<timeID>.*)F(?P<fieldID>.*)L(?P<laserID>.*)C(?P<chanID>.*).tif"
3096
3102
  elif metadata_type == 'custom':
3097
- regex = f'({custom_regex}){img_format}'
3103
+ regex = f"({custom_regex}){img_format}"
3098
3104
 
3099
3105
  print(f'regex mode:{metadata_type} regex:{regex}')
3100
3106
  return regex
@@ -3279,15 +3285,6 @@ class SaliencyMapGenerator:
3279
3285
  return fig
3280
3286
 
3281
3287
  def percentile_normalize(self, img, lower_percentile=2, upper_percentile=98):
3282
- """
3283
- Normalize each channel of the image to the given percentiles.
3284
- Args:
3285
- img: Input image as numpy array with shape (H, W, C)
3286
- lower_percentile: Lower percentile for normalization (default 2)
3287
- upper_percentile: Upper percentile for normalization (default 98)
3288
- Returns:
3289
- img: Normalized image
3290
- """
3291
3288
  img_normalized = np.zeros_like(img)
3292
3289
 
3293
3290
  for c in range(img.shape[2]): # Iterate over each channel
@@ -3297,7 +3294,6 @@ class SaliencyMapGenerator:
3297
3294
 
3298
3295
  return img_normalized
3299
3296
 
3300
-
3301
3297
  class GradCAMGenerator:
3302
3298
  def __init__(self, model, target_layer, cam_type='gradcam'):
3303
3299
  self.model = model
@@ -3402,15 +3398,6 @@ class GradCAMGenerator:
3402
3398
  return fig
3403
3399
 
3404
3400
  def percentile_normalize(self, img, lower_percentile=2, upper_percentile=98):
3405
- """
3406
- Normalize each channel of the image to the given percentiles.
3407
- Args:
3408
- img: Input image as numpy array with shape (H, W, C)
3409
- lower_percentile: Lower percentile for normalization (default 2)
3410
- upper_percentile: Upper percentile for normalization (default 98)
3411
- Returns:
3412
- img: Normalized image
3413
- """
3414
3401
  img_normalized = np.zeros_like(img)
3415
3402
 
3416
3403
  for c in range(img.shape[2]): # Iterate over each channel
@@ -4522,6 +4509,76 @@ def _merge_cells_based_on_parasite_overlap(parasite_mask, cell_mask, nuclei_mask
4522
4509
  relabeled_cell_mask, _ = label(cell_mask, return_num=True)
4523
4510
  return relabeled_cell_mask.astype(np.uint16)
4524
4511
 
4512
+ def _merge_cells_without_nucleus(adj_cell_mask: np.ndarray, nuclei_mask: np.ndarray):
4513
+ """
4514
+ Relabel any cell that lacks a nucleus to the ID of an adjacent
4515
+ cell that *does* contain a nucleus.
4516
+
4517
+ Parameters
4518
+ ----------
4519
+ adj_cell_mask : np.ndarray
4520
+ Labelled (0 = background) cell mask after all other merging steps.
4521
+ nuclei_mask : np.ndarray
4522
+ Labelled (0 = background) nuclei mask.
4523
+
4524
+ Returns
4525
+ -------
4526
+ np.ndarray
4527
+ Updated cell mask with nucleus-free cells merged into
4528
+ neighbouring nucleus-bearing cells.
4529
+ """
4530
+ out = adj_cell_mask.copy()
4531
+
4532
+ # ----------------------------------------------------------------- #
4533
+ # 1 — Identify which cell IDs contain a nucleus
4534
+ nuc_labels = np.unique(nuclei_mask[nuclei_mask > 0])
4535
+
4536
+ cells_with_nuc = set()
4537
+ for nuc_id in nuc_labels:
4538
+ labels, counts = np.unique(adj_cell_mask[nuclei_mask == nuc_id],
4539
+ return_counts=True)
4540
+
4541
+ # drop background (label 0) from *both* arrays
4542
+ keep = labels > 0
4543
+ labels = labels[keep]
4544
+ counts = counts[keep]
4545
+
4546
+ if labels.size: # at least one non-zero overlap
4547
+ cells_with_nuc.add(labels[np.argmax(counts)])
4548
+
4549
+ # ----------------------------------------------------------------- #
4550
+ # 2 — Build an adjacency map between neighbouring cell IDs
4551
+ # ----------------------------------------------------------------- #
4552
+ boundaries = find_boundaries(adj_cell_mask, mode="thick")
4553
+ adj_map = defaultdict(set)
4554
+
4555
+ ys, xs = np.where(boundaries)
4556
+ h, w = adj_cell_mask.shape
4557
+ for y, x in zip(ys, xs):
4558
+ src = adj_cell_mask[y, x]
4559
+ if src == 0:
4560
+ continue
4561
+ for dy in (-1, 0, 1):
4562
+ for dx in (-1, 0, 1):
4563
+ ny, nx = y + dy, x + dx
4564
+ if 0 <= ny < h and 0 <= nx < w:
4565
+ dst = adj_cell_mask[ny, nx]
4566
+ if dst != 0 and dst != src:
4567
+ adj_map[src].add(dst)
4568
+
4569
+ # ----------------------------------------------------------------- #
4570
+ # 3 — Relabel nucleus-free cells that touch nucleus-bearing neighbours
4571
+ # ----------------------------------------------------------------- #
4572
+ cells_no_nuc = set(np.unique(adj_cell_mask)) - {0} - cells_with_nuc
4573
+ for cell_id in cells_no_nuc:
4574
+ neighbours = adj_map.get(cell_id, set()) & cells_with_nuc
4575
+ if neighbours:
4576
+ # Choose the first nucleus-bearing neighbour deterministically
4577
+ target = sorted(neighbours)[0]
4578
+ out[out == cell_id] = target
4579
+
4580
+ return out.astype(np.uint16)
4581
+
4525
4582
  def adjust_cell_masks(parasite_folder, cell_folder, nuclei_folder, overlap_threshold=5, perimeter_threshold=30):
4526
4583
 
4527
4584
  """
@@ -4556,12 +4613,12 @@ def adjust_cell_masks(parasite_folder, cell_folder, nuclei_folder, overlap_thres
4556
4613
  parasite_mask = np.load(parasite_path, allow_pickle=True)
4557
4614
  cell_mask = np.load(cell_path, allow_pickle=True)
4558
4615
  nuclei_mask = np.load(nuclei_path, allow_pickle=True)
4616
+
4559
4617
  # Merge and relabel cells
4560
4618
  merged_cell_mask = _merge_cells_based_on_parasite_overlap(parasite_mask, cell_mask, nuclei_mask, overlap_threshold, perimeter_threshold)
4561
4619
 
4562
- # Force 16 bit
4563
- #merged_cell_mask = merged_cell_mask.astype(np.uint16)
4564
-
4620
+ #merged_cell_mask = _merge_cells_without_nucleus(merged_cell_mask, nuclei_mask)
4621
+
4565
4622
  # Overwrite the original cell mask file with the merged result
4566
4623
  np.save(cell_path, merged_cell_mask)
4567
4624
 
@@ -4698,10 +4755,10 @@ def get_ml_results_paths(src, model_type='xgboost', channel_of_interest=1):
4698
4755
  elif isinstance(channel_of_interest, int):
4699
4756
  feature_string = f"channel_{channel_of_interest}"
4700
4757
 
4701
- elif channel_of_interest is 'morphology':
4758
+ elif channel_of_interest == 'morphology':
4702
4759
  feature_string = 'morphology'
4703
4760
 
4704
- elif channel_of_interest is None:
4761
+ elif channel_of_interest == None:
4705
4762
  feature_string = 'all_features'
4706
4763
  else:
4707
4764
  raise ValueError(f"Unsupported channel_of_interest: {channel_of_interest}. Supported values are 'int', 'list', 'None', or 'morphology'.")
@@ -4851,7 +4908,7 @@ def correct_masks(src):
4851
4908
 
4852
4909
  from .io import _load_and_concatenate_arrays
4853
4910
 
4854
- cell_path = os.path.join(src,'norm_channel_stack', 'cell_mask_stack')
4911
+ cell_path = os.path.join(src,'masks', 'cell_mask_stack')
4855
4912
  convert_and_relabel_masks(cell_path)
4856
4913
  _load_and_concatenate_arrays(src, [0,1,2,3], 1, 0, 2)
4857
4914
 
@@ -5115,13 +5172,46 @@ def correct_metadata_column_names(df):
5115
5172
 
5116
5173
  def control_filelist(folder, mode='columnID', values=['01','02']):
5117
5174
  files = os.listdir(folder)
5118
- if mode is 'columnID':
5175
+ if mode == 'columnID':
5119
5176
  filtered_files = [file for file in files if file.split('_')[1][1:] in values]
5120
- if mode is 'rowID':
5177
+ if mode == 'rowID':
5121
5178
  filtered_files = [file for file in files if file.split('_')[1][:1] in values]
5122
5179
  return filtered_files
5123
5180
 
5124
5181
  def rename_columns_in_db(db_path):
5182
+ # map old column names → new names
5183
+ rename_map = {
5184
+ 'row': 'rowID',
5185
+ 'column': 'columnID',
5186
+ 'col': 'columnID',
5187
+ 'plate': 'plateID',
5188
+ 'field': 'fieldID',
5189
+ 'channel': 'chanID',
5190
+ }
5191
+
5192
+ con = sqlite3.connect(db_path)
5193
+ cur = con.cursor()
5194
+
5195
+ # 1) get all user tables
5196
+ cur.execute("SELECT name FROM sqlite_master WHERE type='table';")
5197
+ tables = [row[0] for row in cur.fetchall()]
5198
+
5199
+ for table in tables:
5200
+ # 2) get column names only
5201
+ cur.execute(f"PRAGMA table_info(`{table}`);")
5202
+ cols = [row[1] for row in cur.fetchall()]
5203
+
5204
+ # 3) for each old→new, if the old exists and new does not, rename it
5205
+ for old, new in rename_map.items():
5206
+ if old in cols and new not in cols:
5207
+ sql = f"ALTER TABLE `{table}` RENAME COLUMN `{old}` TO `{new}`;"
5208
+ cur.execute(sql)
5209
+ print(f"Renamed `{table}`.`{old}` → `{new}`")
5210
+
5211
+ con.commit()
5212
+ con.close()
5213
+
5214
+ def rename_columns_in_db_v1(db_path):
5125
5215
  with sqlite3.connect(db_path) as conn:
5126
5216
  cursor = conn.cursor()
5127
5217
 
@@ -5204,7 +5294,7 @@ def delete_intermedeate_files(settings):
5204
5294
  path_orig = os.path.join(settings['src'], 'orig')
5205
5295
  path_stack = os.path.join(settings['src'], 'stack')
5206
5296
  merged_stack = os.path.join(settings['src'], 'merged')
5207
- path_norm_chan_stack = os.path.join(settings['src'], 'norm_channel_stack')
5297
+ path_norm_chan_stack = os.path.join(settings['src'], 'masks')
5208
5298
  path_1 = os.path.join(settings['src'], '1')
5209
5299
  path_2 = os.path.join(settings['src'], '2')
5210
5300
  path_3 = os.path.join(settings['src'], '3')
@@ -5491,3 +5581,37 @@ def correct_metadata(df):
5491
5581
  df = df.rename(columns={'field_name': 'fieldID'})
5492
5582
 
5493
5583
  return df
5584
+
5585
+ def remove_outliers_by_group(df, group_col, value_col, method='iqr', threshold=1.5):
5586
+ """
5587
+ Removes outliers from `value_col` within each group defined by `group_col`.
5588
+
5589
+ Parameters:
5590
+ df (pd.DataFrame): The input DataFrame.
5591
+ group_col (str): Column name to group by.
5592
+ value_col (str): Column containing values to check for outliers.
5593
+ method (str): 'iqr' or 'zscore'.
5594
+ threshold (float): Threshold multiplier for IQR (default 1.5) or z-score.
5595
+
5596
+ Returns:
5597
+ pd.DataFrame: A DataFrame with outliers removed.
5598
+ """
5599
+ def iqr_filter(subdf):
5600
+ q1 = subdf[value_col].quantile(0.25)
5601
+ q3 = subdf[value_col].quantile(0.75)
5602
+ iqr = q3 - q1
5603
+ lower = q1 - threshold * iqr
5604
+ upper = q3 + threshold * iqr
5605
+ return subdf[(subdf[value_col] >= lower) & (subdf[value_col] <= upper)]
5606
+
5607
+ def zscore_filter(subdf):
5608
+ mean = subdf[value_col].mean()
5609
+ std = subdf[value_col].std()
5610
+ return subdf[(subdf[value_col] - mean).abs() <= threshold * std]
5611
+
5612
+ if method == 'iqr':
5613
+ return df.groupby(group_col, group_keys=False).apply(iqr_filter)
5614
+ elif method == 'zscore':
5615
+ return df.groupby(group_col, group_keys=False).apply(zscore_filter)
5616
+ else:
5617
+ raise ValueError("method must be 'iqr' or 'zscore'")