spacr 0.0.18__py3-none-any.whl → 0.0.21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
spacr/core.py CHANGED
@@ -1,12 +1,13 @@
1
- import os, sqlite3, gc, torch, time, random, shutil, cv2, tarfile, datetime
1
+ import os, sqlite3, gc, torch, time, random, shutil, cv2, tarfile, datetime, shap, string
2
2
 
3
3
  # image and array processing
4
4
  import numpy as np
5
5
  import pandas as pd
6
6
 
7
+ from cellpose import train
7
8
  import cellpose
8
9
  from cellpose import models as cp_models
9
- from cellpose import denoise
10
+ from cellpose.models import CellposeModel
10
11
 
11
12
  import statsmodels.formula.api as smf
12
13
  import statsmodels.api as sm
@@ -28,18 +29,18 @@ matplotlib.use('Agg')
28
29
 
29
30
  import torchvision.transforms as transforms
30
31
  from sklearn.model_selection import train_test_split
31
- from sklearn.ensemble import IsolationForest
32
-
32
+ from sklearn.ensemble import IsolationForest, RandomForestClassifier, HistGradientBoostingClassifier
33
33
  from .logger import log_function_call
34
34
 
35
- #from .io import TarImageDataset, NoClassDataset, MyDataset, read_db, _copy_missclassified, read_mask, load_normalized_images_and_labels, load_images_and_labels
36
- #from .plot import plot_merged, plot_arrays, _plot_controls, _plot_recruitment, _imshow, _plot_histograms_and_stats, _reg_v_plot, visualize_masks, plot_comparison_results
37
- #from .utils import extract_boundaries, boundary_f1_score, compute_segmentation_ap, jaccard_index, dice_coefficient, _object_filter
38
- #from .utils import resize_images_and_labels, generate_fraction_map, MLR, fishers_odds, lasso_reg, model_metrics, _map_wells_png, check_multicollinearity, init_globals, add_images_to_tar
39
- #from .utils import get_paths_from_db, pick_best_model, test_model_performance, evaluate_model_performance, compute_irm_penalty
40
- #from .utils import _pivot_counts_table, _generate_masks, _get_cellpose_channels, annotate_conditions, _calculate_recruitment, calculate_loss, _group_by_well, choose_model
35
+ from sklearn.linear_model import LogisticRegression
36
+ from sklearn.inspection import permutation_importance
37
+ from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, classification_report
38
+ from xgboost import XGBClassifier
39
+
40
+ from scipy.spatial.distance import cosine, euclidean, mahalanobis, cityblock, minkowski, chebyshev, hamming, jaccard, braycurtis
41
+ from sklearn.preprocessing import StandardScaler
42
+ import shap
41
43
 
42
- @log_function_call
43
44
  def analyze_plaques(folder):
44
45
  summary_data = []
45
46
  details_data = []
@@ -76,75 +77,6 @@ def analyze_plaques(folder):
76
77
 
77
78
  print(f"Analysis completed and saved to database '{db_name}'.")
78
79
 
79
- @log_function_call
80
- def compare_masks(dir1, dir2, dir3, verbose=False):
81
-
82
- from .io import _read_mask
83
- from .plot import visualize_masks, plot_comparison_results
84
- from .utils import extract_boundaries, boundary_f1_score, compute_segmentation_ap, jaccard_index, dice_coefficient
85
-
86
- filenames = os.listdir(dir1)
87
- results = []
88
- cond_1 = os.path.basename(dir1)
89
- cond_2 = os.path.basename(dir2)
90
- cond_3 = os.path.basename(dir3)
91
- for index, filename in enumerate(filenames):
92
- print(f'Processing image:{index+1}', end='\r', flush=True)
93
- path1, path2, path3 = os.path.join(dir1, filename), os.path.join(dir2, filename), os.path.join(dir3, filename)
94
- if os.path.exists(path2) and os.path.exists(path3):
95
-
96
- mask1, mask2, mask3 = _read_mask(path1), _read_mask(path2), _read_mask(path3)
97
- boundary_true1, boundary_true2, boundary_true3 = extract_boundaries(mask1), extract_boundaries(mask2), extract_boundaries(mask3)
98
-
99
-
100
- true_masks, pred_masks = [mask1], [mask2, mask3] # Assuming mask1 is the ground truth for simplicity
101
- true_labels, pred_labels_1, pred_labels_2 = label(mask1), label(mask2), label(mask3)
102
- average_precision_0, average_precision_1 = compute_segmentation_ap(mask1, mask2), compute_segmentation_ap(mask1, mask3)
103
- ap_scores = [average_precision_0, average_precision_1]
104
-
105
- if verbose:
106
- unique_values1, unique_values2, unique_values3 = np.unique(mask1), np.unique(mask2), np.unique(mask3)
107
- print(f"Unique values in mask 1: {unique_values1}, mask 2: {unique_values2}, mask 3: {unique_values3}")
108
- visualize_masks(boundary_true1, boundary_true2, boundary_true3, title=f"Boundaries - {filename}")
109
-
110
- boundary_f1_12, boundary_f1_13, boundary_f1_23 = boundary_f1_score(mask1, mask2), boundary_f1_score(mask1, mask3), boundary_f1_score(mask2, mask3)
111
-
112
- if (np.unique(mask1).size == 1 and np.unique(mask1)[0] == 0) and \
113
- (np.unique(mask2).size == 1 and np.unique(mask2)[0] == 0) and \
114
- (np.unique(mask3).size == 1 and np.unique(mask3)[0] == 0):
115
- continue
116
-
117
- if verbose:
118
- unique_values4, unique_values5, unique_values6 = np.unique(boundary_f1_12), np.unique(boundary_f1_13), np.unique(boundary_f1_23)
119
- print(f"Unique values in boundary mask 1: {unique_values4}, mask 2: {unique_values5}, mask 3: {unique_values6}")
120
- visualize_masks(mask1, mask2, mask3, title=filename)
121
-
122
- jaccard12 = jaccard_index(mask1, mask2)
123
- dice12 = dice_coefficient(mask1, mask2)
124
- jaccard13 = jaccard_index(mask1, mask3)
125
- dice13 = dice_coefficient(mask1, mask3)
126
- jaccard23 = jaccard_index(mask2, mask3)
127
- dice23 = dice_coefficient(mask2, mask3)
128
-
129
- results.append({
130
- f'filename': filename,
131
- f'jaccard_{cond_1}_{cond_2}': jaccard12,
132
- f'dice_{cond_1}_{cond_2}': dice12,
133
- f'jaccard_{cond_1}_{cond_3}': jaccard13,
134
- f'dice_{cond_1}_{cond_3}': dice13,
135
- f'jaccard_{cond_2}_{cond_3}': jaccard23,
136
- f'dice_{cond_2}_{cond_3}': dice23,
137
- f'boundary_f1_{cond_1}_{cond_2}': boundary_f1_12,
138
- f'boundary_f1_{cond_1}_{cond_3}': boundary_f1_13,
139
- f'boundary_f1_{cond_2}_{cond_3}': boundary_f1_23,
140
- f'average_precision_{cond_1}_{cond_2}': ap_scores[0],
141
- f'average_precision_{cond_1}_{cond_3}': ap_scores[1]
142
- })
143
- else:
144
- print(f'Cannot find {path1} or {path2} or {path3}')
145
- fig = plot_comparison_results(results)
146
- return results, fig
147
-
148
80
  def generate_cp_masks(settings):
149
81
 
150
82
  src = settings['src']
@@ -178,18 +110,155 @@ def generate_cp_masks(settings):
178
110
 
179
111
  dst = os.path.join(src,'masks')
180
112
  os.makedirs(dst, exist_ok=True)
181
-
113
+
182
114
  identify_masks(src, dst, model_name, channels, diameter, batch_size, flow_threshold, cellprob_threshold, figuresize, cmap, verbose, plot, save, custom_model, signal_thresholds, normalize, resize, target_height, target_width, rescale, resample, net_avg, invert, circular, percentiles, overlay, grayscale)
183
115
 
184
- @log_function_call
185
116
  def train_cellpose(settings):
186
117
 
187
118
  from .io import _load_normalized_images_and_labels, _load_images_and_labels
188
119
  from .utils import resize_images_and_labels
189
120
 
190
121
  img_src = settings['img_src']
191
- mask_src= settings['mask_src']
192
- secondary_image_dir = None
122
+ mask_src = os.path.join(img_src, 'mask')
123
+
124
+ model_name = settings['model_name']
125
+ model_type = settings['model_type']
126
+ learning_rate = settings['learning_rate']
127
+ weight_decay = settings['weight_decay']
128
+ batch_size = settings['batch_size']
129
+ n_epochs = settings['n_epochs']
130
+ from_scratch = settings['from_scratch']
131
+ diameter = settings['diameter']
132
+ verbose = settings['verbose']
133
+
134
+ channels = [0,0]
135
+ signal_thresholds = 1000
136
+ normalize = True
137
+ percentiles = [2,98]
138
+ circular = False
139
+ invert = False
140
+ resize = False
141
+ settings['width_height'] = [1000,1000]
142
+ target_height = settings['width_height'][1]
143
+ target_width = settings['width_height'][0]
144
+ rescale = False
145
+ grayscale = True
146
+ test = False
147
+
148
+ if test:
149
+ test_img_src = os.path.join(os.path.dirname(img_src), 'test')
150
+ test_mask_src = os.path.join(test_img_src, 'mask')
151
+
152
+ test_images, test_masks, test_image_names, test_mask_names = None,None,None,None,
153
+ print(settings)
154
+
155
+ if from_scratch:
156
+ model_name=f'scratch_{model_name}_{model_type}_e{n_epochs}_X{target_width}_Y{target_height}.CP_model'
157
+ else:
158
+ if resize:
159
+ model_name=f'{model_name}_{model_type}_e{n_epochs}_X{target_width}_Y{target_height}.CP_model'
160
+ else:
161
+ model_name=f'{model_name}_{model_type}_e{n_epochs}.CP_model'
162
+
163
+ model_save_path = os.path.join(mask_src, 'models', 'cellpose_model')
164
+ print(model_save_path)
165
+ os.makedirs(model_save_path, exist_ok=True)
166
+
167
+ settings_df = pd.DataFrame(list(settings.items()), columns=['Key', 'Value'])
168
+ settings_csv = os.path.join(model_save_path,f'{model_name}_settings.csv')
169
+ settings_df.to_csv(settings_csv, index=False)
170
+
171
+ if from_scratch:
172
+ model = cp_models.CellposeModel(gpu=True, model_type=model_type, diam_mean=diameter, pretrained_model=None)
173
+ else:
174
+ model = cp_models.CellposeModel(gpu=True, model_type=model_type)
175
+
176
+ if normalize:
177
+
178
+ image_files = [os.path.join(img_src, f) for f in os.listdir(img_src) if f.endswith('.tif')]
179
+ label_files = [os.path.join(mask_src, f) for f in os.listdir(mask_src) if f.endswith('.tif')]
180
+ images, masks, image_names, mask_names = _load_normalized_images_and_labels(image_files, label_files, signal_thresholds, channels=channels, percentiles=percentiles, circular=circular, invert=invert, visualize=verbose)
181
+ images = [np.squeeze(img) if img.shape[-1] == 1 else img for img in images]
182
+
183
+ if test:
184
+ test_image_files = [os.path.join(test_img_src, f) for f in os.listdir(test_img_src) if f.endswith('.tif')]
185
+ test_label_files = [os.path.join(test_mask_src, f) for f in os.listdir(test_mask_src) if f.endswith('.tif')]
186
+ test_images, test_masks, test_image_names, test_mask_names = _load_normalized_images_and_labels(image_files=test_image_files, label_files=test_label_files, signal_thresholds=signal_thresholds, channels=channels, percentiles=percentiles, circular=circular, invert=invert, visualize=verbose)
187
+ test_images = [np.squeeze(img) if img.shape[-1] == 1 else img for img in test_images]
188
+
189
+
190
+ else:
191
+ images, masks, image_names, mask_names = _load_images_and_labels(img_src, mask_src, circular, invert)
192
+ images = [np.squeeze(img) if img.shape[-1] == 1 else img for img in images]
193
+
194
+ if test:
195
+ test_images, test_masks, test_image_names, test_mask_names = _load_images_and_labels(img_src=test_img_src, mask_src=test_mask_src, circular=circular, invert=circular)
196
+ test_images = [np.squeeze(img) if img.shape[-1] == 1 else img for img in test_images]
197
+
198
+ if resize:
199
+ images, masks = resize_images_and_labels(images, masks, target_height, target_width, show_example=True)
200
+
201
+ if model_type == 'cyto':
202
+ cp_channels = [0,1]
203
+ if model_type == 'cyto2':
204
+ cp_channels = [0,2]
205
+ if model_type == 'nucleus':
206
+ cp_channels = [0,0]
207
+ if grayscale:
208
+ cp_channels = [0,0]
209
+ images = [np.squeeze(img) if img.ndim == 3 and 1 in img.shape else img for img in images]
210
+
211
+ masks = [np.squeeze(mask) if mask.ndim == 3 and 1 in mask.shape else mask for mask in masks]
212
+
213
+ print(f'image shape: {images[0].shape}, image type: images[0].shape mask shape: {masks[0].shape}, image type: masks[0].shape')
214
+ save_every = int(n_epochs/10)
215
+ if save_every < 10:
216
+ save_every = n_epochs
217
+
218
+ train.train_seg(model.net,
219
+ train_data=images,
220
+ train_labels=masks,
221
+ train_files=image_names,
222
+ train_labels_files=mask_names,
223
+ train_probs=None,
224
+ test_data=test_images,
225
+ test_labels=test_masks,
226
+ test_files=test_image_names,
227
+ test_labels_files=test_mask_names,
228
+ test_probs=None,
229
+ load_files=True,
230
+ batch_size=batch_size,
231
+ learning_rate=learning_rate,
232
+ n_epochs=n_epochs,
233
+ weight_decay=weight_decay,
234
+ momentum=0.9,
235
+ SGD=False,
236
+ channels=cp_channels,
237
+ channel_axis=None,
238
+ #rgb=False,
239
+ normalize=False,
240
+ compute_flows=False,
241
+ save_path=model_save_path,
242
+ save_every=save_every,
243
+ nimg_per_epoch=None,
244
+ nimg_test_per_epoch=None,
245
+ rescale=rescale,
246
+ #scale_range=None,
247
+ #bsize=224,
248
+ min_train_masks=1,
249
+ model_name=model_name)
250
+
251
+ return print(f"Model saved at: {model_save_path}/{model_name}")
252
+
253
+ def train_cellpose_v1(settings):
254
+
255
+ from .io import _load_normalized_images_and_labels, _load_images_and_labels
256
+ from .utils import resize_images_and_labels
257
+
258
+ img_src = settings['img_src']
259
+
260
+ mask_src = os.path.join(img_src, 'mask')
261
+
193
262
  model_name = settings['model_name']
