celldetective 1.1.1.post4__py3-none-any.whl → 1.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- celldetective/__init__.py +2 -1
- celldetective/extra_properties.py +62 -34
- celldetective/gui/__init__.py +1 -0
- celldetective/gui/analyze_block.py +2 -1
- celldetective/gui/classifier_widget.py +8 -7
- celldetective/gui/control_panel.py +50 -6
- celldetective/gui/layouts.py +5 -4
- celldetective/gui/neighborhood_options.py +10 -8
- celldetective/gui/plot_signals_ui.py +39 -11
- celldetective/gui/process_block.py +413 -95
- celldetective/gui/retrain_segmentation_model_options.py +17 -4
- celldetective/gui/retrain_signal_model_options.py +106 -6
- celldetective/gui/signal_annotator.py +25 -5
- celldetective/gui/signal_annotator2.py +2708 -0
- celldetective/gui/signal_annotator_options.py +3 -1
- celldetective/gui/survival_ui.py +15 -6
- celldetective/gui/tableUI.py +235 -39
- celldetective/io.py +537 -421
- celldetective/measure.py +919 -969
- celldetective/models/pair_signal_detection/blank +0 -0
- celldetective/neighborhood.py +426 -354
- celldetective/relative_measurements.py +648 -0
- celldetective/scripts/analyze_signals.py +1 -1
- celldetective/scripts/measure_cells.py +28 -8
- celldetective/scripts/measure_relative.py +103 -0
- celldetective/scripts/segment_cells.py +5 -5
- celldetective/scripts/track_cells.py +4 -1
- celldetective/scripts/train_segmentation_model.py +23 -18
- celldetective/scripts/train_signal_model.py +33 -0
- celldetective/signals.py +402 -8
- celldetective/tracking.py +8 -2
- celldetective/utils.py +93 -0
- {celldetective-1.1.1.post4.dist-info → celldetective-1.2.0.dist-info}/METADATA +8 -8
- {celldetective-1.1.1.post4.dist-info → celldetective-1.2.0.dist-info}/RECORD +38 -34
- {celldetective-1.1.1.post4.dist-info → celldetective-1.2.0.dist-info}/WHEEL +1 -1
- {celldetective-1.1.1.post4.dist-info → celldetective-1.2.0.dist-info}/LICENSE +0 -0
- {celldetective-1.1.1.post4.dist-info → celldetective-1.2.0.dist-info}/entry_points.txt +0 -0
- {celldetective-1.1.1.post4.dist-info → celldetective-1.2.0.dist-info}/top_level.txt +0 -0
|
@@ -72,6 +72,8 @@ if os.path.exists(instr_path):
|
|
|
72
72
|
print("Reading the following instructions: ", instructions)
|
|
73
73
|
if 'background_correction' in instructions:
|
|
74
74
|
background_correction = instructions['background_correction']
|
|
75
|
+
else:
|
|
76
|
+
background_correction = None
|
|
75
77
|
|
|
76
78
|
if 'features' in instructions:
|
|
77
79
|
features = instructions['features']
|
|
@@ -145,14 +147,19 @@ except IndexError:
|
|
|
145
147
|
# Load trajectories, add centroid if not in trajectory
|
|
146
148
|
trajectories = pos+os.sep.join(['output','tables', table_name])
|
|
147
149
|
if os.path.exists(trajectories):
|
|
150
|
+
print('trajectory exists...')
|
|
148
151
|
trajectories = pd.read_csv(trajectories)
|
|
149
152
|
if 'TRACK_ID' not in list(trajectories.columns):
|
|
150
153
|
do_iso_intensities = False
|
|
151
154
|
intensity_measurement_radii = None
|
|
152
155
|
if clear_previous:
|
|
153
|
-
|
|
156
|
+
print('No TRACK_ID... Clear previous measurements...')
|
|
157
|
+
trajectories = None #remove_trajectory_measurements(trajectories, column_labels)
|
|
158
|
+
do_features = True
|
|
159
|
+
features += ['centroid']
|
|
154
160
|
else:
|
|
155
161
|
if clear_previous:
|
|
162
|
+
print('TRACK_ID found... Clear previous measurements...')
|
|
156
163
|
trajectories = remove_trajectory_measurements(trajectories, column_labels)
|
|
157
164
|
else:
|
|
158
165
|
trajectories = None
|
|
@@ -161,11 +168,11 @@ else:
|
|
|
161
168
|
do_iso_intensities = False
|
|
162
169
|
|
|
163
170
|
|
|
164
|
-
if (features is not None) and (trajectories is not None):
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
171
|
+
# if (features is not None) and (trajectories is not None):
|
|
172
|
+
# features = remove_redundant_features(features,
|
|
173
|
+
# trajectories.columns,
|
|
174
|
+
# channel_names=channel_names
|
|
175
|
+
# )
|
|
169
176
|
|
|
170
177
|
len_movie_auto = auto_load_number_of_frames(file)
|
|
171
178
|
if len_movie_auto is not None:
|
|
@@ -187,6 +194,7 @@ if label_path is None:
|
|
|
187
194
|
else:
|
|
188
195
|
do_features = True
|
|
189
196
|
|
|
197
|
+
|
|
190
198
|
#######################################
|
|
191
199
|
# Loop over all frames and find objects
|
|
192
200
|
#######################################
|
|
@@ -196,6 +204,9 @@ if trajectories is None:
|
|
|
196
204
|
print('Use features as a substitute for the trajectory table.')
|
|
197
205
|
if 'label' not in features:
|
|
198
206
|
features.append('label')
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
|
|
199
210
|
features_log=f'features: {features}'
|
|
200
211
|
border_distances_log=f'border_distances: {border_distances}'
|
|
201
212
|
haralick_options_log=f'haralick_options: {haralick_options}'
|
|
@@ -237,7 +248,7 @@ def measure_index(indices):
|
|
|
237
248
|
column_labels = {'track': "ID", 'time': column_labels['time'], 'x': column_labels['x'],
|
|
238
249
|
'y': column_labels['y']}
|
|
239
250
|
feature_table.rename(columns={'centroid-1': 'POSITION_X', 'centroid-0': 'POSITION_Y'}, inplace=True)
|
|
240
|
-
|
|
251
|
+
|
|
241
252
|
if do_iso_intensities:
|
|
242
253
|
iso_table = measure_isotropic_intensity(positions_at_t, img, channels=channel_names, intensity_measurement_radii=intensity_measurement_radii, column_labels=column_labels, operations=isotropic_operations, verbose=False)
|
|
243
254
|
|
|
@@ -249,11 +260,20 @@ def measure_index(indices):
|
|
|
249
260
|
elif do_features:
|
|
250
261
|
measurements_at_t = positions_at_t.merge(feature_table, how='outer', on='class_id',suffixes=('', '_delme'))
|
|
251
262
|
measurements_at_t = measurements_at_t[[c for c in measurements_at_t.columns if not c.endswith('_delme')]]
|
|
252
|
-
|
|
263
|
+
|
|
264
|
+
center_of_mass_x_cols = [c for c in list(measurements_at_t.columns) if c.endswith('centre_of_mass_x')]
|
|
265
|
+
center_of_mass_y_cols = [c for c in list(measurements_at_t.columns) if c.endswith('centre_of_mass_y')]
|
|
266
|
+
for c in center_of_mass_x_cols:
|
|
267
|
+
measurements_at_t.loc[:,c.replace('_x','_POSITION_X')] = measurements_at_t[c] + measurements_at_t['POSITION_X']
|
|
268
|
+
for c in center_of_mass_y_cols:
|
|
269
|
+
measurements_at_t.loc[:,c.replace('_y','_POSITION_Y')] = measurements_at_t[c] + measurements_at_t['POSITION_Y']
|
|
270
|
+
measurements_at_t = measurements_at_t.drop(columns = center_of_mass_x_cols+center_of_mass_y_cols)
|
|
271
|
+
|
|
253
272
|
if measurements_at_t is not None:
|
|
254
273
|
measurements_at_t[column_labels['time']] = t
|
|
255
274
|
timestep_dataframes.append(measurements_at_t)
|
|
256
275
|
|
|
276
|
+
|
|
257
277
|
# Multithreading
|
|
258
278
|
indices = list(range(img_num_channels.shape[1]))
|
|
259
279
|
chunks = np.array_split(indices, n_threads)
|
|
@@ -0,0 +1,103 @@
|
|
|
1
|
+
import argparse
|
|
2
|
+
import os
|
|
3
|
+
import json
|
|
4
|
+
from celldetective.relative_measurements import measure_pair_signals_at_position, update_effector_table, extract_neighborhoods_from_pickles
|
|
5
|
+
from celldetective.utils import ConfigSectionMap, extract_experiment_channels
|
|
6
|
+
|
|
7
|
+
from pathlib import Path, PurePath
|
|
8
|
+
|
|
9
|
+
import pandas as pd
|
|
10
|
+
|
|
11
|
+
from art import tprint
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
tprint("Measure pairs")
|
|
15
|
+
|
|
16
|
+
parser = argparse.ArgumentParser(description="Measure features and intensities in a multichannel timeseries.",
|
|
17
|
+
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
|
18
|
+
parser.add_argument('-p', "--position", required=True, help="Path to the position")
|
|
19
|
+
|
|
20
|
+
args = parser.parse_args()
|
|
21
|
+
process_arguments = vars(args)
|
|
22
|
+
pos = str(process_arguments['position'])
|
|
23
|
+
|
|
24
|
+
instruction_file = os.sep.join(['configs', "neighborhood_instructions.json"])
|
|
25
|
+
|
|
26
|
+
# Locate experiment config
|
|
27
|
+
parent1 = Path(pos).parent
|
|
28
|
+
expfolder = parent1.parent
|
|
29
|
+
config = PurePath(expfolder, Path("config.ini"))
|
|
30
|
+
assert os.path.exists(config), 'The configuration file for the experiment could not be located. Abort.'
|
|
31
|
+
print("Configuration file: ", config)
|
|
32
|
+
|
|
33
|
+
# from exp config fetch spatial calib, channel names
|
|
34
|
+
movie_prefix = ConfigSectionMap(config, "MovieSettings")["movie_prefix"]
|
|
35
|
+
spatial_calibration = float(ConfigSectionMap(config, "MovieSettings")["pxtoum"])
|
|
36
|
+
time_calibration = float(ConfigSectionMap(config, "MovieSettings")["frametomin"])
|
|
37
|
+
len_movie = float(ConfigSectionMap(config, "MovieSettings")["len_movie"])
|
|
38
|
+
channel_names, channel_inneigh_protocoles = extract_experiment_channels(config)
|
|
39
|
+
nbr_channels = len(channel_names)
|
|
40
|
+
|
|
41
|
+
# from tracking instructions, fetch btrack config, features, haralick, clean_traj, idea: fetch custom timeline?
|
|
42
|
+
instr_path = PurePath(expfolder, Path(f"{instruction_file}"))
|
|
43
|
+
previous_pair_table_path = pos + os.sep.join(['output', 'tables', 'trajectories_pairs.csv'])
|
|
44
|
+
|
|
45
|
+
# if os.path.exists(instr_path):
|
|
46
|
+
# print(f"Neighborhood instructions has been successfully located.")
|
|
47
|
+
# with open(instr_path, 'r') as f:
|
|
48
|
+
# instructions = json.load(f)
|
|
49
|
+
# print("Reading the following instructions: ", instructions)
|
|
50
|
+
|
|
51
|
+
# if 'distance' in instructions:
|
|
52
|
+
# distance = instructions['distance'][0]
|
|
53
|
+
# else:
|
|
54
|
+
# distance = None
|
|
55
|
+
# else:
|
|
56
|
+
# print('No measurement instructions found')
|
|
57
|
+
# os.abort()
|
|
58
|
+
|
|
59
|
+
previous_neighborhoods = []
|
|
60
|
+
associated_reference_population = []
|
|
61
|
+
|
|
62
|
+
# if distance is None:
|
|
63
|
+
# print('No measurement could be performed. Check your inputs.')
|
|
64
|
+
# print('Done.')
|
|
65
|
+
# os.abort()
|
|
66
|
+
# #distance = 0
|
|
67
|
+
# else:
|
|
68
|
+
neighborhoods_to_measure = extract_neighborhoods_from_pickles(pos)
|
|
69
|
+
all_df_pairs = []
|
|
70
|
+
if os.path.exists(previous_pair_table_path):
|
|
71
|
+
df_0 = pd.read_csv(previous_pair_table_path)
|
|
72
|
+
previous_neighborhoods = [c.replace('status_','') for c in list(df_0.columns) if c.startswith('status_neighborhood')]
|
|
73
|
+
for n in previous_neighborhoods:
|
|
74
|
+
associated_reference_population.append(df_0.loc[~df_0['status_'+n].isnull(),'reference_population'].values[0])
|
|
75
|
+
print(f'{previous_neighborhoods=} {associated_reference_population=}')
|
|
76
|
+
all_df_pairs.append(df_0)
|
|
77
|
+
for k,neigh_protocol in enumerate(neighborhoods_to_measure):
|
|
78
|
+
if neigh_protocol['description'] not in previous_neighborhoods:
|
|
79
|
+
df_pairs = measure_pair_signals_at_position(pos, neigh_protocol)
|
|
80
|
+
print(f'{df_pairs=}')
|
|
81
|
+
if 'REFERENCE_ID' in list(df_pairs.columns):
|
|
82
|
+
all_df_pairs.append(df_pairs)
|
|
83
|
+
elif neigh_protocol['description'] in previous_neighborhoods and neigh_protocol['reference'] != associated_reference_population[previous_neighborhoods.index(neigh_protocol['description'])]:
|
|
84
|
+
df_pairs = measure_pair_signals_at_position(pos, neigh_protocol)
|
|
85
|
+
if 'REFERENCE_ID' in list(df_pairs.columns):
|
|
86
|
+
all_df_pairs.append(df_pairs)
|
|
87
|
+
|
|
88
|
+
print(f'{len(all_df_pairs)} neighborhood measurements sets were computed...')
|
|
89
|
+
|
|
90
|
+
if len(all_df_pairs)>1:
|
|
91
|
+
print('Merging...')
|
|
92
|
+
df_pairs = all_df_pairs[0]
|
|
93
|
+
for i in range(1,len(all_df_pairs)):
|
|
94
|
+
cols = [c1 for c1,c2 in zip(list(df_pairs.columns), list(all_df_pairs[i].columns)) if c1==c2]
|
|
95
|
+
df_pairs = pd.merge(df_pairs.round(decimals=6), all_df_pairs[i].round(decimals=6), how="outer", on=cols)
|
|
96
|
+
elif len(all_df_pairs)==1:
|
|
97
|
+
df_pairs = all_df_pairs[0]
|
|
98
|
+
|
|
99
|
+
print('Writing table...')
|
|
100
|
+
df_pairs = df_pairs.sort_values(by=['reference_population', 'neighbor_population', 'REFERENCE_ID', 'NEIGHBOR_ID', 'FRAME'])
|
|
101
|
+
df_pairs.to_csv(previous_pair_table_path, index=False)
|
|
102
|
+
print('Done.')
|
|
103
|
+
|
|
@@ -129,10 +129,10 @@ img_num_channels = _get_img_num_per_channel(channel_indices, int(len_movie), nbr
|
|
|
129
129
|
|
|
130
130
|
# If everything OK, prepare output, load models
|
|
131
131
|
print('Erasing previous segmentation folder.')
|
|
132
|
-
if os.path.exists(
|
|
133
|
-
rmtree(
|
|
134
|
-
os.mkdir(
|
|
135
|
-
print(f'Folder {
|
|
132
|
+
if os.path.exists(pos+label_folder):
|
|
133
|
+
rmtree(pos+label_folder)
|
|
134
|
+
os.mkdir(pos+label_folder)
|
|
135
|
+
print(f'Folder {pos+label_folder} successfully generated.')
|
|
136
136
|
log=f'segmentation model: {modelname}\n'
|
|
137
137
|
with open(pos+f'log_{mode}.json', 'a') as f:
|
|
138
138
|
f.write(f'{datetime.datetime.now()} SEGMENT \n')
|
|
@@ -203,7 +203,7 @@ def segment_index(indices):
|
|
|
203
203
|
if Y_pred.shape != template.shape[:2]:
|
|
204
204
|
Y_pred = resize(Y_pred, template.shape[:2], order=0)
|
|
205
205
|
|
|
206
|
-
save_tiff_imagej_compatible(os.sep.join([
|
|
206
|
+
save_tiff_imagej_compatible(pos+os.sep.join([label_folder,f"{str(t).zfill(4)}.tif"]), Y_pred, axes='YX')
|
|
207
207
|
|
|
208
208
|
del f;
|
|
209
209
|
del template;
|
|
@@ -9,7 +9,7 @@ import json
|
|
|
9
9
|
from celldetective.io import auto_load_number_of_frames, load_frames, interpret_tracking_configuration
|
|
10
10
|
from celldetective.utils import extract_experiment_channels, _extract_channel_indices_from_config, _extract_channel_indices, ConfigSectionMap, _extract_nbr_channels_from_config, _get_img_num_per_channel, extract_experiment_channels
|
|
11
11
|
from celldetective.measure import drop_tonal_features, measure_features
|
|
12
|
-
from celldetective import track
|
|
12
|
+
from celldetective.tracking import track
|
|
13
13
|
from pathlib import Path, PurePath
|
|
14
14
|
from glob import glob
|
|
15
15
|
from shutil import rmtree
|
|
@@ -213,5 +213,8 @@ print(f"napari data successfully saved in {pos+os.sep.join(['output', 'tables'])
|
|
|
213
213
|
trajectories.to_csv(pos+os.sep.join(['output', 'tables', table_name]), index=False)
|
|
214
214
|
print(f"Table {table_name} successfully saved in {os.sep.join(['output', 'tables'])}")
|
|
215
215
|
|
|
216
|
+
if os.path.exists(pos+os.sep.join(['output', 'tables', table_name.replace('.csv','.pkl')])):
|
|
217
|
+
os.remove(pos+os.sep.join(['output', 'tables', table_name.replace('.csv','.pkl')]))
|
|
218
|
+
|
|
216
219
|
del trajectories; del napari_data;
|
|
217
220
|
gc.collect()
|
|
@@ -16,7 +16,8 @@ from celldetective.io import normalize_multichannel
|
|
|
16
16
|
from stardist import fill_label_holes
|
|
17
17
|
from art import tprint
|
|
18
18
|
import matplotlib.pyplot as plt
|
|
19
|
-
|
|
19
|
+
from distutils.dir_util import copy_tree
|
|
20
|
+
from csbdeep.utils import save_json
|
|
20
21
|
|
|
21
22
|
tprint("Train")
|
|
22
23
|
|
|
@@ -169,22 +170,17 @@ elif model_type=='stardist':
|
|
|
169
170
|
|
|
170
171
|
# Predict on subsampled grid for increased efficiency and larger field of view
|
|
171
172
|
grid = (2,2)
|
|
172
|
-
conf = Config2D
|
|
173
|
+
conf = Config2D(
|
|
173
174
|
n_rays = n_rays,
|
|
174
175
|
grid = grid,
|
|
175
176
|
use_gpu = use_gpu,
|
|
176
177
|
n_channel_in = n_channel,
|
|
177
|
-
unet_dropout = 0.0,
|
|
178
|
-
unet_batch_norm = False,
|
|
179
|
-
unet_n_conv_per_depth=2,
|
|
180
178
|
train_learning_rate = learning_rate,
|
|
181
179
|
train_patch_size = (256,256),
|
|
182
180
|
train_epochs = epochs,
|
|
183
|
-
#train_foreground_only=0.9,
|
|
184
|
-
train_loss_weights=(1,0.2),
|
|
185
181
|
train_reduce_lr = {'factor': 0.1, 'patience': 30, 'min_delta': 0},
|
|
186
|
-
unet_n_depth = 3,
|
|
187
182
|
train_batch_size = batch_size,
|
|
183
|
+
train_steps_per_epoch = int(augmentation_factor*len(X_trn)),
|
|
188
184
|
)
|
|
189
185
|
|
|
190
186
|
if use_gpu:
|
|
@@ -194,19 +190,28 @@ elif model_type=='stardist':
|
|
|
194
190
|
if pretrained is None:
|
|
195
191
|
model = StarDist2D(conf, name=model_name, basedir=target_directory)
|
|
196
192
|
else:
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
model_name
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
193
|
+
|
|
194
|
+
os.rename(instructions, os.sep.join([target_directory, model_name, 'temp.json']))
|
|
195
|
+
copy_tree(pretrained, os.sep.join([target_directory, model_name]))
|
|
196
|
+
|
|
197
|
+
if os.path.exists(os.sep.join([target_directory, model_name, 'training_instructions.json'])):
|
|
198
|
+
os.remove(os.sep.join([target_directory, model_name, 'training_instructions.json']))
|
|
199
|
+
if os.path.exists(os.sep.join([target_directory, model_name, 'config_input.json'])):
|
|
200
|
+
os.remove(os.sep.join([target_directory, model_name, 'config_input.json']))
|
|
201
|
+
if os.path.exists(os.sep.join([target_directory, model_name, 'logs'+os.sep])):
|
|
202
|
+
shutil.rmtree(os.sep.join([target_directory, model_name, 'logs']))
|
|
203
|
+
os.rename(os.sep.join([target_directory, model_name, 'temp.json']),os.sep.join([target_directory, model_name, 'training_instructions.json']))
|
|
204
|
+
|
|
205
|
+
#shutil.copytree(pretrained, os.sep.join([target_directory, model_name]))
|
|
206
206
|
model = StarDist2D(None, name=model_name, basedir=target_directory)
|
|
207
207
|
model.config.train_epochs = epochs
|
|
208
208
|
model.config.train_batch_size = min(len(X_trn),batch_size)
|
|
209
|
-
model.config.train_learning_rate = learning_rate
|
|
209
|
+
model.config.train_learning_rate = learning_rate # perf seems bad if lr is changed in transfer
|
|
210
|
+
model.config.use_gpu = use_gpu
|
|
211
|
+
model.config.train_reduce_lr = {'factor': 0.1, 'patience': 10, 'min_delta': 0}
|
|
212
|
+
print(f'{model.config=}')
|
|
213
|
+
|
|
214
|
+
save_json(vars(model.config), os.sep.join([target_directory, model_name, 'config.json']))
|
|
210
215
|
|
|
211
216
|
median_size = calculate_extents(list(Y_trn), np.mean)
|
|
212
217
|
fov = np.array(model._axes_tile_overlap('YX'))
|
|
@@ -12,6 +12,7 @@ import numpy as np
|
|
|
12
12
|
import gc
|
|
13
13
|
from art import tprint
|
|
14
14
|
from celldetective.signals import SignalDetectionModel
|
|
15
|
+
from celldetective.io import locate_signal_model
|
|
15
16
|
|
|
16
17
|
tprint("Train")
|
|
17
18
|
|
|
@@ -36,14 +37,46 @@ else:
|
|
|
36
37
|
print('The configuration path is not valid. Abort.')
|
|
37
38
|
os.abort()
|
|
38
39
|
|
|
40
|
+
all_classes = []
|
|
41
|
+
for d in threshold_instructions["ds"]:
|
|
42
|
+
datasets = glob(d+os.sep+"*.npy")
|
|
43
|
+
for dd in datasets:
|
|
44
|
+
data = np.load(dd, allow_pickle=True)
|
|
45
|
+
classes = np.unique([ddd["class"] for ddd in data])
|
|
46
|
+
all_classes.extend(classes)
|
|
47
|
+
all_classes = np.unique(all_classes)
|
|
48
|
+
print(all_classes,len(all_classes))
|
|
49
|
+
|
|
50
|
+
n_classes = len(all_classes)
|
|
39
51
|
|
|
40
52
|
model_params = {k:threshold_instructions[k] for k in ('pretrained', 'model_signal_length', 'channel_option', 'n_channels', 'label') if k in threshold_instructions}
|
|
53
|
+
model_params.update({'n_classes': n_classes})
|
|
54
|
+
|
|
41
55
|
train_params = {k:threshold_instructions[k] for k in ('model_name', 'target_directory', 'channel_option','recompile_pretrained', 'test_split', 'augment', 'epochs', 'learning_rate', 'batch_size', 'validation_split','normalization_percentile','normalization_values','normalization_clip') if k in threshold_instructions}
|
|
42
56
|
|
|
43
57
|
print(f'model params {model_params}')
|
|
44
58
|
print(f'train params {train_params}')
|
|
45
59
|
|
|
46
60
|
model = SignalDetectionModel(**model_params)
|
|
61
|
+
print(threshold_instructions['ds'])
|
|
47
62
|
model.fit_from_directory(threshold_instructions['ds'], **train_params)
|
|
48
63
|
|
|
64
|
+
|
|
65
|
+
# if neighborhood of interest in training instructions, write it in config!
|
|
66
|
+
if 'neighborhood_of_interest' in threshold_instructions:
|
|
67
|
+
if threshold_instructions['neighborhood_of_interest'] is not None:
|
|
68
|
+
|
|
69
|
+
model_path = locate_signal_model(threshold_instructions['model_name'], path=None, pairs=True)
|
|
70
|
+
complete_path = model_path #+model
|
|
71
|
+
complete_path = rf"{complete_path}"
|
|
72
|
+
model_config_path = os.sep.join([complete_path,'config_input.json'])
|
|
73
|
+
model_config_path = rf"{model_config_path}"
|
|
74
|
+
|
|
75
|
+
f = open(model_config_path)
|
|
76
|
+
config = json.load(f)
|
|
77
|
+
config.update({'neighborhood_of_interest': threshold_instructions['neighborhood_of_interest'], 'reference_population': threshold_instructions['reference_population'], 'neighbor_population': threshold_instructions['neighbor_population']})
|
|
78
|
+
json_string = json.dumps(config)
|
|
79
|
+
with open(model_config_path, 'w') as outfile:
|
|
80
|
+
outfile.write(json_string)
|
|
81
|
+
|
|
49
82
|
print('Done.')
|