learning-loop-node 0.10.13__py3-none-any.whl → 0.10.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of learning-loop-node might be problematic. Click here for more details.
- learning_loop_node/annotation/annotator_node.py +11 -10
- learning_loop_node/data_classes/detections.py +34 -26
- learning_loop_node/data_classes/general.py +27 -17
- learning_loop_node/data_exchanger.py +6 -5
- learning_loop_node/detector/detector_logic.py +3 -3
- learning_loop_node/detector/detector_node.py +21 -15
- learning_loop_node/detector/rest/about.py +34 -10
- learning_loop_node/detector/rest/backdoor_controls.py +9 -26
- learning_loop_node/detector/rest/detect.py +17 -16
- learning_loop_node/detector/rest/model_version_control.py +30 -13
- learning_loop_node/detector/rest/operation_mode.py +11 -5
- learning_loop_node/detector/rest/outbox_mode.py +7 -1
- learning_loop_node/node.py +93 -48
- learning_loop_node/rest.py +25 -2
- learning_loop_node/tests/detector/conftest.py +4 -4
- learning_loop_node/tests/detector/test_client_communication.py +21 -19
- learning_loop_node/tests/detector/test_detector_node.py +3 -3
- learning_loop_node/tests/trainer/conftest.py +4 -4
- learning_loop_node/tests/trainer/states/test_state_detecting.py +8 -9
- learning_loop_node/tests/trainer/states/test_state_download_train_model.py +8 -8
- learning_loop_node/tests/trainer/states/test_state_prepare.py +6 -7
- learning_loop_node/tests/trainer/states/test_state_sync_confusion_matrix.py +21 -18
- learning_loop_node/tests/trainer/states/test_state_train.py +6 -8
- learning_loop_node/tests/trainer/states/test_state_upload_detections.py +7 -9
- learning_loop_node/tests/trainer/states/test_state_upload_model.py +7 -8
- learning_loop_node/tests/trainer/test_errors.py +2 -2
- learning_loop_node/trainer/rest/backdoor_controls.py +19 -40
- learning_loop_node/trainer/trainer_logic_generic.py +7 -4
- learning_loop_node/trainer/trainer_node.py +4 -3
- {learning_loop_node-0.10.13.dist-info → learning_loop_node-0.10.14.dist-info}/METADATA +1 -1
- {learning_loop_node-0.10.13.dist-info → learning_loop_node-0.10.14.dist-info}/RECORD +32 -32
- {learning_loop_node-0.10.13.dist-info → learning_loop_node-0.10.14.dist-info}/WHEEL +0 -0
|
@@ -6,10 +6,13 @@ from ..outbox import Outbox
|
|
|
6
6
|
router = APIRouter()
|
|
7
7
|
|
|
8
8
|
|
|
9
|
-
@router.get("/outbox_mode")
|
|
9
|
+
@router.get("/outbox_mode", response_class=PlainTextResponse)
|
|
10
10
|
async def get_outbox_mode(request: Request):
|
|
11
11
|
'''
|
|
12
|
+
Get the outbox mode of the detector node.
|
|
13
|
+
|
|
12
14
|
Example Usage
|
|
15
|
+
|
|
13
16
|
curl http://localhost/outbox_mode
|
|
14
17
|
'''
|
|
15
18
|
outbox: Outbox = request.app.outbox
|
|
@@ -19,7 +22,10 @@ async def get_outbox_mode(request: Request):
|
|
|
19
22
|
@router.put("/outbox_mode")
|
|
20
23
|
async def put_outbox_mode(request: Request):
|
|
21
24
|
'''
|
|
25
|
+
Set the outbox mode of the detector node.
|
|
26
|
+
|
|
22
27
|
Example Usage
|
|
28
|
+
|
|
23
29
|
curl -X PUT -d "continuous_upload" http://localhost/outbox_mode
|
|
24
30
|
curl -X PUT -d "stopped" http://localhost/outbox_mode
|
|
25
31
|
'''
|
learning_loop_node/node.py
CHANGED
|
@@ -8,7 +8,6 @@ from datetime import datetime
|
|
|
8
8
|
from typing import Any, Optional
|
|
9
9
|
|
|
10
10
|
import aiohttp
|
|
11
|
-
import socketio
|
|
12
11
|
from aiohttp import TCPConnector
|
|
13
12
|
from fastapi import FastAPI
|
|
14
13
|
from socketio import AsyncClient
|
|
@@ -21,6 +20,10 @@ from .loop_communication import LoopCommunicator
|
|
|
21
20
|
from .rest import router
|
|
22
21
|
|
|
23
22
|
|
|
23
|
+
class NodeConnectionError(Exception):
|
|
24
|
+
pass
|
|
25
|
+
|
|
26
|
+
|
|
24
27
|
class Node(FastAPI):
|
|
25
28
|
|
|
26
29
|
def __init__(self, name: str, uuid: Optional[str] = None, node_type: str = 'node', needs_login: bool = True):
|
|
@@ -43,8 +46,7 @@ class Node(FastAPI):
|
|
|
43
46
|
self.needs_login = needs_login
|
|
44
47
|
|
|
45
48
|
self.log = logging.getLogger('Node')
|
|
46
|
-
self.
|
|
47
|
-
self.websocket_url = self.loop_communicator.websocket_url()
|
|
49
|
+
self.init_loop_communicator()
|
|
48
50
|
self.data_exchanger = DataExchanger(None, self.loop_communicator)
|
|
49
51
|
|
|
50
52
|
self.startup_datetime = datetime.now()
|
|
@@ -56,9 +58,20 @@ class Node(FastAPI):
|
|
|
56
58
|
'nodeType': node_type}
|
|
57
59
|
|
|
58
60
|
self.repeat_task: Any = None
|
|
61
|
+
self.socket_connection_broken = False
|
|
62
|
+
self._skip_repeat_loop = False
|
|
59
63
|
|
|
60
64
|
self.include_router(router)
|
|
61
65
|
|
|
66
|
+
self.CONNECTED_TO_LOOP = asyncio.Event()
|
|
67
|
+
self.DISCONNECTED_FROM_LOOP = asyncio.Event()
|
|
68
|
+
|
|
69
|
+
self.repeat_loop_lock = asyncio.Lock()
|
|
70
|
+
|
|
71
|
+
def init_loop_communicator(self):
|
|
72
|
+
self.loop_communicator = LoopCommunicator()
|
|
73
|
+
self.websocket_url = self.loop_communicator.websocket_url()
|
|
74
|
+
|
|
62
75
|
@property
|
|
63
76
|
def sio_client(self) -> AsyncClient:
|
|
64
77
|
if self._sio_client is None:
|
|
@@ -69,7 +82,10 @@ class Node(FastAPI):
|
|
|
69
82
|
@asynccontextmanager
|
|
70
83
|
async def lifespan(self, app: FastAPI): # pylint: disable=unused-argument
|
|
71
84
|
try:
|
|
72
|
-
|
|
85
|
+
try:
|
|
86
|
+
await self._on_startup()
|
|
87
|
+
except Exception:
|
|
88
|
+
self.log.exception('Fatal error during startup: %s')
|
|
73
89
|
self.repeat_task = asyncio.create_task(self.repeat_loop())
|
|
74
90
|
yield
|
|
75
91
|
finally:
|
|
@@ -83,13 +99,10 @@ class Node(FastAPI):
|
|
|
83
99
|
|
|
84
100
|
async def _on_startup(self):
|
|
85
101
|
self.log.info('received "startup" lifecycle-event')
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
self.log.
|
|
90
|
-
await self.loop_communicator.ensure_login()
|
|
91
|
-
self.log.info('create sio client')
|
|
92
|
-
await self.create_sio_client()
|
|
102
|
+
try:
|
|
103
|
+
await self.reconnect_to_loop()
|
|
104
|
+
except Exception:
|
|
105
|
+
self.log.warning('Could not establish sio connection to loop during startup')
|
|
93
106
|
self.log.info('done')
|
|
94
107
|
await self.on_startup()
|
|
95
108
|
|
|
@@ -102,55 +115,93 @@ class Node(FastAPI):
|
|
|
102
115
|
await self.on_shutdown()
|
|
103
116
|
|
|
104
117
|
async def repeat_loop(self) -> None:
|
|
105
|
-
"""NOTE: with the lifespan approach, we cannot use @repeat_every anymore :("""
|
|
106
118
|
while True:
|
|
119
|
+
if self._skip_repeat_loop:
|
|
120
|
+
self.log.debug('node is muted, skipping repeat loop')
|
|
121
|
+
await asyncio.sleep(1)
|
|
122
|
+
continue
|
|
107
123
|
try:
|
|
108
|
-
|
|
124
|
+
async with self.repeat_loop_lock:
|
|
125
|
+
await self._ensure_sio_connection()
|
|
126
|
+
await self.on_repeat()
|
|
109
127
|
except asyncio.CancelledError:
|
|
110
128
|
return
|
|
111
|
-
except Exception
|
|
112
|
-
self.log.exception(
|
|
129
|
+
except Exception:
|
|
130
|
+
self.log.exception('error in repeat loop')
|
|
131
|
+
|
|
113
132
|
await asyncio.sleep(5)
|
|
114
133
|
|
|
115
|
-
async def
|
|
116
|
-
if not self.sio_client.connected:
|
|
117
|
-
self.log.info('Reconnecting to loop via sio'
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
134
|
+
async def _ensure_sio_connection(self):
|
|
135
|
+
if self.socket_connection_broken or self._sio_client is None or not self.sio_client.connected:
|
|
136
|
+
self.log.info('Reconnecting to loop via sio due to %s',
|
|
137
|
+
'broken connection' if self.socket_connection_broken else 'no connection')
|
|
138
|
+
await self.reconnect_to_loop()
|
|
139
|
+
|
|
140
|
+
async def reconnect_to_loop(self):
|
|
141
|
+
"""Initialize the loop communicator, log in if needed and reconnect to the loop via socket.io."""
|
|
142
|
+
self.init_loop_communicator()
|
|
143
|
+
if self.needs_login:
|
|
144
|
+
await self.loop_communicator.ensure_login(relogin=True)
|
|
145
|
+
try:
|
|
146
|
+
await self._reconnect_socketio()
|
|
147
|
+
except NodeConnectionError:
|
|
148
|
+
self.log.exception('Could not reset sio connection to loop')
|
|
149
|
+
self.socket_connection_broken = True
|
|
150
|
+
raise
|
|
151
|
+
|
|
152
|
+
self.socket_connection_broken = False
|
|
153
|
+
|
|
154
|
+
def set_skip_repeat_loop(self, value: bool):
|
|
155
|
+
self._skip_repeat_loop = value
|
|
156
|
+
self.log.info('node is muted: %s', value)
|
|
123
157
|
|
|
124
158
|
# --------------------------------------------------- SOCKET.IO ---------------------------------------------------
|
|
125
159
|
|
|
126
|
-
async def
|
|
127
|
-
"""Create a socket.io client
|
|
128
|
-
|
|
160
|
+
async def _reconnect_socketio(self):
|
|
161
|
+
"""Create a socket.io client, connect it to the learning loop and register its events.
|
|
162
|
+
The current client is disconnected and deleted if it already exists."""
|
|
163
|
+
|
|
164
|
+
self.log.debug('-------------- Connecting to loop via socket.io -------------------')
|
|
165
|
+
self.log.debug('HTTP Cookies: %s\n', self.loop_communicator.get_cookies())
|
|
129
166
|
|
|
167
|
+
if self._sio_client is not None:
|
|
168
|
+
try:
|
|
169
|
+
await self.sio_client.disconnect()
|
|
170
|
+
self.log.info('disconnected from loop via sio')
|
|
171
|
+
# NOTE: without waiting for the disconnect event, we might disconnect the next connection too early
|
|
172
|
+
await asyncio.wait_for(self.DISCONNECTED_FROM_LOOP.wait(), timeout=5)
|
|
173
|
+
except asyncio.TimeoutError:
|
|
174
|
+
self.log.warning(
|
|
175
|
+
'Did not receive disconnect event from loop within 5 seconds.\nContinuing with new connection...')
|
|
176
|
+
except Exception as e:
|
|
177
|
+
self.log.warning('Could not disconnect from loop via sio: %s.\nIgnoring...', e)
|
|
178
|
+
self._sio_client = None
|
|
179
|
+
|
|
180
|
+
connector = None
|
|
130
181
|
if self.loop_communicator.ssl_cert_path:
|
|
131
|
-
logging.info(
|
|
182
|
+
logging.info('SIO using SSL certificate path: %s', self.loop_communicator.ssl_cert_path)
|
|
132
183
|
ssl_context = ssl.create_default_context(cafile=self.loop_communicator.ssl_cert_path)
|
|
133
184
|
ssl_context.check_hostname = False
|
|
134
185
|
ssl_context.verify_mode = ssl.CERT_REQUIRED
|
|
135
186
|
connector = TCPConnector(ssl=ssl_context)
|
|
136
|
-
self._sio_client = AsyncClient(request_timeout=20,
|
|
137
|
-
http_session=aiohttp.ClientSession(cookies=self.loop_communicator.get_cookies(),
|
|
138
|
-
connector=connector))
|
|
139
187
|
|
|
140
|
-
|
|
141
|
-
self.
|
|
142
|
-
http_session=aiohttp.ClientSession(cookies=self.loop_communicator.get_cookies()))
|
|
188
|
+
self._sio_client = AsyncClient(request_timeout=20, http_session=aiohttp.ClientSession(
|
|
189
|
+
cookies=self.loop_communicator.get_cookies(), connector=connector))
|
|
143
190
|
|
|
144
191
|
# pylint: disable=protected-access
|
|
145
|
-
self.
|
|
192
|
+
self._sio_client._trigger_event = ensure_socket_response(self._sio_client._trigger_event)
|
|
146
193
|
|
|
147
194
|
@self._sio_client.event
|
|
148
195
|
async def connect():
|
|
149
196
|
self.log.info('received "connect" via sio from loop.')
|
|
197
|
+
self.CONNECTED_TO_LOOP.set()
|
|
198
|
+
self.DISCONNECTED_FROM_LOOP.clear()
|
|
150
199
|
|
|
151
200
|
@self._sio_client.event
|
|
152
201
|
async def disconnect():
|
|
153
202
|
self.log.info('received "disconnect" via sio from loop.')
|
|
203
|
+
self.DISCONNECTED_FROM_LOOP.set()
|
|
204
|
+
self.CONNECTED_TO_LOOP.clear()
|
|
154
205
|
|
|
155
206
|
@self._sio_client.event
|
|
156
207
|
async def restart():
|
|
@@ -158,21 +209,15 @@ class Node(FastAPI):
|
|
|
158
209
|
sys.exit(0)
|
|
159
210
|
|
|
160
211
|
self.register_sio_events(self._sio_client)
|
|
161
|
-
|
|
162
|
-
async def connect_sio(self):
|
|
163
212
|
try:
|
|
164
|
-
await self.
|
|
165
|
-
except Exception:
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
except socketio.exceptions.ConnectionError: # type: ignore
|
|
173
|
-
self.log.warning('connection error')
|
|
174
|
-
except Exception:
|
|
175
|
-
self.log.exception(f'error while connecting to "{self.websocket_url}". Exception:')
|
|
213
|
+
await self._sio_client.connect(f"{self.websocket_url}", headers=self.sio_headers, socketio_path="/ws/socket.io")
|
|
214
|
+
except Exception as e:
|
|
215
|
+
self.log.exception('Could not connect socketio client to loop')
|
|
216
|
+
raise NodeConnectionError('Could not connect socketio client to loop') from e
|
|
217
|
+
|
|
218
|
+
if not self._sio_client.connected:
|
|
219
|
+
self.log.exception('Could not connect socketio client to loop')
|
|
220
|
+
raise NodeConnectionError('Could not connect socketio client to loop')
|
|
176
221
|
|
|
177
222
|
# --------------------------------------------------- ABSTRACT METHODS ---------------------------------------------------
|
|
178
223
|
|
learning_loop_node/rest.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
import logging
|
|
2
2
|
from typing import TYPE_CHECKING
|
|
3
3
|
|
|
4
|
-
from fastapi import APIRouter,
|
|
4
|
+
from fastapi import APIRouter, HTTPException, Request
|
|
5
5
|
|
|
6
6
|
if TYPE_CHECKING:
|
|
7
7
|
from .node import Node
|
|
@@ -12,7 +12,7 @@ logger = logging.getLogger('Node.rest')
|
|
|
12
12
|
|
|
13
13
|
|
|
14
14
|
@router.put("/debug_logging")
|
|
15
|
-
async def _debug_logging(request: Request):
|
|
15
|
+
async def _debug_logging(request: Request) -> str:
|
|
16
16
|
'''
|
|
17
17
|
Example Usage
|
|
18
18
|
|
|
@@ -30,3 +30,26 @@ async def _debug_logging(request: Request):
|
|
|
30
30
|
node.log.setLevel('DEBUG')
|
|
31
31
|
return 'on'
|
|
32
32
|
raise HTTPException(status_code=400, detail='Invalid state')
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
@router.put("/socketio")
|
|
36
|
+
async def _socketio(request: Request) -> str:
|
|
37
|
+
'''
|
|
38
|
+
Enable or disable the socketio connection to the learning loop.
|
|
39
|
+
Not intended to be used outside of testing.
|
|
40
|
+
|
|
41
|
+
Example Usage
|
|
42
|
+
|
|
43
|
+
curl -X PUT -d "on" http://hosturl/socketio
|
|
44
|
+
'''
|
|
45
|
+
state = str(await request.body(), 'utf-8')
|
|
46
|
+
node: 'Node' = request.app
|
|
47
|
+
|
|
48
|
+
if state == 'off':
|
|
49
|
+
await node.sio_client.disconnect()
|
|
50
|
+
node.set_skip_repeat_loop(True) # Prevent auto-reconnection
|
|
51
|
+
return 'off'
|
|
52
|
+
if state == 'on':
|
|
53
|
+
node.set_skip_repeat_loop(False) # Allow auto-reconnection (1 sec delay)
|
|
54
|
+
return 'on'
|
|
55
|
+
raise HTTPException(status_code=400, detail='Invalid state')
|
|
@@ -41,8 +41,8 @@ def should_have_segmentations(request) -> bool:
|
|
|
41
41
|
async def test_detector_node():
|
|
42
42
|
"""Initializes and runs a detector testnode. Note that the running instance and the one the function returns are not the same instances!"""
|
|
43
43
|
|
|
44
|
-
os.environ['
|
|
45
|
-
os.environ['
|
|
44
|
+
os.environ['LOOP_ORGANIZATION'] = 'zauberzeug'
|
|
45
|
+
os.environ['LOOP_PROJECT'] = 'demo'
|
|
46
46
|
|
|
47
47
|
detector = TestingDetectorLogic()
|
|
48
48
|
node = DetectorNode(name='test', detector=detector)
|
|
@@ -143,8 +143,8 @@ def mock_detector_logic():
|
|
|
143
143
|
|
|
144
144
|
@pytest.fixture
|
|
145
145
|
def detector_node(mock_detector_logic):
|
|
146
|
-
os.environ['
|
|
147
|
-
os.environ['
|
|
146
|
+
os.environ['LOOP_ORGANIZATION'] = 'test_organization'
|
|
147
|
+
os.environ['LOOP_PROJECT'] = 'test_project'
|
|
148
148
|
return DetectorNode(name="test_node", detector=mock_detector_logic)
|
|
149
149
|
|
|
150
150
|
# ====================================== REDUNDANT FIXTURES IN ALL CONFTESTS ! ======================================
|
|
@@ -93,10 +93,10 @@ async def test_sio_upload(test_detector_node: DetectorNode, sio_client):
|
|
|
93
93
|
|
|
94
94
|
# NOTE: This test seems to be flaky.
|
|
95
95
|
async def test_about_endpoint(test_detector_node: DetectorNode):
|
|
96
|
-
await asyncio.sleep(
|
|
96
|
+
await asyncio.sleep(11)
|
|
97
97
|
response = requests.get(f'http://localhost:{GLOBALS.detector_port}/about', timeout=30)
|
|
98
98
|
|
|
99
|
-
assert response.status_code == 200
|
|
99
|
+
assert response.status_code == 200, response.content
|
|
100
100
|
response_dict = json.loads(response.content)
|
|
101
101
|
assert response_dict['model_info']
|
|
102
102
|
model_information = ModelInformation.from_dict(response_dict['model_info'])
|
|
@@ -108,59 +108,60 @@ async def test_about_endpoint(test_detector_node: DetectorNode):
|
|
|
108
108
|
|
|
109
109
|
|
|
110
110
|
async def test_model_version_api(test_detector_node: DetectorNode):
|
|
111
|
-
await asyncio.sleep(
|
|
111
|
+
await asyncio.sleep(11)
|
|
112
112
|
|
|
113
113
|
response = requests.get(f'http://localhost:{GLOBALS.detector_port}/model_version', timeout=30)
|
|
114
|
-
assert response.status_code == 200
|
|
114
|
+
assert response.status_code == 200, response.content
|
|
115
115
|
response_dict = json.loads(response.content)
|
|
116
|
+
assert response_dict['version_control'] == 'follow_loop'
|
|
116
117
|
assert response_dict['current_version'] == '1.1'
|
|
117
118
|
assert response_dict['target_version'] == '1.1'
|
|
118
119
|
assert response_dict['loop_version'] == '1.1'
|
|
119
120
|
assert response_dict['local_versions'] == ['1.1']
|
|
120
|
-
assert response_dict['version_control'] == 'follow_loop'
|
|
121
121
|
|
|
122
122
|
response = requests.put(f'http://localhost:{GLOBALS.detector_port}/model_version', data='1.0', timeout=30)
|
|
123
|
+
assert response.status_code == 200, response.content
|
|
123
124
|
response = requests.get(f'http://localhost:{GLOBALS.detector_port}/model_version', timeout=30)
|
|
124
|
-
assert response.status_code == 200
|
|
125
|
+
assert response.status_code == 200, response.content
|
|
125
126
|
response_dict = json.loads(response.content)
|
|
127
|
+
assert response_dict['version_control'] == 'specific_version'
|
|
126
128
|
assert response_dict['current_version'] == '1.1'
|
|
127
129
|
assert response_dict['target_version'] == '1.0'
|
|
128
130
|
assert response_dict['loop_version'] == '1.1'
|
|
129
131
|
assert response_dict['local_versions'] == ['1.1']
|
|
130
|
-
assert response_dict['version_control'] == 'specific_version'
|
|
131
132
|
|
|
132
133
|
await asyncio.sleep(11)
|
|
133
|
-
|
|
134
134
|
response = requests.get(f'http://localhost:{GLOBALS.detector_port}/model_version', timeout=30)
|
|
135
|
-
assert response.status_code == 200
|
|
135
|
+
assert response.status_code == 200, response.content
|
|
136
136
|
response_dict = json.loads(response.content)
|
|
137
|
+
assert response_dict['version_control'] == 'specific_version'
|
|
137
138
|
assert response_dict['current_version'] == '1.0'
|
|
138
139
|
assert response_dict['target_version'] == '1.0'
|
|
139
140
|
assert response_dict['loop_version'] == '1.1'
|
|
140
141
|
assert set(response_dict['local_versions']) == set(['1.1', '1.0'])
|
|
141
|
-
assert response_dict['version_control'] == 'specific_version'
|
|
142
142
|
|
|
143
143
|
response = requests.put(f'http://localhost:{GLOBALS.detector_port}/model_version', data='pause', timeout=30)
|
|
144
|
+
assert response.status_code == 200, response.content
|
|
144
145
|
await asyncio.sleep(11)
|
|
145
146
|
response = requests.get(f'http://localhost:{GLOBALS.detector_port}/model_version', timeout=30)
|
|
146
|
-
assert response.status_code == 200
|
|
147
|
+
assert response.status_code == 200, response.content
|
|
147
148
|
response_dict = json.loads(response.content)
|
|
149
|
+
assert response_dict['version_control'] == 'pause'
|
|
148
150
|
assert response_dict['current_version'] == '1.0'
|
|
149
151
|
assert response_dict['target_version'] == '1.0'
|
|
150
152
|
assert response_dict['loop_version'] == '1.1'
|
|
151
153
|
assert set(response_dict['local_versions']) == set(['1.1', '1.0'])
|
|
152
|
-
assert response_dict['version_control'] == 'pause'
|
|
153
154
|
|
|
154
155
|
response = requests.put(f'http://localhost:{GLOBALS.detector_port}/model_version', data='follow_loop', timeout=30)
|
|
155
156
|
await asyncio.sleep(11)
|
|
156
157
|
response = requests.get(f'http://localhost:{GLOBALS.detector_port}/model_version', timeout=30)
|
|
157
|
-
assert response.status_code == 200
|
|
158
|
+
assert response.status_code == 200, response.content
|
|
158
159
|
response_dict = json.loads(response.content)
|
|
160
|
+
assert response_dict['version_control'] == 'follow_loop'
|
|
159
161
|
assert response_dict['current_version'] == '1.1'
|
|
160
162
|
assert response_dict['target_version'] == '1.1'
|
|
161
163
|
assert response_dict['loop_version'] == '1.1'
|
|
162
164
|
assert set(response_dict['local_versions']) == set(['1.1', '1.0'])
|
|
163
|
-
assert response_dict['version_control'] == 'follow_loop'
|
|
164
165
|
|
|
165
166
|
|
|
166
167
|
async def test_rest_outbox_mode(test_detector_node: DetectorNode):
|
|
@@ -169,9 +170,9 @@ async def test_rest_outbox_mode(test_detector_node: DetectorNode):
|
|
|
169
170
|
def check_switch_to_mode(mode: str):
|
|
170
171
|
response = requests.put(f'http://localhost:{GLOBALS.detector_port}/outbox_mode',
|
|
171
172
|
data=mode, timeout=30)
|
|
172
|
-
assert response.status_code == 200
|
|
173
|
+
assert response.status_code == 200, response.content
|
|
173
174
|
response = requests.get(f'http://localhost:{GLOBALS.detector_port}/outbox_mode', timeout=30)
|
|
174
|
-
assert response.status_code == 200
|
|
175
|
+
assert response.status_code == 200, response.content
|
|
175
176
|
assert response.content == mode.encode()
|
|
176
177
|
|
|
177
178
|
check_switch_to_mode('stopped')
|
|
@@ -185,7 +186,7 @@ async def test_api_responsive_during_large_upload(test_detector_node: DetectorNo
|
|
|
185
186
|
with open(test_image_path, 'rb') as f:
|
|
186
187
|
image_bytes = f.read()
|
|
187
188
|
|
|
188
|
-
for _ in range(
|
|
189
|
+
for _ in range(200):
|
|
189
190
|
test_detector_node.outbox.save(image_bytes)
|
|
190
191
|
|
|
191
192
|
outbox_size_early = len(get_outbox_files(test_detector_node.outbox))
|
|
@@ -193,8 +194,9 @@ async def test_api_responsive_during_large_upload(test_detector_node: DetectorNo
|
|
|
193
194
|
|
|
194
195
|
# check if api is still responsive
|
|
195
196
|
response = requests.get(f'http://localhost:{GLOBALS.detector_port}/outbox_mode', timeout=2)
|
|
196
|
-
assert response.status_code == 200
|
|
197
|
+
assert response.status_code == 200, response.content
|
|
197
198
|
|
|
198
199
|
await asyncio.sleep(5)
|
|
199
200
|
outbox_size_late = len(get_outbox_files(test_detector_node.outbox))
|
|
200
|
-
assert
|
|
201
|
+
assert outbox_size_late > 0, 'The outbox should not be fully cleared, maybe the node was too fast.'
|
|
202
|
+
assert outbox_size_early > outbox_size_late, 'The outbox should have been partially emptied.'
|
|
@@ -57,9 +57,9 @@ async def test_get_detections(detector_node: DetectorNode, monkeypatch):
|
|
|
57
57
|
|
|
58
58
|
# Check if detections were processed
|
|
59
59
|
assert result is not None
|
|
60
|
-
assert
|
|
61
|
-
assert len(result
|
|
62
|
-
assert result
|
|
60
|
+
assert result.box_detections is not None
|
|
61
|
+
assert len(result.box_detections) == 1
|
|
62
|
+
assert result.box_detections[0].category_name == "test"
|
|
63
63
|
|
|
64
64
|
# Check if the correct upload method was called
|
|
65
65
|
assert filtered_upload_called == expect_filtered
|
|
@@ -1,7 +1,6 @@
|
|
|
1
|
-
from ...globals import GLOBALS
|
|
2
|
-
import shutil
|
|
3
1
|
import logging
|
|
4
2
|
import os
|
|
3
|
+
import shutil
|
|
5
4
|
import socket
|
|
6
5
|
from multiprocessing import log_to_stderr
|
|
7
6
|
|
|
@@ -9,6 +8,7 @@ import icecream
|
|
|
9
8
|
import pytest
|
|
10
9
|
|
|
11
10
|
from ...data_classes import Context
|
|
11
|
+
from ...globals import GLOBALS
|
|
12
12
|
from ...trainer.trainer_node import TrainerNode
|
|
13
13
|
from .testing_trainer_logic import TestingTrainerLogic
|
|
14
14
|
|
|
@@ -23,8 +23,8 @@ icecream.install()
|
|
|
23
23
|
|
|
24
24
|
@pytest.fixture()
|
|
25
25
|
async def test_initialized_trainer_node():
|
|
26
|
-
os.environ['
|
|
27
|
-
os.environ['
|
|
26
|
+
os.environ['LOOP_ORGANIZATION'] = 'zauberzeug'
|
|
27
|
+
os.environ['LOOP_PROJECT'] = 'demo'
|
|
28
28
|
|
|
29
29
|
trainer = TestingTrainerLogic()
|
|
30
30
|
node = TrainerNode(name='test', trainer_logic=trainer, uuid='NOD30000-0000-0000-0000-000000000000')
|
|
@@ -7,11 +7,10 @@ from ..state_helper import assert_training_state, create_active_training_file
|
|
|
7
7
|
from ..testing_trainer_logic import TestingTrainerLogic
|
|
8
8
|
|
|
9
9
|
# pylint: disable=protected-access
|
|
10
|
-
error_key = 'detecting'
|
|
11
10
|
|
|
12
11
|
|
|
13
|
-
def
|
|
14
|
-
return trainer.errors.has_error_for(
|
|
12
|
+
def trainer_has_detecting_error(trainer: TrainerLogic):
|
|
13
|
+
return trainer.errors.has_error_for('detecting')
|
|
15
14
|
|
|
16
15
|
|
|
17
16
|
async def test_successful_detecting(test_initialized_trainer: TestingTrainerLogic):
|
|
@@ -25,7 +24,7 @@ async def test_successful_detecting(test_initialized_trainer: TestingTrainerLogi
|
|
|
25
24
|
await assert_training_state(trainer.training, TrainerState.Detecting, timeout=1, interval=0.001)
|
|
26
25
|
await assert_training_state(trainer.training, TrainerState.Detected, timeout=10, interval=0.001)
|
|
27
26
|
|
|
28
|
-
assert
|
|
27
|
+
assert trainer_has_detecting_error(trainer) is False
|
|
29
28
|
assert trainer.training.training_state == TrainerState.Detected
|
|
30
29
|
assert trainer.node.last_training_io.load() == trainer.training
|
|
31
30
|
assert trainer.active_training_io.detections_exist()
|
|
@@ -37,7 +36,7 @@ async def test_detecting_can_be_aborted(test_initialized_trainer: TestingTrainer
|
|
|
37
36
|
trainer._init_from_last_training()
|
|
38
37
|
trainer.training.model_uuid_for_detecting = '12345678-bobo-7e92-f95f-424242424242'
|
|
39
38
|
|
|
40
|
-
|
|
39
|
+
trainer._begin_training_task()
|
|
41
40
|
|
|
42
41
|
await assert_training_state(trainer.training, TrainerState.Detecting, timeout=5, interval=0.001)
|
|
43
42
|
await trainer.stop()
|
|
@@ -54,13 +53,13 @@ async def test_model_not_downloadable_error(test_initialized_trainer: TestingTra
|
|
|
54
53
|
model_uuid_for_detecting='00000000-0000-0000-0000-000000000000') # bad model id
|
|
55
54
|
trainer._init_from_last_training()
|
|
56
55
|
|
|
57
|
-
|
|
56
|
+
trainer._begin_training_task()
|
|
58
57
|
|
|
59
|
-
await assert_training_state(trainer.training,
|
|
60
|
-
await assert_training_state(trainer.training,
|
|
58
|
+
await assert_training_state(trainer.training, TrainerState.Detecting, timeout=1, interval=0.001)
|
|
59
|
+
await assert_training_state(trainer.training, TrainerState.TrainModelUploaded, timeout=5, interval=0.001)
|
|
61
60
|
await asyncio.sleep(0.1)
|
|
62
61
|
|
|
63
|
-
assert
|
|
62
|
+
assert trainer_has_detecting_error(trainer)
|
|
64
63
|
assert trainer.training.training_state == TrainerState.TrainModelUploaded
|
|
65
64
|
assert trainer.training.model_uuid_for_detecting == '00000000-0000-0000-0000-000000000000'
|
|
66
65
|
assert trainer.node.last_training_io.load() == trainer.training
|
|
@@ -20,8 +20,8 @@ async def test_downloading_is_successful(test_initialized_trainer: TestingTraine
|
|
|
20
20
|
trainer._perform_state('download_model',
|
|
21
21
|
TrainerState.TrainModelDownloading,
|
|
22
22
|
TrainerState.TrainModelDownloaded, trainer._download_model))
|
|
23
|
-
await assert_training_state(trainer.training,
|
|
24
|
-
await assert_training_state(trainer.training,
|
|
23
|
+
await assert_training_state(trainer.training, TrainerState.TrainModelDownloading, timeout=1, interval=0.001)
|
|
24
|
+
await assert_training_state(trainer.training, TrainerState.TrainModelDownloaded, timeout=10, interval=0.001)
|
|
25
25
|
|
|
26
26
|
assert trainer.training.training_state == TrainerState.TrainModelDownloaded
|
|
27
27
|
assert trainer.node.last_training_io.load() == trainer.training
|
|
@@ -34,11 +34,11 @@ async def test_downloading_is_successful(test_initialized_trainer: TestingTraine
|
|
|
34
34
|
|
|
35
35
|
async def test_abort_download_model(test_initialized_trainer: TestingTrainerLogic):
|
|
36
36
|
trainer = test_initialized_trainer
|
|
37
|
-
create_active_training_file(trainer, training_state=
|
|
37
|
+
create_active_training_file(trainer, training_state=TrainerState.DataDownloaded)
|
|
38
38
|
trainer._init_from_last_training()
|
|
39
39
|
|
|
40
|
-
|
|
41
|
-
await assert_training_state(trainer.training,
|
|
40
|
+
trainer._begin_training_task()
|
|
41
|
+
await assert_training_state(trainer.training, TrainerState.TrainModelDownloading, timeout=1, interval=0.001)
|
|
42
42
|
|
|
43
43
|
await trainer.stop()
|
|
44
44
|
await asyncio.sleep(0.1)
|
|
@@ -53,9 +53,9 @@ async def test_downloading_failed(test_initialized_trainer: TestingTrainerLogic)
|
|
|
53
53
|
base_model_uuid_or_name='00000000-0000-0000-0000-000000000000') # bad model id)
|
|
54
54
|
trainer._init_from_last_training()
|
|
55
55
|
|
|
56
|
-
|
|
57
|
-
await assert_training_state(trainer.training,
|
|
58
|
-
await assert_training_state(trainer.training, TrainerState.DataDownloaded, timeout=
|
|
56
|
+
trainer._begin_training_task()
|
|
57
|
+
await assert_training_state(trainer.training, TrainerState.TrainModelDownloading, timeout=1, interval=0.001)
|
|
58
|
+
await assert_training_state(trainer.training, TrainerState.DataDownloaded, timeout=10, interval=0.001)
|
|
59
59
|
|
|
60
60
|
assert trainer.errors.has_error_for('download_model')
|
|
61
61
|
assert trainer._training is not None # pylint: disable=protected-access
|
|
@@ -6,11 +6,10 @@ from ..state_helper import assert_training_state, create_active_training_file
|
|
|
6
6
|
from ..testing_trainer_logic import TestingTrainerLogic
|
|
7
7
|
|
|
8
8
|
# pylint: disable=protected-access
|
|
9
|
-
error_key = 'prepare'
|
|
10
9
|
|
|
11
10
|
|
|
12
|
-
def
|
|
13
|
-
return trainer.errors.has_error_for(
|
|
11
|
+
def trainer_has_prepare_error(trainer: TrainerLogic):
|
|
12
|
+
return trainer.errors.has_error_for('prepare')
|
|
14
13
|
|
|
15
14
|
|
|
16
15
|
async def test_preparing_is_successful(test_initialized_trainer: TestingTrainerLogic):
|
|
@@ -19,7 +18,7 @@ async def test_preparing_is_successful(test_initialized_trainer: TestingTrainerL
|
|
|
19
18
|
trainer._init_from_last_training()
|
|
20
19
|
|
|
21
20
|
await trainer._perform_state('prepare', TrainerState.DataDownloading, TrainerState.DataDownloaded, trainer._prepare)
|
|
22
|
-
assert
|
|
21
|
+
assert trainer_has_prepare_error(trainer) is False
|
|
23
22
|
assert trainer.training.training_state == TrainerState.DataDownloaded
|
|
24
23
|
assert trainer.training.data is not None
|
|
25
24
|
assert trainer.node.last_training_io.load() == trainer.training
|
|
@@ -30,7 +29,7 @@ async def test_abort_preparing(test_initialized_trainer: TestingTrainerLogic):
|
|
|
30
29
|
create_active_training_file(trainer)
|
|
31
30
|
trainer._init_from_last_training()
|
|
32
31
|
|
|
33
|
-
|
|
32
|
+
trainer._begin_training_task()
|
|
34
33
|
await assert_training_state(trainer.training, TrainerState.DataDownloading, timeout=1, interval=0.001)
|
|
35
34
|
|
|
36
35
|
await trainer.stop()
|
|
@@ -48,9 +47,9 @@ async def test_request_error(test_initialized_trainer: TestingTrainerLogic):
|
|
|
48
47
|
|
|
49
48
|
_ = asyncio.get_running_loop().create_task(trainer._run())
|
|
50
49
|
await assert_training_state(trainer.training, TrainerState.DataDownloading, timeout=3, interval=0.001)
|
|
51
|
-
await assert_training_state(trainer.training, TrainerState.Initialized, timeout=
|
|
50
|
+
await assert_training_state(trainer.training, TrainerState.Initialized, timeout=10, interval=0.001)
|
|
52
51
|
|
|
53
|
-
assert
|
|
52
|
+
assert trainer_has_prepare_error(trainer)
|
|
54
53
|
assert trainer._training is not None # pylint: disable=protected-access
|
|
55
54
|
assert trainer.training.training_state == TrainerState.Initialized
|
|
56
55
|
assert trainer.node.last_training_io.load() == trainer.training
|