spacr 0.0.20__py3-none-any.whl → 0.0.21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
spacr/core.py CHANGED
@@ -1,11 +1,13 @@
1
- import os, sqlite3, gc, torch, time, random, shutil, cv2, tarfile, datetime
1
+ import os, sqlite3, gc, torch, time, random, shutil, cv2, tarfile, datetime, shap, string
2
2
 
3
3
  # image and array processing
4
4
  import numpy as np
5
5
  import pandas as pd
6
6
 
7
+ from cellpose import train
7
8
  import cellpose
8
9
  from cellpose import models as cp_models
10
+ from cellpose.models import CellposeModel
9
11
 
10
12
  import statsmodels.formula.api as smf
11
13
  import statsmodels.api as sm
@@ -27,9 +29,17 @@ matplotlib.use('Agg')
27
29
 
28
30
  import torchvision.transforms as transforms
29
31
  from sklearn.model_selection import train_test_split
30
- from sklearn.ensemble import IsolationForest
32
+ from sklearn.ensemble import IsolationForest, RandomForestClassifier, HistGradientBoostingClassifier
31
33
  from .logger import log_function_call
32
34
 
35
+ from sklearn.linear_model import LogisticRegression
36
+ from sklearn.inspection import permutation_importance
37
+ from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, classification_report
38
+ from xgboost import XGBClassifier
39
+
40
+ from scipy.spatial.distance import cosine, euclidean, mahalanobis, cityblock, minkowski, chebyshev, hamming, jaccard, braycurtis
41
+ from sklearn.preprocessing import StandardScaler
42
+ import shap
33
43
 
34
44
  def analyze_plaques(folder):
35
45
  summary_data = []
@@ -67,74 +77,6 @@ def analyze_plaques(folder):
67
77
 
68
78
  print(f"Analysis completed and saved to database '{db_name}'.")
69
79
 
70
- def compare_masks(dir1, dir2, dir3, verbose=False):
71
-
72
- from .io import _read_mask
73
- from .plot import visualize_masks, plot_comparison_results
74
- from .utils import extract_boundaries, boundary_f1_score, compute_segmentation_ap, jaccard_index, dice_coefficient
75
-
76
- filenames = os.listdir(dir1)
77
- results = []
78
- cond_1 = os.path.basename(dir1)
79
- cond_2 = os.path.basename(dir2)
80
- cond_3 = os.path.basename(dir3)
81
- for index, filename in enumerate(filenames):
82
- print(f'Processing image:{index+1}', end='\r', flush=True)
83
- path1, path2, path3 = os.path.join(dir1, filename), os.path.join(dir2, filename), os.path.join(dir3, filename)
84
- if os.path.exists(path2) and os.path.exists(path3):
85
-
86
- mask1, mask2, mask3 = _read_mask(path1), _read_mask(path2), _read_mask(path3)
87
- boundary_true1, boundary_true2, boundary_true3 = extract_boundaries(mask1), extract_boundaries(mask2), extract_boundaries(mask3)
88
-
89
-
90
- true_masks, pred_masks = [mask1], [mask2, mask3] # Assuming mask1 is the ground truth for simplicity
91
- true_labels, pred_labels_1, pred_labels_2 = label(mask1), label(mask2), label(mask3)
92
- average_precision_0, average_precision_1 = compute_segmentation_ap(mask1, mask2), compute_segmentation_ap(mask1, mask3)
93
- ap_scores = [average_precision_0, average_precision_1]
94
-
95
- if verbose:
96
- unique_values1, unique_values2, unique_values3 = np.unique(mask1), np.unique(mask2), np.unique(mask3)
97
- print(f"Unique values in mask 1: {unique_values1}, mask 2: {unique_values2}, mask 3: {unique_values3}")
98
- visualize_masks(boundary_true1, boundary_true2, boundary_true3, title=f"Boundaries - {filename}")
99
-
100
- boundary_f1_12, boundary_f1_13, boundary_f1_23 = boundary_f1_score(mask1, mask2), boundary_f1_score(mask1, mask3), boundary_f1_score(mask2, mask3)
101
-
102
- if (np.unique(mask1).size == 1 and np.unique(mask1)[0] == 0) and \
103
- (np.unique(mask2).size == 1 and np.unique(mask2)[0] == 0) and \
104
- (np.unique(mask3).size == 1 and np.unique(mask3)[0] == 0):
105
- continue
106
-
107
- if verbose:
108
- unique_values4, unique_values5, unique_values6 = np.unique(boundary_f1_12), np.unique(boundary_f1_13), np.unique(boundary_f1_23)
109
- print(f"Unique values in boundary mask 1: {unique_values4}, mask 2: {unique_values5}, mask 3: {unique_values6}")
110
- visualize_masks(mask1, mask2, mask3, title=filename)
111
-
112
- jaccard12 = jaccard_index(mask1, mask2)
113
- dice12 = dice_coefficient(mask1, mask2)
114
- jaccard13 = jaccard_index(mask1, mask3)
115
- dice13 = dice_coefficient(mask1, mask3)
116
- jaccard23 = jaccard_index(mask2, mask3)
117
- dice23 = dice_coefficient(mask2, mask3)
118
-
119
- results.append({
120
- f'filename': filename,
121
- f'jaccard_{cond_1}_{cond_2}': jaccard12,
122
- f'dice_{cond_1}_{cond_2}': dice12,
123
- f'jaccard_{cond_1}_{cond_3}': jaccard13,
124
- f'dice_{cond_1}_{cond_3}': dice13,
125
- f'jaccard_{cond_2}_{cond_3}': jaccard23,
126
- f'dice_{cond_2}_{cond_3}': dice23,
127
- f'boundary_f1_{cond_1}_{cond_2}': boundary_f1_12,
128
- f'boundary_f1_{cond_1}_{cond_3}': boundary_f1_13,
129
- f'boundary_f1_{cond_2}_{cond_3}': boundary_f1_23,
130
- f'average_precision_{cond_1}_{cond_2}': ap_scores[0],
131
- f'average_precision_{cond_1}_{cond_3}': ap_scores[1]
132
- })
133
- else:
134
- print(f'Cannot find {path1} or {path2} or {path3}')
135
- fig = plot_comparison_results(results)
136
- return results, fig
137
-
138
80
  def generate_cp_masks(settings):
139
81
 
140
82
  src = settings['src']
@@ -177,8 +119,146 @@ def train_cellpose(settings):
177
119
  from .utils import resize_images_and_labels
178
120
 
179
121
  img_src = settings['img_src']
180
- mask_src= settings['mask_src']
181
- secondary_image_dir = None
122
+ mask_src = os.path.join(img_src, 'mask')
123
+
124
+ model_name = settings['model_name']
125
+ model_type = settings['model_type']
126
+ learning_rate = settings['learning_rate']
127
+ weight_decay = settings['weight_decay']
128
+ batch_size = settings['batch_size']
129
+ n_epochs = settings['n_epochs']
130
+ from_scratch = settings['from_scratch']
131
+ diameter = settings['diameter']
132
+ verbose = settings['verbose']
133
+
134
+ channels = [0,0]
135
+ signal_thresholds = 1000
136
+ normalize = True
137
+ percentiles = [2,98]
138
+ circular = False
139
+ invert = False
140
+ resize = False
141
+ settings['width_height'] = [1000,1000]
142
+ target_height = settings['width_height'][1]
143
+ target_width = settings['width_height'][0]
144
+ rescale = False
145
+ grayscale = True
146
+ test = False
147
+
148
+ if test:
149
+ test_img_src = os.path.join(os.path.dirname(img_src), 'test')
150
+ test_mask_src = os.path.join(test_img_src, 'mask')
151
+
152
+ test_images, test_masks, test_image_names, test_mask_names = None,None,None,None,
153
+ print(settings)
154
+
155
+ if from_scratch:
156
+ model_name=f'scratch_{model_name}_{model_type}_e{n_epochs}_X{target_width}_Y{target_height}.CP_model'
157
+ else:
158
+ if resize:
159
+ model_name=f'{model_name}_{model_type}_e{n_epochs}_X{target_width}_Y{target_height}.CP_model'
160
+ else:
161
+ model_name=f'{model_name}_{model_type}_e{n_epochs}.CP_model'
162
+
163
+ model_save_path = os.path.join(mask_src, 'models', 'cellpose_model')
164
+ print(model_save_path)
165
+ os.makedirs(model_save_path, exist_ok=True)
166
+
167
+ settings_df = pd.DataFrame(list(settings.items()), columns=['Key', 'Value'])
168
+ settings_csv = os.path.join(model_save_path,f'{model_name}_settings.csv')
169
+ settings_df.to_csv(settings_csv, index=False)
170
+
171
+ if from_scratch:
172
+ model = cp_models.CellposeModel(gpu=True, model_type=model_type, diam_mean=diameter, pretrained_model=None)
173
+ else:
174
+ model = cp_models.CellposeModel(gpu=True, model_type=model_type)
175
+
176
+ if normalize:
177
+
178
+ image_files = [os.path.join(img_src, f) for f in os.listdir(img_src) if f.endswith('.tif')]
179
+ label_files = [os.path.join(mask_src, f) for f in os.listdir(mask_src) if f.endswith('.tif')]
180
+ images, masks, image_names, mask_names = _load_normalized_images_and_labels(image_files, label_files, signal_thresholds, channels=channels, percentiles=percentiles, circular=circular, invert=invert, visualize=verbose)
181
+ images = [np.squeeze(img) if img.shape[-1] == 1 else img for img in images]
182
+
183
+ if test:
184
+ test_image_files = [os.path.join(test_img_src, f) for f in os.listdir(test_img_src) if f.endswith('.tif')]
185
+ test_label_files = [os.path.join(test_mask_src, f) for f in os.listdir(test_mask_src) if f.endswith('.tif')]
186
+ test_images, test_masks, test_image_names, test_mask_names = _load_normalized_images_and_labels(image_files=test_image_files, label_files=test_label_files, signal_thresholds=signal_thresholds, channels=channels, percentiles=percentiles, circular=circular, invert=invert, visualize=verbose)
187
+ test_images = [np.squeeze(img) if img.shape[-1] == 1 else img for img in test_images]
188
+
189
+
190
+ else:
191
+ images, masks, image_names, mask_names = _load_images_and_labels(img_src, mask_src, circular, invert)
192
+ images = [np.squeeze(img) if img.shape[-1] == 1 else img for img in images]
193
+
194
+ if test:
195
+ test_images, test_masks, test_image_names, test_mask_names = _load_images_and_labels(img_src=test_img_src, mask_src=test_mask_src, circular=circular, invert=circular)
196
+ test_images = [np.squeeze(img) if img.shape[-1] == 1 else img for img in test_images]
197
+
198
+ if resize:
199
+ images, masks = resize_images_and_labels(images, masks, target_height, target_width, show_example=True)
200
+
201
+ if model_type == 'cyto':
202
+ cp_channels = [0,1]
203
+ if model_type == 'cyto2':
204
+ cp_channels = [0,2]
205
+ if model_type == 'nucleus':
206
+ cp_channels = [0,0]
207
+ if grayscale:
208
+ cp_channels = [0,0]
209
+ images = [np.squeeze(img) if img.ndim == 3 and 1 in img.shape else img for img in images]
210
+
211
+ masks = [np.squeeze(mask) if mask.ndim == 3 and 1 in mask.shape else mask for mask in masks]
212
+
213
+ print(f'image shape: {images[0].shape}, image type: images[0].shape mask shape: {masks[0].shape}, image type: masks[0].shape')
214
+ save_every = int(n_epochs/10)
215
+ if save_every < 10:
216
+ save_every = n_epochs
217
+
218
+ train.train_seg(model.net,
219
+ train_data=images,
220
+ train_labels=masks,
221
+ train_files=image_names,
222
+ train_labels_files=mask_names,
223
+ train_probs=None,
224
+ test_data=test_images,
225
+ test_labels=test_masks,
226
+ test_files=test_image_names,
227
+ test_labels_files=test_mask_names,
228
+ test_probs=None,
229
+ load_files=True,
230
+ batch_size=batch_size,
231
+ learning_rate=learning_rate,
232
+ n_epochs=n_epochs,
233
+ weight_decay=weight_decay,
234
+ momentum=0.9,
235
+ SGD=False,
236
+ channels=cp_channels,
237
+ channel_axis=None,
238
+ #rgb=False,
239
+ normalize=False,
240
+ compute_flows=False,
241
+ save_path=model_save_path,
242
+ save_every=save_every,
243
+ nimg_per_epoch=None,
244
+ nimg_test_per_epoch=None,
245
+ rescale=rescale,
246
+ #scale_range=None,
247
+ #bsize=224,
248
+ min_train_masks=1,
249
+ model_name=model_name)
250
+
251
+ return print(f"Model saved at: {model_save_path}/{model_name}")
252
+
253
+ def train_cellpose_v1(settings):
254
+
255
+ from .io import _load_normalized_images_and_labels, _load_images_and_labels
256
+ from .utils import resize_images_and_labels
257
+
258
+ img_src = settings['img_src']
259
+
260
+ mask_src = os.path.join(img_src, 'mask')
261
+
182
262
  model_name = settings['model_name']
183
263
  model_type = settings['model_type']
184
264
  learning_rate = settings['learning_rate']
@@ -186,7 +266,9 @@ def train_cellpose(settings):
186
266
  batch_size = settings['batch_size']
187
267
  n_epochs = settings['n_epochs']
188
268
  verbose = settings['verbose']
189
- signal_thresholds = settings['signal_thresholds']
269
+
270
+ signal_thresholds = 100 #settings['signal_thresholds']
271
+
190
272
  channels = settings['channels']
191
273
  from_scratch = settings['from_scratch']
192
274
  diameter = settings['diameter']
