spacr 0.0.20__py3-none-any.whl → 0.0.35__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
spacr/alpha.py CHANGED
@@ -1,18 +1,295 @@
1
- def gui_mask():
2
- from .cli import get_arg_parser
3
- from .version import version_str
1
+ from skimage import measure, feature
2
+ from skimage.filters import gabor
3
+ from skimage.color import rgb2gray
4
+ from skimage.feature.texture import greycomatrix, greycoprops, local_binary_pattern
5
+ from skimage.util import img_as_ubyte
6
+ import numpy as np
7
+ import pandas as pd
8
+ from scipy.stats import skew, kurtosis, entropy, hmean, gmean, mode
9
+ import pywt
4
10
 
5
- args = get_arg_parser().parse_args()
11
+ import os
12
+ import pandas as pd
13
+ from PIL import Image
14
+ from torchvision import transforms
15
+ import torch
16
+ import torch.nn as nn
17
+ import torch.nn.functional as F
18
+ from torch_geometric.data import Data, DataLoader
19
+ from torch_geometric.nn import GCNConv, global_mean_pool
20
+ from torch.optim import Adam
21
+
22
+ # Step 1: Data Preparation
23
+
24
+ # Load images
25
+ def load_images(image_dir):
26
+ images = {}
27
+ for filename in os.listdir(image_dir):
28
+ if filename.endswith(".png"):
29
+ img = Image.open(os.path.join(image_dir, filename))
30
+ images[filename] = img
31
+ return images
32
+
33
+ # Load sequencing data
34
+ def load_sequencing_data(seq_file):
35
+ seq_data = pd.read_csv(seq_file)
36
+ return seq_data
37
+
38
+ # Step 2: Data Representation
39
+
40
+ # Image Representation (Using a simple CNN for feature extraction)
41
+ class CNNFeatureExtractor(nn.Module):
42
+ def __init__(self):
43
+ super(CNNFeatureExtractor, self).__init__()
44
+ self.conv1 = nn.Conv2d(3, 16, kernel_size=3, stride=1, padding=1)
45
+ self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1, padding=1)
46
+ self.fc = nn.Linear(32 * 8 * 8, 128) # Assuming input images are 64x64
47
+
48
+ def forward(self, x):
49
+ x = F.relu(self.conv1(x))
50
+ x = F.max_pool2d(x, 2)
51
+ x = F.relu(self.conv2(x))
52
+ x = F.max_pool2d(x, 2)
53
+ x = x.view(x.size(0), -1)
54
+ x = self.fc(x)
55
+ return x
56
+
57
+ # Graph Representation
58
+ def create_graph(wells, sequencing_data):
59
+ nodes = []
60
+ edges = []
61
+ node_features = []
62
+
63
+ for well in wells:
64
+ # Add node for each well
65
+ nodes.append(well)
66
+
67
+ # Get sequencing data for the well
68
+ seq_info = sequencing_data[sequencing_data['well'] == well]
69
+
70
+ # Create node features based on gene knockouts and abundances
71
+ features = torch.tensor(seq_info['abundance'].values, dtype=torch.float)
72
+ node_features.append(features)
73
+
74
+ # Define edges (for simplicity, assume fully connected graph)
75
+ for other_well in wells:
76
+ if other_well != well:
77
+ edges.append((wells.index(well), wells.index(other_well)))
78
+
79
+ edge_index = torch.tensor(edges, dtype=torch.long).t().contiguous()
80
+ x = torch.stack(node_features)
81
+
82
+ data = Data(x=x, edge_index=edge_index)
83
+ return data
84
+
85
+ # Step 3: Model Architecture
86
+
87
+ class GraphTransformer(nn.Module):
88
+ def __init__(self, in_channels, hidden_channels, out_channels):
89
+ super(GraphTransformer, self).__init__()
90
+ self.conv1 = GCNConv(in_channels, hidden_channels)
91
+ self.conv2 = GCNConv(hidden_channels, hidden_channels)
92
+ self.fc = nn.Linear(hidden_channels, out_channels)
93
+ self.attention = nn.MultiheadAttention(hidden_channels, num_heads=8)
94
+
95
+ def forward(self, x, edge_index, batch):
96
+ x = F.relu(self.conv1(x, edge_index))
97
+ x = F.relu(self.conv2(x, edge_index))
98
+
99
+ # Apply attention mechanism
100
+ x, _ = self.attention(x.unsqueeze(1), x.unsqueeze(1), x.unsqueeze(1))
101
+ x = x.squeeze(1)
102
+
103
+ x = global_mean_pool(x, batch)
104
+ x = self.fc(x)
105
+ return x
106
+
107
+ # Step 4: Training
108
+
109
+ # Training Loop
110
+ def train(model, data_loader, criterion, optimizer, epochs=10):
111
+ model.train()
112
+ for epoch in range(epochs):
113
+ for data in data_loader:
114
+ optimizer.zero_grad()
115
+ out = model(data.x, data.edge_index, data.batch)
116
+ loss = criterion(out, data.y)
117
+ loss.backward()
118
+ optimizer.step()
119
+ print(f'Epoch {epoch+1}, Loss: {loss.item()}')
120
+
121
+ def evaluate(model, data_loader):
122
+ model.eval()
123
+ correct = 0
124
+ total = 0
125
+ with torch.no_grad():
126
+ for data in data_loader:
127
+ out = model(data.x, data.edge_index, data.batch)
128
+ _, predicted = torch.max(out, 1)
129
+ total += data.y.size(0)
130
+ correct += (predicted == data.y).sum().item()
131
+ accuracy = correct / total
132
+ print(f'Accuracy: {accuracy * 100:.2f}%')
133
+
134
+ def spacr_transformer(image_dir, seq_file, nr_grnas=1350, lr=0.001, mode='train'):
135
+ images = load_images(image_dir)
6
136
 
