spacr 0.0.36__py3-none-any.whl → 0.0.62__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
spacr/core.py CHANGED
@@ -1,13 +1,10 @@
1
- import os, sqlite3, gc, torch, time, random, shutil, cv2, tarfile, datetime, shap, string
1
+ import os, sqlite3, gc, torch, time, random, shutil, cv2, tarfile, datetime, shap
2
2
 
3
- # image and array processing
4
3
  import numpy as np
5
4
  import pandas as pd
6
5
 
7
6
  from cellpose import train
8
- import cellpose
9
7
  from cellpose import models as cp_models
10
- from cellpose.models import CellposeModel
11
8
 
12
9
  import statsmodels.formula.api as smf
13
10
  import statsmodels.api as sm
@@ -16,31 +13,37 @@ from IPython.display import display
16
13
  from multiprocessing import Pool, cpu_count, Value, Lock
17
14
 
18
15
  import seaborn as sns
19
- import matplotlib.pyplot as plt
16
+
20
17
  from skimage.measure import regionprops, label
21
- import skimage.measure as measure
18
+ from skimage.morphology import square
22
19
  from skimage.transform import resize as resizescikit
23
- from sklearn.model_selection import train_test_split
24
20
  from collections import defaultdict
25
- import multiprocessing
26
21
  from torch.utils.data import DataLoader, random_split
27
- import matplotlib
28
- matplotlib.use('Agg')
22
+ from sklearn.cluster import KMeans
23
+ from sklearn.decomposition import PCA
29
24
 
30
- import torchvision.transforms as transforms
25
+ from skimage import measure
31
26
  from sklearn.model_selection import train_test_split
32
27
  from sklearn.ensemble import IsolationForest, RandomForestClassifier, HistGradientBoostingClassifier
33
- from .logger import log_function_call
34
-
35
28
  from sklearn.linear_model import LogisticRegression
36
29
  from sklearn.inspection import permutation_importance
37
30
  from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, classification_report
38
- from xgboost import XGBClassifier
31
+ from sklearn.preprocessing import StandardScaler
39
32
 
33
+ from scipy.ndimage import binary_dilation
40
34
  from scipy.spatial.distance import cosine, euclidean, mahalanobis, cityblock, minkowski, chebyshev, hamming, jaccard, braycurtis
41
- from sklearn.preprocessing import StandardScaler
35
+
36
+ import torchvision.transforms as transforms
37
+ from xgboost import XGBClassifier
42
38
  import shap
43
39
 
40
+ import matplotlib.pyplot as plt
41
+ import matplotlib
42
+ matplotlib.use('Agg')
43
+ #import matplotlib.pyplot as plt
44
+
45
+ from .logger import log_function_call
46
+
44
47
  def analyze_plaques(folder):
45
48
  summary_data = []
46
49
  details_data = []
@@ -77,73 +80,46 @@ def analyze_plaques(folder):
77
80
 
78
81
  print(f"Analysis completed and saved to database '{db_name}'.")
79
82
 
80
- def generate_cp_masks(settings):
81
-
82
- src = settings['src']
83
- model_name = settings['model_name']
84
- channels = settings['channels']
85
- diameter = settings['diameter']
86
- regex = '.tif'
87
- #flow_threshold = 30
88
- cellprob_threshold = settings['cellprob_threshold']
89
- figuresize = 25
90
- cmap = 'inferno'
91
- verbose = settings['verbose']
92
- plot = settings['plot']
93
- save = settings['save']
94
- custom_model = settings['custom_model']
95
- signal_thresholds = 1000
96
- normalize = settings['normalize']
97
- resize = settings['resize']
98
- target_height = settings['width_height'][1]
99
- target_width = settings['width_height'][0]
100
- rescale = settings['rescale']
101
- resample = settings['resample']
102
- net_avg = settings['net_avg']
103
- invert = settings['invert']
104
- circular = settings['circular']
105
- percentiles = settings['percentiles']
106
- overlay = settings['overlay']
107
- grayscale = settings['grayscale']
108
- flow_threshold = settings['flow_threshold']
109
- batch_size = settings['batch_size']
110
-
111
- dst = os.path.join(src,'masks')
112
- os.makedirs(dst, exist_ok=True)
113
-
114
- identify_masks(src, dst, model_name, channels, diameter, batch_size, flow_threshold, cellprob_threshold, figuresize, cmap, verbose, plot, save, custom_model, signal_thresholds, normalize, resize, target_height, target_width, rescale, resample, net_avg, invert, circular, percentiles, overlay, grayscale)
115
-
116
83
  def train_cellpose(settings):
117
84
 
118
85
  from .io import _load_normalized_images_and_labels, _load_images_and_labels
119
86
  from .utils import resize_images_and_labels
120
87
 
121
88
  img_src = settings['img_src']
122
- mask_src = os.path.join(img_src, 'mask')
89
+ mask_src = os.path.join(img_src, 'masks')
123
90
 
124
- model_name = settings['model_name']
125
- model_type = settings['model_type']
126
- learning_rate = settings['learning_rate']
127
- weight_decay = settings['weight_decay']
128
- batch_size = settings['batch_size']
129
- n_epochs = settings['n_epochs']
130
- from_scratch = settings['from_scratch']
131
- diameter = settings['diameter']
132
- verbose = settings['verbose']
133
-
134
- channels = [0,0]
135
- signal_thresholds = 1000
136
- normalize = True
137
- percentiles = [2,98]
138
- circular = False
139
- invert = False
140
- resize = False
141
- settings['width_height'] = [1000,1000]
142
- target_height = settings['width_height'][1]
143
- target_width = settings['width_height'][0]
144
- rescale = False
145
- grayscale = True
146
- test = False
91
+ model_name = settings.setdefault( 'model_name', '')
92
+
93
+ model_name = settings.setdefault('model_name', 'model_name')
94
+
95
+ model_type = settings.setdefault( 'model_type', 'cyto')
96
+ learning_rate = settings.setdefault( 'learning_rate', 0.01)
97
+ weight_decay = settings.setdefault( 'weight_decay', 1e-05)
98
+ batch_size = settings.setdefault( 'batch_size', 50)
99
+ n_epochs = settings.setdefault( 'n_epochs', 100)
100
+ from_scratch = settings.setdefault( 'from_scratch', False)
101
+ diameter = settings.setdefault( 'diameter', 40)
102
+
103
+ remove_background = settings.setdefault( 'remove_background', False)
104
+ background = settings.setdefault( 'background', 100)
105
+ Signal_to_noise = settings.setdefault( 'Signal_to_noise', 10)
106
+ verbose = settings.setdefault( 'verbose', False)
107
+
108
+
109
+ channels = settings.setdefault( 'channels', [0,0])
110
+ normalize = settings.setdefault( 'normalize', True)
111
+ percentiles = settings.setdefault( 'percentiles', None)
112
+ circular = settings.setdefault( 'circular', False)
113
+ invert = settings.setdefault( 'invert', False)
114
+ resize = settings.setdefault( 'resize', False)
115
+
116
+ if resize:
117
+ target_height = settings['width_height'][1]
118
+ target_width = settings['width_height'][0]
119
+
120
+ grayscale = settings.setdefault( 'grayscale', True)
121
+ rescale = settings.setdefault( 'channels', False)
122
+ test = settings.setdefault( 'test', False)
147
123
 
148
124
  if test:
149
125
  test_img_src = os.path.join(os.path.dirname(img_src), 'test')
@@ -177,22 +153,21 @@ def train_cellpose(settings):
177
153
 
178
154
  image_files = [os.path.join(img_src, f) for f in os.listdir(img_src) if f.endswith('.tif')]
179
155
  label_files = [os.path.join(mask_src, f) for f in os.listdir(mask_src) if f.endswith('.tif')]
180
- images, masks, image_names, mask_names = _load_normalized_images_and_labels(image_files, label_files, signal_thresholds, channels=channels, percentiles=percentiles, circular=circular, invert=invert, visualize=verbose)
156
+ images, masks, image_names, mask_names = _load_normalized_images_and_labels(image_files, label_files, channels, percentiles, circular, invert, verbose, remove_background, background, Signal_to_noise)
181
157
  images = [np.squeeze(img) if img.shape[-1] == 1 else img for img in images]
182
158
 
183
159
  if test:
184
160
  test_image_files = [os.path.join(test_img_src, f) for f in os.listdir(test_img_src) if f.endswith('.tif')]
185
161
  test_label_files = [os.path.join(test_mask_src, f) for f in os.listdir(test_mask_src) if f.endswith('.tif')]
186
- test_images, test_masks, test_image_names, test_mask_names = _load_normalized_images_and_labels(image_files=test_image_files, label_files=test_label_files, signal_thresholds=signal_thresholds, channels=channels, percentiles=percentiles, circular=circular, invert=invert, visualize=verbose)
162
+ test_images, test_masks, test_image_names, test_mask_names = _load_normalized_images_and_labels(test_image_files, test_label_files, channels, percentiles, circular, invert, verbose, remove_background, background, Signal_to_noise)
187
163
  test_images = [np.squeeze(img) if img.shape[-1] == 1 else img for img in test_images]
188
164
 
189
-
190
165
  else:
191
166
  images, masks, image_names, mask_names = _load_images_and_labels(img_src, mask_src, circular, invert)
192
167
  images = [np.squeeze(img) if img.shape[-1] == 1 else img for img in images]
193
168
 
194
169
  if test:
195
- test_images, test_masks, test_image_names, test_mask_names = _load_images_and_labels(img_src=test_img_src, mask_src=test_mask_src, circular=circular, invert=circular)
170
+ test_images, test_masks, test_image_names, test_mask_names = _load_images_and_labels(img_src=test_img_src, mask_src=test_mask_src, circular=circular, invert=invert)
196
171
  test_images = [np.squeeze(img) if img.shape[-1] == 1 else img for img in test_images]
197
172
 
198
173
  if resize:
@@ -250,179 +225,6 @@ def train_cellpose(settings):
250
225
 
251
226
  return print(f"Model saved at: {model_save_path}/{model_name}")
252
227
 
