spacr 0.2.1__py3-none-any.whl → 0.2.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- spacr/gui.py +2 -1
- spacr/gui_core.py +75 -34
- spacr/gui_elements.py +323 -59
- spacr/gui_utils.py +26 -32
- spacr/resources/icons/abort.png +0 -0
- spacr/resources/icons/classify.png +0 -0
- spacr/resources/icons/make_masks.png +0 -0
- spacr/resources/icons/mask.png +0 -0
- spacr/resources/icons/measure.png +0 -0
- spacr/resources/icons/ml_analyze.png +0 -0
- spacr/resources/icons/recruitment.png +0 -0
- spacr/resources/icons/regression.png +0 -0
- spacr/resources/icons/run.png +0 -0
- spacr/resources/icons/spacr_logo_rotation.gif +0 -0
- spacr/resources/icons/train_cellpose.png +0 -0
- spacr/resources/icons/umap.png +0 -0
- {spacr-0.2.1.dist-info → spacr-0.2.3.dist-info}/METADATA +1 -1
- spacr-0.2.3.dist-info/RECORD +58 -0
- spacr/alpha.py +0 -807
- spacr/annotate_app.py +0 -670
- spacr/annotate_app_v2.py +0 -670
- spacr/app_make_masks_v2.py +0 -686
- spacr/classify_app.py +0 -201
- spacr/cli.py +0 -41
- spacr/foldseek.py +0 -779
- spacr/get_alfafold_structures.py +0 -72
- spacr/gui_2.py +0 -157
- spacr/gui_annotate.py +0 -145
- spacr/gui_classify_app.py +0 -201
- spacr/gui_make_masks_app.py +0 -927
- spacr/gui_make_masks_app_v2.py +0 -688
- spacr/gui_mask_app.py +0 -249
- spacr/gui_measure_app.py +0 -246
- spacr/gui_run.py +0 -58
- spacr/gui_sim_app.py +0 -0
- spacr/gui_wrappers.py +0 -149
- spacr/icons/abort.png +0 -0
- spacr/icons/abort.svg +0 -1
- spacr/icons/download.png +0 -0
- spacr/icons/download.svg +0 -1
- spacr/icons/download_for_offline_100dp_E8EAED_FILL0_wght100_GRAD-25_opsz48.png +0 -0
- spacr/icons/download_for_offline_100dp_E8EAED_FILL0_wght100_GRAD-25_opsz48.svg +0 -1
- spacr/icons/logo_spacr.png +0 -0
- spacr/icons/make_masks.png +0 -0
- spacr/icons/make_masks.svg +0 -1
- spacr/icons/map_barcodes.png +0 -0
- spacr/icons/map_barcodes.svg +0 -1
- spacr/icons/mask.png +0 -0
- spacr/icons/mask.svg +0 -1
- spacr/icons/measure.png +0 -0
- spacr/icons/measure.svg +0 -1
- spacr/icons/play_circle_100dp_E8EAED_FILL0_wght100_GRAD-25_opsz48.png +0 -0
- spacr/icons/play_circle_100dp_E8EAED_FILL0_wght100_GRAD-25_opsz48.svg +0 -1
- spacr/icons/run.png +0 -0
- spacr/icons/run.svg +0 -1
- spacr/icons/sequencing.png +0 -0
- spacr/icons/sequencing.svg +0 -1
- spacr/icons/settings.png +0 -0
- spacr/icons/settings.svg +0 -1
- spacr/icons/settings_100dp_E8EAED_FILL0_wght100_GRAD-25_opsz48.png +0 -0
- spacr/icons/settings_100dp_E8EAED_FILL0_wght100_GRAD-25_opsz48.svg +0 -1
- spacr/icons/stop_circle_100dp_E8EAED_FILL0_wght100_GRAD-25_opsz48.png +0 -0
- spacr/icons/stop_circle_100dp_E8EAED_FILL0_wght100_GRAD-25_opsz48.svg +0 -1
- spacr/icons/theater_comedy_100dp_E8EAED_FILL0_wght100_GRAD200_opsz48.png +0 -0
- spacr/icons/theater_comedy_100dp_E8EAED_FILL0_wght100_GRAD200_opsz48.svg +0 -1
- spacr/make_masks_app.py +0 -929
- spacr/make_masks_app_v2.py +0 -688
- spacr/mask_app.py +0 -249
- spacr/measure_app.py +0 -246
- spacr/models/cp/toxo_plaque_cyto_e25000_X1120_Y1120.CP_model +0 -0
- spacr/models/cp/toxo_plaque_cyto_e25000_X1120_Y1120.CP_model_settings.csv +0 -23
- spacr/models/cp/toxo_pv_lumen.CP_model +0 -0
- spacr/old_code.py +0 -358
- spacr/resources/icons/abort.svg +0 -1
- spacr/resources/icons/annotate.svg +0 -1
- spacr/resources/icons/classify.svg +0 -1
- spacr/resources/icons/download.svg +0 -1
- spacr/resources/icons/icon.psd +0 -0
- spacr/resources/icons/make_masks.svg +0 -1
- spacr/resources/icons/map_barcodes.svg +0 -1
- spacr/resources/icons/mask.svg +0 -1
- spacr/resources/icons/measure.svg +0 -1
- spacr/resources/icons/run.svg +0 -1
- spacr/resources/icons/run_2.png +0 -0
- spacr/resources/icons/run_2.svg +0 -1
- spacr/resources/icons/sequencing.svg +0 -1
- spacr/resources/icons/settings.svg +0 -1
- spacr/resources/icons/train_cellpose.svg +0 -1
- spacr/test_gui.py +0 -0
- spacr-0.2.1.dist-info/RECORD +0 -126
- /spacr/resources/icons/{cellpose.png → cellpose_all.png} +0 -0
- {spacr-0.2.1.dist-info → spacr-0.2.3.dist-info}/LICENSE +0 -0
- {spacr-0.2.1.dist-info → spacr-0.2.3.dist-info}/WHEEL +0 -0
- {spacr-0.2.1.dist-info → spacr-0.2.3.dist-info}/entry_points.txt +0 -0
- {spacr-0.2.1.dist-info → spacr-0.2.3.dist-info}/top_level.txt +0 -0
spacr/alpha.py
DELETED
@@ -1,807 +0,0 @@
|
|
1
|
-
from skimage import measure, feature
|
2
|
-
from skimage.filters import gabor
|
3
|
-
from skimage.color import rgb2gray
|
4
|
-
from skimage.util import img_as_ubyte
|
5
|
-
import numpy as np
|
6
|
-
import pandas as pd
|
7
|
-
from scipy.stats import skew, kurtosis, entropy, hmean, gmean, mode
|
8
|
-
import pywt
|
9
|
-
|
10
|
-
import os
|
11
|
-
import pandas as pd
|
12
|
-
from PIL import Image
|
13
|
-
import torch
|
14
|
-
import torch.nn as nn
|
15
|
-
import torch.nn.functional as F
|
16
|
-
from torch_geometric.data import Data, DataLoader
|
17
|
-
from torch_geometric.nn import GCNConv, global_mean_pool
|
18
|
-
from torch.optim import Adam
|
19
|
-
import os
|
20
|
-
import shutil
|
21
|
-
import random
|
22
|
-
|
23
|
-
# Step 1: Data Preparation
|
24
|
-
|
25
|
-
# Load images
|
26
|
-
def load_images(image_dir):
|
27
|
-
images = {}
|
28
|
-
for filename in os.listdir(image_dir):
|
29
|
-
if filename.endswith(".png"):
|
30
|
-
img = Image.open(os.path.join(image_dir, filename))
|
31
|
-
images[filename] = img
|
32
|
-
return images
|
33
|
-
|
34
|
-
# Load sequencing data
|
35
|
-
def load_sequencing_data(seq_file):
|
36
|
-
seq_data = pd.read_csv(seq_file)
|
37
|
-
return seq_data
|
38
|
-
|
39
|
-
# Step 2: Data Representation
|
40
|
-
|
41
|
-
# Image Representation (Using a simple CNN for feature extraction)
|
42
|
-
class CNNFeatureExtractor(nn.Module):
|
43
|
-
def __init__(self):
|
44
|
-
super(CNNFeatureExtractor, self).__init__()
|
45
|
-
self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)
|
46
|
-
self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
|
47
|
-
self.fc = nn.Linear(32 * 8 * 8, 128) # Assuming input images are 64x64
|
48
|
-
|
49
|
-
def forward(self, x):
|
50
|
-
x = F.relu(self.conv1(x))
|
51
|
-
x = F.max_pool2d(x, 2)
|
52
|
-
x = F.relu(self.conv2(x))
|
53
|
-
x = F.max_pool2d(x, 2)
|
54
|
-
x = x.view(x.size(0), -1)
|
55
|
-
x = self.fc(x)
|
56
|
-
return x
|
57
|
-
|
58
|
-
# Graph Representation
|
59
|
-
def create_graph(wells, sequencing_data):
|
60
|
-
nodes = []
|
61
|
-
edges = []
|
62
|
-
node_features = []
|
63
|
-
|
64
|
-
for well in wells:
|
65
|
-
# Add node for each well
|
66
|
-
nodes.append(well)
|
67
|
-
|
68
|
-
# Get sequencing data for the well
|
69
|
-
seq_info = sequencing_data[sequencing_data['well'] == well]
|
70
|
-
|
71
|
-
# Create node features based on gene knockouts and abundances
|
72
|
-
features = torch.tensor(seq_info['abundance'].values, dtype=torch.float)
|
73
|
-
node_features.append(features)
|
74
|
-
|
75
|
-
# Define edges (for simplicity, assume fully connected graph)
|
76
|
-
for other_well in wells:
|
77
|
-
if other_well != well:
|
78
|
-
edges.append((wells.index(well), wells.index(other_well)))
|
79
|
-
|
80
|
-
edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous()
|
81
|
-
x = torch.stack(node_features)
|
82
|
-
|
83
|
-
data = Data(x=x, edge_index=edge_index)
|
84
|
-
return data
|
85
|
-
|
86
|
-
# Step 3: Model Architecture
|
87
|
-
|
88
|
-
class GraphTransformer(nn.Module):
|
89
|
-
def __init__(self, in_channels, hidden_channels, out_channels):
|
90
|
-
super(GraphTransformer, self).__init__()
|
91
|
-
self.conv1 = GCNConv(in_channels, hidden_channels)
|
92
|
-
self.conv2 = GCNConv(hidden_channels, hidden_channels)
|
93
|
-
self.fc = nn.Linear(hidden_channels, out_channels)
|
94
|
-
self.attention = nn.MultiheadAttention(hidden_channels, num_heads=8)
|
95
|
-
|
96
|
-
def forward(self, x, edge_index, batch):
|
97
|
-
x = F.relu(self.conv1(x, edge_index))
|
98
|
-
x = F.relu(self.conv2(x, edge_index))
|
99
|
-
|
100
|
-
# Apply attention mechanism
|
101
|
-
x, _ = self.attention(x.unsqueeze(1), x.unsqueeze(1), x.unsqueeze(1))
|
102
|
-
x = x.squeeze(1)
|
103
|
-
|
104
|
-
x = global_mean_pool(x, batch)
|
105
|
-
x = self.fc(x)
|
106
|
-
return x
|
107
|
-
|
108
|
-
# Step 4: Training
|
109
|
-
|
110
|
-
# Training Loop
|
111
|
-
def train(model, data_loader, criterion, optimizer, epochs=10):
|
112
|
-
model.train()
|
113
|
-
for epoch in range(epochs):
|
114
|
-
for data in data_loader:
|
115
|
-
optimizer.zero_grad()
|
116
|
-
out = model(data.x, data.edge_index, data.batch)
|
117
|
-
loss = criterion(out, data.y)
|
118
|
-
loss.backward()
|
119
|
-
optimizer.step()
|
120
|
-
print(f'Epoch {epoch+1}, Loss: {loss.item()}')
|
121
|
-
|
122
|
-
def evaluate(model, data_loader):
|
123
|
-
model.eval()
|
124
|
-
correct = 0
|
125
|
-
total = 0
|
126
|
-
with torch.no_grad():
|
127
|
-
for data in data_loader:
|
128
|
-
out = model(data.x, data.edge_index, data.batch)
|
129
|
-
_, predicted = torch.max(out, 1)
|
130
|
-
total += data.y.size(0)
|
131
|
-
correct += (predicted == data.y).sum().item()
|
132
|
-
accuracy = correct / total
|
133
|
-
print(f'Accuracy: {accuracy * 100:.2f}%')
|
134
|
-
|
135
|
-
def spacr_transformer(image_dir, seq_file, nr_grnas=1350, lr=0.001, mode='train'):
|
136
|
-
images = load_images(image_dir)
|
137
|
-
|
138
|
-
sequencing_data = load_sequencing_data(seq_file)
|
139
|
-
wells = sequencing_data['well'].unique()
|
140
|
-
graph_data = create_graph(wells, sequencing_data)
|
141
|
-
model = GraphTransformer(in_channels=nr_grnas, hidden_channels=128, out_channels=nr_grnas)
|
142
|
-
criterion = nn.CrossEntropyLoss()
|
143
|
-
optimizer = Adam(model.parameters(), lr=lr)
|
144
|
-
data_list = [graph_data]
|
145
|
-
loader = DataLoader(data_list, batch_size=1, shuffle=True)
|
146
|
-
if mode == 'train':
|
147
|
-
train(model, loader, criterion, optimizer)
|
148
|
-
elif mode == 'eval':
|
149
|
-
evaluate(model, loader)
|
150
|
-
else:
|
151
|
-
raise ValueError('Invalid mode. Use "train" or "eval".')
|
152
|
-
|
153
|
-
from skimage.feature import greycomatrix
|
154
|
-
|
155
|
-
from skimage.feature import greycoprops
|
156
|
-
|
157
|
-
def _calculate_glcm_features(intensity_image):
|
158
|
-
glcm = greycomatrix(img_as_ubyte(intensity_image), distances=[1, 2, 3, 4], angles=[0, np.pi/4, np.pi/2, 3*np.pi/4], symmetric=True, normed=True)
|
159
|
-
features = {}
|
160
|
-
for prop in ['contrast', 'dissimilarity', 'homogeneity', 'energy', 'correlation', 'ASM']:
|
161
|
-
for i, distance in enumerate([1, 2, 3, 4]):
|
162
|
-
for j, angle in enumerate([0, np.pi/4, np.pi/2, 3*np.pi/4]):
|
163
|
-
features[f'glcm_{prop}_d{distance}_a{angle}'] = greycoprops(glcm, prop)[i, j]
|
164
|
-
return features
|
165
|
-
|
166
|
-
from skimage.feature import local_binary_pattern
|
167
|
-
|
168
|
-
def _calculate_lbp_features(intensity_image, P=8, R=1):
|
169
|
-
lbp = local_binary_pattern(intensity_image, P, R, method='uniform')
|
170
|
-
lbp_hist, _ = np.histogram(lbp, density=True, bins=np.arange(0, P + 3), range=(0, P + 2))
|
171
|
-
return {f'lbp_{i}': val for i, val in enumerate(lbp_hist)}
|
172
|
-
|
173
|
-
def _calculate_wavelet_features(intensity_image, wavelet='db1'):
|
174
|
-
coeffs = pywt.wavedec2(intensity_image, wavelet=wavelet, level=3)
|
175
|
-
features = {}
|
176
|
-
for i, coeff in enumerate(coeffs):
|
177
|
-
if isinstance(coeff, tuple):
|
178
|
-
for j, subband in enumerate(coeff):
|
179
|
-
features[f'wavelet_coeff_{i}_{j}_mean'] = np.mean(subband)
|
180
|
-
features[f'wavelet_coeff_{i}_{j}_std'] = np.std(subband)
|
181
|
-
features[f'wavelet_coeff_{i}_{j}_energy'] = np.sum(subband**2)
|
182
|
-
else:
|
183
|
-
features[f'wavelet_coeff_{i}_mean'] = np.mean(coeff)
|
184
|
-
features[f'wavelet_coeff_{i}_std'] = np.std(coeff)
|
185
|
-
features[f'wavelet_coeff_{i}_energy'] = np.sum(coeff**2)
|
186
|
-
return features
|
187
|
-
|
188
|
-
|
189
|
-
from .measure import _estimate_blur, _calculate_correlation_object_level, _calculate_homogeneity, _periphery_intensity, _outside_intensity, _calculate_radial_distribution, _create_dataframe, _extended_regionprops_table, _calculate_correlation_object_level
|
190
|
-
|
191
|
-
def _intensity_measurements(cell_mask, nucleus_mask, pathogen_mask, cytoplasm_mask, channel_arrays, settings, sizes=[3, 6, 12, 24], periphery=True, outside=True):
|
192
|
-
radial_dist = settings['radial_dist']
|
193
|
-
calculate_correlation = settings['calculate_correlation']
|
194
|
-
homogeneity = settings['homogeneity']
|
195
|
-
distances = settings['homogeneity_distances']
|
196
|
-
|
197
|
-
intensity_props = ["label", "centroid_weighted", "centroid_weighted_local", "max_intensity", "mean_intensity", "min_intensity", "integrated_intensity"]
|
198
|
-
additional_props = ["standard_deviation_intensity", "median_intensity", "sum_intensity", "intensity_range", "mean_absolute_deviation_intensity", "skewness_intensity", "kurtosis_intensity", "variance_intensity", "mode_intensity", "energy_intensity", "entropy_intensity", "harmonic_mean_intensity", "geometric_mean_intensity"]
|
199
|
-
col_lables = ['region_label', 'mean', '5_percentile', '10_percentile', '25_percentile', '50_percentile', '75_percentile', '85_percentile', '95_percentile']
|
200
|
-
cell_dfs, nucleus_dfs, pathogen_dfs, cytoplasm_dfs = [], [], [], []
|
201
|
-
ls = ['cell','nucleus','pathogen','cytoplasm']
|
202
|
-
labels = [cell_mask, nucleus_mask, pathogen_mask, cytoplasm_mask]
|
203
|
-
dfs = [cell_dfs, nucleus_dfs, pathogen_dfs, cytoplasm_dfs]
|
204
|
-
|
205
|
-
for i in range(0,channel_arrays.shape[-1]):
|
206
|
-
channel = channel_arrays[:, :, i]
|
207
|
-
for j, (label, df) in enumerate(zip(labels, dfs)):
|
208
|
-
|
209
|
-
if np.max(label) == 0:
|
210
|
-
empty_df = pd.DataFrame()
|
211
|
-
df.append(empty_df)
|
212
|
-
continue
|
213
|
-
|
214
|
-
mask_intensity_df = _extended_regionprops_table(label, channel, intensity_props)
|
215
|
-
|
216
|
-
# Additional intensity properties
|
217
|
-
region_props = measure.regionprops_table(label, intensity_image=channel, properties=['label'])
|
218
|
-
intensity_values = [channel[region.coords[:, 0], region.coords[:, 1]] for region in measure.regionprops(label)]
|
219
|
-
additional_data = {prop: [] for prop in additional_props}
|
220
|
-
|
221
|
-
for values in intensity_values:
|
222
|
-
if len(values) == 0:
|
223
|
-
continue
|
224
|
-
additional_data["standard_deviation_intensity"].append(np.std(values))
|
225
|
-
additional_data["median_intensity"].append(np.median(values))
|
226
|
-
additional_data["sum_intensity"].append(np.sum(values))
|
227
|
-
additional_data["intensity_range"].append(np.max(values) - np.min(values))
|
228
|
-
additional_data["mean_absolute_deviation_intensity"].append(np.mean(np.abs(values - np.mean(values))))
|
229
|
-
additional_data["skewness_intensity"].append(skew(values))
|
230
|
-
additional_data["kurtosis_intensity"].append(kurtosis(values))
|
231
|
-
additional_data["variance_intensity"].append(np.var(values))
|
232
|
-
additional_data["mode_intensity"].append(mode(values)[0][0])
|
233
|
-
additional_data["energy_intensity"].append(np.sum(values**2))
|
234
|
-
additional_data["entropy_intensity"].append(entropy(values))
|
235
|
-
additional_data["harmonic_mean_intensity"].append(hmean(values))
|
236
|
-
additional_data["geometric_mean_intensity"].append(gmean(values))
|
237
|
-
|
238
|
-
for prop in additional_props:
|
239
|
-
region_props[prop] = additional_data[prop]
|
240
|
-
|
241
|
-
additional_df = pd.DataFrame(region_props)
|
242
|
-
mask_intensity_df = pd.concat([mask_intensity_df.reset_index(drop=True), additional_df.reset_index(drop=True)], axis=1)
|
243
|
-
|
244
|
-
if homogeneity:
|
245
|
-
homogeneity_df = _calculate_homogeneity(label, channel, distances)
|
246
|
-
mask_intensity_df = pd.concat([mask_intensity_df.reset_index(drop=True), homogeneity_df], axis=1)
|
247
|
-
|
248
|
-
if periphery:
|
249
|
-
if ls[j] == 'nucleus' or ls[j] == 'pathogen':
|
250
|
-
periphery_intensity_stats = _periphery_intensity(label, channel)
|
251
|
-
mask_intensity_df = pd.concat([mask_intensity_df, pd.DataFrame(periphery_intensity_stats, columns=[f'periphery_{stat}' for stat in col_lables])],axis=1)
|
252
|
-
|
253
|
-
if outside:
|
254
|
-
if ls[j] == 'nucleus' or ls[j] == 'pathogen':
|
255
|
-
outside_intensity_stats = _outside_intensity(label, channel)
|
256
|
-
mask_intensity_df = pd.concat([mask_intensity_df, pd.DataFrame(outside_intensity_stats, columns=[f'outside_{stat}' for stat in col_lables])], axis=1)
|
257
|
-
|
258
|
-
# Adding GLCM features
|
259
|
-
glcm_features = _calculate_glcm_features(channel)
|
260
|
-
for k, v in glcm_features.items():
|
261
|
-
mask_intensity_df[f'{ls[j]}_channel_{i}_{k}'] = v
|
262
|
-
|
263
|
-
# Adding LBP features
|
264
|
-
lbp_features = _calculate_lbp_features(channel)
|
265
|
-
for k, v in lbp_features.items():
|
266
|
-
mask_intensity_df[f'{ls[j]}_channel_{i}_{k}'] = v
|
267
|
-
|
268
|
-
# Adding Wavelet features
|
269
|
-
wavelet_features = _calculate_wavelet_features(channel)
|
270
|
-
for k, v in wavelet_features.items():
|
271
|
-
mask_intensity_df[f'{ls[j]}_channel_{i}_{k}'] = v
|
272
|
-
|
273
|
-
blur_col = [_estimate_blur(channel[label == region_label]) for region_label in mask_intensity_df['label']]
|
274
|
-
mask_intensity_df[f'{ls[j]}_channel_{i}_blur'] = blur_col
|
275
|
-
|
276
|
-
mask_intensity_df.columns = [f'{ls[j]}_channel_{i}_{col}' if col != 'label' else col for col in mask_intensity_df.columns]
|
277
|
-
df.append(mask_intensity_df)
|
278
|
-
|
279
|
-
if radial_dist:
|
280
|
-
if np.max(nucleus_mask) != 0:
|
281
|
-
nucleus_radial_distributions = _calculate_radial_distribution(cell_mask, nucleus_mask, channel_arrays, num_bins=6)
|
282
|
-
nucleus_df = _create_dataframe(nucleus_radial_distributions, 'nucleus')
|
283
|
-
dfs[1].append(nucleus_df)
|
284
|
-
|
285
|
-
if np.max(nucleus_mask) != 0:
|
286
|
-
pathogen_radial_distributions = _calculate_radial_distribution(cell_mask, pathogen_mask, channel_arrays, num_bins=6)
|
287
|
-
pathogen_df = _create_dataframe(pathogen_radial_distributions, 'pathogen')
|
288
|
-
dfs[2].append(pathogen_df)
|
289
|
-
|
290
|
-
if calculate_correlation:
|
291
|
-
if channel_arrays.shape[-1] >= 2:
|
292
|
-
for i in range(channel_arrays.shape[-1]):
|
293
|
-
for j in range(i+1, channel_arrays.shape[-1]):
|
294
|
-
chan_i = channel_arrays[:, :, i]
|
295
|
-
chan_j = channel_arrays[:, :, j]
|
296
|
-
for m, mask in enumerate(labels):
|
297
|
-
coloc_df = _calculate_correlation_object_level(chan_i, chan_j, mask, settings)
|
298
|
-
coloc_df.columns = [f'{ls[m]}_channel_{i}_channel_{j}_{col}' for col in coloc_df.columns]
|
299
|
-
dfs[m].append(coloc_df)
|
300
|
-
|
301
|
-
return pd.concat(cell_dfs, axis=1), pd.concat(nucleus_dfs, axis=1), pd.concat(pathogen_dfs, axis=1), pd.concat(cytoplasm_dfs, axis=1)
|
302
|
-
|
303
|
-
def sample_and_copy_images(folder_list, nr_of_images, dst):
|
304
|
-
|
305
|
-
if isinstance(folder_list, str):
|
306
|
-
folder_list = [folder_list]
|
307
|
-
|
308
|
-
# Create the destination folder if it does not exist
|
309
|
-
if not os.path.exists(dst):
|
310
|
-
os.makedirs(dst)
|
311
|
-
|
312
|
-
# Calculate the number of images to sample from each folder
|
313
|
-
nr_of_images_per_folder = nr_of_images // len(folder_list)
|
314
|
-
|
315
|
-
print(f"Sampling {nr_of_images_per_folder} images from {len(folder_list)} folders...")
|
316
|
-
# Initialize a list to hold the paths of the images to be copied
|
317
|
-
images_to_copy = []
|
318
|
-
|
319
|
-
for folder in folder_list:
|
320
|
-
# Get a list of all files in the current folder
|
321
|
-
all_files = [os.path.join(folder, file) for file in os.listdir(folder)]
|
322
|
-
|
323
|
-
# Filter out non-image files
|
324
|
-
image_files = [file for file in all_files if file.lower().endswith(('.jpg', '.jpeg', '.png', '.bmp', '.gif', '.tiff', '.tif'))]
|
325
|
-
|
326
|
-
# Sample images randomly from the list
|
327
|
-
sampled_images = random.sample(image_files, min(nr_of_images_per_folder, len(image_files)))
|
328
|
-
|
329
|
-
# Add the sampled images to the list of images to copy
|
330
|
-
images_to_copy.extend(sampled_images)
|
331
|
-
|
332
|
-
# Copy the sampled images to the destination folder
|
333
|
-
for image in images_to_copy:
|
334
|
-
shutil.copy(image, os.path.join(dst, os.path.basename(image)))
|
335
|
-
|
336
|
-
import os
|
337
|
-
import torch
|
338
|
-
import torch.nn as nn
|
339
|
-
import torch.nn.functional as F
|
340
|
-
from collections import defaultdict
|
341
|
-
from torch.utils.data import Dataset, DataLoader
|
342
|
-
import pandas as pd
|
343
|
-
import numpy as np
|
344
|
-
import torch.optim as optim
|
345
|
-
|
346
|
-
def generate_graphs(sequencing, scores, cell_min, gene_min_read):
|
347
|
-
# Load and preprocess sequencing (gene) data
|
348
|
-
gene_df = pd.read_csv(sequencing)
|
349
|
-
gene_df = gene_df.rename(columns={'prc': 'well_id', 'grna': 'gene_id', 'count': 'read_count'})
|
350
|
-
# Filter out genes with read counts less than gene_min_read
|
351
|
-
gene_df = gene_df[gene_df['read_count'] >= gene_min_read]
|
352
|
-
total_reads_per_well = gene_df.groupby('well_id')['read_count'].sum().reset_index(name='total_reads')
|
353
|
-
gene_df = gene_df.merge(total_reads_per_well, on='well_id')
|
354
|
-
gene_df['well_read_fraction'] = gene_df['read_count'] / gene_df['total_reads']
|
355
|
-
|
356
|
-
# Load and preprocess cell score data
|
357
|
-
cell_df = pd.read_csv(scores)
|
358
|
-
cell_df = cell_df[['prcfo', 'prc', 'pred']].rename(columns={'prcfo': 'cell_id', 'prc': 'well_id', 'pred': 'score'})
|
359
|
-
|
360
|
-
# Create a global mapping of gene IDs to indices
|
361
|
-
unique_genes = gene_df['gene_id'].unique()
|
362
|
-
gene_id_to_index = {gene_id: index for index, gene_id in enumerate(unique_genes)}
|
363
|
-
|
364
|
-
graphs = []
|
365
|
-
for well_id in pd.unique(gene_df['well_id']):
|
366
|
-
well_genes = gene_df[gene_df['well_id'] == well_id]
|
367
|
-
well_cells = cell_df[cell_df['well_id'] == well_id]
|
368
|
-
|
369
|
-
# Skip wells with no cells or genes or with fewer cells than threshold
|
370
|
-
if well_cells.empty or well_genes.empty or len(well_cells) < cell_min:
|
371
|
-
continue
|
372
|
-
|
373
|
-
# Initialize gene features tensor with zeros for all unique genes
|
374
|
-
gene_features = torch.zeros((len(gene_id_to_index), 1), dtype=torch.float)
|
375
|
-
|
376
|
-
# Update gene features tensor with well_read_fraction for genes present in this well
|
377
|
-
for _, row in well_genes.iterrows():
|
378
|
-
gene_index = gene_id_to_index[row['gene_id']]
|
379
|
-
gene_features[gene_index] = torch.tensor([[row['well_read_fraction']]])
|
380
|
-
|
381
|
-
# Prepare cell features (scores)
|
382
|
-
cell_features = torch.tensor(well_cells['score'].values, dtype=torch.float).view(-1, 1)
|
383
|
-
|
384
|
-
num_genes = len(gene_id_to_index)
|
385
|
-
num_cells = cell_features.size(0)
|
386
|
-
num_nodes = num_genes + num_cells
|
387
|
-
|
388
|
-
# Create adjacency matrix connecting each cell to all genes in the well
|
389
|
-
adj = torch.zeros((num_nodes, num_nodes), dtype=torch.float)
|
390
|
-
for _, row in well_genes.iterrows():
|
391
|
-
gene_index = gene_id_to_index[row['gene_id']]
|
392
|
-
adj[num_genes:, gene_index] = 1
|
393
|
-
|
394
|
-
graph = {
|
395
|
-
'adjacency_matrix': adj,
|
396
|
-
'gene_features': gene_features,
|
397
|
-
'cell_features': cell_features,
|
398
|
-
'num_cells': num_cells,
|
399
|
-
'num_genes': num_genes
|
400
|
-
}
|
401
|
-
graphs.append(graph)
|
402
|
-
|
403
|
-
print(f'Generated dataset with {len(graphs)} graphs')
|
404
|
-
return graphs, gene_id_to_index
|
405
|
-
|
406
|
-
def print_graphs_info(graphs, gene_id_to_index):
|
407
|
-
# Invert the gene_id_to_index mapping for easy lookup
|
408
|
-
index_to_gene_id = {v: k for k, v in gene_id_to_index.items()}
|
409
|
-
|
410
|
-
for i, graph in enumerate(graphs, start=1):
|
411
|
-
print(f"Graph {i}:")
|
412
|
-
num_genes = graph['num_genes']
|
413
|
-
num_cells = graph['num_cells']
|
414
|
-
gene_features = graph['gene_features']
|
415
|
-
cell_features = graph['cell_features']
|
416
|
-
|
417
|
-
print(f" Number of Genes: {num_genes}")
|
418
|
-
print(f" Number of Cells: {num_cells}")
|
419
|
-
|
420
|
-
# Identify genes present in the graph based on non-zero feature values
|
421
|
-
present_genes = [index_to_gene_id[idx] for idx, feature in enumerate(gene_features) if feature.item() > 0]
|
422
|
-
print(" Genes present in this Graph:", present_genes)
|
423
|
-
|
424
|
-
# Display gene features for genes present in the graph
|
425
|
-
print(" Gene Features:")
|
426
|
-
for gene_id in present_genes:
|
427
|
-
idx = gene_id_to_index[gene_id]
|
428
|
-
print(f" {gene_id}: {gene_features[idx].item()}")
|
429
|
-
|
430
|
-
# Display a sample of cell features, for brevity
|
431
|
-
print(" Cell Features (sample):")
|
432
|
-
for idx, feature in enumerate(cell_features[:min(5, len(cell_features))]):
|
433
|
-
print(f"Cell {idx+1}: {feature.item()}")
|
434
|
-
|
435
|
-
print("-" * 40)
|
436
|
-
|
437
|
-
class Attention(nn.Module):
|
438
|
-
def __init__(self, feature_dim, attn_dim, dropout_rate=0.1):
|
439
|
-
super(Attention, self).__init__()
|
440
|
-
self.query = nn.Linear(feature_dim, attn_dim)
|
441
|
-
self.key = nn.Linear(feature_dim, attn_dim)
|
442
|
-
self.value = nn.Linear(feature_dim, feature_dim)
|
443
|
-
self.scale = 1.0 / (attn_dim ** 0.5)
|
444
|
-
self.dropout = nn.Dropout(dropout_rate)
|
445
|
-
|
446
|
-
def forward(self, gene_features, cell_features):
|
447
|
-
# Queries come from the cell features
|
448
|
-
q = self.query(cell_features)
|
449
|
-
# Keys and values come from the gene features
|
450
|
-
k = self.key(gene_features)
|
451
|
-
v = self.value(gene_features)
|
452
|
-
|
453
|
-
# Compute attention weights
|
454
|
-
attn_weights = torch.matmul(q, k.transpose(-2, -1)) * self.scale
|
455
|
-
attn_weights = F.softmax(attn_weights, dim=-1)
|
456
|
-
# Apply dropout to attention weights
|
457
|
-
attn_weights = self.dropout(attn_weights)
|
458
|
-
|
459
|
-
# Apply attention weights to the values
|
460
|
-
attn_output = torch.matmul(attn_weights, v)
|
461
|
-
|
462
|
-
return attn_output, attn_weights
|
463
|
-
|
464
|
-
class GraphTransformer(nn.Module):
|
465
|
-
def __init__(self, gene_feature_size, cell_feature_size, hidden_dim, output_dim, attn_dim, dropout_rate=0.1):
|
466
|
-
super(GraphTransformer, self).__init__()
|
467
|
-
self.gene_transform = nn.Linear(gene_feature_size, hidden_dim)
|
468
|
-
self.cell_transform = nn.Linear(cell_feature_size, hidden_dim)
|
469
|
-
self.dropout = nn.Dropout(dropout_rate)
|
470
|
-
|
471
|
-
# Attention layer to let each cell attend to all genes
|
472
|
-
self.attention = Attention(hidden_dim, attn_dim)
|
473
|
-
|
474
|
-
# This layer is used to transform the combined features after attention
|
475
|
-
self.combine_transform = nn.Linear(2 * hidden_dim, hidden_dim)
|
476
|
-
|
477
|
-
# Output layer for predicting cell scores, ensuring it matches the number of cells
|
478
|
-
self.cell_output = nn.Linear(hidden_dim, output_dim)
|
479
|
-
|
480
|
-
def forward(self, adjacency_matrix, gene_features, cell_features):
|
481
|
-
# Apply initial transformation to gene and cell features
|
482
|
-
transformed_gene_features = F.relu(self.gene_transform(gene_features))
|
483
|
-
transformed_cell_features = F.relu(self.cell_transform(cell_features))
|
484
|
-
|
485
|
-
# Incorporate attention mechanism
|
486
|
-
attn_output, attn_weights = self.attention(transformed_gene_features, transformed_cell_features)
|
487
|
-
|
488
|
-
# Combine the transformed cell features with the attention output features
|
489
|
-
combined_cell_features = torch.cat((transformed_cell_features, attn_output), dim=1)
|
490
|
-
|
491
|
-
# Apply dropout here as well
|
492
|
-
combined_cell_features = self.dropout(combined_cell_features)
|
493
|
-
|
494
|
-
combined_cell_features = F.relu(self.combine_transform(combined_cell_features))
|
495
|
-
|
496
|
-
# Combine gene and cell features for message passing
|
497
|
-
combined_features = torch.cat((transformed_gene_features, combined_cell_features), dim=0)
|
498
|
-
|
499
|
-
# Apply message passing via adjacency matrix multiplication
|
500
|
-
message_passed_features = torch.matmul(adjacency_matrix, combined_features)
|
501
|
-
|
502
|
-
# Predict cell scores from the post-message passed cell features
|
503
|
-
cell_scores = self.cell_output(message_passed_features[-cell_features.size(0):])
|
504
|
-
|
505
|
-
return cell_scores, attn_weights
|
506
|
-
|
507
|
-
def train_graph_transformer(graphs, lr=0.01, dropout_rate=0.1, weight_decay=0.00001, epochs=100, save_fldr='', acc_threshold = 0.1):
|
508
|
-
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
509
|
-
model = GraphTransformer(gene_feature_size=1, cell_feature_size=1, hidden_dim=256, output_dim=1, attn_dim=128, dropout_rate=dropout_rate).to(device)
|
510
|
-
|
511
|
-
criterion = nn.MSELoss()
|
512
|
-
#optimizer = torch.optim.Adam(model.parameters(), lr=lr)
|
513
|
-
optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
|
514
|
-
|
515
|
-
training_log = []
|
516
|
-
|
517
|
-
accumulate_grad_batches=1
|
518
|
-
threshold=acc_threshold
|
519
|
-
|
520
|
-
for epoch in range(epochs):
|
521
|
-
model.train()
|
522
|
-
total_loss = 0
|
523
|
-
total_correct = 0
|
524
|
-
total_samples = 0
|
525
|
-
optimizer.zero_grad()
|
526
|
-
batch_count = 0 # Initialize batch_count
|
527
|
-
|
528
|
-
for graph in graphs:
|
529
|
-
adjacency_matrix = graph['adjacency_matrix'].to(device)
|
530
|
-
gene_features = graph['gene_features'].to(device)
|
531
|
-
cell_features = graph['cell_features'].to(device)
|
532
|
-
num_cells = graph['num_cells']
|
533
|
-
predictions, attn_weights = model(adjacency_matrix, gene_features, cell_features)
|
534
|
-
predictions = predictions.squeeze()
|
535
|
-
true_scores = cell_features[:num_cells, 0]
|
536
|
-
loss = criterion(predictions, true_scores) / accumulate_grad_batches
|
537
|
-
loss.backward()
|
538
|
-
|
539
|
-
# Calculate "accuracy"
|
540
|
-
with torch.no_grad():
|
541
|
-
correct_predictions = (torch.abs(predictions - true_scores) / true_scores <= threshold).sum().item()
|
542
|
-
total_correct += correct_predictions
|
543
|
-
total_samples += num_cells
|
544
|
-
|
545
|
-
batch_count += 1 # Increment batch_count
|
546
|
-
if batch_count % accumulate_grad_batches == 0 or batch_count == len(graphs):
|
547
|
-
optimizer.step()
|
548
|
-
optimizer.zero_grad()
|
549
|
-
|
550
|
-
total_loss += loss.item() * accumulate_grad_batches
|
551
|
-
|
552
|
-
accuracy = total_correct / total_samples
|
553
|
-
training_log.append({"Epoch": epoch+1, "Average Loss": total_loss / len(graphs), "Accuracy": accuracy})
|
554
|
-
print(f"Epoch {epoch+1}, Loss: {total_loss / len(graphs)}, Accuracy: {accuracy}", end="\r", flush=True)
|
555
|
-
|
556
|
-
# Save the training log and model as before
|
557
|
-
os.makedirs(save_fldr, exist_ok=True)
|
558
|
-
log_path = os.path.join(save_fldr, 'training_log.csv')
|
559
|
-
training_log_df = pd.DataFrame(training_log)
|
560
|
-
training_log_df.to_csv(log_path, index=False)
|
561
|
-
print(f"Training log saved to {log_path}")
|
562
|
-
|
563
|
-
model_path = os.path.join(save_fldr, 'model.pth')
|
564
|
-
torch.save(model.state_dict(), model_path)
|
565
|
-
print(f"Model saved to {model_path}")
|
566
|
-
|
567
|
-
return model
|
568
|
-
|
569
|
-
def annotate_cells_with_genes(graphs, model, gene_id_to_index):
|
570
|
-
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
571
|
-
model.to(device)
|
572
|
-
model.eval()
|
573
|
-
annotated_data = []
|
574
|
-
|
575
|
-
with torch.no_grad():
|
576
|
-
for graph in graphs:
|
577
|
-
adjacency_matrix = graph['adjacency_matrix'].to(device)
|
578
|
-
gene_features = graph['gene_features'].to(device)
|
579
|
-
cell_features = graph['cell_features'].to(device)
|
580
|
-
|
581
|
-
predictions, attn_weights = model(adjacency_matrix, gene_features, cell_features)
|
582
|
-
predictions = np.atleast_1d(predictions.squeeze().cpu().numpy())
|
583
|
-
attn_weights = np.atleast_2d(attn_weights.squeeze().cpu().numpy())
|
584
|
-
|
585
|
-
# This approach assumes all genes in gene_id_to_index are used in the model.
|
586
|
-
# Create a list of gene IDs present in this specific graph.
|
587
|
-
present_gene_ids = [key for key, value in gene_id_to_index.items() if value < gene_features.size(0)]
|
588
|
-
|
589
|
-
for cell_idx in range(cell_features.size(0)):
|
590
|
-
true_score = cell_features[cell_idx, 0].item()
|
591
|
-
predicted_score = predictions[cell_idx]
|
592
|
-
|
593
|
-
# Find the index of the most probable gene.
|
594
|
-
most_probable_gene_idx = attn_weights[cell_idx].argmax()
|
595
|
-
|
596
|
-
if len(present_gene_ids) > most_probable_gene_idx: # Ensure index is within the range
|
597
|
-
most_probable_gene_id = present_gene_ids[most_probable_gene_idx]
|
598
|
-
most_probable_gene_score = attn_weights[cell_idx, most_probable_gene_idx] if attn_weights.ndim > 1 else attn_weights[most_probable_gene_idx]
|
599
|
-
|
600
|
-
annotated_data.append({
|
601
|
-
"Cell ID": cell_idx,
|
602
|
-
"Most Probable Gene": most_probable_gene_id,
|
603
|
-
"Cell Score": true_score,
|
604
|
-
"Predicted Cell Score": predicted_score,
|
605
|
-
"Probability Score for Highest Gene": most_probable_gene_score
|
606
|
-
})
|
607
|
-
else:
|
608
|
-
# Handle the case where the index is out of bounds - this should not happen but is here for robustness
|
609
|
-
print("Error: Gene index out of bounds. This might indicate a mismatch in the model's output.")
|
610
|
-
|
611
|
-
return pd.DataFrame(annotated_data)
|
612
|
-
|
613
|
-
import torch
|
614
|
-
import torch.nn as nn
|
615
|
-
import torch.nn.functional as F
|
616
|
-
from torch.utils.data import Dataset, DataLoader, TensorDataset
|
617
|
-
|
618
|
-
# Let's assume that the feature embedding part and the dataset loading part
|
619
|
-
# has already been taken care of, and your data is already in the format
|
620
|
-
# suitable for PyTorch (i.e., Tensors).
|
621
|
-
|
622
|
-
class FeatureEmbedder(nn.Module):
|
623
|
-
def __init__(self, vocab_sizes, embedding_size):
|
624
|
-
super(FeatureEmbedder, self).__init__()
|
625
|
-
self.embeddings = nn.ModuleDict({
|
626
|
-
key: nn.Embedding(num_embeddings=vocab_size+1,
|
627
|
-
embedding_dim=embedding_size,
|
628
|
-
padding_idx=vocab_size)
|
629
|
-
for key, vocab_size in vocab_sizes.items()
|
630
|
-
})
|
631
|
-
# Adding the 'visit' embedding
|
632
|
-
self.embeddings['visit'] = nn.Parameter(torch.zeros(1, embedding_size))
|
633
|
-
|
634
|
-
def forward(self, feature_map, max_num_codes):
|
635
|
-
# Implementation will depend on how you want to handle sparse data
|
636
|
-
# This is just a placeholder
|
637
|
-
embeddings = {}
|
638
|
-
masks = {}
|
639
|
-
for key, tensor in feature_map.items():
|
640
|
-
embeddings[key] = self.embeddings[key](tensor.long())
|
641
|
-
mask = torch.ones_like(tensor, dtype=torch.float32)
|
642
|
-
masks[key] = mask.unsqueeze(-1)
|
643
|
-
|
644
|
-
# Batch size hardcoded for simplicity in example
|
645
|
-
batch_size = 1 # Replace with actual batch size
|
646
|
-
embeddings['visit'] = self.embeddings['visit'].expand(batch_size, -1, -1)
|
647
|
-
masks['visit'] = torch.ones(batch_size, 1)
|
648
|
-
|
649
|
-
return embeddings, masks
|
650
|
-
|
651
|
-
class GraphConvolutionalTransformer(nn.Module):
|
652
|
-
def __init__(self, embedding_size=128, num_attention_heads=1, **kwargs):
|
653
|
-
super(GraphConvolutionalTransformer, self).__init__()
|
654
|
-
# Transformer Blocks
|
655
|
-
self.layers = nn.ModuleList([
|
656
|
-
nn.TransformerEncoderLayer(
|
657
|
-
d_model=embedding_size,
|
658
|
-
nhead=num_attention_heads,
|
659
|
-
batch_first=True)
|
660
|
-
for _ in range(kwargs.get('num_transformer_stack', 3))
|
661
|
-
])
|
662
|
-
# Output Layer for Classification
|
663
|
-
self.output_layer = nn.Linear(embedding_size, 1)
|
664
|
-
|
665
|
-
def feedforward(self, features, mask=None, training=None):
|
666
|
-
# Implement feedforward logic (placeholder)
|
667
|
-
pass
|
668
|
-
|
669
|
-
def forward(self, embeddings, masks, mask=None, training=False):
|
670
|
-
features = embeddings
|
671
|
-
attentions = [] # Storing attentions if needed
|
672
|
-
|
673
|
-
# Pass through each Transformer block
|
674
|
-
for layer in self.layers:
|
675
|
-
features = layer(features) # Apply transformer encoding here
|
676
|
-
|
677
|
-
if mask is not None:
|
678
|
-
features = features * mask
|
679
|
-
|
680
|
-
logits = self.output_layer(features[:, 0, :]) # Using the 'visit' embedding for classification
|
681
|
-
return logits, attentions
|
682
|
-
|
683
|
-
# Usage Example
|
684
|
-
#vocab_sizes = {'dx_ints':3249, 'proc_ints':2210}
|
685
|
-
#embedding_size = 128
|
686
|
-
#gct_params = {
|
687
|
-
# 'embedding_size': embedding_size,
|
688
|
-
# 'num_transformer_stack': 3,
|
689
|
-
# 'num_attention_heads': 1
|
690
|
-
#}
|
691
|
-
#feature_embedder = FeatureEmbedder(vocab_sizes, embedding_size)
|
692
|
-
#gct_model = GraphConvolutionalTransformer(**gct_params)
|
693
|
-
#
|
694
|
-
# Assume `feature_map` is a dictionary of tensors, and `max_num_codes` is provided
|
695
|
-
#embeddings, masks = feature_embedder(feature_map, max_num_codes)
|
696
|
-
#logits, attentions = gct_model(embeddings, masks)
|
697
|
-
|
698
|
-
import torch
|
699
|
-
import torchvision.transforms as transforms
|
700
|
-
from torchvision.models import resnet50
|
701
|
-
from PIL import Image
|
702
|
-
import numpy as np
|
703
|
-
import umap
|
704
|
-
import pandas as pd
|
705
|
-
from sklearn.ensemble import RandomForestClassifier
|
706
|
-
from sklearn.preprocessing import StandardScaler
|
707
|
-
from scipy.stats import f_oneway, kruskal
|
708
|
-
from sklearn.cluster import KMeans
|
709
|
-
from scipy import stats
|
710
|
-
|
711
|
-
def load_image(image_path):
|
712
|
-
"""Load and preprocess an image."""
|
713
|
-
transform = transforms.Compose([
|
714
|
-
transforms.Resize((224, 224)),
|
715
|
-
transforms.ToTensor(),
|
716
|
-
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
717
|
-
])
|
718
|
-
image = Image.open(image_path).convert('RGB')
|
719
|
-
image = transform(image).unsqueeze(0)
|
720
|
-
return image
|
721
|
-
|
722
|
-
def extract_features(image_paths, resnet=resnet50):
|
723
|
-
"""Extract features from images using a pre-trained ResNet model."""
|
724
|
-
model = resnet(pretrained=True)
|
725
|
-
model = model.eval()
|
726
|
-
model = torch.nn.Sequential(*list(model.children())[:-1]) # Remove the last classification layer
|
727
|
-
|
728
|
-
features = []
|
729
|
-
for image_path in image_paths:
|
730
|
-
image = load_image(image_path)
|
731
|
-
with torch.no_grad():
|
732
|
-
feature = model(image).squeeze().numpy()
|
733
|
-
features.append(feature)
|
734
|
-
|
735
|
-
return np.array(features)
|
736
|
-
|
737
|
-
def check_normality(series):
|
738
|
-
"""Helper function to check if a feature is normally distributed."""
|
739
|
-
k2, p = stats.normaltest(series)
|
740
|
-
alpha = 0.05
|
741
|
-
if p < alpha: # null hypothesis: x comes from a normal distribution
|
742
|
-
return False
|
743
|
-
return True
|
744
|
-
|
745
|
-
def random_forest_feature_importance(all_df, cluster_col='cluster'):
|
746
|
-
"""Random Forest feature importance."""
|
747
|
-
numeric_features = all_df.select_dtypes(include=[np.number]).columns.tolist()
|
748
|
-
if cluster_col in numeric_features:
|
749
|
-
numeric_features.remove(cluster_col)
|
750
|
-
|
751
|
-
X = all_df[numeric_features]
|
752
|
-
y = all_df[cluster_col]
|
753
|
-
|
754
|
-
scaler = StandardScaler()
|
755
|
-
X_scaled = scaler.fit_transform(X)
|
756
|
-
|
757
|
-
model = RandomForestClassifier(n_estimators=100, random_state=42)
|
758
|
-
model.fit(X_scaled, y)
|
759
|
-
|
760
|
-
feature_importances = model.feature_importances_
|
761
|
-
|
762
|
-
importance_df = pd.DataFrame({
|
763
|
-
'Feature': numeric_features,
|
764
|
-
'Importance': feature_importances
|
765
|
-
}).sort_values(by='Importance', ascending=False)
|
766
|
-
|
767
|
-
return importance_df
|
768
|
-
|
769
|
-
def perform_statistical_tests(all_df, cluster_col='cluster'):
|
770
|
-
"""Perform ANOVA or Kruskal-Wallis tests depending on normality of features."""
|
771
|
-
numeric_features = all_df.select_dtypes(include=[np.number]).columns.tolist()
|
772
|
-
if cluster_col in numeric_features:
|
773
|
-
numeric_features.remove(cluster_col)
|
774
|
-
|
775
|
-
anova_results = []
|
776
|
-
kruskal_results = []
|
777
|
-
|
778
|
-
for feature in numeric_features:
|
779
|
-
groups = [all_df[all_df[cluster_col] == label][feature] for label in np.unique(all_df[cluster_col])]
|
780
|
-
|
781
|
-
if check_normality(all_df[feature]):
|
782
|
-
stat, p = f_oneway(*groups)
|
783
|
-
anova_results.append((feature, stat, p))
|
784
|
-
else:
|
785
|
-
stat, p = kruskal(*groups)
|
786
|
-
kruskal_results.append((feature, stat, p))
|
787
|
-
|
788
|
-
anova_df = pd.DataFrame(anova_results, columns=['Feature', 'ANOVA_Statistic', 'ANOVA_pValue'])
|
789
|
-
kruskal_df = pd.DataFrame(kruskal_results, columns=['Feature', 'Kruskal_Statistic', 'Kruskal_pValue'])
|
790
|
-
|
791
|
-
return anova_df, kruskal_df
|
792
|
-
|
793
|
-
def combine_results(rf_df, anova_df, kruskal_df):
|
794
|
-
"""Combine the results into a single DataFrame."""
|
795
|
-
combined_df = rf_df.merge(anova_df, on='Feature', how='left')
|
796
|
-
combined_df = combined_df.merge(kruskal_df, on='Feature', how='left')
|
797
|
-
return combined_df
|
798
|
-
|
799
|
-
def cluster_feature_analysis(all_df, cluster_col='cluster'):
|
800
|
-
"""
|
801
|
-
Perform Random Forest feature importance, ANOVA for normally distributed features,
|
802
|
-
and Kruskal-Wallis for non-normally distributed features. Combine results into a single DataFrame.
|
803
|
-
"""
|
804
|
-
rf_df = random_forest_feature_importance(all_df, cluster_col)
|
805
|
-
anova_df, kruskal_df = perform_statistical_tests(all_df, cluster_col)
|
806
|
-
combined_df = combine_results(rf_df, anova_df, kruskal_df)
|
807
|
-
return combined_df
|