celldetective 1.0.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- celldetective/__init__.py +2 -0
- celldetective/__main__.py +432 -0
- celldetective/datasets/segmentation_annotations/blank +0 -0
- celldetective/datasets/signal_annotations/blank +0 -0
- celldetective/events.py +149 -0
- celldetective/extra_properties.py +100 -0
- celldetective/filters.py +89 -0
- celldetective/gui/__init__.py +20 -0
- celldetective/gui/about.py +44 -0
- celldetective/gui/analyze_block.py +563 -0
- celldetective/gui/btrack_options.py +898 -0
- celldetective/gui/classifier_widget.py +386 -0
- celldetective/gui/configure_new_exp.py +532 -0
- celldetective/gui/control_panel.py +438 -0
- celldetective/gui/gui_utils.py +495 -0
- celldetective/gui/json_readers.py +113 -0
- celldetective/gui/measurement_options.py +1425 -0
- celldetective/gui/neighborhood_options.py +452 -0
- celldetective/gui/plot_signals_ui.py +1042 -0
- celldetective/gui/process_block.py +1055 -0
- celldetective/gui/retrain_segmentation_model_options.py +706 -0
- celldetective/gui/retrain_signal_model_options.py +643 -0
- celldetective/gui/seg_model_loader.py +460 -0
- celldetective/gui/signal_annotator.py +2388 -0
- celldetective/gui/signal_annotator_options.py +340 -0
- celldetective/gui/styles.py +217 -0
- celldetective/gui/survival_ui.py +903 -0
- celldetective/gui/tableUI.py +608 -0
- celldetective/gui/thresholds_gui.py +1300 -0
- celldetective/icons/logo-large.png +0 -0
- celldetective/icons/logo.png +0 -0
- celldetective/icons/signals_icon.png +0 -0
- celldetective/icons/splash-test.png +0 -0
- celldetective/icons/splash.png +0 -0
- celldetective/icons/splash0.png +0 -0
- celldetective/icons/survival2.png +0 -0
- celldetective/icons/vignette_signals2.png +0 -0
- celldetective/icons/vignette_signals2.svg +114 -0
- celldetective/io.py +2050 -0
- celldetective/links/zenodo.json +561 -0
- celldetective/measure.py +1258 -0
- celldetective/models/segmentation_effectors/blank +0 -0
- celldetective/models/segmentation_generic/blank +0 -0
- celldetective/models/segmentation_targets/blank +0 -0
- celldetective/models/signal_detection/blank +0 -0
- celldetective/models/tracking_configs/mcf7.json +68 -0
- celldetective/models/tracking_configs/ricm.json +203 -0
- celldetective/models/tracking_configs/ricm2.json +203 -0
- celldetective/neighborhood.py +717 -0
- celldetective/scripts/analyze_signals.py +51 -0
- celldetective/scripts/measure_cells.py +275 -0
- celldetective/scripts/segment_cells.py +212 -0
- celldetective/scripts/segment_cells_thresholds.py +140 -0
- celldetective/scripts/track_cells.py +206 -0
- celldetective/scripts/train_segmentation_model.py +246 -0
- celldetective/scripts/train_signal_model.py +49 -0
- celldetective/segmentation.py +712 -0
- celldetective/signals.py +2826 -0
- celldetective/tracking.py +974 -0
- celldetective/utils.py +1681 -0
- celldetective-1.0.2.dist-info/LICENSE +674 -0
- celldetective-1.0.2.dist-info/METADATA +192 -0
- celldetective-1.0.2.dist-info/RECORD +66 -0
- celldetective-1.0.2.dist-info/WHEEL +5 -0
- celldetective-1.0.2.dist-info/entry_points.txt +2 -0
- celldetective-1.0.2.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,206 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Copright © 2022 Laboratoire Adhesion et Inflammation, Authored by Remy Torro.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import argparse
|
|
6
|
+
import os
|
|
7
|
+
import json
|
|
8
|
+
from celldetective.io import auto_load_number_of_frames, load_frames, interpret_tracking_configuration
|
|
9
|
+
from celldetective.utils import extract_experiment_channels, _extract_channel_indices_from_config, _extract_channel_indices, ConfigSectionMap, _extract_nbr_channels_from_config, _get_img_num_per_channel, extract_experiment_channels
|
|
10
|
+
from celldetective.measure import drop_tonal_features, measure_features
|
|
11
|
+
from celldetective import track
|
|
12
|
+
from pathlib import Path, PurePath
|
|
13
|
+
from glob import glob
|
|
14
|
+
from shutil import rmtree
|
|
15
|
+
from tqdm import tqdm
|
|
16
|
+
import numpy as np
|
|
17
|
+
import pandas as pd
|
|
18
|
+
import gc
|
|
19
|
+
import os
|
|
20
|
+
from natsort import natsorted
|
|
21
|
+
from art import tprint
|
|
22
|
+
from tifffile import imread
|
|
23
|
+
import threading
|
|
24
|
+
|
|
25
|
+
tprint("Track")
|
|
26
|
+
|
|
27
|
+
parser = argparse.ArgumentParser(description="Segment a movie in position with the selected model",
|
|
28
|
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
|
29
|
+
parser.add_argument('-p',"--position", required=True, help="Path to the position")
|
|
30
|
+
parser.add_argument("--mode", default="target", choices=["target","effector","targets","effectors"],help="Cell population of interest")
|
|
31
|
+
parser.add_argument("--threads", default="1",help="Number of parallel threads")
|
|
32
|
+
|
|
33
|
+
args = parser.parse_args()
|
|
34
|
+
process_arguments = vars(args)
|
|
35
|
+
pos = str(process_arguments['position'])
|
|
36
|
+
mode = str(process_arguments['mode'])
|
|
37
|
+
n_threads = int(process_arguments['threads'])
|
|
38
|
+
|
|
39
|
+
if not os.path.exists(pos+"output"):
|
|
40
|
+
os.mkdir(pos+"output")
|
|
41
|
+
|
|
42
|
+
if not os.path.exists(pos+os.sep.join(["output","tables"])):
|
|
43
|
+
os.mkdir(pos+os.sep.join(["output","tables"]))
|
|
44
|
+
|
|
45
|
+
if mode.lower()=="target" or mode.lower()=="targets":
|
|
46
|
+
label_folder = "labels_targets"
|
|
47
|
+
instruction_file = os.sep.join(["configs", "tracking_instructions_targets.json"])
|
|
48
|
+
napari_name = "napari_target_trajectories.npy"
|
|
49
|
+
table_name = "trajectories_targets.csv"
|
|
50
|
+
|
|
51
|
+
elif mode.lower()=="effector" or mode.lower()=="effectors":
|
|
52
|
+
label_folder = "labels_effectors"
|
|
53
|
+
instruction_file = os.sep.join(["configs","tracking_instructions_effectors.json"])
|
|
54
|
+
napari_name = "napari_effector_trajectories.npy"
|
|
55
|
+
table_name = "trajectories_effectors.csv"
|
|
56
|
+
|
|
57
|
+
# Locate experiment config
|
|
58
|
+
parent1 = Path(pos).parent
|
|
59
|
+
expfolder = parent1.parent
|
|
60
|
+
config = PurePath(expfolder,Path("config.ini"))
|
|
61
|
+
assert os.path.exists(config),'The configuration file for the experiment could not be located. Abort.'
|
|
62
|
+
print("Configuration file: ",config)
|
|
63
|
+
|
|
64
|
+
# from exp config fetch spatial calib, channel names
|
|
65
|
+
movie_prefix = ConfigSectionMap(config,"MovieSettings")["movie_prefix"]
|
|
66
|
+
spatial_calibration = float(ConfigSectionMap(config,"MovieSettings")["pxtoum"])
|
|
67
|
+
time_calibration = float(ConfigSectionMap(config,"MovieSettings")["frametomin"])
|
|
68
|
+
len_movie = float(ConfigSectionMap(config,"MovieSettings")["len_movie"])
|
|
69
|
+
shape_x = int(ConfigSectionMap(config,"MovieSettings")["shape_x"])
|
|
70
|
+
shape_y = int(ConfigSectionMap(config,"MovieSettings")["shape_y"])
|
|
71
|
+
|
|
72
|
+
channel_names, channel_indices = extract_experiment_channels(config)
|
|
73
|
+
nbr_channels = len(channel_names)
|
|
74
|
+
|
|
75
|
+
# from tracking instructions, fetch btrack config, features, haralick, clean_traj, idea: fetch custom timeline?
|
|
76
|
+
instr_path = PurePath(expfolder,Path(f"{instruction_file}"))
|
|
77
|
+
if os.path.exists(instr_path):
|
|
78
|
+
print(f"Tracking instructions for the {mode} population has been successfully located.")
|
|
79
|
+
with open(instr_path, 'r') as f:
|
|
80
|
+
instructions = json.load(f)
|
|
81
|
+
print("Reading the following instructions: ",instructions)
|
|
82
|
+
btrack_config = interpret_tracking_configuration(instructions['btrack_config_path'])
|
|
83
|
+
|
|
84
|
+
if 'features' in instructions:
|
|
85
|
+
features = instructions['features']
|
|
86
|
+
else:
|
|
87
|
+
features = None
|
|
88
|
+
|
|
89
|
+
if 'mask_channels' in instructions:
|
|
90
|
+
mask_channels = instructions['mask_channels']
|
|
91
|
+
else:
|
|
92
|
+
mask_channels = None
|
|
93
|
+
|
|
94
|
+
if 'haralick_options' in instructions:
|
|
95
|
+
haralick_options = instructions['haralick_options']
|
|
96
|
+
else:
|
|
97
|
+
haralick_options = None
|
|
98
|
+
|
|
99
|
+
if 'post_processing_options' in instructions:
|
|
100
|
+
post_processing_options = instructions['post_processing_options']
|
|
101
|
+
else:
|
|
102
|
+
post_processing_options = None
|
|
103
|
+
else:
|
|
104
|
+
print('No tracking instructions found. Use standard bTrack motion model.')
|
|
105
|
+
btrack_config = interpret_tracking_configuration(None)
|
|
106
|
+
features = None
|
|
107
|
+
mask_channels = None
|
|
108
|
+
haralick_options = None
|
|
109
|
+
post_processing_options = None
|
|
110
|
+
|
|
111
|
+
if features is None:
|
|
112
|
+
features = []
|
|
113
|
+
|
|
114
|
+
# from pos fetch labels
|
|
115
|
+
label_path = natsorted(glob(pos+f"{label_folder}"+os.sep+"*.tif"))
|
|
116
|
+
if len(label_path)>0:
|
|
117
|
+
print(f"Found {len(label_path)} segmented frames...")
|
|
118
|
+
else:
|
|
119
|
+
print(f"No segmented frames have been found. Please run segmentation first, skipping...")
|
|
120
|
+
os.abort()
|
|
121
|
+
|
|
122
|
+
# Do this if features or Haralick is not None, else don't need stack
|
|
123
|
+
try:
|
|
124
|
+
file = glob(pos+os.sep.join(["movie", f"{movie_prefix}*.tif"]))[0]
|
|
125
|
+
except IndexError:
|
|
126
|
+
print('Movie could not be found. Check the prefix. If you intended to measure texture or tone, this will not be performed.')
|
|
127
|
+
file = None
|
|
128
|
+
haralick_option = None
|
|
129
|
+
features = drop_tonal_features(features)
|
|
130
|
+
|
|
131
|
+
len_movie_auto = auto_load_number_of_frames(file)
|
|
132
|
+
if len_movie_auto is not None:
|
|
133
|
+
len_movie = len_movie_auto
|
|
134
|
+
|
|
135
|
+
img_num_channels = _get_img_num_per_channel(channel_indices, len_movie, nbr_channels)
|
|
136
|
+
|
|
137
|
+
#######################################
|
|
138
|
+
# Loop over all frames and find objects
|
|
139
|
+
#######################################
|
|
140
|
+
|
|
141
|
+
timestep_dataframes = []
|
|
142
|
+
|
|
143
|
+
def measure_index(indices):
|
|
144
|
+
for t in tqdm(indices,desc="frame"):
|
|
145
|
+
|
|
146
|
+
# Load channels at time t
|
|
147
|
+
img = load_frames(img_num_channels[:,t], file, scale=None, normalize_input=False)
|
|
148
|
+
lbl = imread(label_path[t])
|
|
149
|
+
|
|
150
|
+
df_props = measure_features(img, lbl, features = features+['centroid'], border_dist=None,
|
|
151
|
+
channels=channel_names, haralick_options=haralick_options, verbose=False,
|
|
152
|
+
)
|
|
153
|
+
df_props.rename(columns={'centroid-1': 'x', 'centroid-0': 'y'},inplace=True)
|
|
154
|
+
df_props['t'] = int(t)
|
|
155
|
+
timestep_dataframes.append(df_props)
|
|
156
|
+
|
|
157
|
+
# Multithreading
|
|
158
|
+
indices = list(range(img_num_channels.shape[1]))
|
|
159
|
+
chunks = np.array_split(indices, n_threads)
|
|
160
|
+
threads = []
|
|
161
|
+
for i in range(n_threads):
|
|
162
|
+
thread_i = threading.Thread(target=measure_index, args=[chunks[i]])
|
|
163
|
+
threads.append(thread_i)
|
|
164
|
+
for th in threads:
|
|
165
|
+
th.start()
|
|
166
|
+
for th in threads:
|
|
167
|
+
th.join()
|
|
168
|
+
|
|
169
|
+
df = pd.concat(timestep_dataframes)
|
|
170
|
+
df.reset_index(inplace=True, drop=True)
|
|
171
|
+
|
|
172
|
+
if mask_channels is not None:
|
|
173
|
+
cols_to_drop = []
|
|
174
|
+
for mc in mask_channels:
|
|
175
|
+
columns = df.columns
|
|
176
|
+
col_contains = [mc in c for c in columns]
|
|
177
|
+
to_remove = np.array(columns)[np.array(col_contains)]
|
|
178
|
+
cols_to_drop.extend(to_remove)
|
|
179
|
+
if len(cols_to_drop)>0:
|
|
180
|
+
df = df.drop(cols_to_drop, axis=1)
|
|
181
|
+
|
|
182
|
+
# do tracking
|
|
183
|
+
trajectories, napari_data = track(None,
|
|
184
|
+
configuration=btrack_config,
|
|
185
|
+
objects=df,
|
|
186
|
+
spatial_calibration=spatial_calibration,
|
|
187
|
+
channel_names=channel_names,
|
|
188
|
+
return_napari_data=True,
|
|
189
|
+
optimizer_options = {'tm_lim': int(12e4)},
|
|
190
|
+
track_kwargs={'step_size': 100},
|
|
191
|
+
clean_trajectories_kwargs=post_processing_options,
|
|
192
|
+
volume=(shape_x, shape_y),
|
|
193
|
+
)
|
|
194
|
+
print(trajectories)
|
|
195
|
+
print(trajectories.columns)
|
|
196
|
+
|
|
197
|
+
# out trajectory table, create POSITION_X_um, POSITION_Y_um, TIME_min (new ones)
|
|
198
|
+
# Save napari data
|
|
199
|
+
np.save(pos+os.sep.join(['output', 'tables', napari_name]), napari_data, allow_pickle=True)
|
|
200
|
+
print(f"napari data successfully saved in {pos+os.sep.join(['output', 'tables'])}")
|
|
201
|
+
|
|
202
|
+
trajectories.to_csv(pos+os.sep.join(['output', 'tables', table_name]), index=False)
|
|
203
|
+
print(f"Table {table_name} successfully saved in {os.sep.join(['output', 'tables'])}")
|
|
204
|
+
|
|
205
|
+
del trajectories; del napari_data;
|
|
206
|
+
gc.collect()
|
|
@@ -0,0 +1,246 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Copright © 2023 Laboratoire Adhesion et Inflammation, Authored by Remy Torro.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import argparse
|
|
6
|
+
import os
|
|
7
|
+
import shutil
|
|
8
|
+
from glob import glob
|
|
9
|
+
import json
|
|
10
|
+
from tqdm import tqdm
|
|
11
|
+
import numpy as np
|
|
12
|
+
import random
|
|
13
|
+
|
|
14
|
+
from celldetective.utils import load_image_dataset, normalize_per_channel, augmenter
|
|
15
|
+
from stardist import fill_label_holes
|
|
16
|
+
from art import tprint
|
|
17
|
+
import matplotlib.pyplot as plt
|
|
18
|
+
|
|
19
|
+
def interpolate_nan(array_like):
|
|
20
|
+
array = array_like.copy()
|
|
21
|
+
|
|
22
|
+
isnan_array = ~np.isnan(array)
|
|
23
|
+
|
|
24
|
+
xp = isnan_array.ravel().nonzero()[0]
|
|
25
|
+
|
|
26
|
+
fp = array[~np.isnan(array)]
|
|
27
|
+
x = np.isnan(array).ravel().nonzero()[0]
|
|
28
|
+
|
|
29
|
+
array[np.isnan(array)] = np.interp(x, xp, fp)
|
|
30
|
+
|
|
31
|
+
return array
|
|
32
|
+
|
|
33
|
+
tprint("Train")
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
parser = argparse.ArgumentParser(description="Train a signal model from instructions.",
|
|
37
|
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
|
38
|
+
parser.add_argument('-c',"--config", required=True,help="Training instructions")
|
|
39
|
+
parser.add_argument('-g',"--use_gpu", required=True, help="Use GPU")
|
|
40
|
+
|
|
41
|
+
args = parser.parse_args()
|
|
42
|
+
process_arguments = vars(args)
|
|
43
|
+
instructions = str(process_arguments['config'])
|
|
44
|
+
use_gpu = bool(process_arguments['use_gpu'])
|
|
45
|
+
|
|
46
|
+
if os.path.exists(instructions):
|
|
47
|
+
with open(instructions, 'r') as f:
|
|
48
|
+
training_instructions = json.load(f)
|
|
49
|
+
else:
|
|
50
|
+
print('Training instructions could not be found. Abort.')
|
|
51
|
+
os.abort()
|
|
52
|
+
|
|
53
|
+
model_name = training_instructions['model_name']
|
|
54
|
+
target_directory = training_instructions['target_directory']
|
|
55
|
+
model_type = training_instructions['model_type']
|
|
56
|
+
pretrained = training_instructions['pretrained']
|
|
57
|
+
|
|
58
|
+
datasets = training_instructions['ds']
|
|
59
|
+
|
|
60
|
+
target_channels = training_instructions['channel_option']
|
|
61
|
+
normalization_percentile = training_instructions['normalization_percentile']
|
|
62
|
+
normalization_clip = training_instructions['normalization_clip']
|
|
63
|
+
normalization_values = training_instructions['normalization_values']
|
|
64
|
+
spatial_calibration = training_instructions['spatial_calibration']
|
|
65
|
+
|
|
66
|
+
validation_split = training_instructions['validation_split']
|
|
67
|
+
augmentation_factor = training_instructions['augmentation_factor']
|
|
68
|
+
|
|
69
|
+
learning_rate = training_instructions['learning_rate']
|
|
70
|
+
epochs = training_instructions['epochs']
|
|
71
|
+
batch_size = training_instructions['batch_size']
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
# Load dataset
|
|
75
|
+
print(f'Datasets: {datasets}')
|
|
76
|
+
X,Y = load_image_dataset(datasets, target_channels, train_spatial_calibration=spatial_calibration,
|
|
77
|
+
mask_suffix='labelled')
|
|
78
|
+
print('Dataset loaded...')
|
|
79
|
+
|
|
80
|
+
# Normalize images
|
|
81
|
+
X = normalize_per_channel(X,
|
|
82
|
+
normalization_percentile_mode=normalization_percentile,
|
|
83
|
+
normalization_values=normalization_values,
|
|
84
|
+
normalization_clipping=normalization_clip
|
|
85
|
+
)
|
|
86
|
+
|
|
87
|
+
for x in X:
|
|
88
|
+
plt.imshow(x[:,:,0])
|
|
89
|
+
plt.xlim(0,1004)
|
|
90
|
+
plt.ylim(0,1002)
|
|
91
|
+
plt.colorbar()
|
|
92
|
+
plt.pause(2)
|
|
93
|
+
plt.close()
|
|
94
|
+
print(x.shape)
|
|
95
|
+
interp = interpolate_nan(x)
|
|
96
|
+
print(interp.shape)
|
|
97
|
+
print(np.any(np.isnan(x).flatten()))
|
|
98
|
+
print(np.any(np.isnan(interp).flatten()))
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
Y = [fill_label_holes(y) for y in tqdm(Y)]
|
|
102
|
+
|
|
103
|
+
assert len(X) > 1, "not enough training data"
|
|
104
|
+
rng = np.random.RandomState()
|
|
105
|
+
ind = rng.permutation(len(X))
|
|
106
|
+
n_val = max(1, int(round(validation_split * len(ind))))
|
|
107
|
+
ind_train, ind_val = ind[:-n_val], ind[-n_val:]
|
|
108
|
+
X_val, Y_val = [X[i] for i in ind_val] , [Y[i] for i in ind_val]
|
|
109
|
+
X_trn, Y_trn = [X[i] for i in ind_train], [Y[i] for i in ind_train]
|
|
110
|
+
print('number of images: %3d' % len(X))
|
|
111
|
+
print('- training: %3d' % len(X_trn))
|
|
112
|
+
print('- validation: %3d' % len(X_val))
|
|
113
|
+
|
|
114
|
+
if model_type=='cellpose':
|
|
115
|
+
|
|
116
|
+
# do augmentation in place
|
|
117
|
+
X_aug = []; Y_aug = [];
|
|
118
|
+
n_val = max(1, int(round(augmentation_factor * len(X_trn))))
|
|
119
|
+
indices = random.choices(list(np.arange(len(X_trn))), k=n_val)
|
|
120
|
+
print('Performing image augmentation pre-training...')
|
|
121
|
+
for i in tqdm(indices):
|
|
122
|
+
x_aug,y_aug = augmenter(X_trn[i], Y_trn[i])
|
|
123
|
+
X_aug.append(x_aug)
|
|
124
|
+
Y_aug.append(y_aug)
|
|
125
|
+
|
|
126
|
+
# Channel axis in front for cellpose
|
|
127
|
+
X_aug = [np.moveaxis(x,-1,0) for x in X_aug]
|
|
128
|
+
X_val = [np.moveaxis(x,-1,0) for x in X_val]
|
|
129
|
+
print('number of augmented images: %3d' % len(X_aug))
|
|
130
|
+
|
|
131
|
+
from cellpose.models import CellposeModel
|
|
132
|
+
from cellpose.io import logger_setup
|
|
133
|
+
import torch
|
|
134
|
+
|
|
135
|
+
if not use_gpu:
|
|
136
|
+
device = torch.device("cpu")
|
|
137
|
+
|
|
138
|
+
logger, log_file = logger_setup()
|
|
139
|
+
print(f'Pretrained model: ',pretrained)
|
|
140
|
+
if pretrained is not None:
|
|
141
|
+
pretrained_path = os.sep.join([pretrained,os.path.split(pretrained)[-1]])
|
|
142
|
+
else:
|
|
143
|
+
pretrained_path = pretrained
|
|
144
|
+
|
|
145
|
+
model = CellposeModel(gpu=use_gpu, model_type=None, pretrained_model=pretrained_path, diam_mean=30.0, nchan=X_aug[0].shape[0],)
|
|
146
|
+
model.train(train_data=X_aug, train_labels=Y_aug, normalize=False, channels=None, batch_size=batch_size,
|
|
147
|
+
min_train_masks=1,save_path=target_directory+os.sep+model_name,n_epochs=epochs, model_name=model_name, learning_rate=learning_rate, test_data = X_val, test_labels=Y_val)
|
|
148
|
+
|
|
149
|
+
file_to_move = glob(os.sep.join([target_directory, model_name, 'models','*']))[0]
|
|
150
|
+
shutil.move(file_to_move, os.sep.join([target_directory, model_name,''])+os.path.split(file_to_move)[-1])
|
|
151
|
+
os.rmdir(os.sep.join([target_directory, model_name, 'models']))
|
|
152
|
+
|
|
153
|
+
diameter = model.diam_labels
|
|
154
|
+
|
|
155
|
+
if pretrained is not None and os.path.split(pretrained)[-1]=='CP_nuclei':
|
|
156
|
+
standard_diameter = 17.0
|
|
157
|
+
else:
|
|
158
|
+
standard_diameter = 30.0
|
|
159
|
+
|
|
160
|
+
input_spatial_calibration = spatial_calibration #*diameter / standard_diameter
|
|
161
|
+
|
|
162
|
+
config_inputs = {"channels": target_channels, "diameter": standard_diameter, 'cellprob_threshold': 0., 'flow_threshold': 0.4,
|
|
163
|
+
'normalization_percentile': normalization_percentile, 'normalization_clip': normalization_clip,
|
|
164
|
+
'normalization_values': normalization_values, 'model_type': 'cellpose',
|
|
165
|
+
'spatial_calibration': input_spatial_calibration}
|
|
166
|
+
json_input_config = json.dumps(config_inputs, indent=4)
|
|
167
|
+
with open(os.sep.join([target_directory, model_name, "config_input.json"]), "w") as outfile:
|
|
168
|
+
outfile.write(json_input_config)
|
|
169
|
+
|
|
170
|
+
elif model_type=='stardist':
|
|
171
|
+
|
|
172
|
+
from stardist import calculate_extents, gputools_available
|
|
173
|
+
from stardist.models import Config2D, StarDist2D
|
|
174
|
+
|
|
175
|
+
n_rays = 32
|
|
176
|
+
print(gputools_available())
|
|
177
|
+
|
|
178
|
+
n_channel=X_trn[0].shape[-1]
|
|
179
|
+
|
|
180
|
+
# Predict on subsampled grid for increased efficiency and larger field of view
|
|
181
|
+
grid = (2,2)
|
|
182
|
+
conf = Config2D (
|
|
183
|
+
n_rays = n_rays,
|
|
184
|
+
grid = grid,
|
|
185
|
+
use_gpu = use_gpu,
|
|
186
|
+
n_channel_in = n_channel,
|
|
187
|
+
unet_dropout = 0.0,
|
|
188
|
+
unet_batch_norm = False,
|
|
189
|
+
unet_n_conv_per_depth=2,
|
|
190
|
+
train_learning_rate = learning_rate,
|
|
191
|
+
train_patch_size = (256,256),
|
|
192
|
+
train_epochs = epochs,
|
|
193
|
+
#train_foreground_only=0.9,
|
|
194
|
+
train_loss_weights=(1,0.2),
|
|
195
|
+
train_reduce_lr = {'factor': 0.1, 'patience': 30, 'min_delta': 0},
|
|
196
|
+
unet_n_depth = 3,
|
|
197
|
+
train_batch_size = batch_size,
|
|
198
|
+
)
|
|
199
|
+
|
|
200
|
+
if use_gpu:
|
|
201
|
+
from csbdeep.utils.tf import limit_gpu_memory
|
|
202
|
+
limit_gpu_memory(None, allow_growth=True)
|
|
203
|
+
|
|
204
|
+
if pretrained is None:
|
|
205
|
+
model = StarDist2D(conf, name=model_name, basedir=target_directory)
|
|
206
|
+
else:
|
|
207
|
+
# files_to_copy = glob(os.sep.join([pretrained, '*']))
|
|
208
|
+
# for f in files_to_copy:
|
|
209
|
+
# shutil.copy(f, os.sep.join([target_directory, model_name, os.path.split(f)[-1]]))
|
|
210
|
+
idx=1
|
|
211
|
+
while os.path.exists(os.sep.join([target_directory, model_name])):
|
|
212
|
+
model_name = model_name+f'_{idx}'
|
|
213
|
+
idx+=1
|
|
214
|
+
|
|
215
|
+
shutil.copytree(pretrained, os.sep.join([target_directory, model_name]))
|
|
216
|
+
model = StarDist2D(None, name=model_name, basedir=target_directory)
|
|
217
|
+
model.config.train_epochs = epochs
|
|
218
|
+
model.config.train_batch_size = min(len(X_trn),batch_size)
|
|
219
|
+
model.config.train_learning_rate = learning_rate
|
|
220
|
+
|
|
221
|
+
median_size = calculate_extents(list(Y_trn), np.mean)
|
|
222
|
+
fov = np.array(model._axes_tile_overlap('YX'))
|
|
223
|
+
print(f"median object size: {median_size}")
|
|
224
|
+
print(f"network field of view : {fov}")
|
|
225
|
+
if any(median_size > fov):
|
|
226
|
+
print("WARNING: median object size larger than field of view of the neural network.")
|
|
227
|
+
|
|
228
|
+
if augmentation_factor==1.0:
|
|
229
|
+
model.train(X_trn, Y_trn, validation_data=(X_val,Y_val))
|
|
230
|
+
else:
|
|
231
|
+
model.train(X_trn, Y_trn, validation_data=(X_val,Y_val), augmenter=augmenter)
|
|
232
|
+
model.optimize_thresholds(X_val,Y_val)
|
|
233
|
+
|
|
234
|
+
config_inputs = {"channels": target_channels, 'normalization_percentile': normalization_percentile,
|
|
235
|
+
'normalization_clip': normalization_clip, 'normalization_values': normalization_values,
|
|
236
|
+
'model_type': 'stardist', 'spatial_calibration': spatial_calibration}
|
|
237
|
+
|
|
238
|
+
json_input_config = json.dumps(config_inputs, indent=4)
|
|
239
|
+
with open(os.sep.join([target_directory, model_name, "config_input.json"]), "w") as outfile:
|
|
240
|
+
outfile.write(json_input_config)
|
|
241
|
+
|
|
242
|
+
print('Done.')
|
|
243
|
+
|
|
244
|
+
|
|
245
|
+
|
|
246
|
+
|
|
@@ -0,0 +1,49 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Copright © 2023 Laboratoire Adhesion et Inflammation, Authored by Remy Torro.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import argparse
|
|
6
|
+
import os
|
|
7
|
+
import json
|
|
8
|
+
from pathlib import Path, PurePath
|
|
9
|
+
from glob import glob
|
|
10
|
+
from tqdm import tqdm
|
|
11
|
+
import numpy as np
|
|
12
|
+
import gc
|
|
13
|
+
from art import tprint
|
|
14
|
+
from celldetective.signals import SignalDetectionModel
|
|
15
|
+
|
|
16
|
+
tprint("Train")
|
|
17
|
+
|
|
18
|
+
parser = argparse.ArgumentParser(description="Train a signal model from instructions.",
|
|
19
|
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
|
20
|
+
parser.add_argument('-c',"--config", required=True,help="Training instructions")
|
|
21
|
+
|
|
22
|
+
args = parser.parse_args()
|
|
23
|
+
process_arguments = vars(args)
|
|
24
|
+
instructions = str(process_arguments['config'])
|
|
25
|
+
|
|
26
|
+
if os.path.exists(instructions):
|
|
27
|
+
with open(instructions, 'r') as f:
|
|
28
|
+
threshold_instructions = json.load(f)
|
|
29
|
+
threshold_instructions.update({'n_channels': len(threshold_instructions['channel_option'])})
|
|
30
|
+
if threshold_instructions['augmentation_factor']>1.0:
|
|
31
|
+
threshold_instructions.update({'augment': True})
|
|
32
|
+
else:
|
|
33
|
+
threshold_instructions.update({'augment': False})
|
|
34
|
+
threshold_instructions.update({'test_split': 0.})
|
|
35
|
+
else:
|
|
36
|
+
print('The configuration path is not valid. Abort.')
|
|
37
|
+
os.abort()
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
model_params = {k:threshold_instructions[k] for k in ('pretrained', 'model_signal_length', 'channel_option', 'n_channels', 'label') if k in threshold_instructions}
|
|
41
|
+
train_params = {k:threshold_instructions[k] for k in ('model_name', 'target_directory', 'channel_option','recompile_pretrained', 'test_split', 'augment', 'epochs', 'learning_rate', 'batch_size', 'validation_split','normalization_percentile','normalization_values','normalization_clip') if k in threshold_instructions}
|
|
42
|
+
|
|
43
|
+
print(f'model params {model_params}')
|
|
44
|
+
print(f'train params {train_params}')
|
|
45
|
+
|
|
46
|
+
model = SignalDetectionModel(**model_params)
|
|
47
|
+
model.fit_from_directory(threshold_instructions['ds'], **train_params)
|
|
48
|
+
|
|
49
|
+
print('Done.')
|