celldetective 1.4.2__py3-none-any.whl → 1.5.0b1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (152) hide show
  1. celldetective/__init__.py +25 -0
  2. celldetective/__main__.py +62 -43
  3. celldetective/_version.py +1 -1
  4. celldetective/extra_properties.py +477 -399
  5. celldetective/filters.py +192 -97
  6. celldetective/gui/InitWindow.py +541 -411
  7. celldetective/gui/__init__.py +0 -15
  8. celldetective/gui/about.py +44 -39
  9. celldetective/gui/analyze_block.py +120 -84
  10. celldetective/gui/base/__init__.py +0 -0
  11. celldetective/gui/base/channel_norm_generator.py +335 -0
  12. celldetective/gui/base/components.py +249 -0
  13. celldetective/gui/base/feature_choice.py +92 -0
  14. celldetective/gui/base/figure_canvas.py +52 -0
  15. celldetective/gui/base/list_widget.py +133 -0
  16. celldetective/gui/{styles.py → base/styles.py} +92 -36
  17. celldetective/gui/base/utils.py +33 -0
  18. celldetective/gui/base_annotator.py +900 -767
  19. celldetective/gui/classifier_widget.py +6 -22
  20. celldetective/gui/configure_new_exp.py +777 -671
  21. celldetective/gui/control_panel.py +635 -524
  22. celldetective/gui/dynamic_progress.py +449 -0
  23. celldetective/gui/event_annotator.py +2023 -1662
  24. celldetective/gui/generic_signal_plot.py +1292 -944
  25. celldetective/gui/gui_utils.py +899 -1289
  26. celldetective/gui/interactions_block.py +658 -0
  27. celldetective/gui/interactive_timeseries_viewer.py +447 -0
  28. celldetective/gui/json_readers.py +48 -15
  29. celldetective/gui/layouts/__init__.py +5 -0
  30. celldetective/gui/layouts/background_model_free_layout.py +537 -0
  31. celldetective/gui/layouts/channel_offset_layout.py +134 -0
  32. celldetective/gui/layouts/local_correction_layout.py +91 -0
  33. celldetective/gui/layouts/model_fit_layout.py +372 -0
  34. celldetective/gui/layouts/operation_layout.py +68 -0
  35. celldetective/gui/layouts/protocol_designer_layout.py +96 -0
  36. celldetective/gui/pair_event_annotator.py +3130 -2435
  37. celldetective/gui/plot_measurements.py +586 -267
  38. celldetective/gui/plot_signals_ui.py +724 -506
  39. celldetective/gui/preprocessing_block.py +395 -0
  40. celldetective/gui/process_block.py +1678 -1831
  41. celldetective/gui/seg_model_loader.py +580 -473
  42. celldetective/gui/settings/__init__.py +0 -7
  43. celldetective/gui/settings/_cellpose_model_params.py +181 -0
  44. celldetective/gui/settings/_event_detection_model_params.py +95 -0
  45. celldetective/gui/settings/_segmentation_model_params.py +159 -0
  46. celldetective/gui/settings/_settings_base.py +77 -65
  47. celldetective/gui/settings/_settings_event_model_training.py +752 -526
  48. celldetective/gui/settings/_settings_measurements.py +1133 -964
  49. celldetective/gui/settings/_settings_neighborhood.py +574 -488
  50. celldetective/gui/settings/_settings_segmentation_model_training.py +779 -564
  51. celldetective/gui/settings/_settings_signal_annotator.py +329 -305
  52. celldetective/gui/settings/_settings_tracking.py +1304 -1094
  53. celldetective/gui/settings/_stardist_model_params.py +98 -0
  54. celldetective/gui/survival_ui.py +422 -312
  55. celldetective/gui/tableUI.py +1665 -1701
  56. celldetective/gui/table_ops/_maths.py +295 -0
  57. celldetective/gui/table_ops/_merge_groups.py +140 -0
  58. celldetective/gui/table_ops/_merge_one_hot.py +95 -0
  59. celldetective/gui/table_ops/_query_table.py +43 -0
  60. celldetective/gui/table_ops/_rename_col.py +44 -0
  61. celldetective/gui/thresholds_gui.py +382 -179
  62. celldetective/gui/viewers/__init__.py +0 -0
  63. celldetective/gui/viewers/base_viewer.py +700 -0
  64. celldetective/gui/viewers/channel_offset_viewer.py +331 -0
  65. celldetective/gui/viewers/contour_viewer.py +394 -0
  66. celldetective/gui/viewers/size_viewer.py +153 -0
  67. celldetective/gui/viewers/spot_detection_viewer.py +341 -0
  68. celldetective/gui/viewers/threshold_viewer.py +309 -0
  69. celldetective/gui/workers.py +403 -126
  70. celldetective/log_manager.py +92 -0
  71. celldetective/measure.py +1895 -1478
  72. celldetective/napari/__init__.py +0 -0
  73. celldetective/napari/utils.py +1025 -0
  74. celldetective/neighborhood.py +1914 -1448
  75. celldetective/preprocessing.py +1620 -1220
  76. celldetective/processes/__init__.py +0 -0
  77. celldetective/processes/background_correction.py +271 -0
  78. celldetective/processes/compute_neighborhood.py +894 -0
  79. celldetective/processes/detect_events.py +246 -0
  80. celldetective/processes/downloader.py +137 -0
  81. celldetective/processes/measure_cells.py +565 -0
  82. celldetective/processes/segment_cells.py +760 -0
  83. celldetective/processes/track_cells.py +435 -0
  84. celldetective/processes/train_segmentation_model.py +694 -0
  85. celldetective/processes/train_signal_model.py +265 -0
  86. celldetective/processes/unified_process.py +292 -0
  87. celldetective/regionprops/_regionprops.py +358 -317
  88. celldetective/relative_measurements.py +987 -710
  89. celldetective/scripts/measure_cells.py +313 -212
  90. celldetective/scripts/measure_relative.py +90 -46
  91. celldetective/scripts/segment_cells.py +165 -104
  92. celldetective/scripts/segment_cells_thresholds.py +96 -68
  93. celldetective/scripts/track_cells.py +198 -149
  94. celldetective/scripts/train_segmentation_model.py +324 -201
  95. celldetective/scripts/train_signal_model.py +87 -45
  96. celldetective/segmentation.py +844 -749
  97. celldetective/signals.py +3514 -2861
  98. celldetective/tracking.py +30 -15
  99. celldetective/utils/__init__.py +0 -0
  100. celldetective/utils/cellpose_utils/__init__.py +133 -0
  101. celldetective/utils/color_mappings.py +42 -0
  102. celldetective/utils/data_cleaning.py +630 -0
  103. celldetective/utils/data_loaders.py +450 -0
  104. celldetective/utils/dataset_helpers.py +207 -0
  105. celldetective/utils/downloaders.py +235 -0
  106. celldetective/utils/event_detection/__init__.py +8 -0
  107. celldetective/utils/experiment.py +1782 -0
  108. celldetective/utils/image_augmenters.py +308 -0
  109. celldetective/utils/image_cleaning.py +74 -0
  110. celldetective/utils/image_loaders.py +926 -0
  111. celldetective/utils/image_transforms.py +335 -0
  112. celldetective/utils/io.py +62 -0
  113. celldetective/utils/mask_cleaning.py +348 -0
  114. celldetective/utils/mask_transforms.py +5 -0
  115. celldetective/utils/masks.py +184 -0
  116. celldetective/utils/maths.py +351 -0
  117. celldetective/utils/model_getters.py +325 -0
  118. celldetective/utils/model_loaders.py +296 -0
  119. celldetective/utils/normalization.py +380 -0
  120. celldetective/utils/parsing.py +465 -0
  121. celldetective/utils/plots/__init__.py +0 -0
  122. celldetective/utils/plots/regression.py +53 -0
  123. celldetective/utils/resources.py +34 -0
  124. celldetective/utils/stardist_utils/__init__.py +104 -0
  125. celldetective/utils/stats.py +90 -0
  126. celldetective/utils/types.py +21 -0
  127. {celldetective-1.4.2.dist-info → celldetective-1.5.0b1.dist-info}/METADATA +1 -1
  128. celldetective-1.5.0b1.dist-info/RECORD +187 -0
  129. {celldetective-1.4.2.dist-info → celldetective-1.5.0b1.dist-info}/WHEEL +1 -1
  130. tests/gui/test_new_project.py +129 -117
  131. tests/gui/test_project.py +127 -79
  132. tests/test_filters.py +39 -15
  133. tests/test_notebooks.py +8 -0
  134. tests/test_tracking.py +232 -13
  135. tests/test_utils.py +123 -77
  136. celldetective/gui/base_components.py +0 -23
  137. celldetective/gui/layouts.py +0 -1602
  138. celldetective/gui/processes/compute_neighborhood.py +0 -594
  139. celldetective/gui/processes/downloader.py +0 -111
  140. celldetective/gui/processes/measure_cells.py +0 -360
  141. celldetective/gui/processes/segment_cells.py +0 -499
  142. celldetective/gui/processes/track_cells.py +0 -303
  143. celldetective/gui/processes/train_segmentation_model.py +0 -270
  144. celldetective/gui/processes/train_signal_model.py +0 -108
  145. celldetective/gui/table_ops/merge_groups.py +0 -118
  146. celldetective/gui/viewers.py +0 -1354
  147. celldetective/io.py +0 -3663
  148. celldetective/utils.py +0 -3108
  149. celldetective-1.4.2.dist-info/RECORD +0 -123
  150. {celldetective-1.4.2.dist-info → celldetective-1.5.0b1.dist-info}/entry_points.txt +0 -0
  151. {celldetective-1.4.2.dist-info → celldetective-1.5.0b1.dist-info}/licenses/LICENSE +0 -0
  152. {celldetective-1.4.2.dist-info → celldetective-1.5.0b1.dist-info}/top_level.txt +0 -0
