spacr 0.3.1__py3-none-any.whl → 0.3.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- spacr/__init__.py +19 -3
- spacr/cellpose.py +311 -0
- spacr/core.py +245 -2494
- spacr/deep_spacr.py +335 -163
- spacr/gui.py +2 -0
- spacr/gui_core.py +85 -65
- spacr/gui_elements.py +110 -5
- spacr/gui_utils.py +375 -7
- spacr/io.py +680 -141
- spacr/logger.py +28 -9
- spacr/measure.py +108 -133
- spacr/mediar.py +0 -3
- spacr/ml.py +1051 -0
- spacr/openai.py +37 -0
- spacr/plot.py +707 -20
- spacr/resources/data/lopit.csv +3833 -0
- spacr/resources/data/toxoplasma_metadata.csv +8843 -0
- spacr/resources/icons/convert.png +0 -0
- spacr/resources/{models/cp/toxo_plaque_cyto_e25000_X1120_Y1120.CP_model → icons/dna_matrix.mp4} +0 -0
- spacr/sequencing.py +241 -1311
- spacr/settings.py +181 -50
- spacr/sim.py +0 -2
- spacr/submodules.py +349 -0
- spacr/timelapse.py +0 -2
- spacr/toxo.py +238 -0
- spacr/utils.py +776 -182
- {spacr-0.3.1.dist-info → spacr-0.3.3.dist-info}/METADATA +31 -22
- {spacr-0.3.1.dist-info → spacr-0.3.3.dist-info}/RECORD +32 -33
- spacr/chris.py +0 -50
- spacr/graph_learning.py +0 -340
- spacr/resources/MEDIAR/.git +0 -1
- spacr/resources/MEDIAR_weights/.DS_Store +0 -0
- spacr/resources/icons/.DS_Store +0 -0
- spacr/resources/icons/spacr_logo_rotation.gif +0 -0
- spacr/resources/models/cp/toxo_plaque_cyto_e25000_X1120_Y1120.CP_model_settings.csv +0 -23
- spacr/resources/models/cp/toxo_pv_lumen.CP_model +0 -0
- spacr/sim_app.py +0 -0
- {spacr-0.3.1.dist-info → spacr-0.3.3.dist-info}/LICENSE +0 -0
- {spacr-0.3.1.dist-info → spacr-0.3.3.dist-info}/WHEEL +0 -0
- {spacr-0.3.1.dist-info → spacr-0.3.3.dist-info}/entry_points.txt +0 -0
- {spacr-0.3.1.dist-info → spacr-0.3.3.dist-info}/top_level.txt +0 -0
spacr/io.py
CHANGED
@@ -1,30 +1,133 @@
|
|
1
|
-
import os, re, sqlite3, gc, torch, time, random, shutil, cv2, tarfile, cellpose, glob, queue
|
1
|
+
import os, re, sqlite3, gc, torch, time, random, shutil, cv2, tarfile, cellpose, glob, queue, tifffile, czifile, atexit, datetime
|
2
2
|
import numpy as np
|
3
3
|
import pandas as pd
|
4
|
-
import tifffile
|
5
4
|
from PIL import Image, ImageOps
|
6
|
-
from collections import defaultdict, Counter
|
5
|
+
from collections import defaultdict, Counter
|
7
6
|
from pathlib import Path
|
8
7
|
from functools import partial
|
9
8
|
from matplotlib.animation import FuncAnimation
|
10
9
|
from IPython.display import display
|
11
10
|
from skimage.util import img_as_uint
|
12
11
|
from skimage.exposure import rescale_intensity
|
13
|
-
from skimage import filters
|
14
12
|
import skimage.measure as measure
|
15
13
|
from skimage import exposure
|
16
14
|
import imageio.v2 as imageio2
|
17
15
|
import matplotlib.pyplot as plt
|
18
16
|
from io import BytesIO
|
19
|
-
from IPython.display import display
|
20
|
-
from multiprocessing import Pool, cpu_count, Process, Queue
|
21
|
-
from torch.utils.data import Dataset, DataLoader
|
17
|
+
from IPython.display import display
|
18
|
+
from multiprocessing import Pool, cpu_count, Process, Queue, Value, Lock
|
19
|
+
from torch.utils.data import Dataset, DataLoader, random_split
|
22
20
|
import matplotlib.pyplot as plt
|
23
21
|
from torchvision.transforms import ToTensor
|
24
|
-
import seaborn as sns
|
25
|
-
import
|
26
|
-
|
27
|
-
from .
|
22
|
+
import seaborn as sns
|
23
|
+
from nd2reader import ND2Reader
|
24
|
+
from torchvision import transforms
|
25
|
+
from sklearn.model_selection import train_test_split
|
26
|
+
|
27
|
+
def process_non_tif_non_2D_images(folder):
|
28
|
+
"""Processes all images in the folder and splits them into grayscale channels, preserving bit depth."""
|
29
|
+
|
30
|
+
# Helper function to save grayscale images
|
31
|
+
def save_grayscale_images(image, base_name, folder, dtype, channel=None, z=None, t=None):
|
32
|
+
"""Save grayscale images with appropriate suffix based on channel, z, and t, preserving bit depth."""
|
33
|
+
suffix = ""
|
34
|
+
if channel is not None:
|
35
|
+
suffix += f"_C{channel}"
|
36
|
+
if z is not None:
|
37
|
+
suffix += f"_Z{z}"
|
38
|
+
if t is not None:
|
39
|
+
suffix += f"_T{t}"
|
40
|
+
|
41
|
+
output_filename = os.path.join(folder, f"{base_name}{suffix}.tif")
|
42
|
+
tifffile.imwrite(output_filename, image.astype(dtype))
|
43
|
+
|
44
|
+
# Function to handle splitting of multi-dimensional images into grayscale channels
|
45
|
+
def split_channels(image, folder, base_name, dtype):
|
46
|
+
"""Splits the image into channels and handles 3D, 4D, and 5D image cases."""
|
47
|
+
if image.ndim == 2:
|
48
|
+
# Grayscale image, already processed separately
|
49
|
+
return
|
50
|
+
|
51
|
+
elif image.ndim == 3:
|
52
|
+
# 3D image: (height, width, channels)
|
53
|
+
for c in range(image.shape[2]):
|
54
|
+
save_grayscale_images(image[..., c], base_name, folder, dtype, channel=c+1)
|
55
|
+
|
56
|
+
elif image.ndim == 4:
|
57
|
+
# 4D image: (height, width, channels, Z-dimension)
|
58
|
+
for z in range(image.shape[3]):
|
59
|
+
for c in range(image.shape[2]):
|
60
|
+
save_grayscale_images(image[..., c, z], base_name, folder, dtype, channel=c+1, z=z+1)
|
61
|
+
|
62
|
+
elif image.ndim == 5:
|
63
|
+
# 5D image: (height, width, channels, Z-dimension, Time)
|
64
|
+
for t in range(image.shape[4]):
|
65
|
+
for z in range(image.shape[3]):
|
66
|
+
for c in range(image.shape[2]):
|
67
|
+
save_grayscale_images(image[..., c, z, t], base_name, folder, dtype, channel=c+1, z=z+1, t=t+1)
|
68
|
+
|
69
|
+
# Function to load images in various formats
|
70
|
+
def load_image(file_path):
|
71
|
+
"""Loads image from various formats and returns it as a numpy array along with its dtype."""
|
72
|
+
ext = os.path.splitext(file_path)[1].lower()
|
73
|
+
|
74
|
+
if ext in ['.tif', '.tiff']:
|
75
|
+
image = tifffile.imread(file_path)
|
76
|
+
return image, image.dtype
|
77
|
+
|
78
|
+
elif ext in ['.png', '.jpg', '.jpeg']:
|
79
|
+
image = Image.open(file_path)
|
80
|
+
return np.array(image), image.mode
|
81
|
+
|
82
|
+
elif ext == '.czi':
|
83
|
+
with czifile.CziFile(file_path) as czi:
|
84
|
+
image = czi.asarray()
|
85
|
+
return image, image.dtype
|
86
|
+
|
87
|
+
elif ext == '.nd2':
|
88
|
+
with ND2Reader(file_path) as nd2:
|
89
|
+
image = np.array(nd2)
|
90
|
+
return image, image.dtype
|
91
|
+
|
92
|
+
else:
|
93
|
+
raise ValueError(f"Unsupported file extension: {ext}")
|
94
|
+
|
95
|
+
# Function to check if an image is grayscale and save it as a TIFF if it isn't already
|
96
|
+
def convert_grayscale_to_tiff(image, filename, folder, dtype):
|
97
|
+
"""Convert grayscale images that are not in TIFF format to TIFF, preserving bit depth."""
|
98
|
+
base_name = os.path.splitext(filename)[0]
|
99
|
+
output_filename = os.path.join(folder, f"{base_name}.tif")
|
100
|
+
tifffile.imwrite(output_filename, image.astype(dtype))
|
101
|
+
print(f"Converted grayscale image {filename} to TIFF with bit depth {dtype}.")
|
102
|
+
|
103
|
+
# Supported formats
|
104
|
+
supported_formats = ['.tif', '.tiff', '.png', '.jpg', '.jpeg', '.czi', '.nd2']
|
105
|
+
|
106
|
+
# Loop through all files in the folder
|
107
|
+
for filename in os.listdir(folder):
|
108
|
+
file_path = os.path.join(folder, filename)
|
109
|
+
ext = os.path.splitext(file_path)[1].lower()
|
110
|
+
|
111
|
+
if ext in supported_formats:
|
112
|
+
print(f"Processing {filename}")
|
113
|
+
try:
|
114
|
+
# Load the image and its dtype
|
115
|
+
image, dtype = load_image(file_path)
|
116
|
+
|
117
|
+
# If the image is grayscale (2D), convert it to TIFF if it's not already in TIFF format
|
118
|
+
if image.ndim == 2:
|
119
|
+
if ext not in ['.tif', '.tiff']:
|
120
|
+
convert_grayscale_to_tiff(image, filename, folder, dtype)
|
121
|
+
else:
|
122
|
+
print(f"Image {filename} is already grayscale and in TIFF format, skipping.")
|
123
|
+
continue
|
124
|
+
|
125
|
+
# Otherwise, split channels and save images
|
126
|
+
base_name = os.path.splitext(filename)[0]
|
127
|
+
split_channels(image, folder, base_name, dtype)
|
128
|
+
|
129
|
+
except Exception as e:
|
130
|
+
print(f"Error processing {filename}: {str(e)}")
|
28
131
|
|
29
132
|
def _load_images_and_labels(image_files, label_files, circular=False, invert=False):
|
30
133
|
|
@@ -632,6 +735,20 @@ class TarImageDataset(Dataset):
|
|
632
735
|
img = self.transform(img)
|
633
736
|
|
634
737
|
return img, m.name
|
738
|
+
|
739
|
+
def load_images_from_paths(images_by_key):
|
740
|
+
images_dict = {}
|
741
|
+
|
742
|
+
for key, paths in images_by_key.items():
|
743
|
+
images_dict[key] = []
|
744
|
+
for path in paths:
|
745
|
+
try:
|
746
|
+
with Image.open(path) as img:
|
747
|
+
images_dict[key].append(np.array(img))
|
748
|
+
except Exception as e:
|
749
|
+
print(f"Error loading image from {path}: {e}")
|
750
|
+
|
751
|
+
return images_dict
|
635
752
|
|
636
753
|
#@log_function_call
|
637
754
|
def _rename_and_organize_image_files(src, regex, batch_size=100, pick_slice=False, skip_mode='01', metadata_type='', img_format='.tif'):
|
@@ -657,15 +774,20 @@ def _rename_and_organize_image_files(src, regex, batch_size=100, pick_slice=Fals
|
|
657
774
|
files_processed = 0
|
658
775
|
if not os.path.exists(stack_path) or (os.path.isdir(stack_path) and len(os.listdir(stack_path)) == 0):
|
659
776
|
all_filenames = [filename for filename in os.listdir(src) if filename.endswith(img_format)]
|
660
|
-
print(f'
|
777
|
+
print(f'All files: {len(all_filenames)} in {src}')
|
661
778
|
time_ls = []
|
662
|
-
|
663
|
-
|
779
|
+
image_paths_by_key = _extract_filename_metadata(all_filenames, src, regular_expression, metadata_type, pick_slice, skip_mode)
|
780
|
+
# Convert dictionary keys to a list for batching
|
781
|
+
batching_keys = list(image_paths_by_key.keys())
|
782
|
+
print(f'All unique FOV: {len(image_paths_by_key)} in {src}')
|
783
|
+
for idx in range(0, len(image_paths_by_key), batch_size):
|
664
784
|
start = time.time()
|
665
|
-
|
666
|
-
for
|
667
|
-
|
668
|
-
|
785
|
+
|
786
|
+
# Select batch keys and create a subset of the dictionary for this batch
|
787
|
+
batch_keys = batching_keys[idx:idx+batch_size]
|
788
|
+
batch_images_by_key = {key: image_paths_by_key[key] for key in batch_keys}
|
789
|
+
images_by_key = load_images_from_paths(batch_images_by_key)
|
790
|
+
|
669
791
|
if pick_slice:
|
670
792
|
for i, key in enumerate(images_by_key):
|
671
793
|
plate, well, field, channel, mode = key
|
@@ -682,10 +804,10 @@ def _rename_and_organize_image_files(src, regex, batch_size=100, pick_slice=Fals
|
|
682
804
|
files_to_process = len(all_filenames)
|
683
805
|
print_progress(files_processed, files_to_process, n_jobs=1, time_ls=time_ls, batch_size=batch_size, operation_type='Preprocessing filenames')
|
684
806
|
|
685
|
-
#if os.path.exists(output_path):
|
686
|
-
# print(f'WARNING: A file with the same name already exists at location {output_filename}')
|
687
807
|
if not os.path.exists(output_path):
|
688
808
|
mip_image.save(output_path)
|
809
|
+
else:
|
810
|
+
print(f'WARNING: A file with the same name already exists at location {output_filename}')
|
689
811
|
else:
|
690
812
|
for i, (key, images) in enumerate(images_by_key.items()):
|
691
813
|
plate, well, field, channel = key[:4]
|
@@ -702,10 +824,11 @@ def _rename_and_organize_image_files(src, regex, batch_size=100, pick_slice=Fals
|
|
702
824
|
files_to_process = len(all_filenames)
|
703
825
|
print_progress(files_processed, files_to_process, n_jobs=1, time_ls=time_ls, batch_size=batch_size, operation_type='Preprocessing filenames')
|
704
826
|
|
705
|
-
#if os.path.exists(output_path):
|
706
|
-
# print(f'WARNING: A file with the same name already exists at location {output_filename}')
|
707
827
|
if not os.path.exists(output_path):
|
708
828
|
mip_image.save(output_path)
|
829
|
+
else:
|
830
|
+
print(f'WARNING: A file with the same name already exists at location {output_filename}')
|
831
|
+
|
709
832
|
images_by_key.clear()
|
710
833
|
|
711
834
|
# Move original images to a new directory
|
@@ -862,47 +985,6 @@ def _move_to_chan_folder(src, regex, timelapse=False, metadata_type=''):
|
|
862
985
|
shutil.move(os.path.join(src, filename), move)
|
863
986
|
return
|
864
987
|
|
865
|
-
def _merge_channels_v2(src, plot=False):
|
866
|
-
from .plot import plot_arrays
|
867
|
-
"""
|
868
|
-
Merge the channels in the given source directory and save the merged files in a 'stack' directory.
|
869
|
-
|
870
|
-
Args:
|
871
|
-
src (str): The path to the source directory containing the channel folders.
|
872
|
-
plot (bool, optional): Whether to plot the merged arrays. Defaults to False.
|
873
|
-
|
874
|
-
Returns:
|
875
|
-
None
|
876
|
-
"""
|
877
|
-
src = Path(src)
|
878
|
-
stack_dir = src / 'stack'
|
879
|
-
chan_dirs = [d for d in src.iterdir() if d.is_dir() and d.name in ['01', '02', '03', '04', '00', '1', '2', '3', '4','0']]
|
880
|
-
|
881
|
-
chan_dirs.sort(key=lambda x: x.name)
|
882
|
-
print(f'List of folders in src: {[d.name for d in chan_dirs]}. Single channel folders.')
|
883
|
-
start_time = time.time()
|
884
|
-
|
885
|
-
# First directory and its files
|
886
|
-
dir_files = list(chan_dirs[0].iterdir())
|
887
|
-
|
888
|
-
# Create the 'stack' directory if it doesn't exist
|
889
|
-
stack_dir.mkdir(exist_ok=True)
|
890
|
-
print(f'generated folder with merged arrays: {stack_dir}')
|
891
|
-
|
892
|
-
if _is_dir_empty(stack_dir):
|
893
|
-
with Pool(max(cpu_count() // 2, 1)) as pool:
|
894
|
-
#with Pool(cpu_count()) as pool:
|
895
|
-
merge_func = partial(_merge_file, chan_dirs, stack_dir)
|
896
|
-
pool.map(merge_func, dir_files)
|
897
|
-
|
898
|
-
avg_time = (time.time() - start_time) / len(dir_files)
|
899
|
-
print(f'Average Time: {avg_time:.3f} sec')
|
900
|
-
|
901
|
-
if plot:
|
902
|
-
plot_arrays(src+'/stack')
|
903
|
-
|
904
|
-
return
|
905
|
-
|
906
988
|
def _merge_channels(src, plot=False):
|
907
989
|
"""
|
908
990
|
Merge the channels in the given source directory and save the merged files in a 'stack' directory without using multiprocessing.
|
@@ -961,9 +1043,7 @@ def _mip_all(src, include_first_chan=True):
|
|
961
1043
|
Returns:
|
962
1044
|
None
|
963
1045
|
"""
|
964
|
-
|
965
|
-
from .utils import normalize_to_dtype
|
966
|
-
|
1046
|
+
|
967
1047
|
#print('========== generating MIPs ==========')
|
968
1048
|
# Iterate over each file in the specified directory (src).
|
969
1049
|
for filename in os.listdir(src):
|
@@ -1337,7 +1417,6 @@ def _get_lists_for_normalization(settings):
|
|
1337
1417
|
return backgrounds, signal_to_noise, signal_thresholds, remove_background
|
1338
1418
|
|
1339
1419
|
def _normalize_stack(src, backgrounds=[100, 100, 100], remove_backgrounds=[False, False, False], lower_percentile=2, save_dtype=np.float32, signal_to_noise=[5, 5, 5], signal_thresholds=[1000, 1000, 1000]):
|
1340
|
-
from .utils import print_progress
|
1341
1420
|
"""
|
1342
1421
|
Normalize the stack of images.
|
1343
1422
|
|
@@ -1430,7 +1509,6 @@ def _normalize_stack(src, backgrounds=[100, 100, 100], remove_backgrounds=[False
|
|
1430
1509
|
return print(f'Saved stacks: {output_fldr}')
|
1431
1510
|
|
1432
1511
|
def _normalize_timelapse(src, lower_percentile=2, save_dtype=np.float32):
|
1433
|
-
from .utils import print_progress
|
1434
1512
|
"""
|
1435
1513
|
Normalize the timelapse data by rescaling the intensity values based on percentiles.
|
1436
1514
|
|
@@ -1559,7 +1637,7 @@ def delete_empty_subdirectories(folder_path):
|
|
1559
1637
|
#@log_function_call
|
1560
1638
|
def preprocess_img_data(settings):
|
1561
1639
|
|
1562
|
-
from .plot import plot_arrays
|
1640
|
+
from .plot import plot_arrays
|
1563
1641
|
from .utils import _run_test_mode, _get_regex
|
1564
1642
|
from .settings import set_default_settings_preprocess_img_data
|
1565
1643
|
|
@@ -2054,7 +2132,6 @@ def _load_and_concatenate_arrays(src, channels, cell_chann_dim, nucleus_chann_di
|
|
2054
2132
|
padded_shapes = [shape + (0,) * (max_tuple_length - len(shape)) for shape in unique_shapes]
|
2055
2133
|
# Now create a NumPy array and find the maximum dimensions
|
2056
2134
|
max_dims = np.max(np.array(padded_shapes), axis=0)
|
2057
|
-
#clear_output(wait=True)
|
2058
2135
|
print(f'Warning: arrays with multiple shapes found. Padding arrays to max X,Y dimentions {max_dims}')
|
2059
2136
|
#print(f'Warning: arrays with multiple shapes found. Padding arrays to max X,Y dimentions {max_dims}', end='\r', flush=True)
|
2060
2137
|
padded_stack_ls = []
|
@@ -2102,7 +2179,7 @@ def _read_db(db_loc, tables):
|
|
2102
2179
|
conn.close()
|
2103
2180
|
return dfs
|
2104
2181
|
|
2105
|
-
def _read_and_merge_data(locs, tables, verbose=False,
|
2182
|
+
def _read_and_merge_data(locs, tables, verbose=False, nuclei_limit=False, pathogen_limit=False, uninfected=False):
|
2106
2183
|
"""
|
2107
2184
|
Read and merge data from SQLite databases and perform data preprocessing.
|
2108
2185
|
|
@@ -2110,9 +2187,9 @@ def _read_and_merge_data(locs, tables, verbose=False, include_multinucleated=Fal
|
|
2110
2187
|
- locs (list): A list of file paths to the SQLite database files.
|
2111
2188
|
- tables (list): A list of table names to read from the databases.
|
2112
2189
|
- verbose (bool): Whether to print verbose output. Default is False.
|
2113
|
-
-
|
2114
|
-
-
|
2115
|
-
-
|
2190
|
+
- nuclei_limit (bool): Whether to include multinucleated cells. Default is False.
|
2191
|
+
- pathogen_limit (bool): Whether to include cells with multiple infections. Default is False.
|
2192
|
+
- uninfected (bool): Whether to include non-infected cells. Default is False.
|
2116
2193
|
|
2117
2194
|
Returns:
|
2118
2195
|
- merged_df (pandas.DataFrame): The merged and preprocessed dataframe.
|
@@ -2187,7 +2264,7 @@ def _read_and_merge_data(locs, tables, verbose=False, include_multinucleated=Fal
|
|
2187
2264
|
nucleus = nucleus.assign(cell_id=lambda x: 'o' + x['cell_id'].astype(int).astype(str))
|
2188
2265
|
nucleus = nucleus.assign(prcfo = lambda x: x['prcf'] + '_' + x['cell_id'])
|
2189
2266
|
nucleus['nucleus_prcfo_count'] = nucleus.groupby('prcfo')['prcfo'].transform('count')
|
2190
|
-
if
|
2267
|
+
if nuclei_limit == False:
|
2191
2268
|
#nucleus = nucleus[~nucleus['prcfo'].duplicated()]
|
2192
2269
|
nucleus = nucleus[nucleus['nucleus_prcfo_count']==1]
|
2193
2270
|
nucleus_g_df, _ = _split_data(nucleus, 'prcfo', 'cell_id')
|
@@ -2203,9 +2280,9 @@ def _read_and_merge_data(locs, tables, verbose=False, include_multinucleated=Fal
|
|
2203
2280
|
pathogens = pathogens.assign(cell_id=lambda x: 'o' + x['cell_id'].astype(int).astype(str))
|
2204
2281
|
pathogens = pathogens.assign(prcfo = lambda x: x['prcf'] + '_' + x['cell_id'])
|
2205
2282
|
pathogens['pathogen_prcfo_count'] = pathogens.groupby('prcfo')['prcfo'].transform('count')
|
2206
|
-
if
|
2283
|
+
if uninfected == False:
|
2207
2284
|
pathogens = pathogens[pathogens['pathogen_prcfo_count']>=1]
|
2208
|
-
if
|
2285
|
+
if pathogen_limit == False:
|
2209
2286
|
pathogens = pathogens[pathogens['pathogen_prcfo_count']<=1]
|
2210
2287
|
pathogens_g_df, _ = _split_data(pathogens, 'prcfo', 'cell_id')
|
2211
2288
|
print(f'pathogens: {len(pathogens)}')
|
@@ -2267,12 +2344,8 @@ def _results_to_csv(src, df, df_well):
|
|
2267
2344
|
wells.to_csv(wells_loc, index=True, header=True)
|
2268
2345
|
cells.to_csv(cells_loc, index=True, header=True)
|
2269
2346
|
return cells, wells
|
2270
|
-
|
2271
|
-
###################################################
|
2272
|
-
# Classify
|
2273
|
-
###################################################
|
2274
2347
|
|
2275
|
-
def read_plot_model_stats(
|
2348
|
+
def read_plot_model_stats(train_file_path, val_file_path ,save=False):
|
2276
2349
|
|
2277
2350
|
def _plot_and_save(train_df, val_df, column='accuracy', save=False, path=None, dpi=600):
|
2278
2351
|
|
@@ -2301,37 +2374,19 @@ def read_plot_model_stats(file_path ,save=False):
|
|
2301
2374
|
plt.savefig(pdf_path, format='pdf', dpi=dpi)
|
2302
2375
|
else:
|
2303
2376
|
plt.show()
|
2304
|
-
# Read the CSV into a dataframe
|
2305
|
-
df = pd.read_csv(file_path, index_col=0)
|
2306
|
-
|
2307
|
-
# Split the dataframe into train and validation based on the index
|
2308
|
-
train_df = df.filter(like='_train', axis=0).copy()
|
2309
|
-
val_df = df.filter(like='_val', axis=0).copy()
|
2310
|
-
|
2311
|
-
fldr_1 = os.path.dirname(file_path)
|
2312
|
-
|
2313
|
-
train_csv_path = os.path.join(fldr_1, 'train.csv')
|
2314
|
-
val_csv_path = os.path.join(fldr_1, 'validation.csv')
|
2315
2377
|
|
2316
|
-
|
2317
|
-
|
2318
|
-
|
2319
|
-
bn_2 = os.path.basename(fldr_2)
|
2320
|
-
bn_3 = os.path.basename(fldr_3)
|
2321
|
-
model_name = str(f'{bn_1}_{bn_2}_{bn_3}')
|
2378
|
+
# Read the CSVs into DataFrames
|
2379
|
+
train_df = pd.read_csv(train_file_path, index_col=0)
|
2380
|
+
val_df = pd.read_csv(val_file_path, index_col=0)
|
2322
2381
|
|
2323
|
-
#
|
2324
|
-
|
2325
|
-
val_df['epoch'] = [int(idx.split('_')[0]) for idx in val_df.index]
|
2326
|
-
|
2327
|
-
# Save dataframes to a CSV file
|
2328
|
-
train_df.to_csv(train_csv_path)
|
2329
|
-
val_df.to_csv(val_csv_path)
|
2382
|
+
# Get the folder path for saving plots
|
2383
|
+
fldr_1 = os.path.dirname(train_file_path)
|
2330
2384
|
|
2331
2385
|
if save:
|
2332
2386
|
# Setting the style
|
2333
2387
|
sns.set(style="whitegrid")
|
2334
2388
|
|
2389
|
+
# Plot and save the results
|
2335
2390
|
_plot_and_save(train_df, val_df, column='accuracy', save=save, path=fldr_1)
|
2336
2391
|
_plot_and_save(train_df, val_df, column='neg_accuracy', save=save, path=fldr_1)
|
2337
2392
|
_plot_and_save(train_df, val_df, column='pos_accuracy', save=save, path=fldr_1)
|
@@ -2379,50 +2434,53 @@ def _save_model(model, model_type, results_df, dst, epoch, epochs, intermedeate_
|
|
2379
2434
|
|
2380
2435
|
return model_path
|
2381
2436
|
|
2382
|
-
def _save_progress(dst,
|
2437
|
+
def _save_progress(dst, train_df, validation_df):
|
2383
2438
|
"""
|
2384
2439
|
Save the progress of the classification model.
|
2385
2440
|
|
2386
2441
|
Parameters:
|
2387
2442
|
dst (str): The destination directory to save the progress.
|
2388
|
-
|
2389
|
-
|
2443
|
+
train_df (pandas.DataFrame): The DataFrame containing training stats.
|
2444
|
+
validation_df (pandas.DataFrame): The DataFrame containing validation stats (if available).
|
2390
2445
|
|
2391
2446
|
Returns:
|
2392
2447
|
None
|
2393
2448
|
"""
|
2449
|
+
|
2450
|
+
def _save_df_to_csv(file_path, df):
|
2451
|
+
"""
|
2452
|
+
Save the given DataFrame to the specified CSV file, either creating a new file or appending to an existing one.
|
2453
|
+
|
2454
|
+
Parameters:
|
2455
|
+
file_path (str): The file path where the CSV will be saved.
|
2456
|
+
df (pandas.DataFrame): The DataFrame to save.
|
2457
|
+
"""
|
2458
|
+
if not os.path.exists(file_path):
|
2459
|
+
with open(file_path, 'w') as f:
|
2460
|
+
df.to_csv(f, index=True, header=True)
|
2461
|
+
f.flush() # Ensure data is written to the file system
|
2462
|
+
else:
|
2463
|
+
with open(file_path, 'a') as f:
|
2464
|
+
df.to_csv(f, index=True, header=False)
|
2465
|
+
f.flush()
|
2466
|
+
|
2394
2467
|
# Save accuracy, loss, PRAUC
|
2395
2468
|
os.makedirs(dst, exist_ok=True)
|
2396
|
-
|
2397
|
-
|
2398
|
-
results_df.to_csv(results_path, index=True, header=True, mode='w')
|
2399
|
-
else:
|
2400
|
-
results_df.to_csv(results_path, index=True, header=False, mode='a')
|
2469
|
+
results_path_train = os.path.join(dst, 'train.csv')
|
2470
|
+
results_path_validation = os.path.join(dst, 'validation.csv')
|
2401
2471
|
|
2402
|
-
|
2403
|
-
|
2404
|
-
return
|
2472
|
+
# Save training data
|
2473
|
+
_save_df_to_csv(results_path_train, train_df)
|
2405
2474
|
|
2406
|
-
|
2407
|
-
|
2408
|
-
|
2475
|
+
# Save validation data if available
|
2476
|
+
if validation_df is not None:
|
2477
|
+
_save_df_to_csv(results_path_validation, validation_df)
|
2409
2478
|
|
2410
|
-
|
2411
|
-
|
2412
|
-
- src (str): The source directory where the settings file will be saved.
|
2479
|
+
# Call read_plot_model_stats after ensuring the files are saved
|
2480
|
+
read_plot_model_stats(results_path_train, results_path_validation, save=True)
|
2413
2481
|
|
2414
|
-
Returns:
|
2415
|
-
None
|
2416
|
-
"""
|
2417
|
-
dst = os.path.join(src,'model')
|
2418
|
-
settings_loc = os.path.join(dst,'settings.csv')
|
2419
|
-
os.makedirs(dst, exist_ok=True)
|
2420
|
-
settings_df = pd.DataFrame(list(settings.items()), columns=['setting_key', 'setting_value'])
|
2421
|
-
display(settings_df)
|
2422
|
-
settings_df.to_csv(settings_loc, index=False)
|
2423
2482
|
return
|
2424
2483
|
|
2425
|
-
|
2426
2484
|
def _copy_missclassified(df):
|
2427
2485
|
misclassified = df[df['true_label'] != df['predicted_label']]
|
2428
2486
|
for _, row in misclassified.iterrows():
|
@@ -2448,7 +2506,7 @@ def _read_db(db_loc, tables):
|
|
2448
2506
|
conn.close() # Close the connection
|
2449
2507
|
return dfs
|
2450
2508
|
|
2451
|
-
def _read_and_merge_data(locs, tables, verbose=False,
|
2509
|
+
def _read_and_merge_data(locs, tables, verbose=False, nuclei_limit=False, pathogen_limit=False, uninfected=False):
|
2452
2510
|
|
2453
2511
|
from .utils import _split_data
|
2454
2512
|
|
@@ -2533,7 +2591,7 @@ def _read_and_merge_data(locs, tables, verbose=False, include_multinucleated=Fal
|
|
2533
2591
|
nucleus = nucleus.assign(cell_id=lambda x: 'o' + x['cell_id'].astype(int).astype(str))
|
2534
2592
|
nucleus = nucleus.assign(prcfo = lambda x: x['prcf'] + '_' + x['cell_id'])
|
2535
2593
|
nucleus['nucleus_prcfo_count'] = nucleus.groupby('prcfo')['prcfo'].transform('count')
|
2536
|
-
if
|
2594
|
+
if nuclei_limit == False:
|
2537
2595
|
nucleus = nucleus[nucleus['nucleus_prcfo_count']==1]
|
2538
2596
|
nucleus_g_df, _ = _split_data(nucleus, 'prcfo', 'cell_id')
|
2539
2597
|
if verbose:
|
@@ -2559,20 +2617,30 @@ def _read_and_merge_data(locs, tables, verbose=False, include_multinucleated=Fal
|
|
2559
2617
|
pathogens = pathogens.assign(cell_id=lambda x: 'o' + x['cell_id'].astype(int).astype(str))
|
2560
2618
|
pathogens = pathogens.assign(prcfo = lambda x: x['prcf'] + '_' + x['cell_id'])
|
2561
2619
|
pathogens['pathogen_prcfo_count'] = pathogens.groupby('prcfo')['prcfo'].transform('count')
|
2562
|
-
|
2620
|
+
|
2621
|
+
print(f"before noninfected: {len(pathogens)}")
|
2622
|
+
if uninfected == False:
|
2563
2623
|
pathogens = pathogens[pathogens['pathogen_prcfo_count']>=1]
|
2564
|
-
|
2565
|
-
|
2624
|
+
print(f"after noninfected: {len(pathogens)}")
|
2625
|
+
|
2626
|
+
if isinstance(pathogen_limit, bool):
|
2627
|
+
if pathogen_limit == False:
|
2566
2628
|
pathogens = pathogens[pathogens['pathogen_prcfo_count']<=1]
|
2567
|
-
|
2568
|
-
|
2629
|
+
print(f"after multiinfected Bool: {len(pathogens)}")
|
2630
|
+
if isinstance(pathogen_limit, float):
|
2631
|
+
pathogen_limit = int(pathogen_limit)
|
2632
|
+
if isinstance(pathogen_limit, int):
|
2633
|
+
pathogens = pathogens[pathogens['pathogen_prcfo_count']<=pathogen_limit]
|
2634
|
+
print(f"afer multiinfected Float: {len(pathogens)}")
|
2569
2635
|
if not 'cell' in tables:
|
2570
2636
|
pathogens_g_df, metadata = _split_data(pathogens, 'prcfo', 'cell_id')
|
2571
2637
|
else:
|
2572
2638
|
pathogens_g_df, _ = _split_data(pathogens, 'prcfo', 'cell_id')
|
2639
|
+
|
2573
2640
|
if verbose:
|
2574
2641
|
print(f'pathogens: {len(pathogens)}')
|
2575
2642
|
print(f'pathogens grouped: {len(pathogens_g_df)}')
|
2643
|
+
|
2576
2644
|
if len(merged_df) == 0:
|
2577
2645
|
merged_df = pathogens_g_df
|
2578
2646
|
else:
|
@@ -2697,4 +2765,475 @@ def generate_cellpose_train_test(src, test_split=0.1):
|
|
2697
2765
|
shutil.copy(img_path, new_img_path)
|
2698
2766
|
shutil.copy(mask_path, new_mask_path)
|
2699
2767
|
print(f'Copied {idx+1}/{len(ls)} images to {_type} set')#, end='\r', flush=True)
|
2700
|
-
|
2768
|
+
|
2769
|
+
def parse_gz_files(folder_path):
|
2770
|
+
"""
|
2771
|
+
Parses the .fastq.gz files in the specified folder path and returns a dictionary
|
2772
|
+
containing the sample names and their corresponding file paths.
|
2773
|
+
|
2774
|
+
Args:
|
2775
|
+
folder_path (str): The path to the folder containing the .fastq.gz files.
|
2776
|
+
|
2777
|
+
Returns:
|
2778
|
+
dict: A dictionary where the keys are the sample names and the values are
|
2779
|
+
dictionaries containing the file paths for the 'R1' and 'R2' read directions.
|
2780
|
+
"""
|
2781
|
+
files = os.listdir(folder_path)
|
2782
|
+
gz_files = [f for f in files if f.endswith('.fastq.gz')]
|
2783
|
+
|
2784
|
+
samples_dict = {}
|
2785
|
+
for gz_file in gz_files:
|
2786
|
+
parts = gz_file.split('_')
|
2787
|
+
sample_name = parts[0]
|
2788
|
+
read_direction = parts[1]
|
2789
|
+
|
2790
|
+
if sample_name not in samples_dict:
|
2791
|
+
samples_dict[sample_name] = {}
|
2792
|
+
|
2793
|
+
if read_direction == "R1":
|
2794
|
+
samples_dict[sample_name]['R1'] = os.path.join(folder_path, gz_file)
|
2795
|
+
elif read_direction == "R2":
|
2796
|
+
samples_dict[sample_name]['R2'] = os.path.join(folder_path, gz_file)
|
2797
|
+
return samples_dict
|
2798
|
+
|
2799
|
+
def generate_dataset(settings={}):
|
2800
|
+
|
2801
|
+
from .utils import initiate_counter, add_images_to_tar, save_settings, generate_path_list_from_db, correct_paths
|
2802
|
+
from .settings import set_generate_dataset_defaults
|
2803
|
+
|
2804
|
+
settings = set_generate_dataset_defaults(settings)
|
2805
|
+
save_settings(settings, 'generate_dataset', show=True)
|
2806
|
+
|
2807
|
+
if isinstance(settings['src'], str):
|
2808
|
+
settings['src'] = [settings['src']]
|
2809
|
+
if isinstance(settings['src'], list):
|
2810
|
+
all_paths = []
|
2811
|
+
for i, src in enumerate(settings['src']):
|
2812
|
+
db_path = os.path.join(src, 'measurements', 'measurements.db')
|
2813
|
+
if i == 0:
|
2814
|
+
dst = os.path.join(src, 'datasets')
|
2815
|
+
paths = generate_path_list_from_db(db_path, file_metadata=settings['file_metadata'])
|
2816
|
+
correct_paths(paths, src)
|
2817
|
+
all_paths.extend(paths)
|
2818
|
+
if isinstance(settings['sample'], int):
|
2819
|
+
selected_paths = random.sample(all_paths, settings['sample'])
|
2820
|
+
print(f"Random selection of {len(selected_paths)} paths")
|
2821
|
+
elif isinstance(settings['sample'], list):
|
2822
|
+
sample = settings['sample'][i]
|
2823
|
+
selected_paths = random.sample(all_paths, settings['sample'])
|
2824
|
+
print(f"Random selection of {len(selected_paths)} paths")
|
2825
|
+
else:
|
2826
|
+
selected_paths = all_paths
|
2827
|
+
random.shuffle(selected_paths)
|
2828
|
+
print(f"All paths: {len(selected_paths)} paths")
|
2829
|
+
|
2830
|
+
total_images = len(selected_paths)
|
2831
|
+
print(f"Found {total_images} images")
|
2832
|
+
|
2833
|
+
# Create a temp folder in dst
|
2834
|
+
temp_dir = os.path.join(dst, "temp_tars")
|
2835
|
+
os.makedirs(temp_dir, exist_ok=True)
|
2836
|
+
|
2837
|
+
# Chunking the data
|
2838
|
+
num_procs = max(2, cpu_count() - 2)
|
2839
|
+
chunk_size = len(selected_paths) // num_procs
|
2840
|
+
remainder = len(selected_paths) % num_procs
|
2841
|
+
|
2842
|
+
paths_chunks = []
|
2843
|
+
start = 0
|
2844
|
+
for i in range(num_procs):
|
2845
|
+
end = start + chunk_size + (1 if i < remainder else 0)
|
2846
|
+
paths_chunks.append(selected_paths[start:end])
|
2847
|
+
start = end
|
2848
|
+
|
2849
|
+
temp_tar_files = [os.path.join(temp_dir, f"temp_{i}.tar") for i in range(num_procs)]
|
2850
|
+
|
2851
|
+
print(f"Generating temporary tar files in {dst}")
|
2852
|
+
|
2853
|
+
# Initialize shared counter and lock
|
2854
|
+
counter = Value('i', 0)
|
2855
|
+
lock = Lock()
|
2856
|
+
|
2857
|
+
with Pool(processes=num_procs, initializer=initiate_counter, initargs=(counter, lock)) as pool:
|
2858
|
+
pool.starmap(add_images_to_tar, [(paths_chunks[i], temp_tar_files[i], total_images) for i in range(num_procs)])
|
2859
|
+
|
2860
|
+
# Combine the temporary tar files into a final tar
|
2861
|
+
date_name = datetime.date.today().strftime('%y%m%d')
|
2862
|
+
if len(settings['src']) > 1:
|
2863
|
+
date_name = f"{date_name}_combined"
|
2864
|
+
#if not settings['file_metadata'] is None:
|
2865
|
+
# tar_name = f"{date_name}_{settings['experiment']}_{settings['file_metadata']}.tar"
|
2866
|
+
#else:
|
2867
|
+
tar_name = f"{date_name}_{settings['experiment']}.tar"
|
2868
|
+
tar_name = os.path.join(dst, tar_name)
|
2869
|
+
if os.path.exists(tar_name):
|
2870
|
+
number = random.randint(1, 100)
|
2871
|
+
tar_name_2 = f"{date_name}_{settings['experiment']}_{settings['file_metadata']}_{number}.tar"
|
2872
|
+
print(f"Warning: {os.path.basename(tar_name)} exists, saving as {os.path.basename(tar_name_2)} ")
|
2873
|
+
tar_name = os.path.join(dst, tar_name_2)
|
2874
|
+
|
2875
|
+
print(f"Merging temporary files")
|
2876
|
+
|
2877
|
+
with tarfile.open(tar_name, 'w') as final_tar:
|
2878
|
+
for temp_tar_path in temp_tar_files:
|
2879
|
+
with tarfile.open(temp_tar_path, 'r') as temp_tar:
|
2880
|
+
for member in temp_tar.getmembers():
|
2881
|
+
file_obj = temp_tar.extractfile(member)
|
2882
|
+
final_tar.addfile(member, file_obj)
|
2883
|
+
os.remove(temp_tar_path)
|
2884
|
+
|
2885
|
+
# Delete the temp folder
|
2886
|
+
shutil.rmtree(temp_dir)
|
2887
|
+
print(f"\nSaved {total_images} images to {tar_name}")
|
2888
|
+
|
2889
|
+
return tar_name
|
2890
|
+
|
2891
|
+
def generate_loaders(src, mode='train', image_size=224, batch_size=32, classes=['nc','pc'], n_jobs=None, validation_split=0.0, pin_memory=False, normalize=False, channels=[1, 2, 3], augment=False, verbose=False):
|
2892
|
+
|
2893
|
+
"""
|
2894
|
+
Generate data loaders for training and validation/test datasets.
|
2895
|
+
|
2896
|
+
Parameters:
|
2897
|
+
- src (str): The source directory containing the data.
|
2898
|
+
- mode (str): The mode of operation. Options are 'train' or 'test'.
|
2899
|
+
- image_size (int): The size of the input images.
|
2900
|
+
- batch_size (int): The batch size for the data loaders.
|
2901
|
+
- classes (list): The list of classes to consider.
|
2902
|
+
- n_jobs (int): The number of worker threads for data loading.
|
2903
|
+
- validation_split (float): The fraction of data to use for validation.
|
2904
|
+
- pin_memory (bool): Whether to pin memory for faster data transfer.
|
2905
|
+
- normalize (bool): Whether to normalize the input images.
|
2906
|
+
- verbose (bool): Whether to print additional information and show images.
|
2907
|
+
- channels (list): The list of channels to retain. Options are [1, 2, 3] for all channels, [1, 2] for blue and green, etc.
|
2908
|
+
|
2909
|
+
Returns:
|
2910
|
+
- train_loaders (list): List of data loaders for training datasets.
|
2911
|
+
- val_loaders (list): List of data loaders for validation datasets.
|
2912
|
+
"""
|
2913
|
+
|
2914
|
+
from .utils import SelectChannels, augment_dataset
|
2915
|
+
|
2916
|
+
chans = []
|
2917
|
+
|
2918
|
+
if 'r' in channels:
|
2919
|
+
chans.append(1)
|
2920
|
+
if 'g' in channels:
|
2921
|
+
chans.append(2)
|
2922
|
+
if 'b' in channels:
|
2923
|
+
chans.append(3)
|
2924
|
+
|
2925
|
+
channels = chans
|
2926
|
+
|
2927
|
+
if verbose:
|
2928
|
+
print(f'Training a network on channels: {channels}')
|
2929
|
+
print(f'Channel 1: Red, Channel 2: Green, Channel 3: Blue')
|
2930
|
+
|
2931
|
+
train_loaders = []
|
2932
|
+
val_loaders = []
|
2933
|
+
|
2934
|
+
if normalize:
|
2935
|
+
transform = transforms.Compose([
|
2936
|
+
transforms.ToTensor(),
|
2937
|
+
transforms.CenterCrop(size=(image_size, image_size)),
|
2938
|
+
SelectChannels(channels),
|
2939
|
+
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])
|
2940
|
+
else:
|
2941
|
+
transform = transforms.Compose([
|
2942
|
+
transforms.ToTensor(),
|
2943
|
+
transforms.CenterCrop(size=(image_size, image_size)),
|
2944
|
+
SelectChannels(channels)])
|
2945
|
+
|
2946
|
+
if mode == 'train':
|
2947
|
+
data_dir = os.path.join(src, 'train')
|
2948
|
+
shuffle = True
|
2949
|
+
print('Loading Train and validation datasets')
|
2950
|
+
elif mode == 'test':
|
2951
|
+
data_dir = os.path.join(src, 'test')
|
2952
|
+
val_loaders = []
|
2953
|
+
validation_split = 0.0
|
2954
|
+
shuffle = True
|
2955
|
+
print('Loading test dataset')
|
2956
|
+
else:
|
2957
|
+
print(f'mode:{mode} is not valid, use mode = train or test')
|
2958
|
+
return
|
2959
|
+
|
2960
|
+
data = spacrDataset(data_dir, classes, transform=transform, shuffle=shuffle, pin_memory=pin_memory)
|
2961
|
+
num_workers = n_jobs if n_jobs is not None else 0
|
2962
|
+
|
2963
|
+
if validation_split > 0:
|
2964
|
+
train_size = int((1 - validation_split) * len(data))
|
2965
|
+
val_size = len(data) - train_size
|
2966
|
+
if not augment:
|
2967
|
+
print(f'Train data:{train_size}, Validation data:{val_size}')
|
2968
|
+
train_dataset, val_dataset = random_split(data, [train_size, val_size])
|
2969
|
+
|
2970
|
+
if augment:
|
2971
|
+
|
2972
|
+
print(f'Data before augmentation: Train: {len(train_dataset)}, Validataion:{len(val_dataset)}')
|
2973
|
+
train_dataset = augment_dataset(train_dataset, is_grayscale=(len(channels) == 1))
|
2974
|
+
print(f'Data after augmentation: Train: {len(train_dataset)}')
|
2975
|
+
|
2976
|
+
print(f'Generating Dataloader with {n_jobs} workers')
|
2977
|
+
train_loaders = DataLoader(train_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=1, pin_memory=pin_memory, persistent_workers=True)
|
2978
|
+
val_loaders = DataLoader(val_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=1, pin_memory=pin_memory, persistent_workers=True)
|
2979
|
+
else:
|
2980
|
+
train_loaders = DataLoader(data, batch_size=batch_size, shuffle=shuffle, num_workers=1, pin_memory=pin_memory, persistent_workers=True)
|
2981
|
+
|
2982
|
+
#dataset (Dataset) – dataset from which to load the data.
|
2983
|
+
#batch_size (int, optional) – how many samples per batch to load (default: 1).
|
2984
|
+
#shuffle (bool, optional) – set to True to have the data reshuffled at every epoch (default: False).
|
2985
|
+
#sampler (Sampler or Iterable, optional) – defines the strategy to draw samples from the dataset. Can be any Iterable with __len__ implemented. If specified, shuffle must not be specified.
|
2986
|
+
#batch_sampler (Sampler or Iterable, optional) – like sampler, but returns a batch of indices at a time. Mutually exclusive with batch_size, shuffle, sampler, and drop_last.
|
2987
|
+
#num_workers (int, optional) – how many subprocesses to use for data loading. 0 means that the data will be loaded in the main process. (default: 0)
|
2988
|
+
#collate_fn (Callable, optional) – merges a list of samples to form a mini-batch of Tensor(s). Used when using batched loading from a map-style dataset.
|
2989
|
+
#pin_memory (bool, optional) – If True, the data loader will copy Tensors into device/CUDA pinned memory before returning them. If your data elements are a custom type, or your collate_fn returns a batch that is a custom type, see the example below.
|
2990
|
+
#drop_last (bool, optional) – set to True to drop the last incomplete batch, if the dataset size is not divisible by the batch size. If False and the size of dataset is not divisible by the batch size, then the last batch will be smaller. (default: False)
|
2991
|
+
#timeout (numeric, optional) – if positive, the timeout value for collecting a batch from workers. Should always be non-negative. (default: 0)
|
2992
|
+
#worker_init_fn (Callable, optional) – If not None, this will be called on each worker subprocess with the worker id (an int in [0, num_workers - 1]) as input, after seeding and before data loading. (default: None)
|
2993
|
+
#multiprocessing_context (str or multiprocessing.context.BaseContext, optional) – If None, the default multiprocessing context of your operating system will be used. (default: None)
|
2994
|
+
#generator (torch.Generator, optional) – If not None, this RNG will be used by RandomSampler to generate random indexes and multiprocessing to generate base_seed for workers. (default: None)
|
2995
|
+
#prefetch_factor (int, optional, keyword-only arg) – Number of batches loaded in advance by each worker. 2 means there will be a total of 2 * num_workers batches prefetched across all workers. (default value depends on the set value for num_workers. If value of num_workers=0 default is None. Otherwise, if value of num_workers > 0 default is 2).
|
2996
|
+
#persistent_workers (bool, optional) – If True, the data loader will not shut down the worker processes after a dataset has been consumed once. This allows to maintain the workers Dataset instances alive. (default: False)
|
2997
|
+
#pin_memory_device (str, optional) – the device to pin_memory to if pin_memory is True.
|
2998
|
+
|
2999
|
+
#images, labels, filenames = next(iter(train_loaders))
|
3000
|
+
#images = images.cpu()
|
3001
|
+
#label_strings = [str(label.item()) for label in labels]
|
3002
|
+
#train_fig = _imshow_gpu(images, label_strings, nrow=20, fontsize=12)
|
3003
|
+
#if verbose:
|
3004
|
+
# plt.show()
|
3005
|
+
|
3006
|
+
train_fig = None
|
3007
|
+
|
3008
|
+
return train_loaders, val_loaders, train_fig
|
3009
|
+
|
3010
|
+
def generate_training_dataset(settings):
|
3011
|
+
|
3012
|
+
# Function to filter png_list_df by prcfo present in df without merging
|
3013
|
+
def filter_png_list(db_path, settings):
|
3014
|
+
tables = ['cell', 'nucleus', 'pathogen', 'cytoplasm']
|
3015
|
+
df, _ = _read_and_merge_data(locs=[db_path],
|
3016
|
+
tables=tables,
|
3017
|
+
verbose=False,
|
3018
|
+
nuclei_limit=settings['nuclei_limit'],
|
3019
|
+
pathogen_limit=settings['pathogen_limit'],
|
3020
|
+
uninfected=settings['uninfected'])
|
3021
|
+
[png_list_df] = _read_db(db_loc=db_path, tables=['png_list'])
|
3022
|
+
filtered_png_list_df = png_list_df[png_list_df['prcfo'].isin(df.index)]
|
3023
|
+
return filtered_png_list_df
|
3024
|
+
|
3025
|
+
# Function to get the smallest class size based on the dataset mode
|
3026
|
+
def get_smallest_class_size(df, settings, dataset_mode):
|
3027
|
+
if dataset_mode == 'metadata':
|
3028
|
+
sizes = [len(df[df['metadata_based_class'] == c]) for c in settings['classes']]
|
3029
|
+
elif dataset_mode == 'annotation':
|
3030
|
+
sizes = [len(class_paths) for class_paths in df]
|
3031
|
+
size = min(sizes)
|
3032
|
+
print(f'Using the smallest class size: {size}')
|
3033
|
+
return size
|
3034
|
+
|
3035
|
+
# Measurement-based selection logic
|
3036
|
+
def measurement_based_selection(settings, db_path):
|
3037
|
+
class_paths_ls = []
|
3038
|
+
tables = ['cell', 'nucleus', 'pathogen', 'cytoplasm']
|
3039
|
+
df, _ = _read_and_merge_data(locs=[db_path],
|
3040
|
+
tables=tables,
|
3041
|
+
verbose=False,
|
3042
|
+
nuclei_limit=settings['nuclei_limit'],
|
3043
|
+
pathogen_limit=settings['pathogen_limit'],
|
3044
|
+
uninfected=settings['uninfected'])
|
3045
|
+
|
3046
|
+
print('length df 1', len(df))
|
3047
|
+
df = annotate_conditions(df, cells=['HeLa'], pathogens=['pathogen'], treatments=settings['classes'],
|
3048
|
+
treatment_loc=settings['class_metadata'])#, types=settings['metadata_type_by'])
|
3049
|
+
print('length df 2', len(df))
|
3050
|
+
|
3051
|
+
png_list_df = filter_png_list(db_path, settings)
|
3052
|
+
|
3053
|
+
if settings['custom_measurement']:
|
3054
|
+
if isinstance(settings['custom_measurement'], list):
|
3055
|
+
if len(settings['custom_measurement']) == 2:
|
3056
|
+
df['recruitment'] = df[f"{settings['custom_measurement'][0]}"] / df[f"{settings['custom_measurement'][1]}"]
|
3057
|
+
else:
|
3058
|
+
df['recruitment'] = df[f"{settings['custom_measurement'][0]}"]
|
3059
|
+
else:
|
3060
|
+
print("custom_measurement should be a list.")
|
3061
|
+
return
|
3062
|
+
|
3063
|
+
else:
|
3064
|
+
df['recruitment'] = df[f"pathogen_channel_{settings['channel_of_interest']}_mean_intensity"] / df[f"cytoplasm_channel_{settings['channel_of_interest']}_mean_intensity"]
|
3065
|
+
|
3066
|
+
q25 = df['recruitment'].quantile(0.25)
|
3067
|
+
q75 = df['recruitment'].quantile(0.75)
|
3068
|
+
df_lower = df[df['recruitment'] <= q25]
|
3069
|
+
df_upper = df[df['recruitment'] >= q75]
|
3070
|
+
|
3071
|
+
class_paths_lower = get_paths_from_db(df=df_lower, png_df=png_list_df, image_type=settings['png_type'])
|
3072
|
+
class_paths_lower = random.sample(class_paths_lower['png_path'].tolist(), settings['size'])
|
3073
|
+
class_paths_ls.append(class_paths_lower)
|
3074
|
+
|
3075
|
+
class_paths_upper = get_paths_from_db(df=df_upper, png_df=png_list_df, image_type=settings['png_type'])
|
3076
|
+
class_paths_upper = random.sample(class_paths_upper['png_path'].tolist(), settings['size'])
|
3077
|
+
class_paths_ls.append(class_paths_upper)
|
3078
|
+
|
3079
|
+
return class_paths_ls
|
3080
|
+
|
3081
|
+
# Metadata-based selection logic
|
3082
|
+
def metadata_based_selection(db_path, settings):
|
3083
|
+
class_paths_ls = []
|
3084
|
+
df = filter_png_list(db_path, settings)
|
3085
|
+
|
3086
|
+
df['metadata_based_class'] = pd.NA
|
3087
|
+
for i, class_ in enumerate(settings['classes']):
|
3088
|
+
ls = settings['class_metadata'][i]
|
3089
|
+
df.loc[df[settings['metadata_type_by']].isin(ls), 'metadata_based_class'] = class_
|
3090
|
+
|
3091
|
+
size = get_smallest_class_size(df, settings, 'metadata')
|
3092
|
+
for class_ in settings['classes']:
|
3093
|
+
class_temp_df = df[df['metadata_based_class'] == class_]
|
3094
|
+
print(f'Found {len(class_temp_df)} images for class {class_}')
|
3095
|
+
class_paths_temp = class_temp_df['png_path'].tolist()
|
3096
|
+
|
3097
|
+
# Ensure to sample `size` number of images (smallest class size)
|
3098
|
+
if len(class_paths_temp) > size:
|
3099
|
+
class_paths_temp = random.sample(class_paths_temp, size)
|
3100
|
+
|
3101
|
+
class_paths_ls.append(class_paths_temp)
|
3102
|
+
|
3103
|
+
return class_paths_ls
|
3104
|
+
|
3105
|
+
# Annotation-based selection logic
|
3106
|
+
def annotation_based_selection(db_path, dst, settings):
|
3107
|
+
class_paths_ls = training_dataset_from_annotation(db_path, dst, settings['annotation_column'], annotated_classes=settings['annotated_classes'])
|
3108
|
+
|
3109
|
+
size = get_smallest_class_size(class_paths_ls, settings, 'annotation')
|
3110
|
+
for i, class_paths in enumerate(class_paths_ls):
|
3111
|
+
if len(class_paths) > size:
|
3112
|
+
class_paths_ls[i] = random.sample(class_paths, size)
|
3113
|
+
|
3114
|
+
return class_paths_ls
|
3115
|
+
|
3116
|
+
from .io import _read_and_merge_data, _read_db
|
3117
|
+
from .utils import get_paths_from_db, annotate_conditions, save_settings
|
3118
|
+
from .settings import set_generate_training_dataset_defaults
|
3119
|
+
|
3120
|
+
# Set default settings and save
|
3121
|
+
settings = set_generate_training_dataset_defaults(settings)
|
3122
|
+
save_settings(settings, 'cv_dataset', show=True)
|
3123
|
+
|
3124
|
+
class_path_list = None
|
3125
|
+
|
3126
|
+
if isinstance(settings['src'], str):
|
3127
|
+
src = [settings['src']]
|
3128
|
+
|
3129
|
+
for i, src in enumerate(settings['src']):
|
3130
|
+
db_path = os.path.join(src, 'measurements', 'measurements.db')
|
3131
|
+
|
3132
|
+
if len(settings['src']) > 1 and i == 0:
|
3133
|
+
dst = os.path.join(src, 'datasets', 'training_all')
|
3134
|
+
elif len(settings['src']) == 1:
|
3135
|
+
dst = os.path.join(src, 'datasets', 'training')
|
3136
|
+
|
3137
|
+
# Create a new directory for training data if necessary
|
3138
|
+
if os.path.exists(dst):
|
3139
|
+
for i in range(1, 100000):
|
3140
|
+
dst = dst + f'_{i}'
|
3141
|
+
if not os.path.exists(dst):
|
3142
|
+
print(f'Creating new directory for training: {dst}')
|
3143
|
+
break
|
3144
|
+
|
3145
|
+
# Select dataset based on dataset mode
|
3146
|
+
if settings['dataset_mode'] == 'annotation':
|
3147
|
+
class_paths_ls = annotation_based_selection(db_path, dst, settings)
|
3148
|
+
|
3149
|
+
elif settings['dataset_mode'] == 'metadata':
|
3150
|
+
class_paths_ls = metadata_based_selection(db_path, settings)
|
3151
|
+
|
3152
|
+
elif settings['dataset_mode'] == 'measurement':
|
3153
|
+
class_paths_ls = measurement_based_selection(settings, db_path)
|
3154
|
+
|
3155
|
+
if class_path_list is None:
|
3156
|
+
class_path_list = [[] for _ in range(len(class_paths_ls))]
|
3157
|
+
|
3158
|
+
# Extend each list in class_path_list with the corresponding list from class_paths_ls
|
3159
|
+
for idx in range(len(class_paths_ls)):
|
3160
|
+
class_path_list[idx].extend(class_paths_ls[idx])
|
3161
|
+
|
3162
|
+
# Generate and return training and testing directories
|
3163
|
+
train_class_dir, test_class_dir = generate_dataset_from_lists(dst, class_data=class_path_list, classes=settings['classes'], test_split=settings['test_split'])
|
3164
|
+
|
3165
|
+
return train_class_dir, test_class_dir
|
3166
|
+
|
3167
|
+
def training_dataset_from_annotation(db_path, dst, annotation_column='test', annotated_classes=(1, 2)):
|
3168
|
+
all_paths = []
|
3169
|
+
|
3170
|
+
# Connect to the database and retrieve the image paths and annotations
|
3171
|
+
print(f'Reading DataBase: {db_path}')
|
3172
|
+
with sqlite3.connect(db_path) as conn:
|
3173
|
+
cursor = conn.cursor()
|
3174
|
+
# Prepare the query with parameterized placeholders for annotated_classes
|
3175
|
+
placeholders = ','.join('?' * len(annotated_classes))
|
3176
|
+
query = f"SELECT png_path, {annotation_column} FROM png_list WHERE {annotation_column} IN ({placeholders})"
|
3177
|
+
cursor.execute(query, annotated_classes)
|
3178
|
+
|
3179
|
+
while True:
|
3180
|
+
rows = cursor.fetchmany(1000)
|
3181
|
+
if not rows:
|
3182
|
+
break
|
3183
|
+
for row in rows:
|
3184
|
+
all_paths.append(row)
|
3185
|
+
|
3186
|
+
# Filter paths based on annotation
|
3187
|
+
class_paths = []
|
3188
|
+
for class_ in annotated_classes:
|
3189
|
+
class_paths_temp = [path for path, annotation in all_paths if annotation == class_]
|
3190
|
+
class_paths.append(class_paths_temp)
|
3191
|
+
|
3192
|
+
print(f'Generated a list of lists from annotation of {len(class_paths)} classes')
|
3193
|
+
return class_paths
|
3194
|
+
|
3195
|
+
def generate_dataset_from_lists(dst, class_data, classes, test_split=0.1):
|
3196
|
+
from .utils import print_progress
|
3197
|
+
# Make sure that the length of class_data matches the length of classes
|
3198
|
+
if len(class_data) != len(classes):
|
3199
|
+
raise ValueError("class_data and classes must have the same length.")
|
3200
|
+
|
3201
|
+
total_files = sum(len(data) for data in class_data)
|
3202
|
+
processed_files = 0
|
3203
|
+
time_ls = []
|
3204
|
+
|
3205
|
+
for cls, data in zip(classes, class_data):
|
3206
|
+
# Create directories
|
3207
|
+
train_class_dir = os.path.join(dst, f'train/{cls}')
|
3208
|
+
test_class_dir = os.path.join(dst, f'test/{cls}')
|
3209
|
+
os.makedirs(train_class_dir, exist_ok=True)
|
3210
|
+
os.makedirs(test_class_dir, exist_ok=True)
|
3211
|
+
|
3212
|
+
# Split the data
|
3213
|
+
train_data, test_data = train_test_split(data, test_size=test_split, shuffle=True, random_state=42)
|
3214
|
+
|
3215
|
+
# Copy train files
|
3216
|
+
for path in train_data:
|
3217
|
+
start = time.time()
|
3218
|
+
shutil.copy(path, os.path.join(train_class_dir, os.path.basename(path)))
|
3219
|
+
duration = time.time() - start
|
3220
|
+
time_ls.append(duration)
|
3221
|
+
print_progress(processed_files, total_files, n_jobs=1, time_ls=None, batch_size=None, operation_type="Copying files for Train dataset")
|
3222
|
+
processed_files += 1
|
3223
|
+
|
3224
|
+
# Copy test files
|
3225
|
+
for path in test_data:
|
3226
|
+
start = time.time()
|
3227
|
+
shutil.copy(path, os.path.join(test_class_dir, os.path.basename(path)))
|
3228
|
+
duration = time.time() - start
|
3229
|
+
time_ls.append(duration)
|
3230
|
+
print_progress(processed_files, total_files, n_jobs=1, time_ls=None, batch_size=None, operation_type="Copying files for Test dataset")
|
3231
|
+
processed_files += 1
|
3232
|
+
|
3233
|
+
# Print summary
|
3234
|
+
for cls in classes:
|
3235
|
+
train_class_dir = os.path.join(dst, f'train/{cls}')
|
3236
|
+
test_class_dir = os.path.join(dst, f'test/{cls}')
|
3237
|
+
print(f'Train class {cls}: {len(os.listdir(train_class_dir))}, Test class {cls}: {len(os.listdir(test_class_dir))}')
|
3238
|
+
|
3239
|
+
return os.path.join(dst, 'train'), os.path.join(dst, 'test')
|