spacr 1.0.7__py3-none-any.whl → 1.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- spacr/app_classify.py +10 -0
- spacr/app_mask.py +9 -0
- spacr/app_measure.py +9 -0
- spacr/app_sequencing.py +9 -0
- spacr/core.py +172 -1
- spacr/deep_spacr.py +296 -7
- spacr/gui.py +68 -0
- spacr/gui_core.py +319 -10
- spacr/gui_elements.py +772 -13
- spacr/gui_utils.py +301 -151
- spacr/io.py +887 -71
- spacr/logger.py +36 -0
- spacr/measure.py +206 -28
- spacr/ml.py +606 -142
- spacr/plot.py +797 -131
- spacr/sequencing.py +363 -8
- spacr/settings.py +1158 -38
- spacr/sp_stats.py +80 -12
- spacr/spacr_cellpose.py +115 -2
- spacr/submodules.py +747 -19
- spacr/timelapse.py +237 -53
- spacr/toxo.py +132 -6
- spacr/utils.py +2422 -80
- {spacr-1.0.7.dist-info → spacr-1.1.0.dist-info}/METADATA +31 -17
- {spacr-1.0.7.dist-info → spacr-1.1.0.dist-info}/RECORD +29 -29
- {spacr-1.0.7.dist-info → spacr-1.1.0.dist-info}/LICENSE +0 -0
- {spacr-1.0.7.dist-info → spacr-1.1.0.dist-info}/WHEEL +0 -0
- {spacr-1.0.7.dist-info → spacr-1.1.0.dist-info}/entry_points.txt +0 -0
- {spacr-1.0.7.dist-info → spacr-1.1.0.dist-info}/top_level.txt +0 -0
spacr/utils.py
CHANGED
@@ -77,7 +77,7 @@ def _generate_mask_random_cmap(mask):
|
|
77
77
|
"""
|
78
78
|
Generate a random colormap based on the unique labels in the given mask.
|
79
79
|
|
80
|
-
|
80
|
+
Args:
|
81
81
|
mask (ndarray): The mask array containing unique labels.
|
82
82
|
|
83
83
|
Returns:
|
@@ -92,7 +92,18 @@ def _generate_mask_random_cmap(mask):
|
|
92
92
|
return random_cmap
|
93
93
|
|
94
94
|
def filepaths_to_database(img_paths, settings, source_folder, crop_mode):
|
95
|
+
"""
|
96
|
+
Insert image paths and metadata into a SQLite database.
|
97
|
+
|
98
|
+
Args:
|
99
|
+
img_paths (list of str): Image file paths to insert.
|
100
|
+
settings (dict): Configuration dictionary. Must contain a 'timelapse' key (bool).
|
101
|
+
source_folder (str): Folder containing the SQLite database.
|
102
|
+
crop_mode (str): One of 'cell', 'nucleus', 'pathogen', or 'cytoplasm'.
|
95
103
|
|
104
|
+
Raises:
|
105
|
+
sqlite3.OperationalError: On database connection or write errors.
|
106
|
+
"""
|
96
107
|
png_df = pd.DataFrame(img_paths, columns=['png_path'])
|
97
108
|
|
98
109
|
png_df['file_name'] = png_df['png_path'].apply(lambda x: os.path.basename(x))
|
@@ -129,6 +140,17 @@ def filepaths_to_database(img_paths, settings, source_folder, crop_mode):
|
|
129
140
|
traceback.print_exc()
|
130
141
|
|
131
142
|
def activation_maps_to_database(img_paths, source_folder, settings):
|
143
|
+
"""
|
144
|
+
Store image paths and metadata in a SQLite database.
|
145
|
+
|
146
|
+
Args:
|
147
|
+
img_paths (list of str): Paths to image files.
|
148
|
+
source_folder (str): Folder where the database is stored.
|
149
|
+
settings (dict): Must include 'dataset' and 'cam_type' keys.
|
150
|
+
|
151
|
+
Raises:
|
152
|
+
sqlite3.OperationalError: On database errors.
|
153
|
+
"""
|
132
154
|
from .io import _create_database
|
133
155
|
|
134
156
|
png_df = pd.DataFrame(img_paths, columns=['png_path'])
|
@@ -152,6 +174,18 @@ def activation_maps_to_database(img_paths, source_folder, settings):
|
|
152
174
|
traceback.print_exc()
|
153
175
|
|
154
176
|
def activation_correlations_to_database(df, img_paths, source_folder, settings):
|
177
|
+
"""
|
178
|
+
Save activation correlation data to a SQLite database.
|
179
|
+
|
180
|
+
Args:
|
181
|
+
df (pd.DataFrame): DataFrame with correlation data. Must include 'file_name'.
|
182
|
+
img_paths (list): List of image file paths.
|
183
|
+
source_folder (str): Folder where the database is stored.
|
184
|
+
settings (dict): Must include 'dataset' and 'cam_type' keys.
|
185
|
+
|
186
|
+
Raises:
|
187
|
+
sqlite3.OperationalError: On database errors.
|
188
|
+
"""
|
155
189
|
from .io import _create_database
|
156
190
|
|
157
191
|
png_df = pd.DataFrame(img_paths, columns=['png_path'])
|
@@ -341,6 +375,39 @@ def load_settings(csv_file_path, show=False, setting_key='setting_key', setting_
|
|
341
375
|
return result_dict
|
342
376
|
|
343
377
|
def save_settings(settings, name='settings', show=False):
|
378
|
+
"""
|
379
|
+
Save a dictionary of settings to a CSV file.
|
380
|
+
|
381
|
+
This function takes a dictionary of settings, processes it, and saves it
|
382
|
+
as a CSV file in a specified directory. It also provides an option to
|
383
|
+
display the settings as a DataFrame.
|
384
|
+
|
385
|
+
Args:
|
386
|
+
settings (dict): A dictionary containing the settings to be saved.
|
387
|
+
name (str, optional): The base name for the output CSV file. Defaults to 'settings'.
|
388
|
+
show (bool, optional): If True, displays the settings as a DataFrame. Defaults to False.
|
389
|
+
|
390
|
+
Raises:
|
391
|
+
KeyError: If the 'src' key is not present in the settings dictionary.
|
392
|
+
|
393
|
+
Notes:
|
394
|
+
- If the 'src' key in the settings dictionary is a list, the first element
|
395
|
+
is used as the source directory, and the file name is appended with '_list'.
|
396
|
+
- If the 'test_mode' key exists in the settings dictionary, it is set to False.
|
397
|
+
- If the 'plot' key exists in the settings dictionary, it is set to False.
|
398
|
+
- A directory named 'settings' is created inside the source directory if it does not exist.
|
399
|
+
- The settings are saved as a CSV file in the 'settings' directory.
|
400
|
+
- The file path where the settings are saved is printed.
|
401
|
+
|
402
|
+
Example:
|
403
|
+
>>> settings = {
|
404
|
+
... 'src': '/path/to/source',
|
405
|
+
... 'test_mode': True,
|
406
|
+
... 'plot': True,
|
407
|
+
... 'param1': 42
|
408
|
+
... }
|
409
|
+
>>> save_settings(settings, name='experiment1', show=True)
|
410
|
+
"""
|
344
411
|
|
345
412
|
settings_2 = settings.copy()
|
346
413
|
|
@@ -367,6 +434,33 @@ def save_settings(settings, name='settings', show=False):
|
|
367
434
|
settings_df.to_csv(settings_csv, index=False)
|
368
435
|
|
369
436
|
def print_progress(files_processed, files_to_process, n_jobs, time_ls=None, batch_size=None, operation_type=""):
|
437
|
+
"""
|
438
|
+
Prints the progress of a file processing operation along with estimated time information.
|
439
|
+
|
440
|
+
Args:
|
441
|
+
files_processed (int or list): The number of files processed so far. If a list is provided,
|
442
|
+
its unique length will be used.
|
443
|
+
files_to_process (int or list): The total number of files to process. If a list is provided,
|
444
|
+
its unique length will be used.
|
445
|
+
n_jobs (int): The number of parallel jobs being used for processing.
|
446
|
+
time_ls (list, optional): A list of time durations for processing batches or files. Used to
|
447
|
+
calculate average time and estimated time left. Defaults to None.
|
448
|
+
batch_size (int or list, optional): The size of each batch being processed. If a list is
|
449
|
+
provided, its length will be used. Defaults to None.
|
450
|
+
operation_type (str, optional): A string describing the type of operation being performed.
|
451
|
+
Defaults to an empty string.
|
452
|
+
|
453
|
+
Returns:
|
454
|
+
None: This function prints the progress and time information to the console.
|
455
|
+
|
456
|
+
Notes:
|
457
|
+
- If `time_ls` is provided, the function calculates the average time per batch or file and
|
458
|
+
estimates the time remaining for the operation.
|
459
|
+
- If `batch_size` is provided, the function calculates the average time per image within
|
460
|
+
a batch.
|
461
|
+
- Handles cases where inputs are lists or non-integer types by converting them to integers
|
462
|
+
or calculating their lengths.
|
463
|
+
"""
|
370
464
|
if isinstance(files_processed, list):
|
371
465
|
files_processed = len(set(files_processed))
|
372
466
|
if isinstance(files_to_process, list):
|
@@ -399,6 +493,20 @@ def print_progress(files_processed, files_to_process, n_jobs, time_ls=None, batc
|
|
399
493
|
print(f'Progress: {files_processed}/{files_to_process}, operation_type: {operation_type}, {time_info}')
|
400
494
|
|
401
495
|
def reset_mp():
|
496
|
+
"""
|
497
|
+
Resets the multiprocessing start method based on the operating system.
|
498
|
+
On Windows, the start method is set to 'spawn' if it is not already set.
|
499
|
+
On Linux and macOS (Darwin), the start method is set to 'fork' if it is not already set.
|
500
|
+
This function ensures compatibility with the multiprocessing module by
|
501
|
+
enforcing the appropriate start method for the current platform.
|
502
|
+
Note:
|
503
|
+
- The function uses `get_start_method` to retrieve the current start method.
|
504
|
+
- The `set_start_method` function is called with `force=True` to override
|
505
|
+
the existing start method if necessary.
|
506
|
+
Raises:
|
507
|
+
ValueError: If an invalid start method is encountered or if the start
|
508
|
+
method cannot be set for some reason.
|
509
|
+
"""
|
402
510
|
current_method = get_start_method()
|
403
511
|
system = platform.system()
|
404
512
|
|
@@ -410,6 +518,23 @@ def reset_mp():
|
|
410
518
|
set_start_method('fork', force=True)
|
411
519
|
|
412
520
|
def is_multiprocessing_process(process):
|
521
|
+
"""
|
522
|
+
Check if the given process is a multiprocessing process.
|
523
|
+
|
524
|
+
This function examines the command-line arguments of the provided process
|
525
|
+
to determine if it is associated with Python's multiprocessing module.
|
526
|
+
|
527
|
+
Args:
|
528
|
+
process (psutil.Process): A psutil Process object representing the process to check.
|
529
|
+
|
530
|
+
Returns:
|
531
|
+
bool: True if the process is a multiprocessing process, False otherwise.
|
532
|
+
|
533
|
+
Raises:
|
534
|
+
psutil.NoSuchProcess: If the process no longer exists.
|
535
|
+
psutil.AccessDenied: If access to the process information is denied.
|
536
|
+
psutil.ZombieProcess: If the process is a zombie process.
|
537
|
+
"""
|
413
538
|
""" Check if the process is a multiprocessing process. """
|
414
539
|
try:
|
415
540
|
for cmd in process.cmdline():
|
@@ -420,6 +545,24 @@ def is_multiprocessing_process(process):
|
|
420
545
|
return False
|
421
546
|
|
422
547
|
def close_file_descriptors():
|
548
|
+
"""
|
549
|
+
Closes all open file descriptors starting from 3 up to the soft limit
|
550
|
+
of the maximum number of file descriptors allowed for the process.
|
551
|
+
|
552
|
+
This function is useful for cleaning up resources by ensuring that
|
553
|
+
no unnecessary file descriptors remain open. It skips standard input,
|
554
|
+
output, and error (file descriptors 0, 1, and 2).
|
555
|
+
|
556
|
+
Exceptions during the closing of file descriptors are caught and ignored.
|
557
|
+
|
558
|
+
Note:
|
559
|
+
The function uses the `resource` module to retrieve the soft limit
|
560
|
+
for the maximum number of file descriptors.
|
561
|
+
|
562
|
+
Raises:
|
563
|
+
OSError: If an error occurs while closing a file descriptor, it is
|
564
|
+
caught and ignored.
|
565
|
+
"""
|
423
566
|
""" Close file descriptors and shared memory objects. """
|
424
567
|
import resource
|
425
568
|
|
@@ -431,6 +574,23 @@ def close_file_descriptors():
|
|
431
574
|
pass
|
432
575
|
|
433
576
|
def close_multiprocessing_processes():
|
577
|
+
"""
|
578
|
+
Close all multiprocessing processes and clean up associated resources.
|
579
|
+
This function iterates through all running processes and terminates any
|
580
|
+
that are identified as multiprocessing processes, excluding the current
|
581
|
+
process. It waits for up to 5 seconds for each process to terminate and
|
582
|
+
logs the termination status. Additionally, it handles exceptions that may
|
583
|
+
occur during the termination process, such as access denial or the process
|
584
|
+
no longer existing.
|
585
|
+
After terminating the processes, it ensures that any open file descriptors
|
586
|
+
are properly closed.
|
587
|
+
Note:
|
588
|
+
This function relies on the `psutil` library to inspect and manage
|
589
|
+
processes, and assumes the existence of helper functions:
|
590
|
+
- `is_multiprocessing_process(proc)`: Determines if a process is a
|
591
|
+
multiprocessing process.
|
592
|
+
- `close_file_descriptors()`: Closes any open file descriptors.
|
593
|
+
"""
|
434
594
|
""" Close all multiprocessing processes. """
|
435
595
|
current_pid = os.getpid()
|
436
596
|
for proc in psutil.process_iter(['pid', 'cmdline']):
|
@@ -452,6 +612,28 @@ def close_multiprocessing_processes():
|
|
452
612
|
close_file_descriptors()
|
453
613
|
|
454
614
|
def check_mask_folder(src,mask_fldr):
|
615
|
+
def check_mask_folder(src, mask_fldr):
|
616
|
+
"""
|
617
|
+
Checks the status of a mask folder within a given source directory.
|
618
|
+
This function verifies whether the specified mask folder exists and
|
619
|
+
whether the number of mask files matches the number of stack files
|
620
|
+
in the corresponding stack folder. It returns a boolean indicating
|
621
|
+
whether further processing is needed.
|
622
|
+
Args:
|
623
|
+
src (str): The source directory containing the 'masks' and 'stack' subdirectories.
|
624
|
+
mask_fldr (str): The name of the mask folder to check within the 'masks' directory.
|
625
|
+
Returns:
|
626
|
+
bool:
|
627
|
+
- True if the mask folder does not exist or if the number of mask files
|
628
|
+
does not match the number of stack files.
|
629
|
+
- False if the mask folder exists and all masks have been generated.
|
630
|
+
Side Effects:
|
631
|
+
Prints a message if all masks have been generated for the specified mask folder.
|
632
|
+
Example:
|
633
|
+
>>> check_mask_folder('/home/user/data', 'experiment1')
|
634
|
+
All masks have been generated for experiment1
|
635
|
+
False
|
636
|
+
"""
|
455
637
|
|
456
638
|
mask_folder = os.path.join(src,'masks',mask_fldr)
|
457
639
|
stack_folder = os.path.join(src,'stack')
|
@@ -469,6 +651,17 @@ def check_mask_folder(src,mask_fldr):
|
|
469
651
|
return True
|
470
652
|
|
471
653
|
def smooth_hull_lines(cluster_data):
|
654
|
+
"""
|
655
|
+
Smooths the convex hull of a set of 2D points using spline interpolation.
|
656
|
+
|
657
|
+
Args:
|
658
|
+
cluster_data (numpy.ndarray): A 2D array of shape (n_points, 2) representing
|
659
|
+
the coordinates of the points in the cluster.
|
660
|
+
|
661
|
+
Returns:
|
662
|
+
tuple: Two 1D numpy arrays representing the x and y coordinates of the smoothed
|
663
|
+
convex hull, respectively. Each array contains 100 points.
|
664
|
+
"""
|
472
665
|
hull = ConvexHull(cluster_data)
|
473
666
|
# Extract vertices of the hull
|
474
667
|
vertices = hull.points[hull.vertices]
|
@@ -498,6 +691,27 @@ def _gen_rgb_image(image, channels):
|
|
498
691
|
return rgb_image
|
499
692
|
|
500
693
|
def _outline_and_overlay(image, rgb_image, mask_dims, outline_colors, outline_thickness):
|
694
|
+
"""
|
695
|
+
Generate outlines for specified mask dimensions and overlay them onto an RGB image.
|
696
|
+
|
697
|
+
This function processes a multi-dimensional mask to extract contours for each specified
|
698
|
+
dimension, creates outlines with specified thickness, and overlays these outlines onto
|
699
|
+
the provided RGB image using the given colors.
|
700
|
+
|
701
|
+
Args:
|
702
|
+
image (numpy.ndarray): The input image containing multi-dimensional masks.
|
703
|
+
rgb_image (numpy.ndarray): The RGB image onto which the outlines will be overlaid.
|
704
|
+
mask_dims (list[int]): A list of dimensions (indices) in the mask to process.
|
705
|
+
outline_colors (list[tuple[int, int, int]]): A list of RGB color tuples for the outlines.
|
706
|
+
outline_thickness (int): The thickness of the outlines to be drawn.
|
707
|
+
|
708
|
+
Returns:
|
709
|
+
tuple:
|
710
|
+
- overlayed_image (numpy.ndarray): The RGB image with outlines overlaid.
|
711
|
+
- outlines (list[numpy.ndarray]): A list of binary masks representing the outlines
|
712
|
+
for each processed dimension.
|
713
|
+
- image (numpy.ndarray): The original input image (unchanged).
|
714
|
+
"""
|
501
715
|
outlines = []
|
502
716
|
overlayed_image = rgb_image.copy()
|
503
717
|
|
@@ -554,6 +768,20 @@ def _convert_cq1_well_id(well_id):
|
|
554
768
|
return well_format
|
555
769
|
|
556
770
|
def _get_cellpose_batch_size():
|
771
|
+
"""
|
772
|
+
Determines the appropriate batch size for Cellpose based on the available GPU's VRAM.
|
773
|
+
|
774
|
+
This function checks if CUDA is available and retrieves the VRAM (Video RAM) of the first GPU device.
|
775
|
+
Based on the VRAM size, it assigns a batch size for Cellpose processing. If CUDA is not available
|
776
|
+
or an error occurs, a default batch size of 8 is returned.
|
777
|
+
|
778
|
+
Returns:
|
779
|
+
int: The batch size for Cellpose. Possible values are:
|
780
|
+
- 8 (default or if VRAM < 8 GB)
|
781
|
+
- 16 (if 8 GB <= VRAM < 12 GB)
|
782
|
+
- 48 (if 12 GB <= VRAM < 24 GB)
|
783
|
+
- 96 (if VRAM >= 24 GB)
|
784
|
+
"""
|
557
785
|
try:
|
558
786
|
# Check if CUDA is available
|
559
787
|
if torch.cuda.is_available():
|
@@ -576,6 +804,32 @@ def _get_cellpose_batch_size():
|
|
576
804
|
return 8
|
577
805
|
|
578
806
|
def _extract_filename_metadata(filenames, src, regular_expression, metadata_type='cellvoyager'):
|
807
|
+
"""
|
808
|
+
Extracts metadata from a list of filenames based on a provided regular expression
|
809
|
+
and organizes the files into a dictionary grouped by metadata keys.
|
810
|
+
Args:
|
811
|
+
filenames (list of str): A list of filenames to process.
|
812
|
+
src (str): The source directory containing the files.
|
813
|
+
regular_expression (re.Pattern): A compiled regular expression used to extract metadata
|
814
|
+
from the filenames. The regex should define named groups such as 'plateID', 'wellID',
|
815
|
+
'fieldID', 'chanID', 'timeID', and 'sliceID' as applicable.
|
816
|
+
metadata_type (str, optional): The type of metadata to process. Defaults to 'cellvoyager'.
|
817
|
+
If set to 'cq1', well IDs will be converted using `_convert_cq1_well_id`.
|
818
|
+
Returns:
|
819
|
+
defaultdict: A dictionary where keys are tuples of metadata values
|
820
|
+
(plate, well, field, channel, timeID, sliceID), and values are lists of file paths
|
821
|
+
corresponding to those metadata keys.
|
822
|
+
Raises:
|
823
|
+
IndexError: If the filename matches the regex but does not contain the expected groups.
|
824
|
+
Prints warnings for filenames that do not match the regex or fail to extract metadata.
|
825
|
+
Notes:
|
826
|
+
- The function assumes that the provided regular expression includes named groups
|
827
|
+
for extracting metadata.
|
828
|
+
- If a metadata group (e.g., 'timeID' or 'sliceID') is not present in the regex,
|
829
|
+
its value will be set to `None`.
|
830
|
+
- The function handles numeric metadata values by converting them to integers
|
831
|
+
where applicable.
|
832
|
+
"""
|
579
833
|
|
580
834
|
images_by_key = defaultdict(list)
|
581
835
|
|
@@ -635,7 +889,7 @@ def mask_object_count(mask):
|
|
635
889
|
"""
|
636
890
|
Counts the number of objects in a given mask.
|
637
891
|
|
638
|
-
|
892
|
+
Args:
|
639
893
|
- mask: numpy.ndarray. The mask containing object labels.
|
640
894
|
|
641
895
|
Returns:
|
@@ -777,7 +1031,7 @@ def normalize_to_dtype(array, p1=2, p2=98, percentile_list=None, new_dtype=None)
|
|
777
1031
|
"""
|
778
1032
|
Normalize each image in the stack to its own percentiles.
|
779
1033
|
|
780
|
-
|
1034
|
+
Args:
|
781
1035
|
- array: numpy array
|
782
1036
|
The input stack to be normalized.
|
783
1037
|
- p1: int, optional
|
@@ -895,7 +1149,7 @@ def _find_bounding_box(crop_mask, _id, buffer=10):
|
|
895
1149
|
"""
|
896
1150
|
Find the bounding box coordinates for a given object ID in a crop mask.
|
897
1151
|
|
898
|
-
|
1152
|
+
Args:
|
899
1153
|
crop_mask (ndarray): The crop mask containing object IDs.
|
900
1154
|
_id (int): The object ID to find the bounding box for.
|
901
1155
|
buffer (int, optional): The buffer size to add to the bounding box coordinates. Defaults to 10.
|
@@ -1093,7 +1347,7 @@ def _get_percentiles(array, p1=2, p2=98):
|
|
1093
1347
|
"""
|
1094
1348
|
Calculate the percentiles of each image in the given array.
|
1095
1349
|
|
1096
|
-
|
1350
|
+
Args:
|
1097
1351
|
- array: numpy.ndarray
|
1098
1352
|
The input array containing the images.
|
1099
1353
|
- q1: float, optional
|
@@ -1125,7 +1379,7 @@ def _crop_center(img, cell_mask, new_width, new_height):
|
|
1125
1379
|
"""
|
1126
1380
|
Crop the image around the center of the cell mask.
|
1127
1381
|
|
1128
|
-
|
1382
|
+
Args:
|
1129
1383
|
- img: numpy.ndarray
|
1130
1384
|
The input image.
|
1131
1385
|
- cell_mask: numpy.ndarray
|
@@ -1195,6 +1449,35 @@ def _get_diam(mag, obj):
|
|
1195
1449
|
return int(diamiter)
|
1196
1450
|
|
1197
1451
|
def _get_object_settings(object_type, settings):
|
1452
|
+
"""
|
1453
|
+
Generate and return a dictionary of settings for a specific object type.
|
1454
|
+
This function configures object-specific settings based on the provided
|
1455
|
+
`object_type` and `settings`. It supports three object types: 'cell',
|
1456
|
+
'nucleus', and 'pathogen'. If an unsupported object type is provided,
|
1457
|
+
a message is printed, and no settings are returned.
|
1458
|
+
Args:
|
1459
|
+
object_type (str): The type of object to configure settings for.
|
1460
|
+
Supported values are 'cell', 'nucleus', and 'pathogen'.
|
1461
|
+
settings (dict): A dictionary containing global settings, which may include:
|
1462
|
+
- 'magnification' (float): Magnification level used to calculate diameter.
|
1463
|
+
- 'nucleus_channel' (optional): Specifies if a nucleus channel is present.
|
1464
|
+
- 'cell_restore_type' (optional): Restore type for cell objects.
|
1465
|
+
- 'nucleus_restore_type' (optional): Restore type for nucleus objects.
|
1466
|
+
- 'pathogen_restore_type' (optional): Restore type for pathogen objects.
|
1467
|
+
- 'merge_pathogens' (bool): Whether to merge pathogens.
|
1468
|
+
- 'verbose' (bool): Whether to print the generated settings.
|
1469
|
+
Returns:
|
1470
|
+
dict: A dictionary containing the configured settings for the specified object type.
|
1471
|
+
The dictionary includes keys such as 'diameter', 'minimum_size', 'maximum_size',
|
1472
|
+
'merge', 'resample', 'remove_border_objects', 'model_name', 'filter_size',
|
1473
|
+
'filter_intensity', and 'restore_type'.
|
1474
|
+
Notes:
|
1475
|
+
- The 'diameter' is calculated using the `_get_diam` function based on the
|
1476
|
+
magnification level and object type.
|
1477
|
+
- The 'minimum_size' and 'maximum_size' are derived from the diameter.
|
1478
|
+
- The 'model_name' varies depending on the object type and other settings.
|
1479
|
+
- If `settings['verbose']` is True, the generated settings are printed.
|
1480
|
+
"""
|
1198
1481
|
object_settings = {}
|
1199
1482
|
|
1200
1483
|
object_settings['diameter'] = _get_diam(settings['magnification'], obj=object_type)
|
@@ -1242,7 +1525,7 @@ def _pivot_counts_table(db_path):
|
|
1242
1525
|
"""
|
1243
1526
|
Read a table from an SQLite database into a pandas DataFrame.
|
1244
1527
|
|
1245
|
-
|
1528
|
+
Args:
|
1246
1529
|
- db_path (str): The path to the SQLite database file.
|
1247
1530
|
- table_name (str): The name of the table to read. Default is 'object_counts'.
|
1248
1531
|
|
@@ -1287,6 +1570,38 @@ def _pivot_counts_table(db_path):
|
|
1287
1570
|
conn.close()
|
1288
1571
|
|
1289
1572
|
def _get_cellpose_channels_v2(src, nucleus_channel, pathogen_channel, cell_channel):
|
1573
|
+
"""
|
1574
|
+
Generate a dictionary of Cellpose channels based on the provided input channels.
|
1575
|
+
|
1576
|
+
This function constructs a mapping of Cellpose channels for nucleus, pathogen,
|
1577
|
+
and cell segmentation. It also checks for the existence of precomputed mask
|
1578
|
+
files and issues a warning if any channel is `None` while masks already exist.
|
1579
|
+
|
1580
|
+
Args:
|
1581
|
+
src (str): The source directory containing the 'masks' folder with precomputed
|
1582
|
+
mask files ('cell_mask_stack', 'nucleus_mask_stack', 'pathogen_mask_stack').
|
1583
|
+
nucleus_channel (int or None): The channel index for the nucleus. If `None`,
|
1584
|
+
the nucleus channel is not included.
|
1585
|
+
pathogen_channel (int or None): The channel index for the pathogen. If `None`,
|
1586
|
+
the pathogen channel is not included.
|
1587
|
+
cell_channel (int or None): The channel index for the cell. If `None`, the cell
|
1588
|
+
channel is not included.
|
1589
|
+
|
1590
|
+
Returns:
|
1591
|
+
dict: A dictionary where keys are 'nucleus', 'pathogen', and 'cell', and values
|
1592
|
+
are lists of two integers representing the channels to be used for each
|
1593
|
+
type. If a channel is not provided (`None`), it is excluded from the dictionary.
|
1594
|
+
|
1595
|
+
Warnings:
|
1596
|
+
Prints a warning if any of the channels (`nucleus_channel`, `pathogen_channel`,
|
1597
|
+
`cell_channel`) is `None` while precomputed mask files already exist in the
|
1598
|
+
specified source directory.
|
1599
|
+
|
1600
|
+
Notes:
|
1601
|
+
- The nucleus and pathogen channels are always duplicated as [channel, channel].
|
1602
|
+
- The cell channel prefers the nucleus channel as the first channel if available;
|
1603
|
+
otherwise, it duplicates the cell channel as [cell_channel, cell_channel].
|
1604
|
+
"""
|
1290
1605
|
cell_mask_path = os.path.join(src, 'masks', 'cell_mask_stack')
|
1291
1606
|
nucleus_mask_path = os.path.join(src, 'masks', 'nucleus_mask_stack')
|
1292
1607
|
pathogen_mask_path = os.path.join(src, 'masks', 'pathogen_mask_stack')
|
@@ -1315,6 +1630,34 @@ def _get_cellpose_channels_v2(src, nucleus_channel, pathogen_channel, cell_chann
|
|
1315
1630
|
return cellpose_channels
|
1316
1631
|
|
1317
1632
|
def _get_cellpose_channels_v1(src, nucleus_channel, pathogen_channel, cell_channel):
|
1633
|
+
"""
|
1634
|
+
Generates a dictionary mapping object types (nucleus, pathogen, cell) to their respective
|
1635
|
+
Cellpose channel configurations based on the provided input channels.
|
1636
|
+
This function checks for the existence of pre-generated Cellpose masks in the specified
|
1637
|
+
source directory and warns the user if any of the input channels are set to None while
|
1638
|
+
masks already exist. The function then constructs a dictionary of Cellpose channel
|
1639
|
+
configurations for the specified object types.
|
1640
|
+
Args:
|
1641
|
+
src (str): The source directory path where the Cellpose masks are stored.
|
1642
|
+
nucleus_channel (int or None): The channel index for the nucleus. If None, the nucleus
|
1643
|
+
is not included in the configuration.
|
1644
|
+
pathogen_channel (int or None): The channel index for the pathogen. If None, the pathogen
|
1645
|
+
is not included in the configuration.
|
1646
|
+
cell_channel (int or None): The channel index for the cell. If None, the cell is not
|
1647
|
+
included in the configuration.
|
1648
|
+
Returns:
|
1649
|
+
dict: A dictionary where keys are object types ('nucleus', 'pathogen', 'cell') and
|
1650
|
+
values are lists specifying the Cellpose channel configuration for each object type.
|
1651
|
+
Warnings:
|
1652
|
+
Prints a warning if any of the input channels are set to None while Cellpose masks
|
1653
|
+
already exist in the specified source directory.
|
1654
|
+
Notes:
|
1655
|
+
- The function assumes that the Cellpose masks are stored in subdirectories named
|
1656
|
+
'masks/cell_mask_stack', 'masks/nucleus_mask_stack', and 'masks/pathogen_mask_stack'
|
1657
|
+
within the source directory.
|
1658
|
+
- The channel configuration is represented as a list of two integers, where the first
|
1659
|
+
integer is always 0, and the second integer depends on the presence of other channels.
|
1660
|
+
"""
|
1318
1661
|
|
1319
1662
|
cell_mask_path = os.path.join(src, 'masks', 'cell_mask_stack')
|
1320
1663
|
nucleus_mask_path = os.path.join(src, 'masks', 'nucleus_mask_stack')
|
@@ -1347,6 +1690,33 @@ def _get_cellpose_channels_v1(src, nucleus_channel, pathogen_channel, cell_chann
|
|
1347
1690
|
return cellpose_channels
|
1348
1691
|
|
1349
1692
|
def _get_cellpose_channels(src, nucleus_channel, pathogen_channel, cell_channel):
|
1693
|
+
"""
|
1694
|
+
Generate a mapping of channels for Cellpose segmentation based on the provided input channels.
|
1695
|
+
|
1696
|
+
This function determines the logical order of channels for nucleus, cell, and pathogen,
|
1697
|
+
removes duplicates while preserving the first occurrence, and maps the channels to indices
|
1698
|
+
capped at a maximum value of 2. It then assigns these indices to the respective Cellpose
|
1699
|
+
segmentation categories.
|
1700
|
+
|
1701
|
+
Args:
|
1702
|
+
src (str): The source directory containing the mask files.
|
1703
|
+
nucleus_channel (int or None): The channel index for the nucleus. If None, the nucleus
|
1704
|
+
is not included in the mapping.
|
1705
|
+
pathogen_channel (int or None): The channel index for the pathogen. If None, the pathogen
|
1706
|
+
is not included in the mapping.
|
1707
|
+
cell_channel (int or None): The channel index for the cell. If None, the cell is not
|
1708
|
+
included in the mapping.
|
1709
|
+
|
1710
|
+
Returns:
|
1711
|
+
dict: A dictionary mapping segmentation categories ('nucleus', 'cell', 'pathogen') to
|
1712
|
+
their respective channel indices. Each category is assigned a list of two indices,
|
1713
|
+
which are either identical or derived from the logical order of the input channels.
|
1714
|
+
|
1715
|
+
Warnings:
|
1716
|
+
If any of the mask files already exist in the `src` directory and any of the input
|
1717
|
+
channels (`nucleus_channel`, `pathogen_channel`, `cell_channel`) are None, a warning
|
1718
|
+
is printed indicating potential unexpected behavior.
|
1719
|
+
"""
|
1350
1720
|
cell_mask_path = os.path.join(src, 'masks', 'cell_mask_stack')
|
1351
1721
|
nucleus_mask_path = os.path.join(src, 'masks', 'nucleus_mask_stack')
|
1352
1722
|
pathogen_mask_path = os.path.join(src, 'masks', 'pathogen_mask_stack')
|
@@ -1474,7 +1844,7 @@ def _split_data(df, group_by, object_type):
|
|
1474
1844
|
Splits the input dataframe into numeric and non-numeric parts, groups them by the specified column,
|
1475
1845
|
and returns the grouped dataframes with conditional aggregation.
|
1476
1846
|
|
1477
|
-
|
1847
|
+
Args:
|
1478
1848
|
df (pandas.DataFrame): The input dataframe.
|
1479
1849
|
group_by (str): The column name to group the dataframes by.
|
1480
1850
|
object_type (str): The column name to concatenate with 'prcf' to create a new column 'prcfo'.
|
@@ -1574,7 +1944,7 @@ def _group_by_well(df):
|
|
1574
1944
|
Group the DataFrame by well coordinates (plate, row, col) and apply mean function to numeric columns
|
1575
1945
|
and select the first value for non-numeric columns.
|
1576
1946
|
|
1577
|
-
|
1947
|
+
Args:
|
1578
1948
|
df (DataFrame): The input DataFrame to be grouped.
|
1579
1949
|
|
1580
1950
|
Returns:
|
@@ -1792,22 +2162,68 @@ class SpatialAttention(nn.Module):
|
|
1792
2162
|
|
1793
2163
|
# Multi-Scale Block with Attention
|
1794
2164
|
class MultiScaleBlockWithAttention(nn.Module):
|
2165
|
+
"""
|
2166
|
+
A PyTorch module implementing a multi-scale convolutional block with spatial attention.
|
2167
|
+
|
2168
|
+
This module applies a dilated convolution followed by a spatial attention mechanism
|
2169
|
+
to enhance input feature maps.
|
2170
|
+
"""
|
1795
2171
|
def __init__(self, in_channels, out_channels):
|
2172
|
+
"""
|
2173
|
+
Initialize the MultiScaleBlockWithAttention module.
|
2174
|
+
|
2175
|
+
Args:
|
2176
|
+
in_channels (int): Number of input channels.
|
2177
|
+
out_channels (int): Number of output channels after convolution.
|
2178
|
+
"""
|
1796
2179
|
super(MultiScaleBlockWithAttention, self).__init__()
|
1797
2180
|
self.dilated_conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, dilation=1, padding=1)
|
1798
2181
|
self.spatial_attention = nn.Conv2d(out_channels, out_channels, kernel_size=1)
|
1799
2182
|
|
1800
2183
|
def custom_forward(self, x):
|
2184
|
+
"""
|
2185
|
+
Apply dilated convolution followed by spatial attention.
|
2186
|
+
|
2187
|
+
Args:
|
2188
|
+
x (torch.Tensor): Input tensor of shape (N, C, H, W).
|
2189
|
+
|
2190
|
+
Returns:
|
2191
|
+
torch.Tensor: Output tensor after attention-enhanced feature transformation.
|
2192
|
+
"""
|
1801
2193
|
x1 = F.relu(self.dilated_conv1(x), inplace=True)
|
1802
2194
|
x = self.spatial_attention(x1)
|
1803
2195
|
return x
|
1804
2196
|
|
1805
2197
|
def forward(self, x):
|
2198
|
+
"""
|
2199
|
+
Standard forward pass that delegates to `custom_forward`.
|
2200
|
+
|
2201
|
+
Args:
|
2202
|
+
x (torch.Tensor): Input tensor.
|
2203
|
+
|
2204
|
+
Returns:
|
2205
|
+
torch.Tensor: Output tensor.
|
2206
|
+
"""
|
1806
2207
|
return self.custom_forward(x)
|
1807
2208
|
|
1808
2209
|
# Final Classifier
|
1809
2210
|
class CustomCellClassifier(nn.Module):
|
2211
|
+
"""
|
2212
|
+
A custom neural network module for single-cell classification using early fusion and multi-scale attention.
|
2213
|
+
|
2214
|
+
This architecture supports optional gradient checkpointing for reduced memory usage during training.
|
2215
|
+
"""
|
1810
2216
|
def __init__(self, num_classes, pathogen_channel, use_attention, use_checkpoint, dropout_rate):
|
2217
|
+
"""
|
2218
|
+
Initialize the classifier with early fusion and attention blocks.
|
2219
|
+
|
2220
|
+
Args:
|
2221
|
+
num_classes (int): Number of classification categories.
|
2222
|
+
pathogen_channel (int): Unused; reserved for future feature fusion.
|
2223
|
+
use_attention (bool): Unused; reserved for attention gating.
|
2224
|
+
use_checkpoint (bool): Whether to use checkpointing to reduce memory.
|
2225
|
+
dropout_rate (float): Unused; reserved for future regularization.
|
2226
|
+
"""
|
1811
2227
|
super(CustomCellClassifier, self).__init__()
|
1812
2228
|
self.early_fusion = EarlyFusion(in_channels=3)
|
1813
2229
|
|
@@ -1820,6 +2236,15 @@ class CustomCellClassifier(nn.Module):
|
|
1820
2236
|
param.requires_grad = True
|
1821
2237
|
|
1822
2238
|
def custom_forward(self, x):
|
2239
|
+
"""
|
2240
|
+
Perform a standard forward pass without gradient checkpointing.
|
2241
|
+
|
2242
|
+
Args:
|
2243
|
+
x (torch.Tensor): Input tensor of shape (N, C, H, W).
|
2244
|
+
|
2245
|
+
Returns:
|
2246
|
+
torch.Tensor: Logits for each class.
|
2247
|
+
"""
|
1823
2248
|
x.requires_grad = True
|
1824
2249
|
x = self.early_fusion(x)
|
1825
2250
|
x = self.multi_scale_block_1(x)
|
@@ -1828,6 +2253,15 @@ class CustomCellClassifier(nn.Module):
|
|
1828
2253
|
return x
|
1829
2254
|
|
1830
2255
|
def forward(self, x):
|
2256
|
+
"""
|
2257
|
+
Forward pass with optional gradient checkpointing.
|
2258
|
+
|
2259
|
+
Args:
|
2260
|
+
x (torch.Tensor): Input tensor.
|
2261
|
+
|
2262
|
+
Returns:
|
2263
|
+
torch.Tensor: Output logits.
|
2264
|
+
"""
|
1831
2265
|
if self.use_checkpoint:
|
1832
2266
|
x.requires_grad = True
|
1833
2267
|
return checkpoint(self.custom_forward, x)
|
@@ -1836,7 +2270,22 @@ class CustomCellClassifier(nn.Module):
|
|
1836
2270
|
|
1837
2271
|
#CNN and Transformer class, pick any Torch model.
|
1838
2272
|
class TorchModel(nn.Module):
|
2273
|
+
"""
|
2274
|
+
A PyTorch wrapper for pretrained torchvision models with a custom SPACR classifier head.
|
2275
|
+
|
2276
|
+
This class supports custom dropout insertion and optional gradient checkpointing
|
2277
|
+
for memory efficiency during training.
|
2278
|
+
"""
|
1839
2279
|
def __init__(self, model_name='resnet50', pretrained=True, dropout_rate=None, use_checkpoint=False):
|
2280
|
+
"""
|
2281
|
+
Initialize the TorchModel with optional dropout and checkpointing.
|
2282
|
+
|
2283
|
+
Args:
|
2284
|
+
model_name (str): The model architecture to load from torchvision.models.
|
2285
|
+
pretrained (bool): Whether to initialize with pretrained weights.
|
2286
|
+
dropout_rate (float or None): Dropout rate for the classifier head.
|
2287
|
+
use_checkpoint (bool): Whether to enable gradient checkpointing.
|
2288
|
+
"""
|
1840
2289
|
super(TorchModel, self).__init__()
|
1841
2290
|
self.model_name = model_name
|
1842
2291
|
self.use_checkpoint = use_checkpoint
|
@@ -1853,13 +2302,27 @@ class TorchModel(nn.Module):
|
|
1853
2302
|
self.init_spacr_classifier(dropout_rate)
|
1854
2303
|
|
1855
2304
|
def apply_dropout_rate(self, model, dropout_rate):
|
1856
|
-
"""
|
2305
|
+
"""
|
2306
|
+
Recursively set dropout probability for all nn.Dropout layers in the model.
|
2307
|
+
|
2308
|
+
Args:
|
2309
|
+
model (nn.Module): The model to modify.
|
2310
|
+
dropout_rate (float): New dropout probability.
|
2311
|
+
"""
|
1857
2312
|
for module in model.modules():
|
1858
2313
|
if isinstance(module, nn.Dropout):
|
1859
2314
|
module.p = dropout_rate
|
1860
2315
|
|
1861
2316
|
def init_base_model(self, pretrained):
|
1862
|
-
"""
|
2317
|
+
"""
|
2318
|
+
Load the base model from torchvision.models.
|
2319
|
+
|
2320
|
+
Args:
|
2321
|
+
pretrained (bool): Whether to load pretrained weights.
|
2322
|
+
|
2323
|
+
Returns:
|
2324
|
+
nn.Module: The base feature extractor.
|
2325
|
+
"""
|
1863
2326
|
model_func = models.__dict__.get(self.model_name, None)
|
1864
2327
|
if not model_func:
|
1865
2328
|
raise ValueError(f"Model {self.model_name} is not recognized.")
|
@@ -1870,7 +2333,12 @@ class TorchModel(nn.Module):
|
|
1870
2333
|
return model_func(pretrained=pretrained)
|
1871
2334
|
|
1872
2335
|
def get_weight_choice(self):
|
1873
|
-
"""
|
2336
|
+
"""
|
2337
|
+
Get the default weights enum for the selected model.
|
2338
|
+
|
2339
|
+
Returns:
|
2340
|
+
Optional[torchvision.models.WeightsEnum]: Default weights or None.
|
2341
|
+
"""
|
1874
2342
|
weight_enum = None
|
1875
2343
|
for attr_name in dir(models):
|
1876
2344
|
if attr_name.lower() == f"{self.model_name}_weights".lower():
|
@@ -1879,7 +2347,12 @@ class TorchModel(nn.Module):
|
|
1879
2347
|
return weight_enum.DEFAULT if weight_enum else None
|
1880
2348
|
|
1881
2349
|
def get_num_ftrs(self):
|
1882
|
-
"""
|
2350
|
+
"""
|
2351
|
+
Determine output feature dimensionality from the base model.
|
2352
|
+
|
2353
|
+
Returns:
|
2354
|
+
int: Feature vector size.
|
2355
|
+
"""
|
1883
2356
|
if hasattr(self.base_model, 'fc'):
|
1884
2357
|
self.base_model.fc = nn.Identity()
|
1885
2358
|
elif hasattr(self.base_model, 'classifier'):
|
@@ -1892,14 +2365,27 @@ class TorchModel(nn.Module):
|
|
1892
2365
|
return output.size(1)
|
1893
2366
|
|
1894
2367
|
def init_spacr_classifier(self, dropout_rate):
|
1895
|
-
"""
|
2368
|
+
"""
|
2369
|
+
Create the final classification layer and optional dropout.
|
2370
|
+
|
2371
|
+
Args:
|
2372
|
+
dropout_rate (float or None): Dropout probability. If None, dropout is skipped.
|
2373
|
+
"""
|
1896
2374
|
self.use_dropout = dropout_rate is not None
|
1897
2375
|
if self.use_dropout:
|
1898
2376
|
self.dropout = nn.Dropout(dropout_rate)
|
1899
2377
|
self.spacr_classifier = nn.Linear(self.num_ftrs, 1)
|
1900
2378
|
|
1901
2379
|
def forward(self, x):
|
1902
|
-
"""
|
2380
|
+
"""
|
2381
|
+
Forward pass through base model, optional dropout, and final classifier.
|
2382
|
+
|
2383
|
+
Args:
|
2384
|
+
x (torch.Tensor): Input image tensor of shape (N, 3, H, W).
|
2385
|
+
|
2386
|
+
Returns:
|
2387
|
+
torch.Tensor: Output logits of shape (N,).
|
2388
|
+
"""
|
1903
2389
|
if self.use_checkpoint:
|
1904
2390
|
x = checkpoint(self.base_model, x)
|
1905
2391
|
else:
|
@@ -1910,19 +2396,62 @@ class TorchModel(nn.Module):
|
|
1910
2396
|
return logits
|
1911
2397
|
|
1912
2398
|
class FocalLossWithLogits(nn.Module):
|
2399
|
+
"""
|
2400
|
+
Focal Loss with logits for binary classification.
|
2401
|
+
|
2402
|
+
This loss function is especially useful for addressing class imbalance by focusing more
|
2403
|
+
on hard-to-classify examples.
|
2404
|
+
|
2405
|
+
Args:
|
2406
|
+
alpha (float): Balancing factor for positive/negative examples. Default is 1.
|
2407
|
+
gamma (float): Focusing parameter that down-weights easy examples. Default is 2.
|
2408
|
+
"""
|
2409
|
+
|
1913
2410
|
def __init__(self, alpha=1, gamma=2):
|
2411
|
+
"""
|
2412
|
+
Initialize the focal loss.
|
2413
|
+
|
2414
|
+
Args:
|
2415
|
+
alpha (float): Balancing factor for positive/negative examples. Default is 1.
|
2416
|
+
gamma (float): Focusing parameter to down-weight well-classified examples. Default is 2.
|
2417
|
+
"""
|
1914
2418
|
super(FocalLossWithLogits, self).__init__()
|
1915
2419
|
self.alpha = alpha
|
1916
2420
|
self.gamma = gamma
|
1917
2421
|
|
1918
2422
|
def forward(self, logits, target):
|
2423
|
+
"""
|
2424
|
+
Compute the focal loss between logits and targets.
|
2425
|
+
|
2426
|
+
Args:
|
2427
|
+
logits (torch.Tensor): Predicted unnormalized scores (logits).
|
2428
|
+
target (torch.Tensor): Ground truth binary labels (same shape as logits).
|
2429
|
+
|
2430
|
+
Returns:
|
2431
|
+
torch.Tensor: Scalar focal loss value.
|
2432
|
+
"""
|
1919
2433
|
BCE_loss = F.binary_cross_entropy_with_logits(logits, target, reduction='none')
|
1920
2434
|
pt = torch.exp(-BCE_loss)
|
1921
2435
|
focal_loss = self.alpha * (1-pt)**self.gamma * BCE_loss
|
1922
2436
|
return focal_loss.mean()
|
1923
2437
|
|
1924
2438
|
class ResNet(nn.Module):
|
2439
|
+
"""
|
2440
|
+
A wrapper around torchvision ResNet models with optional dropout, checkpointing,
|
2441
|
+
and a custom classifier head.
|
2442
|
+
|
2443
|
+
Supported ResNet variants: resnet18, resnet34, resnet50, resnet101, resnet152.
|
2444
|
+
"""
|
1925
2445
|
def __init__(self, resnet_type='resnet50', dropout_rate=None, use_checkpoint=False, init_weights='imagenet'):
|
2446
|
+
"""
|
2447
|
+
Initialize the ResNet model wrapper.
|
2448
|
+
|
2449
|
+
Args:
|
2450
|
+
resnet_type (str): Which ResNet variant to use. Options: 'resnet18', 'resnet34', etc.
|
2451
|
+
dropout_rate (float or None): Dropout rate to apply before the final layer.
|
2452
|
+
use_checkpoint (bool): Whether to enable gradient checkpointing.
|
2453
|
+
init_weights (str): Either 'imagenet' to load pretrained weights or 'none'.
|
2454
|
+
"""
|
1926
2455
|
super(ResNet, self).__init__()
|
1927
2456
|
|
1928
2457
|
resnet_map = {
|
@@ -1939,6 +2468,15 @@ class ResNet(nn.Module):
|
|
1939
2468
|
self.initialize_base(resnet_map[resnet_type], dropout_rate, use_checkpoint, init_weights)
|
1940
2469
|
|
1941
2470
|
def initialize_base(self, base_model_dict, dropout_rate, use_checkpoint, init_weights):
|
2471
|
+
"""
|
2472
|
+
Initialize the base model and classifier layers.
|
2473
|
+
|
2474
|
+
Args:
|
2475
|
+
base_model_dict (dict): Contains model constructor and weight enum.
|
2476
|
+
dropout_rate (float or None): Dropout rate to use.
|
2477
|
+
use_checkpoint (bool): Whether to use gradient checkpointing.
|
2478
|
+
init_weights (str): Weight initialization mode.
|
2479
|
+
"""
|
1942
2480
|
if init_weights == 'imagenet':
|
1943
2481
|
self.resnet = base_model_dict['func'](weights=base_model_dict['weights'])
|
1944
2482
|
elif init_weights == 'none':
|
@@ -1956,6 +2494,15 @@ class ResNet(nn.Module):
|
|
1956
2494
|
self.fc2 = nn.Linear(500, 1)
|
1957
2495
|
|
1958
2496
|
def forward(self, x):
|
2497
|
+
"""
|
2498
|
+
Forward pass through ResNet and classification layers.
|
2499
|
+
|
2500
|
+
Args:
|
2501
|
+
x (torch.Tensor): Input tensor of shape (N, 3, H, W).
|
2502
|
+
|
2503
|
+
Returns:
|
2504
|
+
torch.Tensor: Logits of shape (N,).
|
2505
|
+
"""
|
1959
2506
|
x.requires_grad = True # Ensure that the tensor has requires_grad set to True
|
1960
2507
|
|
1961
2508
|
if self.use_checkpoint:
|
@@ -1995,7 +2542,7 @@ def classification_metrics(all_labels, prediction_pos_probs, loss, epoch):
|
|
1995
2542
|
"""
|
1996
2543
|
Calculate classification metrics for binary classification.
|
1997
2544
|
|
1998
|
-
|
2545
|
+
Args:
|
1999
2546
|
- all_labels (list): List of true labels.
|
2000
2547
|
- prediction_pos_probs (list): List of predicted positive probabilities.
|
2001
2548
|
- loader_name (str): Name of the data loader.
|
@@ -2129,6 +2676,23 @@ def choose_model(model_type, device, init_weights=True, dropout_rate=0, use_chec
|
|
2129
2676
|
return base_model
|
2130
2677
|
|
2131
2678
|
def calculate_loss(output, target, loss_type='binary_cross_entropy_with_logits'):
|
2679
|
+
"""
|
2680
|
+
Calculates the loss between the model output and the target based on the specified loss type.
|
2681
|
+
|
2682
|
+
Args:
|
2683
|
+
output (Tensor): The predicted output from the model.
|
2684
|
+
target (Tensor): The ground truth target values.
|
2685
|
+
loss_type (str, optional): The type of loss function to use.
|
2686
|
+
Supported values are:
|
2687
|
+
- 'binary_cross_entropy_with_logits': Uses binary cross-entropy loss with logits.
|
2688
|
+
- 'focal_loss': Uses focal loss with logits. Defaults to 'binary_cross_entropy_with_logits'.
|
2689
|
+
|
2690
|
+
Returns:
|
2691
|
+
Tensor: The computed loss value.
|
2692
|
+
|
2693
|
+
Raises:
|
2694
|
+
ValueError: If an unsupported loss_type is provided.
|
2695
|
+
"""
|
2132
2696
|
if loss_type == 'binary_cross_entropy_with_logits':
|
2133
2697
|
loss = F.binary_cross_entropy_with_logits(output, target)
|
2134
2698
|
elif loss_type == 'focal_loss':
|
@@ -2137,6 +2701,22 @@ def calculate_loss(output, target, loss_type='binary_cross_entropy_with_logits')
|
|
2137
2701
|
return loss
|
2138
2702
|
|
2139
2703
|
def pick_best_model(src):
|
2704
|
+
"""
|
2705
|
+
Selects the best model file from a given directory based on accuracy and epoch.
|
2706
|
+
This function scans the specified directory for files with a `.pth` extension,
|
2707
|
+
extracts accuracy and epoch information from their filenames using a predefined
|
2708
|
+
pattern, and selects the file with the highest accuracy. If multiple files have
|
2709
|
+
the same accuracy, the one with the highest epoch is selected.
|
2710
|
+
Args:
|
2711
|
+
src (str): The path to the directory containing the model files.
|
2712
|
+
Returns:
|
2713
|
+
str: The full path to the best model file based on accuracy and epoch.
|
2714
|
+
Notes:
|
2715
|
+
- The filenames are expected to follow the pattern `_epoch_<epoch>_acc_<accuracy>.pth`,
|
2716
|
+
where `<epoch>` is an integer and `<accuracy>` is a float.
|
2717
|
+
- If no files match the pattern, the function may raise an IndexError when
|
2718
|
+
attempting to access the first element of the sorted list.
|
2719
|
+
"""
|
2140
2720
|
all_files = os.listdir(src)
|
2141
2721
|
pth_files = [f for f in all_files if f.endswith('.pth')]
|
2142
2722
|
pattern = re.compile(r'_epoch_(\d+)_acc_(\d+(?:\.\d+)?)')
|
@@ -2153,18 +2733,67 @@ def pick_best_model(src):
|
|
2153
2733
|
return os.path.join(src, best_model)
|
2154
2734
|
|
2155
2735
|
def get_paths_from_db(df, png_df, image_type='cell_png'):
|
2736
|
+
"""
|
2737
|
+
Filters and retrieves paths from a DataFrame based on specified criteria.
|
2738
|
+
|
2739
|
+
Args:
|
2740
|
+
df (pd.DataFrame): A DataFrame whose index contains the objects of interest.
|
2741
|
+
png_df (pd.DataFrame): A DataFrame containing a 'png_path' column and a 'prcfo' column.
|
2742
|
+
image_type (str, optional): A string to filter the 'png_path' column. Defaults to 'cell_png'.
|
2743
|
+
|
2744
|
+
Returns:
|
2745
|
+
pd.DataFrame: A filtered DataFrame containing rows from `png_df` where the 'png_path'
|
2746
|
+
column contains the `image_type` string and the 'prcfo' column matches
|
2747
|
+
the index of `df`.
|
2748
|
+
"""
|
2156
2749
|
objects = df.index.tolist()
|
2157
2750
|
filtered_df = png_df[png_df['png_path'].str.contains(image_type) & png_df['prcfo'].isin(objects)]
|
2158
2751
|
return filtered_df
|
2159
2752
|
|
2160
2753
|
def save_file_lists(dst, data_set, ls):
|
2754
|
+
"""
|
2755
|
+
Saves a list of file paths or data entries to a CSV file.
|
2756
|
+
|
2757
|
+
Args:
|
2758
|
+
dst (str): The destination directory where the CSV file will be saved.
|
2759
|
+
data_set (str): The name of the dataset, which will also be used as the column name in the CSV file and the filename.
|
2760
|
+
ls (list): A list of file paths or data entries to be saved.
|
2761
|
+
|
2762
|
+
Returns:
|
2763
|
+
None
|
2764
|
+
"""
|
2161
2765
|
df = pd.DataFrame(ls, columns=[data_set])
|
2162
2766
|
df.to_csv(f'{dst}/{data_set}.csv', index=False)
|
2163
2767
|
return
|
2164
2768
|
|
2165
2769
|
def augment_single_image(args):
|
2166
|
-
|
2167
|
-
|
2770
|
+
"""
|
2771
|
+
Augment a single image by applying various transformations and saving the results.
|
2772
|
+
|
2773
|
+
This function reads an image from the specified file path, applies a series of
|
2774
|
+
transformations (original, rotations, and flips), and saves the transformed images
|
2775
|
+
to the destination directory with appropriate filenames.
|
2776
|
+
|
2777
|
+
Args:
|
2778
|
+
args (tuple): A tuple containing:
|
2779
|
+
img_path (str): The file path to the input image.
|
2780
|
+
dst (str): The destination directory where the augmented images will be saved.
|
2781
|
+
|
2782
|
+
Notes:
|
2783
|
+
The following transformations are applied to the input image:
|
2784
|
+
- Original image (no transformation)
|
2785
|
+
- 90-degree clockwise rotation
|
2786
|
+
- 180-degree rotation
|
2787
|
+
- 270-degree clockwise rotation
|
2788
|
+
- Horizontal flip
|
2789
|
+
- Vertical flip
|
2790
|
+
|
2791
|
+
Side Effects:
|
2792
|
+
Saves the augmented images to the specified destination directory.
|
2793
|
+
Filenames indicate the type of transformation applied.
|
2794
|
+
"""
|
2795
|
+
img_path, dst = args
|
2796
|
+
img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)
|
2168
2797
|
filename = os.path.basename(img_path).split('.')[0]
|
2169
2798
|
|
2170
2799
|
# Original Image
|
@@ -2191,6 +2820,17 @@ def augment_single_image(args):
|
|
2191
2820
|
cv2.imwrite(os.path.join(dst, f"{filename}_flip_ver.png"), img_flip_ver)
|
2192
2821
|
|
2193
2822
|
def augment_images(file_paths, dst):
|
2823
|
+
"""
|
2824
|
+
Augments a list of images and saves the augmented images to the specified destination directory.
|
2825
|
+
|
2826
|
+
Args:
|
2827
|
+
file_paths (list of str): A list of file paths to the images to be augmented.
|
2828
|
+
dst (str): The destination directory where the augmented images will be saved.
|
2829
|
+
If the directory does not exist, it will be created.
|
2830
|
+
|
2831
|
+
Returns:
|
2832
|
+
None
|
2833
|
+
"""
|
2194
2834
|
if not os.path.exists(dst):
|
2195
2835
|
os.makedirs(dst)
|
2196
2836
|
|
@@ -2200,6 +2840,52 @@ def augment_images(file_paths, dst):
|
|
2200
2840
|
pool.map(augment_single_image, args_list)
|
2201
2841
|
|
2202
2842
|
def augment_classes(dst, nc, pc, generate=True,move=True):
|
2843
|
+
def augment_classes(dst, nc, pc, generate=True, move=True):
|
2844
|
+
"""
|
2845
|
+
Augments and organizes image datasets into training and testing directories.
|
2846
|
+
|
2847
|
+
Args:
|
2848
|
+
-----------
|
2849
|
+
dst : str
|
2850
|
+
The destination directory where augmented data will be stored.
|
2851
|
+
nc : list
|
2852
|
+
List of file paths for the "negative class" images to be augmented.
|
2853
|
+
pc : list
|
2854
|
+
List of file paths for the "positive class" images to be augmented.
|
2855
|
+
generate : bool, optional
|
2856
|
+
If True, generates augmented images for both classes and saves them in
|
2857
|
+
separate directories (`aug_nc` for negative class and `aug_pc` for positive class).
|
2858
|
+
Default is True.
|
2859
|
+
move : bool, optional
|
2860
|
+
If True, splits the augmented images into training and testing datasets,
|
2861
|
+
and organizes them into subdirectories under `aug/train` and `aug/test`.
|
2862
|
+
Default is True.
|
2863
|
+
|
2864
|
+
Returns:
|
2865
|
+
--------
|
2866
|
+
None
|
2867
|
+
|
2868
|
+
Notes:
|
2869
|
+
------
|
2870
|
+
- The function uses `train_test_split` to split the augmented images into
|
2871
|
+
training (90%) and testing (10%) datasets.
|
2872
|
+
- The progress of moving files is displayed in the console.
|
2873
|
+
- The final counts of training and testing images for each class are printed.
|
2874
|
+
|
2875
|
+
Example Directory Structure:
|
2876
|
+
----------------------------
|
2877
|
+
After execution, the directory structure will look like this:
|
2878
|
+
dst/
|
2879
|
+
├── aug_nc/ # Augmented negative class images
|
2880
|
+
├── aug_pc/ # Augmented positive class images
|
2881
|
+
├── aug/
|
2882
|
+
├── train/
|
2883
|
+
├── nc/ # Training negative class images
|
2884
|
+
├── pc/ # Training positive class images
|
2885
|
+
├── test/
|
2886
|
+
├── nc/ # Testing negative class images
|
2887
|
+
├── pc/ # Testing positive class images
|
2888
|
+
"""
|
2203
2889
|
aug_nc = os.path.join(dst,'aug_nc')
|
2204
2890
|
aug_pc = os.path.join(dst,'aug_pc')
|
2205
2891
|
all_ = len(nc)+len(pc)
|
@@ -2251,6 +2937,21 @@ def augment_classes(dst, nc, pc, generate=True,move=True):
|
|
2251
2937
|
return
|
2252
2938
|
|
2253
2939
|
def annotate_predictions(csv_loc):
|
2940
|
+
"""
|
2941
|
+
Reads a CSV file containing image metadata, processes the data to extract
|
2942
|
+
additional information, and assigns a condition label to each row based on
|
2943
|
+
specific rules.
|
2944
|
+
Args:
|
2945
|
+
csv_loc (str): The file path to the CSV file containing the metadata.
|
2946
|
+
Returns:
|
2947
|
+
pandas.DataFrame: A DataFrame with the following additional columns:
|
2948
|
+
- 'filename': Extracted filename from the 'path' column.
|
2949
|
+
- 'plateID': Plate ID extracted from the filename.
|
2950
|
+
- 'well': Well information extracted from the filename.
|
2951
|
+
- 'fieldID': Field ID extracted from the filename.
|
2952
|
+
- 'object': Object ID extracted from the filename (with '.png' removed).
|
2953
|
+
- 'cond': Assigned condition label ('screen', 'pc', 'nc', or '') based on rules.
|
2954
|
+
"""
|
2254
2955
|
df = pd.read_csv(csv_loc)
|
2255
2956
|
df['filename'] = df['path'].apply(lambda x: x.split('/')[-1])
|
2256
2957
|
df[['plateID', 'well', 'fieldID', 'object']] = df['filename'].str.split('_', expand=True)
|
@@ -2274,11 +2975,47 @@ def annotate_predictions(csv_loc):
|
|
2274
2975
|
return df
|
2275
2976
|
|
2276
2977
|
def initiate_counter(counter_, lock_):
|
2978
|
+
"""
|
2979
|
+
Initializes global variables `counter` and `lock` with the provided arguments.
|
2980
|
+
|
2981
|
+
This function sets the global variables `counter` and `lock` to the values
|
2982
|
+
passed as `counter_` and `lock_`, respectively. It is typically used to
|
2983
|
+
share a counter and a lock object across multiple threads or processes.
|
2984
|
+
|
2985
|
+
Args:
|
2986
|
+
counter_ (Any): The counter object to be assigned to the global `counter`.
|
2987
|
+
lock_ (Any): The lock object to be assigned to the global `lock`.
|
2988
|
+
"""
|
2277
2989
|
global counter, lock
|
2278
2990
|
counter = counter_
|
2279
2991
|
lock = lock_
|
2280
2992
|
|
2281
2993
|
def add_images_to_tar(paths_chunk, tar_path, total_images):
|
2994
|
+
"""
|
2995
|
+
Adds a chunk of image files to a tar archive.
|
2996
|
+
|
2997
|
+
Args:
|
2998
|
+
paths_chunk (list of str): A list of file paths to the images to be added to the tar archive.
|
2999
|
+
tar_path (str): The path where the tar archive will be created or overwritten.
|
3000
|
+
total_images (int): The total number of images being processed, used for progress tracking.
|
3001
|
+
|
3002
|
+
Behavior:
|
3003
|
+
- Opens a tar archive at the specified `tar_path` in write mode.
|
3004
|
+
- Iterates through the provided `paths_chunk` and adds each image to the tar archive.
|
3005
|
+
- Tracks progress using a shared counter and prints progress updates every 10 images.
|
3006
|
+
- Handles missing files gracefully by printing a warning message if a file is not found.
|
3007
|
+
|
3008
|
+
Notes:
|
3009
|
+
- This function assumes the existence of a global `lock` object for thread-safe counter updates.
|
3010
|
+
- The `counter` object is expected to be a shared multiprocessing.Value or similar.
|
3011
|
+
- The `print_progress` function is used to display progress updates.
|
3012
|
+
|
3013
|
+
Exceptions:
|
3014
|
+
- Prints a warning message if a file in `paths_chunk` is not found, but continues processing other files.
|
3015
|
+
|
3016
|
+
Example:
|
3017
|
+
add_images_to_tar(['/path/to/image1.jpg', '/path/to/image2.jpg'], '/path/to/archive.tar', 100)
|
3018
|
+
"""
|
2282
3019
|
with tarfile.open(tar_path, 'w') as tar:
|
2283
3020
|
for i, img_path in enumerate(paths_chunk):
|
2284
3021
|
arcname = os.path.basename(img_path)
|
@@ -2294,6 +3031,29 @@ def add_images_to_tar(paths_chunk, tar_path, total_images):
|
|
2294
3031
|
print(f"File not found: {img_path}")
|
2295
3032
|
|
2296
3033
|
def generate_fraction_map(df, gene_column, min_frequency=0.0):
|
3034
|
+
"""
|
3035
|
+
Generates a fraction map from a given DataFrame and writes it to a CSV file.
|
3036
|
+
|
3037
|
+
This function calculates the fraction of counts for each gene and well,
|
3038
|
+
organizes the data into a pivot table-like structure, and filters out
|
3039
|
+
columns based on a minimum frequency threshold. The resulting DataFrame
|
3040
|
+
is saved as a CSV file.
|
3041
|
+
|
3042
|
+
Args:
|
3043
|
+
df (pd.DataFrame): Input DataFrame containing the data. It must include
|
3044
|
+
the columns 'count', 'well_read_sum', 'prc', and the specified
|
3045
|
+
`gene_column`.
|
3046
|
+
gene_column (str): The name of the column in `df` that contains gene
|
3047
|
+
identifiers.
|
3048
|
+
min_frequency (float, optional): The minimum frequency threshold for
|
3049
|
+
filtering columns. Columns with a maximum value below this
|
3050
|
+
threshold are dropped. Defaults to 0.0.
|
3051
|
+
|
3052
|
+
Returns:
|
3053
|
+
pd.DataFrame: A DataFrame containing the fraction map, with wells as
|
3054
|
+
rows and genes as columns. Cells contain the fraction values, and
|
3055
|
+
missing values are filled with 0.0.
|
3056
|
+
"""
|
2297
3057
|
df['fraction'] = df['count']/df['well_read_sum']
|
2298
3058
|
genes = df[gene_column].unique().tolist()
|
2299
3059
|
wells = df['prc'].unique().tolist()
|
@@ -2319,6 +3079,17 @@ def generate_fraction_map(df, gene_column, min_frequency=0.0):
|
|
2319
3079
|
return independent_variables
|
2320
3080
|
|
2321
3081
|
def fishers_odds(df, threshold=0.5, phenotyp_col='mean_pred'):
|
3082
|
+
"""
|
3083
|
+
Perform Fisher's exact test to evaluate the association between mutants and a binned phenotype score.
|
3084
|
+
|
3085
|
+
Args:
|
3086
|
+
df (pandas.DataFrame): DataFrame containing binary mutant indicators and a phenotype score column.
|
3087
|
+
threshold (float, optional): Threshold to bin the phenotype score. Defaults to 0.5.
|
3088
|
+
phenotyp_col (str, optional): Name of the column containing phenotype scores. Defaults to 'mean_pred'.
|
3089
|
+
|
3090
|
+
Returns:
|
3091
|
+
pandas.DataFrame: DataFrame with columns 'Mutant', 'OddsRatio', 'PValue', and 'AdjustedPValue'.
|
3092
|
+
"""
|
2322
3093
|
# Binning based on phenotype score (e.g., above 0.8 as high)
|
2323
3094
|
df['high_phenotype'] = df[phenotyp_col] < threshold
|
2324
3095
|
|
@@ -2359,6 +3130,37 @@ def fishers_odds(df, threshold=0.5, phenotyp_col='mean_pred'):
|
|
2359
3130
|
return filtered_results_df
|
2360
3131
|
|
2361
3132
|
def model_metrics(model):
|
3133
|
+
"""
|
3134
|
+
Calculate and display additional metrics and generate diagnostic plots for a given model.
|
3135
|
+
|
3136
|
+
Args:
|
3137
|
+
-----------
|
3138
|
+
model : statsmodels.regression.linear_model.RegressionResultsWrapper
|
3139
|
+
A fitted regression model object from the statsmodels library.
|
3140
|
+
|
3141
|
+
Metrics Calculated:
|
3142
|
+
--------------------
|
3143
|
+
- Root Mean Squared Error (RMSE): Measures the standard deviation of residuals.
|
3144
|
+
- Mean Absolute Error (MAE): Measures the average magnitude of residuals.
|
3145
|
+
- Durbin-Watson: Tests for the presence of autocorrelation in residuals.
|
3146
|
+
|
3147
|
+
Diagnostic Plots:
|
3148
|
+
------------------
|
3149
|
+
1. Residuals vs. Fitted: Scatter plot to check for non-linearity or unequal error variance.
|
3150
|
+
2. Histogram of Residuals: Distribution of residuals to check for normality.
|
3151
|
+
3. QQ Plot: Quantile-Quantile plot to assess if residuals follow a normal distribution.
|
3152
|
+
4. Scale-Location: Scatter plot of standardized residuals to check for homoscedasticity.
|
3153
|
+
|
3154
|
+
Notes:
|
3155
|
+
------
|
3156
|
+
- This function uses matplotlib and seaborn for plotting.
|
3157
|
+
- Ensure that the input model is fitted and contains the necessary attributes like `resid`,
|
3158
|
+
`fittedvalues`, and `mse_resid`.
|
3159
|
+
|
3160
|
+
Returns:
|
3161
|
+
--------
|
3162
|
+
None
|
3163
|
+
"""
|
2362
3164
|
|
2363
3165
|
# Calculate additional metrics
|
2364
3166
|
rmse = np.sqrt(model.mse_resid)
|
@@ -2400,6 +3202,18 @@ def model_metrics(model):
|
|
2400
3202
|
plt.show()
|
2401
3203
|
|
2402
3204
|
def check_multicollinearity(x):
|
3205
|
+
"""
|
3206
|
+
Checks multicollinearity of the predictors by computing the Variance Inflation Factor (VIF).
|
3207
|
+
|
3208
|
+
Args:
|
3209
|
+
x (pd.DataFrame): A DataFrame containing the predictor variables.
|
3210
|
+
|
3211
|
+
Returns:
|
3212
|
+
pd.DataFrame: A DataFrame with two columns:
|
3213
|
+
- 'Variable': The names of the predictor variables.
|
3214
|
+
- 'VIF': The Variance Inflation Factor for each predictor variable.
|
3215
|
+
A VIF value greater than 10 indicates high multicollinearity.
|
3216
|
+
"""
|
2403
3217
|
"""Checks multicollinearity of the predictors by computing the VIF."""
|
2404
3218
|
vif_data = pd.DataFrame()
|
2405
3219
|
vif_data["Variable"] = x.columns
|
@@ -2407,6 +3221,17 @@ def check_multicollinearity(x):
|
|
2407
3221
|
return vif_data
|
2408
3222
|
|
2409
3223
|
def lasso_reg(merged_df, alpha_value=0.01, reg_type='lasso'):
|
3224
|
+
"""
|
3225
|
+
Perform Lasso or Ridge regression on the input DataFrame.
|
3226
|
+
|
3227
|
+
Args:
|
3228
|
+
merged_df (pandas.DataFrame): DataFrame with columns 'gene', 'grna', 'plateID', 'rowID', 'columnID', and 'pred'.
|
3229
|
+
alpha_value (float, optional): Regularization strength. Defaults to 0.01.
|
3230
|
+
reg_type (str, optional): Type of regression to perform, either 'lasso' or 'ridge'. Defaults to 'lasso'.
|
3231
|
+
|
3232
|
+
Returns:
|
3233
|
+
pandas.DataFrame: DataFrame with 'Feature' and 'Coefficient' columns.
|
3234
|
+
"""
|
2410
3235
|
# Separate predictors and response
|
2411
3236
|
X = merged_df[['gene', 'grna', 'plateID', 'rowID', 'columnID']]
|
2412
3237
|
y = merged_df['pred']
|
@@ -2433,7 +3258,20 @@ def lasso_reg(merged_df, alpha_value=0.01, reg_type='lasso'):
|
|
2433
3258
|
return coeff_df
|
2434
3259
|
|
2435
3260
|
def MLR(merged_df, refine_model):
|
2436
|
-
|
3261
|
+
"""
|
3262
|
+
Perform multiple linear regression (MLR) and extract interaction coefficients.
|
3263
|
+
|
3264
|
+
Args:
|
3265
|
+
merged_df (pd.DataFrame): Input DataFrame with data for regression analysis.
|
3266
|
+
refine_model (bool): Whether to remove outliers before refitting the model.
|
3267
|
+
|
3268
|
+
Returns:
|
3269
|
+
tuple: Contains:
|
3270
|
+
- max_effects (dict): Maximum interaction effect size per gene.
|
3271
|
+
- max_effects_pvalues (dict): Corresponding p-values.
|
3272
|
+
- model (statsmodels.regression.linear_model.RegressionResultsWrapper): Fitted regression model.
|
3273
|
+
- df (pd.DataFrame): DataFrame with sorted interaction effects and p-values.
|
3274
|
+
"""
|
2437
3275
|
from .plot import _reg_v_plot
|
2438
3276
|
|
2439
3277
|
#model = smf.ols("pred ~ gene + grna + gene:grna + plate + row + column", merged_df).fit()
|
@@ -2485,9 +3323,35 @@ def MLR(merged_df, refine_model):
|
|
2485
3323
|
return max_effects, max_effects_pvalues, model, df
|
2486
3324
|
|
2487
3325
|
def get_files_from_dir(dir_path, file_extension="*"):
|
3326
|
+
"""
|
3327
|
+
Retrieves a list of files from the specified directory that match the given file extension.
|
3328
|
+
|
3329
|
+
Args:
|
3330
|
+
dir_path (str): The path to the directory from which to retrieve files.
|
3331
|
+
file_extension (str, optional): The file extension to filter files by. Defaults to "*"
|
3332
|
+
(matches all files).
|
3333
|
+
|
3334
|
+
Returns:
|
3335
|
+
list: A list of file paths matching the specified file extension in the given directory.
|
3336
|
+
"""
|
2488
3337
|
return glob(os.path.join(dir_path, file_extension))
|
2489
3338
|
|
2490
3339
|
def create_circular_mask(h, w, center=None, radius=None):
|
3340
|
+
"""
|
3341
|
+
Creates a circular mask for a 2D array with the specified dimensions.
|
3342
|
+
|
3343
|
+
Args:
|
3344
|
+
h (int): The height of the 2D array.
|
3345
|
+
w (int): The width of the 2D array.
|
3346
|
+
center (tuple, optional): The (x, y) coordinates of the circle's center.
|
3347
|
+
Defaults to the center of the array.
|
3348
|
+
radius (int, optional): The radius of the circle. Defaults to the smallest
|
3349
|
+
distance from the center to the array's edges.
|
3350
|
+
|
3351
|
+
Returns:
|
3352
|
+
numpy.ndarray: A boolean 2D array where `True` represents the pixels
|
3353
|
+
inside the circle and `False` represents the pixels outside.
|
3354
|
+
"""
|
2491
3355
|
if center is None: # use the middle of the image
|
2492
3356
|
center = (int(w/2), int(h/2))
|
2493
3357
|
if radius is None: # use the smallest distance between the center and image walls
|
@@ -2500,6 +3364,16 @@ def create_circular_mask(h, w, center=None, radius=None):
|
|
2500
3364
|
return mask
|
2501
3365
|
|
2502
3366
|
def apply_mask(image, output_value=0):
|
3367
|
+
"""
|
3368
|
+
Apply a circular mask to an image, setting pixels outside the mask to a specified value.
|
3369
|
+
|
3370
|
+
Args:
|
3371
|
+
image (np.ndarray): Input image (2D grayscale or 3D RGB array).
|
3372
|
+
output_value (int, optional): Value for pixels outside the mask. Defaults to 0.
|
3373
|
+
|
3374
|
+
Returns:
|
3375
|
+
np.ndarray: Image with circular mask applied.
|
3376
|
+
"""
|
2503
3377
|
h, w = image.shape[:2] # Assuming image is grayscale or RGB
|
2504
3378
|
mask = create_circular_mask(h, w)
|
2505
3379
|
|
@@ -2512,12 +3386,53 @@ def apply_mask(image, output_value=0):
|
|
2512
3386
|
return masked_image
|
2513
3387
|
|
2514
3388
|
def invert_image(image):
|
3389
|
+
"""
|
3390
|
+
Inverts the pixel values of an image.
|
3391
|
+
|
3392
|
+
The function calculates the inverted image by subtracting each pixel value
|
3393
|
+
from the maximum possible value for the image's data type. For example,
|
3394
|
+
for an image with dtype `uint8`, the maximum value is 255.
|
3395
|
+
|
3396
|
+
Args:
|
3397
|
+
image (numpy.ndarray): The input image to be inverted. The image should
|
3398
|
+
be a NumPy array with a valid integer data type (e.g., uint8, uint16).
|
3399
|
+
|
3400
|
+
Returns:
|
3401
|
+
numpy.ndarray: The inverted image, with the same shape and dtype as the input.
|
3402
|
+
"""
|
2515
3403
|
# The maximum value depends on the image dtype (e.g., 255 for uint8)
|
2516
3404
|
max_value = np.iinfo(image.dtype).max
|
2517
3405
|
inverted_image = max_value - image
|
2518
3406
|
return inverted_image
|
2519
3407
|
|
2520
3408
|
def resize_images_and_labels(images, labels, target_height, target_width, show_example=True):
|
3409
|
+
"""
|
3410
|
+
Resize images and labels to the specified target dimensions.
|
3411
|
+
|
3412
|
+
Args:
|
3413
|
+
images (list or None): List of 2D or 3D numpy arrays representing input images.
|
3414
|
+
If None, only labels will be resized.
|
3415
|
+
labels (list or None): List of 2D numpy arrays representing label masks.
|
3416
|
+
If None, only images will be resized.
|
3417
|
+
target_height (int): Desired height of output arrays.
|
3418
|
+
target_width (int): Desired width of output arrays.
|
3419
|
+
show_example (bool, optional): Whether to display an example of original vs resized output.
|
3420
|
+
Defaults to True.
|
3421
|
+
|
3422
|
+
Returns:
|
3423
|
+
tuple: A tuple (resized_images, resized_labels), where:
|
3424
|
+
- resized_images (list): List of resized image arrays. Empty if `images` is None.
|
3425
|
+
- resized_labels (list): List of resized label arrays. Empty if `labels` is None.
|
3426
|
+
|
3427
|
+
Raises:
|
3428
|
+
ValueError: If both `images` and `labels` are None.
|
3429
|
+
|
3430
|
+
Notes:
|
3431
|
+
- Uses `resizescikit` for resizing.
|
3432
|
+
- Applies anti-aliasing when resizing images.
|
3433
|
+
- Uses nearest-neighbor interpolation (`order=0`) for labels to preserve class values.
|
3434
|
+
- Visualization of the resizing process is shown using `plot_resize` if `show_example` is True.
|
3435
|
+
"""
|
2521
3436
|
|
2522
3437
|
from .plot import plot_resize
|
2523
3438
|
|
@@ -2571,6 +3486,29 @@ def resize_images_and_labels(images, labels, target_height, target_width, show_e
|
|
2571
3486
|
return resized_images, resized_labels
|
2572
3487
|
|
2573
3488
|
def resize_labels_back(labels, orig_dims):
|
3489
|
+
"""
|
3490
|
+
Resize a list of label arrays back to their original dimensions.
|
3491
|
+
|
3492
|
+
Args:
|
3493
|
+
labels (list of numpy.ndarray): A list of label arrays to be resized.
|
3494
|
+
orig_dims (list of tuple): A list of tuples where each tuple contains
|
3495
|
+
two integers representing the original dimensions (width, height)
|
3496
|
+
of the corresponding label array.
|
3497
|
+
|
3498
|
+
Returns:
|
3499
|
+
list of numpy.ndarray: A list of resized label arrays with dimensions
|
3500
|
+
matching the corresponding tuples in `orig_dims`.
|
3501
|
+
|
3502
|
+
Raises:
|
3503
|
+
ValueError: If the length of `labels` and `orig_dims` do not match.
|
3504
|
+
ValueError: If any element in `orig_dims` is not a tuple of two integers.
|
3505
|
+
|
3506
|
+
Notes:
|
3507
|
+
- The resizing operation uses nearest-neighbor interpolation (order=0).
|
3508
|
+
- The `preserve_range` parameter ensures that the data range of the
|
3509
|
+
input is preserved during resizing.
|
3510
|
+
- Anti-aliasing is disabled for this operation.
|
3511
|
+
"""
|
2574
3512
|
resized_labels = []
|
2575
3513
|
|
2576
3514
|
if len(labels) != len(orig_dims):
|
@@ -2587,12 +3525,37 @@ def resize_labels_back(labels, orig_dims):
|
|
2587
3525
|
return resized_labels
|
2588
3526
|
|
2589
3527
|
def calculate_iou(mask1, mask2):
|
3528
|
+
"""
|
3529
|
+
Calculate the Intersection over Union (IoU) between two binary masks.
|
3530
|
+
|
3531
|
+
The IoU is a measure of the overlap between two binary masks, defined as the
|
3532
|
+
ratio of the intersection area to the union area of the masks.
|
3533
|
+
|
3534
|
+
Args:
|
3535
|
+
mask1 (numpy.ndarray): The first binary mask. Must be a 2D array.
|
3536
|
+
mask2 (numpy.ndarray): The second binary mask. Must be a 2D array.
|
3537
|
+
|
3538
|
+
Returns:
|
3539
|
+
float: The IoU value, ranging from 0 to 1. Returns 0 if the union of the
|
3540
|
+
masks is empty.
|
3541
|
+
"""
|
2590
3542
|
mask1, mask2 = pad_to_same_shape(mask1, mask2)
|
2591
3543
|
intersection = np.logical_and(mask1, mask2).sum()
|
2592
3544
|
union = np.logical_or(mask1, mask2).sum()
|
2593
3545
|
return intersection / union if union != 0 else 0
|
2594
3546
|
|
2595
3547
|
def match_masks(true_masks, pred_masks, iou_threshold):
|
3548
|
+
"""
|
3549
|
+
Matches predicted masks to ground truth masks based on Intersection over Union (IoU) threshold.
|
3550
|
+
|
3551
|
+
Args:
|
3552
|
+
true_masks (list): A list of ground truth masks.
|
3553
|
+
pred_masks (list): A list of predicted masks.
|
3554
|
+
iou_threshold (float): The IoU threshold for determining a match between a true mask and a predicted mask.
|
3555
|
+
|
3556
|
+
Returns:
|
3557
|
+
list: A list of tuples where each tuple contains a matched pair of (true_mask, pred_mask).
|
3558
|
+
"""
|
2596
3559
|
matches = []
|
2597
3560
|
matched_true_masks_indices = set() # Use set to store indices of matched true masks
|
2598
3561
|
|
@@ -2607,6 +3570,20 @@ def match_masks(true_masks, pred_masks, iou_threshold):
|
|
2607
3570
|
return matches
|
2608
3571
|
|
2609
3572
|
def compute_average_precision(matches, num_true_masks, num_pred_masks):
|
3573
|
+
"""
|
3574
|
+
Computes the precision and recall based on the provided matches, number of true masks,
|
3575
|
+
and number of predicted masks.
|
3576
|
+
|
3577
|
+
Args:
|
3578
|
+
matches (list): A list of matched predictions to ground truth masks.
|
3579
|
+
num_true_masks (int): The total number of ground truth masks.
|
3580
|
+
num_pred_masks (int): The total number of predicted masks.
|
3581
|
+
|
3582
|
+
Returns:
|
3583
|
+
tuple: A tuple containing:
|
3584
|
+
- precision (float): The precision value, calculated as TP / (TP + FP).
|
3585
|
+
- recall (float): The recall value, calculated as TP / (TP + FN).
|
3586
|
+
"""
|
2610
3587
|
TP = len(matches)
|
2611
3588
|
FP = num_pred_masks - TP
|
2612
3589
|
FN = num_true_masks - TP
|
@@ -2615,6 +3592,20 @@ def compute_average_precision(matches, num_true_masks, num_pred_masks):
|
|
2615
3592
|
return precision, recall
|
2616
3593
|
|
2617
3594
|
def pad_to_same_shape(mask1, mask2):
|
3595
|
+
"""
|
3596
|
+
Pads two 2D arrays (masks) to the same shape by adding zero-padding to the
|
3597
|
+
right and bottom of each array as needed.
|
3598
|
+
Args:
|
3599
|
+
mask1 (numpy.ndarray): The first 2D array to be padded.
|
3600
|
+
mask2 (numpy.ndarray): The second 2D array to be padded.
|
3601
|
+
Returns:
|
3602
|
+
tuple: A tuple containing two 2D numpy arrays:
|
3603
|
+
- padded_mask1 (numpy.ndarray): The first array padded to match the shape of the larger array.
|
3604
|
+
- padded_mask2 (numpy.ndarray): The second array padded to match the shape of the larger array.
|
3605
|
+
Notes:
|
3606
|
+
- The padding is applied with constant values of 0.
|
3607
|
+
- The function assumes that both inputs are 2D arrays.
|
3608
|
+
"""
|
2618
3609
|
# Find the shape differences
|
2619
3610
|
shape_diff = np.array([max(mask1.shape[0], mask2.shape[0]) - mask1.shape[0],
|
2620
3611
|
max(mask1.shape[1], mask2.shape[1]) - mask1.shape[1]])
|
@@ -2629,6 +3620,24 @@ def pad_to_same_shape(mask1, mask2):
|
|
2629
3620
|
return padded_mask1, padded_mask2
|
2630
3621
|
|
2631
3622
|
def compute_ap_over_iou_thresholds(true_masks, pred_masks, iou_thresholds):
|
3623
|
+
"""
|
3624
|
+
Compute the Average Precision (AP) over a range of Intersection over Union (IoU) thresholds.
|
3625
|
+
|
3626
|
+
This function calculates the precision-recall pairs for each IoU threshold, validates that
|
3627
|
+
precision and recall values are within the range [0, 1], and computes the Average Precision
|
3628
|
+
by integrating the precision-recall curve using the trapezoidal rule.
|
3629
|
+
|
3630
|
+
Args:
|
3631
|
+
true_masks (list or array-like): Ground truth masks.
|
3632
|
+
pred_masks (list or array-like): Predicted masks.
|
3633
|
+
iou_thresholds (list or array-like): A list of IoU thresholds to evaluate.
|
3634
|
+
|
3635
|
+
Returns:
|
3636
|
+
float: The computed Average Precision (AP) over the specified IoU thresholds.
|
3637
|
+
|
3638
|
+
Raises:
|
3639
|
+
ValueError: If precision or recall values are out of the valid range [0, 1].
|
3640
|
+
"""
|
2632
3641
|
precision_recall_pairs = []
|
2633
3642
|
for iou_threshold in iou_thresholds:
|
2634
3643
|
matches = match_masks(true_masks, pred_masks, iou_threshold)
|
@@ -2645,6 +3654,23 @@ def compute_ap_over_iou_thresholds(true_masks, pred_masks, iou_thresholds):
|
|
2645
3654
|
return np.trapz(sorted_precisions, x=sorted_recalls)
|
2646
3655
|
|
2647
3656
|
def compute_segmentation_ap(true_masks, pred_masks, iou_thresholds=np.linspace(0.5, 0.95, 10)):
|
3657
|
+
"""
|
3658
|
+
Compute the Average Precision (AP) for segmentation masks over a range of IoU thresholds.
|
3659
|
+
|
3660
|
+
This function calculates the AP by comparing the ground truth masks (`true_masks`) with the
|
3661
|
+
predicted masks (`pred_masks`) across multiple Intersection over Union (IoU) thresholds.
|
3662
|
+
|
3663
|
+
Args:
|
3664
|
+
true_masks (ndarray): A binary array representing the ground truth segmentation masks.
|
3665
|
+
Each connected component is treated as a separate object.
|
3666
|
+
pred_masks (ndarray): A binary array representing the predicted segmentation masks.
|
3667
|
+
Each connected component is treated as a separate object.
|
3668
|
+
iou_thresholds (ndarray, optional): A 1D array of IoU thresholds to evaluate AP over.
|
3669
|
+
Defaults to `np.linspace(0.5, 0.95, 10)`.
|
3670
|
+
|
3671
|
+
Returns:
|
3672
|
+
float: The computed Average Precision (AP) over the specified IoU thresholds.
|
3673
|
+
"""
|
2648
3674
|
true_mask_labels = label(true_masks)
|
2649
3675
|
pred_mask_labels = label(pred_masks)
|
2650
3676
|
true_mask_regions = [region.image for region in regionprops(true_mask_labels)]
|
@@ -2652,11 +3678,42 @@ def compute_segmentation_ap(true_masks, pred_masks, iou_thresholds=np.linspace(0
|
|
2652
3678
|
return compute_ap_over_iou_thresholds(true_mask_regions, pred_mask_regions, iou_thresholds)
|
2653
3679
|
|
2654
3680
|
def jaccard_index(mask1, mask2):
|
3681
|
+
"""
|
3682
|
+
Computes the Jaccard Index (Intersection over Union) between two binary masks.
|
3683
|
+
|
3684
|
+
The Jaccard Index is a measure of similarity between two sets, defined as the size
|
3685
|
+
of the intersection divided by the size of the union of the sets.
|
3686
|
+
|
3687
|
+
Args:
|
3688
|
+
mask1 (numpy.ndarray): A binary mask (e.g., a 2D array of boolean or 0/1 values).
|
3689
|
+
mask2 (numpy.ndarray): Another binary mask of the same shape as `mask1`.
|
3690
|
+
|
3691
|
+
Returns:
|
3692
|
+
float: The Jaccard Index, a value between 0 and 1, where 1 indicates perfect overlap
|
3693
|
+
and 0 indicates no overlap.
|
3694
|
+
"""
|
2655
3695
|
intersection = np.logical_and(mask1, mask2)
|
2656
3696
|
union = np.logical_or(mask1, mask2)
|
2657
3697
|
return np.sum(intersection) / np.sum(union)
|
2658
3698
|
|
2659
3699
|
def dice_coefficient(mask1, mask2):
|
3700
|
+
"""
|
3701
|
+
Compute the Dice coefficient, a measure of overlap between two binary masks.
|
3702
|
+
|
3703
|
+
The Dice coefficient is defined as:
|
3704
|
+
``2 * |A ∩ B| / (|A| + |B|)``
|
3705
|
+
|
3706
|
+
where A and B are binary masks, ``|A ∩ B|`` is the number of overlapping non-zero elements,
|
3707
|
+
and ``|A| + |B|`` is the total number of non-zero elements in both masks.
|
3708
|
+
|
3709
|
+
Args:
|
3710
|
+
mask1 (numpy.ndarray): First binary mask.
|
3711
|
+
mask2 (numpy.ndarray): Second binary mask.
|
3712
|
+
|
3713
|
+
Returns:
|
3714
|
+
float: Dice coefficient between 0.0 (no overlap) and 1.0 (perfect overlap).
|
3715
|
+
Returns 1.0 if both masks are empty.
|
3716
|
+
"""
|
2660
3717
|
# Convert to binary masks
|
2661
3718
|
mask1 = np.where(mask1 > 0, 1, 0)
|
2662
3719
|
mask2 = np.where(mask2 > 0, 1, 0)
|
@@ -2673,6 +3730,22 @@ def dice_coefficient(mask1, mask2):
|
|
2673
3730
|
return 2.0 * intersection / total
|
2674
3731
|
|
2675
3732
|
def extract_boundaries(mask, dilation_radius=1):
|
3733
|
+
"""
|
3734
|
+
Extracts the boundaries of a binary mask by applying morphological dilation
|
3735
|
+
and erosion operations and computing their difference.
|
3736
|
+
|
3737
|
+
Args:
|
3738
|
+
mask (numpy.ndarray): A 2D array representing the input mask. Non-zero
|
3739
|
+
values are considered part of the mask.
|
3740
|
+
dilation_radius (int, optional): The radius of the structuring element
|
3741
|
+
used for dilation and erosion.
|
3742
|
+
Defaults to 1.
|
3743
|
+
|
3744
|
+
Returns:
|
3745
|
+
numpy.ndarray: A binary array of the same shape as the input mask,
|
3746
|
+
where the boundary pixels are marked as 1 and all
|
3747
|
+
other pixels are 0.
|
3748
|
+
"""
|
2676
3749
|
binary_mask = (mask > 0).astype(np.uint8)
|
2677
3750
|
struct_elem = np.ones((dilation_radius*2+1, dilation_radius*2+1))
|
2678
3751
|
dilated = morphology.binary_dilation(binary_mask, footprint=struct_elem)
|
@@ -2681,6 +3754,21 @@ def extract_boundaries(mask, dilation_radius=1):
|
|
2681
3754
|
return boundary
|
2682
3755
|
|
2683
3756
|
def boundary_f1_score(mask_true, mask_pred, dilation_radius=1):
|
3757
|
+
"""
|
3758
|
+
Calculate the boundary F1 score between two binary masks.
|
3759
|
+
|
3760
|
+
The boundary F1 score evaluates the alignment of object boundaries between
|
3761
|
+
predicted and ground truth masks. It computes the harmonic mean of boundary
|
3762
|
+
precision and recall after dilating the boundaries.
|
3763
|
+
|
3764
|
+
Args:
|
3765
|
+
mask_true (np.ndarray): Ground truth binary mask. Non-zero values are considered foreground.
|
3766
|
+
mask_pred (np.ndarray): Predicted binary mask. Non-zero values are considered foreground.
|
3767
|
+
dilation_radius (int, optional): Radius used for boundary dilation to allow tolerance in matching. Defaults to 1.
|
3768
|
+
|
3769
|
+
Returns:
|
3770
|
+
float: Boundary F1 score between 0 (no boundary match) and 1 (perfect boundary alignment).
|
3771
|
+
"""
|
2684
3772
|
# Assume extract_boundaries is defined to extract object boundaries with given dilation_radius
|
2685
3773
|
boundary_true = extract_boundaries(mask_true, dilation_radius)
|
2686
3774
|
boundary_pred = extract_boundaries(mask_pred, dilation_radius)
|
@@ -2697,8 +3785,6 @@ def boundary_f1_score(mask_true, mask_pred, dilation_radius=1):
|
|
2697
3785
|
|
2698
3786
|
return f1
|
2699
3787
|
|
2700
|
-
|
2701
|
-
|
2702
3788
|
def _remove_noninfected(stack, cell_dim, nucleus_dim, pathogen_dim):
|
2703
3789
|
"""
|
2704
3790
|
Remove non-infected cells from the stack based on the provided dimensions.
|
@@ -3179,6 +4265,23 @@ def _object_filter(df, object_type, size_range, intensity_range, mask_chans, mas
|
|
3179
4265
|
return df
|
3180
4266
|
|
3181
4267
|
def _get_regex(metadata_type, img_format, custom_regex=None):
|
4268
|
+
"""
|
4269
|
+
Generates a regex pattern based on the specified metadata type and image format.
|
4270
|
+
Args:
|
4271
|
+
metadata_type (str): The type of metadata to generate the regex for.
|
4272
|
+
Supported values are:
|
4273
|
+
- 'cellvoyager': Generates a regex for CellVoyager metadata format.
|
4274
|
+
- 'cq1': Generates a regex for CQ1 metadata format.
|
4275
|
+
- 'auto': Generates a regex for automatic detection of metadata.
|
4276
|
+
- 'custom': Uses a custom regex provided by the user.
|
4277
|
+
img_format (str): The image file format (e.g., 'tif'). If None, defaults to 'tif'.
|
4278
|
+
custom_regex (str, optional): A custom regex pattern to use when metadata_type is 'custom'.
|
4279
|
+
Returns:
|
4280
|
+
str: The generated regex pattern.
|
4281
|
+
Notes:
|
4282
|
+
- Prints the image format and the generated regex for debugging purposes.
|
4283
|
+
- If `img_format` is None, it defaults to 'tif'.
|
4284
|
+
"""
|
3182
4285
|
|
3183
4286
|
print(f"Image_format: {img_format}")
|
3184
4287
|
|
@@ -3197,6 +4300,27 @@ def _get_regex(metadata_type, img_format, custom_regex=None):
|
|
3197
4300
|
return regex
|
3198
4301
|
|
3199
4302
|
def _run_test_mode(src, regex, timelapse=False, test_images=10, random_test=True):
|
4303
|
+
"""
|
4304
|
+
Prepares a test dataset by selecting and copying a subset of images from the source directory
|
4305
|
+
based on a regular expression and other criteria.
|
4306
|
+
Args:
|
4307
|
+
src (str): The source directory containing the images.
|
4308
|
+
regex (str): A regular expression to match filenames. The regex should include named groups
|
4309
|
+
'plateID', 'wellID', and 'fieldID' if applicable.
|
4310
|
+
timelapse (bool, optional): If True, limits the selection to one set of images to ensure
|
4311
|
+
full sequence inclusion. Defaults to False.
|
4312
|
+
test_images (int, optional): The number of image sets to include in the test dataset.
|
4313
|
+
Defaults to 10.
|
4314
|
+
random_test (bool, optional): If True, randomizes the selection of image sets. Defaults to True.
|
4315
|
+
Returns:
|
4316
|
+
str: The path to the test folder containing the selected images.
|
4317
|
+
Notes:
|
4318
|
+
- If a subdirectory named 'orig' exists within the source directory, it will be used as the
|
4319
|
+
source for image selection.
|
4320
|
+
- The selected images are grouped by a combination of plate, well, and field identifiers
|
4321
|
+
extracted from the filename using the provided regex.
|
4322
|
+
- The function ensures the test folder is created and populated with the selected images.
|
4323
|
+
"""
|
3200
4324
|
|
3201
4325
|
if timelapse:
|
3202
4326
|
test_images = 1 # Use only 1 set for timelapse to ensure full sequence inclusion
|
@@ -3241,6 +4365,51 @@ def _run_test_mode(src, regex, timelapse=False, test_images=10, random_test=True
|
|
3241
4365
|
return test_folder_path
|
3242
4366
|
|
3243
4367
|
def _choose_model(model_name, device, object_type='cell', restore_type=None, object_settings={}):
|
4368
|
+
"""
|
4369
|
+
Selects and initializes a model based on the provided parameters.
|
4370
|
+
|
4371
|
+
Args:
|
4372
|
+
-----------
|
4373
|
+
model_name : str
|
4374
|
+
The name of the model to use. Supported values include:
|
4375
|
+
- 'toxo_pv_lumen' (specific to 'pathogen' object_type)
|
4376
|
+
- 'sam', 'cyto', 'cyto2', 'cyto3', 'nuclei' (for default models)
|
4377
|
+
device : torch.device
|
4378
|
+
The device on which the model will run (e.g., CPU or GPU).
|
4379
|
+
object_type : str, optional
|
4380
|
+
The type of object the model is intended for. Default is 'cell'.
|
4381
|
+
Supported values:
|
4382
|
+
- 'pathogen': Uses a specific model for Toxoplasma PV lumen.
|
4383
|
+
- 'nucleus': Uses denoising or restoration models.
|
4384
|
+
- 'cell': Uses default Cellpose models.
|
4385
|
+
restore_type : str or None, optional
|
4386
|
+
The type of restoration to apply. Supported values:
|
4387
|
+
- 'denoise', 'deblur', 'upsample', or None (default).
|
4388
|
+
If None, standard models are used.
|
4389
|
+
object_settings : dict, optional
|
4390
|
+
Additional settings for specific object types. For example:
|
4391
|
+
- For 'pathogen', requires 'diameter' key to specify object diameter.
|
4392
|
+
|
4393
|
+
Returns:
|
4394
|
+
--------
|
4395
|
+
model : CellposeModel or CellposeDenoiseModel
|
4396
|
+
The initialized model based on the provided parameters.
|
4397
|
+
|
4398
|
+
Notes:
|
4399
|
+
------
|
4400
|
+
- For 'pathogen' object_type with 'toxo_pv_lumen' model_name, a custom
|
4401
|
+
pretrained model is loaded from the 'models/cp' directory.
|
4402
|
+
- For 'nucleus' object_type, restoration models are initialized with
|
4403
|
+
specific restore types.
|
4404
|
+
- If an invalid restore_type is provided, it defaults to None with a warning.
|
4405
|
+
- GPU usage is determined automatically based on availability.
|
4406
|
+
|
4407
|
+
Raises:
|
4408
|
+
-------
|
4409
|
+
KeyError:
|
4410
|
+
If 'diameter' is not provided in `object_settings` when using the
|
4411
|
+
'toxo_pv_lumen' model for 'pathogen' object_type.
|
4412
|
+
"""
|
3244
4413
|
if object_type == 'pathogen':
|
3245
4414
|
if model_name == 'toxo_pv_lumen':
|
3246
4415
|
diameter = object_settings['diameter']
|
@@ -3292,47 +4461,46 @@ def _choose_model(model_name, device, object_type='cell', restore_type=None, obj
|
|
3292
4461
|
)
|
3293
4462
|
return model
|
3294
4463
|
|
3295
|
-
|
3296
|
-
|
3297
|
-
|
3298
|
-
if model_name == 'toxo_pv_lumen':
|
3299
|
-
diameter = object_settings['diameter']
|
3300
|
-
current_dir = os.path.dirname(__file__)
|
3301
|
-
model_path = os.path.join(current_dir, 'models', 'cp', 'toxo_pv_lumen.CP_model')
|
3302
|
-
print(model_path)
|
3303
|
-
model = cp_models.CellposeModel(gpu=torch.cuda.is_available(), model_type=None, pretrained_model=model_path, diam_mean=diameter, device=device)
|
3304
|
-
print(f'Using Toxoplasma PV lumen model to generate pathogen masks')
|
3305
|
-
return model
|
3306
|
-
|
3307
|
-
restore_list = ['denoise', 'deblur', 'upsample', None]
|
3308
|
-
if restore_type not in restore_list:
|
3309
|
-
print(f"Invalid restore type. Choose from {restore_list} defaulting to None")
|
3310
|
-
restore_type = None
|
4464
|
+
class SelectChannels:
|
4465
|
+
"""
|
4466
|
+
Selectively retain specific color channels in a 3-channel image tensor.
|
3311
4467
|
|
3312
|
-
|
3313
|
-
|
3314
|
-
model = cp_models.CellposeModel(gpu=torch.cuda.is_available(), model_type=model_name, device=device)
|
3315
|
-
return model
|
3316
|
-
else:
|
3317
|
-
if object_type == 'nucleus':
|
3318
|
-
restore = f'{restore_type}_nuclei'
|
3319
|
-
model = denoise.CellposeDenoiseModel(gpu=torch.cuda.is_available(), model_type="nuclei",restore_type=restore, chan2_restore=False, device=device)
|
3320
|
-
return model
|
4468
|
+
This transformation zeros out unselected channels based on a list of 1-based
|
4469
|
+
indices corresponding to RGB channels.
|
3321
4470
|
|
3322
|
-
|
3323
|
-
|
3324
|
-
|
3325
|
-
|
3326
|
-
|
3327
|
-
|
3328
|
-
model = denoise.CellposeDenoiseModel(gpu=torch.cuda.is_available(), model_type="cyto3",restore_type=restore, chan2_restore=chan2_restore, device=device)
|
3329
|
-
return model
|
4471
|
+
Args:
|
4472
|
+
channels (list of int): List of 1-based indices for channels to retain.
|
4473
|
+
Valid values are:
|
4474
|
+
- 1: Red
|
4475
|
+
- 2: Green
|
4476
|
+
- 3: Blue
|
3330
4477
|
|
3331
|
-
|
4478
|
+
Example:
|
4479
|
+
>>> select_channels = SelectChannels([1, 3])
|
4480
|
+
>>> modified_img = select_channels(img)
|
4481
|
+
|
4482
|
+
Note:
|
4483
|
+
The input image must be a PyTorch tensor with shape (3, H, W).
|
4484
|
+
"""
|
3332
4485
|
def __init__(self, channels):
|
4486
|
+
"""
|
4487
|
+
Initialize the SelectChannels instance.
|
4488
|
+
|
4489
|
+
Args:
|
4490
|
+
channels (list of int): Channels (1-based) to retain.
|
4491
|
+
"""
|
3333
4492
|
self.channels = channels
|
3334
4493
|
|
3335
4494
|
def __call__(self, img):
|
4495
|
+
"""
|
4496
|
+
Apply the channel selection to the input image tensor.
|
4497
|
+
|
4498
|
+
Args:
|
4499
|
+
img (torch.Tensor): A 3-channel image tensor of shape (3, H, W).
|
4500
|
+
|
4501
|
+
Returns:
|
4502
|
+
torch.Tensor: A tensor where unselected channels are set to zero.
|
4503
|
+
"""
|
3336
4504
|
img = img.clone()
|
3337
4505
|
if 1 not in self.channels:
|
3338
4506
|
img[0, :, :] = 0 # Zero out the red channel
|
@@ -3343,6 +4511,20 @@ class SelectChannels:
|
|
3343
4511
|
return img
|
3344
4512
|
|
3345
4513
|
def preprocess_image_v1(image_path, image_size=224, channels=[1,2,3], normalize=True):
|
4514
|
+
"""
|
4515
|
+
Preprocess an image for input into a machine learning model.
|
4516
|
+
|
4517
|
+
Args:
|
4518
|
+
image_path (str): The file path to the image to be processed.
|
4519
|
+
image_size (int, optional): The size to which the image will be center-cropped. Defaults to 224.
|
4520
|
+
channels (list, optional): A list of channel indices to select from the image. Defaults to [1, 2, 3].
|
4521
|
+
normalize (bool, optional): Whether to normalize the image using mean and standard deviation. Defaults to True.
|
4522
|
+
|
4523
|
+
Returns:
|
4524
|
+
tuple: A tuple containing:
|
4525
|
+
- image (PIL.Image.Image): The original image loaded as a PIL Image object.
|
4526
|
+
- input_tensor (torch.Tensor): The preprocessed image as a PyTorch tensor with an added batch dimension.
|
4527
|
+
"""
|
3346
4528
|
|
3347
4529
|
if normalize:
|
3348
4530
|
transform = transforms.Compose([
|
@@ -3361,10 +4543,27 @@ def preprocess_image_v1(image_path, image_size=224, channels=[1,2,3], normalize=
|
|
3361
4543
|
return image, input_tensor
|
3362
4544
|
|
3363
4545
|
class SaliencyMapGenerator:
|
4546
|
+
"""
|
4547
|
+
SaliencyMapGenerator is a utility class for generating saliency maps and visualizing model predictions
|
4548
|
+
for PyTorch models in binary classification tasks.
|
4549
|
+
"""
|
3364
4550
|
def __init__(self, model):
|
4551
|
+
"""
|
4552
|
+
Initialize the SaliencyMapGenerator.
|
4553
|
+
"""
|
3365
4554
|
self.model = model
|
3366
4555
|
|
3367
4556
|
def compute_saliency_maps(self, X, y):
|
4557
|
+
"""
|
4558
|
+
Compute saliency maps for the given inputs and target labels.
|
4559
|
+
|
4560
|
+
Args:
|
4561
|
+
X (torch.Tensor): Input tensor with requires_grad enabled.
|
4562
|
+
y (torch.Tensor): Ground truth labels for the input samples.
|
4563
|
+
|
4564
|
+
Returns:
|
4565
|
+
torch.Tensor: The computed saliency maps, same shape as input.
|
4566
|
+
"""
|
3368
4567
|
self.model.eval()
|
3369
4568
|
X.requires_grad_()
|
3370
4569
|
|
@@ -3381,6 +4580,17 @@ class SaliencyMapGenerator:
|
|
3381
4580
|
return saliency
|
3382
4581
|
|
3383
4582
|
def compute_saliency_and_predictions(self, X):
|
4583
|
+
"""
|
4584
|
+
Compute saliency maps and predictions for the given input batch.
|
4585
|
+
|
4586
|
+
Args:
|
4587
|
+
X (torch.Tensor): Input tensor with requires_grad enabled.
|
4588
|
+
|
4589
|
+
Returns:
|
4590
|
+
tuple:
|
4591
|
+
- torch.Tensor: Saliency maps for the input.
|
4592
|
+
- torch.Tensor: Predicted class labels.
|
4593
|
+
"""
|
3384
4594
|
self.model.eval()
|
3385
4595
|
X.requires_grad_()
|
3386
4596
|
|
@@ -3400,6 +4610,19 @@ class SaliencyMapGenerator:
|
|
3400
4610
|
return saliency, predictions
|
3401
4611
|
|
3402
4612
|
def plot_activation_grid(self, X, saliency, predictions, overlay=True, normalize=False):
|
4613
|
+
"""
|
4614
|
+
Plot a grid of input images with overlaid saliency maps and predicted labels.
|
4615
|
+
|
4616
|
+
Args:
|
4617
|
+
X (torch.Tensor): Input tensor of images (N, C, H, W).
|
4618
|
+
saliency (torch.Tensor): Corresponding saliency maps (N, C, H, W).
|
4619
|
+
predictions (torch.Tensor): Predicted class labels (N,).
|
4620
|
+
overlay (bool): Whether to overlay saliency maps on input images. Default is True.
|
4621
|
+
normalize (bool): Whether to normalize input images by percentiles. Default is False.
|
4622
|
+
|
4623
|
+
Returns:
|
4624
|
+
matplotlib.figure.Figure: A matplotlib figure object showing the grid.
|
4625
|
+
"""
|
3403
4626
|
N = X.shape[0]
|
3404
4627
|
rows = (N + 7) // 8
|
3405
4628
|
fig, axs = plt.subplots(rows, 8, figsize=(16, rows * 2))
|
@@ -3428,6 +4651,17 @@ class SaliencyMapGenerator:
|
|
3428
4651
|
return fig
|
3429
4652
|
|
3430
4653
|
def percentile_normalize(self, img, lower_percentile=2, upper_percentile=98):
|
4654
|
+
"""
|
4655
|
+
Normalize an image's intensity per channel using percentile clipping.
|
4656
|
+
|
4657
|
+
Args:
|
4658
|
+
img (np.ndarray): Image of shape (H, W, C) to be normalized.
|
4659
|
+
lower_percentile (int): Lower percentile for intensity clipping. Default is 2.
|
4660
|
+
upper_percentile (int): Upper percentile for intensity clipping. Default is 98.
|
4661
|
+
|
4662
|
+
Returns:
|
4663
|
+
np.ndarray: Percentile-normalized image.
|
4664
|
+
"""
|
3431
4665
|
img_normalized = np.zeros_like(img)
|
3432
4666
|
|
3433
4667
|
for c in range(img.shape[2]): # Iterate over each channel
|
@@ -3438,7 +4672,19 @@ class SaliencyMapGenerator:
|
|
3438
4672
|
return img_normalized
|
3439
4673
|
|
3440
4674
|
class GradCAMGenerator:
|
4675
|
+
"""
|
4676
|
+
GradCAMGenerator generates Grad-CAM (Gradient-weighted Class Activation Mapping) visualizations
|
4677
|
+
for CNN-based PyTorch models, supporting binary classification tasks.
|
4678
|
+
"""
|
3441
4679
|
def __init__(self, model, target_layer, cam_type='gradcam'):
|
4680
|
+
"""
|
4681
|
+
Initialize the GradCAMGenerator and register hooks.
|
4682
|
+
|
4683
|
+
Args:
|
4684
|
+
model (torch.nn.Module): A trained PyTorch model.
|
4685
|
+
target_layer (str): The name of the layer to compute Grad-CAM on.
|
4686
|
+
cam_type (str): Type of CAM method (default is 'gradcam').
|
4687
|
+
"""
|
3442
4688
|
self.model = model
|
3443
4689
|
self.model.eval()
|
3444
4690
|
self.target_layer = target_layer
|
@@ -3451,6 +4697,10 @@ class GradCAMGenerator:
|
|
3451
4697
|
self.hook_layers()
|
3452
4698
|
|
3453
4699
|
def hook_layers(self):
|
4700
|
+
"""
|
4701
|
+
Register forward and backward hooks to capture activations and gradients
|
4702
|
+
from the specified target layer during inference and backpropagation.
|
4703
|
+
"""
|
3454
4704
|
# Forward hook to get activations
|
3455
4705
|
def forward_hook(module, input, output):
|
3456
4706
|
self.activations = output
|
@@ -3463,6 +4713,16 @@ class GradCAMGenerator:
|
|
3463
4713
|
self.target_layer_module.register_backward_hook(backward_hook)
|
3464
4714
|
|
3465
4715
|
def get_layer(self, model, target_layer):
|
4716
|
+
"""
|
4717
|
+
Recursively retrieve a layer object from a nested model.
|
4718
|
+
|
4719
|
+
Args:
|
4720
|
+
model (torch.nn.Module): The model containing the layer.
|
4721
|
+
target_layer (str): Dot-separated string path to the desired layer.
|
4722
|
+
|
4723
|
+
Returns:
|
4724
|
+
torch.nn.Module: The resolved layer module.
|
4725
|
+
"""
|
3466
4726
|
# Recursively find the layer specified in target_layer
|
3467
4727
|
modules = target_layer.split('.')
|
3468
4728
|
layer = model
|
@@ -3471,6 +4731,16 @@ class GradCAMGenerator:
|
|
3471
4731
|
return layer
|
3472
4732
|
|
3473
4733
|
def compute_gradcam_maps(self, X, y):
|
4734
|
+
"""
|
4735
|
+
Compute Grad-CAM heatmaps for an input batch and target labels.
|
4736
|
+
|
4737
|
+
Args:
|
4738
|
+
X (torch.Tensor): Input tensor of shape (N, C, H, W).
|
4739
|
+
y (torch.Tensor): Target labels (0 or 1) for each sample.
|
4740
|
+
|
4741
|
+
Returns:
|
4742
|
+
np.ndarray: Grad-CAM heatmaps normalized to [0, 1] for each input.
|
4743
|
+
"""
|
3474
4744
|
X.requires_grad_()
|
3475
4745
|
|
3476
4746
|
# Forward pass
|
@@ -3495,6 +4765,17 @@ class GradCAMGenerator:
|
|
3495
4765
|
return gradcam
|
3496
4766
|
|
3497
4767
|
def compute_gradcam_and_predictions(self, X):
|
4768
|
+
"""
|
4769
|
+
Compute Grad-CAM heatmaps and class predictions for a batch.
|
4770
|
+
|
4771
|
+
Args:
|
4772
|
+
X (torch.Tensor): Input tensor of shape (N, C, H, W).
|
4773
|
+
|
4774
|
+
Returns:
|
4775
|
+
Tuple[torch.Tensor, torch.Tensor]:
|
4776
|
+
- Grad-CAM heatmaps for each sample.
|
4777
|
+
- Predicted class labels (0 or 1).
|
4778
|
+
"""
|
3498
4779
|
self.model.eval()
|
3499
4780
|
X.requires_grad_()
|
3500
4781
|
|
@@ -3513,6 +4794,19 @@ class GradCAMGenerator:
|
|
3513
4794
|
return torch.tensor(gradcam_maps), predictions
|
3514
4795
|
|
3515
4796
|
def plot_activation_grid(self, X, gradcam, predictions, overlay=True, normalize=False):
|
4797
|
+
"""
|
4798
|
+
Plot a grid of input images overlaid with Grad-CAM heatmaps.
|
4799
|
+
|
4800
|
+
Args:
|
4801
|
+
X (torch.Tensor): Input image batch (N, C, H, W).
|
4802
|
+
gradcam (torch.Tensor): Grad-CAM heatmaps (N, H, W).
|
4803
|
+
predictions (torch.Tensor): Predicted class labels.
|
4804
|
+
overlay (bool): Whether to overlay Grad-CAM on input images.
|
4805
|
+
normalize (bool): Whether to normalize image intensities by percentiles.
|
4806
|
+
|
4807
|
+
Returns:
|
4808
|
+
matplotlib.figure.Figure: The generated grid figure.
|
4809
|
+
"""
|
3516
4810
|
N = X.shape[0]
|
3517
4811
|
rows = (N + 7) // 8
|
3518
4812
|
fig, axs = plt.subplots(rows, 8, figsize=(16, rows * 2))
|
@@ -3541,6 +4835,17 @@ class GradCAMGenerator:
|
|
3541
4835
|
return fig
|
3542
4836
|
|
3543
4837
|
def percentile_normalize(self, img, lower_percentile=2, upper_percentile=98):
|
4838
|
+
"""
|
4839
|
+
Normalize each channel of the input image to the specified percentiles.
|
4840
|
+
|
4841
|
+
Args:
|
4842
|
+
img (np.ndarray): Image array (H, W, C).
|
4843
|
+
lower_percentile (int): Lower clipping percentile.
|
4844
|
+
upper_percentile (int): Upper clipping percentile.
|
4845
|
+
|
4846
|
+
Returns:
|
4847
|
+
np.ndarray: Percentile-normalized image.
|
4848
|
+
"""
|
3544
4849
|
img_normalized = np.zeros_like(img)
|
3545
4850
|
|
3546
4851
|
for c in range(img.shape[2]): # Iterate over each channel
|
@@ -3551,6 +4856,23 @@ class GradCAMGenerator:
|
|
3551
4856
|
return img_normalized
|
3552
4857
|
|
3553
4858
|
def preprocess_image(image_path, normalize=True, image_size=224, channels=[1,2,3]):
|
4859
|
+
"""
|
4860
|
+
Preprocess an image for input into a machine learning model.
|
4861
|
+
|
4862
|
+
Args:
|
4863
|
+
image_path (str): Path to the input image file.
|
4864
|
+
normalize (bool, optional): If True, apply ImageNet normalization
|
4865
|
+
(mean and std). Defaults to True.
|
4866
|
+
image_size (int, optional): Target size (height and width) for resizing.
|
4867
|
+
Defaults to 224.
|
4868
|
+
channels (list of int, optional): 1-based channel indices to retain (e.g., [1, 2, 3]
|
4869
|
+
for RGB). Defaults to [1, 2, 3].
|
4870
|
+
|
4871
|
+
Returns:
|
4872
|
+
tuple:
|
4873
|
+
- PIL.Image.Image: The original image.
|
4874
|
+
- torch.Tensor: The processed image tensor suitable for model input.
|
4875
|
+
"""
|
3554
4876
|
preprocess = transforms.Compose([
|
3555
4877
|
transforms.Resize((image_size, image_size)),
|
3556
4878
|
transforms.ToTensor(),
|
@@ -3565,6 +4887,29 @@ def preprocess_image(image_path, normalize=True, image_size=224, channels=[1,2,3
|
|
3565
4887
|
return image, input_tensor
|
3566
4888
|
|
3567
4889
|
def class_visualization(target_y, model_path, dtype, img_size=224, channels=[0,1,2], l2_reg=1e-3, learning_rate=25, num_iterations=100, blur_every=10, max_jitter=16, show_every=25, class_names = ['nc', 'pc']):
|
4890
|
+
def class_visualization(target_y, model_path, dtype, img_size=224, channels=[0,1,2], l2_reg=1e-3, learning_rate=25, num_iterations=100, blur_every=10, max_jitter=16, show_every=25, class_names=['nc', 'pc']):
|
4891
|
+
"""
|
4892
|
+
Generate a class visualization for a given target class using gradient ascent.
|
4893
|
+
Args:
|
4894
|
+
target_y (int): The target class index for which the visualization is generated.
|
4895
|
+
model_path (str): Path to the pre-trained model file.
|
4896
|
+
dtype (torch.dtype): Data type for the tensors (e.g., torch.FloatTensor or torch.cuda.FloatTensor).
|
4897
|
+
img_size (int, optional): Size of the square input image. Default is 224.
|
4898
|
+
channels (list, optional): List of channel indices to use in the input image. Default is [0, 1, 2].
|
4899
|
+
l2_reg (float, optional): L2 regularization strength to prevent overfitting. Default is 1e-3.
|
4900
|
+
learning_rate (float, optional): Learning rate for gradient ascent. Default is 25.
|
4901
|
+
num_iterations (int, optional): Number of iterations for gradient ascent. Default is 100.
|
4902
|
+
blur_every (int, optional): Frequency (in iterations) to apply Gaussian blur as a regularizer. Default is 10.
|
4903
|
+
max_jitter (int, optional): Maximum pixel jitter for random image translation. Default is 16.
|
4904
|
+
show_every (int, optional): Frequency (in iterations) to display the intermediate visualization. Default is 25.
|
4905
|
+
class_names (list, optional): List of class names corresponding to the target indices. Default is ['nc', 'pc'].
|
4906
|
+
Returns:
|
4907
|
+
numpy.ndarray: The final deprocessed image as a NumPy array.
|
4908
|
+
Notes:
|
4909
|
+
- The function performs gradient ascent on the input image to maximize the score of the target class.
|
4910
|
+
- Regularization techniques such as L2 regularization, clamping, and Gaussian blur are applied to improve visualization quality.
|
4911
|
+
- The function assumes the model is compatible with the input image size and channels.
|
4912
|
+
"""
|
3568
4913
|
|
3569
4914
|
def jitter(img, ox, oy):
|
3570
4915
|
# Randomly jitter the image
|
@@ -3645,6 +4990,18 @@ def class_visualization(target_y, model_path, dtype, img_size=224, channels=[0,1
|
|
3645
4990
|
return deprocess(img.data.cpu())
|
3646
4991
|
|
3647
4992
|
def get_submodules(model, prefix=''):
|
4993
|
+
"""
|
4994
|
+
Recursively retrieves the names of all submodules in a given model.
|
4995
|
+
|
4996
|
+
Args:
|
4997
|
+
model (torch.nn.Module): The model whose submodules are to be retrieved.
|
4998
|
+
prefix (str, optional): A prefix to prepend to the names of the submodules.
|
4999
|
+
Defaults to an empty string.
|
5000
|
+
|
5001
|
+
Returns:
|
5002
|
+
list of str: A list of strings representing the full names of all submodules
|
5003
|
+
in the model, including nested submodules.
|
5004
|
+
"""
|
3648
5005
|
submodules = []
|
3649
5006
|
for name, module in model.named_children():
|
3650
5007
|
full_name = prefix + ('.' if prefix else '') + name
|
@@ -3653,7 +5010,18 @@ def get_submodules(model, prefix=''):
|
|
3653
5010
|
return submodules
|
3654
5011
|
|
3655
5012
|
class GradCAM:
|
5013
|
+
"""
|
5014
|
+
Compute Grad-CAM (Gradient-weighted Class Activation Mapping) for a given model and target layer(s).
|
5015
|
+
"""
|
3656
5016
|
def __init__(self, model, target_layers=None, use_cuda=True):
|
5017
|
+
"""
|
5018
|
+
Initialize the GradCAM object.
|
5019
|
+
|
5020
|
+
Args:
|
5021
|
+
model (nn.Module): The model for which Grad-CAM will be computed.
|
5022
|
+
target_layers (list of str): Names of layers to register hooks on.
|
5023
|
+
use_cuda (bool): Whether to use CUDA (GPU) for computation.
|
5024
|
+
"""
|
3657
5025
|
self.model = model
|
3658
5026
|
self.model.eval()
|
3659
5027
|
self.target_layers = target_layers
|
@@ -3662,9 +5030,28 @@ class GradCAM:
|
|
3662
5030
|
self.model = model.cuda()
|
3663
5031
|
|
3664
5032
|
def forward(self, input):
|
5033
|
+
"""
|
5034
|
+
Run a forward pass through the model.
|
5035
|
+
|
5036
|
+
Args:
|
5037
|
+
input (torch.Tensor): Input tensor.
|
5038
|
+
|
5039
|
+
Returns:
|
5040
|
+
torch.Tensor: Model output.
|
5041
|
+
"""
|
3665
5042
|
return self.model(input)
|
3666
5043
|
|
3667
5044
|
def __call__(self, x, index=None):
|
5045
|
+
"""
|
5046
|
+
Compute the Grad-CAM heatmap for an input image.
|
5047
|
+
|
5048
|
+
Args:
|
5049
|
+
x (torch.Tensor): Input tensor of shape (1, C, H, W).
|
5050
|
+
index (int or None): Class index to compute gradients for. If None, uses the predicted class.
|
5051
|
+
|
5052
|
+
Returns:
|
5053
|
+
numpy.ndarray: Normalized Grad-CAM heatmap of shape (H, W).
|
5054
|
+
"""
|
3668
5055
|
if self.cuda:
|
3669
5056
|
x = x.cuda()
|
3670
5057
|
|
@@ -3711,6 +5098,23 @@ class GradCAM:
|
|
3711
5098
|
return cam
|
3712
5099
|
|
3713
5100
|
def show_cam_on_image(img, mask):
|
5101
|
+
"""
|
5102
|
+
Overlay a heatmap generated from a mask onto an image.
|
5103
|
+
|
5104
|
+
This function applies a color map to the mask, combines it with the input
|
5105
|
+
image, and normalizes the result to create a visually interpretable
|
5106
|
+
representation of the mask overlaid on the image.
|
5107
|
+
|
5108
|
+
Args:
|
5109
|
+
img (numpy.ndarray): The input image as a NumPy array with pixel values
|
5110
|
+
normalized between 0 and 1.
|
5111
|
+
mask (numpy.ndarray): The mask to overlay on the image, with values
|
5112
|
+
normalized between 0 and 1.
|
5113
|
+
|
5114
|
+
Returns:
|
5115
|
+
numpy.ndarray: The resulting image with the heatmap overlay, as a
|
5116
|
+
NumPy array with pixel values in the range [0, 255].
|
5117
|
+
"""
|
3714
5118
|
heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)
|
3715
5119
|
heatmap = np.float32(heatmap) / 255
|
3716
5120
|
cam = heatmap + np.float32(img)
|
@@ -3718,6 +5122,24 @@ def show_cam_on_image(img, mask):
|
|
3718
5122
|
return np.uint8(255 * cam)
|
3719
5123
|
|
3720
5124
|
def recommend_target_layers(model):
|
5125
|
+
"""
|
5126
|
+
Identifies and recommends target layers in a given model for further processing.
|
5127
|
+
|
5128
|
+
This function iterates through all the modules in the provided model and collects
|
5129
|
+
the names of all 2D convolutional layers (`torch.nn.Conv2d`). It then recommends
|
5130
|
+
the last convolutional layer as the primary target layer.
|
5131
|
+
|
5132
|
+
Args:
|
5133
|
+
model (torch.nn.Module): The neural network model to analyze.
|
5134
|
+
|
5135
|
+
Returns:
|
5136
|
+
tuple:
|
5137
|
+
- list: A list containing the name of the recommended target layer (last Conv2d layer).
|
5138
|
+
- list: A list of all Conv2d layer names found in the model.
|
5139
|
+
|
5140
|
+
Raises:
|
5141
|
+
ValueError: If no convolutional layers (`torch.nn.Conv2d`) are found in the model.
|
5142
|
+
"""
|
3721
5143
|
target_layers = []
|
3722
5144
|
for name, module in model.named_modules():
|
3723
5145
|
if isinstance(module, torch.nn.Conv2d):
|
@@ -3729,11 +5151,38 @@ def recommend_target_layers(model):
|
|
3729
5151
|
raise ValueError("No convolutional layers found in the model.")
|
3730
5152
|
|
3731
5153
|
class IntegratedGradients:
|
5154
|
+
"""
|
5155
|
+
Compute Integrated Gradients for model interpretability.
|
5156
|
+
|
5157
|
+
This class implements the Integrated Gradients method to attribute the prediction
|
5158
|
+
of a neural network to its input features. It approximates the integral of gradients
|
5159
|
+
along a straight path from a baseline to the input.
|
5160
|
+
"""
|
5161
|
+
|
5162
|
+
|
3732
5163
|
def __init__(self, model):
|
5164
|
+
"""
|
5165
|
+
Initialize the IntegratedGradients instance.
|
5166
|
+
|
5167
|
+
Args:
|
5168
|
+
model (torch.nn.Module): A trained PyTorch model.
|
5169
|
+
"""
|
3733
5170
|
self.model = model
|
3734
5171
|
self.model.eval()
|
3735
5172
|
|
3736
5173
|
def generate_integrated_gradients(self, input_tensor, target_label_idx, baseline=None, num_steps=50):
|
5174
|
+
"""
|
5175
|
+
Compute the integrated gradients for a given input and target class.
|
5176
|
+
|
5177
|
+
Args:
|
5178
|
+
input_tensor (torch.Tensor): The input tensor of shape (1, C, H, W) or similar.
|
5179
|
+
target_label_idx (int): Index of the target class for which gradients are computed.
|
5180
|
+
baseline (torch.Tensor, optional): Baseline tensor with the same shape as input. Defaults to zeros.
|
5181
|
+
num_steps (int, optional): Number of steps in the Riemann approximation of the integral. Defaults to 50.
|
5182
|
+
|
5183
|
+
Returns:
|
5184
|
+
np.ndarray: Integrated gradients as a NumPy array with the same shape as `input_tensor`.
|
5185
|
+
"""
|
3737
5186
|
if baseline is None:
|
3738
5187
|
baseline = torch.zeros_like(input_tensor)
|
3739
5188
|
|
@@ -3753,18 +5202,61 @@ class IntegratedGradients:
|
|
3753
5202
|
return integrated_grads
|
3754
5203
|
|
3755
5204
|
def get_db_paths(src):
|
5205
|
+
"""
|
5206
|
+
Generate a list of database file paths based on the given source(s).
|
5207
|
+
|
5208
|
+
This function takes a single source path or a list of source paths and
|
5209
|
+
constructs the corresponding paths to the 'measurements.db' file located
|
5210
|
+
in the 'measurements' subdirectory of each source.
|
5211
|
+
|
5212
|
+
Args:
|
5213
|
+
src (str or list of str): A single source path as a string or a list
|
5214
|
+
of source paths.
|
5215
|
+
|
5216
|
+
Returns:
|
5217
|
+
list of str: A list of file paths pointing to 'measurements/measurements.db'
|
5218
|
+
for each source in the input.
|
5219
|
+
"""
|
3756
5220
|
if isinstance(src, str):
|
3757
5221
|
src = [src]
|
3758
5222
|
db_paths = [os.path.join(source, 'measurements/measurements.db') for source in src]
|
3759
5223
|
return db_paths
|
3760
5224
|
|
3761
5225
|
def get_sequencing_paths(src):
|
5226
|
+
"""
|
5227
|
+
Generate a list of file paths pointing to sequencing data CSV files.
|
5228
|
+
|
5229
|
+
This function takes a single source path or a list of source paths and
|
5230
|
+
constructs the full file paths to the 'sequencing/sequencing_data.csv'
|
5231
|
+
file located within each source directory.
|
5232
|
+
|
5233
|
+
Args:
|
5234
|
+
src (str or list of str): A single source directory path as a string
|
5235
|
+
or a list of source directory paths.
|
5236
|
+
|
5237
|
+
Returns:
|
5238
|
+
list of str: A list of full file paths to the 'sequencing_data.csv'
|
5239
|
+
files for each source directory.
|
5240
|
+
"""
|
3762
5241
|
if isinstance(src, str):
|
3763
5242
|
src = [src]
|
3764
5243
|
seq_paths = [os.path.join(source, 'sequencing/sequencing_data.csv') for source in src]
|
3765
5244
|
return seq_paths
|
3766
5245
|
|
3767
5246
|
def load_image_paths(c, visualize):
|
5247
|
+
"""
|
5248
|
+
Loads image paths from a database table and optionally filters them based on a visualization keyword.
|
5249
|
+
|
5250
|
+
Args:
|
5251
|
+
c (sqlite3.Cursor): A database cursor object used to execute SQL queries.
|
5252
|
+
visualize (str): A keyword to filter image paths. If provided, only rows where the 'png_path' column
|
5253
|
+
contains the keyword followed by '_png' will be included. If None or empty, no filtering
|
5254
|
+
is applied.
|
5255
|
+
|
5256
|
+
Returns:
|
5257
|
+
pandas.DataFrame: A DataFrame containing the image paths and other associated data from the 'png_list' table.
|
5258
|
+
The DataFrame is indexed by the 'prcfo' column.
|
5259
|
+
"""
|
3768
5260
|
c.execute(f'SELECT * FROM png_list')
|
3769
5261
|
data = c.fetchall()
|
3770
5262
|
columns_info = c.execute(f'PRAGMA table_info(png_list)').fetchall()
|
@@ -3777,6 +5269,18 @@ def load_image_paths(c, visualize):
|
|
3777
5269
|
return image_paths_df
|
3778
5270
|
|
3779
5271
|
def merge_dataframes(df, image_paths_df, verbose):
|
5272
|
+
"""
|
5273
|
+
Merges two pandas DataFrames on their indices and optionally displays the result.
|
5274
|
+
|
5275
|
+
Parameters:
|
5276
|
+
df (pandas.DataFrame): The main DataFrame to be merged. It must have a column named 'prcfo',
|
5277
|
+
which will be set as the index before merging.
|
5278
|
+
image_paths_df (pandas.DataFrame): The DataFrame containing image paths to be merged with `df`.
|
5279
|
+
verbose (bool): If True, the resulting merged DataFrame will be displayed.
|
5280
|
+
|
5281
|
+
Returns:
|
5282
|
+
pandas.DataFrame: The merged DataFrame with the indices aligned.
|
5283
|
+
"""
|
3780
5284
|
df.set_index('prcfo', inplace=True)
|
3781
5285
|
df = image_paths_df.merge(df, left_index=True, right_index=True)
|
3782
5286
|
if verbose:
|
@@ -3784,12 +5288,48 @@ def merge_dataframes(df, image_paths_df, verbose):
|
|
3784
5288
|
return df
|
3785
5289
|
|
3786
5290
|
def remove_highly_correlated_columns_v1(df, threshold):
|
5291
|
+
"""
|
5292
|
+
Removes columns from a DataFrame that are highly correlated with other columns.
|
5293
|
+
|
5294
|
+
This function calculates the correlation matrix of the given DataFrame, identifies
|
5295
|
+
columns with a correlation higher than the specified threshold, and removes them
|
5296
|
+
to reduce multicollinearity.
|
5297
|
+
|
5298
|
+
Args:
|
5299
|
+
df (pandas.DataFrame): The input DataFrame containing the data.
|
5300
|
+
threshold (float): The correlation threshold above which columns are considered
|
5301
|
+
highly correlated and will be removed.
|
5302
|
+
|
5303
|
+
Returns:
|
5304
|
+
pandas.DataFrame: A DataFrame with highly correlated columns removed.
|
5305
|
+
|
5306
|
+
Example:
|
5307
|
+
>>> import pandas as pd
|
5308
|
+
>>> import numpy as np
|
5309
|
+
>>> data = {'A': [1, 2, 3], 'B': [2, 4, 6], 'C': [7, 8, 9]}
|
5310
|
+
>>> df = pd.DataFrame(data)
|
5311
|
+
>>> remove_highly_correlated_columns_v1(df, threshold=0.9)
|
5312
|
+
A C
|
5313
|
+
"""
|
3787
5314
|
corr_matrix = df.corr().abs()
|
3788
5315
|
upper_tri = corr_matrix.where(np.triu(np.ones(corr_matrix.shape), k=1).astype(bool))
|
3789
5316
|
to_drop = [column for column in upper_tri.columns if any(upper_tri[column] > threshold)]
|
3790
5317
|
return df.drop(to_drop, axis=1)
|
3791
5318
|
|
3792
5319
|
def filter_columns(df, filter_by):
|
5320
|
+
"""
|
5321
|
+
Filters the columns of a DataFrame based on a specified criterion.
|
5322
|
+
|
5323
|
+
Args:
|
5324
|
+
df (pandas.DataFrame): The input DataFrame whose columns are to be filtered.
|
5325
|
+
filter_by (str): The criterion for filtering columns. If 'morphology',
|
5326
|
+
columns containing 'channel' in their names are excluded.
|
5327
|
+
Otherwise, only columns containing the specified string
|
5328
|
+
are included.
|
5329
|
+
|
5330
|
+
Returns:
|
5331
|
+
pandas.DataFrame: A DataFrame containing only the filtered columns.
|
5332
|
+
"""
|
3793
5333
|
if filter_by != 'morphology':
|
3794
5334
|
cols_to_include = [col for col in df.columns if filter_by in str(col)]
|
3795
5335
|
else:
|
@@ -3801,7 +5341,7 @@ def reduction_and_clustering(numeric_data, n_neighbors, min_dist, metric, eps, m
|
|
3801
5341
|
"""
|
3802
5342
|
Perform dimensionality reduction and clustering on the given data.
|
3803
5343
|
|
3804
|
-
|
5344
|
+
Args:
|
3805
5345
|
numeric_data (np.ndarray): Numeric data for embedding and clustering.
|
3806
5346
|
n_neighbors (int or float): Number of neighbors for UMAP or perplexity for t-SNE.
|
3807
5347
|
min_dist (float): Minimum distance for UMAP.
|
@@ -3903,12 +5443,51 @@ def reduction_and_clustering(numeric_data, n_neighbors, min_dist, metric, eps, m
|
|
3903
5443
|
return embedding, labels, reducer
|
3904
5444
|
|
3905
5445
|
def remove_noise(embedding, labels):
|
5446
|
+
"""
|
5447
|
+
Removes noise from the given embedding and labels by filtering out elements
|
5448
|
+
where the corresponding label is -1.
|
5449
|
+
|
5450
|
+
Args:
|
5451
|
+
embedding (numpy.ndarray): The embedding array, where each row corresponds
|
5452
|
+
to a data point.
|
5453
|
+
labels (numpy.ndarray): The array of labels corresponding to the embedding,
|
5454
|
+
where a label of -1 indicates noise.
|
5455
|
+
|
5456
|
+
Returns:
|
5457
|
+
tuple: A tuple containing:
|
5458
|
+
- numpy.ndarray: The filtered embedding array with noise removed.
|
5459
|
+
- numpy.ndarray: The filtered labels array with noise removed.
|
5460
|
+
"""
|
3906
5461
|
non_noise_indices = labels != -1
|
3907
5462
|
embedding = embedding[non_noise_indices]
|
3908
5463
|
labels = labels[non_noise_indices]
|
3909
5464
|
return embedding, labels
|
3910
5465
|
|
3911
5466
|
def plot_embedding(embedding, image_paths, labels, image_nr, img_zoom, colors, plot_by_cluster, plot_outlines, plot_points, plot_images, smooth_lines, black_background, figuresize, dot_size, remove_image_canvas, verbose):
|
5467
|
+
"""
|
5468
|
+
Plots a 2D embedding with optional images, clusters, and customization options.
|
5469
|
+
|
5470
|
+
Args:
|
5471
|
+
embedding (np.ndarray): A 2D array of shape (n_samples, 2) representing the embedding coordinates.
|
5472
|
+
image_paths (list or None): A list of file paths to images corresponding to the data points, or None if no images are used.
|
5473
|
+
labels (np.ndarray): An array of cluster labels for each data point.
|
5474
|
+
image_nr (int): The number of images to display on the plot.
|
5475
|
+
img_zoom (float): The zoom factor for the displayed images.
|
5476
|
+
colors (list or None): A list of colors to use for clusters, or None to use default colors.
|
5477
|
+
plot_by_cluster (bool): Whether to plot images grouped by cluster.
|
5478
|
+
plot_outlines (bool): Whether to draw outlines around clusters.
|
5479
|
+
plot_points (bool): Whether to plot individual data points.
|
5480
|
+
plot_images (bool): Whether to overlay images on the embedding.
|
5481
|
+
smooth_lines (bool): Whether to draw smooth lines between cluster centers.
|
5482
|
+
black_background (bool): Whether to use a black background for the plot.
|
5483
|
+
figuresize (tuple): The size of the figure in inches (width, height).
|
5484
|
+
dot_size (float): The size of the dots representing data points.
|
5485
|
+
remove_image_canvas (bool): Whether to remove the canvas around the images.
|
5486
|
+
verbose (bool): Whether to print verbose output during the plotting process.
|
5487
|
+
|
5488
|
+
Returns:
|
5489
|
+
matplotlib.figure.Figure: The generated plot as a Matplotlib figure object.
|
5490
|
+
"""
|
3912
5491
|
unique_labels = np.unique(labels)
|
3913
5492
|
#num_clusters = len(unique_labels[unique_labels != 0])
|
3914
5493
|
colors, label_to_color_index = assign_colors(unique_labels, colors)
|
@@ -3921,6 +5500,22 @@ def plot_embedding(embedding, image_paths, labels, image_nr, img_zoom, colors, p
|
|
3921
5500
|
return fig
|
3922
5501
|
|
3923
5502
|
def generate_colors(num_clusters, black_background):
|
5503
|
+
"""
|
5504
|
+
Generate a set of RGBA colors for visualization purposes.
|
5505
|
+
|
5506
|
+
This function generates a list of random RGBA colors, appends specific predefined colors,
|
5507
|
+
and optionally includes a black background color.
|
5508
|
+
|
5509
|
+
Args:
|
5510
|
+
num_clusters (int): The number of clusters for which colors need to be generated.
|
5511
|
+
Additional random colors will be generated beyond the predefined ones.
|
5512
|
+
black_background (bool): If True, a black background color ([0, 0, 0, 1]) will be included
|
5513
|
+
at the beginning of the color list.
|
5514
|
+
|
5515
|
+
Returns:
|
5516
|
+
numpy.ndarray: A 2D array of shape (num_colors, 4), where each row represents an RGBA color.
|
5517
|
+
The first dimension corresponds to the total number of colors generated.
|
5518
|
+
"""
|
3924
5519
|
random_colors = np.random.rand(num_clusters + 1, 4)
|
3925
5520
|
random_colors[:, 3] = 1
|
3926
5521
|
specific_colors = [
|
@@ -3935,6 +5530,19 @@ def generate_colors(num_clusters, black_background):
|
|
3935
5530
|
return random_colors
|
3936
5531
|
|
3937
5532
|
def assign_colors(unique_labels, random_colors):
|
5533
|
+
"""
|
5534
|
+
Assigns colors to unique labels and creates a mapping from labels to color indices.
|
5535
|
+
|
5536
|
+
Args:
|
5537
|
+
unique_labels (list or iterable): A collection of unique labels for which colors need to be assigned.
|
5538
|
+
random_colors (numpy.ndarray or list): An array or list of RGB color values, where each color is represented
|
5539
|
+
as a triplet of integers in the range [0, 255].
|
5540
|
+
|
5541
|
+
Returns:
|
5542
|
+
tuple: A tuple containing:
|
5543
|
+
- colors (list of tuple): A list of RGB color tuples in the original [0, 255] range.
|
5544
|
+
- label_to_color_index (dict): A dictionary mapping each unique label to its corresponding color index.
|
5545
|
+
"""
|
3938
5546
|
normalized_colors = random_colors / 255
|
3939
5547
|
colors_img = [tuple(color) for color in normalized_colors]
|
3940
5548
|
colors = [tuple(color) for color in random_colors]
|
@@ -3942,6 +5550,17 @@ def assign_colors(unique_labels, random_colors):
|
|
3942
5550
|
return colors, label_to_color_index
|
3943
5551
|
|
3944
5552
|
def setup_plot(figuresize, black_background):
|
5553
|
+
"""
|
5554
|
+
Sets up a matplotlib plot with specified figure size and background color.
|
5555
|
+
|
5556
|
+
Args:
|
5557
|
+
figuresize (float): The size of the figure in inches (used for both width and height).
|
5558
|
+
black_background (bool): If True, sets the plot to have a black background with white text and labels.
|
5559
|
+
If False, sets the plot to have a white background with black text and labels.
|
5560
|
+
|
5561
|
+
Returns:
|
5562
|
+
tuple: A tuple containing the figure (`matplotlib.figure.Figure`) and axes (`matplotlib.axes._axes.Axes`) objects.
|
5563
|
+
"""
|
3945
5564
|
if black_background:
|
3946
5565
|
plt.rcParams.update({'figure.facecolor': 'black', 'axes.facecolor': 'black', 'text.color': 'white', 'xtick.color': 'white', 'ytick.color': 'white', 'axes.labelcolor': 'white'})
|
3947
5566
|
else:
|
@@ -3950,6 +5569,30 @@ def setup_plot(figuresize, black_background):
|
|
3950
5569
|
return fig, ax
|
3951
5570
|
|
3952
5571
|
def plot_clusters(ax, embedding, labels, colors, cluster_centers, plot_outlines, plot_points, smooth_lines, figuresize=10, dot_size=50, verbose=False):
|
5572
|
+
"""
|
5573
|
+
Plots clusters on a 2D embedding using matplotlib.
|
5574
|
+
|
5575
|
+
Args:
|
5576
|
+
ax (matplotlib.axes.Axes): The matplotlib Axes object to plot on.
|
5577
|
+
embedding (numpy.ndarray): A 2D array of shape (n_samples, 2) representing the embedding coordinates.
|
5578
|
+
labels (numpy.ndarray): An array of cluster labels for each point in the embedding.
|
5579
|
+
colors (list): A list of colors corresponding to each cluster.
|
5580
|
+
cluster_centers (numpy.ndarray): A 2D array of shape (n_clusters, 2) representing the coordinates of cluster centers.
|
5581
|
+
plot_outlines (bool): Whether to plot the outlines of clusters using convex hulls or smoothed lines.
|
5582
|
+
plot_points (bool): Whether to plot individual points in the clusters.
|
5583
|
+
smooth_lines (bool): Whether to use smoothed lines for cluster outlines instead of convex hulls.
|
5584
|
+
figuresize (int, optional): The size of the figure. Defaults to 10.
|
5585
|
+
dot_size (int, optional): The size of the points in the scatter plot. Defaults to 50.
|
5586
|
+
verbose (bool, optional): Whether to print additional information for debugging. Defaults to False.
|
5587
|
+
|
5588
|
+
Returns:
|
5589
|
+
None
|
5590
|
+
|
5591
|
+
Notes:
|
5592
|
+
- This function assumes that the embedding is 2D.
|
5593
|
+
- Cluster labels should be integers, with -1 typically representing noise.
|
5594
|
+
- The function uses matplotlib for plotting and assumes that the required libraries (e.g., numpy, matplotlib) are imported.
|
5595
|
+
"""
|
3953
5596
|
unique_labels = np.unique(labels)
|
3954
5597
|
for cluster_label, color, center in zip(unique_labels, colors, cluster_centers):
|
3955
5598
|
cluster_data = embedding[labels == cluster_label]
|
@@ -3975,6 +5618,24 @@ def plot_clusters(ax, embedding, labels, colors, cluster_centers, plot_outlines,
|
|
3975
5618
|
plt.tick_params(axis='both', which='major', labelsize=int(figuresize * 0.75))
|
3976
5619
|
|
3977
5620
|
def plot_umap_images(ax, image_paths, embedding, labels, image_nr, img_zoom, colors, plot_by_cluster, remove_image_canvas, verbose):
|
5621
|
+
"""
|
5622
|
+
Plots UMAP embeddings with associated images on a given matplotlib axis.
|
5623
|
+
|
5624
|
+
Args:
|
5625
|
+
ax (matplotlib.axes.Axes): The matplotlib axis on which to plot the images.
|
5626
|
+
image_paths (list of str): List of file paths to the images to be plotted.
|
5627
|
+
embedding (numpy.ndarray): 2D array of UMAP embeddings with shape (n_samples, 2).
|
5628
|
+
labels (numpy.ndarray): Array of cluster labels for each embedding point.
|
5629
|
+
image_nr (int): Number of images to plot.
|
5630
|
+
img_zoom (float): Zoom factor for the images.
|
5631
|
+
colors (list): List of colors for each cluster.
|
5632
|
+
plot_by_cluster (bool): If True, plot images grouped by cluster; otherwise, plot randomly sampled images.
|
5633
|
+
remove_image_canvas (bool): If True, remove the image canvas (background) when plotting.
|
5634
|
+
verbose (bool): If True, print additional information during execution.
|
5635
|
+
|
5636
|
+
Returns:
|
5637
|
+
None
|
5638
|
+
"""
|
3978
5639
|
if plot_by_cluster:
|
3979
5640
|
cluster_indices = {label: np.where(labels == label)[0] for label in np.unique(labels) if label != -1}
|
3980
5641
|
plot_images_by_cluster(ax, image_paths, embedding, labels, image_nr, img_zoom, colors, cluster_indices, remove_image_canvas, verbose)
|
@@ -3986,6 +5647,24 @@ def plot_umap_images(ax, image_paths, embedding, labels, image_nr, img_zoom, col
|
|
3986
5647
|
plot_image(ax, x, y, img, img_zoom, remove_image_canvas)
|
3987
5648
|
|
3988
5649
|
def plot_images_by_cluster(ax, image_paths, embedding, labels, image_nr, img_zoom, colors, cluster_indices, remove_image_canvas, verbose):
|
5650
|
+
"""
|
5651
|
+
Plots images on a given axis based on their cluster assignments and embeddings.
|
5652
|
+
|
5653
|
+
Args:
|
5654
|
+
ax (matplotlib.axes.Axes): The matplotlib axis on which to plot the images.
|
5655
|
+
image_paths (list of str): List of file paths to the images to be plotted.
|
5656
|
+
embedding (array-like): 2D array of shape (n_samples, 2) containing the x and y coordinates for each image.
|
5657
|
+
labels (array-like): Array of cluster labels for each image. -1 indicates noise or unclustered points.
|
5658
|
+
image_nr (int): Maximum number of images to display per cluster.
|
5659
|
+
img_zoom (float): Zoom factor for the displayed images.
|
5660
|
+
colors (list of str): List of colors corresponding to each cluster.
|
5661
|
+
cluster_indices (dict): Dictionary mapping cluster labels to lists of indices of images in each cluster.
|
5662
|
+
remove_image_canvas (bool): If True, removes the canvas (border) around the plotted images.
|
5663
|
+
verbose (bool): If True, prints additional information during execution.
|
5664
|
+
|
5665
|
+
Returns:
|
5666
|
+
None
|
5667
|
+
"""
|
3989
5668
|
for cluster_label, color in zip(np.unique(labels), colors):
|
3990
5669
|
if cluster_label == -1:
|
3991
5670
|
continue
|
@@ -3998,6 +5677,22 @@ def plot_images_by_cluster(ax, image_paths, embedding, labels, image_nr, img_zoo
|
|
3998
5677
|
plot_image(ax, x, y, img, img_zoom, remove_image_canvas)
|
3999
5678
|
|
4000
5679
|
def plot_image(ax, x, y, img, img_zoom, remove_image_canvas=True):
|
5680
|
+
"""
|
5681
|
+
Plots an image on a given matplotlib axis at specified coordinates.
|
5682
|
+
|
5683
|
+
Args:
|
5684
|
+
ax (matplotlib.axes.Axes): The axis on which to plot the image.
|
5685
|
+
x (float): The x-coordinate where the image will be placed.
|
5686
|
+
y (float): The y-coordinate where the image will be placed.
|
5687
|
+
img (numpy.ndarray or array-like): The image data to be plotted.
|
5688
|
+
img_zoom (float): The zoom factor for the image.
|
5689
|
+
remove_image_canvas (bool, optional): If True, removes the canvas
|
5690
|
+
(e.g., padding or borders) from the image before plotting.
|
5691
|
+
Defaults to True.
|
5692
|
+
|
5693
|
+
Returns:
|
5694
|
+
None
|
5695
|
+
"""
|
4001
5696
|
img = np.array(img)
|
4002
5697
|
if remove_image_canvas:
|
4003
5698
|
img = remove_canvas(img)
|
@@ -4006,6 +5701,26 @@ def plot_image(ax, x, y, img, img_zoom, remove_image_canvas=True):
|
|
4006
5701
|
ax.add_artist(ab)
|
4007
5702
|
|
4008
5703
|
def remove_canvas(img):
|
5704
|
+
"""
|
5705
|
+
Converts an image to a normalized RGBA format by adding an alpha channel.
|
5706
|
+
|
5707
|
+
This function processes images in either grayscale ('L', 'I') or RGB ('RGB') mode.
|
5708
|
+
For grayscale images, the pixel values are normalized, and an alpha channel is
|
5709
|
+
created based on non-zero pixel values. For RGB images, the pixel values are
|
5710
|
+
normalized to the range [0, 1], and an alpha channel is created based on the
|
5711
|
+
presence of non-zero pixel values across all channels.
|
5712
|
+
|
5713
|
+
Args:
|
5714
|
+
img (PIL.Image.Image): The input image to process. Must be in 'L', 'I', or 'RGB' mode.
|
5715
|
+
|
5716
|
+
Returns:
|
5717
|
+
numpy.ndarray: A 4-channel RGBA image as a NumPy array, where the first three
|
5718
|
+
channels represent the normalized RGB values, and the fourth channel represents
|
5719
|
+
the alpha channel.
|
5720
|
+
|
5721
|
+
Raises:
|
5722
|
+
ValueError: If the input image mode is not 'L', 'I', or 'RGB'.
|
5723
|
+
"""
|
4009
5724
|
if img.mode in ['L', 'I']:
|
4010
5725
|
img_data = np.array(img)
|
4011
5726
|
img_data = img_data / np.max(img_data)
|
@@ -4022,6 +5737,27 @@ def remove_canvas(img):
|
|
4022
5737
|
return img_data_with_alpha
|
4023
5738
|
|
4024
5739
|
def plot_clusters_grid(embedding, labels, image_nr, image_paths, colors, figuresize, black_background, verbose):
|
5740
|
+
"""
|
5741
|
+
Plot a grid of images for each cluster based on the given labels and embeddings.
|
5742
|
+
|
5743
|
+
Args:
|
5744
|
+
embedding (np.ndarray): Embedding of data points for visualization.
|
5745
|
+
labels (np.ndarray): Cluster labels for each data point. A value of -1 indicates noise or outliers.
|
5746
|
+
image_nr (int): Maximum number of images to display per cluster.
|
5747
|
+
image_paths (list of str): File paths to images corresponding to the data points.
|
5748
|
+
colors (list of str): List of colors for each cluster.
|
5749
|
+
figuresize (tuple): Size of the figure (width, height).
|
5750
|
+
black_background (bool): Whether to use a black background.
|
5751
|
+
verbose (bool): Whether to print progress information.
|
5752
|
+
|
5753
|
+
Returns:
|
5754
|
+
matplotlib.figure.Figure or None: The generated figure, or None if no valid clusters are found.
|
5755
|
+
|
5756
|
+
Notes:
|
5757
|
+
- Clusters larger than `image_nr` are randomly subsampled.
|
5758
|
+
- If all labels are -1, the function returns None.
|
5759
|
+
- Relies on an external `plot_grid` function for grid rendering.
|
5760
|
+
"""
|
4025
5761
|
unique_labels = np.unique(labels)
|
4026
5762
|
num_clusters = len(unique_labels[unique_labels != -1])
|
4027
5763
|
if num_clusters == 0:
|
@@ -4042,7 +5778,25 @@ def plot_clusters_grid(embedding, labels, image_nr, image_paths, colors, figures
|
|
4042
5778
|
fig = plot_grid(cluster_images, colors, figuresize, black_background, verbose)
|
4043
5779
|
return fig
|
4044
5780
|
|
4045
|
-
def plot_grid(cluster_images, colors, figuresize, black_background, verbose):
|
5781
|
+
def plot_grid(cluster_images, colors, figuresize, black_background, verbose):
|
5782
|
+
"""
|
5783
|
+
Plot a grid of images grouped by cluster with optional background and labels.
|
5784
|
+
|
5785
|
+
Args:
|
5786
|
+
cluster_images (dict): Dictionary mapping cluster labels to lists of images.
|
5787
|
+
colors (list): List of RGB tuples specifying colors for each cluster.
|
5788
|
+
figuresize (float): Base figure size; actual size scales with the number of clusters.
|
5789
|
+
black_background (bool): If True, use a black background; otherwise, use white.
|
5790
|
+
verbose (bool): If True, print cluster labels and index info during plotting.
|
5791
|
+
|
5792
|
+
Returns:
|
5793
|
+
matplotlib.figure.Figure: The generated figure containing the image grid.
|
5794
|
+
|
5795
|
+
Notes:
|
5796
|
+
- Grid size is dynamically adjusted per cluster.
|
5797
|
+
- Cluster labels are shown alongside image grids using corresponding colors.
|
5798
|
+
- A maximum figure size limit prevents overly large plots.
|
5799
|
+
"""
|
4046
5800
|
num_clusters = len(cluster_images)
|
4047
5801
|
max_figsize = 200 # Set a maximum figure size
|
4048
5802
|
if figuresize * num_clusters > max_figsize:
|
@@ -4090,6 +5844,27 @@ def plot_grid(cluster_images, colors, figuresize, black_background, verbose):
|
|
4090
5844
|
return grid_fig
|
4091
5845
|
|
4092
5846
|
def generate_path_list_from_db(db_path, file_metadata):
|
5847
|
+
"""
|
5848
|
+
Generate a list of file paths from a SQLite database using optional metadata filters.
|
5849
|
+
|
5850
|
+
Args:
|
5851
|
+
db_path (str): Path to the SQLite database.
|
5852
|
+
file_metadata (str | list[str] | None): Filter criteria for file paths.
|
5853
|
+
- str: Only include paths containing the string.
|
5854
|
+
- list of str: Include paths containing any of the strings.
|
5855
|
+
- None or empty: Include all paths.
|
5856
|
+
|
5857
|
+
Returns:
|
5858
|
+
list[str] or None: List of matching file paths, or None if an error occurs.
|
5859
|
+
|
5860
|
+
Raises:
|
5861
|
+
sqlite3.Error: If a database operation fails.
|
5862
|
+
Exception: For any other unexpected error.
|
5863
|
+
|
5864
|
+
Notes:
|
5865
|
+
- Paths are fetched from the 'png_list' table using the 'png_path' column.
|
5866
|
+
- Results are retrieved in batches of 1000 rows for efficiency.
|
5867
|
+
"""
|
4093
5868
|
all_paths = []
|
4094
5869
|
|
4095
5870
|
# Connect to the database and retrieve the image paths
|
@@ -4128,6 +5903,27 @@ def generate_path_list_from_db(db_path, file_metadata):
|
|
4128
5903
|
return all_paths
|
4129
5904
|
|
4130
5905
|
def correct_paths(df, base_path, folder='data'):
|
5906
|
+
"""
|
5907
|
+
Adjust file paths to include the specified base directory and folder.
|
5908
|
+
|
5909
|
+
Args:
|
5910
|
+
df (pandas.DataFrame or list): Input containing file paths.
|
5911
|
+
- If a DataFrame, it must have a 'png_path' column.
|
5912
|
+
- If a list, it should contain file path strings.
|
5913
|
+
base_path (str): Base directory to prepend if not already present.
|
5914
|
+
folder (str, optional): Folder name to insert into paths (default: 'data').
|
5915
|
+
|
5916
|
+
Returns:
|
5917
|
+
tuple or list:
|
5918
|
+
- If input is a DataFrame: (updated DataFrame, list of adjusted paths).
|
5919
|
+
- If input is a list: list of adjusted paths.
|
5920
|
+
|
5921
|
+
Raises:
|
5922
|
+
ValueError: If the DataFrame does not contain a 'png_path' column.
|
5923
|
+
|
5924
|
+
Notes:
|
5925
|
+
Paths already containing the base path are not modified.
|
5926
|
+
"""
|
4131
5927
|
|
4132
5928
|
if isinstance(df, pd.DataFrame):
|
4133
5929
|
|
@@ -4159,6 +5955,24 @@ def correct_paths(df, base_path, folder='data'):
|
|
4159
5955
|
return adjusted_image_paths
|
4160
5956
|
|
4161
5957
|
def delete_folder(folder_path):
|
5958
|
+
"""
|
5959
|
+
Deletes a folder and all of its contents, including subdirectories and files.
|
5960
|
+
|
5961
|
+
Args:
|
5962
|
+
folder_path (str): The path to the folder to be deleted.
|
5963
|
+
|
5964
|
+
Behavior:
|
5965
|
+
- If the specified folder exists and is a directory, it recursively deletes all files
|
5966
|
+
and subdirectories within it, and then removes the folder itself.
|
5967
|
+
- If the folder does not exist or is not a directory, a message is printed indicating this.
|
5968
|
+
|
5969
|
+
Prints:
|
5970
|
+
- A confirmation message if the folder is successfully deleted.
|
5971
|
+
- An error message if the folder does not exist or is not a directory.
|
5972
|
+
|
5973
|
+
Example:
|
5974
|
+
delete_folder('/path/to/folder')
|
5975
|
+
"""
|
4162
5976
|
if os.path.exists(folder_path) and os.path.isdir(folder_path):
|
4163
5977
|
for root, dirs, files in os.walk(folder_path, topdown=False):
|
4164
5978
|
for name in files:
|
@@ -4171,6 +5985,28 @@ def delete_folder(folder_path):
|
|
4171
5985
|
print(f"Folder '{folder_path}' does not exist or is not a directory.")
|
4172
5986
|
|
4173
5987
|
def measure_test_mode(settings):
|
5988
|
+
"""
|
5989
|
+
Adjusts the source folder in the settings dictionary for test mode.
|
5990
|
+
|
5991
|
+
If `test_mode` is enabled in the `settings` dictionary, this function:
|
5992
|
+
- Checks if the current source folder (`settings['src']`) is not already set to 'test'.
|
5993
|
+
- Selects a random subset of files from the source folder based on `settings['test_nr']`.
|
5994
|
+
- Copies the selected files into a new 'test/merged' directory.
|
5995
|
+
- Updates the `settings['src']` to point to the new 'test/merged' directory.
|
5996
|
+
- Prints a message indicating the change in the source folder.
|
5997
|
+
|
5998
|
+
If the source folder is already set to 'test', it simply prints a message confirming the test mode.
|
5999
|
+
|
6000
|
+
Args:
|
6001
|
+
settings (dict): A dictionary containing configuration settings.
|
6002
|
+
Expected keys:
|
6003
|
+
- 'test_mode' (bool): Whether test mode is enabled.
|
6004
|
+
- 'src' (str): Path to the source folder.
|
6005
|
+
- 'test_nr' (int): Number of files to select for test mode.
|
6006
|
+
|
6007
|
+
Returns:
|
6008
|
+
dict: The updated settings dictionary with the modified source folder if test mode is enabled.
|
6009
|
+
"""
|
4174
6010
|
|
4175
6011
|
if settings['test_mode']:
|
4176
6012
|
if not os.path.basename(settings['src']) == 'test':
|
@@ -4251,7 +6087,7 @@ def remove_low_variance_columns(df, threshold=0.01, verbose=False):
|
|
4251
6087
|
"""
|
4252
6088
|
Removes columns from the dataframe that have low variance.
|
4253
6089
|
|
4254
|
-
|
6090
|
+
Args:
|
4255
6091
|
df (pandas.DataFrame): The DataFrame containing the data.
|
4256
6092
|
threshold (float): The variance threshold below which columns will be removed.
|
4257
6093
|
|
@@ -4273,7 +6109,7 @@ def remove_highly_correlated_columns(df, threshold=0.95, verbose=False):
|
|
4273
6109
|
"""
|
4274
6110
|
Removes columns from the dataframe that are highly correlated with one another.
|
4275
6111
|
|
4276
|
-
|
6112
|
+
Args:
|
4277
6113
|
df (pandas.DataFrame): The DataFrame containing the data.
|
4278
6114
|
threshold (float): The correlation threshold above which columns will be removed.
|
4279
6115
|
|
@@ -4301,7 +6137,7 @@ def filter_dataframe_features(df, channel_of_interest, exclude=None, remove_low_
|
|
4301
6137
|
"""
|
4302
6138
|
Filter the dataframe `df` based on the specified `channel_of_interest` and `exclude` parameters.
|
4303
6139
|
|
4304
|
-
|
6140
|
+
Args:
|
4305
6141
|
- df (pandas.DataFrame): The input dataframe to be filtered.
|
4306
6142
|
- channel_of_interest (str, int, list, None): The channel(s) of interest to filter the dataframe. If None, no filtering is applied. If 'morphology', only morphology features are included.If an integer, only the specified channel is included. If a list, only the specified channels are included.If a string, only the specified channel is included.
|
4307
6143
|
- exclude (str, list, None): The feature(s) to exclude from the filtered dataframe. If None, no features are excluded. If a string, the specified feature is excluded.If a list, the specified features are excluded.
|
@@ -4381,6 +6217,20 @@ def filter_dataframe_features(df, channel_of_interest, exclude=None, remove_low_
|
|
4381
6217
|
|
4382
6218
|
# Create a function to check if images overlap
|
4383
6219
|
def check_overlap(current_position, other_positions, threshold):
|
6220
|
+
"""
|
6221
|
+
Checks if the current position overlaps with any of the other positions
|
6222
|
+
within a specified threshold distance.
|
6223
|
+
|
6224
|
+
Args:
|
6225
|
+
current_position (iterable): The current position as a list, tuple, or array of coordinates.
|
6226
|
+
other_positions (iterable): A collection of positions to compare against,
|
6227
|
+
where each position is a list, tuple, or array of coordinates.
|
6228
|
+
threshold (float): The distance threshold below which two positions are considered overlapping.
|
6229
|
+
|
6230
|
+
Returns:
|
6231
|
+
bool: True if the current position overlaps with any of the other positions,
|
6232
|
+
False otherwise.
|
6233
|
+
"""
|
4384
6234
|
for other_position in other_positions:
|
4385
6235
|
distance = np.linalg.norm(np.array(current_position) - np.array(other_position))
|
4386
6236
|
if distance < threshold:
|
@@ -4389,6 +6239,25 @@ def check_overlap(current_position, other_positions, threshold):
|
|
4389
6239
|
|
4390
6240
|
# Define a function to try random positions around a given point
|
4391
6241
|
def find_non_overlapping_position(x, y, image_positions, threshold, max_attempts=100):
|
6242
|
+
"""
|
6243
|
+
Finds a new position near the given coordinates (x, y) that does not overlap
|
6244
|
+
with any of the positions in the provided image_positions list, based on a
|
6245
|
+
specified threshold.
|
6246
|
+
|
6247
|
+
Args:
|
6248
|
+
x (float): The x-coordinate of the initial position.
|
6249
|
+
y (float): The y-coordinate of the initial position.
|
6250
|
+
image_positions (list of tuples): A list of (x, y) tuples representing
|
6251
|
+
existing positions to avoid overlapping with.
|
6252
|
+
threshold (float): The minimum distance required to avoid overlap.
|
6253
|
+
max_attempts (int, optional): The maximum number of attempts to find a
|
6254
|
+
non-overlapping position. Defaults to 100.
|
6255
|
+
|
6256
|
+
Returns:
|
6257
|
+
tuple: A tuple (new_x, new_y) representing the new non-overlapping
|
6258
|
+
position. If no suitable position is found within the maximum attempts,
|
6259
|
+
the original position (x, y) is returned.
|
6260
|
+
"""
|
4392
6261
|
offset_range = 10 # Adjust the range for random offsets
|
4393
6262
|
attempts = 0
|
4394
6263
|
while attempts < max_attempts:
|
@@ -4405,7 +6274,7 @@ def search_reduction_and_clustering(numeric_data, n_neighbors, min_dist, metric,
|
|
4405
6274
|
"""
|
4406
6275
|
Perform dimensionality reduction and clustering on the given data.
|
4407
6276
|
|
4408
|
-
|
6277
|
+
Args:
|
4409
6278
|
numeric_data (np.array): Numeric data to process.
|
4410
6279
|
n_neighbors (int): Number of neighbors for UMAP or perplexity for tSNE.
|
4411
6280
|
min_dist (float): Minimum distance for UMAP.
|
@@ -4457,7 +6326,18 @@ def search_reduction_and_clustering(numeric_data, n_neighbors, min_dist, metric,
|
|
4457
6326
|
return embedding, labels
|
4458
6327
|
|
4459
6328
|
def load_image(image_path):
|
4460
|
-
"""
|
6329
|
+
"""
|
6330
|
+
Load and transform an image into a normalized tensor.
|
6331
|
+
|
6332
|
+
Applies resizing to 224x224, converts to a tensor, and normalizes using
|
6333
|
+
ImageNet mean and standard deviation.
|
6334
|
+
|
6335
|
+
Args:
|
6336
|
+
image_path (str): Path to the input image.
|
6337
|
+
|
6338
|
+
Returns:
|
6339
|
+
torch.Tensor: Transformed image tensor with shape (1, 3, 224, 224).
|
6340
|
+
"""
|
4461
6341
|
transform = transforms.Compose([
|
4462
6342
|
transforms.Resize((224, 224)),
|
4463
6343
|
transforms.ToTensor(),
|
@@ -4468,7 +6348,17 @@ def load_image(image_path):
|
|
4468
6348
|
return image
|
4469
6349
|
|
4470
6350
|
def extract_features(image_paths, resnet=resnet50):
|
4471
|
-
"""
|
6351
|
+
"""
|
6352
|
+
Extracts features from a list of image paths using a pre-trained ResNet model.
|
6353
|
+
|
6354
|
+
Args:
|
6355
|
+
image_paths (list of str): A list of file paths to the images from which features are to be extracted.
|
6356
|
+
resnet (torchvision.models, optional): A ResNet model class to use for feature extraction.
|
6357
|
+
Defaults to torchvision.models.resnet50.
|
6358
|
+
|
6359
|
+
Returns:
|
6360
|
+
numpy.ndarray: A 2D array where each row corresponds to the extracted features of an image.
|
6361
|
+
"""
|
4472
6362
|
model = resnet(pretrained=True)
|
4473
6363
|
model = model.eval()
|
4474
6364
|
model = torch.nn.Sequential(*list(model.children())[:-1]) # Remove the last classification layer
|
@@ -4483,7 +6373,20 @@ def extract_features(image_paths, resnet=resnet50):
|
|
4483
6373
|
return np.array(features)
|
4484
6374
|
|
4485
6375
|
def check_normality(series):
|
4486
|
-
"""
|
6376
|
+
"""
|
6377
|
+
Test whether a given data series follows a normal distribution.
|
6378
|
+
|
6379
|
+
This function uses the D'Agostino and Pearson's test to check the null
|
6380
|
+
hypothesis that the data comes from a normal distribution. If the p-value
|
6381
|
+
is less than the significance level (alpha), the null hypothesis is rejected.
|
6382
|
+
|
6383
|
+
Args:
|
6384
|
+
series (array-like): The data series to test for normality.
|
6385
|
+
|
6386
|
+
Returns:
|
6387
|
+
bool: True if the data follows a normal distribution (p >= alpha),
|
6388
|
+
False otherwise.
|
6389
|
+
"""
|
4487
6390
|
k2, p = stats.normaltest(series)
|
4488
6391
|
alpha = 0.05
|
4489
6392
|
if p < alpha: # null hypothesis: x comes from a normal distribution
|
@@ -4491,7 +6394,29 @@ def check_normality(series):
|
|
4491
6394
|
return True
|
4492
6395
|
|
4493
6396
|
def random_forest_feature_importance(all_df, cluster_col='cluster'):
|
4494
|
-
"""
|
6397
|
+
"""
|
6398
|
+
Computes feature importance using a Random Forest Classifier.
|
6399
|
+
|
6400
|
+
This function takes a DataFrame, selects numeric features, and computes
|
6401
|
+
the importance of each feature in predicting the specified cluster column
|
6402
|
+
using a Random Forest Classifier. The results are returned as a sorted
|
6403
|
+
DataFrame of feature importances.
|
6404
|
+
|
6405
|
+
Args:
|
6406
|
+
all_df (pd.DataFrame): The input DataFrame containing the data.
|
6407
|
+
cluster_col (str, optional): The name of the column representing the
|
6408
|
+
target variable (cluster). Defaults to 'cluster'.
|
6409
|
+
|
6410
|
+
Returns:
|
6411
|
+
pd.DataFrame: A DataFrame containing the features and their corresponding
|
6412
|
+
importance scores, sorted in descending order of importance.
|
6413
|
+
|
6414
|
+
Notes:
|
6415
|
+
- The function assumes that the target column (`cluster_col`) is numeric.
|
6416
|
+
- Standard scaling is applied to the numeric features before fitting the model.
|
6417
|
+
- The Random Forest Classifier is initialized with 100 estimators and a
|
6418
|
+
random state of 42 for reproducibility.
|
6419
|
+
"""
|
4495
6420
|
numeric_features = all_df.select_dtypes(include=[np.number]).columns.tolist()
|
4496
6421
|
if cluster_col in numeric_features:
|
4497
6422
|
numeric_features.remove(cluster_col)
|
@@ -4515,7 +6440,29 @@ def random_forest_feature_importance(all_df, cluster_col='cluster'):
|
|
4515
6440
|
return importance_df
|
4516
6441
|
|
4517
6442
|
def perform_statistical_tests(all_df, cluster_col='cluster'):
|
4518
|
-
"""
|
6443
|
+
"""
|
6444
|
+
Perform ANOVA and Kruskal-Wallis tests on numeric features grouped by clusters.
|
6445
|
+
|
6446
|
+
This function evaluates whether numeric features differ significantly across groups
|
6447
|
+
defined by the `cluster_col`.
|
6448
|
+
|
6449
|
+
Args:
|
6450
|
+
all_df (pd.DataFrame): DataFrame containing numeric features and cluster assignments.
|
6451
|
+
cluster_col (str, optional): Name of the column indicating cluster/group labels. Defaults to 'cluster'.
|
6452
|
+
|
6453
|
+
Returns:
|
6454
|
+
tuple:
|
6455
|
+
- anova_df (pd.DataFrame): ANOVA test results with columns
|
6456
|
+
['Feature', 'ANOVA_Statistic', 'ANOVA_pValue'].
|
6457
|
+
- kruskal_df (pd.DataFrame): Kruskal-Wallis test results with columns
|
6458
|
+
['Feature', 'Kruskal_Statistic', 'Kruskal_pValue'].
|
6459
|
+
|
6460
|
+
Notes:
|
6461
|
+
- Normality of each feature is assessed using `check_normality`.
|
6462
|
+
- ANOVA is used for normally distributed features.
|
6463
|
+
- Kruskal-Wallis is used for non-normal features.
|
6464
|
+
- Assumes `check_normality`, `scipy.stats.f_oneway`, and `scipy.stats.kruskal` are available.
|
6465
|
+
"""
|
4519
6466
|
numeric_features = all_df.select_dtypes(include=[np.number]).columns.tolist()
|
4520
6467
|
if cluster_col in numeric_features:
|
4521
6468
|
numeric_features.remove(cluster_col)
|
@@ -4539,15 +6486,47 @@ def perform_statistical_tests(all_df, cluster_col='cluster'):
|
|
4539
6486
|
return anova_df, kruskal_df
|
4540
6487
|
|
4541
6488
|
def combine_results(rf_df, anova_df, kruskal_df):
|
4542
|
-
"""
|
6489
|
+
"""
|
6490
|
+
Combines results from multiple DataFrames into a single DataFrame.
|
6491
|
+
|
6492
|
+
This function merges three DataFrames (`rf_df`, `anova_df`, and `kruskal_df`)
|
6493
|
+
on the 'Feature' column using a left join. The resulting DataFrame contains
|
6494
|
+
all features from `rf_df` and their corresponding data from `anova_df` and
|
6495
|
+
`kruskal_df` where available.
|
6496
|
+
|
6497
|
+
Args:
|
6498
|
+
rf_df (pd.DataFrame): A DataFrame containing features and their associated
|
6499
|
+
data from a random forest analysis.
|
6500
|
+
anova_df (pd.DataFrame): A DataFrame containing features and their associated
|
6501
|
+
data from an ANOVA analysis.
|
6502
|
+
kruskal_df (pd.DataFrame): A DataFrame containing features and their associated
|
6503
|
+
data from a Kruskal-Wallis analysis.
|
6504
|
+
|
6505
|
+
Returns:
|
6506
|
+
pd.DataFrame: A combined DataFrame with features and their associated data
|
6507
|
+
from all three input DataFrames.
|
6508
|
+
"""
|
4543
6509
|
combined_df = rf_df.merge(anova_df, on='Feature', how='left')
|
4544
6510
|
combined_df = combined_df.merge(kruskal_df, on='Feature', how='left')
|
4545
6511
|
return combined_df
|
4546
6512
|
|
4547
6513
|
def cluster_feature_analysis(all_df, cluster_col='cluster'):
|
4548
6514
|
"""
|
4549
|
-
Perform
|
4550
|
-
|
6515
|
+
Perform feature analysis for clustering by combining results from
|
6516
|
+
random forest feature importance and statistical tests.
|
6517
|
+
|
6518
|
+
This function calculates feature importance using a random forest model,
|
6519
|
+
performs statistical tests (ANOVA and Kruskal-Wallis) to assess the
|
6520
|
+
significance of features, and combines the results into a single DataFrame.
|
6521
|
+
|
6522
|
+
Args:
|
6523
|
+
all_df (pd.DataFrame): The input DataFrame containing features and cluster labels.
|
6524
|
+
cluster_col (str, optional): The name of the column representing cluster labels.
|
6525
|
+
Defaults to 'cluster'.
|
6526
|
+
|
6527
|
+
Returns:
|
6528
|
+
pd.DataFrame: A DataFrame combining the results of random forest feature
|
6529
|
+
importance and statistical tests for feature analysis.
|
4551
6530
|
"""
|
4552
6531
|
rf_df = random_forest_feature_importance(all_df, cluster_col)
|
4553
6532
|
anova_df, kruskal_df = perform_statistical_tests(all_df, cluster_col)
|
@@ -4657,7 +6636,7 @@ def _merge_cells_without_nucleus(adj_cell_mask: np.ndarray, nuclei_mask: np.ndar
|
|
4657
6636
|
Relabel any cell that lacks a nucleus to the ID of an adjacent
|
4658
6637
|
cell that *does* contain a nucleus.
|
4659
6638
|
|
4660
|
-
|
6639
|
+
Args
|
4661
6640
|
----------
|
4662
6641
|
adj_cell_mask : np.ndarray
|
4663
6642
|
Labelled (0 = background) cell mask after all other merging steps.
|
@@ -4723,6 +6702,26 @@ def _merge_cells_without_nucleus(adj_cell_mask: np.ndarray, nuclei_mask: np.ndar
|
|
4723
6702
|
return out.astype(np.uint16)
|
4724
6703
|
|
4725
6704
|
def process_mask_file_adjust_cell(file_name, parasite_folder, cell_folder, nuclei_folder, overlap_threshold, perimeter_threshold):
|
6705
|
+
"""
|
6706
|
+
Processes and adjusts a cell mask file based on parasite overlap and perimeter thresholds.
|
6707
|
+
|
6708
|
+
This function loads parasite, cell, and nuclei mask files, merges cells based on parasite overlap
|
6709
|
+
and perimeter thresholds, and saves the updated cell mask back to the file system.
|
6710
|
+
|
6711
|
+
Args:
|
6712
|
+
file_name (str): The name of the mask file to process.
|
6713
|
+
parasite_folder (str): The directory containing parasite mask files.
|
6714
|
+
cell_folder (str): The directory containing cell mask files.
|
6715
|
+
nuclei_folder (str): The directory containing nuclei mask files.
|
6716
|
+
overlap_threshold (float): The threshold for parasite overlap to merge cells.
|
6717
|
+
perimeter_threshold (float): The threshold for cell perimeter to merge cells.
|
6718
|
+
|
6719
|
+
Returns:
|
6720
|
+
float: The time taken to process the mask file, in seconds.
|
6721
|
+
|
6722
|
+
Raises:
|
6723
|
+
ValueError: If the corresponding cell or nuclei mask file for the given file_name is not found.
|
6724
|
+
"""
|
4726
6725
|
start = time.perf_counter()
|
4727
6726
|
|
4728
6727
|
parasite_path = os.path.join(parasite_folder, file_name)
|
@@ -4745,6 +6744,27 @@ def process_mask_file_adjust_cell(file_name, parasite_folder, cell_folder, nucle
|
|
4745
6744
|
return end - start
|
4746
6745
|
|
4747
6746
|
def adjust_cell_masks(parasite_folder, cell_folder, nuclei_folder, overlap_threshold=5, perimeter_threshold=30, n_jobs=None):
|
6747
|
+
"""
|
6748
|
+
Adjusts cell masks based on parasite, cell, and nuclei data files.
|
6749
|
+
This function processes `.npy` files from the specified folders to adjust cell masks
|
6750
|
+
by considering overlap and perimeter thresholds. It uses multiprocessing to parallelize
|
6751
|
+
the processing of files.
|
6752
|
+
Args:
|
6753
|
+
parasite_folder (str): Path to the folder containing parasite `.npy` files.
|
6754
|
+
cell_folder (str): Path to the folder containing cell `.npy` files.
|
6755
|
+
nuclei_folder (str): Path to the folder containing nuclei `.npy` files.
|
6756
|
+
overlap_threshold (int, optional): Threshold for overlap adjustment. Defaults to 5.
|
6757
|
+
perimeter_threshold (int, optional): Threshold for perimeter adjustment. Defaults to 30.
|
6758
|
+
n_jobs (int, optional): Number of parallel jobs to run. Defaults to the number of CPU cores minus 2.
|
6759
|
+
Raises:
|
6760
|
+
ValueError: If the number of files in the parasite, cell, and nuclei folders do not match.
|
6761
|
+
Notes:
|
6762
|
+
- The function assumes that the files in the folders are named in a way that allows
|
6763
|
+
them to be sorted and matched correctly.
|
6764
|
+
- Progress is printed to the console during processing.
|
6765
|
+
Returns:
|
6766
|
+
None
|
6767
|
+
"""
|
4748
6768
|
|
4749
6769
|
parasite_files = sorted([f for f in os.listdir(parasite_folder) if f.endswith('.npy')])
|
4750
6770
|
cell_files = sorted([f for f in os.listdir(cell_folder) if f.endswith('.npy')])
|
@@ -4822,6 +6842,25 @@ def adjust_cell_masks_v1(parasite_folder, cell_folder, nuclei_folder, overlap_th
|
|
4822
6842
|
print_progress(files_processed, files_to_process, n_jobs=1, time_ls=time_ls, batch_size=None, operation_type=f'adjust_cell_masks')
|
4823
6843
|
|
4824
6844
|
def process_masks(mask_folder, image_folder, channel, batch_size=50, n_clusters=2, plot=False):
|
6845
|
+
"""
|
6846
|
+
Processes mask files by measuring object properties, clustering objects, and removing objects
|
6847
|
+
not belonging to the largest cluster.
|
6848
|
+
Args:
|
6849
|
+
mask_folder (str): Path to the folder containing mask files (.npy format).
|
6850
|
+
image_folder (str): Path to the folder containing corresponding image files (.npy format).
|
6851
|
+
channel (int): The channel index to extract from the image files.
|
6852
|
+
batch_size (int, optional): Number of files to process in each batch. Defaults to 50.
|
6853
|
+
n_clusters (int, optional): Number of clusters for KMeans clustering. Defaults to 2.
|
6854
|
+
plot (bool, optional): Whether to plot the clustering results using PCA. Defaults to False.
|
6855
|
+
Returns:
|
6856
|
+
None: The function modifies the mask files in place by removing objects not in the largest cluster.
|
6857
|
+
Notes:
|
6858
|
+
- The mask files are expected to be in .npy format and contain labeled regions.
|
6859
|
+
- The image files are expected to be in .npy format and have the same names as the mask files.
|
6860
|
+
- The function assumes that the mask and image files are sorted in the same order.
|
6861
|
+
- The clustering is performed on accumulated object properties across all files.
|
6862
|
+
- The largest cluster is determined based on the number of objects in each cluster.
|
6863
|
+
"""
|
4825
6864
|
|
4826
6865
|
def read_files_in_batches(folder, batch_size=50):
|
4827
6866
|
files = [f for f in os.listdir(folder) if f.endswith('.npy')]
|
@@ -4895,6 +6934,25 @@ def process_masks(mask_folder, image_folder, channel, batch_size=50, n_clusters=
|
|
4895
6934
|
label_index += len(batch_properties)
|
4896
6935
|
|
4897
6936
|
def merge_regression_res_with_metadata(results_file, metadata_file, name='_metadata'):
|
6937
|
+
"""
|
6938
|
+
Merge regression results with metadata using gene identifiers.
|
6939
|
+
|
6940
|
+
Reads regression results and metadata from two CSV files, extracts gene identifiers,
|
6941
|
+
merges the data on gene names, and saves the merged DataFrame.
|
6942
|
+
|
6943
|
+
Args:
|
6944
|
+
results_file (str): Path to the regression results CSV. Must contain a 'feature' column.
|
6945
|
+
metadata_file (str): Path to the metadata CSV. Must contain a 'Gene ID' column.
|
6946
|
+
name (str, optional): Suffix for the output file name. Defaults to '_metadata'.
|
6947
|
+
|
6948
|
+
Returns:
|
6949
|
+
pandas.DataFrame: Merged DataFrame with regression results and metadata.
|
6950
|
+
|
6951
|
+
Notes:
|
6952
|
+
- Extracts gene names from 'feature' (format: '[gene]') and 'Gene ID' (format: 'prefix_gene').
|
6953
|
+
- Gene extraction failures result in NaNs in the merge but are not dropped.
|
6954
|
+
- Output CSV is saved next to the input `results_file`, with the suffix appended.
|
6955
|
+
"""
|
4898
6956
|
# Read the CSV files into dataframes
|
4899
6957
|
df_results = pd.read_csv(results_file)
|
4900
6958
|
df_metadata = pd.read_csv(metadata_file)
|
@@ -4932,6 +6990,19 @@ def merge_regression_res_with_metadata(results_file, metadata_file, name='_metad
|
|
4932
6990
|
return merged_df
|
4933
6991
|
|
4934
6992
|
def process_vision_results(df, threshold=0.5):
|
6993
|
+
"""
|
6994
|
+
Process vision results by extracting metadata from file paths and thresholding predictions.
|
6995
|
+
|
6996
|
+
Args:
|
6997
|
+
df (pd.DataFrame): DataFrame with vision results. Must include 'path' and 'pred' columns.
|
6998
|
+
threshold (float, optional): Threshold for classifying predictions. Defaults to 0.5.
|
6999
|
+
|
7000
|
+
Returns:
|
7001
|
+
pd.DataFrame: Modified DataFrame with added columns:
|
7002
|
+
- 'plateID', 'rowID', 'columnID', 'fieldID', 'object': extracted from 'path'
|
7003
|
+
- 'prc': combination of 'plateID', 'rowID', and 'columnID'
|
7004
|
+
- 'cv_predictions': binary classification based on `threshold`
|
7005
|
+
"""
|
4935
7006
|
|
4936
7007
|
# Split the 'path' column using _map_wells function
|
4937
7008
|
mapped_values = df['path'].apply(lambda x: _map_wells(x))
|
@@ -4947,6 +7018,50 @@ def process_vision_results(df, threshold=0.5):
|
|
4947
7018
|
return df
|
4948
7019
|
|
4949
7020
|
def get_ml_results_paths(src, model_type='xgboost', channel_of_interest=1):
|
7021
|
+
"""
|
7022
|
+
Generate file paths for machine learning result outputs based on model type and feature selection.
|
7023
|
+
|
7024
|
+
Args:
|
7025
|
+
src (str): Base directory where the results folder structure will be created.
|
7026
|
+
model_type (str, optional): Type of ML model used (e.g., 'xgboost', 'random_forest'). Defaults to 'xgboost'.
|
7027
|
+
channel_of_interest (int | list[int] | str | None, optional): Feature set specification:
|
7028
|
+
- int: Single channel (e.g., 1)
|
7029
|
+
- list[int]: Multiple channels (e.g., [1, 2, 3])
|
7030
|
+
- 'morphology': Use only morphology features
|
7031
|
+
- None: Use all features
|
7032
|
+
Defaults to 1.
|
7033
|
+
|
7034
|
+
Returns:
|
7035
|
+
tuple[str, ...]: Tuple of 10 file paths:
|
7036
|
+
- data_path: CSV with predictions or main results
|
7037
|
+
- permutation_path: CSV with permutation importances
|
7038
|
+
- feature_importance_path: CSV with model feature importances
|
7039
|
+
- model_metricks_path: CSV with metrics from trained model
|
7040
|
+
- permutation_fig_path: PDF plot of permutation importances
|
7041
|
+
- feature_importance_fig_path: PDF plot of feature importances
|
7042
|
+
- shap_fig_path: PDF SHAP summary plot
|
7043
|
+
- plate_heatmap_path: PDF visualization of plate layout
|
7044
|
+
- settings_csv: CSV with ML settings and parameters
|
7045
|
+
- ml_features: CSV with extracted feature data used for training
|
7046
|
+
|
7047
|
+
Raises:
|
7048
|
+
ValueError: If `channel_of_interest` is not an int, list, None, or 'morphology'.
|
7049
|
+
|
7050
|
+
Example:
|
7051
|
+
>>> get_ml_results_paths('/home/user/data', model_type='random_forest', channel_of_interest=[1, 2])
|
7052
|
+
(
|
7053
|
+
'/home/user/data/results/random_forest/channels_1_2/results.csv',
|
7054
|
+
'/home/user/data/results/random_forest/channels_1_2/permutation.csv',
|
7055
|
+
'/home/user/data/results/random_forest/channels_1_2/feature_importance.csv',
|
7056
|
+
'/home/user/data/results/random_forest/channels_1_2/random_forest_model.csv',
|
7057
|
+
'/home/user/data/results/random_forest/channels_1_2/permutation.pdf',
|
7058
|
+
'/home/user/data/results/random_forest/channels_1_2/feature_importance.pdf',
|
7059
|
+
'/home/user/data/results/random_forest/channels_1_2/shap.pdf',
|
7060
|
+
'/home/user/data/results/random_forest/channels_1_2/plate_heatmap.pdf',
|
7061
|
+
'/home/user/data/results/random_forest/channels_1_2/ml_settings.csv',
|
7062
|
+
'/home/user/data/results/random_forest/channels_1_2/ml_features.csv'
|
7063
|
+
)
|
7064
|
+
"""
|
4950
7065
|
|
4951
7066
|
if isinstance(channel_of_interest, list):
|
4952
7067
|
feature_string = "channels_" + "_".join(map(str, channel_of_interest))
|
@@ -4981,7 +7096,7 @@ def augment_image(image):
|
|
4981
7096
|
"""
|
4982
7097
|
Perform data augmentation by rotating and reflecting the image.
|
4983
7098
|
|
4984
|
-
|
7099
|
+
Args:
|
4985
7100
|
- image (PIL Image or numpy array): The input image.
|
4986
7101
|
|
4987
7102
|
Returns:
|
@@ -5025,7 +7140,7 @@ def augment_dataset(dataset, is_grayscale=False):
|
|
5025
7140
|
"""
|
5026
7141
|
Perform data augmentation on the entire dataset by rotating and reflecting the images.
|
5027
7142
|
|
5028
|
-
|
7143
|
+
Args:
|
5029
7144
|
- dataset (list of tuples): The input dataset, each entry is a tuple (image, label, filename).
|
5030
7145
|
- is_grayscale (bool): Flag indicating if the images are grayscale.
|
5031
7146
|
|
@@ -5063,7 +7178,7 @@ def convert_and_relabel_masks(folder_path):
|
|
5063
7178
|
"""
|
5064
7179
|
Converts all int64 npy masks in a folder to uint16 with relabeling to ensure all labels are retained.
|
5065
7180
|
|
5066
|
-
|
7181
|
+
Args:
|
5067
7182
|
- folder_path (str): The path to the folder containing int64 npy mask files.
|
5068
7183
|
|
5069
7184
|
Returns:
|
@@ -5104,6 +7219,20 @@ def convert_and_relabel_masks(folder_path):
|
|
5104
7219
|
print(f"Converted {file} and saved as uint16_{file}")
|
5105
7220
|
|
5106
7221
|
def correct_masks(src):
|
7222
|
+
"""
|
7223
|
+
Corrects and processes mask files located in the specified source directory.
|
7224
|
+
|
7225
|
+
This function performs the following steps:
|
7226
|
+
1. Constructs the file path for the cell mask stack within the 'masks' subdirectory of the source.
|
7227
|
+
2. Converts and relabels the masks using the `convert_and_relabel_masks` function.
|
7228
|
+
3. Loads and concatenates arrays from the source directory using the `_load_and_concatenate_arrays` function.
|
7229
|
+
|
7230
|
+
Args:
|
7231
|
+
src (str): The path to the source directory containing the mask files.
|
7232
|
+
|
7233
|
+
Returns:
|
7234
|
+
None
|
7235
|
+
"""
|
5107
7236
|
|
5108
7237
|
from .io import _load_and_concatenate_arrays
|
5109
7238
|
|
@@ -5112,15 +7241,35 @@ def correct_masks(src):
|
|
5112
7241
|
_load_and_concatenate_arrays(src, [0,1,2,3], 1, 0, 2)
|
5113
7242
|
|
5114
7243
|
def count_reads_in_fastq(fastq_file):
|
7244
|
+
"""
|
7245
|
+
Counts the number of reads in a FASTQ file.
|
7246
|
+
|
7247
|
+
A FASTQ file contains sequencing reads, where each read is represented
|
7248
|
+
by four lines: a header, the sequence, a separator, and the quality scores.
|
7249
|
+
This function calculates the total number of reads by dividing the total
|
7250
|
+
number of lines in the file by 4.
|
7251
|
+
|
7252
|
+
Args:
|
7253
|
+
fastq_file (str): Path to the FASTQ file, which can be gzip-compressed.
|
7254
|
+
|
7255
|
+
Returns:
|
7256
|
+
int: The number of reads in the FASTQ file.
|
7257
|
+
"""
|
5115
7258
|
count = 0
|
5116
7259
|
with gzip.open(fastq_file, "rt") as f:
|
5117
7260
|
for _ in f:
|
5118
7261
|
count += 1
|
5119
7262
|
return count // 4
|
5120
7263
|
|
5121
|
-
|
5122
7264
|
# Function to determine the CUDA version
|
5123
7265
|
def get_cuda_version():
|
7266
|
+
"""
|
7267
|
+
Retrieves the installed CUDA version by invoking the `nvcc --version` command.
|
7268
|
+
|
7269
|
+
Returns:
|
7270
|
+
str: The CUDA version as a string with dots removed (e.g., '110' for version 11.0),
|
7271
|
+
or None if the `nvcc` command is not found or an error occurs.
|
7272
|
+
"""
|
5124
7273
|
try:
|
5125
7274
|
output = subprocess.check_output(['nvcc', '--version'], stderr=subprocess.STDOUT).decode('utf-8')
|
5126
7275
|
if 'release' in output:
|
@@ -5129,10 +7278,36 @@ def get_cuda_version():
|
|
5129
7278
|
return None
|
5130
7279
|
|
5131
7280
|
def all_elements_match(list1, list2):
|
7281
|
+
"""
|
7282
|
+
Check if all elements in the first list are present in the second list.
|
7283
|
+
|
7284
|
+
Args:
|
7285
|
+
list1 (list): The first list containing elements to check.
|
7286
|
+
list2 (list): The second list to check against.
|
7287
|
+
|
7288
|
+
Returns:
|
7289
|
+
bool: True if all elements in list1 are found in list2, False otherwise.
|
7290
|
+
"""
|
5132
7291
|
# Check if all elements in list1 are in list2
|
5133
7292
|
return all(element in list2 for element in list1)
|
5134
7293
|
|
5135
7294
|
def prepare_batch_for_segmentation(batch):
|
7295
|
+
"""
|
7296
|
+
Prepare a batch of images for segmentation by ensuring correct data type and normalization.
|
7297
|
+
|
7298
|
+
Args:
|
7299
|
+
batch (np.ndarray): Batch of images with shape (N, H, W, C), where:
|
7300
|
+
- N: number of images
|
7301
|
+
- H, W: height and width
|
7302
|
+
- C: number of channels (e.g., 1 for grayscale, 3 for RGB)
|
7303
|
+
|
7304
|
+
Returns:
|
7305
|
+
np.ndarray: Batch with dtype `float32`, normalized to [0, 1] if needed.
|
7306
|
+
|
7307
|
+
Notes:
|
7308
|
+
- Converts to `float32` if not already.
|
7309
|
+
- Each image is divided by its own maximum pixel value if that value > 1.
|
7310
|
+
"""
|
5136
7311
|
# Ensure the batch is of dtype float32
|
5137
7312
|
if batch.dtype != np.float32:
|
5138
7313
|
batch = batch.astype(np.float32)
|
@@ -5145,6 +7320,25 @@ def prepare_batch_for_segmentation(batch):
|
|
5145
7320
|
return batch
|
5146
7321
|
|
5147
7322
|
def check_index(df, elements=5, split_char='_'):
|
7323
|
+
"""
|
7324
|
+
Checks the indices of a DataFrame to ensure they can be split into a specified number of parts.
|
7325
|
+
|
7326
|
+
Args:
|
7327
|
+
df (pandas.DataFrame): The DataFrame whose indices are to be checked.
|
7328
|
+
elements (int, optional): The expected number of parts after splitting an index. Defaults to 5.
|
7329
|
+
split_char (str, optional): The character used to split the index. Defaults to '_'.
|
7330
|
+
|
7331
|
+
Raises:
|
7332
|
+
ValueError: If any index cannot be split into the specified number of parts,
|
7333
|
+
a ValueError is raised listing the problematic indices.
|
7334
|
+
|
7335
|
+
Example:
|
7336
|
+
>>> import pandas as pd
|
7337
|
+
>>> data = {'col1': [1, 2, 3]}
|
7338
|
+
>>> df = pd.DataFrame(data, index=['a_b_c_d_e', 'f_g_h_i', 'j_k_l_m_n'])
|
7339
|
+
>>> check_index(df)
|
7340
|
+
ValueError: Found 1 problematic indices that do not split into 5 parts.
|
7341
|
+
"""
|
5148
7342
|
problematic_indices = []
|
5149
7343
|
for idx in df.index:
|
5150
7344
|
parts = str(idx).split(split_char)
|
@@ -5158,6 +7352,22 @@ def check_index(df, elements=5, split_char='_'):
|
|
5158
7352
|
|
5159
7353
|
# Define the mapping function
|
5160
7354
|
def map_condition(col_value, neg='c1', pos='c2', mix='c3'):
|
7355
|
+
"""
|
7356
|
+
Maps a given column value to a specific condition label.
|
7357
|
+
|
7358
|
+
Args:
|
7359
|
+
col_value (str): The value to be mapped.
|
7360
|
+
neg (str, optional): The value representing the 'neg' condition. Defaults to 'c1'.
|
7361
|
+
pos (str, optional): The value representing the 'pos' condition. Defaults to 'c2'.
|
7362
|
+
mix (str, optional): The value representing the 'mix' condition. Defaults to 'c3'.
|
7363
|
+
|
7364
|
+
Returns:
|
7365
|
+
str: A string representing the mapped condition:
|
7366
|
+
- 'neg' if col_value matches the neg parameter.
|
7367
|
+
- 'pos' if col_value matches the pos parameter.
|
7368
|
+
- 'mix' if col_value matches the mix parameter.
|
7369
|
+
- 'screen' if col_value does not match any of the above.
|
7370
|
+
"""
|
5161
7371
|
if col_value == neg:
|
5162
7372
|
return 'neg'
|
5163
7373
|
elif col_value == pos:
|
@@ -5236,7 +7446,7 @@ def generate_cytoplasm_mask(nucleus_mask, cell_mask):
|
|
5236
7446
|
"""
|
5237
7447
|
Generates a cytoplasm mask from nucleus and cell masks.
|
5238
7448
|
|
5239
|
-
|
7449
|
+
Args:
|
5240
7450
|
- nucleus_mask (np.array): Binary or segmented mask of the nucleus (non-zero values represent nucleus).
|
5241
7451
|
- cell_mask (np.array): Binary or segmented mask of the whole cell (non-zero values represent cell).
|
5242
7452
|
|
@@ -5259,7 +7469,7 @@ def add_column_to_database(settings):
|
|
5259
7469
|
If the column already exists in the database, it adds the column with a suffix.
|
5260
7470
|
NaN values will remain as NULL in the database.
|
5261
7471
|
|
5262
|
-
|
7472
|
+
Args:
|
5263
7473
|
settings (dict): A dictionary containing the following keys:
|
5264
7474
|
csv_path (str): Path to the CSV file with the data to be added.
|
5265
7475
|
db_path (str): Path to the SQLite database (or connection string for other databases).
|
@@ -5355,6 +7565,24 @@ def fill_holes_in_mask(mask):
|
|
5355
7565
|
return filled_mask
|
5356
7566
|
|
5357
7567
|
def correct_metadata_column_names(df):
|
7568
|
+
"""
|
7569
|
+
Standardize column names in a metadata DataFrame.
|
7570
|
+
|
7571
|
+
This function renames commonly used but inconsistent metadata columns to a standardized format.
|
7572
|
+
|
7573
|
+
Renaming rules:
|
7574
|
+
- 'plate_name' → 'plateID'
|
7575
|
+
- 'column_name' or 'col' → 'columnID'
|
7576
|
+
- 'row_name' → 'rowID'
|
7577
|
+
- 'grna_name' → 'grna'
|
7578
|
+
- If 'plate_row' exists, it is split into 'plateID' and 'rowID' using '_' as a delimiter.
|
7579
|
+
|
7580
|
+
Args:
|
7581
|
+
df (pd.DataFrame): Input DataFrame with metadata columns.
|
7582
|
+
|
7583
|
+
Returns:
|
7584
|
+
pd.DataFrame: DataFrame with standardized column names.
|
7585
|
+
"""
|
5358
7586
|
if 'plate_name' in df.columns:
|
5359
7587
|
df = df.rename(columns={'plate_name': 'plateID'})
|
5360
7588
|
if 'column_name' in df.columns:
|
@@ -5370,6 +7598,27 @@ def correct_metadata_column_names(df):
|
|
5370
7598
|
return df
|
5371
7599
|
|
5372
7600
|
def control_filelist(folder, mode='columnID', values=['01','02']):
|
7601
|
+
def control_filelist(folder, mode='columnID', values=['01', '02']):
|
7602
|
+
"""
|
7603
|
+
Filters a list of files in a given folder based on a specified mode and values.
|
7604
|
+
|
7605
|
+
Args:
|
7606
|
+
folder (str): The path to the folder containing the files to be filtered.
|
7607
|
+
mode (str, optional): The filtering mode. Can be 'columnID' or 'rowID'.
|
7608
|
+
- 'columnID': Filters files where the second part of the filename (split by '_')
|
7609
|
+
has a substring starting from the second character that matches one of the values.
|
7610
|
+
- 'rowID': Filters files where the second part of the filename (split by '_')
|
7611
|
+
has a substring starting from the first character that matches one of the values.
|
7612
|
+
Defaults to 'columnID'.
|
7613
|
+
values (list of str, optional): A list of string values to filter the files by. Defaults to ['01', '02'].
|
7614
|
+
|
7615
|
+
Returns:
|
7616
|
+
list of str: A list of filenames that match the filtering criteria.
|
7617
|
+
|
7618
|
+
Raises:
|
7619
|
+
IndexError: If the filenames in the folder do not conform to the expected format
|
7620
|
+
(e.g., missing the second part after splitting by '_').
|
7621
|
+
"""
|
5373
7622
|
files = os.listdir(folder)
|
5374
7623
|
if mode == 'columnID':
|
5375
7624
|
filtered_files = [file for file in files if file.split('_')[1][1:] in values]
|
@@ -5378,6 +7627,33 @@ def control_filelist(folder, mode='columnID', values=['01','02']):
|
|
5378
7627
|
return filtered_files
|
5379
7628
|
|
5380
7629
|
def rename_columns_in_db(db_path):
|
7630
|
+
"""
|
7631
|
+
Renames specific columns in all user tables of a SQLite database based on a predefined mapping.
|
7632
|
+
|
7633
|
+
This function connects to the SQLite database at the given path, retrieves all user-defined tables,
|
7634
|
+
and renames columns in those tables according to the `rename_map` dictionary. If a column with the
|
7635
|
+
old name exists and the new name does not already exist in the same table, the column is renamed.
|
7636
|
+
|
7637
|
+
Args:
|
7638
|
+
db_path (str): The file path to the SQLite database.
|
7639
|
+
|
7640
|
+
Behavior:
|
7641
|
+
- Retrieves all user-defined tables in the database.
|
7642
|
+
- For each table, checks the column names against the `rename_map`.
|
7643
|
+
- Renames columns as specified in the `rename_map` if conditions are met.
|
7644
|
+
- Commits the changes to the database.
|
7645
|
+
|
7646
|
+
Notes:
|
7647
|
+
- The `rename_map` dictionary defines the mapping of old column names to new column names.
|
7648
|
+
- If a column with the new name already exists in a table, the old column will not be renamed.
|
7649
|
+
- The function uses SQLite's `ALTER TABLE ... RENAME COLUMN` syntax, which requires SQLite version 3.25.0 or higher.
|
7650
|
+
|
7651
|
+
Example:
|
7652
|
+
rename_columns_in_db("/path/to/database.db")
|
7653
|
+
|
7654
|
+
Raises:
|
7655
|
+
sqlite3.OperationalError: If there are issues executing SQL commands, such as unsupported SQLite versions.
|
7656
|
+
"""
|
5381
7657
|
# map old column names → new names
|
5382
7658
|
rename_map = {
|
5383
7659
|
'row': 'rowID',
|
@@ -5411,6 +7687,29 @@ def rename_columns_in_db(db_path):
|
|
5411
7687
|
con.close()
|
5412
7688
|
|
5413
7689
|
def group_feature_class(df, feature_groups=['cell', 'cytoplasm', 'nucleus', 'pathogen'], name='compartment'):
|
7690
|
+
"""
|
7691
|
+
Classify and group features by category, then compute summed importance for each group.
|
7692
|
+
|
7693
|
+
This function adds a new column to the DataFrame to classify features based on regex matching
|
7694
|
+
against the given `feature_groups`. It then computes the total importance per group.
|
7695
|
+
|
7696
|
+
Args:
|
7697
|
+
df (pd.DataFrame): Input DataFrame with at least the columns `'feature'` and `'importance'`.
|
7698
|
+
feature_groups (list of str, optional): List of feature group identifiers (used in regex matching).
|
7699
|
+
Defaults to ['cell', 'cytoplasm', 'nucleus', 'pathogen'].
|
7700
|
+
name (str, optional): Name of the new column used to store group classifications.
|
7701
|
+
If set to `'channel'`, missing values in that column will be filled with `'morphology'`.
|
7702
|
+
Defaults to 'compartment'.
|
7703
|
+
|
7704
|
+
Returns:
|
7705
|
+
pd.DataFrame: Modified DataFrame including:
|
7706
|
+
- A new column with group classifications (based on `name` argument).
|
7707
|
+
- A summary with summed importance for each group and a row for total importance.
|
7708
|
+
|
7709
|
+
Notes:
|
7710
|
+
- A feature matching multiple groups will have their labels joined with a hyphen (e.g., "cell-nucleus").
|
7711
|
+
- Classification is done via regex search on the `'feature'` column.
|
7712
|
+
"""
|
5414
7713
|
|
5415
7714
|
# Function to determine compartment based on multiple matches
|
5416
7715
|
def find_feature_class(feature, compartments):
|
@@ -5441,6 +7740,33 @@ def group_feature_class(df, feature_groups=['cell', 'cytoplasm', 'nucleus', 'pat
|
|
5441
7740
|
return df
|
5442
7741
|
|
5443
7742
|
def delete_intermedeate_files(settings):
|
7743
|
+
"""
|
7744
|
+
Delete intermediate files and directories specified in the settings dictionary.
|
7745
|
+
|
7746
|
+
This function removes a predefined set of subdirectories under the path given by
|
7747
|
+
`settings['src']`. If deletion fails, an error message is printed.
|
7748
|
+
|
7749
|
+
Args:
|
7750
|
+
settings (dict): Must include the key `'src'`, the base directory containing
|
7751
|
+
intermediate files and subdirectories.
|
7752
|
+
|
7753
|
+
Behavior:
|
7754
|
+
- Verifies that `'src'` and its `orig/` subdirectory exist.
|
7755
|
+
- Deletes the following under `src` if they exist:
|
7756
|
+
* 'stack'
|
7757
|
+
* 'masks'
|
7758
|
+
* directories '1' through '10'
|
7759
|
+
- Compares lengths of `merged_stack` and `path_stack` to decide on deletion.
|
7760
|
+
- Prints success or error messages for each directory.
|
7761
|
+
|
7762
|
+
Notes:
|
7763
|
+
- If `src` or `orig/` is missing, the function exits early.
|
7764
|
+
- Any deletion exceptions are caught and reported; no exceptions are raised.
|
7765
|
+
|
7766
|
+
Example:
|
7767
|
+
>>> settings = {'src': '/path/to/source'}
|
7768
|
+
>>> delete_intermedeate_files(settings)
|
7769
|
+
"""
|
5444
7770
|
|
5445
7771
|
path_orig = os.path.join(settings['src'], 'orig')
|
5446
7772
|
path_stack = os.path.join(settings['src'], 'stack')
|
@@ -5485,7 +7811,7 @@ def filter_and_save_csv(input_csv, output_csv, column_name, upper_threshold, low
|
|
5485
7811
|
Reads a CSV into a DataFrame, filters rows based on a column for values > upper_threshold and < lower_threshold,
|
5486
7812
|
and saves the filtered DataFrame to a new CSV file.
|
5487
7813
|
|
5488
|
-
|
7814
|
+
Args:
|
5489
7815
|
input_csv (str): Path to the input CSV file.
|
5490
7816
|
output_csv (str): Path to save the filtered CSV file.
|
5491
7817
|
column_name (str): Column name to apply the filters on.
|
@@ -5511,7 +7837,7 @@ def extract_tar_bz2_files(folder_path):
|
|
5511
7837
|
"""
|
5512
7838
|
Extracts all .tar.bz2 files in the given folder into subfolders with the same name as the tar file.
|
5513
7839
|
|
5514
|
-
|
7840
|
+
Args:
|
5515
7841
|
folder_path (str): Path to the folder containing .tar.bz2 files.
|
5516
7842
|
"""
|
5517
7843
|
if not os.path.isdir(folder_path):
|
@@ -5539,7 +7865,7 @@ def calculate_shortest_distance(df, object1, object2):
|
|
5539
7865
|
"""
|
5540
7866
|
Calculate the shortest edge-to-edge distance between two objects (e.g., pathogen and nucleus).
|
5541
7867
|
|
5542
|
-
|
7868
|
+
Args:
|
5543
7869
|
- df: Pandas DataFrame containing measurements
|
5544
7870
|
- object1: String, name of the first object (e.g., "pathogen")
|
5545
7871
|
- object2: String, name of the second object (e.g., "nucleus")
|
@@ -5694,6 +8020,22 @@ def copy_images_to_consolidated(image_path_map, root_folder):
|
|
5694
8020
|
#print(f"Copied: {original_path} -> {new_file_path}")
|
5695
8021
|
|
5696
8022
|
def correct_metadata(df):
|
8023
|
+
"""
|
8024
|
+
Corrects and standardizes the metadata column names in a DataFrame.
|
8025
|
+
This function checks for specific column names in the input DataFrame and
|
8026
|
+
renames or maps them to standardized names for consistency. The following
|
8027
|
+
transformations are applied:
|
8028
|
+
- If 'object_name' exists, it is mapped to 'objectID'.
|
8029
|
+
- If 'field_name' exists, it is mapped to 'fieldID'.
|
8030
|
+
- If 'plate' or 'plate_name' exists, they are mapped to 'plateID'.
|
8031
|
+
- If 'row' or 'row_name' exists, they are renamed to 'rowID'.
|
8032
|
+
- If 'col', 'column', or 'column_name' exists, they are renamed to 'columnID'.
|
8033
|
+
- If 'field' or 'field_name' exists, they are renamed to 'fieldID'.
|
8034
|
+
Args:
|
8035
|
+
df (pandas.DataFrame): The input DataFrame containing metadata columns.
|
8036
|
+
Returns:
|
8037
|
+
pandas.DataFrame: The DataFrame with standardized metadata column names.
|
8038
|
+
"""
|
5697
8039
|
|
5698
8040
|
#if 'object' in df.columns:
|
5699
8041
|
# df['objectID'] = df['object']
|
@@ -5737,7 +8079,7 @@ def remove_outliers_by_group(df, group_col, value_col, method='iqr', threshold=1
|
|
5737
8079
|
"""
|
5738
8080
|
Removes outliers from `value_col` within each group defined by `group_col`.
|
5739
8081
|
|
5740
|
-
|
8082
|
+
Args:
|
5741
8083
|
df (pd.DataFrame): The input DataFrame.
|
5742
8084
|
group_col (str): Column name to group by.
|
5743
8085
|
value_col (str): Column containing values to check for outliers.
|