spacr 0.0.2__py3-none-any.whl → 0.0.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
spacr/core.py CHANGED
@@ -1,10 +1,9 @@
1
- import os, sqlite3, gc, torch, time, random, shutil, cv2, tarfile, datetime
1
+ import os, sqlite3, gc, torch, time, random, shutil, cv2, tarfile, datetime, shap
2
2
 
3
- # image and array processing
4
3
  import numpy as np
5
4
  import pandas as pd
6
5
 
7
- import cellpose
6
+ from cellpose import train
8
7
  from cellpose import models as cp_models
9
8
 
10
9
  import statsmodels.formula.api as smf
@@ -14,23 +13,37 @@ from IPython.display import display
14
13
  from multiprocessing import Pool, cpu_count, Value, Lock
15
14
 
16
15
  import seaborn as sns
17
- import matplotlib.pyplot as plt
16
+
18
17
  from skimage.measure import regionprops, label
19
- import skimage.measure as measure
18
+ from skimage.morphology import square
20
19
  from skimage.transform import resize as resizescikit
21
- from sklearn.model_selection import train_test_split
22
20
  from collections import defaultdict
23
- import multiprocessing
24
21
  from torch.utils.data import DataLoader, random_split
22
+ from sklearn.cluster import KMeans
23
+ from sklearn.decomposition import PCA
24
+
25
+ from skimage import measure
26
+ from sklearn.model_selection import train_test_split
27
+ from sklearn.ensemble import IsolationForest, RandomForestClassifier, HistGradientBoostingClassifier
28
+ from sklearn.linear_model import LogisticRegression
29
+ from sklearn.inspection import permutation_importance
30
+ from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, classification_report
31
+ from sklearn.preprocessing import StandardScaler
32
+
33
+ from scipy.ndimage import binary_dilation
34
+ from scipy.spatial.distance import cosine, euclidean, mahalanobis, cityblock, minkowski, chebyshev, hamming, jaccard, braycurtis
35
+
36
+ import torchvision.transforms as transforms
37
+ from xgboost import XGBClassifier
38
+ import shap
39
+
40
+ import matplotlib.pyplot as plt
25
41
  import matplotlib
26
42
  matplotlib.use('Agg')
43
+ #import matplotlib.pyplot as plt
27
44
 
28
- import torchvision.transforms as transforms
29
- from sklearn.model_selection import train_test_split
30
- from sklearn.ensemble import IsolationForest
31
45
  from .logger import log_function_call
32
46
 
33
-
34
47
  def analyze_plaques(folder):
35
48
  summary_data = []
36
49
  details_data = []
@@ -67,169 +80,95 @@ def analyze_plaques(folder):
67
80
 
68
81
  print(f"Analysis completed and saved to database '{db_name}'.")
69
82
 
70
- def compare_masks(dir1, dir2, dir3, verbose=False):
71
-
72
- from .io import _read_mask
73
- from .plot import visualize_masks, plot_comparison_results
74
- from .utils import extract_boundaries, boundary_f1_score, compute_segmentation_ap, jaccard_index, dice_coefficient
75
-
76
- filenames = os.listdir(dir1)
77
- results = []
78
- cond_1 = os.path.basename(dir1)
79
- cond_2 = os.path.basename(dir2)
80
- cond_3 = os.path.basename(dir3)
81
- for index, filename in enumerate(filenames):
82
- print(f'Processing image:{index+1}', end='\r', flush=True)
83
- path1, path2, path3 = os.path.join(dir1, filename), os.path.join(dir2, filename), os.path.join(dir3, filename)
84
- if os.path.exists(path2) and os.path.exists(path3):
85
-
86
- mask1, mask2, mask3 = _read_mask(path1), _read_mask(path2), _read_mask(path3)
87
- boundary_true1, boundary_true2, boundary_true3 = extract_boundaries(mask1), extract_boundaries(mask2), extract_boundaries(mask3)
88
-
89
-
90
- true_masks, pred_masks = [mask1], [mask2, mask3] # Assuming mask1 is the ground truth for simplicity
91
- true_labels, pred_labels_1, pred_labels_2 = label(mask1), label(mask2), label(mask3)
92
- average_precision_0, average_precision_1 = compute_segmentation_ap(mask1, mask2), compute_segmentation_ap(mask1, mask3)
93
- ap_scores = [average_precision_0, average_precision_1]
94
-
95
- if verbose:
96
- unique_values1, unique_values2, unique_values3 = np.unique(mask1), np.unique(mask2), np.unique(mask3)
97
- print(f"Unique values in mask 1: {unique_values1}, mask 2: {unique_values2}, mask 3: {unique_values3}")
98
- visualize_masks(boundary_true1, boundary_true2, boundary_true3, title=f"Boundaries - {filename}")
99
-
100
- boundary_f1_12, boundary_f1_13, boundary_f1_23 = boundary_f1_score(mask1, mask2), boundary_f1_score(mask1, mask3), boundary_f1_score(mask2, mask3)
101
-
102
- if (np.unique(mask1).size == 1 and np.unique(mask1)[0] == 0) and \
103
- (np.unique(mask2).size == 1 and np.unique(mask2)[0] == 0) and \
104
- (np.unique(mask3).size == 1 and np.unique(mask3)[0] == 0):
105
- continue
106
-
107
- if verbose:
108
- unique_values4, unique_values5, unique_values6 = np.unique(boundary_f1_12), np.unique(boundary_f1_13), np.unique(boundary_f1_23)
109
- print(f"Unique values in boundary mask 1: {unique_values4}, mask 2: {unique_values5}, mask 3: {unique_values6}")
110
- visualize_masks(mask1, mask2, mask3, title=filename)
111
-
112
- jaccard12 = jaccard_index(mask1, mask2)
113
- dice12 = dice_coefficient(mask1, mask2)
114
- jaccard13 = jaccard_index(mask1, mask3)
115
- dice13 = dice_coefficient(mask1, mask3)
116
- jaccard23 = jaccard_index(mask2, mask3)
117
- dice23 = dice_coefficient(mask2, mask3)
118
-
119
- results.append({
120
- f'filename': filename,
121
- f'jaccard_{cond_1}_{cond_2}': jaccard12,
122
- f'dice_{cond_1}_{cond_2}': dice12,
123
- f'jaccard_{cond_1}_{cond_3}': jaccard13,
124
- f'dice_{cond_1}_{cond_3}': dice13,
125
- f'jaccard_{cond_2}_{cond_3}': jaccard23,
126
- f'dice_{cond_2}_{cond_3}': dice23,
127
- f'boundary_f1_{cond_1}_{cond_2}': boundary_f1_12,
128
- f'boundary_f1_{cond_1}_{cond_3}': boundary_f1_13,
129
- f'boundary_f1_{cond_2}_{cond_3}': boundary_f1_23,
130
- f'average_precision_{cond_1}_{cond_2}': ap_scores[0],
131
- f'average_precision_{cond_1}_{cond_3}': ap_scores[1]
132
- })
133
- else:
134
- print(f'Cannot find {path1} or {path2} or {path3}')
135
- fig = plot_comparison_results(results)
136
- return results, fig
137
-
138
- def generate_cp_masks(settings):
139
-
140
- src = settings['src']
141
- model_name = settings['model_name']
142
- channels = settings['channels']
143
- diameter = settings['diameter']
144
- regex = '.tif'
145
- #flow_threshold = 30
146
- cellprob_threshold = settings['cellprob_threshold']
147
- figuresize = 25
148
- cmap = 'inferno'
149
- verbose = settings['verbose']
150
- plot = settings['plot']
151
- save = settings['save']
152
- custom_model = settings['custom_model']
153
- signal_thresholds = 1000
154
- normalize = settings['normalize']
155
- resize = settings['resize']
156
- target_height = settings['width_height'][1]
157
- target_width = settings['width_height'][0]
158
- rescale = settings['rescale']
159
- resample = settings['resample']
160
- net_avg = settings['net_avg']
161
- invert = settings['invert']
162
- circular = settings['circular']
163
- percentiles = settings['percentiles']
164
- overlay = settings['overlay']
165
- grayscale = settings['grayscale']
166
- flow_threshold = settings['flow_threshold']
167
- batch_size = settings['batch_size']
168
-
169
- dst = os.path.join(src,'masks')
170
- os.makedirs(dst, exist_ok=True)
171
-
172
- identify_masks(src, dst, model_name, channels, diameter, batch_size, flow_threshold, cellprob_threshold, figuresize, cmap, verbose, plot, save, custom_model, signal_thresholds, normalize, resize, target_height, target_width, rescale, resample, net_avg, invert, circular, percentiles, overlay, grayscale)
173
-
174
83
  def train_cellpose(settings):
175
84
 
176
85
  from .io import _load_normalized_images_and_labels, _load_images_and_labels
177
86
  from .utils import resize_images_and_labels
178
87
 
179
88
  img_src = settings['img_src']
180
- mask_src= settings['mask_src']
181
- secondary_image_dir = None
182
- model_name = settings['model_name']
183
- model_type = settings['model_type']
184
- learning_rate = settings['learning_rate']
185
- weight_decay = settings['weight_decay']
186
- batch_size = settings['batch_size']
187
- n_epochs = settings['n_epochs']
188
- verbose = settings['verbose']
189
- signal_thresholds = settings['signal_thresholds']
190
- channels = settings['channels']
191
- from_scratch = settings['from_scratch']
192
- diameter = settings['diameter']
193
- resize = settings['resize']
194
- rescale = settings['rescale']
195
- normalize = settings['normalize']
196
- target_height = settings['width_height'][1]
197
- target_width = settings['width_height'][0]
198
- circular = settings['circular']
199
- invert = settings['invert']
200
- percentiles = settings['percentiles']
201
- grayscale = settings['grayscale']
89
+ mask_src = os.path.join(img_src, 'masks')
90
+
91
+ model_name = settings.setdefault( 'model_name', '')
92
+
93
+ model_name = settings.setdefault('model_name', 'model_name')
94
+
95
+ model_type = settings.setdefault( 'model_type', 'cyto')
96
+ learning_rate = settings.setdefault( 'learning_rate', 0.01)
97
+ weight_decay = settings.setdefault( 'weight_decay', 1e-05)
98
+ batch_size = settings.setdefault( 'batch_size', 50)
99
+ n_epochs = settings.setdefault( 'n_epochs', 100)
100
+ from_scratch = settings.setdefault( 'from_scratch', False)
101
+ diameter = settings.setdefault( 'diameter', 40)
102
+
103
+ remove_background = settings.setdefault( 'remove_background', False)
104
+ background = settings.setdefault( 'background', 100)
105
+ Signal_to_noise = settings.setdefault( 'Signal_to_noise', 10)
106
+ verbose = settings.setdefault( 'verbose', False)
107
+
202
108
 
109
+ channels = settings.setdefault( 'channels', [0,0])
110
+ normalize = settings.setdefault( 'normalize', True)
111
+ percentiles = settings.setdefault( 'percentiles', None)
112
+ circular = settings.setdefault( 'circular', False)
113
+ invert = settings.setdefault( 'invert', False)
114
+ resize = settings.setdefault( 'resize', False)
115
+
116
+ if resize:
117
+ target_height = settings['width_height'][1]
118
+ target_width = settings['width_height'][0]
119
+
120
+ grayscale = settings.setdefault( 'grayscale', True)
121
+ rescale = settings.setdefault( 'channels', False)
122
+ test = settings.setdefault( 'test', False)
123
+
124
+ if test:
125
+ test_img_src = os.path.join(os.path.dirname(img_src), 'test')
126
+ test_mask_src = os.path.join(test_img_src, 'mask')
127
+
128
+ test_images, test_masks, test_image_names, test_mask_names = None,None,None,None,
203
129
  print(settings)
204
130
 
205
131
  if from_scratch:
206
132
  model_name=f'scratch_{model_name}_{model_type}_e{n_epochs}_X{target_width}_Y{target_height}.CP_model'
207
133
  else:
208
- model_name=f'{model_name}_{model_type}_e{n_epochs}_X{target_width}_Y{target_height}.CP_model'
134
+ if resize:
135
+ model_name=f'{model_name}_{model_type}_e{n_epochs}_X{target_width}_Y{target_height}.CP_model'
136
+ else:
137
+ model_name=f'{model_name}_{model_type}_e{n_epochs}.CP_model'
209
138
 
210
139
  model_save_path = os.path.join(mask_src, 'models', 'cellpose_model')
211
- os.makedirs(os.path.dirname(model_save_path), exist_ok=True)
140
+ print(model_save_path)
141
+ os.makedirs(model_save_path, exist_ok=True)
212
142
 
213
143
  settings_df = pd.DataFrame(list(settings.items()), columns=['Key', 'Value'])
214
144
  settings_csv = os.path.join(model_save_path,f'{model_name}_settings.csv')
215
145
  settings_df.to_csv(settings_csv, index=False)
216
146
 
217
- if model_type =='cyto':
218
- if not from_scratch:
219
- model = cp_models.CellposeModel(gpu=True, model_type=model_type)
220
- else:
221
- model = cp_models.CellposeModel(gpu=True, model_type=model_type, net_avg=False, diam_mean=diameter, pretrained_model=None)
222
- if model_type !='cyto':
147
+ if from_scratch:
148
+ model = cp_models.CellposeModel(gpu=True, model_type=model_type, diam_mean=diameter, pretrained_model=None)
149
+ else:
223
150
  model = cp_models.CellposeModel(gpu=True, model_type=model_type)
224
151
 
225
-
226
-
227
- if normalize:
228
- images, masks, image_names, mask_names = _load_normalized_images_and_labels(image_dir=img_src, label_dir=mask_src, secondary_image_dir=secondary_image_dir, signal_thresholds=signal_thresholds, channels=channels, percentiles=percentiles, circular=circular, invert=invert, visualize=verbose)
152
+ if normalize:
153
+
154
+ image_files = [os.path.join(img_src, f) for f in os.listdir(img_src) if f.endswith('.tif')]
155
+ label_files = [os.path.join(mask_src, f) for f in os.listdir(mask_src) if f.endswith('.tif')]
156
+ images, masks, image_names, mask_names = _load_normalized_images_and_labels(image_files, label_files, channels, percentiles, circular, invert, verbose, remove_background, background, Signal_to_noise)
229
157
  images = [np.squeeze(img) if img.shape[-1] == 1 else img for img in images]
158
+
159
+ if test:
160
+ test_image_files = [os.path.join(test_img_src, f) for f in os.listdir(test_img_src) if f.endswith('.tif')]
161
+ test_label_files = [os.path.join(test_mask_src, f) for f in os.listdir(test_mask_src) if f.endswith('.tif')]
162
+ test_images, test_masks, test_image_names, test_mask_names = _load_normalized_images_and_labels(test_image_files, test_label_files, channels, percentiles, circular, invert, verbose, remove_background, background, Signal_to_noise)
163
+ test_images = [np.squeeze(img) if img.shape[-1] == 1 else img for img in test_images]
164
+
230
165
  else:
231
166
  images, masks, image_names, mask_names = _load_images_and_labels(img_src, mask_src, circular, invert)
232
167
  images = [np.squeeze(img) if img.shape[-1] == 1 else img for img in images]
168
+
169
+ if test:
170
+ test_images, test_masks, test_image_names, test_mask_names = _load_images_and_labels(img_src=test_img_src, mask_src=test_mask_src, circular=circular, invert=invert)
171
+ test_images = [np.squeeze(img) if img.shape[-1] == 1 else img for img in test_images]
233
172
 
234
173
  if resize:
235
174
  images, masks = resize_images_and_labels(images, masks, target_height, target_width, show_example=True)
@@ -248,25 +187,41 @@ def train_cellpose(settings):
248
187
 
249
188
  print(f'image shape: {images[0].shape}, image type: images[0].shape mask shape: {masks[0].shape}, image type: masks[0].shape')
250
189
  save_every = int(n_epochs/10)
