spacr 0.0.1__py3-none-any.whl → 0.0.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- spacr/__init__.py +6 -2
- spacr/__main__.py +0 -2
- spacr/alpha.py +807 -0
- spacr/annotate_app.py +118 -120
- spacr/chris.py +50 -0
- spacr/cli.py +25 -187
- spacr/core.py +1611 -389
- spacr/deep_spacr.py +696 -0
- spacr/foldseek.py +779 -0
- spacr/get_alfafold_structures.py +72 -0
- spacr/graph_learning.py +320 -0
- spacr/graph_learning_lap.py +84 -0
- spacr/gui.py +145 -0
- spacr/gui_2.py +90 -0
- spacr/gui_classify_app.py +187 -0
- spacr/gui_mask_app.py +149 -174
- spacr/gui_measure_app.py +116 -109
- spacr/gui_sim_app.py +0 -0
- spacr/gui_utils.py +679 -139
- spacr/io.py +620 -469
- spacr/mask_app.py +116 -9
- spacr/measure.py +178 -84
- spacr/models/cp/toxo_pv_lumen.CP_model +0 -0
- spacr/old_code.py +255 -1
- spacr/plot.py +263 -100
- spacr/sequencing.py +1130 -0
- spacr/sim.py +634 -122
- spacr/timelapse.py +343 -53
- spacr/train.py +195 -22
- spacr/umap.py +0 -689
- spacr/utils.py +1530 -188
- spacr-0.0.6.dist-info/METADATA +118 -0
- spacr-0.0.6.dist-info/RECORD +39 -0
- {spacr-0.0.1.dist-info → spacr-0.0.6.dist-info}/WHEEL +1 -1
- spacr-0.0.6.dist-info/entry_points.txt +9 -0
- spacr-0.0.1.dist-info/METADATA +0 -64
- spacr-0.0.1.dist-info/RECORD +0 -26
- spacr-0.0.1.dist-info/entry_points.txt +0 -5
- {spacr-0.0.1.dist-info → spacr-0.0.6.dist-info}/LICENSE +0 -0
- {spacr-0.0.1.dist-info → spacr-0.0.6.dist-info}/top_level.txt +0 -0
spacr/old_code.py
CHANGED
@@ -51,7 +51,7 @@ def resize_figure_to_canvas(fig, canvas):
|
|
51
51
|
|
52
52
|
return fig
|
53
53
|
|
54
|
-
|
54
|
+
def process_fig_queue_v1():
|
55
55
|
global canvas
|
56
56
|
while not fig_queue.empty():
|
57
57
|
try:
|
@@ -102,3 +102,257 @@ def run_mask_gui(q):
|
|
102
102
|
preprocess_generate_masks_wrapper(settings['src'], settings=settings, advanced_settings={})
|
103
103
|
except Exception as e:
|
104
104
|
q.put(f"Error during processing: {e}\n")
|
105
|
+
|
106
|
+
#@log_function_call
|
107
|
+
def main_thread_update_function(root, q, fig_queue, canvas_widget, progress_label):
|
108
|
+
try:
|
109
|
+
while not q.empty():
|
110
|
+
message = q.get_nowait()
|
111
|
+
if message.startswith("Progress"):
|
112
|
+
progress_label.config(text=message)
|
113
|
+
elif message.startswith("Processing"):
|
114
|
+
progress_label.config(text=message)
|
115
|
+
elif message == "" or message == "\r":
|
116
|
+
pass
|
117
|
+
elif message.startswith("/"):
|
118
|
+
pass
|
119
|
+
elif message.startswith("\\"):
|
120
|
+
pass
|
121
|
+
elif message.startswith(""):
|
122
|
+
pass
|
123
|
+
else:
|
124
|
+
print(message)
|
125
|
+
except Exception as e:
|
126
|
+
print(f"Error updating GUI canvas: {e}")
|
127
|
+
#try:
|
128
|
+
# while not fig_queue.empty():
|
129
|
+
# fig = fig_queue.get_nowait()
|
130
|
+
# #if hasattr(canvas_widget, 'figure'):
|
131
|
+
# #clear_canvas(canvas_widget)
|
132
|
+
# canvas_widget.figure = fig
|
133
|
+
#except Exception as e:
|
134
|
+
# print(f"Error updating GUI figure: {e}")
|
135
|
+
finally:
|
136
|
+
root.after(100, lambda: main_thread_update_function(root, q, fig_queue, canvas_widget, progress_label))
|
137
|
+
|
138
|
+
class MPNN(MessagePassing):
|
139
|
+
def __init__(self, node_in_features, edge_in_features, out_features):
|
140
|
+
super(MPNN, self).__init__(aggr='mean') # 'mean' aggregation.
|
141
|
+
self.message_mlp = Sequential(
|
142
|
+
Linear(node_in_features + edge_in_features, 128),
|
143
|
+
ReLU(),
|
144
|
+
Linear(128, out_features)
|
145
|
+
)
|
146
|
+
self.update_mlp = Sequential(
|
147
|
+
Linear(out_features, out_features),
|
148
|
+
ReLU(),
|
149
|
+
Linear(out_features, out_features)
|
150
|
+
)
|
151
|
+
|
152
|
+
def forward(self, x, edge_index, edge_attr):
|
153
|
+
# x: Node features [N, node_in_features]
|
154
|
+
# edge_index: Graph connectivity [2, E]
|
155
|
+
# edge_attr: Edge attributes/features [E, edge_in_features]
|
156
|
+
return self.propagate(edge_index, x=x, edge_attr=edge_attr)
|
157
|
+
|
158
|
+
def message(self, x_j, edge_attr):
|
159
|
+
# x_j: Input features of neighbors [E, node_in_features]
|
160
|
+
# edge_attr: Edge attributes [E, edge_in_features]
|
161
|
+
tmp = torch.cat([x_j, edge_attr], dim=-1) # Concatenate node features with edge attributes
|
162
|
+
return self.message_mlp(tmp)
|
163
|
+
|
164
|
+
def update(self, aggr_out):
|
165
|
+
# aggr_out: Aggregated messages [N, out_features]
|
166
|
+
return self.update_mlp(aggr_out)
|
167
|
+
|
168
|
+
def weighted_mse_loss(output, target, score_threshold=0.8, high_score_weight=10):
|
169
|
+
# Assumes output and target are the predicted and true scores, respectively
|
170
|
+
weights = torch.ones_like(target)
|
171
|
+
high_score_mask = target >= score_threshold
|
172
|
+
weights[high_score_mask] = high_score_weight
|
173
|
+
return ((output - target) ** 2 * weights).mean()
|
174
|
+
|
175
|
+
def generate_single_graph(sequencing, scores):
|
176
|
+
# Load and preprocess sequencing data
|
177
|
+
gene_df = pd.read_csv(sequencing)
|
178
|
+
gene_df = gene_df.rename(columns={"prc": "well_id", "grna": "gene_id", "count": "read_count"})
|
179
|
+
total_reads_per_well = gene_df.groupby('well_id')['read_count'].sum().reset_index(name='total_reads')
|
180
|
+
gene_df = gene_df.merge(total_reads_per_well, on='well_id')
|
181
|
+
gene_df['well_read_fraction'] = gene_df['read_count']/gene_df['total_reads']
|
182
|
+
|
183
|
+
# Load and preprocess cell score data
|
184
|
+
cell_df = pd.read_csv(scores)
|
185
|
+
cell_df = cell_df[['prcfo', 'prc', 'pred']].rename(columns={'prcfo': 'cell_id', 'prc': 'well_id', 'pred': 'score'})
|
186
|
+
|
187
|
+
# Initialize mappings
|
188
|
+
gene_id_to_index = {gene: i for i, gene in enumerate(gene_df['gene_id'].unique())}
|
189
|
+
cell_id_to_index = {cell: i + len(gene_id_to_index) for i, cell in enumerate(cell_df['cell_id'].unique())}
|
190
|
+
|
191
|
+
# Initialize edge indices and attributes
|
192
|
+
edge_index = []
|
193
|
+
edge_attr = []
|
194
|
+
|
195
|
+
# Associate each cell with all genes in the same well
|
196
|
+
for well_id, group in gene_df.groupby('well_id'):
|
197
|
+
if well_id in cell_df['well_id'].values:
|
198
|
+
cell_indices = cell_df[cell_df['well_id'] == well_id]['cell_id'].map(cell_id_to_index).values
|
199
|
+
gene_indices = group['gene_id'].map(gene_id_to_index).values
|
200
|
+
fractions = group['well_read_fraction'].values
|
201
|
+
|
202
|
+
for cell_idx in cell_indices:
|
203
|
+
for gene_idx, fraction in zip(gene_indices, fractions):
|
204
|
+
edge_index.append([cell_idx, gene_idx])
|
205
|
+
edge_attr.append([fraction])
|
206
|
+
|
207
|
+
# Convert lists to PyTorch tensors
|
208
|
+
edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
|
209
|
+
edge_attr = torch.tensor(edge_attr, dtype=torch.float)
|
210
|
+
cell_scores = torch.tensor(cell_df['score'].values, dtype=torch.float)
|
211
|
+
|
212
|
+
# One-hot encoding for genes, and zero features for cells (could be replaced with real features if available)
|
213
|
+
gene_features = torch.eye(len(gene_id_to_index))
|
214
|
+
cell_features = torch.zeros(len(cell_id_to_index), gene_features.size(1))
|
215
|
+
|
216
|
+
# Combine features
|
217
|
+
x = torch.cat([cell_features, gene_features], dim=0)
|
218
|
+
|
219
|
+
# Create the graph data object
|
220
|
+
data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=cell_scores)
|
221
|
+
|
222
|
+
return data, gene_id_to_index, len(gene_id_to_index)
|
223
|
+
|
224
|
+
# in _normalize_and_outline
|
225
|
+
outlines = []
|
226
|
+
|
227
|
+
overlayed_image = rgb_image.copy()
|
228
|
+
for i, mask_dim in enumerate(mask_dims):
|
229
|
+
mask = np.take(image, mask_dim, axis=2)
|
230
|
+
outline = np.zeros_like(mask)
|
231
|
+
# Find the contours of the objects in the mask
|
232
|
+
for j in np.unique(mask)[1:]:
|
233
|
+
contours = find_contours(mask == j, 0.5)
|
234
|
+
for contour in contours:
|
235
|
+
contour = contour.astype(int)
|
236
|
+
outline[contour[:, 0], contour[:, 1]] = j
|
237
|
+
# Make the outline thicker
|
238
|
+
outline = dilation(outline, square(outline_thickness))
|
239
|
+
outlines.append(outline)
|
240
|
+
# Overlay the outlines onto the RGB image
|
241
|
+
for j in np.unique(outline)[1:]:
|
242
|
+
overlayed_image[outline == j] = outline_colors[i % len(outline_colors)]
|
243
|
+
|
244
|
+
def _extract_filename_metadata(filenames, src, images_by_key, regular_expression, metadata_type='cellvoyager', pick_slice=False, skip_mode='01'):
|
245
|
+
for filename in filenames:
|
246
|
+
match = regular_expression.match(filename)
|
247
|
+
if match:
|
248
|
+
try:
|
249
|
+
try:
|
250
|
+
plate = match.group('plateID')
|
251
|
+
except:
|
252
|
+
plate = os.path.basename(src)
|
253
|
+
|
254
|
+
well = match.group('wellID')
|
255
|
+
field = match.group('fieldID')
|
256
|
+
channel = match.group('chanID')
|
257
|
+
mode = None
|
258
|
+
|
259
|
+
if well[0].isdigit():
|
260
|
+
well = str(_safe_int_convert(well))
|
261
|
+
if field[0].isdigit():
|
262
|
+
field = str(_safe_int_convert(field))
|
263
|
+
if channel[0].isdigit():
|
264
|
+
channel = str(_safe_int_convert(channel))
|
265
|
+
|
266
|
+
if metadata_type =='cq1':
|
267
|
+
orig_wellID = wellID
|
268
|
+
wellID = _convert_cq1_well_id(wellID)
|
269
|
+
clear_output(wait=True)
|
270
|
+
print(f'Converted Well ID: {orig_wellID} to {wellID}', end='\r', flush=True)
|
271
|
+
|
272
|
+
if pick_slice:
|
273
|
+
try:
|
274
|
+
mode = match.group('AID')
|
275
|
+
except IndexError:
|
276
|
+
sliceid = '00'
|
277
|
+
|
278
|
+
if mode == skip_mode:
|
279
|
+
continue
|
280
|
+
|
281
|
+
key = (plate, well, field, channel, mode)
|
282
|
+
with Image.open(os.path.join(src, filename)) as img:
|
283
|
+
images_by_key[key].append(np.array(img))
|
284
|
+
except IndexError:
|
285
|
+
print(f"Could not extract information from filename {filename} using provided regex")
|
286
|
+
else:
|
287
|
+
print(f"Filename {filename} did not match provided regex")
|
288
|
+
continue
|
289
|
+
|
290
|
+
return images_by_key
|
291
|
+
|
292
|
+
def compare_cellpose_masks_v1(src, verbose=False, save=False):
|
293
|
+
|
294
|
+
from .io import _read_mask
|
295
|
+
from .plot import visualize_masks, plot_comparison_results, visualize_cellpose_masks
|
296
|
+
from .utils import extract_boundaries, boundary_f1_score, compute_segmentation_ap, jaccard_index
|
297
|
+
|
298
|
+
import os
|
299
|
+
import numpy as np
|
300
|
+
from skimage.measure import label
|
301
|
+
|
302
|
+
# Collect all subdirectories in src
|
303
|
+
dirs = [os.path.join(src, d) for d in os.listdir(src) if os.path.isdir(os.path.join(src, d))]
|
304
|
+
|
305
|
+
dirs.sort() # Optional: sort directories if needed
|
306
|
+
|
307
|
+
# Get common files in all directories
|
308
|
+
common_files = set(os.listdir(dirs[0]))
|
309
|
+
for d in dirs[1:]:
|
310
|
+
common_files.intersection_update(os.listdir(d))
|
311
|
+
common_files = list(common_files)
|
312
|
+
|
313
|
+
results = []
|
314
|
+
conditions = [os.path.basename(d) for d in dirs]
|
315
|
+
|
316
|
+
for index, filename in enumerate(common_files):
|
317
|
+
print(f'Processing image {index+1}/{len(common_files)}', end='\r', flush=True)
|
318
|
+
paths = [os.path.join(d, filename) for d in dirs]
|
319
|
+
|
320
|
+
# Check if file exists in all directories
|
321
|
+
if not all(os.path.exists(path) for path in paths):
|
322
|
+
print(f'Skipping {filename} as it is not present in all directories.')
|
323
|
+
continue
|
324
|
+
|
325
|
+
masks = [_read_mask(path) for path in paths]
|
326
|
+
boundaries = [extract_boundaries(mask) for mask in masks]
|
327
|
+
|
328
|
+
if verbose:
|
329
|
+
visualize_cellpose_masks(masks, titles=conditions, comparison_title=f"Masks Comparison for {filename}", save=save, src=src)
|
330
|
+
|
331
|
+
# Initialize data structure for results
|
332
|
+
file_results = {'filename': filename}
|
333
|
+
|
334
|
+
# Compare each mask with each other
|
335
|
+
for i in range(len(masks)):
|
336
|
+
for j in range(i + 1, len(masks)):
|
337
|
+
condition_i = conditions[i]
|
338
|
+
condition_j = conditions[j]
|
339
|
+
mask_i = masks[i]
|
340
|
+
mask_j = masks[j]
|
341
|
+
|
342
|
+
# Compute metrics
|
343
|
+
boundary_f1 = boundary_f1_score(mask_i, mask_j)
|
344
|
+
jaccard = jaccard_index(mask_i, mask_j)
|
345
|
+
average_precision = compute_segmentation_ap(mask_i, mask_j)
|
346
|
+
|
347
|
+
# Store results
|
348
|
+
file_results[f'jaccard_{condition_i}_{condition_j}'] = jaccard
|
349
|
+
file_results[f'boundary_f1_{condition_i}_{condition_j}'] = boundary_f1
|
350
|
+
file_results[f'average_precision_{condition_i}_{condition_j}'] = average_precision
|
351
|
+
|
352
|
+
results.append(file_results)
|
353
|
+
|
354
|
+
fig = plot_comparison_results(results)
|
355
|
+
|
356
|
+
save_results_and_figure(src, fig, results)
|
357
|
+
|
358
|
+
return results, fig
|