spacr 0.2.4__py3-none-any.whl → 0.2.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. spacr/__init__.py +1 -11
  2. spacr/core.py +277 -349
  3. spacr/deep_spacr.py +248 -269
  4. spacr/gui.py +58 -54
  5. spacr/gui_core.py +689 -535
  6. spacr/gui_elements.py +1002 -153
  7. spacr/gui_utils.py +452 -107
  8. spacr/io.py +158 -91
  9. spacr/measure.py +199 -151
  10. spacr/plot.py +159 -47
  11. spacr/resources/font/open_sans/OFL.txt +93 -0
  12. spacr/resources/font/open_sans/OpenSans-Italic-VariableFont_wdth,wght.ttf +0 -0
  13. spacr/resources/font/open_sans/OpenSans-VariableFont_wdth,wght.ttf +0 -0
  14. spacr/resources/font/open_sans/README.txt +100 -0
  15. spacr/resources/font/open_sans/static/OpenSans-Bold.ttf +0 -0
  16. spacr/resources/font/open_sans/static/OpenSans-BoldItalic.ttf +0 -0
  17. spacr/resources/font/open_sans/static/OpenSans-ExtraBold.ttf +0 -0
  18. spacr/resources/font/open_sans/static/OpenSans-ExtraBoldItalic.ttf +0 -0
  19. spacr/resources/font/open_sans/static/OpenSans-Italic.ttf +0 -0
  20. spacr/resources/font/open_sans/static/OpenSans-Light.ttf +0 -0
  21. spacr/resources/font/open_sans/static/OpenSans-LightItalic.ttf +0 -0
  22. spacr/resources/font/open_sans/static/OpenSans-Medium.ttf +0 -0
  23. spacr/resources/font/open_sans/static/OpenSans-MediumItalic.ttf +0 -0
  24. spacr/resources/font/open_sans/static/OpenSans-Regular.ttf +0 -0
  25. spacr/resources/font/open_sans/static/OpenSans-SemiBold.ttf +0 -0
  26. spacr/resources/font/open_sans/static/OpenSans-SemiBoldItalic.ttf +0 -0
  27. spacr/resources/font/open_sans/static/OpenSans_Condensed-Bold.ttf +0 -0
  28. spacr/resources/font/open_sans/static/OpenSans_Condensed-BoldItalic.ttf +0 -0
  29. spacr/resources/font/open_sans/static/OpenSans_Condensed-ExtraBold.ttf +0 -0
  30. spacr/resources/font/open_sans/static/OpenSans_Condensed-ExtraBoldItalic.ttf +0 -0
  31. spacr/resources/font/open_sans/static/OpenSans_Condensed-Italic.ttf +0 -0
  32. spacr/resources/font/open_sans/static/OpenSans_Condensed-Light.ttf +0 -0
  33. spacr/resources/font/open_sans/static/OpenSans_Condensed-LightItalic.ttf +0 -0
  34. spacr/resources/font/open_sans/static/OpenSans_Condensed-Medium.ttf +0 -0
  35. spacr/resources/font/open_sans/static/OpenSans_Condensed-MediumItalic.ttf +0 -0
  36. spacr/resources/font/open_sans/static/OpenSans_Condensed-Regular.ttf +0 -0
  37. spacr/resources/font/open_sans/static/OpenSans_Condensed-SemiBold.ttf +0 -0
  38. spacr/resources/font/open_sans/static/OpenSans_Condensed-SemiBoldItalic.ttf +0 -0
  39. spacr/resources/font/open_sans/static/OpenSans_SemiCondensed-Bold.ttf +0 -0
  40. spacr/resources/font/open_sans/static/OpenSans_SemiCondensed-BoldItalic.ttf +0 -0
  41. spacr/resources/font/open_sans/static/OpenSans_SemiCondensed-ExtraBold.ttf +0 -0
  42. spacr/resources/font/open_sans/static/OpenSans_SemiCondensed-ExtraBoldItalic.ttf +0 -0
  43. spacr/resources/font/open_sans/static/OpenSans_SemiCondensed-Italic.ttf +0 -0
  44. spacr/resources/font/open_sans/static/OpenSans_SemiCondensed-Light.ttf +0 -0
  45. spacr/resources/font/open_sans/static/OpenSans_SemiCondensed-LightItalic.ttf +0 -0
  46. spacr/resources/font/open_sans/static/OpenSans_SemiCondensed-Medium.ttf +0 -0
  47. spacr/resources/font/open_sans/static/OpenSans_SemiCondensed-MediumItalic.ttf +0 -0
  48. spacr/resources/font/open_sans/static/OpenSans_SemiCondensed-Regular.ttf +0 -0
  49. spacr/resources/font/open_sans/static/OpenSans_SemiCondensed-SemiBold.ttf +0 -0
  50. spacr/resources/font/open_sans/static/OpenSans_SemiCondensed-SemiBoldItalic.ttf +0 -0
  51. spacr/resources/icons/logo.pdf +2786 -6
  52. spacr/resources/icons/logo_spacr.png +0 -0
  53. spacr/resources/icons/logo_spacr_1.png +0 -0
  54. spacr/sequencing.py +477 -587
  55. spacr/settings.py +217 -144
  56. spacr/utils.py +46 -46
  57. {spacr-0.2.4.dist-info → spacr-0.2.8.dist-info}/METADATA +46 -35
  58. spacr-0.2.8.dist-info/RECORD +100 -0
  59. {spacr-0.2.4.dist-info → spacr-0.2.8.dist-info}/WHEEL +1 -1
  60. spacr-0.2.4.dist-info/RECORD +0 -58
  61. {spacr-0.2.4.dist-info → spacr-0.2.8.dist-info}/LICENSE +0 -0
  62. {spacr-0.2.4.dist-info → spacr-0.2.8.dist-info}/entry_points.txt +0 -0
  63. {spacr-0.2.4.dist-info → spacr-0.2.8.dist-info}/top_level.txt +0 -0
