celldetective 1.4.2__py3-none-any.whl → 1.5.0b1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (152) hide show
  1. celldetective/__init__.py +25 -0
  2. celldetective/__main__.py +62 -43
  3. celldetective/_version.py +1 -1
  4. celldetective/extra_properties.py +477 -399
  5. celldetective/filters.py +192 -97
  6. celldetective/gui/InitWindow.py +541 -411
  7. celldetective/gui/__init__.py +0 -15
  8. celldetective/gui/about.py +44 -39
  9. celldetective/gui/analyze_block.py +120 -84
  10. celldetective/gui/base/__init__.py +0 -0
  11. celldetective/gui/base/channel_norm_generator.py +335 -0
  12. celldetective/gui/base/components.py +249 -0
  13. celldetective/gui/base/feature_choice.py +92 -0
  14. celldetective/gui/base/figure_canvas.py +52 -0
  15. celldetective/gui/base/list_widget.py +133 -0
  16. celldetective/gui/{styles.py → base/styles.py} +92 -36
  17. celldetective/gui/base/utils.py +33 -0
  18. celldetective/gui/base_annotator.py +900 -767
  19. celldetective/gui/classifier_widget.py +6 -22
  20. celldetective/gui/configure_new_exp.py +777 -671
  21. celldetective/gui/control_panel.py +635 -524
  22. celldetective/gui/dynamic_progress.py +449 -0
  23. celldetective/gui/event_annotator.py +2023 -1662
  24. celldetective/gui/generic_signal_plot.py +1292 -944
  25. celldetective/gui/gui_utils.py +899 -1289
  26. celldetective/gui/interactions_block.py +658 -0
  27. celldetective/gui/interactive_timeseries_viewer.py +447 -0
  28. celldetective/gui/json_readers.py +48 -15
  29. celldetective/gui/layouts/__init__.py +5 -0
  30. celldetective/gui/layouts/background_model_free_layout.py +537 -0
  31. celldetective/gui/layouts/channel_offset_layout.py +134 -0
  32. celldetective/gui/layouts/local_correction_layout.py +91 -0
  33. celldetective/gui/layouts/model_fit_layout.py +372 -0
  34. celldetective/gui/layouts/operation_layout.py +68 -0
  35. celldetective/gui/layouts/protocol_designer_layout.py +96 -0
  36. celldetective/gui/pair_event_annotator.py +3130 -2435
  37. celldetective/gui/plot_measurements.py +586 -267
  38. celldetective/gui/plot_signals_ui.py +724 -506
  39. celldetective/gui/preprocessing_block.py +395 -0
  40. celldetective/gui/process_block.py +1678 -1831
  41. celldetective/gui/seg_model_loader.py +580 -473
  42. celldetective/gui/settings/__init__.py +0 -7
  43. celldetective/gui/settings/_cellpose_model_params.py +181 -0
  44. celldetective/gui/settings/_event_detection_model_params.py +95 -0
  45. celldetective/gui/settings/_segmentation_model_params.py +159 -0
  46. celldetective/gui/settings/_settings_base.py +77 -65
  47. celldetective/gui/settings/_settings_event_model_training.py +752 -526
  48. celldetective/gui/settings/_settings_measurements.py +1133 -964
  49. celldetective/gui/settings/_settings_neighborhood.py +574 -488
  50. celldetective/gui/settings/_settings_segmentation_model_training.py +779 -564
  51. celldetective/gui/settings/_settings_signal_annotator.py +329 -305
  52. celldetective/gui/settings/_settings_tracking.py +1304 -1094
  53. celldetective/gui/settings/_stardist_model_params.py +98 -0
  54. celldetective/gui/survival_ui.py +422 -312
  55. celldetective/gui/tableUI.py +1665 -1701
  56. celldetective/gui/table_ops/_maths.py +295 -0
  57. celldetective/gui/table_ops/_merge_groups.py +140 -0
  58. celldetective/gui/table_ops/_merge_one_hot.py +95 -0
  59. celldetective/gui/table_ops/_query_table.py +43 -0
  60. celldetective/gui/table_ops/_rename_col.py +44 -0
  61. celldetective/gui/thresholds_gui.py +382 -179
  62. celldetective/gui/viewers/__init__.py +0 -0
  63. celldetective/gui/viewers/base_viewer.py +700 -0
  64. celldetective/gui/viewers/channel_offset_viewer.py +331 -0
  65. celldetective/gui/viewers/contour_viewer.py +394 -0
  66. celldetective/gui/viewers/size_viewer.py +153 -0
  67. celldetective/gui/viewers/spot_detection_viewer.py +341 -0
  68. celldetective/gui/viewers/threshold_viewer.py +309 -0
  69. celldetective/gui/workers.py +403 -126
  70. celldetective/log_manager.py +92 -0
  71. celldetective/measure.py +1895 -1478
  72. celldetective/napari/__init__.py +0 -0
  73. celldetective/napari/utils.py +1025 -0
  74. celldetective/neighborhood.py +1914 -1448
  75. celldetective/preprocessing.py +1620 -1220
  76. celldetective/processes/__init__.py +0 -0
  77. celldetective/processes/background_correction.py +271 -0
  78. celldetective/processes/compute_neighborhood.py +894 -0
  79. celldetective/processes/detect_events.py +246 -0
  80. celldetective/processes/downloader.py +137 -0
  81. celldetective/processes/measure_cells.py +565 -0
  82. celldetective/processes/segment_cells.py +760 -0
  83. celldetective/processes/track_cells.py +435 -0
  84. celldetective/processes/train_segmentation_model.py +694 -0
  85. celldetective/processes/train_signal_model.py +265 -0
  86. celldetective/processes/unified_process.py +292 -0
  87. celldetective/regionprops/_regionprops.py +358 -317
  88. celldetective/relative_measurements.py +987 -710
  89. celldetective/scripts/measure_cells.py +313 -212
  90. celldetective/scripts/measure_relative.py +90 -46
  91. celldetective/scripts/segment_cells.py +165 -104
  92. celldetective/scripts/segment_cells_thresholds.py +96 -68
  93. celldetective/scripts/track_cells.py +198 -149
  94. celldetective/scripts/train_segmentation_model.py +324 -201
  95. celldetective/scripts/train_signal_model.py +87 -45
  96. celldetective/segmentation.py +844 -749
  97. celldetective/signals.py +3514 -2861
  98. celldetective/tracking.py +30 -15
  99. celldetective/utils/__init__.py +0 -0
  100. celldetective/utils/cellpose_utils/__init__.py +133 -0
  101. celldetective/utils/color_mappings.py +42 -0
  102. celldetective/utils/data_cleaning.py +630 -0
  103. celldetective/utils/data_loaders.py +450 -0
  104. celldetective/utils/dataset_helpers.py +207 -0
  105. celldetective/utils/downloaders.py +235 -0
  106. celldetective/utils/event_detection/__init__.py +8 -0
  107. celldetective/utils/experiment.py +1782 -0
  108. celldetective/utils/image_augmenters.py +308 -0
  109. celldetective/utils/image_cleaning.py +74 -0
  110. celldetective/utils/image_loaders.py +926 -0
  111. celldetective/utils/image_transforms.py +335 -0
  112. celldetective/utils/io.py +62 -0
  113. celldetective/utils/mask_cleaning.py +348 -0
  114. celldetective/utils/mask_transforms.py +5 -0
  115. celldetective/utils/masks.py +184 -0
  116. celldetective/utils/maths.py +351 -0
  117. celldetective/utils/model_getters.py +325 -0
  118. celldetective/utils/model_loaders.py +296 -0
  119. celldetective/utils/normalization.py +380 -0
  120. celldetective/utils/parsing.py +465 -0
  121. celldetective/utils/plots/__init__.py +0 -0
  122. celldetective/utils/plots/regression.py +53 -0
  123. celldetective/utils/resources.py +34 -0
  124. celldetective/utils/stardist_utils/__init__.py +104 -0
  125. celldetective/utils/stats.py +90 -0
  126. celldetective/utils/types.py +21 -0
  127. {celldetective-1.4.2.dist-info → celldetective-1.5.0b1.dist-info}/METADATA +1 -1
  128. celldetective-1.5.0b1.dist-info/RECORD +187 -0
  129. {celldetective-1.4.2.dist-info → celldetective-1.5.0b1.dist-info}/WHEEL +1 -1
  130. tests/gui/test_new_project.py +129 -117
  131. tests/gui/test_project.py +127 -79
  132. tests/test_filters.py +39 -15
  133. tests/test_notebooks.py +8 -0
  134. tests/test_tracking.py +232 -13
  135. tests/test_utils.py +123 -77
  136. celldetective/gui/base_components.py +0 -23
  137. celldetective/gui/layouts.py +0 -1602
  138. celldetective/gui/processes/compute_neighborhood.py +0 -594
  139. celldetective/gui/processes/downloader.py +0 -111
  140. celldetective/gui/processes/measure_cells.py +0 -360
  141. celldetective/gui/processes/segment_cells.py +0 -499
  142. celldetective/gui/processes/track_cells.py +0 -303
  143. celldetective/gui/processes/train_segmentation_model.py +0 -270
  144. celldetective/gui/processes/train_signal_model.py +0 -108
  145. celldetective/gui/table_ops/merge_groups.py +0 -118
  146. celldetective/gui/viewers.py +0 -1354
  147. celldetective/io.py +0 -3663
  148. celldetective/utils.py +0 -3108
  149. celldetective-1.4.2.dist-info/RECORD +0 -123
  150. {celldetective-1.4.2.dist-info → celldetective-1.5.0b1.dist-info}/entry_points.txt +0 -0
  151. {celldetective-1.4.2.dist-info → celldetective-1.5.0b1.dist-info}/licenses/LICENSE +0 -0
  152. {celldetective-1.4.2.dist-info → celldetective-1.5.0b1.dist-info}/top_level.txt +0 -0
