spacr 0.2.1__py3-none-any.whl → 0.2.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (95) hide show
  1. spacr/gui.py +2 -1
  2. spacr/gui_core.py +75 -34
  3. spacr/gui_elements.py +323 -59
  4. spacr/gui_utils.py +26 -32
  5. spacr/resources/icons/abort.png +0 -0
  6. spacr/resources/icons/classify.png +0 -0
  7. spacr/resources/icons/make_masks.png +0 -0
  8. spacr/resources/icons/mask.png +0 -0
  9. spacr/resources/icons/measure.png +0 -0
  10. spacr/resources/icons/ml_analyze.png +0 -0
  11. spacr/resources/icons/recruitment.png +0 -0
  12. spacr/resources/icons/regression.png +0 -0
  13. spacr/resources/icons/run.png +0 -0
  14. spacr/resources/icons/spacr_logo_rotation.gif +0 -0
  15. spacr/resources/icons/train_cellpose.png +0 -0
  16. spacr/resources/icons/umap.png +0 -0
  17. {spacr-0.2.1.dist-info → spacr-0.2.3.dist-info}/METADATA +1 -1
  18. spacr-0.2.3.dist-info/RECORD +58 -0
  19. spacr/alpha.py +0 -807
  20. spacr/annotate_app.py +0 -670
  21. spacr/annotate_app_v2.py +0 -670
  22. spacr/app_make_masks_v2.py +0 -686
  23. spacr/classify_app.py +0 -201
  24. spacr/cli.py +0 -41
  25. spacr/foldseek.py +0 -779
  26. spacr/get_alfafold_structures.py +0 -72
  27. spacr/gui_2.py +0 -157
  28. spacr/gui_annotate.py +0 -145
  29. spacr/gui_classify_app.py +0 -201
  30. spacr/gui_make_masks_app.py +0 -927
  31. spacr/gui_make_masks_app_v2.py +0 -688
  32. spacr/gui_mask_app.py +0 -249
  33. spacr/gui_measure_app.py +0 -246
  34. spacr/gui_run.py +0 -58
  35. spacr/gui_sim_app.py +0 -0
  36. spacr/gui_wrappers.py +0 -149
  37. spacr/icons/abort.png +0 -0
  38. spacr/icons/abort.svg +0 -1
  39. spacr/icons/download.png +0 -0
  40. spacr/icons/download.svg +0 -1
  41. spacr/icons/download_for_offline_100dp_E8EAED_FILL0_wght100_GRAD-25_opsz48.png +0 -0
  42. spacr/icons/download_for_offline_100dp_E8EAED_FILL0_wght100_GRAD-25_opsz48.svg +0 -1
  43. spacr/icons/logo_spacr.png +0 -0
  44. spacr/icons/make_masks.png +0 -0
  45. spacr/icons/make_masks.svg +0 -1
  46. spacr/icons/map_barcodes.png +0 -0
  47. spacr/icons/map_barcodes.svg +0 -1
  48. spacr/icons/mask.png +0 -0
  49. spacr/icons/mask.svg +0 -1
  50. spacr/icons/measure.png +0 -0
  51. spacr/icons/measure.svg +0 -1
  52. spacr/icons/play_circle_100dp_E8EAED_FILL0_wght100_GRAD-25_opsz48.png +0 -0
  53. spacr/icons/play_circle_100dp_E8EAED_FILL0_wght100_GRAD-25_opsz48.svg +0 -1
  54. spacr/icons/run.png +0 -0
  55. spacr/icons/run.svg +0 -1
  56. spacr/icons/sequencing.png +0 -0
  57. spacr/icons/sequencing.svg +0 -1
  58. spacr/icons/settings.png +0 -0
  59. spacr/icons/settings.svg +0 -1
  60. spacr/icons/settings_100dp_E8EAED_FILL0_wght100_GRAD-25_opsz48.png +0 -0
  61. spacr/icons/settings_100dp_E8EAED_FILL0_wght100_GRAD-25_opsz48.svg +0 -1
  62. spacr/icons/stop_circle_100dp_E8EAED_FILL0_wght100_GRAD-25_opsz48.png +0 -0
  63. spacr/icons/stop_circle_100dp_E8EAED_FILL0_wght100_GRAD-25_opsz48.svg +0 -1
  64. spacr/icons/theater_comedy_100dp_E8EAED_FILL0_wght100_GRAD200_opsz48.png +0 -0
  65. spacr/icons/theater_comedy_100dp_E8EAED_FILL0_wght100_GRAD200_opsz48.svg +0 -1
  66. spacr/make_masks_app.py +0 -929
  67. spacr/make_masks_app_v2.py +0 -688
  68. spacr/mask_app.py +0 -249
  69. spacr/measure_app.py +0 -246
  70. spacr/models/cp/toxo_plaque_cyto_e25000_X1120_Y1120.CP_model +0 -0
  71. spacr/models/cp/toxo_plaque_cyto_e25000_X1120_Y1120.CP_model_settings.csv +0 -23
  72. spacr/models/cp/toxo_pv_lumen.CP_model +0 -0
  73. spacr/old_code.py +0 -358
  74. spacr/resources/icons/abort.svg +0 -1
  75. spacr/resources/icons/annotate.svg +0 -1
  76. spacr/resources/icons/classify.svg +0 -1
  77. spacr/resources/icons/download.svg +0 -1
  78. spacr/resources/icons/icon.psd +0 -0
  79. spacr/resources/icons/make_masks.svg +0 -1
  80. spacr/resources/icons/map_barcodes.svg +0 -1
  81. spacr/resources/icons/mask.svg +0 -1
  82. spacr/resources/icons/measure.svg +0 -1
  83. spacr/resources/icons/run.svg +0 -1
  84. spacr/resources/icons/run_2.png +0 -0
  85. spacr/resources/icons/run_2.svg +0 -1
  86. spacr/resources/icons/sequencing.svg +0 -1
  87. spacr/resources/icons/settings.svg +0 -1
  88. spacr/resources/icons/train_cellpose.svg +0 -1
  89. spacr/test_gui.py +0 -0
  90. spacr-0.2.1.dist-info/RECORD +0 -126
  91. /spacr/resources/icons/{cellpose.png → cellpose_all.png} +0 -0
  92. {spacr-0.2.1.dist-info → spacr-0.2.3.dist-info}/LICENSE +0 -0
  93. {spacr-0.2.1.dist-info → spacr-0.2.3.dist-info}/WHEEL +0 -0
  94. {spacr-0.2.1.dist-info → spacr-0.2.3.dist-info}/entry_points.txt +0 -0
  95. {spacr-0.2.1.dist-info → spacr-0.2.3.dist-info}/top_level.txt +0 -0
