learning-loop-node 0.10.12__py3-none-any.whl → 0.10.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of learning-loop-node might be problematic. Click here for more details.
- learning_loop_node/annotation/annotator_node.py +11 -10
- learning_loop_node/data_classes/detections.py +34 -25
- learning_loop_node/data_classes/general.py +27 -17
- learning_loop_node/data_exchanger.py +6 -5
- learning_loop_node/detector/detector_logic.py +10 -4
- learning_loop_node/detector/detector_node.py +80 -54
- learning_loop_node/detector/inbox_filter/relevance_filter.py +9 -3
- learning_loop_node/detector/outbox.py +8 -1
- learning_loop_node/detector/rest/about.py +34 -9
- learning_loop_node/detector/rest/backdoor_controls.py +10 -29
- learning_loop_node/detector/rest/detect.py +27 -19
- learning_loop_node/detector/rest/model_version_control.py +30 -13
- learning_loop_node/detector/rest/operation_mode.py +11 -5
- learning_loop_node/detector/rest/outbox_mode.py +7 -1
- learning_loop_node/helpers/log_conf.py +5 -0
- learning_loop_node/node.py +97 -49
- learning_loop_node/rest.py +55 -0
- learning_loop_node/tests/detector/conftest.py +36 -2
- learning_loop_node/tests/detector/test_client_communication.py +21 -19
- learning_loop_node/tests/detector/test_detector_node.py +86 -0
- learning_loop_node/tests/trainer/conftest.py +4 -4
- learning_loop_node/tests/trainer/states/test_state_detecting.py +8 -9
- learning_loop_node/tests/trainer/states/test_state_download_train_model.py +8 -8
- learning_loop_node/tests/trainer/states/test_state_prepare.py +6 -7
- learning_loop_node/tests/trainer/states/test_state_sync_confusion_matrix.py +21 -18
- learning_loop_node/tests/trainer/states/test_state_train.py +6 -8
- learning_loop_node/tests/trainer/states/test_state_upload_detections.py +7 -9
- learning_loop_node/tests/trainer/states/test_state_upload_model.py +7 -8
- learning_loop_node/tests/trainer/test_errors.py +2 -2
- learning_loop_node/trainer/io_helpers.py +3 -6
- learning_loop_node/trainer/rest/backdoor_controls.py +19 -40
- learning_loop_node/trainer/trainer_logic.py +4 -4
- learning_loop_node/trainer/trainer_logic_generic.py +15 -12
- learning_loop_node/trainer/trainer_node.py +5 -4
- {learning_loop_node-0.10.12.dist-info → learning_loop_node-0.10.14.dist-info}/METADATA +16 -15
- {learning_loop_node-0.10.12.dist-info → learning_loop_node-0.10.14.dist-info}/RECORD +37 -35
- {learning_loop_node-0.10.12.dist-info → learning_loop_node-0.10.14.dist-info}/WHEEL +0 -0
|
@@ -12,6 +12,9 @@ import pytest
|
|
|
12
12
|
import socketio
|
|
13
13
|
import uvicorn
|
|
14
14
|
|
|
15
|
+
from learning_loop_node.data_classes import BoxDetection, Detections
|
|
16
|
+
from learning_loop_node.detector.detector_logic import DetectorLogic
|
|
17
|
+
|
|
15
18
|
from ...detector.detector_node import DetectorNode
|
|
16
19
|
from ...detector.outbox import Outbox
|
|
17
20
|
from ...globals import GLOBALS
|
|
@@ -38,8 +41,8 @@ def should_have_segmentations(request) -> bool:
|
|
|
38
41
|
async def test_detector_node():
|
|
39
42
|
"""Initializes and runs a detector testnode. Note that the running instance and the one the function returns are not the same instances!"""
|
|
40
43
|
|
|
41
|
-
os.environ['
|
|
42
|
-
os.environ['
|
|
44
|
+
os.environ['LOOP_ORGANIZATION'] = 'zauberzeug'
|
|
45
|
+
os.environ['LOOP_PROJECT'] = 'demo'
|
|
43
46
|
|
|
44
47
|
detector = TestingDetectorLogic()
|
|
45
48
|
node = DetectorNode(name='test', detector=detector)
|
|
@@ -113,6 +116,37 @@ def get_outbox_files(outbox: Outbox):
|
|
|
113
116
|
files = glob(f'{outbox.path}/**/*', recursive=True)
|
|
114
117
|
return [file for file in files if os.path.isfile(file)]
|
|
115
118
|
|
|
119
|
+
|
|
120
|
+
@pytest.fixture
|
|
121
|
+
def mock_detector_logic():
|
|
122
|
+
class MockDetectorLogic(DetectorLogic): # pylint: disable=abstract-method
|
|
123
|
+
def __init__(self):
|
|
124
|
+
super().__init__('mock')
|
|
125
|
+
self.detections = Detections(
|
|
126
|
+
box_detections=[BoxDetection(category_name="test",
|
|
127
|
+
category_id="1",
|
|
128
|
+
confidence=0.9,
|
|
129
|
+
x=0, y=0, width=10, height=10,
|
|
130
|
+
model_name="mock",
|
|
131
|
+
)]
|
|
132
|
+
)
|
|
133
|
+
|
|
134
|
+
@property
|
|
135
|
+
def is_initialized(self):
|
|
136
|
+
return True
|
|
137
|
+
|
|
138
|
+
def evaluate_with_all_info(self, image, tags, source): # pylint: disable=signature-differs
|
|
139
|
+
return self.detections
|
|
140
|
+
|
|
141
|
+
return MockDetectorLogic()
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
@pytest.fixture
|
|
145
|
+
def detector_node(mock_detector_logic):
|
|
146
|
+
os.environ['LOOP_ORGANIZATION'] = 'test_organization'
|
|
147
|
+
os.environ['LOOP_PROJECT'] = 'test_project'
|
|
148
|
+
return DetectorNode(name="test_node", detector=mock_detector_logic)
|
|
149
|
+
|
|
116
150
|
# ====================================== REDUNDANT FIXTURES IN ALL CONFTESTS ! ======================================
|
|
117
151
|
|
|
118
152
|
|
|
@@ -93,10 +93,10 @@ async def test_sio_upload(test_detector_node: DetectorNode, sio_client):
|
|
|
93
93
|
|
|
94
94
|
# NOTE: This test seems to be flaky.
|
|
95
95
|
async def test_about_endpoint(test_detector_node: DetectorNode):
|
|
96
|
-
await asyncio.sleep(
|
|
96
|
+
await asyncio.sleep(11)
|
|
97
97
|
response = requests.get(f'http://localhost:{GLOBALS.detector_port}/about', timeout=30)
|
|
98
98
|
|
|
99
|
-
assert response.status_code == 200
|
|
99
|
+
assert response.status_code == 200, response.content
|
|
100
100
|
response_dict = json.loads(response.content)
|
|
101
101
|
assert response_dict['model_info']
|
|
102
102
|
model_information = ModelInformation.from_dict(response_dict['model_info'])
|
|
@@ -108,59 +108,60 @@ async def test_about_endpoint(test_detector_node: DetectorNode):
|
|
|
108
108
|
|
|
109
109
|
|
|
110
110
|
async def test_model_version_api(test_detector_node: DetectorNode):
|
|
111
|
-
await asyncio.sleep(
|
|
111
|
+
await asyncio.sleep(11)
|
|
112
112
|
|
|
113
113
|
response = requests.get(f'http://localhost:{GLOBALS.detector_port}/model_version', timeout=30)
|
|
114
|
-
assert response.status_code == 200
|
|
114
|
+
assert response.status_code == 200, response.content
|
|
115
115
|
response_dict = json.loads(response.content)
|
|
116
|
+
assert response_dict['version_control'] == 'follow_loop'
|
|
116
117
|
assert response_dict['current_version'] == '1.1'
|
|
117
118
|
assert response_dict['target_version'] == '1.1'
|
|
118
119
|
assert response_dict['loop_version'] == '1.1'
|
|
119
120
|
assert response_dict['local_versions'] == ['1.1']
|
|
120
|
-
assert response_dict['version_control'] == 'follow_loop'
|
|
121
121
|
|
|
122
122
|
response = requests.put(f'http://localhost:{GLOBALS.detector_port}/model_version', data='1.0', timeout=30)
|
|
123
|
+
assert response.status_code == 200, response.content
|
|
123
124
|
response = requests.get(f'http://localhost:{GLOBALS.detector_port}/model_version', timeout=30)
|
|
124
|
-
assert response.status_code == 200
|
|
125
|
+
assert response.status_code == 200, response.content
|
|
125
126
|
response_dict = json.loads(response.content)
|
|
127
|
+
assert response_dict['version_control'] == 'specific_version'
|
|
126
128
|
assert response_dict['current_version'] == '1.1'
|
|
127
129
|
assert response_dict['target_version'] == '1.0'
|
|
128
130
|
assert response_dict['loop_version'] == '1.1'
|
|
129
131
|
assert response_dict['local_versions'] == ['1.1']
|
|
130
|
-
assert response_dict['version_control'] == 'specific_version'
|
|
131
132
|
|
|
132
133
|
await asyncio.sleep(11)
|
|
133
|
-
|
|
134
134
|
response = requests.get(f'http://localhost:{GLOBALS.detector_port}/model_version', timeout=30)
|
|
135
|
-
assert response.status_code == 200
|
|
135
|
+
assert response.status_code == 200, response.content
|
|
136
136
|
response_dict = json.loads(response.content)
|
|
137
|
+
assert response_dict['version_control'] == 'specific_version'
|
|
137
138
|
assert response_dict['current_version'] == '1.0'
|
|
138
139
|
assert response_dict['target_version'] == '1.0'
|
|
139
140
|
assert response_dict['loop_version'] == '1.1'
|
|
140
141
|
assert set(response_dict['local_versions']) == set(['1.1', '1.0'])
|
|
141
|
-
assert response_dict['version_control'] == 'specific_version'
|
|
142
142
|
|
|
143
143
|
response = requests.put(f'http://localhost:{GLOBALS.detector_port}/model_version', data='pause', timeout=30)
|
|
144
|
+
assert response.status_code == 200, response.content
|
|
144
145
|
await asyncio.sleep(11)
|
|
145
146
|
response = requests.get(f'http://localhost:{GLOBALS.detector_port}/model_version', timeout=30)
|
|
146
|
-
assert response.status_code == 200
|
|
147
|
+
assert response.status_code == 200, response.content
|
|
147
148
|
response_dict = json.loads(response.content)
|
|
149
|
+
assert response_dict['version_control'] == 'pause'
|
|
148
150
|
assert response_dict['current_version'] == '1.0'
|
|
149
151
|
assert response_dict['target_version'] == '1.0'
|
|
150
152
|
assert response_dict['loop_version'] == '1.1'
|
|
151
153
|
assert set(response_dict['local_versions']) == set(['1.1', '1.0'])
|
|
152
|
-
assert response_dict['version_control'] == 'pause'
|
|
153
154
|
|
|
154
155
|
response = requests.put(f'http://localhost:{GLOBALS.detector_port}/model_version', data='follow_loop', timeout=30)
|
|
155
156
|
await asyncio.sleep(11)
|
|
156
157
|
response = requests.get(f'http://localhost:{GLOBALS.detector_port}/model_version', timeout=30)
|
|
157
|
-
assert response.status_code == 200
|
|
158
|
+
assert response.status_code == 200, response.content
|
|
158
159
|
response_dict = json.loads(response.content)
|
|
160
|
+
assert response_dict['version_control'] == 'follow_loop'
|
|
159
161
|
assert response_dict['current_version'] == '1.1'
|
|
160
162
|
assert response_dict['target_version'] == '1.1'
|
|
161
163
|
assert response_dict['loop_version'] == '1.1'
|
|
162
164
|
assert set(response_dict['local_versions']) == set(['1.1', '1.0'])
|
|
163
|
-
assert response_dict['version_control'] == 'follow_loop'
|
|
164
165
|
|
|
165
166
|
|
|
166
167
|
async def test_rest_outbox_mode(test_detector_node: DetectorNode):
|
|
@@ -169,9 +170,9 @@ async def test_rest_outbox_mode(test_detector_node: DetectorNode):
|
|
|
169
170
|
def check_switch_to_mode(mode: str):
|
|
170
171
|
response = requests.put(f'http://localhost:{GLOBALS.detector_port}/outbox_mode',
|
|
171
172
|
data=mode, timeout=30)
|
|
172
|
-
assert response.status_code == 200
|
|
173
|
+
assert response.status_code == 200, response.content
|
|
173
174
|
response = requests.get(f'http://localhost:{GLOBALS.detector_port}/outbox_mode', timeout=30)
|
|
174
|
-
assert response.status_code == 200
|
|
175
|
+
assert response.status_code == 200, response.content
|
|
175
176
|
assert response.content == mode.encode()
|
|
176
177
|
|
|
177
178
|
check_switch_to_mode('stopped')
|
|
@@ -185,7 +186,7 @@ async def test_api_responsive_during_large_upload(test_detector_node: DetectorNo
|
|
|
185
186
|
with open(test_image_path, 'rb') as f:
|
|
186
187
|
image_bytes = f.read()
|
|
187
188
|
|
|
188
|
-
for _ in range(
|
|
189
|
+
for _ in range(200):
|
|
189
190
|
test_detector_node.outbox.save(image_bytes)
|
|
190
191
|
|
|
191
192
|
outbox_size_early = len(get_outbox_files(test_detector_node.outbox))
|
|
@@ -193,8 +194,9 @@ async def test_api_responsive_during_large_upload(test_detector_node: DetectorNo
|
|
|
193
194
|
|
|
194
195
|
# check if api is still responsive
|
|
195
196
|
response = requests.get(f'http://localhost:{GLOBALS.detector_port}/outbox_mode', timeout=2)
|
|
196
|
-
assert response.status_code == 200
|
|
197
|
+
assert response.status_code == 200, response.content
|
|
197
198
|
|
|
198
199
|
await asyncio.sleep(5)
|
|
199
200
|
outbox_size_late = len(get_outbox_files(test_detector_node.outbox))
|
|
200
|
-
assert
|
|
201
|
+
assert outbox_size_late > 0, 'The outbox should not be fully cleared, maybe the node was too fast.'
|
|
202
|
+
assert outbox_size_early > outbox_size_late, 'The outbox should have been partially emptied.'
|
|
@@ -0,0 +1,86 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
import pytest
|
|
3
|
+
|
|
4
|
+
from learning_loop_node.detector.detector_node import DetectorNode
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
@pytest.mark.asyncio
|
|
8
|
+
async def test_get_detections(detector_node: DetectorNode, monkeypatch):
|
|
9
|
+
# Mock raw image data
|
|
10
|
+
raw_image = np.zeros((100, 100, 3), dtype=np.uint8)
|
|
11
|
+
|
|
12
|
+
# Mock relevance_filter and outbox
|
|
13
|
+
filtered_upload_called = False
|
|
14
|
+
save_called = False
|
|
15
|
+
|
|
16
|
+
save_args = []
|
|
17
|
+
|
|
18
|
+
def mock_filtered_upload(*args, **kwargs): # pylint: disable=unused-argument
|
|
19
|
+
nonlocal filtered_upload_called
|
|
20
|
+
filtered_upload_called = True
|
|
21
|
+
|
|
22
|
+
def mock_save(*args, **kwargs):
|
|
23
|
+
nonlocal save_called
|
|
24
|
+
nonlocal save_args
|
|
25
|
+
save_called = True
|
|
26
|
+
save_args = (args, kwargs)
|
|
27
|
+
|
|
28
|
+
monkeypatch.setattr(detector_node.relevance_filter, "may_upload_detections", mock_filtered_upload)
|
|
29
|
+
monkeypatch.setattr(detector_node.outbox, "save", mock_save)
|
|
30
|
+
|
|
31
|
+
# Test cases
|
|
32
|
+
test_cases = [
|
|
33
|
+
(None, True, False),
|
|
34
|
+
("filtered", True, False),
|
|
35
|
+
("all", False, True),
|
|
36
|
+
("disabled", False, False),
|
|
37
|
+
]
|
|
38
|
+
|
|
39
|
+
expected_save_args = {
|
|
40
|
+
'image': raw_image,
|
|
41
|
+
'detections': detector_node.detector_logic.detections, # type: ignore
|
|
42
|
+
'tags': ['test_tag'],
|
|
43
|
+
'source': 'test_source',
|
|
44
|
+
}
|
|
45
|
+
|
|
46
|
+
for autoupload, expect_filtered, expect_all in test_cases:
|
|
47
|
+
filtered_upload_called = False
|
|
48
|
+
save_called = False
|
|
49
|
+
|
|
50
|
+
result = await detector_node.get_detections(
|
|
51
|
+
raw_image=raw_image,
|
|
52
|
+
camera_id="test_camera",
|
|
53
|
+
tags=["test_tag"],
|
|
54
|
+
source="test_source",
|
|
55
|
+
autoupload=autoupload
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
# Check if detections were processed
|
|
59
|
+
assert result is not None
|
|
60
|
+
assert result.box_detections is not None
|
|
61
|
+
assert len(result.box_detections) == 1
|
|
62
|
+
assert result.box_detections[0].category_name == "test"
|
|
63
|
+
|
|
64
|
+
# Check if the correct upload method was called
|
|
65
|
+
assert filtered_upload_called == expect_filtered
|
|
66
|
+
assert save_called == expect_all
|
|
67
|
+
|
|
68
|
+
if save_called:
|
|
69
|
+
save_pos_args, save_kwargs = save_args # pylint: disable=unbalanced-tuple-unpacking
|
|
70
|
+
expected_values = list(expected_save_args.values())
|
|
71
|
+
assert len(save_pos_args) + len(save_kwargs) == len(expected_values)
|
|
72
|
+
|
|
73
|
+
# Check positional arguments
|
|
74
|
+
for arg, expected in zip(save_pos_args, expected_values[:len(save_pos_args)]):
|
|
75
|
+
if isinstance(arg, (list, np.ndarray)):
|
|
76
|
+
assert np.array_equal(arg, expected)
|
|
77
|
+
else:
|
|
78
|
+
assert arg == expected
|
|
79
|
+
|
|
80
|
+
# Check keyword arguments
|
|
81
|
+
for key, value in save_kwargs.items():
|
|
82
|
+
expected = expected_save_args[key]
|
|
83
|
+
if isinstance(value, (list, np.ndarray)):
|
|
84
|
+
assert np.array_equal(value, expected)
|
|
85
|
+
else:
|
|
86
|
+
assert value == expected
|
|
@@ -1,7 +1,6 @@
|
|
|
1
|
-
from ...globals import GLOBALS
|
|
2
|
-
import shutil
|
|
3
1
|
import logging
|
|
4
2
|
import os
|
|
3
|
+
import shutil
|
|
5
4
|
import socket
|
|
6
5
|
from multiprocessing import log_to_stderr
|
|
7
6
|
|
|
@@ -9,6 +8,7 @@ import icecream
|
|
|
9
8
|
import pytest
|
|
10
9
|
|
|
11
10
|
from ...data_classes import Context
|
|
11
|
+
from ...globals import GLOBALS
|
|
12
12
|
from ...trainer.trainer_node import TrainerNode
|
|
13
13
|
from .testing_trainer_logic import TestingTrainerLogic
|
|
14
14
|
|
|
@@ -23,8 +23,8 @@ icecream.install()
|
|
|
23
23
|
|
|
24
24
|
@pytest.fixture()
|
|
25
25
|
async def test_initialized_trainer_node():
|
|
26
|
-
os.environ['
|
|
27
|
-
os.environ['
|
|
26
|
+
os.environ['LOOP_ORGANIZATION'] = 'zauberzeug'
|
|
27
|
+
os.environ['LOOP_PROJECT'] = 'demo'
|
|
28
28
|
|
|
29
29
|
trainer = TestingTrainerLogic()
|
|
30
30
|
node = TrainerNode(name='test', trainer_logic=trainer, uuid='NOD30000-0000-0000-0000-000000000000')
|
|
@@ -7,11 +7,10 @@ from ..state_helper import assert_training_state, create_active_training_file
|
|
|
7
7
|
from ..testing_trainer_logic import TestingTrainerLogic
|
|
8
8
|
|
|
9
9
|
# pylint: disable=protected-access
|
|
10
|
-
error_key = 'detecting'
|
|
11
10
|
|
|
12
11
|
|
|
13
|
-
def
|
|
14
|
-
return trainer.errors.has_error_for(
|
|
12
|
+
def trainer_has_detecting_error(trainer: TrainerLogic):
|
|
13
|
+
return trainer.errors.has_error_for('detecting')
|
|
15
14
|
|
|
16
15
|
|
|
17
16
|
async def test_successful_detecting(test_initialized_trainer: TestingTrainerLogic):
|
|
@@ -25,7 +24,7 @@ async def test_successful_detecting(test_initialized_trainer: TestingTrainerLogi
|
|
|
25
24
|
await assert_training_state(trainer.training, TrainerState.Detecting, timeout=1, interval=0.001)
|
|
26
25
|
await assert_training_state(trainer.training, TrainerState.Detected, timeout=10, interval=0.001)
|
|
27
26
|
|
|
28
|
-
assert
|
|
27
|
+
assert trainer_has_detecting_error(trainer) is False
|
|
29
28
|
assert trainer.training.training_state == TrainerState.Detected
|
|
30
29
|
assert trainer.node.last_training_io.load() == trainer.training
|
|
31
30
|
assert trainer.active_training_io.detections_exist()
|
|
@@ -37,7 +36,7 @@ async def test_detecting_can_be_aborted(test_initialized_trainer: TestingTrainer
|
|
|
37
36
|
trainer._init_from_last_training()
|
|
38
37
|
trainer.training.model_uuid_for_detecting = '12345678-bobo-7e92-f95f-424242424242'
|
|
39
38
|
|
|
40
|
-
|
|
39
|
+
trainer._begin_training_task()
|
|
41
40
|
|
|
42
41
|
await assert_training_state(trainer.training, TrainerState.Detecting, timeout=5, interval=0.001)
|
|
43
42
|
await trainer.stop()
|
|
@@ -54,13 +53,13 @@ async def test_model_not_downloadable_error(test_initialized_trainer: TestingTra
|
|
|
54
53
|
model_uuid_for_detecting='00000000-0000-0000-0000-000000000000') # bad model id
|
|
55
54
|
trainer._init_from_last_training()
|
|
56
55
|
|
|
57
|
-
|
|
56
|
+
trainer._begin_training_task()
|
|
58
57
|
|
|
59
|
-
await assert_training_state(trainer.training,
|
|
60
|
-
await assert_training_state(trainer.training,
|
|
58
|
+
await assert_training_state(trainer.training, TrainerState.Detecting, timeout=1, interval=0.001)
|
|
59
|
+
await assert_training_state(trainer.training, TrainerState.TrainModelUploaded, timeout=5, interval=0.001)
|
|
61
60
|
await asyncio.sleep(0.1)
|
|
62
61
|
|
|
63
|
-
assert
|
|
62
|
+
assert trainer_has_detecting_error(trainer)
|
|
64
63
|
assert trainer.training.training_state == TrainerState.TrainModelUploaded
|
|
65
64
|
assert trainer.training.model_uuid_for_detecting == '00000000-0000-0000-0000-000000000000'
|
|
66
65
|
assert trainer.node.last_training_io.load() == trainer.training
|
|
@@ -20,8 +20,8 @@ async def test_downloading_is_successful(test_initialized_trainer: TestingTraine
|
|
|
20
20
|
trainer._perform_state('download_model',
|
|
21
21
|
TrainerState.TrainModelDownloading,
|
|
22
22
|
TrainerState.TrainModelDownloaded, trainer._download_model))
|
|
23
|
-
await assert_training_state(trainer.training,
|
|
24
|
-
await assert_training_state(trainer.training,
|
|
23
|
+
await assert_training_state(trainer.training, TrainerState.TrainModelDownloading, timeout=1, interval=0.001)
|
|
24
|
+
await assert_training_state(trainer.training, TrainerState.TrainModelDownloaded, timeout=10, interval=0.001)
|
|
25
25
|
|
|
26
26
|
assert trainer.training.training_state == TrainerState.TrainModelDownloaded
|
|
27
27
|
assert trainer.node.last_training_io.load() == trainer.training
|
|
@@ -34,11 +34,11 @@ async def test_downloading_is_successful(test_initialized_trainer: TestingTraine
|
|
|
34
34
|
|
|
35
35
|
async def test_abort_download_model(test_initialized_trainer: TestingTrainerLogic):
|
|
36
36
|
trainer = test_initialized_trainer
|
|
37
|
-
create_active_training_file(trainer, training_state=
|
|
37
|
+
create_active_training_file(trainer, training_state=TrainerState.DataDownloaded)
|
|
38
38
|
trainer._init_from_last_training()
|
|
39
39
|
|
|
40
|
-
|
|
41
|
-
await assert_training_state(trainer.training,
|
|
40
|
+
trainer._begin_training_task()
|
|
41
|
+
await assert_training_state(trainer.training, TrainerState.TrainModelDownloading, timeout=1, interval=0.001)
|
|
42
42
|
|
|
43
43
|
await trainer.stop()
|
|
44
44
|
await asyncio.sleep(0.1)
|
|
@@ -53,9 +53,9 @@ async def test_downloading_failed(test_initialized_trainer: TestingTrainerLogic)
|
|
|
53
53
|
base_model_uuid_or_name='00000000-0000-0000-0000-000000000000') # bad model id)
|
|
54
54
|
trainer._init_from_last_training()
|
|
55
55
|
|
|
56
|
-
|
|
57
|
-
await assert_training_state(trainer.training,
|
|
58
|
-
await assert_training_state(trainer.training, TrainerState.DataDownloaded, timeout=
|
|
56
|
+
trainer._begin_training_task()
|
|
57
|
+
await assert_training_state(trainer.training, TrainerState.TrainModelDownloading, timeout=1, interval=0.001)
|
|
58
|
+
await assert_training_state(trainer.training, TrainerState.DataDownloaded, timeout=10, interval=0.001)
|
|
59
59
|
|
|
60
60
|
assert trainer.errors.has_error_for('download_model')
|
|
61
61
|
assert trainer._training is not None # pylint: disable=protected-access
|
|
@@ -6,11 +6,10 @@ from ..state_helper import assert_training_state, create_active_training_file
|
|
|
6
6
|
from ..testing_trainer_logic import TestingTrainerLogic
|
|
7
7
|
|
|
8
8
|
# pylint: disable=protected-access
|
|
9
|
-
error_key = 'prepare'
|
|
10
9
|
|
|
11
10
|
|
|
12
|
-
def
|
|
13
|
-
return trainer.errors.has_error_for(
|
|
11
|
+
def trainer_has_prepare_error(trainer: TrainerLogic):
|
|
12
|
+
return trainer.errors.has_error_for('prepare')
|
|
14
13
|
|
|
15
14
|
|
|
16
15
|
async def test_preparing_is_successful(test_initialized_trainer: TestingTrainerLogic):
|
|
@@ -19,7 +18,7 @@ async def test_preparing_is_successful(test_initialized_trainer: TestingTrainerL
|
|
|
19
18
|
trainer._init_from_last_training()
|
|
20
19
|
|
|
21
20
|
await trainer._perform_state('prepare', TrainerState.DataDownloading, TrainerState.DataDownloaded, trainer._prepare)
|
|
22
|
-
assert
|
|
21
|
+
assert trainer_has_prepare_error(trainer) is False
|
|
23
22
|
assert trainer.training.training_state == TrainerState.DataDownloaded
|
|
24
23
|
assert trainer.training.data is not None
|
|
25
24
|
assert trainer.node.last_training_io.load() == trainer.training
|
|
@@ -30,7 +29,7 @@ async def test_abort_preparing(test_initialized_trainer: TestingTrainerLogic):
|
|
|
30
29
|
create_active_training_file(trainer)
|
|
31
30
|
trainer._init_from_last_training()
|
|
32
31
|
|
|
33
|
-
|
|
32
|
+
trainer._begin_training_task()
|
|
34
33
|
await assert_training_state(trainer.training, TrainerState.DataDownloading, timeout=1, interval=0.001)
|
|
35
34
|
|
|
36
35
|
await trainer.stop()
|
|
@@ -48,9 +47,9 @@ async def test_request_error(test_initialized_trainer: TestingTrainerLogic):
|
|
|
48
47
|
|
|
49
48
|
_ = asyncio.get_running_loop().create_task(trainer._run())
|
|
50
49
|
await assert_training_state(trainer.training, TrainerState.DataDownloading, timeout=3, interval=0.001)
|
|
51
|
-
await assert_training_state(trainer.training, TrainerState.Initialized, timeout=
|
|
50
|
+
await assert_training_state(trainer.training, TrainerState.Initialized, timeout=10, interval=0.001)
|
|
52
51
|
|
|
53
|
-
assert
|
|
52
|
+
assert trainer_has_prepare_error(trainer)
|
|
54
53
|
assert trainer._training is not None # pylint: disable=protected-access
|
|
55
54
|
assert trainer.training.training_state == TrainerState.Initialized
|
|
56
55
|
assert trainer.node.last_training_io.load() == trainer.training
|
|
@@ -11,11 +11,9 @@ from ..testing_trainer_logic import TestingTrainerLogic
|
|
|
11
11
|
|
|
12
12
|
# pylint: disable=protected-access
|
|
13
13
|
|
|
14
|
-
error_key = 'sync_confusion_matrix'
|
|
15
14
|
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
return trainer.errors.has_error_for(error_key)
|
|
15
|
+
def trainer_has_sync_confusion_matrix_error(trainer: TrainerLogic):
|
|
16
|
+
return trainer.errors.has_error_for('sync_confusion_matrix')
|
|
19
17
|
|
|
20
18
|
|
|
21
19
|
async def test_nothing_to_sync(test_initialized_trainer: TestingTrainerLogic):
|
|
@@ -26,10 +24,10 @@ async def test_nothing_to_sync(test_initialized_trainer: TestingTrainerLogic):
|
|
|
26
24
|
create_active_training_file(trainer, training_state=TrainerState.TrainingFinished)
|
|
27
25
|
trainer._init_from_last_training()
|
|
28
26
|
|
|
29
|
-
|
|
27
|
+
trainer._begin_training_task()
|
|
30
28
|
|
|
31
29
|
await assert_training_state(trainer.training, TrainerState.ConfusionMatrixSynced, timeout=1, interval=0.001)
|
|
32
|
-
assert
|
|
30
|
+
assert trainer_has_sync_confusion_matrix_error(trainer) is False
|
|
33
31
|
assert trainer.training.training_state == TrainerState.ConfusionMatrixSynced
|
|
34
32
|
assert trainer.node.last_training_io.load() == trainer.training
|
|
35
33
|
|
|
@@ -38,16 +36,16 @@ async def test_unsynced_model_available__sync_successful(test_initialized_traine
|
|
|
38
36
|
trainer = test_initialized_trainer_node.trainer_logic
|
|
39
37
|
assert isinstance(trainer, TestingTrainerLogic)
|
|
40
38
|
|
|
41
|
-
await mock_socket_io_call(mocker, test_initialized_trainer_node, {'success': True})
|
|
39
|
+
await mock_socket_io_call(mocker, test_initialized_trainer_node, return_value={'success': True})
|
|
42
40
|
create_active_training_file(trainer, training_state=TrainerState.TrainingFinished)
|
|
43
41
|
|
|
44
42
|
trainer._init_from_last_training()
|
|
45
43
|
trainer.has_new_model = True
|
|
46
44
|
|
|
47
|
-
|
|
45
|
+
trainer._begin_training_task()
|
|
48
46
|
await assert_training_state(trainer.training, TrainerState.ConfusionMatrixSynced, timeout=1, interval=0.001)
|
|
49
47
|
|
|
50
|
-
assert
|
|
48
|
+
assert trainer_has_sync_confusion_matrix_error(trainer) is False
|
|
51
49
|
# assert trainer.training.training_state == TrainerState.ConfusionMatrixSynced
|
|
52
50
|
assert trainer.node.last_training_io.load() == trainer.training
|
|
53
51
|
|
|
@@ -56,17 +54,19 @@ async def test_unsynced_model_available__sio_not_connected(test_initialized_trai
|
|
|
56
54
|
trainer = test_initialized_trainer_node.trainer_logic
|
|
57
55
|
assert isinstance(trainer, TestingTrainerLogic)
|
|
58
56
|
|
|
57
|
+
await test_initialized_trainer_node.sio_client.disconnect()
|
|
58
|
+
test_initialized_trainer_node.set_skip_repeat_loop(True)
|
|
59
59
|
create_active_training_file(trainer, training_state=TrainerState.TrainingFinished)
|
|
60
60
|
|
|
61
61
|
assert test_initialized_trainer_node.sio_client.connected is False
|
|
62
62
|
trainer.has_new_model = True
|
|
63
63
|
|
|
64
|
-
|
|
64
|
+
trainer._begin_training_task()
|
|
65
65
|
|
|
66
|
-
await assert_training_state(trainer.training,
|
|
67
|
-
await assert_training_state(trainer.training, TrainerState.TrainingFinished, timeout=
|
|
66
|
+
await assert_training_state(trainer.training, TrainerState.ConfusionMatrixSyncing, timeout=1, interval=0.001)
|
|
67
|
+
await assert_training_state(trainer.training, TrainerState.TrainingFinished, timeout=10, interval=0.001)
|
|
68
68
|
|
|
69
|
-
assert
|
|
69
|
+
assert trainer_has_sync_confusion_matrix_error(trainer) # Due to sio not being connected, the request will fail
|
|
70
70
|
assert trainer.training.training_state == TrainerState.TrainingFinished
|
|
71
71
|
assert trainer.node.last_training_io.load() == trainer.training
|
|
72
72
|
|
|
@@ -75,17 +75,17 @@ async def test_unsynced_model_available__request_is_not_successful(test_initiali
|
|
|
75
75
|
trainer = test_initialized_trainer_node.trainer_logic
|
|
76
76
|
assert isinstance(trainer, TestingTrainerLogic)
|
|
77
77
|
|
|
78
|
-
await mock_socket_io_call(mocker, test_initialized_trainer_node, {'success': False})
|
|
78
|
+
await mock_socket_io_call(mocker, test_initialized_trainer_node, return_value={'success': False})
|
|
79
79
|
|
|
80
80
|
create_active_training_file(trainer, training_state=TrainerState.TrainingFinished)
|
|
81
81
|
|
|
82
82
|
trainer.has_new_model = True
|
|
83
|
-
|
|
83
|
+
trainer._begin_training_task()
|
|
84
84
|
|
|
85
|
-
await assert_training_state(trainer.training,
|
|
86
|
-
await assert_training_state(trainer.training, TrainerState.TrainingFinished, timeout=
|
|
85
|
+
await assert_training_state(trainer.training, TrainerState.ConfusionMatrixSyncing, timeout=1, interval=0.001)
|
|
86
|
+
await assert_training_state(trainer.training, TrainerState.TrainingFinished, timeout=10, interval=0.001)
|
|
87
87
|
|
|
88
|
-
assert
|
|
88
|
+
assert trainer_has_sync_confusion_matrix_error(trainer) # Due to sio call failure, the error will be set
|
|
89
89
|
assert trainer.training.training_state == TrainerState.TrainingFinished
|
|
90
90
|
assert trainer.node.last_training_io.load() == trainer.training
|
|
91
91
|
|
|
@@ -100,6 +100,9 @@ async def test_basic_mock(test_initialized_trainer_node: TrainerNode, mocker: Mo
|
|
|
100
100
|
|
|
101
101
|
|
|
102
102
|
async def mock_socket_io_call(mocker, trainer_node: TrainerNode, return_value):
|
|
103
|
+
'''
|
|
104
|
+
Patch the socketio call function to always return the given return_value
|
|
105
|
+
'''
|
|
103
106
|
for _ in range(10):
|
|
104
107
|
if trainer_node.sio_client is None:
|
|
105
108
|
await asyncio.sleep(0.1)
|
|
@@ -14,10 +14,10 @@ async def test_successful_training(test_initialized_trainer: TestingTrainerLogic
|
|
|
14
14
|
create_active_training_file(trainer, training_state=TrainerState.TrainModelDownloaded)
|
|
15
15
|
trainer._init_from_last_training()
|
|
16
16
|
|
|
17
|
-
|
|
17
|
+
trainer._begin_training_task()
|
|
18
18
|
|
|
19
19
|
await condition(lambda: trainer._executor and trainer._executor.is_running(), timeout=1, interval=0.01)
|
|
20
|
-
await assert_training_state(trainer.training, TrainerState.TrainingRunning, timeout=
|
|
20
|
+
await assert_training_state(trainer.training, TrainerState.TrainingRunning, timeout=10, interval=0.01)
|
|
21
21
|
assert trainer.start_training_task is not None
|
|
22
22
|
|
|
23
23
|
assert trainer._executor is not None
|
|
@@ -34,16 +34,15 @@ async def test_stop_running_training(test_initialized_trainer: TestingTrainerLog
|
|
|
34
34
|
create_active_training_file(trainer, training_state=TrainerState.TrainModelDownloaded)
|
|
35
35
|
trainer._init_from_last_training()
|
|
36
36
|
|
|
37
|
-
|
|
37
|
+
trainer._begin_training_task()
|
|
38
38
|
|
|
39
39
|
await condition(lambda: trainer._executor and trainer._executor.is_running(), timeout=1, interval=0.01)
|
|
40
|
-
await assert_training_state(trainer.training, TrainerState.TrainingRunning, timeout=
|
|
40
|
+
await assert_training_state(trainer.training, TrainerState.TrainingRunning, timeout=10, interval=0.01)
|
|
41
41
|
assert trainer.start_training_task is not None
|
|
42
42
|
|
|
43
43
|
await trainer.stop()
|
|
44
44
|
await assert_training_state(trainer.training, TrainerState.TrainingFinished, timeout=2, interval=0.01)
|
|
45
45
|
|
|
46
|
-
assert trainer.training.training_state == TrainerState.TrainingFinished
|
|
47
46
|
assert trainer.node.last_training_io.load() == trainer.training
|
|
48
47
|
|
|
49
48
|
|
|
@@ -55,15 +54,14 @@ async def test_training_can_maybe_resumed(test_initialized_trainer: TestingTrain
|
|
|
55
54
|
trainer._init_from_last_training()
|
|
56
55
|
trainer._can_resume_flag = True
|
|
57
56
|
|
|
58
|
-
|
|
57
|
+
trainer._begin_training_task()
|
|
59
58
|
|
|
60
59
|
await condition(lambda: trainer._executor and trainer._executor.is_running(), timeout=1, interval=0.01)
|
|
61
|
-
await assert_training_state(trainer.training, TrainerState.TrainingRunning, timeout=
|
|
60
|
+
await assert_training_state(trainer.training, TrainerState.TrainingRunning, timeout=10, interval=0.001)
|
|
62
61
|
assert trainer.start_training_task is not None
|
|
63
62
|
|
|
64
63
|
assert trainer._executor is not None
|
|
65
64
|
await trainer._executor.stop_and_wait() # NOTE normally a training terminates itself e.g
|
|
66
65
|
await assert_training_state(trainer.training, TrainerState.TrainingFinished, timeout=1, interval=0.001)
|
|
67
66
|
|
|
68
|
-
assert trainer.training.training_state == TrainerState.TrainingFinished
|
|
69
67
|
assert trainer.node.last_training_io.load() == trainer.training
|
|
@@ -11,11 +11,10 @@ from ..state_helper import assert_training_state, create_active_training_file
|
|
|
11
11
|
from ..testing_trainer_logic import TestingTrainerLogic
|
|
12
12
|
|
|
13
13
|
# pylint: disable=protected-access
|
|
14
|
-
error_key = 'upload_detections'
|
|
15
14
|
|
|
16
15
|
|
|
17
|
-
def
|
|
18
|
-
return trainer.errors.has_error_for(
|
|
16
|
+
def trainer_has_upload_detections_error(trainer: TrainerLogic):
|
|
17
|
+
return trainer.errors.has_error_for('upload_detections')
|
|
19
18
|
|
|
20
19
|
|
|
21
20
|
async def create_valid_detection_file(trainer: TrainerLogic, number_of_entries: int = 1, file_index: int = 0):
|
|
@@ -125,12 +124,11 @@ async def test_bad_status_from_LearningLoop(test_initialized_trainer: TestingTra
|
|
|
125
124
|
trainer._init_from_last_training()
|
|
126
125
|
trainer.active_training_io.save_detections([get_dummy_detections()])
|
|
127
126
|
|
|
128
|
-
|
|
127
|
+
trainer._begin_training_task()
|
|
129
128
|
await assert_training_state(trainer.training, TrainerState.DetectionUploading, timeout=1, interval=0.001)
|
|
130
|
-
await assert_training_state(trainer.training, TrainerState.Detected, timeout=
|
|
129
|
+
await assert_training_state(trainer.training, TrainerState.Detected, timeout=10, interval=0.001)
|
|
131
130
|
|
|
132
|
-
assert
|
|
133
|
-
assert trainer.training.training_state == TrainerState.Detected
|
|
131
|
+
assert trainer_has_upload_detections_error(trainer)
|
|
134
132
|
assert trainer.node.last_training_io.load() == trainer.training
|
|
135
133
|
|
|
136
134
|
|
|
@@ -143,7 +141,7 @@ async def test_go_to_cleanup_if_no_detections_exist(test_initialized_trainer: Te
|
|
|
143
141
|
create_active_training_file(trainer, training_state=TrainerState.Detected)
|
|
144
142
|
trainer._init_from_last_training()
|
|
145
143
|
|
|
146
|
-
|
|
144
|
+
trainer._begin_training_task()
|
|
147
145
|
await assert_training_state(trainer.training, TrainerState.ReadyForCleanup, timeout=1, interval=0.001)
|
|
148
146
|
|
|
149
147
|
|
|
@@ -154,7 +152,7 @@ async def test_abort_uploading(test_initialized_trainer: TestingTrainerLogic):
|
|
|
154
152
|
trainer._init_from_last_training()
|
|
155
153
|
await create_valid_detection_file(trainer)
|
|
156
154
|
|
|
157
|
-
|
|
155
|
+
trainer._begin_training_task()
|
|
158
156
|
|
|
159
157
|
await assert_training_state(trainer.training, TrainerState.DetectionUploading, timeout=1, interval=0.001)
|
|
160
158
|
|