spacr 0.3.1__py3-none-any.whl → 0.3.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (41) hide show
  1. spacr/__init__.py +19 -3
  2. spacr/cellpose.py +311 -0
  3. spacr/core.py +140 -2493
  4. spacr/deep_spacr.py +151 -29
  5. spacr/gui.py +1 -0
  6. spacr/gui_core.py +74 -63
  7. spacr/gui_elements.py +110 -5
  8. spacr/gui_utils.py +346 -6
  9. spacr/io.py +624 -44
  10. spacr/logger.py +28 -9
  11. spacr/measure.py +107 -95
  12. spacr/mediar.py +0 -3
  13. spacr/ml.py +964 -0
  14. spacr/openai.py +37 -0
  15. spacr/plot.py +280 -15
  16. spacr/resources/data/lopit.csv +3833 -0
  17. spacr/resources/data/toxoplasma_metadata.csv +8843 -0
  18. spacr/resources/icons/convert.png +0 -0
  19. spacr/resources/{models/cp/toxo_plaque_cyto_e25000_X1120_Y1120.CP_model → icons/dna_matrix.mp4} +0 -0
  20. spacr/sequencing.py +241 -1311
  21. spacr/settings.py +129 -43
  22. spacr/sim.py +0 -2
  23. spacr/submodules.py +348 -0
  24. spacr/timelapse.py +0 -2
  25. spacr/toxo.py +233 -0
  26. spacr/utils.py +271 -171
  27. {spacr-0.3.1.dist-info → spacr-0.3.2.dist-info}/METADATA +7 -1
  28. {spacr-0.3.1.dist-info → spacr-0.3.2.dist-info}/RECORD +32 -33
  29. spacr/chris.py +0 -50
  30. spacr/graph_learning.py +0 -340
  31. spacr/resources/MEDIAR/.git +0 -1
  32. spacr/resources/MEDIAR_weights/.DS_Store +0 -0
  33. spacr/resources/icons/.DS_Store +0 -0
  34. spacr/resources/icons/spacr_logo_rotation.gif +0 -0
  35. spacr/resources/models/cp/toxo_plaque_cyto_e25000_X1120_Y1120.CP_model_settings.csv +0 -23
  36. spacr/resources/models/cp/toxo_pv_lumen.CP_model +0 -0
  37. spacr/sim_app.py +0 -0
  38. {spacr-0.3.1.dist-info → spacr-0.3.2.dist-info}/LICENSE +0 -0
  39. {spacr-0.3.1.dist-info → spacr-0.3.2.dist-info}/WHEEL +0 -0
  40. {spacr-0.3.1.dist-info → spacr-0.3.2.dist-info}/entry_points.txt +0 -0
  41. {spacr-0.3.1.dist-info → spacr-0.3.2.dist-info}/top_level.txt +0 -0
spacr/io.py CHANGED
@@ -1,30 +1,132 @@
1
- import os, re, sqlite3, gc, torch, time, random, shutil, cv2, tarfile, cellpose, glob, queue
1
+ import os, re, sqlite3, gc, torch, time, random, shutil, cv2, tarfile, cellpose, glob, queue, tifffile, czifile, atexit, datetime
2
2
  import numpy as np
3
3
  import pandas as pd
4
- import tifffile
5
4
  from PIL import Image, ImageOps
6
- from collections import defaultdict, Counter, deque
5
+ from collections import defaultdict, Counter
7
6
  from pathlib import Path
8
7
  from functools import partial
9
8
  from matplotlib.animation import FuncAnimation
10
9
  from IPython.display import display
11
10
  from skimage.util import img_as_uint
12
11
  from skimage.exposure import rescale_intensity
13
- from skimage import filters
14
12
  import skimage.measure as measure
15
13
  from skimage import exposure
16
14
  import imageio.v2 as imageio2
17
15
  import matplotlib.pyplot as plt
18
16
  from io import BytesIO
19
- from IPython.display import display, clear_output
20
- from multiprocessing import Pool, cpu_count, Process, Queue
21
- from torch.utils.data import Dataset, DataLoader
17
+ from IPython.display import display
18
+ from multiprocessing import Pool, cpu_count, Process, Queue, Value, Lock
19
+ from torch.utils.data import Dataset, DataLoader, random_split
22
20
  import matplotlib.pyplot as plt
23
21
  from torchvision.transforms import ToTensor