@@ -11,75 +11,95 @@ from tqdm import tqdm
11
11
  import numpy as np
12
12
  import random
13
13
 
14
- from celldetective.utils import load_image_dataset, augmenter, interpolate_nan
15
- from celldetective.io import normalize_multichannel
16
- from stardist import fill_label_holes
14
+ from celldetective.utils.image_augmenters import augmenter
15
+ from celldetective.utils.image_loaders import load_image_dataset
16
+ from celldetective.utils.image_cleaning import interpolate_nan
17
+ from celldetective.utils.normalization import normalize_multichannel
18
+ from celldetective.utils.mask_cleaning import fill_label_holes
17
19
  from art import tprint
18
20
  from distutils.dir_util import copy_tree
19
- from csbdeep.utils import save_json
21
+
22
+
23
+ def save_json(data, fpath, **kwargs):
24
+ with open(fpath, "w") as f:
25
+ f.write(json.dumps(data, **kwargs))
26
+
20
27
 
21
28
  tprint("Train")
22
29
 
23
- parser = argparse.ArgumentParser(description="Train a signal model from instructions.",
24
- formatter_class=argparse.ArgumentDefaultsHelpFormatter)
25
- parser.add_argument('-c',"--config", required=True,help="Training instructions")
26
- parser.add_argument('-g',"--use_gpu", required=True, help="Use GPU")
30
+ parser = argparse.ArgumentParser(
31
+ description="Train a signal model from instructions.",
32
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
33
+ )
34
+ parser.add_argument("-c", "--config", required=True, help="Training instructions")
35
+ parser.add_argument("-g", "--use_gpu", required=True, help="Use GPU")
27
36
 
