spacr 0.3.1__py3-none-any.whl → 0.3.22__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. spacr/__init__.py +19 -3
  2. spacr/cellpose.py +311 -0
  3. spacr/core.py +245 -2494
  4. spacr/deep_spacr.py +316 -48
  5. spacr/gui.py +1 -0
  6. spacr/gui_core.py +74 -63
  7. spacr/gui_elements.py +110 -5
  8. spacr/gui_utils.py +346 -6
  9. spacr/io.py +680 -141
  10. spacr/logger.py +28 -9
  11. spacr/measure.py +107 -95
  12. spacr/mediar.py +0 -3
  13. spacr/ml.py +1051 -0
  14. spacr/openai.py +37 -0
  15. spacr/plot.py +707 -20
  16. spacr/resources/data/lopit.csv +3833 -0
  17. spacr/resources/data/toxoplasma_metadata.csv +8843 -0
  18. spacr/resources/icons/convert.png +0 -0
  19. spacr/resources/{models/cp/toxo_plaque_cyto_e25000_X1120_Y1120.CP_model → icons/dna_matrix.mp4} +0 -0
  20. spacr/sequencing.py +241 -1311
  21. spacr/settings.py +134 -47
  22. spacr/sim.py +0 -2
  23. spacr/submodules.py +349 -0
  24. spacr/timelapse.py +0 -2
  25. spacr/toxo.py +238 -0
  26. spacr/utils.py +419 -180
  27. {spacr-0.3.1.dist-info → spacr-0.3.22.dist-info}/METADATA +31 -22
  28. {spacr-0.3.1.dist-info → spacr-0.3.22.dist-info}/RECORD +32 -33
  29. spacr/chris.py +0 -50
  30. spacr/graph_learning.py +0 -340
  31. spacr/resources/MEDIAR/.git +0 -1
  32. spacr/resources/MEDIAR_weights/.DS_Store +0 -0
  33. spacr/resources/icons/.DS_Store +0 -0
  34. spacr/resources/icons/spacr_logo_rotation.gif +0 -0
  35. spacr/resources/models/cp/toxo_plaque_cyto_e25000_X1120_Y1120.CP_model_settings.csv +0 -23
  36. spacr/resources/models/cp/toxo_pv_lumen.CP_model +0 -0
  37. spacr/sim_app.py +0 -0
  38. {spacr-0.3.1.dist-info → spacr-0.3.22.dist-info}/LICENSE +0 -0
  39. {spacr-0.3.1.dist-info → spacr-0.3.22.dist-info}/WHEEL +0 -0
  40. {spacr-0.3.1.dist-info → spacr-0.3.22.dist-info}/entry_points.txt +0 -0
  41. {spacr-0.3.1.dist-info → spacr-0.3.22.dist-info}/top_level.txt +0 -0
spacr/core.py CHANGED
@@ -1,1925 +1,169 @@
1
- import os, sqlite3, gc, torch, time, random, shutil, cv2, tarfile, datetime, shap
2
-
1
+ import os, gc, torch, time, random
3
2
  import numpy as np
4
3
  import pandas as pd
5
-
6
- from cellpose import train
7
- from cellpose import models as cp_models
8
-
9
- import statsmodels.formula.api as smf
10
- import statsmodels.api as sm
11
- from functools import reduce
12
- from IPython.display import display
13
- from multiprocessing import Pool, cpu_count, Value, Lock
14
-
15
- import seaborn as sns
16
- import cellpose
17
- from skimage.measure import regionprops, label
18
- from skimage.transform import resize as resizescikit
19
-
20
- from skimage import measure
21
- from sklearn.model_selection import train_test_split
22
- from sklearn.ensemble import IsolationForest, RandomForestClassifier, HistGradientBoostingClassifier
23
- from sklearn.linear_model import LogisticRegression
24
- from sklearn.inspection import permutation_importance
25
- from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, classification_report
26
- from sklearn.preprocessing import StandardScaler
27
- from sklearn.metrics import precision_recall_curve, f1_score
28
-
29
- from scipy.spatial.distance import cosine, euclidean, mahalanobis, cityblock, minkowski, chebyshev, hamming, jaccard, braycurtis
30
-
31
- import torchvision.transforms as transforms
32
- from xgboost import XGBClassifier
33
- import shap
34
-
35
4
  import matplotlib.pyplot as plt
36
- import matplotlib
37
- matplotlib.use('Agg')
38
-
39
- from .logger import log_function_call
5
+ from IPython.display import display
40
6
 
41
7
  import warnings
42
8
  warnings.filterwarnings("ignore", message="3D stack used, but stitch_threshold=0 and do_3D=False, so masks are made per plane only")
43
9
 
