celldetective 1.3.3.post1__py3-none-any.whl → 1.3.4.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (35) hide show
  1. celldetective/__main__.py +30 -4
  2. celldetective/_version.py +1 -1
  3. celldetective/extra_properties.py +21 -0
  4. celldetective/filters.py +15 -2
  5. celldetective/gui/InitWindow.py +28 -34
  6. celldetective/gui/analyze_block.py +3 -498
  7. celldetective/gui/classifier_widget.py +1 -1
  8. celldetective/gui/control_panel.py +98 -27
  9. celldetective/gui/generic_signal_plot.py +35 -18
  10. celldetective/gui/gui_utils.py +143 -2
  11. celldetective/gui/layouts.py +7 -6
  12. celldetective/gui/measurement_options.py +3 -11
  13. celldetective/gui/plot_measurements.py +5 -13
  14. celldetective/gui/plot_signals_ui.py +30 -30
  15. celldetective/gui/process_block.py +61 -103
  16. celldetective/gui/signal_annotator.py +50 -32
  17. celldetective/gui/signal_annotator2.py +7 -4
  18. celldetective/gui/styles.py +13 -0
  19. celldetective/gui/survival_ui.py +8 -21
  20. celldetective/gui/tableUI.py +1 -2
  21. celldetective/gui/thresholds_gui.py +0 -6
  22. celldetective/gui/viewers.py +1 -5
  23. celldetective/io.py +31 -4
  24. celldetective/measure.py +8 -5
  25. celldetective/neighborhood.py +0 -2
  26. celldetective/scripts/measure_cells.py +21 -9
  27. celldetective/signals.py +78 -66
  28. celldetective/tracking.py +19 -13
  29. {celldetective-1.3.3.post1.dist-info → celldetective-1.3.4.post1.dist-info}/METADATA +2 -1
  30. {celldetective-1.3.3.post1.dist-info → celldetective-1.3.4.post1.dist-info}/RECORD +35 -35
  31. tests/test_qt.py +5 -3
  32. {celldetective-1.3.3.post1.dist-info → celldetective-1.3.4.post1.dist-info}/LICENSE +0 -0
  33. {celldetective-1.3.3.post1.dist-info → celldetective-1.3.4.post1.dist-info}/WHEEL +0 -0
  34. {celldetective-1.3.3.post1.dist-info → celldetective-1.3.4.post1.dist-info}/entry_points.txt +0 -0
  35. {celldetective-1.3.3.post1.dist-info → celldetective-1.3.4.post1.dist-info}/top_level.txt +0 -0
@@ -23,6 +23,7 @@ class Styles(object):
23
23
 
24
24
  def init_button_styles(self):
25
25
 
26
+
26
27
  self.button_style_sheet = '''
27
28
  QPushButton {
28
29
  background-color: #1565c0;
@@ -147,6 +148,18 @@ class Styles(object):
147
148
  }
148
149
  '''
149
150
 
