spacr 0.0.81__py3-none-any.whl → 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
spacr/core.py CHANGED
@@ -13,7 +13,7 @@ from IPython.display import display
13
13
  from multiprocessing import Pool, cpu_count, Value, Lock
14
14
 
15
15
  import seaborn as sns
16
-
16
+ import cellpose
17
17
  from skimage.measure import regionprops, label
18
18
  from skimage.transform import resize as resizescikit
19
19
  from torch.utils.data import DataLoader
@@ -25,6 +25,7 @@ from sklearn.linear_model import LogisticRegression
25
25
  from sklearn.inspection import permutation_importance
26
26
  from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, classification_report
27
27
  from sklearn.preprocessing import StandardScaler
28
+ from sklearn.metrics import precision_recall_curve, f1_score
28
29
 
29
30
  from scipy.spatial.distance import cosine, euclidean, mahalanobis, cityblock, minkowski, chebyshev, hamming, jaccard, braycurtis
30
31
 
@@ -41,43 +42,51 @@ from .logger import log_function_call
41
42
  def analyze_plaques(folder):
42
43
  summary_data = []
43
44
  details_data = []
45
+ stats_data = []
44
46
 
45
47
  for filename in os.listdir(folder):
46
48
  filepath = os.path.join(folder, filename)
47
49
  if os.path.isfile(filepath):
48
50
  # Assuming each file is a NumPy array file (.npy) containing a 16-bit labeled image
49
- image = np.load(filepath)
50
-
51
+ #image = np.load(filepath)
52
+ image = cellpose.io.imread(filepath)
51
53
  labeled_image = label(image)
52
54
  regions = regionprops(labeled_image)
53
55
 
54
56
  object_count = len(regions)
55
57
  sizes = [region.area for region in regions]
56
58
  average_size = np.mean(sizes) if sizes else 0
59
+ std_dev_size = np.std(sizes) if sizes else 0
57
60
 
58
61
  summary_data.append({'file': filename, 'object_count': object_count, 'average_size': average_size})
62
+ stats_data.append({'file': filename, 'plaque_count': object_count, 'average_size': average_size, 'std_dev_size': std_dev_size})
59
63
  for size in sizes:
60
64
  details_data.append({'file': filename, 'plaque_size': size})
61
65
 
62
66
  # Convert lists to pandas DataFrames
63
67
  summary_df = pd.DataFrame(summary_data)
64
68
  details_df = pd.DataFrame(details_data)
69
+ stats_df = pd.DataFrame(stats_data)
65
70
 
66
71
  # Save DataFrames to a SQLite database
67
- db_name = 'plaques_analysis.db'
72
+ db_name = os.path.join(folder, 'plaques_analysis.db')
68
73
  conn = sqlite3.connect(db_name)
69
74
 
70
75
  summary_df.to_sql('summary', conn, if_exists='replace', index=False)
71
76
  details_df.to_sql('details', conn, if_exists='replace', index=False)
77
+ stats_df.to_sql('stats', conn, if_exists='replace', index=False)
72
78
 
73
79
  conn.close()
74
80
 
75
81
  print(f"Analysis completed and saved to database '{db_name}'.")
76
82
 
83
+
77
84
  def train_cellpose(settings):
78
85
 
79
86
  from .io import _load_normalized_images_and_labels, _load_images_and_labels
80
- from .utils import resize_images_and_labels
87
+ from .settings import get_train_cellpose_default_settings#, resize_images_and_labels
88
+
89
+ settings = get_train_cellpose_default_settings()
81
90
 
82
91
  img_src = settings['img_src']
83
92
  mask_src = os.path.join(img_src, 'masks')
@@ -99,7 +108,6 @@ def train_cellpose(settings):
99
108
  Signal_to_noise = settings.setdefault( 'Signal_to_noise', 10)
100
109
  verbose = settings.setdefault( 'verbose', False)
101
110
 
102
-
103
111
  channels = settings.setdefault( 'channels', [0,0])
104
112
  normalize = settings.setdefault( 'normalize', True)
105
113
  percentiles = settings.setdefault( 'percentiles', None)
@@ -119,7 +127,7 @@ def train_cellpose(settings):
119
127
  test_img_src = os.path.join(os.path.dirname(img_src), 'test')
120
128
  test_mask_src = os.path.join(test_img_src, 'mask')
121
129
 
122
- test_images, test_masks, test_image_names, test_mask_names = None,None,None,None,
130
+ test_images, test_masks, test_image_names, test_mask_names = None,None,None,None
123
131
  print(settings)
124
132
 
125
133
  if from_scratch:
@@ -147,13 +155,13 @@ def train_cellpose(settings):
147
155
 
148
156
  image_files = [os.path.join(img_src, f) for f in os.listdir(img_src) if f.endswith('.tif')]
149
157
  label_files = [os.path.join(mask_src, f) for f in os.listdir(mask_src) if f.endswith('.tif')]
150
- images, masks, image_names, mask_names = _load_normalized_images_and_labels(image_files, label_files, channels, percentiles, circular, invert, verbose, remove_background, background, Signal_to_noise)
158
+ images, masks, image_names, mask_names, orig_dims = _load_normalized_images_and_labels(image_files, label_files, channels, percentiles, circular, invert, verbose, remove_background, background, Signal_to_noise, target_height, target_width)
151
159
  images = [np.squeeze(img) if img.shape[-1] == 1 else img for img in images]
152
160
 
153
161
  if test:
154
162
  test_image_files = [os.path.join(test_img_src, f) for f in os.listdir(test_img_src) if f.endswith('.tif')]
155
163
  test_label_files = [os.path.join(test_mask_src, f) for f in os.listdir(test_mask_src) if f.endswith('.tif')]
