learning-loop-node 0.10.12__py3-none-any.whl → 0.10.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of learning-loop-node might be problematic. Click here for more details.

Files changed (37) hide show
  1. learning_loop_node/annotation/annotator_node.py +11 -10
  2. learning_loop_node/data_classes/detections.py +34 -25
  3. learning_loop_node/data_classes/general.py +27 -17
  4. learning_loop_node/data_exchanger.py +6 -5
  5. learning_loop_node/detector/detector_logic.py +10 -4
  6. learning_loop_node/detector/detector_node.py +80 -54
  7. learning_loop_node/detector/inbox_filter/relevance_filter.py +9 -3
  8. learning_loop_node/detector/outbox.py +8 -1
  9. learning_loop_node/detector/rest/about.py +34 -9
  10. learning_loop_node/detector/rest/backdoor_controls.py +10 -29
  11. learning_loop_node/detector/rest/detect.py +27 -19
  12. learning_loop_node/detector/rest/model_version_control.py +30 -13
  13. learning_loop_node/detector/rest/operation_mode.py +11 -5
  14. learning_loop_node/detector/rest/outbox_mode.py +7 -1
  15. learning_loop_node/helpers/log_conf.py +5 -0
  16. learning_loop_node/node.py +97 -49
  17. learning_loop_node/rest.py +55 -0
  18. learning_loop_node/tests/detector/conftest.py +36 -2
  19. learning_loop_node/tests/detector/test_client_communication.py +21 -19
  20. learning_loop_node/tests/detector/test_detector_node.py +86 -0
  21. learning_loop_node/tests/trainer/conftest.py +4 -4
  22. learning_loop_node/tests/trainer/states/test_state_detecting.py +8 -9
  23. learning_loop_node/tests/trainer/states/test_state_download_train_model.py +8 -8
  24. learning_loop_node/tests/trainer/states/test_state_prepare.py +6 -7
  25. learning_loop_node/tests/trainer/states/test_state_sync_confusion_matrix.py +21 -18
  26. learning_loop_node/tests/trainer/states/test_state_train.py +6 -8
  27. learning_loop_node/tests/trainer/states/test_state_upload_detections.py +7 -9
  28. learning_loop_node/tests/trainer/states/test_state_upload_model.py +7 -8
  29. learning_loop_node/tests/trainer/test_errors.py +2 -2
  30. learning_loop_node/trainer/io_helpers.py +3 -6
  31. learning_loop_node/trainer/rest/backdoor_controls.py +19 -40
  32. learning_loop_node/trainer/trainer_logic.py +4 -4
  33. learning_loop_node/trainer/trainer_logic_generic.py +15 -12
  34. learning_loop_node/trainer/trainer_node.py +5 -4
  35. {learning_loop_node-0.10.12.dist-info → learning_loop_node-0.10.14.dist-info}/METADATA +16 -15
  36. {learning_loop_node-0.10.12.dist-info → learning_loop_node-0.10.14.dist-info}/RECORD +37 -35
  37. {learning_loop_node-0.10.12.dist-info → learning_loop_node-0.10.14.dist-info}/WHEEL +0 -0
@@ -12,6 +12,9 @@ import pytest
12
12
  import socketio
13
13
  import uvicorn
14
14
 
15
+ from learning_loop_node.data_classes import BoxDetection, Detections
16
+ from learning_loop_node.detector.detector_logic import DetectorLogic
17
+
15
18
  from ...detector.detector_node import DetectorNode
16
19
  from ...detector.outbox import Outbox
17
20
  from ...globals import GLOBALS
@@ -38,8 +41,8 @@ def should_have_segmentations(request) -> bool:
38
41
  async def test_detector_node():
39
42
  """Initializes and runs a detector testnode. Note that the running instance and the one the function returns are not the same instances!"""
40
43
 
41
- os.environ['ORGANIZATION'] = 'zauberzeug'
42
- os.environ['PROJECT'] = 'demo'
44
+ os.environ['LOOP_ORGANIZATION'] = 'zauberzeug'
45
+ os.environ['LOOP_PROJECT'] = 'demo'
43
46
 
44
47
  detector = TestingDetectorLogic()
45
48
  node = DetectorNode(name='test', detector=detector)
@@ -113,6 +116,37 @@ def get_outbox_files(outbox: Outbox):
113
116
  files = glob(f'{outbox.path}/**/*', recursive=True)
114
117
  return [file for file in files if os.path.isfile(file)]
115
118
 