24
- import seaborn as sns
25
- import atexit
26
-
27
- from .logger import log_function_call
22
+ import seaborn as sns
23
+ from nd2reader import ND2Reader
24
+ from torchvision import transforms
25
+
26
+ def process_non_tif_non_2D_images(folder):
27
+ """Processes all images in the folder and splits them into grayscale channels, preserving bit depth."""
28
+
29
+ # Helper function to save grayscale images
30
+ def save_grayscale_images(image, base_name, folder, dtype, channel=None, z=None, t=None):
31
+ """Save grayscale images with appropriate suffix based on channel, z, and t, preserving bit depth."""
32
+ suffix = ""
33
+ if channel is not None:
34
+ suffix += f"_C{channel}"
35
+ if z is not None:
36
+ suffix += f"_Z{z}"
37
+ if t is not None:
38
+ suffix += f"_T{t}"
39
+
40
+ output_filename = os.path.join(folder, f"{base_name}{suffix}.tif")
41
+ tifffile.imwrite(output_filename, image.astype(dtype))
42
+
43
+ # Function to handle splitting of multi-dimensional images into grayscale channels
44
+ def split_channels(image, folder, base_name, dtype):
45
+ """Splits the image into channels and handles 3D, 4D, and 5D image cases."""
46
+ if image.ndim == 2:
47
+ # Grayscale image, already processed separately
48
+ return
49
+
50
+ elif image.ndim == 3:
51
+ # 3D image: (height, width, channels)
52
+ for c in range(image.shape[2]):
53
+ save_grayscale_images(image[..., c], base_name, folder, dtype, channel=c+1)
54
+
55
+ elif image.ndim == 4:
56
+ # 4D image: (height, width, channels, Z-dimension)
57
+ for z in range(image.shape[3]):
58
+ for c in range(image.shape[2]):
59
+ save_grayscale_images(image[..., c, z], base_name, folder, dtype, channel=c+1, z=z+1)
60
+
61
+ elif image.ndim == 5:
62
+ # 5D image: (height, width, channels, Z-dimension, Time)
63
+ for t in range(image.shape[4]):
64
+ for z in range(image.shape[3]):
65
+ for c in range(image.shape[2]):
66
+ save_grayscale_images(image[..., c, z, t], base_name, folder, dtype, channel=c+1, z=z+1, t=t+1)
67
+
68
+ # Function to load images in various formats
69
+ def load_image(file_path):
70
+ """Loads image from various formats and returns it as a numpy array along with its dtype."""
71
+ ext = os.path.splitext(file_path)[1].lower()
72
+
73
+ if ext in ['.tif', '.tiff']:
74
+ image = tifffile.imread(file_path)
75
+ return image, image.dtype
76
+
77
+ elif ext in ['.png', '.jpg', '.jpeg']:
78
+ image = Image.open(file_path)
79
+ return np.array(image), image.mode
80
+
81
+ elif ext == '.czi':
82
+ with czifile.CziFile(file_path) as czi:
83
+ image = czi.asarray()
84
+ return image, image.dtype
85
+
86
+ elif ext == '.nd2':
87
+ with ND2Reader(file_path) as nd2:
88
+ image = np.array(nd2)
89
+ return image, image.dtype
90
+
91
+ else:
92
+ raise ValueError(f"Unsupported file extension: {ext}")
93
+
94
+ # Function to check if an image is grayscale and save it as a TIFF if it isn't already
95
+ def convert_grayscale_to_tiff(image, filename, folder, dtype):
96
+ """Convert grayscale images that are not in TIFF format to TIFF, preserving bit depth."""
97
+ base_name = os.path.splitext(filename)[0]
98
+ output_filename = os.path.join(folder, f"{base_name}.tif")
99
+ tifffile.imwrite(output_filename, image.astype(dtype))
100
+ print(f"Converted grayscale image {filename} to TIFF with bit depth {dtype}.")
101
+
102
+ # Supported formats
103
+ supported_formats = ['.tif', '.tiff', '.png', '.jpg', '.jpeg', '.czi', '.nd2']
104
+
105
+ # Loop through all files in the folder
106
+ for filename in os.listdir(folder):
107
+ file_path = os.path.join(folder, filename)
108
+ ext = os.path.splitext(file_path)[1].lower()
109
+
110
+ if ext in supported_formats:
111
+ print(f"Processing {filename}")
112
+ try:
113
+ # Load the image and its dtype
114
+ image, dtype = load_image(file_path)
115
+
116
+ # If the image is grayscale (2D), convert it to TIFF if it's not already in TIFF format
117
+ if image.ndim == 2:
118
+ if ext not in ['.tif', '.tiff']:
119
+ convert_grayscale_to_tiff(image, filename, folder, dtype)
120
+ else:
121
+ print(f"Image {filename} is already grayscale and in TIFF format, skipping.")
122
+ continue
123
+
124
+ # Otherwise, split channels and save images
125
+ base_name = os.path.splitext(filename)[0]
126
+ split_channels(image, folder, base_name, dtype)
127
+
128
+ except Exception as e:
129
+ print(f"Error processing {filename}: {str(e)}")
28
130
 
29
131
  def _load_images_and_labels(image_files, label_files, circular=False, invert=False):
30
132
 
@@ -632,6 +734,20 @@ class TarImageDataset(Dataset):
632
734
  img = self.transform(img)
633
735
 
634
736
  return img, m.name
737
+
738
+ def load_images_from_paths(images_by_key):
739
+ images_dict = {}
740
+
741
+ for key, paths in images_by_key.items():
742
+ images_dict[key] = []
743
+ for path in paths:
744
+ try:
745
+ with Image.open(path) as img:
746
+ images_dict[key].append(np.array(img))
747
+ except Exception as e:
748
+ print(f"Error loading image from {path}: {e}")
749
+
750
+ return images_dict
635
751
 
636
752
  #@log_function_call
637
753
  def _rename_and_organize_image_files(src, regex, batch_size=100, pick_slice=False, skip_mode='01', metadata_type='', img_format='.tif'):