44
-
45
- from torchvision import transforms
46
- from torch.utils.data import DataLoader, random_split
47
- from collections import defaultdict
48
- import os
49
- import random
50
- from PIL import Image
51
- from torchvision.transforms import ToTensor
52
-
53
- def analyze_plaques(folder):
54
- summary_data = []
55
- details_data = []
56
- stats_data = []
57
-
58
- for filename in os.listdir(folder):
59
- filepath = os.path.join(folder, filename)
60
- if os.path.isfile(filepath):
61
- # Assuming each file is a NumPy array file (.npy) containing a 16-bit labeled image
62
- #image = np.load(filepath)
63
- image = cellpose.io.imread(filepath)
64
- labeled_image = label(image)
65
- regions = regionprops(labeled_image)
66
-
67
- object_count = len(regions)
68
- sizes = [region.area for region in regions]
69
- average_size = np.mean(sizes) if sizes else 0
70
- std_dev_size = np.std(sizes) if sizes else 0
71
-
72
- summary_data.append({'file': filename, 'object_count': object_count, 'average_size': average_size})
73
- stats_data.append({'file': filename, 'plaque_count': object_count, 'average_size': average_size, 'std_dev_size': std_dev_size})
74
- for size in sizes:
75
- details_data.append({'file': filename, 'plaque_size': size})
76
-
77
- # Convert lists to pandas DataFrames
78
- summary_df = pd.DataFrame(summary_data)
79
- details_df = pd.DataFrame(details_data)
80
- stats_df = pd.DataFrame(stats_data)
81
-
82
- # Save DataFrames to a SQLite database
83
- db_name = os.path.join(folder, 'plaques_analysis.db')
84
- conn = sqlite3.connect(db_name)
85
-
86
- summary_df.to_sql('summary', conn, if_exists='replace', index=False)
87
- details_df.to_sql('details', conn, if_exists='replace', index=False)
88
- stats_df.to_sql('stats', conn, if_exists='replace', index=False)
89
-
90
- conn.close()
91
-
92
- print(f"Analysis completed and saved to database '{db_name}'.")
93
-
94
- def train_cellpose(settings):
95
-
96
- from .io import _load_normalized_images_and_labels, _load_images_and_labels
97
- from .settings import get_train_cellpose_default_settings#, resize_images_and_labels
98
-
99
- settings = get_train_cellpose_default_settings()
100
-
101
- img_src = settings['img_src']
102
- mask_src = os.path.join(img_src, 'masks')
103
-
104
- model_name = settings.setdefault( 'model_name', '')
105
-
106
- model_name = settings.setdefault('model_name', 'model_name')
107
-
108
- model_type = settings.setdefault( 'model_type', 'cyto')
109
- learning_rate = settings.setdefault( 'learning_rate', 0.01)
110
- weight_decay = settings.setdefault( 'weight_decay', 1e-05)
111
- batch_size = settings.setdefault( 'batch_size', 50)
112
- n_epochs = settings.setdefault( 'n_epochs', 100)
113
- from_scratch = settings.setdefault( 'from_scratch', False)
114
- diameter = settings.setdefault( 'diameter', 40)
115
-
116
- remove_background = settings.setdefault( 'remove_background', False)
117
- background = settings.setdefault( 'background', 100)
118
- Signal_to_noise = settings.setdefault( 'Signal_to_noise', 10)
119
- verbose = settings.setdefault( 'verbose', False)
120
-
121
- channels = settings.setdefault( 'channels', [0,0])
122
- normalize = settings.setdefault( 'normalize', True)
123
- percentiles = settings.setdefault( 'percentiles', None)
124
- circular = settings.setdefault( 'circular', False)
125
- invert = settings.setdefault( 'invert', False)
126
- resize = settings.setdefault( 'resize', False)
127
-
128
- if resize:
129
- target_height = settings['width_height'][1]
130
- target_width = settings['width_height'][0]
131
-
132
- grayscale = settings.setdefault( 'grayscale', True)
133
- rescale = settings.setdefault( 'channels', False)
134
- test = settings.setdefault( 'test', False)
135
-
136
- if test:
137
- test_img_src = os.path.join(os.path.dirname(img_src), 'test')
138
- test_mask_src = os.path.join(test_img_src, 'mask')
139
-
140
- test_images, test_masks, test_image_names, test_mask_names = None,None,None,None
141
- print(settings)
142
-
143
- if from_scratch:
144
- model_name=f'scratch_{model_name}_{model_type}_e{n_epochs}_X{target_width}_Y{target_height}.CP_model'
145
- else:
146
- if resize:
147
- model_name=f'{model_name}_{model_type}_e{n_epochs}_X{target_width}_Y{target_height}.CP_model'
148
- else:
149
- model_name=f'{model_name}_{model_type}_e{n_epochs}.CP_model'
150
-
151
- model_save_path = os.path.join(mask_src, 'models', 'cellpose_model')
152
- print(model_save_path)
153
- os.makedirs(model_save_path, exist_ok=True)
154
-
155
- settings_df = pd.DataFrame(list(settings.items()), columns=['Key', 'Value'])
156
- settings_csv = os.path.join(model_save_path,f'{model_name}_settings.csv')
157
- settings_df.to_csv(settings_csv, index=False)
158
-
159
- if from_scratch:
160
- model = cp_models.CellposeModel(gpu=True, model_type=model_type, diam_mean=diameter, pretrained_model=None)
161
- else:
162
- model = cp_models.CellposeModel(gpu=True, model_type=model_type)
163
-
164
- if normalize:
165
-
166
- image_files = [os.path.join(img_src, f) for f in os.listdir(img_src) if f.endswith('.tif')]
167
- label_files = [os.path.join(mask_src, f) for f in os.listdir(mask_src) if f.endswith('.tif')]
168
- images, masks, image_names, mask_names, orig_dims = _load_normalized_images_and_labels(image_files, label_files, channels, percentiles, circular, invert, verbose, remove_background, background, Signal_to_noise, target_height, target_width)
169
- images = [np.squeeze(img) if img.shape[-1] == 1 else img for img in images]
170
-
171
- if test:
172
- test_image_files = [os.path.join(test_img_src, f) for f in os.listdir(test_img_src) if f.endswith('.tif')]
173
- test_label_files = [os.path.join(test_mask_src, f) for f in os.listdir(test_mask_src) if f.endswith('.tif')]
174
- test_images, test_masks, test_image_names, test_mask_names = _load_normalized_images_and_labels(test_image_files, test_label_files, channels, percentiles, circular, invert, verbose, remove_background, background, Signal_to_noise, target_height, target_width)
175
- test_images = [np.squeeze(img) if img.shape[-1] == 1 else img for img in test_images]
176
-
177
- else:
178
- images, masks, image_names, mask_names = _load_images_and_labels(img_src, mask_src, circular, invert)
179
- images = [np.squeeze(img) if img.shape[-1] == 1 else img for img in images]
180
-
181
- if test:
182
- test_images, test_masks, test_image_names, test_mask_names = _load_images_and_labels(img_src=test_img_src, mask_src=test_mask_src, circular=circular, invert=invert)
183
- test_images = [np.squeeze(img) if img.shape[-1] == 1 else img for img in test_images]
184
-
185
- #if resize:
186
- # images, masks = resize_images_and_labels(images, masks, target_height, target_width, show_example=True)
187
-
188
- if model_type == 'cyto':
189
- cp_channels = [0,1]
190
- if model_type == 'cyto2':
191
- cp_channels = [0,2]
192
- if model_type == 'nucleus':
193
- cp_channels = [0,0]
194
- if grayscale:
195
- cp_channels = [0,0]
196
- images = [np.squeeze(img) if img.ndim == 3 and 1 in img.shape else img for img in images]
197
-
198
- masks = [np.squeeze(mask) if mask.ndim == 3 and 1 in mask.shape else mask for mask in masks]
199
-
200
- print(f'image shape: {images[0].shape}, image type: images[0].shape mask shape: {masks[0].shape}, image type: masks[0].shape')
201
- save_every = int(n_epochs/10)
202
- if save_every < 10:
203
- save_every = n_epochs
204
-
205
- train.train_seg(model.net,
206
- train_data=images,
207
- train_labels=masks,
208
- train_files=image_names,
209
- train_labels_files=mask_names,
210
- train_probs=None,
211
- test_data=test_images,
212
- test_labels=test_masks,
213
- test_files=test_image_names,
214
- test_labels_files=test_mask_names,
215
- test_probs=None,
216
- load_files=True,
217
- batch_size=batch_size,
218
- learning_rate=learning_rate,
219
- n_epochs=n_epochs,
220
- weight_decay=weight_decay,
221
- momentum=0.9,
222
- SGD=False,
223
- channels=cp_channels,
224
- channel_axis=None,
225
- #rgb=False,
226
- normalize=False,
227
- compute_flows=False,
228
- save_path=model_save_path,
229
- save_every=save_every,
230
- nimg_per_epoch=None,
231
- nimg_test_per_epoch=None,
232
- rescale=rescale,
233
- #scale_range=None,
234
- #bsize=224,
235
- min_train_masks=1,
236
- model_name=model_name)
237
-
238
- return print(f"Model saved at: {model_save_path}/{model_name}")
239
-
240
- def analyze_data_reg(sequencing_loc, dv_loc, agg_type = 'mean', dv_col='pred', transform=None, min_cell_count=50, min_reads=100, min_wells=2, max_wells=1000, min_frequency=0.0,remove_outlier_genes=False, refine_model=False,by_plate=False, regression_type='mlr', alpha_value=0.01, fishers=False, fisher_threshold=0.9):
241
-
242
- from .plot import _reg_v_plot
243
- from .utils import generate_fraction_map, MLR, fishers_odds, lasso_reg
244
-
245
- def qstring_to_float(qstr):
246
- number = int(qstr[1:]) # Remove the "q" and convert the rest to an integer
247
- return number / 100.0
248
-
249
- columns_list = ['c1', 'c2', 'c3']
250
- plate_list = ['p1','p3','p4']
251
-
252
- dv_df = pd.read_csv(dv_loc)#, index_col='prc')
253
-
254
- if agg_type.startswith('q'):
255
- val = qstring_to_float(agg_type)
256
- agg_type = lambda x: x.quantile(val)
257
-
258
- # Aggregating for mean prediction, total count and count of values > 0.95
259
- dv_df = dv_df.groupby('prc').agg(
260
- pred=(dv_col, agg_type),
261
- count_prc=('prc', 'size'),
262
- mean_pathogen_area=('pathogen_area', 'mean')
263
- )
264
-
265
- dv_df = dv_df[dv_df['count_prc'] >= min_cell_count]
266
- sequencing_df = pd.read_csv(sequencing_loc)
267
-
268
-
269
- reads_df, stats_dict = process_reads(df=sequencing_df,
270
- min_reads=min_reads,
271
- min_wells=min_wells,
272
- max_wells=max_wells,
273
- gene_column='gene',
274
- remove_outliers=remove_outlier_genes)
275
-
276
- reads_df['value'] = reads_df['count']/reads_df['well_read_sum']
277
- reads_df['gene_grna'] = reads_df['gene']+'_'+reads_df['grna']
278
-
279
- display(reads_df)
280
-
281
- df_long = reads_df
282
-
283
- df_long = df_long[df_long['value'] > min_frequency] # removes gRNAs under a certain proportion
284
- #df_long = df_long[df_long['value']<1.0] # removes gRNAs in wells with only one gRNA
285
-
286
- # Extract gene and grna info from gene_grna column
287
- df_long["gene"] = df_long["grna"].str.split("_").str[1]
288
- df_long["grna"] = df_long["grna"].str.split("_").str[2]
289
-
290
- agg_df = df_long.groupby('prc')['count'].sum().reset_index()
291
- agg_df = agg_df.rename(columns={'count': 'count_sum'})
292
- df_long = pd.merge(df_long, agg_df, on='prc', how='left')
293
- df_long['value'] = df_long['count']/df_long['count_sum']
294
-
295
- merged_df = df_long.merge(dv_df, left_on='prc', right_index=True)
296
- merged_df = merged_df[merged_df['value'] > 0]
297
- merged_df['plate'] = merged_df['prc'].str.split('_').str[0]
298
- merged_df['row'] = merged_df['prc'].str.split('_').str[1]
299
- merged_df['column'] = merged_df['prc'].str.split('_').str[2]
300
-
301
- merged_df = merged_df[~merged_df['column'].isin(columns_list)]
302
- merged_df = merged_df[merged_df['plate'].isin(plate_list)]
303
-
304
- if transform == 'log':
305
- merged_df['pred'] = np.log(merged_df['pred'] + 1e-10)
306
-
307
- # Printing the unique values in 'col' and 'plate' columns
308
- print("Unique values in col:", merged_df['column'].unique())
309
- print("Unique values in plate:", merged_df['plate'].unique())
310
- display(merged_df)
311
-
312
- if fishers:
313
- iv_df = generate_fraction_map(df=reads_df,
314
- gene_column='grna',
315
- min_frequency=min_frequency)
316
-
317
- fishers_df = iv_df.join(dv_df, on='prc', how='inner')
318
-
319
- significant_mutants = fishers_odds(df=fishers_df, threshold=fisher_threshold, phenotyp_col='pred')
320
- significant_mutants = significant_mutants.sort_values(by='OddsRatio', ascending=False)
321
- display(significant_mutants)
322
-
323
- if regression_type == 'mlr':
324
- if by_plate:
325
- merged_df2 = merged_df.copy()
326
- for plate in merged_df2['plate'].unique():
327
- merged_df = merged_df2[merged_df2['plate'] == plate]
328
- print(f'merged_df: {len(merged_df)}, plate: {plate}')
329
- if len(merged_df) <100:
330
- break
331
-
332
- max_effects, max_effects_pvalues, model, df = MLR(merged_df, refine_model)
333
- else:
334
-
335
- max_effects, max_effects_pvalues, model, df = MLR(merged_df, refine_model)
336
- return max_effects, max_effects_pvalues, model, df
337
-
338
- if regression_type == 'ridge' or regression_type == 'lasso':
339
- coeffs = lasso_reg(merged_df, alpha_value=alpha_value, reg_type=regression_type)
340
- return coeffs
341
-
342
- if regression_type == 'mixed':
343
- model = smf.mixedlm("pred ~ gene_grna - 1", merged_df, groups=merged_df["plate"], re_formula="~1")
344
- result = model.fit(method="bfgs")
345
- print(result.summary())
346
-
347
- # Print AIC and BIC
348
- print("AIC:", result.aic)
349
- print("BIC:", result.bic)
350
-
351
-
352
- results_df = pd.DataFrame({
353
- 'effect': result.params,
354
- 'Standard Error': result.bse,
355
- 'T-Value': result.tvalues,
356
- 'p': result.pvalues
357
- })
358
-
359
- display(results_df)
360
- _reg_v_plot(df=results_df)
361
-
362
- std_resid = result.resid
363
-
364
- # Create subplots
365
- fig, axes = plt.subplots(1, 3, figsize=(15, 5))
366
-
367
- # Histogram of Residuals
368
- axes[0].hist(std_resid, bins=50, edgecolor='k')
369
- axes[0].set_xlabel('Residuals')
370
- axes[0].set_ylabel('Frequency')
371
- axes[0].set_title('Histogram of Residuals')
372
-
373
- # Boxplot of Residuals
374
- axes[1].boxplot(std_resid)
375
- axes[1].set_ylabel('Residuals')
376
- axes[1].set_title('Boxplot of Residuals')
377
-
378
- # QQ Plot
379
- sm.qqplot(std_resid, line='45', ax=axes[2])
380
- axes[2].set_title('QQ Plot')
381
-
382
- # Show plots
383
- plt.tight_layout()
384
- plt.show()
385
-
386
- return result
387
-
388
- def analyze_data_reg(sequencing_loc, dv_loc, agg_type = 'mean', min_cell_count=50, min_reads=100, min_wells=2, max_wells=1000, remove_outlier_genes=False, refine_model=False, by_plate=False, threshold=0.5, fishers=False):
389
-
390
- from .plot import _reg_v_plot
391
- from .utils import generate_fraction_map, fishers_odds, model_metrics
392
-
393
- def qstring_to_float(qstr):
394
- number = int(qstr[1:]) # Remove the "q" and convert the rest to an integer
395
- return number / 100.0
396
-
397
- columns_list = ['c1', 'c2', 'c3', 'c15']
398
- plate_list = ['p1','p2','p3','p4']
399
-
400
- dv_df = pd.read_csv(dv_loc)#, index_col='prc')
401
-
402
- if agg_type.startswith('q'):
403
- val = qstring_to_float(agg_type)
404
- agg_type = lambda x: x.quantile(val)
405
-
406
- # Aggregating for mean prediction, total count and count of values > 0.95
407
- dv_df = dv_df.groupby('prc').agg(
408
- pred=('pred', agg_type),
409
- count_prc=('prc', 'size'),
410
- #count_above_95=('pred', lambda x: (x > 0.95).sum()),
411
- mean_pathogen_area=('pathogen_area', 'mean')
412
- )
413
-
414
- dv_df = dv_df[dv_df['count_prc'] >= min_cell_count]
415
- sequencing_df = pd.read_csv(sequencing_loc)
416
-
417
- reads_df, stats_dict = process_reads(df=sequencing_df,
418
- min_reads=min_reads,
419
- min_wells=min_wells,
420
- max_wells=max_wells,
421
- gene_column='gene',
422
- remove_outliers=remove_outlier_genes)
423
-
424
- iv_df = generate_fraction_map(df=reads_df,
425
- gene_column='grna',
426
- min_frequency=0.0)
427
-
428
- # Melt the iv_df to long format
429
- df_long = iv_df.reset_index().melt(id_vars=["prc"],
430
- value_vars=iv_df.columns,
431
- var_name="gene_grna",
432
- value_name="value")
433
-
434
- # Extract gene and grna info from gene_grna column
435
- df_long["gene"] = df_long["gene_grna"].str.split("_").str[1]
436
- df_long["grna"] = df_long["gene_grna"].str.split("_").str[2]
437
-
438
- merged_df = df_long.merge(dv_df, left_on='prc', right_index=True)
439
- merged_df = merged_df[merged_df['value'] > 0]
440
- merged_df['plate'] = merged_df['prc'].str.split('_').str[0]
441
- merged_df['row'] = merged_df['prc'].str.split('_').str[1]
442
- merged_df['column'] = merged_df['prc'].str.split('_').str[2]
443
-
444
- merged_df = merged_df[~merged_df['column'].isin(columns_list)]
445
- merged_df = merged_df[merged_df['plate'].isin(plate_list)]
446
-
447
- # Printing the unique values in 'col' and 'plate' columns
448
- print("Unique values in col:", merged_df['column'].unique())
449
- print("Unique values in plate:", merged_df['plate'].unique())
450
-
451
- if not by_plate:
452
- if fishers:
453
- fishers_odds(df=merged_df, threshold=threshold, phenotyp_col='pred')
454
-
455
- if by_plate:
456
- merged_df2 = merged_df.copy()
457
- for plate in merged_df2['plate'].unique():
458
- merged_df = merged_df2[merged_df2['plate'] == plate]
459
- print(f'merged_df: {len(merged_df)}, plate: {plate}')
460
- if len(merged_df) <100:
461
- break
462
- display(merged_df)
463
-
464
- model = smf.ols("pred ~ gene + grna + gene:grna + plate + row + column", merged_df).fit()
465
- #model = smf.ols("pred ~ infection_time + gene + grna + gene:grna + plate + row + column", merged_df).fit()
466
-
467
- # Display model metrics and summary
468
- model_metrics(model)
469
- #print(model.summary())
470
-
471
- if refine_model:
472
- # Filter outliers
473
- std_resid = model.get_influence().resid_studentized_internal
474
- outliers_resid = np.where(np.abs(std_resid) > 3)[0]
475
- (c, p) = model.get_influence().cooks_distance
476
- outliers_cooks = np.where(c > 4/(len(merged_df)-merged_df.shape[1]-1))[0]
477
- outliers = reduce(np.union1d, (outliers_resid, outliers_cooks))
478
- merged_df_filtered = merged_df.drop(merged_df.index[outliers])
479
-
480
- display(merged_df_filtered)
481
-
482
- # Refit the model with filtered data
483
- model = smf.ols("pred ~ gene + grna + gene:grna + row + column", merged_df_filtered).fit()
484
- print("Number of outliers detected by standardized residuals:", len(outliers_resid))
485
- print("Number of outliers detected by Cook's distance:", len(outliers_cooks))
486
-
487
- model_metrics(model)
488
-
489
- # Extract interaction coefficients and determine the maximum effect size
490
- interaction_coeffs = {key: val for key, val in model.params.items() if "gene[T." in key and ":grna[T." in key}
491
- interaction_pvalues = {key: val for key, val in model.pvalues.items() if "gene[T." in key and ":grna[T." in key}
492
-
493
- max_effects = {}
494
- max_effects_pvalues = {}
495
- for key, val in interaction_coeffs.items():
496
- gene_name = key.split(":")[0].replace("gene[T.", "").replace("]", "")
497
- if gene_name not in max_effects or abs(max_effects[gene_name]) < abs(val):
498
- max_effects[gene_name] = val
499
- max_effects_pvalues[gene_name] = interaction_pvalues[key]
500
-
501
- for key in max_effects:
502
- print(f"Key: {key}: {max_effects[key]}, p:{max_effects_pvalues[key]}")
503
-
504
- df = pd.DataFrame([max_effects, max_effects_pvalues])
505
- df = df.transpose()
506
- df = df.rename(columns={df.columns[0]: 'effect', df.columns[1]: 'p'})
507
- df = df.sort_values(by=['effect', 'p'], ascending=[False, True])
508
-
509
- _reg_v_plot(df)
510
-
511
- if fishers:
512
- fishers_odds(df=merged_df, threshold=threshold, phenotyp_col='pred')
513
- else:
514
- display(merged_df)
515
-
516
- model = smf.ols("pred ~ gene + grna + gene:grna + plate + row + column", merged_df).fit()
517
-
518
- # Display model metrics and summary
519
- model_metrics(model)
520
-
521
- if refine_model:
522
- # Filter outliers
523
- std_resid = model.get_influence().resid_studentized_internal
524
- outliers_resid = np.where(np.abs(std_resid) > 3)[0]
525
- (c, p) = model.get_influence().cooks_distance
526
- outliers_cooks = np.where(c > 4/(len(merged_df)-merged_df.shape[1]-1))[0]
527
- outliers = reduce(np.union1d, (outliers_resid, outliers_cooks))
528
- merged_df_filtered = merged_df.drop(merged_df.index[outliers])
529
-
530
- display(merged_df_filtered)
531
-
532
- # Refit the model with filtered data
533
- model = smf.ols("pred ~ gene + grna + gene:grna + plate + row + column", merged_df_filtered).fit()
534
- print("Number of outliers detected by standardized residuals:", len(outliers_resid))
535
- print("Number of outliers detected by Cook's distance:", len(outliers_cooks))
536
-
537
- model_metrics(model)
538
-
539
- # Extract interaction coefficients and determine the maximum effect size
540
- interaction_coeffs = {key: val for key, val in model.params.items() if "gene[T." in key and ":grna[T." in key}
541
- interaction_pvalues = {key: val for key, val in model.pvalues.items() if "gene[T." in key and ":grna[T." in key}
542
-
543
- max_effects = {}
544
- max_effects_pvalues = {}
545
- for key, val in interaction_coeffs.items():
546
- gene_name = key.split(":")[0].replace("gene[T.", "").replace("]", "")
547
- if gene_name not in max_effects or abs(max_effects[gene_name]) < abs(val):
548
- max_effects[gene_name] = val
549
- max_effects_pvalues[gene_name] = interaction_pvalues[key]
550
-
551
- for key in max_effects:
552
- print(f"Key: {key}: {max_effects[key]}, p:{max_effects_pvalues[key]}")
553
-
554
- df = pd.DataFrame([max_effects, max_effects_pvalues])
555
- df = df.transpose()
556
- df = df.rename(columns={df.columns[0]: 'effect', df.columns[1]: 'p'})
557
- df = df.sort_values(by=['effect', 'p'], ascending=[False, True])
558
-
559
- _reg_v_plot(df)
560
-
561
- if fishers:
562
- fishers_odds(df=merged_df, threshold=threshold, phenotyp_col='pred')
563
-
564
- return max_effects, max_effects_pvalues, model, df
565
-
566
- def regression_analasys(dv_df,sequencing_loc, min_reads=75, min_wells=2, max_wells=0, model_type = 'mlr', min_cells=100, transform='logit', min_frequency=0.05, gene_column='gene', effect_size_threshold=0.25, fishers=True, clean_regression=False, VIF_threshold=10):
567
-
568
- from .utils import generate_fraction_map, fishers_odds, model_metrics, check_multicollinearity
569
-
570
- sequencing_df = pd.read_csv(sequencing_loc)
571
- columns_list = ['c1','c2','c3', 'c15']
572
- sequencing_df = sequencing_df[~sequencing_df['col'].isin(columns_list)]
573
-
574
- reads_df, stats_dict = process_reads(df=sequencing_df,
575
- min_reads=min_reads,
576
- min_wells=min_wells,
577
- max_wells=max_wells,
578
- gene_column='gene')
579
-
580
- display(reads_df)
581
-
582
- iv_df = generate_fraction_map(df=reads_df,
583
- gene_column=gene_column,
584
- min_frequency=min_frequency)
585
-
586
- display(iv_df)
587
-
588
- dv_df = dv_df[dv_df['count_prc']>min_cells]
589
- display(dv_df)
590
- merged_df = iv_df.join(dv_df, on='prc', how='inner')
591
- display(merged_df)
592
- fisher_df = merged_df.copy()
593
-
594
- merged_df.reset_index(inplace=True)
595
- merged_df[['plate', 'row', 'col']] = merged_df['prc'].str.split('_', expand=True)
596
- merged_df = merged_df.drop(columns=['prc'])
597
- merged_df.dropna(inplace=True)
598
- merged_df = pd.get_dummies(merged_df, columns=['plate', 'row', 'col'], drop_first=True)
599
-
600
- y = merged_df['mean_pred']
601
-
602
- if model_type == 'mlr':
603
- merged_df = merged_df.drop(columns=['count_prc'])
604
-
605
- elif model_type == 'wls':
606
- weights = merged_df['count_prc']
607
-
608
- elif model_type == 'glm':
609
- merged_df = merged_df.drop(columns=['count_prc'])
610
-
611
- if transform == 'logit':
612
- # logit transformation
613
- epsilon = 1e-15
614
- y = np.log(y + epsilon) - np.log(1 - y + epsilon)
615
-
616
- elif transform == 'log':
617
- # log transformation
618
- y = np.log10(y+1)
619
-
620
- elif transform == 'center':
621
- # Centering the y around 0
622
- y_mean = y.mean()
623
- y = y - y_mean
624
-
625
- x = merged_df.drop('mean_pred', axis=1)
626
- x = x.select_dtypes(include=[np.number])
627
- #x = sm.add_constant(x)
628
- x['const'] = 0.0
629
-
630
- if model_type == 'mlr':
631
- model = sm.OLS(y, x).fit()
632
- model_metrics(model)
633
-
634
- # Check for Multicollinearity
635
- vif_data = check_multicollinearity(x.drop('const', axis=1)) # assuming you've added a constant to x
636
- high_vif_columns = vif_data[vif_data["VIF"] > VIF_threshold]["Variable"].values # VIF threshold of 10 is common, but this can vary based on context
637
-
638
- print(f"Columns with high VIF: {high_vif_columns}")
639
- x = x.drop(columns=high_vif_columns) # dropping columns with high VIF
640
-
641
- if clean_regression:
642
- # 1. Filter by standardized residuals
643
- std_resid = model.get_influence().resid_studentized_internal
644
- outliers_resid = np.where(np.abs(std_resid) > 3)[0]
645
-
646
- # 2. Filter by leverage
647
- influence = model.get_influence().hat_matrix_diag
648
- outliers_lev = np.where(influence > 2*(x.shape[1])/len(y))[0]
649
-
650
- # 3. Filter by Cook's distance
651
- (c, p) = model.get_influence().cooks_distance
652
- outliers_cooks = np.where(c > 4/(len(y)-x.shape[1]-1))[0]
653
-
654
- # Combine all identified outliers
655
- outliers = reduce(np.union1d, (outliers_resid, outliers_lev, outliers_cooks))
656
-
657
- # Filter out outliers
658
- x_clean = x.drop(x.index[outliers])
659
- y_clean = y.drop(y.index[outliers])
660
-
661
- # Re-run the regression with the filtered data
662
- model = sm.OLS(y_clean, x_clean).fit()
663
- model_metrics(model)
664
-
665
- elif model_type == 'wls':
666
- model = sm.WLS(y, x, weights=weights).fit()
667
-
668
- elif model_type == 'glm':
669
- model = sm.GLM(y, x, family=sm.families.Binomial()).fit()
670
-
671
- print(model.summary())
672
-
673
- results_summary = model.summary()
674
-
675
- results_as_html = results_summary.tables[1].as_html()
676
- results_df = pd.read_html(results_as_html, header=0, index_col=0)[0]
677
- results_df = results_df.sort_values(by='coef', ascending=False)
678
-
679
- if model_type == 'mlr':
680
- results_df['p'] = results_df['P>|t|']
681
- elif model_type == 'wls':
682
- results_df['p'] = results_df['P>|t|']
683
- elif model_type == 'glm':
684
- results_df['p'] = results_df['P>|z|']
685
-
686
- results_df['type'] = 1
687
- results_df.loc[results_df['p'] == 0.000, 'p'] = 0.005
688
- results_df['-log10(p)'] = -np.log10(results_df['p'])
689
-
690
- display(results_df)
691
-
692
- # Create subplots
693
- fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 15))
694
-
695
- # Plot histogram on ax1
696
- sns.histplot(data=y, kde=False, element="step", ax=ax1, color='teal')
697
- ax1.set_xlim([0, 1])
698
- ax1.spines['top'].set_visible(False)
699
- ax1.spines['right'].set_visible(False)
700
-
701
- # Prepare data for volcano plot on ax2
702
- results_df['-log10(p)'] = -np.log10(results_df['p'])
703
-
704
- # Assuming the 'type' column is in the merged_df
705
- sc = ax2.scatter(results_df['coef'], results_df['-log10(p)'], c=results_df['type'], cmap='coolwarm')
706
- ax2.set_title('Volcano Plot')
707
- ax2.set_xlabel('Coefficient')
708
- ax2.set_ylabel('-log10(P-value)')
709
-
710
- # Adjust colorbar
711
- cbar = plt.colorbar(sc, ax=ax2, ticks=[-1, 1])
712
- cbar.set_label('Sign of Coefficient')
713
- cbar.set_ticklabels(['-ve', '+ve'])
714
-
715
- # Add text for specified points
716
- for idx, row in results_df.iterrows():
717
- if row['p'] < 0.05 and row['coef'] > effect_size_threshold:
718
- ax2.text(row['coef'], -np.log10(row['p']), idx, fontsize=8, ha='center', va='bottom', color='black')
719
-
720
- ax2.axhline(y=-np.log10(0.05), color='gray', linestyle='--')
721
-
722
- plt.show()
723
-
724
- #if model_type == 'mlr':
725
- # show_residules(model)
726
-
727
- if fishers:
728
- threshold = 2*effect_size_threshold
729
- fishers_odds(df=fisher_df, threshold=threshold, phenotyp_col='mean_pred')
730
-
731
- return
732
-
733
- def merge_pred_mes(src,
734
- pred_loc,
735
- target='protein of interest',
736
- cell_dim=4,
737
- nucleus_dim=5,
738
- pathogen_dim=6,
739
- channel_of_interest=1,
740
- pathogen_size_min=0,
741
- nucleus_size_min=0,
742
- cell_size_min=0,
743
- pathogen_min=0,
744
- nucleus_min=0,
745
- cell_min=0,
746
- target_min=0,
747
- mask_chans=[0,1,2],
748
- filter_data=False,
749
- include_noninfected=False,
750
- include_multiinfected=False,
751
- include_multinucleated=False,
752
- cells_per_well=10,
753
- save_filtered_filelist=False,
754
- verbose=False):
755
-
756
- from .io import _read_and_merge_data
757
- from .plot import _plot_histograms_and_stats
758
-
759
- mask_chans=[cell_dim,nucleus_dim,pathogen_dim]
760
- sns.color_palette("mako", as_cmap=True)
761
- print(f'channel:{channel_of_interest} = {target}')
762
- overlay_channels = [0, 1, 2, 3]
763
- overlay_channels.remove(channel_of_interest)
764
- overlay_channels.reverse()
765
-
766
- db_loc = [src+'/measurements/measurements.db']
767
- tables = ['cell', 'nucleus', 'pathogen','cytoplasm']
768
- df, object_dfs = _read_and_merge_data(db_loc,
769
- tables,
770
- verbose=True,
771
- include_multinucleated=include_multinucleated,
772
- include_multiinfected=include_multiinfected,
773
- include_noninfected=include_noninfected)
774
- if filter_data:
775
- df = df[df['cell_area'] > cell_size_min]
776
- df = df[df[f'cell_channel_{mask_chans[2]}_mean_intensity'] > cell_min]
777
- print(f'After cell filtration {len(df)}')
778
- df = df[df['nucleus_area'] > nucleus_size_min]
779
- df = df[df[f'nucleus_channel_{mask_chans[0]}_mean_intensity'] > nucleus_min]
780
- print(f'After nucleus filtration {len(df)}')
781
- df = df[df['pathogen_area'] > pathogen_size_min]
782
- df=df[df[f'pathogen_channel_{mask_chans[1]}_mean_intensity'] > pathogen_min]
783
- print(f'After pathogen filtration {len(df)}')
784
- df = df[df[f'cell_channel_{channel_of_interest}_percentile_95'] > target_min]
785
- print(f'After channel {channel_of_interest} filtration', len(df))
786
-
787
- df['recruitment'] = df[f'pathogen_channel_{channel_of_interest}_mean_intensity']/df[f'cytoplasm_channel_{channel_of_interest}_mean_intensity']
788
-
789
- pred_df = annotate_results(pred_loc=pred_loc)
790
-
791
- if verbose:
792
- _plot_histograms_and_stats(df=pred_df)
793
-
794
- pred_df.set_index('prcfo', inplace=True)
795
- pred_df = pred_df.drop(columns=['plate', 'row', 'col', 'field'])
796
-
797
- joined_df = df.join(pred_df, how='inner')
798
-
799
- if verbose:
800
- _plot_histograms_and_stats(df=joined_df)
801
-
802
- return joined_df
803
-
804
- def process_reads(df, min_reads, min_wells, max_wells, gene_column, remove_outliers=False):
805
- print('start',len(df))
806
- df = df[df['count'] >= min_reads]
807
- print('after filtering min reads',min_reads, len(df))
808
- reads_ls = df['count']
809
- stats_dict = {}
810
- stats_dict['screen_reads_mean'] = np.mean(reads_ls)
811
- stats_dict['screen_reads_sd'] = np.std(reads_ls)
812
- stats_dict['screen_reads_var'] = np.var(reads_ls)
813
-
814
- well_read_sum = pd.DataFrame(df.groupby(['prc']).sum())
815
- well_read_sum = well_read_sum.rename({'count': 'well_read_sum'}, axis=1)
816
- well_sgRNA_count = pd.DataFrame(df.groupby(['prc']).count()[gene_column])
817
- well_sgRNA_count = well_sgRNA_count.rename({gene_column: 'gRNAs_per_well'}, axis=1)
818
- well_seq = pd.merge(well_read_sum, well_sgRNA_count, how='inner', suffixes=('', '_right'), left_index=True, right_index=True)
819
- gRNA_well_count = pd.DataFrame(df.groupby([gene_column]).count()['prc'])
820
- gRNA_well_count = gRNA_well_count.rename({'prc': 'gRNA_well_count'}, axis=1)
821
- df = pd.merge(df, well_seq, on='prc', how='inner', suffixes=('', '_right'))
822
- df = pd.merge(df, gRNA_well_count, on=gene_column, how='inner', suffixes=('', '_right'))
823
-
824
- df = df[df['gRNA_well_count'] >= min_wells]
825
- df = df[df['gRNA_well_count'] <= max_wells]
826
-
827
- if remove_outliers:
828
- clf = IsolationForest(contamination='auto', random_state=42, n_jobs=20)
829
- #clf.fit(df.select_dtypes(include=['int', 'float']))
830
- clf.fit(df[["gRNA_well_count", "count"]])
831
- outlier_array = clf.predict(df[["gRNA_well_count", "count"]])
832
- #outlier_array = clf.predict(df.select_dtypes(include=['int', 'float']))
833
- outlier_df = pd.DataFrame(outlier_array, columns=['outlier'])
834
- df['outlier'] = outlier_df['outlier']
835
- outliers = pd.DataFrame(df[df['outlier']==-1])
836
- df = pd.DataFrame(df[df['outlier']==1])
837
- print('removed',len(outliers), 'outliers', 'inlers',len(df))
838
-
839
- columns_to_drop = ['gRNA_well_count','gRNAs_per_well', 'well_read_sum']#, 'outlier']
840
- df = df.drop(columns_to_drop, axis=1)
841
-
842
- plates = ['p1', 'p2', 'p3', 'p4']
843
- df = df[df.plate.isin(plates) == True]
844
- print('after filtering out p5,p6,p7,p8',len(df))
845
-
846
- gRNA_well_count = pd.DataFrame(df.groupby([gene_column]).count()['prc'])
847
- gRNA_well_count = gRNA_well_count.rename({'prc': 'gRNA_well_count'}, axis=1)
848
- df = pd.merge(df, gRNA_well_count, on=gene_column, how='inner', suffixes=('', '_right'))
849
- well_read_sum = pd.DataFrame(df.groupby(['prc']).sum())
850
- well_read_sum = well_read_sum.rename({'count': 'well_read_sum'}, axis=1)
851
- well_sgRNA_count = pd.DataFrame(df.groupby(['prc']).count()[gene_column])
852
- well_sgRNA_count = well_sgRNA_count.rename({gene_column: 'gRNAs_per_well'}, axis=1)
853
- well_seq = pd.merge(well_read_sum, well_sgRNA_count, how='inner', suffixes=('', '_right'), left_index=True, right_index=True)
854
- df = pd.merge(df, well_seq, on='prc', how='inner', suffixes=('', '_right'))
855
-
856
- columns_to_drop = [col for col in df.columns if col.endswith('_right')]
857
- columns_to_drop2 = [col for col in df.columns if col.endswith('0')]
858
- columns_to_drop = columns_to_drop + columns_to_drop2
859
- df = df.drop(columns_to_drop, axis=1)
860
- return df, stats_dict
861
-
862
- def annotate_results(pred_loc):
863
-
864
- from .utils import _map_wells_png
865
-
866
- df = pd.read_csv(pred_loc)
867
- df = df.copy()
868
- pc_col_list = ['c4','c5','c6','c7','c8','c9','c10','c11','c12','c13','c14','c15','c16','c17','c18','c19','c20','c21','c22','c23','c24']
869
- pc_plate_list = ['p6','p7','p8', 'p9']
870
-
871
- nc_col_list = ['c1','c2','c3']
872
- nc_plate_list = ['p1','p2','p3','p4','p6','p7','p8', 'p9']
873
-
874
- screen_col_list = ['c4','c5','c6','c7','c8','c9','c10','c11','c12','c13','c14','c15','c16','c17','c18','c19','c20','c21','c22','c23','c24']
875
- screen_plate_list = ['p1','p2','p3','p4']
876
-
877
- df[['plate', 'row', 'col', 'field', 'cell_id', 'prcfo']] = df['path'].apply(lambda x: pd.Series(_map_wells_png(x)))
878
-
879
- df.loc[(df['col'].isin(pc_col_list)) & (df['plate'].isin(pc_plate_list)), 'condition'] = 'pc'
880
- df.loc[(df['col'].isin(nc_col_list)) & (df['plate'].isin(nc_plate_list)), 'condition'] = 'nc'
881
- df.loc[(df['col'].isin(screen_col_list)) & (df['plate'].isin(screen_plate_list)), 'condition'] = 'screen'
882
-
883
- df = df.dropna(subset=['condition'])
884
- display(df)
885
- return df
886
-
887
- def generate_dataset(settings={}):
888
-
889
- from .utils import initiate_counter, add_images_to_tar
890
-
891
- db_path = os.path.join(settings['src'], 'measurements', 'measurements.db')
892
- dst = os.path.join(settings['src'], 'datasets')
893
- all_paths = []
894
-
895
- # Connect to the database and retrieve the image paths
896
- print(f"Reading DataBase: {db_path}")
897
- try:
898
- with sqlite3.connect(db_path) as conn:
899
- cursor = conn.cursor()
900
- if settings['file_metadata']:
901
- if isinstance(settings['file_metadata'], str):
902
- cursor.execute("SELECT png_path FROM png_list WHERE png_path LIKE ?", (f"%{settings['file_metadata']}%",))
903
- else:
904
- cursor.execute("SELECT png_path FROM png_list")
905
-
906
- while True:
907
- rows = cursor.fetchmany(1000)
908
- if not rows:
909
- break
910
- all_paths.extend([row[0] for row in rows])
911
-
912
- except sqlite3.Error as e:
913
- print(f"Database error: {e}")
914
- return
915
- except Exception as e:
916
- print(f"Error: {e}")
917
- return
918
-
919
- if isinstance(settings['sample'], int):
920
- selected_paths = random.sample(all_paths, settings['sample'])
921
- print(f"Random selection of {len(selected_paths)} paths")
922
- else:
923
- selected_paths = all_paths
924
- random.shuffle(selected_paths)
925
- print(f"All paths: {len(selected_paths)} paths")
926
-
927
- total_images = len(selected_paths)
928
- print(f"Found {total_images} images")
929
-
930
- # Create a temp folder in dst
931
- temp_dir = os.path.join(dst, "temp_tars")
932
- os.makedirs(temp_dir, exist_ok=True)
933
-
934
- # Chunking the data
935
- num_procs = max(2, cpu_count() - 2)
936
- chunk_size = len(selected_paths) // num_procs
937
- remainder = len(selected_paths) % num_procs
938
-
939
- paths_chunks = []
940
- start = 0
941
- for i in range(num_procs):
942
- end = start + chunk_size + (1 if i < remainder else 0)
943
- paths_chunks.append(selected_paths[start:end])
944
- start = end
945
-
946
- temp_tar_files = [os.path.join(temp_dir, f"temp_{i}.tar") for i in range(num_procs)]
947
-
948
- print(f"Generating temporary tar files in {dst}")
949
-
950
- # Initialize shared counter and lock
951
- counter = Value('i', 0)
952
- lock = Lock()
953
-
954
- with Pool(processes=num_procs, initializer=initiate_counter, initargs=(counter, lock)) as pool:
955
- pool.starmap(add_images_to_tar, [(paths_chunks[i], temp_tar_files[i], total_images) for i in range(num_procs)])
956
-
957
- # Combine the temporary tar files into a final tar
958
- date_name = datetime.date.today().strftime('%y%m%d')
959
- if not settings['file_metadata'] is None:
960
- tar_name = f"{date_name}_{settings['experiment']}_{settings['file_metadata']}.tar"
961
- else:
962
- tar_name = f"{date_name}_{settings['experiment']}.tar"
963
- tar_name = os.path.join(dst, tar_name)
964
- if os.path.exists(tar_name):
965
- number = random.randint(1, 100)
966
- tar_name_2 = f"{date_name}_{settings['experiment']}_{settings['file_metadata']}_{number}.tar"
967
- print(f"Warning: {os.path.basename(tar_name)} exists, saving as {os.path.basename(tar_name_2)} ")
968
- tar_name = os.path.join(dst, tar_name_2)
969
-
970
- print(f"Merging temporary files")
971
-
972
- with tarfile.open(tar_name, 'w') as final_tar:
973
- for temp_tar_path in temp_tar_files:
974
- with tarfile.open(temp_tar_path, 'r') as temp_tar:
975
- for member in temp_tar.getmembers():
976
- file_obj = temp_tar.extractfile(member)
977
- final_tar.addfile(member, file_obj)
978
- os.remove(temp_tar_path)
979
-
980
- # Delete the temp folder
981
- shutil.rmtree(temp_dir)
982
- print(f"\nSaved {total_images} images to {tar_name}")
983
-
984
- return tar_name
985
-
986
- def apply_model_to_tar(settings={}):
987
-
988
- from .io import TarImageDataset
989
- from .utils import process_vision_results, print_progress
990
-
991
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
992
- if settings['normalize']:
993
- transform = transforms.Compose([
994
- transforms.ToTensor(),
995
- transforms.CenterCrop(size=(settings['image_size'], settings['image_size'])),
996
- transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])
997
- else:
998
- transform = transforms.Compose([
999
- transforms.ToTensor(),
1000
- transforms.CenterCrop(size=(settings['image_size'], settings['image_size']))])
1001
-
1002
- if settings['verbose']:
1003
- print(f"Loading model from {settings['model_path']}")
1004
- print(f"Loading dataset from {settings['tar_path']}")
1005
-
1006
- model = torch.load(settings['model_path'])
1007
-
1008
- dataset = TarImageDataset(settings['tar_path'], transform=transform)
1009
- data_loader = DataLoader(dataset, batch_size=settings['batch_size'], shuffle=True, num_workers=settings['n_jobs'], pin_memory=True)
1010
-
1011
- model_name = os.path.splitext(os.path.basename(settings['model_path']))[0]
1012
- dataset_name = os.path.splitext(os.path.basename(settings['tar_path']))[0]
1013
- date_name = datetime.date.today().strftime('%y%m%d')
1014
- dst = os.path.dirname(settings['tar_path'])
1015
- result_loc = f'{dst}/{date_name}_{dataset_name}_{model_name}_result.csv'
1016
-
1017
- model.eval()
1018
- model = model.to(device)
1019
-
1020
- if settings['verbose']:
1021
- print(model)
1022
- print(f'Generated dataset with {len(dataset)} images')
1023
- print(f'Generating loader from {len(data_loader)} batches')
1024
- print(f'Results wil be saved in: {result_loc}')
1025
- print(f'Model is in eval mode')
1026
- print(f'Model loaded to device')
1027
-
1028
- prediction_pos_probs = []
1029
- filenames_list = []
1030
- time_ls = []
1031
- gc.collect()
1032
- with torch.no_grad():
1033
- for batch_idx, (batch_images, filenames) in enumerate(data_loader, start=1):
1034
- start = time.time()
1035
- images = batch_images.to(torch.float).to(device)
1036
- outputs = model(images)
1037
- batch_prediction_pos_prob = torch.sigmoid(outputs).cpu().numpy()
1038
- prediction_pos_probs.extend(batch_prediction_pos_prob.tolist())
1039
- filenames_list.extend(filenames)
1040
- stop = time.time()
1041
- duration = stop - start
1042
- time_ls.append(duration)
1043
- files_processed = batch_idx*settings['batch_size']
1044
- files_to_process = len(data_loader)
1045
- print_progress(files_processed, files_to_process, n_jobs=settings['n_jobs'], time_ls=time_ls, batch_size=settings['batch_size'], operation_type="Tar dataset")
1046
-
1047
- data = {'path':filenames_list, 'pred':prediction_pos_probs}
1048
- df = pd.DataFrame(data, index=None)
1049
- df = process_vision_results(df, settings['score_threshold'])
1050
-
1051
- df.to_csv(result_loc, index=True, header=True, mode='w')
1052
- torch.cuda.empty_cache()
1053
- torch.cuda.memory.empty_cache()
1054
- return df
1055
-
1056
- def apply_model(src, model_path, image_size=224, batch_size=64, normalize=True, n_jobs=10):
1057
-
1058
- from .io import NoClassDataset
1059
- from .utils import print_progress
1060
-
1061
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
1062
-
1063
- if normalize:
1064
- transform = transforms.Compose([
1065
- transforms.ToTensor(),
1066
- transforms.CenterCrop(size=(image_size, image_size)),
1067
- transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])
1068
- else:
1069
- transform = transforms.Compose([
1070
- transforms.ToTensor(),
1071
- transforms.CenterCrop(size=(image_size, image_size))])
1072
-
1073
- model = torch.load(model_path)
1074
- print(model)
1075
-
1076
- print(f'Loading dataset in {src} with {len(src)} images')
1077
- dataset = NoClassDataset(data_dir=src, transform=transform, shuffle=True, load_to_memory=False)
1078
- data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=n_jobs)
1079
- print(f'Loaded {len(src)} images')
1080
-
1081
- result_loc = os.path.splitext(model_path)[0]+datetime.date.today().strftime('%y%m%d')+'_'+os.path.splitext(model_path)[1]+'_test_result.csv'
1082
- print(f'Results wil be saved in: {result_loc}')
1083
-
1084
- model.eval()
1085
- model = model.to(device)
1086
- prediction_pos_probs = []
1087
- filenames_list = []
1088
- time_ls = []
1089
- with torch.no_grad():
1090
- for batch_idx, (batch_images, filenames) in enumerate(data_loader, start=1):
1091
- start = time.time()
1092
- images = batch_images.to(torch.float).to(device)
1093
- outputs = model(images)
1094
- batch_prediction_pos_prob = torch.sigmoid(outputs).cpu().numpy()
1095
- prediction_pos_probs.extend(batch_prediction_pos_prob.tolist())
1096
- filenames_list.extend(filenames)
1097
- stop = time.time()
1098
- duration = stop - start
1099
- time_ls.append(duration)
1100
- files_processed = batch_idx*batch_size
1101
- files_to_process = len(data_loader)
1102
- print_progress(files_processed, files_to_process, n_jobs=n_jobs, time_ls=time_ls, batch_size=batch_size, operation_type="Generating predictions")
1103
-
1104
- data = {'path':filenames_list, 'pred':prediction_pos_probs}
1105
- df = pd.DataFrame(data, index=None)
1106
- df.to_csv(result_loc, index=True, header=True, mode='w')
1107
- torch.cuda.empty_cache()
1108
- torch.cuda.memory.empty_cache()
1109
- return df
1110
-
1111
- def generate_training_data_file_list(src,
1112
- target='protein of interest',
1113
- cell_dim=4,
1114
- nucleus_dim=5,
1115
- pathogen_dim=6,
1116
- channel_of_interest=1,
1117
- pathogen_size_min=0,
1118
- nucleus_size_min=0,
1119
- cell_size_min=0,
1120
- pathogen_min=0,
1121
- nucleus_min=0,
1122
- cell_min=0,
1123
- target_min=0,
1124
- mask_chans=[0,1,2],
1125
- filter_data=False,
1126
- include_noninfected=False,
1127
- include_multiinfected=False,
1128
- include_multinucleated=False,
1129
- cells_per_well=10,
1130
- save_filtered_filelist=False):
1131
-
1132
- from .io import _read_and_merge_data
1133
-
1134
- mask_dims=[cell_dim,nucleus_dim,pathogen_dim]
1135
- sns.color_palette("mako", as_cmap=True)
1136
- print(f'channel:{channel_of_interest} = {target}')
1137
- overlay_channels = [0, 1, 2, 3]
1138
- overlay_channels.remove(channel_of_interest)
1139
- overlay_channels.reverse()
1140
-
1141
- db_loc = [src+'/measurements/measurements.db']
1142
- tables = ['cell', 'nucleus', 'pathogen','cytoplasm']
1143
- df, object_dfs = _read_and_merge_data(db_loc,
1144
- tables,
1145
- verbose=True,
1146
- include_multinucleated=include_multinucleated,
1147
- include_multiinfected=include_multiinfected,
1148
- include_noninfected=include_noninfected)
1149
-
1150
- if filter_data:
1151
- df = df[df['cell_area'] > cell_size_min]
1152
- df = df[df[f'cell_channel_{mask_chans[2]}_mean_intensity'] > cell_min]
1153
- print(f'After cell filtration {len(df)}')
1154
- df = df[df['nucleus_area'] > nucleus_size_min]
1155
- df = df[df[f'nucleus_channel_{mask_chans[0]}_mean_intensity'] > nucleus_min]
1156
- print(f'After nucleus filtration {len(df)}')
1157
- df = df[df['pathogen_area'] > pathogen_size_min]
1158
- df=df[df[f'pathogen_channel_{mask_chans[1]}_mean_intensity'] > pathogen_min]
1159
- print(f'After pathogen filtration {len(df)}')
1160
- df = df[df[f'cell_channel_{channel_of_interest}_percentile_95'] > target_min]
1161
- print(f'After channel {channel_of_interest} filtration', len(df))
1162
-
1163
- df['recruitment'] = df[f'pathogen_channel_{channel_of_interest}_mean_intensity']/df[f'cytoplasm_channel_{channel_of_interest}_mean_intensity']
1164
- return df
1165
-
1166
- def training_dataset_from_annotation(db_path, dst, annotation_column='test', annotated_classes=(1, 2)):
1167
- all_paths = []
1168
-
1169
- # Connect to the database and retrieve the image paths and annotations
1170
- print(f'Reading DataBase: {db_path}')
1171
- with sqlite3.connect(db_path) as conn:
1172
- cursor = conn.cursor()
1173
- # Prepare the query with parameterized placeholders for annotated_classes
1174
- placeholders = ','.join('?' * len(annotated_classes))
1175
- query = f"SELECT png_path, {annotation_column} FROM png_list WHERE {annotation_column} IN ({placeholders})"
1176
- cursor.execute(query, annotated_classes)
1177
-
1178
- while True:
1179
- rows = cursor.fetchmany(1000)
1180
- if not rows:
1181
- break
1182
- for row in rows:
1183
- all_paths.append(row)
1184
-
1185
- # Filter paths based on annotation
1186
- class_paths = []
1187
- for class_ in annotated_classes:
1188
- class_paths_temp = [path for path, annotation in all_paths if annotation == class_]
1189
- class_paths.append(class_paths_temp)
1190
-
1191
- print(f'Generated a list of lists from annotation of {len(class_paths)} classes')
1192
- return class_paths
1193
-
1194
- def generate_dataset_from_lists(dst, class_data, classes, test_split=0.1):
1195
- from .utils import print_progress
1196
- # Make sure that the length of class_data matches the length of classes
1197
- if len(class_data) != len(classes):
1198
- raise ValueError("class_data and classes must have the same length.")
1199
-
1200
- total_files = sum(len(data) for data in class_data)
1201
- processed_files = 0
1202
- time_ls = []
1203
-
1204
- for cls, data in zip(classes, class_data):
1205
- # Create directories
1206
- train_class_dir = os.path.join(dst, f'train/{cls}')
1207
- test_class_dir = os.path.join(dst, f'test/{cls}')
1208
- os.makedirs(train_class_dir, exist_ok=True)
1209
- os.makedirs(test_class_dir, exist_ok=True)
1210
-
1211
- # Split the data
1212
- train_data, test_data = train_test_split(data, test_size=test_split, shuffle=True, random_state=42)
1213
-
1214
- # Copy train files
1215
- for path in train_data:
1216
- start = time.time()
1217
- shutil.copy(path, os.path.join(train_class_dir, os.path.basename(path)))
1218
- duration = time.time() - start
1219
- time_ls.append(duration)
1220
- print_progress(processed_files, total_files, n_jobs=1, time_ls=None, batch_size=None, operation_type="Copying files for Train dataset")
1221
- processed_files += 1
1222
-
1223
- # Copy test files
1224
- for path in test_data:
1225
- start = time.time()
1226
- shutil.copy(path, os.path.join(test_class_dir, os.path.basename(path)))
1227
- duration = time.time() - start
1228
- time_ls.append(duration)
1229
- print_progress(processed_files, total_files, n_jobs=1, time_ls=None, batch_size=None, operation_type="Copying files for Test dataset")
1230
- processed_files += 1
1231
-
1232
- # Print summary
1233
- for cls in classes:
1234
- train_class_dir = os.path.join(dst, f'train/{cls}')
1235
- test_class_dir = os.path.join(dst, f'test/{cls}')
1236
- print(f'Train class {cls}: {len(os.listdir(train_class_dir))}, Test class {cls}: {len(os.listdir(test_class_dir))}')
1237
-
1238
- return os.path.join(dst, 'train'), os.path.join(dst, 'test')
1239
-
1240
- def generate_training_dataset(settings):
1241
-
1242
- from .io import _read_and_merge_data, _read_db
1243
- from .utils import get_paths_from_db, annotate_conditions
1244
- from .settings import set_generate_training_dataset_defaults
1245
-
1246
- settings = set_generate_training_dataset_defaults(settings)
1247
-
1248
- db_path = os.path.join(settings['src'], 'measurements','measurements.db')
1249
- dst = os.path.join(settings['src'], 'datasets', 'training')
1250
-
1251
- if os.path.exists(dst):
1252
- for i in range(1, 1000):
1253
- dst = os.path.join(settings['src'], 'datasets', f'training_{i}')
1254
- if not os.path.exists(dst):
1255
- print(f'Creating new directory for training: {dst}')
1256
- break
1257
-
1258
- if settings['dataset_mode'] == 'annotation':
1259
- class_paths_ls_2 = []
1260
- class_paths_ls = training_dataset_from_annotation(db_path, dst, settings['annotation_column'], annotated_classes=settings['annotated_classes'])
1261
- for class_paths in class_paths_ls:
1262
- class_paths_temp = random.sample(class_paths, settings['size'])
1263
- class_paths_ls_2.append(class_paths_temp)
1264
- class_paths_ls = class_paths_ls_2
1265
-
1266
- elif settings['dataset_mode'] == 'metadata':
1267
- class_paths_ls = []
1268
- class_len_ls = []
1269
- [df] = _read_db(db_loc=db_path, tables=['png_list'])
1270
- df['metadata_based_class'] = pd.NA
1271
- for i, class_ in enumerate(settings['classes']):
1272
- ls = settings['class_metadata'][i]
1273
- df.loc[df[settings['metadata_type_by']].isin(ls), 'metadata_based_class'] = class_
1274
-
1275
- for class_ in settings['classes']:
1276
- if settings['size'] == None:
1277
- c_s = []
1278
- for c in settings['classes']:
1279
- c_s_t_df = df[df['metadata_based_class'] == c]
1280
- c_s.append(len(c_s_t_df))
1281
- print(f'Found {len(c_s_t_df)} images for class {c}')
1282
- size = min(c_s)
1283
- print(f'Using the smallest class size: {size}')
1284
-
1285
- class_temp_df = df[df['metadata_based_class'] == class_]
1286
- class_len_ls.append(len(class_temp_df))
1287
- print(f'Found {len(class_temp_df)} images for class {class_}')
1288
- class_paths_temp = random.sample(class_temp_df['png_path'].tolist(), settings['size'])
1289
- class_paths_ls.append(class_paths_temp)
1290
-
1291
- elif settings['dataset_mode'] == 'recruitment':
1292
- class_paths_ls = []
1293
- if not isinstance(settings['tables'], list):
1294
- tables = ['cell', 'nucleus', 'pathogen','cytoplasm']
1295
-
1296
- df, _ = _read_and_merge_data(locs=[db_path],
1297
- tables=tables,
1298
- verbose=False,
1299
- include_multinucleated=True,
1300
- include_multiinfected=True,
1301
- include_noninfected=True)
1302
-
1303
- print('length df 1', len(df))
1304
-
1305
- df = annotate_conditions(df, cells=['HeLa'], cell_loc=None, pathogens=['pathogen'], pathogen_loc=None, treatments=settings['classes'], treatment_loc=settings['class_metadata'], types = settings['metadata_type_by'])
1306
- print('length df 2', len(df))
1307
- [png_list_df] = _read_db(db_loc=db_path, tables=['png_list'])
1308
-
1309
- if settings['custom_measurement'] != None:
1310
-
1311
- if not isinstance(settings['custom_measurement'], list):
1312
- print(f'custom_measurement should be a list, add [ measurement_1, measurement_2 ] or [ measurement ]')
1313
- return
1314
-
1315
- if isinstance(settings['custom_measurement'], list):
1316
- if len(settings['custom_measurement']) == 2:
1317
- print(f"Classes will be defined by the Q1 and Q3 quantiles of recruitment ({settings['custom_measurement'][0]}/{settings['custom_measurement'][1]})")
1318
- df['recruitment'] = df[f"{settings['custom_measurement'][0]}']/df[f'{settings['custom_measurement'][1]}"]
1319
- if len(settings['custom_measurement']) == 1:
1320
- print(f"Classes will be defined by the Q1 and Q3 quantiles of recruitment ({settings['custom_measurement'][0]})")
1321
- df['recruitment'] = df[f"{settings['custom_measurement'][0]}"]
1322
- else:
1323
- print(f"Classes will be defined by the Q1 and Q3 quantiles of recruitment (pathogen/cytoplasm for channel {settings['channel_of_interest']})")
1324
- df['recruitment'] = df[f"pathogen_channel_{settings['channel_of_interest']}_mean_intensity']/df[f'cytoplasm_channel_{settings['channel_of_interest']}_mean_intensity"]
1325
-
1326
- q25 = df['recruitment'].quantile(0.25)
1327
- q75 = df['recruitment'].quantile(0.75)
1328
- df_lower = df[df['recruitment'] <= q25]
1329
- df_upper = df[df['recruitment'] >= q75]
1330
-
1331
- class_paths_lower = get_paths_from_db(df=df_lower, png_df=png_list_df, image_type=settings['png_type'])
1332
-
1333
- class_paths_lower = random.sample(class_paths_lower['png_path'].tolist(), settings['size'])
1334
- class_paths_ls.append(class_paths_lower)
1335
-
1336
- class_paths_upper = get_paths_from_db(df=df_upper, png_df=png_list_df, image_type=settings['png_type'])
1337
- class_paths_upper = random.sample(class_paths_upper['png_path'].tolist(), settings['size'])
1338
- class_paths_ls.append(class_paths_upper)
1339
-
1340
- train_class_dir, test_class_dir = generate_dataset_from_lists(dst, class_data=class_paths_ls, classes=settings['classes'], test_split=settings['test_split'])
1341
-
1342
- return train_class_dir, test_class_dir
1343
-
1344
- def generate_loaders(src, mode='train', image_size=224, batch_size=32, classes=['nc','pc'], n_jobs=None, validation_split=0.0, pin_memory=False, normalize=False, channels=[1, 2, 3], augment=False, preload_batches=3, verbose=False):
1345
-
1346
- """
1347
- Generate data loaders for training and validation/test datasets.
1348
-
1349
- Parameters:
1350
- - src (str): The source directory containing the data.
1351
- - mode (str): The mode of operation. Options are 'train' or 'test'.
1352
- - image_size (int): The size of the input images.
1353
- - batch_size (int): The batch size for the data loaders.
1354
- - classes (list): The list of classes to consider.
1355
- - n_jobs (int): The number of worker threads for data loading.
1356
- - validation_split (float): The fraction of data to use for validation.
1357
- - pin_memory (bool): Whether to pin memory for faster data transfer.
1358
- - normalize (bool): Whether to normalize the input images.
1359
- - verbose (bool): Whether to print additional information and show images.
1360
- - channels (list): The list of channels to retain. Options are [1, 2, 3] for all channels, [1, 2] for blue and green, etc.
1361
-
1362
- Returns:
1363
- - train_loaders (list): List of data loaders for training datasets.
1364
- - val_loaders (list): List of data loaders for validation datasets.
1365
- """
1366
-
1367
- from .io import spacrDataset, spacrDataLoader
1368
- from .plot import _imshow_gpu
1369
- from .utils import SelectChannels, augment_dataset
1370
-
1371
- chans = []
1372
-
1373
- if 'r' in channels:
1374
- chans.append(1)
1375
- if 'g' in channels:
1376
- chans.append(2)
1377
- if 'b' in channels:
1378
- chans.append(3)
1379
-
1380
- channels = chans
1381
-
1382
- if verbose:
1383
- print(f'Training a network on channels: {channels}')
1384
- print(f'Channel 1: Red, Channel 2: Green, Channel 3: Blue')
1385
-
1386
- train_loaders = []
1387
- val_loaders = []
1388
-
1389
- if normalize:
1390
- transform = transforms.Compose([
1391
- transforms.ToTensor(),
1392
- transforms.CenterCrop(size=(image_size, image_size)),
1393
- SelectChannels(channels),
1394
- transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])
1395
- else:
1396
- transform = transforms.Compose([
1397
- transforms.ToTensor(),
1398
- transforms.CenterCrop(size=(image_size, image_size)),
1399
- SelectChannels(channels)])
1400
-
1401
- if mode == 'train':
1402
- data_dir = os.path.join(src, 'train')
1403
- shuffle = True
1404
- print('Generating Train and validation datasets')
1405
- elif mode == 'test':
1406
- data_dir = os.path.join(src, 'test')
1407
- val_loaders = []
1408
- validation_split = 0.0
1409
- shuffle = True
1410
- print('Generating test dataset')
1411
- else:
1412
- print(f'mode:{mode} is not valid, use mode = train or test')
1413
- return
1414
-
1415
- data = spacrDataset(data_dir, classes, transform=transform, shuffle=shuffle, pin_memory=pin_memory)
1416
- num_workers = n_jobs if n_jobs is not None else 0
1417
-
1418
- if validation_split > 0:
1419
- train_size = int((1 - validation_split) * len(data))
1420
- val_size = len(data) - train_size
1421
- if not augment:
1422
- print(f'Train data:{train_size}, Validation data:{val_size}')
1423
- train_dataset, val_dataset = random_split(data, [train_size, val_size])
1424
-
1425
- if augment:
1426
-
1427
- print(f'Data before augmentation: Train: {len(train_dataset)}, Validataion:{len(val_dataset)}')
1428
- train_dataset = augment_dataset(train_dataset, is_grayscale=(len(channels) == 1))
1429
- print(f'Data after augmentation: Train: {len(train_dataset)}')
1430
-
1431
- print(f'Generating Dataloader with {n_jobs} workers')
1432
- #train_loaders = spacrDataLoader(train_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, pin_memory=pin_memory, persistent_workers=True, preload_batches=preload_batches)
1433
- #train_loaders = spacrDataLoader(val_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, pin_memory=pin_memory, persistent_workers=True, preload_batches=preload_batches)
1434
-
1435
- train_loaders = DataLoader(train_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=1, pin_memory=pin_memory, persistent_workers=True)
1436
- val_loaders = DataLoader(val_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=1, pin_memory=pin_memory, persistent_workers=True)
1437
- else:
1438
- train_loaders = DataLoader(data, batch_size=batch_size, shuffle=shuffle, num_workers=1, pin_memory=pin_memory, persistent_workers=True)
1439
- #train_loaders = spacrDataLoader(data, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers, pin_memory=pin_memory, persistent_workers=True, preload_batches=preload_batches)
1440
-
1441
- #dataset (Dataset) – dataset from which to load the data.
1442
- #batch_size (int, optional) – how many samples per batch to load (default: 1).
1443
- #shuffle (bool, optional) – set to True to have the data reshuffled at every epoch (default: False).
1444
- #sampler (Sampler or Iterable, optional) – defines the strategy to draw samples from the dataset. Can be any Iterable with __len__ implemented. If specified, shuffle must not be specified.
1445
- #batch_sampler (Sampler or Iterable, optional) – like sampler, but returns a batch of indices at a time. Mutually exclusive with batch_size, shuffle, sampler, and drop_last.
1446
- #num_workers (int, optional) – how many subprocesses to use for data loading. 0 means that the data will be loaded in the main process. (default: 0)
1447
- #collate_fn (Callable, optional) – merges a list of samples to form a mini-batch of Tensor(s). Used when using batched loading from a map-style dataset.
1448
- #pin_memory (bool, optional) – If True, the data loader will copy Tensors into device/CUDA pinned memory before returning them. If your data elements are a custom type, or your collate_fn returns a batch that is a custom type, see the example below.
1449
- #drop_last (bool, optional) – set to True to drop the last incomplete batch, if the dataset size is not divisible by the batch size. If False and the size of dataset is not divisible by the batch size, then the last batch will be smaller. (default: False)
1450
- #timeout (numeric, optional) – if positive, the timeout value for collecting a batch from workers. Should always be non-negative. (default: 0)
1451
- #worker_init_fn (Callable, optional) – If not None, this will be called on each worker subprocess with the worker id (an int in [0, num_workers - 1]) as input, after seeding and before data loading. (default: None)
1452
- #multiprocessing_context (str or multiprocessing.context.BaseContext, optional) – If None, the default multiprocessing context of your operating system will be used. (default: None)
1453
- #generator (torch.Generator, optional) – If not None, this RNG will be used by RandomSampler to generate random indexes and multiprocessing to generate base_seed for workers. (default: None)
1454
- #prefetch_factor (int, optional, keyword-only arg) – Number of batches loaded in advance by each worker. 2 means there will be a total of 2 * num_workers batches prefetched across all workers. (default value depends on the set value for num_workers. If value of num_workers=0 default is None. Otherwise, if value of num_workers > 0 default is 2).
1455
- #persistent_workers (bool, optional) – If True, the data loader will not shut down the worker processes after a dataset has been consumed once. This allows to maintain the workers Dataset instances alive. (default: False)
1456
- #pin_memory_device (str, optional) – the device to pin_memory to if pin_memory is True.
1457
-
1458
- #images, labels, filenames = next(iter(train_loaders))
1459
- #images = images.cpu()
1460
- #label_strings = [str(label.item()) for label in labels]
1461
- #train_fig = _imshow_gpu(images, label_strings, nrow=20, fontsize=12)
1462
- #if verbose:
1463
- # plt.show()
1464
-
1465
- train_fig = None
1466
-
1467
- return train_loaders, val_loaders, train_fig
1468
-
1469
- def analyze_recruitment(settings={}):
1470
- """
1471
- Analyze recruitment data by grouping the DataFrame by well coordinates and plotting controls and recruitment data.
1472
-
1473
- Parameters:
1474
- settings (dict): settings.
1475
-
1476
- Returns:
1477
- None
1478
- """
1479
-
1480
- from .io import _read_and_merge_data, _results_to_csv
1481
- from .plot import plot_image_mask_overlay, _plot_controls, _plot_recruitment
1482
- from .utils import _object_filter, annotate_conditions, _calculate_recruitment, _group_by_well, save_settings
1483
- from .settings import get_analyze_recruitment_default_settings
1484
-
1485
- settings = get_analyze_recruitment_default_settings(settings=settings)
1486
- save_settings(settings, name='recruitment')
1487
-
1488
- # metadata settings
1489
- src = settings['src']
1490
- target = settings['target']
1491
- cell_types = settings['cell_types']
1492
- cell_plate_metadata = settings['cell_plate_metadata']
1493
- pathogen_types = settings['pathogen_types']
1494
- pathogen_plate_metadata = settings['pathogen_plate_metadata']
1495
- treatments = settings['treatments']
1496
- treatment_plate_metadata = settings['treatment_plate_metadata']
1497
- metadata_types = settings['metadata_types']
1498
- channel_dims = settings['channel_dims']
1499
- cell_chann_dim = settings['cell_chann_dim']
1500
- cell_mask_dim = settings['cell_mask_dim']
1501
- nucleus_chann_dim = settings['nucleus_chann_dim']
1502
- nucleus_mask_dim = settings['nucleus_mask_dim']
1503
- pathogen_chann_dim = settings['pathogen_chann_dim']
1504
- pathogen_mask_dim = settings['pathogen_mask_dim']
1505
- channel_of_interest = settings['channel_of_interest']
1506
-
1507
- # Advanced settings
1508
- plot = settings['plot']
1509
- plot_nr = settings['plot_nr']
1510
- plot_control = settings['plot_control']
1511
- figuresize = settings['figuresize']
1512
- include_noninfected = settings['include_noninfected']
1513
- include_multiinfected = settings['include_multiinfected']
1514
- include_multinucleated = settings['include_multinucleated']
1515
- cells_per_well = settings['cells_per_well']
1516
- pathogen_size_range = settings['pathogen_size_range']
1517
- nucleus_size_range = settings['nucleus_size_range']
1518
- cell_size_range = settings['cell_size_range']
1519
- pathogen_intensity_range = settings['pathogen_intensity_range']
1520
- nucleus_intensity_range = settings['nucleus_intensity_range']
1521
- cell_intensity_range = settings['cell_intensity_range']
1522
- target_intensity_min = settings['target_intensity_min']
1523
-
1524
- print(f'Cell(s): {cell_types}, in {cell_plate_metadata}')
1525
- print(f'Pathogen(s): {pathogen_types}, in {pathogen_plate_metadata}')
1526
- print(f'Treatment(s): {treatments}, in {treatment_plate_metadata}')
1527
-
1528
- mask_dims=[cell_mask_dim,nucleus_mask_dim,pathogen_mask_dim]
1529
- mask_chans=[nucleus_chann_dim, pathogen_chann_dim, cell_chann_dim]
1530
-
1531
- if isinstance(metadata_types, str):
1532
- metadata_types = [metadata_types, metadata_types, metadata_types]
1533
- if isinstance(metadata_types, list):
1534
- if len(metadata_types) < 3:
1535
- metadata_types = [metadata_types[0], metadata_types[0], metadata_types[0]]
1536
- print(f'WARNING: setting metadata types to first element times 3: {metadata_types}. To avoid this behaviour, set metadata_types to a list with 3 elements. Elements should be col row or plate.')
1537
- else:
1538
- metadata_types = metadata_types
1539
-
1540
- sns.color_palette("mako", as_cmap=True)
1541
- print(f'channel:{channel_of_interest} = {target}')
1542
- overlay_channels = channel_dims
1543
- overlay_channels.remove(channel_of_interest)
1544
- overlay_channels.reverse()
1545
-
1546
- db_loc = [src+'/measurements/measurements.db']
1547
- tables = ['cell', 'nucleus', 'pathogen','cytoplasm']
1548
- df, _ = _read_and_merge_data(db_loc,
1549
- tables,
1550
- verbose=True,
1551
- include_multinucleated=include_multinucleated,
1552
- include_multiinfected=include_multiinfected,
1553
- include_noninfected=include_noninfected)
1554
-
1555
- df = annotate_conditions(df,
1556
- cells=cell_types,
1557
- cell_loc=cell_plate_metadata,
1558
- pathogens=pathogen_types,
1559
- pathogen_loc=pathogen_plate_metadata,
1560
- treatments=treatments,
1561
- treatment_loc=treatment_plate_metadata,
1562
- types=metadata_types)
1563
-
1564
- df = df.dropna(subset=['condition'])
1565
- print(f'After dropping non-annotated wells: {len(df)} rows')
1566
- files = df['file_name'].tolist()
1567
- print(f'found: {len(files)} files')
1568
- files = [item + '.npy' for item in files]
1569
- random.shuffle(files)
1570
-
1571
- _max = 10**100
1572
- if cell_size_range is None:
1573
- cell_size_range = [0,_max]
1574
- if nucleus_size_range is None:
1575
- nucleus_size_range = [0,_max]
1576
- if pathogen_size_range is None:
1577
- pathogen_size_range = [0,_max]
1578
-
1579
- if plot:
1580
- merged_path = os.path.join(src,'merged')
1581
- if os.path.exists(merged_path):
1582
- try:
1583
- for idx, file in enumerate(os.listdir(merged_path)):
1584
- file_path = os.path.join(merged_path,file)
1585
- if idx <= plot_nr:
1586
- plot_image_mask_overlay(file_path,
1587
- channel_dims,
1588
- cell_chann_dim,
1589
- nucleus_chann_dim,
1590
- pathogen_chann_dim,
1591
- figuresize=10,
1592
- normalize=True,
1593
- thickness=3,
1594
- save_pdf=True)
1595
- except Exception as e:
1596
- print(f'Failed to plot images with outlines, Error: {e}')
1597
-
1598
- if not cell_chann_dim is None:
1599
- df = _object_filter(df, object_type='cell', size_range=cell_size_range, intensity_range=cell_intensity_range, mask_chans=mask_chans, mask_chan=0)
1600
- if not target_intensity_min is None:
1601
- df = df[df[f'cell_channel_{channel_of_interest}_percentile_95'] > target_intensity_min]
1602
- print(f'After channel {channel_of_interest} filtration', len(df))
1603
- if not nucleus_chann_dim is None:
1604
- df = _object_filter(df, object_type='nucleus', size_range=nucleus_size_range, intensity_range=nucleus_intensity_range, mask_chans=mask_chans, mask_chan=1)
1605
- if not pathogen_chann_dim is None:
1606
- df = _object_filter(df, object_type='pathogen', size_range=pathogen_size_range, intensity_range=pathogen_intensity_range, mask_chans=mask_chans, mask_chan=2)
1607
-
1608
- df['recruitment'] = df[f'pathogen_channel_{channel_of_interest}_mean_intensity']/df[f'cytoplasm_channel_{channel_of_interest}_mean_intensity']
1609
- for chan in channel_dims:
1610
- df = _calculate_recruitment(df, channel=chan)
1611
- print(f'calculated recruitment for: {len(df)} rows')
1612
- df_well = _group_by_well(df)
1613
- print(f'found: {len(df_well)} wells')
1614
-
1615
- df_well = df_well[df_well['cells_per_well'] >= cells_per_well]
1616
- prc_list = df_well['prc'].unique().tolist()
1617
- df = df[df['prc'].isin(prc_list)]
1618
- print(f'After cells per well filter: {len(df)} cells in {len(df_well)} wells left wth threshold {cells_per_well}')
1619
-
1620
- if plot_control:
1621
- _plot_controls(df, mask_chans, channel_of_interest, figuresize=5)
1622
-
1623
- print(f'PV level: {len(df)} rows')
1624
- _plot_recruitment(df=df, df_type='by PV', channel_of_interest=channel_of_interest, target=target, figuresize=figuresize)
1625
- print(f'well level: {len(df_well)} rows')
1626
- _plot_recruitment(df=df_well, df_type='by well', channel_of_interest=channel_of_interest, target=target, figuresize=figuresize)
1627
- cells,wells = _results_to_csv(src, df, df_well)
1628
- return [cells,wells]
1629
-
1630
10
  def preprocess_generate_masks(src, settings={}):
