celldetective 1.3.7__py3-none-any.whl → 1.3.7.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,270 @@
1
+ from multiprocessing import Process
2
+ import time
3
+ import os
4
+ import shutil
5
+ from glob import glob
6
+ import json
7
+ from tqdm import tqdm
8
+ import numpy as np
9
+ import random
10
+
11
+ from celldetective.utils import load_image_dataset, augmenter, interpolate_nan
12
+ from celldetective.io import normalize_multichannel
13
+ from stardist import fill_label_holes
14
+ from art import tprint
15
+ from distutils.dir_util import copy_tree
16
+ from csbdeep.utils import save_json
17
+
18
+
19
+ class TrainSegModelProcess(Process):
20
+
21
+ def __init__(self, queue=None, process_args=None, *args, **kwargs):
22
+
23
+ super().__init__(*args, **kwargs)
24
+
25
+ self.queue = queue
26
+
27
+ if process_args is not None:
28
+ for key, value in process_args.items():
29
+ setattr(self, key, value)
30
+
31
+ tprint("Train segmentation")
32
+ self.read_instructions()
33
+ self.extract_training_params()
34
+ self.load_dataset()
35
+ self.split_test_train()
36
+
37
+ self.sum_done = 0
38
+ self.t0 = time.time()
39
+
40
+ def read_instructions(self):
41
+
42
+ if os.path.exists(self.instructions):
43
+ with open(self.instructions, 'r') as f:
44
+ self.training_instructions = json.load(f)
45
+ else:
46
+ print('Training instructions could not be found. Abort.')
47
+ self.abort_process()
48
+
49
+ def run(self):
50
+
51
+ if self.model_type=="cellpose":
52
+ self.train_cellpose_model()
53
+ elif self.model_type=="stardist":
54
+ self.train_stardist_model()
55
+
56
+ self.queue.put("finished")
57
+ self.queue.close()
58
+
59
+ def train_stardist_model(self):
60
+
61
+ from stardist import calculate_extents, gputools_available
62
+ from stardist.models import Config2D, StarDist2D
63
+
64
+ n_rays = 32
65
+ print(gputools_available())
66
+
67
+ n_channel = self.X_trn[0].shape[-1]
68
+
69
+ # Predict on subsampled grid for increased efficiency and larger field of view
70
+ grid = (2,2)
71
+ conf = Config2D(
72
+ n_rays = n_rays,
73
+ grid = grid,
74
+ use_gpu = self.use_gpu,
75
+ n_channel_in = n_channel,
76
+ train_learning_rate = self.learning_rate,
77
+ train_patch_size = (256,256),
78
+ train_epochs = self.epochs,
79
+ train_reduce_lr = {'factor': 0.1, 'patience': 30, 'min_delta': 0},
80
+ train_batch_size = self.batch_size,
81
+ train_steps_per_epoch = int(self.augmentation_factor*len(self.X_trn)),
82
+ )
83
+
84
+ if self.use_gpu:
85
+ from csbdeep.utils.tf import limit_gpu_memory
86
+ limit_gpu_memory(None, allow_growth=True)
87
+
88
+ if self.pretrained is None:
89
+ model = StarDist2D(conf, name=self.model_name, basedir=self.target_directory)
90
+ else:
91
+ os.rename(self.instructions, os.sep.join([self.target_directory, self.model_name, 'temp.json']))
92
+ copy_tree(self.pretrained, os.sep.join([self.target_directory, self.model_name]))
93
+
94
+ if os.path.exists(os.sep.join([self.target_directory, self.model_name, 'training_instructions.json'])):
95
+ os.remove(os.sep.join([self.target_directory, self.model_name, 'training_instructions.json']))
96
+ if os.path.exists(os.sep.join([self.target_directory, self.model_name, 'config_input.json'])):
97
+ os.remove(os.sep.join([self.target_directory, self.model_name, 'config_input.json']))
98
+ if os.path.exists(os.sep.join([self.target_directory, self.model_name, 'logs'+os.sep])):
99
+ shutil.rmtree(os.sep.join([self.target_directory, self.model_name, 'logs']))
100
+ os.rename(os.sep.join([self.target_directory, self.model_name, 'temp.json']),os.sep.join([self.target_directory, self.model_name, 'training_instructions.json']))
101
+
102
+ #shutil.copytree(pretrained, os.sep.join([target_directory, model_name]))
103
+ model = StarDist2D(None, name=self.model_name, basedir=self.target_directory)
104
+ model.config.train_epochs = self.epochs
105
+ model.config.train_batch_size = min(len(self.X_trn),self.batch_size)
106
+ model.config.train_learning_rate = self.learning_rate # perf seems bad if lr is changed in transfer
107
+ model.config.use_gpu = self.use_gpu
108
+ model.config.train_reduce_lr = {'factor': 0.1, 'patience': 10, 'min_delta': 0}
109
+ print(f'{model.config=}')
110
+
111
+ save_json(vars(model.config), os.sep.join([self.target_directory, self.model_name, 'config.json']))
112
+
113
+ median_size = calculate_extents(list(self.Y_trn), np.mean)
114
+ fov = np.array(model._axes_tile_overlap('YX'))
115
+ print(f"median object size: {median_size}")
116
+ print(f"network field of view : {fov}")
117
+ if any(median_size > fov):
118
+ print("WARNING: median object size larger than field of view of the neural network.")
119
+
120
+ if self.augmentation_factor==1.0:
121
+ model.train(self.X_trn, self.Y_trn, validation_data=(self.X_val,self.Y_val))
122
+ else:
123
+ model.train(self.X_trn, self.Y_trn, validation_data=(self.X_val,self.Y_val), augmenter=augmenter)
124
+ model.optimize_thresholds(self.X_val,self.Y_val)
125
+
126
+ config_inputs = {"channels": self.target_channels, 'normalization_percentile': self.normalization_percentile,
127
+ 'normalization_clip': self.normalization_clip, 'normalization_values': self.normalization_values,
128
+ 'model_type': 'stardist', 'spatial_calibration': self.spatial_calibration, 'dataset': {'train': self.files_train, 'validation': self.files_val}}
129
+
130
+ json_input_config = json.dumps(config_inputs, indent=4)
131
+ with open(os.sep.join([self.target_directory, self.model_name, "config_input.json"]), "w") as outfile:
132
+ outfile.write(json_input_config)
133
+
134
+ def train_cellpose_model(self):
135
+
136
+ # do augmentation in place
137
+ X_aug = []; Y_aug = [];
138
+ n_val = max(1, int(round(self.augmentation_factor * len(self.X_trn))))
139
+ indices = random.choices(list(np.arange(len(self.X_trn))), k=n_val)
140
+ print('Performing image augmentation pre-training...')
141
+ for i in tqdm(indices):
142
+ x_aug,y_aug = augmenter(self.X_trn[i], self.Y_trn[i])
143
+ X_aug.append(x_aug)
144
+ Y_aug.append(y_aug)
145
+
146
+ # Channel axis in front for cellpose
147
+ X_aug = [np.moveaxis(x,-1,0) for x in X_aug]
148
+ self.X_val = [np.moveaxis(x,-1,0) for x in self.X_val]
149
+ print('number of augmented images: %3d' % len(X_aug))
150
+
151
+ from cellpose.models import CellposeModel
152
+ from cellpose.io import logger_setup
153
+ import torch
154
+
155
+ if not self.use_gpu:
156
+ print('Using CPU for training...')
157
+ device = torch.device("cpu")
158
+ else:
159
+ print('Using GPU for training...')
160
+
161
+ logger, log_file = logger_setup()
162
+ print(f'Pretrained model: ', self.pretrained)
163
+ if self.pretrained is not None:
164
+ pretrained_path = os.sep.join([self.pretrained,os.path.split(self.pretrained)[-1]])
165
+ else:
166
+ pretrained_path = self.pretrained
167
+
168
+ model = CellposeModel(gpu=self.use_gpu, model_type=None, pretrained_model=pretrained_path, diam_mean=30.0, nchan=X_aug[0].shape[0],)
169
+ model.train(train_data=X_aug, train_labels=Y_aug, normalize=False, channels=None, batch_size=self.batch_size,
170
+ min_train_masks=1,save_path=self.target_directory+os.sep+self.model_name,n_epochs=self.epochs, model_name=self.model_name, learning_rate=self.learning_rate, test_data = self.X_val, test_labels=self.Y_val)
171
+
172
+ file_to_move = glob(os.sep.join([self.target_directory, self.model_name, 'models','*']))[0]
173
+ shutil.move(file_to_move, os.sep.join([self.target_directory, self.model_name,''])+os.path.split(file_to_move)[-1])
174
+ os.rmdir(os.sep.join([self.target_directory, self.model_name, 'models']))
175
+
176
+ diameter = model.diam_labels
177
+
178
+ if self.pretrained is not None and os.path.split(self.pretrained)[-1]=='CP_nuclei':
179
+ standard_diameter = 17.0
180
+ else:
181
+ standard_diameter = 30.0
182
+
183
+ input_spatial_calibration = self.spatial_calibration #*diameter / standard_diameter
184
+
185
+ config_inputs = {"channels": self.target_channels, "diameter": standard_diameter, 'cellprob_threshold': 0., 'flow_threshold': 0.4,
186
+ 'normalization_percentile': self.normalization_percentile, 'normalization_clip': self.normalization_clip,
187
+ 'normalization_values': self.normalization_values, 'model_type': 'cellpose',
188
+ 'spatial_calibration': input_spatial_calibration, 'dataset': {'train': self.files_train, 'validation': self.files_val}}
189
+ json_input_config = json.dumps(config_inputs, indent=4)
190
+ with open(os.sep.join([self.target_directory, self.model_name, "config_input.json"]), "w") as outfile:
191
+ outfile.write(json_input_config)
192
+
193
+
194
+ def split_test_train(self):
195
+
196
+ if not len(self.X) > 1:
197
+ print("Not enough training data")
198
+ self.abort_process()
199
+
200
+ rng = np.random.RandomState()
201
+ ind = rng.permutation(len(self.X))
202
+ n_val = max(1, int(round(self.validation_split * len(ind))))
203
+ ind_train, ind_val = ind[:-n_val], ind[-n_val:]
204
+ self.X_val, self.Y_val = [self.X[i] for i in ind_val] , [self.Y[i] for i in ind_val]
205
+ self.X_trn, self.Y_trn = [self.X[i] for i in ind_train], [self.Y[i] for i in ind_train]
206
+
207
+ self.files_train = [self.filenames[i] for i in ind_train]
208
+ self.files_val = [self.filenames[i] for i in ind_val]
209
+
210
+ print('number of images: %3d' % len(self.X))
211
+ print('- training: %3d' % len(self.X_trn))
212
+ print('- validation: %3d' % len(self.X_val))
213
+
214
+ def extract_training_params(self):
215
+
216
+ self.model_name = self.training_instructions['model_name']
217
+ self.target_directory = self.training_instructions['target_directory']
218
+ self.model_type = self.training_instructions['model_type']
219
+ self.pretrained = self.training_instructions['pretrained']
220
+
221
+ self.datasets = self.training_instructions['ds']
222
+
223
+ self.target_channels = self.training_instructions['channel_option']
224
+ self.normalization_percentile = self.training_instructions['normalization_percentile']
225
+ self.normalization_clip = self.training_instructions['normalization_clip']
226
+ self.normalization_values = self.training_instructions['normalization_values']
227
+ self.spatial_calibration = self.training_instructions['spatial_calibration']
228
+
229
+ self.validation_split = self.training_instructions['validation_split']
230
+ self.augmentation_factor = self.training_instructions['augmentation_factor']
231
+
232
+ self.learning_rate = self.training_instructions['learning_rate']
233
+ self.epochs = self.training_instructions['epochs']
234
+ self.batch_size = self.training_instructions['batch_size']
235
+
236
+ def load_dataset(self):
237
+
238
+ print(f'Datasets: {self.datasets}')
239
+ self.X,self.Y,self.filenames = load_image_dataset(self.datasets, self.target_channels, train_spatial_calibration=self.spatial_calibration,
240
+ mask_suffix='labelled')
241
+ print('Dataset loaded...')
242
+
243
+ self.values = []
244
+ self.percentiles = []
245
+ for k in range(len(self.normalization_percentile)):
246
+ if self.normalization_percentile[k]:
247
+ self.percentiles.append(self.normalization_values[k])
248
+ self.values.append(None)
249
+ else:
250
+ self.percentiles.append(None)
251
+ self.values.append(self.normalization_values[k])
252
+
253
+ self.X = [normalize_multichannel(x, **{"percentiles": self.percentiles, 'values': self.values, 'clip': self.normalization_clip}) for x in self.X]
254
+
255
+ for k in range(len(self.X)):
256
+ x = self.X[k].copy()
257
+ x_interp = np.moveaxis([interpolate_nan(x[:,:,c].copy()) for c in range(x.shape[-1])],0,-1)
258
+ self.X[k] = x_interp
259
+
260
+ self.Y = [fill_label_holes(y) for y in tqdm(self.Y)]
261
+
262
+ def end_process(self):
263
+
264
+ self.terminate()
265
+ self.queue.put("finished")
266
+
267
+ def abort_process(self):
268
+
269
+ self.terminate()
270
+ self.queue.put("error")
@@ -0,0 +1,108 @@
1
+ from multiprocessing import Process
2
+ import time
3
+ import os
4
+ import json
5
+ from glob import glob
6
+ import numpy as np
7
+ from art import tprint
8
+ from celldetective.signals import SignalDetectionModel
9
+ from celldetective.io import locate_signal_model
10
+
11
+
12
+ class TrainSignalModelProcess(Process):
13
+
14
+ def __init__(self, queue=None, process_args=None, *args, **kwargs):
15
+
16
+ super().__init__(*args, **kwargs)
17
+
18
+ self.queue = queue
19
+
20
+ if process_args is not None:
21
+ for key, value in process_args.items():
22
+ setattr(self, key, value)
23
+
24
+ tprint("Train segmentation")
25
+ self.read_instructions()
26
+ self.extract_training_params()
27
+
28
+
29
+ self.sum_done = 0
30
+ self.t0 = time.time()
31
+
32
+ def read_instructions(self):
33
+
34
+ if os.path.exists(self.instructions):
35
+ with open(self.instructions, 'r') as f:
36
+ self.training_instructions = json.load(f)
37
+ else:
38
+ print('Training instructions could not be found. Abort.')
39
+ self.abort_process()
40
+
41
+ all_classes = []
42
+ for d in self.training_instructions["ds"]:
43
+ datasets = glob(d+os.sep+"*.npy")
44
+ for dd in datasets:
45
+ data = np.load(dd, allow_pickle=True)
46
+ classes = np.unique([ddd["class"] for ddd in data])
47
+ all_classes.extend(classes)
48
+ all_classes = np.unique(all_classes)
49
+ n_classes = len(all_classes)
50
+
51
+ self.model_params = {k:self.training_instructions[k] for k in ('pretrained', 'model_signal_length', 'channel_option', 'n_channels', 'label') if k in self.training_instructions}
52
+ self.model_params.update({'n_classes': n_classes})
53
+ self.train_params = {k:self.training_instructions[k] for k in ('model_name', 'target_directory', 'channel_option','recompile_pretrained', 'test_split', 'augment', 'epochs', 'learning_rate', 'batch_size', 'validation_split','normalization_percentile','normalization_values','normalization_clip') if k in self.training_instructions}
54
+
55
+ def neighborhood_postprocessing(self):
56
+
57
+ # if neighborhood of interest in training instructions, write it in config!
58
+ if 'neighborhood_of_interest' in self.training_instructions:
59
+ if self.training_instructions['neighborhood_of_interest'] is not None:
60
+
61
+ model_path = locate_signal_model(self.training_instructions['model_name'], path=None, pairs=True)
62
+ complete_path = model_path #+model
63
+ complete_path = rf"{complete_path}"
64
+ model_config_path = os.sep.join([complete_path,'config_input.json'])
65
+ model_config_path = rf"{model_config_path}"
66
+
67
+ f = open(model_config_path)
68
+ config = json.load(f)
69
+ config.update({'neighborhood_of_interest': self.training_instructions['neighborhood_of_interest'], 'reference_population': self.training_instructions['reference_population'], 'neighbor_population': self.training_instructions['neighbor_population']})
70
+ json_string = json.dumps(config)
71
+ with open(model_config_path, 'w') as outfile:
72
+ outfile.write(json_string)
73
+
74
+ def run(self):
75
+
76
+ model = SignalDetectionModel(**self.model_params)
77
+ model.fit_from_directory(self.training_instructions['ds'], **self.train_params)
78
+ self.neighborhood_postprocessing()
79
+ self.queue.put("finished")
80
+ self.queue.close()
81
+
82
+
83
+ def extract_training_params(self):
84
+
85
+ self.training_instructions.update({'n_channels': len(self.training_instructions['channel_option'])})
86
+ if self.training_instructions['augmentation_factor']>1.0:
87
+ self.training_instructions.update({'augment': True})
88
+ else:
89
+ self.training_instructions.update({'augment': False})
90
+ self.training_instructions.update({'test_split': 0.})
91
+
92
+
93
+ def end_process(self):
94
+
95
+ # self.terminate()
96
+
97
+ # if self.model_type=="stardist":
98
+ # from stardist.models import StarDist2D
99
+ # self.model = StarDist2D(None, name=self.model_name, basedir=self.target_directory)
100
+ # self.model.optimize_thresholds(self.X_val,self.Y_val)
101
+
102
+ self.terminate()
103
+ self.queue.put("finished")
104
+
105
+ def abort_process(self):
106
+
107
+ self.terminate()
108
+ self.queue.put("error")
celldetective/io.py CHANGED
@@ -58,6 +58,7 @@ def extract_experiment_from_well(well_path):
58
58
  >>> well_path = "/path/to/experiment/plate/well"
