celldetective 1.0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (66) hide show
  1. celldetective/__init__.py +2 -0
  2. celldetective/__main__.py +432 -0
  3. celldetective/datasets/segmentation_annotations/blank +0 -0
  4. celldetective/datasets/signal_annotations/blank +0 -0
  5. celldetective/events.py +149 -0
  6. celldetective/extra_properties.py +100 -0
  7. celldetective/filters.py +89 -0
  8. celldetective/gui/__init__.py +20 -0
  9. celldetective/gui/about.py +44 -0
  10. celldetective/gui/analyze_block.py +563 -0
  11. celldetective/gui/btrack_options.py +898 -0
  12. celldetective/gui/classifier_widget.py +386 -0
  13. celldetective/gui/configure_new_exp.py +532 -0
  14. celldetective/gui/control_panel.py +438 -0
  15. celldetective/gui/gui_utils.py +495 -0
  16. celldetective/gui/json_readers.py +113 -0
  17. celldetective/gui/measurement_options.py +1425 -0
  18. celldetective/gui/neighborhood_options.py +452 -0
  19. celldetective/gui/plot_signals_ui.py +1042 -0
  20. celldetective/gui/process_block.py +1055 -0
  21. celldetective/gui/retrain_segmentation_model_options.py +706 -0
  22. celldetective/gui/retrain_signal_model_options.py +643 -0
  23. celldetective/gui/seg_model_loader.py +460 -0
  24. celldetective/gui/signal_annotator.py +2388 -0
  25. celldetective/gui/signal_annotator_options.py +340 -0
  26. celldetective/gui/styles.py +217 -0
  27. celldetective/gui/survival_ui.py +903 -0
  28. celldetective/gui/tableUI.py +608 -0
  29. celldetective/gui/thresholds_gui.py +1300 -0
  30. celldetective/icons/logo-large.png +0 -0
  31. celldetective/icons/logo.png +0 -0
  32. celldetective/icons/signals_icon.png +0 -0
  33. celldetective/icons/splash-test.png +0 -0
  34. celldetective/icons/splash.png +0 -0
  35. celldetective/icons/splash0.png +0 -0
  36. celldetective/icons/survival2.png +0 -0
  37. celldetective/icons/vignette_signals2.png +0 -0
  38. celldetective/icons/vignette_signals2.svg +114 -0
  39. celldetective/io.py +2050 -0
  40. celldetective/links/zenodo.json +561 -0
  41. celldetective/measure.py +1258 -0
  42. celldetective/models/segmentation_effectors/blank +0 -0
  43. celldetective/models/segmentation_generic/blank +0 -0
  44. celldetective/models/segmentation_targets/blank +0 -0
  45. celldetective/models/signal_detection/blank +0 -0
  46. celldetective/models/tracking_configs/mcf7.json +68 -0
  47. celldetective/models/tracking_configs/ricm.json +203 -0
  48. celldetective/models/tracking_configs/ricm2.json +203 -0
  49. celldetective/neighborhood.py +717 -0
  50. celldetective/scripts/analyze_signals.py +51 -0
  51. celldetective/scripts/measure_cells.py +275 -0
  52. celldetective/scripts/segment_cells.py +212 -0
  53. celldetective/scripts/segment_cells_thresholds.py +140 -0
  54. celldetective/scripts/track_cells.py +206 -0
  55. celldetective/scripts/train_segmentation_model.py +246 -0
  56. celldetective/scripts/train_signal_model.py +49 -0
  57. celldetective/segmentation.py +712 -0
  58. celldetective/signals.py +2826 -0
  59. celldetective/tracking.py +974 -0
  60. celldetective/utils.py +1681 -0
  61. celldetective-1.0.2.dist-info/LICENSE +674 -0
  62. celldetective-1.0.2.dist-info/METADATA +192 -0
  63. celldetective-1.0.2.dist-info/RECORD +66 -0
  64. celldetective-1.0.2.dist-info/WHEEL +5 -0
  65. celldetective-1.0.2.dist-info/entry_points.txt +2 -0
  66. celldetective-1.0.2.dist-info/top_level.txt +1 -0
