spacr 0.0.20__py3-none-any.whl → 0.0.35__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- spacr/alpha.py +291 -14
- spacr/annotate_app.py +7 -5
- spacr/chris.py +50 -0
- spacr/core.py +1301 -426
- spacr/foldseek.py +793 -0
- spacr/get_alfafold_structures.py +72 -0
- spacr/gui.py +144 -0
- spacr/gui_classify_app.py +65 -74
- spacr/gui_mask_app.py +110 -87
- spacr/gui_measure_app.py +104 -81
- spacr/gui_utils.py +276 -31
- spacr/io.py +261 -102
- spacr/mask_app.py +6 -3
- spacr/measure.py +150 -64
- spacr/plot.py +151 -12
- spacr/sim.py +666 -119
- spacr/timelapse.py +139 -9
- spacr/train.py +18 -10
- spacr/utils.py +43 -49
- {spacr-0.0.20.dist-info → spacr-0.0.35.dist-info}/METADATA +5 -2
- spacr-0.0.35.dist-info/RECORD +35 -0
- spacr-0.0.35.dist-info/entry_points.txt +8 -0
- spacr-0.0.20.dist-info/RECORD +0 -31
- spacr-0.0.20.dist-info/entry_points.txt +0 -7
- {spacr-0.0.20.dist-info → spacr-0.0.35.dist-info}/LICENSE +0 -0
- {spacr-0.0.20.dist-info → spacr-0.0.35.dist-info}/WHEEL +0 -0
- {spacr-0.0.20.dist-info → spacr-0.0.35.dist-info}/top_level.txt +0 -0
spacr/alpha.py
CHANGED
@@ -1,18 +1,295 @@
|
|
1
|
-
|
2
|
-
|
3
|
-
|
1
|
+
from skimage import measure, feature
|
2
|
+
from skimage.filters import gabor
|
3
|
+
from skimage.color import rgb2gray
|
4
|
+
from skimage.feature.texture import greycomatrix, greycoprops, local_binary_pattern
|
5
|
+
from skimage.util import img_as_ubyte
|
6
|
+
import numpy as np
|
7
|
+
import pandas as pd
|
8
|
+
from scipy.stats import skew, kurtosis, entropy, hmean, gmean, mode
|
9
|
+
import pywt
|
4
10
|
|
5
|
-
|
11
|
+
import os
|
12
|
+
import pandas as pd
|
13
|
+
from PIL import Image
|
14
|
+
from torchvision import transforms
|
15
|
+
import torch
|
16
|
+
import torch.nn as nn
|
17
|
+
import torch.nn.functional as F
|
18
|
+
from torch_geometric.data import Data, DataLoader
|
19
|
+
from torch_geometric.nn import GCNConv, global_mean_pool
|
20
|
+
from torch.optim import Adam
|
21
|
+
|
22
|
+
# Step 1: Data Preparation
|
23
|
+
|
24
|
+
# Load images
|
25
|
+
def load_images(image_dir):
|
26
|
+
images = {}
|
27
|
+
for filename in os.listdir(image_dir):
|
28
|
+
if filename.endswith(".png"):
|
29
|
+
img = Image.open(os.path.join(image_dir, filename))
|
30
|
+
images[filename] = img
|
31
|
+
return images
|
32
|
+
|
33
|
+
# Load sequencing data
|
34
|
+
def load_sequencing_data(seq_file):
|
35
|
+
seq_data = pd.read_csv(seq_file)
|
36
|
+
return seq_data
|
37
|
+
|
38
|
+
# Step 2: Data Representation
|
39
|
+
|
40
|
+
# Image Representation (Using a simple CNN for feature extraction)
|
41
|
+
class CNNFeatureExtractor(nn.Module):
|
42
|
+
def __init__(self):
|
43
|
+
super(CNNFeatureExtractor, self).__init__()
|
44
|
+
self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)
|
45
|
+
self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
|
46
|
+
self.fc = nn.Linear(32 * 8 * 8, 128) # Assuming input images are 64x64
|
47
|
+
|
48
|
+
def forward(self, x):
|
49
|
+
x = F.relu(self.conv1(x))
|
50
|
+
x = F.max_pool2d(x, 2)
|
51
|
+
x = F.relu(self.conv2(x))
|
52
|
+
x = F.max_pool2d(x, 2)
|
53
|
+
x = x.view(x.size(0), -1)
|
54
|
+
x = self.fc(x)
|
55
|
+
return x
|
56
|
+
|
57
|
+
# Graph Representation
|
58
|
+
def create_graph(wells, sequencing_data):
|
59
|
+
nodes = []
|
60
|
+
edges = []
|
61
|
+
node_features = []
|
62
|
+
|
63
|
+
for well in wells:
|
64
|
+
# Add node for each well
|
65
|
+
nodes.append(well)
|
66
|
+
|
67
|
+
# Get sequencing data for the well
|
68
|
+
seq_info = sequencing_data[sequencing_data['well'] == well]
|
69
|
+
|
70
|
+
# Create node features based on gene knockouts and abundances
|
71
|
+
features = torch.tensor(seq_info['abundance'].values, dtype=torch.float)
|
72
|
+
node_features.append(features)
|
73
|
+
|
74
|
+
# Define edges (for simplicity, assume fully connected graph)
|
75
|
+
for other_well in wells:
|
76
|
+
if other_well != well:
|
77
|
+
edges.append((wells.index(well), wells.index(other_well)))
|
78
|
+
|
79
|
+
edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous()
|
80
|
+
x = torch.stack(node_features)
|
81
|
+
|
82
|
+
data = Data(x=x, edge_index=edge_index)
|
83
|
+
return data
|
84
|
+
|
85
|
+
# Step 3: Model Architecture
|
86
|
+
|
87
|
+
class GraphTransformer(nn.Module):
|
88
|
+
def __init__(self, in_channels, hidden_channels, out_channels):
|
89
|
+
super(GraphTransformer, self).__init__()
|
90
|
+
self.conv1 = GCNConv(in_channels, hidden_channels)
|
91
|
+
self.conv2 = GCNConv(hidden_channels, hidden_channels)
|
92
|
+
self.fc = nn.Linear(hidden_channels, out_channels)
|
93
|
+
self.attention = nn.MultiheadAttention(hidden_channels, num_heads=8)
|
94
|
+
|
95
|
+
def forward(self, x, edge_index, batch):
|
96
|
+
x = F.relu(self.conv1(x, edge_index))
|
97
|
+
x = F.relu(self.conv2(x, edge_index))
|
98
|
+
|
99
|
+
# Apply attention mechanism
|
100
|
+
x, _ = self.attention(x.unsqueeze(1), x.unsqueeze(1), x.unsqueeze(1))
|
101
|
+
x = x.squeeze(1)
|
102
|
+
|
103
|
+
x = global_mean_pool(x, batch)
|
104
|
+
x = self.fc(x)
|
105
|
+
return x
|
106
|
+
|
107
|
+
# Step 4: Training
|
108
|
+
|
109
|
+
# Training Loop
|
110
|
+
def train(model, data_loader, criterion, optimizer, epochs=10):
|
111
|
+
model.train()
|
112
|
+
for epoch in range(epochs):
|
113
|
+
for data in data_loader:
|
114
|
+
optimizer.zero_grad()
|
115
|
+
out = model(data.x, data.edge_index, data.batch)
|
116
|
+
loss = criterion(out, data.y)
|
117
|
+
loss.backward()
|
118
|
+
optimizer.step()
|
119
|
+
print(f'Epoch {epoch+1}, Loss: {loss.item()}')
|
120
|
+
|
121
|
+
def evaluate(model, data_loader):
|
122
|
+
model.eval()
|
123
|
+
correct = 0
|
124
|
+
total = 0
|
125
|
+
with torch.no_grad():
|
126
|
+
for data in data_loader:
|
127
|
+
out = model(data.x, data.edge_index, data.batch)
|
128
|
+
_, predicted = torch.max(out, 1)
|
129
|
+
total += data.y.size(0)
|
130
|
+
correct += (predicted == data.y).sum().item()
|
131
|
+
accuracy = correct / total
|
132
|
+
print(f'Accuracy: {accuracy * 100:.2f}%')
|
133
|
+
|
134
|
+
def spacr_transformer(image_dir, seq_file, nr_grnas=1350, lr=0.001, mode='train'):
|
135
|
+
images = load_images(image_dir)
|
6
136
|
|
7
|
-
|
8
|
-
|
9
|
-
|
137
|
+
sequencing_data = load_sequencing_data(seq_file)
|
138
|
+
wells = sequencing_data['well'].unique()
|
139
|
+
graph_data = create_graph(wells, sequencing_data)
|
140
|
+
model = GraphTransformer(in_channels=nr_grnas, hidden_channels=128, out_channels=nr_grnas)
|
141
|
+
criterion = nn.CrossEntropyLoss()
|
142
|
+
optimizer = Adam(model.parameters(), lr=lr)
|
143
|
+
data_list = [graph_data]
|
144
|
+
loader = DataLoader(data_list, batch_size=1, shuffle=True)
|
145
|
+
if mode == 'train':
|
146
|
+
train(model, loader, criterion, optimizer)
|
147
|
+
elif mode == 'eval':
|
148
|
+
evaluate(model, loader)
|
149
|
+
else:
|
150
|
+
raise ValueError('Invalid mode. Use "train" or "eval".')
|
10
151
|
|
11
|
-
|
12
|
-
|
13
|
-
|
14
|
-
|
152
|
+
def _calculate_glcm_features(intensity_image):
|
153
|
+
glcm = greycomatrix(img_as_ubyte(intensity_image), distances=[1, 2, 3, 4], angles=[0, np.pi/4, np.pi/2, 3*np.pi/4], symmetric=True, normed=True)
|
154
|
+
features = {}
|
155
|
+
for prop in ['contrast', 'dissimilarity', 'homogeneity', 'energy', 'correlation', 'ASM']:
|
156
|
+
for i, distance in enumerate([1, 2, 3, 4]):
|
157
|
+
for j, angle in enumerate([0, np.pi/4, np.pi/2, 3*np.pi/4]):
|
158
|
+
features[f'glcm_{prop}_d{distance}_a{angle}'] = greycoprops(glcm, prop)[i, j]
|
159
|
+
return features
|
160
|
+
|
161
|
+
def _calculate_lbp_features(intensity_image, P=8, R=1):
|
162
|
+
lbp = local_binary_pattern(intensity_image, P, R, method='uniform')
|
163
|
+
lbp_hist, _ = np.histogram(lbp, density=True, bins=np.arange(0, P + 3), range=(0, P + 2))
|
164
|
+
return {f'lbp_{i}': val for i, val in enumerate(lbp_hist)}
|
165
|
+
|
166
|
+
def _calculate_wavelet_features(intensity_image, wavelet='db1'):
|
167
|
+
coeffs = pywt.wavedec2(intensity_image, wavelet=wavelet, level=3)
|
168
|
+
features = {}
|
169
|
+
for i, coeff in enumerate(coeffs):
|
170
|
+
if isinstance(coeff, tuple):
|
171
|
+
for j, subband in enumerate(coeff):
|
172
|
+
features[f'wavelet_coeff_{i}_{j}_mean'] = np.mean(subband)
|
173
|
+
features[f'wavelet_coeff_{i}_{j}_std'] = np.std(subband)
|
174
|
+
features[f'wavelet_coeff_{i}_{j}_energy'] = np.sum(subband**2)
|
175
|
+
else:
|
176
|
+
features[f'wavelet_coeff_{i}_mean'] = np.mean(coeff)
|
177
|
+
features[f'wavelet_coeff_{i}_std'] = np.std(coeff)
|
178
|
+
features[f'wavelet_coeff_{i}_energy'] = np.sum(coeff**2)
|
179
|
+
return features
|
180
|
+
|
181
|
+
|
182
|
+
from .measure import _estimate_blur, _calculate_correlation_object_level, _calculate_homogeneity, _periphery_intensity, _outside_intensity, _calculate_radial_distribution, _create_dataframe, _extended_regionprops_table, _calculate_correlation_object_level
|
183
|
+
|
184
|
+
def _intensity_measurements(cell_mask, nucleus_mask, pathogen_mask, cytoplasm_mask, channel_arrays, settings, sizes=[3, 6, 12, 24], periphery=True, outside=True):
|
185
|
+
radial_dist = settings['radial_dist']
|
186
|
+
calculate_correlation = settings['calculate_correlation']
|
187
|
+
homogeneity = settings['homogeneity']
|
188
|
+
distances = settings['homogeneity_distances']
|
189
|
+
|
190
|
+
intensity_props = ["label", "centroid_weighted", "centroid_weighted_local", "max_intensity", "mean_intensity", "min_intensity", "integrated_intensity"]
|
191
|
+
additional_props = ["standard_deviation_intensity", "median_intensity", "sum_intensity", "intensity_range", "mean_absolute_deviation_intensity", "skewness_intensity", "kurtosis_intensity", "variance_intensity", "mode_intensity", "energy_intensity", "entropy_intensity", "harmonic_mean_intensity", "geometric_mean_intensity"]
|
192
|
+
col_lables = ['region_label', 'mean', '5_percentile', '10_percentile', '25_percentile', '50_percentile', '75_percentile', '85_percentile', '95_percentile']
|
193
|
+
cell_dfs, nucleus_dfs, pathogen_dfs, cytoplasm_dfs = [], [], [], []
|
194
|
+
ls = ['cell','nucleus','pathogen','cytoplasm']
|
195
|
+
labels = [cell_mask, nucleus_mask, pathogen_mask, cytoplasm_mask]
|
196
|
+
dfs = [cell_dfs, nucleus_dfs, pathogen_dfs, cytoplasm_dfs]
|
197
|
+
|
198
|
+
for i in range(0,channel_arrays.shape[-1]):
|
199
|
+
channel = channel_arrays[:, :, i]
|
200
|
+
for j, (label, df) in enumerate(zip(labels, dfs)):
|
201
|
+
|
202
|
+
if np.max(label) == 0:
|
203
|
+
empty_df = pd.DataFrame()
|
204
|
+
df.append(empty_df)
|
205
|
+
continue
|
206
|
+
|
207
|
+
mask_intensity_df = _extended_regionprops_table(label, channel, intensity_props)
|
208
|
+
|
209
|
+
# Additional intensity properties
|
210
|
+
region_props = measure.regionprops_table(label, intensity_image=channel, properties=['label'])
|
211
|
+
intensity_values = [channel[region.coords[:, 0], region.coords[:, 1]] for region in measure.regionprops(label)]
|
212
|
+
additional_data = {prop: [] for prop in additional_props}
|
213
|
+
|
214
|
+
for values in intensity_values:
|
215
|
+
if len(values) == 0:
|
216
|
+
continue
|
217
|
+
additional_data["standard_deviation_intensity"].append(np.std(values))
|
218
|
+
additional_data["median_intensity"].append(np.median(values))
|
219
|
+
additional_data["sum_intensity"].append(np.sum(values))
|
220
|
+
additional_data["intensity_range"].append(np.max(values) - np.min(values))
|
221
|
+
additional_data["mean_absolute_deviation_intensity"].append(np.mean(np.abs(values - np.mean(values))))
|
222
|
+
additional_data["skewness_intensity"].append(skew(values))
|
223
|
+
additional_data["kurtosis_intensity"].append(kurtosis(values))
|
224
|
+
additional_data["variance_intensity"].append(np.var(values))
|
225
|
+
additional_data["mode_intensity"].append(mode(values)[0][0])
|
226
|
+
additional_data["energy_intensity"].append(np.sum(values**2))
|
227
|
+
additional_data["entropy_intensity"].append(entropy(values))
|
228
|
+
additional_data["harmonic_mean_intensity"].append(hmean(values))
|
229
|
+
additional_data["geometric_mean_intensity"].append(gmean(values))
|
230
|
+
|
231
|
+
for prop in additional_props:
|
232
|
+
region_props[prop] = additional_data[prop]
|
233
|
+
|
234
|
+
additional_df = pd.DataFrame(region_props)
|
235
|
+
mask_intensity_df = pd.concat([mask_intensity_df.reset_index(drop=True), additional_df.reset_index(drop=True)], axis=1)
|
236
|
+
|
237
|
+
if homogeneity:
|
238
|
+
homogeneity_df = _calculate_homogeneity(label, channel, distances)
|
239
|
+
mask_intensity_df = pd.concat([mask_intensity_df.reset_index(drop=True), homogeneity_df], axis=1)
|
240
|
+
|
241
|
+
if periphery:
|
242
|
+
if ls[j] == 'nucleus' or ls[j] == 'pathogen':
|
243
|
+
periphery_intensity_stats = _periphery_intensity(label, channel)
|
244
|
+
mask_intensity_df = pd.concat([mask_intensity_df, pd.DataFrame(periphery_intensity_stats, columns=[f'periphery_{stat}' for stat in col_lables])],axis=1)
|
245
|
+
|
246
|
+
if outside:
|
247
|
+
if ls[j] == 'nucleus' or ls[j] == 'pathogen':
|
248
|
+
outside_intensity_stats = _outside_intensity(label, channel)
|
249
|
+
mask_intensity_df = pd.concat([mask_intensity_df, pd.DataFrame(outside_intensity_stats, columns=[f'outside_{stat}' for stat in col_lables])], axis=1)
|
250
|
+
|
251
|
+
# Adding GLCM features
|
252
|
+
glcm_features = _calculate_glcm_features(channel)
|
253
|
+
for k, v in glcm_features.items():
|
254
|
+
mask_intensity_df[f'{ls[j]}_channel_{i}_{k}'] = v
|
255
|
+
|
256
|
+
# Adding LBP features
|
257
|
+
lbp_features = _calculate_lbp_features(channel)
|
258
|
+
for k, v in lbp_features.items():
|
259
|
+
mask_intensity_df[f'{ls[j]}_channel_{i}_{k}'] = v
|
260
|
+
|
261
|
+
# Adding Wavelet features
|
262
|
+
wavelet_features = _calculate_wavelet_features(channel)
|
263
|
+
for k, v in wavelet_features.items():
|
264
|
+
mask_intensity_df[f'{ls[j]}_channel_{i}_{k}'] = v
|
265
|
+
|
266
|
+
blur_col = [_estimate_blur(channel[label == region_label]) for region_label in mask_intensity_df['label']]
|
267
|
+
mask_intensity_df[f'{ls[j]}_channel_{i}_blur'] = blur_col
|
268
|
+
|
269
|
+
mask_intensity_df.columns = [f'{ls[j]}_channel_{i}_{col}' if col != 'label' else col for col in mask_intensity_df.columns]
|
270
|
+
df.append(mask_intensity_df)
|
271
|
+
|
272
|
+
if radial_dist:
|
273
|
+
if np.max(nucleus_mask) != 0:
|
274
|
+
nucleus_radial_distributions = _calculate_radial_distribution(cell_mask, nucleus_mask, channel_arrays, num_bins=6)
|
275
|
+
nucleus_df = _create_dataframe(nucleus_radial_distributions, 'nucleus')
|
276
|
+
dfs[1].append(nucleus_df)
|
277
|
+
|
278
|
+
if np.max(nucleus_mask) != 0:
|
279
|
+
pathogen_radial_distributions = _calculate_radial_distribution(cell_mask, pathogen_mask, channel_arrays, num_bins=6)
|
280
|
+
pathogen_df = _create_dataframe(pathogen_radial_distributions, 'pathogen')
|
281
|
+
dfs[2].append(pathogen_df)
|
282
|
+
|
283
|
+
if calculate_correlation:
|
284
|
+
if channel_arrays.shape[-1] >= 2:
|
285
|
+
for i in range(channel_arrays.shape[-1]):
|
286
|
+
for j in range(i+1, channel_arrays.shape[-1]):
|
287
|
+
chan_i = channel_arrays[:, :, i]
|
288
|
+
chan_j = channel_arrays[:, :, j]
|
289
|
+
for m, mask in enumerate(labels):
|
290
|
+
coloc_df = _calculate_correlation_object_level(chan_i, chan_j, mask, settings)
|
291
|
+
coloc_df.columns = [f'{ls[m]}_channel_{i}_channel_{j}_{col}' for col in coloc_df.columns]
|
292
|
+
dfs[m].append(coloc_df)
|
293
|
+
|
294
|
+
return pd.concat(cell_dfs, axis=1), pd.concat(nucleus_dfs, axis=1), pd.concat(pathogen_dfs, axis=1), pd.concat(cytoplasm_dfs, axis=1)
|
15
295
|
|
16
|
-
global vars_dict, root
|
17
|
-
root, vars_dict = initiate_mask_root(1000, 1500)
|
18
|
-
root.mainloop()
|
spacr/annotate_app.py
CHANGED
@@ -13,7 +13,7 @@ from ttkthemes import ThemedTk
|
|
13
13
|
|
14
14
|
from .logger import log_function_call
|
15
15
|
|
16
|
-
from .gui_utils import ScrollableFrame, set_default_font, set_dark_style, create_dark_mode
|
16
|
+
from .gui_utils import ScrollableFrame, set_default_font, set_dark_style, create_dark_mode, style_text_boxes, create_menu_bar
|
17
17
|
|
18
18
|
class ImageApp:
|
19
19
|
"""
|
@@ -38,7 +38,7 @@ class ImageApp:
|
|
38
38
|
- db_update_thread (threading.Thread): A thread for updating the database.
|
39
39
|
"""
|
40
40
|
|
41
|
-
def
|
41
|
+
def __init__(self, root, db_path, image_type=None, channels=None, grid_rows=None, grid_cols=None, image_size=(200, 200), annotation_column='annotate'):
|
42
42
|
"""
|
43
43
|
Initializes an instance of the ImageApp class.
|
44
44
|
|
@@ -383,7 +383,7 @@ def annotate(db, image_type=None, channels=None, geom="1000x1100", img_size=(200
|
|
383
383
|
root = tk.Tk()
|
384
384
|
root.geometry(geom)
|
385
385
|
app = ImageApp(root, db, image_type=image_type, channels=channels, image_size=img_size, grid_rows=rows, grid_cols=columns, annotation_column=annotation_column)
|
386
|
-
|
386
|
+
#app = ImageApp()
|
387
387
|
next_button = tk.Button(root, text="Next", command=app.next_page)
|
388
388
|
next_button.grid(row=app.grid_rows, column=app.grid_cols - 1)
|
389
389
|
back_button = tk.Button(root, text="Back", command=app.previous_page)
|
@@ -425,7 +425,8 @@ def initiate_annotation_app_root(width, height):
|
|
425
425
|
root = ThemedTk(theme=theme)
|
426
426
|
style = ttk.Style(root)
|
427
427
|
set_dark_style(style)
|
428
|
-
|
428
|
+
style_text_boxes(style)
|
429
|
+
set_default_font(root, font_name="Arial", size=8)
|
429
430
|
root.geometry(f"{width}x{height}")
|
430
431
|
root.title("Annotation App")
|
431
432
|
|
@@ -473,6 +474,7 @@ def initiate_annotation_app_root(width, height):
|
|
473
474
|
new_root = tk.Tk()
|
474
475
|
new_root.geometry(f"{width}x{height}")
|
475
476
|
new_root.title("Mask Application")
|
477
|
+
|
476
478
|
|
477
479
|
# Start the annotation application in the new root window
|
478
480
|
app_instance = annotate(db, image_type, channels, annotation_column, geom, img_size, rows, columns)
|
@@ -482,7 +484,7 @@ def initiate_annotation_app_root(width, height):
|
|
482
484
|
create_dark_mode(root, style, console_output=None)
|
483
485
|
|
484
486
|
run_button = ttk.Button(scrollable_frame.scrollable_frame, text="Run", command=run_app)
|
485
|
-
run_button.grid(row=row, column=0, columnspan=2, pady=10)
|
487
|
+
run_button.grid(row=row, column=0, columnspan=2, pady=10, padx=10)
|
486
488
|
|
487
489
|
return root
|
488
490
|
|
spacr/chris.py
ADDED
@@ -0,0 +1,50 @@
|
|
1
|
+
import pandas as pd
|
2
|
+
import numpy as np
|
3
|
+
from .core import _permutation_importance, _shap_analysis
|
4
|
+
|
5
|
+
def join_measurments_and_annotation(src, tables = ['cell', 'nucleus', 'pathogen','cytoplasm']):
|
6
|
+
|
7
|
+
from .io import _read_and_merge_data, _read_db
|
8
|
+
|
9
|
+
db_loc = [src+'/measurements/measurements.db']
|
10
|
+
loc = src+'/measurements/measurements.db'
|
11
|
+
df, _ = _read_and_merge_data(db_loc,
|
12
|
+
tables,
|
13
|
+
verbose=True,
|
14
|
+
include_multinucleated=True,
|
15
|
+
include_multiinfected=True,
|
16
|
+
include_noninfected=True)
|
17
|
+
|
18
|
+
paths_df = _read_db(loc, tables=['png_list'])
|
19
|
+
|
20
|
+
merged_df = pd.merge(df, paths_df[0], on='prcfo', how='left')
|
21
|
+
|
22
|
+
return merged_df
|
23
|
+
|
24
|
+
def plate_heatmap(src, model_type='xgboost', variable='predictions', grouping='mean', min_max='allq', cmap='viridis', channel_of_interest=3, min_count=25, n_estimators=100, col_to_compare='col', pos='c1', neg='c2', exclude=None, n_repeats=10, clean=True, nr_to_plot=20, verbose=False, n_jobs=-1):
|
25
|
+
from .io import _read_and_merge_data
|
26
|
+
from .plot import _plot_plates
|
27
|
+
|
28
|
+
db_loc = [src+'/measurements/measurements.db']
|
29
|
+
tables = ['cell', 'nucleus', 'pathogen','cytoplasm']
|
30
|
+
include_multinucleated, include_multiinfected, include_noninfected = True, 2.0, True
|
31
|
+
|
32
|
+
df = join_measurments_and_annotation(src, tables=['cell', 'nucleus', 'pathogen', 'cytoplasm'])
|
33
|
+
|
34
|
+
if not channel_of_interest is None:
|
35
|
+
df['recruitment'] = df[f'pathogen_channel_{channel_of_interest}_mean_intensity']/df[f'cytoplasm_channel_{channel_of_interest}_mean_intensity']
|
36
|
+
feature_string = f'channel_{channel_of_interest}'
|
37
|
+
else:
|
38
|
+
feature_string = None
|
39
|
+
|
40
|
+
output = _permutation_importance(df, feature_string, col_to_compare, pos, neg, exclude, n_repeats, clean, nr_to_plot, n_estimators=n_estimators, random_state=42, model_type=model_type, n_jobs=n_jobs)
|
41
|
+
|
42
|
+
_shap_analysis(output[3], output[4], output[5])
|
43
|
+
|
44
|
+
features = output[0].select_dtypes(include=[np.number]).columns.tolist()
|
45
|
+
|
46
|
+
if not variable in features:
|
47
|
+
raise ValueError(f"Variable {variable} not found in the dataframe. Please choose one of the following: {features}")
|
48
|
+
|
49
|
+
plate_heatmap = _plot_plates(output[0], variable, grouping, min_max, cmap, min_count)
|
50
|
+
return [output, plate_heatmap]
|