spacr 0.0.82__py3-none-any.whl → 0.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
spacr/core.py CHANGED
@@ -13,7 +13,7 @@ from IPython.display import display
13
13
  from multiprocessing import Pool, cpu_count, Value, Lock
14
14
 
15
15
  import seaborn as sns
16
-
16
+ import cellpose
17
17
  from skimage.measure import regionprops, label
18
18
  from skimage.transform import resize as resizescikit
19
19
  from torch.utils.data import DataLoader
@@ -25,6 +25,7 @@ from sklearn.linear_model import LogisticRegression
25
25
  from sklearn.inspection import permutation_importance
26
26
  from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, classification_report
27
27
  from sklearn.preprocessing import StandardScaler
28
+ from sklearn.metrics import precision_recall_curve, f1_score
28
29
 
29
30
  from scipy.spatial.distance import cosine, euclidean, mahalanobis, cityblock, minkowski, chebyshev, hamming, jaccard, braycurtis
30
31
 
@@ -41,43 +42,51 @@ from .logger import log_function_call
41
42
  def analyze_plaques(folder):
42
43
  summary_data = []
43
44
  details_data = []
45
+ stats_data = []
44
46
 
45
47
  for filename in os.listdir(folder):
46
48
  filepath = os.path.join(folder, filename)
47
49
  if os.path.isfile(filepath):
48
50
  # Assuming each file is a NumPy array file (.npy) containing a 16-bit labeled image
49
- image = np.load(filepath)
50
-
51
+ #image = np.load(filepath)
52
+ image = cellpose.io.imread(filepath)
51
53
  labeled_image = label(image)
52
54
  regions = regionprops(labeled_image)
53
55
 
54
56
  object_count = len(regions)
55
57
  sizes = [region.area for region in regions]
56
58
  average_size = np.mean(sizes) if sizes else 0
59
+ std_dev_size = np.std(sizes) if sizes else 0
57
60
 
58
61
  summary_data.append({'file': filename, 'object_count': object_count, 'average_size': average_size})
62
+ stats_data.append({'file': filename, 'plaque_count': object_count, 'average_size': average_size, 'std_dev_size': std_dev_size})
59
63
  for size in sizes:
60
64
  details_data.append({'file': filename, 'plaque_size': size})
61
65
 
62
66
  # Convert lists to pandas DataFrames
63
67
  summary_df = pd.DataFrame(summary_data)
64
68
  details_df = pd.DataFrame(details_data)
69
+ stats_df = pd.DataFrame(stats_data)
65
70
 
66
71
  # Save DataFrames to a SQLite database
67
- db_name = 'plaques_analysis.db'
72
+ db_name = os.path.join(folder, 'plaques_analysis.db')
68
73
  conn = sqlite3.connect(db_name)
69
74
 
70
75
  summary_df.to_sql('summary', conn, if_exists='replace', index=False)
71
76
  details_df.to_sql('details', conn, if_exists='replace', index=False)
77
+ stats_df.to_sql('stats', conn, if_exists='replace', index=False)
72
78
 
73
79
  conn.close()
74
80
 
75
81
  print(f"Analysis completed and saved to database '{db_name}'.")
76
82
 
83
+
77
84
  def train_cellpose(settings):
78
85
 
79
86
  from .io import _load_normalized_images_and_labels, _load_images_and_labels
80
- #from .utils import resize_images_and_labels
87
+ from .settings import get_train_cellpose_default_settings#, resize_images_and_labels
88
+
89
+ settings = get_train_cellpose_default_settings()
81
90
 
82
91
  img_src = settings['img_src']
83
92
  mask_src = os.path.join(img_src, 'masks')
@@ -146,7 +155,7 @@ def train_cellpose(settings):
146
155
 
147
156
  image_files = [os.path.join(img_src, f) for f in os.listdir(img_src) if f.endswith('.tif')]
148
157
  label_files = [os.path.join(mask_src, f) for f in os.listdir(mask_src) if f.endswith('.tif')]
149
- images, masks, image_names, mask_names = _load_normalized_images_and_labels(image_files, label_files, channels, percentiles, circular, invert, verbose, remove_background, background, Signal_to_noise, target_height, target_width)
158
+ images, masks, image_names, mask_names, orig_dims = _load_normalized_images_and_labels(image_files, label_files, channels, percentiles, circular, invert, verbose, remove_background, background, Signal_to_noise, target_height, target_width)
150
159
  images = [np.squeeze(img) if img.shape[-1] == 1 else img for img in images]
151
160
 
152
161
  if test:
