spacr 0.0.1__py3-none-any.whl → 0.0.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
spacr/core.py CHANGED
@@ -1,12 +1,10 @@
1
- import os, sqlite3, gc, torch, time, random, shutil, cv2, tarfile, datetime
1
+ import os, sqlite3, gc, torch, time, random, shutil, cv2, tarfile, datetime, shap
2
2
 
3
- # image and array processing
4
3
  import numpy as np
5
4
  import pandas as pd
6
5
 
7
- import cellpose
6
+ from cellpose import train
8
7
  from cellpose import models as cp_models
9
- from cellpose import denoise
10
8
 
11
9
  import statsmodels.formula.api as smf
12
10
  import statsmodels.api as sm
@@ -15,31 +13,37 @@ from IPython.display import display
15
13
  from multiprocessing import Pool, cpu_count, Value, Lock
16
14
 
17
15
  import seaborn as sns
18
- import matplotlib.pyplot as plt
16
+
19
17
  from skimage.measure import regionprops, label
20
- import skimage.measure as measure
18
+ from skimage.morphology import square
21
19
  from skimage.transform import resize as resizescikit
22
- from sklearn.model_selection import train_test_split
23
20
  from collections import defaultdict
24
- import multiprocessing
25
21
  from torch.utils.data import DataLoader, random_split
26
- import matplotlib
27
- matplotlib.use('Agg')
22
+ from sklearn.cluster import KMeans
23
+ from sklearn.decomposition import PCA
28
24
 
29
- import torchvision.transforms as transforms
25
+ from skimage import measure
30
26
  from sklearn.model_selection import train_test_split
31
- from sklearn.ensemble import IsolationForest
27
+ from sklearn.ensemble import IsolationForest, RandomForestClassifier, HistGradientBoostingClassifier
28
+ from sklearn.linear_model import LogisticRegression
29
+ from sklearn.inspection import permutation_importance
30
+ from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, classification_report
31
+ from sklearn.preprocessing import StandardScaler
32
32
 
33
- from .logger import log_function_call
33
+ from scipy.ndimage import binary_dilation
34
+ from scipy.spatial.distance import cosine, euclidean, mahalanobis, cityblock, minkowski, chebyshev, hamming, jaccard, braycurtis
35
+
36
+ import torchvision.transforms as transforms
37
+ from xgboost import XGBClassifier
38
+ import shap
39
+
40
+ import matplotlib.pyplot as plt
41
+ import matplotlib
42
+ matplotlib.use('Agg')
43
+ #import matplotlib.pyplot as plt
34
44
 
35
- #from .io import TarImageDataset, NoClassDataset, MyDataset, read_db, _copy_missclassified, read_mask, load_normalized_images_and_labels, load_images_and_labels
36
- #from .plot import plot_merged, plot_arrays, _plot_controls, _plot_recruitment, _imshow, _plot_histograms_and_stats, _reg_v_plot, visualize_masks, plot_comparison_results
37
- #from .utils import extract_boundaries, boundary_f1_score, compute_segmentation_ap, jaccard_index, dice_coefficient, _object_filter
38
- #from .utils import resize_images_and_labels, generate_fraction_map, MLR, fishers_odds, lasso_reg, model_metrics, _map_wells_png, check_multicollinearity, init_globals, add_images_to_tar
39
- #from .utils import get_paths_from_db, pick_best_model, test_model_performance, evaluate_model_performance, compute_irm_penalty
40
- #from .utils import _pivot_counts_table, _generate_masks, _get_cellpose_channels, annotate_conditions, _calculate_recruitment, calculate_loss, _group_by_well, choose_model
45
+ from .logger import log_function_call
41
46
 
42
- @log_function_call
43
47
  def analyze_plaques(folder):
44
48
  summary_data = []
45
49
  details_data = []
@@ -76,171 +80,95 @@ def analyze_plaques(folder):
76
80
 
77
81
  print(f"Analysis completed and saved to database '{db_name}'.")
78
82
 
79
- @log_function_call
80
- def compare_masks(dir1, dir2, dir3, verbose=False):
81
-
82
- from .io import _read_mask
83
- from .plot import visualize_masks, plot_comparison_results
84
- from .utils import extract_boundaries, boundary_f1_score, compute_segmentation_ap, jaccard_index, dice_coefficient
85
-
86
- filenames = os.listdir(dir1)
87
- results = []
88
- cond_1 = os.path.basename(dir1)
89
- cond_2 = os.path.basename(dir2)
90
- cond_3 = os.path.basename(dir3)
91
- for index, filename in enumerate(filenames):
92
- print(f'Processing image:{index+1}', end='\r', flush=True)
93
- path1, path2, path3 = os.path.join(dir1, filename), os.path.join(dir2, filename), os.path.join(dir3, filename)
94
- if os.path.exists(path2) and os.path.exists(path3):
95
-
96
- mask1, mask2, mask3 = _read_mask(path1), _read_mask(path2), _read_mask(path3)
97
- boundary_true1, boundary_true2, boundary_true3 = extract_boundaries(mask1), extract_boundaries(mask2), extract_boundaries(mask3)
98
-
99
-
100
- true_masks, pred_masks = [mask1], [mask2, mask3] # Assuming mask1 is the ground truth for simplicity
101
- true_labels, pred_labels_1, pred_labels_2 = label(mask1), label(mask2), label(mask3)
102
- average_precision_0, average_precision_1 = compute_segmentation_ap(mask1, mask2), compute_segmentation_ap(mask1, mask3)
103
- ap_scores = [average_precision_0, average_precision_1]
104
-
105
- if verbose:
106
- unique_values1, unique_values2, unique_values3 = np.unique(mask1), np.unique(mask2), np.unique(mask3)
107
- print(f"Unique values in mask 1: {unique_values1}, mask 2: {unique_values2}, mask 3: {unique_values3}")
108
- visualize_masks(boundary_true1, boundary_true2, boundary_true3, title=f"Boundaries - {filename}")
109
-
110
- boundary_f1_12, boundary_f1_13, boundary_f1_23 = boundary_f1_score(mask1, mask2), boundary_f1_score(mask1, mask3), boundary_f1_score(mask2, mask3)
111
-
112
- if (np.unique(mask1).size == 1 and np.unique(mask1)[0] == 0) and \
113
- (np.unique(mask2).size == 1 and np.unique(mask2)[0] == 0) and \
114
- (np.unique(mask3).size == 1 and np.unique(mask3)[0] == 0):
115
- continue
116
-
117
- if verbose:
118
- unique_values4, unique_values5, unique_values6 = np.unique(boundary_f1_12), np.unique(boundary_f1_13), np.unique(boundary_f1_23)
119
- print(f"Unique values in boundary mask 1: {unique_values4}, mask 2: {unique_values5}, mask 3: {unique_values6}")
120
- visualize_masks(mask1, mask2, mask3, title=filename)
121
-
122
- jaccard12 = jaccard_index(mask1, mask2)
123
- dice12 = dice_coefficient(mask1, mask2)
124
- jaccard13 = jaccard_index(mask1, mask3)
125
- dice13 = dice_coefficient(mask1, mask3)
126
- jaccard23 = jaccard_index(mask2, mask3)
127
- dice23 = dice_coefficient(mask2, mask3)
128
-
129
- results.append({
130
- f'filename': filename,
131
- f'jaccard_{cond_1}_{cond_2}': jaccard12,
132
- f'dice_{cond_1}_{cond_2}': dice12,
133
- f'jaccard_{cond_1}_{cond_3}': jaccard13,
134
- f'dice_{cond_1}_{cond_3}': dice13,
135
- f'jaccard_{cond_2}_{cond_3}': jaccard23,
136
- f'dice_{cond_2}_{cond_3}': dice23,
137
- f'boundary_f1_{cond_1}_{cond_2}': boundary_f1_12,
138
- f'boundary_f1_{cond_1}_{cond_3}': boundary_f1_13,
139
- f'boundary_f1_{cond_2}_{cond_3}': boundary_f1_23,
140
- f'average_precision_{cond_1}_{cond_2}': ap_scores[0],
141
- f'average_precision_{cond_1}_{cond_3}': ap_scores[1]
142
- })
143
- else:
144
- print(f'Cannot find {path1} or {path2} or {path3}')
145
- fig = plot_comparison_results(results)
146
- return results, fig
147
-
148
- def generate_cp_masks(settings):
149
-
150
- src = settings['src']
151
- model_name = settings['model_name']
152
- channels = settings['channels']
153
- diameter = settings['diameter']
154
- regex = '.tif'
155
- #flow_threshold = 30
156
- cellprob_threshold = settings['cellprob_threshold']
157
- figuresize = 25
158
- cmap = 'inferno'
159
- verbose = settings['verbose']
160
- plot = settings['plot']
161
- save = settings['save']
162
- custom_model = settings['custom_model']
163
- signal_thresholds = 1000
164
- normalize = settings['normalize']
165
- resize = settings['resize']
166
- target_height = settings['width_height'][1]
167
- target_width = settings['width_height'][0]
168
- rescale = settings['rescale']
169
- resample = settings['resample']
170
- net_avg = settings['net_avg']
171
- invert = settings['invert']
172
- circular = settings['circular']
173
- percentiles = settings['percentiles']
174
- overlay = settings['overlay']
175
- grayscale = settings['grayscale']
176
- flow_threshold = settings['flow_threshold']
177
- batch_size = settings['batch_size']
178
-
179
- dst = os.path.join(src,'masks')
180
- os.makedirs(dst, exist_ok=True)
181
-
182
- identify_masks(src, dst, model_name, channels, diameter, batch_size, flow_threshold, cellprob_threshold, figuresize, cmap, verbose, plot, save, custom_model, signal_thresholds, normalize, resize, target_height, target_width, rescale, resample, net_avg, invert, circular, percentiles, overlay, grayscale)
183
-
184
- @log_function_call
185
83
  def train_cellpose(settings):
186
84
 
187
85
  from .io import _load_normalized_images_and_labels, _load_images_and_labels
188
86
  from .utils import resize_images_and_labels
189
87
 
190
88
  img_src = settings['img_src']
191
- mask_src= settings['mask_src']
192
- secondary_image_dir = None
193
- model_name = settings['model_name']
194
- model_type = settings['model_type']
195
- learning_rate = settings['learning_rate']
196
- weight_decay = settings['weight_decay']
197
- batch_size = settings['batch_size']
198
- n_epochs = settings['n_epochs']
199
- verbose = settings['verbose']
200
- signal_thresholds = settings['signal_thresholds']
201
- channels = settings['channels']
202
- from_scratch = settings['from_scratch']
203
- diameter = settings['diameter']
204
- resize = settings['resize']
205
- rescale = settings['rescale']
206
- normalize = settings['normalize']
207
- target_height = settings['width_height'][1]
208
- target_width = settings['width_height'][0]
209
- circular = settings['circular']
210
- invert = settings['invert']
211
- percentiles = settings['percentiles']
212
- grayscale = settings['grayscale']
89
+ mask_src = os.path.join(img_src, 'masks')
213
90
 
91
+ model_name = settings.setdefault( 'model_name', '')
92
+
93
+ model_name = settings.setdefault('model_name', 'model_name')
94
+
95
+ model_type = settings.setdefault( 'model_type', 'cyto')
96
+ learning_rate = settings.setdefault( 'learning_rate', 0.01)
97
+ weight_decay = settings.setdefault( 'weight_decay', 1e-05)
98
+ batch_size = settings.setdefault( 'batch_size', 50)
99
+ n_epochs = settings.setdefault( 'n_epochs', 100)
100
+ from_scratch = settings.setdefault( 'from_scratch', False)
101
+ diameter = settings.setdefault( 'diameter', 40)
102
+
103
+ remove_background = settings.setdefault( 'remove_background', False)
104
+ background = settings.setdefault( 'background', 100)
105
+ Signal_to_noise = settings.setdefault( 'Signal_to_noise', 10)
106
+ verbose = settings.setdefault( 'verbose', False)
107
+
108
+
109
+ channels = settings.setdefault( 'channels', [0,0])
110
+ normalize = settings.setdefault( 'normalize', True)
111
+ percentiles = settings.setdefault( 'percentiles', None)
112
+ circular = settings.setdefault( 'circular', False)
113
+ invert = settings.setdefault( 'invert', False)
114
+ resize = settings.setdefault( 'resize', False)
115
+
116
+ if resize:
117
+ target_height = settings['width_height'][1]
118
+ target_width = settings['width_height'][0]
119
+
120
+ grayscale = settings.setdefault( 'grayscale', True)
121
+ rescale = settings.setdefault( 'channels', False)
122
+ test = settings.setdefault( 'test', False)
123
+
124
+ if test:
125
+ test_img_src = os.path.join(os.path.dirname(img_src), 'test')
126
+ test_mask_src = os.path.join(test_img_src, 'mask')
127
+
128
+ test_images, test_masks, test_image_names, test_mask_names = None,None,None,None,
214
129
  print(settings)
215
130
 
216
131
  if from_scratch:
217
132
  model_name=f'scratch_{model_name}_{model_type}_e{n_epochs}_X{target_width}_Y{target_height}.CP_model'
218
133
  else:
219
- model_name=f'{model_name}_{model_type}_e{n_epochs}_X{target_width}_Y{target_height}.CP_model'
134
+ if resize:
135
+ model_name=f'{model_name}_{model_type}_e{n_epochs}_X{target_width}_Y{target_height}.CP_model'
136
+ else:
137
+ model_name=f'{model_name}_{model_type}_e{n_epochs}.CP_model'
220
138
 
221
139
  model_save_path = os.path.join(mask_src, 'models', 'cellpose_model')
222
- os.makedirs(os.path.dirname(model_save_path), exist_ok=True)
140
+ print(model_save_path)
141
+ os.makedirs(model_save_path, exist_ok=True)
223
142
 
224
143
  settings_df = pd.DataFrame(list(settings.items()), columns=['Key', 'Value'])
225
144
  settings_csv = os.path.join(model_save_path,f'{model_name}_settings.csv')
226
145
  settings_df.to_csv(settings_csv, index=False)
227
146
 
228
- if model_type =='cyto':
229
- if not from_scratch:
230
- model = cp_models.CellposeModel(gpu=True, model_type=model_type)
231
- else:
232
- model = cp_models.CellposeModel(gpu=True, model_type=model_type, net_avg=False, diam_mean=diameter, pretrained_model=None)
233
- if model_type !='cyto':
147
+ if from_scratch:
148
+ model = cp_models.CellposeModel(gpu=True, model_type=model_type, diam_mean=diameter, pretrained_model=None)
149
+ else:
234
150
  model = cp_models.CellposeModel(gpu=True, model_type=model_type)
235
151
 
236
-
237
-
238
- if normalize:
239
- images, masks, image_names, mask_names = _load_normalized_images_and_labels(image_dir=img_src, label_dir=mask_src, secondary_image_dir=secondary_image_dir, signal_thresholds=signal_thresholds, channels=channels, percentiles=percentiles, circular=circular, invert=invert, visualize=verbose)
152
+ if normalize:
153
+
154
+ image_files = [os.path.join(img_src, f) for f in os.listdir(img_src) if f.endswith('.tif')]
155
+ label_files = [os.path.join(mask_src, f) for f in os.listdir(mask_src) if f.endswith('.tif')]
156
+ images, masks, image_names, mask_names = _load_normalized_images_and_labels(image_files, label_files, channels, percentiles, circular, invert, verbose, remove_background, background, Signal_to_noise)
240
157
  images = [np.squeeze(img) if img.shape[-1] == 1 else img for img in images]
158
+
159
+ if test:
160
+ test_image_files = [os.path.join(test_img_src, f) for f in os.listdir(test_img_src) if f.endswith('.tif')]
161
+ test_label_files = [os.path.join(test_mask_src, f) for f in os.listdir(test_mask_src) if f.endswith('.tif')]
162
+ test_images, test_masks, test_image_names, test_mask_names = _load_normalized_images_and_labels(test_image_files, test_label_files, channels, percentiles, circular, invert, verbose, remove_background, background, Signal_to_noise)
163
+ test_images = [np.squeeze(img) if img.shape[-1] == 1 else img for img in test_images]
164
+
241
165
  else:
242
166
  images, masks, image_names, mask_names = _load_images_and_labels(img_src, mask_src, circular, invert)
243
167
  images = [np.squeeze(img) if img.shape[-1] == 1 else img for img in images]
168
+
169
+ if test:
170
+ test_images, test_masks, test_image_names, test_mask_names = _load_images_and_labels(img_src=test_img_src, mask_src=test_mask_src, circular=circular, invert=invert)
171
+ test_images = [np.squeeze(img) if img.shape[-1] == 1 else img for img in test_images]
244
172
 
245
173
  if resize:
246
174
  images, masks = resize_images_and_labels(images, masks, target_height, target_width, show_example=True)
@@ -259,29 +187,44 @@ def train_cellpose(settings):
259
187
 
260
188
  print(f'image shape: {images[0].shape}, image type: images[0].shape mask shape: {masks[0].shape}, image type: masks[0].shape')
261
189
  save_every = int(n_epochs/10)
