spacr 0.0.18__py3-none-any.whl → 0.0.21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
spacr/measure.py CHANGED
@@ -12,6 +12,8 @@ from scipy.ndimage import binary_dilation
12
12
  from skimage.segmentation import find_boundaries
13
13
  from skimage.feature import graycomatrix, graycoprops
14
14
  from mahotas.features import zernike_moments
15
+ from skimage import morphology, measure, filters
16
+ from skimage.util import img_as_bool
15
17
 
16
18
  from .logger import log_function_call
17
19
 
@@ -92,6 +94,70 @@ def _calculate_zernike(mask, df, degree=8):
92
94
  else:
93
95
  return df
94
96
 
97
+ def _analyze_cytoskeleton(array, mask, channel):
98
+ """
99
+ Analyzes and extracts skeleton properties from labeled objects in a masked image based on microtubule staining intensities.
100
+
101
+ Parameters:
102
+ image : numpy array
103
+ Intensity image where the microtubules are stained.
104
+ mask : numpy array
105
+ Mask where objects are labeled for analysis. Each label corresponds to a unique object.
106
+
107
+ Returns:
108
+ DataFrame
109
+ A pandas DataFrame containing the measured properties of each object's skeleton.
110
+ """
111
+
112
+ image = array[:, :, channel]
113
+
114
+ properties_list = []
115
+
116
+ # Process each object in the mask based on its label
117
+ for label in np.unique(mask):
118
+ if label == 0:
119
+ continue # Skip background
120
+
121
+ # Isolate the object using the label
122
+ object_region = mask == label
123
+ region_intensity = np.where(object_region, image, 0) # Use np.where for more efficient masking
124
+
125
+ # Ensure there are non-zero values to process
126
+ if np.any(region_intensity):
127
+ # Calculate adaptive offset based on intensity percentiles within the object
128
+ valid_pixels = region_intensity[region_intensity > 0]
129
+ if len(valid_pixels) > 1: # Ensure there are enough pixels to compute percentiles
130
+ offset = np.percentile(valid_pixels, 90) - np.percentile(valid_pixels, 50)
131
+ block_size = 35 # Adjust this based on your object sizes and detail needs
132
+ local_thresh = filters.threshold_local(region_intensity, block_size=block_size, offset=offset)
133
+ cytoskeleton = region_intensity > local_thresh
134
+
135
+ # Skeletonize the thresholded cytoskeleton
136
+ skeleton = morphology.skeletonize(img_as_bool(cytoskeleton))
137
+
138
+ # Measure properties of the skeleton
139
+ skeleton_props = measure.regionprops(measure.label(skeleton), intensity_image=image)
140
+ skeleton_length = sum(prop.area for prop in skeleton_props) # Sum of lengths of all skeleton segments
141
+ branch_data = morphology.skeleton_branch_analysis(skeleton)
142
+
143
+ # Store properties
144
+ properties = {
145
+ "object_label": label,
146
+ "skeleton_length": skeleton_length,
147
+ "skeleton_branch_points": len(branch_data['branch_points'])
148
+ }
149
+ properties_list.append(properties)
150
+ else:
151
+ # Handle cases with insufficient pixels
152
+ properties_list.append({
153
+ "object_label": label,
154
+ "skeleton_length": 0,
155
+ "skeleton_branch_points": 0
156
+ })
157
+
158
+ return pd.DataFrame(properties_list)
159
+
160
+ @log_function_call
95
161
  def _morphological_measurements(cell_mask, nucleus_mask, pathogen_mask, cytoplasm_mask, settings, zernike=True, degree=8):
96
162
  """
97
163
  Calculate morphological measurements for cells, nucleus, pathogens, and cytoplasms based on the given masks.
@@ -435,6 +501,7 @@ def _estimate_blur(image):
435
501
  # Compute and return the variance of the Laplacian
436
502
  return lap.var()
437
503
 
504
+ @log_function_call
438
505
  def _intensity_measurements(cell_mask, nucleus_mask, pathogen_mask, cytoplasm_mask, channel_arrays, settings, sizes=[3, 6, 12, 24], periphery=True, outside=True):
439
506
  """