7
- if args.version:
8
- print(version_str)
9
- return
137
+ sequencing_data = load_sequencing_data(seq_file)
138
+ wells = sequencing_data['well'].unique()
139
+ graph_data = create_graph(wells, sequencing_data)
140
+ model = GraphTransformer(in_channels=nr_grnas, hidden_channels=128, out_channels=nr_grnas)
141
+ criterion = nn.CrossEntropyLoss()
142
+ optimizer = Adam(model.parameters(), lr=lr)
143
+ data_list = [graph_data]
144
+ loader = DataLoader(data_list, batch_size=1, shuffle=True)
145
+ if mode == 'train':
146
+ train(model, loader, criterion, optimizer)
147
+ elif mode == 'eval':
148
+ evaluate(model, loader)
149
+ else:
150
+ raise ValueError('Invalid mode. Use "train" or "eval".')
10
151
 
11
- if args.headless:
12
- settings = {}
13
- spacr.core.preprocess_generate_masks(settings['src'], settings=settings, advanced_settings={})
14
- return
152
+ def _calculate_glcm_features(intensity_image):
153
+ glcm = greycomatrix(img_as_ubyte(intensity_image), distances=[1, 2, 3, 4], angles=[0, np.pi/4, np.pi/2, 3*np.pi/4], symmetric=True, normed=True)
154
+ features = {}
155
+ for prop in ['contrast', 'dissimilarity', 'homogeneity', 'energy', 'correlation', 'ASM']:
156
+ for i, distance in enumerate([1, 2, 3, 4]):
157
+ for j, angle in enumerate([0, np.pi/4, np.pi/2, 3*np.pi/4]):
158
+ features[f'glcm_{prop}_d{distance}_a{angle}'] = greycoprops(glcm, prop)[i, j]
159
+ return features
160
+
161
+ def _calculate_lbp_features(intensity_image, P=8, R=1):
162
+ lbp = local_binary_pattern(intensity_image, P, R, method='uniform')
163
+ lbp_hist, _ = np.histogram(lbp, density=True, bins=np.arange(0, P + 3), range=(0, P + 2))
164
+ return {f'lbp_{i}': val for i, val in enumerate(lbp_hist)}
165
+
166
+ def _calculate_wavelet_features(intensity_image, wavelet='db1'):
167
+ coeffs = pywt.wavedec2(intensity_image, wavelet=wavelet, level=3)
168
+ features = {}
169
+ for i, coeff in enumerate(coeffs):
170
+ if isinstance(coeff, tuple):
171
+ for j, subband in enumerate(coeff):
172
+ features[f'wavelet_coeff_{i}_{j}_mean'] = np.mean(subband)
173
+ features[f'wavelet_coeff_{i}_{j}_std'] = np.std(subband)
174
+ features[f'wavelet_coeff_{i}_{j}_energy'] = np.sum(subband**2)
175
+ else:
176
+ features[f'wavelet_coeff_{i}_mean'] = np.mean(coeff)
177
+ features[f'wavelet_coeff_{i}_std'] = np.std(coeff)
178
+ features[f'wavelet_coeff_{i}_energy'] = np.sum(coeff**2)
179
+ return features
180
+
181
+
182
+ from .measure import _estimate_blur, _calculate_correlation_object_level, _calculate_homogeneity, _periphery_intensity, _outside_intensity, _calculate_radial_distribution, _create_dataframe, _extended_regionprops_table, _calculate_correlation_object_level
183
+
184
+ def _intensity_measurements(cell_mask, nucleus_mask, pathogen_mask, cytoplasm_mask, channel_arrays, settings, sizes=[3, 6, 12, 24], periphery=True, outside=True):
185
+ radial_dist = settings['radial_dist']
186
+ calculate_correlation = settings['calculate_correlation']
187
+ homogeneity = settings['homogeneity']
188
+ distances = settings['homogeneity_distances']
189
+
190
+ intensity_props = ["label", "centroid_weighted", "centroid_weighted_local", "max_intensity", "mean_intensity", "min_intensity", "integrated_intensity"]
191
+ additional_props = ["standard_deviation_intensity", "median_intensity", "sum_intensity", "intensity_range", "mean_absolute_deviation_intensity", "skewness_intensity", "kurtosis_intensity", "variance_intensity", "mode_intensity", "energy_intensity", "entropy_intensity", "harmonic_mean_intensity", "geometric_mean_intensity"]
192
+ col_lables = ['region_label', 'mean', '5_percentile', '10_percentile', '25_percentile', '50_percentile', '75_percentile', '85_percentile', '95_percentile']
193
+ cell_dfs, nucleus_dfs, pathogen_dfs, cytoplasm_dfs = [], [], [], []
194
+ ls = ['cell','nucleus','pathogen','cytoplasm']
195
+ labels = [cell_mask, nucleus_mask, pathogen_mask, cytoplasm_mask]
196
+ dfs = [cell_dfs, nucleus_dfs, pathogen_dfs, cytoplasm_dfs]
197
+
198
+ for i in range(0,channel_arrays.shape[-1]):
199
+ channel = channel_arrays[:, :, i]
200
+ for j, (label, df) in enumerate(zip(labels, dfs)):
201
+
202
+ if np.max(label) == 0:
203
+ empty_df = pd.DataFrame()
204
+ df.append(empty_df)
205
+ continue
206
+
207
+ mask_intensity_df = _extended_regionprops_table(label, channel, intensity_props)
208
+
209
+ # Additional intensity properties
210
+ region_props = measure.regionprops_table(label, intensity_image=channel, properties=['label'])
211
+ intensity_values = [channel[region.coords[:, 0], region.coords[:, 1]] for region in measure.regionprops(label)]
212
+ additional_data = {prop: [] for prop in additional_props}
213
+
214
+ for values in intensity_values:
215
+ if len(values) == 0:
216
+ continue
217
+ additional_data["standard_deviation_intensity"].append(np.std(values))
218
+ additional_data["median_intensity"].append(np.median(values))
219
+ additional_data["sum_intensity"].append(np.sum(values))
220
+ additional_data["intensity_range"].append(np.max(values) - np.min(values))
221
+ additional_data["mean_absolute_deviation_intensity"].append(np.mean(np.abs(values - np.mean(values))))
222
+ additional_data["skewness_intensity"].append(skew(values))
223
+ additional_data["kurtosis_intensity"].append(kurtosis(values))
224
+ additional_data["variance_intensity"].append(np.var(values))
225
+ additional_data["mode_intensity"].append(mode(values)[0][0])
226
+ additional_data["energy_intensity"].append(np.sum(values**2))
227
+ additional_data["entropy_intensity"].append(entropy(values))
228
+ additional_data["harmonic_mean_intensity"].append(hmean(values))
229
+ additional_data["geometric_mean_intensity"].append(gmean(values))
230
+
231
+ for prop in additional_props:
232
+ region_props[prop] = additional_data[prop]
233
+
234
+ additional_df = pd.DataFrame(region_props)
235
+ mask_intensity_df = pd.concat([mask_intensity_df.reset_index(drop=True), additional_df.reset_index(drop=True)], axis=1)
236
+
237
+ if homogeneity:
238
+ homogeneity_df = _calculate_homogeneity(label, channel, distances)
239
+ mask_intensity_df = pd.concat([mask_intensity_df.reset_index(drop=True), homogeneity_df], axis=1)
240
+
241
+ if periphery:
242
+ if ls[j] == 'nucleus' or ls[j] == 'pathogen':
243
+ periphery_intensity_stats = _periphery_intensity(label, channel)
244
+ mask_intensity_df = pd.concat([mask_intensity_df, pd.DataFrame(periphery_intensity_stats, columns=[f'periphery_{stat}' for stat in col_lables])],axis=1)
245
+
246
+ if outside:
247
+ if ls[j] == 'nucleus' or ls[j] == 'pathogen':
248
+ outside_intensity_stats = _outside_intensity(label, channel)
249
+ mask_intensity_df = pd.concat([mask_intensity_df, pd.DataFrame(outside_intensity_stats, columns=[f'outside_{stat}' for stat in col_lables])], axis=1)
250
+
251
+ # Adding GLCM features
252
+ glcm_features = _calculate_glcm_features(channel)
253
+ for k, v in glcm_features.items():
254
+ mask_intensity_df[f'{ls[j]}_channel_{i}_{k}'] = v
255
+
256
+ # Adding LBP features
257
+ lbp_features = _calculate_lbp_features(channel)
258
+ for k, v in lbp_features.items():
259
+ mask_intensity_df[f'{ls[j]}_channel_{i}_{k}'] = v
260
+
261
+ # Adding Wavelet features
262
+ wavelet_features = _calculate_wavelet_features(channel)
263
+ for k, v in wavelet_features.items():
264
+ mask_intensity_df[f'{ls[j]}_channel_{i}_{k}'] = v
265
+
266
+ blur_col = [_estimate_blur(channel[label == region_label]) for region_label in mask_intensity_df['label']]
267
+ mask_intensity_df[f'{ls[j]}_channel_{i}_blur'] = blur_col
268
+
269
+ mask_intensity_df.columns = [f'{ls[j]}_channel_{i}_{col}' if col != 'label' else col for col in mask_intensity_df.columns]
270
+ df.append(mask_intensity_df)
271
+
272
+ if radial_dist:
273
+ if np.max(nucleus_mask) != 0:
274
+ nucleus_radial_distributions = _calculate_radial_distribution(cell_mask, nucleus_mask, channel_arrays, num_bins=6)
275
+ nucleus_df = _create_dataframe(nucleus_radial_distributions, 'nucleus')
276
+ dfs[1].append(nucleus_df)
277
+
278
+ if np.max(nucleus_mask) != 0:
279
+ pathogen_radial_distributions = _calculate_radial_distribution(cell_mask, pathogen_mask, channel_arrays, num_bins=6)
280
+ pathogen_df = _create_dataframe(pathogen_radial_distributions, 'pathogen')
281
+ dfs[2].append(pathogen_df)
282
+
283
+ if calculate_correlation:
284
+ if channel_arrays.shape[-1] >= 2:
285
+ for i in range(channel_arrays.shape[-1]):
286
+ for j in range(i+1, channel_arrays.shape[-1]):
287
+ chan_i = channel_arrays[:, :, i]
288
+ chan_j = channel_arrays[:, :, j]
289
+ for m, mask in enumerate(labels):
290
+ coloc_df = _calculate_correlation_object_level(chan_i, chan_j, mask, settings)
291
+ coloc_df.columns = [f'{ls[m]}_channel_{i}_channel_{j}_{col}' for col in coloc_df.columns]
292
+ dfs[m].append(coloc_df)
293
+
294
+ return pd.concat(cell_dfs, axis=1), pd.concat(nucleus_dfs, axis=1), pd.concat(pathogen_dfs, axis=1), pd.concat(cytoplasm_dfs, axis=1)
15
295
 
