celldetective 1.0.2__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (56) hide show
  1. celldetective/__main__.py +2 -2
  2. celldetective/events.py +2 -44
  3. celldetective/filters.py +4 -5
  4. celldetective/gui/__init__.py +1 -1
  5. celldetective/gui/analyze_block.py +37 -10
  6. celldetective/gui/btrack_options.py +24 -23
  7. celldetective/gui/classifier_widget.py +62 -19
  8. celldetective/gui/configure_new_exp.py +32 -35
  9. celldetective/gui/control_panel.py +115 -81
  10. celldetective/gui/gui_utils.py +674 -396
  11. celldetective/gui/json_readers.py +7 -6
  12. celldetective/gui/layouts.py +755 -0
  13. celldetective/gui/measurement_options.py +168 -487
  14. celldetective/gui/neighborhood_options.py +322 -270
  15. celldetective/gui/plot_measurements.py +1114 -0
  16. celldetective/gui/plot_signals_ui.py +20 -20
  17. celldetective/gui/process_block.py +449 -169
  18. celldetective/gui/retrain_segmentation_model_options.py +27 -26
  19. celldetective/gui/retrain_signal_model_options.py +25 -24
  20. celldetective/gui/seg_model_loader.py +31 -27
  21. celldetective/gui/signal_annotator.py +2326 -2295
  22. celldetective/gui/signal_annotator_options.py +18 -16
  23. celldetective/gui/styles.py +16 -1
  24. celldetective/gui/survival_ui.py +61 -39
  25. celldetective/gui/tableUI.py +60 -23
  26. celldetective/gui/thresholds_gui.py +68 -66
  27. celldetective/gui/viewers.py +596 -0
  28. celldetective/io.py +234 -23
  29. celldetective/measure.py +37 -32
  30. celldetective/neighborhood.py +495 -27
  31. celldetective/preprocessing.py +683 -0
  32. celldetective/scripts/analyze_signals.py +7 -0
  33. celldetective/scripts/measure_cells.py +12 -0
  34. celldetective/scripts/segment_cells.py +5 -0
  35. celldetective/scripts/track_cells.py +11 -0
  36. celldetective/signals.py +221 -98
  37. celldetective/tracking.py +0 -1
  38. celldetective/utils.py +178 -36
  39. celldetective-1.1.0.dist-info/METADATA +305 -0
  40. celldetective-1.1.0.dist-info/RECORD +80 -0
  41. {celldetective-1.0.2.dist-info → celldetective-1.1.0.dist-info}/top_level.txt +1 -0
  42. tests/__init__.py +0 -0
  43. tests/test_events.py +28 -0
  44. tests/test_filters.py +24 -0
  45. tests/test_io.py +70 -0
  46. tests/test_measure.py +141 -0
  47. tests/test_neighborhood.py +70 -0
  48. tests/test_segmentation.py +93 -0
  49. tests/test_signals.py +135 -0
  50. tests/test_tracking.py +164 -0
  51. tests/test_utils.py +71 -0
  52. celldetective-1.0.2.dist-info/METADATA +0 -192
  53. celldetective-1.0.2.dist-info/RECORD +0 -66
  54. {celldetective-1.0.2.dist-info → celldetective-1.1.0.dist-info}/LICENSE +0 -0
  55. {celldetective-1.0.2.dist-info → celldetective-1.1.0.dist-info}/WHEEL +0 -0
  56. {celldetective-1.0.2.dist-info → celldetective-1.1.0.dist-info}/entry_points.txt +0 -0