1631
11
 
1632
12
  from .io import preprocess_img_data, _load_and_concatenate_arrays
1633
13
  from .plot import plot_image_mask_overlay, plot_arrays
1634
14
  from .utils import _pivot_counts_table, check_mask_folder, adjust_cell_masks, print_progress, save_settings
1635
15
  from .settings import set_default_settings_preprocess_generate_masks
1636
-
1637
- settings = set_default_settings_preprocess_generate_masks(src, settings)
1638
- settings['src'] = src
1639
- save_settings(settings, name='gen_mask')
1640
16
 
1641
- if not settings['pathogen_channel'] is None:
1642
- custom_model_ls = ['toxo_pv_lumen','toxo_cyto']
1643
- if settings['pathogen_model'] not in custom_model_ls:
1644
- ValueError(f'Pathogen model must be {custom_model_ls} or None')
1645
-
1646
- if settings['timelapse']:
1647
- settings['randomize'] = False
1648
-
1649
- if settings['preprocess']:
1650
- if not settings['masks']:
1651
- print(f'WARNING: channels for mask generation are defined when preprocess = True')
17
+ if not isinstance(settings['src'], (str, list)):
18
+ ValueError(f'src must be a string or a list of strings')
19
+ return
1652
20
 
1653
- if isinstance(settings['save'], bool):
1654
- settings['save'] = [settings['save']]*3
21
+ if isinstance(settings['src'], str):
22
+ settings['src'] = [settings['src']]
1655
23
 