@@ -962,9 +971,10 @@ def generate_dataset(src, file_metadata=None, experiment='TSG101_screen', sample
962
971
  shutil.rmtree(temp_dir)
963
972
  print(f"\nSaved {total_images} images to {tar_name}")
964
973
 
965
- def apply_model_to_tar(tar_path, model_path, file_type='cell_png', image_size=224, batch_size=64, normalize=True, preload='images', num_workers=10, verbose=False):
974
+ def apply_model_to_tar(tar_path, model_path, file_type='cell_png', image_size=224, batch_size=64, normalize=True, preload='images', num_workers=10, threshold=0.5, verbose=False):
966
975
 
967
- from .io import TarImageDataset, DataLoader
976
+ from .io import TarImageDataset
977
+ from .utils import process_vision_results
968
978
 
969
979
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
970
980
  if normalize:
@@ -1017,6 +1027,8 @@ def apply_model_to_tar(tar_path, model_path, file_type='cell_png', image_size=22
1017
1027
 
1018
1028
  data = {'path':filenames_list, 'pred':prediction_pos_probs}
1019
1029
  df = pd.DataFrame(data, index=None)
1030
+ df = process_vision_results(df, threshold)
1031
+
1020
1032
  df.to_csv(result_loc, index=True, header=True, mode='w')
1021
1033
  torch.cuda.empty_cache()
1022
1034
  torch.cuda.memory.empty_cache()
@@ -1290,7 +1302,7 @@ def generate_training_dataset(src, mode='annotation', annotation_column='test',
1290
1302
 
1291
1303
  return
1292
1304
 
1293
- def generate_loaders(src, train_mode='erm', mode='train', image_size=224, batch_size=32, classes=['nc','pc'], num_workers=None, validation_split=0.0, max_show=2, pin_memory=False, normalize=False, channels=[1, 2, 3], verbose=False):
1305
+ def generate_loaders(src, train_mode='erm', mode='train', image_size=224, batch_size=32, classes=['nc','pc'], num_workers=None, validation_split=0.0, max_show=2, pin_memory=False, normalize=False, channels=[1, 2, 3], augment=False, verbose=False):
1294
1306
 
1295
1307
  """
1296
1308
  Generate data loaders for training and validation/test datasets.
@@ -1325,7 +1337,7 @@ def generate_loaders(src, train_mode='erm', mode='train', image_size=224, batch_
1325
1337
  import random
1326
1338
  from PIL import Image
1327
1339
  from torchvision.transforms import ToTensor
1328
- from .utils import SelectChannels
1340
+ from .utils import SelectChannels, augment_dataset
1329
1341
 
1330
1342
  chans = []
1331
1343
 
@@ -1375,14 +1387,22 @@ def generate_loaders(src, train_mode='erm', mode='train', image_size=224, batch_
1375
1387
  return
1376
1388
 
1377
1389
  if train_mode == 'erm':
1390
+
1378
1391
  data = MyDataset(data_dir, classes, transform=transform, shuffle=shuffle, pin_memory=pin_memory)
1392
+
1379
1393
  if validation_split > 0:
1380
1394
  train_size = int((1 - validation_split) * len(data))
1381
1395
  val_size = len(data) - train_size
1396
+ if not augment:
1397
+ print(f'Train data:{train_size}, Validation data:{val_size}')
1398
+ train_dataset, val_dataset = random_split(data, [train_size, val_size])
1382
1399
 
1383
- print(f'Train data:{train_size}, Validation data:{val_size}')
1400
+ if augment:
1384
1401
 
1385
- train_dataset, val_dataset = random_split(data, [train_size, val_size])
1402
+ print(f'Data before augmentation: Train: {len(train_dataset)}, Validataion:{len(val_dataset)}')
1403
+ train_dataset = augment_dataset(train_dataset, is_grayscale=(len(channels) == 1))
1404
+ #val_dataset = augment_dataset(val_dataset, is_grayscale=(len(channels) == 1))
1405
+ print(f'Data after augmentation: Train: {len(train_dataset)}')#, Validataion:{len(val_dataset)}')
1386
1406
 
1387
1407
  train_loaders = DataLoader(train_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers if num_workers is not None else 0, pin_memory=pin_memory)
1388
1408
  val_loaders = DataLoader(val_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers if num_workers is not None else 0, pin_memory=pin_memory)
@@ -1405,10 +1425,16 @@ def generate_loaders(src, train_mode='erm', mode='train', image_size=224, batch_
1405
1425
  if validation_split > 0:
1406
1426
  train_size = int((1 - validation_split) * len(plate_data))
1407
1427
  val_size = len(plate_data) - train_size
1428
+ if not augment:
1429
+ print(f'Train data:{train_size}, Validation data:{val_size}')
1430
+ train_dataset, val_dataset = random_split(plate_data, [train_size, val_size])
1408
1431
 
1409
- print(f'Train data:{train_size}, Validation data:{val_size}')
1432
+ if augment:
1410
1433
 
1411
- train_dataset, val_dataset = random_split(plate_data, [train_size, val_size])
1434
+ print(f'Data before augmentation: Train: {len(train_dataset)}, Validataion:{val_dataset}')
1435
+ train_dataset = augment_dataset(train_dataset, is_grayscale=(len(channels) == 1))
1436
+ #val_dataset = augment_dataset(val_dataset, is_grayscale=(len(channels) == 1))
1437
+ print(f'Data after augmentation: Train: {len(train_dataset)}')#, Validataion:{len(val_dataset)}')
1412
1438
 
1413
1439
  train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers if num_workers is not None else 0, pin_memory=pin_memory)
1414
1440
  val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers if num_workers is not None else 0, pin_memory=pin_memory)
@@ -1423,28 +1449,33 @@ def generate_loaders(src, train_mode='erm', mode='train', image_size=224, batch_
1423
1449
  else:
1424
1450
  print(f'train_mode:{train_mode} is not valid, use: train_mode = irm or erm')
1425
1451
  return
1452
+
1453
+
1454
+ if train_mode == 'erm':
1455
+ for idx, (images, labels, filenames) in enumerate(train_loaders):
1456
+ if idx >= max_show:
1457
+ break
1458
+ images = images.cpu()
1459
+ label_strings = [str(label.item()) for label in labels]
1460
+ train_fig = _imshow(images, label_strings, nrow=20, fontsize=12)
1461
+ if verbose:
1462
+ plt.show()
1426
1463
 
1427
- if verbose:
1428
- if train_mode == 'erm':
1429
- for idx, (images, labels, filenames) in enumerate(train_loaders):
1464
+ elif train_mode == 'irm':
1465
+ for plate_name, train_loader in zip(plate_names, train_loaders):
1466
+ print(f'Plate: {plate_name} with {len(train_loader.dataset)} images')
1467
+ for idx, (images, labels, filenames) in enumerate(train_loader):
1430
1468
  if idx >= max_show:
1431
1469
  break
1432
1470
  images = images.cpu()
1433
1471
  label_strings = [str(label.item()) for label in labels]
1434
- _imshow(images, label_strings, nrow=20, fontsize=12)
1435
- elif train_mode == 'irm':
1436
- for plate_name, train_loader in zip(plate_names, train_loaders):
1437
- print(f'Plate: {plate_name} with {len(train_loader.dataset)} images')
1438
- for idx, (images, labels, filenames) in enumerate(train_loader):
1439
- if idx >= max_show:
1440
- break
1441
- images = images.cpu()
1442
- label_strings = [str(label.item()) for label in labels]
1443
- _imshow(images, label_strings, nrow=20, fontsize=12)
1444
-
1445
- return train_loaders, val_loaders, plate_names
1446
-
1447
- def analyze_recruitment(src, metadata_settings, advanced_settings):
1472
+ train_fig = _imshow(images, label_strings, nrow=20, fontsize=12)
1473
+ if verbose:
1474
+ plt.show()
1475
+
1476
+ return train_loaders, val_loaders, plate_names, train_fig
1477
+
1478
+ def analyze_recruitment(src, metadata_settings={}, advanced_settings={}):
1448
1479
  """
1449
1480
  Analyze recruitment data by grouping the DataFrame by well coordinates and plotting controls and recruitment data.
1450
1481
 
@@ -1460,6 +1491,9 @@ def analyze_recruitment(src, metadata_settings, advanced_settings):
1460
1491
  from .io import _read_and_merge_data, _results_to_csv
1461
1492
  from .plot import plot_merged, _plot_controls, _plot_recruitment
1462
1493
  from .utils import _object_filter, annotate_conditions, _calculate_recruitment, _group_by_well
1494
+ from .settings import get_analyze_recruitment_default_settings
1495
+
1496
+ settings = get_analyze_recruitment_default_settings(settings)
1463
1497
 
1464
1498
  settings_dict = {**metadata_settings, **advanced_settings}
1465
1499
  settings_df = pd.DataFrame(list(settings_dict.items()), columns=['Key', 'Value'])
@@ -1634,8 +1668,8 @@ def preprocess_generate_masks(src, settings={}):
1634
1668
 
1635
1669
  from .io import preprocess_img_data, _load_and_concatenate_arrays
1636
1670
  from .plot import plot_merged, plot_arrays
1637
- from .utils import _pivot_counts_table, set_default_settings_preprocess_generate_masks, set_default_plot_merge_settings, check_mask_folder
1638
- from .utils import adjust_cell_masks, _merge_cells_based_on_parasite_overlap, process_masks
1671
+ from .utils import _pivot_counts_table, check_mask_folder, adjust_cell_masks, _merge_cells_based_on_parasite_overlap, process_masks
1672
+ from .settings import set_default_settings_preprocess_generate_masks, set_default_plot_merge_settings
1639
1673
 
1640
1674
  settings = set_default_settings_preprocess_generate_masks(src, settings)
1641
1675
  settings_df = pd.DataFrame(list(settings.items()), columns=['Key', 'Value'])
@@ -1756,36 +1790,14 @@ def identify_masks_finetune(settings):
1756
1790
  from .plot import print_mask_and_flows
1757
1791
  from .utils import get_files_from_dir, resize_images_and_labels
1758
1792
  from .io import _load_normalized_images_and_labels, _load_images_and_labels
1759
-
1793
+ from .settings import get_identify_masks_finetune_default_settings
1794
+
1795
+ settings = get_identify_masks_finetune_default_settings(settings)
1796
+
1760
1797
  #User defined settings
1761
1798
  src=settings['src']
1762
1799
  dst=settings['dst']
1763
1800
 
1764
-
1765
- settings.setdefault('model_name', 'cyto')
1766
- settings.setdefault('custom_model', None)
1767
- settings.setdefault('channels', [0,0])
1768
- settings.setdefault('background', 100)
1769
- settings.setdefault('remove_background', False)
1770
- settings.setdefault('Signal_to_noise', 10)
1771
- settings.setdefault('CP_prob', 0)
1772
- settings.setdefault('diameter', 30)
1773
- settings.setdefault('batch_size', 50)
1774
- settings.setdefault('flow_threshold', 0.4)
1775
- settings.setdefault('save', False)
1776
- settings.setdefault('verbose', False)
1777
- settings.setdefault('normalize', True)
1778
- settings.setdefault('percentiles', None)
1779
- settings.setdefault('circular', False)
1780
- settings.setdefault('invert', False)
1781
- settings.setdefault('resize', False)
1782
- settings.setdefault('target_height', None)
1783
- settings.setdefault('target_width', None)
1784
- settings.setdefault('rescale', False)
1785
- settings.setdefault('resample', False)
1786
- settings.setdefault('grayscale', True)
1787
-
1788
-
1789
1801
  model_name=settings['model_name']
1790
1802
  custom_model=settings['custom_model']
1791
1803
  channels = settings['channels']
@@ -1844,23 +1856,25 @@ def identify_masks_finetune(settings):
1844
1856
  print(f'Cellpose settings: Model: {model_name}, channels: {channels}, cellpose_chans: {chans}, diameter:{diameter}, flow_threshold:{flow_threshold}, cellprob_threshold:{CP_prob}')
1845
1857
 
1846
1858
  all_image_files = [os.path.join(src, f) for f in os.listdir(src) if f.endswith('.tif')]
1847
-
1859
+ mask_files = set(os.listdir(os.path.join(src, 'masks')))
1860
+ all_image_files = [f for f in all_image_files if os.path.basename(f) not in mask_files]
1848
1861
  random.shuffle(all_image_files)
1849
1862
 
1850
1863
  time_ls = []
1851
1864
  for i in range(0, len(all_image_files), batch_size):
1865
+ gc.collect()
1852
1866
  image_files = all_image_files[i:i+batch_size]
1853
1867
 
1854
1868
  if normalize:
1855
- images, _, image_names, _ = _load_normalized_images_and_labels(image_files=image_files, label_files=None, channels=channels, percentiles=percentiles, circular=circular, invert=invert, visualize=verbose, remove_background=remove_background, background=background, Signal_to_noise=Signal_to_noise)
1869
+ images, _, image_names, _, orig_dims = _load_normalized_images_and_labels(image_files=image_files, label_files=None, channels=channels, percentiles=percentiles, circular=circular, invert=invert, visualize=verbose, remove_background=remove_background, background=background, Signal_to_noise=Signal_to_noise, target_height=target_height, target_width=target_width)
1856
1870
  images = [np.squeeze(img) if img.shape[-1] == 1 else img for img in images]
1857
- orig_dims = [(image.shape[0], image.shape[1]) for image in images]
1871
+ #orig_dims = [(image.shape[0], image.shape[1]) for image in images]
1858
1872
  else:
1859
1873
  images, _, image_names, _ = _load_images_and_labels(image_files=image_files, label_files=None, circular=circular, invert=invert)
1860
1874
  images = [np.squeeze(img) if img.shape[-1] == 1 else img for img in images]
1861
1875
  orig_dims = [(image.shape[0], image.shape[1]) for image in images]
1862
- if resize:
1863
- images, _ = resize_images_and_labels(images, None, target_height, target_width, True)
1876
+ if resize:
1877
+ images, _ = resize_images_and_labels(images, None, target_height, target_width, True)
1864
1878
 
1865
1879
  for file_index, stack in enumerate(images):
1866
1880
  start = time.time()
@@ -1899,6 +1913,8 @@ def identify_masks_finetune(settings):
1899
1913
  os.makedirs(dst, exist_ok=True)
1900
1914
  output_filename = os.path.join(dst, image_names[file_index])
1901
1915
  cv2.imwrite(output_filename, mask)
1916
+ del images, output, mask, flows
1917
+ gc.collect()
1902
1918
  return
1903
1919
 
1904
1920
  def identify_masks(src, object_type, model_name, batch_size, channels, diameter, minimum_size, maximum_size, filter_intensity, flow_threshold=30, cellprob_threshold=1, figuresize=25, cmap='inferno', refine_masks=True, filter_size=True, filter_dimm=True, remove_border_objects=False, verbose=False, plot=False, merge=False, save=True, start_at=0, file_type='.npz', net_avg=True, resample=True, timelapse=False, timelapse_displacement=None, timelapse_frame_limits=None, timelapse_memory=3, timelapse_remove_transient=False, timelapse_mode='btrack', timelapse_objects='cell'):
@@ -2125,10 +2141,11 @@ def prepare_batch_for_cellpose(batch):
2125
2141
 
2126
2142
  def generate_cellpose_masks(src, settings, object_type):
2127
2143
 
2128
- from .utils import _masks_to_masks_stack, _filter_cp_masks, _get_cellpose_batch_size, _get_object_settings, _get_cellpose_channels, _choose_model, mask_object_count, set_default_settings_preprocess_generate_masks
2144
+ from .utils import _masks_to_masks_stack, _filter_cp_masks, _get_cellpose_batch_size, _get_cellpose_channels, _choose_model, mask_object_count
2129
2145
  from .io import _create_database, _save_object_counts_to_database, _check_masks, _get_avg_object_size
2130
2146
  from .timelapse import _npz_to_movie, _btrack_track_cells, _trackpy_track_cells
2131
2147
  from .plot import plot_masks
2148
+ from .settings import set_default_settings_preprocess_generate_masks, _get_object_settings
2132
2149
 
2133
2150
  gc.collect()
2134
2151
  if not torch.cuda.is_available():
@@ -2457,32 +2474,15 @@ def generate_masks_from_imgs(src, model, model_name, batch_size, diameter, cellp
2457
2474
 
2458
2475
 
2459
2476
  def check_cellpose_models(settings):
2477
+
2478
+ from .settings import get_check_cellpose_models_default_settings
2460
2479
 
2480
+ settings = get_check_cellpose_models_default_settings(settings)
2461
2481
  src = settings['src']
2462
- settings.setdefault('batch_size', 10)
2463
- settings.setdefault('CP_prob', 0)
2464
- settings.setdefault('flow_threshold', 0.4)
2465
- settings.setdefault('save', True)
2466
- settings.setdefault('normalize', True)
2467
- settings.setdefault('channels', [0,0])
2468
- settings.setdefault('percentiles', None)
2469
- settings.setdefault('circular', False)
2470
- settings.setdefault('invert', False)
2471
- settings.setdefault('plot', True)
2472
- settings.setdefault('diameter', 40)
2473
- settings.setdefault('grayscale', True)
2474
- settings.setdefault('remove_background', False)
2475
- settings.setdefault('background', 100)
2476
- settings.setdefault('Signal_to_noise', 5)
2477
- settings.setdefault('verbose', False)
2478
- settings.setdefault('resize', False)
2479
- settings.setdefault('target_height', None)
2480
- settings.setdefault('target_width', None)
2481
2482
 
2482
- if settings['verbose']:
2483
- settings_df = pd.DataFrame(list(settings.items()), columns=['setting_key', 'setting_value'])
2484
- settings_df['setting_value'] = settings_df['setting_value'].apply(str)
2485
- display(settings_df)
2483
+ settings_df = pd.DataFrame(list(settings.items()), columns=['setting_key', 'setting_value'])
2484
+ settings_df['setting_value'] = settings_df['setting_value'].apply(str)
2485
+ display(settings_df)
2486
2486
 
2487
2487
  cellpose_models = ['cyto', 'nuclei', 'cyto2', 'cyto3']
2488
2488
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
@@ -2622,8 +2622,24 @@ def _calculate_similarity(df, features, col_to_compare, val1, val2):
2622
2622
 
2623
2623
  return df
2624
2624
 
2625
- def _permutation_importance(df, feature_string='channel_3', col_to_compare='col', pos='c1', neg='c2', exclude=None, n_repeats=10, clean=True, nr_to_plot=30, n_estimators=100, test_size=0.2, random_state=42, model_type='xgboost', n_jobs=-1):
2626
-
2625
+ def find_optimal_threshold(y_true, y_pred_proba):
2626
+ """
2627
+ Find the optimal threshold for binary classification based on the F1-score.
2628
+
2629
+ Args:
2630
+ y_true (array-like): True binary labels.
2631
+ y_pred_proba (array-like): Predicted probabilities for the positive class.
2632
+
2633
+ Returns:
2634
+ float: The optimal threshold.
2635
+ """
2636
+ precision, recall, thresholds = precision_recall_curve(y_true, y_pred_proba)
2637
+ f1_scores = 2 * (precision * recall) / (precision + recall)
2638
+ optimal_idx = np.argmax(f1_scores)
2639
+ optimal_threshold = thresholds[optimal_idx]
2640
+ return optimal_threshold
2641
+
2642
+ def ml_analysis(df, channel_of_interest=3, location_column='col', positive_control='c2', negative_control='c1', exclude=None, n_repeats=10, top_features=30, n_estimators=100, test_size=0.2, model_type='xgboost', n_jobs=-1, remove_low_variance_features=True, remove_highly_correlated_features=True, verbose=False):
2627
2643
  """
2628
2644
  Calculates permutation importance for numerical features in the dataframe,
2629
2645
  comparing groups based on specified column values and uses the model to predict
@@ -2632,12 +2648,11 @@ def _permutation_importance(df, feature_string='channel_3', col_to_compare='col'
2632
2648
  Args:
2633
2649
  df (pandas.DataFrame): The DataFrame containing the data.
2634
2650
  feature_string (str): String to filter features that contain this substring.
2635
- col_to_compare (str): Column name to use for comparing groups.
2636
- pos, neg (str): Values in col_to_compare to create subsets for comparison.
2651
+ location_column (str): Column name to use for comparing groups.
2652
+ positive_control, negative_control (str): Values in location_column to create subsets for comparison.
2637
2653
  exclude (list or str, optional): Columns to exclude from features.
2638
2654
  n_repeats (int): Number of repeats for permutation importance.
2639
- clean (bool): Whether to remove columns with a single value.
2640
- nr_to_plot (int): Number of top features to plot based on permutation importance.
2655
+ top_features (int): Number of top features to plot based on permutation importance.
2641
2656
  n_estimators (int): Number of trees in the random forest, gradient boosting, or XGBoost model.
2642
2657
  test_size (float): Proportion of the dataset to include in the test split.
2643
2658
  random_state (int): Random seed for reproducibility.
@@ -2650,38 +2665,48 @@ def _permutation_importance(df, feature_string='channel_3', col_to_compare='col'
2650
2665
  """
2651
2666
 
2652
2667
  from .utils import filter_dataframe_features
2668
+ from .plot import plot_permutation, plot_feature_importance
2653
2669
 
2670
+ random_state = 42
2671
+
2654
2672
  if 'cells_per_well' in df.columns:
2655
2673
  df = df.drop(columns=['cells_per_well'])
2656
2674
 
2675
+ df_metadata = df[[location_column]].copy()
2676
+ df, features = filter_dataframe_features(df, channel_of_interest, exclude, remove_low_variance_features, remove_highly_correlated_features, verbose)
2677
+
2678
+
2679
+ if verbose:
2680
+ print(f'Found {len(features)} numerical features in the dataframe')
2681
+ print(f'Features used in training: {features}')
2682
+ df = pd.concat([df, df_metadata[location_column]], axis=1)
2683
+
2657
2684
  # Subset the dataframe based on specified column values
2658
- df1 = df[df[col_to_compare] == pos].copy()
2659
- df2 = df[df[col_to_compare] == neg].copy()
2685
+ df1 = df[df[location_column] == negative_control].copy()
2686
+ df2 = df[df[location_column] == positive_control].copy()
2660
2687
 
2661
2688
  # Create target variable
2662
- df1['target'] = 0
2663
- df2['target'] = 1
2689
+ df1['target'] = 0 # Negative control
2690
+ df2['target'] = 1 # Positive control
2664
2691
 
2665
2692
  # Combine the subsets for analysis
2666
2693
  combined_df = pd.concat([df1, df2])
2667
-
2668
- if feature_string in ['channel_0', 'channel_1', 'channel_2', 'channel_3']:
2669
- channel_of_interest = int(feature_string.split('_')[-1])
2670
- elif not feature_string is 'morphology':
2671
- channel_of_interest = 'morphology'
2672
-
2673
- _, features = filter_dataframe_features(combined_df, channel_of_interest, exclude)
2694
+ combined_df = combined_df.drop(columns=[location_column])
2695
+ if verbose:
2696
+ print(f'Found {len(df1)} samples for {negative_control} and {len(df2)} samples for {positive_control}. Total: {len(combined_df)}')
2674
2697
 
2675
2698
  X = combined_df[features]
2676
2699
  y = combined_df['target']
2677
2700
 
2678
2701
  # Split the data into training and testing sets
2679
2702
  X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=test_size, random_state=random_state)
2680
-
2681
- # Label the data in the original dataframe
2703
+
2704
+ # Add data usage labels
2682
2705
  combined_df['data_usage'] = 'train'
2683
2706
  combined_df.loc[X_test.index, 'data_usage'] = 'test'
2684
-
2707
+ df['data_usage'] = 'not_used'
2708
+ df.loc[combined_df.index, 'data_usage'] = combined_df['data_usage']
2709
+
2685
2710
  # Initialize the model based on model_type
2686
2711
  if model_type == 'random_forest':
2687
2712
  model = RandomForestClassifier(n_estimators=n_estimators, random_state=random_state, n_jobs=n_jobs)
@@ -2703,29 +2728,24 @@ def _permutation_importance(df, feature_string='channel_3', col_to_compare='col'
2703
2728
  'feature': [features[i] for i in perm_importance.importances_mean.argsort()],
2704
2729
  'importance_mean': perm_importance.importances_mean[perm_importance.importances_mean.argsort()],
2705
2730
  'importance_std': perm_importance.importances_std[perm_importance.importances_mean.argsort()]
2706
- }).tail(nr_to_plot)
2731
+ }).tail(top_features)
2732
+
2733
+ permutation_fig = plot_permutation(permutation_df)
2734
+ if verbose:
2735
+ permutation_fig.show()
2707
2736
 
