celldetective 1.3.6.post2__py3-none-any.whl → 1.3.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- celldetective/_version.py +1 -1
- celldetective/events.py +4 -0
- celldetective/gui/InitWindow.py +23 -9
- celldetective/gui/control_panel.py +19 -11
- celldetective/gui/generic_signal_plot.py +5 -0
- celldetective/gui/help/DL-segmentation-strategy.json +17 -17
- celldetective/gui/help/Threshold-vs-DL.json +11 -11
- celldetective/gui/help/cell-populations.json +5 -5
- celldetective/gui/help/exp-structure.json +15 -15
- celldetective/gui/help/feature-btrack.json +5 -5
- celldetective/gui/help/neighborhood.json +7 -7
- celldetective/gui/help/prefilter-for-segmentation.json +7 -7
- celldetective/gui/help/preprocessing.json +19 -19
- celldetective/gui/help/propagate-classification.json +7 -7
- celldetective/gui/plot_signals_ui.py +13 -9
- celldetective/gui/process_block.py +63 -14
- celldetective/gui/retrain_segmentation_model_options.py +21 -8
- celldetective/gui/retrain_signal_model_options.py +12 -2
- celldetective/gui/signal_annotator.py +9 -0
- celldetective/gui/signal_annotator2.py +8 -0
- celldetective/gui/styles.py +1 -0
- celldetective/gui/tableUI.py +1 -1
- celldetective/gui/workers.py +136 -0
- celldetective/io.py +53 -27
- celldetective/measure.py +112 -14
- celldetective/scripts/measure_cells.py +10 -35
- celldetective/scripts/segment_cells.py +15 -62
- celldetective/scripts/segment_cells_thresholds.py +1 -2
- celldetective/scripts/track_cells.py +16 -19
- celldetective/segmentation.py +16 -62
- celldetective/signals.py +11 -7
- celldetective/utils.py +587 -67
- {celldetective-1.3.6.post2.dist-info → celldetective-1.3.7.dist-info}/METADATA +1 -1
- {celldetective-1.3.6.post2.dist-info → celldetective-1.3.7.dist-info}/RECORD +38 -37
- {celldetective-1.3.6.post2.dist-info → celldetective-1.3.7.dist-info}/LICENSE +0 -0
- {celldetective-1.3.6.post2.dist-info → celldetective-1.3.7.dist-info}/WHEEL +0 -0
- {celldetective-1.3.6.post2.dist-info → celldetective-1.3.7.dist-info}/entry_points.txt +0 -0
- {celldetective-1.3.6.post2.dist-info → celldetective-1.3.7.dist-info}/top_level.txt +0 -0
celldetective/signals.py
CHANGED
|
@@ -879,7 +879,6 @@ class SignalDetectionModel(object):
|
|
|
879
879
|
|
|
880
880
|
assert self.model_class.layers[0].input_shape[0] == self.model_reg.layers[0].input_shape[0], f"mismatch between input shape of classification: {self.model_class.layers[0].input_shape[0]} and regression {self.model_reg.layers[0].input_shape[0]} models... Error."
|
|
881
881
|
|
|
882
|
-
|
|
883
882
|
return True
|
|
884
883
|
|
|
885
884
|
def create_models_from_scratch(self):
|
|
@@ -898,7 +897,7 @@ class SignalDetectionModel(object):
|
|
|
898
897
|
|
|
899
898
|
self.model_class = ResNetModelCurrent(n_channels=self.n_channels,
|
|
900
899
|
n_slices=self.n_conv,
|
|
901
|
-
n_classes =
|
|
900
|
+
n_classes = 3,
|
|
902
901
|
dense_collection=self.dense_collection,
|
|
903
902
|
dropout_rate=self.dropout_rate,
|
|
904
903
|
header="classifier",
|
|
@@ -1016,6 +1015,7 @@ class SignalDetectionModel(object):
|
|
|
1016
1015
|
self.loss_class = loss_class
|
|
1017
1016
|
self.show_plots = show_plots
|
|
1018
1017
|
self.channel_option = channel_option
|
|
1018
|
+
|
|
1019
1019
|
assert self.n_channels==len(self.channel_option), f'Mismatch between the channel option and the number of channels of the model...'
|
|
1020
1020
|
|
|
1021
1021
|
if isinstance(self.datasets[0], dict):
|
|
@@ -1075,7 +1075,7 @@ class SignalDetectionModel(object):
|
|
|
1075
1075
|
# If y-class is not one-hot encoded, encode it
|
|
1076
1076
|
if self.y_class_train.shape[-1] != self.n_classes:
|
|
1077
1077
|
self.class_weights = compute_weights(y=self.y_class_train,class_weight="balanced", classes=np.unique(self.y_class_train))
|
|
1078
|
-
self.y_class_train = to_categorical(self.y_class_train)
|
|
1078
|
+
self.y_class_train = to_categorical(self.y_class_train, num_classes=3)
|
|
1079
1079
|
|
|
1080
1080
|
if self.normalize:
|
|
1081
1081
|
self.y_time_train = self.y_time_train.astype(np.float32)/self.model_signal_length
|
|
@@ -1091,7 +1091,7 @@ class SignalDetectionModel(object):
|
|
|
1091
1091
|
self.x_val = pad_to_model_length(self.x_val, self.model_signal_length)
|
|
1092
1092
|
self.y_class_val = validation_data[1]
|
|
1093
1093
|
if self.y_class_val.shape[-1] != self.n_classes:
|
|
1094
|
-
self.y_class_val = to_categorical(self.y_class_val)
|
|
1094
|
+
self.y_class_val = to_categorical(self.y_class_val, num_classes=3)
|
|
1095
1095
|
self.y_time_val = validation_data[2]
|
|
1096
1096
|
if self.normalize:
|
|
1097
1097
|
self.y_time_val = self.y_time_val.astype(np.float32)/self.model_signal_length
|
|
@@ -1111,7 +1111,7 @@ class SignalDetectionModel(object):
|
|
|
1111
1111
|
self.x_test = pad_to_model_length(self.x_test, self.model_signal_length)
|
|
1112
1112
|
self.y_class_test = test_data[1]
|
|
1113
1113
|
if self.y_class_test.shape[-1] != self.n_classes:
|
|
1114
|
-
self.y_class_test = to_categorical(self.y_class_test)
|
|
1114
|
+
self.y_class_test = to_categorical(self.y_class_test, num_classes=3)
|
|
1115
1115
|
self.y_time_test = test_data[2]
|
|
1116
1116
|
if self.normalize:
|
|
1117
1117
|
self.y_time_test = self.y_time_test.astype(np.float32)/self.model_signal_length
|
|
@@ -1320,6 +1320,8 @@ class SignalDetectionModel(object):
|
|
|
1320
1320
|
"""
|
|
1321
1321
|
|
|
1322
1322
|
# if pretrained model
|
|
1323
|
+
self.n_classes = 3
|
|
1324
|
+
|
|
1323
1325
|
if self.pretrained is not None:
|
|
1324
1326
|
# if recompile
|
|
1325
1327
|
if self.recompile_pretrained:
|
|
@@ -1352,6 +1354,7 @@ class SignalDetectionModel(object):
|
|
|
1352
1354
|
# plt.show()
|
|
1353
1355
|
|
|
1354
1356
|
if hasattr(self, 'x_val'):
|
|
1357
|
+
|
|
1355
1358
|
self.history_classifier = self.model_class.fit(x=self.x_train,
|
|
1356
1359
|
y=self.y_class_train,
|
|
1357
1360
|
batch_size=self.batch_size,
|
|
@@ -1740,8 +1743,8 @@ class SignalDetectionModel(object):
|
|
|
1740
1743
|
|
|
1741
1744
|
# Compute class weights and one-hot encode
|
|
1742
1745
|
self.class_weights = compute_weights(self.y_class_set)
|
|
1743
|
-
self.nbr_classes = len(np.unique(self.y_class_set))
|
|
1744
|
-
self.y_class_set = to_categorical(self.y_class_set)
|
|
1746
|
+
self.nbr_classes = 3 #len(np.unique(self.y_class_set))
|
|
1747
|
+
self.y_class_set = to_categorical(self.y_class_set, num_classes=3)
|
|
1745
1748
|
|
|
1746
1749
|
ds = train_test_split(self.x_set,
|
|
1747
1750
|
self.y_time_set,
|
|
@@ -1799,6 +1802,7 @@ class SignalDetectionModel(object):
|
|
|
1799
1802
|
y_class_train_aug = []
|
|
1800
1803
|
|
|
1801
1804
|
counts = [0.,0.,0.]
|
|
1805
|
+
# warning augmentation creates class 2 even if does not exist in data, need to address this
|
|
1802
1806
|
for k in indices:
|
|
1803
1807
|
counts[self.y_class_train[k].argmax()] += 1
|
|
1804
1808
|
aug = augmenter(self.x_train[k],
|