119
+
120
+ @pytest.fixture
121
+ def mock_detector_logic():
122
+ class MockDetectorLogic(DetectorLogic): # pylint: disable=abstract-method
123
+ def __init__(self):
124
+ super().__init__('mock')
125
+ self.detections = Detections(
126
+ box_detections=[BoxDetection(category_name="test",
127
+ category_id="1",
128
+ confidence=0.9,
129
+ x=0, y=0, width=10, height=10,
130
+ model_name="mock",
131
+ )]
132
+ )
133
+
134
+ @property
135
+ def is_initialized(self):
136
+ return True
137
+
138
+ def evaluate_with_all_info(self, image, tags, source): # pylint: disable=signature-differs
139
+ return self.detections
140
+
141
+ return MockDetectorLogic()
142
+
143
+
144
+ @pytest.fixture
145
+ def detector_node(mock_detector_logic):
146
+ os.environ['LOOP_ORGANIZATION'] = 'test_organization'
147
+ os.environ['LOOP_PROJECT'] = 'test_project'
148
+ return DetectorNode(name="test_node", detector=mock_detector_logic)
149
+
116
150
  # ====================================== REDUNDANT FIXTURES IN ALL CONFTESTS ! ======================================
117
151
 
118
152
 
@@ -93,10 +93,10 @@ async def test_sio_upload(test_detector_node: DetectorNode, sio_client):
93
93
 
94
94
  # NOTE: This test seems to be flaky.
95
95
  async def test_about_endpoint(test_detector_node: DetectorNode):
96
- await asyncio.sleep(3)
96
+ await asyncio.sleep(11)
97
97
  response = requests.get(f'http://localhost:{GLOBALS.detector_port}/about', timeout=30)
98
98
 
99
- assert response.status_code == 200
99
+ assert response.status_code == 200, response.content
100
100
  response_dict = json.loads(response.content)
101
101
  assert response_dict['model_info']
102
102
  model_information = ModelInformation.from_dict(response_dict['model_info'])
@@ -108,59 +108,60 @@ async def test_about_endpoint(test_detector_node: DetectorNode):
108
108
 
109
109
 
110
110
  async def test_model_version_api(test_detector_node: DetectorNode):
111
- await asyncio.sleep(3)
111
+ await asyncio.sleep(11)
112
112
 
113
113
  response = requests.get(f'http://localhost:{GLOBALS.detector_port}/model_version', timeout=30)
114
- assert response.status_code == 200
114
+ assert response.status_code == 200, response.content
115
115
  response_dict = json.loads(response.content)
116
+ assert response_dict['version_control'] == 'follow_loop'
116
117
  assert response_dict['current_version'] == '1.1'
117
118
  assert response_dict['target_version'] == '1.1'
118
119
  assert response_dict['loop_version'] == '1.1'
119
120
  assert response_dict['local_versions'] == ['1.1']
120
- assert response_dict['version_control'] == 'follow_loop'
121
121
 
122
122
  response = requests.put(f'http://localhost:{GLOBALS.detector_port}/model_version', data='1.0', timeout=30)
123
+ assert response.status_code == 200, response.content
123
124
  response = requests.get(f'http://localhost:{GLOBALS.detector_port}/model_version', timeout=30)
124
- assert response.status_code == 200
125
+ assert response.status_code == 200, response.content
125
126
  response_dict = json.loads(response.content)
127
+ assert response_dict['version_control'] == 'specific_version'
126
128
  assert response_dict['current_version'] == '1.1'
127
129
  assert response_dict['target_version'] == '1.0'
128
130
  assert response_dict['loop_version'] == '1.1'
129
131
  assert response_dict['local_versions'] == ['1.1']
130
- assert response_dict['version_control'] == 'specific_version'
131
132
 
132
133
  await asyncio.sleep(11)
133
-
134
134
  response = requests.get(f'http://localhost:{GLOBALS.detector_port}/model_version', timeout=30)
135
- assert response.status_code == 200
135
+ assert response.status_code == 200, response.content
136
136
  response_dict = json.loads(response.content)
137
+ assert response_dict['version_control'] == 'specific_version'
137
138
  assert response_dict['current_version'] == '1.0'
138
139
  assert response_dict['target_version'] == '1.0'
139
140
  assert response_dict['loop_version'] == '1.1'
140
141
  assert set(response_dict['local_versions']) == set(['1.1', '1.0'])
141
- assert response_dict['version_control'] == 'specific_version'
142
142
 
143
143
  response = requests.put(f'http://localhost:{GLOBALS.detector_port}/model_version', data='pause', timeout=30)
144
+ assert response.status_code == 200, response.content
144
145
  await asyncio.sleep(11)
145
146
  response = requests.get(f'http://localhost:{GLOBALS.detector_port}/model_version', timeout=30)
146
- assert response.status_code == 200
147
+ assert response.status_code == 200, response.content
147
148
  response_dict = json.loads(response.content)
149
+ assert response_dict['version_control'] == 'pause'
148
150
  assert response_dict['current_version'] == '1.0'
149
151
  assert response_dict['target_version'] == '1.0'