194
263
  model_type = settings['model_type']
195
264
  learning_rate = settings['learning_rate']
@@ -197,7 +266,9 @@ def train_cellpose(settings):
197
266
  batch_size = settings['batch_size']
198
267
  n_epochs = settings['n_epochs']
199
268
  verbose = settings['verbose']
200
- signal_thresholds = settings['signal_thresholds']
269
+
270
+ signal_thresholds = 100 #settings['signal_thresholds']
271
+
201
272
  channels = settings['channels']
202
273
  from_scratch = settings['from_scratch']
203
274
  diameter = settings['diameter']
@@ -210,7 +281,17 @@ def train_cellpose(settings):
210
281
  invert = settings['invert']
211
282
  percentiles = settings['percentiles']
212
283
  grayscale = settings['grayscale']
213
-
284
+
285
+ if model_type == 'cyto':
286
+ settings['diameter'] = 30
287
+ diameter = settings['diameter']
288
+ print(f'Cyto model must have diamiter 30. Diameter set the 30')
289
+
290
+ if model_type == 'nuclei':
291
+ settings['diameter'] = 17
292
+ diameter = settings['diameter']
293
+ print(f'Nuclei model must have diamiter 17. Diameter set the 17')
294
+
214
295
  print(settings)
215
296
 
216
297
  if from_scratch:
@@ -219,24 +300,24 @@ def train_cellpose(settings):
219
300
  model_name=f'{model_name}_{model_type}_e{n_epochs}_X{target_width}_Y{target_height}.CP_model'
220
301
 
221
302
  model_save_path = os.path.join(mask_src, 'models', 'cellpose_model')
222
- os.makedirs(os.path.dirname(model_save_path), exist_ok=True)
303
+ print(model_save_path)
304
+ os.makedirs(model_save_path, exist_ok=True)
223
305
 
224
306
  settings_df = pd.DataFrame(list(settings.items()), columns=['Key', 'Value'])
225
307
  settings_csv = os.path.join(model_save_path,f'{model_name}_settings.csv')
226
308
  settings_df.to_csv(settings_csv, index=False)
227
309
 
228
- if model_type =='cyto':
229
- if not from_scratch:
230
- model = cp_models.CellposeModel(gpu=True, model_type=model_type)
231
- else:
232
- model = cp_models.CellposeModel(gpu=True, model_type=model_type, net_avg=False, diam_mean=diameter, pretrained_model=None)
233
- if model_type !='cyto':
310
+ if not from_scratch:
234
311
  model = cp_models.CellposeModel(gpu=True, model_type=model_type)
235
-
236
-
237
-
238
- if normalize:
239
- images, masks, image_names, mask_names = _load_normalized_images_and_labels(image_dir=img_src, label_dir=mask_src, secondary_image_dir=secondary_image_dir, signal_thresholds=signal_thresholds, channels=channels, percentiles=percentiles, circular=circular, invert=invert, visualize=verbose)
312
+
313
+ else:
314
+ model = cp_models.CellposeModel(gpu=True, model_type=model_type, pretrained_model=None)
315
+
316
+ if normalize:
317
+ image_files = [os.path.join(img_src, f) for f in os.listdir(img_src) if f.endswith('.tif')]
318
+ label_files = [os.path.join(mask_src, f) for f in os.listdir(mask_src) if f.endswith('.tif')]
319
+
320
+ images, masks, image_names, mask_names = _load_normalized_images_and_labels(image_files, label_files, signal_thresholds, channels=channels, percentiles=percentiles, circular=circular, invert=invert, visualize=verbose)
240
321
  images = [np.squeeze(img) if img.shape[-1] == 1 else img for img in images]
241
322
  else:
242
323
  images, masks, image_names, mask_names = _load_images_and_labels(img_src, mask_src, circular, invert)
@@ -259,29 +340,89 @@ def train_cellpose(settings):
259
340
 
260
341
  print(f'image shape: {images[0].shape}, image type: images[0].shape mask shape: {masks[0].shape}, image type: masks[0].shape')
261
342
  save_every = int(n_epochs/10)
262
- print('cellpose image input dtype', images[0].dtype)
263
- print('cellpose mask input dtype', masks[0].dtype)
343
+ if save_every < 10:
344
+ save_every = n_epochs
345
+
346
+
347
+ #print('cellpose image input dtype', images[0].dtype)
348
+ #print('cellpose mask input dtype', masks[0].dtype)
349
+
264
350
  # Train the model
265
- model.train(train_data=images, #(list of arrays (2D or 3D)) – images for training
266
- train_labels=masks, #(list of arrays (2D or 3D)) – labels for train_data, where 0=no masks; 1,2,…=mask labels can include flows as additional images
267
- train_files=image_names, #(list of strings) file names for images in train_data (to save flows for future runs)
268
- channels=cp_channels, #(list of ints (default, None)) – channels to use for training
269
- normalize=False, #(bool (default, True))normalize data so 0.0=1st percentile and 1.0=99th percentile of image intensities in each channel
270
- save_path=model_save_path, #(string (default, None)) – where to save trained model, if None it is not saved
271
- save_every=save_every, #(int (default, 100)) – save network every [save_every] epochs
272
- learning_rate=learning_rate, #(float or list/np.ndarray (default, 0.2)) – learning rate for training, if list, must be same length as n_epochs
273
- n_epochs=n_epochs, #(int (default, 500)) – how many times to go through whole training set during training
274
- weight_decay=weight_decay, #(float (default, 0.00001)) –
275
- SGD=True, #(bool (default, True)) – use SGD as optimization instead of RAdam
276
- batch_size=batch_size, #(int (optional, default 8)) – number of 224x224 patches to run simultaneously on the GPU (can make smaller or bigger depending on GPU memory usage)
277
- nimg_per_epoch=None, #(int (optional, default None)) – minimum number of images to train on per epoch, with a small training set (< 8 images) it may help to set to 8
278
- rescale=rescale, #(bool (default, True)) – whether or not to rescale images to diam_mean during training, if True it assumes you will fit a size model after training or resize your images accordingly, if False it will try to train the model to be scale-invariant (works worse)
279
- min_train_masks=1, #(int (default, 5)) – minimum number of masks an image must have to use in training set
280
- model_name=model_name) #(str (default, None)) – name of network, otherwise saved with name as params + training start time
351
+ #model.train(train_data=images, #(list of arrays (2D or 3D)) – images for training
352
+
353
+ #model.train(train_data=images, #(list of arrays (2D or 3D)) images for training
354
+ # train_labels=masks, #(list of arrays (2D or 3D)) – labels for train_data, where 0=no masks; 1,2,…=mask labels can include flows as additional images
355
+ # train_files=image_names, #(list of strings) – file names for images in train_data (to save flows for future runs)
356
+ # channels=cp_channels, #(list of ints (default, None)) – channels to use for training
357
+ # normalize=False, #(bool (default, True)) – normalize data so 0.0=1st percentile and 1.0=99th percentile of image intensities in each channel
358
+ # save_path=model_save_path, #(string (default, None)) – where to save trained model, if None it is not saved
359
+ # save_every=save_every, #(int (default, 100)) – save network every [save_every] epochs
360
+ # learning_rate=learning_rate, #(float or list/np.ndarray (default, 0.2)) – learning rate for training, if list, must be same length as n_epochs
361
+ # n_epochs=n_epochs, #(int (default, 500)) – how many times to go through whole training set during training
362
+ # weight_decay=weight_decay, #(float (default, 0.00001)) –
363
+ # SGD=True, #(bool (default, True)) – use SGD as optimization instead of RAdam
364
+ # batch_size=batch_size, #(int (optional, default 8)) – number of 224x224 patches to run simultaneously on the GPU (can make smaller or bigger depending on GPU memory usage)
365
+ # nimg_per_epoch=None, #(int (optional, default None)) – minimum number of images to train on per epoch, with a small training set (< 8 images) it may help to set to 8
366
+ # rescale=rescale, #(bool (default, True)) – whether or not to rescale images to diam_mean during training, if True it assumes you will fit a size model after training or resize your images accordingly, if False it will try to train the model to be scale-invariant (works worse)
367
+ # min_train_masks=1, #(int (default, 5)) – minimum number of masks an image must have to use in training set
368
+ # model_name=model_name) #(str (default, None)) – name of network, otherwise saved with name as params + training start time
369
+
370
+
371
+ train.train_seg(model.net,
372
+ train_data=images,
373
+ train_labels=masks,
374
+ train_files=image_names,
375
+ train_labels_files=None,
376
+ train_probs=None,
377
+ test_data=None,
378
+ test_labels=None,
379
+ test_files=None,
380
+ test_labels_files=None,
381
+ test_probs=None,
382
+ load_files=True,
383
+ batch_size=batch_size,
384
+ learning_rate=learning_rate,
385
+ n_epochs=n_epochs,
386
+ weight_decay=weight_decay,
387
+ momentum=0.9,
388
+ SGD=False,
389
+ channels=cp_channels,
390
+ channel_axis=None,
391
+ #rgb=False,
392
+ normalize=False,
393
+ compute_flows=False,
394
+ save_path=model_save_path,
395
+ save_every=save_every,
396
+ nimg_per_epoch=None,
397
+ nimg_test_per_epoch=None,
398
+ rescale=rescale,
399
+ #scale_range=None,
400
+ #bsize=224,
401
+ min_train_masks=1,
402
+ model_name=model_name)
403
+
404
+ #model_save_path = train.train_seg(model.net,
405
+ # train_data=images,
406
+ # train_files=image_names,
407
+ # train_labels=masks,
408
+ # channels=cp_channels,
409
+ # normalize=False,
410
+ # save_every=save_every,
411
+ # learning_rate=learning_rate,
412
+ # n_epochs=n_epochs,
413
+ # #test_data=test_images,
414
+ # #test_labels=test_labels,
415
+ # weight_decay=weight_decay,
416
+ # SGD=True,
417
+ # batch_size=batch_size,
418
+ # nimg_per_epoch=None,
419
+ # rescale=rescale,
420
+ # min_train_masks=1,
421
+ # model_name=model_name)
422
+
281
423
 
282
424
  return print(f"Model saved at: {model_save_path}/{model_name}")
283
425
 
284
- @log_function_call
285
426
  def analyze_data_reg(sequencing_loc, dv_loc, agg_type = 'mean', dv_col='pred', transform=None, min_cell_count=50, min_reads=100, min_wells=2, max_wells=1000, min_frequency=0.0,remove_outlier_genes=False, refine_model=False,by_plate=False, regression_type='mlr', alpha_value=0.01, fishers=False, fisher_threshold=0.9):
286
427
 
287
428
  from .plot import _reg_v_plot