1656
- if settings['verbose']:
1657
- settings_df = pd.DataFrame(list(settings.items()), columns=['setting_key', 'setting_value'])
1658
- settings_df['setting_value'] = settings_df['setting_value'].apply(str)
1659
- display(settings_df)
24
+ if isinstance(settings['src'], list):
25
+ source_folders = settings['src']
26
+ for source_folder in source_folders:
27
+ print(f'Processing folder: {source_folder}')
28
+ settings['src'] = source_folder
29
+ src = source_folder
30
+ settings = set_default_settings_preprocess_generate_masks(src, settings)
31
+
32
+ save_settings(settings, name='gen_mask')
1660
33
 
1661
- if settings['test_mode']:
1662
- print(f'Starting Test mode ...')
1663
-
1664
- if settings['preprocess']:
1665
- settings, src = preprocess_img_data(settings)
1666
-
1667
- files_to_process = 3
1668
- files_processed = 0
1669
- if settings['masks']:
1670
- mask_src = os.path.join(src, 'norm_channel_stack')
1671
- if settings['cell_channel'] != None:
1672
- time_ls=[]
1673
- if check_mask_folder(src, 'cell_mask_stack'):
1674
- start = time.time()
1675
- if settings['segmentation_mode'] == 'cellpose':
1676
- generate_cellpose_masks(mask_src, settings, 'cell')
1677
- elif settings['segmentation_mode'] == 'mediar':
1678
- generate_mediar_masks(mask_src, settings, 'cell')
1679
- stop = time.time()
1680
- duration = (stop - start)
1681
- time_ls.append(duration)
1682
- files_processed += 1
1683
- print_progress(files_processed, files_to_process, n_jobs=1, time_ls=time_ls, batch_size=None, operation_type=f'cell_mask_gen')
34
+ if not settings['pathogen_channel'] is None:
35
+ custom_model_ls = ['toxo_pv_lumen','toxo_cyto']
36
+ if settings['pathogen_model'] not in custom_model_ls:
37
+ ValueError(f'Pathogen model must be {custom_model_ls} or None')
1684
38
 