spacr/io.py CHANGED
@@ -1,9 +1,9 @@
1
- import os, re, sqlite3, gc, torch, time, random, shutil, cv2, tarfile, cellpose, glob
1
+ import os, re, sqlite3, gc, torch, time, random, shutil, cv2, tarfile, cellpose, glob, queue
2
2
  import numpy as np
3
3
  import pandas as pd
4
4
  import tifffile
5
- from PIL import Image
6
- from collections import defaultdict, Counter
5
+ from PIL import Image, ImageOps
6
+ from collections import defaultdict, Counter, deque
7
7
  from pathlib import Path
8
8
  from functools import partial
9
9
  from matplotlib.animation import FuncAnimation
@@ -17,12 +17,12 @@ import imageio.v2 as imageio2
17
17
  import matplotlib.pyplot as plt
18
18
  from io import BytesIO
19
19
  from IPython.display import display, clear_output
20
- from multiprocessing import Pool, cpu_count
21
- from torch.utils.data import Dataset
20
+ from multiprocessing import Pool, cpu_count, Process, Queue
21
+ from torch.utils.data import Dataset, DataLoader
22
22
  import matplotlib.pyplot as plt
23
23
  from torchvision.transforms import ToTensor
24
24
  import seaborn as sns
25
-
25
+ import atexit
26
26
 
27
27
  from .logger import log_function_call
28
28
 
@@ -444,20 +444,7 @@ class NoClassDataset(Dataset):
444
444
  # Return both the image and its filename
445
445
  return img, self.filenames[index]
446
446
 
447
- class MyDataset(Dataset):
448
- """
449
- A custom dataset class for loading and processing image data.
450
-
451
- Args:
452
- data_dir (str): The directory path where the image data is stored.
453
- loader_classes (list): A list of class names for the dataset.
454
- transform (callable, optional): A function/transform to apply to the image data. Default is None.
455
- shuffle (bool, optional): Whether to shuffle the dataset. Default is True.
456
- pin_memory (bool, optional): Whether to pin the loaded images to memory. Default is False.
457
- specific_files (list, optional): A list of specific file paths to include in the dataset. Default is None.
458
- specific_labels (list, optional): A list of specific labels corresponding to the specific files. Default is None.
459
- """
460
-
447
+ class spacrDataset(Dataset):
461
448
  def __init__(self, data_dir, loader_classes, transform=None, shuffle=True, pin_memory=False, specific_files=None, specific_labels=None):
462
449
  self.data_dir = data_dir
463
450
  self.classes = loader_classes
@@ -466,7 +453,7 @@ class MyDataset(Dataset):
466
453
  self.pin_memory = pin_memory
467
454
  self.filenames = []
468
455
  self.labels = []
469
-
456
+
470
457
  if specific_files and specific_labels:
471
458
  self.filenames = specific_files
472
459
  self.labels = specific_labels