28
37
  args = parser.parse_args()
29
38
  process_arguments = vars(args)
30
- instructions = str(process_arguments['config'])
31
- use_gpu = bool(process_arguments['use_gpu'])
39
+ instructions = str(process_arguments["config"])
40
+ use_gpu = bool(process_arguments["use_gpu"])
32
41
 
33
42
  if os.path.exists(instructions):
34
- with open(instructions, 'r') as f:
35
- training_instructions = json.load(f)
43
+ with open(instructions, "r") as f:
44
+ training_instructions = json.load(f)
36
45
  else:
37
- print('Training instructions could not be found. Abort.')
38
- os.abort()
46
+ print("Training instructions could not be found. Abort.")
47
+ os.abort()
39
48
 
40
- model_name = training_instructions['model_name']
41
- target_directory = training_instructions['target_directory']
42
- model_type = training_instructions['model_type']
43
- pretrained = training_instructions['pretrained']
49
+ model_name = training_instructions["model_name"]
50
+ target_directory = training_instructions["target_directory"]
51
+ model_type = training_instructions["model_type"]
52
+ pretrained = training_instructions["pretrained"]
44
53
 
45
- datasets = training_instructions['ds']
54
+ datasets = training_instructions["ds"]
46
55
 
47
- target_channels = training_instructions['channel_option']
48
- normalization_percentile = training_instructions['normalization_percentile']
49
- normalization_clip = training_instructions['normalization_clip']
50
- normalization_values = training_instructions['normalization_values']
51
- spatial_calibration = training_instructions['spatial_calibration']
56
+ target_channels = training_instructions["channel_option"]
57
+ normalization_percentile = training_instructions["normalization_percentile"]
58
+ normalization_clip = training_instructions["normalization_clip"]
59
+ normalization_values = training_instructions["normalization_values"]
60
+ spatial_calibration = training_instructions["spatial_calibration"]
52
61
 