@@ -0,0 +1,206 @@
1
+ """
2
+ Copright © 2022 Laboratoire Adhesion et Inflammation, Authored by Remy Torro.
3
+ """
4
+
5
+ import argparse
6
+ import os
7
+ import json
8
+ from celldetective.io import auto_load_number_of_frames, load_frames, interpret_tracking_configuration
9
+ from celldetective.utils import extract_experiment_channels, _extract_channel_indices_from_config, _extract_channel_indices, ConfigSectionMap, _extract_nbr_channels_from_config, _get_img_num_per_channel, extract_experiment_channels
10
+ from celldetective.measure import drop_tonal_features, measure_features
11
+ from celldetective import track
12
+ from pathlib import Path, PurePath
13
+ from glob import glob
14
+ from shutil import rmtree
15
+ from tqdm import tqdm
16
+ import numpy as np
17
+ import pandas as pd
18
+ import gc
19
+ import os
20
+ from natsort import natsorted
21
+ from art import tprint
22
+ from tifffile import imread
23
+ import threading
24
+
25
+ tprint("Track")
26
+
27
+ parser = argparse.ArgumentParser(description="Segment a movie in position with the selected model",
28
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter)
29
+ parser.add_argument('-p',"--position", required=True, help="Path to the position")
30
+ parser.add_argument("--mode", default="target", choices=["target","effector","targets","effectors"],help="Cell population of interest")
31
+ parser.add_argument("--threads", default="1",help="Number of parallel threads")
32
+
33
+ args = parser.parse_args()
34
+ process_arguments = vars(args)
35
+ pos = str(process_arguments['position'])
36
+ mode = str(process_arguments['mode'])
37
+ n_threads = int(process_arguments['threads'])
38
+
39
+ if not os.path.exists(pos+"output"):
40
+ os.mkdir(pos+"output")
41
+
42
+ if not os.path.exists(pos+os.sep.join(["output","tables"])):
43
+ os.mkdir(pos+os.sep.join(["output","tables"]))
44
+
45
+ if mode.lower()=="target" or mode.lower()=="targets":
46
+ label_folder = "labels_targets"
47
+ instruction_file = os.sep.join(["configs", "tracking_instructions_targets.json"])
48
+ napari_name = "napari_target_trajectories.npy"
49
+ table_name = "trajectories_targets.csv"
50
+
51
+ elif mode.lower()=="effector" or mode.lower()=="effectors":
52
+ label_folder = "labels_effectors"
53
+ instruction_file = os.sep.join(["configs","tracking_instructions_effectors.json"])
54
+ napari_name = "napari_effector_trajectories.npy"
55
+ table_name = "trajectories_effectors.csv"
56
+
57
+ # Locate experiment config
58
+ parent1 = Path(pos).parent
59
+ expfolder = parent1.parent
60
+ config = PurePath(expfolder,Path("config.ini"))
61
+ assert os.path.exists(config),'The configuration file for the experiment could not be located. Abort.'
62
+ print("Configuration file: ",config)
63
+
64
+ # from exp config fetch spatial calib, channel names
65
+ movie_prefix = ConfigSectionMap(config,"MovieSettings")["movie_prefix"]
66
+ spatial_calibration = float(ConfigSectionMap(config,"MovieSettings")["pxtoum"])
67
+ time_calibration = float(ConfigSectionMap(config,"MovieSettings")["frametomin"])
68
+ len_movie = float(ConfigSectionMap(config,"MovieSettings")["len_movie"])
69
+ shape_x = int(ConfigSectionMap(config,"MovieSettings")["shape_x"])
70
+ shape_y = int(ConfigSectionMap(config,"MovieSettings")["shape_y"])
71
+
72
+ channel_names, channel_indices = extract_experiment_channels(config)
73
+ nbr_channels = len(channel_names)
74
+
75
+ # from tracking instructions, fetch btrack config, features, haralick, clean_traj, idea: fetch custom timeline?
76
+ instr_path = PurePath(expfolder,Path(f"{instruction_file}"))
77
+ if os.path.exists(instr_path):
78
+ print(f"Tracking instructions for the {mode} population has been successfully located.")
79
+ with open(instr_path, 'r') as f:
80
+ instructions = json.load(f)
81
+ print("Reading the following instructions: ",instructions)
82
+ btrack_config = interpret_tracking_configuration(instructions['btrack_config_path'])
83
+
84
+ if 'features' in instructions:
85
+ features = instructions['features']
86
+ else:
87
+ features = None
88
+
89
+ if 'mask_channels' in instructions:
90
+ mask_channels = instructions['mask_channels']
91
+ else:
92
+ mask_channels = None
93
+
94
+ if 'haralick_options' in instructions:
95
+ haralick_options = instructions['haralick_options']
96
+ else:
97
+ haralick_options = None
98
+
99
+ if 'post_processing_options' in instructions:
100
+ post_processing_options = instructions['post_processing_options']
101
+ else:
102
+ post_processing_options = None
103
+ else:
104
+ print('No tracking instructions found. Use standard bTrack motion model.')
105
+ btrack_config = interpret_tracking_configuration(None)
106
+ features = None
107
+ mask_channels = None
108
+ haralick_options = None
109
+ post_processing_options = None
110
+
111
+ if features is None:
112
+ features = []
113
+
114
+ # from pos fetch labels
115
+ label_path = natsorted(glob(pos+f"{label_folder}"+os.sep+"*.tif"))
116
+ if len(label_path)>0:
117
+ print(f"Found {len(label_path)} segmented frames...")
118
+ else:
119
+ print(f"No segmented frames have been found. Please run segmentation first, skipping...")
120
+ os.abort()
121
+
122
+ # Do this if features or Haralick is not None, else don't need stack
123
+ try:
124
+ file = glob(pos+os.sep.join(["movie", f"{movie_prefix}*.tif"]))[0]
125
+ except IndexError:
126
+ print('Movie could not be found. Check the prefix. If you intended to measure texture or tone, this will not be performed.')
127
+ file = None
128
+ haralick_option = None
129
+ features = drop_tonal_features(features)
130
+
131
+ len_movie_auto = auto_load_number_of_frames(file)
132
+ if len_movie_auto is not None:
133
+ len_movie = len_movie_auto
134
+
135
+ img_num_channels = _get_img_num_per_channel(channel_indices, len_movie, nbr_channels)
136
+
137
+ #######################################
138
+ # Loop over all frames and find objects
139
+ #######################################
140
+
141
+ timestep_dataframes = []
142
+
143
+ def measure_index(indices):
144
+ for t in tqdm(indices,desc="frame"):
145
+
146
+ # Load channels at time t
147
+ img = load_frames(img_num_channels[:,t], file, scale=None, normalize_input=False)
148
+ lbl = imread(label_path[t])
149
+
150
+ df_props = measure_features(img, lbl, features = features+['centroid'], border_dist=None,
151
+ channels=channel_names, haralick_options=haralick_options, verbose=False,
152
+ )
153
+ df_props.rename(columns={'centroid-1': 'x', 'centroid-0': 'y'},inplace=True)
154
+ df_props['t'] = int(t)
155
+ timestep_dataframes.append(df_props)
156
+
157
+ # Multithreading
158
+ indices = list(range(img_num_channels.shape[1]))
159
+ chunks = np.array_split(indices, n_threads)
160
+ threads = []
161
+ for i in range(n_threads):
162
+ thread_i = threading.Thread(target=measure_index, args=[chunks[i]])
163
+ threads.append(thread_i)
164
+ for th in threads:
165
+ th.start()
166
+ for th in threads:
167
+ th.join()
168
+
169
+ df = pd.concat(timestep_dataframes)
170
+ df.reset_index(inplace=True, drop=True)
171
+
172
+ if mask_channels is not None:
173
+ cols_to_drop = []
174
+ for mc in mask_channels:
175
+ columns = df.columns
176
+ col_contains = [mc in c for c in columns]
177
+ to_remove = np.array(columns)[np.array(col_contains)]
178
+ cols_to_drop.extend(to_remove)
179
+ if len(cols_to_drop)>0:
180
+ df = df.drop(cols_to_drop, axis=1)
181
+
182
+ # do tracking
183
+ trajectories, napari_data = track(None,
184
+ configuration=btrack_config,
185
+ objects=df,
186
+ spatial_calibration=spatial_calibration,
187
+ channel_names=channel_names,
188
+ return_napari_data=True,
189
+ optimizer_options = {'tm_lim': int(12e4)},
190
+ track_kwargs={'step_size': 100},
191
+ clean_trajectories_kwargs=post_processing_options,
192
+ volume=(shape_x, shape_y),
193
+ )
194
+ print(trajectories)
195
+ print(trajectories.columns)
196
+
197
+ # out trajectory table, create POSITION_X_um, POSITION_Y_um, TIME_min (new ones)
198
+ # Save napari data
199
+ np.save(pos+os.sep.join(['output', 'tables', napari_name]), napari_data, allow_pickle=True)
200
+ print(f"napari data successfully saved in {pos+os.sep.join(['output', 'tables'])}")
201
+
202
+ trajectories.to_csv(pos+os.sep.join(['output', 'tables', table_name]), index=False)
203
+ print(f"Table {table_name} successfully saved in {os.sep.join(['output', 'tables'])}")
204
+
205
+ del trajectories; del napari_data;
206
+ gc.collect()
@@ -0,0 +1,246 @@
1
+ """
2
+ Copright © 2023 Laboratoire Adhesion et Inflammation, Authored by Remy Torro.
3
+ """
4
+
5
+ import argparse
6
+ import os
7
+ import shutil
8
+ from glob import glob
9
+ import json
10
+ from tqdm import tqdm
11
+ import numpy as np
12
+ import random
13
+
14
+ from celldetective.utils import load_image_dataset, normalize_per_channel, augmenter
15
+ from stardist import fill_label_holes
16
+ from art import tprint
17
+ import matplotlib.pyplot as plt
18
+
19
+ def interpolate_nan(array_like):
20
+ array = array_like.copy()
21
+
22
+ isnan_array = ~np.isnan(array)
23
+
24
+ xp = isnan_array.ravel().nonzero()[0]
25
+
26
+ fp = array[~np.isnan(array)]
27
+ x = np.isnan(array).ravel().nonzero()[0]
28
+
29
+ array[np.isnan(array)] = np.interp(x, xp, fp)
30
+
31
+ return array
32
+
33
+ tprint("Train")
34
+
35
+
36
+ parser = argparse.ArgumentParser(description="Train a signal model from instructions.",
37
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter)
38
+ parser.add_argument('-c',"--config", required=True,help="Training instructions")
39
+ parser.add_argument('-g',"--use_gpu", required=True, help="Use GPU")
40
+
41
+ args = parser.parse_args()
42
+ process_arguments = vars(args)
43
+ instructions = str(process_arguments['config'])
44
+ use_gpu = bool(process_arguments['use_gpu'])
45
+
46
+ if os.path.exists(instructions):
47
+ with open(instructions, 'r') as f:
48
+ training_instructions = json.load(f)
49
+ else:
50
+ print('Training instructions could not be found. Abort.')
51
+ os.abort()
52
+
53
+ model_name = training_instructions['model_name']
54
+ target_directory = training_instructions['target_directory']
55
+ model_type = training_instructions['model_type']
56
+ pretrained = training_instructions['pretrained']
57
+
58
+ datasets = training_instructions['ds']
59
+
60
+ target_channels = training_instructions['channel_option']
61
+ normalization_percentile = training_instructions['normalization_percentile']
62
+ normalization_clip = training_instructions['normalization_clip']
63
+ normalization_values = training_instructions['normalization_values']
64
+ spatial_calibration = training_instructions['spatial_calibration']
65
+
66
+ validation_split = training_instructions['validation_split']
67
+ augmentation_factor = training_instructions['augmentation_factor']
68
+
69
+ learning_rate = training_instructions['learning_rate']
70
+ epochs = training_instructions['epochs']
71
+ batch_size = training_instructions['batch_size']
72
+
73
+
74
+ # Load dataset
75
+ print(f'Datasets: {datasets}')
76
+ X,Y = load_image_dataset(datasets, target_channels, train_spatial_calibration=spatial_calibration,
77
+ mask_suffix='labelled')
78
+ print('Dataset loaded...')
79
+
80
+ # Normalize images
81
+ X = normalize_per_channel(X,
82
+ normalization_percentile_mode=normalization_percentile,
83
+ normalization_values=normalization_values,
84
+ normalization_clipping=normalization_clip
85
+ )
86
+
87
+ for x in X:
88
+ plt.imshow(x[:,:,0])
89
+ plt.xlim(0,1004)
90
+ plt.ylim(0,1002)
91
+ plt.colorbar()
92
+ plt.pause(2)
93
+ plt.close()
94
+ print(x.shape)
95
+ interp = interpolate_nan(x)
96
+ print(interp.shape)
97
+ print(np.any(np.isnan(x).flatten()))
98
+ print(np.any(np.isnan(interp).flatten()))
99
+
100
+
101
+ Y = [fill_label_holes(y) for y in tqdm(Y)]
102
+
103
+ assert len(X) > 1, "not enough training data"
104
+ rng = np.random.RandomState()
105
+ ind = rng.permutation(len(X))
106
+ n_val = max(1, int(round(validation_split * len(ind))))
107
+ ind_train, ind_val = ind[:-n_val], ind[-n_val:]
108
+ X_val, Y_val = [X[i] for i in ind_val] , [Y[i] for i in ind_val]
109
+ X_trn, Y_trn = [X[i] for i in ind_train], [Y[i] for i in ind_train]
110
+ print('number of images: %3d' % len(X))
111
+ print('- training: %3d' % len(X_trn))
112
+ print('- validation: %3d' % len(X_val))
113
+
114
+ if model_type=='cellpose':
115
+
116
+ # do augmentation in place
117
+ X_aug = []; Y_aug = [];
118
+ n_val = max(1, int(round(augmentation_factor * len(X_trn))))
119
+ indices = random.choices(list(np.arange(len(X_trn))), k=n_val)
120
+ print('Performing image augmentation pre-training...')
121
+ for i in tqdm(indices):
122
+ x_aug,y_aug = augmenter(X_trn[i], Y_trn[i])
123
+ X_aug.append(x_aug)
124
+ Y_aug.append(y_aug)
125
+
126
+ # Channel axis in front for cellpose
127
+ X_aug = [np.moveaxis(x,-1,0) for x in X_aug]
128
+ X_val = [np.moveaxis(x,-1,0) for x in X_val]
129
+ print('number of augmented images: %3d' % len(X_aug))
130
+
131
+ from cellpose.models import CellposeModel
132
+ from cellpose.io import logger_setup
133
+ import torch
134
+
135
+ if not use_gpu:
136
+ device = torch.device("cpu")
137
+
138
+ logger, log_file = logger_setup()
139
+ print(f'Pretrained model: ',pretrained)
140
+ if pretrained is not None:
141
+ pretrained_path = os.sep.join([pretrained,os.path.split(pretrained)[-1]])
142
+ else:
143
+ pretrained_path = pretrained
144
+
145
+ model = CellposeModel(gpu=use_gpu, model_type=None, pretrained_model=pretrained_path, diam_mean=30.0, nchan=X_aug[0].shape[0],)
146
+ model.train(train_data=X_aug, train_labels=Y_aug, normalize=False, channels=None, batch_size=batch_size,
147
+ min_train_masks=1,save_path=target_directory+os.sep+model_name,n_epochs=epochs, model_name=model_name, learning_rate=learning_rate, test_data = X_val, test_labels=Y_val)
148
+
149
+ file_to_move = glob(os.sep.join([target_directory, model_name, 'models','*']))[0]
150
+ shutil.move(file_to_move, os.sep.join([target_directory, model_name,''])+os.path.split(file_to_move)[-1])
151
+ os.rmdir(os.sep.join([target_directory, model_name, 'models']))
152
+
153
+ diameter = model.diam_labels
154
+
155
+ if pretrained is not None and os.path.split(pretrained)[-1]=='CP_nuclei':
156
+ standard_diameter = 17.0
157
+ else:
158
+ standard_diameter = 30.0
159
+
160
+ input_spatial_calibration = spatial_calibration #*diameter / standard_diameter
161
+
162
+ config_inputs = {"channels": target_channels, "diameter": standard_diameter, 'cellprob_threshold': 0., 'flow_threshold': 0.4,
163
+ 'normalization_percentile': normalization_percentile, 'normalization_clip': normalization_clip,
164
+ 'normalization_values': normalization_values, 'model_type': 'cellpose',
165
+ 'spatial_calibration': input_spatial_calibration}
166
+ json_input_config = json.dumps(config_inputs, indent=4)
167
+ with open(os.sep.join([target_directory, model_name, "config_input.json"]), "w") as outfile:
168
+ outfile.write(json_input_config)
169
+
170
+ elif model_type=='stardist':
171
+
172
+ from stardist import calculate_extents, gputools_available
173
+ from stardist.models import Config2D, StarDist2D
174
+
175
+ n_rays = 32
176
+ print(gputools_available())
177
+
178
+ n_channel=X_trn[0].shape[-1]
179
+
180
+ # Predict on subsampled grid for increased efficiency and larger field of view
181
+ grid = (2,2)
182
+ conf = Config2D (
183
+ n_rays = n_rays,
184
+ grid = grid,
185
+ use_gpu = use_gpu,
186
+ n_channel_in = n_channel,
187
+ unet_dropout = 0.0,
188
+ unet_batch_norm = False,
189
+ unet_n_conv_per_depth=2,
190
+ train_learning_rate = learning_rate,
191
+ train_patch_size = (256,256),
192
+ train_epochs = epochs,
193
+ #train_foreground_only=0.9,
194
+ train_loss_weights=(1,0.2),
195
+ train_reduce_lr = {'factor': 0.1, 'patience': 30, 'min_delta': 0},
196
+ unet_n_depth = 3,
197
+ train_batch_size = batch_size,
198
+ )
199
+
200
+ if use_gpu:
201
+ from csbdeep.utils.tf import limit_gpu_memory
202
+ limit_gpu_memory(None, allow_growth=True)
203
+
204
+ if pretrained is None:
205
+ model = StarDist2D(conf, name=model_name, basedir=target_directory)
206
+ else:
207
+ # files_to_copy = glob(os.sep.join([pretrained, '*']))
208
+ # for f in files_to_copy:
209
+ # shutil.copy(f, os.sep.join([target_directory, model_name, os.path.split(f)[-1]]))
210
+ idx=1
211
+ while os.path.exists(os.sep.join([target_directory, model_name])):
212
+ model_name = model_name+f'_{idx}'
213
+ idx+=1
214
+
215
+ shutil.copytree(pretrained, os.sep.join([target_directory, model_name]))
216
+ model = StarDist2D(None, name=model_name, basedir=target_directory)
217
+ model.config.train_epochs = epochs
218
+ model.config.train_batch_size = min(len(X_trn),batch_size)
219
+ model.config.train_learning_rate = learning_rate
220
+
221
+ median_size = calculate_extents(list(Y_trn), np.mean)
222
+ fov = np.array(model._axes_tile_overlap('YX'))
223
+ print(f"median object size: {median_size}")
224
+ print(f"network field of view : {fov}")
225
+ if any(median_size > fov):
226
+ print("WARNING: median object size larger than field of view of the neural network.")
227
+
228
+ if augmentation_factor==1.0:
229
+ model.train(X_trn, Y_trn, validation_data=(X_val,Y_val))
230
+ else:
231
+ model.train(X_trn, Y_trn, validation_data=(X_val,Y_val), augmenter=augmenter)
232
+ model.optimize_thresholds(X_val,Y_val)
233
+
234
+ config_inputs = {"channels": target_channels, 'normalization_percentile': normalization_percentile,
235
+ 'normalization_clip': normalization_clip, 'normalization_values': normalization_values,
236
+ 'model_type': 'stardist', 'spatial_calibration': spatial_calibration}
237
+
238
+ json_input_config = json.dumps(config_inputs, indent=4)
239
+ with open(os.sep.join([target_directory, model_name, "config_input.json"]), "w") as outfile:
240
+ outfile.write(json_input_config)
241
+
242
+ print('Done.')
243
+
244
+
245
+
246
+
@@ -0,0 +1,49 @@
1
+ """
2
+ Copright © 2023 Laboratoire Adhesion et Inflammation, Authored by Remy Torro.
3
+ """
4
+
5
+ import argparse
6
+ import os
7
+ import json
8
+ from pathlib import Path, PurePath
9
+ from glob import glob
10
+ from tqdm import tqdm
11
+ import numpy as np
12
+ import gc
13
+ from art import tprint
14
+ from celldetective.signals import SignalDetectionModel
15
+
16
+ tprint("Train")
17
+
18
+ parser = argparse.ArgumentParser(description="Train a signal model from instructions.",
19
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter)
20
+ parser.add_argument('-c',"--config", required=True,help="Training instructions")
21
+
22
+ args = parser.parse_args()
23
+ process_arguments = vars(args)
24
+ instructions = str(process_arguments['config'])
25
+
26
+ if os.path.exists(instructions):
27
+ with open(instructions, 'r') as f:
28
+ threshold_instructions = json.load(f)
29
+ threshold_instructions.update({'n_channels': len(threshold_instructions['channel_option'])})
30
+ if threshold_instructions['augmentation_factor']>1.0:
31
+ threshold_instructions.update({'augment': True})
32
+ else:
33
+ threshold_instructions.update({'augment': False})
34
+ threshold_instructions.update({'test_split': 0.})
35
+ else:
36
+ print('The configuration path is not valid. Abort.')
37
+ os.abort()
38
+
39
+
40
+ model_params = {k:threshold_instructions[k] for k in ('pretrained', 'model_signal_length', 'channel_option', 'n_channels', 'label') if k in threshold_instructions}
41
+ train_params = {k:threshold_instructions[k] for k in ('model_name', 'target_directory', 'channel_option','recompile_pretrained', 'test_split', 'augment', 'epochs', 'learning_rate', 'batch_size', 'validation_split','normalization_percentile','normalization_values','normalization_clip') if k in threshold_instructions}
42
+
43
+ print(f'model params {model_params}')
44
+ print(f'train params {train_params}')
45
+
46
+ model = SignalDetectionModel(**model_params)
47
+ model.fit_from_directory(threshold_instructions['ds'], **train_params)
48
+
49
+ print('Done.')