spacr 1.0.7__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
spacr/utils.py CHANGED
@@ -77,7 +77,7 @@ def _generate_mask_random_cmap(mask):
77
77
  """
78
78
  Generate a random colormap based on the unique labels in the given mask.
79
79
 
80
- Parameters:
80
+ Args:
81
81
  mask (ndarray): The mask array containing unique labels.
82
82
 
83
83
  Returns:
@@ -92,7 +92,18 @@ def _generate_mask_random_cmap(mask):
92
92
  return random_cmap
93
93
 
94
94
  def filepaths_to_database(img_paths, settings, source_folder, crop_mode):
95
+ """
96
+ Insert image paths and metadata into a SQLite database.
97
+
98
+ Args:
99
+ img_paths (list of str): Image file paths to insert.
100
+ settings (dict): Configuration dictionary. Must contain a 'timelapse' key (bool).
101
+ source_folder (str): Folder containing the SQLite database.
102
+ crop_mode (str): One of 'cell', 'nucleus', 'pathogen', or 'cytoplasm'.
95
103
 
104
+ Raises:
105
+ sqlite3.OperationalError: On database connection or write errors.
106
+ """
96
107
  png_df = pd.DataFrame(img_paths, columns=['png_path'])
97
108
 
98
109
  png_df['file_name'] = png_df['png_path'].apply(lambda x: os.path.basename(x))
@@ -129,6 +140,17 @@ def filepaths_to_database(img_paths, settings, source_folder, crop_mode):
129
140
  traceback.print_exc()
130
141
 
131
142
  def activation_maps_to_database(img_paths, source_folder, settings):
143
+ """
144
+ Store image paths and metadata in a SQLite database.
145
+
146
+ Args:
147
+ img_paths (list of str): Paths to image files.
148
+ source_folder (str): Folder where the database is stored.
149
+ settings (dict): Must include 'dataset' and 'cam_type' keys.
150
+
151
+ Raises:
152
+ sqlite3.OperationalError: On database errors.
153
+ """
132
154
  from .io import _create_database
133
155
 
134
156
  png_df = pd.DataFrame(img_paths, columns=['png_path'])
@@ -152,6 +174,18 @@ def activation_maps_to_database(img_paths, source_folder, settings):
152
174
  traceback.print_exc()
153
175
 
154
176
  def activation_correlations_to_database(df, img_paths, source_folder, settings):
177
+ """
178
+ Save activation correlation data to a SQLite database.
179
+
180
+ Args:
181
+ df (pd.DataFrame): DataFrame with correlation data. Must include 'file_name'.
182
+ img_paths (list): List of image file paths.
183
+ source_folder (str): Folder where the database is stored.
184
+ settings (dict): Must include 'dataset' and 'cam_type' keys.
185
+
186
+ Raises:
187
+ sqlite3.OperationalError: On database errors.
188
+ """
155
189
  from .io import _create_database
156
190
 
157
191
  png_df = pd.DataFrame(img_paths, columns=['png_path'])
@@ -341,6 +375,39 @@ def load_settings(csv_file_path, show=False, setting_key='setting_key', setting_
341
375
  return result_dict
342
376
 
343
377
  def save_settings(settings, name='settings', show=False):
378
+ """
379
+ Save a dictionary of settings to a CSV file.
380
+
381
+ This function takes a dictionary of settings, processes it, and saves it
382
+ as a CSV file in a specified directory. It also provides an option to
383
+ display the settings as a DataFrame.
384
+
385
+ Args:
386
+ settings (dict): A dictionary containing the settings to be saved.
387
+ name (str, optional): The base name for the output CSV file. Defaults to 'settings'.
388
+ show (bool, optional): If True, displays the settings as a DataFrame. Defaults to False.
389
+
390
+ Raises:
391
+ KeyError: If the 'src' key is not present in the settings dictionary.
392
+
393
+ Notes:
394
+ - If the 'src' key in the settings dictionary is a list, the first element
395
+ is used as the source directory, and the file name is appended with '_list'.
396
+ - If the 'test_mode' key exists in the settings dictionary, it is set to False.
397
+ - If the 'plot' key exists in the settings dictionary, it is set to False.
398
+ - A directory named 'settings' is created inside the source directory if it does not exist.
399
+ - The settings are saved as a CSV file in the 'settings' directory.
400
+ - The file path where the settings are saved is printed.
401
+
402
+ Example:
403
+ >>> settings = {
404
+ ... 'src': '/path/to/source',
405
+ ... 'test_mode': True,
406
+ ... 'plot': True,
407
+ ... 'param1': 42
408
+ ... }
409
+ >>> save_settings(settings, name='experiment1', show=True)
410
+ """
344
411
 
345
412
  settings_2 = settings.copy()
346
413
 
@@ -367,6 +434,33 @@ def save_settings(settings, name='settings', show=False):
367
434
  settings_df.to_csv(settings_csv, index=False)
368
435
 
369
436
  def print_progress(files_processed, files_to_process, n_jobs, time_ls=None, batch_size=None, operation_type=""):
437
+ """
438
+ Prints the progress of a file processing operation along with estimated time information.
439
+
440
+ Args:
441
+ files_processed (int or list): The number of files processed so far. If a list is provided,
442
+ its unique length will be used.
443
+ files_to_process (int or list): The total number of files to process. If a list is provided,
444
+ its unique length will be used.
445
+ n_jobs (int): The number of parallel jobs being used for processing.
446
+ time_ls (list, optional): A list of time durations for processing batches or files. Used to
447
+ calculate average time and estimated time left. Defaults to None.
448
+ batch_size (int or list, optional): The size of each batch being processed. If a list is
449
+ provided, its length will be used. Defaults to None.
450
+ operation_type (str, optional): A string describing the type of operation being performed.
451
+ Defaults to an empty string.
452
+
453
+ Returns:
454
+ None: This function prints the progress and time information to the console.
455
+
456
+ Notes:
457
+ - If `time_ls` is provided, the function calculates the average time per batch or file and
458
+ estimates the time remaining for the operation.
459
+ - If `batch_size` is provided, the function calculates the average time per image within
460
+ a batch.
461
+ - Handles cases where inputs are lists or non-integer types by converting them to integers
462
+ or calculating their lengths.
463
+ """
370
464
  if isinstance(files_processed, list):
371
465
  files_processed = len(set(files_processed))
372
466
  if isinstance(files_to_process, list):
@@ -399,6 +493,20 @@ def print_progress(files_processed, files_to_process, n_jobs, time_ls=None, batc
399
493
  print(f'Progress: {files_processed}/{files_to_process}, operation_type: {operation_type}, {time_info}')
400
494
 
401
495
  def reset_mp():
496
+ """
497
+ Resets the multiprocessing start method based on the operating system.
498
+ On Windows, the start method is set to 'spawn' if it is not already set.
499
+ On Linux and macOS (Darwin), the start method is set to 'fork' if it is not already set.
500
+ This function ensures compatibility with the multiprocessing module by
501
+ enforcing the appropriate start method for the current platform.
502
+ Note:
503
+ - The function uses `get_start_method` to retrieve the current start method.
504
+ - The `set_start_method` function is called with `force=True` to override
505
+ the existing start method if necessary.
506
+ Raises:
507
+ ValueError: If an invalid start method is encountered or if the start
508
+ method cannot be set for some reason.
509
+ """
402
510
  current_method = get_start_method()
403
511
  system = platform.system()
404
512
 
@@ -410,6 +518,23 @@ def reset_mp():
410
518
  set_start_method('fork', force=True)
411
519
 
412
520
  def is_multiprocessing_process(process):
521
+ """
522
+ Check if the given process is a multiprocessing process.
523
+
524
+ This function examines the command-line arguments of the provided process
525
+ to determine if it is associated with Python's multiprocessing module.
526
+
527
+ Args:
528
+ process (psutil.Process): A psutil Process object representing the process to check.
529
+
530
+ Returns:
531
+ bool: True if the process is a multiprocessing process, False otherwise.
532
+
533
+ Raises:
534
+ psutil.NoSuchProcess: If the process no longer exists.
535
+ psutil.AccessDenied: If access to the process information is denied.
536
+ psutil.ZombieProcess: If the process is a zombie process.
537
+ """
413
538
  """ Check if the process is a multiprocessing process. """
414
539
  try:
415
540
  for cmd in process.cmdline():
@@ -420,6 +545,24 @@ def is_multiprocessing_process(process):
420
545
  return False
421
546
 
422
547
  def close_file_descriptors():
548
+ """
549
+ Closes all open file descriptors starting from 3 up to the soft limit
550
+ of the maximum number of file descriptors allowed for the process.
551
+
552
+ This function is useful for cleaning up resources by ensuring that
553
+ no unnecessary file descriptors remain open. It skips standard input,
554
+ output, and error (file descriptors 0, 1, and 2).
555
+
556
+ Exceptions during the closing of file descriptors are caught and ignored.
557
+
558
+ Note:
559
+ The function uses the `resource` module to retrieve the soft limit
560
+ for the maximum number of file descriptors.
561
+
562
+ Raises:
563
+ OSError: If an error occurs while closing a file descriptor, it is
564
+ caught and ignored.
565
+ """
423
566
  """ Close file descriptors and shared memory objects. """
424
567
  import resource
425
568
 
@@ -431,6 +574,23 @@ def close_file_descriptors():
431
574
  pass
432
575
 
433
576
  def close_multiprocessing_processes():
577
+ """
578
+ Close all multiprocessing processes and clean up associated resources.
579
+ This function iterates through all running processes and terminates any
580
+ that are identified as multiprocessing processes, excluding the current
581
+ process. It waits for up to 5 seconds for each process to terminate and
582
+ logs the termination status. Additionally, it handles exceptions that may
583
+ occur during the termination process, such as access denial or the process
584
+ no longer existing.
585
+ After terminating the processes, it ensures that any open file descriptors
586
+ are properly closed.
587
+ Note:
588
+ This function relies on the `psutil` library to inspect and manage
589
+ processes, and assumes the existence of helper functions:
590
+ - `is_multiprocessing_process(proc)`: Determines if a process is a
591
+ multiprocessing process.
592
+ - `close_file_descriptors()`: Closes any open file descriptors.
593
+ """
434
594
  """ Close all multiprocessing processes. """
435
595
  current_pid = os.getpid()
436
596
  for proc in psutil.process_iter(['pid', 'cmdline']):
@@ -452,6 +612,28 @@ def close_multiprocessing_processes():
452
612
  close_file_descriptors()
453
613
 
454
614
  def check_mask_folder(src,mask_fldr):
615
+ def check_mask_folder(src, mask_fldr):
616
+ """
617
+ Checks the status of a mask folder within a given source directory.
618
+ This function verifies whether the specified mask folder exists and
619
+ whether the number of mask files matches the number of stack files
620
+ in the corresponding stack folder. It returns a boolean indicating
621
+ whether further processing is needed.
622
+ Args:
623
+ src (str): The source directory containing the 'masks' and 'stack' subdirectories.
624
+ mask_fldr (str): The name of the mask folder to check within the 'masks' directory.
625
+ Returns:
626
+ bool:
627
+ - True if the mask folder does not exist or if the number of mask files
628
+ does not match the number of stack files.
629
+ - False if the mask folder exists and all masks have been generated.
630
+ Side Effects:
631
+ Prints a message if all masks have been generated for the specified mask folder.
632
+ Example:
633
+ >>> check_mask_folder('/home/user/data', 'experiment1')
634
+ All masks have been generated for experiment1
635
+ False
636
+ """
455
637
 
456
638
  mask_folder = os.path.join(src,'masks',mask_fldr)
457
639
  stack_folder = os.path.join(src,'stack')
@@ -469,6 +651,17 @@ def check_mask_folder(src,mask_fldr):
469
651
  return True
470
652
 
471
653
  def smooth_hull_lines(cluster_data):
654
+ """
655
+ Smooths the convex hull of a set of 2D points using spline interpolation.
656
+
657
+ Args:
658
+ cluster_data (numpy.ndarray): A 2D array of shape (n_points, 2) representing
659
+ the coordinates of the points in the cluster.
660
+
661
+ Returns:
662
+ tuple: Two 1D numpy arrays representing the x and y coordinates of the smoothed
663
+ convex hull, respectively. Each array contains 100 points.
664
+ """
472
665
  hull = ConvexHull(cluster_data)
473
666
  # Extract vertices of the hull
474
667
  vertices = hull.points[hull.vertices]
@@ -498,6 +691,27 @@ def _gen_rgb_image(image, channels):
498
691
  return rgb_image
499
692
 
500
693
  def _outline_and_overlay(image, rgb_image, mask_dims, outline_colors, outline_thickness):
694
+ """
695
+ Generate outlines for specified mask dimensions and overlay them onto an RGB image.
696
+
697
+ This function processes a multi-dimensional mask to extract contours for each specified
698
+ dimension, creates outlines with specified thickness, and overlays these outlines onto
699
+ the provided RGB image using the given colors.
700
+
701
+ Args:
702
+ image (numpy.ndarray): The input image containing multi-dimensional masks.
703
+ rgb_image (numpy.ndarray): The RGB image onto which the outlines will be overlaid.
704
+ mask_dims (list[int]): A list of dimensions (indices) in the mask to process.
705
+ outline_colors (list[tuple[int, int, int]]): A list of RGB color tuples for the outlines.
706
+ outline_thickness (int): The thickness of the outlines to be drawn.
707
+
708
+ Returns:
709
+ tuple:
710
+ - overlayed_image (numpy.ndarray): The RGB image with outlines overlaid.
711
+ - outlines (list[numpy.ndarray]): A list of binary masks representing the outlines
712
+ for each processed dimension.
713
+ - image (numpy.ndarray): The original input image (unchanged).
714
+ """
501
715
  outlines = []
502
716
  overlayed_image = rgb_image.copy()
503
717
 
@@ -554,6 +768,20 @@ def _convert_cq1_well_id(well_id):
554
768
  return well_format
555
769
 
556
770
  def _get_cellpose_batch_size():
771
+ """
772
+ Determines the appropriate batch size for Cellpose based on the available GPU's VRAM.
773
+
774
+ This function checks if CUDA is available and retrieves the VRAM (Video RAM) of the first GPU device.
775
+ Based on the VRAM size, it assigns a batch size for Cellpose processing. If CUDA is not available
776
+ or an error occurs, a default batch size of 8 is returned.
777
+
778
+ Returns:
779
+ int: The batch size for Cellpose. Possible values are:
780
+ - 8 (default or if VRAM < 8 GB)
781
+ - 16 (if 8 GB <= VRAM < 12 GB)
782
+ - 48 (if 12 GB <= VRAM < 24 GB)
783
+ - 96 (if VRAM >= 24 GB)
784
+ """
557
785
  try:
558
786
  # Check if CUDA is available
559
787
  if torch.cuda.is_available():
@@ -576,6 +804,32 @@ def _get_cellpose_batch_size():
576
804
  return 8
577
805
 
578
806
  def _extract_filename_metadata(filenames, src, regular_expression, metadata_type='cellvoyager'):
807
+ """
808
+ Extracts metadata from a list of filenames based on a provided regular expression
809
+ and organizes the files into a dictionary grouped by metadata keys.
810
+ Args:
811
+ filenames (list of str): A list of filenames to process.
812
+ src (str): The source directory containing the files.
813
+ regular_expression (re.Pattern): A compiled regular expression used to extract metadata
814
+ from the filenames. The regex should define named groups such as 'plateID', 'wellID',
815
+ 'fieldID', 'chanID', 'timeID', and 'sliceID' as applicable.
816
+ metadata_type (str, optional): The type of metadata to process. Defaults to 'cellvoyager'.
817
+ If set to 'cq1', well IDs will be converted using `_convert_cq1_well_id`.
818
+ Returns:
819
+ defaultdict: A dictionary where keys are tuples of metadata values
820
+ (plate, well, field, channel, timeID, sliceID), and values are lists of file paths
821
+ corresponding to those metadata keys.
822
+ Raises:
823
+ IndexError: If the filename matches the regex but does not contain the expected groups.
824
+ Prints warnings for filenames that do not match the regex or fail to extract metadata.
825
+ Notes:
826
+ - The function assumes that the provided regular expression includes named groups
827
+ for extracting metadata.
828
+ - If a metadata group (e.g., 'timeID' or 'sliceID') is not present in the regex,
829
+ its value will be set to `None`.
830
+ - The function handles numeric metadata values by converting them to integers
831
+ where applicable.
832
+ """
579
833
 
580
834
  images_by_key = defaultdict(list)
581
835
 
@@ -635,7 +889,7 @@ def mask_object_count(mask):
635
889
  """
636
890
  Counts the number of objects in a given mask.
637
891
 
638
- Parameters:
892
+ Args:
639
893
  - mask: numpy.ndarray. The mask containing object labels.
640
894
 
641
895
  Returns:
@@ -777,7 +1031,7 @@ def normalize_to_dtype(array, p1=2, p2=98, percentile_list=None, new_dtype=None)
777
1031
  """
778
1032
  Normalize each image in the stack to its own percentiles.
779
1033
 
780
- Parameters:
1034
+ Args:
781
1035
  - array: numpy array
782
1036
  The input stack to be normalized.
783
1037
  - p1: int, optional
@@ -895,7 +1149,7 @@ def _find_bounding_box(crop_mask, _id, buffer=10):
895
1149
  """
896
1150
  Find the bounding box coordinates for a given object ID in a crop mask.
897
1151
 
898
- Parameters:
1152
+ Args:
899
1153
  crop_mask (ndarray): The crop mask containing object IDs.
900
1154
  _id (int): The object ID to find the bounding box for.
901
1155
  buffer (int, optional): The buffer size to add to the bounding box coordinates. Defaults to 10.
@@ -1093,7 +1347,7 @@ def _get_percentiles(array, p1=2, p2=98):
1093
1347
  """
1094
1348
  Calculate the percentiles of each image in the given array.
1095
1349
 
1096
- Parameters:
1350
+ Args:
1097
1351
  - array: numpy.ndarray
1098
1352
  The input array containing the images.
1099
1353
  - q1: float, optional
@@ -1125,7 +1379,7 @@ def _crop_center(img, cell_mask, new_width, new_height):
1125
1379
  """
1126
1380
  Crop the image around the center of the cell mask.
1127
1381
 
1128
- Parameters:
1382
+ Args:
1129
1383
  - img: numpy.ndarray
1130
1384
  The input image.
1131
1385
  - cell_mask: numpy.ndarray
@@ -1195,6 +1449,35 @@ def _get_diam(mag, obj):
1195
1449
  return int(diamiter)
1196
1450
 
1197
1451
  def _get_object_settings(object_type, settings):
1452
+ """
1453
+ Generate and return a dictionary of settings for a specific object type.
1454
+ This function configures object-specific settings based on the provided
1455
+ `object_type` and `settings`. It supports three object types: 'cell',
1456
+ 'nucleus', and 'pathogen'. If an unsupported object type is provided,
1457
+ a message is printed, and no settings are returned.
1458
+ Args:
1459
+ object_type (str): The type of object to configure settings for.
1460
+ Supported values are 'cell', 'nucleus', and 'pathogen'.
1461
+ settings (dict): A dictionary containing global settings, which may include:
1462
+ - 'magnification' (float): Magnification level used to calculate diameter.
1463
+ - 'nucleus_channel' (optional): Specifies if a nucleus channel is present.
1464
+ - 'cell_restore_type' (optional): Restore type for cell objects.
1465
+ - 'nucleus_restore_type' (optional): Restore type for nucleus objects.
1466
+ - 'pathogen_restore_type' (optional): Restore type for pathogen objects.
1467
+ - 'merge_pathogens' (bool): Whether to merge pathogens.
1468
+ - 'verbose' (bool): Whether to print the generated settings.
1469
+ Returns:
1470
+ dict: A dictionary containing the configured settings for the specified object type.
1471
+ The dictionary includes keys such as 'diameter', 'minimum_size', 'maximum_size',
1472
+ 'merge', 'resample', 'remove_border_objects', 'model_name', 'filter_size',
1473
+ 'filter_intensity', and 'restore_type'.
1474
+ Notes:
1475
+ - The 'diameter' is calculated using the `_get_diam` function based on the
1476
+ magnification level and object type.
1477
+ - The 'minimum_size' and 'maximum_size' are derived from the diameter.
1478
+ - The 'model_name' varies depending on the object type and other settings.
1479
+ - If `settings['verbose']` is True, the generated settings are printed.
1480
+ """
1198
1481
  object_settings = {}
1199
1482
 
1200
1483
  object_settings['diameter'] = _get_diam(settings['magnification'], obj=object_type)
