spacr 0.0.1__py3-none-any.whl → 0.0.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- spacr/__init__.py +6 -2
- spacr/__main__.py +0 -2
- spacr/alpha.py +807 -0
- spacr/annotate_app.py +118 -120
- spacr/chris.py +50 -0
- spacr/cli.py +25 -187
- spacr/core.py +1611 -389
- spacr/deep_spacr.py +696 -0
- spacr/foldseek.py +779 -0
- spacr/get_alfafold_structures.py +72 -0
- spacr/graph_learning.py +320 -0
- spacr/graph_learning_lap.py +84 -0
- spacr/gui.py +145 -0
- spacr/gui_2.py +90 -0
- spacr/gui_classify_app.py +187 -0
- spacr/gui_mask_app.py +149 -174
- spacr/gui_measure_app.py +116 -109
- spacr/gui_sim_app.py +0 -0
- spacr/gui_utils.py +679 -139
- spacr/io.py +620 -469
- spacr/mask_app.py +116 -9
- spacr/measure.py +178 -84
- spacr/models/cp/toxo_pv_lumen.CP_model +0 -0
- spacr/old_code.py +255 -1
- spacr/plot.py +263 -100
- spacr/sequencing.py +1130 -0
- spacr/sim.py +634 -122
- spacr/timelapse.py +343 -53
- spacr/train.py +195 -22
- spacr/umap.py +0 -689
- spacr/utils.py +1530 -188
- spacr-0.0.6.dist-info/METADATA +118 -0
- spacr-0.0.6.dist-info/RECORD +39 -0
- {spacr-0.0.1.dist-info → spacr-0.0.6.dist-info}/WHEEL +1 -1
- spacr-0.0.6.dist-info/entry_points.txt +9 -0
- spacr-0.0.1.dist-info/METADATA +0 -64
- spacr-0.0.1.dist-info/RECORD +0 -26
- spacr-0.0.1.dist-info/entry_points.txt +0 -5
- {spacr-0.0.1.dist-info → spacr-0.0.6.dist-info}/LICENSE +0 -0
- {spacr-0.0.1.dist-info → spacr-0.0.6.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,72 @@
|
|
1
|
+
import csv
|
2
|
+
import os
|
3
|
+
import requests
|
4
|
+
|
5
|
+
def download_alphafold_structures(tsv_location, dst, version="4"):
|
6
|
+
# Create the destination directory if it does not exist
|
7
|
+
dst_pdb = os.path.join(dst,'pdb')
|
8
|
+
dst_cif = os.path.join(dst,'cif')
|
9
|
+
dst_pae = os.path.join(dst,'pae')
|
10
|
+
|
11
|
+
if not os.path.exists(dst):
|
12
|
+
os.makedirs(dst)
|
13
|
+
if not os.path.exists(dst_pdb):
|
14
|
+
os.makedirs(dst_pdb)
|
15
|
+
if not os.path.exists(dst_cif):
|
16
|
+
os.makedirs(dst_cif)
|
17
|
+
if not os.path.exists(dst_pae):
|
18
|
+
os.makedirs(dst_pae)
|
19
|
+
|
20
|
+
failed_downloads = [] # List to keep track of failed downloads
|
21
|
+
|
22
|
+
# Open the TSV file and read entries
|
23
|
+
with open(tsv_location, 'r') as tsv_file:
|
24
|
+
reader = csv.DictReader(tsv_file, delimiter='\t')
|
25
|
+
for row in reader:
|
26
|
+
entry = row['Entry']
|
27
|
+
af_link = f"https://alphafold.ebi.ac.uk/files/AF-{entry}-F1-model_v{version}.pdb"
|
28
|
+
cif_link = f"https://alphafold.ebi.ac.uk/files/AF-{entry}-F1-model_v{version}.cif"
|
29
|
+
pae_link = f"https://alphafold.ebi.ac.uk/files/AF-{entry}-F1-predicted_aligned_error_v{version}.json"
|
30
|
+
|
31
|
+
try:
|
32
|
+
response_pdb = requests.get(af_link, stream=True)
|
33
|
+
response_cif = requests.get(cif_link, stream=True)
|
34
|
+
response_pae = requests.get(pae_link, stream=True)
|
35
|
+
if response_pdb.status_code == 200:
|
36
|
+
|
37
|
+
# Save the PDB file
|
38
|
+
with open(os.path.join(dst_pdb, f"AF-{entry}-F1-model_v{version}.pdb"), 'wb') as pdb_file:
|
39
|
+
pdb_file.write(response_pdb.content)
|
40
|
+
print(f"Downloaded: AF-{entry}-F1-model_v{version}.pdb")
|
41
|
+
|
42
|
+
# Save the CIF file
|
43
|
+
with open(os.path.join(dst_cif, f"AF-{entry}-F1-model_v{version}.cif"), 'wb') as cif_file:
|
44
|
+
cif_file.write(response_cif.content)
|
45
|
+
print(f"Downloaded: AF-{entry}-F1-model_v{version}.cif")
|
46
|
+
|
47
|
+
# Save the PAE file
|
48
|
+
with open(os.path.join(dst_pae, f"AF-{entry}-F1-predicted_aligned_error_v{version}.json"), 'wb') as pdb_file:
|
49
|
+
pdb_file.write(response_pae.content)
|
50
|
+
print(f"Downloaded: AF-{entry}-F1-predicted_aligned_error_v{version}.json")
|
51
|
+
|
52
|
+
else:
|
53
|
+
# If the file could not be downloaded, record the entry
|
54
|
+
failed_downloads.append(entry)
|
55
|
+
print(f"Failed to download structure for: {entry}")
|
56
|
+
except Exception as e:
|
57
|
+
print(f"Error downloading structure for {entry}: {e}")
|
58
|
+
failed_downloads.append(entry)
|
59
|
+
|
60
|
+
# Save the list of failed downloads to a CSV file in the destination folder
|
61
|
+
if failed_downloads:
|
62
|
+
with open(os.path.join(dst, 'failed_downloads.csv'), 'w', newline='') as failed_file:
|
63
|
+
writer = csv.writer(failed_file)
|
64
|
+
writer.writerow(['Entry'])
|
65
|
+
for entry in failed_downloads:
|
66
|
+
writer.writerow([entry])
|
67
|
+
print(f"Failed download entries saved to: {os.path.join(dst, 'failed_downloads.csv')}")
|
68
|
+
|
69
|
+
# Example usage:
|
70
|
+
tsv_location = '/home/carruthers/Downloads/GT1_proteome/GT1_proteins_uniprot.tsv' # Replace with the path to your TSV file containing a list of UniProt entries
|
71
|
+
dst_folder = '/home/carruthers/Downloads/GT1_proteome' # Replace with your destination folder
|
72
|
+
download_alphafold_structures(tsv_location, dst_folder)
|
spacr/graph_learning.py
ADDED
@@ -0,0 +1,320 @@
|
|
1
|
+
import os
|
2
|
+
os.environ['DGLBACKEND'] = 'pytorch'
|
3
|
+
import torch, dgl
|
4
|
+
import pandas as pd
|
5
|
+
import torch.nn as nn
|
6
|
+
from torchvision import datasets, transforms
|
7
|
+
from sklearn.preprocessing import StandardScaler
|
8
|
+
from PIL import Image
|
9
|
+
import dgl.nn.pytorch as dglnn
|
10
|
+
from sklearn.datasets import make_classification
|
11
|
+
from .utils import SelectChannels
|
12
|
+
|
13
|
+
# approach outline
|
14
|
+
#
|
15
|
+
# 1. Data Preparation:
|
16
|
+
# Test Mode: Load MNIST data and generate synthetic gRNA data.
|
17
|
+
# Real Data: Load image paths and sequencing data as fractions.
|
18
|
+
#
|
19
|
+
# 2. Graph Construction:
|
20
|
+
# Each well is represented as a graph.
|
21
|
+
# Each graph has cell nodes (with image features) and gRNA nodes (with gRNA fraction features).
|
22
|
+
# Each cell node is connected to each gRNA node within the same well.
|
23
|
+
#
|
24
|
+
# 3. Model Training:
|
25
|
+
# Use an encoder-decoder architecture with the Graph Transformer model.
|
26
|
+
# The encoder processes the cell and gRNA nodes.
|
27
|
+
# The decoder outputs the phenotype score for each cell node.
|
28
|
+
# The model is trained on all wells (including positive and negative controls).
|
29
|
+
# The model learns to score the gRNA in column 1 (negative control) as 0 and the gRNA in column 2 (positive control) as 1 based on the cell features.
|
30
|
+
#
|
31
|
+
# 4. Model Application:
|
32
|
+
# Apply the trained model to all wells to get classification probabilities.
|
33
|
+
#
|
34
|
+
# 5. Evaluation:
|
35
|
+
# Evaluate the model's performance using the control wells.
|
36
|
+
#
|
37
|
+
# 6. Association Analysis:
|
38
|
+
# Analyze the association between gRNAs and the classification scores.
|
39
|
+
#
|
40
|
+
# The model learns the associations between cell features and phenotype scores based on the controls and then generalizes this learning to the screening wells.
|
41
|
+
|
42
|
+
# Load MNIST data for testing
|
43
|
+
def load_mnist_data():
|
44
|
+
transform = transforms.Compose([
|
45
|
+
transforms.Resize((28, 28)),
|
46
|
+
transforms.ToTensor(),
|
47
|
+
transforms.Normalize((0.1307,), (0.3081,))
|
48
|
+
])
|
49
|
+
mnist_train = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
|
50
|
+
mnist_test = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
|
51
|
+
return mnist_train, mnist_test
|
52
|
+
|
53
|
+
# Generate synthetic gRNA data
|
54
|
+
def generate_synthetic_grna_data(n_samples, n_features):
|
55
|
+
X, y = make_classification(n_samples=n_samples, n_features=n_features, n_informative=5, n_redundant=0, n_classes=2, random_state=42)
|
56
|
+
synthetic_data = pd.DataFrame(X, columns=[f'feature_{i}' for i in range(n_features)])
|
57
|
+
synthetic_data['label'] = y
|
58
|
+
return synthetic_data
|
59
|
+
|
60
|
+
# Preprocess image
|
61
|
+
def preprocess_image(image_path, image_size=224, channels=[1,2,3], normalize=True):
|
62
|
+
|
63
|
+
if normalize:
|
64
|
+
preprocess = transforms.Compose([
|
65
|
+
transforms.ToTensor(),
|
66
|
+
transforms.CenterCrop(size=(image_size, image_size)),
|
67
|
+
SelectChannels(channels),
|
68
|
+
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])
|
69
|
+
else:
|
70
|
+
preprocess = transforms.Compose([
|
71
|
+
transforms.ToTensor(),
|
72
|
+
transforms.CenterCrop(size=(image_size, image_size)),
|
73
|
+
SelectChannels(channels)])
|
74
|
+
|
75
|
+
image = Image.open(image_path).convert('RGB')
|
76
|
+
return preprocess(image)
|
77
|
+
|
78
|
+
def extract_metadata_from_path(path):
|
79
|
+
"""
|
80
|
+
Extract metadata from the image path.
|
81
|
+
The path format is expected to be plate_well_field_objectnumber.png
|
82
|
+
|
83
|
+
Parameters:
|
84
|
+
path (str): The path to the image file.
|
85
|
+
|
86
|
+
Returns:
|
87
|
+
dict: A dictionary with the extracted metadata.
|
88
|
+
"""
|
89
|
+
filename = os.path.basename(path)
|
90
|
+
name, ext = os.path.splitext(filename)
|
91
|
+
|
92
|
+
# Ensure the file has the correct extension
|
93
|
+
if ext.lower() != '.png':
|
94
|
+
raise ValueError("Expected a .png file")
|
95
|
+
|
96
|
+
# Split the name by underscores
|
97
|
+
parts = name.split('_')
|
98
|
+
if len(parts) != 4:
|
99
|
+
raise ValueError("Expected filename format: plate_well_field_objectnumber.png")
|
100
|
+
|
101
|
+
plate, well, field, object_number = parts
|
102
|
+
|
103
|
+
return {'plate': plate, 'well': well,'field': field, 'object_number': object_number}
|
104
|
+
|
105
|
+
# Load images
|
106
|
+
def load_images(image_paths, image_size=224, channels=[1,2,3], normalize=True):
|
107
|
+
images = []
|
108
|
+
metadata_list = []
|
109
|
+
for path in image_paths:
|
110
|
+
image = preprocess_image(path, image_size, channels, normalize)
|
111
|
+
images.append(image)
|
112
|
+
metadata = extract_metadata_from_path(path) # Extract metadata from image path or database
|
113
|
+
metadata_list.append(metadata)
|
114
|
+
return torch.stack(images), metadata_list
|
115
|
+
|
116
|
+
# Normalize sequencing data
|
117
|
+
def normalize_sequencing_data(sequencing_data):
|
118
|
+
scaler = StandardScaler()
|
119
|
+
sequencing_data.iloc[:, 2:] = scaler.fit_transform(sequencing_data.iloc[:, 2:])
|
120
|
+
return sequencing_data
|
121
|
+
|
122
|
+
# Construct graph for each well
|
123
|
+
def construct_well_graph(images, image_metadata, grna_data):
|
124
|
+
cell_nodes = len(images)
|
125
|
+
grna_nodes = grna_data.shape[0]
|
126
|
+
|
127
|
+
graph = dgl.DGLGraph()
|
128
|
+
graph.add_nodes(cell_nodes + grna_nodes)
|
129
|
+
|
130
|
+
cell_features = torch.stack(images)
|
131
|
+
grna_features = torch.tensor(grna_data).float()
|
132
|
+
|
133
|
+
features = torch.cat([cell_features, grna_features], dim=0)
|
134
|
+
graph.ndata['features'] = features
|
135
|
+
|
136
|
+
for i in range(cell_nodes):
|
137
|
+
for j in range(cell_nodes, cell_nodes + grna_nodes):
|
138
|
+
graph.add_edge(i, j)
|
139
|
+
graph.add_edge(j, i)
|
140
|
+
|
141
|
+
return graph
|
142
|
+
|
143
|
+
def create_graphs_for_wells(images, metadata_list, sequencing_data):
|
144
|
+
graphs = []
|
145
|
+
labels = []
|
146
|
+
|
147
|
+
for well in sequencing_data['well'].unique():
|
148
|
+
well_images = [img for img, meta in zip(images, metadata_list) if meta['well'] == well]
|
149
|
+
well_metadata = [meta for meta in metadata_list if meta['well'] == well]
|
150
|
+
well_grna_data = sequencing_data[sequencing_data['well'] == well].iloc[:, 2:].values
|
151
|
+
|
152
|
+
graph = construct_well_graph(well_images, well_metadata, well_grna_data)
|
153
|
+
graphs.append(graph)
|
154
|
+
|
155
|
+
if well_metadata[0]['column'] == 1: # Negative control
|
156
|
+
labels.append(0)
|
157
|
+
elif well_metadata[0]['column'] == 2: # Positive control
|
158
|
+
labels.append(1)
|
159
|
+
else:
|
160
|
+
labels.append(-1) # Screen wells, will be used for evaluation
|
161
|
+
|
162
|
+
return graphs, labels
|
163
|
+
|
164
|
+
# Define Encoder-Decoder Transformer Model
|
165
|
+
class Encoder(nn.Module):
|
166
|
+
def __init__(self, in_feats, hidden_feats):
|
167
|
+
super(Encoder, self).__init__()
|
168
|
+
self.conv1 = dglnn.GraphConv(in_feats, hidden_feats)
|
169
|
+
self.conv2 = dglnn.GraphConv(hidden_feats, hidden_feats)
|
170
|
+
|
171
|
+
def forward(self, g, features):
|
172
|
+
x = self.conv1(g, features)
|
173
|
+
x = torch.relu(x)
|
174
|
+
x = self.conv2(g, x)
|
175
|
+
x = torch.relu(x)
|
176
|
+
return x
|
177
|
+
|
178
|
+
class Decoder(nn.Module):
|
179
|
+
def __init__(self, hidden_feats, out_feats):
|
180
|
+
super(Decoder, self).__init__()
|
181
|
+
self.linear = nn.Linear(hidden_feats, out_feats)
|
182
|
+
|
183
|
+
def forward(self, x):
|
184
|
+
return self.linear(x)
|
185
|
+
|
186
|
+
class GraphTransformer(nn.Module):
|
187
|
+
def __init__(self, in_feats, hidden_feats, out_feats):
|
188
|
+
super(GraphTransformer, self).__init__()
|
189
|
+
self.encoder = Encoder(in_feats, hidden_feats)
|
190
|
+
self.decoder = Decoder(hidden_feats, out_feats)
|
191
|
+
|
192
|
+
def forward(self, g, features):
|
193
|
+
x = self.encoder(g, features)
|
194
|
+
with g.local_scope():
|
195
|
+
g.ndata['h'] = x
|
196
|
+
hg = dgl.mean_nodes(g, 'h')
|
197
|
+
return self.decoder(hg)
|
198
|
+
|
199
|
+
def train(graphs, labels, model, loss_fn, optimizer, epochs=100):
|
200
|
+
for epoch in range(epochs):
|
201
|
+
model.train()
|
202
|
+
total_loss = 0
|
203
|
+
correct = 0
|
204
|
+
total = 0
|
205
|
+
|
206
|
+
for graph, label in zip(graphs, labels):
|
207
|
+
if label == -1:
|
208
|
+
continue # Skip screen wells for training
|
209
|
+
|
210
|
+
features = graph.ndata['features']
|
211
|
+
logits = model(graph, features)
|
212
|
+
loss = loss_fn(logits, torch.tensor([label]))
|
213
|
+
|
214
|
+
optimizer.zero_grad()
|
215
|
+
loss.backward()
|
216
|
+
optimizer.step()
|
217
|
+
|
218
|
+
total_loss += loss.item()
|
219
|
+
_, predicted = torch.max(logits, 1)
|
220
|
+
correct += (predicted == label).sum().item()
|
221
|
+
total += 1
|
222
|
+
|
223
|
+
accuracy = correct / total if total > 0 else 0
|
224
|
+
print(f'Epoch {epoch}, Loss: {total_loss / total:.4f}, Accuracy: {accuracy * 100:.2f}%')
|
225
|
+
|
226
|
+
def apply_model(graphs, model):
|
227
|
+
model.eval()
|
228
|
+
results = []
|
229
|
+
|
230
|
+
with torch.no_grad():
|
231
|
+
for graph in graphs:
|
232
|
+
features = graph.ndata['features']
|
233
|
+
logits = model(graph, features)
|
234
|
+
probabilities = torch.softmax(logits, dim=1)
|
235
|
+
results.append(probabilities[:, 1].item())
|
236
|
+
|
237
|
+
return results
|
238
|
+
|
239
|
+
def analyze_associations(probabilities, sequencing_data):
|
240
|
+
# Analyze associations between gRNAs and classification scores
|
241
|
+
sequencing_data['positive_prob'] = probabilities
|
242
|
+
return sequencing_data.groupby('gRNA').positive_prob.mean().sort_values(ascending=False)
|
243
|
+
|
244
|
+
def train_graph_transformer(src, lr=0.01, epochs=100, hidden_feats=128, n_classes=2, row_limit=None, image_size=224, channels=[1,2,3], normalize=True, test_mode=False):
|
245
|
+
if test_mode:
|
246
|
+
# Load MNIST data
|
247
|
+
mnist_train, mnist_test = load_mnist_data()
|
248
|
+
|
249
|
+
# Generate synthetic gRNA data
|
250
|
+
synthetic_grna_data = generate_synthetic_grna_data(len(mnist_train), 10) # 10 synthetic features
|
251
|
+
sequencing_data = synthetic_grna_data
|
252
|
+
|
253
|
+
# Load MNIST images and metadata
|
254
|
+
images = []
|
255
|
+
metadata_list = []
|
256
|
+
for idx, (img, label) in enumerate(mnist_train):
|
257
|
+
images.append(img)
|
258
|
+
metadata_list.append({'index': idx, 'plate': 'plate1', 'well': idx, 'column': label})
|
259
|
+
images = torch.stack(images)
|
260
|
+
|
261
|
+
# Normalize synthetic sequencing data
|
262
|
+
sequencing_data = normalize_sequencing_data(sequencing_data)
|
263
|
+
|
264
|
+
else:
|
265
|
+
from .io import _read_and_join_tables
|
266
|
+
from .utils import get_db_paths, get_sequencing_paths, correct_paths
|
267
|
+
|
268
|
+
db_paths = get_db_paths(src)
|
269
|
+
seq_paths = get_sequencing_paths(src)
|
270
|
+
|
271
|
+
if isinstance(src, str):
|
272
|
+
src = [src]
|
273
|
+
|
274
|
+
sequencing_data = pd.DataFrame()
|
275
|
+
for seq in seq_paths:
|
276
|
+
sequencing_df = pd.read_csv(seq)
|
277
|
+
sequencing_data = pd.concat([sequencing_data, sequencing_df], axis=0)
|
278
|
+
|
279
|
+
all_df = pd.DataFrame()
|
280
|
+
for db_path in db_paths:
|
281
|
+
df = _read_and_join_tables(db_path, table_names=['png_list'])
|
282
|
+
all_df = pd.concat([all_df, df], axis=0)
|
283
|
+
|
284
|
+
tables = ['png_list']
|
285
|
+
all_df = pd.DataFrame()
|
286
|
+
image_paths = []
|
287
|
+
for i, db_path in enumerate(db_paths):
|
288
|
+
df = _read_and_join_tables(db_path, table_names=tables)
|
289
|
+
df, image_paths_tmp = correct_paths(df, src[i])
|
290
|
+
all_df = pd.concat([all_df, df], axis=0)
|
291
|
+
image_paths.extend(image_paths_tmp)
|
292
|
+
|
293
|
+
if row_limit is not None:
|
294
|
+
all_df = all_df.sample(n=row_limit, random_state=42)
|
295
|
+
|
296
|
+
images, metadata_list = load_images(image_paths, image_size, channels, normalize)
|
297
|
+
sequencing_data = normalize_sequencing_data(sequencing_data)
|
298
|
+
|
299
|
+
# Step 1: Create graphs for each well
|
300
|
+
graphs, labels = create_graphs_for_wells(images, metadata_list, sequencing_data)
|
301
|
+
|
302
|
+
# Step 2: Train Graph Transformer Model
|
303
|
+
in_feats = graphs[0].ndata['features'].shape[1]
|
304
|
+
model = GraphTransformer(in_feats, hidden_feats, n_classes)
|
305
|
+
loss_fn = nn.CrossEntropyLoss()
|
306
|
+
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
|
307
|
+
|
308
|
+
# Train the model
|
309
|
+
train(graphs, labels, model, loss_fn, optimizer, epochs)
|
310
|
+
|
311
|
+
# Step 3: Apply the model to all wells (including screen wells)
|
312
|
+
screen_graphs = [graph for graph, label in zip(graphs, labels) if label == -1]
|
313
|
+
probabilities = apply_model(screen_graphs, model)
|
314
|
+
|
315
|
+
# Step 4: Analyze associations between gRNAs and classification scores
|
316
|
+
associations = analyze_associations(probabilities, sequencing_data)
|
317
|
+
print("Top associated gRNAs with positive control phenotype:")
|
318
|
+
print(associations.head())
|
319
|
+
|
320
|
+
return model, associations
|
@@ -0,0 +1,84 @@
|
|
1
|
+
import torch
|
2
|
+
import torch.nn as nn
|
3
|
+
import torch.nn.functional as F
|
4
|
+
from torch.utils.data import Dataset, DataLoader, TensorDataset
|
5
|
+
|
6
|
+
# Let's assume that the feature embedding part and the dataset loading part
|
7
|
+
# has already been taken care of, and your data is already in the format
|
8
|
+
# suitable for PyTorch (i.e., Tensors).
|
9
|
+
|
10
|
+
class FeatureEmbedder(nn.Module):
|
11
|
+
def __init__(self, vocab_sizes, embedding_size):
|
12
|
+
super(FeatureEmbedder, self).__init__()
|
13
|
+
self.embeddings = nn.ModuleDict({
|
14
|
+
key: nn.Embedding(num_embeddings=vocab_size+1,
|
15
|
+
embedding_dim=embedding_size,
|
16
|
+
padding_idx=vocab_size)
|
17
|
+
for key, vocab_size in vocab_sizes.items()
|
18
|
+
})
|
19
|
+
# Adding the 'visit' embedding
|
20
|
+
self.embeddings['visit'] = nn.Parameter(torch.zeros(1, embedding_size))
|
21
|
+
|
22
|
+
def forward(self, feature_map, max_num_codes):
|
23
|
+
# Implementation will depend on how you want to handle sparse data
|
24
|
+
# This is just a placeholder
|
25
|
+
embeddings = {}
|
26
|
+
masks = {}
|
27
|
+
for key, tensor in feature_map.items():
|
28
|
+
embeddings[key] = self.embeddings[key](tensor.long())
|
29
|
+
mask = torch.ones_like(tensor, dtype=torch.float32)
|
30
|
+
masks[key] = mask.unsqueeze(-1)
|
31
|
+
|
32
|
+
# Batch size hardcoded for simplicity in example
|
33
|
+
batch_size = 1 # Replace with actual batch size
|
34
|
+
embeddings['visit'] = self.embeddings['visit'].expand(batch_size, -1, -1)
|
35
|
+
masks['visit'] = torch.ones(batch_size, 1)
|
36
|
+
|
37
|
+
return embeddings, masks
|
38
|
+
|
39
|
+
class GraphConvolutionalTransformer(nn.Module):
|
40
|
+
def __init__(self, embedding_size=128, num_attention_heads=1, **kwargs):
|
41
|
+
super(GraphConvolutionalTransformer, self).__init__()
|
42
|
+
# Transformer Blocks
|
43
|
+
self.layers = nn.ModuleList([
|
44
|
+
nn.TransformerEncoderLayer(
|
45
|
+
d_model=embedding_size,
|
46
|
+
nhead=num_attention_heads,
|
47
|
+
batch_first=True)
|
48
|
+
for _ in range(kwargs.get('num_transformer_stack', 3))
|
49
|
+
])
|
50
|
+
# Output Layer for Classification
|
51
|
+
self.output_layer = nn.Linear(embedding_size, 1)
|
52
|
+
|
53
|
+
def feedforward(self, features, mask=None, training=None):
|
54
|
+
# Implement feedforward logic (placeholder)
|
55
|
+
pass
|
56
|
+
|
57
|
+
def forward(self, embeddings, masks, mask=None, training=False):
|
58
|
+
features = embeddings
|
59
|
+
attentions = [] # Storing attentions if needed
|
60
|
+
|
61
|
+
# Pass through each Transformer block
|
62
|
+
for layer in self.layers:
|
63
|
+
features = layer(features) # Apply transformer encoding here
|
64
|
+
|
65
|
+
if mask is not None:
|
66
|
+
features = features * mask
|
67
|
+
|
68
|
+
logits = self.output_layer(features[:, 0, :]) # Using the 'visit' embedding for classification
|
69
|
+
return logits, attentions
|
70
|
+
|
71
|
+
# Usage Example
|
72
|
+
vocab_sizes = {'dx_ints':3249, 'proc_ints':2210}
|
73
|
+
embedding_size = 128
|
74
|
+
gct_params = {
|
75
|
+
'embedding_size': embedding_size,
|
76
|
+
'num_transformer_stack': 3,
|
77
|
+
'num_attention_heads': 1
|
78
|
+
}
|
79
|
+
feature_embedder = FeatureEmbedder(vocab_sizes, embedding_size)
|
80
|
+
gct_model = GraphConvolutionalTransformer(**gct_params)
|
81
|
+
|
82
|
+
# Assume `feature_map` is a dictionary of tensors, and `max_num_codes` is provided
|
83
|
+
embeddings, masks = feature_embedder(feature_map, max_num_codes)
|
84
|
+
logits, attentions = gct_model(embeddings, masks)
|
spacr/gui.py
ADDED
@@ -0,0 +1,145 @@
|
|
1
|
+
import tkinter as tk
|
2
|
+
from tkinter import ttk
|
3
|
+
from tkinter import font as tkFont
|
4
|
+
from PIL import Image, ImageTk
|
5
|
+
import os
|
6
|
+
import requests
|
7
|
+
|
8
|
+
# Import your GUI apps
|
9
|
+
from .gui_mask_app import initiate_mask_root
|
10
|
+
from .gui_measure_app import initiate_measure_root
|
11
|
+
from .annotate_app import initiate_annotation_app_root
|
12
|
+
from .mask_app import initiate_mask_app_root
|
13
|
+
from .gui_classify_app import initiate_classify_root
|
14
|
+
|
15
|
+
from .gui_utils import CustomButton, style_text_boxes
|
16
|
+
|
17
|
+
class MainApp(tk.Tk):
|
18
|
+
def __init__(self):
|
19
|
+
super().__init__()
|
20
|
+
self.title("SpaCr GUI Collection")
|
21
|
+
self.geometry("1100x1500")
|
22
|
+
self.configure(bg="black")
|
23
|
+
#self.attributes('-fullscreen', True)
|
24
|
+
|
25
|
+
style = ttk.Style()
|
26
|
+
style_text_boxes(style)
|
27
|
+
|
28
|
+
self.gui_apps = {
|
29
|
+
"Mask": (initiate_mask_root, "Generate cellpose masks for cells, nuclei and pathogen images."),
|
30
|
+
"Measure": (initiate_measure_root, "Measure single object intensity and morphological feature. Crop and save single object image"),
|
31
|
+
"Annotate": (initiate_annotation_app_root, "Annotation single object images on a grid. Annotations are saved to database."),
|
32
|
+
"Make Masks": (initiate_mask_app_root, "Adjust pre-existing Cellpose models to your specific dataset for improved performance"),
|
33
|
+
"Classify": (initiate_classify_root, "Train Torch Convolutional Neural Networks (CNNs) or Transformers to classify single object images.")
|
34
|
+
}
|
35
|
+
|
36
|
+
self.selected_app = tk.StringVar()
|
37
|
+
self.create_widgets()
|
38
|
+
|
39
|
+
def create_widgets(self):
|
40
|
+
# Create the menu bar
|
41
|
+
#create_menu_bar(self)
|
42
|
+
# Create a canvas to hold the selected app and other elements
|
43
|
+
self.canvas = tk.Canvas(self, bg="black", highlightthickness=0, width=4000, height=4000)
|
44
|
+
self.canvas.grid(row=0, column=0, sticky="nsew")
|
45
|
+
self.grid_rowconfigure(0, weight=1)
|
46
|
+
self.grid_columnconfigure(0, weight=1)
|
47
|
+
# Create a frame inside the canvas to hold the main content
|
48
|
+
self.content_frame = tk.Frame(self.canvas, bg="black")
|
49
|
+
self.content_frame.pack(fill=tk.BOTH, expand=True)
|
50
|
+
# Create startup screen with buttons for each GUI app
|
51
|
+
self.create_startup_screen()
|
52
|
+
|
53
|
+
def create_startup_screen(self):
|
54
|
+
self.clear_frame(self.content_frame)
|
55
|
+
|
56
|
+
# Create a frame for the logo and description
|
57
|
+
logo_frame = tk.Frame(self.content_frame, bg="black")
|
58
|
+
logo_frame.pack(pady=20, expand=True)
|
59
|
+
|
60
|
+
# Load the logo image
|
61
|
+
if not self.load_logo(logo_frame):
|
62
|
+
tk.Label(logo_frame, text="Logo not found", bg="black", fg="white", font=('Arial', 24, tkFont.NORMAL)).pack(padx=10, pady=10)
|
63
|
+
|
64
|
+
# Add SpaCr text below the logo with padding for sharper text
|
65
|
+
tk.Label(logo_frame, text="SpaCr", bg="black", fg="#008080", font=('Arial', 24, tkFont.NORMAL)).pack(padx=10, pady=10)
|
66
|
+
|
67
|
+
# Create a frame for the buttons and descriptions
|
68
|
+
buttons_frame = tk.Frame(self.content_frame, bg="black")
|
69
|
+
buttons_frame.pack(pady=10, expand=True, padx=10)
|
70
|
+
|
71
|
+
for i, (app_name, app_data) in enumerate(self.gui_apps.items()):
|
72
|
+
app_func, app_desc = app_data
|
73
|
+
|
74
|
+
# Create custom button with text
|
75
|
+
button = CustomButton(buttons_frame, text=app_name, command=lambda app_name=app_name: self.load_app(app_name))
|
76
|
+
button.grid(row=i, column=0, pady=10, padx=10, sticky="w")
|
77
|
+
|
78
|
+
description_label = tk.Label(buttons_frame, text=app_desc, bg="black", fg="white", wraplength=800, justify="left", font=('Arial', 10, tkFont.NORMAL))
|
79
|
+
description_label.grid(row=i, column=1, pady=10, padx=10, sticky="w")
|
80
|
+
|
81
|
+
# Ensure buttons have a fixed width
|
82
|
+
buttons_frame.grid_columnconfigure(0, minsize=150)
|
83
|
+
# Ensure descriptions expand as needed
|
84
|
+
buttons_frame.grid_columnconfigure(1, weight=1)
|
85
|
+
|
86
|
+
def load_logo(self, frame):
|
87
|
+
def download_image(url, save_path):
|
88
|
+
try:
|
89
|
+
response = requests.get(url, stream=True)
|
90
|
+
response.raise_for_status() # Raise an HTTPError for bad responses
|
91
|
+
with open(save_path, 'wb') as f:
|
92
|
+
for chunk in response.iter_content(chunk_size=8192):
|
93
|
+
f.write(chunk)
|
94
|
+
return True
|
95
|
+
except requests.exceptions.RequestException as e:
|
96
|
+
print(f"Failed to download image from {url}: {e}")
|
97
|
+
return False
|
98
|
+
|
99
|
+
try:
|
100
|
+
img_path = os.path.join(os.path.dirname(__file__), 'logo_spacr.png')
|
101
|
+
print(f"Trying to load logo from {img_path}")
|
102
|
+
logo_image = Image.open(img_path)
|
103
|
+
except (FileNotFoundError, Image.UnidentifiedImageError):
|
104
|
+
print(f"File {img_path} not found or is not a valid image. Attempting to download from GitHub.")
|
105
|
+
if download_image('https://raw.githubusercontent.com/EinarOlafsson/spacr/main/spacr/logo_spacr.png', img_path):
|
106
|
+
try:
|
107
|
+
print(f"Downloaded file size: {os.path.getsize(img_path)} bytes")
|
108
|
+
logo_image = Image.open(img_path)
|
109
|
+
except Image.UnidentifiedImageError as e:
|
110
|
+
print(f"Downloaded file is not a valid image: {e}")
|
111
|
+
return False
|
112
|
+
else:
|
113
|
+
return False
|
114
|
+
except Exception as e:
|
115
|
+
print(f"An error occurred while loading the logo: {e}")
|
116
|
+
return False
|
117
|
+
try:
|
118
|
+
logo_image = logo_image.resize((800, 800), Image.Resampling.LANCZOS)
|
119
|
+
logo_photo = ImageTk.PhotoImage(logo_image)
|
120
|
+
logo_label = tk.Label(frame, image=logo_photo, bg="black")
|
121
|
+
logo_label.image = logo_photo # Keep a reference to avoid garbage collection
|
122
|
+
logo_label.pack()
|
123
|
+
return True
|
124
|
+
except Exception as e:
|
125
|
+
print(f"An error occurred while processing the logo image: {e}")
|
126
|
+
return False
|
127
|
+
|
128
|
+
def load_app(self, app_name):
|
129
|
+
selected_app_func, _ = self.gui_apps[app_name]
|
130
|
+
self.clear_frame(self.content_frame)
|
131
|
+
|
132
|
+
app_frame = tk.Frame(self.content_frame, bg="black")
|
133
|
+
app_frame.pack(fill=tk.BOTH, expand=True)
|
134
|
+
selected_app_func(app_frame, self.winfo_width(), self.winfo_height())
|
135
|
+
|
136
|
+
def clear_frame(self, frame):
|
137
|
+
for widget in frame.winfo_children():
|
138
|
+
widget.destroy()
|
139
|
+
|
140
|
+
def gui_app():
|
141
|
+
app = MainApp()
|
142
|
+
app.mainloop()
|
143
|
+
|
144
|
+
if __name__ == "__main__":
|
145
|
+
gui_app()
|