53
- validation_split = training_instructions['validation_split']
54
- augmentation_factor = training_instructions['augmentation_factor']
62
+ validation_split = training_instructions["validation_split"]
63
+ augmentation_factor = training_instructions["augmentation_factor"]
55
64
 
56
- learning_rate = training_instructions['learning_rate']
57
- epochs = training_instructions['epochs']
58
- batch_size = training_instructions['batch_size']
65
+ learning_rate = training_instructions["learning_rate"]
66
+ epochs = training_instructions["epochs"]
67
+ batch_size = training_instructions["batch_size"]
59
68
 
60
69
 
61
70
  # Load dataset
62
- print(f'Datasets: {datasets}')
63
- X,Y,filenames = load_image_dataset(datasets, target_channels, train_spatial_calibration=spatial_calibration,
64
- mask_suffix='labelled')
65
- print('Dataset loaded...')
71
+ print(f"Datasets: {datasets}")
72
+ X, Y, filenames = load_image_dataset(
73
+ datasets,
74
+ target_channels,
75
+ train_spatial_calibration=spatial_calibration,
76
+ mask_suffix="labelled",
77
+ )
78
+ print("Dataset loaded...")
66
79
 
67
80
  values = []
68
81
  percentiles = []
69
82
  for k in range(len(normalization_percentile)):
70
- if normalization_percentile[k]:
71
- percentiles.append(normalization_values[k])
72
- values.append(None)
73
- else:
74
- percentiles.append(None)
75
- values.append(normalization_values[k])
76
-
77
- X = [normalize_multichannel(x, **{"percentiles": percentiles, 'values': values, 'clip': normalization_clip}) for x in X]
83
+ if normalization_percentile[k]:
84
+ percentiles.append(normalization_values[k])
85
+ values.append(None)
86
+ else:
87
+ percentiles.append(None)
88
+ values.append(normalization_values[k])
89
+
90
+ X = [
91
+ normalize_multichannel(
92
+ x, **{"percentiles": percentiles, "values": values, "clip": normalization_clip}
93
+ )
94
+ for x in X
95
+ ]
78
96
 
79
97
  for k in range(len(X)):
80
- x = X[k].copy()
81
- x_interp = np.moveaxis([interpolate_nan(x[:,:,c].copy()) for c in range(x.shape[-1])],0,-1)
82
- X[k] = x_interp
98
+ x = X[k].copy()
99
+ x_interp = np.moveaxis(
100
+ [interpolate_nan(x[:, :, c].copy()) for c in range(x.shape[-1])], 0, -1
101
+ )
102
+ X[k] = x_interp
83
103
 
84
104
  Y = [fill_label_holes(y) for y in tqdm(Y)]
85
105
 
@@ -88,165 +108,268 @@ rng = np.random.RandomState()
88
108
  ind = rng.permutation(len(X))
89
109
  n_val = max(1, int(round(validation_split * len(ind))))
90
110
  ind_train, ind_val = ind[:-n_val], ind[-n_val:]
91
- X_val, Y_val = [X[i] for i in ind_val] , [Y[i] for i in ind_val]
111
+ X_val, Y_val = [X[i] for i in ind_val], [Y[i] for i in ind_val]
92
112
  X_trn, Y_trn = [X[i] for i in ind_train], [Y[i] for i in ind_train]
93
113
 
94
114
  files_train = [filenames[i] for i in ind_train]
95
115
  files_val = [filenames[i] for i in ind_val]
96
116
 