1685
- if settings['nucleus_channel'] != None:
1686
- time_ls=[]
1687
- if check_mask_folder(src, 'nucleus_mask_stack'):
1688
- start = time.time()
1689
- if settings['segmentation_mode'] == 'cellpose':
1690
- generate_cellpose_masks(mask_src, settings, 'nucleus')
1691
- elif settings['segmentation_mode'] == 'mediar':
1692
- generate_mediar_masks(mask_src, settings, 'nucleus')
1693
- stop = time.time()
1694
- duration = (stop - start)
1695
- time_ls.append(duration)
1696
- files_processed += 1
1697
- print_progress(files_processed, files_to_process, n_jobs=1, time_ls=time_ls, batch_size=None, operation_type=f'nucleus_mask_gen')
39
+ if settings['timelapse']:
40
+ settings['randomize'] = False
1698
41
 
1699
- if settings['pathogen_channel'] != None:
1700
- time_ls=[]
1701
- if check_mask_folder(src, 'pathogen_mask_stack'):
1702
- start = time.time()
1703
- if settings['segmentation_mode'] == 'cellpose':
1704
- generate_cellpose_masks(mask_src, settings, 'pathogen')
1705
- elif settings['segmentation_mode'] == 'mediar':
1706
- generate_mediar_masks(mask_src, settings, 'pathogen')
1707
- stop = time.time()
1708
- duration = (stop - start)
1709
- time_ls.append(duration)
1710
- files_processed += 1
1711
- print_progress(files_processed, files_to_process, n_jobs=1, time_ls=time_ls, batch_size=None, operation_type=f'pathogen_mask_gen')
1712
-
1713
- #if settings['organelle'] != None:
1714
- # if check_mask_folder(src, 'organelle_mask_stack'):
1715
- # generate_cellpose_masks(mask_src, settings, 'organelle')
1716
-
1717
- if settings['adjust_cells']:
1718
- if settings['pathogen_channel'] != None and settings['cell_channel'] != None and settings['nucleus_channel'] != None:
1719
-
1720
- start = time.time()
1721
- cell_folder = os.path.join(mask_src, 'cell_mask_stack')
1722
- nuclei_folder = os.path.join(mask_src, 'nucleus_mask_stack')
1723
- parasite_folder = os.path.join(mask_src, 'pathogen_mask_stack')
1724
- #organelle_folder = os.path.join(mask_src, 'organelle_mask_stack')
1725
-
1726
- adjust_cell_masks(parasite_folder, cell_folder, nuclei_folder, overlap_threshold=5, perimeter_threshold=30)
1727
- stop = time.time()
1728
- adjust_time = (stop-start)/60
1729
- print(f'Cell mask adjustment: {adjust_time} min.')
42
+ if settings['preprocess']:
43
+ if not settings['masks']:
44
+ print(f'WARNING: channels for mask generation are defined when preprocess = True')
1730
45
 
