spacr 0.0.20__py3-none-any.whl → 0.0.21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
spacr/alpha.py CHANGED
@@ -1,18 +1,295 @@
1
- def gui_mask():
2
- from .cli import get_arg_parser
3
- from .version import version_str
1
+ from skimage import measure, feature
2
+ from skimage.filters import gabor
3
+ from skimage.color import rgb2gray
4
+ from skimage.feature.texture import greycomatrix, greycoprops, local_binary_pattern
5
+ from skimage.util import img_as_ubyte
6
+ import numpy as np
7
+ import pandas as pd
8
+ from scipy.stats import skew, kurtosis, entropy, hmean, gmean, mode
9
+ import pywt
4
10
 
5
- args = get_arg_parser().parse_args()
11
+ import os
12
+ import pandas as pd
13
+ from PIL import Image
14
+ from torchvision import transforms
15
+ import torch
16
+ import torch.nn as nn
17
+ import torch.nn.functional as F
18
+ from torch_geometric.data import Data, DataLoader
19
+ from torch_geometric.nn import GCNConv, global_mean_pool
20
+ from torch.optim import Adam
21
+
22
+ # Step 1: Data Preparation
23
+
24
+ # Load images
25
+ def load_images(image_dir):
26
+ images = {}
27
+ for filename in os.listdir(image_dir):
28
+ if filename.endswith(".png"):
29
+ img = Image.open(os.path.join(image_dir, filename))
30
+ images[filename] = img
31
+ return images
32
+
33
+ # Load sequencing data
34
+ def load_sequencing_data(seq_file):
35
+ seq_data = pd.read_csv(seq_file)
36
+ return seq_data
37
+
38
+ # Step 2: Data Representation
39
+
40
+ # Image Representation (Using a simple CNN for feature extraction)
41
+ class CNNFeatureExtractor(nn.Module):
42
+ def __init__(self):
43
+ super(CNNFeatureExtractor, self).__init__()
44
+ self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)
45
+ self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
46
+ self.fc = nn.Linear(32 * 8 * 8, 128) # Assuming input images are 64x64
47
+
48
+ def forward(self, x):
49
+ x = F.relu(self.conv1(x))
50
+ x = F.max_pool2d(x, 2)
51
+ x = F.relu(self.conv2(x))
52
+ x = F.max_pool2d(x, 2)
53
+ x = x.view(x.size(0), -1)
54
+ x = self.fc(x)
55
+ return x
56
+
57
+ # Graph Representation
58
+ def create_graph(wells, sequencing_data):
59
+ nodes = []
60
+ edges = []
61
+ node_features = []
62
+
63
+ for well in wells:
64
+ # Add node for each well
65
+ nodes.append(well)
66
+
67
+ # Get sequencing data for the well
68
+ seq_info = sequencing_data[sequencing_data['well'] == well]
69
+
70
+ # Create node features based on gene knockouts and abundances
71
+ features = torch.tensor(seq_info['abundance'].values, dtype=torch.float)
72
+ node_features.append(features)
73
+
74
+ # Define edges (for simplicity, assume fully connected graph)
75
+ for other_well in wells:
76
+ if other_well != well:
77
+ edges.append((wells.index(well), wells.index(other_well)))
78
+
79
+ edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous()
80
+ x = torch.stack(node_features)
81
+
82
+ data = Data(x=x, edge_index=edge_index)
83
+ return data
84
+
85
+ # Step 3: Model Architecture
86
+
87
+ class GraphTransformer(nn.Module):
88
+ def __init__(self, in_channels, hidden_channels, out_channels):
89
+ super(GraphTransformer, self).__init__()
90
+ self.conv1 = GCNConv(in_channels, hidden_channels)
91
+ self.conv2 = GCNConv(hidden_channels, hidden_channels)
92
+ self.fc = nn.Linear(hidden_channels, out_channels)
93
+ self.attention = nn.MultiheadAttention(hidden_channels, num_heads=8)
94
+
95
+ def forward(self, x, edge_index, batch):
96
+ x = F.relu(self.conv1(x, edge_index))
97
+ x = F.relu(self.conv2(x, edge_index))
98
+
99
+ # Apply attention mechanism
100
+ x, _ = self.attention(x.unsqueeze(1), x.unsqueeze(1), x.unsqueeze(1))
101
+ x = x.squeeze(1)
102
+
103
+ x = global_mean_pool(x, batch)
104
+ x = self.fc(x)
105
+ return x
106
+
107
+ # Step 4: Training
108
+
109
+ # Training Loop
110
+ def train(model, data_loader, criterion, optimizer, epochs=10):
111
+ model.train()
112
+ for epoch in range(epochs):
113
+ for data in data_loader:
114
+ optimizer.zero_grad()
115
+ out = model(data.x, data.edge_index, data.batch)
116
+ loss = criterion(out, data.y)
117
+ loss.backward()
118
+ optimizer.step()
119
+ print(f'Epoch {epoch+1}, Loss: {loss.item()}')
120
+
121
+ def evaluate(model, data_loader):
122
+ model.eval()
123
+ correct = 0
124
+ total = 0
125
+ with torch.no_grad():
126
+ for data in data_loader:
127
+ out = model(data.x, data.edge_index, data.batch)
128
+ _, predicted = torch.max(out, 1)
129
+ total += data.y.size(0)
130
+ correct += (predicted == data.y).sum().item()
131
+ accuracy = correct / total
132
+ print(f'Accuracy: {accuracy * 100:.2f}%')
133
+
134
+ def spacr_transformer(image_dir, seq_file, nr_grnas=1350, lr=0.001, mode='train'):
135
+ images = load_images(image_dir)
6
136
 