2708
- # Plotting
2709
- fig, ax = plt.subplots()
2710
- ax.barh(permutation_df['feature'], permutation_df['importance_mean'], xerr=permutation_df['importance_std'], color="teal", align="center", alpha=0.6)
2711
- ax.set_xlabel('Permutation Importance')
2712
- plt.tight_layout()
2713
- plt.show()
2714
-
2715
2737
  # Feature importance for models that support it
2716
2738
  if model_type in ['random_forest', 'xgboost', 'gradient_boosting']:
2717
2739
  feature_importances = model.feature_importances_
2718
2740
  feature_importance_df = pd.DataFrame({
2719
2741
  'feature': features,
2720
2742
  'importance': feature_importances
2721
- }).sort_values(by='importance', ascending=False).head(nr_to_plot)
2743
+ }).sort_values(by='importance', ascending=False).head(top_features)
2722
2744
 
2723
- # Plotting feature importance
2724
- fig, ax = plt.subplots()
2725
- ax.barh(feature_importance_df['feature'], feature_importance_df['importance'], color="blue", align="center", alpha=0.6)
2726
- ax.set_xlabel('Feature Importance')
2727
- plt.tight_layout()
2728
- plt.show()
2745
+ feature_importance_fig = plot_feature_importance(feature_importance_df)
2746
+ if verbose:
2747
+ feature_importance_fig.show()
2748
+
2729
2749
  else:
2730
2750
  feature_importance_df = pd.DataFrame()
2731
2751
 
@@ -2733,38 +2753,38 @@ def _permutation_importance(df, feature_string='channel_3', col_to_compare='col'
2733
2753
  predictions_test = model.predict(X_test)
2734
2754
  combined_df.loc[X_test.index, 'predictions'] = predictions_test
2735
2755
 
2736
- # Predicting the target variable for the training set
2737
- predictions_train = model.predict(X_train)
2738
- combined_df.loc[X_train.index, 'predictions'] = predictions_train
2756
+ # Get prediction probabilities for the test set
2757
+ prediction_probabilities_test = model.predict_proba(X_test)
2758
+
2759
+ # Find the optimal threshold
2760
+ optimal_threshold = find_optimal_threshold(y_test, prediction_probabilities_test[:, 1])
2761
+ if verbose:
2762
+ print(f'Optimal threshold: {optimal_threshold}')
2739
2763
 
2740
2764
  # Predicting the target variable for all other rows in the dataframe
2741
2765
  X_all = df[features]
2742
2766
  all_predictions = model.predict(X_all)
2743
2767
  df['predictions'] = all_predictions
2744
2768
 
2745
- # Combine data usage labels back to the original dataframe
2746
- combined_data_usage = pd.concat([combined_df[['data_usage']], df[['predictions']]], axis=0)
2747
- df = df.join(combined_data_usage, how='left', rsuffix='_model')
2748
-
2749
- # Calculating and printing the accuracy metrics
2750
- accuracy = accuracy_score(y_test, predictions_test)
2751
- precision = precision_score(y_test, predictions_test)
2752
- recall = recall_score(y_test, predictions_test)
2753
- f1 = f1_score(y_test, predictions_test)
2754
- print(f"Accuracy: {accuracy}")
2755
- print(f"Precision: {precision}")
2756
- print(f"Recall: {recall}")
2757
- print(f"F1 Score: {f1}")
2758
-
2759
- # Printing class-specific accuracy metrics
2760
- print("\nClassification Report:")
2761
- print(classification_report(y_test, predictions_test))
2769
+ # Get prediction probabilities for all rows in the dataframe
2770
+ prediction_probabilities = model.predict_proba(X_all)
2771
+ for i in range(prediction_probabilities.shape[1]):
2772
+ df[f'prediction_probability_class_{i}'] = prediction_probabilities[:, i]
2773
+ if verbose:
2774
+ print("\nClassification Report:")
2775
+ print(classification_report(y_test, predictions_test))
2776
+ report_dict = classification_report(y_test, predictions_test, output_dict=True)
2777
+ metrics_df = pd.DataFrame(report_dict).transpose()
2762
2778
 
2763
- df = _calculate_similarity(df, features, col_to_compare, pos, neg)
2779
+ df = _calculate_similarity(df, features, location_column, positive_control, negative_control)
2764
2780
 
2765
- return [df, permutation_df, feature_importance_df, model, X_train, X_test, y_train, y_test]
2781
+ df['prcfo'] = df.index.astype(str)
2782
+ df[['plate', 'row', 'col', 'field', 'object']] = df['prcfo'].str.split('_', expand=True)
2783
+ df['prc'] = df['plate'] + '_' + df['row'] + '_' + df['col']
2784
+
2785
+ return [df, permutation_df, feature_importance_df, model, X_train, X_test, y_train, y_test, metrics_df], [permutation_fig, feature_importance_fig]
2766
2786
 
2767
- def _shap_analysis(model, X_train, X_test):
2787
+ def shap_analysis(model, X_train, X_test):
2768
2788
 
2769
2789
  """
2770
2790
  Performs SHAP analysis on the given model and data.
@@ -2773,17 +2793,45 @@ def _shap_analysis(model, X_train, X_test):
2773
2793
  model: The trained model.
2774
2794
  X_train (pandas.DataFrame): Training feature set.
2775
2795
  X_test (pandas.DataFrame): Testing feature set.
2796
+ Returns:
2797
+ fig: Matplotlib figure object containing the SHAP summary plot.
2776
2798
  """
2777
-
2799
+
2778
2800
  explainer = shap.Explainer(model, X_train)
2779
2801
  shap_values = explainer(X_test)
2780
-
2802
+ # Create a new figure
2803
+ fig, ax = plt.subplots()
2781
2804
  # Summary plot
2782
- shap.summary_plot(shap_values, X_test)
2783
-
2784
- def plate_heatmap(src, model_type='xgboost', variable='predictions', grouping='mean', min_max='allq', cmap='viridis', channel_of_interest=3, min_count=25, n_estimators=100, col_to_compare='col', pos='c1', neg='c2', exclude=None, n_repeats=10, clean=True, nr_to_plot=20, verbose=False, n_jobs=-1):
2805
+ shap.summary_plot(shap_values, X_test, show=False)
2806
+ # Save the current figure (the one that SHAP just created)
2807
+ fig = plt.gcf()
2808
+ plt.close(fig) # Close the figure to prevent it from displaying immediately
2809
+ return fig
2810
+
2811
+ def check_index(df, elements=5, split_char='_'):
2812
+ problematic_indices = []
2813
+ for idx in df.index:
2814
+ parts = str(idx).split(split_char)
2815
+ if len(parts) != elements:
2816
+ problematic_indices.append(idx)
2817
+ if problematic_indices:
2818
+ print("Indices that cannot be separated into 5 parts:")
2819
+ for idx in problematic_indices:
2820
+ print(idx)
2821
+ raise ValueError(f"Found {len(problematic_indices)} problematic indices that do not split into {elements} parts.")
2822
+
2823
+ #def plate_heatmap(src, model_type='xgboost', variable='predictions', grouping='mean', min_max='allq', cmap='viridis', channel_of_interest=3, min_count=25, n_estimators=100, col_to_compare='col', pos='c2', neg='c1', exclude=None, n_repeats=10, clean=True, nr_to_plot=20, verbose=False, n_jobs=-1):
2824
+ def generate_ml_scores(src, settings):
2825
+
2785
2826
  from .io import _read_and_merge_data
2786
- from .plot import _plot_plates
2827
+ from .plot import plot_plates
2828
+ from .utils import get_ml_results_paths
2829
+ from .settings import set_default_analyze_screen
2830
+
2831
+ settings = set_default_analyze_screen(settings)
2832
+
2833
+ settings_df = pd.DataFrame(list(settings.items()), columns=['Key', 'Value'])
2834
+ display(settings_df)
2787
2835
 
2788
2836
  db_loc = [src+'/measurements/measurements.db']
2789
2837
  tables = ['cell', 'nucleus', 'pathogen','cytoplasm']
@@ -2791,27 +2839,60 @@ def plate_heatmap(src, model_type='xgboost', variable='predictions', grouping='m
2791
2839
 
2792
2840
  df, _ = _read_and_merge_data(db_loc,
2793
2841
  tables,
2794
- verbose=verbose,
2795
- include_multinucleated=include_multinucleated,
2796
- include_multiinfected=include_multiinfected,
2797
- include_noninfected=include_noninfected)
2798
-
2799
- if not channel_of_interest is None:
2800
- df['recruitment'] = df[f'pathogen_channel_{channel_of_interest}_mean_intensity']/df[f'cytoplasm_channel_{channel_of_interest}_mean_intensity']
2801
- feature_string = f'channel_{channel_of_interest}'
2802
- else:
2803
- feature_string = None
2804
-
2805
- output = _permutation_importance(df, feature_string, col_to_compare, pos, neg, exclude, n_repeats, clean, nr_to_plot, n_estimators=n_estimators, random_state=42, model_type=model_type, n_jobs=n_jobs)
2806
-
2807
- _shap_analysis(output[3], output[4], output[5])
2842
+ settings['verbose'],
2843
+ include_multinucleated,
2844
+ include_multiinfected,
2845
+ include_noninfected)
2846
+
2847
+ if settings['channel_of_interest'] in [0,1,2,3]:
2848
+
2849
+ df['recruitment'] = df[f"pathogen_channel_{settings['channel_of_interest']}_mean_intensity"]/df[f"cytoplasm_channel_{settings['channel_of_interest']}_mean_intensity"]
2850
+
2851
+ output, figs = ml_analysis(df,
2852
+ settings['channel_of_interest'],
2853
+ settings['location_column'],
2854
+ settings['positive_control'],
2855
+ settings['negative_control'],
2856
+ settings['exclude'],
2857
+ settings['n_repeats'],
2858
+ settings['top_features'],
2859
+ settings['n_estimators'],
2860
+ settings['test_size'],
2861
+ settings['model_type'],
2862
+ settings['n_jobs'],
2863
+ settings['remove_low_variance_features'],
2864
+ settings['remove_highly_correlated_features'],
2865
+ settings['verbose'])
2866
+
2867
+ shap_fig = shap_analysis(output[3], output[4], output[5])
2808
2868
 
2809
2869
  features = output[0].select_dtypes(include=[np.number]).columns.tolist()
2810
2870
 
2811
- if not variable in features:
2812
- raise ValueError(f"Variable {variable} not found in the dataframe. Please choose one of the following: {features}")
2871
+ if not settings['heatmap_feature'] in features:
2872
+ raise ValueError(f"Variable {settings['heatmap_feature']} not found in the dataframe. Please choose one of the following: {features}")
2813
2873
 
2814
- plate_heatmap = _plot_plates(output[0], variable, grouping, min_max, cmap, min_count)
2874
+ plate_heatmap = plot_plates(df=output[0],
2875
+ variable=settings['heatmap_feature'],
2876
+ grouping=settings['grouping'],
2877
+ min_max=settings['min_max'],
2878
+ cmap=settings['cmap'],
2879
+ min_count=settings['minimum_cell_count'],
2880
+ verbose=settings['verbose'])
2881
+
2882
+ data_path, permutation_path, feature_importance_path, model_metricks_path, permutation_fig_path, feature_importance_fig_path, shap_fig_path, plate_heatmap_path, settings_csv = get_ml_results_paths(src, settings['model_type'], settings['channel_of_interest'])
2883
+ df, permutation_df, feature_importance_df, _, _, _, _, _, metrics_df = output
2884
+
2885
+ settings_df.to_csv(settings_csv, index=False)
2886
+ df.to_csv(data_path, mode='w', encoding='utf-8')
2887
+ permutation_df.to_csv(permutation_path, mode='w', encoding='utf-8')
2888
+ feature_importance_df.to_csv(feature_importance_path, mode='w', encoding='utf-8')
2889
+ metrics_df.to_csv(model_metricks_path, mode='w', encoding='utf-8')
2890
+
2891
+ plate_heatmap.savefig(plate_heatmap_path, format='pdf')
2892
+ figs[0].savefig(permutation_fig_path, format='pdf')
2893
+ figs[1].savefig(feature_importance_fig_path, format='pdf')
2894
+ shap_fig.savefig(shap_fig_path, format='pdf')
2895
+
2815
2896
  return [output, plate_heatmap]
2816
2897
 
2817
2898
  def join_measurments_and_annotation(src, tables = ['cell', 'nucleus', 'pathogen','cytoplasm']):
@@ -2940,8 +3021,8 @@ def generate_image_umap(settings={}):
2940
3021
  """
2941
3022
 
2942
3023
  from .io import _read_and_join_tables
2943
- from .utils import get_db_paths, preprocess_data, reduction_and_clustering, remove_noise, generate_colors, correct_paths, plot_embedding, plot_clusters_grid, get_umap_image_settings, cluster_feature_analysis, generate_umap_from_images
2944
-
3024
+ from .utils import get_db_paths, preprocess_data, reduction_and_clustering, remove_noise, generate_colors, correct_paths, plot_embedding, plot_clusters_grid, cluster_feature_analysis, generate_umap_from_images
3025
+ from .settings import get_umap_image_settings
2945
3026
  settings = get_umap_image_settings(settings)
2946
3027
 
2947
3028
  if isinstance(settings['src'], str):
@@ -3123,7 +3204,8 @@ def reducer_hyperparameter_search(settings={}, reduction_params=None, dbscan_par
3123
3204
  """
3124
3205
 
3125
3206
  from .io import _read_and_join_tables
3126
- from .utils import get_db_paths, preprocess_data, search_reduction_and_clustering, generate_colors, get_umap_image_settings
3207
+ from .utils import get_db_paths, preprocess_data, search_reduction_and_clustering, generate_colors
3208
+ from .settings import get_umap_image_settings
3127
3209
 
3128
3210
  settings = get_umap_image_settings(settings)
3129
3211
  pointsize = settings['dot_size']