1731
- if os.path.exists(os.path.join(src,'measurements')):
1732
- _pivot_counts_table(db_path=os.path.join(src,'measurements', 'measurements.db'))
1733
-
1734
- #Concatenate stack with masks
1735
- _load_and_concatenate_arrays(src, settings['channels'], settings['cell_channel'], settings['nucleus_channel'], settings['pathogen_channel'])
1736
-
1737
- if settings['plot']:
1738
- if not settings['timelapse']:
1739
-
1740
- if settings['test_mode'] == True:
1741
- settings['examples_to_plot'] = len(os.path.join(src,'merged'))
1742
-
1743
- try:
1744
- merged_src = os.path.join(src,'merged')
1745
- files = os.listdir(merged_src)
1746
- random.shuffle(files)
1747
- time_ls = []
46
+ if isinstance(settings['save'], bool):
47
+ settings['save'] = [settings['save']]*3
48
+
49
+ if settings['verbose']:
50
+ settings_df = pd.DataFrame(list(settings.items()), columns=['setting_key', 'setting_value'])
51
+ settings_df['setting_value'] = settings_df['setting_value'].apply(str)
52
+ display(settings_df)
53
+
54
+ if settings['test_mode']:
55
+ print(f'Starting Test mode ...')
56
+
57
+ if settings['preprocess']:
58
+ settings, src = preprocess_img_data(settings)
59
+
60
+ files_to_process = 3
61
+ files_processed = 0
62
+ if settings['masks']:
63
+ mask_src = os.path.join(src, 'norm_channel_stack')
64
+ if settings['cell_channel'] != None:
65
+ time_ls=[]
66
+ if check_mask_folder(src, 'cell_mask_stack'):
67
+ start = time.time()
68
+ if settings['segmentation_mode'] == 'cellpose':
69
+ generate_cellpose_masks(mask_src, settings, 'cell')
70
+ elif settings['segmentation_mode'] == 'mediar':
71
+ generate_mediar_masks(mask_src, settings, 'cell')
72
+ stop = time.time()
73
+ duration = (stop - start)
74
+ time_ls.append(duration)
75
+ files_processed += 1
76
+ print_progress(files_processed, files_to_process, n_jobs=1, time_ls=time_ls, batch_size=None, operation_type=f'cell_mask_gen')
1748
77
 
1749
- for i, file in enumerate(files):
78
+ if settings['nucleus_channel'] != None:
79
+ time_ls=[]
80
+ if check_mask_folder(src, 'nucleus_mask_stack'):
1750
81
  start = time.time()
1751
- if i+1 <= settings['examples_to_plot']:
1752
- file_path = os.path.join(merged_src, file)
1753
- plot_image_mask_overlay(file_path, settings['channels'], settings['cell_channel'], settings['nucleus_channel'], settings['pathogen_channel'], figuresize=10, normalize=True, thickness=3, save_pdf=True)
1754
- stop = time.time()
1755
- duration = stop-start
1756
- time_ls.append(duration)
1757
- files_processed = i+1
1758
- files_to_process = settings['examples_to_plot']
1759
- print_progress(files_processed, files_to_process, n_jobs=1, time_ls=time_ls, batch_size=None, operation_type="Plot mask outlines")
1760
-
1761
- except Exception as e:
1762
- print(f'Failed to plot image mask overly. Error: {e}')
1763
- else:
1764
- plot_arrays(src=os.path.join(src,'merged'), figuresize=settings['figuresize'], cmap=settings['cmap'], nr=settings['examples_to_plot'], normalize=settings['normalize'], q1=1, q2=99)
1765
-
1766
- torch.cuda.empty_cache()
1767
- gc.collect()
1768
- print("Successfully completed run")
1769
- return
82
+ if settings['segmentation_mode'] == 'cellpose':
83
+ generate_cellpose_masks(mask_src, settings, 'nucleus')
84
+ elif settings['segmentation_mode'] == 'mediar':
85
+ generate_mediar_masks(mask_src, settings, 'nucleus')
86
+ stop = time.time()
87
+ duration = (stop - start)
88
+ time_ls.append(duration)
89
+ files_processed += 1
90
+ print_progress(files_processed, files_to_process, n_jobs=1, time_ls=time_ls, batch_size=None, operation_type=f'nucleus_mask_gen')
91
+
92
+ if settings['pathogen_channel'] != None:
93
+ time_ls=[]
94
+ if check_mask_folder(src, 'pathogen_mask_stack'):
95
+ start = time.time()
96
+ if settings['segmentation_mode'] == 'cellpose':
97
+ generate_cellpose_masks(mask_src, settings, 'pathogen')
98
+ elif settings['segmentation_mode'] == 'mediar':
99
+ generate_mediar_masks(mask_src, settings, 'pathogen')
100
+ stop = time.time()
101
+ duration = (stop - start)
102
+ time_ls.append(duration)
103
+ files_processed += 1
104
+ print_progress(files_processed, files_to_process, n_jobs=1, time_ls=time_ls, batch_size=None, operation_type=f'pathogen_mask_gen')
105
+
106
+ #if settings['organelle'] != None:
107
+ # if check_mask_folder(src, 'organelle_mask_stack'):
108
+ # generate_cellpose_masks(mask_src, settings, 'organelle')
109
+
110
+ if settings['adjust_cells']:
111
+ if settings['pathogen_channel'] != None and settings['cell_channel'] != None and settings['nucleus_channel'] != None:
1770
112
 
1771
- def identify_masks_finetune(settings):
1772
-
1773
- from .plot import print_mask_and_flows
1774
- from .utils import get_files_from_dir, resize_images_and_labels, print_progress
1775
- from .io import _load_normalized_images_and_labels, _load_images_and_labels
1776
- from .settings import get_identify_masks_finetune_default_settings
1777
-
1778
- settings = get_identify_masks_finetune_default_settings(settings)
1779
- src=settings['src']
1780
- dst=settings['dst']
1781
- model_name=settings['model_name']
1782
- custom_model=settings['custom_model']
1783
- channels = settings['channels']
1784
- background = settings['background']
1785
- remove_background=settings['remove_background']
1786
- Signal_to_noise = settings['Signal_to_noise']
1787
- CP_prob = settings['CP_prob']
1788
- diameter=settings['diameter']
1789
- batch_size=settings['batch_size']
1790
- flow_threshold=settings['flow_threshold']
1791
- save=settings['save']
1792
- verbose=settings['verbose']
1793
-
1794
- # static settings
1795
- normalize = settings['normalize']
1796
- percentiles = settings['percentiles']
1797
- circular = settings['circular']
1798
- invert = settings['invert']
1799
- resize = settings['resize']
1800
-
1801
- if resize:
1802
- target_height = settings['target_height']
1803
- target_width = settings['target_width']
1804
-
1805
- rescale = settings['rescale']
1806
- resample = settings['resample']
1807
- grayscale = settings['grayscale']
1808
-
1809
- os.makedirs(dst, exist_ok=True)
1810
-
1811
- if not custom_model is None:
1812
- if not os.path.exists(custom_model):
1813
- print(f'Custom model not found: {custom_model}')
1814
- return
113
+ start = time.time()
114
+ cell_folder = os.path.join(mask_src, 'cell_mask_stack')
115
+ nuclei_folder = os.path.join(mask_src, 'nucleus_mask_stack')
116
+ parasite_folder = os.path.join(mask_src, 'pathogen_mask_stack')
117
+ #organelle_folder = os.path.join(mask_src, 'organelle_mask_stack')
118
+ print(f'Adjusting cell masks with nuclei and pathogen masks')
119
+ adjust_cell_masks(parasite_folder, cell_folder, nuclei_folder, overlap_threshold=5, perimeter_threshold=30)
120
+ stop = time.time()
121
+ adjust_time = (stop-start)/60
122
+ print(f'Cell mask adjustment: {adjust_time} min.')
123
+
124
+ if os.path.exists(os.path.join(src,'measurements')):
125
+ _pivot_counts_table(db_path=os.path.join(src,'measurements', 'measurements.db'))
1815
126
 
1816
- if not torch.cuda.is_available():
1817
- print(f'Torch CUDA is not available, using CPU')
1818
-
1819
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
1820
-
1821
- if custom_model == None:
1822
- model = cp_models.CellposeModel(gpu=True, model_type=model_name, device=device)
1823
- print(f'Loaded model: {model_name}')
1824
- else:
1825
- model = cp_models.CellposeModel(gpu=torch.cuda.is_available(), model_type=None, pretrained_model=custom_model, diam_mean=diameter, device=device)
1826
- print("Pretrained Model Loaded:", model.pretrained_model)
127
+ #Concatenate stack with masks
128
+ _load_and_concatenate_arrays(src, settings['channels'], settings['cell_channel'], settings['nucleus_channel'], settings['pathogen_channel'])
129
+
130
+ if settings['plot']:
131
+ if not settings['timelapse']:
1827
132
 
1828
- chans = [2, 1] if model_name == 'cyto2' else [0,0] if model_name == 'nucleus' else [1,0] if model_name == 'cyto' else [2, 0]
1829
-
1830
- if grayscale:
1831
- chans=[0, 0]
1832
-
1833
- print(f'Using channels: {chans} for model of type {model_name}')
1834
-
1835
- if verbose == True:
1836
- print(f'Cellpose settings: Model: {model_name}, channels: {channels}, cellpose_chans: {chans}, diameter:{diameter}, flow_threshold:{flow_threshold}, cellprob_threshold:{CP_prob}')
1837
-
1838
- all_image_files = [os.path.join(src, f) for f in os.listdir(src) if f.endswith('.tif')]
1839
- mask_files = set(os.listdir(os.path.join(src, 'masks')))
1840
- all_image_files = [f for f in all_image_files if os.path.basename(f) not in mask_files]
1841
- random.shuffle(all_image_files)
1842
-
1843
- time_ls = []
1844
- for i in range(0, len(all_image_files), batch_size):
1845
- gc.collect()
1846
- image_files = all_image_files[i:i+batch_size]
1847
-
1848
- if normalize:
1849
- images, _, image_names, _, orig_dims = _load_normalized_images_and_labels(image_files=image_files, label_files=None, channels=channels, percentiles=percentiles, circular=circular, invert=invert, visualize=verbose, remove_background=remove_background, background=background, Signal_to_noise=Signal_to_noise, target_height=target_height, target_width=target_width)
1850
- images = [np.squeeze(img) if img.shape[-1] == 1 else img for img in images]
1851
- #orig_dims = [(image.shape[0], image.shape[1]) for image in images]
1852
- else:
1853
- images, _, image_names, _ = _load_images_and_labels(image_files=image_files, label_files=None, circular=circular, invert=invert)
1854
- images = [np.squeeze(img) if img.shape[-1] == 1 else img for img in images]
1855
- orig_dims = [(image.shape[0], image.shape[1]) for image in images]
1856
- if resize:
1857
- images, _ = resize_images_and_labels(images, None, target_height, target_width, True)
1858
-
1859
- for file_index, stack in enumerate(images):
1860
- start = time.time()
1861
- output = model.eval(x=stack,
1862
- normalize=False,
1863
- channels=chans,
1864
- channel_axis=3,
1865
- diameter=diameter,
1866
- flow_threshold=flow_threshold,
1867
- cellprob_threshold=CP_prob,
1868
- rescale=rescale,
1869
- resample=resample,
1870
- progress=True)
133
+ if settings['test_mode'] == True:
134
+ settings['examples_to_plot'] = len(os.path.join(src,'merged'))
1871
135
 
1872
- if len(output) == 4:
1873
- mask, flows, _, _ = output
1874
- elif len(output) == 3:
1875
- mask, flows, _ = output
1876
- else:
1877
- raise ValueError("Unexpected number of return values from model.eval()")
1878
-
1879
- if resize:
1880
- dims = orig_dims[file_index]
1881
- mask = resizescikit(mask, dims, order=0, preserve_range=True, anti_aliasing=False).astype(mask.dtype)
1882
-
1883
- stop = time.time()
1884
- duration = (stop - start)
1885
- time_ls.append(duration)
1886
- files_processed = len(images)
1887
- files_to_process = file_index+1
1888
- print_progress(files_processed, files_to_process, n_jobs=1, time_ls=time_ls)
1889
- print_progress(files_processed, files_to_process, n_jobs=1, time_ls=time_ls, batch_size=None, operation_type="")
1890
-
1891
-
1892
- if verbose:
1893
- if resize:
1894
- stack = resizescikit(stack, dims, preserve_range=True, anti_aliasing=False).astype(stack.dtype)
1895
- print_mask_and_flows(stack, mask, flows, overlay=True)
1896
- if save:
1897
- os.makedirs(dst, exist_ok=True)
1898
- output_filename = os.path.join(dst, image_names[file_index])
1899
- cv2.imwrite(output_filename, mask)
1900
- del images, output, mask, flows
1901
- gc.collect()
136
+ try:
137
+ merged_src = os.path.join(src,'merged')
138
+ files = os.listdir(merged_src)
139
+ random.shuffle(files)
140
+ time_ls = []
141
+
142
+ for i, file in enumerate(files):
143
+ start = time.time()
144
+ if i+1 <= settings['examples_to_plot']:
145
+ file_path = os.path.join(merged_src, file)
146
+ plot_image_mask_overlay(file_path, settings['channels'], settings['cell_channel'], settings['nucleus_channel'], settings['pathogen_channel'], figuresize=10, normalize=True, thickness=3, save_pdf=True)
147
+ stop = time.time()
148
+ duration = stop-start
149
+ time_ls.append(duration)
150
+ files_processed = i+1
151
+ files_to_process = settings['examples_to_plot']
152
+ print_progress(files_processed, files_to_process, n_jobs=1, time_ls=time_ls, batch_size=None, operation_type="Plot mask outlines")
153
+
154
+ except Exception as e:
155
+ print(f'Failed to plot image mask overly. Error: {e}')
156
+ else:
157
+ plot_arrays(src=os.path.join(src,'merged'), figuresize=settings['figuresize'], cmap=settings['cmap'], nr=settings['examples_to_plot'], normalize=settings['normalize'], q1=1, q2=99)
158
+
159
+ torch.cuda.empty_cache()
160
+ gc.collect()
161
+ print("Successfully completed run")
1902
162
  return
1903
163
 
1904
- def all_elements_match(list1, list2):
1905
- # Check if all elements in list1 are in list2
1906
- return all(element in list2 for element in list1)
1907
-
1908
- def prepare_batch_for_segmentation(batch):
1909
- # Ensure the batch is of dtype float32
1910
- if batch.dtype != np.float32:
1911
- batch = batch.astype(np.float32)
1912
-
1913
- # Normalize each image in the batch
1914
- for i in range(batch.shape[0]):
1915
- if batch[i].max() > 1:
1916
- batch[i] = batch[i] / batch[i].max()
1917
-
1918
- return batch
1919
-
1920
164
  def generate_cellpose_masks(src, settings, object_type):
1921
165
 
1922
- from .utils import _masks_to_masks_stack, _filter_cp_masks, _get_cellpose_batch_size, _get_cellpose_channels, _choose_model, mask_object_count, print_progress
166
+ from .utils import _masks_to_masks_stack, _filter_cp_masks, _get_cellpose_batch_size, _get_cellpose_channels, _choose_model, mask_object_count, all_elements_match, prepare_batch_for_segmentation
1923
167
  from .io import _create_database, _save_object_counts_to_database, _check_masks, _get_avg_object_size
1924
168
  from .timelapse import _npz_to_movie, _btrack_track_cells, _trackpy_track_cells
1925
169
  from .plot import plot_masks
@@ -2162,593 +406,6 @@ def generate_cellpose_masks(src, settings, object_type):
2162
406
  torch.cuda.empty_cache()
2163
407
  return
2164
408
 