59
59
  >>> extract_experiment_from_well(well_path)
60
60
  '/path/to/experiment'
61
+
61
62
  """
62
63
 
63
64
  if not well_path.endswith(os.sep):
@@ -94,6 +95,7 @@ def extract_well_from_position(pos_path):
94
95
  >>> pos_path = "/path/to/experiment/plate/well/position"
95
96
  >>> extract_well_from_position(pos_path)
96
97
  '/path/to/experiment/plate/well/'
98
+
97
99
  """
98
100
 
99
101
  if not pos_path.endswith(os.sep):
@@ -129,6 +131,7 @@ def extract_experiment_from_position(pos_path):
129
131
  >>> pos_path = "/path/to/experiment/plate/well/position"
130
132
  >>> extract_experiment_from_position(pos_path)
131
133
  '/path/to/experiment'
134
+
132
135
  """
133
136
 
134
137
  pos_path = pos_path.replace(os.sep, '/')
@@ -187,6 +190,7 @@ def collect_experiment_metadata(pos_path=None, well_path=None):
187
190
  >>> metadata = collect_experiment_metadata(well_path=well_path)
188
191
  >>> metadata["concentration"]
189
192
  10.0
193
+
190
194
  """
191
195
 
192
196
  if pos_path is not None:
@@ -289,6 +293,7 @@ def get_config(experiment):
289
293
  >>> config_path = get_config(experiment)
