spacr 0.0.18__py3-none-any.whl → 0.0.21__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- spacr/alpha.py +291 -14
- spacr/annotate_app.py +2 -2
- spacr/core.py +1377 -296
- spacr/foldseek.py +793 -0
- spacr/get_alfafold_structures.py +72 -0
- spacr/graph_learning.py +259 -65
- spacr/graph_learning_lap.py +73 -71
- spacr/gui_classify_app.py +5 -21
- spacr/gui_mask_app.py +36 -30
- spacr/gui_measure_app.py +10 -24
- spacr/gui_utils.py +82 -54
- spacr/io.py +505 -205
- spacr/measure.py +160 -80
- spacr/old_code.py +155 -1
- spacr/plot.py +243 -99
- spacr/sim.py +666 -119
- spacr/timelapse.py +343 -52
- spacr/train.py +18 -10
- spacr/utils.py +252 -151
- {spacr-0.0.18.dist-info → spacr-0.0.21.dist-info}/METADATA +32 -27
- spacr-0.0.21.dist-info/RECORD +33 -0
- {spacr-0.0.18.dist-info → spacr-0.0.21.dist-info}/WHEEL +1 -1
- spacr/gui_temp.py +0 -212
- spacr/test_annotate_app.py +0 -58
- spacr/test_plot.py +0 -43
- spacr/test_train.py +0 -39
- spacr/test_utils.py +0 -33
- spacr-0.0.18.dist-info/RECORD +0 -36
- {spacr-0.0.18.dist-info → spacr-0.0.21.dist-info}/LICENSE +0 -0
- {spacr-0.0.18.dist-info → spacr-0.0.21.dist-info}/entry_points.txt +0 -0
- {spacr-0.0.18.dist-info → spacr-0.0.21.dist-info}/top_level.txt +0 -0
spacr/measure.py
CHANGED
@@ -12,6 +12,8 @@ from scipy.ndimage import binary_dilation
|
|
12
12
|
from skimage.segmentation import find_boundaries
|
13
13
|
from skimage.feature import graycomatrix, graycoprops
|
14
14
|
from mahotas.features import zernike_moments
|
15
|
+
from skimage import morphology, measure, filters
|
16
|
+
from skimage.util import img_as_bool
|
15
17
|
|
16
18
|
from .logger import log_function_call
|
17
19
|
|
@@ -92,6 +94,70 @@ def _calculate_zernike(mask, df, degree=8):
|
|
92
94
|
else:
|
93
95
|
return df
|
94
96
|
|
97
|
+
def _analyze_cytoskeleton(array, mask, channel):
|
98
|
+
"""
|
99
|
+
Analyzes and extracts skeleton properties from labeled objects in a masked image based on microtubule staining intensities.
|
100
|
+
|
101
|
+
Parameters:
|
102
|
+
image : numpy array
|
103
|
+
Intensity image where the microtubules are stained.
|
104
|
+
mask : numpy array
|
105
|
+
Mask where objects are labeled for analysis. Each label corresponds to a unique object.
|
106
|
+
|
107
|
+
Returns:
|
108
|
+
DataFrame
|
109
|
+
A pandas DataFrame containing the measured properties of each object's skeleton.
|
110
|
+
"""
|
111
|
+
|
112
|
+
image = array[:, :, channel]
|
113
|
+
|
114
|
+
properties_list = []
|
115
|
+
|
116
|
+
# Process each object in the mask based on its label
|
117
|
+
for label in np.unique(mask):
|
118
|
+
if label == 0:
|
119
|
+
continue # Skip background
|
120
|
+
|
121
|
+
# Isolate the object using the label
|
122
|
+
object_region = mask == label
|
123
|
+
region_intensity = np.where(object_region, image, 0) # Use np.where for more efficient masking
|
124
|
+
|
125
|
+
# Ensure there are non-zero values to process
|
126
|
+
if np.any(region_intensity):
|
127
|
+
# Calculate adaptive offset based on intensity percentiles within the object
|
128
|
+
valid_pixels = region_intensity[region_intensity > 0]
|
129
|
+
if len(valid_pixels) > 1: # Ensure there are enough pixels to compute percentiles
|
130
|
+
offset = np.percentile(valid_pixels, 90) - np.percentile(valid_pixels, 50)
|
131
|
+
block_size = 35 # Adjust this based on your object sizes and detail needs
|
132
|
+
local_thresh = filters.threshold_local(region_intensity, block_size=block_size, offset=offset)
|
133
|
+
cytoskeleton = region_intensity > local_thresh
|
134
|
+
|
135
|
+
# Skeletonize the thresholded cytoskeleton
|
136
|
+
skeleton = morphology.skeletonize(img_as_bool(cytoskeleton))
|
137
|
+
|
138
|
+
# Measure properties of the skeleton
|
139
|
+
skeleton_props = measure.regionprops(measure.label(skeleton), intensity_image=image)
|
140
|
+
skeleton_length = sum(prop.area for prop in skeleton_props) # Sum of lengths of all skeleton segments
|
141
|
+
branch_data = morphology.skeleton_branch_analysis(skeleton)
|
142
|
+
|
143
|
+
# Store properties
|
144
|
+
properties = {
|
145
|
+
"object_label": label,
|
146
|
+
"skeleton_length": skeleton_length,
|
147
|
+
"skeleton_branch_points": len(branch_data['branch_points'])
|
148
|
+
}
|
149
|
+
properties_list.append(properties)
|
150
|
+
else:
|
151
|
+
# Handle cases with insufficient pixels
|
152
|
+
properties_list.append({
|
153
|
+
"object_label": label,
|
154
|
+
"skeleton_length": 0,
|
155
|
+
"skeleton_branch_points": 0
|
156
|
+
})
|
157
|
+
|
158
|
+
return pd.DataFrame(properties_list)
|
159
|
+
|
160
|
+
@log_function_call
|
95
161
|
def _morphological_measurements(cell_mask, nucleus_mask, pathogen_mask, cytoplasm_mask, settings, zernike=True, degree=8):
|
96
162
|
"""
|
97
163
|
Calculate morphological measurements for cells, nucleus, pathogens, and cytoplasms based on the given masks.
|
@@ -435,6 +501,7 @@ def _estimate_blur(image):
|
|
435
501
|
# Compute and return the variance of the Laplacian
|
436
502
|
return lap.var()
|
437
503
|
|
504
|
+
@log_function_call
|
438
505
|
def _intensity_measurements(cell_mask, nucleus_mask, pathogen_mask, cytoplasm_mask, channel_arrays, settings, sizes=[3, 6, 12, 24], periphery=True, outside=True):
|
439
506
|
"""
|
440
507
|
Calculate various intensity measurements for different regions in the image.
|
@@ -524,6 +591,7 @@ def _intensity_measurements(cell_mask, nucleus_mask, pathogen_mask, cytoplasm_ma
|
|
524
591
|
|
525
592
|
@log_function_call
|
526
593
|
def _measure_crop_core(index, time_ls, file, settings):
|
594
|
+
|
527
595
|
"""
|
528
596
|
Measure and crop the images based on specified settings.
|
529
597
|
|
@@ -622,9 +690,8 @@ def _measure_crop_core(index, time_ls, file, settings):
|
|
622
690
|
if settings['cytoplasm_min_size'] is not None and settings['cytoplasm_min_size'] != 0:
|
623
691
|
cytoplasm_mask = _filter_object(cytoplasm_mask, settings['cytoplasm_min_size'])
|
624
692
|
|
625
|
-
if settings['cell_mask_dim'] is not None
|
626
|
-
|
627
|
-
cell_mask, nucleus_mask, pathogen_mask, cytoplasm_mask = _exclude_objects(cell_mask, nucleus_mask, pathogen_mask, cytoplasm_mask, include_uninfected=False)
|
693
|
+
if settings['cell_mask_dim'] is not None:
|
694
|
+
cell_mask, nucleus_mask, pathogen_mask, cytoplasm_mask = _exclude_objects(cell_mask, nucleus_mask, pathogen_mask, cytoplasm_mask, include_uninfected=settings['include_uninfected'])
|
628
695
|
|
629
696
|
# Update data with the new masks
|
630
697
|
if settings['cell_mask_dim'] is not None:
|
@@ -643,6 +710,10 @@ def _measure_crop_core(index, time_ls, file, settings):
|
|
643
710
|
|
644
711
|
cell_df, nucleus_df, pathogen_df, cytoplasm_df = _morphological_measurements(cell_mask, nucleus_mask, pathogen_mask, cytoplasm_mask, settings)
|
645
712
|
|
713
|
+
#if settings['skeleton']:
|
714
|
+
#skeleton_df = _analyze_cytoskeleton(image=channel_arrays, mask=cell_mask, channel=1)
|
715
|
+
#merge skeleton_df with cell_df here
|
716
|
+
|
646
717
|
cell_intensity_df, nucleus_intensity_df, pathogen_intensity_df, cytoplasm_intensity_df = _intensity_measurements(cell_mask, nucleus_mask, pathogen_mask, cytoplasm_mask, channel_arrays, settings, sizes=[1, 2, 3, 4, 5], periphery=True, outside=True)
|
647
718
|
if settings['cell_mask_dim'] is not None:
|
648
719
|
cell_merged_df = _merge_and_save_to_database(cell_df, cell_intensity_df, 'cell', source_folder, file_name, settings['experiment'], settings['timelapse'])
|
@@ -656,7 +727,6 @@ def _measure_crop_core(index, time_ls, file, settings):
|
|
656
727
|
if settings['cytoplasm']:
|
657
728
|
cytoplasm_merged_df = _merge_and_save_to_database(cytoplasm_df, cytoplasm_intensity_df, 'cytoplasm', source_folder, file_name, settings['experiment'], settings['timelapse'])
|
658
729
|
|
659
|
-
|
660
730
|
if settings['save_png'] or settings['save_arrays'] or settings['plot']:
|
661
731
|
|
662
732
|
if isinstance(settings['dialate_pngs'], bool):
|
@@ -676,7 +746,6 @@ def _measure_crop_core(index, time_ls, file, settings):
|
|
676
746
|
crop_ls = settings['crop_mode']
|
677
747
|
size_ls = settings['png_size']
|
678
748
|
for crop_idx, crop_mode in enumerate(crop_ls):
|
679
|
-
print(crop_idx, crop_mode)
|
680
749
|
width, height = size_ls[crop_idx]
|
681
750
|
if crop_mode == 'cell':
|
682
751
|
crop_mask = cell_mask.copy()
|
@@ -730,7 +799,7 @@ def _measure_crop_core(index, time_ls, file, settings):
|
|
730
799
|
png_channels = data[:, :, settings['png_dims']].astype(data_type)
|
731
800
|
|
732
801
|
if settings['normalize_by'] == 'fov':
|
733
|
-
percentiles_list = _get_percentiles(png_channels, settings['
|
802
|
+
percentiles_list = _get_percentiles(png_channels, settings['normalize'][0],q2=settings['normalize'][1])
|
734
803
|
|
735
804
|
png_channels = _crop_center(png_channels, region, new_width=width, new_height=height)
|
736
805
|
|
@@ -787,6 +856,7 @@ def _measure_crop_core(index, time_ls, file, settings):
|
|
787
856
|
conn.commit()
|
788
857
|
except sqlite3.OperationalError as e:
|
789
858
|
print(f"SQLite error: {e}", flush=True)
|
859
|
+
traceback.print_exc()
|
790
860
|
|
791
861
|
if settings['plot']:
|
792
862
|
_plot_cropped_arrays(png_channels)
|
@@ -818,37 +888,31 @@ def _measure_crop_core(index, time_ls, file, settings):
|
|
818
888
|
return average_time, cells
|
819
889
|
|
820
890
|
@log_function_call
|
821
|
-
def measure_crop(settings
|
891
|
+
def measure_crop(settings):
|
892
|
+
|
822
893
|
"""
|
823
894
|
Measure the crop of an image based on the provided settings.
|
824
895
|
|
825
896
|
Args:
|
826
897
|
settings (dict): The settings for measuring the crop.
|
827
|
-
annotation_settings (dict): The annotation settings.
|
828
|
-
advanced_settings (dict): The advanced settings.
|
829
898
|
|
830
899
|
Returns:
|
831
900
|
None
|
832
901
|
"""
|
902
|
+
|
903
|
+
if settings.get('test_mode', False):
|
904
|
+
if not os.basename(settings['src']) == 'test':
|
905
|
+
src = os.path.join(src, 'test')
|
906
|
+
settings['src'] = src
|
907
|
+
print(f'Changed source folder to {src} for test mode')
|
908
|
+
else:
|
909
|
+
print(f'Test mode enabled, using source folder {settings["src"]}')
|
833
910
|
|
834
911
|
from .io import _save_settings_to_db
|
835
912
|
from .timelapse import _timelapse_masks_to_gif, _scmovie
|
836
913
|
from .plot import _save_scimg_plot
|
837
914
|
from .utils import _list_endpoint_subdirectories, _generate_representative_images
|
838
915
|
|
839
|
-
settings = {**settings, **annotation_settings, **advanced_settings}
|
840
|
-
|
841
|
-
dirname = os.path.dirname(settings['input_folder'])
|
842
|
-
settings_df = pd.DataFrame(list(settings.items()), columns=['Key', 'Value'])
|
843
|
-
settings_csv = os.path.join(dirname,'settings','measure_crop_settings.csv')
|
844
|
-
os.makedirs(os.path.join(dirname,'settings'), exist_ok=True)
|
845
|
-
settings_df.to_csv(settings_csv, index=False)
|
846
|
-
|
847
|
-
if settings['timelapse_objects'] == 'nucleus':
|
848
|
-
if not settings['cell_mask_dim'] is None:
|
849
|
-
tlo = settings['timelapse_objects']
|
850
|
-
print(f'timelapse object:{tlo}, cells will be relabeled to nucleus labels to track cells.')
|
851
|
-
|
852
916
|
#general settings
|
853
917
|
settings['merge_edge_pathogen_cells'] = True
|
854
918
|
settings['radial_dist'] = True
|
@@ -857,6 +921,26 @@ def measure_crop(settings, annotation_settings, advanced_settings):
|
|
857
921
|
settings['homogeneity'] = True
|
858
922
|
settings['homogeneity_distances'] = [8,16,32]
|
859
923
|
settings['save_arrays'] = False
|
924
|
+
|
925
|
+
settings['dialate_pngs'] = False
|
926
|
+
settings['dialate_png_ratios'] = [0.2]
|
927
|
+
settings['timelapse'] = False
|
928
|
+
settings['representative_images'] = False
|
929
|
+
settings['timelapse_objects'] = 'cell'
|
930
|
+
settings['max_workers'] = os.cpu_count()-2
|
931
|
+
settings['experiment'] = 'test'
|
932
|
+
settings['cells'] = 'HeLa'
|
933
|
+
settings['cell_loc'] = None
|
934
|
+
settings['pathogens'] = ['ME49Dku80WT', 'ME49Dku80dgra8:GRA8', 'ME49Dku80dgra8', 'ME49Dku80TKO']
|
935
|
+
settings['pathogen_loc'] = [['c1', 'c2', 'c3', 'c4', 'c5', 'c6'], ['c7', 'c8', 'c9', 'c10', 'c11', 'c12'], ['c13', 'c14', 'c15', 'c16', 'c17', 'c18'], ['c19', 'c20', 'c21', 'c22', 'c23', 'c24']]
|
936
|
+
settings['treatments'] = ['BR1', 'BR2', 'BR3']
|
937
|
+
settings['treatment_loc'] = [['c1', 'c2', 'c7', 'c8', 'c13', 'c14', 'c19', 'c20'], ['c3', 'c4', 'c9', 'c10', 'c15', 'c16', 'c21', 'c22'], ['c5', 'c6', 'c11', 'c12', 'c17', 'c18', 'c23', 'c24']]
|
938
|
+
settings['channel_of_interest'] = 2
|
939
|
+
settings['compartments'] = ['pathogen', 'cytoplasm']
|
940
|
+
settings['measurement'] = 'mean_intensity'
|
941
|
+
settings['nr_imgs'] = 32
|
942
|
+
settings['um_per_pixel'] = 0.1
|
943
|
+
settings['center_crop'] = True
|
860
944
|
|
861
945
|
if settings['cell_mask_dim'] is None:
|
862
946
|
settings['include_uninfected'] = True
|
@@ -869,7 +953,18 @@ def measure_crop(settings, annotation_settings, advanced_settings):
|
|
869
953
|
else:
|
870
954
|
settings['cytoplasm'] = False
|
871
955
|
|
872
|
-
settings
|
956
|
+
#settings = {**settings, **annotation_settings, **advanced_settings}
|
957
|
+
|
958
|
+
dirname = os.path.dirname(settings['input_folder'])
|
959
|
+
settings_df = pd.DataFrame(list(settings.items()), columns=['Key', 'Value'])
|
960
|
+
settings_csv = os.path.join(dirname,'settings','measure_crop_settings.csv')
|
961
|
+
os.makedirs(os.path.join(dirname,'settings'), exist_ok=True)
|
962
|
+
settings_df.to_csv(settings_csv, index=False)
|
963
|
+
|
964
|
+
if settings['timelapse_objects'] == 'nucleus':
|
965
|
+
if not settings['cell_mask_dim'] is None:
|
966
|
+
tlo = settings['timelapse_objects']
|
967
|
+
print(f'timelapse object:{tlo}, cells will be relabeled to nucleus labels to track cells.')
|
873
968
|
|
874
969
|
int_setting_keys = ['cell_mask_dim', 'nucleus_mask_dim', 'pathogen_mask_dim', 'cell_min_size', 'nucleus_min_size', 'pathogen_min_size', 'cytoplasm_min_size']
|
875
970
|
|
@@ -913,64 +1008,49 @@ def measure_crop(settings, annotation_settings, advanced_settings):
|
|
913
1008
|
time_left = (((files_to_process-files_processed)*average_time)/max_workers)/60
|
914
1009
|
print(f'Progress: {files_processed}/{files_to_process} Time/img {average_time:.3f}sec, Time Remaining {time_left:.3f} min.', end='\r', flush=True)
|
915
1010
|
result.get()
|
916
|
-
|
917
|
-
|
918
|
-
|
919
|
-
|
920
|
-
|
921
|
-
|
922
|
-
|
923
|
-
|
924
|
-
|
925
|
-
|
926
|
-
|
927
|
-
|
928
|
-
|
929
|
-
|
930
|
-
|
931
|
-
|
932
|
-
|
933
|
-
|
934
|
-
|
935
|
-
|
936
|
-
for i, well_src in enumerate(sc_img_fldrs):
|
937
|
-
if len(os.listdir(well_src)) < 16:
|
938
|
-
nr_imgs = len(os.listdir(well_src))
|
939
|
-
standardize = False
|
940
|
-
else:
|
941
|
-
nr_imgs = 16
|
942
|
-
standardize = True
|
943
|
-
try:
|
944
|
-
all_folders = len(sc_img_fldrs)
|
945
|
-
_save_scimg_plot(src=well_src, nr_imgs=nr_imgs, channel_indices=settings['png_dims'], um_per_pixel=0.1, scale_bar_length_um=10, standardize=standardize, fontsize=12, show_filename=True, channel_names=['red','green','blue'], dpi=300, plot=False, i=i, all_folders=all_folders)
|
946
|
-
|
947
|
-
except Exception as e:
|
948
|
-
print(f"Unable to generate figure for folder {well_src}: {e}", end='\r', flush=True)
|
949
|
-
#traceback.print_exc()
|
1011
|
+
|
1012
|
+
if settings['representative_images']:
|
1013
|
+
if settings['save_png']:
|
1014
|
+
img_fldr = os.path.join(os.path.dirname(settings['input_folder']), 'data')
|
1015
|
+
sc_img_fldrs = _list_endpoint_subdirectories(img_fldr)
|
1016
|
+
|
1017
|
+
for i, well_src in enumerate(sc_img_fldrs):
|
1018
|
+
if len(os.listdir(well_src)) < 16:
|
1019
|
+
nr_imgs = len(os.listdir(well_src))
|
1020
|
+
standardize = False
|
1021
|
+
else:
|
1022
|
+
nr_imgs = 16
|
1023
|
+
standardize = True
|
1024
|
+
try:
|
1025
|
+
all_folders = len(sc_img_fldrs)
|
1026
|
+
_save_scimg_plot(src=well_src, nr_imgs=nr_imgs, channel_indices=settings['png_dims'], um_per_pixel=0.1, scale_bar_length_um=10, standardize=standardize, fontsize=12, show_filename=True, channel_names=['red','green','blue'], dpi=300, plot=False, i=i, all_folders=all_folders)
|
1027
|
+
|
1028
|
+
except Exception as e:
|
1029
|
+
print(f"Unable to generate figure for folder {well_src}: {e}", end='\r', flush=True)
|
1030
|
+
#traceback.print_exc()
|
950
1031
|
|
951
1032
|
if settings['save_measurements']:
|
952
|
-
|
953
|
-
|
954
|
-
|
955
|
-
|
956
|
-
|
957
|
-
|
958
|
-
|
959
|
-
|
960
|
-
|
961
|
-
|
962
|
-
|
963
|
-
|
964
|
-
|
965
|
-
|
966
|
-
|
967
|
-
|
968
|
-
|
969
|
-
|
970
|
-
|
971
|
-
|
972
|
-
|
973
|
-
channel_names=None)
|
1033
|
+
db_path = os.path.join(os.path.dirname(settings['input_folder']), 'measurements', 'measurements.db')
|
1034
|
+
channel_indices = settings['png_dims']
|
1035
|
+
channel_indices = [min(value, 2) for value in channel_indices]
|
1036
|
+
_generate_representative_images(db_path,
|
1037
|
+
cells=settings['cells'],
|
1038
|
+
cell_loc=settings['cell_loc'],
|
1039
|
+
pathogens=settings['pathogens'],
|
1040
|
+
pathogen_loc=settings['pathogen_loc'],
|
1041
|
+
treatments=settings['treatments'],
|
1042
|
+
treatment_loc=settings['treatment_loc'],
|
1043
|
+
channel_of_interest=settings['channel_of_interest'],
|
1044
|
+
compartments = settings['compartments'],
|
1045
|
+
measurement = settings['measurement'],
|
1046
|
+
nr_imgs=settings['nr_imgs'],
|
1047
|
+
channel_indices=channel_indices,
|
1048
|
+
um_per_pixel=settings['um_per_pixel'],
|
1049
|
+
scale_bar_length_um=10,
|
1050
|
+
plot=False,
|
1051
|
+
fontsize=12,
|
1052
|
+
show_filename=True,
|
1053
|
+
channel_names=None)
|
974
1054
|
|
975
1055
|
if settings['timelapse']:
|
976
1056
|
if settings['timelapse_objects'] == 'nucleus':
|
@@ -979,7 +1059,7 @@ def measure_crop(settings, annotation_settings, advanced_settings):
|
|
979
1059
|
object_types = ['nucleus','pathogen','cell']
|
980
1060
|
_timelapse_masks_to_gif(folder_path, mask_channels, object_types)
|
981
1061
|
|
982
|
-
if settings['save_png']:
|
1062
|
+
#if settings['save_png']:
|
983
1063
|
img_fldr = os.path.join(os.path.dirname(settings['input_folder']), 'data')
|
984
1064
|
sc_img_fldrs = _list_endpoint_subdirectories(img_fldr)
|
985
1065
|
_scmovie(sc_img_fldrs)
|
spacr/old_code.py
CHANGED
@@ -133,4 +133,158 @@ def main_thread_update_function(root, q, fig_queue, canvas_widget, progress_labe
|
|
133
133
|
#except Exception as e:
|
134
134
|
# print(f"Error updating GUI figure: {e}")
|
135
135
|
finally:
|
136
|
-
root.after(100, lambda: main_thread_update_function(root, q, fig_queue, canvas_widget, progress_label))
|
136
|
+
root.after(100, lambda: main_thread_update_function(root, q, fig_queue, canvas_widget, progress_label))
|
137
|
+
|
138
|
+
class MPNN(MessagePassing):
|
139
|
+
def __init__(self, node_in_features, edge_in_features, out_features):
|
140
|
+
super(MPNN, self).__init__(aggr='mean') # 'mean' aggregation.
|
141
|
+
self.message_mlp = Sequential(
|
142
|
+
Linear(node_in_features + edge_in_features, 128),
|
143
|
+
ReLU(),
|
144
|
+
Linear(128, out_features)
|
145
|
+
)
|
146
|
+
self.update_mlp = Sequential(
|
147
|
+
Linear(out_features, out_features),
|
148
|
+
ReLU(),
|
149
|
+
Linear(out_features, out_features)
|
150
|
+
)
|
151
|
+
|
152
|
+
def forward(self, x, edge_index, edge_attr):
|
153
|
+
# x: Node features [N, node_in_features]
|
154
|
+
# edge_index: Graph connectivity [2, E]
|
155
|
+
# edge_attr: Edge attributes/features [E, edge_in_features]
|
156
|
+
return self.propagate(edge_index, x=x, edge_attr=edge_attr)
|
157
|
+
|
158
|
+
def message(self, x_j, edge_attr):
|
159
|
+
# x_j: Input features of neighbors [E, node_in_features]
|
160
|
+
# edge_attr: Edge attributes [E, edge_in_features]
|
161
|
+
tmp = torch.cat([x_j, edge_attr], dim=-1) # Concatenate node features with edge attributes
|
162
|
+
return self.message_mlp(tmp)
|
163
|
+
|
164
|
+
def update(self, aggr_out):
|
165
|
+
# aggr_out: Aggregated messages [N, out_features]
|
166
|
+
return self.update_mlp(aggr_out)
|
167
|
+
|
168
|
+
def weighted_mse_loss(output, target, score_threshold=0.8, high_score_weight=10):
|
169
|
+
# Assumes output and target are the predicted and true scores, respectively
|
170
|
+
weights = torch.ones_like(target)
|
171
|
+
high_score_mask = target >= score_threshold
|
172
|
+
weights[high_score_mask] = high_score_weight
|
173
|
+
return ((output - target) ** 2 * weights).mean()
|
174
|
+
|
175
|
+
def generate_single_graph(sequencing, scores):
|
176
|
+
# Load and preprocess sequencing data
|
177
|
+
gene_df = pd.read_csv(sequencing)
|
178
|
+
gene_df = gene_df.rename(columns={"prc": "well_id", "grna": "gene_id", "count": "read_count"})
|
179
|
+
total_reads_per_well = gene_df.groupby('well_id')['read_count'].sum().reset_index(name='total_reads')
|
180
|
+
gene_df = gene_df.merge(total_reads_per_well, on='well_id')
|
181
|
+
gene_df['well_read_fraction'] = gene_df['read_count']/gene_df['total_reads']
|
182
|
+
|
183
|
+
# Load and preprocess cell score data
|
184
|
+
cell_df = pd.read_csv(scores)
|
185
|
+
cell_df = cell_df[['prcfo', 'prc', 'pred']].rename(columns={'prcfo': 'cell_id', 'prc': 'well_id', 'pred': 'score'})
|
186
|
+
|
187
|
+
# Initialize mappings
|
188
|
+
gene_id_to_index = {gene: i for i, gene in enumerate(gene_df['gene_id'].unique())}
|
189
|
+
cell_id_to_index = {cell: i + len(gene_id_to_index) for i, cell in enumerate(cell_df['cell_id'].unique())}
|
190
|
+
|
191
|
+
# Initialize edge indices and attributes
|
192
|
+
edge_index = []
|
193
|
+
edge_attr = []
|
194
|
+
|
195
|
+
# Associate each cell with all genes in the same well
|
196
|
+
for well_id, group in gene_df.groupby('well_id'):
|
197
|
+
if well_id in cell_df['well_id'].values:
|
198
|
+
cell_indices = cell_df[cell_df['well_id'] == well_id]['cell_id'].map(cell_id_to_index).values
|
199
|
+
gene_indices = group['gene_id'].map(gene_id_to_index).values
|
200
|
+
fractions = group['well_read_fraction'].values
|
201
|
+
|
202
|
+
for cell_idx in cell_indices:
|
203
|
+
for gene_idx, fraction in zip(gene_indices, fractions):
|
204
|
+
edge_index.append([cell_idx, gene_idx])
|
205
|
+
edge_attr.append([fraction])
|
206
|
+
|
207
|
+
# Convert lists to PyTorch tensors
|
208
|
+
edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
|
209
|
+
edge_attr = torch.tensor(edge_attr, dtype=torch.float)
|
210
|
+
cell_scores = torch.tensor(cell_df['score'].values, dtype=torch.float)
|
211
|
+
|
212
|
+
# One-hot encoding for genes, and zero features for cells (could be replaced with real features if available)
|
213
|
+
gene_features = torch.eye(len(gene_id_to_index))
|
214
|
+
cell_features = torch.zeros(len(cell_id_to_index), gene_features.size(1))
|
215
|
+
|
216
|
+
# Combine features
|
217
|
+
x = torch.cat([cell_features, gene_features], dim=0)
|
218
|
+
|
219
|
+
# Create the graph data object
|
220
|
+
data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, y=cell_scores)
|
221
|
+
|
222
|
+
return data, gene_id_to_index, len(gene_id_to_index)
|
223
|
+
|
224
|
+
# in _normalize_and_outline
|
225
|
+
outlines = []
|
226
|
+
|
227
|
+
overlayed_image = rgb_image.copy()
|
228
|
+
for i, mask_dim in enumerate(mask_dims):
|
229
|
+
mask = np.take(image, mask_dim, axis=2)
|
230
|
+
outline = np.zeros_like(mask)
|
231
|
+
# Find the contours of the objects in the mask
|
232
|
+
for j in np.unique(mask)[1:]:
|
233
|
+
contours = find_contours(mask == j, 0.5)
|
234
|
+
for contour in contours:
|
235
|
+
contour = contour.astype(int)
|
236
|
+
outline[contour[:, 0], contour[:, 1]] = j
|
237
|
+
# Make the outline thicker
|
238
|
+
outline = dilation(outline, square(outline_thickness))
|
239
|
+
outlines.append(outline)
|
240
|
+
# Overlay the outlines onto the RGB image
|
241
|
+
for j in np.unique(outline)[1:]:
|
242
|
+
overlayed_image[outline == j] = outline_colors[i % len(outline_colors)]
|
243
|
+
|
244
|
+
def _extract_filename_metadata(filenames, src, images_by_key, regular_expression, metadata_type='cellvoyager', pick_slice=False, skip_mode='01'):
|
245
|
+
for filename in filenames:
|
246
|
+
match = regular_expression.match(filename)
|
247
|
+
if match:
|
248
|
+
try:
|
249
|
+
try:
|
250
|
+
plate = match.group('plateID')
|
251
|
+
except:
|
252
|
+
plate = os.path.basename(src)
|
253
|
+
|
254
|
+
well = match.group('wellID')
|
255
|
+
field = match.group('fieldID')
|
256
|
+
channel = match.group('chanID')
|
257
|
+
mode = None
|
258
|
+
|
259
|
+
if well[0].isdigit():
|
260
|
+
well = str(_safe_int_convert(well))
|
261
|
+
if field[0].isdigit():
|
262
|
+
field = str(_safe_int_convert(field))
|
263
|
+
if channel[0].isdigit():
|
264
|
+
channel = str(_safe_int_convert(channel))
|
265
|
+
|
266
|
+
if metadata_type =='cq1':
|
267
|
+
orig_wellID = wellID
|
268
|
+
wellID = _convert_cq1_well_id(wellID)
|
269
|
+
clear_output(wait=True)
|
270
|
+
print(f'Converted Well ID: {orig_wellID} to {wellID}', end='\r', flush=True)
|
271
|
+
|
272
|
+
if pick_slice:
|
273
|
+
try:
|
274
|
+
mode = match.group('AID')
|
275
|
+
except IndexError:
|
276
|
+
sliceid = '00'
|
277
|
+
|
278
|
+
if mode == skip_mode:
|
279
|
+
continue
|
280
|
+
|
281
|
+
key = (plate, well, field, channel, mode)
|
282
|
+
with Image.open(os.path.join(src, filename)) as img:
|
283
|
+
images_by_key[key].append(np.array(img))
|
284
|
+
except IndexError:
|
285
|
+
print(f"Could not extract information from filename {filename} using provided regex")
|
286
|
+
else:
|
287
|
+
print(f"Filename {filename} did not match provided regex")
|
288
|
+
continue
|
289
|
+
|
290
|
+
return images_by_key
|