251
- print('cellpose image input dtype', images[0].dtype)
252
- print('cellpose mask input dtype', masks[0].dtype)
253
- # Train the model
254
- model.train(train_data=images, #(list of arrays (2D or 3D)) – images for training
255
- train_labels=masks, #(list of arrays (2D or 3D)) – labels for train_data, where 0=no masks; 1,2,…=mask labels can include flows as additional images
256
- train_files=image_names, #(list of strings) – file names for images in train_data (to save flows for future runs)
257
- channels=cp_channels, #(list of ints (default, None)) – channels to use for training
258
- normalize=False, #(bool (default, True)) – normalize data so 0.0=1st percentile and 1.0=99th percentile of image intensities in each channel
259
- save_path=model_save_path, #(string (default, None)) – where to save trained model, if None it is not saved
260
- save_every=save_every, #(int (default, 100)) – save network every [save_every] epochs
261
- learning_rate=learning_rate, #(float or list/np.ndarray (default, 0.2)) – learning rate for training, if list, must be same length as n_epochs
262
- n_epochs=n_epochs, #(int (default, 500)) – how many times to go through whole training set during training
263
- weight_decay=weight_decay, #(float (default, 0.00001)) –
264
- SGD=True, #(bool (default, True)) – use SGD as optimization instead of RAdam
265
- batch_size=batch_size, #(int (optional, default 8)) – number of 224x224 patches to run simultaneously on the GPU (can make smaller or bigger depending on GPU memory usage)
266
- nimg_per_epoch=None, #(int (optional, default None)) – minimum number of images to train on per epoch, with a small training set (< 8 images) it may help to set to 8
267
- rescale=rescale, #(bool (default, True)) – whether or not to rescale images to diam_mean during training, if True it assumes you will fit a size model after training or resize your images accordingly, if False it will try to train the model to be scale-invariant (works worse)
268
- min_train_masks=1, #(int (default, 5)) – minimum number of masks an image must have to use in training set
269
- model_name=model_name) #(str (default, None)) – name of network, otherwise saved with name as params + training start time
190
+ if save_every < 10:
191
+ save_every = n_epochs
192
+
193
+ train.train_seg(model.net,
194
+ train_data=images,
195
+ train_labels=masks,
196
+ train_files=image_names,
197
+ train_labels_files=mask_names,
198
+ train_probs=None,
199
+ test_data=test_images,
200
+ test_labels=test_masks,
201
+ test_files=test_image_names,
202
+ test_labels_files=test_mask_names,
203
+ test_probs=None,
204
+ load_files=True,
205
+ batch_size=batch_size,
206
+ learning_rate=learning_rate,
207
+ n_epochs=n_epochs,
208
+ weight_decay=weight_decay,
209
+ momentum=0.9,
210
+ SGD=False,
211
+ channels=cp_channels,
212
+ channel_axis=None,
213
+ #rgb=False,
214
+ normalize=False,
215
+ compute_flows=False,
216
+ save_path=model_save_path,
217
+ save_every=save_every,
218
+ nimg_per_epoch=None,
219
+ nimg_test_per_epoch=None,
220
+ rescale=rescale,
221
+ #scale_range=None,
222
+ #bsize=224,
223
+ min_train_masks=1,
224
+ model_name=model_name)
270
225
 
271
226
  return print(f"Model saved at: {model_save_path}/{model_name}")
272
227
 