262
- print('cellpose image input dtype', images[0].dtype)
263
- print('cellpose mask input dtype', masks[0].dtype)
264
- # Train the model
265
- model.train(train_data=images, #(list of arrays (2D or 3D)) – images for training
266
- train_labels=masks, #(list of arrays (2D or 3D)) – labels for train_data, where 0=no masks; 1,2,…=mask labels can include flows as additional images
267
- train_files=image_names, #(list of strings) – file names for images in train_data (to save flows for future runs)
268
- channels=cp_channels, #(list of ints (default, None)) – channels to use for training
269
- normalize=False, #(bool (default, True)) – normalize data so 0.0=1st percentile and 1.0=99th percentile of image intensities in each channel
270
- save_path=model_save_path, #(string (default, None)) – where to save trained model, if None it is not saved
271
- save_every=save_every, #(int (default, 100)) – save network every [save_every] epochs
272
- learning_rate=learning_rate, #(float or list/np.ndarray (default, 0.2)) – learning rate for training, if list, must be same length as n_epochs
273
- n_epochs=n_epochs, #(int (default, 500)) – how many times to go through whole training set during training
274
- weight_decay=weight_decay, #(float (default, 0.00001)) –
275
- SGD=True, #(bool (default, True)) – use SGD as optimization instead of RAdam
276
- batch_size=batch_size, #(int (optional, default 8)) – number of 224x224 patches to run simultaneously on the GPU (can make smaller or bigger depending on GPU memory usage)
277
- nimg_per_epoch=None, #(int (optional, default None)) – minimum number of images to train on per epoch, with a small training set (< 8 images) it may help to set to 8
278
- rescale=rescale, #(bool (default, True)) – whether or not to rescale images to diam_mean during training, if True it assumes you will fit a size model after training or resize your images accordingly, if False it will try to train the model to be scale-invariant (works worse)
279
- min_train_masks=1, #(int (default, 5)) – minimum number of masks an image must have to use in training set
280
- model_name=model_name) #(str (default, None)) – name of network, otherwise saved with name as params + training start time
190
+ if save_every < 10:
191
+ save_every = n_epochs
192
+
193
+ train.train_seg(model.net,
194
+ train_data=images,
195
+ train_labels=masks,
196
+ train_files=image_names,
197
+ train_labels_files=mask_names,
198
+ train_probs=None,
199
+ test_data=test_images,
200
+ test_labels=test_masks,
201
+ test_files=test_image_names,
202
+ test_labels_files=test_mask_names,
203
+ test_probs=None,
204
+ load_files=True,
205
+ batch_size=batch_size,
206
+ learning_rate=learning_rate,
207
+ n_epochs=n_epochs,
208
+ weight_decay=weight_decay,
209
+ momentum=0.9,
210
+ SGD=False,
211
+ channels=cp_channels,
212
+ channel_axis=None,
213
+ #rgb=False,
214
+ normalize=False,
215
+ compute_flows=False,
216
+ save_path=model_save_path,
217
+ save_every=save_every,
218
+ nimg_per_epoch=None,
219
+ nimg_test_per_epoch=None,
220
+ rescale=rescale,
221
+ #scale_range=None,
222
+ #bsize=224,
223
+ min_train_masks=1,
224
+ model_name=model_name)
281
225
 
282
226
  return print(f"Model saved at: {model_save_path}/{model_name}")
283
227
 
284
- @log_function_call
285
228
  def analyze_data_reg(sequencing_loc, dv_loc, agg_type = 'mean', dv_col='pred', transform=None, min_cell_count=50, min_reads=100, min_wells=2, max_wells=1000, min_frequency=0.0,remove_outlier_genes=False, refine_model=False,by_plate=False, regression_type='mlr', alpha_value=0.01, fishers=False, fisher_threshold=0.9):
286
229
 
287
230
  from .plot import _reg_v_plot