253
- def train_cellpose_v1(settings):
254
-
255
- from .io import _load_normalized_images_and_labels, _load_images_and_labels
256
- from .utils import resize_images_and_labels
257
-
258
- img_src = settings['img_src']
259
-
260
- mask_src = os.path.join(img_src, 'mask')
261
-
262
- model_name = settings['model_name']
263
- model_type = settings['model_type']
264
- learning_rate = settings['learning_rate']
265
- weight_decay = settings['weight_decay']
266
- batch_size = settings['batch_size']
267
- n_epochs = settings['n_epochs']
268
- verbose = settings['verbose']
269
-
270
- signal_thresholds = 100 #settings['signal_thresholds']
271
-
272
- channels = settings['channels']
273
- from_scratch = settings['from_scratch']
274
- diameter = settings['diameter']
275
- resize = settings['resize']
276
- rescale = settings['rescale']
277
- normalize = settings['normalize']
278
- target_height = settings['width_height'][1]
279
- target_width = settings['width_height'][0]
280
- circular = settings['circular']
281
- invert = settings['invert']
282
- percentiles = settings['percentiles']
283
- grayscale = settings['grayscale']
284
-
285
- if model_type == 'cyto':
286
- settings['diameter'] = 30
287
- diameter = settings['diameter']
288
- print(f'Cyto model must have diamiter 30. Diameter set the 30')
289
-
290
- if model_type == 'nuclei':
291
- settings['diameter'] = 17
292
- diameter = settings['diameter']
293
- print(f'Nuclei model must have diamiter 17. Diameter set the 17')
294
-
295
- print(settings)
296
-
297
- if from_scratch:
298
- model_name=f'scratch_{model_name}_{model_type}_e{n_epochs}_X{target_width}_Y{target_height}.CP_model'
299
- else:
300
- model_name=f'{model_name}_{model_type}_e{n_epochs}_X{target_width}_Y{target_height}.CP_model'
301
-
302
- model_save_path = os.path.join(mask_src, 'models', 'cellpose_model')
303
- print(model_save_path)
304
- os.makedirs(model_save_path, exist_ok=True)
305
-
306
- settings_df = pd.DataFrame(list(settings.items()), columns=['Key', 'Value'])
307
- settings_csv = os.path.join(model_save_path,f'{model_name}_settings.csv')
308
- settings_df.to_csv(settings_csv, index=False)
309
-
310
- if not from_scratch:
311
- model = cp_models.CellposeModel(gpu=True, model_type=model_type)
312
-
313
- else:
314
- model = cp_models.CellposeModel(gpu=True, model_type=model_type, pretrained_model=None)
315
-
316
- if normalize:
317
- image_files = [os.path.join(img_src, f) for f in os.listdir(img_src) if f.endswith('.tif')]
318
- label_files = [os.path.join(mask_src, f) for f in os.listdir(mask_src) if f.endswith('.tif')]
319
-
320
- images, masks, image_names, mask_names = _load_normalized_images_and_labels(image_files, label_files, signal_thresholds, channels=channels, percentiles=percentiles, circular=circular, invert=invert, visualize=verbose)
321
- images = [np.squeeze(img) if img.shape[-1] == 1 else img for img in images]
322
- else:
323
- images, masks, image_names, mask_names = _load_images_and_labels(img_src, mask_src, circular, invert)
324
- images = [np.squeeze(img) if img.shape[-1] == 1 else img for img in images]
325
-
326
- if resize:
327
- images, masks = resize_images_and_labels(images, masks, target_height, target_width, show_example=True)
328
-
329
- if model_type == 'cyto':
330
- cp_channels = [0,1]
331
- if model_type == 'cyto2':
332
- cp_channels = [0,2]
333
- if model_type == 'nucleus':
334
- cp_channels = [0,0]
335
- if grayscale:
336
- cp_channels = [0,0]
337
- images = [np.squeeze(img) if img.ndim == 3 and 1 in img.shape else img for img in images]
338
-
339
- masks = [np.squeeze(mask) if mask.ndim == 3 and 1 in mask.shape else mask for mask in masks]
340
-
341
- print(f'image shape: {images[0].shape}, image type: images[0].shape mask shape: {masks[0].shape}, image type: masks[0].shape')
342
- save_every = int(n_epochs/10)
343
- if save_every < 10:
344
- save_every = n_epochs
345
-
346
-
347
- #print('cellpose image input dtype', images[0].dtype)
348
- #print('cellpose mask input dtype', masks[0].dtype)
349
-
350
- # Train the model
351
- #model.train(train_data=images, #(list of arrays (2D or 3D)) – images for training
352
-
353
- #model.train(train_data=images, #(list of arrays (2D or 3D)) – images for training
354
- # train_labels=masks, #(list of arrays (2D or 3D)) – labels for train_data, where 0=no masks; 1,2,…=mask labels can include flows as additional images
355
- # train_files=image_names, #(list of strings) – file names for images in train_data (to save flows for future runs)
356
- # channels=cp_channels, #(list of ints (default, None)) – channels to use for training
357
- # normalize=False, #(bool (default, True)) – normalize data so 0.0=1st percentile and 1.0=99th percentile of image intensities in each channel
358
- # save_path=model_save_path, #(string (default, None)) – where to save trained model, if None it is not saved
359
- # save_every=save_every, #(int (default, 100)) – save network every [save_every] epochs
360
- # learning_rate=learning_rate, #(float or list/np.ndarray (default, 0.2)) – learning rate for training, if list, must be same length as n_epochs
361
- # n_epochs=n_epochs, #(int (default, 500)) – how many times to go through whole training set during training
362
- # weight_decay=weight_decay, #(float (default, 0.00001)) –
363
- # SGD=True, #(bool (default, True)) – use SGD as optimization instead of RAdam
364
- # batch_size=batch_size, #(int (optional, default 8)) – number of 224x224 patches to run simultaneously on the GPU (can make smaller or bigger depending on GPU memory usage)
365
- # nimg_per_epoch=None, #(int (optional, default None)) – minimum number of images to train on per epoch, with a small training set (< 8 images) it may help to set to 8
366
- # rescale=rescale, #(bool (default, True)) – whether or not to rescale images to diam_mean during training, if True it assumes you will fit a size model after training or resize your images accordingly, if False it will try to train the model to be scale-invariant (works worse)
367
- # min_train_masks=1, #(int (default, 5)) – minimum number of masks an image must have to use in training set
368
- # model_name=model_name) #(str (default, None)) – name of network, otherwise saved with name as params + training start time
369
-
370
-
371
- train.train_seg(model.net,
372
- train_data=images,
373
- train_labels=masks,
374
- train_files=image_names,
375
- train_labels_files=None,
376
- train_probs=None,
377
- test_data=None,
378
- test_labels=None,
379
- test_files=None,
380
- test_labels_files=None,
381
- test_probs=None,
382
- load_files=True,
383
- batch_size=batch_size,
384
- learning_rate=learning_rate,
385
- n_epochs=n_epochs,
386
- weight_decay=weight_decay,
387
- momentum=0.9,
388
- SGD=False,
389
- channels=cp_channels,
390
- channel_axis=None,
391
- #rgb=False,
392
- normalize=False,
393
- compute_flows=False,
394
- save_path=model_save_path,
395
- save_every=save_every,
396
- nimg_per_epoch=None,
397
- nimg_test_per_epoch=None,
398
- rescale=rescale,
399
- #scale_range=None,
400
- #bsize=224,
401
- min_train_masks=1,
402
- model_name=model_name)
403
-
404
- #model_save_path = train.train_seg(model.net,
405
- # train_data=images,
406
- # train_files=image_names,
407
- # train_labels=masks,
408
- # channels=cp_channels,
409
- # normalize=False,
410
- # save_every=save_every,
411
- # learning_rate=learning_rate,
412
- # n_epochs=n_epochs,
413
- # #test_data=test_images,
414
- # #test_labels=test_labels,
415
- # weight_decay=weight_decay,
416
- # SGD=True,
417
- # batch_size=batch_size,
418
- # nimg_per_epoch=None,
419
- # rescale=rescale,
420
- # min_train_masks=1,
421
- # model_name=model_name)
422
-
423
-
424
- return print(f"Model saved at: {model_save_path}/{model_name}")
425
-
426
228
  def analyze_data_reg(sequencing_loc, dv_loc, agg_type = 'mean', dv_col='pred', transform=None, min_cell_count=50, min_reads=100, min_wells=2, max_wells=1000, min_frequency=0.0,remove_outlier_genes=False, refine_model=False,by_plate=False, regression_type='mlr', alpha_value=0.01, fishers=False, fisher_threshold=0.9):
427
229
 
428
230
  from .plot import _reg_v_plot