@@ -1242,7 +1525,7 @@ def _pivot_counts_table(db_path):
1242
1525
  """
1243
1526
  Read a table from an SQLite database into a pandas DataFrame.
1244
1527
 
1245
- Parameters:
1528
+ Args:
1246
1529
  - db_path (str): The path to the SQLite database file.
1247
1530
  - table_name (str): The name of the table to read. Default is 'object_counts'.
1248
1531
 
@@ -1287,6 +1570,38 @@ def _pivot_counts_table(db_path):
1287
1570
  conn.close()
1288
1571
 
1289
1572
  def _get_cellpose_channels_v2(src, nucleus_channel, pathogen_channel, cell_channel):
1573
+ """
1574
+ Generate a dictionary of Cellpose channels based on the provided input channels.
1575
+
1576
+ This function constructs a mapping of Cellpose channels for nucleus, pathogen,
1577
+ and cell segmentation. It also checks for the existence of precomputed mask
1578
+ files and issues a warning if any channel is `None` while masks already exist.
1579
+
1580
+ Args:
1581
+ src (str): The source directory containing the 'masks' folder with precomputed
1582
+ mask files ('cell_mask_stack', 'nucleus_mask_stack', 'pathogen_mask_stack').
1583
+ nucleus_channel (int or None): The channel index for the nucleus. If `None`,
1584
+ the nucleus channel is not included.
1585
+ pathogen_channel (int or None): The channel index for the pathogen. If `None`,
1586
+ the pathogen channel is not included.
1587
+ cell_channel (int or None): The channel index for the cell. If `None`, the cell
1588
+ channel is not included.
1589
+
1590
+ Returns:
1591
+ dict: A dictionary where keys are 'nucleus', 'pathogen', and 'cell', and values
1592
+ are lists of two integers representing the channels to be used for each
1593
+ type. If a channel is not provided (`None`), it is excluded from the dictionary.
1594
+
1595
+ Warnings:
1596
+ Prints a warning if any of the channels (`nucleus_channel`, `pathogen_channel`,
1597
+ `cell_channel`) is `None` while precomputed mask files already exist in the
1598
+ specified source directory.
1599
+
1600
+ Notes:
1601
+ - The nucleus and pathogen channels are always duplicated as [channel, channel].
1602
+ - The cell channel prefers the nucleus channel as the first channel if available;
1603
+ otherwise, it duplicates the cell channel as [cell_channel, cell_channel].
1604
+ """
1290
1605
  cell_mask_path = os.path.join(src, 'masks', 'cell_mask_stack')
1291
1606
  nucleus_mask_path = os.path.join(src, 'masks', 'nucleus_mask_stack')
1292
1607
  pathogen_mask_path = os.path.join(src, 'masks', 'pathogen_mask_stack')
@@ -1315,6 +1630,34 @@ def _get_cellpose_channels_v2(src, nucleus_channel, pathogen_channel, cell_chann
1315
1630
  return cellpose_channels
1316
1631
 
1317
1632
  def _get_cellpose_channels_v1(src, nucleus_channel, pathogen_channel, cell_channel):
1633
+ """
1634
+ Generates a dictionary mapping object types (nucleus, pathogen, cell) to their respective
1635
+ Cellpose channel configurations based on the provided input channels.
1636
+ This function checks for the existence of pre-generated Cellpose masks in the specified
1637
+ source directory and warns the user if any of the input channels are set to None while
1638
+ masks already exist. The function then constructs a dictionary of Cellpose channel
1639
+ configurations for the specified object types.
1640
+ Args:
1641
+ src (str): The source directory path where the Cellpose masks are stored.
1642
+ nucleus_channel (int or None): The channel index for the nucleus. If None, the nucleus
1643
+ is not included in the configuration.
1644
+ pathogen_channel (int or None): The channel index for the pathogen. If None, the pathogen
1645
+ is not included in the configuration.
1646
+ cell_channel (int or None): The channel index for the cell. If None, the cell is not
1647
+ included in the configuration.
1648
+ Returns:
1649
+ dict: A dictionary where keys are object types ('nucleus', 'pathogen', 'cell') and
1650
+ values are lists specifying the Cellpose channel configuration for each object type.
1651
+ Warnings:
1652
+ Prints a warning if any of the input channels are set to None while Cellpose masks
1653
+ already exist in the specified source directory.
1654
+ Notes:
1655
+ - The function assumes that the Cellpose masks are stored in subdirectories named
1656
+ 'masks/cell_mask_stack', 'masks/nucleus_mask_stack', and 'masks/pathogen_mask_stack'
1657
+ within the source directory.
1658
+ - The channel configuration is represented as a list of two integers, where the first
1659
+ integer is always 0, and the second integer depends on the presence of other channels.
1660
+ """
1318
1661
 
1319
1662
  cell_mask_path = os.path.join(src, 'masks', 'cell_mask_stack')
1320
1663
  nucleus_mask_path = os.path.join(src, 'masks', 'nucleus_mask_stack')
@@ -1347,6 +1690,33 @@ def _get_cellpose_channels_v1(src, nucleus_channel, pathogen_channel, cell_chann
1347
1690
  return cellpose_channels
1348
1691
 
1349
1692
  def _get_cellpose_channels(src, nucleus_channel, pathogen_channel, cell_channel):
1693
+ """
1694
+ Generate a mapping of channels for Cellpose segmentation based on the provided input channels.
1695
+
1696
+ This function determines the logical order of channels for nucleus, cell, and pathogen,
1697
+ removes duplicates while preserving the first occurrence, and maps the channels to indices
1698
+ capped at a maximum value of 2. It then assigns these indices to the respective Cellpose
1699
+ segmentation categories.
1700
+
1701
+ Args:
1702
+ src (str): The source directory containing the mask files.
1703
+ nucleus_channel (int or None): The channel index for the nucleus. If None, the nucleus
1704
+ is not included in the mapping.
1705
+ pathogen_channel (int or None): The channel index for the pathogen. If None, the pathogen
1706
+ is not included in the mapping.
1707
+ cell_channel (int or None): The channel index for the cell. If None, the cell is not
1708
+ included in the mapping.
1709
+
1710
+ Returns:
1711
+ dict: A dictionary mapping segmentation categories ('nucleus', 'cell', 'pathogen') to
1712
+ their respective channel indices. Each category is assigned a list of two indices,
1713
+ which are either identical or derived from the logical order of the input channels.
1714
+
1715
+ Warnings:
1716
+ If any of the mask files already exist in the `src` directory and any of the input
1717
+ channels (`nucleus_channel`, `pathogen_channel`, `cell_channel`) are None, a warning
1718
+ is printed indicating potential unexpected behavior.
1719
+ """
1350
1720
  cell_mask_path = os.path.join(src, 'masks', 'cell_mask_stack')
1351
1721
  nucleus_mask_path = os.path.join(src, 'masks', 'nucleus_mask_stack')
1352
1722
  pathogen_mask_path = os.path.join(src, 'masks', 'pathogen_mask_stack')
@@ -1474,7 +1844,7 @@ def _split_data(df, group_by, object_type):
1474
1844
  Splits the input dataframe into numeric and non-numeric parts, groups them by the specified column,
1475
1845
  and returns the grouped dataframes with conditional aggregation.
1476
1846
 
1477
- Parameters:
1847
+ Args:
1478
1848
  df (pandas.DataFrame): The input dataframe.
1479
1849
  group_by (str): The column name to group the dataframes by.
1480
1850
  object_type (str): The column name to concatenate with 'prcf' to create a new column 'prcfo'.
@@ -1574,7 +1944,7 @@ def _group_by_well(df):
1574
1944
  Group the DataFrame by well coordinates (plate, row, col) and apply mean function to numeric columns
1575
1945
  and select the first value for non-numeric columns.
1576
1946
 
1577
- Parameters:
1947
+ Args:
1578
1948
  df (DataFrame): The input DataFrame to be grouped.
1579
1949
 
1580
1950
  Returns:
@@ -1792,22 +2162,68 @@ class SpatialAttention(nn.Module):
1792
2162
 
1793
2163
  # Multi-Scale Block with Attention
1794
2164
  class MultiScaleBlockWithAttention(nn.Module):
2165
+ """
2166
+ A PyTorch module implementing a multi-scale convolutional block with spatial attention.
2167
+
2168
+ This module applies a dilated convolution followed by a spatial attention mechanism
2169
+ to enhance input feature maps.
2170
+ """
1795
2171
  def __init__(self, in_channels, out_channels):
2172
+ """
2173
+ Initialize the MultiScaleBlockWithAttention module.
2174
+
2175
+ Args:
2176
+ in_channels (int): Number of input channels.
2177
+ out_channels (int): Number of output channels after convolution.
2178
+ """
1796
2179
  super(MultiScaleBlockWithAttention, self).__init__()
1797
2180
  self.dilated_conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, dilation=1, padding=1)
1798
2181
  self.spatial_attention = nn.Conv2d(out_channels, out_channels, kernel_size=1)
1799
2182
 
1800
2183
  def custom_forward(self, x):
2184
+ """
2185
+ Apply dilated convolution followed by spatial attention.
2186
+
2187
+ Args:
2188
+ x (torch.Tensor): Input tensor of shape (N, C, H, W).
2189
+
2190
+ Returns:
2191
+ torch.Tensor: Output tensor after attention-enhanced feature transformation.
2192
+ """
1801
2193
  x1 = F.relu(self.dilated_conv1(x), inplace=True)
1802
2194
  x = self.spatial_attention(x1)
1803
2195
  return x
1804
2196
 
1805
2197
  def forward(self, x):
2198
+ """
2199
+ Standard forward pass that delegates to `custom_forward`.
2200
+
2201
+ Args:
2202
+ x (torch.Tensor): Input tensor.
2203
+
2204
+ Returns:
2205
+ torch.Tensor: Output tensor.
2206
+ """
1806
2207
  return self.custom_forward(x)
1807
2208
 
1808
2209
  # Final Classifier
1809
2210
  class CustomCellClassifier(nn.Module):
2211
+ """
2212
+ A custom neural network module for single-cell classification using early fusion and multi-scale attention.
2213
+
2214
+ This architecture supports optional gradient checkpointing for reduced memory usage during training.
2215
+ """
1810
2216
  def __init__(self, num_classes, pathogen_channel, use_attention, use_checkpoint, dropout_rate):
2217
+ """
2218
+ Initialize the classifier with early fusion and attention blocks.
2219
+
2220
+ Args:
2221
+ num_classes (int): Number of classification categories.
2222
+ pathogen_channel (int): Unused; reserved for future feature fusion.
2223
+ use_attention (bool): Unused; reserved for attention gating.
2224
+ use_checkpoint (bool): Whether to use checkpointing to reduce memory.
2225
+ dropout_rate (float): Unused; reserved for future regularization.
2226
+ """
1811
2227
  super(CustomCellClassifier, self).__init__()
1812
2228
  self.early_fusion = EarlyFusion(in_channels=3)
1813
2229
 
@@ -1820,6 +2236,15 @@ class CustomCellClassifier(nn.Module):
1820
2236
  param.requires_grad = True
1821
2237
 
1822
2238
  def custom_forward(self, x):
2239
+ """
2240
+ Perform a standard forward pass without gradient checkpointing.
2241
+
2242
+ Args:
2243
+ x (torch.Tensor): Input tensor of shape (N, C, H, W).
2244
+
2245
+ Returns:
2246
+ torch.Tensor: Logits for each class.
2247
+ """
1823
2248
  x.requires_grad = True
1824
2249
  x = self.early_fusion(x)
1825
2250
  x = self.multi_scale_block_1(x)
@@ -1828,6 +2253,15 @@ class CustomCellClassifier(nn.Module):
1828
2253
  return x
1829
2254
 
1830
2255
  def forward(self, x):
2256
+ """
2257
+ Forward pass with optional gradient checkpointing.
2258
+
2259
+ Args:
2260
+ x (torch.Tensor): Input tensor.
2261
+
2262
+ Returns:
2263
+ torch.Tensor: Output logits.
2264
+ """
1831
2265
  if self.use_checkpoint:
1832
2266
  x.requires_grad = True
1833
2267
  return checkpoint(self.custom_forward, x)
@@ -1836,7 +2270,22 @@ class CustomCellClassifier(nn.Module):
1836
2270
 
1837
2271
  #CNN and Transformer class, pick any Torch model.
1838
2272
  class TorchModel(nn.Module):
2273
+ """
2274
+ A PyTorch wrapper for pretrained torchvision models with a custom SPACR classifier head.
2275
+
2276
+ This class supports custom dropout insertion and optional gradient checkpointing
2277
+ for memory efficiency during training.
2278
+ """
1839
2279
  def __init__(self, model_name='resnet50', pretrained=True, dropout_rate=None, use_checkpoint=False):
2280
+ """
2281
+ Initialize the TorchModel with optional dropout and checkpointing.
2282
+
2283
+ Args:
2284
+ model_name (str): The model architecture to load from torchvision.models.
2285
+ pretrained (bool): Whether to initialize with pretrained weights.
2286
+ dropout_rate (float or None): Dropout rate for the classifier head.
2287
+ use_checkpoint (bool): Whether to enable gradient checkpointing.
2288
+ """
1840
2289
  super(TorchModel, self).__init__()
1841
2290
  self.model_name = model_name
1842
2291
  self.use_checkpoint = use_checkpoint
@@ -1853,13 +2302,27 @@ class TorchModel(nn.Module):
1853
2302
  self.init_spacr_classifier(dropout_rate)
1854
2303
 
1855
2304
  def apply_dropout_rate(self, model, dropout_rate):
1856
- """Apply dropout rate to all dropout layers in the model."""
2305
+ """
2306
+ Recursively set dropout probability for all nn.Dropout layers in the model.
2307
+
2308
+ Args:
2309
+ model (nn.Module): The model to modify.
2310
+ dropout_rate (float): New dropout probability.
2311
+ """
1857
2312
  for module in model.modules():
1858
2313
  if isinstance(module, nn.Dropout):
1859
2314
  module.p = dropout_rate
1860
2315
 
1861
2316
  def init_base_model(self, pretrained):
1862
- """Initialize the base model from torchvision.models."""
2317
+ """
2318
+ Load the base model from torchvision.models.
2319
+
2320
+ Args:
2321
+ pretrained (bool): Whether to load pretrained weights.
2322
+
2323
+ Returns:
2324
+ nn.Module: The base feature extractor.
2325
+ """
1863
2326
  model_func = models.__dict__.get(self.model_name, None)
1864
2327
  if not model_func:
1865
2328
  raise ValueError(f"Model {self.model_name} is not recognized.")
@@ -1870,7 +2333,12 @@ class TorchModel(nn.Module):
1870
2333
  return model_func(pretrained=pretrained)
1871
2334
 
1872
2335
  def get_weight_choice(self):
1873
- """Get weight choice if it exists for the model."""
2336
+ """
2337
+ Get the default weights enum for the selected model.
2338
+
2339
+ Returns:
2340
+ Optional[torchvision.models.WeightsEnum]: Default weights or None.
2341
+ """
1874
2342
  weight_enum = None
1875
2343
  for attr_name in dir(models):
1876
2344
  if attr_name.lower() == f"{self.model_name}_weights".lower():
@@ -1879,7 +2347,12 @@ class TorchModel(nn.Module):
1879
2347
  return weight_enum.DEFAULT if weight_enum else None
1880
2348
 
1881
2349
  def get_num_ftrs(self):
1882
- """Determine the number of features output by the base model."""
2350
+ """
2351
+ Determine output feature dimensionality from the base model.
2352
+
2353
+ Returns:
2354
+ int: Feature vector size.
2355
+ """
1883
2356
  if hasattr(self.base_model, 'fc'):
1884
2357
  self.base_model.fc = nn.Identity()
1885
2358
  elif hasattr(self.base_model, 'classifier'):
@@ -1892,14 +2365,27 @@ class TorchModel(nn.Module):
1892
2365
  return output.size(1)
1893
2366
 
1894
2367
  def init_spacr_classifier(self, dropout_rate):
1895
- """Initialize the SPACR classifier."""
2368
+ """
2369
+ Create the final classification layer and optional dropout.
2370
+
2371
+ Args:
2372
+ dropout_rate (float or None): Dropout probability. If None, dropout is skipped.
2373
+ """
1896
2374
  self.use_dropout = dropout_rate is not None
1897
2375
  if self.use_dropout:
1898
2376
  self.dropout = nn.Dropout(dropout_rate)
1899
2377
  self.spacr_classifier = nn.Linear(self.num_ftrs, 1)
1900
2378
 
1901
2379
  def forward(self, x):
1902
- """Define the forward pass of the model."""
2380
+ """
2381
+ Forward pass through base model, optional dropout, and final classifier.
2382
+
2383
+ Args:
2384
+ x (torch.Tensor): Input image tensor of shape (N, 3, H, W).
2385
+
2386
+ Returns:
2387
+ torch.Tensor: Output logits of shape (N,).
2388
+ """
1903
2389
  if self.use_checkpoint:
1904
2390
  x = checkpoint(self.base_model, x)
1905
2391
  else:
@@ -1910,19 +2396,62 @@ class TorchModel(nn.Module):
1910
2396
  return logits
1911
2397
 
1912
2398
  class FocalLossWithLogits(nn.Module):
2399
+ """
2400
+ Focal Loss with logits for binary classification.
2401
+
2402
+ This loss function is especially useful for addressing class imbalance by focusing more
2403
+ on hard-to-classify examples.
2404
+
2405
+ Args:
2406
+ alpha (float): Balancing factor for positive/negative examples. Default is 1.
2407
+ gamma (float): Focusing parameter that down-weights easy examples. Default is 2.
2408
+ """
2409
+
1913
2410
  def __init__(self, alpha=1, gamma=2):
2411
+ """
2412
+ Initialize the focal loss.
2413
+
2414
+ Args:
2415
+ alpha (float): Balancing factor for positive/negative examples. Default is 1.
2416
+ gamma (float): Focusing parameter to down-weight well-classified examples. Default is 2.
2417
+ """
1914
2418
  super(FocalLossWithLogits, self).__init__()
1915
2419
  self.alpha = alpha
1916
2420
  self.gamma = gamma
1917
2421
 
1918
2422
  def forward(self, logits, target):
2423
+ """
2424
+ Compute the focal loss between logits and targets.
2425
+
2426
+ Args:
2427
+ logits (torch.Tensor): Predicted unnormalized scores (logits).
2428
+ target (torch.Tensor): Ground truth binary labels (same shape as logits).
2429
+
2430
+ Returns:
2431
+ torch.Tensor: Scalar focal loss value.
2432
+ """
1919
2433
  BCE_loss = F.binary_cross_entropy_with_logits(logits, target, reduction='none')
1920
2434
  pt = torch.exp(-BCE_loss)
1921
2435
  focal_loss = self.alpha * (1-pt)**self.gamma * BCE_loss
1922
2436
  return focal_loss.mean()
1923
2437
 
1924
2438
  class ResNet(nn.Module):
2439
+ """
2440
+ A wrapper around torchvision ResNet models with optional dropout, checkpointing,
2441
+ and a custom classifier head.
2442
+
2443
+ Supported ResNet variants: resnet18, resnet34, resnet50, resnet101, resnet152.
2444
+ """
1925
2445
  def __init__(self, resnet_type='resnet50', dropout_rate=None, use_checkpoint=False, init_weights='imagenet'):
2446
+ """
2447
+ Initialize the ResNet model wrapper.
2448
+
2449
+ Args:
2450
+ resnet_type (str): Which ResNet variant to use. Options: 'resnet18', 'resnet34', etc.
2451
+ dropout_rate (float or None): Dropout rate to apply before the final layer.
2452
+ use_checkpoint (bool): Whether to enable gradient checkpointing.
2453
+ init_weights (str): Either 'imagenet' to load pretrained weights or 'none'.
2454
+ """
1926
2455
  super(ResNet, self).__init__()
1927
2456
 
1928
2457
  resnet_map = {
@@ -1939,6 +2468,15 @@ class ResNet(nn.Module):
1939
2468
  self.initialize_base(resnet_map[resnet_type], dropout_rate, use_checkpoint, init_weights)
1940
2469
 
1941
2470
  def initialize_base(self, base_model_dict, dropout_rate, use_checkpoint, init_weights):
2471
+ """
2472
+ Initialize the base model and classifier layers.
2473
+
2474
+ Args:
2475
+ base_model_dict (dict): Contains model constructor and weight enum.
2476
+ dropout_rate (float or None): Dropout rate to use.
2477
+ use_checkpoint (bool): Whether to use gradient checkpointing.
2478
+ init_weights (str): Weight initialization mode.
2479
+ """
1942
2480
  if init_weights == 'imagenet':
1943
2481
  self.resnet = base_model_dict['func'](weights=base_model_dict['weights'])
1944
2482
  elif init_weights == 'none':
@@ -1956,6 +2494,15 @@ class ResNet(nn.Module):
1956
2494
  self.fc2 = nn.Linear(500, 1)
1957
2495
 
1958
2496
  def forward(self, x):
2497
+ """
2498
+ Forward pass through ResNet and classification layers.
2499
+
2500
+ Args:
2501
+ x (torch.Tensor): Input tensor of shape (N, 3, H, W).
2502
+
2503
+ Returns:
2504
+ torch.Tensor: Logits of shape (N,).
2505
+ """
1959
2506
  x.requires_grad = True # Ensure that the tensor has requires_grad set to True
1960
2507
 
1961
2508
  if self.use_checkpoint:
@@ -1995,7 +2542,7 @@ def classification_metrics(all_labels, prediction_pos_probs, loss, epoch):
1995
2542
  """
1996
2543
  Calculate classification metrics for binary classification.
1997
2544
 
1998
- Parameters:
2545
+ Args:
1999
2546
  - all_labels (list): List of true labels.
2000
2547
  - prediction_pos_probs (list): List of predicted positive probabilities.
2001
2548
  - loader_name (str): Name of the data loader.
@@ -2129,6 +2676,23 @@ def choose_model(model_type, device, init_weights=True, dropout_rate=0, use_chec
2129
2676
  return base_model
2130
2677
 
2131
2678
  def calculate_loss(output, target, loss_type='binary_cross_entropy_with_logits'):
2679
+ """
2680
+ Calculates the loss between the model output and the target based on the specified loss type.
2681
+
2682
+ Args:
2683
+ output (Tensor): The predicted output from the model.
2684
+ target (Tensor): The ground truth target values.
2685
+ loss_type (str, optional): The type of loss function to use.
2686
+ Supported values are:
2687
+ - 'binary_cross_entropy_with_logits': Uses binary cross-entropy loss with logits.
2688
+ - 'focal_loss': Uses focal loss with logits. Defaults to 'binary_cross_entropy_with_logits'.
2689
+
2690
+ Returns:
2691
+ Tensor: The computed loss value.
2692
+
2693
+ Raises:
2694
+ ValueError: If an unsupported loss_type is provided.
2695
+ """
2132
2696
  if loss_type == 'binary_cross_entropy_with_logits':
2133
2697
  loss = F.binary_cross_entropy_with_logits(output, target)
2134
2698
  elif loss_type == 'focal_loss':
@@ -2137,6 +2701,22 @@ def calculate_loss(output, target, loss_type='binary_cross_entropy_with_logits')
2137
2701
  return loss
2138
2702
 
2139
2703
  def pick_best_model(src):
2704
+ """
2705
+ Selects the best model file from a given directory based on accuracy and epoch.
2706
+ This function scans the specified directory for files with a `.pth` extension,
2707
+ extracts accuracy and epoch information from their filenames using a predefined
2708
+ pattern, and selects the file with the highest accuracy. If multiple files have
2709
+ the same accuracy, the one with the highest epoch is selected.
2710
+ Args:
2711
+ src (str): The path to the directory containing the model files.
2712
+ Returns:
2713
+ str: The full path to the best model file based on accuracy and epoch.
2714
+ Notes:
2715
+ - The filenames are expected to follow the pattern `_epoch_<epoch>_acc_<accuracy>.pth`,
2716
+ where `<epoch>` is an integer and `<accuracy>` is a float.
2717
+ - If no files match the pattern, the function may raise an IndexError when
2718
+ attempting to access the first element of the sorted list.
2719
+ """
2140
2720
  all_files = os.listdir(src)
2141
2721
  pth_files = [f for f in all_files if f.endswith('.pth')]
2142
2722
  pattern = re.compile(r'_epoch_(\d+)_acc_(\d+(?:\.\d+)?)')
@@ -2153,18 +2733,67 @@ def pick_best_model(src):
2153
2733
  return os.path.join(src, best_model)
2154
2734
 
2155
2735
  def get_paths_from_db(df, png_df, image_type='cell_png'):
2736
+ """
2737
+ Filters and retrieves paths from a DataFrame based on specified criteria.
2738
+
2739
+ Args:
2740
+ df (pd.DataFrame): A DataFrame whose index contains the objects of interest.
2741
+ png_df (pd.DataFrame): A DataFrame containing a 'png_path' column and a 'prcfo' column.
2742
+ image_type (str, optional): A string to filter the 'png_path' column. Defaults to 'cell_png'.
2743
+
2744
+ Returns:
2745
+ pd.DataFrame: A filtered DataFrame containing rows from `png_df` where the 'png_path'
2746
+ column contains the `image_type` string and the 'prcfo' column matches
2747
+ the index of `df`.
2748
+ """
2156
2749
  objects = df.index.tolist()
2157
2750
  filtered_df = png_df[png_df['png_path'].str.contains(image_type) & png_df['prcfo'].isin(objects)]
2158
2751
  return filtered_df
2159
2752
 
2160
2753
  def save_file_lists(dst, data_set, ls):
2754
+ """
2755
+ Saves a list of file paths or data entries to a CSV file.
2756
+
2757
+ Args:
2758
+ dst (str): The destination directory where the CSV file will be saved.
2759
+ data_set (str): The name of the dataset, which will also be used as the column name in the CSV file and the filename.
2760
+ ls (list): A list of file paths or data entries to be saved.
2761
+
2762
+ Returns:
2763
+ None
2764
+ """
2161
2765
  df = pd.DataFrame(ls, columns=[data_set])
2162
2766
  df.to_csv(f'{dst}/{data_set}.csv', index=False)
2163
2767
  return
2164
2768
 
2165
2769
  def augment_single_image(args):
2166
- img_path, dst = args
2167
- img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)
2770
+ """
2771
+ Augment a single image by applying various transformations and saving the results.
2772
+
2773
+ This function reads an image from the specified file path, applies a series of
2774
+ transformations (original, rotations, and flips), and saves the transformed images
2775
+ to the destination directory with appropriate filenames.
2776
+
2777
+ Args:
2778
+ args (tuple): A tuple containing:
2779
+ img_path (str): The file path to the input image.
2780
+ dst (str): The destination directory where the augmented images will be saved.
2781
+
2782
+ Notes:
2783
+ The following transformations are applied to the input image:
2784
+ - Original image (no transformation)
2785
+ - 90-degree clockwise rotation
2786
+ - 180-degree rotation
2787
+ - 270-degree clockwise rotation
2788
+ - Horizontal flip
2789
+ - Vertical flip
2790
+
2791
+ Side Effects:
2792
+ Saves the augmented images to the specified destination directory.
2793
+ Filenames indicate the type of transformation applied.
2794
+ """
2795
+ img_path, dst = args
2796
+ img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)
2168
2797
  filename = os.path.basename(img_path).split('.')[0]