290
294
  >>> print(config_path)
291
295
  '/path/to/experiment/config.ini'
296
+
292
297
  """
293
298
 
294
299
  if not experiment.endswith(os.sep):
@@ -336,6 +341,7 @@ def get_spatial_calibration(experiment):
336
341
  >>> calibration = get_spatial_calibration(experiment)
337
342
  >>> print(calibration)
338
343
  0.325 # pixels-to-micrometers conversion factor
344
+
339
345
  """
340
346
 
341
347
  config = get_config(experiment)
@@ -380,6 +386,7 @@ def get_temporal_calibration(experiment):
380
386
  >>> calibration = get_temporal_calibration(experiment)
381
387
  >>> print(calibration)
382
388
  0.5 # frames-to-minutes conversion factor
389
+
383
390
  """
384
391
 
385
392
  config = get_config(experiment)
@@ -435,6 +442,7 @@ def get_experiment_concentrations(experiment, dtype=str):
435
442
  >>> concentrations = get_experiment_concentrations(experiment, dtype=float)
436
443
  >>> print(concentrations)
437
444
  [0.1, 0.2, 0.5, 1.0]
445
+
438
446
  """
439
447
 
440
448
  config = get_config(experiment)
@@ -489,6 +497,7 @@ def get_experiment_cell_types(experiment, dtype=str):
489
497
  >>> cell_types = get_experiment_cell_types(experiment, dtype=str)
