spacr 0.3.1__py3-none-any.whl → 0.3.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- spacr/__init__.py +19 -3
- spacr/cellpose.py +311 -0
- spacr/core.py +245 -2494
- spacr/deep_spacr.py +335 -163
- spacr/gui.py +2 -0
- spacr/gui_core.py +85 -65
- spacr/gui_elements.py +110 -5
- spacr/gui_utils.py +375 -7
- spacr/io.py +680 -141
- spacr/logger.py +28 -9
- spacr/measure.py +108 -133
- spacr/mediar.py +0 -3
- spacr/ml.py +1051 -0
- spacr/openai.py +37 -0
- spacr/plot.py +707 -20
- spacr/resources/data/lopit.csv +3833 -0
- spacr/resources/data/toxoplasma_metadata.csv +8843 -0
- spacr/resources/icons/convert.png +0 -0
- spacr/resources/{models/cp/toxo_plaque_cyto_e25000_X1120_Y1120.CP_model → icons/dna_matrix.mp4} +0 -0
- spacr/sequencing.py +241 -1311
- spacr/settings.py +181 -50
- spacr/sim.py +0 -2
- spacr/submodules.py +349 -0
- spacr/timelapse.py +0 -2
- spacr/toxo.py +238 -0
- spacr/utils.py +776 -182
- {spacr-0.3.1.dist-info → spacr-0.3.3.dist-info}/METADATA +31 -22
- {spacr-0.3.1.dist-info → spacr-0.3.3.dist-info}/RECORD +32 -33
- spacr/chris.py +0 -50
- spacr/graph_learning.py +0 -340
- spacr/resources/MEDIAR/.git +0 -1
- spacr/resources/MEDIAR_weights/.DS_Store +0 -0
- spacr/resources/icons/.DS_Store +0 -0
- spacr/resources/icons/spacr_logo_rotation.gif +0 -0
- spacr/resources/models/cp/toxo_plaque_cyto_e25000_X1120_Y1120.CP_model_settings.csv +0 -23
- spacr/resources/models/cp/toxo_pv_lumen.CP_model +0 -0
- spacr/sim_app.py +0 -0
- {spacr-0.3.1.dist-info → spacr-0.3.3.dist-info}/LICENSE +0 -0
- {spacr-0.3.1.dist-info → spacr-0.3.3.dist-info}/WHEEL +0 -0
- {spacr-0.3.1.dist-info → spacr-0.3.3.dist-info}/entry_points.txt +0 -0
- {spacr-0.3.1.dist-info → spacr-0.3.3.dist-info}/top_level.txt +0 -0
spacr/core.py
CHANGED
@@ -1,1925 +1,169 @@
|
|
1
|
-
import os,
|
2
|
-
|
1
|
+
import os, gc, torch, time, random
|
3
2
|
import numpy as np
|
4
3
|
import pandas as pd
|
5
|
-
|
6
|
-
from cellpose import train
|
7
|
-
from cellpose import models as cp_models
|
8
|
-
|
9
|
-
import statsmodels.formula.api as smf
|
10
|
-
import statsmodels.api as sm
|
11
|
-
from functools import reduce
|
12
|
-
from IPython.display import display
|
13
|
-
from multiprocessing import Pool, cpu_count, Value, Lock
|
14
|
-
|
15
|
-
import seaborn as sns
|
16
|
-
import cellpose
|
17
|
-
from skimage.measure import regionprops, label
|
18
|
-
from skimage.transform import resize as resizescikit
|
19
|
-
|
20
|
-
from skimage import measure
|
21
|
-
from sklearn.model_selection import train_test_split
|
22
|
-
from sklearn.ensemble import IsolationForest, RandomForestClassifier, HistGradientBoostingClassifier
|
23
|
-
from sklearn.linear_model import LogisticRegression
|
24
|
-
from sklearn.inspection import permutation_importance
|
25
|
-
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, classification_report
|
26
|
-
from sklearn.preprocessing import StandardScaler
|
27
|
-
from sklearn.metrics import precision_recall_curve, f1_score
|
28
|
-
|
29
|
-
from scipy.spatial.distance import cosine, euclidean, mahalanobis, cityblock, minkowski, chebyshev, hamming, jaccard, braycurtis
|
30
|
-
|
31
|
-
import torchvision.transforms as transforms
|
32
|
-
from xgboost import XGBClassifier
|
33
|
-
import shap
|
34
|
-
|
35
4
|
import matplotlib.pyplot as plt
|
36
|
-
import
|
37
|
-
matplotlib.use('Agg')
|
38
|
-
|
39
|
-
from .logger import log_function_call
|
5
|
+
from IPython.display import display
|
40
6
|
|
41
7
|
import warnings
|
42
8
|
warnings.filterwarnings("ignore", message="3D stack used, but stitch_threshold=0 and do_3D=False, so masks are made per plane only")
|
43
9
|
|
44
|
-
|
45
|
-
from torchvision import transforms
|
46
|
-
from torch.utils.data import DataLoader, random_split
|
47
|
-
from collections import defaultdict
|
48
|
-
import os
|
49
|
-
import random
|
50
|
-
from PIL import Image
|
51
|
-
from torchvision.transforms import ToTensor
|
52
|
-
|
53
|
-
def analyze_plaques(folder):
|
54
|
-
summary_data = []
|
55
|
-
details_data = []
|
56
|
-
stats_data = []
|
57
|
-
|
58
|
-
for filename in os.listdir(folder):
|
59
|
-
filepath = os.path.join(folder, filename)
|
60
|
-
if os.path.isfile(filepath):
|
61
|
-
# Assuming each file is a NumPy array file (.npy) containing a 16-bit labeled image
|
62
|
-
#image = np.load(filepath)
|
63
|
-
image = cellpose.io.imread(filepath)
|
64
|
-
labeled_image = label(image)
|
65
|
-
regions = regionprops(labeled_image)
|
66
|
-
|
67
|
-
object_count = len(regions)
|
68
|
-
sizes = [region.area for region in regions]
|
69
|
-
average_size = np.mean(sizes) if sizes else 0
|
70
|
-
std_dev_size = np.std(sizes) if sizes else 0
|
71
|
-
|
72
|
-
summary_data.append({'file': filename, 'object_count': object_count, 'average_size': average_size})
|
73
|
-
stats_data.append({'file': filename, 'plaque_count': object_count, 'average_size': average_size, 'std_dev_size': std_dev_size})
|
74
|
-
for size in sizes:
|
75
|
-
details_data.append({'file': filename, 'plaque_size': size})
|
76
|
-
|
77
|
-
# Convert lists to pandas DataFrames
|
78
|
-
summary_df = pd.DataFrame(summary_data)
|
79
|
-
details_df = pd.DataFrame(details_data)
|
80
|
-
stats_df = pd.DataFrame(stats_data)
|
81
|
-
|
82
|
-
# Save DataFrames to a SQLite database
|
83
|
-
db_name = os.path.join(folder, 'plaques_analysis.db')
|
84
|
-
conn = sqlite3.connect(db_name)
|
85
|
-
|
86
|
-
summary_df.to_sql('summary', conn, if_exists='replace', index=False)
|
87
|
-
details_df.to_sql('details', conn, if_exists='replace', index=False)
|
88
|
-
stats_df.to_sql('stats', conn, if_exists='replace', index=False)
|
89
|
-
|
90
|
-
conn.close()
|
91
|
-
|
92
|
-
print(f"Analysis completed and saved to database '{db_name}'.")
|
93
|
-
|
94
|
-
def train_cellpose(settings):
|
95
|
-
|
96
|
-
from .io import _load_normalized_images_and_labels, _load_images_and_labels
|
97
|
-
from .settings import get_train_cellpose_default_settings#, resize_images_and_labels
|
98
|
-
|
99
|
-
settings = get_train_cellpose_default_settings()
|
100
|
-
|
101
|
-
img_src = settings['img_src']
|
102
|
-
mask_src = os.path.join(img_src, 'masks')
|
103
|
-
|
104
|
-
model_name = settings.setdefault( 'model_name', '')
|
105
|
-
|
106
|
-
model_name = settings.setdefault('model_name', 'model_name')
|
107
|
-
|
108
|
-
model_type = settings.setdefault( 'model_type', 'cyto')
|
109
|
-
learning_rate = settings.setdefault( 'learning_rate', 0.01)
|
110
|
-
weight_decay = settings.setdefault( 'weight_decay', 1e-05)
|
111
|
-
batch_size = settings.setdefault( 'batch_size', 50)
|
112
|
-
n_epochs = settings.setdefault( 'n_epochs', 100)
|
113
|
-
from_scratch = settings.setdefault( 'from_scratch', False)
|
114
|
-
diameter = settings.setdefault( 'diameter', 40)
|
115
|
-
|
116
|
-
remove_background = settings.setdefault( 'remove_background', False)
|
117
|
-
background = settings.setdefault( 'background', 100)
|
118
|
-
Signal_to_noise = settings.setdefault( 'Signal_to_noise', 10)
|
119
|
-
verbose = settings.setdefault( 'verbose', False)
|
120
|
-
|
121
|
-
channels = settings.setdefault( 'channels', [0,0])
|
122
|
-
normalize = settings.setdefault( 'normalize', True)
|
123
|
-
percentiles = settings.setdefault( 'percentiles', None)
|
124
|
-
circular = settings.setdefault( 'circular', False)
|
125
|
-
invert = settings.setdefault( 'invert', False)
|
126
|
-
resize = settings.setdefault( 'resize', False)
|
127
|
-
|
128
|
-
if resize:
|
129
|
-
target_height = settings['width_height'][1]
|
130
|
-
target_width = settings['width_height'][0]
|
131
|
-
|
132
|
-
grayscale = settings.setdefault( 'grayscale', True)
|
133
|
-
rescale = settings.setdefault( 'channels', False)
|
134
|
-
test = settings.setdefault( 'test', False)
|
135
|
-
|
136
|
-
if test:
|
137
|
-
test_img_src = os.path.join(os.path.dirname(img_src), 'test')
|
138
|
-
test_mask_src = os.path.join(test_img_src, 'mask')
|
139
|
-
|
140
|
-
test_images, test_masks, test_image_names, test_mask_names = None,None,None,None
|
141
|
-
print(settings)
|
142
|
-
|
143
|
-
if from_scratch:
|
144
|
-
model_name=f'scratch_{model_name}_{model_type}_e{n_epochs}_X{target_width}_Y{target_height}.CP_model'
|
145
|
-
else:
|
146
|
-
if resize:
|
147
|
-
model_name=f'{model_name}_{model_type}_e{n_epochs}_X{target_width}_Y{target_height}.CP_model'
|
148
|
-
else:
|
149
|
-
model_name=f'{model_name}_{model_type}_e{n_epochs}.CP_model'
|
150
|
-
|
151
|
-
model_save_path = os.path.join(mask_src, 'models', 'cellpose_model')
|
152
|
-
print(model_save_path)
|
153
|
-
os.makedirs(model_save_path, exist_ok=True)
|
154
|
-
|
155
|
-
settings_df = pd.DataFrame(list(settings.items()), columns=['Key', 'Value'])
|
156
|
-
settings_csv = os.path.join(model_save_path,f'{model_name}_settings.csv')
|
157
|
-
settings_df.to_csv(settings_csv, index=False)
|
158
|
-
|
159
|
-
if from_scratch:
|
160
|
-
model = cp_models.CellposeModel(gpu=True, model_type=model_type, diam_mean=diameter, pretrained_model=None)
|
161
|
-
else:
|
162
|
-
model = cp_models.CellposeModel(gpu=True, model_type=model_type)
|
163
|
-
|
164
|
-
if normalize:
|
165
|
-
|
166
|
-
image_files = [os.path.join(img_src, f) for f in os.listdir(img_src) if f.endswith('.tif')]
|
167
|
-
label_files = [os.path.join(mask_src, f) for f in os.listdir(mask_src) if f.endswith('.tif')]
|
168
|
-
images, masks, image_names, mask_names, orig_dims = _load_normalized_images_and_labels(image_files, label_files, channels, percentiles, circular, invert, verbose, remove_background, background, Signal_to_noise, target_height, target_width)
|
169
|
-
images = [np.squeeze(img) if img.shape[-1] == 1 else img for img in images]
|
170
|
-
|
171
|
-
if test:
|
172
|
-
test_image_files = [os.path.join(test_img_src, f) for f in os.listdir(test_img_src) if f.endswith('.tif')]
|
173
|
-
test_label_files = [os.path.join(test_mask_src, f) for f in os.listdir(test_mask_src) if f.endswith('.tif')]
|
174
|
-
test_images, test_masks, test_image_names, test_mask_names = _load_normalized_images_and_labels(test_image_files, test_label_files, channels, percentiles, circular, invert, verbose, remove_background, background, Signal_to_noise, target_height, target_width)
|
175
|
-
test_images = [np.squeeze(img) if img.shape[-1] == 1 else img for img in test_images]
|
176
|
-
|
177
|
-
else:
|
178
|
-
images, masks, image_names, mask_names = _load_images_and_labels(img_src, mask_src, circular, invert)
|
179
|
-
images = [np.squeeze(img) if img.shape[-1] == 1 else img for img in images]
|
180
|
-
|
181
|
-
if test:
|
182
|
-
test_images, test_masks, test_image_names, test_mask_names = _load_images_and_labels(img_src=test_img_src, mask_src=test_mask_src, circular=circular, invert=invert)
|
183
|
-
test_images = [np.squeeze(img) if img.shape[-1] == 1 else img for img in test_images]
|
184
|
-
|
185
|
-
#if resize:
|
186
|
-
# images, masks = resize_images_and_labels(images, masks, target_height, target_width, show_example=True)
|
187
|
-
|
188
|
-
if model_type == 'cyto':
|
189
|
-
cp_channels = [0,1]
|
190
|
-
if model_type == 'cyto2':
|
191
|
-
cp_channels = [0,2]
|
192
|
-
if model_type == 'nucleus':
|
193
|
-
cp_channels = [0,0]
|
194
|
-
if grayscale:
|
195
|
-
cp_channels = [0,0]
|
196
|
-
images = [np.squeeze(img) if img.ndim == 3 and 1 in img.shape else img for img in images]
|
197
|
-
|
198
|
-
masks = [np.squeeze(mask) if mask.ndim == 3 and 1 in mask.shape else mask for mask in masks]
|
199
|
-
|
200
|
-
print(f'image shape: {images[0].shape}, image type: images[0].shape mask shape: {masks[0].shape}, image type: masks[0].shape')
|
201
|
-
save_every = int(n_epochs/10)
|
202
|
-
if save_every < 10:
|
203
|
-
save_every = n_epochs
|
204
|
-
|
205
|
-
train.train_seg(model.net,
|
206
|
-
train_data=images,
|
207
|
-
train_labels=masks,
|
208
|
-
train_files=image_names,
|
209
|
-
train_labels_files=mask_names,
|
210
|
-
train_probs=None,
|
211
|
-
test_data=test_images,
|
212
|
-
test_labels=test_masks,
|
213
|
-
test_files=test_image_names,
|
214
|
-
test_labels_files=test_mask_names,
|
215
|
-
test_probs=None,
|
216
|
-
load_files=True,
|
217
|
-
batch_size=batch_size,
|
218
|
-
learning_rate=learning_rate,
|
219
|
-
n_epochs=n_epochs,
|
220
|
-
weight_decay=weight_decay,
|
221
|
-
momentum=0.9,
|
222
|
-
SGD=False,
|
223
|
-
channels=cp_channels,
|
224
|
-
channel_axis=None,
|
225
|
-
#rgb=False,
|
226
|
-
normalize=False,
|
227
|
-
compute_flows=False,
|
228
|
-
save_path=model_save_path,
|
229
|
-
save_every=save_every,
|
230
|
-
nimg_per_epoch=None,
|
231
|
-
nimg_test_per_epoch=None,
|
232
|
-
rescale=rescale,
|
233
|
-
#scale_range=None,
|
234
|
-
#bsize=224,
|
235
|
-
min_train_masks=1,
|
236
|
-
model_name=model_name)
|
237
|
-
|
238
|
-
return print(f"Model saved at: {model_save_path}/{model_name}")
|
239
|
-
|
240
|
-
def analyze_data_reg(sequencing_loc, dv_loc, agg_type = 'mean', dv_col='pred', transform=None, min_cell_count=50, min_reads=100, min_wells=2, max_wells=1000, min_frequency=0.0,remove_outlier_genes=False, refine_model=False,by_plate=False, regression_type='mlr', alpha_value=0.01, fishers=False, fisher_threshold=0.9):
|
241
|
-
|
242
|
-
from .plot import _reg_v_plot
|
243
|
-
from .utils import generate_fraction_map, MLR, fishers_odds, lasso_reg
|
244
|
-
|
245
|
-
def qstring_to_float(qstr):
|
246
|
-
number = int(qstr[1:]) # Remove the "q" and convert the rest to an integer
|
247
|
-
return number / 100.0
|
248
|
-
|
249
|
-
columns_list = ['c1', 'c2', 'c3']
|
250
|
-
plate_list = ['p1','p3','p4']
|
251
|
-
|
252
|
-
dv_df = pd.read_csv(dv_loc)#, index_col='prc')
|
253
|
-
|
254
|
-
if agg_type.startswith('q'):
|
255
|
-
val = qstring_to_float(agg_type)
|
256
|
-
agg_type = lambda x: x.quantile(val)
|
257
|
-
|
258
|
-
# Aggregating for mean prediction, total count and count of values > 0.95
|
259
|
-
dv_df = dv_df.groupby('prc').agg(
|
260
|
-
pred=(dv_col, agg_type),
|
261
|
-
count_prc=('prc', 'size'),
|
262
|
-
mean_pathogen_area=('pathogen_area', 'mean')
|
263
|
-
)
|
264
|
-
|
265
|
-
dv_df = dv_df[dv_df['count_prc'] >= min_cell_count]
|
266
|
-
sequencing_df = pd.read_csv(sequencing_loc)
|
267
|
-
|
268
|
-
|
269
|
-
reads_df, stats_dict = process_reads(df=sequencing_df,
|
270
|
-
min_reads=min_reads,
|
271
|
-
min_wells=min_wells,
|
272
|
-
max_wells=max_wells,
|
273
|
-
gene_column='gene',
|
274
|
-
remove_outliers=remove_outlier_genes)
|
275
|
-
|
276
|
-
reads_df['value'] = reads_df['count']/reads_df['well_read_sum']
|
277
|
-
reads_df['gene_grna'] = reads_df['gene']+'_'+reads_df['grna']
|
278
|
-
|
279
|
-
display(reads_df)
|
280
|
-
|
281
|
-
df_long = reads_df
|
282
|
-
|
283
|
-
df_long = df_long[df_long['value'] > min_frequency] # removes gRNAs under a certain proportion
|
284
|
-
#df_long = df_long[df_long['value']<1.0] # removes gRNAs in wells with only one gRNA
|
285
|
-
|
286
|
-
# Extract gene and grna info from gene_grna column
|
287
|
-
df_long["gene"] = df_long["grna"].str.split("_").str[1]
|
288
|
-
df_long["grna"] = df_long["grna"].str.split("_").str[2]
|
289
|
-
|
290
|
-
agg_df = df_long.groupby('prc')['count'].sum().reset_index()
|
291
|
-
agg_df = agg_df.rename(columns={'count': 'count_sum'})
|
292
|
-
df_long = pd.merge(df_long, agg_df, on='prc', how='left')
|
293
|
-
df_long['value'] = df_long['count']/df_long['count_sum']
|
294
|
-
|
295
|
-
merged_df = df_long.merge(dv_df, left_on='prc', right_index=True)
|
296
|
-
merged_df = merged_df[merged_df['value'] > 0]
|
297
|
-
merged_df['plate'] = merged_df['prc'].str.split('_').str[0]
|
298
|
-
merged_df['row'] = merged_df['prc'].str.split('_').str[1]
|
299
|
-
merged_df['column'] = merged_df['prc'].str.split('_').str[2]
|
300
|
-
|
301
|
-
merged_df = merged_df[~merged_df['column'].isin(columns_list)]
|
302
|
-
merged_df = merged_df[merged_df['plate'].isin(plate_list)]
|
303
|
-
|
304
|
-
if transform == 'log':
|
305
|
-
merged_df['pred'] = np.log(merged_df['pred'] + 1e-10)
|
306
|
-
|
307
|
-
# Printing the unique values in 'col' and 'plate' columns
|
308
|
-
print("Unique values in col:", merged_df['column'].unique())
|
309
|
-
print("Unique values in plate:", merged_df['plate'].unique())
|
310
|
-
display(merged_df)
|
311
|
-
|
312
|
-
if fishers:
|
313
|
-
iv_df = generate_fraction_map(df=reads_df,
|
314
|
-
gene_column='grna',
|
315
|
-
min_frequency=min_frequency)
|
316
|
-
|
317
|
-
fishers_df = iv_df.join(dv_df, on='prc', how='inner')
|
318
|
-
|
319
|
-
significant_mutants = fishers_odds(df=fishers_df, threshold=fisher_threshold, phenotyp_col='pred')
|
320
|
-
significant_mutants = significant_mutants.sort_values(by='OddsRatio', ascending=False)
|
321
|
-
display(significant_mutants)
|
322
|
-
|
323
|
-
if regression_type == 'mlr':
|
324
|
-
if by_plate:
|
325
|
-
merged_df2 = merged_df.copy()
|
326
|
-
for plate in merged_df2['plate'].unique():
|
327
|
-
merged_df = merged_df2[merged_df2['plate'] == plate]
|
328
|
-
print(f'merged_df: {len(merged_df)}, plate: {plate}')
|
329
|
-
if len(merged_df) <100:
|
330
|
-
break
|
331
|
-
|
332
|
-
max_effects, max_effects_pvalues, model, df = MLR(merged_df, refine_model)
|
333
|
-
else:
|
334
|
-
|
335
|
-
max_effects, max_effects_pvalues, model, df = MLR(merged_df, refine_model)
|
336
|
-
return max_effects, max_effects_pvalues, model, df
|
337
|
-
|
338
|
-
if regression_type == 'ridge' or regression_type == 'lasso':
|
339
|
-
coeffs = lasso_reg(merged_df, alpha_value=alpha_value, reg_type=regression_type)
|
340
|
-
return coeffs
|
341
|
-
|
342
|
-
if regression_type == 'mixed':
|
343
|
-
model = smf.mixedlm("pred ~ gene_grna - 1", merged_df, groups=merged_df["plate"], re_formula="~1")
|
344
|
-
result = model.fit(method="bfgs")
|
345
|
-
print(result.summary())
|
346
|
-
|
347
|
-
# Print AIC and BIC
|
348
|
-
print("AIC:", result.aic)
|
349
|
-
print("BIC:", result.bic)
|
350
|
-
|
351
|
-
|
352
|
-
results_df = pd.DataFrame({
|
353
|
-
'effect': result.params,
|
354
|
-
'Standard Error': result.bse,
|
355
|
-
'T-Value': result.tvalues,
|
356
|
-
'p': result.pvalues
|
357
|
-
})
|
358
|
-
|
359
|
-
display(results_df)
|
360
|
-
_reg_v_plot(df=results_df)
|
361
|
-
|
362
|
-
std_resid = result.resid
|
363
|
-
|
364
|
-
# Create subplots
|
365
|
-
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
|
366
|
-
|
367
|
-
# Histogram of Residuals
|
368
|
-
axes[0].hist(std_resid, bins=50, edgecolor='k')
|
369
|
-
axes[0].set_xlabel('Residuals')
|
370
|
-
axes[0].set_ylabel('Frequency')
|
371
|
-
axes[0].set_title('Histogram of Residuals')
|
372
|
-
|
373
|
-
# Boxplot of Residuals
|
374
|
-
axes[1].boxplot(std_resid)
|
375
|
-
axes[1].set_ylabel('Residuals')
|
376
|
-
axes[1].set_title('Boxplot of Residuals')
|
377
|
-
|
378
|
-
# QQ Plot
|
379
|
-
sm.qqplot(std_resid, line='45', ax=axes[2])
|
380
|
-
axes[2].set_title('QQ Plot')
|
381
|
-
|
382
|
-
# Show plots
|
383
|
-
plt.tight_layout()
|
384
|
-
plt.show()
|
385
|
-
|
386
|
-
return result
|
387
|
-
|
388
|
-
def analyze_data_reg(sequencing_loc, dv_loc, agg_type = 'mean', min_cell_count=50, min_reads=100, min_wells=2, max_wells=1000, remove_outlier_genes=False, refine_model=False, by_plate=False, threshold=0.5, fishers=False):
|
389
|
-
|
390
|
-
from .plot import _reg_v_plot
|
391
|
-
from .utils import generate_fraction_map, fishers_odds, model_metrics
|
392
|
-
|
393
|
-
def qstring_to_float(qstr):
|
394
|
-
number = int(qstr[1:]) # Remove the "q" and convert the rest to an integer
|
395
|
-
return number / 100.0
|
396
|
-
|
397
|
-
columns_list = ['c1', 'c2', 'c3', 'c15']
|
398
|
-
plate_list = ['p1','p2','p3','p4']
|
399
|
-
|
400
|
-
dv_df = pd.read_csv(dv_loc)#, index_col='prc')
|
401
|
-
|
402
|
-
if agg_type.startswith('q'):
|
403
|
-
val = qstring_to_float(agg_type)
|
404
|
-
agg_type = lambda x: x.quantile(val)
|
405
|
-
|
406
|
-
# Aggregating for mean prediction, total count and count of values > 0.95
|
407
|
-
dv_df = dv_df.groupby('prc').agg(
|
408
|
-
pred=('pred', agg_type),
|
409
|
-
count_prc=('prc', 'size'),
|
410
|
-
#count_above_95=('pred', lambda x: (x > 0.95).sum()),
|
411
|
-
mean_pathogen_area=('pathogen_area', 'mean')
|
412
|
-
)
|
413
|
-
|
414
|
-
dv_df = dv_df[dv_df['count_prc'] >= min_cell_count]
|
415
|
-
sequencing_df = pd.read_csv(sequencing_loc)
|
416
|
-
|
417
|
-
reads_df, stats_dict = process_reads(df=sequencing_df,
|
418
|
-
min_reads=min_reads,
|
419
|
-
min_wells=min_wells,
|
420
|
-
max_wells=max_wells,
|
421
|
-
gene_column='gene',
|
422
|
-
remove_outliers=remove_outlier_genes)
|
423
|
-
|
424
|
-
iv_df = generate_fraction_map(df=reads_df,
|
425
|
-
gene_column='grna',
|
426
|
-
min_frequency=0.0)
|
427
|
-
|
428
|
-
# Melt the iv_df to long format
|
429
|
-
df_long = iv_df.reset_index().melt(id_vars=["prc"],
|
430
|
-
value_vars=iv_df.columns,
|
431
|
-
var_name="gene_grna",
|
432
|
-
value_name="value")
|
433
|
-
|
434
|
-
# Extract gene and grna info from gene_grna column
|
435
|
-
df_long["gene"] = df_long["gene_grna"].str.split("_").str[1]
|
436
|
-
df_long["grna"] = df_long["gene_grna"].str.split("_").str[2]
|
437
|
-
|
438
|
-
merged_df = df_long.merge(dv_df, left_on='prc', right_index=True)
|
439
|
-
merged_df = merged_df[merged_df['value'] > 0]
|
440
|
-
merged_df['plate'] = merged_df['prc'].str.split('_').str[0]
|
441
|
-
merged_df['row'] = merged_df['prc'].str.split('_').str[1]
|
442
|
-
merged_df['column'] = merged_df['prc'].str.split('_').str[2]
|
443
|
-
|
444
|
-
merged_df = merged_df[~merged_df['column'].isin(columns_list)]
|
445
|
-
merged_df = merged_df[merged_df['plate'].isin(plate_list)]
|
446
|
-
|
447
|
-
# Printing the unique values in 'col' and 'plate' columns
|
448
|
-
print("Unique values in col:", merged_df['column'].unique())
|
449
|
-
print("Unique values in plate:", merged_df['plate'].unique())
|
450
|
-
|
451
|
-
if not by_plate:
|
452
|
-
if fishers:
|
453
|
-
fishers_odds(df=merged_df, threshold=threshold, phenotyp_col='pred')
|
454
|
-
|
455
|
-
if by_plate:
|
456
|
-
merged_df2 = merged_df.copy()
|
457
|
-
for plate in merged_df2['plate'].unique():
|
458
|
-
merged_df = merged_df2[merged_df2['plate'] == plate]
|
459
|
-
print(f'merged_df: {len(merged_df)}, plate: {plate}')
|
460
|
-
if len(merged_df) <100:
|
461
|
-
break
|
462
|
-
display(merged_df)
|
463
|
-
|
464
|
-
model = smf.ols("pred ~ gene + grna + gene:grna + plate + row + column", merged_df).fit()
|
465
|
-
#model = smf.ols("pred ~ infection_time + gene + grna + gene:grna + plate + row + column", merged_df).fit()
|
466
|
-
|
467
|
-
# Display model metrics and summary
|
468
|
-
model_metrics(model)
|
469
|
-
#print(model.summary())
|
470
|
-
|
471
|
-
if refine_model:
|
472
|
-
# Filter outliers
|
473
|
-
std_resid = model.get_influence().resid_studentized_internal
|
474
|
-
outliers_resid = np.where(np.abs(std_resid) > 3)[0]
|
475
|
-
(c, p) = model.get_influence().cooks_distance
|
476
|
-
outliers_cooks = np.where(c > 4/(len(merged_df)-merged_df.shape[1]-1))[0]
|
477
|
-
outliers = reduce(np.union1d, (outliers_resid, outliers_cooks))
|
478
|
-
merged_df_filtered = merged_df.drop(merged_df.index[outliers])
|
479
|
-
|
480
|
-
display(merged_df_filtered)
|
481
|
-
|
482
|
-
# Refit the model with filtered data
|
483
|
-
model = smf.ols("pred ~ gene + grna + gene:grna + row + column", merged_df_filtered).fit()
|
484
|
-
print("Number of outliers detected by standardized residuals:", len(outliers_resid))
|
485
|
-
print("Number of outliers detected by Cook's distance:", len(outliers_cooks))
|
486
|
-
|
487
|
-
model_metrics(model)
|
488
|
-
|
489
|
-
# Extract interaction coefficients and determine the maximum effect size
|
490
|
-
interaction_coeffs = {key: val for key, val in model.params.items() if "gene[T." in key and ":grna[T." in key}
|
491
|
-
interaction_pvalues = {key: val for key, val in model.pvalues.items() if "gene[T." in key and ":grna[T." in key}
|
492
|
-
|
493
|
-
max_effects = {}
|
494
|
-
max_effects_pvalues = {}
|
495
|
-
for key, val in interaction_coeffs.items():
|
496
|
-
gene_name = key.split(":")[0].replace("gene[T.", "").replace("]", "")
|
497
|
-
if gene_name not in max_effects or abs(max_effects[gene_name]) < abs(val):
|
498
|
-
max_effects[gene_name] = val
|
499
|
-
max_effects_pvalues[gene_name] = interaction_pvalues[key]
|
500
|
-
|
501
|
-
for key in max_effects:
|
502
|
-
print(f"Key: {key}: {max_effects[key]}, p:{max_effects_pvalues[key]}")
|
503
|
-
|
504
|
-
df = pd.DataFrame([max_effects, max_effects_pvalues])
|
505
|
-
df = df.transpose()
|
506
|
-
df = df.rename(columns={df.columns[0]: 'effect', df.columns[1]: 'p'})
|
507
|
-
df = df.sort_values(by=['effect', 'p'], ascending=[False, True])
|
508
|
-
|
509
|
-
_reg_v_plot(df)
|
510
|
-
|
511
|
-
if fishers:
|
512
|
-
fishers_odds(df=merged_df, threshold=threshold, phenotyp_col='pred')
|
513
|
-
else:
|
514
|
-
display(merged_df)
|
515
|
-
|
516
|
-
model = smf.ols("pred ~ gene + grna + gene:grna + plate + row + column", merged_df).fit()
|
517
|
-
|
518
|
-
# Display model metrics and summary
|
519
|
-
model_metrics(model)
|
520
|
-
|
521
|
-
if refine_model:
|
522
|
-
# Filter outliers
|
523
|
-
std_resid = model.get_influence().resid_studentized_internal
|
524
|
-
outliers_resid = np.where(np.abs(std_resid) > 3)[0]
|
525
|
-
(c, p) = model.get_influence().cooks_distance
|
526
|
-
outliers_cooks = np.where(c > 4/(len(merged_df)-merged_df.shape[1]-1))[0]
|
527
|
-
outliers = reduce(np.union1d, (outliers_resid, outliers_cooks))
|
528
|
-
merged_df_filtered = merged_df.drop(merged_df.index[outliers])
|
529
|
-
|
530
|
-
display(merged_df_filtered)
|
531
|
-
|
532
|
-
# Refit the model with filtered data
|
533
|
-
model = smf.ols("pred ~ gene + grna + gene:grna + plate + row + column", merged_df_filtered).fit()
|
534
|
-
print("Number of outliers detected by standardized residuals:", len(outliers_resid))
|
535
|
-
print("Number of outliers detected by Cook's distance:", len(outliers_cooks))
|
536
|
-
|
537
|
-
model_metrics(model)
|
538
|
-
|
539
|
-
# Extract interaction coefficients and determine the maximum effect size
|
540
|
-
interaction_coeffs = {key: val for key, val in model.params.items() if "gene[T." in key and ":grna[T." in key}
|
541
|
-
interaction_pvalues = {key: val for key, val in model.pvalues.items() if "gene[T." in key and ":grna[T." in key}
|
542
|
-
|
543
|
-
max_effects = {}
|
544
|
-
max_effects_pvalues = {}
|
545
|
-
for key, val in interaction_coeffs.items():
|
546
|
-
gene_name = key.split(":")[0].replace("gene[T.", "").replace("]", "")
|
547
|
-
if gene_name not in max_effects or abs(max_effects[gene_name]) < abs(val):
|
548
|
-
max_effects[gene_name] = val
|
549
|
-
max_effects_pvalues[gene_name] = interaction_pvalues[key]
|
550
|
-
|
551
|
-
for key in max_effects:
|
552
|
-
print(f"Key: {key}: {max_effects[key]}, p:{max_effects_pvalues[key]}")
|
553
|
-
|
554
|
-
df = pd.DataFrame([max_effects, max_effects_pvalues])
|
555
|
-
df = df.transpose()
|
556
|
-
df = df.rename(columns={df.columns[0]: 'effect', df.columns[1]: 'p'})
|
557
|
-
df = df.sort_values(by=['effect', 'p'], ascending=[False, True])
|
558
|
-
|
559
|
-
_reg_v_plot(df)
|
560
|
-
|
561
|
-
if fishers:
|
562
|
-
fishers_odds(df=merged_df, threshold=threshold, phenotyp_col='pred')
|
563
|
-
|
564
|
-
return max_effects, max_effects_pvalues, model, df
|
565
|
-
|
566
|
-
def regression_analasys(dv_df,sequencing_loc, min_reads=75, min_wells=2, max_wells=0, model_type = 'mlr', min_cells=100, transform='logit', min_frequency=0.05, gene_column='gene', effect_size_threshold=0.25, fishers=True, clean_regression=False, VIF_threshold=10):
|
567
|
-
|
568
|
-
from .utils import generate_fraction_map, fishers_odds, model_metrics, check_multicollinearity
|
569
|
-
|
570
|
-
sequencing_df = pd.read_csv(sequencing_loc)
|
571
|
-
columns_list = ['c1','c2','c3', 'c15']
|
572
|
-
sequencing_df = sequencing_df[~sequencing_df['col'].isin(columns_list)]
|
573
|
-
|
574
|
-
reads_df, stats_dict = process_reads(df=sequencing_df,
|
575
|
-
min_reads=min_reads,
|
576
|
-
min_wells=min_wells,
|
577
|
-
max_wells=max_wells,
|
578
|
-
gene_column='gene')
|
579
|
-
|
580
|
-
display(reads_df)
|
581
|
-
|
582
|
-
iv_df = generate_fraction_map(df=reads_df,
|
583
|
-
gene_column=gene_column,
|
584
|
-
min_frequency=min_frequency)
|
585
|
-
|
586
|
-
display(iv_df)
|
587
|
-
|
588
|
-
dv_df = dv_df[dv_df['count_prc']>min_cells]
|
589
|
-
display(dv_df)
|
590
|
-
merged_df = iv_df.join(dv_df, on='prc', how='inner')
|
591
|
-
display(merged_df)
|
592
|
-
fisher_df = merged_df.copy()
|
593
|
-
|
594
|
-
merged_df.reset_index(inplace=True)
|
595
|
-
merged_df[['plate', 'row', 'col']] = merged_df['prc'].str.split('_', expand=True)
|
596
|
-
merged_df = merged_df.drop(columns=['prc'])
|
597
|
-
merged_df.dropna(inplace=True)
|
598
|
-
merged_df = pd.get_dummies(merged_df, columns=['plate', 'row', 'col'], drop_first=True)
|
599
|
-
|
600
|
-
y = merged_df['mean_pred']
|
601
|
-
|
602
|
-
if model_type == 'mlr':
|
603
|
-
merged_df = merged_df.drop(columns=['count_prc'])
|
604
|
-
|
605
|
-
elif model_type == 'wls':
|
606
|
-
weights = merged_df['count_prc']
|
607
|
-
|
608
|
-
elif model_type == 'glm':
|
609
|
-
merged_df = merged_df.drop(columns=['count_prc'])
|
610
|
-
|
611
|
-
if transform == 'logit':
|
612
|
-
# logit transformation
|
613
|
-
epsilon = 1e-15
|
614
|
-
y = np.log(y + epsilon) - np.log(1 - y + epsilon)
|
615
|
-
|
616
|
-
elif transform == 'log':
|
617
|
-
# log transformation
|
618
|
-
y = np.log10(y+1)
|
619
|
-
|
620
|
-
elif transform == 'center':
|
621
|
-
# Centering the y around 0
|
622
|
-
y_mean = y.mean()
|
623
|
-
y = y - y_mean
|
624
|
-
|
625
|
-
x = merged_df.drop('mean_pred', axis=1)
|
626
|
-
x = x.select_dtypes(include=[np.number])
|
627
|
-
#x = sm.add_constant(x)
|
628
|
-
x['const'] = 0.0
|
629
|
-
|
630
|
-
if model_type == 'mlr':
|
631
|
-
model = sm.OLS(y, x).fit()
|
632
|
-
model_metrics(model)
|
633
|
-
|
634
|
-
# Check for Multicollinearity
|
635
|
-
vif_data = check_multicollinearity(x.drop('const', axis=1)) # assuming you've added a constant to x
|
636
|
-
high_vif_columns = vif_data[vif_data["VIF"] > VIF_threshold]["Variable"].values # VIF threshold of 10 is common, but this can vary based on context
|
637
|
-
|
638
|
-
print(f"Columns with high VIF: {high_vif_columns}")
|
639
|
-
x = x.drop(columns=high_vif_columns) # dropping columns with high VIF
|
640
|
-
|
641
|
-
if clean_regression:
|
642
|
-
# 1. Filter by standardized residuals
|
643
|
-
std_resid = model.get_influence().resid_studentized_internal
|
644
|
-
outliers_resid = np.where(np.abs(std_resid) > 3)[0]
|
645
|
-
|
646
|
-
# 2. Filter by leverage
|
647
|
-
influence = model.get_influence().hat_matrix_diag
|
648
|
-
outliers_lev = np.where(influence > 2*(x.shape[1])/len(y))[0]
|
649
|
-
|
650
|
-
# 3. Filter by Cook's distance
|
651
|
-
(c, p) = model.get_influence().cooks_distance
|
652
|
-
outliers_cooks = np.where(c > 4/(len(y)-x.shape[1]-1))[0]
|
653
|
-
|
654
|
-
# Combine all identified outliers
|
655
|
-
outliers = reduce(np.union1d, (outliers_resid, outliers_lev, outliers_cooks))
|
656
|
-
|
657
|
-
# Filter out outliers
|
658
|
-
x_clean = x.drop(x.index[outliers])
|
659
|
-
y_clean = y.drop(y.index[outliers])
|
660
|
-
|
661
|
-
# Re-run the regression with the filtered data
|
662
|
-
model = sm.OLS(y_clean, x_clean).fit()
|
663
|
-
model_metrics(model)
|
664
|
-
|
665
|
-
elif model_type == 'wls':
|
666
|
-
model = sm.WLS(y, x, weights=weights).fit()
|
667
|
-
|
668
|
-
elif model_type == 'glm':
|
669
|
-
model = sm.GLM(y, x, family=sm.families.Binomial()).fit()
|
670
|
-
|
671
|
-
print(model.summary())
|
672
|
-
|
673
|
-
results_summary = model.summary()
|
674
|
-
|
675
|
-
results_as_html = results_summary.tables[1].as_html()
|
676
|
-
results_df = pd.read_html(results_as_html, header=0, index_col=0)[0]
|
677
|
-
results_df = results_df.sort_values(by='coef', ascending=False)
|
678
|
-
|
679
|
-
if model_type == 'mlr':
|
680
|
-
results_df['p'] = results_df['P>|t|']
|
681
|
-
elif model_type == 'wls':
|
682
|
-
results_df['p'] = results_df['P>|t|']
|
683
|
-
elif model_type == 'glm':
|
684
|
-
results_df['p'] = results_df['P>|z|']
|
685
|
-
|
686
|
-
results_df['type'] = 1
|
687
|
-
results_df.loc[results_df['p'] == 0.000, 'p'] = 0.005
|
688
|
-
results_df['-log10(p)'] = -np.log10(results_df['p'])
|
689
|
-
|
690
|
-
display(results_df)
|
691
|
-
|
692
|
-
# Create subplots
|
693
|
-
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 15))
|
694
|
-
|
695
|
-
# Plot histogram on ax1
|
696
|
-
sns.histplot(data=y, kde=False, element="step", ax=ax1, color='teal')
|
697
|
-
ax1.set_xlim([0, 1])
|
698
|
-
ax1.spines['top'].set_visible(False)
|
699
|
-
ax1.spines['right'].set_visible(False)
|
700
|
-
|
701
|
-
# Prepare data for volcano plot on ax2
|
702
|
-
results_df['-log10(p)'] = -np.log10(results_df['p'])
|
703
|
-
|
704
|
-
# Assuming the 'type' column is in the merged_df
|
705
|
-
sc = ax2.scatter(results_df['coef'], results_df['-log10(p)'], c=results_df['type'], cmap='coolwarm')
|
706
|
-
ax2.set_title('Volcano Plot')
|
707
|
-
ax2.set_xlabel('Coefficient')
|
708
|
-
ax2.set_ylabel('-log10(P-value)')
|
709
|
-
|
710
|
-
# Adjust colorbar
|
711
|
-
cbar = plt.colorbar(sc, ax=ax2, ticks=[-1, 1])
|
712
|
-
cbar.set_label('Sign of Coefficient')
|
713
|
-
cbar.set_ticklabels(['-ve', '+ve'])
|
714
|
-
|
715
|
-
# Add text for specified points
|
716
|
-
for idx, row in results_df.iterrows():
|
717
|
-
if row['p'] < 0.05 and row['coef'] > effect_size_threshold:
|
718
|
-
ax2.text(row['coef'], -np.log10(row['p']), idx, fontsize=8, ha='center', va='bottom', color='black')
|
719
|
-
|
720
|
-
ax2.axhline(y=-np.log10(0.05), color='gray', linestyle='--')
|
721
|
-
|
722
|
-
plt.show()
|
723
|
-
|
724
|
-
#if model_type == 'mlr':
|
725
|
-
# show_residules(model)
|
726
|
-
|
727
|
-
if fishers:
|
728
|
-
threshold = 2*effect_size_threshold
|
729
|
-
fishers_odds(df=fisher_df, threshold=threshold, phenotyp_col='mean_pred')
|
730
|
-
|
731
|
-
return
|
732
|
-
|
733
|
-
def merge_pred_mes(src,
|
734
|
-
pred_loc,
|
735
|
-
target='protein of interest',
|
736
|
-
cell_dim=4,
|
737
|
-
nucleus_dim=5,
|
738
|
-
pathogen_dim=6,
|
739
|
-
channel_of_interest=1,
|
740
|
-
pathogen_size_min=0,
|
741
|
-
nucleus_size_min=0,
|
742
|
-
cell_size_min=0,
|
743
|
-
pathogen_min=0,
|
744
|
-
nucleus_min=0,
|
745
|
-
cell_min=0,
|
746
|
-
target_min=0,
|
747
|
-
mask_chans=[0,1,2],
|
748
|
-
filter_data=False,
|
749
|
-
include_noninfected=False,
|
750
|
-
include_multiinfected=False,
|
751
|
-
include_multinucleated=False,
|
752
|
-
cells_per_well=10,
|
753
|
-
save_filtered_filelist=False,
|
754
|
-
verbose=False):
|
755
|
-
|
756
|
-
from .io import _read_and_merge_data
|
757
|
-
from .plot import _plot_histograms_and_stats
|
758
|
-
|
759
|
-
mask_chans=[cell_dim,nucleus_dim,pathogen_dim]
|
760
|
-
sns.color_palette("mako", as_cmap=True)
|
761
|
-
print(f'channel:{channel_of_interest} = {target}')
|
762
|
-
overlay_channels = [0, 1, 2, 3]
|
763
|
-
overlay_channels.remove(channel_of_interest)
|
764
|
-
overlay_channels.reverse()
|
765
|
-
|
766
|
-
db_loc = [src+'/measurements/measurements.db']
|
767
|
-
tables = ['cell', 'nucleus', 'pathogen','cytoplasm']
|
768
|
-
df, object_dfs = _read_and_merge_data(db_loc,
|
769
|
-
tables,
|
770
|
-
verbose=True,
|
771
|
-
include_multinucleated=include_multinucleated,
|
772
|
-
include_multiinfected=include_multiinfected,
|
773
|
-
include_noninfected=include_noninfected)
|
774
|
-
if filter_data:
|
775
|
-
df = df[df['cell_area'] > cell_size_min]
|
776
|
-
df = df[df[f'cell_channel_{mask_chans[2]}_mean_intensity'] > cell_min]
|
777
|
-
print(f'After cell filtration {len(df)}')
|
778
|
-
df = df[df['nucleus_area'] > nucleus_size_min]
|
779
|
-
df = df[df[f'nucleus_channel_{mask_chans[0]}_mean_intensity'] > nucleus_min]
|
780
|
-
print(f'After nucleus filtration {len(df)}')
|
781
|
-
df = df[df['pathogen_area'] > pathogen_size_min]
|
782
|
-
df=df[df[f'pathogen_channel_{mask_chans[1]}_mean_intensity'] > pathogen_min]
|
783
|
-
print(f'After pathogen filtration {len(df)}')
|
784
|
-
df = df[df[f'cell_channel_{channel_of_interest}_percentile_95'] > target_min]
|
785
|
-
print(f'After channel {channel_of_interest} filtration', len(df))
|
786
|
-
|
787
|
-
df['recruitment'] = df[f'pathogen_channel_{channel_of_interest}_mean_intensity']/df[f'cytoplasm_channel_{channel_of_interest}_mean_intensity']
|
788
|
-
|
789
|
-
pred_df = annotate_results(pred_loc=pred_loc)
|
790
|
-
|
791
|
-
if verbose:
|
792
|
-
_plot_histograms_and_stats(df=pred_df)
|
793
|
-
|
794
|
-
pred_df.set_index('prcfo', inplace=True)
|
795
|
-
pred_df = pred_df.drop(columns=['plate', 'row', 'col', 'field'])
|
796
|
-
|
797
|
-
joined_df = df.join(pred_df, how='inner')
|
798
|
-
|
799
|
-
if verbose:
|
800
|
-
_plot_histograms_and_stats(df=joined_df)
|
801
|
-
|
802
|
-
return joined_df
|
803
|
-
|
804
|
-
def process_reads(df, min_reads, min_wells, max_wells, gene_column, remove_outliers=False):
|
805
|
-
print('start',len(df))
|
806
|
-
df = df[df['count'] >= min_reads]
|
807
|
-
print('after filtering min reads',min_reads, len(df))
|
808
|
-
reads_ls = df['count']
|
809
|
-
stats_dict = {}
|
810
|
-
stats_dict['screen_reads_mean'] = np.mean(reads_ls)
|
811
|
-
stats_dict['screen_reads_sd'] = np.std(reads_ls)
|
812
|
-
stats_dict['screen_reads_var'] = np.var(reads_ls)
|
813
|
-
|
814
|
-
well_read_sum = pd.DataFrame(df.groupby(['prc']).sum())
|
815
|
-
well_read_sum = well_read_sum.rename({'count': 'well_read_sum'}, axis=1)
|
816
|
-
well_sgRNA_count = pd.DataFrame(df.groupby(['prc']).count()[gene_column])
|
817
|
-
well_sgRNA_count = well_sgRNA_count.rename({gene_column: 'gRNAs_per_well'}, axis=1)
|
818
|
-
well_seq = pd.merge(well_read_sum, well_sgRNA_count, how='inner', suffixes=('', '_right'), left_index=True, right_index=True)
|
819
|
-
gRNA_well_count = pd.DataFrame(df.groupby([gene_column]).count()['prc'])
|
820
|
-
gRNA_well_count = gRNA_well_count.rename({'prc': 'gRNA_well_count'}, axis=1)
|
821
|
-
df = pd.merge(df, well_seq, on='prc', how='inner', suffixes=('', '_right'))
|
822
|
-
df = pd.merge(df, gRNA_well_count, on=gene_column, how='inner', suffixes=('', '_right'))
|
823
|
-
|
824
|
-
df = df[df['gRNA_well_count'] >= min_wells]
|
825
|
-
df = df[df['gRNA_well_count'] <= max_wells]
|
826
|
-
|
827
|
-
if remove_outliers:
|
828
|
-
clf = IsolationForest(contamination='auto', random_state=42, n_jobs=20)
|
829
|
-
#clf.fit(df.select_dtypes(include=['int', 'float']))
|
830
|
-
clf.fit(df[["gRNA_well_count", "count"]])
|
831
|
-
outlier_array = clf.predict(df[["gRNA_well_count", "count"]])
|
832
|
-
#outlier_array = clf.predict(df.select_dtypes(include=['int', 'float']))
|
833
|
-
outlier_df = pd.DataFrame(outlier_array, columns=['outlier'])
|
834
|
-
df['outlier'] = outlier_df['outlier']
|
835
|
-
outliers = pd.DataFrame(df[df['outlier']==-1])
|
836
|
-
df = pd.DataFrame(df[df['outlier']==1])
|
837
|
-
print('removed',len(outliers), 'outliers', 'inlers',len(df))
|
838
|
-
|
839
|
-
columns_to_drop = ['gRNA_well_count','gRNAs_per_well', 'well_read_sum']#, 'outlier']
|
840
|
-
df = df.drop(columns_to_drop, axis=1)
|
841
|
-
|
842
|
-
plates = ['p1', 'p2', 'p3', 'p4']
|
843
|
-
df = df[df.plate.isin(plates) == True]
|
844
|
-
print('after filtering out p5,p6,p7,p8',len(df))
|
845
|
-
|
846
|
-
gRNA_well_count = pd.DataFrame(df.groupby([gene_column]).count()['prc'])
|
847
|
-
gRNA_well_count = gRNA_well_count.rename({'prc': 'gRNA_well_count'}, axis=1)
|
848
|
-
df = pd.merge(df, gRNA_well_count, on=gene_column, how='inner', suffixes=('', '_right'))
|
849
|
-
well_read_sum = pd.DataFrame(df.groupby(['prc']).sum())
|
850
|
-
well_read_sum = well_read_sum.rename({'count': 'well_read_sum'}, axis=1)
|
851
|
-
well_sgRNA_count = pd.DataFrame(df.groupby(['prc']).count()[gene_column])
|
852
|
-
well_sgRNA_count = well_sgRNA_count.rename({gene_column: 'gRNAs_per_well'}, axis=1)
|
853
|
-
well_seq = pd.merge(well_read_sum, well_sgRNA_count, how='inner', suffixes=('', '_right'), left_index=True, right_index=True)
|
854
|
-
df = pd.merge(df, well_seq, on='prc', how='inner', suffixes=('', '_right'))
|
855
|
-
|
856
|
-
columns_to_drop = [col for col in df.columns if col.endswith('_right')]
|
857
|
-
columns_to_drop2 = [col for col in df.columns if col.endswith('0')]
|
858
|
-
columns_to_drop = columns_to_drop + columns_to_drop2
|
859
|
-
df = df.drop(columns_to_drop, axis=1)
|
860
|
-
return df, stats_dict
|
861
|
-
|
862
|
-
def annotate_results(pred_loc):
|
863
|
-
|
864
|
-
from .utils import _map_wells_png
|
865
|
-
|
866
|
-
df = pd.read_csv(pred_loc)
|
867
|
-
df = df.copy()
|
868
|
-
pc_col_list = ['c4','c5','c6','c7','c8','c9','c10','c11','c12','c13','c14','c15','c16','c17','c18','c19','c20','c21','c22','c23','c24']
|
869
|
-
pc_plate_list = ['p6','p7','p8', 'p9']
|
870
|
-
|
871
|
-
nc_col_list = ['c1','c2','c3']
|
872
|
-
nc_plate_list = ['p1','p2','p3','p4','p6','p7','p8', 'p9']
|
873
|
-
|
874
|
-
screen_col_list = ['c4','c5','c6','c7','c8','c9','c10','c11','c12','c13','c14','c15','c16','c17','c18','c19','c20','c21','c22','c23','c24']
|
875
|
-
screen_plate_list = ['p1','p2','p3','p4']
|
876
|
-
|
877
|
-
df[['plate', 'row', 'col', 'field', 'cell_id', 'prcfo']] = df['path'].apply(lambda x: pd.Series(_map_wells_png(x)))
|
878
|
-
|
879
|
-
df.loc[(df['col'].isin(pc_col_list)) & (df['plate'].isin(pc_plate_list)), 'condition'] = 'pc'
|
880
|
-
df.loc[(df['col'].isin(nc_col_list)) & (df['plate'].isin(nc_plate_list)), 'condition'] = 'nc'
|
881
|
-
df.loc[(df['col'].isin(screen_col_list)) & (df['plate'].isin(screen_plate_list)), 'condition'] = 'screen'
|
882
|
-
|
883
|
-
df = df.dropna(subset=['condition'])
|
884
|
-
display(df)
|
885
|
-
return df
|
886
|
-
|
887
|
-
def generate_dataset(settings={}):
|
888
|
-
|
889
|
-
from .utils import initiate_counter, add_images_to_tar
|
890
|
-
|
891
|
-
db_path = os.path.join(settings['src'], 'measurements', 'measurements.db')
|
892
|
-
dst = os.path.join(settings['src'], 'datasets')
|
893
|
-
all_paths = []
|
894
|
-
|
895
|
-
# Connect to the database and retrieve the image paths
|
896
|
-
print(f"Reading DataBase: {db_path}")
|
897
|
-
try:
|
898
|
-
with sqlite3.connect(db_path) as conn:
|
899
|
-
cursor = conn.cursor()
|
900
|
-
if settings['file_metadata']:
|
901
|
-
if isinstance(settings['file_metadata'], str):
|
902
|
-
cursor.execute("SELECT png_path FROM png_list WHERE png_path LIKE ?", (f"%{settings['file_metadata']}%",))
|
903
|
-
else:
|
904
|
-
cursor.execute("SELECT png_path FROM png_list")
|
905
|
-
|
906
|
-
while True:
|
907
|
-
rows = cursor.fetchmany(1000)
|
908
|
-
if not rows:
|
909
|
-
break
|
910
|
-
all_paths.extend([row[0] for row in rows])
|
911
|
-
|
912
|
-
except sqlite3.Error as e:
|
913
|
-
print(f"Database error: {e}")
|
914
|
-
return
|
915
|
-
except Exception as e:
|
916
|
-
print(f"Error: {e}")
|
917
|
-
return
|
918
|
-
|
919
|
-
if isinstance(settings['sample'], int):
|
920
|
-
selected_paths = random.sample(all_paths, settings['sample'])
|
921
|
-
print(f"Random selection of {len(selected_paths)} paths")
|
922
|
-
else:
|
923
|
-
selected_paths = all_paths
|
924
|
-
random.shuffle(selected_paths)
|
925
|
-
print(f"All paths: {len(selected_paths)} paths")
|
926
|
-
|
927
|
-
total_images = len(selected_paths)
|
928
|
-
print(f"Found {total_images} images")
|
929
|
-
|
930
|
-
# Create a temp folder in dst
|
931
|
-
temp_dir = os.path.join(dst, "temp_tars")
|
932
|
-
os.makedirs(temp_dir, exist_ok=True)
|
933
|
-
|
934
|
-
# Chunking the data
|
935
|
-
num_procs = max(2, cpu_count() - 2)
|
936
|
-
chunk_size = len(selected_paths) // num_procs
|
937
|
-
remainder = len(selected_paths) % num_procs
|
938
|
-
|
939
|
-
paths_chunks = []
|
940
|
-
start = 0
|
941
|
-
for i in range(num_procs):
|
942
|
-
end = start + chunk_size + (1 if i < remainder else 0)
|
943
|
-
paths_chunks.append(selected_paths[start:end])
|
944
|
-
start = end
|
945
|
-
|
946
|
-
temp_tar_files = [os.path.join(temp_dir, f"temp_{i}.tar") for i in range(num_procs)]
|
947
|
-
|
948
|
-
print(f"Generating temporary tar files in {dst}")
|
949
|
-
|
950
|
-
# Initialize shared counter and lock
|
951
|
-
counter = Value('i', 0)
|
952
|
-
lock = Lock()
|
953
|
-
|
954
|
-
with Pool(processes=num_procs, initializer=initiate_counter, initargs=(counter, lock)) as pool:
|
955
|
-
pool.starmap(add_images_to_tar, [(paths_chunks[i], temp_tar_files[i], total_images) for i in range(num_procs)])
|
956
|
-
|
957
|
-
# Combine the temporary tar files into a final tar
|
958
|
-
date_name = datetime.date.today().strftime('%y%m%d')
|
959
|
-
if not settings['file_metadata'] is None:
|
960
|
-
tar_name = f"{date_name}_{settings['experiment']}_{settings['file_metadata']}.tar"
|
961
|
-
else:
|
962
|
-
tar_name = f"{date_name}_{settings['experiment']}.tar"
|
963
|
-
tar_name = os.path.join(dst, tar_name)
|
964
|
-
if os.path.exists(tar_name):
|
965
|
-
number = random.randint(1, 100)
|
966
|
-
tar_name_2 = f"{date_name}_{settings['experiment']}_{settings['file_metadata']}_{number}.tar"
|
967
|
-
print(f"Warning: {os.path.basename(tar_name)} exists, saving as {os.path.basename(tar_name_2)} ")
|
968
|
-
tar_name = os.path.join(dst, tar_name_2)
|
969
|
-
|
970
|
-
print(f"Merging temporary files")
|
971
|
-
|
972
|
-
with tarfile.open(tar_name, 'w') as final_tar:
|
973
|
-
for temp_tar_path in temp_tar_files:
|
974
|
-
with tarfile.open(temp_tar_path, 'r') as temp_tar:
|
975
|
-
for member in temp_tar.getmembers():
|
976
|
-
file_obj = temp_tar.extractfile(member)
|
977
|
-
final_tar.addfile(member, file_obj)
|
978
|
-
os.remove(temp_tar_path)
|
979
|
-
|
980
|
-
# Delete the temp folder
|
981
|
-
shutil.rmtree(temp_dir)
|
982
|
-
print(f"\nSaved {total_images} images to {tar_name}")
|
983
|
-
|
984
|
-
return tar_name
|
985
|
-
|
986
|
-
def apply_model_to_tar(settings={}):
|
987
|
-
|
988
|
-
from .io import TarImageDataset
|
989
|
-
from .utils import process_vision_results, print_progress
|
990
|
-
|
991
|
-
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
992
|
-
if settings['normalize']:
|
993
|
-
transform = transforms.Compose([
|
994
|
-
transforms.ToTensor(),
|
995
|
-
transforms.CenterCrop(size=(settings['image_size'], settings['image_size'])),
|
996
|
-
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])
|
997
|
-
else:
|
998
|
-
transform = transforms.Compose([
|
999
|
-
transforms.ToTensor(),
|
1000
|
-
transforms.CenterCrop(size=(settings['image_size'], settings['image_size']))])
|
1001
|
-
|
1002
|
-
if settings['verbose']:
|
1003
|
-
print(f"Loading model from {settings['model_path']}")
|
1004
|
-
print(f"Loading dataset from {settings['tar_path']}")
|
1005
|
-
|
1006
|
-
model = torch.load(settings['model_path'])
|
1007
|
-
|
1008
|
-
dataset = TarImageDataset(settings['tar_path'], transform=transform)
|
1009
|
-
data_loader = DataLoader(dataset, batch_size=settings['batch_size'], shuffle=True, num_workers=settings['n_jobs'], pin_memory=True)
|
1010
|
-
|
1011
|
-
model_name = os.path.splitext(os.path.basename(settings['model_path']))[0]
|
1012
|
-
dataset_name = os.path.splitext(os.path.basename(settings['tar_path']))[0]
|
1013
|
-
date_name = datetime.date.today().strftime('%y%m%d')
|
1014
|
-
dst = os.path.dirname(settings['tar_path'])
|
1015
|
-
result_loc = f'{dst}/{date_name}_{dataset_name}_{model_name}_result.csv'
|
1016
|
-
|
1017
|
-
model.eval()
|
1018
|
-
model = model.to(device)
|
1019
|
-
|
1020
|
-
if settings['verbose']:
|
1021
|
-
print(model)
|
1022
|
-
print(f'Generated dataset with {len(dataset)} images')
|
1023
|
-
print(f'Generating loader from {len(data_loader)} batches')
|
1024
|
-
print(f'Results wil be saved in: {result_loc}')
|
1025
|
-
print(f'Model is in eval mode')
|
1026
|
-
print(f'Model loaded to device')
|
1027
|
-
|
1028
|
-
prediction_pos_probs = []
|
1029
|
-
filenames_list = []
|
1030
|
-
time_ls = []
|
1031
|
-
gc.collect()
|
1032
|
-
with torch.no_grad():
|
1033
|
-
for batch_idx, (batch_images, filenames) in enumerate(data_loader, start=1):
|
1034
|
-
start = time.time()
|
1035
|
-
images = batch_images.to(torch.float).to(device)
|
1036
|
-
outputs = model(images)
|
1037
|
-
batch_prediction_pos_prob = torch.sigmoid(outputs).cpu().numpy()
|
1038
|
-
prediction_pos_probs.extend(batch_prediction_pos_prob.tolist())
|
1039
|
-
filenames_list.extend(filenames)
|
1040
|
-
stop = time.time()
|
1041
|
-
duration = stop - start
|
1042
|
-
time_ls.append(duration)
|
1043
|
-
files_processed = batch_idx*settings['batch_size']
|
1044
|
-
files_to_process = len(data_loader)
|
1045
|
-
print_progress(files_processed, files_to_process, n_jobs=settings['n_jobs'], time_ls=time_ls, batch_size=settings['batch_size'], operation_type="Tar dataset")
|
1046
|
-
|
1047
|
-
data = {'path':filenames_list, 'pred':prediction_pos_probs}
|
1048
|
-
df = pd.DataFrame(data, index=None)
|
1049
|
-
df = process_vision_results(df, settings['score_threshold'])
|
1050
|
-
|
1051
|
-
df.to_csv(result_loc, index=True, header=True, mode='w')
|
1052
|
-
torch.cuda.empty_cache()
|
1053
|
-
torch.cuda.memory.empty_cache()
|
1054
|
-
return df
|
1055
|
-
|
1056
|
-
def apply_model(src, model_path, image_size=224, batch_size=64, normalize=True, n_jobs=10):
|
1057
|
-
|
1058
|
-
from .io import NoClassDataset
|
1059
|
-
from .utils import print_progress
|
1060
|
-
|
1061
|
-
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
1062
|
-
|
1063
|
-
if normalize:
|
1064
|
-
transform = transforms.Compose([
|
1065
|
-
transforms.ToTensor(),
|
1066
|
-
transforms.CenterCrop(size=(image_size, image_size)),
|
1067
|
-
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])
|
1068
|
-
else:
|
1069
|
-
transform = transforms.Compose([
|
1070
|
-
transforms.ToTensor(),
|
1071
|
-
transforms.CenterCrop(size=(image_size, image_size))])
|
1072
|
-
|
1073
|
-
model = torch.load(model_path)
|
1074
|
-
print(model)
|
1075
|
-
|
1076
|
-
print(f'Loading dataset in {src} with {len(src)} images')
|
1077
|
-
dataset = NoClassDataset(data_dir=src, transform=transform, shuffle=True, load_to_memory=False)
|
1078
|
-
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=n_jobs)
|
1079
|
-
print(f'Loaded {len(src)} images')
|
1080
|
-
|
1081
|
-
result_loc = os.path.splitext(model_path)[0]+datetime.date.today().strftime('%y%m%d')+'_'+os.path.splitext(model_path)[1]+'_test_result.csv'
|
1082
|
-
print(f'Results wil be saved in: {result_loc}')
|
1083
|
-
|
1084
|
-
model.eval()
|
1085
|
-
model = model.to(device)
|
1086
|
-
prediction_pos_probs = []
|
1087
|
-
filenames_list = []
|
1088
|
-
time_ls = []
|
1089
|
-
with torch.no_grad():
|
1090
|
-
for batch_idx, (batch_images, filenames) in enumerate(data_loader, start=1):
|
1091
|
-
start = time.time()
|
1092
|
-
images = batch_images.to(torch.float).to(device)
|
1093
|
-
outputs = model(images)
|
1094
|
-
batch_prediction_pos_prob = torch.sigmoid(outputs).cpu().numpy()
|
1095
|
-
prediction_pos_probs.extend(batch_prediction_pos_prob.tolist())
|
1096
|
-
filenames_list.extend(filenames)
|
1097
|
-
stop = time.time()
|
1098
|
-
duration = stop - start
|
1099
|
-
time_ls.append(duration)
|
1100
|
-
files_processed = batch_idx*batch_size
|
1101
|
-
files_to_process = len(data_loader)
|
1102
|
-
print_progress(files_processed, files_to_process, n_jobs=n_jobs, time_ls=time_ls, batch_size=batch_size, operation_type="Generating predictions")
|
1103
|
-
|
1104
|
-
data = {'path':filenames_list, 'pred':prediction_pos_probs}
|
1105
|
-
df = pd.DataFrame(data, index=None)
|
1106
|
-
df.to_csv(result_loc, index=True, header=True, mode='w')
|
1107
|
-
torch.cuda.empty_cache()
|
1108
|
-
torch.cuda.memory.empty_cache()
|
1109
|
-
return df
|
1110
|
-
|
1111
|
-
def generate_training_data_file_list(src,
|
1112
|
-
target='protein of interest',
|
1113
|
-
cell_dim=4,
|
1114
|
-
nucleus_dim=5,
|
1115
|
-
pathogen_dim=6,
|
1116
|
-
channel_of_interest=1,
|
1117
|
-
pathogen_size_min=0,
|
1118
|
-
nucleus_size_min=0,
|
1119
|
-
cell_size_min=0,
|
1120
|
-
pathogen_min=0,
|
1121
|
-
nucleus_min=0,
|
1122
|
-
cell_min=0,
|
1123
|
-
target_min=0,
|
1124
|
-
mask_chans=[0,1,2],
|
1125
|
-
filter_data=False,
|
1126
|
-
include_noninfected=False,
|
1127
|
-
include_multiinfected=False,
|
1128
|
-
include_multinucleated=False,
|
1129
|
-
cells_per_well=10,
|
1130
|
-
save_filtered_filelist=False):
|
1131
|
-
|
1132
|
-
from .io import _read_and_merge_data
|
1133
|
-
|
1134
|
-
mask_dims=[cell_dim,nucleus_dim,pathogen_dim]
|
1135
|
-
sns.color_palette("mako", as_cmap=True)
|
1136
|
-
print(f'channel:{channel_of_interest} = {target}')
|
1137
|
-
overlay_channels = [0, 1, 2, 3]
|
1138
|
-
overlay_channels.remove(channel_of_interest)
|
1139
|
-
overlay_channels.reverse()
|
1140
|
-
|
1141
|
-
db_loc = [src+'/measurements/measurements.db']
|
1142
|
-
tables = ['cell', 'nucleus', 'pathogen','cytoplasm']
|
1143
|
-
df, object_dfs = _read_and_merge_data(db_loc,
|
1144
|
-
tables,
|
1145
|
-
verbose=True,
|
1146
|
-
include_multinucleated=include_multinucleated,
|
1147
|
-
include_multiinfected=include_multiinfected,
|
1148
|
-
include_noninfected=include_noninfected)
|
1149
|
-
|
1150
|
-
if filter_data:
|
1151
|
-
df = df[df['cell_area'] > cell_size_min]
|
1152
|
-
df = df[df[f'cell_channel_{mask_chans[2]}_mean_intensity'] > cell_min]
|
1153
|
-
print(f'After cell filtration {len(df)}')
|
1154
|
-
df = df[df['nucleus_area'] > nucleus_size_min]
|
1155
|
-
df = df[df[f'nucleus_channel_{mask_chans[0]}_mean_intensity'] > nucleus_min]
|
1156
|
-
print(f'After nucleus filtration {len(df)}')
|
1157
|
-
df = df[df['pathogen_area'] > pathogen_size_min]
|
1158
|
-
df=df[df[f'pathogen_channel_{mask_chans[1]}_mean_intensity'] > pathogen_min]
|
1159
|
-
print(f'After pathogen filtration {len(df)}')
|
1160
|
-
df = df[df[f'cell_channel_{channel_of_interest}_percentile_95'] > target_min]
|
1161
|
-
print(f'After channel {channel_of_interest} filtration', len(df))
|
1162
|
-
|
1163
|
-
df['recruitment'] = df[f'pathogen_channel_{channel_of_interest}_mean_intensity']/df[f'cytoplasm_channel_{channel_of_interest}_mean_intensity']
|
1164
|
-
return df
|
1165
|
-
|
1166
|
-
def training_dataset_from_annotation(db_path, dst, annotation_column='test', annotated_classes=(1, 2)):
|
1167
|
-
all_paths = []
|
1168
|
-
|
1169
|
-
# Connect to the database and retrieve the image paths and annotations
|
1170
|
-
print(f'Reading DataBase: {db_path}')
|
1171
|
-
with sqlite3.connect(db_path) as conn:
|
1172
|
-
cursor = conn.cursor()
|
1173
|
-
# Prepare the query with parameterized placeholders for annotated_classes
|
1174
|
-
placeholders = ','.join('?' * len(annotated_classes))
|
1175
|
-
query = f"SELECT png_path, {annotation_column} FROM png_list WHERE {annotation_column} IN ({placeholders})"
|
1176
|
-
cursor.execute(query, annotated_classes)
|
1177
|
-
|
1178
|
-
while True:
|
1179
|
-
rows = cursor.fetchmany(1000)
|
1180
|
-
if not rows:
|
1181
|
-
break
|
1182
|
-
for row in rows:
|
1183
|
-
all_paths.append(row)
|
1184
|
-
|
1185
|
-
# Filter paths based on annotation
|
1186
|
-
class_paths = []
|
1187
|
-
for class_ in annotated_classes:
|
1188
|
-
class_paths_temp = [path for path, annotation in all_paths if annotation == class_]
|
1189
|
-
class_paths.append(class_paths_temp)
|
1190
|
-
|
1191
|
-
print(f'Generated a list of lists from annotation of {len(class_paths)} classes')
|
1192
|
-
return class_paths
|
1193
|
-
|
1194
|
-
def generate_dataset_from_lists(dst, class_data, classes, test_split=0.1):
|
1195
|
-
from .utils import print_progress
|
1196
|
-
# Make sure that the length of class_data matches the length of classes
|
1197
|
-
if len(class_data) != len(classes):
|
1198
|
-
raise ValueError("class_data and classes must have the same length.")
|
1199
|
-
|
1200
|
-
total_files = sum(len(data) for data in class_data)
|
1201
|
-
processed_files = 0
|
1202
|
-
time_ls = []
|
1203
|
-
|
1204
|
-
for cls, data in zip(classes, class_data):
|
1205
|
-
# Create directories
|
1206
|
-
train_class_dir = os.path.join(dst, f'train/{cls}')
|
1207
|
-
test_class_dir = os.path.join(dst, f'test/{cls}')
|
1208
|
-
os.makedirs(train_class_dir, exist_ok=True)
|
1209
|
-
os.makedirs(test_class_dir, exist_ok=True)
|
1210
|
-
|
1211
|
-
# Split the data
|
1212
|
-
train_data, test_data = train_test_split(data, test_size=test_split, shuffle=True, random_state=42)
|
1213
|
-
|
1214
|
-
# Copy train files
|
1215
|
-
for path in train_data:
|
1216
|
-
start = time.time()
|
1217
|
-
shutil.copy(path, os.path.join(train_class_dir, os.path.basename(path)))
|
1218
|
-
duration = time.time() - start
|
1219
|
-
time_ls.append(duration)
|
1220
|
-
print_progress(processed_files, total_files, n_jobs=1, time_ls=None, batch_size=None, operation_type="Copying files for Train dataset")
|
1221
|
-
processed_files += 1
|
1222
|
-
|
1223
|
-
# Copy test files
|
1224
|
-
for path in test_data:
|
1225
|
-
start = time.time()
|
1226
|
-
shutil.copy(path, os.path.join(test_class_dir, os.path.basename(path)))
|
1227
|
-
duration = time.time() - start
|
1228
|
-
time_ls.append(duration)
|
1229
|
-
print_progress(processed_files, total_files, n_jobs=1, time_ls=None, batch_size=None, operation_type="Copying files for Test dataset")
|
1230
|
-
processed_files += 1
|
1231
|
-
|
1232
|
-
# Print summary
|
1233
|
-
for cls in classes:
|
1234
|
-
train_class_dir = os.path.join(dst, f'train/{cls}')
|
1235
|
-
test_class_dir = os.path.join(dst, f'test/{cls}')
|
1236
|
-
print(f'Train class {cls}: {len(os.listdir(train_class_dir))}, Test class {cls}: {len(os.listdir(test_class_dir))}')
|
1237
|
-
|
1238
|
-
return os.path.join(dst, 'train'), os.path.join(dst, 'test')
|
1239
|
-
|
1240
|
-
def generate_training_dataset(settings):
|
1241
|
-
|
1242
|
-
from .io import _read_and_merge_data, _read_db
|
1243
|
-
from .utils import get_paths_from_db, annotate_conditions
|
1244
|
-
from .settings import set_generate_training_dataset_defaults
|
1245
|
-
|
1246
|
-
settings = set_generate_training_dataset_defaults(settings)
|
1247
|
-
|
1248
|
-
db_path = os.path.join(settings['src'], 'measurements','measurements.db')
|
1249
|
-
dst = os.path.join(settings['src'], 'datasets', 'training')
|
1250
|
-
|
1251
|
-
if os.path.exists(dst):
|
1252
|
-
for i in range(1, 1000):
|
1253
|
-
dst = os.path.join(settings['src'], 'datasets', f'training_{i}')
|
1254
|
-
if not os.path.exists(dst):
|
1255
|
-
print(f'Creating new directory for training: {dst}')
|
1256
|
-
break
|
1257
|
-
|
1258
|
-
if settings['dataset_mode'] == 'annotation':
|
1259
|
-
class_paths_ls_2 = []
|
1260
|
-
class_paths_ls = training_dataset_from_annotation(db_path, dst, settings['annotation_column'], annotated_classes=settings['annotated_classes'])
|
1261
|
-
for class_paths in class_paths_ls:
|
1262
|
-
class_paths_temp = random.sample(class_paths, settings['size'])
|
1263
|
-
class_paths_ls_2.append(class_paths_temp)
|
1264
|
-
class_paths_ls = class_paths_ls_2
|
1265
|
-
|
1266
|
-
elif settings['dataset_mode'] == 'metadata':
|
1267
|
-
class_paths_ls = []
|
1268
|
-
class_len_ls = []
|
1269
|
-
[df] = _read_db(db_loc=db_path, tables=['png_list'])
|
1270
|
-
df['metadata_based_class'] = pd.NA
|
1271
|
-
for i, class_ in enumerate(settings['classes']):
|
1272
|
-
ls = settings['class_metadata'][i]
|
1273
|
-
df.loc[df[settings['metadata_type_by']].isin(ls), 'metadata_based_class'] = class_
|
1274
|
-
|
1275
|
-
for class_ in settings['classes']:
|
1276
|
-
if settings['size'] == None:
|
1277
|
-
c_s = []
|
1278
|
-
for c in settings['classes']:
|
1279
|
-
c_s_t_df = df[df['metadata_based_class'] == c]
|
1280
|
-
c_s.append(len(c_s_t_df))
|
1281
|
-
print(f'Found {len(c_s_t_df)} images for class {c}')
|
1282
|
-
size = min(c_s)
|
1283
|
-
print(f'Using the smallest class size: {size}')
|
1284
|
-
|
1285
|
-
class_temp_df = df[df['metadata_based_class'] == class_]
|
1286
|
-
class_len_ls.append(len(class_temp_df))
|
1287
|
-
print(f'Found {len(class_temp_df)} images for class {class_}')
|
1288
|
-
class_paths_temp = random.sample(class_temp_df['png_path'].tolist(), settings['size'])
|
1289
|
-
class_paths_ls.append(class_paths_temp)
|
1290
|
-
|
1291
|
-
elif settings['dataset_mode'] == 'recruitment':
|
1292
|
-
class_paths_ls = []
|
1293
|
-
if not isinstance(settings['tables'], list):
|
1294
|
-
tables = ['cell', 'nucleus', 'pathogen','cytoplasm']
|
1295
|
-
|
1296
|
-
df, _ = _read_and_merge_data(locs=[db_path],
|
1297
|
-
tables=tables,
|
1298
|
-
verbose=False,
|
1299
|
-
include_multinucleated=True,
|
1300
|
-
include_multiinfected=True,
|
1301
|
-
include_noninfected=True)
|
1302
|
-
|
1303
|
-
print('length df 1', len(df))
|
1304
|
-
|
1305
|
-
df = annotate_conditions(df, cells=['HeLa'], cell_loc=None, pathogens=['pathogen'], pathogen_loc=None, treatments=settings['classes'], treatment_loc=settings['class_metadata'], types = settings['metadata_type_by'])
|
1306
|
-
print('length df 2', len(df))
|
1307
|
-
[png_list_df] = _read_db(db_loc=db_path, tables=['png_list'])
|
1308
|
-
|
1309
|
-
if settings['custom_measurement'] != None:
|
1310
|
-
|
1311
|
-
if not isinstance(settings['custom_measurement'], list):
|
1312
|
-
print(f'custom_measurement should be a list, add [ measurement_1, measurement_2 ] or [ measurement ]')
|
1313
|
-
return
|
1314
|
-
|
1315
|
-
if isinstance(settings['custom_measurement'], list):
|
1316
|
-
if len(settings['custom_measurement']) == 2:
|
1317
|
-
print(f"Classes will be defined by the Q1 and Q3 quantiles of recruitment ({settings['custom_measurement'][0]}/{settings['custom_measurement'][1]})")
|
1318
|
-
df['recruitment'] = df[f"{settings['custom_measurement'][0]}']/df[f'{settings['custom_measurement'][1]}"]
|
1319
|
-
if len(settings['custom_measurement']) == 1:
|
1320
|
-
print(f"Classes will be defined by the Q1 and Q3 quantiles of recruitment ({settings['custom_measurement'][0]})")
|
1321
|
-
df['recruitment'] = df[f"{settings['custom_measurement'][0]}"]
|
1322
|
-
else:
|
1323
|
-
print(f"Classes will be defined by the Q1 and Q3 quantiles of recruitment (pathogen/cytoplasm for channel {settings['channel_of_interest']})")
|
1324
|
-
df['recruitment'] = df[f"pathogen_channel_{settings['channel_of_interest']}_mean_intensity']/df[f'cytoplasm_channel_{settings['channel_of_interest']}_mean_intensity"]
|
1325
|
-
|
1326
|
-
q25 = df['recruitment'].quantile(0.25)
|
1327
|
-
q75 = df['recruitment'].quantile(0.75)
|
1328
|
-
df_lower = df[df['recruitment'] <= q25]
|
1329
|
-
df_upper = df[df['recruitment'] >= q75]
|
1330
|
-
|
1331
|
-
class_paths_lower = get_paths_from_db(df=df_lower, png_df=png_list_df, image_type=settings['png_type'])
|
1332
|
-
|
1333
|
-
class_paths_lower = random.sample(class_paths_lower['png_path'].tolist(), settings['size'])
|
1334
|
-
class_paths_ls.append(class_paths_lower)
|
1335
|
-
|
1336
|
-
class_paths_upper = get_paths_from_db(df=df_upper, png_df=png_list_df, image_type=settings['png_type'])
|
1337
|
-
class_paths_upper = random.sample(class_paths_upper['png_path'].tolist(), settings['size'])
|
1338
|
-
class_paths_ls.append(class_paths_upper)
|
1339
|
-
|
1340
|
-
train_class_dir, test_class_dir = generate_dataset_from_lists(dst, class_data=class_paths_ls, classes=settings['classes'], test_split=settings['test_split'])
|
1341
|
-
|
1342
|
-
return train_class_dir, test_class_dir
|
1343
|
-
|
1344
|
-
def generate_loaders(src, mode='train', image_size=224, batch_size=32, classes=['nc','pc'], n_jobs=None, validation_split=0.0, pin_memory=False, normalize=False, channels=[1, 2, 3], augment=False, preload_batches=3, verbose=False):
|
1345
|
-
|
1346
|
-
"""
|
1347
|
-
Generate data loaders for training and validation/test datasets.
|
1348
|
-
|
1349
|
-
Parameters:
|
1350
|
-
- src (str): The source directory containing the data.
|
1351
|
-
- mode (str): The mode of operation. Options are 'train' or 'test'.
|
1352
|
-
- image_size (int): The size of the input images.
|
1353
|
-
- batch_size (int): The batch size for the data loaders.
|
1354
|
-
- classes (list): The list of classes to consider.
|
1355
|
-
- n_jobs (int): The number of worker threads for data loading.
|
1356
|
-
- validation_split (float): The fraction of data to use for validation.
|
1357
|
-
- pin_memory (bool): Whether to pin memory for faster data transfer.
|
1358
|
-
- normalize (bool): Whether to normalize the input images.
|
1359
|
-
- verbose (bool): Whether to print additional information and show images.
|
1360
|
-
- channels (list): The list of channels to retain. Options are [1, 2, 3] for all channels, [1, 2] for blue and green, etc.
|
1361
|
-
|
1362
|
-
Returns:
|
1363
|
-
- train_loaders (list): List of data loaders for training datasets.
|
1364
|
-
- val_loaders (list): List of data loaders for validation datasets.
|
1365
|
-
"""
|
1366
|
-
|
1367
|
-
from .io import spacrDataset, spacrDataLoader
|
1368
|
-
from .plot import _imshow_gpu
|
1369
|
-
from .utils import SelectChannels, augment_dataset
|
1370
|
-
|
1371
|
-
chans = []
|
1372
|
-
|
1373
|
-
if 'r' in channels:
|
1374
|
-
chans.append(1)
|
1375
|
-
if 'g' in channels:
|
1376
|
-
chans.append(2)
|
1377
|
-
if 'b' in channels:
|
1378
|
-
chans.append(3)
|
1379
|
-
|
1380
|
-
channels = chans
|
1381
|
-
|
1382
|
-
if verbose:
|
1383
|
-
print(f'Training a network on channels: {channels}')
|
1384
|
-
print(f'Channel 1: Red, Channel 2: Green, Channel 3: Blue')
|
1385
|
-
|
1386
|
-
train_loaders = []
|
1387
|
-
val_loaders = []
|
1388
|
-
|
1389
|
-
if normalize:
|
1390
|
-
transform = transforms.Compose([
|
1391
|
-
transforms.ToTensor(),
|
1392
|
-
transforms.CenterCrop(size=(image_size, image_size)),
|
1393
|
-
SelectChannels(channels),
|
1394
|
-
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])
|
1395
|
-
else:
|
1396
|
-
transform = transforms.Compose([
|
1397
|
-
transforms.ToTensor(),
|
1398
|
-
transforms.CenterCrop(size=(image_size, image_size)),
|
1399
|
-
SelectChannels(channels)])
|
1400
|
-
|
1401
|
-
if mode == 'train':
|
1402
|
-
data_dir = os.path.join(src, 'train')
|
1403
|
-
shuffle = True
|
1404
|
-
print('Generating Train and validation datasets')
|
1405
|
-
elif mode == 'test':
|
1406
|
-
data_dir = os.path.join(src, 'test')
|
1407
|
-
val_loaders = []
|
1408
|
-
validation_split = 0.0
|
1409
|
-
shuffle = True
|
1410
|
-
print('Generating test dataset')
|
1411
|
-
else:
|
1412
|
-
print(f'mode:{mode} is not valid, use mode = train or test')
|
1413
|
-
return
|
1414
|
-
|
1415
|
-
data = spacrDataset(data_dir, classes, transform=transform, shuffle=shuffle, pin_memory=pin_memory)
|
1416
|
-
num_workers = n_jobs if n_jobs is not None else 0
|
1417
|
-
|
1418
|
-
if validation_split > 0:
|
1419
|
-
train_size = int((1 - validation_split) * len(data))
|
1420
|
-
val_size = len(data) - train_size
|
1421
|
-
if not augment:
|
1422
|
-
print(f'Train data:{train_size}, Validation data:{val_size}')
|
1423
|
-
train_dataset, val_dataset = random_split(data, [train_size, val_size])
|
1424
|
-
|
1425
|
-
if augment:
|
1426
|
-
|
1427
|
-
print(f'Data before augmentation: Train: {len(train_dataset)}, Validataion:{len(val_dataset)}')
|
1428
|
-
train_dataset = augment_dataset(train_dataset, is_grayscale=(len(channels) == 1))
|
1429
|
-
print(f'Data after augmentation: Train: {len(train_dataset)}')
|
1430
|
-
|
1431
|
-
print(f'Generating Dataloader with {n_jobs} workers')
|
1432
|
-
#train_loaders = spacrDataLoader(train_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, pin_memory=pin_memory, persistent_workers=True, preload_batches=preload_batches)
|
1433
|
-
#train_loaders = spacrDataLoader(val_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, pin_memory=pin_memory, persistent_workers=True, preload_batches=preload_batches)
|
1434
|
-
|
1435
|
-
train_loaders = DataLoader(train_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=1, pin_memory=pin_memory, persistent_workers=True)
|
1436
|
-
val_loaders = DataLoader(val_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=1, pin_memory=pin_memory, persistent_workers=True)
|
1437
|
-
else:
|
1438
|
-
train_loaders = DataLoader(data, batch_size=batch_size, shuffle=shuffle, num_workers=1, pin_memory=pin_memory, persistent_workers=True)
|
1439
|
-
#train_loaders = spacrDataLoader(data, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, pin_memory=pin_memory, persistent_workers=True, preload_batches=preload_batches)
|
1440
|
-
|
1441
|
-
#dataset (Dataset) – dataset from which to load the data.
|
1442
|
-
#batch_size (int, optional) – how many samples per batch to load (default: 1).
|
1443
|
-
#shuffle (bool, optional) – set to True to have the data reshuffled at every epoch (default: False).
|
1444
|
-
#sampler (Sampler or Iterable, optional) – defines the strategy to draw samples from the dataset. Can be any Iterable with __len__ implemented. If specified, shuffle must not be specified.
|
1445
|
-
#batch_sampler (Sampler or Iterable, optional) – like sampler, but returns a batch of indices at a time. Mutually exclusive with batch_size, shuffle, sampler, and drop_last.
|
1446
|
-
#num_workers (int, optional) – how many subprocesses to use for data loading. 0 means that the data will be loaded in the main process. (default: 0)
|
1447
|
-
#collate_fn (Callable, optional) – merges a list of samples to form a mini-batch of Tensor(s). Used when using batched loading from a map-style dataset.
|
1448
|
-
#pin_memory (bool, optional) – If True, the data loader will copy Tensors into device/CUDA pinned memory before returning them. If your data elements are a custom type, or your collate_fn returns a batch that is a custom type, see the example below.
|
1449
|
-
#drop_last (bool, optional) – set to True to drop the last incomplete batch, if the dataset size is not divisible by the batch size. If False and the size of dataset is not divisible by the batch size, then the last batch will be smaller. (default: False)
|
1450
|
-
#timeout (numeric, optional) – if positive, the timeout value for collecting a batch from workers. Should always be non-negative. (default: 0)
|
1451
|
-
#worker_init_fn (Callable, optional) – If not None, this will be called on each worker subprocess with the worker id (an int in [0, num_workers - 1]) as input, after seeding and before data loading. (default: None)
|
1452
|
-
#multiprocessing_context (str or multiprocessing.context.BaseContext, optional) – If None, the default multiprocessing context of your operating system will be used. (default: None)
|
1453
|
-
#generator (torch.Generator, optional) – If not None, this RNG will be used by RandomSampler to generate random indexes and multiprocessing to generate base_seed for workers. (default: None)
|
1454
|
-
#prefetch_factor (int, optional, keyword-only arg) – Number of batches loaded in advance by each worker. 2 means there will be a total of 2 * num_workers batches prefetched across all workers. (default value depends on the set value for num_workers. If value of num_workers=0 default is None. Otherwise, if value of num_workers > 0 default is 2).
|
1455
|
-
#persistent_workers (bool, optional) – If True, the data loader will not shut down the worker processes after a dataset has been consumed once. This allows to maintain the workers Dataset instances alive. (default: False)
|
1456
|
-
#pin_memory_device (str, optional) – the device to pin_memory to if pin_memory is True.
|
1457
|
-
|
1458
|
-
#images, labels, filenames = next(iter(train_loaders))
|
1459
|
-
#images = images.cpu()
|
1460
|
-
#label_strings = [str(label.item()) for label in labels]
|
1461
|
-
#train_fig = _imshow_gpu(images, label_strings, nrow=20, fontsize=12)
|
1462
|
-
#if verbose:
|
1463
|
-
# plt.show()
|
1464
|
-
|
1465
|
-
train_fig = None
|
1466
|
-
|
1467
|
-
return train_loaders, val_loaders, train_fig
|
1468
|
-
|
1469
|
-
def analyze_recruitment(settings={}):
|
1470
|
-
"""
|
1471
|
-
Analyze recruitment data by grouping the DataFrame by well coordinates and plotting controls and recruitment data.
|
1472
|
-
|
1473
|
-
Parameters:
|
1474
|
-
settings (dict): settings.
|
1475
|
-
|
1476
|
-
Returns:
|
1477
|
-
None
|
1478
|
-
"""
|
1479
|
-
|
1480
|
-
from .io import _read_and_merge_data, _results_to_csv
|
1481
|
-
from .plot import plot_image_mask_overlay, _plot_controls, _plot_recruitment
|
1482
|
-
from .utils import _object_filter, annotate_conditions, _calculate_recruitment, _group_by_well, save_settings
|
1483
|
-
from .settings import get_analyze_recruitment_default_settings
|
1484
|
-
|
1485
|
-
settings = get_analyze_recruitment_default_settings(settings=settings)
|
1486
|
-
save_settings(settings, name='recruitment')
|
1487
|
-
|
1488
|
-
# metadata settings
|
1489
|
-
src = settings['src']
|
1490
|
-
target = settings['target']
|
1491
|
-
cell_types = settings['cell_types']
|
1492
|
-
cell_plate_metadata = settings['cell_plate_metadata']
|
1493
|
-
pathogen_types = settings['pathogen_types']
|
1494
|
-
pathogen_plate_metadata = settings['pathogen_plate_metadata']
|
1495
|
-
treatments = settings['treatments']
|
1496
|
-
treatment_plate_metadata = settings['treatment_plate_metadata']
|
1497
|
-
metadata_types = settings['metadata_types']
|
1498
|
-
channel_dims = settings['channel_dims']
|
1499
|
-
cell_chann_dim = settings['cell_chann_dim']
|
1500
|
-
cell_mask_dim = settings['cell_mask_dim']
|
1501
|
-
nucleus_chann_dim = settings['nucleus_chann_dim']
|
1502
|
-
nucleus_mask_dim = settings['nucleus_mask_dim']
|
1503
|
-
pathogen_chann_dim = settings['pathogen_chann_dim']
|
1504
|
-
pathogen_mask_dim = settings['pathogen_mask_dim']
|
1505
|
-
channel_of_interest = settings['channel_of_interest']
|
1506
|
-
|
1507
|
-
# Advanced settings
|
1508
|
-
plot = settings['plot']
|
1509
|
-
plot_nr = settings['plot_nr']
|
1510
|
-
plot_control = settings['plot_control']
|
1511
|
-
figuresize = settings['figuresize']
|
1512
|
-
include_noninfected = settings['include_noninfected']
|
1513
|
-
include_multiinfected = settings['include_multiinfected']
|
1514
|
-
include_multinucleated = settings['include_multinucleated']
|
1515
|
-
cells_per_well = settings['cells_per_well']
|
1516
|
-
pathogen_size_range = settings['pathogen_size_range']
|
1517
|
-
nucleus_size_range = settings['nucleus_size_range']
|
1518
|
-
cell_size_range = settings['cell_size_range']
|
1519
|
-
pathogen_intensity_range = settings['pathogen_intensity_range']
|
1520
|
-
nucleus_intensity_range = settings['nucleus_intensity_range']
|
1521
|
-
cell_intensity_range = settings['cell_intensity_range']
|
1522
|
-
target_intensity_min = settings['target_intensity_min']
|
1523
|
-
|
1524
|
-
print(f'Cell(s): {cell_types}, in {cell_plate_metadata}')
|
1525
|
-
print(f'Pathogen(s): {pathogen_types}, in {pathogen_plate_metadata}')
|
1526
|
-
print(f'Treatment(s): {treatments}, in {treatment_plate_metadata}')
|
1527
|
-
|
1528
|
-
mask_dims=[cell_mask_dim,nucleus_mask_dim,pathogen_mask_dim]
|
1529
|
-
mask_chans=[nucleus_chann_dim, pathogen_chann_dim, cell_chann_dim]
|
1530
|
-
|
1531
|
-
if isinstance(metadata_types, str):
|
1532
|
-
metadata_types = [metadata_types, metadata_types, metadata_types]
|
1533
|
-
if isinstance(metadata_types, list):
|
1534
|
-
if len(metadata_types) < 3:
|
1535
|
-
metadata_types = [metadata_types[0], metadata_types[0], metadata_types[0]]
|
1536
|
-
print(f'WARNING: setting metadata types to first element times 3: {metadata_types}. To avoid this behaviour, set metadata_types to a list with 3 elements. Elements should be col row or plate.')
|
1537
|
-
else:
|
1538
|
-
metadata_types = metadata_types
|
1539
|
-
|
1540
|
-
sns.color_palette("mako", as_cmap=True)
|
1541
|
-
print(f'channel:{channel_of_interest} = {target}')
|
1542
|
-
overlay_channels = channel_dims
|
1543
|
-
overlay_channels.remove(channel_of_interest)
|
1544
|
-
overlay_channels.reverse()
|
1545
|
-
|
1546
|
-
db_loc = [src+'/measurements/measurements.db']
|
1547
|
-
tables = ['cell', 'nucleus', 'pathogen','cytoplasm']
|
1548
|
-
df, _ = _read_and_merge_data(db_loc,
|
1549
|
-
tables,
|
1550
|
-
verbose=True,
|
1551
|
-
include_multinucleated=include_multinucleated,
|
1552
|
-
include_multiinfected=include_multiinfected,
|
1553
|
-
include_noninfected=include_noninfected)
|
1554
|
-
|
1555
|
-
df = annotate_conditions(df,
|
1556
|
-
cells=cell_types,
|
1557
|
-
cell_loc=cell_plate_metadata,
|
1558
|
-
pathogens=pathogen_types,
|
1559
|
-
pathogen_loc=pathogen_plate_metadata,
|
1560
|
-
treatments=treatments,
|
1561
|
-
treatment_loc=treatment_plate_metadata,
|
1562
|
-
types=metadata_types)
|
1563
|
-
|
1564
|
-
df = df.dropna(subset=['condition'])
|
1565
|
-
print(f'After dropping non-annotated wells: {len(df)} rows')
|
1566
|
-
files = df['file_name'].tolist()
|
1567
|
-
print(f'found: {len(files)} files')
|
1568
|
-
files = [item + '.npy' for item in files]
|
1569
|
-
random.shuffle(files)
|
1570
|
-
|
1571
|
-
_max = 10**100
|
1572
|
-
if cell_size_range is None:
|
1573
|
-
cell_size_range = [0,_max]
|
1574
|
-
if nucleus_size_range is None:
|
1575
|
-
nucleus_size_range = [0,_max]
|
1576
|
-
if pathogen_size_range is None:
|
1577
|
-
pathogen_size_range = [0,_max]
|
1578
|
-
|
1579
|
-
if plot:
|
1580
|
-
merged_path = os.path.join(src,'merged')
|
1581
|
-
if os.path.exists(merged_path):
|
1582
|
-
try:
|
1583
|
-
for idx, file in enumerate(os.listdir(merged_path)):
|
1584
|
-
file_path = os.path.join(merged_path,file)
|
1585
|
-
if idx <= plot_nr:
|
1586
|
-
plot_image_mask_overlay(file_path,
|
1587
|
-
channel_dims,
|
1588
|
-
cell_chann_dim,
|
1589
|
-
nucleus_chann_dim,
|
1590
|
-
pathogen_chann_dim,
|
1591
|
-
figuresize=10,
|
1592
|
-
normalize=True,
|
1593
|
-
thickness=3,
|
1594
|
-
save_pdf=True)
|
1595
|
-
except Exception as e:
|
1596
|
-
print(f'Failed to plot images with outlines, Error: {e}')
|
1597
|
-
|
1598
|
-
if not cell_chann_dim is None:
|
1599
|
-
df = _object_filter(df, object_type='cell', size_range=cell_size_range, intensity_range=cell_intensity_range, mask_chans=mask_chans, mask_chan=0)
|
1600
|
-
if not target_intensity_min is None:
|
1601
|
-
df = df[df[f'cell_channel_{channel_of_interest}_percentile_95'] > target_intensity_min]
|
1602
|
-
print(f'After channel {channel_of_interest} filtration', len(df))
|
1603
|
-
if not nucleus_chann_dim is None:
|
1604
|
-
df = _object_filter(df, object_type='nucleus', size_range=nucleus_size_range, intensity_range=nucleus_intensity_range, mask_chans=mask_chans, mask_chan=1)
|
1605
|
-
if not pathogen_chann_dim is None:
|
1606
|
-
df = _object_filter(df, object_type='pathogen', size_range=pathogen_size_range, intensity_range=pathogen_intensity_range, mask_chans=mask_chans, mask_chan=2)
|
1607
|
-
|
1608
|
-
df['recruitment'] = df[f'pathogen_channel_{channel_of_interest}_mean_intensity']/df[f'cytoplasm_channel_{channel_of_interest}_mean_intensity']
|
1609
|
-
for chan in channel_dims:
|
1610
|
-
df = _calculate_recruitment(df, channel=chan)
|
1611
|
-
print(f'calculated recruitment for: {len(df)} rows')
|
1612
|
-
df_well = _group_by_well(df)
|
1613
|
-
print(f'found: {len(df_well)} wells')
|
1614
|
-
|
1615
|
-
df_well = df_well[df_well['cells_per_well'] >= cells_per_well]
|
1616
|
-
prc_list = df_well['prc'].unique().tolist()
|
1617
|
-
df = df[df['prc'].isin(prc_list)]
|
1618
|
-
print(f'After cells per well filter: {len(df)} cells in {len(df_well)} wells left wth threshold {cells_per_well}')
|
1619
|
-
|
1620
|
-
if plot_control:
|
1621
|
-
_plot_controls(df, mask_chans, channel_of_interest, figuresize=5)
|
1622
|
-
|
1623
|
-
print(f'PV level: {len(df)} rows')
|
1624
|
-
_plot_recruitment(df=df, df_type='by PV', channel_of_interest=channel_of_interest, target=target, figuresize=figuresize)
|
1625
|
-
print(f'well level: {len(df_well)} rows')
|
1626
|
-
_plot_recruitment(df=df_well, df_type='by well', channel_of_interest=channel_of_interest, target=target, figuresize=figuresize)
|
1627
|
-
cells,wells = _results_to_csv(src, df, df_well)
|
1628
|
-
return [cells,wells]
|
1629
|
-
|
1630
10
|
def preprocess_generate_masks(src, settings={}):
|
1631
11
|
|
1632
12
|
from .io import preprocess_img_data, _load_and_concatenate_arrays
|
1633
13
|
from .plot import plot_image_mask_overlay, plot_arrays
|
1634
14
|
from .utils import _pivot_counts_table, check_mask_folder, adjust_cell_masks, print_progress, save_settings
|
1635
15
|
from .settings import set_default_settings_preprocess_generate_masks
|
1636
|
-
|
1637
|
-
settings = set_default_settings_preprocess_generate_masks(src, settings)
|
1638
|
-
settings['src'] = src
|
1639
|
-
save_settings(settings, name='gen_mask')
|
1640
16
|
|
1641
|
-
if not settings['
|
1642
|
-
|
1643
|
-
|
1644
|
-
ValueError(f'Pathogen model must be {custom_model_ls} or None')
|
1645
|
-
|
1646
|
-
if settings['timelapse']:
|
1647
|
-
settings['randomize'] = False
|
1648
|
-
|
1649
|
-
if settings['preprocess']:
|
1650
|
-
if not settings['masks']:
|
1651
|
-
print(f'WARNING: channels for mask generation are defined when preprocess = True')
|
17
|
+
if not isinstance(settings['src'], (str, list)):
|
18
|
+
ValueError(f'src must be a string or a list of strings')
|
19
|
+
return
|
1652
20
|
|
1653
|
-
if isinstance(settings['
|
1654
|
-
settings['
|
21
|
+
if isinstance(settings['src'], str):
|
22
|
+
settings['src'] = [settings['src']]
|
1655
23
|
|
1656
|
-
if settings['
|
1657
|
-
|
1658
|
-
|
1659
|
-
|
24
|
+
if isinstance(settings['src'], list):
|
25
|
+
source_folders = settings['src']
|
26
|
+
for source_folder in source_folders:
|
27
|
+
print(f'Processing folder: {source_folder}')
|
28
|
+
settings['src'] = source_folder
|
29
|
+
src = source_folder
|
30
|
+
settings = set_default_settings_preprocess_generate_masks(src, settings)
|
31
|
+
|
32
|
+
save_settings(settings, name='gen_mask')
|
1660
33
|
|
1661
|
-
|
1662
|
-
|
1663
|
-
|
1664
|
-
|
1665
|
-
settings, src = preprocess_img_data(settings)
|
1666
|
-
|
1667
|
-
files_to_process = 3
|
1668
|
-
files_processed = 0
|
1669
|
-
if settings['masks']:
|
1670
|
-
mask_src = os.path.join(src, 'norm_channel_stack')
|
1671
|
-
if settings['cell_channel'] != None:
|
1672
|
-
time_ls=[]
|
1673
|
-
if check_mask_folder(src, 'cell_mask_stack'):
|
1674
|
-
start = time.time()
|
1675
|
-
if settings['segmentation_mode'] == 'cellpose':
|
1676
|
-
generate_cellpose_masks(mask_src, settings, 'cell')
|
1677
|
-
elif settings['segmentation_mode'] == 'mediar':
|
1678
|
-
generate_mediar_masks(mask_src, settings, 'cell')
|
1679
|
-
stop = time.time()
|
1680
|
-
duration = (stop - start)
|
1681
|
-
time_ls.append(duration)
|
1682
|
-
files_processed += 1
|
1683
|
-
print_progress(files_processed, files_to_process, n_jobs=1, time_ls=time_ls, batch_size=None, operation_type=f'cell_mask_gen')
|
34
|
+
if not settings['pathogen_channel'] is None:
|
35
|
+
custom_model_ls = ['toxo_pv_lumen','toxo_cyto']
|
36
|
+
if settings['pathogen_model'] not in custom_model_ls:
|
37
|
+
ValueError(f'Pathogen model must be {custom_model_ls} or None')
|
1684
38
|
|
1685
|
-
|
1686
|
-
|
1687
|
-
if check_mask_folder(src, 'nucleus_mask_stack'):
|
1688
|
-
start = time.time()
|
1689
|
-
if settings['segmentation_mode'] == 'cellpose':
|
1690
|
-
generate_cellpose_masks(mask_src, settings, 'nucleus')
|
1691
|
-
elif settings['segmentation_mode'] == 'mediar':
|
1692
|
-
generate_mediar_masks(mask_src, settings, 'nucleus')
|
1693
|
-
stop = time.time()
|
1694
|
-
duration = (stop - start)
|
1695
|
-
time_ls.append(duration)
|
1696
|
-
files_processed += 1
|
1697
|
-
print_progress(files_processed, files_to_process, n_jobs=1, time_ls=time_ls, batch_size=None, operation_type=f'nucleus_mask_gen')
|
39
|
+
if settings['timelapse']:
|
40
|
+
settings['randomize'] = False
|
1698
41
|
|
1699
|
-
|
1700
|
-
|
1701
|
-
|
1702
|
-
start = time.time()
|
1703
|
-
if settings['segmentation_mode'] == 'cellpose':
|
1704
|
-
generate_cellpose_masks(mask_src, settings, 'pathogen')
|
1705
|
-
elif settings['segmentation_mode'] == 'mediar':
|
1706
|
-
generate_mediar_masks(mask_src, settings, 'pathogen')
|
1707
|
-
stop = time.time()
|
1708
|
-
duration = (stop - start)
|
1709
|
-
time_ls.append(duration)
|
1710
|
-
files_processed += 1
|
1711
|
-
print_progress(files_processed, files_to_process, n_jobs=1, time_ls=time_ls, batch_size=None, operation_type=f'pathogen_mask_gen')
|
1712
|
-
|
1713
|
-
#if settings['organelle'] != None:
|
1714
|
-
# if check_mask_folder(src, 'organelle_mask_stack'):
|
1715
|
-
# generate_cellpose_masks(mask_src, settings, 'organelle')
|
1716
|
-
|
1717
|
-
if settings['adjust_cells']:
|
1718
|
-
if settings['pathogen_channel'] != None and settings['cell_channel'] != None and settings['nucleus_channel'] != None:
|
1719
|
-
|
1720
|
-
start = time.time()
|
1721
|
-
cell_folder = os.path.join(mask_src, 'cell_mask_stack')
|
1722
|
-
nuclei_folder = os.path.join(mask_src, 'nucleus_mask_stack')
|
1723
|
-
parasite_folder = os.path.join(mask_src, 'pathogen_mask_stack')
|
1724
|
-
#organelle_folder = os.path.join(mask_src, 'organelle_mask_stack')
|
1725
|
-
|
1726
|
-
adjust_cell_masks(parasite_folder, cell_folder, nuclei_folder, overlap_threshold=5, perimeter_threshold=30)
|
1727
|
-
stop = time.time()
|
1728
|
-
adjust_time = (stop-start)/60
|
1729
|
-
print(f'Cell mask adjustment: {adjust_time} min.')
|
42
|
+
if settings['preprocess']:
|
43
|
+
if not settings['masks']:
|
44
|
+
print(f'WARNING: channels for mask generation are defined when preprocess = True')
|
1730
45
|
|
1731
|
-
|
1732
|
-
|
1733
|
-
|
1734
|
-
|
1735
|
-
|
1736
|
-
|
1737
|
-
|
1738
|
-
|
1739
|
-
|
1740
|
-
|
1741
|
-
|
1742
|
-
|
1743
|
-
|
1744
|
-
|
1745
|
-
|
1746
|
-
|
1747
|
-
|
46
|
+
if isinstance(settings['save'], bool):
|
47
|
+
settings['save'] = [settings['save']]*3
|
48
|
+
|
49
|
+
if settings['verbose']:
|
50
|
+
settings_df = pd.DataFrame(list(settings.items()), columns=['setting_key', 'setting_value'])
|
51
|
+
settings_df['setting_value'] = settings_df['setting_value'].apply(str)
|
52
|
+
display(settings_df)
|
53
|
+
|
54
|
+
if settings['test_mode']:
|
55
|
+
print(f'Starting Test mode ...')
|
56
|
+
|
57
|
+
if settings['preprocess']:
|
58
|
+
settings, src = preprocess_img_data(settings)
|
59
|
+
|
60
|
+
files_to_process = 3
|
61
|
+
files_processed = 0
|
62
|
+
if settings['masks']:
|
63
|
+
mask_src = os.path.join(src, 'norm_channel_stack')
|
64
|
+
if settings['cell_channel'] != None:
|
65
|
+
time_ls=[]
|
66
|
+
if check_mask_folder(src, 'cell_mask_stack'):
|
67
|
+
start = time.time()
|
68
|
+
if settings['segmentation_mode'] == 'cellpose':
|
69
|
+
generate_cellpose_masks(mask_src, settings, 'cell')
|
70
|
+
elif settings['segmentation_mode'] == 'mediar':
|
71
|
+
generate_mediar_masks(mask_src, settings, 'cell')
|
72
|
+
stop = time.time()
|
73
|
+
duration = (stop - start)
|
74
|
+
time_ls.append(duration)
|
75
|
+
files_processed += 1
|
76
|
+
print_progress(files_processed, files_to_process, n_jobs=1, time_ls=time_ls, batch_size=None, operation_type=f'cell_mask_gen')
|
1748
77
|
|
1749
|
-
|
78
|
+
if settings['nucleus_channel'] != None:
|
79
|
+
time_ls=[]
|
80
|
+
if check_mask_folder(src, 'nucleus_mask_stack'):
|
1750
81
|
start = time.time()
|
1751
|
-
if
|
1752
|
-
|
1753
|
-
|
1754
|
-
|
1755
|
-
|
1756
|
-
|
1757
|
-
|
1758
|
-
|
1759
|
-
|
1760
|
-
|
1761
|
-
|
1762
|
-
|
1763
|
-
|
1764
|
-
|
1765
|
-
|
1766
|
-
|
1767
|
-
|
1768
|
-
|
1769
|
-
|
82
|
+
if settings['segmentation_mode'] == 'cellpose':
|
83
|
+
generate_cellpose_masks(mask_src, settings, 'nucleus')
|
84
|
+
elif settings['segmentation_mode'] == 'mediar':
|
85
|
+
generate_mediar_masks(mask_src, settings, 'nucleus')
|
86
|
+
stop = time.time()
|
87
|
+
duration = (stop - start)
|
88
|
+
time_ls.append(duration)
|
89
|
+
files_processed += 1
|
90
|
+
print_progress(files_processed, files_to_process, n_jobs=1, time_ls=time_ls, batch_size=None, operation_type=f'nucleus_mask_gen')
|
91
|
+
|
92
|
+
if settings['pathogen_channel'] != None:
|
93
|
+
time_ls=[]
|
94
|
+
if check_mask_folder(src, 'pathogen_mask_stack'):
|
95
|
+
start = time.time()
|
96
|
+
if settings['segmentation_mode'] == 'cellpose':
|
97
|
+
generate_cellpose_masks(mask_src, settings, 'pathogen')
|
98
|
+
elif settings['segmentation_mode'] == 'mediar':
|
99
|
+
generate_mediar_masks(mask_src, settings, 'pathogen')
|
100
|
+
stop = time.time()
|
101
|
+
duration = (stop - start)
|
102
|
+
time_ls.append(duration)
|
103
|
+
files_processed += 1
|
104
|
+
print_progress(files_processed, files_to_process, n_jobs=1, time_ls=time_ls, batch_size=None, operation_type=f'pathogen_mask_gen')
|
105
|
+
|
106
|
+
#if settings['organelle'] != None:
|
107
|
+
# if check_mask_folder(src, 'organelle_mask_stack'):
|
108
|
+
# generate_cellpose_masks(mask_src, settings, 'organelle')
|
109
|
+
|
110
|
+
if settings['adjust_cells']:
|
111
|
+
if settings['pathogen_channel'] != None and settings['cell_channel'] != None and settings['nucleus_channel'] != None:
|
1770
112
|
|
1771
|
-
|
1772
|
-
|
1773
|
-
|
1774
|
-
|
1775
|
-
|
1776
|
-
|
1777
|
-
|
1778
|
-
|
1779
|
-
|
1780
|
-
|
1781
|
-
|
1782
|
-
|
1783
|
-
|
1784
|
-
background = settings['background']
|
1785
|
-
remove_background=settings['remove_background']
|
1786
|
-
Signal_to_noise = settings['Signal_to_noise']
|
1787
|
-
CP_prob = settings['CP_prob']
|
1788
|
-
diameter=settings['diameter']
|
1789
|
-
batch_size=settings['batch_size']
|
1790
|
-
flow_threshold=settings['flow_threshold']
|
1791
|
-
save=settings['save']
|
1792
|
-
verbose=settings['verbose']
|
1793
|
-
|
1794
|
-
# static settings
|
1795
|
-
normalize = settings['normalize']
|
1796
|
-
percentiles = settings['percentiles']
|
1797
|
-
circular = settings['circular']
|
1798
|
-
invert = settings['invert']
|
1799
|
-
resize = settings['resize']
|
1800
|
-
|
1801
|
-
if resize:
|
1802
|
-
target_height = settings['target_height']
|
1803
|
-
target_width = settings['target_width']
|
1804
|
-
|
1805
|
-
rescale = settings['rescale']
|
1806
|
-
resample = settings['resample']
|
1807
|
-
grayscale = settings['grayscale']
|
1808
|
-
|
1809
|
-
os.makedirs(dst, exist_ok=True)
|
1810
|
-
|
1811
|
-
if not custom_model is None:
|
1812
|
-
if not os.path.exists(custom_model):
|
1813
|
-
print(f'Custom model not found: {custom_model}')
|
1814
|
-
return
|
113
|
+
start = time.time()
|
114
|
+
cell_folder = os.path.join(mask_src, 'cell_mask_stack')
|
115
|
+
nuclei_folder = os.path.join(mask_src, 'nucleus_mask_stack')
|
116
|
+
parasite_folder = os.path.join(mask_src, 'pathogen_mask_stack')
|
117
|
+
#organelle_folder = os.path.join(mask_src, 'organelle_mask_stack')
|
118
|
+
print(f'Adjusting cell masks with nuclei and pathogen masks')
|
119
|
+
adjust_cell_masks(parasite_folder, cell_folder, nuclei_folder, overlap_threshold=5, perimeter_threshold=30)
|
120
|
+
stop = time.time()
|
121
|
+
adjust_time = (stop-start)/60
|
122
|
+
print(f'Cell mask adjustment: {adjust_time} min.')
|
123
|
+
|
124
|
+
if os.path.exists(os.path.join(src,'measurements')):
|
125
|
+
_pivot_counts_table(db_path=os.path.join(src,'measurements', 'measurements.db'))
|
1815
126
|
|
1816
|
-
|
1817
|
-
|
1818
|
-
|
1819
|
-
|
1820
|
-
|
1821
|
-
if custom_model == None:
|
1822
|
-
model = cp_models.CellposeModel(gpu=True, model_type=model_name, device=device)
|
1823
|
-
print(f'Loaded model: {model_name}')
|
1824
|
-
else:
|
1825
|
-
model = cp_models.CellposeModel(gpu=torch.cuda.is_available(), model_type=None, pretrained_model=custom_model, diam_mean=diameter, device=device)
|
1826
|
-
print("Pretrained Model Loaded:", model.pretrained_model)
|
127
|
+
#Concatenate stack with masks
|
128
|
+
_load_and_concatenate_arrays(src, settings['channels'], settings['cell_channel'], settings['nucleus_channel'], settings['pathogen_channel'])
|
129
|
+
|
130
|
+
if settings['plot']:
|
131
|
+
if not settings['timelapse']:
|
1827
132
|
|
1828
|
-
|
1829
|
-
|
1830
|
-
if grayscale:
|
1831
|
-
chans=[0, 0]
|
1832
|
-
|
1833
|
-
print(f'Using channels: {chans} for model of type {model_name}')
|
1834
|
-
|
1835
|
-
if verbose == True:
|
1836
|
-
print(f'Cellpose settings: Model: {model_name}, channels: {channels}, cellpose_chans: {chans}, diameter:{diameter}, flow_threshold:{flow_threshold}, cellprob_threshold:{CP_prob}')
|
1837
|
-
|
1838
|
-
all_image_files = [os.path.join(src, f) for f in os.listdir(src) if f.endswith('.tif')]
|
1839
|
-
mask_files = set(os.listdir(os.path.join(src, 'masks')))
|
1840
|
-
all_image_files = [f for f in all_image_files if os.path.basename(f) not in mask_files]
|
1841
|
-
random.shuffle(all_image_files)
|
1842
|
-
|
1843
|
-
time_ls = []
|
1844
|
-
for i in range(0, len(all_image_files), batch_size):
|
1845
|
-
gc.collect()
|
1846
|
-
image_files = all_image_files[i:i+batch_size]
|
1847
|
-
|
1848
|
-
if normalize:
|
1849
|
-
images, _, image_names, _, orig_dims = _load_normalized_images_and_labels(image_files=image_files, label_files=None, channels=channels, percentiles=percentiles, circular=circular, invert=invert, visualize=verbose, remove_background=remove_background, background=background, Signal_to_noise=Signal_to_noise, target_height=target_height, target_width=target_width)
|
1850
|
-
images = [np.squeeze(img) if img.shape[-1] == 1 else img for img in images]
|
1851
|
-
#orig_dims = [(image.shape[0], image.shape[1]) for image in images]
|
1852
|
-
else:
|
1853
|
-
images, _, image_names, _ = _load_images_and_labels(image_files=image_files, label_files=None, circular=circular, invert=invert)
|
1854
|
-
images = [np.squeeze(img) if img.shape[-1] == 1 else img for img in images]
|
1855
|
-
orig_dims = [(image.shape[0], image.shape[1]) for image in images]
|
1856
|
-
if resize:
|
1857
|
-
images, _ = resize_images_and_labels(images, None, target_height, target_width, True)
|
1858
|
-
|
1859
|
-
for file_index, stack in enumerate(images):
|
1860
|
-
start = time.time()
|
1861
|
-
output = model.eval(x=stack,
|
1862
|
-
normalize=False,
|
1863
|
-
channels=chans,
|
1864
|
-
channel_axis=3,
|
1865
|
-
diameter=diameter,
|
1866
|
-
flow_threshold=flow_threshold,
|
1867
|
-
cellprob_threshold=CP_prob,
|
1868
|
-
rescale=rescale,
|
1869
|
-
resample=resample,
|
1870
|
-
progress=True)
|
133
|
+
if settings['test_mode'] == True:
|
134
|
+
settings['examples_to_plot'] = len(os.path.join(src,'merged'))
|
1871
135
|
|
1872
|
-
|
1873
|
-
|
1874
|
-
|
1875
|
-
|
1876
|
-
|
1877
|
-
|
1878
|
-
|
1879
|
-
|
1880
|
-
|
1881
|
-
|
1882
|
-
|
1883
|
-
|
1884
|
-
|
1885
|
-
|
1886
|
-
|
1887
|
-
|
1888
|
-
|
1889
|
-
|
1890
|
-
|
1891
|
-
|
1892
|
-
|
1893
|
-
|
1894
|
-
|
1895
|
-
|
1896
|
-
|
1897
|
-
|
1898
|
-
output_filename = os.path.join(dst, image_names[file_index])
|
1899
|
-
cv2.imwrite(output_filename, mask)
|
1900
|
-
del images, output, mask, flows
|
1901
|
-
gc.collect()
|
136
|
+
try:
|
137
|
+
merged_src = os.path.join(src,'merged')
|
138
|
+
files = os.listdir(merged_src)
|
139
|
+
random.shuffle(files)
|
140
|
+
time_ls = []
|
141
|
+
|
142
|
+
for i, file in enumerate(files):
|
143
|
+
start = time.time()
|
144
|
+
if i+1 <= settings['examples_to_plot']:
|
145
|
+
file_path = os.path.join(merged_src, file)
|
146
|
+
plot_image_mask_overlay(file_path, settings['channels'], settings['cell_channel'], settings['nucleus_channel'], settings['pathogen_channel'], figuresize=10, normalize=True, thickness=3, save_pdf=True)
|
147
|
+
stop = time.time()
|
148
|
+
duration = stop-start
|
149
|
+
time_ls.append(duration)
|
150
|
+
files_processed = i+1
|
151
|
+
files_to_process = settings['examples_to_plot']
|
152
|
+
print_progress(files_processed, files_to_process, n_jobs=1, time_ls=time_ls, batch_size=None, operation_type="Plot mask outlines")
|
153
|
+
|
154
|
+
except Exception as e:
|
155
|
+
print(f'Failed to plot image mask overly. Error: {e}')
|
156
|
+
else:
|
157
|
+
plot_arrays(src=os.path.join(src,'merged'), figuresize=settings['figuresize'], cmap=settings['cmap'], nr=settings['examples_to_plot'], normalize=settings['normalize'], q1=1, q2=99)
|
158
|
+
|
159
|
+
torch.cuda.empty_cache()
|
160
|
+
gc.collect()
|
161
|
+
print("Successfully completed run")
|
1902
162
|
return
|
1903
163
|
|
1904
|
-
def all_elements_match(list1, list2):
|
1905
|
-
# Check if all elements in list1 are in list2
|
1906
|
-
return all(element in list2 for element in list1)
|
1907
|
-
|
1908
|
-
def prepare_batch_for_segmentation(batch):
|
1909
|
-
# Ensure the batch is of dtype float32
|
1910
|
-
if batch.dtype != np.float32:
|
1911
|
-
batch = batch.astype(np.float32)
|
1912
|
-
|
1913
|
-
# Normalize each image in the batch
|
1914
|
-
for i in range(batch.shape[0]):
|
1915
|
-
if batch[i].max() > 1:
|
1916
|
-
batch[i] = batch[i] / batch[i].max()
|
1917
|
-
|
1918
|
-
return batch
|
1919
|
-
|
1920
164
|
def generate_cellpose_masks(src, settings, object_type):
|
1921
165
|
|
1922
|
-
from .utils import _masks_to_masks_stack, _filter_cp_masks, _get_cellpose_batch_size, _get_cellpose_channels, _choose_model, mask_object_count,
|
166
|
+
from .utils import _masks_to_masks_stack, _filter_cp_masks, _get_cellpose_batch_size, _get_cellpose_channels, _choose_model, mask_object_count, all_elements_match, prepare_batch_for_segmentation
|
1923
167
|
from .io import _create_database, _save_object_counts_to_database, _check_masks, _get_avg_object_size
|
1924
168
|
from .timelapse import _npz_to_movie, _btrack_track_cells, _trackpy_track_cells
|
1925
169
|
from .plot import plot_masks
|
@@ -2162,593 +406,6 @@ def generate_cellpose_masks(src, settings, object_type):
|
|
2162
406
|
torch.cuda.empty_cache()
|
2163
407
|
return
|
2164
408
|
|
2165
|
-
def generate_masks_from_imgs(src, model, model_name, batch_size, diameter, cellprob_threshold, flow_threshold, grayscale, save, normalize, channels, percentiles, circular, invert, plot, resize, target_height, target_width, remove_background, background, Signal_to_noise, verbose):
|
2166
|
-
|
2167
|
-
from .io import _load_images_and_labels, _load_normalized_images_and_labels
|
2168
|
-
from .utils import resize_images_and_labels, resizescikit, print_progress
|
2169
|
-
from .plot import print_mask_and_flows
|
2170
|
-
|
2171
|
-
dst = os.path.join(src, model_name)
|
2172
|
-
os.makedirs(dst, exist_ok=True)
|
2173
|
-
|
2174
|
-
chans = [2, 1] if model_name == 'cyto2' else [0,0] if model_name == 'nucleus' else [1,0] if model_name == 'cyto' else [2, 0]
|
2175
|
-
|
2176
|
-
if grayscale:
|
2177
|
-
chans=[0, 0]
|
2178
|
-
|
2179
|
-
all_image_files = [os.path.join(src, f) for f in os.listdir(src) if f.endswith('.tif')]
|
2180
|
-
random.shuffle(all_image_files)
|
2181
|
-
|
2182
|
-
if verbose == True:
|
2183
|
-
print(f'Cellpose settings: Model: {model_name}, channels: {channels}, cellpose_chans: {chans}, diameter:{diameter}, flow_threshold:{flow_threshold}, cellprob_threshold:{cellprob_threshold}')
|
2184
|
-
|
2185
|
-
time_ls = []
|
2186
|
-
for i in range(0, len(all_image_files), batch_size):
|
2187
|
-
image_files = all_image_files[i:i+batch_size]
|
2188
|
-
|
2189
|
-
if normalize:
|
2190
|
-
images, _, image_names, _, orig_dims = _load_normalized_images_and_labels(image_files, None, channels, percentiles, circular, invert, plot, remove_background, background, Signal_to_noise, target_height, target_width)
|
2191
|
-
images = [np.squeeze(img) if img.shape[-1] == 1 else img for img in images]
|
2192
|
-
orig_dims = [(image.shape[0], image.shape[1]) for image in images]
|
2193
|
-
else:
|
2194
|
-
images, _, image_names, _ = _load_images_and_labels(image_files, None, circular, invert)
|
2195
|
-
images = [np.squeeze(img) if img.shape[-1] == 1 else img for img in images]
|
2196
|
-
orig_dims = [(image.shape[0], image.shape[1]) for image in images]
|
2197
|
-
if resize:
|
2198
|
-
images, _ = resize_images_and_labels(images, None, target_height, target_width, True)
|
2199
|
-
|
2200
|
-
for file_index, stack in enumerate(images):
|
2201
|
-
start = time.time()
|
2202
|
-
output = model.eval(x=stack,
|
2203
|
-
normalize=False,
|
2204
|
-
channels=chans,
|
2205
|
-
channel_axis=3,
|
2206
|
-
diameter=diameter,
|
2207
|
-
flow_threshold=flow_threshold,
|
2208
|
-
cellprob_threshold=cellprob_threshold,
|
2209
|
-
rescale=False,
|
2210
|
-
resample=False,
|
2211
|
-
progress=False)
|
2212
|
-
|
2213
|
-
if len(output) == 4:
|
2214
|
-
mask, flows, _, _ = output
|
2215
|
-
elif len(output) == 3:
|
2216
|
-
mask, flows, _ = output
|
2217
|
-
else:
|
2218
|
-
raise ValueError("Unexpected number of return values from model.eval()")
|
2219
|
-
|
2220
|
-
if resize:
|
2221
|
-
dims = orig_dims[file_index]
|
2222
|
-
mask = resizescikit(mask, dims, order=0, preserve_range=True, anti_aliasing=False).astype(mask.dtype)
|
2223
|
-
|
2224
|
-
stop = time.time()
|
2225
|
-
duration = (stop - start)
|
2226
|
-
time_ls.append(duration)
|
2227
|
-
files_processed = file_index+1
|
2228
|
-
files_to_process = len(images)
|
2229
|
-
|
2230
|
-
print_progress(files_processed, files_to_process, n_jobs=1, time_ls=time_ls, batch_size=None, operation_type="Generating masks")
|
2231
|
-
|
2232
|
-
if plot:
|
2233
|
-
if resize:
|
2234
|
-
stack = resizescikit(stack, dims, preserve_range=True, anti_aliasing=False).astype(stack.dtype)
|
2235
|
-
print_mask_and_flows(stack, mask, flows, overlay=True)
|
2236
|
-
if save:
|
2237
|
-
output_filename = os.path.join(dst, image_names[file_index])
|
2238
|
-
cv2.imwrite(output_filename, mask)
|
2239
|
-
|
2240
|
-
|
2241
|
-
def check_cellpose_models(settings):
|
2242
|
-
|
2243
|
-
from .settings import get_check_cellpose_models_default_settings
|
2244
|
-
|
2245
|
-
settings = get_check_cellpose_models_default_settings(settings)
|
2246
|
-
src = settings['src']
|
2247
|
-
|
2248
|
-
settings_df = pd.DataFrame(list(settings.items()), columns=['setting_key', 'setting_value'])
|
2249
|
-
settings_df['setting_value'] = settings_df['setting_value'].apply(str)
|
2250
|
-
display(settings_df)
|
2251
|
-
|
2252
|
-
cellpose_models = ['cyto', 'nuclei', 'cyto2', 'cyto3']
|
2253
|
-
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
2254
|
-
|
2255
|
-
for model_name in cellpose_models:
|
2256
|
-
|
2257
|
-
model = cp_models.CellposeModel(gpu=True, model_type=model_name, device=device)
|
2258
|
-
print(f'Using {model_name}')
|
2259
|
-
generate_masks_from_imgs(src, model, model_name, settings['batch_size'], settings['diameter'], settings['CP_prob'], settings['flow_threshold'], settings['grayscale'], settings['save'], settings['normalize'], settings['channels'], settings['percentiles'], settings['circular'], settings['invert'], settings['plot'], settings['resize'], settings['target_height'], settings['target_width'], settings['remove_background'], settings['background'], settings['Signal_to_noise'], settings['verbose'])
|
2260
|
-
|
2261
|
-
return
|
2262
|
-
|
2263
|
-
def save_results_and_figure(src, fig, results):
|
2264
|
-
|
2265
|
-
if not isinstance(results, pd.DataFrame):
|
2266
|
-
results = pd.DataFrame(results)
|
2267
|
-
|
2268
|
-
results_dir = os.path.join(src, 'results')
|
2269
|
-
os.makedirs(results_dir, exist_ok=True)
|
2270
|
-
results_path = os.path.join(results_dir,f'results.csv')
|
2271
|
-
fig_path = os.path.join(results_dir, f'model_comparison_plot.pdf')
|
2272
|
-
results.to_csv(results_path, index=False)
|
2273
|
-
fig.savefig(fig_path, format='pdf')
|
2274
|
-
print(f'Saved figure to {fig_path} and results to {results_path}')
|
2275
|
-
|
2276
|
-
def compare_mask(args):
|
2277
|
-
src, filename, dirs, conditions = args
|
2278
|
-
paths = [os.path.join(d, filename) for d in dirs]
|
2279
|
-
|
2280
|
-
if not all(os.path.exists(path) for path in paths):
|
2281
|
-
return None
|
2282
|
-
|
2283
|
-
from .io import _read_mask # Import here to avoid issues in multiprocessing
|
2284
|
-
from .utils import extract_boundaries, boundary_f1_score, compute_segmentation_ap, jaccard_index
|
2285
|
-
from .plot import plot_comparison_results
|
2286
|
-
|
2287
|
-
masks = [_read_mask(path) for path in paths]
|
2288
|
-
file_results = {'filename': filename}
|
2289
|
-
|
2290
|
-
for i in range(len(masks)):
|
2291
|
-
for j in range(i + 1, len(masks)):
|
2292
|
-
mask_i, mask_j = masks[i], masks[j]
|
2293
|
-
f1_score = boundary_f1_score(mask_i, mask_j)
|
2294
|
-
jac_index = jaccard_index(mask_i, mask_j)
|
2295
|
-
ap_score = compute_segmentation_ap(mask_i, mask_j)
|
2296
|
-
|
2297
|
-
file_results.update({
|
2298
|
-
f'jaccard_{conditions[i]}_{conditions[j]}': jac_index,
|
2299
|
-
f'boundary_f1_{conditions[i]}_{conditions[j]}': f1_score,
|
2300
|
-
f'ap_{conditions[i]}_{conditions[j]}': ap_score
|
2301
|
-
})
|
2302
|
-
|
2303
|
-
return file_results
|
2304
|
-
|
2305
|
-
def compare_cellpose_masks(src, verbose=False, processes=None, save=True):
|
2306
|
-
from .plot import visualize_cellpose_masks, plot_comparison_results
|
2307
|
-
from .io import _read_mask
|
2308
|
-
|
2309
|
-
dirs = [os.path.join(src, d) for d in os.listdir(src) if os.path.isdir(os.path.join(src, d)) and d != 'results']
|
2310
|
-
dirs.sort()
|
2311
|
-
conditions = [os.path.basename(d) for d in dirs]
|
2312
|
-
|
2313
|
-
# Get common files in all directories
|
2314
|
-
common_files = set(os.listdir(dirs[0]))
|
2315
|
-
for d in dirs[1:]:
|
2316
|
-
common_files.intersection_update(os.listdir(d))
|
2317
|
-
common_files = list(common_files)
|
2318
|
-
|
2319
|
-
# Create a pool of n_jobs
|
2320
|
-
with Pool(processes=processes) as pool:
|
2321
|
-
args = [(src, filename, dirs, conditions) for filename in common_files]
|
2322
|
-
results = pool.map(compare_mask, args)
|
2323
|
-
|
2324
|
-
# Filter out None results (from skipped files)
|
2325
|
-
results = [res for res in results if res is not None]
|
2326
|
-
print(results)
|
2327
|
-
if verbose:
|
2328
|
-
for result in results:
|
2329
|
-
filename = result['filename']
|
2330
|
-
masks = [_read_mask(os.path.join(d, filename)) for d in dirs]
|
2331
|
-
visualize_cellpose_masks(masks, titles=conditions, filename=filename, save=save, src=src)
|
2332
|
-
|
2333
|
-
fig = plot_comparison_results(results)
|
2334
|
-
save_results_and_figure(src, fig, results)
|
2335
|
-
return
|
2336
|
-
|
2337
|
-
def _calculate_similarity(df, features, col_to_compare, val1, val2):
|
2338
|
-
"""
|
2339
|
-
Calculate similarity scores of each well to the positive and negative controls using various metrics.
|
2340
|
-
|
2341
|
-
Args:
|
2342
|
-
df (pandas.DataFrame): DataFrame containing the data.
|
2343
|
-
features (list): List of feature columns to use for similarity calculation.
|
2344
|
-
col_to_compare (str): Column name to use for comparing groups.
|
2345
|
-
val1, val2 (str): Values in col_to_compare to create subsets for comparison.
|
2346
|
-
|
2347
|
-
Returns:
|
2348
|
-
pandas.DataFrame: DataFrame with similarity scores.
|
2349
|
-
"""
|
2350
|
-
# Separate positive and negative control wells
|
2351
|
-
pos_control = df[df[col_to_compare] == val1][features].mean()
|
2352
|
-
neg_control = df[df[col_to_compare] == val2][features].mean()
|
2353
|
-
|
2354
|
-
# Standardize features for Mahalanobis distance
|
2355
|
-
scaler = StandardScaler()
|
2356
|
-
scaled_features = scaler.fit_transform(df[features])
|
2357
|
-
|
2358
|
-
# Regularize the covariance matrix to avoid singularity
|
2359
|
-
cov_matrix = np.cov(scaled_features, rowvar=False)
|
2360
|
-
inv_cov_matrix = None
|
2361
|
-
try:
|
2362
|
-
inv_cov_matrix = np.linalg.inv(cov_matrix)
|
2363
|
-
except np.linalg.LinAlgError:
|
2364
|
-
# Add a small value to the diagonal elements for regularization
|
2365
|
-
epsilon = 1e-5
|
2366
|
-
inv_cov_matrix = np.linalg.inv(cov_matrix + np.eye(cov_matrix.shape[0]) * epsilon)
|
2367
|
-
|
2368
|
-
# Calculate similarity scores
|
2369
|
-
df['similarity_to_pos_euclidean'] = df[features].apply(lambda row: euclidean(row, pos_control), axis=1)
|
2370
|
-
df['similarity_to_neg_euclidean'] = df[features].apply(lambda row: euclidean(row, neg_control), axis=1)
|
2371
|
-
df['similarity_to_pos_cosine'] = df[features].apply(lambda row: cosine(row, pos_control), axis=1)
|
2372
|
-
df['similarity_to_neg_cosine'] = df[features].apply(lambda row: cosine(row, neg_control), axis=1)
|
2373
|
-
df['similarity_to_pos_mahalanobis'] = df[features].apply(lambda row: mahalanobis(row, pos_control, inv_cov_matrix), axis=1)
|
2374
|
-
df['similarity_to_neg_mahalanobis'] = df[features].apply(lambda row: mahalanobis(row, neg_control, inv_cov_matrix), axis=1)
|
2375
|
-
df['similarity_to_pos_manhattan'] = df[features].apply(lambda row: cityblock(row, pos_control), axis=1)
|
2376
|
-
df['similarity_to_neg_manhattan'] = df[features].apply(lambda row: cityblock(row, neg_control), axis=1)
|
2377
|
-
df['similarity_to_pos_minkowski'] = df[features].apply(lambda row: minkowski(row, pos_control, p=3), axis=1)
|
2378
|
-
df['similarity_to_neg_minkowski'] = df[features].apply(lambda row: minkowski(row, neg_control, p=3), axis=1)
|
2379
|
-
df['similarity_to_pos_chebyshev'] = df[features].apply(lambda row: chebyshev(row, pos_control), axis=1)
|
2380
|
-
df['similarity_to_neg_chebyshev'] = df[features].apply(lambda row: chebyshev(row, neg_control), axis=1)
|
2381
|
-
df['similarity_to_pos_hamming'] = df[features].apply(lambda row: hamming(row, pos_control), axis=1)
|
2382
|
-
df['similarity_to_neg_hamming'] = df[features].apply(lambda row: hamming(row, neg_control), axis=1)
|
2383
|
-
df['similarity_to_pos_jaccard'] = df[features].apply(lambda row: jaccard(row, pos_control), axis=1)
|
2384
|
-
df['similarity_to_neg_jaccard'] = df[features].apply(lambda row: jaccard(row, neg_control), axis=1)
|
2385
|
-
df['similarity_to_pos_braycurtis'] = df[features].apply(lambda row: braycurtis(row, pos_control), axis=1)
|
2386
|
-
df['similarity_to_neg_braycurtis'] = df[features].apply(lambda row: braycurtis(row, neg_control), axis=1)
|
2387
|
-
|
2388
|
-
return df
|
2389
|
-
|
2390
|
-
def find_optimal_threshold(y_true, y_pred_proba):
|
2391
|
-
"""
|
2392
|
-
Find the optimal threshold for binary classification based on the F1-score.
|
2393
|
-
|
2394
|
-
Args:
|
2395
|
-
y_true (array-like): True binary labels.
|
2396
|
-
y_pred_proba (array-like): Predicted probabilities for the positive class.
|
2397
|
-
|
2398
|
-
Returns:
|
2399
|
-
float: The optimal threshold.
|
2400
|
-
"""
|
2401
|
-
precision, recall, thresholds = precision_recall_curve(y_true, y_pred_proba)
|
2402
|
-
f1_scores = 2 * (precision * recall) / (precision + recall)
|
2403
|
-
optimal_idx = np.argmax(f1_scores)
|
2404
|
-
optimal_threshold = thresholds[optimal_idx]
|
2405
|
-
return optimal_threshold
|
2406
|
-
|
2407
|
-
def ml_analysis(df, channel_of_interest=3, location_column='col', positive_control='c2', negative_control='c1', exclude=None, n_repeats=10, top_features=30, n_estimators=100, test_size=0.2, model_type='xgboost', n_jobs=-1, remove_low_variance_features=True, remove_highly_correlated_features=True, verbose=False):
|
2408
|
-
"""
|
2409
|
-
Calculates permutation importance for numerical features in the dataframe,
|
2410
|
-
comparing groups based on specified column values and uses the model to predict
|
2411
|
-
the class for all other rows in the dataframe.
|
2412
|
-
|
2413
|
-
Args:
|
2414
|
-
df (pandas.DataFrame): The DataFrame containing the data.
|
2415
|
-
feature_string (str): String to filter features that contain this substring.
|
2416
|
-
location_column (str): Column name to use for comparing groups.
|
2417
|
-
positive_control, negative_control (str): Values in location_column to create subsets for comparison.
|
2418
|
-
exclude (list or str, optional): Columns to exclude from features.
|
2419
|
-
n_repeats (int): Number of repeats for permutation importance.
|
2420
|
-
top_features (int): Number of top features to plot based on permutation importance.
|
2421
|
-
n_estimators (int): Number of trees in the random forest, gradient boosting, or XGBoost model.
|
2422
|
-
test_size (float): Proportion of the dataset to include in the test split.
|
2423
|
-
random_state (int): Random seed for reproducibility.
|
2424
|
-
model_type (str): Type of model to use ('random_forest', 'logistic_regression', 'gradient_boosting', 'xgboost').
|
2425
|
-
n_jobs (int): Number of jobs to run in parallel for applicable models.
|
2426
|
-
|
2427
|
-
Returns:
|
2428
|
-
pandas.DataFrame: The original dataframe with added prediction and data usage columns.
|
2429
|
-
pandas.DataFrame: DataFrame containing the importances and standard deviations.
|
2430
|
-
"""
|
2431
|
-
|
2432
|
-
from .utils import filter_dataframe_features
|
2433
|
-
from .plot import plot_permutation, plot_feature_importance
|
2434
|
-
|
2435
|
-
random_state = 42
|
2436
|
-
|
2437
|
-
if 'cells_per_well' in df.columns:
|
2438
|
-
df = df.drop(columns=['cells_per_well'])
|
2439
|
-
|
2440
|
-
df_metadata = df[[location_column]].copy()
|
2441
|
-
df, features = filter_dataframe_features(df, channel_of_interest, exclude, remove_low_variance_features, remove_highly_correlated_features, verbose)
|
2442
|
-
|
2443
|
-
if verbose:
|
2444
|
-
print(f'Found {len(features)} numerical features in the dataframe')
|
2445
|
-
print(f'Features used in training: {features}')
|
2446
|
-
df = pd.concat([df, df_metadata[location_column]], axis=1)
|
2447
|
-
|
2448
|
-
# Subset the dataframe based on specified column values
|
2449
|
-
df1 = df[df[location_column] == negative_control].copy()
|
2450
|
-
df2 = df[df[location_column] == positive_control].copy()
|
2451
|
-
|
2452
|
-
# Create target variable
|
2453
|
-
df1['target'] = 0 # Negative control
|
2454
|
-
df2['target'] = 1 # Positive control
|
2455
|
-
|
2456
|
-
# Combine the subsets for analysis
|
2457
|
-
combined_df = pd.concat([df1, df2])
|
2458
|
-
combined_df = combined_df.drop(columns=[location_column])
|
2459
|
-
if verbose:
|
2460
|
-
print(f'Found {len(df1)} samples for {negative_control} and {len(df2)} samples for {positive_control}. Total: {len(combined_df)}')
|
2461
|
-
|
2462
|
-
X = combined_df[features]
|
2463
|
-
y = combined_df['target']
|
2464
|
-
|
2465
|
-
# Split the data into training and testing sets
|
2466
|
-
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=test_size, random_state=random_state)
|
2467
|
-
|
2468
|
-
# Add data usage labels
|
2469
|
-
combined_df['data_usage'] = 'train'
|
2470
|
-
combined_df.loc[X_test.index, 'data_usage'] = 'test'
|
2471
|
-
df['data_usage'] = 'not_used'
|
2472
|
-
df.loc[combined_df.index, 'data_usage'] = combined_df['data_usage']
|
2473
|
-
|
2474
|
-
# Initialize the model based on model_type
|
2475
|
-
if model_type == 'random_forest':
|
2476
|
-
model = RandomForestClassifier(n_estimators=n_estimators, random_state=random_state, n_jobs=n_jobs)
|
2477
|
-
elif model_type == 'logistic_regression':
|
2478
|
-
model = LogisticRegression(max_iter=1000, random_state=random_state, n_jobs=n_jobs)
|
2479
|
-
elif model_type == 'gradient_boosting':
|
2480
|
-
model = HistGradientBoostingClassifier(max_iter=n_estimators, random_state=random_state) # Supports n_jobs internally
|
2481
|
-
elif model_type == 'xgboost':
|
2482
|
-
model = XGBClassifier(n_estimators=n_estimators, random_state=random_state, nthread=n_jobs, use_label_encoder=False, eval_metric='logloss')
|
2483
|
-
else:
|
2484
|
-
raise ValueError(f"Unsupported model_type: {model_type}")
|
2485
|
-
|
2486
|
-
model.fit(X_train, y_train)
|
2487
|
-
|
2488
|
-
perm_importance = permutation_importance(model, X_train, y_train, n_repeats=n_repeats, random_state=random_state, n_jobs=n_jobs)
|
2489
|
-
|
2490
|
-
# Create a DataFrame for permutation importances
|
2491
|
-
permutation_df = pd.DataFrame({
|
2492
|
-
'feature': [features[i] for i in perm_importance.importances_mean.argsort()],
|
2493
|
-
'importance_mean': perm_importance.importances_mean[perm_importance.importances_mean.argsort()],
|
2494
|
-
'importance_std': perm_importance.importances_std[perm_importance.importances_mean.argsort()]
|
2495
|
-
}).tail(top_features)
|
2496
|
-
|
2497
|
-
permutation_fig = plot_permutation(permutation_df)
|
2498
|
-
if verbose:
|
2499
|
-
permutation_fig.show()
|
2500
|
-
|
2501
|
-
# Feature importance for models that support it
|
2502
|
-
if model_type in ['random_forest', 'xgboost', 'gradient_boosting']:
|
2503
|
-
feature_importances = model.feature_importances_
|
2504
|
-
feature_importance_df = pd.DataFrame({
|
2505
|
-
'feature': features,
|
2506
|
-
'importance': feature_importances
|
2507
|
-
}).sort_values(by='importance', ascending=False).head(top_features)
|
2508
|
-
|
2509
|
-
feature_importance_fig = plot_feature_importance(feature_importance_df)
|
2510
|
-
if verbose:
|
2511
|
-
feature_importance_fig.show()
|
2512
|
-
|
2513
|
-
else:
|
2514
|
-
feature_importance_df = pd.DataFrame()
|
2515
|
-
|
2516
|
-
# Predicting the target variable for the test set
|
2517
|
-
predictions_test = model.predict(X_test)
|
2518
|
-
combined_df.loc[X_test.index, 'predictions'] = predictions_test
|
2519
|
-
|
2520
|
-
# Get prediction probabilities for the test set
|
2521
|
-
prediction_probabilities_test = model.predict_proba(X_test)
|
2522
|
-
|
2523
|
-
# Find the optimal threshold
|
2524
|
-
optimal_threshold = find_optimal_threshold(y_test, prediction_probabilities_test[:, 1])
|
2525
|
-
if verbose:
|
2526
|
-
print(f'Optimal threshold: {optimal_threshold}')
|
2527
|
-
|
2528
|
-
# Predicting the target variable for all other rows in the dataframe
|
2529
|
-
X_all = df[features]
|
2530
|
-
all_predictions = model.predict(X_all)
|
2531
|
-
df['predictions'] = all_predictions
|
2532
|
-
|
2533
|
-
# Get prediction probabilities for all rows in the dataframe
|
2534
|
-
prediction_probabilities = model.predict_proba(X_all)
|
2535
|
-
for i in range(prediction_probabilities.shape[1]):
|
2536
|
-
df[f'prediction_probability_class_{i}'] = prediction_probabilities[:, i]
|
2537
|
-
if verbose:
|
2538
|
-
print("\nClassification Report:")
|
2539
|
-
print(classification_report(y_test, predictions_test))
|
2540
|
-
report_dict = classification_report(y_test, predictions_test, output_dict=True)
|
2541
|
-
metrics_df = pd.DataFrame(report_dict).transpose()
|
2542
|
-
|
2543
|
-
df = _calculate_similarity(df, features, location_column, positive_control, negative_control)
|
2544
|
-
|
2545
|
-
df['prcfo'] = df.index.astype(str)
|
2546
|
-
df[['plate', 'row', 'col', 'field', 'object']] = df['prcfo'].str.split('_', expand=True)
|
2547
|
-
df['prc'] = df['plate'] + '_' + df['row'] + '_' + df['col']
|
2548
|
-
|
2549
|
-
return [df, permutation_df, feature_importance_df, model, X_train, X_test, y_train, y_test, metrics_df], [permutation_fig, feature_importance_fig]
|
2550
|
-
|
2551
|
-
def shap_analysis(model, X_train, X_test):
|
2552
|
-
|
2553
|
-
"""
|
2554
|
-
Performs SHAP analysis on the given model and data.
|
2555
|
-
|
2556
|
-
Args:
|
2557
|
-
model: The trained model.
|
2558
|
-
X_train (pandas.DataFrame): Training feature set.
|
2559
|
-
X_test (pandas.DataFrame): Testing feature set.
|
2560
|
-
Returns:
|
2561
|
-
fig: Matplotlib figure object containing the SHAP summary plot.
|
2562
|
-
"""
|
2563
|
-
|
2564
|
-
explainer = shap.Explainer(model, X_train)
|
2565
|
-
shap_values = explainer(X_test)
|
2566
|
-
# Create a new figure
|
2567
|
-
fig, ax = plt.subplots()
|
2568
|
-
# Summary plot
|
2569
|
-
shap.summary_plot(shap_values, X_test, show=False)
|
2570
|
-
# Save the current figure (the one that SHAP just created)
|
2571
|
-
fig = plt.gcf()
|
2572
|
-
plt.close(fig) # Close the figure to prevent it from displaying immediately
|
2573
|
-
return fig
|
2574
|
-
|
2575
|
-
def check_index(df, elements=5, split_char='_'):
|
2576
|
-
problematic_indices = []
|
2577
|
-
for idx in df.index:
|
2578
|
-
parts = str(idx).split(split_char)
|
2579
|
-
if len(parts) != elements:
|
2580
|
-
problematic_indices.append(idx)
|
2581
|
-
if problematic_indices:
|
2582
|
-
print("Indices that cannot be separated into 5 parts:")
|
2583
|
-
for idx in problematic_indices:
|
2584
|
-
print(idx)
|
2585
|
-
raise ValueError(f"Found {len(problematic_indices)} problematic indices that do not split into {elements} parts.")
|
2586
|
-
|
2587
|
-
def generate_ml_scores(src, settings):
|
2588
|
-
|
2589
|
-
from .io import _read_and_merge_data
|
2590
|
-
from .plot import plot_plates
|
2591
|
-
from .utils import get_ml_results_paths
|
2592
|
-
from .settings import set_default_analyze_screen
|
2593
|
-
|
2594
|
-
settings = set_default_analyze_screen(settings)
|
2595
|
-
|
2596
|
-
settings_df = pd.DataFrame(list(settings.items()), columns=['Key', 'Value'])
|
2597
|
-
display(settings_df)
|
2598
|
-
|
2599
|
-
db_loc = [src+'/measurements/measurements.db']
|
2600
|
-
tables = ['cell', 'nucleus', 'pathogen','cytoplasm']
|
2601
|
-
include_multinucleated, include_multiinfected, include_noninfected = True, 2.0, True
|
2602
|
-
|
2603
|
-
df, _ = _read_and_merge_data(db_loc,
|
2604
|
-
tables,
|
2605
|
-
settings['verbose'],
|
2606
|
-
include_multinucleated,
|
2607
|
-
include_multiinfected,
|
2608
|
-
include_noninfected)
|
2609
|
-
|
2610
|
-
if settings['channel_of_interest'] in [0,1,2,3]:
|
2611
|
-
|
2612
|
-
df['recruitment'] = df[f"pathogen_channel_{settings['channel_of_interest']}_mean_intensity"]/df[f"cytoplasm_channel_{settings['channel_of_interest']}_mean_intensity"]
|
2613
|
-
|
2614
|
-
output, figs = ml_analysis(df,
|
2615
|
-
settings['channel_of_interest'],
|
2616
|
-
settings['location_column'],
|
2617
|
-
settings['positive_control'],
|
2618
|
-
settings['negative_control'],
|
2619
|
-
settings['exclude'],
|
2620
|
-
settings['n_repeats'],
|
2621
|
-
settings['top_features'],
|
2622
|
-
settings['n_estimators'],
|
2623
|
-
settings['test_size'],
|
2624
|
-
settings['model_type_ml'],
|
2625
|
-
settings['n_jobs'],
|
2626
|
-
settings['remove_low_variance_features'],
|
2627
|
-
settings['remove_highly_correlated_features'],
|
2628
|
-
settings['verbose'])
|
2629
|
-
|
2630
|
-
shap_fig = shap_analysis(output[3], output[4], output[5])
|
2631
|
-
|
2632
|
-
features = output[0].select_dtypes(include=[np.number]).columns.tolist()
|
2633
|
-
|
2634
|
-
if not settings['heatmap_feature'] in features:
|
2635
|
-
raise ValueError(f"Variable {settings['heatmap_feature']} not found in the dataframe. Please choose one of the following: {features}")
|
2636
|
-
|
2637
|
-
plate_heatmap = plot_plates(df=output[0],
|
2638
|
-
variable=settings['heatmap_feature'],
|
2639
|
-
grouping=settings['grouping'],
|
2640
|
-
min_max=settings['min_max'],
|
2641
|
-
cmap=settings['cmap'],
|
2642
|
-
min_count=settings['minimum_cell_count'],
|
2643
|
-
verbose=settings['verbose'])
|
2644
|
-
|
2645
|
-
data_path, permutation_path, feature_importance_path, model_metricks_path, permutation_fig_path, feature_importance_fig_path, shap_fig_path, plate_heatmap_path, settings_csv = get_ml_results_paths(src, settings['model_type_ml'], settings['channel_of_interest'])
|
2646
|
-
df, permutation_df, feature_importance_df, _, _, _, _, _, metrics_df = output
|
2647
|
-
|
2648
|
-
settings_df.to_csv(settings_csv, index=False)
|
2649
|
-
df.to_csv(data_path, mode='w', encoding='utf-8')
|
2650
|
-
permutation_df.to_csv(permutation_path, mode='w', encoding='utf-8')
|
2651
|
-
feature_importance_df.to_csv(feature_importance_path, mode='w', encoding='utf-8')
|
2652
|
-
metrics_df.to_csv(model_metricks_path, mode='w', encoding='utf-8')
|
2653
|
-
|
2654
|
-
plate_heatmap.savefig(plate_heatmap_path, format='pdf')
|
2655
|
-
figs[0].savefig(permutation_fig_path, format='pdf')
|
2656
|
-
figs[1].savefig(feature_importance_fig_path, format='pdf')
|
2657
|
-
shap_fig.savefig(shap_fig_path, format='pdf')
|
2658
|
-
|
2659
|
-
return [output, plate_heatmap]
|
2660
|
-
|
2661
|
-
def join_measurments_and_annotation(src, tables = ['cell', 'nucleus', 'pathogen','cytoplasm']):
|
2662
|
-
|
2663
|
-
from .io import _read_and_merge_data, _read_db
|
2664
|
-
|
2665
|
-
db_loc = [src+'/measurements/measurements.db']
|
2666
|
-
loc = src+'/measurements/measurements.db'
|
2667
|
-
df, _ = _read_and_merge_data(db_loc,
|
2668
|
-
tables,
|
2669
|
-
verbose=True,
|
2670
|
-
include_multinucleated=True,
|
2671
|
-
include_multiinfected=True,
|
2672
|
-
include_noninfected=True)
|
2673
|
-
|
2674
|
-
paths_df = _read_db(loc, tables=['png_list'])
|
2675
|
-
|
2676
|
-
merged_df = pd.merge(df, paths_df[0], on='prcfo', how='left')
|
2677
|
-
|
2678
|
-
return merged_df
|
2679
|
-
|
2680
|
-
def jitterplot_by_annotation(src, x_column, y_column, plot_title='Jitter Plot', output_path=None, filter_column=None, filter_values=None):
|
2681
|
-
"""
|
2682
|
-
Reads a CSV file and creates a jitter plot of one column grouped by another column.
|
2683
|
-
|
2684
|
-
Args:
|
2685
|
-
src (str): Path to the source data.
|
2686
|
-
x_column (str): Name of the column to be used for the x-axis.
|
2687
|
-
y_column (str): Name of the column to be used for the y-axis.
|
2688
|
-
plot_title (str): Title of the plot. Default is 'Jitter Plot'.
|
2689
|
-
output_path (str): Path to save the plot image. If None, the plot will be displayed. Default is None.
|
2690
|
-
|
2691
|
-
Returns:
|
2692
|
-
pd.DataFrame: The filtered and balanced DataFrame.
|
2693
|
-
"""
|
2694
|
-
# Read the CSV file into a DataFrame
|
2695
|
-
df = join_measurments_and_annotation(src, tables=['cell', 'nucleus', 'pathogen', 'cytoplasm'])
|
2696
|
-
|
2697
|
-
# Print column names for debugging
|
2698
|
-
print(f"Generated dataframe with: {df.shape[1]} columns and {df.shape[0]} rows")
|
2699
|
-
#print("Columns in DataFrame:", df.columns.tolist())
|
2700
|
-
|
2701
|
-
# Replace NaN values with a specific label in x_column
|
2702
|
-
df[x_column] = df[x_column].fillna('NaN')
|
2703
|
-
|
2704
|
-
# Filter the DataFrame if filter_column and filter_values are provided
|
2705
|
-
if not filter_column is None:
|
2706
|
-
if isinstance(filter_column, str):
|
2707
|
-
df = df[df[filter_column].isin(filter_values)]
|
2708
|
-
if isinstance(filter_column, list):
|
2709
|
-
for i,val in enumerate(filter_column):
|
2710
|
-
print(f'hello {len(df)}')
|
2711
|
-
df = df[df[val].isin(filter_values[i])]
|
2712
|
-
|
2713
|
-
# Use the correct column names based on your DataFrame
|
2714
|
-
required_columns = ['plate_x', 'row_x', 'col_x']
|
2715
|
-
if not all(column in df.columns for column in required_columns):
|
2716
|
-
raise KeyError(f"DataFrame does not contain the necessary columns: {required_columns}")
|
2717
|
-
|
2718
|
-
# Filter to retain rows with non-NaN values in x_column and with matching plate, row, col values
|
2719
|
-
non_nan_df = df[df[x_column] != 'NaN']
|
2720
|
-
retained_rows = df[df[['plate_x', 'row_x', 'col_x']].apply(tuple, axis=1).isin(non_nan_df[['plate_x', 'row_x', 'col_x']].apply(tuple, axis=1))]
|
2721
|
-
|
2722
|
-
# Determine the minimum count of examples across all groups in x_column
|
2723
|
-
min_count = retained_rows[x_column].value_counts().min()
|
2724
|
-
print(f'Found {min_count} annotated images')
|
2725
|
-
|
2726
|
-
# Randomly sample min_count examples from each group in x_column
|
2727
|
-
balanced_df = retained_rows.groupby(x_column).apply(lambda x: x.sample(min_count, random_state=42)).reset_index(drop=True)
|
2728
|
-
|
2729
|
-
# Create the jitter plot
|
2730
|
-
plt.figure(figsize=(10, 6))
|
2731
|
-
jitter_plot = sns.stripplot(data=balanced_df, x=x_column, y=y_column, hue=x_column, jitter=True, palette='viridis', dodge=False)
|
2732
|
-
plt.title(plot_title)
|
2733
|
-
plt.xlabel(x_column)
|
2734
|
-
plt.ylabel(y_column)
|
2735
|
-
|
2736
|
-
# Customize the x-axis labels
|
2737
|
-
plt.xticks(rotation=45, ha='right')
|
2738
|
-
|
2739
|
-
# Adjust the position of the x-axis labels to be centered below the data
|
2740
|
-
ax = plt.gca()
|
2741
|
-
ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha='center')
|
2742
|
-
|
2743
|
-
# Save the plot to a file or display it
|
2744
|
-
if output_path:
|
2745
|
-
plt.savefig(output_path, bbox_inches='tight')
|
2746
|
-
print(f"Jitter plot saved to {output_path}")
|
2747
|
-
else:
|
2748
|
-
plt.show()
|
2749
|
-
|
2750
|
-
return balanced_df
|
2751
|
-
|
2752
409
|
def generate_image_umap(settings={}):
|
2753
410
|
"""
|
2754
411
|
Generate UMAP or tSNE embedding and visualize the data with clustering.
|
@@ -2784,7 +441,7 @@ def generate_image_umap(settings={}):
|
|
2784
441
|
"""
|
2785
442
|
|
2786
443
|
from .io import _read_and_join_tables
|
2787
|
-
from .utils import get_db_paths, preprocess_data, reduction_and_clustering, remove_noise, generate_colors, correct_paths, plot_embedding, plot_clusters_grid, cluster_feature_analysis
|
444
|
+
from .utils import get_db_paths, preprocess_data, reduction_and_clustering, remove_noise, generate_colors, correct_paths, plot_embedding, plot_clusters_grid, cluster_feature_analysis, map_condition
|
2788
445
|
from .settings import set_default_umap_image_settings
|
2789
446
|
settings = set_default_umap_image_settings(settings)
|
2790
447
|
|
@@ -2933,17 +590,6 @@ def generate_image_umap(settings={}):
|
|
2933
590
|
|
2934
591
|
return all_df
|
2935
592
|
|
2936
|
-
# Define the mapping function
|
2937
|
-
def map_condition(col_value, neg='c1', pos='c2', mix='c3'):
|
2938
|
-
if col_value == neg:
|
2939
|
-
return 'neg'
|
2940
|
-
elif col_value == pos:
|
2941
|
-
return 'pos'
|
2942
|
-
elif col_value == mix:
|
2943
|
-
return 'mix'
|
2944
|
-
else:
|
2945
|
-
return 'screen'
|
2946
|
-
|
2947
593
|
def reducer_hyperparameter_search(settings={}, reduction_params=None, dbscan_params=None, kmeans_params=None, save=False):
|
2948
594
|
"""
|
2949
595
|
Perform a hyperparameter search for UMAP or tSNE on the given data.
|
@@ -2970,7 +616,7 @@ def reducer_hyperparameter_search(settings={}, reduction_params=None, dbscan_par
|
|
2970
616
|
"""
|
2971
617
|
|
2972
618
|
from .io import _read_and_join_tables
|
2973
|
-
from .utils import get_db_paths, preprocess_data, search_reduction_and_clustering, generate_colors
|
619
|
+
from .utils import get_db_paths, preprocess_data, search_reduction_and_clustering, generate_colors, map_condition
|
2974
620
|
from .settings import set_default_umap_image_settings
|
2975
621
|
|
2976
622
|
settings = set_default_umap_image_settings(settings)
|
@@ -3122,7 +768,8 @@ def generate_mediar_masks(src, settings, object_type):
|
|
3122
768
|
from .mediar import MEDIARPredictor
|
3123
769
|
from .io import _create_database, _save_object_counts_to_database
|
3124
770
|
from .plot import plot_masks
|
3125
|
-
from .settings import set_default_settings_preprocess_generate_masks
|
771
|
+
from .settings import set_default_settings_preprocess_generate_masks
|
772
|
+
from .utils import prepare_batch_for_segmentation
|
3126
773
|
|
3127
774
|
# Clear CUDA cache and check if CUDA is available
|
3128
775
|
gc.collect()
|
@@ -3197,4 +844,108 @@ def generate_mediar_masks(src, settings, object_type):
|
|
3197
844
|
gc.collect()
|
3198
845
|
torch.cuda.empty_cache()
|
3199
846
|
|
3200
|
-
print("Mask generation completed.")
|
847
|
+
print("Mask generation completed.")
|
848
|
+
|
849
|
+
def generate_screen_graphs(settings):
|
850
|
+
"""
|
851
|
+
Generate screen graphs for different measurements in a given source directory.
|
852
|
+
|
853
|
+
Args:
|
854
|
+
src (str or list): Path(s) to the source directory or directories.
|
855
|
+
tables (list): List of tables to include in the analysis (default: ['cell', 'nucleus', 'pathogen', 'cytoplasm']).
|
856
|
+
graph_type (str): Type of graph to generate (default: 'bar').
|
857
|
+
summary_func (str or function): Function to summarize data (default: 'mean').
|
858
|
+
y_axis_start (float): Starting value for the y-axis (default: 0).
|
859
|
+
error_bar_type (str): Type of error bar to use ('std' or 'sem') (default: 'std').
|
860
|
+
theme (str): Theme for the graph (default: 'pastel').
|
861
|
+
representation (str): Representation for grouping (default: 'well').
|
862
|
+
|
863
|
+
Returns:
|
864
|
+
figs (list): List of generated figures.
|
865
|
+
results (list): List of corresponding result DataFrames.
|
866
|
+
"""
|
867
|
+
|
868
|
+
from .plot import spacrGraph
|
869
|
+
from .io import _read_and_merge_data
|
870
|
+
from.utils import annotate_conditions
|
871
|
+
|
872
|
+
if isinstance(settings['src'], str):
|
873
|
+
srcs = [settings['src']]
|
874
|
+
else:
|
875
|
+
srcs = settings['src']
|
876
|
+
|
877
|
+
all_df = pd.DataFrame()
|
878
|
+
figs = []
|
879
|
+
results = []
|
880
|
+
|
881
|
+
for src in srcs:
|
882
|
+
db_loc = [os.path.join(src, 'measurements', 'measurements.db')]
|
883
|
+
|
884
|
+
# Read and merge data from the database
|
885
|
+
df, _ = _read_and_merge_data(db_loc, settings['tables'], verbose=True, nuclei_limit=settings['nuclei_limit'], pathogen_limit=settings['pathogen_limit'], uninfected=settings['uninfected'])
|
886
|
+
|
887
|
+
# Annotate the data
|
888
|
+
df = annotate_conditions(df, cells=settings['cells'], cell_loc=None, pathogens=settings['controls'], pathogen_loc=settings['controls_loc'], treatments=None, treatment_loc=None)
|
889
|
+
|
890
|
+
# Calculate recruitment metric
|
891
|
+
df['recruitment'] = df['pathogen_channel_1_mean_intensity'] / df['cytoplasm_channel_1_mean_intensity']
|
892
|
+
|
893
|
+
# Combine with the overall DataFrame
|
894
|
+
all_df = pd.concat([all_df, df], ignore_index=True)
|
895
|
+
|
896
|
+
# Generate individual plot
|
897
|
+
plotter = spacrGraph(df,
|
898
|
+
grouping_column='pathogen',
|
899
|
+
data_column='recruitment',
|
900
|
+
graph_type=settings['graph_type'],
|
901
|
+
summary_func=settings['summary_func'],
|
902
|
+
y_axis_start=settings['y_axis_start'],
|
903
|
+
error_bar_type=settings['error_bar_type'],
|
904
|
+
theme=settings['theme'],
|
905
|
+
representation=settings['representation'])
|
906
|
+
|
907
|
+
plotter.create_plot()
|
908
|
+
fig = plotter.get_figure()
|
909
|
+
results_df = plotter.get_results()
|
910
|
+
|
911
|
+
# Append to the lists
|
912
|
+
figs.append(fig)
|
913
|
+
results.append(results_df)
|
914
|
+
|
915
|
+
# Generate plot for the combined data (all_df)
|
916
|
+
plotter = spacrGraph(all_df,
|
917
|
+
grouping_column='pathogen',
|
918
|
+
data_column='recruitment',
|
919
|
+
graph_type=settings['graph_type'],
|
920
|
+
summary_func=settings['summary_func'],
|
921
|
+
y_axis_start=settings['y_axis_start'],
|
922
|
+
error_bar_type=settings['error_bar_type'],
|
923
|
+
theme=settings['theme'],
|
924
|
+
representation=settings['representation'])
|
925
|
+
|
926
|
+
plotter.create_plot()
|
927
|
+
fig = plotter.get_figure()
|
928
|
+
results_df = plotter.get_results()
|
929
|
+
|
930
|
+
figs.append(fig)
|
931
|
+
results.append(results_df)
|
932
|
+
|
933
|
+
# Save figures and results
|
934
|
+
for i, fig in enumerate(figs):
|
935
|
+
res = results[i]
|
936
|
+
|
937
|
+
if i < len(srcs):
|
938
|
+
source = srcs[i]
|
939
|
+
else:
|
940
|
+
source = srcs[0]
|
941
|
+
|
942
|
+
# Ensure the destination folder exists
|
943
|
+
dst = os.path.join(source, 'results')
|
944
|
+
print(f"Savings results to {dst}")
|
945
|
+
os.makedirs(dst, exist_ok=True)
|
946
|
+
|
947
|
+
# Save the figure and results DataFrame
|
948
|
+
fig.savefig(os.path.join(dst, f"figure_controls_{i}_{settings['representation']}_{settings['summary_func']}_{settings['graph_type']}.pdf"), format='pdf')
|
949
|
+
res.to_csv(os.path.join(dst, f"results_controls_{i}_{settings['representation']}_{settings['summary_func']}_{settings['graph_type']}.csv"), index=False)
|
950
|
+
|
951
|
+
return
|