spacr/alpha.py DELETED
@@ -1,807 +0,0 @@
1
- from skimage import measure, feature
2
- from skimage.filters import gabor
3
- from skimage.color import rgb2gray
4
- from skimage.util import img_as_ubyte
5
- import numpy as np
6
- import pandas as pd
7
- from scipy.stats import skew, kurtosis, entropy, hmean, gmean, mode
8
- import pywt
9
-
10
- import os
11
- import pandas as pd
12
- from PIL import Image
13
- import torch
14
- import torch.nn as nn
15
- import torch.nn.functional as F
16
- from torch_geometric.data import Data, DataLoader
17
- from torch_geometric.nn import GCNConv, global_mean_pool
18
- from torch.optim import Adam
19
- import os
20
- import shutil
21
- import random
22
-
23
- # Step 1: Data Preparation
24
-
25
- # Load images
26
- def load_images(image_dir):
27
- images = {}
28
- for filename in os.listdir(image_dir):
29
- if filename.endswith(".png"):
30
- img = Image.open(os.path.join(image_dir, filename))
31
- images[filename] = img
32
- return images
33
-
34
- # Load sequencing data
35
- def load_sequencing_data(seq_file):
36
- seq_data = pd.read_csv(seq_file)
37
- return seq_data
38
-
39
- # Step 2: Data Representation
40
-
41
- # Image Representation (Using a simple CNN for feature extraction)
42
- class CNNFeatureExtractor(nn.Module):
43
- def __init__(self):
44
- super(CNNFeatureExtractor, self).__init__()
45
- self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)
46
- self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
47
- self.fc = nn.Linear(32 * 8 * 8, 128) # Assuming input images are 64x64
48
-
49
- def forward(self, x):
50
- x = F.relu(self.conv1(x))
51
- x = F.max_pool2d(x, 2)
52
- x = F.relu(self.conv2(x))
53
- x = F.max_pool2d(x, 2)
54
- x = x.view(x.size(0), -1)
55
- x = self.fc(x)
56
- return x
57
-
58
- # Graph Representation
59
- def create_graph(wells, sequencing_data):
60
- nodes = []
61
- edges = []
62
- node_features = []
63
-
64
- for well in wells:
65
- # Add node for each well
66
- nodes.append(well)
67
-
68
- # Get sequencing data for the well
69
- seq_info = sequencing_data[sequencing_data['well'] == well]
70
-
71
- # Create node features based on gene knockouts and abundances
72
- features = torch.tensor(seq_info['abundance'].values, dtype=torch.float)
73
- node_features.append(features)
74
-
75
- # Define edges (for simplicity, assume fully connected graph)
76
- for other_well in wells:
77
- if other_well != well:
78
- edges.append((wells.index(well), wells.index(other_well)))
79
-
80
- edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous()
81
- x = torch.stack(node_features)
82
-
83
- data = Data(x=x, edge_index=edge_index)
84
- return data
85
-
86
- # Step 3: Model Architecture
87
-
88
- class GraphTransformer(nn.Module):
89
- def __init__(self, in_channels, hidden_channels, out_channels):
90
- super(GraphTransformer, self).__init__()
91
- self.conv1 = GCNConv(in_channels, hidden_channels)
92
- self.conv2 = GCNConv(hidden_channels, hidden_channels)
93
- self.fc = nn.Linear(hidden_channels, out_channels)
94
- self.attention = nn.MultiheadAttention(hidden_channels, num_heads=8)
95
-
96
- def forward(self, x, edge_index, batch):
97
- x = F.relu(self.conv1(x, edge_index))
98
- x = F.relu(self.conv2(x, edge_index))
99
-
100
- # Apply attention mechanism
101
- x, _ = self.attention(x.unsqueeze(1), x.unsqueeze(1), x.unsqueeze(1))
102
- x = x.squeeze(1)
103
-
104
- x = global_mean_pool(x, batch)
105
- x = self.fc(x)
106
- return x
107
-
108
- # Step 4: Training
109
-
110
- # Training Loop
111
- def train(model, data_loader, criterion, optimizer, epochs=10):
112
- model.train()
113
- for epoch in range(epochs):
114
- for data in data_loader:
115
- optimizer.zero_grad()
116
- out = model(data.x, data.edge_index, data.batch)
117
- loss = criterion(out, data.y)
118
- loss.backward()
119
- optimizer.step()
120
- print(f'Epoch {epoch+1}, Loss: {loss.item()}')
121
-
122
- def evaluate(model, data_loader):
123
- model.eval()
124
- correct = 0
125
- total = 0
126
- with torch.no_grad():
127
- for data in data_loader:
128
- out = model(data.x, data.edge_index, data.batch)
129
- _, predicted = torch.max(out, 1)
130
- total += data.y.size(0)
131
- correct += (predicted == data.y).sum().item()
132
- accuracy = correct / total
133
- print(f'Accuracy: {accuracy * 100:.2f}%')
134
-
135
- def spacr_transformer(image_dir, seq_file, nr_grnas=1350, lr=0.001, mode='train'):
136
- images = load_images(image_dir)
137
-
138
- sequencing_data = load_sequencing_data(seq_file)
139
- wells = sequencing_data['well'].unique()
140
- graph_data = create_graph(wells, sequencing_data)
141
- model = GraphTransformer(in_channels=nr_grnas, hidden_channels=128, out_channels=nr_grnas)
142
- criterion = nn.CrossEntropyLoss()
143
- optimizer = Adam(model.parameters(), lr=lr)
144
- data_list = [graph_data]
145
- loader = DataLoader(data_list, batch_size=1, shuffle=True)
146
- if mode == 'train':
147
- train(model, loader, criterion, optimizer)
148
- elif mode == 'eval':
149
- evaluate(model, loader)
150
- else:
151
- raise ValueError('Invalid mode. Use "train" or "eval".')
152
-
153
- from skimage.feature import greycomatrix
154
-
155
- from skimage.feature import greycoprops
156
-
157
- def _calculate_glcm_features(intensity_image):
158
- glcm = greycomatrix(img_as_ubyte(intensity_image), distances=[1, 2, 3, 4], angles=[0, np.pi/4, np.pi/2, 3*np.pi/4], symmetric=True, normed=True)
159
- features = {}
160
- for prop in ['contrast', 'dissimilarity', 'homogeneity', 'energy', 'correlation', 'ASM']:
161
- for i, distance in enumerate([1, 2, 3, 4]):
162
- for j, angle in enumerate([0, np.pi/4, np.pi/2, 3*np.pi/4]):
163
- features[f'glcm_{prop}_d{distance}_a{angle}'] = greycoprops(glcm, prop)[i, j]
164
- return features
165
-
166
- from skimage.feature import local_binary_pattern
167
-
168
- def _calculate_lbp_features(intensity_image, P=8, R=1):
169
- lbp = local_binary_pattern(intensity_image, P, R, method='uniform')
170
- lbp_hist, _ = np.histogram(lbp, density=True, bins=np.arange(0, P + 3), range=(0, P + 2))
171
- return {f'lbp_{i}': val for i, val in enumerate(lbp_hist)}
172
-
173
- def _calculate_wavelet_features(intensity_image, wavelet='db1'):
174
- coeffs = pywt.wavedec2(intensity_image, wavelet=wavelet, level=3)
175
- features = {}
176
- for i, coeff in enumerate(coeffs):
177
- if isinstance(coeff, tuple):
178
- for j, subband in enumerate(coeff):
179
- features[f'wavelet_coeff_{i}_{j}_mean'] = np.mean(subband)
180
- features[f'wavelet_coeff_{i}_{j}_std'] = np.std(subband)
181
- features[f'wavelet_coeff_{i}_{j}_energy'] = np.sum(subband**2)
182
- else:
183
- features[f'wavelet_coeff_{i}_mean'] = np.mean(coeff)
184
- features[f'wavelet_coeff_{i}_std'] = np.std(coeff)
185
- features[f'wavelet_coeff_{i}_energy'] = np.sum(coeff**2)
186
- return features
187
-
188
-
189
- from .measure import _estimate_blur, _calculate_correlation_object_level, _calculate_homogeneity, _periphery_intensity, _outside_intensity, _calculate_radial_distribution, _create_dataframe, _extended_regionprops_table, _calculate_correlation_object_level
190
-
191
- def _intensity_measurements(cell_mask, nucleus_mask, pathogen_mask, cytoplasm_mask, channel_arrays, settings, sizes=[3, 6, 12, 24], periphery=True, outside=True):
192
- radial_dist = settings['radial_dist']
193
- calculate_correlation = settings['calculate_correlation']
194
- homogeneity = settings['homogeneity']
195
- distances = settings['homogeneity_distances']
196
-
197
- intensity_props = ["label", "centroid_weighted", "centroid_weighted_local", "max_intensity", "mean_intensity", "min_intensity", "integrated_intensity"]
198
- additional_props = ["standard_deviation_intensity", "median_intensity", "sum_intensity", "intensity_range", "mean_absolute_deviation_intensity", "skewness_intensity", "kurtosis_intensity", "variance_intensity", "mode_intensity", "energy_intensity", "entropy_intensity", "harmonic_mean_intensity", "geometric_mean_intensity"]
199
- col_lables = ['region_label', 'mean', '5_percentile', '10_percentile', '25_percentile', '50_percentile', '75_percentile', '85_percentile', '95_percentile']
200
- cell_dfs, nucleus_dfs, pathogen_dfs, cytoplasm_dfs = [], [], [], []
201
- ls = ['cell','nucleus','pathogen','cytoplasm']
202
- labels = [cell_mask, nucleus_mask, pathogen_mask, cytoplasm_mask]
203
- dfs = [cell_dfs, nucleus_dfs, pathogen_dfs, cytoplasm_dfs]
204
-
205
- for i in range(0,channel_arrays.shape[-1]):
206
- channel = channel_arrays[:, :, i]
207
- for j, (label, df) in enumerate(zip(labels, dfs)):
208
-
209
- if np.max(label) == 0:
210
- empty_df = pd.DataFrame()
211
- df.append(empty_df)
212
- continue
213
-
214
- mask_intensity_df = _extended_regionprops_table(label, channel, intensity_props)
215
-
216
- # Additional intensity properties
217
- region_props = measure.regionprops_table(label, intensity_image=channel, properties=['label'])
218
- intensity_values = [channel[region.coords[:, 0], region.coords[:, 1]] for region in measure.regionprops(label)]
219
- additional_data = {prop: [] for prop in additional_props}
220
-
221
- for values in intensity_values:
222
- if len(values) == 0:
223
- continue
224
- additional_data["standard_deviation_intensity"].append(np.std(values))
225
- additional_data["median_intensity"].append(np.median(values))
226
- additional_data["sum_intensity"].append(np.sum(values))
227
- additional_data["intensity_range"].append(np.max(values) - np.min(values))
228
- additional_data["mean_absolute_deviation_intensity"].append(np.mean(np.abs(values - np.mean(values))))
229
- additional_data["skewness_intensity"].append(skew(values))
230
- additional_data["kurtosis_intensity"].append(kurtosis(values))
231
- additional_data["variance_intensity"].append(np.var(values))
232
- additional_data["mode_intensity"].append(mode(values)[0][0])
233
- additional_data["energy_intensity"].append(np.sum(values**2))
234
- additional_data["entropy_intensity"].append(entropy(values))
235
- additional_data["harmonic_mean_intensity"].append(hmean(values))
236
- additional_data["geometric_mean_intensity"].append(gmean(values))
237
-
238
- for prop in additional_props:
239
- region_props[prop] = additional_data[prop]
240
-
241
- additional_df = pd.DataFrame(region_props)
242
- mask_intensity_df = pd.concat([mask_intensity_df.reset_index(drop=True), additional_df.reset_index(drop=True)], axis=1)
243
-
244
- if homogeneity:
245
- homogeneity_df = _calculate_homogeneity(label, channel, distances)
246
- mask_intensity_df = pd.concat([mask_intensity_df.reset_index(drop=True), homogeneity_df], axis=1)
247
-
248
- if periphery:
249
- if ls[j] == 'nucleus' or ls[j] == 'pathogen':
250
- periphery_intensity_stats = _periphery_intensity(label, channel)
251
- mask_intensity_df = pd.concat([mask_intensity_df, pd.DataFrame(periphery_intensity_stats, columns=[f'periphery_{stat}' for stat in col_lables])],axis=1)
252
-
253
- if outside:
254
- if ls[j] == 'nucleus' or ls[j] == 'pathogen':
255
- outside_intensity_stats = _outside_intensity(label, channel)
256
- mask_intensity_df = pd.concat([mask_intensity_df, pd.DataFrame(outside_intensity_stats, columns=[f'outside_{stat}' for stat in col_lables])], axis=1)
257
-
258
- # Adding GLCM features
259
- glcm_features = _calculate_glcm_features(channel)
260
- for k, v in glcm_features.items():
261
- mask_intensity_df[f'{ls[j]}_channel_{i}_{k}'] = v
262
-
263
- # Adding LBP features
264
- lbp_features = _calculate_lbp_features(channel)
265
- for k, v in lbp_features.items():
266
- mask_intensity_df[f'{ls[j]}_channel_{i}_{k}'] = v
267
-
268
- # Adding Wavelet features
269
- wavelet_features = _calculate_wavelet_features(channel)
270
- for k, v in wavelet_features.items():
271
- mask_intensity_df[f'{ls[j]}_channel_{i}_{k}'] = v
272
-
273
- blur_col = [_estimate_blur(channel[label == region_label]) for region_label in mask_intensity_df['label']]
274
- mask_intensity_df[f'{ls[j]}_channel_{i}_blur'] = blur_col
275
-
276
- mask_intensity_df.columns = [f'{ls[j]}_channel_{i}_{col}' if col != 'label' else col for col in mask_intensity_df.columns]
277
- df.append(mask_intensity_df)
278
-
279
- if radial_dist:
280
- if np.max(nucleus_mask) != 0:
281
- nucleus_radial_distributions = _calculate_radial_distribution(cell_mask, nucleus_mask, channel_arrays, num_bins=6)
282
- nucleus_df = _create_dataframe(nucleus_radial_distributions, 'nucleus')
283
- dfs[1].append(nucleus_df)
284
-
285
- if np.max(nucleus_mask) != 0:
286
- pathogen_radial_distributions = _calculate_radial_distribution(cell_mask, pathogen_mask, channel_arrays, num_bins=6)
287
- pathogen_df = _create_dataframe(pathogen_radial_distributions, 'pathogen')
288
- dfs[2].append(pathogen_df)
289
-
290
- if calculate_correlation:
291
- if channel_arrays.shape[-1] >= 2:
292
- for i in range(channel_arrays.shape[-1]):
293
- for j in range(i+1, channel_arrays.shape[-1]):
294
- chan_i = channel_arrays[:, :, i]
295
- chan_j = channel_arrays[:, :, j]
296
- for m, mask in enumerate(labels):
297
- coloc_df = _calculate_correlation_object_level(chan_i, chan_j, mask, settings)
298
- coloc_df.columns = [f'{ls[m]}_channel_{i}_channel_{j}_{col}' for col in coloc_df.columns]
299
- dfs[m].append(coloc_df)
300
-
301
- return pd.concat(cell_dfs, axis=1), pd.concat(nucleus_dfs, axis=1), pd.concat(pathogen_dfs, axis=1), pd.concat(cytoplasm_dfs, axis=1)
302
-
303
- def sample_and_copy_images(folder_list, nr_of_images, dst):
304
-
305
- if isinstance(folder_list, str):
306
- folder_list = [folder_list]
307
-
308
- # Create the destination folder if it does not exist
309
- if not os.path.exists(dst):
310
- os.makedirs(dst)
311
-
312
- # Calculate the number of images to sample from each folder
313
- nr_of_images_per_folder = nr_of_images // len(folder_list)
314
-
315
- print(f"Sampling {nr_of_images_per_folder} images from {len(folder_list)} folders...")
316
- # Initialize a list to hold the paths of the images to be copied
317
- images_to_copy = []
318
-
319
- for folder in folder_list:
320
- # Get a list of all files in the current folder
321
- all_files = [os.path.join(folder, file) for file in os.listdir(folder)]
322
-
323
- # Filter out non-image files
324
- image_files = [file for file in all_files if file.lower().endswith(('.jpg', '.jpeg', '.png', '.bmp', '.gif', '.tiff', '.tif'))]
325
-
326
- # Sample images randomly from the list
327
- sampled_images = random.sample(image_files, min(nr_of_images_per_folder, len(image_files)))
328
-
329
- # Add the sampled images to the list of images to copy
330
- images_to_copy.extend(sampled_images)
331
-
332
- # Copy the sampled images to the destination folder
333
- for image in images_to_copy:
334
- shutil.copy(image, os.path.join(dst, os.path.basename(image)))
335
-
336
- import os
337
- import torch
338
- import torch.nn as nn
339
- import torch.nn.functional as F
340
- from collections import defaultdict
341
- from torch.utils.data import Dataset, DataLoader
342
- import pandas as pd
343
- import numpy as np
344
- import torch.optim as optim
345
-
346
- def generate_graphs(sequencing, scores, cell_min, gene_min_read):
347
- # Load and preprocess sequencing (gene) data
348
- gene_df = pd.read_csv(sequencing)
349
- gene_df = gene_df.rename(columns={'prc': 'well_id', 'grna': 'gene_id', 'count': 'read_count'})
350
- # Filter out genes with read counts less than gene_min_read
351
- gene_df = gene_df[gene_df['read_count'] >= gene_min_read]
352
- total_reads_per_well = gene_df.groupby('well_id')['read_count'].sum().reset_index(name='total_reads')
353
- gene_df = gene_df.merge(total_reads_per_well, on='well_id')
354
- gene_df['well_read_fraction'] = gene_df['read_count'] / gene_df['total_reads']
355
-
356
- # Load and preprocess cell score data
357
- cell_df = pd.read_csv(scores)
358
- cell_df = cell_df[['prcfo', 'prc', 'pred']].rename(columns={'prcfo': 'cell_id', 'prc': 'well_id', 'pred': 'score'})
359
-
360
- # Create a global mapping of gene IDs to indices
361
- unique_genes = gene_df['gene_id'].unique()
362
- gene_id_to_index = {gene_id: index for index, gene_id in enumerate(unique_genes)}
363
-
364
- graphs = []
365
- for well_id in pd.unique(gene_df['well_id']):
366
- well_genes = gene_df[gene_df['well_id'] == well_id]
367
- well_cells = cell_df[cell_df['well_id'] == well_id]
368
-
369
- # Skip wells with no cells or genes or with fewer cells than threshold
370
- if well_cells.empty or well_genes.empty or len(well_cells) < cell_min:
371
- continue
372
-
373
- # Initialize gene features tensor with zeros for all unique genes
374
- gene_features = torch.zeros((len(gene_id_to_index), 1), dtype=torch.float)
375
-
376
- # Update gene features tensor with well_read_fraction for genes present in this well
377
- for _, row in well_genes.iterrows():
378
- gene_index = gene_id_to_index[row['gene_id']]
379
- gene_features[gene_index] = torch.tensor([[row['well_read_fraction']]])
380
-
381
- # Prepare cell features (scores)
382
- cell_features = torch.tensor(well_cells['score'].values, dtype=torch.float).view(-1, 1)
383
-
384
- num_genes = len(gene_id_to_index)
385
- num_cells = cell_features.size(0)
386
- num_nodes = num_genes + num_cells
387
-
388
- # Create adjacency matrix connecting each cell to all genes in the well
389
- adj = torch.zeros((num_nodes, num_nodes), dtype=torch.float)
390
- for _, row in well_genes.iterrows():
391
- gene_index = gene_id_to_index[row['gene_id']]
392
- adj[num_genes:, gene_index] = 1
393
-
394
- graph = {
395
- 'adjacency_matrix': adj,
396
- 'gene_features': gene_features,
397
- 'cell_features': cell_features,
398
- 'num_cells': num_cells,
399
- 'num_genes': num_genes
400
- }
401
- graphs.append(graph)
402
-
403
- print(f'Generated dataset with {len(graphs)} graphs')
404
- return graphs, gene_id_to_index
405
-
406
- def print_graphs_info(graphs, gene_id_to_index):
407
- # Invert the gene_id_to_index mapping for easy lookup
408
- index_to_gene_id = {v: k for k, v in gene_id_to_index.items()}
409
-
410
- for i, graph in enumerate(graphs, start=1):
411
- print(f"Graph {i}:")
412
- num_genes = graph['num_genes']
413
- num_cells = graph['num_cells']
414
- gene_features = graph['gene_features']
415
- cell_features = graph['cell_features']
416
-
417
- print(f" Number of Genes: {num_genes}")
418
- print(f" Number of Cells: {num_cells}")
419
-
420
- # Identify genes present in the graph based on non-zero feature values
421
- present_genes = [index_to_gene_id[idx] for idx, feature in enumerate(gene_features) if feature.item() > 0]
422
- print(" Genes present in this Graph:", present_genes)
423
-
424
- # Display gene features for genes present in the graph
425
- print(" Gene Features:")
426
- for gene_id in present_genes:
427
- idx = gene_id_to_index[gene_id]
428
- print(f" {gene_id}: {gene_features[idx].item()}")
429
-
430
- # Display a sample of cell features, for brevity
431
- print(" Cell Features (sample):")
432
- for idx, feature in enumerate(cell_features[:min(5, len(cell_features))]):
433
- print(f"Cell {idx+1}: {feature.item()}")
434
-
435
- print("-" * 40)
436
-
437
- class Attention(nn.Module):
438
- def __init__(self, feature_dim, attn_dim, dropout_rate=0.1):
439
- super(Attention, self).__init__()
440
- self.query = nn.Linear(feature_dim, attn_dim)
441
- self.key = nn.Linear(feature_dim, attn_dim)
442
- self.value = nn.Linear(feature_dim, feature_dim)
443
- self.scale = 1.0 / (attn_dim ** 0.5)
444
- self.dropout = nn.Dropout(dropout_rate)
445
-
446
- def forward(self, gene_features, cell_features):
447
- # Queries come from the cell features
448
- q = self.query(cell_features)
449
- # Keys and values come from the gene features
450
- k = self.key(gene_features)
451
- v = self.value(gene_features)
452
-
453
- # Compute attention weights
454
- attn_weights = torch.matmul(q, k.transpose(-2, -1)) * self.scale
455
- attn_weights = F.softmax(attn_weights, dim=-1)
456
- # Apply dropout to attention weights
457
- attn_weights = self.dropout(attn_weights)
458
-
459
- # Apply attention weights to the values
460
- attn_output = torch.matmul(attn_weights, v)
461
-
462
- return attn_output, attn_weights
463
-
464
- class GraphTransformer(nn.Module):
465
- def __init__(self, gene_feature_size, cell_feature_size, hidden_dim, output_dim, attn_dim, dropout_rate=0.1):
466
- super(GraphTransformer, self).__init__()
467
- self.gene_transform = nn.Linear(gene_feature_size, hidden_dim)
468
- self.cell_transform = nn.Linear(cell_feature_size, hidden_dim)
469
- self.dropout = nn.Dropout(dropout_rate)
470
-
471
- # Attention layer to let each cell attend to all genes
472
- self.attention = Attention(hidden_dim, attn_dim)
473
-
474
- # This layer is used to transform the combined features after attention
475
- self.combine_transform = nn.Linear(2 * hidden_dim, hidden_dim)
476
-
477
- # Output layer for predicting cell scores, ensuring it matches the number of cells
478
- self.cell_output = nn.Linear(hidden_dim, output_dim)
479
-
480
- def forward(self, adjacency_matrix, gene_features, cell_features):
481
- # Apply initial transformation to gene and cell features
482
- transformed_gene_features = F.relu(self.gene_transform(gene_features))
483
- transformed_cell_features = F.relu(self.cell_transform(cell_features))
484
-
485
- # Incorporate attention mechanism
486
- attn_output, attn_weights = self.attention(transformed_gene_features, transformed_cell_features)
487
-
488
- # Combine the transformed cell features with the attention output features
489
- combined_cell_features = torch.cat((transformed_cell_features, attn_output), dim=1)
490
-
491
- # Apply dropout here as well
492
- combined_cell_features = self.dropout(combined_cell_features)
493
-
494
- combined_cell_features = F.relu(self.combine_transform(combined_cell_features))
495
-
496
- # Combine gene and cell features for message passing
497
- combined_features = torch.cat((transformed_gene_features, combined_cell_features), dim=0)
498
-
499
- # Apply message passing via adjacency matrix multiplication
500
- message_passed_features = torch.matmul(adjacency_matrix, combined_features)
501
-
502
- # Predict cell scores from the post-message passed cell features
503
- cell_scores = self.cell_output(message_passed_features[-cell_features.size(0):])
504
-
505
- return cell_scores, attn_weights
506
-
507
- def train_graph_transformer(graphs, lr=0.01, dropout_rate=0.1, weight_decay=0.00001, epochs=100, save_fldr='', acc_threshold = 0.1):
508
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
509
- model = GraphTransformer(gene_feature_size=1, cell_feature_size=1, hidden_dim=256, output_dim=1, attn_dim=128, dropout_rate=dropout_rate).to(device)
510
-
511
- criterion = nn.MSELoss()
512
- #optimizer = torch.optim.Adam(model.parameters(), lr=lr)
513
- optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
514
-
515
- training_log = []
516
-
517
- accumulate_grad_batches=1
518
- threshold=acc_threshold
519
-
520
- for epoch in range(epochs):
521
- model.train()
522
- total_loss = 0
523
- total_correct = 0
524
- total_samples = 0
525
- optimizer.zero_grad()
526
- batch_count = 0 # Initialize batch_count
527
-
528
- for graph in graphs:
529
- adjacency_matrix = graph['adjacency_matrix'].to(device)
530
- gene_features = graph['gene_features'].to(device)
531
- cell_features = graph['cell_features'].to(device)
532
- num_cells = graph['num_cells']
533
- predictions, attn_weights = model(adjacency_matrix, gene_features, cell_features)
534
- predictions = predictions.squeeze()
535
- true_scores = cell_features[:num_cells, 0]
536
- loss = criterion(predictions, true_scores) / accumulate_grad_batches
537
- loss.backward()
538
-
539
- # Calculate "accuracy"
540
- with torch.no_grad():
541
- correct_predictions = (torch.abs(predictions - true_scores) / true_scores <= threshold).sum().item()
542
- total_correct += correct_predictions
543
- total_samples += num_cells
544
-
545
- batch_count += 1 # Increment batch_count
546
- if batch_count % accumulate_grad_batches == 0 or batch_count == len(graphs):
547
- optimizer.step()
548
- optimizer.zero_grad()
549
-
550
- total_loss += loss.item() * accumulate_grad_batches
551
-
552
- accuracy = total_correct / total_samples
553
- training_log.append({"Epoch": epoch+1, "Average Loss": total_loss / len(graphs), "Accuracy": accuracy})
554
- print(f"Epoch {epoch+1}, Loss: {total_loss / len(graphs)}, Accuracy: {accuracy}", end="\r", flush=True)
555
-
556
- # Save the training log and model as before
557
- os.makedirs(save_fldr, exist_ok=True)
558
- log_path = os.path.join(save_fldr, 'training_log.csv')
559
- training_log_df = pd.DataFrame(training_log)
560
- training_log_df.to_csv(log_path, index=False)
561
- print(f"Training log saved to {log_path}")
562
-
563
- model_path = os.path.join(save_fldr, 'model.pth')
564
- torch.save(model.state_dict(), model_path)
565
- print(f"Model saved to {model_path}")
566
-
567
- return model
568
-
569
- def annotate_cells_with_genes(graphs, model, gene_id_to_index):
570
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
571
- model.to(device)
572
- model.eval()
573
- annotated_data = []
574
-
575
- with torch.no_grad():
576
- for graph in graphs:
577
- adjacency_matrix = graph['adjacency_matrix'].to(device)
578
- gene_features = graph['gene_features'].to(device)
579
- cell_features = graph['cell_features'].to(device)
580
-
581
- predictions, attn_weights = model(adjacency_matrix, gene_features, cell_features)
582
- predictions = np.atleast_1d(predictions.squeeze().cpu().numpy())
583
- attn_weights = np.atleast_2d(attn_weights.squeeze().cpu().numpy())
584
-
585
- # This approach assumes all genes in gene_id_to_index are used in the model.
586
- # Create a list of gene IDs present in this specific graph.
587
- present_gene_ids = [key for key, value in gene_id_to_index.items() if value < gene_features.size(0)]
588
-
589
- for cell_idx in range(cell_features.size(0)):
590
- true_score = cell_features[cell_idx, 0].item()
591
- predicted_score = predictions[cell_idx]
592
-
593
- # Find the index of the most probable gene.
594
- most_probable_gene_idx = attn_weights[cell_idx].argmax()
595
-
596
- if len(present_gene_ids) > most_probable_gene_idx: # Ensure index is within the range
597
- most_probable_gene_id = present_gene_ids[most_probable_gene_idx]
598
- most_probable_gene_score = attn_weights[cell_idx, most_probable_gene_idx] if attn_weights.ndim > 1 else attn_weights[most_probable_gene_idx]
599
-
600
- annotated_data.append({
601
- "Cell ID": cell_idx,
602
- "Most Probable Gene": most_probable_gene_id,
603
- "Cell Score": true_score,
604
- "Predicted Cell Score": predicted_score,
605
- "Probability Score for Highest Gene": most_probable_gene_score
606
- })
607
- else:
608
- # Handle the case where the index is out of bounds - this should not happen but is here for robustness
609
- print("Error: Gene index out of bounds. This might indicate a mismatch in the model's output.")
610
-
611
- return pd.DataFrame(annotated_data)
612
-
613
- import torch
614
- import torch.nn as nn
615
- import torch.nn.functional as F
616
- from torch.utils.data import Dataset, DataLoader, TensorDataset
617
-
618
- # Let's assume that the feature embedding part and the dataset loading part
619
- # has already been taken care of, and your data is already in the format
620
- # suitable for PyTorch (i.e., Tensors).
621
-
622
- class FeatureEmbedder(nn.Module):
623
- def __init__(self, vocab_sizes, embedding_size):
624
- super(FeatureEmbedder, self).__init__()
625
- self.embeddings = nn.ModuleDict({
626
- key: nn.Embedding(num_embeddings=vocab_size+1,
627
- embedding_dim=embedding_size,
628
- padding_idx=vocab_size)
629
- for key, vocab_size in vocab_sizes.items()
630
- })
631
- # Adding the 'visit' embedding
632
- self.embeddings['visit'] = nn.Parameter(torch.zeros(1, embedding_size))
633
-
634
- def forward(self, feature_map, max_num_codes):
635
- # Implementation will depend on how you want to handle sparse data
636
- # This is just a placeholder
637
- embeddings = {}
638
- masks = {}
639
- for key, tensor in feature_map.items():
640
- embeddings[key] = self.embeddings[key](tensor.long())
641
- mask = torch.ones_like(tensor, dtype=torch.float32)
642
- masks[key] = mask.unsqueeze(-1)
643
-
644
- # Batch size hardcoded for simplicity in example
645
- batch_size = 1 # Replace with actual batch size
646
- embeddings['visit'] = self.embeddings['visit'].expand(batch_size, -1, -1)
647
- masks['visit'] = torch.ones(batch_size, 1)
648
-
649
- return embeddings, masks
650
-
651
- class GraphConvolutionalTransformer(nn.Module):
652
- def __init__(self, embedding_size=128, num_attention_heads=1, **kwargs):
653
- super(GraphConvolutionalTransformer, self).__init__()
654
- # Transformer Blocks
655
- self.layers = nn.ModuleList([
656
- nn.TransformerEncoderLayer(
657
- d_model=embedding_size,
658
- nhead=num_attention_heads,
659
- batch_first=True)
660
- for _ in range(kwargs.get('num_transformer_stack', 3))
661
- ])
662
- # Output Layer for Classification
663
- self.output_layer = nn.Linear(embedding_size, 1)
664
-
665
- def feedforward(self, features, mask=None, training=None):
666
- # Implement feedforward logic (placeholder)
667
- pass
668
-
669
- def forward(self, embeddings, masks, mask=None, training=False):
670
- features = embeddings
671
- attentions = [] # Storing attentions if needed
672
-
673
- # Pass through each Transformer block
674
- for layer in self.layers:
675
- features = layer(features) # Apply transformer encoding here
676
-
677
- if mask is not None:
678
- features = features * mask
679
-
680
- logits = self.output_layer(features[:, 0, :]) # Using the 'visit' embedding for classification
681
- return logits, attentions
682
-
683
- # Usage Example
684
- #vocab_sizes = {'dx_ints':3249, 'proc_ints':2210}
685
- #embedding_size = 128
686
- #gct_params = {
687
- # 'embedding_size': embedding_size,
688
- # 'num_transformer_stack': 3,
689
- # 'num_attention_heads': 1
690
- #}
691
- #feature_embedder = FeatureEmbedder(vocab_sizes, embedding_size)
692
- #gct_model = GraphConvolutionalTransformer(**gct_params)
693
- #
694
- # Assume `feature_map` is a dictionary of tensors, and `max_num_codes` is provided
695
- #embeddings, masks = feature_embedder(feature_map, max_num_codes)
696
- #logits, attentions = gct_model(embeddings, masks)
697
-
698
- import torch
699
- import torchvision.transforms as transforms
700
- from torchvision.models import resnet50
701
- from PIL import Image
702
- import numpy as np
703
- import umap
704
- import pandas as pd
705
- from sklearn.ensemble import RandomForestClassifier
706
- from sklearn.preprocessing import StandardScaler
707
- from scipy.stats import f_oneway, kruskal
708
- from sklearn.cluster import KMeans
709
- from scipy import stats
710
-
711
- def load_image(image_path):
712
- """Load and preprocess an image."""
713
- transform = transforms.Compose([
714
- transforms.Resize((224, 224)),
715
- transforms.ToTensor(),
716
- transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
717
- ])
718
- image = Image.open(image_path).convert('RGB')
719
- image = transform(image).unsqueeze(0)
720
- return image
721
-
722
- def extract_features(image_paths, resnet=resnet50):
723
- """Extract features from images using a pre-trained ResNet model."""
724
- model = resnet(pretrained=True)
725
- model = model.eval()
726
- model = torch.nn.Sequential(*list(model.children())[:-1]) # Remove the last classification layer
727
-
728
- features = []
729
- for image_path in image_paths:
730
- image = load_image(image_path)
731
- with torch.no_grad():
732
- feature = model(image).squeeze().numpy()
733
- features.append(feature)
734
-
735
- return np.array(features)
736
-
737
- def check_normality(series):
738
- """Helper function to check if a feature is normally distributed."""
739
- k2, p = stats.normaltest(series)
740
- alpha = 0.05
741
- if p < alpha: # null hypothesis: x comes from a normal distribution
742
- return False
743
- return True
744
-
745
- def random_forest_feature_importance(all_df, cluster_col='cluster'):
746
- """Random Forest feature importance."""
747
- numeric_features = all_df.select_dtypes(include=[np.number]).columns.tolist()
748
- if cluster_col in numeric_features:
749
- numeric_features.remove(cluster_col)
750
-
751
- X = all_df[numeric_features]
752
- y = all_df[cluster_col]
753
-
754
- scaler = StandardScaler()
755
- X_scaled = scaler.fit_transform(X)
756
-
757
- model = RandomForestClassifier(n_estimators=100, random_state=42)
758
- model.fit(X_scaled, y)
759
-
760
- feature_importances = model.feature_importances_
761
-
762
- importance_df = pd.DataFrame({
763
- 'Feature': numeric_features,
764
- 'Importance': feature_importances
765
- }).sort_values(by='Importance', ascending=False)
766
-
767
- return importance_df
768
-
769
- def perform_statistical_tests(all_df, cluster_col='cluster'):
770
- """Perform ANOVA or Kruskal-Wallis tests depending on normality of features."""
771
- numeric_features = all_df.select_dtypes(include=[np.number]).columns.tolist()
772
- if cluster_col in numeric_features:
773
- numeric_features.remove(cluster_col)
774
-
775
- anova_results = []
776
- kruskal_results = []
777
-
778
- for feature in numeric_features:
779
- groups = [all_df[all_df[cluster_col] == label][feature] for label in np.unique(all_df[cluster_col])]
780
-
781
- if check_normality(all_df[feature]):
782
- stat, p = f_oneway(*groups)
783
- anova_results.append((feature, stat, p))
784
- else:
785
- stat, p = kruskal(*groups)
786
- kruskal_results.append((feature, stat, p))
787
-
788
- anova_df = pd.DataFrame(anova_results, columns=['Feature', 'ANOVA_Statistic', 'ANOVA_pValue'])
789
- kruskal_df = pd.DataFrame(kruskal_results, columns=['Feature', 'Kruskal_Statistic', 'Kruskal_pValue'])
790
-
791
- return anova_df, kruskal_df
792
-
793
- def combine_results(rf_df, anova_df, kruskal_df):
794
- """Combine the results into a single DataFrame."""
795
- combined_df = rf_df.merge(anova_df, on='Feature', how='left')
796
- combined_df = combined_df.merge(kruskal_df, on='Feature', how='left')
797
- return combined_df
798
-
799
- def cluster_feature_analysis(all_df, cluster_col='cluster'):
800
- """
801
- Perform Random Forest feature importance, ANOVA for normally distributed features,
802
- and Kruskal-Wallis for non-normally distributed features. Combine results into a single DataFrame.
803
- """
804
- rf_df = random_forest_feature_importance(all_df, cluster_col)
805
- anova_df, kruskal_df = perform_statistical_tests(all_df, cluster_col)
806
- combined_df = combine_results(rf_df, anova_df, kruskal_df)
807
- return combined_df