490
498
  >>> print(cell_types)
491
499
  ['TypeA', 'TypeB', 'TypeC', 'TypeD']
500
+
492
501
  """
493
502
 
494
503
  config = get_config(experiment)
@@ -540,6 +549,7 @@ def get_experiment_antibodies(experiment, dtype=str):
540
549
 
541
550
  >>> get_experiment_antibodies("path/to/experiment2", dtype=int)
542
551
  array([0, 1, 2])
552
+
543
553
  """
544
554
 
545
555
  config = get_config(experiment)
@@ -594,6 +604,7 @@ def get_experiment_pharmaceutical_agents(experiment, dtype=str):
594
604
  >>> antibodies = get_experiment_antibodies(experiment, dtype=str)
595
605
  >>> print(antibodies)
596
606
  ['AntibodyA', 'AntibodyB', 'AntibodyC', 'AntibodyD']
607
+
597
608
  """
598
609
 
599
610
  config = get_config(experiment)
@@ -702,6 +713,7 @@ def extract_well_name_and_number(well):
702
713
  >>> well_path = "another/path/W1"
703
714
  >>> extract_well_name_and_number(well_path)
704
715
  ('W1', 1)
716
+
705
717
  """
706
718
 
707
719
  split_well_path = well.split(os.sep)
@@ -740,6 +752,7 @@ def extract_position_name(pos):
740
752
  >>> pos_path = "another/path/positionA"