7
- if args.version:
8
- print(version_str)
9
- return
137
+ sequencing_data = load_sequencing_data(seq_file)
138
+ wells = sequencing_data['well'].unique()
139
+ graph_data = create_graph(wells, sequencing_data)
140
+ model = GraphTransformer(in_channels=nr_grnas, hidden_channels=128, out_channels=nr_grnas)
141
+ criterion = nn.CrossEntropyLoss()
142
+ optimizer = Adam(model.parameters(), lr=lr)
143
+ data_list = [graph_data]
144
+ loader = DataLoader(data_list, batch_size=1, shuffle=True)
145
+ if mode == 'train':
146
+ train(model, loader, criterion, optimizer)
147
+ elif mode == 'eval':
148
+ evaluate(model, loader)
149
+ else:
150
+ raise ValueError('Invalid mode. Use "train" or "eval".')
10
151
 
11
- if args.headless:
12
- settings = {}
13
- spacr.core.preprocess_generate_masks(settings['src'], settings=settings, advanced_settings={})
14
- return
152
+ def _calculate_glcm_features(intensity_image):
153
+ glcm = greycomatrix(img_as_ubyte(intensity_image), distances=[1, 2, 3, 4], angles=[0, np.pi/4, np.pi/2, 3*np.pi/4], symmetric=True, normed=True)
154
+ features = {}
155
+ for prop in ['contrast', 'dissimilarity', 'homogeneity', 'energy', 'correlation', 'ASM']:
156
+ for i, distance in enumerate([1, 2, 3, 4]):
157
+ for j, angle in enumerate([0, np.pi/4, np.pi/2, 3*np.pi/4]):
158
+ features[f'glcm_{prop}_d{distance}_a{angle}'] = greycoprops(glcm, prop)[i, j]
159
+ return features
160
+
161
+ def _calculate_lbp_features(intensity_image, P=8, R=1):
162
+ lbp = local_binary_pattern(intensity_image, P, R, method='uniform')
163
+ lbp_hist, _ = np.histogram(lbp, density=True, bins=np.arange(0, P + 3), range=(0, P + 2))
164
+ return {f'lbp_{i}': val for i, val in enumerate(lbp_hist)}
165
+
166
+ def _calculate_wavelet_features(intensity_image, wavelet='db1'):
167
+ coeffs = pywt.wavedec2(intensity_image, wavelet=wavelet, level=3)
168
+ features = {}
169
+ for i, coeff in enumerate(coeffs):
170
+ if isinstance(coeff, tuple):
171
+ for j, subband in enumerate(coeff):
172
+ features[f'wavelet_coeff_{i}_{j}_mean'] = np.mean(subband)
173
+ features[f'wavelet_coeff_{i}_{j}_std'] = np.std(subband)
174
+ features[f'wavelet_coeff_{i}_{j}_energy'] = np.sum(subband**2)
175
+ else:
176
+ features[f'wavelet_coeff_{i}_mean'] = np.mean(coeff)
177
+ features[f'wavelet_coeff_{i}_std'] = np.std(coeff)
178
+ features[f'wavelet_coeff_{i}_energy'] = np.sum(coeff**2)
179
+ return features
180
+
181
+
182
+ from .measure import _estimate_blur, _calculate_correlation_object_level, _calculate_homogeneity, _periphery_intensity, _outside_intensity, _calculate_radial_distribution, _create_dataframe, _extended_regionprops_table, _calculate_correlation_object_level
183
+
184
+ def _intensity_measurements(cell_mask, nucleus_mask, pathogen_mask, cytoplasm_mask, channel_arrays, settings, sizes=[3, 6, 12, 24], periphery=True, outside=True):
185
+ radial_dist = settings['radial_dist']
186
+ calculate_correlation = settings['calculate_correlation']
187
+ homogeneity = settings['homogeneity']
188
+ distances = settings['homogeneity_distances']
189
+
190
+ intensity_props = ["label", "centroid_weighted", "centroid_weighted_local", "max_intensity", "mean_intensity", "min_intensity", "integrated_intensity"]
191
+ additional_props = ["standard_deviation_intensity", "median_intensity", "sum_intensity", "intensity_range", "mean_absolute_deviation_intensity", "skewness_intensity", "kurtosis_intensity", "variance_intensity", "mode_intensity", "energy_intensity", "entropy_intensity", "harmonic_mean_intensity", "geometric_mean_intensity"]
192
+ col_lables = ['region_label', 'mean', '5_percentile', '10_percentile', '25_percentile', '50_percentile', '75_percentile', '85_percentile', '95_percentile']
193
+ cell_dfs, nucleus_dfs, pathogen_dfs, cytoplasm_dfs = [], [], [], []
194
+ ls = ['cell','nucleus','pathogen','cytoplasm']
195
+ labels = [cell_mask, nucleus_mask, pathogen_mask, cytoplasm_mask]
196
+ dfs = [cell_dfs, nucleus_dfs, pathogen_dfs, cytoplasm_dfs]
197
+
198
+ for i in range(0,channel_arrays.shape[-1]):
199
+ channel = channel_arrays[:, :, i]
200
+ for j, (label, df) in enumerate(zip(labels, dfs)):
201
+
202
+ if np.max(label) == 0:
203
+ empty_df = pd.DataFrame()
204
+ df.append(empty_df)
205
+ continue
206
+
207
+ mask_intensity_df = _extended_regionprops_table(label, channel, intensity_props)
208
+
209
+ # Additional intensity properties
210
+ region_props = measure.regionprops_table(label, intensity_image=channel, properties=['label'])
211
+ intensity_values = [channel[region.coords[:, 0], region.coords[:, 1]] for region in measure.regionprops(label)]
212
+ additional_data = {prop: [] for prop in additional_props}
213
+
214
+ for values in intensity_values:
215
+ if len(values) == 0:
216
+ continue
217
+ additional_data["standard_deviation_intensity"].append(np.std(values))
218
+ additional_data["median_intensity"].append(np.median(values))
219
+ additional_data["sum_intensity"].append(np.sum(values))
220
+ additional_data["intensity_range"].append(np.max(values) - np.min(values))
221
+ additional_data["mean_absolute_deviation_intensity"].append(np.mean(np.abs(values - np.mean(values))))
222
+ additional_data["skewness_intensity"].append(skew(values))
223
+ additional_data["kurtosis_intensity"].append(kurtosis(values))
224
+ additional_data["variance_intensity"].append(np.var(values))
225
+ additional_data["mode_intensity"].append(mode(values)[0][0])
226
+ additional_data["energy_intensity"].append(np.sum(values**2))
227
+ additional_data["entropy_intensity"].append(entropy(values))
228
+ additional_data["harmonic_mean_intensity"].append(hmean(values))
229
+ additional_data["geometric_mean_intensity"].append(gmean(values))
230
+
231
+ for prop in additional_props:
232
+ region_props[prop] = additional_data[prop]
233
+
234
+ additional_df = pd.DataFrame(region_props)
235
+ mask_intensity_df = pd.concat([mask_intensity_df.reset_index(drop=True), additional_df.reset_index(drop=True)], axis=1)
236
+
237
+ if homogeneity:
238
+ homogeneity_df = _calculate_homogeneity(label, channel, distances)
239
+ mask_intensity_df = pd.concat([mask_intensity_df.reset_index(drop=True), homogeneity_df], axis=1)
240
+
241
+ if periphery:
242
+ if ls[j] == 'nucleus' or ls[j] == 'pathogen':
243
+ periphery_intensity_stats = _periphery_intensity(label, channel)
244
+ mask_intensity_df = pd.concat([mask_intensity_df, pd.DataFrame(periphery_intensity_stats, columns=[f'periphery_{stat}' for stat in col_lables])],axis=1)
245
+
246
+ if outside:
247
+ if ls[j] == 'nucleus' or ls[j] == 'pathogen':
248
+ outside_intensity_stats = _outside_intensity(label, channel)
249
+ mask_intensity_df = pd.concat([mask_intensity_df, pd.DataFrame(outside_intensity_stats, columns=[f'outside_{stat}' for stat in col_lables])], axis=1)
250
+
251
+ # Adding GLCM features
252
+ glcm_features = _calculate_glcm_features(channel)
253
+ for k, v in glcm_features.items():
254
+ mask_intensity_df[f'{ls[j]}_channel_{i}_{k}'] = v
255
+
256
+ # Adding LBP features
257
+ lbp_features = _calculate_lbp_features(channel)
258
+ for k, v in lbp_features.items():
259
+ mask_intensity_df[f'{ls[j]}_channel_{i}_{k}'] = v
260
+
261
+ # Adding Wavelet features
262
+ wavelet_features = _calculate_wavelet_features(channel)
263
+ for k, v in wavelet_features.items():
264
+ mask_intensity_df[f'{ls[j]}_channel_{i}_{k}'] = v
265
+
266
+ blur_col = [_estimate_blur(channel[label == region_label]) for region_label in mask_intensity_df['label']]
267
+ mask_intensity_df[f'{ls[j]}_channel_{i}_blur'] = blur_col
268
+
269
+ mask_intensity_df.columns = [f'{ls[j]}_channel_{i}_{col}' if col != 'label' else col for col in mask_intensity_df.columns]
270
+ df.append(mask_intensity_df)
271
+
272
+ if radial_dist:
273
+ if np.max(nucleus_mask) != 0:
274
+ nucleus_radial_distributions = _calculate_radial_distribution(cell_mask, nucleus_mask, channel_arrays, num_bins=6)
275
+ nucleus_df = _create_dataframe(nucleus_radial_distributions, 'nucleus')
276
+ dfs[1].append(nucleus_df)
277
+
278
+ if np.max(nucleus_mask) != 0:
279
+ pathogen_radial_distributions = _calculate_radial_distribution(cell_mask, pathogen_mask, channel_arrays, num_bins=6)
280
+ pathogen_df = _create_dataframe(pathogen_radial_distributions, 'pathogen')
281
+ dfs[2].append(pathogen_df)
282
+
283
+ if calculate_correlation:
284
+ if channel_arrays.shape[-1] >= 2:
285
+ for i in range(channel_arrays.shape[-1]):
286
+ for j in range(i+1, channel_arrays.shape[-1]):
287
+ chan_i = channel_arrays[:, :, i]
288
+ chan_j = channel_arrays[:, :, j]
289
+ for m, mask in enumerate(labels):
290
+ coloc_df = _calculate_correlation_object_level(chan_i, chan_j, mask, settings)
291
+ coloc_df.columns = [f'{ls[m]}_channel_{i}_channel_{j}_{col}' for col in coloc_df.columns]
292
+ dfs[m].append(coloc_df)
293
+
294
+ return pd.concat(cell_dfs, axis=1), pd.concat(nucleus_dfs, axis=1), pd.concat(pathogen_dfs, axis=1), pd.concat(cytoplasm_dfs, axis=1)
15
295
 
