spacr 0.2.53__tar.gz → 0.2.56__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (122) hide show
  1. {spacr-0.2.53/spacr.egg-info → spacr-0.2.56}/PKG-INFO +2 -1
  2. {spacr-0.2.53 → spacr-0.2.56}/setup.py +2 -1
  3. {spacr-0.2.53 → spacr-0.2.56}/spacr/core.py +282 -10
  4. {spacr-0.2.53 → spacr-0.2.56}/spacr/deep_spacr.py +101 -41
  5. {spacr-0.2.53 → spacr-0.2.56}/spacr/gui.py +1 -1
  6. {spacr-0.2.53 → spacr-0.2.56}/spacr/gui_core.py +8 -10
  7. {spacr-0.2.53 → spacr-0.2.56}/spacr/gui_elements.py +70 -0
  8. {spacr-0.2.53 → spacr-0.2.56}/spacr/gui_utils.py +30 -10
  9. {spacr-0.2.53 → spacr-0.2.56}/spacr/io.py +12 -4
  10. {spacr-0.2.53 → spacr-0.2.56}/spacr/sequencing.py +443 -643
  11. {spacr-0.2.53 → spacr-0.2.56}/spacr/settings.py +176 -44
  12. {spacr-0.2.53 → spacr-0.2.56}/spacr/utils.py +13 -5
  13. {spacr-0.2.53 → spacr-0.2.56/spacr.egg-info}/PKG-INFO +2 -1
  14. {spacr-0.2.53 → spacr-0.2.56}/spacr.egg-info/requires.txt +1 -0
  15. {spacr-0.2.53 → spacr-0.2.56}/LICENSE +0 -0
  16. {spacr-0.2.53 → spacr-0.2.56}/MANIFEST.in +0 -0
  17. {spacr-0.2.53 → spacr-0.2.56}/README.rst +0 -0
  18. {spacr-0.2.53 → spacr-0.2.56}/setup.cfg +0 -0
  19. {spacr-0.2.53 → spacr-0.2.56}/spacr/__init__.py +0 -0
  20. {spacr-0.2.53 → spacr-0.2.56}/spacr/__main__.py +0 -0
  21. {spacr-0.2.53 → spacr-0.2.56}/spacr/app_annotate.py +0 -0
  22. {spacr-0.2.53 → spacr-0.2.56}/spacr/app_classify.py +0 -0
  23. {spacr-0.2.53 → spacr-0.2.56}/spacr/app_make_masks.py +0 -0
  24. {spacr-0.2.53 → spacr-0.2.56}/spacr/app_mask.py +0 -0
  25. {spacr-0.2.53 → spacr-0.2.56}/spacr/app_measure.py +0 -0
  26. {spacr-0.2.53 → spacr-0.2.56}/spacr/app_sequencing.py +0 -0
  27. {spacr-0.2.53 → spacr-0.2.56}/spacr/app_umap.py +0 -0
  28. {spacr-0.2.53 → spacr-0.2.56}/spacr/chris.py +0 -0
  29. {spacr-0.2.53 → spacr-0.2.56}/spacr/graph_learning.py +0 -0
  30. {spacr-0.2.53 → spacr-0.2.56}/spacr/logger.py +0 -0
  31. {spacr-0.2.53 → spacr-0.2.56}/spacr/measure.py +0 -0
  32. {spacr-0.2.53 → spacr-0.2.56}/spacr/plot.py +0 -0
  33. {spacr-0.2.53 → spacr-0.2.56}/spacr/resources/font/open_sans/OFL.txt +0 -0
  34. {spacr-0.2.53 → spacr-0.2.56}/spacr/resources/font/open_sans/OpenSans-Italic-VariableFont_wdth,wght.ttf +0 -0
  35. {spacr-0.2.53 → spacr-0.2.56}/spacr/resources/font/open_sans/OpenSans-VariableFont_wdth,wght.ttf +0 -0
  36. {spacr-0.2.53 → spacr-0.2.56}/spacr/resources/font/open_sans/README.txt +0 -0
  37. {spacr-0.2.53 → spacr-0.2.56}/spacr/resources/font/open_sans/static/OpenSans-Bold.ttf +0 -0
  38. {spacr-0.2.53 → spacr-0.2.56}/spacr/resources/font/open_sans/static/OpenSans-BoldItalic.ttf +0 -0
  39. {spacr-0.2.53 → spacr-0.2.56}/spacr/resources/font/open_sans/static/OpenSans-ExtraBold.ttf +0 -0
  40. {spacr-0.2.53 → spacr-0.2.56}/spacr/resources/font/open_sans/static/OpenSans-ExtraBoldItalic.ttf +0 -0
  41. {spacr-0.2.53 → spacr-0.2.56}/spacr/resources/font/open_sans/static/OpenSans-Italic.ttf +0 -0
  42. {spacr-0.2.53 → spacr-0.2.56}/spacr/resources/font/open_sans/static/OpenSans-Light.ttf +0 -0
  43. {spacr-0.2.53 → spacr-0.2.56}/spacr/resources/font/open_sans/static/OpenSans-LightItalic.ttf +0 -0
  44. {spacr-0.2.53 → spacr-0.2.56}/spacr/resources/font/open_sans/static/OpenSans-Medium.ttf +0 -0
  45. {spacr-0.2.53 → spacr-0.2.56}/spacr/resources/font/open_sans/static/OpenSans-MediumItalic.ttf +0 -0
  46. {spacr-0.2.53 → spacr-0.2.56}/spacr/resources/font/open_sans/static/OpenSans-Regular.ttf +0 -0
  47. {spacr-0.2.53 → spacr-0.2.56}/spacr/resources/font/open_sans/static/OpenSans-SemiBold.ttf +0 -0
  48. {spacr-0.2.53 → spacr-0.2.56}/spacr/resources/font/open_sans/static/OpenSans-SemiBoldItalic.ttf +0 -0
  49. {spacr-0.2.53 → spacr-0.2.56}/spacr/resources/font/open_sans/static/OpenSans_Condensed-Bold.ttf +0 -0
  50. {spacr-0.2.53 → spacr-0.2.56}/spacr/resources/font/open_sans/static/OpenSans_Condensed-BoldItalic.ttf +0 -0
  51. {spacr-0.2.53 → spacr-0.2.56}/spacr/resources/font/open_sans/static/OpenSans_Condensed-ExtraBold.ttf +0 -0
  52. {spacr-0.2.53 → spacr-0.2.56}/spacr/resources/font/open_sans/static/OpenSans_Condensed-ExtraBoldItalic.ttf +0 -0
  53. {spacr-0.2.53 → spacr-0.2.56}/spacr/resources/font/open_sans/static/OpenSans_Condensed-Italic.ttf +0 -0
  54. {spacr-0.2.53 → spacr-0.2.56}/spacr/resources/font/open_sans/static/OpenSans_Condensed-Light.ttf +0 -0
  55. {spacr-0.2.53 → spacr-0.2.56}/spacr/resources/font/open_sans/static/OpenSans_Condensed-LightItalic.ttf +0 -0
  56. {spacr-0.2.53 → spacr-0.2.56}/spacr/resources/font/open_sans/static/OpenSans_Condensed-Medium.ttf +0 -0
  57. {spacr-0.2.53 → spacr-0.2.56}/spacr/resources/font/open_sans/static/OpenSans_Condensed-MediumItalic.ttf +0 -0
  58. {spacr-0.2.53 → spacr-0.2.56}/spacr/resources/font/open_sans/static/OpenSans_Condensed-Regular.ttf +0 -0
  59. {spacr-0.2.53 → spacr-0.2.56}/spacr/resources/font/open_sans/static/OpenSans_Condensed-SemiBold.ttf +0 -0
  60. {spacr-0.2.53 → spacr-0.2.56}/spacr/resources/font/open_sans/static/OpenSans_Condensed-SemiBoldItalic.ttf +0 -0
  61. {spacr-0.2.53 → spacr-0.2.56}/spacr/resources/font/open_sans/static/OpenSans_SemiCondensed-Bold.ttf +0 -0
  62. {spacr-0.2.53 → spacr-0.2.56}/spacr/resources/font/open_sans/static/OpenSans_SemiCondensed-BoldItalic.ttf +0 -0
  63. {spacr-0.2.53 → spacr-0.2.56}/spacr/resources/font/open_sans/static/OpenSans_SemiCondensed-ExtraBold.ttf +0 -0
  64. {spacr-0.2.53 → spacr-0.2.56}/spacr/resources/font/open_sans/static/OpenSans_SemiCondensed-ExtraBoldItalic.ttf +0 -0
  65. {spacr-0.2.53 → spacr-0.2.56}/spacr/resources/font/open_sans/static/OpenSans_SemiCondensed-Italic.ttf +0 -0
  66. {spacr-0.2.53 → spacr-0.2.56}/spacr/resources/font/open_sans/static/OpenSans_SemiCondensed-Light.ttf +0 -0
  67. {spacr-0.2.53 → spacr-0.2.56}/spacr/resources/font/open_sans/static/OpenSans_SemiCondensed-LightItalic.ttf +0 -0
  68. {spacr-0.2.53 → spacr-0.2.56}/spacr/resources/font/open_sans/static/OpenSans_SemiCondensed-Medium.ttf +0 -0
  69. {spacr-0.2.53 → spacr-0.2.56}/spacr/resources/font/open_sans/static/OpenSans_SemiCondensed-MediumItalic.ttf +0 -0
  70. {spacr-0.2.53 → spacr-0.2.56}/spacr/resources/font/open_sans/static/OpenSans_SemiCondensed-Regular.ttf +0 -0
  71. {spacr-0.2.53 → spacr-0.2.56}/spacr/resources/font/open_sans/static/OpenSans_SemiCondensed-SemiBold.ttf +0 -0
  72. {spacr-0.2.53 → spacr-0.2.56}/spacr/resources/font/open_sans/static/OpenSans_SemiCondensed-SemiBoldItalic.ttf +0 -0
  73. {spacr-0.2.53 → spacr-0.2.56}/spacr/resources/icons/abort.png +0 -0
  74. {spacr-0.2.53 → spacr-0.2.56}/spacr/resources/icons/annotate.png +0 -0
  75. {spacr-0.2.53 → spacr-0.2.56}/spacr/resources/icons/cellpose_all.png +0 -0
  76. {spacr-0.2.53 → spacr-0.2.56}/spacr/resources/icons/cellpose_masks.png +0 -0
  77. {spacr-0.2.53 → spacr-0.2.56}/spacr/resources/icons/classify.png +0 -0
  78. {spacr-0.2.53 → spacr-0.2.56}/spacr/resources/icons/default.png +0 -0
  79. {spacr-0.2.53 → spacr-0.2.56}/spacr/resources/icons/download.png +0 -0
  80. {spacr-0.2.53 → spacr-0.2.56}/spacr/resources/icons/logo.pdf +0 -0
  81. {spacr-0.2.53 → spacr-0.2.56}/spacr/resources/icons/logo_spacr.png +0 -0
  82. {spacr-0.2.53 → spacr-0.2.56}/spacr/resources/icons/logo_spacr_1.png +0 -0
  83. {spacr-0.2.53 → spacr-0.2.56}/spacr/resources/icons/make_masks.png +0 -0
  84. {spacr-0.2.53 → spacr-0.2.56}/spacr/resources/icons/map_barcodes.png +0 -0
  85. {spacr-0.2.53 → spacr-0.2.56}/spacr/resources/icons/mask.png +0 -0
  86. {spacr-0.2.53 → spacr-0.2.56}/spacr/resources/icons/measure.png +0 -0
  87. {spacr-0.2.53 → spacr-0.2.56}/spacr/resources/icons/ml_analyze.png +0 -0
  88. {spacr-0.2.53 → spacr-0.2.56}/spacr/resources/icons/recruitment.png +0 -0
  89. {spacr-0.2.53 → spacr-0.2.56}/spacr/resources/icons/regression.png +0 -0
  90. {spacr-0.2.53 → spacr-0.2.56}/spacr/resources/icons/run.png +0 -0
  91. {spacr-0.2.53 → spacr-0.2.56}/spacr/resources/icons/sequencing.png +0 -0
  92. {spacr-0.2.53 → spacr-0.2.56}/spacr/resources/icons/settings.png +0 -0
  93. {spacr-0.2.53 → spacr-0.2.56}/spacr/resources/icons/spacr_logo_rotation.gif +0 -0
  94. {spacr-0.2.53 → spacr-0.2.56}/spacr/resources/icons/train_cellpose.png +0 -0
  95. {spacr-0.2.53 → spacr-0.2.56}/spacr/resources/icons/umap.png +0 -0
  96. {spacr-0.2.53 → spacr-0.2.56}/spacr/resources/models/cp/toxo_plaque_cyto_e25000_X1120_Y1120.CP_model +0 -0
  97. {spacr-0.2.53 → spacr-0.2.56}/spacr/resources/models/cp/toxo_plaque_cyto_e25000_X1120_Y1120.CP_model_settings.csv +0 -0
  98. {spacr-0.2.53 → spacr-0.2.56}/spacr/resources/models/cp/toxo_pv_lumen.CP_model +0 -0
  99. {spacr-0.2.53 → spacr-0.2.56}/spacr/sim.py +0 -0
  100. {spacr-0.2.53 → spacr-0.2.56}/spacr/sim_app.py +0 -0
  101. {spacr-0.2.53 → spacr-0.2.56}/spacr/timelapse.py +0 -0
  102. {spacr-0.2.53 → spacr-0.2.56}/spacr/version.py +0 -0
  103. {spacr-0.2.53 → spacr-0.2.56}/spacr.egg-info/SOURCES.txt +0 -0
  104. {spacr-0.2.53 → spacr-0.2.56}/spacr.egg-info/dependency_links.txt +0 -0
  105. {spacr-0.2.53 → spacr-0.2.56}/spacr.egg-info/entry_points.txt +0 -0
  106. {spacr-0.2.53 → spacr-0.2.56}/spacr.egg-info/top_level.txt +0 -0
  107. {spacr-0.2.53 → spacr-0.2.56}/tests/test_annotate_app.py +0 -0
  108. {spacr-0.2.53 → spacr-0.2.56}/tests/test_core.py +0 -0
  109. {spacr-0.2.53 → spacr-0.2.56}/tests/test_gui_classify_app.py +0 -0
  110. {spacr-0.2.53 → spacr-0.2.56}/tests/test_gui_mask_app.py +0 -0
  111. {spacr-0.2.53 → spacr-0.2.56}/tests/test_gui_measure_app.py +0 -0
  112. {spacr-0.2.53 → spacr-0.2.56}/tests/test_gui_sim_app.py +0 -0
  113. {spacr-0.2.53 → spacr-0.2.56}/tests/test_gui_utils.py +0 -0
  114. {spacr-0.2.53 → spacr-0.2.56}/tests/test_io.py +0 -0
  115. {spacr-0.2.53 → spacr-0.2.56}/tests/test_mask_app.py +0 -0
  116. {spacr-0.2.53 → spacr-0.2.56}/tests/test_measure.py +0 -0
  117. {spacr-0.2.53 → spacr-0.2.56}/tests/test_plot.py +0 -0
  118. {spacr-0.2.53 → spacr-0.2.56}/tests/test_sim.py +0 -0
  119. {spacr-0.2.53 → spacr-0.2.56}/tests/test_timelapse.py +0 -0
  120. {spacr-0.2.53 → spacr-0.2.56}/tests/test_train.py +0 -0
  121. {spacr-0.2.53 → spacr-0.2.56}/tests/test_umap.py +0 -0
  122. {spacr-0.2.53 → spacr-0.2.56}/tests/test_utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: spacr