2165
- def generate_masks_from_imgs(src, model, model_name, batch_size, diameter, cellprob_threshold, flow_threshold, grayscale, save, normalize, channels, percentiles, circular, invert, plot, resize, target_height, target_width, remove_background, background, Signal_to_noise, verbose):
2166
-
2167
- from .io import _load_images_and_labels, _load_normalized_images_and_labels
2168
- from .utils import resize_images_and_labels, resizescikit, print_progress
2169
- from .plot import print_mask_and_flows
2170
-
2171
- dst = os.path.join(src, model_name)
2172
- os.makedirs(dst, exist_ok=True)
2173
-
2174
- chans = [2, 1] if model_name == 'cyto2' else [0,0] if model_name == 'nucleus' else [1,0] if model_name == 'cyto' else [2, 0]
2175
-
2176
- if grayscale:
2177
- chans=[0, 0]
2178
-
2179
- all_image_files = [os.path.join(src, f) for f in os.listdir(src) if f.endswith('.tif')]
2180
- random.shuffle(all_image_files)
2181
-
2182
- if verbose == True:
2183
- print(f'Cellpose settings: Model: {model_name}, channels: {channels}, cellpose_chans: {chans}, diameter:{diameter}, flow_threshold:{flow_threshold}, cellprob_threshold:{cellprob_threshold}')
2184
-
2185
- time_ls = []
2186
- for i in range(0, len(all_image_files), batch_size):
2187
- image_files = all_image_files[i:i+batch_size]
2188
-
2189
- if normalize:
2190
- images, _, image_names, _, orig_dims = _load_normalized_images_and_labels(image_files, None, channels, percentiles, circular, invert, plot, remove_background, background, Signal_to_noise, target_height, target_width)
2191
- images = [np.squeeze(img) if img.shape[-1] == 1 else img for img in images]
2192
- orig_dims = [(image.shape[0], image.shape[1]) for image in images]
2193
- else:
2194
- images, _, image_names, _ = _load_images_and_labels(image_files, None, circular, invert)
2195
- images = [np.squeeze(img) if img.shape[-1] == 1 else img for img in images]
2196
- orig_dims = [(image.shape[0], image.shape[1]) for image in images]
2197
- if resize:
2198
- images, _ = resize_images_and_labels(images, None, target_height, target_width, True)
2199
-
2200
- for file_index, stack in enumerate(images):
2201
- start = time.time()
2202
- output = model.eval(x=stack,
2203
- normalize=False,
2204
- channels=chans,
2205
- channel_axis=3,
2206
- diameter=diameter,
2207
- flow_threshold=flow_threshold,
2208
- cellprob_threshold=cellprob_threshold,
2209
- rescale=False,
2210
- resample=False,
2211
- progress=False)
2212
-
2213
- if len(output) == 4:
2214
- mask, flows, _, _ = output
2215
- elif len(output) == 3:
2216
- mask, flows, _ = output
2217
- else:
2218
- raise ValueError("Unexpected number of return values from model.eval()")
2219
-
2220
- if resize:
2221
- dims = orig_dims[file_index]
2222
- mask = resizescikit(mask, dims, order=0, preserve_range=True, anti_aliasing=False).astype(mask.dtype)
2223
-
2224
- stop = time.time()
2225
- duration = (stop - start)
2226
- time_ls.append(duration)
2227
- files_processed = file_index+1
2228
- files_to_process = len(images)
2229
-
2230
- print_progress(files_processed, files_to_process, n_jobs=1, time_ls=time_ls, batch_size=None, operation_type="Generating masks")
2231
-
2232
- if plot:
2233
- if resize:
2234
- stack = resizescikit(stack, dims, preserve_range=True, anti_aliasing=False).astype(stack.dtype)
2235
- print_mask_and_flows(stack, mask, flows, overlay=True)
2236
- if save:
2237
- output_filename = os.path.join(dst, image_names[file_index])
2238
- cv2.imwrite(output_filename, mask)
2239
-
2240
-
2241
- def check_cellpose_models(settings):
2242
-
2243
- from .settings import get_check_cellpose_models_default_settings
2244
-
2245
- settings = get_check_cellpose_models_default_settings(settings)
2246
- src = settings['src']
2247
-
2248
- settings_df = pd.DataFrame(list(settings.items()), columns=['setting_key', 'setting_value'])
2249
- settings_df['setting_value'] = settings_df['setting_value'].apply(str)
2250
- display(settings_df)
2251
-
2252
- cellpose_models = ['cyto', 'nuclei', 'cyto2', 'cyto3']
2253
- device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
2254
-
2255
- for model_name in cellpose_models:
2256
-
2257
- model = cp_models.CellposeModel(gpu=True, model_type=model_name, device=device)
2258
- print(f'Using {model_name}')
2259
- generate_masks_from_imgs(src, model, model_name, settings['batch_size'], settings['diameter'], settings['CP_prob'], settings['flow_threshold'], settings['grayscale'], settings['save'], settings['normalize'], settings['channels'], settings['percentiles'], settings['circular'], settings['invert'], settings['plot'], settings['resize'], settings['target_height'], settings['target_width'], settings['remove_background'], settings['background'], settings['Signal_to_noise'], settings['verbose'])
2260
-
2261
- return
2262
-
2263
- def save_results_and_figure(src, fig, results):
2264
-
2265
- if not isinstance(results, pd.DataFrame):
2266
- results = pd.DataFrame(results)
2267
-
2268
- results_dir = os.path.join(src, 'results')
2269
- os.makedirs(results_dir, exist_ok=True)
2270
- results_path = os.path.join(results_dir,f'results.csv')
2271
- fig_path = os.path.join(results_dir, f'model_comparison_plot.pdf')
2272
- results.to_csv(results_path, index=False)
2273
- fig.savefig(fig_path, format='pdf')
2274
- print(f'Saved figure to {fig_path} and results to {results_path}')
2275
-
2276
- def compare_mask(args):
2277
- src, filename, dirs, conditions = args
2278
- paths = [os.path.join(d, filename) for d in dirs]
2279
-
2280
- if not all(os.path.exists(path) for path in paths):
2281
- return None
2282
-
2283
- from .io import _read_mask # Import here to avoid issues in multiprocessing
2284
- from .utils import extract_boundaries, boundary_f1_score, compute_segmentation_ap, jaccard_index
2285
- from .plot import plot_comparison_results
2286
-
2287
- masks = [_read_mask(path) for path in paths]
2288
- file_results = {'filename': filename}
2289
-
2290
- for i in range(len(masks)):
2291
- for j in range(i + 1, len(masks)):
2292
- mask_i, mask_j = masks[i], masks[j]
2293
- f1_score = boundary_f1_score(mask_i, mask_j)
2294
- jac_index = jaccard_index(mask_i, mask_j)
2295
- ap_score = compute_segmentation_ap(mask_i, mask_j)
2296
-
2297
- file_results.update({
2298
- f'jaccard_{conditions[i]}_{conditions[j]}': jac_index,
2299
- f'boundary_f1_{conditions[i]}_{conditions[j]}': f1_score,
2300
- f'ap_{conditions[i]}_{conditions[j]}': ap_score
2301
- })
2302
-
2303
- return file_results
2304
-
2305
- def compare_cellpose_masks(src, verbose=False, processes=None, save=True):
2306
- from .plot import visualize_cellpose_masks, plot_comparison_results
2307
- from .io import _read_mask
2308
-
2309
- dirs = [os.path.join(src, d) for d in os.listdir(src) if os.path.isdir(os.path.join(src, d)) and d != 'results']
2310
- dirs.sort()
2311
- conditions = [os.path.basename(d) for d in dirs]
2312
-
2313
- # Get common files in all directories
2314
- common_files = set(os.listdir(dirs[0]))
2315
- for d in dirs[1:]:
2316
- common_files.intersection_update(os.listdir(d))
2317
- common_files = list(common_files)
2318
-
2319
- # Create a pool of n_jobs
2320
- with Pool(processes=processes) as pool:
2321
- args = [(src, filename, dirs, conditions) for filename in common_files]
2322
- results = pool.map(compare_mask, args)
2323
-
2324
- # Filter out None results (from skipped files)
2325
- results = [res for res in results if res is not None]
2326
- print(results)
2327
- if verbose:
2328
- for result in results:
2329
- filename = result['filename']
2330
- masks = [_read_mask(os.path.join(d, filename)) for d in dirs]
2331
- visualize_cellpose_masks(masks, titles=conditions, filename=filename, save=save, src=src)
2332
-
2333
- fig = plot_comparison_results(results)
2334
- save_results_and_figure(src, fig, results)
2335
- return
2336
-
2337
- def _calculate_similarity(df, features, col_to_compare, val1, val2):
2338
- """
2339
- Calculate similarity scores of each well to the positive and negative controls using various metrics.
2340
-
2341
- Args:
2342
- df (pandas.DataFrame): DataFrame containing the data.
2343
- features (list): List of feature columns to use for similarity calculation.
2344
- col_to_compare (str): Column name to use for comparing groups.
2345
- val1, val2 (str): Values in col_to_compare to create subsets for comparison.
2346
-
2347
- Returns:
2348
- pandas.DataFrame: DataFrame with similarity scores.
2349
- """
2350
- # Separate positive and negative control wells
2351
- pos_control = df[df[col_to_compare] == val1][features].mean()
2352
- neg_control = df[df[col_to_compare] == val2][features].mean()
2353
-
2354
- # Standardize features for Mahalanobis distance
2355
- scaler = StandardScaler()
2356
- scaled_features = scaler.fit_transform(df[features])
2357
-
2358
- # Regularize the covariance matrix to avoid singularity
2359
- cov_matrix = np.cov(scaled_features, rowvar=False)
2360
- inv_cov_matrix = None
2361
- try:
2362
- inv_cov_matrix = np.linalg.inv(cov_matrix)
2363
- except np.linalg.LinAlgError:
2364
- # Add a small value to the diagonal elements for regularization
2365
- epsilon = 1e-5
2366
- inv_cov_matrix = np.linalg.inv(cov_matrix + np.eye(cov_matrix.shape[0]) * epsilon)
2367
-
2368
- # Calculate similarity scores
2369
- df['similarity_to_pos_euclidean'] = df[features].apply(lambda row: euclidean(row, pos_control), axis=1)
2370
- df['similarity_to_neg_euclidean'] = df[features].apply(lambda row: euclidean(row, neg_control), axis=1)
2371
- df['similarity_to_pos_cosine'] = df[features].apply(lambda row: cosine(row, pos_control), axis=1)
2372
- df['similarity_to_neg_cosine'] = df[features].apply(lambda row: cosine(row, neg_control), axis=1)
2373
- df['similarity_to_pos_mahalanobis'] = df[features].apply(lambda row: mahalanobis(row, pos_control, inv_cov_matrix), axis=1)
2374
- df['similarity_to_neg_mahalanobis'] = df[features].apply(lambda row: mahalanobis(row, neg_control, inv_cov_matrix), axis=1)
2375
- df['similarity_to_pos_manhattan'] = df[features].apply(lambda row: cityblock(row, pos_control), axis=1)
2376
- df['similarity_to_neg_manhattan'] = df[features].apply(lambda row: cityblock(row, neg_control), axis=1)
2377
- df['similarity_to_pos_minkowski'] = df[features].apply(lambda row: minkowski(row, pos_control, p=3), axis=1)
2378
- df['similarity_to_neg_minkowski'] = df[features].apply(lambda row: minkowski(row, neg_control, p=3), axis=1)
2379
- df['similarity_to_pos_chebyshev'] = df[features].apply(lambda row: chebyshev(row, pos_control), axis=1)
2380
- df['similarity_to_neg_chebyshev'] = df[features].apply(lambda row: chebyshev(row, neg_control), axis=1)
2381
- df['similarity_to_pos_hamming'] = df[features].apply(lambda row: hamming(row, pos_control), axis=1)
2382
- df['similarity_to_neg_hamming'] = df[features].apply(lambda row: hamming(row, neg_control), axis=1)
2383
- df['similarity_to_pos_jaccard'] = df[features].apply(lambda row: jaccard(row, pos_control), axis=1)
2384
- df['similarity_to_neg_jaccard'] = df[features].apply(lambda row: jaccard(row, neg_control), axis=1)
2385
- df['similarity_to_pos_braycurtis'] = df[features].apply(lambda row: braycurtis(row, pos_control), axis=1)
2386
- df['similarity_to_neg_braycurtis'] = df[features].apply(lambda row: braycurtis(row, neg_control), axis=1)
2387
-
2388
- return df
2389
-
2390
- def find_optimal_threshold(y_true, y_pred_proba):
2391
- """
2392
- Find the optimal threshold for binary classification based on the F1-score.
2393
-
2394
- Args:
2395
- y_true (array-like): True binary labels.
2396
- y_pred_proba (array-like): Predicted probabilities for the positive class.
2397
-
2398
- Returns:
2399
- float: The optimal threshold.
2400
- """
2401
- precision, recall, thresholds = precision_recall_curve(y_true, y_pred_proba)
2402
- f1_scores = 2 * (precision * recall) / (precision + recall)
2403
- optimal_idx = np.argmax(f1_scores)
2404
- optimal_threshold = thresholds[optimal_idx]
2405
- return optimal_threshold
2406
-
2407
- def ml_analysis(df, channel_of_interest=3, location_column='col', positive_control='c2', negative_control='c1', exclude=None, n_repeats=10, top_features=30, n_estimators=100, test_size=0.2, model_type='xgboost', n_jobs=-1, remove_low_variance_features=True, remove_highly_correlated_features=True, verbose=False):
2408
- """
2409
- Calculates permutation importance for numerical features in the dataframe,
2410
- comparing groups based on specified column values and uses the model to predict
2411
- the class for all other rows in the dataframe.
2412
-
2413
- Args:
2414
- df (pandas.DataFrame): The DataFrame containing the data.
2415
- feature_string (str): String to filter features that contain this substring.
2416
- location_column (str): Column name to use for comparing groups.
2417
- positive_control, negative_control (str): Values in location_column to create subsets for comparison.
2418
- exclude (list or str, optional): Columns to exclude from features.
2419
- n_repeats (int): Number of repeats for permutation importance.
2420
- top_features (int): Number of top features to plot based on permutation importance.
2421
- n_estimators (int): Number of trees in the random forest, gradient boosting, or XGBoost model.
2422
- test_size (float): Proportion of the dataset to include in the test split.
2423
- random_state (int): Random seed for reproducibility.
2424
- model_type (str): Type of model to use ('random_forest', 'logistic_regression', 'gradient_boosting', 'xgboost').
2425
- n_jobs (int): Number of jobs to run in parallel for applicable models.
2426
-
2427
- Returns:
2428
- pandas.DataFrame: The original dataframe with added prediction and data usage columns.
2429
- pandas.DataFrame: DataFrame containing the importances and standard deviations.
2430
- """
2431
-
2432
- from .utils import filter_dataframe_features
2433
- from .plot import plot_permutation, plot_feature_importance
2434
-
2435
- random_state = 42
2436
-
2437
- if 'cells_per_well' in df.columns:
2438
- df = df.drop(columns=['cells_per_well'])
2439
-
2440
- df_metadata = df[[location_column]].copy()
2441
- df, features = filter_dataframe_features(df, channel_of_interest, exclude, remove_low_variance_features, remove_highly_correlated_features, verbose)
2442
-
2443
- if verbose:
2444
- print(f'Found {len(features)} numerical features in the dataframe')
2445
- print(f'Features used in training: {features}')
2446
- df = pd.concat([df, df_metadata[location_column]], axis=1)
2447
-
2448
- # Subset the dataframe based on specified column values
2449
- df1 = df[df[location_column] == negative_control].copy()
2450
- df2 = df[df[location_column] == positive_control].copy()
2451
-
2452
- # Create target variable
2453
- df1['target'] = 0 # Negative control
2454
- df2['target'] = 1 # Positive control
2455
-
2456
- # Combine the subsets for analysis
2457
- combined_df = pd.concat([df1, df2])
2458
- combined_df = combined_df.drop(columns=[location_column])
2459
- if verbose:
2460
- print(f'Found {len(df1)} samples for {negative_control} and {len(df2)} samples for {positive_control}. Total: {len(combined_df)}')
2461
-
2462
- X = combined_df[features]
2463
- y = combined_df['target']
2464
-
2465
- # Split the data into training and testing sets
2466
- X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=test_size, random_state=random_state)
2467
-
2468
- # Add data usage labels
2469
- combined_df['data_usage'] = 'train'
2470
- combined_df.loc[X_test.index, 'data_usage'] = 'test'
2471
- df['data_usage'] = 'not_used'
2472
- df.loc[combined_df.index, 'data_usage'] = combined_df['data_usage']
2473
-
2474
- # Initialize the model based on model_type
2475
- if model_type == 'random_forest':
2476
- model = RandomForestClassifier(n_estimators=n_estimators, random_state=random_state, n_jobs=n_jobs)
2477
- elif model_type == 'logistic_regression':
2478
- model = LogisticRegression(max_iter=1000, random_state=random_state, n_jobs=n_jobs)
2479
- elif model_type == 'gradient_boosting':
2480
- model = HistGradientBoostingClassifier(max_iter=n_estimators, random_state=random_state) # Supports n_jobs internally
2481
- elif model_type == 'xgboost':
2482
- model = XGBClassifier(n_estimators=n_estimators, random_state=random_state, nthread=n_jobs, use_label_encoder=False, eval_metric='logloss')
2483
- else:
2484
- raise ValueError(f"Unsupported model_type: {model_type}")
2485
-
2486
- model.fit(X_train, y_train)
2487
-
2488
- perm_importance = permutation_importance(model, X_train, y_train, n_repeats=n_repeats, random_state=random_state, n_jobs=n_jobs)
2489
-
2490
- # Create a DataFrame for permutation importances
2491
- permutation_df = pd.DataFrame({
2492
- 'feature': [features[i] for i in perm_importance.importances_mean.argsort()],
2493
- 'importance_mean': perm_importance.importances_mean[perm_importance.importances_mean.argsort()],
2494
- 'importance_std': perm_importance.importances_std[perm_importance.importances_mean.argsort()]
2495
- }).tail(top_features)
2496
-
2497
- permutation_fig = plot_permutation(permutation_df)
2498
- if verbose:
2499
- permutation_fig.show()
2500
-
2501
- # Feature importance for models that support it
2502
- if model_type in ['random_forest', 'xgboost', 'gradient_boosting']:
2503
- feature_importances = model.feature_importances_
2504
- feature_importance_df = pd.DataFrame({
2505
- 'feature': features,
2506
- 'importance': feature_importances
2507
- }).sort_values(by='importance', ascending=False).head(top_features)
2508
-
2509
- feature_importance_fig = plot_feature_importance(feature_importance_df)
2510
- if verbose:
2511
- feature_importance_fig.show()
2512
-
2513
- else:
2514
- feature_importance_df = pd.DataFrame()
2515
-
2516
- # Predicting the target variable for the test set
2517
- predictions_test = model.predict(X_test)
2518
- combined_df.loc[X_test.index, 'predictions'] = predictions_test
2519
-
2520
- # Get prediction probabilities for the test set
2521
- prediction_probabilities_test = model.predict_proba(X_test)
2522
-
2523
- # Find the optimal threshold
2524
- optimal_threshold = find_optimal_threshold(y_test, prediction_probabilities_test[:, 1])
2525
- if verbose:
2526
- print(f'Optimal threshold: {optimal_threshold}')
2527
-
2528
- # Predicting the target variable for all other rows in the dataframe
2529
- X_all = df[features]
2530
- all_predictions = model.predict(X_all)
2531
- df['predictions'] = all_predictions
2532
-
2533
- # Get prediction probabilities for all rows in the dataframe
2534
- prediction_probabilities = model.predict_proba(X_all)
2535
- for i in range(prediction_probabilities.shape[1]):
2536
- df[f'prediction_probability_class_{i}'] = prediction_probabilities[:, i]
2537
- if verbose:
2538
- print("\nClassification Report:")
2539
- print(classification_report(y_test, predictions_test))
2540
- report_dict = classification_report(y_test, predictions_test, output_dict=True)
2541
- metrics_df = pd.DataFrame(report_dict).transpose()
2542
-
2543
- df = _calculate_similarity(df, features, location_column, positive_control, negative_control)
2544
-
2545
- df['prcfo'] = df.index.astype(str)
2546
- df[['plate', 'row', 'col', 'field', 'object']] = df['prcfo'].str.split('_', expand=True)
2547
- df['prc'] = df['plate'] + '_' + df['row'] + '_' + df['col']
2548
-
2549
- return [df, permutation_df, feature_importance_df, model, X_train, X_test, y_train, y_test, metrics_df], [permutation_fig, feature_importance_fig]
2550
-
2551
- def shap_analysis(model, X_train, X_test):
2552
-
2553
- """
2554
- Performs SHAP analysis on the given model and data.
2555
-
2556
- Args:
2557
- model: The trained model.
2558
- X_train (pandas.DataFrame): Training feature set.
2559
- X_test (pandas.DataFrame): Testing feature set.
2560
- Returns:
2561
- fig: Matplotlib figure object containing the SHAP summary plot.
2562
- """
2563
-
2564
- explainer = shap.Explainer(model, X_train)
2565
- shap_values = explainer(X_test)
2566
- # Create a new figure
2567
- fig, ax = plt.subplots()
2568
- # Summary plot
2569
- shap.summary_plot(shap_values, X_test, show=False)
2570
- # Save the current figure (the one that SHAP just created)
2571
- fig = plt.gcf()
2572
- plt.close(fig) # Close the figure to prevent it from displaying immediately
2573
- return fig
2574
-
2575
- def check_index(df, elements=5, split_char='_'):
2576
- problematic_indices = []
2577
- for idx in df.index:
2578
- parts = str(idx).split(split_char)
2579
- if len(parts) != elements:
2580
- problematic_indices.append(idx)
2581
- if problematic_indices:
2582
- print("Indices that cannot be separated into 5 parts:")
2583
- for idx in problematic_indices:
2584
- print(idx)
2585
- raise ValueError(f"Found {len(problematic_indices)} problematic indices that do not split into {elements} parts.")
2586
-
2587
- def generate_ml_scores(src, settings):
2588
-
2589
- from .io import _read_and_merge_data
2590
- from .plot import plot_plates
2591
- from .utils import get_ml_results_paths
2592
- from .settings import set_default_analyze_screen
2593
-
2594
- settings = set_default_analyze_screen(settings)
2595
-
2596
- settings_df = pd.DataFrame(list(settings.items()), columns=['Key', 'Value'])
2597
- display(settings_df)
2598
-
2599
- db_loc = [src+'/measurements/measurements.db']
2600
- tables = ['cell', 'nucleus', 'pathogen','cytoplasm']
2601
- include_multinucleated, include_multiinfected, include_noninfected = True, 2.0, True
2602
-
2603
- df, _ = _read_and_merge_data(db_loc,
2604
- tables,
2605
- settings['verbose'],
2606
- include_multinucleated,
2607
- include_multiinfected,
2608
- include_noninfected)
2609
-
2610
- if settings['channel_of_interest'] in [0,1,2,3]:
2611
-
2612
- df['recruitment'] = df[f"pathogen_channel_{settings['channel_of_interest']}_mean_intensity"]/df[f"cytoplasm_channel_{settings['channel_of_interest']}_mean_intensity"]
2613
-
2614
- output, figs = ml_analysis(df,
2615
- settings['channel_of_interest'],
2616
- settings['location_column'],
2617
- settings['positive_control'],
2618
- settings['negative_control'],
2619
- settings['exclude'],
2620
- settings['n_repeats'],
2621
- settings['top_features'],
2622
- settings['n_estimators'],
2623
- settings['test_size'],
2624
- settings['model_type_ml'],
2625
- settings['n_jobs'],
2626
- settings['remove_low_variance_features'],
2627
- settings['remove_highly_correlated_features'],
2628
- settings['verbose'])
2629
-
2630
- shap_fig = shap_analysis(output[3], output[4], output[5])
2631
-
2632
- features = output[0].select_dtypes(include=[np.number]).columns.tolist()
2633
-
2634
- if not settings['heatmap_feature'] in features:
2635
- raise ValueError(f"Variable {settings['heatmap_feature']} not found in the dataframe. Please choose one of the following: {features}")
2636
-
2637
- plate_heatmap = plot_plates(df=output[0],
2638
- variable=settings['heatmap_feature'],
2639
- grouping=settings['grouping'],
2640
- min_max=settings['min_max'],
2641
- cmap=settings['cmap'],
2642
- min_count=settings['minimum_cell_count'],
2643
- verbose=settings['verbose'])
2644
-
2645
- data_path, permutation_path, feature_importance_path, model_metricks_path, permutation_fig_path, feature_importance_fig_path, shap_fig_path, plate_heatmap_path, settings_csv = get_ml_results_paths(src, settings['model_type_ml'], settings['channel_of_interest'])
2646
- df, permutation_df, feature_importance_df, _, _, _, _, _, metrics_df = output
2647
-
2648
- settings_df.to_csv(settings_csv, index=False)
2649
- df.to_csv(data_path, mode='w', encoding='utf-8')
2650
- permutation_df.to_csv(permutation_path, mode='w', encoding='utf-8')
2651
- feature_importance_df.to_csv(feature_importance_path, mode='w', encoding='utf-8')
2652
- metrics_df.to_csv(model_metricks_path, mode='w', encoding='utf-8')
2653
-
2654
- plate_heatmap.savefig(plate_heatmap_path, format='pdf')
2655
- figs[0].savefig(permutation_fig_path, format='pdf')
2656
- figs[1].savefig(feature_importance_fig_path, format='pdf')
2657
- shap_fig.savefig(shap_fig_path, format='pdf')
2658
-
2659
- return [output, plate_heatmap]
2660
-
2661
- def join_measurments_and_annotation(src, tables = ['cell', 'nucleus', 'pathogen','cytoplasm']):
2662
-
2663
- from .io import _read_and_merge_data, _read_db
2664
-
2665
- db_loc = [src+'/measurements/measurements.db']
2666
- loc = src+'/measurements/measurements.db'
2667
- df, _ = _read_and_merge_data(db_loc,
2668
- tables,
2669
- verbose=True,
2670
- include_multinucleated=True,
2671
- include_multiinfected=True,
2672
- include_noninfected=True)
2673
-
2674
- paths_df = _read_db(loc, tables=['png_list'])
2675
-
2676
- merged_df = pd.merge(df, paths_df[0], on='prcfo', how='left')
2677
-
2678
- return merged_df
2679
-
2680
- def jitterplot_by_annotation(src, x_column, y_column, plot_title='Jitter Plot', output_path=None, filter_column=None, filter_values=None):
2681
- """
2682
- Reads a CSV file and creates a jitter plot of one column grouped by another column.
2683
-
2684
- Args:
2685
- src (str): Path to the source data.
2686
- x_column (str): Name of the column to be used for the x-axis.
2687
- y_column (str): Name of the column to be used for the y-axis.
2688
- plot_title (str): Title of the plot. Default is 'Jitter Plot'.
2689
- output_path (str): Path to save the plot image. If None, the plot will be displayed. Default is None.
2690
-
2691
- Returns:
2692
- pd.DataFrame: The filtered and balanced DataFrame.
2693
- """
2694
- # Read the CSV file into a DataFrame
2695
- df = join_measurments_and_annotation(src, tables=['cell', 'nucleus', 'pathogen', 'cytoplasm'])
2696
-
2697
- # Print column names for debugging
2698
- print(f"Generated dataframe with: {df.shape[1]} columns and {df.shape[0]} rows")
2699
- #print("Columns in DataFrame:", df.columns.tolist())
2700
-
2701
- # Replace NaN values with a specific label in x_column
2702
- df[x_column] = df[x_column].fillna('NaN')
2703
-
2704
- # Filter the DataFrame if filter_column and filter_values are provided
2705
- if not filter_column is None:
2706
- if isinstance(filter_column, str):
2707
- df = df[df[filter_column].isin(filter_values)]
2708
- if isinstance(filter_column, list):
2709
- for i,val in enumerate(filter_column):
2710
- print(f'hello {len(df)}')
2711
- df = df[df[val].isin(filter_values[i])]
2712
-
2713
- # Use the correct column names based on your DataFrame
2714
- required_columns = ['plate_x', 'row_x', 'col_x']
2715
- if not all(column in df.columns for column in required_columns):
2716
- raise KeyError(f"DataFrame does not contain the necessary columns: {required_columns}")
2717
-
2718
- # Filter to retain rows with non-NaN values in x_column and with matching plate, row, col values
2719
- non_nan_df = df[df[x_column] != 'NaN']
2720
- retained_rows = df[df[['plate_x', 'row_x', 'col_x']].apply(tuple, axis=1).isin(non_nan_df[['plate_x', 'row_x', 'col_x']].apply(tuple, axis=1))]
2721
-
2722
- # Determine the minimum count of examples across all groups in x_column
2723
- min_count = retained_rows[x_column].value_counts().min()
2724
- print(f'Found {min_count} annotated images')
2725
-
2726
- # Randomly sample min_count examples from each group in x_column
2727
- balanced_df = retained_rows.groupby(x_column).apply(lambda x: x.sample(min_count, random_state=42)).reset_index(drop=True)
2728
-
2729
- # Create the jitter plot
2730
- plt.figure(figsize=(10, 6))
2731
- jitter_plot = sns.stripplot(data=balanced_df, x=x_column, y=y_column, hue=x_column, jitter=True, palette='viridis', dodge=False)
2732
- plt.title(plot_title)
2733
- plt.xlabel(x_column)
2734
- plt.ylabel(y_column)
2735
-
2736
- # Customize the x-axis labels
2737
- plt.xticks(rotation=45, ha='right')
2738
-
2739
- # Adjust the position of the x-axis labels to be centered below the data
2740
- ax = plt.gca()
2741
- ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha='center')
2742
-
2743
- # Save the plot to a file or display it
2744
- if output_path:
2745
- plt.savefig(output_path, bbox_inches='tight')
2746
- print(f"Jitter plot saved to {output_path}")
2747
- else:
2748
- plt.show()
2749
-
2750
- return balanced_df
2751
-
2752
409
  def generate_image_umap(settings={}):