tests/test_events.py ADDED
@@ -0,0 +1,28 @@
1
+ import unittest
2
+ from celldetective.events import switch_to_events
3
+
4
+ class TestEventSwitch(unittest.TestCase):
5
+
6
+ @classmethod
7
+ def setUpClass(self):
8
+ self.classes = [0,0,1,1,2]
9
+ self.event_times = [5.,8.5,-1,-1,-1]
10
+ self.max_times = [10,10,10,10,10]
11
+ self.origin_times = [0,3,2,1,0]
12
+
13
+ self.expected_events = [1,1,0,0]
14
+ self.expected_times = [5.,5.5,8,9]
15
+
16
+ def test_expected_events(self):
17
+ events, times = switch_to_events(
18
+ self.classes,
19
+ self.event_times,
20
+ self.max_times,
21
+ self.origin_times
22
+ )
23
+ self.assertEqual(events, self.expected_events)
24
+ self.assertEqual(times, self.expected_times)
25
+
26
+
27
+ if __name__=="__main__":
28
+ unittest.main()
tests/test_filters.py ADDED
@@ -0,0 +1,24 @@
1
+ import unittest
2
+ import numpy as np
3
+ from celldetective.filters import gauss_filter, abs_filter
4
+
5
+
6
+ class TestFilters(unittest.TestCase):
7
+
8
+ @classmethod
9
+ def setUpClass(self):
10
+ self.img = np.ones((256,256),dtype=int)
11
+ self.img[100:110,100:110] = 0
12
+ self.gauss_sigma = 1.6
13
+
14
+ def test_gauss_filter_is_float(self):
15
+ self.assertIsInstance(gauss_filter(self.img, self.gauss_sigma)[0,0], float)
16
+
17
+ def test_gauss_filter_has_same_shape(self):
18
+ self.assertEqual(gauss_filter(self.img, self.gauss_sigma).shape, self.img.shape)
19
+
20
+ def test_abs_filter_is_positive(self):
21
+ self.assertTrue(np.all(abs_filter(self.img) >= 0.))
22
+
23
+ if __name__=="__main__":
24
+ unittest.main()
tests/test_io.py ADDED
@@ -0,0 +1,70 @@
1
+ import unittest
2
+ import matplotlib.pyplot as plt
3
+ import numpy as np
4
+ import os
5
+
6
+ # class TestPatchMask(unittest.TestCase):
7
+
8
+ # @classmethod
9
+ # def setUpClass(self):
10
+ # self.radius = 3
11
+
12
+ # def test_correct_shape(self):
13
+ # self.patch = create_patch_mask(self.radius, self.radius)
14
+ # self.assertEqual(self.patch.shape,(3,3))
15
+
16
+ # def test_correct_ring(self):
17
+ # self.patch = create_patch_mask(5, 5,radius=[1,2])
18
+ # self.assertFalse(self.patch[2,2])
19
+
20
+ # class TestRemoveRedundantFeatures(unittest.TestCase):
21
+
22
+ # @classmethod
23
+ # def setUpClass(self):
24
+ # self.list_a = ['feat1','feat2','feat3','feat4','intensity_mean']
25
+ # self.list_b = ['feat5','feat2','feat1','feat6','test_channel_mean']
26
+ # self.expected = ['feat3','feat4']
27
+
28
+ # def test_remove_red_features(self):
29
+ # self.assertEqual(remove_redundant_features(self.list_a, self.list_b, channel_names=['test_channel']), self.expected)
30
+
31
+
32
+ # class TestExtractChannelIndices(unittest.TestCase):
33
+
34
+ # @classmethod
35
+ # def setUpClass(self):
36
+ # self.channels = ['ch1','ch2','ch3','ch4']
37
+ # self.required_channels = ['ch4','ch2']
38
+ # self.expected_indices = [3,1]
39
+
40
+ # def test_extracted_channels_are_correct(self):
41
+ # self.assertEqual(list(_extract_channel_indices(self.channels, self.required_channels)), self.expected_indices)
42
+
43
+
44
+ # class TestImgIndexPerChannel(unittest.TestCase):
45
+
46
+ # @classmethod
47
+ # def setUpClass(self):
48
+ # self.channels_indices = [1]
49
+ # self.len_movie = 5
50
+ # self.nbr_channels = 3
51
+ # self.expected_indices = [1,4,7,10,13]
52
+
53
+ # def test_index_sequence_is_correct(self):
54
+ # self.assertEqual(list(_get_img_num_per_channel(self.channels_indices, self.len_movie, self.nbr_channels)[0]), self.expected_indices)
55
+
56
+
57
+ # class TestSplitArrayByRatio(unittest.TestCase):
58
+
59
+ # @classmethod
60
+ # def setUpClass(self):
61
+ # self.array_length = 100
62
+ # self.array = np.ones(self.array_length)
63
+
64
+ # def test_ratio_split_is_correct(self):
65
+ # split_array = split_by_ratio(self.array,0.5,0.25,0.1)
66
+ # self.assertTrue(np.all([len(split_array[0])==50, len(split_array[1])==25, len(split_array[2])==10]))
67
+
68
+
69
+ # if __name__=="__main__":
70
+ # unittest.main()
tests/test_measure.py ADDED
@@ -0,0 +1,141 @@
1
+ import unittest
2
+ import pandas as pd
3
+ import numpy as np
4
+ from celldetective.measure import measure_features, measure_isotropic_intensity, drop_tonal_features
5
+
6
+ class TestFeatureMeasurement(unittest.TestCase):
7
+
8
+ """
9
+ To do: test spot detection, fluo normalization and peripheral measurements
10
+ """
11
+
12
+ @classmethod
13
+ def setUpClass(self):
14
+
15
+ # Simple mock data, 100px*100px, one channel, value is one, uniform
16
+ # Two objects in labels map
17
+
18
+ self.frame = np.ones((100,100,1), dtype=float)
19
+ self.labels = np.zeros((100,100), dtype=int)
20
+ self.labels[50:55,50:55] = 1
21
+ self.labels[0:10,0:10] = 2
22
+
23
+ self.feature_measurements = measure_features(
24
+ self.frame,
25
+ self.labels,
26
+ features=['intensity_mean','area',],
27
+ channels=['test_channel']
28
+ )
29
+
30
+ self.feature_measurements_no_image = measure_features(
31
+ None,
32
+ self.labels,
33
+ features=['intensity_mean','area',],
34
+ channels=None
35
+ )
36
+
37
+ self.feature_measurements_no_features = measure_features(
38
+ self.frame,
39
+ self.labels,
40
+ features=None,
41
+ channels=['test_channel'],
42
+ )
43
+
44
+ # With image
45
+ def test_measure_yields_table(self):
46
+ self.assertIsInstance(self.feature_measurements, pd.DataFrame)
47
+
48
+ def test_two_objects(self):
49
+ self.assertEqual(len(self.feature_measurements),2)
50
+
51
+ def test_channel_named_correctly(self):
52
+ self.assertIn('test_channel_mean',list(self.feature_measurements.columns))
53
+
54
+ def test_intensity_is_one(self):
55
+ self.assertTrue(np.all([v==1.0 for v in self.feature_measurements['test_channel_mean'].values]))
56
+
57
+ def test_area_first_is_twenty_five(self):
58
+ self.assertEqual(self.feature_measurements['area'].values[0],25)
59
+
60
+ def test_area_second_is_hundred(self):
61
+ self.assertEqual(self.feature_measurements['area'].values[1],100)
62
+
63
+ # Without image
64
+ def test_measure_yields_table(self):
65
+ self.assertIsInstance(self.feature_measurements_no_image, pd.DataFrame)
66
+
67
+ def test_two_objects(self):
68
+ self.assertEqual(len(self.feature_measurements_no_image),2)
69
+
70
+ def test_channel_not_in_table(self):
71
+ self.assertNotIn('test_channel_mean',list(self.feature_measurements_no_image.columns))
72
+
73
+ # With no features
74
+ def test_only_one_measurement(self):
75
+ cols = list(self.feature_measurements_no_features.columns)
76
+ assert 'class_id' in cols and len(cols)==1
77
+
78
+
79
+ class TestIsotropicMeasurement(unittest.TestCase):
80
+
81
+ """
82
+
83
+ Test that isotropic intensity measurements behave as expected on fake image
84
+
85
+ """
86
+
87
+ @classmethod
88
+ def setUpClass(self):
89
+
90
+ # Simple mock data, 100px*100px, one channel, value is one
91
+ # Square (21*21px) of value 0. in middle
92
+ # Two objects in labels map
93
+
94
+ self.frame = np.ones((100,100,1), dtype=float)
95
+ self.frame[40:61,40:61,0] = 0.
96
+ self.positions = pd.DataFrame([{'TRACK_ID': 0, 'POSITION_X': 50, 'POSITION_Y': 50, 'FRAME': 0, 'class_id': 0}])
97
+
98
+ self.inner_radius = 9
99
+ self.upper_radius = 20
100
+ self.safe_upper_radius = int(21//2*np.sqrt(2))+2
101
+
102
+ self.iso_measurements = measure_isotropic_intensity(self.positions,
103
+ self.frame,
104
+ channels=['test_channel'],
105
+ intensity_measurement_radii=[self.inner_radius, self.upper_radius],
106
+ operations = ['mean'],
107
+ )
108
+ self.iso_measurements_ring = measure_isotropic_intensity(
109
+ self.positions,
110
+ self.frame,
111
+ channels=['test_channel'],
112
+ intensity_measurement_radii=[[self.safe_upper_radius, self.safe_upper_radius+3]],
113
+ operations = ['mean'],
114
+ )
115
+
116
+
117
+ def test_measure_yields_table(self):
118
+ self.assertIsInstance(self.iso_measurements, pd.DataFrame)
119
+
120
+ def test_intensity_zero_in_small_circle(self):
121
+ self.assertEqual(self.iso_measurements[f'test_channel_circle_{self.inner_radius}_mean'].values[0],0.)
122
+
123
+ def test_intensity_greater_than_zero_in_intermediate_circle(self):
124
+ self.assertGreater(self.iso_measurements[f'test_channel_circle_{self.upper_radius}_mean'].values[0],0.)
125
+
126
+ def test_ring_measurement_avoids_zero(self):
127
+ self.assertEqual(self.iso_measurements[f'test_channel_ring_{self.safe_upper_radius}_{self.safe_upper_radius+3}_mean'].values[0],1.0)
128
+
129
+ class TestDropTonal(unittest.TestCase):
130
+
131
+ @classmethod
132
+ def setUpClass(self):
133
+ self.features = ['area', 'intensity_mean', 'intensity_max']
134
+
135
+ def test_drop_tonal(self):
136
+ self.features_processed = drop_tonal_features(self.features)
137
+ self.assertEqual(self.features_processed,['area'])
138
+
139
+
140
+ if __name__=="__main__":
141
+ unittest.main()
@@ -0,0 +1,70 @@
1
+ import unittest
2
+ import matplotlib.pyplot as plt
3
+ import numpy as np
4
+ import os
5
+
6
+ # class TestPatchMask(unittest.TestCase):
7
+
8
+ # @classmethod
9
+ # def setUpClass(self):
10
+ # self.radius = 3
11
+
12
+ # def test_correct_shape(self):
13
+ # self.patch = create_patch_mask(self.radius, self.radius)
14
+ # self.assertEqual(self.patch.shape,(3,3))
15
+
16
+ # def test_correct_ring(self):
17
+ # self.patch = create_patch_mask(5, 5,radius=[1,2])
18
+ # self.assertFalse(self.patch[2,2])
19
+
20
+ # class TestRemoveRedundantFeatures(unittest.TestCase):
21
+
22
+ # @classmethod
23
+ # def setUpClass(self):
24
+ # self.list_a = ['feat1','feat2','feat3','feat4','intensity_mean']
25
+ # self.list_b = ['feat5','feat2','feat1','feat6','test_channel_mean']
26
+ # self.expected = ['feat3','feat4']
27
+
28
+ # def test_remove_red_features(self):
29
+ # self.assertEqual(remove_redundant_features(self.list_a, self.list_b, channel_names=['test_channel']), self.expected)
30
+
31
+
32
+ # class TestExtractChannelIndices(unittest.TestCase):
33
+
34
+ # @classmethod
35
+ # def setUpClass(self):
36
+ # self.channels = ['ch1','ch2','ch3','ch4']
37
+ # self.required_channels = ['ch4','ch2']
38
+ # self.expected_indices = [3,1]
39
+
40
+ # def test_extracted_channels_are_correct(self):
41
+ # self.assertEqual(list(_extract_channel_indices(self.channels, self.required_channels)), self.expected_indices)
42
+
43
+
44
+ # class TestImgIndexPerChannel(unittest.TestCase):
45
+
46
+ # @classmethod
47
+ # def setUpClass(self):
48
+ # self.channels_indices = [1]
49
+ # self.len_movie = 5
50
+ # self.nbr_channels = 3
51
+ # self.expected_indices = [1,4,7,10,13]
52
+
53
+ # def test_index_sequence_is_correct(self):
54
+ # self.assertEqual(list(_get_img_num_per_channel(self.channels_indices, self.len_movie, self.nbr_channels)[0]), self.expected_indices)
55
+
56
+
57
+ # class TestSplitArrayByRatio(unittest.TestCase):
58
+
59
+ # @classmethod
60
+ # def setUpClass(self):
61
+ # self.array_length = 100
62
+ # self.array = np.ones(self.array_length)
63
+
64
+ # def test_ratio_split_is_correct(self):
65
+ # split_array = split_by_ratio(self.array,0.5,0.25,0.1)
66
+ # self.assertTrue(np.all([len(split_array[0])==50, len(split_array[1])==25, len(split_array[2])==10]))
67
+
68
+
69
+ # if __name__=="__main__":
70
+ # unittest.main()
@@ -0,0 +1,93 @@
1
+ import unittest
2
+ import matplotlib.pyplot as plt
3
+ import numpy as np
4
+ import os
5
+ import json
6
+ from tifffile import imread
7
+ from celldetective.segmentation import segment, segment_frame_from_thresholds
8
+ from tensorflow.keras.metrics import BinaryIoU
9
+
10
+ TEST_IMAGE_FILENAME = os.path.join(os.path.dirname(__file__), os.sep.join(['assets','sample.tif']))
11
+ TEST_LABEL_FILENAME = os.path.join(os.path.dirname(__file__), os.sep.join(['assets','sample_labelled.tif']))
12
+ TEST_CONFIG_FILENAME = os.path.join(os.path.dirname(__file__), os.sep.join(['assets','sample.json']))
13
+
14
+ class TestDLMCF7Segmentation(unittest.TestCase):
15
+
16
+ @classmethod
17
+ def setUpClass(self):
18
+ self.img = imread(TEST_IMAGE_FILENAME)
19
+ self.label_true = imread(TEST_LABEL_FILENAME)
20
+ self.stack = np.moveaxis([self.img, self.img, self.img],1,-1)
21
+ with open(TEST_CONFIG_FILENAME) as config_file:
22
+ self.config = json.load(config_file)
23
+ self.channels = self.config['channels']
24
+ self.spatial_calibration = self.config['spatial_calibration']
25
+
26
+ def test_correct_segmentation_with_multimodal_model(self):
27
+
28
+ labels = segment(self.stack, "MCF7_bf_pi_cfse_h", channels=self.channels, spatial_calibration=self.spatial_calibration, view_on_napari=False,
29
+ use_gpu=False)
30
+ np.testing.assert_array_equal(labels[0], labels[1])
31
+
32
+ self.binary_label_true = self.label_true.copy().astype(float)
33
+ self.binary_label_true[self.binary_label_true>0] = 1.
34
+
35
+ label_binary = labels[0].copy().astype(float)
36
+ label_binary[label_binary>0] = 1.
37
+
38
+ m = BinaryIoU(target_class_ids=[1])
39
+ m.update_state(self.binary_label_true, label_binary)
40
+ score = m.result().numpy()
41
+
42
+ self.assertGreater(score,0.9)
43
+
44
+ def test_correct_segmentation_with_transferred_model(self):
45
+
46
+ labels = segment(self.stack, "MCF7_h_versatile", channels=self.channels, spatial_calibration=self.spatial_calibration, view_on_napari=False,
47
+ use_gpu=True, time_flat_normalization=False, time_flat_percentiles=(0.0,99.99))
48
+ np.testing.assert_array_equal(labels[0], labels[1])
49
+
50
+ self.binary_label_true = self.label_true.copy().astype(float)
51
+ self.binary_label_true[self.binary_label_true>0] = 1.
52
+
53
+ label_binary = labels[0].copy().astype(float)
54
+ label_binary[label_binary>0] = 1.
55
+
56
+ m = BinaryIoU(target_class_ids=[1])
57
+ m.update_state(self.binary_label_true, label_binary)
58
+ score = m.result().numpy()
59
+
60
+ self.assertGreater(score,0.9)
61
+
62
+
63
+ class TestThresholdMCF7Segmentation(unittest.TestCase):
64
+
65
+ @classmethod
66
+ def setUpClass(self):
67
+ self.img = imread(TEST_IMAGE_FILENAME)
68
+ self.label_true = imread(TEST_LABEL_FILENAME)
69
+ with open(TEST_CONFIG_FILENAME) as config_file:
70
+ self.config = json.load(config_file)
71
+ self.channels = self.config['channels']
72
+ self.spatial_calibration = self.config['spatial_calibration']
73
+
74
+ def test_correct_segmentation_with_threshold(self):
75
+
76
+ label = segment_frame_from_thresholds(np.moveaxis(self.img,0,-1), target_channel=3, thresholds=[8000,1.0E10], equalize_reference=None,
77
+ filters=[['variance',4],['gauss',2]], marker_min_distance=13, marker_footprint_size=34, marker_footprint=None, feature_queries=["area < 80"], channel_names=None)
78
+
79
+ self.binary_label_true = self.label_true.copy().astype(float)
80
+ self.binary_label_true[self.binary_label_true>0] = 1.
81
+
82
+ label_binary = label.copy().astype(float)
83
+ label_binary[label_binary>0] = 1.
84
+
85
+ m = BinaryIoU(target_class_ids=[1])
86
+ m.update_state(self.binary_label_true, label_binary)
87
+ score = m.result().numpy()
88
+
89
+ self.assertGreater(score,0.7)
90
+
91
+
92
+ if __name__=="__main__":
93
+ unittest.main()
tests/test_signals.py ADDED
@@ -0,0 +1,135 @@
1
+ import unittest
2
+ import matplotlib.pyplot as plt
3
+ import numpy as np
4
+ import os
5
+ import random
6
+ import pandas as pd
7
+ import shutil
8
+
9
+ def sigmoid(t,t0,dt,A,offset):
10
+ return A/(1+np.exp(-(t-t0)/dt)) + offset
11
+
12
+ def generate_fake_signal_data(n_signals):
13
+
14
+ timeline = np.linspace(0,100,100)
15
+ amplitudes = list(np.linspace(2000,3000,100))
16
+ slopes = list(np.linspace(0.5,5,100))
17
+ means = list(np.linspace(-100,200,100))
18
+ random_cut = list(np.linspace(25,200,176,dtype=int))
19
+ noise_levels = list(np.linspace(1,100,100,dtype=int))
20
+
21
+ trajectories = []
22
+ for i in range(n_signals):
23
+
24
+ a = random.sample(amplitudes,k=1)[0]
25
+ dt = random.sample(slopes,k=1)[0]
26
+ mu = random.sample(means,k=1)[0]
27
+ cut = random.sample(random_cut,k=1)[0]
28
+ n = random.sample(noise_levels,k=1)[0]
29
+
30
+ if mu<=0.:
31
+ cclass=2
32
+ t0=-1
33
+ elif (mu>0)*(mu<=100):
34
+ cclass=0
35
+ t0=mu
36
+ else:
37
+ cclass=1
38
+ t0=-1
39
+
40
+ noise = [random.random()*n for i in range(len(timeline))]
41
+ signal = sigmoid(timeline, mu, dt,a,0)+noise
42
+ signal = signal[:cut]
43
+ if mu>=cut:
44
+ cclass=1
45
+ t0=-1
46
+
47
+ for j in range(len(signal)):
48
+ trajectories.append({'TRACK_ID': i, 'POSITION_X': 0., 'POSITION_Y': 0., 'FRAME': j,'signal': signal[j], 't0': t0, 'cclass': cclass})
49
+
50
+ trajectories = pd.DataFrame(trajectories)
51
+
52
+ return trajectories
53
+
54
+ def export_set(trajectories, name='set.npy', output_folder='.'):
55
+
56
+ training_set = []
57
+ cols = trajectories.columns
58
+ tracks = np.unique(trajectories["TRACK_ID"].to_numpy())
59
+
60
+ for track in tracks:
61
+ signals = {}
62
+ for c in cols:
63
+ signals.update({c: trajectories.loc[trajectories["TRACK_ID"] == track, c].to_numpy()})
64
+ time_of_interest = trajectories.loc[trajectories["TRACK_ID"] == track, "t0"].to_numpy()[0]
65
+ cclass = trajectories.loc[trajectories["TRACK_ID"] == track, "cclass"].to_numpy()[0]
66
+ signals.update({"time_of_interest": time_of_interest, "class": cclass})
67
+ training_set.append(signals)
68
+
69
+ np.save(os.sep.join([output_folder,name]), training_set)
70
+
71
+
72
+ class TestCreateSignalModel(unittest.TestCase):
73
+
74
+ def test_create_model(self):
75
+
76
+ from celldetective.signals import SignalDetectionModel
77
+
78
+ model = SignalDetectionModel(
79
+ channel_option=["signal"],
80
+ model_signal_length=128,
81
+ n_channels=1,
82
+ n_conv=2,
83
+ n_classes=3,
84
+ dense_collection=512,
85
+ dropout_rate=0.1,
86
+ label='test',
87
+ )
88
+
89
+
90
+ class TestTrainSignalModel(unittest.TestCase):
91
+
92
+ @classmethod
93
+ def setUpClass(self):
94
+
95
+ from celldetective.signals import SignalDetectionModel
96
+
97
+ self.trajectories = generate_fake_signal_data(300)
98
+ if not os.path.exists('temp'):
99
+ os.mkdir('temp')
100
+ export_set(self.trajectories, name='set.npy', output_folder='temp')
101
+ self.model = SignalDetectionModel(
102
+ channel_option=["signal"],
103
+ model_signal_length=128,
104
+ n_channels=1,
105
+ n_conv=2,
106
+ n_classes=3,
107
+ dense_collection=512,
108
+ dropout_rate=0.1,
109
+ label='test',
110
+ )
111
+
112
+ def test_train_signal_model(self):
113
+
114
+ self.model.fit_from_directory(
115
+ ['temp'],
116
+ normalize=True,
117
+ normalization_percentile=None,
118
+ normalization_values = None,
119
+ normalization_clip = None,
120
+ channel_option=["signal"],
121
+ target_directory='temp',
122
+ augment=False,
123
+ model_name='None',
124
+ validation_split=0.2,
125
+ test_split=0.1,
126
+ batch_size = 16,
127
+ epochs=1,
128
+ recompile_pretrained=False,
129
+ learning_rate=0.01,
130
+ show_plots=False,
131
+ )
132
+ shutil.rmtree('temp')
133
+
134
+ if __name__=="__main__":
135
+ unittest.main()