spacr 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
spacr/core.py ADDED
@@ -0,0 +1,2250 @@
1
+ import os, sqlite3, gc, torch, time, random, shutil, cv2, tarfile, datetime
2
+
3
+ # image and array processing
4
+ import numpy as np
5
+ import pandas as pd
6
+
7
+ import cellpose
8
+ from cellpose import models as cp_models
9
+ from cellpose import denoise
10
+
11
+ import statsmodels.formula.api as smf
12
+ import statsmodels.api as sm
13
+ from functools import reduce
14
+ from IPython.display import display
15
+ from multiprocessing import Pool, cpu_count, Value, Lock
16
+
17
+ import seaborn as sns
18
+ import matplotlib.pyplot as plt
19
+ from skimage.measure import regionprops, label
20
+ import skimage.measure as measure
21
+ from skimage.transform import resize as resizescikit
22
+ from sklearn.model_selection import train_test_split
23
+ from collections import defaultdict
24
+ import multiprocessing
25
+ from torch.utils.data import DataLoader, random_split
26
+ import matplotlib
27
+ matplotlib.use('Agg')
28
+
29
+ import torchvision.transforms as transforms
30
+ from sklearn.model_selection import train_test_split
31
+ from sklearn.ensemble import IsolationForest
32
+
33
+ from .logger import log_function_call
34
+
35
+ #from .io import TarImageDataset, NoClassDataset, MyDataset, read_db, _copy_missclassified, read_mask, load_normalized_images_and_labels, load_images_and_labels
36
+ #from .plot import plot_merged, plot_arrays, _plot_controls, _plot_recruitment, _imshow, _plot_histograms_and_stats, _reg_v_plot, visualize_masks, plot_comparison_results
37
+ #from .utils import extract_boundaries, boundary_f1_score, compute_segmentation_ap, jaccard_index, dice_coefficient, _object_filter
38
+ #from .utils import resize_images_and_labels, generate_fraction_map, MLR, fishers_odds, lasso_reg, model_metrics, _map_wells_png, check_multicollinearity, init_globals, add_images_to_tar
39
+ #from .utils import get_paths_from_db, pick_best_model, test_model_performance, evaluate_model_performance, compute_irm_penalty
40
+ #from .utils import _pivot_counts_table, _generate_masks, _get_cellpose_channels, annotate_conditions, _calculate_recruitment, calculate_loss, _group_by_well, choose_model
41
+
42
+ @log_function_call
43
+ def analyze_plaques(folder):
44
+ summary_data = []
45
+ details_data = []
46
+
47
+ for filename in os.listdir(folder):
48
+ filepath = os.path.join(folder, filename)
49
+ if os.path.isfile(filepath):
50
+ # Assuming each file is a NumPy array file (.npy) containing a 16-bit labeled image
51
+ image = np.load(filepath)
52
+
53
+ labeled_image = label(image)
54
+ regions = regionprops(labeled_image)
55
+
56
+ object_count = len(regions)
57
+ sizes = [region.area for region in regions]
58
+ average_size = np.mean(sizes) if sizes else 0
59
+
60
+ summary_data.append({'file': filename, 'object_count': object_count, 'average_size': average_size})
61
+ for size in sizes:
62
+ details_data.append({'file': filename, 'plaque_size': size})
63
+
64
+ # Convert lists to pandas DataFrames
65
+ summary_df = pd.DataFrame(summary_data)
66
+ details_df = pd.DataFrame(details_data)
67
+
68
+ # Save DataFrames to a SQLite database
69
+ db_name = 'plaques_analysis.db'
70
+ conn = sqlite3.connect(db_name)
71
+
72
+ summary_df.to_sql('summary', conn, if_exists='replace', index=False)
73
+ details_df.to_sql('details', conn, if_exists='replace', index=False)
74
+
75
+ conn.close()
76
+
77
+ print(f"Analysis completed and saved to database '{db_name}'.")
78
+
79
+ @log_function_call
80
+ def compare_masks(dir1, dir2, dir3, verbose=False):
81
+
82
+ from .io import _read_mask
83
+ from .plot import visualize_masks, plot_comparison_results
84
+ from .utils import extract_boundaries, boundary_f1_score, compute_segmentation_ap, jaccard_index, dice_coefficient
85
+
86
+ filenames = os.listdir(dir1)
87
+ results = []
88
+ cond_1 = os.path.basename(dir1)
89
+ cond_2 = os.path.basename(dir2)
90
+ cond_3 = os.path.basename(dir3)
91
+ for index, filename in enumerate(filenames):
92
+ print(f'Processing image:{index+1}', end='\r', flush=True)
93
+ path1, path2, path3 = os.path.join(dir1, filename), os.path.join(dir2, filename), os.path.join(dir3, filename)
94
+ if os.path.exists(path2) and os.path.exists(path3):
95
+
96
+ mask1, mask2, mask3 = _read_mask(path1), _read_mask(path2), _read_mask(path3)
97
+ boundary_true1, boundary_true2, boundary_true3 = extract_boundaries(mask1), extract_boundaries(mask2), extract_boundaries(mask3)
98
+
99
+
100
+ true_masks, pred_masks = [mask1], [mask2, mask3] # Assuming mask1 is the ground truth for simplicity
101
+ true_labels, pred_labels_1, pred_labels_2 = label(mask1), label(mask2), label(mask3)
102
+ average_precision_0, average_precision_1 = compute_segmentation_ap(mask1, mask2), compute_segmentation_ap(mask1, mask3)
103
+ ap_scores = [average_precision_0, average_precision_1]
104
+
105
+ if verbose:
106
+ unique_values1, unique_values2, unique_values3 = np.unique(mask1), np.unique(mask2), np.unique(mask3)
107
+ print(f"Unique values in mask 1: {unique_values1}, mask 2: {unique_values2}, mask 3: {unique_values3}")
108
+ visualize_masks(boundary_true1, boundary_true2, boundary_true3, title=f"Boundaries - {filename}")
109
+
110
+ boundary_f1_12, boundary_f1_13, boundary_f1_23 = boundary_f1_score(mask1, mask2), boundary_f1_score(mask1, mask3), boundary_f1_score(mask2, mask3)
111
+
112
+ if (np.unique(mask1).size == 1 and np.unique(mask1)[0] == 0) and \
113
+ (np.unique(mask2).size == 1 and np.unique(mask2)[0] == 0) and \
114
+ (np.unique(mask3).size == 1 and np.unique(mask3)[0] == 0):
115
+ continue
116
+
117
+ if verbose:
118
+ unique_values4, unique_values5, unique_values6 = np.unique(boundary_f1_12), np.unique(boundary_f1_13), np.unique(boundary_f1_23)
119
+ print(f"Unique values in boundary mask 1: {unique_values4}, mask 2: {unique_values5}, mask 3: {unique_values6}")
120
+ visualize_masks(mask1, mask2, mask3, title=filename)
121
+
122
+ jaccard12 = jaccard_index(mask1, mask2)
123
+ dice12 = dice_coefficient(mask1, mask2)
124
+ jaccard13 = jaccard_index(mask1, mask3)
125
+ dice13 = dice_coefficient(mask1, mask3)
126
+ jaccard23 = jaccard_index(mask2, mask3)
127
+ dice23 = dice_coefficient(mask2, mask3)
128
+
129
+ results.append({
130
+ f'filename': filename,
131
+ f'jaccard_{cond_1}_{cond_2}': jaccard12,
132
+ f'dice_{cond_1}_{cond_2}': dice12,
133
+ f'jaccard_{cond_1}_{cond_3}': jaccard13,
134
+ f'dice_{cond_1}_{cond_3}': dice13,
135
+ f'jaccard_{cond_2}_{cond_3}': jaccard23,
136
+ f'dice_{cond_2}_{cond_3}': dice23,
137
+ f'boundary_f1_{cond_1}_{cond_2}': boundary_f1_12,
138
+ f'boundary_f1_{cond_1}_{cond_3}': boundary_f1_13,
139
+ f'boundary_f1_{cond_2}_{cond_3}': boundary_f1_23,
140
+ f'average_precision_{cond_1}_{cond_2}': ap_scores[0],
141
+ f'average_precision_{cond_1}_{cond_3}': ap_scores[1]
142
+ })
143
+ else:
144
+ print(f'Cannot find {path1} or {path2} or {path3}')
145
+ fig = plot_comparison_results(results)
146
+ return results, fig
147
+
148
+ def generate_cp_masks(settings):
149
+
150
+ src = settings['src']
151
+ model_name = settings['model_name']
152
+ channels = settings['channels']
153
+ diameter = settings['diameter']
154
+ regex = '.tif'
155
+ #flow_threshold = 30
156
+ cellprob_threshold = settings['cellprob_threshold']
157
+ figuresize = 25
158
+ cmap = 'inferno'
159
+ verbose = settings['verbose']
160
+ plot = settings['plot']
161
+ save = settings['save']
162
+ custom_model = settings['custom_model']
163
+ signal_thresholds = 1000
164
+ normalize = settings['normalize']
165
+ resize = settings['resize']
166
+ target_height = settings['width_height'][1]
167
+ target_width = settings['width_height'][0]
168
+ rescale = settings['rescale']
169
+ resample = settings['resample']
170
+ net_avg = settings['net_avg']
171
+ invert = settings['invert']
172
+ circular = settings['circular']
173
+ percentiles = settings['percentiles']
174
+ overlay = settings['overlay']
175
+ grayscale = settings['grayscale']
176
+ flow_threshold = settings['flow_threshold']
177
+ batch_size = settings['batch_size']
178
+
179
+ dst = os.path.join(src,'masks')
180
+ os.makedirs(dst, exist_ok=True)
181
+
182
+ identify_masks(src, dst, model_name, channels, diameter, batch_size, flow_threshold, cellprob_threshold, figuresize, cmap, verbose, plot, save, custom_model, signal_thresholds, normalize, resize, target_height, target_width, rescale, resample, net_avg, invert, circular, percentiles, overlay, grayscale)
183
+
184
+ @log_function_call
185
+ def train_cellpose(settings):
186
+
187
+ from .io import _load_normalized_images_and_labels, _load_images_and_labels
188
+ from .utils import resize_images_and_labels
189
+
190
+ img_src = settings['img_src']
191
+ mask_src= settings['mask_src']
192
+ secondary_image_dir = None
193
+ model_name = settings['model_name']
194
+ model_type = settings['model_type']
195
+ learning_rate = settings['learning_rate']
196
+ weight_decay = settings['weight_decay']
197
+ batch_size = settings['batch_size']
198
+ n_epochs = settings['n_epochs']
199
+ verbose = settings['verbose']
200
+ signal_thresholds = settings['signal_thresholds']
201
+ channels = settings['channels']
202
+ from_scratch = settings['from_scratch']
203
+ diameter = settings['diameter']
204
+ resize = settings['resize']
205
+ rescale = settings['rescale']
206
+ normalize = settings['normalize']
207
+ target_height = settings['width_height'][1]
208
+ target_width = settings['width_height'][0]
209
+ circular = settings['circular']
210
+ invert = settings['invert']
211
+ percentiles = settings['percentiles']
212
+ grayscale = settings['grayscale']
213
+
214
+ print(settings)
215
+
216
+ if from_scratch:
217
+ model_name=f'scratch_{model_name}_{model_type}_e{n_epochs}_X{target_width}_Y{target_height}.CP_model'
218
+ else:
219
+ model_name=f'{model_name}_{model_type}_e{n_epochs}_X{target_width}_Y{target_height}.CP_model'
220
+
221
+ model_save_path = os.path.join(mask_src, 'models', 'cellpose_model')
222
+ os.makedirs(os.path.dirname(model_save_path), exist_ok=True)
223
+
224
+ settings_df = pd.DataFrame(list(settings.items()), columns=['Key', 'Value'])
225
+ settings_csv = os.path.join(model_save_path,f'{model_name}_settings.csv')
226
+ settings_df.to_csv(settings_csv, index=False)
227
+
228
+ if model_type =='cyto':
229
+ if not from_scratch:
230
+ model = cp_models.CellposeModel(gpu=True, model_type=model_type)
231
+ else:
232
+ model = cp_models.CellposeModel(gpu=True, model_type=model_type, net_avg=False, diam_mean=diameter, pretrained_model=None)
233
+ if model_type !='cyto':
234
+ model = cp_models.CellposeModel(gpu=True, model_type=model_type)
235
+
236
+
237
+
238
+ if normalize:
239
+ images, masks, image_names, mask_names = _load_normalized_images_and_labels(image_dir=img_src, label_dir=mask_src, secondary_image_dir=secondary_image_dir, signal_thresholds=signal_thresholds, channels=channels, percentiles=percentiles, circular=circular, invert=invert, visualize=verbose)
240
+ images = [np.squeeze(img) if img.shape[-1] == 1 else img for img in images]
241
+ else:
242
+ images, masks, image_names, mask_names = _load_images_and_labels(img_src, mask_src, circular, invert)
243
+ images = [np.squeeze(img) if img.shape[-1] == 1 else img for img in images]
244
+
245
+ if resize:
246
+ images, masks = resize_images_and_labels(images, masks, target_height, target_width, show_example=True)
247
+
248
+ if model_type == 'cyto':
249
+ cp_channels = [0,1]
250
+ if model_type == 'cyto2':
251
+ cp_channels = [0,2]
252
+ if model_type == 'nucleus':
253
+ cp_channels = [0,0]
254
+ if grayscale:
255
+ cp_channels = [0,0]
256
+ images = [np.squeeze(img) if img.ndim == 3 and 1 in img.shape else img for img in images]
257
+
258
+ masks = [np.squeeze(mask) if mask.ndim == 3 and 1 in mask.shape else mask for mask in masks]
259
+
260
+ print(f'image shape: {images[0].shape}, image type: images[0].shape mask shape: {masks[0].shape}, image type: masks[0].shape')
261
+ save_every = int(n_epochs/10)
262
+ print('cellpose image input dtype', images[0].dtype)
263
+ print('cellpose mask input dtype', masks[0].dtype)
264
+ # Train the model
265
+ model.train(train_data=images, #(list of arrays (2D or 3D)) – images for training
266
+ train_labels=masks, #(list of arrays (2D or 3D)) – labels for train_data, where 0=no masks; 1,2,…=mask labels can include flows as additional images
267
+ train_files=image_names, #(list of strings) – file names for images in train_data (to save flows for future runs)
268
+ channels=cp_channels, #(list of ints (default, None)) – channels to use for training
269
+ normalize=False, #(bool (default, True)) – normalize data so 0.0=1st percentile and 1.0=99th percentile of image intensities in each channel
270
+ save_path=model_save_path, #(string (default, None)) – where to save trained model, if None it is not saved
271
+ save_every=save_every, #(int (default, 100)) – save network every [save_every] epochs
272
+ learning_rate=learning_rate, #(float or list/np.ndarray (default, 0.2)) – learning rate for training, if list, must be same length as n_epochs
273
+ n_epochs=n_epochs, #(int (default, 500)) – how many times to go through whole training set during training
274
+ weight_decay=weight_decay, #(float (default, 0.00001)) –
275
+ SGD=True, #(bool (default, True)) – use SGD as optimization instead of RAdam
276
+ batch_size=batch_size, #(int (optional, default 8)) – number of 224x224 patches to run simultaneously on the GPU (can make smaller or bigger depending on GPU memory usage)
277
+ nimg_per_epoch=None, #(int (optional, default None)) – minimum number of images to train on per epoch, with a small training set (< 8 images) it may help to set to 8
278
+ rescale=rescale, #(bool (default, True)) – whether or not to rescale images to diam_mean during training, if True it assumes you will fit a size model after training or resize your images accordingly, if False it will try to train the model to be scale-invariant (works worse)
279
+ min_train_masks=1, #(int (default, 5)) – minimum number of masks an image must have to use in training set
280
+ model_name=model_name) #(str (default, None)) – name of network, otherwise saved with name as params + training start time
281
+
282
+ return print(f"Model saved at: {model_save_path}/{model_name}")
283
+
284
+ @log_function_call
285
+ def analyze_data_reg(sequencing_loc, dv_loc, agg_type = 'mean', dv_col='pred', transform=None, min_cell_count=50, min_reads=100, min_wells=2, max_wells=1000, min_frequency=0.0,remove_outlier_genes=False, refine_model=False,by_plate=False, regression_type='mlr', alpha_value=0.01, fishers=False, fisher_threshold=0.9):
286
+
287
+ from .plot import _reg_v_plot
288
+ from .utils import generate_fraction_map, MLR, fishers_odds, lasso_reg
289
+
290
+ def qstring_to_float(qstr):
291
+ number = int(qstr[1:]) # Remove the "q" and convert the rest to an integer
292
+ return number / 100.0
293
+
294
+ columns_list = ['c1', 'c2', 'c3']
295
+ plate_list = ['p1','p3','p4']
296
+
297
+ dv_df = pd.read_csv(dv_loc)#, index_col='prc')
298
+
299
+ if agg_type.startswith('q'):
300
+ val = qstring_to_float(agg_type)
301
+ agg_type = lambda x: x.quantile(val)
302
+
303
+ # Aggregating for mean prediction, total count and count of values > 0.95
304
+ dv_df = dv_df.groupby('prc').agg(
305
+ pred=(dv_col, agg_type),
306
+ count_prc=('prc', 'size'),
307
+ mean_pathogen_area=('pathogen_area', 'mean')
308
+ )
309
+
310
+ dv_df = dv_df[dv_df['count_prc'] >= min_cell_count]
311
+ sequencing_df = pd.read_csv(sequencing_loc)
312
+
313
+
314
+ reads_df, stats_dict = process_reads(df=sequencing_df,
315
+ min_reads=min_reads,
316
+ min_wells=min_wells,
317
+ max_wells=max_wells,
318
+ gene_column='gene',
319
+ remove_outliers=remove_outlier_genes)
320
+
321
+ reads_df['value'] = reads_df['count']/reads_df['well_read_sum']
322
+ reads_df['gene_grna'] = reads_df['gene']+'_'+reads_df['grna']
323
+
324
+ display(reads_df)
325
+
326
+ df_long = reads_df
327
+
328
+ df_long = df_long[df_long['value'] > min_frequency] # removes gRNAs under a certain proportion
329
+ #df_long = df_long[df_long['value']<1.0] # removes gRNAs in wells with only one gRNA
330
+
331
+ # Extract gene and grna info from gene_grna column
332
+ df_long["gene"] = df_long["grna"].str.split("_").str[1]
333
+ df_long["grna"] = df_long["grna"].str.split("_").str[2]
334
+
335
+ agg_df = df_long.groupby('prc')['count'].sum().reset_index()
336
+ agg_df = agg_df.rename(columns={'count': 'count_sum'})
337
+ df_long = pd.merge(df_long, agg_df, on='prc', how='left')
338
+ df_long['value'] = df_long['count']/df_long['count_sum']
339
+
340
+ merged_df = df_long.merge(dv_df, left_on='prc', right_index=True)
341
+ merged_df = merged_df[merged_df['value'] > 0]
342
+ merged_df['plate'] = merged_df['prc'].str.split('_').str[0]
343
+ merged_df['row'] = merged_df['prc'].str.split('_').str[1]
344
+ merged_df['column'] = merged_df['prc'].str.split('_').str[2]
345
+
346
+ merged_df = merged_df[~merged_df['column'].isin(columns_list)]
347
+ merged_df = merged_df[merged_df['plate'].isin(plate_list)]
348
+
349
+ if transform == 'log':
350
+ merged_df['pred'] = np.log(merged_df['pred'] + 1e-10)
351
+
352
+ # Printing the unique values in 'col' and 'plate' columns
353
+ print("Unique values in col:", merged_df['column'].unique())
354
+ print("Unique values in plate:", merged_df['plate'].unique())
355
+ display(merged_df)
356
+
357
+ if fishers:
358
+ iv_df = generate_fraction_map(df=reads_df,
359
+ gene_column='grna',
360
+ min_frequency=min_frequency)
361
+
362
+ fishers_df = iv_df.join(dv_df, on='prc', how='inner')
363
+
364
+ significant_mutants = fishers_odds(df=fishers_df, threshold=fisher_threshold, phenotyp_col='pred')
365
+ significant_mutants = significant_mutants.sort_values(by='OddsRatio', ascending=False)
366
+ display(significant_mutants)
367
+
368
+ if regression_type == 'mlr':
369
+ if by_plate:
370
+ merged_df2 = merged_df.copy()
371
+ for plate in merged_df2['plate'].unique():
372
+ merged_df = merged_df2[merged_df2['plate'] == plate]
373
+ print(f'merged_df: {len(merged_df)}, plate: {plate}')
374
+ if len(merged_df) <100:
375
+ break
376
+
377
+ max_effects, max_effects_pvalues, model, df = MLR(merged_df, refine_model)
378
+ else:
379
+
380
+ max_effects, max_effects_pvalues, model, df = MLR(merged_df, refine_model)
381
+ return max_effects, max_effects_pvalues, model, df
382
+
383
+ if regression_type == 'ridge' or regression_type == 'lasso':
384
+ coeffs = lasso_reg(merged_df, alpha_value=alpha_value, reg_type=regression_type)
385
+ return coeffs
386
+
387
+ if regression_type == 'mixed':
388
+ model = smf.mixedlm("pred ~ gene_grna - 1", merged_df, groups=merged_df["plate"], re_formula="~1")
389
+ result = model.fit(method="bfgs")
390
+ print(result.summary())
391
+
392
+ # Print AIC and BIC
393
+ print("AIC:", result.aic)
394
+ print("BIC:", result.bic)
395
+
396
+
397
+ results_df = pd.DataFrame({
398
+ 'effect': result.params,
399
+ 'Standard Error': result.bse,
400
+ 'T-Value': result.tvalues,
401
+ 'p': result.pvalues
402
+ })
403
+
404
+ display(results_df)
405
+ _reg_v_plot(df=results_df)
406
+
407
+ std_resid = result.resid
408
+
409
+ # Create subplots
410
+ fig, axes = plt.subplots(1, 3, figsize=(15, 5))
411
+
412
+ # Histogram of Residuals
413
+ axes[0].hist(std_resid, bins=50, edgecolor='k')
414
+ axes[0].set_xlabel('Residuals')
415
+ axes[0].set_ylabel('Frequency')
416
+ axes[0].set_title('Histogram of Residuals')
417
+
418
+ # Boxplot of Residuals
419
+ axes[1].boxplot(std_resid)
420
+ axes[1].set_ylabel('Residuals')
421
+ axes[1].set_title('Boxplot of Residuals')
422
+
423
+ # QQ Plot
424
+ sm.qqplot(std_resid, line='45', ax=axes[2])
425
+ axes[2].set_title('QQ Plot')
426
+
427
+ # Show plots
428
+ plt.tight_layout()
429
+ plt.show()
430
+
431
+ return result
432
+
433
+ @log_function_call
434
+ def analyze_data_reg(sequencing_loc, dv_loc, agg_type = 'mean', min_cell_count=50, min_reads=100, min_wells=2, max_wells=1000, remove_outlier_genes=False, refine_model=False, by_plate=False, threshold=0.5, fishers=False):
435
+
436
+ from .plot import _reg_v_plot
437
+ from .utils import generate_fraction_map, fishers_odds, model_metrics
438
+
439
+ def qstring_to_float(qstr):
440
+ number = int(qstr[1:]) # Remove the "q" and convert the rest to an integer
441
+ return number / 100.0
442
+
443
+ columns_list = ['c1', 'c2', 'c3', 'c15']
444
+ plate_list = ['p1','p2','p3','p4']
445
+
446
+ dv_df = pd.read_csv(dv_loc)#, index_col='prc')
447
+
448
+ if agg_type.startswith('q'):
449
+ val = qstring_to_float(agg_type)
450
+ agg_type = lambda x: x.quantile(val)
451
+
452
+ # Aggregating for mean prediction, total count and count of values > 0.95
453
+ dv_df = dv_df.groupby('prc').agg(
454
+ pred=('pred', agg_type),
455
+ count_prc=('prc', 'size'),
456
+ #count_above_95=('pred', lambda x: (x > 0.95).sum()),
457
+ mean_pathogen_area=('pathogen_area', 'mean')
458
+ )
459
+
460
+ dv_df = dv_df[dv_df['count_prc'] >= min_cell_count]
461
+ sequencing_df = pd.read_csv(sequencing_loc)
462
+
463
+ reads_df, stats_dict = process_reads(df=sequencing_df,
464
+ min_reads=min_reads,
465
+ min_wells=min_wells,
466
+ max_wells=max_wells,
467
+ gene_column='gene',
468
+ remove_outliers=remove_outlier_genes)
469
+
470
+ iv_df = generate_fraction_map(df=reads_df,
471
+ gene_column='grna',
472
+ min_frequency=0.0)
473
+
474
+ # Melt the iv_df to long format
475
+ df_long = iv_df.reset_index().melt(id_vars=["prc"],
476
+ value_vars=iv_df.columns,
477
+ var_name="gene_grna",
478
+ value_name="value")
479
+
480
+ # Extract gene and grna info from gene_grna column
481
+ df_long["gene"] = df_long["gene_grna"].str.split("_").str[1]
482
+ df_long["grna"] = df_long["gene_grna"].str.split("_").str[2]
483
+
484
+ merged_df = df_long.merge(dv_df, left_on='prc', right_index=True)
485
+ merged_df = merged_df[merged_df['value'] > 0]
486
+ merged_df['plate'] = merged_df['prc'].str.split('_').str[0]
487
+ merged_df['row'] = merged_df['prc'].str.split('_').str[1]
488
+ merged_df['column'] = merged_df['prc'].str.split('_').str[2]
489
+
490
+ merged_df = merged_df[~merged_df['column'].isin(columns_list)]
491
+ merged_df = merged_df[merged_df['plate'].isin(plate_list)]
492
+
493
+ # Printing the unique values in 'col' and 'plate' columns
494
+ print("Unique values in col:", merged_df['column'].unique())
495
+ print("Unique values in plate:", merged_df['plate'].unique())
496
+
497
+ if not by_plate:
498
+ if fishers:
499
+ fishers_odds(df=merged_df, threshold=threshold, phenotyp_col='pred')
500
+
501
+ if by_plate:
502
+ merged_df2 = merged_df.copy()
503
+ for plate in merged_df2['plate'].unique():
504
+ merged_df = merged_df2[merged_df2['plate'] == plate]
505
+ print(f'merged_df: {len(merged_df)}, plate: {plate}')
506
+ if len(merged_df) <100:
507
+ break
508
+ display(merged_df)
509
+
510
+ model = smf.ols("pred ~ gene + grna + gene:grna + plate + row + column", merged_df).fit()
511
+ #model = smf.ols("pred ~ infection_time + gene + grna + gene:grna + plate + row + column", merged_df).fit()
512
+
513
+ # Display model metrics and summary
514
+ model_metrics(model)
515
+ #print(model.summary())
516
+
517
+ if refine_model:
518
+ # Filter outliers
519
+ std_resid = model.get_influence().resid_studentized_internal
520
+ outliers_resid = np.where(np.abs(std_resid) > 3)[0]
521
+ (c, p) = model.get_influence().cooks_distance
522
+ outliers_cooks = np.where(c > 4/(len(merged_df)-merged_df.shape[1]-1))[0]
523
+ outliers = reduce(np.union1d, (outliers_resid, outliers_cooks))
524
+ merged_df_filtered = merged_df.drop(merged_df.index[outliers])
525
+
526
+ display(merged_df_filtered)
527
+
528
+ # Refit the model with filtered data
529
+ model = smf.ols("pred ~ gene + grna + gene:grna + row + column", merged_df_filtered).fit()
530
+ print("Number of outliers detected by standardized residuals:", len(outliers_resid))
531
+ print("Number of outliers detected by Cook's distance:", len(outliers_cooks))
532
+
533
+ model_metrics(model)
534
+
535
+ # Extract interaction coefficients and determine the maximum effect size
536
+ interaction_coeffs = {key: val for key, val in model.params.items() if "gene[T." in key and ":grna[T." in key}
537
+ interaction_pvalues = {key: val for key, val in model.pvalues.items() if "gene[T." in key and ":grna[T." in key}
538
+
539
+ max_effects = {}
540
+ max_effects_pvalues = {}
541
+ for key, val in interaction_coeffs.items():
542
+ gene_name = key.split(":")[0].replace("gene[T.", "").replace("]", "")
543
+ if gene_name not in max_effects or abs(max_effects[gene_name]) < abs(val):
544
+ max_effects[gene_name] = val
545
+ max_effects_pvalues[gene_name] = interaction_pvalues[key]
546
+
547
+ for key in max_effects:
548
+ print(f"Key: {key}: {max_effects[key]}, p:{max_effects_pvalues[key]}")
549
+
550
+ df = pd.DataFrame([max_effects, max_effects_pvalues])
551
+ df = df.transpose()
552
+ df = df.rename(columns={df.columns[0]: 'effect', df.columns[1]: 'p'})
553
+ df = df.sort_values(by=['effect', 'p'], ascending=[False, True])
554
+
555
+ _reg_v_plot(df)
556
+
557
+ if fishers:
558
+ fishers_odds(df=merged_df, threshold=threshold, phenotyp_col='pred')
559
+ else:
560
+ display(merged_df)
561
+
562
+ model = smf.ols("pred ~ gene + grna + gene:grna + plate + row + column", merged_df).fit()
563
+
564
+ # Display model metrics and summary
565
+ model_metrics(model)
566
+
567
+ if refine_model:
568
+ # Filter outliers
569
+ std_resid = model.get_influence().resid_studentized_internal
570
+ outliers_resid = np.where(np.abs(std_resid) > 3)[0]
571
+ (c, p) = model.get_influence().cooks_distance
572
+ outliers_cooks = np.where(c > 4/(len(merged_df)-merged_df.shape[1]-1))[0]
573
+ outliers = reduce(np.union1d, (outliers_resid, outliers_cooks))
574
+ merged_df_filtered = merged_df.drop(merged_df.index[outliers])
575
+
576
+ display(merged_df_filtered)
577
+
578
+ # Refit the model with filtered data
579
+ model = smf.ols("pred ~ gene + grna + gene:grna + plate + row + column", merged_df_filtered).fit()
580
+ print("Number of outliers detected by standardized residuals:", len(outliers_resid))
581
+ print("Number of outliers detected by Cook's distance:", len(outliers_cooks))
582
+
583
+ model_metrics(model)
584
+
585
+ # Extract interaction coefficients and determine the maximum effect size
586
+ interaction_coeffs = {key: val for key, val in model.params.items() if "gene[T." in key and ":grna[T." in key}
587
+ interaction_pvalues = {key: val for key, val in model.pvalues.items() if "gene[T." in key and ":grna[T." in key}
588
+
589
+ max_effects = {}
590
+ max_effects_pvalues = {}
591
+ for key, val in interaction_coeffs.items():
592
+ gene_name = key.split(":")[0].replace("gene[T.", "").replace("]", "")
593
+ if gene_name not in max_effects or abs(max_effects[gene_name]) < abs(val):
594
+ max_effects[gene_name] = val
595
+ max_effects_pvalues[gene_name] = interaction_pvalues[key]
596
+
597
+ for key in max_effects:
598
+ print(f"Key: {key}: {max_effects[key]}, p:{max_effects_pvalues[key]}")
599
+
600
+ df = pd.DataFrame([max_effects, max_effects_pvalues])
601
+ df = df.transpose()
602
+ df = df.rename(columns={df.columns[0]: 'effect', df.columns[1]: 'p'})
603
+ df = df.sort_values(by=['effect', 'p'], ascending=[False, True])
604
+
605
+ _reg_v_plot(df)
606
+
607
+ if fishers:
608
+ fishers_odds(df=merged_df, threshold=threshold, phenotyp_col='pred')
609
+
610
+ return max_effects, max_effects_pvalues, model, df
611
+
612
+ @log_function_call
613
+ def regression_analasys(dv_df,sequencing_loc, min_reads=75, min_wells=2, max_wells=0, model_type = 'mlr', min_cells=100, transform='logit', min_frequency=0.05, gene_column='gene', effect_size_threshold=0.25, fishers=True, clean_regression=False, VIF_threshold=10):
614
+
615
+ from .utils import generate_fraction_map, fishers_odds, model_metrics, check_multicollinearity
616
+
617
+ sequencing_df = pd.read_csv(sequencing_loc)
618
+ columns_list = ['c1','c2','c3', 'c15']
619
+ sequencing_df = sequencing_df[~sequencing_df['col'].isin(columns_list)]
620
+
621
+ reads_df, stats_dict = process_reads(df=sequencing_df,
622
+ min_reads=min_reads,
623
+ min_wells=min_wells,
624
+ max_wells=max_wells,
625
+ gene_column='gene')
626
+
627
+ display(reads_df)
628
+
629
+ iv_df = generate_fraction_map(df=reads_df,
630
+ gene_column=gene_column,
631
+ min_frequency=min_frequency)
632
+
633
+ display(iv_df)
634
+
635
+ dv_df = dv_df[dv_df['count_prc']>min_cells]
636
+ display(dv_df)
637
+ merged_df = iv_df.join(dv_df, on='prc', how='inner')
638
+ display(merged_df)
639
+ fisher_df = merged_df.copy()
640
+
641
+ merged_df.reset_index(inplace=True)
642
+ merged_df[['plate', 'row', 'col']] = merged_df['prc'].str.split('_', expand=True)
643
+ merged_df = merged_df.drop(columns=['prc'])
644
+ merged_df.dropna(inplace=True)
645
+ merged_df = pd.get_dummies(merged_df, columns=['plate', 'row', 'col'], drop_first=True)
646
+
647
+ y = merged_df['mean_pred']
648
+
649
+ if model_type == 'mlr':
650
+ merged_df = merged_df.drop(columns=['count_prc'])
651
+
652
+ elif model_type == 'wls':
653
+ weights = merged_df['count_prc']
654
+
655
+ elif model_type == 'glm':
656
+ merged_df = merged_df.drop(columns=['count_prc'])
657
+
658
+ if transform == 'logit':
659
+ # logit transformation
660
+ epsilon = 1e-15
661
+ y = np.log(y + epsilon) - np.log(1 - y + epsilon)
662
+
663
+ elif transform == 'log':
664
+ # log transformation
665
+ y = np.log10(y+1)
666
+
667
+ elif transform == 'center':
668
+ # Centering the y around 0
669
+ y_mean = y.mean()
670
+ y = y - y_mean
671
+
672
+ x = merged_df.drop('mean_pred', axis=1)
673
+ x = x.select_dtypes(include=[np.number])
674
+ #x = sm.add_constant(x)
675
+ x['const'] = 0.0
676
+
677
+ if model_type == 'mlr':
678
+ model = sm.OLS(y, x).fit()
679
+ model_metrics(model)
680
+
681
+ # Check for Multicollinearity
682
+ vif_data = check_multicollinearity(x.drop('const', axis=1)) # assuming you've added a constant to x
683
+ high_vif_columns = vif_data[vif_data["VIF"] > VIF_threshold]["Variable"].values # VIF threshold of 10 is common, but this can vary based on context
684
+
685
+ print(f"Columns with high VIF: {high_vif_columns}")
686
+ x = x.drop(columns=high_vif_columns) # dropping columns with high VIF
687
+
688
+ if clean_regression:
689
+ # 1. Filter by standardized residuals
690
+ std_resid = model.get_influence().resid_studentized_internal
691
+ outliers_resid = np.where(np.abs(std_resid) > 3)[0]
692
+
693
+ # 2. Filter by leverage
694
+ influence = model.get_influence().hat_matrix_diag
695
+ outliers_lev = np.where(influence > 2*(x.shape[1])/len(y))[0]
696
+
697
+ # 3. Filter by Cook's distance
698
+ (c, p) = model.get_influence().cooks_distance
699
+ outliers_cooks = np.where(c > 4/(len(y)-x.shape[1]-1))[0]
700
+
701
+ # Combine all identified outliers
702
+ outliers = reduce(np.union1d, (outliers_resid, outliers_lev, outliers_cooks))
703
+
704
+ # Filter out outliers
705
+ x_clean = x.drop(x.index[outliers])
706
+ y_clean = y.drop(y.index[outliers])
707
+
708
+ # Re-run the regression with the filtered data
709
+ model = sm.OLS(y_clean, x_clean).fit()
710
+ model_metrics(model)
711
+
712
+ elif model_type == 'wls':
713
+ model = sm.WLS(y, x, weights=weights).fit()
714
+
715
+ elif model_type == 'glm':
716
+ model = sm.GLM(y, x, family=sm.families.Binomial()).fit()
717
+
718
+ print(model.summary())
719
+
720
+ results_summary = model.summary()
721
+
722
+ results_as_html = results_summary.tables[1].as_html()
723
+ results_df = pd.read_html(results_as_html, header=0, index_col=0)[0]
724
+ results_df = results_df.sort_values(by='coef', ascending=False)
725
+
726
+ if model_type == 'mlr':
727
+ results_df['p'] = results_df['P>|t|']
728
+ elif model_type == 'wls':
729
+ results_df['p'] = results_df['P>|t|']
730
+ elif model_type == 'glm':
731
+ results_df['p'] = results_df['P>|z|']
732
+
733
+ results_df['type'] = 1
734
+ results_df.loc[results_df['p'] == 0.000, 'p'] = 0.005
735
+ results_df['-log10(p)'] = -np.log10(results_df['p'])
736
+
737
+ display(results_df)
738
+
739
+ # Create subplots
740
+ fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 15))
741
+
742
+ # Plot histogram on ax1
743
+ sns.histplot(data=y, kde=False, element="step", ax=ax1, color='teal')
744
+ ax1.set_xlim([0, 1])
745
+ ax1.spines['top'].set_visible(False)
746
+ ax1.spines['right'].set_visible(False)
747
+
748
+ # Prepare data for volcano plot on ax2
749
+ results_df['-log10(p)'] = -np.log10(results_df['p'])
750
+
751
+ # Assuming the 'type' column is in the merged_df
752
+ sc = ax2.scatter(results_df['coef'], results_df['-log10(p)'], c=results_df['type'], cmap='coolwarm')
753
+ ax2.set_title('Volcano Plot')
754
+ ax2.set_xlabel('Coefficient')
755
+ ax2.set_ylabel('-log10(P-value)')
756
+
757
+ # Adjust colorbar
758
+ cbar = plt.colorbar(sc, ax=ax2, ticks=[-1, 1])
759
+ cbar.set_label('Sign of Coefficient')
760
+ cbar.set_ticklabels(['-ve', '+ve'])
761
+
762
+ # Add text for specified points
763
+ for idx, row in results_df.iterrows():
764
+ if row['p'] < 0.05 and row['coef'] > effect_size_threshold:
765
+ ax2.text(row['coef'], -np.log10(row['p']), idx, fontsize=8, ha='center', va='bottom', color='black')
766
+
767
+ ax2.axhline(y=-np.log10(0.05), color='gray', linestyle='--')
768
+
769
+ plt.show()
770
+
771
+ #if model_type == 'mlr':
772
+ # show_residules(model)
773
+
774
+ if fishers:
775
+ threshold = 2*effect_size_threshold
776
+ fishers_odds(df=fisher_df, threshold=threshold, phenotyp_col='mean_pred')
777
+
778
+ return
779
+
780
+ @log_function_call
781
+ def merge_pred_mes(src,
782
+ pred_loc,
783
+ target='protein of interest',
784
+ cell_dim=4,
785
+ nucleus_dim=5,
786
+ pathogen_dim=6,
787
+ channel_of_interest=1,
788
+ pathogen_size_min=0,
789
+ nucleus_size_min=0,
790
+ cell_size_min=0,
791
+ pathogen_min=0,
792
+ nucleus_min=0,
793
+ cell_min=0,
794
+ target_min=0,
795
+ mask_chans=[0,1,2],
796
+ filter_data=False,
797
+ include_noninfected=False,
798
+ include_multiinfected=False,
799
+ include_multinucleated=False,
800
+ cells_per_well=10,
801
+ save_filtered_filelist=False,
802
+ verbose=False):
803
+
804
+ from .io import _read_and_merge_data
805
+ from .plot import _plot_histograms_and_stats
806
+
807
+ mask_chans=[cell_dim,nucleus_dim,pathogen_dim]
808
+ sns.color_palette("mako", as_cmap=True)
809
+ print(f'channel:{channel_of_interest} = {target}')
810
+ overlay_channels = [0, 1, 2, 3]
811
+ overlay_channels.remove(channel_of_interest)
812
+ overlay_channels.reverse()
813
+
814
+ db_loc = [src+'/measurements/measurements.db']
815
+ tables = ['cell', 'nucleus', 'pathogen','cytoplasm']
816
+ df, object_dfs = _read_and_merge_data(db_loc,
817
+ tables,
818
+ verbose=True,
819
+ include_multinucleated=include_multinucleated,
820
+ include_multiinfected=include_multiinfected,
821
+ include_noninfected=include_noninfected)
822
+ if filter_data:
823
+ df = df[df['cell_area'] > cell_size_min]
824
+ df = df[df[f'cell_channel_{mask_chans[2]}_mean_intensity'] > cell_min]
825
+ print(f'After cell filtration {len(df)}')
826
+ df = df[df['nucleus_area'] > nucleus_size_min]
827
+ df = df[df[f'nucleus_channel_{mask_chans[0]}_mean_intensity'] > nucleus_min]
828
+ print(f'After nucleus filtration {len(df)}')
829
+ df = df[df['pathogen_area'] > pathogen_size_min]
830
+ df=df[df[f'pathogen_channel_{mask_chans[1]}_mean_intensity'] > pathogen_min]
831
+ print(f'After pathogen filtration {len(df)}')
832
+ df = df[df[f'cell_channel_{channel_of_interest}_percentile_95'] > target_min]
833
+ print(f'After channel {channel_of_interest} filtration', len(df))
834
+
835
+ df['recruitment'] = df[f'pathogen_channel_{channel_of_interest}_mean_intensity']/df[f'cytoplasm_channel_{channel_of_interest}_mean_intensity']
836
+
837
+ pred_df = annotate_results(pred_loc=pred_loc)
838
+
839
+ if verbose:
840
+ _plot_histograms_and_stats(df=pred_df)
841
+
842
+ pred_df.set_index('prcfo', inplace=True)
843
+ pred_df = pred_df.drop(columns=['plate', 'row', 'col', 'field'])
844
+
845
+ joined_df = df.join(pred_df, how='inner')
846
+
847
+ if verbose:
848
+ _plot_histograms_and_stats(df=joined_df)
849
+
850
+ #dv = joined_df.copy()
851
+ #if 'prc' not in dv.columns:
852
+ #dv['prc'] = dv['plate'] + '_' + dv['row'] + '_' + dv['col']
853
+ #dv = dv[['pred']].groupby('prc').mean()
854
+ #dv.set_index('prc', inplace=True)
855
+
856
+ #loc = '/mnt/data/CellVoyager/20x/tsg101/crispr_screen/all/measurements/dv.csv'
857
+ #dv.to_csv(loc, index=True, header=True, mode='w')
858
+
859
+ return joined_df
860
+
861
+ def process_reads(df, min_reads, min_wells, max_wells, gene_column, remove_outliers=False):
862
+ print('start',len(df))
863
+ df = df[df['count'] >= min_reads]
864
+ print('after filtering min reads',min_reads, len(df))
865
+ reads_ls = df['count']
866
+ stats_dict = {}
867
+ stats_dict['screen_reads_mean'] = np.mean(reads_ls)
868
+ stats_dict['screen_reads_sd'] = np.std(reads_ls)
869
+ stats_dict['screen_reads_var'] = np.var(reads_ls)
870
+
871
+ well_read_sum = pd.DataFrame(df.groupby(['prc']).sum())
872
+ well_read_sum = well_read_sum.rename({'count': 'well_read_sum'}, axis=1)
873
+ well_sgRNA_count = pd.DataFrame(df.groupby(['prc']).count()[gene_column])
874
+ well_sgRNA_count = well_sgRNA_count.rename({gene_column: 'gRNAs_per_well'}, axis=1)
875
+ well_seq = pd.merge(well_read_sum, well_sgRNA_count, how='inner', suffixes=('', '_right'), left_index=True, right_index=True)
876
+ gRNA_well_count = pd.DataFrame(df.groupby([gene_column]).count()['prc'])
877
+ gRNA_well_count = gRNA_well_count.rename({'prc': 'gRNA_well_count'}, axis=1)
878
+ df = pd.merge(df, well_seq, on='prc', how='inner', suffixes=('', '_right'))
879
+ df = pd.merge(df, gRNA_well_count, on=gene_column, how='inner', suffixes=('', '_right'))
880
+
881
+ df = df[df['gRNA_well_count'] >= min_wells]
882
+ df = df[df['gRNA_well_count'] <= max_wells]
883
+
884
+ if remove_outliers:
885
+ clf = IsolationForest(contamination='auto', random_state=42, n_jobs=20)
886
+ #clf.fit(df.select_dtypes(include=['int', 'float']))
887
+ clf.fit(df[["gRNA_well_count", "count"]])
888
+ outlier_array = clf.predict(df[["gRNA_well_count", "count"]])
889
+ #outlier_array = clf.predict(df.select_dtypes(include=['int', 'float']))
890
+ outlier_df = pd.DataFrame(outlier_array, columns=['outlier'])
891
+ df['outlier'] = outlier_df['outlier']
892
+ outliers = pd.DataFrame(df[df['outlier']==-1])
893
+ df = pd.DataFrame(df[df['outlier']==1])
894
+ print('removed',len(outliers), 'outliers', 'inlers',len(df))
895
+
896
+ columns_to_drop = ['gRNA_well_count','gRNAs_per_well', 'well_read_sum']#, 'outlier']
897
+ df = df.drop(columns_to_drop, axis=1)
898
+
899
+ plates = ['p1', 'p2', 'p3', 'p4']
900
+ df = df[df.plate.isin(plates) == True]
901
+ print('after filtering out p5,p6,p7,p8',len(df))
902
+
903
+ gRNA_well_count = pd.DataFrame(df.groupby([gene_column]).count()['prc'])
904
+ gRNA_well_count = gRNA_well_count.rename({'prc': 'gRNA_well_count'}, axis=1)
905
+ df = pd.merge(df, gRNA_well_count, on=gene_column, how='inner', suffixes=('', '_right'))
906
+ well_read_sum = pd.DataFrame(df.groupby(['prc']).sum())
907
+ well_read_sum = well_read_sum.rename({'count': 'well_read_sum'}, axis=1)
908
+ well_sgRNA_count = pd.DataFrame(df.groupby(['prc']).count()[gene_column])
909
+ well_sgRNA_count = well_sgRNA_count.rename({gene_column: 'gRNAs_per_well'}, axis=1)
910
+ well_seq = pd.merge(well_read_sum, well_sgRNA_count, how='inner', suffixes=('', '_right'), left_index=True, right_index=True)
911
+ df = pd.merge(df, well_seq, on='prc', how='inner', suffixes=('', '_right'))
912
+
913
+ columns_to_drop = [col for col in df.columns if col.endswith('_right')]
914
+ columns_to_drop2 = [col for col in df.columns if col.endswith('0')]
915
+ columns_to_drop = columns_to_drop + columns_to_drop2
916
+ df = df.drop(columns_to_drop, axis=1)
917
+ return df, stats_dict
918
+
919
+ def annotate_results(pred_loc):
920
+
921
+ from .utils import _map_wells_png
922
+
923
+ df = pd.read_csv(pred_loc)
924
+ df = df.copy()
925
+ pc_col_list = ['c4','c5','c6','c7','c8','c9','c10','c11','c12','c13','c14','c15','c16','c17','c18','c19','c20','c21','c22','c23','c24']
926
+ pc_plate_list = ['p6','p7','p8', 'p9']
927
+
928
+ nc_col_list = ['c1','c2','c3']
929
+ nc_plate_list = ['p1','p2','p3','p4','p6','p7','p8', 'p9']
930
+
931
+ screen_col_list = ['c4','c5','c6','c7','c8','c9','c10','c11','c12','c13','c14','c15','c16','c17','c18','c19','c20','c21','c22','c23','c24']
932
+ screen_plate_list = ['p1','p2','p3','p4']
933
+
934
+ df[['plate', 'row', 'col', 'field', 'cell_id', 'prcfo']] = df['path'].apply(lambda x: pd.Series(_map_wells_png(x)))
935
+
936
+ df.loc[(df['col'].isin(pc_col_list)) & (df['plate'].isin(pc_plate_list)), 'condition'] = 'pc'
937
+ df.loc[(df['col'].isin(nc_col_list)) & (df['plate'].isin(nc_plate_list)), 'condition'] = 'nc'
938
+ df.loc[(df['col'].isin(screen_col_list)) & (df['plate'].isin(screen_plate_list)), 'condition'] = 'screen'
939
+
940
+ df = df.dropna(subset=['condition'])
941
+ display(df)
942
+ return df
943
+
944
+ def generate_dataset(src, file_type=None, experiment='TSG101_screen', sample=None):
945
+
946
+ from .utils import init_globals, add_images_to_tar
947
+
948
+ db_path = os.path.join(src, 'measurements','measurements.db')
949
+ dst = os.path.join(src, 'datasets')
950
+
951
+ global total_images
952
+ all_paths = []
953
+
954
+ # Connect to the database and retrieve the image paths
955
+ print(f'Reading DataBase: {db_path}')
956
+ with sqlite3.connect(db_path) as conn:
957
+ cursor = conn.cursor()
958
+ if file_type:
959
+ cursor.execute("SELECT png_path FROM png_list WHERE png_path LIKE ?", (f"%{file_type}%",))
960
+ else:
961
+ cursor.execute("SELECT png_path FROM png_list")
962
+ while True:
963
+ rows = cursor.fetchmany(1000)
964
+ if not rows:
965
+ break
966
+ all_paths.extend([row[0] for row in rows])
967
+
968
+ if isinstance(sample, int):
969
+ selected_paths = random.sample(all_paths, sample)
970
+ print(f'Random selection of {len(selected_paths)} paths')
971
+ else:
972
+ selected_paths = all_paths
973
+ random.shuffle(selected_paths)
974
+ print(f'All paths: {len(selected_paths)} paths')
975
+
976
+ total_images = len(selected_paths)
977
+ print(f'found {total_images} images')
978
+
979
+ # Create a temp folder in dst
980
+ temp_dir = os.path.join(dst, "temp_tars")
981
+ os.makedirs(temp_dir, exist_ok=True)
982
+
983
+ # Chunking the data
984
+ if len(selected_paths) > 10000:
985
+ num_procs = cpu_count()-2
986
+ chunk_size = len(selected_paths) // num_procs
987
+ remainder = len(selected_paths) % num_procs
988
+ else:
989
+ num_procs = 2
990
+ chunk_size = len(selected_paths) // 2
991
+ remainder = 0
992
+
993
+ paths_chunks = []
994
+ start = 0
995
+ for i in range(num_procs):
996
+ end = start + chunk_size + (1 if i < remainder else 0)
997
+ paths_chunks.append(selected_paths[start:end])
998
+ start = end
999
+
1000
+ temp_tar_files = [os.path.join(temp_dir, f'temp_{i}.tar') for i in range(num_procs)]
1001
+
1002
+ # Initialize the shared objects
1003
+ counter_ = Value('i', 0)
1004
+ lock_ = Lock()
1005
+
1006
+ ctx = multiprocessing.get_context('spawn')
1007
+
1008
+ print(f'Generating temporary tar files in {dst}')
1009
+
1010
+ # Combine the temporary tar files into a final tar
1011
+ date_name = datetime.date.today().strftime('%y%m%d')
1012
+ tar_name = f'{date_name}_{experiment}_{file_type}.tar'
1013
+ if os.path.exists(tar_name):
1014
+ number = random.randint(1, 100)
1015
+ tar_name_2 = f'{date_name}_{experiment}_{file_type}_{number}.tar'
1016
+ print(f'Warning: {os.path.basename(tar_name)} exists saving as {os.path.basename(tar_name_2)} ')
1017
+ tar_name = tar_name_2
1018
+
1019
+ # Add the counter and lock to the arguments for pool.map
1020
+ print(f'Merging temporary files')
1021
+ #with Pool(processes=num_procs, initializer=init_globals, initargs=(counter_, lock_)) as pool:
1022
+ # results = pool.map(add_images_to_tar, zip(paths_chunks, temp_tar_files))
1023
+
1024
+ with ctx.Pool(processes=num_procs, initializer=init_globals, initargs=(counter_, lock_)) as pool:
1025
+ results = pool.map(add_images_to_tar, zip(paths_chunks, temp_tar_files))
1026
+
1027
+ with tarfile.open(os.path.join(dst, tar_name), 'w') as final_tar:
1028
+ for tar_path in results:
1029
+ with tarfile.open(tar_path, 'r') as t:
1030
+ for member in t.getmembers():
1031
+ t.extract(member, path=dst)
1032
+ final_tar.add(os.path.join(dst, member.name), arcname=member.name)
1033
+ os.remove(os.path.join(dst, member.name))
1034
+ os.remove(tar_path)
1035
+
1036
+ # Delete the temp folder
1037
+ shutil.rmtree(temp_dir)
1038
+ print(f"\nSaved {total_images} images to {os.path.join(dst, tar_name)}")
1039
+
1040
+ def apply_model_to_tar(tar_path, model_path, file_type='cell_png', image_size=224, batch_size=64, normalize=True, preload='images', num_workers=10, verbose=False):
1041
+
1042
+ from .io import TarImageDataset, DataLoader
1043
+
1044
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
1045
+ if normalize:
1046
+ transform = transforms.Compose([
1047
+ transforms.ToTensor(),
1048
+ transforms.CenterCrop(size=(image_size, image_size)),
1049
+ transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])
1050
+ else:
1051
+ transform = transforms.Compose([
1052
+ transforms.ToTensor(),
1053
+ transforms.CenterCrop(size=(image_size, image_size))])
1054
+
1055
+ if verbose:
1056
+ print(f'Loading model from {model_path}')
1057
+ print(f'Loading dataset from {tar_path}')
1058
+
1059
+ model = torch.load(model_path)
1060
+
1061
+ dataset = TarImageDataset(tar_path, transform=transform)
1062
+ data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True)
1063
+
1064
+ model_name = os.path.splitext(os.path.basename(model_path))[0]
1065
+ dataset_name = os.path.splitext(os.path.basename(tar_path))[0]
1066
+ date_name = datetime.date.today().strftime('%y%m%d')
1067
+ dst = os.path.dirname(tar_path)
1068
+ result_loc = f'{dst}/{date_name}_{dataset_name}_{model_name}_result.csv'
1069
+
1070
+ model.eval()
1071
+ model = model.to(device)
1072
+
1073
+ if verbose:
1074
+ print(model)
1075
+ print(f'Generated dataset with {len(dataset)} images')
1076
+ print(f'Generating loader from {len(data_loader)} batches')
1077
+ print(f'Results wil be saved in: {result_loc}')
1078
+ print(f'Model is in eval mode')
1079
+ print(f'Model loaded to device')
1080
+
1081
+ prediction_pos_probs = []
1082
+ filenames_list = []
1083
+ gc.collect()
1084
+ with torch.no_grad():
1085
+ for batch_idx, (batch_images, filenames) in enumerate(data_loader, start=1):
1086
+ images = batch_images.to(torch.float).to(device)
1087
+ outputs = model(images)
1088
+ batch_prediction_pos_prob = torch.sigmoid(outputs).cpu().numpy()
1089
+ prediction_pos_probs.extend(batch_prediction_pos_prob.tolist())
1090
+ filenames_list.extend(filenames)
1091
+ print(f'\rbatch: {batch_idx}/{len(data_loader)}', end='\r', flush=True)
1092
+
1093
+ data = {'path':filenames_list, 'pred':prediction_pos_probs}
1094
+ df = pd.DataFrame(data, index=None)
1095
+ df.to_csv(result_loc, index=True, header=True, mode='w')
1096
+ torch.cuda.empty_cache()
1097
+ torch.cuda.memory.empty_cache()
1098
+ return df
1099
+
1100
+ def apply_model(src, model_path, image_size=224, batch_size=64, normalize=True, num_workers=10):
1101
+
1102
+ from .io import NoClassDataset
1103
+
1104
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
1105
+
1106
+ if normalize:
1107
+ transform = transforms.Compose([
1108
+ transforms.ToTensor(),
1109
+ transforms.CenterCrop(size=(image_size, image_size)),
1110
+ transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])
1111
+ else:
1112
+ transform = transforms.Compose([
1113
+ transforms.ToTensor(),
1114
+ transforms.CenterCrop(size=(image_size, image_size))])
1115
+
1116
+ model = torch.load(model_path)
1117
+ print(model)
1118
+
1119
+ print(f'Loading dataset in {src} with {len(src)} images')
1120
+ dataset = NoClassDataset(data_dir=src, transform=transform, shuffle=True, load_to_memory=False)
1121
+ data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
1122
+ print(f'Loaded {len(src)} images')
1123
+
1124
+ result_loc = os.path.splitext(model_path)[0]+datetime.date.today().strftime('%y%m%d')+'_'+os.path.splitext(model_path)[1]+'_test_result.csv'
1125
+ print(f'Results wil be saved in: {result_loc}')
1126
+
1127
+ model.eval()
1128
+ model = model.to(device)
1129
+ prediction_pos_probs = []
1130
+ filenames_list = []
1131
+ with torch.no_grad():
1132
+ for batch_idx, (batch_images, filenames) in enumerate(data_loader, start=1):
1133
+ images = batch_images.to(torch.float).to(device)
1134
+ outputs = model(images)
1135
+ batch_prediction_pos_prob = torch.sigmoid(outputs).cpu().numpy()
1136
+ prediction_pos_probs.extend(batch_prediction_pos_prob.tolist())
1137
+ filenames_list.extend(filenames)
1138
+ print(f'\rbatch: {batch_idx}/{len(data_loader)}', end='\r', flush=True)
1139
+ data = {'path':filenames_list, 'pred':prediction_pos_probs}
1140
+ df = pd.DataFrame(data, index=None)
1141
+ df.to_csv(result_loc, index=True, header=True, mode='w')
1142
+ torch.cuda.empty_cache()
1143
+ torch.cuda.memory.empty_cache()
1144
+ return df
1145
+
1146
+
1147
+ def generate_training_data_file_list(src,
1148
+ target='protein of interest',
1149
+ cell_dim=4,
1150
+ nucleus_dim=5,
1151
+ pathogen_dim=6,
1152
+ channel_of_interest=1,
1153
+ pathogen_size_min=0,
1154
+ nucleus_size_min=0,
1155
+ cell_size_min=0,
1156
+ pathogen_min=0,
1157
+ nucleus_min=0,
1158
+ cell_min=0,
1159
+ target_min=0,
1160
+ mask_chans=[0,1,2],
1161
+ filter_data=False,
1162
+ include_noninfected=False,
1163
+ include_multiinfected=False,
1164
+ include_multinucleated=False,
1165
+ cells_per_well=10,
1166
+ save_filtered_filelist=False):
1167
+
1168
+ from .io import _read_and_merge_data
1169
+
1170
+ mask_dims=[cell_dim,nucleus_dim,pathogen_dim]
1171
+ sns.color_palette("mako", as_cmap=True)
1172
+ print(f'channel:{channel_of_interest} = {target}')
1173
+ overlay_channels = [0, 1, 2, 3]
1174
+ overlay_channels.remove(channel_of_interest)
1175
+ overlay_channels.reverse()
1176
+
1177
+ db_loc = [src+'/measurements/measurements.db']
1178
+ tables = ['cell', 'nucleus', 'pathogen','cytoplasm']
1179
+ df, object_dfs = _read_and_merge_data(db_loc,
1180
+ tables,
1181
+ verbose=True,
1182
+ include_multinucleated=include_multinucleated,
1183
+ include_multiinfected=include_multiinfected,
1184
+ include_noninfected=include_noninfected)
1185
+
1186
+ if filter_data:
1187
+ df = df[df['cell_area'] > cell_size_min]
1188
+ df = df[df[f'cell_channel_{mask_chans[2]}_mean_intensity'] > cell_min]
1189
+ print(f'After cell filtration {len(df)}')
1190
+ df = df[df['nucleus_area'] > nucleus_size_min]
1191
+ df = df[df[f'nucleus_channel_{mask_chans[0]}_mean_intensity'] > nucleus_min]
1192
+ print(f'After nucleus filtration {len(df)}')
1193
+ df = df[df['pathogen_area'] > pathogen_size_min]
1194
+ df=df[df[f'pathogen_channel_{mask_chans[1]}_mean_intensity'] > pathogen_min]
1195
+ print(f'After pathogen filtration {len(df)}')
1196
+ df = df[df[f'cell_channel_{channel_of_interest}_percentile_95'] > target_min]
1197
+ print(f'After channel {channel_of_interest} filtration', len(df))
1198
+
1199
+ df['recruitment'] = df[f'pathogen_channel_{channel_of_interest}_mean_intensity']/df[f'cytoplasm_channel_{channel_of_interest}_mean_intensity']
1200
+ return df
1201
+
1202
+ def training_dataset_from_annotation(db_path, dst, annotation_column='test', annotated_classes=(1, 2)):
1203
+ all_paths = []
1204
+
1205
+ # Connect to the database and retrieve the image paths and annotations
1206
+ print(f'Reading DataBase: {db_path}')
1207
+ with sqlite3.connect(db_path) as conn:
1208
+ cursor = conn.cursor()
1209
+ # Prepare the query with parameterized placeholders for annotated_classes
1210
+ placeholders = ','.join('?' * len(annotated_classes))
1211
+ query = f"SELECT png_path, {annotation_column} FROM png_list WHERE {annotation_column} IN ({placeholders})"
1212
+ cursor.execute(query, annotated_classes)
1213
+
1214
+ while True:
1215
+ rows = cursor.fetchmany(1000)
1216
+ if not rows:
1217
+ break
1218
+ for row in rows:
1219
+ all_paths.append(row)
1220
+
1221
+ # Filter paths based on annotation
1222
+ class_paths = []
1223
+ for class_ in annotated_classes:
1224
+ class_paths_temp = [path for path, annotation in all_paths if annotation == class_]
1225
+ class_paths.append(class_paths_temp)
1226
+
1227
+ print(f'Generated a list of lists from annotation of {len(class_paths)} classes')
1228
+ return class_paths
1229
+
1230
+ def generate_dataset_from_lists(dst, class_data, classes, test_split=0.1):
1231
+ # Make sure that the length of class_data matches the length of classes
1232
+ if len(class_data) != len(classes):
1233
+ raise ValueError("class_data and classes must have the same length.")
1234
+
1235
+ total_files = sum(len(data) for data in class_data)
1236
+ processed_files = 0
1237
+
1238
+ for cls, data in zip(classes, class_data):
1239
+ # Create directories
1240
+ train_class_dir = os.path.join(dst, f'train/{cls}')
1241
+ test_class_dir = os.path.join(dst, f'test/{cls}')
1242
+ os.makedirs(train_class_dir, exist_ok=True)
1243
+ os.makedirs(test_class_dir, exist_ok=True)
1244
+
1245
+ # Split the data
1246
+ train_data, test_data = train_test_split(data, test_size=test_split, shuffle=True, random_state=42)
1247
+
1248
+ # Copy train files
1249
+ for path in train_data:
1250
+ shutil.copy(path, os.path.join(train_class_dir, os.path.basename(path)))
1251
+ processed_files += 1
1252
+ print(f'{processed_files}/{total_files}', end='\r', flush=True)
1253
+
1254
+ # Copy test files
1255
+ for path in test_data:
1256
+ shutil.copy(path, os.path.join(test_class_dir, os.path.basename(path)))
1257
+ processed_files += 1
1258
+ print(f'{processed_files}/{total_files}', end='\r', flush=True)
1259
+
1260
+ # Print summary
1261
+ for cls in classes:
1262
+ train_class_dir = os.path.join(dst, f'train/{cls}')
1263
+ test_class_dir = os.path.join(dst, f'test/{cls}')
1264
+ print(f'Train class {cls}: {len(os.listdir(train_class_dir))}, Test class {cls}: {len(os.listdir(test_class_dir))}')
1265
+
1266
+ return
1267
+
1268
+ def generate_training_dataset(src, mode='annotation', annotation_column='test', annotated_classes=[1,2], classes=['nc','pc'], size=200, test_split=0.1, class_metadata=[['c1'],['c2']], metadata_type_by='col', channel_of_interest=3, custom_measurement=None, tables=None, png_type='cell_png'):
1269
+
1270
+ from .io import _read_and_merge_data, _read_db
1271
+ from .utils import get_paths_from_db, annotate_conditions
1272
+
1273
+ db_path = os.path.join(src, 'measurements','measurements.db')
1274
+ dst = os.path.join(src, 'datasets', 'training')
1275
+
1276
+ if mode == 'annotation':
1277
+ class_paths_ls_2 = []
1278
+ class_paths_ls = training_dataset_from_annotation(db_path, dst, annotation_column, annotated_classes=annotated_classes)
1279
+ for class_paths in class_paths_ls:
1280
+ class_paths_temp = random.sample(class_paths, size)
1281
+ class_paths_ls_2.append(class_paths_temp)
1282
+ class_paths_ls = class_paths_ls_2
1283
+
1284
+ elif mode == 'metadata':
1285
+ class_paths_ls = []
1286
+ [df] = _read_db(db_loc=db_path, tables=['png_list'])
1287
+ df['metadata_based_class'] = pd.NA
1288
+ for i, class_ in enumerate(classes):
1289
+ ls = class_metadata[i]
1290
+ df.loc[df[metadata_type_by].isin(ls), 'metadata_based_class'] = class_
1291
+
1292
+ for class_ in classes:
1293
+ class_temp_df = df[df['metadata_based_class'] == class_]
1294
+ class_paths_temp = random.sample(class_temp_df['png_path'].tolist(), size)
1295
+ class_paths_ls.append(class_paths_temp)
1296
+
1297
+ elif mode == 'recruitment':
1298
+ class_paths_ls = []
1299
+ if not isinstance(tables, list):
1300
+ tables = ['cell', 'nucleus', 'pathogen','cytoplasm']
1301
+
1302
+ df, _ = _read_and_merge_data(locs=[db_path],
1303
+ tables=tables,
1304
+ verbose=False,
1305
+ include_multinucleated=True,
1306
+ include_multiinfected=True,
1307
+ include_noninfected=True)
1308
+
1309
+ print('length df 1', len(df))
1310
+
1311
+ df = annotate_conditions(df, cells=['HeLa'], cell_loc=None, pathogens=['pathogen'], pathogen_loc=None, treatments=classes, treatment_loc=class_metadata, types = ['col','col',metadata_type_by])
1312
+ print('length df 2', len(df))
1313
+ [png_list_df] = _read_db(db_loc=db_path, tables=['png_list'])
1314
+
1315
+ if custom_measurement != None:
1316
+
1317
+ if not isinstance(custom_measurement, list):
1318
+ print(f'custom_measurement should be a list, add [ measurement_1, measurement_2 ] or [ measurement ]')
1319
+ return
1320
+
1321
+ if isinstance(custom_measurement, list):
1322
+ if len(custom_measurement) == 2:
1323
+ print(f'Classes will be defined by the Q1 and Q3 quantiles of recruitment ({custom_measurement[0]}/{custom_measurement[1]})')
1324
+ df['recruitment'] = df[f'{custom_measurement[0]}']/df[f'{custom_measurement[1]}']
1325
+ if len(custom_measurement) == 1:
1326
+ print(f'Classes will be defined by the Q1 and Q3 quantiles of recruitment ({custom_measurement[0]})')
1327
+ df['recruitment'] = df[f'{custom_measurement[0]}']
1328
+ else:
1329
+ print(f'Classes will be defined by the Q1 and Q3 quantiles of recruitment (pathogen/cytoplasm for channel {channel_of_interest})')
1330
+ df['recruitment'] = df[f'pathogen_channel_{channel_of_interest}_mean_intensity']/df[f'cytoplasm_channel_{channel_of_interest}_mean_intensity']
1331
+
1332
+ q25 = df['recruitment'].quantile(0.25)
1333
+ q75 = df['recruitment'].quantile(0.75)
1334
+ df_lower = df[df['recruitment'] <= q25]
1335
+ df_upper = df[df['recruitment'] >= q75]
1336
+
1337
+ class_paths_lower = get_paths_from_db(df=df_lower, png_df=png_list_df, image_type=png_type)
1338
+
1339
+ class_paths_lower = random.sample(class_paths_lower['png_path'].tolist(), size)
1340
+ class_paths_ls.append(class_paths_lower)
1341
+
1342
+ class_paths_upper = get_paths_from_db(df=df_upper, png_df=png_list_df, image_type=png_type)
1343
+ class_paths_upper = random.sample(class_paths_upper['png_path'].tolist(), size)
1344
+ class_paths_ls.append(class_paths_upper)
1345
+
1346
+ generate_dataset_from_lists(dst, class_data=class_paths_ls, classes=classes, test_split=0.1)
1347
+
1348
+ return
1349
+
1350
+ def generate_loaders(src, train_mode='erm', mode='train', image_size=224, batch_size=32, classes=['nc','pc'], num_workers=None, validation_split=0.0, max_show=2, pin_memory=False, normalize=False, verbose=False):
1351
+ """
1352
+ Generate data loaders for training and validation/test datasets.
1353
+
1354
+ Parameters:
1355
+ - src (str): The source directory containing the data.
1356
+ - train_mode (str): The training mode. Options are 'erm' (Empirical Risk Minimization) or 'irm' (Invariant Risk Minimization).
1357
+ - mode (str): The mode of operation. Options are 'train' or 'test'.
1358
+ - image_size (int): The size of the input images.
1359
+ - batch_size (int): The batch size for the data loaders.
1360
+ - classes (list): The list of classes to consider.
1361
+ - num_workers (int): The number of worker threads for data loading.
1362
+ - validation_split (float): The fraction of data to use for validation when train_mode is 'erm'.
1363
+ - max_show (int): The maximum number of images to show when verbose is True.
1364
+ - pin_memory (bool): Whether to pin memory for faster data transfer.
1365
+ - normalize (bool): Whether to normalize the input images.
1366
+ - verbose (bool): Whether to print additional information and show images.
1367
+
1368
+ Returns:
1369
+ - train_loaders (list): List of data loaders for training datasets.
1370
+ - val_loaders (list): List of data loaders for validation datasets.
1371
+ - plate_names (list): List of plate names (only applicable when train_mode is 'irm').
1372
+ """
1373
+
1374
+ from .io import MyDataset
1375
+ from .plot import _imshow
1376
+
1377
+ plate_to_filenames = defaultdict(list)
1378
+ plate_to_labels = defaultdict(list)
1379
+ train_loaders = []
1380
+ val_loaders = []
1381
+ plate_names = []
1382
+
1383
+ if normalize:
1384
+ transform = transforms.Compose([
1385
+ transforms.ToTensor(),
1386
+ transforms.CenterCrop(size=(image_size, image_size)),
1387
+ transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])
1388
+ else:
1389
+ transform = transforms.Compose([
1390
+ transforms.ToTensor(),
1391
+ transforms.CenterCrop(size=(image_size, image_size))])
1392
+
1393
+ if mode == 'train':
1394
+ data_dir = os.path.join(src, 'train')
1395
+ shuffle = True
1396
+ print(f'Generating Train and validation datasets')
1397
+
1398
+ elif mode == 'test':
1399
+ data_dir = os.path.join(src, 'test')
1400
+ val_loaders = []
1401
+ validation_split=0.0
1402
+ shuffle = True
1403
+ print(f'Generating test dataset')
1404
+
1405
+ else:
1406
+ print(f'mode:{mode} is not valid, use mode = train or test')
1407
+ return
1408
+
1409
+ if train_mode == 'erm':
1410
+ data = MyDataset(data_dir, classes, transform=transform, shuffle=shuffle, pin_memory=pin_memory)
1411
+ #train_loaders = DataLoader(data, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers if num_workers is not None else 0, pin_memory=pin_memory)
1412
+ if validation_split > 0:
1413
+ train_size = int((1 - validation_split) * len(data))
1414
+ val_size = len(data) - train_size
1415
+
1416
+ print(f'Train data:{train_size}, Validation data:{val_size}')
1417
+
1418
+ train_dataset, val_dataset = random_split(data, [train_size, val_size])
1419
+
1420
+ train_loaders = DataLoader(train_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers if num_workers is not None else 0, pin_memory=pin_memory)
1421
+ val_loaders = DataLoader(val_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers if num_workers is not None else 0, pin_memory=pin_memory)
1422
+ else:
1423
+ train_loaders = DataLoader(data, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers if num_workers is not None else 0, pin_memory=pin_memory)
1424
+
1425
+ elif train_mode == 'irm':
1426
+ data = MyDataset(data_dir, classes, transform=transform, shuffle=shuffle, pin_memory=pin_memory)
1427
+
1428
+ for filename, label in zip(data.filenames, data.labels):
1429
+ plate = data.get_plate(filename)
1430
+ plate_to_filenames[plate].append(filename)
1431
+ plate_to_labels[plate].append(label)
1432
+
1433
+ for plate, filenames in plate_to_filenames.items():
1434
+ labels = plate_to_labels[plate]
1435
+ plate_data = MyDataset(data_dir, classes, specific_files=filenames, specific_labels=labels, transform=transform, shuffle=False, pin_memory=pin_memory)
1436
+ plate_names.append(plate)
1437
+
1438
+ if validation_split > 0:
1439
+ train_size = int((1 - validation_split) * len(plate_data))
1440
+ val_size = len(plate_data) - train_size
1441
+
1442
+ print(f'Train data:{train_size}, Validation data:{val_size}')
1443
+
1444
+ train_dataset, val_dataset = random_split(plate_data, [train_size, val_size])
1445
+
1446
+ train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers if num_workers is not None else 0, pin_memory=pin_memory)
1447
+ val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers if num_workers is not None else 0, pin_memory=pin_memory)
1448
+
1449
+ train_loaders.append(train_loader)
1450
+ val_loaders.append(val_loader)
1451
+ else:
1452
+ train_loader = DataLoader(plate_data, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers if num_workers is not None else 0, pin_memory=pin_memory)
1453
+ train_loaders.append(train_loader)
1454
+ val_loaders.append(None)
1455
+
1456
+ else:
1457
+ print(f'train_mode:{train_mode} is not valid, use: train_mode = irm or erm')
1458
+ return
1459
+
1460
+ if verbose:
1461
+ if train_mode == 'erm':
1462
+ for idx, (images, labels, filenames) in enumerate(train_loaders):
1463
+ if idx >= max_show:
1464
+ break
1465
+ images = images.cpu()
1466
+ label_strings = [str(label.item()) for label in labels]
1467
+ _imshow(images, label_strings, nrow=20, fontsize=12)
1468
+
1469
+ elif train_mode == 'irm':
1470
+ for plate_name, train_loader in zip(plate_names, train_loaders):
1471
+ print(f'Plate: {plate_name} with {len(train_loader.dataset)} images')
1472
+ for idx, (images, labels, filenames) in enumerate(train_loader):
1473
+ if idx >= max_show:
1474
+ break
1475
+ images = images.cpu()
1476
+ label_strings = [str(label.item()) for label in labels]
1477
+ _imshow(images, label_strings, nrow=20, fontsize=12)
1478
+
1479
+ return train_loaders, val_loaders, plate_names
1480
+
1481
+ def analyze_recruitment(src, metadata_settings, advanced_settings):
1482
+ """
1483
+ Analyze recruitment data by grouping the DataFrame by well coordinates and plotting controls and recruitment data.
1484
+
1485
+ Parameters:
1486
+ src (str): The source of the recruitment data.
1487
+ metadata_settings (dict): The settings for metadata.
1488
+ advanced_settings (dict): The advanced settings for recruitment analysis.
1489
+
1490
+ Returns:
1491
+ None
1492
+ """
1493
+
1494
+ from .io import _read_and_merge_data, _results_to_csv
1495
+ from .plot import plot_merged, _plot_controls, _plot_recruitment
1496
+ from .utils import _object_filter, annotate_conditions, _calculate_recruitment, _group_by_well
1497
+
1498
+ settings_dict = {**metadata_settings, **advanced_settings}
1499
+ settings_df = pd.DataFrame(list(settings_dict.items()), columns=['Key', 'Value'])
1500
+ settings_csv = os.path.join(src,'settings','analyze_settings.csv')
1501
+ os.makedirs(os.path.join(src,'settings'), exist_ok=True)
1502
+ settings_df.to_csv(settings_csv, index=False)
1503
+
1504
+ # metadata settings
1505
+ target = metadata_settings['target']
1506
+ cell_types = metadata_settings['cell_types']
1507
+ cell_plate_metadata = metadata_settings['cell_plate_metadata']
1508
+ pathogen_types = metadata_settings['pathogen_types']
1509
+ pathogen_plate_metadata = metadata_settings['pathogen_plate_metadata']
1510
+ treatments = metadata_settings['treatments']
1511
+ treatment_plate_metadata = metadata_settings['treatment_plate_metadata']
1512
+ metadata_types = metadata_settings['metadata_types']
1513
+ channel_dims = metadata_settings['channel_dims']
1514
+ cell_chann_dim = metadata_settings['cell_chann_dim']
1515
+ cell_mask_dim = metadata_settings['cell_mask_dim']
1516
+ nucleus_chann_dim = metadata_settings['nucleus_chann_dim']
1517
+ nucleus_mask_dim = metadata_settings['nucleus_mask_dim']
1518
+ pathogen_chann_dim = metadata_settings['pathogen_chann_dim']
1519
+ pathogen_mask_dim = metadata_settings['pathogen_mask_dim']
1520
+ channel_of_interest = metadata_settings['channel_of_interest']
1521
+
1522
+ # Advanced settings
1523
+ plot = advanced_settings['plot']
1524
+ plot_nr = advanced_settings['plot_nr']
1525
+ plot_control = advanced_settings['plot_control']
1526
+ figuresize = advanced_settings['figuresize']
1527
+ remove_background = advanced_settings['remove_background']
1528
+ backgrounds = advanced_settings['backgrounds']
1529
+ include_noninfected = advanced_settings['include_noninfected']
1530
+ include_multiinfected = advanced_settings['include_multiinfected']
1531
+ include_multinucleated = advanced_settings['include_multinucleated']
1532
+ cells_per_well = advanced_settings['cells_per_well']
1533
+ pathogen_size_range = advanced_settings['pathogen_size_range']
1534
+ nucleus_size_range = advanced_settings['nucleus_size_range']
1535
+ cell_size_range = advanced_settings['cell_size_range']
1536
+ pathogen_intensity_range = advanced_settings['pathogen_intensity_range']
1537
+ nucleus_intensity_range = advanced_settings['nucleus_intensity_range']
1538
+ cell_intensity_range = advanced_settings['cell_intensity_range']
1539
+ target_intensity_min = advanced_settings['target_intensity_min']
1540
+
1541
+ print(f'Cell(s): {cell_types}, in {cell_plate_metadata}')
1542
+ print(f'Pathogen(s): {pathogen_types}, in {pathogen_plate_metadata}')
1543
+ print(f'Treatment(s): {treatments}, in {treatment_plate_metadata}')
1544
+
1545
+ mask_dims=[cell_mask_dim,nucleus_mask_dim,pathogen_mask_dim]
1546
+ mask_chans=[nucleus_chann_dim, pathogen_chann_dim, cell_chann_dim]
1547
+
1548
+ if isinstance(metadata_types, str):
1549
+ metadata_types = [metadata_types, metadata_types, metadata_types]
1550
+ if isinstance(metadata_types, list):
1551
+ if len(metadata_types) < 3:
1552
+ metadata_types = [metadata_types[0], metadata_types[0], metadata_types[0]]
1553
+ print(f'WARNING: setting metadata types to first element times 3: {metadata_types}. To avoid this behaviour, set metadata_types to a list with 3 elements. Elements should be col row or plate.')
1554
+ else:
1555
+ metadata_types = metadata_types
1556
+
1557
+ if isinstance(backgrounds, (int,float)):
1558
+ backgrounds = [backgrounds, backgrounds, backgrounds, backgrounds]
1559
+
1560
+ sns.color_palette("mako", as_cmap=True)
1561
+ print(f'channel:{channel_of_interest} = {target}')
1562
+ overlay_channels = channel_dims
1563
+ overlay_channels.remove(channel_of_interest)
1564
+ overlay_channels.reverse()
1565
+
1566
+ db_loc = [src+'/measurements/measurements.db']
1567
+ tables = ['cell', 'nucleus', 'pathogen','cytoplasm']
1568
+ df, _ = _read_and_merge_data(db_loc,
1569
+ tables,
1570
+ verbose=True,
1571
+ include_multinucleated=include_multinucleated,
1572
+ include_multiinfected=include_multiinfected,
1573
+ include_noninfected=include_noninfected)
1574
+
1575
+ df = annotate_conditions(df,
1576
+ cells=cell_types,
1577
+ cell_loc=cell_plate_metadata,
1578
+ pathogens=pathogen_types,
1579
+ pathogen_loc=pathogen_plate_metadata,
1580
+ treatments=treatments,
1581
+ treatment_loc=treatment_plate_metadata,
1582
+ types=metadata_types)
1583
+
1584
+ df = df.dropna(subset=['condition'])
1585
+ print(f'After dropping non-annotated wells: {len(df)} rows')
1586
+ files = df['file_name'].tolist()
1587
+ files = [item + '.npy' for item in files]
1588
+ random.shuffle(files)
1589
+
1590
+ if plot:
1591
+ plot_settings = {'include_noninfected':include_noninfected,
1592
+ 'include_multiinfected':include_multiinfected,
1593
+ 'include_multinucleated':include_multinucleated,
1594
+ 'remove_background':remove_background,
1595
+ 'filter_min_max':[[cell_size_range[0],cell_size_range[1]],[nucleus_size_range[0],nucleus_size_range[1]],[pathogen_size_range[0],pathogen_size_range[1]]],
1596
+ 'channel_dims':channel_dims,
1597
+ 'backgrounds':backgrounds,
1598
+ 'cell_mask_dim':mask_dims[0],
1599
+ 'nucleus_mask_dim':mask_dims[1],
1600
+ 'pathogen_mask_dim':mask_dims[2],
1601
+ 'overlay_chans':overlay_channels,
1602
+ 'outline_thickness':3,
1603
+ 'outline_color':'gbr',
1604
+ 'overlay_chans':overlay_channels,
1605
+ 'overlay':True,
1606
+ 'normalization_percentiles':[1,99],
1607
+ 'normalize':True,
1608
+ 'print_object_number':True,
1609
+ 'nr':plot_nr,
1610
+ 'figuresize':20,
1611
+ 'cmap':'inferno',
1612
+ 'verbose':False}
1613
+
1614
+ if os.path.exists(os.path.join(src,'merged')):
1615
+ try:
1616
+ plot_merged(src=os.path.join(src,'merged'), settings=plot_settings)
1617
+ except Exception as e:
1618
+ print(f'Failed to plot images with outlines, Error: {e}')
1619
+
1620
+ if not cell_chann_dim is None:
1621
+ df = _object_filter(df, object_type='cell', size_range=cell_size_range, intensity_range=cell_intensity_range, mask_chans=mask_chans, mask_chan=0)
1622
+ if not target_intensity_min is None:
1623
+ df = df[df[f'cell_channel_{channel_of_interest}_percentile_95'] > target_intensity_min]
1624
+ print(f'After channel {channel_of_interest} filtration', len(df))
1625
+ if not nucleus_chann_dim is None:
1626
+ df = _object_filter(df, object_type='nucleus', size_range=nucleus_size_range, intensity_range=nucleus_intensity_range, mask_chans=mask_chans, mask_chan=1)
1627
+ if not pathogen_chann_dim is None:
1628
+ df = _object_filter(df, object_type='pathogen', size_range=pathogen_size_range, intensity_range=pathogen_intensity_range, mask_chans=mask_chans, mask_chan=2)
1629
+
1630
+ df['recruitment'] = df[f'pathogen_channel_{channel_of_interest}_mean_intensity']/df[f'cytoplasm_channel_{channel_of_interest}_mean_intensity']
1631
+ for chan in channel_dims:
1632
+ df = _calculate_recruitment(df, channel=chan)
1633
+ print(f'calculated recruitment for: {len(df)} rows')
1634
+ df_well = _group_by_well(df)
1635
+ print(f'found: {len(df_well)} wells')
1636
+
1637
+ df_well = df_well[df_well['cells_per_well'] >= cells_per_well]
1638
+ prc_list = df_well['prc'].unique().tolist()
1639
+ df = df[df['prc'].isin(prc_list)]
1640
+ print(f'After cells per well filter: {len(df)} cells in {len(df_well)} wells left wth threshold {cells_per_well}')
1641
+
1642
+ if plot_control:
1643
+ _plot_controls(df, mask_chans, channel_of_interest, figuresize=5)
1644
+
1645
+ print(f'PV level: {len(df)} rows')
1646
+ _plot_recruitment(df=df, df_type='by PV', channel_of_interest=channel_of_interest, target=target, figuresize=figuresize)
1647
+ print(f'well level: {len(df_well)} rows')
1648
+ _plot_recruitment(df=df_well, df_type='by well', channel_of_interest=channel_of_interest, target=target, figuresize=figuresize)
1649
+ cells,wells = _results_to_csv(src, df, df_well)
1650
+ return [cells,wells]
1651
+
1652
+ @log_function_call
1653
+ def preprocess_generate_masks(src, settings={},advanced_settings={}):
1654
+
1655
+ from .io import preprocess_img_data, _load_and_concatenate_arrays
1656
+ from .plot import plot_merged, plot_arrays
1657
+ from .utils import _pivot_counts_table
1658
+
1659
+ settings = {**settings, **advanced_settings}
1660
+ settings['src'] = src
1661
+ settings_df = pd.DataFrame(list(settings.items()), columns=['Key', 'Value'])
1662
+ settings_csv = os.path.join(src,'settings','preprocess_generate_masks_settings.csv')
1663
+ os.makedirs(os.path.join(src,'settings'), exist_ok=True)
1664
+ settings_df.to_csv(settings_csv, index=False)
1665
+
1666
+ if settings['timelapse']:
1667
+ settings['randomize'] = False
1668
+
1669
+ if settings['preprocess']:
1670
+ if not settings['masks']:
1671
+ print(f'WARNING: channels for mask generation are defined when preprocess = True')
1672
+
1673
+ if isinstance(settings['merge'], bool):
1674
+ settings['merge'] = [settings['merge']]*3
1675
+ if isinstance(settings['save'], bool):
1676
+ settings['save'] = [settings['save']]*3
1677
+
1678
+ if settings['preprocess']:
1679
+ preprocess_img_data(settings)
1680
+
1681
+ if settings['masks']:
1682
+ mask_src = os.path.join(src, 'norm_channel_stack')
1683
+ if settings['cell_channel'] != None:
1684
+ generate_cellpose_masks(src=mask_src, settings=settings, object_type='cell')
1685
+
1686
+ if settings['nucleus_channel'] != None:
1687
+ generate_cellpose_masks(src=mask_src, settings=settings, object_type='nucleus')
1688
+
1689
+ if settings['pathogen_channel'] != None:
1690
+ generate_cellpose_masks(src=mask_src, settings=settings, object_type='pathogen')
1691
+
1692
+ if os.path.exists(os.path.join(src,'measurements')):
1693
+ _pivot_counts_table(db_path=os.path.join(src,'measurements', 'measurements.db'))
1694
+
1695
+ #Concatinate stack with masks
1696
+ _load_and_concatenate_arrays(src, settings['channels'], settings['cell_channel'], settings['nucleus_channel'], settings['pathogen_channel'])
1697
+
1698
+ if settings['plot']:
1699
+ if not settings['timelapse']:
1700
+ plot_dims = len(settings['channels'])
1701
+ overlay_channels = [2,1,0]
1702
+ cell_mask_dim = nucleus_mask_dim = pathogen_mask_dim = None
1703
+ plot_counter = plot_dims
1704
+
1705
+ if settings['cell_channel'] is not None:
1706
+ cell_mask_dim = plot_counter
1707
+ plot_counter += 1
1708
+
1709
+ if settings['nucleus_channel'] is not None:
1710
+ nucleus_mask_dim = plot_counter
1711
+ plot_counter += 1
1712
+
1713
+ if settings['pathogen_channel'] is not None:
1714
+ pathogen_mask_dim = plot_counter
1715
+
1716
+ overlay_channels = [settings['nucleus_channel'], settings['pathogen_channel'], settings['cell_channel']]
1717
+ overlay_channels = [element for element in overlay_channels if element is not None]
1718
+
1719
+ plot_settings = {'include_noninfected':True,
1720
+ 'include_multiinfected':True,
1721
+ 'include_multinucleated':True,
1722
+ 'remove_background':False,
1723
+ 'filter_min_max':None,
1724
+ 'channel_dims':settings['channels'],
1725
+ 'backgrounds':[100,100,100,100],
1726
+ 'cell_mask_dim':cell_mask_dim,
1727
+ 'nucleus_mask_dim':nucleus_mask_dim,
1728
+ 'pathogen_mask_dim':pathogen_mask_dim,
1729
+ 'overlay_chans':[0,2,3],
1730
+ 'outline_thickness':3,
1731
+ 'outline_color':'gbr',
1732
+ 'overlay_chans':overlay_channels,
1733
+ 'overlay':True,
1734
+ 'normalization_percentiles':[1,99],
1735
+ 'normalize':True,
1736
+ 'print_object_number':True,
1737
+ 'nr':settings['examples_to_plot'],
1738
+ 'figuresize':20,
1739
+ 'cmap':'inferno',
1740
+ 'verbose':False}
1741
+ try:
1742
+ fig = plot_merged(src=os.path.join(src,'merged'), settings=plot_settings)
1743
+ except Exception as e:
1744
+ print(f'Failed to plot image mask overly. Error: {e}')
1745
+ else:
1746
+ plot_arrays(src=os.path.join(src,'merged'), figuresize=50, cmap='inferno', nr=1, normalize=True, q1=1, q2=99)
1747
+
1748
+ torch.cuda.empty_cache()
1749
+ gc.collect()
1750
+ return
1751
+
1752
+ def identify_masks_finetune(src, dst, model_name, channels, diameter, batch_size, flow_threshold=30, cellprob_threshold=1, figuresize=25, cmap='inferno', verbose=False, plot=False, save=False, custom_model=None, signal_thresholds=1000, normalize=True, resize=False, target_height=None, target_width=None, rescale=True, resample=True, net_avg=False, invert=False, circular=False, percentiles=None, overlay=True, grayscale=False):
1753
+
1754
+ from .plot import print_mask_and_flows
1755
+ from .utils import get_files_from_dir, resize_images_and_labels
1756
+ from .io import _load_normalized_images_and_labels, _load_images_and_labels
1757
+
1758
+ if not torch.cuda.is_available():
1759
+ print(f'Torch CUDA is not available, using CPU')
1760
+
1761
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
1762
+
1763
+ if custom_model == None:
1764
+ if model_name =='cyto':
1765
+ model = cp_models.CellposeModel(gpu=True, model_type=model_name, net_avg=False, diam_mean=diameter, pretrained_model=None)
1766
+ else:
1767
+ model = cp_models.CellposeModel(gpu=True, model_type=model_name)
1768
+
1769
+ if custom_model != None:
1770
+ model = cp_models.CellposeModel(gpu=torch.cuda.is_available(), model_type=None, pretrained_model=custom_model, diam_mean=diameter, device=device, net_avg=False) #Assuming diameter is defined elsewhere
1771
+ print(f'loaded custom model:{custom_model}')
1772
+
1773
+ chans = [2, 1] if model_name == 'cyto2' else [0,0] if model_name == 'nucleus' else [1,0] if model_name == 'cyto' else [2, 0]
1774
+
1775
+ if grayscale:
1776
+ chans=[0, 0]
1777
+
1778
+ print(f'Using channels: {chans} for model of type {model_name}')
1779
+
1780
+ if verbose == True:
1781
+ print(f'Cellpose settings: Model: {model_name}, channels: {channels}, cellpose_chans: {chans}, diameter:{diameter}, flow_threshold:{flow_threshold}, cellprob_threshold:{cellprob_threshold}')
1782
+
1783
+ all_image_files = get_files_from_dir(src, file_extension="*.tif")
1784
+ random.shuffle(all_image_files)
1785
+
1786
+ time_ls = []
1787
+ for i in range(0, len(all_image_files), batch_size):
1788
+ image_files = all_image_files[i:i+batch_size]
1789
+ if normalize:
1790
+ images, _, image_names, _ = _load_normalized_images_and_labels(image_files=image_files, label_files=None, signal_thresholds=signal_thresholds, channels=channels, percentiles=percentiles, circular=circular, invert=invert, visualize=verbose)
1791
+ images = [np.squeeze(img) if img.shape[-1] == 1 else img for img in images]
1792
+ orig_dims = [(image.shape[0], image.shape[1]) for image in images]
1793
+ else:
1794
+ images, _, image_names, _ = _load_images_and_labels(image_files=image_files, label_files=None, circular=circular, invert=invert)
1795
+ images = [np.squeeze(img) if img.shape[-1] == 1 else img for img in images]
1796
+ orig_dims = [(image.shape[0], image.shape[1]) for image in images]
1797
+ if resize:
1798
+ images, _ = resize_images_and_labels(images, None, target_height, target_width, True)
1799
+
1800
+ for file_index, stack in enumerate(images):
1801
+ start = time.time()
1802
+ output = model.eval(x=stack,
1803
+ normalize=False,
1804
+ channels=chans,
1805
+ channel_axis=3,
1806
+ diameter=diameter,
1807
+ flow_threshold=flow_threshold,
1808
+ cellprob_threshold=cellprob_threshold,
1809
+ rescale=rescale,
1810
+ resample=resample,
1811
+ net_avg=net_avg,
1812
+ progress=False)
1813
+
1814
+ if len(output) == 4:
1815
+ mask, flows, _, _ = output
1816
+ elif len(output) == 3:
1817
+ mask, flows, _ = output
1818
+ else:
1819
+ raise ValueError("Unexpected number of return values from model.eval()")
1820
+
1821
+ if resize:
1822
+ dims = orig_dims[file_index]
1823
+ mask = resizescikit(mask, dims, order=0, preserve_range=True, anti_aliasing=False).astype(mask.dtype)
1824
+
1825
+ stop = time.time()
1826
+ duration = (stop - start)
1827
+ time_ls.append(duration)
1828
+ average_time = np.mean(time_ls) if len(time_ls) > 0 else 0
1829
+ print(f'Processing {file_index+1}/{len(images)} images : Time/image {average_time:.3f} sec', end='\r', flush=True)
1830
+ if plot:
1831
+ if resize:
1832
+ stack = resizescikit(stack, dims, preserve_range=True, anti_aliasing=False).astype(stack.dtype)
1833
+ print_mask_and_flows(stack, mask, flows, overlay=overlay)
1834
+ if save:
1835
+ output_filename = os.path.join(dst, image_names[file_index])
1836
+ cv2.imwrite(output_filename, mask)
1837
+ return
1838
+
1839
+ @log_function_call
1840
+ def identify_masks(src, object_type, model_name, batch_size, channels, diameter, minimum_size, maximum_size, flow_threshold=30, cellprob_threshold=1, figuresize=25, cmap='inferno', refine_masks=True, filter_size=True, filter_dimm=True, remove_border_objects=False, verbose=False, plot=False, merge=False, save=True, start_at=0, file_type='.npz', net_avg=True, resample=True, timelapse=False, timelapse_displacement=None, timelapse_frame_limits=None, timelapse_memory=3, timelapse_remove_transient=False, timelapse_mode='btrack', timelapse_objects='cell'):
1841
+ """
1842
+ Identify masks from the source images.
1843
+
1844
+ Args:
1845
+ src (str): Path to the source images.
1846
+ object_type (str): Type of object to identify.
1847
+ model_name (str): Name of the model to use for identification.
1848
+ batch_size (int): Number of images to process in each batch.
1849
+ channels (list): List of channel names.
1850
+ diameter (float): Diameter of the objects to identify.
1851
+ minimum_size (int): Minimum size of objects to keep.
1852
+ maximum_size (int): Maximum size of objects to keep.
1853
+ flow_threshold (int, optional): Threshold for flow detection. Defaults to 30.
1854
+ cellprob_threshold (int, optional): Threshold for cell probability. Defaults to 1.
1855
+ figuresize (int, optional): Size of the figure. Defaults to 25.
1856
+ cmap (str, optional): Colormap for plotting. Defaults to 'inferno'.
1857
+ refine_masks (bool, optional): Flag indicating whether to refine masks. Defaults to True.
1858
+ filter_size (bool, optional): Flag indicating whether to filter based on size. Defaults to True.
1859
+ filter_dimm (bool, optional): Flag indicating whether to filter based on intensity. Defaults to True.
1860
+ remove_border_objects (bool, optional): Flag indicating whether to remove border objects. Defaults to False.
1861
+ verbose (bool, optional): Flag indicating whether to display verbose output. Defaults to False.
1862
+ plot (bool, optional): Flag indicating whether to plot the masks. Defaults to False.
1863
+ merge (bool, optional): Flag indicating whether to merge adjacent objects. Defaults to False.
1864
+ save (bool, optional): Flag indicating whether to save the masks. Defaults to True.
1865
+ start_at (int, optional): Index to start processing from. Defaults to 0.
1866
+ file_type (str, optional): File type for saving the masks. Defaults to '.npz'.
1867
+ net_avg (bool, optional): Flag indicating whether to use network averaging. Defaults to True.
1868
+ resample (bool, optional): Flag indicating whether to resample the images. Defaults to True.
1869
+ timelapse (bool, optional): Flag indicating whether to generate a timelapse. Defaults to False.
1870
+ timelapse_displacement (float, optional): Displacement threshold for timelapse. Defaults to None.
1871
+ timelapse_frame_limits (tuple, optional): Frame limits for timelapse. Defaults to None.
1872
+ timelapse_memory (int, optional): Memory for timelapse. Defaults to 3.
1873
+ timelapse_remove_transient (bool, optional): Flag indicating whether to remove transient objects in timelapse. Defaults to False.
1874
+ timelapse_mode (str, optional): Mode for timelapse. Defaults to 'btrack'.
1875
+ timelapse_objects (str, optional): Objects to track in timelapse. Defaults to 'cell'.
1876
+
1877
+ Returns:
1878
+ None
1879
+ """
1880
+
1881
+ from .utils import _masks_to_masks_stack, _filter_cp_masks, _get_cellpose_batch_size
1882
+ from .io import _create_database, _save_object_counts_to_database, _check_masks, _get_avg_object_size
1883
+ from .timelapse import _npz_to_movie, _btrack_track_cells, _trackpy_track_cells
1884
+ from .plot import plot_masks
1885
+
1886
+ #Note add logic that handles batches of size 1 as these will break the code batches must all be > 2 images
1887
+ gc.collect()
1888
+ #print('========== generating masks ==========')
1889
+
1890
+ if not torch.cuda.is_available():
1891
+ print(f'Torch CUDA is not available, using CPU')
1892
+
1893
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
1894
+ model = cp_models.Cellpose(gpu=True, model_type=model_name, device=device) #net_avg=net_avg
1895
+ if file_type == '.npz':
1896
+ paths = [os.path.join(src, file) for file in os.listdir(src) if file.endswith('.npz')]
1897
+ else:
1898
+ paths = [os.path.join(src, file) for file in os.listdir(src) if file.endswith('.png')]
1899
+ if timelapse:
1900
+ print(f'timelaps is only compatible with npz files')
1901
+ return
1902
+
1903
+ chans = [2, 1] if model_name == 'cyto2' else [0,0] if model_name == 'nucleus' else [2,0] if model_name == 'cyto' else [2, 0]
1904
+
1905
+ if verbose == True:
1906
+ print(f'source: {src}')
1907
+ print()
1908
+ print(f'Settings: object_type: {object_type}, minimum_size: {minimum_size}, maximum_size:{maximum_size}, figuresize:{figuresize}, cmap:{cmap}, , net_avg:{net_avg}, resample:{resample}')
1909
+ print()
1910
+ print(f'Cellpose settings: Model: {model_name}, batch_size: {batch_size}, channels: {channels}, cellpose_chans: {chans}, diameter:{diameter}, flow_threshold:{flow_threshold}, cellprob_threshold:{cellprob_threshold}')
1911
+ print()
1912
+ print(f'Bool Settings: verbose:{verbose}, plot:{plot}, merge:{merge}, save:{save}, start_at:{start_at}, file_type:{file_type}, timelapse:{timelapse}')
1913
+ print()
1914
+
1915
+ count_loc = os.path.dirname(src)+'/measurements/measurements.db'
1916
+ os.makedirs(os.path.dirname(src)+'/measurements', exist_ok=True)
1917
+ _create_database(count_loc)
1918
+
1919
+ average_sizes = []
1920
+ time_ls = []
1921
+ moving_avg_q1 = 0
1922
+ moving_avg_q3 = 0
1923
+ moving_count = 0
1924
+ for file_index, path in enumerate(paths):
1925
+
1926
+ name = os.path.basename(path)
1927
+ name, ext = os.path.splitext(name)
1928
+ if file_type == '.npz':
1929
+ if start_at:
1930
+ print(f'starting at file index:{start_at}')
1931
+ if file_index < start_at:
1932
+ continue
1933
+ output_folder = os.path.join(os.path.dirname(path), object_type+'_mask_stack')
1934
+ os.makedirs(output_folder, exist_ok=True)
1935
+ overall_average_size = 0
1936
+ with np.load(path) as data:
1937
+ stack = data['data']
1938
+ filenames = data['filenames']
1939
+ if timelapse:
1940
+ if len(stack) != batch_size:
1941
+ print(f'Changed batch_size:{batch_size} to {len(stack)}, data length:{len(stack)}')
1942
+ batch_size = len(stack)
1943
+ if isinstance(timelapse_frame_limits, list):
1944
+ if len(timelapse_frame_limits) >= 2:
1945
+ stack = stack[timelapse_frame_limits[0]: timelapse_frame_limits[1], :, :, :].astype(stack.dtype)
1946
+ filenames = filenames[timelapse_frame_limits[0]: timelapse_frame_limits[1]]
1947
+ batch_size = len(stack)
1948
+ print(f'Cut batch an indecies: {timelapse_frame_limits}, New batch_size: {batch_size} ')
1949
+
1950
+ for i in range(0, stack.shape[0], batch_size):
1951
+ mask_stack = []
1952
+ start = time.time()
1953
+
1954
+ if stack.shape[3] == 1:
1955
+ batch = stack[i: i+batch_size, :, :, [0,0]].astype(stack.dtype)
1956
+ else:
1957
+ batch = stack[i: i+batch_size, :, :, channels].astype(stack.dtype)
1958
+
1959
+ batch_filenames = filenames[i: i+batch_size].tolist()
1960
+
1961
+ if not plot:
1962
+ batch, batch_filenames = _check_masks(batch, batch_filenames, output_folder)
1963
+ if batch.size == 0:
1964
+ print(f'Processing {file_index}/{len(paths)}: Images/N100pz {batch.shape[0]}', end='\r', flush=True)
1965
+ continue
1966
+ if batch.max() > 1:
1967
+ batch = batch / batch.max()
1968
+
1969
+ if timelapse:
1970
+ stitch_threshold=100.0
1971
+ movie_path = os.path.join(os.path.dirname(src), 'movies')
1972
+ os.makedirs(movie_path, exist_ok=True)
1973
+ save_path = os.path.join(movie_path, f'timelapse_{object_type}_{name}.mp4')
1974
+ _npz_to_movie(batch, batch_filenames, save_path, fps=2)
1975
+ else:
1976
+ stitch_threshold=0.0
1977
+
1978
+ cellpose_batch_size = _get_cellpose_batch_size()
1979
+
1980
+ model = cellpose.denoise.DenoiseModel(model_type=f"denoise_{model_name}", gpu=True)
1981
+
1982
+ masks, flows, _, _ = model.eval(x=batch,
1983
+ batch_size=cellpose_batch_size,
1984
+ normalize=False,
1985
+ channels=chans,
1986
+ channel_axis=3,
1987
+ diameter=diameter,
1988
+ flow_threshold=flow_threshold,
1989
+ cellprob_threshold=cellprob_threshold,
1990
+ rescale=None,
1991
+ resample=resample,
1992
+ #net_avg=net_avg,
1993
+ stitch_threshold=stitch_threshold,
1994
+ progress=None)
1995
+ print('Masks shape',masks.shape)
1996
+ if timelapse:
1997
+ _save_object_counts_to_database(masks, object_type, batch_filenames, count_loc, added_string='_timelapse')
1998
+ if object_type in timelapse_objects:
1999
+ if timelapse_mode == 'btrack':
2000
+ if not timelapse_displacement is None:
2001
+ radius = timelapse_displacement
2002
+ else:
2003
+ radius = 100
2004
+
2005
+ workers = os.cpu_count()-2
2006
+ if workers < 1:
2007
+ workers = 1
2008
+
2009
+ mask_stack = _btrack_track_cells(src, name, batch_filenames, object_type, plot, save, masks_3D=masks, mode=timelapse_mode, timelapse_remove_transient=timelapse_remove_transient, radius=radius, workers=workers)
2010
+ if timelapse_mode == 'trackpy':
2011
+ mask_stack = _trackpy_track_cells(src, name, batch_filenames, object_type, masks, timelapse_displacement, timelapse_memory, timelapse_remove_transient, plot, save, timelapse_mode)
2012
+
2013
+ else:
2014
+ mask_stack = _masks_to_masks_stack(masks)
2015
+
2016
+ else:
2017
+ _save_object_counts_to_database(masks, object_type, batch_filenames, count_loc, added_string='_before_filtration')
2018
+ mask_stack = _filter_cp_masks(masks, flows, refine_masks, filter_size, minimum_size, maximum_size, remove_border_objects, merge, filter_dimm, batch, moving_avg_q1, moving_avg_q3, moving_count, plot, figuresize)
2019
+ _save_object_counts_to_database(mask_stack, object_type, batch_filenames, count_loc, added_string='_after_filtration')
2020
+
2021
+ if not np.any(mask_stack):
2022
+ average_obj_size = 0
2023
+ else:
2024
+ average_obj_size = _get_avg_object_size(mask_stack)
2025
+
2026
+ average_sizes.append(average_obj_size)
2027
+ overall_average_size = np.mean(average_sizes) if len(average_sizes) > 0 else 0
2028
+
2029
+ stop = time.time()
2030
+ duration = (stop - start)
2031
+ time_ls.append(duration)
2032
+ average_time = np.mean(time_ls) if len(time_ls) > 0 else 0
2033
+ time_in_min = average_time/60
2034
+ time_per_mask = average_time/batch_size
2035
+ print(f'Processing {len(paths)} files with {batch_size} imgs: {(file_index+1)*(batch_size+1)}/{(len(paths))*(batch_size+1)}: Time/batch {time_in_min:.3f} min: Time/mask {time_per_mask:.3f}sec: {object_type} size: {overall_average_size:.3f} px2', end='\r', flush=True)
2036
+ if not timelapse:
2037
+ if plot:
2038
+ plot_masks(batch, mask_stack, flows, figuresize=figuresize, cmap=cmap, nr=batch_size, file_type='.npz')
2039
+ if save:
2040
+ if file_type == '.npz':
2041
+ for mask_index, mask in enumerate(mask_stack):
2042
+ output_filename = os.path.join(output_folder, batch_filenames[mask_index])
2043
+ np.save(output_filename, mask)
2044
+ mask_stack = []
2045
+ batch_filenames = []
2046
+ gc.collect()
2047
+ return
2048
+
2049
+ @log_function_call
2050
+ def generate_cellpose_masks(src, settings, object_type):
2051
+
2052
+ from .utils import _masks_to_masks_stack, _filter_cp_masks, _get_cellpose_batch_size, _get_object_settings, _get_cellpose_channels
2053
+ from .io import _create_database, _save_object_counts_to_database, _check_masks, _get_avg_object_size
2054
+ from .timelapse import _npz_to_movie, _btrack_track_cells, _trackpy_track_cells
2055
+ from .plot import plot_masks
2056
+
2057
+ gc.collect()
2058
+ if not torch.cuda.is_available():
2059
+ print(f'Torch CUDA is not available, using CPU')
2060
+
2061
+ figuresize=25
2062
+ timelapse = settings['timelapse']
2063
+
2064
+ if timelapse:
2065
+ timelapse_displacement = settings['timelapse_displacement']
2066
+ timelapse_frame_limits = settings['timelapse_frame_limits']
2067
+ timelapse_memory = settings['timelapse_memory']
2068
+ timelapse_remove_transient = settings['timelapse_remove_transient']
2069
+ timelapse_mode = settings['timelapse_mode']
2070
+ timelapse_objects = settings['timelapse_objects']
2071
+
2072
+ batch_size = settings['batch_size']
2073
+ cellprob_threshold = settings[f'{object_type}_CP_prob']
2074
+ flow_threshold = 30
2075
+
2076
+ object_settings = _get_object_settings(object_type, settings)
2077
+ model_name = object_settings['model_name']
2078
+
2079
+ cellpose_channels = _get_cellpose_channels(settings['nucleus_channel'], settings['pathogen_channel'], settings['cell_channel'])
2080
+ channels = cellpose_channels[object_type]
2081
+ cellpose_batch_size = _get_cellpose_batch_size()
2082
+
2083
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
2084
+ model = cp_models.Cellpose(gpu=True, model_type=model_name, device=device) #net_avg=net_avg
2085
+ #dn = denoise.CellposeDenoiseModel(model_type=f"denoise_{model_name}", gpu=True, device=device)
2086
+
2087
+ chans = [2, 1] if model_name == 'cyto2' else [0,0] if model_name == 'nucleus' else [2,0] if model_name == 'cyto' else [2, 0] if model_name == 'cyto3' else [2, 0]
2088
+
2089
+ paths = [os.path.join(src, file) for file in os.listdir(src) if file.endswith('.npz')]
2090
+
2091
+ count_loc = os.path.dirname(src)+'/measurements/measurements.db'
2092
+ os.makedirs(os.path.dirname(src)+'/measurements', exist_ok=True)
2093
+ _create_database(count_loc)
2094
+
2095
+ average_sizes = []
2096
+ time_ls = []
2097
+ moving_avg_q1 = 0
2098
+ moving_avg_q3 = 0
2099
+ moving_count = 0
2100
+
2101
+ for file_index, path in enumerate(paths):
2102
+ name = os.path.basename(path)
2103
+ name, ext = os.path.splitext(name)
2104
+ output_folder = os.path.join(os.path.dirname(path), object_type+'_mask_stack')
2105
+ os.makedirs(output_folder, exist_ok=True)
2106
+ overall_average_size = 0
2107
+ with np.load(path) as data:
2108
+ stack = data['data']
2109
+ filenames = data['filenames']
2110
+ if settings['timelapse']:
2111
+ if len(stack) != batch_size:
2112
+ print(f'Changed batch_size:{batch_size} to {len(stack)}, data length:{len(stack)}')
2113
+ settings['batch_size'] = len(stack)
2114
+ batch_size = len(stack)
2115
+ if isinstance(timelapse_frame_limits, list):
2116
+ if len(timelapse_frame_limits) >= 2:
2117
+ stack = stack[timelapse_frame_limits[0]: timelapse_frame_limits[1], :, :, :].astype(stack.dtype)
2118
+ filenames = filenames[timelapse_frame_limits[0]: timelapse_frame_limits[1]]
2119
+ batch_size = len(stack)
2120
+ print(f'Cut batch an indecies: {timelapse_frame_limits}, New batch_size: {batch_size} ')
2121
+
2122
+ for i in range(0, stack.shape[0], batch_size):
2123
+ mask_stack = []
2124
+ start = time.time()
2125
+
2126
+ if stack.shape[3] == 1:
2127
+ batch = stack[i: i+batch_size, :, :, [0,0]].astype(stack.dtype)
2128
+ else:
2129
+ batch = stack[i: i+batch_size, :, :, channels].astype(stack.dtype)
2130
+
2131
+ batch_filenames = filenames[i: i+batch_size].tolist()
2132
+
2133
+ if not settings['plot']:
2134
+ batch, batch_filenames = _check_masks(batch, batch_filenames, output_folder)
2135
+ if batch.size == 0:
2136
+ print(f'Processing {file_index}/{len(paths)}: Images/N100pz {batch.shape[0]}', end='\r', flush=True)
2137
+ continue
2138
+ if batch.max() > 1:
2139
+ batch = batch / batch.max()
2140
+
2141
+ if timelapse:
2142
+ stitch_threshold=100.0
2143
+ movie_path = os.path.join(os.path.dirname(src), 'movies')
2144
+ os.makedirs(movie_path, exist_ok=True)
2145
+ save_path = os.path.join(movie_path, f'timelapse_{object_type}_{name}.mp4')
2146
+ _npz_to_movie(batch, batch_filenames, save_path, fps=2)
2147
+ else:
2148
+ stitch_threshold=0.0
2149
+ #print(batch.shape)
2150
+ #batch, _, _, _ = dn.eval(x=batch, channels=chans, diameter=object_settings['diameter'])
2151
+ #batch = np.stack((batch, batch), axis=-1)
2152
+ #print(f'object: {object_type} chans : {chans} channels : {channels} model: {model_name}')
2153
+ masks, flows, _, _ = model.eval(x=batch,
2154
+ batch_size=cellpose_batch_size,
2155
+ normalize=False,
2156
+ channels=chans,
2157
+ channel_axis=3,
2158
+ diameter=object_settings['diameter'],
2159
+ flow_threshold=flow_threshold,
2160
+ cellprob_threshold=cellprob_threshold,
2161
+ rescale=None,
2162
+ resample=object_settings['resample'],
2163
+ stitch_threshold=stitch_threshold)
2164
+ #progress=None)
2165
+
2166
+ if timelapse:
2167
+ _save_object_counts_to_database(masks, object_type, batch_filenames, count_loc, added_string='_timelapse')
2168
+ if object_type in timelapse_objects:
2169
+ if timelapse_mode == 'btrack':
2170
+ if not timelapse_displacement is None:
2171
+ radius = timelapse_displacement
2172
+ else:
2173
+ radius = 100
2174
+
2175
+ workers = os.cpu_count()-2
2176
+ if workers < 1:
2177
+ workers = 1
2178
+
2179
+ mask_stack = _btrack_track_cells(src=src,
2180
+ name=name,
2181
+ batch_filenames=batch_filenames,
2182
+ object_type=object_type,
2183
+ plot=settings['plot'],
2184
+ save=settings['save'],
2185
+ masks_3D=masks,
2186
+ mode=timelapse_mode,
2187
+ timelapse_remove_transient=timelapse_remove_transient,
2188
+ radius=radius,
2189
+ workers=workers)
2190
+ if timelapse_mode == 'trackpy':
2191
+ mask_stack = _trackpy_track_cells(src=src,
2192
+ name=name,
2193
+ batch_filenames=batch_filenames,
2194
+ object_type=object_type,
2195
+ masks_3D=masks,
2196
+ timelapse_displacement=timelapse_displacement,
2197
+ timelapse_memory=timelapse_memory,
2198
+ timelapse_remove_transient=timelapse_remove_transient,
2199
+ plot=settings['plot'],
2200
+ save=settings['save'],
2201
+ timelapse_mode=timelapse_mode)
2202
+ else:
2203
+ mask_stack = _masks_to_masks_stack(masks)
2204
+
2205
+ else:
2206
+ _save_object_counts_to_database(masks, object_type, batch_filenames, count_loc, added_string='_before_filtration')
2207
+ mask_stack = _filter_cp_masks(masks=masks,
2208
+ flows=flows,
2209
+ filter_size=object_settings['filter_size'],
2210
+ minimum_size=object_settings['minimum_size'],
2211
+ maximum_size=object_settings['maximum_size'],
2212
+ remove_border_objects=object_settings['remove_border_objects'],
2213
+ merge=False,
2214
+ filter_dimm=object_settings['filter_dimm'],
2215
+ batch=batch,
2216
+ moving_avg_q1=moving_avg_q1,
2217
+ moving_avg_q3=moving_avg_q3,
2218
+ moving_count=moving_count,
2219
+ plot=settings['plot'],
2220
+ figuresize=figuresize)
2221
+
2222
+ _save_object_counts_to_database(mask_stack, object_type, batch_filenames, count_loc, added_string='_after_filtration')
2223
+
2224
+ if not np.any(mask_stack):
2225
+ average_obj_size = 0
2226
+ else:
2227
+ average_obj_size = _get_avg_object_size(mask_stack)
2228
+
2229
+ average_sizes.append(average_obj_size)
2230
+ overall_average_size = np.mean(average_sizes) if len(average_sizes) > 0 else 0
2231
+
2232
+ stop = time.time()
2233
+ duration = (stop - start)
2234
+ time_ls.append(duration)
2235
+ average_time = np.mean(time_ls) if len(time_ls) > 0 else 0
2236
+ time_in_min = average_time/60
2237
+ time_per_mask = average_time/batch_size
2238
+ print(f'Processing {len(paths)} files with {batch_size} imgs: {(file_index+1)*(batch_size+1)}/{(len(paths))*(batch_size+1)}: Time/batch {time_in_min:.3f} min: Time/mask {time_per_mask:.3f}sec: {object_type} size: {overall_average_size:.3f} px2', end='\r', flush=True)
2239
+ if not timelapse:
2240
+ if settings['plot']:
2241
+ plot_masks(batch, mask_stack, flows, figuresize=figuresize, cmap='inferno', nr=batch_size)
2242
+ if settings['save']:
2243
+ for mask_index, mask in enumerate(mask_stack):
2244
+ output_filename = os.path.join(output_folder, batch_filenames[mask_index])
2245
+ np.save(output_filename, mask)
2246
+ mask_stack = []
2247
+ batch_filenames = []
2248
+ gc.collect()
2249
+ torch.cuda.empty_cache()
2250
+ return