2169
2798
 
2170
2799
  # Original Image
@@ -2191,6 +2820,17 @@ def augment_single_image(args):
2191
2820
  cv2.imwrite(os.path.join(dst, f"{filename}_flip_ver.png"), img_flip_ver)
2192
2821
 
2193
2822
  def augment_images(file_paths, dst):
2823
+ """
2824
+ Augments a list of images and saves the augmented images to the specified destination directory.
2825
+
2826
+ Args:
2827
+ file_paths (list of str): A list of file paths to the images to be augmented.
2828
+ dst (str): The destination directory where the augmented images will be saved.
2829
+ If the directory does not exist, it will be created.
2830
+
2831
+ Returns:
2832
+ None
2833
+ """
2194
2834
  if not os.path.exists(dst):
2195
2835
  os.makedirs(dst)
2196
2836
 
@@ -2200,6 +2840,52 @@ def augment_images(file_paths, dst):
2200
2840
  pool.map(augment_single_image, args_list)
2201
2841
 
2202
2842
  def augment_classes(dst, nc, pc, generate=True,move=True):
2843
+ def augment_classes(dst, nc, pc, generate=True, move=True):
2844
+ """
2845
+ Augments and organizes image datasets into training and testing directories.
2846
+
2847
+ Args:
2848
+ -----------
2849
+ dst : str
2850
+ The destination directory where augmented data will be stored.
2851
+ nc : list
2852
+ List of file paths for the "negative class" images to be augmented.
2853
+ pc : list
2854
+ List of file paths for the "positive class" images to be augmented.
2855
+ generate : bool, optional
2856
+ If True, generates augmented images for both classes and saves them in
2857
+ separate directories (`aug_nc` for negative class and `aug_pc` for positive class).
2858
+ Default is True.
2859
+ move : bool, optional
2860
+ If True, splits the augmented images into training and testing datasets,
2861
+ and organizes them into subdirectories under `aug/train` and `aug/test`.
2862
+ Default is True.
2863
+
2864
+ Returns:
2865
+ --------
2866
+ None
2867
+
2868
+ Notes:
2869
+ ------
2870
+ - The function uses `train_test_split` to split the augmented images into
2871
+ training (90%) and testing (10%) datasets.
2872
+ - The progress of moving files is displayed in the console.
2873
+ - The final counts of training and testing images for each class are printed.
2874
+
2875
+ Example Directory Structure:
2876
+ ----------------------------
2877
+ After execution, the directory structure will look like this:
2878
+ dst/
2879
+ ├── aug_nc/ # Augmented negative class images
2880
+ ├── aug_pc/ # Augmented positive class images
2881
+ ├── aug/
2882
+ ├── train/
2883
+ ├── nc/ # Training negative class images
2884
+ ├── pc/ # Training positive class images
2885
+ ├── test/
2886
+ ├── nc/ # Testing negative class images
2887
+ ├── pc/ # Testing positive class images
2888
+ """
2203
2889
  aug_nc = os.path.join(dst,'aug_nc')
2204
2890
  aug_pc = os.path.join(dst,'aug_pc')
2205
2891
  all_ = len(nc)+len(pc)
@@ -2251,6 +2937,21 @@ def augment_classes(dst, nc, pc, generate=True,move=True):
2251
2937
  return
2252
2938
 
2253
2939
  def annotate_predictions(csv_loc):
2940
+ """
2941
+ Reads a CSV file containing image metadata, processes the data to extract
2942
+ additional information, and assigns a condition label to each row based on
2943
+ specific rules.
2944
+ Args:
2945
+ csv_loc (str): The file path to the CSV file containing the metadata.
2946
+ Returns:
2947
+ pandas.DataFrame: A DataFrame with the following additional columns:
2948
+ - 'filename': Extracted filename from the 'path' column.
2949
+ - 'plateID': Plate ID extracted from the filename.
2950
+ - 'well': Well information extracted from the filename.
2951
+ - 'fieldID': Field ID extracted from the filename.
2952
+ - 'object': Object ID extracted from the filename (with '.png' removed).
2953
+ - 'cond': Assigned condition label ('screen', 'pc', 'nc', or '') based on rules.
2954
+ """
2254
2955
  df = pd.read_csv(csv_loc)
2255
2956
  df['filename'] = df['path'].apply(lambda x: x.split('/')[-1])
2256
2957
  df[['plateID', 'well', 'fieldID', 'object']] = df['filename'].str.split('_', expand=True)
@@ -2274,11 +2975,47 @@ def annotate_predictions(csv_loc):
2274
2975
  return df
2275
2976
 
2276
2977
  def initiate_counter(counter_, lock_):
2978
+ """
2979
+ Initializes global variables `counter` and `lock` with the provided arguments.
2980
+
2981
+ This function sets the global variables `counter` and `lock` to the values
2982
+ passed as `counter_` and `lock_`, respectively. It is typically used to
2983
+ share a counter and a lock object across multiple threads or processes.
2984
+
2985
+ Args:
2986
+ counter_ (Any): The counter object to be assigned to the global `counter`.
2987
+ lock_ (Any): The lock object to be assigned to the global `lock`.
2988
+ """
2277
2989
  global counter, lock
2278
2990
  counter = counter_
2279
2991
  lock = lock_
2280
2992
 
2281
2993
  def add_images_to_tar(paths_chunk, tar_path, total_images):
2994
+ """
2995
+ Adds a chunk of image files to a tar archive.
2996
+
2997
+ Args:
2998
+ paths_chunk (list of str): A list of file paths to the images to be added to the tar archive.
2999
+ tar_path (str): The path where the tar archive will be created or overwritten.
3000
+ total_images (int): The total number of images being processed, used for progress tracking.
3001
+
3002
+ Behavior:
3003
+ - Opens a tar archive at the specified `tar_path` in write mode.
3004
+ - Iterates through the provided `paths_chunk` and adds each image to the tar archive.
3005
+ - Tracks progress using a shared counter and prints progress updates every 10 images.
3006
+ - Handles missing files gracefully by printing a warning message if a file is not found.
3007
+
3008
+ Notes:
3009
+ - This function assumes the existence of a global `lock` object for thread-safe counter updates.
3010
+ - The `counter` object is expected to be a shared multiprocessing.Value or similar.
3011
+ - The `print_progress` function is used to display progress updates.
3012
+
3013
+ Exceptions:
3014
+ - Prints a warning message if a file in `paths_chunk` is not found, but continues processing other files.
3015
+
3016
+ Example:
3017
+ add_images_to_tar(['/path/to/image1.jpg', '/path/to/image2.jpg'], '/path/to/archive.tar', 100)
3018
+ """
2282
3019
  with tarfile.open(tar_path, 'w') as tar:
2283
3020
  for i, img_path in enumerate(paths_chunk):
2284
3021
  arcname = os.path.basename(img_path)
@@ -2294,6 +3031,29 @@ def add_images_to_tar(paths_chunk, tar_path, total_images):
2294
3031
  print(f"File not found: {img_path}")
2295
3032
 
2296
3033
  def generate_fraction_map(df, gene_column, min_frequency=0.0):
3034
+ """
3035
+ Generates a fraction map from a given DataFrame and writes it to a CSV file.
3036
+
3037
+ This function calculates the fraction of counts for each gene and well,
3038
+ organizes the data into a pivot table-like structure, and filters out
3039
+ columns based on a minimum frequency threshold. The resulting DataFrame
3040
+ is saved as a CSV file.
3041
+
3042
+ Args:
3043
+ df (pd.DataFrame): Input DataFrame containing the data. It must include
3044
+ the columns 'count', 'well_read_sum', 'prc', and the specified
3045
+ `gene_column`.
3046
+ gene_column (str): The name of the column in `df` that contains gene
3047
+ identifiers.
3048
+ min_frequency (float, optional): The minimum frequency threshold for
3049
+ filtering columns. Columns with a maximum value below this
3050
+ threshold are dropped. Defaults to 0.0.
3051
+
3052
+ Returns:
3053
+ pd.DataFrame: A DataFrame containing the fraction map, with wells as
3054
+ rows and genes as columns. Cells contain the fraction values, and
3055
+ missing values are filled with 0.0.
3056
+ """
2297
3057
  df['fraction'] = df['count']/df['well_read_sum']
2298
3058
  genes = df[gene_column].unique().tolist()
2299
3059
  wells = df['prc'].unique().tolist()
@@ -2319,6 +3079,17 @@ def generate_fraction_map(df, gene_column, min_frequency=0.0):
2319
3079
  return independent_variables
2320
3080
 
2321
3081
  def fishers_odds(df, threshold=0.5, phenotyp_col='mean_pred'):
3082
+ """
3083
+ Perform Fisher's exact test to evaluate the association between mutants and a binned phenotype score.
3084
+
3085
+ Args:
3086
+ df (pandas.DataFrame): DataFrame containing binary mutant indicators and a phenotype score column.
3087
+ threshold (float, optional): Threshold to bin the phenotype score. Defaults to 0.5.
3088
+ phenotyp_col (str, optional): Name of the column containing phenotype scores. Defaults to 'mean_pred'.
3089
+
3090
+ Returns:
3091
+ pandas.DataFrame: DataFrame with columns 'Mutant', 'OddsRatio', 'PValue', and 'AdjustedPValue'.
3092
+ """
2322
3093
  # Binning based on phenotype score (e.g., above 0.8 as high)
2323
3094
  df['high_phenotype'] = df[phenotyp_col] < threshold
2324
3095
 
@@ -2359,6 +3130,37 @@ def fishers_odds(df, threshold=0.5, phenotyp_col='mean_pred'):
2359
3130
  return filtered_results_df
2360
3131
 
2361
3132
  def model_metrics(model):
3133
+ """
3134
+ Calculate and display additional metrics and generate diagnostic plots for a given model.
3135
+
3136
+ Args:
3137
+ -----------
3138
+ model : statsmodels.regression.linear_model.RegressionResultsWrapper
3139
+ A fitted regression model object from the statsmodels library.
3140
+
3141
+ Metrics Calculated:
3142
+ --------------------
3143
+ - Root Mean Squared Error (RMSE): Measures the standard deviation of residuals.
3144
+ - Mean Absolute Error (MAE): Measures the average magnitude of residuals.
3145
+ - Durbin-Watson: Tests for the presence of autocorrelation in residuals.
3146
+
3147
+ Diagnostic Plots:
3148
+ ------------------
3149
+ 1. Residuals vs. Fitted: Scatter plot to check for non-linearity or unequal error variance.
3150
+ 2. Histogram of Residuals: Distribution of residuals to check for normality.
3151
+ 3. QQ Plot: Quantile-Quantile plot to assess if residuals follow a normal distribution.
3152
+ 4. Scale-Location: Scatter plot of standardized residuals to check for homoscedasticity.
3153
+
3154
+ Notes:
3155
+ ------
3156
+ - This function uses matplotlib and seaborn for plotting.
3157
+ - Ensure that the input model is fitted and contains the necessary attributes like `resid`,
3158
+ `fittedvalues`, and `mse_resid`.
3159
+
3160
+ Returns:
3161
+ --------
3162
+ None
3163
+ """
2362
3164
 
2363
3165
  # Calculate additional metrics
2364
3166
  rmse = np.sqrt(model.mse_resid)
@@ -2400,6 +3202,18 @@ def model_metrics(model):
2400
3202
  plt.show()
2401
3203
 
2402
3204
  def check_multicollinearity(x):
3205
+ """
3206
+ Checks multicollinearity of the predictors by computing the Variance Inflation Factor (VIF).
3207
+
3208
+ Args:
3209
+ x (pd.DataFrame): A DataFrame containing the predictor variables.
3210
+
3211
+ Returns:
3212
+ pd.DataFrame: A DataFrame with two columns:
3213
+ - 'Variable': The names of the predictor variables.
3214
+ - 'VIF': The Variance Inflation Factor for each predictor variable.
3215
+ A VIF value greater than 10 indicates high multicollinearity.
3216
+ """
2403
3217
  """Checks multicollinearity of the predictors by computing the VIF."""
2404
3218
  vif_data = pd.DataFrame()
2405
3219
  vif_data["Variable"] = x.columns
@@ -2407,6 +3221,17 @@ def check_multicollinearity(x):
2407
3221
  return vif_data
2408
3222
 
2409
3223
  def lasso_reg(merged_df, alpha_value=0.01, reg_type='lasso'):
3224
+ """
3225
+ Perform Lasso or Ridge regression on the input DataFrame.
3226
+
3227
+ Args:
3228
+ merged_df (pandas.DataFrame): DataFrame with columns 'gene', 'grna', 'plateID', 'rowID', 'columnID', and 'pred'.
3229
+ alpha_value (float, optional): Regularization strength. Defaults to 0.01.
3230
+ reg_type (str, optional): Type of regression to perform, either 'lasso' or 'ridge'. Defaults to 'lasso'.
3231
+
3232
+ Returns:
3233
+ pandas.DataFrame: DataFrame with 'Feature' and 'Coefficient' columns.
3234
+ """
2410
3235
  # Separate predictors and response
2411
3236
  X = merged_df[['gene', 'grna', 'plateID', 'rowID', 'columnID']]
2412
3237
  y = merged_df['pred']
@@ -2433,7 +3258,20 @@ def lasso_reg(merged_df, alpha_value=0.01, reg_type='lasso'):
2433
3258
  return coeff_df
2434
3259
 
2435
3260
  def MLR(merged_df, refine_model):
2436
-
3261
+ """
3262
+ Perform multiple linear regression (MLR) and extract interaction coefficients.
3263
+
3264
+ Args:
3265
+ merged_df (pd.DataFrame): Input DataFrame with data for regression analysis.
3266
+ refine_model (bool): Whether to remove outliers before refitting the model.
3267
+
3268
+ Returns:
3269
+ tuple: Contains:
3270
+ - max_effects (dict): Maximum interaction effect size per gene.
3271
+ - max_effects_pvalues (dict): Corresponding p-values.
3272
+ - model (statsmodels.regression.linear_model.RegressionResultsWrapper): Fitted regression model.
3273
+ - df (pd.DataFrame): DataFrame with sorted interaction effects and p-values.
3274
+ """
2437
3275
  from .plot import _reg_v_plot
2438
3276
 
2439
3277
  #model = smf.ols("pred ~ gene + grna + gene:grna + plate + row + column", merged_df).fit()
@@ -2485,9 +3323,35 @@ def MLR(merged_df, refine_model):
2485
3323
  return max_effects, max_effects_pvalues, model, df
2486
3324
 
2487
3325
  def get_files_from_dir(dir_path, file_extension="*"):
3326
+ """
3327
+ Retrieves a list of files from the specified directory that match the given file extension.
3328
+
3329
+ Args:
3330
+ dir_path (str): The path to the directory from which to retrieve files.
3331
+ file_extension (str, optional): The file extension to filter files by. Defaults to "*"
3332
+ (matches all files).
3333
+
3334
+ Returns:
3335
+ list: A list of file paths matching the specified file extension in the given directory.
3336
+ """
2488
3337
  return glob(os.path.join(dir_path, file_extension))
2489
3338
 
2490
3339
  def create_circular_mask(h, w, center=None, radius=None):
3340
+ """
3341
+ Creates a circular mask for a 2D array with the specified dimensions.
3342
+
3343
+ Args:
3344
+ h (int): The height of the 2D array.
3345
+ w (int): The width of the 2D array.
3346
+ center (tuple, optional): The (x, y) coordinates of the circle's center.
3347
+ Defaults to the center of the array.
3348
+ radius (int, optional): The radius of the circle. Defaults to the smallest
3349
+ distance from the center to the array's edges.
3350
+
3351
+ Returns:
3352
+ numpy.ndarray: A boolean 2D array where `True` represents the pixels
3353
+ inside the circle and `False` represents the pixels outside.
3354
+ """
2491
3355
  if center is None: # use the middle of the image
2492
3356
  center = (int(w/2), int(h/2))
2493
3357
  if radius is None: # use the smallest distance between the center and image walls
@@ -2500,6 +3364,16 @@ def create_circular_mask(h, w, center=None, radius=None):
2500
3364
  return mask
2501
3365
 
2502
3366
  def apply_mask(image, output_value=0):
3367
+ """
3368
+ Apply a circular mask to an image, setting pixels outside the mask to a specified value.
3369
+
3370
+ Args:
3371
+ image (np.ndarray): Input image (2D grayscale or 3D RGB array).
3372
+ output_value (int, optional): Value for pixels outside the mask. Defaults to 0.
3373
+
3374
+ Returns:
3375
+ np.ndarray: Image with circular mask applied.
3376
+ """
2503
3377
  h, w = image.shape[:2] # Assuming image is grayscale or RGB
2504
3378
  mask = create_circular_mask(h, w)
2505
3379
 
@@ -2512,12 +3386,53 @@ def apply_mask(image, output_value=0):
2512
3386
  return masked_image
2513
3387
 
2514
3388
  def invert_image(image):
3389
+ """
3390
+ Inverts the pixel values of an image.
3391
+
3392
+ The function calculates the inverted image by subtracting each pixel value
3393
+ from the maximum possible value for the image's data type. For example,
3394
+ for an image with dtype `uint8`, the maximum value is 255.
3395
+
3396
+ Args:
3397
+ image (numpy.ndarray): The input image to be inverted. The image should
3398
+ be a NumPy array with a valid integer data type (e.g., uint8, uint16).
3399
+
3400
+ Returns:
3401
+ numpy.ndarray: The inverted image, with the same shape and dtype as the input.
3402
+ """
2515
3403
  # The maximum value depends on the image dtype (e.g., 255 for uint8)
2516
3404
  max_value = np.iinfo(image.dtype).max
2517
3405
  inverted_image = max_value - image
2518
3406
  return inverted_image
2519
3407
 
2520
3408
  def resize_images_and_labels(images, labels, target_height, target_width, show_example=True):
3409
+ """
3410
+ Resize images and labels to the specified target dimensions.
3411
+
3412
+ Args:
3413
+ images (list or None): List of 2D or 3D numpy arrays representing input images.
3414
+ If None, only labels will be resized.
3415
+ labels (list or None): List of 2D numpy arrays representing label masks.
3416
+ If None, only images will be resized.
3417
+ target_height (int): Desired height of output arrays.
3418
+ target_width (int): Desired width of output arrays.
3419
+ show_example (bool, optional): Whether to display an example of original vs resized output.
3420
+ Defaults to True.
3421
+
3422
+ Returns:
3423
+ tuple: A tuple (resized_images, resized_labels), where:
3424
+ - resized_images (list): List of resized image arrays. Empty if `images` is None.
3425
+ - resized_labels (list): List of resized label arrays. Empty if `labels` is None.
3426
+
3427
+ Raises:
3428
+ ValueError: If both `images` and `labels` are None.
3429
+
3430
+ Notes:
3431
+ - Uses `resizescikit` for resizing.
3432
+ - Applies anti-aliasing when resizing images.
3433
+ - Uses nearest-neighbor interpolation (`order=0`) for labels to preserve class values.
3434
+ - Visualization of the resizing process is shown using `plot_resize` if `show_example` is True.
3435
+ """
2521
3436
 
2522
3437
  from .plot import plot_resize
2523
3438
 