97
- print('number of images: %3d' % len(X))
98
- print('- training: %3d' % len(X_trn))
99
- print('- validation: %3d' % len(X_val))
100
-
101
- if model_type=='cellpose':
102
-
103
- # do augmentation in place
104
- X_aug = []; Y_aug = [];
105
- n_val = max(1, int(round(augmentation_factor * len(X_trn))))
106
- indices = random.choices(list(np.arange(len(X_trn))), k=n_val)
107
- print('Performing image augmentation pre-training...')
108
- for i in tqdm(indices):
109
- x_aug,y_aug = augmenter(X_trn[i], Y_trn[i])
110
- X_aug.append(x_aug)
111
- Y_aug.append(y_aug)
112
-
113
- # Channel axis in front for cellpose
114
- X_aug = [np.moveaxis(x,-1,0) for x in X_aug]
115
- X_val = [np.moveaxis(x,-1,0) for x in X_val]
116
- print('number of augmented images: %3d' % len(X_aug))
117
-
118
- from cellpose.models import CellposeModel
119
- from cellpose.io import logger_setup
120
- import torch
121
-
122
- if not use_gpu:
123
- print('Using CPU for training...')
124
- device = torch.device("cpu")
125
- else:
126
- print('Using GPU for training...')
127
-
128
- diam_mean = 30.0
129
- logger, log_file = logger_setup()
130
- print(f'Pretrained model: ',pretrained)
131
- if pretrained is not None:
132
- if pretrained.endswith('CP_nuclei'):
133
- diam_mean = 17.0
134
- pretrained_path = os.sep.join([pretrained,os.path.split(pretrained)[-1]])
135
- else:
136
- pretrained_path = pretrained
137
-
138
- model = CellposeModel(gpu=use_gpu, model_type=None, pretrained_model=pretrained_path, diam_mean=diam_mean, nchan=X_aug[0].shape[0],)
139
- model.train(train_data=X_aug, train_labels=Y_aug, normalize=False, channels=None, batch_size=batch_size,
140
- min_train_masks=1,save_path=target_directory+os.sep+model_name,n_epochs=epochs, model_name=model_name, learning_rate=learning_rate, test_data = X_val, test_labels=Y_val)
141
-
142
- file_to_move = glob(os.sep.join([target_directory, model_name, 'models','*']))[0]
143
- shutil.move(file_to_move, os.sep.join([target_directory, model_name,''])+os.path.split(file_to_move)[-1])
144
- os.rmdir(os.sep.join([target_directory, model_name, 'models']))
145
-
146
- diameter = model.diam_labels
147
-
148
- if pretrained is not None and os.path.split(pretrained)[-1]=='CP_nuclei':
149
- standard_diameter = 17.0
150
- else:
151
- standard_diameter = 30.0
152
-
153
- input_spatial_calibration = spatial_calibration #*diameter / standard_diameter
154
-
155
- config_inputs = {"channels": target_channels, "diameter": standard_diameter, 'cellprob_threshold': 0., 'flow_threshold': 0.4,
156
- 'normalization_percentile': normalization_percentile, 'normalization_clip': normalization_clip,
157
- 'normalization_values': normalization_values, 'model_type': 'cellpose',
158
- 'spatial_calibration': input_spatial_calibration, 'cell_size_um': round(diameter*input_spatial_calibration,4), 'dataset': {'train': files_train, 'validation': files_val}}
159
- json_input_config = json.dumps(config_inputs, indent=4)
160
- with open(os.sep.join([target_directory, model_name, "config_input.json"]), "w") as outfile:
161
- outfile.write(json_input_config)
162
-
163
- elif model_type=='stardist':
164
-
165
- from stardist import calculate_extents, gputools_available
166
- from stardist.models import Config2D, StarDist2D
167
-
168
- n_rays = 32
169
- print(gputools_available())
170
-
171
- n_channel=X_trn[0].shape[-1]
172
-
173
- # Predict on subsampled grid for increased efficiency and larger field of view
174
- grid = (2,2)
175
- conf = Config2D(
176
- n_rays = n_rays,
177
- grid = grid,
178
- use_gpu = use_gpu,
179
- n_channel_in = n_channel,
180
- train_learning_rate = learning_rate,
181
- train_patch_size = (256,256),
182
- train_epochs = epochs,
183
- train_reduce_lr = {'factor': 0.1, 'patience': 30, 'min_delta': 0},
184
- train_batch_size = batch_size,
185
- train_steps_per_epoch = int(augmentation_factor*len(X_trn)),
186
- )
187
-
188
- if use_gpu:
189
- from csbdeep.utils.tf import limit_gpu_memory
190
- limit_gpu_memory(None, allow_growth=True)
191
-
192
- if pretrained is None:
193
- model = StarDist2D(conf, name=model_name, basedir=target_directory)
194
- else:
195
-
196
- os.rename(instructions, os.sep.join([target_directory, model_name, 'temp.json']))
197
- copy_tree(pretrained, os.sep.join([target_directory, model_name]))
198
-
199
- if os.path.exists(os.sep.join([target_directory, model_name, 'training_instructions.json'])):
200
- os.remove(os.sep.join([target_directory, model_name, 'training_instructions.json']))
201
- if os.path.exists(os.sep.join([target_directory, model_name, 'config_input.json'])):
202
- os.remove(os.sep.join([target_directory, model_name, 'config_input.json']))
203
- if os.path.exists(os.sep.join([target_directory, model_name, 'logs'+os.sep])):
204
- shutil.rmtree(os.sep.join([target_directory, model_name, 'logs']))
205
- os.rename(os.sep.join([target_directory, model_name, 'temp.json']),os.sep.join([target_directory, model_name, 'training_instructions.json']))
206
-
207
- #shutil.copytree(pretrained, os.sep.join([target_directory, model_name]))
208
- model = StarDist2D(None, name=model_name, basedir=target_directory)
209
- model.config.train_epochs = epochs
210
- model.config.train_batch_size = min(len(X_trn),batch_size)
211
- model.config.train_learning_rate = learning_rate # perf seems bad if lr is changed in transfer
212
- model.config.use_gpu = use_gpu
213
- model.config.train_reduce_lr = {'factor': 0.1, 'patience': 10, 'min_delta': 0}
214
- print(f'{model.config=}')
215
-
216
- save_json(vars(model.config), os.sep.join([target_directory, model_name, 'config.json']))
217
-
218
- median_size = calculate_extents(list(Y_trn), np.mean)
219
- fov = np.array(model._axes_tile_overlap('YX'))
220
- print(f"median object size: {median_size}")
221
- print(f"network field of view : {fov}")
222
- if any(median_size > fov):
223
- print("WARNING: median object size larger than field of view of the neural network.")
224
-
225
- if augmentation_factor==1.0:
226
- model.train(X_trn, Y_trn, validation_data=(X_val,Y_val))
227
- else:
228
- model.train(X_trn, Y_trn, validation_data=(X_val,Y_val), augmenter=augmenter)
229
- model.optimize_thresholds(X_val,Y_val)
230
-
231
- config_inputs = {"channels": target_channels, 'normalization_percentile': normalization_percentile,
232
- 'normalization_clip': normalization_clip, 'normalization_values': normalization_values,
233
- 'model_type': 'stardist', 'spatial_calibration': spatial_calibration,'cell_size_um': median_size * spatial_calibration, 'dataset': {'train': files_train, 'validation': files_val}}
234
-
235
- def make_json_safe(obj):
236
- if isinstance(obj, np.ndarray):
237
- return obj.tolist() # convert to list
238
- if isinstance(obj, (np.int64, np.int32)):
239
- return int(obj)
240
- if isinstance(obj, (np.float32, np.float64)):
241
- return float(obj)
242
- return str(obj) # fallback
243
-
244
- json_input_config = json.dumps(config_inputs, indent=4, default=make_json_safe)
245
- with open(os.sep.join([target_directory, model_name, "config_input.json"]), "w") as outfile:
246
- outfile.write(json_input_config)
247
-
248
- print('Done.')
249
-
250
-
251
-
252
-
117
+ print("number of images: %3d" % len(X))
118
+ print("- training: %3d" % len(X_trn))
119
+ print("- validation: %3d" % len(X_val))
120
+
121
+ if model_type == "cellpose":
122
+
123
+ # do augmentation in place
124
+ X_aug = []
125
+ Y_aug = []
126
+ n_val = max(1, int(round(augmentation_factor * len(X_trn))))
127
+ indices = random.choices(list(np.arange(len(X_trn))), k=n_val)
128
+ print("Performing image augmentation pre-training...")
129
+ for i in tqdm(indices):
130
+ x_aug, y_aug = augmenter(X_trn[i], Y_trn[i])
131
+ X_aug.append(x_aug)
132
+ Y_aug.append(y_aug)
133
+
134
+ # Channel axis in front for cellpose_utils
135
+ X_aug = [np.moveaxis(x, -1, 0) for x in X_aug]
136
+ X_val = [np.moveaxis(x, -1, 0) for x in X_val]
137
+ print("number of augmented images: %3d" % len(X_aug))
138
+
139
+ from cellpose.models import CellposeModel
140
+ from cellpose.io import logger_setup
141
+ import torch
142
+
143
+ if not use_gpu:
144
+ print("Using CPU for training...")
145
+ device = torch.device("cpu")
146
+ else:
147
+ print("Using GPU for training...")
148
+
149
+ diam_mean = 30.0
150
+ logger, log_file = logger_setup()
151
+ print(f"Pretrained model: ", pretrained)
152
+ if pretrained is not None:
153
+ if pretrained.endswith("CP_nuclei"):
154
+ diam_mean = 17.0
155
+ pretrained_path = os.sep.join([pretrained, os.path.split(pretrained)[-1]])
156
+ else:
157
+ pretrained_path = pretrained
158
+
159
+ model = CellposeModel(
160
+ gpu=use_gpu,
161
+ model_type=None,
162
+ pretrained_model=pretrained_path,
163
+ diam_mean=diam_mean,
164
+ nchan=X_aug[0].shape[0],
165
+ )
166
+ for name, module in model.net.named_children():
167
+ print(name, type(module))
168
+
169
+ # Freeze parts of the UNET (if we loaded a pretrained model)
170
+ if pretrained is not None:
171
+ for param in model.net.downsample.parameters():
172
+ param.requires_grad = False
173
+
174
+ # Optional: freeze style branch (recommended unless you are training on very different imaging domains)
175
+ for param in model.net.make_style.parameters():
176
+ param.requires_grad = False
177
+
178
+ # Keep decoder (upsampling path) trainable
179
+ for param in model.net.upsample.parameters():
180
+ param.requires_grad = True
181
+
182
+ # Keep output head trainable
183
+ for param in model.net.output.parameters():
184
+ param.requires_grad = True
185
+
186
+ # Unfreeze all output heads (version-safe)
187
+ output_heads = ["output", "output_conv", "flow", "prob"]
188
+ for head_name in output_heads:
189
+ if hasattr(model.net, head_name):
190
+ for param in getattr(model.net, head_name).parameters():
191
+ param.requires_grad = True
192
+
193
+ # Now train normally (Cellpose will internally skip frozen params)
194
+ model.train(
195
+ train_data=X_aug,
196
+ train_labels=Y_aug,
197
+ normalize=False,
198
+ channels=None,
199
+ batch_size=batch_size,
200
+ min_train_masks=1,
201
+ save_path=target_directory + os.sep + model_name,
202
+ n_epochs=epochs,
203
+ model_name=model_name,
204
+ learning_rate=learning_rate,
205
+ test_data=X_val,
206
+ test_labels=Y_val,
207
+ )
208
+
209
+ file_to_move = glob(os.sep.join([target_directory, model_name, "models", "*"]))[0]
210
+ shutil.move(
211
+ file_to_move,
212
+ os.sep.join([target_directory, model_name, ""])
213
+ + os.path.split(file_to_move)[-1],
214
+ )
215
+ os.rmdir(os.sep.join([target_directory, model_name, "models"]))
216
+
217
+ diameter = model.diam_labels
218
+
219
+ if pretrained is not None and os.path.split(pretrained)[-1] == "CP_nuclei":
220
+ standard_diameter = 17.0
221
+ else:
222
+ standard_diameter = 30.0
223
+
224
+ input_spatial_calibration = spatial_calibration # *diameter / standard_diameter
225
+
226
+ config_inputs = {
227
+ "channels": target_channels,
228
+ "diameter": standard_diameter,
229
+ "cellprob_threshold": 0.0,
230
+ "flow_threshold": 0.4,
231
+ "normalization_percentile": normalization_percentile,
232
+ "normalization_clip": normalization_clip,
233
+ "normalization_values": normalization_values,
234
+ "model_type": "cellpose",
235
+ "spatial_calibration": input_spatial_calibration,
236
+ "cell_size_um": round(diameter * input_spatial_calibration, 4),
237
+ "dataset": {"train": files_train, "validation": files_val},
238
+ }
239
+ json_input_config = json.dumps(config_inputs, indent=4)
240
+ with open(
241
+ os.sep.join([target_directory, model_name, "config_input.json"]), "w"
242
+ ) as outfile:
243
+ outfile.write(json_input_config)
244
+
245
+ elif model_type == "stardist":
246
+
247
+ from stardist import calculate_extents, gputools_available
248
+ from stardist.models import Config2D, StarDist2D
249
+
250
+ n_rays = 32
251
+ print(gputools_available())
252
+
253
+ n_channel = X_trn[0].shape[-1]
254
+
255
+ # Predict on subsampled grid for increased efficiency and larger field of view
256
+ grid = (2, 2)
257
+ conf = Config2D(
258
+ n_rays=n_rays,
259
+ grid=grid,
260
+ use_gpu=use_gpu,
261
+ n_channel_in=n_channel,
262
+ train_learning_rate=learning_rate,
263
+ train_patch_size=(256, 256),
264
+ train_epochs=epochs,
265
+ train_reduce_lr={"factor": 0.1, "patience": 30, "min_delta": 0},
266
+ train_batch_size=batch_size,
267
+ train_steps_per_epoch=int(augmentation_factor * len(X_trn)),
268
+ )
269
+
270
+ if use_gpu:
271
+ from csbdeep.utils.tf import limit_gpu_memory
272
+
273
+ limit_gpu_memory(None, allow_growth=True)
274
+
275
+ if pretrained is None:
276
+ model = StarDist2D(conf, name=model_name, basedir=target_directory)
277
+ else:
278
+
279
+ os.rename(
280
+ instructions, os.sep.join([target_directory, model_name, "temp.json"])
281
+ )
282
+ copy_tree(pretrained, os.sep.join([target_directory, model_name]))
283
+
284
+ if os.path.exists(
285
+ os.sep.join([target_directory, model_name, "training_instructions.json"])
286
+ ):
287
+ os.remove(
288
+ os.sep.join(
289
+ [target_directory, model_name, "training_instructions.json"]
290
+ )
291
+ )
292
+ if os.path.exists(
293
+ os.sep.join([target_directory, model_name, "config_input.json"])
294
+ ):
295
+ os.remove(os.sep.join([target_directory, model_name, "config_input.json"]))
296
+ if os.path.exists(os.sep.join([target_directory, model_name, "logs" + os.sep])):
297
+ shutil.rmtree(os.sep.join([target_directory, model_name, "logs"]))
298
+ os.rename(
299
+ os.sep.join([target_directory, model_name, "temp.json"]),
300
+ os.sep.join([target_directory, model_name, "training_instructions.json"]),
301
+ )
302
+
303
+ # shutil.copytree(pretrained, os.sep.join([target_directory, model_name]))
304
+ model = StarDist2D(None, name=model_name, basedir=target_directory)
305
+ model.config.train_epochs = epochs
306
+ model.config.train_batch_size = min(len(X_trn), batch_size)
307
+ model.config.train_learning_rate = (
308
+ learning_rate # perf seems bad if lr is changed in transfer
309
+ )
310
+ model.config.use_gpu = use_gpu
311
+ model.config.train_reduce_lr = {"factor": 0.1, "patience": 10, "min_delta": 0}
312
+ print(f"{model.config=}")
313
+
314
+ save_json(
315
+ vars(model.config),
316
+ os.sep.join([target_directory, model_name, "config.json"]),
317
+ )
318
+
319
+ median_size = calculate_extents(list(Y_trn), np.mean)
320
+ fov = np.array(model._axes_tile_overlap("YX"))
321
+ print(f"median object size: {median_size}")
322
+ print(f"network field of view : {fov}")
323
+ if any(median_size > fov):
324
+ print(
325
+ "WARNING: median object size larger than field of view of the neural network."
326
+ )
327
+
328
+ if pretrained is not None:
329
+
330
+ mod = model.keras_model
331
+ encoder_depth = len(mod.layers) // 2
332
+
333
+ for layer in mod.layers[:encoder_depth]:
334
+ layer.trainable = False
335
+
336
+ # Keep decoder trainable
337
+ for layer in mod.layers[encoder_depth:]:
338
+ layer.trainable = True
339
+
340
+ if augmentation_factor == 1.0:
341
+ model.train(X_trn, Y_trn, validation_data=(X_val, Y_val))
342
+ else:
343
+ model.train(X_trn, Y_trn, validation_data=(X_val, Y_val), augmenter=augmenter)
344
+ model.optimize_thresholds(X_val, Y_val)
345
+
346
+ if isinstance(median_size, list):
347
+ median_size = np.mean(median_size)
348
+
349
+ config_inputs = {
350
+ "channels": target_channels,
351
+ "normalization_percentile": normalization_percentile,
352
+ "normalization_clip": normalization_clip,
353
+ "normalization_values": normalization_values,
354
+ "model_type": "stardist",
355
+ "spatial_calibration": spatial_calibration,
356
+ "cell_size_um": median_size * spatial_calibration,
357
+ "dataset": {"train": files_train, "validation": files_val},
358
+ }
359
+
360
+ def make_json_safe(obj):
361
+ if isinstance(obj, np.ndarray):
362
+ return obj.tolist() # convert to list
363
+ if isinstance(obj, (np.int64, np.int32)):
364
+ return int(obj)
365
+ if isinstance(obj, (np.float32, np.float64)):
366
+ return float(obj)
367
+ return str(obj) # fallback
368
+
369
+ json_input_config = json.dumps(config_inputs, indent=4, default=make_json_safe)
370
+ with open(
371
+ os.sep.join([target_directory, model_name, "config_input.json"]), "w"
372
+ ) as outfile:
373
+ outfile.write(json_input_config)
374
+
375
+ print("Done.")