151
+ self.menu_check_style = '''
152
+ QCheckBox {
153
+ font-size: 10px;
154
+ padding-left: 10px;
155
+ padding-top: 5px;
156
+ }
157
+ QCheckBox::indicator:unchecked:hover {
158
+ background-color : rgba(189, 189, 189, 1);
159
+ opacity : 0.3;
160
+ }
161
+ '''
162
+
150
163
  self.button_add = '''
151
164
  QPushButton {
152
165
  background-color: transparent;
@@ -39,8 +39,8 @@ class ConfigSurvival(QWidget, Styles):
39
39
  self.float_validator = QDoubleValidator()
40
40
  self.auto_close = False
41
41
 
42
- self.well_option = self.parent_window.parent_window.well_list.currentIndex()
43
- self.position_option = self.parent_window.parent_window.position_list.currentIndex()
42
+ self.well_option = self.parent_window.parent_window.well_list.getSelectedIndices()
43
+ self.position_option = self.parent_window.parent_window.position_list.getSelectedIndices()
44
44
  self.interpret_pos_location()
45
45
  #self.config_path = self.exp_dir + self.config_name
46
46
 
@@ -62,15 +62,10 @@ class ConfigSurvival(QWidget, Styles):
62
62
 
63
63
  """
64
64
 
65
- if self.well_option==len(self.wells):
66
- self.well_indices = np.arange(len(self.wells))
67
- else:
68
- self.well_indices = np.array([self.well_option],dtype=int)
69
-
70
- if self.position_option==0:
65
+ self.well_indices = self.parent_window.parent_window.well_list.getSelectedIndices()
66
+ self.position_indices = self.parent_window.parent_window.position_list.getSelectedIndices()
67
+ if self.position_indices==[]:
71
68
  self.position_indices = None
72
- else:
73
- self.position_indices = np.array([self.position_option],dtype=int)
74
69
 
75
70
 
76
71
  def populate_widget(self):
@@ -243,18 +238,10 @@ class ConfigSurvival(QWidget, Styles):
243
238
 
244
239
  """
245
240
 
246
- self.well_option = self.parent_window.parent_window.well_list.currentIndex()
247
- if self.well_option==len(self.wells):
248
- wo = '*'
249
- else:
250
- wo = self.well_option
251
- self.position_option = self.parent_window.parent_window.position_list.currentIndex()
252
- if self.position_option==0:
253
- po = '*'
254
- else:
255
- po = self.position_option - 1
241
+ self.well_option = self.parent_window.parent_window.well_list.getSelectedIndices()
242
+ self.position_option = self.parent_window.parent_window.position_list.getSelectedIndices()
256
243
 
257
- self.df, self.df_pos_info = load_experiment_tables(self.exp_dir, well_option=wo, position_option=po, population=self.cbs[0].currentText(), return_pos_info=True)
244
+ self.df, self.df_pos_info = load_experiment_tables(self.exp_dir, well_option=self.well_option, position_option=self.position_option, population=self.cbs[0].currentText(), return_pos_info=True)
258
245
  if self.df is None:
259
246
  msgBox = QMessageBox()
260
247
  msgBox.setIcon(QMessageBox.Warning)
@@ -1,5 +1,5 @@
1
1
  from PyQt5.QtWidgets import QRadioButton, QButtonGroup, QMainWindow, QTableView, QAction, QMenu,QFileDialog, QLineEdit, QHBoxLayout, QWidget, QPushButton, QVBoxLayout, QComboBox, QLabel, QCheckBox, QMessageBox
2
- from PyQt5.QtCore import Qt, QAbstractTableModel
2
+ from PyQt5.QtCore import Qt
3
3
  from PyQt5.QtGui import QBrush, QColor, QDoubleValidator
4
4
  import pandas as pd
5
5
  import matplotlib.pyplot as plt
@@ -17,7 +17,6 @@ from superqt import QColormapComboBox, QLabeledSlider, QSearchableComboBox
17
17
  from superqt.fonticon import icon
18
18
  from fonticon_mdi6 import MDI6
19
19
  from math import floor
20
- import re
21
20
 
22
21
  from matplotlib import colormaps
23
22
 
@@ -1,12 +1,6 @@
1
- import math
2
-
3
- import skimage
4
1
  from PyQt5.QtWidgets import QAction, QMenu, QMainWindow, QMessageBox, QLabel, QWidget, QFileDialog, QHBoxLayout, \
5
2
  QGridLayout, QLineEdit, QScrollArea, QVBoxLayout, QComboBox, QPushButton, QApplication, QPushButton, QRadioButton, QButtonGroup
6
3
  from PyQt5.QtGui import QDoubleValidator, QIntValidator
7
- from matplotlib.patches import Circle
8
- from scipy import ndimage
9
- from skimage.morphology import disk
10
4
 
11
5
  from celldetective.filters import std_filter, gauss_filter
12
6
  from celldetective.gui.gui_utils import center_window, FigureCanvas, ListWidget, FilterChoice, color_from_class, help_generic
@@ -22,11 +22,7 @@ from fonticon_mdi6 import MDI6
22
22
  from matplotlib_scalebar.scalebar import ScaleBar
23
23
  import gc
24
24
  from celldetective.utils import mask_edges
25
- from scipy.ndimage import shift, grey_dilation
26
- from skimage.feature import blob_dog
27
- import math
28
- from skimage.morphology import disk
29
- from matplotlib.patches import Circle
25
+ from scipy.ndimage import shift
30
26
 
31
27
  class StackVisualizer(QWidget, Styles):
32
28
 
celldetective/io.py CHANGED
@@ -69,7 +69,7 @@ def collect_experiment_metadata(pos_path=None, well_path=None):
69
69
  antibodies = get_experiment_antibodies(experiment)
70
70
  pharmaceutical_agents = get_experiment_pharmaceutical_agents(experiment)
71
71
 
72
- return {"pos_path": pos_path, "pos_name": pos_name, "well_path": well_path, "well_name": well_name, "well_nbr": well_nbr, "experiment": experiment, "antibody": antibodies[idx], "concentration": concentrations[idx], "cell_type": cell_types[idx], "pharmaceutical_agent": pharmaceutical_agents[idx]}
72
+ return {"pos_path": pos_path, "position": pos_path, "pos_name": pos_name, "well_path": well_path, "well_name": well_name, "well_nbr": well_nbr, "experiment": experiment, "antibody": antibodies[idx], "concentration": concentrations[idx], "cell_type": cell_types[idx], "pharmaceutical_agent": pharmaceutical_agents[idx]}
73
73
 
74
74
 
75
75
  def get_experiment_wells(experiment):
@@ -691,7 +691,7 @@ def locate_stack(position, prefix='Aligned'):
691
691
 
692
692
  return stack
693
693
 
694
- def locate_labels(position, population='target'):
694
+ def locate_labels(position, population='target', frames=None):
695
695
 
696
696
  """
697
697
 
@@ -732,7 +732,33 @@ def locate_labels(position, population='target'):
732
732
  label_path = natsorted(glob(position + os.sep.join(["labels_targets", "*.tif"])))
733
733
  elif population.lower() == "effector" or population.lower() == "effectors":
734
734
  label_path = natsorted(glob(position + os.sep.join(["labels_effectors", "*.tif"])))
735
- labels = np.array([imread(i.replace('\\', '/')) for i in label_path])
735
+
736
+ label_names = [os.path.split(lbl)[-1] for lbl in label_path]
737
+
738
+ if frames is None:
739
+
740
+ labels = np.array([imread(i.replace('\\', '/')) for i in label_path])
741
+
742
+ elif isinstance(frames, (int,float, np.int_)):
743
+
744
+ tzfill = str(int(frames)).zfill(4)
745
+ idx = label_names.index(f"{tzfill}.tif")
746
+ if idx==-1:
747
+ labels = None
748
+ else:
749
+ labels = np.array(imread(label_path[idx].replace('\\', '/')))
750
+
751
+ elif isinstance(frames, (list,np.ndarray)):
752
+ labels = []
753
+ for f in frames:
754
+ tzfill = str(int(f)).zfill(4)
755
+ idx = label_names.index(f"{tzfill}.tif")
756
+ if idx==-1:
757
+ labels.append(None)
758
+ else:
759
+ labels.append(np.array(imread(label_path[idx].replace('\\', '/'))))
760
+ else:
761
+ print('Frames argument must be None, int or list...')
736
762
 
737
763
  return labels
738
764
 
@@ -783,7 +809,8 @@ def fix_missing_labels(position, population='target', prefix='Aligned'):
783
809
  to_create = all_frames
784
810
  to_create = [str(x).zfill(4)+'.tif' for x in to_create]
785
811
  for file in to_create:
786
- imwrite(os.sep.join([path, file]), template)
812
+ save_tiff_imagej_compatible(os.sep.join([path, file]), template.astype(np.int16), axes='YX')
813
+ #imwrite(os.sep.join([path, file]), template.astype(int))
787
814
 
788
815
 
789
816
  def locate_stack_and_labels(position, prefix='Aligned', population="target"):
celldetective/measure.py CHANGED
@@ -25,8 +25,6 @@ from celldetective.extra_properties import *
25
25
  from inspect import getmembers, isfunction
26
26
  from skimage.morphology import disk
27
27
 
28
- import matplotlib.pyplot as plt
29
-
30
28
  abs_path = os.sep.join([os.path.split(os.path.dirname(os.path.realpath(__file__)))[0], 'celldetective'])
31
29
 
32
30
  def measure(stack=None, labels=None, trajectories=None, channel_names=None,
@@ -417,9 +415,10 @@ def measure_features(img, label, features=['area', 'intensity_mean'], channels=N
417
415
  if haralick_options is not None:
418
416
  try:
419
417
  df_haralick = compute_haralick_features(img, label, channels=channels, **haralick_options)
420
- df_haralick = df_haralick.rename(columns={"cell_id": "label"})
421
- df_props = df_props.merge(df_haralick, how='outer', on='label', suffixes=('_delme', ''))
422
- df_props = df_props[[c for c in df_props.columns if not c.endswith('_delme')]]
418
+ if df_haralick is not None:
419
+ df_haralick = df_haralick.rename(columns={"cell_id": "label"})
420
+ df_props = df_props.merge(df_haralick, how='outer', on='label', suffixes=('_delme', ''))
421
+ df_props = df_props[[c for c in df_props.columns if not c.endswith('_delme')]]
423
422
  except Exception as e:
424
423
  print(e)
425
424
  pass
@@ -520,6 +519,10 @@ def compute_haralick_features(img, labels, channels=None, target_channel=0, scal
520
519
  if len(img.shape)==3:
521
520
  img = img[:,:,target_channel]
522
521
 
522
+ # Routine to skip black frames
523
+ if np.percentile(img.flatten(),99.9)==0.0:
524
+ return None
525
+
523
526
  img = interpolate_nan(img)
524
527
 
525
528
  # Rescale image and mask
@@ -7,8 +7,6 @@ from celldetective.utils import contour_of_instance_segmentation, extract_identi
7
7
  from scipy.spatial.distance import cdist
8
8
  from celldetective.io import locate_labels, get_position_pickle, get_position_table
9
9
 
10
- import matplotlib.pyplot as plt
11
-
12
10
  abs_path = os.sep.join([os.path.split(os.path.dirname(os.path.realpath(__file__)))[0], 'celldetective'])
13
11
 
14
12
 
@@ -5,7 +5,7 @@ Copright © 2022 Laboratoire Adhesion et Inflammation, Authored by Remy Torro.
5
5
  import argparse
6
6
  import os
7
7
  import json
8
- from celldetective.io import auto_load_number_of_frames, load_frames
8
+ from celldetective.io import auto_load_number_of_frames, load_frames, fix_missing_labels, locate_labels
9
9
  from celldetective.utils import extract_experiment_channels, ConfigSectionMap, _get_img_num_per_channel, extract_experiment_channels
10
10
  from celldetective.utils import remove_redundant_features, remove_trajectory_measurements
11
11
  from celldetective.measure import drop_tonal_features, measure_features, measure_isotropic_intensity
@@ -16,7 +16,6 @@ import numpy as np
16
16
  import pandas as pd
17
17
  from natsort import natsorted
18
18
  from art import tprint
19
- from tifffile import imread
20
19
  import threading
21
20
  import datetime
22
21
 
@@ -165,6 +164,8 @@ else:
165
164
  features += ['centroid']
166
165
  do_iso_intensities = False
167
166
 
167
+ # if 'centroid' not in features:
168
+ # features += ['centroid']
168
169
 
169
170
  # if (features is not None) and (trajectories is not None):
170
171
  # features = remove_redundant_features(features,
@@ -176,6 +177,12 @@ len_movie_auto = auto_load_number_of_frames(file)
176
177
  if len_movie_auto is not None:
177
178
  len_movie = len_movie_auto
178
179
 
180
+ if label_path is not None and file is not None:
181
+ test = len(label_path)==len_movie
182
+ if not test:
183
+ fix_missing_labels(pos, population=mode, prefix=movie_prefix)
184
+ label_path = natsorted(glob(os.sep.join([pos, label_folder, '*.tif'])))
185
+
179
186
  img_num_channels = _get_img_num_per_channel(channel_indices, len_movie, nbr_channels)
180
187
 
181
188
 
@@ -203,6 +210,8 @@ if trajectories is None:
203
210
  if 'label' not in features:
204
211
  features.append('label')
205
212
 
213
+ if label_path is not None:
214
+ label_names = [os.path.split(lbl)[-1] for lbl in label_path]
206
215
 
207
216
 
208
217
  features_log=f'features: {features}'
@@ -228,7 +237,10 @@ def measure_index(indices):
228
237
  img = load_frames(img_num_channels[:,t], file, scale=None, normalize_input=False)
229
238
 
230
239
  if label_path is not None:
231
- lbl = imread(label_path[t])
240
+
241
+ lbl = locate_labels(pos, population=mode, frames=t)
242
+ if lbl is None:
243
+ continue
232
244
 
233
245
  if trajectories is not None:
234
246
 
@@ -286,16 +298,16 @@ for th in threads:
286
298
 
287
299
 
288
300
  if len(timestep_dataframes)>0:
301
+
289
302
  df = pd.concat(timestep_dataframes)
290
- df.reset_index(inplace=True, drop=True)
291
303
 
292
- if trajectories is None:
304
+ if trajectories is not None:
305
+ df = df.sort_values(by=[column_labels['track'],column_labels['time']])
306
+ df = df.dropna(subset=[column_labels['track']])
307
+ else:
293
308
  df['ID'] = np.arange(len(df))
294
309
 
295
- if column_labels['track'] in df.columns:
296
- df = df.sort_values(by=[column_labels['track'], column_labels['time']])
297
- else:
298
- df = df.sort_values(by=column_labels['time'])
310
+ df = df.reset_index(drop=True)
299
311
 
300
312
  df.to_csv(pos+os.sep.join(["output", "tables", table_name]), index=False)
301
313
  print(f'Measurements successfully written in table {pos+os.sep.join(["output", "tables", table_name])}')
celldetective/signals.py CHANGED
@@ -27,7 +27,7 @@ from natsort import natsorted
27
27
  from glob import glob
28
28
  import random
29
29
  from celldetective.utils import color_from_status, color_from_class
30
- from math import floor, ceil
30
+ from math import floor
31
31
  from scipy.optimize import curve_fit
32
32
  import time
33
33
  import math
@@ -193,73 +193,74 @@ def analyze_signals(trajectories, model, interpolate_na=True,
193
193
  signals[i,max(frames):,j] = signal[-1]
194
194
 
195
195
  model = SignalDetectionModel(pretrained=complete_path)
196
+ if not model.pretrained is None:
196
197
 
197
- classes = model.predict_class(signals)
198
- times_recast = model.predict_time_of_interest(signals)
198
+ classes = model.predict_class(signals)
199
+ times_recast = model.predict_time_of_interest(signals)
199
200
 
200
- if label is None:
201
- class_col = 'class'
202
- time_col = 't0'
203
- status_col = 'status'
204
- else:
205
- class_col = 'class_'+label
206
- time_col = 't_'+label
207
- status_col = 'status_'+label
201
+ if label is None:
202
+ class_col = 'class'
203
+ time_col = 't0'
204
+ status_col = 'status'
205
+ else:
206
+ class_col = 'class_'+label
207
+ time_col = 't_'+label
208
+ status_col = 'status_'+label
208
209
 
209
- for i,(tid,group) in enumerate(trajectories.groupby(column_labels['track'])):
210
- indices = group.index
211
- trajectories.loc[indices,class_col] = classes[i]
212
- trajectories.loc[indices,time_col] = times_recast[i]
213
- print('Done.')
210
+ for i,(tid,group) in enumerate(trajectories.groupby(column_labels['track'])):
211
+ indices = group.index
212
+ trajectories.loc[indices,class_col] = classes[i]
213
+ trajectories.loc[indices,time_col] = times_recast[i]
214
+ print('Done.')
214
215
 
215
- for tid, group in trajectories.groupby(column_labels['track']):
216
-
217
- indices = group.index
218
- t0 = group[time_col].to_numpy()[0]
219
- cclass = group[class_col].to_numpy()[0]
220
- timeline = group[column_labels['time']].to_numpy()
221
- status = np.zeros_like(timeline)
222
- if t0 > 0:
223
- status[timeline>=t0] = 1.
224
- if cclass==2:
225
- status[:] = 2
226
- if cclass>2:
227
- status[:] = 42
228
- status_color = [color_from_status(s) for s in status]
229
- class_color = [color_from_class(cclass) for i in range(len(status))]
230
-
231
- trajectories.loc[indices, status_col] = status
232
- trajectories.loc[indices, 'status_color'] = status_color
233
- trajectories.loc[indices, 'class_color'] = class_color
234
-
235
- if plot_outcome:
236
- fig,ax = plt.subplots(1,len(selected_signals), figsize=(10,5))
237
- for i,s in enumerate(selected_signals):
238
- for k,(tid,group) in enumerate(trajectories.groupby(column_labels['track'])):
239
- cclass = group[class_col].to_numpy()[0]
240
- t0 = group[time_col].to_numpy()[0]
241
- timeline = group[column_labels['time']].to_numpy()
242
- if cclass==0:
243
- if len(selected_signals)>1:
244
- ax[i].plot(timeline - t0, group[s].to_numpy(),c='tab:blue',alpha=0.1)
245
- else:
246
- ax.plot(timeline - t0, group[s].to_numpy(),c='tab:blue',alpha=0.1)
247
- if len(selected_signals)>1:
248
- for a,s in zip(ax,selected_signals):
249
- a.set_title(s)
250
- a.set_xlabel(r'time - t$_0$ [frame]')
251
- a.spines['top'].set_visible(False)
252
- a.spines['right'].set_visible(False)
253
- else:
254
- ax.set_title(s)
255
- ax.set_xlabel(r'time - t$_0$ [frame]')
256
- ax.spines['top'].set_visible(False)
257
- ax.spines['right'].set_visible(False)
258
- plt.tight_layout()
259
- if output_dir is not None:
260
- plt.savefig(output_dir+'signal_collapse.png',bbox_inches='tight',dpi=300)
261
- plt.pause(3)
262
- plt.close()
216
+ for tid, group in trajectories.groupby(column_labels['track']):
217
+
218
+ indices = group.index
219
+ t0 = group[time_col].to_numpy()[0]
220
+ cclass = group[class_col].to_numpy()[0]
221
+ timeline = group[column_labels['time']].to_numpy()
222
+ status = np.zeros_like(timeline)
223
+ if t0 > 0:
224
+ status[timeline>=t0] = 1.
225
+ if cclass==2:
226
+ status[:] = 2
227
+ if cclass>2:
228
+ status[:] = 42
229
+ status_color = [color_from_status(s) for s in status]
230
+ class_color = [color_from_class(cclass) for i in range(len(status))]
231
+
232
+ trajectories.loc[indices, status_col] = status
233
+ trajectories.loc[indices, 'status_color'] = status_color
234
+ trajectories.loc[indices, 'class_color'] = class_color
235
+
236
+ if plot_outcome:
237
+ fig,ax = plt.subplots(1,len(selected_signals), figsize=(10,5))
238
+ for i,s in enumerate(selected_signals):
239
+ for k,(tid,group) in enumerate(trajectories.groupby(column_labels['track'])):
240
+ cclass = group[class_col].to_numpy()[0]
241
+ t0 = group[time_col].to_numpy()[0]
242
+ timeline = group[column_labels['time']].to_numpy()
243
+ if cclass==0:
244
+ if len(selected_signals)>1:
245
+ ax[i].plot(timeline - t0, group[s].to_numpy(),c='tab:blue',alpha=0.1)
246
+ else:
247
+ ax.plot(timeline - t0, group[s].to_numpy(),c='tab:blue',alpha=0.1)
248
+ if len(selected_signals)>1:
249
+ for a,s in zip(ax,selected_signals):
250
+ a.set_title(s)
251
+ a.set_xlabel(r'time - t$_0$ [frame]')
252
+ a.spines['top'].set_visible(False)
253
+ a.spines['right'].set_visible(False)
254
+ else:
255
+ ax.set_title(s)
256
+ ax.set_xlabel(r'time - t$_0$ [frame]')
257
+ ax.spines['top'].set_visible(False)
258
+ ax.spines['right'].set_visible(False)
259
+ plt.tight_layout()
260
+ if output_dir is not None:
261
+ plt.savefig(output_dir+'signal_collapse.png',bbox_inches='tight',dpi=300)
262
+ plt.pause(3)
263
+ plt.close()
263
264
 
264
265
  return trajectories
265
266
 
@@ -800,8 +801,12 @@ class SignalDetectionModel(object):
800
801
 
801
802
 
802
803
  if self.pretrained is not None:
803
- print(f"Load pretrained models from {path}...")
804
- self.load_pretrained_model()
804
+ print(f"Load pretrained models from {pretrained}...")
805
+ test = self.load_pretrained_model()
806
+ if test is None:
807
+ self.pretrained = None
808
+ print('Pretrained model could not be loaded. Check the log for error. Abort...')
809
+ return None
805
810
  else:
806
811
  print("Create models from scratch...")
807
812
  self.create_models_from_scratch()
@@ -828,6 +833,9 @@ class SignalDetectionModel(object):
828
833
  - The configuration file is expected to be named "config_input.json" and located in the same directory as the models.
829
834
  """
830
835
 
836
+ if self.pretrained.endswith(os.sep):
837
+ self.pretrained = os.sep.join(self.pretrained.split(os.sep)[:-1])
838
+
831
839
  try:
832
840
  self.model_class = load_model(os.sep.join([self.pretrained,"classifier.h5"]),compile=False)
833
841
  self.model_class.load_weights(os.sep.join([self.pretrained,"classifier.h5"]))
@@ -843,6 +851,9 @@ class SignalDetectionModel(object):
843
851
  print(f"Error {e}...")
844
852
  self.model_reg = None
845
853
 
854
+ if self.model_class is None and self.model_reg is None:
855
+ return None
856
+
846
857
  # load config
847
858
  with open(os.sep.join([self.pretrained,"config_input.json"])) as config_file:
848
859
  model_config = json.load(config_file)
@@ -867,6 +878,7 @@ class SignalDetectionModel(object):
867
878
 
868
879
  assert self.model_class.layers[0].input_shape[0] == self.model_reg.layers[0].input_shape[0], f"mismatch between input shape of classification: {self.model_class.layers[0].input_shape[0]} and regression {self.model_reg.layers[0].input_shape[0]} models... Error."
869
880
 
881
+ return True
870
882
 
871
883
  def create_models_from_scratch(self):
872
884
 
celldetective/tracking.py CHANGED
@@ -959,22 +959,28 @@ def write_first_detection_class(tab, column_labels={'track': "TRACK_ID", 'time':
959
959
  indices = track_group.index
960
960
  detection = track_group[column_labels['x']].values
961
961
  timeline = track_group[column_labels['time']].values
962
- if len(timeline)>2:
963
- dt = timeline[1] - timeline[0]
964
- if np.any(detection==detection):
965
- t_first = timeline[detection==detection][0]
966
- cclass = 0
967
- if t_first<=0:
968
- t_first = -1
969
- cclass = 2
970
- else:
971
- t_first = float(t_first) - float(dt)
972
- else:
962
+ dt = 1
963
+
964
+ # Initialize
965
+ cclass = 2; t_first = np.nan;
966
+
967
+ if np.any(detection==detection):
968
+ t_first = timeline[detection==detection][0]
969
+ cclass = 0
970
+ if t_first<=0:
973
971
  t_first = -1
974
972
  cclass = 2
973
+ else:
974
+ t_first = float(t_first) - float(dt)
975
+ if t_first==0:
976
+ t_first += 0.01
977
+ else:
978
+ t_first = -1
979
+ cclass = 2
980
+
981
+ tab.loc[indices, 'class_firstdetection'] = cclass
982
+ tab.loc[indices, 't_firstdetection'] = t_first
975
983
 
976
- tab.loc[indices, 'class_firstdetection'] = cclass
977
- tab.loc[indices, 't_firstdetection'] = t_first
978
984
  return tab
979
985
 
980
986
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: celldetective
3
- Version: 1.3.3.post1
3
+ Version: 1.3.4.post1
4
4
  Summary: description
5
5
  Home-page: http://github.com/remyeltorro/celldetective
6
6
  Author: Rémy Torro
@@ -41,6 +41,7 @@ Requires-Dist: pytest
41
41
  Requires-Dist: pytest-qt
42
42
  Requires-Dist: h5py
43
43
  Requires-Dist: cliffs_delta
44
+ Requires-Dist: requests
44
45
 
45
46
  # Celldetective
46
47