@@ -1,303 +0,0 @@
1
- from multiprocessing import Process
2
- import time
3
- from celldetective.io import auto_load_number_of_frames, _load_frames_to_measure, locate_labels
4
- from celldetective.utils import config_section_to_dict, _get_img_num_per_channel, \
5
- _mask_intensity_measurements, remove_file_if_exists
6
- from pathlib import Path, PurePath
7
- from glob import glob
8
- from tqdm import tqdm
9
- import numpy as np
10
- import gc
11
- import concurrent.futures
12
- import datetime
13
- import os
14
- import json
15
- from celldetective.io import interpret_tracking_configuration
16
- from celldetective.utils import extract_experiment_channels
17
- from celldetective.measure import drop_tonal_features, measure_features
18
- from celldetective.tracking import track
19
- import pandas as pd
20
- from natsort import natsorted
21
- from art import tprint
22
-
23
-
24
- class TrackingProcess(Process):
25
-
26
- def __init__(self, queue=None, process_args=None, *args, **kwargs):
27
-
28
- super().__init__(*args, **kwargs)
29
-
30
- self.queue = queue
31
-
32
- if process_args is not None:
33
- for key, value in process_args.items():
34
- setattr(self, key, value)
35
-
36
-
37
- tprint("Track")
38
- self.timestep_dataframes = []
39
-
40
- # Experiment
41
- self.prepare_folders()
42
-
43
- self.locate_experiment_config()
44
- self.extract_experiment_parameters()
45
- self.read_tracking_instructions()
46
- self.detect_movie_and_labels()
47
- self.detect_channels()
48
-
49
- self.write_log()
50
-
51
- if not self.btrack_option:
52
- self.features = []
53
- self.channel_names = None
54
- self.haralick_options = None
55
-
56
- self.sum_done = 0
57
- self.t0 = time.time()
58
-
59
- def read_tracking_instructions(self):
60
-
61
- instr_path = PurePath(self.exp_dir,Path(f"{self.instruction_file}"))
62
- if os.path.exists(instr_path):
63
- print(f"Tracking instructions for the {self.mode} population have been successfully loaded...")
64
- with open(instr_path, 'r') as f:
65
- self.instructions = json.load(f)
66
-
67
- self.btrack_config = interpret_tracking_configuration(self.instructions['btrack_config_path'])
68
-
69
- if 'features' in self.instructions:
70
- self.features = self.instructions['features']
71
- else:
72
- self.features = None
73
-
74
- if 'mask_channels' in self.instructions:
75
- self.mask_channels = self.instructions['mask_channels']
76
- else:
77
- self.mask_channels = None
78
-
79
- if 'haralick_options' in self.instructions:
80
- self.haralick_options = self.instructions['haralick_options']
81
- else:
82
- self.haralick_options = None
83
-
84
- if 'post_processing_options' in self.instructions:
85
- self.post_processing_options = self.instructions['post_processing_options']
86
- else:
87
- self.post_processing_options = None
88
-
89
- self.btrack_option = True
90
- if 'btrack_option' in self.instructions:
91
- self.btrack_option = self.instructions['btrack_option']
92
- self.search_range = None
93
- if 'search_range' in self.instructions:
94
- self.search_range = self.instructions['search_range']
95
- self.memory = None
96
- if 'memory' in self.instructions:
97
- self.memory = self.instructions['memory']
98
- else:
99
- print('Tracking instructions could not be located... Using a standard bTrack motion model instead...')
100
- self.btrack_config = interpret_tracking_configuration(None)
101
- self.features = None
102
- self.mask_channels = None
103
- self.haralick_options = None
104
- self.post_processing_options = None
105
- self.btrack_option = True
106
- self.memory = None
107
- self.search_range = None
108
-
109
- if self.features is None:
110
- self.features = []
111
-
112
- def detect_channels(self):
113
- self.img_num_channels = _get_img_num_per_channel(self.channel_indices, self.len_movie, self.nbr_channels)
114
-
115
- def write_log(self):
116
-
117
- features_log=f'features: {self.features}'
118
- mask_channels_log=f'mask_channels: {self.mask_channels}'
119
- haralick_option_log=f'haralick_options: {self.haralick_options}'
120
- post_processing_option_log=f'post_processing_options: {self.post_processing_options}'
121
- log_list=[features_log, mask_channels_log, haralick_option_log, post_processing_option_log]
122
- log='\n'.join(log_list)
123
-
124
- with open(self.pos+f'log_{self.mode}.txt', 'a') as f:
125
- f.write(f'{datetime.datetime.now()} TRACK \n')
126
- f.write(log+"\n")
127
-
128
- def prepare_folders(self):
129
-
130
- if not os.path.exists(self.pos+"output"):
131
- os.mkdir(self.pos+"output")
132
-
133
- if not os.path.exists(self.pos+os.sep.join(["output","tables"])):
134
- os.mkdir(self.pos+os.sep.join(["output","tables"]))
135
-
136
- if self.mode.lower()=="target" or self.mode.lower()=="targets":
137
- self.label_folder = "labels_targets"
138
- self.instruction_file = os.sep.join(["configs", "tracking_instructions_targets.json"])
139
- self.napari_name = "napari_target_trajectories.npy"
140
- self.table_name = "trajectories_targets.csv"
141
-
142
- elif self.mode.lower()=="effector" or self.mode.lower()=="effectors":
143
- self.label_folder = "labels_effectors"
144
- self.instruction_file = os.sep.join(["configs","tracking_instructions_effectors.json"])
145
- self.napari_name = "napari_effector_trajectories.npy"
146
- self.table_name = "trajectories_effectors.csv"
147
-
148
- else:
149
- self.label_folder = f"labels_{self.mode}"
150
- self.instruction_file = os.sep.join(["configs",f"tracking_instructions_{self.mode}.json"])
151
- self.napari_name = f"napari_{self.mode}_trajectories.npy"
152
- self.table_name = f"trajectories_{self.mode}.csv"
153
-
154
- def extract_experiment_parameters(self):
155
-
156
- self.movie_prefix = config_section_to_dict(self.config, "MovieSettings")["movie_prefix"]
157
- self.spatial_calibration = float(config_section_to_dict(self.config, "MovieSettings")["pxtoum"])
158
- self.time_calibration = float(config_section_to_dict(self.config, "MovieSettings")["frametomin"])
159
- self.len_movie = float(config_section_to_dict(self.config, "MovieSettings")["len_movie"])
160
- self.shape_x = int(config_section_to_dict(self.config, "MovieSettings")["shape_x"])
161
- self.shape_y = int(config_section_to_dict(self.config, "MovieSettings")["shape_y"])
162
-
163
- self.channel_names, self.channel_indices = extract_experiment_channels(self.exp_dir)
164
- self.nbr_channels = len(self.channel_names)
165
-
166
- def locate_experiment_config(self):
167
-
168
- parent1 = Path(self.pos).parent
169
- self.exp_dir = parent1.parent
170
- self.config = PurePath(self.exp_dir,Path("config.ini"))
171
-
172
- if not os.path.exists(self.config):
173
- print('The configuration file for the experiment was not found...')
174
- self.abort_process()
175
-
176
- def detect_movie_and_labels(self):
177
-
178
- self.label_path = natsorted(glob(self.pos+f"{self.label_folder}"+os.sep+"*.tif"))
179
- if len(self.label_path)>0:
180
- print(f"Found {len(self.label_path)} segmented frames...")
181
- else:
182
- print(f"No segmented frames have been found. Please run segmentation first. Abort...")
183
- self.abort_process()
184
-
185
- try:
186
- self.file = glob(self.pos+f"movie/{self.movie_prefix}*.tif")[0]
187
- except IndexError:
188
- self.file = None
189
- self.haralick_option = None
190
- self.features = drop_tonal_features(self.features)
191
- print('Movie could not be found. Check the prefix.')
192
-
193
- len_movie_auto = auto_load_number_of_frames(self.file)
194
- if len_movie_auto is not None:
195
- self.len_movie = len_movie_auto
196
-
197
- def parallel_job(self, indices):
198
-
199
- props = []
200
-
201
- try:
202
-
203
- for t in tqdm(indices,desc="frame"):
204
-
205
- # Load channels at time t
206
- img = _load_frames_to_measure(self.file, indices=self.img_num_channels[:,t])
207
- lbl = locate_labels(self.pos, population=self.mode, frames=t)
208
- if lbl is None:
209
- continue
210
-
211
- df_props = measure_features(img, lbl, features = self.features+['centroid'], border_dist=None,
212
- channels=self.channel_names, haralick_options=self.haralick_options, verbose=False)
213
- df_props.rename(columns={'centroid-1': 'x', 'centroid-0': 'y'},inplace=True)
214
- df_props['t'] = int(t)
215
-
216
- props.append(df_props)
217
-
218
- self.sum_done+=1/self.len_movie*50
219
- mean_exec_per_step = (time.time() - self.t0) / (self.sum_done*self.len_movie / 50 + 1)
220
- pred_time = (self.len_movie - (self.sum_done*self.len_movie / 50 + 1)) * mean_exec_per_step + 30
221
- self.queue.put([self.sum_done, pred_time])
222
-
223
-
224
- except Exception as e:
225
- print(e)
226
-
227
- return props
228
-
229
- def run(self):
230
-
231
- self.indices = list(range(self.img_num_channels.shape[1]))
232
- chunks = np.array_split(self.indices, self.n_threads)
233
-
234
- self.timestep_dataframes = []
235
- with concurrent.futures.ThreadPoolExecutor(max_workers=self.n_threads) as executor:
236
- results = executor.map(self.parallel_job, chunks)
237
- try:
238
- for i,return_value in enumerate(results):
239
- print(f'Thread {i} completed...')
240
- #print(f"Thread {i} output check: ",return_value)
241
- self.timestep_dataframes.extend(return_value)
242
- except Exception as e:
243
- print("Exception: ", e)
244
-
245
- print('Features successfully measured...')
246
-
247
- df = pd.concat(self.timestep_dataframes)
248
- df = df.replace([np.inf, -np.inf], np.nan)
249
-
250
- df.reset_index(inplace=True, drop=True)
251
- df = _mask_intensity_measurements(df, self.mask_channels)
252
-
253
- # do tracking
254
- if self.btrack_option:
255
- tracker = 'bTrack'
256
- else:
257
- tracker = 'trackpy'
258
-
259
- # do tracking
260
- trajectories, napari_data = track(None,
261
- configuration=self.btrack_config,
262
- objects=df,
263
- spatial_calibration=self.spatial_calibration,
264
- channel_names=self.channel_names,
265
- return_napari_data=True,
266
- optimizer_options = {'tm_lim': int(12e4)},
267
- track_kwargs={'step_size': 100},
268
- clean_trajectories_kwargs=self.post_processing_options,
269
- volume=(self.shape_x, self.shape_y),
270
- btrack_option=self.btrack_option,
271
- search_range=self.search_range,
272
- memory=self.memory,
273
- )
274
- print(f"Tracking successfully performed...")
275
-
276
- # out trajectory table, create POSITION_X_um, POSITION_Y_um, TIME_min (new ones)
277
- # Save napari data
278
- np.save(self.pos+os.sep.join(['output', 'tables', self.napari_name]), napari_data, allow_pickle=True)
279
-
280
- trajectories.to_csv(self.pos+os.sep.join(['output', 'tables', self.table_name]), index=False)
281
- print(f"Trajectory table successfully exported in {os.sep.join(['output', 'tables'])}...")
282
-
283
- remove_file_if_exists(self.pos+os.sep.join(['output', 'tables', self.table_name.replace('.csv','.pkl')]))
284
-
285
- del trajectories; del napari_data;
286
- gc.collect()
287
-
288
- # Send end signal
289
- self.queue.put([100, 0])
290
- time.sleep(1)
291
-
292
- self.queue.put("finished")
293
- self.queue.close()
294
-
295
- def end_process(self):
296
-
297
- self.terminate()
298
- self.queue.put("finished")
299
-
300
- def abort_process(self):
301
-
302
- self.terminate()
303
- self.queue.put("error")
@@ -1,270 +0,0 @@
1
- from multiprocessing import Process
2
- import time
3
- import os
4
- import shutil
5
- from glob import glob
6
- import json
7
- from tqdm import tqdm
8
- import numpy as np
9
- import random
10
-
11
- from celldetective.utils import load_image_dataset, augmenter, interpolate_nan
12
- from celldetective.io import normalize_multichannel
13
- from stardist import fill_label_holes
14
- from art import tprint
15
- from distutils.dir_util import copy_tree
16
- from csbdeep.utils import save_json
17
-
18
-
19
- class TrainSegModelProcess(Process):
20
-
21
- def __init__(self, queue=None, process_args=None, *args, **kwargs):
22
-
23
- super().__init__(*args, **kwargs)
24
-
25
- self.queue = queue
26
-
27
- if process_args is not None:
28
- for key, value in process_args.items():
29
- setattr(self, key, value)
30
-
31
- tprint("Train segmentation")
32
- self.read_instructions()
33
- self.extract_training_params()
34
- self.load_dataset()
35
- self.split_test_train()
36
-
37
- self.sum_done = 0
38
- self.t0 = time.time()
39
-
40
- def read_instructions(self):
41
-
42
- if os.path.exists(self.instructions):
43
- with open(self.instructions, 'r') as f:
44
- self.training_instructions = json.load(f)
45
- else:
46
- print('Training instructions could not be found. Abort.')
47
- self.abort_process()
48
-
49
- def run(self):
50
-
51
- if self.model_type=="cellpose":
52
- self.train_cellpose_model()
53
- elif self.model_type=="stardist":
54
- self.train_stardist_model()
55
-
56
- self.queue.put("finished")
57
- self.queue.close()
58
-
59
- def train_stardist_model(self):
60
-
61
- from stardist import calculate_extents, gputools_available
62
- from stardist.models import Config2D, StarDist2D
63
-
64
- n_rays = 32
65
- print(gputools_available())
66
-
67
- n_channel = self.X_trn[0].shape[-1]
68
-
69
- # Predict on subsampled grid for increased efficiency and larger field of view
70
- grid = (2,2)
71
- conf = Config2D(
72
- n_rays = n_rays,
73
- grid = grid,
74
- use_gpu = self.use_gpu,
75
- n_channel_in = n_channel,
76
- train_learning_rate = self.learning_rate,
77
- train_patch_size = (256,256),
78
- train_epochs = self.epochs,
79
- train_reduce_lr = {'factor': 0.1, 'patience': 30, 'min_delta': 0},
80
- train_batch_size = self.batch_size,
81
- train_steps_per_epoch = int(self.augmentation_factor*len(self.X_trn)),
82
- )
83
-
84
- if self.use_gpu:
85
- from csbdeep.utils.tf import limit_gpu_memory
86
- limit_gpu_memory(None, allow_growth=True)
87
-
88
- if self.pretrained is None:
89
- model = StarDist2D(conf, name=self.model_name, basedir=self.target_directory)
90
- else:
91
- os.rename(self.instructions, os.sep.join([self.target_directory, self.model_name, 'temp.json']))
92
- copy_tree(self.pretrained, os.sep.join([self.target_directory, self.model_name]))
93
-
94
- if os.path.exists(os.sep.join([self.target_directory, self.model_name, 'training_instructions.json'])):
95
- os.remove(os.sep.join([self.target_directory, self.model_name, 'training_instructions.json']))
96
- if os.path.exists(os.sep.join([self.target_directory, self.model_name, 'config_input.json'])):
97
- os.remove(os.sep.join([self.target_directory, self.model_name, 'config_input.json']))
98
- if os.path.exists(os.sep.join([self.target_directory, self.model_name, 'logs'+os.sep])):
99
- shutil.rmtree(os.sep.join([self.target_directory, self.model_name, 'logs']))
100
- os.rename(os.sep.join([self.target_directory, self.model_name, 'temp.json']),os.sep.join([self.target_directory, self.model_name, 'training_instructions.json']))
101
-
102
- #shutil.copytree(pretrained, os.sep.join([target_directory, model_name]))
103
- model = StarDist2D(None, name=self.model_name, basedir=self.target_directory)
104
- model.config.train_epochs = self.epochs
105
- model.config.train_batch_size = min(len(self.X_trn),self.batch_size)
106
- model.config.train_learning_rate = self.learning_rate # perf seems bad if lr is changed in transfer
107
- model.config.use_gpu = self.use_gpu
108
- model.config.train_reduce_lr = {'factor': 0.1, 'patience': 10, 'min_delta': 0}
109
- print(f'{model.config=}')
110
-
111
- save_json(vars(model.config), os.sep.join([self.target_directory, self.model_name, 'config.json']))
112
-
113
- median_size = calculate_extents(list(self.Y_trn), np.mean)
114
- fov = np.array(model._axes_tile_overlap('YX'))
115
- print(f"median object size: {median_size}")
116
- print(f"network field of view : {fov}")
117
- if any(median_size > fov):
118
- print("WARNING: median object size larger than field of view of the neural network.")
119
-
120
- if self.augmentation_factor==1.0:
121
- model.train(self.X_trn, self.Y_trn, validation_data=(self.X_val,self.Y_val))
122
- else:
123
- model.train(self.X_trn, self.Y_trn, validation_data=(self.X_val,self.Y_val), augmenter=augmenter)
124
- model.optimize_thresholds(self.X_val,self.Y_val)
125
-
126
- config_inputs = {"channels": self.target_channels, 'normalization_percentile': self.normalization_percentile,
127
- 'normalization_clip': self.normalization_clip, 'normalization_values': self.normalization_values,
128
- 'model_type': 'stardist', 'spatial_calibration': self.spatial_calibration, 'dataset': {'train': self.files_train, 'validation': self.files_val}}
129
-
130
- json_input_config = json.dumps(config_inputs, indent=4)
131
- with open(os.sep.join([self.target_directory, self.model_name, "config_input.json"]), "w") as outfile:
132
- outfile.write(json_input_config)
133
-
134
- def train_cellpose_model(self):
135
-
136
- # do augmentation in place
137
- X_aug = []; Y_aug = [];
138
- n_val = max(1, int(round(self.augmentation_factor * len(self.X_trn))))
139
- indices = random.choices(list(np.arange(len(self.X_trn))), k=n_val)
140
- print('Performing image augmentation pre-training...')
141
- for i in tqdm(indices):
142
- x_aug,y_aug = augmenter(self.X_trn[i], self.Y_trn[i])
143
- X_aug.append(x_aug)
144
- Y_aug.append(y_aug)
145
-
146
- # Channel axis in front for cellpose
147
- X_aug = [np.moveaxis(x,-1,0) for x in X_aug]
148
- self.X_val = [np.moveaxis(x,-1,0) for x in self.X_val]
149
- print('number of augmented images: %3d' % len(X_aug))
150
-
151
- from cellpose.models import CellposeModel
152
- from cellpose.io import logger_setup
153
- import torch
154
-
155
- if not self.use_gpu:
156
- print('Using CPU for training...')
157
- device = torch.device("cpu")
158
- else:
159
- print('Using GPU for training...')
160
-
161
- logger, log_file = logger_setup()
162
- print(f'Pretrained model: ', self.pretrained)
163
- if self.pretrained is not None:
164
- pretrained_path = os.sep.join([self.pretrained,os.path.split(self.pretrained)[-1]])
165
- else:
166
- pretrained_path = self.pretrained
167
-
168
- model = CellposeModel(gpu=self.use_gpu, model_type=None, pretrained_model=pretrained_path, diam_mean=30.0, nchan=X_aug[0].shape[0],)
169
- model.train(train_data=X_aug, train_labels=Y_aug, normalize=False, channels=None, batch_size=self.batch_size,
170
- min_train_masks=1,save_path=self.target_directory+os.sep+self.model_name,n_epochs=self.epochs, model_name=self.model_name, learning_rate=self.learning_rate, test_data = self.X_val, test_labels=self.Y_val)
171
-
172
- file_to_move = glob(os.sep.join([self.target_directory, self.model_name, 'models','*']))[0]
173
- shutil.move(file_to_move, os.sep.join([self.target_directory, self.model_name,''])+os.path.split(file_to_move)[-1])
174
- os.rmdir(os.sep.join([self.target_directory, self.model_name, 'models']))
175
-
176
- diameter = model.diam_labels
177
-
178
- if self.pretrained is not None and os.path.split(self.pretrained)[-1]=='CP_nuclei':
179
- standard_diameter = 17.0
180
- else:
181
- standard_diameter = 30.0
182
-
183
- input_spatial_calibration = self.spatial_calibration #*diameter / standard_diameter
184
-
185
- config_inputs = {"channels": self.target_channels, "diameter": standard_diameter, 'cellprob_threshold': 0., 'flow_threshold': 0.4,
186
- 'normalization_percentile': self.normalization_percentile, 'normalization_clip': self.normalization_clip,
187
- 'normalization_values': self.normalization_values, 'model_type': 'cellpose',
188
- 'spatial_calibration': input_spatial_calibration, 'dataset': {'train': self.files_train, 'validation': self.files_val}}
189
- json_input_config = json.dumps(config_inputs, indent=4)
190
- with open(os.sep.join([self.target_directory, self.model_name, "config_input.json"]), "w") as outfile:
191
- outfile.write(json_input_config)
192
-
193
-
194
- def split_test_train(self):
195
-
196
- if not len(self.X) > 1:
197
- print("Not enough training data")
198
- self.abort_process()
199
-
200
- rng = np.random.RandomState()
201
- ind = rng.permutation(len(self.X))
202
- n_val = max(1, int(round(self.validation_split * len(ind))))
203
- ind_train, ind_val = ind[:-n_val], ind[-n_val:]
204
- self.X_val, self.Y_val = [self.X[i] for i in ind_val] , [self.Y[i] for i in ind_val]
205
- self.X_trn, self.Y_trn = [self.X[i] for i in ind_train], [self.Y[i] for i in ind_train]
206
-
207
- self.files_train = [self.filenames[i] for i in ind_train]
208
- self.files_val = [self.filenames[i] for i in ind_val]
209
-
210
- print('number of images: %3d' % len(self.X))
211
- print('- training: %3d' % len(self.X_trn))
212
- print('- validation: %3d' % len(self.X_val))
213
-
214
- def extract_training_params(self):
215
-
216
- self.model_name = self.training_instructions['model_name']
217
- self.target_directory = self.training_instructions['target_directory']
218
- self.model_type = self.training_instructions['model_type']
219
- self.pretrained = self.training_instructions['pretrained']
220
-
221
- self.datasets = self.training_instructions['ds']
222
-
223
- self.target_channels = self.training_instructions['channel_option']
224
- self.normalization_percentile = self.training_instructions['normalization_percentile']
225
- self.normalization_clip = self.training_instructions['normalization_clip']
226
- self.normalization_values = self.training_instructions['normalization_values']
227
- self.spatial_calibration = self.training_instructions['spatial_calibration']
228
-
229
- self.validation_split = self.training_instructions['validation_split']
230
- self.augmentation_factor = self.training_instructions['augmentation_factor']
231
-
232
- self.learning_rate = self.training_instructions['learning_rate']
233
- self.epochs = self.training_instructions['epochs']
234
- self.batch_size = self.training_instructions['batch_size']
235
-
236
- def load_dataset(self):
237
-
238
- print(f'Datasets: {self.datasets}')
239
- self.X,self.Y,self.filenames = load_image_dataset(self.datasets, self.target_channels, train_spatial_calibration=self.spatial_calibration,
240
- mask_suffix='labelled')
241
- print('Dataset loaded...')
242
-
243
- self.values = []
244
- self.percentiles = []
245
- for k in range(len(self.normalization_percentile)):
246
- if self.normalization_percentile[k]:
247
- self.percentiles.append(self.normalization_values[k])
248
- self.values.append(None)
249
- else:
250
- self.percentiles.append(None)
251
- self.values.append(self.normalization_values[k])
252
-
253
- self.X = [normalize_multichannel(x, **{"percentiles": self.percentiles, 'values': self.values, 'clip': self.normalization_clip}) for x in self.X]
254
-
255
- for k in range(len(self.X)):
256
- x = self.X[k].copy()
257
- x_interp = np.moveaxis([interpolate_nan(x[:,:,c].copy()) for c in range(x.shape[-1])],0,-1)
258
- self.X[k] = x_interp
259
-
260
- self.Y = [fill_label_holes(y) for y in tqdm(self.Y)]
261
-
262
- def end_process(self):
263
-
264
- self.terminate()
265
- self.queue.put("finished")
266
-
267
- def abort_process(self):
268
-
269
- self.terminate()
270
- self.queue.put("error")
@@ -1,108 +0,0 @@
1
- from multiprocessing import Process
2
- import time
3
- import os
4
- import json
5
- from glob import glob
6
- import numpy as np
7
- from art import tprint
8
- from celldetective.signals import SignalDetectionModel
9
- from celldetective.io import locate_signal_model
10
-
11
-
12
- class TrainSignalModelProcess(Process):
13
-
14
- def __init__(self, queue=None, process_args=None, *args, **kwargs):
15
-
16
- super().__init__(*args, **kwargs)
17
-
18
- self.queue = queue
19
-
20
- if process_args is not None:
21
- for key, value in process_args.items():
22
- setattr(self, key, value)
23
-
24
- tprint("Train segmentation")
25
- self.read_instructions()
26
- self.extract_training_params()
27
-
28
-
29
- self.sum_done = 0
30
- self.t0 = time.time()
31
-
32
- def read_instructions(self):
33
-
34
- if os.path.exists(self.instructions):
35
- with open(self.instructions, 'r') as f:
36
- self.training_instructions = json.load(f)
37
- else:
38
- print('Training instructions could not be found. Abort.')
39
- self.abort_process()
40
-
41
- all_classes = []
42
- for d in self.training_instructions["ds"]:
43
- datasets = glob(d+os.sep+"*.npy")
44
- for dd in datasets:
45
- data = np.load(dd, allow_pickle=True)
46
- classes = np.unique([ddd["class"] for ddd in data])
47
- all_classes.extend(classes)
48
- all_classes = np.unique(all_classes)
49
- n_classes = len(all_classes)
50
-
51
- self.model_params = {k:self.training_instructions[k] for k in ('pretrained', 'model_signal_length', 'channel_option', 'n_channels', 'label') if k in self.training_instructions}
52
- self.model_params.update({'n_classes': n_classes})
53
- self.train_params = {k:self.training_instructions[k] for k in ('model_name', 'target_directory', 'channel_option','recompile_pretrained', 'test_split', 'augment', 'epochs', 'learning_rate', 'batch_size', 'validation_split','normalization_percentile','normalization_values','normalization_clip') if k in self.training_instructions}
54
-
55
- def neighborhood_postprocessing(self):
56
-
57
- # if neighborhood of interest in training instructions, write it in config!
58
- if 'neighborhood_of_interest' in self.training_instructions:
59
- if self.training_instructions['neighborhood_of_interest'] is not None:
60
-
61
- model_path = locate_signal_model(self.training_instructions['model_name'], path=None, pairs=True)
62
- complete_path = model_path #+model
63
- complete_path = rf"{complete_path}"
64
- model_config_path = os.sep.join([complete_path,'config_input.json'])
65
- model_config_path = rf"{model_config_path}"
66
-
67
- f = open(model_config_path)
68
- config = json.load(f)
69
- config.update({'neighborhood_of_interest': self.training_instructions['neighborhood_of_interest'], 'reference_population': self.training_instructions['reference_population'], 'neighbor_population': self.training_instructions['neighbor_population']})
70
- json_string = json.dumps(config)
71
- with open(model_config_path, 'w') as outfile:
72
- outfile.write(json_string)
73
-
74
- def run(self):
75
-
76
- model = SignalDetectionModel(**self.model_params)
77
- model.fit_from_directory(self.training_instructions['ds'], **self.train_params)
78
- self.neighborhood_postprocessing()
79
- self.queue.put("finished")
80
- self.queue.close()
81
-
82
-
83
- def extract_training_params(self):
84
-
85
- self.training_instructions.update({'n_channels': len(self.training_instructions['channel_option'])})
86
- if self.training_instructions['augmentation_factor']>1.0:
87
- self.training_instructions.update({'augment': True})
88
- else:
89
- self.training_instructions.update({'augment': False})
90
- self.training_instructions.update({'test_split': 0.})
91
-
92
-
93
- def end_process(self):
94
-
95
- # self.terminate()
96
-
97
- # if self.model_type=="stardist":
98
- # from stardist.models import StarDist2D
99
- # self.model = StarDist2D(None, name=self.model_name, basedir=self.target_directory)
100
- # self.model.optimize_thresholds(self.X_val,self.Y_val)
101
-
102
- self.terminate()
103
- self.queue.put("finished")
104
-
105
- def abort_process(self):
106
-
107
- self.terminate()
108
- self.queue.put("error")