spacr 0.0.1__py3-none-any.whl → 0.0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
spacr/old_code.py CHANGED
@@ -51,7 +51,7 @@ def resize_figure_to_canvas(fig, canvas):
51
51
 
52
52
  return fig
53
53
 
54
- def process_fig_queue_v1():
54
+ def process_fig_queue_v1():
55
55
  global canvas
56
56
  while not fig_queue.empty():
57
57
  try:
@@ -102,3 +102,189 @@ def run_mask_gui(q):
102
102
  preprocess_generate_masks_wrapper(settings['src'], settings=settings, advanced_settings={})
103
103
  except Exception as e:
104
104
  q.put(f"Error during processing: {e}\n")
105
+
106
+ @log_function_call
107
+ def main_thread_update_function(root, q, fig_queue, canvas_widget, progress_label):
108
+ try:
109
+ while not q.empty():
110
+ message = q.get_nowait()
111
+ if message.startswith("Progress"):
112
+ progress_label.config(text=message)
113
+ elif message.startswith("Processing"):
114
+ progress_label.config(text=message)
115
+ elif message == "" or message == "\r":
116
+ pass
117
+ elif message.startswith("/"):
118
+ pass
119
+ elif message.startswith("\\"):
120
+ pass
121
+ elif message.startswith(""):
122
+ pass
123
+ else:
124
+ print(message)
125
+ except Exception as e:
126
+ print(f"Error updating GUI canvas: {e}")
127
+ #try:
128
+ # while not fig_queue.empty():
129
+ # fig = fig_queue.get_nowait()
130
+ # #if hasattr(canvas_widget, 'figure'):
131
+ # #clear_canvas(canvas_widget)
132
+ # canvas_widget.figure = fig
133
+ #except Exception as e:
134
+ # print(f"Error updating GUI figure: {e}")
135
+ finally:
136
+ root.after(100, lambda: main_thread_update_function(root, q, fig_queue, canvas_widget, progress_label))
137
+
138
+ class MPNN(MessagePassing):
139
+ def __init__(self, node_in_features, edge_in_features, out_features):
140
+ super(MPNN, self).__init__(aggr='mean') # 'mean' aggregation.
141
+ self.message_mlp = Sequential(
142
+ Linear(node_in_features + edge_in_features, 128),
143
+ ReLU(),
144
+ Linear(128, out_features)
145
+ )
146
+ self.update_mlp = Sequential(
147
+ Linear(out_features, out_features),
148
+ ReLU(),
149
+ Linear(out_features, out_features)
150
+ )
151
+
152
+ def forward(self, x, edge_index, edge_attr):
153
+ # x: Node features [N, node_in_features]
154
+ # edge_index: Graph connectivity [2, E]
155
+ # edge_attr: Edge attributes/features [E, edge_in_features]
156
+ return self.propagate(edge_index, x=x, edge_attr=edge_attr)
157
+
158
+ def message(self, x_j, edge_attr):
159
+ # x_j: Input features of neighbors [E, node_in_features]
160
+ # edge_attr: Edge attributes [E, edge_in_features]
161
+ tmp = torch.cat([x_j, edge_attr], dim=-1) # Concatenate node features with edge attributes
162
+ return self.message_mlp(tmp)
163
+
164
+ def update(self, aggr_out):
165
+ # aggr_out: Aggregated messages [N, out_features]
166
+ return self.update_mlp(aggr_out)
167
+
168
+ def weighted_mse_loss(output, target, score_threshold=0.8, high_score_weight=10):
169
+ # Assumes output and target are the predicted and true scores, respectively
170
+ weights = torch.ones_like(target)
171
+ high_score_mask = target >= score_threshold
172
+ weights[high_score_mask] = high_score_weight
173
+ return ((output - target) ** 2 * weights).mean()
174
+
175
+ def generate_single_graph(sequencing, scores):
176
+ # Load and preprocess sequencing data
177
+ gene_df = pd.read_csv(sequencing)
178
+ gene_df = gene_df.rename(columns={"prc": "well_id", "grna": "gene_id", "count": "read_count"})
179
+ total_reads_per_well = gene_df.groupby('well_id')['read_count'].sum().reset_index(name='total_reads')
180
+ gene_df = gene_df.merge(total_reads_per_well, on='well_id')
181
+ gene_df['well_read_fraction'] = gene_df['read_count']/gene_df['total_reads']
182
+
183
+ # Load and preprocess cell score data
184
+ cell_df = pd.read_csv(scores)
185
+ cell_df = cell_df[['prcfo', 'prc', 'pred']].rename(columns={'prcfo': 'cell_id', 'prc': 'well_id', 'pred': 'score'})
186
+
187
+ # Initialize mappings
188
+ gene_id_to_index = {gene: i for i, gene in enumerate(gene_df['gene_id'].unique())}
189
+ cell_id_to_index = {cell: i + len(gene_id_to_index) for i, cell in enumerate(cell_df['cell_id'].unique())}
190
+
191
+ # Initialize edge indices and attributes
192
+ edge_index = []
193
+ edge_attr = []
194
+
195
+ # Associate each cell with all genes in the same well
196
+ for well_id, group in gene_df.groupby('well_id'):
197
+ if well_id in cell_df['well_id'].values:
198
+ cell_indices = cell_df[cell_df['well_id'] == well_id]['cell_id'].map(cell_id_to_index).values
199
+ gene_indices = group['gene_id'].map(gene_id_to_index).values
200
+ fractions = group['well_read_fraction'].values
201
+
202
+ for cell_idx in cell_indices:
203
+ for gene_idx, fraction in zip(gene_indices, fractions):
204
+ edge_index.append([cell_idx, gene_idx])
205
+ edge_attr.append([fraction])
206
+
207
+ # Convert lists to PyTorch tensors
208
+ edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
209
+ edge_attr = torch.tensor(edge_attr, dtype=torch.float)
210
+ cell_scores = torch.tensor(cell_df['score'].values, dtype=torch.float)
211
+
212
+ # One-hot encoding for genes, and zero features for cells (could be replaced with real features if available)
213
+ gene_features = torch.eye(len(gene_id_to_index))
214
+ cell_features = torch.zeros(len(cell_id_to_index), gene_features.size(1))
215
+
216
+ # Combine features
217
+ x = torch.cat([cell_features, gene_features], dim=0)
218
+
219
+ # Create the graph data object
220
+ data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=cell_scores)
221
+
222
+ return data, gene_id_to_index, len(gene_id_to_index)
223
+
224
+ # in _normalize_and_outline
225
+ outlines = []
226
+
227
+ overlayed_image = rgb_image.copy()
228
+ for i, mask_dim in enumerate(mask_dims):
229
+ mask = np.take(image, mask_dim, axis=2)
230
+ outline = np.zeros_like(mask)
231
+ # Find the contours of the objects in the mask
232
+ for j in np.unique(mask)[1:]:
233
+ contours = find_contours(mask == j, 0.5)
234
+ for contour in contours:
235
+ contour = contour.astype(int)
236
+ outline[contour[:, 0], contour[:, 1]] = j
237
+ # Make the outline thicker
238
+ outline = dilation(outline, square(outline_thickness))
239
+ outlines.append(outline)
240
+ # Overlay the outlines onto the RGB image
241
+ for j in np.unique(outline)[1:]:
242
+ overlayed_image[outline == j] = outline_colors[i % len(outline_colors)]
243
+
244
+ def _extract_filename_metadata(filenames, src, images_by_key, regular_expression, metadata_type='cellvoyager', pick_slice=False, skip_mode='01'):
245
+ for filename in filenames:
246
+ match = regular_expression.match(filename)
247
+ if match:
248
+ try:
249
+ try:
250
+ plate = match.group('plateID')
251
+ except:
252
+ plate = os.path.basename(src)
253
+
254
+ well = match.group('wellID')
255
+ field = match.group('fieldID')
256
+ channel = match.group('chanID')
257
+ mode = None
258
+
259
+ if well[0].isdigit():
260
+ well = str(_safe_int_convert(well))
261
+ if field[0].isdigit():
262
+ field = str(_safe_int_convert(field))
263
+ if channel[0].isdigit():
264
+ channel = str(_safe_int_convert(channel))
265
+
266
+ if metadata_type =='cq1':
267
+ orig_wellID = wellID
268
+ wellID = _convert_cq1_well_id(wellID)
269
+ clear_output(wait=True)
270
+ print(f'Converted Well ID: {orig_wellID} to {wellID}', end='\r', flush=True)
271
+
272
+ if pick_slice:
273
+ try:
274
+ mode = match.group('AID')
275
+ except IndexError:
276
+ sliceid = '00'
277
+
278
+ if mode == skip_mode:
279
+ continue
280
+
281
+ key = (plate, well, field, channel, mode)
282
+ with Image.open(os.path.join(src, filename)) as img:
283
+ images_by_key[key].append(np.array(img))
284
+ except IndexError:
285
+ print(f"Could not extract information from filename {filename} using provided regex")
286
+ else:
287
+ print(f"Filename {filename} did not match provided regex")
288
+ continue
289
+
290
+ return images_by_key
spacr/plot.py CHANGED
@@ -220,10 +220,11 @@ def _filter_objects_in_plot(stack, cell_mask_dim, nucleus_mask_dim, pathogen_mas
220
220
  if not filter_min_max is None:
221
221
  min_max = filter_min_max[i]
222
222
  else:
223
- min_max = [0, 100000]
223
+ min_max = [0, 100000000]
224
224
 
225
225
  mask = np.take(stack, mask_dim, axis=2)
226
226
  props = measure.regionprops_table(mask, properties=['label', 'area'])
227
+ #props = measure.regionprops_table(mask, intensity_image=intensity_image, properties=['label', 'area', 'mean_intensity'])
227
228
  avg_size_before = np.mean(props['area'])
228
229
  total_count_before = len(props['label'])
229
230
 
@@ -264,7 +265,55 @@ def _filter_objects_in_plot(stack, cell_mask_dim, nucleus_mask_dim, pathogen_mas
264
265
 
265
266
  return stack
266
267
 
267
- def _normalize_and_outline(image, remove_background, backgrounds, normalize, normalization_percentiles, overlay, overlay_chans, mask_dims, outline_colors, outline_thickness):
268
+ def plot_arrays(src, figuresize=50, cmap='inferno', nr=1, normalize=True, q1=1, q2=99):
269
+ """
270
+ Plot randomly selected arrays from a given directory.
271
+
272
+ Parameters:
273
+ - src (str): The directory path containing the arrays.
274
+ - figuresize (int): The size of the figure (default: 50).
275
+ - cmap (str): The colormap to use for displaying the arrays (default: 'inferno').
276
+ - nr (int): The number of arrays to plot (default: 1).
277
+ - normalize (bool): Whether to normalize the arrays (default: True).
278
+ - q1 (int): The lower percentile for normalization (default: 1).
279
+ - q2 (int): The upper percentile for normalization (default: 99).
280
+
281
+ Returns:
282
+ None
283
+ """
284
+ from .utils import normalize_to_dtype
285
+
286
+ mask_cmap = random_cmap()
287
+ paths = []
288
+ for file in os.listdir(src):
289
+ if file.endswith('.npy'):
290
+ path = os.path.join(src, file)
291
+ paths.append(path)
292
+ paths = random.sample(paths, nr)
293
+ for path in paths:
294
+ print(f'Image path:{path}')
295
+ img = np.load(path)
296
+ if normalize:
297
+ img = normalize_to_dtype(array=img, q1=q1, q2=q2)
298
+ dim = img.shape
299
+ if len(img.shape)>2:
300
+ array_nr = img.shape[2]
301
+ fig, axs = plt.subplots(1, array_nr,figsize=(figuresize,figuresize))
302
+ for channel in range(array_nr):
303
+ i = np.take(img, [channel], axis=2)
304
+ axs[channel].imshow(i, cmap=plt.get_cmap(cmap)) #_imshow
305
+ axs[channel].set_title('Channel '+str(channel),size=24)
306
+ axs[channel].axis('off')
307
+ else:
308
+ fig, ax = plt.subplots(1, 1,figsize=(figuresize,figuresize))
309
+ ax.imshow(img, cmap=plt.get_cmap(cmap)) #_imshow
310
+ ax.set_title('Channel 0',size=24)
311
+ ax.axis('off')
312
+ fig.tight_layout()
313
+ plt.show()
314
+ return
315
+
316
+ def _normalize_and_outline(image, remove_background, normalize, normalization_percentiles, overlay, overlay_chans, mask_dims, outline_colors, outline_thickness):
268
317
  """
269
318
  Normalize and outline an image.
270
319
 
@@ -283,43 +332,34 @@ def _normalize_and_outline(image, remove_background, backgrounds, normalize, nor
283
332
  Returns:
284
333
  tuple: A tuple containing the overlayed image, the original image, and a list of outlines.
285
334
  """
286
- from .utils import normalize_to_dtype
287
-
288
- outlines = []
335
+ from .utils import normalize_to_dtype, _outline_and_overlay, _gen_rgb_image
336
+
289
337
  if remove_background:
290
- for chan_index, channel in enumerate(range(image.shape[-1])):
291
- single_channel = image[:, :, channel] # Extract the specific channel
292
- background = backgrounds[chan_index]
293
- single_channel[single_channel < background] = 0
294
- image[:, :, channel] = single_channel
338
+ backgrounds = np.percentile(image, 1, axis=(0, 1))
339
+ backgrounds = backgrounds[:, np.newaxis, np.newaxis]
340
+ mask = np.zeros_like(image, dtype=bool)
341
+ for chan_index in range(image.shape[-1]):
342
+ if chan_index not in mask_dims:
343
+ mask[:, :, chan_index] = image[:, :, chan_index] < backgrounds[chan_index]
344
+ image[mask] = 0
345
+
295
346
  if normalize:
296
347
  image = normalize_to_dtype(array=image, q1=normalization_percentiles[0], q2=normalization_percentiles[1])
297
- rgb_image = np.take(image, overlay_chans, axis=-1)
298
- rgb_image = rgb_image.astype(float)
299
- rgb_image -= rgb_image.min()
300
- rgb_image /= rgb_image.max()
348
+
349
+ rgb_image = _gen_rgb_image(image, cahnnels=overlay_chans)
350
+
301
351
  if overlay:
302
- overlayed_image = rgb_image.copy()
303
- for i, mask_dim in enumerate(mask_dims):
304
- mask = np.take(image, mask_dim, axis=2)
305
- outline = np.zeros_like(mask)
306
- # Find the contours of the objects in the mask
307
- for j in np.unique(mask)[1:]:
308
- contours = find_contours(mask == j, 0.5)
309
- for contour in contours:
310
- contour = contour.astype(int)
311
- outline[contour[:, 0], contour[:, 1]] = j
312
- # Make the outline thicker
313
- outline = dilation(outline, square(outline_thickness))
314
- outlines.append(outline)
315
- # Overlay the outlines onto the RGB image
316
- for j in np.unique(outline)[1:]:
317
- overlayed_image[outline == j] = outline_colors[i % len(outline_colors)]
352
+ overlayed_image, outlines, image = _outline_and_overlay(image, rgb_image, mask_dims, outline_colors, outline_thickness)
353
+
318
354
  return overlayed_image, image, outlines
319
355
  else:
356
+ # Remove mask_dims from image
357
+ channels_to_keep = [i for i in range(image.shape[-1]) if i not in mask_dims]
358
+ image = np.take(image, channels_to_keep, axis=-1)
320
359
  return [], image, []
321
360
 
322
361
  def _plot_merged_plot(overlay, image, stack, mask_dims, figuresize, overlayed_image, outlines, cmap, outline_colors, print_object_number):
362
+
323
363
  """
324
364
  Plot the merged plot with overlay, image channels, and masks.
325
365
 
@@ -338,6 +378,7 @@ def _plot_merged_plot(overlay, image, stack, mask_dims, figuresize, overlayed_im
338
378
  Returns:
339
379
  fig (Figure): The generated matplotlib figure.
340
380
  """
381
+
341
382
  if overlay:
342
383
  fig, ax = plt.subplots(1, image.shape[-1] + len(mask_dims) + 1, figsize=(4 * figuresize, figuresize))
343
384
  ax[0].imshow(overlayed_image) #_imshow
@@ -378,60 +419,12 @@ def _plot_merged_plot(overlay, image, stack, mask_dims, figuresize, overlayed_im
378
419
  plt.show()
379
420
  return fig
380
421
 
381
- def plot_arrays(src, figuresize=50, cmap='inferno', nr=1, normalize=True, q1=1, q2=99):
382
- """
383
- Plot randomly selected arrays from a given directory.
384
-
385
- Parameters:
386
- - src (str): The directory path containing the arrays.
387
- - figuresize (int): The size of the figure (default: 50).
388
- - cmap (str): The colormap to use for displaying the arrays (default: 'inferno').
389
- - nr (int): The number of arrays to plot (default: 1).
390
- - normalize (bool): Whether to normalize the arrays (default: True).
391
- - q1 (int): The lower percentile for normalization (default: 1).
392
- - q2 (int): The upper percentile for normalization (default: 99).
393
-
394
- Returns:
395
- None
396
- """
397
- from .utils import normalize_to_dtype
398
-
399
- mask_cmap = random_cmap()
400
- paths = []
401
- for file in os.listdir(src):
402
- if file.endswith('.npy'):
403
- path = os.path.join(src, file)
404
- paths.append(path)
405
- paths = random.sample(paths, nr)
406
- for path in paths:
407
- print(f'Image path:{path}')
408
- img = np.load(path)
409
- if normalize:
410
- img = normalize_to_dtype(array=img, q1=q1, q2=q2)
411
- dim = img.shape
412
- if len(img.shape)>2:
413
- array_nr = img.shape[2]
414
- fig, axs = plt.subplots(1, array_nr,figsize=(figuresize,figuresize))
415
- for channel in range(array_nr):
416
- i = np.take(img, [channel], axis=2)
417
- axs[channel].imshow(i, cmap=plt.get_cmap(cmap)) #_imshow
418
- axs[channel].set_title('Channel '+str(channel),size=24)
419
- axs[channel].axis('off')
420
- else:
421
- fig, ax = plt.subplots(1, 1,figsize=(figuresize,figuresize))
422
- ax.imshow(img, cmap=plt.get_cmap(cmap)) #_imshow
423
- ax.set_title('Channel 0',size=24)
424
- ax.axis('off')
425
- fig.tight_layout()
426
- plt.show()
427
- return
428
-
429
422
  def plot_merged(src, settings):
430
423
  """
431
424
  Plot the merged images after applying various filters and modifications.
432
425
 
433
426
  Args:
434
- src (ndarray): The source images.
427
+ src (path): Path to folder with images.
435
428
  settings (dict): The settings for the plot.
436
429
 
437
430
  Returns:
@@ -463,13 +456,27 @@ def plot_merged(src, settings):
463
456
  if settings['include_multiinfected'] is not True or settings['include_multinucleated'] is not True or settings['filter_min_max'] is not None:
464
457
  stack = _filter_objects_in_plot(stack, settings['cell_mask_dim'], settings['nucleus_mask_dim'], settings['pathogen_mask_dim'], mask_dims, settings['filter_min_max'], settings['include_multinucleated'], settings['include_multiinfected'])
465
458
 
466
- #image = np.take(stack, settings['channel_dims'], axis=2)
467
- print('stack.shape', stack.shape)
468
- overlayed_image, image, outlines = _normalize_and_outline(stack, settings['remove_background'], settings['backgrounds'], settings['normalize'], settings['normalization_percentiles'], settings['overlay'], settings['overlay_chans'], mask_dims, outline_colors, settings['outline_thickness'])
469
-
459
+ overlayed_image, image, outlines = _normalize_and_outline(image=stack,
460
+ remove_background=settings['remove_background'],
461
+ normalize=settings['normalize'],
462
+ normalization_percentiles=settings['normalization_percentiles'],
463
+ overlay=settings['overlay'],
464
+ overlay_chans=settings['overlay_chans'],
465
+ mask_dims=mask_dims,
466
+ outline_colors=outline_colors,
467
+ outline_thickness=settings['outline_thickness'])
470
468
  if index < settings['nr']:
471
469
  index += 1
472
- fig = _plot_merged_plot(settings['overlay'], image, stack, mask_dims, settings['figuresize'], overlayed_image, outlines, settings['cmap'], outline_colors, settings['print_object_number'])
470
+ fig = _plot_merged_plot(overlay=settings['overlay'],
471
+ image=image,
472
+ stack=stack,
473
+ mask_dims=mask_dims,
474
+ figuresize=settings['figuresize'],
475
+ overlayed_image=overlayed_image,
476
+ outlines=outlines,
477
+ cmap=settings['cmap'],
478
+ outline_colors=outline_colors,
479
+ print_object_number=settings['print_object_number'])
473
480
  else:
474
481
  return fig
475
482
 
@@ -700,7 +707,7 @@ def _visualize_and_save_timelapse_stack_with_tracks(masks, tracks_df, save, src,
700
707
  interactive (bool, optional): Flag indicating whether to display the timelapse stack interactively. Defaults to False.
701
708
  """
702
709
 
703
- from .timelapse import _save_mask_timelapse_as_gif
710
+ from .io import _save_mask_timelapse_as_gif
704
711
 
705
712
  highest_label = max(np.max(mask) for mask in masks)
706
713
  # Generate random colors for each label, including the background
spacr/sim.py CHANGED
@@ -35,6 +35,7 @@ def generate_gene_list(number_of_genes, number_of_all_genes):
35
35
 
36
36
  # plate_map is a table with a row for each well, containing well metadata: plate_id, row_id, and column_id
37
37
  def generate_plate_map(nr_plates):
38
+ print('nr_plates',nr_plates)
38
39
  """
39
40
  Generate a plate map based on the number of plates.
40
41
 
@@ -1023,6 +1024,8 @@ def save_plot(fig, src, variable, i):
1023
1024
  return
1024
1025
 
1025
1026
  def run_and_save(i, settings, time_ls, total_sims):
1027
+
1028
+
1026
1029
  """
1027
1030
  Run the simulation and save the results.
1028
1031
 
@@ -1035,6 +1038,9 @@ def run_and_save(i, settings, time_ls, total_sims):
1035
1038
  Returns:
1036
1039
  tuple: A tuple containing the simulation index, simulation time, and None.
1037
1040
  """
1041
+ #print(f'Runnings simulation with the following paramiters')
1042
+ #print(settings)
1043
+
1038
1044
  if settings['random_seed']:
1039
1045
  random.seed(42) # sims will be too similar with random seed
1040
1046
  src = settings['src']
@@ -1184,4 +1190,5 @@ def run_multiple_simulations(settings):
1184
1190
  average_time = np.mean(time_ls) if len(time_ls) > 0 else 0
1185
1191
  time_left = (((total_sims - sims_processed) * average_time) / max_workers) / 60
1186
1192
  print(f'Progress: {sims_processed}/{total_sims} Time/simulation {average_time:.3f}sec Time Remaining {time_left:.3f} min.', end='\r', flush=True)
1187
- result.get()
1193
+ result.get()
1194
+