@@ -479,33 +466,113 @@ class MyDataset(Dataset):
479
466
 
480
467
  if self.shuffle:
481
468
  self.shuffle_dataset()
482
-
469
+
483
470
  if self.pin_memory:
484
- self.images = [self.load_image(f) for f in self.filenames]
485
-
471
+ # Use multiprocessing to load images in parallel
472
+ with Pool(processes=cpu_count()) as pool:
473
+ self.images = pool.map(self.load_image, self.filenames)
474
+ else:
475
+ self.images = None
476
+
486
477
  def load_image(self, img_path):
487
478
  img = Image.open(img_path).convert('RGB')
479
+ img = ImageOps.exif_transpose(img) # Handle image orientation
488
480
  return img
489
-
481
+
490
482
  def __len__(self):
491
483
  return len(self.filenames)
492
-
484
+
493
485
  def shuffle_dataset(self):
494
486
  combined = list(zip(self.filenames, self.labels))
495
487
  random.shuffle(combined)
496
488
  self.filenames, self.labels = zip(*combined)
497
-
489
+
498
490
  def get_plate(self, filepath):
499
- filename = os.path.basename(filepath) # Get just the filename from the full path
491
+ filename = os.path.basename(filepath)
500
492
  return filename.split('_')[0]
501
-
493
+
502
494
  def __getitem__(self, index):
495
+ if self.pin_memory:
496
+ img = self.images[index]
497
+ else:
498
+ img = self.load_image(self.filenames[index])
503
499
  label = self.labels[index]
504
500
  filename = self.filenames[index]
505
- img = self.load_image(filename)
506
501
  if self.transform:
507
502
  img = self.transform(img)
508
503
  return img, label, filename
504
+
505
+ class spacrDataLoader(DataLoader):
506
+ def __init__(self, *args, preload_batches=1, **kwargs):
507
+ super().__init__(*args, **kwargs)
508
+ self.preload_batches = preload_batches
509
+ self.batch_queue = Queue(maxsize=preload_batches)
510
+ self.process = None
511
+ self.current_batch_index = 0
512
+ self._stop_event = False
513
+ self.pin_memory = kwargs.get('pin_memory', False)
514
+ atexit.register(self.cleanup)
515
+
516
+ def _preload_next_batches(self):
517
+ try:
518
+ for _ in range(self.preload_batches):
519
+ if self._stop_event:
520
+ break
521
+ batch = next(self._iterator)
522
+ if self.pin_memory:
523
+ batch = self._pin_memory_batch(batch)
524
+ self.batch_queue.put(batch)
525
+ except StopIteration:
526
+ pass
527
+
528
+ def _start_preloading(self):
529
+ if self.process is None or not self.process.is_alive():
530
+ self._iterator = iter(super().__iter__())
531
+ if not self.pin_memory:
532
+ self.process = Process(target=self._preload_next_batches)
533
+ self.process.start()
534
+ else:
535
+ self._preload_next_batches() # Directly load if pin_memory is True
536
+
537
+ def _pin_memory_batch(self, batch):
538
+ if isinstance(batch, (list, tuple)):
539
+ return [b.pin_memory() if isinstance(b, torch.Tensor) else b for b in batch]
540
+ elif isinstance(batch, torch.Tensor):
541
+ return batch.pin_memory()
542
+ else:
543
+ return batch
544
+
545
+ def __iter__(self):
546
+ self._start_preloading()
547
+ return self
548
+
549
+ def __next__(self):
550
+ if self.process and not self.process.is_alive() and self.batch_queue.empty():
551
+ raise StopIteration
552
+
553
+ try:
554
+ if self.pin_memory:
555
+ next_batch = self.batch_queue.get(timeout=60)
556
+ else:
557
+ next_batch = self.batch_queue.get(timeout=60)
558
+ self.current_batch_index += 1
559
+
560
+ # Start preloading the next batches
561
+ if self.batch_queue.qsize() < self.preload_batches:
562
+ self._start_preloading()
563
+
564
+ return next_batch
565
+ except queue.Empty:
566
+ raise StopIteration
567
+
568
+ def cleanup(self):
569
+ self._stop_event = True
570
+ if self.process and self.process.is_alive():
571
+ self.process.terminate()
572
+ self.process.join()
573
+
574
+ def __del__(self):
575
+ self.cleanup()
509
576
 
510
577
  class NoClassDataset(Dataset):