@@ -657,15 +773,20 @@ def _rename_and_organize_image_files(src, regex, batch_size=100, pick_slice=Fals
657
773
  files_processed = 0
658
774
  if not os.path.exists(stack_path) or (os.path.isdir(stack_path) and len(os.listdir(stack_path)) == 0):
659
775
  all_filenames = [filename for filename in os.listdir(src) if filename.endswith(img_format)]
660
- print(f'All_files: {len(all_filenames)} in {src}')
776
+ print(f'All files: {len(all_filenames)} in {src}')
661
777
  time_ls = []
662
-
663
- for idx in range(0, len(all_filenames), batch_size):
778
+ image_paths_by_key = _extract_filename_metadata(all_filenames, src, regular_expression, metadata_type, pick_slice, skip_mode)
779
+ # Convert dictionary keys to a list for batching
780
+ batching_keys = list(image_paths_by_key.keys())
781
+ print(f'All unique FOV: {len(image_paths_by_key)} in {src}')
782
+ for idx in range(0, len(image_paths_by_key), batch_size):
664
783
  start = time.time()
665
- batch_filenames = all_filenames[idx:idx+batch_size]
666
- for filename in batch_filenames:
667
- images_by_key = _extract_filename_metadata(batch_filenames, src, regular_expression, metadata_type, pick_slice, skip_mode)
668
-
784
+
785
+ # Select batch keys and create a subset of the dictionary for this batch
786
+ batch_keys = batching_keys[idx:idx+batch_size]
787
+ batch_images_by_key = {key: image_paths_by_key[key] for key in batch_keys}
788
+ images_by_key = load_images_from_paths(batch_images_by_key)
789
+
669
790
  if pick_slice:
670
791
  for i, key in enumerate(images_by_key):
671
792
  plate, well, field, channel, mode = key
@@ -682,10 +803,10 @@ def _rename_and_organize_image_files(src, regex, batch_size=100, pick_slice=Fals
682
803
  files_to_process = len(all_filenames)
683
804
  print_progress(files_processed, files_to_process, n_jobs=1, time_ls=time_ls, batch_size=batch_size, operation_type='Preprocessing filenames')
684
805
 
685
- #if os.path.exists(output_path):
686
- # print(f'WARNING: A file with the same name already exists at location {output_filename}')
687
806
  if not os.path.exists(output_path):
688
807
  mip_image.save(output_path)
808
+ else:
809
+ print(f'WARNING: A file with the same name already exists at location {output_filename}')
689
810
  else:
690
811
  for i, (key, images) in enumerate(images_by_key.items()):
691
812
  plate, well, field, channel = key[:4]
@@ -702,10 +823,11 @@ def _rename_and_organize_image_files(src, regex, batch_size=100, pick_slice=Fals
702
823
  files_to_process = len(all_filenames)
703
824
  print_progress(files_processed, files_to_process, n_jobs=1, time_ls=time_ls, batch_size=batch_size, operation_type='Preprocessing filenames')
704
825
 
705
- #if os.path.exists(output_path):
706
- # print(f'WARNING: A file with the same name already exists at location {output_filename}')
707
826
  if not os.path.exists(output_path):
708
827
  mip_image.save(output_path)
828
+ else:
829
+ print(f'WARNING: A file with the same name already exists at location {output_filename}')
830
+
709
831
  images_by_key.clear()
710
832
 
711
833
  # Move original images to a new directory
@@ -961,9 +1083,7 @@ def _mip_all(src, include_first_chan=True):
961
1083
  Returns:
962
1084
  None
963
1085
  """
964
-
965
- from .utils import normalize_to_dtype
966
-
1086
+
967
1087
  #print('========== generating MIPs ==========')
968
1088
  # Iterate over each file in the specified directory (src).
969
1089
  for filename in os.listdir(src):
@@ -1337,7 +1457,6 @@ def _get_lists_for_normalization(settings):
1337
1457
  return backgrounds, signal_to_noise, signal_thresholds, remove_background
1338
1458
 
1339
1459
  def _normalize_stack(src, backgrounds=[100, 100, 100], remove_backgrounds=[False, False, False], lower_percentile=2, save_dtype=np.float32, signal_to_noise=[5, 5, 5], signal_thresholds=[1000, 1000, 1000]):
1340
- from .utils import print_progress
1341
1460
  """
1342
1461
  Normalize the stack of images.
1343
1462
 
@@ -1430,7 +1549,6 @@ def _normalize_stack(src, backgrounds=[100, 100, 100], remove_backgrounds=[False
1430
1549
  return print(f'Saved stacks: {output_fldr}')
1431
1550
 
1432
1551
  def _normalize_timelapse(src, lower_percentile=2, save_dtype=np.float32):
1433
- from .utils import print_progress
1434
1552
  """
1435
1553
  Normalize the timelapse data by rescaling the intensity values based on percentiles.
1436
1554
 
@@ -1559,7 +1677,7 @@ def delete_empty_subdirectories(folder_path):
1559
1677
  #@log_function_call
1560
1678
  def preprocess_img_data(settings):
1561
1679
 
1562
- from .plot import plot_arrays, _plot_4D_arrays
1680
+ from .plot import plot_arrays
1563
1681
  from .utils import _run_test_mode, _get_regex
1564
1682
  from .settings import set_default_settings_preprocess_img_data
1565
1683
 
@@ -2054,7 +2172,6 @@ def _load_and_concatenate_arrays(src, channels, cell_chann_dim, nucleus_chann_di
2054
2172
  padded_shapes = [shape + (0,) * (max_tuple_length - len(shape)) for shape in unique_shapes]
2055
2173
  # Now create a NumPy array and find the maximum dimensions
2056
2174
  max_dims = np.max(np.array(padded_shapes), axis=0)
2057
- #clear_output(wait=True)
2058
2175
  print(f'Warning: arrays with multiple shapes found. Padding arrays to max X,Y dimentions {max_dims}')
2059
2176
  #print(f'Warning: arrays with multiple shapes found. Padding arrays to max X,Y dimentions {max_dims}', end='\r', flush=True)
2060
2177
  padded_stack_ls = []
@@ -2102,7 +2219,7 @@ def _read_db(db_loc, tables):
2102
2219
  conn.close()
2103
2220
  return dfs
2104
2221
 
2105
- def _read_and_merge_data(locs, tables, verbose=False, include_multinucleated=False, include_multiinfected=False, include_noninfected=False):
2222
+ def _read_and_merge_data(locs, tables, verbose=False, nuclei_limit=False, pathogen_limit=False, uninfected=False):
2106
2223
  """
2107
2224
  Read and merge data from SQLite databases and perform data preprocessing.
2108
2225
 
@@ -2110,9 +2227,9 @@ def _read_and_merge_data(locs, tables, verbose=False, include_multinucleated=Fal
2110
2227
  - locs (list): A list of file paths to the SQLite database files.
2111
2228
  - tables (list): A list of table names to read from the databases.
2112
2229
  - verbose (bool): Whether to print verbose output. Default is False.
2113
- - include_multinucleated (bool): Whether to include multinucleated cells. Default is False.
2114
- - include_multiinfected (bool): Whether to include cells with multiple infections. Default is False.
2115
- - include_noninfected (bool): Whether to include non-infected cells. Default is False.
2230
+ - nuclei_limit (bool): Whether to include multinucleated cells. Default is False.
2231
+ - pathogen_limit (bool): Whether to include cells with multiple infections. Default is False.
2232
+ - uninfected (bool): Whether to include non-infected cells. Default is False.
2116
2233
 
2117
2234
  Returns:
2118
2235
  - merged_df (pandas.DataFrame): The merged and preprocessed dataframe.
@@ -2187,7 +2304,7 @@ def _read_and_merge_data(locs, tables, verbose=False, include_multinucleated=Fal
2187
2304
  nucleus = nucleus.assign(cell_id=lambda x: 'o' + x['cell_id'].astype(int).astype(str))
2188
2305
  nucleus = nucleus.assign(prcfo = lambda x: x['prcf'] + '_' + x['cell_id'])
2189
2306
  nucleus['nucleus_prcfo_count'] = nucleus.groupby('prcfo')['prcfo'].transform('count')
2190
- if include_multinucleated == False:
2307
+ if nuclei_limit == False:
2191
2308
  #nucleus = nucleus[~nucleus['prcfo'].duplicated()]
2192
2309
  nucleus = nucleus[nucleus['nucleus_prcfo_count']==1]
2193
2310
  nucleus_g_df, _ = _split_data(nucleus, 'prcfo', 'cell_id')
@@ -2203,9 +2320,9 @@ def _read_and_merge_data(locs, tables, verbose=False, include_multinucleated=Fal
2203
2320
  pathogens = pathogens.assign(cell_id=lambda x: 'o' + x['cell_id'].astype(int).astype(str))
2204
2321
  pathogens = pathogens.assign(prcfo = lambda x: x['prcf'] + '_' + x['cell_id'])
2205
2322
  pathogens['pathogen_prcfo_count'] = pathogens.groupby('prcfo')['prcfo'].transform('count')
2206
- if include_noninfected == False:
2323
+ if uninfected == False:
2207
2324
  pathogens = pathogens[pathogens['pathogen_prcfo_count']>=1]
2208
- if include_multiinfected == False:
2325
+ if pathogen_limit == False:
2209
2326
  pathogens = pathogens[pathogens['pathogen_prcfo_count']<=1]
2210
2327
  pathogens_g_df, _ = _split_data(pathogens, 'prcfo', 'cell_id')
2211
2328
  print(f'pathogens: {len(pathogens)}')
@@ -2448,7 +2565,7 @@ def _read_db(db_loc, tables):
2448
2565
  conn.close() # Close the connection
2449
2566
  return dfs
2450
2567
 
2451
- def _read_and_merge_data(locs, tables, verbose=False, include_multinucleated=False, include_multiinfected=False, include_noninfected=False):
2568
+ def _read_and_merge_data(locs, tables, verbose=False, nuclei_limit=False, pathogen_limit=False, uninfected=False):
2452
2569
 
2453
2570
  from .utils import _split_data
2454
2571
 
@@ -2533,7 +2650,7 @@ def _read_and_merge_data(locs, tables, verbose=False, include_multinucleated=Fal
2533
2650
  nucleus = nucleus.assign(cell_id=lambda x: 'o' + x['cell_id'].astype(int).astype(str))
2534
2651
  nucleus = nucleus.assign(prcfo = lambda x: x['prcf'] + '_' + x['cell_id'])
2535
2652
  nucleus['nucleus_prcfo_count'] = nucleus.groupby('prcfo')['prcfo'].transform('count')
2536
- if include_multinucleated == False:
2653
+ if nuclei_limit == False:
2537
2654
  nucleus = nucleus[nucleus['nucleus_prcfo_count']==1]
2538
2655
  nucleus_g_df, _ = _split_data(nucleus, 'prcfo', 'cell_id')
2539
2656
  if verbose:
@@ -2559,20 +2676,30 @@ def _read_and_merge_data(locs, tables, verbose=False, include_multinucleated=Fal
2559
2676
  pathogens = pathogens.assign(cell_id=lambda x: 'o' + x['cell_id'].astype(int).astype(str))
2560
2677
  pathogens = pathogens.assign(prcfo = lambda x: x['prcf'] + '_' + x['cell_id'])
2561
2678
  pathogens['pathogen_prcfo_count'] = pathogens.groupby('prcfo')['prcfo'].transform('count')
2562
- if include_noninfected == False:
2679
+
2680
+ print(f"before noninfected: {len(pathogens)}")
2681
+ if uninfected == False:
2563
2682
  pathogens = pathogens[pathogens['pathogen_prcfo_count']>=1]
2564
- if isinstance(include_multiinfected, bool):
2565
- if include_multiinfected == False:
2683
+ print(f"after noninfected: {len(pathogens)}")
2684
+
2685
+ if isinstance(pathogen_limit, bool):
2686
+ if pathogen_limit == False:
2566
2687
  pathogens = pathogens[pathogens['pathogen_prcfo_count']<=1]
2567
- if isinstance(include_multiinfected, float):
2568
- pathogens = pathogens[pathogens['pathogen_prcfo_count']<=include_multiinfected]
2688
+ print(f"after multiinfected Bool: {len(pathogens)}")
2689
+ if isinstance(pathogen_limit, float):
2690
+ pathogen_limit = int(pathogen_limit)
2691
+ if isinstance(pathogen_limit, int):
2692
+ pathogens = pathogens[pathogens['pathogen_prcfo_count']<=pathogen_limit]
2693
+ print(f"afer multiinfected Float: {len(pathogens)}")
2569
2694
  if not 'cell' in tables:
2570
2695
  pathogens_g_df, metadata = _split_data(pathogens, 'prcfo', 'cell_id')
2571
2696
  else:
2572
2697
  pathogens_g_df, _ = _split_data(pathogens, 'prcfo', 'cell_id')
2698
+
2573
2699
  if verbose:
2574
2700
  print(f'pathogens: {len(pathogens)}')
2575
2701
  print(f'pathogens grouped: {len(pathogens_g_df)}')
2702
+
2576
2703
  if len(merged_df) == 0:
2577
2704
  merged_df = pathogens_g_df
2578
2705
  else:
@@ -2697,4 +2824,457 @@ def generate_cellpose_train_test(src, test_split=0.1):
2697
2824
  shutil.copy(img_path, new_img_path)
2698
2825
  shutil.copy(mask_path, new_mask_path)
2699
2826
  print(f'Copied {idx+1}/{len(ls)} images to {_type} set')#, end='\r', flush=True)
2700
-
2827
+
2828
+ def parse_gz_files(folder_path):
2829
+ """
2830
+ Parses the .fastq.gz files in the specified folder path and returns a dictionary
2831
+ containing the sample names and their corresponding file paths.
2832
+
2833
+ Args:
2834
+ folder_path (str): The path to the folder containing the .fastq.gz files.
2835
+
2836
+ Returns:
2837
+ dict: A dictionary where the keys are the sample names and the values are
2838
+ dictionaries containing the file paths for the 'R1' and 'R2' read directions.
2839
+ """
2840
+ files = os.listdir(folder_path)
2841
+ gz_files = [f for f in files if f.endswith('.fastq.gz')]
2842
+
2843
+ samples_dict = {}
2844
+ for gz_file in gz_files:
2845
+ parts = gz_file.split('_')
2846
+ sample_name = parts[0]
2847
+ read_direction = parts[1]
2848
+
2849
+ if sample_name not in samples_dict:
2850
+ samples_dict[sample_name] = {}
2851
+
2852
+ if read_direction == "R1":
2853
+ samples_dict[sample_name]['R1'] = os.path.join(folder_path, gz_file)
2854
+ elif read_direction == "R2":
2855
+ samples_dict[sample_name]['R2'] = os.path.join(folder_path, gz_file)
2856
+ return samples_dict
2857
+
2858
+ def generate_dataset(settings={}):
2859
+
2860
+ from .utils import initiate_counter, add_images_to_tar, save_settings, generate_path_list_from_db, correct_paths
2861
+ from .settings import set_generate_dataset_defaults
2862
+
2863
+ settings = set_generate_dataset_defaults(settings)
2864
+ save_settings(settings, 'generate_dataset', show=True)
2865
+
2866
+ if isinstance(settings['src'], str):
2867
+ settings['src'] = [settings['src']]
2868
+ if isinstance(settings['src'], list):
2869
+ all_paths = []
2870
+ for i, src in enumerate(settings['src']):
2871
+ db_path = os.path.join(src, 'measurements', 'measurements.db')
2872
+ dst = os.path.join(src, 'datasets')
2873
+ paths = generate_path_list_from_db(db_path, file_metadata=settings['file_metadata'])
2874
+ correct_paths(paths, src)
2875
+ all_paths.extend(paths)
2876
+ if isinstance(settings['sample'], int):
2877
+ selected_paths = random.sample(all_paths, settings['sample'])
2878
+ print(f"Random selection of {len(selected_paths)} paths")
2879
+ elif isinstance(settings['sample'], list):
2880
+ sample = settings['sample'][i]
2881
+ selected_paths = random.sample(all_paths, settings['sample'])
2882
+ print(f"Random selection of {len(selected_paths)} paths")
2883
+ else:
2884
+ selected_paths = all_paths
2885
+ random.shuffle(selected_paths)
2886
+ print(f"All paths: {len(selected_paths)} paths")
2887
+
2888
+ total_images = len(selected_paths)
2889
+ print(f"Found {total_images} images")
2890
+
2891
+ # Create a temp folder in dst
2892
+ temp_dir = os.path.join(dst, "temp_tars")
2893
+ os.makedirs(temp_dir, exist_ok=True)
2894
+
2895
+ # Chunking the data
2896
+ num_procs = max(2, cpu_count() - 2)
2897
+ chunk_size = len(selected_paths) // num_procs
2898
+ remainder = len(selected_paths) % num_procs
2899
+
2900
+ paths_chunks = []
2901
+ start = 0
2902
+ for i in range(num_procs):
2903
+ end = start + chunk_size + (1 if i < remainder else 0)
2904
+ paths_chunks.append(selected_paths[start:end])
2905
+ start = end
2906
+
2907
+ temp_tar_files = [os.path.join(temp_dir, f"temp_{i}.tar") for i in range(num_procs)]
2908
+
2909
+ print(f"Generating temporary tar files in {dst}")
2910
+
2911
+ # Initialize shared counter and lock
2912
+ counter = Value('i', 0)
2913
+ lock = Lock()
2914
+
2915
+ with Pool(processes=num_procs, initializer=initiate_counter, initargs=(counter, lock)) as pool:
2916
+ pool.starmap(add_images_to_tar, [(paths_chunks[i], temp_tar_files[i], total_images) for i in range(num_procs)])
2917
+
2918
+ # Combine the temporary tar files into a final tar
2919
+ date_name = datetime.date.today().strftime('%y%m%d')
2920
+ if not settings['file_metadata'] is None:
2921
+ tar_name = f"{date_name}_{settings['experiment']}_{settings['file_metadata']}.tar"
2922
+ else:
2923
+ tar_name = f"{date_name}_{settings['experiment']}.tar"
2924
+ tar_name = os.path.join(dst, tar_name)
2925
+ if os.path.exists(tar_name):
2926
+ number = random.randint(1, 100)
2927
+ tar_name_2 = f"{date_name}_{settings['experiment']}_{settings['file_metadata']}_{number}.tar"
2928
+ print(f"Warning: {os.path.basename(tar_name)} exists, saving as {os.path.basename(tar_name_2)} ")
2929
+ tar_name = os.path.join(dst, tar_name_2)
2930
+
2931
+ print(f"Merging temporary files")
2932
+
2933
+ with tarfile.open(tar_name, 'w') as final_tar:
2934
+ for temp_tar_path in temp_tar_files:
2935
+ with tarfile.open(temp_tar_path, 'r') as temp_tar:
2936
+ for member in temp_tar.getmembers():
2937
+ file_obj = temp_tar.extractfile(member)
2938
+ final_tar.addfile(member, file_obj)
2939
+ os.remove(temp_tar_path)
2940
+
2941
+ # Delete the temp folder
2942
+ shutil.rmtree(temp_dir)
2943
+ print(f"\nSaved {total_images} images to {tar_name}")
2944
+
2945
+ return tar_name
2946
+
2947
+ def generate_loaders(src, mode='train', image_size=224, batch_size=32, classes=['nc','pc'], n_jobs=None, validation_split=0.0, pin_memory=False, normalize=False, channels=[1, 2, 3], augment=False, verbose=False):
2948
+
2949
+ """
2950
+ Generate data loaders for training and validation/test datasets.
2951
+
2952
+ Parameters:
2953
+ - src (str): The source directory containing the data.
2954
+ - mode (str): The mode of operation. Options are 'train' or 'test'.
2955
+ - image_size (int): The size of the input images.
2956
+ - batch_size (int): The batch size for the data loaders.
2957
+ - classes (list): The list of classes to consider.
2958
+ - n_jobs (int): The number of worker threads for data loading.
2959
+ - validation_split (float): The fraction of data to use for validation.
2960
+ - pin_memory (bool): Whether to pin memory for faster data transfer.
2961
+ - normalize (bool): Whether to normalize the input images.
2962
+ - verbose (bool): Whether to print additional information and show images.
2963
+ - channels (list): The list of channels to retain. Options are [1, 2, 3] for all channels, [1, 2] for blue and green, etc.
2964
+
2965
+ Returns:
2966
+ - train_loaders (list): List of data loaders for training datasets.
2967
+ - val_loaders (list): List of data loaders for validation datasets.
2968
+ """
2969
+
2970
+ from .io import spacrDataset
2971
+ from .utils import SelectChannels, augment_dataset
2972
+
2973
+ chans = []
2974
+
2975
+ if 'r' in channels:
2976
+ chans.append(1)
2977
+ if 'g' in channels:
2978
+ chans.append(2)
2979
+ if 'b' in channels:
2980
+ chans.append(3)
2981
+
2982
+ channels = chans
2983
+
2984
+ if verbose:
2985
+ print(f'Training a network on channels: {channels}')
2986
+ print(f'Channel 1: Red, Channel 2: Green, Channel 3: Blue')
2987
+
2988
+ train_loaders = []
2989
+ val_loaders = []
2990
+
2991
+ if normalize:
2992
+ transform = transforms.Compose([
2993
+ transforms.ToTensor(),
2994
+ transforms.CenterCrop(size=(image_size, image_size)),
2995
+ SelectChannels(channels),
2996
+ transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])
2997
+ else:
2998
+ transform = transforms.Compose([
2999
+ transforms.ToTensor(),
3000
+ transforms.CenterCrop(size=(image_size, image_size)),
3001
+ SelectChannels(channels)])
3002
+
3003
+ if mode == 'train':
3004
+ data_dir = os.path.join(src, 'train')
3005
+ shuffle = True
3006
+ print('Loading Train and validation datasets')
3007
+ elif mode == 'test':
3008
+ data_dir = os.path.join(src, 'test')
3009
+ val_loaders = []
3010
+ validation_split = 0.0
3011
+ shuffle = True
3012
+ print('Loading test dataset')
3013
+ else:
3014
+ print(f'mode:{mode} is not valid, use mode = train or test')
3015
+ return
3016
+
3017
+ data = spacrDataset(data_dir, classes, transform=transform, shuffle=shuffle, pin_memory=pin_memory)
3018
+ num_workers = n_jobs if n_jobs is not None else 0
3019
+
3020
+ if validation_split > 0:
3021
+ train_size = int((1 - validation_split) * len(data))
3022
+ val_size = len(data) - train_size
3023
+ if not augment:
3024
+ print(f'Train data:{train_size}, Validation data:{val_size}')
3025
+ train_dataset, val_dataset = random_split(data, [train_size, val_size])
3026
+
3027
+ if augment:
3028
+
3029
+ print(f'Data before augmentation: Train: {len(train_dataset)}, Validataion:{len(val_dataset)}')
3030
+ train_dataset = augment_dataset(train_dataset, is_grayscale=(len(channels) == 1))
3031
+ print(f'Data after augmentation: Train: {len(train_dataset)}')
3032
+
3033
+ print(f'Generating Dataloader with {n_jobs} workers')
3034
+ train_loaders = DataLoader(train_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=1, pin_memory=pin_memory, persistent_workers=True)
3035
+ val_loaders = DataLoader(val_dataset, batch_size=batch_size, shuffle=shuffle, num_workers=1, pin_memory=pin_memory, persistent_workers=True)
3036
+ else:
3037
+ train_loaders = DataLoader(data, batch_size=batch_size, shuffle=shuffle, num_workers=1, pin_memory=pin_memory, persistent_workers=True)
3038
+
3039
+ #dataset (Dataset) – dataset from which to load the data.
3040
+ #batch_size (int, optional) – how many samples per batch to load (default: 1).
3041
+ #shuffle (bool, optional) – set to True to have the data reshuffled at every epoch (default: False).
3042
+ #sampler (Sampler or Iterable, optional) – defines the strategy to draw samples from the dataset. Can be any Iterable with __len__ implemented. If specified, shuffle must not be specified.
3043
+ #batch_sampler (Sampler or Iterable, optional) – like sampler, but returns a batch of indices at a time. Mutually exclusive with batch_size, shuffle, sampler, and drop_last.
3044
+ #num_workers (int, optional) – how many subprocesses to use for data loading. 0 means that the data will be loaded in the main process. (default: 0)
3045
+ #collate_fn (Callable, optional) – merges a list of samples to form a mini-batch of Tensor(s). Used when using batched loading from a map-style dataset.
3046
+ #pin_memory (bool, optional) – If True, the data loader will copy Tensors into device/CUDA pinned memory before returning them. If your data elements are a custom type, or your collate_fn returns a batch that is a custom type, see the example below.
3047
+ #drop_last (bool, optional) – set to True to drop the last incomplete batch, if the dataset size is not divisible by the batch size. If False and the size of dataset is not divisible by the batch size, then the last batch will be smaller. (default: False)
3048
+ #timeout (numeric, optional) – if positive, the timeout value for collecting a batch from workers. Should always be non-negative. (default: 0)
3049
+ #worker_init_fn (Callable, optional) – If not None, this will be called on each worker subprocess with the worker id (an int in [0, num_workers - 1]) as input, after seeding and before data loading. (default: None)
3050
+ #multiprocessing_context (str or multiprocessing.context.BaseContext, optional) – If None, the default multiprocessing context of your operating system will be used. (default: None)
3051
+ #generator (torch.Generator, optional) – If not None, this RNG will be used by RandomSampler to generate random indexes and multiprocessing to generate base_seed for workers. (default: None)
3052
+ #prefetch_factor (int, optional, keyword-only arg) – Number of batches loaded in advance by each worker. 2 means there will be a total of 2 * num_workers batches prefetched across all workers. (default value depends on the set value for num_workers. If value of num_workers=0 default is None. Otherwise, if value of num_workers > 0 default is 2).
3053
+ #persistent_workers (bool, optional) – If True, the data loader will not shut down the worker processes after a dataset has been consumed once. This allows to maintain the workers Dataset instances alive. (default: False)
3054
+ #pin_memory_device (str, optional) – the device to pin_memory to if pin_memory is True.
3055
+
3056
+ #images, labels, filenames = next(iter(train_loaders))
3057
+ #images = images.cpu()
3058
+ #label_strings = [str(label.item()) for label in labels]
3059
+ #train_fig = _imshow_gpu(images, label_strings, nrow=20, fontsize=12)
3060
+ #if verbose:
3061
+ # plt.show()
3062
+
3063
+ train_fig = None
3064
+
3065
+ return train_loaders, val_loaders, train_fig
3066
+
3067
+ def generate_training_dataset(settings):
3068
+
3069
+ from .io import _read_and_merge_data, _read_db
3070
+ from .utils import get_paths_from_db, annotate_conditions, save_settings
3071
+ from .settings import set_generate_training_dataset_defaults
3072
+
3073
+ # Function to filter png_list_df by prcfo present in df without merging
3074
+ def filter_png_list(db_path, settings):
3075
+ tables = ['cell', 'nucleus', 'pathogen', 'cytoplasm']
3076
+ df, _ = _read_and_merge_data(locs=[db_path],
3077
+ tables=tables,
3078
+ verbose=False,
3079
+ nuclei_limit=settings['nuclei_limit'],
3080
+ pathogen_limit=settings['pathogen_limit'],
3081
+ uninfected=settings['uninfected'])
3082
+ [png_list_df] = _read_db(db_loc=db_path, tables=['png_list'])
3083
+ filtered_png_list_df = png_list_df[png_list_df['prcfo'].isin(df.index)]
3084
+ return filtered_png_list_df
3085
+
3086
+ # Function to get the smallest class size based on the dataset mode
3087
+ def get_smallest_class_size(df, settings, dataset_mode):
3088
+ if dataset_mode == 'metadata':
3089
+ sizes = [len(df[df['metadata_based_class'] == c]) for c in settings['classes']]
3090
+ elif dataset_mode == 'annotation':
3091
+ sizes = [len(class_paths) for class_paths in df]
3092
+ size = min(sizes)
3093
+ print(f'Using the smallest class size: {size}')
3094
+ return size
3095
+
3096
+ # Measurement-based selection logic
3097
+ def measurement_based_selection(settings, db_path):
3098
+ class_paths_ls = []
3099
+ tables = ['cell', 'nucleus', 'pathogen', 'cytoplasm']
3100
+ df, _ = _read_and_merge_data(locs=[db_path],
3101
+ tables=tables,
3102
+ verbose=False,
3103
+ nuclei_limit=settings['nuclei_limit'],
3104
+ pathogen_limit=settings['pathogen_limit'],
3105
+ uninfected=settings['uninfected'])
3106
+
3107
+ print('length df 1', len(df))
3108
+ df = annotate_conditions(df, cells=['HeLa'], pathogens=['pathogen'], treatments=settings['classes'],
3109
+ treatment_loc=settings['class_metadata'])#, types=settings['metadata_type_by'])
3110
+ print('length df 2', len(df))
3111
+
3112
+ png_list_df = filter_png_list(db_path, settings)
3113
+
3114
+ if settings['custom_measurement']:
3115
+ if isinstance(settings['custom_measurement'], list):
3116
+ if len(settings['custom_measurement']) == 2:
3117
+ df['recruitment'] = df[f"{settings['custom_measurement'][0]}"] / df[f"{settings['custom_measurement'][1]}"]
3118
+ else:
3119
+ df['recruitment'] = df[f"{settings['custom_measurement'][0]}"]
3120
+ else:
3121
+ print("custom_measurement should be a list.")
3122
+ return
3123
+
3124
+ else:
3125
+ df['recruitment'] = df[f"pathogen_channel_{settings['channel_of_interest']}_mean_intensity"] / df[f"cytoplasm_channel_{settings['channel_of_interest']}_mean_intensity"]
3126
+
3127
+ q25 = df['recruitment'].quantile(0.25)
3128
+ q75 = df['recruitment'].quantile(0.75)
3129
+ df_lower = df[df['recruitment'] <= q25]
3130
+ df_upper = df[df['recruitment'] >= q75]
3131
+
3132
+ class_paths_lower = get_paths_from_db(df=df_lower, png_df=png_list_df, image_type=settings['png_type'])
3133
+ class_paths_lower = random.sample(class_paths_lower['png_path'].tolist(), settings['size'])
3134
+ class_paths_ls.append(class_paths_lower)
3135
+
3136
+ class_paths_upper = get_paths_from_db(df=df_upper, png_df=png_list_df, image_type=settings['png_type'])
3137
+ class_paths_upper = random.sample(class_paths_upper['png_path'].tolist(), settings['size'])
3138
+ class_paths_ls.append(class_paths_upper)
3139
+
3140
+ return class_paths_ls
3141
+
3142
+ # Metadata-based selection logic
3143
+ def metadata_based_selection(db_path, settings):
3144
+ class_paths_ls = []
3145
+ df = filter_png_list(db_path, settings)
3146
+
3147
+ df['metadata_based_class'] = pd.NA
3148
+ for i, class_ in enumerate(settings['classes']):
3149
+ ls = settings['class_metadata'][i]
3150
+ df.loc[df[settings['metadata_type_by']].isin(ls), 'metadata_based_class'] = class_
3151
+
3152
+ size = get_smallest_class_size(df, settings, 'metadata')
3153
+ for class_ in settings['classes']:
3154
+ class_temp_df = df[df['metadata_based_class'] == class_]
3155
+ print(f'Found {len(class_temp_df)} images for class {class_}')
3156
+ class_paths_temp = class_temp_df['png_path'].tolist()
3157
+
3158
+ # Ensure to sample `size` number of images (smallest class size)
3159
+ if len(class_paths_temp) > size:
3160
+ class_paths_temp = random.sample(class_paths_temp, size)
3161
+
3162
+ class_paths_ls.append(class_paths_temp)
3163
+
3164
+ return class_paths_ls
3165
+
3166
+ # Annotation-based selection logic
3167
+ def annotation_based_selection(db_path, dst, settings):
3168
+ class_paths_ls = training_dataset_from_annotation(db_path, dst, settings['annotation_column'], annotated_classes=settings['annotated_classes'])
3169
+
3170
+ size = get_smallest_class_size(class_paths_ls, settings, 'annotation')
3171
+ for i, class_paths in enumerate(class_paths_ls):
3172
+ if len(class_paths) > size:
3173
+ class_paths_ls[i] = random.sample(class_paths, size)
3174
+
3175
+ return class_paths_ls
3176
+
3177
+ # Set default settings and save
3178
+ settings = set_generate_training_dataset_defaults(settings)
3179
+ save_settings(settings, 'cv_dataset', show=True)
3180
+
3181
+ db_path = os.path.join(settings['src'], 'measurements', 'measurements.db')
3182
+ dst = os.path.join(settings['src'], 'datasets', 'training')
3183
+
3184
+ # Create a new directory for training data if necessary
3185
+ if os.path.exists(dst):
3186
+ for i in range(1, 100000):
3187
+ dst = os.path.join(settings['src'], 'datasets', f'training_{i}')
3188
+ if not os.path.exists(dst):
3189
+ print(f'Creating new directory for training: {dst}')
3190
+ break
3191
+
3192
+ # Select dataset based on dataset mode
3193
+ if settings['dataset_mode'] == 'annotation':
3194
+ class_paths_ls = annotation_based_selection(db_path, dst, settings)
3195
+
3196
+ elif settings['dataset_mode'] == 'metadata':
3197
+ class_paths_ls = metadata_based_selection(db_path, settings)
3198
+
3199
+ elif settings['dataset_mode'] == 'measurement':
3200
+ class_paths_ls = measurement_based_selection(settings, db_path)
3201
+
3202
+ # Generate and return training and testing directories
3203
+ train_class_dir, test_class_dir = generate_dataset_from_lists(dst, class_data=class_paths_ls, classes=settings['classes'], test_split=settings['test_split'])
3204
+
3205
+ return train_class_dir, test_class_dir
3206
+
3207
+ def training_dataset_from_annotation(db_path, dst, annotation_column='test', annotated_classes=(1, 2)):
3208
+ all_paths = []
3209
+
3210
+ # Connect to the database and retrieve the image paths and annotations
3211
+ print(f'Reading DataBase: {db_path}')
3212
+ with sqlite3.connect(db_path) as conn:
3213
+ cursor = conn.cursor()
3214
+ # Prepare the query with parameterized placeholders for annotated_classes
3215
+ placeholders = ','.join('?' * len(annotated_classes))
3216
+ query = f"SELECT png_path, {annotation_column} FROM png_list WHERE {annotation_column} IN ({placeholders})"
3217
+ cursor.execute(query, annotated_classes)
3218
+
3219
+ while True:
3220
+ rows = cursor.fetchmany(1000)
3221
+ if not rows:
3222
+ break
3223
+ for row in rows:
3224
+ all_paths.append(row)
3225
+
3226
+ # Filter paths based on annotation
3227
+ class_paths = []
3228
+ for class_ in annotated_classes:
3229
+ class_paths_temp = [path for path, annotation in all_paths if annotation == class_]
3230
+ class_paths.append(class_paths_temp)
3231
+
3232
+ print(f'Generated a list of lists from annotation of {len(class_paths)} classes')
3233
+ return class_paths
3234
+
3235
+ def generate_dataset_from_lists(dst, class_data, classes, test_split=0.1):
3236
+ from .utils import print_progress
3237
+ from .deep_spacr import train_test_split
3238
+ # Make sure that the length of class_data matches the length of classes
3239
+ if len(class_data) != len(classes):
3240
+ raise ValueError("class_data and classes must have the same length.")
3241
+
3242
+ total_files = sum(len(data) for data in class_data)
3243
+ processed_files = 0
3244
+ time_ls = []
3245
+
3246
+ for cls, data in zip(classes, class_data):
3247
+ # Create directories
3248
+ train_class_dir = os.path.join(dst, f'train/{cls}')
3249
+ test_class_dir = os.path.join(dst, f'test/{cls}')
3250
+ os.makedirs(train_class_dir, exist_ok=True)
3251
+ os.makedirs(test_class_dir, exist_ok=True)
3252
+
3253
+ # Split the data
3254
+ train_data, test_data = train_test_split(data, test_size=test_split, shuffle=True, random_state=42)
3255
+
3256
+ # Copy train files
3257
+ for path in train_data:
3258
+ start = time.time()
3259
+ shutil.copy(path, os.path.join(train_class_dir, os.path.basename(path)))
3260
+ duration = time.time() - start
3261
+ time_ls.append(duration)
3262
+ print_progress(processed_files, total_files, n_jobs=1, time_ls=None, batch_size=None, operation_type="Copying files for Train dataset")
3263
+ processed_files += 1
3264
+
3265
+ # Copy test files
3266
+ for path in test_data:
3267
+ start = time.time()
3268
+ shutil.copy(path, os.path.join(test_class_dir, os.path.basename(path)))
3269
+ duration = time.time() - start
3270
+ time_ls.append(duration)
3271
+ print_progress(processed_files, total_files, n_jobs=1, time_ls=None, batch_size=None, operation_type="Copying files for Test dataset")
3272
+ processed_files += 1
3273
+
3274
+ # Print summary
3275
+ for cls in classes:
3276
+ train_class_dir = os.path.join(dst, f'train/{cls}')
3277
+ test_class_dir = os.path.join(dst, f'test/{cls}')
3278
+ print(f'Train class {cls}: {len(os.listdir(train_class_dir))}, Test class {cls}: {len(os.listdir(test_class_dir))}')
3279
+
3280
+ return os.path.join(dst, 'train'), os.path.join(dst, 'test')