@@ -831,15 +786,6 @@ def merge_pred_mes(src,
831
786
 
832
787
  if verbose:
833
788
  _plot_histograms_and_stats(df=joined_df)
834
-
835
- #dv = joined_df.copy()
836
- #if 'prc' not in dv.columns:
837
- #dv['prc'] = dv['plate'] + '_' + dv['row'] + '_' + dv['col']
838
- #dv = dv[['pred']].groupby('prc').mean()
839
- #dv.set_index('prc', inplace=True)
840
-
841
- #loc = '/mnt/data/CellVoyager/20x/tsg101/crispr_screen/all/measurements/dv.csv'
842
- #dv.to_csv(loc, index=True, header=True, mode='w')
843
789
 
844
790
  return joined_df
845
791
 
@@ -926,30 +872,38 @@ def annotate_results(pred_loc):
926
872
  display(df)
927
873
  return df
928
874
 
929
- def generate_dataset(src, file_type=None, experiment='TSG101_screen', sample=None):
875
+ def generate_dataset(src, file_metadata=None, experiment='TSG101_screen', sample=None):
930
876
 
931
- from .utils import init_globals, add_images_to_tar
932
-
933
- db_path = os.path.join(src, 'measurements','measurements.db')
877
+ from .utils import initiate_counter, add_images_to_tar
878
+
879
+ db_path = os.path.join(src, 'measurements', 'measurements.db')
934
880
  dst = os.path.join(src, 'datasets')
935
-
936
- global total_images
937
881
  all_paths = []
938
-
882
+
939
883
  # Connect to the database and retrieve the image paths
940
884
  print(f'Reading DataBase: {db_path}')
941
- with sqlite3.connect(db_path) as conn:
942
- cursor = conn.cursor()
943
- if file_type:
944
- cursor.execute("SELECT png_path FROM png_list WHERE png_path LIKE ?", (f"%{file_type}%",))
945
- else:
946
- cursor.execute("SELECT png_path FROM png_list")
947
- while True:
948
- rows = cursor.fetchmany(1000)
949
- if not rows:
950
- break
951
- all_paths.extend([row[0] for row in rows])
952
-
885
+ try:
886
+ with sqlite3.connect(db_path) as conn:
887
+ cursor = conn.cursor()
888
+ if file_metadata:
889
+ if isinstance(file_metadata, str):
890
+ cursor.execute("SELECT png_path FROM png_list WHERE png_path LIKE ?", (f"%{file_metadata}%",))
891
+ else:
892
+ cursor.execute("SELECT png_path FROM png_list")
893
+
894
+ while True:
895
+ rows = cursor.fetchmany(1000)
896
+ if not rows:
897
+ break
898
+ all_paths.extend([row[0] for row in rows])
899
+
900
+ except sqlite3.Error as e:
901
+ print(f"Database error: {e}")
902
+ return
903
+ except Exception as e:
904
+ print(f"Error: {e}")
905
+ return
906
+
953
907
  if isinstance(sample, int):
954
908
  selected_paths = random.sample(all_paths, sample)
955
909
  print(f'Random selection of {len(selected_paths)} paths')
@@ -957,23 +911,18 @@ def generate_dataset(src, file_type=None, experiment='TSG101_screen', sample=Non
957
911
  selected_paths = all_paths
958
912
  random.shuffle(selected_paths)
959
913
  print(f'All paths: {len(selected_paths)} paths')
960
-
914
+
961
915
  total_images = len(selected_paths)
962
- print(f'found {total_images} images')
963
-
916
+ print(f'Found {total_images} images')
917
+
964
918
  # Create a temp folder in dst
965
919
  temp_dir = os.path.join(dst, "temp_tars")
966
920
  os.makedirs(temp_dir, exist_ok=True)
967
921
 
968
922
  # Chunking the data
969
- if len(selected_paths) > 10000:
970
- num_procs = cpu_count()-2
971
- chunk_size = len(selected_paths) // num_procs
972
- remainder = len(selected_paths) % num_procs
973
- else:
974
- num_procs = 2
975
- chunk_size = len(selected_paths) // 2
976
- remainder = 0
923
+ num_procs = max(2, cpu_count() - 2)
924
+ chunk_size = len(selected_paths) // num_procs
925
+ remainder = len(selected_paths) % num_procs
977
926
 
978
927
  paths_chunks = []
979
928
  start = 0
@@ -983,45 +932,43 @@ def generate_dataset(src, file_type=None, experiment='TSG101_screen', sample=Non
983
932
  start = end
984
933
 
985
934
  temp_tar_files = [os.path.join(temp_dir, f'temp_{i}.tar') for i in range(num_procs)]
986
-
987
- # Initialize the shared objects
988
- counter_ = Value('i', 0)
989
- lock_ = Lock()
990
935
 
991
- ctx = multiprocessing.get_context('spawn')
992
-
993
936
  print(f'Generating temporary tar files in {dst}')
994
-
937
+
938
+ # Initialize shared counter and lock
939
+ counter = Value('i', 0)
940
+ lock = Lock()
941
+
942
+ with Pool(processes=num_procs, initializer=initiate_counter, initargs=(counter, lock)) as pool:
943
+ pool.starmap(add_images_to_tar, [(paths_chunks[i], temp_tar_files[i], total_images) for i in range(num_procs)])
944
+
995
945
  # Combine the temporary tar files into a final tar
996
946
  date_name = datetime.date.today().strftime('%y%m%d')
997
- tar_name = f'{date_name}_{experiment}_{file_type}.tar'
947
+ if not file_metadata is None:
948
+ tar_name = f'{date_name}_{experiment}_{file_metadata}.tar'
949
+ else:
950
+ tar_name = f'{date_name}_{experiment}.tar'
951
+ tar_name = os.path.join(dst, tar_name)
998
952
  if os.path.exists(tar_name):
999
953
  number = random.randint(1, 100)
1000
- tar_name_2 = f'{date_name}_{experiment}_{file_type}_{number}.tar'
1001
- print(f'Warning: {os.path.basename(tar_name)} exists saving as {os.path.basename(tar_name_2)} ')
1002
- tar_name = tar_name_2
1003
-
1004
- # Add the counter and lock to the arguments for pool.map
954
+ tar_name_2 = f'{date_name}_{experiment}_{file_metadata}_{number}.tar'
955
+ print(f'Warning: {os.path.basename(tar_name)} exists, saving as {os.path.basename(tar_name_2)} ')
956
+ tar_name = os.path.join(dst, tar_name_2)
957
+
1005
958
  print(f'Merging temporary files')
1006
- #with Pool(processes=num_procs, initializer=init_globals, initargs=(counter_, lock_)) as pool:
1007
- # results = pool.map(add_images_to_tar, zip(paths_chunks, temp_tar_files))
1008
959
 
1009
- with ctx.Pool(processes=num_procs, initializer=init_globals, initargs=(counter_, lock_)) as pool:
1010
- results = pool.map(add_images_to_tar, zip(paths_chunks, temp_tar_files))
1011
-
1012
- with tarfile.open(os.path.join(dst, tar_name), 'w') as final_tar:
1013
- for tar_path in results:
1014
- with tarfile.open(tar_path, 'r') as t:
1015
- for member in t.getmembers():
1016
- t.extract(member, path=dst)
1017
- final_tar.add(os.path.join(dst, member.name), arcname=member.name)
1018
- os.remove(os.path.join(dst, member.name))
1019
- os.remove(tar_path)
960
+ with tarfile.open(tar_name, 'w') as final_tar:
961
+ for temp_tar_path in temp_tar_files:
962
+ with tarfile.open(temp_tar_path, 'r') as temp_tar:
963
+ for member in temp_tar.getmembers():
964
+ file_obj = temp_tar.extractfile(member)
965
+ final_tar.addfile(member, file_obj)
966
+ os.remove(temp_tar_path)
1020
967
 
1021
968
  # Delete the temp folder
1022
969
  shutil.rmtree(temp_dir)
1023
- print(f"\nSaved {total_images} images to {os.path.join(dst, tar_name)}")
1024
-
970
+ print(f"\nSaved {total_images} images to {tar_name}")
971
+
1025
972
  def apply_model_to_tar(tar_path, model_path, file_type='cell_png', image_size=224, batch_size=64, normalize=True, preload='images', num_workers=10, verbose=False):
1026
973
 
1027
974
  from .io import TarImageDataset, DataLoader
@@ -1128,7 +1075,6 @@ def apply_model(src, model_path, image_size=224, batch_size=64, normalize=True,
1128
1075
  torch.cuda.memory.empty_cache()
1129
1076
  return df
1130
1077
 
1131
-
1132
1078
  def generate_training_data_file_list(src,
1133
1079
  target='protein of interest',
1134
1080
  cell_dim=4,
@@ -1257,7 +1203,14 @@ def generate_training_dataset(src, mode='annotation', annotation_column='test',
1257
1203
 
1258
1204
  db_path = os.path.join(src, 'measurements','measurements.db')
1259
1205
  dst = os.path.join(src, 'datasets', 'training')
1260
-
1206
+
1207
+ if os.path.exists(dst):
1208
+ for i in range(1, 1000):
1209
+ dst = os.path.join(src, 'datasets', f'training_{i}')
1210
+ if not os.path.exists(dst):
1211
+ print(f'Creating new directory for training: {dst}')
1212
+ break
1213
+
1261
1214
  if mode == 'annotation':
1262
1215
  class_paths_ls_2 = []
1263
1216
  class_paths_ls = training_dataset_from_annotation(db_path, dst, annotation_column, annotated_classes=annotated_classes)
@@ -1268,6 +1221,7 @@ def generate_training_dataset(src, mode='annotation', annotation_column='test',
1268
1221
 
1269
1222
  elif mode == 'metadata':
1270
1223
  class_paths_ls = []
1224
+ class_len_ls = []
1271
1225
  [df] = _read_db(db_loc=db_path, tables=['png_list'])
1272
1226
  df['metadata_based_class'] = pd.NA
1273
1227
  for i, class_ in enumerate(classes):
@@ -1275,7 +1229,18 @@ def generate_training_dataset(src, mode='annotation', annotation_column='test',
1275
1229
  df.loc[df[metadata_type_by].isin(ls), 'metadata_based_class'] = class_
1276
1230
 
1277
1231
  for class_ in classes:
1232
+ if size == None:
1233
+ c_s = []
1234
+ for c in classes:
1235
+ c_s_t_df = df[df['metadata_based_class'] == c]
1236
+ c_s.append(len(c_s_t_df))
1237
+ print(f'Found {len(c_s_t_df)} images for class {c}')
1238
+ size = min(c_s)
1239
+ print(f'Using the smallest class size: {size}')
1240
+
1278
1241
  class_temp_df = df[df['metadata_based_class'] == class_]
1242
+ class_len_ls.append(len(class_temp_df))
1243
+ print(f'Found {len(class_temp_df)} images for class {class_}')
1279
1244
  class_paths_temp = random.sample(class_temp_df['png_path'].tolist(), size)
1280
1245
  class_paths_ls.append(class_paths_temp)
1281
1246
 
@@ -1332,7 +1297,8 @@ def generate_training_dataset(src, mode='annotation', annotation_column='test',
1332
1297
 
1333
1298
  return
1334
1299
 
1335
- def generate_loaders(src, train_mode='erm', mode='train', image_size=224, batch_size=32, classes=['nc','pc'], num_workers=None, validation_split=0.0, max_show=2, pin_memory=False, normalize=False, verbose=False):
1300
+ def generate_loaders(src, train_mode='erm', mode='train', image_size=224, batch_size=32, classes=['nc','pc'], num_workers=None, validation_split=0.0, max_show=2, pin_memory=False, normalize=False, channels=[1, 2, 3], verbose=False):
1301
+
1336
1302
  """
1337
1303
  Generate data loaders for training and validation/test datasets.
1338
1304
 
@@ -1349,16 +1315,40 @@ def generate_loaders(src, train_mode='erm', mode='train', image_size=224, batch_
1349
1315
  - pin_memory (bool): Whether to pin memory for faster data transfer.
1350
1316
  - normalize (bool): Whether to normalize the input images.
1351
1317
  - verbose (bool): Whether to print additional information and show images.
1318
+ - channels (list): The list of channels to retain. Options are [1, 2, 3] for all channels, [1, 2] for blue and green, etc.
1352
1319
 
1353
1320
  Returns:
1354
1321
  - train_loaders (list): List of data loaders for training datasets.
1355
1322
  - val_loaders (list): List of data loaders for validation datasets.
1356
1323
  - plate_names (list): List of plate names (only applicable when train_mode is 'irm').
1357
1324
  """
1358
-
1325
+
1359
1326
  from .io import MyDataset
1360
1327
  from .plot import _imshow
1361
-
1328
+ from torchvision import transforms
1329
+ from torch.utils.data import DataLoader, random_split
1330
+ from collections import defaultdict
1331
+ import os
1332
+ import random
1333
+ from PIL import Image
1334
+ from torchvision.transforms import ToTensor
1335
+ from .utils import SelectChannels
1336
+
1337
+ chans = []
1338
+
1339
+ if 'r' in channels:
1340
+ chans.append(1)
1341
+ if 'g' in channels:
1342
+ chans.append(2)
1343
+ if 'b' in channels:
1344
+ chans.append(3)
1345
+
1346
+ channels = chans
1347
+
1348
+ if verbose:
1349
+ print(f'Training a network on channels: {channels}')
1350
+ print(f'Channel 1: Red, Channel 2: Green, Channel 3: Blue')
1351
+
1362
1352
  plate_to_filenames = defaultdict(list)
1363
1353
  plate_to_labels = defaultdict(list)
1364
1354
  train_loaders = []
@@ -1369,31 +1359,30 @@ def generate_loaders(src, train_mode='erm', mode='train', image_size=224, batch_
1369
1359
  transform = transforms.Compose([
1370
1360
  transforms.ToTensor(),
1371
1361
  transforms.CenterCrop(size=(image_size, image_size)),
1362
+ SelectChannels(channels),
1372
1363
  transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])
1373
1364
  else:
1374
1365
  transform = transforms.Compose([
1375
1366
  transforms.ToTensor(),
1376
- transforms.CenterCrop(size=(image_size, image_size))])
1377
-
1367
+ transforms.CenterCrop(size=(image_size, image_size)),
1368
+ SelectChannels(channels)])
1369
+
1378
1370
  if mode == 'train':
1379
1371
  data_dir = os.path.join(src, 'train')
1380
1372
  shuffle = True
1381
- print(f'Generating Train and validation datasets')
1382
-
1373
+ print('Generating Train and validation datasets')
1383
1374
  elif mode == 'test':
1384
1375
  data_dir = os.path.join(src, 'test')
1385
1376
  val_loaders = []
1386
- validation_split=0.0
1377
+ validation_split = 0.0
1387
1378
  shuffle = True
1388
- print(f'Generating test dataset')
1389
-
1379
+ print('Generating test dataset')
1390
1380
  else:
1391
1381
  print(f'mode:{mode} is not valid, use mode = train or test')
1392
1382
  return
1393
-
1383
+
1394
1384
  if train_mode == 'erm':
1395
1385
  data = MyDataset(data_dir, classes, transform=transform, shuffle=shuffle, pin_memory=pin_memory)
1396
- #train_loaders = DataLoader(data, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers if num_workers is not None else 0, pin_memory=pin_memory)
1397
1386
  if validation_split > 0:
1398
1387
  train_size = int((1 - validation_split) * len(data))
1399
1388
  val_size = len(data) - train_size
@@ -1450,7 +1439,6 @@ def generate_loaders(src, train_mode='erm', mode='train', image_size=224, batch_
1450
1439
  images = images.cpu()
1451
1440
  label_strings = [str(label.item()) for label in labels]
1452
1441
  _imshow(images, label_strings, nrow=20, fontsize=12)
1453
-
1454
1442
  elif train_mode == 'irm':
1455
1443
  for plate_name, train_loader in zip(plate_names, train_loaders):
1456
1444
  print(f'Plate: {plate_name} with {len(train_loader.dataset)} images')
@@ -1569,15 +1557,30 @@ def analyze_recruitment(src, metadata_settings, advanced_settings):
1569
1557
  df = df.dropna(subset=['condition'])
1570
1558
  print(f'After dropping non-annotated wells: {len(df)} rows')
1571
1559
  files = df['file_name'].tolist()
1560
+ print(f'found: {len(files)} files')
1572
1561
  files = [item + '.npy' for item in files]
1573
1562
  random.shuffle(files)
1574
-
1563
+
1564
+ _max = 10**100
1565
+
1566
+ if cell_size_range is None and nucleus_size_range is None and pathogen_size_range is None:
1567
+ filter_min_max = None
1568
+ else:
1569
+ if cell_size_range is None:
1570
+ cell_size_range = [0,_max]
1571
+ if nucleus_size_range is None:
1572
+ nucleus_size_range = [0,_max]
1573
+ if pathogen_size_range is None:
1574
+ pathogen_size_range = [0,_max]
1575
+
1576
+ filter_min_max = [[cell_size_range[0],cell_size_range[1]],[nucleus_size_range[0],nucleus_size_range[1]],[pathogen_size_range[0],pathogen_size_range[1]]]
1577
+
1575
1578
  if plot:
1576
1579
  plot_settings = {'include_noninfected':include_noninfected,
1577
1580
  'include_multiinfected':include_multiinfected,
1578
1581
  'include_multinucleated':include_multinucleated,
1579
1582
  'remove_background':remove_background,
1580
- 'filter_min_max':[[cell_size_range[0],cell_size_range[1]],[nucleus_size_range[0],nucleus_size_range[1]],[pathogen_size_range[0],pathogen_size_range[1]]],
1583
+ 'filter_min_max':filter_min_max,
1581
1584
  'channel_dims':channel_dims,
1582
1585
  'backgrounds':backgrounds,
1583
1586
  'cell_mask_dim':mask_dims[0],
@@ -1634,31 +1637,225 @@ def analyze_recruitment(src, metadata_settings, advanced_settings):
1634
1637
  cells,wells = _results_to_csv(src, df, df_well)
1635
1638
  return [cells,wells]
1636
1639
 
1640
+ def _merge_cells_based_on_parasite_overlap(parasite_mask, cell_mask, nuclei_mask, overlap_threshold=5, perimeter_threshold=30):
1641
+ """
1642
+ Merge cells in cell_mask if a parasite in parasite_mask overlaps with more than one cell,
1643
+ and if cells share more than a specified perimeter percentage.
1644
+
1645
+ Args:
1646
+ parasite_mask (ndarray): Mask of parasites.
1647
+ cell_mask (ndarray): Mask of cells.
1648
+ nuclei_mask (ndarray): Mask of nuclei.
1649
+ overlap_threshold (float): The percentage threshold for merging cells based on parasite overlap.
1650
+ perimeter_threshold (float): The percentage threshold for merging cells based on shared perimeter.
1651
+
1652
+ Returns:
1653
+ ndarray: The modified cell mask (cell_mask) with unique labels.
1654
+ """
1655
+ labeled_cells = label(cell_mask)
1656
+ labeled_parasites = label(parasite_mask)
1657
+ labeled_nuclei = label(nuclei_mask)
1658
+ num_parasites = np.max(labeled_parasites)
1659
+ num_cells = np.max(labeled_cells)
1660
+ num_nuclei = np.max(labeled_nuclei)
1661
+
1662
+ # Merge cells based on parasite overlap
1663
+ for parasite_id in range(1, num_parasites + 1):
1664
+ current_parasite_mask = labeled_parasites == parasite_id
1665
+ overlapping_cell_labels = np.unique(labeled_cells[current_parasite_mask])
1666
+ overlapping_cell_labels = overlapping_cell_labels[overlapping_cell_labels != 0]
1667
+ if len(overlapping_cell_labels) > 1:
1668
+ # Calculate the overlap percentages
1669
+ overlap_percentages = [
1670
+ np.sum(current_parasite_mask & (labeled_cells == cell_label)) / np.sum(current_parasite_mask) * 100
1671
+ for cell_label in overlapping_cell_labels
1672
+ ]
1673
+ # Merge cells if overlap percentage is above the threshold
1674
+ for cell_label, overlap_percentage in zip(overlapping_cell_labels, overlap_percentages):
1675
+ if overlap_percentage > overlap_threshold:
1676
+ first_label = overlapping_cell_labels[0]
1677
+ for other_label in overlapping_cell_labels[1:]:
1678
+ if other_label != first_label:
1679
+ cell_mask[cell_mask == other_label] = first_label
1680
+
1681
+ # Merge cells based on nucleus overlap
1682
+ for nucleus_id in range(1, num_nuclei + 1):
1683
+ current_nucleus_mask = labeled_nuclei == nucleus_id
1684
+ overlapping_cell_labels = np.unique(labeled_cells[current_nucleus_mask])
1685
+ overlapping_cell_labels = overlapping_cell_labels[overlapping_cell_labels != 0]
1686
+ if len(overlapping_cell_labels) > 1:
1687
+ # Calculate the overlap percentages
1688
+ overlap_percentages = [
1689
+ np.sum(current_nucleus_mask & (labeled_cells == cell_label)) / np.sum(current_nucleus_mask) * 100
1690
+ for cell_label in overlapping_cell_labels
1691
+ ]
1692
+ # Merge cells if overlap percentage is above the threshold for each cell
1693
+ if all(overlap_percentage > overlap_threshold for overlap_percentage in overlap_percentages):
1694
+ first_label = overlapping_cell_labels[0]
1695
+ for other_label in overlapping_cell_labels[1:]:
1696
+ if other_label != first_label:
1697
+ cell_mask[cell_mask == other_label] = first_label
1698
+
1699
+ # Check for cells without nuclei and merge based on shared perimeter
1700
+ labeled_cells = label(cell_mask) # Re-label after merging based on overlap
1701
+ cell_regions = regionprops(labeled_cells)
1702
+ for region in cell_regions:
1703
+ cell_label = region.label
1704
+ cell_mask_binary = labeled_cells == cell_label
1705
+ overlapping_nuclei = np.unique(nuclei_mask[cell_mask_binary])
1706
+ overlapping_nuclei = overlapping_nuclei[overlapping_nuclei != 0]
1707
+
1708
+ if len(overlapping_nuclei) == 0:
1709
+ # Cell does not overlap with any nucleus
1710
+ perimeter = region.perimeter
1711
+ # Dilate the cell to find neighbors
1712
+ dilated_cell = binary_dilation(cell_mask_binary, structure=square(3))
1713
+ neighbor_cells = np.unique(labeled_cells[dilated_cell])
1714
+ neighbor_cells = neighbor_cells[(neighbor_cells != 0) & (neighbor_cells != cell_label)]
1715
+ # Calculate shared border length with neighboring cells
1716
+ shared_borders = [
1717
+ np.sum((labeled_cells == neighbor_label) & dilated_cell) for neighbor_label in neighbor_cells
1718
+ ]
1719
+ shared_border_percentages = [shared_border / perimeter * 100 for shared_border in shared_borders]
1720
+ # Merge with the neighbor cell with the largest shared border percentage above the threshold
1721
+ if shared_borders:
1722
+ max_shared_border_index = np.argmax(shared_border_percentages)
1723
+ max_shared_border_percentage = shared_border_percentages[max_shared_border_index]
1724
+ if max_shared_border_percentage > perimeter_threshold:
1725
+ cell_mask[labeled_cells == cell_label] = neighbor_cells[max_shared_border_index]
1726
+
1727
+ # Relabel the merged cell mask
1728
+ relabeled_cell_mask, _ = label(cell_mask, return_num=True)
1729
+ return relabeled_cell_mask
1730
+
1731
+ def adjust_cell_masks(parasite_folder, cell_folder, nuclei_folder, overlap_threshold=5, perimeter_threshold=30):
1732
+ """
1733
+ Process all npy files in the given folders. Merge and relabel cells in cell masks
1734
+ based on parasite overlap and cell perimeter sharing conditions.
1735
+
1736
+ Args:
1737
+ parasite_folder (str): Path to the folder containing parasite masks.
1738
+ cell_folder (str): Path to the folder containing cell masks.
1739
+ nuclei_folder (str): Path to the folder containing nuclei masks.
1740
+ overlap_threshold (float): The percentage threshold for merging cells based on parasite overlap.
1741
+ perimeter_threshold (float): The percentage threshold for merging cells based on shared perimeter.
1742
+ """
1743
+
1744
+ parasite_files = sorted([f for f in os.listdir(parasite_folder) if f.endswith('.npy')])
1745
+ cell_files = sorted([f for f in os.listdir(cell_folder) if f.endswith('.npy')])
1746
+ nuclei_files = sorted([f for f in os.listdir(nuclei_folder) if f.endswith('.npy')])
1747
+
1748
+ # Ensure there are matching files in all folders
1749
+ if not (len(parasite_files) == len(cell_files) == len(nuclei_files)):
1750
+ raise ValueError("The number of files in the folders do not match.")
1751
+
1752
+ # Match files by name
1753
+ for file_name in parasite_files:
1754
+ parasite_path = os.path.join(parasite_folder, file_name)
1755
+ cell_path = os.path.join(cell_folder, file_name)
1756
+ nuclei_path = os.path.join(nuclei_folder, file_name)
1757
+ # Check if the corresponding cell and nuclei mask files exist
1758
+ if not (os.path.exists(cell_path) and os.path.exists(nuclei_path)):
1759
+ raise ValueError(f"Corresponding cell or nuclei mask file for {file_name} not found.")
1760
+ # Load the masks
1761
+ parasite_mask = np.load(parasite_path)
1762
+ cell_mask = np.load(cell_path)
1763
+ nuclei_mask = np.load(nuclei_path)
1764
+ # Merge and relabel cells
1765
+ merged_cell_mask = _merge_cells_based_on_parasite_overlap(parasite_mask, cell_mask, nuclei_mask, overlap_threshold, perimeter_threshold)
1766
+ # Overwrite the original cell mask file with the merged result
1767
+ np.save(cell_path, merged_cell_mask)
1768
+
1769
+ def process_masks(mask_folder, image_folder, channel, batch_size=50, n_clusters=2, plot=False):
1770
+
1771
+ def read_files_in_batches(folder, batch_size=50):
1772
+ files = [f for f in os.listdir(folder) if f.endswith('.npy')]
1773
+ files.sort() # Sort to ensure matching order
1774
+ for i in range(0, len(files), batch_size):
1775
+ yield files[i:i + batch_size]
1776
+
1777
+ def measure_morphology_and_intensity(mask, image):
1778
+ properties = measure.regionprops(mask, intensity_image=image)
1779
+ properties_list = [{'area': p.area, 'mean_intensity': p.mean_intensity, 'perimeter': p.perimeter, 'eccentricity': p.eccentricity} for p in properties]
1780
+ return properties_list
1781
+
1782
+ def cluster_objects(properties, n_clusters=2):
1783
+ data = np.array([[p['area'], p['mean_intensity'], p['perimeter'], p['eccentricity']] for p in properties])
1784
+ kmeans = KMeans(n_clusters=n_clusters, random_state=0).fit(data)
1785
+ return kmeans
1786
+
1787
+ def remove_objects_not_in_largest_cluster(mask, labels, largest_cluster_label):
1788
+ cleaned_mask = np.zeros_like(mask)
1789
+ for region in measure.regionprops(mask):
1790
+ if labels[region.label - 1] == largest_cluster_label:
1791
+ cleaned_mask[mask == region.label] = region.label
1792
+ return cleaned_mask
1793
+
1794
+ def plot_clusters(properties, labels):
1795
+ data = np.array([[p['area'], p['mean_intensity'], p['perimeter'], p['eccentricity']] for p in properties])
1796
+ pca = PCA(n_components=2)
1797
+ data_2d = pca.fit_transform(data)
1798
+ plt.scatter(data_2d[:, 0], data_2d[:, 1], c=labels, cmap='viridis')
1799
+ plt.xlabel('PCA Component 1')
1800
+ plt.ylabel('PCA Component 2')
1801
+ plt.title('Object Clustering')
1802
+ plt.show()
1803
+
1804
+ all_properties = []
1805
+
1806
+ # Step 1: Accumulate properties over all files
1807
+ for batch in read_files_in_batches(mask_folder, batch_size):
1808
+ mask_files = [os.path.join(mask_folder, file) for file in batch]
1809
+ image_files = [os.path.join(image_folder, file) for file in batch]
1810
+
1811
+ masks = [np.load(file) for file in mask_files]
1812
+ images = [np.load(file)[:, :, channel] for file in image_files]
1813
+
1814
+ for i, mask in enumerate(masks):
1815
+ image = images[i]
1816
+ # Measure morphology and intensity
1817
+ properties = measure_morphology_and_intensity(mask, image)
1818
+ all_properties.extend(properties)
1819
+
1820
+ # Step 2: Perform clustering on accumulated properties
1821
+ kmeans = cluster_objects(all_properties, n_clusters)
1822
+ labels = kmeans.labels_
1823
+
1824
+ if plot:
1825
+ # Step 3: Plot clusters using PCA
1826
+ plot_clusters(all_properties, labels)
1827
+
1828
+ # Step 4: Remove objects not in the largest cluster and overwrite files in batches
1829
+ label_index = 0
1830
+ for batch in read_files_in_batches(mask_folder, batch_size):
1831
+ mask_files = [os.path.join(mask_folder, file) for file in batch]
1832
+ masks = [np.load(file) for file in mask_files]
1833
+
1834
+ for i, mask in enumerate(masks):
1835
+ batch_properties = measure_morphology_and_intensity(mask, mask)
1836
+ batch_labels = labels[label_index:label_index + len(batch_properties)]
1837
+ largest_cluster_label = np.bincount(batch_labels).argmax()
1838
+ cleaned_mask = remove_objects_not_in_largest_cluster(mask, batch_labels, largest_cluster_label)
1839
+ np.save(mask_files[i], cleaned_mask)
1840
+ label_index += len(batch_properties)
1841
+
1637
1842
  def preprocess_generate_masks(src, settings={}):
1638
1843
 
1639
1844
  from .io import preprocess_img_data, _load_and_concatenate_arrays
1640
1845
  from .plot import plot_merged, plot_arrays
1641
- from .utils import _pivot_counts_table
1642
-
1643
- settings['fps'] = 2
1644
- settings['remove_background'] = True
1645
- settings['lower_quantile'] = 0.02
1646
- settings['merge'] = False
1647
- settings['normalize_plots'] = True
1648
- settings['all_to_mip'] = False
1649
- settings['pick_slice'] = False
1650
- settings['skip_mode'] = src
1651
- settings['workers'] = os.cpu_count()-4
1652
- settings['verbose'] = True
1653
- settings['examples_to_plot'] = 1
1654
- settings['src'] = src
1655
- settings['upscale'] = False
1656
- settings['upscale_factor'] = 2.0
1846
+ from .utils import _pivot_counts_table, set_default_settings_preprocess_generate_masks, set_default_plot_merge_settings, check_mask_folder
1847
+
1848
+ settings = set_default_settings_preprocess_generate_masks(src, settings)
1657
1849
 
1658
1850
  settings_df = pd.DataFrame(list(settings.items()), columns=['Key', 'Value'])
1659
1851
  settings_csv = os.path.join(src,'settings','preprocess_generate_masks_settings.csv')
1660
1852
  os.makedirs(os.path.join(src,'settings'), exist_ok=True)
1661
1853
  settings_df.to_csv(settings_csv, index=False)
1854
+
1855
+ if not settings['pathogen_channel'] is None:
1856
+ custom_model_ls = ['toxo_pv_lumen','toxo_cyto']
1857
+ if settings['pathogen_model'] not in custom_model_ls:
1858
+ ValueError(f'Pathogen model must be {custom_model_ls} or None')
1662
1859
 
1663
1860
  if settings['timelapse']:
1664
1861
  settings['randomize'] = False
@@ -1667,24 +1864,50 @@ def preprocess_generate_masks(src, settings={}):
1667
1864
  if not settings['masks']:
1668
1865
  print(f'WARNING: channels for mask generation are defined when preprocess = True')
1669
1866
 
1670
- if isinstance(settings['merge'], bool):
1671
- settings['merge'] = [settings['merge']]*3
1672
1867
  if isinstance(settings['save'], bool):
1673
1868
  settings['save'] = [settings['save']]*3
1674
1869
 
1870
+ if settings['verbose']:
1871
+ settings_df = pd.DataFrame(list(settings.items()), columns=['setting_key', 'setting_value'])
1872
+ settings_df['setting_value'] = settings_df['setting_value'].apply(str)
1873
+ display(settings_df)
1874
+
1875
+ if settings['test_mode']:
1876
+ print(f'Starting Test mode ...')
1877
+
1675
1878
  if settings['preprocess']:
1676
1879
  settings, src = preprocess_img_data(settings)
1677
1880
 
1678
1881
  if settings['masks']:
1679
1882
  mask_src = os.path.join(src, 'norm_channel_stack')
1680
1883
  if settings['cell_channel'] != None:
1681
- generate_cellpose_masks(src=mask_src, settings=settings, object_type='cell')
1884
+ if check_mask_folder(src, 'cell_mask_stack'):
1885
+ generate_cellpose_masks(mask_src, settings, 'cell')
1682
1886
 
1683
1887
  if settings['nucleus_channel'] != None:
1684
- generate_cellpose_masks(src=mask_src, settings=settings, object_type='nucleus')
1888
+ if check_mask_folder(src, 'nucleus_mask_stack'):
1889
+ generate_cellpose_masks(mask_src, settings, 'nucleus')
1685
1890
 
1686
1891
  if settings['pathogen_channel'] != None:
1687
- generate_cellpose_masks(src=mask_src, settings=settings, object_type='pathogen')
1892
+ if check_mask_folder(src, 'pathogen_mask_stack'):
1893
+ generate_cellpose_masks(mask_src, settings, 'pathogen')
1894
+
1895
+ if settings['adjust_cells']:
1896
+ if settings['pathogen_channel'] != None and settings['cell_channel'] != None and settings['nucleus_channel'] != None:
1897
+
1898
+ start = time.time()
1899
+ cell_folder = os.path.join(mask_src, 'cell_mask_stack')
1900
+ nuclei_folder = os.path.join(mask_src, 'nucleus_mask_stack')
1901
+ parasite_folder = os.path.join(mask_src, 'pathogen_mask_stack')
1902
+ #image_folder = os.path.join(src, 'stack')
1903
+
1904
+ #process_masks(cell_folder, image_folder, settings['cell_channel'], settings['batch_size'], n_clusters=2, plot=settings['plot'])
1905
+ #process_masks(nuclei_folder, image_folder, settings['nucleus_channel'], settings['batch_size'], n_clusters=2, plot=settings['plot'])
1906
+ #process_masks(parasite_folder, image_folder, settings['pathogen_channel'], settings['batch_size'], n_clusters=2, plot=settings['plot'])
1907
+
1908
+ adjust_cell_masks(parasite_folder, cell_folder, nuclei_folder, overlap_threshold=5, perimeter_threshold=30)
1909
+ stop = time.time()
1910
+ print(f'Cell mask adjustment: {stop-start} seconds')
1688
1911
 
1689
1912
  if os.path.exists(os.path.join(src,'measurements')):
1690
1913
  _pivot_counts_table(db_path=os.path.join(src,'measurements', 'measurements.db'))
@@ -1713,60 +1936,110 @@ def preprocess_generate_masks(src, settings={}):
1713
1936
  overlay_channels = [settings['nucleus_channel'], settings['pathogen_channel'], settings['cell_channel']]
1714
1937
  overlay_channels = [element for element in overlay_channels if element is not None]
1715
1938
 
1716
- plot_settings = {'include_noninfected':True,
1717
- 'include_multiinfected':True,
1718
- 'include_multinucleated':True,
1719
- 'remove_background':False,
1720
- 'filter_min_max':None,
1721
- 'channel_dims':settings['channels'],
1722
- 'backgrounds':[100,100,100,100],
1723
- 'cell_mask_dim':cell_mask_dim,
1724
- 'nucleus_mask_dim':nucleus_mask_dim,
1725
- 'pathogen_mask_dim':pathogen_mask_dim,
1726
- 'overlay_chans':[0,2,3],
1727
- 'outline_thickness':3,
1728
- 'outline_color':'gbr',
1729
- 'overlay_chans':overlay_channels,
1730
- 'overlay':True,
1731
- 'normalization_percentiles':[1,99],
1732
- 'normalize':True,
1733
- 'print_object_number':True,
1734
- 'nr':settings['examples_to_plot'],
1735
- 'figuresize':20,
1736
- 'cmap':'inferno',
1737
- 'verbose':False}
1939
+ plot_settings = set_default_plot_merge_settings()
1940
+ plot_settings['channel_dims'] = settings['channels']
1941
+ plot_settings['cell_mask_dim'] = cell_mask_dim
1942
+ plot_settings['nucleus_mask_dim'] = nucleus_mask_dim
1943
+ plot_settings['pathogen_mask_dim'] = pathogen_mask_dim
1944
+ plot_settings['overlay_chans'] = overlay_channels
1945
+ plot_settings['nr'] = settings['examples_to_plot']
1946
+
1947
+ if settings['test_mode'] == True:
1948
+ plot_settings['nr'] = len(os.path.join(src,'merged'))
1949
+
1738
1950
  try:
1739
1951
  fig = plot_merged(src=os.path.join(src,'merged'), settings=plot_settings)
1740
1952
  except Exception as e:
1741
1953
  print(f'Failed to plot image mask overly. Error: {e}')
1742
1954
  else:
1743
- plot_arrays(src=os.path.join(src,'merged'), figuresize=50, cmap='inferno', nr=1, normalize=True, q1=1, q2=99)
1955
+ plot_arrays(src=os.path.join(src,'merged'), figuresize=settings['figuresize'], cmap=settings['cmap'], nr=settings['examples_to_plot'], normalize=settings['normalize'], q1=1, q2=99)
1744
1956
 
1745
1957
  torch.cuda.empty_cache()
1746
1958
  gc.collect()
1747
1959
  print("Successfully completed run")
1748
1960
  return
1749
1961
 
1750
- def identify_masks_finetune(src, dst, model_name, channels, diameter, batch_size, flow_threshold=30, cellprob_threshold=1, figuresize=25, cmap='inferno', verbose=False, plot=False, save=False, custom_model=None, signal_thresholds=1000, normalize=True, resize=False, target_height=None, target_width=None, rescale=True, resample=True, net_avg=False, invert=False, circular=False, percentiles=None, overlay=True, grayscale=False):
1962
+ def identify_masks_finetune(settings):
1751
1963
 
1752
1964
  from .plot import print_mask_and_flows
1753
1965
  from .utils import get_files_from_dir, resize_images_and_labels
1754
1966
  from .io import _load_normalized_images_and_labels, _load_images_and_labels
1755
1967
 
1968
+ #User defined settings
1969
+ src=settings['src']
1970
+ dst=settings['dst']
1971
+
1972
+
1973
+ settings.setdefault('model_name', 'cyto')
1974
+ settings.setdefault('custom_model', None)
1975
+ settings.setdefault('channels', [0,0])
1976
+ settings.setdefault('background', 100)
1977
+ settings.setdefault('remove_background', False)
1978
+ settings.setdefault('Signal_to_noise', 10)
1979
+ settings.setdefault('CP_prob', 0)
1980
+ settings.setdefault('diameter', 30)
1981
+ settings.setdefault('batch_size', 50)
1982
+ settings.setdefault('flow_threshold', 0.4)
1983
+ settings.setdefault('save', False)
1984
+ settings.setdefault('verbose', False)
1985
+ settings.setdefault('normalize', True)
1986
+ settings.setdefault('percentiles', None)
1987
+ settings.setdefault('circular', False)
1988
+ settings.setdefault('invert', False)
1989
+ settings.setdefault('resize', False)
1990
+ settings.setdefault('target_height', None)
1991
+ settings.setdefault('target_width', None)
1992
+ settings.setdefault('rescale', False)
1993
+ settings.setdefault('resample', False)
1994
+ settings.setdefault('grayscale', True)
1995
+
1996
+
1997
+ model_name=settings['model_name']
1998
+ custom_model=settings['custom_model']
1999
+ channels = settings['channels']
2000
+ background = settings['background']
2001
+ remove_background=settings['remove_background']
2002
+ Signal_to_noise = settings['Signal_to_noise']
2003
+ CP_prob = settings['CP_prob']
2004
+ diameter=settings['diameter']
2005
+ batch_size=settings['batch_size']
2006
+ flow_threshold=settings['flow_threshold']
2007
+ save=settings['save']
2008
+ verbose=settings['verbose']
2009
+
2010
+ # static settings
2011
+ normalize = settings['normalize']
2012
+ percentiles = settings['percentiles']
2013
+ circular = settings['circular']
2014
+ invert = settings['invert']
2015
+ resize = settings['resize']
2016
+
2017
+ if resize:
2018
+ target_height = settings['target_height']
2019
+ target_width = settings['target_width']
2020
+
2021
+ rescale = settings['rescale']
2022
+ resample = settings['resample']
2023
+ grayscale = settings['grayscale']
2024
+
2025
+ os.makedirs(dst, exist_ok=True)
2026
+
2027
+ if not custom_model is None:
2028
+ if not os.path.exists(custom_model):
2029
+ print(f'Custom model not found: {custom_model}')
2030
+ return
2031
+
1756
2032
  if not torch.cuda.is_available():
1757
2033
  print(f'Torch CUDA is not available, using CPU')
1758
2034
 
1759
2035
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
1760
2036
 
1761
2037
  if custom_model == None:
1762
- if model_name =='cyto':
1763
- model = cp_models.CellposeModel(gpu=True, model_type=model_name, net_avg=False, diam_mean=diameter, pretrained_model=None)
1764
- else:
1765
- model = cp_models.CellposeModel(gpu=True, model_type=model_name)
1766
-
1767
- if custom_model != None:
1768
- model = cp_models.CellposeModel(gpu=torch.cuda.is_available(), model_type=None, pretrained_model=custom_model, diam_mean=diameter, device=device, net_avg=False) #Assuming diameter is defined elsewhere
1769
- print(f'loaded custom model:{custom_model}')
2038
+ model = cp_models.CellposeModel(gpu=True, model_type=model_name, device=device)
2039
+ print(f'Loaded model: {model_name}')
2040
+ else:
2041
+ model = cp_models.CellposeModel(gpu=torch.cuda.is_available(), model_type=None, pretrained_model=custom_model, diam_mean=diameter, device=device)
2042
+ print("Pretrained Model Loaded:", model.pretrained_model)
1770
2043
 
1771
2044
  chans = [2, 1] if model_name == 'cyto2' else [0,0] if model_name == 'nucleus' else [1,0] if model_name == 'cyto' else [2, 0]
1772
2045
 
@@ -1776,16 +2049,18 @@ def identify_masks_finetune(src, dst, model_name, channels, diameter, batch_size
1776
2049
  print(f'Using channels: {chans} for model of type {model_name}')
1777
2050
 
1778
2051
  if verbose == True:
1779
- print(f'Cellpose settings: Model: {model_name}, channels: {channels}, cellpose_chans: {chans}, diameter:{diameter}, flow_threshold:{flow_threshold}, cellprob_threshold:{cellprob_threshold}')
2052
+ print(f'Cellpose settings: Model: {model_name}, channels: {channels}, cellpose_chans: {chans}, diameter:{diameter}, flow_threshold:{flow_threshold}, cellprob_threshold:{CP_prob}')
1780
2053
 
1781
- all_image_files = get_files_from_dir(src, file_extension="*.tif")
2054
+ all_image_files = [os.path.join(src, f) for f in os.listdir(src) if f.endswith('.tif')]
2055
+
1782
2056
  random.shuffle(all_image_files)
1783
2057
 
1784
2058
  time_ls = []
1785
2059
  for i in range(0, len(all_image_files), batch_size):
1786
2060
  image_files = all_image_files[i:i+batch_size]
2061
+
1787
2062
  if normalize:
1788
- images, _, image_names, _ = _load_normalized_images_and_labels(image_files=image_files, label_files=None, signal_thresholds=signal_thresholds, channels=channels, percentiles=percentiles, circular=circular, invert=invert, visualize=verbose)
2063
+ images, _, image_names, _ = _load_normalized_images_and_labels(image_files=image_files, label_files=None, channels=channels, percentiles=percentiles, circular=circular, invert=invert, visualize=verbose, remove_background=remove_background, background=background, Signal_to_noise=Signal_to_noise)
1789
2064
  images = [np.squeeze(img) if img.shape[-1] == 1 else img for img in images]
1790
2065
  orig_dims = [(image.shape[0], image.shape[1]) for image in images]
1791
2066
  else:
@@ -1803,11 +2078,10 @@ def identify_masks_finetune(src, dst, model_name, channels, diameter, batch_size
1803
2078
  channel_axis=3,
1804
2079
  diameter=diameter,
1805
2080
  flow_threshold=flow_threshold,
1806
- cellprob_threshold=cellprob_threshold,
2081
+ cellprob_threshold=CP_prob,
1807
2082
  rescale=rescale,
1808
2083
  resample=resample,
1809
- net_avg=net_avg,
1810
- progress=False)
2084
+ progress=True)
1811
2085
 
1812
2086
  if len(output) == 4:
1813
2087
  mask, flows, _, _ = output
@@ -1825,11 +2099,12 @@ def identify_masks_finetune(src, dst, model_name, channels, diameter, batch_size
1825
2099
  time_ls.append(duration)
1826
2100
  average_time = np.mean(time_ls) if len(time_ls) > 0 else 0
1827
2101
  print(f'Processing {file_index+1}/{len(images)} images : Time/image {average_time:.3f} sec', end='\r', flush=True)
1828
- if plot:
2102
+ if verbose:
1829
2103
  if resize:
1830
2104
  stack = resizescikit(stack, dims, preserve_range=True, anti_aliasing=False).astype(stack.dtype)
1831
- print_mask_and_flows(stack, mask, flows, overlay=overlay)
2105
+ print_mask_and_flows(stack, mask, flows, overlay=True)
1832
2106
  if save:
2107
+ os.makedirs(dst, exist_ok=True)
1833
2108
  output_filename = os.path.join(dst, image_names[file_index])
1834
2109
  cv2.imwrite(output_filename, mask)
1835
2110
  return
@@ -1882,7 +2157,6 @@ def identify_masks(src, object_type, model_name, batch_size, channels, diameter,
1882
2157
 
1883
2158
  #Note add logic that handles batches of size 1 as these will break the code batches must all be > 2 images
1884
2159
  gc.collect()
1885
- #print('========== generating masks ==========')
1886
2160
 
1887
2161
  if not torch.cuda.is_available():
1888
2162
  print(f'Torch CUDA is not available, using CPU')
@@ -1972,8 +2246,6 @@ def identify_masks(src, object_type, model_name, batch_size, channels, diameter,
1972
2246
  stitch_threshold=0.0
1973
2247
 
1974
2248
  cellpose_batch_size = _get_cellpose_batch_size()
1975
-
1976
- #model = cellpose.denoise.DenoiseModel(model_type=f"denoise_{model_name}", gpu=True)
1977
2249
 
1978
2250
  masks, flows, _, _ = model.eval(x=batch,
1979
2251
  batch_size=cellpose_batch_size,
@@ -2047,9 +2319,21 @@ def all_elements_match(list1, list2):
2047
2319
  # Check if all elements in list1 are in list2
2048
2320
  return all(element in list2 for element in list1)
2049
2321
 
2050
- def generate_cellpose_masks_v1(src, settings, object_type):
2322
+ def prepare_batch_for_cellpose(batch):
2323
+ # Ensure the batch is of dtype float32
2324
+ if batch.dtype != np.float32:
2325
+ batch = batch.astype(np.float32)
2326
+
2327
+ # Normalize each image in the batch
2328
+ for i in range(batch.shape[0]):
2329
+ if batch[i].max() > 1:
2330
+ batch[i] = batch[i] / batch[i].max()
2331
+
2332
+ return batch
2333
+
2334
+ def generate_cellpose_masks(src, settings, object_type):
2051
2335
 
2052
- from .utils import _masks_to_masks_stack, _filter_cp_masks, _get_cellpose_batch_size, _get_object_settings, _get_cellpose_channels, mask_object_count
2336
+ from .utils import _masks_to_masks_stack, _filter_cp_masks, _get_cellpose_batch_size, _get_object_settings, _get_cellpose_channels, _choose_model, mask_object_count, set_default_settings_preprocess_generate_masks
2053
2337
  from .io import _create_database, _save_object_counts_to_database, _check_masks, _get_avg_object_size
2054
2338
  from .timelapse import _npz_to_movie, _btrack_track_cells, _trackpy_track_cells
2055
2339
  from .plot import plot_masks
@@ -2057,6 +2341,13 @@ def generate_cellpose_masks_v1(src, settings, object_type):
2057
2341
  gc.collect()
2058
2342
  if not torch.cuda.is_available():
2059
2343
  print(f'Torch CUDA is not available, using CPU')
2344
+
2345
+ settings = set_default_settings_preprocess_generate_masks(src, settings)
2346
+
2347
+ if settings['verbose']:
2348
+ settings_df = pd.DataFrame(list(settings.items()), columns=['setting_key', 'setting_value'])
2349
+ settings_df['setting_value'] = settings_df['setting_value'].apply(str)
2350
+ display(settings_df)
2060
2351
 
2061
2352
  figuresize=25
2062
2353
  timelapse = settings['timelapse']
@@ -2071,23 +2362,26 @@ def generate_cellpose_masks_v1(src, settings, object_type):
2071
2362
 
2072
2363
  batch_size = settings['batch_size']
2073
2364
  cellprob_threshold = settings[f'{object_type}_CP_prob']
2074
- flow_threshold = 30
2075
-
2365
+
2366
+ flow_threshold = settings[f'{object_type}_FT']
2367
+
2076
2368
  object_settings = _get_object_settings(object_type, settings)
2077
2369
  model_name = object_settings['model_name']
2078
2370
 
2079
2371
  cellpose_channels = _get_cellpose_channels(src, settings['nucleus_channel'], settings['pathogen_channel'], settings['cell_channel'])
2080
2372
  if settings['verbose']:
2081
2373
  print(cellpose_channels)
2374
+
2082
2375
  channels = cellpose_channels[object_type]
2083
2376
  cellpose_batch_size = _get_cellpose_batch_size()
2084
-
2085
2377
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
2086
- model = cp_models.Cellpose(gpu=True, model_type=model_name, device=device) #net_avg=net_avg
2087
- #dn = denoise.CellposeDenoiseModel(model_type=f"denoise_{model_name}", gpu=True, device=device)
2088
2378
 
2379
+ if object_type == 'pathogen' and not settings['pathogen_model'] is None:
2380
+ model_name = settings['pathogen_model']
2381
+
2382
+ model = _choose_model(model_name, device, object_type=object_type, restore_type=None, object_settings=object_settings)
2383
+
2089
2384
  chans = [2, 1] if model_name == 'cyto2' else [0,0] if model_name == 'nucleus' else [2,0] if model_name == 'cyto' else [2, 0] if model_name == 'cyto3' else [2, 0]
2090
-
2091
2385
  paths = [os.path.join(src, file) for file in os.listdir(src) if file.endswith('.npz')]
2092
2386
 
2093
2387
  count_loc = os.path.dirname(src)+'/measurements/measurements.db'
@@ -2096,7 +2390,6 @@ def generate_cellpose_masks_v1(src, settings, object_type):
2096
2390
 
2097
2391
  average_sizes = []
2098
2392
  time_ls = []
2099
-
2100
2393
  for file_index, path in enumerate(paths):
2101
2394
  name = os.path.basename(path)
2102
2395
  name, ext = os.path.splitext(name)
@@ -2106,6 +2399,14 @@ def generate_cellpose_masks_v1(src, settings, object_type):
2106
2399
  with np.load(path) as data:
2107
2400
  stack = data['data']
2108
2401
  filenames = data['filenames']
2402
+
2403
+ for i, filename in enumerate(filenames):
2404
+ output_path = os.path.join(output_folder, filename)
2405
+
2406
+ if os.path.exists(output_path):
2407
+ print(f"File {filename} already exists in the output folder. Skipping...")
2408
+ continue
2409
+
2109
2410
  if settings['timelapse']:
2110
2411
 
2111
2412
  trackable_objects = ['cell','nucleus','pathogen']
@@ -2140,31 +2441,43 @@ def generate_cellpose_masks_v1(src, settings, object_type):
2140
2441
  if batch.size == 0:
2141
2442
  print(f'Processing {file_index}/{len(paths)}: Images/npz {batch.shape[0]}')
2142
2443
  continue
2143
- if batch.max() > 1:
2144
- batch = batch / batch.max()
2444
+
2445
+ batch = prepare_batch_for_cellpose(batch)
2145
2446
 
2146
2447
  if timelapse:
2147
- stitch_threshold=100.0
2148
2448
  movie_path = os.path.join(os.path.dirname(src), 'movies')
2149
2449
  os.makedirs(movie_path, exist_ok=True)
2150
2450
  save_path = os.path.join(movie_path, f'timelapse_{object_type}_{name}.mp4')
2151
2451
  _npz_to_movie(batch, batch_filenames, save_path, fps=2)
2152
- else:
2153
- stitch_threshold=0.0
2154
-
2155
- print('batch.shape',batch.shape)
2156
- masks, flows, _, _ = model.eval(x=batch,
2157
- batch_size=cellpose_batch_size,
2158
- normalize=False,
2159
- channels=chans,
2160
- channel_axis=3,
2161
- diameter=object_settings['diameter'],
2162
- flow_threshold=flow_threshold,
2163
- cellprob_threshold=cellprob_threshold,
2164
- rescale=None,
2165
- resample=object_settings['resample'],
2166
- stitch_threshold=stitch_threshold)
2167
2452
 
2453
+ if settings['verbose']:
2454
+ print(f'Processing {file_index}/{len(paths)}: Images/npz {batch.shape[0]}')
2455
+
2456
+ #cellpose_normalize_dict = {'lowhigh':[0.0,1.0], #pass in normalization values for 0.0 and 1.0 as list [low, high] if None all other keys ignored
2457
+ # 'sharpen':object_settings['diameter']/4, #recommended to be 1/4-1/8 diameter of cells in pixels
2458
+ # 'normalize':True, #(if False, all following parameters ignored)
2459
+ # 'percentile':[2,98], #[perc_low, perc_high]
2460
+ # 'tile_norm':224, #normalize by tile set to e.g. 100 for normailize window to be 100 px
2461
+ # 'norm3D':True} #compute normalization across entire z-stack rather than plane-by-plane in stitching mode.
2462
+
2463
+ output = model.eval(x=batch,
2464
+ batch_size=cellpose_batch_size,
2465
+ normalize=False,
2466
+ channels=chans,
2467
+ channel_axis=3,
2468
+ diameter=object_settings['diameter'],
2469
+ flow_threshold=flow_threshold,
2470
+ cellprob_threshold=cellprob_threshold,
2471
+ rescale=None,
2472
+ resample=object_settings['resample'])
2473
+
2474
+ if len(output) == 4:
2475
+ masks, flows, _, _ = output
2476
+ elif len(output) == 3:
2477
+ masks, flows, _ = output
2478
+ else:
2479
+ raise ValueError(f"Unexpected number of return values from model.eval(). Expected 3 or 4, got {len(output)}")
2480
+
2168
2481
  if timelapse:
2169
2482
  if settings['plot']:
2170
2483
  for idx, (mask, flow, image) in enumerate(zip(masks, flows[0], batch)):
@@ -2210,23 +2523,45 @@ def generate_cellpose_masks_v1(src, settings, object_type):
2210
2523
  mode=timelapse_mode)
2211
2524
  else:
2212
2525
  mask_stack = _masks_to_masks_stack(masks)
2213
-
2214
2526
  else:
2215
2527
  _save_object_counts_to_database(masks, object_type, batch_filenames, count_loc, added_string='_before_filtration')
2216
- mask_stack = _filter_cp_masks(masks=masks,
2217
- flows=flows,
2218
- filter_size=object_settings['filter_size'],
2219
- filter_intensity=object_settings['filter_intensity'],
2220
- minimum_size=object_settings['minimum_size'],
2221
- maximum_size=object_settings['maximum_size'],
2222
- remove_border_objects=object_settings['remove_border_objects'],
2223
- merge=False,
2224
- batch=batch,
2225
- plot=settings['plot'],
2226
- figuresize=figuresize)
2227
-
2228
- _save_object_counts_to_database(mask_stack, object_type, batch_filenames, count_loc, added_string='_after_filtration')
2528
+ if object_settings['merge'] and not settings['filter']:
2529
+ mask_stack = _filter_cp_masks(masks=masks,
2530
+ flows=flows,
2531
+ filter_size=False,
2532
+ filter_intensity=False,
2533
+ minimum_size=object_settings['minimum_size'],
2534
+ maximum_size=object_settings['maximum_size'],
2535
+ remove_border_objects=False,
2536
+ merge=object_settings['merge'],
2537
+ batch=batch,
2538
+ plot=settings['plot'],
2539
+ figuresize=figuresize)
2540
+
2541
+ if settings['filter']:
2542
+ mask_stack = _filter_cp_masks(masks=masks,
2543
+ flows=flows,
2544
+ filter_size=object_settings['filter_size'],
2545
+ filter_intensity=object_settings['filter_intensity'],
2546
+ minimum_size=object_settings['minimum_size'],
2547
+ maximum_size=object_settings['maximum_size'],
2548
+ remove_border_objects=object_settings['remove_border_objects'],
2549
+ merge=object_settings['merge'],
2550
+ batch=batch,
2551
+ plot=settings['plot'],
2552
+ figuresize=figuresize)
2553
+
2554
+ _save_object_counts_to_database(mask_stack, object_type, batch_filenames, count_loc, added_string='_after_filtration')
2555
+ else:
2556
+ mask_stack = _masks_to_masks_stack(masks)
2229
2557
 
2558
+ if settings['plot']:
2559
+ for idx, (mask, flow, image) in enumerate(zip(masks, flows[0], batch)):
2560
+ if idx == 0:
2561
+ num_objects = mask_object_count(mask)
2562
+ print(f'Number of objects, : {num_objects}')
2563
+ plot_masks(batch=image, masks=mask, flows=flow, cmap='inferno', figuresize=figuresize, nr=1, file_type='.npz', print_object_number=True)
2564
+
2230
2565
  if not np.any(mask_stack):
2231
2566
  average_obj_size = 0
2232
2567
  else:
@@ -2255,207 +2590,883 @@ def generate_cellpose_masks_v1(src, settings, object_type):
2255
2590
  torch.cuda.empty_cache()
2256
2591
  return
2257
2592
 
2258
- def generate_cellpose_masks(src, settings, object_type):
2593
+ def generate_masks_from_imgs(src, model, model_name, batch_size, diameter, cellprob_threshold, flow_threshold, grayscale, save, normalize, channels, percentiles, circular, invert, plot, resize, target_height, target_width, remove_background, background, Signal_to_noise, verbose):
2259
2594
 
2260
- from .utils import _masks_to_masks_stack, _filter_cp_masks, _get_cellpose_batch_size, _get_object_settings, _get_cellpose_channels, _choose_model, mask_object_count
2261
- from .io import _create_database, _save_object_counts_to_database, _check_masks, _get_avg_object_size
2262
- from .timelapse import _npz_to_movie, _btrack_track_cells, _trackpy_track_cells
2263
- from .plot import plot_masks
2595
+ from .io import _load_images_and_labels, _load_normalized_images_and_labels
2596
+ from .utils import resize_images_and_labels, resizescikit
2597
+ from .plot import print_mask_and_flows
2598
+
2599
+ dst = os.path.join(src, model_name)
2600
+ os.makedirs(dst, exist_ok=True)
2601
+
2602
+ chans = [2, 1] if model_name == 'cyto2' else [0,0] if model_name == 'nucleus' else [1,0] if model_name == 'cyto' else [2, 0]
2603
+
2604
+ if grayscale:
2605
+ chans=[0, 0]
2264
2606
 
2265
- gc.collect()
2266
- if not torch.cuda.is_available():
2267
- print(f'Torch CUDA is not available, using CPU')
2607
+ all_image_files = [os.path.join(src, f) for f in os.listdir(src) if f.endswith('.tif')]
2608
+ random.shuffle(all_image_files)
2268
2609
 
2269
- figuresize=25
2270
- timelapse = settings['timelapse']
2271
-
2272
- if timelapse:
2273
- timelapse_displacement = settings['timelapse_displacement']
2274
- timelapse_frame_limits = settings['timelapse_frame_limits']
2275
- timelapse_memory = settings['timelapse_memory']
2276
- timelapse_remove_transient = settings['timelapse_remove_transient']
2277
- timelapse_mode = settings['timelapse_mode']
2278
- timelapse_objects = settings['timelapse_objects']
2279
-
2280
- batch_size = settings['batch_size']
2281
- cellprob_threshold = settings[f'{object_type}_CP_prob']
2282
- flow_threshold = 30
2610
+ if verbose == True:
2611
+ print(f'Cellpose settings: Model: {model_name}, channels: {channels}, cellpose_chans: {chans}, diameter:{diameter}, flow_threshold:{flow_threshold}, cellprob_threshold:{cellprob_threshold}')
2283
2612
 
2284
- object_settings = _get_object_settings(object_type, settings)
2285
- model_name = object_settings['model_name']
2613
+ time_ls = []
2614
+ for i in range(0, len(all_image_files), batch_size):
2615
+ image_files = all_image_files[i:i+batch_size]
2616
+
2617
+ if normalize:
2618
+ images, _, image_names, _ = _load_normalized_images_and_labels(image_files, None, channels, percentiles, circular, invert, plot, remove_background, background, Signal_to_noise)
2619
+ images = [np.squeeze(img) if img.shape[-1] == 1 else img for img in images]
2620
+ orig_dims = [(image.shape[0], image.shape[1]) for image in images]
2621
+ else:
2622
+ images, _, image_names, _ = _load_images_and_labels(image_files, None, circular, invert)
2623
+ images = [np.squeeze(img) if img.shape[-1] == 1 else img for img in images]
2624
+ orig_dims = [(image.shape[0], image.shape[1]) for image in images]
2625
+ if resize:
2626
+ images, _ = resize_images_and_labels(images, None, target_height, target_width, True)
2627
+
2628
+ for file_index, stack in enumerate(images):
2629
+ start = time.time()
2630
+ output = model.eval(x=stack,
2631
+ normalize=False,
2632
+ channels=chans,
2633
+ channel_axis=3,
2634
+ diameter=diameter,
2635
+ flow_threshold=flow_threshold,
2636
+ cellprob_threshold=cellprob_threshold,
2637
+ rescale=False,
2638
+ resample=False,
2639
+ progress=False)
2640
+
2641
+ if len(output) == 4:
2642
+ mask, flows, _, _ = output
2643
+ elif len(output) == 3:
2644
+ mask, flows, _ = output
2645
+ else:
2646
+ raise ValueError("Unexpected number of return values from model.eval()")
2647
+
2648
+ if resize:
2649
+ dims = orig_dims[file_index]
2650
+ mask = resizescikit(mask, dims, order=0, preserve_range=True, anti_aliasing=False).astype(mask.dtype)
2651
+
2652
+ stop = time.time()
2653
+ duration = (stop - start)
2654
+ time_ls.append(duration)
2655
+ average_time = np.mean(time_ls) if len(time_ls) > 0 else 0
2656
+ print(f'Processing {file_index+1}/{len(images)} images : Time/image {average_time:.3f} sec', end='\r', flush=True)
2657
+ if plot:
2658
+ if resize:
2659
+ stack = resizescikit(stack, dims, preserve_range=True, anti_aliasing=False).astype(stack.dtype)
2660
+ print_mask_and_flows(stack, mask, flows, overlay=True)
2661
+ if save:
2662
+ output_filename = os.path.join(dst, image_names[file_index])
2663
+ cv2.imwrite(output_filename, mask)
2664
+
2665
+
2666
+ def check_cellpose_models(settings):
2286
2667
 
2287
- cellpose_channels = _get_cellpose_channels(src, settings['nucleus_channel'], settings['pathogen_channel'], settings['cell_channel'])
2668
+ src = settings['src']
2669
+ settings.setdefault('batch_size', 10)
2670
+ settings.setdefault('CP_prob', 0)
2671
+ settings.setdefault('flow_threshold', 0.4)
2672
+ settings.setdefault('save', True)
2673
+ settings.setdefault('normalize', True)
2674
+ settings.setdefault('channels', [0,0])
2675
+ settings.setdefault('percentiles', None)
2676
+ settings.setdefault('circular', False)
2677
+ settings.setdefault('invert', False)
2678
+ settings.setdefault('plot', True)
2679
+ settings.setdefault('diameter', 40)
2680
+ settings.setdefault('grayscale', True)
2681
+ settings.setdefault('remove_background', False)
2682
+ settings.setdefault('background', 100)
2683
+ settings.setdefault('Signal_to_noise', 5)
2684
+ settings.setdefault('verbose', False)
2685
+ settings.setdefault('resize', False)
2686
+ settings.setdefault('target_height', None)
2687
+ settings.setdefault('target_width', None)
2688
+
2288
2689
  if settings['verbose']:
2289
- print(cellpose_channels)
2690
+ settings_df = pd.DataFrame(list(settings.items()), columns=['setting_key', 'setting_value'])
2691
+ settings_df['setting_value'] = settings_df['setting_value'].apply(str)
2692
+ display(settings_df)
2290
2693
 
2291
- channels = cellpose_channels[object_type]
2292
- cellpose_batch_size = _get_cellpose_batch_size()
2694
+ cellpose_models = ['cyto', 'nuclei', 'cyto2', 'cyto3']
2293
2695
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
2294
- model = _choose_model(model_name, device, object_type='cell', restore_type=None)
2295
- chans = [2, 1] if model_name == 'cyto2' else [0,0] if model_name == 'nucleus' else [2,0] if model_name == 'cyto' else [2, 0] if model_name == 'cyto3' else [2, 0]
2296
- paths = [os.path.join(src, file) for file in os.listdir(src) if file.endswith('.npz')]
2297
2696
 
2298
- count_loc = os.path.dirname(src)+'/measurements/measurements.db'
2299
- os.makedirs(os.path.dirname(src)+'/measurements', exist_ok=True)
2300
- _create_database(count_loc)
2697
+ for model_name in cellpose_models:
2698
+
2699
+ model = cp_models.CellposeModel(gpu=True, model_type=model_name, device=device)
2700
+ print(f'Using {model_name}')
2701
+ generate_masks_from_imgs(src, model, model_name, settings['batch_size'], settings['diameter'], settings['CP_prob'], settings['flow_threshold'], settings['grayscale'], settings['save'], settings['normalize'], settings['channels'], settings['percentiles'], settings['circular'], settings['invert'], settings['plot'], settings['resize'], settings['target_height'], settings['target_width'], settings['remove_background'], settings['background'], settings['Signal_to_noise'], settings['verbose'])
2702
+
2703
+ return
2704
+
2705
+ def save_results_and_figure(src, fig, results):
2706
+
2707
+ if not isinstance(results, pd.DataFrame):
2708
+ results = pd.DataFrame(results)
2709
+
2710
+ results_dir = os.path.join(src, 'results')
2711
+ os.makedirs(results_dir, exist_ok=True)
2712
+ results_path = os.path.join(results_dir,f'results.csv')
2713
+ fig_path = os.path.join(results_dir, f'model_comparison_plot.pdf')
2714
+ results.to_csv(results_path, index=False)
2715
+ fig.savefig(fig_path, format='pdf')
2716
+ print(f'Saved figure to {fig_path} and results to {results_path}')
2717
+
2718
+ def compare_mask(args):
2719
+ src, filename, dirs, conditions = args
2720
+ paths = [os.path.join(d, filename) for d in dirs]
2721
+
2722
+ if not all(os.path.exists(path) for path in paths):
2723
+ return None
2724
+
2725
+ from .io import _read_mask # Import here to avoid issues in multiprocessing
2726
+ from .utils import extract_boundaries, boundary_f1_score, compute_segmentation_ap, jaccard_index
2727
+ from .plot import plot_comparison_results
2728
+
2729
+ masks = [_read_mask(path) for path in paths]
2730
+ file_results = {'filename': filename}
2731
+
2732
+ for i in range(len(masks)):
2733
+ for j in range(i + 1, len(masks)):
2734
+ mask_i, mask_j = masks[i], masks[j]
2735
+ f1_score = boundary_f1_score(mask_i, mask_j)
2736
+ jac_index = jaccard_index(mask_i, mask_j)
2737
+ ap_score = compute_segmentation_ap(mask_i, mask_j)
2738
+
2739
+ file_results.update({
2740
+ f'jaccard_{conditions[i]}_{conditions[j]}': jac_index,
2741
+ f'boundary_f1_{conditions[i]}_{conditions[j]}': f1_score,
2742
+ f'ap_{conditions[i]}_{conditions[j]}': ap_score
2743
+ })
2301
2744
 
2302
- average_sizes = []
2303
- time_ls = []
2304
- for file_index, path in enumerate(paths):
2305
- name = os.path.basename(path)
2306
- name, ext = os.path.splitext(name)
2307
- output_folder = os.path.join(os.path.dirname(path), object_type+'_mask_stack')
2308
- os.makedirs(output_folder, exist_ok=True)
2309
- overall_average_size = 0
2310
- with np.load(path) as data:
2311
- stack = data['data']
2312
- filenames = data['filenames']
2313
- if settings['timelapse']:
2745
+ return file_results
2314
2746
 
2315
- trackable_objects = ['cell','nucleus','pathogen']
2316
- if not all_elements_match(settings['timelapse_objects'], trackable_objects):
2317
- print(f'timelapse_objects {settings["timelapse_objects"]} must be a subset of {trackable_objects}')
2318
- return
2747
+ def compare_cellpose_masks(src, verbose=False, processes=None, save=True):
2748
+ from .plot import visualize_cellpose_masks, plot_comparison_results
2749
+ from .io import _read_mask
2319
2750
 
2320
- if len(stack) != batch_size:
2321
- print(f'Changed batch_size:{batch_size} to {len(stack)}, data length:{len(stack)}')
2322
- settings['timelapse_batch_size'] = len(stack)
2323
- batch_size = len(stack)
2324
- if isinstance(timelapse_frame_limits, list):
2325
- if len(timelapse_frame_limits) >= 2:
2326
- stack = stack[timelapse_frame_limits[0]: timelapse_frame_limits[1], :, :, :].astype(stack.dtype)
2327
- filenames = filenames[timelapse_frame_limits[0]: timelapse_frame_limits[1]]
2328
- batch_size = len(stack)
2329
- print(f'Cut batch at indecies: {timelapse_frame_limits}, New batch_size: {batch_size} ')
2751
+ dirs = [os.path.join(src, d) for d in os.listdir(src) if os.path.isdir(os.path.join(src, d)) and d != 'results']
2752
+ dirs.sort() # Optional: sort directories if needed
2753
+ conditions = [os.path.basename(d) for d in dirs]
2330
2754
 
2331
- for i in range(0, stack.shape[0], batch_size):
2332
- mask_stack = []
2333
- start = time.time()
2755
+ # Get common files in all directories
2756
+ common_files = set(os.listdir(dirs[0]))
2757
+ for d in dirs[1:]:
2758
+ common_files.intersection_update(os.listdir(d))
2759
+ common_files = list(common_files)
2334
2760
 
2335
- if stack.shape[3] == 1:
2336
- batch = stack[i: i+batch_size, :, :, [0,0]].astype(stack.dtype)
2337
- else:
2338
- batch = stack[i: i+batch_size, :, :, channels].astype(stack.dtype)
2761
+ # Create a pool of workers
2762
+ with Pool(processes=processes) as pool:
2763
+ args = [(src, filename, dirs, conditions) for filename in common_files]
2764
+ results = pool.map(compare_mask, args)
2339
2765
 
2340
- batch_filenames = filenames[i: i+batch_size].tolist()
2766
+ # Filter out None results (from skipped files)
2767
+ results = [res for res in results if res is not None]
2768
+ #print(results)
2769
+ if verbose:
2770
+ for result in results:
2771
+ filename = result['filename']
2772
+ masks = [_read_mask(os.path.join(d, filename)) for d in dirs]
2773
+ visualize_cellpose_masks(masks, titles=conditions, filename=filename, save=save, src=src)
2341
2774
 
2342
- if not settings['plot']:
2343
- batch, batch_filenames = _check_masks(batch, batch_filenames, output_folder)
2344
- if batch.size == 0:
2345
- print(f'Processing {file_index}/{len(paths)}: Images/npz {batch.shape[0]}')
2346
- continue
2347
- if batch.max() > 1:
2348
- batch = batch / batch.max()
2775
+ fig = plot_comparison_results(results)
2776
+ save_results_and_figure(src, fig, results)
2777
+ return
2349
2778
 
2350
- if timelapse:
2351
- stitch_threshold=100.0
2352
- movie_path = os.path.join(os.path.dirname(src), 'movies')
2353
- os.makedirs(movie_path, exist_ok=True)
2354
- save_path = os.path.join(movie_path, f'timelapse_{object_type}_{name}.mp4')
2355
- _npz_to_movie(batch, batch_filenames, save_path, fps=2)
2356
- else:
2357
- stitch_threshold=0.0
2358
-
2359
- print('batch.shape',batch.shape)
2360
- masks, flows, _, _ = model.eval(x=batch,
2361
- batch_size=cellpose_batch_size,
2362
- normalize=False,
2363
- channels=chans,
2364
- channel_axis=3,
2365
- diameter=object_settings['diameter'],
2366
- flow_threshold=flow_threshold,
2367
- cellprob_threshold=cellprob_threshold,
2368
- rescale=None,
2369
- resample=object_settings['resample'],
2370
- stitch_threshold=stitch_threshold)
2371
-
2372
- if timelapse:
2779
+ def _calculate_similarity(df, features, col_to_compare, val1, val2):
2780
+ """
2781
+ Calculate similarity scores of each well to the positive and negative controls using various metrics.
2782
+
2783
+ Args:
2784
+ df (pandas.DataFrame): DataFrame containing the data.
2785
+ features (list): List of feature columns to use for similarity calculation.
2786
+ col_to_compare (str): Column name to use for comparing groups.
2787
+ val1, val2 (str): Values in col_to_compare to create subsets for comparison.
2373
2788
 
2374
- if settings['plot']:
2375
- for idx, (mask, flow, image) in enumerate(zip(masks, flows[0], batch)):
2376
- if idx == 0:
2377
- num_objects = mask_object_count(mask)
2378
- print(f'Number of objects: {num_objects}')
2379
- plot_masks(batch=image, masks=mask, flows=flow, cmap='inferno', figuresize=figuresize, nr=1, file_type='.npz', print_object_number=True)
2789
+ Returns:
2790
+ pandas.DataFrame: DataFrame with similarity scores.
2791
+ """
2792
+ # Separate positive and negative control wells
2793
+ pos_control = df[df[col_to_compare] == val1][features].mean()
2794
+ neg_control = df[df[col_to_compare] == val2][features].mean()
2795
+
2796
+ # Standardize features for Mahalanobis distance
2797
+ scaler = StandardScaler()
2798
+ scaled_features = scaler.fit_transform(df[features])
2799
+
2800
+ # Regularize the covariance matrix to avoid singularity
2801
+ cov_matrix = np.cov(scaled_features, rowvar=False)
2802
+ inv_cov_matrix = None
2803
+ try:
2804
+ inv_cov_matrix = np.linalg.inv(cov_matrix)
2805
+ except np.linalg.LinAlgError:
2806
+ # Add a small value to the diagonal elements for regularization
2807
+ epsilon = 1e-5
2808
+ inv_cov_matrix = np.linalg.inv(cov_matrix + np.eye(cov_matrix.shape[0]) * epsilon)
2809
+
2810
+ # Calculate similarity scores
2811
+ df['similarity_to_pos_euclidean'] = df[features].apply(lambda row: euclidean(row, pos_control), axis=1)
2812
+ df['similarity_to_neg_euclidean'] = df[features].apply(lambda row: euclidean(row, neg_control), axis=1)
2813
+ df['similarity_to_pos_cosine'] = df[features].apply(lambda row: cosine(row, pos_control), axis=1)
2814
+ df['similarity_to_neg_cosine'] = df[features].apply(lambda row: cosine(row, neg_control), axis=1)
2815
+ df['similarity_to_pos_mahalanobis'] = df[features].apply(lambda row: mahalanobis(row, pos_control, inv_cov_matrix), axis=1)
2816
+ df['similarity_to_neg_mahalanobis'] = df[features].apply(lambda row: mahalanobis(row, neg_control, inv_cov_matrix), axis=1)
2817
+ df['similarity_to_pos_manhattan'] = df[features].apply(lambda row: cityblock(row, pos_control), axis=1)
2818
+ df['similarity_to_neg_manhattan'] = df[features].apply(lambda row: cityblock(row, neg_control), axis=1)
2819
+ df['similarity_to_pos_minkowski'] = df[features].apply(lambda row: minkowski(row, pos_control, p=3), axis=1)
2820
+ df['similarity_to_neg_minkowski'] = df[features].apply(lambda row: minkowski(row, neg_control, p=3), axis=1)
2821
+ df['similarity_to_pos_chebyshev'] = df[features].apply(lambda row: chebyshev(row, pos_control), axis=1)
2822
+ df['similarity_to_neg_chebyshev'] = df[features].apply(lambda row: chebyshev(row, neg_control), axis=1)
2823
+ df['similarity_to_pos_hamming'] = df[features].apply(lambda row: hamming(row, pos_control), axis=1)
2824
+ df['similarity_to_neg_hamming'] = df[features].apply(lambda row: hamming(row, neg_control), axis=1)
2825
+ df['similarity_to_pos_jaccard'] = df[features].apply(lambda row: jaccard(row, pos_control), axis=1)
2826
+ df['similarity_to_neg_jaccard'] = df[features].apply(lambda row: jaccard(row, neg_control), axis=1)
2827
+ df['similarity_to_pos_braycurtis'] = df[features].apply(lambda row: braycurtis(row, pos_control), axis=1)
2828
+ df['similarity_to_neg_braycurtis'] = df[features].apply(lambda row: braycurtis(row, neg_control), axis=1)
2829
+
2830
+ return df
2380
2831
 
2381
- _save_object_counts_to_database(masks, object_type, batch_filenames, count_loc, added_string='_timelapse')
2382
- if object_type in timelapse_objects:
2383
- if timelapse_mode == 'btrack':
2384
- if not timelapse_displacement is None:
2385
- radius = timelapse_displacement
2386
- else:
2387
- radius = 100
2832
+ def _permutation_importance(df, feature_string='channel_3', col_to_compare='col', pos='c1', neg='c2', exclude=None, n_repeats=10, clean=True, nr_to_plot=30, n_estimators=100, test_size=0.2, random_state=42, model_type='xgboost', n_jobs=-1):
2833
+
2834
+ """
2835
+ Calculates permutation importance for numerical features in the dataframe,
2836
+ comparing groups based on specified column values and uses the model to predict
2837
+ the class for all other rows in the dataframe.
2388
2838
 
2389
- workers = os.cpu_count()-2
2390
- if workers < 1:
2391
- workers = 1
2839
+ Args:
2840
+ df (pandas.DataFrame): The DataFrame containing the data.
2841
+ feature_string (str): String to filter features that contain this substring.
2842
+ col_to_compare (str): Column name to use for comparing groups.
2843
+ pos, neg (str): Values in col_to_compare to create subsets for comparison.
2844
+ exclude (list or str, optional): Columns to exclude from features.
2845
+ n_repeats (int): Number of repeats for permutation importance.
2846
+ clean (bool): Whether to remove columns with a single value.
2847
+ nr_to_plot (int): Number of top features to plot based on permutation importance.
2848
+ n_estimators (int): Number of trees in the random forest, gradient boosting, or XGBoost model.
2849
+ test_size (float): Proportion of the dataset to include in the test split.
2850
+ random_state (int): Random seed for reproducibility.
2851
+ model_type (str): Type of model to use ('random_forest', 'logistic_regression', 'gradient_boosting', 'xgboost').
2852
+ n_jobs (int): Number of jobs to run in parallel for applicable models.
2392
2853
 
2393
- mask_stack = _btrack_track_cells(src=src,
2394
- name=name,
2395
- batch_filenames=batch_filenames,
2396
- object_type=object_type,
2397
- plot=settings['plot'],
2398
- save=settings['save'],
2399
- masks_3D=masks,
2400
- mode=timelapse_mode,
2401
- timelapse_remove_transient=timelapse_remove_transient,
2402
- radius=radius,
2403
- workers=workers)
2404
- if timelapse_mode == 'trackpy':
2405
- mask_stack = _trackpy_track_cells(src=src,
2406
- name=name,
2407
- batch_filenames=batch_filenames,
2408
- object_type=object_type,
2409
- masks=masks,
2410
- timelapse_displacement=timelapse_displacement,
2411
- timelapse_memory=timelapse_memory,
2412
- timelapse_remove_transient=timelapse_remove_transient,
2413
- plot=settings['plot'],
2414
- save=settings['save'],
2415
- mode=timelapse_mode)
2416
- else:
2417
- mask_stack = _masks_to_masks_stack(masks)
2854
+ Returns:
2855
+ pandas.DataFrame: The original dataframe with added prediction and data usage columns.
2856
+ pandas.DataFrame: DataFrame containing the importances and standard deviations.
2857
+ """
2858
+
2859
+ from .utils import filter_dataframe_features
2860
+
2861
+ if 'cells_per_well' in df.columns:
2862
+ df = df.drop(columns=['cells_per_well'])
2863
+
2864
+ # Subset the dataframe based on specified column values
2865
+ df1 = df[df[col_to_compare] == pos].copy()
2866
+ df2 = df[df[col_to_compare] == neg].copy()
2867
+
2868
+ # Create target variable
2869
+ df1['target'] = 0
2870
+ df2['target'] = 1
2871
+
2872
+ # Combine the subsets for analysis
2873
+ combined_df = pd.concat([df1, df2])
2874
+
2875
+ if feature_string in ['channel_0', 'channel_1', 'channel_2', 'channel_3']:
2876
+ channel_of_interest = int(feature_string.split('_')[-1])
2877
+ elif not feature_string is 'morphology':
2878
+ channel_of_interest = 'morphology'
2879
+
2880
+ _, features = filter_dataframe_features(combined_df, channel_of_interest, exclude)
2881
+
2882
+ X = combined_df[features]
2883
+ y = combined_df['target']
2884
+
2885
+ # Split the data into training and testing sets
2886
+ X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=test_size, random_state=random_state)
2887
+
2888
+ # Label the data in the original dataframe
2889
+ combined_df['data_usage'] = 'train'
2890
+ combined_df.loc[X_test.index, 'data_usage'] = 'test'
2891
+
2892
+ # Initialize the model based on model_type
2893
+ if model_type == 'random_forest':
2894
+ model = RandomForestClassifier(n_estimators=n_estimators, random_state=random_state, n_jobs=n_jobs)
2895
+ elif model_type == 'logistic_regression':
2896
+ model = LogisticRegression(max_iter=1000, random_state=random_state, n_jobs=n_jobs)
2897
+ elif model_type == 'gradient_boosting':
2898
+ model = HistGradientBoostingClassifier(max_iter=n_estimators, random_state=random_state) # Supports n_jobs internally
2899
+ elif model_type == 'xgboost':
2900
+ model = XGBClassifier(n_estimators=n_estimators, random_state=random_state, nthread=n_jobs, use_label_encoder=False, eval_metric='logloss')
2901
+ else:
2902
+ raise ValueError(f"Unsupported model_type: {model_type}")
2903
+
2904
+ model.fit(X_train, y_train)
2905
+
2906
+ perm_importance = permutation_importance(model, X_train, y_train, n_repeats=n_repeats, random_state=random_state, n_jobs=n_jobs)
2907
+
2908
+ # Create a DataFrame for permutation importances
2909
+ permutation_df = pd.DataFrame({
2910
+ 'feature': [features[i] for i in perm_importance.importances_mean.argsort()],
2911
+ 'importance_mean': perm_importance.importances_mean[perm_importance.importances_mean.argsort()],
2912
+ 'importance_std': perm_importance.importances_std[perm_importance.importances_mean.argsort()]
2913
+ }).tail(nr_to_plot)
2914
+
2915
+ # Plotting
2916
+ fig, ax = plt.subplots()
2917
+ ax.barh(permutation_df['feature'], permutation_df['importance_mean'], xerr=permutation_df['importance_std'], color="teal", align="center", alpha=0.6)
2918
+ ax.set_xlabel('Permutation Importance')
2919
+ plt.tight_layout()
2920
+ plt.show()
2921
+
2922
+ # Feature importance for models that support it
2923
+ if model_type in ['random_forest', 'xgboost', 'gradient_boosting']:
2924
+ feature_importances = model.feature_importances_
2925
+ feature_importance_df = pd.DataFrame({
2926
+ 'feature': features,
2927
+ 'importance': feature_importances
2928
+ }).sort_values(by='importance', ascending=False).head(nr_to_plot)
2929
+
2930
+ # Plotting feature importance
2931
+ fig, ax = plt.subplots()
2932
+ ax.barh(feature_importance_df['feature'], feature_importance_df['importance'], color="blue", align="center", alpha=0.6)
2933
+ ax.set_xlabel('Feature Importance')
2934
+ plt.tight_layout()
2935
+ plt.show()
2936
+ else:
2937
+ feature_importance_df = pd.DataFrame()
2938
+
2939
+ # Predicting the target variable for the test set
2940
+ predictions_test = model.predict(X_test)
2941
+ combined_df.loc[X_test.index, 'predictions'] = predictions_test
2942
+
2943
+ # Predicting the target variable for the training set
2944
+ predictions_train = model.predict(X_train)
2945
+ combined_df.loc[X_train.index, 'predictions'] = predictions_train
2946
+
2947
+ # Predicting the target variable for all other rows in the dataframe
2948
+ X_all = df[features]
2949
+ all_predictions = model.predict(X_all)
2950
+ df['predictions'] = all_predictions
2951
+
2952
+ # Combine data usage labels back to the original dataframe
2953
+ combined_data_usage = pd.concat([combined_df[['data_usage']], df[['predictions']]], axis=0)
2954
+ df = df.join(combined_data_usage, how='left', rsuffix='_model')
2955
+
2956
+ # Calculating and printing the accuracy metrics
2957
+ accuracy = accuracy_score(y_test, predictions_test)
2958
+ precision = precision_score(y_test, predictions_test)
2959
+ recall = recall_score(y_test, predictions_test)
2960
+ f1 = f1_score(y_test, predictions_test)
2961
+ print(f"Accuracy: {accuracy}")
2962
+ print(f"Precision: {precision}")
2963
+ print(f"Recall: {recall}")
2964
+ print(f"F1 Score: {f1}")
2965
+
2966
+ # Printing class-specific accuracy metrics
2967
+ print("\nClassification Report:")
2968
+ print(classification_report(y_test, predictions_test))
2969
+
2970
+ df = _calculate_similarity(df, features, col_to_compare, pos, neg)
2971
+
2972
+ return [df, permutation_df, feature_importance_df, model, X_train, X_test, y_train, y_test]
2973
+
2974
+ def _shap_analysis(model, X_train, X_test):
2975
+
2976
+ """
2977
+ Performs SHAP analysis on the given model and data.
2978
+
2979
+ Args:
2980
+ model: The trained model.
2981
+ X_train (pandas.DataFrame): Training feature set.
2982
+ X_test (pandas.DataFrame): Testing feature set.
2983
+ """
2984
+
2985
+ explainer = shap.Explainer(model, X_train)
2986
+ shap_values = explainer(X_test)
2987
+
2988
+ # Summary plot
2989
+ shap.summary_plot(shap_values, X_test)
2990
+
2991
+ def plate_heatmap(src, model_type='xgboost', variable='predictions', grouping='mean', min_max='allq', cmap='viridis', channel_of_interest=3, min_count=25, n_estimators=100, col_to_compare='col', pos='c1', neg='c2', exclude=None, n_repeats=10, clean=True, nr_to_plot=20, verbose=False, n_jobs=-1):
2992
+ from .io import _read_and_merge_data
2993
+ from .plot import _plot_plates
2994
+
2995
+ db_loc = [src+'/measurements/measurements.db']
2996
+ tables = ['cell', 'nucleus', 'pathogen','cytoplasm']
2997
+ include_multinucleated, include_multiinfected, include_noninfected = True, 2.0, True
2998
+
2999
+ df, _ = _read_and_merge_data(db_loc,
3000
+ tables,
3001
+ verbose=verbose,
3002
+ include_multinucleated=include_multinucleated,
3003
+ include_multiinfected=include_multiinfected,
3004
+ include_noninfected=include_noninfected)
3005
+
3006
+ if not channel_of_interest is None:
3007
+ df['recruitment'] = df[f'pathogen_channel_{channel_of_interest}_mean_intensity']/df[f'cytoplasm_channel_{channel_of_interest}_mean_intensity']
3008
+ feature_string = f'channel_{channel_of_interest}'
3009
+ else:
3010
+ feature_string = None
3011
+
3012
+ output = _permutation_importance(df, feature_string, col_to_compare, pos, neg, exclude, n_repeats, clean, nr_to_plot, n_estimators=n_estimators, random_state=42, model_type=model_type, n_jobs=n_jobs)
3013
+
3014
+ _shap_analysis(output[3], output[4], output[5])
3015
+
3016
+ features = output[0].select_dtypes(include=[np.number]).columns.tolist()
3017
+
3018
+ if not variable in features:
3019
+ raise ValueError(f"Variable {variable} not found in the dataframe. Please choose one of the following: {features}")
3020
+
3021
+ plate_heatmap = _plot_plates(output[0], variable, grouping, min_max, cmap, min_count)
3022
+ return [output, plate_heatmap]
3023
+
3024
+ def join_measurments_and_annotation(src, tables = ['cell', 'nucleus', 'pathogen','cytoplasm']):
3025
+
3026
+ from .io import _read_and_merge_data, _read_db
3027
+
3028
+ db_loc = [src+'/measurements/measurements.db']
3029
+ loc = src+'/measurements/measurements.db'
3030
+ df, _ = _read_and_merge_data(db_loc,
3031
+ tables,
3032
+ verbose=True,
3033
+ include_multinucleated=True,
3034
+ include_multiinfected=True,
3035
+ include_noninfected=True)
3036
+
3037
+ paths_df = _read_db(loc, tables=['png_list'])
3038
+
3039
+ merged_df = pd.merge(df, paths_df[0], on='prcfo', how='left')
3040
+
3041
+ return merged_df
3042
+
3043
+ def jitterplot_by_annotation(src, x_column, y_column, plot_title='Jitter Plot', output_path=None, filter_column=None, filter_values=None):
3044
+ """
3045
+ Reads a CSV file and creates a jitter plot of one column grouped by another column.
3046
+
3047
+ Args:
3048
+ src (str): Path to the source data.
3049
+ x_column (str): Name of the column to be used for the x-axis.
3050
+ y_column (str): Name of the column to be used for the y-axis.
3051
+ plot_title (str): Title of the plot. Default is 'Jitter Plot'.
3052
+ output_path (str): Path to save the plot image. If None, the plot will be displayed. Default is None.
3053
+
3054
+ Returns:
3055
+ pd.DataFrame: The filtered and balanced DataFrame.
3056
+ """
3057
+ # Read the CSV file into a DataFrame
3058
+ df = join_measurments_and_annotation(src, tables=['cell', 'nucleus', 'pathogen', 'cytoplasm'])
3059
+
3060
+ # Print column names for debugging
3061
+ print(f"Generated dataframe with: {df.shape[1]} columns and {df.shape[0]} rows")
3062
+ #print("Columns in DataFrame:", df.columns.tolist())
3063
+
3064
+ # Replace NaN values with a specific label in x_column
3065
+ df[x_column] = df[x_column].fillna('NaN')
3066
+
3067
+ # Filter the DataFrame if filter_column and filter_values are provided
3068
+ if not filter_column is None:
3069
+ if isinstance(filter_column, str):
3070
+ df = df[df[filter_column].isin(filter_values)]
3071
+ if isinstance(filter_column, list):
3072
+ for i,val in enumerate(filter_column):
3073
+ print(f'hello {len(df)}')
3074
+ df = df[df[val].isin(filter_values[i])]
3075
+
3076
+ # Use the correct column names based on your DataFrame
3077
+ required_columns = ['plate_x', 'row_x', 'col_x']
3078
+ if not all(column in df.columns for column in required_columns):
3079
+ raise KeyError(f"DataFrame does not contain the necessary columns: {required_columns}")
3080
+
3081
+ # Filter to retain rows with non-NaN values in x_column and with matching plate, row, col values
3082
+ non_nan_df = df[df[x_column] != 'NaN']
3083
+ retained_rows = df[df[['plate_x', 'row_x', 'col_x']].apply(tuple, axis=1).isin(non_nan_df[['plate_x', 'row_x', 'col_x']].apply(tuple, axis=1))]
3084
+
3085
+ # Determine the minimum count of examples across all groups in x_column
3086
+ min_count = retained_rows[x_column].value_counts().min()
3087
+ print(f'Found {min_count} annotated images')
3088
+
3089
+ # Randomly sample min_count examples from each group in x_column
3090
+ balanced_df = retained_rows.groupby(x_column).apply(lambda x: x.sample(min_count, random_state=42)).reset_index(drop=True)
3091
+
3092
+ # Create the jitter plot
3093
+ plt.figure(figsize=(10, 6))
3094
+ jitter_plot = sns.stripplot(data=balanced_df, x=x_column, y=y_column, hue=x_column, jitter=True, palette='viridis', dodge=False)
3095
+ plt.title(plot_title)
3096
+ plt.xlabel(x_column)
3097
+ plt.ylabel(y_column)
3098
+
3099
+ # Customize the x-axis labels
3100
+ plt.xticks(rotation=45, ha='right')
3101
+
3102
+ # Adjust the position of the x-axis labels to be centered below the data
3103
+ ax = plt.gca()
3104
+ ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha='center')
3105
+
3106
+ # Save the plot to a file or display it
3107
+ if output_path:
3108
+ plt.savefig(output_path, bbox_inches='tight')
3109
+ print(f"Jitter plot saved to {output_path}")
3110
+ else:
3111
+ plt.show()
2418
3112
 
3113
+ return balanced_df
3114
+
3115
+ def generate_image_umap(settings={}):
3116
+ """
3117
+ Generate UMAP or tSNE embedding and visualize the data with clustering.
3118
+
3119
+ Parameters:
3120
+ settings (dict): Dictionary containing the following keys:
3121
+ src (str): Source directory containing the data.
3122
+ row_limit (int): Limit the number of rows to process.
3123
+ tables (list): List of table names to read from the database.
3124
+ visualize (str): Visualization type.
3125
+ image_nr (int): Number of images to display.
3126
+ dot_size (int): Size of dots in the scatter plot.
3127
+ n_neighbors (int): Number of neighbors for UMAP.
3128
+ figuresize (int): Size of the figure.
3129
+ black_background (bool): Whether to use a black background.
3130
+ remove_image_canvas (bool): Whether to remove the image canvas.
3131
+ plot_outlines (bool): Whether to plot outlines.
3132
+ plot_points (bool): Whether to plot points.
3133
+ smooth_lines (bool): Whether to smooth lines.
3134
+ verbose (bool): Whether to print verbose output.
3135
+ embedding_by_controls (bool): Whether to use embedding from controls.
3136
+ col_to_compare (str): Column to compare for control-based embedding.
3137
+ pos (str): Positive control value.
3138
+ neg (str): Negative control value.
3139
+ clustering (str): Clustering method ('DBSCAN' or 'KMeans').
3140
+ exclude (list): List of columns to exclude from the analysis.
3141
+ plot_images (bool): Whether to plot images.
3142
+ reduction_method (str): Dimensionality reduction method ('UMAP' or 'tSNE').
3143
+ save_figure (bool): Whether to save the figure as a PDF.
3144
+
3145
+ Returns:
3146
+ pd.DataFrame: DataFrame with the original data and an additional column 'cluster' containing the cluster identity.
3147
+ """
3148
+
3149
+ from .io import _read_and_join_tables
3150
+ from .utils import get_db_paths, preprocess_data, reduction_and_clustering, remove_noise, generate_colors, correct_paths, plot_embedding, plot_clusters_grid, get_umap_image_settings
3151
+ from .alpha import cluster_feature_analysis, generate_umap_from_images
3152
+
3153
+ settings = get_umap_image_settings(settings)
3154
+
3155
+ if isinstance(settings['src'], str):
3156
+ settings['src'] = [settings['src']]
3157
+
3158
+ if settings['plot_images'] is False:
3159
+ settings['black_background'] = False
3160
+
3161
+ if settings['color_by']:
3162
+ settings['remove_cluster_noise'] = False
3163
+ settings['plot_outlines'] = False
3164
+ settings['smooth_lines'] = False
3165
+
3166
+ settings_df = pd.DataFrame(list(settings.items()), columns=['Key', 'Value'])
3167
+ settings_dir = os.path.join(settings['src'][0],'settings')
3168
+ settings_csv = os.path.join(settings_dir,'embedding_settings.csv')
3169
+ os.makedirs(settings_dir, exist_ok=True)
3170
+ settings_df.to_csv(settings_csv, index=False)
3171
+ display(settings_df)
3172
+
3173
+ db_paths = get_db_paths(settings['src'])
3174
+
3175
+ tables = settings['tables'] + ['png_list']
3176
+ all_df = pd.DataFrame()
3177
+ #image_paths = []
3178
+
3179
+ for i,db_path in enumerate(db_paths):
3180
+ df = _read_and_join_tables(db_path, table_names=tables)
3181
+ df, image_paths_tmp = correct_paths(df, settings['src'][i])
3182
+ all_df = pd.concat([all_df, df], axis=0)
3183
+ #image_paths.extend(image_paths_tmp)
3184
+
3185
+ all_df['cond'] = all_df['col'].apply(map_condition, neg=settings['neg'], pos=settings['pos'], mix=settings['mix'])
3186
+
3187
+ if settings['exclude_conditions']:
3188
+ if isinstance(settings['exclude_conditions'], str):
3189
+ settings['exclude_conditions'] = [settings['exclude_conditions']]
3190
+ row_count_before = len(all_df)
3191
+ all_df = all_df[~all_df['cond'].isin(settings['exclude_conditions'])]
3192
+ if settings['verbose']:
3193
+ print(f'Excluded {row_count_before - len(all_df)} rows after excluding: {settings["exclude_conditions"]}, rows left: {len(all_df)}')
3194
+
3195
+ if settings['row_limit'] is not None:
3196
+ all_df = all_df.sample(n=settings['row_limit'], random_state=42)
3197
+
3198
+ image_paths = all_df['png_path'].to_list()
3199
+
3200
+ if settings['embedding_by_controls']:
3201
+
3202
+ # Extract and reset the index for the column to compare
3203
+ col_to_compare = all_df[settings['col_to_compare']].reset_index(drop=True)
3204
+
3205
+ # Preprocess the data to obtain numeric data
3206
+ numeric_data = preprocess_data(all_df, settings['filter_by'], settings['remove_highly_correlated'], settings['log_data'], settings['exclude'])
3207
+
3208
+ # Convert numeric_data back to a DataFrame to align with col_to_compare
3209
+ numeric_data_df = pd.DataFrame(numeric_data)
3210
+
3211
+ # Ensure numeric_data_df and col_to_compare are properly aligned
3212
+ numeric_data_df = numeric_data_df.reset_index(drop=True)
3213
+
3214
+ # Assign the column back to numeric_data_df
3215
+ numeric_data_df[settings['col_to_compare']] = col_to_compare
3216
+
3217
+ # Subset the dataframe based on specified column values for controls
3218
+ positive_control_df = numeric_data_df[numeric_data_df[settings['col_to_compare']] == settings['pos']].copy()
3219
+ negative_control_df = numeric_data_df[numeric_data_df[settings['col_to_compare']] == settings['neg']].copy()
3220
+ control_numeric_data_df = pd.concat([positive_control_df, negative_control_df])
3221
+
3222
+ # Drop the comparison column from numeric_data_df and control_numeric_data_df
3223
+ numeric_data_df = numeric_data_df.drop(columns=[settings['col_to_compare']])
3224
+ control_numeric_data_df = control_numeric_data_df.drop(columns=[settings['col_to_compare']])
3225
+
3226
+ # Convert numeric_data_df and control_numeric_data_df back to numpy arrays
3227
+ numeric_data = numeric_data_df.values
3228
+ control_numeric_data = control_numeric_data_df.values
3229
+
3230
+ # Train the reducer on control data
3231
+ _, _, reducer = reduction_and_clustering(control_numeric_data, settings['n_neighbors'], settings['min_dist'], settings['metric'], settings['eps'], settings['min_samples'], settings['clustering'], settings['reduction_method'], settings['verbose'], n_jobs=settings['n_jobs'], mode='fit', model=False)
3232
+
3233
+ # Apply the trained reducer to the entire dataset
3234
+ numeric_data = preprocess_data(all_df, settings['filter_by'], settings['remove_highly_correlated'], settings['log_data'], settings['exclude'])
3235
+ embedding, labels, _ = reduction_and_clustering(numeric_data, settings['n_neighbors'], settings['min_dist'], settings['metric'], settings['eps'], settings['min_samples'], settings['clustering'], settings['reduction_method'], settings['verbose'], n_jobs=settings['n_jobs'], mode=None, model=reducer)
3236
+
3237
+ else:
3238
+ if settings['resnet_features']:
3239
+ numeric_data, embedding, labels = generate_umap_from_images(image_paths, settings['n_neighbors'], settings['min_dist'], settings['metric'], settings['clustering'], settings['eps'], settings['min_samples'], settings['n_jobs'], settings['verbose'])
3240
+ else:
3241
+ # Apply the trained reducer to the entire dataset
3242
+ numeric_data = preprocess_data(all_df, settings['filter_by'], settings['remove_highly_correlated'], settings['log_data'], settings['exclude'])
3243
+ embedding, labels, _ = reduction_and_clustering(numeric_data, settings['n_neighbors'], settings['min_dist'], settings['metric'], settings['eps'], settings['min_samples'], settings['clustering'], settings['reduction_method'], settings['verbose'], n_jobs=settings['n_jobs'])
3244
+
3245
+ if settings['remove_cluster_noise']:
3246
+ # Remove noise from the clusters (removes -1 labels from DBSCAN)
3247
+ embedding, labels = remove_noise(embedding, labels)
3248
+
3249
+ # Plot the results
3250
+ if settings['color_by']:
3251
+ if settings['embedding_by_controls']:
3252
+ labels = all_df[settings['color_by']]
3253
+ else:
3254
+ labels = all_df[settings['color_by']]
3255
+
3256
+ # Generate colors for the clusters
3257
+ colors = generate_colors(len(np.unique(labels)), settings['black_background'])
3258
+
3259
+ # Plot the embedding
3260
+ umap_plt = plot_embedding(embedding, image_paths, labels, settings['image_nr'], settings['img_zoom'], colors, settings['plot_by_cluster'], settings['plot_outlines'], settings['plot_points'], settings['plot_images'], settings['smooth_lines'], settings['black_background'], settings['figuresize'], settings['dot_size'], settings['remove_image_canvas'], settings['verbose'])
3261
+ if settings['plot_cluster_grids'] and settings['plot_images']:
3262
+ grid_plt = plot_clusters_grid(embedding, labels, settings['image_nr'], image_paths, colors, settings['figuresize'], settings['black_background'], settings['verbose'])
3263
+
3264
+ # Save figure as PDF if required
3265
+ if settings['save_figure']:
3266
+ results_dir = os.path.join(settings['src'][0], 'results')
3267
+ os.makedirs(results_dir, exist_ok=True)
3268
+ reduction_method = settings['reduction_method'].upper()
3269
+ embedding_path = os.path.join(results_dir, f'{reduction_method}_embedding.pdf')
3270
+ umap_plt.savefig(embedding_path, format='pdf')
3271
+ print(f'Saved {reduction_method} embedding to {embedding_path} and grid to {embedding_path}')
3272
+ if settings['plot_cluster_grids'] and settings['plot_images']:
3273
+ grid_path = os.path.join(results_dir, f'{reduction_method}_grid.pdf')
3274
+ grid_plt.savefig(grid_path, format='pdf')
3275
+ print(f'Saved {reduction_method} embedding to {embedding_path} and grid to {grid_path}')
3276
+
3277
+ # Add cluster labels to the dataframe
3278
+ all_df['cluster'] = labels
3279
+
3280
+ # Save the results to a CSV file
3281
+ results_dir = os.path.join(settings['src'][0], 'results')
3282
+ results_csv = os.path.join(results_dir,'embedding_results.csv')
3283
+ os.makedirs(results_dir, exist_ok=True)
3284
+ all_df.to_csv(results_csv, index=False)
3285
+ print(f'Results saved to {results_csv}')
3286
+
3287
+ if settings['analyze_clusters']:
3288
+ combined_results = cluster_feature_analysis(all_df)
3289
+ results_dir = os.path.join(settings['src'][0], 'results')
3290
+ cluster_results_csv = os.path.join(results_dir,'cluster_results.csv')
3291
+ os.makedirs(results_dir, exist_ok=True)
3292
+ combined_results.to_csv(cluster_results_csv, index=False)
3293
+ print(f'Cluster results saved to {cluster_results_csv}')
3294
+
3295
+ return all_df
3296
+
3297
+ # Define the mapping function
3298
+ def map_condition(col_value, neg='c1', pos='c2', mix='c3'):
3299
+ if col_value == neg:
3300
+ return 'neg'
3301
+ elif col_value == pos:
3302
+ return 'pos'
3303
+ elif col_value == mix:
3304
+ return 'mix'
3305
+ else:
3306
+ return 'screen'
3307
+
3308
+ def reducer_hyperparameter_search(settings={}, reduction_params=None, dbscan_params=None, kmeans_params=None, save=False):
3309
+ """
3310
+ Perform a hyperparameter search for UMAP or tSNE on the given data.
3311
+
3312
+ Parameters:
3313
+ settings (dict): Dictionary containing the following keys:
3314
+ src (str): Source directory containing the data.
3315
+ row_limit (int): Limit the number of rows to process.
3316
+ tables (list): List of table names to read from the database.
3317
+ filter_by (str): Column to filter the data.
3318
+ sample_size (int): Number of samples to use for the hyperparameter search.
3319
+ remove_highly_correlated (bool): Whether to remove highly correlated columns.
3320
+ log_data (bool): Whether to log transform the data.
3321
+ verbose (bool): Whether to print verbose output.
3322
+ reduction_method (str): Dimensionality reduction method ('UMAP' or 'tSNE').
3323
+ reduction_params (list): List of dictionaries containing hyperparameters to test for the reduction method.
3324
+ dbscan_params (list): List of dictionaries containing DBSCAN hyperparameters to test.
3325
+ kmeans_params (list): List of dictionaries containing KMeans hyperparameters to test.
3326
+ pointsize (int): Size of the points in the scatter plot.
3327
+ save (bool): Whether to save the resulting plot as a file.
3328
+
3329
+ Returns:
3330
+ None
3331
+ """
3332
+
3333
+ from .io import _read_and_join_tables
3334
+ from .utils import get_db_paths, preprocess_data, search_reduction_and_clustering, generate_colors, get_umap_image_settings
3335
+
3336
+ settings = get_umap_image_settings(settings)
3337
+ pointsize = settings['dot_size']
3338
+ if isinstance(dbscan_params, dict):
3339
+ dbscan_params = [dbscan_params]
3340
+
3341
+ if isinstance(kmeans_params, dict):
3342
+ kmeans_params = [kmeans_params]
3343
+
3344
+ if isinstance(reduction_params, dict):
3345
+ reduction_params = [reduction_params]
3346
+
3347
+ # Determine reduction method based on the keys in reduction_param
3348
+ if any('n_neighbors' in param for param in reduction_params):
3349
+ reduction_method = 'umap'
3350
+ elif any('perplexity' in param for param in reduction_params):
3351
+ reduction_method = 'tsne'
3352
+ elif any('perplexity' in param for param in reduction_params) and any('n_neighbors' in param for param in reduction_params):
3353
+ raise ValueError("Reduction parameters must include 'n_neighbors' for UMAP or 'perplexity' for tSNE, not both.")
3354
+
3355
+ if settings['reduction_method'].lower() != reduction_method:
3356
+ settings['reduction_method'] = reduction_method
3357
+ print(f'Changed reduction method to {reduction_method} based on the provided parameters.')
3358
+
3359
+ if settings['verbose']:
3360
+ display(pd.DataFrame(list(settings.items()), columns=['Key', 'Value']))
3361
+
3362
+ db_paths = get_db_paths(settings['src'])
3363
+
3364
+ tables = settings['tables']
3365
+ all_df = pd.DataFrame()
3366
+ for db_path in db_paths:
3367
+ df = _read_and_join_tables(db_path, table_names=tables)
3368
+ all_df = pd.concat([all_df, df], axis=0)
3369
+
3370
+ all_df['cond'] = all_df['col'].apply(map_condition, neg=settings['neg'], pos=settings['pos'], mix=settings['mix'])
3371
+
3372
+ if settings['exclude_conditions']:
3373
+ if isinstance(settings['exclude_conditions'], str):
3374
+ settings['exclude_conditions'] = [settings['exclude_conditions']]
3375
+ row_count_before = len(all_df)
3376
+ all_df = all_df[~all_df['cond'].isin(settings['exclude_conditions'])]
3377
+ if settings['verbose']:
3378
+ print(f'Excluded {row_count_before - len(all_df)} rows after excluding: {settings["exclude_conditions"]}, rows left: {len(all_df)}')
3379
+
3380
+ if settings['row_limit'] is not None:
3381
+ all_df = all_df.sample(n=settings['row_limit'], random_state=42)
3382
+
3383
+ numeric_data = preprocess_data(all_df, settings['filter_by'], settings['remove_highly_correlated'], settings['log_data'], settings['exclude'])
3384
+
3385
+ # Combine DBSCAN and KMeans parameters
3386
+ clustering_params = []
3387
+ if dbscan_params:
3388
+ for param in dbscan_params:
3389
+ param['method'] = 'dbscan'
3390
+ clustering_params.append(param)
3391
+ if kmeans_params:
3392
+ for param in kmeans_params:
3393
+ param['method'] = 'kmeans'
3394
+ clustering_params.append(param)
3395
+
3396
+ print('Testing paramiters:', reduction_params)
3397
+ print('Testing clustering paramiters:', clustering_params)
3398
+
3399
+ # Calculate the grid size
3400
+ grid_rows = len(reduction_params)
3401
+ grid_cols = len(clustering_params)
3402
+
3403
+ fig_width = grid_cols*10
3404
+ fig_height = grid_rows*10
3405
+
3406
+ fig, axs = plt.subplots(grid_rows, grid_cols, figsize=(fig_width, fig_height))
3407
+
3408
+ # Make sure axs is always an array of axes
3409
+ axs = np.atleast_1d(axs)
3410
+
3411
+ # Iterate through the Cartesian product of reduction and clustering hyperparameters
3412
+ for i, reduction_param in enumerate(reduction_params):
3413
+ for j, clustering_param in enumerate(clustering_params):
3414
+ if len(clustering_params) <= 1:
3415
+ axs[i].axis('off')
3416
+ ax = axs[i]
3417
+ elif len(reduction_params) <= 1:
3418
+ axs[j].axis('off')
3419
+ ax = axs[j]
2419
3420
  else:
2420
- _save_object_counts_to_database(masks, object_type, batch_filenames, count_loc, added_string='_before_filtration')
2421
- mask_stack = _filter_cp_masks(masks=masks,
2422
- flows=flows,
2423
- filter_size=object_settings['filter_size'],
2424
- filter_intensity=object_settings['filter_intensity'],
2425
- minimum_size=object_settings['minimum_size'],
2426
- maximum_size=object_settings['maximum_size'],
2427
- remove_border_objects=object_settings['remove_border_objects'],
2428
- merge=False,
2429
- batch=batch,
2430
- plot=settings['plot'],
2431
- figuresize=figuresize)
3421
+ ax = axs[i, j]
3422
+
3423
+ # Perform dimensionality reduction and clustering
3424
+ if settings['reduction_method'].lower() == 'umap':
3425
+ n_neighbors = reduction_param.get('n_neighbors', 15)
3426
+
3427
+ if isinstance(n_neighbors, float):
3428
+ n_neighbors = int(n_neighbors * len(numeric_data))
3429
+
3430
+ min_dist = reduction_param.get('min_dist', 0.1)
3431
+ embedding, labels = search_reduction_and_clustering(numeric_data, n_neighbors, min_dist, settings['metric'],
3432
+ clustering_param.get('eps', 0.5), clustering_param.get('min_samples', 5),
3433
+ clustering_param['method'], settings['reduction_method'], settings['verbose'], reduction_param, n_jobs=settings['n_jobs'])
2432
3434
 
2433
- _save_object_counts_to_database(mask_stack, object_type, batch_filenames, count_loc, added_string='_after_filtration')
3435
+ elif settings['reduction_method'].lower() == 'tsne':
3436
+ perplexity = reduction_param.get('perplexity', 30)
2434
3437
 
2435
- if not np.any(mask_stack):
2436
- average_obj_size = 0
2437
- else:
2438
- average_obj_size = _get_avg_object_size(mask_stack)
3438
+ if isinstance(perplexity, float):
3439
+ perplexity = int(perplexity * len(numeric_data))
2439
3440
 
2440
- average_sizes.append(average_obj_size)
2441
- overall_average_size = np.mean(average_sizes) if len(average_sizes) > 0 else 0
3441
+ embedding, labels = search_reduction_and_clustering(numeric_data, perplexity, 0.1, settings['metric'],
3442
+ clustering_param.get('eps', 0.5), clustering_param.get('min_samples', 5),
3443
+ clustering_param['method'], settings['reduction_method'], settings['verbose'], reduction_param, n_jobs=settings['n_jobs'])
3444
+
3445
+ else:
3446
+ raise ValueError(f"Unsupported reduction method: {settings['reduction_method']}. Supported methods are 'UMAP' and 'tSNE'")
3447
+
3448
+ # Plot the results
3449
+ if settings['color_by']:
3450
+ unique_groups = all_df[settings['color_by']].unique()
3451
+ colors = generate_colors(len(unique_groups), False)
3452
+ for group, color in zip(unique_groups, colors):
3453
+ indices = all_df[settings['color_by']] == group
3454
+ ax.scatter(embedding[indices, 0], embedding[indices, 1], s=pointsize, label=f"{group}", color=color)
3455
+ else:
3456
+ unique_labels = np.unique(labels)
3457
+ colors = generate_colors(len(unique_labels), False)
3458
+ for label, color in zip(unique_labels, colors):
3459
+ ax.scatter(embedding[labels == label, 0], embedding[labels == label, 1], s=pointsize, label=f"Cluster {label}", color=color)
3460
+
3461
+ ax.set_title(f"{settings['reduction_method']} {reduction_param}\n{clustering_param['method']} {clustering_param}")
3462
+ ax.legend()
3463
+
3464
+ plt.tight_layout()
3465
+ if save:
3466
+ results_dir = os.path.join(settings['src'], 'results')
3467
+ os.makedirs(results_dir, exist_ok=True)
3468
+ plt.savefig(os.path.join(results_dir, 'hyperparameter_search.pdf'))
3469
+ else:
3470
+ plt.show()
2442
3471
 
2443
- stop = time.time()
2444
- duration = (stop - start)
2445
- time_ls.append(duration)
2446
- average_time = np.mean(time_ls) if len(time_ls) > 0 else 0
2447
- time_in_min = average_time/60
2448
- time_per_mask = average_time/batch_size
2449
- print(f'Processing {len(paths)} files with {batch_size} imgs: {(file_index+1)*(batch_size+1)}/{(len(paths))*(batch_size+1)}: Time/batch {time_in_min:.3f} min: Time/mask {time_per_mask:.3f}sec: {object_type} size: {overall_average_size:.3f} px2')
2450
- if not timelapse:
2451
- if settings['plot']:
2452
- plot_masks(batch, mask_stack, flows, figuresize=figuresize, cmap='inferno', nr=batch_size)
2453
- if settings['save']:
2454
- for mask_index, mask in enumerate(mask_stack):
2455
- output_filename = os.path.join(output_folder, batch_filenames[mask_index])
2456
- np.save(output_filename, mask)
2457
- mask_stack = []
2458
- batch_filenames = []
2459
- gc.collect()
2460
- torch.cuda.empty_cache()
2461
3472
  return