741
753
  >>> extract_position_name(pos_path)
742
754
  'positionA'
755
+
743
756
  """
744
757
 
745
758
  split_pos_path = pos.split(os.sep)
@@ -890,6 +903,7 @@ def get_position_movie_path(pos, prefix=''):
890
903
  >>> pos_path = "nonexistent/path"
891
904
  >>> get_position_movie_path(pos_path)
892
905
  None
906
+
893
907
  """
894
908
 
895
909
 
@@ -961,6 +975,7 @@ def load_experiment_tables(experiment, population='targets', well_option='*', po
961
975
  Use pickle files for faster loading:
962
976
 
963
977
  >>> df = load_experiment_tables("experiment_01", load_pickle=True)
978
+
964
979
  """
965
980
 
966
981
  config = get_config(experiment)
@@ -1171,6 +1186,7 @@ def locate_labels(position, population='target', frames=None):
1171
1186
  Load multiple specific frames:
1172
1187
 
1173
1188
  >>> labels = locate_labels("/path/to/position", population="target", frames=[0, 1, 2])
1189
+
1174
1190
  """
1175
1191
 
1176
1192
  if not position.endswith(os.sep):
@@ -1241,6 +1257,7 @@ def fix_missing_labels(position, population='target', prefix='Aligned'):
1241
1257
  -------
1242
1258
  None
1243
1259
  The function creates new label files in the corresponding folder for any frames missing label files.
1260
+
1244
1261
  """
1245
1262
 
1246
1263
  if not position.endswith(os.sep):
@@ -1414,6 +1431,7 @@ def auto_load_number_of_frames(stack_path):
1414
1431
  >>> frames = auto_load_number_of_frames(None)
1415
1432
  >>> print(frames)
1416
1433
  None
1434
+
1417
1435
  """
1418
1436
 
1419
1437
  if stack_path is None:
@@ -1510,6 +1528,7 @@ def parse_isotropic_radii(string):
1510
1528
  - It identifies ranges using square brackets and assumes that ranges are always
1511
1529
  two consecutive values.
1512
1530
  - Non-integer sections of the string are ignored.
1531
+
1513
1532
  """
1514
1533
 
1515
1534
  sections = re.split(',| ', string)
@@ -1618,6 +1637,7 @@ def interpret_tracking_configuration(config):
1618
1637
 
1619
1638
  >>> interpret_tracking_configuration(None)
1620
1639
  '/path/to/default/config.json'
1640
+
1621
1641
  """
1622
1642
 
1623
1643
  if isinstance(config, str):
@@ -1792,6 +1812,7 @@ def locate_signal_model(name, path=None, pairs=False):
1792
1812
 
1793
1813
  >>> locate_signal_model("remote_model")
1794
1814
  'path/to/celldetective/models/signal_detection/remote_model/'
1815
+
1795
1816
  """
1796
1817
 
1797
1818
  main_dir = os.sep.join([os.path.split(os.path.dirname(os.path.realpath(__file__)))[0], "celldetective"])
@@ -1859,6 +1880,7 @@ def locate_pair_signal_model(name, path=None):
1859
1880
 
1860
1881
  >>> locate_pair_signal_model("custom_model", path="/additional/models/")
1861
1882
  '/additional/models/custom_model/'
1883
+
1862
1884
  """
1863
1885
 
1864
1886
 
@@ -1937,6 +1959,7 @@ def relabel_segmentation(labels, df, exclude_nans=True, column_labels={'track':
1937
1959
  ... }
1938
1960
  >>> new_labels = relabel_segmentation(labels, df, column_labels=column_labels, exclude_nans=True)
1939
1961
  Done.
1962
+
1940
1963
  """
1941
1964
 
1942
1965
  n_threads = threads
@@ -2037,6 +2060,7 @@ def control_tracks(position, prefix="Aligned", population="target", relabel=True
2037
2060
  Example
2038
2061
  -------
2039
2062
  >>> control_tracks("/path/to/data/position_1", prefix="Aligned", population="target", relabel=True, flush_memory=True, threads=4)
2063
+
2040
2064
  """
2041
2065
 
2042
2066
  if not position.endswith(os.sep):
@@ -2089,6 +2113,7 @@ def tracks_to_btrack(df, exclude_nans=False):
2089
2113
  Example
2090
2114
  -------
2091
2115
  >>> data, properties, graph = tracks_to_btrack(df, exclude_nans=True)
2116
+
2092
2117
  """
2093
2118
 
2094
2119
  graph = {}
@@ -2268,7 +2293,7 @@ def view_tracks_in_napari(position, population, stack=None, labels=None, relabel
2268
2293
  new_cell['TRACK_ID'] = value_under
2269
2294
  df = pd.concat([df, new_cell], ignore_index=True)
2270
2295
 
2271
- relabel = np.amax(df['TRACK_ID'].unique()) + 1
2296
+ relabel = np.amax(viewer.layers['segmentation'].data) + 1
2272
2297
  for f in viewer.layers['segmentation'].data[int(frame):]:
2273
2298
  if target_track_id!=0:
2274
2299
  f[np.where(f==target_track_id)] = relabel