@@ -430,7 +571,6 @@ def analyze_data_reg(sequencing_loc, dv_loc, agg_type = 'mean', dv_col='pred', t
430
571
 
431
572
  return result
432
573
 
433
- @log_function_call
434
574
  def analyze_data_reg(sequencing_loc, dv_loc, agg_type = 'mean', min_cell_count=50, min_reads=100, min_wells=2, max_wells=1000, remove_outlier_genes=False, refine_model=False, by_plate=False, threshold=0.5, fishers=False):
435
575
 
436
576
  from .plot import _reg_v_plot
@@ -609,7 +749,6 @@ def analyze_data_reg(sequencing_loc, dv_loc, agg_type = 'mean', min_cell_count=5
609
749
 
610
750
  return max_effects, max_effects_pvalues, model, df
611
751
 
612
- @log_function_call
613
752
  def regression_analasys(dv_df,sequencing_loc, min_reads=75, min_wells=2, max_wells=0, model_type = 'mlr', min_cells=100, transform='logit', min_frequency=0.05, gene_column='gene', effect_size_threshold=0.25, fishers=True, clean_regression=False, VIF_threshold=10):
614
753
 
615
754
  from .utils import generate_fraction_map, fishers_odds, model_metrics, check_multicollinearity
@@ -777,7 +916,6 @@ def regression_analasys(dv_df,sequencing_loc, min_reads=75, min_wells=2, max_wel
777
916
 
778
917
  return
779
918
 
780
- @log_function_call
781
919
  def merge_pred_mes(src,
782
920
  pred_loc,
783
921
  target='protein of interest',
@@ -941,30 +1079,38 @@ def annotate_results(pred_loc):
941
1079
  display(df)
942
1080
  return df
943
1081
 
944
- def generate_dataset(src, file_type=None, experiment='TSG101_screen', sample=None):
1082
+ def generate_dataset(src, file_metadata=None, experiment='TSG101_screen', sample=None):
945
1083
 
946
- from .utils import init_globals, add_images_to_tar
947
-
948
- db_path = os.path.join(src, 'measurements','measurements.db')
1084
+ from .utils import initiate_counter, add_images_to_tar
1085
+
1086
+ db_path = os.path.join(src, 'measurements', 'measurements.db')
949
1087
  dst = os.path.join(src, 'datasets')
950
-
951
- global total_images
952
1088
  all_paths = []
953
-
1089
+
954
1090
  # Connect to the database and retrieve the image paths
955
1091
  print(f'Reading DataBase: {db_path}')
956
- with sqlite3.connect(db_path) as conn:
957
- cursor = conn.cursor()
958
- if file_type:
959
- cursor.execute("SELECT png_path FROM png_list WHERE png_path LIKE ?", (f"%{file_type}%",))
960
- else:
961
- cursor.execute("SELECT png_path FROM png_list")
962
- while True:
963
- rows = cursor.fetchmany(1000)
964
- if not rows:
965
- break
966
- all_paths.extend([row[0] for row in rows])
967
-
1092
+ try:
1093
+ with sqlite3.connect(db_path) as conn:
1094
+ cursor = conn.cursor()
1095
+ if file_metadata:
1096
+ if isinstance(file_metadata, str):
1097
+ cursor.execute("SELECT png_path FROM png_list WHERE png_path LIKE ?", (f"%{file_metadata}%",))
1098
+ else:
1099
+ cursor.execute("SELECT png_path FROM png_list")
1100
+
1101
+ while True:
1102
+ rows = cursor.fetchmany(1000)
1103
+ if not rows:
1104
+ break
1105
+ all_paths.extend([row[0] for row in rows])
1106
+
1107
+ except sqlite3.Error as e:
1108
+ print(f"Database error: {e}")
1109
+ return
1110
+ except Exception as e:
1111
+ print(f"Error: {e}")
1112
+ return
1113
+
968
1114
  if isinstance(sample, int):
969
1115
  selected_paths = random.sample(all_paths, sample)
970
1116
  print(f'Random selection of {len(selected_paths)} paths')
@@ -972,23 +1118,18 @@ def generate_dataset(src, file_type=None, experiment='TSG101_screen', sample=Non
972
1118
  selected_paths = all_paths
973
1119
  random.shuffle(selected_paths)
974
1120
  print(f'All paths: {len(selected_paths)} paths')
975
-
1121
+
976
1122
  total_images = len(selected_paths)
977
- print(f'found {total_images} images')
978
-
1123
+ print(f'Found {total_images} images')
1124
+
979
1125
  # Create a temp folder in dst
980
1126
  temp_dir = os.path.join(dst, "temp_tars")
981
1127
  os.makedirs(temp_dir, exist_ok=True)
982
1128
 
983
1129
  # Chunking the data
984
- if len(selected_paths) > 10000:
985
- num_procs = cpu_count()-2
986
- chunk_size = len(selected_paths) // num_procs
987
- remainder = len(selected_paths) % num_procs
988
- else:
989
- num_procs = 2
990
- chunk_size = len(selected_paths) // 2
991
- remainder = 0
1130
+ num_procs = max(2, cpu_count() - 2)
1131
+ chunk_size = len(selected_paths) // num_procs
1132
+ remainder = len(selected_paths) % num_procs
992
1133
 
993
1134
  paths_chunks = []
994
1135
  start = 0
@@ -998,45 +1139,43 @@ def generate_dataset(src, file_type=None, experiment='TSG101_screen', sample=Non
998
1139
  start = end
999
1140
 
1000
1141
  temp_tar_files = [os.path.join(temp_dir, f'temp_{i}.tar') for i in range(num_procs)]
1001
-
1002
- # Initialize the shared objects
1003
- counter_ = Value('i', 0)
1004
- lock_ = Lock()
1005
1142
 
1006
- ctx = multiprocessing.get_context('spawn')
1007
-
1008
1143
  print(f'Generating temporary tar files in {dst}')
1009
-
1144
+
1145
+ # Initialize shared counter and lock
1146
+ counter = Value('i', 0)
1147
+ lock = Lock()
1148
+
1149
+ with Pool(processes=num_procs, initializer=initiate_counter, initargs=(counter, lock)) as pool:
1150
+ pool.starmap(add_images_to_tar, [(paths_chunks[i], temp_tar_files[i], total_images) for i in range(num_procs)])
1151
+
1010
1152
  # Combine the temporary tar files into a final tar
1011
1153
  date_name = datetime.date.today().strftime('%y%m%d')
1012
- tar_name = f'{date_name}_{experiment}_{file_type}.tar'
1154
+ if not file_metadata is None:
1155
+ tar_name = f'{date_name}_{experiment}_{file_metadata}.tar'
1156
+ else:
1157
+ tar_name = f'{date_name}_{experiment}.tar'
1158
+ tar_name = os.path.join(dst, tar_name)
1013
1159
  if os.path.exists(tar_name):
1014
1160
  number = random.randint(1, 100)
1015
- tar_name_2 = f'{date_name}_{experiment}_{file_type}_{number}.tar'
1016
- print(f'Warning: {os.path.basename(tar_name)} exists saving as {os.path.basename(tar_name_2)} ')
1017
- tar_name = tar_name_2
1018
-
1019
- # Add the counter and lock to the arguments for pool.map
1161
+ tar_name_2 = f'{date_name}_{experiment}_{file_metadata}_{number}.tar'
1162
+ print(f'Warning: {os.path.basename(tar_name)} exists, saving as {os.path.basename(tar_name_2)} ')
1163
+ tar_name = os.path.join(dst, tar_name_2)
1164
+
1020
1165
  print(f'Merging temporary files')
1021
- #with Pool(processes=num_procs, initializer=init_globals, initargs=(counter_, lock_)) as pool:
1022
- # results = pool.map(add_images_to_tar, zip(paths_chunks, temp_tar_files))
1023
1166
 
1024
- with ctx.Pool(processes=num_procs, initializer=init_globals, initargs=(counter_, lock_)) as pool:
1025
- results = pool.map(add_images_to_tar, zip(paths_chunks, temp_tar_files))
1026
-
1027
- with tarfile.open(os.path.join(dst, tar_name), 'w') as final_tar:
1028
- for tar_path in results:
1029
- with tarfile.open(tar_path, 'r') as t:
1030
- for member in t.getmembers():
1031
- t.extract(member, path=dst)
1032
- final_tar.add(os.path.join(dst, member.name), arcname=member.name)
1033
- os.remove(os.path.join(dst, member.name))
1034
- os.remove(tar_path)
1167
+ with tarfile.open(tar_name, 'w') as final_tar:
1168
+ for temp_tar_path in temp_tar_files:
1169
+ with tarfile.open(temp_tar_path, 'r') as temp_tar:
1170
+ for member in temp_tar.getmembers():
1171
+ file_obj = temp_tar.extractfile(member)
1172
+ final_tar.addfile(member, file_obj)
1173
+ os.remove(temp_tar_path)
1035
1174
 
1036
1175
  # Delete the temp folder
1037
1176
  shutil.rmtree(temp_dir)
1038
- print(f"\nSaved {total_images} images to {os.path.join(dst, tar_name)}")
1039
-
1177
+ print(f"\nSaved {total_images} images to {tar_name}")
1178
+
1040
1179
  def apply_model_to_tar(tar_path, model_path, file_type='cell_png', image_size=224, batch_size=64, normalize=True, preload='images', num_workers=10, verbose=False):
1041
1180
 
1042
1181
  from .io import TarImageDataset, DataLoader
@@ -1272,7 +1411,14 @@ def generate_training_dataset(src, mode='annotation', annotation_column='test',
1272
1411
 
1273
1412
  db_path = os.path.join(src, 'measurements','measurements.db')
1274
1413
  dst = os.path.join(src, 'datasets', 'training')
1275
-
1414
+
1415
+ if os.path.exists(dst):
1416
+ for i in range(1, 1000):
1417
+ dst = os.path.join(src, 'datasets', f'training_{i}')
1418
+ if not os.path.exists(dst):
1419
+ print(f'Creating new directory for training: {dst}')
1420
+ break
1421
+
1276
1422
  if mode == 'annotation':
1277
1423
  class_paths_ls_2 = []
1278
1424
  class_paths_ls = training_dataset_from_annotation(db_path, dst, annotation_column, annotated_classes=annotated_classes)
@@ -1283,6 +1429,7 @@ def generate_training_dataset(src, mode='annotation', annotation_column='test',
1283
1429
 
1284
1430
  elif mode == 'metadata':
1285
1431
  class_paths_ls = []
1432
+ class_len_ls = []
1286
1433
  [df] = _read_db(db_loc=db_path, tables=['png_list'])
1287
1434
  df['metadata_based_class'] = pd.NA
1288
1435
  for i, class_ in enumerate(classes):
@@ -1290,7 +1437,18 @@ def generate_training_dataset(src, mode='annotation', annotation_column='test',
1290
1437
  df.loc[df[metadata_type_by].isin(ls), 'metadata_based_class'] = class_
1291
1438
 
1292
1439
  for class_ in classes:
1440
+ if size == None:
1441
+ c_s = []
1442
+ for c in classes:
1443
+ c_s_t_df = df[df['metadata_based_class'] == c]
1444
+ c_s.append(len(c_s_t_df))
1445
+ print(f'Found {len(c_s_t_df)} images for class {c}')
1446
+ size = min(c_s)
1447
+ print(f'Using the smallest class size: {size}')
1448
+
1293
1449
  class_temp_df = df[df['metadata_based_class'] == class_]
1450
+ class_len_ls.append(len(class_temp_df))
1451
+ print(f'Found {len(class_temp_df)} images for class {class_}')
1294
1452
  class_paths_temp = random.sample(class_temp_df['png_path'].tolist(), size)
1295
1453
  class_paths_ls.append(class_paths_temp)
1296
1454
 
@@ -1347,7 +1505,7 @@ def generate_training_dataset(src, mode='annotation', annotation_column='test',
1347
1505
 
1348
1506
  return
1349
1507
 
1350
- def generate_loaders(src, train_mode='erm', mode='train', image_size=224, batch_size=32, classes=['nc','pc'], num_workers=None, validation_split=0.0, max_show=2, pin_memory=False, normalize=False, verbose=False):
1508
+ def generate_loaders_v1(src, train_mode='erm', mode='train', image_size=224, batch_size=32, classes=['nc','pc'], num_workers=None, validation_split=0.0, max_show=2, pin_memory=False, normalize=False, verbose=False):
1351
1509
  """
1352
1510
  Generate data loaders for training and validation/test datasets.
1353
1511
 
@@ -1478,55 +1636,222 @@ def generate_loaders(src, train_mode='erm', mode='train', image_size=224, batch_
1478
1636
 
1479
1637
  return train_loaders, val_loaders, plate_names
1480
1638
 
1481
- def analyze_recruitment(src, metadata_settings, advanced_settings):
1639
+ def generate_loaders(src, train_mode='erm', mode='train', image_size=224, batch_size=32, classes=['nc','pc'], num_workers=None, validation_split=0.0, max_show=2, pin_memory=False, normalize=False, channels=[1, 2, 3], verbose=False):
1640
+
1482
1641
  """
1483
- Analyze recruitment data by grouping the DataFrame by well coordinates and plotting controls and recruitment data.
1642
+ Generate data loaders for training and validation/test datasets.
1484
1643
 
1485
1644
  Parameters:
1486
- src (str): The source of the recruitment data.
1487
- metadata_settings (dict): The settings for metadata.
1488
- advanced_settings (dict): The advanced settings for recruitment analysis.
1645
+ - src (str): The source directory containing the data.
1646
+ - train_mode (str): The training mode. Options are 'erm' (Empirical Risk Minimization) or 'irm' (Invariant Risk Minimization).
1647
+ - mode (str): The mode of operation. Options are 'train' or 'test'.
1648
+ - image_size (int): The size of the input images.
1649
+ - batch_size (int): The batch size for the data loaders.
1650
+ - classes (list): The list of classes to consider.
1651
+ - num_workers (int): The number of worker threads for data loading.
1652
+ - validation_split (float): The fraction of data to use for validation when train_mode is 'erm'.
1653
+ - max_show (int): The maximum number of images to show when verbose is True.
1654
+ - pin_memory (bool): Whether to pin memory for faster data transfer.
1655
+ - normalize (bool): Whether to normalize the input images.
1656
+ - verbose (bool): Whether to print additional information and show images.
1657
+ - channels (list): The list of channels to retain. Options are [1, 2, 3] for all channels, [1, 2] for blue and green, etc.
1489
1658
 
1490
1659
  Returns:
1491
- None
1660
+ - train_loaders (list): List of data loaders for training datasets.
1661
+ - val_loaders (list): List of data loaders for validation datasets.
1662
+ - plate_names (list): List of plate names (only applicable when train_mode is 'irm').
1492
1663
  """
1493
-
1494
- from .io import _read_and_merge_data, _results_to_csv
1495
- from .plot import plot_merged, _plot_controls, _plot_recruitment
1496
- from .utils import _object_filter, annotate_conditions, _calculate_recruitment, _group_by_well
1497
-
1498
- settings_dict = {**metadata_settings, **advanced_settings}
1499
- settings_df = pd.DataFrame(list(settings_dict.items()), columns=['Key', 'Value'])
1500
- settings_csv = os.path.join(src,'settings','analyze_settings.csv')
1501
- os.makedirs(os.path.join(src,'settings'), exist_ok=True)
1502
- settings_df.to_csv(settings_csv, index=False)
1503
1664
 
1504
- # metadata settings
1505
- target = metadata_settings['target']
1506
- cell_types = metadata_settings['cell_types']
1507
- cell_plate_metadata = metadata_settings['cell_plate_metadata']
1508
- pathogen_types = metadata_settings['pathogen_types']
1509
- pathogen_plate_metadata = metadata_settings['pathogen_plate_metadata']
1510
- treatments = metadata_settings['treatments']
1511
- treatment_plate_metadata = metadata_settings['treatment_plate_metadata']
1512
- metadata_types = metadata_settings['metadata_types']
1513
- channel_dims = metadata_settings['channel_dims']
1514
- cell_chann_dim = metadata_settings['cell_chann_dim']
1515
- cell_mask_dim = metadata_settings['cell_mask_dim']
1516
- nucleus_chann_dim = metadata_settings['nucleus_chann_dim']
1517
- nucleus_mask_dim = metadata_settings['nucleus_mask_dim']
1518
- pathogen_chann_dim = metadata_settings['pathogen_chann_dim']
1519
- pathogen_mask_dim = metadata_settings['pathogen_mask_dim']
1520
- channel_of_interest = metadata_settings['channel_of_interest']
1521
-
1522
- # Advanced settings
1523
- plot = advanced_settings['plot']
1524
- plot_nr = advanced_settings['plot_nr']
1525
- plot_control = advanced_settings['plot_control']
1526
- figuresize = advanced_settings['figuresize']
1527
- remove_background = advanced_settings['remove_background']
1528
- backgrounds = advanced_settings['backgrounds']
1529
- include_noninfected = advanced_settings['include_noninfected']
1665
+ from .io import MyDataset
1666
+ from .plot import _imshow
1667
+ from torchvision import transforms
1668
+ from torch.utils.data import DataLoader, random_split
1669
+ from collections import defaultdict
1670
+ import os
1671
+ import random
1672
+ from PIL import Image
1673
+ from torchvision.transforms import ToTensor
1674
+
1675
+ chans = []
1676
+
1677
+ if 'r' in channels:
1678
+ chans.append(1)
1679
+ if 'g' in channels:
1680
+ chans.append(2)
1681
+ if 'b' in channels:
1682
+ chans.append(3)
1683
+
1684
+ channels = chans
1685
+
1686
+ if verbose:
1687
+ print(f'Training a network on channels: {channels}')
1688
+ print(f'Channel 1: Red, Channel 2: Green, Channel 3: Blue')
1689
+
1690
+ class SelectChannels:
1691
+ def __init__(self, channels):
1692
+ self.channels = channels
1693
+
1694
+ def __call__(self, img):
1695
+ img = img.clone()
1696
+ if 1 not in self.channels:
1697
+ img[0, :, :] = 0 # Zero out the red channel
1698
+ if 2 not in self.channels:
1699
+ img[1, :, :] = 0 # Zero out the green channel
1700
+ if 3 not in self.channels:
1701
+ img[2, :, :] = 0 # Zero out the blue channel
1702
+ return img
1703
+
1704
+ plate_to_filenames = defaultdict(list)
1705
+ plate_to_labels = defaultdict(list)
1706
+ train_loaders = []
1707
+ val_loaders = []
1708
+ plate_names = []
1709
+
1710
+ if normalize:
1711
+ transform = transforms.Compose([
1712
+ transforms.ToTensor(),
1713
+ transforms.CenterCrop(size=(image_size, image_size)),
1714
+ SelectChannels(channels),
1715
+ transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])
1716
+ else:
1717
+ transform = transforms.Compose([
1718
+ transforms.ToTensor(),
1719
+ transforms.CenterCrop(size=(image_size, image_size)),
1720
+ SelectChannels(channels)])
1721
+
1722
+ if mode == 'train':
1723
+ data_dir = os.path.join(src, 'train')
1724
+ shuffle = True
1725
+ print('Generating Train and validation datasets')
1726
+ elif mode == 'test':
1727
+ data_dir = os.path.join(src, 'test')
1728
+ val_loaders = []
1729
+ validation_split = 0.0
1730
+ shuffle = True
1731
+ print('Generating test dataset')
1732
+ else:
1733
+ print(f'mode:{mode} is not valid, use mode = train or test')
1734
+ return
1735
+
1736
+ if train_mode == 'erm':
1737
+ data = MyDataset(data_dir, classes, transform=transform, shuffle=shuffle, pin_memory=pin_memory)
1738
+ if validation_split > 0:
1739
+ train_size = int((1 - validation_split) * len(data))
1740
+ val_size = len(data) - train_size
1741
+
1742
+ print(f'Train data:{train_size}, Validation data:{val_size}')
1743
+
1744
+ train_dataset, val_dataset = random_split(data, [train_size, val_size])
1745
+
1746
+ train_loaders = DataLoader(train_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers if num_workers is not None else 0, pin_memory=pin_memory)
1747
+ val_loaders = DataLoader(val_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers if num_workers is not None else 0, pin_memory=pin_memory)
1748
+ else:
1749
+ train_loaders = DataLoader(data, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers if num_workers is not None else 0, pin_memory=pin_memory)
1750
+
1751
+ elif train_mode == 'irm':
1752
+ data = MyDataset(data_dir, classes, transform=transform, shuffle=shuffle, pin_memory=pin_memory)
1753
+
1754
+ for filename, label in zip(data.filenames, data.labels):
1755
+ plate = data.get_plate(filename)
1756
+ plate_to_filenames[plate].append(filename)
1757
+ plate_to_labels[plate].append(label)
1758
+
1759
+ for plate, filenames in plate_to_filenames.items():
1760
+ labels = plate_to_labels[plate]
1761
+ plate_data = MyDataset(data_dir, classes, specific_files=filenames, specific_labels=labels, transform=transform, shuffle=False, pin_memory=pin_memory)
1762
+ plate_names.append(plate)
1763
+
1764
+ if validation_split > 0:
1765
+ train_size = int((1 - validation_split) * len(plate_data))
1766
+ val_size = len(plate_data) - train_size
1767
+
1768
+ print(f'Train data:{train_size}, Validation data:{val_size}')
1769
+
1770
+ train_dataset, val_dataset = random_split(plate_data, [train_size, val_size])
1771
+
1772
+ train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers if num_workers is not None else 0, pin_memory=pin_memory)
1773
+ val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers if num_workers is not None else 0, pin_memory=pin_memory)
1774
+
1775
+ train_loaders.append(train_loader)
1776
+ val_loaders.append(val_loader)
1777
+ else:
1778
+ train_loader = DataLoader(plate_data, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers if num_workers is not None else 0, pin_memory=pin_memory)
1779
+ train_loaders.append(train_loader)
1780
+ val_loaders.append(None)
1781
+
1782
+ else:
1783
+ print(f'train_mode:{train_mode} is not valid, use: train_mode = irm or erm')
1784
+ return
1785
+
1786
+ if verbose:
1787
+ if train_mode == 'erm':
1788
+ for idx, (images, labels, filenames) in enumerate(train_loaders):
1789
+ if idx >= max_show:
1790
+ break
1791
+ images = images.cpu()
1792
+ label_strings = [str(label.item()) for label in labels]
1793
+ _imshow(images, label_strings, nrow=20, fontsize=12)
1794
+ elif train_mode == 'irm':
1795
+ for plate_name, train_loader in zip(plate_names, train_loaders):
1796
+ print(f'Plate: {plate_name} with {len(train_loader.dataset)} images')
1797
+ for idx, (images, labels, filenames) in enumerate(train_loader):
1798
+ if idx >= max_show:
1799
+ break
1800
+ images = images.cpu()
1801
+ label_strings = [str(label.item()) for label in labels]
1802
+ _imshow(images, label_strings, nrow=20, fontsize=12)
1803
+
1804
+ return train_loaders, val_loaders, plate_names
1805
+
1806
+ def analyze_recruitment(src, metadata_settings, advanced_settings):
1807
+ """
1808
+ Analyze recruitment data by grouping the DataFrame by well coordinates and plotting controls and recruitment data.
1809
+
1810
+ Parameters:
1811
+ src (str): The source of the recruitment data.
1812
+ metadata_settings (dict): The settings for metadata.
1813
+ advanced_settings (dict): The advanced settings for recruitment analysis.
1814
+
1815
+ Returns:
1816
+ None
1817
+ """
1818
+
1819
+ from .io import _read_and_merge_data, _results_to_csv
1820
+ from .plot import plot_merged, _plot_controls, _plot_recruitment
1821
+ from .utils import _object_filter, annotate_conditions, _calculate_recruitment, _group_by_well
1822
+
1823
+ settings_dict = {**metadata_settings, **advanced_settings}
1824
+ settings_df = pd.DataFrame(list(settings_dict.items()), columns=['Key', 'Value'])
1825
+ settings_csv = os.path.join(src,'settings','analyze_settings.csv')
1826
+ os.makedirs(os.path.join(src,'settings'), exist_ok=True)
1827
+ settings_df.to_csv(settings_csv, index=False)
1828
+
1829
+ # metadata settings
1830
+ target = metadata_settings['target']
1831
+ cell_types = metadata_settings['cell_types']
1832
+ cell_plate_metadata = metadata_settings['cell_plate_metadata']
1833
+ pathogen_types = metadata_settings['pathogen_types']
1834
+ pathogen_plate_metadata = metadata_settings['pathogen_plate_metadata']
1835
+ treatments = metadata_settings['treatments']
1836
+ treatment_plate_metadata = metadata_settings['treatment_plate_metadata']
1837
+ metadata_types = metadata_settings['metadata_types']
1838
+ channel_dims = metadata_settings['channel_dims']
1839
+ cell_chann_dim = metadata_settings['cell_chann_dim']
1840
+ cell_mask_dim = metadata_settings['cell_mask_dim']
1841
+ nucleus_chann_dim = metadata_settings['nucleus_chann_dim']
1842
+ nucleus_mask_dim = metadata_settings['nucleus_mask_dim']
1843
+ pathogen_chann_dim = metadata_settings['pathogen_chann_dim']
1844
+ pathogen_mask_dim = metadata_settings['pathogen_mask_dim']
1845
+ channel_of_interest = metadata_settings['channel_of_interest']
1846
+
1847
+ # Advanced settings
1848
+ plot = advanced_settings['plot']
1849
+ plot_nr = advanced_settings['plot_nr']
1850
+ plot_control = advanced_settings['plot_control']
1851
+ figuresize = advanced_settings['figuresize']
1852
+ remove_background = advanced_settings['remove_background']
1853
+ backgrounds = advanced_settings['backgrounds']
1854
+ include_noninfected = advanced_settings['include_noninfected']
1530
1855
  include_multiinfected = advanced_settings['include_multiinfected']
1531
1856
  include_multinucleated = advanced_settings['include_multinucleated']
1532
1857
  cells_per_well = advanced_settings['cells_per_well']
@@ -1584,15 +1909,30 @@ def analyze_recruitment(src, metadata_settings, advanced_settings):
1584
1909
  df = df.dropna(subset=['condition'])
1585
1910
  print(f'After dropping non-annotated wells: {len(df)} rows')
1586
1911
  files = df['file_name'].tolist()
1912
+ print(f'found: {len(files)} files')
1587
1913
  files = [item + '.npy' for item in files]
1588
1914
  random.shuffle(files)
1589
-
1915
+
1916
+ _max = 10**100
1917
+
1918
+ if cell_size_range is None and nucleus_size_range is None and pathogen_size_range is None:
1919
+ filter_min_max = None
1920
+ else:
1921
+ if cell_size_range is None:
1922
+ cell_size_range = [0,_max]
1923
+ if nucleus_size_range is None:
1924
+ nucleus_size_range = [0,_max]
1925
+ if pathogen_size_range is None:
1926
+ pathogen_size_range = [0,_max]
1927
+
1928
+ filter_min_max = [[cell_size_range[0],cell_size_range[1]],[nucleus_size_range[0],nucleus_size_range[1]],[pathogen_size_range[0],pathogen_size_range[1]]]
1929
+
1590
1930
  if plot:
1591
1931
  plot_settings = {'include_noninfected':include_noninfected,
1592
1932
  'include_multiinfected':include_multiinfected,
1593
1933
  'include_multinucleated':include_multinucleated,
1594
1934
  'remove_background':remove_background,
1595
- 'filter_min_max':[[cell_size_range[0],cell_size_range[1]],[nucleus_size_range[0],nucleus_size_range[1]],[pathogen_size_range[0],pathogen_size_range[1]]],
1935
+ 'filter_min_max':filter_min_max,
1596
1936
  'channel_dims':channel_dims,
1597
1937
  'backgrounds':backgrounds,
1598
1938
  'cell_mask_dim':mask_dims[0],
@@ -1649,15 +1989,37 @@ def analyze_recruitment(src, metadata_settings, advanced_settings):
1649
1989
  cells,wells = _results_to_csv(src, df, df_well)
1650
1990
  return [cells,wells]
1651
1991
 
1652
- @log_function_call
1653
- def preprocess_generate_masks(src, settings={},advanced_settings={}):
1992
+ def preprocess_generate_masks(src, settings={}):
1654
1993
 
1655
1994
  from .io import preprocess_img_data, _load_and_concatenate_arrays
1656
1995
  from .plot import plot_merged, plot_arrays
1657
1996
  from .utils import _pivot_counts_table
1658
-
1659
- settings = {**settings, **advanced_settings}
1997
+
1998
+ settings['plot'] = False
1999
+ settings['fps'] = 2
2000
+ settings['remove_background'] = True
2001
+ settings['lower_quantile'] = 0.02
2002
+ settings['merge'] = False
2003
+ settings['normalize_plots'] = True
2004
+ settings['all_to_mip'] = False
2005
+ settings['pick_slice'] = False
2006
+ settings['skip_mode'] = src
2007
+ settings['workers'] = os.cpu_count()-4
2008
+ settings['verbose'] = True
2009
+ settings['examples_to_plot'] = 1
1660
2010
  settings['src'] = src
2011
+ settings['upscale'] = False
2012
+ settings['upscale_factor'] = 2.0
2013
+
2014
+ settings['randomize'] = True
2015
+ settings['timelapse'] = False
2016
+ settings['timelapse_displacement'] = None
2017
+ settings['timelapse_memory'] = 3
2018
+ settings['timelapse_frame_limits'] = None
2019
+ settings['timelapse_remove_transient'] = False
2020
+ settings['timelapse_mode'] = 'trackpy'
2021
+ settings['timelapse_objects'] = ['cells']
2022
+
1661
2023
  settings_df = pd.DataFrame(list(settings.items()), columns=['Key', 'Value'])
1662
2024
  settings_csv = os.path.join(src,'settings','preprocess_generate_masks_settings.csv')
1663
2025
  os.makedirs(os.path.join(src,'settings'), exist_ok=True)
@@ -1676,7 +2038,7 @@ def preprocess_generate_masks(src, settings={},advanced_settings={}):
1676
2038
  settings['save'] = [settings['save']]*3
1677
2039
 
1678
2040
  if settings['preprocess']:
1679
- preprocess_img_data(settings)
2041
+ settings, src = preprocess_img_data(settings)
1680
2042
 
1681
2043
  if settings['masks']:
1682
2044
  mask_src = os.path.join(src, 'norm_channel_stack')
@@ -1726,7 +2088,6 @@ def preprocess_generate_masks(src, settings={},advanced_settings={}):
1726
2088
  'cell_mask_dim':cell_mask_dim,
1727
2089
  'nucleus_mask_dim':nucleus_mask_dim,
1728
2090
  'pathogen_mask_dim':pathogen_mask_dim,
1729
- 'overlay_chans':[0,2,3],
1730
2091
  'outline_thickness':3,
1731
2092
  'outline_color':'gbr',
1732
2093
  'overlay_chans':overlay_channels,
@@ -1738,6 +2099,10 @@ def preprocess_generate_masks(src, settings={},advanced_settings={}):
1738
2099
  'figuresize':20,
1739
2100
  'cmap':'inferno',
1740
2101
  'verbose':False}
2102
+
2103
+ if settings['test_mode'] == True:
2104
+ plot_settings['nr'] = len(os.path.join(src,'merged'))
2105
+
1741
2106
  try:
1742
2107
  fig = plot_merged(src=os.path.join(src,'merged'), settings=plot_settings)
1743
2108
  except Exception as e:
@@ -1750,26 +2115,61 @@ def preprocess_generate_masks(src, settings={},advanced_settings={}):
1750
2115
  print("Successfully completed run")
1751
2116
  return
1752
2117
 
1753
- def identify_masks_finetune(src, dst, model_name, channels, diameter, batch_size, flow_threshold=30, cellprob_threshold=1, figuresize=25, cmap='inferno', verbose=False, plot=False, save=False, custom_model=None, signal_thresholds=1000, normalize=True, resize=False, target_height=None, target_width=None, rescale=True, resample=True, net_avg=False, invert=False, circular=False, percentiles=None, overlay=True, grayscale=False):
2118
+ def identify_masks_finetune(settings):
1754
2119
 
1755
2120
  from .plot import print_mask_and_flows
1756
2121
  from .utils import get_files_from_dir, resize_images_and_labels
1757
2122
  from .io import _load_normalized_images_and_labels, _load_images_and_labels
1758
2123
 
2124
+ src=settings['src']
2125
+ dst=settings['dst']
2126
+ model_name=settings['model_name']
2127
+ diameter=settings['diameter']
2128
+ batch_size=settings['batch_size']
2129
+ flow_threshold=settings['flow_threshold']
2130
+ cellprob_threshold=settings['cellprob_threshold']
2131
+
2132
+ verbose=settings['verbose']
2133
+ plot=settings['plot']
2134
+ save=settings['save']
2135
+ custom_model=settings['custom_model']
2136
+ overlay=settings['overlay']
2137
+
2138
+ figuresize=25
2139
+ cmap='inferno'
2140
+ channels = [0,0]
2141
+ signal_thresholds = 1000
2142
+ normalize = True
2143
+ percentiles = [2,98]
2144
+ circular = False
2145
+ invert = False
2146
+ resize = False
2147
+ settings['width_height'] = [1000,1000]
2148
+ target_height = settings['width_height'][1]
2149
+ target_width = settings['width_height'][0]
2150
+ rescale = False
2151
+ resample = False
2152
+ grayscale = True
2153
+ test = False
2154
+
2155
+ os.makedirs(dst, exist_ok=True)
2156
+
2157
+ if not custom_model is None:
2158
+ if not os.path.exists(custom_model):
2159
+ print(f'Custom model not found: {custom_model}')
2160
+ return
2161
+
1759
2162
  if not torch.cuda.is_available():
1760
2163
  print(f'Torch CUDA is not available, using CPU')
1761
2164
 
1762
2165
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
1763
2166
 
1764
2167
  if custom_model == None:
1765
- if model_name =='cyto':
1766
- model = cp_models.CellposeModel(gpu=True, model_type=model_name, net_avg=False, diam_mean=diameter, pretrained_model=None)
1767
- else:
1768
- model = cp_models.CellposeModel(gpu=True, model_type=model_name)
1769
-
1770
- if custom_model != None:
1771
- model = cp_models.CellposeModel(gpu=torch.cuda.is_available(), model_type=None, pretrained_model=custom_model, diam_mean=diameter, device=device, net_avg=False) #Assuming diameter is defined elsewhere
1772
- print(f'loaded custom model:{custom_model}')
2168
+ model = cp_models.CellposeModel(gpu=True, model_type=model_name, device=device)
2169
+ print(f'Loaded model: {model_name}')
2170
+ else:
2171
+ model = cp_models.CellposeModel(gpu=torch.cuda.is_available(), model_type=None, pretrained_model=custom_model, diam_mean=diameter, device=device)
2172
+ print("Pretrained Model Loaded:", model.pretrained_model)
1773
2173
 
1774
2174
  chans = [2, 1] if model_name == 'cyto2' else [0,0] if model_name == 'nucleus' else [1,0] if model_name == 'cyto' else [2, 0]
1775
2175
 
@@ -1781,14 +2181,16 @@ def identify_masks_finetune(src, dst, model_name, channels, diameter, batch_size
1781
2181
  if verbose == True:
1782
2182
  print(f'Cellpose settings: Model: {model_name}, channels: {channels}, cellpose_chans: {chans}, diameter:{diameter}, flow_threshold:{flow_threshold}, cellprob_threshold:{cellprob_threshold}')
1783
2183
 
1784
- all_image_files = get_files_from_dir(src, file_extension="*.tif")
2184
+ all_image_files = [os.path.join(src, f) for f in os.listdir(src) if f.endswith('.tif')]
2185
+
1785
2186
  random.shuffle(all_image_files)
1786
2187
 
1787
2188
  time_ls = []
1788
2189
  for i in range(0, len(all_image_files), batch_size):
1789
2190
  image_files = all_image_files[i:i+batch_size]
2191
+
1790
2192
  if normalize:
1791
- images, _, image_names, _ = _load_normalized_images_and_labels(image_files=image_files, label_files=None, signal_thresholds=signal_thresholds, channels=channels, percentiles=percentiles, circular=circular, invert=invert, visualize=verbose)
2193
+ images, _, image_names, _ = _load_normalized_images_and_labels(image_files=image_files, label_files=None, signal_thresholds=signal_thresholds, channels=channels, percentiles=percentiles, circular=circular, invert=invert, visualize=plot)
1792
2194
  images = [np.squeeze(img) if img.shape[-1] == 1 else img for img in images]
1793
2195
  orig_dims = [(image.shape[0], image.shape[1]) for image in images]
1794
2196
  else:
@@ -1809,8 +2211,7 @@ def identify_masks_finetune(src, dst, model_name, channels, diameter, batch_size
1809
2211
  cellprob_threshold=cellprob_threshold,
1810
2212
  rescale=rescale,
1811
2213
  resample=resample,
1812
- net_avg=net_avg,
1813
- progress=False)
2214
+ progress=True)
1814
2215
 
1815
2216
  if len(output) == 4:
1816
2217
  mask, flows, _, _ = output
@@ -1837,8 +2238,7 @@ def identify_masks_finetune(src, dst, model_name, channels, diameter, batch_size
1837
2238
  cv2.imwrite(output_filename, mask)
1838
2239
  return
1839
2240
 
1840
- @log_function_call
1841
- def identify_masks(src, object_type, model_name, batch_size, channels, diameter, minimum_size, maximum_size, flow_threshold=30, cellprob_threshold=1, figuresize=25, cmap='inferno', refine_masks=True, filter_size=True, filter_dimm=True, remove_border_objects=False, verbose=False, plot=False, merge=False, save=True, start_at=0, file_type='.npz', net_avg=True, resample=True, timelapse=False, timelapse_displacement=None, timelapse_frame_limits=None, timelapse_memory=3, timelapse_remove_transient=False, timelapse_mode='btrack', timelapse_objects='cell'):
2241
+ def identify_masks(src, object_type, model_name, batch_size, channels, diameter, minimum_size, maximum_size, filter_intensity, flow_threshold=30, cellprob_threshold=1, figuresize=25, cmap='inferno', refine_masks=True, filter_size=True, filter_dimm=True, remove_border_objects=False, verbose=False, plot=False, merge=False, save=True, start_at=0, file_type='.npz', net_avg=True, resample=True, timelapse=False, timelapse_displacement=None, timelapse_frame_limits=None, timelapse_memory=3, timelapse_remove_transient=False, timelapse_mode='btrack', timelapse_objects='cell'):
1842
2242
  """
1843
2243
  Identify masks from the source images.
1844
2244
 
@@ -1886,13 +2286,13 @@ def identify_masks(src, object_type, model_name, batch_size, channels, diameter,
1886
2286
 
1887
2287
  #Note add logic that handles batches of size 1 as these will break the code batches must all be > 2 images
1888
2288
  gc.collect()
1889
- #print('========== generating masks ==========')
1890
2289
 
1891
2290
  if not torch.cuda.is_available():
1892
2291
  print(f'Torch CUDA is not available, using CPU')
1893
2292
 
1894
2293
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
1895
- model = cp_models.Cellpose(gpu=True, model_type=model_name, device=device) #net_avg=net_avg
2294
+ model = cp_models.Cellpose(gpu=True, model_type=model_name, device=device)
2295
+
1896
2296
  if file_type == '.npz':
1897
2297
  paths = [os.path.join(src, file) for file in os.listdir(src) if file.endswith('.npz')]
1898
2298
  else:
@@ -1919,9 +2319,6 @@ def identify_masks(src, object_type, model_name, batch_size, channels, diameter,
1919
2319
 
1920
2320
  average_sizes = []
1921
2321
  time_ls = []
1922
- moving_avg_q1 = 0
1923
- moving_avg_q3 = 0
1924
- moving_count = 0
1925
2322
  for file_index, path in enumerate(paths):
1926
2323
 
1927
2324
  name = os.path.basename(path)
@@ -1979,7 +2376,7 @@ def identify_masks(src, object_type, model_name, batch_size, channels, diameter,
1979
2376
 
1980
2377
  cellpose_batch_size = _get_cellpose_batch_size()
1981
2378
 
1982
- model = cellpose.denoise.DenoiseModel(model_type=f"denoise_{model_name}", gpu=True)
2379
+ #model = cellpose.denoise.DenoiseModel(model_type=f"denoise_{model_name}", gpu=True)
1983
2380
 
1984
2381
  masks, flows, _, _ = model.eval(x=batch,
1985
2382
  batch_size=cellpose_batch_size,
@@ -1991,9 +2388,9 @@ def identify_masks(src, object_type, model_name, batch_size, channels, diameter,
1991
2388
  cellprob_threshold=cellprob_threshold,
1992
2389
  rescale=None,
1993
2390
  resample=resample,
1994
- #net_avg=net_avg,
1995
2391
  stitch_threshold=stitch_threshold,
1996
2392
  progress=None)
2393
+
1997
2394
  print('Masks shape',masks.shape)
1998
2395
  if timelapse:
1999
2396
  _save_object_counts_to_database(masks, object_type, batch_filenames, count_loc, added_string='_timelapse')
@@ -2017,7 +2414,7 @@ def identify_masks(src, object_type, model_name, batch_size, channels, diameter,
2017
2414
 
2018
2415
  else:
2019
2416
  _save_object_counts_to_database(masks, object_type, batch_filenames, count_loc, added_string='_before_filtration')
2020
- mask_stack = _filter_cp_masks(masks, flows, refine_masks, filter_size, minimum_size, maximum_size, remove_border_objects, merge, filter_dimm, batch, moving_avg_q1, moving_avg_q3, moving_count, plot, figuresize)
2417
+ mask_stack = _filter_cp_masks(masks, flows, filter_size, filter_intensity, minimum_size, maximum_size, remove_border_objects, merge, batch, plot, figuresize)
2021
2418
  _save_object_counts_to_database(mask_stack, object_type, batch_filenames, count_loc, added_string='_after_filtration')
2022
2419
 
2023
2420
  if not np.any(mask_stack):
@@ -2049,10 +2446,13 @@ def identify_masks(src, object_type, model_name, batch_size, channels, diameter,
2049
2446
  gc.collect()
2050
2447
  return
2051
2448
 
2052
- @log_function_call
2449
+ def all_elements_match(list1, list2):
2450
+ # Check if all elements in list1 are in list2
2451
+ return all(element in list2 for element in list1)
2452
+
2053
2453
  def generate_cellpose_masks(src, settings, object_type):
2054
2454
 
2055
- from .utils import _masks_to_masks_stack, _filter_cp_masks, _get_cellpose_batch_size, _get_object_settings, _get_cellpose_channels
2455
+ from .utils import _masks_to_masks_stack, _filter_cp_masks, _get_cellpose_batch_size, _get_object_settings, _get_cellpose_channels, _choose_model, mask_object_count
2056
2456
  from .io import _create_database, _save_object_counts_to_database, _check_masks, _get_avg_object_size
2057
2457
  from .timelapse import _npz_to_movie, _btrack_track_cells, _trackpy_track_cells
2058
2458
  from .plot import plot_masks
@@ -2079,16 +2479,15 @@ def generate_cellpose_masks(src, settings, object_type):
2079
2479
  object_settings = _get_object_settings(object_type, settings)
2080
2480
  model_name = object_settings['model_name']
2081
2481
 
2082
- cellpose_channels = _get_cellpose_channels(settings['nucleus_channel'], settings['pathogen_channel'], settings['cell_channel'])
2482
+ cellpose_channels = _get_cellpose_channels(src, settings['nucleus_channel'], settings['pathogen_channel'], settings['cell_channel'])
2483
+ if settings['verbose']:
2484
+ print(cellpose_channels)
2485
+
2083
2486
  channels = cellpose_channels[object_type]
2084
2487
  cellpose_batch_size = _get_cellpose_batch_size()
2085
-
2086
2488
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
2087
- model = cp_models.Cellpose(gpu=True, model_type=model_name, device=device) #net_avg=net_avg
2088
- #dn = denoise.CellposeDenoiseModel(model_type=f"denoise_{model_name}", gpu=True, device=device)
2089
-
2489
+ model = _choose_model(model_name, device, object_type='cell', restore_type=None)
2090
2490
  chans = [2, 1] if model_name == 'cyto2' else [0,0] if model_name == 'nucleus' else [2,0] if model_name == 'cyto' else [2, 0] if model_name == 'cyto3' else [2, 0]
2091
-
2092
2491
  paths = [os.path.join(src, file) for file in os.listdir(src) if file.endswith('.npz')]
2093
2492
 
2094
2493
  count_loc = os.path.dirname(src)+'/measurements/measurements.db'
@@ -2097,10 +2496,6 @@ def generate_cellpose_masks(src, settings, object_type):
2097
2496
 
2098
2497
  average_sizes = []
2099
2498
  time_ls = []
2100
- moving_avg_q1 = 0
2101
- moving_avg_q3 = 0
2102
- moving_count = 0
2103
-
2104
2499
  for file_index, path in enumerate(paths):
2105
2500
  name = os.path.basename(path)
2106
2501
  name, ext = os.path.splitext(name)
@@ -2111,16 +2506,22 @@ def generate_cellpose_masks(src, settings, object_type):
2111
2506
  stack = data['data']
2112
2507
  filenames = data['filenames']
2113
2508
  if settings['timelapse']:
2509
+
2510
+ trackable_objects = ['cell','nucleus','pathogen']
2511
+ if not all_elements_match(settings['timelapse_objects'], trackable_objects):
2512
+ print(f'timelapse_objects {settings["timelapse_objects"]} must be a subset of {trackable_objects}')
2513
+ return
2514
+
2114
2515
  if len(stack) != batch_size:
2115
2516
  print(f'Changed batch_size:{batch_size} to {len(stack)}, data length:{len(stack)}')
2116
- settings['batch_size'] = len(stack)
2517
+ settings['timelapse_batch_size'] = len(stack)
2117
2518
  batch_size = len(stack)
2118
2519
  if isinstance(timelapse_frame_limits, list):
2119
2520
  if len(timelapse_frame_limits) >= 2:
2120
2521
  stack = stack[timelapse_frame_limits[0]: timelapse_frame_limits[1], :, :, :].astype(stack.dtype)
2121
2522
  filenames = filenames[timelapse_frame_limits[0]: timelapse_frame_limits[1]]
2122
2523
  batch_size = len(stack)
2123
- print(f'Cut batch an indecies: {timelapse_frame_limits}, New batch_size: {batch_size} ')
2524
+ print(f'Cut batch at indecies: {timelapse_frame_limits}, New batch_size: {batch_size} ')
2124
2525
 
2125
2526
  for i in range(0, stack.shape[0], batch_size):
2126
2527
  mask_stack = []
@@ -2136,8 +2537,7 @@ def generate_cellpose_masks(src, settings, object_type):
2136
2537
  if not settings['plot']:
2137
2538
  batch, batch_filenames = _check_masks(batch, batch_filenames, output_folder)
2138
2539
  if batch.size == 0:
2139
- print(f'Processing {file_index}/{len(paths)}: Images/N100pz {batch.shape[0]}')
2140
- #print(f'Processing {file_index}/{len(paths)}: Images/N100pz {batch.shape[0]}', end='\r', flush=True)
2540
+ print(f'Processing {file_index}/{len(paths)}: Images/npz {batch.shape[0]}')
2141
2541
  continue
2142
2542
  if batch.max() > 1:
2143
2543
  batch = batch / batch.max()
@@ -2150,10 +2550,8 @@ def generate_cellpose_masks(src, settings, object_type):
2150
2550
  _npz_to_movie(batch, batch_filenames, save_path, fps=2)
2151
2551
  else:
2152
2552
  stitch_threshold=0.0
2153
- #print(batch.shape)
2154
- #batch, _, _, _ = dn.eval(x=batch, channels=chans, diameter=object_settings['diameter'])
2155
- #batch = np.stack((batch, batch), axis=-1)
2156
- #print(f'object: {object_type} chans : {chans} channels : {channels} model: {model_name}')
2553
+
2554
+ print('batch.shape',batch.shape)
2157
2555
  masks, flows, _, _ = model.eval(x=batch,
2158
2556
  batch_size=cellpose_batch_size,
2159
2557
  normalize=False,
@@ -2165,9 +2563,15 @@ def generate_cellpose_masks(src, settings, object_type):
2165
2563
  rescale=None,
2166
2564
  resample=object_settings['resample'],
2167
2565
  stitch_threshold=stitch_threshold)
2168
- #progress=None)
2169
-
2566
+
2170
2567
  if timelapse:
2568
+ if settings['plot']:
2569
+ for idx, (mask, flow, image) in enumerate(zip(masks, flows[0], batch)):
2570
+ if idx == 0:
2571
+ num_objects = mask_object_count(mask)
2572
+ print(f'Number of objects: {num_objects}')
2573
+ plot_masks(batch=image, masks=mask, flows=flow, cmap='inferno', figuresize=figuresize, nr=1, file_type='.npz', print_object_number=True)
2574
+
2171
2575
  _save_object_counts_to_database(masks, object_type, batch_filenames, count_loc, added_string='_timelapse')
2172
2576
  if object_type in timelapse_objects:
2173
2577
  if timelapse_mode == 'btrack':
@@ -2196,35 +2600,54 @@ def generate_cellpose_masks(src, settings, object_type):
2196
2600
  name=name,
2197
2601
  batch_filenames=batch_filenames,
2198
2602
  object_type=object_type,
2199
- masks_3D=masks,
2603
+ masks=masks,
2200
2604
  timelapse_displacement=timelapse_displacement,
2201
2605
  timelapse_memory=timelapse_memory,
2202
2606
  timelapse_remove_transient=timelapse_remove_transient,
2203
2607
  plot=settings['plot'],
2204
2608
  save=settings['save'],
2205
- timelapse_mode=timelapse_mode)
2609
+ mode=timelapse_mode)
2206
2610
  else:
2207
2611
  mask_stack = _masks_to_masks_stack(masks)
2208
-
2209
2612
  else:
2210
2613
  _save_object_counts_to_database(masks, object_type, batch_filenames, count_loc, added_string='_before_filtration')
2211
- mask_stack = _filter_cp_masks(masks=masks,
2212
- flows=flows,
2213
- filter_size=object_settings['filter_size'],
2214
- minimum_size=object_settings['minimum_size'],
2215
- maximum_size=object_settings['maximum_size'],
2216
- remove_border_objects=object_settings['remove_border_objects'],
2217
- merge=False,
2218
- filter_dimm=object_settings['filter_dimm'],
2219
- batch=batch,
2220
- moving_avg_q1=moving_avg_q1,
2221
- moving_avg_q3=moving_avg_q3,
2222
- moving_count=moving_count,
2223
- plot=settings['plot'],
2224
- figuresize=figuresize)
2225
-
2226
- _save_object_counts_to_database(mask_stack, object_type, batch_filenames, count_loc, added_string='_after_filtration')
2614
+ if object_settings['merge'] and not settings['filter']:
2615
+ mask_stack = _filter_cp_masks(masks=masks,
2616
+ flows=flows,
2617
+ filter_size=False,
2618
+ filter_intensity=False,
2619
+ minimum_size=object_settings['minimum_size'],
2620
+ maximum_size=object_settings['maximum_size'],
2621
+ remove_border_objects=False,
2622
+ merge=object_settings['merge'],
2623
+ batch=batch,
2624
+ plot=settings['plot'],
2625
+ figuresize=figuresize)
2626
+
2627
+ if settings['filter']:
2628
+ mask_stack = _filter_cp_masks(masks=masks,
2629
+ flows=flows,
2630
+ filter_size=object_settings['filter_size'],
2631
+ filter_intensity=object_settings['filter_intensity'],
2632
+ minimum_size=object_settings['minimum_size'],
2633
+ maximum_size=object_settings['maximum_size'],
2634
+ remove_border_objects=object_settings['remove_border_objects'],
2635
+ merge=object_settings['merge'],
2636
+ batch=batch,
2637
+ plot=settings['plot'],
2638
+ figuresize=figuresize)
2639
+
2640
+ _save_object_counts_to_database(mask_stack, object_type, batch_filenames, count_loc, added_string='_after_filtration')
2641
+ else:
2642
+ mask_stack = _masks_to_masks_stack(masks)
2227
2643
 
2644
+ if settings['plot']:
2645
+ for idx, (mask, flow, image) in enumerate(zip(masks, flows[0], batch)):
2646
+ if idx == 0:
2647
+ num_objects = mask_object_count(mask)
2648
+ print(f'Number of objects, : {num_objects}')
2649
+ plot_masks(batch=image, masks=mask, flows=flow, cmap='inferno', figuresize=figuresize, nr=1, file_type='.npz', print_object_number=True)
2650
+
2228
2651
  if not np.any(mask_stack):
2229
2652
  average_obj_size = 0
2230
2653
  else:
@@ -2240,7 +2663,6 @@ def generate_cellpose_masks(src, settings, object_type):
2240
2663
  time_in_min = average_time/60
2241
2664
  time_per_mask = average_time/batch_size
2242
2665
  print(f'Processing {len(paths)} files with {batch_size} imgs: {(file_index+1)*(batch_size+1)}/{(len(paths))*(batch_size+1)}: Time/batch {time_in_min:.3f} min: Time/mask {time_per_mask:.3f}sec: {object_type} size: {overall_average_size:.3f} px2')
2243
- #print(f'Processing {len(paths)} files with {batch_size} imgs: {(file_index+1)*(batch_size+1)}/{(len(paths))*(batch_size+1)}: Time/batch {time_in_min:.3f} min: Time/mask {time_per_mask:.3f}sec: {object_type} size: {overall_average_size:.3f} px2', end='\r', flush=True)
2244
2666
  if not timelapse:
2245
2667
  if settings['plot']:
2246
2668
  plot_masks(batch, mask_stack, flows, figuresize=figuresize, cmap='inferno', nr=batch_size)
@@ -2252,4 +2674,663 @@ def generate_cellpose_masks(src, settings, object_type):
2252
2674
  batch_filenames = []
2253
2675
  gc.collect()
2254
2676
  torch.cuda.empty_cache()
2255
- return
2677
+ return
2678
+
2679
+ def generate_masks_from_imgs(src, model, model_name, batch_size, diameter, cellprob_threshold, grayscale, save, normalize, channels, percentiles, circular, invert, plot, resize, target_height, target_width, verbose):
2680
+ from .io import _load_images_and_labels, _load_normalized_images_and_labels
2681
+ from .utils import resize_images_and_labels, resizescikit
2682
+ from .plot import print_mask_and_flows
2683
+
2684
+ dst = os.path.join(src, model_name)
2685
+ os.makedirs(dst, exist_ok=True)
2686
+
2687
+ flow_threshold = 30
2688
+ chans = [2, 1] if model_name == 'cyto2' else [0,0] if model_name == 'nucleus' else [1,0] if model_name == 'cyto' else [2, 0]
2689
+
2690
+ if grayscale:
2691
+ chans=[0, 0]
2692
+
2693
+ all_image_files = [os.path.join(src, f) for f in os.listdir(src) if f.endswith('.tif')]
2694
+ random.shuffle(all_image_files)
2695
+
2696
+
2697
+ if verbose == True:
2698
+ print(f'Cellpose settings: Model: {model_name}, channels: {channels}, cellpose_chans: {chans}, diameter:{diameter}, flow_threshold:{flow_threshold}, cellprob_threshold:{cellprob_threshold}')
2699
+
2700
+ time_ls = []
2701
+ for i in range(0, len(all_image_files), batch_size):
2702
+ image_files = all_image_files[i:i+batch_size]
2703
+
2704
+ if normalize:
2705
+ images, _, image_names, _ = _load_normalized_images_and_labels(image_files=image_files, label_files=None, signal_thresholds=100, channels=channels, percentiles=percentiles, circular=circular, invert=invert, visualize=plot)
2706
+ images = [np.squeeze(img) if img.shape[-1] == 1 else img for img in images]
2707
+ orig_dims = [(image.shape[0], image.shape[1]) for image in images]
2708
+ else:
2709
+ images, _, image_names, _ = _load_images_and_labels(image_files=image_files, label_files=None, circular=circular, invert=invert)
2710
+ images = [np.squeeze(img) if img.shape[-1] == 1 else img for img in images]
2711
+ orig_dims = [(image.shape[0], image.shape[1]) for image in images]
2712
+ if resize:
2713
+ images, _ = resize_images_and_labels(images, None, target_height, target_width, True)
2714
+
2715
+ for file_index, stack in enumerate(images):
2716
+ start = time.time()
2717
+ output = model.eval(x=stack,
2718
+ normalize=False,
2719
+ channels=chans,
2720
+ channel_axis=3,
2721
+ diameter=diameter,
2722
+ flow_threshold=flow_threshold,
2723
+ cellprob_threshold=cellprob_threshold,
2724
+ rescale=False,
2725
+ resample=False,
2726
+ progress=True)
2727
+
2728
+ if len(output) == 4:
2729
+ mask, flows, _, _ = output
2730
+ elif len(output) == 3:
2731
+ mask, flows, _ = output
2732
+ else:
2733
+ raise ValueError("Unexpected number of return values from model.eval()")
2734
+
2735
+ if resize:
2736
+ dims = orig_dims[file_index]
2737
+ mask = resizescikit(mask, dims, order=0, preserve_range=True, anti_aliasing=False).astype(mask.dtype)
2738
+
2739
+ stop = time.time()
2740
+ duration = (stop - start)
2741
+ time_ls.append(duration)
2742
+ average_time = np.mean(time_ls) if len(time_ls) > 0 else 0
2743
+ print(f'Processing {file_index+1}/{len(images)} images : Time/image {average_time:.3f} sec', end='\r', flush=True)
2744
+ if plot:
2745
+ if resize:
2746
+ stack = resizescikit(stack, dims, preserve_range=True, anti_aliasing=False).astype(stack.dtype)
2747
+ print_mask_and_flows(stack, mask, flows, overlay=True)
2748
+ if save:
2749
+ output_filename = os.path.join(dst, image_names[file_index])
2750
+ cv2.imwrite(output_filename, mask)
2751
+
2752
+
2753
+ def check_cellpose_models(settings):
2754
+
2755
+ src = settings['src']
2756
+ batch_size = settings['batch_size']
2757
+ cellprob_threshold = settings['cellprob_threshold']
2758
+ save = settings['save']
2759
+ normalize = settings['normalize']
2760
+ channels = settings['channels']
2761
+ percentiles = settings['percentiles']
2762
+ circular = settings['circular']
2763
+ invert = settings['invert']
2764
+ plot = settings['plot']
2765
+ diameter = settings['diameter']
2766
+ resize = settings['resize']
2767
+ grayscale = settings['grayscale']
2768
+ verbose = settings['verbose']
2769
+ target_height = settings['width_height'][0]
2770
+ target_width = settings['width_height'][1]
2771
+
2772
+ cellpose_models = ['cyto', 'nuclei', 'cyto2', 'cyto3']
2773
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
2774
+
2775
+ for model_name in cellpose_models:
2776
+
2777
+ model = cp_models.CellposeModel(gpu=True, model_type=model_name, device=device)
2778
+ print(f'Using {model_name}')
2779
+ generate_masks_from_imgs(src, model, model_name, batch_size, diameter, cellprob_threshold, grayscale, save, normalize, channels, percentiles, circular, invert, plot, resize, target_height, target_width, verbose)
2780
+
2781
+ return
2782
+
2783
+ def compare_masks_v1(dir1, dir2, dir3, verbose=False):
2784
+
2785
+ from .io import _read_mask
2786
+ from .plot import visualize_masks, plot_comparison_results
2787
+ from .utils import extract_boundaries, boundary_f1_score, compute_segmentation_ap, jaccard_index, dice_coefficient
2788
+
2789
+ filenames = os.listdir(dir1)
2790
+ results = []
2791
+ cond_1 = os.path.basename(dir1)
2792
+ cond_2 = os.path.basename(dir2)
2793
+ cond_3 = os.path.basename(dir3)
2794
+
2795
+ for index, filename in enumerate(filenames):
2796
+ print(f'Processing image:{index+1}', end='\r', flush=True)
2797
+ path1, path2, path3 = os.path.join(dir1, filename), os.path.join(dir2, filename), os.path.join(dir3, filename)
2798
+
2799
+ print(path1)
2800
+ print(path2)
2801
+ print(path3)
2802
+
2803
+ if os.path.exists(path2) and os.path.exists(path3):
2804
+
2805
+ mask1, mask2, mask3 = _read_mask(path1), _read_mask(path2), _read_mask(path3)
2806
+ boundary_true1, boundary_true2, boundary_true3 = extract_boundaries(mask1), extract_boundaries(mask2), extract_boundaries(mask3)
2807
+
2808
+
2809
+ true_masks, pred_masks = [mask1], [mask2, mask3] # Assuming mask1 is the ground truth for simplicity
2810
+ true_labels, pred_labels_1, pred_labels_2 = label(mask1), label(mask2), label(mask3)
2811
+ average_precision_0, average_precision_1 = compute_segmentation_ap(mask1, mask2), compute_segmentation_ap(mask1, mask3)
2812
+ ap_scores = [average_precision_0, average_precision_1]
2813
+
2814
+ if verbose:
2815
+ #unique_values1, unique_values2, unique_values3 = np.unique(mask1), np.unique(mask2), np.unique(mask3)
2816
+ #print(f"Unique values in mask 1: {unique_values1}, mask 2: {unique_values2}, mask 3: {unique_values3}")
2817
+ visualize_masks(boundary_true1, boundary_true2, boundary_true3, title=f"Boundaries - {filename}")
2818
+
2819
+ boundary_f1_12, boundary_f1_13, boundary_f1_23 = boundary_f1_score(mask1, mask2), boundary_f1_score(mask1, mask3), boundary_f1_score(mask2, mask3)
2820
+
2821
+ if (np.unique(mask1).size == 1 and np.unique(mask1)[0] == 0) and \
2822
+ (np.unique(mask2).size == 1 and np.unique(mask2)[0] == 0) and \
2823
+ (np.unique(mask3).size == 1 and np.unique(mask3)[0] == 0):
2824
+ continue
2825
+
2826
+ if verbose:
2827
+ #unique_values4, unique_values5, unique_values6 = np.unique(boundary_f1_12), np.unique(boundary_f1_13), np.unique(boundary_f1_23)
2828
+ #print(f"Unique values in boundary mask 1: {unique_values4}, mask 2: {unique_values5}, mask 3: {unique_values6}")
2829
+ visualize_masks(mask1, mask2, mask3, title=filename)
2830
+
2831
+ jaccard12 = jaccard_index(mask1, mask2)
2832
+ dice12 = dice_coefficient(mask1, mask2)
2833
+
2834
+ jaccard13 = jaccard_index(mask1, mask3)
2835
+ dice13 = dice_coefficient(mask1, mask3)
2836
+
2837
+ jaccard23 = jaccard_index(mask2, mask3)
2838
+ dice23 = dice_coefficient(mask2, mask3)
2839
+
2840
+ results.append({
2841
+ f'filename': filename,
2842
+ f'jaccard_{cond_1}_{cond_2}': jaccard12,
2843
+ f'dice_{cond_1}_{cond_2}': dice12,
2844
+ f'jaccard_{cond_1}_{cond_3}': jaccard13,
2845
+ f'dice_{cond_1}_{cond_3}': dice13,
2846
+ f'jaccard_{cond_2}_{cond_3}': jaccard23,
2847
+ f'dice_{cond_2}_{cond_3}': dice23,
2848
+ f'boundary_f1_{cond_1}_{cond_2}': boundary_f1_12,
2849
+ f'boundary_f1_{cond_1}_{cond_3}': boundary_f1_13,
2850
+ f'boundary_f1_{cond_2}_{cond_3}': boundary_f1_23,
2851
+ f'average_precision_{cond_1}_{cond_2}': ap_scores[0],
2852
+ f'average_precision_{cond_1}_{cond_3}': ap_scores[1]
2853
+ })
2854
+ else:
2855
+ print(f'Cannot find {path1} or {path2} or {path3}')
2856
+ fig = plot_comparison_results(results)
2857
+ return results, fig
2858
+
2859
+ def compare_cellpose_masks_v1(src, verbose=False):
2860
+ from .io import _read_mask
2861
+ from .plot import visualize_masks, plot_comparison_results, visualize_cellpose_masks
2862
+ from .utils import extract_boundaries, boundary_f1_score, compute_segmentation_ap, jaccard_index
2863
+
2864
+ import os
2865
+ import numpy as np
2866
+ from skimage.measure import label
2867
+
2868
+ # Collect all subdirectories in src
2869
+ dirs = [os.path.join(src, d) for d in os.listdir(src) if os.path.isdir(os.path.join(src, d))]
2870
+
2871
+ dirs.sort() # Optional: sort directories if needed
2872
+
2873
+ # Get common files in all directories
2874
+ common_files = set(os.listdir(dirs[0]))
2875
+ for d in dirs[1:]:
2876
+ common_files.intersection_update(os.listdir(d))
2877
+ common_files = list(common_files)
2878
+
2879
+ results = []
2880
+ conditions = [os.path.basename(d) for d in dirs]
2881
+
2882
+ for index, filename in enumerate(common_files):
2883
+ print(f'Processing image {index+1}/{len(common_files)}', end='\r', flush=True)
2884
+ paths = [os.path.join(d, filename) for d in dirs]
2885
+
2886
+ # Check if file exists in all directories
2887
+ if not all(os.path.exists(path) for path in paths):
2888
+ print(f'Skipping {filename} as it is not present in all directories.')
2889
+ continue
2890
+
2891
+ masks = [_read_mask(path) for path in paths]
2892
+ boundaries = [extract_boundaries(mask) for mask in masks]
2893
+
2894
+ if verbose:
2895
+ visualize_cellpose_masks(masks, titles=conditions, comparison_title=f"Masks Comparison for {filename}")
2896
+
2897
+ # Initialize data structure for results
2898
+ file_results = {'filename': filename}
2899
+
2900
+ # Compare each mask with each other
2901
+ for i in range(len(masks)):
2902
+ for j in range(i + 1, len(masks)):
2903
+ condition_i = conditions[i]
2904
+ condition_j = conditions[j]
2905
+ mask_i = masks[i]
2906
+ mask_j = masks[j]
2907
+
2908
+ # Compute metrics
2909
+ boundary_f1 = boundary_f1_score(mask_i, mask_j)
2910
+ jaccard = jaccard_index(mask_i, mask_j)
2911
+ average_precision = compute_segmentation_ap(mask_i, mask_j)
2912
+
2913
+ # Store results
2914
+ file_results[f'jaccard_{condition_i}_{condition_j}'] = jaccard
2915
+ file_results[f'boundary_f1_{condition_i}_{condition_j}'] = boundary_f1
2916
+ file_results[f'average_precision_{condition_i}_{condition_j}'] = average_precision
2917
+
2918
+ results.append(file_results)
2919
+
2920
+ fig = plot_comparison_results(results)
2921
+ return results, fig
2922
+
2923
+ def compare_mask(args):
2924
+ src, filename, dirs, conditions = args
2925
+ paths = [os.path.join(d, filename) for d in dirs]
2926
+
2927
+ if not all(os.path.exists(path) for path in paths):
2928
+ return None
2929
+
2930
+ from .io import _read_mask # Import here to avoid issues in multiprocessing
2931
+ from .utils import extract_boundaries, boundary_f1_score, compute_segmentation_ap, jaccard_index
2932
+ from .plot import plot_comparison_results
2933
+
2934
+ masks = [_read_mask(path) for path in paths]
2935
+ file_results = {'filename': filename}
2936
+
2937
+ for i in range(len(masks)):
2938
+ for j in range(i + 1, len(masks)):
2939
+ mask_i, mask_j = masks[i], masks[j]
2940
+ f1_score = boundary_f1_score(mask_i, mask_j)
2941
+ jac_index = jaccard_index(mask_i, mask_j)
2942
+ ap_score = compute_segmentation_ap(mask_i, mask_j)
2943
+
2944
+ file_results.update({
2945
+ f'jaccard_{conditions[i]}_{conditions[j]}': jac_index,
2946
+ f'boundary_f1_{conditions[i]}_{conditions[j]}': f1_score,
2947
+ f'ap_{conditions[i]}_{conditions[j]}': ap_score
2948
+ })
2949
+
2950
+ return file_results
2951
+
2952
+ def compare_cellpose_masks(src, verbose=False, processes=None):
2953
+ from .plot import visualize_cellpose_masks, plot_comparison_results
2954
+ from .io import _read_mask
2955
+ dirs = [os.path.join(src, d) for d in os.listdir(src) if os.path.isdir(os.path.join(src, d))]
2956
+ dirs.sort() # Optional: sort directories if needed
2957
+ conditions = [os.path.basename(d) for d in dirs]
2958
+
2959
+ # Get common files in all directories
2960
+ common_files = set(os.listdir(dirs[0]))
2961
+ for d in dirs[1:]:
2962
+ common_files.intersection_update(os.listdir(d))
2963
+ common_files = list(common_files)
2964
+
2965
+ # Create a pool of workers
2966
+ with Pool(processes=processes) as pool:
2967
+ args = [(src, filename, dirs, conditions) for filename in common_files]
2968
+ results = pool.map(compare_mask, args)
2969
+
2970
+ # Filter out None results (from skipped files)
2971
+ results = [res for res in results if res is not None]
2972
+
2973
+ if verbose:
2974
+ for result in results:
2975
+ filename = result['filename']
2976
+ masks = [_read_mask(os.path.join(d, filename)) for d in dirs]
2977
+ visualize_cellpose_masks(masks, titles=conditions, comparison_title=f"Masks Comparison for {filename}")
2978
+
2979
+ fig = plot_comparison_results(results)
2980
+ return results, fig
2981
+
2982
+
2983
+ def _calculate_similarity(df, features, col_to_compare, val1, val2):
2984
+ """
2985
+ Calculate similarity scores of each well to the positive and negative controls using various metrics.
2986
+
2987
+ Args:
2988
+ df (pandas.DataFrame): DataFrame containing the data.
2989
+ features (list): List of feature columns to use for similarity calculation.
2990
+ col_to_compare (str): Column name to use for comparing groups.
2991
+ val1, val2 (str): Values in col_to_compare to create subsets for comparison.
2992
+
2993
+ Returns:
2994
+ pandas.DataFrame: DataFrame with similarity scores.
2995
+ """
2996
+ # Separate positive and negative control wells
2997
+ pos_control = df[df[col_to_compare] == val1][features].mean()
2998
+ neg_control = df[df[col_to_compare] == val2][features].mean()
2999
+
3000
+ # Standardize features for Mahalanobis distance
3001
+ scaler = StandardScaler()
3002
+ scaled_features = scaler.fit_transform(df[features])
3003
+
3004
+ # Regularize the covariance matrix to avoid singularity
3005
+ cov_matrix = np.cov(scaled_features, rowvar=False)
3006
+ inv_cov_matrix = None
3007
+ try:
3008
+ inv_cov_matrix = np.linalg.inv(cov_matrix)
3009
+ except np.linalg.LinAlgError:
3010
+ # Add a small value to the diagonal elements for regularization
3011
+ epsilon = 1e-5
3012
+ inv_cov_matrix = np.linalg.inv(cov_matrix + np.eye(cov_matrix.shape[0]) * epsilon)
3013
+
3014
+ # Calculate similarity scores
3015
+ df['similarity_to_pos_euclidean'] = df[features].apply(lambda row: euclidean(row, pos_control), axis=1)
3016
+ df['similarity_to_neg_euclidean'] = df[features].apply(lambda row: euclidean(row, neg_control), axis=1)
3017
+ df['similarity_to_pos_cosine'] = df[features].apply(lambda row: cosine(row, pos_control), axis=1)
3018
+ df['similarity_to_neg_cosine'] = df[features].apply(lambda row: cosine(row, neg_control), axis=1)
3019
+ df['similarity_to_pos_mahalanobis'] = df[features].apply(lambda row: mahalanobis(row, pos_control, inv_cov_matrix), axis=1)
3020
+ df['similarity_to_neg_mahalanobis'] = df[features].apply(lambda row: mahalanobis(row, neg_control, inv_cov_matrix), axis=1)
3021
+ df['similarity_to_pos_manhattan'] = df[features].apply(lambda row: cityblock(row, pos_control), axis=1)
3022
+ df['similarity_to_neg_manhattan'] = df[features].apply(lambda row: cityblock(row, neg_control), axis=1)
3023
+ df['similarity_to_pos_minkowski'] = df[features].apply(lambda row: minkowski(row, pos_control, p=3), axis=1)
3024
+ df['similarity_to_neg_minkowski'] = df[features].apply(lambda row: minkowski(row, neg_control, p=3), axis=1)
3025
+ df['similarity_to_pos_chebyshev'] = df[features].apply(lambda row: chebyshev(row, pos_control), axis=1)
3026
+ df['similarity_to_neg_chebyshev'] = df[features].apply(lambda row: chebyshev(row, neg_control), axis=1)
3027
+ df['similarity_to_pos_hamming'] = df[features].apply(lambda row: hamming(row, pos_control), axis=1)
3028
+ df['similarity_to_neg_hamming'] = df[features].apply(lambda row: hamming(row, neg_control), axis=1)
3029
+ df['similarity_to_pos_jaccard'] = df[features].apply(lambda row: jaccard(row, pos_control), axis=1)
3030
+ df['similarity_to_neg_jaccard'] = df[features].apply(lambda row: jaccard(row, neg_control), axis=1)
3031
+ df['similarity_to_pos_braycurtis'] = df[features].apply(lambda row: braycurtis(row, pos_control), axis=1)
3032
+ df['similarity_to_neg_braycurtis'] = df[features].apply(lambda row: braycurtis(row, neg_control), axis=1)
3033
+
3034
+ return df
3035
+
3036
+ def _permutation_importance(df, feature_string='channel_3', col_to_compare='col', pos='c1', neg='c2', exclude=None, n_repeats=10, clean=True, nr_to_plot=30, n_estimators=100, test_size=0.2, random_state=42, model_type='xgboost', n_jobs=-1):
3037
+
3038
+ """
3039
+ Calculates permutation importance for numerical features in the dataframe,
3040
+ comparing groups based on specified column values and uses the model to predict
3041
+ the class for all other rows in the dataframe.
3042
+
3043
+ Args:
3044
+ df (pandas.DataFrame): The DataFrame containing the data.
3045
+ feature_string (str): String to filter features that contain this substring.
3046
+ col_to_compare (str): Column name to use for comparing groups.
3047
+ pos, neg (str): Values in col_to_compare to create subsets for comparison.
3048
+ exclude (list or str, optional): Columns to exclude from features.
3049
+ n_repeats (int): Number of repeats for permutation importance.
3050
+ clean (bool): Whether to remove columns with a single value.
3051
+ nr_to_plot (int): Number of top features to plot based on permutation importance.
3052
+ n_estimators (int): Number of trees in the random forest, gradient boosting, or XGBoost model.
3053
+ test_size (float): Proportion of the dataset to include in the test split.
3054
+ random_state (int): Random seed for reproducibility.
3055
+ model_type (str): Type of model to use ('random_forest', 'logistic_regression', 'gradient_boosting', 'xgboost').
3056
+ n_jobs (int): Number of jobs to run in parallel for applicable models.
3057
+
3058
+ Returns:
3059
+ pandas.DataFrame: The original dataframe with added prediction and data usage columns.
3060
+ pandas.DataFrame: DataFrame containing the importances and standard deviations.
3061
+ """
3062
+
3063
+ if 'cells_per_well' in df.columns:
3064
+ df = df.drop(columns=['cells_per_well'])
3065
+
3066
+ # Subset the dataframe based on specified column values
3067
+ df1 = df[df[col_to_compare] == pos].copy()
3068
+ df2 = df[df[col_to_compare] == neg].copy()
3069
+
3070
+ # Create target variable
3071
+ df1['target'] = 0
3072
+ df2['target'] = 1
3073
+
3074
+ # Combine the subsets for analysis
3075
+ combined_df = pd.concat([df1, df2])
3076
+
3077
+ # Automatically select numerical features
3078
+ features = combined_df.select_dtypes(include=[np.number]).columns.tolist()
3079
+ features.remove('target')
3080
+
3081
+ if clean:
3082
+ combined_df = combined_df.loc[:, combined_df.nunique() > 1]
3083
+ features = [feature for feature in features if feature in combined_df.columns]
3084
+
3085
+ if feature_string is not None:
3086
+ feature_list = ['channel_0', 'channel_1', 'channel_2', 'channel_3']
3087
+
3088
+ # Remove feature_string from the list if it exists
3089
+ if feature_string in feature_list:
3090
+ feature_list.remove(feature_string)
3091
+
3092
+ features = [feature for feature in features if feature_string in feature]
3093
+
3094
+ # Iterate through the list and remove columns from df
3095
+ for feature_ in feature_list:
3096
+ features = [feature for feature in features if feature_ not in feature]
3097
+ print(f'After removing {feature_} features: {len(features)}')
3098
+
3099
+ if exclude:
3100
+ if isinstance(exclude, list):
3101
+ features = [feature for feature in features if feature not in exclude]
3102
+ else:
3103
+ features.remove(exclude)
3104
+
3105
+ X = combined_df[features]
3106
+ y = combined_df['target']
3107
+
3108
+ # Split the data into training and testing sets
3109
+ X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=test_size, random_state=random_state)
3110
+
3111
+ # Label the data in the original dataframe
3112
+ combined_df['data_usage'] = 'train'
3113
+ combined_df.loc[X_test.index, 'data_usage'] = 'test'
3114
+
3115
+ # Initialize the model based on model_type
3116
+ if model_type == 'random_forest':
3117
+ model = RandomForestClassifier(n_estimators=n_estimators, random_state=random_state, n_jobs=n_jobs)
3118
+ elif model_type == 'logistic_regression':
3119
+ model = LogisticRegression(max_iter=1000, random_state=random_state, n_jobs=n_jobs)
3120
+ elif model_type == 'gradient_boosting':
3121
+ model = HistGradientBoostingClassifier(max_iter=n_estimators, random_state=random_state) # Supports n_jobs internally
3122
+ elif model_type == 'xgboost':
3123
+ model = XGBClassifier(n_estimators=n_estimators, random_state=random_state, nthread=n_jobs, use_label_encoder=False, eval_metric='logloss')
3124
+ else:
3125
+ raise ValueError(f"Unsupported model_type: {model_type}")
3126
+
3127
+ model.fit(X_train, y_train)
3128
+
3129
+ perm_importance = permutation_importance(model, X_train, y_train, n_repeats=n_repeats, random_state=random_state, n_jobs=n_jobs)
3130
+
3131
+ # Create a DataFrame for permutation importances
3132
+ permutation_df = pd.DataFrame({
3133
+ 'feature': [features[i] for i in perm_importance.importances_mean.argsort()],
3134
+ 'importance_mean': perm_importance.importances_mean[perm_importance.importances_mean.argsort()],
3135
+ 'importance_std': perm_importance.importances_std[perm_importance.importances_mean.argsort()]
3136
+ }).tail(nr_to_plot)
3137
+
3138
+ # Plotting
3139
+ fig, ax = plt.subplots()
3140
+ ax.barh(permutation_df['feature'], permutation_df['importance_mean'], xerr=permutation_df['importance_std'], color="teal", align="center", alpha=0.6)
3141
+ ax.set_xlabel('Permutation Importance')
3142
+ plt.tight_layout()
3143
+ plt.show()
3144
+
3145
+ # Feature importance for models that support it
3146
+ if model_type in ['random_forest', 'xgboost', 'gradient_boosting']:
3147
+ feature_importances = model.feature_importances_
3148
+ feature_importance_df = pd.DataFrame({
3149
+ 'feature': features,
3150
+ 'importance': feature_importances
3151
+ }).sort_values(by='importance', ascending=False).head(nr_to_plot)
3152
+
3153
+ # Plotting feature importance
3154
+ fig, ax = plt.subplots()
3155
+ ax.barh(feature_importance_df['feature'], feature_importance_df['importance'], color="blue", align="center", alpha=0.6)
3156
+ ax.set_xlabel('Feature Importance')
3157
+ plt.tight_layout()
3158
+ plt.show()
3159
+ else:
3160
+ feature_importance_df = pd.DataFrame()
3161
+
3162
+ # Predicting the target variable for the test set
3163
+ predictions_test = model.predict(X_test)
3164
+ combined_df.loc[X_test.index, 'predictions'] = predictions_test
3165
+
3166
+ # Predicting the target variable for the training set
3167
+ predictions_train = model.predict(X_train)
3168
+ combined_df.loc[X_train.index, 'predictions'] = predictions_train
3169
+
3170
+ # Predicting the target variable for all other rows in the dataframe
3171
+ X_all = df[features]
3172
+ all_predictions = model.predict(X_all)
3173
+ df['predictions'] = all_predictions
3174
+
3175
+ # Combine data usage labels back to the original dataframe
3176
+ combined_data_usage = pd.concat([combined_df[['data_usage']], df[['predictions']]], axis=0)
3177
+ df = df.join(combined_data_usage, how='left', rsuffix='_model')
3178
+
3179
+ # Calculating and printing the accuracy metrics
3180
+ accuracy = accuracy_score(y_test, predictions_test)
3181
+ precision = precision_score(y_test, predictions_test)
3182
+ recall = recall_score(y_test, predictions_test)
3183
+ f1 = f1_score(y_test, predictions_test)
3184
+ print(f"Accuracy: {accuracy}")
3185
+ print(f"Precision: {precision}")
3186
+ print(f"Recall: {recall}")
3187
+ print(f"F1 Score: {f1}")
3188
+
3189
+ # Printing class-specific accuracy metrics
3190
+ print("\nClassification Report:")
3191
+ print(classification_report(y_test, predictions_test))
3192
+
3193
+ df = _calculate_similarity(df, features, col_to_compare, pos, neg)
3194
+
3195
+ return [df, permutation_df, feature_importance_df, model, X_train, X_test, y_train, y_test]
3196
+
3197
+ def _shap_analysis(model, X_train, X_test):
3198
+
3199
+ """
3200
+ Performs SHAP analysis on the given model and data.
3201
+
3202
+ Args:
3203
+ model: The trained model.
3204
+ X_train (pandas.DataFrame): Training feature set.
3205
+ X_test (pandas.DataFrame): Testing feature set.
3206
+ """
3207
+
3208
+ explainer = shap.Explainer(model, X_train)
3209
+ shap_values = explainer(X_test)
3210
+
3211
+ # Summary plot
3212
+ shap.summary_plot(shap_values, X_test)
3213
+
3214
+ def plate_heatmap(src, model_type='xgboost', variable='predictions', grouping='mean', min_max='allq', cmap='viridis', channel_of_interest=3, min_count=25, n_estimators=100, col_to_compare='col', pos='c1', neg='c2', exclude=None, n_repeats=10, clean=True, nr_to_plot=20, verbose=False, n_jobs=-1):
3215
+ from .io import _read_and_merge_data
3216
+ from .plot import _plot_plates
3217
+
3218
+ db_loc = [src+'/measurements/measurements.db']
3219
+ tables = ['cell', 'nucleus', 'pathogen','cytoplasm']
3220
+ include_multinucleated, include_multiinfected, include_noninfected = True, 2.0, True
3221
+
3222
+ df, _ = _read_and_merge_data(db_loc,
3223
+ tables,
3224
+ verbose=verbose,
3225
+ include_multinucleated=include_multinucleated,
3226
+ include_multiinfected=include_multiinfected,
3227
+ include_noninfected=include_noninfected)
3228
+
3229
+ if not channel_of_interest is None:
3230
+ df['recruitment'] = df[f'pathogen_channel_{channel_of_interest}_mean_intensity']/df[f'cytoplasm_channel_{channel_of_interest}_mean_intensity']
3231
+ feature_string = f'channel_{channel_of_interest}'
3232
+ else:
3233
+ feature_string = None
3234
+
3235
+ output = _permutation_importance(df, feature_string, col_to_compare, pos, neg, exclude, n_repeats, clean, nr_to_plot, n_estimators=n_estimators, random_state=42, model_type=model_type, n_jobs=n_jobs)
3236
+
3237
+ _shap_analysis(output[3], output[4], output[5])
3238
+
3239
+ features = output[0].select_dtypes(include=[np.number]).columns.tolist()
3240
+
3241
+ if not variable in features:
3242
+ raise ValueError(f"Variable {variable} not found in the dataframe. Please choose one of the following: {features}")
3243
+
3244
+ plate_heatmap = _plot_plates(output[0], variable, grouping, min_max, cmap, min_count)
3245
+ return [output, plate_heatmap]
3246
+
3247
+ def join_measurments_and_annotation(src, tables = ['cell', 'nucleus', 'pathogen','cytoplasm']):
3248
+
3249
+ from .io import _read_and_merge_data, _read_db
3250
+
3251
+ db_loc = [src+'/measurements/measurements.db']
3252
+ loc = src+'/measurements/measurements.db'
3253
+ df, _ = _read_and_merge_data(db_loc,
3254
+ tables,
3255
+ verbose=True,
3256
+ include_multinucleated=True,
3257
+ include_multiinfected=True,
3258
+ include_noninfected=True)
3259
+
3260
+ paths_df = _read_db(loc, tables=['png_list'])
3261
+
3262
+ merged_df = pd.merge(df, paths_df[0], on='prcfo', how='left')
3263
+
3264
+ return merged_df
3265
+
3266
+ def jitterplot_by_annotation(src, x_column, y_column, plot_title='Jitter Plot', output_path=None, filter_column=None, filter_values=None):
3267
+ """
3268
+ Reads a CSV file and creates a jitter plot of one column grouped by another column.
3269
+
3270
+ Args:
3271
+ src (str): Path to the source data.
3272
+ x_column (str): Name of the column to be used for the x-axis.
3273
+ y_column (str): Name of the column to be used for the y-axis.
3274
+ plot_title (str): Title of the plot. Default is 'Jitter Plot'.
3275
+ output_path (str): Path to save the plot image. If None, the plot will be displayed. Default is None.
3276
+
3277
+ Returns:
3278
+ pd.DataFrame: The filtered and balanced DataFrame.
3279
+ """
3280
+ # Read the CSV file into a DataFrame
3281
+ df = join_measurments_and_annotation(src, tables=['cell', 'nucleus', 'pathogen', 'cytoplasm'])
3282
+
3283
+ # Print column names for debugging
3284
+ print(f"Generated dataframe with: {df.shape[1]} columns and {df.shape[0]} rows")
3285
+ #print("Columns in DataFrame:", df.columns.tolist())
3286
+
3287
+ # Replace NaN values with a specific label in x_column
3288
+ df[x_column] = df[x_column].fillna('NaN')
3289
+
3290
+ # Filter the DataFrame if filter_column and filter_values are provided
3291
+ if not filter_column is None:
3292
+ if isinstance(filter_column, str):
3293
+ df = df[df[filter_column].isin(filter_values)]
3294
+ if isinstance(filter_column, list):
3295
+ for i,val in enumerate(filter_column):
3296
+ print(f'hello {len(df)}')
3297
+ df = df[df[val].isin(filter_values[i])]
3298
+
3299
+ # Use the correct column names based on your DataFrame
3300
+ required_columns = ['plate_x', 'row_x', 'col_x']
3301
+ if not all(column in df.columns for column in required_columns):
3302
+ raise KeyError(f"DataFrame does not contain the necessary columns: {required_columns}")
3303
+
3304
+ # Filter to retain rows with non-NaN values in x_column and with matching plate, row, col values
3305
+ non_nan_df = df[df[x_column] != 'NaN']
3306
+ retained_rows = df[df[['plate_x', 'row_x', 'col_x']].apply(tuple, axis=1).isin(non_nan_df[['plate_x', 'row_x', 'col_x']].apply(tuple, axis=1))]
3307
+
3308
+ # Determine the minimum count of examples across all groups in x_column
3309
+ min_count = retained_rows[x_column].value_counts().min()
3310
+ print(f'Found {min_count} annotated images')
3311
+
3312
+ # Randomly sample min_count examples from each group in x_column
3313
+ balanced_df = retained_rows.groupby(x_column).apply(lambda x: x.sample(min_count, random_state=42)).reset_index(drop=True)
3314
+
3315
+ # Create the jitter plot
3316
+ plt.figure(figsize=(10, 6))
3317
+ jitter_plot = sns.stripplot(data=balanced_df, x=x_column, y=y_column, hue=x_column, jitter=True, palette='viridis', dodge=False)
3318
+ plt.title(plot_title)
3319
+ plt.xlabel(x_column)
3320
+ plt.ylabel(y_column)
3321
+
3322
+ # Customize the x-axis labels
3323
+ plt.xticks(rotation=45, ha='right')
3324
+
3325
+ # Adjust the position of the x-axis labels to be centered below the data
3326
+ ax = plt.gca()
3327
+ ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha='center')
3328
+
3329
+ # Save the plot to a file or display it
3330
+ if output_path:
3331
+ plt.savefig(output_path, bbox_inches='tight')
3332
+ print(f"Jitter plot saved to {output_path}")
3333
+ else:
3334
+ plt.show()
3335
+
3336
+ return balanced_df