@@ -199,7 +281,17 @@ def train_cellpose(settings):
199
281
  invert = settings['invert']
200
282
  percentiles = settings['percentiles']
201
283
  grayscale = settings['grayscale']
202
-
284
+
285
+ if model_type == 'cyto':
286
+ settings['diameter'] = 30
287
+ diameter = settings['diameter']
288
+ print(f'Cyto model must have diamiter 30. Diameter set the 30')
289
+
290
+ if model_type == 'nuclei':
291
+ settings['diameter'] = 17
292
+ diameter = settings['diameter']
293
+ print(f'Nuclei model must have diamiter 17. Diameter set the 17')
294
+
203
295
  print(settings)
204
296
 
205
297
  if from_scratch:
@@ -208,24 +300,24 @@ def train_cellpose(settings):
208
300
  model_name=f'{model_name}_{model_type}_e{n_epochs}_X{target_width}_Y{target_height}.CP_model'
209
301
 
210
302
  model_save_path = os.path.join(mask_src, 'models', 'cellpose_model')
211
- os.makedirs(os.path.dirname(model_save_path), exist_ok=True)
303
+ print(model_save_path)
304
+ os.makedirs(model_save_path, exist_ok=True)
212
305
 
213
306
  settings_df = pd.DataFrame(list(settings.items()), columns=['Key', 'Value'])
214
307
  settings_csv = os.path.join(model_save_path,f'{model_name}_settings.csv')
215
308
  settings_df.to_csv(settings_csv, index=False)
216
309
 
217
- if model_type =='cyto':
218
- if not from_scratch:
219
- model = cp_models.CellposeModel(gpu=True, model_type=model_type)
220
- else:
221
- model = cp_models.CellposeModel(gpu=True, model_type=model_type, net_avg=False, diam_mean=diameter, pretrained_model=None)
222
- if model_type !='cyto':
310
+ if not from_scratch:
223
311
  model = cp_models.CellposeModel(gpu=True, model_type=model_type)
224
-
225
-
226
-
227
- if normalize:
228
- images, masks, image_names, mask_names = _load_normalized_images_and_labels(image_dir=img_src, label_dir=mask_src, secondary_image_dir=secondary_image_dir, signal_thresholds=signal_thresholds, channels=channels, percentiles=percentiles, circular=circular, invert=invert, visualize=verbose)
312
+
313
+ else:
314
+ model = cp_models.CellposeModel(gpu=True, model_type=model_type, pretrained_model=None)
315
+
316
+ if normalize:
317
+ image_files = [os.path.join(img_src, f) for f in os.listdir(img_src) if f.endswith('.tif')]
318
+ label_files = [os.path.join(mask_src, f) for f in os.listdir(mask_src) if f.endswith('.tif')]
319
+
320
+ images, masks, image_names, mask_names = _load_normalized_images_and_labels(image_files, label_files, signal_thresholds, channels=channels, percentiles=percentiles, circular=circular, invert=invert, visualize=verbose)
229
321
  images = [np.squeeze(img) if img.shape[-1] == 1 else img for img in images]
230
322
  else:
231
323
  images, masks, image_names, mask_names = _load_images_and_labels(img_src, mask_src, circular, invert)
@@ -248,25 +340,86 @@ def train_cellpose(settings):
248
340
 
249
341
  print(f'image shape: {images[0].shape}, image type: images[0].shape mask shape: {masks[0].shape}, image type: masks[0].shape')
250
342
  save_every = int(n_epochs/10)
251
- print('cellpose image input dtype', images[0].dtype)
252
- print('cellpose mask input dtype', masks[0].dtype)
343
+ if save_every < 10:
344
+ save_every = n_epochs
345
+
346
+
347
+ #print('cellpose image input dtype', images[0].dtype)
348
+ #print('cellpose mask input dtype', masks[0].dtype)
349
+
253
350
  # Train the model
254
- model.train(train_data=images, #(list of arrays (2D or 3D)) – images for training
255
- train_labels=masks, #(list of arrays (2D or 3D)) – labels for train_data, where 0=no masks; 1,2,…=mask labels can include flows as additional images
256
- train_files=image_names, #(list of strings) file names for images in train_data (to save flows for future runs)
257
- channels=cp_channels, #(list of ints (default, None)) – channels to use for training
258
- normalize=False, #(bool (default, True))normalize data so 0.0=1st percentile and 1.0=99th percentile of image intensities in each channel
259
- save_path=model_save_path, #(string (default, None)) – where to save trained model, if None it is not saved
260
- save_every=save_every, #(int (default, 100)) – save network every [save_every] epochs
261
- learning_rate=learning_rate, #(float or list/np.ndarray (default, 0.2)) – learning rate for training, if list, must be same length as n_epochs
262
- n_epochs=n_epochs, #(int (default, 500)) – how many times to go through whole training set during training
263
- weight_decay=weight_decay, #(float (default, 0.00001)) –
264
- SGD=True, #(bool (default, True)) – use SGD as optimization instead of RAdam
265
- batch_size=batch_size, #(int (optional, default 8)) – number of 224x224 patches to run simultaneously on the GPU (can make smaller or bigger depending on GPU memory usage)
266
- nimg_per_epoch=None, #(int (optional, default None)) – minimum number of images to train on per epoch, with a small training set (< 8 images) it may help to set to 8
267
- rescale=rescale, #(bool (default, True)) – whether or not to rescale images to diam_mean during training, if True it assumes you will fit a size model after training or resize your images accordingly, if False it will try to train the model to be scale-invariant (works worse)
268
- min_train_masks=1, #(int (default, 5)) – minimum number of masks an image must have to use in training set
269
- model_name=model_name) #(str (default, None)) – name of network, otherwise saved with name as params + training start time
351
+ #model.train(train_data=images, #(list of arrays (2D or 3D)) – images for training
352
+
353
+ #model.train(train_data=images, #(list of arrays (2D or 3D)) images for training
354
+ # train_labels=masks, #(list of arrays (2D or 3D)) – labels for train_data, where 0=no masks; 1,2,…=mask labels can include flows as additional images
355
+ # train_files=image_names, #(list of strings) – file names for images in train_data (to save flows for future runs)
356
+ # channels=cp_channels, #(list of ints (default, None)) – channels to use for training
357
+ # normalize=False, #(bool (default, True)) – normalize data so 0.0=1st percentile and 1.0=99th percentile of image intensities in each channel
358
+ # save_path=model_save_path, #(string (default, None)) – where to save trained model, if None it is not saved
359
+ # save_every=save_every, #(int (default, 100)) – save network every [save_every] epochs
360
+ # learning_rate=learning_rate, #(float or list/np.ndarray (default, 0.2)) – learning rate for training, if list, must be same length as n_epochs
361
+ # n_epochs=n_epochs, #(int (default, 500)) – how many times to go through whole training set during training
362
+ # weight_decay=weight_decay, #(float (default, 0.00001)) –
363
+ # SGD=True, #(bool (default, True)) – use SGD as optimization instead of RAdam
364
+ # batch_size=batch_size, #(int (optional, default 8)) – number of 224x224 patches to run simultaneously on the GPU (can make smaller or bigger depending on GPU memory usage)
365
+ # nimg_per_epoch=None, #(int (optional, default None)) – minimum number of images to train on per epoch, with a small training set (< 8 images) it may help to set to 8
366
+ # rescale=rescale, #(bool (default, True)) – whether or not to rescale images to diam_mean during training, if True it assumes you will fit a size model after training or resize your images accordingly, if False it will try to train the model to be scale-invariant (works worse)
367
+ # min_train_masks=1, #(int (default, 5)) – minimum number of masks an image must have to use in training set
368
+ # model_name=model_name) #(str (default, None)) – name of network, otherwise saved with name as params + training start time
369
+
370
+
371
+ train.train_seg(model.net,
372
+ train_data=images,
373
+ train_labels=masks,
374
+ train_files=image_names,
375
+ train_labels_files=None,
376
+ train_probs=None,
377
+ test_data=None,
378
+ test_labels=None,
379
+ test_files=None,
380
+ test_labels_files=None,
381
+ test_probs=None,
382
+ load_files=True,
383
+ batch_size=batch_size,
384
+ learning_rate=learning_rate,
385
+ n_epochs=n_epochs,
386
+ weight_decay=weight_decay,
387
+ momentum=0.9,
388
+ SGD=False,
389
+ channels=cp_channels,
390
+ channel_axis=None,
391
+ #rgb=False,
392
+ normalize=False,
393
+ compute_flows=False,
394
+ save_path=model_save_path,
395
+ save_every=save_every,
396
+ nimg_per_epoch=None,
397
+ nimg_test_per_epoch=None,
398
+ rescale=rescale,
399
+ #scale_range=None,
400
+ #bsize=224,
401
+ min_train_masks=1,
402
+ model_name=model_name)
403
+
404
+ #model_save_path = train.train_seg(model.net,
405
+ # train_data=images,
406
+ # train_files=image_names,
407
+ # train_labels=masks,
408
+ # channels=cp_channels,
409
+ # normalize=False,
410
+ # save_every=save_every,
411
+ # learning_rate=learning_rate,
412
+ # n_epochs=n_epochs,
413
+ # #test_data=test_images,
414
+ # #test_labels=test_labels,
415
+ # weight_decay=weight_decay,
416
+ # SGD=True,
417
+ # batch_size=batch_size,
418
+ # nimg_per_epoch=None,
419
+ # rescale=rescale,
420
+ # min_train_masks=1,
421
+ # model_name=model_name)
422
+
270
423
 
271
424
  return print(f"Model saved at: {model_save_path}/{model_name}")
272
425
 
@@ -926,30 +1079,38 @@ def annotate_results(pred_loc):
926
1079
  display(df)
927
1080
  return df
928
1081
 
929
- def generate_dataset(src, file_type=None, experiment='TSG101_screen', sample=None):
1082
+ def generate_dataset(src, file_metadata=None, experiment='TSG101_screen', sample=None):
930
1083
 
931
- from .utils import init_globals, add_images_to_tar
932
-
933
- db_path = os.path.join(src, 'measurements','measurements.db')
1084
+ from .utils import initiate_counter, add_images_to_tar
1085
+
1086
+ db_path = os.path.join(src, 'measurements', 'measurements.db')
934
1087
  dst = os.path.join(src, 'datasets')
935
-
936
- global total_images
937
1088
  all_paths = []
938
-
1089
+
939
1090
  # Connect to the database and retrieve the image paths
940
1091
  print(f'Reading DataBase: {db_path}')
941
- with sqlite3.connect(db_path) as conn:
942
- cursor = conn.cursor()
943
- if file_type:
944
- cursor.execute("SELECT png_path FROM png_list WHERE png_path LIKE ?", (f"%{file_type}%",))
945
- else:
946
- cursor.execute("SELECT png_path FROM png_list")
947
- while True:
948
- rows = cursor.fetchmany(1000)
949
- if not rows:
950
- break
951
- all_paths.extend([row[0] for row in rows])
952
-
1092
+ try:
1093
+ with sqlite3.connect(db_path) as conn:
1094
+ cursor = conn.cursor()
1095
+ if file_metadata:
1096
+ if isinstance(file_metadata, str):
1097
+ cursor.execute("SELECT png_path FROM png_list WHERE png_path LIKE ?", (f"%{file_metadata}%",))
1098
+ else:
1099
+ cursor.execute("SELECT png_path FROM png_list")
1100
+
1101
+ while True:
1102
+ rows = cursor.fetchmany(1000)
1103
+ if not rows:
1104
+ break
1105
+ all_paths.extend([row[0] for row in rows])
1106
+
1107
+ except sqlite3.Error as e:
1108
+ print(f"Database error: {e}")
1109
+ return
1110
+ except Exception as e:
1111
+ print(f"Error: {e}")
1112
+ return
1113
+
953
1114
  if isinstance(sample, int):
954
1115
  selected_paths = random.sample(all_paths, sample)
955
1116
  print(f'Random selection of {len(selected_paths)} paths')
