spacr 0.3.0__py3-none-any.whl → 0.3.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- spacr/__init__.py +19 -3
- spacr/cellpose.py +311 -0
- spacr/core.py +142 -2495
- spacr/deep_spacr.py +151 -29
- spacr/gui.py +1 -0
- spacr/gui_core.py +74 -63
- spacr/gui_elements.py +110 -5
- spacr/gui_utils.py +346 -6
- spacr/io.py +631 -51
- spacr/logger.py +28 -9
- spacr/measure.py +107 -95
- spacr/mediar.py +0 -5
- spacr/ml.py +964 -0
- spacr/openai.py +37 -0
- spacr/plot.py +281 -16
- spacr/resources/data/lopit.csv +3833 -0
- spacr/resources/data/toxoplasma_metadata.csv +8843 -0
- spacr/resources/icons/convert.png +0 -0
- spacr/resources/{models/cp/toxo_plaque_cyto_e25000_X1120_Y1120.CP_model → icons/dna_matrix.mp4} +0 -0
- spacr/sequencing.py +241 -1311
- spacr/settings.py +129 -43
- spacr/sim.py +0 -2
- spacr/submodules.py +348 -0
- spacr/timelapse.py +0 -2
- spacr/toxo.py +233 -0
- spacr/utils.py +275 -173
- {spacr-0.3.0.dist-info → spacr-0.3.2.dist-info}/METADATA +7 -1
- {spacr-0.3.0.dist-info → spacr-0.3.2.dist-info}/RECORD +32 -33
- spacr/chris.py +0 -50
- spacr/graph_learning.py +0 -340
- spacr/resources/MEDIAR/.git +0 -1
- spacr/resources/MEDIAR_weights/.DS_Store +0 -0
- spacr/resources/icons/.DS_Store +0 -0
- spacr/resources/icons/spacr_logo_rotation.gif +0 -0
- spacr/resources/models/cp/toxo_plaque_cyto_e25000_X1120_Y1120.CP_model_settings.csv +0 -23
- spacr/resources/models/cp/toxo_pv_lumen.CP_model +0 -0
- spacr/sim_app.py +0 -0
- {spacr-0.3.0.dist-info → spacr-0.3.2.dist-info}/LICENSE +0 -0
- {spacr-0.3.0.dist-info → spacr-0.3.2.dist-info}/WHEEL +0 -0
- {spacr-0.3.0.dist-info → spacr-0.3.2.dist-info}/entry_points.txt +0 -0
- {spacr-0.3.0.dist-info → spacr-0.3.2.dist-info}/top_level.txt +0 -0
spacr/io.py
CHANGED
@@ -1,30 +1,132 @@
|
|
1
|
-
import os, re, sqlite3, gc, torch, time, random, shutil, cv2, tarfile, cellpose, glob, queue
|
1
|
+
import os, re, sqlite3, gc, torch, time, random, shutil, cv2, tarfile, cellpose, glob, queue, tifffile, czifile, atexit, datetime
|
2
2
|
import numpy as np
|
3
3
|
import pandas as pd
|
4
|
-
import tifffile
|
5
4
|
from PIL import Image, ImageOps
|
6
|
-
from collections import defaultdict, Counter
|
5
|
+
from collections import defaultdict, Counter
|
7
6
|
from pathlib import Path
|
8
7
|
from functools import partial
|
9
8
|
from matplotlib.animation import FuncAnimation
|
10
9
|
from IPython.display import display
|
11
10
|
from skimage.util import img_as_uint
|
12
11
|
from skimage.exposure import rescale_intensity
|
13
|
-
from skimage import filters
|
14
12
|
import skimage.measure as measure
|
15
13
|
from skimage import exposure
|
16
14
|
import imageio.v2 as imageio2
|
17
15
|
import matplotlib.pyplot as plt
|
18
16
|
from io import BytesIO
|
19
|
-
from IPython.display import display
|
20
|
-
from multiprocessing import Pool, cpu_count, Process, Queue
|
21
|
-
from torch.utils.data import Dataset, DataLoader
|
17
|
+
from IPython.display import display
|
18
|
+
from multiprocessing import Pool, cpu_count, Process, Queue, Value, Lock
|
19
|
+
from torch.utils.data import Dataset, DataLoader, random_split
|
22
20
|
import matplotlib.pyplot as plt
|
23
21
|
from torchvision.transforms import ToTensor
|
24
|
-
import seaborn as sns
|
25
|
-
import
|
26
|
-
|
27
|
-
|
22
|
+
import seaborn as sns
|
23
|
+
from nd2reader import ND2Reader
|
24
|
+
from torchvision import transforms
|
25
|
+
|
26
|
+
def process_non_tif_non_2D_images(folder):
|
27
|
+
"""Processes all images in the folder and splits them into grayscale channels, preserving bit depth."""
|
28
|
+
|
29
|
+
# Helper function to save grayscale images
|
30
|
+
def save_grayscale_images(image, base_name, folder, dtype, channel=None, z=None, t=None):
|
31
|
+
"""Save grayscale images with appropriate suffix based on channel, z, and t, preserving bit depth."""
|
32
|
+
suffix = ""
|
33
|
+
if channel is not None:
|
34
|
+
suffix += f"_C{channel}"
|
35
|
+
if z is not None:
|
36
|
+
suffix += f"_Z{z}"
|
37
|
+
if t is not None:
|
38
|
+
suffix += f"_T{t}"
|
39
|
+
|
40
|
+
output_filename = os.path.join(folder, f"{base_name}{suffix}.tif")
|
41
|
+
tifffile.imwrite(output_filename, image.astype(dtype))
|
42
|
+
|
43
|
+
# Function to handle splitting of multi-dimensional images into grayscale channels
|
44
|
+
def split_channels(image, folder, base_name, dtype):
|
45
|
+
"""Splits the image into channels and handles 3D, 4D, and 5D image cases."""
|
46
|
+
if image.ndim == 2:
|
47
|
+
# Grayscale image, already processed separately
|
48
|
+
return
|
49
|
+
|
50
|
+
elif image.ndim == 3:
|
51
|
+
# 3D image: (height, width, channels)
|
52
|
+
for c in range(image.shape[2]):
|
53
|
+
save_grayscale_images(image[..., c], base_name, folder, dtype, channel=c+1)
|
54
|
+
|
55
|
+
elif image.ndim == 4:
|
56
|
+
# 4D image: (height, width, channels, Z-dimension)
|
57
|
+
for z in range(image.shape[3]):
|
58
|
+
for c in range(image.shape[2]):
|
59
|
+
save_grayscale_images(image[..., c, z], base_name, folder, dtype, channel=c+1, z=z+1)
|
60
|
+
|
61
|
+
elif image.ndim == 5:
|
62
|
+
# 5D image: (height, width, channels, Z-dimension, Time)
|
63
|
+
for t in range(image.shape[4]):
|
64
|
+
for z in range(image.shape[3]):
|
65
|
+
for c in range(image.shape[2]):
|
66
|
+
save_grayscale_images(image[..., c, z, t], base_name, folder, dtype, channel=c+1, z=z+1, t=t+1)
|
67
|
+
|
68
|
+
# Function to load images in various formats
|
69
|
+
def load_image(file_path):
|
70
|
+
"""Loads image from various formats and returns it as a numpy array along with its dtype."""
|
71
|
+
ext = os.path.splitext(file_path)[1].lower()
|
72
|
+
|
73
|
+
if ext in ['.tif', '.tiff']:
|
74
|
+
image = tifffile.imread(file_path)
|
75
|
+
return image, image.dtype
|
76
|
+
|
77
|
+
elif ext in ['.png', '.jpg', '.jpeg']:
|
78
|
+
image = Image.open(file_path)
|
79
|
+
return np.array(image), image.mode
|
80
|
+
|
81
|
+
elif ext == '.czi':
|
82
|
+
with czifile.CziFile(file_path) as czi:
|
83
|
+
image = czi.asarray()
|
84
|
+
return image, image.dtype
|
85
|
+
|
86
|
+
elif ext == '.nd2':
|
87
|
+
with ND2Reader(file_path) as nd2:
|
88
|
+
image = np.array(nd2)
|
89
|
+
return image, image.dtype
|
90
|
+
|
91
|
+
else:
|
92
|
+
raise ValueError(f"Unsupported file extension: {ext}")
|
93
|
+
|
94
|
+
# Function to check if an image is grayscale and save it as a TIFF if it isn't already
|
95
|
+
def convert_grayscale_to_tiff(image, filename, folder, dtype):
|
96
|
+
"""Convert grayscale images that are not in TIFF format to TIFF, preserving bit depth."""
|
97
|
+
base_name = os.path.splitext(filename)[0]
|
98
|
+
output_filename = os.path.join(folder, f"{base_name}.tif")
|
99
|
+
tifffile.imwrite(output_filename, image.astype(dtype))
|
100
|
+
print(f"Converted grayscale image {filename} to TIFF with bit depth {dtype}.")
|
101
|
+
|
102
|
+
# Supported formats
|
103
|
+
supported_formats = ['.tif', '.tiff', '.png', '.jpg', '.jpeg', '.czi', '.nd2']
|
104
|
+
|
105
|
+
# Loop through all files in the folder
|
106
|
+
for filename in os.listdir(folder):
|
107
|
+
file_path = os.path.join(folder, filename)
|
108
|
+
ext = os.path.splitext(file_path)[1].lower()
|
109
|
+
|
110
|
+
if ext in supported_formats:
|
111
|
+
print(f"Processing {filename}")
|
112
|
+
try:
|
113
|
+
# Load the image and its dtype
|
114
|
+
image, dtype = load_image(file_path)
|
115
|
+
|
116
|
+
# If the image is grayscale (2D), convert it to TIFF if it's not already in TIFF format
|
117
|
+
if image.ndim == 2:
|
118
|
+
if ext not in ['.tif', '.tiff']:
|
119
|
+
convert_grayscale_to_tiff(image, filename, folder, dtype)
|
120
|
+
else:
|
121
|
+
print(f"Image {filename} is already grayscale and in TIFF format, skipping.")
|
122
|
+
continue
|
123
|
+
|
124
|
+
# Otherwise, split channels and save images
|
125
|
+
base_name = os.path.splitext(filename)[0]
|
126
|
+
split_channels(image, folder, base_name, dtype)
|
127
|
+
|
128
|
+
except Exception as e:
|
129
|
+
print(f"Error processing {filename}: {str(e)}")
|
28
130
|
|
29
131
|
def _load_images_and_labels(image_files, label_files, circular=False, invert=False):
|
30
132
|
|
@@ -632,6 +734,20 @@ class TarImageDataset(Dataset):
|
|
632
734
|
img = self.transform(img)
|
633
735
|
|
634
736
|
return img, m.name
|
737
|
+
|
738
|
+
def load_images_from_paths(images_by_key):
|
739
|
+
images_dict = {}
|
740
|
+
|
741
|
+
for key, paths in images_by_key.items():
|
742
|
+
images_dict[key] = []
|
743
|
+
for path in paths:
|
744
|
+
try:
|
745
|
+
with Image.open(path) as img:
|
746
|
+
images_dict[key].append(np.array(img))
|
747
|
+
except Exception as e:
|
748
|
+
print(f"Error loading image from {path}: {e}")
|
749
|
+
|
750
|
+
return images_dict
|
635
751
|
|
636
752
|
#@log_function_call
|
637
753
|
def _rename_and_organize_image_files(src, regex, batch_size=100, pick_slice=False, skip_mode='01', metadata_type='', img_format='.tif'):
|
@@ -653,20 +769,24 @@ def _rename_and_organize_image_files(src, regex, batch_size=100, pick_slice=Fals
|
|
653
769
|
from .utils import _extract_filename_metadata, print_progress
|
654
770
|
|
655
771
|
regular_expression = re.compile(regex)
|
656
|
-
images_by_key = defaultdict(list)
|
657
772
|
stack_path = os.path.join(src, 'stack')
|
658
773
|
files_processed = 0
|
659
774
|
if not os.path.exists(stack_path) or (os.path.isdir(stack_path) and len(os.listdir(stack_path)) == 0):
|
660
775
|
all_filenames = [filename for filename in os.listdir(src) if filename.endswith(img_format)]
|
661
|
-
print(f'
|
776
|
+
print(f'All files: {len(all_filenames)} in {src}')
|
662
777
|
time_ls = []
|
663
|
-
|
664
|
-
|
778
|
+
image_paths_by_key = _extract_filename_metadata(all_filenames, src, regular_expression, metadata_type, pick_slice, skip_mode)
|
779
|
+
# Convert dictionary keys to a list for batching
|
780
|
+
batching_keys = list(image_paths_by_key.keys())
|
781
|
+
print(f'All unique FOV: {len(image_paths_by_key)} in {src}')
|
782
|
+
for idx in range(0, len(image_paths_by_key), batch_size):
|
665
783
|
start = time.time()
|
666
|
-
|
667
|
-
for
|
668
|
-
|
669
|
-
|
784
|
+
|
785
|
+
# Select batch keys and create a subset of the dictionary for this batch
|
786
|
+
batch_keys = batching_keys[idx:idx+batch_size]
|
787
|
+
batch_images_by_key = {key: image_paths_by_key[key] for key in batch_keys}
|
788
|
+
images_by_key = load_images_from_paths(batch_images_by_key)
|
789
|
+
|
670
790
|
if pick_slice:
|
671
791
|
for i, key in enumerate(images_by_key):
|
672
792
|
plate, well, field, channel, mode = key
|
@@ -683,16 +803,16 @@ def _rename_and_organize_image_files(src, regex, batch_size=100, pick_slice=Fals
|
|
683
803
|
files_to_process = len(all_filenames)
|
684
804
|
print_progress(files_processed, files_to_process, n_jobs=1, time_ls=time_ls, batch_size=batch_size, operation_type='Preprocessing filenames')
|
685
805
|
|
686
|
-
if os.path.exists(output_path):
|
687
|
-
print(f'WARNING: A file with the same name already exists at location {output_filename}')
|
688
|
-
else:
|
806
|
+
if not os.path.exists(output_path):
|
689
807
|
mip_image.save(output_path)
|
808
|
+
else:
|
809
|
+
print(f'WARNING: A file with the same name already exists at location {output_filename}')
|
690
810
|
else:
|
691
811
|
for i, (key, images) in enumerate(images_by_key.items()):
|
692
|
-
mip = np.max(np.stack(images), axis=0)
|
693
|
-
mip_image = Image.fromarray(mip)
|
694
812
|
plate, well, field, channel = key[:4]
|
695
813
|
output_dir = os.path.join(src, channel)
|
814
|
+
mip = np.max(np.stack(images), axis=0)
|
815
|
+
mip_image = Image.fromarray(mip)
|
696
816
|
os.makedirs(output_dir, exist_ok=True)
|
697
817
|
output_filename = f'{plate}_{well}_{field}.tif'
|
698
818
|
output_path = os.path.join(output_dir, output_filename)
|
@@ -703,10 +823,11 @@ def _rename_and_organize_image_files(src, regex, batch_size=100, pick_slice=Fals
|
|
703
823
|
files_to_process = len(all_filenames)
|
704
824
|
print_progress(files_processed, files_to_process, n_jobs=1, time_ls=time_ls, batch_size=batch_size, operation_type='Preprocessing filenames')
|
705
825
|
|
706
|
-
if os.path.exists(output_path):
|
707
|
-
print(f'WARNING: A file with the same name already exists at location {output_filename}')
|
708
|
-
else:
|
826
|
+
if not os.path.exists(output_path):
|
709
827
|
mip_image.save(output_path)
|
828
|
+
else:
|
829
|
+
print(f'WARNING: A file with the same name already exists at location {output_filename}')
|
830
|
+
|
710
831
|
images_by_key.clear()
|
711
832
|
|
712
833
|
# Move original images to a new directory
|
@@ -962,9 +1083,7 @@ def _mip_all(src, include_first_chan=True):
|
|
962
1083
|
Returns:
|
963
1084
|
None
|
964
1085
|
"""
|
965
|
-
|
966
|
-
from .utils import normalize_to_dtype
|
967
|
-
|
1086
|
+
|
968
1087
|
#print('========== generating MIPs ==========')
|
969
1088
|
# Iterate over each file in the specified directory (src).
|
970
1089
|
for filename in os.listdir(src):
|
@@ -972,8 +1091,8 @@ def _mip_all(src, include_first_chan=True):
|
|
972
1091
|
if filename.endswith('.npy'):
|
973
1092
|
# Load the array from the file.
|
974
1093
|
array = np.load(os.path.join(src, filename))
|
975
|
-
# Normalize the array
|
976
|
-
array = normalize_to_dtype(array, q1=
|
1094
|
+
# Normalize the array
|
1095
|
+
#array = normalize_to_dtype(array, q1=0, q2=99, percentiles=None)
|
977
1096
|
|
978
1097
|
if array.ndim != 3: # Check if the array is not 3-dimensional.
|
979
1098
|
# Log a message indicating a zero array will be generated due to unexpected dimensions.
|
@@ -1338,7 +1457,6 @@ def _get_lists_for_normalization(settings):
|
|
1338
1457
|
return backgrounds, signal_to_noise, signal_thresholds, remove_background
|
1339
1458
|
|
1340
1459
|
def _normalize_stack(src, backgrounds=[100, 100, 100], remove_backgrounds=[False, False, False], lower_percentile=2, save_dtype=np.float32, signal_to_noise=[5, 5, 5], signal_thresholds=[1000, 1000, 1000]):
|
1341
|
-
from .utils import print_progress
|
1342
1460
|
"""
|
1343
1461
|
Normalize the stack of images.
|
1344
1462
|
|
@@ -1431,7 +1549,6 @@ def _normalize_stack(src, backgrounds=[100, 100, 100], remove_backgrounds=[False
|
|
1431
1549
|
return print(f'Saved stacks: {output_fldr}')
|
1432
1550
|
|
1433
1551
|
def _normalize_timelapse(src, lower_percentile=2, save_dtype=np.float32):
|
1434
|
-
from .utils import print_progress
|
1435
1552
|
"""
|
1436
1553
|
Normalize the timelapse data by rescaling the intensity values based on percentiles.
|
1437
1554
|
|
@@ -1560,7 +1677,7 @@ def delete_empty_subdirectories(folder_path):
|
|
1560
1677
|
#@log_function_call
|
1561
1678
|
def preprocess_img_data(settings):
|
1562
1679
|
|
1563
|
-
from .plot import plot_arrays
|
1680
|
+
from .plot import plot_arrays
|
1564
1681
|
from .utils import _run_test_mode, _get_regex
|
1565
1682
|
from .settings import set_default_settings_preprocess_img_data
|
1566
1683
|
|
@@ -1671,6 +1788,7 @@ def preprocess_img_data(settings):
|
|
1671
1788
|
if plot:
|
1672
1789
|
print(f'plotting {nr} images from {src}/stack')
|
1673
1790
|
plot_arrays(src+'/stack', figuresize, cmap, nr=nr, normalize=normalize)
|
1791
|
+
|
1674
1792
|
if all_to_mip:
|
1675
1793
|
_mip_all(src+'/stack')
|
1676
1794
|
if plot:
|
@@ -2054,7 +2172,6 @@ def _load_and_concatenate_arrays(src, channels, cell_chann_dim, nucleus_chann_di
|
|
2054
2172
|
padded_shapes = [shape + (0,) * (max_tuple_length - len(shape)) for shape in unique_shapes]
|
2055
2173
|
# Now create a NumPy array and find the maximum dimensions
|
2056
2174
|
max_dims = np.max(np.array(padded_shapes), axis=0)
|
2057
|
-
#clear_output(wait=True)
|
2058
2175
|
print(f'Warning: arrays with multiple shapes found. Padding arrays to max X,Y dimentions {max_dims}')
|
2059
2176
|
#print(f'Warning: arrays with multiple shapes found. Padding arrays to max X,Y dimentions {max_dims}', end='\r', flush=True)
|
2060
2177
|
padded_stack_ls = []
|
@@ -2102,7 +2219,7 @@ def _read_db(db_loc, tables):
|
|
2102
2219
|
conn.close()
|
2103
2220
|
return dfs
|
2104
2221
|
|
2105
|
-
def _read_and_merge_data(locs, tables, verbose=False,
|
2222
|
+
def _read_and_merge_data(locs, tables, verbose=False, nuclei_limit=False, pathogen_limit=False, uninfected=False):
|
2106
2223
|
"""
|
2107
2224
|
Read and merge data from SQLite databases and perform data preprocessing.
|
2108
2225
|
|
@@ -2110,9 +2227,9 @@ def _read_and_merge_data(locs, tables, verbose=False, include_multinucleated=Fal
|
|
2110
2227
|
- locs (list): A list of file paths to the SQLite database files.
|
2111
2228
|
- tables (list): A list of table names to read from the databases.
|
2112
2229
|
- verbose (bool): Whether to print verbose output. Default is False.
|
2113
|
-
-
|
2114
|
-
-
|
2115
|
-
-
|
2230
|
+
- nuclei_limit (bool): Whether to include multinucleated cells. Default is False.
|
2231
|
+
- pathogen_limit (bool): Whether to include cells with multiple infections. Default is False.
|
2232
|
+
- uninfected (bool): Whether to include non-infected cells. Default is False.
|
2116
2233
|
|
2117
2234
|
Returns:
|
2118
2235
|
- merged_df (pandas.DataFrame): The merged and preprocessed dataframe.
|
@@ -2187,7 +2304,7 @@ def _read_and_merge_data(locs, tables, verbose=False, include_multinucleated=Fal
|
|
2187
2304
|
nucleus = nucleus.assign(cell_id=lambda x: 'o' + x['cell_id'].astype(int).astype(str))
|
2188
2305
|
nucleus = nucleus.assign(prcfo = lambda x: x['prcf'] + '_' + x['cell_id'])
|
2189
2306
|
nucleus['nucleus_prcfo_count'] = nucleus.groupby('prcfo')['prcfo'].transform('count')
|
2190
|
-
if
|
2307
|
+
if nuclei_limit == False:
|
2191
2308
|
#nucleus = nucleus[~nucleus['prcfo'].duplicated()]
|
2192
2309
|
nucleus = nucleus[nucleus['nucleus_prcfo_count']==1]
|
2193
2310
|
nucleus_g_df, _ = _split_data(nucleus, 'prcfo', 'cell_id')
|
@@ -2203,9 +2320,9 @@ def _read_and_merge_data(locs, tables, verbose=False, include_multinucleated=Fal
|
|
2203
2320
|
pathogens = pathogens.assign(cell_id=lambda x: 'o' + x['cell_id'].astype(int).astype(str))
|
2204
2321
|
pathogens = pathogens.assign(prcfo = lambda x: x['prcf'] + '_' + x['cell_id'])
|
2205
2322
|
pathogens['pathogen_prcfo_count'] = pathogens.groupby('prcfo')['prcfo'].transform('count')
|
2206
|
-
if
|
2323
|
+
if uninfected == False:
|
2207
2324
|
pathogens = pathogens[pathogens['pathogen_prcfo_count']>=1]
|
2208
|
-
if
|
2325
|
+
if pathogen_limit == False:
|
2209
2326
|
pathogens = pathogens[pathogens['pathogen_prcfo_count']<=1]
|
2210
2327
|
pathogens_g_df, _ = _split_data(pathogens, 'prcfo', 'cell_id')
|
2211
2328
|
print(f'pathogens: {len(pathogens)}')
|
@@ -2448,7 +2565,7 @@ def _read_db(db_loc, tables):
|
|
2448
2565
|
conn.close() # Close the connection
|
2449
2566
|
return dfs
|
2450
2567
|
|
2451
|
-
def _read_and_merge_data(locs, tables, verbose=False,
|
2568
|
+
def _read_and_merge_data(locs, tables, verbose=False, nuclei_limit=False, pathogen_limit=False, uninfected=False):
|
2452
2569
|
|
2453
2570
|
from .utils import _split_data
|
2454
2571
|
|
@@ -2533,7 +2650,7 @@ def _read_and_merge_data(locs, tables, verbose=False, include_multinucleated=Fal
|
|
2533
2650
|
nucleus = nucleus.assign(cell_id=lambda x: 'o' + x['cell_id'].astype(int).astype(str))
|
2534
2651
|
nucleus = nucleus.assign(prcfo = lambda x: x['prcf'] + '_' + x['cell_id'])
|
2535
2652
|
nucleus['nucleus_prcfo_count'] = nucleus.groupby('prcfo')['prcfo'].transform('count')
|
2536
|
-
if
|
2653
|
+
if nuclei_limit == False:
|
2537
2654
|
nucleus = nucleus[nucleus['nucleus_prcfo_count']==1]
|
2538
2655
|
nucleus_g_df, _ = _split_data(nucleus, 'prcfo', 'cell_id')
|
2539
2656
|
if verbose:
|
@@ -2559,20 +2676,30 @@ def _read_and_merge_data(locs, tables, verbose=False, include_multinucleated=Fal
|
|
2559
2676
|
pathogens = pathogens.assign(cell_id=lambda x: 'o' + x['cell_id'].astype(int).astype(str))
|
2560
2677
|
pathogens = pathogens.assign(prcfo = lambda x: x['prcf'] + '_' + x['cell_id'])
|
2561
2678
|
pathogens['pathogen_prcfo_count'] = pathogens.groupby('prcfo')['prcfo'].transform('count')
|
2562
|
-
|
2679
|
+
|
2680
|
+
print(f"before noninfected: {len(pathogens)}")
|
2681
|
+
if uninfected == False:
|
2563
2682
|
pathogens = pathogens[pathogens['pathogen_prcfo_count']>=1]
|
2564
|
-
|
2565
|
-
|
2683
|
+
print(f"after noninfected: {len(pathogens)}")
|
2684
|
+
|
2685
|
+
if isinstance(pathogen_limit, bool):
|
2686
|
+
if pathogen_limit == False:
|
2566
2687
|
pathogens = pathogens[pathogens['pathogen_prcfo_count']<=1]
|
2567
|
-
|
2568
|
-
|
2688
|
+
print(f"after multiinfected Bool: {len(pathogens)}")
|
2689
|
+
if isinstance(pathogen_limit, float):
|
2690
|
+
pathogen_limit = int(pathogen_limit)
|
2691
|
+
if isinstance(pathogen_limit, int):
|
2692
|
+
pathogens = pathogens[pathogens['pathogen_prcfo_count']<=pathogen_limit]
|
2693
|
+
print(f"afer multiinfected Float: {len(pathogens)}")
|
2569
2694
|
if not 'cell' in tables:
|
2570
2695
|
pathogens_g_df, metadata = _split_data(pathogens, 'prcfo', 'cell_id')
|
2571
2696
|
else:
|
2572
2697
|
pathogens_g_df, _ = _split_data(pathogens, 'prcfo', 'cell_id')
|
2698
|
+
|
2573
2699
|
if verbose:
|
2574
2700
|
print(f'pathogens: {len(pathogens)}')
|
2575
2701
|
print(f'pathogens grouped: {len(pathogens_g_df)}')
|
2702
|
+
|
2576
2703
|
if len(merged_df) == 0:
|
2577
2704
|
merged_df = pathogens_g_df
|
2578
2705
|
else:
|
@@ -2697,4 +2824,457 @@ def generate_cellpose_train_test(src, test_split=0.1):
|
|
2697
2824
|
shutil.copy(img_path, new_img_path)
|
2698
2825
|
shutil.copy(mask_path, new_mask_path)
|
2699
2826
|
print(f'Copied {idx+1}/{len(ls)} images to {_type} set')#, end='\r', flush=True)
|
2700
|
-
|
2827
|
+
|
2828
|
+
def parse_gz_files(folder_path):
|
2829
|
+
"""
|
2830
|
+
Parses the .fastq.gz files in the specified folder path and returns a dictionary
|
2831
|
+
containing the sample names and their corresponding file paths.
|
2832
|
+
|
2833
|
+
Args:
|
2834
|
+
folder_path (str): The path to the folder containing the .fastq.gz files.
|
2835
|
+
|
2836
|
+
Returns:
|
2837
|
+
dict: A dictionary where the keys are the sample names and the values are
|
2838
|
+
dictionaries containing the file paths for the 'R1' and 'R2' read directions.
|
2839
|
+
"""
|
2840
|
+
files = os.listdir(folder_path)
|
2841
|
+
gz_files = [f for f in files if f.endswith('.fastq.gz')]
|
2842
|
+
|
2843
|
+
samples_dict = {}
|
2844
|
+
for gz_file in gz_files:
|
2845
|
+
parts = gz_file.split('_')
|
2846
|
+
sample_name = parts[0]
|
2847
|
+
read_direction = parts[1]
|
2848
|
+
|
2849
|
+
if sample_name not in samples_dict:
|
2850
|
+
samples_dict[sample_name] = {}
|
2851
|
+
|
2852
|
+
if read_direction == "R1":
|
2853
|
+
samples_dict[sample_name]['R1'] = os.path.join(folder_path, gz_file)
|
2854
|
+
elif read_direction == "R2":
|
2855
|
+
samples_dict[sample_name]['R2'] = os.path.join(folder_path, gz_file)
|
2856
|
+
return samples_dict
|
2857
|
+
|
2858
|
+
def generate_dataset(settings={}):
|
2859
|
+
|
2860
|
+
from .utils import initiate_counter, add_images_to_tar, save_settings, generate_path_list_from_db, correct_paths
|
2861
|
+
from .settings import set_generate_dataset_defaults
|
2862
|
+
|
2863
|
+
settings = set_generate_dataset_defaults(settings)
|
2864
|
+
save_settings(settings, 'generate_dataset', show=True)
|
2865
|
+
|
2866
|
+
if isinstance(settings['src'], str):
|
2867
|
+
settings['src'] = [settings['src']]
|
2868
|
+
if isinstance(settings['src'], list):
|
2869
|
+
all_paths = []
|
2870
|
+
for i, src in enumerate(settings['src']):
|
2871
|
+
db_path = os.path.join(src, 'measurements', 'measurements.db')
|
2872
|
+
dst = os.path.join(src, 'datasets')
|
2873
|
+
paths = generate_path_list_from_db(db_path, file_metadata=settings['file_metadata'])
|
2874
|
+
correct_paths(paths, src)
|
2875
|
+
all_paths.extend(paths)
|
2876
|
+
if isinstance(settings['sample'], int):
|
2877
|
+
selected_paths = random.sample(all_paths, settings['sample'])
|
2878
|
+
print(f"Random selection of {len(selected_paths)} paths")
|
2879
|
+
elif isinstance(settings['sample'], list):
|
2880
|
+
sample = settings['sample'][i]
|
2881
|
+
selected_paths = random.sample(all_paths, settings['sample'])
|
2882
|
+
print(f"Random selection of {len(selected_paths)} paths")
|
2883
|
+
else:
|
2884
|
+
selected_paths = all_paths
|
2885
|
+
random.shuffle(selected_paths)
|
2886
|
+
print(f"All paths: {len(selected_paths)} paths")
|
2887
|
+
|
2888
|
+
total_images = len(selected_paths)
|
2889
|
+
print(f"Found {total_images} images")
|
2890
|
+
|
2891
|
+
# Create a temp folder in dst
|
2892
|
+
temp_dir = os.path.join(dst, "temp_tars")
|
2893
|
+
os.makedirs(temp_dir, exist_ok=True)
|
2894
|
+
|
2895
|
+
# Chunking the data
|
2896
|
+
num_procs = max(2, cpu_count() - 2)
|
2897
|
+
chunk_size = len(selected_paths) // num_procs
|
2898
|
+
remainder = len(selected_paths) % num_procs
|
2899
|
+
|
2900
|
+
paths_chunks = []
|
2901
|
+
start = 0
|
2902
|
+
for i in range(num_procs):
|
2903
|
+
end = start + chunk_size + (1 if i < remainder else 0)
|
2904
|
+
paths_chunks.append(selected_paths[start:end])
|
2905
|
+
start = end
|
2906
|
+
|
2907
|
+
temp_tar_files = [os.path.join(temp_dir, f"temp_{i}.tar") for i in range(num_procs)]
|
2908
|
+
|
2909
|
+
print(f"Generating temporary tar files in {dst}")
|
2910
|
+
|
2911
|
+
# Initialize shared counter and lock
|
2912
|
+
counter = Value('i', 0)
|
2913
|
+
lock = Lock()
|
2914
|
+
|
2915
|
+
with Pool(processes=num_procs, initializer=initiate_counter, initargs=(counter, lock)) as pool:
|
2916
|
+
pool.starmap(add_images_to_tar, [(paths_chunks[i], temp_tar_files[i], total_images) for i in range(num_procs)])
|
2917
|
+
|
2918
|
+
# Combine the temporary tar files into a final tar
|
2919
|
+
date_name = datetime.date.today().strftime('%y%m%d')
|
2920
|
+
if not settings['file_metadata'] is None:
|
2921
|
+
tar_name = f"{date_name}_{settings['experiment']}_{settings['file_metadata']}.tar"
|
2922
|
+
else:
|
2923
|
+
tar_name = f"{date_name}_{settings['experiment']}.tar"
|
2924
|
+
tar_name = os.path.join(dst, tar_name)
|
2925
|
+
if os.path.exists(tar_name):
|
2926
|
+
number = random.randint(1, 100)
|
2927
|
+
tar_name_2 = f"{date_name}_{settings['experiment']}_{settings['file_metadata']}_{number}.tar"
|
2928
|
+
print(f"Warning: {os.path.basename(tar_name)} exists, saving as {os.path.basename(tar_name_2)} ")
|
2929
|
+
tar_name = os.path.join(dst, tar_name_2)
|
2930
|
+
|
2931
|
+
print(f"Merging temporary files")
|
2932
|
+
|
2933
|
+
with tarfile.open(tar_name, 'w') as final_tar:
|
2934
|
+
for temp_tar_path in temp_tar_files:
|
2935
|
+
with tarfile.open(temp_tar_path, 'r') as temp_tar:
|
2936
|
+
for member in temp_tar.getmembers():
|
2937
|
+
file_obj = temp_tar.extractfile(member)
|
2938
|
+
final_tar.addfile(member, file_obj)
|
2939
|
+
os.remove(temp_tar_path)
|
2940
|
+
|
2941
|
+
# Delete the temp folder
|
2942
|
+
shutil.rmtree(temp_dir)
|
2943
|
+
print(f"\nSaved {total_images} images to {tar_name}")
|
2944
|
+
|
2945
|
+
return tar_name
|
2946
|
+
|
2947
|
+
def generate_loaders(src, mode='train', image_size=224, batch_size=32, classes=['nc','pc'], n_jobs=None, validation_split=0.0, pin_memory=False, normalize=False, channels=[1, 2, 3], augment=False, verbose=False):
|
2948
|
+
|
2949
|
+
"""
|
2950
|
+
Generate data loaders for training and validation/test datasets.
|
2951
|
+
|
2952
|
+
Parameters:
|
2953
|
+
- src (str): The source directory containing the data.
|
2954
|
+
- mode (str): The mode of operation. Options are 'train' or 'test'.
|
2955
|
+
- image_size (int): The size of the input images.
|
2956
|
+
- batch_size (int): The batch size for the data loaders.
|
2957
|
+
- classes (list): The list of classes to consider.
|
2958
|
+
- n_jobs (int): The number of worker threads for data loading.
|
2959
|
+
- validation_split (float): The fraction of data to use for validation.
|
2960
|
+
- pin_memory (bool): Whether to pin memory for faster data transfer.
|
2961
|
+
- normalize (bool): Whether to normalize the input images.
|
2962
|
+
- verbose (bool): Whether to print additional information and show images.
|
2963
|
+
- channels (list): The list of channels to retain. Options are [1, 2, 3] for all channels, [1, 2] for blue and green, etc.
|
2964
|
+
|
2965
|
+
Returns:
|
2966
|
+
- train_loaders (list): List of data loaders for training datasets.
|
2967
|
+
- val_loaders (list): List of data loaders for validation datasets.
|
2968
|
+
"""
|
2969
|
+
|
2970
|
+
from .io import spacrDataset
|
2971
|
+
from .utils import SelectChannels, augment_dataset
|
2972
|
+
|
2973
|
+
chans = []
|
2974
|
+
|
2975
|
+
if 'r' in channels:
|
2976
|
+
chans.append(1)
|
2977
|
+
if 'g' in channels:
|
2978
|
+
chans.append(2)
|
2979
|
+
if 'b' in channels:
|
2980
|
+
chans.append(3)
|
2981
|
+
|
2982
|
+
channels = chans
|
2983
|
+
|
2984
|
+
if verbose:
|
2985
|
+
print(f'Training a network on channels: {channels}')
|
2986
|
+
print(f'Channel 1: Red, Channel 2: Green, Channel 3: Blue')
|
2987
|
+
|
2988
|
+
train_loaders = []
|
2989
|
+
val_loaders = []
|
2990
|
+
|
2991
|
+
if normalize:
|
2992
|
+
transform = transforms.Compose([
|
2993
|
+
transforms.ToTensor(),
|
2994
|
+
transforms.CenterCrop(size=(image_size, image_size)),
|
2995
|
+
SelectChannels(channels),
|
2996
|
+
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])
|
2997
|
+
else:
|
2998
|
+
transform = transforms.Compose([
|
2999
|
+
transforms.ToTensor(),
|
3000
|
+
transforms.CenterCrop(size=(image_size, image_size)),
|
3001
|
+
SelectChannels(channels)])
|
3002
|
+
|
3003
|
+
if mode == 'train':
|
3004
|
+
data_dir = os.path.join(src, 'train')
|
3005
|
+
shuffle = True
|
3006
|
+
print('Loading Train and validation datasets')
|
3007
|
+
elif mode == 'test':
|
3008
|
+
data_dir = os.path.join(src, 'test')
|
3009
|
+
val_loaders = []
|
3010
|
+
validation_split = 0.0
|
3011
|
+
shuffle = True
|
3012
|
+
print('Loading test dataset')
|
3013
|
+
else:
|
3014
|
+
print(f'mode:{mode} is not valid, use mode = train or test')
|
3015
|
+
return
|
3016
|
+
|
3017
|
+
data = spacrDataset(data_dir, classes, transform=transform, shuffle=shuffle, pin_memory=pin_memory)
|
3018
|
+
num_workers = n_jobs if n_jobs is not None else 0
|
3019
|
+
|
3020
|
+
if validation_split > 0:
|
3021
|
+
train_size = int((1 - validation_split) * len(data))
|
3022
|
+
val_size = len(data) - train_size
|
3023
|
+
if not augment:
|
3024
|
+
print(f'Train data:{train_size}, Validation data:{val_size}')
|
3025
|
+
train_dataset, val_dataset = random_split(data, [train_size, val_size])
|
3026
|
+
|
3027
|
+
if augment:
|
3028
|
+
|
3029
|
+
print(f'Data before augmentation: Train: {len(train_dataset)}, Validataion:{len(val_dataset)}')
|
3030
|
+
train_dataset = augment_dataset(train_dataset, is_grayscale=(len(channels) == 1))
|
3031
|
+
print(f'Data after augmentation: Train: {len(train_dataset)}')
|
3032
|
+
|
3033
|
+
print(f'Generating Dataloader with {n_jobs} workers')
|
3034
|
+
train_loaders = DataLoader(train_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=1, pin_memory=pin_memory, persistent_workers=True)
|
3035
|
+
val_loaders = DataLoader(val_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=1, pin_memory=pin_memory, persistent_workers=True)
|
3036
|
+
else:
|
3037
|
+
train_loaders = DataLoader(data, batch_size=batch_size, shuffle=shuffle, num_workers=1, pin_memory=pin_memory, persistent_workers=True)
|
3038
|
+
|
3039
|
+
#dataset (Dataset) – dataset from which to load the data.
|
3040
|
+
#batch_size (int, optional) – how many samples per batch to load (default: 1).
|
3041
|
+
#shuffle (bool, optional) – set to True to have the data reshuffled at every epoch (default: False).
|
3042
|
+
#sampler (Sampler or Iterable, optional) – defines the strategy to draw samples from the dataset. Can be any Iterable with __len__ implemented. If specified, shuffle must not be specified.
|
3043
|
+
#batch_sampler (Sampler or Iterable, optional) – like sampler, but returns a batch of indices at a time. Mutually exclusive with batch_size, shuffle, sampler, and drop_last.
|
3044
|
+
#num_workers (int, optional) – how many subprocesses to use for data loading. 0 means that the data will be loaded in the main process. (default: 0)
|
3045
|
+
#collate_fn (Callable, optional) – merges a list of samples to form a mini-batch of Tensor(s). Used when using batched loading from a map-style dataset.
|
3046
|
+
#pin_memory (bool, optional) – If True, the data loader will copy Tensors into device/CUDA pinned memory before returning them. If your data elements are a custom type, or your collate_fn returns a batch that is a custom type, see the example below.
|
3047
|
+
#drop_last (bool, optional) – set to True to drop the last incomplete batch, if the dataset size is not divisible by the batch size. If False and the size of dataset is not divisible by the batch size, then the last batch will be smaller. (default: False)
|
3048
|
+
#timeout (numeric, optional) – if positive, the timeout value for collecting a batch from workers. Should always be non-negative. (default: 0)
|
3049
|
+
#worker_init_fn (Callable, optional) – If not None, this will be called on each worker subprocess with the worker id (an int in [0, num_workers - 1]) as input, after seeding and before data loading. (default: None)
|
3050
|
+
#multiprocessing_context (str or multiprocessing.context.BaseContext, optional) – If None, the default multiprocessing context of your operating system will be used. (default: None)
|
3051
|
+
#generator (torch.Generator, optional) – If not None, this RNG will be used by RandomSampler to generate random indexes and multiprocessing to generate base_seed for workers. (default: None)
|
3052
|
+
#prefetch_factor (int, optional, keyword-only arg) – Number of batches loaded in advance by each worker. 2 means there will be a total of 2 * num_workers batches prefetched across all workers. (default value depends on the set value for num_workers. If value of num_workers=0 default is None. Otherwise, if value of num_workers > 0 default is 2).
|
3053
|
+
#persistent_workers (bool, optional) – If True, the data loader will not shut down the worker processes after a dataset has been consumed once. This allows to maintain the workers Dataset instances alive. (default: False)
|
3054
|
+
#pin_memory_device (str, optional) – the device to pin_memory to if pin_memory is True.
|
3055
|
+
|
3056
|
+
#images, labels, filenames = next(iter(train_loaders))
|
3057
|
+
#images = images.cpu()
|
3058
|
+
#label_strings = [str(label.item()) for label in labels]
|
3059
|
+
#train_fig = _imshow_gpu(images, label_strings, nrow=20, fontsize=12)
|
3060
|
+
#if verbose:
|
3061
|
+
# plt.show()
|
3062
|
+
|
3063
|
+
train_fig = None
|
3064
|
+
|
3065
|
+
return train_loaders, val_loaders, train_fig
|
3066
|
+
|
3067
|
+
def generate_training_dataset(settings):
|
3068
|
+
|
3069
|
+
from .io import _read_and_merge_data, _read_db
|
3070
|
+
from .utils import get_paths_from_db, annotate_conditions, save_settings
|
3071
|
+
from .settings import set_generate_training_dataset_defaults
|
3072
|
+
|
3073
|
+
# Function to filter png_list_df by prcfo present in df without merging
|
3074
|
+
def filter_png_list(db_path, settings):
|
3075
|
+
tables = ['cell', 'nucleus', 'pathogen', 'cytoplasm']
|
3076
|
+
df, _ = _read_and_merge_data(locs=[db_path],
|
3077
|
+
tables=tables,
|
3078
|
+
verbose=False,
|
3079
|
+
nuclei_limit=settings['nuclei_limit'],
|
3080
|
+
pathogen_limit=settings['pathogen_limit'],
|
3081
|
+
uninfected=settings['uninfected'])
|
3082
|
+
[png_list_df] = _read_db(db_loc=db_path, tables=['png_list'])
|
3083
|
+
filtered_png_list_df = png_list_df[png_list_df['prcfo'].isin(df.index)]
|
3084
|
+
return filtered_png_list_df
|
3085
|
+
|
3086
|
+
# Function to get the smallest class size based on the dataset mode
|
3087
|
+
def get_smallest_class_size(df, settings, dataset_mode):
|
3088
|
+
if dataset_mode == 'metadata':
|
3089
|
+
sizes = [len(df[df['metadata_based_class'] == c]) for c in settings['classes']]
|
3090
|
+
elif dataset_mode == 'annotation':
|
3091
|
+
sizes = [len(class_paths) for class_paths in df]
|
3092
|
+
size = min(sizes)
|
3093
|
+
print(f'Using the smallest class size: {size}')
|
3094
|
+
return size
|
3095
|
+
|
3096
|
+
# Measurement-based selection logic
|
3097
|
+
def measurement_based_selection(settings, db_path):
|
3098
|
+
class_paths_ls = []
|
3099
|
+
tables = ['cell', 'nucleus', 'pathogen', 'cytoplasm']
|
3100
|
+
df, _ = _read_and_merge_data(locs=[db_path],
|
3101
|
+
tables=tables,
|
3102
|
+
verbose=False,
|
3103
|
+
nuclei_limit=settings['nuclei_limit'],
|
3104
|
+
pathogen_limit=settings['pathogen_limit'],
|
3105
|
+
uninfected=settings['uninfected'])
|
3106
|
+
|
3107
|
+
print('length df 1', len(df))
|
3108
|
+
df = annotate_conditions(df, cells=['HeLa'], pathogens=['pathogen'], treatments=settings['classes'],
|
3109
|
+
treatment_loc=settings['class_metadata'])#, types=settings['metadata_type_by'])
|
3110
|
+
print('length df 2', len(df))
|
3111
|
+
|
3112
|
+
png_list_df = filter_png_list(db_path, settings)
|
3113
|
+
|
3114
|
+
if settings['custom_measurement']:
|
3115
|
+
if isinstance(settings['custom_measurement'], list):
|
3116
|
+
if len(settings['custom_measurement']) == 2:
|
3117
|
+
df['recruitment'] = df[f"{settings['custom_measurement'][0]}"] / df[f"{settings['custom_measurement'][1]}"]
|
3118
|
+
else:
|
3119
|
+
df['recruitment'] = df[f"{settings['custom_measurement'][0]}"]
|
3120
|
+
else:
|
3121
|
+
print("custom_measurement should be a list.")
|
3122
|
+
return
|
3123
|
+
|
3124
|
+
else:
|
3125
|
+
df['recruitment'] = df[f"pathogen_channel_{settings['channel_of_interest']}_mean_intensity"] / df[f"cytoplasm_channel_{settings['channel_of_interest']}_mean_intensity"]
|
3126
|
+
|
3127
|
+
q25 = df['recruitment'].quantile(0.25)
|
3128
|
+
q75 = df['recruitment'].quantile(0.75)
|
3129
|
+
df_lower = df[df['recruitment'] <= q25]
|
3130
|
+
df_upper = df[df['recruitment'] >= q75]
|
3131
|
+
|
3132
|
+
class_paths_lower = get_paths_from_db(df=df_lower, png_df=png_list_df, image_type=settings['png_type'])
|
3133
|
+
class_paths_lower = random.sample(class_paths_lower['png_path'].tolist(), settings['size'])
|
3134
|
+
class_paths_ls.append(class_paths_lower)
|
3135
|
+
|
3136
|
+
class_paths_upper = get_paths_from_db(df=df_upper, png_df=png_list_df, image_type=settings['png_type'])
|
3137
|
+
class_paths_upper = random.sample(class_paths_upper['png_path'].tolist(), settings['size'])
|
3138
|
+
class_paths_ls.append(class_paths_upper)
|
3139
|
+
|
3140
|
+
return class_paths_ls
|
3141
|
+
|
3142
|
+
# Metadata-based selection logic
|
3143
|
+
def metadata_based_selection(db_path, settings):
|
3144
|
+
class_paths_ls = []
|
3145
|
+
df = filter_png_list(db_path, settings)
|
3146
|
+
|
3147
|
+
df['metadata_based_class'] = pd.NA
|
3148
|
+
for i, class_ in enumerate(settings['classes']):
|
3149
|
+
ls = settings['class_metadata'][i]
|
3150
|
+
df.loc[df[settings['metadata_type_by']].isin(ls), 'metadata_based_class'] = class_
|
3151
|
+
|
3152
|
+
size = get_smallest_class_size(df, settings, 'metadata')
|
3153
|
+
for class_ in settings['classes']:
|
3154
|
+
class_temp_df = df[df['metadata_based_class'] == class_]
|
3155
|
+
print(f'Found {len(class_temp_df)} images for class {class_}')
|
3156
|
+
class_paths_temp = class_temp_df['png_path'].tolist()
|
3157
|
+
|
3158
|
+
# Ensure to sample `size` number of images (smallest class size)
|
3159
|
+
if len(class_paths_temp) > size:
|
3160
|
+
class_paths_temp = random.sample(class_paths_temp, size)
|
3161
|
+
|
3162
|
+
class_paths_ls.append(class_paths_temp)
|
3163
|
+
|
3164
|
+
return class_paths_ls
|
3165
|
+
|
3166
|
+
# Annotation-based selection logic
|
3167
|
+
def annotation_based_selection(db_path, dst, settings):
|
3168
|
+
class_paths_ls = training_dataset_from_annotation(db_path, dst, settings['annotation_column'], annotated_classes=settings['annotated_classes'])
|
3169
|
+
|
3170
|
+
size = get_smallest_class_size(class_paths_ls, settings, 'annotation')
|
3171
|
+
for i, class_paths in enumerate(class_paths_ls):
|
3172
|
+
if len(class_paths) > size:
|
3173
|
+
class_paths_ls[i] = random.sample(class_paths, size)
|
3174
|
+
|
3175
|
+
return class_paths_ls
|
3176
|
+
|
3177
|
+
# Set default settings and save
|
3178
|
+
settings = set_generate_training_dataset_defaults(settings)
|
3179
|
+
save_settings(settings, 'cv_dataset', show=True)
|
3180
|
+
|
3181
|
+
db_path = os.path.join(settings['src'], 'measurements', 'measurements.db')
|
3182
|
+
dst = os.path.join(settings['src'], 'datasets', 'training')
|
3183
|
+
|
3184
|
+
# Create a new directory for training data if necessary
|
3185
|
+
if os.path.exists(dst):
|
3186
|
+
for i in range(1, 100000):
|
3187
|
+
dst = os.path.join(settings['src'], 'datasets', f'training_{i}')
|
3188
|
+
if not os.path.exists(dst):
|
3189
|
+
print(f'Creating new directory for training: {dst}')
|
3190
|
+
break
|
3191
|
+
|
3192
|
+
# Select dataset based on dataset mode
|
3193
|
+
if settings['dataset_mode'] == 'annotation':
|
3194
|
+
class_paths_ls = annotation_based_selection(db_path, dst, settings)
|
3195
|
+
|
3196
|
+
elif settings['dataset_mode'] == 'metadata':
|
3197
|
+
class_paths_ls = metadata_based_selection(db_path, settings)
|
3198
|
+
|
3199
|
+
elif settings['dataset_mode'] == 'measurement':
|
3200
|
+
class_paths_ls = measurement_based_selection(settings, db_path)
|
3201
|
+
|
3202
|
+
# Generate and return training and testing directories
|
3203
|
+
train_class_dir, test_class_dir = generate_dataset_from_lists(dst, class_data=class_paths_ls, classes=settings['classes'], test_split=settings['test_split'])
|
3204
|
+
|
3205
|
+
return train_class_dir, test_class_dir
|
3206
|
+
|
3207
|
+
def training_dataset_from_annotation(db_path, dst, annotation_column='test', annotated_classes=(1, 2)):
|
3208
|
+
all_paths = []
|
3209
|
+
|
3210
|
+
# Connect to the database and retrieve the image paths and annotations
|
3211
|
+
print(f'Reading DataBase: {db_path}')
|
3212
|
+
with sqlite3.connect(db_path) as conn:
|
3213
|
+
cursor = conn.cursor()
|
3214
|
+
# Prepare the query with parameterized placeholders for annotated_classes
|
3215
|
+
placeholders = ','.join('?' * len(annotated_classes))
|
3216
|
+
query = f"SELECT png_path, {annotation_column} FROM png_list WHERE {annotation_column} IN ({placeholders})"
|
3217
|
+
cursor.execute(query, annotated_classes)
|
3218
|
+
|
3219
|
+
while True:
|
3220
|
+
rows = cursor.fetchmany(1000)
|
3221
|
+
if not rows:
|
3222
|
+
break
|
3223
|
+
for row in rows:
|
3224
|
+
all_paths.append(row)
|
3225
|
+
|
3226
|
+
# Filter paths based on annotation
|
3227
|
+
class_paths = []
|
3228
|
+
for class_ in annotated_classes:
|
3229
|
+
class_paths_temp = [path for path, annotation in all_paths if annotation == class_]
|
3230
|
+
class_paths.append(class_paths_temp)
|
3231
|
+
|
3232
|
+
print(f'Generated a list of lists from annotation of {len(class_paths)} classes')
|
3233
|
+
return class_paths
|
3234
|
+
|
3235
|
+
def generate_dataset_from_lists(dst, class_data, classes, test_split=0.1):
|
3236
|
+
from .utils import print_progress
|
3237
|
+
from .deep_spacr import train_test_split
|
3238
|
+
# Make sure that the length of class_data matches the length of classes
|
3239
|
+
if len(class_data) != len(classes):
|
3240
|
+
raise ValueError("class_data and classes must have the same length.")
|
3241
|
+
|
3242
|
+
total_files = sum(len(data) for data in class_data)
|
3243
|
+
processed_files = 0
|
3244
|
+
time_ls = []
|
3245
|
+
|
3246
|
+
for cls, data in zip(classes, class_data):
|
3247
|
+
# Create directories
|
3248
|
+
train_class_dir = os.path.join(dst, f'train/{cls}')
|
3249
|
+
test_class_dir = os.path.join(dst, f'test/{cls}')
|
3250
|
+
os.makedirs(train_class_dir, exist_ok=True)
|
3251
|
+
os.makedirs(test_class_dir, exist_ok=True)
|
3252
|
+
|
3253
|
+
# Split the data
|
3254
|
+
train_data, test_data = train_test_split(data, test_size=test_split, shuffle=True, random_state=42)
|
3255
|
+
|
3256
|
+
# Copy train files
|
3257
|
+
for path in train_data:
|
3258
|
+
start = time.time()
|
3259
|
+
shutil.copy(path, os.path.join(train_class_dir, os.path.basename(path)))
|
3260
|
+
duration = time.time() - start
|
3261
|
+
time_ls.append(duration)
|
3262
|
+
print_progress(processed_files, total_files, n_jobs=1, time_ls=None, batch_size=None, operation_type="Copying files for Train dataset")
|
3263
|
+
processed_files += 1
|
3264
|
+
|
3265
|
+
# Copy test files
|
3266
|
+
for path in test_data:
|
3267
|
+
start = time.time()
|
3268
|
+
shutil.copy(path, os.path.join(test_class_dir, os.path.basename(path)))
|
3269
|
+
duration = time.time() - start
|
3270
|
+
time_ls.append(duration)
|
3271
|
+
print_progress(processed_files, total_files, n_jobs=1, time_ls=None, batch_size=None, operation_type="Copying files for Test dataset")
|
3272
|
+
processed_files += 1
|
3273
|
+
|
3274
|
+
# Print summary
|
3275
|
+
for cls in classes:
|
3276
|
+
train_class_dir = os.path.join(dst, f'train/{cls}')
|
3277
|
+
test_class_dir = os.path.join(dst, f'test/{cls}')
|
3278
|
+
print(f'Train class {cls}: {len(os.listdir(train_class_dir))}, Test class {cls}: {len(os.listdir(test_class_dir))}')
|
3279
|
+
|
3280
|
+
return os.path.join(dst, 'train'), os.path.join(dst, 'test')
|