spacr 0.0.1__py3-none-any.whl → 0.0.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
spacr/old_code.py CHANGED
@@ -51,7 +51,7 @@ def resize_figure_to_canvas(fig, canvas):
51
51
 
52
52
  return fig
53
53
 
54
- def process_fig_queue_v1():
54
+ def process_fig_queue_v1():
55
55
  global canvas
56
56
  while not fig_queue.empty():
57
57
  try:
@@ -102,3 +102,257 @@ def run_mask_gui(q):
102
102
  preprocess_generate_masks_wrapper(settings['src'], settings=settings, advanced_settings={})
103
103
  except Exception as e:
104
104
  q.put(f"Error during processing: {e}\n")
105
+
106
+ #@log_function_call
107
+ def main_thread_update_function(root, q, fig_queue, canvas_widget, progress_label):
108
+ try:
109
+ while not q.empty():
110
+ message = q.get_nowait()
111
+ if message.startswith("Progress"):
112
+ progress_label.config(text=message)
113
+ elif message.startswith("Processing"):
114
+ progress_label.config(text=message)
115
+ elif message == "" or message == "\r":
116
+ pass
117
+ elif message.startswith("/"):
118
+ pass
119
+ elif message.startswith("\\"):
120
+ pass
121
+ elif message.startswith(""):
122
+ pass
123
+ else:
124
+ print(message)
125
+ except Exception as e:
126
+ print(f"Error updating GUI canvas: {e}")
127
+ #try:
128
+ # while not fig_queue.empty():
129
+ # fig = fig_queue.get_nowait()
130
+ # #if hasattr(canvas_widget, 'figure'):
131
+ # #clear_canvas(canvas_widget)
132
+ # canvas_widget.figure = fig
133
+ #except Exception as e:
134
+ # print(f"Error updating GUI figure: {e}")
135
+ finally:
136
+ root.after(100, lambda: main_thread_update_function(root, q, fig_queue, canvas_widget, progress_label))
137
+
138
+ class MPNN(MessagePassing):
139
+ def __init__(self, node_in_features, edge_in_features, out_features):
140
+ super(MPNN, self).__init__(aggr='mean') # 'mean' aggregation.
141
+ self.message_mlp = Sequential(
142
+ Linear(node_in_features + edge_in_features, 128),
143
+ ReLU(),
144
+ Linear(128, out_features)
145
+ )
146
+ self.update_mlp = Sequential(
147
+ Linear(out_features, out_features),
148
+ ReLU(),
149
+ Linear(out_features, out_features)
150
+ )
151
+
152
+ def forward(self, x, edge_index, edge_attr):
153
+ # x: Node features [N, node_in_features]
154
+ # edge_index: Graph connectivity [2, E]
155
+ # edge_attr: Edge attributes/features [E, edge_in_features]
156
+ return self.propagate(edge_index, x=x, edge_attr=edge_attr)
157
+
158
+ def message(self, x_j, edge_attr):
159
+ # x_j: Input features of neighbors [E, node_in_features]
160
+ # edge_attr: Edge attributes [E, edge_in_features]
161
+ tmp = torch.cat([x_j, edge_attr], dim=-1) # Concatenate node features with edge attributes
162
+ return self.message_mlp(tmp)
163
+
164
+ def update(self, aggr_out):
165
+ # aggr_out: Aggregated messages [N, out_features]
166
+ return self.update_mlp(aggr_out)
167
+
168
+ def weighted_mse_loss(output, target, score_threshold=0.8, high_score_weight=10):
169
+ # Assumes output and target are the predicted and true scores, respectively
170
+ weights = torch.ones_like(target)
171
+ high_score_mask = target >= score_threshold
172
+ weights[high_score_mask] = high_score_weight
173
+ return ((output - target) ** 2 * weights).mean()
174
+
175
+ def generate_single_graph(sequencing, scores):
176
+ # Load and preprocess sequencing data
177
+ gene_df = pd.read_csv(sequencing)
178
+ gene_df = gene_df.rename(columns={"prc": "well_id", "grna": "gene_id", "count": "read_count"})
179
+ total_reads_per_well = gene_df.groupby('well_id')['read_count'].sum().reset_index(name='total_reads')
180
+ gene_df = gene_df.merge(total_reads_per_well, on='well_id')
181
+ gene_df['well_read_fraction'] = gene_df['read_count']/gene_df['total_reads']
182
+
183
+ # Load and preprocess cell score data
184
+ cell_df = pd.read_csv(scores)
185
+ cell_df = cell_df[['prcfo', 'prc', 'pred']].rename(columns={'prcfo': 'cell_id', 'prc': 'well_id', 'pred': 'score'})
186
+
187
+ # Initialize mappings
188
+ gene_id_to_index = {gene: i for i, gene in enumerate(gene_df['gene_id'].unique())}
189
+ cell_id_to_index = {cell: i + len(gene_id_to_index) for i, cell in enumerate(cell_df['cell_id'].unique())}
190
+
191
+ # Initialize edge indices and attributes
192
+ edge_index = []
193
+ edge_attr = []
194
+
195
+ # Associate each cell with all genes in the same well
196
+ for well_id, group in gene_df.groupby('well_id'):
197
+ if well_id in cell_df['well_id'].values:
198
+ cell_indices = cell_df[cell_df['well_id'] == well_id]['cell_id'].map(cell_id_to_index).values
199
+ gene_indices = group['gene_id'].map(gene_id_to_index).values
200
+ fractions = group['well_read_fraction'].values
201
+
202
+ for cell_idx in cell_indices:
203
+ for gene_idx, fraction in zip(gene_indices, fractions):
204
+ edge_index.append([cell_idx, gene_idx])
205
+ edge_attr.append([fraction])
206
+
207
+ # Convert lists to PyTorch tensors
208
+ edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
209
+ edge_attr = torch.tensor(edge_attr, dtype=torch.float)
210
+ cell_scores = torch.tensor(cell_df['score'].values, dtype=torch.float)
211
+
212
+ # One-hot encoding for genes, and zero features for cells (could be replaced with real features if available)
213
+ gene_features = torch.eye(len(gene_id_to_index))
214
+ cell_features = torch.zeros(len(cell_id_to_index), gene_features.size(1))
215
+
216
+ # Combine features
217
+ x = torch.cat([cell_features, gene_features], dim=0)
218
+
219
+ # Create the graph data object
220
+ data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=cell_scores)
221
+
222
+ return data, gene_id_to_index, len(gene_id_to_index)
223
+
224
+ # in _normalize_and_outline
225
+ outlines = []
226
+
227
+ overlayed_image = rgb_image.copy()
228
+ for i, mask_dim in enumerate(mask_dims):
229
+ mask = np.take(image, mask_dim, axis=2)
230
+ outline = np.zeros_like(mask)
231
+ # Find the contours of the objects in the mask
232
+ for j in np.unique(mask)[1:]:
233
+ contours = find_contours(mask == j, 0.5)
234
+ for contour in contours:
235
+ contour = contour.astype(int)
236
+ outline[contour[:, 0], contour[:, 1]] = j
237
+ # Make the outline thicker
238
+ outline = dilation(outline, square(outline_thickness))
239
+ outlines.append(outline)
240
+ # Overlay the outlines onto the RGB image
241
+ for j in np.unique(outline)[1:]:
242
+ overlayed_image[outline == j] = outline_colors[i % len(outline_colors)]
243
+
244
+ def _extract_filename_metadata(filenames, src, images_by_key, regular_expression, metadata_type='cellvoyager', pick_slice=False, skip_mode='01'):
245
+ for filename in filenames:
246
+ match = regular_expression.match(filename)
247
+ if match:
248
+ try:
249
+ try:
250
+ plate = match.group('plateID')
251
+ except:
252
+ plate = os.path.basename(src)
253
+
254
+ well = match.group('wellID')
255
+ field = match.group('fieldID')
256
+ channel = match.group('chanID')
257
+ mode = None
258
+
259
+ if well[0].isdigit():
260
+ well = str(_safe_int_convert(well))
261
+ if field[0].isdigit():
262
+ field = str(_safe_int_convert(field))
263
+ if channel[0].isdigit():
264
+ channel = str(_safe_int_convert(channel))
265
+
266
+ if metadata_type =='cq1':
267
+ orig_wellID = wellID
268
+ wellID = _convert_cq1_well_id(wellID)
269
+ clear_output(wait=True)
270
+ print(f'Converted Well ID: {orig_wellID} to {wellID}', end='\r', flush=True)
271
+
272
+ if pick_slice:
273
+ try:
274
+ mode = match.group('AID')
275
+ except IndexError:
276
+ sliceid = '00'
277
+
278
+ if mode == skip_mode:
279
+ continue
280
+
281
+ key = (plate, well, field, channel, mode)
282
+ with Image.open(os.path.join(src, filename)) as img:
283
+ images_by_key[key].append(np.array(img))
284
+ except IndexError:
285
+ print(f"Could not extract information from filename {filename} using provided regex")
286
+ else:
287
+ print(f"Filename {filename} did not match provided regex")
288
+ continue
289
+
290
+ return images_by_key
291
+
292
+ def compare_cellpose_masks_v1(src, verbose=False, save=False):
293
+
294
+ from .io import _read_mask
295
+ from .plot import visualize_masks, plot_comparison_results, visualize_cellpose_masks
296
+ from .utils import extract_boundaries, boundary_f1_score, compute_segmentation_ap, jaccard_index
297
+
298
+ import os
299
+ import numpy as np
300
+ from skimage.measure import label
301
+
302
+ # Collect all subdirectories in src
303
+ dirs = [os.path.join(src, d) for d in os.listdir(src) if os.path.isdir(os.path.join(src, d))]
304
+
305
+ dirs.sort() # Optional: sort directories if needed
306
+
307
+ # Get common files in all directories
308
+ common_files = set(os.listdir(dirs[0]))
309
+ for d in dirs[1:]:
310
+ common_files.intersection_update(os.listdir(d))
311
+ common_files = list(common_files)
312
+
313
+ results = []
314
+ conditions = [os.path.basename(d) for d in dirs]
315
+
316
+ for index, filename in enumerate(common_files):
317
+ print(f'Processing image {index+1}/{len(common_files)}', end='\r', flush=True)
318
+ paths = [os.path.join(d, filename) for d in dirs]
319
+
320
+ # Check if file exists in all directories
321
+ if not all(os.path.exists(path) for path in paths):
322
+ print(f'Skipping {filename} as it is not present in all directories.')
323
+ continue
324
+
325
+ masks = [_read_mask(path) for path in paths]
326
+ boundaries = [extract_boundaries(mask) for mask in masks]
327
+
328
+ if verbose:
329
+ visualize_cellpose_masks(masks, titles=conditions, comparison_title=f"Masks Comparison for {filename}", save=save, src=src)
330
+
331
+ # Initialize data structure for results
332
+ file_results = {'filename': filename}
333
+
334
+ # Compare each mask with each other
335
+ for i in range(len(masks)):
336
+ for j in range(i + 1, len(masks)):
337
+ condition_i = conditions[i]
338
+ condition_j = conditions[j]
339
+ mask_i = masks[i]
340
+ mask_j = masks[j]
341
+
342
+ # Compute metrics
343
+ boundary_f1 = boundary_f1_score(mask_i, mask_j)
344
+ jaccard = jaccard_index(mask_i, mask_j)
345
+ average_precision = compute_segmentation_ap(mask_i, mask_j)
346
+
347
+ # Store results
348
+ file_results[f'jaccard_{condition_i}_{condition_j}'] = jaccard
349
+ file_results[f'boundary_f1_{condition_i}_{condition_j}'] = boundary_f1
350
+ file_results[f'average_precision_{condition_i}_{condition_j}'] = average_precision
351
+
352
+ results.append(file_results)
353
+
354
+ fig = plot_comparison_results(results)
355
+
356
+ save_results_and_figure(src, fig, results)
357
+
358
+ return results, fig