spacr 0.3.1__py3-none-any.whl → 0.3.22__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. spacr/__init__.py +19 -3
  2. spacr/cellpose.py +311 -0
  3. spacr/core.py +245 -2494
  4. spacr/deep_spacr.py +316 -48
  5. spacr/gui.py +1 -0
  6. spacr/gui_core.py +74 -63
  7. spacr/gui_elements.py +110 -5
  8. spacr/gui_utils.py +346 -6
  9. spacr/io.py +680 -141
  10. spacr/logger.py +28 -9
  11. spacr/measure.py +107 -95
  12. spacr/mediar.py +0 -3
  13. spacr/ml.py +1051 -0
  14. spacr/openai.py +37 -0
  15. spacr/plot.py +707 -20
  16. spacr/resources/data/lopit.csv +3833 -0
  17. spacr/resources/data/toxoplasma_metadata.csv +8843 -0
  18. spacr/resources/icons/convert.png +0 -0
  19. spacr/resources/{models/cp/toxo_plaque_cyto_e25000_X1120_Y1120.CP_model → icons/dna_matrix.mp4} +0 -0
  20. spacr/sequencing.py +241 -1311
  21. spacr/settings.py +134 -47
  22. spacr/sim.py +0 -2
  23. spacr/submodules.py +349 -0
  24. spacr/timelapse.py +0 -2
  25. spacr/toxo.py +238 -0
  26. spacr/utils.py +419 -180
  27. {spacr-0.3.1.dist-info → spacr-0.3.22.dist-info}/METADATA +31 -22
  28. {spacr-0.3.1.dist-info → spacr-0.3.22.dist-info}/RECORD +32 -33
  29. spacr/chris.py +0 -50
  30. spacr/graph_learning.py +0 -340
  31. spacr/resources/MEDIAR/.git +0 -1
  32. spacr/resources/MEDIAR_weights/.DS_Store +0 -0
  33. spacr/resources/icons/.DS_Store +0 -0
  34. spacr/resources/icons/spacr_logo_rotation.gif +0 -0
  35. spacr/resources/models/cp/toxo_plaque_cyto_e25000_X1120_Y1120.CP_model_settings.csv +0 -23
  36. spacr/resources/models/cp/toxo_pv_lumen.CP_model +0 -0
  37. spacr/sim_app.py +0 -0
  38. {spacr-0.3.1.dist-info → spacr-0.3.22.dist-info}/LICENSE +0 -0
  39. {spacr-0.3.1.dist-info → spacr-0.3.22.dist-info}/WHEEL +0 -0
  40. {spacr-0.3.1.dist-info → spacr-0.3.22.dist-info}/entry_points.txt +0 -0
  41. {spacr-0.3.1.dist-info → spacr-0.3.22.dist-info}/top_level.txt +0 -0
spacr/io.py CHANGED
@@ -1,30 +1,133 @@
1
- import os, re, sqlite3, gc, torch, time, random, shutil, cv2, tarfile, cellpose, glob, queue
1
+ import os, re, sqlite3, gc, torch, time, random, shutil, cv2, tarfile, cellpose, glob, queue, tifffile, czifile, atexit, datetime
2
2
  import numpy as np
3
3
  import pandas as pd
4
- import tifffile
5
4
  from PIL import Image, ImageOps
6
- from collections import defaultdict, Counter, deque
5
+ from collections import defaultdict, Counter
7
6
  from pathlib import Path
8
7
  from functools import partial
9
8
  from matplotlib.animation import FuncAnimation
10
9
  from IPython.display import display
11
10
  from skimage.util import img_as_uint
12
11
  from skimage.exposure import rescale_intensity
13
- from skimage import filters
14
12
  import skimage.measure as measure
15
13
  from skimage import exposure
16
14
  import imageio.v2 as imageio2
17
15
  import matplotlib.pyplot as plt
18
16
  from io import BytesIO
19
- from IPython.display import display, clear_output
20
- from multiprocessing import Pool, cpu_count, Process, Queue
21
- from torch.utils.data import Dataset, DataLoader
17
+ from IPython.display import display
18
+ from multiprocessing import Pool, cpu_count, Process, Queue, Value, Lock
19
+ from torch.utils.data import Dataset, DataLoader, random_split
22
20
  import matplotlib.pyplot as plt
23
21
  from torchvision.transforms import ToTensor
24
- import seaborn as sns
25
- import atexit
26
-
27
- from .logger import log_function_call
22
+ import seaborn as sns
23
+ from nd2reader import ND2Reader
24
+ from torchvision import transforms
25
+ from sklearn.model_selection import train_test_split
26
+
27
+ def process_non_tif_non_2D_images(folder):
28
+ """Processes all images in the folder and splits them into grayscale channels, preserving bit depth."""
29
+
30
+ # Helper function to save grayscale images
31
+ def save_grayscale_images(image, base_name, folder, dtype, channel=None, z=None, t=None):
32
+ """Save grayscale images with appropriate suffix based on channel, z, and t, preserving bit depth."""
33
+ suffix = ""
34
+ if channel is not None:
35
+ suffix += f"_C{channel}"
36
+ if z is not None:
37
+ suffix += f"_Z{z}"
38
+ if t is not None:
39
+ suffix += f"_T{t}"
40
+
41
+ output_filename = os.path.join(folder, f"{base_name}{suffix}.tif")
42
+ tifffile.imwrite(output_filename, image.astype(dtype))
43
+
44
+ # Function to handle splitting of multi-dimensional images into grayscale channels
45
+ def split_channels(image, folder, base_name, dtype):
46
+ """Splits the image into channels and handles 3D, 4D, and 5D image cases."""
47
+ if image.ndim == 2:
48
+ # Grayscale image, already processed separately
49
+ return
50
+
51
+ elif image.ndim == 3:
52
+ # 3D image: (height, width, channels)
53
+ for c in range(image.shape[2]):
54
+ save_grayscale_images(image[..., c], base_name, folder, dtype, channel=c+1)
55
+
56
+ elif image.ndim == 4:
57
+ # 4D image: (height, width, channels, Z-dimension)
58
+ for z in range(image.shape[3]):
59
+ for c in range(image.shape[2]):
60
+ save_grayscale_images(image[..., c, z], base_name, folder, dtype, channel=c+1, z=z+1)
61
+
62
+ elif image.ndim == 5:
63
+ # 5D image: (height, width, channels, Z-dimension, Time)
64
+ for t in range(image.shape[4]):
65
+ for z in range(image.shape[3]):
66
+ for c in range(image.shape[2]):
67
+ save_grayscale_images(image[..., c, z, t], base_name, folder, dtype, channel=c+1, z=z+1, t=t+1)
68
+
69
+ # Function to load images in various formats
70
+ def load_image(file_path):
71
+ """Loads image from various formats and returns it as a numpy array along with its dtype."""
72
+ ext = os.path.splitext(file_path)[1].lower()
73
+
74
+ if ext in ['.tif', '.tiff']:
75
+ image = tifffile.imread(file_path)
76
+ return image, image.dtype
77
+
78
+ elif ext in ['.png', '.jpg', '.jpeg']:
79
+ image = Image.open(file_path)
80
+ return np.array(image), image.mode
81
+
82
+ elif ext == '.czi':
83
+ with czifile.CziFile(file_path) as czi:
84
+ image = czi.asarray()
85
+ return image, image.dtype
86
+
87
+ elif ext == '.nd2':
88
+ with ND2Reader(file_path) as nd2:
89
+ image = np.array(nd2)
90
+ return image, image.dtype
91
+
92
+ else:
93
+ raise ValueError(f"Unsupported file extension: {ext}")
94
+
95
+ # Function to check if an image is grayscale and save it as a TIFF if it isn't already
96
+ def convert_grayscale_to_tiff(image, filename, folder, dtype):
97
+ """Convert grayscale images that are not in TIFF format to TIFF, preserving bit depth."""
98
+ base_name = os.path.splitext(filename)[0]
99
+ output_filename = os.path.join(folder, f"{base_name}.tif")
100
+ tifffile.imwrite(output_filename, image.astype(dtype))
101
+ print(f"Converted grayscale image {filename} to TIFF with bit depth {dtype}.")
102
+
103
+ # Supported formats
104
+ supported_formats = ['.tif', '.tiff', '.png', '.jpg', '.jpeg', '.czi', '.nd2']
105
+
106
+ # Loop through all files in the folder
107
+ for filename in os.listdir(folder):
108
+ file_path = os.path.join(folder, filename)
109
+ ext = os.path.splitext(file_path)[1].lower()
110
+
111
+ if ext in supported_formats:
112
+ print(f"Processing {filename}")
113
+ try:
114
+ # Load the image and its dtype
115
+ image, dtype = load_image(file_path)
116
+
117
+ # If the image is grayscale (2D), convert it to TIFF if it's not already in TIFF format
118
+ if image.ndim == 2:
119
+ if ext not in ['.tif', '.tiff']:
120
+ convert_grayscale_to_tiff(image, filename, folder, dtype)
121
+ else:
122
+ print(f"Image {filename} is already grayscale and in TIFF format, skipping.")
123
+ continue
124
+
125
+ # Otherwise, split channels and save images
126
+ base_name = os.path.splitext(filename)[0]
127
+ split_channels(image, folder, base_name, dtype)
128
+
129
+ except Exception as e:
130
+ print(f"Error processing {filename}: {str(e)}")
28
131
 