2753
410
  """
2754
411
  Generate UMAP or tSNE embedding and visualize the data with clustering.
@@ -2784,7 +441,7 @@ def generate_image_umap(settings={}):
2784
441
  """
2785
442
 
2786
443
  from .io import _read_and_join_tables
2787
- from .utils import get_db_paths, preprocess_data, reduction_and_clustering, remove_noise, generate_colors, correct_paths, plot_embedding, plot_clusters_grid, cluster_feature_analysis #, generate_umap_from_images
444
+ from .utils import get_db_paths, preprocess_data, reduction_and_clustering, remove_noise, generate_colors, correct_paths, plot_embedding, plot_clusters_grid, cluster_feature_analysis, map_condition
2788
445
  from .settings import set_default_umap_image_settings
2789
446
  settings = set_default_umap_image_settings(settings)
2790
447
 
@@ -2933,17 +590,6 @@ def generate_image_umap(settings={}):
2933
590
 
2934
591
  return all_df
2935
592
 
2936
- # Define the mapping function
2937
- def map_condition(col_value, neg='c1', pos='c2', mix='c3'):
2938
- if col_value == neg:
2939
- return 'neg'
2940
- elif col_value == pos:
2941
- return 'pos'
2942
- elif col_value == mix:
2943
- return 'mix'
2944
- else:
2945
- return 'screen'
2946
-
2947
593
  def reducer_hyperparameter_search(settings={}, reduction_params=None, dbscan_params=None, kmeans_params=None, save=False):
2948
594
  """
2949
595
  Perform a hyperparameter search for UMAP or tSNE on the given data.
@@ -2970,7 +616,7 @@ def reducer_hyperparameter_search(settings={}, reduction_params=None, dbscan_par
2970
616
  """
2971
617
 
2972
618
  from .io import _read_and_join_tables
2973
- from .utils import get_db_paths, preprocess_data, search_reduction_and_clustering, generate_colors
619
+ from .utils import get_db_paths, preprocess_data, search_reduction_and_clustering, generate_colors, map_condition
2974
620
  from .settings import set_default_umap_image_settings
2975
621
 
2976
622
  settings = set_default_umap_image_settings(settings)
@@ -3122,7 +768,8 @@ def generate_mediar_masks(src, settings, object_type):
3122
768
  from .mediar import MEDIARPredictor
3123
769
  from .io import _create_database, _save_object_counts_to_database
3124
770
  from .plot import plot_masks
3125
- from .settings import set_default_settings_preprocess_generate_masks, _get_object_settings
771
+ from .settings import set_default_settings_preprocess_generate_masks
772
+ from .utils import prepare_batch_for_segmentation
3126
773
 
3127
774
  # Clear CUDA cache and check if CUDA is available
3128
775
  gc.collect()
@@ -3197,4 +844,108 @@ def generate_mediar_masks(src, settings, object_type):
3197
844
  gc.collect()
3198
845
  torch.cuda.empty_cache()
3199
846
 
3200
- print("Mask generation completed.")
847
+ print("Mask generation completed.")
848
+
849
+ def generate_screen_graphs(settings):
850
+ """
851
+ Generate screen graphs for different measurements in a given source directory.
852
+
853
+ Args:
854
+ src (str or list): Path(s) to the source directory or directories.
855
+ tables (list): List of tables to include in the analysis (default: ['cell', 'nucleus', 'pathogen', 'cytoplasm']).
856
+ graph_type (str): Type of graph to generate (default: 'bar').
857
+ summary_func (str or function): Function to summarize data (default: 'mean').
858
+ y_axis_start (float): Starting value for the y-axis (default: 0).
859
+ error_bar_type (str): Type of error bar to use ('std' or 'sem') (default: 'std').
860
+ theme (str): Theme for the graph (default: 'pastel').
861
+ representation (str): Representation for grouping (default: 'well').
862
+
863
+ Returns:
864
+ figs (list): List of generated figures.
865
+ results (list): List of corresponding result DataFrames.
866
+ """
867
+
868
+ from .plot import spacrGraph
869
+ from .io import _read_and_merge_data
870
+ from.utils import annotate_conditions
871
+
872
+ if isinstance(settings['src'], str):
873
+ srcs = [settings['src']]
874
+ else:
875
+ srcs = settings['src']
876
+
877
+ all_df = pd.DataFrame()
878
+ figs = []
879
+ results = []
880
+
881
+ for src in srcs:
882
+ db_loc = [os.path.join(src, 'measurements', 'measurements.db')]
883
+
884
+ # Read and merge data from the database
885
+ df, _ = _read_and_merge_data(db_loc, settings['tables'], verbose=True, nuclei_limit=settings['nuclei_limit'], pathogen_limit=settings['pathogen_limit'], uninfected=settings['uninfected'])
886
+
887
+ # Annotate the data
888
+ df = annotate_conditions(df, cells=settings['cells'], cell_loc=None, pathogens=settings['controls'], pathogen_loc=settings['controls_loc'], treatments=None, treatment_loc=None)
889
+
890
+ # Calculate recruitment metric
891
+ df['recruitment'] = df['pathogen_channel_1_mean_intensity'] / df['cytoplasm_channel_1_mean_intensity']
892
+
893
+ # Combine with the overall DataFrame
894
+ all_df = pd.concat([all_df, df], ignore_index=True)
895
+
896
+ # Generate individual plot
897
+ plotter = spacrGraph(df,
898
+ grouping_column='pathogen',
899
+ data_column='recruitment',
900
+ graph_type=settings['graph_type'],
901
+ summary_func=settings['summary_func'],
902
+ y_axis_start=settings['y_axis_start'],
903
+ error_bar_type=settings['error_bar_type'],
904
+ theme=settings['theme'],
905
+ representation=settings['representation'])
906
+
907
+ plotter.create_plot()
908
+ fig = plotter.get_figure()
909
+ results_df = plotter.get_results()
910
+
911
+ # Append to the lists
912
+ figs.append(fig)
913
+ results.append(results_df)
914
+
915
+ # Generate plot for the combined data (all_df)
916
+ plotter = spacrGraph(all_df,
917
+ grouping_column='pathogen',
918
+ data_column='recruitment',
919
+ graph_type=settings['graph_type'],
920
+ summary_func=settings['summary_func'],
921
+ y_axis_start=settings['y_axis_start'],
922
+ error_bar_type=settings['error_bar_type'],
923
+ theme=settings['theme'],
924
+ representation=settings['representation'])
925
+
926
+ plotter.create_plot()
927
+ fig = plotter.get_figure()
928
+ results_df = plotter.get_results()
929
+
930
+ figs.append(fig)
931
+ results.append(results_df)
932
+
933
+ # Save figures and results
934
+ for i, fig in enumerate(figs):
935
+ res = results[i]
936
+
937
+ if i < len(srcs):
938
+ source = srcs[i]
939
+ else:
940
+ source = srcs[0]
941
+
942
+ # Ensure the destination folder exists
943
+ dst = os.path.join(source, 'results')
944
+ print(f"Savings results to {dst}")
945
+ os.makedirs(dst, exist_ok=True)
946
+
947
+ # Save the figure and results DataFrame
948
+ fig.savefig(os.path.join(dst, f"figure_controls_{i}_{settings['representation']}_{settings['summary_func']}_{settings['graph_type']}.pdf"), format='pdf')
949
+ res.to_csv(os.path.join(dst, f"results_controls_{i}_{settings['representation']}_{settings['summary_func']}_{settings['graph_type']}.csv"), index=False)
950
+
951
+ return