150
152
  assert response_dict['loop_version'] == '1.1'
151
153
  assert set(response_dict['local_versions']) == set(['1.1', '1.0'])
152
- assert response_dict['version_control'] == 'pause'
153
154
 
154
155
  response = requests.put(f'http://localhost:{GLOBALS.detector_port}/model_version', data='follow_loop', timeout=30)
155
156
  await asyncio.sleep(11)
156
157
  response = requests.get(f'http://localhost:{GLOBALS.detector_port}/model_version', timeout=30)
157
- assert response.status_code == 200
158
+ assert response.status_code == 200, response.content
158
159
  response_dict = json.loads(response.content)
160
+ assert response_dict['version_control'] == 'follow_loop'
159
161
  assert response_dict['current_version'] == '1.1'
160
162
  assert response_dict['target_version'] == '1.1'
161
163
  assert response_dict['loop_version'] == '1.1'
162
164
  assert set(response_dict['local_versions']) == set(['1.1', '1.0'])
163
- assert response_dict['version_control'] == 'follow_loop'
164
165
 
165
166
 
166
167
  async def test_rest_outbox_mode(test_detector_node: DetectorNode):
@@ -169,9 +170,9 @@ async def test_rest_outbox_mode(test_detector_node: DetectorNode):
169
170
  def check_switch_to_mode(mode: str):
170
171
  response = requests.put(f'http://localhost:{GLOBALS.detector_port}/outbox_mode',
171
172
  data=mode, timeout=30)
172
- assert response.status_code == 200
173
+ assert response.status_code == 200, response.content
173
174
  response = requests.get(f'http://localhost:{GLOBALS.detector_port}/outbox_mode', timeout=30)
174
- assert response.status_code == 200
175
+ assert response.status_code == 200, response.content
175
176
  assert response.content == mode.encode()
176
177
 
177
178
  check_switch_to_mode('stopped')
