spacr 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- spacr/__init__.py +37 -0
- spacr/__main__.py +15 -0
- spacr/annotate_app.py +495 -0
- spacr/cli.py +203 -0
- spacr/core.py +2250 -0
- spacr/gui_mask_app.py +247 -0
- spacr/gui_measure_app.py +214 -0
- spacr/gui_utils.py +488 -0
- spacr/io.py +2271 -0
- spacr/logger.py +20 -0
- spacr/mask_app.py +818 -0
- spacr/measure.py +1014 -0
- spacr/old_code.py +104 -0
- spacr/plot.py +1273 -0
- spacr/sim.py +1187 -0
- spacr/timelapse.py +576 -0
- spacr/train.py +494 -0
- spacr/umap.py +689 -0
- spacr/utils.py +2726 -0
- spacr/version.py +19 -0
- spacr-0.0.1.dist-info/LICENSE +21 -0
- spacr-0.0.1.dist-info/METADATA +64 -0
- spacr-0.0.1.dist-info/RECORD +26 -0
- spacr-0.0.1.dist-info/WHEEL +5 -0
- spacr-0.0.1.dist-info/entry_points.txt +5 -0
- spacr-0.0.1.dist-info/top_level.txt +1 -0
spacr/core.py
ADDED
@@ -0,0 +1,2250 @@
|
|
1
|
+
import os, sqlite3, gc, torch, time, random, shutil, cv2, tarfile, datetime
|
2
|
+
|
3
|
+
# image and array processing
|
4
|
+
import numpy as np
|
5
|
+
import pandas as pd
|
6
|
+
|
7
|
+
import cellpose
|
8
|
+
from cellpose import models as cp_models
|
9
|
+
from cellpose import denoise
|
10
|
+
|
11
|
+
import statsmodels.formula.api as smf
|
12
|
+
import statsmodels.api as sm
|
13
|
+
from functools import reduce
|
14
|
+
from IPython.display import display
|
15
|
+
from multiprocessing import Pool, cpu_count, Value, Lock
|
16
|
+
|
17
|
+
import seaborn as sns
|
18
|
+
import matplotlib.pyplot as plt
|
19
|
+
from skimage.measure import regionprops, label
|
20
|
+
import skimage.measure as measure
|
21
|
+
from skimage.transform import resize as resizescikit
|
22
|
+
from sklearn.model_selection import train_test_split
|
23
|
+
from collections import defaultdict
|
24
|
+
import multiprocessing
|
25
|
+
from torch.utils.data import DataLoader, random_split
|
26
|
+
import matplotlib
|
27
|
+
matplotlib.use('Agg')
|
28
|
+
|
29
|
+
import torchvision.transforms as transforms
|
30
|
+
from sklearn.model_selection import train_test_split
|
31
|
+
from sklearn.ensemble import IsolationForest
|
32
|
+
|
33
|
+
from .logger import log_function_call
|
34
|
+
|
35
|
+
#from .io import TarImageDataset, NoClassDataset, MyDataset, read_db, _copy_missclassified, read_mask, load_normalized_images_and_labels, load_images_and_labels
|
36
|
+
#from .plot import plot_merged, plot_arrays, _plot_controls, _plot_recruitment, _imshow, _plot_histograms_and_stats, _reg_v_plot, visualize_masks, plot_comparison_results
|
37
|
+
#from .utils import extract_boundaries, boundary_f1_score, compute_segmentation_ap, jaccard_index, dice_coefficient, _object_filter
|
38
|
+
#from .utils import resize_images_and_labels, generate_fraction_map, MLR, fishers_odds, lasso_reg, model_metrics, _map_wells_png, check_multicollinearity, init_globals, add_images_to_tar
|
39
|
+
#from .utils import get_paths_from_db, pick_best_model, test_model_performance, evaluate_model_performance, compute_irm_penalty
|
40
|
+
#from .utils import _pivot_counts_table, _generate_masks, _get_cellpose_channels, annotate_conditions, _calculate_recruitment, calculate_loss, _group_by_well, choose_model
|
41
|
+
|
42
|
+
@log_function_call
|
43
|
+
def analyze_plaques(folder):
|
44
|
+
summary_data = []
|
45
|
+
details_data = []
|
46
|
+
|
47
|
+
for filename in os.listdir(folder):
|
48
|
+
filepath = os.path.join(folder, filename)
|
49
|
+
if os.path.isfile(filepath):
|
50
|
+
# Assuming each file is a NumPy array file (.npy) containing a 16-bit labeled image
|
51
|
+
image = np.load(filepath)
|
52
|
+
|
53
|
+
labeled_image = label(image)
|
54
|
+
regions = regionprops(labeled_image)
|
55
|
+
|
56
|
+
object_count = len(regions)
|
57
|
+
sizes = [region.area for region in regions]
|
58
|
+
average_size = np.mean(sizes) if sizes else 0
|
59
|
+
|
60
|
+
summary_data.append({'file': filename, 'object_count': object_count, 'average_size': average_size})
|
61
|
+
for size in sizes:
|
62
|
+
details_data.append({'file': filename, 'plaque_size': size})
|
63
|
+
|
64
|
+
# Convert lists to pandas DataFrames
|
65
|
+
summary_df = pd.DataFrame(summary_data)
|
66
|
+
details_df = pd.DataFrame(details_data)
|
67
|
+
|
68
|
+
# Save DataFrames to a SQLite database
|
69
|
+
db_name = 'plaques_analysis.db'
|
70
|
+
conn = sqlite3.connect(db_name)
|
71
|
+
|
72
|
+
summary_df.to_sql('summary', conn, if_exists='replace', index=False)
|
73
|
+
details_df.to_sql('details', conn, if_exists='replace', index=False)
|
74
|
+
|
75
|
+
conn.close()
|
76
|
+
|
77
|
+
print(f"Analysis completed and saved to database '{db_name}'.")
|
78
|
+
|
79
|
+
@log_function_call
|
80
|
+
def compare_masks(dir1, dir2, dir3, verbose=False):
|
81
|
+
|
82
|
+
from .io import _read_mask
|
83
|
+
from .plot import visualize_masks, plot_comparison_results
|
84
|
+
from .utils import extract_boundaries, boundary_f1_score, compute_segmentation_ap, jaccard_index, dice_coefficient
|
85
|
+
|
86
|
+
filenames = os.listdir(dir1)
|
87
|
+
results = []
|
88
|
+
cond_1 = os.path.basename(dir1)
|
89
|
+
cond_2 = os.path.basename(dir2)
|
90
|
+
cond_3 = os.path.basename(dir3)
|
91
|
+
for index, filename in enumerate(filenames):
|
92
|
+
print(f'Processing image:{index+1}', end='\r', flush=True)
|
93
|
+
path1, path2, path3 = os.path.join(dir1, filename), os.path.join(dir2, filename), os.path.join(dir3, filename)
|
94
|
+
if os.path.exists(path2) and os.path.exists(path3):
|
95
|
+
|
96
|
+
mask1, mask2, mask3 = _read_mask(path1), _read_mask(path2), _read_mask(path3)
|
97
|
+
boundary_true1, boundary_true2, boundary_true3 = extract_boundaries(mask1), extract_boundaries(mask2), extract_boundaries(mask3)
|
98
|
+
|
99
|
+
|
100
|
+
true_masks, pred_masks = [mask1], [mask2, mask3] # Assuming mask1 is the ground truth for simplicity
|
101
|
+
true_labels, pred_labels_1, pred_labels_2 = label(mask1), label(mask2), label(mask3)
|
102
|
+
average_precision_0, average_precision_1 = compute_segmentation_ap(mask1, mask2), compute_segmentation_ap(mask1, mask3)
|
103
|
+
ap_scores = [average_precision_0, average_precision_1]
|
104
|
+
|
105
|
+
if verbose:
|
106
|
+
unique_values1, unique_values2, unique_values3 = np.unique(mask1), np.unique(mask2), np.unique(mask3)
|
107
|
+
print(f"Unique values in mask 1: {unique_values1}, mask 2: {unique_values2}, mask 3: {unique_values3}")
|
108
|
+
visualize_masks(boundary_true1, boundary_true2, boundary_true3, title=f"Boundaries - {filename}")
|
109
|
+
|
110
|
+
boundary_f1_12, boundary_f1_13, boundary_f1_23 = boundary_f1_score(mask1, mask2), boundary_f1_score(mask1, mask3), boundary_f1_score(mask2, mask3)
|
111
|
+
|
112
|
+
if (np.unique(mask1).size == 1 and np.unique(mask1)[0] == 0) and \
|
113
|
+
(np.unique(mask2).size == 1 and np.unique(mask2)[0] == 0) and \
|
114
|
+
(np.unique(mask3).size == 1 and np.unique(mask3)[0] == 0):
|
115
|
+
continue
|
116
|
+
|
117
|
+
if verbose:
|
118
|
+
unique_values4, unique_values5, unique_values6 = np.unique(boundary_f1_12), np.unique(boundary_f1_13), np.unique(boundary_f1_23)
|
119
|
+
print(f"Unique values in boundary mask 1: {unique_values4}, mask 2: {unique_values5}, mask 3: {unique_values6}")
|
120
|
+
visualize_masks(mask1, mask2, mask3, title=filename)
|
121
|
+
|
122
|
+
jaccard12 = jaccard_index(mask1, mask2)
|
123
|
+
dice12 = dice_coefficient(mask1, mask2)
|
124
|
+
jaccard13 = jaccard_index(mask1, mask3)
|
125
|
+
dice13 = dice_coefficient(mask1, mask3)
|
126
|
+
jaccard23 = jaccard_index(mask2, mask3)
|
127
|
+
dice23 = dice_coefficient(mask2, mask3)
|
128
|
+
|
129
|
+
results.append({
|
130
|
+
f'filename': filename,
|
131
|
+
f'jaccard_{cond_1}_{cond_2}': jaccard12,
|
132
|
+
f'dice_{cond_1}_{cond_2}': dice12,
|
133
|
+
f'jaccard_{cond_1}_{cond_3}': jaccard13,
|
134
|
+
f'dice_{cond_1}_{cond_3}': dice13,
|
135
|
+
f'jaccard_{cond_2}_{cond_3}': jaccard23,
|
136
|
+
f'dice_{cond_2}_{cond_3}': dice23,
|
137
|
+
f'boundary_f1_{cond_1}_{cond_2}': boundary_f1_12,
|
138
|
+
f'boundary_f1_{cond_1}_{cond_3}': boundary_f1_13,
|
139
|
+
f'boundary_f1_{cond_2}_{cond_3}': boundary_f1_23,
|
140
|
+
f'average_precision_{cond_1}_{cond_2}': ap_scores[0],
|
141
|
+
f'average_precision_{cond_1}_{cond_3}': ap_scores[1]
|
142
|
+
})
|
143
|
+
else:
|
144
|
+
print(f'Cannot find {path1} or {path2} or {path3}')
|
145
|
+
fig = plot_comparison_results(results)
|
146
|
+
return results, fig
|
147
|
+
|
148
|
+
def generate_cp_masks(settings):
|
149
|
+
|
150
|
+
src = settings['src']
|
151
|
+
model_name = settings['model_name']
|
152
|
+
channels = settings['channels']
|
153
|
+
diameter = settings['diameter']
|
154
|
+
regex = '.tif'
|
155
|
+
#flow_threshold = 30
|
156
|
+
cellprob_threshold = settings['cellprob_threshold']
|
157
|
+
figuresize = 25
|
158
|
+
cmap = 'inferno'
|
159
|
+
verbose = settings['verbose']
|
160
|
+
plot = settings['plot']
|
161
|
+
save = settings['save']
|
162
|
+
custom_model = settings['custom_model']
|
163
|
+
signal_thresholds = 1000
|
164
|
+
normalize = settings['normalize']
|
165
|
+
resize = settings['resize']
|
166
|
+
target_height = settings['width_height'][1]
|
167
|
+
target_width = settings['width_height'][0]
|
168
|
+
rescale = settings['rescale']
|
169
|
+
resample = settings['resample']
|
170
|
+
net_avg = settings['net_avg']
|
171
|
+
invert = settings['invert']
|
172
|
+
circular = settings['circular']
|
173
|
+
percentiles = settings['percentiles']
|
174
|
+
overlay = settings['overlay']
|
175
|
+
grayscale = settings['grayscale']
|
176
|
+
flow_threshold = settings['flow_threshold']
|
177
|
+
batch_size = settings['batch_size']
|
178
|
+
|
179
|
+
dst = os.path.join(src,'masks')
|
180
|
+
os.makedirs(dst, exist_ok=True)
|
181
|
+
|
182
|
+
identify_masks(src, dst, model_name, channels, diameter, batch_size, flow_threshold, cellprob_threshold, figuresize, cmap, verbose, plot, save, custom_model, signal_thresholds, normalize, resize, target_height, target_width, rescale, resample, net_avg, invert, circular, percentiles, overlay, grayscale)
|
183
|
+
|
184
|
+
@log_function_call
|
185
|
+
def train_cellpose(settings):
|
186
|
+
|
187
|
+
from .io import _load_normalized_images_and_labels, _load_images_and_labels
|
188
|
+
from .utils import resize_images_and_labels
|
189
|
+
|
190
|
+
img_src = settings['img_src']
|
191
|
+
mask_src= settings['mask_src']
|
192
|
+
secondary_image_dir = None
|
193
|
+
model_name = settings['model_name']
|
194
|
+
model_type = settings['model_type']
|
195
|
+
learning_rate = settings['learning_rate']
|
196
|
+
weight_decay = settings['weight_decay']
|
197
|
+
batch_size = settings['batch_size']
|
198
|
+
n_epochs = settings['n_epochs']
|
199
|
+
verbose = settings['verbose']
|
200
|
+
signal_thresholds = settings['signal_thresholds']
|
201
|
+
channels = settings['channels']
|
202
|
+
from_scratch = settings['from_scratch']
|
203
|
+
diameter = settings['diameter']
|
204
|
+
resize = settings['resize']
|
205
|
+
rescale = settings['rescale']
|
206
|
+
normalize = settings['normalize']
|
207
|
+
target_height = settings['width_height'][1]
|
208
|
+
target_width = settings['width_height'][0]
|
209
|
+
circular = settings['circular']
|
210
|
+
invert = settings['invert']
|
211
|
+
percentiles = settings['percentiles']
|
212
|
+
grayscale = settings['grayscale']
|
213
|
+
|
214
|
+
print(settings)
|
215
|
+
|
216
|
+
if from_scratch:
|
217
|
+
model_name=f'scratch_{model_name}_{model_type}_e{n_epochs}_X{target_width}_Y{target_height}.CP_model'
|
218
|
+
else:
|
219
|
+
model_name=f'{model_name}_{model_type}_e{n_epochs}_X{target_width}_Y{target_height}.CP_model'
|
220
|
+
|
221
|
+
model_save_path = os.path.join(mask_src, 'models', 'cellpose_model')
|
222
|
+
os.makedirs(os.path.dirname(model_save_path), exist_ok=True)
|
223
|
+
|
224
|
+
settings_df = pd.DataFrame(list(settings.items()), columns=['Key', 'Value'])
|
225
|
+
settings_csv = os.path.join(model_save_path,f'{model_name}_settings.csv')
|
226
|
+
settings_df.to_csv(settings_csv, index=False)
|
227
|
+
|
228
|
+
if model_type =='cyto':
|
229
|
+
if not from_scratch:
|
230
|
+
model = cp_models.CellposeModel(gpu=True, model_type=model_type)
|
231
|
+
else:
|
232
|
+
model = cp_models.CellposeModel(gpu=True, model_type=model_type, net_avg=False, diam_mean=diameter, pretrained_model=None)
|
233
|
+
if model_type !='cyto':
|
234
|
+
model = cp_models.CellposeModel(gpu=True, model_type=model_type)
|
235
|
+
|
236
|
+
|
237
|
+
|
238
|
+
if normalize:
|
239
|
+
images, masks, image_names, mask_names = _load_normalized_images_and_labels(image_dir=img_src, label_dir=mask_src, secondary_image_dir=secondary_image_dir, signal_thresholds=signal_thresholds, channels=channels, percentiles=percentiles, circular=circular, invert=invert, visualize=verbose)
|
240
|
+
images = [np.squeeze(img) if img.shape[-1] == 1 else img for img in images]
|
241
|
+
else:
|
242
|
+
images, masks, image_names, mask_names = _load_images_and_labels(img_src, mask_src, circular, invert)
|
243
|
+
images = [np.squeeze(img) if img.shape[-1] == 1 else img for img in images]
|
244
|
+
|
245
|
+
if resize:
|
246
|
+
images, masks = resize_images_and_labels(images, masks, target_height, target_width, show_example=True)
|
247
|
+
|
248
|
+
if model_type == 'cyto':
|
249
|
+
cp_channels = [0,1]
|
250
|
+
if model_type == 'cyto2':
|
251
|
+
cp_channels = [0,2]
|
252
|
+
if model_type == 'nucleus':
|
253
|
+
cp_channels = [0,0]
|
254
|
+
if grayscale:
|
255
|
+
cp_channels = [0,0]
|
256
|
+
images = [np.squeeze(img) if img.ndim == 3 and 1 in img.shape else img for img in images]
|
257
|
+
|
258
|
+
masks = [np.squeeze(mask) if mask.ndim == 3 and 1 in mask.shape else mask for mask in masks]
|
259
|
+
|
260
|
+
print(f'image shape: {images[0].shape}, image type: images[0].shape mask shape: {masks[0].shape}, image type: masks[0].shape')
|
261
|
+
save_every = int(n_epochs/10)
|
262
|
+
print('cellpose image input dtype', images[0].dtype)
|
263
|
+
print('cellpose mask input dtype', masks[0].dtype)
|
264
|
+
# Train the model
|
265
|
+
model.train(train_data=images, #(list of arrays (2D or 3D)) – images for training
|
266
|
+
train_labels=masks, #(list of arrays (2D or 3D)) – labels for train_data, where 0=no masks; 1,2,…=mask labels can include flows as additional images
|
267
|
+
train_files=image_names, #(list of strings) – file names for images in train_data (to save flows for future runs)
|
268
|
+
channels=cp_channels, #(list of ints (default, None)) – channels to use for training
|
269
|
+
normalize=False, #(bool (default, True)) – normalize data so 0.0=1st percentile and 1.0=99th percentile of image intensities in each channel
|
270
|
+
save_path=model_save_path, #(string (default, None)) – where to save trained model, if None it is not saved
|
271
|
+
save_every=save_every, #(int (default, 100)) – save network every [save_every] epochs
|
272
|
+
learning_rate=learning_rate, #(float or list/np.ndarray (default, 0.2)) – learning rate for training, if list, must be same length as n_epochs
|
273
|
+
n_epochs=n_epochs, #(int (default, 500)) – how many times to go through whole training set during training
|
274
|
+
weight_decay=weight_decay, #(float (default, 0.00001)) –
|
275
|
+
SGD=True, #(bool (default, True)) – use SGD as optimization instead of RAdam
|
276
|
+
batch_size=batch_size, #(int (optional, default 8)) – number of 224x224 patches to run simultaneously on the GPU (can make smaller or bigger depending on GPU memory usage)
|
277
|
+
nimg_per_epoch=None, #(int (optional, default None)) – minimum number of images to train on per epoch, with a small training set (< 8 images) it may help to set to 8
|
278
|
+
rescale=rescale, #(bool (default, True)) – whether or not to rescale images to diam_mean during training, if True it assumes you will fit a size model after training or resize your images accordingly, if False it will try to train the model to be scale-invariant (works worse)
|
279
|
+
min_train_masks=1, #(int (default, 5)) – minimum number of masks an image must have to use in training set
|
280
|
+
model_name=model_name) #(str (default, None)) – name of network, otherwise saved with name as params + training start time
|
281
|
+
|
282
|
+
return print(f"Model saved at: {model_save_path}/{model_name}")
|
283
|
+
|
284
|
+
@log_function_call
|
285
|
+
def analyze_data_reg(sequencing_loc, dv_loc, agg_type = 'mean', dv_col='pred', transform=None, min_cell_count=50, min_reads=100, min_wells=2, max_wells=1000, min_frequency=0.0,remove_outlier_genes=False, refine_model=False,by_plate=False, regression_type='mlr', alpha_value=0.01, fishers=False, fisher_threshold=0.9):
|
286
|
+
|
287
|
+
from .plot import _reg_v_plot
|
288
|
+
from .utils import generate_fraction_map, MLR, fishers_odds, lasso_reg
|
289
|
+
|
290
|
+
def qstring_to_float(qstr):
|
291
|
+
number = int(qstr[1:]) # Remove the "q" and convert the rest to an integer
|
292
|
+
return number / 100.0
|
293
|
+
|
294
|
+
columns_list = ['c1', 'c2', 'c3']
|
295
|
+
plate_list = ['p1','p3','p4']
|
296
|
+
|
297
|
+
dv_df = pd.read_csv(dv_loc)#, index_col='prc')
|
298
|
+
|
299
|
+
if agg_type.startswith('q'):
|
300
|
+
val = qstring_to_float(agg_type)
|
301
|
+
agg_type = lambda x: x.quantile(val)
|
302
|
+
|
303
|
+
# Aggregating for mean prediction, total count and count of values > 0.95
|
304
|
+
dv_df = dv_df.groupby('prc').agg(
|
305
|
+
pred=(dv_col, agg_type),
|
306
|
+
count_prc=('prc', 'size'),
|
307
|
+
mean_pathogen_area=('pathogen_area', 'mean')
|
308
|
+
)
|
309
|
+
|
310
|
+
dv_df = dv_df[dv_df['count_prc'] >= min_cell_count]
|
311
|
+
sequencing_df = pd.read_csv(sequencing_loc)
|
312
|
+
|
313
|
+
|
314
|
+
reads_df, stats_dict = process_reads(df=sequencing_df,
|
315
|
+
min_reads=min_reads,
|
316
|
+
min_wells=min_wells,
|
317
|
+
max_wells=max_wells,
|
318
|
+
gene_column='gene',
|
319
|
+
remove_outliers=remove_outlier_genes)
|
320
|
+
|
321
|
+
reads_df['value'] = reads_df['count']/reads_df['well_read_sum']
|
322
|
+
reads_df['gene_grna'] = reads_df['gene']+'_'+reads_df['grna']
|
323
|
+
|
324
|
+
display(reads_df)
|
325
|
+
|
326
|
+
df_long = reads_df
|
327
|
+
|
328
|
+
df_long = df_long[df_long['value'] > min_frequency] # removes gRNAs under a certain proportion
|
329
|
+
#df_long = df_long[df_long['value']<1.0] # removes gRNAs in wells with only one gRNA
|
330
|
+
|
331
|
+
# Extract gene and grna info from gene_grna column
|
332
|
+
df_long["gene"] = df_long["grna"].str.split("_").str[1]
|
333
|
+
df_long["grna"] = df_long["grna"].str.split("_").str[2]
|
334
|
+
|
335
|
+
agg_df = df_long.groupby('prc')['count'].sum().reset_index()
|
336
|
+
agg_df = agg_df.rename(columns={'count': 'count_sum'})
|
337
|
+
df_long = pd.merge(df_long, agg_df, on='prc', how='left')
|
338
|
+
df_long['value'] = df_long['count']/df_long['count_sum']
|
339
|
+
|
340
|
+
merged_df = df_long.merge(dv_df, left_on='prc', right_index=True)
|
341
|
+
merged_df = merged_df[merged_df['value'] > 0]
|
342
|
+
merged_df['plate'] = merged_df['prc'].str.split('_').str[0]
|
343
|
+
merged_df['row'] = merged_df['prc'].str.split('_').str[1]
|
344
|
+
merged_df['column'] = merged_df['prc'].str.split('_').str[2]
|
345
|
+
|
346
|
+
merged_df = merged_df[~merged_df['column'].isin(columns_list)]
|
347
|
+
merged_df = merged_df[merged_df['plate'].isin(plate_list)]
|
348
|
+
|
349
|
+
if transform == 'log':
|
350
|
+
merged_df['pred'] = np.log(merged_df['pred'] + 1e-10)
|
351
|
+
|
352
|
+
# Printing the unique values in 'col' and 'plate' columns
|
353
|
+
print("Unique values in col:", merged_df['column'].unique())
|
354
|
+
print("Unique values in plate:", merged_df['plate'].unique())
|
355
|
+
display(merged_df)
|
356
|
+
|
357
|
+
if fishers:
|
358
|
+
iv_df = generate_fraction_map(df=reads_df,
|
359
|
+
gene_column='grna',
|
360
|
+
min_frequency=min_frequency)
|
361
|
+
|
362
|
+
fishers_df = iv_df.join(dv_df, on='prc', how='inner')
|
363
|
+
|
364
|
+
significant_mutants = fishers_odds(df=fishers_df, threshold=fisher_threshold, phenotyp_col='pred')
|
365
|
+
significant_mutants = significant_mutants.sort_values(by='OddsRatio', ascending=False)
|
366
|
+
display(significant_mutants)
|
367
|
+
|
368
|
+
if regression_type == 'mlr':
|
369
|
+
if by_plate:
|
370
|
+
merged_df2 = merged_df.copy()
|
371
|
+
for plate in merged_df2['plate'].unique():
|
372
|
+
merged_df = merged_df2[merged_df2['plate'] == plate]
|
373
|
+
print(f'merged_df: {len(merged_df)}, plate: {plate}')
|
374
|
+
if len(merged_df) <100:
|
375
|
+
break
|
376
|
+
|
377
|
+
max_effects, max_effects_pvalues, model, df = MLR(merged_df, refine_model)
|
378
|
+
else:
|
379
|
+
|
380
|
+
max_effects, max_effects_pvalues, model, df = MLR(merged_df, refine_model)
|
381
|
+
return max_effects, max_effects_pvalues, model, df
|
382
|
+
|
383
|
+
if regression_type == 'ridge' or regression_type == 'lasso':
|
384
|
+
coeffs = lasso_reg(merged_df, alpha_value=alpha_value, reg_type=regression_type)
|
385
|
+
return coeffs
|
386
|
+
|
387
|
+
if regression_type == 'mixed':
|
388
|
+
model = smf.mixedlm("pred ~ gene_grna - 1", merged_df, groups=merged_df["plate"], re_formula="~1")
|
389
|
+
result = model.fit(method="bfgs")
|
390
|
+
print(result.summary())
|
391
|
+
|
392
|
+
# Print AIC and BIC
|
393
|
+
print("AIC:", result.aic)
|
394
|
+
print("BIC:", result.bic)
|
395
|
+
|
396
|
+
|
397
|
+
results_df = pd.DataFrame({
|
398
|
+
'effect': result.params,
|
399
|
+
'Standard Error': result.bse,
|
400
|
+
'T-Value': result.tvalues,
|
401
|
+
'p': result.pvalues
|
402
|
+
})
|
403
|
+
|
404
|
+
display(results_df)
|
405
|
+
_reg_v_plot(df=results_df)
|
406
|
+
|
407
|
+
std_resid = result.resid
|
408
|
+
|
409
|
+
# Create subplots
|
410
|
+
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
|
411
|
+
|
412
|
+
# Histogram of Residuals
|
413
|
+
axes[0].hist(std_resid, bins=50, edgecolor='k')
|
414
|
+
axes[0].set_xlabel('Residuals')
|
415
|
+
axes[0].set_ylabel('Frequency')
|
416
|
+
axes[0].set_title('Histogram of Residuals')
|
417
|
+
|
418
|
+
# Boxplot of Residuals
|
419
|
+
axes[1].boxplot(std_resid)
|
420
|
+
axes[1].set_ylabel('Residuals')
|
421
|
+
axes[1].set_title('Boxplot of Residuals')
|
422
|
+
|
423
|
+
# QQ Plot
|
424
|
+
sm.qqplot(std_resid, line='45', ax=axes[2])
|
425
|
+
axes[2].set_title('QQ Plot')
|
426
|
+
|
427
|
+
# Show plots
|
428
|
+
plt.tight_layout()
|
429
|
+
plt.show()
|
430
|
+
|
431
|
+
return result
|
432
|
+
|
433
|
+
@log_function_call
|
434
|
+
def analyze_data_reg(sequencing_loc, dv_loc, agg_type = 'mean', min_cell_count=50, min_reads=100, min_wells=2, max_wells=1000, remove_outlier_genes=False, refine_model=False, by_plate=False, threshold=0.5, fishers=False):
|
435
|
+
|
436
|
+
from .plot import _reg_v_plot
|
437
|
+
from .utils import generate_fraction_map, fishers_odds, model_metrics
|
438
|
+
|
439
|
+
def qstring_to_float(qstr):
|
440
|
+
number = int(qstr[1:]) # Remove the "q" and convert the rest to an integer
|
441
|
+
return number / 100.0
|
442
|
+
|
443
|
+
columns_list = ['c1', 'c2', 'c3', 'c15']
|
444
|
+
plate_list = ['p1','p2','p3','p4']
|
445
|
+
|
446
|
+
dv_df = pd.read_csv(dv_loc)#, index_col='prc')
|
447
|
+
|
448
|
+
if agg_type.startswith('q'):
|
449
|
+
val = qstring_to_float(agg_type)
|
450
|
+
agg_type = lambda x: x.quantile(val)
|
451
|
+
|
452
|
+
# Aggregating for mean prediction, total count and count of values > 0.95
|
453
|
+
dv_df = dv_df.groupby('prc').agg(
|
454
|
+
pred=('pred', agg_type),
|
455
|
+
count_prc=('prc', 'size'),
|
456
|
+
#count_above_95=('pred', lambda x: (x > 0.95).sum()),
|
457
|
+
mean_pathogen_area=('pathogen_area', 'mean')
|
458
|
+
)
|
459
|
+
|
460
|
+
dv_df = dv_df[dv_df['count_prc'] >= min_cell_count]
|
461
|
+
sequencing_df = pd.read_csv(sequencing_loc)
|
462
|
+
|
463
|
+
reads_df, stats_dict = process_reads(df=sequencing_df,
|
464
|
+
min_reads=min_reads,
|
465
|
+
min_wells=min_wells,
|
466
|
+
max_wells=max_wells,
|
467
|
+
gene_column='gene',
|
468
|
+
remove_outliers=remove_outlier_genes)
|
469
|
+
|
470
|
+
iv_df = generate_fraction_map(df=reads_df,
|
471
|
+
gene_column='grna',
|
472
|
+
min_frequency=0.0)
|
473
|
+
|
474
|
+
# Melt the iv_df to long format
|
475
|
+
df_long = iv_df.reset_index().melt(id_vars=["prc"],
|
476
|
+
value_vars=iv_df.columns,
|
477
|
+
var_name="gene_grna",
|
478
|
+
value_name="value")
|
479
|
+
|
480
|
+
# Extract gene and grna info from gene_grna column
|
481
|
+
df_long["gene"] = df_long["gene_grna"].str.split("_").str[1]
|
482
|
+
df_long["grna"] = df_long["gene_grna"].str.split("_").str[2]
|
483
|
+
|
484
|
+
merged_df = df_long.merge(dv_df, left_on='prc', right_index=True)
|
485
|
+
merged_df = merged_df[merged_df['value'] > 0]
|
486
|
+
merged_df['plate'] = merged_df['prc'].str.split('_').str[0]
|
487
|
+
merged_df['row'] = merged_df['prc'].str.split('_').str[1]
|
488
|
+
merged_df['column'] = merged_df['prc'].str.split('_').str[2]
|
489
|
+
|
490
|
+
merged_df = merged_df[~merged_df['column'].isin(columns_list)]
|
491
|
+
merged_df = merged_df[merged_df['plate'].isin(plate_list)]
|
492
|
+
|
493
|
+
# Printing the unique values in 'col' and 'plate' columns
|
494
|
+
print("Unique values in col:", merged_df['column'].unique())
|
495
|
+
print("Unique values in plate:", merged_df['plate'].unique())
|
496
|
+
|
497
|
+
if not by_plate:
|
498
|
+
if fishers:
|
499
|
+
fishers_odds(df=merged_df, threshold=threshold, phenotyp_col='pred')
|
500
|
+
|
501
|
+
if by_plate:
|
502
|
+
merged_df2 = merged_df.copy()
|
503
|
+
for plate in merged_df2['plate'].unique():
|
504
|
+
merged_df = merged_df2[merged_df2['plate'] == plate]
|
505
|
+
print(f'merged_df: {len(merged_df)}, plate: {plate}')
|
506
|
+
if len(merged_df) <100:
|
507
|
+
break
|
508
|
+
display(merged_df)
|
509
|
+
|
510
|
+
model = smf.ols("pred ~ gene + grna + gene:grna + plate + row + column", merged_df).fit()
|
511
|
+
#model = smf.ols("pred ~ infection_time + gene + grna + gene:grna + plate + row + column", merged_df).fit()
|
512
|
+
|
513
|
+
# Display model metrics and summary
|
514
|
+
model_metrics(model)
|
515
|
+
#print(model.summary())
|
516
|
+
|
517
|
+
if refine_model:
|
518
|
+
# Filter outliers
|
519
|
+
std_resid = model.get_influence().resid_studentized_internal
|
520
|
+
outliers_resid = np.where(np.abs(std_resid) > 3)[0]
|
521
|
+
(c, p) = model.get_influence().cooks_distance
|
522
|
+
outliers_cooks = np.where(c > 4/(len(merged_df)-merged_df.shape[1]-1))[0]
|
523
|
+
outliers = reduce(np.union1d, (outliers_resid, outliers_cooks))
|
524
|
+
merged_df_filtered = merged_df.drop(merged_df.index[outliers])
|
525
|
+
|
526
|
+
display(merged_df_filtered)
|
527
|
+
|
528
|
+
# Refit the model with filtered data
|
529
|
+
model = smf.ols("pred ~ gene + grna + gene:grna + row + column", merged_df_filtered).fit()
|
530
|
+
print("Number of outliers detected by standardized residuals:", len(outliers_resid))
|
531
|
+
print("Number of outliers detected by Cook's distance:", len(outliers_cooks))
|
532
|
+
|
533
|
+
model_metrics(model)
|
534
|
+
|
535
|
+
# Extract interaction coefficients and determine the maximum effect size
|
536
|
+
interaction_coeffs = {key: val for key, val in model.params.items() if "gene[T." in key and ":grna[T." in key}
|
537
|
+
interaction_pvalues = {key: val for key, val in model.pvalues.items() if "gene[T." in key and ":grna[T." in key}
|
538
|
+
|
539
|
+
max_effects = {}
|
540
|
+
max_effects_pvalues = {}
|
541
|
+
for key, val in interaction_coeffs.items():
|
542
|
+
gene_name = key.split(":")[0].replace("gene[T.", "").replace("]", "")
|
543
|
+
if gene_name not in max_effects or abs(max_effects[gene_name]) < abs(val):
|
544
|
+
max_effects[gene_name] = val
|
545
|
+
max_effects_pvalues[gene_name] = interaction_pvalues[key]
|
546
|
+
|
547
|
+
for key in max_effects:
|
548
|
+
print(f"Key: {key}: {max_effects[key]}, p:{max_effects_pvalues[key]}")
|
549
|
+
|
550
|
+
df = pd.DataFrame([max_effects, max_effects_pvalues])
|
551
|
+
df = df.transpose()
|
552
|
+
df = df.rename(columns={df.columns[0]: 'effect', df.columns[1]: 'p'})
|
553
|
+
df = df.sort_values(by=['effect', 'p'], ascending=[False, True])
|
554
|
+
|
555
|
+
_reg_v_plot(df)
|
556
|
+
|
557
|
+
if fishers:
|
558
|
+
fishers_odds(df=merged_df, threshold=threshold, phenotyp_col='pred')
|
559
|
+
else:
|
560
|
+
display(merged_df)
|
561
|
+
|
562
|
+
model = smf.ols("pred ~ gene + grna + gene:grna + plate + row + column", merged_df).fit()
|
563
|
+
|
564
|
+
# Display model metrics and summary
|
565
|
+
model_metrics(model)
|
566
|
+
|
567
|
+
if refine_model:
|
568
|
+
# Filter outliers
|
569
|
+
std_resid = model.get_influence().resid_studentized_internal
|
570
|
+
outliers_resid = np.where(np.abs(std_resid) > 3)[0]
|
571
|
+
(c, p) = model.get_influence().cooks_distance
|
572
|
+
outliers_cooks = np.where(c > 4/(len(merged_df)-merged_df.shape[1]-1))[0]
|
573
|
+
outliers = reduce(np.union1d, (outliers_resid, outliers_cooks))
|
574
|
+
merged_df_filtered = merged_df.drop(merged_df.index[outliers])
|
575
|
+
|
576
|
+
display(merged_df_filtered)
|
577
|
+
|
578
|
+
# Refit the model with filtered data
|
579
|
+
model = smf.ols("pred ~ gene + grna + gene:grna + plate + row + column", merged_df_filtered).fit()
|
580
|
+
print("Number of outliers detected by standardized residuals:", len(outliers_resid))
|
581
|
+
print("Number of outliers detected by Cook's distance:", len(outliers_cooks))
|
582
|
+
|
583
|
+
model_metrics(model)
|
584
|
+
|
585
|
+
# Extract interaction coefficients and determine the maximum effect size
|
586
|
+
interaction_coeffs = {key: val for key, val in model.params.items() if "gene[T." in key and ":grna[T." in key}
|
587
|
+
interaction_pvalues = {key: val for key, val in model.pvalues.items() if "gene[T." in key and ":grna[T." in key}
|
588
|
+
|
589
|
+
max_effects = {}
|
590
|
+
max_effects_pvalues = {}
|
591
|
+
for key, val in interaction_coeffs.items():
|
592
|
+
gene_name = key.split(":")[0].replace("gene[T.", "").replace("]", "")
|
593
|
+
if gene_name not in max_effects or abs(max_effects[gene_name]) < abs(val):
|
594
|
+
max_effects[gene_name] = val
|
595
|
+
max_effects_pvalues[gene_name] = interaction_pvalues[key]
|
596
|
+
|
597
|
+
for key in max_effects:
|
598
|
+
print(f"Key: {key}: {max_effects[key]}, p:{max_effects_pvalues[key]}")
|
599
|
+
|
600
|
+
df = pd.DataFrame([max_effects, max_effects_pvalues])
|
601
|
+
df = df.transpose()
|
602
|
+
df = df.rename(columns={df.columns[0]: 'effect', df.columns[1]: 'p'})
|
603
|
+
df = df.sort_values(by=['effect', 'p'], ascending=[False, True])
|
604
|
+
|
605
|
+
_reg_v_plot(df)
|
606
|
+
|
607
|
+
if fishers:
|
608
|
+
fishers_odds(df=merged_df, threshold=threshold, phenotyp_col='pred')
|
609
|
+
|
610
|
+
return max_effects, max_effects_pvalues, model, df
|
611
|
+
|
612
|
+
@log_function_call
|
613
|
+
def regression_analasys(dv_df,sequencing_loc, min_reads=75, min_wells=2, max_wells=0, model_type = 'mlr', min_cells=100, transform='logit', min_frequency=0.05, gene_column='gene', effect_size_threshold=0.25, fishers=True, clean_regression=False, VIF_threshold=10):
|
614
|
+
|
615
|
+
from .utils import generate_fraction_map, fishers_odds, model_metrics, check_multicollinearity
|
616
|
+
|
617
|
+
sequencing_df = pd.read_csv(sequencing_loc)
|
618
|
+
columns_list = ['c1','c2','c3', 'c15']
|
619
|
+
sequencing_df = sequencing_df[~sequencing_df['col'].isin(columns_list)]
|
620
|
+
|
621
|
+
reads_df, stats_dict = process_reads(df=sequencing_df,
|
622
|
+
min_reads=min_reads,
|
623
|
+
min_wells=min_wells,
|
624
|
+
max_wells=max_wells,
|
625
|
+
gene_column='gene')
|
626
|
+
|
627
|
+
display(reads_df)
|
628
|
+
|
629
|
+
iv_df = generate_fraction_map(df=reads_df,
|
630
|
+
gene_column=gene_column,
|
631
|
+
min_frequency=min_frequency)
|
632
|
+
|
633
|
+
display(iv_df)
|
634
|
+
|
635
|
+
dv_df = dv_df[dv_df['count_prc']>min_cells]
|
636
|
+
display(dv_df)
|
637
|
+
merged_df = iv_df.join(dv_df, on='prc', how='inner')
|
638
|
+
display(merged_df)
|
639
|
+
fisher_df = merged_df.copy()
|
640
|
+
|
641
|
+
merged_df.reset_index(inplace=True)
|
642
|
+
merged_df[['plate', 'row', 'col']] = merged_df['prc'].str.split('_', expand=True)
|
643
|
+
merged_df = merged_df.drop(columns=['prc'])
|
644
|
+
merged_df.dropna(inplace=True)
|
645
|
+
merged_df = pd.get_dummies(merged_df, columns=['plate', 'row', 'col'], drop_first=True)
|
646
|
+
|
647
|
+
y = merged_df['mean_pred']
|
648
|
+
|
649
|
+
if model_type == 'mlr':
|
650
|
+
merged_df = merged_df.drop(columns=['count_prc'])
|
651
|
+
|
652
|
+
elif model_type == 'wls':
|
653
|
+
weights = merged_df['count_prc']
|
654
|
+
|
655
|
+
elif model_type == 'glm':
|
656
|
+
merged_df = merged_df.drop(columns=['count_prc'])
|
657
|
+
|
658
|
+
if transform == 'logit':
|
659
|
+
# logit transformation
|
660
|
+
epsilon = 1e-15
|
661
|
+
y = np.log(y + epsilon) - np.log(1 - y + epsilon)
|
662
|
+
|
663
|
+
elif transform == 'log':
|
664
|
+
# log transformation
|
665
|
+
y = np.log10(y+1)
|
666
|
+
|
667
|
+
elif transform == 'center':
|
668
|
+
# Centering the y around 0
|
669
|
+
y_mean = y.mean()
|
670
|
+
y = y - y_mean
|
671
|
+
|
672
|
+
x = merged_df.drop('mean_pred', axis=1)
|
673
|
+
x = x.select_dtypes(include=[np.number])
|
674
|
+
#x = sm.add_constant(x)
|
675
|
+
x['const'] = 0.0
|
676
|
+
|
677
|
+
if model_type == 'mlr':
|
678
|
+
model = sm.OLS(y, x).fit()
|
679
|
+
model_metrics(model)
|
680
|
+
|
681
|
+
# Check for Multicollinearity
|
682
|
+
vif_data = check_multicollinearity(x.drop('const', axis=1)) # assuming you've added a constant to x
|
683
|
+
high_vif_columns = vif_data[vif_data["VIF"] > VIF_threshold]["Variable"].values # VIF threshold of 10 is common, but this can vary based on context
|
684
|
+
|
685
|
+
print(f"Columns with high VIF: {high_vif_columns}")
|
686
|
+
x = x.drop(columns=high_vif_columns) # dropping columns with high VIF
|
687
|
+
|
688
|
+
if clean_regression:
|
689
|
+
# 1. Filter by standardized residuals
|
690
|
+
std_resid = model.get_influence().resid_studentized_internal
|
691
|
+
outliers_resid = np.where(np.abs(std_resid) > 3)[0]
|
692
|
+
|
693
|
+
# 2. Filter by leverage
|
694
|
+
influence = model.get_influence().hat_matrix_diag
|
695
|
+
outliers_lev = np.where(influence > 2*(x.shape[1])/len(y))[0]
|
696
|
+
|
697
|
+
# 3. Filter by Cook's distance
|
698
|
+
(c, p) = model.get_influence().cooks_distance
|
699
|
+
outliers_cooks = np.where(c > 4/(len(y)-x.shape[1]-1))[0]
|
700
|
+
|
701
|
+
# Combine all identified outliers
|
702
|
+
outliers = reduce(np.union1d, (outliers_resid, outliers_lev, outliers_cooks))
|
703
|
+
|
704
|
+
# Filter out outliers
|
705
|
+
x_clean = x.drop(x.index[outliers])
|
706
|
+
y_clean = y.drop(y.index[outliers])
|
707
|
+
|
708
|
+
# Re-run the regression with the filtered data
|
709
|
+
model = sm.OLS(y_clean, x_clean).fit()
|
710
|
+
model_metrics(model)
|
711
|
+
|
712
|
+
elif model_type == 'wls':
|
713
|
+
model = sm.WLS(y, x, weights=weights).fit()
|
714
|
+
|
715
|
+
elif model_type == 'glm':
|
716
|
+
model = sm.GLM(y, x, family=sm.families.Binomial()).fit()
|
717
|
+
|
718
|
+
print(model.summary())
|
719
|
+
|
720
|
+
results_summary = model.summary()
|
721
|
+
|
722
|
+
results_as_html = results_summary.tables[1].as_html()
|
723
|
+
results_df = pd.read_html(results_as_html, header=0, index_col=0)[0]
|
724
|
+
results_df = results_df.sort_values(by='coef', ascending=False)
|
725
|
+
|
726
|
+
if model_type == 'mlr':
|
727
|
+
results_df['p'] = results_df['P>|t|']
|
728
|
+
elif model_type == 'wls':
|
729
|
+
results_df['p'] = results_df['P>|t|']
|
730
|
+
elif model_type == 'glm':
|
731
|
+
results_df['p'] = results_df['P>|z|']
|
732
|
+
|
733
|
+
results_df['type'] = 1
|
734
|
+
results_df.loc[results_df['p'] == 0.000, 'p'] = 0.005
|
735
|
+
results_df['-log10(p)'] = -np.log10(results_df['p'])
|
736
|
+
|
737
|
+
display(results_df)
|
738
|
+
|
739
|
+
# Create subplots
|
740
|
+
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 15))
|
741
|
+
|
742
|
+
# Plot histogram on ax1
|
743
|
+
sns.histplot(data=y, kde=False, element="step", ax=ax1, color='teal')
|
744
|
+
ax1.set_xlim([0, 1])
|
745
|
+
ax1.spines['top'].set_visible(False)
|
746
|
+
ax1.spines['right'].set_visible(False)
|
747
|
+
|
748
|
+
# Prepare data for volcano plot on ax2
|
749
|
+
results_df['-log10(p)'] = -np.log10(results_df['p'])
|
750
|
+
|
751
|
+
# Assuming the 'type' column is in the merged_df
|
752
|
+
sc = ax2.scatter(results_df['coef'], results_df['-log10(p)'], c=results_df['type'], cmap='coolwarm')
|
753
|
+
ax2.set_title('Volcano Plot')
|
754
|
+
ax2.set_xlabel('Coefficient')
|
755
|
+
ax2.set_ylabel('-log10(P-value)')
|
756
|
+
|
757
|
+
# Adjust colorbar
|
758
|
+
cbar = plt.colorbar(sc, ax=ax2, ticks=[-1, 1])
|
759
|
+
cbar.set_label('Sign of Coefficient')
|
760
|
+
cbar.set_ticklabels(['-ve', '+ve'])
|
761
|
+
|
762
|
+
# Add text for specified points
|
763
|
+
for idx, row in results_df.iterrows():
|
764
|
+
if row['p'] < 0.05 and row['coef'] > effect_size_threshold:
|
765
|
+
ax2.text(row['coef'], -np.log10(row['p']), idx, fontsize=8, ha='center', va='bottom', color='black')
|
766
|
+
|
767
|
+
ax2.axhline(y=-np.log10(0.05), color='gray', linestyle='--')
|
768
|
+
|
769
|
+
plt.show()
|
770
|
+
|
771
|
+
#if model_type == 'mlr':
|
772
|
+
# show_residules(model)
|
773
|
+
|
774
|
+
if fishers:
|
775
|
+
threshold = 2*effect_size_threshold
|
776
|
+
fishers_odds(df=fisher_df, threshold=threshold, phenotyp_col='mean_pred')
|
777
|
+
|
778
|
+
return
|
779
|
+
|
780
|
+
@log_function_call
|
781
|
+
def merge_pred_mes(src,
|
782
|
+
pred_loc,
|
783
|
+
target='protein of interest',
|
784
|
+
cell_dim=4,
|
785
|
+
nucleus_dim=5,
|
786
|
+
pathogen_dim=6,
|
787
|
+
channel_of_interest=1,
|
788
|
+
pathogen_size_min=0,
|
789
|
+
nucleus_size_min=0,
|
790
|
+
cell_size_min=0,
|
791
|
+
pathogen_min=0,
|
792
|
+
nucleus_min=0,
|
793
|
+
cell_min=0,
|
794
|
+
target_min=0,
|
795
|
+
mask_chans=[0,1,2],
|
796
|
+
filter_data=False,
|
797
|
+
include_noninfected=False,
|
798
|
+
include_multiinfected=False,
|
799
|
+
include_multinucleated=False,
|
800
|
+
cells_per_well=10,
|
801
|
+
save_filtered_filelist=False,
|
802
|
+
verbose=False):
|
803
|
+
|
804
|
+
from .io import _read_and_merge_data
|
805
|
+
from .plot import _plot_histograms_and_stats
|
806
|
+
|
807
|
+
mask_chans=[cell_dim,nucleus_dim,pathogen_dim]
|
808
|
+
sns.color_palette("mako", as_cmap=True)
|
809
|
+
print(f'channel:{channel_of_interest} = {target}')
|
810
|
+
overlay_channels = [0, 1, 2, 3]
|
811
|
+
overlay_channels.remove(channel_of_interest)
|
812
|
+
overlay_channels.reverse()
|
813
|
+
|
814
|
+
db_loc = [src+'/measurements/measurements.db']
|
815
|
+
tables = ['cell', 'nucleus', 'pathogen','cytoplasm']
|
816
|
+
df, object_dfs = _read_and_merge_data(db_loc,
|
817
|
+
tables,
|
818
|
+
verbose=True,
|
819
|
+
include_multinucleated=include_multinucleated,
|
820
|
+
include_multiinfected=include_multiinfected,
|
821
|
+
include_noninfected=include_noninfected)
|
822
|
+
if filter_data:
|
823
|
+
df = df[df['cell_area'] > cell_size_min]
|
824
|
+
df = df[df[f'cell_channel_{mask_chans[2]}_mean_intensity'] > cell_min]
|
825
|
+
print(f'After cell filtration {len(df)}')
|
826
|
+
df = df[df['nucleus_area'] > nucleus_size_min]
|
827
|
+
df = df[df[f'nucleus_channel_{mask_chans[0]}_mean_intensity'] > nucleus_min]
|
828
|
+
print(f'After nucleus filtration {len(df)}')
|
829
|
+
df = df[df['pathogen_area'] > pathogen_size_min]
|
830
|
+
df=df[df[f'pathogen_channel_{mask_chans[1]}_mean_intensity'] > pathogen_min]
|
831
|
+
print(f'After pathogen filtration {len(df)}')
|
832
|
+
df = df[df[f'cell_channel_{channel_of_interest}_percentile_95'] > target_min]
|
833
|
+
print(f'After channel {channel_of_interest} filtration', len(df))
|
834
|
+
|
835
|
+
df['recruitment'] = df[f'pathogen_channel_{channel_of_interest}_mean_intensity']/df[f'cytoplasm_channel_{channel_of_interest}_mean_intensity']
|
836
|
+
|
837
|
+
pred_df = annotate_results(pred_loc=pred_loc)
|
838
|
+
|
839
|
+
if verbose:
|
840
|
+
_plot_histograms_and_stats(df=pred_df)
|
841
|
+
|
842
|
+
pred_df.set_index('prcfo', inplace=True)
|
843
|
+
pred_df = pred_df.drop(columns=['plate', 'row', 'col', 'field'])
|
844
|
+
|
845
|
+
joined_df = df.join(pred_df, how='inner')
|
846
|
+
|
847
|
+
if verbose:
|
848
|
+
_plot_histograms_and_stats(df=joined_df)
|
849
|
+
|
850
|
+
#dv = joined_df.copy()
|
851
|
+
#if 'prc' not in dv.columns:
|
852
|
+
#dv['prc'] = dv['plate'] + '_' + dv['row'] + '_' + dv['col']
|
853
|
+
#dv = dv[['pred']].groupby('prc').mean()
|
854
|
+
#dv.set_index('prc', inplace=True)
|
855
|
+
|
856
|
+
#loc = '/mnt/data/CellVoyager/20x/tsg101/crispr_screen/all/measurements/dv.csv'
|
857
|
+
#dv.to_csv(loc, index=True, header=True, mode='w')
|
858
|
+
|
859
|
+
return joined_df
|
860
|
+
|
861
|
+
def process_reads(df, min_reads, min_wells, max_wells, gene_column, remove_outliers=False):
|
862
|
+
print('start',len(df))
|
863
|
+
df = df[df['count'] >= min_reads]
|
864
|
+
print('after filtering min reads',min_reads, len(df))
|
865
|
+
reads_ls = df['count']
|
866
|
+
stats_dict = {}
|
867
|
+
stats_dict['screen_reads_mean'] = np.mean(reads_ls)
|
868
|
+
stats_dict['screen_reads_sd'] = np.std(reads_ls)
|
869
|
+
stats_dict['screen_reads_var'] = np.var(reads_ls)
|
870
|
+
|
871
|
+
well_read_sum = pd.DataFrame(df.groupby(['prc']).sum())
|
872
|
+
well_read_sum = well_read_sum.rename({'count': 'well_read_sum'}, axis=1)
|
873
|
+
well_sgRNA_count = pd.DataFrame(df.groupby(['prc']).count()[gene_column])
|
874
|
+
well_sgRNA_count = well_sgRNA_count.rename({gene_column: 'gRNAs_per_well'}, axis=1)
|
875
|
+
well_seq = pd.merge(well_read_sum, well_sgRNA_count, how='inner', suffixes=('', '_right'), left_index=True, right_index=True)
|
876
|
+
gRNA_well_count = pd.DataFrame(df.groupby([gene_column]).count()['prc'])
|
877
|
+
gRNA_well_count = gRNA_well_count.rename({'prc': 'gRNA_well_count'}, axis=1)
|
878
|
+
df = pd.merge(df, well_seq, on='prc', how='inner', suffixes=('', '_right'))
|
879
|
+
df = pd.merge(df, gRNA_well_count, on=gene_column, how='inner', suffixes=('', '_right'))
|
880
|
+
|
881
|
+
df = df[df['gRNA_well_count'] >= min_wells]
|
882
|
+
df = df[df['gRNA_well_count'] <= max_wells]
|
883
|
+
|
884
|
+
if remove_outliers:
|
885
|
+
clf = IsolationForest(contamination='auto', random_state=42, n_jobs=20)
|
886
|
+
#clf.fit(df.select_dtypes(include=['int', 'float']))
|
887
|
+
clf.fit(df[["gRNA_well_count", "count"]])
|
888
|
+
outlier_array = clf.predict(df[["gRNA_well_count", "count"]])
|
889
|
+
#outlier_array = clf.predict(df.select_dtypes(include=['int', 'float']))
|
890
|
+
outlier_df = pd.DataFrame(outlier_array, columns=['outlier'])
|
891
|
+
df['outlier'] = outlier_df['outlier']
|
892
|
+
outliers = pd.DataFrame(df[df['outlier']==-1])
|
893
|
+
df = pd.DataFrame(df[df['outlier']==1])
|
894
|
+
print('removed',len(outliers), 'outliers', 'inlers',len(df))
|
895
|
+
|
896
|
+
columns_to_drop = ['gRNA_well_count','gRNAs_per_well', 'well_read_sum']#, 'outlier']
|
897
|
+
df = df.drop(columns_to_drop, axis=1)
|
898
|
+
|
899
|
+
plates = ['p1', 'p2', 'p3', 'p4']
|
900
|
+
df = df[df.plate.isin(plates) == True]
|
901
|
+
print('after filtering out p5,p6,p7,p8',len(df))
|
902
|
+
|
903
|
+
gRNA_well_count = pd.DataFrame(df.groupby([gene_column]).count()['prc'])
|
904
|
+
gRNA_well_count = gRNA_well_count.rename({'prc': 'gRNA_well_count'}, axis=1)
|
905
|
+
df = pd.merge(df, gRNA_well_count, on=gene_column, how='inner', suffixes=('', '_right'))
|
906
|
+
well_read_sum = pd.DataFrame(df.groupby(['prc']).sum())
|
907
|
+
well_read_sum = well_read_sum.rename({'count': 'well_read_sum'}, axis=1)
|
908
|
+
well_sgRNA_count = pd.DataFrame(df.groupby(['prc']).count()[gene_column])
|
909
|
+
well_sgRNA_count = well_sgRNA_count.rename({gene_column: 'gRNAs_per_well'}, axis=1)
|
910
|
+
well_seq = pd.merge(well_read_sum, well_sgRNA_count, how='inner', suffixes=('', '_right'), left_index=True, right_index=True)
|
911
|
+
df = pd.merge(df, well_seq, on='prc', how='inner', suffixes=('', '_right'))
|
912
|
+
|
913
|
+
columns_to_drop = [col for col in df.columns if col.endswith('_right')]
|
914
|
+
columns_to_drop2 = [col for col in df.columns if col.endswith('0')]
|
915
|
+
columns_to_drop = columns_to_drop + columns_to_drop2
|
916
|
+
df = df.drop(columns_to_drop, axis=1)
|
917
|
+
return df, stats_dict
|
918
|
+
|
919
|
+
def annotate_results(pred_loc):
|
920
|
+
|
921
|
+
from .utils import _map_wells_png
|
922
|
+
|
923
|
+
df = pd.read_csv(pred_loc)
|
924
|
+
df = df.copy()
|
925
|
+
pc_col_list = ['c4','c5','c6','c7','c8','c9','c10','c11','c12','c13','c14','c15','c16','c17','c18','c19','c20','c21','c22','c23','c24']
|
926
|
+
pc_plate_list = ['p6','p7','p8', 'p9']
|
927
|
+
|
928
|
+
nc_col_list = ['c1','c2','c3']
|
929
|
+
nc_plate_list = ['p1','p2','p3','p4','p6','p7','p8', 'p9']
|
930
|
+
|
931
|
+
screen_col_list = ['c4','c5','c6','c7','c8','c9','c10','c11','c12','c13','c14','c15','c16','c17','c18','c19','c20','c21','c22','c23','c24']
|
932
|
+
screen_plate_list = ['p1','p2','p3','p4']
|
933
|
+
|
934
|
+
df[['plate', 'row', 'col', 'field', 'cell_id', 'prcfo']] = df['path'].apply(lambda x: pd.Series(_map_wells_png(x)))
|
935
|
+
|
936
|
+
df.loc[(df['col'].isin(pc_col_list)) & (df['plate'].isin(pc_plate_list)), 'condition'] = 'pc'
|
937
|
+
df.loc[(df['col'].isin(nc_col_list)) & (df['plate'].isin(nc_plate_list)), 'condition'] = 'nc'
|
938
|
+
df.loc[(df['col'].isin(screen_col_list)) & (df['plate'].isin(screen_plate_list)), 'condition'] = 'screen'
|
939
|
+
|
940
|
+
df = df.dropna(subset=['condition'])
|
941
|
+
display(df)
|
942
|
+
return df
|
943
|
+
|
944
|
+
def generate_dataset(src, file_type=None, experiment='TSG101_screen', sample=None):
|
945
|
+
|
946
|
+
from .utils import init_globals, add_images_to_tar
|
947
|
+
|
948
|
+
db_path = os.path.join(src, 'measurements','measurements.db')
|
949
|
+
dst = os.path.join(src, 'datasets')
|
950
|
+
|
951
|
+
global total_images
|
952
|
+
all_paths = []
|
953
|
+
|
954
|
+
# Connect to the database and retrieve the image paths
|
955
|
+
print(f'Reading DataBase: {db_path}')
|
956
|
+
with sqlite3.connect(db_path) as conn:
|
957
|
+
cursor = conn.cursor()
|
958
|
+
if file_type:
|
959
|
+
cursor.execute("SELECT png_path FROM png_list WHERE png_path LIKE ?", (f"%{file_type}%",))
|
960
|
+
else:
|
961
|
+
cursor.execute("SELECT png_path FROM png_list")
|
962
|
+
while True:
|
963
|
+
rows = cursor.fetchmany(1000)
|
964
|
+
if not rows:
|
965
|
+
break
|
966
|
+
all_paths.extend([row[0] for row in rows])
|
967
|
+
|
968
|
+
if isinstance(sample, int):
|
969
|
+
selected_paths = random.sample(all_paths, sample)
|
970
|
+
print(f'Random selection of {len(selected_paths)} paths')
|
971
|
+
else:
|
972
|
+
selected_paths = all_paths
|
973
|
+
random.shuffle(selected_paths)
|
974
|
+
print(f'All paths: {len(selected_paths)} paths')
|
975
|
+
|
976
|
+
total_images = len(selected_paths)
|
977
|
+
print(f'found {total_images} images')
|
978
|
+
|
979
|
+
# Create a temp folder in dst
|
980
|
+
temp_dir = os.path.join(dst, "temp_tars")
|
981
|
+
os.makedirs(temp_dir, exist_ok=True)
|
982
|
+
|
983
|
+
# Chunking the data
|
984
|
+
if len(selected_paths) > 10000:
|
985
|
+
num_procs = cpu_count()-2
|
986
|
+
chunk_size = len(selected_paths) // num_procs
|
987
|
+
remainder = len(selected_paths) % num_procs
|
988
|
+
else:
|
989
|
+
num_procs = 2
|
990
|
+
chunk_size = len(selected_paths) // 2
|
991
|
+
remainder = 0
|
992
|
+
|
993
|
+
paths_chunks = []
|
994
|
+
start = 0
|
995
|
+
for i in range(num_procs):
|
996
|
+
end = start + chunk_size + (1 if i < remainder else 0)
|
997
|
+
paths_chunks.append(selected_paths[start:end])
|
998
|
+
start = end
|
999
|
+
|
1000
|
+
temp_tar_files = [os.path.join(temp_dir, f'temp_{i}.tar') for i in range(num_procs)]
|
1001
|
+
|
1002
|
+
# Initialize the shared objects
|
1003
|
+
counter_ = Value('i', 0)
|
1004
|
+
lock_ = Lock()
|
1005
|
+
|
1006
|
+
ctx = multiprocessing.get_context('spawn')
|
1007
|
+
|
1008
|
+
print(f'Generating temporary tar files in {dst}')
|
1009
|
+
|
1010
|
+
# Combine the temporary tar files into a final tar
|
1011
|
+
date_name = datetime.date.today().strftime('%y%m%d')
|
1012
|
+
tar_name = f'{date_name}_{experiment}_{file_type}.tar'
|
1013
|
+
if os.path.exists(tar_name):
|
1014
|
+
number = random.randint(1, 100)
|
1015
|
+
tar_name_2 = f'{date_name}_{experiment}_{file_type}_{number}.tar'
|
1016
|
+
print(f'Warning: {os.path.basename(tar_name)} exists saving as {os.path.basename(tar_name_2)} ')
|
1017
|
+
tar_name = tar_name_2
|
1018
|
+
|
1019
|
+
# Add the counter and lock to the arguments for pool.map
|
1020
|
+
print(f'Merging temporary files')
|
1021
|
+
#with Pool(processes=num_procs, initializer=init_globals, initargs=(counter_, lock_)) as pool:
|
1022
|
+
# results = pool.map(add_images_to_tar, zip(paths_chunks, temp_tar_files))
|
1023
|
+
|
1024
|
+
with ctx.Pool(processes=num_procs, initializer=init_globals, initargs=(counter_, lock_)) as pool:
|
1025
|
+
results = pool.map(add_images_to_tar, zip(paths_chunks, temp_tar_files))
|
1026
|
+
|
1027
|
+
with tarfile.open(os.path.join(dst, tar_name), 'w') as final_tar:
|
1028
|
+
for tar_path in results:
|
1029
|
+
with tarfile.open(tar_path, 'r') as t:
|
1030
|
+
for member in t.getmembers():
|
1031
|
+
t.extract(member, path=dst)
|
1032
|
+
final_tar.add(os.path.join(dst, member.name), arcname=member.name)
|
1033
|
+
os.remove(os.path.join(dst, member.name))
|
1034
|
+
os.remove(tar_path)
|
1035
|
+
|
1036
|
+
# Delete the temp folder
|
1037
|
+
shutil.rmtree(temp_dir)
|
1038
|
+
print(f"\nSaved {total_images} images to {os.path.join(dst, tar_name)}")
|
1039
|
+
|
1040
|
+
def apply_model_to_tar(tar_path, model_path, file_type='cell_png', image_size=224, batch_size=64, normalize=True, preload='images', num_workers=10, verbose=False):
|
1041
|
+
|
1042
|
+
from .io import TarImageDataset, DataLoader
|
1043
|
+
|
1044
|
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
1045
|
+
if normalize:
|
1046
|
+
transform = transforms.Compose([
|
1047
|
+
transforms.ToTensor(),
|
1048
|
+
transforms.CenterCrop(size=(image_size, image_size)),
|
1049
|
+
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])
|
1050
|
+
else:
|
1051
|
+
transform = transforms.Compose([
|
1052
|
+
transforms.ToTensor(),
|
1053
|
+
transforms.CenterCrop(size=(image_size, image_size))])
|
1054
|
+
|
1055
|
+
if verbose:
|
1056
|
+
print(f'Loading model from {model_path}')
|
1057
|
+
print(f'Loading dataset from {tar_path}')
|
1058
|
+
|
1059
|
+
model = torch.load(model_path)
|
1060
|
+
|
1061
|
+
dataset = TarImageDataset(tar_path, transform=transform)
|
1062
|
+
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True)
|
1063
|
+
|
1064
|
+
model_name = os.path.splitext(os.path.basename(model_path))[0]
|
1065
|
+
dataset_name = os.path.splitext(os.path.basename(tar_path))[0]
|
1066
|
+
date_name = datetime.date.today().strftime('%y%m%d')
|
1067
|
+
dst = os.path.dirname(tar_path)
|
1068
|
+
result_loc = f'{dst}/{date_name}_{dataset_name}_{model_name}_result.csv'
|
1069
|
+
|
1070
|
+
model.eval()
|
1071
|
+
model = model.to(device)
|
1072
|
+
|
1073
|
+
if verbose:
|
1074
|
+
print(model)
|
1075
|
+
print(f'Generated dataset with {len(dataset)} images')
|
1076
|
+
print(f'Generating loader from {len(data_loader)} batches')
|
1077
|
+
print(f'Results wil be saved in: {result_loc}')
|
1078
|
+
print(f'Model is in eval mode')
|
1079
|
+
print(f'Model loaded to device')
|
1080
|
+
|
1081
|
+
prediction_pos_probs = []
|
1082
|
+
filenames_list = []
|
1083
|
+
gc.collect()
|
1084
|
+
with torch.no_grad():
|
1085
|
+
for batch_idx, (batch_images, filenames) in enumerate(data_loader, start=1):
|
1086
|
+
images = batch_images.to(torch.float).to(device)
|
1087
|
+
outputs = model(images)
|
1088
|
+
batch_prediction_pos_prob = torch.sigmoid(outputs).cpu().numpy()
|
1089
|
+
prediction_pos_probs.extend(batch_prediction_pos_prob.tolist())
|
1090
|
+
filenames_list.extend(filenames)
|
1091
|
+
print(f'\rbatch: {batch_idx}/{len(data_loader)}', end='\r', flush=True)
|
1092
|
+
|
1093
|
+
data = {'path':filenames_list, 'pred':prediction_pos_probs}
|
1094
|
+
df = pd.DataFrame(data, index=None)
|
1095
|
+
df.to_csv(result_loc, index=True, header=True, mode='w')
|
1096
|
+
torch.cuda.empty_cache()
|
1097
|
+
torch.cuda.memory.empty_cache()
|
1098
|
+
return df
|
1099
|
+
|
1100
|
+
def apply_model(src, model_path, image_size=224, batch_size=64, normalize=True, num_workers=10):
|
1101
|
+
|
1102
|
+
from .io import NoClassDataset
|
1103
|
+
|
1104
|
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
1105
|
+
|
1106
|
+
if normalize:
|
1107
|
+
transform = transforms.Compose([
|
1108
|
+
transforms.ToTensor(),
|
1109
|
+
transforms.CenterCrop(size=(image_size, image_size)),
|
1110
|
+
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])
|
1111
|
+
else:
|
1112
|
+
transform = transforms.Compose([
|
1113
|
+
transforms.ToTensor(),
|
1114
|
+
transforms.CenterCrop(size=(image_size, image_size))])
|
1115
|
+
|
1116
|
+
model = torch.load(model_path)
|
1117
|
+
print(model)
|
1118
|
+
|
1119
|
+
print(f'Loading dataset in {src} with {len(src)} images')
|
1120
|
+
dataset = NoClassDataset(data_dir=src, transform=transform, shuffle=True, load_to_memory=False)
|
1121
|
+
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers)
|
1122
|
+
print(f'Loaded {len(src)} images')
|
1123
|
+
|
1124
|
+
result_loc = os.path.splitext(model_path)[0]+datetime.date.today().strftime('%y%m%d')+'_'+os.path.splitext(model_path)[1]+'_test_result.csv'
|
1125
|
+
print(f'Results wil be saved in: {result_loc}')
|
1126
|
+
|
1127
|
+
model.eval()
|
1128
|
+
model = model.to(device)
|
1129
|
+
prediction_pos_probs = []
|
1130
|
+
filenames_list = []
|
1131
|
+
with torch.no_grad():
|
1132
|
+
for batch_idx, (batch_images, filenames) in enumerate(data_loader, start=1):
|
1133
|
+
images = batch_images.to(torch.float).to(device)
|
1134
|
+
outputs = model(images)
|
1135
|
+
batch_prediction_pos_prob = torch.sigmoid(outputs).cpu().numpy()
|
1136
|
+
prediction_pos_probs.extend(batch_prediction_pos_prob.tolist())
|
1137
|
+
filenames_list.extend(filenames)
|
1138
|
+
print(f'\rbatch: {batch_idx}/{len(data_loader)}', end='\r', flush=True)
|
1139
|
+
data = {'path':filenames_list, 'pred':prediction_pos_probs}
|
1140
|
+
df = pd.DataFrame(data, index=None)
|
1141
|
+
df.to_csv(result_loc, index=True, header=True, mode='w')
|
1142
|
+
torch.cuda.empty_cache()
|
1143
|
+
torch.cuda.memory.empty_cache()
|
1144
|
+
return df
|
1145
|
+
|
1146
|
+
|
1147
|
+
def generate_training_data_file_list(src,
|
1148
|
+
target='protein of interest',
|
1149
|
+
cell_dim=4,
|
1150
|
+
nucleus_dim=5,
|
1151
|
+
pathogen_dim=6,
|
1152
|
+
channel_of_interest=1,
|
1153
|
+
pathogen_size_min=0,
|
1154
|
+
nucleus_size_min=0,
|
1155
|
+
cell_size_min=0,
|
1156
|
+
pathogen_min=0,
|
1157
|
+
nucleus_min=0,
|
1158
|
+
cell_min=0,
|
1159
|
+
target_min=0,
|
1160
|
+
mask_chans=[0,1,2],
|
1161
|
+
filter_data=False,
|
1162
|
+
include_noninfected=False,
|
1163
|
+
include_multiinfected=False,
|
1164
|
+
include_multinucleated=False,
|
1165
|
+
cells_per_well=10,
|
1166
|
+
save_filtered_filelist=False):
|
1167
|
+
|
1168
|
+
from .io import _read_and_merge_data
|
1169
|
+
|
1170
|
+
mask_dims=[cell_dim,nucleus_dim,pathogen_dim]
|
1171
|
+
sns.color_palette("mako", as_cmap=True)
|
1172
|
+
print(f'channel:{channel_of_interest} = {target}')
|
1173
|
+
overlay_channels = [0, 1, 2, 3]
|
1174
|
+
overlay_channels.remove(channel_of_interest)
|
1175
|
+
overlay_channels.reverse()
|
1176
|
+
|
1177
|
+
db_loc = [src+'/measurements/measurements.db']
|
1178
|
+
tables = ['cell', 'nucleus', 'pathogen','cytoplasm']
|
1179
|
+
df, object_dfs = _read_and_merge_data(db_loc,
|
1180
|
+
tables,
|
1181
|
+
verbose=True,
|
1182
|
+
include_multinucleated=include_multinucleated,
|
1183
|
+
include_multiinfected=include_multiinfected,
|
1184
|
+
include_noninfected=include_noninfected)
|
1185
|
+
|
1186
|
+
if filter_data:
|
1187
|
+
df = df[df['cell_area'] > cell_size_min]
|
1188
|
+
df = df[df[f'cell_channel_{mask_chans[2]}_mean_intensity'] > cell_min]
|
1189
|
+
print(f'After cell filtration {len(df)}')
|
1190
|
+
df = df[df['nucleus_area'] > nucleus_size_min]
|
1191
|
+
df = df[df[f'nucleus_channel_{mask_chans[0]}_mean_intensity'] > nucleus_min]
|
1192
|
+
print(f'After nucleus filtration {len(df)}')
|
1193
|
+
df = df[df['pathogen_area'] > pathogen_size_min]
|
1194
|
+
df=df[df[f'pathogen_channel_{mask_chans[1]}_mean_intensity'] > pathogen_min]
|
1195
|
+
print(f'After pathogen filtration {len(df)}')
|
1196
|
+
df = df[df[f'cell_channel_{channel_of_interest}_percentile_95'] > target_min]
|
1197
|
+
print(f'After channel {channel_of_interest} filtration', len(df))
|
1198
|
+
|
1199
|
+
df['recruitment'] = df[f'pathogen_channel_{channel_of_interest}_mean_intensity']/df[f'cytoplasm_channel_{channel_of_interest}_mean_intensity']
|
1200
|
+
return df
|
1201
|
+
|
1202
|
+
def training_dataset_from_annotation(db_path, dst, annotation_column='test', annotated_classes=(1, 2)):
|
1203
|
+
all_paths = []
|
1204
|
+
|
1205
|
+
# Connect to the database and retrieve the image paths and annotations
|
1206
|
+
print(f'Reading DataBase: {db_path}')
|
1207
|
+
with sqlite3.connect(db_path) as conn:
|
1208
|
+
cursor = conn.cursor()
|
1209
|
+
# Prepare the query with parameterized placeholders for annotated_classes
|
1210
|
+
placeholders = ','.join('?' * len(annotated_classes))
|
1211
|
+
query = f"SELECT png_path, {annotation_column} FROM png_list WHERE {annotation_column} IN ({placeholders})"
|
1212
|
+
cursor.execute(query, annotated_classes)
|
1213
|
+
|
1214
|
+
while True:
|
1215
|
+
rows = cursor.fetchmany(1000)
|
1216
|
+
if not rows:
|
1217
|
+
break
|
1218
|
+
for row in rows:
|
1219
|
+
all_paths.append(row)
|
1220
|
+
|
1221
|
+
# Filter paths based on annotation
|
1222
|
+
class_paths = []
|
1223
|
+
for class_ in annotated_classes:
|
1224
|
+
class_paths_temp = [path for path, annotation in all_paths if annotation == class_]
|
1225
|
+
class_paths.append(class_paths_temp)
|
1226
|
+
|
1227
|
+
print(f'Generated a list of lists from annotation of {len(class_paths)} classes')
|
1228
|
+
return class_paths
|
1229
|
+
|
1230
|
+
def generate_dataset_from_lists(dst, class_data, classes, test_split=0.1):
|
1231
|
+
# Make sure that the length of class_data matches the length of classes
|
1232
|
+
if len(class_data) != len(classes):
|
1233
|
+
raise ValueError("class_data and classes must have the same length.")
|
1234
|
+
|
1235
|
+
total_files = sum(len(data) for data in class_data)
|
1236
|
+
processed_files = 0
|
1237
|
+
|
1238
|
+
for cls, data in zip(classes, class_data):
|
1239
|
+
# Create directories
|
1240
|
+
train_class_dir = os.path.join(dst, f'train/{cls}')
|
1241
|
+
test_class_dir = os.path.join(dst, f'test/{cls}')
|
1242
|
+
os.makedirs(train_class_dir, exist_ok=True)
|
1243
|
+
os.makedirs(test_class_dir, exist_ok=True)
|
1244
|
+
|
1245
|
+
# Split the data
|
1246
|
+
train_data, test_data = train_test_split(data, test_size=test_split, shuffle=True, random_state=42)
|
1247
|
+
|
1248
|
+
# Copy train files
|
1249
|
+
for path in train_data:
|
1250
|
+
shutil.copy(path, os.path.join(train_class_dir, os.path.basename(path)))
|
1251
|
+
processed_files += 1
|
1252
|
+
print(f'{processed_files}/{total_files}', end='\r', flush=True)
|
1253
|
+
|
1254
|
+
# Copy test files
|
1255
|
+
for path in test_data:
|
1256
|
+
shutil.copy(path, os.path.join(test_class_dir, os.path.basename(path)))
|
1257
|
+
processed_files += 1
|
1258
|
+
print(f'{processed_files}/{total_files}', end='\r', flush=True)
|
1259
|
+
|
1260
|
+
# Print summary
|
1261
|
+
for cls in classes:
|
1262
|
+
train_class_dir = os.path.join(dst, f'train/{cls}')
|
1263
|
+
test_class_dir = os.path.join(dst, f'test/{cls}')
|
1264
|
+
print(f'Train class {cls}: {len(os.listdir(train_class_dir))}, Test class {cls}: {len(os.listdir(test_class_dir))}')
|
1265
|
+
|
1266
|
+
return
|
1267
|
+
|
1268
|
+
def generate_training_dataset(src, mode='annotation', annotation_column='test', annotated_classes=[1,2], classes=['nc','pc'], size=200, test_split=0.1, class_metadata=[['c1'],['c2']], metadata_type_by='col', channel_of_interest=3, custom_measurement=None, tables=None, png_type='cell_png'):
|
1269
|
+
|
1270
|
+
from .io import _read_and_merge_data, _read_db
|
1271
|
+
from .utils import get_paths_from_db, annotate_conditions
|
1272
|
+
|
1273
|
+
db_path = os.path.join(src, 'measurements','measurements.db')
|
1274
|
+
dst = os.path.join(src, 'datasets', 'training')
|
1275
|
+
|
1276
|
+
if mode == 'annotation':
|
1277
|
+
class_paths_ls_2 = []
|
1278
|
+
class_paths_ls = training_dataset_from_annotation(db_path, dst, annotation_column, annotated_classes=annotated_classes)
|
1279
|
+
for class_paths in class_paths_ls:
|
1280
|
+
class_paths_temp = random.sample(class_paths, size)
|
1281
|
+
class_paths_ls_2.append(class_paths_temp)
|
1282
|
+
class_paths_ls = class_paths_ls_2
|
1283
|
+
|
1284
|
+
elif mode == 'metadata':
|
1285
|
+
class_paths_ls = []
|
1286
|
+
[df] = _read_db(db_loc=db_path, tables=['png_list'])
|
1287
|
+
df['metadata_based_class'] = pd.NA
|
1288
|
+
for i, class_ in enumerate(classes):
|
1289
|
+
ls = class_metadata[i]
|
1290
|
+
df.loc[df[metadata_type_by].isin(ls), 'metadata_based_class'] = class_
|
1291
|
+
|
1292
|
+
for class_ in classes:
|
1293
|
+
class_temp_df = df[df['metadata_based_class'] == class_]
|
1294
|
+
class_paths_temp = random.sample(class_temp_df['png_path'].tolist(), size)
|
1295
|
+
class_paths_ls.append(class_paths_temp)
|
1296
|
+
|
1297
|
+
elif mode == 'recruitment':
|
1298
|
+
class_paths_ls = []
|
1299
|
+
if not isinstance(tables, list):
|
1300
|
+
tables = ['cell', 'nucleus', 'pathogen','cytoplasm']
|
1301
|
+
|
1302
|
+
df, _ = _read_and_merge_data(locs=[db_path],
|
1303
|
+
tables=tables,
|
1304
|
+
verbose=False,
|
1305
|
+
include_multinucleated=True,
|
1306
|
+
include_multiinfected=True,
|
1307
|
+
include_noninfected=True)
|
1308
|
+
|
1309
|
+
print('length df 1', len(df))
|
1310
|
+
|
1311
|
+
df = annotate_conditions(df, cells=['HeLa'], cell_loc=None, pathogens=['pathogen'], pathogen_loc=None, treatments=classes, treatment_loc=class_metadata, types = ['col','col',metadata_type_by])
|
1312
|
+
print('length df 2', len(df))
|
1313
|
+
[png_list_df] = _read_db(db_loc=db_path, tables=['png_list'])
|
1314
|
+
|
1315
|
+
if custom_measurement != None:
|
1316
|
+
|
1317
|
+
if not isinstance(custom_measurement, list):
|
1318
|
+
print(f'custom_measurement should be a list, add [ measurement_1, measurement_2 ] or [ measurement ]')
|
1319
|
+
return
|
1320
|
+
|
1321
|
+
if isinstance(custom_measurement, list):
|
1322
|
+
if len(custom_measurement) == 2:
|
1323
|
+
print(f'Classes will be defined by the Q1 and Q3 quantiles of recruitment ({custom_measurement[0]}/{custom_measurement[1]})')
|
1324
|
+
df['recruitment'] = df[f'{custom_measurement[0]}']/df[f'{custom_measurement[1]}']
|
1325
|
+
if len(custom_measurement) == 1:
|
1326
|
+
print(f'Classes will be defined by the Q1 and Q3 quantiles of recruitment ({custom_measurement[0]})')
|
1327
|
+
df['recruitment'] = df[f'{custom_measurement[0]}']
|
1328
|
+
else:
|
1329
|
+
print(f'Classes will be defined by the Q1 and Q3 quantiles of recruitment (pathogen/cytoplasm for channel {channel_of_interest})')
|
1330
|
+
df['recruitment'] = df[f'pathogen_channel_{channel_of_interest}_mean_intensity']/df[f'cytoplasm_channel_{channel_of_interest}_mean_intensity']
|
1331
|
+
|
1332
|
+
q25 = df['recruitment'].quantile(0.25)
|
1333
|
+
q75 = df['recruitment'].quantile(0.75)
|
1334
|
+
df_lower = df[df['recruitment'] <= q25]
|
1335
|
+
df_upper = df[df['recruitment'] >= q75]
|
1336
|
+
|
1337
|
+
class_paths_lower = get_paths_from_db(df=df_lower, png_df=png_list_df, image_type=png_type)
|
1338
|
+
|
1339
|
+
class_paths_lower = random.sample(class_paths_lower['png_path'].tolist(), size)
|
1340
|
+
class_paths_ls.append(class_paths_lower)
|
1341
|
+
|
1342
|
+
class_paths_upper = get_paths_from_db(df=df_upper, png_df=png_list_df, image_type=png_type)
|
1343
|
+
class_paths_upper = random.sample(class_paths_upper['png_path'].tolist(), size)
|
1344
|
+
class_paths_ls.append(class_paths_upper)
|
1345
|
+
|
1346
|
+
generate_dataset_from_lists(dst, class_data=class_paths_ls, classes=classes, test_split=0.1)
|
1347
|
+
|
1348
|
+
return
|
1349
|
+
|
1350
|
+
def generate_loaders(src, train_mode='erm', mode='train', image_size=224, batch_size=32, classes=['nc','pc'], num_workers=None, validation_split=0.0, max_show=2, pin_memory=False, normalize=False, verbose=False):
|
1351
|
+
"""
|
1352
|
+
Generate data loaders for training and validation/test datasets.
|
1353
|
+
|
1354
|
+
Parameters:
|
1355
|
+
- src (str): The source directory containing the data.
|
1356
|
+
- train_mode (str): The training mode. Options are 'erm' (Empirical Risk Minimization) or 'irm' (Invariant Risk Minimization).
|
1357
|
+
- mode (str): The mode of operation. Options are 'train' or 'test'.
|
1358
|
+
- image_size (int): The size of the input images.
|
1359
|
+
- batch_size (int): The batch size for the data loaders.
|
1360
|
+
- classes (list): The list of classes to consider.
|
1361
|
+
- num_workers (int): The number of worker threads for data loading.
|
1362
|
+
- validation_split (float): The fraction of data to use for validation when train_mode is 'erm'.
|
1363
|
+
- max_show (int): The maximum number of images to show when verbose is True.
|
1364
|
+
- pin_memory (bool): Whether to pin memory for faster data transfer.
|
1365
|
+
- normalize (bool): Whether to normalize the input images.
|
1366
|
+
- verbose (bool): Whether to print additional information and show images.
|
1367
|
+
|
1368
|
+
Returns:
|
1369
|
+
- train_loaders (list): List of data loaders for training datasets.
|
1370
|
+
- val_loaders (list): List of data loaders for validation datasets.
|
1371
|
+
- plate_names (list): List of plate names (only applicable when train_mode is 'irm').
|
1372
|
+
"""
|
1373
|
+
|
1374
|
+
from .io import MyDataset
|
1375
|
+
from .plot import _imshow
|
1376
|
+
|
1377
|
+
plate_to_filenames = defaultdict(list)
|
1378
|
+
plate_to_labels = defaultdict(list)
|
1379
|
+
train_loaders = []
|
1380
|
+
val_loaders = []
|
1381
|
+
plate_names = []
|
1382
|
+
|
1383
|
+
if normalize:
|
1384
|
+
transform = transforms.Compose([
|
1385
|
+
transforms.ToTensor(),
|
1386
|
+
transforms.CenterCrop(size=(image_size, image_size)),
|
1387
|
+
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])
|
1388
|
+
else:
|
1389
|
+
transform = transforms.Compose([
|
1390
|
+
transforms.ToTensor(),
|
1391
|
+
transforms.CenterCrop(size=(image_size, image_size))])
|
1392
|
+
|
1393
|
+
if mode == 'train':
|
1394
|
+
data_dir = os.path.join(src, 'train')
|
1395
|
+
shuffle = True
|
1396
|
+
print(f'Generating Train and validation datasets')
|
1397
|
+
|
1398
|
+
elif mode == 'test':
|
1399
|
+
data_dir = os.path.join(src, 'test')
|
1400
|
+
val_loaders = []
|
1401
|
+
validation_split=0.0
|
1402
|
+
shuffle = True
|
1403
|
+
print(f'Generating test dataset')
|
1404
|
+
|
1405
|
+
else:
|
1406
|
+
print(f'mode:{mode} is not valid, use mode = train or test')
|
1407
|
+
return
|
1408
|
+
|
1409
|
+
if train_mode == 'erm':
|
1410
|
+
data = MyDataset(data_dir, classes, transform=transform, shuffle=shuffle, pin_memory=pin_memory)
|
1411
|
+
#train_loaders = DataLoader(data, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers if num_workers is not None else 0, pin_memory=pin_memory)
|
1412
|
+
if validation_split > 0:
|
1413
|
+
train_size = int((1 - validation_split) * len(data))
|
1414
|
+
val_size = len(data) - train_size
|
1415
|
+
|
1416
|
+
print(f'Train data:{train_size}, Validation data:{val_size}')
|
1417
|
+
|
1418
|
+
train_dataset, val_dataset = random_split(data, [train_size, val_size])
|
1419
|
+
|
1420
|
+
train_loaders = DataLoader(train_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers if num_workers is not None else 0, pin_memory=pin_memory)
|
1421
|
+
val_loaders = DataLoader(val_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers if num_workers is not None else 0, pin_memory=pin_memory)
|
1422
|
+
else:
|
1423
|
+
train_loaders = DataLoader(data, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers if num_workers is not None else 0, pin_memory=pin_memory)
|
1424
|
+
|
1425
|
+
elif train_mode == 'irm':
|
1426
|
+
data = MyDataset(data_dir, classes, transform=transform, shuffle=shuffle, pin_memory=pin_memory)
|
1427
|
+
|
1428
|
+
for filename, label in zip(data.filenames, data.labels):
|
1429
|
+
plate = data.get_plate(filename)
|
1430
|
+
plate_to_filenames[plate].append(filename)
|
1431
|
+
plate_to_labels[plate].append(label)
|
1432
|
+
|
1433
|
+
for plate, filenames in plate_to_filenames.items():
|
1434
|
+
labels = plate_to_labels[plate]
|
1435
|
+
plate_data = MyDataset(data_dir, classes, specific_files=filenames, specific_labels=labels, transform=transform, shuffle=False, pin_memory=pin_memory)
|
1436
|
+
plate_names.append(plate)
|
1437
|
+
|
1438
|
+
if validation_split > 0:
|
1439
|
+
train_size = int((1 - validation_split) * len(plate_data))
|
1440
|
+
val_size = len(plate_data) - train_size
|
1441
|
+
|
1442
|
+
print(f'Train data:{train_size}, Validation data:{val_size}')
|
1443
|
+
|
1444
|
+
train_dataset, val_dataset = random_split(plate_data, [train_size, val_size])
|
1445
|
+
|
1446
|
+
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers if num_workers is not None else 0, pin_memory=pin_memory)
|
1447
|
+
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers if num_workers is not None else 0, pin_memory=pin_memory)
|
1448
|
+
|
1449
|
+
train_loaders.append(train_loader)
|
1450
|
+
val_loaders.append(val_loader)
|
1451
|
+
else:
|
1452
|
+
train_loader = DataLoader(plate_data, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers if num_workers is not None else 0, pin_memory=pin_memory)
|
1453
|
+
train_loaders.append(train_loader)
|
1454
|
+
val_loaders.append(None)
|
1455
|
+
|
1456
|
+
else:
|
1457
|
+
print(f'train_mode:{train_mode} is not valid, use: train_mode = irm or erm')
|
1458
|
+
return
|
1459
|
+
|
1460
|
+
if verbose:
|
1461
|
+
if train_mode == 'erm':
|
1462
|
+
for idx, (images, labels, filenames) in enumerate(train_loaders):
|
1463
|
+
if idx >= max_show:
|
1464
|
+
break
|
1465
|
+
images = images.cpu()
|
1466
|
+
label_strings = [str(label.item()) for label in labels]
|
1467
|
+
_imshow(images, label_strings, nrow=20, fontsize=12)
|
1468
|
+
|
1469
|
+
elif train_mode == 'irm':
|
1470
|
+
for plate_name, train_loader in zip(plate_names, train_loaders):
|
1471
|
+
print(f'Plate: {plate_name} with {len(train_loader.dataset)} images')
|
1472
|
+
for idx, (images, labels, filenames) in enumerate(train_loader):
|
1473
|
+
if idx >= max_show:
|
1474
|
+
break
|
1475
|
+
images = images.cpu()
|
1476
|
+
label_strings = [str(label.item()) for label in labels]
|
1477
|
+
_imshow(images, label_strings, nrow=20, fontsize=12)
|
1478
|
+
|
1479
|
+
return train_loaders, val_loaders, plate_names
|
1480
|
+
|
1481
|
+
def analyze_recruitment(src, metadata_settings, advanced_settings):
|
1482
|
+
"""
|
1483
|
+
Analyze recruitment data by grouping the DataFrame by well coordinates and plotting controls and recruitment data.
|
1484
|
+
|
1485
|
+
Parameters:
|
1486
|
+
src (str): The source of the recruitment data.
|
1487
|
+
metadata_settings (dict): The settings for metadata.
|
1488
|
+
advanced_settings (dict): The advanced settings for recruitment analysis.
|
1489
|
+
|
1490
|
+
Returns:
|
1491
|
+
None
|
1492
|
+
"""
|
1493
|
+
|
1494
|
+
from .io import _read_and_merge_data, _results_to_csv
|
1495
|
+
from .plot import plot_merged, _plot_controls, _plot_recruitment
|
1496
|
+
from .utils import _object_filter, annotate_conditions, _calculate_recruitment, _group_by_well
|
1497
|
+
|
1498
|
+
settings_dict = {**metadata_settings, **advanced_settings}
|
1499
|
+
settings_df = pd.DataFrame(list(settings_dict.items()), columns=['Key', 'Value'])
|
1500
|
+
settings_csv = os.path.join(src,'settings','analyze_settings.csv')
|
1501
|
+
os.makedirs(os.path.join(src,'settings'), exist_ok=True)
|
1502
|
+
settings_df.to_csv(settings_csv, index=False)
|
1503
|
+
|
1504
|
+
# metadata settings
|
1505
|
+
target = metadata_settings['target']
|
1506
|
+
cell_types = metadata_settings['cell_types']
|
1507
|
+
cell_plate_metadata = metadata_settings['cell_plate_metadata']
|
1508
|
+
pathogen_types = metadata_settings['pathogen_types']
|
1509
|
+
pathogen_plate_metadata = metadata_settings['pathogen_plate_metadata']
|
1510
|
+
treatments = metadata_settings['treatments']
|
1511
|
+
treatment_plate_metadata = metadata_settings['treatment_plate_metadata']
|
1512
|
+
metadata_types = metadata_settings['metadata_types']
|
1513
|
+
channel_dims = metadata_settings['channel_dims']
|
1514
|
+
cell_chann_dim = metadata_settings['cell_chann_dim']
|
1515
|
+
cell_mask_dim = metadata_settings['cell_mask_dim']
|
1516
|
+
nucleus_chann_dim = metadata_settings['nucleus_chann_dim']
|
1517
|
+
nucleus_mask_dim = metadata_settings['nucleus_mask_dim']
|
1518
|
+
pathogen_chann_dim = metadata_settings['pathogen_chann_dim']
|
1519
|
+
pathogen_mask_dim = metadata_settings['pathogen_mask_dim']
|
1520
|
+
channel_of_interest = metadata_settings['channel_of_interest']
|
1521
|
+
|
1522
|
+
# Advanced settings
|
1523
|
+
plot = advanced_settings['plot']
|
1524
|
+
plot_nr = advanced_settings['plot_nr']
|
1525
|
+
plot_control = advanced_settings['plot_control']
|
1526
|
+
figuresize = advanced_settings['figuresize']
|
1527
|
+
remove_background = advanced_settings['remove_background']
|
1528
|
+
backgrounds = advanced_settings['backgrounds']
|
1529
|
+
include_noninfected = advanced_settings['include_noninfected']
|
1530
|
+
include_multiinfected = advanced_settings['include_multiinfected']
|
1531
|
+
include_multinucleated = advanced_settings['include_multinucleated']
|
1532
|
+
cells_per_well = advanced_settings['cells_per_well']
|
1533
|
+
pathogen_size_range = advanced_settings['pathogen_size_range']
|
1534
|
+
nucleus_size_range = advanced_settings['nucleus_size_range']
|
1535
|
+
cell_size_range = advanced_settings['cell_size_range']
|
1536
|
+
pathogen_intensity_range = advanced_settings['pathogen_intensity_range']
|
1537
|
+
nucleus_intensity_range = advanced_settings['nucleus_intensity_range']
|
1538
|
+
cell_intensity_range = advanced_settings['cell_intensity_range']
|
1539
|
+
target_intensity_min = advanced_settings['target_intensity_min']
|
1540
|
+
|
1541
|
+
print(f'Cell(s): {cell_types}, in {cell_plate_metadata}')
|
1542
|
+
print(f'Pathogen(s): {pathogen_types}, in {pathogen_plate_metadata}')
|
1543
|
+
print(f'Treatment(s): {treatments}, in {treatment_plate_metadata}')
|
1544
|
+
|
1545
|
+
mask_dims=[cell_mask_dim,nucleus_mask_dim,pathogen_mask_dim]
|
1546
|
+
mask_chans=[nucleus_chann_dim, pathogen_chann_dim, cell_chann_dim]
|
1547
|
+
|
1548
|
+
if isinstance(metadata_types, str):
|
1549
|
+
metadata_types = [metadata_types, metadata_types, metadata_types]
|
1550
|
+
if isinstance(metadata_types, list):
|
1551
|
+
if len(metadata_types) < 3:
|
1552
|
+
metadata_types = [metadata_types[0], metadata_types[0], metadata_types[0]]
|
1553
|
+
print(f'WARNING: setting metadata types to first element times 3: {metadata_types}. To avoid this behaviour, set metadata_types to a list with 3 elements. Elements should be col row or plate.')
|
1554
|
+
else:
|
1555
|
+
metadata_types = metadata_types
|
1556
|
+
|
1557
|
+
if isinstance(backgrounds, (int,float)):
|
1558
|
+
backgrounds = [backgrounds, backgrounds, backgrounds, backgrounds]
|
1559
|
+
|
1560
|
+
sns.color_palette("mako", as_cmap=True)
|
1561
|
+
print(f'channel:{channel_of_interest} = {target}')
|
1562
|
+
overlay_channels = channel_dims
|
1563
|
+
overlay_channels.remove(channel_of_interest)
|
1564
|
+
overlay_channels.reverse()
|
1565
|
+
|
1566
|
+
db_loc = [src+'/measurements/measurements.db']
|
1567
|
+
tables = ['cell', 'nucleus', 'pathogen','cytoplasm']
|
1568
|
+
df, _ = _read_and_merge_data(db_loc,
|
1569
|
+
tables,
|
1570
|
+
verbose=True,
|
1571
|
+
include_multinucleated=include_multinucleated,
|
1572
|
+
include_multiinfected=include_multiinfected,
|
1573
|
+
include_noninfected=include_noninfected)
|
1574
|
+
|
1575
|
+
df = annotate_conditions(df,
|
1576
|
+
cells=cell_types,
|
1577
|
+
cell_loc=cell_plate_metadata,
|
1578
|
+
pathogens=pathogen_types,
|
1579
|
+
pathogen_loc=pathogen_plate_metadata,
|
1580
|
+
treatments=treatments,
|
1581
|
+
treatment_loc=treatment_plate_metadata,
|
1582
|
+
types=metadata_types)
|
1583
|
+
|
1584
|
+
df = df.dropna(subset=['condition'])
|
1585
|
+
print(f'After dropping non-annotated wells: {len(df)} rows')
|
1586
|
+
files = df['file_name'].tolist()
|
1587
|
+
files = [item + '.npy' for item in files]
|
1588
|
+
random.shuffle(files)
|
1589
|
+
|
1590
|
+
if plot:
|
1591
|
+
plot_settings = {'include_noninfected':include_noninfected,
|
1592
|
+
'include_multiinfected':include_multiinfected,
|
1593
|
+
'include_multinucleated':include_multinucleated,
|
1594
|
+
'remove_background':remove_background,
|
1595
|
+
'filter_min_max':[[cell_size_range[0],cell_size_range[1]],[nucleus_size_range[0],nucleus_size_range[1]],[pathogen_size_range[0],pathogen_size_range[1]]],
|
1596
|
+
'channel_dims':channel_dims,
|
1597
|
+
'backgrounds':backgrounds,
|
1598
|
+
'cell_mask_dim':mask_dims[0],
|
1599
|
+
'nucleus_mask_dim':mask_dims[1],
|
1600
|
+
'pathogen_mask_dim':mask_dims[2],
|
1601
|
+
'overlay_chans':overlay_channels,
|
1602
|
+
'outline_thickness':3,
|
1603
|
+
'outline_color':'gbr',
|
1604
|
+
'overlay_chans':overlay_channels,
|
1605
|
+
'overlay':True,
|
1606
|
+
'normalization_percentiles':[1,99],
|
1607
|
+
'normalize':True,
|
1608
|
+
'print_object_number':True,
|
1609
|
+
'nr':plot_nr,
|
1610
|
+
'figuresize':20,
|
1611
|
+
'cmap':'inferno',
|
1612
|
+
'verbose':False}
|
1613
|
+
|
1614
|
+
if os.path.exists(os.path.join(src,'merged')):
|
1615
|
+
try:
|
1616
|
+
plot_merged(src=os.path.join(src,'merged'), settings=plot_settings)
|
1617
|
+
except Exception as e:
|
1618
|
+
print(f'Failed to plot images with outlines, Error: {e}')
|
1619
|
+
|
1620
|
+
if not cell_chann_dim is None:
|
1621
|
+
df = _object_filter(df, object_type='cell', size_range=cell_size_range, intensity_range=cell_intensity_range, mask_chans=mask_chans, mask_chan=0)
|
1622
|
+
if not target_intensity_min is None:
|
1623
|
+
df = df[df[f'cell_channel_{channel_of_interest}_percentile_95'] > target_intensity_min]
|
1624
|
+
print(f'After channel {channel_of_interest} filtration', len(df))
|
1625
|
+
if not nucleus_chann_dim is None:
|
1626
|
+
df = _object_filter(df, object_type='nucleus', size_range=nucleus_size_range, intensity_range=nucleus_intensity_range, mask_chans=mask_chans, mask_chan=1)
|
1627
|
+
if not pathogen_chann_dim is None:
|
1628
|
+
df = _object_filter(df, object_type='pathogen', size_range=pathogen_size_range, intensity_range=pathogen_intensity_range, mask_chans=mask_chans, mask_chan=2)
|
1629
|
+
|
1630
|
+
df['recruitment'] = df[f'pathogen_channel_{channel_of_interest}_mean_intensity']/df[f'cytoplasm_channel_{channel_of_interest}_mean_intensity']
|
1631
|
+
for chan in channel_dims:
|
1632
|
+
df = _calculate_recruitment(df, channel=chan)
|
1633
|
+
print(f'calculated recruitment for: {len(df)} rows')
|
1634
|
+
df_well = _group_by_well(df)
|
1635
|
+
print(f'found: {len(df_well)} wells')
|
1636
|
+
|
1637
|
+
df_well = df_well[df_well['cells_per_well'] >= cells_per_well]
|
1638
|
+
prc_list = df_well['prc'].unique().tolist()
|
1639
|
+
df = df[df['prc'].isin(prc_list)]
|
1640
|
+
print(f'After cells per well filter: {len(df)} cells in {len(df_well)} wells left wth threshold {cells_per_well}')
|
1641
|
+
|
1642
|
+
if plot_control:
|
1643
|
+
_plot_controls(df, mask_chans, channel_of_interest, figuresize=5)
|
1644
|
+
|
1645
|
+
print(f'PV level: {len(df)} rows')
|
1646
|
+
_plot_recruitment(df=df, df_type='by PV', channel_of_interest=channel_of_interest, target=target, figuresize=figuresize)
|
1647
|
+
print(f'well level: {len(df_well)} rows')
|
1648
|
+
_plot_recruitment(df=df_well, df_type='by well', channel_of_interest=channel_of_interest, target=target, figuresize=figuresize)
|
1649
|
+
cells,wells = _results_to_csv(src, df, df_well)
|
1650
|
+
return [cells,wells]
|
1651
|
+
|
1652
|
+
@log_function_call
|
1653
|
+
def preprocess_generate_masks(src, settings={},advanced_settings={}):
|
1654
|
+
|
1655
|
+
from .io import preprocess_img_data, _load_and_concatenate_arrays
|
1656
|
+
from .plot import plot_merged, plot_arrays
|
1657
|
+
from .utils import _pivot_counts_table
|
1658
|
+
|
1659
|
+
settings = {**settings, **advanced_settings}
|
1660
|
+
settings['src'] = src
|
1661
|
+
settings_df = pd.DataFrame(list(settings.items()), columns=['Key', 'Value'])
|
1662
|
+
settings_csv = os.path.join(src,'settings','preprocess_generate_masks_settings.csv')
|
1663
|
+
os.makedirs(os.path.join(src,'settings'), exist_ok=True)
|
1664
|
+
settings_df.to_csv(settings_csv, index=False)
|
1665
|
+
|
1666
|
+
if settings['timelapse']:
|
1667
|
+
settings['randomize'] = False
|
1668
|
+
|
1669
|
+
if settings['preprocess']:
|
1670
|
+
if not settings['masks']:
|
1671
|
+
print(f'WARNING: channels for mask generation are defined when preprocess = True')
|
1672
|
+
|
1673
|
+
if isinstance(settings['merge'], bool):
|
1674
|
+
settings['merge'] = [settings['merge']]*3
|
1675
|
+
if isinstance(settings['save'], bool):
|
1676
|
+
settings['save'] = [settings['save']]*3
|
1677
|
+
|
1678
|
+
if settings['preprocess']:
|
1679
|
+
preprocess_img_data(settings)
|
1680
|
+
|
1681
|
+
if settings['masks']:
|
1682
|
+
mask_src = os.path.join(src, 'norm_channel_stack')
|
1683
|
+
if settings['cell_channel'] != None:
|
1684
|
+
generate_cellpose_masks(src=mask_src, settings=settings, object_type='cell')
|
1685
|
+
|
1686
|
+
if settings['nucleus_channel'] != None:
|
1687
|
+
generate_cellpose_masks(src=mask_src, settings=settings, object_type='nucleus')
|
1688
|
+
|
1689
|
+
if settings['pathogen_channel'] != None:
|
1690
|
+
generate_cellpose_masks(src=mask_src, settings=settings, object_type='pathogen')
|
1691
|
+
|
1692
|
+
if os.path.exists(os.path.join(src,'measurements')):
|
1693
|
+
_pivot_counts_table(db_path=os.path.join(src,'measurements', 'measurements.db'))
|
1694
|
+
|
1695
|
+
#Concatinate stack with masks
|
1696
|
+
_load_and_concatenate_arrays(src, settings['channels'], settings['cell_channel'], settings['nucleus_channel'], settings['pathogen_channel'])
|
1697
|
+
|
1698
|
+
if settings['plot']:
|
1699
|
+
if not settings['timelapse']:
|
1700
|
+
plot_dims = len(settings['channels'])
|
1701
|
+
overlay_channels = [2,1,0]
|
1702
|
+
cell_mask_dim = nucleus_mask_dim = pathogen_mask_dim = None
|
1703
|
+
plot_counter = plot_dims
|
1704
|
+
|
1705
|
+
if settings['cell_channel'] is not None:
|
1706
|
+
cell_mask_dim = plot_counter
|
1707
|
+
plot_counter += 1
|
1708
|
+
|
1709
|
+
if settings['nucleus_channel'] is not None:
|
1710
|
+
nucleus_mask_dim = plot_counter
|
1711
|
+
plot_counter += 1
|
1712
|
+
|
1713
|
+
if settings['pathogen_channel'] is not None:
|
1714
|
+
pathogen_mask_dim = plot_counter
|
1715
|
+
|
1716
|
+
overlay_channels = [settings['nucleus_channel'], settings['pathogen_channel'], settings['cell_channel']]
|
1717
|
+
overlay_channels = [element for element in overlay_channels if element is not None]
|
1718
|
+
|
1719
|
+
plot_settings = {'include_noninfected':True,
|
1720
|
+
'include_multiinfected':True,
|
1721
|
+
'include_multinucleated':True,
|
1722
|
+
'remove_background':False,
|
1723
|
+
'filter_min_max':None,
|
1724
|
+
'channel_dims':settings['channels'],
|
1725
|
+
'backgrounds':[100,100,100,100],
|
1726
|
+
'cell_mask_dim':cell_mask_dim,
|
1727
|
+
'nucleus_mask_dim':nucleus_mask_dim,
|
1728
|
+
'pathogen_mask_dim':pathogen_mask_dim,
|
1729
|
+
'overlay_chans':[0,2,3],
|
1730
|
+
'outline_thickness':3,
|
1731
|
+
'outline_color':'gbr',
|
1732
|
+
'overlay_chans':overlay_channels,
|
1733
|
+
'overlay':True,
|
1734
|
+
'normalization_percentiles':[1,99],
|
1735
|
+
'normalize':True,
|
1736
|
+
'print_object_number':True,
|
1737
|
+
'nr':settings['examples_to_plot'],
|
1738
|
+
'figuresize':20,
|
1739
|
+
'cmap':'inferno',
|
1740
|
+
'verbose':False}
|
1741
|
+
try:
|
1742
|
+
fig = plot_merged(src=os.path.join(src,'merged'), settings=plot_settings)
|
1743
|
+
except Exception as e:
|
1744
|
+
print(f'Failed to plot image mask overly. Error: {e}')
|
1745
|
+
else:
|
1746
|
+
plot_arrays(src=os.path.join(src,'merged'), figuresize=50, cmap='inferno', nr=1, normalize=True, q1=1, q2=99)
|
1747
|
+
|
1748
|
+
torch.cuda.empty_cache()
|
1749
|
+
gc.collect()
|
1750
|
+
return
|
1751
|
+
|
1752
|
+
def identify_masks_finetune(src, dst, model_name, channels, diameter, batch_size, flow_threshold=30, cellprob_threshold=1, figuresize=25, cmap='inferno', verbose=False, plot=False, save=False, custom_model=None, signal_thresholds=1000, normalize=True, resize=False, target_height=None, target_width=None, rescale=True, resample=True, net_avg=False, invert=False, circular=False, percentiles=None, overlay=True, grayscale=False):
|
1753
|
+
|
1754
|
+
from .plot import print_mask_and_flows
|
1755
|
+
from .utils import get_files_from_dir, resize_images_and_labels
|
1756
|
+
from .io import _load_normalized_images_and_labels, _load_images_and_labels
|
1757
|
+
|
1758
|
+
if not torch.cuda.is_available():
|
1759
|
+
print(f'Torch CUDA is not available, using CPU')
|
1760
|
+
|
1761
|
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
1762
|
+
|
1763
|
+
if custom_model == None:
|
1764
|
+
if model_name =='cyto':
|
1765
|
+
model = cp_models.CellposeModel(gpu=True, model_type=model_name, net_avg=False, diam_mean=diameter, pretrained_model=None)
|
1766
|
+
else:
|
1767
|
+
model = cp_models.CellposeModel(gpu=True, model_type=model_name)
|
1768
|
+
|
1769
|
+
if custom_model != None:
|
1770
|
+
model = cp_models.CellposeModel(gpu=torch.cuda.is_available(), model_type=None, pretrained_model=custom_model, diam_mean=diameter, device=device, net_avg=False) #Assuming diameter is defined elsewhere
|
1771
|
+
print(f'loaded custom model:{custom_model}')
|
1772
|
+
|
1773
|
+
chans = [2, 1] if model_name == 'cyto2' else [0,0] if model_name == 'nucleus' else [1,0] if model_name == 'cyto' else [2, 0]
|
1774
|
+
|
1775
|
+
if grayscale:
|
1776
|
+
chans=[0, 0]
|
1777
|
+
|
1778
|
+
print(f'Using channels: {chans} for model of type {model_name}')
|
1779
|
+
|
1780
|
+
if verbose == True:
|
1781
|
+
print(f'Cellpose settings: Model: {model_name}, channels: {channels}, cellpose_chans: {chans}, diameter:{diameter}, flow_threshold:{flow_threshold}, cellprob_threshold:{cellprob_threshold}')
|
1782
|
+
|
1783
|
+
all_image_files = get_files_from_dir(src, file_extension="*.tif")
|
1784
|
+
random.shuffle(all_image_files)
|
1785
|
+
|
1786
|
+
time_ls = []
|
1787
|
+
for i in range(0, len(all_image_files), batch_size):
|
1788
|
+
image_files = all_image_files[i:i+batch_size]
|
1789
|
+
if normalize:
|
1790
|
+
images, _, image_names, _ = _load_normalized_images_and_labels(image_files=image_files, label_files=None, signal_thresholds=signal_thresholds, channels=channels, percentiles=percentiles, circular=circular, invert=invert, visualize=verbose)
|
1791
|
+
images = [np.squeeze(img) if img.shape[-1] == 1 else img for img in images]
|
1792
|
+
orig_dims = [(image.shape[0], image.shape[1]) for image in images]
|
1793
|
+
else:
|
1794
|
+
images, _, image_names, _ = _load_images_and_labels(image_files=image_files, label_files=None, circular=circular, invert=invert)
|
1795
|
+
images = [np.squeeze(img) if img.shape[-1] == 1 else img for img in images]
|
1796
|
+
orig_dims = [(image.shape[0], image.shape[1]) for image in images]
|
1797
|
+
if resize:
|
1798
|
+
images, _ = resize_images_and_labels(images, None, target_height, target_width, True)
|
1799
|
+
|
1800
|
+
for file_index, stack in enumerate(images):
|
1801
|
+
start = time.time()
|
1802
|
+
output = model.eval(x=stack,
|
1803
|
+
normalize=False,
|
1804
|
+
channels=chans,
|
1805
|
+
channel_axis=3,
|
1806
|
+
diameter=diameter,
|
1807
|
+
flow_threshold=flow_threshold,
|
1808
|
+
cellprob_threshold=cellprob_threshold,
|
1809
|
+
rescale=rescale,
|
1810
|
+
resample=resample,
|
1811
|
+
net_avg=net_avg,
|
1812
|
+
progress=False)
|
1813
|
+
|
1814
|
+
if len(output) == 4:
|
1815
|
+
mask, flows, _, _ = output
|
1816
|
+
elif len(output) == 3:
|
1817
|
+
mask, flows, _ = output
|
1818
|
+
else:
|
1819
|
+
raise ValueError("Unexpected number of return values from model.eval()")
|
1820
|
+
|
1821
|
+
if resize:
|
1822
|
+
dims = orig_dims[file_index]
|
1823
|
+
mask = resizescikit(mask, dims, order=0, preserve_range=True, anti_aliasing=False).astype(mask.dtype)
|
1824
|
+
|
1825
|
+
stop = time.time()
|
1826
|
+
duration = (stop - start)
|
1827
|
+
time_ls.append(duration)
|
1828
|
+
average_time = np.mean(time_ls) if len(time_ls) > 0 else 0
|
1829
|
+
print(f'Processing {file_index+1}/{len(images)} images : Time/image {average_time:.3f} sec', end='\r', flush=True)
|
1830
|
+
if plot:
|
1831
|
+
if resize:
|
1832
|
+
stack = resizescikit(stack, dims, preserve_range=True, anti_aliasing=False).astype(stack.dtype)
|
1833
|
+
print_mask_and_flows(stack, mask, flows, overlay=overlay)
|
1834
|
+
if save:
|
1835
|
+
output_filename = os.path.join(dst, image_names[file_index])
|
1836
|
+
cv2.imwrite(output_filename, mask)
|
1837
|
+
return
|
1838
|
+
|
1839
|
+
@log_function_call
|
1840
|
+
def identify_masks(src, object_type, model_name, batch_size, channels, diameter, minimum_size, maximum_size, flow_threshold=30, cellprob_threshold=1, figuresize=25, cmap='inferno', refine_masks=True, filter_size=True, filter_dimm=True, remove_border_objects=False, verbose=False, plot=False, merge=False, save=True, start_at=0, file_type='.npz', net_avg=True, resample=True, timelapse=False, timelapse_displacement=None, timelapse_frame_limits=None, timelapse_memory=3, timelapse_remove_transient=False, timelapse_mode='btrack', timelapse_objects='cell'):
|
1841
|
+
"""
|
1842
|
+
Identify masks from the source images.
|
1843
|
+
|
1844
|
+
Args:
|
1845
|
+
src (str): Path to the source images.
|
1846
|
+
object_type (str): Type of object to identify.
|
1847
|
+
model_name (str): Name of the model to use for identification.
|
1848
|
+
batch_size (int): Number of images to process in each batch.
|
1849
|
+
channels (list): List of channel names.
|
1850
|
+
diameter (float): Diameter of the objects to identify.
|
1851
|
+
minimum_size (int): Minimum size of objects to keep.
|
1852
|
+
maximum_size (int): Maximum size of objects to keep.
|
1853
|
+
flow_threshold (int, optional): Threshold for flow detection. Defaults to 30.
|
1854
|
+
cellprob_threshold (int, optional): Threshold for cell probability. Defaults to 1.
|
1855
|
+
figuresize (int, optional): Size of the figure. Defaults to 25.
|
1856
|
+
cmap (str, optional): Colormap for plotting. Defaults to 'inferno'.
|
1857
|
+
refine_masks (bool, optional): Flag indicating whether to refine masks. Defaults to True.
|
1858
|
+
filter_size (bool, optional): Flag indicating whether to filter based on size. Defaults to True.
|
1859
|
+
filter_dimm (bool, optional): Flag indicating whether to filter based on intensity. Defaults to True.
|
1860
|
+
remove_border_objects (bool, optional): Flag indicating whether to remove border objects. Defaults to False.
|
1861
|
+
verbose (bool, optional): Flag indicating whether to display verbose output. Defaults to False.
|
1862
|
+
plot (bool, optional): Flag indicating whether to plot the masks. Defaults to False.
|
1863
|
+
merge (bool, optional): Flag indicating whether to merge adjacent objects. Defaults to False.
|
1864
|
+
save (bool, optional): Flag indicating whether to save the masks. Defaults to True.
|
1865
|
+
start_at (int, optional): Index to start processing from. Defaults to 0.
|
1866
|
+
file_type (str, optional): File type for saving the masks. Defaults to '.npz'.
|
1867
|
+
net_avg (bool, optional): Flag indicating whether to use network averaging. Defaults to True.
|
1868
|
+
resample (bool, optional): Flag indicating whether to resample the images. Defaults to True.
|
1869
|
+
timelapse (bool, optional): Flag indicating whether to generate a timelapse. Defaults to False.
|
1870
|
+
timelapse_displacement (float, optional): Displacement threshold for timelapse. Defaults to None.
|
1871
|
+
timelapse_frame_limits (tuple, optional): Frame limits for timelapse. Defaults to None.
|
1872
|
+
timelapse_memory (int, optional): Memory for timelapse. Defaults to 3.
|
1873
|
+
timelapse_remove_transient (bool, optional): Flag indicating whether to remove transient objects in timelapse. Defaults to False.
|
1874
|
+
timelapse_mode (str, optional): Mode for timelapse. Defaults to 'btrack'.
|
1875
|
+
timelapse_objects (str, optional): Objects to track in timelapse. Defaults to 'cell'.
|
1876
|
+
|
1877
|
+
Returns:
|
1878
|
+
None
|
1879
|
+
"""
|
1880
|
+
|
1881
|
+
from .utils import _masks_to_masks_stack, _filter_cp_masks, _get_cellpose_batch_size
|
1882
|
+
from .io import _create_database, _save_object_counts_to_database, _check_masks, _get_avg_object_size
|
1883
|
+
from .timelapse import _npz_to_movie, _btrack_track_cells, _trackpy_track_cells
|
1884
|
+
from .plot import plot_masks
|
1885
|
+
|
1886
|
+
#Note add logic that handles batches of size 1 as these will break the code batches must all be > 2 images
|
1887
|
+
gc.collect()
|
1888
|
+
#print('========== generating masks ==========')
|
1889
|
+
|
1890
|
+
if not torch.cuda.is_available():
|
1891
|
+
print(f'Torch CUDA is not available, using CPU')
|
1892
|
+
|
1893
|
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
1894
|
+
model = cp_models.Cellpose(gpu=True, model_type=model_name, device=device) #net_avg=net_avg
|
1895
|
+
if file_type == '.npz':
|
1896
|
+
paths = [os.path.join(src, file) for file in os.listdir(src) if file.endswith('.npz')]
|
1897
|
+
else:
|
1898
|
+
paths = [os.path.join(src, file) for file in os.listdir(src) if file.endswith('.png')]
|
1899
|
+
if timelapse:
|
1900
|
+
print(f'timelaps is only compatible with npz files')
|
1901
|
+
return
|
1902
|
+
|
1903
|
+
chans = [2, 1] if model_name == 'cyto2' else [0,0] if model_name == 'nucleus' else [2,0] if model_name == 'cyto' else [2, 0]
|
1904
|
+
|
1905
|
+
if verbose == True:
|
1906
|
+
print(f'source: {src}')
|
1907
|
+
print()
|
1908
|
+
print(f'Settings: object_type: {object_type}, minimum_size: {minimum_size}, maximum_size:{maximum_size}, figuresize:{figuresize}, cmap:{cmap}, , net_avg:{net_avg}, resample:{resample}')
|
1909
|
+
print()
|
1910
|
+
print(f'Cellpose settings: Model: {model_name}, batch_size: {batch_size}, channels: {channels}, cellpose_chans: {chans}, diameter:{diameter}, flow_threshold:{flow_threshold}, cellprob_threshold:{cellprob_threshold}')
|
1911
|
+
print()
|
1912
|
+
print(f'Bool Settings: verbose:{verbose}, plot:{plot}, merge:{merge}, save:{save}, start_at:{start_at}, file_type:{file_type}, timelapse:{timelapse}')
|
1913
|
+
print()
|
1914
|
+
|
1915
|
+
count_loc = os.path.dirname(src)+'/measurements/measurements.db'
|
1916
|
+
os.makedirs(os.path.dirname(src)+'/measurements', exist_ok=True)
|
1917
|
+
_create_database(count_loc)
|
1918
|
+
|
1919
|
+
average_sizes = []
|
1920
|
+
time_ls = []
|
1921
|
+
moving_avg_q1 = 0
|
1922
|
+
moving_avg_q3 = 0
|
1923
|
+
moving_count = 0
|
1924
|
+
for file_index, path in enumerate(paths):
|
1925
|
+
|
1926
|
+
name = os.path.basename(path)
|
1927
|
+
name, ext = os.path.splitext(name)
|
1928
|
+
if file_type == '.npz':
|
1929
|
+
if start_at:
|
1930
|
+
print(f'starting at file index:{start_at}')
|
1931
|
+
if file_index < start_at:
|
1932
|
+
continue
|
1933
|
+
output_folder = os.path.join(os.path.dirname(path), object_type+'_mask_stack')
|
1934
|
+
os.makedirs(output_folder, exist_ok=True)
|
1935
|
+
overall_average_size = 0
|
1936
|
+
with np.load(path) as data:
|
1937
|
+
stack = data['data']
|
1938
|
+
filenames = data['filenames']
|
1939
|
+
if timelapse:
|
1940
|
+
if len(stack) != batch_size:
|
1941
|
+
print(f'Changed batch_size:{batch_size} to {len(stack)}, data length:{len(stack)}')
|
1942
|
+
batch_size = len(stack)
|
1943
|
+
if isinstance(timelapse_frame_limits, list):
|
1944
|
+
if len(timelapse_frame_limits) >= 2:
|
1945
|
+
stack = stack[timelapse_frame_limits[0]: timelapse_frame_limits[1], :, :, :].astype(stack.dtype)
|
1946
|
+
filenames = filenames[timelapse_frame_limits[0]: timelapse_frame_limits[1]]
|
1947
|
+
batch_size = len(stack)
|
1948
|
+
print(f'Cut batch an indecies: {timelapse_frame_limits}, New batch_size: {batch_size} ')
|
1949
|
+
|
1950
|
+
for i in range(0, stack.shape[0], batch_size):
|
1951
|
+
mask_stack = []
|
1952
|
+
start = time.time()
|
1953
|
+
|
1954
|
+
if stack.shape[3] == 1:
|
1955
|
+
batch = stack[i: i+batch_size, :, :, [0,0]].astype(stack.dtype)
|
1956
|
+
else:
|
1957
|
+
batch = stack[i: i+batch_size, :, :, channels].astype(stack.dtype)
|
1958
|
+
|
1959
|
+
batch_filenames = filenames[i: i+batch_size].tolist()
|
1960
|
+
|
1961
|
+
if not plot:
|
1962
|
+
batch, batch_filenames = _check_masks(batch, batch_filenames, output_folder)
|
1963
|
+
if batch.size == 0:
|
1964
|
+
print(f'Processing {file_index}/{len(paths)}: Images/N100pz {batch.shape[0]}', end='\r', flush=True)
|
1965
|
+
continue
|
1966
|
+
if batch.max() > 1:
|
1967
|
+
batch = batch / batch.max()
|
1968
|
+
|
1969
|
+
if timelapse:
|
1970
|
+
stitch_threshold=100.0
|
1971
|
+
movie_path = os.path.join(os.path.dirname(src), 'movies')
|
1972
|
+
os.makedirs(movie_path, exist_ok=True)
|
1973
|
+
save_path = os.path.join(movie_path, f'timelapse_{object_type}_{name}.mp4')
|
1974
|
+
_npz_to_movie(batch, batch_filenames, save_path, fps=2)
|
1975
|
+
else:
|
1976
|
+
stitch_threshold=0.0
|
1977
|
+
|
1978
|
+
cellpose_batch_size = _get_cellpose_batch_size()
|
1979
|
+
|
1980
|
+
model = cellpose.denoise.DenoiseModel(model_type=f"denoise_{model_name}", gpu=True)
|
1981
|
+
|
1982
|
+
masks, flows, _, _ = model.eval(x=batch,
|
1983
|
+
batch_size=cellpose_batch_size,
|
1984
|
+
normalize=False,
|
1985
|
+
channels=chans,
|
1986
|
+
channel_axis=3,
|
1987
|
+
diameter=diameter,
|
1988
|
+
flow_threshold=flow_threshold,
|
1989
|
+
cellprob_threshold=cellprob_threshold,
|
1990
|
+
rescale=None,
|
1991
|
+
resample=resample,
|
1992
|
+
#net_avg=net_avg,
|
1993
|
+
stitch_threshold=stitch_threshold,
|
1994
|
+
progress=None)
|
1995
|
+
print('Masks shape',masks.shape)
|
1996
|
+
if timelapse:
|
1997
|
+
_save_object_counts_to_database(masks, object_type, batch_filenames, count_loc, added_string='_timelapse')
|
1998
|
+
if object_type in timelapse_objects:
|
1999
|
+
if timelapse_mode == 'btrack':
|
2000
|
+
if not timelapse_displacement is None:
|
2001
|
+
radius = timelapse_displacement
|
2002
|
+
else:
|
2003
|
+
radius = 100
|
2004
|
+
|
2005
|
+
workers = os.cpu_count()-2
|
2006
|
+
if workers < 1:
|
2007
|
+
workers = 1
|
2008
|
+
|
2009
|
+
mask_stack = _btrack_track_cells(src, name, batch_filenames, object_type, plot, save, masks_3D=masks, mode=timelapse_mode, timelapse_remove_transient=timelapse_remove_transient, radius=radius, workers=workers)
|
2010
|
+
if timelapse_mode == 'trackpy':
|
2011
|
+
mask_stack = _trackpy_track_cells(src, name, batch_filenames, object_type, masks, timelapse_displacement, timelapse_memory, timelapse_remove_transient, plot, save, timelapse_mode)
|
2012
|
+
|
2013
|
+
else:
|
2014
|
+
mask_stack = _masks_to_masks_stack(masks)
|
2015
|
+
|
2016
|
+
else:
|
2017
|
+
_save_object_counts_to_database(masks, object_type, batch_filenames, count_loc, added_string='_before_filtration')
|
2018
|
+
mask_stack = _filter_cp_masks(masks, flows, refine_masks, filter_size, minimum_size, maximum_size, remove_border_objects, merge, filter_dimm, batch, moving_avg_q1, moving_avg_q3, moving_count, plot, figuresize)
|
2019
|
+
_save_object_counts_to_database(mask_stack, object_type, batch_filenames, count_loc, added_string='_after_filtration')
|
2020
|
+
|
2021
|
+
if not np.any(mask_stack):
|
2022
|
+
average_obj_size = 0
|
2023
|
+
else:
|
2024
|
+
average_obj_size = _get_avg_object_size(mask_stack)
|
2025
|
+
|
2026
|
+
average_sizes.append(average_obj_size)
|
2027
|
+
overall_average_size = np.mean(average_sizes) if len(average_sizes) > 0 else 0
|
2028
|
+
|
2029
|
+
stop = time.time()
|
2030
|
+
duration = (stop - start)
|
2031
|
+
time_ls.append(duration)
|
2032
|
+
average_time = np.mean(time_ls) if len(time_ls) > 0 else 0
|
2033
|
+
time_in_min = average_time/60
|
2034
|
+
time_per_mask = average_time/batch_size
|
2035
|
+
print(f'Processing {len(paths)} files with {batch_size} imgs: {(file_index+1)*(batch_size+1)}/{(len(paths))*(batch_size+1)}: Time/batch {time_in_min:.3f} min: Time/mask {time_per_mask:.3f}sec: {object_type} size: {overall_average_size:.3f} px2', end='\r', flush=True)
|
2036
|
+
if not timelapse:
|
2037
|
+
if plot:
|
2038
|
+
plot_masks(batch, mask_stack, flows, figuresize=figuresize, cmap=cmap, nr=batch_size, file_type='.npz')
|
2039
|
+
if save:
|
2040
|
+
if file_type == '.npz':
|
2041
|
+
for mask_index, mask in enumerate(mask_stack):
|
2042
|
+
output_filename = os.path.join(output_folder, batch_filenames[mask_index])
|
2043
|
+
np.save(output_filename, mask)
|
2044
|
+
mask_stack = []
|
2045
|
+
batch_filenames = []
|
2046
|
+
gc.collect()
|
2047
|
+
return
|
2048
|
+
|
2049
|
+
@log_function_call
|
2050
|
+
def generate_cellpose_masks(src, settings, object_type):
|
2051
|
+
|
2052
|
+
from .utils import _masks_to_masks_stack, _filter_cp_masks, _get_cellpose_batch_size, _get_object_settings, _get_cellpose_channels
|
2053
|
+
from .io import _create_database, _save_object_counts_to_database, _check_masks, _get_avg_object_size
|
2054
|
+
from .timelapse import _npz_to_movie, _btrack_track_cells, _trackpy_track_cells
|
2055
|
+
from .plot import plot_masks
|
2056
|
+
|
2057
|
+
gc.collect()
|
2058
|
+
if not torch.cuda.is_available():
|
2059
|
+
print(f'Torch CUDA is not available, using CPU')
|
2060
|
+
|
2061
|
+
figuresize=25
|
2062
|
+
timelapse = settings['timelapse']
|
2063
|
+
|
2064
|
+
if timelapse:
|
2065
|
+
timelapse_displacement = settings['timelapse_displacement']
|
2066
|
+
timelapse_frame_limits = settings['timelapse_frame_limits']
|
2067
|
+
timelapse_memory = settings['timelapse_memory']
|
2068
|
+
timelapse_remove_transient = settings['timelapse_remove_transient']
|
2069
|
+
timelapse_mode = settings['timelapse_mode']
|
2070
|
+
timelapse_objects = settings['timelapse_objects']
|
2071
|
+
|
2072
|
+
batch_size = settings['batch_size']
|
2073
|
+
cellprob_threshold = settings[f'{object_type}_CP_prob']
|
2074
|
+
flow_threshold = 30
|
2075
|
+
|
2076
|
+
object_settings = _get_object_settings(object_type, settings)
|
2077
|
+
model_name = object_settings['model_name']
|
2078
|
+
|
2079
|
+
cellpose_channels = _get_cellpose_channels(settings['nucleus_channel'], settings['pathogen_channel'], settings['cell_channel'])
|
2080
|
+
channels = cellpose_channels[object_type]
|
2081
|
+
cellpose_batch_size = _get_cellpose_batch_size()
|
2082
|
+
|
2083
|
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
2084
|
+
model = cp_models.Cellpose(gpu=True, model_type=model_name, device=device) #net_avg=net_avg
|
2085
|
+
#dn = denoise.CellposeDenoiseModel(model_type=f"denoise_{model_name}", gpu=True, device=device)
|
2086
|
+
|
2087
|
+
chans = [2, 1] if model_name == 'cyto2' else [0,0] if model_name == 'nucleus' else [2,0] if model_name == 'cyto' else [2, 0] if model_name == 'cyto3' else [2, 0]
|
2088
|
+
|
2089
|
+
paths = [os.path.join(src, file) for file in os.listdir(src) if file.endswith('.npz')]
|
2090
|
+
|
2091
|
+
count_loc = os.path.dirname(src)+'/measurements/measurements.db'
|
2092
|
+
os.makedirs(os.path.dirname(src)+'/measurements', exist_ok=True)
|
2093
|
+
_create_database(count_loc)
|
2094
|
+
|
2095
|
+
average_sizes = []
|
2096
|
+
time_ls = []
|
2097
|
+
moving_avg_q1 = 0
|
2098
|
+
moving_avg_q3 = 0
|
2099
|
+
moving_count = 0
|
2100
|
+
|
2101
|
+
for file_index, path in enumerate(paths):
|
2102
|
+
name = os.path.basename(path)
|
2103
|
+
name, ext = os.path.splitext(name)
|
2104
|
+
output_folder = os.path.join(os.path.dirname(path), object_type+'_mask_stack')
|
2105
|
+
os.makedirs(output_folder, exist_ok=True)
|
2106
|
+
overall_average_size = 0
|
2107
|
+
with np.load(path) as data:
|
2108
|
+
stack = data['data']
|
2109
|
+
filenames = data['filenames']
|
2110
|
+
if settings['timelapse']:
|
2111
|
+
if len(stack) != batch_size:
|
2112
|
+
print(f'Changed batch_size:{batch_size} to {len(stack)}, data length:{len(stack)}')
|
2113
|
+
settings['batch_size'] = len(stack)
|
2114
|
+
batch_size = len(stack)
|
2115
|
+
if isinstance(timelapse_frame_limits, list):
|
2116
|
+
if len(timelapse_frame_limits) >= 2:
|
2117
|
+
stack = stack[timelapse_frame_limits[0]: timelapse_frame_limits[1], :, :, :].astype(stack.dtype)
|
2118
|
+
filenames = filenames[timelapse_frame_limits[0]: timelapse_frame_limits[1]]
|
2119
|
+
batch_size = len(stack)
|
2120
|
+
print(f'Cut batch an indecies: {timelapse_frame_limits}, New batch_size: {batch_size} ')
|
2121
|
+
|
2122
|
+
for i in range(0, stack.shape[0], batch_size):
|
2123
|
+
mask_stack = []
|
2124
|
+
start = time.time()
|
2125
|
+
|
2126
|
+
if stack.shape[3] == 1:
|
2127
|
+
batch = stack[i: i+batch_size, :, :, [0,0]].astype(stack.dtype)
|
2128
|
+
else:
|
2129
|
+
batch = stack[i: i+batch_size, :, :, channels].astype(stack.dtype)
|
2130
|
+
|
2131
|
+
batch_filenames = filenames[i: i+batch_size].tolist()
|
2132
|
+
|
2133
|
+
if not settings['plot']:
|
2134
|
+
batch, batch_filenames = _check_masks(batch, batch_filenames, output_folder)
|
2135
|
+
if batch.size == 0:
|
2136
|
+
print(f'Processing {file_index}/{len(paths)}: Images/N100pz {batch.shape[0]}', end='\r', flush=True)
|
2137
|
+
continue
|
2138
|
+
if batch.max() > 1:
|
2139
|
+
batch = batch / batch.max()
|
2140
|
+
|
2141
|
+
if timelapse:
|
2142
|
+
stitch_threshold=100.0
|
2143
|
+
movie_path = os.path.join(os.path.dirname(src), 'movies')
|
2144
|
+
os.makedirs(movie_path, exist_ok=True)
|
2145
|
+
save_path = os.path.join(movie_path, f'timelapse_{object_type}_{name}.mp4')
|
2146
|
+
_npz_to_movie(batch, batch_filenames, save_path, fps=2)
|
2147
|
+
else:
|
2148
|
+
stitch_threshold=0.0
|
2149
|
+
#print(batch.shape)
|
2150
|
+
#batch, _, _, _ = dn.eval(x=batch, channels=chans, diameter=object_settings['diameter'])
|
2151
|
+
#batch = np.stack((batch, batch), axis=-1)
|
2152
|
+
#print(f'object: {object_type} chans : {chans} channels : {channels} model: {model_name}')
|
2153
|
+
masks, flows, _, _ = model.eval(x=batch,
|
2154
|
+
batch_size=cellpose_batch_size,
|
2155
|
+
normalize=False,
|
2156
|
+
channels=chans,
|
2157
|
+
channel_axis=3,
|
2158
|
+
diameter=object_settings['diameter'],
|
2159
|
+
flow_threshold=flow_threshold,
|
2160
|
+
cellprob_threshold=cellprob_threshold,
|
2161
|
+
rescale=None,
|
2162
|
+
resample=object_settings['resample'],
|
2163
|
+
stitch_threshold=stitch_threshold)
|
2164
|
+
#progress=None)
|
2165
|
+
|
2166
|
+
if timelapse:
|
2167
|
+
_save_object_counts_to_database(masks, object_type, batch_filenames, count_loc, added_string='_timelapse')
|
2168
|
+
if object_type in timelapse_objects:
|
2169
|
+
if timelapse_mode == 'btrack':
|
2170
|
+
if not timelapse_displacement is None:
|
2171
|
+
radius = timelapse_displacement
|
2172
|
+
else:
|
2173
|
+
radius = 100
|
2174
|
+
|
2175
|
+
workers = os.cpu_count()-2
|
2176
|
+
if workers < 1:
|
2177
|
+
workers = 1
|
2178
|
+
|
2179
|
+
mask_stack = _btrack_track_cells(src=src,
|
2180
|
+
name=name,
|
2181
|
+
batch_filenames=batch_filenames,
|
2182
|
+
object_type=object_type,
|
2183
|
+
plot=settings['plot'],
|
2184
|
+
save=settings['save'],
|
2185
|
+
masks_3D=masks,
|
2186
|
+
mode=timelapse_mode,
|
2187
|
+
timelapse_remove_transient=timelapse_remove_transient,
|
2188
|
+
radius=radius,
|
2189
|
+
workers=workers)
|
2190
|
+
if timelapse_mode == 'trackpy':
|
2191
|
+
mask_stack = _trackpy_track_cells(src=src,
|
2192
|
+
name=name,
|
2193
|
+
batch_filenames=batch_filenames,
|
2194
|
+
object_type=object_type,
|
2195
|
+
masks_3D=masks,
|
2196
|
+
timelapse_displacement=timelapse_displacement,
|
2197
|
+
timelapse_memory=timelapse_memory,
|
2198
|
+
timelapse_remove_transient=timelapse_remove_transient,
|
2199
|
+
plot=settings['plot'],
|
2200
|
+
save=settings['save'],
|
2201
|
+
timelapse_mode=timelapse_mode)
|
2202
|
+
else:
|
2203
|
+
mask_stack = _masks_to_masks_stack(masks)
|
2204
|
+
|
2205
|
+
else:
|
2206
|
+
_save_object_counts_to_database(masks, object_type, batch_filenames, count_loc, added_string='_before_filtration')
|
2207
|
+
mask_stack = _filter_cp_masks(masks=masks,
|
2208
|
+
flows=flows,
|
2209
|
+
filter_size=object_settings['filter_size'],
|
2210
|
+
minimum_size=object_settings['minimum_size'],
|
2211
|
+
maximum_size=object_settings['maximum_size'],
|
2212
|
+
remove_border_objects=object_settings['remove_border_objects'],
|
2213
|
+
merge=False,
|
2214
|
+
filter_dimm=object_settings['filter_dimm'],
|
2215
|
+
batch=batch,
|
2216
|
+
moving_avg_q1=moving_avg_q1,
|
2217
|
+
moving_avg_q3=moving_avg_q3,
|
2218
|
+
moving_count=moving_count,
|
2219
|
+
plot=settings['plot'],
|
2220
|
+
figuresize=figuresize)
|
2221
|
+
|
2222
|
+
_save_object_counts_to_database(mask_stack, object_type, batch_filenames, count_loc, added_string='_after_filtration')
|
2223
|
+
|
2224
|
+
if not np.any(mask_stack):
|
2225
|
+
average_obj_size = 0
|
2226
|
+
else:
|
2227
|
+
average_obj_size = _get_avg_object_size(mask_stack)
|
2228
|
+
|
2229
|
+
average_sizes.append(average_obj_size)
|
2230
|
+
overall_average_size = np.mean(average_sizes) if len(average_sizes) > 0 else 0
|
2231
|
+
|
2232
|
+
stop = time.time()
|
2233
|
+
duration = (stop - start)
|
2234
|
+
time_ls.append(duration)
|
2235
|
+
average_time = np.mean(time_ls) if len(time_ls) > 0 else 0
|
2236
|
+
time_in_min = average_time/60
|
2237
|
+
time_per_mask = average_time/batch_size
|
2238
|
+
print(f'Processing {len(paths)} files with {batch_size} imgs: {(file_index+1)*(batch_size+1)}/{(len(paths))*(batch_size+1)}: Time/batch {time_in_min:.3f} min: Time/mask {time_per_mask:.3f}sec: {object_type} size: {overall_average_size:.3f} px2', end='\r', flush=True)
|
2239
|
+
if not timelapse:
|
2240
|
+
if settings['plot']:
|
2241
|
+
plot_masks(batch, mask_stack, flows, figuresize=figuresize, cmap='inferno', nr=batch_size)
|
2242
|
+
if settings['save']:
|
2243
|
+
for mask_index, mask in enumerate(mask_stack):
|
2244
|
+
output_filename = os.path.join(output_folder, batch_filenames[mask_index])
|
2245
|
+
np.save(output_filename, mask)
|
2246
|
+
mask_stack = []
|
2247
|
+
batch_filenames = []
|
2248
|
+
gc.collect()
|
2249
|
+
torch.cuda.empty_cache()
|
2250
|
+
return
|