spacr 0.0.20__py3-none-any.whl → 0.0.21__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- spacr/alpha.py +291 -14
- spacr/annotate_app.py +2 -2
- spacr/core.py +1301 -426
- spacr/foldseek.py +793 -0
- spacr/get_alfafold_structures.py +72 -0
- spacr/gui_mask_app.py +30 -10
- spacr/gui_utils.py +17 -2
- spacr/io.py +260 -102
- spacr/measure.py +150 -64
- spacr/plot.py +151 -12
- spacr/sim.py +666 -119
- spacr/timelapse.py +139 -9
- spacr/train.py +18 -10
- spacr/utils.py +43 -43
- {spacr-0.0.20.dist-info → spacr-0.0.21.dist-info}/METADATA +5 -2
- spacr-0.0.21.dist-info/RECORD +33 -0
- spacr-0.0.20.dist-info/RECORD +0 -31
- {spacr-0.0.20.dist-info → spacr-0.0.21.dist-info}/LICENSE +0 -0
- {spacr-0.0.20.dist-info → spacr-0.0.21.dist-info}/WHEEL +0 -0
- {spacr-0.0.20.dist-info → spacr-0.0.21.dist-info}/entry_points.txt +0 -0
- {spacr-0.0.20.dist-info → spacr-0.0.21.dist-info}/top_level.txt +0 -0
spacr/alpha.py
CHANGED
@@ -1,18 +1,295 @@
|
|
1
|
-
|
2
|
-
|
3
|
-
|
1
|
+
from skimage import measure, feature
|
2
|
+
from skimage.filters import gabor
|
3
|
+
from skimage.color import rgb2gray
|
4
|
+
from skimage.feature.texture import greycomatrix, greycoprops, local_binary_pattern
|
5
|
+
from skimage.util import img_as_ubyte
|
6
|
+
import numpy as np
|
7
|
+
import pandas as pd
|
8
|
+
from scipy.stats import skew, kurtosis, entropy, hmean, gmean, mode
|
9
|
+
import pywt
|
4
10
|
|
5
|
-
|
11
|
+
import os
|
12
|
+
import pandas as pd
|
13
|
+
from PIL import Image
|
14
|
+
from torchvision import transforms
|
15
|
+
import torch
|
16
|
+
import torch.nn as nn
|
17
|
+
import torch.nn.functional as F
|
18
|
+
from torch_geometric.data import Data, DataLoader
|
19
|
+
from torch_geometric.nn import GCNConv, global_mean_pool
|
20
|
+
from torch.optim import Adam
|
21
|
+
|
22
|
+
# Step 1: Data Preparation
|
23
|
+
|
24
|
+
# Load images
|
25
|
+
def load_images(image_dir):
|
26
|
+
images = {}
|
27
|
+
for filename in os.listdir(image_dir):
|
28
|
+
if filename.endswith(".png"):
|
29
|
+
img = Image.open(os.path.join(image_dir, filename))
|
30
|
+
images[filename] = img
|
31
|
+
return images
|
32
|
+
|
33
|
+
# Load sequencing data
|
34
|
+
def load_sequencing_data(seq_file):
|
35
|
+
seq_data = pd.read_csv(seq_file)
|
36
|
+
return seq_data
|
37
|
+
|
38
|
+
# Step 2: Data Representation
|
39
|
+
|
40
|
+
# Image Representation (Using a simple CNN for feature extraction)
|
41
|
+
class CNNFeatureExtractor(nn.Module):
|
42
|
+
def __init__(self):
|
43
|
+
super(CNNFeatureExtractor, self).__init__()
|
44
|
+
self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)
|
45
|
+
self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
|
46
|
+
self.fc = nn.Linear(32 * 8 * 8, 128) # Assuming input images are 64x64
|
47
|
+
|
48
|
+
def forward(self, x):
|
49
|
+
x = F.relu(self.conv1(x))
|
50
|
+
x = F.max_pool2d(x, 2)
|
51
|
+
x = F.relu(self.conv2(x))
|
52
|
+
x = F.max_pool2d(x, 2)
|
53
|
+
x = x.view(x.size(0), -1)
|
54
|
+
x = self.fc(x)
|
55
|
+
return x
|
56
|
+
|
57
|
+
# Graph Representation
|
58
|
+
def create_graph(wells, sequencing_data):
|
59
|
+
nodes = []
|
60
|
+
edges = []
|
61
|
+
node_features = []
|
62
|
+
|
63
|
+
for well in wells:
|
64
|
+
# Add node for each well
|
65
|
+
nodes.append(well)
|
66
|
+
|
67
|
+
# Get sequencing data for the well
|
68
|
+
seq_info = sequencing_data[sequencing_data['well'] == well]
|
69
|
+
|
70
|
+
# Create node features based on gene knockouts and abundances
|
71
|
+
features = torch.tensor(seq_info['abundance'].values, dtype=torch.float)
|
72
|
+
node_features.append(features)
|
73
|
+
|
74
|
+
# Define edges (for simplicity, assume fully connected graph)
|
75
|
+
for other_well in wells:
|
76
|
+
if other_well != well:
|
77
|
+
edges.append((wells.index(well), wells.index(other_well)))
|
78
|
+
|
79
|
+
edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous()
|
80
|
+
x = torch.stack(node_features)
|
81
|
+
|
82
|
+
data = Data(x=x, edge_index=edge_index)
|
83
|
+
return data
|
84
|
+
|
85
|
+
# Step 3: Model Architecture
|
86
|
+
|
87
|
+
class GraphTransformer(nn.Module):
|
88
|
+
def __init__(self, in_channels, hidden_channels, out_channels):
|
89
|
+
super(GraphTransformer, self).__init__()
|
90
|
+
self.conv1 = GCNConv(in_channels, hidden_channels)
|
91
|
+
self.conv2 = GCNConv(hidden_channels, hidden_channels)
|
92
|
+
self.fc = nn.Linear(hidden_channels, out_channels)
|
93
|
+
self.attention = nn.MultiheadAttention(hidden_channels, num_heads=8)
|
94
|
+
|
95
|
+
def forward(self, x, edge_index, batch):
|
96
|
+
x = F.relu(self.conv1(x, edge_index))
|
97
|
+
x = F.relu(self.conv2(x, edge_index))
|
98
|
+
|
99
|
+
# Apply attention mechanism
|
100
|
+
x, _ = self.attention(x.unsqueeze(1), x.unsqueeze(1), x.unsqueeze(1))
|
101
|
+
x = x.squeeze(1)
|
102
|
+
|
103
|
+
x = global_mean_pool(x, batch)
|
104
|
+
x = self.fc(x)
|
105
|
+
return x
|
106
|
+
|
107
|
+
# Step 4: Training
|
108
|
+
|
109
|
+
# Training Loop
|
110
|
+
def train(model, data_loader, criterion, optimizer, epochs=10):
|
111
|
+
model.train()
|
112
|
+
for epoch in range(epochs):
|
113
|
+
for data in data_loader:
|
114
|
+
optimizer.zero_grad()
|
115
|
+
out = model(data.x, data.edge_index, data.batch)
|
116
|
+
loss = criterion(out, data.y)
|
117
|
+
loss.backward()
|
118
|
+
optimizer.step()
|
119
|
+
print(f'Epoch {epoch+1}, Loss: {loss.item()}')
|
120
|
+
|
121
|
+
def evaluate(model, data_loader):
|
122
|
+
model.eval()
|
123
|
+
correct = 0
|
124
|
+
total = 0
|
125
|
+
with torch.no_grad():
|
126
|
+
for data in data_loader:
|
127
|
+
out = model(data.x, data.edge_index, data.batch)
|
128
|
+
_, predicted = torch.max(out, 1)
|
129
|
+
total += data.y.size(0)
|
130
|
+
correct += (predicted == data.y).sum().item()
|
131
|
+
accuracy = correct / total
|
132
|
+
print(f'Accuracy: {accuracy * 100:.2f}%')
|
133
|
+
|
134
|
+
def spacr_transformer(image_dir, seq_file, nr_grnas=1350, lr=0.001, mode='train'):
|
135
|
+
images = load_images(image_dir)
|
6
136
|
|
7
|
-
|
8
|
-
|
9
|
-
|
137
|
+
sequencing_data = load_sequencing_data(seq_file)
|
138
|
+
wells = sequencing_data['well'].unique()
|
139
|
+
graph_data = create_graph(wells, sequencing_data)
|
140
|
+
model = GraphTransformer(in_channels=nr_grnas, hidden_channels=128, out_channels=nr_grnas)
|
141
|
+
criterion = nn.CrossEntropyLoss()
|
142
|
+
optimizer = Adam(model.parameters(), lr=lr)
|
143
|
+
data_list = [graph_data]
|
144
|
+
loader = DataLoader(data_list, batch_size=1, shuffle=True)
|
145
|
+
if mode == 'train':
|
146
|
+
train(model, loader, criterion, optimizer)
|
147
|
+
elif mode == 'eval':
|
148
|
+
evaluate(model, loader)
|
149
|
+
else:
|
150
|
+
raise ValueError('Invalid mode. Use "train" or "eval".')
|
10
151
|
|
11
|
-
|
12
|
-
|
13
|
-
|
14
|
-
|
152
|
+
def _calculate_glcm_features(intensity_image):
|
153
|
+
glcm = greycomatrix(img_as_ubyte(intensity_image), distances=[1, 2, 3, 4], angles=[0, np.pi/4, np.pi/2, 3*np.pi/4], symmetric=True, normed=True)
|
154
|
+
features = {}
|
155
|
+
for prop in ['contrast', 'dissimilarity', 'homogeneity', 'energy', 'correlation', 'ASM']:
|
156
|
+
for i, distance in enumerate([1, 2, 3, 4]):
|
157
|
+
for j, angle in enumerate([0, np.pi/4, np.pi/2, 3*np.pi/4]):
|
158
|
+
features[f'glcm_{prop}_d{distance}_a{angle}'] = greycoprops(glcm, prop)[i, j]
|
159
|
+
return features
|
160
|
+
|
161
|
+
def _calculate_lbp_features(intensity_image, P=8, R=1):
|
162
|
+
lbp = local_binary_pattern(intensity_image, P, R, method='uniform')
|
163
|
+
lbp_hist, _ = np.histogram(lbp, density=True, bins=np.arange(0, P + 3), range=(0, P + 2))
|
164
|
+
return {f'lbp_{i}': val for i, val in enumerate(lbp_hist)}
|
165
|
+
|
166
|
+
def _calculate_wavelet_features(intensity_image, wavelet='db1'):
|
167
|
+
coeffs = pywt.wavedec2(intensity_image, wavelet=wavelet, level=3)
|
168
|
+
features = {}
|
169
|
+
for i, coeff in enumerate(coeffs):
|
170
|
+
if isinstance(coeff, tuple):
|
171
|
+
for j, subband in enumerate(coeff):
|
172
|
+
features[f'wavelet_coeff_{i}_{j}_mean'] = np.mean(subband)
|
173
|
+
features[f'wavelet_coeff_{i}_{j}_std'] = np.std(subband)
|
174
|
+
features[f'wavelet_coeff_{i}_{j}_energy'] = np.sum(subband**2)
|
175
|
+
else:
|
176
|
+
features[f'wavelet_coeff_{i}_mean'] = np.mean(coeff)
|
177
|
+
features[f'wavelet_coeff_{i}_std'] = np.std(coeff)
|
178
|
+
features[f'wavelet_coeff_{i}_energy'] = np.sum(coeff**2)
|
179
|
+
return features
|
180
|
+
|
181
|
+
|
182
|
+
from .measure import _estimate_blur, _calculate_correlation_object_level, _calculate_homogeneity, _periphery_intensity, _outside_intensity, _calculate_radial_distribution, _create_dataframe, _extended_regionprops_table, _calculate_correlation_object_level
|
183
|
+
|
184
|
+
def _intensity_measurements(cell_mask, nucleus_mask, pathogen_mask, cytoplasm_mask, channel_arrays, settings, sizes=[3, 6, 12, 24], periphery=True, outside=True):
|
185
|
+
radial_dist = settings['radial_dist']
|
186
|
+
calculate_correlation = settings['calculate_correlation']
|
187
|
+
homogeneity = settings['homogeneity']
|
188
|
+
distances = settings['homogeneity_distances']
|
189
|
+
|
190
|
+
intensity_props = ["label", "centroid_weighted", "centroid_weighted_local", "max_intensity", "mean_intensity", "min_intensity", "integrated_intensity"]
|
191
|
+
additional_props = ["standard_deviation_intensity", "median_intensity", "sum_intensity", "intensity_range", "mean_absolute_deviation_intensity", "skewness_intensity", "kurtosis_intensity", "variance_intensity", "mode_intensity", "energy_intensity", "entropy_intensity", "harmonic_mean_intensity", "geometric_mean_intensity"]
|
192
|
+
col_lables = ['region_label', 'mean', '5_percentile', '10_percentile', '25_percentile', '50_percentile', '75_percentile', '85_percentile', '95_percentile']
|
193
|
+
cell_dfs, nucleus_dfs, pathogen_dfs, cytoplasm_dfs = [], [], [], []
|
194
|
+
ls = ['cell','nucleus','pathogen','cytoplasm']
|
195
|
+
labels = [cell_mask, nucleus_mask, pathogen_mask, cytoplasm_mask]
|
196
|
+
dfs = [cell_dfs, nucleus_dfs, pathogen_dfs, cytoplasm_dfs]
|
197
|
+
|
198
|
+
for i in range(0,channel_arrays.shape[-1]):
|
199
|
+
channel = channel_arrays[:, :, i]
|
200
|
+
for j, (label, df) in enumerate(zip(labels, dfs)):
|
201
|
+
|
202
|
+
if np.max(label) == 0:
|
203
|
+
empty_df = pd.DataFrame()
|
204
|
+
df.append(empty_df)
|
205
|
+
continue
|
206
|
+
|
207
|
+
mask_intensity_df = _extended_regionprops_table(label, channel, intensity_props)
|
208
|
+
|
209
|
+
# Additional intensity properties
|
210
|
+
region_props = measure.regionprops_table(label, intensity_image=channel, properties=['label'])
|
211
|
+
intensity_values = [channel[region.coords[:, 0], region.coords[:, 1]] for region in measure.regionprops(label)]
|
212
|
+
additional_data = {prop: [] for prop in additional_props}
|
213
|
+
|
214
|
+
for values in intensity_values:
|
215
|
+
if len(values) == 0:
|
216
|
+
continue
|
217
|
+
additional_data["standard_deviation_intensity"].append(np.std(values))
|
218
|
+
additional_data["median_intensity"].append(np.median(values))
|
219
|
+
additional_data["sum_intensity"].append(np.sum(values))
|
220
|
+
additional_data["intensity_range"].append(np.max(values) - np.min(values))
|
221
|
+
additional_data["mean_absolute_deviation_intensity"].append(np.mean(np.abs(values - np.mean(values))))
|
222
|
+
additional_data["skewness_intensity"].append(skew(values))
|
223
|
+
additional_data["kurtosis_intensity"].append(kurtosis(values))
|
224
|
+
additional_data["variance_intensity"].append(np.var(values))
|
225
|
+
additional_data["mode_intensity"].append(mode(values)[0][0])
|
226
|
+
additional_data["energy_intensity"].append(np.sum(values**2))
|
227
|
+
additional_data["entropy_intensity"].append(entropy(values))
|
228
|
+
additional_data["harmonic_mean_intensity"].append(hmean(values))
|
229
|
+
additional_data["geometric_mean_intensity"].append(gmean(values))
|
230
|
+
|
231
|
+
for prop in additional_props:
|
232
|
+
region_props[prop] = additional_data[prop]
|
233
|
+
|
234
|
+
additional_df = pd.DataFrame(region_props)
|
235
|
+
mask_intensity_df = pd.concat([mask_intensity_df.reset_index(drop=True), additional_df.reset_index(drop=True)], axis=1)
|
236
|
+
|
237
|
+
if homogeneity:
|
238
|
+
homogeneity_df = _calculate_homogeneity(label, channel, distances)
|
239
|
+
mask_intensity_df = pd.concat([mask_intensity_df.reset_index(drop=True), homogeneity_df], axis=1)
|
240
|
+
|
241
|
+
if periphery:
|
242
|
+
if ls[j] == 'nucleus' or ls[j] == 'pathogen':
|
243
|
+
periphery_intensity_stats = _periphery_intensity(label, channel)
|
244
|
+
mask_intensity_df = pd.concat([mask_intensity_df, pd.DataFrame(periphery_intensity_stats, columns=[f'periphery_{stat}' for stat in col_lables])],axis=1)
|
245
|
+
|
246
|
+
if outside:
|
247
|
+
if ls[j] == 'nucleus' or ls[j] == 'pathogen':
|
248
|
+
outside_intensity_stats = _outside_intensity(label, channel)
|
249
|
+
mask_intensity_df = pd.concat([mask_intensity_df, pd.DataFrame(outside_intensity_stats, columns=[f'outside_{stat}' for stat in col_lables])], axis=1)
|
250
|
+
|
251
|
+
# Adding GLCM features
|
252
|
+
glcm_features = _calculate_glcm_features(channel)
|
253
|
+
for k, v in glcm_features.items():
|
254
|
+
mask_intensity_df[f'{ls[j]}_channel_{i}_{k}'] = v
|
255
|
+
|
256
|
+
# Adding LBP features
|
257
|
+
lbp_features = _calculate_lbp_features(channel)
|
258
|
+
for k, v in lbp_features.items():
|
259
|
+
mask_intensity_df[f'{ls[j]}_channel_{i}_{k}'] = v
|
260
|
+
|
261
|
+
# Adding Wavelet features
|
262
|
+
wavelet_features = _calculate_wavelet_features(channel)
|
263
|
+
for k, v in wavelet_features.items():
|
264
|
+
mask_intensity_df[f'{ls[j]}_channel_{i}_{k}'] = v
|
265
|
+
|
266
|
+
blur_col = [_estimate_blur(channel[label == region_label]) for region_label in mask_intensity_df['label']]
|
267
|
+
mask_intensity_df[f'{ls[j]}_channel_{i}_blur'] = blur_col
|
268
|
+
|
269
|
+
mask_intensity_df.columns = [f'{ls[j]}_channel_{i}_{col}' if col != 'label' else col for col in mask_intensity_df.columns]
|
270
|
+
df.append(mask_intensity_df)
|
271
|
+
|
272
|
+
if radial_dist:
|
273
|
+
if np.max(nucleus_mask) != 0:
|
274
|
+
nucleus_radial_distributions = _calculate_radial_distribution(cell_mask, nucleus_mask, channel_arrays, num_bins=6)
|
275
|
+
nucleus_df = _create_dataframe(nucleus_radial_distributions, 'nucleus')
|
276
|
+
dfs[1].append(nucleus_df)
|
277
|
+
|
278
|
+
if np.max(nucleus_mask) != 0:
|
279
|
+
pathogen_radial_distributions = _calculate_radial_distribution(cell_mask, pathogen_mask, channel_arrays, num_bins=6)
|
280
|
+
pathogen_df = _create_dataframe(pathogen_radial_distributions, 'pathogen')
|
281
|
+
dfs[2].append(pathogen_df)
|
282
|
+
|
283
|
+
if calculate_correlation:
|
284
|
+
if channel_arrays.shape[-1] >= 2:
|
285
|
+
for i in range(channel_arrays.shape[-1]):
|
286
|
+
for j in range(i+1, channel_arrays.shape[-1]):
|
287
|
+
chan_i = channel_arrays[:, :, i]
|
288
|
+
chan_j = channel_arrays[:, :, j]
|
289
|
+
for m, mask in enumerate(labels):
|
290
|
+
coloc_df = _calculate_correlation_object_level(chan_i, chan_j, mask, settings)
|
291
|
+
coloc_df.columns = [f'{ls[m]}_channel_{i}_channel_{j}_{col}' for col in coloc_df.columns]
|
292
|
+
dfs[m].append(coloc_df)
|
293
|
+
|
294
|
+
return pd.concat(cell_dfs, axis=1), pd.concat(nucleus_dfs, axis=1), pd.concat(pathogen_dfs, axis=1), pd.concat(cytoplasm_dfs, axis=1)
|
15
295
|
|
16
|
-
global vars_dict, root
|
17
|
-
root, vars_dict = initiate_mask_root(1000, 1500)
|
18
|
-
root.mainloop()
|
spacr/annotate_app.py
CHANGED
@@ -38,7 +38,7 @@ class ImageApp:
|
|
38
38
|
- db_update_thread (threading.Thread): A thread for updating the database.
|
39
39
|
"""
|
40
40
|
|
41
|
-
def
|
41
|
+
def __init__(self, root, db_path, image_type=None, channels=None, grid_rows=None, grid_cols=None, image_size=(200, 200), annotation_column='annotate'):
|
42
42
|
"""
|
43
43
|
Initializes an instance of the ImageApp class.
|
44
44
|
|
@@ -383,7 +383,7 @@ def annotate(db, image_type=None, channels=None, geom="1000x1100", img_size=(200
|
|
383
383
|
root = tk.Tk()
|
384
384
|
root.geometry(geom)
|
385
385
|
app = ImageApp(root, db, image_type=image_type, channels=channels, image_size=img_size, grid_rows=rows, grid_cols=columns, annotation_column=annotation_column)
|
386
|
-
|
386
|
+
#app = ImageApp()
|
387
387
|
next_button = tk.Button(root, text="Next", command=app.next_page)
|
388
388
|
next_button.grid(row=app.grid_rows, column=app.grid_cols - 1)
|
389
389
|
back_button = tk.Button(root, text="Back", command=app.previous_page)
|