156
- test_images, test_masks, test_image_names, test_mask_names = _load_normalized_images_and_labels(test_image_files, test_label_files, channels, percentiles, circular, invert, verbose, remove_background, background, Signal_to_noise)
164
+ test_images, test_masks, test_image_names, test_mask_names = _load_normalized_images_and_labels(test_image_files, test_label_files, channels, percentiles, circular, invert, verbose, remove_background, background, Signal_to_noise, target_height, target_width)
157
165
  test_images = [np.squeeze(img) if img.shape[-1] == 1 else img for img in test_images]
158
166
 
159
167
  else:
@@ -164,8 +172,8 @@ def train_cellpose(settings):
164
172
  test_images, test_masks, test_image_names, test_mask_names = _load_images_and_labels(img_src=test_img_src, mask_src=test_mask_src, circular=circular, invert=invert)
165
173
  test_images = [np.squeeze(img) if img.shape[-1] == 1 else img for img in test_images]
166
174
 
167
- if resize:
168
- images, masks = resize_images_and_labels(images, masks, target_height, target_width, show_example=True)
175
+ #if resize:
176
+ # images, masks = resize_images_and_labels(images, masks, target_height, target_width, show_example=True)
169
177
 
170
178
  if model_type == 'cyto':
171
179
  cp_channels = [0,1]