@@ -2571,6 +3486,29 @@ def resize_images_and_labels(images, labels, target_height, target_width, show_e
2571
3486
  return resized_images, resized_labels
2572
3487
 
2573
3488
  def resize_labels_back(labels, orig_dims):
3489
+ """
3490
+ Resize a list of label arrays back to their original dimensions.
3491
+
3492
+ Args:
3493
+ labels (list of numpy.ndarray): A list of label arrays to be resized.
3494
+ orig_dims (list of tuple): A list of tuples where each tuple contains
3495
+ two integers representing the original dimensions (width, height)
3496
+ of the corresponding label array.
3497
+
3498
+ Returns:
3499
+ list of numpy.ndarray: A list of resized label arrays with dimensions
3500
+ matching the corresponding tuples in `orig_dims`.
3501
+
3502
+ Raises:
3503
+ ValueError: If the length of `labels` and `orig_dims` do not match.
3504
+ ValueError: If any element in `orig_dims` is not a tuple of two integers.
3505
+
3506
+ Notes:
3507
+ - The resizing operation uses nearest-neighbor interpolation (order=0).
3508
+ - The `preserve_range` parameter ensures that the data range of the
3509
+ input is preserved during resizing.
3510
+ - Anti-aliasing is disabled for this operation.
3511
+ """
2574
3512
  resized_labels = []
2575
3513
 
2576
3514
  if len(labels) != len(orig_dims):
@@ -2587,12 +3525,37 @@ def resize_labels_back(labels, orig_dims):
2587
3525
  return resized_labels
2588
3526
 
2589
3527
  def calculate_iou(mask1, mask2):
3528
+ """
3529
+ Calculate the Intersection over Union (IoU) between two binary masks.
3530
+
3531
+ The IoU is a measure of the overlap between two binary masks, defined as the
3532
+ ratio of the intersection area to the union area of the masks.
3533
+
3534
+ Args:
3535
+ mask1 (numpy.ndarray): The first binary mask. Must be a 2D array.
3536
+ mask2 (numpy.ndarray): The second binary mask. Must be a 2D array.
3537
+
3538
+ Returns:
3539
+ float: The IoU value, ranging from 0 to 1. Returns 0 if the union of the
3540
+ masks is empty.
3541
+ """
2590
3542
  mask1, mask2 = pad_to_same_shape(mask1, mask2)
2591
3543
  intersection = np.logical_and(mask1, mask2).sum()
2592
3544
  union = np.logical_or(mask1, mask2).sum()
2593
3545
  return intersection / union if union != 0 else 0
2594
3546
 
2595
3547
  def match_masks(true_masks, pred_masks, iou_threshold):
3548
+ """
3549
+ Matches predicted masks to ground truth masks based on Intersection over Union (IoU) threshold.
3550
+
3551
+ Args:
3552
+ true_masks (list): A list of ground truth masks.
3553
+ pred_masks (list): A list of predicted masks.
3554
+ iou_threshold (float): The IoU threshold for determining a match between a true mask and a predicted mask.
3555
+
3556
+ Returns:
3557
+ list: A list of tuples where each tuple contains a matched pair of (true_mask, pred_mask).
3558
+ """
2596
3559
  matches = []
2597
3560
  matched_true_masks_indices = set() # Use set to store indices of matched true masks
2598
3561
 
@@ -2607,6 +3570,20 @@ def match_masks(true_masks, pred_masks, iou_threshold):
2607
3570
  return matches
2608
3571
 
2609
3572
  def compute_average_precision(matches, num_true_masks, num_pred_masks):
3573
+ """
3574
+ Computes the precision and recall based on the provided matches, number of true masks,
3575
+ and number of predicted masks.
3576
+
3577
+ Args:
3578
+ matches (list): A list of matched predictions to ground truth masks.
3579
+ num_true_masks (int): The total number of ground truth masks.
3580
+ num_pred_masks (int): The total number of predicted masks.
3581
+
3582
+ Returns:
3583
+ tuple: A tuple containing:
3584
+ - precision (float): The precision value, calculated as TP / (TP + FP).
3585
+ - recall (float): The recall value, calculated as TP / (TP + FN).
3586
+ """
2610
3587
  TP = len(matches)
2611
3588
  FP = num_pred_masks - TP
2612
3589
  FN = num_true_masks - TP
@@ -2615,6 +3592,20 @@ def compute_average_precision(matches, num_true_masks, num_pred_masks):
2615
3592
  return precision, recall
2616
3593
 
2617
3594
  def pad_to_same_shape(mask1, mask2):
3595
+ """
3596
+ Pads two 2D arrays (masks) to the same shape by adding zero-padding to the
3597
+ right and bottom of each array as needed.
3598
+ Args:
3599
+ mask1 (numpy.ndarray): The first 2D array to be padded.
3600
+ mask2 (numpy.ndarray): The second 2D array to be padded.
3601
+ Returns:
3602
+ tuple: A tuple containing two 2D numpy arrays:
3603
+ - padded_mask1 (numpy.ndarray): The first array padded to match the shape of the larger array.
3604
+ - padded_mask2 (numpy.ndarray): The second array padded to match the shape of the larger array.
3605
+ Notes:
3606
+ - The padding is applied with constant values of 0.
3607
+ - The function assumes that both inputs are 2D arrays.
3608
+ """
2618
3609
  # Find the shape differences
2619
3610
  shape_diff = np.array([max(mask1.shape[0], mask2.shape[0]) - mask1.shape[0],
2620
3611
  max(mask1.shape[1], mask2.shape[1]) - mask1.shape[1]])
@@ -2629,6 +3620,24 @@ def pad_to_same_shape(mask1, mask2):
2629
3620
  return padded_mask1, padded_mask2
2630
3621
 
2631
3622
  def compute_ap_over_iou_thresholds(true_masks, pred_masks, iou_thresholds):
3623
+ """
3624
+ Compute the Average Precision (AP) over a range of Intersection over Union (IoU) thresholds.
3625
+
3626
+ This function calculates the precision-recall pairs for each IoU threshold, validates that
3627
+ precision and recall values are within the range [0, 1], and computes the Average Precision
3628
+ by integrating the precision-recall curve using the trapezoidal rule.
3629
+
3630
+ Args:
3631
+ true_masks (list or array-like): Ground truth masks.
3632
+ pred_masks (list or array-like): Predicted masks.
3633
+ iou_thresholds (list or array-like): A list of IoU thresholds to evaluate.
3634
+
3635
+ Returns:
3636
+ float: The computed Average Precision (AP) over the specified IoU thresholds.
3637
+
3638
+ Raises:
3639
+ ValueError: If precision or recall values are out of the valid range [0, 1].
3640
+ """
2632
3641
  precision_recall_pairs = []
2633
3642
  for iou_threshold in iou_thresholds:
2634
3643
  matches = match_masks(true_masks, pred_masks, iou_threshold)
@@ -2645,6 +3654,23 @@ def compute_ap_over_iou_thresholds(true_masks, pred_masks, iou_thresholds):
2645
3654
  return np.trapz(sorted_precisions, x=sorted_recalls)
2646
3655
 
2647
3656
  def compute_segmentation_ap(true_masks, pred_masks, iou_thresholds=np.linspace(0.5, 0.95, 10)):
3657
+ """
3658
+ Compute the Average Precision (AP) for segmentation masks over a range of IoU thresholds.
3659
+
3660
+ This function calculates the AP by comparing the ground truth masks (`true_masks`) with the
3661
+ predicted masks (`pred_masks`) across multiple Intersection over Union (IoU) thresholds.
3662
+
3663
+ Args:
3664
+ true_masks (ndarray): A binary array representing the ground truth segmentation masks.
3665
+ Each connected component is treated as a separate object.
3666
+ pred_masks (ndarray): A binary array representing the predicted segmentation masks.
3667
+ Each connected component is treated as a separate object.
3668
+ iou_thresholds (ndarray, optional): A 1D array of IoU thresholds to evaluate AP over.
3669
+ Defaults to `np.linspace(0.5, 0.95, 10)`.
3670
+
3671
+ Returns:
3672
+ float: The computed Average Precision (AP) over the specified IoU thresholds.
3673
+ """
2648
3674
  true_mask_labels = label(true_masks)
2649
3675
  pred_mask_labels = label(pred_masks)
2650
3676
  true_mask_regions = [region.image for region in regionprops(true_mask_labels)]
@@ -2652,11 +3678,42 @@ def compute_segmentation_ap(true_masks, pred_masks, iou_thresholds=np.linspace(0
2652
3678
  return compute_ap_over_iou_thresholds(true_mask_regions, pred_mask_regions, iou_thresholds)
2653
3679
 
2654
3680
  def jaccard_index(mask1, mask2):
3681
+ """
3682
+ Computes the Jaccard Index (Intersection over Union) between two binary masks.
3683
+
3684
+ The Jaccard Index is a measure of similarity between two sets, defined as the size
3685
+ of the intersection divided by the size of the union of the sets.
3686
+
3687
+ Args:
3688
+ mask1 (numpy.ndarray): A binary mask (e.g., a 2D array of boolean or 0/1 values).
3689
+ mask2 (numpy.ndarray): Another binary mask of the same shape as `mask1`.
3690
+
3691
+ Returns:
3692
+ float: The Jaccard Index, a value between 0 and 1, where 1 indicates perfect overlap
3693
+ and 0 indicates no overlap.
3694
+ """
2655
3695
  intersection = np.logical_and(mask1, mask2)
2656
3696
  union = np.logical_or(mask1, mask2)
2657
3697
  return np.sum(intersection) / np.sum(union)
2658
3698
 
2659
3699
  def dice_coefficient(mask1, mask2):
3700
+ """
3701
+ Compute the Dice coefficient, a measure of overlap between two binary masks.
3702
+
3703
+ The Dice coefficient is defined as:
3704
+ ``2 * |A ∩ B| / (|A| + |B|)``
3705
+
3706
+ where A and B are binary masks, ``|A ∩ B|`` is the number of overlapping non-zero elements,
3707
+ and ``|A| + |B|`` is the total number of non-zero elements in both masks.
3708
+
3709
+ Args:
3710
+ mask1 (numpy.ndarray): First binary mask.
3711
+ mask2 (numpy.ndarray): Second binary mask.
3712
+
3713
+ Returns:
3714
+ float: Dice coefficient between 0.0 (no overlap) and 1.0 (perfect overlap).
3715
+ Returns 1.0 if both masks are empty.
3716
+ """
2660
3717
  # Convert to binary masks
2661
3718
  mask1 = np.where(mask1 > 0, 1, 0)
2662
3719
  mask2 = np.where(mask2 > 0, 1, 0)
@@ -2673,6 +3730,22 @@ def dice_coefficient(mask1, mask2):
2673
3730
  return 2.0 * intersection / total
2674
3731
 
2675
3732
  def extract_boundaries(mask, dilation_radius=1):
3733
+ """
3734
+ Extracts the boundaries of a binary mask by applying morphological dilation
3735
+ and erosion operations and computing their difference.
3736
+
3737
+ Args:
3738
+ mask (numpy.ndarray): A 2D array representing the input mask. Non-zero
3739
+ values are considered part of the mask.
3740
+ dilation_radius (int, optional): The radius of the structuring element
3741
+ used for dilation and erosion.
3742
+ Defaults to 1.
3743
+
3744
+ Returns:
3745
+ numpy.ndarray: A binary array of the same shape as the input mask,
3746
+ where the boundary pixels are marked as 1 and all
3747
+ other pixels are 0.
3748
+ """
2676
3749
  binary_mask = (mask > 0).astype(np.uint8)
2677
3750
  struct_elem = np.ones((dilation_radius*2+1, dilation_radius*2+1))
2678
3751
  dilated = morphology.binary_dilation(binary_mask, footprint=struct_elem)
@@ -2681,6 +3754,21 @@ def extract_boundaries(mask, dilation_radius=1):
2681
3754
  return boundary
2682
3755
 
2683
3756
  def boundary_f1_score(mask_true, mask_pred, dilation_radius=1):
3757
+ """
3758
+ Calculate the boundary F1 score between two binary masks.
3759
+
3760
+ The boundary F1 score evaluates the alignment of object boundaries between
3761
+ predicted and ground truth masks. It computes the harmonic mean of boundary
3762
+ precision and recall after dilating the boundaries.
3763
+
3764
+ Args:
3765
+ mask_true (np.ndarray): Ground truth binary mask. Non-zero values are considered foreground.
3766
+ mask_pred (np.ndarray): Predicted binary mask. Non-zero values are considered foreground.
3767
+ dilation_radius (int, optional): Radius used for boundary dilation to allow tolerance in matching. Defaults to 1.
3768
+
3769
+ Returns:
3770
+ float: Boundary F1 score between 0 (no boundary match) and 1 (perfect boundary alignment).
3771
+ """
2684
3772
  # Assume extract_boundaries is defined to extract object boundaries with given dilation_radius
2685
3773
  boundary_true = extract_boundaries(mask_true, dilation_radius)
2686
3774
  boundary_pred = extract_boundaries(mask_pred, dilation_radius)
@@ -2697,8 +3785,6 @@ def boundary_f1_score(mask_true, mask_pred, dilation_radius=1):
2697
3785
 
2698
3786
  return f1
2699
3787
 
2700
-
2701
-
2702
3788
  def _remove_noninfected(stack, cell_dim, nucleus_dim, pathogen_dim):
2703
3789
  """
2704
3790
  Remove non-infected cells from the stack based on the provided dimensions.
@@ -3179,6 +4265,23 @@ def _object_filter(df, object_type, size_range, intensity_range, mask_chans, mas
3179
4265
  return df
3180
4266
 
3181
4267
  def _get_regex(metadata_type, img_format, custom_regex=None):
4268
+ """
4269
+ Generates a regex pattern based on the specified metadata type and image format.
4270
+ Args:
4271
+ metadata_type (str): The type of metadata to generate the regex for.
4272
+ Supported values are:
4273
+ - 'cellvoyager': Generates a regex for CellVoyager metadata format.
4274
+ - 'cq1': Generates a regex for CQ1 metadata format.
4275
+ - 'auto': Generates a regex for automatic detection of metadata.
4276
+ - 'custom': Uses a custom regex provided by the user.
4277
+ img_format (str): The image file format (e.g., 'tif'). If None, defaults to 'tif'.
4278
+ custom_regex (str, optional): A custom regex pattern to use when metadata_type is 'custom'.
4279
+ Returns:
4280
+ str: The generated regex pattern.
4281
+ Notes:
4282
+ - Prints the image format and the generated regex for debugging purposes.
4283
+ - If `img_format` is None, it defaults to 'tif'.
4284
+ """
3182
4285
 
3183
4286
  print(f"Image_format: {img_format}")
3184
4287
 
@@ -3197,6 +4300,27 @@ def _get_regex(metadata_type, img_format, custom_regex=None):
3197
4300
  return regex
3198
4301
 
3199
4302
  def _run_test_mode(src, regex, timelapse=False, test_images=10, random_test=True):
4303
+ """
4304
+ Prepares a test dataset by selecting and copying a subset of images from the source directory
4305
+ based on a regular expression and other criteria.
4306
+ Args:
4307
+ src (str): The source directory containing the images.
4308
+ regex (str): A regular expression to match filenames. The regex should include named groups
4309
+ 'plateID', 'wellID', and 'fieldID' if applicable.
4310
+ timelapse (bool, optional): If True, limits the selection to one set of images to ensure
4311
+ full sequence inclusion. Defaults to False.
4312
+ test_images (int, optional): The number of image sets to include in the test dataset.
4313
+ Defaults to 10.
4314
+ random_test (bool, optional): If True, randomizes the selection of image sets. Defaults to True.
4315
+ Returns:
4316
+ str: The path to the test folder containing the selected images.
4317
+ Notes:
4318
+ - If a subdirectory named 'orig' exists within the source directory, it will be used as the
4319
+ source for image selection.
4320
+ - The selected images are grouped by a combination of plate, well, and field identifiers
4321
+ extracted from the filename using the provided regex.
4322
+ - The function ensures the test folder is created and populated with the selected images.
4323
+ """
3200
4324
 
3201
4325
  if timelapse:
3202
4326
  test_images = 1 # Use only 1 set for timelapse to ensure full sequence inclusion
@@ -3241,6 +4365,51 @@ def _run_test_mode(src, regex, timelapse=False, test_images=10, random_test=True
3241
4365
  return test_folder_path
3242
4366
 
3243
4367
  def _choose_model(model_name, device, object_type='cell', restore_type=None, object_settings={}):
4368
+ """
4369
+ Selects and initializes a model based on the provided parameters.
4370
+
4371
+ Args:
4372
+ -----------
4373
+ model_name : str
4374
+ The name of the model to use. Supported values include:
4375
+ - 'toxo_pv_lumen' (specific to 'pathogen' object_type)
4376
+ - 'sam', 'cyto', 'cyto2', 'cyto3', 'nuclei' (for default models)
4377
+ device : torch.device
4378
+ The device on which the model will run (e.g., CPU or GPU).
4379
+ object_type : str, optional
4380
+ The type of object the model is intended for. Default is 'cell'.
4381
+ Supported values:
4382
+ - 'pathogen': Uses a specific model for Toxoplasma PV lumen.
4383
+ - 'nucleus': Uses denoising or restoration models.
4384
+ - 'cell': Uses default Cellpose models.
4385
+ restore_type : str or None, optional
4386
+ The type of restoration to apply. Supported values:
4387
+ - 'denoise', 'deblur', 'upsample', or None (default).
4388
+ If None, standard models are used.
4389
+ object_settings : dict, optional
4390
+ Additional settings for specific object types. For example:
4391
+ - For 'pathogen', requires 'diameter' key to specify object diameter.
4392
+
4393
+ Returns:
4394
+ --------
4395
+ model : CellposeModel or CellposeDenoiseModel
4396
+ The initialized model based on the provided parameters.
4397
+
4398
+ Notes:
4399
+ ------
4400
+ - For 'pathogen' object_type with 'toxo_pv_lumen' model_name, a custom
4401
+ pretrained model is loaded from the 'models/cp' directory.
4402
+ - For 'nucleus' object_type, restoration models are initialized with
4403
+ specific restore types.
4404
+ - If an invalid restore_type is provided, it defaults to None with a warning.
4405
+ - GPU usage is determined automatically based on availability.
4406
+
4407
+ Raises:
4408
+ -------
4409
+ KeyError:
4410
+ If 'diameter' is not provided in `object_settings` when using the
4411
+ 'toxo_pv_lumen' model for 'pathogen' object_type.
4412
+ """
3244
4413
  if object_type == 'pathogen':
3245
4414
  if model_name == 'toxo_pv_lumen':
3246
4415
  diameter = object_settings['diameter']
@@ -3292,47 +4461,46 @@ def _choose_model(model_name, device, object_type='cell', restore_type=None, obj
3292
4461
  )
3293
4462
  return model
3294
4463
 
3295
- def _choose_model_v1(model_name, device, object_type='cell', restore_type=None, object_settings={}):
3296
-
3297
- if object_type == 'pathogen':
3298
- if model_name == 'toxo_pv_lumen':
3299
- diameter = object_settings['diameter']
3300
- current_dir = os.path.dirname(__file__)
3301
- model_path = os.path.join(current_dir, 'models', 'cp', 'toxo_pv_lumen.CP_model')
3302
- print(model_path)
3303
- model = cp_models.CellposeModel(gpu=torch.cuda.is_available(), model_type=None, pretrained_model=model_path, diam_mean=diameter, device=device)
3304
- print(f'Using Toxoplasma PV lumen model to generate pathogen masks')
3305
- return model
3306
-
3307
- restore_list = ['denoise', 'deblur', 'upsample', None]
3308
- if restore_type not in restore_list:
3309
- print(f"Invalid restore type. Choose from {restore_list} defaulting to None")
3310
- restore_type = None
4464
+ class SelectChannels:
4465
+ """
4466
+ Selectively retain specific color channels in a 3-channel image tensor.
3311
4467
 
3312
- if restore_type == None:
3313
- if model_name in ['cyto', 'cyto2', 'cyto3', 'nuclei']:
3314
- model = cp_models.CellposeModel(gpu=torch.cuda.is_available(), model_type=model_name, device=device)
3315
- return model
3316
- else:
3317
- if object_type == 'nucleus':
3318
- restore = f'{restore_type}_nuclei'
3319
- model = denoise.CellposeDenoiseModel(gpu=torch.cuda.is_available(), model_type="nuclei",restore_type=restore, chan2_restore=False, device=device)
3320
- return model
4468
+ This transformation zeros out unselected channels based on a list of 1-based
4469
+ indices corresponding to RGB channels.
3321
4470
 
3322
- else:
3323
- restore = f'{restore_type}_cyto3'
3324
- if model_name =='cyto2':
3325
- chan2_restore = True
3326
- if model_name =='cyto':
3327
- chan2_restore = False
3328
- model = denoise.CellposeDenoiseModel(gpu=torch.cuda.is_available(), model_type="cyto3",restore_type=restore, chan2_restore=chan2_restore, device=device)
3329
- return model
4471
+ Args:
4472
+ channels (list of int): List of 1-based indices for channels to retain.
4473
+ Valid values are:
4474
+ - 1: Red
4475
+ - 2: Green
4476
+ - 3: Blue
3330
4477
 
3331
- class SelectChannels:
4478
+ Example:
4479
+ >>> select_channels = SelectChannels([1, 3])
4480
+ >>> modified_img = select_channels(img)
4481
+
4482
+ Note:
4483
+ The input image must be a PyTorch tensor with shape (3, H, W).
4484
+ """
3332
4485
  def __init__(self, channels):
4486
+ """
4487
+ Initialize the SelectChannels instance.
4488
+
4489
+ Args:
4490
+ channels (list of int): Channels (1-based) to retain.
4491
+ """
3333
4492
  self.channels = channels
3334
4493
 
3335
4494
  def __call__(self, img):
4495
+ """
4496
+ Apply the channel selection to the input image tensor.
4497
+
4498
+ Args:
4499
+ img (torch.Tensor): A 3-channel image tensor of shape (3, H, W).
4500
+
4501
+ Returns:
4502
+ torch.Tensor: A tensor where unselected channels are set to zero.
4503
+ """
3336
4504
  img = img.clone()
3337
4505
  if 1 not in self.channels:
3338
4506
  img[0, :, :] = 0 # Zero out the red channel
@@ -3343,6 +4511,20 @@ class SelectChannels:
3343
4511
  return img
3344
4512
 
3345
4513
  def preprocess_image_v1(image_path, image_size=224, channels=[1,2,3], normalize=True):
4514
+ """
4515
+ Preprocess an image for input into a machine learning model.
4516
+
4517
+ Args:
4518
+ image_path (str): The file path to the image to be processed.
4519
+ image_size (int, optional): The size to which the image will be center-cropped. Defaults to 224.
4520
+ channels (list, optional): A list of channel indices to select from the image. Defaults to [1, 2, 3].
4521
+ normalize (bool, optional): Whether to normalize the image using mean and standard deviation. Defaults to True.
4522
+
4523
+ Returns:
4524
+ tuple: A tuple containing:
4525
+ - image (PIL.Image.Image): The original image loaded as a PIL Image object.
4526
+ - input_tensor (torch.Tensor): The preprocessed image as a PyTorch tensor with an added batch dimension.
4527
+ """
3346
4528
 
3347
4529
  if normalize:
3348
4530
  transform = transforms.Compose([
@@ -3361,10 +4543,27 @@ def preprocess_image_v1(image_path, image_size=224, channels=[1,2,3], normalize=
3361
4543
  return image, input_tensor
3362
4544
 
3363
4545
  class SaliencyMapGenerator:
4546
+ """
4547
+ SaliencyMapGenerator is a utility class for generating saliency maps and visualizing model predictions
4548
+ for PyTorch models in binary classification tasks.
4549
+ """
3364
4550
  def __init__(self, model):
4551
+ """
4552
+ Initialize the SaliencyMapGenerator.
4553
+ """
3365
4554
  self.model = model
3366
4555
 
3367
4556
  def compute_saliency_maps(self, X, y):
4557
+ """
4558
+ Compute saliency maps for the given inputs and target labels.
4559
+
4560
+ Args:
4561
+ X (torch.Tensor): Input tensor with requires_grad enabled.
4562
+ y (torch.Tensor): Ground truth labels for the input samples.
4563
+
4564
+ Returns:
4565
+ torch.Tensor: The computed saliency maps, same shape as input.
4566
+ """
3368
4567
  self.model.eval()
3369
4568
  X.requires_grad_()
3370
4569
 
@@ -3381,6 +4580,17 @@ class SaliencyMapGenerator:
3381
4580
  return saliency
3382
4581
 
3383
4582
  def compute_saliency_and_predictions(self, X):
4583
+ """
4584
+ Compute saliency maps and predictions for the given input batch.
4585
+
4586
+ Args:
4587
+ X (torch.Tensor): Input tensor with requires_grad enabled.
4588
+
4589
+ Returns:
4590
+ tuple:
4591
+ - torch.Tensor: Saliency maps for the input.
4592
+ - torch.Tensor: Predicted class labels.
4593
+ """
3384
4594
  self.model.eval()
3385
4595
  X.requires_grad_()
3386
4596
 
@@ -3400,6 +4610,19 @@ class SaliencyMapGenerator:
3400
4610
  return saliency, predictions
3401
4611
 
3402
4612
  def plot_activation_grid(self, X, saliency, predictions, overlay=True, normalize=False):
4613
+ """
4614
+ Plot a grid of input images with overlaid saliency maps and predicted labels.
4615
+
4616
+ Args:
4617
+ X (torch.Tensor): Input tensor of images (N, C, H, W).
4618
+ saliency (torch.Tensor): Corresponding saliency maps (N, C, H, W).
4619
+ predictions (torch.Tensor): Predicted class labels (N,).
4620
+ overlay (bool): Whether to overlay saliency maps on input images. Default is True.
4621
+ normalize (bool): Whether to normalize input images by percentiles. Default is False.
4622
+
4623
+ Returns:
4624
+ matplotlib.figure.Figure: A matplotlib figure object showing the grid.
4625
+ """
3403
4626
  N = X.shape[0]
3404
4627
  rows = (N + 7) // 8
3405
4628
  fig, axs = plt.subplots(rows, 8, figsize=(16, rows * 2))
@@ -3428,6 +4651,17 @@ class SaliencyMapGenerator:
3428
4651
  return fig
3429
4652
 
3430
4653
  def percentile_normalize(self, img, lower_percentile=2, upper_percentile=98):
4654
+ """
4655
+ Normalize an image's intensity per channel using percentile clipping.
4656
+
4657
+ Args:
4658
+ img (np.ndarray): Image of shape (H, W, C) to be normalized.
4659
+ lower_percentile (int): Lower percentile for intensity clipping. Default is 2.
4660
+ upper_percentile (int): Upper percentile for intensity clipping. Default is 98.
4661
+
4662
+ Returns:
4663
+ np.ndarray: Percentile-normalized image.
4664
+ """
3431
4665
  img_normalized = np.zeros_like(img)
3432
4666
 
3433
4667
  for c in range(img.shape[2]): # Iterate over each channel
@@ -3438,7 +4672,19 @@ class SaliencyMapGenerator:
3438
4672
  return img_normalized
3439
4673
 
3440
4674
  class GradCAMGenerator:
4675
+ """
4676
+ GradCAMGenerator generates Grad-CAM (Gradient-weighted Class Activation Mapping) visualizations
4677
+ for CNN-based PyTorch models, supporting binary classification tasks.
4678
+ """
3441
4679
  def __init__(self, model, target_layer, cam_type='gradcam'):
4680
+ """
4681
+ Initialize the GradCAMGenerator and register hooks.
4682
+
4683
+ Args:
4684
+ model (torch.nn.Module): A trained PyTorch model.
4685
+ target_layer (str): The name of the layer to compute Grad-CAM on.
4686
+ cam_type (str): Type of CAM method (default is 'gradcam').
4687
+ """
3442
4688
  self.model = model
3443
4689
  self.model.eval()
3444
4690
  self.target_layer = target_layer
@@ -3451,6 +4697,10 @@ class GradCAMGenerator:
3451
4697
  self.hook_layers()
3452
4698
 
3453
4699
  def hook_layers(self):
4700
+ """
4701
+ Register forward and backward hooks to capture activations and gradients
4702
+ from the specified target layer during inference and backpropagation.
4703
+ """
3454
4704
  # Forward hook to get activations
3455
4705
  def forward_hook(module, input, output):
3456
4706
  self.activations = output
@@ -3463,6 +4713,16 @@ class GradCAMGenerator:
3463
4713
  self.target_layer_module.register_backward_hook(backward_hook)
3464
4714
 
3465
4715
  def get_layer(self, model, target_layer):
4716
+ """
4717
+ Recursively retrieve a layer object from a nested model.
4718
+
4719
+ Args:
4720
+ model (torch.nn.Module): The model containing the layer.
4721
+ target_layer (str): Dot-separated string path to the desired layer.
4722
+
4723
+ Returns:
4724
+ torch.nn.Module: The resolved layer module.
4725
+ """
3466
4726
  # Recursively find the layer specified in target_layer
3467
4727
  modules = target_layer.split('.')
3468
4728
  layer = model
@@ -3471,6 +4731,16 @@ class GradCAMGenerator:
3471
4731
  return layer
3472
4732
 
3473
4733
  def compute_gradcam_maps(self, X, y):
4734
+ """
4735
+ Compute Grad-CAM heatmaps for an input batch and target labels.
4736
+
4737
+ Args:
4738
+ X (torch.Tensor): Input tensor of shape (N, C, H, W).
4739
+ y (torch.Tensor): Target labels (0 or 1) for each sample.
4740
+
4741
+ Returns:
4742
+ np.ndarray: Grad-CAM heatmaps normalized to [0, 1] for each input.
4743
+ """
3474
4744
  X.requires_grad_()
3475
4745
 
3476
4746
  # Forward pass
@@ -3495,6 +4765,17 @@ class GradCAMGenerator:
3495
4765
  return gradcam
3496
4766
 
3497
4767
  def compute_gradcam_and_predictions(self, X):
4768
+ """
4769
+ Compute Grad-CAM heatmaps and class predictions for a batch.
4770
+
4771
+ Args:
4772
+ X (torch.Tensor): Input tensor of shape (N, C, H, W).
4773
+
4774
+ Returns:
4775
+ Tuple[torch.Tensor, torch.Tensor]:
4776
+ - Grad-CAM heatmaps for each sample.
4777
+ - Predicted class labels (0 or 1).
4778
+ """
3498
4779
  self.model.eval()
3499
4780
  X.requires_grad_()
3500
4781
 
@@ -3513,6 +4794,19 @@ class GradCAMGenerator:
3513
4794
  return torch.tensor(gradcam_maps), predictions
3514
4795
 
3515
4796
  def plot_activation_grid(self, X, gradcam, predictions, overlay=True, normalize=False):
4797
+ """
4798
+ Plot a grid of input images overlaid with Grad-CAM heatmaps.
4799
+
4800
+ Args:
4801
+ X (torch.Tensor): Input image batch (N, C, H, W).
4802
+ gradcam (torch.Tensor): Grad-CAM heatmaps (N, H, W).
4803
+ predictions (torch.Tensor): Predicted class labels.
4804
+ overlay (bool): Whether to overlay Grad-CAM on input images.
4805
+ normalize (bool): Whether to normalize image intensities by percentiles.
4806
+
4807
+ Returns:
4808
+ matplotlib.figure.Figure: The generated grid figure.
4809
+ """
3516
4810
  N = X.shape[0]
3517
4811
  rows = (N + 7) // 8
3518
4812
  fig, axs = plt.subplots(rows, 8, figsize=(16, rows * 2))
@@ -3541,6 +4835,17 @@ class GradCAMGenerator:
3541
4835
  return fig
3542
4836
 
3543
4837
  def percentile_normalize(self, img, lower_percentile=2, upper_percentile=98):
4838
+ """
4839
+ Normalize each channel of the input image to the specified percentiles.
4840
+
4841
+ Args:
4842
+ img (np.ndarray): Image array (H, W, C).
4843
+ lower_percentile (int): Lower clipping percentile.
4844
+ upper_percentile (int): Upper clipping percentile.
4845
+
4846
+ Returns:
4847
+ np.ndarray: Percentile-normalized image.
4848
+ """
3544
4849
  img_normalized = np.zeros_like(img)
3545
4850
 
3546
4851
  for c in range(img.shape[2]): # Iterate over each channel
@@ -3551,6 +4856,23 @@ class GradCAMGenerator:
3551
4856
  return img_normalized
3552
4857
 
3553
4858
  def preprocess_image(image_path, normalize=True, image_size=224, channels=[1,2,3]):
4859
+ """
4860
+ Preprocess an image for input into a machine learning model.
4861
+
4862
+ Args:
4863
+ image_path (str): Path to the input image file.
4864
+ normalize (bool, optional): If True, apply ImageNet normalization
4865
+ (mean and std). Defaults to True.
4866
+ image_size (int, optional): Target size (height and width) for resizing.
4867
+ Defaults to 224.
4868
+ channels (list of int, optional): 1-based channel indices to retain (e.g., [1, 2, 3]
4869
+ for RGB). Defaults to [1, 2, 3].
4870
+
4871
+ Returns:
4872
+ tuple:
4873
+ - PIL.Image.Image: The original image.
4874
+ - torch.Tensor: The processed image tensor suitable for model input.
4875
+ """
3554
4876
  preprocess = transforms.Compose([
3555
4877
  transforms.Resize((image_size, image_size)),
3556
4878
  transforms.ToTensor(),
@@ -3565,6 +4887,29 @@ def preprocess_image(image_path, normalize=True, image_size=224, channels=[1,2,3
3565
4887
  return image, input_tensor
3566
4888
 
3567
4889
  def class_visualization(target_y, model_path, dtype, img_size=224, channels=[0,1,2], l2_reg=1e-3, learning_rate=25, num_iterations=100, blur_every=10, max_jitter=16, show_every=25, class_names = ['nc', 'pc']):
4890
+ def class_visualization(target_y, model_path, dtype, img_size=224, channels=[0,1,2], l2_reg=1e-3, learning_rate=25, num_iterations=100, blur_every=10, max_jitter=16, show_every=25, class_names=['nc', 'pc']):
4891
+ """
4892
+ Generate a class visualization for a given target class using gradient ascent.
4893
+ Args:
4894
+ target_y (int): The target class index for which the visualization is generated.
4895
+ model_path (str): Path to the pre-trained model file.
4896
+ dtype (torch.dtype): Data type for the tensors (e.g., torch.FloatTensor or torch.cuda.FloatTensor).
4897
+ img_size (int, optional): Size of the square input image. Default is 224.
4898
+ channels (list, optional): List of channel indices to use in the input image. Default is [0, 1, 2].
4899
+ l2_reg (float, optional): L2 regularization strength to prevent overfitting. Default is 1e-3.
4900
+ learning_rate (float, optional): Learning rate for gradient ascent. Default is 25.
4901
+ num_iterations (int, optional): Number of iterations for gradient ascent. Default is 100.
4902
+ blur_every (int, optional): Frequency (in iterations) to apply Gaussian blur as a regularizer. Default is 10.
4903
+ max_jitter (int, optional): Maximum pixel jitter for random image translation. Default is 16.
4904
+ show_every (int, optional): Frequency (in iterations) to display the intermediate visualization. Default is 25.
4905
+ class_names (list, optional): List of class names corresponding to the target indices. Default is ['nc', 'pc'].
4906
+ Returns:
4907
+ numpy.ndarray: The final deprocessed image as a NumPy array.
4908
+ Notes:
4909
+ - The function performs gradient ascent on the input image to maximize the score of the target class.
4910
+ - Regularization techniques such as L2 regularization, clamping, and Gaussian blur are applied to improve visualization quality.
4911
+ - The function assumes the model is compatible with the input image size and channels.
4912
+ """
3568
4913
 
3569
4914
  def jitter(img, ox, oy):
3570
4915
  # Randomly jitter the image
@@ -3645,6 +4990,18 @@ def class_visualization(target_y, model_path, dtype, img_size=224, channels=[0,1
3645
4990
  return deprocess(img.data.cpu())
3646
4991
 
3647
4992
  def get_submodules(model, prefix=''):
4993
+ """
4994
+ Recursively retrieves the names of all submodules in a given model.
4995
+
4996
+ Args:
4997
+ model (torch.nn.Module): The model whose submodules are to be retrieved.
4998
+ prefix (str, optional): A prefix to prepend to the names of the submodules.
4999
+ Defaults to an empty string.
5000
+
5001
+ Returns:
5002
+ list of str: A list of strings representing the full names of all submodules
5003
+ in the model, including nested submodules.
5004
+ """
3648
5005
  submodules = []
3649
5006
  for name, module in model.named_children():
3650
5007
  full_name = prefix + ('.' if prefix else '') + name
@@ -3653,7 +5010,18 @@ def get_submodules(model, prefix=''):
3653
5010
  return submodules
3654
5011
 
3655
5012
  class GradCAM:
5013
+ """
5014
+ Compute Grad-CAM (Gradient-weighted Class Activation Mapping) for a given model and target layer(s).
5015
+ """
3656
5016
  def __init__(self, model, target_layers=None, use_cuda=True):
5017
+ """
5018
+ Initialize the GradCAM object.
5019
+
5020
+ Args:
5021
+ model (nn.Module): The model for which Grad-CAM will be computed.
5022
+ target_layers (list of str): Names of layers to register hooks on.
5023
+ use_cuda (bool): Whether to use CUDA (GPU) for computation.
5024
+ """
3657
5025
  self.model = model
3658
5026
  self.model.eval()
3659
5027
  self.target_layers = target_layers
@@ -3662,9 +5030,28 @@ class GradCAM:
3662
5030
  self.model = model.cuda()
3663
5031
 
3664
5032
  def forward(self, input):
5033
+ """
5034
+ Run a forward pass through the model.
5035
+
5036
+ Args:
5037
+ input (torch.Tensor): Input tensor.
5038
+
5039
+ Returns:
5040
+ torch.Tensor: Model output.
5041
+ """
3665
5042
  return self.model(input)
3666
5043
 
3667
5044
  def __call__(self, x, index=None):
5045
+ """
5046
+ Compute the Grad-CAM heatmap for an input image.
5047
+
5048
+ Args:
5049
+ x (torch.Tensor): Input tensor of shape (1, C, H, W).
5050
+ index (int or None): Class index to compute gradients for. If None, uses the predicted class.
5051
+
5052
+ Returns:
5053
+ numpy.ndarray: Normalized Grad-CAM heatmap of shape (H, W).
5054
+ """
3668
5055
  if self.cuda:
3669
5056
  x = x.cuda()
3670
5057
 
@@ -3711,6 +5098,23 @@ class GradCAM:
3711
5098
  return cam
3712
5099
 
3713
5100
  def show_cam_on_image(img, mask):
5101
+ """
5102
+ Overlay a heatmap generated from a mask onto an image.
5103
+
5104
+ This function applies a color map to the mask, combines it with the input
5105
+ image, and normalizes the result to create a visually interpretable
5106
+ representation of the mask overlaid on the image.
5107
+
5108
+ Args:
5109
+ img (numpy.ndarray): The input image as a NumPy array with pixel values
5110
+ normalized between 0 and 1.
5111
+ mask (numpy.ndarray): The mask to overlay on the image, with values
5112
+ normalized between 0 and 1.
5113
+
5114
+ Returns:
5115
+ numpy.ndarray: The resulting image with the heatmap overlay, as a
5116
+ NumPy array with pixel values in the range [0, 255].
5117
+ """
3714
5118
  heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)
3715
5119
  heatmap = np.float32(heatmap) / 255
3716
5120
  cam = heatmap + np.float32(img)
@@ -3718,6 +5122,24 @@ def show_cam_on_image(img, mask):
3718
5122
  return np.uint8(255 * cam)
3719
5123
 
3720
5124
  def recommend_target_layers(model):
5125
+ """
5126
+ Identifies and recommends target layers in a given model for further processing.
5127
+
5128
+ This function iterates through all the modules in the provided model and collects
5129
+ the names of all 2D convolutional layers (`torch.nn.Conv2d`). It then recommends
5130
+ the last convolutional layer as the primary target layer.
5131
+
5132
+ Args:
5133
+ model (torch.nn.Module): The neural network model to analyze.
5134
+
5135
+ Returns:
5136
+ tuple:
5137
+ - list: A list containing the name of the recommended target layer (last Conv2d layer).
5138
+ - list: A list of all Conv2d layer names found in the model.
5139
+
5140
+ Raises:
5141
+ ValueError: If no convolutional layers (`torch.nn.Conv2d`) are found in the model.
5142
+ """
3721
5143
  target_layers = []
3722
5144
  for name, module in model.named_modules():
3723
5145
  if isinstance(module, torch.nn.Conv2d):
@@ -3729,11 +5151,38 @@ def recommend_target_layers(model):
3729
5151
  raise ValueError("No convolutional layers found in the model.")
3730
5152
 
3731
5153
  class IntegratedGradients:
5154
+ """
5155
+ Compute Integrated Gradients for model interpretability.
5156
+
5157
+ This class implements the Integrated Gradients method to attribute the prediction
5158
+ of a neural network to its input features. It approximates the integral of gradients
5159
+ along a straight path from a baseline to the input.
5160
+ """
5161
+
5162
+
3732
5163
  def __init__(self, model):
5164
+ """
5165
+ Initialize the IntegratedGradients instance.
5166
+
5167
+ Args:
5168
+ model (torch.nn.Module): A trained PyTorch model.
5169
+ """
3733
5170
  self.model = model
3734
5171
  self.model.eval()
3735
5172
 
3736
5173
  def generate_integrated_gradients(self, input_tensor, target_label_idx, baseline=None, num_steps=50):
5174
+ """
5175
+ Compute the integrated gradients for a given input and target class.
5176
+
5177
+ Args:
5178
+ input_tensor (torch.Tensor): The input tensor of shape (1, C, H, W) or similar.
5179
+ target_label_idx (int): Index of the target class for which gradients are computed.
5180
+ baseline (torch.Tensor, optional): Baseline tensor with the same shape as input. Defaults to zeros.
5181
+ num_steps (int, optional): Number of steps in the Riemann approximation of the integral. Defaults to 50.
5182
+
5183
+ Returns:
5184
+ np.ndarray: Integrated gradients as a NumPy array with the same shape as `input_tensor`.
5185
+ """
3737
5186
  if baseline is None:
3738
5187
  baseline = torch.zeros_like(input_tensor)
3739
5188
 
@@ -3753,18 +5202,61 @@ class IntegratedGradients:
3753
5202
  return integrated_grads
3754
5203
 
3755
5204
  def get_db_paths(src):
5205
+ """
5206
+ Generate a list of database file paths based on the given source(s).
5207
+
5208
+ This function takes a single source path or a list of source paths and
5209
+ constructs the corresponding paths to the 'measurements.db' file located
5210
+ in the 'measurements' subdirectory of each source.
5211
+
5212
+ Args:
5213
+ src (str or list of str): A single source path as a string or a list
5214
+ of source paths.
5215
+
5216
+ Returns:
5217
+ list of str: A list of file paths pointing to 'measurements/measurements.db'
5218
+ for each source in the input.
5219
+ """
3756
5220
  if isinstance(src, str):
3757
5221
  src = [src]
3758
5222
  db_paths = [os.path.join(source, 'measurements/measurements.db') for source in src]
3759
5223
  return db_paths
3760
5224
 
3761
5225
  def get_sequencing_paths(src):
5226
+ """
5227
+ Generate a list of file paths pointing to sequencing data CSV files.
5228
+
5229
+ This function takes a single source path or a list of source paths and
5230
+ constructs the full file paths to the 'sequencing/sequencing_data.csv'
5231
+ file located within each source directory.
5232
+
5233
+ Args:
5234
+ src (str or list of str): A single source directory path as a string
5235
+ or a list of source directory paths.
5236
+
5237
+ Returns:
5238
+ list of str: A list of full file paths to the 'sequencing_data.csv'
5239
+ files for each source directory.
5240
+ """
3762
5241
  if isinstance(src, str):
3763
5242
  src = [src]
3764
5243
  seq_paths = [os.path.join(source, 'sequencing/sequencing_data.csv') for source in src]
3765
5244
  return seq_paths
3766
5245
 
3767
5246
  def load_image_paths(c, visualize):
5247
+ """
5248
+ Loads image paths from a database table and optionally filters them based on a visualization keyword.
5249
+
5250
+ Args:
5251
+ c (sqlite3.Cursor): A database cursor object used to execute SQL queries.
5252
+ visualize (str): A keyword to filter image paths. If provided, only rows where the 'png_path' column
5253
+ contains the keyword followed by '_png' will be included. If None or empty, no filtering
5254
+ is applied.
5255
+
5256
+ Returns:
5257
+ pandas.DataFrame: A DataFrame containing the image paths and other associated data from the 'png_list' table.
5258
+ The DataFrame is indexed by the 'prcfo' column.
5259
+ """
3768
5260
  c.execute(f'SELECT * FROM png_list')
3769
5261
  data = c.fetchall()
3770
5262
  columns_info = c.execute(f'PRAGMA table_info(png_list)').fetchall()
@@ -3777,6 +5269,18 @@ def load_image_paths(c, visualize):
3777
5269
  return image_paths_df
3778
5270
 
3779
5271
  def merge_dataframes(df, image_paths_df, verbose):
5272
+ """
5273
+ Merges two pandas DataFrames on their indices and optionally displays the result.
5274
+
5275
+ Parameters:
5276
+ df (pandas.DataFrame): The main DataFrame to be merged. It must have a column named 'prcfo',
5277
+ which will be set as the index before merging.
5278
+ image_paths_df (pandas.DataFrame): The DataFrame containing image paths to be merged with `df`.
5279
+ verbose (bool): If True, the resulting merged DataFrame will be displayed.
5280
+
5281
+ Returns:
5282
+ pandas.DataFrame: The merged DataFrame with the indices aligned.
5283
+ """
3780
5284
  df.set_index('prcfo', inplace=True)
3781
5285
  df = image_paths_df.merge(df, left_index=True, right_index=True)
3782
5286
  if verbose:
@@ -3784,12 +5288,48 @@ def merge_dataframes(df, image_paths_df, verbose):
3784
5288
  return df
3785
5289
 
3786
5290
  def remove_highly_correlated_columns_v1(df, threshold):
5291
+ """
5292
+ Removes columns from a DataFrame that are highly correlated with other columns.
5293
+
5294
+ This function calculates the correlation matrix of the given DataFrame, identifies
5295
+ columns with a correlation higher than the specified threshold, and removes them
5296
+ to reduce multicollinearity.
5297
+
5298
+ Args:
5299
+ df (pandas.DataFrame): The input DataFrame containing the data.
5300
+ threshold (float): The correlation threshold above which columns are considered
5301
+ highly correlated and will be removed.
5302
+
5303
+ Returns:
5304
+ pandas.DataFrame: A DataFrame with highly correlated columns removed.
5305
+
5306
+ Example:
5307
+ >>> import pandas as pd
5308
+ >>> import numpy as np
5309
+ >>> data = {'A': [1, 2, 3], 'B': [2, 4, 6], 'C': [7, 8, 9]}
5310
+ >>> df = pd.DataFrame(data)
5311
+ >>> remove_highly_correlated_columns_v1(df, threshold=0.9)
5312
+ A C
5313
+ """
3787
5314
  corr_matrix = df.corr().abs()
3788
5315
  upper_tri = corr_matrix.where(np.triu(np.ones(corr_matrix.shape), k=1).astype(bool))
3789
5316
  to_drop = [column for column in upper_tri.columns if any(upper_tri[column] > threshold)]
3790
5317
  return df.drop(to_drop, axis=1)
3791
5318
 
3792
5319
  def filter_columns(df, filter_by):
5320
+ """
5321
+ Filters the columns of a DataFrame based on a specified criterion.
5322
+
5323
+ Args:
5324
+ df (pandas.DataFrame): The input DataFrame whose columns are to be filtered.
5325
+ filter_by (str): The criterion for filtering columns. If 'morphology',
5326
+ columns containing 'channel' in their names are excluded.
5327
+ Otherwise, only columns containing the specified string
5328
+ are included.
5329
+
5330
+ Returns:
5331
+ pandas.DataFrame: A DataFrame containing only the filtered columns.
5332
+ """
3793
5333
  if filter_by != 'morphology':
3794
5334
  cols_to_include = [col for col in df.columns if filter_by in str(col)]
3795
5335
  else:
@@ -3801,7 +5341,7 @@ def reduction_and_clustering(numeric_data, n_neighbors, min_dist, metric, eps, m
3801
5341
  """
3802
5342
  Perform dimensionality reduction and clustering on the given data.
3803
5343
 
3804
- Parameters:
5344
+ Args:
3805
5345
  numeric_data (np.ndarray): Numeric data for embedding and clustering.
3806
5346
  n_neighbors (int or float): Number of neighbors for UMAP or perplexity for t-SNE.
3807
5347
  min_dist (float): Minimum distance for UMAP.
@@ -3903,12 +5443,51 @@ def reduction_and_clustering(numeric_data, n_neighbors, min_dist, metric, eps, m
3903
5443
  return embedding, labels, reducer
3904
5444
 
3905
5445
  def remove_noise(embedding, labels):
5446
+ """
5447
+ Removes noise from the given embedding and labels by filtering out elements
5448
+ where the corresponding label is -1.
5449
+
5450
+ Args:
5451
+ embedding (numpy.ndarray): The embedding array, where each row corresponds
5452
+ to a data point.
5453
+ labels (numpy.ndarray): The array of labels corresponding to the embedding,
5454
+ where a label of -1 indicates noise.
5455
+
5456
+ Returns:
5457
+ tuple: A tuple containing:
5458
+ - numpy.ndarray: The filtered embedding array with noise removed.
5459
+ - numpy.ndarray: The filtered labels array with noise removed.
5460
+ """
3906
5461
  non_noise_indices = labels != -1
3907
5462
  embedding = embedding[non_noise_indices]
3908
5463
  labels = labels[non_noise_indices]
3909
5464
  return embedding, labels
3910
5465
 
3911
5466
  def plot_embedding(embedding, image_paths, labels, image_nr, img_zoom, colors, plot_by_cluster, plot_outlines, plot_points, plot_images, smooth_lines, black_background, figuresize, dot_size, remove_image_canvas, verbose):
5467
+ """
5468
+ Plots a 2D embedding with optional images, clusters, and customization options.
5469
+
5470
+ Args:
5471
+ embedding (np.ndarray): A 2D array of shape (n_samples, 2) representing the embedding coordinates.
5472
+ image_paths (list or None): A list of file paths to images corresponding to the data points, or None if no images are used.
5473
+ labels (np.ndarray): An array of cluster labels for each data point.
5474
+ image_nr (int): The number of images to display on the plot.
5475
+ img_zoom (float): The zoom factor for the displayed images.
5476
+ colors (list or None): A list of colors to use for clusters, or None to use default colors.
5477
+ plot_by_cluster (bool): Whether to plot images grouped by cluster.
5478
+ plot_outlines (bool): Whether to draw outlines around clusters.
5479
+ plot_points (bool): Whether to plot individual data points.
5480
+ plot_images (bool): Whether to overlay images on the embedding.
5481
+ smooth_lines (bool): Whether to draw smooth lines between cluster centers.
5482
+ black_background (bool): Whether to use a black background for the plot.
5483
+ figuresize (tuple): The size of the figure in inches (width, height).
5484
+ dot_size (float): The size of the dots representing data points.
5485
+ remove_image_canvas (bool): Whether to remove the canvas around the images.
5486
+ verbose (bool): Whether to print verbose output during the plotting process.
5487
+
5488
+ Returns:
5489
+ matplotlib.figure.Figure: The generated plot as a Matplotlib figure object.
5490
+ """
3912
5491
  unique_labels = np.unique(labels)
3913
5492
  #num_clusters = len(unique_labels[unique_labels != 0])
3914
5493
  colors, label_to_color_index = assign_colors(unique_labels, colors)
@@ -3921,6 +5500,22 @@ def plot_embedding(embedding, image_paths, labels, image_nr, img_zoom, colors, p
3921
5500
  return fig
3922
5501
 
3923
5502
  def generate_colors(num_clusters, black_background):
5503
+ """
5504
+ Generate a set of RGBA colors for visualization purposes.
5505
+
5506
+ This function generates a list of random RGBA colors, appends specific predefined colors,
5507
+ and optionally includes a black background color.
5508
+
5509
+ Args:
5510
+ num_clusters (int): The number of clusters for which colors need to be generated.
5511
+ Additional random colors will be generated beyond the predefined ones.
5512
+ black_background (bool): If True, a black background color ([0, 0, 0, 1]) will be included
5513
+ at the beginning of the color list.
5514
+
5515
+ Returns:
5516
+ numpy.ndarray: A 2D array of shape (num_colors, 4), where each row represents an RGBA color.
5517
+ The first dimension corresponds to the total number of colors generated.
5518
+ """
3924
5519
  random_colors = np.random.rand(num_clusters + 1, 4)
3925
5520
  random_colors[:, 3] = 1
3926
5521
  specific_colors = [
@@ -3935,6 +5530,19 @@ def generate_colors(num_clusters, black_background):
3935
5530
  return random_colors
3936
5531
 
3937
5532
  def assign_colors(unique_labels, random_colors):
5533
+ """
5534
+ Assigns colors to unique labels and creates a mapping from labels to color indices.
5535
+
5536
+ Args:
5537
+ unique_labels (list or iterable): A collection of unique labels for which colors need to be assigned.
5538
+ random_colors (numpy.ndarray or list): An array or list of RGB color values, where each color is represented
5539
+ as a triplet of integers in the range [0, 255].
5540
+
5541
+ Returns:
5542
+ tuple: A tuple containing:
5543
+ - colors (list of tuple): A list of RGB color tuples in the original [0, 255] range.
5544
+ - label_to_color_index (dict): A dictionary mapping each unique label to its corresponding color index.
5545
+ """
3938
5546
  normalized_colors = random_colors / 255
3939
5547
  colors_img = [tuple(color) for color in normalized_colors]
3940
5548
  colors = [tuple(color) for color in random_colors]
@@ -3942,6 +5550,17 @@ def assign_colors(unique_labels, random_colors):
3942
5550
  return colors, label_to_color_index
3943
5551
 
3944
5552
  def setup_plot(figuresize, black_background):
5553
+ """
5554
+ Sets up a matplotlib plot with specified figure size and background color.
5555
+
5556
+ Args:
5557
+ figuresize (float): The size of the figure in inches (used for both width and height).
5558
+ black_background (bool): If True, sets the plot to have a black background with white text and labels.
5559
+ If False, sets the plot to have a white background with black text and labels.
5560
+
5561
+ Returns:
5562
+ tuple: A tuple containing the figure (`matplotlib.figure.Figure`) and axes (`matplotlib.axes._axes.Axes`) objects.
5563
+ """
3945
5564
  if black_background:
3946
5565
  plt.rcParams.update({'figure.facecolor': 'black', 'axes.facecolor': 'black', 'text.color': 'white', 'xtick.color': 'white', 'ytick.color': 'white', 'axes.labelcolor': 'white'})
3947
5566
  else:
@@ -3950,6 +5569,30 @@ def setup_plot(figuresize, black_background):
3950
5569
  return fig, ax
3951
5570
 
3952
5571
  def plot_clusters(ax, embedding, labels, colors, cluster_centers, plot_outlines, plot_points, smooth_lines, figuresize=10, dot_size=50, verbose=False):
5572
+ """
5573
+ Plots clusters on a 2D embedding using matplotlib.
5574
+
5575
+ Args:
5576
+ ax (matplotlib.axes.Axes): The matplotlib Axes object to plot on.
5577
+ embedding (numpy.ndarray): A 2D array of shape (n_samples, 2) representing the embedding coordinates.
5578
+ labels (numpy.ndarray): An array of cluster labels for each point in the embedding.
5579
+ colors (list): A list of colors corresponding to each cluster.
5580
+ cluster_centers (numpy.ndarray): A 2D array of shape (n_clusters, 2) representing the coordinates of cluster centers.
5581
+ plot_outlines (bool): Whether to plot the outlines of clusters using convex hulls or smoothed lines.
5582
+ plot_points (bool): Whether to plot individual points in the clusters.
5583
+ smooth_lines (bool): Whether to use smoothed lines for cluster outlines instead of convex hulls.
5584
+ figuresize (int, optional): The size of the figure. Defaults to 10.
5585
+ dot_size (int, optional): The size of the points in the scatter plot. Defaults to 50.
5586
+ verbose (bool, optional): Whether to print additional information for debugging. Defaults to False.
5587
+
5588
+ Returns:
5589
+ None
5590
+
5591
+ Notes:
5592
+ - This function assumes that the embedding is 2D.
5593
+ - Cluster labels should be integers, with -1 typically representing noise.
5594
+ - The function uses matplotlib for plotting and assumes that the required libraries (e.g., numpy, matplotlib) are imported.
5595
+ """
3953
5596
  unique_labels = np.unique(labels)
3954
5597
  for cluster_label, color, center in zip(unique_labels, colors, cluster_centers):
3955
5598
  cluster_data = embedding[labels == cluster_label]
@@ -3975,6 +5618,24 @@ def plot_clusters(ax, embedding, labels, colors, cluster_centers, plot_outlines,
3975
5618
  plt.tick_params(axis='both', which='major', labelsize=int(figuresize * 0.75))
3976
5619
 
3977
5620
  def plot_umap_images(ax, image_paths, embedding, labels, image_nr, img_zoom, colors, plot_by_cluster, remove_image_canvas, verbose):
5621
+ """
5622
+ Plots UMAP embeddings with associated images on a given matplotlib axis.
5623
+
5624
+ Args:
5625
+ ax (matplotlib.axes.Axes): The matplotlib axis on which to plot the images.
5626
+ image_paths (list of str): List of file paths to the images to be plotted.
5627
+ embedding (numpy.ndarray): 2D array of UMAP embeddings with shape (n_samples, 2).
5628
+ labels (numpy.ndarray): Array of cluster labels for each embedding point.
5629
+ image_nr (int): Number of images to plot.
5630
+ img_zoom (float): Zoom factor for the images.
5631
+ colors (list): List of colors for each cluster.
5632
+ plot_by_cluster (bool): If True, plot images grouped by cluster; otherwise, plot randomly sampled images.
5633
+ remove_image_canvas (bool): If True, remove the image canvas (background) when plotting.
5634
+ verbose (bool): If True, print additional information during execution.
5635
+
5636
+ Returns:
5637
+ None
5638
+ """
3978
5639
  if plot_by_cluster:
3979
5640
  cluster_indices = {label: np.where(labels == label)[0] for label in np.unique(labels) if label != -1}
3980
5641
  plot_images_by_cluster(ax, image_paths, embedding, labels, image_nr, img_zoom, colors, cluster_indices, remove_image_canvas, verbose)
@@ -3986,6 +5647,24 @@ def plot_umap_images(ax, image_paths, embedding, labels, image_nr, img_zoom, col
3986
5647
  plot_image(ax, x, y, img, img_zoom, remove_image_canvas)
3987
5648
 
3988
5649
  def plot_images_by_cluster(ax, image_paths, embedding, labels, image_nr, img_zoom, colors, cluster_indices, remove_image_canvas, verbose):
5650
+ """
5651
+ Plots images on a given axis based on their cluster assignments and embeddings.
5652
+
5653
+ Args:
5654
+ ax (matplotlib.axes.Axes): The matplotlib axis on which to plot the images.
5655
+ image_paths (list of str): List of file paths to the images to be plotted.
5656
+ embedding (array-like): 2D array of shape (n_samples, 2) containing the x and y coordinates for each image.
5657
+ labels (array-like): Array of cluster labels for each image. -1 indicates noise or unclustered points.
5658
+ image_nr (int): Maximum number of images to display per cluster.
5659
+ img_zoom (float): Zoom factor for the displayed images.
5660
+ colors (list of str): List of colors corresponding to each cluster.
5661
+ cluster_indices (dict): Dictionary mapping cluster labels to lists of indices of images in each cluster.
5662
+ remove_image_canvas (bool): If True, removes the canvas (border) around the plotted images.
5663
+ verbose (bool): If True, prints additional information during execution.
5664
+
5665
+ Returns:
5666
+ None
5667
+ """
3989
5668
  for cluster_label, color in zip(np.unique(labels), colors):
3990
5669
  if cluster_label == -1:
3991
5670
  continue
@@ -3998,6 +5677,22 @@ def plot_images_by_cluster(ax, image_paths, embedding, labels, image_nr, img_zoo
3998
5677
  plot_image(ax, x, y, img, img_zoom, remove_image_canvas)
3999
5678
 
4000
5679
  def plot_image(ax, x, y, img, img_zoom, remove_image_canvas=True):
5680
+ """
5681
+ Plots an image on a given matplotlib axis at specified coordinates.
5682
+
5683
+ Args:
5684
+ ax (matplotlib.axes.Axes): The axis on which to plot the image.
5685
+ x (float): The x-coordinate where the image will be placed.
5686
+ y (float): The y-coordinate where the image will be placed.
5687
+ img (numpy.ndarray or array-like): The image data to be plotted.
5688
+ img_zoom (float): The zoom factor for the image.
5689
+ remove_image_canvas (bool, optional): If True, removes the canvas
5690
+ (e.g., padding or borders) from the image before plotting.
5691
+ Defaults to True.
5692
+
5693
+ Returns:
5694
+ None
5695
+ """
4001
5696
  img = np.array(img)
4002
5697
  if remove_image_canvas:
4003
5698
  img = remove_canvas(img)
@@ -4006,6 +5701,26 @@ def plot_image(ax, x, y, img, img_zoom, remove_image_canvas=True):
4006
5701
  ax.add_artist(ab)
4007
5702
 
4008
5703
  def remove_canvas(img):
5704
+ """
5705
+ Converts an image to a normalized RGBA format by adding an alpha channel.
5706
+
5707
+ This function processes images in either grayscale ('L', 'I') or RGB ('RGB') mode.
5708
+ For grayscale images, the pixel values are normalized, and an alpha channel is
5709
+ created based on non-zero pixel values. For RGB images, the pixel values are
5710
+ normalized to the range [0, 1], and an alpha channel is created based on the
5711
+ presence of non-zero pixel values across all channels.
5712
+
5713
+ Args:
5714
+ img (PIL.Image.Image): The input image to process. Must be in 'L', 'I', or 'RGB' mode.
5715
+
5716
+ Returns:
5717
+ numpy.ndarray: A 4-channel RGBA image as a NumPy array, where the first three
5718
+ channels represent the normalized RGB values, and the fourth channel represents
5719
+ the alpha channel.
5720
+
5721
+ Raises:
5722
+ ValueError: If the input image mode is not 'L', 'I', or 'RGB'.
5723
+ """
4009
5724
  if img.mode in ['L', 'I']:
4010
5725
  img_data = np.array(img)
4011
5726
  img_data = img_data / np.max(img_data)
@@ -4022,6 +5737,27 @@ def remove_canvas(img):
4022
5737
  return img_data_with_alpha
4023
5738
 
4024
5739
  def plot_clusters_grid(embedding, labels, image_nr, image_paths, colors, figuresize, black_background, verbose):
5740
+ """
5741
+ Plot a grid of images for each cluster based on the given labels and embeddings.
5742
+
5743
+ Args:
5744
+ embedding (np.ndarray): Embedding of data points for visualization.
5745
+ labels (np.ndarray): Cluster labels for each data point. A value of -1 indicates noise or outliers.
5746
+ image_nr (int): Maximum number of images to display per cluster.
5747
+ image_paths (list of str): File paths to images corresponding to the data points.
5748
+ colors (list of str): List of colors for each cluster.
5749
+ figuresize (tuple): Size of the figure (width, height).
5750
+ black_background (bool): Whether to use a black background.
5751
+ verbose (bool): Whether to print progress information.
5752
+
5753
+ Returns:
5754
+ matplotlib.figure.Figure or None: The generated figure, or None if no valid clusters are found.
5755
+
5756
+ Notes:
5757
+ - Clusters larger than `image_nr` are randomly subsampled.
5758
+ - If all labels are -1, the function returns None.
5759
+ - Relies on an external `plot_grid` function for grid rendering.
5760
+ """
4025
5761
  unique_labels = np.unique(labels)
4026
5762
  num_clusters = len(unique_labels[unique_labels != -1])
4027
5763
  if num_clusters == 0:
@@ -4042,7 +5778,25 @@ def plot_clusters_grid(embedding, labels, image_nr, image_paths, colors, figures
4042
5778
  fig = plot_grid(cluster_images, colors, figuresize, black_background, verbose)
4043
5779
  return fig
4044
5780
 
4045
- def plot_grid(cluster_images, colors, figuresize, black_background, verbose):
5781
+ def plot_grid(cluster_images, colors, figuresize, black_background, verbose):
5782
+ """
5783
+ Plot a grid of images grouped by cluster with optional background and labels.
5784
+
5785
+ Args:
5786
+ cluster_images (dict): Dictionary mapping cluster labels to lists of images.
5787
+ colors (list): List of RGB tuples specifying colors for each cluster.
5788
+ figuresize (float): Base figure size; actual size scales with the number of clusters.
5789
+ black_background (bool): If True, use a black background; otherwise, use white.
5790
+ verbose (bool): If True, print cluster labels and index info during plotting.
5791
+
5792
+ Returns:
5793
+ matplotlib.figure.Figure: The generated figure containing the image grid.
5794
+
5795
+ Notes:
5796
+ - Grid size is dynamically adjusted per cluster.
5797
+ - Cluster labels are shown alongside image grids using corresponding colors.
5798
+ - A maximum figure size limit prevents overly large plots.
5799
+ """
4046
5800
  num_clusters = len(cluster_images)
4047
5801
  max_figsize = 200 # Set a maximum figure size
4048
5802
  if figuresize * num_clusters > max_figsize:
@@ -4090,6 +5844,27 @@ def plot_grid(cluster_images, colors, figuresize, black_background, verbose):
4090
5844
  return grid_fig
4091
5845
 
4092
5846
  def generate_path_list_from_db(db_path, file_metadata):
5847
+ """
5848
+ Generate a list of file paths from a SQLite database using optional metadata filters.
5849
+
5850
+ Args:
5851
+ db_path (str): Path to the SQLite database.
5852
+ file_metadata (str | list[str] | None): Filter criteria for file paths.
5853
+ - str: Only include paths containing the string.
5854
+ - list of str: Include paths containing any of the strings.
5855
+ - None or empty: Include all paths.
5856
+
5857
+ Returns:
5858
+ list[str] or None: List of matching file paths, or None if an error occurs.
5859
+
5860
+ Raises:
5861
+ sqlite3.Error: If a database operation fails.
5862
+ Exception: For any other unexpected error.
5863
+
5864
+ Notes:
5865
+ - Paths are fetched from the 'png_list' table using the 'png_path' column.
5866
+ - Results are retrieved in batches of 1000 rows for efficiency.
5867
+ """
4093
5868
  all_paths = []
4094
5869
 
4095
5870
  # Connect to the database and retrieve the image paths
@@ -4128,6 +5903,27 @@ def generate_path_list_from_db(db_path, file_metadata):
4128
5903
  return all_paths
4129
5904
 
4130
5905
  def correct_paths(df, base_path, folder='data'):
5906
+ """
5907
+ Adjust file paths to include the specified base directory and folder.
5908
+
5909
+ Args:
5910
+ df (pandas.DataFrame or list): Input containing file paths.
5911
+ - If a DataFrame, it must have a 'png_path' column.
5912
+ - If a list, it should contain file path strings.
5913
+ base_path (str): Base directory to prepend if not already present.
5914
+ folder (str, optional): Folder name to insert into paths (default: 'data').
5915
+
5916
+ Returns:
5917
+ tuple or list:
5918
+ - If input is a DataFrame: (updated DataFrame, list of adjusted paths).
5919
+ - If input is a list: list of adjusted paths.
5920
+
5921
+ Raises:
5922
+ ValueError: If the DataFrame does not contain a 'png_path' column.
5923
+
5924
+ Notes:
5925
+ Paths already containing the base path are not modified.
5926
+ """
4131
5927
 
4132
5928
  if isinstance(df, pd.DataFrame):
4133
5929
 
@@ -4159,6 +5955,24 @@ def correct_paths(df, base_path, folder='data'):
4159
5955
  return adjusted_image_paths
4160
5956
 
4161
5957
  def delete_folder(folder_path):
5958
+ """
5959
+ Deletes a folder and all of its contents, including subdirectories and files.
5960
+
5961
+ Args:
5962
+ folder_path (str): The path to the folder to be deleted.
5963
+
5964
+ Behavior:
5965
+ - If the specified folder exists and is a directory, it recursively deletes all files
5966
+ and subdirectories within it, and then removes the folder itself.
5967
+ - If the folder does not exist or is not a directory, a message is printed indicating this.
5968
+
5969
+ Prints:
5970
+ - A confirmation message if the folder is successfully deleted.
5971
+ - An error message if the folder does not exist or is not a directory.
5972
+
5973
+ Example:
5974
+ delete_folder('/path/to/folder')
5975
+ """
4162
5976
  if os.path.exists(folder_path) and os.path.isdir(folder_path):
4163
5977
  for root, dirs, files in os.walk(folder_path, topdown=False):
4164
5978
  for name in files:
@@ -4171,6 +5985,28 @@ def delete_folder(folder_path):
4171
5985
  print(f"Folder '{folder_path}' does not exist or is not a directory.")
4172
5986
 
4173
5987
  def measure_test_mode(settings):
5988
+ """
5989
+ Adjusts the source folder in the settings dictionary for test mode.
5990
+
5991
+ If `test_mode` is enabled in the `settings` dictionary, this function:
5992
+ - Checks if the current source folder (`settings['src']`) is not already set to 'test'.
5993
+ - Selects a random subset of files from the source folder based on `settings['test_nr']`.
5994
+ - Copies the selected files into a new 'test/merged' directory.
5995
+ - Updates the `settings['src']` to point to the new 'test/merged' directory.
5996
+ - Prints a message indicating the change in the source folder.
5997
+
5998
+ If the source folder is already set to 'test', it simply prints a message confirming the test mode.
5999
+
6000
+ Args:
6001
+ settings (dict): A dictionary containing configuration settings.
6002
+ Expected keys:
6003
+ - 'test_mode' (bool): Whether test mode is enabled.
6004
+ - 'src' (str): Path to the source folder.
6005
+ - 'test_nr' (int): Number of files to select for test mode.
6006
+
6007
+ Returns:
6008
+ dict: The updated settings dictionary with the modified source folder if test mode is enabled.
6009
+ """
4174
6010
 
4175
6011
  if settings['test_mode']:
4176
6012
  if not os.path.basename(settings['src']) == 'test':
@@ -4251,7 +6087,7 @@ def remove_low_variance_columns(df, threshold=0.01, verbose=False):
4251
6087
  """
4252
6088
  Removes columns from the dataframe that have low variance.
4253
6089
 
4254
- Parameters:
6090
+ Args:
4255
6091
  df (pandas.DataFrame): The DataFrame containing the data.
4256
6092
  threshold (float): The variance threshold below which columns will be removed.
4257
6093
 
@@ -4273,7 +6109,7 @@ def remove_highly_correlated_columns(df, threshold=0.95, verbose=False):
4273
6109
  """
4274
6110
  Removes columns from the dataframe that are highly correlated with one another.
4275
6111
 
4276
- Parameters:
6112
+ Args:
4277
6113
  df (pandas.DataFrame): The DataFrame containing the data.
4278
6114
  threshold (float): The correlation threshold above which columns will be removed.
4279
6115
 
@@ -4301,7 +6137,7 @@ def filter_dataframe_features(df, channel_of_interest, exclude=None, remove_low_
4301
6137
  """
4302
6138
  Filter the dataframe `df` based on the specified `channel_of_interest` and `exclude` parameters.
4303
6139
 
4304
- Parameters:
6140
+ Args:
4305
6141
  - df (pandas.DataFrame): The input dataframe to be filtered.
4306
6142
  - channel_of_interest (str, int, list, None): The channel(s) of interest to filter the dataframe. If None, no filtering is applied. If 'morphology', only morphology features are included.If an integer, only the specified channel is included. If a list, only the specified channels are included.If a string, only the specified channel is included.
4307
6143
  - exclude (str, list, None): The feature(s) to exclude from the filtered dataframe. If None, no features are excluded. If a string, the specified feature is excluded.If a list, the specified features are excluded.
@@ -4381,6 +6217,20 @@ def filter_dataframe_features(df, channel_of_interest, exclude=None, remove_low_
4381
6217
 
4382
6218
  # Create a function to check if images overlap
4383
6219
  def check_overlap(current_position, other_positions, threshold):
6220
+ """
6221
+ Checks if the current position overlaps with any of the other positions
6222
+ within a specified threshold distance.
6223
+
6224
+ Args:
6225
+ current_position (iterable): The current position as a list, tuple, or array of coordinates.
6226
+ other_positions (iterable): A collection of positions to compare against,
6227
+ where each position is a list, tuple, or array of coordinates.
6228
+ threshold (float): The distance threshold below which two positions are considered overlapping.
6229
+
6230
+ Returns:
6231
+ bool: True if the current position overlaps with any of the other positions,
6232
+ False otherwise.
6233
+ """
4384
6234
  for other_position in other_positions:
4385
6235
  distance = np.linalg.norm(np.array(current_position) - np.array(other_position))
4386
6236
  if distance < threshold:
@@ -4389,6 +6239,25 @@ def check_overlap(current_position, other_positions, threshold):
4389
6239
 
4390
6240
  # Define a function to try random positions around a given point
4391
6241
  def find_non_overlapping_position(x, y, image_positions, threshold, max_attempts=100):
6242
+ """
6243
+ Finds a new position near the given coordinates (x, y) that does not overlap
6244
+ with any of the positions in the provided image_positions list, based on a
6245
+ specified threshold.
6246
+
6247
+ Args:
6248
+ x (float): The x-coordinate of the initial position.
6249
+ y (float): The y-coordinate of the initial position.
6250
+ image_positions (list of tuples): A list of (x, y) tuples representing
6251
+ existing positions to avoid overlapping with.
6252
+ threshold (float): The minimum distance required to avoid overlap.
6253
+ max_attempts (int, optional): The maximum number of attempts to find a
6254
+ non-overlapping position. Defaults to 100.
6255
+
6256
+ Returns:
6257
+ tuple: A tuple (new_x, new_y) representing the new non-overlapping
6258
+ position. If no suitable position is found within the maximum attempts,
6259
+ the original position (x, y) is returned.
6260
+ """
4392
6261
  offset_range = 10 # Adjust the range for random offsets
4393
6262
  attempts = 0
4394
6263
  while attempts < max_attempts:
@@ -4405,7 +6274,7 @@ def search_reduction_and_clustering(numeric_data, n_neighbors, min_dist, metric,
4405
6274
  """
4406
6275
  Perform dimensionality reduction and clustering on the given data.
4407
6276
 
4408
- Parameters:
6277
+ Args:
4409
6278
  numeric_data (np.array): Numeric data to process.
4410
6279
  n_neighbors (int): Number of neighbors for UMAP or perplexity for tSNE.
4411
6280
  min_dist (float): Minimum distance for UMAP.
@@ -4457,7 +6326,18 @@ def search_reduction_and_clustering(numeric_data, n_neighbors, min_dist, metric,
4457
6326
  return embedding, labels
4458
6327
 
4459
6328
  def load_image(image_path):
4460
- """Load and preprocess an image."""
6329
+ """
6330
+ Load and transform an image into a normalized tensor.
6331
+
6332
+ Applies resizing to 224x224, converts to a tensor, and normalizes using
6333
+ ImageNet mean and standard deviation.
6334
+
6335
+ Args:
6336
+ image_path (str): Path to the input image.
6337
+
6338
+ Returns:
6339
+ torch.Tensor: Transformed image tensor with shape (1, 3, 224, 224).
6340
+ """
4461
6341
  transform = transforms.Compose([
4462
6342
  transforms.Resize((224, 224)),
4463
6343
  transforms.ToTensor(),
@@ -4468,7 +6348,17 @@ def load_image(image_path):
4468
6348
  return image
4469
6349
 
4470
6350
  def extract_features(image_paths, resnet=resnet50):
4471
- """Extract features from images using a pre-trained ResNet model."""
6351
+ """
6352
+ Extracts features from a list of image paths using a pre-trained ResNet model.
6353
+
6354
+ Args:
6355
+ image_paths (list of str): A list of file paths to the images from which features are to be extracted.
6356
+ resnet (torchvision.models, optional): A ResNet model class to use for feature extraction.
6357
+ Defaults to torchvision.models.resnet50.
6358
+
6359
+ Returns:
6360
+ numpy.ndarray: A 2D array where each row corresponds to the extracted features of an image.
6361
+ """
4472
6362
  model = resnet(pretrained=True)
4473
6363
  model = model.eval()
4474
6364
  model = torch.nn.Sequential(*list(model.children())[:-1]) # Remove the last classification layer
@@ -4483,7 +6373,20 @@ def extract_features(image_paths, resnet=resnet50):
4483
6373
  return np.array(features)
4484
6374
 
4485
6375
  def check_normality(series):
4486
- """Helper function to check if a feature is normally distributed."""
6376
+ """
6377
+ Test whether a given data series follows a normal distribution.
6378
+
6379
+ This function uses the D'Agostino and Pearson's test to check the null
6380
+ hypothesis that the data comes from a normal distribution. If the p-value
6381
+ is less than the significance level (alpha), the null hypothesis is rejected.
6382
+
6383
+ Args:
6384
+ series (array-like): The data series to test for normality.
6385
+
6386
+ Returns:
6387
+ bool: True if the data follows a normal distribution (p >= alpha),
6388
+ False otherwise.
6389
+ """
4487
6390
  k2, p = stats.normaltest(series)
4488
6391
  alpha = 0.05
4489
6392
  if p < alpha: # null hypothesis: x comes from a normal distribution
@@ -4491,7 +6394,29 @@ def check_normality(series):
4491
6394
  return True
4492
6395
 
4493
6396
  def random_forest_feature_importance(all_df, cluster_col='cluster'):
4494
- """Random Forest feature importance."""
6397
+ """
6398
+ Computes feature importance using a Random Forest Classifier.
6399
+
6400
+ This function takes a DataFrame, selects numeric features, and computes
6401
+ the importance of each feature in predicting the specified cluster column
6402
+ using a Random Forest Classifier. The results are returned as a sorted
6403
+ DataFrame of feature importances.
6404
+
6405
+ Args:
6406
+ all_df (pd.DataFrame): The input DataFrame containing the data.
6407
+ cluster_col (str, optional): The name of the column representing the
6408
+ target variable (cluster). Defaults to 'cluster'.
6409
+
6410
+ Returns:
6411
+ pd.DataFrame: A DataFrame containing the features and their corresponding
6412
+ importance scores, sorted in descending order of importance.
6413
+
6414
+ Notes:
6415
+ - The function assumes that the target column (`cluster_col`) is numeric.
6416
+ - Standard scaling is applied to the numeric features before fitting the model.
6417
+ - The Random Forest Classifier is initialized with 100 estimators and a
6418
+ random state of 42 for reproducibility.
6419
+ """
4495
6420
  numeric_features = all_df.select_dtypes(include=[np.number]).columns.tolist()
4496
6421
  if cluster_col in numeric_features:
4497
6422
  numeric_features.remove(cluster_col)
@@ -4515,7 +6440,29 @@ def random_forest_feature_importance(all_df, cluster_col='cluster'):
4515
6440
  return importance_df
4516
6441
 
4517
6442
  def perform_statistical_tests(all_df, cluster_col='cluster'):
4518
- """Perform ANOVA or Kruskal-Wallis tests depending on normality of features."""
6443
+ """
6444
+ Perform ANOVA and Kruskal-Wallis tests on numeric features grouped by clusters.
6445
+
6446
+ This function evaluates whether numeric features differ significantly across groups
6447
+ defined by the `cluster_col`.
6448
+
6449
+ Args:
6450
+ all_df (pd.DataFrame): DataFrame containing numeric features and cluster assignments.
6451
+ cluster_col (str, optional): Name of the column indicating cluster/group labels. Defaults to 'cluster'.
6452
+
6453
+ Returns:
6454
+ tuple:
6455
+ - anova_df (pd.DataFrame): ANOVA test results with columns
6456
+ ['Feature', 'ANOVA_Statistic', 'ANOVA_pValue'].
6457
+ - kruskal_df (pd.DataFrame): Kruskal-Wallis test results with columns
6458
+ ['Feature', 'Kruskal_Statistic', 'Kruskal_pValue'].
6459
+
6460
+ Notes:
6461
+ - Normality of each feature is assessed using `check_normality`.
6462
+ - ANOVA is used for normally distributed features.
6463
+ - Kruskal-Wallis is used for non-normal features.
6464
+ - Assumes `check_normality`, `scipy.stats.f_oneway`, and `scipy.stats.kruskal` are available.
6465
+ """
4519
6466
  numeric_features = all_df.select_dtypes(include=[np.number]).columns.tolist()
4520
6467
  if cluster_col in numeric_features:
4521
6468
  numeric_features.remove(cluster_col)
@@ -4539,15 +6486,47 @@ def perform_statistical_tests(all_df, cluster_col='cluster'):
4539
6486
  return anova_df, kruskal_df
4540
6487
 
4541
6488
  def combine_results(rf_df, anova_df, kruskal_df):
4542
- """Combine the results into a single DataFrame."""
6489
+ """
6490
+ Combines results from multiple DataFrames into a single DataFrame.
6491
+
6492
+ This function merges three DataFrames (`rf_df`, `anova_df`, and `kruskal_df`)
6493
+ on the 'Feature' column using a left join. The resulting DataFrame contains
6494
+ all features from `rf_df` and their corresponding data from `anova_df` and
6495
+ `kruskal_df` where available.
6496
+
6497
+ Args:
6498
+ rf_df (pd.DataFrame): A DataFrame containing features and their associated
6499
+ data from a random forest analysis.
6500
+ anova_df (pd.DataFrame): A DataFrame containing features and their associated
6501
+ data from an ANOVA analysis.
6502
+ kruskal_df (pd.DataFrame): A DataFrame containing features and their associated
6503
+ data from a Kruskal-Wallis analysis.
6504
+
6505
+ Returns:
6506
+ pd.DataFrame: A combined DataFrame with features and their associated data
6507
+ from all three input DataFrames.
6508
+ """
4543
6509
  combined_df = rf_df.merge(anova_df, on='Feature', how='left')
4544
6510
  combined_df = combined_df.merge(kruskal_df, on='Feature', how='left')
4545
6511
  return combined_df
4546
6512
 
4547
6513
  def cluster_feature_analysis(all_df, cluster_col='cluster'):
4548
6514
  """
4549
- Perform Random Forest feature importance, ANOVA for normally distributed features,
4550
- and Kruskal-Wallis for non-normally distributed features. Combine results into a single DataFrame.
6515
+ Perform feature analysis for clustering by combining results from
6516
+ random forest feature importance and statistical tests.
6517
+
6518
+ This function calculates feature importance using a random forest model,
6519
+ performs statistical tests (ANOVA and Kruskal-Wallis) to assess the
6520
+ significance of features, and combines the results into a single DataFrame.
6521
+
6522
+ Args:
6523
+ all_df (pd.DataFrame): The input DataFrame containing features and cluster labels.
6524
+ cluster_col (str, optional): The name of the column representing cluster labels.
6525
+ Defaults to 'cluster'.
6526
+
6527
+ Returns:
6528
+ pd.DataFrame: A DataFrame combining the results of random forest feature
6529
+ importance and statistical tests for feature analysis.
4551
6530
  """
4552
6531
  rf_df = random_forest_feature_importance(all_df, cluster_col)
4553
6532
  anova_df, kruskal_df = perform_statistical_tests(all_df, cluster_col)
@@ -4657,7 +6636,7 @@ def _merge_cells_without_nucleus(adj_cell_mask: np.ndarray, nuclei_mask: np.ndar
4657
6636
  Relabel any cell that lacks a nucleus to the ID of an adjacent
4658
6637
  cell that *does* contain a nucleus.
4659
6638
 
4660
- Parameters
6639
+ Args
4661
6640
  ----------
4662
6641
  adj_cell_mask : np.ndarray
4663
6642
  Labelled (0 = background) cell mask after all other merging steps.
@@ -4723,6 +6702,26 @@ def _merge_cells_without_nucleus(adj_cell_mask: np.ndarray, nuclei_mask: np.ndar
4723
6702
  return out.astype(np.uint16)
4724
6703
 
4725
6704
  def process_mask_file_adjust_cell(file_name, parasite_folder, cell_folder, nuclei_folder, overlap_threshold, perimeter_threshold):
6705
+ """
6706
+ Processes and adjusts a cell mask file based on parasite overlap and perimeter thresholds.
6707
+
6708
+ This function loads parasite, cell, and nuclei mask files, merges cells based on parasite overlap
6709
+ and perimeter thresholds, and saves the updated cell mask back to the file system.
6710
+
6711
+ Args:
6712
+ file_name (str): The name of the mask file to process.
6713
+ parasite_folder (str): The directory containing parasite mask files.
6714
+ cell_folder (str): The directory containing cell mask files.
6715
+ nuclei_folder (str): The directory containing nuclei mask files.
6716
+ overlap_threshold (float): The threshold for parasite overlap to merge cells.
6717
+ perimeter_threshold (float): The threshold for cell perimeter to merge cells.
6718
+
6719
+ Returns:
6720
+ float: The time taken to process the mask file, in seconds.
6721
+
6722
+ Raises:
6723
+ ValueError: If the corresponding cell or nuclei mask file for the given file_name is not found.
6724
+ """
4726
6725
  start = time.perf_counter()
4727
6726
 
4728
6727
  parasite_path = os.path.join(parasite_folder, file_name)
@@ -4745,6 +6744,27 @@ def process_mask_file_adjust_cell(file_name, parasite_folder, cell_folder, nucle
4745
6744
  return end - start
4746
6745
 
4747
6746
  def adjust_cell_masks(parasite_folder, cell_folder, nuclei_folder, overlap_threshold=5, perimeter_threshold=30, n_jobs=None):
6747
+ """
6748
+ Adjusts cell masks based on parasite, cell, and nuclei data files.
6749
+ This function processes `.npy` files from the specified folders to adjust cell masks
6750
+ by considering overlap and perimeter thresholds. It uses multiprocessing to parallelize
6751
+ the processing of files.
6752
+ Args:
6753
+ parasite_folder (str): Path to the folder containing parasite `.npy` files.
6754
+ cell_folder (str): Path to the folder containing cell `.npy` files.
6755
+ nuclei_folder (str): Path to the folder containing nuclei `.npy` files.
6756
+ overlap_threshold (int, optional): Threshold for overlap adjustment. Defaults to 5.
6757
+ perimeter_threshold (int, optional): Threshold for perimeter adjustment. Defaults to 30.
6758
+ n_jobs (int, optional): Number of parallel jobs to run. Defaults to the number of CPU cores minus 2.
6759
+ Raises:
6760
+ ValueError: If the number of files in the parasite, cell, and nuclei folders do not match.
6761
+ Notes:
6762
+ - The function assumes that the files in the folders are named in a way that allows
6763
+ them to be sorted and matched correctly.
6764
+ - Progress is printed to the console during processing.
6765
+ Returns:
6766
+ None
6767
+ """
4748
6768
 
4749
6769
  parasite_files = sorted([f for f in os.listdir(parasite_folder) if f.endswith('.npy')])
4750
6770
  cell_files = sorted([f for f in os.listdir(cell_folder) if f.endswith('.npy')])
@@ -4822,6 +6842,25 @@ def adjust_cell_masks_v1(parasite_folder, cell_folder, nuclei_folder, overlap_th
4822
6842
  print_progress(files_processed, files_to_process, n_jobs=1, time_ls=time_ls, batch_size=None, operation_type=f'adjust_cell_masks')
4823
6843
 
4824
6844
  def process_masks(mask_folder, image_folder, channel, batch_size=50, n_clusters=2, plot=False):
6845
+ """
6846
+ Processes mask files by measuring object properties, clustering objects, and removing objects
6847
+ not belonging to the largest cluster.
6848
+ Args:
6849
+ mask_folder (str): Path to the folder containing mask files (.npy format).
6850
+ image_folder (str): Path to the folder containing corresponding image files (.npy format).
6851
+ channel (int): The channel index to extract from the image files.
6852
+ batch_size (int, optional): Number of files to process in each batch. Defaults to 50.
6853
+ n_clusters (int, optional): Number of clusters for KMeans clustering. Defaults to 2.
6854
+ plot (bool, optional): Whether to plot the clustering results using PCA. Defaults to False.
6855
+ Returns:
6856
+ None: The function modifies the mask files in place by removing objects not in the largest cluster.
6857
+ Notes:
6858
+ - The mask files are expected to be in .npy format and contain labeled regions.
6859
+ - The image files are expected to be in .npy format and have the same names as the mask files.
6860
+ - The function assumes that the mask and image files are sorted in the same order.
6861
+ - The clustering is performed on accumulated object properties across all files.
6862
+ - The largest cluster is determined based on the number of objects in each cluster.
6863
+ """
4825
6864
 
4826
6865
  def read_files_in_batches(folder, batch_size=50):
4827
6866
  files = [f for f in os.listdir(folder) if f.endswith('.npy')]
@@ -4895,6 +6934,25 @@ def process_masks(mask_folder, image_folder, channel, batch_size=50, n_clusters=
4895
6934
  label_index += len(batch_properties)
4896
6935
 
4897
6936
  def merge_regression_res_with_metadata(results_file, metadata_file, name='_metadata'):
6937
+ """
6938
+ Merge regression results with metadata using gene identifiers.
6939
+
6940
+ Reads regression results and metadata from two CSV files, extracts gene identifiers,
6941
+ merges the data on gene names, and saves the merged DataFrame.
6942
+
6943
+ Args:
6944
+ results_file (str): Path to the regression results CSV. Must contain a 'feature' column.
6945
+ metadata_file (str): Path to the metadata CSV. Must contain a 'Gene ID' column.
6946
+ name (str, optional): Suffix for the output file name. Defaults to '_metadata'.
6947
+
6948
+ Returns:
6949
+ pandas.DataFrame: Merged DataFrame with regression results and metadata.
6950
+
6951
+ Notes:
6952
+ - Extracts gene names from 'feature' (format: '[gene]') and 'Gene ID' (format: 'prefix_gene').
6953
+ - Gene extraction failures result in NaNs in the merge but are not dropped.
6954
+ - Output CSV is saved next to the input `results_file`, with the suffix appended.
6955
+ """
4898
6956
  # Read the CSV files into dataframes
4899
6957
  df_results = pd.read_csv(results_file)
4900
6958
  df_metadata = pd.read_csv(metadata_file)
@@ -4932,6 +6990,19 @@ def merge_regression_res_with_metadata(results_file, metadata_file, name='_metad
4932
6990
  return merged_df
4933
6991
 
4934
6992
  def process_vision_results(df, threshold=0.5):
6993
+ """
6994
+ Process vision results by extracting metadata from file paths and thresholding predictions.
6995
+
6996
+ Args:
6997
+ df (pd.DataFrame): DataFrame with vision results. Must include 'path' and 'pred' columns.
6998
+ threshold (float, optional): Threshold for classifying predictions. Defaults to 0.5.
6999
+
7000
+ Returns:
7001
+ pd.DataFrame: Modified DataFrame with added columns:
7002
+ - 'plateID', 'rowID', 'columnID', 'fieldID', 'object': extracted from 'path'
7003
+ - 'prc': combination of 'plateID', 'rowID', and 'columnID'
7004
+ - 'cv_predictions': binary classification based on `threshold`
7005
+ """
4935
7006
 
4936
7007
  # Split the 'path' column using _map_wells function
4937
7008
  mapped_values = df['path'].apply(lambda x: _map_wells(x))
@@ -4947,6 +7018,50 @@ def process_vision_results(df, threshold=0.5):
4947
7018
  return df
4948
7019
 
4949
7020
  def get_ml_results_paths(src, model_type='xgboost', channel_of_interest=1):
7021
+ """
7022
+ Generate file paths for machine learning result outputs based on model type and feature selection.
7023
+
7024
+ Args:
7025
+ src (str): Base directory where the results folder structure will be created.
7026
+ model_type (str, optional): Type of ML model used (e.g., 'xgboost', 'random_forest'). Defaults to 'xgboost'.
7027
+ channel_of_interest (int | list[int] | str | None, optional): Feature set specification:
7028
+ - int: Single channel (e.g., 1)
7029
+ - list[int]: Multiple channels (e.g., [1, 2, 3])
7030
+ - 'morphology': Use only morphology features
7031
+ - None: Use all features
7032
+ Defaults to 1.
7033
+
7034
+ Returns:
7035
+ tuple[str, ...]: Tuple of 10 file paths:
7036
+ - data_path: CSV with predictions or main results
7037
+ - permutation_path: CSV with permutation importances
7038
+ - feature_importance_path: CSV with model feature importances
7039
+ - model_metricks_path: CSV with metrics from trained model
7040
+ - permutation_fig_path: PDF plot of permutation importances
7041
+ - feature_importance_fig_path: PDF plot of feature importances
7042
+ - shap_fig_path: PDF SHAP summary plot
7043
+ - plate_heatmap_path: PDF visualization of plate layout
7044
+ - settings_csv: CSV with ML settings and parameters
7045
+ - ml_features: CSV with extracted feature data used for training
7046
+
7047
+ Raises:
7048
+ ValueError: If `channel_of_interest` is not an int, list, None, or 'morphology'.
7049
+
7050
+ Example:
7051
+ >>> get_ml_results_paths('/home/user/data', model_type='random_forest', channel_of_interest=[1, 2])
7052
+ (
7053
+ '/home/user/data/results/random_forest/channels_1_2/results.csv',
7054
+ '/home/user/data/results/random_forest/channels_1_2/permutation.csv',
7055
+ '/home/user/data/results/random_forest/channels_1_2/feature_importance.csv',
7056
+ '/home/user/data/results/random_forest/channels_1_2/random_forest_model.csv',
7057
+ '/home/user/data/results/random_forest/channels_1_2/permutation.pdf',
7058
+ '/home/user/data/results/random_forest/channels_1_2/feature_importance.pdf',
7059
+ '/home/user/data/results/random_forest/channels_1_2/shap.pdf',
7060
+ '/home/user/data/results/random_forest/channels_1_2/plate_heatmap.pdf',
7061
+ '/home/user/data/results/random_forest/channels_1_2/ml_settings.csv',
7062
+ '/home/user/data/results/random_forest/channels_1_2/ml_features.csv'
7063
+ )
7064
+ """
4950
7065
 
4951
7066
  if isinstance(channel_of_interest, list):
4952
7067
  feature_string = "channels_" + "_".join(map(str, channel_of_interest))
@@ -4981,7 +7096,7 @@ def augment_image(image):
4981
7096
  """
4982
7097
  Perform data augmentation by rotating and reflecting the image.
4983
7098
 
4984
- Parameters:
7099
+ Args:
4985
7100
  - image (PIL Image or numpy array): The input image.
4986
7101
 
4987
7102
  Returns:
@@ -5025,7 +7140,7 @@ def augment_dataset(dataset, is_grayscale=False):
5025
7140
  """
5026
7141
  Perform data augmentation on the entire dataset by rotating and reflecting the images.
5027
7142
 
5028
- Parameters:
7143
+ Args:
5029
7144
  - dataset (list of tuples): The input dataset, each entry is a tuple (image, label, filename).
5030
7145
  - is_grayscale (bool): Flag indicating if the images are grayscale.
5031
7146
 
@@ -5063,7 +7178,7 @@ def convert_and_relabel_masks(folder_path):
5063
7178
  """
5064
7179
  Converts all int64 npy masks in a folder to uint16 with relabeling to ensure all labels are retained.
5065
7180
 
5066
- Parameters:
7181
+ Args:
5067
7182
  - folder_path (str): The path to the folder containing int64 npy mask files.
5068
7183
 
5069
7184
  Returns:
@@ -5104,6 +7219,20 @@ def convert_and_relabel_masks(folder_path):
5104
7219
  print(f"Converted {file} and saved as uint16_{file}")
5105
7220
 
5106
7221
  def correct_masks(src):
7222
+ """
7223
+ Corrects and processes mask files located in the specified source directory.
7224
+
7225
+ This function performs the following steps:
7226
+ 1. Constructs the file path for the cell mask stack within the 'masks' subdirectory of the source.
7227
+ 2. Converts and relabels the masks using the `convert_and_relabel_masks` function.
7228
+ 3. Loads and concatenates arrays from the source directory using the `_load_and_concatenate_arrays` function.
7229
+
7230
+ Args:
7231
+ src (str): The path to the source directory containing the mask files.
7232
+
7233
+ Returns:
7234
+ None
7235
+ """
5107
7236
 
5108
7237
  from .io import _load_and_concatenate_arrays
5109
7238
 
@@ -5112,15 +7241,35 @@ def correct_masks(src):
5112
7241
  _load_and_concatenate_arrays(src, [0,1,2,3], 1, 0, 2)
5113
7242
 
5114
7243
  def count_reads_in_fastq(fastq_file):
7244
+ """
7245
+ Counts the number of reads in a FASTQ file.
7246
+
7247
+ A FASTQ file contains sequencing reads, where each read is represented
7248
+ by four lines: a header, the sequence, a separator, and the quality scores.
7249
+ This function calculates the total number of reads by dividing the total
7250
+ number of lines in the file by 4.
7251
+
7252
+ Args:
7253
+ fastq_file (str): Path to the FASTQ file, which can be gzip-compressed.
7254
+
7255
+ Returns:
7256
+ int: The number of reads in the FASTQ file.
7257
+ """
5115
7258
  count = 0
5116
7259
  with gzip.open(fastq_file, "rt") as f:
5117
7260
  for _ in f:
5118
7261
  count += 1
5119
7262
  return count // 4
5120
7263
 
5121
-
5122
7264
  # Function to determine the CUDA version
5123
7265
  def get_cuda_version():
7266
+ """
7267
+ Retrieves the installed CUDA version by invoking the `nvcc --version` command.
7268
+
7269
+ Returns:
7270
+ str: The CUDA version as a string with dots removed (e.g., '110' for version 11.0),
7271
+ or None if the `nvcc` command is not found or an error occurs.
7272
+ """
5124
7273
  try:
5125
7274
  output = subprocess.check_output(['nvcc', '--version'], stderr=subprocess.STDOUT).decode('utf-8')
5126
7275
  if 'release' in output:
@@ -5129,10 +7278,36 @@ def get_cuda_version():
5129
7278
  return None
5130
7279
 
5131
7280
  def all_elements_match(list1, list2):
7281
+ """
7282
+ Check if all elements in the first list are present in the second list.
7283
+
7284
+ Args:
7285
+ list1 (list): The first list containing elements to check.
7286
+ list2 (list): The second list to check against.
7287
+
7288
+ Returns:
7289
+ bool: True if all elements in list1 are found in list2, False otherwise.
7290
+ """
5132
7291
  # Check if all elements in list1 are in list2
5133
7292
  return all(element in list2 for element in list1)
5134
7293
 
5135
7294
  def prepare_batch_for_segmentation(batch):
7295
+ """
7296
+ Prepare a batch of images for segmentation by ensuring correct data type and normalization.
7297
+
7298
+ Args:
7299
+ batch (np.ndarray): Batch of images with shape (N, H, W, C), where:
7300
+ - N: number of images
7301
+ - H, W: height and width
7302
+ - C: number of channels (e.g., 1 for grayscale, 3 for RGB)
7303
+
7304
+ Returns:
7305
+ np.ndarray: Batch with dtype `float32`, normalized to [0, 1] if needed.
7306
+
7307
+ Notes:
7308
+ - Converts to `float32` if not already.
7309
+ - Each image is divided by its own maximum pixel value if that value > 1.
7310
+ """
5136
7311
  # Ensure the batch is of dtype float32
5137
7312
  if batch.dtype != np.float32:
5138
7313
  batch = batch.astype(np.float32)
@@ -5145,6 +7320,25 @@ def prepare_batch_for_segmentation(batch):
5145
7320
  return batch
5146
7321
 
5147
7322
  def check_index(df, elements=5, split_char='_'):
7323
+ """
7324
+ Checks the indices of a DataFrame to ensure they can be split into a specified number of parts.
7325
+
7326
+ Args:
7327
+ df (pandas.DataFrame): The DataFrame whose indices are to be checked.
7328
+ elements (int, optional): The expected number of parts after splitting an index. Defaults to 5.
7329
+ split_char (str, optional): The character used to split the index. Defaults to '_'.
7330
+
7331
+ Raises:
7332
+ ValueError: If any index cannot be split into the specified number of parts,
7333
+ a ValueError is raised listing the problematic indices.
7334
+
7335
+ Example:
7336
+ >>> import pandas as pd
7337
+ >>> data = {'col1': [1, 2, 3]}
7338
+ >>> df = pd.DataFrame(data, index=['a_b_c_d_e', 'f_g_h_i', 'j_k_l_m_n'])
7339
+ >>> check_index(df)
7340
+ ValueError: Found 1 problematic indices that do not split into 5 parts.
7341
+ """
5148
7342
  problematic_indices = []
5149
7343
  for idx in df.index:
5150
7344
  parts = str(idx).split(split_char)
@@ -5158,6 +7352,22 @@ def check_index(df, elements=5, split_char='_'):
5158
7352
 
5159
7353
  # Define the mapping function
5160
7354
  def map_condition(col_value, neg='c1', pos='c2', mix='c3'):
7355
+ """
7356
+ Maps a given column value to a specific condition label.
7357
+
7358
+ Args:
7359
+ col_value (str): The value to be mapped.
7360
+ neg (str, optional): The value representing the 'neg' condition. Defaults to 'c1'.
7361
+ pos (str, optional): The value representing the 'pos' condition. Defaults to 'c2'.
7362
+ mix (str, optional): The value representing the 'mix' condition. Defaults to 'c3'.
7363
+
7364
+ Returns:
7365
+ str: A string representing the mapped condition:
7366
+ - 'neg' if col_value matches the neg parameter.
7367
+ - 'pos' if col_value matches the pos parameter.
7368
+ - 'mix' if col_value matches the mix parameter.
7369
+ - 'screen' if col_value does not match any of the above.
7370
+ """
5161
7371
  if col_value == neg:
5162
7372
  return 'neg'
5163
7373
  elif col_value == pos:
@@ -5236,7 +7446,7 @@ def generate_cytoplasm_mask(nucleus_mask, cell_mask):
5236
7446
  """
5237
7447
  Generates a cytoplasm mask from nucleus and cell masks.
5238
7448
 
5239
- Parameters:
7449
+ Args:
5240
7450
  - nucleus_mask (np.array): Binary or segmented mask of the nucleus (non-zero values represent nucleus).
5241
7451
  - cell_mask (np.array): Binary or segmented mask of the whole cell (non-zero values represent cell).
5242
7452
 
@@ -5259,7 +7469,7 @@ def add_column_to_database(settings):
5259
7469
  If the column already exists in the database, it adds the column with a suffix.
5260
7470
  NaN values will remain as NULL in the database.
5261
7471
 
5262
- Parameters:
7472
+ Args:
5263
7473
  settings (dict): A dictionary containing the following keys:
5264
7474
  csv_path (str): Path to the CSV file with the data to be added.
5265
7475
  db_path (str): Path to the SQLite database (or connection string for other databases).
@@ -5355,6 +7565,24 @@ def fill_holes_in_mask(mask):
5355
7565
  return filled_mask
5356
7566
 
5357
7567
  def correct_metadata_column_names(df):
7568
+ """
7569
+ Standardize column names in a metadata DataFrame.
7570
+
7571
+ This function renames commonly used but inconsistent metadata columns to a standardized format.
7572
+
7573
+ Renaming rules:
7574
+ - 'plate_name' → 'plateID'
7575
+ - 'column_name' or 'col' → 'columnID'
7576
+ - 'row_name' → 'rowID'
7577
+ - 'grna_name' → 'grna'
7578
+ - If 'plate_row' exists, it is split into 'plateID' and 'rowID' using '_' as a delimiter.
7579
+
7580
+ Args:
7581
+ df (pd.DataFrame): Input DataFrame with metadata columns.
7582
+
7583
+ Returns:
7584
+ pd.DataFrame: DataFrame with standardized column names.
7585
+ """
5358
7586
  if 'plate_name' in df.columns:
5359
7587
  df = df.rename(columns={'plate_name': 'plateID'})
5360
7588
  if 'column_name' in df.columns:
@@ -5370,6 +7598,27 @@ def correct_metadata_column_names(df):
5370
7598
  return df
5371
7599
 
5372
7600
  def control_filelist(folder, mode='columnID', values=['01','02']):
7601
+ def control_filelist(folder, mode='columnID', values=['01', '02']):
7602
+ """
7603
+ Filters a list of files in a given folder based on a specified mode and values.
7604
+
7605
+ Args:
7606
+ folder (str): The path to the folder containing the files to be filtered.
7607
+ mode (str, optional): The filtering mode. Can be 'columnID' or 'rowID'.
7608
+ - 'columnID': Filters files where the second part of the filename (split by '_')
7609
+ has a substring starting from the second character that matches one of the values.
7610
+ - 'rowID': Filters files where the second part of the filename (split by '_')
7611
+ has a substring starting from the first character that matches one of the values.
7612
+ Defaults to 'columnID'.
7613
+ values (list of str, optional): A list of string values to filter the files by. Defaults to ['01', '02'].
7614
+
7615
+ Returns:
7616
+ list of str: A list of filenames that match the filtering criteria.
7617
+
7618
+ Raises:
7619
+ IndexError: If the filenames in the folder do not conform to the expected format
7620
+ (e.g., missing the second part after splitting by '_').
7621
+ """
5373
7622
  files = os.listdir(folder)
5374
7623
  if mode == 'columnID':
5375
7624
  filtered_files = [file for file in files if file.split('_')[1][1:] in values]
@@ -5378,6 +7627,33 @@ def control_filelist(folder, mode='columnID', values=['01','02']):
5378
7627
  return filtered_files
5379
7628
 
5380
7629
  def rename_columns_in_db(db_path):
7630
+ """
7631
+ Renames specific columns in all user tables of a SQLite database based on a predefined mapping.
7632
+
7633
+ This function connects to the SQLite database at the given path, retrieves all user-defined tables,
7634
+ and renames columns in those tables according to the `rename_map` dictionary. If a column with the
7635
+ old name exists and the new name does not already exist in the same table, the column is renamed.
7636
+
7637
+ Args:
7638
+ db_path (str): The file path to the SQLite database.
7639
+
7640
+ Behavior:
7641
+ - Retrieves all user-defined tables in the database.
7642
+ - For each table, checks the column names against the `rename_map`.
7643
+ - Renames columns as specified in the `rename_map` if conditions are met.
7644
+ - Commits the changes to the database.
7645
+
7646
+ Notes:
7647
+ - The `rename_map` dictionary defines the mapping of old column names to new column names.
7648
+ - If a column with the new name already exists in a table, the old column will not be renamed.
7649
+ - The function uses SQLite's `ALTER TABLE ... RENAME COLUMN` syntax, which requires SQLite version 3.25.0 or higher.
7650
+
7651
+ Example:
7652
+ rename_columns_in_db("/path/to/database.db")
7653
+
7654
+ Raises:
7655
+ sqlite3.OperationalError: If there are issues executing SQL commands, such as unsupported SQLite versions.
7656
+ """
5381
7657
  # map old column names → new names
5382
7658
  rename_map = {
5383
7659
  'row': 'rowID',
@@ -5411,6 +7687,29 @@ def rename_columns_in_db(db_path):
5411
7687
  con.close()
5412
7688
 
5413
7689
  def group_feature_class(df, feature_groups=['cell', 'cytoplasm', 'nucleus', 'pathogen'], name='compartment'):
7690
+ """
7691
+ Classify and group features by category, then compute summed importance for each group.
7692
+
7693
+ This function adds a new column to the DataFrame to classify features based on regex matching
7694
+ against the given `feature_groups`. It then computes the total importance per group.
7695
+
7696
+ Args:
7697
+ df (pd.DataFrame): Input DataFrame with at least the columns `'feature'` and `'importance'`.
7698
+ feature_groups (list of str, optional): List of feature group identifiers (used in regex matching).
7699
+ Defaults to ['cell', 'cytoplasm', 'nucleus', 'pathogen'].
7700
+ name (str, optional): Name of the new column used to store group classifications.
7701
+ If set to `'channel'`, missing values in that column will be filled with `'morphology'`.
7702
+ Defaults to 'compartment'.
7703
+
7704
+ Returns:
7705
+ pd.DataFrame: Modified DataFrame including:
7706
+ - A new column with group classifications (based on `name` argument).
7707
+ - A summary with summed importance for each group and a row for total importance.
7708
+
7709
+ Notes:
7710
+ - A feature matching multiple groups will have their labels joined with a hyphen (e.g., "cell-nucleus").
7711
+ - Classification is done via regex search on the `'feature'` column.
7712
+ """
5414
7713
 
5415
7714
  # Function to determine compartment based on multiple matches
5416
7715
  def find_feature_class(feature, compartments):
@@ -5441,6 +7740,33 @@ def group_feature_class(df, feature_groups=['cell', 'cytoplasm', 'nucleus', 'pat
5441
7740
  return df
5442
7741
 
5443
7742
  def delete_intermedeate_files(settings):
7743
+ """
7744
+ Delete intermediate files and directories specified in the settings dictionary.
7745
+
7746
+ This function removes a predefined set of subdirectories under the path given by
7747
+ `settings['src']`. If deletion fails, an error message is printed.
7748
+
7749
+ Args:
7750
+ settings (dict): Must include the key `'src'`, the base directory containing
7751
+ intermediate files and subdirectories.
7752
+
7753
+ Behavior:
7754
+ - Verifies that `'src'` and its `orig/` subdirectory exist.
7755
+ - Deletes the following under `src` if they exist:
7756
+ * 'stack'
7757
+ * 'masks'
7758
+ * directories '1' through '10'
7759
+ - Compares lengths of `merged_stack` and `path_stack` to decide on deletion.
7760
+ - Prints success or error messages for each directory.
7761
+
7762
+ Notes:
7763
+ - If `src` or `orig/` is missing, the function exits early.
7764
+ - Any deletion exceptions are caught and reported; no exceptions are raised.
7765
+
7766
+ Example:
7767
+ >>> settings = {'src': '/path/to/source'}
7768
+ >>> delete_intermedeate_files(settings)
7769
+ """
5444
7770
 
5445
7771
  path_orig = os.path.join(settings['src'], 'orig')
5446
7772
  path_stack = os.path.join(settings['src'], 'stack')
@@ -5485,7 +7811,7 @@ def filter_and_save_csv(input_csv, output_csv, column_name, upper_threshold, low
5485
7811
  Reads a CSV into a DataFrame, filters rows based on a column for values > upper_threshold and < lower_threshold,
5486
7812
  and saves the filtered DataFrame to a new CSV file.
5487
7813
 
5488
- Parameters:
7814
+ Args:
5489
7815
  input_csv (str): Path to the input CSV file.
5490
7816
  output_csv (str): Path to save the filtered CSV file.
5491
7817
  column_name (str): Column name to apply the filters on.
@@ -5511,7 +7837,7 @@ def extract_tar_bz2_files(folder_path):
5511
7837
  """
5512
7838
  Extracts all .tar.bz2 files in the given folder into subfolders with the same name as the tar file.
5513
7839
 
5514
- Parameters:
7840
+ Args:
5515
7841
  folder_path (str): Path to the folder containing .tar.bz2 files.
5516
7842
  """
5517
7843
  if not os.path.isdir(folder_path):
@@ -5539,7 +7865,7 @@ def calculate_shortest_distance(df, object1, object2):
5539
7865
  """
5540
7866
  Calculate the shortest edge-to-edge distance between two objects (e.g., pathogen and nucleus).
5541
7867
 
5542
- Parameters:
7868
+ Args:
5543
7869
  - df: Pandas DataFrame containing measurements
5544
7870
  - object1: String, name of the first object (e.g., "pathogen")
5545
7871
  - object2: String, name of the second object (e.g., "nucleus")
@@ -5694,6 +8020,22 @@ def copy_images_to_consolidated(image_path_map, root_folder):
5694
8020
  #print(f"Copied: {original_path} -> {new_file_path}")
5695
8021
 
5696
8022
  def correct_metadata(df):
8023
+ """
8024
+ Corrects and standardizes the metadata column names in a DataFrame.
8025
+ This function checks for specific column names in the input DataFrame and
8026
+ renames or maps them to standardized names for consistency. The following
8027
+ transformations are applied:
8028
+ - If 'object_name' exists, it is mapped to 'objectID'.
8029
+ - If 'field_name' exists, it is mapped to 'fieldID'.
8030
+ - If 'plate' or 'plate_name' exists, they are mapped to 'plateID'.
8031
+ - If 'row' or 'row_name' exists, they are renamed to 'rowID'.
8032
+ - If 'col', 'column', or 'column_name' exists, they are renamed to 'columnID'.
8033
+ - If 'field' or 'field_name' exists, they are renamed to 'fieldID'.
8034
+ Args:
8035
+ df (pandas.DataFrame): The input DataFrame containing metadata columns.
8036
+ Returns:
8037
+ pandas.DataFrame: The DataFrame with standardized metadata column names.
8038
+ """
5697
8039
 
5698
8040
  #if 'object' in df.columns:
5699
8041
  # df['objectID'] = df['object']
@@ -5737,7 +8079,7 @@ def remove_outliers_by_group(df, group_col, value_col, method='iqr', threshold=1
5737
8079
  """
5738
8080
  Removes outliers from `value_col` within each group defined by `group_col`.
5739
8081
 
5740
- Parameters:
8082
+ Args:
5741
8083
  df (pd.DataFrame): The input DataFrame.
5742
8084
  group_col (str): Column name to group by.
5743
8085
  value_col (str): Column containing values to check for outliers.