@@ -957,23 +1118,18 @@ def generate_dataset(src, file_type=None, experiment='TSG101_screen', sample=Non
957
1118
  selected_paths = all_paths
958
1119
  random.shuffle(selected_paths)
959
1120
  print(f'All paths: {len(selected_paths)} paths')
960
-
1121
+
961
1122
  total_images = len(selected_paths)
962
- print(f'found {total_images} images')
963
-
1123
+ print(f'Found {total_images} images')
1124
+
964
1125
  # Create a temp folder in dst
965
1126
  temp_dir = os.path.join(dst, "temp_tars")
966
1127
  os.makedirs(temp_dir, exist_ok=True)
967
1128
 
968
1129
  # Chunking the data
969
- if len(selected_paths) > 10000:
970
- num_procs = cpu_count()-2
971
- chunk_size = len(selected_paths) // num_procs
972
- remainder = len(selected_paths) % num_procs
973
- else:
974
- num_procs = 2
975
- chunk_size = len(selected_paths) // 2
976
- remainder = 0
1130
+ num_procs = max(2, cpu_count() - 2)
1131
+ chunk_size = len(selected_paths) // num_procs
1132
+ remainder = len(selected_paths) % num_procs
977
1133
 
978
1134
  paths_chunks = []
979
1135
  start = 0
@@ -983,45 +1139,43 @@ def generate_dataset(src, file_type=None, experiment='TSG101_screen', sample=Non
983
1139
  start = end
984
1140
 
985
1141
  temp_tar_files = [os.path.join(temp_dir, f'temp_{i}.tar') for i in range(num_procs)]
986
-
987
- # Initialize the shared objects
988
- counter_ = Value('i', 0)
989
- lock_ = Lock()
990
1142
 
991
- ctx = multiprocessing.get_context('spawn')
992
-
993
1143
  print(f'Generating temporary tar files in {dst}')
994
-
1144
+
1145
+ # Initialize shared counter and lock
1146
+ counter = Value('i', 0)
1147
+ lock = Lock()
1148
+
1149
+ with Pool(processes=num_procs, initializer=initiate_counter, initargs=(counter, lock)) as pool:
1150
+ pool.starmap(add_images_to_tar, [(paths_chunks[i], temp_tar_files[i], total_images) for i in range(num_procs)])
1151
+
995
1152
  # Combine the temporary tar files into a final tar
996
1153
  date_name = datetime.date.today().strftime('%y%m%d')
997
- tar_name = f'{date_name}_{experiment}_{file_type}.tar'
1154
+ if not file_metadata is None:
1155
+ tar_name = f'{date_name}_{experiment}_{file_metadata}.tar'
1156
+ else:
1157
+ tar_name = f'{date_name}_{experiment}.tar'
1158
+ tar_name = os.path.join(dst, tar_name)
998
1159
  if os.path.exists(tar_name):
999
1160
  number = random.randint(1, 100)
1000
- tar_name_2 = f'{date_name}_{experiment}_{file_type}_{number}.tar'
1001
- print(f'Warning: {os.path.basename(tar_name)} exists saving as {os.path.basename(tar_name_2)} ')
1002
- tar_name = tar_name_2
1003
-
1004
- # Add the counter and lock to the arguments for pool.map
1161
+ tar_name_2 = f'{date_name}_{experiment}_{file_metadata}_{number}.tar'
1162
+ print(f'Warning: {os.path.basename(tar_name)} exists, saving as {os.path.basename(tar_name_2)} ')
1163
+ tar_name = os.path.join(dst, tar_name_2)
1164
+
1005
1165
  print(f'Merging temporary files')
1006
- #with Pool(processes=num_procs, initializer=init_globals, initargs=(counter_, lock_)) as pool:
1007
- # results = pool.map(add_images_to_tar, zip(paths_chunks, temp_tar_files))
1008
1166
 
1009
- with ctx.Pool(processes=num_procs, initializer=init_globals, initargs=(counter_, lock_)) as pool:
1010
- results = pool.map(add_images_to_tar, zip(paths_chunks, temp_tar_files))
1011
-
1012
- with tarfile.open(os.path.join(dst, tar_name), 'w') as final_tar:
1013
- for tar_path in results:
1014
- with tarfile.open(tar_path, 'r') as t:
1015
- for member in t.getmembers():
1016
- t.extract(member, path=dst)
1017
- final_tar.add(os.path.join(dst, member.name), arcname=member.name)
1018
- os.remove(os.path.join(dst, member.name))
1019
- os.remove(tar_path)
1167
+ with tarfile.open(tar_name, 'w') as final_tar:
1168
+ for temp_tar_path in temp_tar_files:
1169
+ with tarfile.open(temp_tar_path, 'r') as temp_tar:
1170
+ for member in temp_tar.getmembers():
1171
+ file_obj = temp_tar.extractfile(member)
1172
+ final_tar.addfile(member, file_obj)
1173
+ os.remove(temp_tar_path)
1020
1174
 
1021
1175
  # Delete the temp folder
1022
1176
  shutil.rmtree(temp_dir)
1023
- print(f"\nSaved {total_images} images to {os.path.join(dst, tar_name)}")
1024
-
1177
+ print(f"\nSaved {total_images} images to {tar_name}")
1178
+
1025
1179
  def apply_model_to_tar(tar_path, model_path, file_type='cell_png', image_size=224, batch_size=64, normalize=True, preload='images', num_workers=10, verbose=False):
1026
1180
 
1027
1181
  from .io import TarImageDataset, DataLoader
@@ -1257,7 +1411,14 @@ def generate_training_dataset(src, mode='annotation', annotation_column='test',
1257
1411
 
1258
1412
  db_path = os.path.join(src, 'measurements','measurements.db')
1259
1413
  dst = os.path.join(src, 'datasets', 'training')
1260
-
1414
+
1415
+ if os.path.exists(dst):
1416
+ for i in range(1, 1000):
1417
+ dst = os.path.join(src, 'datasets', f'training_{i}')
1418
+ if not os.path.exists(dst):
1419
+ print(f'Creating new directory for training: {dst}')
1420
+ break
1421
+
1261
1422
  if mode == 'annotation':
1262
1423
  class_paths_ls_2 = []
1263
1424
  class_paths_ls = training_dataset_from_annotation(db_path, dst, annotation_column, annotated_classes=annotated_classes)
@@ -1268,6 +1429,7 @@ def generate_training_dataset(src, mode='annotation', annotation_column='test',
1268
1429
 
1269
1430
  elif mode == 'metadata':
1270
1431
  class_paths_ls = []
1432
+ class_len_ls = []
1271
1433
  [df] = _read_db(db_loc=db_path, tables=['png_list'])
1272
1434
  df['metadata_based_class'] = pd.NA
1273
1435
  for i, class_ in enumerate(classes):
@@ -1275,7 +1437,18 @@ def generate_training_dataset(src, mode='annotation', annotation_column='test',
1275
1437
  df.loc[df[metadata_type_by].isin(ls), 'metadata_based_class'] = class_
1276
1438
 
1277
1439
  for class_ in classes:
1440
+ if size == None:
1441
+ c_s = []
1442
+ for c in classes:
1443
+ c_s_t_df = df[df['metadata_based_class'] == c]
1444
+ c_s.append(len(c_s_t_df))
1445
+ print(f'Found {len(c_s_t_df)} images for class {c}')
1446
+ size = min(c_s)
1447
+ print(f'Using the smallest class size: {size}')
1448
+
1278
1449
  class_temp_df = df[df['metadata_based_class'] == class_]
1450
+ class_len_ls.append(len(class_temp_df))
1451
+ print(f'Found {len(class_temp_df)} images for class {class_}')
1279
1452
  class_paths_temp = random.sample(class_temp_df['png_path'].tolist(), size)
1280
1453
  class_paths_ls.append(class_paths_temp)
1281
1454
 
@@ -1332,7 +1505,7 @@ def generate_training_dataset(src, mode='annotation', annotation_column='test',
1332
1505
 
1333
1506
  return
1334
1507
 
1335
- def generate_loaders(src, train_mode='erm', mode='train', image_size=224, batch_size=32, classes=['nc','pc'], num_workers=None, validation_split=0.0, max_show=2, pin_memory=False, normalize=False, verbose=False):
1508
+ def generate_loaders_v1(src, train_mode='erm', mode='train', image_size=224, batch_size=32, classes=['nc','pc'], num_workers=None, validation_split=0.0, max_show=2, pin_memory=False, normalize=False, verbose=False):
1336
1509
  """
1337
1510
  Generate data loaders for training and validation/test datasets.
1338
1511
 
@@ -1463,56 +1636,223 @@ def generate_loaders(src, train_mode='erm', mode='train', image_size=224, batch_
1463
1636
 
1464
1637
  return train_loaders, val_loaders, plate_names
1465
1638
 
1466
- def analyze_recruitment(src, metadata_settings, advanced_settings):
1639
+ def generate_loaders(src, train_mode='erm', mode='train', image_size=224, batch_size=32, classes=['nc','pc'], num_workers=None, validation_split=0.0, max_show=2, pin_memory=False, normalize=False, channels=[1, 2, 3], verbose=False):
1640
+
1467
1641
  """
1468
- Analyze recruitment data by grouping the DataFrame by well coordinates and plotting controls and recruitment data.
1642
+ Generate data loaders for training and validation/test datasets.
1469
1643
 
1470
1644
  Parameters:
1471
- src (str): The source of the recruitment data.
1472
- metadata_settings (dict): The settings for metadata.
1473
- advanced_settings (dict): The advanced settings for recruitment analysis.
1645
+ - src (str): The source directory containing the data.
1646
+ - train_mode (str): The training mode. Options are 'erm' (Empirical Risk Minimization) or 'irm' (Invariant Risk Minimization).
1647
+ - mode (str): The mode of operation. Options are 'train' or 'test'.
1648
+ - image_size (int): The size of the input images.
1649
+ - batch_size (int): The batch size for the data loaders.
1650
+ - classes (list): The list of classes to consider.
1651
+ - num_workers (int): The number of worker threads for data loading.
1652
+ - validation_split (float): The fraction of data to use for validation when train_mode is 'erm'.
1653
+ - max_show (int): The maximum number of images to show when verbose is True.
1654
+ - pin_memory (bool): Whether to pin memory for faster data transfer.
1655
+ - normalize (bool): Whether to normalize the input images.
1656
+ - verbose (bool): Whether to print additional information and show images.
1657
+ - channels (list): The list of channels to retain. Options are [1, 2, 3] for all channels, [1, 2] for blue and green, etc.
1474
1658
 
1475
1659
  Returns:
1476
- None
1660
+ - train_loaders (list): List of data loaders for training datasets.
1661
+ - val_loaders (list): List of data loaders for validation datasets.
1662
+ - plate_names (list): List of plate names (only applicable when train_mode is 'irm').
1477
1663
  """
1478
-
1479
- from .io import _read_and_merge_data, _results_to_csv
1480
- from .plot import plot_merged, _plot_controls, _plot_recruitment
1481
- from .utils import _object_filter, annotate_conditions, _calculate_recruitment, _group_by_well
1482
-
1483
- settings_dict = {**metadata_settings, **advanced_settings}
1484
- settings_df = pd.DataFrame(list(settings_dict.items()), columns=['Key', 'Value'])
1485
- settings_csv = os.path.join(src,'settings','analyze_settings.csv')
1486
- os.makedirs(os.path.join(src,'settings'), exist_ok=True)
1487
- settings_df.to_csv(settings_csv, index=False)
1488
1664
 
1489
- # metadata settings
1490
- target = metadata_settings['target']
1491
- cell_types = metadata_settings['cell_types']
1492
- cell_plate_metadata = metadata_settings['cell_plate_metadata']
1493
- pathogen_types = metadata_settings['pathogen_types']
1494
- pathogen_plate_metadata = metadata_settings['pathogen_plate_metadata']
1495
- treatments = metadata_settings['treatments']
1496
- treatment_plate_metadata = metadata_settings['treatment_plate_metadata']
1497
- metadata_types = metadata_settings['metadata_types']
1498
- channel_dims = metadata_settings['channel_dims']
1499
- cell_chann_dim = metadata_settings['cell_chann_dim']
1500
- cell_mask_dim = metadata_settings['cell_mask_dim']
1501
- nucleus_chann_dim = metadata_settings['nucleus_chann_dim']
1502
- nucleus_mask_dim = metadata_settings['nucleus_mask_dim']
1503
- pathogen_chann_dim = metadata_settings['pathogen_chann_dim']
1504
- pathogen_mask_dim = metadata_settings['pathogen_mask_dim']
1505
- channel_of_interest = metadata_settings['channel_of_interest']
1506
-
1507
- # Advanced settings
1508
- plot = advanced_settings['plot']
1509
- plot_nr = advanced_settings['plot_nr']
1510
- plot_control = advanced_settings['plot_control']
1511
- figuresize = advanced_settings['figuresize']
1512
- remove_background = advanced_settings['remove_background']
1513
- backgrounds = advanced_settings['backgrounds']
1514
- include_noninfected = advanced_settings['include_noninfected']
1515
- include_multiinfected = advanced_settings['include_multiinfected']
1665
+ from .io import MyDataset
1666
+ from .plot import _imshow
1667
+ from torchvision import transforms
1668
+ from torch.utils.data import DataLoader, random_split
1669
+ from collections import defaultdict
1670
+ import os
1671
+ import random
1672
+ from PIL import Image
1673
+ from torchvision.transforms import ToTensor
1674
+
1675
+ chans = []
1676
+
1677
+ if 'r' in channels:
1678
+ chans.append(1)
1679
+ if 'g' in channels:
1680
+ chans.append(2)
1681
+ if 'b' in channels:
1682
+ chans.append(3)
1683
+
1684
+ channels = chans
1685
+
1686
+ if verbose:
1687
+ print(f'Training a network on channels: {channels}')
1688
+ print(f'Channel 1: Red, Channel 2: Green, Channel 3: Blue')
1689
+
1690
+ class SelectChannels:
1691
+ def __init__(self, channels):
1692
+ self.channels = channels
1693
+
1694
+ def __call__(self, img):
1695
+ img = img.clone()
1696
+ if 1 not in self.channels:
1697
+ img[0, :, :] = 0 # Zero out the red channel
1698
+ if 2 not in self.channels:
1699
+ img[1, :, :] = 0 # Zero out the green channel
1700
+ if 3 not in self.channels:
1701
+ img[2, :, :] = 0 # Zero out the blue channel
1702
+ return img
1703
+
1704
+ plate_to_filenames = defaultdict(list)
1705
+ plate_to_labels = defaultdict(list)
1706
+ train_loaders = []
1707
+ val_loaders = []
1708
+ plate_names = []
1709
+
1710
+ if normalize:
1711
+ transform = transforms.Compose([
1712
+ transforms.ToTensor(),
1713
+ transforms.CenterCrop(size=(image_size, image_size)),
1714
+ SelectChannels(channels),
1715
+ transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])
1716
+ else:
1717
+ transform = transforms.Compose([
1718
+ transforms.ToTensor(),
1719
+ transforms.CenterCrop(size=(image_size, image_size)),
1720
+ SelectChannels(channels)])
1721
+
1722
+ if mode == 'train':
1723
+ data_dir = os.path.join(src, 'train')
1724
+ shuffle = True
1725
+ print('Generating Train and validation datasets')
1726
+ elif mode == 'test':
1727
+ data_dir = os.path.join(src, 'test')
1728
+ val_loaders = []
1729
+ validation_split = 0.0
1730
+ shuffle = True
1731
+ print('Generating test dataset')
1732
+ else:
1733
+ print(f'mode:{mode} is not valid, use mode = train or test')
1734
+ return
1735
+
1736
+ if train_mode == 'erm':
1737
+ data = MyDataset(data_dir, classes, transform=transform, shuffle=shuffle, pin_memory=pin_memory)
1738
+ if validation_split > 0:
1739
+ train_size = int((1 - validation_split) * len(data))
1740
+ val_size = len(data) - train_size
1741
+
1742
+ print(f'Train data:{train_size}, Validation data:{val_size}')
1743
+
1744
+ train_dataset, val_dataset = random_split(data, [train_size, val_size])
1745
+
1746
+ train_loaders = DataLoader(train_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers if num_workers is not None else 0, pin_memory=pin_memory)
1747
+ val_loaders = DataLoader(val_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers if num_workers is not None else 0, pin_memory=pin_memory)
1748
+ else:
1749
+ train_loaders = DataLoader(data, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers if num_workers is not None else 0, pin_memory=pin_memory)
1750
+
1751
+ elif train_mode == 'irm':
1752
+ data = MyDataset(data_dir, classes, transform=transform, shuffle=shuffle, pin_memory=pin_memory)
1753
+
1754
+ for filename, label in zip(data.filenames, data.labels):
1755
+ plate = data.get_plate(filename)
1756
+ plate_to_filenames[plate].append(filename)
1757
+ plate_to_labels[plate].append(label)
1758
+
1759
+ for plate, filenames in plate_to_filenames.items():
1760
+ labels = plate_to_labels[plate]
1761
+ plate_data = MyDataset(data_dir, classes, specific_files=filenames, specific_labels=labels, transform=transform, shuffle=False, pin_memory=pin_memory)
1762
+ plate_names.append(plate)
1763
+
1764
+ if validation_split > 0:
1765
+ train_size = int((1 - validation_split) * len(plate_data))
1766
+ val_size = len(plate_data) - train_size
1767
+
1768
+ print(f'Train data:{train_size}, Validation data:{val_size}')
1769
+
1770
+ train_dataset, val_dataset = random_split(plate_data, [train_size, val_size])
1771
+
1772
+ train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers if num_workers is not None else 0, pin_memory=pin_memory)
1773
+ val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers if num_workers is not None else 0, pin_memory=pin_memory)
1774
+
1775
+ train_loaders.append(train_loader)
1776
+ val_loaders.append(val_loader)
1777
+ else:
1778
+ train_loader = DataLoader(plate_data, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers if num_workers is not None else 0, pin_memory=pin_memory)
1779
+ train_loaders.append(train_loader)
1780
+ val_loaders.append(None)
1781
+
1782
+ else:
1783
+ print(f'train_mode:{train_mode} is not valid, use: train_mode = irm or erm')
1784
+ return
1785
+
1786
+ if verbose:
1787
+ if train_mode == 'erm':
1788
+ for idx, (images, labels, filenames) in enumerate(train_loaders):
1789
+ if idx >= max_show:
1790
+ break
1791
+ images = images.cpu()
1792
+ label_strings = [str(label.item()) for label in labels]
1793
+ _imshow(images, label_strings, nrow=20, fontsize=12)
1794
+ elif train_mode == 'irm':
1795
+ for plate_name, train_loader in zip(plate_names, train_loaders):
1796
+ print(f'Plate: {plate_name} with {len(train_loader.dataset)} images')
1797
+ for idx, (images, labels, filenames) in enumerate(train_loader):
1798
+ if idx >= max_show:
1799
+ break
1800
+ images = images.cpu()
1801
+ label_strings = [str(label.item()) for label in labels]
1802
+ _imshow(images, label_strings, nrow=20, fontsize=12)
1803
+
1804
+ return train_loaders, val_loaders, plate_names
1805
+
1806
+ def analyze_recruitment(src, metadata_settings, advanced_settings):
1807
+ """
1808
+ Analyze recruitment data by grouping the DataFrame by well coordinates and plotting controls and recruitment data.
1809
+
1810
+ Parameters:
1811
+ src (str): The source of the recruitment data.
1812
+ metadata_settings (dict): The settings for metadata.
1813
+ advanced_settings (dict): The advanced settings for recruitment analysis.
1814
+
1815
+ Returns:
1816
+ None
1817
+ """
1818
+
1819
+ from .io import _read_and_merge_data, _results_to_csv
1820
+ from .plot import plot_merged, _plot_controls, _plot_recruitment
1821
+ from .utils import _object_filter, annotate_conditions, _calculate_recruitment, _group_by_well
1822
+
1823
+ settings_dict = {**metadata_settings, **advanced_settings}
1824
+ settings_df = pd.DataFrame(list(settings_dict.items()), columns=['Key', 'Value'])
1825
+ settings_csv = os.path.join(src,'settings','analyze_settings.csv')
1826
+ os.makedirs(os.path.join(src,'settings'), exist_ok=True)
1827
+ settings_df.to_csv(settings_csv, index=False)
1828
+
1829
+ # metadata settings
1830
+ target = metadata_settings['target']
1831
+ cell_types = metadata_settings['cell_types']
1832
+ cell_plate_metadata = metadata_settings['cell_plate_metadata']
1833
+ pathogen_types = metadata_settings['pathogen_types']
1834
+ pathogen_plate_metadata = metadata_settings['pathogen_plate_metadata']
1835
+ treatments = metadata_settings['treatments']
1836
+ treatment_plate_metadata = metadata_settings['treatment_plate_metadata']
1837
+ metadata_types = metadata_settings['metadata_types']
1838
+ channel_dims = metadata_settings['channel_dims']
1839
+ cell_chann_dim = metadata_settings['cell_chann_dim']
1840
+ cell_mask_dim = metadata_settings['cell_mask_dim']
1841
+ nucleus_chann_dim = metadata_settings['nucleus_chann_dim']
1842
+ nucleus_mask_dim = metadata_settings['nucleus_mask_dim']
1843
+ pathogen_chann_dim = metadata_settings['pathogen_chann_dim']
1844
+ pathogen_mask_dim = metadata_settings['pathogen_mask_dim']
1845
+ channel_of_interest = metadata_settings['channel_of_interest']
1846
+
1847
+ # Advanced settings
1848
+ plot = advanced_settings['plot']
1849
+ plot_nr = advanced_settings['plot_nr']
1850
+ plot_control = advanced_settings['plot_control']
1851
+ figuresize = advanced_settings['figuresize']
1852
+ remove_background = advanced_settings['remove_background']
1853
+ backgrounds = advanced_settings['backgrounds']
1854
+ include_noninfected = advanced_settings['include_noninfected']
1855
+ include_multiinfected = advanced_settings['include_multiinfected']
1516
1856
  include_multinucleated = advanced_settings['include_multinucleated']
1517
1857
  cells_per_well = advanced_settings['cells_per_well']
1518
1858
  pathogen_size_range = advanced_settings['pathogen_size_range']
@@ -1569,15 +1909,30 @@ def analyze_recruitment(src, metadata_settings, advanced_settings):
1569
1909
  df = df.dropna(subset=['condition'])
1570
1910
  print(f'After dropping non-annotated wells: {len(df)} rows')
1571
1911
  files = df['file_name'].tolist()
1912
+ print(f'found: {len(files)} files')
1572
1913
  files = [item + '.npy' for item in files]
1573
1914
  random.shuffle(files)
1574
-
1915
+
1916
+ _max = 10**100
1917
+
1918
+ if cell_size_range is None and nucleus_size_range is None and pathogen_size_range is None:
1919
+ filter_min_max = None
1920
+ else:
1921
+ if cell_size_range is None:
1922
+ cell_size_range = [0,_max]
1923
+ if nucleus_size_range is None:
1924
+ nucleus_size_range = [0,_max]
1925
+ if pathogen_size_range is None:
1926
+ pathogen_size_range = [0,_max]
1927
+
1928
+ filter_min_max = [[cell_size_range[0],cell_size_range[1]],[nucleus_size_range[0],nucleus_size_range[1]],[pathogen_size_range[0],pathogen_size_range[1]]]
1929
+
1575
1930
  if plot:
1576
1931
  plot_settings = {'include_noninfected':include_noninfected,
1577
1932
  'include_multiinfected':include_multiinfected,
1578
1933
  'include_multinucleated':include_multinucleated,
1579
1934
  'remove_background':remove_background,
1580
- 'filter_min_max':[[cell_size_range[0],cell_size_range[1]],[nucleus_size_range[0],nucleus_size_range[1]],[pathogen_size_range[0],pathogen_size_range[1]]],
1935
+ 'filter_min_max':filter_min_max,
1581
1936
  'channel_dims':channel_dims,
1582
1937
  'backgrounds':backgrounds,
1583
1938
  'cell_mask_dim':mask_dims[0],
@@ -1640,6 +1995,7 @@ def preprocess_generate_masks(src, settings={}):
1640
1995
  from .plot import plot_merged, plot_arrays
1641
1996
  from .utils import _pivot_counts_table
1642
1997
 
1998
+ settings['plot'] = False
1643
1999
  settings['fps'] = 2
1644
2000
  settings['remove_background'] = True
1645
2001
  settings['lower_quantile'] = 0.02
@@ -1655,6 +2011,15 @@ def preprocess_generate_masks(src, settings={}):
1655
2011
  settings['upscale'] = False
1656
2012
  settings['upscale_factor'] = 2.0
1657
2013
 
2014
+ settings['randomize'] = True
2015
+ settings['timelapse'] = False
2016
+ settings['timelapse_displacement'] = None
2017
+ settings['timelapse_memory'] = 3
2018
+ settings['timelapse_frame_limits'] = None
2019
+ settings['timelapse_remove_transient'] = False
2020
+ settings['timelapse_mode'] = 'trackpy'
2021
+ settings['timelapse_objects'] = ['cells']
2022
+
1658
2023
  settings_df = pd.DataFrame(list(settings.items()), columns=['Key', 'Value'])
1659
2024
  settings_csv = os.path.join(src,'settings','preprocess_generate_masks_settings.csv')
1660
2025
  os.makedirs(os.path.join(src,'settings'), exist_ok=True)
@@ -1723,7 +2088,6 @@ def preprocess_generate_masks(src, settings={}):
1723
2088
  'cell_mask_dim':cell_mask_dim,
1724
2089
  'nucleus_mask_dim':nucleus_mask_dim,
1725
2090
  'pathogen_mask_dim':pathogen_mask_dim,
1726
- 'overlay_chans':[0,2,3],
1727
2091
  'outline_thickness':3,
1728
2092
  'outline_color':'gbr',
1729
2093
  'overlay_chans':overlay_channels,
@@ -1735,6 +2099,10 @@ def preprocess_generate_masks(src, settings={}):
1735
2099
  'figuresize':20,
1736
2100
  'cmap':'inferno',
1737
2101
  'verbose':False}
2102
+
2103
+ if settings['test_mode'] == True:
2104
+ plot_settings['nr'] = len(os.path.join(src,'merged'))
2105
+
1738
2106
  try:
1739
2107
  fig = plot_merged(src=os.path.join(src,'merged'), settings=plot_settings)
1740
2108
  except Exception as e:
@@ -1747,26 +2115,61 @@ def preprocess_generate_masks(src, settings={}):
1747
2115
  print("Successfully completed run")
1748
2116
  return
1749
2117
 
1750
- def identify_masks_finetune(src, dst, model_name, channels, diameter, batch_size, flow_threshold=30, cellprob_threshold=1, figuresize=25, cmap='inferno', verbose=False, plot=False, save=False, custom_model=None, signal_thresholds=1000, normalize=True, resize=False, target_height=None, target_width=None, rescale=True, resample=True, net_avg=False, invert=False, circular=False, percentiles=None, overlay=True, grayscale=False):
2118
+ def identify_masks_finetune(settings):
1751
2119
 
1752
2120
  from .plot import print_mask_and_flows
1753
2121
  from .utils import get_files_from_dir, resize_images_and_labels
1754
2122
  from .io import _load_normalized_images_and_labels, _load_images_and_labels
1755
2123
 
2124
+ src=settings['src']
2125
+ dst=settings['dst']
2126
+ model_name=settings['model_name']
2127
+ diameter=settings['diameter']
2128
+ batch_size=settings['batch_size']
2129
+ flow_threshold=settings['flow_threshold']
2130
+ cellprob_threshold=settings['cellprob_threshold']
2131
+
2132
+ verbose=settings['verbose']
2133
+ plot=settings['plot']
2134
+ save=settings['save']
2135
+ custom_model=settings['custom_model']
2136
+ overlay=settings['overlay']
2137
+
2138
+ figuresize=25
2139
+ cmap='inferno'
2140
+ channels = [0,0]
2141
+ signal_thresholds = 1000
2142
+ normalize = True
2143
+ percentiles = [2,98]
2144
+ circular = False
2145
+ invert = False
2146
+ resize = False
2147
+ settings['width_height'] = [1000,1000]
2148
+ target_height = settings['width_height'][1]
2149
+ target_width = settings['width_height'][0]
2150
+ rescale = False
2151
+ resample = False
2152
+ grayscale = True
2153
+ test = False
2154
+
2155
+ os.makedirs(dst, exist_ok=True)
2156
+
2157
+ if not custom_model is None:
2158
+ if not os.path.exists(custom_model):
2159
+ print(f'Custom model not found: {custom_model}')
2160
+ return
2161
+
1756
2162
  if not torch.cuda.is_available():
1757
2163
  print(f'Torch CUDA is not available, using CPU')
1758
2164
 
1759
2165
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
1760
2166
 
1761
2167
  if custom_model == None:
1762
- if model_name =='cyto':
1763
- model = cp_models.CellposeModel(gpu=True, model_type=model_name, net_avg=False, diam_mean=diameter, pretrained_model=None)
1764
- else:
1765
- model = cp_models.CellposeModel(gpu=True, model_type=model_name)
1766
-
1767
- if custom_model != None:
1768
- model = cp_models.CellposeModel(gpu=torch.cuda.is_available(), model_type=None, pretrained_model=custom_model, diam_mean=diameter, device=device, net_avg=False) #Assuming diameter is defined elsewhere
1769
- print(f'loaded custom model:{custom_model}')
2168
+ model = cp_models.CellposeModel(gpu=True, model_type=model_name, device=device)
2169
+ print(f'Loaded model: {model_name}')
2170
+ else:
2171
+ model = cp_models.CellposeModel(gpu=torch.cuda.is_available(), model_type=None, pretrained_model=custom_model, diam_mean=diameter, device=device)
2172
+ print("Pretrained Model Loaded:", model.pretrained_model)
1770
2173
 
1771
2174
  chans = [2, 1] if model_name == 'cyto2' else [0,0] if model_name == 'nucleus' else [1,0] if model_name == 'cyto' else [2, 0]
1772
2175
 
@@ -1778,14 +2181,16 @@ def identify_masks_finetune(src, dst, model_name, channels, diameter, batch_size
1778
2181
  if verbose == True:
1779
2182
  print(f'Cellpose settings: Model: {model_name}, channels: {channels}, cellpose_chans: {chans}, diameter:{diameter}, flow_threshold:{flow_threshold}, cellprob_threshold:{cellprob_threshold}')
1780
2183
 
1781
- all_image_files = get_files_from_dir(src, file_extension="*.tif")
2184
+ all_image_files = [os.path.join(src, f) for f in os.listdir(src) if f.endswith('.tif')]
2185
+
1782
2186
  random.shuffle(all_image_files)
1783
2187
 
1784
2188
  time_ls = []
1785
2189
  for i in range(0, len(all_image_files), batch_size):
1786
2190
  image_files = all_image_files[i:i+batch_size]
2191
+
1787
2192
  if normalize:
1788
- images, _, image_names, _ = _load_normalized_images_and_labels(image_files=image_files, label_files=None, signal_thresholds=signal_thresholds, channels=channels, percentiles=percentiles, circular=circular, invert=invert, visualize=verbose)
2193
+ images, _, image_names, _ = _load_normalized_images_and_labels(image_files=image_files, label_files=None, signal_thresholds=signal_thresholds, channels=channels, percentiles=percentiles, circular=circular, invert=invert, visualize=plot)
1789
2194
  images = [np.squeeze(img) if img.shape[-1] == 1 else img for img in images]
1790
2195
  orig_dims = [(image.shape[0], image.shape[1]) for image in images]
1791
2196
  else:
@@ -1806,8 +2211,7 @@ def identify_masks_finetune(src, dst, model_name, channels, diameter, batch_size
1806
2211
  cellprob_threshold=cellprob_threshold,
1807
2212
  rescale=rescale,
1808
2213
  resample=resample,
1809
- net_avg=net_avg,
1810
- progress=False)
2214
+ progress=True)
1811
2215
 
1812
2216
  if len(output) == 4:
1813
2217
  mask, flows, _, _ = output
@@ -1882,7 +2286,6 @@ def identify_masks(src, object_type, model_name, batch_size, channels, diameter,
1882
2286
 
1883
2287
  #Note add logic that handles batches of size 1 as these will break the code batches must all be > 2 images
1884
2288
  gc.collect()
1885
- #print('========== generating masks ==========')
1886
2289
 
1887
2290
  if not torch.cuda.is_available():
1888
2291
  print(f'Torch CUDA is not available, using CPU')
@@ -2047,9 +2450,9 @@ def all_elements_match(list1, list2):
2047
2450
  # Check if all elements in list1 are in list2
2048
2451
  return all(element in list2 for element in list1)
2049
2452
 
2050
- def generate_cellpose_masks_v1(src, settings, object_type):
2453
+ def generate_cellpose_masks(src, settings, object_type):
2051
2454
 
2052
- from .utils import _masks_to_masks_stack, _filter_cp_masks, _get_cellpose_batch_size, _get_object_settings, _get_cellpose_channels, mask_object_count
2455
+ from .utils import _masks_to_masks_stack, _filter_cp_masks, _get_cellpose_batch_size, _get_object_settings, _get_cellpose_channels, _choose_model, mask_object_count
2053
2456
  from .io import _create_database, _save_object_counts_to_database, _check_masks, _get_avg_object_size
2054
2457
  from .timelapse import _npz_to_movie, _btrack_track_cells, _trackpy_track_cells
2055
2458
  from .plot import plot_masks
@@ -2079,15 +2482,12 @@ def generate_cellpose_masks_v1(src, settings, object_type):
2079
2482
  cellpose_channels = _get_cellpose_channels(src, settings['nucleus_channel'], settings['pathogen_channel'], settings['cell_channel'])
2080
2483
  if settings['verbose']:
2081
2484
  print(cellpose_channels)
2485
+
2082
2486
  channels = cellpose_channels[object_type]
2083
2487
  cellpose_batch_size = _get_cellpose_batch_size()
2084
-
2085
2488
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
2086
- model = cp_models.Cellpose(gpu=True, model_type=model_name, device=device) #net_avg=net_avg
2087
- #dn = denoise.CellposeDenoiseModel(model_type=f"denoise_{model_name}", gpu=True, device=device)
2088
-
2489
+ model = _choose_model(model_name, device, object_type='cell', restore_type=None)
2089
2490
  chans = [2, 1] if model_name == 'cyto2' else [0,0] if model_name == 'nucleus' else [2,0] if model_name == 'cyto' else [2, 0] if model_name == 'cyto3' else [2, 0]
2090
-
2091
2491
  paths = [os.path.join(src, file) for file in os.listdir(src) if file.endswith('.npz')]
2092
2492
 
2093
2493
  count_loc = os.path.dirname(src)+'/measurements/measurements.db'
@@ -2096,7 +2496,6 @@ def generate_cellpose_masks_v1(src, settings, object_type):
2096
2496
 
2097
2497
  average_sizes = []
2098
2498
  time_ls = []
2099
-
2100
2499
  for file_index, path in enumerate(paths):
2101
2500
  name = os.path.basename(path)
2102
2501
  name, ext = os.path.splitext(name)
@@ -2210,23 +2609,45 @@ def generate_cellpose_masks_v1(src, settings, object_type):
2210
2609
  mode=timelapse_mode)
2211
2610
  else:
2212
2611
  mask_stack = _masks_to_masks_stack(masks)
2213
-
2214
2612
  else:
2215
2613
  _save_object_counts_to_database(masks, object_type, batch_filenames, count_loc, added_string='_before_filtration')
2216
- mask_stack = _filter_cp_masks(masks=masks,
2217
- flows=flows,
2218
- filter_size=object_settings['filter_size'],
2219
- filter_intensity=object_settings['filter_intensity'],
2220
- minimum_size=object_settings['minimum_size'],
2221
- maximum_size=object_settings['maximum_size'],
2222
- remove_border_objects=object_settings['remove_border_objects'],
2223
- merge=False,
2224
- batch=batch,
2225
- plot=settings['plot'],
2226
- figuresize=figuresize)
2227
-
2228
- _save_object_counts_to_database(mask_stack, object_type, batch_filenames, count_loc, added_string='_after_filtration')
2614
+ if object_settings['merge'] and not settings['filter']:
2615
+ mask_stack = _filter_cp_masks(masks=masks,
2616
+ flows=flows,
2617
+ filter_size=False,
2618
+ filter_intensity=False,
2619
+ minimum_size=object_settings['minimum_size'],
2620
+ maximum_size=object_settings['maximum_size'],
2621
+ remove_border_objects=False,
2622
+ merge=object_settings['merge'],
2623
+ batch=batch,
2624
+ plot=settings['plot'],
2625
+ figuresize=figuresize)
2626
+
2627
+ if settings['filter']:
2628
+ mask_stack = _filter_cp_masks(masks=masks,
2629
+ flows=flows,
2630
+ filter_size=object_settings['filter_size'],
2631
+ filter_intensity=object_settings['filter_intensity'],
2632
+ minimum_size=object_settings['minimum_size'],
2633
+ maximum_size=object_settings['maximum_size'],
2634
+ remove_border_objects=object_settings['remove_border_objects'],
2635
+ merge=object_settings['merge'],
2636
+ batch=batch,
2637
+ plot=settings['plot'],
2638
+ figuresize=figuresize)
2639
+
2640
+ _save_object_counts_to_database(mask_stack, object_type, batch_filenames, count_loc, added_string='_after_filtration')
2641
+ else:
2642
+ mask_stack = _masks_to_masks_stack(masks)
2229
2643
 
2644
+ if settings['plot']:
2645
+ for idx, (mask, flow, image) in enumerate(zip(masks, flows[0], batch)):
2646
+ if idx == 0:
2647
+ num_objects = mask_object_count(mask)
2648
+ print(f'Number of objects, : {num_objects}')
2649
+ plot_masks(batch=image, masks=mask, flows=flow, cmap='inferno', figuresize=figuresize, nr=1, file_type='.npz', print_object_number=True)
2650
+
2230
2651
  if not np.any(mask_stack):
2231
2652
  average_obj_size = 0
2232
2653
  else:
@@ -2255,207 +2676,661 @@ def generate_cellpose_masks_v1(src, settings, object_type):
2255
2676
  torch.cuda.empty_cache()
2256
2677
  return
2257
2678
 
2258
- def generate_cellpose_masks(src, settings, object_type):
2259
-
2260
- from .utils import _masks_to_masks_stack, _filter_cp_masks, _get_cellpose_batch_size, _get_object_settings, _get_cellpose_channels, _choose_model, mask_object_count
2261
- from .io import _create_database, _save_object_counts_to_database, _check_masks, _get_avg_object_size
2262
- from .timelapse import _npz_to_movie, _btrack_track_cells, _trackpy_track_cells
2263
- from .plot import plot_masks
2264
-
2265
- gc.collect()
2266
- if not torch.cuda.is_available():
2267
- print(f'Torch CUDA is not available, using CPU')
2268
-
2269
- figuresize=25
2270
- timelapse = settings['timelapse']
2271
-
2272
- if timelapse:
2273
- timelapse_displacement = settings['timelapse_displacement']
2274
- timelapse_frame_limits = settings['timelapse_frame_limits']
2275
- timelapse_memory = settings['timelapse_memory']
2276
- timelapse_remove_transient = settings['timelapse_remove_transient']
2277
- timelapse_mode = settings['timelapse_mode']
2278
- timelapse_objects = settings['timelapse_objects']
2679
+ def generate_masks_from_imgs(src, model, model_name, batch_size, diameter, cellprob_threshold, grayscale, save, normalize, channels, percentiles, circular, invert, plot, resize, target_height, target_width, verbose):
2680
+ from .io import _load_images_and_labels, _load_normalized_images_and_labels
2681
+ from .utils import resize_images_and_labels, resizescikit
2682
+ from .plot import print_mask_and_flows
2683
+
2684
+ dst = os.path.join(src, model_name)
2685
+ os.makedirs(dst, exist_ok=True)
2279
2686
 
2280
- batch_size = settings['batch_size']
2281
- cellprob_threshold = settings[f'{object_type}_CP_prob']
2282
2687
  flow_threshold = 30
2283
-
2284
- object_settings = _get_object_settings(object_type, settings)
2285
- model_name = object_settings['model_name']
2286
-
2287
- cellpose_channels = _get_cellpose_channels(src, settings['nucleus_channel'], settings['pathogen_channel'], settings['cell_channel'])
2288
- if settings['verbose']:
2289
- print(cellpose_channels)
2688
+ chans = [2, 1] if model_name == 'cyto2' else [0,0] if model_name == 'nucleus' else [1,0] if model_name == 'cyto' else [2, 0]
2290
2689
 
2291
- channels = cellpose_channels[object_type]
2292
- cellpose_batch_size = _get_cellpose_batch_size()
2293
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
2294
- model = _choose_model(model_name, device, object_type='cell', restore_type=None)
2295
- chans = [2, 1] if model_name == 'cyto2' else [0,0] if model_name == 'nucleus' else [2,0] if model_name == 'cyto' else [2, 0] if model_name == 'cyto3' else [2, 0]
2296
- paths = [os.path.join(src, file) for file in os.listdir(src) if file.endswith('.npz')]
2690
+ if grayscale:
2691
+ chans=[0, 0]
2297
2692
 
2298
- count_loc = os.path.dirname(src)+'/measurements/measurements.db'
2299
- os.makedirs(os.path.dirname(src)+'/measurements', exist_ok=True)
2300
- _create_database(count_loc)
2693
+ all_image_files = [os.path.join(src, f) for f in os.listdir(src) if f.endswith('.tif')]
2694
+ random.shuffle(all_image_files)
2695
+
2696
+
2697
+ if verbose == True:
2698
+ print(f'Cellpose settings: Model: {model_name}, channels: {channels}, cellpose_chans: {chans}, diameter:{diameter}, flow_threshold:{flow_threshold}, cellprob_threshold:{cellprob_threshold}')
2301
2699
 
2302
- average_sizes = []
2303
2700
  time_ls = []
2304
- for file_index, path in enumerate(paths):
2305
- name = os.path.basename(path)
2306
- name, ext = os.path.splitext(name)
2307
- output_folder = os.path.join(os.path.dirname(path), object_type+'_mask_stack')
2308
- os.makedirs(output_folder, exist_ok=True)
2309
- overall_average_size = 0
2310
- with np.load(path) as data:
2311
- stack = data['data']
2312
- filenames = data['filenames']
2313
- if settings['timelapse']:
2314
-
2315
- trackable_objects = ['cell','nucleus','pathogen']
2316
- if not all_elements_match(settings['timelapse_objects'], trackable_objects):
2317
- print(f'timelapse_objects {settings["timelapse_objects"]} must be a subset of {trackable_objects}')
2318
- return
2701
+ for i in range(0, len(all_image_files), batch_size):
2702
+ image_files = all_image_files[i:i+batch_size]
2319
2703
 
2320
- if len(stack) != batch_size:
2321
- print(f'Changed batch_size:{batch_size} to {len(stack)}, data length:{len(stack)}')
2322
- settings['timelapse_batch_size'] = len(stack)
2323
- batch_size = len(stack)
2324
- if isinstance(timelapse_frame_limits, list):
2325
- if len(timelapse_frame_limits) >= 2:
2326
- stack = stack[timelapse_frame_limits[0]: timelapse_frame_limits[1], :, :, :].astype(stack.dtype)
2327
- filenames = filenames[timelapse_frame_limits[0]: timelapse_frame_limits[1]]
2328
- batch_size = len(stack)
2329
- print(f'Cut batch at indecies: {timelapse_frame_limits}, New batch_size: {batch_size} ')
2704
+ if normalize:
2705
+ images, _, image_names, _ = _load_normalized_images_and_labels(image_files=image_files, label_files=None, signal_thresholds=100, channels=channels, percentiles=percentiles, circular=circular, invert=invert, visualize=plot)
2706
+ images = [np.squeeze(img) if img.shape[-1] == 1 else img for img in images]
2707
+ orig_dims = [(image.shape[0], image.shape[1]) for image in images]
2708
+ else:
2709
+ images, _, image_names, _ = _load_images_and_labels(image_files=image_files, label_files=None, circular=circular, invert=invert)
2710
+ images = [np.squeeze(img) if img.shape[-1] == 1 else img for img in images]
2711
+ orig_dims = [(image.shape[0], image.shape[1]) for image in images]
2712
+ if resize:
2713
+ images, _ = resize_images_and_labels(images, None, target_height, target_width, True)
2330
2714
 
2331
- for i in range(0, stack.shape[0], batch_size):
2332
- mask_stack = []
2715
+ for file_index, stack in enumerate(images):
2333
2716
  start = time.time()
2717
+ output = model.eval(x=stack,
2718
+ normalize=False,
2719
+ channels=chans,
2720
+ channel_axis=3,
2721
+ diameter=diameter,
2722
+ flow_threshold=flow_threshold,
2723
+ cellprob_threshold=cellprob_threshold,
2724
+ rescale=False,
2725
+ resample=False,
2726
+ progress=True)
2334
2727
 
2335
- if stack.shape[3] == 1:
2336
- batch = stack[i: i+batch_size, :, :, [0,0]].astype(stack.dtype)
2728
+ if len(output) == 4:
2729
+ mask, flows, _, _ = output
2730
+ elif len(output) == 3:
2731
+ mask, flows, _ = output
2337
2732
  else:
2338
- batch = stack[i: i+batch_size, :, :, channels].astype(stack.dtype)
2733
+ raise ValueError("Unexpected number of return values from model.eval()")
2339
2734
 
2340
- batch_filenames = filenames[i: i+batch_size].tolist()
2735
+ if resize:
2736
+ dims = orig_dims[file_index]
2737
+ mask = resizescikit(mask, dims, order=0, preserve_range=True, anti_aliasing=False).astype(mask.dtype)
2341
2738
 
2342
- if not settings['plot']:
2343
- batch, batch_filenames = _check_masks(batch, batch_filenames, output_folder)
2344
- if batch.size == 0:
2345
- print(f'Processing {file_index}/{len(paths)}: Images/npz {batch.shape[0]}')
2346
- continue
2347
- if batch.max() > 1:
2348
- batch = batch / batch.max()
2739
+ stop = time.time()
2740
+ duration = (stop - start)
2741
+ time_ls.append(duration)
2742
+ average_time = np.mean(time_ls) if len(time_ls) > 0 else 0
2743
+ print(f'Processing {file_index+1}/{len(images)} images : Time/image {average_time:.3f} sec', end='\r', flush=True)
2744
+ if plot:
2745
+ if resize:
2746
+ stack = resizescikit(stack, dims, preserve_range=True, anti_aliasing=False).astype(stack.dtype)
2747
+ print_mask_and_flows(stack, mask, flows, overlay=True)
2748
+ if save:
2749
+ output_filename = os.path.join(dst, image_names[file_index])
2750
+ cv2.imwrite(output_filename, mask)
2349
2751
 
2350
- if timelapse:
2351
- stitch_threshold=100.0
2352
- movie_path = os.path.join(os.path.dirname(src), 'movies')
2353
- os.makedirs(movie_path, exist_ok=True)
2354
- save_path = os.path.join(movie_path, f'timelapse_{object_type}_{name}.mp4')
2355
- _npz_to_movie(batch, batch_filenames, save_path, fps=2)
2356
- else:
2357
- stitch_threshold=0.0
2358
2752
 
2359
- print('batch.shape',batch.shape)
2360
- masks, flows, _, _ = model.eval(x=batch,
2361
- batch_size=cellpose_batch_size,
2362
- normalize=False,
2363
- channels=chans,
2364
- channel_axis=3,
2365
- diameter=object_settings['diameter'],
2366
- flow_threshold=flow_threshold,
2367
- cellprob_threshold=cellprob_threshold,
2368
- rescale=None,
2369
- resample=object_settings['resample'],
2370
- stitch_threshold=stitch_threshold)
2753
+ def check_cellpose_models(settings):
2754
+
2755
+ src = settings['src']
2756
+ batch_size = settings['batch_size']
2757
+ cellprob_threshold = settings['cellprob_threshold']
2758
+ save = settings['save']
2759
+ normalize = settings['normalize']
2760
+ channels = settings['channels']
2761
+ percentiles = settings['percentiles']
2762
+ circular = settings['circular']
2763
+ invert = settings['invert']
2764
+ plot = settings['plot']
2765
+ diameter = settings['diameter']
2766
+ resize = settings['resize']
2767
+ grayscale = settings['grayscale']
2768
+ verbose = settings['verbose']
2769
+ target_height = settings['width_height'][0]
2770
+ target_width = settings['width_height'][1]
2771
+
2772
+ cellpose_models = ['cyto', 'nuclei', 'cyto2', 'cyto3']
2773
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
2774
+
2775
+ for model_name in cellpose_models:
2776
+
2777
+ model = cp_models.CellposeModel(gpu=True, model_type=model_name, device=device)
2778
+ print(f'Using {model_name}')
2779
+ generate_masks_from_imgs(src, model, model_name, batch_size, diameter, cellprob_threshold, grayscale, save, normalize, channels, percentiles, circular, invert, plot, resize, target_height, target_width, verbose)
2780
+
2781
+ return
2782
+
2783
+ def compare_masks_v1(dir1, dir2, dir3, verbose=False):
2784
+
2785
+ from .io import _read_mask
2786
+ from .plot import visualize_masks, plot_comparison_results
2787
+ from .utils import extract_boundaries, boundary_f1_score, compute_segmentation_ap, jaccard_index, dice_coefficient
2788
+
2789
+ filenames = os.listdir(dir1)
2790
+ results = []
2791
+ cond_1 = os.path.basename(dir1)
2792
+ cond_2 = os.path.basename(dir2)
2793
+ cond_3 = os.path.basename(dir3)
2794
+
2795
+ for index, filename in enumerate(filenames):
2796
+ print(f'Processing image:{index+1}', end='\r', flush=True)
2797
+ path1, path2, path3 = os.path.join(dir1, filename), os.path.join(dir2, filename), os.path.join(dir3, filename)
2798
+
2799
+ print(path1)
2800
+ print(path2)
2801
+ print(path3)
2802
+
2803
+ if os.path.exists(path2) and os.path.exists(path3):
2371
2804
 
2372
- if timelapse:
2805
+ mask1, mask2, mask3 = _read_mask(path1), _read_mask(path2), _read_mask(path3)
2806
+ boundary_true1, boundary_true2, boundary_true3 = extract_boundaries(mask1), extract_boundaries(mask2), extract_boundaries(mask3)
2807
+
2808
+
2809
+ true_masks, pred_masks = [mask1], [mask2, mask3] # Assuming mask1 is the ground truth for simplicity
2810
+ true_labels, pred_labels_1, pred_labels_2 = label(mask1), label(mask2), label(mask3)
2811
+ average_precision_0, average_precision_1 = compute_segmentation_ap(mask1, mask2), compute_segmentation_ap(mask1, mask3)
2812
+ ap_scores = [average_precision_0, average_precision_1]
2373
2813
 
2374
- if settings['plot']:
2375
- for idx, (mask, flow, image) in enumerate(zip(masks, flows[0], batch)):
2376
- if idx == 0:
2377
- num_objects = mask_object_count(mask)
2378
- print(f'Number of objects: {num_objects}')
2379
- plot_masks(batch=image, masks=mask, flows=flow, cmap='inferno', figuresize=figuresize, nr=1, file_type='.npz', print_object_number=True)
2814
+ if verbose:
2815
+ #unique_values1, unique_values2, unique_values3 = np.unique(mask1), np.unique(mask2), np.unique(mask3)
2816
+ #print(f"Unique values in mask 1: {unique_values1}, mask 2: {unique_values2}, mask 3: {unique_values3}")
2817
+ visualize_masks(boundary_true1, boundary_true2, boundary_true3, title=f"Boundaries - {filename}")
2818
+
2819
+ boundary_f1_12, boundary_f1_13, boundary_f1_23 = boundary_f1_score(mask1, mask2), boundary_f1_score(mask1, mask3), boundary_f1_score(mask2, mask3)
2380
2820
 
2381
- _save_object_counts_to_database(masks, object_type, batch_filenames, count_loc, added_string='_timelapse')
2382
- if object_type in timelapse_objects:
2383
- if timelapse_mode == 'btrack':
2384
- if not timelapse_displacement is None:
2385
- radius = timelapse_displacement
2386
- else:
2387
- radius = 100
2821
+ if (np.unique(mask1).size == 1 and np.unique(mask1)[0] == 0) and \
2822
+ (np.unique(mask2).size == 1 and np.unique(mask2)[0] == 0) and \
2823
+ (np.unique(mask3).size == 1 and np.unique(mask3)[0] == 0):
2824
+ continue
2825
+
2826
+ if verbose:
2827
+ #unique_values4, unique_values5, unique_values6 = np.unique(boundary_f1_12), np.unique(boundary_f1_13), np.unique(boundary_f1_23)
2828
+ #print(f"Unique values in boundary mask 1: {unique_values4}, mask 2: {unique_values5}, mask 3: {unique_values6}")
2829
+ visualize_masks(mask1, mask2, mask3, title=filename)
2830
+
2831
+ jaccard12 = jaccard_index(mask1, mask2)
2832
+ dice12 = dice_coefficient(mask1, mask2)
2833
+
2834
+ jaccard13 = jaccard_index(mask1, mask3)
2835
+ dice13 = dice_coefficient(mask1, mask3)
2836
+
2837
+ jaccard23 = jaccard_index(mask2, mask3)
2838
+ dice23 = dice_coefficient(mask2, mask3)
2388
2839
 
2389
- workers = os.cpu_count()-2
2390
- if workers < 1:
2391
- workers = 1
2840
+ results.append({
2841
+ f'filename': filename,
2842
+ f'jaccard_{cond_1}_{cond_2}': jaccard12,
2843
+ f'dice_{cond_1}_{cond_2}': dice12,
2844
+ f'jaccard_{cond_1}_{cond_3}': jaccard13,
2845
+ f'dice_{cond_1}_{cond_3}': dice13,
2846
+ f'jaccard_{cond_2}_{cond_3}': jaccard23,
2847
+ f'dice_{cond_2}_{cond_3}': dice23,
2848
+ f'boundary_f1_{cond_1}_{cond_2}': boundary_f1_12,
2849
+ f'boundary_f1_{cond_1}_{cond_3}': boundary_f1_13,
2850
+ f'boundary_f1_{cond_2}_{cond_3}': boundary_f1_23,
2851
+ f'average_precision_{cond_1}_{cond_2}': ap_scores[0],
2852
+ f'average_precision_{cond_1}_{cond_3}': ap_scores[1]
2853
+ })
2854
+ else:
2855
+ print(f'Cannot find {path1} or {path2} or {path3}')
2856
+ fig = plot_comparison_results(results)
2857
+ return results, fig
2392
2858
 
2393
- mask_stack = _btrack_track_cells(src=src,
2394
- name=name,
2395
- batch_filenames=batch_filenames,
2396
- object_type=object_type,
2397
- plot=settings['plot'],
2398
- save=settings['save'],
2399
- masks_3D=masks,
2400
- mode=timelapse_mode,
2401
- timelapse_remove_transient=timelapse_remove_transient,
2402
- radius=radius,
2403
- workers=workers)
2404
- if timelapse_mode == 'trackpy':
2405
- mask_stack = _trackpy_track_cells(src=src,
2406
- name=name,
2407
- batch_filenames=batch_filenames,
2408
- object_type=object_type,
2409
- masks=masks,
2410
- timelapse_displacement=timelapse_displacement,
2411
- timelapse_memory=timelapse_memory,
2412
- timelapse_remove_transient=timelapse_remove_transient,
2413
- plot=settings['plot'],
2414
- save=settings['save'],
2415
- mode=timelapse_mode)
2416
- else:
2417
- mask_stack = _masks_to_masks_stack(masks)
2859
+ def compare_cellpose_masks_v1(src, verbose=False):
2860
+ from .io import _read_mask
2861
+ from .plot import visualize_masks, plot_comparison_results, visualize_cellpose_masks
2862
+ from .utils import extract_boundaries, boundary_f1_score, compute_segmentation_ap, jaccard_index
2418
2863
 
2419
- else:
2420
- _save_object_counts_to_database(masks, object_type, batch_filenames, count_loc, added_string='_before_filtration')
2421
- mask_stack = _filter_cp_masks(masks=masks,
2422
- flows=flows,
2423
- filter_size=object_settings['filter_size'],
2424
- filter_intensity=object_settings['filter_intensity'],
2425
- minimum_size=object_settings['minimum_size'],
2426
- maximum_size=object_settings['maximum_size'],
2427
- remove_border_objects=object_settings['remove_border_objects'],
2428
- merge=False,
2429
- batch=batch,
2430
- plot=settings['plot'],
2431
- figuresize=figuresize)
2432
-
2433
- _save_object_counts_to_database(mask_stack, object_type, batch_filenames, count_loc, added_string='_after_filtration')
2864
+ import os
2865
+ import numpy as np
2866
+ from skimage.measure import label
2434
2867
 
2435
- if not np.any(mask_stack):
2436
- average_obj_size = 0
2437
- else:
2438
- average_obj_size = _get_avg_object_size(mask_stack)
2868
+ # Collect all subdirectories in src
2869
+ dirs = [os.path.join(src, d) for d in os.listdir(src) if os.path.isdir(os.path.join(src, d))]
2439
2870
 
2440
- average_sizes.append(average_obj_size)
2441
- overall_average_size = np.mean(average_sizes) if len(average_sizes) > 0 else 0
2871
+ dirs.sort() # Optional: sort directories if needed
2442
2872
 
2443
- stop = time.time()
2444
- duration = (stop - start)
2445
- time_ls.append(duration)
2446
- average_time = np.mean(time_ls) if len(time_ls) > 0 else 0
2447
- time_in_min = average_time/60
2448
- time_per_mask = average_time/batch_size
2449
- print(f'Processing {len(paths)} files with {batch_size} imgs: {(file_index+1)*(batch_size+1)}/{(len(paths))*(batch_size+1)}: Time/batch {time_in_min:.3f} min: Time/mask {time_per_mask:.3f}sec: {object_type} size: {overall_average_size:.3f} px2')
2450
- if not timelapse:
2451
- if settings['plot']:
2452
- plot_masks(batch, mask_stack, flows, figuresize=figuresize, cmap='inferno', nr=batch_size)
2453
- if settings['save']:
2454
- for mask_index, mask in enumerate(mask_stack):
2455
- output_filename = os.path.join(output_folder, batch_filenames[mask_index])
2456
- np.save(output_filename, mask)
2457
- mask_stack = []
2458
- batch_filenames = []
2459
- gc.collect()
2460
- torch.cuda.empty_cache()
2461
- return
2873
+ # Get common files in all directories
2874
+ common_files = set(os.listdir(dirs[0]))
2875
+ for d in dirs[1:]:
2876
+ common_files.intersection_update(os.listdir(d))
2877
+ common_files = list(common_files)
2878
+
2879
+ results = []
2880
+ conditions = [os.path.basename(d) for d in dirs]
2881
+
2882
+ for index, filename in enumerate(common_files):
2883
+ print(f'Processing image {index+1}/{len(common_files)}', end='\r', flush=True)
2884
+ paths = [os.path.join(d, filename) for d in dirs]
2885
+
2886
+ # Check if file exists in all directories
2887
+ if not all(os.path.exists(path) for path in paths):
2888
+ print(f'Skipping {filename} as it is not present in all directories.')
2889
+ continue
2890
+
2891
+ masks = [_read_mask(path) for path in paths]
2892
+ boundaries = [extract_boundaries(mask) for mask in masks]
2893
+
2894
+ if verbose:
2895
+ visualize_cellpose_masks(masks, titles=conditions, comparison_title=f"Masks Comparison for {filename}")
2896
+
2897
+ # Initialize data structure for results
2898
+ file_results = {'filename': filename}
2899
+
2900
+ # Compare each mask with each other
2901
+ for i in range(len(masks)):
2902
+ for j in range(i + 1, len(masks)):
2903
+ condition_i = conditions[i]
2904
+ condition_j = conditions[j]
2905
+ mask_i = masks[i]
2906
+ mask_j = masks[j]
2907
+
2908
+ # Compute metrics
2909
+ boundary_f1 = boundary_f1_score(mask_i, mask_j)
2910
+ jaccard = jaccard_index(mask_i, mask_j)
2911
+ average_precision = compute_segmentation_ap(mask_i, mask_j)
2912
+
2913
+ # Store results
2914
+ file_results[f'jaccard_{condition_i}_{condition_j}'] = jaccard
2915
+ file_results[f'boundary_f1_{condition_i}_{condition_j}'] = boundary_f1
2916
+ file_results[f'average_precision_{condition_i}_{condition_j}'] = average_precision
2917
+
2918
+ results.append(file_results)
2919
+
2920
+ fig = plot_comparison_results(results)
2921
+ return results, fig
2922
+
2923
+ def compare_mask(args):
2924
+ src, filename, dirs, conditions = args
2925
+ paths = [os.path.join(d, filename) for d in dirs]
2926
+
2927
+ if not all(os.path.exists(path) for path in paths):
2928
+ return None
2929
+
2930
+ from .io import _read_mask # Import here to avoid issues in multiprocessing
2931
+ from .utils import extract_boundaries, boundary_f1_score, compute_segmentation_ap, jaccard_index
2932
+ from .plot import plot_comparison_results
2933
+
2934
+ masks = [_read_mask(path) for path in paths]
2935
+ file_results = {'filename': filename}
2936
+
2937
+ for i in range(len(masks)):
2938
+ for j in range(i + 1, len(masks)):
2939
+ mask_i, mask_j = masks[i], masks[j]
2940
+ f1_score = boundary_f1_score(mask_i, mask_j)
2941
+ jac_index = jaccard_index(mask_i, mask_j)
2942
+ ap_score = compute_segmentation_ap(mask_i, mask_j)
2943
+
2944
+ file_results.update({
2945
+ f'jaccard_{conditions[i]}_{conditions[j]}': jac_index,
2946
+ f'boundary_f1_{conditions[i]}_{conditions[j]}': f1_score,
2947
+ f'ap_{conditions[i]}_{conditions[j]}': ap_score
2948
+ })
2949
+
2950
+ return file_results
2951
+
2952
+ def compare_cellpose_masks(src, verbose=False, processes=None):
2953
+ from .plot import visualize_cellpose_masks, plot_comparison_results
2954
+ from .io import _read_mask
2955
+ dirs = [os.path.join(src, d) for d in os.listdir(src) if os.path.isdir(os.path.join(src, d))]
2956
+ dirs.sort() # Optional: sort directories if needed
2957
+ conditions = [os.path.basename(d) for d in dirs]
2958
+
2959
+ # Get common files in all directories
2960
+ common_files = set(os.listdir(dirs[0]))
2961
+ for d in dirs[1:]:
2962
+ common_files.intersection_update(os.listdir(d))
2963
+ common_files = list(common_files)
2964
+
2965
+ # Create a pool of workers
2966
+ with Pool(processes=processes) as pool:
2967
+ args = [(src, filename, dirs, conditions) for filename in common_files]
2968
+ results = pool.map(compare_mask, args)
2969
+
2970
+ # Filter out None results (from skipped files)
2971
+ results = [res for res in results if res is not None]
2972
+
2973
+ if verbose:
2974
+ for result in results:
2975
+ filename = result['filename']
2976
+ masks = [_read_mask(os.path.join(d, filename)) for d in dirs]
2977
+ visualize_cellpose_masks(masks, titles=conditions, comparison_title=f"Masks Comparison for {filename}")
2978
+
2979
+ fig = plot_comparison_results(results)
2980
+ return results, fig
2981
+
2982
+
2983
+ def _calculate_similarity(df, features, col_to_compare, val1, val2):
2984
+ """
2985
+ Calculate similarity scores of each well to the positive and negative controls using various metrics.
2986
+
2987
+ Args:
2988
+ df (pandas.DataFrame): DataFrame containing the data.
2989
+ features (list): List of feature columns to use for similarity calculation.
2990
+ col_to_compare (str): Column name to use for comparing groups.
2991
+ val1, val2 (str): Values in col_to_compare to create subsets for comparison.
2992
+
2993
+ Returns:
2994
+ pandas.DataFrame: DataFrame with similarity scores.
2995
+ """
2996
+ # Separate positive and negative control wells
2997
+ pos_control = df[df[col_to_compare] == val1][features].mean()
2998
+ neg_control = df[df[col_to_compare] == val2][features].mean()
2999
+
3000
+ # Standardize features for Mahalanobis distance
3001
+ scaler = StandardScaler()
3002
+ scaled_features = scaler.fit_transform(df[features])
3003
+
3004
+ # Regularize the covariance matrix to avoid singularity
3005
+ cov_matrix = np.cov(scaled_features, rowvar=False)
3006
+ inv_cov_matrix = None
3007
+ try:
3008
+ inv_cov_matrix = np.linalg.inv(cov_matrix)
3009
+ except np.linalg.LinAlgError:
3010
+ # Add a small value to the diagonal elements for regularization
3011
+ epsilon = 1e-5
3012
+ inv_cov_matrix = np.linalg.inv(cov_matrix + np.eye(cov_matrix.shape[0]) * epsilon)
3013
+
3014
+ # Calculate similarity scores
3015
+ df['similarity_to_pos_euclidean'] = df[features].apply(lambda row: euclidean(row, pos_control), axis=1)
3016
+ df['similarity_to_neg_euclidean'] = df[features].apply(lambda row: euclidean(row, neg_control), axis=1)
3017
+ df['similarity_to_pos_cosine'] = df[features].apply(lambda row: cosine(row, pos_control), axis=1)
3018
+ df['similarity_to_neg_cosine'] = df[features].apply(lambda row: cosine(row, neg_control), axis=1)
3019
+ df['similarity_to_pos_mahalanobis'] = df[features].apply(lambda row: mahalanobis(row, pos_control, inv_cov_matrix), axis=1)
3020
+ df['similarity_to_neg_mahalanobis'] = df[features].apply(lambda row: mahalanobis(row, neg_control, inv_cov_matrix), axis=1)
3021
+ df['similarity_to_pos_manhattan'] = df[features].apply(lambda row: cityblock(row, pos_control), axis=1)
3022
+ df['similarity_to_neg_manhattan'] = df[features].apply(lambda row: cityblock(row, neg_control), axis=1)
3023
+ df['similarity_to_pos_minkowski'] = df[features].apply(lambda row: minkowski(row, pos_control, p=3), axis=1)
3024
+ df['similarity_to_neg_minkowski'] = df[features].apply(lambda row: minkowski(row, neg_control, p=3), axis=1)
3025
+ df['similarity_to_pos_chebyshev'] = df[features].apply(lambda row: chebyshev(row, pos_control), axis=1)
3026
+ df['similarity_to_neg_chebyshev'] = df[features].apply(lambda row: chebyshev(row, neg_control), axis=1)
3027
+ df['similarity_to_pos_hamming'] = df[features].apply(lambda row: hamming(row, pos_control), axis=1)
3028
+ df['similarity_to_neg_hamming'] = df[features].apply(lambda row: hamming(row, neg_control), axis=1)
3029
+ df['similarity_to_pos_jaccard'] = df[features].apply(lambda row: jaccard(row, pos_control), axis=1)
3030
+ df['similarity_to_neg_jaccard'] = df[features].apply(lambda row: jaccard(row, neg_control), axis=1)
3031
+ df['similarity_to_pos_braycurtis'] = df[features].apply(lambda row: braycurtis(row, pos_control), axis=1)
3032
+ df['similarity_to_neg_braycurtis'] = df[features].apply(lambda row: braycurtis(row, neg_control), axis=1)
3033
+
3034
+ return df
3035
+
3036
+ def _permutation_importance(df, feature_string='channel_3', col_to_compare='col', pos='c1', neg='c2', exclude=None, n_repeats=10, clean=True, nr_to_plot=30, n_estimators=100, test_size=0.2, random_state=42, model_type='xgboost', n_jobs=-1):
3037
+
3038
+ """
3039
+ Calculates permutation importance for numerical features in the dataframe,
3040
+ comparing groups based on specified column values and uses the model to predict
3041
+ the class for all other rows in the dataframe.
3042
+
3043
+ Args:
3044
+ df (pandas.DataFrame): The DataFrame containing the data.
3045
+ feature_string (str): String to filter features that contain this substring.
3046
+ col_to_compare (str): Column name to use for comparing groups.
3047
+ pos, neg (str): Values in col_to_compare to create subsets for comparison.
3048
+ exclude (list or str, optional): Columns to exclude from features.
3049
+ n_repeats (int): Number of repeats for permutation importance.
3050
+ clean (bool): Whether to remove columns with a single value.
3051
+ nr_to_plot (int): Number of top features to plot based on permutation importance.
3052
+ n_estimators (int): Number of trees in the random forest, gradient boosting, or XGBoost model.
3053
+ test_size (float): Proportion of the dataset to include in the test split.
3054
+ random_state (int): Random seed for reproducibility.
3055
+ model_type (str): Type of model to use ('random_forest', 'logistic_regression', 'gradient_boosting', 'xgboost').
3056
+ n_jobs (int): Number of jobs to run in parallel for applicable models.
3057
+
3058
+ Returns:
3059
+ pandas.DataFrame: The original dataframe with added prediction and data usage columns.
3060
+ pandas.DataFrame: DataFrame containing the importances and standard deviations.
3061
+ """
3062
+
3063
+ if 'cells_per_well' in df.columns:
3064
+ df = df.drop(columns=['cells_per_well'])
3065
+
3066
+ # Subset the dataframe based on specified column values
3067
+ df1 = df[df[col_to_compare] == pos].copy()
3068
+ df2 = df[df[col_to_compare] == neg].copy()
3069
+
3070
+ # Create target variable
3071
+ df1['target'] = 0
3072
+ df2['target'] = 1
3073
+
3074
+ # Combine the subsets for analysis
3075
+ combined_df = pd.concat([df1, df2])
3076
+
3077
+ # Automatically select numerical features
3078
+ features = combined_df.select_dtypes(include=[np.number]).columns.tolist()
3079
+ features.remove('target')
3080
+
3081
+ if clean:
3082
+ combined_df = combined_df.loc[:, combined_df.nunique() > 1]
3083
+ features = [feature for feature in features if feature in combined_df.columns]
3084
+
3085
+ if feature_string is not None:
3086
+ feature_list = ['channel_0', 'channel_1', 'channel_2', 'channel_3']
3087
+
3088
+ # Remove feature_string from the list if it exists
3089
+ if feature_string in feature_list:
3090
+ feature_list.remove(feature_string)
3091
+
3092
+ features = [feature for feature in features if feature_string in feature]
3093
+
3094
+ # Iterate through the list and remove columns from df
3095
+ for feature_ in feature_list:
3096
+ features = [feature for feature in features if feature_ not in feature]
3097
+ print(f'After removing {feature_} features: {len(features)}')
3098
+
3099
+ if exclude:
3100
+ if isinstance(exclude, list):
3101
+ features = [feature for feature in features if feature not in exclude]
3102
+ else:
3103
+ features.remove(exclude)
3104
+
3105
+ X = combined_df[features]
3106
+ y = combined_df['target']
3107
+
3108
+ # Split the data into training and testing sets
3109
+ X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=test_size, random_state=random_state)
3110
+
3111
+ # Label the data in the original dataframe
3112
+ combined_df['data_usage'] = 'train'
3113
+ combined_df.loc[X_test.index, 'data_usage'] = 'test'
3114
+
3115
+ # Initialize the model based on model_type
3116
+ if model_type == 'random_forest':
3117
+ model = RandomForestClassifier(n_estimators=n_estimators, random_state=random_state, n_jobs=n_jobs)
3118
+ elif model_type == 'logistic_regression':
3119
+ model = LogisticRegression(max_iter=1000, random_state=random_state, n_jobs=n_jobs)
3120
+ elif model_type == 'gradient_boosting':
3121
+ model = HistGradientBoostingClassifier(max_iter=n_estimators, random_state=random_state) # Supports n_jobs internally
3122
+ elif model_type == 'xgboost':
3123
+ model = XGBClassifier(n_estimators=n_estimators, random_state=random_state, nthread=n_jobs, use_label_encoder=False, eval_metric='logloss')
3124
+ else:
3125
+ raise ValueError(f"Unsupported model_type: {model_type}")
3126
+
3127
+ model.fit(X_train, y_train)
3128
+
3129
+ perm_importance = permutation_importance(model, X_train, y_train, n_repeats=n_repeats, random_state=random_state, n_jobs=n_jobs)
3130
+
3131
+ # Create a DataFrame for permutation importances
3132
+ permutation_df = pd.DataFrame({
3133
+ 'feature': [features[i] for i in perm_importance.importances_mean.argsort()],
3134
+ 'importance_mean': perm_importance.importances_mean[perm_importance.importances_mean.argsort()],
3135
+ 'importance_std': perm_importance.importances_std[perm_importance.importances_mean.argsort()]
3136
+ }).tail(nr_to_plot)
3137
+
3138
+ # Plotting
3139
+ fig, ax = plt.subplots()
3140
+ ax.barh(permutation_df['feature'], permutation_df['importance_mean'], xerr=permutation_df['importance_std'], color="teal", align="center", alpha=0.6)
3141
+ ax.set_xlabel('Permutation Importance')
3142
+ plt.tight_layout()
3143
+ plt.show()
3144
+
3145
+ # Feature importance for models that support it
3146
+ if model_type in ['random_forest', 'xgboost', 'gradient_boosting']:
3147
+ feature_importances = model.feature_importances_
3148
+ feature_importance_df = pd.DataFrame({
3149
+ 'feature': features,
3150
+ 'importance': feature_importances
3151
+ }).sort_values(by='importance', ascending=False).head(nr_to_plot)
3152
+
3153
+ # Plotting feature importance
3154
+ fig, ax = plt.subplots()
3155
+ ax.barh(feature_importance_df['feature'], feature_importance_df['importance'], color="blue", align="center", alpha=0.6)
3156
+ ax.set_xlabel('Feature Importance')
3157
+ plt.tight_layout()
3158
+ plt.show()
3159
+ else:
3160
+ feature_importance_df = pd.DataFrame()
3161
+
3162
+ # Predicting the target variable for the test set
3163
+ predictions_test = model.predict(X_test)
3164
+ combined_df.loc[X_test.index, 'predictions'] = predictions_test
3165
+
3166
+ # Predicting the target variable for the training set
3167
+ predictions_train = model.predict(X_train)
3168
+ combined_df.loc[X_train.index, 'predictions'] = predictions_train
3169
+
3170
+ # Predicting the target variable for all other rows in the dataframe
3171
+ X_all = df[features]
3172
+ all_predictions = model.predict(X_all)
3173
+ df['predictions'] = all_predictions
3174
+
3175
+ # Combine data usage labels back to the original dataframe
3176
+ combined_data_usage = pd.concat([combined_df[['data_usage']], df[['predictions']]], axis=0)
3177
+ df = df.join(combined_data_usage, how='left', rsuffix='_model')
3178
+
3179
+ # Calculating and printing the accuracy metrics
3180
+ accuracy = accuracy_score(y_test, predictions_test)
3181
+ precision = precision_score(y_test, predictions_test)
3182
+ recall = recall_score(y_test, predictions_test)
3183
+ f1 = f1_score(y_test, predictions_test)
3184
+ print(f"Accuracy: {accuracy}")
3185
+ print(f"Precision: {precision}")
3186
+ print(f"Recall: {recall}")
3187
+ print(f"F1 Score: {f1}")
3188
+
3189
+ # Printing class-specific accuracy metrics
3190
+ print("\nClassification Report:")
3191
+ print(classification_report(y_test, predictions_test))
3192
+
3193
+ df = _calculate_similarity(df, features, col_to_compare, pos, neg)
3194
+
3195
+ return [df, permutation_df, feature_importance_df, model, X_train, X_test, y_train, y_test]
3196
+
3197
+ def _shap_analysis(model, X_train, X_test):
3198
+
3199
+ """
3200
+ Performs SHAP analysis on the given model and data.
3201
+
3202
+ Args:
3203
+ model: The trained model.
3204
+ X_train (pandas.DataFrame): Training feature set.
3205
+ X_test (pandas.DataFrame): Testing feature set.
3206
+ """
3207
+
3208
+ explainer = shap.Explainer(model, X_train)
3209
+ shap_values = explainer(X_test)
3210
+
3211
+ # Summary plot
3212
+ shap.summary_plot(shap_values, X_test)
3213
+
3214
+ def plate_heatmap(src, model_type='xgboost', variable='predictions', grouping='mean', min_max='allq', cmap='viridis', channel_of_interest=3, min_count=25, n_estimators=100, col_to_compare='col', pos='c1', neg='c2', exclude=None, n_repeats=10, clean=True, nr_to_plot=20, verbose=False, n_jobs=-1):
3215
+ from .io import _read_and_merge_data
3216
+ from .plot import _plot_plates
3217
+
3218
+ db_loc = [src+'/measurements/measurements.db']
3219
+ tables = ['cell', 'nucleus', 'pathogen','cytoplasm']
3220
+ include_multinucleated, include_multiinfected, include_noninfected = True, 2.0, True
3221
+
3222
+ df, _ = _read_and_merge_data(db_loc,
3223
+ tables,
3224
+ verbose=verbose,
3225
+ include_multinucleated=include_multinucleated,
3226
+ include_multiinfected=include_multiinfected,
3227
+ include_noninfected=include_noninfected)
3228
+
3229
+ if not channel_of_interest is None:
3230
+ df['recruitment'] = df[f'pathogen_channel_{channel_of_interest}_mean_intensity']/df[f'cytoplasm_channel_{channel_of_interest}_mean_intensity']
3231
+ feature_string = f'channel_{channel_of_interest}'
3232
+ else:
3233
+ feature_string = None
3234
+
3235
+ output = _permutation_importance(df, feature_string, col_to_compare, pos, neg, exclude, n_repeats, clean, nr_to_plot, n_estimators=n_estimators, random_state=42, model_type=model_type, n_jobs=n_jobs)
3236
+
3237
+ _shap_analysis(output[3], output[4], output[5])
3238
+
3239
+ features = output[0].select_dtypes(include=[np.number]).columns.tolist()
3240
+
3241
+ if not variable in features:
3242
+ raise ValueError(f"Variable {variable} not found in the dataframe. Please choose one of the following: {features}")
3243
+
3244
+ plate_heatmap = _plot_plates(output[0], variable, grouping, min_max, cmap, min_count)
3245
+ return [output, plate_heatmap]
3246
+
3247
+ def join_measurments_and_annotation(src, tables = ['cell', 'nucleus', 'pathogen','cytoplasm']):
3248
+
3249
+ from .io import _read_and_merge_data, _read_db
3250
+
3251
+ db_loc = [src+'/measurements/measurements.db']
3252
+ loc = src+'/measurements/measurements.db'
3253
+ df, _ = _read_and_merge_data(db_loc,
3254
+ tables,
3255
+ verbose=True,
3256
+ include_multinucleated=True,
3257
+ include_multiinfected=True,
3258
+ include_noninfected=True)
3259
+
3260
+ paths_df = _read_db(loc, tables=['png_list'])
3261
+
3262
+ merged_df = pd.merge(df, paths_df[0], on='prcfo', how='left')
3263
+
3264
+ return merged_df
3265
+
3266
+ def jitterplot_by_annotation(src, x_column, y_column, plot_title='Jitter Plot', output_path=None, filter_column=None, filter_values=None):
3267
+ """
3268
+ Reads a CSV file and creates a jitter plot of one column grouped by another column.
3269
+
3270
+ Args:
3271
+ src (str): Path to the source data.
3272
+ x_column (str): Name of the column to be used for the x-axis.
3273
+ y_column (str): Name of the column to be used for the y-axis.
3274
+ plot_title (str): Title of the plot. Default is 'Jitter Plot'.
3275
+ output_path (str): Path to save the plot image. If None, the plot will be displayed. Default is None.
3276
+
3277
+ Returns:
3278
+ pd.DataFrame: The filtered and balanced DataFrame.
3279
+ """
3280
+ # Read the CSV file into a DataFrame
3281
+ df = join_measurments_and_annotation(src, tables=['cell', 'nucleus', 'pathogen', 'cytoplasm'])
3282
+
3283
+ # Print column names for debugging
3284
+ print(f"Generated dataframe with: {df.shape[1]} columns and {df.shape[0]} rows")
3285
+ #print("Columns in DataFrame:", df.columns.tolist())
3286
+
3287
+ # Replace NaN values with a specific label in x_column
3288
+ df[x_column] = df[x_column].fillna('NaN')
3289
+
3290
+ # Filter the DataFrame if filter_column and filter_values are provided
3291
+ if not filter_column is None:
3292
+ if isinstance(filter_column, str):
3293
+ df = df[df[filter_column].isin(filter_values)]
3294
+ if isinstance(filter_column, list):
3295
+ for i,val in enumerate(filter_column):
3296
+ print(f'hello {len(df)}')
3297
+ df = df[df[val].isin(filter_values[i])]
3298
+
3299
+ # Use the correct column names based on your DataFrame
3300
+ required_columns = ['plate_x', 'row_x', 'col_x']
3301
+ if not all(column in df.columns for column in required_columns):
3302
+ raise KeyError(f"DataFrame does not contain the necessary columns: {required_columns}")
3303
+
3304
+ # Filter to retain rows with non-NaN values in x_column and with matching plate, row, col values
3305
+ non_nan_df = df[df[x_column] != 'NaN']
3306
+ retained_rows = df[df[['plate_x', 'row_x', 'col_x']].apply(tuple, axis=1).isin(non_nan_df[['plate_x', 'row_x', 'col_x']].apply(tuple, axis=1))]
3307
+
3308
+ # Determine the minimum count of examples across all groups in x_column
3309
+ min_count = retained_rows[x_column].value_counts().min()
3310
+ print(f'Found {min_count} annotated images')
3311
+
3312
+ # Randomly sample min_count examples from each group in x_column
3313
+ balanced_df = retained_rows.groupby(x_column).apply(lambda x: x.sample(min_count, random_state=42)).reset_index(drop=True)
3314
+
3315
+ # Create the jitter plot
3316
+ plt.figure(figsize=(10, 6))
3317
+ jitter_plot = sns.stripplot(data=balanced_df, x=x_column, y=y_column, hue=x_column, jitter=True, palette='viridis', dodge=False)
3318
+ plt.title(plot_title)
3319
+ plt.xlabel(x_column)
3320
+ plt.ylabel(y_column)
3321
+
3322
+ # Customize the x-axis labels
3323
+ plt.xticks(rotation=45, ha='right')
3324
+
3325
+ # Adjust the position of the x-axis labels to be centered below the data
3326
+ ax = plt.gca()
3327
+ ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha='center')
3328
+
3329
+ # Save the plot to a file or display it
3330
+ if output_path:
3331
+ plt.savefig(output_path, bbox_inches='tight')
3332
+ print(f"Jitter plot saved to {output_path}")
3333
+ else:
3334
+ plt.show()
3335
+
3336
+ return balanced_df