@@ -984,15 +786,6 @@ def merge_pred_mes(src,
984
786
 
985
787
  if verbose:
986
788
  _plot_histograms_and_stats(df=joined_df)
987
-
988
- #dv = joined_df.copy()
989
- #if 'prc' not in dv.columns:
990
- #dv['prc'] = dv['plate'] + '_' + dv['row'] + '_' + dv['col']
991
- #dv = dv[['pred']].groupby('prc').mean()
992
- #dv.set_index('prc', inplace=True)
993
-
994
- #loc = '/mnt/data/CellVoyager/20x/tsg101/crispr_screen/all/measurements/dv.csv'
995
- #dv.to_csv(loc, index=True, header=True, mode='w')
996
789
 
997
790
  return joined_df
998
791
 
@@ -1282,7 +1075,6 @@ def apply_model(src, model_path, image_size=224, batch_size=64, normalize=True,
1282
1075
  torch.cuda.memory.empty_cache()
1283
1076
  return df
1284
1077
 
1285
-
1286
1078
  def generate_training_data_file_list(src,
1287
1079
  target='protein of interest',
1288
1080
  cell_dim=4,
@@ -1483,158 +1275,27 @@ def generate_training_dataset(src, mode='annotation', annotation_column='test',
1483
1275
  if len(custom_measurement) == 1:
1484
1276
  print(f'Classes will be defined by the Q1 and Q3 quantiles of recruitment ({custom_measurement[0]})')
1485
1277
  df['recruitment'] = df[f'{custom_measurement[0]}']
1486
- else:
1487
- print(f'Classes will be defined by the Q1 and Q3 quantiles of recruitment (pathogen/cytoplasm for channel {channel_of_interest})')
1488
- df['recruitment'] = df[f'pathogen_channel_{channel_of_interest}_mean_intensity']/df[f'cytoplasm_channel_{channel_of_interest}_mean_intensity']
1489
-
1490
- q25 = df['recruitment'].quantile(0.25)
1491
- q75 = df['recruitment'].quantile(0.75)
1492
- df_lower = df[df['recruitment'] <= q25]
1493
- df_upper = df[df['recruitment'] >= q75]
1494
-
1495
- class_paths_lower = get_paths_from_db(df=df_lower, png_df=png_list_df, image_type=png_type)
1496
-
1497
- class_paths_lower = random.sample(class_paths_lower['png_path'].tolist(), size)
1498
- class_paths_ls.append(class_paths_lower)
1499
-
1500
- class_paths_upper = get_paths_from_db(df=df_upper, png_df=png_list_df, image_type=png_type)
1501
- class_paths_upper = random.sample(class_paths_upper['png_path'].tolist(), size)
1502
- class_paths_ls.append(class_paths_upper)
1503
-
1504
- generate_dataset_from_lists(dst, class_data=class_paths_ls, classes=classes, test_split=0.1)
1505
-
1506
- return
1507
-
1508
- def generate_loaders_v1(src, train_mode='erm', mode='train', image_size=224, batch_size=32, classes=['nc','pc'], num_workers=None, validation_split=0.0, max_show=2, pin_memory=False, normalize=False, verbose=False):
1509
- """
1510
- Generate data loaders for training and validation/test datasets.
1511
-
1512
- Parameters:
1513
- - src (str): The source directory containing the data.
1514
- - train_mode (str): The training mode. Options are 'erm' (Empirical Risk Minimization) or 'irm' (Invariant Risk Minimization).
1515
- - mode (str): The mode of operation. Options are 'train' or 'test'.
1516
- - image_size (int): The size of the input images.
1517
- - batch_size (int): The batch size for the data loaders.
1518
- - classes (list): The list of classes to consider.
1519
- - num_workers (int): The number of worker threads for data loading.
1520
- - validation_split (float): The fraction of data to use for validation when train_mode is 'erm'.
1521
- - max_show (int): The maximum number of images to show when verbose is True.
1522
- - pin_memory (bool): Whether to pin memory for faster data transfer.
1523
- - normalize (bool): Whether to normalize the input images.
1524
- - verbose (bool): Whether to print additional information and show images.
1525
-
1526
- Returns:
1527
- - train_loaders (list): List of data loaders for training datasets.
1528
- - val_loaders (list): List of data loaders for validation datasets.
1529
- - plate_names (list): List of plate names (only applicable when train_mode is 'irm').
1530
- """
1531
-
1532
- from .io import MyDataset
1533
- from .plot import _imshow
1534
-
1535
- plate_to_filenames = defaultdict(list)
1536
- plate_to_labels = defaultdict(list)
1537
- train_loaders = []
1538
- val_loaders = []
1539
- plate_names = []
1540
-
1541
- if normalize:
1542
- transform = transforms.Compose([
1543
- transforms.ToTensor(),
1544
- transforms.CenterCrop(size=(image_size, image_size)),
1545
- transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])
1546
- else:
1547
- transform = transforms.Compose([
1548
- transforms.ToTensor(),
1549
- transforms.CenterCrop(size=(image_size, image_size))])
1550
-
1551
- if mode == 'train':
1552
- data_dir = os.path.join(src, 'train')
1553
- shuffle = True
1554
- print(f'Generating Train and validation datasets')
1555
-
1556
- elif mode == 'test':
1557
- data_dir = os.path.join(src, 'test')
1558
- val_loaders = []
1559
- validation_split=0.0
1560
- shuffle = True
1561
- print(f'Generating test dataset')
1562
-
1563
- else:
1564
- print(f'mode:{mode} is not valid, use mode = train or test')
1565
- return
1566
-
1567
- if train_mode == 'erm':
1568
- data = MyDataset(data_dir, classes, transform=transform, shuffle=shuffle, pin_memory=pin_memory)
1569
- #train_loaders = DataLoader(data, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers if num_workers is not None else 0, pin_memory=pin_memory)
1570
- if validation_split > 0:
1571
- train_size = int((1 - validation_split) * len(data))
1572
- val_size = len(data) - train_size
1573
-
1574
- print(f'Train data:{train_size}, Validation data:{val_size}')
1575
-
1576
- train_dataset, val_dataset = random_split(data, [train_size, val_size])
1577
-
1578
- train_loaders = DataLoader(train_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers if num_workers is not None else 0, pin_memory=pin_memory)
1579
- val_loaders = DataLoader(val_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers if num_workers is not None else 0, pin_memory=pin_memory)
1580
- else:
1581
- train_loaders = DataLoader(data, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers if num_workers is not None else 0, pin_memory=pin_memory)
1582
-
1583
- elif train_mode == 'irm':
1584
- data = MyDataset(data_dir, classes, transform=transform, shuffle=shuffle, pin_memory=pin_memory)
1585
-
1586
- for filename, label in zip(data.filenames, data.labels):
1587
- plate = data.get_plate(filename)
1588
- plate_to_filenames[plate].append(filename)
1589
- plate_to_labels[plate].append(label)
1590
-
1591
- for plate, filenames in plate_to_filenames.items():
1592
- labels = plate_to_labels[plate]
1593
- plate_data = MyDataset(data_dir, classes, specific_files=filenames, specific_labels=labels, transform=transform, shuffle=False, pin_memory=pin_memory)
1594
- plate_names.append(plate)
1595
-
1596
- if validation_split > 0:
1597
- train_size = int((1 - validation_split) * len(plate_data))
1598
- val_size = len(plate_data) - train_size
1599
-
1600
- print(f'Train data:{train_size}, Validation data:{val_size}')
1601
-
1602
- train_dataset, val_dataset = random_split(plate_data, [train_size, val_size])
1603
-
1604
- train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers if num_workers is not None else 0, pin_memory=pin_memory)
1605
- val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers if num_workers is not None else 0, pin_memory=pin_memory)
1606
-
1607
- train_loaders.append(train_loader)
1608
- val_loaders.append(val_loader)
1609
- else:
1610
- train_loader = DataLoader(plate_data, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers if num_workers is not None else 0, pin_memory=pin_memory)
1611
- train_loaders.append(train_loader)
1612
- val_loaders.append(None)
1613
-
1614
- else:
1615
- print(f'train_mode:{train_mode} is not valid, use: train_mode = irm or erm')
1616
- return
1617
-
1618
- if verbose:
1619
- if train_mode == 'erm':
1620
- for idx, (images, labels, filenames) in enumerate(train_loaders):
1621
- if idx >= max_show:
1622
- break
1623
- images = images.cpu()
1624
- label_strings = [str(label.item()) for label in labels]
1625
- _imshow(images, label_strings, nrow=20, fontsize=12)
1626
-
1627
- elif train_mode == 'irm':
1628
- for plate_name, train_loader in zip(plate_names, train_loaders):
1629
- print(f'Plate: {plate_name} with {len(train_loader.dataset)} images')
1630
- for idx, (images, labels, filenames) in enumerate(train_loader):
1631
- if idx >= max_show:
1632
- break
1633
- images = images.cpu()
1634
- label_strings = [str(label.item()) for label in labels]
1635
- _imshow(images, label_strings, nrow=20, fontsize=12)
1278
+ else:
1279
+ print(f'Classes will be defined by the Q1 and Q3 quantiles of recruitment (pathogen/cytoplasm for channel {channel_of_interest})')
1280
+ df['recruitment'] = df[f'pathogen_channel_{channel_of_interest}_mean_intensity']/df[f'cytoplasm_channel_{channel_of_interest}_mean_intensity']
1281
+
1282
+ q25 = df['recruitment'].quantile(0.25)
1283
+ q75 = df['recruitment'].quantile(0.75)
1284
+ df_lower = df[df['recruitment'] <= q25]
1285
+ df_upper = df[df['recruitment'] >= q75]
1286
+
1287
+ class_paths_lower = get_paths_from_db(df=df_lower, png_df=png_list_df, image_type=png_type)
1288
+
1289
+ class_paths_lower = random.sample(class_paths_lower['png_path'].tolist(), size)
1290
+ class_paths_ls.append(class_paths_lower)
1291
+
1292
+ class_paths_upper = get_paths_from_db(df=df_upper, png_df=png_list_df, image_type=png_type)
1293
+ class_paths_upper = random.sample(class_paths_upper['png_path'].tolist(), size)
1294
+ class_paths_ls.append(class_paths_upper)
1636
1295
 
1637
- return train_loaders, val_loaders, plate_names
1296
+ generate_dataset_from_lists(dst, class_data=class_paths_ls, classes=classes, test_split=0.1)
1297
+
1298
+ return
1638
1299
 
1639
1300
  def generate_loaders(src, train_mode='erm', mode='train', image_size=224, batch_size=32, classes=['nc','pc'], num_workers=None, validation_split=0.0, max_show=2, pin_memory=False, normalize=False, channels=[1, 2, 3], verbose=False):
1640
1301
 
@@ -1671,6 +1332,7 @@ def generate_loaders(src, train_mode='erm', mode='train', image_size=224, batch_
1671
1332
  import random
1672
1333
  from PIL import Image
1673
1334
  from torchvision.transforms import ToTensor
1335
+ from .utils import SelectChannels
1674
1336
 
1675
1337
  chans = []
1676
1338
 
@@ -1687,20 +1349,6 @@ def generate_loaders(src, train_mode='erm', mode='train', image_size=224, batch_
1687
1349
  print(f'Training a network on channels: {channels}')
1688
1350
  print(f'Channel 1: Red, Channel 2: Green, Channel 3: Blue')
1689
1351
 
1690
- class SelectChannels:
1691
- def __init__(self, channels):
1692
- self.channels = channels
1693
-
1694
- def __call__(self, img):
1695
- img = img.clone()
1696
- if 1 not in self.channels:
1697
- img[0, :, :] = 0 # Zero out the red channel
1698
- if 2 not in self.channels:
1699
- img[1, :, :] = 0 # Zero out the green channel
1700
- if 3 not in self.channels:
1701
- img[2, :, :] = 0 # Zero out the blue channel
1702
- return img
1703
-
1704
1352
  plate_to_filenames = defaultdict(list)
1705
1353
  plate_to_labels = defaultdict(list)
1706
1354
  train_loaders = []
@@ -1989,41 +1637,225 @@ def analyze_recruitment(src, metadata_settings, advanced_settings):
1989
1637
  cells,wells = _results_to_csv(src, df, df_well)
1990
1638
  return [cells,wells]
1991
1639
 
1640
+ def _merge_cells_based_on_parasite_overlap(parasite_mask, cell_mask, nuclei_mask, overlap_threshold=5, perimeter_threshold=30):
1641
+ """
1642
+ Merge cells in cell_mask if a parasite in parasite_mask overlaps with more than one cell,
1643
+ and if cells share more than a specified perimeter percentage.
1644
+
1645
+ Args:
1646
+ parasite_mask (ndarray): Mask of parasites.
1647
+ cell_mask (ndarray): Mask of cells.
1648
+ nuclei_mask (ndarray): Mask of nuclei.
1649
+ overlap_threshold (float): The percentage threshold for merging cells based on parasite overlap.
1650
+ perimeter_threshold (float): The percentage threshold for merging cells based on shared perimeter.
1651
+
1652
+ Returns:
1653
+ ndarray: The modified cell mask (cell_mask) with unique labels.
1654
+ """
1655
+ labeled_cells = label(cell_mask)
1656
+ labeled_parasites = label(parasite_mask)
1657
+ labeled_nuclei = label(nuclei_mask)
1658
+ num_parasites = np.max(labeled_parasites)
1659
+ num_cells = np.max(labeled_cells)
1660
+ num_nuclei = np.max(labeled_nuclei)
1661
+
1662
+ # Merge cells based on parasite overlap
1663
+ for parasite_id in range(1, num_parasites + 1):
1664
+ current_parasite_mask = labeled_parasites == parasite_id
1665
+ overlapping_cell_labels = np.unique(labeled_cells[current_parasite_mask])
1666
+ overlapping_cell_labels = overlapping_cell_labels[overlapping_cell_labels != 0]
1667
+ if len(overlapping_cell_labels) > 1:
1668
+ # Calculate the overlap percentages
1669
+ overlap_percentages = [
1670
+ np.sum(current_parasite_mask & (labeled_cells == cell_label)) / np.sum(current_parasite_mask) * 100
1671
+ for cell_label in overlapping_cell_labels
1672
+ ]
1673
+ # Merge cells if overlap percentage is above the threshold
1674
+ for cell_label, overlap_percentage in zip(overlapping_cell_labels, overlap_percentages):
1675
+ if overlap_percentage > overlap_threshold:
1676
+ first_label = overlapping_cell_labels[0]
1677
+ for other_label in overlapping_cell_labels[1:]:
1678
+ if other_label != first_label:
1679
+ cell_mask[cell_mask == other_label] = first_label
1680
+
1681
+ # Merge cells based on nucleus overlap
1682
+ for nucleus_id in range(1, num_nuclei + 1):
1683
+ current_nucleus_mask = labeled_nuclei == nucleus_id
1684
+ overlapping_cell_labels = np.unique(labeled_cells[current_nucleus_mask])
1685
+ overlapping_cell_labels = overlapping_cell_labels[overlapping_cell_labels != 0]
1686
+ if len(overlapping_cell_labels) > 1:
1687
+ # Calculate the overlap percentages
1688
+ overlap_percentages = [
1689
+ np.sum(current_nucleus_mask & (labeled_cells == cell_label)) / np.sum(current_nucleus_mask) * 100
1690
+ for cell_label in overlapping_cell_labels
1691
+ ]
1692
+ # Merge cells if overlap percentage is above the threshold for each cell
1693
+ if all(overlap_percentage > overlap_threshold for overlap_percentage in overlap_percentages):
1694
+ first_label = overlapping_cell_labels[0]
1695
+ for other_label in overlapping_cell_labels[1:]:
1696
+ if other_label != first_label:
1697
+ cell_mask[cell_mask == other_label] = first_label
1698
+
1699
+ # Check for cells without nuclei and merge based on shared perimeter
1700
+ labeled_cells = label(cell_mask) # Re-label after merging based on overlap
1701
+ cell_regions = regionprops(labeled_cells)
1702
+ for region in cell_regions:
1703
+ cell_label = region.label
1704
+ cell_mask_binary = labeled_cells == cell_label
1705
+ overlapping_nuclei = np.unique(nuclei_mask[cell_mask_binary])
1706
+ overlapping_nuclei = overlapping_nuclei[overlapping_nuclei != 0]
1707
+
1708
+ if len(overlapping_nuclei) == 0:
1709
+ # Cell does not overlap with any nucleus
1710
+ perimeter = region.perimeter
1711
+ # Dilate the cell to find neighbors
1712
+ dilated_cell = binary_dilation(cell_mask_binary, structure=square(3))
1713
+ neighbor_cells = np.unique(labeled_cells[dilated_cell])
1714
+ neighbor_cells = neighbor_cells[(neighbor_cells != 0) & (neighbor_cells != cell_label)]
1715
+ # Calculate shared border length with neighboring cells
1716
+ shared_borders = [
1717
+ np.sum((labeled_cells == neighbor_label) & dilated_cell) for neighbor_label in neighbor_cells
1718
+ ]
1719
+ shared_border_percentages = [shared_border / perimeter * 100 for shared_border in shared_borders]
1720
+ # Merge with the neighbor cell with the largest shared border percentage above the threshold
1721
+ if shared_borders:
1722
+ max_shared_border_index = np.argmax(shared_border_percentages)
1723
+ max_shared_border_percentage = shared_border_percentages[max_shared_border_index]
1724
+ if max_shared_border_percentage > perimeter_threshold:
1725
+ cell_mask[labeled_cells == cell_label] = neighbor_cells[max_shared_border_index]
1726
+
1727
+ # Relabel the merged cell mask
1728
+ relabeled_cell_mask, _ = label(cell_mask, return_num=True)
1729
+ return relabeled_cell_mask
1730
+
1731
+ def adjust_cell_masks(parasite_folder, cell_folder, nuclei_folder, overlap_threshold=5, perimeter_threshold=30):
1732
+ """
1733
+ Process all npy files in the given folders. Merge and relabel cells in cell masks
1734
+ based on parasite overlap and cell perimeter sharing conditions.
1735
+
1736
+ Args:
1737
+ parasite_folder (str): Path to the folder containing parasite masks.
1738
+ cell_folder (str): Path to the folder containing cell masks.
1739
+ nuclei_folder (str): Path to the folder containing nuclei masks.
1740
+ overlap_threshold (float): The percentage threshold for merging cells based on parasite overlap.
1741
+ perimeter_threshold (float): The percentage threshold for merging cells based on shared perimeter.
1742
+ """
1743
+
1744
+ parasite_files = sorted([f for f in os.listdir(parasite_folder) if f.endswith('.npy')])
1745
+ cell_files = sorted([f for f in os.listdir(cell_folder) if f.endswith('.npy')])
1746
+ nuclei_files = sorted([f for f in os.listdir(nuclei_folder) if f.endswith('.npy')])
1747
+
1748
+ # Ensure there are matching files in all folders
1749
+ if not (len(parasite_files) == len(cell_files) == len(nuclei_files)):
1750
+ raise ValueError("The number of files in the folders do not match.")
1751
+
1752
+ # Match files by name
1753
+ for file_name in parasite_files:
1754
+ parasite_path = os.path.join(parasite_folder, file_name)
1755
+ cell_path = os.path.join(cell_folder, file_name)
1756
+ nuclei_path = os.path.join(nuclei_folder, file_name)
1757
+ # Check if the corresponding cell and nuclei mask files exist
1758
+ if not (os.path.exists(cell_path) and os.path.exists(nuclei_path)):
1759
+ raise ValueError(f"Corresponding cell or nuclei mask file for {file_name} not found.")
1760
+ # Load the masks
1761
+ parasite_mask = np.load(parasite_path)
1762
+ cell_mask = np.load(cell_path)
1763
+ nuclei_mask = np.load(nuclei_path)
1764
+ # Merge and relabel cells
1765
+ merged_cell_mask = _merge_cells_based_on_parasite_overlap(parasite_mask, cell_mask, nuclei_mask, overlap_threshold, perimeter_threshold)
1766
+ # Overwrite the original cell mask file with the merged result
1767
+ np.save(cell_path, merged_cell_mask)
1768
+
1769
+ def process_masks(mask_folder, image_folder, channel, batch_size=50, n_clusters=2, plot=False):
1770
+
1771
+ def read_files_in_batches(folder, batch_size=50):
1772
+ files = [f for f in os.listdir(folder) if f.endswith('.npy')]
1773
+ files.sort() # Sort to ensure matching order
1774
+ for i in range(0, len(files), batch_size):
1775
+ yield files[i:i + batch_size]
1776
+
1777
+ def measure_morphology_and_intensity(mask, image):
1778
+ properties = measure.regionprops(mask, intensity_image=image)
1779
+ properties_list = [{'area': p.area, 'mean_intensity': p.mean_intensity, 'perimeter': p.perimeter, 'eccentricity': p.eccentricity} for p in properties]
1780
+ return properties_list
1781
+
1782
+ def cluster_objects(properties, n_clusters=2):
1783
+ data = np.array([[p['area'], p['mean_intensity'], p['perimeter'], p['eccentricity']] for p in properties])
1784
+ kmeans = KMeans(n_clusters=n_clusters, random_state=0).fit(data)
1785
+ return kmeans
1786
+
1787
+ def remove_objects_not_in_largest_cluster(mask, labels, largest_cluster_label):
1788
+ cleaned_mask = np.zeros_like(mask)
1789
+ for region in measure.regionprops(mask):
1790
+ if labels[region.label - 1] == largest_cluster_label:
1791
+ cleaned_mask[mask == region.label] = region.label
1792
+ return cleaned_mask
1793
+
1794
+ def plot_clusters(properties, labels):
1795
+ data = np.array([[p['area'], p['mean_intensity'], p['perimeter'], p['eccentricity']] for p in properties])
1796
+ pca = PCA(n_components=2)
1797
+ data_2d = pca.fit_transform(data)
1798
+ plt.scatter(data_2d[:, 0], data_2d[:, 1], c=labels, cmap='viridis')
1799
+ plt.xlabel('PCA Component 1')
1800
+ plt.ylabel('PCA Component 2')
1801
+ plt.title('Object Clustering')
1802
+ plt.show()
1803
+
1804
+ all_properties = []
1805
+
1806
+ # Step 1: Accumulate properties over all files
1807
+ for batch in read_files_in_batches(mask_folder, batch_size):
1808
+ mask_files = [os.path.join(mask_folder, file) for file in batch]
1809
+ image_files = [os.path.join(image_folder, file) for file in batch]
1810
+
1811
+ masks = [np.load(file) for file in mask_files]
1812
+ images = [np.load(file)[:, :, channel] for file in image_files]
1813
+
1814
+ for i, mask in enumerate(masks):
1815
+ image = images[i]
1816
+ # Measure morphology and intensity
1817
+ properties = measure_morphology_and_intensity(mask, image)
1818
+ all_properties.extend(properties)
1819
+
1820
+ # Step 2: Perform clustering on accumulated properties
1821
+ kmeans = cluster_objects(all_properties, n_clusters)
1822
+ labels = kmeans.labels_
1823
+
1824
+ if plot:
1825
+ # Step 3: Plot clusters using PCA
1826
+ plot_clusters(all_properties, labels)
1827
+
1828
+ # Step 4: Remove objects not in the largest cluster and overwrite files in batches
1829
+ label_index = 0
1830
+ for batch in read_files_in_batches(mask_folder, batch_size):
1831
+ mask_files = [os.path.join(mask_folder, file) for file in batch]
1832
+ masks = [np.load(file) for file in mask_files]
1833
+
1834
+ for i, mask in enumerate(masks):
1835
+ batch_properties = measure_morphology_and_intensity(mask, mask)
1836
+ batch_labels = labels[label_index:label_index + len(batch_properties)]
1837
+ largest_cluster_label = np.bincount(batch_labels).argmax()
1838
+ cleaned_mask = remove_objects_not_in_largest_cluster(mask, batch_labels, largest_cluster_label)
1839
+ np.save(mask_files[i], cleaned_mask)
1840
+ label_index += len(batch_properties)
1841
+
1992
1842
  def preprocess_generate_masks(src, settings={}):
1993
1843
 
1994
1844
  from .io import preprocess_img_data, _load_and_concatenate_arrays
1995
1845
  from .plot import plot_merged, plot_arrays
1996
- from .utils import _pivot_counts_table
1997
-
1998
- settings['plot'] = False
1999
- settings['fps'] = 2
2000
- settings['remove_background'] = True
2001
- settings['lower_quantile'] = 0.02
2002
- settings['merge'] = False
2003
- settings['normalize_plots'] = True
2004
- settings['all_to_mip'] = False
2005
- settings['pick_slice'] = False
2006
- settings['skip_mode'] = src
2007
- settings['workers'] = os.cpu_count()-4
2008
- settings['verbose'] = True
2009
- settings['examples_to_plot'] = 1
2010
- settings['src'] = src
2011
- settings['upscale'] = False
2012
- settings['upscale_factor'] = 2.0
2013
-
2014
- settings['randomize'] = True
2015
- settings['timelapse'] = False
2016
- settings['timelapse_displacement'] = None
2017
- settings['timelapse_memory'] = 3
2018
- settings['timelapse_frame_limits'] = None
2019
- settings['timelapse_remove_transient'] = False
2020
- settings['timelapse_mode'] = 'trackpy'
2021
- settings['timelapse_objects'] = ['cells']
1846
+ from .utils import _pivot_counts_table, set_default_settings_preprocess_generate_masks, set_default_plot_merge_settings, check_mask_folder
1847
+
1848
+ settings = set_default_settings_preprocess_generate_masks(src, settings)
2022
1849
 
2023
1850
  settings_df = pd.DataFrame(list(settings.items()), columns=['Key', 'Value'])
2024
1851
  settings_csv = os.path.join(src,'settings','preprocess_generate_masks_settings.csv')
2025
1852
  os.makedirs(os.path.join(src,'settings'), exist_ok=True)
2026
1853
  settings_df.to_csv(settings_csv, index=False)
1854
+
1855
+ if not settings['pathogen_channel'] is None:
1856
+ custom_model_ls = ['toxo_pv_lumen','toxo_cyto']
1857
+ if settings['pathogen_model'] not in custom_model_ls:
1858
+ ValueError(f'Pathogen model must be {custom_model_ls} or None')
2027
1859
 
2028
1860
  if settings['timelapse']:
2029
1861
  settings['randomize'] = False
@@ -2032,24 +1864,50 @@ def preprocess_generate_masks(src, settings={}):
2032
1864
  if not settings['masks']:
2033
1865
  print(f'WARNING: channels for mask generation are defined when preprocess = True')
2034
1866
 
2035
- if isinstance(settings['merge'], bool):
2036
- settings['merge'] = [settings['merge']]*3
2037
1867
  if isinstance(settings['save'], bool):
2038
1868
  settings['save'] = [settings['save']]*3
2039
1869
 
1870
+ if settings['verbose']:
1871
+ settings_df = pd.DataFrame(list(settings.items()), columns=['setting_key', 'setting_value'])
1872
+ settings_df['setting_value'] = settings_df['setting_value'].apply(str)
1873
+ display(settings_df)
1874
+
1875
+ if settings['test_mode']:
1876
+ print(f'Starting Test mode ...')
1877
+
2040
1878
  if settings['preprocess']:
2041
1879
  settings, src = preprocess_img_data(settings)
2042
1880
 
2043
1881
  if settings['masks']:
2044
1882
  mask_src = os.path.join(src, 'norm_channel_stack')
2045
1883
  if settings['cell_channel'] != None:
2046
- generate_cellpose_masks(src=mask_src, settings=settings, object_type='cell')
1884
+ if check_mask_folder(src, 'cell_mask_stack'):
1885
+ generate_cellpose_masks(mask_src, settings, 'cell')
2047
1886
 
2048
1887
  if settings['nucleus_channel'] != None:
2049
- generate_cellpose_masks(src=mask_src, settings=settings, object_type='nucleus')
1888
+ if check_mask_folder(src, 'nucleus_mask_stack'):
1889
+ generate_cellpose_masks(mask_src, settings, 'nucleus')
2050
1890
 
2051
1891
  if settings['pathogen_channel'] != None:
2052
- generate_cellpose_masks(src=mask_src, settings=settings, object_type='pathogen')
1892
+ if check_mask_folder(src, 'pathogen_mask_stack'):
1893
+ generate_cellpose_masks(mask_src, settings, 'pathogen')
1894
+
1895
+ if settings['adjust_cells']:
1896
+ if settings['pathogen_channel'] != None and settings['cell_channel'] != None and settings['nucleus_channel'] != None:
1897
+
1898
+ start = time.time()
1899
+ cell_folder = os.path.join(mask_src, 'cell_mask_stack')
1900
+ nuclei_folder = os.path.join(mask_src, 'nucleus_mask_stack')
1901
+ parasite_folder = os.path.join(mask_src, 'pathogen_mask_stack')
1902
+ #image_folder = os.path.join(src, 'stack')
1903
+
1904
+ #process_masks(cell_folder, image_folder, settings['cell_channel'], settings['batch_size'], n_clusters=2, plot=settings['plot'])
1905
+ #process_masks(nuclei_folder, image_folder, settings['nucleus_channel'], settings['batch_size'], n_clusters=2, plot=settings['plot'])
1906
+ #process_masks(parasite_folder, image_folder, settings['pathogen_channel'], settings['batch_size'], n_clusters=2, plot=settings['plot'])
1907
+
1908
+ adjust_cell_masks(parasite_folder, cell_folder, nuclei_folder, overlap_threshold=5, perimeter_threshold=30)
1909
+ stop = time.time()
1910
+ print(f'Cell mask adjustment: {stop-start} seconds')
2053
1911
 
2054
1912
  if os.path.exists(os.path.join(src,'measurements')):
2055
1913
  _pivot_counts_table(db_path=os.path.join(src,'measurements', 'measurements.db'))
@@ -2078,28 +1936,14 @@ def preprocess_generate_masks(src, settings={}):
2078
1936
  overlay_channels = [settings['nucleus_channel'], settings['pathogen_channel'], settings['cell_channel']]
2079
1937
  overlay_channels = [element for element in overlay_channels if element is not None]
2080
1938
 
2081
- plot_settings = {'include_noninfected':True,
2082
- 'include_multiinfected':True,
2083
- 'include_multinucleated':True,
2084
- 'remove_background':False,
2085
- 'filter_min_max':None,
2086
- 'channel_dims':settings['channels'],
2087
- 'backgrounds':[100,100,100,100],
2088
- 'cell_mask_dim':cell_mask_dim,
2089
- 'nucleus_mask_dim':nucleus_mask_dim,
2090
- 'pathogen_mask_dim':pathogen_mask_dim,
2091
- 'outline_thickness':3,
2092
- 'outline_color':'gbr',
2093
- 'overlay_chans':overlay_channels,
2094
- 'overlay':True,
2095
- 'normalization_percentiles':[1,99],
2096
- 'normalize':True,
2097
- 'print_object_number':True,
2098
- 'nr':settings['examples_to_plot'],
2099
- 'figuresize':20,
2100
- 'cmap':'inferno',
2101
- 'verbose':False}
2102
-
1939
+ plot_settings = set_default_plot_merge_settings()
1940
+ plot_settings['channel_dims'] = settings['channels']
1941
+ plot_settings['cell_mask_dim'] = cell_mask_dim
1942
+ plot_settings['nucleus_mask_dim'] = nucleus_mask_dim
1943
+ plot_settings['pathogen_mask_dim'] = pathogen_mask_dim
1944
+ plot_settings['overlay_chans'] = overlay_channels
1945
+ plot_settings['nr'] = settings['examples_to_plot']
1946
+
2103
1947
  if settings['test_mode'] == True:
2104
1948
  plot_settings['nr'] = len(os.path.join(src,'merged'))
2105
1949
 
@@ -2108,7 +1952,7 @@ def preprocess_generate_masks(src, settings={}):
2108
1952
  except Exception as e:
2109
1953
  print(f'Failed to plot image mask overly. Error: {e}')
2110
1954
  else:
2111
- plot_arrays(src=os.path.join(src,'merged'), figuresize=50, cmap='inferno', nr=1, normalize=True, q1=1, q2=99)
1955
+ plot_arrays(src=os.path.join(src,'merged'), figuresize=settings['figuresize'], cmap=settings['cmap'], nr=settings['examples_to_plot'], normalize=settings['normalize'], q1=1, q2=99)
2112
1956
 
2113
1957
  torch.cuda.empty_cache()
2114
1958
  gc.collect()
@@ -2121,36 +1965,62 @@ def identify_masks_finetune(settings):
2121
1965
  from .utils import get_files_from_dir, resize_images_and_labels
2122
1966
  from .io import _load_normalized_images_and_labels, _load_images_and_labels
2123
1967
 
1968
+ #User defined settings
2124
1969
  src=settings['src']
2125
1970
  dst=settings['dst']
1971
+
1972
+
1973
+ settings.setdefault('model_name', 'cyto')
1974
+ settings.setdefault('custom_model', None)
1975
+ settings.setdefault('channels', [0,0])
1976
+ settings.setdefault('background', 100)
1977
+ settings.setdefault('remove_background', False)
1978
+ settings.setdefault('Signal_to_noise', 10)
1979
+ settings.setdefault('CP_prob', 0)
1980
+ settings.setdefault('diameter', 30)
1981
+ settings.setdefault('batch_size', 50)
1982
+ settings.setdefault('flow_threshold', 0.4)
1983
+ settings.setdefault('save', False)
1984
+ settings.setdefault('verbose', False)
1985
+ settings.setdefault('normalize', True)
1986
+ settings.setdefault('percentiles', None)
1987
+ settings.setdefault('circular', False)
1988
+ settings.setdefault('invert', False)
1989
+ settings.setdefault('resize', False)
1990
+ settings.setdefault('target_height', None)
1991
+ settings.setdefault('target_width', None)
1992
+ settings.setdefault('rescale', False)
1993
+ settings.setdefault('resample', False)
1994
+ settings.setdefault('grayscale', True)
1995
+
1996
+
2126
1997
  model_name=settings['model_name']
1998
+ custom_model=settings['custom_model']
1999
+ channels = settings['channels']
2000
+ background = settings['background']
2001
+ remove_background=settings['remove_background']
2002
+ Signal_to_noise = settings['Signal_to_noise']
2003
+ CP_prob = settings['CP_prob']
2127
2004
  diameter=settings['diameter']
2128
2005
  batch_size=settings['batch_size']
2129
2006
  flow_threshold=settings['flow_threshold']
2130
- cellprob_threshold=settings['cellprob_threshold']
2131
-
2132
- verbose=settings['verbose']
2133
- plot=settings['plot']
2134
2007
  save=settings['save']
2135
- custom_model=settings['custom_model']
2136
- overlay=settings['overlay']
2008
+ verbose=settings['verbose']
2137
2009
 
2138
- figuresize=25
2139
- cmap='inferno'
2140
- channels = [0,0]
2141
- signal_thresholds = 1000
2142
- normalize = True
2143
- percentiles = [2,98]
2144
- circular = False
2145
- invert = False
2146
- resize = False
2147
- settings['width_height'] = [1000,1000]
2148
- target_height = settings['width_height'][1]
2149
- target_width = settings['width_height'][0]
2150
- rescale = False
2151
- resample = False
2152
- grayscale = True
2153
- test = False
2010
+ # static settings
2011
+ normalize = settings['normalize']
2012
+ percentiles = settings['percentiles']
2013
+ circular = settings['circular']
2014
+ invert = settings['invert']
2015
+ resize = settings['resize']
2016
+
2017
+ if resize:
2018
+ target_height = settings['target_height']
2019
+ target_width = settings['target_width']
2020
+
2021
+ rescale = settings['rescale']
2022
+ resample = settings['resample']
2023
+ grayscale = settings['grayscale']
2154
2024
 
2155
2025
  os.makedirs(dst, exist_ok=True)
2156
2026
 
@@ -2179,7 +2049,7 @@ def identify_masks_finetune(settings):
2179
2049
  print(f'Using channels: {chans} for model of type {model_name}')
2180
2050
 
2181
2051
  if verbose == True:
2182
- print(f'Cellpose settings: Model: {model_name}, channels: {channels}, cellpose_chans: {chans}, diameter:{diameter}, flow_threshold:{flow_threshold}, cellprob_threshold:{cellprob_threshold}')
2052
+ print(f'Cellpose settings: Model: {model_name}, channels: {channels}, cellpose_chans: {chans}, diameter:{diameter}, flow_threshold:{flow_threshold}, cellprob_threshold:{CP_prob}')
2183
2053
 
2184
2054
  all_image_files = [os.path.join(src, f) for f in os.listdir(src) if f.endswith('.tif')]
2185
2055
 
@@ -2188,9 +2058,9 @@ def identify_masks_finetune(settings):
2188
2058
  time_ls = []
2189
2059
  for i in range(0, len(all_image_files), batch_size):
2190
2060
  image_files = all_image_files[i:i+batch_size]
2191
-
2061
+
2192
2062
  if normalize:
2193
- images, _, image_names, _ = _load_normalized_images_and_labels(image_files=image_files, label_files=None, signal_thresholds=signal_thresholds, channels=channels, percentiles=percentiles, circular=circular, invert=invert, visualize=plot)
2063
+ images, _, image_names, _ = _load_normalized_images_and_labels(image_files=image_files, label_files=None, channels=channels, percentiles=percentiles, circular=circular, invert=invert, visualize=verbose, remove_background=remove_background, background=background, Signal_to_noise=Signal_to_noise)
2194
2064
  images = [np.squeeze(img) if img.shape[-1] == 1 else img for img in images]
2195
2065
  orig_dims = [(image.shape[0], image.shape[1]) for image in images]
2196
2066
  else:
@@ -2208,7 +2078,7 @@ def identify_masks_finetune(settings):
2208
2078
  channel_axis=3,
2209
2079
  diameter=diameter,
2210
2080
  flow_threshold=flow_threshold,
2211
- cellprob_threshold=cellprob_threshold,
2081
+ cellprob_threshold=CP_prob,
2212
2082
  rescale=rescale,
2213
2083
  resample=resample,
2214
2084
  progress=True)
@@ -2229,11 +2099,12 @@ def identify_masks_finetune(settings):
2229
2099
  time_ls.append(duration)
2230
2100
  average_time = np.mean(time_ls) if len(time_ls) > 0 else 0
2231
2101
  print(f'Processing {file_index+1}/{len(images)} images : Time/image {average_time:.3f} sec', end='\r', flush=True)
2232
- if plot:
2102
+ if verbose:
2233
2103
  if resize:
2234
2104
  stack = resizescikit(stack, dims, preserve_range=True, anti_aliasing=False).astype(stack.dtype)
2235
- print_mask_and_flows(stack, mask, flows, overlay=overlay)
2105
+ print_mask_and_flows(stack, mask, flows, overlay=True)
2236
2106
  if save:
2107
+ os.makedirs(dst, exist_ok=True)
2237
2108
  output_filename = os.path.join(dst, image_names[file_index])
2238
2109
  cv2.imwrite(output_filename, mask)
2239
2110
  return
@@ -2375,8 +2246,6 @@ def identify_masks(src, object_type, model_name, batch_size, channels, diameter,
2375
2246
  stitch_threshold=0.0
2376
2247
 
2377
2248
  cellpose_batch_size = _get_cellpose_batch_size()
2378
-
2379
- #model = cellpose.denoise.DenoiseModel(model_type=f"denoise_{model_name}", gpu=True)
2380
2249
 
2381
2250
  masks, flows, _, _ = model.eval(x=batch,
2382
2251
  batch_size=cellpose_batch_size,
@@ -2450,9 +2319,21 @@ def all_elements_match(list1, list2):
2450
2319
  # Check if all elements in list1 are in list2
2451
2320
  return all(element in list2 for element in list1)
2452
2321
 
2322
+ def prepare_batch_for_cellpose(batch):
2323
+ # Ensure the batch is of dtype float32
2324
+ if batch.dtype != np.float32:
2325
+ batch = batch.astype(np.float32)
2326
+
2327
+ # Normalize each image in the batch
2328
+ for i in range(batch.shape[0]):
2329
+ if batch[i].max() > 1:
2330
+ batch[i] = batch[i] / batch[i].max()
2331
+
2332
+ return batch
2333
+
2453
2334
  def generate_cellpose_masks(src, settings, object_type):
2454
2335
 
2455
- from .utils import _masks_to_masks_stack, _filter_cp_masks, _get_cellpose_batch_size, _get_object_settings, _get_cellpose_channels, _choose_model, mask_object_count
2336
+ from .utils import _masks_to_masks_stack, _filter_cp_masks, _get_cellpose_batch_size, _get_object_settings, _get_cellpose_channels, _choose_model, mask_object_count, set_default_settings_preprocess_generate_masks
2456
2337
  from .io import _create_database, _save_object_counts_to_database, _check_masks, _get_avg_object_size
2457
2338
  from .timelapse import _npz_to_movie, _btrack_track_cells, _trackpy_track_cells
2458
2339
  from .plot import plot_masks
@@ -2460,6 +2341,13 @@ def generate_cellpose_masks(src, settings, object_type):
2460
2341
  gc.collect()
2461
2342
  if not torch.cuda.is_available():
2462
2343
  print(f'Torch CUDA is not available, using CPU')
2344
+
2345
+ settings = set_default_settings_preprocess_generate_masks(src, settings)
2346
+
2347
+ if settings['verbose']:
2348
+ settings_df = pd.DataFrame(list(settings.items()), columns=['setting_key', 'setting_value'])
2349
+ settings_df['setting_value'] = settings_df['setting_value'].apply(str)
2350
+ display(settings_df)
2463
2351
 
2464
2352
  figuresize=25
2465
2353
  timelapse = settings['timelapse']
@@ -2474,8 +2362,9 @@ def generate_cellpose_masks(src, settings, object_type):
2474
2362
 
2475
2363
  batch_size = settings['batch_size']
2476
2364
  cellprob_threshold = settings[f'{object_type}_CP_prob']
2477
- flow_threshold = 30
2478
-
2365
+
2366
+ flow_threshold = settings[f'{object_type}_FT']
2367
+
2479
2368
  object_settings = _get_object_settings(object_type, settings)
2480
2369
  model_name = object_settings['model_name']
2481
2370
 
@@ -2486,7 +2375,12 @@ def generate_cellpose_masks(src, settings, object_type):
2486
2375
  channels = cellpose_channels[object_type]
2487
2376
  cellpose_batch_size = _get_cellpose_batch_size()
2488
2377
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
2489
- model = _choose_model(model_name, device, object_type='cell', restore_type=None)
2378
+
2379
+ if object_type == 'pathogen' and not settings['pathogen_model'] is None:
2380
+ model_name = settings['pathogen_model']
2381
+
2382
+ model = _choose_model(model_name, device, object_type=object_type, restore_type=None, object_settings=object_settings)
2383
+
2490
2384
  chans = [2, 1] if model_name == 'cyto2' else [0,0] if model_name == 'nucleus' else [2,0] if model_name == 'cyto' else [2, 0] if model_name == 'cyto3' else [2, 0]
2491
2385
  paths = [os.path.join(src, file) for file in os.listdir(src) if file.endswith('.npz')]
2492
2386
 
@@ -2505,6 +2399,14 @@ def generate_cellpose_masks(src, settings, object_type):
2505
2399
  with np.load(path) as data:
2506
2400
  stack = data['data']
2507
2401
  filenames = data['filenames']
2402
+
2403
+ for i, filename in enumerate(filenames):
2404
+ output_path = os.path.join(output_folder, filename)
2405
+
2406
+ if os.path.exists(output_path):
2407
+ print(f"File {filename} already exists in the output folder. Skipping...")
2408
+ continue
2409
+
2508
2410
  if settings['timelapse']:
2509
2411
 
2510
2412
  trackable_objects = ['cell','nucleus','pathogen']
@@ -2539,31 +2441,43 @@ def generate_cellpose_masks(src, settings, object_type):
2539
2441
  if batch.size == 0:
2540
2442
  print(f'Processing {file_index}/{len(paths)}: Images/npz {batch.shape[0]}')
2541
2443
  continue
2542
- if batch.max() > 1:
2543
- batch = batch / batch.max()
2444
+
2445
+ batch = prepare_batch_for_cellpose(batch)
2544
2446
 
2545
2447
  if timelapse:
2546
- stitch_threshold=100.0
2547
2448
  movie_path = os.path.join(os.path.dirname(src), 'movies')
2548
2449
  os.makedirs(movie_path, exist_ok=True)
2549
2450
  save_path = os.path.join(movie_path, f'timelapse_{object_type}_{name}.mp4')
2550
2451
  _npz_to_movie(batch, batch_filenames, save_path, fps=2)
2551
- else:
2552
- stitch_threshold=0.0
2553
-
2554
- print('batch.shape',batch.shape)
2555
- masks, flows, _, _ = model.eval(x=batch,
2556
- batch_size=cellpose_batch_size,
2557
- normalize=False,
2558
- channels=chans,
2559
- channel_axis=3,
2560
- diameter=object_settings['diameter'],
2561
- flow_threshold=flow_threshold,
2562
- cellprob_threshold=cellprob_threshold,
2563
- rescale=None,
2564
- resample=object_settings['resample'],
2565
- stitch_threshold=stitch_threshold)
2566
2452
 
2453
+ if settings['verbose']:
2454
+ print(f'Processing {file_index}/{len(paths)}: Images/npz {batch.shape[0]}')
2455
+
2456
+ #cellpose_normalize_dict = {'lowhigh':[0.0,1.0], #pass in normalization values for 0.0 and 1.0 as list [low, high] if None all other keys ignored
2457
+ # 'sharpen':object_settings['diameter']/4, #recommended to be 1/4-1/8 diameter of cells in pixels
2458
+ # 'normalize':True, #(if False, all following parameters ignored)
2459
+ # 'percentile':[2,98], #[perc_low, perc_high]
2460
+ # 'tile_norm':224, #normalize by tile set to e.g. 100 for normailize window to be 100 px
2461
+ # 'norm3D':True} #compute normalization across entire z-stack rather than plane-by-plane in stitching mode.
2462
+
2463
+ output = model.eval(x=batch,
2464
+ batch_size=cellpose_batch_size,
2465
+ normalize=False,
2466
+ channels=chans,
2467
+ channel_axis=3,
2468
+ diameter=object_settings['diameter'],
2469
+ flow_threshold=flow_threshold,
2470
+ cellprob_threshold=cellprob_threshold,
2471
+ rescale=None,
2472
+ resample=object_settings['resample'])
2473
+
2474
+ if len(output) == 4:
2475
+ masks, flows, _, _ = output
2476
+ elif len(output) == 3:
2477
+ masks, flows, _ = output
2478
+ else:
2479
+ raise ValueError(f"Unexpected number of return values from model.eval(). Expected 3 or 4, got {len(output)}")
2480
+
2567
2481
  if timelapse:
2568
2482
  if settings['plot']:
2569
2483
  for idx, (mask, flow, image) in enumerate(zip(masks, flows[0], batch)):
@@ -2676,15 +2590,15 @@ def generate_cellpose_masks(src, settings, object_type):
2676
2590
  torch.cuda.empty_cache()
2677
2591
  return
2678
2592
 
2679
- def generate_masks_from_imgs(src, model, model_name, batch_size, diameter, cellprob_threshold, grayscale, save, normalize, channels, percentiles, circular, invert, plot, resize, target_height, target_width, verbose):
2593
+ def generate_masks_from_imgs(src, model, model_name, batch_size, diameter, cellprob_threshold, flow_threshold, grayscale, save, normalize, channels, percentiles, circular, invert, plot, resize, target_height, target_width, remove_background, background, Signal_to_noise, verbose):
2594
+
2680
2595
  from .io import _load_images_and_labels, _load_normalized_images_and_labels
2681
2596
  from .utils import resize_images_and_labels, resizescikit
2682
2597
  from .plot import print_mask_and_flows
2683
2598
 
2684
2599
  dst = os.path.join(src, model_name)
2685
2600
  os.makedirs(dst, exist_ok=True)
2686
-
2687
- flow_threshold = 30
2601
+
2688
2602
  chans = [2, 1] if model_name == 'cyto2' else [0,0] if model_name == 'nucleus' else [1,0] if model_name == 'cyto' else [2, 0]
2689
2603
 
2690
2604
  if grayscale:
@@ -2692,7 +2606,6 @@ def generate_masks_from_imgs(src, model, model_name, batch_size, diameter, cellp
2692
2606
 
2693
2607
  all_image_files = [os.path.join(src, f) for f in os.listdir(src) if f.endswith('.tif')]
2694
2608
  random.shuffle(all_image_files)
2695
-
2696
2609
 
2697
2610
  if verbose == True:
2698
2611
  print(f'Cellpose settings: Model: {model_name}, channels: {channels}, cellpose_chans: {chans}, diameter:{diameter}, flow_threshold:{flow_threshold}, cellprob_threshold:{cellprob_threshold}')
@@ -2702,11 +2615,11 @@ def generate_masks_from_imgs(src, model, model_name, batch_size, diameter, cellp
2702
2615
  image_files = all_image_files[i:i+batch_size]
2703
2616
 
2704
2617
  if normalize:
2705
- images, _, image_names, _ = _load_normalized_images_and_labels(image_files=image_files, label_files=None, signal_thresholds=100, channels=channels, percentiles=percentiles, circular=circular, invert=invert, visualize=plot)
2618
+ images, _, image_names, _ = _load_normalized_images_and_labels(image_files, None, channels, percentiles, circular, invert, plot, remove_background, background, Signal_to_noise)
2706
2619
  images = [np.squeeze(img) if img.shape[-1] == 1 else img for img in images]
2707
2620
  orig_dims = [(image.shape[0], image.shape[1]) for image in images]
2708
2621
  else:
2709
- images, _, image_names, _ = _load_images_and_labels(image_files=image_files, label_files=None, circular=circular, invert=invert)
2622
+ images, _, image_names, _ = _load_images_and_labels(image_files, None, circular, invert)
2710
2623
  images = [np.squeeze(img) if img.shape[-1] == 1 else img for img in images]
2711
2624
  orig_dims = [(image.shape[0], image.shape[1]) for image in images]
2712
2625
  if resize:
@@ -2723,7 +2636,7 @@ def generate_masks_from_imgs(src, model, model_name, batch_size, diameter, cellp
2723
2636
  cellprob_threshold=cellprob_threshold,
2724
2637
  rescale=False,
2725
2638
  resample=False,
2726
- progress=True)
2639
+ progress=False)
2727
2640
 
2728
2641
  if len(output) == 4:
2729
2642
  mask, flows, _, _ = output
@@ -2753,22 +2666,31 @@ def generate_masks_from_imgs(src, model, model_name, batch_size, diameter, cellp
2753
2666
  def check_cellpose_models(settings):
2754
2667
 
2755
2668
  src = settings['src']
2756
- batch_size = settings['batch_size']
2757
- cellprob_threshold = settings['cellprob_threshold']
2758
- save = settings['save']
2759
- normalize = settings['normalize']
2760
- channels = settings['channels']
2761
- percentiles = settings['percentiles']
2762
- circular = settings['circular']
2763
- invert = settings['invert']
2764
- plot = settings['plot']
2765
- diameter = settings['diameter']
2766
- resize = settings['resize']
2767
- grayscale = settings['grayscale']
2768
- verbose = settings['verbose']
2769
- target_height = settings['width_height'][0]
2770
- target_width = settings['width_height'][1]
2771
-
2669
+ settings.setdefault('batch_size', 10)
2670
+ settings.setdefault('CP_prob', 0)
2671
+ settings.setdefault('flow_threshold', 0.4)
2672
+ settings.setdefault('save', True)
2673
+ settings.setdefault('normalize', True)
2674
+ settings.setdefault('channels', [0,0])
2675
+ settings.setdefault('percentiles', None)
2676
+ settings.setdefault('circular', False)
2677
+ settings.setdefault('invert', False)
2678
+ settings.setdefault('plot', True)
2679
+ settings.setdefault('diameter', 40)
2680
+ settings.setdefault('grayscale', True)
2681
+ settings.setdefault('remove_background', False)
2682
+ settings.setdefault('background', 100)
2683
+ settings.setdefault('Signal_to_noise', 5)
2684
+ settings.setdefault('verbose', False)
2685
+ settings.setdefault('resize', False)
2686
+ settings.setdefault('target_height', None)
2687
+ settings.setdefault('target_width', None)
2688
+
2689
+ if settings['verbose']:
2690
+ settings_df = pd.DataFrame(list(settings.items()), columns=['setting_key', 'setting_value'])
2691
+ settings_df['setting_value'] = settings_df['setting_value'].apply(str)
2692
+ display(settings_df)
2693
+
2772
2694
  cellpose_models = ['cyto', 'nuclei', 'cyto2', 'cyto3']
2773
2695
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
2774
2696
 
@@ -2776,149 +2698,22 @@ def check_cellpose_models(settings):
2776
2698
 
2777
2699
  model = cp_models.CellposeModel(gpu=True, model_type=model_name, device=device)
2778
2700
  print(f'Using {model_name}')
2779
- generate_masks_from_imgs(src, model, model_name, batch_size, diameter, cellprob_threshold, grayscale, save, normalize, channels, percentiles, circular, invert, plot, resize, target_height, target_width, verbose)
2780
-
2781
- return
2782
-
2783
- def compare_masks_v1(dir1, dir2, dir3, verbose=False):
2784
-
2785
- from .io import _read_mask
2786
- from .plot import visualize_masks, plot_comparison_results
2787
- from .utils import extract_boundaries, boundary_f1_score, compute_segmentation_ap, jaccard_index, dice_coefficient
2788
-
2789
- filenames = os.listdir(dir1)
2790
- results = []
2791
- cond_1 = os.path.basename(dir1)
2792
- cond_2 = os.path.basename(dir2)
2793
- cond_3 = os.path.basename(dir3)
2794
-
2795
- for index, filename in enumerate(filenames):
2796
- print(f'Processing image:{index+1}', end='\r', flush=True)
2797
- path1, path2, path3 = os.path.join(dir1, filename), os.path.join(dir2, filename), os.path.join(dir3, filename)
2798
-
2799
- print(path1)
2800
- print(path2)
2801
- print(path3)
2802
-
2803
- if os.path.exists(path2) and os.path.exists(path3):
2804
-
2805
- mask1, mask2, mask3 = _read_mask(path1), _read_mask(path2), _read_mask(path3)
2806
- boundary_true1, boundary_true2, boundary_true3 = extract_boundaries(mask1), extract_boundaries(mask2), extract_boundaries(mask3)
2807
-
2808
-
2809
- true_masks, pred_masks = [mask1], [mask2, mask3] # Assuming mask1 is the ground truth for simplicity
2810
- true_labels, pred_labels_1, pred_labels_2 = label(mask1), label(mask2), label(mask3)
2811
- average_precision_0, average_precision_1 = compute_segmentation_ap(mask1, mask2), compute_segmentation_ap(mask1, mask3)
2812
- ap_scores = [average_precision_0, average_precision_1]
2813
-
2814
- if verbose:
2815
- #unique_values1, unique_values2, unique_values3 = np.unique(mask1), np.unique(mask2), np.unique(mask3)
2816
- #print(f"Unique values in mask 1: {unique_values1}, mask 2: {unique_values2}, mask 3: {unique_values3}")
2817
- visualize_masks(boundary_true1, boundary_true2, boundary_true3, title=f"Boundaries - {filename}")
2818
-
2819
- boundary_f1_12, boundary_f1_13, boundary_f1_23 = boundary_f1_score(mask1, mask2), boundary_f1_score(mask1, mask3), boundary_f1_score(mask2, mask3)
2820
-
2821
- if (np.unique(mask1).size == 1 and np.unique(mask1)[0] == 0) and \
2822
- (np.unique(mask2).size == 1 and np.unique(mask2)[0] == 0) and \
2823
- (np.unique(mask3).size == 1 and np.unique(mask3)[0] == 0):
2824
- continue
2825
-
2826
- if verbose:
2827
- #unique_values4, unique_values5, unique_values6 = np.unique(boundary_f1_12), np.unique(boundary_f1_13), np.unique(boundary_f1_23)
2828
- #print(f"Unique values in boundary mask 1: {unique_values4}, mask 2: {unique_values5}, mask 3: {unique_values6}")
2829
- visualize_masks(mask1, mask2, mask3, title=filename)
2830
-
2831
- jaccard12 = jaccard_index(mask1, mask2)
2832
- dice12 = dice_coefficient(mask1, mask2)
2833
-
2834
- jaccard13 = jaccard_index(mask1, mask3)
2835
- dice13 = dice_coefficient(mask1, mask3)
2836
-
2837
- jaccard23 = jaccard_index(mask2, mask3)
2838
- dice23 = dice_coefficient(mask2, mask3)
2839
-
2840
- results.append({
2841
- f'filename': filename,
2842
- f'jaccard_{cond_1}_{cond_2}': jaccard12,
2843
- f'dice_{cond_1}_{cond_2}': dice12,
2844
- f'jaccard_{cond_1}_{cond_3}': jaccard13,
2845
- f'dice_{cond_1}_{cond_3}': dice13,
2846
- f'jaccard_{cond_2}_{cond_3}': jaccard23,
2847
- f'dice_{cond_2}_{cond_3}': dice23,
2848
- f'boundary_f1_{cond_1}_{cond_2}': boundary_f1_12,
2849
- f'boundary_f1_{cond_1}_{cond_3}': boundary_f1_13,
2850
- f'boundary_f1_{cond_2}_{cond_3}': boundary_f1_23,
2851
- f'average_precision_{cond_1}_{cond_2}': ap_scores[0],
2852
- f'average_precision_{cond_1}_{cond_3}': ap_scores[1]
2853
- })
2854
- else:
2855
- print(f'Cannot find {path1} or {path2} or {path3}')
2856
- fig = plot_comparison_results(results)
2857
- return results, fig
2858
-
2859
- def compare_cellpose_masks_v1(src, verbose=False):
2860
- from .io import _read_mask
2861
- from .plot import visualize_masks, plot_comparison_results, visualize_cellpose_masks
2862
- from .utils import extract_boundaries, boundary_f1_score, compute_segmentation_ap, jaccard_index
2863
-
2864
- import os
2865
- import numpy as np
2866
- from skimage.measure import label
2867
-
2868
- # Collect all subdirectories in src
2869
- dirs = [os.path.join(src, d) for d in os.listdir(src) if os.path.isdir(os.path.join(src, d))]
2870
-
2871
- dirs.sort() # Optional: sort directories if needed
2872
-
2873
- # Get common files in all directories
2874
- common_files = set(os.listdir(dirs[0]))
2875
- for d in dirs[1:]:
2876
- common_files.intersection_update(os.listdir(d))
2877
- common_files = list(common_files)
2878
-
2879
- results = []
2880
- conditions = [os.path.basename(d) for d in dirs]
2881
-
2882
- for index, filename in enumerate(common_files):
2883
- print(f'Processing image {index+1}/{len(common_files)}', end='\r', flush=True)
2884
- paths = [os.path.join(d, filename) for d in dirs]
2701
+ generate_masks_from_imgs(src, model, model_name, settings['batch_size'], settings['diameter'], settings['CP_prob'], settings['flow_threshold'], settings['grayscale'], settings['save'], settings['normalize'], settings['channels'], settings['percentiles'], settings['circular'], settings['invert'], settings['plot'], settings['resize'], settings['target_height'], settings['target_width'], settings['remove_background'], settings['background'], settings['Signal_to_noise'], settings['verbose'])
2885
2702
 
2886
- # Check if file exists in all directories
2887
- if not all(os.path.exists(path) for path in paths):
2888
- print(f'Skipping {filename} as it is not present in all directories.')
2889
- continue
2890
-
2891
- masks = [_read_mask(path) for path in paths]
2892
- boundaries = [extract_boundaries(mask) for mask in masks]
2893
-
2894
- if verbose:
2895
- visualize_cellpose_masks(masks, titles=conditions, comparison_title=f"Masks Comparison for {filename}")
2896
-
2897
- # Initialize data structure for results
2898
- file_results = {'filename': filename}
2899
-
2900
- # Compare each mask with each other
2901
- for i in range(len(masks)):
2902
- for j in range(i + 1, len(masks)):
2903
- condition_i = conditions[i]
2904
- condition_j = conditions[j]
2905
- mask_i = masks[i]
2906
- mask_j = masks[j]
2907
-
2908
- # Compute metrics
2909
- boundary_f1 = boundary_f1_score(mask_i, mask_j)
2910
- jaccard = jaccard_index(mask_i, mask_j)
2911
- average_precision = compute_segmentation_ap(mask_i, mask_j)
2703
+ return
2912
2704
 
2913
- # Store results
2914
- file_results[f'jaccard_{condition_i}_{condition_j}'] = jaccard
2915
- file_results[f'boundary_f1_{condition_i}_{condition_j}'] = boundary_f1
2916
- file_results[f'average_precision_{condition_i}_{condition_j}'] = average_precision
2705
+ def save_results_and_figure(src, fig, results):
2917
2706
 
2918
- results.append(file_results)
2707
+ if not isinstance(results, pd.DataFrame):
2708
+ results = pd.DataFrame(results)
2919
2709
 
2920
- fig = plot_comparison_results(results)
2921
- return results, fig
2710
+ results_dir = os.path.join(src, 'results')
2711
+ os.makedirs(results_dir, exist_ok=True)
2712
+ results_path = os.path.join(results_dir,f'results.csv')
2713
+ fig_path = os.path.join(results_dir, f'model_comparison_plot.pdf')
2714
+ results.to_csv(results_path, index=False)
2715
+ fig.savefig(fig_path, format='pdf')
2716
+ print(f'Saved figure to {fig_path} and results to {results_path}')
2922
2717
 
2923
2718
  def compare_mask(args):
2924
2719
  src, filename, dirs, conditions = args
@@ -2949,10 +2744,11 @@ def compare_mask(args):
2949
2744
 
2950
2745
  return file_results
2951
2746
 
2952
- def compare_cellpose_masks(src, verbose=False, processes=None):
2747
+ def compare_cellpose_masks(src, verbose=False, processes=None, save=True):
2953
2748
  from .plot import visualize_cellpose_masks, plot_comparison_results
2954
2749
  from .io import _read_mask
2955
- dirs = [os.path.join(src, d) for d in os.listdir(src) if os.path.isdir(os.path.join(src, d))]
2750
+
2751
+ dirs = [os.path.join(src, d) for d in os.listdir(src) if os.path.isdir(os.path.join(src, d)) and d != 'results']
2956
2752
  dirs.sort() # Optional: sort directories if needed
2957
2753
  conditions = [os.path.basename(d) for d in dirs]
2958
2754
 
@@ -2969,16 +2765,16 @@ def compare_cellpose_masks(src, verbose=False, processes=None):
2969
2765
 
2970
2766
  # Filter out None results (from skipped files)
2971
2767
  results = [res for res in results if res is not None]
2972
-
2768
+ #print(results)
2973
2769
  if verbose:
2974
2770
  for result in results:
2975
2771
  filename = result['filename']
2976
2772
  masks = [_read_mask(os.path.join(d, filename)) for d in dirs]
2977
- visualize_cellpose_masks(masks, titles=conditions, comparison_title=f"Masks Comparison for {filename}")
2773
+ visualize_cellpose_masks(masks, titles=conditions, filename=filename, save=save, src=src)
2978
2774
 
2979
2775
  fig = plot_comparison_results(results)
2980
- return results, fig
2981
-
2776
+ save_results_and_figure(src, fig, results)
2777
+ return
2982
2778
 
2983
2779
  def _calculate_similarity(df, features, col_to_compare, val1, val2):
2984
2780
  """
@@ -3060,6 +2856,8 @@ def _permutation_importance(df, feature_string='channel_3', col_to_compare='col'
3060
2856
  pandas.DataFrame: DataFrame containing the importances and standard deviations.
3061
2857
  """
3062
2858
 
2859
+ from .utils import filter_dataframe_features
2860
+
3063
2861
  if 'cells_per_well' in df.columns:
3064
2862
  df = df.drop(columns=['cells_per_well'])
3065
2863
 
@@ -3074,33 +2872,12 @@ def _permutation_importance(df, feature_string='channel_3', col_to_compare='col'
3074
2872
  # Combine the subsets for analysis
3075
2873
  combined_df = pd.concat([df1, df2])
3076
2874
 
3077
- # Automatically select numerical features
3078
- features = combined_df.select_dtypes(include=[np.number]).columns.tolist()
3079
- features.remove('target')
3080
-
3081
- if clean:
3082
- combined_df = combined_df.loc[:, combined_df.nunique() > 1]
3083
- features = [feature for feature in features if feature in combined_df.columns]
3084
-
3085
- if feature_string is not None:
3086
- feature_list = ['channel_0', 'channel_1', 'channel_2', 'channel_3']
3087
-
3088
- # Remove feature_string from the list if it exists
3089
- if feature_string in feature_list:
3090
- feature_list.remove(feature_string)
3091
-
3092
- features = [feature for feature in features if feature_string in feature]
2875
+ if feature_string in ['channel_0', 'channel_1', 'channel_2', 'channel_3']:
2876
+ channel_of_interest = int(feature_string.split('_')[-1])
2877
+ elif not feature_string is 'morphology':
2878
+ channel_of_interest = 'morphology'
3093
2879
 
3094
- # Iterate through the list and remove columns from df
3095
- for feature_ in feature_list:
3096
- features = [feature for feature in features if feature_ not in feature]
3097
- print(f'After removing {feature_} features: {len(features)}')
3098
-
3099
- if exclude:
3100
- if isinstance(exclude, list):
3101
- features = [feature for feature in features if feature not in exclude]
3102
- else:
3103
- features.remove(exclude)
2880
+ _, features = filter_dataframe_features(combined_df, channel_of_interest, exclude)
3104
2881
 
3105
2882
  X = combined_df[features]
3106
2883
  y = combined_df['target']
@@ -3333,4 +3110,363 @@ def jitterplot_by_annotation(src, x_column, y_column, plot_title='Jitter Plot',
3333
3110
  else:
3334
3111
  plt.show()
3335
3112
 
3336
- return balanced_df
3113
+ return balanced_df
3114
+
3115
+ def generate_image_umap(settings={}):
3116
+ """
3117
+ Generate UMAP or tSNE embedding and visualize the data with clustering.
3118
+
3119
+ Parameters:
3120
+ settings (dict): Dictionary containing the following keys:
3121
+ src (str): Source directory containing the data.
3122
+ row_limit (int): Limit the number of rows to process.
3123
+ tables (list): List of table names to read from the database.
3124
+ visualize (str): Visualization type.
3125
+ image_nr (int): Number of images to display.
3126
+ dot_size (int): Size of dots in the scatter plot.
3127
+ n_neighbors (int): Number of neighbors for UMAP.
3128
+ figuresize (int): Size of the figure.
3129
+ black_background (bool): Whether to use a black background.
3130
+ remove_image_canvas (bool): Whether to remove the image canvas.
3131
+ plot_outlines (bool): Whether to plot outlines.
3132
+ plot_points (bool): Whether to plot points.
3133
+ smooth_lines (bool): Whether to smooth lines.
3134
+ verbose (bool): Whether to print verbose output.
3135
+ embedding_by_controls (bool): Whether to use embedding from controls.
3136
+ col_to_compare (str): Column to compare for control-based embedding.
3137
+ pos (str): Positive control value.
3138
+ neg (str): Negative control value.
3139
+ clustering (str): Clustering method ('DBSCAN' or 'KMeans').
3140
+ exclude (list): List of columns to exclude from the analysis.
3141
+ plot_images (bool): Whether to plot images.
3142
+ reduction_method (str): Dimensionality reduction method ('UMAP' or 'tSNE').
3143
+ save_figure (bool): Whether to save the figure as a PDF.
3144
+
3145
+ Returns:
3146
+ pd.DataFrame: DataFrame with the original data and an additional column 'cluster' containing the cluster identity.
3147
+ """
3148
+
3149
+ from .io import _read_and_join_tables
3150
+ from .utils import get_db_paths, preprocess_data, reduction_and_clustering, remove_noise, generate_colors, correct_paths, plot_embedding, plot_clusters_grid, get_umap_image_settings
3151
+ from .alpha import cluster_feature_analysis, generate_umap_from_images
3152
+
3153
+ settings = get_umap_image_settings(settings)
3154
+
3155
+ if isinstance(settings['src'], str):
3156
+ settings['src'] = [settings['src']]
3157
+
3158
+ if settings['plot_images'] is False:
3159
+ settings['black_background'] = False
3160
+
3161
+ if settings['color_by']:
3162
+ settings['remove_cluster_noise'] = False
3163
+ settings['plot_outlines'] = False
3164
+ settings['smooth_lines'] = False
3165
+
3166
+ settings_df = pd.DataFrame(list(settings.items()), columns=['Key', 'Value'])
3167
+ settings_dir = os.path.join(settings['src'][0],'settings')
3168
+ settings_csv = os.path.join(settings_dir,'embedding_settings.csv')
3169
+ os.makedirs(settings_dir, exist_ok=True)
3170
+ settings_df.to_csv(settings_csv, index=False)
3171
+ display(settings_df)
3172
+
3173
+ db_paths = get_db_paths(settings['src'])
3174
+
3175
+ tables = settings['tables'] + ['png_list']
3176
+ all_df = pd.DataFrame()
3177
+ #image_paths = []
3178
+
3179
+ for i,db_path in enumerate(db_paths):
3180
+ df = _read_and_join_tables(db_path, table_names=tables)
3181
+ df, image_paths_tmp = correct_paths(df, settings['src'][i])
3182
+ all_df = pd.concat([all_df, df], axis=0)
3183
+ #image_paths.extend(image_paths_tmp)
3184
+
3185
+ all_df['cond'] = all_df['col'].apply(map_condition, neg=settings['neg'], pos=settings['pos'], mix=settings['mix'])
3186
+
3187
+ if settings['exclude_conditions']:
3188
+ if isinstance(settings['exclude_conditions'], str):
3189
+ settings['exclude_conditions'] = [settings['exclude_conditions']]
3190
+ row_count_before = len(all_df)
3191
+ all_df = all_df[~all_df['cond'].isin(settings['exclude_conditions'])]
3192
+ if settings['verbose']:
3193
+ print(f'Excluded {row_count_before - len(all_df)} rows after excluding: {settings["exclude_conditions"]}, rows left: {len(all_df)}')
3194
+
3195
+ if settings['row_limit'] is not None:
3196
+ all_df = all_df.sample(n=settings['row_limit'], random_state=42)
3197
+
3198
+ image_paths = all_df['png_path'].to_list()
3199
+
3200
+ if settings['embedding_by_controls']:
3201
+
3202
+ # Extract and reset the index for the column to compare
3203
+ col_to_compare = all_df[settings['col_to_compare']].reset_index(drop=True)
3204
+
3205
+ # Preprocess the data to obtain numeric data
3206
+ numeric_data = preprocess_data(all_df, settings['filter_by'], settings['remove_highly_correlated'], settings['log_data'], settings['exclude'])
3207
+
3208
+ # Convert numeric_data back to a DataFrame to align with col_to_compare
3209
+ numeric_data_df = pd.DataFrame(numeric_data)
3210
+
3211
+ # Ensure numeric_data_df and col_to_compare are properly aligned
3212
+ numeric_data_df = numeric_data_df.reset_index(drop=True)
3213
+
3214
+ # Assign the column back to numeric_data_df
3215
+ numeric_data_df[settings['col_to_compare']] = col_to_compare
3216
+
3217
+ # Subset the dataframe based on specified column values for controls
3218
+ positive_control_df = numeric_data_df[numeric_data_df[settings['col_to_compare']] == settings['pos']].copy()
3219
+ negative_control_df = numeric_data_df[numeric_data_df[settings['col_to_compare']] == settings['neg']].copy()
3220
+ control_numeric_data_df = pd.concat([positive_control_df, negative_control_df])
3221
+
3222
+ # Drop the comparison column from numeric_data_df and control_numeric_data_df
3223
+ numeric_data_df = numeric_data_df.drop(columns=[settings['col_to_compare']])
3224
+ control_numeric_data_df = control_numeric_data_df.drop(columns=[settings['col_to_compare']])
3225
+
3226
+ # Convert numeric_data_df and control_numeric_data_df back to numpy arrays
3227
+ numeric_data = numeric_data_df.values
3228
+ control_numeric_data = control_numeric_data_df.values
3229
+
3230
+ # Train the reducer on control data
3231
+ _, _, reducer = reduction_and_clustering(control_numeric_data, settings['n_neighbors'], settings['min_dist'], settings['metric'], settings['eps'], settings['min_samples'], settings['clustering'], settings['reduction_method'], settings['verbose'], n_jobs=settings['n_jobs'], mode='fit', model=False)
3232
+
3233
+ # Apply the trained reducer to the entire dataset
3234
+ numeric_data = preprocess_data(all_df, settings['filter_by'], settings['remove_highly_correlated'], settings['log_data'], settings['exclude'])
3235
+ embedding, labels, _ = reduction_and_clustering(numeric_data, settings['n_neighbors'], settings['min_dist'], settings['metric'], settings['eps'], settings['min_samples'], settings['clustering'], settings['reduction_method'], settings['verbose'], n_jobs=settings['n_jobs'], mode=None, model=reducer)
3236
+
3237
+ else:
3238
+ if settings['resnet_features']:
3239
+ numeric_data, embedding, labels = generate_umap_from_images(image_paths, settings['n_neighbors'], settings['min_dist'], settings['metric'], settings['clustering'], settings['eps'], settings['min_samples'], settings['n_jobs'], settings['verbose'])
3240
+ else:
3241
+ # Apply the trained reducer to the entire dataset
3242
+ numeric_data = preprocess_data(all_df, settings['filter_by'], settings['remove_highly_correlated'], settings['log_data'], settings['exclude'])
3243
+ embedding, labels, _ = reduction_and_clustering(numeric_data, settings['n_neighbors'], settings['min_dist'], settings['metric'], settings['eps'], settings['min_samples'], settings['clustering'], settings['reduction_method'], settings['verbose'], n_jobs=settings['n_jobs'])
3244
+
3245
+ if settings['remove_cluster_noise']:
3246
+ # Remove noise from the clusters (removes -1 labels from DBSCAN)
3247
+ embedding, labels = remove_noise(embedding, labels)
3248
+
3249
+ # Plot the results
3250
+ if settings['color_by']:
3251
+ if settings['embedding_by_controls']:
3252
+ labels = all_df[settings['color_by']]
3253
+ else:
3254
+ labels = all_df[settings['color_by']]
3255
+
3256
+ # Generate colors for the clusters
3257
+ colors = generate_colors(len(np.unique(labels)), settings['black_background'])
3258
+
3259
+ # Plot the embedding
3260
+ umap_plt = plot_embedding(embedding, image_paths, labels, settings['image_nr'], settings['img_zoom'], colors, settings['plot_by_cluster'], settings['plot_outlines'], settings['plot_points'], settings['plot_images'], settings['smooth_lines'], settings['black_background'], settings['figuresize'], settings['dot_size'], settings['remove_image_canvas'], settings['verbose'])
3261
+ if settings['plot_cluster_grids'] and settings['plot_images']:
3262
+ grid_plt = plot_clusters_grid(embedding, labels, settings['image_nr'], image_paths, colors, settings['figuresize'], settings['black_background'], settings['verbose'])
3263
+
3264
+ # Save figure as PDF if required
3265
+ if settings['save_figure']:
3266
+ results_dir = os.path.join(settings['src'][0], 'results')
3267
+ os.makedirs(results_dir, exist_ok=True)
3268
+ reduction_method = settings['reduction_method'].upper()
3269
+ embedding_path = os.path.join(results_dir, f'{reduction_method}_embedding.pdf')
3270
+ umap_plt.savefig(embedding_path, format='pdf')
3271
+ print(f'Saved {reduction_method} embedding to {embedding_path} and grid to {embedding_path}')
3272
+ if settings['plot_cluster_grids'] and settings['plot_images']:
3273
+ grid_path = os.path.join(results_dir, f'{reduction_method}_grid.pdf')
3274
+ grid_plt.savefig(grid_path, format='pdf')
3275
+ print(f'Saved {reduction_method} embedding to {embedding_path} and grid to {grid_path}')
3276
+
3277
+ # Add cluster labels to the dataframe
3278
+ all_df['cluster'] = labels
3279
+
3280
+ # Save the results to a CSV file
3281
+ results_dir = os.path.join(settings['src'][0], 'results')
3282
+ results_csv = os.path.join(results_dir,'embedding_results.csv')
3283
+ os.makedirs(results_dir, exist_ok=True)
3284
+ all_df.to_csv(results_csv, index=False)
3285
+ print(f'Results saved to {results_csv}')
3286
+
3287
+ if settings['analyze_clusters']:
3288
+ combined_results = cluster_feature_analysis(all_df)
3289
+ results_dir = os.path.join(settings['src'][0], 'results')
3290
+ cluster_results_csv = os.path.join(results_dir,'cluster_results.csv')
3291
+ os.makedirs(results_dir, exist_ok=True)
3292
+ combined_results.to_csv(cluster_results_csv, index=False)
3293
+ print(f'Cluster results saved to {cluster_results_csv}')
3294
+
3295
+ return all_df
3296
+
3297
+ # Define the mapping function
3298
+ def map_condition(col_value, neg='c1', pos='c2', mix='c3'):
3299
+ if col_value == neg:
3300
+ return 'neg'
3301
+ elif col_value == pos:
3302
+ return 'pos'
3303
+ elif col_value == mix:
3304
+ return 'mix'
3305
+ else:
3306
+ return 'screen'
3307
+
3308
+ def reducer_hyperparameter_search(settings={}, reduction_params=None, dbscan_params=None, kmeans_params=None, save=False):
3309
+ """
3310
+ Perform a hyperparameter search for UMAP or tSNE on the given data.
3311
+
3312
+ Parameters:
3313
+ settings (dict): Dictionary containing the following keys:
3314
+ src (str): Source directory containing the data.
3315
+ row_limit (int): Limit the number of rows to process.
3316
+ tables (list): List of table names to read from the database.
3317
+ filter_by (str): Column to filter the data.
3318
+ sample_size (int): Number of samples to use for the hyperparameter search.
3319
+ remove_highly_correlated (bool): Whether to remove highly correlated columns.
3320
+ log_data (bool): Whether to log transform the data.
3321
+ verbose (bool): Whether to print verbose output.
3322
+ reduction_method (str): Dimensionality reduction method ('UMAP' or 'tSNE').
3323
+ reduction_params (list): List of dictionaries containing hyperparameters to test for the reduction method.
3324
+ dbscan_params (list): List of dictionaries containing DBSCAN hyperparameters to test.
3325
+ kmeans_params (list): List of dictionaries containing KMeans hyperparameters to test.
3326
+ pointsize (int): Size of the points in the scatter plot.
3327
+ save (bool): Whether to save the resulting plot as a file.
3328
+
3329
+ Returns:
3330
+ None
3331
+ """
3332
+
3333
+ from .io import _read_and_join_tables
3334
+ from .utils import get_db_paths, preprocess_data, search_reduction_and_clustering, generate_colors, get_umap_image_settings
3335
+
3336
+ settings = get_umap_image_settings(settings)
3337
+ pointsize = settings['dot_size']
3338
+ if isinstance(dbscan_params, dict):
3339
+ dbscan_params = [dbscan_params]
3340
+
3341
+ if isinstance(kmeans_params, dict):
3342
+ kmeans_params = [kmeans_params]
3343
+
3344
+ if isinstance(reduction_params, dict):
3345
+ reduction_params = [reduction_params]
3346
+
3347
+ # Determine reduction method based on the keys in reduction_param
3348
+ if any('n_neighbors' in param for param in reduction_params):
3349
+ reduction_method = 'umap'
3350
+ elif any('perplexity' in param for param in reduction_params):
3351
+ reduction_method = 'tsne'
3352
+ elif any('perplexity' in param for param in reduction_params) and any('n_neighbors' in param for param in reduction_params):
3353
+ raise ValueError("Reduction parameters must include 'n_neighbors' for UMAP or 'perplexity' for tSNE, not both.")
3354
+
3355
+ if settings['reduction_method'].lower() != reduction_method:
3356
+ settings['reduction_method'] = reduction_method
3357
+ print(f'Changed reduction method to {reduction_method} based on the provided parameters.')
3358
+
3359
+ if settings['verbose']:
3360
+ display(pd.DataFrame(list(settings.items()), columns=['Key', 'Value']))
3361
+
3362
+ db_paths = get_db_paths(settings['src'])
3363
+
3364
+ tables = settings['tables']
3365
+ all_df = pd.DataFrame()
3366
+ for db_path in db_paths:
3367
+ df = _read_and_join_tables(db_path, table_names=tables)
3368
+ all_df = pd.concat([all_df, df], axis=0)
3369
+
3370
+ all_df['cond'] = all_df['col'].apply(map_condition, neg=settings['neg'], pos=settings['pos'], mix=settings['mix'])
3371
+
3372
+ if settings['exclude_conditions']:
3373
+ if isinstance(settings['exclude_conditions'], str):
3374
+ settings['exclude_conditions'] = [settings['exclude_conditions']]
3375
+ row_count_before = len(all_df)
3376
+ all_df = all_df[~all_df['cond'].isin(settings['exclude_conditions'])]
3377
+ if settings['verbose']:
3378
+ print(f'Excluded {row_count_before - len(all_df)} rows after excluding: {settings["exclude_conditions"]}, rows left: {len(all_df)}')
3379
+
3380
+ if settings['row_limit'] is not None:
3381
+ all_df = all_df.sample(n=settings['row_limit'], random_state=42)
3382
+
3383
+ numeric_data = preprocess_data(all_df, settings['filter_by'], settings['remove_highly_correlated'], settings['log_data'], settings['exclude'])
3384
+
3385
+ # Combine DBSCAN and KMeans parameters
3386
+ clustering_params = []
3387
+ if dbscan_params:
3388
+ for param in dbscan_params:
3389
+ param['method'] = 'dbscan'
3390
+ clustering_params.append(param)
3391
+ if kmeans_params:
3392
+ for param in kmeans_params:
3393
+ param['method'] = 'kmeans'
3394
+ clustering_params.append(param)
3395
+
3396
+ print('Testing paramiters:', reduction_params)
3397
+ print('Testing clustering paramiters:', clustering_params)
3398
+
3399
+ # Calculate the grid size
3400
+ grid_rows = len(reduction_params)
3401
+ grid_cols = len(clustering_params)
3402
+
3403
+ fig_width = grid_cols*10
3404
+ fig_height = grid_rows*10
3405
+
3406
+ fig, axs = plt.subplots(grid_rows, grid_cols, figsize=(fig_width, fig_height))
3407
+
3408
+ # Make sure axs is always an array of axes
3409
+ axs = np.atleast_1d(axs)
3410
+
3411
+ # Iterate through the Cartesian product of reduction and clustering hyperparameters
3412
+ for i, reduction_param in enumerate(reduction_params):
3413
+ for j, clustering_param in enumerate(clustering_params):
3414
+ if len(clustering_params) <= 1:
3415
+ axs[i].axis('off')
3416
+ ax = axs[i]
3417
+ elif len(reduction_params) <= 1:
3418
+ axs[j].axis('off')
3419
+ ax = axs[j]
3420
+ else:
3421
+ ax = axs[i, j]
3422
+
3423
+ # Perform dimensionality reduction and clustering
3424
+ if settings['reduction_method'].lower() == 'umap':
3425
+ n_neighbors = reduction_param.get('n_neighbors', 15)
3426
+
3427
+ if isinstance(n_neighbors, float):
3428
+ n_neighbors = int(n_neighbors * len(numeric_data))
3429
+
3430
+ min_dist = reduction_param.get('min_dist', 0.1)
3431
+ embedding, labels = search_reduction_and_clustering(numeric_data, n_neighbors, min_dist, settings['metric'],
3432
+ clustering_param.get('eps', 0.5), clustering_param.get('min_samples', 5),
3433
+ clustering_param['method'], settings['reduction_method'], settings['verbose'], reduction_param, n_jobs=settings['n_jobs'])
3434
+
3435
+ elif settings['reduction_method'].lower() == 'tsne':
3436
+ perplexity = reduction_param.get('perplexity', 30)
3437
+
3438
+ if isinstance(perplexity, float):
3439
+ perplexity = int(perplexity * len(numeric_data))
3440
+
3441
+ embedding, labels = search_reduction_and_clustering(numeric_data, perplexity, 0.1, settings['metric'],
3442
+ clustering_param.get('eps', 0.5), clustering_param.get('min_samples', 5),
3443
+ clustering_param['method'], settings['reduction_method'], settings['verbose'], reduction_param, n_jobs=settings['n_jobs'])
3444
+
3445
+ else:
3446
+ raise ValueError(f"Unsupported reduction method: {settings['reduction_method']}. Supported methods are 'UMAP' and 'tSNE'")
3447
+
3448
+ # Plot the results
3449
+ if settings['color_by']:
3450
+ unique_groups = all_df[settings['color_by']].unique()
3451
+ colors = generate_colors(len(unique_groups), False)
3452
+ for group, color in zip(unique_groups, colors):
3453
+ indices = all_df[settings['color_by']] == group
3454
+ ax.scatter(embedding[indices, 0], embedding[indices, 1], s=pointsize, label=f"{group}", color=color)
3455
+ else:
3456
+ unique_labels = np.unique(labels)
3457
+ colors = generate_colors(len(unique_labels), False)
3458
+ for label, color in zip(unique_labels, colors):
3459
+ ax.scatter(embedding[labels == label, 0], embedding[labels == label, 1], s=pointsize, label=f"Cluster {label}", color=color)
3460
+
3461
+ ax.set_title(f"{settings['reduction_method']} {reduction_param}\n{clustering_param['method']} {clustering_param}")
3462
+ ax.legend()
3463
+
3464
+ plt.tight_layout()
3465
+ if save:
3466
+ results_dir = os.path.join(settings['src'], 'results')
3467
+ os.makedirs(results_dir, exist_ok=True)
3468
+ plt.savefig(os.path.join(results_dir, 'hyperparameter_search.pdf'))
3469
+ else:
3470
+ plt.show()
3471
+
3472
+ return