@@ -963,9 +971,10 @@ def generate_dataset(src, file_metadata=None, experiment='TSG101_screen', sample
963
971
  shutil.rmtree(temp_dir)
964
972
  print(f"\nSaved {total_images} images to {tar_name}")
965
973
 
966
- def apply_model_to_tar(tar_path, model_path, file_type='cell_png', image_size=224, batch_size=64, normalize=True, preload='images', num_workers=10, verbose=False):
974
+ def apply_model_to_tar(tar_path, model_path, file_type='cell_png', image_size=224, batch_size=64, normalize=True, preload='images', num_workers=10, threshold=0.5, verbose=False):
967
975
 
968
- from .io import TarImageDataset, DataLoader
976
+ from .io import TarImageDataset
977
+ from .utils import process_vision_results
969
978
 
970
979
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
971
980
  if normalize:
@@ -1018,6 +1027,8 @@ def apply_model_to_tar(tar_path, model_path, file_type='cell_png', image_size=22
1018
1027
 
1019
1028
  data = {'path':filenames_list, 'pred':prediction_pos_probs}
1020
1029
  df = pd.DataFrame(data, index=None)
1030
+ df = process_vision_results(df, threshold)
1031
+
1021
1032
  df.to_csv(result_loc, index=True, header=True, mode='w')
1022
1033
  torch.cuda.empty_cache()
1023
1034
  torch.cuda.memory.empty_cache()
@@ -1291,7 +1302,7 @@ def generate_training_dataset(src, mode='annotation', annotation_column='test',
1291
1302
 
1292
1303
  return
1293
1304
 
1294
- def generate_loaders(src, train_mode='erm', mode='train', image_size=224, batch_size=32, classes=['nc','pc'], num_workers=None, validation_split=0.0, max_show=2, pin_memory=False, normalize=False, channels=[1, 2, 3], verbose=False):
1305
+ def generate_loaders(src, train_mode='erm', mode='train', image_size=224, batch_size=32, classes=['nc','pc'], num_workers=None, validation_split=0.0, max_show=2, pin_memory=False, normalize=False, channels=[1, 2, 3], augment=False, verbose=False):
1295
1306
 
1296
1307
  """
1297
1308
  Generate data loaders for training and validation/test datasets.
@@ -1326,7 +1337,7 @@ def generate_loaders(src, train_mode='erm', mode='train', image_size=224, batch_
1326
1337
  import random
1327
1338
  from PIL import Image
1328
1339
  from torchvision.transforms import ToTensor
1329
- from .utils import SelectChannels
1340
+ from .utils import SelectChannels, augment_dataset
1330
1341
 
1331
1342
  chans = []
1332
1343
 
@@ -1376,14 +1387,22 @@ def generate_loaders(src, train_mode='erm', mode='train', image_size=224, batch_
1376
1387
  return
1377
1388
 
1378
1389
  if train_mode == 'erm':
1390
+
1379
1391
  data = MyDataset(data_dir, classes, transform=transform, shuffle=shuffle, pin_memory=pin_memory)
1392
+
1380
1393
  if validation_split > 0:
1381
1394
  train_size = int((1 - validation_split) * len(data))
1382
1395
  val_size = len(data) - train_size
1396
+ if not augment:
1397
+ print(f'Train data:{train_size}, Validation data:{val_size}')
1398
+ train_dataset, val_dataset = random_split(data, [train_size, val_size])
1383
1399
 
1384
- print(f'Train data:{train_size}, Validation data:{val_size}')
1400
+ if augment:
1385
1401
 
1386
- train_dataset, val_dataset = random_split(data, [train_size, val_size])
1402
+ print(f'Data before augmentation: Train: {len(train_dataset)}, Validataion:{len(val_dataset)}')
1403
+ train_dataset = augment_dataset(train_dataset, is_grayscale=(len(channels) == 1))
1404
+ #val_dataset = augment_dataset(val_dataset, is_grayscale=(len(channels) == 1))
1405
+ print(f'Data after augmentation: Train: {len(train_dataset)}')#, Validataion:{len(val_dataset)}')
1387
1406
 
1388
1407
  train_loaders = DataLoader(train_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers if num_workers is not None else 0, pin_memory=pin_memory)
1389
1408
  val_loaders = DataLoader(val_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers if num_workers is not None else 0, pin_memory=pin_memory)
@@ -1406,10 +1425,16 @@ def generate_loaders(src, train_mode='erm', mode='train', image_size=224, batch_
1406
1425
  if validation_split > 0:
1407
1426
  train_size = int((1 - validation_split) * len(plate_data))
1408
1427
  val_size = len(plate_data) - train_size
1428
+ if not augment:
1429
+ print(f'Train data:{train_size}, Validation data:{val_size}')
1430
+ train_dataset, val_dataset = random_split(plate_data, [train_size, val_size])
1409
1431
 
1410
- print(f'Train data:{train_size}, Validation data:{val_size}')
1432
+ if augment:
1411
1433
 
1412
- train_dataset, val_dataset = random_split(plate_data, [train_size, val_size])
1434
+ print(f'Data before augmentation: Train: {len(train_dataset)}, Validataion:{val_dataset}')
1435
+ train_dataset = augment_dataset(train_dataset, is_grayscale=(len(channels) == 1))
1436
+ #val_dataset = augment_dataset(val_dataset, is_grayscale=(len(channels) == 1))
1437
+ print(f'Data after augmentation: Train: {len(train_dataset)}')#, Validataion:{len(val_dataset)}')
1413
1438
 
1414
1439
  train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers if num_workers is not None else 0, pin_memory=pin_memory)
1415
1440
  val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers if num_workers is not None else 0, pin_memory=pin_memory)
@@ -1424,28 +1449,33 @@ def generate_loaders(src, train_mode='erm', mode='train', image_size=224, batch_
1424
1449
  else:
1425
1450
  print(f'train_mode:{train_mode} is not valid, use: train_mode = irm or erm')
1426
1451
  return
1452
+
1453
+
1454
+ if train_mode == 'erm':
1455
+ for idx, (images, labels, filenames) in enumerate(train_loaders):
1456
+ if idx >= max_show:
1457
+ break
1458
+ images = images.cpu()
1459
+ label_strings = [str(label.item()) for label in labels]
1460
+ train_fig = _imshow(images, label_strings, nrow=20, fontsize=12)
1461
+ if verbose:
1462
+ plt.show()
1427
1463
 
1428
- if verbose:
1429
- if train_mode == 'erm':
1430
- for idx, (images, labels, filenames) in enumerate(train_loaders):
1464
+ elif train_mode == 'irm':
1465
+ for plate_name, train_loader in zip(plate_names, train_loaders):
1466
+ print(f'Plate: {plate_name} with {len(train_loader.dataset)} images')
1467
+ for idx, (images, labels, filenames) in enumerate(train_loader):
1431
1468
  if idx >= max_show:
1432
1469
  break
1433
1470
  images = images.cpu()
1434
1471
  label_strings = [str(label.item()) for label in labels]
1435
- _imshow(images, label_strings, nrow=20, fontsize=12)
1436
- elif train_mode == 'irm':
1437
- for plate_name, train_loader in zip(plate_names, train_loaders):
1438
- print(f'Plate: {plate_name} with {len(train_loader.dataset)} images')
1439
- for idx, (images, labels, filenames) in enumerate(train_loader):
1440
- if idx >= max_show:
1441
- break
1442
- images = images.cpu()
1443
- label_strings = [str(label.item()) for label in labels]
1444
- _imshow(images, label_strings, nrow=20, fontsize=12)
1445
-
1446
- return train_loaders, val_loaders, plate_names
1447
-
1448
- def analyze_recruitment(src, metadata_settings, advanced_settings):
1472
+ train_fig = _imshow(images, label_strings, nrow=20, fontsize=12)
1473
+ if verbose:
1474
+ plt.show()
1475
+
1476
+ return train_loaders, val_loaders, plate_names, train_fig
1477
+
1478
+ def analyze_recruitment(src, metadata_settings={}, advanced_settings={}):
1449
1479
  """
1450
1480
  Analyze recruitment data by grouping the DataFrame by well coordinates and plotting controls and recruitment data.
1451
1481
 
@@ -1461,6 +1491,9 @@ def analyze_recruitment(src, metadata_settings, advanced_settings):
1461
1491
  from .io import _read_and_merge_data, _results_to_csv
1462
1492
  from .plot import plot_merged, _plot_controls, _plot_recruitment
1463
1493
  from .utils import _object_filter, annotate_conditions, _calculate_recruitment, _group_by_well
1494
+ from .settings import get_analyze_recruitment_default_settings
1495
+
1496
+ settings = get_analyze_recruitment_default_settings(settings)
1464
1497
 
1465
1498
  settings_dict = {**metadata_settings, **advanced_settings}
1466
1499
  settings_df = pd.DataFrame(list(settings_dict.items()), columns=['Key', 'Value'])
@@ -1635,8 +1668,8 @@ def preprocess_generate_masks(src, settings={}):
1635
1668
 
1636
1669
  from .io import preprocess_img_data, _load_and_concatenate_arrays
1637
1670
  from .plot import plot_merged, plot_arrays
1638
- from .utils import _pivot_counts_table, set_default_settings_preprocess_generate_masks, set_default_plot_merge_settings, check_mask_folder
1639
- from .utils import adjust_cell_masks, _merge_cells_based_on_parasite_overlap, process_masks
1671
+ from .utils import _pivot_counts_table, check_mask_folder, adjust_cell_masks, _merge_cells_based_on_parasite_overlap, process_masks
1672
+ from .settings import set_default_settings_preprocess_generate_masks, set_default_plot_merge_settings
1640
1673
 
1641
1674
  settings = set_default_settings_preprocess_generate_masks(src, settings)
1642
1675
  settings_df = pd.DataFrame(list(settings.items()), columns=['Key', 'Value'])
@@ -1757,36 +1790,14 @@ def identify_masks_finetune(settings):
1757
1790
  from .plot import print_mask_and_flows
1758
1791
  from .utils import get_files_from_dir, resize_images_and_labels
1759
1792
  from .io import _load_normalized_images_and_labels, _load_images_and_labels
1760
-
1793
+ from .settings import get_identify_masks_finetune_default_settings
1794
+
1795
+ settings = get_identify_masks_finetune_default_settings(settings)
1796
+
1761
1797
  #User defined settings
1762
1798
  src=settings['src']
1763
1799
  dst=settings['dst']
1764
1800
 
1765
-
1766
- settings.setdefault('model_name', 'cyto')
1767
- settings.setdefault('custom_model', None)
1768
- settings.setdefault('channels', [0,0])
1769
- settings.setdefault('background', 100)
1770
- settings.setdefault('remove_background', False)
1771
- settings.setdefault('Signal_to_noise', 10)
1772
- settings.setdefault('CP_prob', 0)
1773
- settings.setdefault('diameter', 30)
1774
- settings.setdefault('batch_size', 50)
1775
- settings.setdefault('flow_threshold', 0.4)
1776
- settings.setdefault('save', False)
1777
- settings.setdefault('verbose', False)
1778
- settings.setdefault('normalize', True)
1779
- settings.setdefault('percentiles', None)
1780
- settings.setdefault('circular', False)
1781
- settings.setdefault('invert', False)
1782
- settings.setdefault('resize', False)
1783
- settings.setdefault('target_height', None)
1784
- settings.setdefault('target_width', None)
1785
- settings.setdefault('rescale', False)
1786
- settings.setdefault('resample', False)
1787
- settings.setdefault('grayscale', True)
1788
-
1789
-
1790
1801
  model_name=settings['model_name']
1791
1802
  custom_model=settings['custom_model']
1792
1803
  channels = settings['channels']
@@ -1845,23 +1856,25 @@ def identify_masks_finetune(settings):
1845
1856
  print(f'Cellpose settings: Model: {model_name}, channels: {channels}, cellpose_chans: {chans}, diameter:{diameter}, flow_threshold:{flow_threshold}, cellprob_threshold:{CP_prob}')
1846
1857
 
1847
1858
  all_image_files = [os.path.join(src, f) for f in os.listdir(src) if f.endswith('.tif')]
1848
-
1859
+ mask_files = set(os.listdir(os.path.join(src, 'masks')))
1860
+ all_image_files = [f for f in all_image_files if os.path.basename(f) not in mask_files]
1849
1861
  random.shuffle(all_image_files)
1850
1862
 
1851
1863
  time_ls = []
1852
1864
  for i in range(0, len(all_image_files), batch_size):
1865
+ gc.collect()
1853
1866
  image_files = all_image_files[i:i+batch_size]
1854
1867
 
1855
1868
  if normalize:
1856
- images, _, image_names, _ = _load_normalized_images_and_labels(image_files=image_files, label_files=None, channels=channels, percentiles=percentiles, circular=circular, invert=invert, visualize=verbose, remove_background=remove_background, background=background, Signal_to_noise=Signal_to_noise)
1869
+ images, _, image_names, _, orig_dims = _load_normalized_images_and_labels(image_files=image_files, label_files=None, channels=channels, percentiles=percentiles, circular=circular, invert=invert, visualize=verbose, remove_background=remove_background, background=background, Signal_to_noise=Signal_to_noise, target_height=target_height, target_width=target_width)
1857
1870
  images = [np.squeeze(img) if img.shape[-1] == 1 else img for img in images]
1858
- orig_dims = [(image.shape[0], image.shape[1]) for image in images]
1871
+ #orig_dims = [(image.shape[0], image.shape[1]) for image in images]
1859
1872
  else:
1860
1873
  images, _, image_names, _ = _load_images_and_labels(image_files=image_files, label_files=None, circular=circular, invert=invert)
1861
1874
  images = [np.squeeze(img) if img.shape[-1] == 1 else img for img in images]
1862
1875
  orig_dims = [(image.shape[0], image.shape[1]) for image in images]
1863
- if resize:
1864
- images, _ = resize_images_and_labels(images, None, target_height, target_width, True)
1876
+ if resize:
1877
+ images, _ = resize_images_and_labels(images, None, target_height, target_width, True)
1865
1878
 
1866
1879
  for file_index, stack in enumerate(images):
1867
1880
  start = time.time()
@@ -1900,6 +1913,8 @@ def identify_masks_finetune(settings):
1900
1913
  os.makedirs(dst, exist_ok=True)
1901
1914
  output_filename = os.path.join(dst, image_names[file_index])
1902
1915
  cv2.imwrite(output_filename, mask)
1916
+ del images, output, mask, flows
1917
+ gc.collect()
1903
1918
  return
1904
1919
 
1905
1920
  def identify_masks(src, object_type, model_name, batch_size, channels, diameter, minimum_size, maximum_size, filter_intensity, flow_threshold=30, cellprob_threshold=1, figuresize=25, cmap='inferno', refine_masks=True, filter_size=True, filter_dimm=True, remove_border_objects=False, verbose=False, plot=False, merge=False, save=True, start_at=0, file_type='.npz', net_avg=True, resample=True, timelapse=False, timelapse_displacement=None, timelapse_frame_limits=None, timelapse_memory=3, timelapse_remove_transient=False, timelapse_mode='btrack', timelapse_objects='cell'):
@@ -2126,10 +2141,11 @@ def prepare_batch_for_cellpose(batch):
2126
2141
 
2127
2142
  def generate_cellpose_masks(src, settings, object_type):
2128
2143
 
2129
- from .utils import _masks_to_masks_stack, _filter_cp_masks, _get_cellpose_batch_size, _get_object_settings, _get_cellpose_channels, _choose_model, mask_object_count, set_default_settings_preprocess_generate_masks
2144
+ from .utils import _masks_to_masks_stack, _filter_cp_masks, _get_cellpose_batch_size, _get_cellpose_channels, _choose_model, mask_object_count
2130
2145
  from .io import _create_database, _save_object_counts_to_database, _check_masks, _get_avg_object_size
2131
2146
  from .timelapse import _npz_to_movie, _btrack_track_cells, _trackpy_track_cells
2132
2147
  from .plot import plot_masks
2148
+ from .settings import set_default_settings_preprocess_generate_masks, _get_object_settings
2133
2149
 
2134
2150
  gc.collect()
2135
2151
  if not torch.cuda.is_available():
@@ -2458,32 +2474,15 @@ def generate_masks_from_imgs(src, model, model_name, batch_size, diameter, cellp
2458
2474
 
2459
2475
 
2460
2476
  def check_cellpose_models(settings):
2477
+
2478
+ from .settings import get_check_cellpose_models_default_settings
2461
2479
 
2480
+ settings = get_check_cellpose_models_default_settings(settings)
2462
2481
  src = settings['src']
2463
- settings.setdefault('batch_size', 10)
2464
- settings.setdefault('CP_prob', 0)
2465
- settings.setdefault('flow_threshold', 0.4)
2466
- settings.setdefault('save', True)
2467
- settings.setdefault('normalize', True)
2468
- settings.setdefault('channels', [0,0])
2469
- settings.setdefault('percentiles', None)
2470
- settings.setdefault('circular', False)
2471
- settings.setdefault('invert', False)
2472
- settings.setdefault('plot', True)
2473
- settings.setdefault('diameter', 40)
2474
- settings.setdefault('grayscale', True)
2475
- settings.setdefault('remove_background', False)
2476
- settings.setdefault('background', 100)
2477
- settings.setdefault('Signal_to_noise', 5)
2478
- settings.setdefault('verbose', False)
2479
- settings.setdefault('resize', False)
2480
- settings.setdefault('target_height', None)
2481
- settings.setdefault('target_width', None)
2482
2482
 
2483
- if settings['verbose']:
2484
- settings_df = pd.DataFrame(list(settings.items()), columns=['setting_key', 'setting_value'])
2485
- settings_df['setting_value'] = settings_df['setting_value'].apply(str)
2486
- display(settings_df)
2483
+ settings_df = pd.DataFrame(list(settings.items()), columns=['setting_key', 'setting_value'])
2484
+ settings_df['setting_value'] = settings_df['setting_value'].apply(str)
2485
+ display(settings_df)
2487
2486
 
2488
2487
  cellpose_models = ['cyto', 'nuclei', 'cyto2', 'cyto3']
2489
2488
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
@@ -2623,8 +2622,24 @@ def _calculate_similarity(df, features, col_to_compare, val1, val2):
2623
2622
 
2624
2623
  return df
2625
2624
 
2626
- def _permutation_importance(df, feature_string='channel_3', col_to_compare='col', pos='c1', neg='c2', exclude=None, n_repeats=10, clean=True, nr_to_plot=30, n_estimators=100, test_size=0.2, random_state=42, model_type='xgboost', n_jobs=-1):
2627
-
2625
+ def find_optimal_threshold(y_true, y_pred_proba):
2626
+ """
2627
+ Find the optimal threshold for binary classification based on the F1-score.
2628
+
2629
+ Args:
2630
+ y_true (array-like): True binary labels.
2631
+ y_pred_proba (array-like): Predicted probabilities for the positive class.
2632
+
2633
+ Returns:
2634
+ float: The optimal threshold.
2635
+ """
2636
+ precision, recall, thresholds = precision_recall_curve(y_true, y_pred_proba)
2637
+ f1_scores = 2 * (precision * recall) / (precision + recall)
2638
+ optimal_idx = np.argmax(f1_scores)
2639
+ optimal_threshold = thresholds[optimal_idx]
2640
+ return optimal_threshold
2641
+
2642
+ def ml_analysis(df, channel_of_interest=3, location_column='col', positive_control='c2', negative_control='c1', exclude=None, n_repeats=10, top_features=30, n_estimators=100, test_size=0.2, model_type='xgboost', n_jobs=-1, remove_low_variance_features=True, remove_highly_correlated_features=True, verbose=False):
2628
2643
  """
2629
2644
  Calculates permutation importance for numerical features in the dataframe,
2630
2645
  comparing groups based on specified column values and uses the model to predict
@@ -2633,12 +2648,11 @@ def _permutation_importance(df, feature_string='channel_3', col_to_compare='col'
2633
2648
  Args:
2634
2649
  df (pandas.DataFrame): The DataFrame containing the data.
2635
2650
  feature_string (str): String to filter features that contain this substring.
2636
- col_to_compare (str): Column name to use for comparing groups.
2637
- pos, neg (str): Values in col_to_compare to create subsets for comparison.
2651
+ location_column (str): Column name to use for comparing groups.
2652
+ positive_control, negative_control (str): Values in location_column to create subsets for comparison.
2638
2653
  exclude (list or str, optional): Columns to exclude from features.
2639
2654
  n_repeats (int): Number of repeats for permutation importance.
2640
- clean (bool): Whether to remove columns with a single value.
2641
- nr_to_plot (int): Number of top features to plot based on permutation importance.
2655
+ top_features (int): Number of top features to plot based on permutation importance.
2642
2656
  n_estimators (int): Number of trees in the random forest, gradient boosting, or XGBoost model.
2643
2657
  test_size (float): Proportion of the dataset to include in the test split.
2644
2658
  random_state (int): Random seed for reproducibility.
@@ -2651,38 +2665,48 @@ def _permutation_importance(df, feature_string='channel_3', col_to_compare='col'
2651
2665
  """
2652
2666
 
2653
2667
  from .utils import filter_dataframe_features
2668
+ from .plot import plot_permutation, plot_feature_importance
2654
2669
 
2670
+ random_state = 42
2671
+
2655
2672
  if 'cells_per_well' in df.columns:
2656
2673
  df = df.drop(columns=['cells_per_well'])
2657
2674
 
2675
+ df_metadata = df[[location_column]].copy()
2676
+ df, features = filter_dataframe_features(df, channel_of_interest, exclude, remove_low_variance_features, remove_highly_correlated_features, verbose)
2677
+
2678
+
2679
+ if verbose:
2680
+ print(f'Found {len(features)} numerical features in the dataframe')
2681
+ print(f'Features used in training: {features}')
2682
+ df = pd.concat([df, df_metadata[location_column]], axis=1)
2683
+
2658
2684
  # Subset the dataframe based on specified column values
2659
- df1 = df[df[col_to_compare] == pos].copy()
2660
- df2 = df[df[col_to_compare] == neg].copy()
2685
+ df1 = df[df[location_column] == negative_control].copy()
2686
+ df2 = df[df[location_column] == positive_control].copy()
2661
2687
 
2662
2688
  # Create target variable
2663
- df1['target'] = 0
2664
- df2['target'] = 1
2689
+ df1['target'] = 0 # Negative control
2690
+ df2['target'] = 1 # Positive control
2665
2691
 
2666
2692
  # Combine the subsets for analysis
2667
2693
  combined_df = pd.concat([df1, df2])
2668
-
2669
- if feature_string in ['channel_0', 'channel_1', 'channel_2', 'channel_3']:
2670
- channel_of_interest = int(feature_string.split('_')[-1])
2671
- elif not feature_string is 'morphology':
2672
- channel_of_interest = 'morphology'
2673
-
2674
- _, features = filter_dataframe_features(combined_df, channel_of_interest, exclude)
2694
+ combined_df = combined_df.drop(columns=[location_column])
2695
+ if verbose:
2696
+ print(f'Found {len(df1)} samples for {negative_control} and {len(df2)} samples for {positive_control}. Total: {len(combined_df)}')
2675
2697
 
2676
2698
  X = combined_df[features]
2677
2699
  y = combined_df['target']
2678
2700
 
2679
2701
  # Split the data into training and testing sets
2680
2702
  X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=test_size, random_state=random_state)
2681
-
2682
- # Label the data in the original dataframe
2703
+
2704
+ # Add data usage labels
2683
2705
  combined_df['data_usage'] = 'train'
2684
2706
  combined_df.loc[X_test.index, 'data_usage'] = 'test'
2685
-
2707
+ df['data_usage'] = 'not_used'
2708
+ df.loc[combined_df.index, 'data_usage'] = combined_df['data_usage']
2709
+
2686
2710
  # Initialize the model based on model_type
2687
2711
  if model_type == 'random_forest':
2688
2712
  model = RandomForestClassifier(n_estimators=n_estimators, random_state=random_state, n_jobs=n_jobs)
@@ -2704,29 +2728,24 @@ def _permutation_importance(df, feature_string='channel_3', col_to_compare='col'
2704
2728
  'feature': [features[i] for i in perm_importance.importances_mean.argsort()],
2705
2729
  'importance_mean': perm_importance.importances_mean[perm_importance.importances_mean.argsort()],
2706
2730
  'importance_std': perm_importance.importances_std[perm_importance.importances_mean.argsort()]
2707
- }).tail(nr_to_plot)
2731
+ }).tail(top_features)
2732
+
2733
+ permutation_fig = plot_permutation(permutation_df)
2734
+ if verbose:
2735
+ permutation_fig.show()
2708
2736
 