511
578
  def __init__(self, data_dir, transform=None, shuffle=True, load_to_memory=False):
@@ -588,20 +655,20 @@ def _rename_and_organize_image_files(src, regex, batch_size=100, pick_slice=Fals
588
655
  regular_expression = re.compile(regex)
589
656
  images_by_key = defaultdict(list)
590
657
  stack_path = os.path.join(src, 'stack')
658
+ files_processed = 0
591
659
  if not os.path.exists(stack_path) or (os.path.isdir(stack_path) and len(os.listdir(stack_path)) == 0):
592
660
  all_filenames = [filename for filename in os.listdir(src) if filename.endswith(img_format)]
593
- print(f'All_files:{len(all_filenames)} in {src}')
661
+ print(f'All_files: {len(all_filenames)} in {src}')
594
662
  time_ls = []
595
- processed = 0
596
- for i in range(0, len(all_filenames), batch_size):
663
+
664
+ for idx in range(0, len(all_filenames), batch_size):
597
665
  start = time.time()
598
- batch_filenames = all_filenames[i:i+batch_size]
599
- processed += len(batch_filenames)
666
+ batch_filenames = all_filenames[idx:idx+batch_size]
600
667
  for filename in batch_filenames:
601
668
  images_by_key = _extract_filename_metadata(batch_filenames, src, images_by_key, regular_expression, metadata_type, pick_slice, skip_mode)
602
-
669
+
603
670
  if pick_slice:
604
- for key in images_by_key:
671
+ for i, key in enumerate(images_by_key):
605
672
  plate, well, field, channel, mode = key
606
673
  max_intensity_slice = max(images_by_key[key], key=lambda x: np.percentile(x, 90))
607
674
  mip_image = Image.fromarray(max_intensity_slice)
@@ -609,21 +676,19 @@ def _rename_and_organize_image_files(src, regex, batch_size=100, pick_slice=Fals
609
676
  os.makedirs(output_dir, exist_ok=True)
610
677
  output_filename = f'{plate}_{well}_{field}.tif'
611
678
  output_path = os.path.join(output_dir, output_filename)
612
-
613
- if os.path.exists(output_path):
614
- print(f'WARNING: A file with the same name already exists at location {output_filename}')
615
- else:
616
- mip_image.save(output_path)
617
-
679
+ files_processed += 1
618
680
  stop = time.time()
619
681
  duration = stop - start
620
682
  time_ls.append(duration)
621
- files_processed = processed
622
683
  files_to_process = len(all_filenames)
623
684
  print_progress(files_processed, files_to_process, n_jobs=1, time_ls=time_ls, batch_size=batch_size, operation_type='Preprocessing filenames')
624
685
 
686
+ if os.path.exists(output_path):
687
+ print(f'WARNING: A file with the same name already exists at location {output_filename}')
688
+ else:
689
+ mip_image.save(output_path)
625
690
  else:
626
- for key, images in images_by_key.items():
691
+ for i, (key, images) in enumerate(images_by_key.items()):
627
692
  mip = np.max(np.stack(images), axis=0)
628
693
  mip_image = Image.fromarray(mip)
629
694
  plate, well, field, channel = key[:4]
@@ -631,18 +696,17 @@ def _rename_and_organize_image_files(src, regex, batch_size=100, pick_slice=Fals
631
696
  os.makedirs(output_dir, exist_ok=True)
632
697
  output_filename = f'{plate}_{well}_{field}.tif'
633
698
  output_path = os.path.join(output_dir, output_filename)
634
-
635
- if os.path.exists(output_path):
636
- print(f'WARNING: A file with the same name already exists at location {output_filename}')
637
- else:
638
- mip_image.save(output_path)
699
+ files_processed += 1
639
700
  stop = time.time()
640
701
  duration = stop - start
641
702
  time_ls.append(duration)
642
- files_processed = processed
643
703
  files_to_process = len(all_filenames)
644
704
  print_progress(files_processed, files_to_process, n_jobs=1, time_ls=time_ls, batch_size=batch_size, operation_type='Preprocessing filenames')
645
705
 
706
+ if os.path.exists(output_path):
707
+ print(f'WARNING: A file with the same name already exists at location {output_filename}')
708
+ else:
709
+ mip_image.save(output_path)
646
710
  images_by_key.clear()
647
711
 
648
712
  # Move original images to a new directory
@@ -656,6 +720,7 @@ def _rename_and_organize_image_files(src, regex, batch_size=100, pick_slice=Fals
656
720
  print(f'WARNING: A file with the same name already exists at location {move}')
657
721
  else:
658
722
  shutil.move(os.path.join(src, filename), move)
723
+ files_processed = 0
659
724
  return
660
725
 
661
726
  def _merge_file(chan_dirs, stack_dir, file_name):
@@ -975,7 +1040,7 @@ def _concatenate_channel(src, channels, randomize=True, timelapse=False, batch_s
975
1040
  time_ls.append(duration)
976
1041
  files_processed = i+1
977
1042
  files_to_process = time_stack_path_lists
978
- print_progress(files_processed, files_to_process, n_jobs=1, time_ls=None, batch_size=None, operation_type="Concatinating")
1043
+ print_progress(files_processed, files_to_process, n_jobs=1, time_ls=time_ls, batch_size=batch_size, operation_type="Concatinating")
979
1044
  stack = np.stack(stack_region)
980
1045
  save_loc = os.path.join(channel_stack_loc, f'{name}.npz')
981
1046
  np.savez(save_loc, data=stack, filenames=filenames_region)
@@ -1006,7 +1071,7 @@ def _concatenate_channel(src, channels, randomize=True, timelapse=False, batch_s
1006
1071
  time_ls.append(duration)
1007
1072
  files_processed = i+1
1008
1073
  files_to_process = nr_files
1009
- print_progress(files_processed, files_to_process, n_jobs=1, time_ls=None, batch_size=None, operation_type="Concatinating")
1074
+ print_progress(files_processed, files_to_process, n_jobs=1, time_ls=time_ls, batch_size=batch_size, operation_type="Concatinating")
1010
1075
  if (i+1) % batch_size == 0 or i+1 == nr_files:
1011
1076
  unique_shapes = {arr.shape[:-1] for arr in stack_ls}
1012
1077
  if len(unique_shapes) > 1:
@@ -1104,7 +1169,7 @@ def _normalize_img_batch(stack, channels, save_dtype, settings):
1104
1169
  time_ls.append(duration)
1105
1170
  files_processed = i+1
1106
1171
  files_to_process = len(channels)
1107
- print_progress(files_processed, files_to_process, n_jobs=1, time_ls=None, batch_size=None, operation_type=f"Normalizing: Channel: {channel}")
1172
+ print_progress(files_processed, files_to_process, n_jobs=1, time_ls=time_ls, batch_size=None, operation_type=f"Normalizing")
1108
1173
 
1109
1174
  return normalized_stack.astype(save_dtype)
1110
1175
 
@@ -1151,7 +1216,6 @@ def concatenate_and_normalize(src, channels, save_dtype=np.float32, settings={})
1151
1216
  parts = file.split('_')
1152
1217
  name = parts[0] + '_' + parts[1] + '_' + parts[2]
1153
1218
  array = np.load(path)
1154
- #array = np.take(array, channels, axis=2)
1155
1219
  stack_region.append(array)
1156
1220
  filenames_region.append(os.path.basename(path))
1157
1221
  stop = time.time()
@@ -1159,7 +1223,7 @@ def concatenate_and_normalize(src, channels, save_dtype=np.float32, settings={})
1159
1223
  time_ls.append(duration)
1160
1224
  files_processed = i+1
1161
1225
  files_to_process = len(time_stack_path_lists)
1162
- print_progress(files_processed, files_to_process, n_jobs=1, time_ls=None, batch_size=None, operation_type="Concatinating")
1226
+ print_progress(files_processed, files_to_process, n_jobs=1, time_ls=time_ls, batch_size=None, operation_type="Concatinating")
1163
1227
  stack = np.stack(stack_region)
1164
1228
 
1165
1229
  normalized_stack = _normalize_img_batch(stack=stack,
@@ -1188,18 +1252,18 @@ def concatenate_and_normalize(src, channels, save_dtype=np.float32, settings={})
1188
1252
  stack_ls = []
1189
1253
  filenames_batch = []
1190
1254
  time_ls = []
1255
+ files_processed = 0
1191
1256
  for i, path in enumerate(paths):
1192
1257
  start = time.time()
1193
1258
  array = np.load(path)
1194
- #array = np.take(array, channels, axis=2)
1195
1259
  stack_ls.append(array)
1196
1260
  filenames_batch.append(os.path.basename(path))
1197
1261
  stop = time.time()
1198
1262
  duration = stop - start
1199
1263
  time_ls.append(duration)
1200
- files_processed = i+1
1264
+ files_processed += 1
1201
1265
  files_to_process = nr_files
1202
- print_progress(files_processed, files_to_process, n_jobs=1, time_ls=None, batch_size=None, operation_type="Concatinating")
1266
+ print_progress(files_processed, files_to_process, n_jobs=1, time_ls=time_ls, batch_size=None, operation_type="Concatinating")
1203
1267
 
1204
1268
  if (i + 1) % settings['batch_size'] == 0 or i + 1 == nr_files:
1205
1269
  unique_shapes = {arr.shape[:-1] for arr in stack_ls}
@@ -1350,12 +1414,12 @@ def _normalize_stack(src, backgrounds=[100, 100, 100], remove_backgrounds=[False
1350
1414
  average_time = np.mean(time_ls) if len(time_ls) > 0 else 0
1351
1415
  print(f'channels:{chan_index}/{stack.shape[-1] - 1}, arrays:{array_index + 1}/{single_channel.shape[0]}, Signal:{upper:.1f}, noise:{lower:.1f}, Signal-to-noise:{average_stnr:.1f}, Time/channel:{average_time:.2f}sec')
1352
1416
 
1353
- stop = time.time()
1354
- duration = stop - start
1355
- time_ls.append(duration)
1356
- files_processed = file_index + 1
1357
- files_to_process = len(paths)
1358
- print_progress(files_processed, files_to_process, n_jobs=1, time_ls=time_ls, batch_size=None, operation_type="Normalizing")
1417
+ #stop = time.time()
1418
+ #duration = stop - start
1419
+ #time_ls.append(duration)
1420
+ #files_processed = file_index + 1
1421
+ #files_to_process = len(paths)
1422
+ #print_progress(files_processed, files_to_process, n_jobs=1, time_ls=time_ls, batch_size=None, operation_type="Normalizing")
1359
1423
 
1360
1424
  normalized_stack[:, :, :, channel] = arr_2d_normalized
1361
1425
 
@@ -1405,12 +1469,12 @@ def _normalize_timelapse(src, lower_percentile=2, save_dtype=np.float32):
1405
1469
 
1406
1470
  print(f'channels:{chan_index+1}/{stack.shape[-1]}, arrays:{array_index+1}/{single_channel.shape[0]}', end='\r')
1407
1471
 
1408
- stop = time.time()
1409
- duration = stop - start
1410
- time_ls.append(duration)
1411
- files_processed = file_index+1
1412
- files_to_process = len(paths)
1413
- print_progress(files_processed, files_to_process, n_jobs=1, time_ls=time_ls, batch_size=None, operation_type="Normalizing")
1472
+ #stop = time.time()
1473
+ #duration = stop - start
1474
+ #time_ls.append(duration)
1475
+ #files_processed = file_index+1
1476
+ #files_to_process = len(paths)
1477
+ #print_progress(files_processed, files_to_process, n_jobs=1, time_ls=time_ls, batch_size=None, operation_type="Normalizing")
1414
1478
 
1415
1479
  save_loc = os.path.join(output_fldr, f'{name}_norm_timelapse.npz')
1416
1480
  np.savez(save_loc, data=normalized_stack, filenames=filenames)
@@ -1620,8 +1684,8 @@ def preprocess_img_data(settings):
1620
1684
  save_dtype=np.float32,
1621
1685
  settings=settings)
1622
1686
 
1623
- if plot:
1624
- _plot_4D_arrays(src+'/norm_channel_stack', nr_npz=1, nr=nr)
1687
+ #if plot:
1688
+ # _plot_4D_arrays(src+'/norm_channel_stack', nr_npz=1, nr=nr)
1625
1689
 
1626
1690
  return settings, src
1627
1691
 
@@ -1951,7 +2015,7 @@ def _load_and_concatenate_arrays(src, channels, cell_chann_dim, nucleus_chann_di
1951
2015
  all_imgs = len(os.listdir(reference_folder))
1952
2016
  time_ls = []
1953
2017
  # Iterate through each file in the reference folder
1954
- for filename in os.listdir(reference_folder):
2018
+ for idx, filename in enumerate(os.listdir(reference_folder)):
1955
2019
  start = time.time()
1956
2020
  stack_ls = []
1957
2021
  if filename.endswith('.npy'):
@@ -2012,7 +2076,7 @@ def _load_and_concatenate_arrays(src, channels, cell_chann_dim, nucleus_chann_di
2012
2076
  stop = time.time()
2013
2077
  duration = stop - start
2014
2078
  time_ls.append(duration)
2015
- files_processed = count
2079
+ files_processed = idx+1
2016
2080
  files_to_process = all_imgs
2017
2081
  print_progress(files_processed, files_to_process, n_jobs=1, time_ls=time_ls, batch_size=None, operation_type="Merging Arrays")
2018
2082
 
@@ -2295,18 +2359,27 @@ def _save_model(model, model_type, results_df, dst, epoch, epochs, intermedeate_
2295
2359
 
2296
2360
  def save_model_at_threshold(threshold, epoch, suffix=""):
2297
2361
  percentile = str(threshold * 100)
2298
- print(f'\rfound: {percentile}% accurate model')#, end='\r', flush=True)
2299
- torch.save(model, f'{dst}/{model_type}_epoch_{str(epoch)}{suffix}_acc_{percentile}_channels_{channels_str}.pth')
2362
+ print(f'Found: {percentile}% accurate model')
2363
+ model_path = f'{dst}/{model_type}_epoch_{str(epoch)}{suffix}_acc_{percentile}_channels_{channels_str}.pth'
2364
+ torch.save(model, model_path)
2365
+ return model_path
2300
2366
 
2301
2367
  if epoch % 100 == 0 or epoch == epochs:
2302
- torch.save(model, f'{dst}/{model_type}_epoch_{str(epoch)}_channels_{channels_str}.pth')
2368
+ model_path = f'{dst}/{model_type}_epoch_{str(epoch)}_channels_{channels_str}.pth'
2369
+ torch.save(model, model_path)
2370
+ return model_path
2303
2371
 
2304
2372
  for threshold in intermedeate_save:
2305
- if results_df['neg_accuracy'].dropna().mean() >= threshold and results_df['pos_accuracy'].dropna().mean() >= threshold:
2306
- save_model_at_threshold(threshold, epoch)
2307
- break # Ensure we only save for the highest matching threshold
2373
+ if results_df['neg_accuracy'] >= threshold and results_df['pos_accuracy'] >= threshold:
2374
+ print(f"Nc class accuracy: {results_df['neg_accuracy']} Pc class Accuracy: {results_df['pos_accuracy']}")
2375
+ model_path = save_model_at_threshold(threshold, epoch)
2376
+ break
2377
+ else:
2378
+ model_path = None
2379
+
2380
+ return model_path
2308
2381
 
2309
- def _save_progress(dst, results_df, train_metrics_df, epoch, epochs):
2382
+ def _save_progress(dst, results_df, result_type='train'):
2310
2383
  """
2311
2384
  Save the progress of the classification model.
2312
2385
 
@@ -2320,18 +2393,13 @@ def _save_progress(dst, results_df, train_metrics_df, epoch, epochs):
2320
2393
  """
2321
2394
  # Save accuracy, loss, PRAUC
2322
2395
  os.makedirs(dst, exist_ok=True)
2323
- results_path = os.path.join(dst, 'acc_loss_prauc.csv')
2396
+ results_path = os.path.join(dst, f'{result_type}.csv')
2324
2397
  if not os.path.exists(results_path):
2325
2398
  results_df.to_csv(results_path, index=True, header=True, mode='w')
2326
2399
  else:
2327
2400
  results_df.to_csv(results_path, index=True, header=False, mode='a')
2328
-
2329
- training_metrics_path = os.path.join(dst, 'training_metrics.csv')
2330
- if not os.path.exists(training_metrics_path):
2331
- train_metrics_df.to_csv(training_metrics_path, index=True, header=True, mode='w')
2332
- else:
2333
- train_metrics_df.to_csv(training_metrics_path, index=True, header=False, mode='a')
2334
- if epoch == epochs:
2401
+
2402
+ if result_type == 'train':
2335
2403
  read_plot_model_stats(results_path, save=True)
2336
2404
  return
2337
2405
 
@@ -2550,7 +2618,6 @@ def _read_mask(mask_path):
2550
2618
  mask = img_as_uint(mask)
2551
2619
  return mask
2552
2620
 
2553
-
2554
2621
  def convert_numpy_to_tiff(folder_path, limit=None):
2555
2622
  """
2556
2623
  Converts all numpy files in a folder to TIFF format and saves them in a subdirectory 'tiff'.