29
132
  def _load_images_and_labels(image_files, label_files, circular=False, invert=False):
30
133
 
@@ -632,6 +735,20 @@ class TarImageDataset(Dataset):
632
735
  img = self.transform(img)
633
736
 
634
737
  return img, m.name
738
+
739
+ def load_images_from_paths(images_by_key):
740
+ images_dict = {}
741
+
742
+ for key, paths in images_by_key.items():
743
+ images_dict[key] = []
744
+ for path in paths:
745
+ try:
746
+ with Image.open(path) as img:
747
+ images_dict[key].append(np.array(img))
748
+ except Exception as e:
749
+ print(f"Error loading image from {path}: {e}")
750
+
751
+ return images_dict
635
752
 
636
753
  #@log_function_call
637
754
  def _rename_and_organize_image_files(src, regex, batch_size=100, pick_slice=False, skip_mode='01', metadata_type='', img_format='.tif'):
@@ -657,15 +774,20 @@ def _rename_and_organize_image_files(src, regex, batch_size=100, pick_slice=Fals
657
774
  files_processed = 0
658
775
  if not os.path.exists(stack_path) or (os.path.isdir(stack_path) and len(os.listdir(stack_path)) == 0):
659
776
  all_filenames = [filename for filename in os.listdir(src) if filename.endswith(img_format)]
660
- print(f'All_files: {len(all_filenames)} in {src}')
777
+ print(f'All files: {len(all_filenames)} in {src}')
661
778
  time_ls = []
662
-
663
- for idx in range(0, len(all_filenames), batch_size):
779
+ image_paths_by_key = _extract_filename_metadata(all_filenames, src, regular_expression, metadata_type, pick_slice, skip_mode)
780
+ # Convert dictionary keys to a list for batching
781
+ batching_keys = list(image_paths_by_key.keys())
782
+ print(f'All unique FOV: {len(image_paths_by_key)} in {src}')
783
+ for idx in range(0, len(image_paths_by_key), batch_size):
664
784
  start = time.time()
665
- batch_filenames = all_filenames[idx:idx+batch_size]
666
- for filename in batch_filenames:
667
- images_by_key = _extract_filename_metadata(batch_filenames, src, regular_expression, metadata_type, pick_slice, skip_mode)
668
-
785
+
786
+ # Select batch keys and create a subset of the dictionary for this batch
787
+ batch_keys = batching_keys[idx:idx+batch_size]
788
+ batch_images_by_key = {key: image_paths_by_key[key] for key in batch_keys}
789
+ images_by_key = load_images_from_paths(batch_images_by_key)
790
+
669
791
  if pick_slice:
670
792
  for i, key in enumerate(images_by_key):
671
793
  plate, well, field, channel, mode = key
@@ -682,10 +804,10 @@ def _rename_and_organize_image_files(src, regex, batch_size=100, pick_slice=Fals
682
804
  files_to_process = len(all_filenames)
683
805
  print_progress(files_processed, files_to_process, n_jobs=1, time_ls=time_ls, batch_size=batch_size, operation_type='Preprocessing filenames')
684
806
 
685
- #if os.path.exists(output_path):
686
- # print(f'WARNING: A file with the same name already exists at location {output_filename}')
687
807
  if not os.path.exists(output_path):
688
808
  mip_image.save(output_path)
809
+ else:
810
+ print(f'WARNING: A file with the same name already exists at location {output_filename}')
689
811
  else:
690
812
  for i, (key, images) in enumerate(images_by_key.items()):
691
813
  plate, well, field, channel = key[:4]
@@ -702,10 +824,11 @@ def _rename_and_organize_image_files(src, regex, batch_size=100, pick_slice=Fals
702
824
  files_to_process = len(all_filenames)
703
825
  print_progress(files_processed, files_to_process, n_jobs=1, time_ls=time_ls, batch_size=batch_size, operation_type='Preprocessing filenames')
704
826
 
705
- #if os.path.exists(output_path):
706
- # print(f'WARNING: A file with the same name already exists at location {output_filename}')
707
827
  if not os.path.exists(output_path):
708
828
  mip_image.save(output_path)
829
+ else:
830
+ print(f'WARNING: A file with the same name already exists at location {output_filename}')
831
+
709
832
  images_by_key.clear()
710
833
 
711
834
  # Move original images to a new directory
@@ -862,47 +985,6 @@ def _move_to_chan_folder(src, regex, timelapse=False, metadata_type=''):
862
985
  shutil.move(os.path.join(src, filename), move)
863
986
  return
864
987
 
865
- def _merge_channels_v2(src, plot=False):
866
- from .plot import plot_arrays
867
- """
868
- Merge the channels in the given source directory and save the merged files in a 'stack' directory.
869
-
870
- Args:
871
- src (str): The path to the source directory containing the channel folders.
872
- plot (bool, optional): Whether to plot the merged arrays. Defaults to False.
873
-
874
- Returns:
875
- None
876
- """
877
- src = Path(src)
878
- stack_dir = src / 'stack'
879
- chan_dirs = [d for d in src.iterdir() if d.is_dir() and d.name in ['01', '02', '03', '04', '00', '1', '2', '3', '4','0']]
880
-
881
- chan_dirs.sort(key=lambda x: x.name)
882
- print(f'List of folders in src: {[d.name for d in chan_dirs]}. Single channel folders.')
883
- start_time = time.time()
884
-
885
- # First directory and its files
886
- dir_files = list(chan_dirs[0].iterdir())
887
-
888
- # Create the 'stack' directory if it doesn't exist
889
- stack_dir.mkdir(exist_ok=True)
890
- print(f'generated folder with merged arrays: {stack_dir}')
891
-
892
- if _is_dir_empty(stack_dir):
893
- with Pool(max(cpu_count() // 2, 1)) as pool:
894
- #with Pool(cpu_count()) as pool:
895
- merge_func = partial(_merge_file, chan_dirs, stack_dir)
896
- pool.map(merge_func, dir_files)
897
-
898
- avg_time = (time.time() - start_time) / len(dir_files)
899
- print(f'Average Time: {avg_time:.3f} sec')
900
-
901
- if plot:
902
- plot_arrays(src+'/stack')
903
-
904
- return
905
-
906
988
  def _merge_channels(src, plot=False):
907
989
  """
908
990
  Merge the channels in the given source directory and save the merged files in a 'stack' directory without using multiprocessing.
@@ -961,9 +1043,7 @@ def _mip_all(src, include_first_chan=True):
961
1043
  Returns:
962
1044
  None
963
1045
  """
964
-
965
- from .utils import normalize_to_dtype
966
-
1046
+
967
1047
  #print('========== generating MIPs ==========')
968
1048
  # Iterate over each file in the specified directory (src).
969
1049
  for filename in os.listdir(src):
@@ -1337,7 +1417,6 @@ def _get_lists_for_normalization(settings):
1337
1417
  return backgrounds, signal_to_noise, signal_thresholds, remove_background
1338
1418
 
1339
1419
  def _normalize_stack(src, backgrounds=[100, 100, 100], remove_backgrounds=[False, False, False], lower_percentile=2, save_dtype=np.float32, signal_to_noise=[5, 5, 5], signal_thresholds=[1000, 1000, 1000]):
1340
- from .utils import print_progress
1341
1420
  """
1342
1421
  Normalize the stack of images.
1343
1422
 
@@ -1430,7 +1509,6 @@ def _normalize_stack(src, backgrounds=[100, 100, 100], remove_backgrounds=[False
1430
1509
  return print(f'Saved stacks: {output_fldr}')
1431
1510
 
1432
1511
  def _normalize_timelapse(src, lower_percentile=2, save_dtype=np.float32):
1433
- from .utils import print_progress
1434
1512
  """
1435
1513
  Normalize the timelapse data by rescaling the intensity values based on percentiles.
1436
1514
 
@@ -1559,7 +1637,7 @@ def delete_empty_subdirectories(folder_path):
1559
1637
  #@log_function_call
1560
1638
  def preprocess_img_data(settings):
1561
1639
 
1562
- from .plot import plot_arrays, _plot_4D_arrays
1640
+ from .plot import plot_arrays
1563
1641
  from .utils import _run_test_mode, _get_regex
1564
1642
  from .settings import set_default_settings_preprocess_img_data
1565
1643
 
@@ -2054,7 +2132,6 @@ def _load_and_concatenate_arrays(src, channels, cell_chann_dim, nucleus_chann_di
2054
2132
  padded_shapes = [shape + (0,) * (max_tuple_length - len(shape)) for shape in unique_shapes]
2055
2133
  # Now create a NumPy array and find the maximum dimensions
2056
2134
  max_dims = np.max(np.array(padded_shapes), axis=0)
2057
- #clear_output(wait=True)
2058
2135
  print(f'Warning: arrays with multiple shapes found. Padding arrays to max X,Y dimentions {max_dims}')
2059
2136
  #print(f'Warning: arrays with multiple shapes found. Padding arrays to max X,Y dimentions {max_dims}', end='\r', flush=True)
2060
2137
  padded_stack_ls = []
@@ -2102,7 +2179,7 @@ def _read_db(db_loc, tables):
2102
2179
  conn.close()
2103
2180
  return dfs
2104
2181
 
2105
- def _read_and_merge_data(locs, tables, verbose=False, include_multinucleated=False, include_multiinfected=False, include_noninfected=False):
2182
+ def _read_and_merge_data(locs, tables, verbose=False, nuclei_limit=False, pathogen_limit=False, uninfected=False):
2106
2183
  """
2107
2184
  Read and merge data from SQLite databases and perform data preprocessing.
2108
2185
 
@@ -2110,9 +2187,9 @@ def _read_and_merge_data(locs, tables, verbose=False, include_multinucleated=Fal
2110
2187
  - locs (list): A list of file paths to the SQLite database files.
2111
2188
  - tables (list): A list of table names to read from the databases.
2112
2189
  - verbose (bool): Whether to print verbose output. Default is False.
2113
- - include_multinucleated (bool): Whether to include multinucleated cells. Default is False.
2114
- - include_multiinfected (bool): Whether to include cells with multiple infections. Default is False.
2115
- - include_noninfected (bool): Whether to include non-infected cells. Default is False.
2190
+ - nuclei_limit (bool): Whether to include multinucleated cells. Default is False.
2191
+ - pathogen_limit (bool): Whether to include cells with multiple infections. Default is False.
2192
+ - uninfected (bool): Whether to include non-infected cells. Default is False.
2116
2193
 
2117
2194
  Returns:
2118
2195
  - merged_df (pandas.DataFrame): The merged and preprocessed dataframe.
@@ -2187,7 +2264,7 @@ def _read_and_merge_data(locs, tables, verbose=False, include_multinucleated=Fal
2187
2264
  nucleus = nucleus.assign(cell_id=lambda x: 'o' + x['cell_id'].astype(int).astype(str))
2188
2265
  nucleus = nucleus.assign(prcfo = lambda x: x['prcf'] + '_' + x['cell_id'])
2189
2266
  nucleus['nucleus_prcfo_count'] = nucleus.groupby('prcfo')['prcfo'].transform('count')
2190
- if include_multinucleated == False:
2267
+ if nuclei_limit == False:
2191
2268
  #nucleus = nucleus[~nucleus['prcfo'].duplicated()]
2192
2269
  nucleus = nucleus[nucleus['nucleus_prcfo_count']==1]
2193
2270
  nucleus_g_df, _ = _split_data(nucleus, 'prcfo', 'cell_id')
@@ -2203,9 +2280,9 @@ def _read_and_merge_data(locs, tables, verbose=False, include_multinucleated=Fal
2203
2280
  pathogens = pathogens.assign(cell_id=lambda x: 'o' + x['cell_id'].astype(int).astype(str))
2204
2281
  pathogens = pathogens.assign(prcfo = lambda x: x['prcf'] + '_' + x['cell_id'])
2205
2282
  pathogens['pathogen_prcfo_count'] = pathogens.groupby('prcfo')['prcfo'].transform('count')
2206
- if include_noninfected == False:
2283
+ if uninfected == False:
2207
2284
  pathogens = pathogens[pathogens['pathogen_prcfo_count']>=1]
2208
- if include_multiinfected == False:
2285
+ if pathogen_limit == False:
2209
2286
  pathogens = pathogens[pathogens['pathogen_prcfo_count']<=1]
2210
2287
  pathogens_g_df, _ = _split_data(pathogens, 'prcfo', 'cell_id')
2211
2288
  print(f'pathogens: {len(pathogens)}')
@@ -2267,12 +2344,8 @@ def _results_to_csv(src, df, df_well):
2267
2344
  wells.to_csv(wells_loc, index=True, header=True)
2268
2345
  cells.to_csv(cells_loc, index=True, header=True)
2269
2346
  return cells, wells
2270
-
2271
- ###################################################
2272
- # Classify
2273
- ###################################################
2274
2347
 
2275
- def read_plot_model_stats(file_path ,save=False):
2348
+ def read_plot_model_stats(train_file_path, val_file_path ,save=False):
2276
2349
 
2277
2350
  def _plot_and_save(train_df, val_df, column='accuracy', save=False, path=None, dpi=600):
2278
2351
 
@@ -2301,37 +2374,19 @@ def read_plot_model_stats(file_path ,save=False):
2301
2374
  plt.savefig(pdf_path, format='pdf', dpi=dpi)
2302
2375
  else:
2303
2376
  plt.show()
2304
- # Read the CSV into a dataframe
2305
- df = pd.read_csv(file_path, index_col=0)
2306
-
2307
- # Split the dataframe into train and validation based on the index
2308
- train_df = df.filter(like='_train', axis=0).copy()
2309
- val_df = df.filter(like='_val', axis=0).copy()
2310
-
2311
- fldr_1 = os.path.dirname(file_path)
2312
-
2313
- train_csv_path = os.path.join(fldr_1, 'train.csv')
2314
- val_csv_path = os.path.join(fldr_1, 'validation.csv')
2315
2377
 
2316
- fldr_2 = os.path.dirname(fldr_1)
2317
- fldr_3 = os.path.dirname(fldr_2)
2318
- bn_1 = os.path.basename(fldr_1)
2319
- bn_2 = os.path.basename(fldr_2)
2320
- bn_3 = os.path.basename(fldr_3)
2321
- model_name = str(f'{bn_1}_{bn_2}_{bn_3}')
2378
+ # Read the CSVs into DataFrames
2379
+ train_df = pd.read_csv(train_file_path, index_col=0)
2380
+ val_df = pd.read_csv(val_file_path, index_col=0)
2322
2381
 
2323
- # Extract epochs from index
2324
- train_df['epoch'] = [int(idx.split('_')[0]) for idx in train_df.index]
2325
- val_df['epoch'] = [int(idx.split('_')[0]) for idx in val_df.index]
2326
-
2327
- # Save dataframes to a CSV file
2328
- train_df.to_csv(train_csv_path)
2329
- val_df.to_csv(val_csv_path)
2382
+ # Get the folder path for saving plots
2383
+ fldr_1 = os.path.dirname(train_file_path)
2330
2384
 
2331
2385
  if save:
2332
2386
  # Setting the style
2333
2387
  sns.set(style="whitegrid")
2334
2388
 
2389
+ # Plot and save the results
2335
2390
  _plot_and_save(train_df, val_df, column='accuracy', save=save, path=fldr_1)
2336
2391
  _plot_and_save(train_df, val_df, column='neg_accuracy', save=save, path=fldr_1)
2337
2392
  _plot_and_save(train_df, val_df, column='pos_accuracy', save=save, path=fldr_1)
@@ -2379,50 +2434,53 @@ def _save_model(model, model_type, results_df, dst, epoch, epochs, intermedeate_
2379
2434
 
2380
2435
  return model_path
2381
2436
 
2382
- def _save_progress(dst, results_df, result_type='train'):
2437
+ def _save_progress(dst, train_df, validation_df):
2383
2438
  """
2384
2439
  Save the progress of the classification model.
2385
2440
 
2386
2441
  Parameters:
2387
2442
  dst (str): The destination directory to save the progress.
2388
- results_df (pandas.DataFrame): The DataFrame containing accuracy, loss, and PRAUC.
2389
- train_metrics_df (pandas.DataFrame): The DataFrame containing training metrics.
2443
+ train_df (pandas.DataFrame): The DataFrame containing training stats.
2444
+ validation_df (pandas.DataFrame): The DataFrame containing validation stats (if available).
2390
2445
 
2391
2446
  Returns:
2392
2447
  None
2393
2448
  """
2449
+
2450
+ def _save_df_to_csv(file_path, df):
2451
+ """
2452
+ Save the given DataFrame to the specified CSV file, either creating a new file or appending to an existing one.
2453
+
2454
+ Parameters:
2455
+ file_path (str): The file path where the CSV will be saved.
2456
+ df (pandas.DataFrame): The DataFrame to save.
2457
+ """
2458
+ if not os.path.exists(file_path):
2459
+ with open(file_path, 'w') as f:
2460
+ df.to_csv(f, index=True, header=True)
2461
+ f.flush() # Ensure data is written to the file system
2462
+ else:
2463
+ with open(file_path, 'a') as f:
2464
+ df.to_csv(f, index=True, header=False)
2465
+ f.flush()
2466
+
2394
2467
  # Save accuracy, loss, PRAUC
2395
2468
  os.makedirs(dst, exist_ok=True)
2396
- results_path = os.path.join(dst, f'{result_type}.csv')
2397
- if not os.path.exists(results_path):
2398
- results_df.to_csv(results_path, index=True, header=True, mode='w')
2399
- else:
2400
- results_df.to_csv(results_path, index=True, header=False, mode='a')
2469
+ results_path_train = os.path.join(dst, 'train.csv')
2470
+ results_path_validation = os.path.join(dst, 'validation.csv')
2401
2471
 
2402
- if result_type == 'train':
2403
- read_plot_model_stats(results_path, save=True)
2404
- return
2472
+ # Save training data
2473
+ _save_df_to_csv(results_path_train, train_df)
2405
2474
 
2406
- def _save_settings(settings, src):
2407
- """
2408
- Save the settings dictionary to a CSV file.
2475
+ # Save validation data if available
2476
+ if validation_df is not None:
2477
+ _save_df_to_csv(results_path_validation, validation_df)
2409
2478
 
2410
- Parameters:
2411
- - settings (dict): A dictionary containing the settings.
2412
- - src (str): The source directory where the settings file will be saved.
2479
+ # Call read_plot_model_stats after ensuring the files are saved
2480
+ read_plot_model_stats(results_path_train, results_path_validation, save=True)
2413
2481
 
2414
- Returns:
2415
- None
2416
- """
2417
- dst = os.path.join(src,'model')
2418
- settings_loc = os.path.join(dst,'settings.csv')
2419
- os.makedirs(dst, exist_ok=True)
2420
- settings_df = pd.DataFrame(list(settings.items()), columns=['setting_key', 'setting_value'])
2421
- display(settings_df)
2422
- settings_df.to_csv(settings_loc, index=False)
2423
2482
  return
2424
2483
 
2425
-
2426
2484
  def _copy_missclassified(df):
2427
2485
  misclassified = df[df['true_label'] != df['predicted_label']]
2428
2486
  for _, row in misclassified.iterrows():
@@ -2448,7 +2506,7 @@ def _read_db(db_loc, tables):
2448
2506
  conn.close() # Close the connection
2449
2507
  return dfs
2450
2508
 
2451
- def _read_and_merge_data(locs, tables, verbose=False, include_multinucleated=False, include_multiinfected=False, include_noninfected=False):
2509
+ def _read_and_merge_data(locs, tables, verbose=False, nuclei_limit=False, pathogen_limit=False, uninfected=False):
2452
2510
 
2453
2511
  from .utils import _split_data
2454
2512
 
@@ -2533,7 +2591,7 @@ def _read_and_merge_data(locs, tables, verbose=False, include_multinucleated=Fal
2533
2591
  nucleus = nucleus.assign(cell_id=lambda x: 'o' + x['cell_id'].astype(int).astype(str))
2534
2592
  nucleus = nucleus.assign(prcfo = lambda x: x['prcf'] + '_' + x['cell_id'])
2535
2593
  nucleus['nucleus_prcfo_count'] = nucleus.groupby('prcfo')['prcfo'].transform('count')
2536
- if include_multinucleated == False:
2594
+ if nuclei_limit == False:
2537
2595
  nucleus = nucleus[nucleus['nucleus_prcfo_count']==1]
2538
2596
  nucleus_g_df, _ = _split_data(nucleus, 'prcfo', 'cell_id')
2539
2597
  if verbose:
@@ -2559,20 +2617,30 @@ def _read_and_merge_data(locs, tables, verbose=False, include_multinucleated=Fal
2559
2617
  pathogens = pathogens.assign(cell_id=lambda x: 'o' + x['cell_id'].astype(int).astype(str))
2560
2618
  pathogens = pathogens.assign(prcfo = lambda x: x['prcf'] + '_' + x['cell_id'])
2561
2619
  pathogens['pathogen_prcfo_count'] = pathogens.groupby('prcfo')['prcfo'].transform('count')
2562
- if include_noninfected == False:
2620
+
2621
+ print(f"before noninfected: {len(pathogens)}")
2622
+ if uninfected == False:
2563
2623
  pathogens = pathogens[pathogens['pathogen_prcfo_count']>=1]
2564
- if isinstance(include_multiinfected, bool):
2565
- if include_multiinfected == False:
2624
+ print(f"after noninfected: {len(pathogens)}")
2625
+
2626
+ if isinstance(pathogen_limit, bool):
2627
+ if pathogen_limit == False:
2566
2628
  pathogens = pathogens[pathogens['pathogen_prcfo_count']<=1]
2567
- if isinstance(include_multiinfected, float):
2568
- pathogens = pathogens[pathogens['pathogen_prcfo_count']<=include_multiinfected]
2629
+ print(f"after multiinfected Bool: {len(pathogens)}")
2630
+ if isinstance(pathogen_limit, float):
2631
+ pathogen_limit = int(pathogen_limit)
2632
+ if isinstance(pathogen_limit, int):
2633
+ pathogens = pathogens[pathogens['pathogen_prcfo_count']<=pathogen_limit]
2634
+ print(f"afer multiinfected Float: {len(pathogens)}")
2569
2635
  if not 'cell' in tables:
2570
2636
  pathogens_g_df, metadata = _split_data(pathogens, 'prcfo', 'cell_id')
2571
2637
  else:
2572
2638
  pathogens_g_df, _ = _split_data(pathogens, 'prcfo', 'cell_id')
2639
+
2573
2640
  if verbose:
2574
2641
  print(f'pathogens: {len(pathogens)}')
2575
2642
  print(f'pathogens grouped: {len(pathogens_g_df)}')
2643
+
2576
2644
  if len(merged_df) == 0:
2577
2645
  merged_df = pathogens_g_df
2578
2646
  else:
@@ -2697,4 +2765,475 @@ def generate_cellpose_train_test(src, test_split=0.1):
2697
2765
  shutil.copy(img_path, new_img_path)
2698
2766
  shutil.copy(mask_path, new_mask_path)
2699
2767
  print(f'Copied {idx+1}/{len(ls)} images to {_type} set')#, end='\r', flush=True)
2700
-
2768
+
2769
+ def parse_gz_files(folder_path):
2770
+ """
2771
+ Parses the .fastq.gz files in the specified folder path and returns a dictionary
2772
+ containing the sample names and their corresponding file paths.
2773
+
2774
+ Args:
2775
+ folder_path (str): The path to the folder containing the .fastq.gz files.
2776
+
2777
+ Returns:
2778
+ dict: A dictionary where the keys are the sample names and the values are
2779
+ dictionaries containing the file paths for the 'R1' and 'R2' read directions.
2780
+ """
2781
+ files = os.listdir(folder_path)
2782
+ gz_files = [f for f in files if f.endswith('.fastq.gz')]
2783
+
2784
+ samples_dict = {}
2785
+ for gz_file in gz_files:
2786
+ parts = gz_file.split('_')
2787
+ sample_name = parts[0]
2788
+ read_direction = parts[1]
2789
+
2790
+ if sample_name not in samples_dict:
2791
+ samples_dict[sample_name] = {}
2792
+
2793
+ if read_direction == "R1":
2794
+ samples_dict[sample_name]['R1'] = os.path.join(folder_path, gz_file)
2795
+ elif read_direction == "R2":
2796
+ samples_dict[sample_name]['R2'] = os.path.join(folder_path, gz_file)
2797
+ return samples_dict
2798
+
2799
+ def generate_dataset(settings={}):
2800
+
2801
+ from .utils import initiate_counter, add_images_to_tar, save_settings, generate_path_list_from_db, correct_paths
2802
+ from .settings import set_generate_dataset_defaults
2803
+
2804
+ settings = set_generate_dataset_defaults(settings)
2805
+ save_settings(settings, 'generate_dataset', show=True)
2806
+
2807
+ if isinstance(settings['src'], str):
2808
+ settings['src'] = [settings['src']]
2809
+ if isinstance(settings['src'], list):
2810
+ all_paths = []
2811
+ for i, src in enumerate(settings['src']):
2812
+ db_path = os.path.join(src, 'measurements', 'measurements.db')
2813
+ if i == 0:
2814
+ dst = os.path.join(src, 'datasets')
2815
+ paths = generate_path_list_from_db(db_path, file_metadata=settings['file_metadata'])
2816
+ correct_paths(paths, src)
2817
+ all_paths.extend(paths)
2818
+ if isinstance(settings['sample'], int):
2819
+ selected_paths = random.sample(all_paths, settings['sample'])
2820
+ print(f"Random selection of {len(selected_paths)} paths")
2821
+ elif isinstance(settings['sample'], list):
2822
+ sample = settings['sample'][i]
2823
+ selected_paths = random.sample(all_paths, settings['sample'])
2824
+ print(f"Random selection of {len(selected_paths)} paths")
2825
+ else:
2826
+ selected_paths = all_paths
2827
+ random.shuffle(selected_paths)
2828
+ print(f"All paths: {len(selected_paths)} paths")
2829
+
2830
+ total_images = len(selected_paths)
2831
+ print(f"Found {total_images} images")
2832
+
2833
+ # Create a temp folder in dst
2834
+ temp_dir = os.path.join(dst, "temp_tars")
2835
+ os.makedirs(temp_dir, exist_ok=True)
2836
+
2837
+ # Chunking the data
2838
+ num_procs = max(2, cpu_count() - 2)
2839
+ chunk_size = len(selected_paths) // num_procs
2840
+ remainder = len(selected_paths) % num_procs
2841
+
2842
+ paths_chunks = []
2843
+ start = 0
2844
+ for i in range(num_procs):
2845
+ end = start + chunk_size + (1 if i < remainder else 0)
2846
+ paths_chunks.append(selected_paths[start:end])
2847
+ start = end
2848
+
2849
+ temp_tar_files = [os.path.join(temp_dir, f"temp_{i}.tar") for i in range(num_procs)]
2850
+
2851
+ print(f"Generating temporary tar files in {dst}")
2852
+
2853
+ # Initialize shared counter and lock
2854
+ counter = Value('i', 0)
2855
+ lock = Lock()
2856
+
2857
+ with Pool(processes=num_procs, initializer=initiate_counter, initargs=(counter, lock)) as pool:
2858
+ pool.starmap(add_images_to_tar, [(paths_chunks[i], temp_tar_files[i], total_images) for i in range(num_procs)])
2859
+
2860
+ # Combine the temporary tar files into a final tar
2861
+ date_name = datetime.date.today().strftime('%y%m%d')
2862
+ if len(settings['src']) > 1:
2863
+ date_name = f"{date_name}_combined"
2864
+ if not settings['file_metadata'] is None:
2865
+ tar_name = f"{date_name}_{settings['experiment']}_{settings['file_metadata']}.tar"
2866
+ else:
2867
+ tar_name = f"{date_name}_{settings['experiment']}.tar"
2868
+ tar_name = os.path.join(dst, tar_name)
2869
+ if os.path.exists(tar_name):
2870
+ number = random.randint(1, 100)
2871
+ tar_name_2 = f"{date_name}_{settings['experiment']}_{settings['file_metadata']}_{number}.tar"
2872
+ print(f"Warning: {os.path.basename(tar_name)} exists, saving as {os.path.basename(tar_name_2)} ")
2873
+ tar_name = os.path.join(dst, tar_name_2)
2874
+
2875
+ print(f"Merging temporary files")
2876
+
2877
+ with tarfile.open(tar_name, 'w') as final_tar:
2878
+ for temp_tar_path in temp_tar_files:
2879
+ with tarfile.open(temp_tar_path, 'r') as temp_tar:
2880
+ for member in temp_tar.getmembers():
2881
+ file_obj = temp_tar.extractfile(member)
2882
+ final_tar.addfile(member, file_obj)
2883
+ os.remove(temp_tar_path)
2884
+
2885
+ # Delete the temp folder
2886
+ shutil.rmtree(temp_dir)
2887
+ print(f"\nSaved {total_images} images to {tar_name}")
2888
+
2889
+ return tar_name
2890
+
2891
+ def generate_loaders(src, mode='train', image_size=224, batch_size=32, classes=['nc','pc'], n_jobs=None, validation_split=0.0, pin_memory=False, normalize=False, channels=[1, 2, 3], augment=False, verbose=False):
2892
+
2893
+ """
2894
+ Generate data loaders for training and validation/test datasets.
2895
+
2896
+ Parameters:
2897
+ - src (str): The source directory containing the data.
2898
+ - mode (str): The mode of operation. Options are 'train' or 'test'.
2899
+ - image_size (int): The size of the input images.
2900
+ - batch_size (int): The batch size for the data loaders.
2901
+ - classes (list): The list of classes to consider.
2902
+ - n_jobs (int): The number of worker threads for data loading.
2903
+ - validation_split (float): The fraction of data to use for validation.
2904
+ - pin_memory (bool): Whether to pin memory for faster data transfer.
2905
+ - normalize (bool): Whether to normalize the input images.
2906
+ - verbose (bool): Whether to print additional information and show images.
2907
+ - channels (list): The list of channels to retain. Options are [1, 2, 3] for all channels, [1, 2] for blue and green, etc.
2908
+
2909
+ Returns:
2910
+ - train_loaders (list): List of data loaders for training datasets.
2911
+ - val_loaders (list): List of data loaders for validation datasets.
2912
+ """
2913
+
2914
+ from .utils import SelectChannels, augment_dataset
2915
+
2916
+ chans = []
2917
+
2918
+ if 'r' in channels:
2919
+ chans.append(1)
2920
+ if 'g' in channels:
2921
+ chans.append(2)
2922
+ if 'b' in channels:
2923
+ chans.append(3)
2924
+
2925
+ channels = chans
2926
+
2927
+ if verbose:
2928
+ print(f'Training a network on channels: {channels}')
2929
+ print(f'Channel 1: Red, Channel 2: Green, Channel 3: Blue')
2930
+
2931
+ train_loaders = []
2932
+ val_loaders = []
2933
+
2934
+ if normalize:
2935
+ transform = transforms.Compose([
2936
+ transforms.ToTensor(),
2937
+ transforms.CenterCrop(size=(image_size, image_size)),
2938
+ SelectChannels(channels),
2939
+ transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])
2940
+ else:
2941
+ transform = transforms.Compose([
2942
+ transforms.ToTensor(),
2943
+ transforms.CenterCrop(size=(image_size, image_size)),
2944
+ SelectChannels(channels)])
2945
+
2946
+ if mode == 'train':
2947
+ data_dir = os.path.join(src, 'train')
2948
+ shuffle = True
2949
+ print('Loading Train and validation datasets')
2950
+ elif mode == 'test':
2951
+ data_dir = os.path.join(src, 'test')
2952
+ val_loaders = []
2953
+ validation_split = 0.0
2954
+ shuffle = True
2955
+ print('Loading test dataset')
2956
+ else:
2957
+ print(f'mode:{mode} is not valid, use mode = train or test')
2958
+ return
2959
+
2960
+ data = spacrDataset(data_dir, classes, transform=transform, shuffle=shuffle, pin_memory=pin_memory)
2961
+ num_workers = n_jobs if n_jobs is not None else 0
2962
+
2963
+ if validation_split > 0:
2964
+ train_size = int((1 - validation_split) * len(data))
2965
+ val_size = len(data) - train_size
2966
+ if not augment:
2967
+ print(f'Train data:{train_size}, Validation data:{val_size}')
2968
+ train_dataset, val_dataset = random_split(data, [train_size, val_size])
2969
+
2970
+ if augment:
2971
+
2972
+ print(f'Data before augmentation: Train: {len(train_dataset)}, Validataion:{len(val_dataset)}')
2973
+ train_dataset = augment_dataset(train_dataset, is_grayscale=(len(channels) == 1))
2974
+ print(f'Data after augmentation: Train: {len(train_dataset)}')
2975
+
2976
+ print(f'Generating Dataloader with {n_jobs} workers')
2977
+ train_loaders = DataLoader(train_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=1, pin_memory=pin_memory, persistent_workers=True)
2978
+ val_loaders = DataLoader(val_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=1, pin_memory=pin_memory, persistent_workers=True)
2979
+ else:
2980
+ train_loaders = DataLoader(data, batch_size=batch_size, shuffle=shuffle, num_workers=1, pin_memory=pin_memory, persistent_workers=True)
2981
+
2982
+ #dataset (Dataset) – dataset from which to load the data.
2983
+ #batch_size (int, optional) – how many samples per batch to load (default: 1).
2984
+ #shuffle (bool, optional) – set to True to have the data reshuffled at every epoch (default: False).
2985
+ #sampler (Sampler or Iterable, optional) – defines the strategy to draw samples from the dataset. Can be any Iterable with __len__ implemented. If specified, shuffle must not be specified.
2986
+ #batch_sampler (Sampler or Iterable, optional) – like sampler, but returns a batch of indices at a time. Mutually exclusive with batch_size, shuffle, sampler, and drop_last.
2987
+ #num_workers (int, optional) – how many subprocesses to use for data loading. 0 means that the data will be loaded in the main process. (default: 0)
2988
+ #collate_fn (Callable, optional) – merges a list of samples to form a mini-batch of Tensor(s). Used when using batched loading from a map-style dataset.
2989
+ #pin_memory (bool, optional) – If True, the data loader will copy Tensors into device/CUDA pinned memory before returning them. If your data elements are a custom type, or your collate_fn returns a batch that is a custom type, see the example below.
2990
+ #drop_last (bool, optional) – set to True to drop the last incomplete batch, if the dataset size is not divisible by the batch size. If False and the size of dataset is not divisible by the batch size, then the last batch will be smaller. (default: False)
2991
+ #timeout (numeric, optional) – if positive, the timeout value for collecting a batch from workers. Should always be non-negative. (default: 0)
2992
+ #worker_init_fn (Callable, optional) – If not None, this will be called on each worker subprocess with the worker id (an int in [0, num_workers - 1]) as input, after seeding and before data loading. (default: None)
2993
+ #multiprocessing_context (str or multiprocessing.context.BaseContext, optional) – If None, the default multiprocessing context of your operating system will be used. (default: None)
2994
+ #generator (torch.Generator, optional) – If not None, this RNG will be used by RandomSampler to generate random indexes and multiprocessing to generate base_seed for workers. (default: None)
2995
+ #prefetch_factor (int, optional, keyword-only arg) – Number of batches loaded in advance by each worker. 2 means there will be a total of 2 * num_workers batches prefetched across all workers. (default value depends on the set value for num_workers. If value of num_workers=0 default is None. Otherwise, if value of num_workers > 0 default is 2).
2996
+ #persistent_workers (bool, optional) – If True, the data loader will not shut down the worker processes after a dataset has been consumed once. This allows to maintain the workers Dataset instances alive. (default: False)
2997
+ #pin_memory_device (str, optional) – the device to pin_memory to if pin_memory is True.
2998
+
2999
+ #images, labels, filenames = next(iter(train_loaders))
3000
+ #images = images.cpu()
3001
+ #label_strings = [str(label.item()) for label in labels]
3002
+ #train_fig = _imshow_gpu(images, label_strings, nrow=20, fontsize=12)
3003
+ #if verbose:
3004
+ # plt.show()
3005
+
3006
+ train_fig = None
3007
+
3008
+ return train_loaders, val_loaders, train_fig
3009
+
3010
+ def generate_training_dataset(settings):
3011
+
3012
+ # Function to filter png_list_df by prcfo present in df without merging
3013
+ def filter_png_list(db_path, settings):
3014
+ tables = ['cell', 'nucleus', 'pathogen', 'cytoplasm']
3015
+ df, _ = _read_and_merge_data(locs=[db_path],
3016
+ tables=tables,
3017
+ verbose=False,
3018
+ nuclei_limit=settings['nuclei_limit'],
3019
+ pathogen_limit=settings['pathogen_limit'],
3020
+ uninfected=settings['uninfected'])
3021
+ [png_list_df] = _read_db(db_loc=db_path, tables=['png_list'])
3022
+ filtered_png_list_df = png_list_df[png_list_df['prcfo'].isin(df.index)]
3023
+ return filtered_png_list_df
3024
+
3025
+ # Function to get the smallest class size based on the dataset mode
3026
+ def get_smallest_class_size(df, settings, dataset_mode):
3027
+ if dataset_mode == 'metadata':
3028
+ sizes = [len(df[df['metadata_based_class'] == c]) for c in settings['classes']]
3029
+ elif dataset_mode == 'annotation':
3030
+ sizes = [len(class_paths) for class_paths in df]
3031
+ size = min(sizes)
3032
+ print(f'Using the smallest class size: {size}')
3033
+ return size
3034
+
3035
+ # Measurement-based selection logic
3036
+ def measurement_based_selection(settings, db_path):
3037
+ class_paths_ls = []
3038
+ tables = ['cell', 'nucleus', 'pathogen', 'cytoplasm']
3039
+ df, _ = _read_and_merge_data(locs=[db_path],
3040
+ tables=tables,
3041
+ verbose=False,
3042
+ nuclei_limit=settings['nuclei_limit'],
3043
+ pathogen_limit=settings['pathogen_limit'],
3044
+ uninfected=settings['uninfected'])
3045
+
3046
+ print('length df 1', len(df))
3047
+ df = annotate_conditions(df, cells=['HeLa'], pathogens=['pathogen'], treatments=settings['classes'],
3048
+ treatment_loc=settings['class_metadata'])#, types=settings['metadata_type_by'])
3049
+ print('length df 2', len(df))
3050
+
3051
+ png_list_df = filter_png_list(db_path, settings)
3052
+
3053
+ if settings['custom_measurement']:
3054
+ if isinstance(settings['custom_measurement'], list):
3055
+ if len(settings['custom_measurement']) == 2:
3056
+ df['recruitment'] = df[f"{settings['custom_measurement'][0]}"] / df[f"{settings['custom_measurement'][1]}"]
3057
+ else:
3058
+ df['recruitment'] = df[f"{settings['custom_measurement'][0]}"]
3059
+ else:
3060
+ print("custom_measurement should be a list.")
3061
+ return
3062
+
3063
+ else:
3064
+ df['recruitment'] = df[f"pathogen_channel_{settings['channel_of_interest']}_mean_intensity"] / df[f"cytoplasm_channel_{settings['channel_of_interest']}_mean_intensity"]
3065
+
3066
+ q25 = df['recruitment'].quantile(0.25)
3067
+ q75 = df['recruitment'].quantile(0.75)
3068
+ df_lower = df[df['recruitment'] <= q25]
3069
+ df_upper = df[df['recruitment'] >= q75]
3070
+
3071
+ class_paths_lower = get_paths_from_db(df=df_lower, png_df=png_list_df, image_type=settings['png_type'])
3072
+ class_paths_lower = random.sample(class_paths_lower['png_path'].tolist(), settings['size'])
3073
+ class_paths_ls.append(class_paths_lower)
3074
+
3075
+ class_paths_upper = get_paths_from_db(df=df_upper, png_df=png_list_df, image_type=settings['png_type'])
3076
+ class_paths_upper = random.sample(class_paths_upper['png_path'].tolist(), settings['size'])
3077
+ class_paths_ls.append(class_paths_upper)
3078
+
3079
+ return class_paths_ls
3080
+
3081
+ # Metadata-based selection logic
3082
+ def metadata_based_selection(db_path, settings):
3083
+ class_paths_ls = []
3084
+ df = filter_png_list(db_path, settings)
3085
+
3086
+ df['metadata_based_class'] = pd.NA
3087
+ for i, class_ in enumerate(settings['classes']):
3088
+ ls = settings['class_metadata'][i]
3089
+ df.loc[df[settings['metadata_type_by']].isin(ls), 'metadata_based_class'] = class_
3090
+
3091
+ size = get_smallest_class_size(df, settings, 'metadata')
3092
+ for class_ in settings['classes']:
3093
+ class_temp_df = df[df['metadata_based_class'] == class_]
3094
+ print(f'Found {len(class_temp_df)} images for class {class_}')
3095
+ class_paths_temp = class_temp_df['png_path'].tolist()
3096
+
3097
+ # Ensure to sample `size` number of images (smallest class size)
3098
+ if len(class_paths_temp) > size:
3099
+ class_paths_temp = random.sample(class_paths_temp, size)
3100
+
3101
+ class_paths_ls.append(class_paths_temp)
3102
+
3103
+ return class_paths_ls
3104
+
3105
+ # Annotation-based selection logic
3106
+ def annotation_based_selection(db_path, dst, settings):
3107
+ class_paths_ls = training_dataset_from_annotation(db_path, dst, settings['annotation_column'], annotated_classes=settings['annotated_classes'])
3108
+
3109
+ size = get_smallest_class_size(class_paths_ls, settings, 'annotation')
3110
+ for i, class_paths in enumerate(class_paths_ls):
3111
+ if len(class_paths) > size:
3112
+ class_paths_ls[i] = random.sample(class_paths, size)
3113
+
3114
+ return class_paths_ls
3115
+
3116
+ from .io import _read_and_merge_data, _read_db
3117
+ from .utils import get_paths_from_db, annotate_conditions, save_settings
3118
+ from .settings import set_generate_training_dataset_defaults
3119
+
3120
+ # Set default settings and save
3121
+ settings = set_generate_training_dataset_defaults(settings)
3122
+ save_settings(settings, 'cv_dataset', show=True)
3123
+
3124
+ class_path_list = None
3125
+
3126
+ if isinstance(settings['src'], str):
3127
+ src = [settings['src']]
3128
+
3129
+ for i, src in enumerate(settings['src']):
3130
+ db_path = os.path.join(src, 'measurements', 'measurements.db')
3131
+
3132
+ if len(settings['src']) > 1 and i == 0:
3133
+ dst = os.path.join(src, 'datasets', 'training_all')
3134
+ elif len(settings['src']) == 1:
3135
+ dst = os.path.join(src, 'datasets', 'training')
3136
+
3137
+ # Create a new directory for training data if necessary
3138
+ if os.path.exists(dst):
3139
+ for i in range(1, 100000):
3140
+ dst = dst + f'_{i}'
3141
+ if not os.path.exists(dst):
3142
+ print(f'Creating new directory for training: {dst}')
3143
+ break
3144
+
3145
+ # Select dataset based on dataset mode
3146
+ if settings['dataset_mode'] == 'annotation':
3147
+ class_paths_ls = annotation_based_selection(db_path, dst, settings)
3148
+
3149
+ elif settings['dataset_mode'] == 'metadata':
3150
+ class_paths_ls = metadata_based_selection(db_path, settings)
3151
+
3152
+ elif settings['dataset_mode'] == 'measurement':
3153
+ class_paths_ls = measurement_based_selection(settings, db_path)
3154
+
3155
+ if class_path_list is None:
3156
+ class_path_list = [[] for _ in range(len(class_paths_ls))]
3157
+
3158
+ # Extend each list in class_path_list with the corresponding list from class_paths_ls
3159
+ for idx in range(len(class_paths_ls)):
3160
+ class_path_list[idx].extend(class_paths_ls[idx])
3161
+
3162
+ # Generate and return training and testing directories
3163
+ train_class_dir, test_class_dir = generate_dataset_from_lists(dst, class_data=class_path_list, classes=settings['classes'], test_split=settings['test_split'])
3164
+
3165
+ return train_class_dir, test_class_dir
3166
+
3167
+ def training_dataset_from_annotation(db_path, dst, annotation_column='test', annotated_classes=(1, 2)):
3168
+ all_paths = []
3169
+
3170
+ # Connect to the database and retrieve the image paths and annotations
3171
+ print(f'Reading DataBase: {db_path}')
3172
+ with sqlite3.connect(db_path) as conn:
3173
+ cursor = conn.cursor()
3174
+ # Prepare the query with parameterized placeholders for annotated_classes
3175
+ placeholders = ','.join('?' * len(annotated_classes))
3176
+ query = f"SELECT png_path, {annotation_column} FROM png_list WHERE {annotation_column} IN ({placeholders})"
3177
+ cursor.execute(query, annotated_classes)
3178
+
3179
+ while True:
3180
+ rows = cursor.fetchmany(1000)
3181
+ if not rows:
3182
+ break
3183
+ for row in rows:
3184
+ all_paths.append(row)
3185
+
3186
+ # Filter paths based on annotation
3187
+ class_paths = []
3188
+ for class_ in annotated_classes:
3189
+ class_paths_temp = [path for path, annotation in all_paths if annotation == class_]
3190
+ class_paths.append(class_paths_temp)
3191
+
3192
+ print(f'Generated a list of lists from annotation of {len(class_paths)} classes')
3193
+ return class_paths
3194
+
3195
+ def generate_dataset_from_lists(dst, class_data, classes, test_split=0.1):
3196
+ from .utils import print_progress
3197
+ # Make sure that the length of class_data matches the length of classes
3198
+ if len(class_data) != len(classes):
3199
+ raise ValueError("class_data and classes must have the same length.")
3200
+
3201
+ total_files = sum(len(data) for data in class_data)
3202
+ processed_files = 0
3203
+ time_ls = []
3204
+
3205
+ for cls, data in zip(classes, class_data):
3206
+ # Create directories
3207
+ train_class_dir = os.path.join(dst, f'train/{cls}')
3208
+ test_class_dir = os.path.join(dst, f'test/{cls}')
3209
+ os.makedirs(train_class_dir, exist_ok=True)
3210
+ os.makedirs(test_class_dir, exist_ok=True)
3211
+
3212
+ # Split the data
3213
+ train_data, test_data = train_test_split(data, test_size=test_split, shuffle=True, random_state=42)
3214
+
3215
+ # Copy train files
3216
+ for path in train_data:
3217
+ start = time.time()
3218
+ shutil.copy(path, os.path.join(train_class_dir, os.path.basename(path)))
3219
+ duration = time.time() - start
3220
+ time_ls.append(duration)
3221
+ print_progress(processed_files, total_files, n_jobs=1, time_ls=None, batch_size=None, operation_type="Copying files for Train dataset")
3222
+ processed_files += 1
3223
+
3224
+ # Copy test files
3225
+ for path in test_data:
3226
+ start = time.time()
3227
+ shutil.copy(path, os.path.join(test_class_dir, os.path.basename(path)))
3228
+ duration = time.time() - start
3229
+ time_ls.append(duration)
3230
+ print_progress(processed_files, total_files, n_jobs=1, time_ls=None, batch_size=None, operation_type="Copying files for Test dataset")
3231
+ processed_files += 1
3232
+
3233
+ # Print summary
3234
+ for cls in classes:
3235
+ train_class_dir = os.path.join(dst, f'train/{cls}')
3236
+ test_class_dir = os.path.join(dst, f'test/{cls}')
3237
+ print(f'Train class {cls}: {len(os.listdir(train_class_dir))}, Test class {cls}: {len(os.listdir(test_class_dir))}')
3238
+
3239
+ return os.path.join(dst, 'train'), os.path.join(dst, 'test')