spacr 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
spacr/utils.py ADDED
@@ -0,0 +1,2726 @@
1
+ import os, re, sqlite3, gc, torch, torchvision, time, random, string, shutil, cv2, tarfile, glob
2
+
3
+ import numpy as np
4
+ from skimage import morphology
5
+ from skimage.measure import label, regionprops_table, regionprops
6
+ import skimage.measure as measure
7
+ from collections import defaultdict
8
+ from PIL import Image
9
+ import pandas as pd
10
+ from statsmodels.stats.outliers_influence import variance_inflation_factor
11
+ from statsmodels.stats.stattools import durbin_watson
12
+ import statsmodels.formula.api as smf
13
+ import statsmodels.api as sm
14
+ from statsmodels.stats.multitest import multipletests
15
+ from itertools import combinations
16
+ from collections import OrderedDict
17
+ from functools import reduce
18
+ from IPython.display import display, clear_output
19
+ from multiprocessing import Pool, cpu_count
20
+ from skimage.transform import resize as resizescikit
21
+ import torch.nn as nn
22
+ import torch.nn.functional as F
23
+ #from torchsummary import summary
24
+ from torch.utils.checkpoint import checkpoint
25
+ from torch.utils.data import Subset
26
+ from torch.autograd import grad
27
+ from torchvision import models
28
+ from skimage.segmentation import clear_border
29
+ import seaborn as sns
30
+ import matplotlib.pyplot as plt
31
+ import scipy.ndimage as ndi
32
+ from scipy.stats import fisher_exact
33
+ from scipy.ndimage import binary_erosion, binary_dilation
34
+ from skimage.exposure import rescale_intensity
35
+ from sklearn.metrics import auc, precision_recall_curve
36
+ from sklearn.model_selection import train_test_split
37
+ from sklearn.linear_model import Lasso, Ridge
38
+ from sklearn.preprocessing import OneHotEncoder
39
+ from torchvision.models.resnet import ResNet18_Weights, ResNet34_Weights, ResNet50_Weights, ResNet101_Weights, ResNet152_Weights
40
+
41
+ from .logger import log_function_call
42
+
43
+ #from .io import _read_and_join_tables, _save_figure
44
+ #from .timelapse import _btrack_track_cells, _trackpy_track_cells
45
+ #from .plot import _plot_images_on_grid, plot_masks, _plot_histograms_and_stats, plot_resize, _plot_plates, _reg_v_plot, plot_masks
46
+ #from .core import identify_masks
47
+
48
+ def _convert_cq1_well_id(well_id):
49
+ """
50
+ Converts a well ID to the CQ1 well format.
51
+
52
+ Args:
53
+ well_id (int): The well ID to be converted.
54
+
55
+ Returns:
56
+ str: The well ID in CQ1 well format.
57
+
58
+ """
59
+ well_id = int(well_id)
60
+ # ASCII code for 'A'
61
+ ascii_A = ord('A')
62
+ # Calculate row and column
63
+ row, col = divmod(well_id - 1, 24)
64
+ # Convert row to letter (A-P) and adjust col to start from 1
65
+ row_letter = chr(ascii_A + row)
66
+ # Format column as two digits
67
+ well_format = f"{row_letter}{col + 1:02d}"
68
+ return well_format
69
+
70
+ def _get_cellpose_batch_size():
71
+ try:
72
+ # Check if CUDA is available
73
+ if torch.cuda.is_available():
74
+ device_properties = torch.cuda.get_device_properties(0)
75
+ vram_gb = device_properties.total_memory / (1024**3) # Convert bytes to gigabytes
76
+ else:
77
+ print("CUDA is not available. Please check your installation and GPU.")
78
+ return 8
79
+ if vram_gb < 8:
80
+ batch_size = 8
81
+ elif vram_gb > 8 and vram_gb < 12:
82
+ batch_size = 16
83
+ elif vram_gb > 12 and vram_gb < 24:
84
+ batch_size = 48
85
+ elif vram_gb > 24:
86
+ batch_size = 96
87
+ print(f"Device {0}: {device_properties.name}, VRAM: {vram_gb:.2f} GB, cellpose batch size: {batch_size}")
88
+ return batch_size
89
+ except Exception as e:
90
+ return 8
91
+
92
+ def _extract_filename_metadata(filenames, src, images_by_key, regular_expression, metadata_type='cellvoyager', pick_slice=False, skip_mode='01'):
93
+ for filename in filenames:
94
+ match = regular_expression.match(filename)
95
+ if match:
96
+ try:
97
+ try:
98
+ plate = match.group('plateID')
99
+ except:
100
+ plate = os.path.basename(src)
101
+
102
+ well = match.group('wellID')
103
+ field = match.group('fieldID')
104
+ channel = match.group('chanID')
105
+ mode = None
106
+
107
+ if well[0].isdigit():
108
+ well = str(_safe_int_convert(well))
109
+ if field[0].isdigit():
110
+ field = str(_safe_int_convert(field))
111
+ if channel[0].isdigit():
112
+ channel = str(_safe_int_convert(channel))
113
+
114
+ if metadata_type =='cq1':
115
+ orig_wellID = wellID
116
+ wellID = _convert_cq1_well_id(wellID)
117
+ clear_output(wait=True)
118
+ print(f'\033[KConverted Well ID: {orig_wellID} to {wellID}', end='\r', flush=True)
119
+
120
+ if pick_slice:
121
+ try:
122
+ mode = match.group('AID')
123
+ except IndexError:
124
+ sliceid = '00'
125
+
126
+ if mode == skip_mode:
127
+ continue
128
+
129
+ key = (plate, well, field, channel, mode)
130
+ with Image.open(os.path.join(src, filename)) as img:
131
+ images_by_key[key].append(np.array(img))
132
+ except IndexError:
133
+ print(f"Could not extract information from filename {filename} using provided regex")
134
+ else:
135
+ print(f"Filename {filename} did not match provided regex")
136
+ continue
137
+
138
+ return images_by_key
139
+
140
+ def mask_object_count(mask):
141
+ """
142
+ Counts the number of objects in a given mask.
143
+
144
+ Parameters:
145
+ - mask: numpy.ndarray
146
+ The mask containing object labels.
147
+
148
+ Returns:
149
+ - int
150
+ The number of objects in the mask.
151
+ """
152
+ unique_labels = np.unique(mask)
153
+ num_objects = len(unique_labels[unique_labels!=0])
154
+ return num_objects
155
+
156
+ def _update_database_with_merged_info(db_path, df, table='png_list', columns=['pathogen', 'treatment', 'host_cells', 'condition', 'prcfo']):
157
+ """
158
+ Merges additional columns into the png_list table in the SQLite database and updates it.
159
+
160
+ Args:
161
+ db_path (str): The path to the SQLite database file.
162
+ df (pd.DataFrame): DataFrame containing the additional info to be merged.
163
+ table (str): Name of the table to update in the database. Defaults to 'png_list'.
164
+ """
165
+ # Connect to the SQLite database
166
+ conn = sqlite3.connect(db_path)
167
+
168
+ # Read the existing table into a DataFrame
169
+ try:
170
+ existing_df = pd.read_sql(f"SELECT * FROM {table}", conn)
171
+ except Exception as e:
172
+ print(f"Failed to read table {table} from database: {e}")
173
+ conn.close()
174
+ return
175
+
176
+ if 'prcfo' not in df.columns:
177
+ print(f'generating prcfo columns')
178
+ try:
179
+ df['prcfo'] = df['plate'].astype(str) + '_' + df['row'].astype(str) + '_' + df['col'].astype(str) + '_' + df['field'].astype(str) + '_o' + df['object_label'].astype(int).astype(str)
180
+ except Exception as e:
181
+ print('Merging on cell failed, trying with cell_id')
182
+ try:
183
+ df['prcfo'] = df['plate'].astype(str) + '_' + df['row'].astype(str) + '_' + df['col'].astype(str) + '_' + df['field'].astype(str) + '_o' + df['cell_id'].astype(int).astype(str)
184
+ except Exception as e:
185
+ print(e)
186
+
187
+ # Merge the existing DataFrame with the new info based on the 'prcfo' column
188
+ merged_df = pd.merge(existing_df, df[columns], on='prcfo', how='left')
189
+
190
+ # Drop the existing table and replace it with the updated DataFrame
191
+ try:
192
+ conn.execute(f"DROP TABLE IF EXISTS {table}")
193
+ merged_df.to_sql(table, conn, index=False)
194
+ print(f"Table {table} successfully updated in the database.")
195
+ except Exception as e:
196
+ print(f"Failed to update table {table} in the database: {e}")
197
+ finally:
198
+ conn.close()
199
+
200
+ def _generate_representative_images(db_path, cells=['HeLa'], cell_loc=None, pathogens=['rh'], pathogen_loc=None, treatments=['cm'], treatment_loc=None, channel_of_interest=1, compartments = ['pathogen','cytoplasm'], measurement = 'mean_intensity', nr_imgs=16, channel_indices=[0,1,2], um_per_pixel=0.1, scale_bar_length_um=10, plot=False, fontsize=12, show_filename=True, channel_names=None, update_db=True):
201
+ """
202
+ Generates representative images based on the provided parameters.
203
+
204
+ Args:
205
+ db_path (str): The path to the SQLite database file.
206
+ cells (list, optional): The list of host cell types. Defaults to ['HeLa'].
207
+ cell_loc (list, optional): The list of location identifiers for host cells. Defaults to None.
208
+ pathogens (list, optional): The list of pathogens. Defaults to ['rh'].
209
+ pathogen_loc (list, optional): The list of location identifiers for pathogens. Defaults to None.
210
+ treatments (list, optional): The list of treatments. Defaults to ['cm'].
211
+ treatment_loc (list, optional): The list of location identifiers for treatments. Defaults to None.
212
+ channel_of_interest (int, optional): The index of the channel of interest. Defaults to 1.
213
+ compartments (list or str, optional): The compartments to compare. Defaults to ['pathogen', 'cytoplasm'].
214
+ measurement (str, optional): The measurement to compare. Defaults to 'mean_intensity'.
215
+ nr_imgs (int, optional): The number of representative images to generate. Defaults to 16.
216
+ channel_indices (list, optional): The indices of the channels to include in the representative images. Defaults to [0, 1, 2].
217
+ um_per_pixel (float, optional): The scale factor for converting pixels to micrometers. Defaults to 0.1.
218
+ scale_bar_length_um (float, optional): The length of the scale bar in micrometers. Defaults to 10.
219
+ plot (bool, optional): Whether to plot the representative images. Defaults to False.
220
+ fontsize (int, optional): The font size for the plot. Defaults to 12.
221
+ show_filename (bool, optional): Whether to show the filename on the plot. Defaults to True.
222
+ channel_names (list, optional): The names of the channels. Defaults to None.
223
+
224
+ Returns:
225
+ None
226
+ """
227
+
228
+ from .io import _read_and_join_tables, _save_figure
229
+ from .plot import _plot_images_on_grid
230
+
231
+ df = _read_and_join_tables(db_path)
232
+ df = _annotate_conditions(df, cells, cell_loc, pathogens, pathogen_loc, treatments,treatment_loc)
233
+
234
+ if update_db:
235
+ _update_database_with_merged_info(db_path, df, table='png_list', columns=['pathogen', 'treatment', 'host_cells', 'condition', 'prcfo'])
236
+
237
+ if isinstance(compartments, list):
238
+ if len(compartments) > 1:
239
+ df['new_measurement'] = df[f'{compartments[0]}_channel_{channel_of_interest}_{measurement}']/df[f'{compartments[1]}_channel_{channel_of_interest}_{measurement}']
240
+ else:
241
+ df['new_measurement'] = df['cell_area']
242
+ dfs = {condition: df_group for condition, df_group in df.groupby('condition')}
243
+ conditions = df['condition'].dropna().unique().tolist()
244
+ for condition in conditions:
245
+ df = dfs[condition]
246
+ df = _filter_closest_to_stat(df, column='new_measurement', n_rows=nr_imgs, use_median=False)
247
+ png_paths_by_condition = df['png_path'].tolist()
248
+ fig = _plot_images_on_grid(png_paths_by_condition, channel_indices, um_per_pixel, scale_bar_length_um, fontsize, show_filename, channel_names, plot)
249
+ src = os.path.dirname(db_path)
250
+ os.makedirs(src, exist_ok=True)
251
+ _save_figure(fig=fig, src=src, text=condition)
252
+ for channel in channel_indices:
253
+ channel_indices=[channel]
254
+ fig = _plot_images_on_grid(png_paths_by_condition, channel_indices, um_per_pixel, scale_bar_length_um, fontsize, show_filename, channel_names, plot)
255
+ _save_figure(fig, src, text=f'channel_{channel}_{condition}')
256
+ plt.close()
257
+
258
+ # Adjusted mapping function to infer type from location identifiers
259
+ def _map_values(row, values, locs):
260
+ """
261
+ Maps values to a specific location in the row or column based on the given locs.
262
+
263
+ Args:
264
+ row (dict): The row dictionary containing the location identifier.
265
+ values (list): The list of values to be mapped.
266
+ locs (list): The list of location identifiers.
267
+
268
+ Returns:
269
+ The mapped value corresponding to the given row or column location, or None if not found.
270
+ """
271
+ if locs:
272
+ value_dict = {loc: value for value, loc_list in zip(values, locs) for loc in loc_list}
273
+ # Determine if we're dealing with row or column based on first location identifier
274
+ type_ = 'row' if locs[0][0][0] == 'r' else 'col'
275
+ return value_dict.get(row[type_], None)
276
+ return values[0] if values else None
277
+
278
+ def _annotate_conditions(df, cells=['HeLa'], cell_loc=None, pathogens=['rh'], pathogen_loc=None, treatments=['cm'], treatment_loc=None):
279
+ """
280
+ Annotates conditions in the given DataFrame based on the provided parameters.
281
+
282
+ Args:
283
+ df (pandas.DataFrame): The DataFrame to annotate.
284
+ cells (list, optional): The list of host cell types. Defaults to ['HeLa'].
285
+ cell_loc (list, optional): The list of location identifiers for host cells. Defaults to None.
286
+ pathogens (list, optional): The list of pathogens. Defaults to ['rh'].
287
+ pathogen_loc (list, optional): The list of location identifiers for pathogens. Defaults to None.
288
+ treatments (list, optional): The list of treatments. Defaults to ['cm'].
289
+ treatment_loc (list, optional): The list of location identifiers for treatments. Defaults to None.
290
+
291
+ Returns:
292
+ pandas.DataFrame: The annotated DataFrame with the 'host_cells', 'pathogen', 'treatment', and 'condition' columns.
293
+ """
294
+
295
+
296
+ # Apply mappings or defaults
297
+ df['host_cells'] = [cells[0]] * len(df) if cell_loc is None else df.apply(_map_values, args=(cells, cell_loc), axis=1)
298
+ df['pathogen'] = [pathogens[0]] * len(df) if pathogen_loc is None else df.apply(_map_values, args=(pathogens, pathogen_loc), axis=1)
299
+ df['treatment'] = [treatments[0]] * len(df) if treatment_loc is None else df.apply(_map_values, args=(treatments, treatment_loc), axis=1)
300
+
301
+ # Construct condition column
302
+ df['condition'] = df.apply(lambda row: '_'.join(filter(None, [row.get('pathogen'), row.get('treatment')])), axis=1)
303
+ df['condition'] = df['condition'].apply(lambda x: x if x else 'none')
304
+ return df
305
+
306
+ def normalize_to_dtype(array, q1=2,q2=98, percentiles=None):
307
+ """
308
+ Normalize the input array to a specified data type.
309
+
310
+ Parameters:
311
+ - array: numpy array
312
+ The input array to be normalized.
313
+ - q1: int, optional
314
+ The lower percentile value for normalization. Default is 2.
315
+ - q2: int, optional
316
+ The upper percentile value for normalization. Default is 98.
317
+ - percentiles: list of tuples, optional
318
+ A list of tuples containing the percentile values for each image in the array.
319
+ If provided, the percentiles for each image will be used instead of q1 and q2.
320
+
321
+ Returns:
322
+ - new_stack: numpy array
323
+ The normalized array with the same shape as the input array.
324
+ """
325
+ nimg = array.shape[2]
326
+ new_stack = np.empty_like(array)
327
+ for i,v in enumerate(range(nimg)):
328
+ img = np.squeeze(array[:, :, v])
329
+ non_zero_img = img[img > 0]
330
+ if non_zero_img.size > 0: # check if there are non-zero values
331
+ img_min = np.percentile(non_zero_img, q1) # change percentile from 0.02 to 2
332
+ img_max = np.percentile(non_zero_img, q2) # change percentile from 0.98 to 98
333
+ img = rescale_intensity(img, in_range=(img_min, img_max), out_range='dtype')
334
+ else: # if there are no non-zero values, just use the image as it is
335
+ if percentiles==None:
336
+ img_min, img_max = img.min(), img.max()
337
+ else:
338
+ img_min, img_max = percentiles[i]
339
+ img = rescale_intensity(img, in_range=(img_min, img_max), out_range='dtype')
340
+ img = np.expand_dims(img, axis=2)
341
+ new_stack[:, :, v] = img[:, :, 0]
342
+ return new_stack
343
+
344
+ def _list_endpoint_subdirectories(base_dir):
345
+ """
346
+ Returns a list of subdirectories within the given base directory.
347
+
348
+ Args:
349
+ base_dir (str): The base directory to search for subdirectories.
350
+
351
+ Returns:
352
+ list: A list of subdirectories within the base directory.
353
+ """
354
+
355
+ endpoint_subdirectories = []
356
+ for root, dirs, _ in os.walk(base_dir):
357
+ if not dirs:
358
+ endpoint_subdirectories.append(root)
359
+
360
+ endpoint_subdirectories = [path for path in endpoint_subdirectories if os.path.basename(path) != 'figure']
361
+ return endpoint_subdirectories
362
+
363
+ def _generate_names(file_name, cell_id, cell_nucleus_ids, cell_pathogen_ids, source_folder, crop_mode='cell'):
364
+ """
365
+ Generate names for the image, folder, and table based on the given parameters.
366
+
367
+ Args:
368
+ file_name (str): The name of the file.
369
+ cell_id (numpy.ndarray): An array of cell IDs.
370
+ cell_nucleus_ids (numpy.ndarray): An array of cell nucleus IDs.
371
+ cell_pathogen_ids (numpy.ndarray): An array of cell pathogen IDs.
372
+ source_folder (str): The source folder path.
373
+ crop_mode (str, optional): The crop mode. Defaults to 'cell'.
374
+
375
+ Returns:
376
+ tuple: A tuple containing the image name, folder path, and table name.
377
+ """
378
+ non_zero_cell_ids = cell_id[cell_id != 0]
379
+ cell_id_str = "multi" if non_zero_cell_ids.size > 1 else str(non_zero_cell_ids[0]) if non_zero_cell_ids.size == 1 else "none"
380
+ cell_nucleus_ids = cell_nucleus_ids[cell_nucleus_ids != 0]
381
+ cell_nucleus_id_str = "multi" if cell_nucleus_ids.size > 1 else str(cell_nucleus_ids[0]) if cell_nucleus_ids.size == 1 else "none"
382
+ cell_pathogen_ids = cell_pathogen_ids[cell_pathogen_ids != 0]
383
+ cell_pathogen_id_str = "multi" if cell_pathogen_ids.size > 1 else str(cell_pathogen_ids[0]) if cell_pathogen_ids.size == 1 else "none"
384
+ fldr = f"{source_folder}/data/"
385
+ img_name = ""
386
+ if crop_mode == 'nucleus':
387
+ img_name = f"{file_name}_{cell_id_str}_{cell_nucleus_id_str}.png"
388
+ fldr += "single_nucleus/" if cell_nucleus_ids.size == 1 else "multiple_nucleus/" if cell_nucleus_ids.size > 1 else "no_nucleus/"
389
+ fldr += "single_pathogen/" if cell_pathogen_ids.size == 1 else "multiple_pathogens/" if cell_pathogen_ids.size > 1 else "uninfected/"
390
+ elif crop_mode == 'pathogen':
391
+ img_name = f"{file_name}_{cell_id_str}_{cell_pathogen_id_str}.png"
392
+ fldr += "single_nucleus/" if cell_nucleus_ids.size == 1 else "multiple_nucleus/" if cell_nucleus_ids.size > 1 else "no_nucleus/"
393
+ fldr += "infected/" if cell_pathogen_ids.size >= 1 else "uninfected/"
394
+ elif crop_mode == 'cell' or crop_mode == 'cytoplasm':
395
+ img_name = f"{file_name}_{cell_id_str}.png"
396
+ fldr += "single_nucleus/" if cell_nucleus_ids.size == 1 else "multiple_nucleus/" if cell_nucleus_ids.size > 1 else "no_nucleus/"
397
+ fldr += "single_pathogen/" if cell_pathogen_ids.size == 1 else "multiple_pathogens/" if cell_pathogen_ids.size > 1 else "uninfected/"
398
+ parts = file_name.split('_')
399
+ plate = parts[0]
400
+ well = parts[1]
401
+ metadata = f'{plate}_{well}'
402
+ fldr = os.path.join(fldr,metadata)
403
+ table_name = fldr.replace("/", "_")
404
+ return img_name, fldr, table_name
405
+
406
+ def _find_bounding_box(crop_mask, _id, buffer=10):
407
+ """
408
+ Find the bounding box coordinates for a given object ID in a crop mask.
409
+
410
+ Parameters:
411
+ crop_mask (ndarray): The crop mask containing object IDs.
412
+ _id (int): The object ID to find the bounding box for.
413
+ buffer (int, optional): The buffer size to add to the bounding box coordinates. Defaults to 10.
414
+
415
+ Returns:
416
+ ndarray: A new mask with the same dimensions as crop_mask, where the bounding box area is filled with the object ID.
417
+ """
418
+ object_indices = np.where(crop_mask == _id)
419
+
420
+ # Determine the bounding box coordinates
421
+ y_min, y_max = object_indices[0].min(), object_indices[0].max()
422
+ x_min, x_max = object_indices[1].min(), object_indices[1].max()
423
+
424
+ # Add buffer to the bounding box coordinates
425
+ y_min = max(y_min - buffer, 0)
426
+ y_max = min(y_max + buffer, crop_mask.shape[0] - 1)
427
+ x_min = max(x_min - buffer, 0)
428
+ x_max = min(x_max + buffer, crop_mask.shape[1] - 1)
429
+
430
+ # Create a new mask with the same dimensions as crop_mask
431
+ new_mask = np.zeros_like(crop_mask)
432
+
433
+ # Fill in the bounding box area with the _id
434
+ new_mask[y_min:y_max+1, x_min:x_max+1] = _id
435
+
436
+ return new_mask
437
+
438
+ def _merge_and_save_to_database(morph_df, intensity_df, table_type, source_folder, file_name, experiment, timelapse=False):
439
+ """
440
+ Merges morphology and intensity dataframes, renames columns, adds additional columns, rearranges columns,
441
+ and saves the merged dataframe to a SQLite database.
442
+
443
+ Args:
444
+ morph_df (pd.DataFrame): Dataframe containing morphology data.
445
+ intensity_df (pd.DataFrame): Dataframe containing intensity data.
446
+ table_type (str): Type of table to save the merged dataframe to.
447
+ source_folder (str): Path to the source folder.
448
+ file_name (str): Name of the file.
449
+ experiment (str): Name of the experiment.
450
+ timelapse (bool, optional): Indicates if the data is from a timelapse experiment. Defaults to False.
451
+
452
+ Raises:
453
+ ValueError: If an invalid table_type is provided or if columns are missing in the dataframe.
454
+
455
+ """
456
+ morph_df = _check_integrity(morph_df)
457
+ intensity_df = _check_integrity(intensity_df)
458
+ if len(morph_df) > 0 and len(intensity_df) > 0:
459
+ merged_df = pd.merge(morph_df, intensity_df, on='object_label', how='outer')
460
+ merged_df = merged_df.rename(columns={"label_list_x": "label_list_morphology", "label_list_y": "label_list_intensity"})
461
+ merged_df['file_name'] = file_name
462
+ merged_df['path_name'] = os.path.join(source_folder, file_name + '.npy')
463
+ if timelapse:
464
+ merged_df[['plate', 'row', 'col', 'field', 'timeid', 'prcf']] = merged_df['file_name'].apply(lambda x: pd.Series(_map_wells(x, timelapse)))
465
+ else:
466
+ merged_df[['plate', 'row', 'col', 'field', 'prcf']] = merged_df['file_name'].apply(lambda x: pd.Series(_map_wells(x, timelapse)))
467
+ cols = merged_df.columns.tolist() # get the list of all columns
468
+ if table_type == 'cell' or table_type == 'cytoplasm':
469
+ column_list = ['object_label', 'plate', 'row', 'col', 'field', 'prcf', 'file_name', 'path_name']
470
+ elif table_type == 'nucleus' or table_type == 'pathogen':
471
+ column_list = ['object_label', 'cell_id', 'plate', 'row', 'col', 'field', 'prcf', 'file_name', 'path_name']
472
+ else:
473
+ raise ValueError(f"Invalid table_type: {table_type}")
474
+ # Check if all columns in column_list are in cols
475
+ missing_columns = [col for col in column_list if col not in cols]
476
+ if len(missing_columns) == 1 and missing_columns[0] == 'cell_id':
477
+ missing_columns = False
478
+ column_list = ['object_label', 'plate', 'row', 'col', 'field', 'prcf', 'file_name', 'path_name']
479
+ if missing_columns:
480
+ raise ValueError(f"Columns missing in DataFrame: {missing_columns}")
481
+ for i, col in enumerate(column_list):
482
+ cols.insert(i, cols.pop(cols.index(col)))
483
+ merged_df = merged_df[cols] # rearrange the columns
484
+ if len(merged_df) > 0:
485
+ try:
486
+ conn = sqlite3.connect(f'{source_folder}/measurements/measurements.db', timeout=5)
487
+ merged_df.to_sql(table_type, conn, if_exists='append', index=False)
488
+ except sqlite3.OperationalError as e:
489
+ print("SQLite error:", e)
490
+
491
+ def _safe_int_convert(value, default=0):
492
+ """
493
+ Converts the given value to an integer if possible, otherwise returns the default value.
494
+
495
+ Args:
496
+ value: The value to be converted to an integer.
497
+ default: The default value to be returned if the conversion fails. Default is 0.
498
+
499
+ Returns:
500
+ The converted integer value if successful, otherwise the default value.
501
+ """
502
+ try:
503
+ return int(value)
504
+ except ValueError:
505
+ print(f'Could not convert {value} to int using {default}', end='\r', flush=True)
506
+ return default
507
+
508
+ def _map_wells(file_name, timelapse=False):
509
+ """
510
+ Maps the components of a file name to plate, row, column, field, and timeid (if timelapse is True).
511
+
512
+ Args:
513
+ file_name (str): The name of the file.
514
+ timelapse (bool, optional): Indicates whether the file is part of a timelapse sequence. Defaults to False.
515
+
516
+ Returns:
517
+ tuple: A tuple containing the mapped values for plate, row, column, field, and timeid (if timelapse is True).
518
+ """
519
+ try:
520
+ parts = file_name.split('_')
521
+ plate = 'p' + parts[0]
522
+ well = parts[1]
523
+ field = 'f' + str(_safe_int_convert(parts[2]))
524
+ if timelapse:
525
+ timeid = 't' + str(_safe_int_convert(parts[3]))
526
+ if well[0].isalpha():
527
+ row = 'r' + str(string.ascii_uppercase.index(well[0]) + 1)
528
+ column = 'c' + str(int(well[1:]))
529
+ else:
530
+ row, column = well, well
531
+ if timelapse:
532
+ prcf = '_'.join([plate, row, column, field, timeid])
533
+ else:
534
+ prcf = '_'.join([plate, row, column, field])
535
+ except Exception as e:
536
+ print(f"Error processing filename: {file_name}")
537
+ print(f"Error: {e}")
538
+ plate, row, column, field, timeid, prcf = 'error','error','error','error','error', 'error'
539
+ if timelapse:
540
+ return plate, row, column, field, timeid, prcf
541
+ else:
542
+ return plate, row, column, field, prcf
543
+
544
+ def _map_wells_png(file_name, timelapse=False):
545
+ """
546
+ Maps the components of a file name to their corresponding values.
547
+
548
+ Args:
549
+ file_name (str): The name of the file.
550
+ timelapse (bool, optional): Indicates whether the file is part of a timelapse sequence. Defaults to False.
551
+
552
+ Returns:
553
+ tuple: A tuple containing the mapped components of the file name.
554
+
555
+ Raises:
556
+ None
557
+
558
+ """
559
+ try:
560
+ root, ext = os.path.splitext(file_name)
561
+ parts = root.split('_')
562
+ plate = 'p' + parts[0]
563
+ well = parts[1]
564
+ field = 'f' + str(_safe_int_convert(parts[2]))
565
+ if timelapse:
566
+ timeid = 't' + str(_safe_int_convert(parts[3]))
567
+ object_id = 'o' + str(_safe_int_convert(parts[-1], default='none'))
568
+ if well[0].isalpha():
569
+ row = 'r' + str(string.ascii_uppercase.index(well[0]) + 1)
570
+ column = 'c' + str(_safe_int_convert(well[1:]))
571
+ else:
572
+ row, column = well, well
573
+ if timelapse:
574
+ prcfo = '_'.join([plate, row, column, field, timeid, object_id])
575
+ else:
576
+ prcfo = '_'.join([plate, row, column, field, object_id])
577
+ except Exception as e:
578
+ print(f"Error processing filename: {file_name}")
579
+ print(f"Error: {e}")
580
+ plate, row, column, field, object_id, prcfo = 'error', 'error', 'error', 'error', 'error', 'error'
581
+ if timelapse:
582
+ return plate, row, column, field, timeid, prcfo, object_id,
583
+ else:
584
+ return plate, row, column, field, prcfo, object_id
585
+
586
+ def _check_integrity(df):
587
+ """
588
+ Check the integrity of the DataFrame and perform necessary modifications.
589
+
590
+ Args:
591
+ df (pandas.DataFrame): The input DataFrame.
592
+
593
+ Returns:
594
+ pandas.DataFrame: The modified DataFrame with integrity checks and modifications applied.
595
+ """
596
+ df.columns = [col + f'_{i}' if df.columns.tolist().count(col) > 1 and i != 0 else col for i, col in enumerate(df.columns)]
597
+ label_cols = [col for col in df.columns if 'label' in col]
598
+ df['label_list'] = df[label_cols].values.tolist()
599
+ df['object_label'] = df['label_list'].apply(lambda x: x[0])
600
+ df = df.drop(columns=label_cols)
601
+ df['label_list'] = df['label_list'].astype(str)
602
+ return df
603
+
604
+ def _get_percentiles(array, q1=2, q2=98):
605
+ """
606
+ Calculate the percentiles of each image in the given array.
607
+
608
+ Parameters:
609
+ - array: numpy.ndarray
610
+ The input array containing the images.
611
+ - q1: float, optional
612
+ The lower percentile value to calculate. Default is 2.
613
+ - q2: float, optional
614
+ The upper percentile value to calculate. Default is 98.
615
+
616
+ Returns:
617
+ - percentiles: list
618
+ A list of tuples, where each tuple contains the minimum and maximum
619
+ values of the corresponding image in the array.
620
+ """
621
+ nimg = array.shape[2]
622
+ percentiles = []
623
+ for v in range(nimg):
624
+ img = np.squeeze(array[:, :, v])
625
+ non_zero_img = img[img > 0]
626
+ if non_zero_img.size > 0: # check if there are non-zero values
627
+ img_min = np.percentile(non_zero_img, q1) # change percentile from 0.02 to 2
628
+ img_max = np.percentile(non_zero_img, q2) # change percentile from 0.98 to 98
629
+ percentiles.append([img_min, img_max])
630
+ else: # if there are no non-zero values, just use the image as it is
631
+ img_min, img_max = img.min(), img.max()
632
+ percentiles.append([img_min, img_max])
633
+ return percentiles
634
+
635
+ def _crop_center(img, cell_mask, new_width, new_height, normalize=(2,98)):
636
+ """
637
+ Crop the image around the center of the cell mask.
638
+
639
+ Parameters:
640
+ - img: numpy.ndarray
641
+ The input image.
642
+ - cell_mask: numpy.ndarray
643
+ The binary mask of the cell.
644
+ - new_width: int
645
+ The desired width of the cropped image.
646
+ - new_height: int
647
+ The desired height of the cropped image.
648
+ - normalize: tuple, optional
649
+ The normalization range for the image pixel values. Default is (2, 98).
650
+
651
+ Returns:
652
+ - img: numpy.ndarray
653
+ The cropped image.
654
+ """
655
+ # Convert all non-zero values in mask to 1
656
+ cell_mask[cell_mask != 0] = 1
657
+ mask_3d = np.repeat(cell_mask[:, :, np.newaxis], img.shape[2], axis=2).astype(img.dtype) # Create 3D mask
658
+ img = np.multiply(img, mask_3d).astype(img.dtype) # Multiply image with mask to set pixel values outside of the mask to 0
659
+ #centroid = np.round(ndi.measurements.center_of_mass(cell_mask)).astype(int) # Compute centroid of the mask
660
+ centroid = np.round(ndi.center_of_mass(cell_mask)).astype(int) # Compute centroid of the mask
661
+ # Pad the image and mask to ensure the crop will not go out of bounds
662
+ pad_width = max(new_width, new_height)
663
+ img = np.pad(img, ((pad_width, pad_width), (pad_width, pad_width), (0, 0)), mode='constant')
664
+ cell_mask = np.pad(cell_mask, ((pad_width, pad_width), (pad_width, pad_width)), mode='constant')
665
+ # Update centroid coordinates due to padding
666
+ centroid += pad_width
667
+ # Compute bounding box
668
+ start_y = max(0, centroid[0] - new_height // 2)
669
+ end_y = min(start_y + new_height, img.shape[0])
670
+ start_x = max(0, centroid[1] - new_width // 2)
671
+ end_x = min(start_x + new_width, img.shape[1])
672
+ # Crop to bounding box
673
+ img = img[start_y:end_y, start_x:end_x, :]
674
+ return img
675
+
676
+
677
+
678
+
679
+ def _masks_to_masks_stack(masks):
680
+ """
681
+ Convert a list of masks into a stack of masks.
682
+
683
+ Args:
684
+ masks (list): A list of masks.
685
+
686
+ Returns:
687
+ list: A stack of masks.
688
+ """
689
+ mask_stack = []
690
+ for idx, mask in enumerate(masks):
691
+ mask_stack.append(mask)
692
+ return mask_stack
693
+
694
+ def _get_diam(mag, obj):
695
+ if obj == 'cell':
696
+ if mag == 20:
697
+ scale = 6
698
+ if mag == 40:
699
+ scale = 4.5
700
+ if mag == 60:
701
+ scale = 3
702
+ elif obj == 'nucleus':
703
+ if mag == 20:
704
+ scale = 3
705
+ if mag == 40:
706
+ scale = 2
707
+ if mag == 60:
708
+ scale = 1.5
709
+ elif obj == 'pathogen':
710
+ if mag == 20:
711
+ scale = 1.5
712
+ if mag == 40:
713
+ scale = 1
714
+ if mag == 60:
715
+ scale = 1.25
716
+ elif obj == 'pathogen_nucleus':
717
+ if mag == 20:
718
+ scale = 0.25
719
+ if mag == 40:
720
+ scale = 0.2
721
+ if mag == 60:
722
+ scale = 0.2
723
+ else:
724
+ raise ValueError("Invalid object type")
725
+ diamiter = mag*scale
726
+ return diamiter
727
+
728
+ def _get_object_settings(object_type, settings):
729
+
730
+ object_settings = {}
731
+ object_settings['refine_masks'] = False
732
+ object_settings['filter_size'] = False
733
+ object_settings['filter_dimm'] = False
734
+ print(object_type)
735
+ object_settings['diameter'] = _get_diam(settings['magnification'], obj=object_type)
736
+ object_settings['remove_border_objects'] = False
737
+ object_settings['minimum_size'] = (object_settings['diameter']**2)/10
738
+ object_settings['maximum_size'] = object_settings['minimum_size']*50
739
+ object_settings['merge'] = False
740
+ object_settings['net_avg'] = True
741
+ object_settings['resample'] = True
742
+ object_settings['model_name'] = 'cyto'
743
+
744
+ if object_type == 'cell':
745
+ if settings['nucleus_channel'] is None:
746
+ object_settings['model_name'] = 'cyto'
747
+ else:
748
+ object_settings['model_name'] = 'cyto2'
749
+
750
+ elif object_type == 'nucleus':
751
+ object_settings['model_name'] = 'nuclei'
752
+
753
+ elif object_type == 'pathogen':
754
+ object_settings['model_name'] = 'cyto3'
755
+
756
+ elif object_type == 'pathogen_nucleus':
757
+ object_settings['filter_size'] = True
758
+ object_settings['model_name'] = 'cyto'
759
+
760
+ else:
761
+ print(f'Object type: {object_type} not supported. Supported object types are : cell, nucleus and pathogen')
762
+ print(f'using settings: {object_settings}')
763
+
764
+ return object_settings
765
+
766
+ def _pivot_counts_table(db_path):
767
+
768
+ def _read_table_to_dataframe(db_path, table_name='object_counts'):
769
+ """
770
+ Read a table from an SQLite database into a pandas DataFrame.
771
+
772
+ Parameters:
773
+ - db_path (str): The path to the SQLite database file.
774
+ - table_name (str): The name of the table to read. Default is 'object_counts'.
775
+
776
+ Returns:
777
+ - df (pandas.DataFrame): The table data as a pandas DataFrame.
778
+ """
779
+ # Connect to the SQLite database
780
+ conn = sqlite3.connect(db_path)
781
+ # Read the entire table into a pandas DataFrame
782
+ query = f"SELECT * FROM {table_name}"
783
+ df = pd.read_sql_query(query, conn)
784
+ # Close the connection
785
+ conn.close()
786
+ return df
787
+
788
+ def _pivot_dataframe(df):
789
+ """
790
+ Pivot the DataFrame.
791
+
792
+ Args:
793
+ df (pandas.DataFrame): The input DataFrame.
794
+
795
+ Returns:
796
+ pandas.DataFrame: The pivoted DataFrame with filled NaN values.
797
+ """
798
+ # Pivot the DataFrame
799
+ pivoted_df = df.pivot(index='file_name', columns='count_type', values='object_count').reset_index()
800
+ # Because the pivot operation can introduce NaN values for missing data,
801
+ # you might want to fill those NaNs with a default value, like 0
802
+ pivoted_df = pivoted_df.fillna(0)
803
+ return pivoted_df
804
+
805
+ # Read the original 'object_counts' table
806
+ df = _read_table_to_dataframe(db_path, 'object_counts')
807
+ # Pivot the DataFrame to have one row per filename and a column for each object type
808
+ pivoted_df = _pivot_dataframe(df)
809
+ # Reconnect to the SQLite database to overwrite the 'object_counts' table with the pivoted DataFrame
810
+ conn = sqlite3.connect(db_path)
811
+ # When overwriting, ensure that you drop the existing table or use if_exists='replace' to overwrite it
812
+ pivoted_df.to_sql('pivoted_counts', conn, if_exists='replace', index=False)
813
+ conn.close()
814
+
815
+ def _get_cellpose_channels_v1(mask_channels, nucleus_chann_dim, pathogen_chann_dim, cell_chann_dim):
816
+ cellpose_channels = {}
817
+ if nucleus_chann_dim in mask_channels:
818
+ cellpose_channels['nucleus'] = [0, mask_channels.index(nucleus_chann_dim)]
819
+ if pathogen_chann_dim in mask_channels:
820
+ cellpose_channels['pathogen'] = [0, mask_channels.index(pathogen_chann_dim)]
821
+ if cell_chann_dim in mask_channels:
822
+ cellpose_channels['cell'] = [0, mask_channels.index(cell_chann_dim)]
823
+ return cellpose_channels
824
+
825
+ def _get_cellpose_channels_v1(cell_channel, nucleus_channel, pathogen_channel):
826
+ # Initialize a dictionary to hold the new indices for the specified channels
827
+ cellpose_channels = {}
828
+
829
+ # Initialize a list to keep track of the channels in their new order
830
+ new_channel_order = []
831
+
832
+ # Add each channel to the new order list if it is not None
833
+ if cell_channel is not None:
834
+ new_channel_order.append(('cell', cell_channel))
835
+ if nucleus_channel is not None:
836
+ new_channel_order.append(('nucleus', nucleus_channel))
837
+ if pathogen_channel is not None:
838
+ new_channel_order.append(('pathogen', pathogen_channel))
839
+
840
+ # Sort the list based on the original channel indices to maintain the original order
841
+ new_channel_order.sort(key=lambda x: x[1])
842
+ print(new_channel_order)
843
+ # Assign new indices based on the sorted order
844
+ for new_index, (channel_name, _) in enumerate(new_channel_order):
845
+ cellpose_channels[channel_name] = [new_index, 0]
846
+
847
+ if cell_channel is not None and nucleus_channel is not None:
848
+ cellpose_channels['cell'][1] = cellpose_channels['nucleus'][0]
849
+
850
+ return cellpose_channels
851
+
852
+ def _get_cellpose_channels(nucleus_channel, pathogen_channel, cell_channel):
853
+ cellpose_channels = {}
854
+ if not nucleus_channel is None:
855
+ cellpose_channels['nucleus'] = [0,0]
856
+
857
+ if not pathogen_channel is None:
858
+ if not nucleus_channel is None:
859
+ cellpose_channels['pathogen'] = [0,1]
860
+ else:
861
+ cellpose_channels['pathogen'] = [0,0]
862
+
863
+ if not cell_channel is None:
864
+ if not nucleus_channel is None:
865
+ if not pathogen_channel is None:
866
+ cellpose_channels['cell'] = [0,2]
867
+ else:
868
+ cellpose_channels['cell'] = [0,1]
869
+ elif not pathogen_channel is None:
870
+ cellpose_channels['cell'] = [0,1]
871
+ else:
872
+ cellpose_channels['cell'] = [0,0]
873
+ return cellpose_channels
874
+
875
+ def annotate_conditions(df, cells=['HeLa'], cell_loc=None, pathogens=['rh'], pathogen_loc=None, treatments=['cm'], treatment_loc=None, types = ['col','col','col']):
876
+ """
877
+ Annotates conditions in a DataFrame based on specified criteria.
878
+
879
+ Args:
880
+ df (pandas.DataFrame): The DataFrame to annotate.
881
+ cells (list, optional): List of host cell types. Defaults to ['HeLa'].
882
+ cell_loc (list, optional): List of corresponding values for each host cell type. Defaults to None.
883
+ pathogens (list, optional): List of pathogens. Defaults to ['rh'].
884
+ pathogen_loc (list, optional): List of corresponding values for each pathogen. Defaults to None.
885
+ treatments (list, optional): List of treatments. Defaults to ['cm'].
886
+ treatment_loc (list, optional): List of corresponding values for each treatment. Defaults to None.
887
+ types (list, optional): List of column types for host cells, pathogens, and treatments. Defaults to ['col','col','col'].
888
+
889
+ Returns:
890
+ pandas.DataFrame: The annotated DataFrame.
891
+ """
892
+
893
+ # Function to apply to each row
894
+ def _map_values(row, dict_, type_='col'):
895
+ """
896
+ Maps the values in a row to corresponding keys in a dictionary.
897
+
898
+ Args:
899
+ row (dict): The row containing the values to be mapped.
900
+ dict_ (dict): The dictionary containing the mapping values.
901
+ type_ (str, optional): The type of mapping to perform. Defaults to 'col'.
902
+
903
+ Returns:
904
+ str: The mapped value if found, otherwise None.
905
+ """
906
+ for values, cols in dict_.items():
907
+ if row[type_] in cols:
908
+ return values
909
+ return None
910
+
911
+ if cell_loc is None:
912
+ df['host_cells'] = cells[0]
913
+ else:
914
+ cells_dict = dict(zip(cells, cell_loc))
915
+ df['host_cells'] = df.apply(lambda row: _map_values(row, cells_dict, type_=types[0]), axis=1)
916
+ if pathogen_loc is None:
917
+ if pathogens != None:
918
+ df['pathogen'] = 'none'
919
+ else:
920
+ pathogens_dict = dict(zip(pathogens, pathogen_loc))
921
+ df['pathogen'] = df.apply(lambda row: _map_values(row, pathogens_dict, type_=types[1]), axis=1)
922
+ if treatment_loc is None:
923
+ df['treatment'] = 'cm'
924
+ else:
925
+ treatments_dict = dict(zip(treatments, treatment_loc))
926
+ df['treatment'] = df.apply(lambda row: _map_values(row, treatments_dict, type_=types[2]), axis=1)
927
+ if pathogens != None:
928
+ df['condition'] = df['pathogen']+'_'+df['treatment']
929
+ else:
930
+ df['condition'] = df['treatment']
931
+ return df
932
+
933
+
934
+
935
+ def _split_data(df, group_by, object_type):
936
+ """
937
+ Splits the input dataframe into numeric and non-numeric parts, groups them by the specified column,
938
+ and returns the grouped dataframes.
939
+
940
+ Parameters:
941
+ df (pandas.DataFrame): The input dataframe.
942
+ group_by (str): The column name to group the dataframes by.
943
+ object_type (str): The column name to concatenate with 'prcf' to create a new column 'prcfo'.
944
+
945
+ Returns:
946
+ grouped_numeric (pandas.DataFrame): The grouped dataframe containing numeric columns.
947
+ grouped_non_numeric (pandas.DataFrame): The grouped dataframe containing non-numeric columns.
948
+ """
949
+ df['prcfo'] = df['prcf'] + '_' + df[object_type]
950
+ df = df.set_index(group_by, inplace=False)
951
+
952
+ df_numeric = df.select_dtypes(include=np.number)
953
+ df_non_numeric = df.select_dtypes(exclude=np.number)
954
+
955
+ grouped_numeric = df_numeric.groupby(df_numeric.index).mean()
956
+ grouped_non_numeric = df_non_numeric.groupby(df_non_numeric.index).first()
957
+
958
+ return pd.DataFrame(grouped_numeric), pd.DataFrame(grouped_non_numeric)
959
+
960
+ def _calculate_recruitment(df, channel):
961
+ """
962
+ Calculate recruitment metrics based on intensity values in different channels.
963
+
964
+ Args:
965
+ df (pandas.DataFrame): The input DataFrame containing intensity values in different channels.
966
+ channel (int): The channel number.
967
+
968
+ Returns:
969
+ pandas.DataFrame: The DataFrame with calculated recruitment metrics.
970
+
971
+ """
972
+ df['pathogen_cell_mean_mean'] = df[f'pathogen_channel_{channel}_mean_intensity']/df[f'cell_channel_{channel}_mean_intensity']
973
+ df['pathogen_cytoplasm_mean_mean'] = df[f'pathogen_channel_{channel}_mean_intensity']/df[f'cytoplasm_channel_{channel}_mean_intensity']
974
+ df['pathogen_nucleus_mean_mean'] = df[f'pathogen_channel_{channel}_mean_intensity']/df[f'nucleus_channel_{channel}_mean_intensity']
975
+
976
+ df['pathogen_cell_q75_mean'] = df[f'pathogen_channel_{channel}_percentile_75']/df[f'cell_channel_{channel}_mean_intensity']
977
+ df['pathogen_cytoplasm_q75_mean'] = df[f'pathogen_channel_{channel}_percentile_75']/df[f'cytoplasm_channel_{channel}_mean_intensity']
978
+ df['pathogen_nucleus_q75_mean'] = df[f'pathogen_channel_{channel}_percentile_75']/df[f'nucleus_channel_{channel}_mean_intensity']
979
+
980
+ df['pathogen_outside_cell_mean_mean'] = df[f'pathogen_channel_{channel}_outside_mean']/df[f'cell_channel_{channel}_mean_intensity']
981
+ df['pathogen_outside_cytoplasm_mean_mean'] = df[f'pathogen_channel_{channel}_outside_mean']/df[f'cytoplasm_channel_{channel}_mean_intensity']
982
+ df['pathogen_outside_nucleus_mean_mean'] = df[f'pathogen_channel_{channel}_outside_mean']/df[f'nucleus_channel_{channel}_mean_intensity']
983
+
984
+ df['pathogen_outside_cell_q75_mean'] = df[f'pathogen_channel_{channel}_outside_75_percentile']/df[f'cell_channel_{channel}_mean_intensity']
985
+ df['pathogen_outside_cytoplasm_q75_mean'] = df[f'pathogen_channel_{channel}_outside_75_percentile']/df[f'cytoplasm_channel_{channel}_mean_intensity']
986
+ df['pathogen_outside_nucleus_q75_mean'] = df[f'pathogen_channel_{channel}_outside_75_percentile']/df[f'nucleus_channel_{channel}_mean_intensity']
987
+
988
+ df['pathogen_periphery_cell_mean_mean'] = df[f'pathogen_channel_{channel}_periphery_mean']/df[f'cell_channel_{channel}_mean_intensity']
989
+ df['pathogen_periphery_cytoplasm_mean_mean'] = df[f'pathogen_channel_{channel}_periphery_mean']/df[f'cytoplasm_channel_{channel}_mean_intensity']
990
+ df['pathogen_periphery_nucleus_mean_mean'] = df[f'pathogen_channel_{channel}_periphery_mean']/df[f'nucleus_channel_{channel}_mean_intensity']
991
+
992
+ channels = [0,1,2,3]
993
+ object_type = 'pathogen'
994
+ for chan in channels:
995
+ df[f'{object_type}_slope_channel_{chan}'] = 1
996
+
997
+ object_type = 'nucleus'
998
+ for chan in channels:
999
+ df[f'{object_type}_slope_channel_{chan}'] = 1
1000
+
1001
+ for chan in channels:
1002
+ df[f'nucleus_coordinates_{chan}'] = df[[f'nucleus_channel_{chan}_centroid_weighted_local-0', f'nucleus_channel_{chan}_centroid_weighted_local-1']].values.tolist()
1003
+ df[f'pathogen_coordinates_{chan}'] = df[[f'pathogen_channel_{chan}_centroid_weighted_local-0', f'pathogen_channel_{chan}_centroid_weighted_local-1']].values.tolist()
1004
+ df[f'cell_coordinates_{chan}'] = df[[f'cell_channel_{chan}_centroid_weighted_local-0', f'cell_channel_{chan}_centroid_weighted_local-1']].values.tolist()
1005
+ df[f'cytoplasm_coordinates_{chan}'] = df[[f'cytoplasm_channel_{chan}_centroid_weighted_local-0', f'cytoplasm_channel_{chan}_centroid_weighted_local-1']].values.tolist()
1006
+
1007
+ df[f'pathogen_cell_distance_channel_{chan}'] = df.apply(lambda row: np.sqrt((row[f'pathogen_coordinates_{chan}'][0] - row[f'cell_coordinates_{chan}'][0])**2 +
1008
+ (row[f'pathogen_coordinates_{chan}'][1] - row[f'cell_coordinates_{chan}'][1])**2), axis=1)
1009
+ df[f'nucleus_cell_distance_channel_{chan}'] = df.apply(lambda row: np.sqrt((row[f'nucleus_coordinates_{chan}'][0] - row[f'cell_coordinates_{chan}'][0])**2 +
1010
+ (row[f'nucleus_coordinates_{chan}'][1] - row[f'cell_coordinates_{chan}'][1])**2), axis=1)
1011
+ return df
1012
+
1013
+ def _group_by_well(df):
1014
+ """
1015
+ Group the DataFrame by well coordinates (plate, row, col) and apply mean function to numeric columns
1016
+ and select the first value for non-numeric columns.
1017
+
1018
+ Parameters:
1019
+ df (DataFrame): The input DataFrame to be grouped.
1020
+
1021
+ Returns:
1022
+ DataFrame: The grouped DataFrame.
1023
+ """
1024
+ numeric_cols = df._get_numeric_data().columns
1025
+ non_numeric_cols = df.select_dtypes(include=['object']).columns
1026
+
1027
+ # Apply mean function to numeric columns and first to non-numeric
1028
+ df_grouped = df.groupby(['plate', 'row', 'col']).agg({**{col: np.mean for col in numeric_cols}, **{col: 'first' for col in non_numeric_cols}})
1029
+ return df_grouped
1030
+
1031
+
1032
+
1033
+
1034
+ ###################################################
1035
+ # Classify
1036
+ ###################################################
1037
+
1038
+ class Cache:
1039
+ """
1040
+ A class representing a cache with a maximum size.
1041
+
1042
+ Attributes:
1043
+ max_size (int): The maximum size of the cache.
1044
+ cache (OrderedDict): The cache data structure.
1045
+ """
1046
+
1047
+ def _init__(self, max_size):
1048
+ self.cache = OrderedDict()
1049
+ self.max_size = max_size
1050
+
1051
+ def get(self, key):
1052
+ if key in self.cache:
1053
+ value = self.cache.pop(key)
1054
+ self.cache[key] = value
1055
+ return value
1056
+ return None
1057
+
1058
+ def put(self, key, value):
1059
+ if len(self.cache) >= self.max_size:
1060
+ self.cache.popitem(last=False)
1061
+ self.cache[key] = value
1062
+
1063
+ class ScaledDotProductAttention(nn.Module):
1064
+ """
1065
+ Scaled Dot-Product Attention module.
1066
+
1067
+ Args:
1068
+ d_k (int): The dimension of the key and query vectors.
1069
+
1070
+ Attributes:
1071
+ d_k (int): The dimension of the key and query vectors.
1072
+
1073
+ Methods:
1074
+ forward(Q, K, V): Performs the forward pass of the attention mechanism.
1075
+
1076
+ """
1077
+
1078
+ def _init__(self, d_k):
1079
+ super(ScaledDotProductAttention, self).__init__()
1080
+ self.d_k = d_k
1081
+
1082
+ def forward(self, Q, K, V):
1083
+ """
1084
+ Performs the forward pass of the attention mechanism.
1085
+
1086
+ Args:
1087
+ Q (torch.Tensor): The query tensor of shape (batch_size, seq_len_q, d_k).
1088
+ K (torch.Tensor): The key tensor of shape (batch_size, seq_len_k, d_k).
1089
+ V (torch.Tensor): The value tensor of shape (batch_size, seq_len_v, d_k).
1090
+
1091
+ Returns:
1092
+ torch.Tensor: The output tensor of shape (batch_size, seq_len_q, d_k).
1093
+
1094
+ """
1095
+ scores = torch.matmul(Q, K.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.d_k, dtype=torch.float32))
1096
+ attention_probs = F.softmax(scores, dim=-1)
1097
+ output = torch.matmul(attention_probs, V)
1098
+ return output
1099
+
1100
+ class SelfAttention(nn.Module):
1101
+ """
1102
+ Self-Attention module that applies scaled dot-product attention mechanism.
1103
+
1104
+ Args:
1105
+ in_channels (int): Number of input channels.
1106
+ d_k (int): Dimensionality of the key and query vectors.
1107
+ """
1108
+
1109
+ def _init__(self, in_channels, d_k):
1110
+ super(SelfAttention, self).__init__()
1111
+ self.W_q = nn.Linear(in_channels, d_k)
1112
+ self.W_k = nn.Linear(in_channels, d_k)
1113
+ self.W_v = nn.Linear(in_channels, d_k)
1114
+ self.attention = ScaledDotProductAttention(d_k)
1115
+
1116
+ def forward(self, x):
1117
+ """
1118
+ Forward pass of the SelfAttention module.
1119
+
1120
+ Args:
1121
+ x (torch.Tensor): Input tensor of shape (batch_size, in_channels).
1122
+
1123
+ Returns:
1124
+ torch.Tensor: Output tensor of shape (batch_size, d_k).
1125
+ """
1126
+ Q = self.W_q(x)
1127
+ K = self.W_k(x)
1128
+ V = self.W_v(x)
1129
+ output = self.attention(Q, K, V)
1130
+ return output
1131
+
1132
+ class ScaledDotProductAttention(nn.Module):
1133
+ def _init__(self, d_k):
1134
+ """
1135
+ Initializes the ScaledDotProductAttention module.
1136
+
1137
+ Args:
1138
+ d_k (int): The dimension of the key and query vectors.
1139
+
1140
+ """
1141
+ super(ScaledDotProductAttention, self).__init__()
1142
+ self.d_k = d_k
1143
+
1144
+ def forward(self, Q, K, V):
1145
+ """
1146
+ Performs the forward pass of the ScaledDotProductAttention module.
1147
+
1148
+ Args:
1149
+ Q (torch.Tensor): The query tensor.
1150
+ K (torch.Tensor): The key tensor.
1151
+ V (torch.Tensor): The value tensor.
1152
+
1153
+ Returns:
1154
+ torch.Tensor: The output tensor.
1155
+
1156
+ """
1157
+ scores = torch.matmul(Q, K.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.d_k, dtype=torch.float32))
1158
+ attention_probs = F.softmax(scores, dim=-1)
1159
+ output = torch.matmul(attention_probs, V)
1160
+ return output
1161
+
1162
+ class SelfAttention(nn.Module):
1163
+ """
1164
+ Self-Attention module that applies scaled dot-product attention mechanism.
1165
+
1166
+ Args:
1167
+ in_channels (int): Number of input channels.
1168
+ d_k (int): Dimensionality of the key and query vectors.
1169
+ """
1170
+ def _init__(self, in_channels, d_k):
1171
+ super(SelfAttention, self).__init__()
1172
+ self.W_q = nn.Linear(in_channels, d_k)
1173
+ self.W_k = nn.Linear(in_channels, d_k)
1174
+ self.W_v = nn.Linear(in_channels, d_k)
1175
+ self.attention = ScaledDotProductAttention(d_k)
1176
+
1177
+ def forward(self, x):
1178
+ """
1179
+ Forward pass of the SelfAttention module.
1180
+
1181
+ Args:
1182
+ x (torch.Tensor): Input tensor of shape (batch_size, in_channels).
1183
+
1184
+ Returns:
1185
+ torch.Tensor: Output tensor after applying self-attention mechanism.
1186
+ """
1187
+ Q = self.W_q(x)
1188
+ K = self.W_k(x)
1189
+ V = self.W_v(x)
1190
+ output = self.attention(Q, K, V)
1191
+ return output
1192
+
1193
+ # Early Fusion Block
1194
+ class EarlyFusion(nn.Module):
1195
+ """
1196
+ Early Fusion module for image classification.
1197
+
1198
+ Args:
1199
+ in_channels (int): Number of input channels.
1200
+ """
1201
+ def _init__(self, in_channels):
1202
+ super(EarlyFusion, self).__init__()
1203
+ self.conv1 = nn.Conv2d(in_channels, 64, kernel_size=1, stride=1)
1204
+
1205
+ def forward(self, x):
1206
+ """
1207
+ Forward pass of the Early Fusion module.
1208
+
1209
+ Args:
1210
+ x (torch.Tensor): Input tensor of shape (batch_size, in_channels, height, width).
1211
+
1212
+ Returns:
1213
+ torch.Tensor: Output tensor of shape (batch_size, 64, height, width).
1214
+ """
1215
+ x = self.conv1(x)
1216
+ return x
1217
+
1218
+ # Spatial Attention Mechanism
1219
+ class SpatialAttention(nn.Module):
1220
+ def _init__(self, kernel_size=7):
1221
+ """
1222
+ Initializes the SpatialAttention module.
1223
+
1224
+ Args:
1225
+ kernel_size (int): The size of the convolutional kernel. Default is 7.
1226
+ """
1227
+ super(SpatialAttention, self).__init__()
1228
+ self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=kernel_size//2, bias=False)
1229
+ self.sigmoid = nn.Sigmoid()
1230
+
1231
+ def forward(self, x):
1232
+ """
1233
+ Performs forward pass of the SpatialAttention module.
1234
+
1235
+ Args:
1236
+ x (torch.Tensor): The input tensor.
1237
+
1238
+ Returns:
1239
+ torch.Tensor: The output tensor after applying spatial attention.
1240
+ """
1241
+ avg_out = torch.mean(x, dim=1, keepdim=True)
1242
+ max_out, _ = torch.max(x, dim=1, keepdim=True)
1243
+ x = torch.cat([avg_out, max_out], dim=1)
1244
+ x = self.conv1(x)
1245
+ return self.sigmoid(x)
1246
+
1247
+ # Multi-Scale Block with Attention
1248
+ class MultiScaleBlockWithAttention(nn.Module):
1249
+ """
1250
+ Multi-scale block with attention module.
1251
+
1252
+ Args:
1253
+ in_channels (int): Number of input channels.
1254
+ out_channels (int): Number of output channels.
1255
+
1256
+ Attributes:
1257
+ dilated_conv1 (nn.Conv2d): Dilated convolution layer.
1258
+ spatial_attention (nn.Conv2d): Spatial attention layer.
1259
+
1260
+ Methods:
1261
+ custom_forward: Custom forward method for the module.
1262
+ forward: Forward method for the module.
1263
+ """
1264
+
1265
+ def _init__(self, in_channels, out_channels):
1266
+ super(MultiScaleBlockWithAttention, self).__init__()
1267
+ self.dilated_conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, dilation=1, padding=1)
1268
+ self.spatial_attention = nn.Conv2d(out_channels, out_channels, kernel_size=1)
1269
+
1270
+ def custom_forward(self, x):
1271
+ """
1272
+ Custom forward method for the module.
1273
+
1274
+ Args:
1275
+ x (torch.Tensor): Input tensor.
1276
+
1277
+ Returns:
1278
+ torch.Tensor: Output tensor.
1279
+ """
1280
+ x1 = F.relu(self.dilated_conv1(x), inplace=True)
1281
+ x = self.spatial_attention(x1)
1282
+ return x
1283
+
1284
+ def forward(self, x):
1285
+ """
1286
+ Forward method for the module.
1287
+
1288
+ Args:
1289
+ x (torch.Tensor): Input tensor.
1290
+
1291
+ Returns:
1292
+ torch.Tensor: Output tensor.
1293
+ """
1294
+ return checkpoint(self.custom_forward, x)
1295
+
1296
+ # Final Classifier
1297
+ class CustomCellClassifier(nn.Module):
1298
+ def _init__(self, num_classes, pathogen_channel, use_attention, use_checkpoint, dropout_rate):
1299
+ super(CustomCellClassifier, self).__init__()
1300
+ self.early_fusion = EarlyFusion(in_channels=3)
1301
+
1302
+ self.multi_scale_block_1 = MultiScaleBlockWithAttention(in_channels=64, out_channels=64)
1303
+
1304
+ self.fc1 = nn.Linear(64, num_classes)
1305
+ self.use_checkpoint = use_checkpoint
1306
+ # Explicitly require gradients for all parameters
1307
+ for param in self.parameters():
1308
+ param.requires_grad = True
1309
+
1310
+ def custom_forward(self, x):
1311
+ x.requires_grad = True
1312
+ x = self.early_fusion(x)
1313
+ x = self.multi_scale_block_1(x)
1314
+ x = F.adaptive_avg_pool2d(x, (1, 1)).view(x.size(0), -1)
1315
+ x = F.relu(self.fc1(x), inplace=True)
1316
+ return x
1317
+
1318
+ def forward(self, x):
1319
+ if self.use_checkpoint:
1320
+ x.requires_grad = True
1321
+ return checkpoint(self.custom_forward, x)
1322
+ else:
1323
+ return self.custom_forward(x)
1324
+
1325
+ #CNN and Transformer class, pick any Torch model.
1326
+ class TorchModel(nn.Module):
1327
+ def _init__(self, model_name='resnet50', pretrained=True, dropout_rate=None, use_checkpoint=False):
1328
+ super(TorchModel, self).__init__()
1329
+ self.model_name = model_name
1330
+ self.use_checkpoint = use_checkpoint
1331
+ self.base_model = self.init_base_model(pretrained)
1332
+
1333
+ # Retain layers up to and including the (5): Linear layer for model 'maxvit_t'
1334
+ if model_name == 'maxvit_t':
1335
+ self.base_model.classifier = nn.Sequential(*list(self.base_model.classifier.children())[:-1])
1336
+
1337
+ if dropout_rate is not None:
1338
+ self.apply_dropout_rate(self.base_model, dropout_rate)
1339
+
1340
+ self.num_ftrs = self.get_num_ftrs()
1341
+ self.init_spacr_classifier(dropout_rate)
1342
+
1343
+ def apply_dropout_rate(self, model, dropout_rate):
1344
+ """Apply dropout rate to all dropout layers in the model."""
1345
+ for module in model.modules():
1346
+ if isinstance(module, nn.Dropout):
1347
+ module.p = dropout_rate
1348
+
1349
+ def init_base_model(self, pretrained):
1350
+ """Initialize the base model from torchvision.models."""
1351
+ model_func = models.__dict__.get(self.model_name, None)
1352
+ if not model_func:
1353
+ raise ValueError(f"Model {self.model_name} is not recognized.")
1354
+ weight_choice = self.get_weight_choice()
1355
+ if weight_choice is not None:
1356
+ return model_func(weights=weight_choice)
1357
+ else:
1358
+ return model_func(pretrained=pretrained)
1359
+
1360
+ def get_weight_choice(self):
1361
+ """Get weight choice if it exists for the model."""
1362
+ weight_enum = None
1363
+ for attr_name in dir(models):
1364
+ if attr_name.lower() == f"{self.model_name}_weights".lower():
1365
+ weight_enum = getattr(models, attr_name)
1366
+ break
1367
+ return weight_enum.DEFAULT if weight_enum else None
1368
+
1369
+ def get_num_ftrs(self):
1370
+ """Determine the number of features output by the base model."""
1371
+ if hasattr(self.base_model, 'fc'):
1372
+ self.base_model.fc = nn.Identity()
1373
+ elif hasattr(self.base_model, 'classifier'):
1374
+ if self.model_name != 'maxvit_t':
1375
+ self.base_model.classifier = nn.Identity()
1376
+
1377
+ # Forward a dummy input and check output size
1378
+ dummy_input = torch.randn(1, 3, 224, 224)
1379
+ output = self.base_model(dummy_input)
1380
+ return output.size(1)
1381
+
1382
+ def init_spacr_classifier(self, dropout_rate):
1383
+ """Initialize the SPACR classifier."""
1384
+ self.use_dropout = dropout_rate is not None
1385
+ if self.use_dropout:
1386
+ self.dropout = nn.Dropout(dropout_rate)
1387
+ self.spacr_classifier = nn.Linear(self.num_ftrs, 1)
1388
+
1389
+ def forward(self, x):
1390
+ """Define the forward pass of the model."""
1391
+ if self.use_checkpoint:
1392
+ x = checkpoint(self.base_model, x)
1393
+ else:
1394
+ x = self.base_model(x)
1395
+ if self.use_dropout:
1396
+ x = self.dropout(x)
1397
+ logits = self.spacr_classifier(x).flatten()
1398
+ return logits
1399
+
1400
+ class FocalLossWithLogits(nn.Module):
1401
+ def _init__(self, alpha=1, gamma=2):
1402
+ super(FocalLossWithLogits, self).__init__()
1403
+ self.alpha = alpha
1404
+ self.gamma = gamma
1405
+
1406
+ def forward(self, logits, target):
1407
+ BCE_loss = F.binary_cross_entropy_with_logits(logits, target, reduction='none')
1408
+ pt = torch.exp(-BCE_loss)
1409
+ focal_loss = self.alpha * (1-pt)**self.gamma * BCE_loss
1410
+ return focal_loss.mean()
1411
+
1412
+ class ResNet(nn.Module):
1413
+ def _init__(self, resnet_type='resnet50', dropout_rate=None, use_checkpoint=False, init_weights='imagenet'):
1414
+ super(ResNet, self).__init__()
1415
+
1416
+ resnet_map = {
1417
+ 'resnet18': {'func': models.resnet18, 'weights': ResNet18_Weights.IMAGENET1K_V1},
1418
+ 'resnet34': {'func': models.resnet34, 'weights': ResNet34_Weights.IMAGENET1K_V1},
1419
+ 'resnet50': {'func': models.resnet50, 'weights': ResNet50_Weights.IMAGENET1K_V1},
1420
+ 'resnet101': {'func': models.resnet101, 'weights': ResNet101_Weights.IMAGENET1K_V1},
1421
+ 'resnet152': {'func': models.resnet152, 'weights': ResNet152_Weights.IMAGENET1K_V1}
1422
+ }
1423
+
1424
+ if resnet_type not in resnet_map:
1425
+ raise ValueError(f"Invalid resnet_type. Choose from {list(resnet_map.keys())}")
1426
+
1427
+ self.initialize_base(resnet_map[resnet_type], dropout_rate, use_checkpoint, init_weights)
1428
+
1429
+ def initialize_base(self, base_model_dict, dropout_rate, use_checkpoint, init_weights):
1430
+ if init_weights == 'imagenet':
1431
+ self.resnet = base_model_dict['func'](weights=base_model_dict['weights'])
1432
+ elif init_weights == 'none':
1433
+ self.resnet = base_model_dict['func'](weights=None)
1434
+ else:
1435
+ raise ValueError("init_weights should be either 'imagenet' or 'none'")
1436
+
1437
+ self.fc1 = nn.Linear(1000, 500)
1438
+ self.use_dropout = dropout_rate != None
1439
+ self.use_checkpoint = use_checkpoint
1440
+
1441
+ if self.use_dropout:
1442
+ self.dropout = nn.Dropout(dropout_rate)
1443
+
1444
+ self.fc2 = nn.Linear(500, 1)
1445
+
1446
+ def forward(self, x):
1447
+ x.requires_grad = True # Ensure that the tensor has requires_grad set to True
1448
+
1449
+ if self.use_checkpoint:
1450
+ x = checkpoint(self.resnet, x) # Use checkpointing for just the ResNet part
1451
+ else:
1452
+ x = self.resnet(x)
1453
+
1454
+ x = F.relu(self.fc1(x))
1455
+
1456
+ if self.use_dropout:
1457
+ x = self.dropout(x)
1458
+
1459
+ logits = self.fc2(x).flatten()
1460
+ return logits
1461
+
1462
+ def split_my_dataset(dataset, split_ratio=0.1):
1463
+ """
1464
+ Splits a dataset into training and validation subsets.
1465
+
1466
+ Args:
1467
+ dataset (torch.utils.data.Dataset): The dataset to be split.
1468
+ split_ratio (float, optional): The ratio of validation samples to total samples. Defaults to 0.1.
1469
+
1470
+ Returns:
1471
+ tuple: A tuple containing the training dataset and validation dataset.
1472
+ """
1473
+ num_samples = len(dataset)
1474
+ indices = list(range(num_samples))
1475
+ split_idx = int((1 - split_ratio) * num_samples)
1476
+ random.shuffle(indices)
1477
+ train_indices, val_indices = indices[:split_idx], indices[split_idx:]
1478
+ train_dataset = Subset(dataset, train_indices)
1479
+ val_dataset = Subset(dataset, val_indices)
1480
+ return train_dataset, val_dataset
1481
+
1482
+ def classification_metrics(all_labels, prediction_pos_probs, loader_name, loss, epoch):
1483
+ """
1484
+ Calculate classification metrics for binary classification.
1485
+
1486
+ Parameters:
1487
+ - all_labels (list): List of true labels.
1488
+ - prediction_pos_probs (list): List of predicted positive probabilities.
1489
+ - loader_name (str): Name of the data loader.
1490
+ - loss (float): Loss value.
1491
+ - epoch (int): Epoch number.
1492
+
1493
+ Returns:
1494
+ - data_df (DataFrame): DataFrame containing the calculated metrics.
1495
+ """
1496
+
1497
+ if len(all_labels) != len(prediction_pos_probs):
1498
+ raise ValueError(f"all_labels ({len(all_labels)}) and pred_labels ({len(prediction_pos_probs)}) have different lengths")
1499
+
1500
+ unique_labels = np.unique(all_labels)
1501
+ if len(unique_labels) >= 2:
1502
+ pr_labels = np.array(all_labels).astype(int)
1503
+ precision, recall, thresholds = precision_recall_curve(pr_labels, prediction_pos_probs, pos_label=1)
1504
+ pr_auc = auc(recall, precision)
1505
+ thresholds = np.append(thresholds, 0.0)
1506
+ f1_scores = 2 * (precision * recall) / (precision + recall)
1507
+ optimal_idx = np.nanargmax(f1_scores)
1508
+ optimal_threshold = thresholds[optimal_idx]
1509
+ pred_labels = [int(p > 0.5) for p in prediction_pos_probs]
1510
+ if len(unique_labels) < 2:
1511
+ optimal_threshold = 0.5
1512
+ pred_labels = [int(p > optimal_threshold) for p in prediction_pos_probs]
1513
+ pr_auc = np.nan
1514
+ data = {'label': all_labels, 'pred': pred_labels}
1515
+ df = pd.DataFrame(data)
1516
+ pc_df = df[df['label'] == 1.0]
1517
+ nc_df = df[df['label'] == 0.0]
1518
+ correct = df[df['label'] == df['pred']]
1519
+ acc_all = len(correct) / len(df)
1520
+ if len(pc_df) > 0:
1521
+ correct_pc = pc_df[pc_df['label'] == pc_df['pred']]
1522
+ acc_pc = len(correct_pc) / len(pc_df)
1523
+ else:
1524
+ acc_pc = np.nan
1525
+ if len(nc_df) > 0:
1526
+ correct_nc = nc_df[nc_df['label'] == nc_df['pred']]
1527
+ acc_nc = len(correct_nc) / len(nc_df)
1528
+ else:
1529
+ acc_nc = np.nan
1530
+ data_dict = {'accuracy': acc_all, 'neg_accuracy': acc_nc, 'pos_accuracy': acc_pc, 'loss':loss.item(),'prauc':pr_auc, 'optimal_threshold':optimal_threshold}
1531
+ data_df = pd.DataFrame(data_dict, index=[str(epoch)+'_'+loader_name])
1532
+ return data_df
1533
+
1534
+
1535
+
1536
+ def compute_irm_penalty(losses, dummy_w, device):
1537
+ """
1538
+ Computes the Invariant Risk Minimization (IRM) penalty.
1539
+
1540
+ Args:
1541
+ losses (list): A list of losses.
1542
+ dummy_w (torch.Tensor): A dummy weight tensor.
1543
+ device (torch.device): The device to perform computations on.
1544
+
1545
+ Returns:
1546
+ float: The computed IRM penalty.
1547
+ """
1548
+ weighted_losses = [loss.clone().detach().requires_grad_(True).to(device) * dummy_w for loss in losses]
1549
+ gradients = [grad(w_loss, dummy_w, create_graph=True)[0] for w_loss in weighted_losses]
1550
+ irm_penalty = 0.0
1551
+ for g1, g2 in combinations(gradients, 2):
1552
+ irm_penalty += (g1.dot(g2))**2
1553
+ return irm_penalty
1554
+
1555
+ #def print_model_summary(base_model, channels, height, width):
1556
+ # """
1557
+ # Prints the summary of a given base model.
1558
+ #
1559
+ # Args:
1560
+ # base_model (torch.nn.Module): The base model to print the summary of.
1561
+ # channels (int): The number of input channels.
1562
+ # height (int): The height of the input.
1563
+ # width (int): The width of the input.
1564
+ #
1565
+ # Returns:
1566
+ # None
1567
+ # """
1568
+ # device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
1569
+ # base_model.to(device)
1570
+ # summary(base_model, (channels, height, width))
1571
+ # return
1572
+
1573
+ def choose_model(model_type, device, init_weights=True, dropout_rate=0, use_checkpoint=False, channels=3, height=224, width=224, chan_dict=None, num_classes=2):
1574
+ """
1575
+ Choose a model for classification.
1576
+
1577
+ Args:
1578
+ model_type (str): The type of model to choose. Can be one of the pre-defined TorchVision models or 'custom' for a custom model.
1579
+ device (str): The device to use for model inference.
1580
+ init_weights (bool, optional): Whether to initialize the model with pre-trained weights. Defaults to True.
1581
+ dropout_rate (float, optional): The dropout rate to use in the model. Defaults to 0.
1582
+ use_checkpoint (bool, optional): Whether to use checkpointing during model training. Defaults to False.
1583
+ channels (int, optional): The number of input channels for the model. Defaults to 3.
1584
+ height (int, optional): The height of the input images for the model. Defaults to 224.
1585
+ width (int, optional): The width of the input images for the model. Defaults to 224.
1586
+ chan_dict (dict, optional): A dictionary containing channel information for custom models. Defaults to None.
1587
+ num_classes (int, optional): The number of output classes for the model. Defaults to 2.
1588
+
1589
+ Returns:
1590
+ torch.nn.Module: The chosen model.
1591
+ """
1592
+
1593
+ torch_model_types = torchvision.models.list_models(module=torchvision.models)
1594
+ model_types = torch_model_types + ['custom']
1595
+
1596
+ if not chan_dict is None:
1597
+ pathogen_channel = chan_dict['pathogen_channel']
1598
+ nucleus_channel = chan_dict['nucleus_channel']
1599
+ protein_channel = chan_dict['protein_channel']
1600
+
1601
+ if model_type not in model_types:
1602
+ print(f'Invalid model_type: {model_type}. Compatible model_types: {model_types}')
1603
+ return
1604
+
1605
+ print(f'\rModel parameters: Architecture: {model_type} init_weights: {init_weights} dropout_rate: {dropout_rate} use_checkpoint: {use_checkpoint}', end='\r', flush=True)
1606
+
1607
+ if model_type == 'custom':
1608
+
1609
+ base_model = CustomCellClassifier(num_classes, pathogen_channel=pathogen_channel, use_attention=True, use_checkpoint=use_checkpoint, dropout_rate=dropout_rate)
1610
+ #base_model = CustomCellClassifier(num_classes=2, pathogen_channel=pathogen_channel, nucleus_channel=nucleus_channel, protein_channel=protein_channel, dropout_rate=dropout_rate, use_checkpoint=use_checkpoint)
1611
+ elif model_type in torch_model_types:
1612
+ base_model = TorchModel(model_name=model_type, pretrained=init_weights, dropout_rate=dropout_rate)
1613
+ else:
1614
+ print(f'Compatible model_types: {model_types}')
1615
+ raise ValueError(f"Invalid model_type: {model_type}")
1616
+
1617
+ print(base_model)
1618
+
1619
+ return base_model
1620
+
1621
+ def calculate_loss(output, target, loss_type='binary_cross_entropy_with_logits'):
1622
+ if loss_type == 'binary_cross_entropy_with_logits':
1623
+ loss = F.binary_cross_entropy_with_logits(output, target)
1624
+ elif loss_type == 'focal_loss':
1625
+ focal_loss_fn = FocalLossWithLogits(alpha=1, gamma=2)
1626
+ loss = focal_loss_fn(output, target)
1627
+ return loss
1628
+
1629
+ def pick_best_model(src):
1630
+ all_files = os.listdir(src)
1631
+ pth_files = [f for f in all_files if f.endswith('.pth')]
1632
+ pattern = re.compile(r'_epoch_(\d+)_acc_(\d+(?:\.\d+)?)')
1633
+
1634
+ def sort_key(x):
1635
+ match = pattern.search(x)
1636
+ if not match:
1637
+ return (0.0, 0) # Make the primary sorting key float for consistency
1638
+ g1, g2 = match.groups()
1639
+ return (float(g2), int(g1)) # Primary sort by accuracy (g2) and secondary sort by epoch (g1)
1640
+
1641
+ sorted_files = sorted(pth_files, key=sort_key, reverse=True)
1642
+ best_model = sorted_files[0]
1643
+ return os.path.join(src, best_model)
1644
+
1645
+ def get_paths_from_db(df, png_df, image_type='cell_png'):
1646
+ objects = df.index.tolist()
1647
+ filtered_df = png_df[png_df['png_path'].str.contains(image_type) & png_df['prcfo'].isin(objects)]
1648
+ return filtered_df
1649
+
1650
+ def save_file_lists(dst, data_set, ls):
1651
+ df = pd.DataFrame(ls, columns=[data_set])
1652
+ df.to_csv(f'{dst}/{data_set}.csv', index=False)
1653
+ return
1654
+
1655
+ def augment_single_image(args):
1656
+ img_path, dst = args
1657
+ img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)
1658
+ filename = os.path.basename(img_path).split('.')[0]
1659
+
1660
+ # Original Image
1661
+ cv2.imwrite(os.path.join(dst, f"{filename}_original.png"), img)
1662
+
1663
+ # 90 degree rotation
1664
+ img_rot_90 = cv2.rotate(img, cv2.ROTATE_90_CLOCKWISE)
1665
+ cv2.imwrite(os.path.join(dst, f"{filename}_rot_90.png"), img_rot_90)
1666
+
1667
+ # 180 degree rotation
1668
+ img_rot_180 = cv2.rotate(img, cv2.ROTATE_180)
1669
+ cv2.imwrite(os.path.join(dst, f"{filename}_rot_180.png"), img_rot_180)
1670
+
1671
+ # 270 degree rotation
1672
+ img_rot_270 = cv2.rotate(img, cv2.ROTATE_90_COUNTERCLOCKWISE)
1673
+ cv2.imwrite(os.path.join(dst, f"{filename}_rot_270.png"), img_rot_270)
1674
+
1675
+ # Horizontal Flip
1676
+ img_flip_hor = cv2.flip(img, 1)
1677
+ cv2.imwrite(os.path.join(dst, f"{filename}_flip_hor.png"), img_flip_hor)
1678
+
1679
+ # Vertical Flip
1680
+ img_flip_ver = cv2.flip(img, 0)
1681
+ cv2.imwrite(os.path.join(dst, f"{filename}_flip_ver.png"), img_flip_ver)
1682
+
1683
+ def augment_images(file_paths, dst):
1684
+ if not os.path.exists(dst):
1685
+ os.makedirs(dst)
1686
+
1687
+ args_list = [(img_path, dst) for img_path in file_paths]
1688
+
1689
+ with Pool(cpu_count()) as pool:
1690
+ pool.map(augment_single_image, args_list)
1691
+
1692
+ def augment_classes(dst, nc, pc, generate=True,move=True):
1693
+ aug_nc = os.path.join(dst,'aug_nc')
1694
+ aug_pc = os.path.join(dst,'aug_pc')
1695
+ all_ = len(nc)+len(pc)
1696
+ if generate == True:
1697
+ os.makedirs(aug_nc, exist_ok=True)
1698
+ if __name__ == '__main__':
1699
+ augment_images(file_paths=nc, dst=aug_nc)
1700
+
1701
+ os.makedirs(aug_pc, exist_ok=True)
1702
+ if __name__ == '__main__':
1703
+ augment_images(file_paths=pc, dst=aug_pc)
1704
+
1705
+ if move == True:
1706
+ aug = os.path.join(dst,'aug')
1707
+ aug_train_nc = os.path.join(aug,'train/nc')
1708
+ aug_train_pc = os.path.join(aug,'train/pc')
1709
+ aug_test_nc = os.path.join(aug,'test/nc')
1710
+ aug_test_pc = os.path.join(aug,'test/pc')
1711
+
1712
+ os.makedirs(aug_train_nc, exist_ok=True)
1713
+ os.makedirs(aug_train_pc, exist_ok=True)
1714
+ os.makedirs(aug_test_nc, exist_ok=True)
1715
+ os.makedirs(aug_test_pc, exist_ok=True)
1716
+
1717
+ aug_nc_list = [os.path.join(aug_nc, file) for file in os.listdir(aug_nc)]
1718
+ aug_pc_list = [os.path.join(aug_pc, file) for file in os.listdir(aug_pc)]
1719
+
1720
+ nc_train_data, nc_test_data = train_test_split(aug_nc_list, test_size=0.1, shuffle=True, random_state=42)
1721
+ pc_train_data, pc_test_data = train_test_split(aug_pc_list, test_size=0.1, shuffle=True, random_state=42)
1722
+
1723
+ i=0
1724
+ for path in nc_train_data:
1725
+ i+=1
1726
+ shutil.move(path, os.path.join(aug_train_nc, os.path.basename(path)))
1727
+ print(f'{i}/{all_}', end='\r', flush=True)
1728
+ for path in nc_test_data:
1729
+ i+=1
1730
+ shutil.move(path, os.path.join(aug_test_nc, os.path.basename(path)))
1731
+ print(f'{i}/{all_}', end='\r', flush=True)
1732
+ for path in pc_train_data:
1733
+ i+=1
1734
+ shutil.move(path, os.path.join(aug_train_pc, os.path.basename(path)))
1735
+ print(f'{i}/{all_}', end='\r', flush=True)
1736
+ for path in pc_test_data:
1737
+ i+=1
1738
+ shutil.move(path, os.path.join(aug_test_pc, os.path.basename(path)))
1739
+ print(f'{i}/{all_}', end='\r', flush=True)
1740
+ print(f'Train nc: {len(os.listdir(aug_train_nc))}, Train pc:{len(os.listdir(aug_test_nc))}, Test nc:{len(os.listdir(aug_train_pc))}, Test pc:{len(os.listdir(aug_test_pc))}')
1741
+ return
1742
+
1743
+ def annotate_predictions(csv_loc):
1744
+ df = pd.read_csv(csv_loc)
1745
+ df['filename'] = df['path'].apply(lambda x: x.split('/')[-1])
1746
+ df[['plate', 'well', 'field', 'object']] = df['filename'].str.split('_', expand=True)
1747
+ df['object'] = df['object'].str.replace('.png', '')
1748
+
1749
+ def assign_condition(row):
1750
+ plate = int(row['plate'])
1751
+ col = int(row['well'][1:])
1752
+
1753
+ if col > 3:
1754
+ if plate in [1, 2, 3, 4]:
1755
+ return 'screen'
1756
+ elif plate in [5, 6, 7, 8]:
1757
+ return 'pc'
1758
+ elif col in [1, 2, 3]:
1759
+ return 'nc'
1760
+ else:
1761
+ return ''
1762
+
1763
+ df['cond'] = df.apply(assign_condition, axis=1)
1764
+ return df
1765
+
1766
+ def init_globals(counter_, lock_):
1767
+ global counter, lock
1768
+ counter = counter_
1769
+ lock = lock_
1770
+
1771
+ def add_images_to_tar(args):
1772
+ global counter, lock, total_images
1773
+ paths_chunk, tar_path = args
1774
+ with tarfile.open(tar_path, 'w') as tar:
1775
+ for img_path in paths_chunk:
1776
+ arcname = os.path.basename(img_path)
1777
+ try:
1778
+ tar.add(img_path, arcname=arcname)
1779
+ with lock:
1780
+ counter.value += 1
1781
+ print(f"\rProcessed: {counter.value}/{total_images}", end='', flush=True)
1782
+ except FileNotFoundError:
1783
+ print(f"File not found: {img_path}")
1784
+ return tar_path
1785
+
1786
+ def generate_fraction_map(df, gene_column, min_frequency=0.0):
1787
+ df['fraction'] = df['count']/df['well_read_sum']
1788
+ genes = df[gene_column].unique().tolist()
1789
+ wells = df['prc'].unique().tolist()
1790
+ print(len(genes),len(wells))
1791
+ independent_variables = pd.DataFrame(columns=genes, index = wells)
1792
+ for index, row in df.iterrows():
1793
+ prc = row['prc']
1794
+ gene = row[gene_column]
1795
+ fraction = row['fraction']
1796
+ independent_variables.loc[prc,gene]=fraction
1797
+ independent_variables = independent_variables.dropna(axis=1, how='all')
1798
+ independent_variables = independent_variables.dropna(axis=0, how='all')
1799
+ independent_variables['sum'] = independent_variables.sum(axis=1)
1800
+ #sums = independent_variables['sum'].unique().tolist()
1801
+ #print(sums)
1802
+ #independent_variables = independent_variables[(independent_variables['sum'] == 0.0) | (independent_variables['sum'] == 1.0)]
1803
+ independent_variables = independent_variables.fillna(0.0)
1804
+ independent_variables = independent_variables.drop(columns=[col for col in independent_variables.columns if independent_variables[col].max() < min_frequency])
1805
+ independent_variables = independent_variables.drop('sum', axis=1)
1806
+ independent_variables.index.name = 'prc'
1807
+ loc = '/mnt/data/CellVoyager/20x/tsg101/crispr_screen/all/measurements/iv.csv'
1808
+ independent_variables.to_csv(loc, index=True, header=True, mode='w')
1809
+ return independent_variables
1810
+
1811
+ def fishers_odds(df, threshold=0.5, phenotyp_col='mean_pred'):
1812
+ # Binning based on phenotype score (e.g., above 0.8 as high)
1813
+ df['high_phenotype'] = df[phenotyp_col] < threshold
1814
+
1815
+ results = []
1816
+ mutants = df.columns[:-2]
1817
+ mutants = [item for item in mutants if item not in ['count_prc','mean_pathogen_area']]
1818
+ print(f'fishers df')
1819
+ display(df)
1820
+ # Perform Fisher's exact test for each mutant
1821
+ for mutant in mutants:
1822
+ contingency_table = pd.crosstab(df[mutant] > 0, df['high_phenotype'])
1823
+ if contingency_table.shape == (2, 2): # Check for 2x2 shape
1824
+ odds_ratio, p_value = fisher_exact(contingency_table)
1825
+ results.append((mutant, odds_ratio, p_value))
1826
+ else:
1827
+ # Optionally handle non-2x2 tables (e.g., append NaN or other placeholders)
1828
+ results.append((mutant, float('nan'), float('nan')))
1829
+
1830
+ # Convert results to DataFrame for easier handling
1831
+ results_df = pd.DataFrame(results, columns=['Mutant', 'OddsRatio', 'PValue'])
1832
+ # Remove rows with undefined odds ratios or p-values
1833
+ filtered_results_df = results_df.dropna(subset=['OddsRatio', 'PValue'])
1834
+
1835
+ pvalues = filtered_results_df['PValue'].values
1836
+
1837
+ # Check if pvalues array is empty
1838
+ if len(pvalues) > 0:
1839
+ # Apply Benjamini-Hochberg correction
1840
+ adjusted_pvalues = multipletests(pvalues, method='fdr_bh')[1]
1841
+ # Add adjusted p-values back to the dataframe
1842
+ filtered_results_df['AdjustedPValue'] = adjusted_pvalues
1843
+ # Filter significant results
1844
+ significant_mutants = filtered_results_df[filtered_results_df['AdjustedPValue'] < 0.05]
1845
+ else:
1846
+ print("No p-values to adjust. Check your data filtering steps.")
1847
+ significant_mutants = pd.DataFrame() # return empty DataFrame in this case
1848
+
1849
+ return filtered_results_df
1850
+
1851
+ def model_metrics(model):
1852
+
1853
+ # Calculate additional metrics
1854
+ rmse = np.sqrt(model.mse_resid)
1855
+ mae = np.mean(np.abs(model.resid))
1856
+ durbin_w_value = durbin_watson(model.resid)
1857
+
1858
+ # Display the additional metrics
1859
+ print("\nAdditional Metrics:")
1860
+ print(f"Root Mean Squared Error (RMSE): {rmse}")
1861
+ print(f"Mean Absolute Error (MAE): {mae}")
1862
+ print(f"Durbin-Watson: {durbin_w_value}")
1863
+
1864
+ # Residual Plots
1865
+ fig, ax = plt.subplots(2, 2, figsize=(15, 12))
1866
+
1867
+ # Residual vs. Fitted
1868
+ ax[0, 0].scatter(model.fittedvalues, model.resid, edgecolors = 'k', facecolors = 'none')
1869
+ ax[0, 0].set_title('Residuals vs Fitted')
1870
+ ax[0, 0].set_xlabel('Fitted values')
1871
+ ax[0, 0].set_ylabel('Residuals')
1872
+
1873
+ # Histogram
1874
+ sns.histplot(model.resid, kde=True, ax=ax[0, 1])
1875
+ ax[0, 1].set_title('Histogram of Residuals')
1876
+ ax[0, 1].set_xlabel('Residuals')
1877
+
1878
+ # QQ Plot
1879
+ sm.qqplot(model.resid, fit=True, line='45', ax=ax[1, 0])
1880
+ ax[1, 0].set_title('QQ Plot')
1881
+
1882
+ # Scale-Location
1883
+ standardized_resid = model.get_influence().resid_studentized_internal
1884
+ ax[1, 1].scatter(model.fittedvalues, np.sqrt(np.abs(standardized_resid)), edgecolors = 'k', facecolors = 'none')
1885
+ ax[1, 1].set_title('Scale-Location')
1886
+ ax[1, 1].set_xlabel('Fitted values')
1887
+ ax[1, 1].set_ylabel('$\sqrt{|Standardized Residuals|}$')
1888
+
1889
+ plt.tight_layout()
1890
+ plt.show()
1891
+
1892
+ def check_multicollinearity(x):
1893
+ """Checks multicollinearity of the predictors by computing the VIF."""
1894
+ vif_data = pd.DataFrame()
1895
+ vif_data["Variable"] = x.columns
1896
+ vif_data["VIF"] = [variance_inflation_factor(x.values, i) for i in range(x.shape[1])]
1897
+ return vif_data
1898
+
1899
+ def generate_dependent_variable(df, dv_loc, pc_min=0.95, nc_max=0.05, agg_type='mean'):
1900
+
1901
+ from .plot import _plot_histograms_and_stats, _plot_plates
1902
+
1903
+ def qstring_to_float(qstr):
1904
+ number = int(qstr[1:]) # Remove the "q" and convert the rest to an integer
1905
+ return number / 100.0
1906
+
1907
+ print("Unique values in plate:", df['plate'].unique())
1908
+ dv_cell_loc = f'{dv_loc}/dv_cell.csv'
1909
+ dv_well_loc = f'{dv_loc}/dv_well.csv'
1910
+
1911
+ df['pred'] = 1-df['pred'] #if you swiched pc and nc
1912
+ df = df[(df['pred'] <= nc_max) | (df['pred'] >= pc_min)]
1913
+
1914
+ if 'prc' not in df.columns:
1915
+ df['prc'] = df['plate'] + '_' + df['row'] + '_' + df['col']
1916
+
1917
+ if agg_type.startswith('q'):
1918
+ val = qstring_to_float(agg_type)
1919
+ agg_type = lambda x: x.quantile(val)
1920
+
1921
+ # Aggregating for mean prediction and total count
1922
+ df_grouped = df.groupby('prc').agg(
1923
+ pred=('pred', agg_type),
1924
+ recruitment=('recruitment', agg_type),
1925
+ count_prc=('prc', 'size'),
1926
+ #count_above_95=('pred', lambda x: (x > 0.95).sum()),
1927
+ mean_pathogen_area=('pathogen_area', 'mean')
1928
+ )
1929
+
1930
+ df_cell = df[['prc', 'pred', 'pathogen_area', 'recruitment']]
1931
+
1932
+ df_cell.to_csv(dv_cell_loc, index=True, header=True, mode='w')
1933
+ df_grouped.to_csv(dv_well_loc, index=True, header=True, mode='w') # Changed from loc to dv_loc
1934
+ display(df)
1935
+ _plot_histograms_and_stats(df)
1936
+ df_grouped = df_grouped.sort_values(by='count_prc', ascending=True)
1937
+ display(df_grouped)
1938
+ print('pred')
1939
+ _plot_plates(df=df_cell, variable='pred', grouping='mean', min_max='allq', cmap='viridis')
1940
+ print('recruitment')
1941
+ _plot_plates(df=df_cell, variable='recruitment', grouping='mean', min_max='allq', cmap='viridis')
1942
+
1943
+ return df_grouped
1944
+
1945
+ def lasso_reg(merged_df, alpha_value=0.01, reg_type='lasso'):
1946
+ # Separate predictors and response
1947
+ X = merged_df[['gene', 'grna', 'plate', 'row', 'column']]
1948
+ y = merged_df['pred']
1949
+
1950
+ # One-hot encode the categorical predictors
1951
+ encoder = OneHotEncoder(drop='first') # drop one category to avoid the dummy variable trap
1952
+ X_encoded = encoder.fit_transform(X).toarray()
1953
+ feature_names = encoder.get_feature_names_out(input_features=X.columns)
1954
+
1955
+ if reg_type == 'ridge':
1956
+ # Fit ridge regression
1957
+ ridge = Ridge(alpha=alpha_value)
1958
+ ridge.fit(X_encoded, y)
1959
+ coefficients = ridge.coef_
1960
+ coeff_dict = dict(zip(feature_names, ridge.coef_))
1961
+
1962
+ if reg_type == 'lasso':
1963
+ # Fit Lasso regression
1964
+ lasso = Lasso(alpha=alpha_value)
1965
+ lasso.fit(X_encoded, y)
1966
+ coefficients = lasso.coef_
1967
+ coeff_dict = dict(zip(feature_names, lasso.coef_))
1968
+ coeff_df = pd.DataFrame(list(coeff_dict.items()), columns=['Feature', 'Coefficient'])
1969
+ return coeff_df
1970
+
1971
+ def MLR(merged_df, refine_model):
1972
+
1973
+ from .plot import _reg_v_plot
1974
+
1975
+ #model = smf.ols("pred ~ gene + grna + gene:grna + plate + row + column", merged_df).fit()
1976
+ model = smf.ols("pred ~ gene:grna + plate + row + column", merged_df).fit()
1977
+ # Display model metrics and summary
1978
+ model_metrics(model)
1979
+
1980
+ if refine_model:
1981
+ # Filter outliers
1982
+ std_resid = model.get_influence().resid_studentized_internal
1983
+ outliers_resid = np.where(np.abs(std_resid) > 3)[0]
1984
+ (c, p) = model.get_influence().cooks_distance
1985
+ outliers_cooks = np.where(c > 4/(len(merged_df)-merged_df.shape[1]-1))[0]
1986
+ outliers = reduce(np.union1d, (outliers_resid, outliers_cooks))
1987
+ merged_df_filtered = merged_df.drop(merged_df.index[outliers])
1988
+
1989
+ display(merged_df_filtered)
1990
+
1991
+ # Refit the model with filtered data
1992
+ model = smf.ols("pred ~ gene + grna + gene:grna + row + column", merged_df_filtered).fit()
1993
+ print("Number of outliers detected by standardized residuals:", len(outliers_resid))
1994
+ print("Number of outliers detected by Cook's distance:", len(outliers_cooks))
1995
+
1996
+ model_metrics(model)
1997
+ print(model.summary())
1998
+
1999
+ # Extract interaction coefficients and determine the maximum effect size
2000
+ interaction_coeffs = {key: val for key, val in model.params.items() if "gene[T." in key and ":grna[T." in key}
2001
+ interaction_pvalues = {key: val for key, val in model.pvalues.items() if "gene[T." in key and ":grna[T." in key}
2002
+
2003
+ max_effects = {}
2004
+ max_effects_pvalues = {}
2005
+ for key, val in interaction_coeffs.items():
2006
+ gene_name = key.split(":")[0].replace("gene[T.", "").replace("]", "")
2007
+ if gene_name not in max_effects or abs(max_effects[gene_name]) < abs(val):
2008
+ max_effects[gene_name] = val
2009
+ max_effects_pvalues[gene_name] = interaction_pvalues[key]
2010
+
2011
+ for key in max_effects:
2012
+ print(f"Key: {key}: {max_effects[key]}, p:{max_effects_pvalues[key]}")
2013
+
2014
+ df = pd.DataFrame([max_effects, max_effects_pvalues])
2015
+ df = df.transpose()
2016
+ df = df.rename(columns={df.columns[0]: 'effect', df.columns[1]: 'p'})
2017
+ df = df.sort_values(by=['effect', 'p'], ascending=[False, True])
2018
+
2019
+ _reg_v_plot(df)
2020
+
2021
+ return max_effects, max_effects_pvalues, model, df
2022
+
2023
+ #def normalize_to_dtype(array, q1=2, q2=98, percentiles=None):
2024
+ # if len(array.shape) == 2:
2025
+ # array = np.expand_dims(array, axis=-1)
2026
+ # num_channels = array.shape[-1]
2027
+ # new_stack = np.empty_like(array)
2028
+ # for channel in range(num_channels):
2029
+ # img = array[..., channel]
2030
+ # non_zero_img = img[img > 0]
2031
+ # if non_zero_img.size > 0:
2032
+ # img_min = np.percentile(non_zero_img, q1)
2033
+ # img_max = np.percentile(non_zero_img, q2)
2034
+ # else:
2035
+ # img_min, img_max = (percentiles[channel] if percentiles and channel < len(percentiles)
2036
+ # else (img.min(), img.max()))
2037
+ # new_stack[..., channel] = rescale_intensity(img, in_range=(img_min, img_max), out_range='dtype')
2038
+ # if new_stack.shape[-1] == 1:
2039
+ # new_stack = np.squeeze(new_stack, axis=-1)
2040
+ # return new_stack
2041
+
2042
+ def get_files_from_dir(dir_path, file_extension="*"):
2043
+ return glob(os.path.join(dir_path, file_extension))
2044
+
2045
+ def create_circular_mask(h, w, center=None, radius=None):
2046
+ if center is None: # use the middle of the image
2047
+ center = (int(w/2), int(h/2))
2048
+ if radius is None: # use the smallest distance between the center and image walls
2049
+ radius = min(center[0], center[1], w-center[0], h-center[1])
2050
+
2051
+ Y, X = np.ogrid[:h, :w]
2052
+ dist_from_center = np.sqrt((X - center[0])**2 + (Y-center[1])**2)
2053
+
2054
+ mask = dist_from_center <= radius
2055
+ return mask
2056
+
2057
+ def apply_mask(image, output_value=0):
2058
+ h, w = image.shape[:2] # Assuming image is grayscale or RGB
2059
+ mask = create_circular_mask(h, w)
2060
+
2061
+ # If the image has more than one channel, repeat the mask for each channel
2062
+ if len(image.shape) > 2:
2063
+ mask = np.repeat(mask[:, :, np.newaxis], image.shape[2], axis=2)
2064
+
2065
+ # Apply the mask - set pixels outside of the mask to output_value
2066
+ masked_image = np.where(mask, image, output_value)
2067
+ return masked_image
2068
+
2069
+ def invert_image(image):
2070
+ # The maximum value depends on the image dtype (e.g., 255 for uint8)
2071
+ max_value = np.iinfo(image.dtype).max
2072
+ inverted_image = max_value - image
2073
+ return inverted_image
2074
+
2075
+ def resize_images_and_labels(images, labels, target_height, target_width, show_example=True):
2076
+
2077
+ from .plot import plot_resize
2078
+
2079
+ resized_images = []
2080
+ resized_labels = []
2081
+ if not images is None and not labels is None:
2082
+ for image, label in zip(images, labels):
2083
+
2084
+ if image.ndim == 2:
2085
+ image_shape = (target_height, target_width)
2086
+ elif image.ndim == 3:
2087
+ image_shape = (target_height, target_width, image.shape[-1])
2088
+
2089
+ resized_image = resizescikit(image, image_shape, preserve_range=True, anti_aliasing=True).astype(image.dtype)
2090
+ resized_label = resizescikit(label, (target_height, target_width), order=0, preserve_range=True, anti_aliasing=False).astype(label.dtype)
2091
+
2092
+ if resized_image.shape[-1] == 1:
2093
+ resized_image = np.squeeze(resized_image)
2094
+
2095
+ resized_images.append(resized_image)
2096
+ resized_labels.append(resized_label)
2097
+
2098
+ elif not images is None:
2099
+ for image in images:
2100
+
2101
+ if image.ndim == 2:
2102
+ image_shape = (target_height, target_width)
2103
+ elif image.ndim == 3:
2104
+ image_shape = (target_height, target_width, image.shape[-1])
2105
+
2106
+ resized_image = resizescikit(image, image_shape, preserve_range=True, anti_aliasing=True).astype(image.dtype)
2107
+
2108
+ if resized_image.shape[-1] == 1:
2109
+ resized_image = np.squeeze(resized_image)
2110
+
2111
+ resized_images.append(resized_image)
2112
+
2113
+ elif not labels is None:
2114
+ for label in labels:
2115
+ resized_label = resizescikit(label, (target_height, target_width), order=0, preserve_range=True, anti_aliasing=False).astype(label.dtype)
2116
+ resized_labels.append(resized_label)
2117
+
2118
+ if show_example:
2119
+ if not images is None and not labels is None:
2120
+ plot_resize(images, resized_images, labels, resized_labels)
2121
+ elif not images is None:
2122
+ plot_resize(images, resized_images, images, resized_images)
2123
+ elif not labels is None:
2124
+ plot_resize(labels, resized_labels, labels, resized_labels)
2125
+
2126
+ return resized_images, resized_labels
2127
+
2128
+ def resize_labels_back(labels, orig_dims):
2129
+ resized_labels = []
2130
+
2131
+ if len(labels) != len(orig_dims):
2132
+ raise ValueError("The length of labels and orig_dims must match.")
2133
+
2134
+ for label, dims in zip(labels, orig_dims):
2135
+ # Ensure dims is a tuple of two integers (width, height)
2136
+ if not isinstance(dims, tuple) or len(dims) != 2:
2137
+ raise ValueError("Each element in orig_dims must be a tuple of two integers representing the original dimensions (width, height)")
2138
+
2139
+ resized_label = resizescikit(label, dims, order=0, preserve_range=True, anti_aliasing=False).astype(label.dtype)
2140
+ resized_labels.append(resized_label)
2141
+
2142
+ return resized_labels
2143
+
2144
+ def calculate_iou(mask1, mask2):
2145
+ mask1, mask2 = pad_to_same_shape(mask1, mask2)
2146
+ intersection = np.logical_and(mask1, mask2).sum()
2147
+ union = np.logical_or(mask1, mask2).sum()
2148
+ return intersection / union if union != 0 else 0
2149
+
2150
+ def match_masks(true_masks, pred_masks, iou_threshold):
2151
+ matches = []
2152
+ matched_true_masks_indices = set() # Use set to store indices of matched true masks
2153
+
2154
+ for pred_mask in pred_masks:
2155
+ for true_mask_index, true_mask in enumerate(true_masks):
2156
+ if true_mask_index not in matched_true_masks_indices:
2157
+ iou = calculate_iou(true_mask, pred_mask)
2158
+ if iou >= iou_threshold:
2159
+ matches.append((true_mask, pred_mask))
2160
+ matched_true_masks_indices.add(true_mask_index) # Store the index of the matched true mask
2161
+ break # Move on to the next predicted mask
2162
+ return matches
2163
+
2164
+ def compute_average_precision(matches, num_true_masks, num_pred_masks):
2165
+ TP = len(matches)
2166
+ FP = num_pred_masks - TP
2167
+ FN = num_true_masks - TP
2168
+ precision = TP / (TP + FP) if TP + FP > 0 else 0
2169
+ recall = TP / (TP + FN) if TP + FN > 0 else 0
2170
+ return precision, recall
2171
+
2172
+ def pad_to_same_shape(mask1, mask2):
2173
+ # Find the shape differences
2174
+ shape_diff = np.array([max(mask1.shape[0], mask2.shape[0]) - mask1.shape[0],
2175
+ max(mask1.shape[1], mask2.shape[1]) - mask1.shape[1]])
2176
+ pad_mask1 = ((0, shape_diff[0]), (0, shape_diff[1]))
2177
+ shape_diff = np.array([max(mask1.shape[0], mask2.shape[0]) - mask2.shape[0],
2178
+ max(mask1.shape[1], mask2.shape[1]) - mask2.shape[1]])
2179
+ pad_mask2 = ((0, shape_diff[0]), (0, shape_diff[1]))
2180
+
2181
+ padded_mask1 = np.pad(mask1, pad_mask1, mode='constant', constant_values=0)
2182
+ padded_mask2 = np.pad(mask2, pad_mask2, mode='constant', constant_values=0)
2183
+
2184
+ return padded_mask1, padded_mask2
2185
+
2186
+ def compute_ap_over_iou_thresholds(true_masks, pred_masks, iou_thresholds):
2187
+ precision_recall_pairs = []
2188
+ for iou_threshold in iou_thresholds:
2189
+ matches = match_masks(true_masks, pred_masks, iou_threshold)
2190
+ precision, recall = compute_average_precision(matches, len(true_masks), len(pred_masks))
2191
+ # Check that precision and recall are within the range [0, 1]
2192
+ if not 0 <= precision <= 1 or not 0 <= recall <= 1:
2193
+ raise ValueError(f'Precision or recall out of bounds. Precision: {precision}, Recall: {recall}')
2194
+ precision_recall_pairs.append((precision, recall))
2195
+
2196
+ # Sort by recall values
2197
+ precision_recall_pairs = sorted(precision_recall_pairs, key=lambda x: x[1])
2198
+ sorted_precisions = [p[0] for p in precision_recall_pairs]
2199
+ sorted_recalls = [p[1] for p in precision_recall_pairs]
2200
+ return np.trapz(sorted_precisions, x=sorted_recalls)
2201
+
2202
+ def compute_segmentation_ap(true_masks, pred_masks, iou_thresholds=np.linspace(0.5, 0.95, 10)):
2203
+ true_mask_labels = label(true_masks)
2204
+ pred_mask_labels = label(pred_masks)
2205
+ true_mask_regions = [region.image for region in regionprops(true_mask_labels)]
2206
+ pred_mask_regions = [region.image for region in regionprops(pred_mask_labels)]
2207
+ return compute_ap_over_iou_thresholds(true_mask_regions, pred_mask_regions, iou_thresholds)
2208
+
2209
+ def jaccard_index(mask1, mask2):
2210
+ intersection = np.logical_and(mask1, mask2)
2211
+ union = np.logical_or(mask1, mask2)
2212
+ return np.sum(intersection) / np.sum(union)
2213
+
2214
+ def dice_coefficient(mask1, mask2):
2215
+ # Convert to binary masks
2216
+ mask1 = np.where(mask1 > 0, 1, 0)
2217
+ mask2 = np.where(mask2 > 0, 1, 0)
2218
+
2219
+ # Calculate intersection and total
2220
+ intersection = np.sum(mask1 & mask2)
2221
+ total = np.sum(mask1) + np.sum(mask2)
2222
+
2223
+ # Handle the case where both masks are empty
2224
+ if total == 0:
2225
+ return 1.0
2226
+
2227
+ # Return the Dice coefficient
2228
+ return 2.0 * intersection / total
2229
+
2230
+ def extract_boundaries(mask, dilation_radius=1):
2231
+ binary_mask = (mask > 0).astype(np.uint8)
2232
+ struct_elem = np.ones((dilation_radius*2+1, dilation_radius*2+1))
2233
+ dilated = binary_dilation(binary_mask, footprint=struct_elem)
2234
+ eroded = binary_erosion(binary_mask, footprint=struct_elem)
2235
+ boundary = dilated ^ eroded
2236
+ return boundary
2237
+
2238
+ def boundary_f1_score(mask_true, mask_pred, dilation_radius=1):
2239
+ # Assume extract_boundaries is defined to extract object boundaries with given dilation_radius
2240
+ boundary_true = extract_boundaries(mask_true, dilation_radius)
2241
+ boundary_pred = extract_boundaries(mask_pred, dilation_radius)
2242
+
2243
+ # Calculate intersection of boundaries
2244
+ intersection = np.logical_and(boundary_true, boundary_pred)
2245
+
2246
+ # Calculate precision and recall for boundary detection
2247
+ precision = np.sum(intersection) / (np.sum(boundary_pred) + 1e-6)
2248
+ recall = np.sum(intersection) / (np.sum(boundary_true) + 1e-6)
2249
+
2250
+ # Calculate F1 score as harmonic mean of precision and recall
2251
+ f1 = 2 * (precision * recall) / (precision + recall + 1e-6)
2252
+
2253
+ return f1
2254
+
2255
+
2256
+
2257
+ def _remove_noninfected(stack, cell_dim, nucleus_dim, pathogen_dim):
2258
+ """
2259
+ Remove non-infected cells from the stack based on the provided dimensions.
2260
+
2261
+ Args:
2262
+ stack (ndarray): The stack of images.
2263
+ cell_dim (int or None): The dimension index for the cell mask. If None, a zero-filled mask will be used.
2264
+ nucleus_dim (int or None): The dimension index for the nucleus mask. If None, a zero-filled mask will be used.
2265
+ pathogen_dim (int or None): The dimension index for the pathogen mask. If None, a zero-filled mask will be used.
2266
+
2267
+ Returns:
2268
+ ndarray: The updated stack with non-infected cells removed.
2269
+ """
2270
+ if not cell_dim is None:
2271
+ cell_mask = stack[:, :, cell_dim]
2272
+ else:
2273
+ cell_mask = np.zeros_like(stack)
2274
+ if not nucleus_dim is None:
2275
+ nucleus_mask = stack[:, :, nucleus_dim]
2276
+ else:
2277
+ nucleus_mask = np.zeros_like(stack)
2278
+
2279
+ if not pathogen_dim is None:
2280
+ pathogen_mask = stack[:, :, pathogen_dim]
2281
+ else:
2282
+ pathogen_mask = np.zeros_like(stack)
2283
+
2284
+ for cell_label in np.unique(cell_mask)[1:]:
2285
+ cell_region = cell_mask == cell_label
2286
+ labels_in_cell = np.unique(pathogen_mask[cell_region])
2287
+ if len(labels_in_cell) <= 1:
2288
+ cell_mask[cell_region] = 0
2289
+ nucleus_mask[cell_region] = 0
2290
+ if not cell_dim is None:
2291
+ stack[:, :, cell_dim] = cell_mask
2292
+ if not nucleus_dim is None:
2293
+ stack[:, :, nucleus_dim] = nucleus_mask
2294
+ return stack
2295
+
2296
+ def _remove_outside_objects(stack, cell_dim, nucleus_dim, pathogen_dim):
2297
+ """
2298
+ Remove outside objects from the stack based on the provided dimensions.
2299
+
2300
+ Args:
2301
+ stack (ndarray): The stack of images.
2302
+ cell_dim (int): The dimension index of the cell mask in the stack.
2303
+ nucleus_dim (int): The dimension index of the nucleus mask in the stack.
2304
+ pathogen_dim (int): The dimension index of the pathogen mask in the stack.
2305
+
2306
+ Returns:
2307
+ ndarray: The updated stack with outside objects removed.
2308
+ """
2309
+ if not cell_dim is None:
2310
+ cell_mask = stack[:, :, cell_dim]
2311
+ else:
2312
+ return stack
2313
+ nucleus_mask = stack[:, :, nucleus_dim]
2314
+ pathogen_mask = stack[:, :, pathogen_dim]
2315
+ pathogen_labels = np.unique(pathogen_mask)[1:]
2316
+ for pathogen_label in pathogen_labels:
2317
+ pathogen_region = pathogen_mask == pathogen_label
2318
+ cell_in_pathogen_region = np.unique(cell_mask[pathogen_region])
2319
+ cell_in_pathogen_region = cell_in_pathogen_region[cell_in_pathogen_region != 0] # Exclude background
2320
+ if len(cell_in_pathogen_region) == 0:
2321
+ pathogen_mask[pathogen_region] = 0
2322
+ corresponding_nucleus_region = nucleus_mask == pathogen_label
2323
+ nucleus_mask[corresponding_nucleus_region] = 0
2324
+ stack[:, :, cell_dim] = cell_mask
2325
+ stack[:, :, nucleus_dim] = nucleus_mask
2326
+ stack[:, :, pathogen_dim] = pathogen_mask
2327
+ return stack
2328
+
2329
+ def _remove_multiobject_cells(stack, mask_dim, cell_dim, nucleus_dim, pathogen_dim, object_dim):
2330
+ """
2331
+ Remove multi-object cells from the stack.
2332
+
2333
+ Args:
2334
+ stack (ndarray): The stack of images.
2335
+ mask_dim (int): The dimension of the mask in the stack.
2336
+ cell_dim (int): The dimension of the cell in the stack.
2337
+ nucleus_dim (int): The dimension of the nucleus in the stack.
2338
+ pathogen_dim (int): The dimension of the pathogen in the stack.
2339
+ object_dim (int): The dimension of the object in the stack.
2340
+
2341
+ Returns:
2342
+ ndarray: The updated stack with multi-object cells removed.
2343
+ """
2344
+ cell_mask = stack[:, :, mask_dim]
2345
+ nucleus_mask = stack[:, :, nucleus_dim]
2346
+ pathogen_mask = stack[:, :, pathogen_dim]
2347
+ object_mask = stack[:, :, object_dim]
2348
+
2349
+ for cell_label in np.unique(cell_mask)[1:]:
2350
+ cell_region = cell_mask == cell_label
2351
+ labels_in_cell = np.unique(object_mask[cell_region])
2352
+ if len(labels_in_cell) > 2:
2353
+ cell_mask[cell_region] = 0
2354
+ nucleus_mask[cell_region] = 0
2355
+ for pathogen_label in labels_in_cell[1:]: # Skip the first label (0)
2356
+ pathogen_mask[pathogen_mask == pathogen_label] = 0
2357
+
2358
+ stack[:, :, cell_dim] = cell_mask
2359
+ stack[:, :, nucleus_dim] = nucleus_mask
2360
+ stack[:, :, pathogen_dim] = pathogen_mask
2361
+ return stack
2362
+
2363
+ def merge_touching_objects(mask, threshold=0.25):
2364
+ """
2365
+ Merges touching objects in a binary mask based on the percentage of their shared boundary.
2366
+
2367
+ Args:
2368
+ mask (ndarray): Binary mask representing objects.
2369
+ threshold (float, optional): Threshold value for merging objects. Defaults to 0.25.
2370
+
2371
+ Returns:
2372
+ ndarray: Merged mask.
2373
+
2374
+ """
2375
+ perimeters = {}
2376
+ labels = np.unique(mask)
2377
+ # Calculating perimeter of each object
2378
+ for label in labels:
2379
+ if label != 0: # Ignore background
2380
+ edges = morphology.erosion(mask == label) ^ (mask == label)
2381
+ perimeters[label] = np.sum(edges)
2382
+ # Detect touching objects and find the shared boundary
2383
+ shared_perimeters = {}
2384
+ dilated = morphology.dilation(mask > 0)
2385
+ for label in labels:
2386
+ if label != 0: # Ignore background
2387
+ # Find the objects that this object is touching
2388
+ dilated_label = morphology.dilation(mask == label)
2389
+ touching_labels = np.unique(mask[dilated & (dilated_label != 0) & (mask != 0)])
2390
+ for touching_label in touching_labels:
2391
+ if touching_label != label: # Exclude the object itself
2392
+ shared_boundary = dilated_label & morphology.dilation(mask == touching_label)
2393
+ shared_perimeters[(label, touching_label)] = np.sum(shared_boundary)
2394
+ # Merge objects if more than 25% of their boundary is touching
2395
+ for (label1, label2), shared_perimeter in shared_perimeters.items():
2396
+ if shared_perimeter > threshold * min(perimeters[label1], perimeters[label2]):
2397
+ mask[mask == label2] = label1 # Merge label2 into label1
2398
+ return mask
2399
+
2400
+ def remove_intensity_objects(image, mask, intensity_threshold, mode):
2401
+ """
2402
+ Removes objects from the mask based on their mean intensity in the original image.
2403
+
2404
+ Args:
2405
+ image (ndarray): The original image.
2406
+ mask (ndarray): The mask containing labeled objects.
2407
+ intensity_threshold (float): The threshold value for mean intensity.
2408
+ mode (str): The mode for intensity comparison. Can be 'low' or 'high'.
2409
+
2410
+ Returns:
2411
+ ndarray: The updated mask with objects removed.
2412
+
2413
+ """
2414
+ # Calculate the mean intensity of each object in the original image
2415
+ props = regionprops_table(mask, image, properties=('label', 'mean_intensity'))
2416
+ # Find the labels of the objects with mean intensity below the threshold
2417
+ if mode == 'low':
2418
+ labels_to_remove = props['label'][props['mean_intensity'] < intensity_threshold]
2419
+ if mode == 'high':
2420
+ labels_to_remove = props['label'][props['mean_intensity'] > intensity_threshold]
2421
+ # Remove these objects from the mask
2422
+ mask[np.isin(mask, labels_to_remove)] = 0
2423
+ return mask
2424
+
2425
+ def _filter_closest_to_stat(df, column, n_rows, use_median=False):
2426
+ """
2427
+ Filter the DataFrame to include the closest rows to a statistical measure.
2428
+
2429
+ Args:
2430
+ df (pandas.DataFrame): The input DataFrame.
2431
+ column (str): The column name to calculate the statistical measure.
2432
+ n_rows (int): The number of closest rows to include in the result.
2433
+ use_median (bool, optional): Whether to use the median or mean as the statistical measure.
2434
+ Defaults to False (mean).
2435
+
2436
+ Returns:
2437
+ pandas.DataFrame: The filtered DataFrame with the closest rows to the statistical measure.
2438
+ """
2439
+ if use_median:
2440
+ target_value = df[column].median()
2441
+ else:
2442
+ target_value = df[column].mean()
2443
+ df['diff'] = (df[column] - target_value).abs()
2444
+ result_df = df.sort_values(by='diff').head(n_rows)
2445
+ result_df = result_df.drop(columns=['diff'])
2446
+ return result_df
2447
+
2448
+ def _find_similar_sized_images(file_list):
2449
+ """
2450
+ Find the largest group of images with the most similar size and shape.
2451
+
2452
+ Args:
2453
+ file_list (list): List of file paths to the images.
2454
+
2455
+ Returns:
2456
+ list: List of file paths belonging to the largest group of images with the most similar size and shape.
2457
+ """
2458
+ # Dictionary to hold image sizes and their paths
2459
+ size_to_paths = defaultdict(list)
2460
+ # Iterate over image paths to get their dimensions
2461
+ for path in file_list:
2462
+ img = cv2.imread(path, cv2.IMREAD_UNCHANGED) # Read with unchanged color space to support different image types
2463
+ if img is not None:
2464
+ # Find indices where the image is not padded (non-zero)
2465
+ if img.ndim == 3: # Color image
2466
+ mask = np.any(img != 0, axis=2)
2467
+ else: # Grayscale image
2468
+ mask = img != 0
2469
+ # Find the bounding box of non-zero regions
2470
+ coords = np.argwhere(mask)
2471
+ if coords.size == 0: # Skip images that are completely padded
2472
+ continue
2473
+ y0, x0 = coords.min(axis=0)
2474
+ y1, x1 = coords.max(axis=0) + 1 # Add 1 because slice end index is exclusive
2475
+ # Crop the image to remove padding
2476
+ cropped_img = img[y0:y1, x0:x1]
2477
+ # Get dimensions of the cropped image
2478
+ height, width = cropped_img.shape[:2]
2479
+ aspect_ratio = width / height
2480
+ size_key = (width, height, round(aspect_ratio, 2)) # Group by width, height, and aspect ratio
2481
+ size_to_paths[size_key].append(path)
2482
+ # Find the largest group of images with the most similar size and shape
2483
+ largest_group = max(size_to_paths.values(), key=len)
2484
+ return largest_group
2485
+
2486
+ def _relabel_parent_with_child_labels(parent_mask, child_mask):
2487
+ """
2488
+ Relabels the parent mask based on overlapping child labels.
2489
+
2490
+ Args:
2491
+ parent_mask (ndarray): Binary mask representing the parent objects.
2492
+ child_mask (ndarray): Binary mask representing the child objects.
2493
+
2494
+ Returns:
2495
+ tuple: A tuple containing the relabeled parent mask and the original child mask.
2496
+
2497
+ """
2498
+ # Label parent mask to identify unique objects
2499
+ parent_labels = label(parent_mask, background=0)
2500
+ # Use the original child mask labels directly, without relabeling
2501
+ child_labels = child_mask
2502
+
2503
+ # Create a new parent mask for updated labels
2504
+ parent_mask_new = np.zeros_like(parent_mask)
2505
+
2506
+ # Directly relabel parent cells based on overlapping child labels
2507
+ unique_child_labels = np.unique(child_labels)[1:] # Skip background
2508
+ for child_label in unique_child_labels:
2509
+ child_area_mask = (child_labels == child_label)
2510
+ overlapping_parent_label = np.unique(parent_labels[child_area_mask])
2511
+
2512
+ # Since each parent is assumed to overlap with exactly one nucleus,
2513
+ # directly set the parent label to the child label where overlap occurs
2514
+ for parent_label in overlapping_parent_label:
2515
+ if parent_label != 0: # Skip background
2516
+ parent_mask_new[parent_labels == parent_label] = child_label
2517
+
2518
+ # For cells containing multiple nucleus, standardize all nucleus to the first label
2519
+ # This will be done only if needed, as per your condition
2520
+ for parent_label in np.unique(parent_mask_new)[1:]: # Skip background
2521
+ parent_area_mask = (parent_mask_new == parent_label)
2522
+ child_labels_in_parent = np.unique(child_mask[parent_area_mask])
2523
+ child_labels_in_parent = child_labels_in_parent[child_labels_in_parent != 0] # Exclude background
2524
+
2525
+ if len(child_labels_in_parent) > 1:
2526
+ # Standardize to the first child label within this parent
2527
+ first_child_label = child_labels_in_parent[0]
2528
+ for child_label in child_labels_in_parent:
2529
+ child_mask[child_mask == child_label] = first_child_label
2530
+
2531
+ return parent_mask_new, child_mask
2532
+
2533
+ def _exclude_objects(cell_mask, nucleus_mask, pathogen_mask, cytoplasm_mask, include_uninfected=True):
2534
+ """
2535
+ Exclude objects from the masks based on certain criteria.
2536
+
2537
+ Args:
2538
+ cell_mask (ndarray): Mask representing cells.
2539
+ nucleus_mask (ndarray): Mask representing nucleus.
2540
+ pathogen_mask (ndarray): Mask representing pathogens.
2541
+ cytoplasm_mask (ndarray): Mask representing cytoplasm.
2542
+ include_uninfected (bool, optional): Whether to include uninfected cells. Defaults to True.
2543
+
2544
+ Returns:
2545
+ tuple: A tuple containing the filtered cell mask, nucleus mask, pathogen mask, and cytoplasm mask.
2546
+ """
2547
+ # Remove cells with no nucleus or cytoplasm (or pathogen)
2548
+ filtered_cells = np.zeros_like(cell_mask) # Initialize a new mask to store the filtered cells.
2549
+ for cell_label in np.unique(cell_mask): # Iterate over all cell labels in the cell mask.
2550
+ if cell_label == 0: # Skip background
2551
+ continue
2552
+ cell_region = cell_mask == cell_label # Get a mask for the current cell.
2553
+ # Check existence of nucleus, cytoplasm and pathogen in the current cell.
2554
+ has_nucleus = np.any(nucleus_mask[cell_region])
2555
+ has_cytoplasm = np.any(cytoplasm_mask[cell_region])
2556
+ has_pathogen = np.any(pathogen_mask[cell_region])
2557
+ if include_uninfected:
2558
+ if has_nucleus and has_cytoplasm:
2559
+ filtered_cells[cell_region] = cell_label
2560
+ else:
2561
+ if has_nucleus and has_cytoplasm and has_pathogen:
2562
+ filtered_cells[cell_region] = cell_label
2563
+ # Remove objects outside of cells
2564
+ nucleus_mask = nucleus_mask * (filtered_cells > 0)
2565
+ pathogen_mask = pathogen_mask * (filtered_cells > 0)
2566
+ cytoplasm_mask = cytoplasm_mask * (filtered_cells > 0)
2567
+ return filtered_cells, nucleus_mask, pathogen_mask, cytoplasm_mask
2568
+
2569
+ def _merge_overlapping_objects(mask1, mask2):
2570
+ """
2571
+ Merge overlapping objects in two masks.
2572
+
2573
+ Args:
2574
+ mask1 (ndarray): First mask.
2575
+ mask2 (ndarray): Second mask.
2576
+
2577
+ Returns:
2578
+ tuple: A tuple containing the merged masks (mask1, mask2).
2579
+ """
2580
+ labeled_1 = label(mask1)
2581
+ num_1 = np.max(labeled_1)
2582
+ for m1_id in range(1, num_1 + 1):
2583
+ current_1_mask = labeled_1 == m1_id
2584
+ overlapping_2_labels = np.unique(mask2[current_1_mask])
2585
+ overlapping_2_labels = overlapping_2_labels[overlapping_2_labels != 0]
2586
+ if len(overlapping_2_labels) > 1:
2587
+ overlap_percentages = [np.sum(current_1_mask & (mask2 == m2_label)) / np.sum(current_1_mask) * 100 for m2_label in overlapping_2_labels]
2588
+ max_overlap_label = overlapping_2_labels[np.argmax(overlap_percentages)]
2589
+ max_overlap_percentage = max(overlap_percentages)
2590
+ if max_overlap_percentage >= 90:
2591
+ for m2_label in overlapping_2_labels:
2592
+ if m2_label != max_overlap_label:
2593
+ mask1[(current_1_mask) & (mask2 == m2_label)] = 0
2594
+ else:
2595
+ for m2_label in overlapping_2_labels[1:]:
2596
+ mask2[mask2 == m2_label] = overlapping_2_labels[0]
2597
+ return mask1, mask2
2598
+
2599
+ def _filter_object(mask, min_value):
2600
+ """
2601
+ Filter objects in a mask based on their frequency.
2602
+
2603
+ Args:
2604
+ mask (ndarray): The input mask.
2605
+ min_value (int): The minimum frequency threshold.
2606
+
2607
+ Returns:
2608
+ ndarray: The filtered mask.
2609
+ """
2610
+ count = np.bincount(mask.ravel())
2611
+ to_remove = np.where(count < min_value)
2612
+ mask[np.isin(mask, to_remove)] = 0
2613
+ return mask
2614
+
2615
+ def _filter_cp_masks(masks, flows, filter_size, minimum_size, maximum_size, remove_border_objects, merge, filter_dimm, batch, moving_avg_q1, moving_avg_q3, moving_count, plot, figuresize):
2616
+ """
2617
+ Filter the masks based on various criteria such as size, border objects, merging, and intensity.
2618
+
2619
+ Args:
2620
+ masks (list): List of masks.
2621
+ flows (list): List of flows.
2622
+ refine_masks (bool): Flag indicating whether to refine masks.
2623
+ filter_size (bool): Flag indicating whether to filter based on size.
2624
+ minimum_size (int): Minimum size of objects to keep.
2625
+ maximum_size (int): Maximum size of objects to keep.
2626
+ remove_border_objects (bool): Flag indicating whether to remove border objects.
2627
+ merge (bool): Flag indicating whether to merge adjacent objects.
2628
+ filter_dimm (bool): Flag indicating whether to filter based on intensity.
2629
+ batch (ndarray): Batch of images.
2630
+ moving_avg_q1 (float): Moving average of the first quartile of object intensities.
2631
+ moving_avg_q3 (float): Moving average of the third quartile of object intensities.
2632
+ moving_count (int): Count of moving averages.
2633
+ plot (bool): Flag indicating whether to plot the masks.
2634
+ figuresize (tuple): Size of the figure.
2635
+
2636
+ Returns:
2637
+ list: List of filtered masks.
2638
+ """
2639
+
2640
+ from .plot import plot_masks
2641
+
2642
+ mask_stack = []
2643
+ for idx, (mask, flow, image) in enumerate(zip(masks, flows[0], batch)):
2644
+ if plot and idx == 0:
2645
+ num_objects = mask_object_count(mask)
2646
+ print(f'Number of objects before filtration: {num_objects}')
2647
+ plot_masks(batch=image, masks=mask, flows=flow, cmap='inferno', figuresize=figuresize, nr=1, file_type='.npz', print_object_number=True)
2648
+
2649
+ if filter_size:
2650
+ props = measure.regionprops_table(mask, properties=['label', 'area']) # Measure properties of labeled image regions.
2651
+ valid_labels = props['label'][np.logical_and(props['area'] > minimum_size, props['area'] < maximum_size)] # Select labels of valid size.
2652
+ masks[idx] = np.isin(mask, valid_labels) * mask # Keep only valid objects.
2653
+ if plot and idx == 0:
2654
+ num_objects = mask_object_count(mask)
2655
+ print(f'Number of objects after size filtration >{minimum_size} and <{maximum_size} : {num_objects}')
2656
+ plot_masks(batch=image, masks=mask, flows=flow, cmap='inferno', figuresize=figuresize, nr=1, file_type='.npz', print_object_number=True)
2657
+ if remove_border_objects:
2658
+ mask = clear_border(mask)
2659
+ if plot and idx == 0:
2660
+ num_objects = mask_object_count(mask)
2661
+ print(f'Number of objects after removing border objects, : {num_objects}')
2662
+ plot_masks(batch=image, masks=mask, flows=flow, cmap='inferno', figuresize=figuresize, nr=1, file_type='.npz', print_object_number=True)
2663
+ if merge:
2664
+ mask = merge_touching_objects(mask, threshold=0.25)
2665
+ if plot and idx == 0:
2666
+ num_objects = mask_object_count(mask)
2667
+ print(f'Number of objects after merging adjacent objects, : {num_objects}')
2668
+ plot_masks(batch=image, masks=mask, flows=flow, cmap='inferno', figuresize=figuresize, nr=1, file_type='.npz', print_object_number=True)
2669
+ if filter_dimm:
2670
+ unique_labels = np.unique(mask)
2671
+ if len(unique_labels) == 1 and unique_labels[0] == 0:
2672
+ continue
2673
+ object_intensities = [np.mean(batch[idx, :, :, 1][mask == label]) for label in unique_labels if label != 0]
2674
+ object_q1s = [np.percentile(intensities, 25) for intensities in object_intensities if intensities.size > 0]
2675
+ object_q3s = [np.percentile(intensities, 75) for intensities in object_intensities if intensities.size > 0]
2676
+ if object_q1s:
2677
+ object_q1_mean = np.mean(object_q1s)
2678
+ object_q3_mean = np.mean(object_q3s)
2679
+ moving_avg_q1 = (moving_avg_q1 * moving_count + object_q1_mean) / (moving_count + 1)
2680
+ moving_avg_q3 = (moving_avg_q3 * moving_count + object_q3_mean) / (moving_count + 1)
2681
+ moving_count += 1
2682
+ mask = remove_intensity_objects(batch[idx, :, :, 1], mask, intensity_threshold=moving_avg_q1, mode='low')
2683
+ mask = remove_intensity_objects(batch[idx, :, :, 1], mask, intensity_threshold=moving_avg_q3, mode='high')
2684
+ if plot and idx == 0:
2685
+ num_objects = mask_object_count(mask)
2686
+ print(f'Objects after intensity filtration > {moving_avg_q1} and <{moving_avg_q3}: {num_objects}')
2687
+ plot_masks(batch=image, masks=mask, flows=flow, cmap='inferno', figuresize=figuresize, nr=1, file_type='.npz', print_object_number=True)
2688
+ mask_stack.append(mask)
2689
+ return mask_stack
2690
+
2691
+ def _object_filter(df, object_type, size_range, intensity_range, mask_chans, mask_chan):
2692
+ """
2693
+ Filter the DataFrame based on object type, size range, and intensity range.
2694
+
2695
+ Args:
2696
+ df (pandas.DataFrame): The DataFrame to filter.
2697
+ object_type (str): The type of object to filter.
2698
+ size_range (list or None): The range of object sizes to filter.
2699
+ intensity_range (list or None): The range of object intensities to filter.
2700
+ mask_chans (list): The list of mask channels.
2701
+ mask_chan (int): The index of the mask channel to use.
2702
+
2703
+ Returns:
2704
+ pandas.DataFrame: The filtered DataFrame.
2705
+ """
2706
+ if not size_range is None:
2707
+ if isinstance(size_range, list):
2708
+ if isinstance(size_range[0], int):
2709
+ df = df[df[f'{object_type}_area'] > size_range[0]]
2710
+ print(f'After {object_type} minimum area filter: {len(df)}')
2711
+ if isinstance(size_range[1], int):
2712
+ df = df[df[f'{object_type}_area'] < size_range[1]]
2713
+ print(f'After {object_type} maximum area filter: {len(df)}')
2714
+ if not intensity_range is None:
2715
+ if isinstance(intensity_range, list):
2716
+ if isinstance(intensity_range[0], int):
2717
+ df = df[df[f'{object_type}_channel_{mask_chans[mask_chan]}_mean_intensity'] > intensity_range[0]]
2718
+ print(f'After {object_type} minimum mean intensity filter: {len(df)}')
2719
+ if isinstance(intensity_range[1], int):
2720
+ df = df[df[f'{object_type}_channel_{mask_chans[mask_chan]}_mean_intensity'] < intensity_range[1]]
2721
+ print(f'After {object_type} maximum mean intensity filter: {len(df)}')
2722
+ return df
2723
+
2724
+ ###################################################
2725
+ # Classify
2726
+ ###################################################