440
507
  Calculate various intensity measurements for different regions in the image.
@@ -524,6 +591,7 @@ def _intensity_measurements(cell_mask, nucleus_mask, pathogen_mask, cytoplasm_ma
524
591
 
525
592
  @log_function_call
526
593
  def _measure_crop_core(index, time_ls, file, settings):
594
+
527
595
  """
528
596
  Measure and crop the images based on specified settings.
529
597
 
@@ -622,9 +690,8 @@ def _measure_crop_core(index, time_ls, file, settings):
622
690
  if settings['cytoplasm_min_size'] is not None and settings['cytoplasm_min_size'] != 0:
623
691
  cytoplasm_mask = _filter_object(cytoplasm_mask, settings['cytoplasm_min_size'])
624
692
 
625
- if settings['cell_mask_dim'] is not None and settings['pathogen_mask_dim'] is not None:
626
- if settings['include_uninfected'] == False:
627
- cell_mask, nucleus_mask, pathogen_mask, cytoplasm_mask = _exclude_objects(cell_mask, nucleus_mask, pathogen_mask, cytoplasm_mask, include_uninfected=False)
693
+ if settings['cell_mask_dim'] is not None:
694
+ cell_mask, nucleus_mask, pathogen_mask, cytoplasm_mask = _exclude_objects(cell_mask, nucleus_mask, pathogen_mask, cytoplasm_mask, include_uninfected=settings['include_uninfected'])
628
695
 
629
696
  # Update data with the new masks
630
697
  if settings['cell_mask_dim'] is not None:
@@ -643,6 +710,10 @@ def _measure_crop_core(index, time_ls, file, settings):
643
710
 
644
711
  cell_df, nucleus_df, pathogen_df, cytoplasm_df = _morphological_measurements(cell_mask, nucleus_mask, pathogen_mask, cytoplasm_mask, settings)
645
712
 
713
+ #if settings['skeleton']:
714
+ #skeleton_df = _analyze_cytoskeleton(image=channel_arrays, mask=cell_mask, channel=1)
715
+ #merge skeleton_df with cell_df here
716
+
646
717
  cell_intensity_df, nucleus_intensity_df, pathogen_intensity_df, cytoplasm_intensity_df = _intensity_measurements(cell_mask, nucleus_mask, pathogen_mask, cytoplasm_mask, channel_arrays, settings, sizes=[1, 2, 3, 4, 5], periphery=True, outside=True)
647
718
  if settings['cell_mask_dim'] is not None:
648
719
  cell_merged_df = _merge_and_save_to_database(cell_df, cell_intensity_df, 'cell', source_folder, file_name, settings['experiment'], settings['timelapse'])
@@ -656,7 +727,6 @@ def _measure_crop_core(index, time_ls, file, settings):
656
727
  if settings['cytoplasm']:
657
728
  cytoplasm_merged_df = _merge_and_save_to_database(cytoplasm_df, cytoplasm_intensity_df, 'cytoplasm', source_folder, file_name, settings['experiment'], settings['timelapse'])
658
729
 
659
-
660
730
  if settings['save_png'] or settings['save_arrays'] or settings['plot']:
661
731
 
662
732
  if isinstance(settings['dialate_pngs'], bool):
@@ -676,7 +746,6 @@ def _measure_crop_core(index, time_ls, file, settings):
676
746
  crop_ls = settings['crop_mode']
677
747
  size_ls = settings['png_size']
678
748
  for crop_idx, crop_mode in enumerate(crop_ls):
679
- print(crop_idx, crop_mode)
680
749
  width, height = size_ls[crop_idx]
681
750
  if crop_mode == 'cell':
682
751
  crop_mask = cell_mask.copy()
@@ -730,7 +799,7 @@ def _measure_crop_core(index, time_ls, file, settings):
730
799
  png_channels = data[:, :, settings['png_dims']].astype(data_type)
731
800
 
732
801
  if settings['normalize_by'] == 'fov':
733
- percentiles_list = _get_percentiles(png_channels, settings['normalize_percentiles'][0],q2=settings['normalize_percentiles'][1])
802
+ percentiles_list = _get_percentiles(png_channels, settings['normalize'][0],q2=settings['normalize'][1])
734
803
 
735
804
  png_channels = _crop_center(png_channels, region, new_width=width, new_height=height)
736
805
 
@@ -787,6 +856,7 @@ def _measure_crop_core(index, time_ls, file, settings):
787
856
  conn.commit()
788
857
  except sqlite3.OperationalError as e:
789
858
  print(f"SQLite error: {e}", flush=True)
859
+ traceback.print_exc()
790
860
 
791
861
  if settings['plot']:
792
862
  _plot_cropped_arrays(png_channels)
@@ -818,37 +888,31 @@ def _measure_crop_core(index, time_ls, file, settings):
818
888
  return average_time, cells
819
889
 
820
890
  @log_function_call
821
- def measure_crop(settings, annotation_settings, advanced_settings):
891
+ def measure_crop(settings):
892
+
822
893
  """
823
894
  Measure the crop of an image based on the provided settings.
824
895
 
825
896
  Args:
826
897
  settings (dict): The settings for measuring the crop.
827
- annotation_settings (dict): The annotation settings.
828
- advanced_settings (dict): The advanced settings.
829
898
 
830
899
  Returns:
831
900
  None
832
901
  """
902
+
903
+ if settings.get('test_mode', False):
904
+ if not os.basename(settings['src']) == 'test':
905
+ src = os.path.join(src, 'test')
906
+ settings['src'] = src
907
+ print(f'Changed source folder to {src} for test mode')
908
+ else:
909
+ print(f'Test mode enabled, using source folder {settings["src"]}')
833
910
 
834
911
  from .io import _save_settings_to_db
835
912
  from .timelapse import _timelapse_masks_to_gif, _scmovie
836
913
  from .plot import _save_scimg_plot
837
914
  from .utils import _list_endpoint_subdirectories, _generate_representative_images
838
915
 
839
- settings = {**settings, **annotation_settings, **advanced_settings}
840
-
841
- dirname = os.path.dirname(settings['input_folder'])
842
- settings_df = pd.DataFrame(list(settings.items()), columns=['Key', 'Value'])
843
- settings_csv = os.path.join(dirname,'settings','measure_crop_settings.csv')
844
- os.makedirs(os.path.join(dirname,'settings'), exist_ok=True)
845
- settings_df.to_csv(settings_csv, index=False)
846
-
847
- if settings['timelapse_objects'] == 'nucleus':
848
- if not settings['cell_mask_dim'] is None:
849
- tlo = settings['timelapse_objects']
850
- print(f'timelapse object:{tlo}, cells will be relabeled to nucleus labels to track cells.')
851
-
852
916
  #general settings
853
917
  settings['merge_edge_pathogen_cells'] = True
854
918
  settings['radial_dist'] = True
@@ -857,6 +921,26 @@ def measure_crop(settings, annotation_settings, advanced_settings):
857
921
  settings['homogeneity'] = True
858
922
  settings['homogeneity_distances'] = [8,16,32]
859
923
  settings['save_arrays'] = False
924
+
925
+ settings['dialate_pngs'] = False
926
+ settings['dialate_png_ratios'] = [0.2]
927
+ settings['timelapse'] = False
928
+ settings['representative_images'] = False
929
+ settings['timelapse_objects'] = 'cell'
930
+ settings['max_workers'] = os.cpu_count()-2
931
+ settings['experiment'] = 'test'
932
+ settings['cells'] = 'HeLa'
933
+ settings['cell_loc'] = None
934
+ settings['pathogens'] = ['ME49Dku80WT', 'ME49Dku80dgra8:GRA8', 'ME49Dku80dgra8', 'ME49Dku80TKO']
935
+ settings['pathogen_loc'] = [['c1', 'c2', 'c3', 'c4', 'c5', 'c6'], ['c7', 'c8', 'c9', 'c10', 'c11', 'c12'], ['c13', 'c14', 'c15', 'c16', 'c17', 'c18'], ['c19', 'c20', 'c21', 'c22', 'c23', 'c24']]
936
+ settings['treatments'] = ['BR1', 'BR2', 'BR3']
937
+ settings['treatment_loc'] = [['c1', 'c2', 'c7', 'c8', 'c13', 'c14', 'c19', 'c20'], ['c3', 'c4', 'c9', 'c10', 'c15', 'c16', 'c21', 'c22'], ['c5', 'c6', 'c11', 'c12', 'c17', 'c18', 'c23', 'c24']]
938
+ settings['channel_of_interest'] = 2
939
+ settings['compartments'] = ['pathogen', 'cytoplasm']
940
+ settings['measurement'] = 'mean_intensity'
941
+ settings['nr_imgs'] = 32
942
+ settings['um_per_pixel'] = 0.1
943
+ settings['center_crop'] = True
860
944
 
861
945
  if settings['cell_mask_dim'] is None:
862
946
  settings['include_uninfected'] = True
@@ -869,7 +953,18 @@ def measure_crop(settings, annotation_settings, advanced_settings):
869
953
  else:
870
954
  settings['cytoplasm'] = False
871
955
 
872
- settings['center_crop'] = True
956
+ #settings = {**settings, **annotation_settings, **advanced_settings}
957
+
958
+ dirname = os.path.dirname(settings['input_folder'])
959
+ settings_df = pd.DataFrame(list(settings.items()), columns=['Key', 'Value'])
960
+ settings_csv = os.path.join(dirname,'settings','measure_crop_settings.csv')
961
+ os.makedirs(os.path.join(dirname,'settings'), exist_ok=True)
962
+ settings_df.to_csv(settings_csv, index=False)
963
+
964
+ if settings['timelapse_objects'] == 'nucleus':
965
+ if not settings['cell_mask_dim'] is None:
966
+ tlo = settings['timelapse_objects']
967
+ print(f'timelapse object:{tlo}, cells will be relabeled to nucleus labels to track cells.')
873
968
 
874
969
  int_setting_keys = ['cell_mask_dim', 'nucleus_mask_dim', 'pathogen_mask_dim', 'cell_min_size', 'nucleus_min_size', 'pathogen_min_size', 'cytoplasm_min_size']
875
970
 
@@ -913,64 +1008,49 @@ def measure_crop(settings, annotation_settings, advanced_settings):
913
1008
  time_left = (((files_to_process-files_processed)*average_time)/max_workers)/60
914
1009
  print(f'Progress: {files_processed}/{files_to_process} Time/img {average_time:.3f}sec, Time Remaining {time_left:.3f} min.', end='\r', flush=True)
915
1010
  result.get()
916
-
917
- #if settings['save_png']:
918
- # img_fldr = os.path.join(os.path.dirname(settings['input_folder']), 'data')
919
- # sc_img_fldrs = _list_endpoint_subdirectories(img_fldr)
920
- # for well_src in sc_img_fldrs:
921
- # if len(os.listdir(well_src)) < 16:
922
- # nr_imgs = len(os.listdir(well_src))
923
- # standardize = False
924
- # else:
925
- # nr_imgs = 16
926
- # standardize = True
927
- # try:
928
- # _save_scimg_plot(src=well_src, nr_imgs=nr_imgs, channel_indices=settings['png_dims'], um_per_pixel=0.1, scale_bar_length_um=10, standardize=standardize, fontsize=12, show_filename=True, channel_names=['red','green','blue'], dpi=300, plot=False)
929
- # except Exception as e: # Consider catching a more specific exception if possible
930
- # print(f"Unable to generate figure for folder {well_src}: {e}", flush=True)
931
-
932
- if settings['save_png']:
933
- img_fldr = os.path.join(os.path.dirname(settings['input_folder']), 'data')
934
- sc_img_fldrs = _list_endpoint_subdirectories(img_fldr)
935
-
936
- for i, well_src in enumerate(sc_img_fldrs):
937
- if len(os.listdir(well_src)) < 16:
938
- nr_imgs = len(os.listdir(well_src))
939
- standardize = False
940
- else:
941
- nr_imgs = 16
942
- standardize = True
943
- try:
944
- all_folders = len(sc_img_fldrs)
945
- _save_scimg_plot(src=well_src, nr_imgs=nr_imgs, channel_indices=settings['png_dims'], um_per_pixel=0.1, scale_bar_length_um=10, standardize=standardize, fontsize=12, show_filename=True, channel_names=['red','green','blue'], dpi=300, plot=False, i=i, all_folders=all_folders)
946
-
947
- except Exception as e:
948
- print(f"Unable to generate figure for folder {well_src}: {e}", end='\r', flush=True)
949
- #traceback.print_exc()
1011
+
1012
+ if settings['representative_images']:
1013
+ if settings['save_png']:
1014
+ img_fldr = os.path.join(os.path.dirname(settings['input_folder']), 'data')
1015
+ sc_img_fldrs = _list_endpoint_subdirectories(img_fldr)
1016
+
1017
+ for i, well_src in enumerate(sc_img_fldrs):
1018
+ if len(os.listdir(well_src)) < 16:
1019
+ nr_imgs = len(os.listdir(well_src))
1020
+ standardize = False
1021
+ else:
1022
+ nr_imgs = 16
1023
+ standardize = True
1024
+ try:
1025
+ all_folders = len(sc_img_fldrs)
1026
+ _save_scimg_plot(src=well_src, nr_imgs=nr_imgs, channel_indices=settings['png_dims'], um_per_pixel=0.1, scale_bar_length_um=10, standardize=standardize, fontsize=12, show_filename=True, channel_names=['red','green','blue'], dpi=300, plot=False, i=i, all_folders=all_folders)
1027
+
1028
+ except Exception as e:
1029
+ print(f"Unable to generate figure for folder {well_src}: {e}", end='\r', flush=True)
1030
+ #traceback.print_exc()
950
1031
 
951
1032
  if settings['save_measurements']:
952
- if settings['representative_images']:
953
- db_path = os.path.join(os.path.dirname(settings['input_folder']), 'measurements', 'measurements.db')
954
- channel_indices = settings['png_dims']
955
- channel_indices = [min(value, 2) for value in channel_indices]
956
- _generate_representative_images(db_path,
957
- cells=settings['cells'],
958
- cell_loc=settings['cell_loc'],
959
- pathogens=settings['pathogens'],
960
- pathogen_loc=settings['pathogen_loc'],
961
- treatments=settings['treatments'],
962
- treatment_loc=settings['treatment_loc'],
963
- channel_of_interest=settings['channel_of_interest'],
964
- compartments = settings['compartments'],
965
- measurement = settings['measurement'],
966
- nr_imgs=settings['nr_imgs'],
967
- channel_indices=channel_indices,
968
- um_per_pixel=settings['um_per_pixel'],
969
- scale_bar_length_um=10,
970
- plot=False,
971
- fontsize=12,
972
- show_filename=True,
973
- channel_names=None)
1033
+ db_path = os.path.join(os.path.dirname(settings['input_folder']), 'measurements', 'measurements.db')
1034
+ channel_indices = settings['png_dims']
1035
+ channel_indices = [min(value, 2) for value in channel_indices]
1036
+ _generate_representative_images(db_path,
1037
+ cells=settings['cells'],
1038
+ cell_loc=settings['cell_loc'],
1039
+ pathogens=settings['pathogens'],
1040
+ pathogen_loc=settings['pathogen_loc'],
1041
+ treatments=settings['treatments'],
1042
+ treatment_loc=settings['treatment_loc'],
1043
+ channel_of_interest=settings['channel_of_interest'],
1044
+ compartments = settings['compartments'],
1045
+ measurement = settings['measurement'],
1046
+ nr_imgs=settings['nr_imgs'],
1047
+ channel_indices=channel_indices,
1048
+ um_per_pixel=settings['um_per_pixel'],
1049
+ scale_bar_length_um=10,
1050
+ plot=False,
1051
+ fontsize=12,
1052
+ show_filename=True,
1053
+ channel_names=None)
974
1054
 
975
1055
  if settings['timelapse']:
976
1056
  if settings['timelapse_objects'] == 'nucleus':
@@ -979,7 +1059,7 @@ def measure_crop(settings, annotation_settings, advanced_settings):
979
1059
  object_types = ['nucleus','pathogen','cell']
980
1060
  _timelapse_masks_to_gif(folder_path, mask_channels, object_types)
981
1061
 
982
- if settings['save_png']:
1062
+ #if settings['save_png']:
983
1063
  img_fldr = os.path.join(os.path.dirname(settings['input_folder']), 'data')
984
1064
  sc_img_fldrs = _list_endpoint_subdirectories(img_fldr)
985
1065
  _scmovie(sc_img_fldrs)
spacr/old_code.py CHANGED
@@ -133,4 +133,158 @@ def main_thread_update_function(root, q, fig_queue, canvas_widget, progress_labe
133
133
  #except Exception as e:
134
134
  # print(f"Error updating GUI figure: {e}")
135
135
  finally:
136
- root.after(100, lambda: main_thread_update_function(root, q, fig_queue, canvas_widget, progress_label))
136
+ root.after(100, lambda: main_thread_update_function(root, q, fig_queue, canvas_widget, progress_label))
137
+
138
+ class MPNN(MessagePassing):
139
+ def __init__(self, node_in_features, edge_in_features, out_features):
140
+ super(MPNN, self).__init__(aggr='mean') # 'mean' aggregation.
141
+ self.message_mlp = Sequential(
142
+ Linear(node_in_features + edge_in_features, 128),
143
+ ReLU(),
144
+ Linear(128, out_features)
145
+ )
146
+ self.update_mlp = Sequential(
147
+ Linear(out_features, out_features),
148
+ ReLU(),
149
+ Linear(out_features, out_features)
150
+ )
151
+
152
+ def forward(self, x, edge_index, edge_attr):
153
+ # x: Node features [N, node_in_features]
154
+ # edge_index: Graph connectivity [2, E]
155
+ # edge_attr: Edge attributes/features [E, edge_in_features]
156
+ return self.propagate(edge_index, x=x, edge_attr=edge_attr)
157
+
158
+ def message(self, x_j, edge_attr):
159
+ # x_j: Input features of neighbors [E, node_in_features]
160
+ # edge_attr: Edge attributes [E, edge_in_features]
161
+ tmp = torch.cat([x_j, edge_attr], dim=-1) # Concatenate node features with edge attributes
162
+ return self.message_mlp(tmp)
163
+
164
+ def update(self, aggr_out):
165
+ # aggr_out: Aggregated messages [N, out_features]
166
+ return self.update_mlp(aggr_out)
167
+
168
+ def weighted_mse_loss(output, target, score_threshold=0.8, high_score_weight=10):
169
+ # Assumes output and target are the predicted and true scores, respectively
170
+ weights = torch.ones_like(target)
171
+ high_score_mask = target >= score_threshold
172
+ weights[high_score_mask] = high_score_weight
173
+ return ((output - target) ** 2 * weights).mean()
174
+
175
+ def generate_single_graph(sequencing, scores):
176
+ # Load and preprocess sequencing data
177
+ gene_df = pd.read_csv(sequencing)
178
+ gene_df = gene_df.rename(columns={"prc": "well_id", "grna": "gene_id", "count": "read_count"})
179
+ total_reads_per_well = gene_df.groupby('well_id')['read_count'].sum().reset_index(name='total_reads')
180
+ gene_df = gene_df.merge(total_reads_per_well, on='well_id')
181
+ gene_df['well_read_fraction'] = gene_df['read_count']/gene_df['total_reads']
182
+
183
+ # Load and preprocess cell score data
184
+ cell_df = pd.read_csv(scores)
185
+ cell_df = cell_df[['prcfo', 'prc', 'pred']].rename(columns={'prcfo': 'cell_id', 'prc': 'well_id', 'pred': 'score'})
186
+
187
+ # Initialize mappings
188
+ gene_id_to_index = {gene: i for i, gene in enumerate(gene_df['gene_id'].unique())}
189
+ cell_id_to_index = {cell: i + len(gene_id_to_index) for i, cell in enumerate(cell_df['cell_id'].unique())}
190
+
191
+ # Initialize edge indices and attributes
192
+ edge_index = []
193
+ edge_attr = []
194
+
195
+ # Associate each cell with all genes in the same well
196
+ for well_id, group in gene_df.groupby('well_id'):
197
+ if well_id in cell_df['well_id'].values:
198
+ cell_indices = cell_df[cell_df['well_id'] == well_id]['cell_id'].map(cell_id_to_index).values
199
+ gene_indices = group['gene_id'].map(gene_id_to_index).values
200
+ fractions = group['well_read_fraction'].values
201
+
202
+ for cell_idx in cell_indices:
203
+ for gene_idx, fraction in zip(gene_indices, fractions):
204
+ edge_index.append([cell_idx, gene_idx])
205
+ edge_attr.append([fraction])
206
+
207
+ # Convert lists to PyTorch tensors
208
+ edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
209
+ edge_attr = torch.tensor(edge_attr, dtype=torch.float)
210
+ cell_scores = torch.tensor(cell_df['score'].values, dtype=torch.float)
211
+
212
+ # One-hot encoding for genes, and zero features for cells (could be replaced with real features if available)
213
+ gene_features = torch.eye(len(gene_id_to_index))
214
+ cell_features = torch.zeros(len(cell_id_to_index), gene_features.size(1))
215
+
216
+ # Combine features
217
+ x = torch.cat([cell_features, gene_features], dim=0)
218
+
219
+ # Create the graph data object
220
+ data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=cell_scores)
221
+
222
+ return data, gene_id_to_index, len(gene_id_to_index)
223
+
224
+ # in _normalize_and_outline
225
+ outlines = []
226
+
227
+ overlayed_image = rgb_image.copy()
228
+ for i, mask_dim in enumerate(mask_dims):
229
+ mask = np.take(image, mask_dim, axis=2)
230
+ outline = np.zeros_like(mask)
231
+ # Find the contours of the objects in the mask
232
+ for j in np.unique(mask)[1:]:
233
+ contours = find_contours(mask == j, 0.5)
234
+ for contour in contours:
235
+ contour = contour.astype(int)
236
+ outline[contour[:, 0], contour[:, 1]] = j
237
+ # Make the outline thicker
238
+ outline = dilation(outline, square(outline_thickness))
239
+ outlines.append(outline)
240
+ # Overlay the outlines onto the RGB image
241
+ for j in np.unique(outline)[1:]:
242
+ overlayed_image[outline == j] = outline_colors[i % len(outline_colors)]
243
+
244
+ def _extract_filename_metadata(filenames, src, images_by_key, regular_expression, metadata_type='cellvoyager', pick_slice=False, skip_mode='01'):
245
+ for filename in filenames:
246
+ match = regular_expression.match(filename)
247
+ if match:
248
+ try:
249
+ try:
250
+ plate = match.group('plateID')
251
+ except:
252
+ plate = os.path.basename(src)
253
+
254
+ well = match.group('wellID')
255
+ field = match.group('fieldID')
256
+ channel = match.group('chanID')
257
+ mode = None
258
+
259
+ if well[0].isdigit():
260
+ well = str(_safe_int_convert(well))
261
+ if field[0].isdigit():
262
+ field = str(_safe_int_convert(field))
263
+ if channel[0].isdigit():
264
+ channel = str(_safe_int_convert(channel))
265
+
266
+ if metadata_type =='cq1':
267
+ orig_wellID = wellID
268
+ wellID = _convert_cq1_well_id(wellID)
269
+ clear_output(wait=True)
270
+ print(f'Converted Well ID: {orig_wellID} to {wellID}', end='\r', flush=True)
271
+
272
+ if pick_slice:
273
+ try:
274
+ mode = match.group('AID')
275
+ except IndexError:
276
+ sliceid = '00'
277
+
278
+ if mode == skip_mode:
279
+ continue
280
+
281
+ key = (plate, well, field, channel, mode)
282
+ with Image.open(os.path.join(src, filename)) as img:
283
+ images_by_key[key].append(np.array(img))
284
+ except IndexError:
285
+ print(f"Could not extract information from filename {filename} using provided regex")
286
+ else:
287
+ print(f"Filename {filename} did not match provided regex")
288
+ continue
289
+
290
+ return images_by_key