@@ -430,7 +373,6 @@ def analyze_data_reg(sequencing_loc, dv_loc, agg_type = 'mean', dv_col='pred', t
430
373
 
431
374
  return result
432
375
 
433
- @log_function_call
434
376
  def analyze_data_reg(sequencing_loc, dv_loc, agg_type = 'mean', min_cell_count=50, min_reads=100, min_wells=2, max_wells=1000, remove_outlier_genes=False, refine_model=False, by_plate=False, threshold=0.5, fishers=False):
435
377
 
436
378
  from .plot import _reg_v_plot
@@ -609,7 +551,6 @@ def analyze_data_reg(sequencing_loc, dv_loc, agg_type = 'mean', min_cell_count=5
609
551
 
610
552
  return max_effects, max_effects_pvalues, model, df
611
553
 
612
- @log_function_call
613
554
  def regression_analasys(dv_df,sequencing_loc, min_reads=75, min_wells=2, max_wells=0, model_type = 'mlr', min_cells=100, transform='logit', min_frequency=0.05, gene_column='gene', effect_size_threshold=0.25, fishers=True, clean_regression=False, VIF_threshold=10):
614
555
 
615
556
  from .utils import generate_fraction_map, fishers_odds, model_metrics, check_multicollinearity
@@ -777,7 +718,6 @@ def regression_analasys(dv_df,sequencing_loc, min_reads=75, min_wells=2, max_wel
777
718
 
778
719
  return
779
720
 
780
- @log_function_call
781
721
  def merge_pred_mes(src,
782
722
  pred_loc,
783
723
  target='protein of interest',
@@ -846,15 +786,6 @@ def merge_pred_mes(src,
846
786
 
847
787
  if verbose:
848
788
  _plot_histograms_and_stats(df=joined_df)
849
-
850
- #dv = joined_df.copy()
851
- #if 'prc' not in dv.columns:
852
- #dv['prc'] = dv['plate'] + '_' + dv['row'] + '_' + dv['col']
853
- #dv = dv[['pred']].groupby('prc').mean()
854
- #dv.set_index('prc', inplace=True)
855
-
856
- #loc = '/mnt/data/CellVoyager/20x/tsg101/crispr_screen/all/measurements/dv.csv'
857
- #dv.to_csv(loc, index=True, header=True, mode='w')
858
789
 
859
790
  return joined_df
860
791
 
@@ -941,30 +872,38 @@ def annotate_results(pred_loc):
941
872
  display(df)
942
873
  return df
943
874
 
944
- def generate_dataset(src, file_type=None, experiment='TSG101_screen', sample=None):
875
+ def generate_dataset(src, file_metadata=None, experiment='TSG101_screen', sample=None):
945
876
 
946
- from .utils import init_globals, add_images_to_tar
947
-
948
- db_path = os.path.join(src, 'measurements','measurements.db')
877
+ from .utils import initiate_counter, add_images_to_tar
878
+
879
+ db_path = os.path.join(src, 'measurements', 'measurements.db')
949
880
  dst = os.path.join(src, 'datasets')
950
-
951
- global total_images
952
881
  all_paths = []
953
-
882
+
954
883
  # Connect to the database and retrieve the image paths
955
884
  print(f'Reading DataBase: {db_path}')
956
- with sqlite3.connect(db_path) as conn:
957
- cursor = conn.cursor()
958
- if file_type:
959
- cursor.execute("SELECT png_path FROM png_list WHERE png_path LIKE ?", (f"%{file_type}%",))
960
- else:
961
- cursor.execute("SELECT png_path FROM png_list")
962
- while True:
963
- rows = cursor.fetchmany(1000)
964
- if not rows:
965
- break
966
- all_paths.extend([row[0] for row in rows])
967
-
885
+ try:
886
+ with sqlite3.connect(db_path) as conn:
887
+ cursor = conn.cursor()
888
+ if file_metadata:
889
+ if isinstance(file_metadata, str):
890
+ cursor.execute("SELECT png_path FROM png_list WHERE png_path LIKE ?", (f"%{file_metadata}%",))
891
+ else:
892
+ cursor.execute("SELECT png_path FROM png_list")
893
+
894
+ while True:
895
+ rows = cursor.fetchmany(1000)
896
+ if not rows:
897
+ break
898
+ all_paths.extend([row[0] for row in rows])
899
+
900
+ except sqlite3.Error as e:
901
+ print(f"Database error: {e}")
902
+ return
903
+ except Exception as e:
904
+ print(f"Error: {e}")
905
+ return
906
+
968
907
  if isinstance(sample, int):
969
908
  selected_paths = random.sample(all_paths, sample)
970
909
  print(f'Random selection of {len(selected_paths)} paths')
@@ -972,23 +911,18 @@ def generate_dataset(src, file_type=None, experiment='TSG101_screen', sample=Non
972
911
  selected_paths = all_paths
973
912
  random.shuffle(selected_paths)
974
913
  print(f'All paths: {len(selected_paths)} paths')
975
-
914
+
976
915
  total_images = len(selected_paths)
977
- print(f'found {total_images} images')
978
-
916
+ print(f'Found {total_images} images')
917
+
979
918
  # Create a temp folder in dst
980
919
  temp_dir = os.path.join(dst, "temp_tars")
981
920
  os.makedirs(temp_dir, exist_ok=True)
982
921
 
983
922
  # Chunking the data
984
- if len(selected_paths) > 10000:
985
- num_procs = cpu_count()-2
986
- chunk_size = len(selected_paths) // num_procs
987
- remainder = len(selected_paths) % num_procs
988
- else:
989
- num_procs = 2
990
- chunk_size = len(selected_paths) // 2
991
- remainder = 0
923
+ num_procs = max(2, cpu_count() - 2)
924
+ chunk_size = len(selected_paths) // num_procs
925
+ remainder = len(selected_paths) % num_procs
992
926
 
993
927
  paths_chunks = []
994
928
  start = 0
@@ -998,45 +932,43 @@ def generate_dataset(src, file_type=None, experiment='TSG101_screen', sample=Non
998
932
  start = end
999
933
 
1000
934
  temp_tar_files = [os.path.join(temp_dir, f'temp_{i}.tar') for i in range(num_procs)]
1001
-
1002
- # Initialize the shared objects
1003
- counter_ = Value('i', 0)
1004
- lock_ = Lock()
1005
935
 
1006
- ctx = multiprocessing.get_context('spawn')
1007
-
1008
936
  print(f'Generating temporary tar files in {dst}')
1009
-
937
+
938
+ # Initialize shared counter and lock
939
+ counter = Value('i', 0)
940
+ lock = Lock()
941
+
942
+ with Pool(processes=num_procs, initializer=initiate_counter, initargs=(counter, lock)) as pool:
943
+ pool.starmap(add_images_to_tar, [(paths_chunks[i], temp_tar_files[i], total_images) for i in range(num_procs)])
944
+
1010
945
  # Combine the temporary tar files into a final tar
1011
946
  date_name = datetime.date.today().strftime('%y%m%d')
1012
- tar_name = f'{date_name}_{experiment}_{file_type}.tar'
947
+ if not file_metadata is None:
948
+ tar_name = f'{date_name}_{experiment}_{file_metadata}.tar'
949
+ else:
950
+ tar_name = f'{date_name}_{experiment}.tar'
951
+ tar_name = os.path.join(dst, tar_name)
1013
952
  if os.path.exists(tar_name):
1014
953
  number = random.randint(1, 100)
1015
- tar_name_2 = f'{date_name}_{experiment}_{file_type}_{number}.tar'
1016
- print(f'Warning: {os.path.basename(tar_name)} exists saving as {os.path.basename(tar_name_2)} ')
1017
- tar_name = tar_name_2
1018
-
1019
- # Add the counter and lock to the arguments for pool.map
954
+ tar_name_2 = f'{date_name}_{experiment}_{file_metadata}_{number}.tar'
955
+ print(f'Warning: {os.path.basename(tar_name)} exists, saving as {os.path.basename(tar_name_2)} ')
956
+ tar_name = os.path.join(dst, tar_name_2)
957
+
1020
958
  print(f'Merging temporary files')
1021
- #with Pool(processes=num_procs, initializer=init_globals, initargs=(counter_, lock_)) as pool:
1022
- # results = pool.map(add_images_to_tar, zip(paths_chunks, temp_tar_files))
1023
959
 
1024
- with ctx.Pool(processes=num_procs, initializer=init_globals, initargs=(counter_, lock_)) as pool:
1025
- results = pool.map(add_images_to_tar, zip(paths_chunks, temp_tar_files))
1026
-
1027
- with tarfile.open(os.path.join(dst, tar_name), 'w') as final_tar:
1028
- for tar_path in results:
1029
- with tarfile.open(tar_path, 'r') as t:
1030
- for member in t.getmembers():
1031
- t.extract(member, path=dst)
1032
- final_tar.add(os.path.join(dst, member.name), arcname=member.name)
1033
- os.remove(os.path.join(dst, member.name))
1034
- os.remove(tar_path)
960
+ with tarfile.open(tar_name, 'w') as final_tar:
961
+ for temp_tar_path in temp_tar_files:
962
+ with tarfile.open(temp_tar_path, 'r') as temp_tar:
963
+ for member in temp_tar.getmembers():
964
+ file_obj = temp_tar.extractfile(member)
965
+ final_tar.addfile(member, file_obj)
966
+ os.remove(temp_tar_path)
1035
967
 
1036
968
  # Delete the temp folder
1037
969
  shutil.rmtree(temp_dir)
1038
- print(f"\nSaved {total_images} images to {os.path.join(dst, tar_name)}")
1039
-
970
+ print(f"\nSaved {total_images} images to {tar_name}")
971
+
1040
972
  def apply_model_to_tar(tar_path, model_path, file_type='cell_png', image_size=224, batch_size=64, normalize=True, preload='images', num_workers=10, verbose=False):
1041
973
 
1042
974
  from .io import TarImageDataset, DataLoader
@@ -1088,7 +1020,7 @@ def apply_model_to_tar(tar_path, model_path, file_type='cell_png', image_size=22
1088
1020
  batch_prediction_pos_prob = torch.sigmoid(outputs).cpu().numpy()
1089
1021
  prediction_pos_probs.extend(batch_prediction_pos_prob.tolist())
1090
1022
  filenames_list.extend(filenames)
1091
- print(f'\rbatch: {batch_idx}/{len(data_loader)}', end='\r', flush=True)
1023
+ print(f'batch: {batch_idx}/{len(data_loader)}', end='\r', flush=True)
1092
1024
 
1093
1025
  data = {'path':filenames_list, 'pred':prediction_pos_probs}
1094
1026
  df = pd.DataFrame(data, index=None)
@@ -1143,7 +1075,6 @@ def apply_model(src, model_path, image_size=224, batch_size=64, normalize=True,
1143
1075
  torch.cuda.memory.empty_cache()
1144
1076
  return df
1145
1077
 
1146
-
1147
1078
  def generate_training_data_file_list(src,
1148
1079
  target='protein of interest',
1149
1080
  cell_dim=4,
@@ -1272,7 +1203,14 @@ def generate_training_dataset(src, mode='annotation', annotation_column='test',
1272
1203
 
1273
1204
  db_path = os.path.join(src, 'measurements','measurements.db')
1274
1205
  dst = os.path.join(src, 'datasets', 'training')
1275
-
1206
+
1207
+ if os.path.exists(dst):
1208
+ for i in range(1, 1000):
1209
+ dst = os.path.join(src, 'datasets', f'training_{i}')
1210
+ if not os.path.exists(dst):
1211
+ print(f'Creating new directory for training: {dst}')
1212
+ break
1213
+
1276
1214
  if mode == 'annotation':
1277
1215
  class_paths_ls_2 = []
1278
1216
  class_paths_ls = training_dataset_from_annotation(db_path, dst, annotation_column, annotated_classes=annotated_classes)
@@ -1283,6 +1221,7 @@ def generate_training_dataset(src, mode='annotation', annotation_column='test',
1283
1221
 
1284
1222
  elif mode == 'metadata':
1285
1223
  class_paths_ls = []
1224
+ class_len_ls = []
1286
1225
  [df] = _read_db(db_loc=db_path, tables=['png_list'])
1287
1226
  df['metadata_based_class'] = pd.NA
1288
1227
  for i, class_ in enumerate(classes):
@@ -1290,7 +1229,18 @@ def generate_training_dataset(src, mode='annotation', annotation_column='test',
1290
1229
  df.loc[df[metadata_type_by].isin(ls), 'metadata_based_class'] = class_
1291
1230
 
1292
1231
  for class_ in classes:
1232
+ if size == None:
1233
+ c_s = []
1234
+ for c in classes:
1235
+ c_s_t_df = df[df['metadata_based_class'] == c]
1236
+ c_s.append(len(c_s_t_df))
1237
+ print(f'Found {len(c_s_t_df)} images for class {c}')
1238
+ size = min(c_s)
1239
+ print(f'Using the smallest class size: {size}')
1240
+
1293
1241
  class_temp_df = df[df['metadata_based_class'] == class_]
1242
+ class_len_ls.append(len(class_temp_df))
1243
+ print(f'Found {len(class_temp_df)} images for class {class_}')
1294
1244
  class_paths_temp = random.sample(class_temp_df['png_path'].tolist(), size)
1295
1245
  class_paths_ls.append(class_paths_temp)
1296
1246
 
@@ -1347,7 +1297,8 @@ def generate_training_dataset(src, mode='annotation', annotation_column='test',
1347
1297
 
1348
1298
  return
1349
1299
 
1350
- def generate_loaders(src, train_mode='erm', mode='train', image_size=224, batch_size=32, classes=['nc','pc'], num_workers=None, validation_split=0.0, max_show=2, pin_memory=False, normalize=False, verbose=False):
1300
+ def generate_loaders(src, train_mode='erm', mode='train', image_size=224, batch_size=32, classes=['nc','pc'], num_workers=None, validation_split=0.0, max_show=2, pin_memory=False, normalize=False, channels=[1, 2, 3], verbose=False):
1301
+
1351
1302
  """
1352
1303
  Generate data loaders for training and validation/test datasets.
1353
1304
 
@@ -1364,16 +1315,40 @@ def generate_loaders(src, train_mode='erm', mode='train', image_size=224, batch_
1364
1315
  - pin_memory (bool): Whether to pin memory for faster data transfer.
1365
1316
  - normalize (bool): Whether to normalize the input images.
1366
1317
  - verbose (bool): Whether to print additional information and show images.
1318
+ - channels (list): The list of channels to retain. Options are [1, 2, 3] for all channels, [1, 2] for blue and green, etc.
1367
1319
 
1368
1320
  Returns:
1369
1321
  - train_loaders (list): List of data loaders for training datasets.
1370
1322
  - val_loaders (list): List of data loaders for validation datasets.
1371
1323
  - plate_names (list): List of plate names (only applicable when train_mode is 'irm').
1372
1324
  """
1373
-
1325
+
1374
1326
  from .io import MyDataset
1375
1327
  from .plot import _imshow
1376
-
1328
+ from torchvision import transforms
1329
+ from torch.utils.data import DataLoader, random_split
1330
+ from collections import defaultdict
1331
+ import os
1332
+ import random
1333
+ from PIL import Image
1334
+ from torchvision.transforms import ToTensor
1335
+ from .utils import SelectChannels
1336
+
1337
+ chans = []
1338
+
1339
+ if 'r' in channels:
1340
+ chans.append(1)
1341
+ if 'g' in channels:
1342
+ chans.append(2)
1343
+ if 'b' in channels:
1344
+ chans.append(3)
1345
+
1346
+ channels = chans
1347
+
1348
+ if verbose:
1349
+ print(f'Training a network on channels: {channels}')
1350
+ print(f'Channel 1: Red, Channel 2: Green, Channel 3: Blue')
1351
+
1377
1352
  plate_to_filenames = defaultdict(list)
1378
1353
  plate_to_labels = defaultdict(list)
1379
1354
  train_loaders = []
@@ -1384,31 +1359,30 @@ def generate_loaders(src, train_mode='erm', mode='train', image_size=224, batch_
1384
1359
  transform = transforms.Compose([
1385
1360
  transforms.ToTensor(),
1386
1361
  transforms.CenterCrop(size=(image_size, image_size)),
1362
+ SelectChannels(channels),
1387
1363
  transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])
1388
1364
  else:
1389
1365
  transform = transforms.Compose([
1390
1366
  transforms.ToTensor(),
1391
- transforms.CenterCrop(size=(image_size, image_size))])
1392
-
1367
+ transforms.CenterCrop(size=(image_size, image_size)),
1368
+ SelectChannels(channels)])
1369
+
1393
1370
  if mode == 'train':
1394
1371
  data_dir = os.path.join(src, 'train')
1395
1372
  shuffle = True
1396
- print(f'Generating Train and validation datasets')
1397
-
1373
+ print('Generating Train and validation datasets')
1398
1374
  elif mode == 'test':
1399
1375
  data_dir = os.path.join(src, 'test')
1400
1376
  val_loaders = []
1401
- validation_split=0.0
1377
+ validation_split = 0.0
1402
1378
  shuffle = True
1403
- print(f'Generating test dataset')
1404
-
1379
+ print('Generating test dataset')
1405
1380
  else:
1406
1381
  print(f'mode:{mode} is not valid, use mode = train or test')
1407
1382
  return
1408
-
1383
+
1409
1384
  if train_mode == 'erm':
1410
1385
  data = MyDataset(data_dir, classes, transform=transform, shuffle=shuffle, pin_memory=pin_memory)
1411
- #train_loaders = DataLoader(data, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers if num_workers is not None else 0, pin_memory=pin_memory)
1412
1386
  if validation_split > 0:
1413
1387
  train_size = int((1 - validation_split) * len(data))
1414
1388
  val_size = len(data) - train_size
@@ -1465,7 +1439,6 @@ def generate_loaders(src, train_mode='erm', mode='train', image_size=224, batch_
1465
1439
  images = images.cpu()
1466
1440
  label_strings = [str(label.item()) for label in labels]
1467
1441
  _imshow(images, label_strings, nrow=20, fontsize=12)
1468
-
1469
1442
  elif train_mode == 'irm':
1470
1443
  for plate_name, train_loader in zip(plate_names, train_loaders):
1471
1444
  print(f'Plate: {plate_name} with {len(train_loader.dataset)} images')
@@ -1584,15 +1557,30 @@ def analyze_recruitment(src, metadata_settings, advanced_settings):
1584
1557
  df = df.dropna(subset=['condition'])
1585
1558
  print(f'After dropping non-annotated wells: {len(df)} rows')
1586
1559
  files = df['file_name'].tolist()
1560
+ print(f'found: {len(files)} files')
1587
1561
  files = [item + '.npy' for item in files]
1588
1562
  random.shuffle(files)
1589
-
1563
+
1564
+ _max = 10**100
1565
+
1566
+ if cell_size_range is None and nucleus_size_range is None and pathogen_size_range is None:
1567
+ filter_min_max = None
1568
+ else:
1569
+ if cell_size_range is None:
1570
+ cell_size_range = [0,_max]
1571
+ if nucleus_size_range is None:
1572
+ nucleus_size_range = [0,_max]
1573
+ if pathogen_size_range is None:
1574
+ pathogen_size_range = [0,_max]
1575
+
1576
+ filter_min_max = [[cell_size_range[0],cell_size_range[1]],[nucleus_size_range[0],nucleus_size_range[1]],[pathogen_size_range[0],pathogen_size_range[1]]]
1577
+
1590
1578
  if plot:
1591
1579
  plot_settings = {'include_noninfected':include_noninfected,
1592
1580
  'include_multiinfected':include_multiinfected,
1593
1581
  'include_multinucleated':include_multinucleated,
1594
1582
  'remove_background':remove_background,
1595
- 'filter_min_max':[[cell_size_range[0],cell_size_range[1]],[nucleus_size_range[0],nucleus_size_range[1]],[pathogen_size_range[0],pathogen_size_range[1]]],
1583
+ 'filter_min_max':filter_min_max,
1596
1584
  'channel_dims':channel_dims,
1597
1585
  'backgrounds':backgrounds,
1598
1586
  'cell_mask_dim':mask_dims[0],
@@ -1649,19 +1637,225 @@ def analyze_recruitment(src, metadata_settings, advanced_settings):
1649
1637
  cells,wells = _results_to_csv(src, df, df_well)
1650
1638
  return [cells,wells]
1651
1639
 
1652
- @log_function_call
1653
- def preprocess_generate_masks(src, settings={},advanced_settings={}):
1640
+ def _merge_cells_based_on_parasite_overlap(parasite_mask, cell_mask, nuclei_mask, overlap_threshold=5, perimeter_threshold=30):
1641
+ """
1642
+ Merge cells in cell_mask if a parasite in parasite_mask overlaps with more than one cell,
1643
+ and if cells share more than a specified perimeter percentage.
1644
+
1645
+ Args:
1646
+ parasite_mask (ndarray): Mask of parasites.
1647
+ cell_mask (ndarray): Mask of cells.
1648
+ nuclei_mask (ndarray): Mask of nuclei.
1649
+ overlap_threshold (float): The percentage threshold for merging cells based on parasite overlap.
1650
+ perimeter_threshold (float): The percentage threshold for merging cells based on shared perimeter.
1651
+
1652
+ Returns:
1653
+ ndarray: The modified cell mask (cell_mask) with unique labels.
1654
+ """
1655
+ labeled_cells = label(cell_mask)
1656
+ labeled_parasites = label(parasite_mask)
1657
+ labeled_nuclei = label(nuclei_mask)
1658
+ num_parasites = np.max(labeled_parasites)
1659
+ num_cells = np.max(labeled_cells)
1660
+ num_nuclei = np.max(labeled_nuclei)
1661
+
1662
+ # Merge cells based on parasite overlap
1663
+ for parasite_id in range(1, num_parasites + 1):
1664
+ current_parasite_mask = labeled_parasites == parasite_id
1665
+ overlapping_cell_labels = np.unique(labeled_cells[current_parasite_mask])
1666
+ overlapping_cell_labels = overlapping_cell_labels[overlapping_cell_labels != 0]
1667
+ if len(overlapping_cell_labels) > 1:
1668
+ # Calculate the overlap percentages
1669
+ overlap_percentages = [
1670
+ np.sum(current_parasite_mask & (labeled_cells == cell_label)) / np.sum(current_parasite_mask) * 100
1671
+ for cell_label in overlapping_cell_labels
1672
+ ]
1673
+ # Merge cells if overlap percentage is above the threshold
1674
+ for cell_label, overlap_percentage in zip(overlapping_cell_labels, overlap_percentages):
1675
+ if overlap_percentage > overlap_threshold:
1676
+ first_label = overlapping_cell_labels[0]
1677
+ for other_label in overlapping_cell_labels[1:]:
1678
+ if other_label != first_label:
1679
+ cell_mask[cell_mask == other_label] = first_label
1680
+
1681
+ # Merge cells based on nucleus overlap
1682
+ for nucleus_id in range(1, num_nuclei + 1):
1683
+ current_nucleus_mask = labeled_nuclei == nucleus_id
1684
+ overlapping_cell_labels = np.unique(labeled_cells[current_nucleus_mask])
1685
+ overlapping_cell_labels = overlapping_cell_labels[overlapping_cell_labels != 0]
1686
+ if len(overlapping_cell_labels) > 1:
1687
+ # Calculate the overlap percentages
1688
+ overlap_percentages = [
1689
+ np.sum(current_nucleus_mask & (labeled_cells == cell_label)) / np.sum(current_nucleus_mask) * 100
1690
+ for cell_label in overlapping_cell_labels
1691
+ ]
1692
+ # Merge cells if overlap percentage is above the threshold for each cell
1693
+ if all(overlap_percentage > overlap_threshold for overlap_percentage in overlap_percentages):
1694
+ first_label = overlapping_cell_labels[0]
1695
+ for other_label in overlapping_cell_labels[1:]:
1696
+ if other_label != first_label:
1697
+ cell_mask[cell_mask == other_label] = first_label
1698
+
1699
+ # Check for cells without nuclei and merge based on shared perimeter
1700
+ labeled_cells = label(cell_mask) # Re-label after merging based on overlap
1701
+ cell_regions = regionprops(labeled_cells)
1702
+ for region in cell_regions:
1703
+ cell_label = region.label
1704
+ cell_mask_binary = labeled_cells == cell_label
1705
+ overlapping_nuclei = np.unique(nuclei_mask[cell_mask_binary])
1706
+ overlapping_nuclei = overlapping_nuclei[overlapping_nuclei != 0]
1707
+
1708
+ if len(overlapping_nuclei) == 0:
1709
+ # Cell does not overlap with any nucleus
1710
+ perimeter = region.perimeter
1711
+ # Dilate the cell to find neighbors
1712
+ dilated_cell = binary_dilation(cell_mask_binary, structure=square(3))
1713
+ neighbor_cells = np.unique(labeled_cells[dilated_cell])
1714
+ neighbor_cells = neighbor_cells[(neighbor_cells != 0) & (neighbor_cells != cell_label)]
1715
+ # Calculate shared border length with neighboring cells
1716
+ shared_borders = [
1717
+ np.sum((labeled_cells == neighbor_label) & dilated_cell) for neighbor_label in neighbor_cells
1718
+ ]
1719
+ shared_border_percentages = [shared_border / perimeter * 100 for shared_border in shared_borders]
1720
+ # Merge with the neighbor cell with the largest shared border percentage above the threshold
1721
+ if shared_borders:
1722
+ max_shared_border_index = np.argmax(shared_border_percentages)
1723
+ max_shared_border_percentage = shared_border_percentages[max_shared_border_index]
1724
+ if max_shared_border_percentage > perimeter_threshold:
1725
+ cell_mask[labeled_cells == cell_label] = neighbor_cells[max_shared_border_index]
1726
+
1727
+ # Relabel the merged cell mask
1728
+ relabeled_cell_mask, _ = label(cell_mask, return_num=True)
1729
+ return relabeled_cell_mask
1730
+
1731
+ def adjust_cell_masks(parasite_folder, cell_folder, nuclei_folder, overlap_threshold=5, perimeter_threshold=30):
1732
+ """
1733
+ Process all npy files in the given folders. Merge and relabel cells in cell masks
1734
+ based on parasite overlap and cell perimeter sharing conditions.
1735
+
1736
+ Args:
1737
+ parasite_folder (str): Path to the folder containing parasite masks.
1738
+ cell_folder (str): Path to the folder containing cell masks.
1739
+ nuclei_folder (str): Path to the folder containing nuclei masks.
1740
+ overlap_threshold (float): The percentage threshold for merging cells based on parasite overlap.
1741
+ perimeter_threshold (float): The percentage threshold for merging cells based on shared perimeter.
1742
+ """
1743
+
1744
+ parasite_files = sorted([f for f in os.listdir(parasite_folder) if f.endswith('.npy')])
1745
+ cell_files = sorted([f for f in os.listdir(cell_folder) if f.endswith('.npy')])
1746
+ nuclei_files = sorted([f for f in os.listdir(nuclei_folder) if f.endswith('.npy')])
1747
+
1748
+ # Ensure there are matching files in all folders
1749
+ if not (len(parasite_files) == len(cell_files) == len(nuclei_files)):
1750
+ raise ValueError("The number of files in the folders do not match.")
1751
+
1752
+ # Match files by name
1753
+ for file_name in parasite_files:
1754
+ parasite_path = os.path.join(parasite_folder, file_name)
1755
+ cell_path = os.path.join(cell_folder, file_name)
1756
+ nuclei_path = os.path.join(nuclei_folder, file_name)
1757
+ # Check if the corresponding cell and nuclei mask files exist
1758
+ if not (os.path.exists(cell_path) and os.path.exists(nuclei_path)):
1759
+ raise ValueError(f"Corresponding cell or nuclei mask file for {file_name} not found.")
1760
+ # Load the masks
1761
+ parasite_mask = np.load(parasite_path)
1762
+ cell_mask = np.load(cell_path)
1763
+ nuclei_mask = np.load(nuclei_path)
1764
+ # Merge and relabel cells
1765
+ merged_cell_mask = _merge_cells_based_on_parasite_overlap(parasite_mask, cell_mask, nuclei_mask, overlap_threshold, perimeter_threshold)
1766
+ # Overwrite the original cell mask file with the merged result
1767
+ np.save(cell_path, merged_cell_mask)
1768
+
1769
+ def process_masks(mask_folder, image_folder, channel, batch_size=50, n_clusters=2, plot=False):
1770
+
1771
+ def read_files_in_batches(folder, batch_size=50):
1772
+ files = [f for f in os.listdir(folder) if f.endswith('.npy')]
1773
+ files.sort() # Sort to ensure matching order
1774
+ for i in range(0, len(files), batch_size):
1775
+ yield files[i:i + batch_size]
1776
+
1777
+ def measure_morphology_and_intensity(mask, image):
1778
+ properties = measure.regionprops(mask, intensity_image=image)
1779
+ properties_list = [{'area': p.area, 'mean_intensity': p.mean_intensity, 'perimeter': p.perimeter, 'eccentricity': p.eccentricity} for p in properties]
1780
+ return properties_list
1781
+
1782
+ def cluster_objects(properties, n_clusters=2):
1783
+ data = np.array([[p['area'], p['mean_intensity'], p['perimeter'], p['eccentricity']] for p in properties])
1784
+ kmeans = KMeans(n_clusters=n_clusters, random_state=0).fit(data)
1785
+ return kmeans
1786
+
1787
+ def remove_objects_not_in_largest_cluster(mask, labels, largest_cluster_label):
1788
+ cleaned_mask = np.zeros_like(mask)
1789
+ for region in measure.regionprops(mask):
1790
+ if labels[region.label - 1] == largest_cluster_label:
1791
+ cleaned_mask[mask == region.label] = region.label
1792
+ return cleaned_mask
1793
+
1794
+ def plot_clusters(properties, labels):
1795
+ data = np.array([[p['area'], p['mean_intensity'], p['perimeter'], p['eccentricity']] for p in properties])
1796
+ pca = PCA(n_components=2)
1797
+ data_2d = pca.fit_transform(data)
1798
+ plt.scatter(data_2d[:, 0], data_2d[:, 1], c=labels, cmap='viridis')
1799
+ plt.xlabel('PCA Component 1')
1800
+ plt.ylabel('PCA Component 2')
1801
+ plt.title('Object Clustering')
1802
+ plt.show()
1803
+
1804
+ all_properties = []
1805
+
1806
+ # Step 1: Accumulate properties over all files
1807
+ for batch in read_files_in_batches(mask_folder, batch_size):
1808
+ mask_files = [os.path.join(mask_folder, file) for file in batch]
1809
+ image_files = [os.path.join(image_folder, file) for file in batch]
1810
+
1811
+ masks = [np.load(file) for file in mask_files]
1812
+ images = [np.load(file)[:, :, channel] for file in image_files]
1813
+
1814
+ for i, mask in enumerate(masks):
1815
+ image = images[i]
1816
+ # Measure morphology and intensity
1817
+ properties = measure_morphology_and_intensity(mask, image)
1818
+ all_properties.extend(properties)
1819
+
1820
+ # Step 2: Perform clustering on accumulated properties
1821
+ kmeans = cluster_objects(all_properties, n_clusters)
1822
+ labels = kmeans.labels_
1823
+
1824
+ if plot:
1825
+ # Step 3: Plot clusters using PCA
1826
+ plot_clusters(all_properties, labels)
1827
+
1828
+ # Step 4: Remove objects not in the largest cluster and overwrite files in batches
1829
+ label_index = 0
1830
+ for batch in read_files_in_batches(mask_folder, batch_size):
1831
+ mask_files = [os.path.join(mask_folder, file) for file in batch]
1832
+ masks = [np.load(file) for file in mask_files]
1833
+
1834
+ for i, mask in enumerate(masks):
1835
+ batch_properties = measure_morphology_and_intensity(mask, mask)
1836
+ batch_labels = labels[label_index:label_index + len(batch_properties)]
1837
+ largest_cluster_label = np.bincount(batch_labels).argmax()
1838
+ cleaned_mask = remove_objects_not_in_largest_cluster(mask, batch_labels, largest_cluster_label)
1839
+ np.save(mask_files[i], cleaned_mask)
1840
+ label_index += len(batch_properties)
1841
+
1842
+ def preprocess_generate_masks(src, settings={}):
1654
1843
 
1655
1844
  from .io import preprocess_img_data, _load_and_concatenate_arrays
1656
1845
  from .plot import plot_merged, plot_arrays
1657
- from .utils import _pivot_counts_table
1658
-
1659
- settings = {**settings, **advanced_settings}
1660
- settings['src'] = src
1846
+ from .utils import _pivot_counts_table, set_default_settings_preprocess_generate_masks, set_default_plot_merge_settings, check_mask_folder
1847
+
1848
+ settings = set_default_settings_preprocess_generate_masks(src, settings)
1849
+
1661
1850
  settings_df = pd.DataFrame(list(settings.items()), columns=['Key', 'Value'])
1662
1851
  settings_csv = os.path.join(src,'settings','preprocess_generate_masks_settings.csv')
1663
1852
  os.makedirs(os.path.join(src,'settings'), exist_ok=True)
1664
1853
  settings_df.to_csv(settings_csv, index=False)
1854
+
1855
+ if not settings['pathogen_channel'] is None:
1856
+ custom_model_ls = ['toxo_pv_lumen','toxo_cyto']
1857
+ if settings['pathogen_model'] not in custom_model_ls:
1858
+ ValueError(f'Pathogen model must be {custom_model_ls} or None')
1665
1859
 
1666
1860
  if settings['timelapse']:
1667
1861
  settings['randomize'] = False
@@ -1670,24 +1864,50 @@ def preprocess_generate_masks(src, settings={},advanced_settings={}):
1670
1864
  if not settings['masks']:
1671
1865
  print(f'WARNING: channels for mask generation are defined when preprocess = True')
1672
1866
 
1673
- if isinstance(settings['merge'], bool):
1674
- settings['merge'] = [settings['merge']]*3
1675
1867
  if isinstance(settings['save'], bool):
1676
1868
  settings['save'] = [settings['save']]*3
1677
1869
 
1870
+ if settings['verbose']:
1871
+ settings_df = pd.DataFrame(list(settings.items()), columns=['setting_key', 'setting_value'])
1872
+ settings_df['setting_value'] = settings_df['setting_value'].apply(str)
1873
+ display(settings_df)
1874
+
1875
+ if settings['test_mode']:
1876
+ print(f'Starting Test mode ...')
1877
+
1678
1878
  if settings['preprocess']:
1679
- preprocess_img_data(settings)
1879
+ settings, src = preprocess_img_data(settings)
1680
1880
 
1681
1881
  if settings['masks']:
1682
1882
  mask_src = os.path.join(src, 'norm_channel_stack')
1683
1883
  if settings['cell_channel'] != None:
1684
- generate_cellpose_masks(src=mask_src, settings=settings, object_type='cell')
1884
+ if check_mask_folder(src, 'cell_mask_stack'):
1885
+ generate_cellpose_masks(mask_src, settings, 'cell')
1685
1886
 
1686
1887
  if settings['nucleus_channel'] != None:
1687
- generate_cellpose_masks(src=mask_src, settings=settings, object_type='nucleus')
1888
+ if check_mask_folder(src, 'nucleus_mask_stack'):
1889
+ generate_cellpose_masks(mask_src, settings, 'nucleus')
1688
1890
 
1689
1891
  if settings['pathogen_channel'] != None:
1690
- generate_cellpose_masks(src=mask_src, settings=settings, object_type='pathogen')
1892
+ if check_mask_folder(src, 'pathogen_mask_stack'):
1893
+ generate_cellpose_masks(mask_src, settings, 'pathogen')
1894
+
1895
+ if settings['adjust_cells']:
1896
+ if settings['pathogen_channel'] != None and settings['cell_channel'] != None and settings['nucleus_channel'] != None:
1897
+
1898
+ start = time.time()
1899
+ cell_folder = os.path.join(mask_src, 'cell_mask_stack')
1900
+ nuclei_folder = os.path.join(mask_src, 'nucleus_mask_stack')
1901
+ parasite_folder = os.path.join(mask_src, 'pathogen_mask_stack')
1902
+ #image_folder = os.path.join(src, 'stack')
1903
+
1904
+ #process_masks(cell_folder, image_folder, settings['cell_channel'], settings['batch_size'], n_clusters=2, plot=settings['plot'])
1905
+ #process_masks(nuclei_folder, image_folder, settings['nucleus_channel'], settings['batch_size'], n_clusters=2, plot=settings['plot'])
1906
+ #process_masks(parasite_folder, image_folder, settings['pathogen_channel'], settings['batch_size'], n_clusters=2, plot=settings['plot'])
1907
+
1908
+ adjust_cell_masks(parasite_folder, cell_folder, nuclei_folder, overlap_threshold=5, perimeter_threshold=30)
1909
+ stop = time.time()
1910
+ print(f'Cell mask adjustment: {stop-start} seconds')
1691
1911
 
1692
1912
  if os.path.exists(os.path.join(src,'measurements')):
1693
1913
  _pivot_counts_table(db_path=os.path.join(src,'measurements', 'measurements.db'))
@@ -1716,59 +1936,110 @@ def preprocess_generate_masks(src, settings={},advanced_settings={}):
1716
1936
  overlay_channels = [settings['nucleus_channel'], settings['pathogen_channel'], settings['cell_channel']]
1717
1937
  overlay_channels = [element for element in overlay_channels if element is not None]
1718
1938
 
1719
- plot_settings = {'include_noninfected':True,
1720
- 'include_multiinfected':True,
1721
- 'include_multinucleated':True,
1722
- 'remove_background':False,
1723
- 'filter_min_max':None,
1724
- 'channel_dims':settings['channels'],
1725
- 'backgrounds':[100,100,100,100],
1726
- 'cell_mask_dim':cell_mask_dim,
1727
- 'nucleus_mask_dim':nucleus_mask_dim,
1728
- 'pathogen_mask_dim':pathogen_mask_dim,
1729
- 'overlay_chans':[0,2,3],
1730
- 'outline_thickness':3,
1731
- 'outline_color':'gbr',
1732
- 'overlay_chans':overlay_channels,
1733
- 'overlay':True,
1734
- 'normalization_percentiles':[1,99],
1735
- 'normalize':True,
1736
- 'print_object_number':True,
1737
- 'nr':settings['examples_to_plot'],
1738
- 'figuresize':20,
1739
- 'cmap':'inferno',
1740
- 'verbose':False}
1939
+ plot_settings = set_default_plot_merge_settings()
1940
+ plot_settings['channel_dims'] = settings['channels']
1941
+ plot_settings['cell_mask_dim'] = cell_mask_dim
1942
+ plot_settings['nucleus_mask_dim'] = nucleus_mask_dim
1943
+ plot_settings['pathogen_mask_dim'] = pathogen_mask_dim
1944
+ plot_settings['overlay_chans'] = overlay_channels
1945
+ plot_settings['nr'] = settings['examples_to_plot']
1946
+
1947
+ if settings['test_mode'] == True:
1948
+ plot_settings['nr'] = len(os.path.join(src,'merged'))
1949
+
1741
1950
  try:
1742
1951
  fig = plot_merged(src=os.path.join(src,'merged'), settings=plot_settings)
1743
1952
  except Exception as e:
1744
1953
  print(f'Failed to plot image mask overly. Error: {e}')
1745
1954
  else:
1746
- plot_arrays(src=os.path.join(src,'merged'), figuresize=50, cmap='inferno', nr=1, normalize=True, q1=1, q2=99)
1955
+ plot_arrays(src=os.path.join(src,'merged'), figuresize=settings['figuresize'], cmap=settings['cmap'], nr=settings['examples_to_plot'], normalize=settings['normalize'], q1=1, q2=99)
1747
1956
 
1748
1957
  torch.cuda.empty_cache()
1749
1958
  gc.collect()
1959
+ print("Successfully completed run")
1750
1960
  return
1751
1961
 
1752
- def identify_masks_finetune(src, dst, model_name, channels, diameter, batch_size, flow_threshold=30, cellprob_threshold=1, figuresize=25, cmap='inferno', verbose=False, plot=False, save=False, custom_model=None, signal_thresholds=1000, normalize=True, resize=False, target_height=None, target_width=None, rescale=True, resample=True, net_avg=False, invert=False, circular=False, percentiles=None, overlay=True, grayscale=False):
1962
+ def identify_masks_finetune(settings):
1753
1963
 
1754
1964
  from .plot import print_mask_and_flows
1755
1965
  from .utils import get_files_from_dir, resize_images_and_labels
1756
1966
  from .io import _load_normalized_images_and_labels, _load_images_and_labels
1757
1967
 
1968
+ #User defined settings
1969
+ src=settings['src']
1970
+ dst=settings['dst']
1971
+
1972
+
1973
+ settings.setdefault('model_name', 'cyto')
1974
+ settings.setdefault('custom_model', None)
1975
+ settings.setdefault('channels', [0,0])
1976
+ settings.setdefault('background', 100)
1977
+ settings.setdefault('remove_background', False)
1978
+ settings.setdefault('Signal_to_noise', 10)
1979
+ settings.setdefault('CP_prob', 0)
1980
+ settings.setdefault('diameter', 30)
1981
+ settings.setdefault('batch_size', 50)
1982
+ settings.setdefault('flow_threshold', 0.4)
1983
+ settings.setdefault('save', False)
1984
+ settings.setdefault('verbose', False)
1985
+ settings.setdefault('normalize', True)
1986
+ settings.setdefault('percentiles', None)
1987
+ settings.setdefault('circular', False)
1988
+ settings.setdefault('invert', False)
1989
+ settings.setdefault('resize', False)
1990
+ settings.setdefault('target_height', None)
1991
+ settings.setdefault('target_width', None)
1992
+ settings.setdefault('rescale', False)
1993
+ settings.setdefault('resample', False)
1994
+ settings.setdefault('grayscale', True)
1995
+
1996
+
1997
+ model_name=settings['model_name']
1998
+ custom_model=settings['custom_model']
1999
+ channels = settings['channels']
2000
+ background = settings['background']
2001
+ remove_background=settings['remove_background']
2002
+ Signal_to_noise = settings['Signal_to_noise']
2003
+ CP_prob = settings['CP_prob']
2004
+ diameter=settings['diameter']
2005
+ batch_size=settings['batch_size']
2006
+ flow_threshold=settings['flow_threshold']
2007
+ save=settings['save']
2008
+ verbose=settings['verbose']
2009
+
2010
+ # static settings
2011
+ normalize = settings['normalize']
2012
+ percentiles = settings['percentiles']
2013
+ circular = settings['circular']
2014
+ invert = settings['invert']
2015
+ resize = settings['resize']
2016
+
2017
+ if resize:
2018
+ target_height = settings['target_height']
2019
+ target_width = settings['target_width']
2020
+
2021
+ rescale = settings['rescale']
2022
+ resample = settings['resample']
2023
+ grayscale = settings['grayscale']
2024
+
2025
+ os.makedirs(dst, exist_ok=True)
2026
+
2027
+ if not custom_model is None:
2028
+ if not os.path.exists(custom_model):
2029
+ print(f'Custom model not found: {custom_model}')
2030
+ return
2031
+
1758
2032
  if not torch.cuda.is_available():
1759
2033
  print(f'Torch CUDA is not available, using CPU')
1760
2034
 
1761
2035
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
1762
2036
 
1763
2037
  if custom_model == None:
1764
- if model_name =='cyto':
1765
- model = cp_models.CellposeModel(gpu=True, model_type=model_name, net_avg=False, diam_mean=diameter, pretrained_model=None)
1766
- else:
1767
- model = cp_models.CellposeModel(gpu=True, model_type=model_name)
1768
-
1769
- if custom_model != None:
1770
- model = cp_models.CellposeModel(gpu=torch.cuda.is_available(), model_type=None, pretrained_model=custom_model, diam_mean=diameter, device=device, net_avg=False) #Assuming diameter is defined elsewhere
1771
- print(f'loaded custom model:{custom_model}')
2038
+ model = cp_models.CellposeModel(gpu=True, model_type=model_name, device=device)
2039
+ print(f'Loaded model: {model_name}')
2040
+ else:
2041
+ model = cp_models.CellposeModel(gpu=torch.cuda.is_available(), model_type=None, pretrained_model=custom_model, diam_mean=diameter, device=device)
2042
+ print("Pretrained Model Loaded:", model.pretrained_model)
1772
2043
 
1773
2044
  chans = [2, 1] if model_name == 'cyto2' else [0,0] if model_name == 'nucleus' else [1,0] if model_name == 'cyto' else [2, 0]
1774
2045
 
@@ -1778,16 +2049,18 @@ def identify_masks_finetune(src, dst, model_name, channels, diameter, batch_size
1778
2049
  print(f'Using channels: {chans} for model of type {model_name}')
1779
2050
 
1780
2051
  if verbose == True:
1781
- print(f'Cellpose settings: Model: {model_name}, channels: {channels}, cellpose_chans: {chans}, diameter:{diameter}, flow_threshold:{flow_threshold}, cellprob_threshold:{cellprob_threshold}')
2052
+ print(f'Cellpose settings: Model: {model_name}, channels: {channels}, cellpose_chans: {chans}, diameter:{diameter}, flow_threshold:{flow_threshold}, cellprob_threshold:{CP_prob}')
1782
2053
 
1783
- all_image_files = get_files_from_dir(src, file_extension="*.tif")
2054
+ all_image_files = [os.path.join(src, f) for f in os.listdir(src) if f.endswith('.tif')]
2055
+
1784
2056
  random.shuffle(all_image_files)
1785
2057
 
1786
2058
  time_ls = []
1787
2059
  for i in range(0, len(all_image_files), batch_size):
1788
2060
  image_files = all_image_files[i:i+batch_size]
2061
+
1789
2062
  if normalize:
1790
- images, _, image_names, _ = _load_normalized_images_and_labels(image_files=image_files, label_files=None, signal_thresholds=signal_thresholds, channels=channels, percentiles=percentiles, circular=circular, invert=invert, visualize=verbose)
2063
+ images, _, image_names, _ = _load_normalized_images_and_labels(image_files=image_files, label_files=None, channels=channels, percentiles=percentiles, circular=circular, invert=invert, visualize=verbose, remove_background=remove_background, background=background, Signal_to_noise=Signal_to_noise)
1791
2064
  images = [np.squeeze(img) if img.shape[-1] == 1 else img for img in images]
1792
2065
  orig_dims = [(image.shape[0], image.shape[1]) for image in images]
1793
2066
  else:
@@ -1805,11 +2078,10 @@ def identify_masks_finetune(src, dst, model_name, channels, diameter, batch_size
1805
2078
  channel_axis=3,
1806
2079
  diameter=diameter,
1807
2080
  flow_threshold=flow_threshold,
1808
- cellprob_threshold=cellprob_threshold,
2081
+ cellprob_threshold=CP_prob,
1809
2082
  rescale=rescale,
1810
2083
  resample=resample,
1811
- net_avg=net_avg,
1812
- progress=False)
2084
+ progress=True)
1813
2085
 
1814
2086
  if len(output) == 4:
1815
2087
  mask, flows, _, _ = output
@@ -1827,17 +2099,17 @@ def identify_masks_finetune(src, dst, model_name, channels, diameter, batch_size
1827
2099
  time_ls.append(duration)
1828
2100
  average_time = np.mean(time_ls) if len(time_ls) > 0 else 0
1829
2101
  print(f'Processing {file_index+1}/{len(images)} images : Time/image {average_time:.3f} sec', end='\r', flush=True)
1830
- if plot:
2102
+ if verbose:
1831
2103
  if resize:
1832
2104
  stack = resizescikit(stack, dims, preserve_range=True, anti_aliasing=False).astype(stack.dtype)
1833
- print_mask_and_flows(stack, mask, flows, overlay=overlay)
2105
+ print_mask_and_flows(stack, mask, flows, overlay=True)
1834
2106
  if save:
2107
+ os.makedirs(dst, exist_ok=True)
1835
2108
  output_filename = os.path.join(dst, image_names[file_index])
1836
2109
  cv2.imwrite(output_filename, mask)
1837
2110
  return
1838
2111
 
1839
- @log_function_call
1840
- def identify_masks(src, object_type, model_name, batch_size, channels, diameter, minimum_size, maximum_size, flow_threshold=30, cellprob_threshold=1, figuresize=25, cmap='inferno', refine_masks=True, filter_size=True, filter_dimm=True, remove_border_objects=False, verbose=False, plot=False, merge=False, save=True, start_at=0, file_type='.npz', net_avg=True, resample=True, timelapse=False, timelapse_displacement=None, timelapse_frame_limits=None, timelapse_memory=3, timelapse_remove_transient=False, timelapse_mode='btrack', timelapse_objects='cell'):
2112
+ def identify_masks(src, object_type, model_name, batch_size, channels, diameter, minimum_size, maximum_size, filter_intensity, flow_threshold=30, cellprob_threshold=1, figuresize=25, cmap='inferno', refine_masks=True, filter_size=True, filter_dimm=True, remove_border_objects=False, verbose=False, plot=False, merge=False, save=True, start_at=0, file_type='.npz', net_avg=True, resample=True, timelapse=False, timelapse_displacement=None, timelapse_frame_limits=None, timelapse_memory=3, timelapse_remove_transient=False, timelapse_mode='btrack', timelapse_objects='cell'):
1841
2113
  """
1842
2114
  Identify masks from the source images.
1843
2115
 
@@ -1885,13 +2157,13 @@ def identify_masks(src, object_type, model_name, batch_size, channels, diameter,
1885
2157
 
1886
2158
  #Note add logic that handles batches of size 1 as these will break the code batches must all be > 2 images
1887
2159
  gc.collect()
1888
- #print('========== generating masks ==========')
1889
2160
 
1890
2161
  if not torch.cuda.is_available():
1891
2162
  print(f'Torch CUDA is not available, using CPU')
1892
2163
 
1893
2164
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
1894
- model = cp_models.Cellpose(gpu=True, model_type=model_name, device=device) #net_avg=net_avg
2165
+ model = cp_models.Cellpose(gpu=True, model_type=model_name, device=device)
2166
+
1895
2167
  if file_type == '.npz':
1896
2168
  paths = [os.path.join(src, file) for file in os.listdir(src) if file.endswith('.npz')]
1897
2169
  else:
@@ -1918,9 +2190,6 @@ def identify_masks(src, object_type, model_name, batch_size, channels, diameter,
1918
2190
 
1919
2191
  average_sizes = []
1920
2192
  time_ls = []
1921
- moving_avg_q1 = 0
1922
- moving_avg_q3 = 0
1923
- moving_count = 0
1924
2193
  for file_index, path in enumerate(paths):
1925
2194
 
1926
2195
  name = os.path.basename(path)
@@ -1961,7 +2230,8 @@ def identify_masks(src, object_type, model_name, batch_size, channels, diameter,
1961
2230
  if not plot:
1962
2231
  batch, batch_filenames = _check_masks(batch, batch_filenames, output_folder)
1963
2232
  if batch.size == 0:
1964
- print(f'Processing {file_index}/{len(paths)}: Images/N100pz {batch.shape[0]}', end='\r', flush=True)
2233
+ print(f'Processing: {file_index}/{len(paths)}: Images/N100pz {batch.shape[0]}')
2234
+ #print(f'Processing {file_index}/{len(paths)}: Images/N100pz {batch.shape[0]}', end='\r', flush=True)
1965
2235
  continue
1966
2236
  if batch.max() > 1:
1967
2237
  batch = batch / batch.max()
@@ -1976,8 +2246,6 @@ def identify_masks(src, object_type, model_name, batch_size, channels, diameter,
1976
2246
  stitch_threshold=0.0
1977
2247
 
1978
2248
  cellpose_batch_size = _get_cellpose_batch_size()
1979
-
1980
- model = cellpose.denoise.DenoiseModel(model_type=f"denoise_{model_name}", gpu=True)
1981
2249
 
1982
2250
  masks, flows, _, _ = model.eval(x=batch,
1983
2251
  batch_size=cellpose_batch_size,
@@ -1989,9 +2257,9 @@ def identify_masks(src, object_type, model_name, batch_size, channels, diameter,
1989
2257
  cellprob_threshold=cellprob_threshold,
1990
2258
  rescale=None,
1991
2259
  resample=resample,
1992
- #net_avg=net_avg,
1993
2260
  stitch_threshold=stitch_threshold,
1994
2261
  progress=None)
2262
+
1995
2263
  print('Masks shape',masks.shape)
1996
2264
  if timelapse:
1997
2265
  _save_object_counts_to_database(masks, object_type, batch_filenames, count_loc, added_string='_timelapse')
@@ -2015,7 +2283,7 @@ def identify_masks(src, object_type, model_name, batch_size, channels, diameter,
2015
2283
 
2016
2284
  else:
2017
2285
  _save_object_counts_to_database(masks, object_type, batch_filenames, count_loc, added_string='_before_filtration')
2018
- mask_stack = _filter_cp_masks(masks, flows, refine_masks, filter_size, minimum_size, maximum_size, remove_border_objects, merge, filter_dimm, batch, moving_avg_q1, moving_avg_q3, moving_count, plot, figuresize)
2286
+ mask_stack = _filter_cp_masks(masks, flows, filter_size, filter_intensity, minimum_size, maximum_size, remove_border_objects, merge, batch, plot, figuresize)
2019
2287
  _save_object_counts_to_database(mask_stack, object_type, batch_filenames, count_loc, added_string='_after_filtration')
2020
2288
 
2021
2289
  if not np.any(mask_stack):
@@ -2032,7 +2300,8 @@ def identify_masks(src, object_type, model_name, batch_size, channels, diameter,
2032
2300
  average_time = np.mean(time_ls) if len(time_ls) > 0 else 0
2033
2301
  time_in_min = average_time/60
2034
2302
  time_per_mask = average_time/batch_size
2035
- print(f'Processing {len(paths)} files with {batch_size} imgs: {(file_index+1)*(batch_size+1)}/{(len(paths))*(batch_size+1)}: Time/batch {time_in_min:.3f} min: Time/mask {time_per_mask:.3f}sec: {object_type} size: {overall_average_size:.3f} px2', end='\r', flush=True)
2303
+ print(f'Processing: {len(paths)} files with {batch_size} imgs: {(file_index+1)*(batch_size+1)}/{(len(paths))*(batch_size+1)}: Time/batch {time_in_min:.3f} min: Time/mask {time_per_mask:.3f}sec: {object_type} size: {overall_average_size:.3f} px2')
2304
+ #print(f'Processing {len(paths)} files with {batch_size} imgs: {(file_index+1)*(batch_size+1)}/{(len(paths))*(batch_size+1)}: Time/batch {time_in_min:.3f} min: Time/mask {time_per_mask:.3f}sec: {object_type} size: {overall_average_size:.3f} px2', end='\r', flush=True)
2036
2305
  if not timelapse:
2037
2306
  if plot:
2038
2307
  plot_masks(batch, mask_stack, flows, figuresize=figuresize, cmap=cmap, nr=batch_size, file_type='.npz')
@@ -2046,10 +2315,25 @@ def identify_masks(src, object_type, model_name, batch_size, channels, diameter,
2046
2315
  gc.collect()
2047
2316
  return
2048
2317
 
2049
- @log_function_call
2318
+ def all_elements_match(list1, list2):
2319
+ # Check if all elements in list1 are in list2
2320
+ return all(element in list2 for element in list1)
2321
+
2322
+ def prepare_batch_for_cellpose(batch):
2323
+ # Ensure the batch is of dtype float32
2324
+ if batch.dtype != np.float32:
2325
+ batch = batch.astype(np.float32)
2326
+
2327
+ # Normalize each image in the batch
2328
+ for i in range(batch.shape[0]):
2329
+ if batch[i].max() > 1:
2330
+ batch[i] = batch[i] / batch[i].max()
2331
+
2332
+ return batch
2333
+
2050
2334
  def generate_cellpose_masks(src, settings, object_type):
2051
2335
 
2052
- from .utils import _masks_to_masks_stack, _filter_cp_masks, _get_cellpose_batch_size, _get_object_settings, _get_cellpose_channels
2336
+ from .utils import _masks_to_masks_stack, _filter_cp_masks, _get_cellpose_batch_size, _get_object_settings, _get_cellpose_channels, _choose_model, mask_object_count, set_default_settings_preprocess_generate_masks
2053
2337
  from .io import _create_database, _save_object_counts_to_database, _check_masks, _get_avg_object_size
2054
2338
  from .timelapse import _npz_to_movie, _btrack_track_cells, _trackpy_track_cells
2055
2339
  from .plot import plot_masks
@@ -2057,6 +2341,13 @@ def generate_cellpose_masks(src, settings, object_type):
2057
2341
  gc.collect()
2058
2342
  if not torch.cuda.is_available():
2059
2343
  print(f'Torch CUDA is not available, using CPU')
2344
+
2345
+ settings = set_default_settings_preprocess_generate_masks(src, settings)
2346
+
2347
+ if settings['verbose']:
2348
+ settings_df = pd.DataFrame(list(settings.items()), columns=['setting_key', 'setting_value'])
2349
+ settings_df['setting_value'] = settings_df['setting_value'].apply(str)
2350
+ display(settings_df)
2060
2351
 
2061
2352
  figuresize=25
2062
2353
  timelapse = settings['timelapse']
@@ -2071,21 +2362,26 @@ def generate_cellpose_masks(src, settings, object_type):
2071
2362
 
2072
2363
  batch_size = settings['batch_size']
2073
2364
  cellprob_threshold = settings[f'{object_type}_CP_prob']
2074
- flow_threshold = 30
2075
-
2365
+
2366
+ flow_threshold = settings[f'{object_type}_FT']
2367
+
2076
2368
  object_settings = _get_object_settings(object_type, settings)
2077
2369
  model_name = object_settings['model_name']
2078
2370
 
2079
- cellpose_channels = _get_cellpose_channels(settings['nucleus_channel'], settings['pathogen_channel'], settings['cell_channel'])
2371
+ cellpose_channels = _get_cellpose_channels(src, settings['nucleus_channel'], settings['pathogen_channel'], settings['cell_channel'])
2372
+ if settings['verbose']:
2373
+ print(cellpose_channels)
2374
+
2080
2375
  channels = cellpose_channels[object_type]
2081
2376
  cellpose_batch_size = _get_cellpose_batch_size()
2082
-
2083
2377
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
2084
- model = cp_models.Cellpose(gpu=True, model_type=model_name, device=device) #net_avg=net_avg
2085
- #dn = denoise.CellposeDenoiseModel(model_type=f"denoise_{model_name}", gpu=True, device=device)
2086
2378
 
2379
+ if object_type == 'pathogen' and not settings['pathogen_model'] is None:
2380
+ model_name = settings['pathogen_model']
2381
+
2382
+ model = _choose_model(model_name, device, object_type=object_type, restore_type=None, object_settings=object_settings)
2383
+
2087
2384
  chans = [2, 1] if model_name == 'cyto2' else [0,0] if model_name == 'nucleus' else [2,0] if model_name == 'cyto' else [2, 0] if model_name == 'cyto3' else [2, 0]
2088
-
2089
2385
  paths = [os.path.join(src, file) for file in os.listdir(src) if file.endswith('.npz')]
2090
2386
 
2091
2387
  count_loc = os.path.dirname(src)+'/measurements/measurements.db'
@@ -2094,10 +2390,6 @@ def generate_cellpose_masks(src, settings, object_type):
2094
2390
 
2095
2391
  average_sizes = []
2096
2392
  time_ls = []
2097
- moving_avg_q1 = 0
2098
- moving_avg_q3 = 0
2099
- moving_count = 0
2100
-
2101
2393
  for file_index, path in enumerate(paths):
2102
2394
  name = os.path.basename(path)
2103
2395
  name, ext = os.path.splitext(name)
@@ -2107,17 +2399,31 @@ def generate_cellpose_masks(src, settings, object_type):
2107
2399
  with np.load(path) as data:
2108
2400
  stack = data['data']
2109
2401
  filenames = data['filenames']
2402
+
2403
+ for i, filename in enumerate(filenames):
2404
+ output_path = os.path.join(output_folder, filename)
2405
+
2406
+ if os.path.exists(output_path):
2407
+ print(f"File {filename} already exists in the output folder. Skipping...")
2408
+ continue
2409
+
2110
2410
  if settings['timelapse']:
2411
+
2412
+ trackable_objects = ['cell','nucleus','pathogen']
2413
+ if not all_elements_match(settings['timelapse_objects'], trackable_objects):
2414
+ print(f'timelapse_objects {settings["timelapse_objects"]} must be a subset of {trackable_objects}')
2415
+ return
2416
+
2111
2417
  if len(stack) != batch_size:
2112
2418
  print(f'Changed batch_size:{batch_size} to {len(stack)}, data length:{len(stack)}')
2113
- settings['batch_size'] = len(stack)
2419
+ settings['timelapse_batch_size'] = len(stack)
2114
2420
  batch_size = len(stack)
2115
2421
  if isinstance(timelapse_frame_limits, list):
2116
2422
  if len(timelapse_frame_limits) >= 2:
2117
2423
  stack = stack[timelapse_frame_limits[0]: timelapse_frame_limits[1], :, :, :].astype(stack.dtype)
2118
2424
  filenames = filenames[timelapse_frame_limits[0]: timelapse_frame_limits[1]]
2119
2425
  batch_size = len(stack)
2120
- print(f'Cut batch an indecies: {timelapse_frame_limits}, New batch_size: {batch_size} ')
2426
+ print(f'Cut batch at indecies: {timelapse_frame_limits}, New batch_size: {batch_size} ')
2121
2427
 
2122
2428
  for i in range(0, stack.shape[0], batch_size):
2123
2429
  mask_stack = []
@@ -2133,37 +2439,53 @@ def generate_cellpose_masks(src, settings, object_type):
2133
2439
  if not settings['plot']:
2134
2440
  batch, batch_filenames = _check_masks(batch, batch_filenames, output_folder)
2135
2441
  if batch.size == 0:
2136
- print(f'Processing {file_index}/{len(paths)}: Images/N100pz {batch.shape[0]}', end='\r', flush=True)
2442
+ print(f'Processing {file_index}/{len(paths)}: Images/npz {batch.shape[0]}')
2137
2443
  continue
2138
- if batch.max() > 1:
2139
- batch = batch / batch.max()
2444
+
2445
+ batch = prepare_batch_for_cellpose(batch)
2140
2446
 
2141
2447
  if timelapse:
2142
- stitch_threshold=100.0
2143
2448
  movie_path = os.path.join(os.path.dirname(src), 'movies')
2144
2449
  os.makedirs(movie_path, exist_ok=True)
2145
2450
  save_path = os.path.join(movie_path, f'timelapse_{object_type}_{name}.mp4')
2146
2451
  _npz_to_movie(batch, batch_filenames, save_path, fps=2)
2452
+
2453
+ if settings['verbose']:
2454
+ print(f'Processing {file_index}/{len(paths)}: Images/npz {batch.shape[0]}')
2455
+
2456
+ #cellpose_normalize_dict = {'lowhigh':[0.0,1.0], #pass in normalization values for 0.0 and 1.0 as list [low, high] if None all other keys ignored
2457
+ # 'sharpen':object_settings['diameter']/4, #recommended to be 1/4-1/8 diameter of cells in pixels
2458
+ # 'normalize':True, #(if False, all following parameters ignored)
2459
+ # 'percentile':[2,98], #[perc_low, perc_high]
2460
+ # 'tile_norm':224, #normalize by tile set to e.g. 100 for normailize window to be 100 px
2461
+ # 'norm3D':True} #compute normalization across entire z-stack rather than plane-by-plane in stitching mode.
2462
+
2463
+ output = model.eval(x=batch,
2464
+ batch_size=cellpose_batch_size,
2465
+ normalize=False,
2466
+ channels=chans,
2467
+ channel_axis=3,
2468
+ diameter=object_settings['diameter'],
2469
+ flow_threshold=flow_threshold,
2470
+ cellprob_threshold=cellprob_threshold,
2471
+ rescale=None,
2472
+ resample=object_settings['resample'])
2473
+
2474
+ if len(output) == 4:
2475
+ masks, flows, _, _ = output
2476
+ elif len(output) == 3:
2477
+ masks, flows, _ = output
2147
2478
  else:
2148
- stitch_threshold=0.0
2149
- #print(batch.shape)
2150
- #batch, _, _, _ = dn.eval(x=batch, channels=chans, diameter=object_settings['diameter'])
2151
- #batch = np.stack((batch, batch), axis=-1)
2152
- #print(f'object: {object_type} chans : {chans} channels : {channels} model: {model_name}')
2153
- masks, flows, _, _ = model.eval(x=batch,
2154
- batch_size=cellpose_batch_size,
2155
- normalize=False,
2156
- channels=chans,
2157
- channel_axis=3,
2158
- diameter=object_settings['diameter'],
2159
- flow_threshold=flow_threshold,
2160
- cellprob_threshold=cellprob_threshold,
2161
- rescale=None,
2162
- resample=object_settings['resample'],
2163
- stitch_threshold=stitch_threshold)
2164
- #progress=None)
2479
+ raise ValueError(f"Unexpected number of return values from model.eval(). Expected 3 or 4, got {len(output)}")
2165
2480
 
2166
2481
  if timelapse:
2482
+ if settings['plot']:
2483
+ for idx, (mask, flow, image) in enumerate(zip(masks, flows[0], batch)):
2484
+ if idx == 0:
2485
+ num_objects = mask_object_count(mask)
2486
+ print(f'Number of objects: {num_objects}')
2487
+ plot_masks(batch=image, masks=mask, flows=flow, cmap='inferno', figuresize=figuresize, nr=1, file_type='.npz', print_object_number=True)
2488
+
2167
2489
  _save_object_counts_to_database(masks, object_type, batch_filenames, count_loc, added_string='_timelapse')
2168
2490
  if object_type in timelapse_objects:
2169
2491
  if timelapse_mode == 'btrack':
@@ -2192,35 +2514,54 @@ def generate_cellpose_masks(src, settings, object_type):
2192
2514
  name=name,
2193
2515
  batch_filenames=batch_filenames,
2194
2516
  object_type=object_type,
2195
- masks_3D=masks,
2517
+ masks=masks,
2196
2518
  timelapse_displacement=timelapse_displacement,
2197
2519
  timelapse_memory=timelapse_memory,
2198
2520
  timelapse_remove_transient=timelapse_remove_transient,
2199
2521
  plot=settings['plot'],
2200
2522
  save=settings['save'],
2201
- timelapse_mode=timelapse_mode)
2523
+ mode=timelapse_mode)
2202
2524
  else:
2203
2525
  mask_stack = _masks_to_masks_stack(masks)
2204
-
2205
2526
  else:
2206
2527
  _save_object_counts_to_database(masks, object_type, batch_filenames, count_loc, added_string='_before_filtration')
2207
- mask_stack = _filter_cp_masks(masks=masks,
2208
- flows=flows,
2209
- filter_size=object_settings['filter_size'],
2210
- minimum_size=object_settings['minimum_size'],
2211
- maximum_size=object_settings['maximum_size'],
2212
- remove_border_objects=object_settings['remove_border_objects'],
2213
- merge=False,
2214
- filter_dimm=object_settings['filter_dimm'],
2215
- batch=batch,
2216
- moving_avg_q1=moving_avg_q1,
2217
- moving_avg_q3=moving_avg_q3,
2218
- moving_count=moving_count,
2219
- plot=settings['plot'],
2220
- figuresize=figuresize)
2221
-
2222
- _save_object_counts_to_database(mask_stack, object_type, batch_filenames, count_loc, added_string='_after_filtration')
2528
+ if object_settings['merge'] and not settings['filter']:
2529
+ mask_stack = _filter_cp_masks(masks=masks,
2530
+ flows=flows,
2531
+ filter_size=False,
2532
+ filter_intensity=False,
2533
+ minimum_size=object_settings['minimum_size'],
2534
+ maximum_size=object_settings['maximum_size'],
2535
+ remove_border_objects=False,
2536
+ merge=object_settings['merge'],
2537
+ batch=batch,
2538
+ plot=settings['plot'],
2539
+ figuresize=figuresize)
2540
+
2541
+ if settings['filter']:
2542
+ mask_stack = _filter_cp_masks(masks=masks,
2543
+ flows=flows,
2544
+ filter_size=object_settings['filter_size'],
2545
+ filter_intensity=object_settings['filter_intensity'],
2546
+ minimum_size=object_settings['minimum_size'],
2547
+ maximum_size=object_settings['maximum_size'],
2548
+ remove_border_objects=object_settings['remove_border_objects'],
2549
+ merge=object_settings['merge'],
2550
+ batch=batch,
2551
+ plot=settings['plot'],
2552
+ figuresize=figuresize)
2553
+
2554
+ _save_object_counts_to_database(mask_stack, object_type, batch_filenames, count_loc, added_string='_after_filtration')
2555
+ else:
2556
+ mask_stack = _masks_to_masks_stack(masks)
2223
2557
 
2558
+ if settings['plot']:
2559
+ for idx, (mask, flow, image) in enumerate(zip(masks, flows[0], batch)):
2560
+ if idx == 0:
2561
+ num_objects = mask_object_count(mask)
2562
+ print(f'Number of objects, : {num_objects}')
2563
+ plot_masks(batch=image, masks=mask, flows=flow, cmap='inferno', figuresize=figuresize, nr=1, file_type='.npz', print_object_number=True)
2564
+
2224
2565
  if not np.any(mask_stack):
2225
2566
  average_obj_size = 0
2226
2567
  else:
@@ -2235,7 +2576,7 @@ def generate_cellpose_masks(src, settings, object_type):
2235
2576
  average_time = np.mean(time_ls) if len(time_ls) > 0 else 0
2236
2577
  time_in_min = average_time/60
2237
2578
  time_per_mask = average_time/batch_size
2238
- print(f'Processing {len(paths)} files with {batch_size} imgs: {(file_index+1)*(batch_size+1)}/{(len(paths))*(batch_size+1)}: Time/batch {time_in_min:.3f} min: Time/mask {time_per_mask:.3f}sec: {object_type} size: {overall_average_size:.3f} px2', end='\r', flush=True)
2579
+ print(f'Processing {len(paths)} files with {batch_size} imgs: {(file_index+1)*(batch_size+1)}/{(len(paths))*(batch_size+1)}: Time/batch {time_in_min:.3f} min: Time/mask {time_per_mask:.3f}sec: {object_type} size: {overall_average_size:.3f} px2')
2239
2580
  if not timelapse:
2240
2581
  if settings['plot']:
2241
2582
  plot_masks(batch, mask_stack, flows, figuresize=figuresize, cmap='inferno', nr=batch_size)
@@ -2247,4 +2588,885 @@ def generate_cellpose_masks(src, settings, object_type):
2247
2588
  batch_filenames = []
2248
2589
  gc.collect()
2249
2590
  torch.cuda.empty_cache()
2591
+ return
2592
+
2593
+ def generate_masks_from_imgs(src, model, model_name, batch_size, diameter, cellprob_threshold, flow_threshold, grayscale, save, normalize, channels, percentiles, circular, invert, plot, resize, target_height, target_width, remove_background, background, Signal_to_noise, verbose):
2594
+
2595
+ from .io import _load_images_and_labels, _load_normalized_images_and_labels
2596
+ from .utils import resize_images_and_labels, resizescikit
2597
+ from .plot import print_mask_and_flows
2598
+
2599
+ dst = os.path.join(src, model_name)
2600
+ os.makedirs(dst, exist_ok=True)
2601
+
2602
+ chans = [2, 1] if model_name == 'cyto2' else [0,0] if model_name == 'nucleus' else [1,0] if model_name == 'cyto' else [2, 0]
2603
+
2604
+ if grayscale:
2605
+ chans=[0, 0]
2606
+
2607
+ all_image_files = [os.path.join(src, f) for f in os.listdir(src) if f.endswith('.tif')]
2608
+ random.shuffle(all_image_files)
2609
+
2610
+ if verbose == True:
2611
+ print(f'Cellpose settings: Model: {model_name}, channels: {channels}, cellpose_chans: {chans}, diameter:{diameter}, flow_threshold:{flow_threshold}, cellprob_threshold:{cellprob_threshold}')
2612
+
2613
+ time_ls = []
2614
+ for i in range(0, len(all_image_files), batch_size):
2615
+ image_files = all_image_files[i:i+batch_size]
2616
+
2617
+ if normalize:
2618
+ images, _, image_names, _ = _load_normalized_images_and_labels(image_files, None, channels, percentiles, circular, invert, plot, remove_background, background, Signal_to_noise)
2619
+ images = [np.squeeze(img) if img.shape[-1] == 1 else img for img in images]
2620
+ orig_dims = [(image.shape[0], image.shape[1]) for image in images]
2621
+ else:
2622
+ images, _, image_names, _ = _load_images_and_labels(image_files, None, circular, invert)
2623
+ images = [np.squeeze(img) if img.shape[-1] == 1 else img for img in images]
2624
+ orig_dims = [(image.shape[0], image.shape[1]) for image in images]
2625
+ if resize:
2626
+ images, _ = resize_images_and_labels(images, None, target_height, target_width, True)
2627
+
2628
+ for file_index, stack in enumerate(images):
2629
+ start = time.time()
2630
+ output = model.eval(x=stack,
2631
+ normalize=False,
2632
+ channels=chans,
2633
+ channel_axis=3,
2634
+ diameter=diameter,
2635
+ flow_threshold=flow_threshold,
2636
+ cellprob_threshold=cellprob_threshold,
2637
+ rescale=False,
2638
+ resample=False,
2639
+ progress=False)
2640
+
2641
+ if len(output) == 4:
2642
+ mask, flows, _, _ = output
2643
+ elif len(output) == 3:
2644
+ mask, flows, _ = output
2645
+ else:
2646
+ raise ValueError("Unexpected number of return values from model.eval()")
2647
+
2648
+ if resize:
2649
+ dims = orig_dims[file_index]
2650
+ mask = resizescikit(mask, dims, order=0, preserve_range=True, anti_aliasing=False).astype(mask.dtype)
2651
+
2652
+ stop = time.time()
2653
+ duration = (stop - start)
2654
+ time_ls.append(duration)
2655
+ average_time = np.mean(time_ls) if len(time_ls) > 0 else 0
2656
+ print(f'Processing {file_index+1}/{len(images)} images : Time/image {average_time:.3f} sec', end='\r', flush=True)
2657
+ if plot:
2658
+ if resize:
2659
+ stack = resizescikit(stack, dims, preserve_range=True, anti_aliasing=False).astype(stack.dtype)
2660
+ print_mask_and_flows(stack, mask, flows, overlay=True)
2661
+ if save:
2662
+ output_filename = os.path.join(dst, image_names[file_index])
2663
+ cv2.imwrite(output_filename, mask)
2664
+
2665
+
2666
+ def check_cellpose_models(settings):
2667
+
2668
+ src = settings['src']
2669
+ settings.setdefault('batch_size', 10)
2670
+ settings.setdefault('CP_prob', 0)
2671
+ settings.setdefault('flow_threshold', 0.4)
2672
+ settings.setdefault('save', True)
2673
+ settings.setdefault('normalize', True)
2674
+ settings.setdefault('channels', [0,0])
2675
+ settings.setdefault('percentiles', None)
2676
+ settings.setdefault('circular', False)
2677
+ settings.setdefault('invert', False)
2678
+ settings.setdefault('plot', True)
2679
+ settings.setdefault('diameter', 40)
2680
+ settings.setdefault('grayscale', True)
2681
+ settings.setdefault('remove_background', False)
2682
+ settings.setdefault('background', 100)
2683
+ settings.setdefault('Signal_to_noise', 5)
2684
+ settings.setdefault('verbose', False)
2685
+ settings.setdefault('resize', False)
2686
+ settings.setdefault('target_height', None)
2687
+ settings.setdefault('target_width', None)
2688
+
2689
+ if settings['verbose']:
2690
+ settings_df = pd.DataFrame(list(settings.items()), columns=['setting_key', 'setting_value'])
2691
+ settings_df['setting_value'] = settings_df['setting_value'].apply(str)
2692
+ display(settings_df)
2693
+
2694
+ cellpose_models = ['cyto', 'nuclei', 'cyto2', 'cyto3']
2695
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
2696
+
2697
+ for model_name in cellpose_models:
2698
+
2699
+ model = cp_models.CellposeModel(gpu=True, model_type=model_name, device=device)
2700
+ print(f'Using {model_name}')
2701
+ generate_masks_from_imgs(src, model, model_name, settings['batch_size'], settings['diameter'], settings['CP_prob'], settings['flow_threshold'], settings['grayscale'], settings['save'], settings['normalize'], settings['channels'], settings['percentiles'], settings['circular'], settings['invert'], settings['plot'], settings['resize'], settings['target_height'], settings['target_width'], settings['remove_background'], settings['background'], settings['Signal_to_noise'], settings['verbose'])
2702
+
2703
+ return
2704
+
2705
+ def save_results_and_figure(src, fig, results):
2706
+
2707
+ if not isinstance(results, pd.DataFrame):
2708
+ results = pd.DataFrame(results)
2709
+
2710
+ results_dir = os.path.join(src, 'results')
2711
+ os.makedirs(results_dir, exist_ok=True)
2712
+ results_path = os.path.join(results_dir,f'results.csv')
2713
+ fig_path = os.path.join(results_dir, f'model_comparison_plot.pdf')
2714
+ results.to_csv(results_path, index=False)
2715
+ fig.savefig(fig_path, format='pdf')
2716
+ print(f'Saved figure to {fig_path} and results to {results_path}')
2717
+
2718
+ def compare_mask(args):
2719
+ src, filename, dirs, conditions = args
2720
+ paths = [os.path.join(d, filename) for d in dirs]
2721
+
2722
+ if not all(os.path.exists(path) for path in paths):
2723
+ return None
2724
+
2725
+ from .io import _read_mask # Import here to avoid issues in multiprocessing
2726
+ from .utils import extract_boundaries, boundary_f1_score, compute_segmentation_ap, jaccard_index
2727
+ from .plot import plot_comparison_results
2728
+
2729
+ masks = [_read_mask(path) for path in paths]
2730
+ file_results = {'filename': filename}
2731
+
2732
+ for i in range(len(masks)):
2733
+ for j in range(i + 1, len(masks)):
2734
+ mask_i, mask_j = masks[i], masks[j]
2735
+ f1_score = boundary_f1_score(mask_i, mask_j)
2736
+ jac_index = jaccard_index(mask_i, mask_j)
2737
+ ap_score = compute_segmentation_ap(mask_i, mask_j)
2738
+
2739
+ file_results.update({
2740
+ f'jaccard_{conditions[i]}_{conditions[j]}': jac_index,
2741
+ f'boundary_f1_{conditions[i]}_{conditions[j]}': f1_score,
2742
+ f'ap_{conditions[i]}_{conditions[j]}': ap_score
2743
+ })
2744
+
2745
+ return file_results
2746
+
2747
+ def compare_cellpose_masks(src, verbose=False, processes=None, save=True):
2748
+ from .plot import visualize_cellpose_masks, plot_comparison_results
2749
+ from .io import _read_mask
2750
+
2751
+ dirs = [os.path.join(src, d) for d in os.listdir(src) if os.path.isdir(os.path.join(src, d)) and d != 'results']
2752
+ dirs.sort() # Optional: sort directories if needed
2753
+ conditions = [os.path.basename(d) for d in dirs]
2754
+
2755
+ # Get common files in all directories
2756
+ common_files = set(os.listdir(dirs[0]))
2757
+ for d in dirs[1:]:
2758
+ common_files.intersection_update(os.listdir(d))
2759
+ common_files = list(common_files)
2760
+
2761
+ # Create a pool of workers
2762
+ with Pool(processes=processes) as pool:
2763
+ args = [(src, filename, dirs, conditions) for filename in common_files]
2764
+ results = pool.map(compare_mask, args)
2765
+
2766
+ # Filter out None results (from skipped files)
2767
+ results = [res for res in results if res is not None]
2768
+ #print(results)
2769
+ if verbose:
2770
+ for result in results:
2771
+ filename = result['filename']
2772
+ masks = [_read_mask(os.path.join(d, filename)) for d in dirs]
2773
+ visualize_cellpose_masks(masks, titles=conditions, filename=filename, save=save, src=src)
2774
+
2775
+ fig = plot_comparison_results(results)
2776
+ save_results_and_figure(src, fig, results)
2777
+ return
2778
+
2779
+ def _calculate_similarity(df, features, col_to_compare, val1, val2):
2780
+ """
2781
+ Calculate similarity scores of each well to the positive and negative controls using various metrics.
2782
+
2783
+ Args:
2784
+ df (pandas.DataFrame): DataFrame containing the data.
2785
+ features (list): List of feature columns to use for similarity calculation.
2786
+ col_to_compare (str): Column name to use for comparing groups.
2787
+ val1, val2 (str): Values in col_to_compare to create subsets for comparison.
2788
+
2789
+ Returns:
2790
+ pandas.DataFrame: DataFrame with similarity scores.
2791
+ """
2792
+ # Separate positive and negative control wells
2793
+ pos_control = df[df[col_to_compare] == val1][features].mean()
2794
+ neg_control = df[df[col_to_compare] == val2][features].mean()
2795
+
2796
+ # Standardize features for Mahalanobis distance
2797
+ scaler = StandardScaler()
2798
+ scaled_features = scaler.fit_transform(df[features])
2799
+
2800
+ # Regularize the covariance matrix to avoid singularity
2801
+ cov_matrix = np.cov(scaled_features, rowvar=False)
2802
+ inv_cov_matrix = None
2803
+ try:
2804
+ inv_cov_matrix = np.linalg.inv(cov_matrix)
2805
+ except np.linalg.LinAlgError:
2806
+ # Add a small value to the diagonal elements for regularization
2807
+ epsilon = 1e-5
2808
+ inv_cov_matrix = np.linalg.inv(cov_matrix + np.eye(cov_matrix.shape[0]) * epsilon)
2809
+
2810
+ # Calculate similarity scores
2811
+ df['similarity_to_pos_euclidean'] = df[features].apply(lambda row: euclidean(row, pos_control), axis=1)
2812
+ df['similarity_to_neg_euclidean'] = df[features].apply(lambda row: euclidean(row, neg_control), axis=1)
2813
+ df['similarity_to_pos_cosine'] = df[features].apply(lambda row: cosine(row, pos_control), axis=1)
2814
+ df['similarity_to_neg_cosine'] = df[features].apply(lambda row: cosine(row, neg_control), axis=1)
2815
+ df['similarity_to_pos_mahalanobis'] = df[features].apply(lambda row: mahalanobis(row, pos_control, inv_cov_matrix), axis=1)
2816
+ df['similarity_to_neg_mahalanobis'] = df[features].apply(lambda row: mahalanobis(row, neg_control, inv_cov_matrix), axis=1)
2817
+ df['similarity_to_pos_manhattan'] = df[features].apply(lambda row: cityblock(row, pos_control), axis=1)
2818
+ df['similarity_to_neg_manhattan'] = df[features].apply(lambda row: cityblock(row, neg_control), axis=1)
2819
+ df['similarity_to_pos_minkowski'] = df[features].apply(lambda row: minkowski(row, pos_control, p=3), axis=1)
2820
+ df['similarity_to_neg_minkowski'] = df[features].apply(lambda row: minkowski(row, neg_control, p=3), axis=1)
2821
+ df['similarity_to_pos_chebyshev'] = df[features].apply(lambda row: chebyshev(row, pos_control), axis=1)
2822
+ df['similarity_to_neg_chebyshev'] = df[features].apply(lambda row: chebyshev(row, neg_control), axis=1)
2823
+ df['similarity_to_pos_hamming'] = df[features].apply(lambda row: hamming(row, pos_control), axis=1)
2824
+ df['similarity_to_neg_hamming'] = df[features].apply(lambda row: hamming(row, neg_control), axis=1)
2825
+ df['similarity_to_pos_jaccard'] = df[features].apply(lambda row: jaccard(row, pos_control), axis=1)
2826
+ df['similarity_to_neg_jaccard'] = df[features].apply(lambda row: jaccard(row, neg_control), axis=1)
2827
+ df['similarity_to_pos_braycurtis'] = df[features].apply(lambda row: braycurtis(row, pos_control), axis=1)
2828
+ df['similarity_to_neg_braycurtis'] = df[features].apply(lambda row: braycurtis(row, neg_control), axis=1)
2829
+
2830
+ return df
2831
+
2832
+ def _permutation_importance(df, feature_string='channel_3', col_to_compare='col', pos='c1', neg='c2', exclude=None, n_repeats=10, clean=True, nr_to_plot=30, n_estimators=100, test_size=0.2, random_state=42, model_type='xgboost', n_jobs=-1):
2833
+
2834
+ """
2835
+ Calculates permutation importance for numerical features in the dataframe,
2836
+ comparing groups based on specified column values and uses the model to predict
2837
+ the class for all other rows in the dataframe.
2838
+
2839
+ Args:
2840
+ df (pandas.DataFrame): The DataFrame containing the data.
2841
+ feature_string (str): String to filter features that contain this substring.
2842
+ col_to_compare (str): Column name to use for comparing groups.
2843
+ pos, neg (str): Values in col_to_compare to create subsets for comparison.
2844
+ exclude (list or str, optional): Columns to exclude from features.
2845
+ n_repeats (int): Number of repeats for permutation importance.
2846
+ clean (bool): Whether to remove columns with a single value.
2847
+ nr_to_plot (int): Number of top features to plot based on permutation importance.
2848
+ n_estimators (int): Number of trees in the random forest, gradient boosting, or XGBoost model.
2849
+ test_size (float): Proportion of the dataset to include in the test split.
2850
+ random_state (int): Random seed for reproducibility.
2851
+ model_type (str): Type of model to use ('random_forest', 'logistic_regression', 'gradient_boosting', 'xgboost').
2852
+ n_jobs (int): Number of jobs to run in parallel for applicable models.
2853
+
2854
+ Returns:
2855
+ pandas.DataFrame: The original dataframe with added prediction and data usage columns.
2856
+ pandas.DataFrame: DataFrame containing the importances and standard deviations.
2857
+ """
2858
+
2859
+ from .utils import filter_dataframe_features
2860
+
2861
+ if 'cells_per_well' in df.columns:
2862
+ df = df.drop(columns=['cells_per_well'])
2863
+
2864
+ # Subset the dataframe based on specified column values
2865
+ df1 = df[df[col_to_compare] == pos].copy()
2866
+ df2 = df[df[col_to_compare] == neg].copy()
2867
+
2868
+ # Create target variable
2869
+ df1['target'] = 0
2870
+ df2['target'] = 1
2871
+
2872
+ # Combine the subsets for analysis
2873
+ combined_df = pd.concat([df1, df2])
2874
+
2875
+ if feature_string in ['channel_0', 'channel_1', 'channel_2', 'channel_3']:
2876
+ channel_of_interest = int(feature_string.split('_')[-1])
2877
+ elif not feature_string is 'morphology':
2878
+ channel_of_interest = 'morphology'
2879
+
2880
+ _, features = filter_dataframe_features(combined_df, channel_of_interest, exclude)
2881
+
2882
+ X = combined_df[features]
2883
+ y = combined_df['target']
2884
+
2885
+ # Split the data into training and testing sets
2886
+ X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=test_size, random_state=random_state)
2887
+
2888
+ # Label the data in the original dataframe
2889
+ combined_df['data_usage'] = 'train'
2890
+ combined_df.loc[X_test.index, 'data_usage'] = 'test'
2891
+
2892
+ # Initialize the model based on model_type
2893
+ if model_type == 'random_forest':
2894
+ model = RandomForestClassifier(n_estimators=n_estimators, random_state=random_state, n_jobs=n_jobs)
2895
+ elif model_type == 'logistic_regression':
2896
+ model = LogisticRegression(max_iter=1000, random_state=random_state, n_jobs=n_jobs)
2897
+ elif model_type == 'gradient_boosting':
2898
+ model = HistGradientBoostingClassifier(max_iter=n_estimators, random_state=random_state) # Supports n_jobs internally
2899
+ elif model_type == 'xgboost':
2900
+ model = XGBClassifier(n_estimators=n_estimators, random_state=random_state, nthread=n_jobs, use_label_encoder=False, eval_metric='logloss')
2901
+ else:
2902
+ raise ValueError(f"Unsupported model_type: {model_type}")
2903
+
2904
+ model.fit(X_train, y_train)
2905
+
2906
+ perm_importance = permutation_importance(model, X_train, y_train, n_repeats=n_repeats, random_state=random_state, n_jobs=n_jobs)
2907
+
2908
+ # Create a DataFrame for permutation importances
2909
+ permutation_df = pd.DataFrame({
2910
+ 'feature': [features[i] for i in perm_importance.importances_mean.argsort()],
2911
+ 'importance_mean': perm_importance.importances_mean[perm_importance.importances_mean.argsort()],
2912
+ 'importance_std': perm_importance.importances_std[perm_importance.importances_mean.argsort()]
2913
+ }).tail(nr_to_plot)
2914
+
2915
+ # Plotting
2916
+ fig, ax = plt.subplots()
2917
+ ax.barh(permutation_df['feature'], permutation_df['importance_mean'], xerr=permutation_df['importance_std'], color="teal", align="center", alpha=0.6)
2918
+ ax.set_xlabel('Permutation Importance')
2919
+ plt.tight_layout()
2920
+ plt.show()
2921
+
2922
+ # Feature importance for models that support it
2923
+ if model_type in ['random_forest', 'xgboost', 'gradient_boosting']:
2924
+ feature_importances = model.feature_importances_
2925
+ feature_importance_df = pd.DataFrame({
2926
+ 'feature': features,
2927
+ 'importance': feature_importances
2928
+ }).sort_values(by='importance', ascending=False).head(nr_to_plot)
2929
+
2930
+ # Plotting feature importance
2931
+ fig, ax = plt.subplots()
2932
+ ax.barh(feature_importance_df['feature'], feature_importance_df['importance'], color="blue", align="center", alpha=0.6)
2933
+ ax.set_xlabel('Feature Importance')
2934
+ plt.tight_layout()
2935
+ plt.show()
2936
+ else:
2937
+ feature_importance_df = pd.DataFrame()
2938
+
2939
+ # Predicting the target variable for the test set
2940
+ predictions_test = model.predict(X_test)
2941
+ combined_df.loc[X_test.index, 'predictions'] = predictions_test
2942
+
2943
+ # Predicting the target variable for the training set
2944
+ predictions_train = model.predict(X_train)
2945
+ combined_df.loc[X_train.index, 'predictions'] = predictions_train
2946
+
2947
+ # Predicting the target variable for all other rows in the dataframe
2948
+ X_all = df[features]
2949
+ all_predictions = model.predict(X_all)
2950
+ df['predictions'] = all_predictions
2951
+
2952
+ # Combine data usage labels back to the original dataframe
2953
+ combined_data_usage = pd.concat([combined_df[['data_usage']], df[['predictions']]], axis=0)
2954
+ df = df.join(combined_data_usage, how='left', rsuffix='_model')
2955
+
2956
+ # Calculating and printing the accuracy metrics
2957
+ accuracy = accuracy_score(y_test, predictions_test)
2958
+ precision = precision_score(y_test, predictions_test)
2959
+ recall = recall_score(y_test, predictions_test)
2960
+ f1 = f1_score(y_test, predictions_test)
2961
+ print(f"Accuracy: {accuracy}")
2962
+ print(f"Precision: {precision}")
2963
+ print(f"Recall: {recall}")
2964
+ print(f"F1 Score: {f1}")
2965
+
2966
+ # Printing class-specific accuracy metrics
2967
+ print("\nClassification Report:")
2968
+ print(classification_report(y_test, predictions_test))
2969
+
2970
+ df = _calculate_similarity(df, features, col_to_compare, pos, neg)
2971
+
2972
+ return [df, permutation_df, feature_importance_df, model, X_train, X_test, y_train, y_test]
2973
+
2974
+ def _shap_analysis(model, X_train, X_test):
2975
+
2976
+ """
2977
+ Performs SHAP analysis on the given model and data.
2978
+
2979
+ Args:
2980
+ model: The trained model.
2981
+ X_train (pandas.DataFrame): Training feature set.
2982
+ X_test (pandas.DataFrame): Testing feature set.
2983
+ """
2984
+
2985
+ explainer = shap.Explainer(model, X_train)
2986
+ shap_values = explainer(X_test)
2987
+
2988
+ # Summary plot
2989
+ shap.summary_plot(shap_values, X_test)
2990
+
2991
+ def plate_heatmap(src, model_type='xgboost', variable='predictions', grouping='mean', min_max='allq', cmap='viridis', channel_of_interest=3, min_count=25, n_estimators=100, col_to_compare='col', pos='c1', neg='c2', exclude=None, n_repeats=10, clean=True, nr_to_plot=20, verbose=False, n_jobs=-1):
2992
+ from .io import _read_and_merge_data
2993
+ from .plot import _plot_plates
2994
+
2995
+ db_loc = [src+'/measurements/measurements.db']
2996
+ tables = ['cell', 'nucleus', 'pathogen','cytoplasm']
2997
+ include_multinucleated, include_multiinfected, include_noninfected = True, 2.0, True
2998
+
2999
+ df, _ = _read_and_merge_data(db_loc,
3000
+ tables,
3001
+ verbose=verbose,
3002
+ include_multinucleated=include_multinucleated,
3003
+ include_multiinfected=include_multiinfected,
3004
+ include_noninfected=include_noninfected)
3005
+
3006
+ if not channel_of_interest is None:
3007
+ df['recruitment'] = df[f'pathogen_channel_{channel_of_interest}_mean_intensity']/df[f'cytoplasm_channel_{channel_of_interest}_mean_intensity']
3008
+ feature_string = f'channel_{channel_of_interest}'
3009
+ else:
3010
+ feature_string = None
3011
+
3012
+ output = _permutation_importance(df, feature_string, col_to_compare, pos, neg, exclude, n_repeats, clean, nr_to_plot, n_estimators=n_estimators, random_state=42, model_type=model_type, n_jobs=n_jobs)
3013
+
3014
+ _shap_analysis(output[3], output[4], output[5])
3015
+
3016
+ features = output[0].select_dtypes(include=[np.number]).columns.tolist()
3017
+
3018
+ if not variable in features:
3019
+ raise ValueError(f"Variable {variable} not found in the dataframe. Please choose one of the following: {features}")
3020
+
3021
+ plate_heatmap = _plot_plates(output[0], variable, grouping, min_max, cmap, min_count)
3022
+ return [output, plate_heatmap]
3023
+
3024
+ def join_measurments_and_annotation(src, tables = ['cell', 'nucleus', 'pathogen','cytoplasm']):
3025
+
3026
+ from .io import _read_and_merge_data, _read_db
3027
+
3028
+ db_loc = [src+'/measurements/measurements.db']
3029
+ loc = src+'/measurements/measurements.db'
3030
+ df, _ = _read_and_merge_data(db_loc,
3031
+ tables,
3032
+ verbose=True,
3033
+ include_multinucleated=True,
3034
+ include_multiinfected=True,
3035
+ include_noninfected=True)
3036
+
3037
+ paths_df = _read_db(loc, tables=['png_list'])
3038
+
3039
+ merged_df = pd.merge(df, paths_df[0], on='prcfo', how='left')
3040
+
3041
+ return merged_df
3042
+
3043
+ def jitterplot_by_annotation(src, x_column, y_column, plot_title='Jitter Plot', output_path=None, filter_column=None, filter_values=None):
3044
+ """
3045
+ Reads a CSV file and creates a jitter plot of one column grouped by another column.
3046
+
3047
+ Args:
3048
+ src (str): Path to the source data.
3049
+ x_column (str): Name of the column to be used for the x-axis.
3050
+ y_column (str): Name of the column to be used for the y-axis.
3051
+ plot_title (str): Title of the plot. Default is 'Jitter Plot'.
3052
+ output_path (str): Path to save the plot image. If None, the plot will be displayed. Default is None.
3053
+
3054
+ Returns:
3055
+ pd.DataFrame: The filtered and balanced DataFrame.
3056
+ """
3057
+ # Read the CSV file into a DataFrame
3058
+ df = join_measurments_and_annotation(src, tables=['cell', 'nucleus', 'pathogen', 'cytoplasm'])
3059
+
3060
+ # Print column names for debugging
3061
+ print(f"Generated dataframe with: {df.shape[1]} columns and {df.shape[0]} rows")
3062
+ #print("Columns in DataFrame:", df.columns.tolist())
3063
+
3064
+ # Replace NaN values with a specific label in x_column
3065
+ df[x_column] = df[x_column].fillna('NaN')
3066
+
3067
+ # Filter the DataFrame if filter_column and filter_values are provided
3068
+ if not filter_column is None:
3069
+ if isinstance(filter_column, str):
3070
+ df = df[df[filter_column].isin(filter_values)]
3071
+ if isinstance(filter_column, list):
3072
+ for i,val in enumerate(filter_column):
3073
+ print(f'hello {len(df)}')
3074
+ df = df[df[val].isin(filter_values[i])]
3075
+
3076
+ # Use the correct column names based on your DataFrame
3077
+ required_columns = ['plate_x', 'row_x', 'col_x']
3078
+ if not all(column in df.columns for column in required_columns):
3079
+ raise KeyError(f"DataFrame does not contain the necessary columns: {required_columns}")
3080
+
3081
+ # Filter to retain rows with non-NaN values in x_column and with matching plate, row, col values
3082
+ non_nan_df = df[df[x_column] != 'NaN']
3083
+ retained_rows = df[df[['plate_x', 'row_x', 'col_x']].apply(tuple, axis=1).isin(non_nan_df[['plate_x', 'row_x', 'col_x']].apply(tuple, axis=1))]
3084
+
3085
+ # Determine the minimum count of examples across all groups in x_column
3086
+ min_count = retained_rows[x_column].value_counts().min()
3087
+ print(f'Found {min_count} annotated images')
3088
+
3089
+ # Randomly sample min_count examples from each group in x_column
3090
+ balanced_df = retained_rows.groupby(x_column).apply(lambda x: x.sample(min_count, random_state=42)).reset_index(drop=True)
3091
+
3092
+ # Create the jitter plot
3093
+ plt.figure(figsize=(10, 6))
3094
+ jitter_plot = sns.stripplot(data=balanced_df, x=x_column, y=y_column, hue=x_column, jitter=True, palette='viridis', dodge=False)
3095
+ plt.title(plot_title)
3096
+ plt.xlabel(x_column)
3097
+ plt.ylabel(y_column)
3098
+
3099
+ # Customize the x-axis labels
3100
+ plt.xticks(rotation=45, ha='right')
3101
+
3102
+ # Adjust the position of the x-axis labels to be centered below the data
3103
+ ax = plt.gca()
3104
+ ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha='center')
3105
+
3106
+ # Save the plot to a file or display it
3107
+ if output_path:
3108
+ plt.savefig(output_path, bbox_inches='tight')
3109
+ print(f"Jitter plot saved to {output_path}")
3110
+ else:
3111
+ plt.show()
3112
+
3113
+ return balanced_df
3114
+
3115
+ def generate_image_umap(settings={}):
3116
+ """
3117
+ Generate UMAP or tSNE embedding and visualize the data with clustering.
3118
+
3119
+ Parameters:
3120
+ settings (dict): Dictionary containing the following keys:
3121
+ src (str): Source directory containing the data.
3122
+ row_limit (int): Limit the number of rows to process.
3123
+ tables (list): List of table names to read from the database.
3124
+ visualize (str): Visualization type.
3125
+ image_nr (int): Number of images to display.
3126
+ dot_size (int): Size of dots in the scatter plot.
3127
+ n_neighbors (int): Number of neighbors for UMAP.
3128
+ figuresize (int): Size of the figure.
3129
+ black_background (bool): Whether to use a black background.
3130
+ remove_image_canvas (bool): Whether to remove the image canvas.
3131
+ plot_outlines (bool): Whether to plot outlines.
3132
+ plot_points (bool): Whether to plot points.
3133
+ smooth_lines (bool): Whether to smooth lines.
3134
+ verbose (bool): Whether to print verbose output.
3135
+ embedding_by_controls (bool): Whether to use embedding from controls.
3136
+ col_to_compare (str): Column to compare for control-based embedding.
3137
+ pos (str): Positive control value.
3138
+ neg (str): Negative control value.
3139
+ clustering (str): Clustering method ('DBSCAN' or 'KMeans').
3140
+ exclude (list): List of columns to exclude from the analysis.
3141
+ plot_images (bool): Whether to plot images.
3142
+ reduction_method (str): Dimensionality reduction method ('UMAP' or 'tSNE').
3143
+ save_figure (bool): Whether to save the figure as a PDF.
3144
+
3145
+ Returns:
3146
+ pd.DataFrame: DataFrame with the original data and an additional column 'cluster' containing the cluster identity.
3147
+ """
3148
+
3149
+ from .io import _read_and_join_tables
3150
+ from .utils import get_db_paths, preprocess_data, reduction_and_clustering, remove_noise, generate_colors, correct_paths, plot_embedding, plot_clusters_grid, get_umap_image_settings
3151
+ from .alpha import cluster_feature_analysis, generate_umap_from_images
3152
+
3153
+ settings = get_umap_image_settings(settings)
3154
+
3155
+ if isinstance(settings['src'], str):
3156
+ settings['src'] = [settings['src']]
3157
+
3158
+ if settings['plot_images'] is False:
3159
+ settings['black_background'] = False
3160
+
3161
+ if settings['color_by']:
3162
+ settings['remove_cluster_noise'] = False
3163
+ settings['plot_outlines'] = False
3164
+ settings['smooth_lines'] = False
3165
+
3166
+ settings_df = pd.DataFrame(list(settings.items()), columns=['Key', 'Value'])
3167
+ settings_dir = os.path.join(settings['src'][0],'settings')
3168
+ settings_csv = os.path.join(settings_dir,'embedding_settings.csv')
3169
+ os.makedirs(settings_dir, exist_ok=True)
3170
+ settings_df.to_csv(settings_csv, index=False)
3171
+ display(settings_df)
3172
+
3173
+ db_paths = get_db_paths(settings['src'])
3174
+
3175
+ tables = settings['tables'] + ['png_list']
3176
+ all_df = pd.DataFrame()
3177
+ #image_paths = []
3178
+
3179
+ for i,db_path in enumerate(db_paths):
3180
+ df = _read_and_join_tables(db_path, table_names=tables)
3181
+ df, image_paths_tmp = correct_paths(df, settings['src'][i])
3182
+ all_df = pd.concat([all_df, df], axis=0)
3183
+ #image_paths.extend(image_paths_tmp)
3184
+
3185
+ all_df['cond'] = all_df['col'].apply(map_condition, neg=settings['neg'], pos=settings['pos'], mix=settings['mix'])
3186
+
3187
+ if settings['exclude_conditions']:
3188
+ if isinstance(settings['exclude_conditions'], str):
3189
+ settings['exclude_conditions'] = [settings['exclude_conditions']]
3190
+ row_count_before = len(all_df)
3191
+ all_df = all_df[~all_df['cond'].isin(settings['exclude_conditions'])]
3192
+ if settings['verbose']:
3193
+ print(f'Excluded {row_count_before - len(all_df)} rows after excluding: {settings["exclude_conditions"]}, rows left: {len(all_df)}')
3194
+
3195
+ if settings['row_limit'] is not None:
3196
+ all_df = all_df.sample(n=settings['row_limit'], random_state=42)
3197
+
3198
+ image_paths = all_df['png_path'].to_list()
3199
+
3200
+ if settings['embedding_by_controls']:
3201
+
3202
+ # Extract and reset the index for the column to compare
3203
+ col_to_compare = all_df[settings['col_to_compare']].reset_index(drop=True)
3204
+
3205
+ # Preprocess the data to obtain numeric data
3206
+ numeric_data = preprocess_data(all_df, settings['filter_by'], settings['remove_highly_correlated'], settings['log_data'], settings['exclude'])
3207
+
3208
+ # Convert numeric_data back to a DataFrame to align with col_to_compare
3209
+ numeric_data_df = pd.DataFrame(numeric_data)
3210
+
3211
+ # Ensure numeric_data_df and col_to_compare are properly aligned
3212
+ numeric_data_df = numeric_data_df.reset_index(drop=True)
3213
+
3214
+ # Assign the column back to numeric_data_df
3215
+ numeric_data_df[settings['col_to_compare']] = col_to_compare
3216
+
3217
+ # Subset the dataframe based on specified column values for controls
3218
+ positive_control_df = numeric_data_df[numeric_data_df[settings['col_to_compare']] == settings['pos']].copy()
3219
+ negative_control_df = numeric_data_df[numeric_data_df[settings['col_to_compare']] == settings['neg']].copy()
3220
+ control_numeric_data_df = pd.concat([positive_control_df, negative_control_df])
3221
+
3222
+ # Drop the comparison column from numeric_data_df and control_numeric_data_df
3223
+ numeric_data_df = numeric_data_df.drop(columns=[settings['col_to_compare']])
3224
+ control_numeric_data_df = control_numeric_data_df.drop(columns=[settings['col_to_compare']])
3225
+
3226
+ # Convert numeric_data_df and control_numeric_data_df back to numpy arrays
3227
+ numeric_data = numeric_data_df.values
3228
+ control_numeric_data = control_numeric_data_df.values
3229
+
3230
+ # Train the reducer on control data
3231
+ _, _, reducer = reduction_and_clustering(control_numeric_data, settings['n_neighbors'], settings['min_dist'], settings['metric'], settings['eps'], settings['min_samples'], settings['clustering'], settings['reduction_method'], settings['verbose'], n_jobs=settings['n_jobs'], mode='fit', model=False)
3232
+
3233
+ # Apply the trained reducer to the entire dataset
3234
+ numeric_data = preprocess_data(all_df, settings['filter_by'], settings['remove_highly_correlated'], settings['log_data'], settings['exclude'])
3235
+ embedding, labels, _ = reduction_and_clustering(numeric_data, settings['n_neighbors'], settings['min_dist'], settings['metric'], settings['eps'], settings['min_samples'], settings['clustering'], settings['reduction_method'], settings['verbose'], n_jobs=settings['n_jobs'], mode=None, model=reducer)
3236
+
3237
+ else:
3238
+ if settings['resnet_features']:
3239
+ numeric_data, embedding, labels = generate_umap_from_images(image_paths, settings['n_neighbors'], settings['min_dist'], settings['metric'], settings['clustering'], settings['eps'], settings['min_samples'], settings['n_jobs'], settings['verbose'])
3240
+ else:
3241
+ # Apply the trained reducer to the entire dataset
3242
+ numeric_data = preprocess_data(all_df, settings['filter_by'], settings['remove_highly_correlated'], settings['log_data'], settings['exclude'])
3243
+ embedding, labels, _ = reduction_and_clustering(numeric_data, settings['n_neighbors'], settings['min_dist'], settings['metric'], settings['eps'], settings['min_samples'], settings['clustering'], settings['reduction_method'], settings['verbose'], n_jobs=settings['n_jobs'])
3244
+
3245
+ if settings['remove_cluster_noise']:
3246
+ # Remove noise from the clusters (removes -1 labels from DBSCAN)
3247
+ embedding, labels = remove_noise(embedding, labels)
3248
+
3249
+ # Plot the results
3250
+ if settings['color_by']:
3251
+ if settings['embedding_by_controls']:
3252
+ labels = all_df[settings['color_by']]
3253
+ else:
3254
+ labels = all_df[settings['color_by']]
3255
+
3256
+ # Generate colors for the clusters
3257
+ colors = generate_colors(len(np.unique(labels)), settings['black_background'])
3258
+
3259
+ # Plot the embedding
3260
+ umap_plt = plot_embedding(embedding, image_paths, labels, settings['image_nr'], settings['img_zoom'], colors, settings['plot_by_cluster'], settings['plot_outlines'], settings['plot_points'], settings['plot_images'], settings['smooth_lines'], settings['black_background'], settings['figuresize'], settings['dot_size'], settings['remove_image_canvas'], settings['verbose'])
3261
+ if settings['plot_cluster_grids'] and settings['plot_images']:
3262
+ grid_plt = plot_clusters_grid(embedding, labels, settings['image_nr'], image_paths, colors, settings['figuresize'], settings['black_background'], settings['verbose'])
3263
+
3264
+ # Save figure as PDF if required
3265
+ if settings['save_figure']:
3266
+ results_dir = os.path.join(settings['src'][0], 'results')
3267
+ os.makedirs(results_dir, exist_ok=True)
3268
+ reduction_method = settings['reduction_method'].upper()
3269
+ embedding_path = os.path.join(results_dir, f'{reduction_method}_embedding.pdf')
3270
+ umap_plt.savefig(embedding_path, format='pdf')
3271
+ print(f'Saved {reduction_method} embedding to {embedding_path} and grid to {embedding_path}')
3272
+ if settings['plot_cluster_grids'] and settings['plot_images']:
3273
+ grid_path = os.path.join(results_dir, f'{reduction_method}_grid.pdf')
3274
+ grid_plt.savefig(grid_path, format='pdf')
3275
+ print(f'Saved {reduction_method} embedding to {embedding_path} and grid to {grid_path}')
3276
+
3277
+ # Add cluster labels to the dataframe
3278
+ all_df['cluster'] = labels
3279
+
3280
+ # Save the results to a CSV file
3281
+ results_dir = os.path.join(settings['src'][0], 'results')
3282
+ results_csv = os.path.join(results_dir,'embedding_results.csv')
3283
+ os.makedirs(results_dir, exist_ok=True)
3284
+ all_df.to_csv(results_csv, index=False)
3285
+ print(f'Results saved to {results_csv}')
3286
+
3287
+ if settings['analyze_clusters']:
3288
+ combined_results = cluster_feature_analysis(all_df)
3289
+ results_dir = os.path.join(settings['src'][0], 'results')
3290
+ cluster_results_csv = os.path.join(results_dir,'cluster_results.csv')
3291
+ os.makedirs(results_dir, exist_ok=True)
3292
+ combined_results.to_csv(cluster_results_csv, index=False)
3293
+ print(f'Cluster results saved to {cluster_results_csv}')
3294
+
3295
+ return all_df
3296
+
3297
+ # Define the mapping function
3298
+ def map_condition(col_value, neg='c1', pos='c2', mix='c3'):
3299
+ if col_value == neg:
3300
+ return 'neg'
3301
+ elif col_value == pos:
3302
+ return 'pos'
3303
+ elif col_value == mix:
3304
+ return 'mix'
3305
+ else:
3306
+ return 'screen'
3307
+
3308
+ def reducer_hyperparameter_search(settings={}, reduction_params=None, dbscan_params=None, kmeans_params=None, save=False):
3309
+ """
3310
+ Perform a hyperparameter search for UMAP or tSNE on the given data.
3311
+
3312
+ Parameters:
3313
+ settings (dict): Dictionary containing the following keys:
3314
+ src (str): Source directory containing the data.
3315
+ row_limit (int): Limit the number of rows to process.
3316
+ tables (list): List of table names to read from the database.
3317
+ filter_by (str): Column to filter the data.
3318
+ sample_size (int): Number of samples to use for the hyperparameter search.
3319
+ remove_highly_correlated (bool): Whether to remove highly correlated columns.
3320
+ log_data (bool): Whether to log transform the data.
3321
+ verbose (bool): Whether to print verbose output.
3322
+ reduction_method (str): Dimensionality reduction method ('UMAP' or 'tSNE').
3323
+ reduction_params (list): List of dictionaries containing hyperparameters to test for the reduction method.
3324
+ dbscan_params (list): List of dictionaries containing DBSCAN hyperparameters to test.
3325
+ kmeans_params (list): List of dictionaries containing KMeans hyperparameters to test.
3326
+ pointsize (int): Size of the points in the scatter plot.
3327
+ save (bool): Whether to save the resulting plot as a file.
3328
+
3329
+ Returns:
3330
+ None
3331
+ """
3332
+
3333
+ from .io import _read_and_join_tables
3334
+ from .utils import get_db_paths, preprocess_data, search_reduction_and_clustering, generate_colors, get_umap_image_settings
3335
+
3336
+ settings = get_umap_image_settings(settings)
3337
+ pointsize = settings['dot_size']
3338
+ if isinstance(dbscan_params, dict):
3339
+ dbscan_params = [dbscan_params]
3340
+
3341
+ if isinstance(kmeans_params, dict):
3342
+ kmeans_params = [kmeans_params]
3343
+
3344
+ if isinstance(reduction_params, dict):
3345
+ reduction_params = [reduction_params]
3346
+
3347
+ # Determine reduction method based on the keys in reduction_param
3348
+ if any('n_neighbors' in param for param in reduction_params):
3349
+ reduction_method = 'umap'
3350
+ elif any('perplexity' in param for param in reduction_params):
3351
+ reduction_method = 'tsne'
3352
+ elif any('perplexity' in param for param in reduction_params) and any('n_neighbors' in param for param in reduction_params):
3353
+ raise ValueError("Reduction parameters must include 'n_neighbors' for UMAP or 'perplexity' for tSNE, not both.")
3354
+
3355
+ if settings['reduction_method'].lower() != reduction_method:
3356
+ settings['reduction_method'] = reduction_method
3357
+ print(f'Changed reduction method to {reduction_method} based on the provided parameters.')
3358
+
3359
+ if settings['verbose']:
3360
+ display(pd.DataFrame(list(settings.items()), columns=['Key', 'Value']))
3361
+
3362
+ db_paths = get_db_paths(settings['src'])
3363
+
3364
+ tables = settings['tables']
3365
+ all_df = pd.DataFrame()
3366
+ for db_path in db_paths:
3367
+ df = _read_and_join_tables(db_path, table_names=tables)
3368
+ all_df = pd.concat([all_df, df], axis=0)
3369
+
3370
+ all_df['cond'] = all_df['col'].apply(map_condition, neg=settings['neg'], pos=settings['pos'], mix=settings['mix'])
3371
+
3372
+ if settings['exclude_conditions']:
3373
+ if isinstance(settings['exclude_conditions'], str):
3374
+ settings['exclude_conditions'] = [settings['exclude_conditions']]
3375
+ row_count_before = len(all_df)
3376
+ all_df = all_df[~all_df['cond'].isin(settings['exclude_conditions'])]
3377
+ if settings['verbose']:
3378
+ print(f'Excluded {row_count_before - len(all_df)} rows after excluding: {settings["exclude_conditions"]}, rows left: {len(all_df)}')
3379
+
3380
+ if settings['row_limit'] is not None:
3381
+ all_df = all_df.sample(n=settings['row_limit'], random_state=42)
3382
+
3383
+ numeric_data = preprocess_data(all_df, settings['filter_by'], settings['remove_highly_correlated'], settings['log_data'], settings['exclude'])
3384
+
3385
+ # Combine DBSCAN and KMeans parameters
3386
+ clustering_params = []
3387
+ if dbscan_params:
3388
+ for param in dbscan_params:
3389
+ param['method'] = 'dbscan'
3390
+ clustering_params.append(param)
3391
+ if kmeans_params:
3392
+ for param in kmeans_params:
3393
+ param['method'] = 'kmeans'
3394
+ clustering_params.append(param)
3395
+
3396
+ print('Testing paramiters:', reduction_params)
3397
+ print('Testing clustering paramiters:', clustering_params)
3398
+
3399
+ # Calculate the grid size
3400
+ grid_rows = len(reduction_params)
3401
+ grid_cols = len(clustering_params)
3402
+
3403
+ fig_width = grid_cols*10
3404
+ fig_height = grid_rows*10
3405
+
3406
+ fig, axs = plt.subplots(grid_rows, grid_cols, figsize=(fig_width, fig_height))
3407
+
3408
+ # Make sure axs is always an array of axes
3409
+ axs = np.atleast_1d(axs)
3410
+
3411
+ # Iterate through the Cartesian product of reduction and clustering hyperparameters
3412
+ for i, reduction_param in enumerate(reduction_params):
3413
+ for j, clustering_param in enumerate(clustering_params):
3414
+ if len(clustering_params) <= 1:
3415
+ axs[i].axis('off')
3416
+ ax = axs[i]
3417
+ elif len(reduction_params) <= 1:
3418
+ axs[j].axis('off')
3419
+ ax = axs[j]
3420
+ else:
3421
+ ax = axs[i, j]
3422
+
3423
+ # Perform dimensionality reduction and clustering
3424
+ if settings['reduction_method'].lower() == 'umap':
3425
+ n_neighbors = reduction_param.get('n_neighbors', 15)
3426
+
3427
+ if isinstance(n_neighbors, float):
3428
+ n_neighbors = int(n_neighbors * len(numeric_data))
3429
+
3430
+ min_dist = reduction_param.get('min_dist', 0.1)
3431
+ embedding, labels = search_reduction_and_clustering(numeric_data, n_neighbors, min_dist, settings['metric'],
3432
+ clustering_param.get('eps', 0.5), clustering_param.get('min_samples', 5),
3433
+ clustering_param['method'], settings['reduction_method'], settings['verbose'], reduction_param, n_jobs=settings['n_jobs'])
3434
+
3435
+ elif settings['reduction_method'].lower() == 'tsne':
3436
+ perplexity = reduction_param.get('perplexity', 30)
3437
+
3438
+ if isinstance(perplexity, float):
3439
+ perplexity = int(perplexity * len(numeric_data))
3440
+
3441
+ embedding, labels = search_reduction_and_clustering(numeric_data, perplexity, 0.1, settings['metric'],
3442
+ clustering_param.get('eps', 0.5), clustering_param.get('min_samples', 5),
3443
+ clustering_param['method'], settings['reduction_method'], settings['verbose'], reduction_param, n_jobs=settings['n_jobs'])
3444
+
3445
+ else:
3446
+ raise ValueError(f"Unsupported reduction method: {settings['reduction_method']}. Supported methods are 'UMAP' and 'tSNE'")
3447
+
3448
+ # Plot the results
3449
+ if settings['color_by']:
3450
+ unique_groups = all_df[settings['color_by']].unique()
3451
+ colors = generate_colors(len(unique_groups), False)
3452
+ for group, color in zip(unique_groups, colors):
3453
+ indices = all_df[settings['color_by']] == group
3454
+ ax.scatter(embedding[indices, 0], embedding[indices, 1], s=pointsize, label=f"{group}", color=color)
3455
+ else:
3456
+ unique_labels = np.unique(labels)
3457
+ colors = generate_colors(len(unique_labels), False)
3458
+ for label, color in zip(unique_labels, colors):
3459
+ ax.scatter(embedding[labels == label, 0], embedding[labels == label, 1], s=pointsize, label=f"Cluster {label}", color=color)
3460
+
3461
+ ax.set_title(f"{settings['reduction_method']} {reduction_param}\n{clustering_param['method']} {clustering_param}")
3462
+ ax.legend()
3463
+
3464
+ plt.tight_layout()
3465
+ if save:
3466
+ results_dir = os.path.join(settings['src'], 'results')
3467
+ os.makedirs(results_dir, exist_ok=True)
3468
+ plt.savefig(os.path.join(results_dir, 'hyperparameter_search.pdf'))
3469
+ else:
3470
+ plt.show()
3471
+
2250
3472
  return