3
- Version: 0.2.53
3
+ Version: 0.2.56
4
4
  Summary: Spatial phenotype analysis of crisp screens (SpaCr)
5
5
  Home-page: https://github.com/EinarOlafsson/spacr
6
6
  Author: Einar Birnir Olafsson
@@ -44,6 +44,7 @@ Requires-Dist: gputil<2.0,>=1.4.0
44
44
  Requires-Dist: gpustat<2.0,>=1.1.1
45
45
  Requires-Dist: pyautogui<1.0,>=0.9.54
46
46
  Requires-Dist: tables<4.0,>=3.8.0
47
+ Requires-Dist: rapidfuzz<4.0,>=3.9
47
48
  Requires-Dist: huggingface-hub<0.25,>=0.24.0
48
49
  Provides-Extra: dev
49
50
  Requires-Dist: pytest<3.11,>=3.9; extra == "dev"
@@ -50,12 +50,13 @@ dependencies = [
50
50
  'gpustat>=1.1.1,<2.0',
51
51
  'pyautogui>=0.9.54,<1.0',
52
52
  'tables>=3.8.0,<4.0',
53
+ 'rapidfuzz>=3.9, <4.0',
53
54
  'huggingface-hub>=0.24.0,<0.25'
54
55
  ]
55
56
 