2709
- # Plotting
2710
- fig, ax = plt.subplots()
2711
- ax.barh(permutation_df['feature'], permutation_df['importance_mean'], xerr=permutation_df['importance_std'], color="teal", align="center", alpha=0.6)
2712
- ax.set_xlabel('Permutation Importance')
2713
- plt.tight_layout()
2714
- plt.show()
2715
-
2716
2737
  # Feature importance for models that support it
2717
2738
  if model_type in ['random_forest', 'xgboost', 'gradient_boosting']:
2718
2739
  feature_importances = model.feature_importances_
2719
2740
  feature_importance_df = pd.DataFrame({
2720
2741
  'feature': features,
2721
2742
  'importance': feature_importances
2722
- }).sort_values(by='importance', ascending=False).head(nr_to_plot)
2743
+ }).sort_values(by='importance', ascending=False).head(top_features)
2723
2744
 
2724
- # Plotting feature importance
2725
- fig, ax = plt.subplots()
2726
- ax.barh(feature_importance_df['feature'], feature_importance_df['importance'], color="blue", align="center", alpha=0.6)
2727
- ax.set_xlabel('Feature Importance')
2728
- plt.tight_layout()
2729
- plt.show()
2745
+ feature_importance_fig = plot_feature_importance(feature_importance_df)
2746
+ if verbose:
2747
+ feature_importance_fig.show()
2748
+
2730
2749
  else:
2731
2750
  feature_importance_df = pd.DataFrame()
2732
2751
 
@@ -2734,38 +2753,38 @@ def _permutation_importance(df, feature_string='channel_3', col_to_compare='col'
2734
2753
  predictions_test = model.predict(X_test)
2735
2754
  combined_df.loc[X_test.index, 'predictions'] = predictions_test
2736
2755
 
2737
- # Predicting the target variable for the training set
2738
- predictions_train = model.predict(X_train)
2739
- combined_df.loc[X_train.index, 'predictions'] = predictions_train
2756
+ # Get prediction probabilities for the test set
2757
+ prediction_probabilities_test = model.predict_proba(X_test)
2758
+
2759
+ # Find the optimal threshold
2760
+ optimal_threshold = find_optimal_threshold(y_test, prediction_probabilities_test[:, 1])
2761
+ if verbose:
2762
+ print(f'Optimal threshold: {optimal_threshold}')
2740
2763
 
2741
2764
  # Predicting the target variable for all other rows in the dataframe
2742
2765
  X_all = df[features]
2743
2766
  all_predictions = model.predict(X_all)
2744
2767
  df['predictions'] = all_predictions
2745
2768
 
2746
- # Combine data usage labels back to the original dataframe
2747
- combined_data_usage = pd.concat([combined_df[['data_usage']], df[['predictions']]], axis=0)
2748
- df = df.join(combined_data_usage, how='left', rsuffix='_model')
2749
-
2750
- # Calculating and printing the accuracy metrics
2751
- accuracy = accuracy_score(y_test, predictions_test)
2752
- precision = precision_score(y_test, predictions_test)
2753
- recall = recall_score(y_test, predictions_test)
2754
- f1 = f1_score(y_test, predictions_test)
2755
- print(f"Accuracy: {accuracy}")
2756
- print(f"Precision: {precision}")
2757
- print(f"Recall: {recall}")
2758
- print(f"F1 Score: {f1}")
2759
-
2760
- # Printing class-specific accuracy metrics
2761
- print("\nClassification Report:")
2762
- print(classification_report(y_test, predictions_test))
2769
+ # Get prediction probabilities for all rows in the dataframe
2770
+ prediction_probabilities = model.predict_proba(X_all)
2771
+ for i in range(prediction_probabilities.shape[1]):
2772
+ df[f'prediction_probability_class_{i}'] = prediction_probabilities[:, i]
2773
+ if verbose:
2774
+ print("\nClassification Report:")
2775
+ print(classification_report(y_test, predictions_test))
2776
+ report_dict = classification_report(y_test, predictions_test, output_dict=True)
2777
+ metrics_df = pd.DataFrame(report_dict).transpose()
2763
2778
 
2764
- df = _calculate_similarity(df, features, col_to_compare, pos, neg)
2779
+ df = _calculate_similarity(df, features, location_column, positive_control, negative_control)
2765
2780
 
2766
- return [df, permutation_df, feature_importance_df, model, X_train, X_test, y_train, y_test]
2781
+ df['prcfo'] = df.index.astype(str)
2782
+ df[['plate', 'row', 'col', 'field', 'object']] = df['prcfo'].str.split('_', expand=True)
2783
+ df['prc'] = df['plate'] + '_' + df['row'] + '_' + df['col']
2784
+
2785
+ return [df, permutation_df, feature_importance_df, model, X_train, X_test, y_train, y_test, metrics_df], [permutation_fig, feature_importance_fig]
2767
2786
 
2768
- def _shap_analysis(model, X_train, X_test):
2787
+ def shap_analysis(model, X_train, X_test):
2769
2788
 
2770
2789
  """
2771
2790
  Performs SHAP analysis on the given model and data.
@@ -2774,17 +2793,45 @@ def _shap_analysis(model, X_train, X_test):
2774
2793
  model: The trained model.
2775
2794
  X_train (pandas.DataFrame): Training feature set.
2776
2795
  X_test (pandas.DataFrame): Testing feature set.
2796
+ Returns:
2797
+ fig: Matplotlib figure object containing the SHAP summary plot.
2777
2798
  """
2778
-
2799
+
2779
2800
  explainer = shap.Explainer(model, X_train)
2780
2801
  shap_values = explainer(X_test)
2781
-
2802
+ # Create a new figure
2803
+ fig, ax = plt.subplots()
2782
2804
  # Summary plot
2783
- shap.summary_plot(shap_values, X_test)
2784
-
2785
- def plate_heatmap(src, model_type='xgboost', variable='predictions', grouping='mean', min_max='allq', cmap='viridis', channel_of_interest=3, min_count=25, n_estimators=100, col_to_compare='col', pos='c1', neg='c2', exclude=None, n_repeats=10, clean=True, nr_to_plot=20, verbose=False, n_jobs=-1):
2805
+ shap.summary_plot(shap_values, X_test, show=False)
2806
+ # Save the current figure (the one that SHAP just created)
2807
+ fig = plt.gcf()
2808
+ plt.close(fig) # Close the figure to prevent it from displaying immediately
2809
+ return fig
2810
+
2811
+ def check_index(df, elements=5, split_char='_'):
2812
+ problematic_indices = []
2813
+ for idx in df.index:
2814
+ parts = str(idx).split(split_char)
2815
+ if len(parts) != elements:
2816
+ problematic_indices.append(idx)
2817
+ if problematic_indices:
2818
+ print("Indices that cannot be separated into 5 parts:")
2819
+ for idx in problematic_indices:
2820
+ print(idx)
2821
+ raise ValueError(f"Found {len(problematic_indices)} problematic indices that do not split into {elements} parts.")
2822
+
2823
+ #def plate_heatmap(src, model_type='xgboost', variable='predictions', grouping='mean', min_max='allq', cmap='viridis', channel_of_interest=3, min_count=25, n_estimators=100, col_to_compare='col', pos='c2', neg='c1', exclude=None, n_repeats=10, clean=True, nr_to_plot=20, verbose=False, n_jobs=-1):
2824
+ def generate_ml_scores(src, settings):
2825
+
2786
2826
  from .io import _read_and_merge_data
2787
- from .plot import _plot_plates
2827
+ from .plot import plot_plates
2828
+ from .utils import get_ml_results_paths
2829
+ from .settings import set_default_analyze_screen
2830
+
2831
+ settings = set_default_analyze_screen(settings)
2832
+
2833
+ settings_df = pd.DataFrame(list(settings.items()), columns=['Key', 'Value'])
2834
+ display(settings_df)
2788
2835
 
2789
2836
  db_loc = [src+'/measurements/measurements.db']
2790
2837
  tables = ['cell', 'nucleus', 'pathogen','cytoplasm']
@@ -2792,27 +2839,60 @@ def plate_heatmap(src, model_type='xgboost', variable='predictions', grouping='m
2792
2839
 
2793
2840
  df, _ = _read_and_merge_data(db_loc,
2794
2841
  tables,
2795
- verbose=verbose,
2796
- include_multinucleated=include_multinucleated,
2797
- include_multiinfected=include_multiinfected,
2798
- include_noninfected=include_noninfected)
2799
-
2800
- if not channel_of_interest is None:
2801
- df['recruitment'] = df[f'pathogen_channel_{channel_of_interest}_mean_intensity']/df[f'cytoplasm_channel_{channel_of_interest}_mean_intensity']
2802
- feature_string = f'channel_{channel_of_interest}'
2803
- else:
2804
- feature_string = None
2805
-
2806
- output = _permutation_importance(df, feature_string, col_to_compare, pos, neg, exclude, n_repeats, clean, nr_to_plot, n_estimators=n_estimators, random_state=42, model_type=model_type, n_jobs=n_jobs)
2807
-
2808
- _shap_analysis(output[3], output[4], output[5])
2842
+ settings['verbose'],
2843
+ include_multinucleated,
2844
+ include_multiinfected,
2845
+ include_noninfected)
2846
+
2847
+ if settings['channel_of_interest'] in [0,1,2,3]:
2848
+
2849
+ df['recruitment'] = df[f"pathogen_channel_{settings['channel_of_interest']}_mean_intensity"]/df[f"cytoplasm_channel_{settings['channel_of_interest']}_mean_intensity"]
2850
+
2851
+ output, figs = ml_analysis(df,
2852
+ settings['channel_of_interest'],
2853
+ settings['location_column'],
2854
+ settings['positive_control'],
2855
+ settings['negative_control'],
2856
+ settings['exclude'],
2857
+ settings['n_repeats'],
2858
+ settings['top_features'],
2859
+ settings['n_estimators'],
2860
+ settings['test_size'],
2861
+ settings['model_type'],
2862
+ settings['n_jobs'],
2863
+ settings['remove_low_variance_features'],
2864
+ settings['remove_highly_correlated_features'],
2865
+ settings['verbose'])
2866
+
2867
+ shap_fig = shap_analysis(output[3], output[4], output[5])
2809
2868
 
2810
2869
  features = output[0].select_dtypes(include=[np.number]).columns.tolist()
2811
2870
 
2812
- if not variable in features:
2813
- raise ValueError(f"Variable {variable} not found in the dataframe. Please choose one of the following: {features}")
2871
+ if not settings['heatmap_feature'] in features:
2872
+ raise ValueError(f"Variable {settings['heatmap_feature']} not found in the dataframe. Please choose one of the following: {features}")
2814
2873
 
2815
- plate_heatmap = _plot_plates(output[0], variable, grouping, min_max, cmap, min_count)
2874
+ plate_heatmap = plot_plates(df=output[0],
2875
+ variable=settings['heatmap_feature'],
2876
+ grouping=settings['grouping'],
2877
+ min_max=settings['min_max'],
2878
+ cmap=settings['cmap'],
2879
+ min_count=settings['minimum_cell_count'],
2880
+ verbose=settings['verbose'])
2881
+
2882
+ data_path, permutation_path, feature_importance_path, model_metricks_path, permutation_fig_path, feature_importance_fig_path, shap_fig_path, plate_heatmap_path, settings_csv = get_ml_results_paths(src, settings['model_type'], settings['channel_of_interest'])
2883
+ df, permutation_df, feature_importance_df, _, _, _, _, _, metrics_df = output
2884
+
2885
+ settings_df.to_csv(settings_csv, index=False)
2886
+ df.to_csv(data_path, mode='w', encoding='utf-8')
2887
+ permutation_df.to_csv(permutation_path, mode='w', encoding='utf-8')
2888
+ feature_importance_df.to_csv(feature_importance_path, mode='w', encoding='utf-8')
2889
+ metrics_df.to_csv(model_metricks_path, mode='w', encoding='utf-8')
2890
+
2891
+ plate_heatmap.savefig(plate_heatmap_path, format='pdf')
2892
+ figs[0].savefig(permutation_fig_path, format='pdf')
2893
+ figs[1].savefig(feature_importance_fig_path, format='pdf')
2894
+ shap_fig.savefig(shap_fig_path, format='pdf')
2895
+
2816
2896
  return [output, plate_heatmap]
2817
2897
 
2818
2898
  def join_measurments_and_annotation(src, tables = ['cell', 'nucleus', 'pathogen','cytoplasm']):
@@ -2941,8 +3021,8 @@ def generate_image_umap(settings={}):
2941
3021
  """
2942
3022
 
2943
3023
  from .io import _read_and_join_tables
2944
- from .utils import get_db_paths, preprocess_data, reduction_and_clustering, remove_noise, generate_colors, correct_paths, plot_embedding, plot_clusters_grid, get_umap_image_settings, cluster_feature_analysis, generate_umap_from_images
2945
-
3024
+ from .utils import get_db_paths, preprocess_data, reduction_and_clustering, remove_noise, generate_colors, correct_paths, plot_embedding, plot_clusters_grid, cluster_feature_analysis, generate_umap_from_images
3025
+ from .settings import get_umap_image_settings
2946
3026
  settings = get_umap_image_settings(settings)
2947
3027
 
2948
3028
  if isinstance(settings['src'], str):
@@ -3124,7 +3204,8 @@ def reducer_hyperparameter_search(settings={}, reduction_params=None, dbscan_par
3124
3204
  """
3125
3205
 
3126
3206
  from .io import _read_and_join_tables
3127
- from .utils import get_db_paths, preprocess_data, search_reduction_and_clustering, generate_colors, get_umap_image_settings
3207
+ from .utils import get_db_paths, preprocess_data, search_reduction_and_clustering, generate_colors
3208
+ from .settings import get_umap_image_settings
3128
3209
 
3129
3210
  settings = get_umap_image_settings(settings)
3130
3211
  pointsize = settings['dot_size']