16
- global vars_dict, root
17
- root, vars_dict = initiate_mask_root(1000, 1500)
18
- root.mainloop()
spacr/annotate_app.py CHANGED
@@ -38,7 +38,7 @@ class ImageApp:
38
38
  - db_update_thread (threading.Thread): A thread for updating the database.
39
39
  """
40
40
 
41
- def _init_(self, root, db_path, image_type=None, channels=None, grid_rows=None, grid_cols=None, image_size=(200, 200), annotation_column='annotate'):
41
+ def __init__(self, root, db_path, image_type=None, channels=None, grid_rows=None, grid_cols=None, image_size=(200, 200), annotation_column='annotate'):
42
42
  """
43
43
  Initializes an instance of the ImageApp class.
44
44
 
@@ -383,7 +383,7 @@ def annotate(db, image_type=None, channels=None, geom="1000x1100", img_size=(200
383
383
  root = tk.Tk()
384
384
  root.geometry(geom)
385
385
  app = ImageApp(root, db, image_type=image_type, channels=channels, image_size=img_size, grid_rows=rows, grid_cols=columns, annotation_column=annotation_column)
386
-
386
+ #app = ImageApp()
387
387
  next_button = tk.Button(root, text="Next", command=app.next_page)
388
388
  next_button.grid(row=app.grid_rows, column=app.grid_cols - 1)
389
389
  back_button = tk.Button(root, text="Back", command=app.previous_page)