celldetective 1.2.1__py3-none-any.whl → 1.2.2.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- celldetective/__main__.py +12 -5
- celldetective/events.py +28 -2
- celldetective/gui/about.py +0 -1
- celldetective/gui/analyze_block.py +3 -18
- celldetective/gui/btrack_options.py +126 -21
- celldetective/gui/classifier_widget.py +67 -111
- celldetective/gui/configure_new_exp.py +37 -4
- celldetective/gui/control_panel.py +14 -30
- celldetective/gui/generic_signal_plot.py +793 -0
- celldetective/gui/gui_utils.py +401 -226
- celldetective/gui/json_readers.py +0 -2
- celldetective/gui/layouts.py +269 -25
- celldetective/gui/measurement_options.py +14 -23
- celldetective/gui/neighborhood_options.py +3 -15
- celldetective/gui/plot_measurements.py +10 -23
- celldetective/gui/plot_signals_ui.py +53 -687
- celldetective/gui/process_block.py +320 -186
- celldetective/gui/retrain_segmentation_model_options.py +30 -47
- celldetective/gui/retrain_signal_model_options.py +5 -14
- celldetective/gui/seg_model_loader.py +129 -113
- celldetective/gui/signal_annotator.py +89 -99
- celldetective/gui/signal_annotator2.py +5 -9
- celldetective/gui/styles.py +32 -0
- celldetective/gui/survival_ui.py +49 -712
- celldetective/gui/tableUI.py +0 -1
- celldetective/gui/thresholds_gui.py +38 -11
- celldetective/gui/viewers.py +6 -7
- celldetective/io.py +60 -82
- celldetective/measure.py +374 -15
- celldetective/neighborhood.py +1 -7
- celldetective/preprocessing.py +2 -4
- celldetective/relative_measurements.py +0 -3
- celldetective/scripts/analyze_signals.py +0 -1
- celldetective/scripts/measure_cells.py +1 -3
- celldetective/scripts/measure_relative.py +1 -2
- celldetective/scripts/segment_cells.py +16 -12
- celldetective/scripts/segment_cells_thresholds.py +17 -10
- celldetective/scripts/track_cells.py +18 -18
- celldetective/scripts/train_segmentation_model.py +1 -2
- celldetective/scripts/train_signal_model.py +0 -3
- celldetective/segmentation.py +1 -1
- celldetective/signals.py +17 -8
- celldetective/tracking.py +2 -1
- celldetective/utils.py +42 -2
- {celldetective-1.2.1.dist-info → celldetective-1.2.2.post1.dist-info}/METADATA +19 -12
- celldetective-1.2.2.post1.dist-info/RECORD +86 -0
- {celldetective-1.2.1.dist-info → celldetective-1.2.2.post1.dist-info}/WHEEL +1 -1
- celldetective/models/segmentation_effectors/primNK_cfse/config_input.json +0 -29
- celldetective/models/segmentation_effectors/primNK_cfse/cp-cfse-transfer +0 -0
- celldetective/models/segmentation_effectors/primNK_cfse/training_instructions.json +0 -37
- celldetective/models/segmentation_effectors/ricm-bimodal/config_input.json +0 -130
- celldetective/models/segmentation_effectors/ricm-bimodal/ricm-bimodal +0 -0
- celldetective/models/segmentation_effectors/ricm-bimodal/training_instructions.json +0 -37
- celldetective-1.2.1.dist-info/RECORD +0 -91
- {celldetective-1.2.1.dist-info → celldetective-1.2.2.post1.dist-info}/LICENSE +0 -0
- {celldetective-1.2.1.dist-info → celldetective-1.2.2.post1.dist-info}/entry_points.txt +0 -0
- {celldetective-1.2.1.dist-info → celldetective-1.2.2.post1.dist-info}/top_level.txt +0 -0
celldetective/measure.py
CHANGED
|
@@ -1,36 +1,26 @@
|
|
|
1
1
|
import math
|
|
2
|
-
import sys
|
|
3
|
-
from collections import defaultdict
|
|
4
2
|
|
|
5
3
|
import numpy as np
|
|
6
4
|
import pandas as pd
|
|
7
|
-
from
|
|
8
|
-
import
|
|
9
|
-
import tifffile
|
|
10
|
-
from lmfit import Parameters, Model, models
|
|
11
|
-
# import lmfit
|
|
5
|
+
from sklearn.metrics import r2_score
|
|
6
|
+
from scipy.optimize import curve_fit
|
|
12
7
|
from scipy import ndimage
|
|
13
|
-
from stardist import fill_label_holes
|
|
14
8
|
from tqdm import tqdm
|
|
15
9
|
from skimage.measure import regionprops_table
|
|
16
|
-
from scipy.ndimage.morphology import distance_transform_edt
|
|
17
10
|
from functools import reduce
|
|
18
11
|
from mahotas.features import haralick
|
|
19
|
-
from scipy.ndimage import zoom
|
|
12
|
+
from scipy.ndimage import zoom
|
|
20
13
|
import os
|
|
21
14
|
import subprocess
|
|
15
|
+
from math import ceil
|
|
22
16
|
|
|
23
|
-
from celldetective.filters import std_filter, gauss_filter
|
|
24
|
-
import datetime
|
|
25
17
|
from skimage.draw import disk as dsk
|
|
26
18
|
|
|
27
|
-
from celldetective.filters import std_filter, gauss_filter
|
|
28
19
|
from celldetective.utils import rename_intensity_column, create_patch_mask, remove_redundant_features, \
|
|
29
|
-
remove_trajectory_measurements, contour_of_instance_segmentation
|
|
20
|
+
remove_trajectory_measurements, contour_of_instance_segmentation, extract_cols_from_query, step_function, interpolate_nan
|
|
30
21
|
from celldetective.preprocessing import field_correction
|
|
31
22
|
import celldetective.extra_properties as extra_properties
|
|
32
23
|
from celldetective.extra_properties import *
|
|
33
|
-
import cv2
|
|
34
24
|
from inspect import getmembers, isfunction
|
|
35
25
|
from skimage.morphology import disk
|
|
36
26
|
|
|
@@ -551,6 +541,8 @@ def compute_haralick_features(img, labels, channels=None, target_channel=0, scal
|
|
|
551
541
|
if len(img.shape)==3:
|
|
552
542
|
img = img[:,:,target_channel]
|
|
553
543
|
|
|
544
|
+
img = interpolate_nan(img)
|
|
545
|
+
|
|
554
546
|
# Rescale image and mask
|
|
555
547
|
img = zoom(img,[scale_factor,scale_factor],order=3).astype(float)
|
|
556
548
|
labels = zoom(labels, [scale_factor,scale_factor],order=0)
|
|
@@ -986,3 +978,370 @@ def blob_detection(image, label, threshold, diameter):
|
|
|
986
978
|
blob_labels[mask_index] = [blobs_filtered.shape[0], spot_intensity['intensity_mean'][0]]
|
|
987
979
|
return blob_labels
|
|
988
980
|
|
|
981
|
+
### Classification ####
|
|
982
|
+
|
|
983
|
+
|
|
984
|
+
def estimate_time(df, class_attr, model='step_function', class_of_interest=[2], r2_threshold=0.5):
|
|
985
|
+
|
|
986
|
+
"""
|
|
987
|
+
Estimate the timing of an event for cells based on classification status and fit a model to the observed status signal.
|
|
988
|
+
|
|
989
|
+
Parameters
|
|
990
|
+
----------
|
|
991
|
+
df : pandas.DataFrame
|
|
992
|
+
DataFrame containing tracked data with classification and status columns.
|
|
993
|
+
class_attr : str
|
|
994
|
+
Column name for the classification attribute (e.g., 'class_event').
|
|
995
|
+
model : str, optional
|
|
996
|
+
Name of the model function used to fit the status signal (default is 'step_function').
|
|
997
|
+
class_of_interest : list, optional
|
|
998
|
+
List of class values that define the cells of interest for analysis (default is [2]).
|
|
999
|
+
r2_threshold : float, optional
|
|
1000
|
+
R-squared threshold for determining if the model fit is acceptable (default is 0.5).
|
|
1001
|
+
|
|
1002
|
+
Returns
|
|
1003
|
+
-------
|
|
1004
|
+
pandas.DataFrame
|
|
1005
|
+
Updated DataFrame with estimated event timing added in a column replacing 'class' with 't',
|
|
1006
|
+
and reclassification of cells based on the model fit.
|
|
1007
|
+
|
|
1008
|
+
Notes
|
|
1009
|
+
-----
|
|
1010
|
+
- The function assumes that cells are grouped by a unique identifier ('TRACK_ID') and sorted by time ('FRAME').
|
|
1011
|
+
- If the model provides a poor fit (R² < r2_threshold), the class of interest is set to 2.0 and timing (-1).
|
|
1012
|
+
- The function supports different models that can be passed as the `model` parameter, which are evaluated using `eval()`.
|
|
1013
|
+
|
|
1014
|
+
Example
|
|
1015
|
+
-------
|
|
1016
|
+
>>> df = estimate_time(df, 'class', model='step_function', class_of_interest=[2], r2_threshold=0.6)
|
|
1017
|
+
|
|
1018
|
+
"""
|
|
1019
|
+
|
|
1020
|
+
cols = list(df.columns)
|
|
1021
|
+
assert 'TRACK_ID' in cols,'Please provide tracked data...'
|
|
1022
|
+
if 'position' in cols:
|
|
1023
|
+
sort_cols = ['position', 'TRACK_ID']
|
|
1024
|
+
else:
|
|
1025
|
+
sort_cols = ['TRACK_ID']
|
|
1026
|
+
|
|
1027
|
+
df = df.sort_values(by=sort_cols,ignore_index=True)
|
|
1028
|
+
df = df.reset_index(drop=True)
|
|
1029
|
+
|
|
1030
|
+
for tid,group in df.loc[df[class_attr].isin(class_of_interest)].groupby(sort_cols):
|
|
1031
|
+
|
|
1032
|
+
indices = group.index
|
|
1033
|
+
status_col = class_attr.replace('class','status')
|
|
1034
|
+
|
|
1035
|
+
group_clean = group.dropna(subset=status_col)
|
|
1036
|
+
status_signal = group_clean[status_col].values
|
|
1037
|
+
timeline = group_clean['FRAME'].values
|
|
1038
|
+
|
|
1039
|
+
frames = group_clean['FRAME'].to_numpy()
|
|
1040
|
+
status_values = group_clean[status_col].to_numpy()
|
|
1041
|
+
t_first = group['t_firstdetection'].to_numpy()[0]
|
|
1042
|
+
|
|
1043
|
+
try:
|
|
1044
|
+
|
|
1045
|
+
popt, pcov = curve_fit(eval(model), timeline.astype(int), status_signal, p0=[df['FRAME'].max()//2, 0.8],maxfev=30000)
|
|
1046
|
+
values = [eval(model)(t, *popt) for t in timeline]
|
|
1047
|
+
r2 = r2_score(status_signal,values)
|
|
1048
|
+
|
|
1049
|
+
except Exception as e:
|
|
1050
|
+
|
|
1051
|
+
print(e)
|
|
1052
|
+
df.loc[indices, class_attr] = 2.0
|
|
1053
|
+
df.loc[indices, class_attr.replace('class','t')] = -1
|
|
1054
|
+
continue
|
|
1055
|
+
|
|
1056
|
+
if r2 > float(r2_threshold):
|
|
1057
|
+
t0 = popt[0]
|
|
1058
|
+
df.loc[indices, class_attr.replace('class','t')] = t0
|
|
1059
|
+
df.loc[indices, class_attr] = 0.0
|
|
1060
|
+
else:
|
|
1061
|
+
df.loc[indices, class_attr.replace('class','t')] = -1
|
|
1062
|
+
df.loc[indices, class_attr] = 2.0
|
|
1063
|
+
|
|
1064
|
+
return df
|
|
1065
|
+
|
|
1066
|
+
|
|
1067
|
+
def interpret_track_classification(df, class_attr, irreversible_event=False, unique_state=False,r2_threshold=0.5):
|
|
1068
|
+
|
|
1069
|
+
"""
|
|
1070
|
+
Interpret and classify tracked cells based on their status signals.
|
|
1071
|
+
|
|
1072
|
+
Parameters
|
|
1073
|
+
----------
|
|
1074
|
+
df : pandas.DataFrame
|
|
1075
|
+
DataFrame containing tracked cell data, including a classification attribute column and other necessary columns.
|
|
1076
|
+
class_attr : str
|
|
1077
|
+
Column name for the classification attribute (e.g., 'class') used to determine the state of cells.
|
|
1078
|
+
irreversible_event : bool, optional
|
|
1079
|
+
If True, classifies irreversible events in the dataset (default is False).
|
|
1080
|
+
When set to True, `unique_state` is ignored.
|
|
1081
|
+
unique_state : bool, optional
|
|
1082
|
+
If True, classifies unique states of cells in the dataset based on a percentile threshold (default is False).
|
|
1083
|
+
This option is ignored if `irreversible_event` is set to True.
|
|
1084
|
+
r2_threshold : float, optional
|
|
1085
|
+
R-squared threshold used when fitting the model during the classification of irreversible events (default is 0.5).
|
|
1086
|
+
|
|
1087
|
+
Returns
|
|
1088
|
+
-------
|
|
1089
|
+
pandas.DataFrame
|
|
1090
|
+
DataFrame with updated classifications for cell trajectories:
|
|
1091
|
+
- If `irreversible_event` is True, it classifies irreversible events using the `classify_irreversible_events` function.
|
|
1092
|
+
- If `unique_state` is True, it classifies unique states using the `classify_unique_states` function.
|
|
1093
|
+
|
|
1094
|
+
Raises
|
|
1095
|
+
------
|
|
1096
|
+
AssertionError
|
|
1097
|
+
If the 'TRACK_ID' column is missing in the input DataFrame.
|
|
1098
|
+
|
|
1099
|
+
Notes
|
|
1100
|
+
-----
|
|
1101
|
+
- The function assumes that the input DataFrame contains a column for tracking cells (`TRACK_ID`) and possibly a 'position' column.
|
|
1102
|
+
- The classification behavior depends on the `irreversible_event` and `unique_state` flags:
|
|
1103
|
+
- When `irreversible_event` is True, the function classifies events that are considered irreversible.
|
|
1104
|
+
- When `unique_state` is True (and `irreversible_event` is False), it classifies unique states using a 50th percentile threshold.
|
|
1105
|
+
|
|
1106
|
+
Example
|
|
1107
|
+
-------
|
|
1108
|
+
>>> df = interpret_track_classification(df, 'class', irreversible_event=True, r2_threshold=0.7)
|
|
1109
|
+
"""
|
|
1110
|
+
|
|
1111
|
+
cols = list(df.columns)
|
|
1112
|
+
|
|
1113
|
+
assert 'TRACK_ID' in cols,'Please provide tracked data...'
|
|
1114
|
+
if 'position' in cols:
|
|
1115
|
+
sort_cols = ['position', 'TRACK_ID']
|
|
1116
|
+
else:
|
|
1117
|
+
sort_cols = ['TRACK_ID']
|
|
1118
|
+
if class_attr.replace('class','status') not in cols:
|
|
1119
|
+
df.loc[:,class_attr.replace('class','status')] = df.loc[:,class_attr]
|
|
1120
|
+
|
|
1121
|
+
if irreversible_event:
|
|
1122
|
+
unique_state = False
|
|
1123
|
+
|
|
1124
|
+
if irreversible_event:
|
|
1125
|
+
|
|
1126
|
+
df = classify_irreversible_events(df, class_attr, r2_threshold=0.5)
|
|
1127
|
+
|
|
1128
|
+
elif unique_state:
|
|
1129
|
+
|
|
1130
|
+
df = classify_unique_states(df, class_attr, percentile=50)
|
|
1131
|
+
|
|
1132
|
+
return df
|
|
1133
|
+
|
|
1134
|
+
def classify_irreversible_events(df, class_attr, r2_threshold=0.5, percentile_recovery=95):
|
|
1135
|
+
|
|
1136
|
+
"""
|
|
1137
|
+
Classify irreversible events in a tracked dataset based on the status of cells and transitions.
|
|
1138
|
+
|
|
1139
|
+
Parameters
|
|
1140
|
+
----------
|
|
1141
|
+
df : pandas.DataFrame
|
|
1142
|
+
DataFrame containing tracked cell data, including classification and status columns.
|
|
1143
|
+
class_attr : str
|
|
1144
|
+
Column name for the classification attribute (e.g., 'class') used to update the classification of cell states.
|
|
1145
|
+
r2_threshold : float, optional
|
|
1146
|
+
R-squared threshold for fitting the model (default is 0.5). Used when estimating the time of transition.
|
|
1147
|
+
|
|
1148
|
+
Returns
|
|
1149
|
+
-------
|
|
1150
|
+
pandas.DataFrame
|
|
1151
|
+
DataFrame with updated classifications for irreversible events, with the following outcomes:
|
|
1152
|
+
- Cells with all 0s in the status column are classified as 1 (no event).
|
|
1153
|
+
- Cells with all 1s are classified as 2 (event already occurred).
|
|
1154
|
+
- Cells with a mix of 0s and 1s are classified as 2 (ambiguous, possible transition).
|
|
1155
|
+
- For cells classified as 2, the time of the event is estimated using the `estimate_time` function. If successful they are reclassified as 0 (event).
|
|
1156
|
+
- The classification for cells still classified as 2 is revisited using a 95th percentile threshold.
|
|
1157
|
+
|
|
1158
|
+
Notes
|
|
1159
|
+
-----
|
|
1160
|
+
- The function assumes that cells are grouped by a unique identifier ('TRACK_ID') and sorted by position or ID.
|
|
1161
|
+
- The classification is based on the `stat_col` derived from `class_attr` (status column).
|
|
1162
|
+
- Cells with no event (all 0s in the status column) are assigned a class value of 1.
|
|
1163
|
+
- Cells with irreversible events (all 1s in the status column) are assigned a class value of 2.
|
|
1164
|
+
- Cells with transitions (a mix of 0s and 1s) are classified as 2 and their event times are estimated. When successful they are reclassified as 0.
|
|
1165
|
+
- After event classification, the function reclassifies leftover ambiguous cases (class 2) using the `classify_unique_states` function.
|
|
1166
|
+
|
|
1167
|
+
Example
|
|
1168
|
+
-------
|
|
1169
|
+
>>> df = classify_irreversible_events(df, 'class', r2_threshold=0.7)
|
|
1170
|
+
"""
|
|
1171
|
+
|
|
1172
|
+
cols = list(df.columns)
|
|
1173
|
+
assert 'TRACK_ID' in cols,'Please provide tracked data...'
|
|
1174
|
+
if 'position' in cols:
|
|
1175
|
+
sort_cols = ['position', 'TRACK_ID']
|
|
1176
|
+
else:
|
|
1177
|
+
sort_cols = ['TRACK_ID']
|
|
1178
|
+
|
|
1179
|
+
stat_col = class_attr.replace('class','status')
|
|
1180
|
+
|
|
1181
|
+
for tid,track in df.groupby(sort_cols):
|
|
1182
|
+
|
|
1183
|
+
track_valid = track.dropna(subset=stat_col)
|
|
1184
|
+
indices_valid = track_valid[class_attr].index
|
|
1185
|
+
|
|
1186
|
+
indices = track[class_attr].index
|
|
1187
|
+
status_values = track_valid[stat_col].to_numpy()
|
|
1188
|
+
|
|
1189
|
+
if np.all([s==0 for s in status_values]):
|
|
1190
|
+
# all negative, no event
|
|
1191
|
+
df.loc[indices, class_attr] = 1
|
|
1192
|
+
|
|
1193
|
+
elif np.all([s==1 for s in status_values]):
|
|
1194
|
+
# all positive, event already observed
|
|
1195
|
+
df.loc[indices, class_attr] = 2
|
|
1196
|
+
df.loc[indices, class_attr.replace('class','status')] = 2
|
|
1197
|
+
else:
|
|
1198
|
+
# ambiguity, possible transition
|
|
1199
|
+
df.loc[indices, class_attr] = 2
|
|
1200
|
+
|
|
1201
|
+
df.loc[df[class_attr]!=2, class_attr.replace('class', 't')] = -1
|
|
1202
|
+
df = estimate_time(df, class_attr, model='step_function', class_of_interest=[2],r2_threshold=r2_threshold)
|
|
1203
|
+
|
|
1204
|
+
# Revisit class 2 cells to classify as neg/pos with percentile tolerance
|
|
1205
|
+
df.loc[df[class_attr]==2,:] = classify_unique_states(df.loc[df[class_attr]==2,:].copy(), class_attr, percentile_recovery)
|
|
1206
|
+
|
|
1207
|
+
return df
|
|
1208
|
+
|
|
1209
|
+
def classify_unique_states(df, class_attr, percentile=50):
|
|
1210
|
+
|
|
1211
|
+
"""
|
|
1212
|
+
Classify unique cell states based on percentile values of a status attribute in a tracked dataset.
|
|
1213
|
+
|
|
1214
|
+
Parameters
|
|
1215
|
+
----------
|
|
1216
|
+
df : pandas.DataFrame
|
|
1217
|
+
DataFrame containing tracked cell data, including classification and status columns.
|
|
1218
|
+
class_attr : str
|
|
1219
|
+
Column name for the classification attribute (e.g., 'class') used to update the classification of cell states.
|
|
1220
|
+
percentile : int, optional
|
|
1221
|
+
Percentile value used to classify the status attribute within the valid frames (default is median).
|
|
1222
|
+
|
|
1223
|
+
Returns
|
|
1224
|
+
-------
|
|
1225
|
+
pandas.DataFrame
|
|
1226
|
+
DataFrame with updated classification for each track and corresponding time (if applicable).
|
|
1227
|
+
The classification is updated based on the calculated percentile:
|
|
1228
|
+
- Cells with percentile values that round to 0 (negative to classification) are classified as 1.
|
|
1229
|
+
- Cells with percentile values that round to 1 (positive to classification) are classified as 2.
|
|
1230
|
+
- If classification is not applicable (NaN), time (`class_attr.replace('class', 't')`) is set to -1.
|
|
1231
|
+
|
|
1232
|
+
Notes
|
|
1233
|
+
-----
|
|
1234
|
+
- The function assumes that cells are grouped by a unique identifier ('TRACK_ID') and sorted by position or ID.
|
|
1235
|
+
- The classification is based on the `stat_col` derived from `class_attr` (status column).
|
|
1236
|
+
- NaN values in the status column are excluded from the percentile calculation.
|
|
1237
|
+
- For each track, the classification is assigned according to the rounded percentile value.
|
|
1238
|
+
- Time (`class_attr.replace('class', 't')`) is set to -1 when the cell state is classified.
|
|
1239
|
+
|
|
1240
|
+
Example
|
|
1241
|
+
-------
|
|
1242
|
+
>>> df = classify_unique_states(df, 'class', percentile=75)
|
|
1243
|
+
"""
|
|
1244
|
+
|
|
1245
|
+
cols = list(df.columns)
|
|
1246
|
+
assert 'TRACK_ID' in cols,'Please provide tracked data...'
|
|
1247
|
+
if 'position' in cols:
|
|
1248
|
+
sort_cols = ['position', 'TRACK_ID']
|
|
1249
|
+
else:
|
|
1250
|
+
sort_cols = ['TRACK_ID']
|
|
1251
|
+
|
|
1252
|
+
stat_col = class_attr.replace('class','status')
|
|
1253
|
+
|
|
1254
|
+
for tid,track in df.groupby(sort_cols):
|
|
1255
|
+
|
|
1256
|
+
track_valid = track.dropna(subset=stat_col)
|
|
1257
|
+
indices_valid = track_valid[class_attr].index
|
|
1258
|
+
|
|
1259
|
+
indices = track[class_attr].index
|
|
1260
|
+
status_values = track_valid[stat_col].to_numpy()
|
|
1261
|
+
|
|
1262
|
+
frames = track_valid['FRAME'].to_numpy()
|
|
1263
|
+
t_first = track['t_firstdetection'].to_numpy()[0]
|
|
1264
|
+
perc_status = np.nanpercentile(status_values[frames>=t_first], percentile)
|
|
1265
|
+
|
|
1266
|
+
if perc_status==perc_status:
|
|
1267
|
+
c = ceil(perc_status)
|
|
1268
|
+
if c==0:
|
|
1269
|
+
df.loc[indices, class_attr] = 1
|
|
1270
|
+
df.loc[indices, class_attr.replace('class','t')] = -1
|
|
1271
|
+
elif c==1:
|
|
1272
|
+
df.loc[indices, class_attr] = 2
|
|
1273
|
+
df.loc[indices, class_attr.replace('class','t')] = -1
|
|
1274
|
+
return df
|
|
1275
|
+
|
|
1276
|
+
def classify_cells_from_query(df, status_attr, query):
|
|
1277
|
+
|
|
1278
|
+
"""
|
|
1279
|
+
Classify cells in a DataFrame based on a query string, assigning classifications to a specified column.
|
|
1280
|
+
|
|
1281
|
+
Parameters
|
|
1282
|
+
----------
|
|
1283
|
+
df : pandas.DataFrame
|
|
1284
|
+
The DataFrame containing cell data to be classified.
|
|
1285
|
+
status_attr : str
|
|
1286
|
+
The name of the column where the classification results will be stored.
|
|
1287
|
+
- Initially, all cells are assigned a value of 0.
|
|
1288
|
+
query : str
|
|
1289
|
+
A string representing the condition for classifying the cells. The query is applied to the DataFrame using pandas `.query()`.
|
|
1290
|
+
|
|
1291
|
+
Returns
|
|
1292
|
+
-------
|
|
1293
|
+
pandas.DataFrame
|
|
1294
|
+
The DataFrame with an updated `status_attr` column:
|
|
1295
|
+
- Cells matching the query are classified with a value of 1.
|
|
1296
|
+
- Cells that have `NaN` values in any of the columns involved in the query are classified as `NaN`.
|
|
1297
|
+
- Cells that do not match the query are classified with a value of 0.
|
|
1298
|
+
|
|
1299
|
+
Notes
|
|
1300
|
+
-----
|
|
1301
|
+
- If the `query` string is empty, a message is printed and no classification is performed.
|
|
1302
|
+
- If the query contains columns that are not found in `df`, the entire `class_attr` column is set to `NaN`.
|
|
1303
|
+
- Any errors encountered during query evaluation will prevent changes from being applied and will print a message.
|
|
1304
|
+
|
|
1305
|
+
Examples
|
|
1306
|
+
--------
|
|
1307
|
+
>>> data = {'cell_type': ['A', 'B', 'A', 'B'], 'size': [10, 20, np.nan, 15]}
|
|
1308
|
+
>>> df = pd.DataFrame(data)
|
|
1309
|
+
>>> classify_cells_from_query(df, 'selected_cells', 'size > 15')
|
|
1310
|
+
cell_type size selected_cells
|
|
1311
|
+
0 A 10.0 0.0
|
|
1312
|
+
1 B 20.0 1.0
|
|
1313
|
+
2 A NaN NaN
|
|
1314
|
+
3 B 15.0 0.0
|
|
1315
|
+
|
|
1316
|
+
- If the query string is empty, the function prints a message and returns the DataFrame unchanged.
|
|
1317
|
+
- If any of the columns in the query don't exist in the DataFrame, the classification column is set to `NaN`.
|
|
1318
|
+
|
|
1319
|
+
Raises
|
|
1320
|
+
------
|
|
1321
|
+
Exception
|
|
1322
|
+
If the query is invalid or if there are issues with the DataFrame or query syntax, an error message is printed, and `None` is returned.
|
|
1323
|
+
"""
|
|
1324
|
+
|
|
1325
|
+
|
|
1326
|
+
# Initialize all states to 0
|
|
1327
|
+
df[status_attr] = 0
|
|
1328
|
+
cols = extract_cols_from_query(query)
|
|
1329
|
+
cols_in_df = np.all([c in list(df.columns) for c in cols], axis=0)
|
|
1330
|
+
|
|
1331
|
+
if query=='':
|
|
1332
|
+
print('The provided query is empty...')
|
|
1333
|
+
else:
|
|
1334
|
+
try:
|
|
1335
|
+
if cols_in_df:
|
|
1336
|
+
selection = df.dropna(subset=cols).query(query).index
|
|
1337
|
+
null_selection = df[df.loc[:,cols].isna().any(axis=1)].index
|
|
1338
|
+
# Set NaN to invalid cells, 1 otherwise
|
|
1339
|
+
df.loc[null_selection, status_attr] = np.nan
|
|
1340
|
+
df.loc[selection, status_attr] = 1
|
|
1341
|
+
else:
|
|
1342
|
+
df.loc[:, status_attr] = np.nan
|
|
1343
|
+
|
|
1344
|
+
except Exception as e:
|
|
1345
|
+
print("The query could not be understood. No filtering was applied. {e}...")
|
|
1346
|
+
return None
|
|
1347
|
+
return df
|
celldetective/neighborhood.py
CHANGED
|
@@ -1,17 +1,11 @@
|
|
|
1
1
|
import numpy as np
|
|
2
2
|
import pandas as pd
|
|
3
3
|
from tqdm import tqdm
|
|
4
|
-
from skimage.measure import regionprops_table
|
|
5
4
|
from skimage.graph import pixel_graph
|
|
6
|
-
from functools import reduce
|
|
7
|
-
from mahotas.features import haralick
|
|
8
|
-
from scipy.ndimage import zoom
|
|
9
5
|
import os
|
|
10
|
-
import
|
|
11
|
-
from celldetective.utils import contour_of_instance_segmentation, rename_intensity_column, create_patch_mask, remove_redundant_features, extract_identity_col
|
|
6
|
+
from celldetective.utils import contour_of_instance_segmentation, extract_identity_col
|
|
12
7
|
from scipy.spatial.distance import cdist
|
|
13
8
|
from celldetective.io import locate_labels, get_position_pickle, get_position_table
|
|
14
|
-
import re
|
|
15
9
|
|
|
16
10
|
abs_path = os.sep.join([os.path.split(os.path.dirname(os.path.realpath(__file__)))[0], 'celldetective'])
|
|
17
11
|
|
celldetective/preprocessing.py
CHANGED
|
@@ -6,13 +6,11 @@ from tqdm import tqdm
|
|
|
6
6
|
import numpy as np
|
|
7
7
|
import os
|
|
8
8
|
from celldetective.io import get_config, get_experiment_wells, interpret_wells_and_positions, extract_well_name_and_number, get_positions_in_well, extract_position_name, get_position_movie_path, load_frames, auto_load_number_of_frames
|
|
9
|
-
from celldetective.utils import estimate_unreliable_edge, unpad,
|
|
10
|
-
from celldetective.filters import std_filter, gauss_filter
|
|
9
|
+
from celldetective.utils import estimate_unreliable_edge, unpad, ConfigSectionMap, _extract_channel_indices_from_config, _extract_nbr_channels_from_config, _get_img_num_per_channel
|
|
11
10
|
from celldetective.segmentation import filter_image, threshold_image
|
|
12
|
-
from stardist import fill_label_holes
|
|
13
11
|
from csbdeep.io import save_tiff_imagej_compatible
|
|
14
12
|
from gc import collect
|
|
15
|
-
from lmfit import Parameters, Model
|
|
13
|
+
from lmfit import Parameters, Model
|
|
16
14
|
import tifffile.tifffile as tiff
|
|
17
15
|
|
|
18
16
|
def estimate_background_per_condition(experiment, threshold_on_std=1, well_option='*', target_channel="channel_name", frame_range=[0,5], mode="timeseries", activation_protocol=[['gauss',2],['std',4]], show_progress_per_pos=False, show_progress_per_well=True):
|
|
@@ -3,10 +3,7 @@ import numpy as np
|
|
|
3
3
|
from celldetective.utils import derivative, extract_identity_col
|
|
4
4
|
import os
|
|
5
5
|
import subprocess
|
|
6
|
-
from math import ceil
|
|
7
6
|
abs_path = os.sep.join([os.path.split(os.path.dirname(os.path.realpath(__file__)))[0], 'celldetective'])
|
|
8
|
-
import random
|
|
9
|
-
from tqdm import tqdm
|
|
10
7
|
|
|
11
8
|
def measure_pairs(pos, neighborhood_protocol):
|
|
12
9
|
|
|
@@ -6,16 +6,14 @@ import argparse
|
|
|
6
6
|
import os
|
|
7
7
|
import json
|
|
8
8
|
from celldetective.io import auto_load_number_of_frames, load_frames
|
|
9
|
-
from celldetective.utils import extract_experiment_channels,
|
|
9
|
+
from celldetective.utils import extract_experiment_channels, ConfigSectionMap, _get_img_num_per_channel, extract_experiment_channels
|
|
10
10
|
from celldetective.utils import remove_redundant_features, remove_trajectory_measurements
|
|
11
11
|
from celldetective.measure import drop_tonal_features, measure_features, measure_isotropic_intensity
|
|
12
12
|
from pathlib import Path, PurePath
|
|
13
13
|
from glob import glob
|
|
14
|
-
from shutil import rmtree
|
|
15
14
|
from tqdm import tqdm
|
|
16
15
|
import numpy as np
|
|
17
16
|
import pandas as pd
|
|
18
|
-
import gc
|
|
19
17
|
from natsort import natsorted
|
|
20
18
|
from art import tprint
|
|
21
19
|
from tifffile import imread
|
|
@@ -1,7 +1,6 @@
|
|
|
1
1
|
import argparse
|
|
2
2
|
import os
|
|
3
|
-
import
|
|
4
|
-
from celldetective.relative_measurements import measure_pair_signals_at_position, update_effector_table, extract_neighborhoods_from_pickles
|
|
3
|
+
from celldetective.relative_measurements import measure_pair_signals_at_position, extract_neighborhoods_from_pickles
|
|
5
4
|
from celldetective.utils import ConfigSectionMap, extract_experiment_channels
|
|
6
5
|
|
|
7
6
|
from pathlib import Path, PurePath
|
|
@@ -9,7 +9,7 @@ import json
|
|
|
9
9
|
from stardist.models import StarDist2D
|
|
10
10
|
from cellpose.models import CellposeModel
|
|
11
11
|
from celldetective.io import locate_segmentation_model, auto_load_number_of_frames, load_frames
|
|
12
|
-
from celldetective.utils import interpolate_nan, _estimate_scale_factor, _extract_channel_indices_from_config,
|
|
12
|
+
from celldetective.utils import interpolate_nan, _estimate_scale_factor, _extract_channel_indices_from_config, ConfigSectionMap, _extract_nbr_channels_from_config, _get_img_num_per_channel
|
|
13
13
|
from pathlib import Path, PurePath
|
|
14
14
|
from glob import glob
|
|
15
15
|
from shutil import rmtree
|
|
@@ -20,10 +20,7 @@ from csbdeep.io import save_tiff_imagej_compatible
|
|
|
20
20
|
import gc
|
|
21
21
|
from art import tprint
|
|
22
22
|
from scipy.ndimage import zoom
|
|
23
|
-
import threading
|
|
24
23
|
|
|
25
|
-
import matplotlib.pyplot as plt
|
|
26
|
-
import time
|
|
27
24
|
|
|
28
25
|
tprint("Segment")
|
|
29
26
|
|
|
@@ -210,17 +207,24 @@ def segment_index(indices):
|
|
|
210
207
|
del Y_pred;
|
|
211
208
|
gc.collect()
|
|
212
209
|
|
|
210
|
+
|
|
211
|
+
import concurrent.futures
|
|
212
|
+
|
|
213
213
|
# Multithreading
|
|
214
214
|
indices = list(range(img_num_channels.shape[1]))
|
|
215
215
|
chunks = np.array_split(indices, n_threads)
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
216
|
+
|
|
217
|
+
with concurrent.futures.ThreadPoolExecutor() as executor:
|
|
218
|
+
executor.map(segment_index, chunks)
|
|
219
|
+
|
|
220
|
+
# threads = []
|
|
221
|
+
# for i in range(n_threads):
|
|
222
|
+
# thread_i = threading.Thread(target=segment_index, args=[chunks[i]])
|
|
223
|
+
# threads.append(thread_i)
|
|
224
|
+
# for th in threads:
|
|
225
|
+
# th.start()
|
|
226
|
+
# for th in threads:
|
|
227
|
+
# th.join()
|
|
224
228
|
|
|
225
229
|
print('Done.')
|
|
226
230
|
|
|
@@ -7,7 +7,7 @@ import os
|
|
|
7
7
|
import json
|
|
8
8
|
from celldetective.io import auto_load_number_of_frames, load_frames
|
|
9
9
|
from celldetective.segmentation import segment_frame_from_thresholds
|
|
10
|
-
from celldetective.utils import
|
|
10
|
+
from celldetective.utils import _extract_channel_indices_from_config, ConfigSectionMap, _extract_nbr_channels_from_config, _get_img_num_per_channel, extract_experiment_channels
|
|
11
11
|
from pathlib import Path, PurePath
|
|
12
12
|
from glob import glob
|
|
13
13
|
from shutil import rmtree
|
|
@@ -16,7 +16,6 @@ import numpy as np
|
|
|
16
16
|
from csbdeep.io import save_tiff_imagej_compatible
|
|
17
17
|
import gc
|
|
18
18
|
from art import tprint
|
|
19
|
-
import threading
|
|
20
19
|
|
|
21
20
|
tprint("Segment")
|
|
22
21
|
|
|
@@ -120,17 +119,25 @@ def segment_index(indices):
|
|
|
120
119
|
del mask;
|
|
121
120
|
gc.collect()
|
|
122
121
|
|
|
122
|
+
import concurrent.futures
|
|
123
|
+
|
|
123
124
|
# Multithreading
|
|
124
125
|
indices = list(range(img_num_channels.shape[1]))
|
|
125
126
|
chunks = np.array_split(indices, n_threads)
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
127
|
+
|
|
128
|
+
with concurrent.futures.ThreadPoolExecutor() as executor:
|
|
129
|
+
executor.map(segment_index, chunks)
|
|
130
|
+
|
|
131
|
+
# indices = list(range(img_num_channels.shape[1]))
|
|
132
|
+
# chunks = np.array_split(indices, n_threads)
|
|
133
|
+
# threads = []
|
|
134
|
+
# for i in range(n_threads):
|
|
135
|
+
# thread_i = threading.Thread(target=segment_index, args=[chunks[i]])
|
|
136
|
+
# threads.append(thread_i)
|
|
137
|
+
# for th in threads:
|
|
138
|
+
# th.start()
|
|
139
|
+
# for th in threads:
|
|
140
|
+
# th.join()
|
|
134
141
|
|
|
135
142
|
print('Done.')
|
|
136
143
|
gc.collect()
|