spacr 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- spacr/__init__.py +37 -0
- spacr/__main__.py +15 -0
- spacr/annotate_app.py +495 -0
- spacr/cli.py +203 -0
- spacr/core.py +2250 -0
- spacr/gui_mask_app.py +247 -0
- spacr/gui_measure_app.py +214 -0
- spacr/gui_utils.py +488 -0
- spacr/io.py +2271 -0
- spacr/logger.py +20 -0
- spacr/mask_app.py +818 -0
- spacr/measure.py +1014 -0
- spacr/old_code.py +104 -0
- spacr/plot.py +1273 -0
- spacr/sim.py +1187 -0
- spacr/timelapse.py +576 -0
- spacr/train.py +494 -0
- spacr/umap.py +689 -0
- spacr/utils.py +2726 -0
- spacr/version.py +19 -0
- spacr-0.0.1.dist-info/LICENSE +21 -0
- spacr-0.0.1.dist-info/METADATA +64 -0
- spacr-0.0.1.dist-info/RECORD +26 -0
- spacr-0.0.1.dist-info/WHEEL +5 -0
- spacr-0.0.1.dist-info/entry_points.txt +5 -0
- spacr-0.0.1.dist-info/top_level.txt +1 -0
spacr/utils.py
ADDED
@@ -0,0 +1,2726 @@
|
|
1
|
+
import os, re, sqlite3, gc, torch, torchvision, time, random, string, shutil, cv2, tarfile, glob
|
2
|
+
|
3
|
+
import numpy as np
|
4
|
+
from skimage import morphology
|
5
|
+
from skimage.measure import label, regionprops_table, regionprops
|
6
|
+
import skimage.measure as measure
|
7
|
+
from collections import defaultdict
|
8
|
+
from PIL import Image
|
9
|
+
import pandas as pd
|
10
|
+
from statsmodels.stats.outliers_influence import variance_inflation_factor
|
11
|
+
from statsmodels.stats.stattools import durbin_watson
|
12
|
+
import statsmodels.formula.api as smf
|
13
|
+
import statsmodels.api as sm
|
14
|
+
from statsmodels.stats.multitest import multipletests
|
15
|
+
from itertools import combinations
|
16
|
+
from collections import OrderedDict
|
17
|
+
from functools import reduce
|
18
|
+
from IPython.display import display, clear_output
|
19
|
+
from multiprocessing import Pool, cpu_count
|
20
|
+
from skimage.transform import resize as resizescikit
|
21
|
+
import torch.nn as nn
|
22
|
+
import torch.nn.functional as F
|
23
|
+
#from torchsummary import summary
|
24
|
+
from torch.utils.checkpoint import checkpoint
|
25
|
+
from torch.utils.data import Subset
|
26
|
+
from torch.autograd import grad
|
27
|
+
from torchvision import models
|
28
|
+
from skimage.segmentation import clear_border
|
29
|
+
import seaborn as sns
|
30
|
+
import matplotlib.pyplot as plt
|
31
|
+
import scipy.ndimage as ndi
|
32
|
+
from scipy.stats import fisher_exact
|
33
|
+
from scipy.ndimage import binary_erosion, binary_dilation
|
34
|
+
from skimage.exposure import rescale_intensity
|
35
|
+
from sklearn.metrics import auc, precision_recall_curve
|
36
|
+
from sklearn.model_selection import train_test_split
|
37
|
+
from sklearn.linear_model import Lasso, Ridge
|
38
|
+
from sklearn.preprocessing import OneHotEncoder
|
39
|
+
from torchvision.models.resnet import ResNet18_Weights, ResNet34_Weights, ResNet50_Weights, ResNet101_Weights, ResNet152_Weights
|
40
|
+
|
41
|
+
from .logger import log_function_call
|
42
|
+
|
43
|
+
#from .io import _read_and_join_tables, _save_figure
|
44
|
+
#from .timelapse import _btrack_track_cells, _trackpy_track_cells
|
45
|
+
#from .plot import _plot_images_on_grid, plot_masks, _plot_histograms_and_stats, plot_resize, _plot_plates, _reg_v_plot, plot_masks
|
46
|
+
#from .core import identify_masks
|
47
|
+
|
48
|
+
def _convert_cq1_well_id(well_id):
|
49
|
+
"""
|
50
|
+
Converts a well ID to the CQ1 well format.
|
51
|
+
|
52
|
+
Args:
|
53
|
+
well_id (int): The well ID to be converted.
|
54
|
+
|
55
|
+
Returns:
|
56
|
+
str: The well ID in CQ1 well format.
|
57
|
+
|
58
|
+
"""
|
59
|
+
well_id = int(well_id)
|
60
|
+
# ASCII code for 'A'
|
61
|
+
ascii_A = ord('A')
|
62
|
+
# Calculate row and column
|
63
|
+
row, col = divmod(well_id - 1, 24)
|
64
|
+
# Convert row to letter (A-P) and adjust col to start from 1
|
65
|
+
row_letter = chr(ascii_A + row)
|
66
|
+
# Format column as two digits
|
67
|
+
well_format = f"{row_letter}{col + 1:02d}"
|
68
|
+
return well_format
|
69
|
+
|
70
|
+
def _get_cellpose_batch_size():
|
71
|
+
try:
|
72
|
+
# Check if CUDA is available
|
73
|
+
if torch.cuda.is_available():
|
74
|
+
device_properties = torch.cuda.get_device_properties(0)
|
75
|
+
vram_gb = device_properties.total_memory / (1024**3) # Convert bytes to gigabytes
|
76
|
+
else:
|
77
|
+
print("CUDA is not available. Please check your installation and GPU.")
|
78
|
+
return 8
|
79
|
+
if vram_gb < 8:
|
80
|
+
batch_size = 8
|
81
|
+
elif vram_gb > 8 and vram_gb < 12:
|
82
|
+
batch_size = 16
|
83
|
+
elif vram_gb > 12 and vram_gb < 24:
|
84
|
+
batch_size = 48
|
85
|
+
elif vram_gb > 24:
|
86
|
+
batch_size = 96
|
87
|
+
print(f"Device {0}: {device_properties.name}, VRAM: {vram_gb:.2f} GB, cellpose batch size: {batch_size}")
|
88
|
+
return batch_size
|
89
|
+
except Exception as e:
|
90
|
+
return 8
|
91
|
+
|
92
|
+
def _extract_filename_metadata(filenames, src, images_by_key, regular_expression, metadata_type='cellvoyager', pick_slice=False, skip_mode='01'):
|
93
|
+
for filename in filenames:
|
94
|
+
match = regular_expression.match(filename)
|
95
|
+
if match:
|
96
|
+
try:
|
97
|
+
try:
|
98
|
+
plate = match.group('plateID')
|
99
|
+
except:
|
100
|
+
plate = os.path.basename(src)
|
101
|
+
|
102
|
+
well = match.group('wellID')
|
103
|
+
field = match.group('fieldID')
|
104
|
+
channel = match.group('chanID')
|
105
|
+
mode = None
|
106
|
+
|
107
|
+
if well[0].isdigit():
|
108
|
+
well = str(_safe_int_convert(well))
|
109
|
+
if field[0].isdigit():
|
110
|
+
field = str(_safe_int_convert(field))
|
111
|
+
if channel[0].isdigit():
|
112
|
+
channel = str(_safe_int_convert(channel))
|
113
|
+
|
114
|
+
if metadata_type =='cq1':
|
115
|
+
orig_wellID = wellID
|
116
|
+
wellID = _convert_cq1_well_id(wellID)
|
117
|
+
clear_output(wait=True)
|
118
|
+
print(f'\033[KConverted Well ID: {orig_wellID} to {wellID}', end='\r', flush=True)
|
119
|
+
|
120
|
+
if pick_slice:
|
121
|
+
try:
|
122
|
+
mode = match.group('AID')
|
123
|
+
except IndexError:
|
124
|
+
sliceid = '00'
|
125
|
+
|
126
|
+
if mode == skip_mode:
|
127
|
+
continue
|
128
|
+
|
129
|
+
key = (plate, well, field, channel, mode)
|
130
|
+
with Image.open(os.path.join(src, filename)) as img:
|
131
|
+
images_by_key[key].append(np.array(img))
|
132
|
+
except IndexError:
|
133
|
+
print(f"Could not extract information from filename {filename} using provided regex")
|
134
|
+
else:
|
135
|
+
print(f"Filename {filename} did not match provided regex")
|
136
|
+
continue
|
137
|
+
|
138
|
+
return images_by_key
|
139
|
+
|
140
|
+
def mask_object_count(mask):
|
141
|
+
"""
|
142
|
+
Counts the number of objects in a given mask.
|
143
|
+
|
144
|
+
Parameters:
|
145
|
+
- mask: numpy.ndarray
|
146
|
+
The mask containing object labels.
|
147
|
+
|
148
|
+
Returns:
|
149
|
+
- int
|
150
|
+
The number of objects in the mask.
|
151
|
+
"""
|
152
|
+
unique_labels = np.unique(mask)
|
153
|
+
num_objects = len(unique_labels[unique_labels!=0])
|
154
|
+
return num_objects
|
155
|
+
|
156
|
+
def _update_database_with_merged_info(db_path, df, table='png_list', columns=['pathogen', 'treatment', 'host_cells', 'condition', 'prcfo']):
|
157
|
+
"""
|
158
|
+
Merges additional columns into the png_list table in the SQLite database and updates it.
|
159
|
+
|
160
|
+
Args:
|
161
|
+
db_path (str): The path to the SQLite database file.
|
162
|
+
df (pd.DataFrame): DataFrame containing the additional info to be merged.
|
163
|
+
table (str): Name of the table to update in the database. Defaults to 'png_list'.
|
164
|
+
"""
|
165
|
+
# Connect to the SQLite database
|
166
|
+
conn = sqlite3.connect(db_path)
|
167
|
+
|
168
|
+
# Read the existing table into a DataFrame
|
169
|
+
try:
|
170
|
+
existing_df = pd.read_sql(f"SELECT * FROM {table}", conn)
|
171
|
+
except Exception as e:
|
172
|
+
print(f"Failed to read table {table} from database: {e}")
|
173
|
+
conn.close()
|
174
|
+
return
|
175
|
+
|
176
|
+
if 'prcfo' not in df.columns:
|
177
|
+
print(f'generating prcfo columns')
|
178
|
+
try:
|
179
|
+
df['prcfo'] = df['plate'].astype(str) + '_' + df['row'].astype(str) + '_' + df['col'].astype(str) + '_' + df['field'].astype(str) + '_o' + df['object_label'].astype(int).astype(str)
|
180
|
+
except Exception as e:
|
181
|
+
print('Merging on cell failed, trying with cell_id')
|
182
|
+
try:
|
183
|
+
df['prcfo'] = df['plate'].astype(str) + '_' + df['row'].astype(str) + '_' + df['col'].astype(str) + '_' + df['field'].astype(str) + '_o' + df['cell_id'].astype(int).astype(str)
|
184
|
+
except Exception as e:
|
185
|
+
print(e)
|
186
|
+
|
187
|
+
# Merge the existing DataFrame with the new info based on the 'prcfo' column
|
188
|
+
merged_df = pd.merge(existing_df, df[columns], on='prcfo', how='left')
|
189
|
+
|
190
|
+
# Drop the existing table and replace it with the updated DataFrame
|
191
|
+
try:
|
192
|
+
conn.execute(f"DROP TABLE IF EXISTS {table}")
|
193
|
+
merged_df.to_sql(table, conn, index=False)
|
194
|
+
print(f"Table {table} successfully updated in the database.")
|
195
|
+
except Exception as e:
|
196
|
+
print(f"Failed to update table {table} in the database: {e}")
|
197
|
+
finally:
|
198
|
+
conn.close()
|
199
|
+
|
200
|
+
def _generate_representative_images(db_path, cells=['HeLa'], cell_loc=None, pathogens=['rh'], pathogen_loc=None, treatments=['cm'], treatment_loc=None, channel_of_interest=1, compartments = ['pathogen','cytoplasm'], measurement = 'mean_intensity', nr_imgs=16, channel_indices=[0,1,2], um_per_pixel=0.1, scale_bar_length_um=10, plot=False, fontsize=12, show_filename=True, channel_names=None, update_db=True):
|
201
|
+
"""
|
202
|
+
Generates representative images based on the provided parameters.
|
203
|
+
|
204
|
+
Args:
|
205
|
+
db_path (str): The path to the SQLite database file.
|
206
|
+
cells (list, optional): The list of host cell types. Defaults to ['HeLa'].
|
207
|
+
cell_loc (list, optional): The list of location identifiers for host cells. Defaults to None.
|
208
|
+
pathogens (list, optional): The list of pathogens. Defaults to ['rh'].
|
209
|
+
pathogen_loc (list, optional): The list of location identifiers for pathogens. Defaults to None.
|
210
|
+
treatments (list, optional): The list of treatments. Defaults to ['cm'].
|
211
|
+
treatment_loc (list, optional): The list of location identifiers for treatments. Defaults to None.
|
212
|
+
channel_of_interest (int, optional): The index of the channel of interest. Defaults to 1.
|
213
|
+
compartments (list or str, optional): The compartments to compare. Defaults to ['pathogen', 'cytoplasm'].
|
214
|
+
measurement (str, optional): The measurement to compare. Defaults to 'mean_intensity'.
|
215
|
+
nr_imgs (int, optional): The number of representative images to generate. Defaults to 16.
|
216
|
+
channel_indices (list, optional): The indices of the channels to include in the representative images. Defaults to [0, 1, 2].
|
217
|
+
um_per_pixel (float, optional): The scale factor for converting pixels to micrometers. Defaults to 0.1.
|
218
|
+
scale_bar_length_um (float, optional): The length of the scale bar in micrometers. Defaults to 10.
|
219
|
+
plot (bool, optional): Whether to plot the representative images. Defaults to False.
|
220
|
+
fontsize (int, optional): The font size for the plot. Defaults to 12.
|
221
|
+
show_filename (bool, optional): Whether to show the filename on the plot. Defaults to True.
|
222
|
+
channel_names (list, optional): The names of the channels. Defaults to None.
|
223
|
+
|
224
|
+
Returns:
|
225
|
+
None
|
226
|
+
"""
|
227
|
+
|
228
|
+
from .io import _read_and_join_tables, _save_figure
|
229
|
+
from .plot import _plot_images_on_grid
|
230
|
+
|
231
|
+
df = _read_and_join_tables(db_path)
|
232
|
+
df = _annotate_conditions(df, cells, cell_loc, pathogens, pathogen_loc, treatments,treatment_loc)
|
233
|
+
|
234
|
+
if update_db:
|
235
|
+
_update_database_with_merged_info(db_path, df, table='png_list', columns=['pathogen', 'treatment', 'host_cells', 'condition', 'prcfo'])
|
236
|
+
|
237
|
+
if isinstance(compartments, list):
|
238
|
+
if len(compartments) > 1:
|
239
|
+
df['new_measurement'] = df[f'{compartments[0]}_channel_{channel_of_interest}_{measurement}']/df[f'{compartments[1]}_channel_{channel_of_interest}_{measurement}']
|
240
|
+
else:
|
241
|
+
df['new_measurement'] = df['cell_area']
|
242
|
+
dfs = {condition: df_group for condition, df_group in df.groupby('condition')}
|
243
|
+
conditions = df['condition'].dropna().unique().tolist()
|
244
|
+
for condition in conditions:
|
245
|
+
df = dfs[condition]
|
246
|
+
df = _filter_closest_to_stat(df, column='new_measurement', n_rows=nr_imgs, use_median=False)
|
247
|
+
png_paths_by_condition = df['png_path'].tolist()
|
248
|
+
fig = _plot_images_on_grid(png_paths_by_condition, channel_indices, um_per_pixel, scale_bar_length_um, fontsize, show_filename, channel_names, plot)
|
249
|
+
src = os.path.dirname(db_path)
|
250
|
+
os.makedirs(src, exist_ok=True)
|
251
|
+
_save_figure(fig=fig, src=src, text=condition)
|
252
|
+
for channel in channel_indices:
|
253
|
+
channel_indices=[channel]
|
254
|
+
fig = _plot_images_on_grid(png_paths_by_condition, channel_indices, um_per_pixel, scale_bar_length_um, fontsize, show_filename, channel_names, plot)
|
255
|
+
_save_figure(fig, src, text=f'channel_{channel}_{condition}')
|
256
|
+
plt.close()
|
257
|
+
|
258
|
+
# Adjusted mapping function to infer type from location identifiers
|
259
|
+
def _map_values(row, values, locs):
|
260
|
+
"""
|
261
|
+
Maps values to a specific location in the row or column based on the given locs.
|
262
|
+
|
263
|
+
Args:
|
264
|
+
row (dict): The row dictionary containing the location identifier.
|
265
|
+
values (list): The list of values to be mapped.
|
266
|
+
locs (list): The list of location identifiers.
|
267
|
+
|
268
|
+
Returns:
|
269
|
+
The mapped value corresponding to the given row or column location, or None if not found.
|
270
|
+
"""
|
271
|
+
if locs:
|
272
|
+
value_dict = {loc: value for value, loc_list in zip(values, locs) for loc in loc_list}
|
273
|
+
# Determine if we're dealing with row or column based on first location identifier
|
274
|
+
type_ = 'row' if locs[0][0][0] == 'r' else 'col'
|
275
|
+
return value_dict.get(row[type_], None)
|
276
|
+
return values[0] if values else None
|
277
|
+
|
278
|
+
def _annotate_conditions(df, cells=['HeLa'], cell_loc=None, pathogens=['rh'], pathogen_loc=None, treatments=['cm'], treatment_loc=None):
|
279
|
+
"""
|
280
|
+
Annotates conditions in the given DataFrame based on the provided parameters.
|
281
|
+
|
282
|
+
Args:
|
283
|
+
df (pandas.DataFrame): The DataFrame to annotate.
|
284
|
+
cells (list, optional): The list of host cell types. Defaults to ['HeLa'].
|
285
|
+
cell_loc (list, optional): The list of location identifiers for host cells. Defaults to None.
|
286
|
+
pathogens (list, optional): The list of pathogens. Defaults to ['rh'].
|
287
|
+
pathogen_loc (list, optional): The list of location identifiers for pathogens. Defaults to None.
|
288
|
+
treatments (list, optional): The list of treatments. Defaults to ['cm'].
|
289
|
+
treatment_loc (list, optional): The list of location identifiers for treatments. Defaults to None.
|
290
|
+
|
291
|
+
Returns:
|
292
|
+
pandas.DataFrame: The annotated DataFrame with the 'host_cells', 'pathogen', 'treatment', and 'condition' columns.
|
293
|
+
"""
|
294
|
+
|
295
|
+
|
296
|
+
# Apply mappings or defaults
|
297
|
+
df['host_cells'] = [cells[0]] * len(df) if cell_loc is None else df.apply(_map_values, args=(cells, cell_loc), axis=1)
|
298
|
+
df['pathogen'] = [pathogens[0]] * len(df) if pathogen_loc is None else df.apply(_map_values, args=(pathogens, pathogen_loc), axis=1)
|
299
|
+
df['treatment'] = [treatments[0]] * len(df) if treatment_loc is None else df.apply(_map_values, args=(treatments, treatment_loc), axis=1)
|
300
|
+
|
301
|
+
# Construct condition column
|
302
|
+
df['condition'] = df.apply(lambda row: '_'.join(filter(None, [row.get('pathogen'), row.get('treatment')])), axis=1)
|
303
|
+
df['condition'] = df['condition'].apply(lambda x: x if x else 'none')
|
304
|
+
return df
|
305
|
+
|
306
|
+
def normalize_to_dtype(array, q1=2,q2=98, percentiles=None):
|
307
|
+
"""
|
308
|
+
Normalize the input array to a specified data type.
|
309
|
+
|
310
|
+
Parameters:
|
311
|
+
- array: numpy array
|
312
|
+
The input array to be normalized.
|
313
|
+
- q1: int, optional
|
314
|
+
The lower percentile value for normalization. Default is 2.
|
315
|
+
- q2: int, optional
|
316
|
+
The upper percentile value for normalization. Default is 98.
|
317
|
+
- percentiles: list of tuples, optional
|
318
|
+
A list of tuples containing the percentile values for each image in the array.
|
319
|
+
If provided, the percentiles for each image will be used instead of q1 and q2.
|
320
|
+
|
321
|
+
Returns:
|
322
|
+
- new_stack: numpy array
|
323
|
+
The normalized array with the same shape as the input array.
|
324
|
+
"""
|
325
|
+
nimg = array.shape[2]
|
326
|
+
new_stack = np.empty_like(array)
|
327
|
+
for i,v in enumerate(range(nimg)):
|
328
|
+
img = np.squeeze(array[:, :, v])
|
329
|
+
non_zero_img = img[img > 0]
|
330
|
+
if non_zero_img.size > 0: # check if there are non-zero values
|
331
|
+
img_min = np.percentile(non_zero_img, q1) # change percentile from 0.02 to 2
|
332
|
+
img_max = np.percentile(non_zero_img, q2) # change percentile from 0.98 to 98
|
333
|
+
img = rescale_intensity(img, in_range=(img_min, img_max), out_range='dtype')
|
334
|
+
else: # if there are no non-zero values, just use the image as it is
|
335
|
+
if percentiles==None:
|
336
|
+
img_min, img_max = img.min(), img.max()
|
337
|
+
else:
|
338
|
+
img_min, img_max = percentiles[i]
|
339
|
+
img = rescale_intensity(img, in_range=(img_min, img_max), out_range='dtype')
|
340
|
+
img = np.expand_dims(img, axis=2)
|
341
|
+
new_stack[:, :, v] = img[:, :, 0]
|
342
|
+
return new_stack
|
343
|
+
|
344
|
+
def _list_endpoint_subdirectories(base_dir):
|
345
|
+
"""
|
346
|
+
Returns a list of subdirectories within the given base directory.
|
347
|
+
|
348
|
+
Args:
|
349
|
+
base_dir (str): The base directory to search for subdirectories.
|
350
|
+
|
351
|
+
Returns:
|
352
|
+
list: A list of subdirectories within the base directory.
|
353
|
+
"""
|
354
|
+
|
355
|
+
endpoint_subdirectories = []
|
356
|
+
for root, dirs, _ in os.walk(base_dir):
|
357
|
+
if not dirs:
|
358
|
+
endpoint_subdirectories.append(root)
|
359
|
+
|
360
|
+
endpoint_subdirectories = [path for path in endpoint_subdirectories if os.path.basename(path) != 'figure']
|
361
|
+
return endpoint_subdirectories
|
362
|
+
|
363
|
+
def _generate_names(file_name, cell_id, cell_nucleus_ids, cell_pathogen_ids, source_folder, crop_mode='cell'):
|
364
|
+
"""
|
365
|
+
Generate names for the image, folder, and table based on the given parameters.
|
366
|
+
|
367
|
+
Args:
|
368
|
+
file_name (str): The name of the file.
|
369
|
+
cell_id (numpy.ndarray): An array of cell IDs.
|
370
|
+
cell_nucleus_ids (numpy.ndarray): An array of cell nucleus IDs.
|
371
|
+
cell_pathogen_ids (numpy.ndarray): An array of cell pathogen IDs.
|
372
|
+
source_folder (str): The source folder path.
|
373
|
+
crop_mode (str, optional): The crop mode. Defaults to 'cell'.
|
374
|
+
|
375
|
+
Returns:
|
376
|
+
tuple: A tuple containing the image name, folder path, and table name.
|
377
|
+
"""
|
378
|
+
non_zero_cell_ids = cell_id[cell_id != 0]
|
379
|
+
cell_id_str = "multi" if non_zero_cell_ids.size > 1 else str(non_zero_cell_ids[0]) if non_zero_cell_ids.size == 1 else "none"
|
380
|
+
cell_nucleus_ids = cell_nucleus_ids[cell_nucleus_ids != 0]
|
381
|
+
cell_nucleus_id_str = "multi" if cell_nucleus_ids.size > 1 else str(cell_nucleus_ids[0]) if cell_nucleus_ids.size == 1 else "none"
|
382
|
+
cell_pathogen_ids = cell_pathogen_ids[cell_pathogen_ids != 0]
|
383
|
+
cell_pathogen_id_str = "multi" if cell_pathogen_ids.size > 1 else str(cell_pathogen_ids[0]) if cell_pathogen_ids.size == 1 else "none"
|
384
|
+
fldr = f"{source_folder}/data/"
|
385
|
+
img_name = ""
|
386
|
+
if crop_mode == 'nucleus':
|
387
|
+
img_name = f"{file_name}_{cell_id_str}_{cell_nucleus_id_str}.png"
|
388
|
+
fldr += "single_nucleus/" if cell_nucleus_ids.size == 1 else "multiple_nucleus/" if cell_nucleus_ids.size > 1 else "no_nucleus/"
|
389
|
+
fldr += "single_pathogen/" if cell_pathogen_ids.size == 1 else "multiple_pathogens/" if cell_pathogen_ids.size > 1 else "uninfected/"
|
390
|
+
elif crop_mode == 'pathogen':
|
391
|
+
img_name = f"{file_name}_{cell_id_str}_{cell_pathogen_id_str}.png"
|
392
|
+
fldr += "single_nucleus/" if cell_nucleus_ids.size == 1 else "multiple_nucleus/" if cell_nucleus_ids.size > 1 else "no_nucleus/"
|
393
|
+
fldr += "infected/" if cell_pathogen_ids.size >= 1 else "uninfected/"
|
394
|
+
elif crop_mode == 'cell' or crop_mode == 'cytoplasm':
|
395
|
+
img_name = f"{file_name}_{cell_id_str}.png"
|
396
|
+
fldr += "single_nucleus/" if cell_nucleus_ids.size == 1 else "multiple_nucleus/" if cell_nucleus_ids.size > 1 else "no_nucleus/"
|
397
|
+
fldr += "single_pathogen/" if cell_pathogen_ids.size == 1 else "multiple_pathogens/" if cell_pathogen_ids.size > 1 else "uninfected/"
|
398
|
+
parts = file_name.split('_')
|
399
|
+
plate = parts[0]
|
400
|
+
well = parts[1]
|
401
|
+
metadata = f'{plate}_{well}'
|
402
|
+
fldr = os.path.join(fldr,metadata)
|
403
|
+
table_name = fldr.replace("/", "_")
|
404
|
+
return img_name, fldr, table_name
|
405
|
+
|
406
|
+
def _find_bounding_box(crop_mask, _id, buffer=10):
|
407
|
+
"""
|
408
|
+
Find the bounding box coordinates for a given object ID in a crop mask.
|
409
|
+
|
410
|
+
Parameters:
|
411
|
+
crop_mask (ndarray): The crop mask containing object IDs.
|
412
|
+
_id (int): The object ID to find the bounding box for.
|
413
|
+
buffer (int, optional): The buffer size to add to the bounding box coordinates. Defaults to 10.
|
414
|
+
|
415
|
+
Returns:
|
416
|
+
ndarray: A new mask with the same dimensions as crop_mask, where the bounding box area is filled with the object ID.
|
417
|
+
"""
|
418
|
+
object_indices = np.where(crop_mask == _id)
|
419
|
+
|
420
|
+
# Determine the bounding box coordinates
|
421
|
+
y_min, y_max = object_indices[0].min(), object_indices[0].max()
|
422
|
+
x_min, x_max = object_indices[1].min(), object_indices[1].max()
|
423
|
+
|
424
|
+
# Add buffer to the bounding box coordinates
|
425
|
+
y_min = max(y_min - buffer, 0)
|
426
|
+
y_max = min(y_max + buffer, crop_mask.shape[0] - 1)
|
427
|
+
x_min = max(x_min - buffer, 0)
|
428
|
+
x_max = min(x_max + buffer, crop_mask.shape[1] - 1)
|
429
|
+
|
430
|
+
# Create a new mask with the same dimensions as crop_mask
|
431
|
+
new_mask = np.zeros_like(crop_mask)
|
432
|
+
|
433
|
+
# Fill in the bounding box area with the _id
|
434
|
+
new_mask[y_min:y_max+1, x_min:x_max+1] = _id
|
435
|
+
|
436
|
+
return new_mask
|
437
|
+
|
438
|
+
def _merge_and_save_to_database(morph_df, intensity_df, table_type, source_folder, file_name, experiment, timelapse=False):
|
439
|
+
"""
|
440
|
+
Merges morphology and intensity dataframes, renames columns, adds additional columns, rearranges columns,
|
441
|
+
and saves the merged dataframe to a SQLite database.
|
442
|
+
|
443
|
+
Args:
|
444
|
+
morph_df (pd.DataFrame): Dataframe containing morphology data.
|
445
|
+
intensity_df (pd.DataFrame): Dataframe containing intensity data.
|
446
|
+
table_type (str): Type of table to save the merged dataframe to.
|
447
|
+
source_folder (str): Path to the source folder.
|
448
|
+
file_name (str): Name of the file.
|
449
|
+
experiment (str): Name of the experiment.
|
450
|
+
timelapse (bool, optional): Indicates if the data is from a timelapse experiment. Defaults to False.
|
451
|
+
|
452
|
+
Raises:
|
453
|
+
ValueError: If an invalid table_type is provided or if columns are missing in the dataframe.
|
454
|
+
|
455
|
+
"""
|
456
|
+
morph_df = _check_integrity(morph_df)
|
457
|
+
intensity_df = _check_integrity(intensity_df)
|
458
|
+
if len(morph_df) > 0 and len(intensity_df) > 0:
|
459
|
+
merged_df = pd.merge(morph_df, intensity_df, on='object_label', how='outer')
|
460
|
+
merged_df = merged_df.rename(columns={"label_list_x": "label_list_morphology", "label_list_y": "label_list_intensity"})
|
461
|
+
merged_df['file_name'] = file_name
|
462
|
+
merged_df['path_name'] = os.path.join(source_folder, file_name + '.npy')
|
463
|
+
if timelapse:
|
464
|
+
merged_df[['plate', 'row', 'col', 'field', 'timeid', 'prcf']] = merged_df['file_name'].apply(lambda x: pd.Series(_map_wells(x, timelapse)))
|
465
|
+
else:
|
466
|
+
merged_df[['plate', 'row', 'col', 'field', 'prcf']] = merged_df['file_name'].apply(lambda x: pd.Series(_map_wells(x, timelapse)))
|
467
|
+
cols = merged_df.columns.tolist() # get the list of all columns
|
468
|
+
if table_type == 'cell' or table_type == 'cytoplasm':
|
469
|
+
column_list = ['object_label', 'plate', 'row', 'col', 'field', 'prcf', 'file_name', 'path_name']
|
470
|
+
elif table_type == 'nucleus' or table_type == 'pathogen':
|
471
|
+
column_list = ['object_label', 'cell_id', 'plate', 'row', 'col', 'field', 'prcf', 'file_name', 'path_name']
|
472
|
+
else:
|
473
|
+
raise ValueError(f"Invalid table_type: {table_type}")
|
474
|
+
# Check if all columns in column_list are in cols
|
475
|
+
missing_columns = [col for col in column_list if col not in cols]
|
476
|
+
if len(missing_columns) == 1 and missing_columns[0] == 'cell_id':
|
477
|
+
missing_columns = False
|
478
|
+
column_list = ['object_label', 'plate', 'row', 'col', 'field', 'prcf', 'file_name', 'path_name']
|
479
|
+
if missing_columns:
|
480
|
+
raise ValueError(f"Columns missing in DataFrame: {missing_columns}")
|
481
|
+
for i, col in enumerate(column_list):
|
482
|
+
cols.insert(i, cols.pop(cols.index(col)))
|
483
|
+
merged_df = merged_df[cols] # rearrange the columns
|
484
|
+
if len(merged_df) > 0:
|
485
|
+
try:
|
486
|
+
conn = sqlite3.connect(f'{source_folder}/measurements/measurements.db', timeout=5)
|
487
|
+
merged_df.to_sql(table_type, conn, if_exists='append', index=False)
|
488
|
+
except sqlite3.OperationalError as e:
|
489
|
+
print("SQLite error:", e)
|
490
|
+
|
491
|
+
def _safe_int_convert(value, default=0):
|
492
|
+
"""
|
493
|
+
Converts the given value to an integer if possible, otherwise returns the default value.
|
494
|
+
|
495
|
+
Args:
|
496
|
+
value: The value to be converted to an integer.
|
497
|
+
default: The default value to be returned if the conversion fails. Default is 0.
|
498
|
+
|
499
|
+
Returns:
|
500
|
+
The converted integer value if successful, otherwise the default value.
|
501
|
+
"""
|
502
|
+
try:
|
503
|
+
return int(value)
|
504
|
+
except ValueError:
|
505
|
+
print(f'Could not convert {value} to int using {default}', end='\r', flush=True)
|
506
|
+
return default
|
507
|
+
|
508
|
+
def _map_wells(file_name, timelapse=False):
|
509
|
+
"""
|
510
|
+
Maps the components of a file name to plate, row, column, field, and timeid (if timelapse is True).
|
511
|
+
|
512
|
+
Args:
|
513
|
+
file_name (str): The name of the file.
|
514
|
+
timelapse (bool, optional): Indicates whether the file is part of a timelapse sequence. Defaults to False.
|
515
|
+
|
516
|
+
Returns:
|
517
|
+
tuple: A tuple containing the mapped values for plate, row, column, field, and timeid (if timelapse is True).
|
518
|
+
"""
|
519
|
+
try:
|
520
|
+
parts = file_name.split('_')
|
521
|
+
plate = 'p' + parts[0]
|
522
|
+
well = parts[1]
|
523
|
+
field = 'f' + str(_safe_int_convert(parts[2]))
|
524
|
+
if timelapse:
|
525
|
+
timeid = 't' + str(_safe_int_convert(parts[3]))
|
526
|
+
if well[0].isalpha():
|
527
|
+
row = 'r' + str(string.ascii_uppercase.index(well[0]) + 1)
|
528
|
+
column = 'c' + str(int(well[1:]))
|
529
|
+
else:
|
530
|
+
row, column = well, well
|
531
|
+
if timelapse:
|
532
|
+
prcf = '_'.join([plate, row, column, field, timeid])
|
533
|
+
else:
|
534
|
+
prcf = '_'.join([plate, row, column, field])
|
535
|
+
except Exception as e:
|
536
|
+
print(f"Error processing filename: {file_name}")
|
537
|
+
print(f"Error: {e}")
|
538
|
+
plate, row, column, field, timeid, prcf = 'error','error','error','error','error', 'error'
|
539
|
+
if timelapse:
|
540
|
+
return plate, row, column, field, timeid, prcf
|
541
|
+
else:
|
542
|
+
return plate, row, column, field, prcf
|
543
|
+
|
544
|
+
def _map_wells_png(file_name, timelapse=False):
|
545
|
+
"""
|
546
|
+
Maps the components of a file name to their corresponding values.
|
547
|
+
|
548
|
+
Args:
|
549
|
+
file_name (str): The name of the file.
|
550
|
+
timelapse (bool, optional): Indicates whether the file is part of a timelapse sequence. Defaults to False.
|
551
|
+
|
552
|
+
Returns:
|
553
|
+
tuple: A tuple containing the mapped components of the file name.
|
554
|
+
|
555
|
+
Raises:
|
556
|
+
None
|
557
|
+
|
558
|
+
"""
|
559
|
+
try:
|
560
|
+
root, ext = os.path.splitext(file_name)
|
561
|
+
parts = root.split('_')
|
562
|
+
plate = 'p' + parts[0]
|
563
|
+
well = parts[1]
|
564
|
+
field = 'f' + str(_safe_int_convert(parts[2]))
|
565
|
+
if timelapse:
|
566
|
+
timeid = 't' + str(_safe_int_convert(parts[3]))
|
567
|
+
object_id = 'o' + str(_safe_int_convert(parts[-1], default='none'))
|
568
|
+
if well[0].isalpha():
|
569
|
+
row = 'r' + str(string.ascii_uppercase.index(well[0]) + 1)
|
570
|
+
column = 'c' + str(_safe_int_convert(well[1:]))
|
571
|
+
else:
|
572
|
+
row, column = well, well
|
573
|
+
if timelapse:
|
574
|
+
prcfo = '_'.join([plate, row, column, field, timeid, object_id])
|
575
|
+
else:
|
576
|
+
prcfo = '_'.join([plate, row, column, field, object_id])
|
577
|
+
except Exception as e:
|
578
|
+
print(f"Error processing filename: {file_name}")
|
579
|
+
print(f"Error: {e}")
|
580
|
+
plate, row, column, field, object_id, prcfo = 'error', 'error', 'error', 'error', 'error', 'error'
|
581
|
+
if timelapse:
|
582
|
+
return plate, row, column, field, timeid, prcfo, object_id,
|
583
|
+
else:
|
584
|
+
return plate, row, column, field, prcfo, object_id
|
585
|
+
|
586
|
+
def _check_integrity(df):
|
587
|
+
"""
|
588
|
+
Check the integrity of the DataFrame and perform necessary modifications.
|
589
|
+
|
590
|
+
Args:
|
591
|
+
df (pandas.DataFrame): The input DataFrame.
|
592
|
+
|
593
|
+
Returns:
|
594
|
+
pandas.DataFrame: The modified DataFrame with integrity checks and modifications applied.
|
595
|
+
"""
|
596
|
+
df.columns = [col + f'_{i}' if df.columns.tolist().count(col) > 1 and i != 0 else col for i, col in enumerate(df.columns)]
|
597
|
+
label_cols = [col for col in df.columns if 'label' in col]
|
598
|
+
df['label_list'] = df[label_cols].values.tolist()
|
599
|
+
df['object_label'] = df['label_list'].apply(lambda x: x[0])
|
600
|
+
df = df.drop(columns=label_cols)
|
601
|
+
df['label_list'] = df['label_list'].astype(str)
|
602
|
+
return df
|
603
|
+
|
604
|
+
def _get_percentiles(array, q1=2, q2=98):
|
605
|
+
"""
|
606
|
+
Calculate the percentiles of each image in the given array.
|
607
|
+
|
608
|
+
Parameters:
|
609
|
+
- array: numpy.ndarray
|
610
|
+
The input array containing the images.
|
611
|
+
- q1: float, optional
|
612
|
+
The lower percentile value to calculate. Default is 2.
|
613
|
+
- q2: float, optional
|
614
|
+
The upper percentile value to calculate. Default is 98.
|
615
|
+
|
616
|
+
Returns:
|
617
|
+
- percentiles: list
|
618
|
+
A list of tuples, where each tuple contains the minimum and maximum
|
619
|
+
values of the corresponding image in the array.
|
620
|
+
"""
|
621
|
+
nimg = array.shape[2]
|
622
|
+
percentiles = []
|
623
|
+
for v in range(nimg):
|
624
|
+
img = np.squeeze(array[:, :, v])
|
625
|
+
non_zero_img = img[img > 0]
|
626
|
+
if non_zero_img.size > 0: # check if there are non-zero values
|
627
|
+
img_min = np.percentile(non_zero_img, q1) # change percentile from 0.02 to 2
|
628
|
+
img_max = np.percentile(non_zero_img, q2) # change percentile from 0.98 to 98
|
629
|
+
percentiles.append([img_min, img_max])
|
630
|
+
else: # if there are no non-zero values, just use the image as it is
|
631
|
+
img_min, img_max = img.min(), img.max()
|
632
|
+
percentiles.append([img_min, img_max])
|
633
|
+
return percentiles
|
634
|
+
|
635
|
+
def _crop_center(img, cell_mask, new_width, new_height, normalize=(2,98)):
|
636
|
+
"""
|
637
|
+
Crop the image around the center of the cell mask.
|
638
|
+
|
639
|
+
Parameters:
|
640
|
+
- img: numpy.ndarray
|
641
|
+
The input image.
|
642
|
+
- cell_mask: numpy.ndarray
|
643
|
+
The binary mask of the cell.
|
644
|
+
- new_width: int
|
645
|
+
The desired width of the cropped image.
|
646
|
+
- new_height: int
|
647
|
+
The desired height of the cropped image.
|
648
|
+
- normalize: tuple, optional
|
649
|
+
The normalization range for the image pixel values. Default is (2, 98).
|
650
|
+
|
651
|
+
Returns:
|
652
|
+
- img: numpy.ndarray
|
653
|
+
The cropped image.
|
654
|
+
"""
|
655
|
+
# Convert all non-zero values in mask to 1
|
656
|
+
cell_mask[cell_mask != 0] = 1
|
657
|
+
mask_3d = np.repeat(cell_mask[:, :, np.newaxis], img.shape[2], axis=2).astype(img.dtype) # Create 3D mask
|
658
|
+
img = np.multiply(img, mask_3d).astype(img.dtype) # Multiply image with mask to set pixel values outside of the mask to 0
|
659
|
+
#centroid = np.round(ndi.measurements.center_of_mass(cell_mask)).astype(int) # Compute centroid of the mask
|
660
|
+
centroid = np.round(ndi.center_of_mass(cell_mask)).astype(int) # Compute centroid of the mask
|
661
|
+
# Pad the image and mask to ensure the crop will not go out of bounds
|
662
|
+
pad_width = max(new_width, new_height)
|
663
|
+
img = np.pad(img, ((pad_width, pad_width), (pad_width, pad_width), (0, 0)), mode='constant')
|
664
|
+
cell_mask = np.pad(cell_mask, ((pad_width, pad_width), (pad_width, pad_width)), mode='constant')
|
665
|
+
# Update centroid coordinates due to padding
|
666
|
+
centroid += pad_width
|
667
|
+
# Compute bounding box
|
668
|
+
start_y = max(0, centroid[0] - new_height // 2)
|
669
|
+
end_y = min(start_y + new_height, img.shape[0])
|
670
|
+
start_x = max(0, centroid[1] - new_width // 2)
|
671
|
+
end_x = min(start_x + new_width, img.shape[1])
|
672
|
+
# Crop to bounding box
|
673
|
+
img = img[start_y:end_y, start_x:end_x, :]
|
674
|
+
return img
|
675
|
+
|
676
|
+
|
677
|
+
|
678
|
+
|
679
|
+
def _masks_to_masks_stack(masks):
|
680
|
+
"""
|
681
|
+
Convert a list of masks into a stack of masks.
|
682
|
+
|
683
|
+
Args:
|
684
|
+
masks (list): A list of masks.
|
685
|
+
|
686
|
+
Returns:
|
687
|
+
list: A stack of masks.
|
688
|
+
"""
|
689
|
+
mask_stack = []
|
690
|
+
for idx, mask in enumerate(masks):
|
691
|
+
mask_stack.append(mask)
|
692
|
+
return mask_stack
|
693
|
+
|
694
|
+
def _get_diam(mag, obj):
|
695
|
+
if obj == 'cell':
|
696
|
+
if mag == 20:
|
697
|
+
scale = 6
|
698
|
+
if mag == 40:
|
699
|
+
scale = 4.5
|
700
|
+
if mag == 60:
|
701
|
+
scale = 3
|
702
|
+
elif obj == 'nucleus':
|
703
|
+
if mag == 20:
|
704
|
+
scale = 3
|
705
|
+
if mag == 40:
|
706
|
+
scale = 2
|
707
|
+
if mag == 60:
|
708
|
+
scale = 1.5
|
709
|
+
elif obj == 'pathogen':
|
710
|
+
if mag == 20:
|
711
|
+
scale = 1.5
|
712
|
+
if mag == 40:
|
713
|
+
scale = 1
|
714
|
+
if mag == 60:
|
715
|
+
scale = 1.25
|
716
|
+
elif obj == 'pathogen_nucleus':
|
717
|
+
if mag == 20:
|
718
|
+
scale = 0.25
|
719
|
+
if mag == 40:
|
720
|
+
scale = 0.2
|
721
|
+
if mag == 60:
|
722
|
+
scale = 0.2
|
723
|
+
else:
|
724
|
+
raise ValueError("Invalid object type")
|
725
|
+
diamiter = mag*scale
|
726
|
+
return diamiter
|
727
|
+
|
728
|
+
def _get_object_settings(object_type, settings):
|
729
|
+
|
730
|
+
object_settings = {}
|
731
|
+
object_settings['refine_masks'] = False
|
732
|
+
object_settings['filter_size'] = False
|
733
|
+
object_settings['filter_dimm'] = False
|
734
|
+
print(object_type)
|
735
|
+
object_settings['diameter'] = _get_diam(settings['magnification'], obj=object_type)
|
736
|
+
object_settings['remove_border_objects'] = False
|
737
|
+
object_settings['minimum_size'] = (object_settings['diameter']**2)/10
|
738
|
+
object_settings['maximum_size'] = object_settings['minimum_size']*50
|
739
|
+
object_settings['merge'] = False
|
740
|
+
object_settings['net_avg'] = True
|
741
|
+
object_settings['resample'] = True
|
742
|
+
object_settings['model_name'] = 'cyto'
|
743
|
+
|
744
|
+
if object_type == 'cell':
|
745
|
+
if settings['nucleus_channel'] is None:
|
746
|
+
object_settings['model_name'] = 'cyto'
|
747
|
+
else:
|
748
|
+
object_settings['model_name'] = 'cyto2'
|
749
|
+
|
750
|
+
elif object_type == 'nucleus':
|
751
|
+
object_settings['model_name'] = 'nuclei'
|
752
|
+
|
753
|
+
elif object_type == 'pathogen':
|
754
|
+
object_settings['model_name'] = 'cyto3'
|
755
|
+
|
756
|
+
elif object_type == 'pathogen_nucleus':
|
757
|
+
object_settings['filter_size'] = True
|
758
|
+
object_settings['model_name'] = 'cyto'
|
759
|
+
|
760
|
+
else:
|
761
|
+
print(f'Object type: {object_type} not supported. Supported object types are : cell, nucleus and pathogen')
|
762
|
+
print(f'using settings: {object_settings}')
|
763
|
+
|
764
|
+
return object_settings
|
765
|
+
|
766
|
+
def _pivot_counts_table(db_path):
|
767
|
+
|
768
|
+
def _read_table_to_dataframe(db_path, table_name='object_counts'):
|
769
|
+
"""
|
770
|
+
Read a table from an SQLite database into a pandas DataFrame.
|
771
|
+
|
772
|
+
Parameters:
|
773
|
+
- db_path (str): The path to the SQLite database file.
|
774
|
+
- table_name (str): The name of the table to read. Default is 'object_counts'.
|
775
|
+
|
776
|
+
Returns:
|
777
|
+
- df (pandas.DataFrame): The table data as a pandas DataFrame.
|
778
|
+
"""
|
779
|
+
# Connect to the SQLite database
|
780
|
+
conn = sqlite3.connect(db_path)
|
781
|
+
# Read the entire table into a pandas DataFrame
|
782
|
+
query = f"SELECT * FROM {table_name}"
|
783
|
+
df = pd.read_sql_query(query, conn)
|
784
|
+
# Close the connection
|
785
|
+
conn.close()
|
786
|
+
return df
|
787
|
+
|
788
|
+
def _pivot_dataframe(df):
|
789
|
+
"""
|
790
|
+
Pivot the DataFrame.
|
791
|
+
|
792
|
+
Args:
|
793
|
+
df (pandas.DataFrame): The input DataFrame.
|
794
|
+
|
795
|
+
Returns:
|
796
|
+
pandas.DataFrame: The pivoted DataFrame with filled NaN values.
|
797
|
+
"""
|
798
|
+
# Pivot the DataFrame
|
799
|
+
pivoted_df = df.pivot(index='file_name', columns='count_type', values='object_count').reset_index()
|
800
|
+
# Because the pivot operation can introduce NaN values for missing data,
|
801
|
+
# you might want to fill those NaNs with a default value, like 0
|
802
|
+
pivoted_df = pivoted_df.fillna(0)
|
803
|
+
return pivoted_df
|
804
|
+
|
805
|
+
# Read the original 'object_counts' table
|
806
|
+
df = _read_table_to_dataframe(db_path, 'object_counts')
|
807
|
+
# Pivot the DataFrame to have one row per filename and a column for each object type
|
808
|
+
pivoted_df = _pivot_dataframe(df)
|
809
|
+
# Reconnect to the SQLite database to overwrite the 'object_counts' table with the pivoted DataFrame
|
810
|
+
conn = sqlite3.connect(db_path)
|
811
|
+
# When overwriting, ensure that you drop the existing table or use if_exists='replace' to overwrite it
|
812
|
+
pivoted_df.to_sql('pivoted_counts', conn, if_exists='replace', index=False)
|
813
|
+
conn.close()
|
814
|
+
|
815
|
+
def _get_cellpose_channels_v1(mask_channels, nucleus_chann_dim, pathogen_chann_dim, cell_chann_dim):
|
816
|
+
cellpose_channels = {}
|
817
|
+
if nucleus_chann_dim in mask_channels:
|
818
|
+
cellpose_channels['nucleus'] = [0, mask_channels.index(nucleus_chann_dim)]
|
819
|
+
if pathogen_chann_dim in mask_channels:
|
820
|
+
cellpose_channels['pathogen'] = [0, mask_channels.index(pathogen_chann_dim)]
|
821
|
+
if cell_chann_dim in mask_channels:
|
822
|
+
cellpose_channels['cell'] = [0, mask_channels.index(cell_chann_dim)]
|
823
|
+
return cellpose_channels
|
824
|
+
|
825
|
+
def _get_cellpose_channels_v1(cell_channel, nucleus_channel, pathogen_channel):
|
826
|
+
# Initialize a dictionary to hold the new indices for the specified channels
|
827
|
+
cellpose_channels = {}
|
828
|
+
|
829
|
+
# Initialize a list to keep track of the channels in their new order
|
830
|
+
new_channel_order = []
|
831
|
+
|
832
|
+
# Add each channel to the new order list if it is not None
|
833
|
+
if cell_channel is not None:
|
834
|
+
new_channel_order.append(('cell', cell_channel))
|
835
|
+
if nucleus_channel is not None:
|
836
|
+
new_channel_order.append(('nucleus', nucleus_channel))
|
837
|
+
if pathogen_channel is not None:
|
838
|
+
new_channel_order.append(('pathogen', pathogen_channel))
|
839
|
+
|
840
|
+
# Sort the list based on the original channel indices to maintain the original order
|
841
|
+
new_channel_order.sort(key=lambda x: x[1])
|
842
|
+
print(new_channel_order)
|
843
|
+
# Assign new indices based on the sorted order
|
844
|
+
for new_index, (channel_name, _) in enumerate(new_channel_order):
|
845
|
+
cellpose_channels[channel_name] = [new_index, 0]
|
846
|
+
|
847
|
+
if cell_channel is not None and nucleus_channel is not None:
|
848
|
+
cellpose_channels['cell'][1] = cellpose_channels['nucleus'][0]
|
849
|
+
|
850
|
+
return cellpose_channels
|
851
|
+
|
852
|
+
def _get_cellpose_channels(nucleus_channel, pathogen_channel, cell_channel):
|
853
|
+
cellpose_channels = {}
|
854
|
+
if not nucleus_channel is None:
|
855
|
+
cellpose_channels['nucleus'] = [0,0]
|
856
|
+
|
857
|
+
if not pathogen_channel is None:
|
858
|
+
if not nucleus_channel is None:
|
859
|
+
cellpose_channels['pathogen'] = [0,1]
|
860
|
+
else:
|
861
|
+
cellpose_channels['pathogen'] = [0,0]
|
862
|
+
|
863
|
+
if not cell_channel is None:
|
864
|
+
if not nucleus_channel is None:
|
865
|
+
if not pathogen_channel is None:
|
866
|
+
cellpose_channels['cell'] = [0,2]
|
867
|
+
else:
|
868
|
+
cellpose_channels['cell'] = [0,1]
|
869
|
+
elif not pathogen_channel is None:
|
870
|
+
cellpose_channels['cell'] = [0,1]
|
871
|
+
else:
|
872
|
+
cellpose_channels['cell'] = [0,0]
|
873
|
+
return cellpose_channels
|
874
|
+
|
875
|
+
def annotate_conditions(df, cells=['HeLa'], cell_loc=None, pathogens=['rh'], pathogen_loc=None, treatments=['cm'], treatment_loc=None, types = ['col','col','col']):
|
876
|
+
"""
|
877
|
+
Annotates conditions in a DataFrame based on specified criteria.
|
878
|
+
|
879
|
+
Args:
|
880
|
+
df (pandas.DataFrame): The DataFrame to annotate.
|
881
|
+
cells (list, optional): List of host cell types. Defaults to ['HeLa'].
|
882
|
+
cell_loc (list, optional): List of corresponding values for each host cell type. Defaults to None.
|
883
|
+
pathogens (list, optional): List of pathogens. Defaults to ['rh'].
|
884
|
+
pathogen_loc (list, optional): List of corresponding values for each pathogen. Defaults to None.
|
885
|
+
treatments (list, optional): List of treatments. Defaults to ['cm'].
|
886
|
+
treatment_loc (list, optional): List of corresponding values for each treatment. Defaults to None.
|
887
|
+
types (list, optional): List of column types for host cells, pathogens, and treatments. Defaults to ['col','col','col'].
|
888
|
+
|
889
|
+
Returns:
|
890
|
+
pandas.DataFrame: The annotated DataFrame.
|
891
|
+
"""
|
892
|
+
|
893
|
+
# Function to apply to each row
|
894
|
+
def _map_values(row, dict_, type_='col'):
|
895
|
+
"""
|
896
|
+
Maps the values in a row to corresponding keys in a dictionary.
|
897
|
+
|
898
|
+
Args:
|
899
|
+
row (dict): The row containing the values to be mapped.
|
900
|
+
dict_ (dict): The dictionary containing the mapping values.
|
901
|
+
type_ (str, optional): The type of mapping to perform. Defaults to 'col'.
|
902
|
+
|
903
|
+
Returns:
|
904
|
+
str: The mapped value if found, otherwise None.
|
905
|
+
"""
|
906
|
+
for values, cols in dict_.items():
|
907
|
+
if row[type_] in cols:
|
908
|
+
return values
|
909
|
+
return None
|
910
|
+
|
911
|
+
if cell_loc is None:
|
912
|
+
df['host_cells'] = cells[0]
|
913
|
+
else:
|
914
|
+
cells_dict = dict(zip(cells, cell_loc))
|
915
|
+
df['host_cells'] = df.apply(lambda row: _map_values(row, cells_dict, type_=types[0]), axis=1)
|
916
|
+
if pathogen_loc is None:
|
917
|
+
if pathogens != None:
|
918
|
+
df['pathogen'] = 'none'
|
919
|
+
else:
|
920
|
+
pathogens_dict = dict(zip(pathogens, pathogen_loc))
|
921
|
+
df['pathogen'] = df.apply(lambda row: _map_values(row, pathogens_dict, type_=types[1]), axis=1)
|
922
|
+
if treatment_loc is None:
|
923
|
+
df['treatment'] = 'cm'
|
924
|
+
else:
|
925
|
+
treatments_dict = dict(zip(treatments, treatment_loc))
|
926
|
+
df['treatment'] = df.apply(lambda row: _map_values(row, treatments_dict, type_=types[2]), axis=1)
|
927
|
+
if pathogens != None:
|
928
|
+
df['condition'] = df['pathogen']+'_'+df['treatment']
|
929
|
+
else:
|
930
|
+
df['condition'] = df['treatment']
|
931
|
+
return df
|
932
|
+
|
933
|
+
|
934
|
+
|
935
|
+
def _split_data(df, group_by, object_type):
|
936
|
+
"""
|
937
|
+
Splits the input dataframe into numeric and non-numeric parts, groups them by the specified column,
|
938
|
+
and returns the grouped dataframes.
|
939
|
+
|
940
|
+
Parameters:
|
941
|
+
df (pandas.DataFrame): The input dataframe.
|
942
|
+
group_by (str): The column name to group the dataframes by.
|
943
|
+
object_type (str): The column name to concatenate with 'prcf' to create a new column 'prcfo'.
|
944
|
+
|
945
|
+
Returns:
|
946
|
+
grouped_numeric (pandas.DataFrame): The grouped dataframe containing numeric columns.
|
947
|
+
grouped_non_numeric (pandas.DataFrame): The grouped dataframe containing non-numeric columns.
|
948
|
+
"""
|
949
|
+
df['prcfo'] = df['prcf'] + '_' + df[object_type]
|
950
|
+
df = df.set_index(group_by, inplace=False)
|
951
|
+
|
952
|
+
df_numeric = df.select_dtypes(include=np.number)
|
953
|
+
df_non_numeric = df.select_dtypes(exclude=np.number)
|
954
|
+
|
955
|
+
grouped_numeric = df_numeric.groupby(df_numeric.index).mean()
|
956
|
+
grouped_non_numeric = df_non_numeric.groupby(df_non_numeric.index).first()
|
957
|
+
|
958
|
+
return pd.DataFrame(grouped_numeric), pd.DataFrame(grouped_non_numeric)
|
959
|
+
|
960
|
+
def _calculate_recruitment(df, channel):
|
961
|
+
"""
|
962
|
+
Calculate recruitment metrics based on intensity values in different channels.
|
963
|
+
|
964
|
+
Args:
|
965
|
+
df (pandas.DataFrame): The input DataFrame containing intensity values in different channels.
|
966
|
+
channel (int): The channel number.
|
967
|
+
|
968
|
+
Returns:
|
969
|
+
pandas.DataFrame: The DataFrame with calculated recruitment metrics.
|
970
|
+
|
971
|
+
"""
|
972
|
+
df['pathogen_cell_mean_mean'] = df[f'pathogen_channel_{channel}_mean_intensity']/df[f'cell_channel_{channel}_mean_intensity']
|
973
|
+
df['pathogen_cytoplasm_mean_mean'] = df[f'pathogen_channel_{channel}_mean_intensity']/df[f'cytoplasm_channel_{channel}_mean_intensity']
|
974
|
+
df['pathogen_nucleus_mean_mean'] = df[f'pathogen_channel_{channel}_mean_intensity']/df[f'nucleus_channel_{channel}_mean_intensity']
|
975
|
+
|
976
|
+
df['pathogen_cell_q75_mean'] = df[f'pathogen_channel_{channel}_percentile_75']/df[f'cell_channel_{channel}_mean_intensity']
|
977
|
+
df['pathogen_cytoplasm_q75_mean'] = df[f'pathogen_channel_{channel}_percentile_75']/df[f'cytoplasm_channel_{channel}_mean_intensity']
|
978
|
+
df['pathogen_nucleus_q75_mean'] = df[f'pathogen_channel_{channel}_percentile_75']/df[f'nucleus_channel_{channel}_mean_intensity']
|
979
|
+
|
980
|
+
df['pathogen_outside_cell_mean_mean'] = df[f'pathogen_channel_{channel}_outside_mean']/df[f'cell_channel_{channel}_mean_intensity']
|
981
|
+
df['pathogen_outside_cytoplasm_mean_mean'] = df[f'pathogen_channel_{channel}_outside_mean']/df[f'cytoplasm_channel_{channel}_mean_intensity']
|
982
|
+
df['pathogen_outside_nucleus_mean_mean'] = df[f'pathogen_channel_{channel}_outside_mean']/df[f'nucleus_channel_{channel}_mean_intensity']
|
983
|
+
|
984
|
+
df['pathogen_outside_cell_q75_mean'] = df[f'pathogen_channel_{channel}_outside_75_percentile']/df[f'cell_channel_{channel}_mean_intensity']
|
985
|
+
df['pathogen_outside_cytoplasm_q75_mean'] = df[f'pathogen_channel_{channel}_outside_75_percentile']/df[f'cytoplasm_channel_{channel}_mean_intensity']
|
986
|
+
df['pathogen_outside_nucleus_q75_mean'] = df[f'pathogen_channel_{channel}_outside_75_percentile']/df[f'nucleus_channel_{channel}_mean_intensity']
|
987
|
+
|
988
|
+
df['pathogen_periphery_cell_mean_mean'] = df[f'pathogen_channel_{channel}_periphery_mean']/df[f'cell_channel_{channel}_mean_intensity']
|
989
|
+
df['pathogen_periphery_cytoplasm_mean_mean'] = df[f'pathogen_channel_{channel}_periphery_mean']/df[f'cytoplasm_channel_{channel}_mean_intensity']
|
990
|
+
df['pathogen_periphery_nucleus_mean_mean'] = df[f'pathogen_channel_{channel}_periphery_mean']/df[f'nucleus_channel_{channel}_mean_intensity']
|
991
|
+
|
992
|
+
channels = [0,1,2,3]
|
993
|
+
object_type = 'pathogen'
|
994
|
+
for chan in channels:
|
995
|
+
df[f'{object_type}_slope_channel_{chan}'] = 1
|
996
|
+
|
997
|
+
object_type = 'nucleus'
|
998
|
+
for chan in channels:
|
999
|
+
df[f'{object_type}_slope_channel_{chan}'] = 1
|
1000
|
+
|
1001
|
+
for chan in channels:
|
1002
|
+
df[f'nucleus_coordinates_{chan}'] = df[[f'nucleus_channel_{chan}_centroid_weighted_local-0', f'nucleus_channel_{chan}_centroid_weighted_local-1']].values.tolist()
|
1003
|
+
df[f'pathogen_coordinates_{chan}'] = df[[f'pathogen_channel_{chan}_centroid_weighted_local-0', f'pathogen_channel_{chan}_centroid_weighted_local-1']].values.tolist()
|
1004
|
+
df[f'cell_coordinates_{chan}'] = df[[f'cell_channel_{chan}_centroid_weighted_local-0', f'cell_channel_{chan}_centroid_weighted_local-1']].values.tolist()
|
1005
|
+
df[f'cytoplasm_coordinates_{chan}'] = df[[f'cytoplasm_channel_{chan}_centroid_weighted_local-0', f'cytoplasm_channel_{chan}_centroid_weighted_local-1']].values.tolist()
|
1006
|
+
|
1007
|
+
df[f'pathogen_cell_distance_channel_{chan}'] = df.apply(lambda row: np.sqrt((row[f'pathogen_coordinates_{chan}'][0] - row[f'cell_coordinates_{chan}'][0])**2 +
|
1008
|
+
(row[f'pathogen_coordinates_{chan}'][1] - row[f'cell_coordinates_{chan}'][1])**2), axis=1)
|
1009
|
+
df[f'nucleus_cell_distance_channel_{chan}'] = df.apply(lambda row: np.sqrt((row[f'nucleus_coordinates_{chan}'][0] - row[f'cell_coordinates_{chan}'][0])**2 +
|
1010
|
+
(row[f'nucleus_coordinates_{chan}'][1] - row[f'cell_coordinates_{chan}'][1])**2), axis=1)
|
1011
|
+
return df
|
1012
|
+
|
1013
|
+
def _group_by_well(df):
|
1014
|
+
"""
|
1015
|
+
Group the DataFrame by well coordinates (plate, row, col) and apply mean function to numeric columns
|
1016
|
+
and select the first value for non-numeric columns.
|
1017
|
+
|
1018
|
+
Parameters:
|
1019
|
+
df (DataFrame): The input DataFrame to be grouped.
|
1020
|
+
|
1021
|
+
Returns:
|
1022
|
+
DataFrame: The grouped DataFrame.
|
1023
|
+
"""
|
1024
|
+
numeric_cols = df._get_numeric_data().columns
|
1025
|
+
non_numeric_cols = df.select_dtypes(include=['object']).columns
|
1026
|
+
|
1027
|
+
# Apply mean function to numeric columns and first to non-numeric
|
1028
|
+
df_grouped = df.groupby(['plate', 'row', 'col']).agg({**{col: np.mean for col in numeric_cols}, **{col: 'first' for col in non_numeric_cols}})
|
1029
|
+
return df_grouped
|
1030
|
+
|
1031
|
+
|
1032
|
+
|
1033
|
+
|
1034
|
+
###################################################
|
1035
|
+
# Classify
|
1036
|
+
###################################################
|
1037
|
+
|
1038
|
+
class Cache:
|
1039
|
+
"""
|
1040
|
+
A class representing a cache with a maximum size.
|
1041
|
+
|
1042
|
+
Attributes:
|
1043
|
+
max_size (int): The maximum size of the cache.
|
1044
|
+
cache (OrderedDict): The cache data structure.
|
1045
|
+
"""
|
1046
|
+
|
1047
|
+
def _init__(self, max_size):
|
1048
|
+
self.cache = OrderedDict()
|
1049
|
+
self.max_size = max_size
|
1050
|
+
|
1051
|
+
def get(self, key):
|
1052
|
+
if key in self.cache:
|
1053
|
+
value = self.cache.pop(key)
|
1054
|
+
self.cache[key] = value
|
1055
|
+
return value
|
1056
|
+
return None
|
1057
|
+
|
1058
|
+
def put(self, key, value):
|
1059
|
+
if len(self.cache) >= self.max_size:
|
1060
|
+
self.cache.popitem(last=False)
|
1061
|
+
self.cache[key] = value
|
1062
|
+
|
1063
|
+
class ScaledDotProductAttention(nn.Module):
|
1064
|
+
"""
|
1065
|
+
Scaled Dot-Product Attention module.
|
1066
|
+
|
1067
|
+
Args:
|
1068
|
+
d_k (int): The dimension of the key and query vectors.
|
1069
|
+
|
1070
|
+
Attributes:
|
1071
|
+
d_k (int): The dimension of the key and query vectors.
|
1072
|
+
|
1073
|
+
Methods:
|
1074
|
+
forward(Q, K, V): Performs the forward pass of the attention mechanism.
|
1075
|
+
|
1076
|
+
"""
|
1077
|
+
|
1078
|
+
def _init__(self, d_k):
|
1079
|
+
super(ScaledDotProductAttention, self).__init__()
|
1080
|
+
self.d_k = d_k
|
1081
|
+
|
1082
|
+
def forward(self, Q, K, V):
|
1083
|
+
"""
|
1084
|
+
Performs the forward pass of the attention mechanism.
|
1085
|
+
|
1086
|
+
Args:
|
1087
|
+
Q (torch.Tensor): The query tensor of shape (batch_size, seq_len_q, d_k).
|
1088
|
+
K (torch.Tensor): The key tensor of shape (batch_size, seq_len_k, d_k).
|
1089
|
+
V (torch.Tensor): The value tensor of shape (batch_size, seq_len_v, d_k).
|
1090
|
+
|
1091
|
+
Returns:
|
1092
|
+
torch.Tensor: The output tensor of shape (batch_size, seq_len_q, d_k).
|
1093
|
+
|
1094
|
+
"""
|
1095
|
+
scores = torch.matmul(Q, K.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.d_k, dtype=torch.float32))
|
1096
|
+
attention_probs = F.softmax(scores, dim=-1)
|
1097
|
+
output = torch.matmul(attention_probs, V)
|
1098
|
+
return output
|
1099
|
+
|
1100
|
+
class SelfAttention(nn.Module):
|
1101
|
+
"""
|
1102
|
+
Self-Attention module that applies scaled dot-product attention mechanism.
|
1103
|
+
|
1104
|
+
Args:
|
1105
|
+
in_channels (int): Number of input channels.
|
1106
|
+
d_k (int): Dimensionality of the key and query vectors.
|
1107
|
+
"""
|
1108
|
+
|
1109
|
+
def _init__(self, in_channels, d_k):
|
1110
|
+
super(SelfAttention, self).__init__()
|
1111
|
+
self.W_q = nn.Linear(in_channels, d_k)
|
1112
|
+
self.W_k = nn.Linear(in_channels, d_k)
|
1113
|
+
self.W_v = nn.Linear(in_channels, d_k)
|
1114
|
+
self.attention = ScaledDotProductAttention(d_k)
|
1115
|
+
|
1116
|
+
def forward(self, x):
|
1117
|
+
"""
|
1118
|
+
Forward pass of the SelfAttention module.
|
1119
|
+
|
1120
|
+
Args:
|
1121
|
+
x (torch.Tensor): Input tensor of shape (batch_size, in_channels).
|
1122
|
+
|
1123
|
+
Returns:
|
1124
|
+
torch.Tensor: Output tensor of shape (batch_size, d_k).
|
1125
|
+
"""
|
1126
|
+
Q = self.W_q(x)
|
1127
|
+
K = self.W_k(x)
|
1128
|
+
V = self.W_v(x)
|
1129
|
+
output = self.attention(Q, K, V)
|
1130
|
+
return output
|
1131
|
+
|
1132
|
+
class ScaledDotProductAttention(nn.Module):
|
1133
|
+
def _init__(self, d_k):
|
1134
|
+
"""
|
1135
|
+
Initializes the ScaledDotProductAttention module.
|
1136
|
+
|
1137
|
+
Args:
|
1138
|
+
d_k (int): The dimension of the key and query vectors.
|
1139
|
+
|
1140
|
+
"""
|
1141
|
+
super(ScaledDotProductAttention, self).__init__()
|
1142
|
+
self.d_k = d_k
|
1143
|
+
|
1144
|
+
def forward(self, Q, K, V):
|
1145
|
+
"""
|
1146
|
+
Performs the forward pass of the ScaledDotProductAttention module.
|
1147
|
+
|
1148
|
+
Args:
|
1149
|
+
Q (torch.Tensor): The query tensor.
|
1150
|
+
K (torch.Tensor): The key tensor.
|
1151
|
+
V (torch.Tensor): The value tensor.
|
1152
|
+
|
1153
|
+
Returns:
|
1154
|
+
torch.Tensor: The output tensor.
|
1155
|
+
|
1156
|
+
"""
|
1157
|
+
scores = torch.matmul(Q, K.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.d_k, dtype=torch.float32))
|
1158
|
+
attention_probs = F.softmax(scores, dim=-1)
|
1159
|
+
output = torch.matmul(attention_probs, V)
|
1160
|
+
return output
|
1161
|
+
|
1162
|
+
class SelfAttention(nn.Module):
|
1163
|
+
"""
|
1164
|
+
Self-Attention module that applies scaled dot-product attention mechanism.
|
1165
|
+
|
1166
|
+
Args:
|
1167
|
+
in_channels (int): Number of input channels.
|
1168
|
+
d_k (int): Dimensionality of the key and query vectors.
|
1169
|
+
"""
|
1170
|
+
def _init__(self, in_channels, d_k):
|
1171
|
+
super(SelfAttention, self).__init__()
|
1172
|
+
self.W_q = nn.Linear(in_channels, d_k)
|
1173
|
+
self.W_k = nn.Linear(in_channels, d_k)
|
1174
|
+
self.W_v = nn.Linear(in_channels, d_k)
|
1175
|
+
self.attention = ScaledDotProductAttention(d_k)
|
1176
|
+
|
1177
|
+
def forward(self, x):
|
1178
|
+
"""
|
1179
|
+
Forward pass of the SelfAttention module.
|
1180
|
+
|
1181
|
+
Args:
|
1182
|
+
x (torch.Tensor): Input tensor of shape (batch_size, in_channels).
|
1183
|
+
|
1184
|
+
Returns:
|
1185
|
+
torch.Tensor: Output tensor after applying self-attention mechanism.
|
1186
|
+
"""
|
1187
|
+
Q = self.W_q(x)
|
1188
|
+
K = self.W_k(x)
|
1189
|
+
V = self.W_v(x)
|
1190
|
+
output = self.attention(Q, K, V)
|
1191
|
+
return output
|
1192
|
+
|
1193
|
+
# Early Fusion Block
|
1194
|
+
class EarlyFusion(nn.Module):
|
1195
|
+
"""
|
1196
|
+
Early Fusion module for image classification.
|
1197
|
+
|
1198
|
+
Args:
|
1199
|
+
in_channels (int): Number of input channels.
|
1200
|
+
"""
|
1201
|
+
def _init__(self, in_channels):
|
1202
|
+
super(EarlyFusion, self).__init__()
|
1203
|
+
self.conv1 = nn.Conv2d(in_channels, 64, kernel_size=1, stride=1)
|
1204
|
+
|
1205
|
+
def forward(self, x):
|
1206
|
+
"""
|
1207
|
+
Forward pass of the Early Fusion module.
|
1208
|
+
|
1209
|
+
Args:
|
1210
|
+
x (torch.Tensor): Input tensor of shape (batch_size, in_channels, height, width).
|
1211
|
+
|
1212
|
+
Returns:
|
1213
|
+
torch.Tensor: Output tensor of shape (batch_size, 64, height, width).
|
1214
|
+
"""
|
1215
|
+
x = self.conv1(x)
|
1216
|
+
return x
|
1217
|
+
|
1218
|
+
# Spatial Attention Mechanism
|
1219
|
+
class SpatialAttention(nn.Module):
|
1220
|
+
def _init__(self, kernel_size=7):
|
1221
|
+
"""
|
1222
|
+
Initializes the SpatialAttention module.
|
1223
|
+
|
1224
|
+
Args:
|
1225
|
+
kernel_size (int): The size of the convolutional kernel. Default is 7.
|
1226
|
+
"""
|
1227
|
+
super(SpatialAttention, self).__init__()
|
1228
|
+
self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=kernel_size//2, bias=False)
|
1229
|
+
self.sigmoid = nn.Sigmoid()
|
1230
|
+
|
1231
|
+
def forward(self, x):
|
1232
|
+
"""
|
1233
|
+
Performs forward pass of the SpatialAttention module.
|
1234
|
+
|
1235
|
+
Args:
|
1236
|
+
x (torch.Tensor): The input tensor.
|
1237
|
+
|
1238
|
+
Returns:
|
1239
|
+
torch.Tensor: The output tensor after applying spatial attention.
|
1240
|
+
"""
|
1241
|
+
avg_out = torch.mean(x, dim=1, keepdim=True)
|
1242
|
+
max_out, _ = torch.max(x, dim=1, keepdim=True)
|
1243
|
+
x = torch.cat([avg_out, max_out], dim=1)
|
1244
|
+
x = self.conv1(x)
|
1245
|
+
return self.sigmoid(x)
|
1246
|
+
|
1247
|
+
# Multi-Scale Block with Attention
|
1248
|
+
class MultiScaleBlockWithAttention(nn.Module):
|
1249
|
+
"""
|
1250
|
+
Multi-scale block with attention module.
|
1251
|
+
|
1252
|
+
Args:
|
1253
|
+
in_channels (int): Number of input channels.
|
1254
|
+
out_channels (int): Number of output channels.
|
1255
|
+
|
1256
|
+
Attributes:
|
1257
|
+
dilated_conv1 (nn.Conv2d): Dilated convolution layer.
|
1258
|
+
spatial_attention (nn.Conv2d): Spatial attention layer.
|
1259
|
+
|
1260
|
+
Methods:
|
1261
|
+
custom_forward: Custom forward method for the module.
|
1262
|
+
forward: Forward method for the module.
|
1263
|
+
"""
|
1264
|
+
|
1265
|
+
def _init__(self, in_channels, out_channels):
|
1266
|
+
super(MultiScaleBlockWithAttention, self).__init__()
|
1267
|
+
self.dilated_conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, dilation=1, padding=1)
|
1268
|
+
self.spatial_attention = nn.Conv2d(out_channels, out_channels, kernel_size=1)
|
1269
|
+
|
1270
|
+
def custom_forward(self, x):
|
1271
|
+
"""
|
1272
|
+
Custom forward method for the module.
|
1273
|
+
|
1274
|
+
Args:
|
1275
|
+
x (torch.Tensor): Input tensor.
|
1276
|
+
|
1277
|
+
Returns:
|
1278
|
+
torch.Tensor: Output tensor.
|
1279
|
+
"""
|
1280
|
+
x1 = F.relu(self.dilated_conv1(x), inplace=True)
|
1281
|
+
x = self.spatial_attention(x1)
|
1282
|
+
return x
|
1283
|
+
|
1284
|
+
def forward(self, x):
|
1285
|
+
"""
|
1286
|
+
Forward method for the module.
|
1287
|
+
|
1288
|
+
Args:
|
1289
|
+
x (torch.Tensor): Input tensor.
|
1290
|
+
|
1291
|
+
Returns:
|
1292
|
+
torch.Tensor: Output tensor.
|
1293
|
+
"""
|
1294
|
+
return checkpoint(self.custom_forward, x)
|
1295
|
+
|
1296
|
+
# Final Classifier
|
1297
|
+
class CustomCellClassifier(nn.Module):
|
1298
|
+
def _init__(self, num_classes, pathogen_channel, use_attention, use_checkpoint, dropout_rate):
|
1299
|
+
super(CustomCellClassifier, self).__init__()
|
1300
|
+
self.early_fusion = EarlyFusion(in_channels=3)
|
1301
|
+
|
1302
|
+
self.multi_scale_block_1 = MultiScaleBlockWithAttention(in_channels=64, out_channels=64)
|
1303
|
+
|
1304
|
+
self.fc1 = nn.Linear(64, num_classes)
|
1305
|
+
self.use_checkpoint = use_checkpoint
|
1306
|
+
# Explicitly require gradients for all parameters
|
1307
|
+
for param in self.parameters():
|
1308
|
+
param.requires_grad = True
|
1309
|
+
|
1310
|
+
def custom_forward(self, x):
|
1311
|
+
x.requires_grad = True
|
1312
|
+
x = self.early_fusion(x)
|
1313
|
+
x = self.multi_scale_block_1(x)
|
1314
|
+
x = F.adaptive_avg_pool2d(x, (1, 1)).view(x.size(0), -1)
|
1315
|
+
x = F.relu(self.fc1(x), inplace=True)
|
1316
|
+
return x
|
1317
|
+
|
1318
|
+
def forward(self, x):
|
1319
|
+
if self.use_checkpoint:
|
1320
|
+
x.requires_grad = True
|
1321
|
+
return checkpoint(self.custom_forward, x)
|
1322
|
+
else:
|
1323
|
+
return self.custom_forward(x)
|
1324
|
+
|
1325
|
+
#CNN and Transformer class, pick any Torch model.
|
1326
|
+
class TorchModel(nn.Module):
|
1327
|
+
def _init__(self, model_name='resnet50', pretrained=True, dropout_rate=None, use_checkpoint=False):
|
1328
|
+
super(TorchModel, self).__init__()
|
1329
|
+
self.model_name = model_name
|
1330
|
+
self.use_checkpoint = use_checkpoint
|
1331
|
+
self.base_model = self.init_base_model(pretrained)
|
1332
|
+
|
1333
|
+
# Retain layers up to and including the (5): Linear layer for model 'maxvit_t'
|
1334
|
+
if model_name == 'maxvit_t':
|
1335
|
+
self.base_model.classifier = nn.Sequential(*list(self.base_model.classifier.children())[:-1])
|
1336
|
+
|
1337
|
+
if dropout_rate is not None:
|
1338
|
+
self.apply_dropout_rate(self.base_model, dropout_rate)
|
1339
|
+
|
1340
|
+
self.num_ftrs = self.get_num_ftrs()
|
1341
|
+
self.init_spacr_classifier(dropout_rate)
|
1342
|
+
|
1343
|
+
def apply_dropout_rate(self, model, dropout_rate):
|
1344
|
+
"""Apply dropout rate to all dropout layers in the model."""
|
1345
|
+
for module in model.modules():
|
1346
|
+
if isinstance(module, nn.Dropout):
|
1347
|
+
module.p = dropout_rate
|
1348
|
+
|
1349
|
+
def init_base_model(self, pretrained):
|
1350
|
+
"""Initialize the base model from torchvision.models."""
|
1351
|
+
model_func = models.__dict__.get(self.model_name, None)
|
1352
|
+
if not model_func:
|
1353
|
+
raise ValueError(f"Model {self.model_name} is not recognized.")
|
1354
|
+
weight_choice = self.get_weight_choice()
|
1355
|
+
if weight_choice is not None:
|
1356
|
+
return model_func(weights=weight_choice)
|
1357
|
+
else:
|
1358
|
+
return model_func(pretrained=pretrained)
|
1359
|
+
|
1360
|
+
def get_weight_choice(self):
|
1361
|
+
"""Get weight choice if it exists for the model."""
|
1362
|
+
weight_enum = None
|
1363
|
+
for attr_name in dir(models):
|
1364
|
+
if attr_name.lower() == f"{self.model_name}_weights".lower():
|
1365
|
+
weight_enum = getattr(models, attr_name)
|
1366
|
+
break
|
1367
|
+
return weight_enum.DEFAULT if weight_enum else None
|
1368
|
+
|
1369
|
+
def get_num_ftrs(self):
|
1370
|
+
"""Determine the number of features output by the base model."""
|
1371
|
+
if hasattr(self.base_model, 'fc'):
|
1372
|
+
self.base_model.fc = nn.Identity()
|
1373
|
+
elif hasattr(self.base_model, 'classifier'):
|
1374
|
+
if self.model_name != 'maxvit_t':
|
1375
|
+
self.base_model.classifier = nn.Identity()
|
1376
|
+
|
1377
|
+
# Forward a dummy input and check output size
|
1378
|
+
dummy_input = torch.randn(1, 3, 224, 224)
|
1379
|
+
output = self.base_model(dummy_input)
|
1380
|
+
return output.size(1)
|
1381
|
+
|
1382
|
+
def init_spacr_classifier(self, dropout_rate):
|
1383
|
+
"""Initialize the SPACR classifier."""
|
1384
|
+
self.use_dropout = dropout_rate is not None
|
1385
|
+
if self.use_dropout:
|
1386
|
+
self.dropout = nn.Dropout(dropout_rate)
|
1387
|
+
self.spacr_classifier = nn.Linear(self.num_ftrs, 1)
|
1388
|
+
|
1389
|
+
def forward(self, x):
|
1390
|
+
"""Define the forward pass of the model."""
|
1391
|
+
if self.use_checkpoint:
|
1392
|
+
x = checkpoint(self.base_model, x)
|
1393
|
+
else:
|
1394
|
+
x = self.base_model(x)
|
1395
|
+
if self.use_dropout:
|
1396
|
+
x = self.dropout(x)
|
1397
|
+
logits = self.spacr_classifier(x).flatten()
|
1398
|
+
return logits
|
1399
|
+
|
1400
|
+
class FocalLossWithLogits(nn.Module):
|
1401
|
+
def _init__(self, alpha=1, gamma=2):
|
1402
|
+
super(FocalLossWithLogits, self).__init__()
|
1403
|
+
self.alpha = alpha
|
1404
|
+
self.gamma = gamma
|
1405
|
+
|
1406
|
+
def forward(self, logits, target):
|
1407
|
+
BCE_loss = F.binary_cross_entropy_with_logits(logits, target, reduction='none')
|
1408
|
+
pt = torch.exp(-BCE_loss)
|
1409
|
+
focal_loss = self.alpha * (1-pt)**self.gamma * BCE_loss
|
1410
|
+
return focal_loss.mean()
|
1411
|
+
|
1412
|
+
class ResNet(nn.Module):
|
1413
|
+
def _init__(self, resnet_type='resnet50', dropout_rate=None, use_checkpoint=False, init_weights='imagenet'):
|
1414
|
+
super(ResNet, self).__init__()
|
1415
|
+
|
1416
|
+
resnet_map = {
|
1417
|
+
'resnet18': {'func': models.resnet18, 'weights': ResNet18_Weights.IMAGENET1K_V1},
|
1418
|
+
'resnet34': {'func': models.resnet34, 'weights': ResNet34_Weights.IMAGENET1K_V1},
|
1419
|
+
'resnet50': {'func': models.resnet50, 'weights': ResNet50_Weights.IMAGENET1K_V1},
|
1420
|
+
'resnet101': {'func': models.resnet101, 'weights': ResNet101_Weights.IMAGENET1K_V1},
|
1421
|
+
'resnet152': {'func': models.resnet152, 'weights': ResNet152_Weights.IMAGENET1K_V1}
|
1422
|
+
}
|
1423
|
+
|
1424
|
+
if resnet_type not in resnet_map:
|
1425
|
+
raise ValueError(f"Invalid resnet_type. Choose from {list(resnet_map.keys())}")
|
1426
|
+
|
1427
|
+
self.initialize_base(resnet_map[resnet_type], dropout_rate, use_checkpoint, init_weights)
|
1428
|
+
|
1429
|
+
def initialize_base(self, base_model_dict, dropout_rate, use_checkpoint, init_weights):
|
1430
|
+
if init_weights == 'imagenet':
|
1431
|
+
self.resnet = base_model_dict['func'](weights=base_model_dict['weights'])
|
1432
|
+
elif init_weights == 'none':
|
1433
|
+
self.resnet = base_model_dict['func'](weights=None)
|
1434
|
+
else:
|
1435
|
+
raise ValueError("init_weights should be either 'imagenet' or 'none'")
|
1436
|
+
|
1437
|
+
self.fc1 = nn.Linear(1000, 500)
|
1438
|
+
self.use_dropout = dropout_rate != None
|
1439
|
+
self.use_checkpoint = use_checkpoint
|
1440
|
+
|
1441
|
+
if self.use_dropout:
|
1442
|
+
self.dropout = nn.Dropout(dropout_rate)
|
1443
|
+
|
1444
|
+
self.fc2 = nn.Linear(500, 1)
|
1445
|
+
|
1446
|
+
def forward(self, x):
|
1447
|
+
x.requires_grad = True # Ensure that the tensor has requires_grad set to True
|
1448
|
+
|
1449
|
+
if self.use_checkpoint:
|
1450
|
+
x = checkpoint(self.resnet, x) # Use checkpointing for just the ResNet part
|
1451
|
+
else:
|
1452
|
+
x = self.resnet(x)
|
1453
|
+
|
1454
|
+
x = F.relu(self.fc1(x))
|
1455
|
+
|
1456
|
+
if self.use_dropout:
|
1457
|
+
x = self.dropout(x)
|
1458
|
+
|
1459
|
+
logits = self.fc2(x).flatten()
|
1460
|
+
return logits
|
1461
|
+
|
1462
|
+
def split_my_dataset(dataset, split_ratio=0.1):
|
1463
|
+
"""
|
1464
|
+
Splits a dataset into training and validation subsets.
|
1465
|
+
|
1466
|
+
Args:
|
1467
|
+
dataset (torch.utils.data.Dataset): The dataset to be split.
|
1468
|
+
split_ratio (float, optional): The ratio of validation samples to total samples. Defaults to 0.1.
|
1469
|
+
|
1470
|
+
Returns:
|
1471
|
+
tuple: A tuple containing the training dataset and validation dataset.
|
1472
|
+
"""
|
1473
|
+
num_samples = len(dataset)
|
1474
|
+
indices = list(range(num_samples))
|
1475
|
+
split_idx = int((1 - split_ratio) * num_samples)
|
1476
|
+
random.shuffle(indices)
|
1477
|
+
train_indices, val_indices = indices[:split_idx], indices[split_idx:]
|
1478
|
+
train_dataset = Subset(dataset, train_indices)
|
1479
|
+
val_dataset = Subset(dataset, val_indices)
|
1480
|
+
return train_dataset, val_dataset
|
1481
|
+
|
1482
|
+
def classification_metrics(all_labels, prediction_pos_probs, loader_name, loss, epoch):
|
1483
|
+
"""
|
1484
|
+
Calculate classification metrics for binary classification.
|
1485
|
+
|
1486
|
+
Parameters:
|
1487
|
+
- all_labels (list): List of true labels.
|
1488
|
+
- prediction_pos_probs (list): List of predicted positive probabilities.
|
1489
|
+
- loader_name (str): Name of the data loader.
|
1490
|
+
- loss (float): Loss value.
|
1491
|
+
- epoch (int): Epoch number.
|
1492
|
+
|
1493
|
+
Returns:
|
1494
|
+
- data_df (DataFrame): DataFrame containing the calculated metrics.
|
1495
|
+
"""
|
1496
|
+
|
1497
|
+
if len(all_labels) != len(prediction_pos_probs):
|
1498
|
+
raise ValueError(f"all_labels ({len(all_labels)}) and pred_labels ({len(prediction_pos_probs)}) have different lengths")
|
1499
|
+
|
1500
|
+
unique_labels = np.unique(all_labels)
|
1501
|
+
if len(unique_labels) >= 2:
|
1502
|
+
pr_labels = np.array(all_labels).astype(int)
|
1503
|
+
precision, recall, thresholds = precision_recall_curve(pr_labels, prediction_pos_probs, pos_label=1)
|
1504
|
+
pr_auc = auc(recall, precision)
|
1505
|
+
thresholds = np.append(thresholds, 0.0)
|
1506
|
+
f1_scores = 2 * (precision * recall) / (precision + recall)
|
1507
|
+
optimal_idx = np.nanargmax(f1_scores)
|
1508
|
+
optimal_threshold = thresholds[optimal_idx]
|
1509
|
+
pred_labels = [int(p > 0.5) for p in prediction_pos_probs]
|
1510
|
+
if len(unique_labels) < 2:
|
1511
|
+
optimal_threshold = 0.5
|
1512
|
+
pred_labels = [int(p > optimal_threshold) for p in prediction_pos_probs]
|
1513
|
+
pr_auc = np.nan
|
1514
|
+
data = {'label': all_labels, 'pred': pred_labels}
|
1515
|
+
df = pd.DataFrame(data)
|
1516
|
+
pc_df = df[df['label'] == 1.0]
|
1517
|
+
nc_df = df[df['label'] == 0.0]
|
1518
|
+
correct = df[df['label'] == df['pred']]
|
1519
|
+
acc_all = len(correct) / len(df)
|
1520
|
+
if len(pc_df) > 0:
|
1521
|
+
correct_pc = pc_df[pc_df['label'] == pc_df['pred']]
|
1522
|
+
acc_pc = len(correct_pc) / len(pc_df)
|
1523
|
+
else:
|
1524
|
+
acc_pc = np.nan
|
1525
|
+
if len(nc_df) > 0:
|
1526
|
+
correct_nc = nc_df[nc_df['label'] == nc_df['pred']]
|
1527
|
+
acc_nc = len(correct_nc) / len(nc_df)
|
1528
|
+
else:
|
1529
|
+
acc_nc = np.nan
|
1530
|
+
data_dict = {'accuracy': acc_all, 'neg_accuracy': acc_nc, 'pos_accuracy': acc_pc, 'loss':loss.item(),'prauc':pr_auc, 'optimal_threshold':optimal_threshold}
|
1531
|
+
data_df = pd.DataFrame(data_dict, index=[str(epoch)+'_'+loader_name])
|
1532
|
+
return data_df
|
1533
|
+
|
1534
|
+
|
1535
|
+
|
1536
|
+
def compute_irm_penalty(losses, dummy_w, device):
|
1537
|
+
"""
|
1538
|
+
Computes the Invariant Risk Minimization (IRM) penalty.
|
1539
|
+
|
1540
|
+
Args:
|
1541
|
+
losses (list): A list of losses.
|
1542
|
+
dummy_w (torch.Tensor): A dummy weight tensor.
|
1543
|
+
device (torch.device): The device to perform computations on.
|
1544
|
+
|
1545
|
+
Returns:
|
1546
|
+
float: The computed IRM penalty.
|
1547
|
+
"""
|
1548
|
+
weighted_losses = [loss.clone().detach().requires_grad_(True).to(device) * dummy_w for loss in losses]
|
1549
|
+
gradients = [grad(w_loss, dummy_w, create_graph=True)[0] for w_loss in weighted_losses]
|
1550
|
+
irm_penalty = 0.0
|
1551
|
+
for g1, g2 in combinations(gradients, 2):
|
1552
|
+
irm_penalty += (g1.dot(g2))**2
|
1553
|
+
return irm_penalty
|
1554
|
+
|
1555
|
+
#def print_model_summary(base_model, channels, height, width):
|
1556
|
+
# """
|
1557
|
+
# Prints the summary of a given base model.
|
1558
|
+
#
|
1559
|
+
# Args:
|
1560
|
+
# base_model (torch.nn.Module): The base model to print the summary of.
|
1561
|
+
# channels (int): The number of input channels.
|
1562
|
+
# height (int): The height of the input.
|
1563
|
+
# width (int): The width of the input.
|
1564
|
+
#
|
1565
|
+
# Returns:
|
1566
|
+
# None
|
1567
|
+
# """
|
1568
|
+
# device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
1569
|
+
# base_model.to(device)
|
1570
|
+
# summary(base_model, (channels, height, width))
|
1571
|
+
# return
|
1572
|
+
|
1573
|
+
def choose_model(model_type, device, init_weights=True, dropout_rate=0, use_checkpoint=False, channels=3, height=224, width=224, chan_dict=None, num_classes=2):
|
1574
|
+
"""
|
1575
|
+
Choose a model for classification.
|
1576
|
+
|
1577
|
+
Args:
|
1578
|
+
model_type (str): The type of model to choose. Can be one of the pre-defined TorchVision models or 'custom' for a custom model.
|
1579
|
+
device (str): The device to use for model inference.
|
1580
|
+
init_weights (bool, optional): Whether to initialize the model with pre-trained weights. Defaults to True.
|
1581
|
+
dropout_rate (float, optional): The dropout rate to use in the model. Defaults to 0.
|
1582
|
+
use_checkpoint (bool, optional): Whether to use checkpointing during model training. Defaults to False.
|
1583
|
+
channels (int, optional): The number of input channels for the model. Defaults to 3.
|
1584
|
+
height (int, optional): The height of the input images for the model. Defaults to 224.
|
1585
|
+
width (int, optional): The width of the input images for the model. Defaults to 224.
|
1586
|
+
chan_dict (dict, optional): A dictionary containing channel information for custom models. Defaults to None.
|
1587
|
+
num_classes (int, optional): The number of output classes for the model. Defaults to 2.
|
1588
|
+
|
1589
|
+
Returns:
|
1590
|
+
torch.nn.Module: The chosen model.
|
1591
|
+
"""
|
1592
|
+
|
1593
|
+
torch_model_types = torchvision.models.list_models(module=torchvision.models)
|
1594
|
+
model_types = torch_model_types + ['custom']
|
1595
|
+
|
1596
|
+
if not chan_dict is None:
|
1597
|
+
pathogen_channel = chan_dict['pathogen_channel']
|
1598
|
+
nucleus_channel = chan_dict['nucleus_channel']
|
1599
|
+
protein_channel = chan_dict['protein_channel']
|
1600
|
+
|
1601
|
+
if model_type not in model_types:
|
1602
|
+
print(f'Invalid model_type: {model_type}. Compatible model_types: {model_types}')
|
1603
|
+
return
|
1604
|
+
|
1605
|
+
print(f'\rModel parameters: Architecture: {model_type} init_weights: {init_weights} dropout_rate: {dropout_rate} use_checkpoint: {use_checkpoint}', end='\r', flush=True)
|
1606
|
+
|
1607
|
+
if model_type == 'custom':
|
1608
|
+
|
1609
|
+
base_model = CustomCellClassifier(num_classes, pathogen_channel=pathogen_channel, use_attention=True, use_checkpoint=use_checkpoint, dropout_rate=dropout_rate)
|
1610
|
+
#base_model = CustomCellClassifier(num_classes=2, pathogen_channel=pathogen_channel, nucleus_channel=nucleus_channel, protein_channel=protein_channel, dropout_rate=dropout_rate, use_checkpoint=use_checkpoint)
|
1611
|
+
elif model_type in torch_model_types:
|
1612
|
+
base_model = TorchModel(model_name=model_type, pretrained=init_weights, dropout_rate=dropout_rate)
|
1613
|
+
else:
|
1614
|
+
print(f'Compatible model_types: {model_types}')
|
1615
|
+
raise ValueError(f"Invalid model_type: {model_type}")
|
1616
|
+
|
1617
|
+
print(base_model)
|
1618
|
+
|
1619
|
+
return base_model
|
1620
|
+
|
1621
|
+
def calculate_loss(output, target, loss_type='binary_cross_entropy_with_logits'):
|
1622
|
+
if loss_type == 'binary_cross_entropy_with_logits':
|
1623
|
+
loss = F.binary_cross_entropy_with_logits(output, target)
|
1624
|
+
elif loss_type == 'focal_loss':
|
1625
|
+
focal_loss_fn = FocalLossWithLogits(alpha=1, gamma=2)
|
1626
|
+
loss = focal_loss_fn(output, target)
|
1627
|
+
return loss
|
1628
|
+
|
1629
|
+
def pick_best_model(src):
|
1630
|
+
all_files = os.listdir(src)
|
1631
|
+
pth_files = [f for f in all_files if f.endswith('.pth')]
|
1632
|
+
pattern = re.compile(r'_epoch_(\d+)_acc_(\d+(?:\.\d+)?)')
|
1633
|
+
|
1634
|
+
def sort_key(x):
|
1635
|
+
match = pattern.search(x)
|
1636
|
+
if not match:
|
1637
|
+
return (0.0, 0) # Make the primary sorting key float for consistency
|
1638
|
+
g1, g2 = match.groups()
|
1639
|
+
return (float(g2), int(g1)) # Primary sort by accuracy (g2) and secondary sort by epoch (g1)
|
1640
|
+
|
1641
|
+
sorted_files = sorted(pth_files, key=sort_key, reverse=True)
|
1642
|
+
best_model = sorted_files[0]
|
1643
|
+
return os.path.join(src, best_model)
|
1644
|
+
|
1645
|
+
def get_paths_from_db(df, png_df, image_type='cell_png'):
|
1646
|
+
objects = df.index.tolist()
|
1647
|
+
filtered_df = png_df[png_df['png_path'].str.contains(image_type) & png_df['prcfo'].isin(objects)]
|
1648
|
+
return filtered_df
|
1649
|
+
|
1650
|
+
def save_file_lists(dst, data_set, ls):
|
1651
|
+
df = pd.DataFrame(ls, columns=[data_set])
|
1652
|
+
df.to_csv(f'{dst}/{data_set}.csv', index=False)
|
1653
|
+
return
|
1654
|
+
|
1655
|
+
def augment_single_image(args):
|
1656
|
+
img_path, dst = args
|
1657
|
+
img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED)
|
1658
|
+
filename = os.path.basename(img_path).split('.')[0]
|
1659
|
+
|
1660
|
+
# Original Image
|
1661
|
+
cv2.imwrite(os.path.join(dst, f"{filename}_original.png"), img)
|
1662
|
+
|
1663
|
+
# 90 degree rotation
|
1664
|
+
img_rot_90 = cv2.rotate(img, cv2.ROTATE_90_CLOCKWISE)
|
1665
|
+
cv2.imwrite(os.path.join(dst, f"{filename}_rot_90.png"), img_rot_90)
|
1666
|
+
|
1667
|
+
# 180 degree rotation
|
1668
|
+
img_rot_180 = cv2.rotate(img, cv2.ROTATE_180)
|
1669
|
+
cv2.imwrite(os.path.join(dst, f"{filename}_rot_180.png"), img_rot_180)
|
1670
|
+
|
1671
|
+
# 270 degree rotation
|
1672
|
+
img_rot_270 = cv2.rotate(img, cv2.ROTATE_90_COUNTERCLOCKWISE)
|
1673
|
+
cv2.imwrite(os.path.join(dst, f"{filename}_rot_270.png"), img_rot_270)
|
1674
|
+
|
1675
|
+
# Horizontal Flip
|
1676
|
+
img_flip_hor = cv2.flip(img, 1)
|
1677
|
+
cv2.imwrite(os.path.join(dst, f"{filename}_flip_hor.png"), img_flip_hor)
|
1678
|
+
|
1679
|
+
# Vertical Flip
|
1680
|
+
img_flip_ver = cv2.flip(img, 0)
|
1681
|
+
cv2.imwrite(os.path.join(dst, f"{filename}_flip_ver.png"), img_flip_ver)
|
1682
|
+
|
1683
|
+
def augment_images(file_paths, dst):
|
1684
|
+
if not os.path.exists(dst):
|
1685
|
+
os.makedirs(dst)
|
1686
|
+
|
1687
|
+
args_list = [(img_path, dst) for img_path in file_paths]
|
1688
|
+
|
1689
|
+
with Pool(cpu_count()) as pool:
|
1690
|
+
pool.map(augment_single_image, args_list)
|
1691
|
+
|
1692
|
+
def augment_classes(dst, nc, pc, generate=True,move=True):
|
1693
|
+
aug_nc = os.path.join(dst,'aug_nc')
|
1694
|
+
aug_pc = os.path.join(dst,'aug_pc')
|
1695
|
+
all_ = len(nc)+len(pc)
|
1696
|
+
if generate == True:
|
1697
|
+
os.makedirs(aug_nc, exist_ok=True)
|
1698
|
+
if __name__ == '__main__':
|
1699
|
+
augment_images(file_paths=nc, dst=aug_nc)
|
1700
|
+
|
1701
|
+
os.makedirs(aug_pc, exist_ok=True)
|
1702
|
+
if __name__ == '__main__':
|
1703
|
+
augment_images(file_paths=pc, dst=aug_pc)
|
1704
|
+
|
1705
|
+
if move == True:
|
1706
|
+
aug = os.path.join(dst,'aug')
|
1707
|
+
aug_train_nc = os.path.join(aug,'train/nc')
|
1708
|
+
aug_train_pc = os.path.join(aug,'train/pc')
|
1709
|
+
aug_test_nc = os.path.join(aug,'test/nc')
|
1710
|
+
aug_test_pc = os.path.join(aug,'test/pc')
|
1711
|
+
|
1712
|
+
os.makedirs(aug_train_nc, exist_ok=True)
|
1713
|
+
os.makedirs(aug_train_pc, exist_ok=True)
|
1714
|
+
os.makedirs(aug_test_nc, exist_ok=True)
|
1715
|
+
os.makedirs(aug_test_pc, exist_ok=True)
|
1716
|
+
|
1717
|
+
aug_nc_list = [os.path.join(aug_nc, file) for file in os.listdir(aug_nc)]
|
1718
|
+
aug_pc_list = [os.path.join(aug_pc, file) for file in os.listdir(aug_pc)]
|
1719
|
+
|
1720
|
+
nc_train_data, nc_test_data = train_test_split(aug_nc_list, test_size=0.1, shuffle=True, random_state=42)
|
1721
|
+
pc_train_data, pc_test_data = train_test_split(aug_pc_list, test_size=0.1, shuffle=True, random_state=42)
|
1722
|
+
|
1723
|
+
i=0
|
1724
|
+
for path in nc_train_data:
|
1725
|
+
i+=1
|
1726
|
+
shutil.move(path, os.path.join(aug_train_nc, os.path.basename(path)))
|
1727
|
+
print(f'{i}/{all_}', end='\r', flush=True)
|
1728
|
+
for path in nc_test_data:
|
1729
|
+
i+=1
|
1730
|
+
shutil.move(path, os.path.join(aug_test_nc, os.path.basename(path)))
|
1731
|
+
print(f'{i}/{all_}', end='\r', flush=True)
|
1732
|
+
for path in pc_train_data:
|
1733
|
+
i+=1
|
1734
|
+
shutil.move(path, os.path.join(aug_train_pc, os.path.basename(path)))
|
1735
|
+
print(f'{i}/{all_}', end='\r', flush=True)
|
1736
|
+
for path in pc_test_data:
|
1737
|
+
i+=1
|
1738
|
+
shutil.move(path, os.path.join(aug_test_pc, os.path.basename(path)))
|
1739
|
+
print(f'{i}/{all_}', end='\r', flush=True)
|
1740
|
+
print(f'Train nc: {len(os.listdir(aug_train_nc))}, Train pc:{len(os.listdir(aug_test_nc))}, Test nc:{len(os.listdir(aug_train_pc))}, Test pc:{len(os.listdir(aug_test_pc))}')
|
1741
|
+
return
|
1742
|
+
|
1743
|
+
def annotate_predictions(csv_loc):
|
1744
|
+
df = pd.read_csv(csv_loc)
|
1745
|
+
df['filename'] = df['path'].apply(lambda x: x.split('/')[-1])
|
1746
|
+
df[['plate', 'well', 'field', 'object']] = df['filename'].str.split('_', expand=True)
|
1747
|
+
df['object'] = df['object'].str.replace('.png', '')
|
1748
|
+
|
1749
|
+
def assign_condition(row):
|
1750
|
+
plate = int(row['plate'])
|
1751
|
+
col = int(row['well'][1:])
|
1752
|
+
|
1753
|
+
if col > 3:
|
1754
|
+
if plate in [1, 2, 3, 4]:
|
1755
|
+
return 'screen'
|
1756
|
+
elif plate in [5, 6, 7, 8]:
|
1757
|
+
return 'pc'
|
1758
|
+
elif col in [1, 2, 3]:
|
1759
|
+
return 'nc'
|
1760
|
+
else:
|
1761
|
+
return ''
|
1762
|
+
|
1763
|
+
df['cond'] = df.apply(assign_condition, axis=1)
|
1764
|
+
return df
|
1765
|
+
|
1766
|
+
def init_globals(counter_, lock_):
|
1767
|
+
global counter, lock
|
1768
|
+
counter = counter_
|
1769
|
+
lock = lock_
|
1770
|
+
|
1771
|
+
def add_images_to_tar(args):
|
1772
|
+
global counter, lock, total_images
|
1773
|
+
paths_chunk, tar_path = args
|
1774
|
+
with tarfile.open(tar_path, 'w') as tar:
|
1775
|
+
for img_path in paths_chunk:
|
1776
|
+
arcname = os.path.basename(img_path)
|
1777
|
+
try:
|
1778
|
+
tar.add(img_path, arcname=arcname)
|
1779
|
+
with lock:
|
1780
|
+
counter.value += 1
|
1781
|
+
print(f"\rProcessed: {counter.value}/{total_images}", end='', flush=True)
|
1782
|
+
except FileNotFoundError:
|
1783
|
+
print(f"File not found: {img_path}")
|
1784
|
+
return tar_path
|
1785
|
+
|
1786
|
+
def generate_fraction_map(df, gene_column, min_frequency=0.0):
|
1787
|
+
df['fraction'] = df['count']/df['well_read_sum']
|
1788
|
+
genes = df[gene_column].unique().tolist()
|
1789
|
+
wells = df['prc'].unique().tolist()
|
1790
|
+
print(len(genes),len(wells))
|
1791
|
+
independent_variables = pd.DataFrame(columns=genes, index = wells)
|
1792
|
+
for index, row in df.iterrows():
|
1793
|
+
prc = row['prc']
|
1794
|
+
gene = row[gene_column]
|
1795
|
+
fraction = row['fraction']
|
1796
|
+
independent_variables.loc[prc,gene]=fraction
|
1797
|
+
independent_variables = independent_variables.dropna(axis=1, how='all')
|
1798
|
+
independent_variables = independent_variables.dropna(axis=0, how='all')
|
1799
|
+
independent_variables['sum'] = independent_variables.sum(axis=1)
|
1800
|
+
#sums = independent_variables['sum'].unique().tolist()
|
1801
|
+
#print(sums)
|
1802
|
+
#independent_variables = independent_variables[(independent_variables['sum'] == 0.0) | (independent_variables['sum'] == 1.0)]
|
1803
|
+
independent_variables = independent_variables.fillna(0.0)
|
1804
|
+
independent_variables = independent_variables.drop(columns=[col for col in independent_variables.columns if independent_variables[col].max() < min_frequency])
|
1805
|
+
independent_variables = independent_variables.drop('sum', axis=1)
|
1806
|
+
independent_variables.index.name = 'prc'
|
1807
|
+
loc = '/mnt/data/CellVoyager/20x/tsg101/crispr_screen/all/measurements/iv.csv'
|
1808
|
+
independent_variables.to_csv(loc, index=True, header=True, mode='w')
|
1809
|
+
return independent_variables
|
1810
|
+
|
1811
|
+
def fishers_odds(df, threshold=0.5, phenotyp_col='mean_pred'):
|
1812
|
+
# Binning based on phenotype score (e.g., above 0.8 as high)
|
1813
|
+
df['high_phenotype'] = df[phenotyp_col] < threshold
|
1814
|
+
|
1815
|
+
results = []
|
1816
|
+
mutants = df.columns[:-2]
|
1817
|
+
mutants = [item for item in mutants if item not in ['count_prc','mean_pathogen_area']]
|
1818
|
+
print(f'fishers df')
|
1819
|
+
display(df)
|
1820
|
+
# Perform Fisher's exact test for each mutant
|
1821
|
+
for mutant in mutants:
|
1822
|
+
contingency_table = pd.crosstab(df[mutant] > 0, df['high_phenotype'])
|
1823
|
+
if contingency_table.shape == (2, 2): # Check for 2x2 shape
|
1824
|
+
odds_ratio, p_value = fisher_exact(contingency_table)
|
1825
|
+
results.append((mutant, odds_ratio, p_value))
|
1826
|
+
else:
|
1827
|
+
# Optionally handle non-2x2 tables (e.g., append NaN or other placeholders)
|
1828
|
+
results.append((mutant, float('nan'), float('nan')))
|
1829
|
+
|
1830
|
+
# Convert results to DataFrame for easier handling
|
1831
|
+
results_df = pd.DataFrame(results, columns=['Mutant', 'OddsRatio', 'PValue'])
|
1832
|
+
# Remove rows with undefined odds ratios or p-values
|
1833
|
+
filtered_results_df = results_df.dropna(subset=['OddsRatio', 'PValue'])
|
1834
|
+
|
1835
|
+
pvalues = filtered_results_df['PValue'].values
|
1836
|
+
|
1837
|
+
# Check if pvalues array is empty
|
1838
|
+
if len(pvalues) > 0:
|
1839
|
+
# Apply Benjamini-Hochberg correction
|
1840
|
+
adjusted_pvalues = multipletests(pvalues, method='fdr_bh')[1]
|
1841
|
+
# Add adjusted p-values back to the dataframe
|
1842
|
+
filtered_results_df['AdjustedPValue'] = adjusted_pvalues
|
1843
|
+
# Filter significant results
|
1844
|
+
significant_mutants = filtered_results_df[filtered_results_df['AdjustedPValue'] < 0.05]
|
1845
|
+
else:
|
1846
|
+
print("No p-values to adjust. Check your data filtering steps.")
|
1847
|
+
significant_mutants = pd.DataFrame() # return empty DataFrame in this case
|
1848
|
+
|
1849
|
+
return filtered_results_df
|
1850
|
+
|
1851
|
+
def model_metrics(model):
|
1852
|
+
|
1853
|
+
# Calculate additional metrics
|
1854
|
+
rmse = np.sqrt(model.mse_resid)
|
1855
|
+
mae = np.mean(np.abs(model.resid))
|
1856
|
+
durbin_w_value = durbin_watson(model.resid)
|
1857
|
+
|
1858
|
+
# Display the additional metrics
|
1859
|
+
print("\nAdditional Metrics:")
|
1860
|
+
print(f"Root Mean Squared Error (RMSE): {rmse}")
|
1861
|
+
print(f"Mean Absolute Error (MAE): {mae}")
|
1862
|
+
print(f"Durbin-Watson: {durbin_w_value}")
|
1863
|
+
|
1864
|
+
# Residual Plots
|
1865
|
+
fig, ax = plt.subplots(2, 2, figsize=(15, 12))
|
1866
|
+
|
1867
|
+
# Residual vs. Fitted
|
1868
|
+
ax[0, 0].scatter(model.fittedvalues, model.resid, edgecolors = 'k', facecolors = 'none')
|
1869
|
+
ax[0, 0].set_title('Residuals vs Fitted')
|
1870
|
+
ax[0, 0].set_xlabel('Fitted values')
|
1871
|
+
ax[0, 0].set_ylabel('Residuals')
|
1872
|
+
|
1873
|
+
# Histogram
|
1874
|
+
sns.histplot(model.resid, kde=True, ax=ax[0, 1])
|
1875
|
+
ax[0, 1].set_title('Histogram of Residuals')
|
1876
|
+
ax[0, 1].set_xlabel('Residuals')
|
1877
|
+
|
1878
|
+
# QQ Plot
|
1879
|
+
sm.qqplot(model.resid, fit=True, line='45', ax=ax[1, 0])
|
1880
|
+
ax[1, 0].set_title('QQ Plot')
|
1881
|
+
|
1882
|
+
# Scale-Location
|
1883
|
+
standardized_resid = model.get_influence().resid_studentized_internal
|
1884
|
+
ax[1, 1].scatter(model.fittedvalues, np.sqrt(np.abs(standardized_resid)), edgecolors = 'k', facecolors = 'none')
|
1885
|
+
ax[1, 1].set_title('Scale-Location')
|
1886
|
+
ax[1, 1].set_xlabel('Fitted values')
|
1887
|
+
ax[1, 1].set_ylabel('$\sqrt{|Standardized Residuals|}$')
|
1888
|
+
|
1889
|
+
plt.tight_layout()
|
1890
|
+
plt.show()
|
1891
|
+
|
1892
|
+
def check_multicollinearity(x):
|
1893
|
+
"""Checks multicollinearity of the predictors by computing the VIF."""
|
1894
|
+
vif_data = pd.DataFrame()
|
1895
|
+
vif_data["Variable"] = x.columns
|
1896
|
+
vif_data["VIF"] = [variance_inflation_factor(x.values, i) for i in range(x.shape[1])]
|
1897
|
+
return vif_data
|
1898
|
+
|
1899
|
+
def generate_dependent_variable(df, dv_loc, pc_min=0.95, nc_max=0.05, agg_type='mean'):
|
1900
|
+
|
1901
|
+
from .plot import _plot_histograms_and_stats, _plot_plates
|
1902
|
+
|
1903
|
+
def qstring_to_float(qstr):
|
1904
|
+
number = int(qstr[1:]) # Remove the "q" and convert the rest to an integer
|
1905
|
+
return number / 100.0
|
1906
|
+
|
1907
|
+
print("Unique values in plate:", df['plate'].unique())
|
1908
|
+
dv_cell_loc = f'{dv_loc}/dv_cell.csv'
|
1909
|
+
dv_well_loc = f'{dv_loc}/dv_well.csv'
|
1910
|
+
|
1911
|
+
df['pred'] = 1-df['pred'] #if you swiched pc and nc
|
1912
|
+
df = df[(df['pred'] <= nc_max) | (df['pred'] >= pc_min)]
|
1913
|
+
|
1914
|
+
if 'prc' not in df.columns:
|
1915
|
+
df['prc'] = df['plate'] + '_' + df['row'] + '_' + df['col']
|
1916
|
+
|
1917
|
+
if agg_type.startswith('q'):
|
1918
|
+
val = qstring_to_float(agg_type)
|
1919
|
+
agg_type = lambda x: x.quantile(val)
|
1920
|
+
|
1921
|
+
# Aggregating for mean prediction and total count
|
1922
|
+
df_grouped = df.groupby('prc').agg(
|
1923
|
+
pred=('pred', agg_type),
|
1924
|
+
recruitment=('recruitment', agg_type),
|
1925
|
+
count_prc=('prc', 'size'),
|
1926
|
+
#count_above_95=('pred', lambda x: (x > 0.95).sum()),
|
1927
|
+
mean_pathogen_area=('pathogen_area', 'mean')
|
1928
|
+
)
|
1929
|
+
|
1930
|
+
df_cell = df[['prc', 'pred', 'pathogen_area', 'recruitment']]
|
1931
|
+
|
1932
|
+
df_cell.to_csv(dv_cell_loc, index=True, header=True, mode='w')
|
1933
|
+
df_grouped.to_csv(dv_well_loc, index=True, header=True, mode='w') # Changed from loc to dv_loc
|
1934
|
+
display(df)
|
1935
|
+
_plot_histograms_and_stats(df)
|
1936
|
+
df_grouped = df_grouped.sort_values(by='count_prc', ascending=True)
|
1937
|
+
display(df_grouped)
|
1938
|
+
print('pred')
|
1939
|
+
_plot_plates(df=df_cell, variable='pred', grouping='mean', min_max='allq', cmap='viridis')
|
1940
|
+
print('recruitment')
|
1941
|
+
_plot_plates(df=df_cell, variable='recruitment', grouping='mean', min_max='allq', cmap='viridis')
|
1942
|
+
|
1943
|
+
return df_grouped
|
1944
|
+
|
1945
|
+
def lasso_reg(merged_df, alpha_value=0.01, reg_type='lasso'):
|
1946
|
+
# Separate predictors and response
|
1947
|
+
X = merged_df[['gene', 'grna', 'plate', 'row', 'column']]
|
1948
|
+
y = merged_df['pred']
|
1949
|
+
|
1950
|
+
# One-hot encode the categorical predictors
|
1951
|
+
encoder = OneHotEncoder(drop='first') # drop one category to avoid the dummy variable trap
|
1952
|
+
X_encoded = encoder.fit_transform(X).toarray()
|
1953
|
+
feature_names = encoder.get_feature_names_out(input_features=X.columns)
|
1954
|
+
|
1955
|
+
if reg_type == 'ridge':
|
1956
|
+
# Fit ridge regression
|
1957
|
+
ridge = Ridge(alpha=alpha_value)
|
1958
|
+
ridge.fit(X_encoded, y)
|
1959
|
+
coefficients = ridge.coef_
|
1960
|
+
coeff_dict = dict(zip(feature_names, ridge.coef_))
|
1961
|
+
|
1962
|
+
if reg_type == 'lasso':
|
1963
|
+
# Fit Lasso regression
|
1964
|
+
lasso = Lasso(alpha=alpha_value)
|
1965
|
+
lasso.fit(X_encoded, y)
|
1966
|
+
coefficients = lasso.coef_
|
1967
|
+
coeff_dict = dict(zip(feature_names, lasso.coef_))
|
1968
|
+
coeff_df = pd.DataFrame(list(coeff_dict.items()), columns=['Feature', 'Coefficient'])
|
1969
|
+
return coeff_df
|
1970
|
+
|
1971
|
+
def MLR(merged_df, refine_model):
|
1972
|
+
|
1973
|
+
from .plot import _reg_v_plot
|
1974
|
+
|
1975
|
+
#model = smf.ols("pred ~ gene + grna + gene:grna + plate + row + column", merged_df).fit()
|
1976
|
+
model = smf.ols("pred ~ gene:grna + plate + row + column", merged_df).fit()
|
1977
|
+
# Display model metrics and summary
|
1978
|
+
model_metrics(model)
|
1979
|
+
|
1980
|
+
if refine_model:
|
1981
|
+
# Filter outliers
|
1982
|
+
std_resid = model.get_influence().resid_studentized_internal
|
1983
|
+
outliers_resid = np.where(np.abs(std_resid) > 3)[0]
|
1984
|
+
(c, p) = model.get_influence().cooks_distance
|
1985
|
+
outliers_cooks = np.where(c > 4/(len(merged_df)-merged_df.shape[1]-1))[0]
|
1986
|
+
outliers = reduce(np.union1d, (outliers_resid, outliers_cooks))
|
1987
|
+
merged_df_filtered = merged_df.drop(merged_df.index[outliers])
|
1988
|
+
|
1989
|
+
display(merged_df_filtered)
|
1990
|
+
|
1991
|
+
# Refit the model with filtered data
|
1992
|
+
model = smf.ols("pred ~ gene + grna + gene:grna + row + column", merged_df_filtered).fit()
|
1993
|
+
print("Number of outliers detected by standardized residuals:", len(outliers_resid))
|
1994
|
+
print("Number of outliers detected by Cook's distance:", len(outliers_cooks))
|
1995
|
+
|
1996
|
+
model_metrics(model)
|
1997
|
+
print(model.summary())
|
1998
|
+
|
1999
|
+
# Extract interaction coefficients and determine the maximum effect size
|
2000
|
+
interaction_coeffs = {key: val for key, val in model.params.items() if "gene[T." in key and ":grna[T." in key}
|
2001
|
+
interaction_pvalues = {key: val for key, val in model.pvalues.items() if "gene[T." in key and ":grna[T." in key}
|
2002
|
+
|
2003
|
+
max_effects = {}
|
2004
|
+
max_effects_pvalues = {}
|
2005
|
+
for key, val in interaction_coeffs.items():
|
2006
|
+
gene_name = key.split(":")[0].replace("gene[T.", "").replace("]", "")
|
2007
|
+
if gene_name not in max_effects or abs(max_effects[gene_name]) < abs(val):
|
2008
|
+
max_effects[gene_name] = val
|
2009
|
+
max_effects_pvalues[gene_name] = interaction_pvalues[key]
|
2010
|
+
|
2011
|
+
for key in max_effects:
|
2012
|
+
print(f"Key: {key}: {max_effects[key]}, p:{max_effects_pvalues[key]}")
|
2013
|
+
|
2014
|
+
df = pd.DataFrame([max_effects, max_effects_pvalues])
|
2015
|
+
df = df.transpose()
|
2016
|
+
df = df.rename(columns={df.columns[0]: 'effect', df.columns[1]: 'p'})
|
2017
|
+
df = df.sort_values(by=['effect', 'p'], ascending=[False, True])
|
2018
|
+
|
2019
|
+
_reg_v_plot(df)
|
2020
|
+
|
2021
|
+
return max_effects, max_effects_pvalues, model, df
|
2022
|
+
|
2023
|
+
#def normalize_to_dtype(array, q1=2, q2=98, percentiles=None):
|
2024
|
+
# if len(array.shape) == 2:
|
2025
|
+
# array = np.expand_dims(array, axis=-1)
|
2026
|
+
# num_channels = array.shape[-1]
|
2027
|
+
# new_stack = np.empty_like(array)
|
2028
|
+
# for channel in range(num_channels):
|
2029
|
+
# img = array[..., channel]
|
2030
|
+
# non_zero_img = img[img > 0]
|
2031
|
+
# if non_zero_img.size > 0:
|
2032
|
+
# img_min = np.percentile(non_zero_img, q1)
|
2033
|
+
# img_max = np.percentile(non_zero_img, q2)
|
2034
|
+
# else:
|
2035
|
+
# img_min, img_max = (percentiles[channel] if percentiles and channel < len(percentiles)
|
2036
|
+
# else (img.min(), img.max()))
|
2037
|
+
# new_stack[..., channel] = rescale_intensity(img, in_range=(img_min, img_max), out_range='dtype')
|
2038
|
+
# if new_stack.shape[-1] == 1:
|
2039
|
+
# new_stack = np.squeeze(new_stack, axis=-1)
|
2040
|
+
# return new_stack
|
2041
|
+
|
2042
|
+
def get_files_from_dir(dir_path, file_extension="*"):
|
2043
|
+
return glob(os.path.join(dir_path, file_extension))
|
2044
|
+
|
2045
|
+
def create_circular_mask(h, w, center=None, radius=None):
|
2046
|
+
if center is None: # use the middle of the image
|
2047
|
+
center = (int(w/2), int(h/2))
|
2048
|
+
if radius is None: # use the smallest distance between the center and image walls
|
2049
|
+
radius = min(center[0], center[1], w-center[0], h-center[1])
|
2050
|
+
|
2051
|
+
Y, X = np.ogrid[:h, :w]
|
2052
|
+
dist_from_center = np.sqrt((X - center[0])**2 + (Y-center[1])**2)
|
2053
|
+
|
2054
|
+
mask = dist_from_center <= radius
|
2055
|
+
return mask
|
2056
|
+
|
2057
|
+
def apply_mask(image, output_value=0):
|
2058
|
+
h, w = image.shape[:2] # Assuming image is grayscale or RGB
|
2059
|
+
mask = create_circular_mask(h, w)
|
2060
|
+
|
2061
|
+
# If the image has more than one channel, repeat the mask for each channel
|
2062
|
+
if len(image.shape) > 2:
|
2063
|
+
mask = np.repeat(mask[:, :, np.newaxis], image.shape[2], axis=2)
|
2064
|
+
|
2065
|
+
# Apply the mask - set pixels outside of the mask to output_value
|
2066
|
+
masked_image = np.where(mask, image, output_value)
|
2067
|
+
return masked_image
|
2068
|
+
|
2069
|
+
def invert_image(image):
|
2070
|
+
# The maximum value depends on the image dtype (e.g., 255 for uint8)
|
2071
|
+
max_value = np.iinfo(image.dtype).max
|
2072
|
+
inverted_image = max_value - image
|
2073
|
+
return inverted_image
|
2074
|
+
|
2075
|
+
def resize_images_and_labels(images, labels, target_height, target_width, show_example=True):
|
2076
|
+
|
2077
|
+
from .plot import plot_resize
|
2078
|
+
|
2079
|
+
resized_images = []
|
2080
|
+
resized_labels = []
|
2081
|
+
if not images is None and not labels is None:
|
2082
|
+
for image, label in zip(images, labels):
|
2083
|
+
|
2084
|
+
if image.ndim == 2:
|
2085
|
+
image_shape = (target_height, target_width)
|
2086
|
+
elif image.ndim == 3:
|
2087
|
+
image_shape = (target_height, target_width, image.shape[-1])
|
2088
|
+
|
2089
|
+
resized_image = resizescikit(image, image_shape, preserve_range=True, anti_aliasing=True).astype(image.dtype)
|
2090
|
+
resized_label = resizescikit(label, (target_height, target_width), order=0, preserve_range=True, anti_aliasing=False).astype(label.dtype)
|
2091
|
+
|
2092
|
+
if resized_image.shape[-1] == 1:
|
2093
|
+
resized_image = np.squeeze(resized_image)
|
2094
|
+
|
2095
|
+
resized_images.append(resized_image)
|
2096
|
+
resized_labels.append(resized_label)
|
2097
|
+
|
2098
|
+
elif not images is None:
|
2099
|
+
for image in images:
|
2100
|
+
|
2101
|
+
if image.ndim == 2:
|
2102
|
+
image_shape = (target_height, target_width)
|
2103
|
+
elif image.ndim == 3:
|
2104
|
+
image_shape = (target_height, target_width, image.shape[-1])
|
2105
|
+
|
2106
|
+
resized_image = resizescikit(image, image_shape, preserve_range=True, anti_aliasing=True).astype(image.dtype)
|
2107
|
+
|
2108
|
+
if resized_image.shape[-1] == 1:
|
2109
|
+
resized_image = np.squeeze(resized_image)
|
2110
|
+
|
2111
|
+
resized_images.append(resized_image)
|
2112
|
+
|
2113
|
+
elif not labels is None:
|
2114
|
+
for label in labels:
|
2115
|
+
resized_label = resizescikit(label, (target_height, target_width), order=0, preserve_range=True, anti_aliasing=False).astype(label.dtype)
|
2116
|
+
resized_labels.append(resized_label)
|
2117
|
+
|
2118
|
+
if show_example:
|
2119
|
+
if not images is None and not labels is None:
|
2120
|
+
plot_resize(images, resized_images, labels, resized_labels)
|
2121
|
+
elif not images is None:
|
2122
|
+
plot_resize(images, resized_images, images, resized_images)
|
2123
|
+
elif not labels is None:
|
2124
|
+
plot_resize(labels, resized_labels, labels, resized_labels)
|
2125
|
+
|
2126
|
+
return resized_images, resized_labels
|
2127
|
+
|
2128
|
+
def resize_labels_back(labels, orig_dims):
|
2129
|
+
resized_labels = []
|
2130
|
+
|
2131
|
+
if len(labels) != len(orig_dims):
|
2132
|
+
raise ValueError("The length of labels and orig_dims must match.")
|
2133
|
+
|
2134
|
+
for label, dims in zip(labels, orig_dims):
|
2135
|
+
# Ensure dims is a tuple of two integers (width, height)
|
2136
|
+
if not isinstance(dims, tuple) or len(dims) != 2:
|
2137
|
+
raise ValueError("Each element in orig_dims must be a tuple of two integers representing the original dimensions (width, height)")
|
2138
|
+
|
2139
|
+
resized_label = resizescikit(label, dims, order=0, preserve_range=True, anti_aliasing=False).astype(label.dtype)
|
2140
|
+
resized_labels.append(resized_label)
|
2141
|
+
|
2142
|
+
return resized_labels
|
2143
|
+
|
2144
|
+
def calculate_iou(mask1, mask2):
|
2145
|
+
mask1, mask2 = pad_to_same_shape(mask1, mask2)
|
2146
|
+
intersection = np.logical_and(mask1, mask2).sum()
|
2147
|
+
union = np.logical_or(mask1, mask2).sum()
|
2148
|
+
return intersection / union if union != 0 else 0
|
2149
|
+
|
2150
|
+
def match_masks(true_masks, pred_masks, iou_threshold):
|
2151
|
+
matches = []
|
2152
|
+
matched_true_masks_indices = set() # Use set to store indices of matched true masks
|
2153
|
+
|
2154
|
+
for pred_mask in pred_masks:
|
2155
|
+
for true_mask_index, true_mask in enumerate(true_masks):
|
2156
|
+
if true_mask_index not in matched_true_masks_indices:
|
2157
|
+
iou = calculate_iou(true_mask, pred_mask)
|
2158
|
+
if iou >= iou_threshold:
|
2159
|
+
matches.append((true_mask, pred_mask))
|
2160
|
+
matched_true_masks_indices.add(true_mask_index) # Store the index of the matched true mask
|
2161
|
+
break # Move on to the next predicted mask
|
2162
|
+
return matches
|
2163
|
+
|
2164
|
+
def compute_average_precision(matches, num_true_masks, num_pred_masks):
|
2165
|
+
TP = len(matches)
|
2166
|
+
FP = num_pred_masks - TP
|
2167
|
+
FN = num_true_masks - TP
|
2168
|
+
precision = TP / (TP + FP) if TP + FP > 0 else 0
|
2169
|
+
recall = TP / (TP + FN) if TP + FN > 0 else 0
|
2170
|
+
return precision, recall
|
2171
|
+
|
2172
|
+
def pad_to_same_shape(mask1, mask2):
|
2173
|
+
# Find the shape differences
|
2174
|
+
shape_diff = np.array([max(mask1.shape[0], mask2.shape[0]) - mask1.shape[0],
|
2175
|
+
max(mask1.shape[1], mask2.shape[1]) - mask1.shape[1]])
|
2176
|
+
pad_mask1 = ((0, shape_diff[0]), (0, shape_diff[1]))
|
2177
|
+
shape_diff = np.array([max(mask1.shape[0], mask2.shape[0]) - mask2.shape[0],
|
2178
|
+
max(mask1.shape[1], mask2.shape[1]) - mask2.shape[1]])
|
2179
|
+
pad_mask2 = ((0, shape_diff[0]), (0, shape_diff[1]))
|
2180
|
+
|
2181
|
+
padded_mask1 = np.pad(mask1, pad_mask1, mode='constant', constant_values=0)
|
2182
|
+
padded_mask2 = np.pad(mask2, pad_mask2, mode='constant', constant_values=0)
|
2183
|
+
|
2184
|
+
return padded_mask1, padded_mask2
|
2185
|
+
|
2186
|
+
def compute_ap_over_iou_thresholds(true_masks, pred_masks, iou_thresholds):
|
2187
|
+
precision_recall_pairs = []
|
2188
|
+
for iou_threshold in iou_thresholds:
|
2189
|
+
matches = match_masks(true_masks, pred_masks, iou_threshold)
|
2190
|
+
precision, recall = compute_average_precision(matches, len(true_masks), len(pred_masks))
|
2191
|
+
# Check that precision and recall are within the range [0, 1]
|
2192
|
+
if not 0 <= precision <= 1 or not 0 <= recall <= 1:
|
2193
|
+
raise ValueError(f'Precision or recall out of bounds. Precision: {precision}, Recall: {recall}')
|
2194
|
+
precision_recall_pairs.append((precision, recall))
|
2195
|
+
|
2196
|
+
# Sort by recall values
|
2197
|
+
precision_recall_pairs = sorted(precision_recall_pairs, key=lambda x: x[1])
|
2198
|
+
sorted_precisions = [p[0] for p in precision_recall_pairs]
|
2199
|
+
sorted_recalls = [p[1] for p in precision_recall_pairs]
|
2200
|
+
return np.trapz(sorted_precisions, x=sorted_recalls)
|
2201
|
+
|
2202
|
+
def compute_segmentation_ap(true_masks, pred_masks, iou_thresholds=np.linspace(0.5, 0.95, 10)):
|
2203
|
+
true_mask_labels = label(true_masks)
|
2204
|
+
pred_mask_labels = label(pred_masks)
|
2205
|
+
true_mask_regions = [region.image for region in regionprops(true_mask_labels)]
|
2206
|
+
pred_mask_regions = [region.image for region in regionprops(pred_mask_labels)]
|
2207
|
+
return compute_ap_over_iou_thresholds(true_mask_regions, pred_mask_regions, iou_thresholds)
|
2208
|
+
|
2209
|
+
def jaccard_index(mask1, mask2):
|
2210
|
+
intersection = np.logical_and(mask1, mask2)
|
2211
|
+
union = np.logical_or(mask1, mask2)
|
2212
|
+
return np.sum(intersection) / np.sum(union)
|
2213
|
+
|
2214
|
+
def dice_coefficient(mask1, mask2):
|
2215
|
+
# Convert to binary masks
|
2216
|
+
mask1 = np.where(mask1 > 0, 1, 0)
|
2217
|
+
mask2 = np.where(mask2 > 0, 1, 0)
|
2218
|
+
|
2219
|
+
# Calculate intersection and total
|
2220
|
+
intersection = np.sum(mask1 & mask2)
|
2221
|
+
total = np.sum(mask1) + np.sum(mask2)
|
2222
|
+
|
2223
|
+
# Handle the case where both masks are empty
|
2224
|
+
if total == 0:
|
2225
|
+
return 1.0
|
2226
|
+
|
2227
|
+
# Return the Dice coefficient
|
2228
|
+
return 2.0 * intersection / total
|
2229
|
+
|
2230
|
+
def extract_boundaries(mask, dilation_radius=1):
|
2231
|
+
binary_mask = (mask > 0).astype(np.uint8)
|
2232
|
+
struct_elem = np.ones((dilation_radius*2+1, dilation_radius*2+1))
|
2233
|
+
dilated = binary_dilation(binary_mask, footprint=struct_elem)
|
2234
|
+
eroded = binary_erosion(binary_mask, footprint=struct_elem)
|
2235
|
+
boundary = dilated ^ eroded
|
2236
|
+
return boundary
|
2237
|
+
|
2238
|
+
def boundary_f1_score(mask_true, mask_pred, dilation_radius=1):
|
2239
|
+
# Assume extract_boundaries is defined to extract object boundaries with given dilation_radius
|
2240
|
+
boundary_true = extract_boundaries(mask_true, dilation_radius)
|
2241
|
+
boundary_pred = extract_boundaries(mask_pred, dilation_radius)
|
2242
|
+
|
2243
|
+
# Calculate intersection of boundaries
|
2244
|
+
intersection = np.logical_and(boundary_true, boundary_pred)
|
2245
|
+
|
2246
|
+
# Calculate precision and recall for boundary detection
|
2247
|
+
precision = np.sum(intersection) / (np.sum(boundary_pred) + 1e-6)
|
2248
|
+
recall = np.sum(intersection) / (np.sum(boundary_true) + 1e-6)
|
2249
|
+
|
2250
|
+
# Calculate F1 score as harmonic mean of precision and recall
|
2251
|
+
f1 = 2 * (precision * recall) / (precision + recall + 1e-6)
|
2252
|
+
|
2253
|
+
return f1
|
2254
|
+
|
2255
|
+
|
2256
|
+
|
2257
|
+
def _remove_noninfected(stack, cell_dim, nucleus_dim, pathogen_dim):
|
2258
|
+
"""
|
2259
|
+
Remove non-infected cells from the stack based on the provided dimensions.
|
2260
|
+
|
2261
|
+
Args:
|
2262
|
+
stack (ndarray): The stack of images.
|
2263
|
+
cell_dim (int or None): The dimension index for the cell mask. If None, a zero-filled mask will be used.
|
2264
|
+
nucleus_dim (int or None): The dimension index for the nucleus mask. If None, a zero-filled mask will be used.
|
2265
|
+
pathogen_dim (int or None): The dimension index for the pathogen mask. If None, a zero-filled mask will be used.
|
2266
|
+
|
2267
|
+
Returns:
|
2268
|
+
ndarray: The updated stack with non-infected cells removed.
|
2269
|
+
"""
|
2270
|
+
if not cell_dim is None:
|
2271
|
+
cell_mask = stack[:, :, cell_dim]
|
2272
|
+
else:
|
2273
|
+
cell_mask = np.zeros_like(stack)
|
2274
|
+
if not nucleus_dim is None:
|
2275
|
+
nucleus_mask = stack[:, :, nucleus_dim]
|
2276
|
+
else:
|
2277
|
+
nucleus_mask = np.zeros_like(stack)
|
2278
|
+
|
2279
|
+
if not pathogen_dim is None:
|
2280
|
+
pathogen_mask = stack[:, :, pathogen_dim]
|
2281
|
+
else:
|
2282
|
+
pathogen_mask = np.zeros_like(stack)
|
2283
|
+
|
2284
|
+
for cell_label in np.unique(cell_mask)[1:]:
|
2285
|
+
cell_region = cell_mask == cell_label
|
2286
|
+
labels_in_cell = np.unique(pathogen_mask[cell_region])
|
2287
|
+
if len(labels_in_cell) <= 1:
|
2288
|
+
cell_mask[cell_region] = 0
|
2289
|
+
nucleus_mask[cell_region] = 0
|
2290
|
+
if not cell_dim is None:
|
2291
|
+
stack[:, :, cell_dim] = cell_mask
|
2292
|
+
if not nucleus_dim is None:
|
2293
|
+
stack[:, :, nucleus_dim] = nucleus_mask
|
2294
|
+
return stack
|
2295
|
+
|
2296
|
+
def _remove_outside_objects(stack, cell_dim, nucleus_dim, pathogen_dim):
|
2297
|
+
"""
|
2298
|
+
Remove outside objects from the stack based on the provided dimensions.
|
2299
|
+
|
2300
|
+
Args:
|
2301
|
+
stack (ndarray): The stack of images.
|
2302
|
+
cell_dim (int): The dimension index of the cell mask in the stack.
|
2303
|
+
nucleus_dim (int): The dimension index of the nucleus mask in the stack.
|
2304
|
+
pathogen_dim (int): The dimension index of the pathogen mask in the stack.
|
2305
|
+
|
2306
|
+
Returns:
|
2307
|
+
ndarray: The updated stack with outside objects removed.
|
2308
|
+
"""
|
2309
|
+
if not cell_dim is None:
|
2310
|
+
cell_mask = stack[:, :, cell_dim]
|
2311
|
+
else:
|
2312
|
+
return stack
|
2313
|
+
nucleus_mask = stack[:, :, nucleus_dim]
|
2314
|
+
pathogen_mask = stack[:, :, pathogen_dim]
|
2315
|
+
pathogen_labels = np.unique(pathogen_mask)[1:]
|
2316
|
+
for pathogen_label in pathogen_labels:
|
2317
|
+
pathogen_region = pathogen_mask == pathogen_label
|
2318
|
+
cell_in_pathogen_region = np.unique(cell_mask[pathogen_region])
|
2319
|
+
cell_in_pathogen_region = cell_in_pathogen_region[cell_in_pathogen_region != 0] # Exclude background
|
2320
|
+
if len(cell_in_pathogen_region) == 0:
|
2321
|
+
pathogen_mask[pathogen_region] = 0
|
2322
|
+
corresponding_nucleus_region = nucleus_mask == pathogen_label
|
2323
|
+
nucleus_mask[corresponding_nucleus_region] = 0
|
2324
|
+
stack[:, :, cell_dim] = cell_mask
|
2325
|
+
stack[:, :, nucleus_dim] = nucleus_mask
|
2326
|
+
stack[:, :, pathogen_dim] = pathogen_mask
|
2327
|
+
return stack
|
2328
|
+
|
2329
|
+
def _remove_multiobject_cells(stack, mask_dim, cell_dim, nucleus_dim, pathogen_dim, object_dim):
|
2330
|
+
"""
|
2331
|
+
Remove multi-object cells from the stack.
|
2332
|
+
|
2333
|
+
Args:
|
2334
|
+
stack (ndarray): The stack of images.
|
2335
|
+
mask_dim (int): The dimension of the mask in the stack.
|
2336
|
+
cell_dim (int): The dimension of the cell in the stack.
|
2337
|
+
nucleus_dim (int): The dimension of the nucleus in the stack.
|
2338
|
+
pathogen_dim (int): The dimension of the pathogen in the stack.
|
2339
|
+
object_dim (int): The dimension of the object in the stack.
|
2340
|
+
|
2341
|
+
Returns:
|
2342
|
+
ndarray: The updated stack with multi-object cells removed.
|
2343
|
+
"""
|
2344
|
+
cell_mask = stack[:, :, mask_dim]
|
2345
|
+
nucleus_mask = stack[:, :, nucleus_dim]
|
2346
|
+
pathogen_mask = stack[:, :, pathogen_dim]
|
2347
|
+
object_mask = stack[:, :, object_dim]
|
2348
|
+
|
2349
|
+
for cell_label in np.unique(cell_mask)[1:]:
|
2350
|
+
cell_region = cell_mask == cell_label
|
2351
|
+
labels_in_cell = np.unique(object_mask[cell_region])
|
2352
|
+
if len(labels_in_cell) > 2:
|
2353
|
+
cell_mask[cell_region] = 0
|
2354
|
+
nucleus_mask[cell_region] = 0
|
2355
|
+
for pathogen_label in labels_in_cell[1:]: # Skip the first label (0)
|
2356
|
+
pathogen_mask[pathogen_mask == pathogen_label] = 0
|
2357
|
+
|
2358
|
+
stack[:, :, cell_dim] = cell_mask
|
2359
|
+
stack[:, :, nucleus_dim] = nucleus_mask
|
2360
|
+
stack[:, :, pathogen_dim] = pathogen_mask
|
2361
|
+
return stack
|
2362
|
+
|
2363
|
+
def merge_touching_objects(mask, threshold=0.25):
|
2364
|
+
"""
|
2365
|
+
Merges touching objects in a binary mask based on the percentage of their shared boundary.
|
2366
|
+
|
2367
|
+
Args:
|
2368
|
+
mask (ndarray): Binary mask representing objects.
|
2369
|
+
threshold (float, optional): Threshold value for merging objects. Defaults to 0.25.
|
2370
|
+
|
2371
|
+
Returns:
|
2372
|
+
ndarray: Merged mask.
|
2373
|
+
|
2374
|
+
"""
|
2375
|
+
perimeters = {}
|
2376
|
+
labels = np.unique(mask)
|
2377
|
+
# Calculating perimeter of each object
|
2378
|
+
for label in labels:
|
2379
|
+
if label != 0: # Ignore background
|
2380
|
+
edges = morphology.erosion(mask == label) ^ (mask == label)
|
2381
|
+
perimeters[label] = np.sum(edges)
|
2382
|
+
# Detect touching objects and find the shared boundary
|
2383
|
+
shared_perimeters = {}
|
2384
|
+
dilated = morphology.dilation(mask > 0)
|
2385
|
+
for label in labels:
|
2386
|
+
if label != 0: # Ignore background
|
2387
|
+
# Find the objects that this object is touching
|
2388
|
+
dilated_label = morphology.dilation(mask == label)
|
2389
|
+
touching_labels = np.unique(mask[dilated & (dilated_label != 0) & (mask != 0)])
|
2390
|
+
for touching_label in touching_labels:
|
2391
|
+
if touching_label != label: # Exclude the object itself
|
2392
|
+
shared_boundary = dilated_label & morphology.dilation(mask == touching_label)
|
2393
|
+
shared_perimeters[(label, touching_label)] = np.sum(shared_boundary)
|
2394
|
+
# Merge objects if more than 25% of their boundary is touching
|
2395
|
+
for (label1, label2), shared_perimeter in shared_perimeters.items():
|
2396
|
+
if shared_perimeter > threshold * min(perimeters[label1], perimeters[label2]):
|
2397
|
+
mask[mask == label2] = label1 # Merge label2 into label1
|
2398
|
+
return mask
|
2399
|
+
|
2400
|
+
def remove_intensity_objects(image, mask, intensity_threshold, mode):
|
2401
|
+
"""
|
2402
|
+
Removes objects from the mask based on their mean intensity in the original image.
|
2403
|
+
|
2404
|
+
Args:
|
2405
|
+
image (ndarray): The original image.
|
2406
|
+
mask (ndarray): The mask containing labeled objects.
|
2407
|
+
intensity_threshold (float): The threshold value for mean intensity.
|
2408
|
+
mode (str): The mode for intensity comparison. Can be 'low' or 'high'.
|
2409
|
+
|
2410
|
+
Returns:
|
2411
|
+
ndarray: The updated mask with objects removed.
|
2412
|
+
|
2413
|
+
"""
|
2414
|
+
# Calculate the mean intensity of each object in the original image
|
2415
|
+
props = regionprops_table(mask, image, properties=('label', 'mean_intensity'))
|
2416
|
+
# Find the labels of the objects with mean intensity below the threshold
|
2417
|
+
if mode == 'low':
|
2418
|
+
labels_to_remove = props['label'][props['mean_intensity'] < intensity_threshold]
|
2419
|
+
if mode == 'high':
|
2420
|
+
labels_to_remove = props['label'][props['mean_intensity'] > intensity_threshold]
|
2421
|
+
# Remove these objects from the mask
|
2422
|
+
mask[np.isin(mask, labels_to_remove)] = 0
|
2423
|
+
return mask
|
2424
|
+
|
2425
|
+
def _filter_closest_to_stat(df, column, n_rows, use_median=False):
|
2426
|
+
"""
|
2427
|
+
Filter the DataFrame to include the closest rows to a statistical measure.
|
2428
|
+
|
2429
|
+
Args:
|
2430
|
+
df (pandas.DataFrame): The input DataFrame.
|
2431
|
+
column (str): The column name to calculate the statistical measure.
|
2432
|
+
n_rows (int): The number of closest rows to include in the result.
|
2433
|
+
use_median (bool, optional): Whether to use the median or mean as the statistical measure.
|
2434
|
+
Defaults to False (mean).
|
2435
|
+
|
2436
|
+
Returns:
|
2437
|
+
pandas.DataFrame: The filtered DataFrame with the closest rows to the statistical measure.
|
2438
|
+
"""
|
2439
|
+
if use_median:
|
2440
|
+
target_value = df[column].median()
|
2441
|
+
else:
|
2442
|
+
target_value = df[column].mean()
|
2443
|
+
df['diff'] = (df[column] - target_value).abs()
|
2444
|
+
result_df = df.sort_values(by='diff').head(n_rows)
|
2445
|
+
result_df = result_df.drop(columns=['diff'])
|
2446
|
+
return result_df
|
2447
|
+
|
2448
|
+
def _find_similar_sized_images(file_list):
|
2449
|
+
"""
|
2450
|
+
Find the largest group of images with the most similar size and shape.
|
2451
|
+
|
2452
|
+
Args:
|
2453
|
+
file_list (list): List of file paths to the images.
|
2454
|
+
|
2455
|
+
Returns:
|
2456
|
+
list: List of file paths belonging to the largest group of images with the most similar size and shape.
|
2457
|
+
"""
|
2458
|
+
# Dictionary to hold image sizes and their paths
|
2459
|
+
size_to_paths = defaultdict(list)
|
2460
|
+
# Iterate over image paths to get their dimensions
|
2461
|
+
for path in file_list:
|
2462
|
+
img = cv2.imread(path, cv2.IMREAD_UNCHANGED) # Read with unchanged color space to support different image types
|
2463
|
+
if img is not None:
|
2464
|
+
# Find indices where the image is not padded (non-zero)
|
2465
|
+
if img.ndim == 3: # Color image
|
2466
|
+
mask = np.any(img != 0, axis=2)
|
2467
|
+
else: # Grayscale image
|
2468
|
+
mask = img != 0
|
2469
|
+
# Find the bounding box of non-zero regions
|
2470
|
+
coords = np.argwhere(mask)
|
2471
|
+
if coords.size == 0: # Skip images that are completely padded
|
2472
|
+
continue
|
2473
|
+
y0, x0 = coords.min(axis=0)
|
2474
|
+
y1, x1 = coords.max(axis=0) + 1 # Add 1 because slice end index is exclusive
|
2475
|
+
# Crop the image to remove padding
|
2476
|
+
cropped_img = img[y0:y1, x0:x1]
|
2477
|
+
# Get dimensions of the cropped image
|
2478
|
+
height, width = cropped_img.shape[:2]
|
2479
|
+
aspect_ratio = width / height
|
2480
|
+
size_key = (width, height, round(aspect_ratio, 2)) # Group by width, height, and aspect ratio
|
2481
|
+
size_to_paths[size_key].append(path)
|
2482
|
+
# Find the largest group of images with the most similar size and shape
|
2483
|
+
largest_group = max(size_to_paths.values(), key=len)
|
2484
|
+
return largest_group
|
2485
|
+
|
2486
|
+
def _relabel_parent_with_child_labels(parent_mask, child_mask):
|
2487
|
+
"""
|
2488
|
+
Relabels the parent mask based on overlapping child labels.
|
2489
|
+
|
2490
|
+
Args:
|
2491
|
+
parent_mask (ndarray): Binary mask representing the parent objects.
|
2492
|
+
child_mask (ndarray): Binary mask representing the child objects.
|
2493
|
+
|
2494
|
+
Returns:
|
2495
|
+
tuple: A tuple containing the relabeled parent mask and the original child mask.
|
2496
|
+
|
2497
|
+
"""
|
2498
|
+
# Label parent mask to identify unique objects
|
2499
|
+
parent_labels = label(parent_mask, background=0)
|
2500
|
+
# Use the original child mask labels directly, without relabeling
|
2501
|
+
child_labels = child_mask
|
2502
|
+
|
2503
|
+
# Create a new parent mask for updated labels
|
2504
|
+
parent_mask_new = np.zeros_like(parent_mask)
|
2505
|
+
|
2506
|
+
# Directly relabel parent cells based on overlapping child labels
|
2507
|
+
unique_child_labels = np.unique(child_labels)[1:] # Skip background
|
2508
|
+
for child_label in unique_child_labels:
|
2509
|
+
child_area_mask = (child_labels == child_label)
|
2510
|
+
overlapping_parent_label = np.unique(parent_labels[child_area_mask])
|
2511
|
+
|
2512
|
+
# Since each parent is assumed to overlap with exactly one nucleus,
|
2513
|
+
# directly set the parent label to the child label where overlap occurs
|
2514
|
+
for parent_label in overlapping_parent_label:
|
2515
|
+
if parent_label != 0: # Skip background
|
2516
|
+
parent_mask_new[parent_labels == parent_label] = child_label
|
2517
|
+
|
2518
|
+
# For cells containing multiple nucleus, standardize all nucleus to the first label
|
2519
|
+
# This will be done only if needed, as per your condition
|
2520
|
+
for parent_label in np.unique(parent_mask_new)[1:]: # Skip background
|
2521
|
+
parent_area_mask = (parent_mask_new == parent_label)
|
2522
|
+
child_labels_in_parent = np.unique(child_mask[parent_area_mask])
|
2523
|
+
child_labels_in_parent = child_labels_in_parent[child_labels_in_parent != 0] # Exclude background
|
2524
|
+
|
2525
|
+
if len(child_labels_in_parent) > 1:
|
2526
|
+
# Standardize to the first child label within this parent
|
2527
|
+
first_child_label = child_labels_in_parent[0]
|
2528
|
+
for child_label in child_labels_in_parent:
|
2529
|
+
child_mask[child_mask == child_label] = first_child_label
|
2530
|
+
|
2531
|
+
return parent_mask_new, child_mask
|
2532
|
+
|
2533
|
+
def _exclude_objects(cell_mask, nucleus_mask, pathogen_mask, cytoplasm_mask, include_uninfected=True):
|
2534
|
+
"""
|
2535
|
+
Exclude objects from the masks based on certain criteria.
|
2536
|
+
|
2537
|
+
Args:
|
2538
|
+
cell_mask (ndarray): Mask representing cells.
|
2539
|
+
nucleus_mask (ndarray): Mask representing nucleus.
|
2540
|
+
pathogen_mask (ndarray): Mask representing pathogens.
|
2541
|
+
cytoplasm_mask (ndarray): Mask representing cytoplasm.
|
2542
|
+
include_uninfected (bool, optional): Whether to include uninfected cells. Defaults to True.
|
2543
|
+
|
2544
|
+
Returns:
|
2545
|
+
tuple: A tuple containing the filtered cell mask, nucleus mask, pathogen mask, and cytoplasm mask.
|
2546
|
+
"""
|
2547
|
+
# Remove cells with no nucleus or cytoplasm (or pathogen)
|
2548
|
+
filtered_cells = np.zeros_like(cell_mask) # Initialize a new mask to store the filtered cells.
|
2549
|
+
for cell_label in np.unique(cell_mask): # Iterate over all cell labels in the cell mask.
|
2550
|
+
if cell_label == 0: # Skip background
|
2551
|
+
continue
|
2552
|
+
cell_region = cell_mask == cell_label # Get a mask for the current cell.
|
2553
|
+
# Check existence of nucleus, cytoplasm and pathogen in the current cell.
|
2554
|
+
has_nucleus = np.any(nucleus_mask[cell_region])
|
2555
|
+
has_cytoplasm = np.any(cytoplasm_mask[cell_region])
|
2556
|
+
has_pathogen = np.any(pathogen_mask[cell_region])
|
2557
|
+
if include_uninfected:
|
2558
|
+
if has_nucleus and has_cytoplasm:
|
2559
|
+
filtered_cells[cell_region] = cell_label
|
2560
|
+
else:
|
2561
|
+
if has_nucleus and has_cytoplasm and has_pathogen:
|
2562
|
+
filtered_cells[cell_region] = cell_label
|
2563
|
+
# Remove objects outside of cells
|
2564
|
+
nucleus_mask = nucleus_mask * (filtered_cells > 0)
|
2565
|
+
pathogen_mask = pathogen_mask * (filtered_cells > 0)
|
2566
|
+
cytoplasm_mask = cytoplasm_mask * (filtered_cells > 0)
|
2567
|
+
return filtered_cells, nucleus_mask, pathogen_mask, cytoplasm_mask
|
2568
|
+
|
2569
|
+
def _merge_overlapping_objects(mask1, mask2):
|
2570
|
+
"""
|
2571
|
+
Merge overlapping objects in two masks.
|
2572
|
+
|
2573
|
+
Args:
|
2574
|
+
mask1 (ndarray): First mask.
|
2575
|
+
mask2 (ndarray): Second mask.
|
2576
|
+
|
2577
|
+
Returns:
|
2578
|
+
tuple: A tuple containing the merged masks (mask1, mask2).
|
2579
|
+
"""
|
2580
|
+
labeled_1 = label(mask1)
|
2581
|
+
num_1 = np.max(labeled_1)
|
2582
|
+
for m1_id in range(1, num_1 + 1):
|
2583
|
+
current_1_mask = labeled_1 == m1_id
|
2584
|
+
overlapping_2_labels = np.unique(mask2[current_1_mask])
|
2585
|
+
overlapping_2_labels = overlapping_2_labels[overlapping_2_labels != 0]
|
2586
|
+
if len(overlapping_2_labels) > 1:
|
2587
|
+
overlap_percentages = [np.sum(current_1_mask & (mask2 == m2_label)) / np.sum(current_1_mask) * 100 for m2_label in overlapping_2_labels]
|
2588
|
+
max_overlap_label = overlapping_2_labels[np.argmax(overlap_percentages)]
|
2589
|
+
max_overlap_percentage = max(overlap_percentages)
|
2590
|
+
if max_overlap_percentage >= 90:
|
2591
|
+
for m2_label in overlapping_2_labels:
|
2592
|
+
if m2_label != max_overlap_label:
|
2593
|
+
mask1[(current_1_mask) & (mask2 == m2_label)] = 0
|
2594
|
+
else:
|
2595
|
+
for m2_label in overlapping_2_labels[1:]:
|
2596
|
+
mask2[mask2 == m2_label] = overlapping_2_labels[0]
|
2597
|
+
return mask1, mask2
|
2598
|
+
|
2599
|
+
def _filter_object(mask, min_value):
|
2600
|
+
"""
|
2601
|
+
Filter objects in a mask based on their frequency.
|
2602
|
+
|
2603
|
+
Args:
|
2604
|
+
mask (ndarray): The input mask.
|
2605
|
+
min_value (int): The minimum frequency threshold.
|
2606
|
+
|
2607
|
+
Returns:
|
2608
|
+
ndarray: The filtered mask.
|
2609
|
+
"""
|
2610
|
+
count = np.bincount(mask.ravel())
|
2611
|
+
to_remove = np.where(count < min_value)
|
2612
|
+
mask[np.isin(mask, to_remove)] = 0
|
2613
|
+
return mask
|
2614
|
+
|
2615
|
+
def _filter_cp_masks(masks, flows, filter_size, minimum_size, maximum_size, remove_border_objects, merge, filter_dimm, batch, moving_avg_q1, moving_avg_q3, moving_count, plot, figuresize):
|
2616
|
+
"""
|
2617
|
+
Filter the masks based on various criteria such as size, border objects, merging, and intensity.
|
2618
|
+
|
2619
|
+
Args:
|
2620
|
+
masks (list): List of masks.
|
2621
|
+
flows (list): List of flows.
|
2622
|
+
refine_masks (bool): Flag indicating whether to refine masks.
|
2623
|
+
filter_size (bool): Flag indicating whether to filter based on size.
|
2624
|
+
minimum_size (int): Minimum size of objects to keep.
|
2625
|
+
maximum_size (int): Maximum size of objects to keep.
|
2626
|
+
remove_border_objects (bool): Flag indicating whether to remove border objects.
|
2627
|
+
merge (bool): Flag indicating whether to merge adjacent objects.
|
2628
|
+
filter_dimm (bool): Flag indicating whether to filter based on intensity.
|
2629
|
+
batch (ndarray): Batch of images.
|
2630
|
+
moving_avg_q1 (float): Moving average of the first quartile of object intensities.
|
2631
|
+
moving_avg_q3 (float): Moving average of the third quartile of object intensities.
|
2632
|
+
moving_count (int): Count of moving averages.
|
2633
|
+
plot (bool): Flag indicating whether to plot the masks.
|
2634
|
+
figuresize (tuple): Size of the figure.
|
2635
|
+
|
2636
|
+
Returns:
|
2637
|
+
list: List of filtered masks.
|
2638
|
+
"""
|
2639
|
+
|
2640
|
+
from .plot import plot_masks
|
2641
|
+
|
2642
|
+
mask_stack = []
|
2643
|
+
for idx, (mask, flow, image) in enumerate(zip(masks, flows[0], batch)):
|
2644
|
+
if plot and idx == 0:
|
2645
|
+
num_objects = mask_object_count(mask)
|
2646
|
+
print(f'Number of objects before filtration: {num_objects}')
|
2647
|
+
plot_masks(batch=image, masks=mask, flows=flow, cmap='inferno', figuresize=figuresize, nr=1, file_type='.npz', print_object_number=True)
|
2648
|
+
|
2649
|
+
if filter_size:
|
2650
|
+
props = measure.regionprops_table(mask, properties=['label', 'area']) # Measure properties of labeled image regions.
|
2651
|
+
valid_labels = props['label'][np.logical_and(props['area'] > minimum_size, props['area'] < maximum_size)] # Select labels of valid size.
|
2652
|
+
masks[idx] = np.isin(mask, valid_labels) * mask # Keep only valid objects.
|
2653
|
+
if plot and idx == 0:
|
2654
|
+
num_objects = mask_object_count(mask)
|
2655
|
+
print(f'Number of objects after size filtration >{minimum_size} and <{maximum_size} : {num_objects}')
|
2656
|
+
plot_masks(batch=image, masks=mask, flows=flow, cmap='inferno', figuresize=figuresize, nr=1, file_type='.npz', print_object_number=True)
|
2657
|
+
if remove_border_objects:
|
2658
|
+
mask = clear_border(mask)
|
2659
|
+
if plot and idx == 0:
|
2660
|
+
num_objects = mask_object_count(mask)
|
2661
|
+
print(f'Number of objects after removing border objects, : {num_objects}')
|
2662
|
+
plot_masks(batch=image, masks=mask, flows=flow, cmap='inferno', figuresize=figuresize, nr=1, file_type='.npz', print_object_number=True)
|
2663
|
+
if merge:
|
2664
|
+
mask = merge_touching_objects(mask, threshold=0.25)
|
2665
|
+
if plot and idx == 0:
|
2666
|
+
num_objects = mask_object_count(mask)
|
2667
|
+
print(f'Number of objects after merging adjacent objects, : {num_objects}')
|
2668
|
+
plot_masks(batch=image, masks=mask, flows=flow, cmap='inferno', figuresize=figuresize, nr=1, file_type='.npz', print_object_number=True)
|
2669
|
+
if filter_dimm:
|
2670
|
+
unique_labels = np.unique(mask)
|
2671
|
+
if len(unique_labels) == 1 and unique_labels[0] == 0:
|
2672
|
+
continue
|
2673
|
+
object_intensities = [np.mean(batch[idx, :, :, 1][mask == label]) for label in unique_labels if label != 0]
|
2674
|
+
object_q1s = [np.percentile(intensities, 25) for intensities in object_intensities if intensities.size > 0]
|
2675
|
+
object_q3s = [np.percentile(intensities, 75) for intensities in object_intensities if intensities.size > 0]
|
2676
|
+
if object_q1s:
|
2677
|
+
object_q1_mean = np.mean(object_q1s)
|
2678
|
+
object_q3_mean = np.mean(object_q3s)
|
2679
|
+
moving_avg_q1 = (moving_avg_q1 * moving_count + object_q1_mean) / (moving_count + 1)
|
2680
|
+
moving_avg_q3 = (moving_avg_q3 * moving_count + object_q3_mean) / (moving_count + 1)
|
2681
|
+
moving_count += 1
|
2682
|
+
mask = remove_intensity_objects(batch[idx, :, :, 1], mask, intensity_threshold=moving_avg_q1, mode='low')
|
2683
|
+
mask = remove_intensity_objects(batch[idx, :, :, 1], mask, intensity_threshold=moving_avg_q3, mode='high')
|
2684
|
+
if plot and idx == 0:
|
2685
|
+
num_objects = mask_object_count(mask)
|
2686
|
+
print(f'Objects after intensity filtration > {moving_avg_q1} and <{moving_avg_q3}: {num_objects}')
|
2687
|
+
plot_masks(batch=image, masks=mask, flows=flow, cmap='inferno', figuresize=figuresize, nr=1, file_type='.npz', print_object_number=True)
|
2688
|
+
mask_stack.append(mask)
|
2689
|
+
return mask_stack
|
2690
|
+
|
2691
|
+
def _object_filter(df, object_type, size_range, intensity_range, mask_chans, mask_chan):
|
2692
|
+
"""
|
2693
|
+
Filter the DataFrame based on object type, size range, and intensity range.
|
2694
|
+
|
2695
|
+
Args:
|
2696
|
+
df (pandas.DataFrame): The DataFrame to filter.
|
2697
|
+
object_type (str): The type of object to filter.
|
2698
|
+
size_range (list or None): The range of object sizes to filter.
|
2699
|
+
intensity_range (list or None): The range of object intensities to filter.
|
2700
|
+
mask_chans (list): The list of mask channels.
|
2701
|
+
mask_chan (int): The index of the mask channel to use.
|
2702
|
+
|
2703
|
+
Returns:
|
2704
|
+
pandas.DataFrame: The filtered DataFrame.
|
2705
|
+
"""
|
2706
|
+
if not size_range is None:
|
2707
|
+
if isinstance(size_range, list):
|
2708
|
+
if isinstance(size_range[0], int):
|
2709
|
+
df = df[df[f'{object_type}_area'] > size_range[0]]
|
2710
|
+
print(f'After {object_type} minimum area filter: {len(df)}')
|
2711
|
+
if isinstance(size_range[1], int):
|
2712
|
+
df = df[df[f'{object_type}_area'] < size_range[1]]
|
2713
|
+
print(f'After {object_type} maximum area filter: {len(df)}')
|
2714
|
+
if not intensity_range is None:
|
2715
|
+
if isinstance(intensity_range, list):
|
2716
|
+
if isinstance(intensity_range[0], int):
|
2717
|
+
df = df[df[f'{object_type}_channel_{mask_chans[mask_chan]}_mean_intensity'] > intensity_range[0]]
|
2718
|
+
print(f'After {object_type} minimum mean intensity filter: {len(df)}')
|
2719
|
+
if isinstance(intensity_range[1], int):
|
2720
|
+
df = df[df[f'{object_type}_channel_{mask_chans[mask_chan]}_mean_intensity'] < intensity_range[1]]
|
2721
|
+
print(f'After {object_type} maximum mean intensity filter: {len(df)}')
|
2722
|
+
return df
|
2723
|
+
|
2724
|
+
###################################################
|
2725
|
+
# Classify
|
2726
|
+
###################################################
|