16
- global vars_dict, root
17
- root, vars_dict = initiate_mask_root(1000, 1500)
18
- root.mainloop()
spacr/annotate_app.py CHANGED
@@ -13,7 +13,7 @@ from ttkthemes import ThemedTk
13
13
 
14
14
  from .logger import log_function_call
15
15
 
16
- from .gui_utils import ScrollableFrame, set_default_font, set_dark_style, create_dark_mode
16
+ from .gui_utils import ScrollableFrame, set_default_font, set_dark_style, create_dark_mode, style_text_boxes, create_menu_bar
17
17
 
18
18
  class ImageApp:
19
19
  """
@@ -38,7 +38,7 @@ class ImageApp:
38
38
  - db_update_thread (threading.Thread): A thread for updating the database.
39
39
  """
40
40
 
41
- def _init_(self, root, db_path, image_type=None, channels=None, grid_rows=None, grid_cols=None, image_size=(200, 200), annotation_column='annotate'):
41
+ def __init__(self, root, db_path, image_type=None, channels=None, grid_rows=None, grid_cols=None, image_size=(200, 200), annotation_column='annotate'):
42
42
  """
43
43
  Initializes an instance of the ImageApp class.
44
44
 
@@ -383,7 +383,7 @@ def annotate(db, image_type=None, channels=None, geom="1000x1100", img_size=(200
383
383
  root = tk.Tk()
384
384
  root.geometry(geom)
385
385
  app = ImageApp(root, db, image_type=image_type, channels=channels, image_size=img_size, grid_rows=rows, grid_cols=columns, annotation_column=annotation_column)
386
-
386
+ #app = ImageApp()
387
387
  next_button = tk.Button(root, text="Next", command=app.next_page)
388
388
  next_button.grid(row=app.grid_rows, column=app.grid_cols - 1)
389
389
  back_button = tk.Button(root, text="Back", command=app.previous_page)
@@ -425,7 +425,8 @@ def initiate_annotation_app_root(width, height):
425
425
  root = ThemedTk(theme=theme)
426
426
  style = ttk.Style(root)
427
427
  set_dark_style(style)
428
- set_default_font(root, font_name="Arial", size=10)
428
+ style_text_boxes(style)
429
+ set_default_font(root, font_name="Arial", size=8)
429
430
  root.geometry(f"{width}x{height}")
430
431
  root.title("Annotation App")
431
432
 
@@ -473,6 +474,7 @@ def initiate_annotation_app_root(width, height):
473
474
  new_root = tk.Tk()
474
475
  new_root.geometry(f"{width}x{height}")
475
476
  new_root.title("Mask Application")
477
+
476
478
 
477
479
  # Start the annotation application in the new root window
478
480
  app_instance = annotate(db, image_type, channels, annotation_column, geom, img_size, rows, columns)
@@ -482,7 +484,7 @@ def initiate_annotation_app_root(width, height):
482
484
  create_dark_mode(root, style, console_output=None)
483
485
 
484
486
  run_button = ttk.Button(scrollable_frame.scrollable_frame, text="Run", command=run_app)
485
- run_button.grid(row=row, column=0, columnspan=2, pady=10)
487
+ run_button.grid(row=row, column=0, columnspan=2, pady=10, padx=10)
486
488
 
487
489
  return root
488
490
 
spacr/chris.py ADDED
@@ -0,0 +1,50 @@
1
+ import pandas as pd
2
+ import numpy as np
3
+ from .core import _permutation_importance, _shap_analysis
4
+
5
+ def join_measurments_and_annotation(src, tables = ['cell', 'nucleus', 'pathogen','cytoplasm']):
6
+
7
+ from .io import _read_and_merge_data, _read_db
8
+
9
+ db_loc = [src+'/measurements/measurements.db']
10
+ loc = src+'/measurements/measurements.db'
11
+ df, _ = _read_and_merge_data(db_loc,
12
+ tables,
13
+ verbose=True,
14
+ include_multinucleated=True,
15
+ include_multiinfected=True,
16
+ include_noninfected=True)
17
+
18
+ paths_df = _read_db(loc, tables=['png_list'])
19
+
20
+ merged_df = pd.merge(df, paths_df[0], on='prcfo', how='left')
21
+
22
+ return merged_df
23
+
24
+ def plate_heatmap(src, model_type='xgboost', variable='predictions', grouping='mean', min_max='allq', cmap='viridis', channel_of_interest=3, min_count=25, n_estimators=100, col_to_compare='col', pos='c1', neg='c2', exclude=None, n_repeats=10, clean=True, nr_to_plot=20, verbose=False, n_jobs=-1):
25
+ from .io import _read_and_merge_data
26
+ from .plot import _plot_plates
27
+
28
+ db_loc = [src+'/measurements/measurements.db']
29
+ tables = ['cell', 'nucleus', 'pathogen','cytoplasm']
30
+ include_multinucleated, include_multiinfected, include_noninfected = True, 2.0, True
31
+
32
+ df = join_measurments_and_annotation(src, tables=['cell', 'nucleus', 'pathogen', 'cytoplasm'])
33
+
34
+ if not channel_of_interest is None:
35
+ df['recruitment'] = df[f'pathogen_channel_{channel_of_interest}_mean_intensity']/df[f'cytoplasm_channel_{channel_of_interest}_mean_intensity']
36
+ feature_string = f'channel_{channel_of_interest}'
37
+ else:
38
+ feature_string = None
39
+
40
+ output = _permutation_importance(df, feature_string, col_to_compare, pos, neg, exclude, n_repeats, clean, nr_to_plot, n_estimators=n_estimators, random_state=42, model_type=model_type, n_jobs=n_jobs)
41
+
42
+ _shap_analysis(output[3], output[4], output[5])
43
+
44
+ features = output[0].select_dtypes(include=[np.number]).columns.tolist()
45
+
46
+ if not variable in features:
47
+ raise ValueError(f"Variable {variable} not found in the dataframe. Please choose one of the following: {features}")
48
+
49
+ plate_heatmap = _plot_plates(output[0], variable, grouping, min_max, cmap, min_count)
50
+ return [output, plate_heatmap]