56
57
  setup(
57
58
  name="spacr",
58
- version="0.2.53",
59
+ version="0.2.56",
59
60
  author="Einar Birnir Olafsson",
60
61
  author_email="olafsson@med.umich.com",
61
62
  description="Spatial phenotype analysis of crisp screens (SpaCr)",
@@ -877,7 +877,106 @@ def annotate_results(pred_loc):
877
877
  display(df)
878
878
  return df
879
879
 
880
- def generate_dataset(src, file_metadata=None, experiment='TSG101_screen', sample=None):
880
+ def generate_dataset(settings={}):
881
+
882
+ from .utils import initiate_counter, add_images_to_tar
883
+
884
+ db_path = os.path.join(settings['src'], 'measurements', 'measurements.db')
885
+ dst = os.path.join(settings['src'], 'datasets')
886
+ all_paths = []
887
+
888
+ # Connect to the database and retrieve the image paths
889
+ print(f"Reading DataBase: {db_path}")
890
+ try:
891
+ with sqlite3.connect(db_path) as conn:
892
+ cursor = conn.cursor()
893
+ if settings['file_metadata']:
894
+ if isinstance(settings['file_metadata'], str):
895
+ cursor.execute("SELECT png_path FROM png_list WHERE png_path LIKE ?", (f"%{settings['file_metadata']}%",))
896
+ else:
897
+ cursor.execute("SELECT png_path FROM png_list")
898
+
899
+ while True:
900
+ rows = cursor.fetchmany(1000)
901
+ if not rows:
902
+ break
903
+ all_paths.extend([row[0] for row in rows])
904
+
905
+ except sqlite3.Error as e:
906
+ print(f"Database error: {e}")
907
+ return
908
+ except Exception as e:
909
+ print(f"Error: {e}")
910
+ return
911
+
912
+ if isinstance(settings['sample'], int):
913
+ selected_paths = random.sample(all_paths, settings['sample'])
914
+ print(f"Random selection of {len(selected_paths)} paths")
915
+ else:
916
+ selected_paths = all_paths
917
+ random.shuffle(selected_paths)
918
+ print(f"All paths: {len(selected_paths)} paths")
919
+
920
+ total_images = len(selected_paths)
921
+ print(f"Found {total_images} images")
922
+
923
+ # Create a temp folder in dst
924
+ temp_dir = os.path.join(dst, "temp_tars")
925
+ os.makedirs(temp_dir, exist_ok=True)
926
+
927
+ # Chunking the data
928
+ num_procs = max(2, cpu_count() - 2)
929
+ chunk_size = len(selected_paths) // num_procs
930
+ remainder = len(selected_paths) % num_procs
931
+
932
+ paths_chunks = []
933
+ start = 0
934
+ for i in range(num_procs):
935
+ end = start + chunk_size + (1 if i < remainder else 0)
936
+ paths_chunks.append(selected_paths[start:end])
937
+ start = end
938
+
939
+ temp_tar_files = [os.path.join(temp_dir, f"temp_{i}.tar") for i in range(num_procs)]
940
+
941
+ print(f"Generating temporary tar files in {dst}")
942
+
943
+ # Initialize shared counter and lock
944
+ counter = Value('i', 0)
945
+ lock = Lock()
946
+
947
+ with Pool(processes=num_procs, initializer=initiate_counter, initargs=(counter, lock)) as pool:
948
+ pool.starmap(add_images_to_tar, [(paths_chunks[i], temp_tar_files[i], total_images) for i in range(num_procs)])
949
+
950
+ # Combine the temporary tar files into a final tar
951
+ date_name = datetime.date.today().strftime('%y%m%d')
952
+ if not settings['file_metadata'] is None:
953
+ tar_name = f"{date_name}_{settings['experiment']}_{settings['file_metadata']}.tar"
954
+ else:
955
+ tar_name = f"{date_name}_{settings['experiment']}.tar"
956
+ tar_name = os.path.join(dst, tar_name)
957
+ if os.path.exists(tar_name):
958
+ number = random.randint(1, 100)
959
+ tar_name_2 = f"{date_name}_{settings['experiment']}_{settings['file_metadata']}_{number}.tar"
960
+ print(f"Warning: {os.path.basename(tar_name)} exists, saving as {os.path.basename(tar_name_2)} ")
961
+ tar_name = os.path.join(dst, tar_name_2)
962
+
963
+ print(f"Merging temporary files")
964
+
965
+ with tarfile.open(tar_name, 'w') as final_tar:
966
+ for temp_tar_path in temp_tar_files:
967
+ with tarfile.open(temp_tar_path, 'r') as temp_tar:
968
+ for member in temp_tar.getmembers():
969
+ file_obj = temp_tar.extractfile(member)
970
+ final_tar.addfile(member, file_obj)
971
+ os.remove(temp_tar_path)
972
+
973
+ # Delete the temp folder
974
+ shutil.rmtree(temp_dir)
975
+ print(f"\nSaved {total_images} images to {tar_name}")
976
+
977
+ return tar_name
978
+
979
+ def generate_dataset_v1(src, file_metadata=None, experiment='TSG101_screen', sample=None):
881
980
 
882
981
  from .utils import initiate_counter, add_images_to_tar
883
982
 
@@ -974,7 +1073,7 @@ def generate_dataset(src, file_metadata=None, experiment='TSG101_screen', sample
974
1073
  shutil.rmtree(temp_dir)
975
1074
  print(f"\nSaved {total_images} images to {tar_name}")
976
1075
 
977
- def apply_model_to_tar(tar_path, model_path, file_type='cell_png', image_size=224, batch_size=64, normalize=True, preload='images', n_jobs=10, threshold=0.5, verbose=False):
1076
+ def apply_model_to_tar_v1(tar_path, model_path, file_type='cell_png', image_size=224, batch_size=64, normalize=True, preload='images', n_jobs=10, threshold=0.5, verbose=False):
978
1077
 
979
1078
  from .io import TarImageDataset
980
1079
  from .utils import process_vision_results, print_progress
@@ -1044,6 +1143,76 @@ def apply_model_to_tar(tar_path, model_path, file_type='cell_png', image_size=22
1044
1143
  torch.cuda.memory.empty_cache()
1045
1144
  return df
1046
1145
 
1146
+ def apply_model_to_tar(settings={}):
1147
+
1148
+ from .io import TarImageDataset
1149
+ from .utils import process_vision_results, print_progress
1150
+
1151
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
1152
+ if settings['normalize']:
1153
+ transform = transforms.Compose([
1154
+ transforms.ToTensor(),
1155
+ transforms.CenterCrop(size=(settings['image_size'], settings['image_size'])),
1156
+ transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))])
1157
+ else:
1158
+ transform = transforms.Compose([
1159
+ transforms.ToTensor(),
1160
+ transforms.CenterCrop(size=(settings['image_size'], settings['image_size']))])
1161
+
1162
+ if settings['verbose']:
1163
+ print(f"Loading model from {settings['model_path']}")
1164
+ print(f"Loading dataset from {settings['tar_path']}")
1165
+
1166
+ model = torch.load(settings['model_path'])
1167
+
1168
+ dataset = TarImageDataset(settings['tar_path'], transform=transform)
1169
+ data_loader = DataLoader(dataset, batch_size=settings['batch_size'], shuffle=True, num_workers=settings['n_jobs'], pin_memory=True)
1170
+
1171
+ model_name = os.path.splitext(os.path.basename(settings['model_path']))[0]
1172
+ dataset_name = os.path.splitext(os.path.basename(settings['tar_path']))[0]
1173
+ date_name = datetime.date.today().strftime('%y%m%d')
1174
+ dst = os.path.dirname(settings['tar_path'])
1175
+ result_loc = f'{dst}/{date_name}_{dataset_name}_{model_name}_result.csv'
1176
+
1177
+ model.eval()
1178
+ model = model.to(device)
1179
+
1180
+ if settings['verbose']:
1181
+ print(model)
1182
+ print(f'Generated dataset with {len(dataset)} images')
1183
+ print(f'Generating loader from {len(data_loader)} batches')
1184
+ print(f'Results wil be saved in: {result_loc}')
1185
+ print(f'Model is in eval mode')
1186
+ print(f'Model loaded to device')
1187
+
1188
+ prediction_pos_probs = []
1189
+ filenames_list = []
1190
+ time_ls = []
1191
+ gc.collect()
1192
+ with torch.no_grad():
1193
+ for batch_idx, (batch_images, filenames) in enumerate(data_loader, start=1):
1194
+ start = time.time()
1195
+ images = batch_images.to(torch.float).to(device)
1196
+ outputs = model(images)
1197
+ batch_prediction_pos_prob = torch.sigmoid(outputs).cpu().numpy()
1198
+ prediction_pos_probs.extend(batch_prediction_pos_prob.tolist())
1199
+ filenames_list.extend(filenames)
1200
+ stop = time.time()
1201
+ duration = stop - start
1202
+ time_ls.append(duration)
1203
+ files_processed = batch_idx*settings['batch_size']
1204
+ files_to_process = len(data_loader)
1205
+ print_progress(files_processed, files_to_process, n_jobs=settings['n_jobs'], time_ls=time_ls, batch_size=settings['batch_size'], operation_type="Tar dataset")
1206
+
1207
+ data = {'path':filenames_list, 'pred':prediction_pos_probs}
1208
+ df = pd.DataFrame(data, index=None)
1209
+ df = process_vision_results(df, settings['score_threshold'])
1210
+
1211
+ df.to_csv(result_loc, index=True, header=True, mode='w')
1212
+ torch.cuda.empty_cache()
1213
+ torch.cuda.memory.empty_cache()
1214
+ return df
1215
+
1047
1216
  def apply_model(src, model_path, image_size=224, batch_size=64, normalize=True, n_jobs=10):
1048
1217
 
1049
1218
  from .io import NoClassDataset
@@ -1206,19 +1375,19 @@ def generate_dataset_from_lists(dst, class_data, classes, test_split=0.1):
1206
1375
  for path in train_data:
1207
1376
  start = time.time()
1208
1377
  shutil.copy(path, os.path.join(train_class_dir, os.path.basename(path)))
1209
- processed_files += 1
1210
1378
  duration = time.time() - start
1211
1379
  time_ls.append(duration)
1212
1380
  print_progress(processed_files, total_files, n_jobs=1, time_ls=None, batch_size=None, operation_type="Copying files for Train dataset")
1381
+ processed_files += 1
1213
1382
 
1214
1383
  # Copy test files
1215
1384
  for path in test_data:
1216
1385
  start = time.time()
1217
1386
  shutil.copy(path, os.path.join(test_class_dir, os.path.basename(path)))
1218
- processed_files += 1
1219
1387
  duration = time.time() - start
1220
1388
  time_ls.append(duration)
1221
1389
  print_progress(processed_files, total_files, n_jobs=1, time_ls=None, batch_size=None, operation_type="Copying files for Test dataset")
1390
+ processed_files += 1
1222
1391
 
1223
1392
  # Print summary
1224
1393
  for cls in classes:
@@ -1226,9 +1395,9 @@ def generate_dataset_from_lists(dst, class_data, classes, test_split=0.1):
1226
1395
  test_class_dir = os.path.join(dst, f'test/{cls}')
1227
1396
  print(f'Train class {cls}: {len(os.listdir(train_class_dir))}, Test class {cls}: {len(os.listdir(test_class_dir))}')
1228
1397
 
1229
- return
1398
+ return os.path.join(dst, 'train'), os.path.join(dst, 'test')
1230
1399
 
1231
- def generate_training_dataset(src, mode='annotation', annotation_column='test', annotated_classes=[1,2], classes=['nc','pc'], size=200, test_split=0.1, class_metadata=[['c1'],['c2']], metadata_type_by='col', channel_of_interest=3, custom_measurement=None, tables=None, png_type='cell_png'):
1400
+ def generate_training_dataset_v1(src, mode='annotation', annotation_column='test', annotated_classes=[1,2], classes=['nc','pc'], size=200, test_split=0.1, class_metadata=[['c1'],['c2']], metadata_type_by='col', channel_of_interest=3, custom_measurement=None, tables=None, png_type='cell_png'):
1232
1401
 
1233
1402
  from .io import _read_and_merge_data, _read_db
1234
1403
  from .utils import get_paths_from_db, annotate_conditions
@@ -1329,6 +1498,110 @@ def generate_training_dataset(src, mode='annotation', annotation_column='test',
1329
1498
 
1330
1499
  return
1331
1500
 
1501
+ def generate_training_dataset(settings):
1502
+
1503
+ from .io import _read_and_merge_data, _read_db
1504
+ from .utils import get_paths_from_db, annotate_conditions
1505
+ from .settings import set_generate_training_dataset_defaults
1506
+
1507
+ settings = set_generate_training_dataset_defaults(settings)
1508
+
1509
+ db_path = os.path.join(settings['src'], 'measurements','measurements.db')
1510
+ dst = os.path.join(settings['src'], 'datasets', 'training')
1511
+
1512
+ if os.path.exists(dst):
1513
+ for i in range(1, 1000):
1514
+ dst = os.path.join(settings['src'], 'datasets', f'training_{i}')
1515
+ if not os.path.exists(dst):
1516
+ print(f'Creating new directory for training: {dst}')
1517
+ break
1518
+
1519
+ if settings['dataset_mode'] == 'annotation':
1520
+ class_paths_ls_2 = []
1521
+ class_paths_ls = training_dataset_from_annotation(db_path, dst, settings['annotation_column'], annotated_classes=settings['annotated_classes'])
1522
+ for class_paths in class_paths_ls:
1523
+ class_paths_temp = random.sample(class_paths, settings['size'])
1524
+ class_paths_ls_2.append(class_paths_temp)
1525
+ class_paths_ls = class_paths_ls_2
1526
+
1527
+ elif settings['dataset_mode'] == 'metadata':
1528
+ class_paths_ls = []
1529
+ class_len_ls = []
1530
+ [df] = _read_db(db_loc=db_path, tables=['png_list'])
1531
+ df['metadata_based_class'] = pd.NA
1532
+ for i, class_ in enumerate(settings['classes']):
1533
+ ls = settings['class_metadata'][i]
1534
+ df.loc[df[settings['metadata_type_by']].isin(ls), 'metadata_based_class'] = class_
1535
+
1536
+ for class_ in settings['classes']:
1537
+ if settings['size'] == None:
1538
+ c_s = []
1539
+ for c in settings['classes']:
1540
+ c_s_t_df = df[df['metadata_based_class'] == c]
1541
+ c_s.append(len(c_s_t_df))
1542
+ print(f'Found {len(c_s_t_df)} images for class {c}')
1543
+ size = min(c_s)
1544
+ print(f'Using the smallest class size: {size}')
1545
+
1546
+ class_temp_df = df[df['metadata_based_class'] == class_]
1547
+ class_len_ls.append(len(class_temp_df))
1548
+ print(f'Found {len(class_temp_df)} images for class {class_}')
1549
+ class_paths_temp = random.sample(class_temp_df['png_path'].tolist(), settings['size'])
1550
+ class_paths_ls.append(class_paths_temp)
1551
+
1552
+ elif settings['dataset_mode'] == 'recruitment':
1553
+ class_paths_ls = []
1554
+ if not isinstance(settings['tables'], list):
1555
+ tables = ['cell', 'nucleus', 'pathogen','cytoplasm']
1556
+
1557
+ df, _ = _read_and_merge_data(locs=[db_path],
1558
+ tables=tables,
1559
+ verbose=False,
1560
+ include_multinucleated=True,
1561
+ include_multiinfected=True,
1562
+ include_noninfected=True)
1563
+
1564
+ print('length df 1', len(df))
1565
+
1566
+ df = annotate_conditions(df, cells=['HeLa'], cell_loc=None, pathogens=['pathogen'], pathogen_loc=None, treatments=settings['classes'], treatment_loc=settings['class_metadata'], types = settings['metadata_type_by'])
1567
+ print('length df 2', len(df))
1568
+ [png_list_df] = _read_db(db_loc=db_path, tables=['png_list'])
1569
+
1570
+ if settings['custom_measurement'] != None:
1571
+
1572
+ if not isinstance(settings['custom_measurement'], list):
1573
+ print(f'custom_measurement should be a list, add [ measurement_1, measurement_2 ] or [ measurement ]')
1574
+ return
1575
+
1576
+ if isinstance(settings['custom_measurement'], list):
1577
+ if len(settings['custom_measurement']) == 2:
1578
+ print(f"Classes will be defined by the Q1 and Q3 quantiles of recruitment ({settings['custom_measurement'][0]}/{settings['custom_measurement'][1]})")
1579
+ df['recruitment'] = df[f"{settings['custom_measurement'][0]}']/df[f'{settings['custom_measurement'][1]}"]
1580
+ if len(settings['custom_measurement']) == 1:
1581
+ print(f"Classes will be defined by the Q1 and Q3 quantiles of recruitment ({settings['custom_measurement'][0]})")
1582
+ df['recruitment'] = df[f"{settings['custom_measurement'][0]}"]
1583
+ else:
1584
+ print(f"Classes will be defined by the Q1 and Q3 quantiles of recruitment (pathogen/cytoplasm for channel {settings['channel_of_interest']})")
1585
+ df['recruitment'] = df[f"pathogen_channel_{settings['channel_of_interest']}_mean_intensity']/df[f'cytoplasm_channel_{settings['channel_of_interest']}_mean_intensity"]
1586
+
1587
+ q25 = df['recruitment'].quantile(0.25)
1588
+ q75 = df['recruitment'].quantile(0.75)
1589
+ df_lower = df[df['recruitment'] <= q25]
1590
+ df_upper = df[df['recruitment'] >= q75]
1591
+
1592
+ class_paths_lower = get_paths_from_db(df=df_lower, png_df=png_list_df, image_type=settings['png_type'])
1593
+
1594
+ class_paths_lower = random.sample(class_paths_lower['png_path'].tolist(), settings['size'])
1595
+ class_paths_ls.append(class_paths_lower)
1596
+
1597
+ class_paths_upper = get_paths_from_db(df=df_upper, png_df=png_list_df, image_type=settings['png_type'])
1598
+ class_paths_upper = random.sample(class_paths_upper['png_path'].tolist(), settings['size'])
1599
+ class_paths_ls.append(class_paths_upper)
1600
+
1601
+ train_class_dir, test_class_dir = generate_dataset_from_lists(dst, class_data=class_paths_ls, classes=settings['classes'], test_split=settings['test_split'])
1602
+
1603
+ return train_class_dir, test_class_dir
1604
+
1332
1605
  def generate_loaders(src, train_mode='erm', mode='train', image_size=224, batch_size=32, classes=['nc','pc'], n_jobs=None, validation_split=0.0, max_show=2, pin_memory=False, normalize=False, channels=[1, 2, 3], augment=False, verbose=False):
1333
1606
 
1334
1607
  """
@@ -2497,7 +2770,6 @@ def ml_analysis(df, channel_of_interest=3, location_column='col', positive_contr
2497
2770
  df_metadata = df[[location_column]].copy()
2498
2771
  df, features = filter_dataframe_features(df, channel_of_interest, exclude, remove_low_variance_features, remove_highly_correlated_features, verbose)
2499
2772
 
2500
-
2501
2773
  if verbose:
2502
2774
  print(f'Found {len(features)} numerical features in the dataframe')
2503
2775
  print(f'Features used in training: {features}')
@@ -2642,7 +2914,6 @@ def check_index(df, elements=5, split_char='_'):
2642
2914
  print(idx)
2643
2915
  raise ValueError(f"Found {len(problematic_indices)} problematic indices that do not split into {elements} parts.")
2644
2916
 
2645
- #def plate_heatmap(src, model_type='xgboost', variable='predictions', grouping='mean', min_max='allq', cmap='viridis', channel_of_interest=3, min_count=25, n_estimators=100, col_to_compare='col', pos='c2', neg='c1', exclude=None, n_repeats=10, clean=True, nr_to_plot=20, verbose=False, n_jobs=-1):
2646
2917
  def generate_ml_scores(src, settings):
2647
2918
 
2648
2919
  from .io import _read_and_merge_data
@@ -2680,7 +2951,7 @@ def generate_ml_scores(src, settings):
2680
2951
  settings['top_features'],
2681
2952
  settings['n_estimators'],
2682
2953
  settings['test_size'],
2683
- settings['model_type'],
2954
+ settings['model_type_ml'],
2684
2955
  settings['n_jobs'],
2685
2956
  settings['remove_low_variance_features'],
2686
2957
  settings['remove_highly_correlated_features'],
@@ -2701,7 +2972,7 @@ def generate_ml_scores(src, settings):
2701
2972
  min_count=settings['minimum_cell_count'],
2702
2973
  verbose=settings['verbose'])
2703
2974
 
2704
- data_path, permutation_path, feature_importance_path, model_metricks_path, permutation_fig_path, feature_importance_fig_path, shap_fig_path, plate_heatmap_path, settings_csv = get_ml_results_paths(src, settings['model_type'], settings['channel_of_interest'])
2975
+ data_path, permutation_path, feature_importance_path, model_metricks_path, permutation_fig_path, feature_importance_fig_path, shap_fig_path, plate_heatmap_path, settings_csv = get_ml_results_paths(src, settings['model_type_ml'], settings['channel_of_interest'])
2705
2976
  df, permutation_df, feature_importance_df, _, _, _, _, _, metrics_df = output
2706
2977
 
2707
2978
  settings_df.to_csv(settings_csv, index=False)
@@ -2858,6 +3129,7 @@ def generate_image_umap(settings={}):
2858
3129
  settings['plot_outlines'] = False
2859
3130
  settings['smooth_lines'] = False
2860
3131
 
3132
+ print(f'Generating Image UMAP ...')
2861
3133
  settings_df = pd.DataFrame(list(settings.items()), columns=['Key', 'Value'])
2862
3134
  settings_dir = os.path.join(settings['src'][0],'settings')
2863
3135
  settings_csv = os.path.join(settings_dir,'embedding_settings.csv')
@@ -196,7 +196,7 @@ def test_model_performance(loaders, model, loader_name_list, epoch, train_mode,
196
196
  test_time = end_time - start_time
197
197
  return result, results_df
198
198
 
199
- def train_test_model(src, settings, custom_model=False, custom_model_path=None):
199
+ def train_test_model(settings):
200
200
 
201
201
  from .io import _save_settings, _copy_missclassified
202
202
  from .utils import pick_best_model
@@ -208,7 +208,10 @@ def train_test_model(src, settings, custom_model=False, custom_model_path=None):
208
208
  gc.collect()
209
209
 
210
210
  settings = set_default_train_test_model(settings)
211
- channels_str = ''.join(settings['channels'])
211
+
212
+ src = settings['src']
213
+
214
+ channels_str = ''.join(settings['train_channels'])
212
215
  dst = os.path.join(src,'model', settings['model_type'], channels_str, str(f"epochs_{settings['epochs']}"))
213
216
  os.makedirs(dst, exist_ok=True)
214
217
  settings['src'] = src
@@ -217,8 +220,8 @@ def train_test_model(src, settings, custom_model=False, custom_model_path=None):
217
220
  settings_csv = os.path.join(dst,'train_test_model_settings.csv')
218
221
  settings_df.to_csv(settings_csv, index=False)
219
222
 
220
- if custom_model:
221
- model = torch.load(custom_model_path)
223
+ if settings['custom_model']:
224
+ model = torch.load(settings['custom_model_path'])
222
225
 
223
226
  if settings['train']:
224
227
  _save_settings(settings, src)
@@ -234,7 +237,7 @@ def train_test_model(src, settings, custom_model=False, custom_model_path=None):
234
237
  validation_split=settings['val_split'],
235
238
  pin_memory=settings['pin_memory'],
236
239
  normalize=settings['normalize'],
237
- channels=settings['channels'],
240
+ channels=settings['train_channels'],
238
241
  augment=settings['augment'],
239
242
  verbose=settings['verbose'])
240
243
 
@@ -242,28 +245,28 @@ def train_test_model(src, settings, custom_model=False, custom_model_path=None):
242
245
  train_fig.savefig(train_batch_1_figure, format='pdf', dpi=600)
243
246
 
244
247
  if settings['train']:
245
- model = train_model(dst = settings['dst'],
246
- model_type=settings['model_type'],
247
- train_loaders = train,
248
- train_loader_names = plate_names,
249
- train_mode = settings['train_mode'],
250
- epochs = settings['epochs'],
251
- learning_rate = settings['learning_rate'],
252
- init_weights = settings['init_weights'],
253
- weight_decay = settings['weight_decay'],
254
- amsgrad = settings['amsgrad'],
255
- optimizer_type = settings['optimizer_type'],
256
- use_checkpoint = settings['use_checkpoint'],
257
- dropout_rate = settings['dropout_rate'],
258
- n_jobs = settings['n_jobs'],
259
- val_loaders = val,
260
- test_loaders = None,
261
- intermedeate_save = settings['intermedeate_save'],
262
- schedule = settings['schedule'],
263
- loss_type=settings['loss_type'],
264
- gradient_accumulation=settings['gradient_accumulation'],
265
- gradient_accumulation_steps=settings['gradient_accumulation_steps'],
266
- channels=settings['channels'])
248
+ model, model_path = train_model(dst = settings['dst'],
249
+ model_type=settings['model_type'],
250
+ train_loaders = train,
251
+ train_loader_names = plate_names,
252
+ train_mode = settings['train_mode'],
253
+ epochs = settings['epochs'],
254
+ learning_rate = settings['learning_rate'],
255
+ init_weights = settings['init_weights'],
256
+ weight_decay = settings['weight_decay'],
257
+ amsgrad = settings['amsgrad'],
258
+ optimizer_type = settings['optimizer_type'],
259
+ use_checkpoint = settings['use_checkpoint'],
260
+ dropout_rate = settings['dropout_rate'],
261
+ n_jobs = settings['n_jobs'],
262
+ val_loaders = val,
263
+ test_loaders = None,
264
+ intermedeate_save = settings['intermedeate_save'],
265
+ schedule = settings['schedule'],
266
+ loss_type=settings['loss_type'],
267
+ gradient_accumulation=settings['gradient_accumulation'],
268
+ gradient_accumulation_steps=settings['gradient_accumulation_steps'],
269
+ channels=settings['train_channels'])
267
270
 
268
271
  torch.cuda.empty_cache()
269
272
  torch.cuda.memory.empty_cache()
@@ -280,7 +283,7 @@ def train_test_model(src, settings, custom_model=False, custom_model_path=None):
280
283
  validation_split=0.0,
281
284
  pin_memory=settings['pin_memory'],
282
285
  normalize=settings['normalize'],
283
- channels=settings['channels'],
286
+ channels=settings['train_channels'],
284
287
  augment=False,
285
288
  verbose=settings['verbose'])
286
289
  if model == None:
@@ -314,6 +317,8 @@ def train_test_model(src, settings, custom_model=False, custom_model_path=None):
314
317
  torch.cuda.empty_cache()
315
318
  torch.cuda.memory.empty_cache()
316
319
  gc.collect()
320
+
321
+ return model_path
317
322
 
318
323
  def train_model(dst, model_type, train_loaders, train_loader_names, train_mode='erm', epochs=100, learning_rate=0.0001, weight_decay=0.05, amsgrad=False, optimizer_type='adamw', use_checkpoint=False, dropout_rate=0, n_jobs=20, val_loaders=None, test_loaders=None, init_weights='imagenet', intermedeate_save=None, chan_dict=None, schedule = None, loss_type='binary_cross_entropy_with_logits', gradient_accumulation=False, gradient_accumulation_steps=4, channels=['r','g','b']):
319
324
  """
@@ -348,7 +353,7 @@ def train_model(dst, model_type, train_loaders, train_loader_names, train_mode='
348
353
  """
349
354
 
350
355
  from .io import _save_model, _save_progress
351
- from .utils import compute_irm_penalty, calculate_loss, choose_model
356
+ from .utils import compute_irm_penalty, calculate_loss, choose_model, print_progress
352
357
 
353
358
  print(f'Train batches:{len(train_loaders)}, Validation batches:{len(val_loaders)}')
354
359
 
@@ -386,6 +391,7 @@ def train_model(dst, model_type, train_loaders, train_loader_names, train_mode='
386
391
  else:
387
392
  scheduler = None
388
393
 
394
+ time_ls = []
389
395
  if train_mode == 'erm':
390
396
  for epoch in range(1, epochs+1):
391
397
  model.train()
@@ -412,7 +418,13 @@ def train_model(dst, model_type, train_loaders, train_loader_names, train_mode='
412
418
  optimizer.zero_grad()
413
419
 
414
420
  avg_loss = running_loss / batch_idx
415
- print(f'\rTrain: epoch: {epoch} batch: {batch_idx}/{len(train_loaders)} avg_loss: {avg_loss:.5f} time: {(time.time()-start_time):.5f}', end='\r', flush=True)
421
+ #print(f'\rTrain: epoch: {epoch} batch: {batch_idx}/{len(train_loaders)} avg_loss: {avg_loss:.5f} time: {(time.time()-start_time):.5f}', end='\r', flush=True)
422
+
423
+ batch_size = len(train_loaders)
424
+ duration = time.time() - start_time
425
+ time_ls.append(duration)
426
+ metricks = f"Loss: {avg_loss:.5f}"
427
+ print_progress(files_processed=epoch, files_to_process=epochs, n_jobs=1, time_ls=time_ls, batch_size=batch_size, operation_type=f"Training {model_type} model", metricks=metricks)
416
428
 
417
429
  end_time = time.time()
418
430
  train_time = end_time - start_time
@@ -421,6 +433,7 @@ def train_model(dst, model_type, train_loaders, train_loader_names, train_mode='
421
433
  train_names = 'train'
422
434
  results_df, train_test_time = evaluate_model_performance(train_loaders, model, train_names, epoch, train_mode='erm', loss_type=loss_type)
423
435
  train_metrics_df['train_test_time'] = train_test_time
436
+
424
437
  if val_loaders != None:
425
438
  val_names = 'val'
426
439
  result, val_time = evaluate_model_performance(val_loaders, model, val_names, epoch, train_mode='erm', loss_type=loss_type)
@@ -430,6 +443,7 @@ def train_model(dst, model_type, train_loaders, train_loader_names, train_mode='
430
443
 
431
444
  results_df = pd.concat([results_df, result])
432
445
  train_metrics_df['val_time'] = val_time
446
+
433
447
  if test_loaders != None:
434
448
  test_names = 'test'
435
449
  result, test_test_time = evaluate_model_performance(test_loaders, model, test_names, epoch, train_mode='erm', loss_type=loss_type)
@@ -444,9 +458,30 @@ def train_model(dst, model_type, train_loaders, train_loader_names, train_mode='
444
458
  scheduler.step()
445
459
 
446
460
  _save_progress(dst, results_df, train_metrics_df, epoch, epochs)
447
- clear_output(wait=True)
448
- display(results_df)
449
- _save_model(model, model_type, results_df, dst, epoch, epochs, intermedeate_save=[0.99,0.98,0.95,0.94], channels=channels)
461
+ #clear_output(wait=True)
462
+ #display(results_df)
463
+
464
+ train_idx = f"{epoch}_train"
465
+ val_idx = f"{epoch}_val"
466
+ train_acc = results_df.loc[train_idx, 'accuracy']
467
+ neg_train_acc = results_df.loc[train_idx, 'neg_accuracy']
468
+ pos_train_acc = results_df.loc[train_idx, 'pos_accuracy']
469
+ val_acc = results_df.loc[val_idx, 'accuracy']
470
+ neg_val_acc = results_df.loc[val_idx, 'neg_accuracy']
471
+ pos_val_acc = results_df.loc[val_idx, 'pos_accuracy']
472
+ train_loss = results_df.loc[train_idx, 'loss']
473
+ train_prauc = results_df.loc[train_idx, 'prauc']
474
+ val_loss = results_df.loc[val_idx, 'loss']
475
+ val_prauc = results_df.loc[val_idx, 'prauc']
476
+
477
+ metricks = f"Train Acc: {train_acc:.5f} Val Acc: {val_acc:.5f} Train Loss: {train_loss:.5f} Val Loss: {val_loss:.5f} Train PRAUC: {train_prauc:.5f} Val PRAUC: {val_prauc:.5f}, Nc Train Acc: {neg_train_acc:.5f} Nc Val Acc: {neg_val_acc:.5f} Pc Train Acc: {pos_train_acc:.5f} Pc Val Acc: {pos_val_acc:.5f}"
478
+
479
+ batch_size = len(train_loaders)
480
+ duration = time.time() - start_time
481
+ time_ls.append(duration)
482
+ print_progress(files_processed=epoch, files_to_process=epochs, n_jobs=1, time_ls=time_ls, batch_size=batch_size, operation_type=f"Training {model_type} model", metricks=metricks)
483
+
484
+ model_path = _save_model(model, model_type, results_df, dst, epoch, epochs, intermedeate_save=[0.99,0.98,0.95,0.94], channels=channels)
450
485
 
451
486
  if train_mode == 'irm':
452
487
  dummy_w = torch.nn.Parameter(torch.Tensor([1.0])).to(device)
@@ -517,9 +552,10 @@ def train_model(dst, model_type, train_loaders, train_loader_names, train_mode='
517
552
  clear_output(wait=True)
518
553
  display(results_df)
519
554
  _save_progress(dst, results_df, train_metrics_df, epoch, epochs)
520
- _save_model(model, model_type, results_df, dst, epoch, epochs, intermedeate_save=[0.99,0.98,0.95,0.94])
521
- print(f'Saved model: {dst}')
522
- return model
555
+ model_path = _save_model(model, model_type, results_df, dst, epoch, epochs, intermedeate_save=[0.99,0.98,0.95,0.94])
556
+ print(f'Saved model: {model_path}')
557
+
558
+ return model, model_path
523
559
 
524
560
  def visualize_saliency_map(src, model_type='maxvit', model_path='', image_size=224, channels=[1,2,3], normalize=True, class_names=None, save_saliency=False, save_dir='saliency_maps'):
525
561
 
@@ -778,8 +814,32 @@ def visualize_smooth_grad(src, model_path, target_label_idx, image_size=224, cha
778
814
  smooth_grad_image = Image.fromarray((smooth_grad_map * 255).astype(np.uint8))
779
815
  smooth_grad_image.save(os.path.join(save_dir, f'smooth_grad_{file}'))
780
816
 
781
- # Usage
782
- #src = '/path/to/images'
783
- #model_path = '/path/to/model.pth'
784
- #target_label_idx = 0 # Change this to the target class index
785
- #visualize_smooth_grad(src, model_path, target_label_idx)
817
+ def deep_spacr(settings={}):
818
+ from .settings import deep_spacr_defaults
819
+ from .core import generate_training_dataset, generate_dataset, apply_model_to_tar
820
+
821
+ settings = deep_spacr_defaults(settings)
822
+ src = settings['src']
823
+
824
+ if settings['train'] or settings['test']:
825
+ if settings['generate_training_dataset']:
826
+ print(f"Generating train and test datasets ...")
827
+ train_path, test_path = generate_training_dataset(settings)
828
+ print(f'Generated Train set: {train_path}')
829
+ print(f'Generated Train set: {test_path}')
830
+ settings['src'] = os.path.dirname(train_path)
831
+
832
+ if settings['train_DL_model']:
833
+ print(f"Training model ...")
834
+ model_path = train_test_model(settings)
835
+ settings['model_path'] = model_path
836
+ settings['src'] = src
837
+
838
+ if settings['apply_model_to_dataset']:
839
+ if not os.path.exists(settings['tar_path']):
840
+ print(f"Generating dataset ...")
841
+ tar_path = generate_dataset(settings)
842
+ settings['tar_path'] = tar_path
843
+
844
+ if os.path.exists(settings['model_path']):
845
+ apply_model_to_tar(settings)
@@ -27,7 +27,7 @@ class MainApp(tk.Tk):
27
27
  }
28
28
 
29
29
  self.additional_gui_apps = {
30
- "Sequencing": (lambda frame: initiate_root(self, 'sequencing'), "Analyze sequencing data."),
30
+ #"Sequencing": (lambda frame: initiate_root(self, 'sequencing'), "Analyze sequencing data."),
31
31
  "Umap": (lambda frame: initiate_root(self, 'umap'), "Generate UMAP embeddings with datapoints represented as images."),
32
32
  "Train Cellpose": (lambda frame: initiate_root(self, 'train_cellpose'), "Train custom Cellpose models."),
33
33
  "ML Analyze": (lambda frame: initiate_root(self, 'ml_analyze'), "Machine learning analysis of data."),