@@ -185,7 +186,7 @@ async def test_api_responsive_during_large_upload(test_detector_node: DetectorNo
185
186
  with open(test_image_path, 'rb') as f:
186
187
  image_bytes = f.read()
187
188
 
188
- for _ in range(100):
189
+ for _ in range(200):
189
190
  test_detector_node.outbox.save(image_bytes)
190
191
 
191
192
  outbox_size_early = len(get_outbox_files(test_detector_node.outbox))
@@ -193,8 +194,9 @@ async def test_api_responsive_during_large_upload(test_detector_node: DetectorNo
193
194
 
194
195
  # check if api is still responsive
195
196
  response = requests.get(f'http://localhost:{GLOBALS.detector_port}/outbox_mode', timeout=2)
196
- assert response.status_code == 200
197
+ assert response.status_code == 200, response.content
197
198
 
198
199
  await asyncio.sleep(5)
199
200
  outbox_size_late = len(get_outbox_files(test_detector_node.outbox))
200
- assert outbox_size_early > outbox_size_late > 0, 'The outbox should have been partially emptied.'
201
+ assert outbox_size_late > 0, 'The outbox should not be fully cleared, maybe the node was too fast.'
202
+ assert outbox_size_early > outbox_size_late, 'The outbox should have been partially emptied.'
@@ -0,0 +1,86 @@
1
+ import numpy as np
2
+ import pytest
3
+
4
+ from learning_loop_node.detector.detector_node import DetectorNode
5
+
6
+
7
+ @pytest.mark.asyncio
8
+ async def test_get_detections(detector_node: DetectorNode, monkeypatch):
9
+ # Mock raw image data
10
+ raw_image = np.zeros((100, 100, 3), dtype=np.uint8)
11
+
12
+ # Mock relevance_filter and outbox
13
+ filtered_upload_called = False
14
+ save_called = False
15
+
16
+ save_args = []
17
+
18
+ def mock_filtered_upload(*args, **kwargs): # pylint: disable=unused-argument
19
+ nonlocal filtered_upload_called
20
+ filtered_upload_called = True
21
+
22
+ def mock_save(*args, **kwargs):
23
+ nonlocal save_called
24
+ nonlocal save_args
25
+ save_called = True
26
+ save_args = (args, kwargs)
27
+
28
+ monkeypatch.setattr(detector_node.relevance_filter, "may_upload_detections", mock_filtered_upload)
29
+ monkeypatch.setattr(detector_node.outbox, "save", mock_save)
30
+
31
+ # Test cases
32
+ test_cases = [
33
+ (None, True, False),
34
+ ("filtered", True, False),
35
+ ("all", False, True),
36
+ ("disabled", False, False),
37
+ ]
38
+
39
+ expected_save_args = {
40
+ 'image': raw_image,
41
+ 'detections': detector_node.detector_logic.detections, # type: ignore
42
+ 'tags': ['test_tag'],
43
+ 'source': 'test_source',
44
+ }
45
+
46
+ for autoupload, expect_filtered, expect_all in test_cases:
47
+ filtered_upload_called = False
48
+ save_called = False
49
+
50
+ result = await detector_node.get_detections(
51
+ raw_image=raw_image,
52
+ camera_id="test_camera",
53
+ tags=["test_tag"],
54
+ source="test_source",
55
+ autoupload=autoupload
56
+ )
57
+
58
+ # Check if detections were processed
59
+ assert result is not None
60
+ assert result.box_detections is not None
61
+ assert len(result.box_detections) == 1
62
+ assert result.box_detections[0].category_name == "test"
63
+
64
+ # Check if the correct upload method was called
65
+ assert filtered_upload_called == expect_filtered
66
+ assert save_called == expect_all
67
+
68
+ if save_called:
69
+ save_pos_args, save_kwargs = save_args # pylint: disable=unbalanced-tuple-unpacking
70
+ expected_values = list(expected_save_args.values())
71
+ assert len(save_pos_args) + len(save_kwargs) == len(expected_values)
72
+
73
+ # Check positional arguments
74
+ for arg, expected in zip(save_pos_args, expected_values[:len(save_pos_args)]):
75
+ if isinstance(arg, (list, np.ndarray)):
76
+ assert np.array_equal(arg, expected)
77
+ else:
78
+ assert arg == expected
79
+
80
+ # Check keyword arguments
81
+ for key, value in save_kwargs.items():
82
+ expected = expected_save_args[key]
83
+ if isinstance(value, (list, np.ndarray)):
84
+ assert np.array_equal(value, expected)
85
+ else:
86
+ assert value == expected
@@ -1,7 +1,6 @@
1
- from ...globals import GLOBALS
2
- import shutil
3
1
  import logging
4
2
  import os
3
+ import shutil
5
4
  import socket
6
5
  from multiprocessing import log_to_stderr
7
6
 
@@ -9,6 +8,7 @@ import icecream
9
8
  import pytest
10
9
 
11
10
  from ...data_classes import Context
11
+ from ...globals import GLOBALS
12
12
  from ...trainer.trainer_node import TrainerNode
13
13
  from .testing_trainer_logic import TestingTrainerLogic
14
14
 
@@ -23,8 +23,8 @@ icecream.install()
23
23
 
24
24
  @pytest.fixture()
25
25
  async def test_initialized_trainer_node():
26
- os.environ['ORGANIZATION'] = 'zauberzeug'
27
- os.environ['PROJECT'] = 'demo'
26
+ os.environ['LOOP_ORGANIZATION'] = 'zauberzeug'
27
+ os.environ['LOOP_PROJECT'] = 'demo'
28
28
 
29
29
  trainer = TestingTrainerLogic()
30
30
  node = TrainerNode(name='test', trainer_logic=trainer, uuid='NOD30000-0000-0000-0000-000000000000')
@@ -7,11 +7,10 @@ from ..state_helper import assert_training_state, create_active_training_file
7
7
  from ..testing_trainer_logic import TestingTrainerLogic
8
8
 
9
9
  # pylint: disable=protected-access
10
- error_key = 'detecting'
11
10
 
12
11
 
13
- def trainer_has_error(trainer: TrainerLogic):
14
- return trainer.errors.has_error_for(error_key)
12
+ def trainer_has_detecting_error(trainer: TrainerLogic):
13
+ return trainer.errors.has_error_for('detecting')
15
14
 
16
15
 
17
16
  async def test_successful_detecting(test_initialized_trainer: TestingTrainerLogic):
@@ -25,7 +24,7 @@ async def test_successful_detecting(test_initialized_trainer: TestingTrainerLogi
25
24
  await assert_training_state(trainer.training, TrainerState.Detecting, timeout=1, interval=0.001)
26
25
  await assert_training_state(trainer.training, TrainerState.Detected, timeout=10, interval=0.001)
27
26
 
28
- assert trainer_has_error(trainer) is False
27
+ assert trainer_has_detecting_error(trainer) is False
29
28
  assert trainer.training.training_state == TrainerState.Detected
30
29
  assert trainer.node.last_training_io.load() == trainer.training
31
30
  assert trainer.active_training_io.detections_exist()
@@ -37,7 +36,7 @@ async def test_detecting_can_be_aborted(test_initialized_trainer: TestingTrainer
37
36
  trainer._init_from_last_training()
38
37
  trainer.training.model_uuid_for_detecting = '12345678-bobo-7e92-f95f-424242424242'
39
38
 
40
- _ = asyncio.get_running_loop().create_task(trainer._run())
39
+ trainer._begin_training_task()
41
40
 
42
41
  await assert_training_state(trainer.training, TrainerState.Detecting, timeout=5, interval=0.001)
43
42
  await trainer.stop()
@@ -54,13 +53,13 @@ async def test_model_not_downloadable_error(test_initialized_trainer: TestingTra
54
53
  model_uuid_for_detecting='00000000-0000-0000-0000-000000000000') # bad model id
55
54
  trainer._init_from_last_training()
56
55
 
57
- _ = asyncio.get_running_loop().create_task(trainer._run())
56
+ trainer._begin_training_task()
58
57
 
59
- await assert_training_state(trainer.training, 'detecting', timeout=1, interval=0.001)
60
- await assert_training_state(trainer.training, 'train_model_uploaded', timeout=1, interval=0.001)
58
+ await assert_training_state(trainer.training, TrainerState.Detecting, timeout=1, interval=0.001)
59
+ await assert_training_state(trainer.training, TrainerState.TrainModelUploaded, timeout=5, interval=0.001)
61
60
  await asyncio.sleep(0.1)
62
61
 
63
- assert trainer_has_error(trainer)
62
+ assert trainer_has_detecting_error(trainer)
64
63
  assert trainer.training.training_state == TrainerState.TrainModelUploaded
65
64
  assert trainer.training.model_uuid_for_detecting == '00000000-0000-0000-0000-000000000000'
66
65
  assert trainer.node.last_training_io.load() == trainer.training
@@ -20,8 +20,8 @@ async def test_downloading_is_successful(test_initialized_trainer: TestingTraine
20
20
  trainer._perform_state('download_model',
21
21
  TrainerState.TrainModelDownloading,
22
22
  TrainerState.TrainModelDownloaded, trainer._download_model))
23
- await assert_training_state(trainer.training, 'train_model_downloading', timeout=1, interval=0.001)
24
- await assert_training_state(trainer.training, 'train_model_downloaded', timeout=1, interval=0.001)
23
+ await assert_training_state(trainer.training, TrainerState.TrainModelDownloading, timeout=1, interval=0.001)
24
+ await assert_training_state(trainer.training, TrainerState.TrainModelDownloaded, timeout=10, interval=0.001)
25
25
 
26
26
  assert trainer.training.training_state == TrainerState.TrainModelDownloaded
27
27
  assert trainer.node.last_training_io.load() == trainer.training
@@ -34,11 +34,11 @@ async def test_downloading_is_successful(test_initialized_trainer: TestingTraine
34
34
 
35
35
  async def test_abort_download_model(test_initialized_trainer: TestingTrainerLogic):
36
36
  trainer = test_initialized_trainer
37
- create_active_training_file(trainer, training_state='data_downloaded')
37
+ create_active_training_file(trainer, training_state=TrainerState.DataDownloaded)
38
38
  trainer._init_from_last_training()
39
39
 
40
- _ = asyncio.get_running_loop().create_task(trainer._run())
41
- await assert_training_state(trainer.training, 'train_model_downloading', timeout=1, interval=0.001)
40
+ trainer._begin_training_task()
41
+ await assert_training_state(trainer.training, TrainerState.TrainModelDownloading, timeout=1, interval=0.001)
42
42
 
43
43
  await trainer.stop()
44
44
  await asyncio.sleep(0.1)
@@ -53,9 +53,9 @@ async def test_downloading_failed(test_initialized_trainer: TestingTrainerLogic)
53
53
  base_model_uuid_or_name='00000000-0000-0000-0000-000000000000') # bad model id)
54
54
  trainer._init_from_last_training()
55
55
 
56
- _ = asyncio.get_running_loop().create_task(trainer._run())
57
- await assert_training_state(trainer.training, 'train_model_downloading', timeout=1, interval=0.001)
58
- await assert_training_state(trainer.training, TrainerState.DataDownloaded, timeout=1, interval=0.001)
56
+ trainer._begin_training_task()
57
+ await assert_training_state(trainer.training, TrainerState.TrainModelDownloading, timeout=1, interval=0.001)
58
+ await assert_training_state(trainer.training, TrainerState.DataDownloaded, timeout=10, interval=0.001)
59
59
 
60
60
  assert trainer.errors.has_error_for('download_model')
61
61
  assert trainer._training is not None # pylint: disable=protected-access
@@ -6,11 +6,10 @@ from ..state_helper import assert_training_state, create_active_training_file
6
6
  from ..testing_trainer_logic import TestingTrainerLogic
7
7
 
8
8
  # pylint: disable=protected-access
9
- error_key = 'prepare'
10
9
 
11
10
 
12
- def trainer_has_error(trainer: TrainerLogic):
13
- return trainer.errors.has_error_for(error_key)
11
+ def trainer_has_prepare_error(trainer: TrainerLogic):
12
+ return trainer.errors.has_error_for('prepare')
14
13
 
15
14
 
16
15
  async def test_preparing_is_successful(test_initialized_trainer: TestingTrainerLogic):
@@ -19,7 +18,7 @@ async def test_preparing_is_successful(test_initialized_trainer: TestingTrainerL
19
18
  trainer._init_from_last_training()
20
19
 
21
20
  await trainer._perform_state('prepare', TrainerState.DataDownloading, TrainerState.DataDownloaded, trainer._prepare)
22
- assert trainer_has_error(trainer) is False
21
+ assert trainer_has_prepare_error(trainer) is False
23
22
  assert trainer.training.training_state == TrainerState.DataDownloaded
24
23
  assert trainer.training.data is not None
25
24
  assert trainer.node.last_training_io.load() == trainer.training
@@ -30,7 +29,7 @@ async def test_abort_preparing(test_initialized_trainer: TestingTrainerLogic):
30
29
  create_active_training_file(trainer)
31
30
  trainer._init_from_last_training()
32
31
 
33
- _ = asyncio.get_running_loop().create_task(trainer._run())
32
+ trainer._begin_training_task()
34
33
  await assert_training_state(trainer.training, TrainerState.DataDownloading, timeout=1, interval=0.001)
35
34
 
36
35
  await trainer.stop()
@@ -48,9 +47,9 @@ async def test_request_error(test_initialized_trainer: TestingTrainerLogic):
48
47
 
49
48
  _ = asyncio.get_running_loop().create_task(trainer._run())
50
49
  await assert_training_state(trainer.training, TrainerState.DataDownloading, timeout=3, interval=0.001)
51
- await assert_training_state(trainer.training, TrainerState.Initialized, timeout=3, interval=0.001)
50
+ await assert_training_state(trainer.training, TrainerState.Initialized, timeout=10, interval=0.001)
52
51
 
53
- assert trainer_has_error(trainer)
52
+ assert trainer_has_prepare_error(trainer)
54
53
  assert trainer._training is not None # pylint: disable=protected-access
55
54
  assert trainer.training.training_state == TrainerState.Initialized
56
55
  assert trainer.node.last_training_io.load() == trainer.training
@@ -11,11 +11,9 @@ from ..testing_trainer_logic import TestingTrainerLogic
11
11
 
12
12
  # pylint: disable=protected-access
13
13
 
14
- error_key = 'sync_confusion_matrix'
15
14
 
16
-
17
- def trainer_has_error(trainer: TrainerLogic):
18
- return trainer.errors.has_error_for(error_key)
15
+ def trainer_has_sync_confusion_matrix_error(trainer: TrainerLogic):
16
+ return trainer.errors.has_error_for('sync_confusion_matrix')
19
17
 
20
18
 
21
19
  async def test_nothing_to_sync(test_initialized_trainer: TestingTrainerLogic):
@@ -26,10 +24,10 @@ async def test_nothing_to_sync(test_initialized_trainer: TestingTrainerLogic):
26
24
  create_active_training_file(trainer, training_state=TrainerState.TrainingFinished)
27
25
  trainer._init_from_last_training()
28
26
 
29
- _ = asyncio.get_running_loop().create_task(trainer._run())
27
+ trainer._begin_training_task()
30
28
 
31
29
  await assert_training_state(trainer.training, TrainerState.ConfusionMatrixSynced, timeout=1, interval=0.001)
32
- assert trainer_has_error(trainer) is False
30
+ assert trainer_has_sync_confusion_matrix_error(trainer) is False
33
31
  assert trainer.training.training_state == TrainerState.ConfusionMatrixSynced
34
32
  assert trainer.node.last_training_io.load() == trainer.training
35
33
 
@@ -38,16 +36,16 @@ async def test_unsynced_model_available__sync_successful(test_initialized_traine
38
36
  trainer = test_initialized_trainer_node.trainer_logic
39
37
  assert isinstance(trainer, TestingTrainerLogic)
40
38
 
41
- await mock_socket_io_call(mocker, test_initialized_trainer_node, {'success': True})
39
+ await mock_socket_io_call(mocker, test_initialized_trainer_node, return_value={'success': True})
42
40
  create_active_training_file(trainer, training_state=TrainerState.TrainingFinished)
43
41
 
44
42
  trainer._init_from_last_training()
45
43
  trainer.has_new_model = True
46
44
 
47
- _ = asyncio.get_running_loop().create_task(trainer._run())
45
+ trainer._begin_training_task()
48
46
  await assert_training_state(trainer.training, TrainerState.ConfusionMatrixSynced, timeout=1, interval=0.001)
49
47
 
50
- assert trainer_has_error(trainer) is False
48
+ assert trainer_has_sync_confusion_matrix_error(trainer) is False
51
49
  # assert trainer.training.training_state == TrainerState.ConfusionMatrixSynced
52
50
  assert trainer.node.last_training_io.load() == trainer.training
53
51
 
@@ -56,17 +54,19 @@ async def test_unsynced_model_available__sio_not_connected(test_initialized_trai
56
54
  trainer = test_initialized_trainer_node.trainer_logic
57
55
  assert isinstance(trainer, TestingTrainerLogic)
58
56
 
57
+ await test_initialized_trainer_node.sio_client.disconnect()
58
+ test_initialized_trainer_node.set_skip_repeat_loop(True)
59
59
  create_active_training_file(trainer, training_state=TrainerState.TrainingFinished)
60
60
 
61
61
  assert test_initialized_trainer_node.sio_client.connected is False
62
62
  trainer.has_new_model = True
63
63
 
64
- _ = asyncio.get_running_loop().create_task(trainer._run())
64
+ trainer._begin_training_task()
65
65
 
66
- await assert_training_state(trainer.training, 'confusion_matrix_syncing', timeout=1, interval=0.001)
67
- await assert_training_state(trainer.training, TrainerState.TrainingFinished, timeout=1, interval=0.001)
66
+ await assert_training_state(trainer.training, TrainerState.ConfusionMatrixSyncing, timeout=1, interval=0.001)
67
+ await assert_training_state(trainer.training, TrainerState.TrainingFinished, timeout=10, interval=0.001)
68
68
 
69
- assert trainer_has_error(trainer)
69
+ assert trainer_has_sync_confusion_matrix_error(trainer) # Due to sio not being connected, the request will fail
70
70
  assert trainer.training.training_state == TrainerState.TrainingFinished
71
71
  assert trainer.node.last_training_io.load() == trainer.training
72
72
 
@@ -75,17 +75,17 @@ async def test_unsynced_model_available__request_is_not_successful(test_initiali
75
75
  trainer = test_initialized_trainer_node.trainer_logic
76
76
  assert isinstance(trainer, TestingTrainerLogic)
77
77
 
78
- await mock_socket_io_call(mocker, test_initialized_trainer_node, {'success': False})
78
+ await mock_socket_io_call(mocker, test_initialized_trainer_node, return_value={'success': False})
79
79
 
80
80
  create_active_training_file(trainer, training_state=TrainerState.TrainingFinished)
81
81
 
82
82
  trainer.has_new_model = True
83
- _ = asyncio.get_running_loop().create_task(trainer._run())
83
+ trainer._begin_training_task()
84
84
 
85
- await assert_training_state(trainer.training, 'confusion_matrix_syncing', timeout=1, interval=0.001)
86
- await assert_training_state(trainer.training, TrainerState.TrainingFinished, timeout=1, interval=0.001)
85
+ await assert_training_state(trainer.training, TrainerState.ConfusionMatrixSyncing, timeout=1, interval=0.001)
86
+ await assert_training_state(trainer.training, TrainerState.TrainingFinished, timeout=10, interval=0.001)
87
87
 
88
- assert trainer_has_error(trainer)
88
+ assert trainer_has_sync_confusion_matrix_error(trainer) # Due to sio call failure, the error will be set
89
89
  assert trainer.training.training_state == TrainerState.TrainingFinished
90
90
  assert trainer.node.last_training_io.load() == trainer.training
91
91
 
@@ -100,6 +100,9 @@ async def test_basic_mock(test_initialized_trainer_node: TrainerNode, mocker: Mo
100
100
 
101
101
 
102
102
  async def mock_socket_io_call(mocker, trainer_node: TrainerNode, return_value):
103
+ '''
104
+ Patch the socketio call function to always return the given return_value
105
+ '''
103
106
  for _ in range(10):
104
107
  if trainer_node.sio_client is None:
105
108
  await asyncio.sleep(0.1)
@@ -14,10 +14,10 @@ async def test_successful_training(test_initialized_trainer: TestingTrainerLogic
14
14
  create_active_training_file(trainer, training_state=TrainerState.TrainModelDownloaded)
15
15
  trainer._init_from_last_training()
16
16
 
17
- _ = asyncio.get_running_loop().create_task(trainer._run())
17
+ trainer._begin_training_task()
18
18
 
19
19
  await condition(lambda: trainer._executor and trainer._executor.is_running(), timeout=1, interval=0.01)
20
- await assert_training_state(trainer.training, TrainerState.TrainingRunning, timeout=1, interval=0.01)
20
+ await assert_training_state(trainer.training, TrainerState.TrainingRunning, timeout=10, interval=0.01)
21
21
  assert trainer.start_training_task is not None
22
22
 
23
23
  assert trainer._executor is not None
@@ -34,16 +34,15 @@ async def test_stop_running_training(test_initialized_trainer: TestingTrainerLog
34
34
  create_active_training_file(trainer, training_state=TrainerState.TrainModelDownloaded)
35
35
  trainer._init_from_last_training()
36
36
 
37
- _ = asyncio.get_running_loop().create_task(trainer._run())
37
+ trainer._begin_training_task()
38
38
 
39
39
  await condition(lambda: trainer._executor and trainer._executor.is_running(), timeout=1, interval=0.01)
40
- await assert_training_state(trainer.training, TrainerState.TrainingRunning, timeout=1, interval=0.01)
40
+ await assert_training_state(trainer.training, TrainerState.TrainingRunning, timeout=10, interval=0.01)
41
41
  assert trainer.start_training_task is not None
42
42
 
43
43
  await trainer.stop()
44
44
  await assert_training_state(trainer.training, TrainerState.TrainingFinished, timeout=2, interval=0.01)
45
45
 
46
- assert trainer.training.training_state == TrainerState.TrainingFinished
47
46
  assert trainer.node.last_training_io.load() == trainer.training
48
47
 
49
48
 
@@ -55,15 +54,14 @@ async def test_training_can_maybe_resumed(test_initialized_trainer: TestingTrain
55
54
  trainer._init_from_last_training()
56
55
  trainer._can_resume_flag = True
57
56
 
58
- _ = asyncio.get_running_loop().create_task(trainer._run())
57
+ trainer._begin_training_task()
59
58
 
60
59
  await condition(lambda: trainer._executor and trainer._executor.is_running(), timeout=1, interval=0.01)
61
- await assert_training_state(trainer.training, TrainerState.TrainingRunning, timeout=1, interval=0.001)
60
+ await assert_training_state(trainer.training, TrainerState.TrainingRunning, timeout=10, interval=0.001)
62
61
  assert trainer.start_training_task is not None
63
62
 
64
63
  assert trainer._executor is not None
65
64
  await trainer._executor.stop_and_wait() # NOTE normally a training terminates itself e.g
66
65
  await assert_training_state(trainer.training, TrainerState.TrainingFinished, timeout=1, interval=0.001)
67
66
 
68
- assert trainer.training.training_state == TrainerState.TrainingFinished
69
67
  assert trainer.node.last_training_io.load() == trainer.training
@@ -11,11 +11,10 @@ from ..state_helper import assert_training_state, create_active_training_file
11
11
  from ..testing_trainer_logic import TestingTrainerLogic
12
12
 
13
13
  # pylint: disable=protected-access
14
- error_key = 'upload_detections'
15
14
 
16
15
 
17
- def trainer_has_error(trainer: TrainerLogic):
18
- return trainer.errors.has_error_for(error_key)
16
+ def trainer_has_upload_detections_error(trainer: TrainerLogic):
17
+ return trainer.errors.has_error_for('upload_detections')
19
18
 
20
19
 
21
20
  async def create_valid_detection_file(trainer: TrainerLogic, number_of_entries: int = 1, file_index: int = 0):
@@ -125,12 +124,11 @@ async def test_bad_status_from_LearningLoop(test_initialized_trainer: TestingTra
125
124
  trainer._init_from_last_training()
126
125
  trainer.active_training_io.save_detections([get_dummy_detections()])
127
126
 
128
- _ = asyncio.get_running_loop().create_task(trainer._run())
127
+ trainer._begin_training_task()
129
128
  await assert_training_state(trainer.training, TrainerState.DetectionUploading, timeout=1, interval=0.001)
130
- await assert_training_state(trainer.training, TrainerState.Detected, timeout=1, interval=0.001)
129
+ await assert_training_state(trainer.training, TrainerState.Detected, timeout=10, interval=0.001)
131
130
 
132
- assert trainer_has_error(trainer)
133
- assert trainer.training.training_state == TrainerState.Detected
131
+ assert trainer_has_upload_detections_error(trainer)
134
132
  assert trainer.node.last_training_io.load() == trainer.training
135
133
 
136
134
 
@@ -143,7 +141,7 @@ async def test_go_to_cleanup_if_no_detections_exist(test_initialized_trainer: Te
143
141
  create_active_training_file(trainer, training_state=TrainerState.Detected)
144
142
  trainer._init_from_last_training()
145
143
 
146
- _ = asyncio.get_running_loop().create_task(trainer._run())
144
+ trainer._begin_training_task()
147
145
  await assert_training_state(trainer.training, TrainerState.ReadyForCleanup, timeout=1, interval=0.001)
148
146
 
149
147
 
@@ -154,7 +152,7 @@ async def test_abort_uploading(test_initialized_trainer: TestingTrainerLogic):
154
152
  trainer._init_from_last_training()
155
153
  await create_valid_detection_file(trainer)
156
154
 
157
- _ = asyncio.get_running_loop().create_task(trainer._run())
155
+ trainer._begin_training_task()
158
156
 
159
157
  await assert_training_state(trainer.training, TrainerState.DetectionUploading, timeout=1, interval=0.001)
160
158