spacr 0.0.1__py3-none-any.whl → 0.0.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- spacr/__init__.py +6 -2
- spacr/__main__.py +0 -2
- spacr/alpha.py +807 -0
- spacr/annotate_app.py +118 -120
- spacr/chris.py +50 -0
- spacr/cli.py +25 -187
- spacr/core.py +1611 -389
- spacr/deep_spacr.py +696 -0
- spacr/foldseek.py +779 -0
- spacr/get_alfafold_structures.py +72 -0
- spacr/graph_learning.py +320 -0
- spacr/graph_learning_lap.py +84 -0
- spacr/gui.py +145 -0
- spacr/gui_2.py +90 -0
- spacr/gui_classify_app.py +187 -0
- spacr/gui_mask_app.py +149 -174
- spacr/gui_measure_app.py +116 -109
- spacr/gui_sim_app.py +0 -0
- spacr/gui_utils.py +679 -139
- spacr/io.py +620 -469
- spacr/mask_app.py +116 -9
- spacr/measure.py +178 -84
- spacr/models/cp/toxo_pv_lumen.CP_model +0 -0
- spacr/old_code.py +255 -1
- spacr/plot.py +263 -100
- spacr/sequencing.py +1130 -0
- spacr/sim.py +634 -122
- spacr/timelapse.py +343 -53
- spacr/train.py +195 -22
- spacr/umap.py +0 -689
- spacr/utils.py +1530 -188
- spacr-0.0.6.dist-info/METADATA +118 -0
- spacr-0.0.6.dist-info/RECORD +39 -0
- {spacr-0.0.1.dist-info → spacr-0.0.6.dist-info}/WHEEL +1 -1
- spacr-0.0.6.dist-info/entry_points.txt +9 -0
- spacr-0.0.1.dist-info/METADATA +0 -64
- spacr-0.0.1.dist-info/RECORD +0 -26
- spacr-0.0.1.dist-info/entry_points.txt +0 -5
- {spacr-0.0.1.dist-info → spacr-0.0.6.dist-info}/LICENSE +0 -0
- {spacr-0.0.1.dist-info → spacr-0.0.6.dist-info}/top_level.txt +0 -0
spacr/core.py
CHANGED
@@ -1,12 +1,10 @@
|
|
1
|
-
import os, sqlite3, gc, torch, time, random, shutil, cv2, tarfile, datetime
|
1
|
+
import os, sqlite3, gc, torch, time, random, shutil, cv2, tarfile, datetime, shap
|
2
2
|
|
3
|
-
# image and array processing
|
4
3
|
import numpy as np
|
5
4
|
import pandas as pd
|
6
5
|
|
7
|
-
import
|
6
|
+
from cellpose import train
|
8
7
|
from cellpose import models as cp_models
|
9
|
-
from cellpose import denoise
|
10
8
|
|
11
9
|
import statsmodels.formula.api as smf
|
12
10
|
import statsmodels.api as sm
|
@@ -15,31 +13,37 @@ from IPython.display import display
|
|
15
13
|
from multiprocessing import Pool, cpu_count, Value, Lock
|
16
14
|
|
17
15
|
import seaborn as sns
|
18
|
-
|
16
|
+
|
19
17
|
from skimage.measure import regionprops, label
|
20
|
-
|
18
|
+
from skimage.morphology import square
|
21
19
|
from skimage.transform import resize as resizescikit
|
22
|
-
from sklearn.model_selection import train_test_split
|
23
20
|
from collections import defaultdict
|
24
|
-
import multiprocessing
|
25
21
|
from torch.utils.data import DataLoader, random_split
|
26
|
-
import
|
27
|
-
|
22
|
+
from sklearn.cluster import KMeans
|
23
|
+
from sklearn.decomposition import PCA
|
28
24
|
|
29
|
-
|
25
|
+
from skimage import measure
|
30
26
|
from sklearn.model_selection import train_test_split
|
31
|
-
from sklearn.ensemble import IsolationForest
|
27
|
+
from sklearn.ensemble import IsolationForest, RandomForestClassifier, HistGradientBoostingClassifier
|
28
|
+
from sklearn.linear_model import LogisticRegression
|
29
|
+
from sklearn.inspection import permutation_importance
|
30
|
+
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, classification_report
|
31
|
+
from sklearn.preprocessing import StandardScaler
|
32
32
|
|
33
|
-
from .
|
33
|
+
from scipy.ndimage import binary_dilation
|
34
|
+
from scipy.spatial.distance import cosine, euclidean, mahalanobis, cityblock, minkowski, chebyshev, hamming, jaccard, braycurtis
|
35
|
+
|
36
|
+
import torchvision.transforms as transforms
|
37
|
+
from xgboost import XGBClassifier
|
38
|
+
import shap
|
39
|
+
|
40
|
+
import matplotlib.pyplot as plt
|
41
|
+
import matplotlib
|
42
|
+
matplotlib.use('Agg')
|
43
|
+
#import matplotlib.pyplot as plt
|
34
44
|
|
35
|
-
|
36
|
-
#from .plot import plot_merged, plot_arrays, _plot_controls, _plot_recruitment, _imshow, _plot_histograms_and_stats, _reg_v_plot, visualize_masks, plot_comparison_results
|
37
|
-
#from .utils import extract_boundaries, boundary_f1_score, compute_segmentation_ap, jaccard_index, dice_coefficient, _object_filter
|
38
|
-
#from .utils import resize_images_and_labels, generate_fraction_map, MLR, fishers_odds, lasso_reg, model_metrics, _map_wells_png, check_multicollinearity, init_globals, add_images_to_tar
|
39
|
-
#from .utils import get_paths_from_db, pick_best_model, test_model_performance, evaluate_model_performance, compute_irm_penalty
|
40
|
-
#from .utils import _pivot_counts_table, _generate_masks, _get_cellpose_channels, annotate_conditions, _calculate_recruitment, calculate_loss, _group_by_well, choose_model
|
45
|
+
from .logger import log_function_call
|
41
46
|
|
42
|
-
@log_function_call
|
43
47
|
def analyze_plaques(folder):
|
44
48
|
summary_data = []
|
45
49
|
details_data = []
|
@@ -76,171 +80,95 @@ def analyze_plaques(folder):
|
|
76
80
|
|
77
81
|
print(f"Analysis completed and saved to database '{db_name}'.")
|
78
82
|
|
79
|
-
@log_function_call
|
80
|
-
def compare_masks(dir1, dir2, dir3, verbose=False):
|
81
|
-
|
82
|
-
from .io import _read_mask
|
83
|
-
from .plot import visualize_masks, plot_comparison_results
|
84
|
-
from .utils import extract_boundaries, boundary_f1_score, compute_segmentation_ap, jaccard_index, dice_coefficient
|
85
|
-
|
86
|
-
filenames = os.listdir(dir1)
|
87
|
-
results = []
|
88
|
-
cond_1 = os.path.basename(dir1)
|
89
|
-
cond_2 = os.path.basename(dir2)
|
90
|
-
cond_3 = os.path.basename(dir3)
|
91
|
-
for index, filename in enumerate(filenames):
|
92
|
-
print(f'Processing image:{index+1}', end='\r', flush=True)
|
93
|
-
path1, path2, path3 = os.path.join(dir1, filename), os.path.join(dir2, filename), os.path.join(dir3, filename)
|
94
|
-
if os.path.exists(path2) and os.path.exists(path3):
|
95
|
-
|
96
|
-
mask1, mask2, mask3 = _read_mask(path1), _read_mask(path2), _read_mask(path3)
|
97
|
-
boundary_true1, boundary_true2, boundary_true3 = extract_boundaries(mask1), extract_boundaries(mask2), extract_boundaries(mask3)
|
98
|
-
|
99
|
-
|
100
|
-
true_masks, pred_masks = [mask1], [mask2, mask3] # Assuming mask1 is the ground truth for simplicity
|
101
|
-
true_labels, pred_labels_1, pred_labels_2 = label(mask1), label(mask2), label(mask3)
|
102
|
-
average_precision_0, average_precision_1 = compute_segmentation_ap(mask1, mask2), compute_segmentation_ap(mask1, mask3)
|
103
|
-
ap_scores = [average_precision_0, average_precision_1]
|
104
|
-
|
105
|
-
if verbose:
|
106
|
-
unique_values1, unique_values2, unique_values3 = np.unique(mask1), np.unique(mask2), np.unique(mask3)
|
107
|
-
print(f"Unique values in mask 1: {unique_values1}, mask 2: {unique_values2}, mask 3: {unique_values3}")
|
108
|
-
visualize_masks(boundary_true1, boundary_true2, boundary_true3, title=f"Boundaries - {filename}")
|
109
|
-
|
110
|
-
boundary_f1_12, boundary_f1_13, boundary_f1_23 = boundary_f1_score(mask1, mask2), boundary_f1_score(mask1, mask3), boundary_f1_score(mask2, mask3)
|
111
|
-
|
112
|
-
if (np.unique(mask1).size == 1 and np.unique(mask1)[0] == 0) and \
|
113
|
-
(np.unique(mask2).size == 1 and np.unique(mask2)[0] == 0) and \
|
114
|
-
(np.unique(mask3).size == 1 and np.unique(mask3)[0] == 0):
|
115
|
-
continue
|
116
|
-
|
117
|
-
if verbose:
|
118
|
-
unique_values4, unique_values5, unique_values6 = np.unique(boundary_f1_12), np.unique(boundary_f1_13), np.unique(boundary_f1_23)
|
119
|
-
print(f"Unique values in boundary mask 1: {unique_values4}, mask 2: {unique_values5}, mask 3: {unique_values6}")
|
120
|
-
visualize_masks(mask1, mask2, mask3, title=filename)
|
121
|
-
|
122
|
-
jaccard12 = jaccard_index(mask1, mask2)
|
123
|
-
dice12 = dice_coefficient(mask1, mask2)
|
124
|
-
jaccard13 = jaccard_index(mask1, mask3)
|
125
|
-
dice13 = dice_coefficient(mask1, mask3)
|
126
|
-
jaccard23 = jaccard_index(mask2, mask3)
|
127
|
-
dice23 = dice_coefficient(mask2, mask3)
|
128
|
-
|
129
|
-
results.append({
|
130
|
-
f'filename': filename,
|
131
|
-
f'jaccard_{cond_1}_{cond_2}': jaccard12,
|
132
|
-
f'dice_{cond_1}_{cond_2}': dice12,
|
133
|
-
f'jaccard_{cond_1}_{cond_3}': jaccard13,
|
134
|
-
f'dice_{cond_1}_{cond_3}': dice13,
|
135
|
-
f'jaccard_{cond_2}_{cond_3}': jaccard23,
|
136
|
-
f'dice_{cond_2}_{cond_3}': dice23,
|
137
|
-
f'boundary_f1_{cond_1}_{cond_2}': boundary_f1_12,
|
138
|
-
f'boundary_f1_{cond_1}_{cond_3}': boundary_f1_13,
|
139
|
-
f'boundary_f1_{cond_2}_{cond_3}': boundary_f1_23,
|
140
|
-
f'average_precision_{cond_1}_{cond_2}': ap_scores[0],
|
141
|
-
f'average_precision_{cond_1}_{cond_3}': ap_scores[1]
|
142
|
-
})
|
143
|
-
else:
|
144
|
-
print(f'Cannot find {path1} or {path2} or {path3}')
|
145
|
-
fig = plot_comparison_results(results)
|
146
|
-
return results, fig
|
147
|
-
|
148
|
-
def generate_cp_masks(settings):
|
149
|
-
|
150
|
-
src = settings['src']
|
151
|
-
model_name = settings['model_name']
|
152
|
-
channels = settings['channels']
|
153
|
-
diameter = settings['diameter']
|
154
|
-
regex = '.tif'
|
155
|
-
#flow_threshold = 30
|
156
|
-
cellprob_threshold = settings['cellprob_threshold']
|
157
|
-
figuresize = 25
|
158
|
-
cmap = 'inferno'
|
159
|
-
verbose = settings['verbose']
|
160
|
-
plot = settings['plot']
|
161
|
-
save = settings['save']
|
162
|
-
custom_model = settings['custom_model']
|
163
|
-
signal_thresholds = 1000
|
164
|
-
normalize = settings['normalize']
|
165
|
-
resize = settings['resize']
|
166
|
-
target_height = settings['width_height'][1]
|
167
|
-
target_width = settings['width_height'][0]
|
168
|
-
rescale = settings['rescale']
|
169
|
-
resample = settings['resample']
|
170
|
-
net_avg = settings['net_avg']
|
171
|
-
invert = settings['invert']
|
172
|
-
circular = settings['circular']
|
173
|
-
percentiles = settings['percentiles']
|
174
|
-
overlay = settings['overlay']
|
175
|
-
grayscale = settings['grayscale']
|
176
|
-
flow_threshold = settings['flow_threshold']
|
177
|
-
batch_size = settings['batch_size']
|
178
|
-
|
179
|
-
dst = os.path.join(src,'masks')
|
180
|
-
os.makedirs(dst, exist_ok=True)
|
181
|
-
|
182
|
-
identify_masks(src, dst, model_name, channels, diameter, batch_size, flow_threshold, cellprob_threshold, figuresize, cmap, verbose, plot, save, custom_model, signal_thresholds, normalize, resize, target_height, target_width, rescale, resample, net_avg, invert, circular, percentiles, overlay, grayscale)
|
183
|
-
|
184
|
-
@log_function_call
|
185
83
|
def train_cellpose(settings):
|
186
84
|
|
187
85
|
from .io import _load_normalized_images_and_labels, _load_images_and_labels
|
188
86
|
from .utils import resize_images_and_labels
|
189
87
|
|
190
88
|
img_src = settings['img_src']
|
191
|
-
mask_src=
|
192
|
-
secondary_image_dir = None
|
193
|
-
model_name = settings['model_name']
|
194
|
-
model_type = settings['model_type']
|
195
|
-
learning_rate = settings['learning_rate']
|
196
|
-
weight_decay = settings['weight_decay']
|
197
|
-
batch_size = settings['batch_size']
|
198
|
-
n_epochs = settings['n_epochs']
|
199
|
-
verbose = settings['verbose']
|
200
|
-
signal_thresholds = settings['signal_thresholds']
|
201
|
-
channels = settings['channels']
|
202
|
-
from_scratch = settings['from_scratch']
|
203
|
-
diameter = settings['diameter']
|
204
|
-
resize = settings['resize']
|
205
|
-
rescale = settings['rescale']
|
206
|
-
normalize = settings['normalize']
|
207
|
-
target_height = settings['width_height'][1]
|
208
|
-
target_width = settings['width_height'][0]
|
209
|
-
circular = settings['circular']
|
210
|
-
invert = settings['invert']
|
211
|
-
percentiles = settings['percentiles']
|
212
|
-
grayscale = settings['grayscale']
|
89
|
+
mask_src = os.path.join(img_src, 'masks')
|
213
90
|
|
91
|
+
model_name = settings.setdefault( 'model_name', '')
|
92
|
+
|
93
|
+
model_name = settings.setdefault('model_name', 'model_name')
|
94
|
+
|
95
|
+
model_type = settings.setdefault( 'model_type', 'cyto')
|
96
|
+
learning_rate = settings.setdefault( 'learning_rate', 0.01)
|
97
|
+
weight_decay = settings.setdefault( 'weight_decay', 1e-05)
|
98
|
+
batch_size = settings.setdefault( 'batch_size', 50)
|
99
|
+
n_epochs = settings.setdefault( 'n_epochs', 100)
|
100
|
+
from_scratch = settings.setdefault( 'from_scratch', False)
|
101
|
+
diameter = settings.setdefault( 'diameter', 40)
|
102
|
+
|
103
|
+
remove_background = settings.setdefault( 'remove_background', False)
|
104
|
+
background = settings.setdefault( 'background', 100)
|
105
|
+
Signal_to_noise = settings.setdefault( 'Signal_to_noise', 10)
|
106
|
+
verbose = settings.setdefault( 'verbose', False)
|
107
|
+
|
108
|
+
|
109
|
+
channels = settings.setdefault( 'channels', [0,0])
|
110
|
+
normalize = settings.setdefault( 'normalize', True)
|
111
|
+
percentiles = settings.setdefault( 'percentiles', None)
|
112
|
+
circular = settings.setdefault( 'circular', False)
|
113
|
+
invert = settings.setdefault( 'invert', False)
|
114
|
+
resize = settings.setdefault( 'resize', False)
|
115
|
+
|
116
|
+
if resize:
|
117
|
+
target_height = settings['width_height'][1]
|
118
|
+
target_width = settings['width_height'][0]
|
119
|
+
|
120
|
+
grayscale = settings.setdefault( 'grayscale', True)
|
121
|
+
rescale = settings.setdefault( 'channels', False)
|
122
|
+
test = settings.setdefault( 'test', False)
|
123
|
+
|
124
|
+
if test:
|
125
|
+
test_img_src = os.path.join(os.path.dirname(img_src), 'test')
|
126
|
+
test_mask_src = os.path.join(test_img_src, 'mask')
|
127
|
+
|
128
|
+
test_images, test_masks, test_image_names, test_mask_names = None,None,None,None,
|
214
129
|
print(settings)
|
215
130
|
|
216
131
|
if from_scratch:
|
217
132
|
model_name=f'scratch_{model_name}_{model_type}_e{n_epochs}_X{target_width}_Y{target_height}.CP_model'
|
218
133
|
else:
|
219
|
-
|
134
|
+
if resize:
|
135
|
+
model_name=f'{model_name}_{model_type}_e{n_epochs}_X{target_width}_Y{target_height}.CP_model'
|
136
|
+
else:
|
137
|
+
model_name=f'{model_name}_{model_type}_e{n_epochs}.CP_model'
|
220
138
|
|
221
139
|
model_save_path = os.path.join(mask_src, 'models', 'cellpose_model')
|
222
|
-
|
140
|
+
print(model_save_path)
|
141
|
+
os.makedirs(model_save_path, exist_ok=True)
|
223
142
|
|
224
143
|
settings_df = pd.DataFrame(list(settings.items()), columns=['Key', 'Value'])
|
225
144
|
settings_csv = os.path.join(model_save_path,f'{model_name}_settings.csv')
|
226
145
|
settings_df.to_csv(settings_csv, index=False)
|
227
146
|
|
228
|
-
if
|
229
|
-
|
230
|
-
|
231
|
-
else:
|
232
|
-
model = cp_models.CellposeModel(gpu=True, model_type=model_type, net_avg=False, diam_mean=diameter, pretrained_model=None)
|
233
|
-
if model_type !='cyto':
|
147
|
+
if from_scratch:
|
148
|
+
model = cp_models.CellposeModel(gpu=True, model_type=model_type, diam_mean=diameter, pretrained_model=None)
|
149
|
+
else:
|
234
150
|
model = cp_models.CellposeModel(gpu=True, model_type=model_type)
|
235
151
|
|
236
|
-
|
237
|
-
|
238
|
-
|
239
|
-
|
152
|
+
if normalize:
|
153
|
+
|
154
|
+
image_files = [os.path.join(img_src, f) for f in os.listdir(img_src) if f.endswith('.tif')]
|
155
|
+
label_files = [os.path.join(mask_src, f) for f in os.listdir(mask_src) if f.endswith('.tif')]
|
156
|
+
images, masks, image_names, mask_names = _load_normalized_images_and_labels(image_files, label_files, channels, percentiles, circular, invert, verbose, remove_background, background, Signal_to_noise)
|
240
157
|
images = [np.squeeze(img) if img.shape[-1] == 1 else img for img in images]
|
158
|
+
|
159
|
+
if test:
|
160
|
+
test_image_files = [os.path.join(test_img_src, f) for f in os.listdir(test_img_src) if f.endswith('.tif')]
|
161
|
+
test_label_files = [os.path.join(test_mask_src, f) for f in os.listdir(test_mask_src) if f.endswith('.tif')]
|
162
|
+
test_images, test_masks, test_image_names, test_mask_names = _load_normalized_images_and_labels(test_image_files, test_label_files, channels, percentiles, circular, invert, verbose, remove_background, background, Signal_to_noise)
|
163
|
+
test_images = [np.squeeze(img) if img.shape[-1] == 1 else img for img in test_images]
|
164
|
+
|
241
165
|
else:
|
242
166
|
images, masks, image_names, mask_names = _load_images_and_labels(img_src, mask_src, circular, invert)
|
243
167
|
images = [np.squeeze(img) if img.shape[-1] == 1 else img for img in images]
|
168
|
+
|
169
|
+
if test:
|
170
|
+
test_images, test_masks, test_image_names, test_mask_names = _load_images_and_labels(img_src=test_img_src, mask_src=test_mask_src, circular=circular, invert=invert)
|
171
|
+
test_images = [np.squeeze(img) if img.shape[-1] == 1 else img for img in test_images]
|
244
172
|
|
245
173
|
if resize:
|
246
174
|
images, masks = resize_images_and_labels(images, masks, target_height, target_width, show_example=True)
|
@@ -259,29 +187,44 @@ def train_cellpose(settings):
|
|
259
187
|
|
260
188
|
print(f'image shape: {images[0].shape}, image type: images[0].shape mask shape: {masks[0].shape}, image type: masks[0].shape')
|
261
189
|
save_every = int(n_epochs/10)
|
262
|
-
|
263
|
-
|
264
|
-
|
265
|
-
model.
|
266
|
-
|
267
|
-
|
268
|
-
|
269
|
-
|
270
|
-
|
271
|
-
|
272
|
-
|
273
|
-
|
274
|
-
|
275
|
-
|
276
|
-
|
277
|
-
|
278
|
-
|
279
|
-
|
280
|
-
|
190
|
+
if save_every < 10:
|
191
|
+
save_every = n_epochs
|
192
|
+
|
193
|
+
train.train_seg(model.net,
|
194
|
+
train_data=images,
|
195
|
+
train_labels=masks,
|
196
|
+
train_files=image_names,
|
197
|
+
train_labels_files=mask_names,
|
198
|
+
train_probs=None,
|
199
|
+
test_data=test_images,
|
200
|
+
test_labels=test_masks,
|
201
|
+
test_files=test_image_names,
|
202
|
+
test_labels_files=test_mask_names,
|
203
|
+
test_probs=None,
|
204
|
+
load_files=True,
|
205
|
+
batch_size=batch_size,
|
206
|
+
learning_rate=learning_rate,
|
207
|
+
n_epochs=n_epochs,
|
208
|
+
weight_decay=weight_decay,
|
209
|
+
momentum=0.9,
|
210
|
+
SGD=False,
|
211
|
+
channels=cp_channels,
|
212
|
+
channel_axis=None,
|
213
|
+
#rgb=False,
|
214
|
+
normalize=False,
|
215
|
+
compute_flows=False,
|
216
|
+
save_path=model_save_path,
|
217
|
+
save_every=save_every,
|
218
|
+
nimg_per_epoch=None,
|
219
|
+
nimg_test_per_epoch=None,
|
220
|
+
rescale=rescale,
|
221
|
+
#scale_range=None,
|
222
|
+
#bsize=224,
|
223
|
+
min_train_masks=1,
|
224
|
+
model_name=model_name)
|
281
225
|
|
282
226
|
return print(f"Model saved at: {model_save_path}/{model_name}")
|
283
227
|
|
284
|
-
@log_function_call
|
285
228
|
def analyze_data_reg(sequencing_loc, dv_loc, agg_type = 'mean', dv_col='pred', transform=None, min_cell_count=50, min_reads=100, min_wells=2, max_wells=1000, min_frequency=0.0,remove_outlier_genes=False, refine_model=False,by_plate=False, regression_type='mlr', alpha_value=0.01, fishers=False, fisher_threshold=0.9):
|
286
229
|
|
287
230
|
from .plot import _reg_v_plot
|
@@ -430,7 +373,6 @@ def analyze_data_reg(sequencing_loc, dv_loc, agg_type = 'mean', dv_col='pred', t
|
|
430
373
|
|
431
374
|
return result
|
432
375
|
|
433
|
-
@log_function_call
|
434
376
|
def analyze_data_reg(sequencing_loc, dv_loc, agg_type = 'mean', min_cell_count=50, min_reads=100, min_wells=2, max_wells=1000, remove_outlier_genes=False, refine_model=False, by_plate=False, threshold=0.5, fishers=False):
|
435
377
|
|
436
378
|
from .plot import _reg_v_plot
|
@@ -609,7 +551,6 @@ def analyze_data_reg(sequencing_loc, dv_loc, agg_type = 'mean', min_cell_count=5
|
|
609
551
|
|
610
552
|
return max_effects, max_effects_pvalues, model, df
|
611
553
|
|
612
|
-
@log_function_call
|
613
554
|
def regression_analasys(dv_df,sequencing_loc, min_reads=75, min_wells=2, max_wells=0, model_type = 'mlr', min_cells=100, transform='logit', min_frequency=0.05, gene_column='gene', effect_size_threshold=0.25, fishers=True, clean_regression=False, VIF_threshold=10):
|
614
555
|
|
615
556
|
from .utils import generate_fraction_map, fishers_odds, model_metrics, check_multicollinearity
|
@@ -777,7 +718,6 @@ def regression_analasys(dv_df,sequencing_loc, min_reads=75, min_wells=2, max_wel
|
|
777
718
|
|
778
719
|
return
|
779
720
|
|
780
|
-
@log_function_call
|
781
721
|
def merge_pred_mes(src,
|
782
722
|
pred_loc,
|
783
723
|
target='protein of interest',
|
@@ -846,15 +786,6 @@ def merge_pred_mes(src,
|
|
846
786
|
|
847
787
|
if verbose:
|
848
788
|
_plot_histograms_and_stats(df=joined_df)
|
849
|
-
|
850
|
-
#dv = joined_df.copy()
|
851
|
-
#if 'prc' not in dv.columns:
|
852
|
-
#dv['prc'] = dv['plate'] + '_' + dv['row'] + '_' + dv['col']
|
853
|
-
#dv = dv[['pred']].groupby('prc').mean()
|
854
|
-
#dv.set_index('prc', inplace=True)
|
855
|
-
|
856
|
-
#loc = '/mnt/data/CellVoyager/20x/tsg101/crispr_screen/all/measurements/dv.csv'
|
857
|
-
#dv.to_csv(loc, index=True, header=True, mode='w')
|
858
789
|
|
859
790
|
return joined_df
|
860
791
|
|
@@ -941,30 +872,38 @@ def annotate_results(pred_loc):
|
|
941
872
|
display(df)
|
942
873
|
return df
|
943
874
|
|
944
|
-
def generate_dataset(src,
|
875
|
+
def generate_dataset(src, file_metadata=None, experiment='TSG101_screen', sample=None):
|
945
876
|
|
946
|
-
from .utils import
|
947
|
-
|
948
|
-
db_path = os.path.join(src, 'measurements','measurements.db')
|
877
|
+
from .utils import initiate_counter, add_images_to_tar
|
878
|
+
|
879
|
+
db_path = os.path.join(src, 'measurements', 'measurements.db')
|
949
880
|
dst = os.path.join(src, 'datasets')
|
950
|
-
|
951
|
-
global total_images
|
952
881
|
all_paths = []
|
953
|
-
|
882
|
+
|
954
883
|
# Connect to the database and retrieve the image paths
|
955
884
|
print(f'Reading DataBase: {db_path}')
|
956
|
-
|
957
|
-
|
958
|
-
|
959
|
-
|
960
|
-
|
961
|
-
|
962
|
-
|
963
|
-
|
964
|
-
|
965
|
-
|
966
|
-
|
967
|
-
|
885
|
+
try:
|
886
|
+
with sqlite3.connect(db_path) as conn:
|
887
|
+
cursor = conn.cursor()
|
888
|
+
if file_metadata:
|
889
|
+
if isinstance(file_metadata, str):
|
890
|
+
cursor.execute("SELECT png_path FROM png_list WHERE png_path LIKE ?", (f"%{file_metadata}%",))
|
891
|
+
else:
|
892
|
+
cursor.execute("SELECT png_path FROM png_list")
|
893
|
+
|
894
|
+
while True:
|
895
|
+
rows = cursor.fetchmany(1000)
|
896
|
+
if not rows:
|
897
|
+
break
|
898
|
+
all_paths.extend([row[0] for row in rows])
|
899
|
+
|
900
|
+
except sqlite3.Error as e:
|
901
|
+
print(f"Database error: {e}")
|
902
|
+
return
|
903
|
+
except Exception as e:
|
904
|
+
print(f"Error: {e}")
|
905
|
+
return
|
906
|
+
|
968
907
|
if isinstance(sample, int):
|
969
908
|
selected_paths = random.sample(all_paths, sample)
|
970
909
|
print(f'Random selection of {len(selected_paths)} paths')
|
@@ -972,23 +911,18 @@ def generate_dataset(src, file_type=None, experiment='TSG101_screen', sample=Non
|
|
972
911
|
selected_paths = all_paths
|
973
912
|
random.shuffle(selected_paths)
|
974
913
|
print(f'All paths: {len(selected_paths)} paths')
|
975
|
-
|
914
|
+
|
976
915
|
total_images = len(selected_paths)
|
977
|
-
print(f'
|
978
|
-
|
916
|
+
print(f'Found {total_images} images')
|
917
|
+
|
979
918
|
# Create a temp folder in dst
|
980
919
|
temp_dir = os.path.join(dst, "temp_tars")
|
981
920
|
os.makedirs(temp_dir, exist_ok=True)
|
982
921
|
|
983
922
|
# Chunking the data
|
984
|
-
|
985
|
-
|
986
|
-
|
987
|
-
remainder = len(selected_paths) % num_procs
|
988
|
-
else:
|
989
|
-
num_procs = 2
|
990
|
-
chunk_size = len(selected_paths) // 2
|
991
|
-
remainder = 0
|
923
|
+
num_procs = max(2, cpu_count() - 2)
|
924
|
+
chunk_size = len(selected_paths) // num_procs
|
925
|
+
remainder = len(selected_paths) % num_procs
|
992
926
|
|
993
927
|
paths_chunks = []
|
994
928
|
start = 0
|
@@ -998,45 +932,43 @@ def generate_dataset(src, file_type=None, experiment='TSG101_screen', sample=Non
|
|
998
932
|
start = end
|
999
933
|
|
1000
934
|
temp_tar_files = [os.path.join(temp_dir, f'temp_{i}.tar') for i in range(num_procs)]
|
1001
|
-
|
1002
|
-
# Initialize the shared objects
|
1003
|
-
counter_ = Value('i', 0)
|
1004
|
-
lock_ = Lock()
|
1005
935
|
|
1006
|
-
ctx = multiprocessing.get_context('spawn')
|
1007
|
-
|
1008
936
|
print(f'Generating temporary tar files in {dst}')
|
1009
|
-
|
937
|
+
|
938
|
+
# Initialize shared counter and lock
|
939
|
+
counter = Value('i', 0)
|
940
|
+
lock = Lock()
|
941
|
+
|
942
|
+
with Pool(processes=num_procs, initializer=initiate_counter, initargs=(counter, lock)) as pool:
|
943
|
+
pool.starmap(add_images_to_tar, [(paths_chunks[i], temp_tar_files[i], total_images) for i in range(num_procs)])
|
944
|
+
|
1010
945
|
# Combine the temporary tar files into a final tar
|
1011
946
|
date_name = datetime.date.today().strftime('%y%m%d')
|
1012
|
-
|
947
|
+
if not file_metadata is None:
|
948
|
+
tar_name = f'{date_name}_{experiment}_{file_metadata}.tar'
|
949
|
+
else:
|
950
|
+
tar_name = f'{date_name}_{experiment}.tar'
|
951
|
+
tar_name = os.path.join(dst, tar_name)
|
1013
952
|
if os.path.exists(tar_name):
|
1014
953
|
number = random.randint(1, 100)
|
1015
|
-
tar_name_2 = f'{date_name}_{experiment}_{
|
1016
|
-
print(f'Warning: {os.path.basename(tar_name)} exists saving as {os.path.basename(tar_name_2)} ')
|
1017
|
-
tar_name = tar_name_2
|
1018
|
-
|
1019
|
-
# Add the counter and lock to the arguments for pool.map
|
954
|
+
tar_name_2 = f'{date_name}_{experiment}_{file_metadata}_{number}.tar'
|
955
|
+
print(f'Warning: {os.path.basename(tar_name)} exists, saving as {os.path.basename(tar_name_2)} ')
|
956
|
+
tar_name = os.path.join(dst, tar_name_2)
|
957
|
+
|
1020
958
|
print(f'Merging temporary files')
|
1021
|
-
#with Pool(processes=num_procs, initializer=init_globals, initargs=(counter_, lock_)) as pool:
|
1022
|
-
# results = pool.map(add_images_to_tar, zip(paths_chunks, temp_tar_files))
|
1023
959
|
|
1024
|
-
with
|
1025
|
-
|
1026
|
-
|
1027
|
-
|
1028
|
-
|
1029
|
-
|
1030
|
-
|
1031
|
-
t.extract(member, path=dst)
|
1032
|
-
final_tar.add(os.path.join(dst, member.name), arcname=member.name)
|
1033
|
-
os.remove(os.path.join(dst, member.name))
|
1034
|
-
os.remove(tar_path)
|
960
|
+
with tarfile.open(tar_name, 'w') as final_tar:
|
961
|
+
for temp_tar_path in temp_tar_files:
|
962
|
+
with tarfile.open(temp_tar_path, 'r') as temp_tar:
|
963
|
+
for member in temp_tar.getmembers():
|
964
|
+
file_obj = temp_tar.extractfile(member)
|
965
|
+
final_tar.addfile(member, file_obj)
|
966
|
+
os.remove(temp_tar_path)
|
1035
967
|
|
1036
968
|
# Delete the temp folder
|
1037
969
|
shutil.rmtree(temp_dir)
|
1038
|
-
print(f"\nSaved {total_images} images to {
|
1039
|
-
|
970
|
+
print(f"\nSaved {total_images} images to {tar_name}")
|
971
|
+
|
1040
972
|
def apply_model_to_tar(tar_path, model_path, file_type='cell_png', image_size=224, batch_size=64, normalize=True, preload='images', num_workers=10, verbose=False):
|
1041
973
|
|
1042
974
|
from .io import TarImageDataset, DataLoader
|
@@ -1088,7 +1020,7 @@ def apply_model_to_tar(tar_path, model_path, file_type='cell_png', image_size=22
|
|
1088
1020
|
batch_prediction_pos_prob = torch.sigmoid(outputs).cpu().numpy()
|
1089
1021
|
prediction_pos_probs.extend(batch_prediction_pos_prob.tolist())
|
1090
1022
|
filenames_list.extend(filenames)
|
1091
|
-
print(f'
|
1023
|
+
print(f'batch: {batch_idx}/{len(data_loader)}', end='\r', flush=True)
|
1092
1024
|
|
1093
1025
|
data = {'path':filenames_list, 'pred':prediction_pos_probs}
|
1094
1026
|
df = pd.DataFrame(data, index=None)
|
@@ -1143,7 +1075,6 @@ def apply_model(src, model_path, image_size=224, batch_size=64, normalize=True,
|
|
1143
1075
|
torch.cuda.memory.empty_cache()
|
1144
1076
|
return df
|
1145
1077
|
|
1146
|
-
|
1147
1078
|
def generate_training_data_file_list(src,
|
1148
1079
|
target='protein of interest',
|
1149
1080
|
cell_dim=4,
|
@@ -1272,7 +1203,14 @@ def generate_training_dataset(src, mode='annotation', annotation_column='test',
|
|
1272
1203
|
|
1273
1204
|
db_path = os.path.join(src, 'measurements','measurements.db')
|
1274
1205
|
dst = os.path.join(src, 'datasets', 'training')
|
1275
|
-
|
1206
|
+
|
1207
|
+
if os.path.exists(dst):
|
1208
|
+
for i in range(1, 1000):
|
1209
|
+
dst = os.path.join(src, 'datasets', f'training_{i}')
|
1210
|
+
if not os.path.exists(dst):
|
1211
|
+
print(f'Creating new directory for training: {dst}')
|
1212
|
+
break
|
1213
|
+
|
1276
1214
|
if mode == 'annotation':
|
1277
1215
|
class_paths_ls_2 = []
|
1278
1216
|
class_paths_ls = training_dataset_from_annotation(db_path, dst, annotation_column, annotated_classes=annotated_classes)
|
@@ -1283,6 +1221,7 @@ def generate_training_dataset(src, mode='annotation', annotation_column='test',
|
|
1283
1221
|
|
1284
1222
|
elif mode == 'metadata':
|
1285
1223
|
class_paths_ls = []
|
1224
|
+
class_len_ls = []
|
1286
1225
|
[df] = _read_db(db_loc=db_path, tables=['png_list'])
|
1287
1226
|
df['metadata_based_class'] = pd.NA
|
1288
1227
|
for i, class_ in enumerate(classes):
|
@@ -1290,7 +1229,18 @@ def generate_training_dataset(src, mode='annotation', annotation_column='test',
|
|
1290
1229
|
df.loc[df[metadata_type_by].isin(ls), 'metadata_based_class'] = class_
|
1291
1230
|
|
1292
1231
|
for class_ in classes:
|
1232
|
+
if size == None:
|
1233
|
+
c_s = []
|
1234
|
+
for c in classes:
|
1235
|
+
c_s_t_df = df[df['metadata_based_class'] == c]
|
1236
|
+
c_s.append(len(c_s_t_df))
|
1237
|
+
print(f'Found {len(c_s_t_df)} images for class {c}')
|
1238
|
+
size = min(c_s)
|
1239
|
+
print(f'Using the smallest class size: {size}')
|
1240
|
+
|
1293
1241
|
class_temp_df = df[df['metadata_based_class'] == class_]
|
1242
|
+
class_len_ls.append(len(class_temp_df))
|
1243
|
+
print(f'Found {len(class_temp_df)} images for class {class_}')
|
1294
1244
|
class_paths_temp = random.sample(class_temp_df['png_path'].tolist(), size)
|
1295
1245
|
class_paths_ls.append(class_paths_temp)
|
1296
1246
|
|
@@ -1347,7 +1297,8 @@ def generate_training_dataset(src, mode='annotation', annotation_column='test',
|
|
1347
1297
|
|
1348
1298
|
return
|
1349
1299
|
|
1350
|
-
def generate_loaders(src, train_mode='erm', mode='train', image_size=224, batch_size=32, classes=['nc','pc'], num_workers=None, validation_split=0.0, max_show=2, pin_memory=False, normalize=False, verbose=False):
|
1300
|
+
def generate_loaders(src, train_mode='erm', mode='train', image_size=224, batch_size=32, classes=['nc','pc'], num_workers=None, validation_split=0.0, max_show=2, pin_memory=False, normalize=False, channels=[1, 2, 3], verbose=False):
|
1301
|
+
|
1351
1302
|
"""
|
1352
1303
|
Generate data loaders for training and validation/test datasets.
|
1353
1304
|
|
@@ -1364,16 +1315,40 @@ def generate_loaders(src, train_mode='erm', mode='train', image_size=224, batch_
|
|
1364
1315
|
- pin_memory (bool): Whether to pin memory for faster data transfer.
|
1365
1316
|
- normalize (bool): Whether to normalize the input images.
|
1366
1317
|
- verbose (bool): Whether to print additional information and show images.
|
1318
|
+
- channels (list): The list of channels to retain. Options are [1, 2, 3] for all channels, [1, 2] for blue and green, etc.
|
1367
1319
|
|
1368
1320
|
Returns:
|
1369
1321
|
- train_loaders (list): List of data loaders for training datasets.
|
1370
1322
|
- val_loaders (list): List of data loaders for validation datasets.
|
1371
1323
|
- plate_names (list): List of plate names (only applicable when train_mode is 'irm').
|
1372
1324
|
"""
|
1373
|
-
|
1325
|
+
|
1374
1326
|
from .io import MyDataset
|
1375
1327
|
from .plot import _imshow
|
1376
|
-
|
1328
|
+
from torchvision import transforms
|
1329
|
+
from torch.utils.data import DataLoader, random_split
|
1330
|
+
from collections import defaultdict
|
1331
|
+
import os
|
1332
|
+
import random
|
1333
|
+
from PIL import Image
|
1334
|
+
from torchvision.transforms import ToTensor
|
1335
|
+
from .utils import SelectChannels
|
1336
|
+
|
1337
|
+
chans = []
|
1338
|
+
|
1339
|
+
if 'r' in channels:
|
1340
|
+
chans.append(1)
|
1341
|
+
if 'g' in channels:
|
1342
|
+
chans.append(2)
|
1343
|
+
if 'b' in channels:
|
1344
|
+
chans.append(3)
|
1345
|
+
|
1346
|
+
channels = chans
|
1347
|
+
|
1348
|
+
if verbose:
|
1349
|
+
print(f'Training a network on channels: {channels}')
|
1350
|
+
print(f'Channel 1: Red, Channel 2: Green, Channel 3: Blue')
|
1351
|
+
|
1377
1352
|
plate_to_filenames = defaultdict(list)
|
1378
1353
|
plate_to_labels = defaultdict(list)
|
1379
1354
|
train_loaders = []
|
@@ -1384,31 +1359,30 @@ def generate_loaders(src, train_mode='erm', mode='train', image_size=224, batch_
|
|
1384
1359
|
transform = transforms.Compose([
|
1385
1360
|
transforms.ToTensor(),
|
1386
1361
|
transforms.CenterCrop(size=(image_size, image_size)),
|
1362
|
+
SelectChannels(channels),
|
1387
1363
|
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])
|
1388
1364
|
else:
|
1389
1365
|
transform = transforms.Compose([
|
1390
1366
|
transforms.ToTensor(),
|
1391
|
-
transforms.CenterCrop(size=(image_size, image_size))
|
1392
|
-
|
1367
|
+
transforms.CenterCrop(size=(image_size, image_size)),
|
1368
|
+
SelectChannels(channels)])
|
1369
|
+
|
1393
1370
|
if mode == 'train':
|
1394
1371
|
data_dir = os.path.join(src, 'train')
|
1395
1372
|
shuffle = True
|
1396
|
-
print(
|
1397
|
-
|
1373
|
+
print('Generating Train and validation datasets')
|
1398
1374
|
elif mode == 'test':
|
1399
1375
|
data_dir = os.path.join(src, 'test')
|
1400
1376
|
val_loaders = []
|
1401
|
-
validation_split=0.0
|
1377
|
+
validation_split = 0.0
|
1402
1378
|
shuffle = True
|
1403
|
-
print(
|
1404
|
-
|
1379
|
+
print('Generating test dataset')
|
1405
1380
|
else:
|
1406
1381
|
print(f'mode:{mode} is not valid, use mode = train or test')
|
1407
1382
|
return
|
1408
|
-
|
1383
|
+
|
1409
1384
|
if train_mode == 'erm':
|
1410
1385
|
data = MyDataset(data_dir, classes, transform=transform, shuffle=shuffle, pin_memory=pin_memory)
|
1411
|
-
#train_loaders = DataLoader(data, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers if num_workers is not None else 0, pin_memory=pin_memory)
|
1412
1386
|
if validation_split > 0:
|
1413
1387
|
train_size = int((1 - validation_split) * len(data))
|
1414
1388
|
val_size = len(data) - train_size
|
@@ -1465,7 +1439,6 @@ def generate_loaders(src, train_mode='erm', mode='train', image_size=224, batch_
|
|
1465
1439
|
images = images.cpu()
|
1466
1440
|
label_strings = [str(label.item()) for label in labels]
|
1467
1441
|
_imshow(images, label_strings, nrow=20, fontsize=12)
|
1468
|
-
|
1469
1442
|
elif train_mode == 'irm':
|
1470
1443
|
for plate_name, train_loader in zip(plate_names, train_loaders):
|
1471
1444
|
print(f'Plate: {plate_name} with {len(train_loader.dataset)} images')
|
@@ -1584,15 +1557,30 @@ def analyze_recruitment(src, metadata_settings, advanced_settings):
|
|
1584
1557
|
df = df.dropna(subset=['condition'])
|
1585
1558
|
print(f'After dropping non-annotated wells: {len(df)} rows')
|
1586
1559
|
files = df['file_name'].tolist()
|
1560
|
+
print(f'found: {len(files)} files')
|
1587
1561
|
files = [item + '.npy' for item in files]
|
1588
1562
|
random.shuffle(files)
|
1589
|
-
|
1563
|
+
|
1564
|
+
_max = 10**100
|
1565
|
+
|
1566
|
+
if cell_size_range is None and nucleus_size_range is None and pathogen_size_range is None:
|
1567
|
+
filter_min_max = None
|
1568
|
+
else:
|
1569
|
+
if cell_size_range is None:
|
1570
|
+
cell_size_range = [0,_max]
|
1571
|
+
if nucleus_size_range is None:
|
1572
|
+
nucleus_size_range = [0,_max]
|
1573
|
+
if pathogen_size_range is None:
|
1574
|
+
pathogen_size_range = [0,_max]
|
1575
|
+
|
1576
|
+
filter_min_max = [[cell_size_range[0],cell_size_range[1]],[nucleus_size_range[0],nucleus_size_range[1]],[pathogen_size_range[0],pathogen_size_range[1]]]
|
1577
|
+
|
1590
1578
|
if plot:
|
1591
1579
|
plot_settings = {'include_noninfected':include_noninfected,
|
1592
1580
|
'include_multiinfected':include_multiinfected,
|
1593
1581
|
'include_multinucleated':include_multinucleated,
|
1594
1582
|
'remove_background':remove_background,
|
1595
|
-
'filter_min_max':
|
1583
|
+
'filter_min_max':filter_min_max,
|
1596
1584
|
'channel_dims':channel_dims,
|
1597
1585
|
'backgrounds':backgrounds,
|
1598
1586
|
'cell_mask_dim':mask_dims[0],
|
@@ -1649,19 +1637,225 @@ def analyze_recruitment(src, metadata_settings, advanced_settings):
|
|
1649
1637
|
cells,wells = _results_to_csv(src, df, df_well)
|
1650
1638
|
return [cells,wells]
|
1651
1639
|
|
1652
|
-
|
1653
|
-
|
1640
|
+
def _merge_cells_based_on_parasite_overlap(parasite_mask, cell_mask, nuclei_mask, overlap_threshold=5, perimeter_threshold=30):
|
1641
|
+
"""
|
1642
|
+
Merge cells in cell_mask if a parasite in parasite_mask overlaps with more than one cell,
|
1643
|
+
and if cells share more than a specified perimeter percentage.
|
1644
|
+
|
1645
|
+
Args:
|
1646
|
+
parasite_mask (ndarray): Mask of parasites.
|
1647
|
+
cell_mask (ndarray): Mask of cells.
|
1648
|
+
nuclei_mask (ndarray): Mask of nuclei.
|
1649
|
+
overlap_threshold (float): The percentage threshold for merging cells based on parasite overlap.
|
1650
|
+
perimeter_threshold (float): The percentage threshold for merging cells based on shared perimeter.
|
1651
|
+
|
1652
|
+
Returns:
|
1653
|
+
ndarray: The modified cell mask (cell_mask) with unique labels.
|
1654
|
+
"""
|
1655
|
+
labeled_cells = label(cell_mask)
|
1656
|
+
labeled_parasites = label(parasite_mask)
|
1657
|
+
labeled_nuclei = label(nuclei_mask)
|
1658
|
+
num_parasites = np.max(labeled_parasites)
|
1659
|
+
num_cells = np.max(labeled_cells)
|
1660
|
+
num_nuclei = np.max(labeled_nuclei)
|
1661
|
+
|
1662
|
+
# Merge cells based on parasite overlap
|
1663
|
+
for parasite_id in range(1, num_parasites + 1):
|
1664
|
+
current_parasite_mask = labeled_parasites == parasite_id
|
1665
|
+
overlapping_cell_labels = np.unique(labeled_cells[current_parasite_mask])
|
1666
|
+
overlapping_cell_labels = overlapping_cell_labels[overlapping_cell_labels != 0]
|
1667
|
+
if len(overlapping_cell_labels) > 1:
|
1668
|
+
# Calculate the overlap percentages
|
1669
|
+
overlap_percentages = [
|
1670
|
+
np.sum(current_parasite_mask & (labeled_cells == cell_label)) / np.sum(current_parasite_mask) * 100
|
1671
|
+
for cell_label in overlapping_cell_labels
|
1672
|
+
]
|
1673
|
+
# Merge cells if overlap percentage is above the threshold
|
1674
|
+
for cell_label, overlap_percentage in zip(overlapping_cell_labels, overlap_percentages):
|
1675
|
+
if overlap_percentage > overlap_threshold:
|
1676
|
+
first_label = overlapping_cell_labels[0]
|
1677
|
+
for other_label in overlapping_cell_labels[1:]:
|
1678
|
+
if other_label != first_label:
|
1679
|
+
cell_mask[cell_mask == other_label] = first_label
|
1680
|
+
|
1681
|
+
# Merge cells based on nucleus overlap
|
1682
|
+
for nucleus_id in range(1, num_nuclei + 1):
|
1683
|
+
current_nucleus_mask = labeled_nuclei == nucleus_id
|
1684
|
+
overlapping_cell_labels = np.unique(labeled_cells[current_nucleus_mask])
|
1685
|
+
overlapping_cell_labels = overlapping_cell_labels[overlapping_cell_labels != 0]
|
1686
|
+
if len(overlapping_cell_labels) > 1:
|
1687
|
+
# Calculate the overlap percentages
|
1688
|
+
overlap_percentages = [
|
1689
|
+
np.sum(current_nucleus_mask & (labeled_cells == cell_label)) / np.sum(current_nucleus_mask) * 100
|
1690
|
+
for cell_label in overlapping_cell_labels
|
1691
|
+
]
|
1692
|
+
# Merge cells if overlap percentage is above the threshold for each cell
|
1693
|
+
if all(overlap_percentage > overlap_threshold for overlap_percentage in overlap_percentages):
|
1694
|
+
first_label = overlapping_cell_labels[0]
|
1695
|
+
for other_label in overlapping_cell_labels[1:]:
|
1696
|
+
if other_label != first_label:
|
1697
|
+
cell_mask[cell_mask == other_label] = first_label
|
1698
|
+
|
1699
|
+
# Check for cells without nuclei and merge based on shared perimeter
|
1700
|
+
labeled_cells = label(cell_mask) # Re-label after merging based on overlap
|
1701
|
+
cell_regions = regionprops(labeled_cells)
|
1702
|
+
for region in cell_regions:
|
1703
|
+
cell_label = region.label
|
1704
|
+
cell_mask_binary = labeled_cells == cell_label
|
1705
|
+
overlapping_nuclei = np.unique(nuclei_mask[cell_mask_binary])
|
1706
|
+
overlapping_nuclei = overlapping_nuclei[overlapping_nuclei != 0]
|
1707
|
+
|
1708
|
+
if len(overlapping_nuclei) == 0:
|
1709
|
+
# Cell does not overlap with any nucleus
|
1710
|
+
perimeter = region.perimeter
|
1711
|
+
# Dilate the cell to find neighbors
|
1712
|
+
dilated_cell = binary_dilation(cell_mask_binary, structure=square(3))
|
1713
|
+
neighbor_cells = np.unique(labeled_cells[dilated_cell])
|
1714
|
+
neighbor_cells = neighbor_cells[(neighbor_cells != 0) & (neighbor_cells != cell_label)]
|
1715
|
+
# Calculate shared border length with neighboring cells
|
1716
|
+
shared_borders = [
|
1717
|
+
np.sum((labeled_cells == neighbor_label) & dilated_cell) for neighbor_label in neighbor_cells
|
1718
|
+
]
|
1719
|
+
shared_border_percentages = [shared_border / perimeter * 100 for shared_border in shared_borders]
|
1720
|
+
# Merge with the neighbor cell with the largest shared border percentage above the threshold
|
1721
|
+
if shared_borders:
|
1722
|
+
max_shared_border_index = np.argmax(shared_border_percentages)
|
1723
|
+
max_shared_border_percentage = shared_border_percentages[max_shared_border_index]
|
1724
|
+
if max_shared_border_percentage > perimeter_threshold:
|
1725
|
+
cell_mask[labeled_cells == cell_label] = neighbor_cells[max_shared_border_index]
|
1726
|
+
|
1727
|
+
# Relabel the merged cell mask
|
1728
|
+
relabeled_cell_mask, _ = label(cell_mask, return_num=True)
|
1729
|
+
return relabeled_cell_mask
|
1730
|
+
|
1731
|
+
def adjust_cell_masks(parasite_folder, cell_folder, nuclei_folder, overlap_threshold=5, perimeter_threshold=30):
|
1732
|
+
"""
|
1733
|
+
Process all npy files in the given folders. Merge and relabel cells in cell masks
|
1734
|
+
based on parasite overlap and cell perimeter sharing conditions.
|
1735
|
+
|
1736
|
+
Args:
|
1737
|
+
parasite_folder (str): Path to the folder containing parasite masks.
|
1738
|
+
cell_folder (str): Path to the folder containing cell masks.
|
1739
|
+
nuclei_folder (str): Path to the folder containing nuclei masks.
|
1740
|
+
overlap_threshold (float): The percentage threshold for merging cells based on parasite overlap.
|
1741
|
+
perimeter_threshold (float): The percentage threshold for merging cells based on shared perimeter.
|
1742
|
+
"""
|
1743
|
+
|
1744
|
+
parasite_files = sorted([f for f in os.listdir(parasite_folder) if f.endswith('.npy')])
|
1745
|
+
cell_files = sorted([f for f in os.listdir(cell_folder) if f.endswith('.npy')])
|
1746
|
+
nuclei_files = sorted([f for f in os.listdir(nuclei_folder) if f.endswith('.npy')])
|
1747
|
+
|
1748
|
+
# Ensure there are matching files in all folders
|
1749
|
+
if not (len(parasite_files) == len(cell_files) == len(nuclei_files)):
|
1750
|
+
raise ValueError("The number of files in the folders do not match.")
|
1751
|
+
|
1752
|
+
# Match files by name
|
1753
|
+
for file_name in parasite_files:
|
1754
|
+
parasite_path = os.path.join(parasite_folder, file_name)
|
1755
|
+
cell_path = os.path.join(cell_folder, file_name)
|
1756
|
+
nuclei_path = os.path.join(nuclei_folder, file_name)
|
1757
|
+
# Check if the corresponding cell and nuclei mask files exist
|
1758
|
+
if not (os.path.exists(cell_path) and os.path.exists(nuclei_path)):
|
1759
|
+
raise ValueError(f"Corresponding cell or nuclei mask file for {file_name} not found.")
|
1760
|
+
# Load the masks
|
1761
|
+
parasite_mask = np.load(parasite_path)
|
1762
|
+
cell_mask = np.load(cell_path)
|
1763
|
+
nuclei_mask = np.load(nuclei_path)
|
1764
|
+
# Merge and relabel cells
|
1765
|
+
merged_cell_mask = _merge_cells_based_on_parasite_overlap(parasite_mask, cell_mask, nuclei_mask, overlap_threshold, perimeter_threshold)
|
1766
|
+
# Overwrite the original cell mask file with the merged result
|
1767
|
+
np.save(cell_path, merged_cell_mask)
|
1768
|
+
|
1769
|
+
def process_masks(mask_folder, image_folder, channel, batch_size=50, n_clusters=2, plot=False):
|
1770
|
+
|
1771
|
+
def read_files_in_batches(folder, batch_size=50):
|
1772
|
+
files = [f for f in os.listdir(folder) if f.endswith('.npy')]
|
1773
|
+
files.sort() # Sort to ensure matching order
|
1774
|
+
for i in range(0, len(files), batch_size):
|
1775
|
+
yield files[i:i + batch_size]
|
1776
|
+
|
1777
|
+
def measure_morphology_and_intensity(mask, image):
|
1778
|
+
properties = measure.regionprops(mask, intensity_image=image)
|
1779
|
+
properties_list = [{'area': p.area, 'mean_intensity': p.mean_intensity, 'perimeter': p.perimeter, 'eccentricity': p.eccentricity} for p in properties]
|
1780
|
+
return properties_list
|
1781
|
+
|
1782
|
+
def cluster_objects(properties, n_clusters=2):
|
1783
|
+
data = np.array([[p['area'], p['mean_intensity'], p['perimeter'], p['eccentricity']] for p in properties])
|
1784
|
+
kmeans = KMeans(n_clusters=n_clusters, random_state=0).fit(data)
|
1785
|
+
return kmeans
|
1786
|
+
|
1787
|
+
def remove_objects_not_in_largest_cluster(mask, labels, largest_cluster_label):
|
1788
|
+
cleaned_mask = np.zeros_like(mask)
|
1789
|
+
for region in measure.regionprops(mask):
|
1790
|
+
if labels[region.label - 1] == largest_cluster_label:
|
1791
|
+
cleaned_mask[mask == region.label] = region.label
|
1792
|
+
return cleaned_mask
|
1793
|
+
|
1794
|
+
def plot_clusters(properties, labels):
|
1795
|
+
data = np.array([[p['area'], p['mean_intensity'], p['perimeter'], p['eccentricity']] for p in properties])
|
1796
|
+
pca = PCA(n_components=2)
|
1797
|
+
data_2d = pca.fit_transform(data)
|
1798
|
+
plt.scatter(data_2d[:, 0], data_2d[:, 1], c=labels, cmap='viridis')
|
1799
|
+
plt.xlabel('PCA Component 1')
|
1800
|
+
plt.ylabel('PCA Component 2')
|
1801
|
+
plt.title('Object Clustering')
|
1802
|
+
plt.show()
|
1803
|
+
|
1804
|
+
all_properties = []
|
1805
|
+
|
1806
|
+
# Step 1: Accumulate properties over all files
|
1807
|
+
for batch in read_files_in_batches(mask_folder, batch_size):
|
1808
|
+
mask_files = [os.path.join(mask_folder, file) for file in batch]
|
1809
|
+
image_files = [os.path.join(image_folder, file) for file in batch]
|
1810
|
+
|
1811
|
+
masks = [np.load(file) for file in mask_files]
|
1812
|
+
images = [np.load(file)[:, :, channel] for file in image_files]
|
1813
|
+
|
1814
|
+
for i, mask in enumerate(masks):
|
1815
|
+
image = images[i]
|
1816
|
+
# Measure morphology and intensity
|
1817
|
+
properties = measure_morphology_and_intensity(mask, image)
|
1818
|
+
all_properties.extend(properties)
|
1819
|
+
|
1820
|
+
# Step 2: Perform clustering on accumulated properties
|
1821
|
+
kmeans = cluster_objects(all_properties, n_clusters)
|
1822
|
+
labels = kmeans.labels_
|
1823
|
+
|
1824
|
+
if plot:
|
1825
|
+
# Step 3: Plot clusters using PCA
|
1826
|
+
plot_clusters(all_properties, labels)
|
1827
|
+
|
1828
|
+
# Step 4: Remove objects not in the largest cluster and overwrite files in batches
|
1829
|
+
label_index = 0
|
1830
|
+
for batch in read_files_in_batches(mask_folder, batch_size):
|
1831
|
+
mask_files = [os.path.join(mask_folder, file) for file in batch]
|
1832
|
+
masks = [np.load(file) for file in mask_files]
|
1833
|
+
|
1834
|
+
for i, mask in enumerate(masks):
|
1835
|
+
batch_properties = measure_morphology_and_intensity(mask, mask)
|
1836
|
+
batch_labels = labels[label_index:label_index + len(batch_properties)]
|
1837
|
+
largest_cluster_label = np.bincount(batch_labels).argmax()
|
1838
|
+
cleaned_mask = remove_objects_not_in_largest_cluster(mask, batch_labels, largest_cluster_label)
|
1839
|
+
np.save(mask_files[i], cleaned_mask)
|
1840
|
+
label_index += len(batch_properties)
|
1841
|
+
|
1842
|
+
def preprocess_generate_masks(src, settings={}):
|
1654
1843
|
|
1655
1844
|
from .io import preprocess_img_data, _load_and_concatenate_arrays
|
1656
1845
|
from .plot import plot_merged, plot_arrays
|
1657
|
-
from .utils import _pivot_counts_table
|
1658
|
-
|
1659
|
-
settings =
|
1660
|
-
|
1846
|
+
from .utils import _pivot_counts_table, set_default_settings_preprocess_generate_masks, set_default_plot_merge_settings, check_mask_folder
|
1847
|
+
|
1848
|
+
settings = set_default_settings_preprocess_generate_masks(src, settings)
|
1849
|
+
|
1661
1850
|
settings_df = pd.DataFrame(list(settings.items()), columns=['Key', 'Value'])
|
1662
1851
|
settings_csv = os.path.join(src,'settings','preprocess_generate_masks_settings.csv')
|
1663
1852
|
os.makedirs(os.path.join(src,'settings'), exist_ok=True)
|
1664
1853
|
settings_df.to_csv(settings_csv, index=False)
|
1854
|
+
|
1855
|
+
if not settings['pathogen_channel'] is None:
|
1856
|
+
custom_model_ls = ['toxo_pv_lumen','toxo_cyto']
|
1857
|
+
if settings['pathogen_model'] not in custom_model_ls:
|
1858
|
+
ValueError(f'Pathogen model must be {custom_model_ls} or None')
|
1665
1859
|
|
1666
1860
|
if settings['timelapse']:
|
1667
1861
|
settings['randomize'] = False
|
@@ -1670,24 +1864,50 @@ def preprocess_generate_masks(src, settings={},advanced_settings={}):
|
|
1670
1864
|
if not settings['masks']:
|
1671
1865
|
print(f'WARNING: channels for mask generation are defined when preprocess = True')
|
1672
1866
|
|
1673
|
-
if isinstance(settings['merge'], bool):
|
1674
|
-
settings['merge'] = [settings['merge']]*3
|
1675
1867
|
if isinstance(settings['save'], bool):
|
1676
1868
|
settings['save'] = [settings['save']]*3
|
1677
1869
|
|
1870
|
+
if settings['verbose']:
|
1871
|
+
settings_df = pd.DataFrame(list(settings.items()), columns=['setting_key', 'setting_value'])
|
1872
|
+
settings_df['setting_value'] = settings_df['setting_value'].apply(str)
|
1873
|
+
display(settings_df)
|
1874
|
+
|
1875
|
+
if settings['test_mode']:
|
1876
|
+
print(f'Starting Test mode ...')
|
1877
|
+
|
1678
1878
|
if settings['preprocess']:
|
1679
|
-
preprocess_img_data(settings)
|
1879
|
+
settings, src = preprocess_img_data(settings)
|
1680
1880
|
|
1681
1881
|
if settings['masks']:
|
1682
1882
|
mask_src = os.path.join(src, 'norm_channel_stack')
|
1683
1883
|
if settings['cell_channel'] != None:
|
1684
|
-
|
1884
|
+
if check_mask_folder(src, 'cell_mask_stack'):
|
1885
|
+
generate_cellpose_masks(mask_src, settings, 'cell')
|
1685
1886
|
|
1686
1887
|
if settings['nucleus_channel'] != None:
|
1687
|
-
|
1888
|
+
if check_mask_folder(src, 'nucleus_mask_stack'):
|
1889
|
+
generate_cellpose_masks(mask_src, settings, 'nucleus')
|
1688
1890
|
|
1689
1891
|
if settings['pathogen_channel'] != None:
|
1690
|
-
|
1892
|
+
if check_mask_folder(src, 'pathogen_mask_stack'):
|
1893
|
+
generate_cellpose_masks(mask_src, settings, 'pathogen')
|
1894
|
+
|
1895
|
+
if settings['adjust_cells']:
|
1896
|
+
if settings['pathogen_channel'] != None and settings['cell_channel'] != None and settings['nucleus_channel'] != None:
|
1897
|
+
|
1898
|
+
start = time.time()
|
1899
|
+
cell_folder = os.path.join(mask_src, 'cell_mask_stack')
|
1900
|
+
nuclei_folder = os.path.join(mask_src, 'nucleus_mask_stack')
|
1901
|
+
parasite_folder = os.path.join(mask_src, 'pathogen_mask_stack')
|
1902
|
+
#image_folder = os.path.join(src, 'stack')
|
1903
|
+
|
1904
|
+
#process_masks(cell_folder, image_folder, settings['cell_channel'], settings['batch_size'], n_clusters=2, plot=settings['plot'])
|
1905
|
+
#process_masks(nuclei_folder, image_folder, settings['nucleus_channel'], settings['batch_size'], n_clusters=2, plot=settings['plot'])
|
1906
|
+
#process_masks(parasite_folder, image_folder, settings['pathogen_channel'], settings['batch_size'], n_clusters=2, plot=settings['plot'])
|
1907
|
+
|
1908
|
+
adjust_cell_masks(parasite_folder, cell_folder, nuclei_folder, overlap_threshold=5, perimeter_threshold=30)
|
1909
|
+
stop = time.time()
|
1910
|
+
print(f'Cell mask adjustment: {stop-start} seconds')
|
1691
1911
|
|
1692
1912
|
if os.path.exists(os.path.join(src,'measurements')):
|
1693
1913
|
_pivot_counts_table(db_path=os.path.join(src,'measurements', 'measurements.db'))
|
@@ -1716,59 +1936,110 @@ def preprocess_generate_masks(src, settings={},advanced_settings={}):
|
|
1716
1936
|
overlay_channels = [settings['nucleus_channel'], settings['pathogen_channel'], settings['cell_channel']]
|
1717
1937
|
overlay_channels = [element for element in overlay_channels if element is not None]
|
1718
1938
|
|
1719
|
-
plot_settings =
|
1720
|
-
|
1721
|
-
|
1722
|
-
|
1723
|
-
|
1724
|
-
|
1725
|
-
|
1726
|
-
|
1727
|
-
|
1728
|
-
|
1729
|
-
|
1730
|
-
'outline_thickness':3,
|
1731
|
-
'outline_color':'gbr',
|
1732
|
-
'overlay_chans':overlay_channels,
|
1733
|
-
'overlay':True,
|
1734
|
-
'normalization_percentiles':[1,99],
|
1735
|
-
'normalize':True,
|
1736
|
-
'print_object_number':True,
|
1737
|
-
'nr':settings['examples_to_plot'],
|
1738
|
-
'figuresize':20,
|
1739
|
-
'cmap':'inferno',
|
1740
|
-
'verbose':False}
|
1939
|
+
plot_settings = set_default_plot_merge_settings()
|
1940
|
+
plot_settings['channel_dims'] = settings['channels']
|
1941
|
+
plot_settings['cell_mask_dim'] = cell_mask_dim
|
1942
|
+
plot_settings['nucleus_mask_dim'] = nucleus_mask_dim
|
1943
|
+
plot_settings['pathogen_mask_dim'] = pathogen_mask_dim
|
1944
|
+
plot_settings['overlay_chans'] = overlay_channels
|
1945
|
+
plot_settings['nr'] = settings['examples_to_plot']
|
1946
|
+
|
1947
|
+
if settings['test_mode'] == True:
|
1948
|
+
plot_settings['nr'] = len(os.path.join(src,'merged'))
|
1949
|
+
|
1741
1950
|
try:
|
1742
1951
|
fig = plot_merged(src=os.path.join(src,'merged'), settings=plot_settings)
|
1743
1952
|
except Exception as e:
|
1744
1953
|
print(f'Failed to plot image mask overly. Error: {e}')
|
1745
1954
|
else:
|
1746
|
-
plot_arrays(src=os.path.join(src,'merged'), figuresize=
|
1955
|
+
plot_arrays(src=os.path.join(src,'merged'), figuresize=settings['figuresize'], cmap=settings['cmap'], nr=settings['examples_to_plot'], normalize=settings['normalize'], q1=1, q2=99)
|
1747
1956
|
|
1748
1957
|
torch.cuda.empty_cache()
|
1749
1958
|
gc.collect()
|
1959
|
+
print("Successfully completed run")
|
1750
1960
|
return
|
1751
1961
|
|
1752
|
-
def identify_masks_finetune(
|
1962
|
+
def identify_masks_finetune(settings):
|
1753
1963
|
|
1754
1964
|
from .plot import print_mask_and_flows
|
1755
1965
|
from .utils import get_files_from_dir, resize_images_and_labels
|
1756
1966
|
from .io import _load_normalized_images_and_labels, _load_images_and_labels
|
1757
1967
|
|
1968
|
+
#User defined settings
|
1969
|
+
src=settings['src']
|
1970
|
+
dst=settings['dst']
|
1971
|
+
|
1972
|
+
|
1973
|
+
settings.setdefault('model_name', 'cyto')
|
1974
|
+
settings.setdefault('custom_model', None)
|
1975
|
+
settings.setdefault('channels', [0,0])
|
1976
|
+
settings.setdefault('background', 100)
|
1977
|
+
settings.setdefault('remove_background', False)
|
1978
|
+
settings.setdefault('Signal_to_noise', 10)
|
1979
|
+
settings.setdefault('CP_prob', 0)
|
1980
|
+
settings.setdefault('diameter', 30)
|
1981
|
+
settings.setdefault('batch_size', 50)
|
1982
|
+
settings.setdefault('flow_threshold', 0.4)
|
1983
|
+
settings.setdefault('save', False)
|
1984
|
+
settings.setdefault('verbose', False)
|
1985
|
+
settings.setdefault('normalize', True)
|
1986
|
+
settings.setdefault('percentiles', None)
|
1987
|
+
settings.setdefault('circular', False)
|
1988
|
+
settings.setdefault('invert', False)
|
1989
|
+
settings.setdefault('resize', False)
|
1990
|
+
settings.setdefault('target_height', None)
|
1991
|
+
settings.setdefault('target_width', None)
|
1992
|
+
settings.setdefault('rescale', False)
|
1993
|
+
settings.setdefault('resample', False)
|
1994
|
+
settings.setdefault('grayscale', True)
|
1995
|
+
|
1996
|
+
|
1997
|
+
model_name=settings['model_name']
|
1998
|
+
custom_model=settings['custom_model']
|
1999
|
+
channels = settings['channels']
|
2000
|
+
background = settings['background']
|
2001
|
+
remove_background=settings['remove_background']
|
2002
|
+
Signal_to_noise = settings['Signal_to_noise']
|
2003
|
+
CP_prob = settings['CP_prob']
|
2004
|
+
diameter=settings['diameter']
|
2005
|
+
batch_size=settings['batch_size']
|
2006
|
+
flow_threshold=settings['flow_threshold']
|
2007
|
+
save=settings['save']
|
2008
|
+
verbose=settings['verbose']
|
2009
|
+
|
2010
|
+
# static settings
|
2011
|
+
normalize = settings['normalize']
|
2012
|
+
percentiles = settings['percentiles']
|
2013
|
+
circular = settings['circular']
|
2014
|
+
invert = settings['invert']
|
2015
|
+
resize = settings['resize']
|
2016
|
+
|
2017
|
+
if resize:
|
2018
|
+
target_height = settings['target_height']
|
2019
|
+
target_width = settings['target_width']
|
2020
|
+
|
2021
|
+
rescale = settings['rescale']
|
2022
|
+
resample = settings['resample']
|
2023
|
+
grayscale = settings['grayscale']
|
2024
|
+
|
2025
|
+
os.makedirs(dst, exist_ok=True)
|
2026
|
+
|
2027
|
+
if not custom_model is None:
|
2028
|
+
if not os.path.exists(custom_model):
|
2029
|
+
print(f'Custom model not found: {custom_model}')
|
2030
|
+
return
|
2031
|
+
|
1758
2032
|
if not torch.cuda.is_available():
|
1759
2033
|
print(f'Torch CUDA is not available, using CPU')
|
1760
2034
|
|
1761
2035
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
1762
2036
|
|
1763
2037
|
if custom_model == None:
|
1764
|
-
|
1765
|
-
|
1766
|
-
|
1767
|
-
|
1768
|
-
|
1769
|
-
if custom_model != None:
|
1770
|
-
model = cp_models.CellposeModel(gpu=torch.cuda.is_available(), model_type=None, pretrained_model=custom_model, diam_mean=diameter, device=device, net_avg=False) #Assuming diameter is defined elsewhere
|
1771
|
-
print(f'loaded custom model:{custom_model}')
|
2038
|
+
model = cp_models.CellposeModel(gpu=True, model_type=model_name, device=device)
|
2039
|
+
print(f'Loaded model: {model_name}')
|
2040
|
+
else:
|
2041
|
+
model = cp_models.CellposeModel(gpu=torch.cuda.is_available(), model_type=None, pretrained_model=custom_model, diam_mean=diameter, device=device)
|
2042
|
+
print("Pretrained Model Loaded:", model.pretrained_model)
|
1772
2043
|
|
1773
2044
|
chans = [2, 1] if model_name == 'cyto2' else [0,0] if model_name == 'nucleus' else [1,0] if model_name == 'cyto' else [2, 0]
|
1774
2045
|
|
@@ -1778,16 +2049,18 @@ def identify_masks_finetune(src, dst, model_name, channels, diameter, batch_size
|
|
1778
2049
|
print(f'Using channels: {chans} for model of type {model_name}')
|
1779
2050
|
|
1780
2051
|
if verbose == True:
|
1781
|
-
print(f'Cellpose settings: Model: {model_name}, channels: {channels}, cellpose_chans: {chans}, diameter:{diameter}, flow_threshold:{flow_threshold}, cellprob_threshold:{
|
2052
|
+
print(f'Cellpose settings: Model: {model_name}, channels: {channels}, cellpose_chans: {chans}, diameter:{diameter}, flow_threshold:{flow_threshold}, cellprob_threshold:{CP_prob}')
|
1782
2053
|
|
1783
|
-
all_image_files =
|
2054
|
+
all_image_files = [os.path.join(src, f) for f in os.listdir(src) if f.endswith('.tif')]
|
2055
|
+
|
1784
2056
|
random.shuffle(all_image_files)
|
1785
2057
|
|
1786
2058
|
time_ls = []
|
1787
2059
|
for i in range(0, len(all_image_files), batch_size):
|
1788
2060
|
image_files = all_image_files[i:i+batch_size]
|
2061
|
+
|
1789
2062
|
if normalize:
|
1790
|
-
images, _, image_names, _ = _load_normalized_images_and_labels(image_files=image_files, label_files=None,
|
2063
|
+
images, _, image_names, _ = _load_normalized_images_and_labels(image_files=image_files, label_files=None, channels=channels, percentiles=percentiles, circular=circular, invert=invert, visualize=verbose, remove_background=remove_background, background=background, Signal_to_noise=Signal_to_noise)
|
1791
2064
|
images = [np.squeeze(img) if img.shape[-1] == 1 else img for img in images]
|
1792
2065
|
orig_dims = [(image.shape[0], image.shape[1]) for image in images]
|
1793
2066
|
else:
|
@@ -1805,11 +2078,10 @@ def identify_masks_finetune(src, dst, model_name, channels, diameter, batch_size
|
|
1805
2078
|
channel_axis=3,
|
1806
2079
|
diameter=diameter,
|
1807
2080
|
flow_threshold=flow_threshold,
|
1808
|
-
cellprob_threshold=
|
2081
|
+
cellprob_threshold=CP_prob,
|
1809
2082
|
rescale=rescale,
|
1810
2083
|
resample=resample,
|
1811
|
-
|
1812
|
-
progress=False)
|
2084
|
+
progress=True)
|
1813
2085
|
|
1814
2086
|
if len(output) == 4:
|
1815
2087
|
mask, flows, _, _ = output
|
@@ -1827,17 +2099,17 @@ def identify_masks_finetune(src, dst, model_name, channels, diameter, batch_size
|
|
1827
2099
|
time_ls.append(duration)
|
1828
2100
|
average_time = np.mean(time_ls) if len(time_ls) > 0 else 0
|
1829
2101
|
print(f'Processing {file_index+1}/{len(images)} images : Time/image {average_time:.3f} sec', end='\r', flush=True)
|
1830
|
-
if
|
2102
|
+
if verbose:
|
1831
2103
|
if resize:
|
1832
2104
|
stack = resizescikit(stack, dims, preserve_range=True, anti_aliasing=False).astype(stack.dtype)
|
1833
|
-
print_mask_and_flows(stack, mask, flows, overlay=
|
2105
|
+
print_mask_and_flows(stack, mask, flows, overlay=True)
|
1834
2106
|
if save:
|
2107
|
+
os.makedirs(dst, exist_ok=True)
|
1835
2108
|
output_filename = os.path.join(dst, image_names[file_index])
|
1836
2109
|
cv2.imwrite(output_filename, mask)
|
1837
2110
|
return
|
1838
2111
|
|
1839
|
-
|
1840
|
-
def identify_masks(src, object_type, model_name, batch_size, channels, diameter, minimum_size, maximum_size, flow_threshold=30, cellprob_threshold=1, figuresize=25, cmap='inferno', refine_masks=True, filter_size=True, filter_dimm=True, remove_border_objects=False, verbose=False, plot=False, merge=False, save=True, start_at=0, file_type='.npz', net_avg=True, resample=True, timelapse=False, timelapse_displacement=None, timelapse_frame_limits=None, timelapse_memory=3, timelapse_remove_transient=False, timelapse_mode='btrack', timelapse_objects='cell'):
|
2112
|
+
def identify_masks(src, object_type, model_name, batch_size, channels, diameter, minimum_size, maximum_size, filter_intensity, flow_threshold=30, cellprob_threshold=1, figuresize=25, cmap='inferno', refine_masks=True, filter_size=True, filter_dimm=True, remove_border_objects=False, verbose=False, plot=False, merge=False, save=True, start_at=0, file_type='.npz', net_avg=True, resample=True, timelapse=False, timelapse_displacement=None, timelapse_frame_limits=None, timelapse_memory=3, timelapse_remove_transient=False, timelapse_mode='btrack', timelapse_objects='cell'):
|
1841
2113
|
"""
|
1842
2114
|
Identify masks from the source images.
|
1843
2115
|
|
@@ -1885,13 +2157,13 @@ def identify_masks(src, object_type, model_name, batch_size, channels, diameter,
|
|
1885
2157
|
|
1886
2158
|
#Note add logic that handles batches of size 1 as these will break the code batches must all be > 2 images
|
1887
2159
|
gc.collect()
|
1888
|
-
#print('========== generating masks ==========')
|
1889
2160
|
|
1890
2161
|
if not torch.cuda.is_available():
|
1891
2162
|
print(f'Torch CUDA is not available, using CPU')
|
1892
2163
|
|
1893
2164
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
1894
|
-
model = cp_models.Cellpose(gpu=True, model_type=model_name, device=device)
|
2165
|
+
model = cp_models.Cellpose(gpu=True, model_type=model_name, device=device)
|
2166
|
+
|
1895
2167
|
if file_type == '.npz':
|
1896
2168
|
paths = [os.path.join(src, file) for file in os.listdir(src) if file.endswith('.npz')]
|
1897
2169
|
else:
|
@@ -1918,9 +2190,6 @@ def identify_masks(src, object_type, model_name, batch_size, channels, diameter,
|
|
1918
2190
|
|
1919
2191
|
average_sizes = []
|
1920
2192
|
time_ls = []
|
1921
|
-
moving_avg_q1 = 0
|
1922
|
-
moving_avg_q3 = 0
|
1923
|
-
moving_count = 0
|
1924
2193
|
for file_index, path in enumerate(paths):
|
1925
2194
|
|
1926
2195
|
name = os.path.basename(path)
|
@@ -1961,7 +2230,8 @@ def identify_masks(src, object_type, model_name, batch_size, channels, diameter,
|
|
1961
2230
|
if not plot:
|
1962
2231
|
batch, batch_filenames = _check_masks(batch, batch_filenames, output_folder)
|
1963
2232
|
if batch.size == 0:
|
1964
|
-
print(f'Processing {file_index}/{len(paths)}: Images/N100pz {batch.shape[0]}'
|
2233
|
+
print(f'Processing: {file_index}/{len(paths)}: Images/N100pz {batch.shape[0]}')
|
2234
|
+
#print(f'Processing {file_index}/{len(paths)}: Images/N100pz {batch.shape[0]}', end='\r', flush=True)
|
1965
2235
|
continue
|
1966
2236
|
if batch.max() > 1:
|
1967
2237
|
batch = batch / batch.max()
|
@@ -1976,8 +2246,6 @@ def identify_masks(src, object_type, model_name, batch_size, channels, diameter,
|
|
1976
2246
|
stitch_threshold=0.0
|
1977
2247
|
|
1978
2248
|
cellpose_batch_size = _get_cellpose_batch_size()
|
1979
|
-
|
1980
|
-
model = cellpose.denoise.DenoiseModel(model_type=f"denoise_{model_name}", gpu=True)
|
1981
2249
|
|
1982
2250
|
masks, flows, _, _ = model.eval(x=batch,
|
1983
2251
|
batch_size=cellpose_batch_size,
|
@@ -1989,9 +2257,9 @@ def identify_masks(src, object_type, model_name, batch_size, channels, diameter,
|
|
1989
2257
|
cellprob_threshold=cellprob_threshold,
|
1990
2258
|
rescale=None,
|
1991
2259
|
resample=resample,
|
1992
|
-
#net_avg=net_avg,
|
1993
2260
|
stitch_threshold=stitch_threshold,
|
1994
2261
|
progress=None)
|
2262
|
+
|
1995
2263
|
print('Masks shape',masks.shape)
|
1996
2264
|
if timelapse:
|
1997
2265
|
_save_object_counts_to_database(masks, object_type, batch_filenames, count_loc, added_string='_timelapse')
|
@@ -2015,7 +2283,7 @@ def identify_masks(src, object_type, model_name, batch_size, channels, diameter,
|
|
2015
2283
|
|
2016
2284
|
else:
|
2017
2285
|
_save_object_counts_to_database(masks, object_type, batch_filenames, count_loc, added_string='_before_filtration')
|
2018
|
-
mask_stack = _filter_cp_masks(masks, flows,
|
2286
|
+
mask_stack = _filter_cp_masks(masks, flows, filter_size, filter_intensity, minimum_size, maximum_size, remove_border_objects, merge, batch, plot, figuresize)
|
2019
2287
|
_save_object_counts_to_database(mask_stack, object_type, batch_filenames, count_loc, added_string='_after_filtration')
|
2020
2288
|
|
2021
2289
|
if not np.any(mask_stack):
|
@@ -2032,7 +2300,8 @@ def identify_masks(src, object_type, model_name, batch_size, channels, diameter,
|
|
2032
2300
|
average_time = np.mean(time_ls) if len(time_ls) > 0 else 0
|
2033
2301
|
time_in_min = average_time/60
|
2034
2302
|
time_per_mask = average_time/batch_size
|
2035
|
-
print(f'Processing {len(paths)} files with {batch_size} imgs: {(file_index+1)*(batch_size+1)}/{(len(paths))*(batch_size+1)}: Time/batch {time_in_min:.3f} min: Time/mask {time_per_mask:.3f}sec: {object_type} size: {overall_average_size:.3f} px2'
|
2303
|
+
print(f'Processing: {len(paths)} files with {batch_size} imgs: {(file_index+1)*(batch_size+1)}/{(len(paths))*(batch_size+1)}: Time/batch {time_in_min:.3f} min: Time/mask {time_per_mask:.3f}sec: {object_type} size: {overall_average_size:.3f} px2')
|
2304
|
+
#print(f'Processing {len(paths)} files with {batch_size} imgs: {(file_index+1)*(batch_size+1)}/{(len(paths))*(batch_size+1)}: Time/batch {time_in_min:.3f} min: Time/mask {time_per_mask:.3f}sec: {object_type} size: {overall_average_size:.3f} px2', end='\r', flush=True)
|
2036
2305
|
if not timelapse:
|
2037
2306
|
if plot:
|
2038
2307
|
plot_masks(batch, mask_stack, flows, figuresize=figuresize, cmap=cmap, nr=batch_size, file_type='.npz')
|
@@ -2046,10 +2315,25 @@ def identify_masks(src, object_type, model_name, batch_size, channels, diameter,
|
|
2046
2315
|
gc.collect()
|
2047
2316
|
return
|
2048
2317
|
|
2049
|
-
|
2318
|
+
def all_elements_match(list1, list2):
|
2319
|
+
# Check if all elements in list1 are in list2
|
2320
|
+
return all(element in list2 for element in list1)
|
2321
|
+
|
2322
|
+
def prepare_batch_for_cellpose(batch):
|
2323
|
+
# Ensure the batch is of dtype float32
|
2324
|
+
if batch.dtype != np.float32:
|
2325
|
+
batch = batch.astype(np.float32)
|
2326
|
+
|
2327
|
+
# Normalize each image in the batch
|
2328
|
+
for i in range(batch.shape[0]):
|
2329
|
+
if batch[i].max() > 1:
|
2330
|
+
batch[i] = batch[i] / batch[i].max()
|
2331
|
+
|
2332
|
+
return batch
|
2333
|
+
|
2050
2334
|
def generate_cellpose_masks(src, settings, object_type):
|
2051
2335
|
|
2052
|
-
from .utils import _masks_to_masks_stack, _filter_cp_masks, _get_cellpose_batch_size, _get_object_settings, _get_cellpose_channels
|
2336
|
+
from .utils import _masks_to_masks_stack, _filter_cp_masks, _get_cellpose_batch_size, _get_object_settings, _get_cellpose_channels, _choose_model, mask_object_count, set_default_settings_preprocess_generate_masks
|
2053
2337
|
from .io import _create_database, _save_object_counts_to_database, _check_masks, _get_avg_object_size
|
2054
2338
|
from .timelapse import _npz_to_movie, _btrack_track_cells, _trackpy_track_cells
|
2055
2339
|
from .plot import plot_masks
|
@@ -2057,6 +2341,13 @@ def generate_cellpose_masks(src, settings, object_type):
|
|
2057
2341
|
gc.collect()
|
2058
2342
|
if not torch.cuda.is_available():
|
2059
2343
|
print(f'Torch CUDA is not available, using CPU')
|
2344
|
+
|
2345
|
+
settings = set_default_settings_preprocess_generate_masks(src, settings)
|
2346
|
+
|
2347
|
+
if settings['verbose']:
|
2348
|
+
settings_df = pd.DataFrame(list(settings.items()), columns=['setting_key', 'setting_value'])
|
2349
|
+
settings_df['setting_value'] = settings_df['setting_value'].apply(str)
|
2350
|
+
display(settings_df)
|
2060
2351
|
|
2061
2352
|
figuresize=25
|
2062
2353
|
timelapse = settings['timelapse']
|
@@ -2071,21 +2362,26 @@ def generate_cellpose_masks(src, settings, object_type):
|
|
2071
2362
|
|
2072
2363
|
batch_size = settings['batch_size']
|
2073
2364
|
cellprob_threshold = settings[f'{object_type}_CP_prob']
|
2074
|
-
|
2075
|
-
|
2365
|
+
|
2366
|
+
flow_threshold = settings[f'{object_type}_FT']
|
2367
|
+
|
2076
2368
|
object_settings = _get_object_settings(object_type, settings)
|
2077
2369
|
model_name = object_settings['model_name']
|
2078
2370
|
|
2079
|
-
cellpose_channels = _get_cellpose_channels(settings['nucleus_channel'], settings['pathogen_channel'], settings['cell_channel'])
|
2371
|
+
cellpose_channels = _get_cellpose_channels(src, settings['nucleus_channel'], settings['pathogen_channel'], settings['cell_channel'])
|
2372
|
+
if settings['verbose']:
|
2373
|
+
print(cellpose_channels)
|
2374
|
+
|
2080
2375
|
channels = cellpose_channels[object_type]
|
2081
2376
|
cellpose_batch_size = _get_cellpose_batch_size()
|
2082
|
-
|
2083
2377
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
2084
|
-
model = cp_models.Cellpose(gpu=True, model_type=model_name, device=device) #net_avg=net_avg
|
2085
|
-
#dn = denoise.CellposeDenoiseModel(model_type=f"denoise_{model_name}", gpu=True, device=device)
|
2086
2378
|
|
2379
|
+
if object_type == 'pathogen' and not settings['pathogen_model'] is None:
|
2380
|
+
model_name = settings['pathogen_model']
|
2381
|
+
|
2382
|
+
model = _choose_model(model_name, device, object_type=object_type, restore_type=None, object_settings=object_settings)
|
2383
|
+
|
2087
2384
|
chans = [2, 1] if model_name == 'cyto2' else [0,0] if model_name == 'nucleus' else [2,0] if model_name == 'cyto' else [2, 0] if model_name == 'cyto3' else [2, 0]
|
2088
|
-
|
2089
2385
|
paths = [os.path.join(src, file) for file in os.listdir(src) if file.endswith('.npz')]
|
2090
2386
|
|
2091
2387
|
count_loc = os.path.dirname(src)+'/measurements/measurements.db'
|
@@ -2094,10 +2390,6 @@ def generate_cellpose_masks(src, settings, object_type):
|
|
2094
2390
|
|
2095
2391
|
average_sizes = []
|
2096
2392
|
time_ls = []
|
2097
|
-
moving_avg_q1 = 0
|
2098
|
-
moving_avg_q3 = 0
|
2099
|
-
moving_count = 0
|
2100
|
-
|
2101
2393
|
for file_index, path in enumerate(paths):
|
2102
2394
|
name = os.path.basename(path)
|
2103
2395
|
name, ext = os.path.splitext(name)
|
@@ -2107,17 +2399,31 @@ def generate_cellpose_masks(src, settings, object_type):
|
|
2107
2399
|
with np.load(path) as data:
|
2108
2400
|
stack = data['data']
|
2109
2401
|
filenames = data['filenames']
|
2402
|
+
|
2403
|
+
for i, filename in enumerate(filenames):
|
2404
|
+
output_path = os.path.join(output_folder, filename)
|
2405
|
+
|
2406
|
+
if os.path.exists(output_path):
|
2407
|
+
print(f"File {filename} already exists in the output folder. Skipping...")
|
2408
|
+
continue
|
2409
|
+
|
2110
2410
|
if settings['timelapse']:
|
2411
|
+
|
2412
|
+
trackable_objects = ['cell','nucleus','pathogen']
|
2413
|
+
if not all_elements_match(settings['timelapse_objects'], trackable_objects):
|
2414
|
+
print(f'timelapse_objects {settings["timelapse_objects"]} must be a subset of {trackable_objects}')
|
2415
|
+
return
|
2416
|
+
|
2111
2417
|
if len(stack) != batch_size:
|
2112
2418
|
print(f'Changed batch_size:{batch_size} to {len(stack)}, data length:{len(stack)}')
|
2113
|
-
settings['
|
2419
|
+
settings['timelapse_batch_size'] = len(stack)
|
2114
2420
|
batch_size = len(stack)
|
2115
2421
|
if isinstance(timelapse_frame_limits, list):
|
2116
2422
|
if len(timelapse_frame_limits) >= 2:
|
2117
2423
|
stack = stack[timelapse_frame_limits[0]: timelapse_frame_limits[1], :, :, :].astype(stack.dtype)
|
2118
2424
|
filenames = filenames[timelapse_frame_limits[0]: timelapse_frame_limits[1]]
|
2119
2425
|
batch_size = len(stack)
|
2120
|
-
print(f'Cut batch
|
2426
|
+
print(f'Cut batch at indecies: {timelapse_frame_limits}, New batch_size: {batch_size} ')
|
2121
2427
|
|
2122
2428
|
for i in range(0, stack.shape[0], batch_size):
|
2123
2429
|
mask_stack = []
|
@@ -2133,37 +2439,53 @@ def generate_cellpose_masks(src, settings, object_type):
|
|
2133
2439
|
if not settings['plot']:
|
2134
2440
|
batch, batch_filenames = _check_masks(batch, batch_filenames, output_folder)
|
2135
2441
|
if batch.size == 0:
|
2136
|
-
print(f'Processing {file_index}/{len(paths)}: Images/
|
2442
|
+
print(f'Processing {file_index}/{len(paths)}: Images/npz {batch.shape[0]}')
|
2137
2443
|
continue
|
2138
|
-
|
2139
|
-
|
2444
|
+
|
2445
|
+
batch = prepare_batch_for_cellpose(batch)
|
2140
2446
|
|
2141
2447
|
if timelapse:
|
2142
|
-
stitch_threshold=100.0
|
2143
2448
|
movie_path = os.path.join(os.path.dirname(src), 'movies')
|
2144
2449
|
os.makedirs(movie_path, exist_ok=True)
|
2145
2450
|
save_path = os.path.join(movie_path, f'timelapse_{object_type}_{name}.mp4')
|
2146
2451
|
_npz_to_movie(batch, batch_filenames, save_path, fps=2)
|
2452
|
+
|
2453
|
+
if settings['verbose']:
|
2454
|
+
print(f'Processing {file_index}/{len(paths)}: Images/npz {batch.shape[0]}')
|
2455
|
+
|
2456
|
+
#cellpose_normalize_dict = {'lowhigh':[0.0,1.0], #pass in normalization values for 0.0 and 1.0 as list [low, high] if None all other keys ignored
|
2457
|
+
# 'sharpen':object_settings['diameter']/4, #recommended to be 1/4-1/8 diameter of cells in pixels
|
2458
|
+
# 'normalize':True, #(if False, all following parameters ignored)
|
2459
|
+
# 'percentile':[2,98], #[perc_low, perc_high]
|
2460
|
+
# 'tile_norm':224, #normalize by tile set to e.g. 100 for normailize window to be 100 px
|
2461
|
+
# 'norm3D':True} #compute normalization across entire z-stack rather than plane-by-plane in stitching mode.
|
2462
|
+
|
2463
|
+
output = model.eval(x=batch,
|
2464
|
+
batch_size=cellpose_batch_size,
|
2465
|
+
normalize=False,
|
2466
|
+
channels=chans,
|
2467
|
+
channel_axis=3,
|
2468
|
+
diameter=object_settings['diameter'],
|
2469
|
+
flow_threshold=flow_threshold,
|
2470
|
+
cellprob_threshold=cellprob_threshold,
|
2471
|
+
rescale=None,
|
2472
|
+
resample=object_settings['resample'])
|
2473
|
+
|
2474
|
+
if len(output) == 4:
|
2475
|
+
masks, flows, _, _ = output
|
2476
|
+
elif len(output) == 3:
|
2477
|
+
masks, flows, _ = output
|
2147
2478
|
else:
|
2148
|
-
|
2149
|
-
#print(batch.shape)
|
2150
|
-
#batch, _, _, _ = dn.eval(x=batch, channels=chans, diameter=object_settings['diameter'])
|
2151
|
-
#batch = np.stack((batch, batch), axis=-1)
|
2152
|
-
#print(f'object: {object_type} chans : {chans} channels : {channels} model: {model_name}')
|
2153
|
-
masks, flows, _, _ = model.eval(x=batch,
|
2154
|
-
batch_size=cellpose_batch_size,
|
2155
|
-
normalize=False,
|
2156
|
-
channels=chans,
|
2157
|
-
channel_axis=3,
|
2158
|
-
diameter=object_settings['diameter'],
|
2159
|
-
flow_threshold=flow_threshold,
|
2160
|
-
cellprob_threshold=cellprob_threshold,
|
2161
|
-
rescale=None,
|
2162
|
-
resample=object_settings['resample'],
|
2163
|
-
stitch_threshold=stitch_threshold)
|
2164
|
-
#progress=None)
|
2479
|
+
raise ValueError(f"Unexpected number of return values from model.eval(). Expected 3 or 4, got {len(output)}")
|
2165
2480
|
|
2166
2481
|
if timelapse:
|
2482
|
+
if settings['plot']:
|
2483
|
+
for idx, (mask, flow, image) in enumerate(zip(masks, flows[0], batch)):
|
2484
|
+
if idx == 0:
|
2485
|
+
num_objects = mask_object_count(mask)
|
2486
|
+
print(f'Number of objects: {num_objects}')
|
2487
|
+
plot_masks(batch=image, masks=mask, flows=flow, cmap='inferno', figuresize=figuresize, nr=1, file_type='.npz', print_object_number=True)
|
2488
|
+
|
2167
2489
|
_save_object_counts_to_database(masks, object_type, batch_filenames, count_loc, added_string='_timelapse')
|
2168
2490
|
if object_type in timelapse_objects:
|
2169
2491
|
if timelapse_mode == 'btrack':
|
@@ -2192,35 +2514,54 @@ def generate_cellpose_masks(src, settings, object_type):
|
|
2192
2514
|
name=name,
|
2193
2515
|
batch_filenames=batch_filenames,
|
2194
2516
|
object_type=object_type,
|
2195
|
-
|
2517
|
+
masks=masks,
|
2196
2518
|
timelapse_displacement=timelapse_displacement,
|
2197
2519
|
timelapse_memory=timelapse_memory,
|
2198
2520
|
timelapse_remove_transient=timelapse_remove_transient,
|
2199
2521
|
plot=settings['plot'],
|
2200
2522
|
save=settings['save'],
|
2201
|
-
|
2523
|
+
mode=timelapse_mode)
|
2202
2524
|
else:
|
2203
2525
|
mask_stack = _masks_to_masks_stack(masks)
|
2204
|
-
|
2205
2526
|
else:
|
2206
2527
|
_save_object_counts_to_database(masks, object_type, batch_filenames, count_loc, added_string='_before_filtration')
|
2207
|
-
|
2208
|
-
|
2209
|
-
|
2210
|
-
|
2211
|
-
|
2212
|
-
|
2213
|
-
|
2214
|
-
|
2215
|
-
|
2216
|
-
|
2217
|
-
|
2218
|
-
|
2219
|
-
|
2220
|
-
|
2221
|
-
|
2222
|
-
|
2528
|
+
if object_settings['merge'] and not settings['filter']:
|
2529
|
+
mask_stack = _filter_cp_masks(masks=masks,
|
2530
|
+
flows=flows,
|
2531
|
+
filter_size=False,
|
2532
|
+
filter_intensity=False,
|
2533
|
+
minimum_size=object_settings['minimum_size'],
|
2534
|
+
maximum_size=object_settings['maximum_size'],
|
2535
|
+
remove_border_objects=False,
|
2536
|
+
merge=object_settings['merge'],
|
2537
|
+
batch=batch,
|
2538
|
+
plot=settings['plot'],
|
2539
|
+
figuresize=figuresize)
|
2540
|
+
|
2541
|
+
if settings['filter']:
|
2542
|
+
mask_stack = _filter_cp_masks(masks=masks,
|
2543
|
+
flows=flows,
|
2544
|
+
filter_size=object_settings['filter_size'],
|
2545
|
+
filter_intensity=object_settings['filter_intensity'],
|
2546
|
+
minimum_size=object_settings['minimum_size'],
|
2547
|
+
maximum_size=object_settings['maximum_size'],
|
2548
|
+
remove_border_objects=object_settings['remove_border_objects'],
|
2549
|
+
merge=object_settings['merge'],
|
2550
|
+
batch=batch,
|
2551
|
+
plot=settings['plot'],
|
2552
|
+
figuresize=figuresize)
|
2553
|
+
|
2554
|
+
_save_object_counts_to_database(mask_stack, object_type, batch_filenames, count_loc, added_string='_after_filtration')
|
2555
|
+
else:
|
2556
|
+
mask_stack = _masks_to_masks_stack(masks)
|
2223
2557
|
|
2558
|
+
if settings['plot']:
|
2559
|
+
for idx, (mask, flow, image) in enumerate(zip(masks, flows[0], batch)):
|
2560
|
+
if idx == 0:
|
2561
|
+
num_objects = mask_object_count(mask)
|
2562
|
+
print(f'Number of objects, : {num_objects}')
|
2563
|
+
plot_masks(batch=image, masks=mask, flows=flow, cmap='inferno', figuresize=figuresize, nr=1, file_type='.npz', print_object_number=True)
|
2564
|
+
|
2224
2565
|
if not np.any(mask_stack):
|
2225
2566
|
average_obj_size = 0
|
2226
2567
|
else:
|
@@ -2235,7 +2576,7 @@ def generate_cellpose_masks(src, settings, object_type):
|
|
2235
2576
|
average_time = np.mean(time_ls) if len(time_ls) > 0 else 0
|
2236
2577
|
time_in_min = average_time/60
|
2237
2578
|
time_per_mask = average_time/batch_size
|
2238
|
-
print(f'Processing {len(paths)} files with {batch_size} imgs: {(file_index+1)*(batch_size+1)}/{(len(paths))*(batch_size+1)}: Time/batch {time_in_min:.3f} min: Time/mask {time_per_mask:.3f}sec: {object_type} size: {overall_average_size:.3f} px2'
|
2579
|
+
print(f'Processing {len(paths)} files with {batch_size} imgs: {(file_index+1)*(batch_size+1)}/{(len(paths))*(batch_size+1)}: Time/batch {time_in_min:.3f} min: Time/mask {time_per_mask:.3f}sec: {object_type} size: {overall_average_size:.3f} px2')
|
2239
2580
|
if not timelapse:
|
2240
2581
|
if settings['plot']:
|
2241
2582
|
plot_masks(batch, mask_stack, flows, figuresize=figuresize, cmap='inferno', nr=batch_size)
|
@@ -2247,4 +2588,885 @@ def generate_cellpose_masks(src, settings, object_type):
|
|
2247
2588
|
batch_filenames = []
|
2248
2589
|
gc.collect()
|
2249
2590
|
torch.cuda.empty_cache()
|
2591
|
+
return
|
2592
|
+
|
2593
|
+
def generate_masks_from_imgs(src, model, model_name, batch_size, diameter, cellprob_threshold, flow_threshold, grayscale, save, normalize, channels, percentiles, circular, invert, plot, resize, target_height, target_width, remove_background, background, Signal_to_noise, verbose):
|
2594
|
+
|
2595
|
+
from .io import _load_images_and_labels, _load_normalized_images_and_labels
|
2596
|
+
from .utils import resize_images_and_labels, resizescikit
|
2597
|
+
from .plot import print_mask_and_flows
|
2598
|
+
|
2599
|
+
dst = os.path.join(src, model_name)
|
2600
|
+
os.makedirs(dst, exist_ok=True)
|
2601
|
+
|
2602
|
+
chans = [2, 1] if model_name == 'cyto2' else [0,0] if model_name == 'nucleus' else [1,0] if model_name == 'cyto' else [2, 0]
|
2603
|
+
|
2604
|
+
if grayscale:
|
2605
|
+
chans=[0, 0]
|
2606
|
+
|
2607
|
+
all_image_files = [os.path.join(src, f) for f in os.listdir(src) if f.endswith('.tif')]
|
2608
|
+
random.shuffle(all_image_files)
|
2609
|
+
|
2610
|
+
if verbose == True:
|
2611
|
+
print(f'Cellpose settings: Model: {model_name}, channels: {channels}, cellpose_chans: {chans}, diameter:{diameter}, flow_threshold:{flow_threshold}, cellprob_threshold:{cellprob_threshold}')
|
2612
|
+
|
2613
|
+
time_ls = []
|
2614
|
+
for i in range(0, len(all_image_files), batch_size):
|
2615
|
+
image_files = all_image_files[i:i+batch_size]
|
2616
|
+
|
2617
|
+
if normalize:
|
2618
|
+
images, _, image_names, _ = _load_normalized_images_and_labels(image_files, None, channels, percentiles, circular, invert, plot, remove_background, background, Signal_to_noise)
|
2619
|
+
images = [np.squeeze(img) if img.shape[-1] == 1 else img for img in images]
|
2620
|
+
orig_dims = [(image.shape[0], image.shape[1]) for image in images]
|
2621
|
+
else:
|
2622
|
+
images, _, image_names, _ = _load_images_and_labels(image_files, None, circular, invert)
|
2623
|
+
images = [np.squeeze(img) if img.shape[-1] == 1 else img for img in images]
|
2624
|
+
orig_dims = [(image.shape[0], image.shape[1]) for image in images]
|
2625
|
+
if resize:
|
2626
|
+
images, _ = resize_images_and_labels(images, None, target_height, target_width, True)
|
2627
|
+
|
2628
|
+
for file_index, stack in enumerate(images):
|
2629
|
+
start = time.time()
|
2630
|
+
output = model.eval(x=stack,
|
2631
|
+
normalize=False,
|
2632
|
+
channels=chans,
|
2633
|
+
channel_axis=3,
|
2634
|
+
diameter=diameter,
|
2635
|
+
flow_threshold=flow_threshold,
|
2636
|
+
cellprob_threshold=cellprob_threshold,
|
2637
|
+
rescale=False,
|
2638
|
+
resample=False,
|
2639
|
+
progress=False)
|
2640
|
+
|
2641
|
+
if len(output) == 4:
|
2642
|
+
mask, flows, _, _ = output
|
2643
|
+
elif len(output) == 3:
|
2644
|
+
mask, flows, _ = output
|
2645
|
+
else:
|
2646
|
+
raise ValueError("Unexpected number of return values from model.eval()")
|
2647
|
+
|
2648
|
+
if resize:
|
2649
|
+
dims = orig_dims[file_index]
|
2650
|
+
mask = resizescikit(mask, dims, order=0, preserve_range=True, anti_aliasing=False).astype(mask.dtype)
|
2651
|
+
|
2652
|
+
stop = time.time()
|
2653
|
+
duration = (stop - start)
|
2654
|
+
time_ls.append(duration)
|
2655
|
+
average_time = np.mean(time_ls) if len(time_ls) > 0 else 0
|
2656
|
+
print(f'Processing {file_index+1}/{len(images)} images : Time/image {average_time:.3f} sec', end='\r', flush=True)
|
2657
|
+
if plot:
|
2658
|
+
if resize:
|
2659
|
+
stack = resizescikit(stack, dims, preserve_range=True, anti_aliasing=False).astype(stack.dtype)
|
2660
|
+
print_mask_and_flows(stack, mask, flows, overlay=True)
|
2661
|
+
if save:
|
2662
|
+
output_filename = os.path.join(dst, image_names[file_index])
|
2663
|
+
cv2.imwrite(output_filename, mask)
|
2664
|
+
|
2665
|
+
|
2666
|
+
def check_cellpose_models(settings):
|
2667
|
+
|
2668
|
+
src = settings['src']
|
2669
|
+
settings.setdefault('batch_size', 10)
|
2670
|
+
settings.setdefault('CP_prob', 0)
|
2671
|
+
settings.setdefault('flow_threshold', 0.4)
|
2672
|
+
settings.setdefault('save', True)
|
2673
|
+
settings.setdefault('normalize', True)
|
2674
|
+
settings.setdefault('channels', [0,0])
|
2675
|
+
settings.setdefault('percentiles', None)
|
2676
|
+
settings.setdefault('circular', False)
|
2677
|
+
settings.setdefault('invert', False)
|
2678
|
+
settings.setdefault('plot', True)
|
2679
|
+
settings.setdefault('diameter', 40)
|
2680
|
+
settings.setdefault('grayscale', True)
|
2681
|
+
settings.setdefault('remove_background', False)
|
2682
|
+
settings.setdefault('background', 100)
|
2683
|
+
settings.setdefault('Signal_to_noise', 5)
|
2684
|
+
settings.setdefault('verbose', False)
|
2685
|
+
settings.setdefault('resize', False)
|
2686
|
+
settings.setdefault('target_height', None)
|
2687
|
+
settings.setdefault('target_width', None)
|
2688
|
+
|
2689
|
+
if settings['verbose']:
|
2690
|
+
settings_df = pd.DataFrame(list(settings.items()), columns=['setting_key', 'setting_value'])
|
2691
|
+
settings_df['setting_value'] = settings_df['setting_value'].apply(str)
|
2692
|
+
display(settings_df)
|
2693
|
+
|
2694
|
+
cellpose_models = ['cyto', 'nuclei', 'cyto2', 'cyto3']
|
2695
|
+
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
2696
|
+
|
2697
|
+
for model_name in cellpose_models:
|
2698
|
+
|
2699
|
+
model = cp_models.CellposeModel(gpu=True, model_type=model_name, device=device)
|
2700
|
+
print(f'Using {model_name}')
|
2701
|
+
generate_masks_from_imgs(src, model, model_name, settings['batch_size'], settings['diameter'], settings['CP_prob'], settings['flow_threshold'], settings['grayscale'], settings['save'], settings['normalize'], settings['channels'], settings['percentiles'], settings['circular'], settings['invert'], settings['plot'], settings['resize'], settings['target_height'], settings['target_width'], settings['remove_background'], settings['background'], settings['Signal_to_noise'], settings['verbose'])
|
2702
|
+
|
2703
|
+
return
|
2704
|
+
|
2705
|
+
def save_results_and_figure(src, fig, results):
|
2706
|
+
|
2707
|
+
if not isinstance(results, pd.DataFrame):
|
2708
|
+
results = pd.DataFrame(results)
|
2709
|
+
|
2710
|
+
results_dir = os.path.join(src, 'results')
|
2711
|
+
os.makedirs(results_dir, exist_ok=True)
|
2712
|
+
results_path = os.path.join(results_dir,f'results.csv')
|
2713
|
+
fig_path = os.path.join(results_dir, f'model_comparison_plot.pdf')
|
2714
|
+
results.to_csv(results_path, index=False)
|
2715
|
+
fig.savefig(fig_path, format='pdf')
|
2716
|
+
print(f'Saved figure to {fig_path} and results to {results_path}')
|
2717
|
+
|
2718
|
+
def compare_mask(args):
|
2719
|
+
src, filename, dirs, conditions = args
|
2720
|
+
paths = [os.path.join(d, filename) for d in dirs]
|
2721
|
+
|
2722
|
+
if not all(os.path.exists(path) for path in paths):
|
2723
|
+
return None
|
2724
|
+
|
2725
|
+
from .io import _read_mask # Import here to avoid issues in multiprocessing
|
2726
|
+
from .utils import extract_boundaries, boundary_f1_score, compute_segmentation_ap, jaccard_index
|
2727
|
+
from .plot import plot_comparison_results
|
2728
|
+
|
2729
|
+
masks = [_read_mask(path) for path in paths]
|
2730
|
+
file_results = {'filename': filename}
|
2731
|
+
|
2732
|
+
for i in range(len(masks)):
|
2733
|
+
for j in range(i + 1, len(masks)):
|
2734
|
+
mask_i, mask_j = masks[i], masks[j]
|
2735
|
+
f1_score = boundary_f1_score(mask_i, mask_j)
|
2736
|
+
jac_index = jaccard_index(mask_i, mask_j)
|
2737
|
+
ap_score = compute_segmentation_ap(mask_i, mask_j)
|
2738
|
+
|
2739
|
+
file_results.update({
|
2740
|
+
f'jaccard_{conditions[i]}_{conditions[j]}': jac_index,
|
2741
|
+
f'boundary_f1_{conditions[i]}_{conditions[j]}': f1_score,
|
2742
|
+
f'ap_{conditions[i]}_{conditions[j]}': ap_score
|
2743
|
+
})
|
2744
|
+
|
2745
|
+
return file_results
|
2746
|
+
|
2747
|
+
def compare_cellpose_masks(src, verbose=False, processes=None, save=True):
|
2748
|
+
from .plot import visualize_cellpose_masks, plot_comparison_results
|
2749
|
+
from .io import _read_mask
|
2750
|
+
|
2751
|
+
dirs = [os.path.join(src, d) for d in os.listdir(src) if os.path.isdir(os.path.join(src, d)) and d != 'results']
|
2752
|
+
dirs.sort() # Optional: sort directories if needed
|
2753
|
+
conditions = [os.path.basename(d) for d in dirs]
|
2754
|
+
|
2755
|
+
# Get common files in all directories
|
2756
|
+
common_files = set(os.listdir(dirs[0]))
|
2757
|
+
for d in dirs[1:]:
|
2758
|
+
common_files.intersection_update(os.listdir(d))
|
2759
|
+
common_files = list(common_files)
|
2760
|
+
|
2761
|
+
# Create a pool of workers
|
2762
|
+
with Pool(processes=processes) as pool:
|
2763
|
+
args = [(src, filename, dirs, conditions) for filename in common_files]
|
2764
|
+
results = pool.map(compare_mask, args)
|
2765
|
+
|
2766
|
+
# Filter out None results (from skipped files)
|
2767
|
+
results = [res for res in results if res is not None]
|
2768
|
+
#print(results)
|
2769
|
+
if verbose:
|
2770
|
+
for result in results:
|
2771
|
+
filename = result['filename']
|
2772
|
+
masks = [_read_mask(os.path.join(d, filename)) for d in dirs]
|
2773
|
+
visualize_cellpose_masks(masks, titles=conditions, filename=filename, save=save, src=src)
|
2774
|
+
|
2775
|
+
fig = plot_comparison_results(results)
|
2776
|
+
save_results_and_figure(src, fig, results)
|
2777
|
+
return
|
2778
|
+
|
2779
|
+
def _calculate_similarity(df, features, col_to_compare, val1, val2):
|
2780
|
+
"""
|
2781
|
+
Calculate similarity scores of each well to the positive and negative controls using various metrics.
|
2782
|
+
|
2783
|
+
Args:
|
2784
|
+
df (pandas.DataFrame): DataFrame containing the data.
|
2785
|
+
features (list): List of feature columns to use for similarity calculation.
|
2786
|
+
col_to_compare (str): Column name to use for comparing groups.
|
2787
|
+
val1, val2 (str): Values in col_to_compare to create subsets for comparison.
|
2788
|
+
|
2789
|
+
Returns:
|
2790
|
+
pandas.DataFrame: DataFrame with similarity scores.
|
2791
|
+
"""
|
2792
|
+
# Separate positive and negative control wells
|
2793
|
+
pos_control = df[df[col_to_compare] == val1][features].mean()
|
2794
|
+
neg_control = df[df[col_to_compare] == val2][features].mean()
|
2795
|
+
|
2796
|
+
# Standardize features for Mahalanobis distance
|
2797
|
+
scaler = StandardScaler()
|
2798
|
+
scaled_features = scaler.fit_transform(df[features])
|
2799
|
+
|
2800
|
+
# Regularize the covariance matrix to avoid singularity
|
2801
|
+
cov_matrix = np.cov(scaled_features, rowvar=False)
|
2802
|
+
inv_cov_matrix = None
|
2803
|
+
try:
|
2804
|
+
inv_cov_matrix = np.linalg.inv(cov_matrix)
|
2805
|
+
except np.linalg.LinAlgError:
|
2806
|
+
# Add a small value to the diagonal elements for regularization
|
2807
|
+
epsilon = 1e-5
|
2808
|
+
inv_cov_matrix = np.linalg.inv(cov_matrix + np.eye(cov_matrix.shape[0]) * epsilon)
|
2809
|
+
|
2810
|
+
# Calculate similarity scores
|
2811
|
+
df['similarity_to_pos_euclidean'] = df[features].apply(lambda row: euclidean(row, pos_control), axis=1)
|
2812
|
+
df['similarity_to_neg_euclidean'] = df[features].apply(lambda row: euclidean(row, neg_control), axis=1)
|
2813
|
+
df['similarity_to_pos_cosine'] = df[features].apply(lambda row: cosine(row, pos_control), axis=1)
|
2814
|
+
df['similarity_to_neg_cosine'] = df[features].apply(lambda row: cosine(row, neg_control), axis=1)
|
2815
|
+
df['similarity_to_pos_mahalanobis'] = df[features].apply(lambda row: mahalanobis(row, pos_control, inv_cov_matrix), axis=1)
|
2816
|
+
df['similarity_to_neg_mahalanobis'] = df[features].apply(lambda row: mahalanobis(row, neg_control, inv_cov_matrix), axis=1)
|
2817
|
+
df['similarity_to_pos_manhattan'] = df[features].apply(lambda row: cityblock(row, pos_control), axis=1)
|
2818
|
+
df['similarity_to_neg_manhattan'] = df[features].apply(lambda row: cityblock(row, neg_control), axis=1)
|
2819
|
+
df['similarity_to_pos_minkowski'] = df[features].apply(lambda row: minkowski(row, pos_control, p=3), axis=1)
|
2820
|
+
df['similarity_to_neg_minkowski'] = df[features].apply(lambda row: minkowski(row, neg_control, p=3), axis=1)
|
2821
|
+
df['similarity_to_pos_chebyshev'] = df[features].apply(lambda row: chebyshev(row, pos_control), axis=1)
|
2822
|
+
df['similarity_to_neg_chebyshev'] = df[features].apply(lambda row: chebyshev(row, neg_control), axis=1)
|
2823
|
+
df['similarity_to_pos_hamming'] = df[features].apply(lambda row: hamming(row, pos_control), axis=1)
|
2824
|
+
df['similarity_to_neg_hamming'] = df[features].apply(lambda row: hamming(row, neg_control), axis=1)
|
2825
|
+
df['similarity_to_pos_jaccard'] = df[features].apply(lambda row: jaccard(row, pos_control), axis=1)
|
2826
|
+
df['similarity_to_neg_jaccard'] = df[features].apply(lambda row: jaccard(row, neg_control), axis=1)
|
2827
|
+
df['similarity_to_pos_braycurtis'] = df[features].apply(lambda row: braycurtis(row, pos_control), axis=1)
|
2828
|
+
df['similarity_to_neg_braycurtis'] = df[features].apply(lambda row: braycurtis(row, neg_control), axis=1)
|
2829
|
+
|
2830
|
+
return df
|
2831
|
+
|
2832
|
+
def _permutation_importance(df, feature_string='channel_3', col_to_compare='col', pos='c1', neg='c2', exclude=None, n_repeats=10, clean=True, nr_to_plot=30, n_estimators=100, test_size=0.2, random_state=42, model_type='xgboost', n_jobs=-1):
|
2833
|
+
|
2834
|
+
"""
|
2835
|
+
Calculates permutation importance for numerical features in the dataframe,
|
2836
|
+
comparing groups based on specified column values and uses the model to predict
|
2837
|
+
the class for all other rows in the dataframe.
|
2838
|
+
|
2839
|
+
Args:
|
2840
|
+
df (pandas.DataFrame): The DataFrame containing the data.
|
2841
|
+
feature_string (str): String to filter features that contain this substring.
|
2842
|
+
col_to_compare (str): Column name to use for comparing groups.
|
2843
|
+
pos, neg (str): Values in col_to_compare to create subsets for comparison.
|
2844
|
+
exclude (list or str, optional): Columns to exclude from features.
|
2845
|
+
n_repeats (int): Number of repeats for permutation importance.
|
2846
|
+
clean (bool): Whether to remove columns with a single value.
|
2847
|
+
nr_to_plot (int): Number of top features to plot based on permutation importance.
|
2848
|
+
n_estimators (int): Number of trees in the random forest, gradient boosting, or XGBoost model.
|
2849
|
+
test_size (float): Proportion of the dataset to include in the test split.
|
2850
|
+
random_state (int): Random seed for reproducibility.
|
2851
|
+
model_type (str): Type of model to use ('random_forest', 'logistic_regression', 'gradient_boosting', 'xgboost').
|
2852
|
+
n_jobs (int): Number of jobs to run in parallel for applicable models.
|
2853
|
+
|
2854
|
+
Returns:
|
2855
|
+
pandas.DataFrame: The original dataframe with added prediction and data usage columns.
|
2856
|
+
pandas.DataFrame: DataFrame containing the importances and standard deviations.
|
2857
|
+
"""
|
2858
|
+
|
2859
|
+
from .utils import filter_dataframe_features
|
2860
|
+
|
2861
|
+
if 'cells_per_well' in df.columns:
|
2862
|
+
df = df.drop(columns=['cells_per_well'])
|
2863
|
+
|
2864
|
+
# Subset the dataframe based on specified column values
|
2865
|
+
df1 = df[df[col_to_compare] == pos].copy()
|
2866
|
+
df2 = df[df[col_to_compare] == neg].copy()
|
2867
|
+
|
2868
|
+
# Create target variable
|
2869
|
+
df1['target'] = 0
|
2870
|
+
df2['target'] = 1
|
2871
|
+
|
2872
|
+
# Combine the subsets for analysis
|
2873
|
+
combined_df = pd.concat([df1, df2])
|
2874
|
+
|
2875
|
+
if feature_string in ['channel_0', 'channel_1', 'channel_2', 'channel_3']:
|
2876
|
+
channel_of_interest = int(feature_string.split('_')[-1])
|
2877
|
+
elif not feature_string is 'morphology':
|
2878
|
+
channel_of_interest = 'morphology'
|
2879
|
+
|
2880
|
+
_, features = filter_dataframe_features(combined_df, channel_of_interest, exclude)
|
2881
|
+
|
2882
|
+
X = combined_df[features]
|
2883
|
+
y = combined_df['target']
|
2884
|
+
|
2885
|
+
# Split the data into training and testing sets
|
2886
|
+
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=test_size, random_state=random_state)
|
2887
|
+
|
2888
|
+
# Label the data in the original dataframe
|
2889
|
+
combined_df['data_usage'] = 'train'
|
2890
|
+
combined_df.loc[X_test.index, 'data_usage'] = 'test'
|
2891
|
+
|
2892
|
+
# Initialize the model based on model_type
|
2893
|
+
if model_type == 'random_forest':
|
2894
|
+
model = RandomForestClassifier(n_estimators=n_estimators, random_state=random_state, n_jobs=n_jobs)
|
2895
|
+
elif model_type == 'logistic_regression':
|
2896
|
+
model = LogisticRegression(max_iter=1000, random_state=random_state, n_jobs=n_jobs)
|
2897
|
+
elif model_type == 'gradient_boosting':
|
2898
|
+
model = HistGradientBoostingClassifier(max_iter=n_estimators, random_state=random_state) # Supports n_jobs internally
|
2899
|
+
elif model_type == 'xgboost':
|
2900
|
+
model = XGBClassifier(n_estimators=n_estimators, random_state=random_state, nthread=n_jobs, use_label_encoder=False, eval_metric='logloss')
|
2901
|
+
else:
|
2902
|
+
raise ValueError(f"Unsupported model_type: {model_type}")
|
2903
|
+
|
2904
|
+
model.fit(X_train, y_train)
|
2905
|
+
|
2906
|
+
perm_importance = permutation_importance(model, X_train, y_train, n_repeats=n_repeats, random_state=random_state, n_jobs=n_jobs)
|
2907
|
+
|
2908
|
+
# Create a DataFrame for permutation importances
|
2909
|
+
permutation_df = pd.DataFrame({
|
2910
|
+
'feature': [features[i] for i in perm_importance.importances_mean.argsort()],
|
2911
|
+
'importance_mean': perm_importance.importances_mean[perm_importance.importances_mean.argsort()],
|
2912
|
+
'importance_std': perm_importance.importances_std[perm_importance.importances_mean.argsort()]
|
2913
|
+
}).tail(nr_to_plot)
|
2914
|
+
|
2915
|
+
# Plotting
|
2916
|
+
fig, ax = plt.subplots()
|
2917
|
+
ax.barh(permutation_df['feature'], permutation_df['importance_mean'], xerr=permutation_df['importance_std'], color="teal", align="center", alpha=0.6)
|
2918
|
+
ax.set_xlabel('Permutation Importance')
|
2919
|
+
plt.tight_layout()
|
2920
|
+
plt.show()
|
2921
|
+
|
2922
|
+
# Feature importance for models that support it
|
2923
|
+
if model_type in ['random_forest', 'xgboost', 'gradient_boosting']:
|
2924
|
+
feature_importances = model.feature_importances_
|
2925
|
+
feature_importance_df = pd.DataFrame({
|
2926
|
+
'feature': features,
|
2927
|
+
'importance': feature_importances
|
2928
|
+
}).sort_values(by='importance', ascending=False).head(nr_to_plot)
|
2929
|
+
|
2930
|
+
# Plotting feature importance
|
2931
|
+
fig, ax = plt.subplots()
|
2932
|
+
ax.barh(feature_importance_df['feature'], feature_importance_df['importance'], color="blue", align="center", alpha=0.6)
|
2933
|
+
ax.set_xlabel('Feature Importance')
|
2934
|
+
plt.tight_layout()
|
2935
|
+
plt.show()
|
2936
|
+
else:
|
2937
|
+
feature_importance_df = pd.DataFrame()
|
2938
|
+
|
2939
|
+
# Predicting the target variable for the test set
|
2940
|
+
predictions_test = model.predict(X_test)
|
2941
|
+
combined_df.loc[X_test.index, 'predictions'] = predictions_test
|
2942
|
+
|
2943
|
+
# Predicting the target variable for the training set
|
2944
|
+
predictions_train = model.predict(X_train)
|
2945
|
+
combined_df.loc[X_train.index, 'predictions'] = predictions_train
|
2946
|
+
|
2947
|
+
# Predicting the target variable for all other rows in the dataframe
|
2948
|
+
X_all = df[features]
|
2949
|
+
all_predictions = model.predict(X_all)
|
2950
|
+
df['predictions'] = all_predictions
|
2951
|
+
|
2952
|
+
# Combine data usage labels back to the original dataframe
|
2953
|
+
combined_data_usage = pd.concat([combined_df[['data_usage']], df[['predictions']]], axis=0)
|
2954
|
+
df = df.join(combined_data_usage, how='left', rsuffix='_model')
|
2955
|
+
|
2956
|
+
# Calculating and printing the accuracy metrics
|
2957
|
+
accuracy = accuracy_score(y_test, predictions_test)
|
2958
|
+
precision = precision_score(y_test, predictions_test)
|
2959
|
+
recall = recall_score(y_test, predictions_test)
|
2960
|
+
f1 = f1_score(y_test, predictions_test)
|
2961
|
+
print(f"Accuracy: {accuracy}")
|
2962
|
+
print(f"Precision: {precision}")
|
2963
|
+
print(f"Recall: {recall}")
|
2964
|
+
print(f"F1 Score: {f1}")
|
2965
|
+
|
2966
|
+
# Printing class-specific accuracy metrics
|
2967
|
+
print("\nClassification Report:")
|
2968
|
+
print(classification_report(y_test, predictions_test))
|
2969
|
+
|
2970
|
+
df = _calculate_similarity(df, features, col_to_compare, pos, neg)
|
2971
|
+
|
2972
|
+
return [df, permutation_df, feature_importance_df, model, X_train, X_test, y_train, y_test]
|
2973
|
+
|
2974
|
+
def _shap_analysis(model, X_train, X_test):
|
2975
|
+
|
2976
|
+
"""
|
2977
|
+
Performs SHAP analysis on the given model and data.
|
2978
|
+
|
2979
|
+
Args:
|
2980
|
+
model: The trained model.
|
2981
|
+
X_train (pandas.DataFrame): Training feature set.
|
2982
|
+
X_test (pandas.DataFrame): Testing feature set.
|
2983
|
+
"""
|
2984
|
+
|
2985
|
+
explainer = shap.Explainer(model, X_train)
|
2986
|
+
shap_values = explainer(X_test)
|
2987
|
+
|
2988
|
+
# Summary plot
|
2989
|
+
shap.summary_plot(shap_values, X_test)
|
2990
|
+
|
2991
|
+
def plate_heatmap(src, model_type='xgboost', variable='predictions', grouping='mean', min_max='allq', cmap='viridis', channel_of_interest=3, min_count=25, n_estimators=100, col_to_compare='col', pos='c1', neg='c2', exclude=None, n_repeats=10, clean=True, nr_to_plot=20, verbose=False, n_jobs=-1):
|
2992
|
+
from .io import _read_and_merge_data
|
2993
|
+
from .plot import _plot_plates
|
2994
|
+
|
2995
|
+
db_loc = [src+'/measurements/measurements.db']
|
2996
|
+
tables = ['cell', 'nucleus', 'pathogen','cytoplasm']
|
2997
|
+
include_multinucleated, include_multiinfected, include_noninfected = True, 2.0, True
|
2998
|
+
|
2999
|
+
df, _ = _read_and_merge_data(db_loc,
|
3000
|
+
tables,
|
3001
|
+
verbose=verbose,
|
3002
|
+
include_multinucleated=include_multinucleated,
|
3003
|
+
include_multiinfected=include_multiinfected,
|
3004
|
+
include_noninfected=include_noninfected)
|
3005
|
+
|
3006
|
+
if not channel_of_interest is None:
|
3007
|
+
df['recruitment'] = df[f'pathogen_channel_{channel_of_interest}_mean_intensity']/df[f'cytoplasm_channel_{channel_of_interest}_mean_intensity']
|
3008
|
+
feature_string = f'channel_{channel_of_interest}'
|
3009
|
+
else:
|
3010
|
+
feature_string = None
|
3011
|
+
|
3012
|
+
output = _permutation_importance(df, feature_string, col_to_compare, pos, neg, exclude, n_repeats, clean, nr_to_plot, n_estimators=n_estimators, random_state=42, model_type=model_type, n_jobs=n_jobs)
|
3013
|
+
|
3014
|
+
_shap_analysis(output[3], output[4], output[5])
|
3015
|
+
|
3016
|
+
features = output[0].select_dtypes(include=[np.number]).columns.tolist()
|
3017
|
+
|
3018
|
+
if not variable in features:
|
3019
|
+
raise ValueError(f"Variable {variable} not found in the dataframe. Please choose one of the following: {features}")
|
3020
|
+
|
3021
|
+
plate_heatmap = _plot_plates(output[0], variable, grouping, min_max, cmap, min_count)
|
3022
|
+
return [output, plate_heatmap]
|
3023
|
+
|
3024
|
+
def join_measurments_and_annotation(src, tables = ['cell', 'nucleus', 'pathogen','cytoplasm']):
|
3025
|
+
|
3026
|
+
from .io import _read_and_merge_data, _read_db
|
3027
|
+
|
3028
|
+
db_loc = [src+'/measurements/measurements.db']
|
3029
|
+
loc = src+'/measurements/measurements.db'
|
3030
|
+
df, _ = _read_and_merge_data(db_loc,
|
3031
|
+
tables,
|
3032
|
+
verbose=True,
|
3033
|
+
include_multinucleated=True,
|
3034
|
+
include_multiinfected=True,
|
3035
|
+
include_noninfected=True)
|
3036
|
+
|
3037
|
+
paths_df = _read_db(loc, tables=['png_list'])
|
3038
|
+
|
3039
|
+
merged_df = pd.merge(df, paths_df[0], on='prcfo', how='left')
|
3040
|
+
|
3041
|
+
return merged_df
|
3042
|
+
|
3043
|
+
def jitterplot_by_annotation(src, x_column, y_column, plot_title='Jitter Plot', output_path=None, filter_column=None, filter_values=None):
|
3044
|
+
"""
|
3045
|
+
Reads a CSV file and creates a jitter plot of one column grouped by another column.
|
3046
|
+
|
3047
|
+
Args:
|
3048
|
+
src (str): Path to the source data.
|
3049
|
+
x_column (str): Name of the column to be used for the x-axis.
|
3050
|
+
y_column (str): Name of the column to be used for the y-axis.
|
3051
|
+
plot_title (str): Title of the plot. Default is 'Jitter Plot'.
|
3052
|
+
output_path (str): Path to save the plot image. If None, the plot will be displayed. Default is None.
|
3053
|
+
|
3054
|
+
Returns:
|
3055
|
+
pd.DataFrame: The filtered and balanced DataFrame.
|
3056
|
+
"""
|
3057
|
+
# Read the CSV file into a DataFrame
|
3058
|
+
df = join_measurments_and_annotation(src, tables=['cell', 'nucleus', 'pathogen', 'cytoplasm'])
|
3059
|
+
|
3060
|
+
# Print column names for debugging
|
3061
|
+
print(f"Generated dataframe with: {df.shape[1]} columns and {df.shape[0]} rows")
|
3062
|
+
#print("Columns in DataFrame:", df.columns.tolist())
|
3063
|
+
|
3064
|
+
# Replace NaN values with a specific label in x_column
|
3065
|
+
df[x_column] = df[x_column].fillna('NaN')
|
3066
|
+
|
3067
|
+
# Filter the DataFrame if filter_column and filter_values are provided
|
3068
|
+
if not filter_column is None:
|
3069
|
+
if isinstance(filter_column, str):
|
3070
|
+
df = df[df[filter_column].isin(filter_values)]
|
3071
|
+
if isinstance(filter_column, list):
|
3072
|
+
for i,val in enumerate(filter_column):
|
3073
|
+
print(f'hello {len(df)}')
|
3074
|
+
df = df[df[val].isin(filter_values[i])]
|
3075
|
+
|
3076
|
+
# Use the correct column names based on your DataFrame
|
3077
|
+
required_columns = ['plate_x', 'row_x', 'col_x']
|
3078
|
+
if not all(column in df.columns for column in required_columns):
|
3079
|
+
raise KeyError(f"DataFrame does not contain the necessary columns: {required_columns}")
|
3080
|
+
|
3081
|
+
# Filter to retain rows with non-NaN values in x_column and with matching plate, row, col values
|
3082
|
+
non_nan_df = df[df[x_column] != 'NaN']
|
3083
|
+
retained_rows = df[df[['plate_x', 'row_x', 'col_x']].apply(tuple, axis=1).isin(non_nan_df[['plate_x', 'row_x', 'col_x']].apply(tuple, axis=1))]
|
3084
|
+
|
3085
|
+
# Determine the minimum count of examples across all groups in x_column
|
3086
|
+
min_count = retained_rows[x_column].value_counts().min()
|
3087
|
+
print(f'Found {min_count} annotated images')
|
3088
|
+
|
3089
|
+
# Randomly sample min_count examples from each group in x_column
|
3090
|
+
balanced_df = retained_rows.groupby(x_column).apply(lambda x: x.sample(min_count, random_state=42)).reset_index(drop=True)
|
3091
|
+
|
3092
|
+
# Create the jitter plot
|
3093
|
+
plt.figure(figsize=(10, 6))
|
3094
|
+
jitter_plot = sns.stripplot(data=balanced_df, x=x_column, y=y_column, hue=x_column, jitter=True, palette='viridis', dodge=False)
|
3095
|
+
plt.title(plot_title)
|
3096
|
+
plt.xlabel(x_column)
|
3097
|
+
plt.ylabel(y_column)
|
3098
|
+
|
3099
|
+
# Customize the x-axis labels
|
3100
|
+
plt.xticks(rotation=45, ha='right')
|
3101
|
+
|
3102
|
+
# Adjust the position of the x-axis labels to be centered below the data
|
3103
|
+
ax = plt.gca()
|
3104
|
+
ax.set_xticklabels(ax.get_xticklabels(), rotation=45, ha='center')
|
3105
|
+
|
3106
|
+
# Save the plot to a file or display it
|
3107
|
+
if output_path:
|
3108
|
+
plt.savefig(output_path, bbox_inches='tight')
|
3109
|
+
print(f"Jitter plot saved to {output_path}")
|
3110
|
+
else:
|
3111
|
+
plt.show()
|
3112
|
+
|
3113
|
+
return balanced_df
|
3114
|
+
|
3115
|
+
def generate_image_umap(settings={}):
|
3116
|
+
"""
|
3117
|
+
Generate UMAP or tSNE embedding and visualize the data with clustering.
|
3118
|
+
|
3119
|
+
Parameters:
|
3120
|
+
settings (dict): Dictionary containing the following keys:
|
3121
|
+
src (str): Source directory containing the data.
|
3122
|
+
row_limit (int): Limit the number of rows to process.
|
3123
|
+
tables (list): List of table names to read from the database.
|
3124
|
+
visualize (str): Visualization type.
|
3125
|
+
image_nr (int): Number of images to display.
|
3126
|
+
dot_size (int): Size of dots in the scatter plot.
|
3127
|
+
n_neighbors (int): Number of neighbors for UMAP.
|
3128
|
+
figuresize (int): Size of the figure.
|
3129
|
+
black_background (bool): Whether to use a black background.
|
3130
|
+
remove_image_canvas (bool): Whether to remove the image canvas.
|
3131
|
+
plot_outlines (bool): Whether to plot outlines.
|
3132
|
+
plot_points (bool): Whether to plot points.
|
3133
|
+
smooth_lines (bool): Whether to smooth lines.
|
3134
|
+
verbose (bool): Whether to print verbose output.
|
3135
|
+
embedding_by_controls (bool): Whether to use embedding from controls.
|
3136
|
+
col_to_compare (str): Column to compare for control-based embedding.
|
3137
|
+
pos (str): Positive control value.
|
3138
|
+
neg (str): Negative control value.
|
3139
|
+
clustering (str): Clustering method ('DBSCAN' or 'KMeans').
|
3140
|
+
exclude (list): List of columns to exclude from the analysis.
|
3141
|
+
plot_images (bool): Whether to plot images.
|
3142
|
+
reduction_method (str): Dimensionality reduction method ('UMAP' or 'tSNE').
|
3143
|
+
save_figure (bool): Whether to save the figure as a PDF.
|
3144
|
+
|
3145
|
+
Returns:
|
3146
|
+
pd.DataFrame: DataFrame with the original data and an additional column 'cluster' containing the cluster identity.
|
3147
|
+
"""
|
3148
|
+
|
3149
|
+
from .io import _read_and_join_tables
|
3150
|
+
from .utils import get_db_paths, preprocess_data, reduction_and_clustering, remove_noise, generate_colors, correct_paths, plot_embedding, plot_clusters_grid, get_umap_image_settings
|
3151
|
+
from .alpha import cluster_feature_analysis, generate_umap_from_images
|
3152
|
+
|
3153
|
+
settings = get_umap_image_settings(settings)
|
3154
|
+
|
3155
|
+
if isinstance(settings['src'], str):
|
3156
|
+
settings['src'] = [settings['src']]
|
3157
|
+
|
3158
|
+
if settings['plot_images'] is False:
|
3159
|
+
settings['black_background'] = False
|
3160
|
+
|
3161
|
+
if settings['color_by']:
|
3162
|
+
settings['remove_cluster_noise'] = False
|
3163
|
+
settings['plot_outlines'] = False
|
3164
|
+
settings['smooth_lines'] = False
|
3165
|
+
|
3166
|
+
settings_df = pd.DataFrame(list(settings.items()), columns=['Key', 'Value'])
|
3167
|
+
settings_dir = os.path.join(settings['src'][0],'settings')
|
3168
|
+
settings_csv = os.path.join(settings_dir,'embedding_settings.csv')
|
3169
|
+
os.makedirs(settings_dir, exist_ok=True)
|
3170
|
+
settings_df.to_csv(settings_csv, index=False)
|
3171
|
+
display(settings_df)
|
3172
|
+
|
3173
|
+
db_paths = get_db_paths(settings['src'])
|
3174
|
+
|
3175
|
+
tables = settings['tables'] + ['png_list']
|
3176
|
+
all_df = pd.DataFrame()
|
3177
|
+
#image_paths = []
|
3178
|
+
|
3179
|
+
for i,db_path in enumerate(db_paths):
|
3180
|
+
df = _read_and_join_tables(db_path, table_names=tables)
|
3181
|
+
df, image_paths_tmp = correct_paths(df, settings['src'][i])
|
3182
|
+
all_df = pd.concat([all_df, df], axis=0)
|
3183
|
+
#image_paths.extend(image_paths_tmp)
|
3184
|
+
|
3185
|
+
all_df['cond'] = all_df['col'].apply(map_condition, neg=settings['neg'], pos=settings['pos'], mix=settings['mix'])
|
3186
|
+
|
3187
|
+
if settings['exclude_conditions']:
|
3188
|
+
if isinstance(settings['exclude_conditions'], str):
|
3189
|
+
settings['exclude_conditions'] = [settings['exclude_conditions']]
|
3190
|
+
row_count_before = len(all_df)
|
3191
|
+
all_df = all_df[~all_df['cond'].isin(settings['exclude_conditions'])]
|
3192
|
+
if settings['verbose']:
|
3193
|
+
print(f'Excluded {row_count_before - len(all_df)} rows after excluding: {settings["exclude_conditions"]}, rows left: {len(all_df)}')
|
3194
|
+
|
3195
|
+
if settings['row_limit'] is not None:
|
3196
|
+
all_df = all_df.sample(n=settings['row_limit'], random_state=42)
|
3197
|
+
|
3198
|
+
image_paths = all_df['png_path'].to_list()
|
3199
|
+
|
3200
|
+
if settings['embedding_by_controls']:
|
3201
|
+
|
3202
|
+
# Extract and reset the index for the column to compare
|
3203
|
+
col_to_compare = all_df[settings['col_to_compare']].reset_index(drop=True)
|
3204
|
+
|
3205
|
+
# Preprocess the data to obtain numeric data
|
3206
|
+
numeric_data = preprocess_data(all_df, settings['filter_by'], settings['remove_highly_correlated'], settings['log_data'], settings['exclude'])
|
3207
|
+
|
3208
|
+
# Convert numeric_data back to a DataFrame to align with col_to_compare
|
3209
|
+
numeric_data_df = pd.DataFrame(numeric_data)
|
3210
|
+
|
3211
|
+
# Ensure numeric_data_df and col_to_compare are properly aligned
|
3212
|
+
numeric_data_df = numeric_data_df.reset_index(drop=True)
|
3213
|
+
|
3214
|
+
# Assign the column back to numeric_data_df
|
3215
|
+
numeric_data_df[settings['col_to_compare']] = col_to_compare
|
3216
|
+
|
3217
|
+
# Subset the dataframe based on specified column values for controls
|
3218
|
+
positive_control_df = numeric_data_df[numeric_data_df[settings['col_to_compare']] == settings['pos']].copy()
|
3219
|
+
negative_control_df = numeric_data_df[numeric_data_df[settings['col_to_compare']] == settings['neg']].copy()
|
3220
|
+
control_numeric_data_df = pd.concat([positive_control_df, negative_control_df])
|
3221
|
+
|
3222
|
+
# Drop the comparison column from numeric_data_df and control_numeric_data_df
|
3223
|
+
numeric_data_df = numeric_data_df.drop(columns=[settings['col_to_compare']])
|
3224
|
+
control_numeric_data_df = control_numeric_data_df.drop(columns=[settings['col_to_compare']])
|
3225
|
+
|
3226
|
+
# Convert numeric_data_df and control_numeric_data_df back to numpy arrays
|
3227
|
+
numeric_data = numeric_data_df.values
|
3228
|
+
control_numeric_data = control_numeric_data_df.values
|
3229
|
+
|
3230
|
+
# Train the reducer on control data
|
3231
|
+
_, _, reducer = reduction_and_clustering(control_numeric_data, settings['n_neighbors'], settings['min_dist'], settings['metric'], settings['eps'], settings['min_samples'], settings['clustering'], settings['reduction_method'], settings['verbose'], n_jobs=settings['n_jobs'], mode='fit', model=False)
|
3232
|
+
|
3233
|
+
# Apply the trained reducer to the entire dataset
|
3234
|
+
numeric_data = preprocess_data(all_df, settings['filter_by'], settings['remove_highly_correlated'], settings['log_data'], settings['exclude'])
|
3235
|
+
embedding, labels, _ = reduction_and_clustering(numeric_data, settings['n_neighbors'], settings['min_dist'], settings['metric'], settings['eps'], settings['min_samples'], settings['clustering'], settings['reduction_method'], settings['verbose'], n_jobs=settings['n_jobs'], mode=None, model=reducer)
|
3236
|
+
|
3237
|
+
else:
|
3238
|
+
if settings['resnet_features']:
|
3239
|
+
numeric_data, embedding, labels = generate_umap_from_images(image_paths, settings['n_neighbors'], settings['min_dist'], settings['metric'], settings['clustering'], settings['eps'], settings['min_samples'], settings['n_jobs'], settings['verbose'])
|
3240
|
+
else:
|
3241
|
+
# Apply the trained reducer to the entire dataset
|
3242
|
+
numeric_data = preprocess_data(all_df, settings['filter_by'], settings['remove_highly_correlated'], settings['log_data'], settings['exclude'])
|
3243
|
+
embedding, labels, _ = reduction_and_clustering(numeric_data, settings['n_neighbors'], settings['min_dist'], settings['metric'], settings['eps'], settings['min_samples'], settings['clustering'], settings['reduction_method'], settings['verbose'], n_jobs=settings['n_jobs'])
|
3244
|
+
|
3245
|
+
if settings['remove_cluster_noise']:
|
3246
|
+
# Remove noise from the clusters (removes -1 labels from DBSCAN)
|
3247
|
+
embedding, labels = remove_noise(embedding, labels)
|
3248
|
+
|
3249
|
+
# Plot the results
|
3250
|
+
if settings['color_by']:
|
3251
|
+
if settings['embedding_by_controls']:
|
3252
|
+
labels = all_df[settings['color_by']]
|
3253
|
+
else:
|
3254
|
+
labels = all_df[settings['color_by']]
|
3255
|
+
|
3256
|
+
# Generate colors for the clusters
|
3257
|
+
colors = generate_colors(len(np.unique(labels)), settings['black_background'])
|
3258
|
+
|
3259
|
+
# Plot the embedding
|
3260
|
+
umap_plt = plot_embedding(embedding, image_paths, labels, settings['image_nr'], settings['img_zoom'], colors, settings['plot_by_cluster'], settings['plot_outlines'], settings['plot_points'], settings['plot_images'], settings['smooth_lines'], settings['black_background'], settings['figuresize'], settings['dot_size'], settings['remove_image_canvas'], settings['verbose'])
|
3261
|
+
if settings['plot_cluster_grids'] and settings['plot_images']:
|
3262
|
+
grid_plt = plot_clusters_grid(embedding, labels, settings['image_nr'], image_paths, colors, settings['figuresize'], settings['black_background'], settings['verbose'])
|
3263
|
+
|
3264
|
+
# Save figure as PDF if required
|
3265
|
+
if settings['save_figure']:
|
3266
|
+
results_dir = os.path.join(settings['src'][0], 'results')
|
3267
|
+
os.makedirs(results_dir, exist_ok=True)
|
3268
|
+
reduction_method = settings['reduction_method'].upper()
|
3269
|
+
embedding_path = os.path.join(results_dir, f'{reduction_method}_embedding.pdf')
|
3270
|
+
umap_plt.savefig(embedding_path, format='pdf')
|
3271
|
+
print(f'Saved {reduction_method} embedding to {embedding_path} and grid to {embedding_path}')
|
3272
|
+
if settings['plot_cluster_grids'] and settings['plot_images']:
|
3273
|
+
grid_path = os.path.join(results_dir, f'{reduction_method}_grid.pdf')
|
3274
|
+
grid_plt.savefig(grid_path, format='pdf')
|
3275
|
+
print(f'Saved {reduction_method} embedding to {embedding_path} and grid to {grid_path}')
|
3276
|
+
|
3277
|
+
# Add cluster labels to the dataframe
|
3278
|
+
all_df['cluster'] = labels
|
3279
|
+
|
3280
|
+
# Save the results to a CSV file
|
3281
|
+
results_dir = os.path.join(settings['src'][0], 'results')
|
3282
|
+
results_csv = os.path.join(results_dir,'embedding_results.csv')
|
3283
|
+
os.makedirs(results_dir, exist_ok=True)
|
3284
|
+
all_df.to_csv(results_csv, index=False)
|
3285
|
+
print(f'Results saved to {results_csv}')
|
3286
|
+
|
3287
|
+
if settings['analyze_clusters']:
|
3288
|
+
combined_results = cluster_feature_analysis(all_df)
|
3289
|
+
results_dir = os.path.join(settings['src'][0], 'results')
|
3290
|
+
cluster_results_csv = os.path.join(results_dir,'cluster_results.csv')
|
3291
|
+
os.makedirs(results_dir, exist_ok=True)
|
3292
|
+
combined_results.to_csv(cluster_results_csv, index=False)
|
3293
|
+
print(f'Cluster results saved to {cluster_results_csv}')
|
3294
|
+
|
3295
|
+
return all_df
|
3296
|
+
|
3297
|
+
# Define the mapping function
|
3298
|
+
def map_condition(col_value, neg='c1', pos='c2', mix='c3'):
|
3299
|
+
if col_value == neg:
|
3300
|
+
return 'neg'
|
3301
|
+
elif col_value == pos:
|
3302
|
+
return 'pos'
|
3303
|
+
elif col_value == mix:
|
3304
|
+
return 'mix'
|
3305
|
+
else:
|
3306
|
+
return 'screen'
|
3307
|
+
|
3308
|
+
def reducer_hyperparameter_search(settings={}, reduction_params=None, dbscan_params=None, kmeans_params=None, save=False):
|
3309
|
+
"""
|
3310
|
+
Perform a hyperparameter search for UMAP or tSNE on the given data.
|
3311
|
+
|
3312
|
+
Parameters:
|
3313
|
+
settings (dict): Dictionary containing the following keys:
|
3314
|
+
src (str): Source directory containing the data.
|
3315
|
+
row_limit (int): Limit the number of rows to process.
|
3316
|
+
tables (list): List of table names to read from the database.
|
3317
|
+
filter_by (str): Column to filter the data.
|
3318
|
+
sample_size (int): Number of samples to use for the hyperparameter search.
|
3319
|
+
remove_highly_correlated (bool): Whether to remove highly correlated columns.
|
3320
|
+
log_data (bool): Whether to log transform the data.
|
3321
|
+
verbose (bool): Whether to print verbose output.
|
3322
|
+
reduction_method (str): Dimensionality reduction method ('UMAP' or 'tSNE').
|
3323
|
+
reduction_params (list): List of dictionaries containing hyperparameters to test for the reduction method.
|
3324
|
+
dbscan_params (list): List of dictionaries containing DBSCAN hyperparameters to test.
|
3325
|
+
kmeans_params (list): List of dictionaries containing KMeans hyperparameters to test.
|
3326
|
+
pointsize (int): Size of the points in the scatter plot.
|
3327
|
+
save (bool): Whether to save the resulting plot as a file.
|
3328
|
+
|
3329
|
+
Returns:
|
3330
|
+
None
|
3331
|
+
"""
|
3332
|
+
|
3333
|
+
from .io import _read_and_join_tables
|
3334
|
+
from .utils import get_db_paths, preprocess_data, search_reduction_and_clustering, generate_colors, get_umap_image_settings
|
3335
|
+
|
3336
|
+
settings = get_umap_image_settings(settings)
|
3337
|
+
pointsize = settings['dot_size']
|
3338
|
+
if isinstance(dbscan_params, dict):
|
3339
|
+
dbscan_params = [dbscan_params]
|
3340
|
+
|
3341
|
+
if isinstance(kmeans_params, dict):
|
3342
|
+
kmeans_params = [kmeans_params]
|
3343
|
+
|
3344
|
+
if isinstance(reduction_params, dict):
|
3345
|
+
reduction_params = [reduction_params]
|
3346
|
+
|
3347
|
+
# Determine reduction method based on the keys in reduction_param
|
3348
|
+
if any('n_neighbors' in param for param in reduction_params):
|
3349
|
+
reduction_method = 'umap'
|
3350
|
+
elif any('perplexity' in param for param in reduction_params):
|
3351
|
+
reduction_method = 'tsne'
|
3352
|
+
elif any('perplexity' in param for param in reduction_params) and any('n_neighbors' in param for param in reduction_params):
|
3353
|
+
raise ValueError("Reduction parameters must include 'n_neighbors' for UMAP or 'perplexity' for tSNE, not both.")
|
3354
|
+
|
3355
|
+
if settings['reduction_method'].lower() != reduction_method:
|
3356
|
+
settings['reduction_method'] = reduction_method
|
3357
|
+
print(f'Changed reduction method to {reduction_method} based on the provided parameters.')
|
3358
|
+
|
3359
|
+
if settings['verbose']:
|
3360
|
+
display(pd.DataFrame(list(settings.items()), columns=['Key', 'Value']))
|
3361
|
+
|
3362
|
+
db_paths = get_db_paths(settings['src'])
|
3363
|
+
|
3364
|
+
tables = settings['tables']
|
3365
|
+
all_df = pd.DataFrame()
|
3366
|
+
for db_path in db_paths:
|
3367
|
+
df = _read_and_join_tables(db_path, table_names=tables)
|
3368
|
+
all_df = pd.concat([all_df, df], axis=0)
|
3369
|
+
|
3370
|
+
all_df['cond'] = all_df['col'].apply(map_condition, neg=settings['neg'], pos=settings['pos'], mix=settings['mix'])
|
3371
|
+
|
3372
|
+
if settings['exclude_conditions']:
|
3373
|
+
if isinstance(settings['exclude_conditions'], str):
|
3374
|
+
settings['exclude_conditions'] = [settings['exclude_conditions']]
|
3375
|
+
row_count_before = len(all_df)
|
3376
|
+
all_df = all_df[~all_df['cond'].isin(settings['exclude_conditions'])]
|
3377
|
+
if settings['verbose']:
|
3378
|
+
print(f'Excluded {row_count_before - len(all_df)} rows after excluding: {settings["exclude_conditions"]}, rows left: {len(all_df)}')
|
3379
|
+
|
3380
|
+
if settings['row_limit'] is not None:
|
3381
|
+
all_df = all_df.sample(n=settings['row_limit'], random_state=42)
|
3382
|
+
|
3383
|
+
numeric_data = preprocess_data(all_df, settings['filter_by'], settings['remove_highly_correlated'], settings['log_data'], settings['exclude'])
|
3384
|
+
|
3385
|
+
# Combine DBSCAN and KMeans parameters
|
3386
|
+
clustering_params = []
|
3387
|
+
if dbscan_params:
|
3388
|
+
for param in dbscan_params:
|
3389
|
+
param['method'] = 'dbscan'
|
3390
|
+
clustering_params.append(param)
|
3391
|
+
if kmeans_params:
|
3392
|
+
for param in kmeans_params:
|
3393
|
+
param['method'] = 'kmeans'
|
3394
|
+
clustering_params.append(param)
|
3395
|
+
|
3396
|
+
print('Testing paramiters:', reduction_params)
|
3397
|
+
print('Testing clustering paramiters:', clustering_params)
|
3398
|
+
|
3399
|
+
# Calculate the grid size
|
3400
|
+
grid_rows = len(reduction_params)
|
3401
|
+
grid_cols = len(clustering_params)
|
3402
|
+
|
3403
|
+
fig_width = grid_cols*10
|
3404
|
+
fig_height = grid_rows*10
|
3405
|
+
|
3406
|
+
fig, axs = plt.subplots(grid_rows, grid_cols, figsize=(fig_width, fig_height))
|
3407
|
+
|
3408
|
+
# Make sure axs is always an array of axes
|
3409
|
+
axs = np.atleast_1d(axs)
|
3410
|
+
|
3411
|
+
# Iterate through the Cartesian product of reduction and clustering hyperparameters
|
3412
|
+
for i, reduction_param in enumerate(reduction_params):
|
3413
|
+
for j, clustering_param in enumerate(clustering_params):
|
3414
|
+
if len(clustering_params) <= 1:
|
3415
|
+
axs[i].axis('off')
|
3416
|
+
ax = axs[i]
|
3417
|
+
elif len(reduction_params) <= 1:
|
3418
|
+
axs[j].axis('off')
|
3419
|
+
ax = axs[j]
|
3420
|
+
else:
|
3421
|
+
ax = axs[i, j]
|
3422
|
+
|
3423
|
+
# Perform dimensionality reduction and clustering
|
3424
|
+
if settings['reduction_method'].lower() == 'umap':
|
3425
|
+
n_neighbors = reduction_param.get('n_neighbors', 15)
|
3426
|
+
|
3427
|
+
if isinstance(n_neighbors, float):
|
3428
|
+
n_neighbors = int(n_neighbors * len(numeric_data))
|
3429
|
+
|
3430
|
+
min_dist = reduction_param.get('min_dist', 0.1)
|
3431
|
+
embedding, labels = search_reduction_and_clustering(numeric_data, n_neighbors, min_dist, settings['metric'],
|
3432
|
+
clustering_param.get('eps', 0.5), clustering_param.get('min_samples', 5),
|
3433
|
+
clustering_param['method'], settings['reduction_method'], settings['verbose'], reduction_param, n_jobs=settings['n_jobs'])
|
3434
|
+
|
3435
|
+
elif settings['reduction_method'].lower() == 'tsne':
|
3436
|
+
perplexity = reduction_param.get('perplexity', 30)
|
3437
|
+
|
3438
|
+
if isinstance(perplexity, float):
|
3439
|
+
perplexity = int(perplexity * len(numeric_data))
|
3440
|
+
|
3441
|
+
embedding, labels = search_reduction_and_clustering(numeric_data, perplexity, 0.1, settings['metric'],
|
3442
|
+
clustering_param.get('eps', 0.5), clustering_param.get('min_samples', 5),
|
3443
|
+
clustering_param['method'], settings['reduction_method'], settings['verbose'], reduction_param, n_jobs=settings['n_jobs'])
|
3444
|
+
|
3445
|
+
else:
|
3446
|
+
raise ValueError(f"Unsupported reduction method: {settings['reduction_method']}. Supported methods are 'UMAP' and 'tSNE'")
|
3447
|
+
|
3448
|
+
# Plot the results
|
3449
|
+
if settings['color_by']:
|
3450
|
+
unique_groups = all_df[settings['color_by']].unique()
|
3451
|
+
colors = generate_colors(len(unique_groups), False)
|
3452
|
+
for group, color in zip(unique_groups, colors):
|
3453
|
+
indices = all_df[settings['color_by']] == group
|
3454
|
+
ax.scatter(embedding[indices, 0], embedding[indices, 1], s=pointsize, label=f"{group}", color=color)
|
3455
|
+
else:
|
3456
|
+
unique_labels = np.unique(labels)
|
3457
|
+
colors = generate_colors(len(unique_labels), False)
|
3458
|
+
for label, color in zip(unique_labels, colors):
|
3459
|
+
ax.scatter(embedding[labels == label, 0], embedding[labels == label, 1], s=pointsize, label=f"Cluster {label}", color=color)
|
3460
|
+
|
3461
|
+
ax.set_title(f"{settings['reduction_method']} {reduction_param}\n{clustering_param['method']} {clustering_param}")
|
3462
|
+
ax.legend()
|
3463
|
+
|
3464
|
+
plt.tight_layout()
|
3465
|
+
if save:
|
3466
|
+
results_dir = os.path.join(settings['src'], 'results')
|
3467
|
+
os.makedirs(results_dir, exist_ok=True)
|
3468
|
+
plt.savefig(os.path.join(results_dir, 'hyperparameter_search.pdf'))
|
3469
|
+
else:
|
3470
|
+
plt.show()
|
3471
|
+
|
2250
3472
|
return
|