learning-loop-node 0.10.13__py3-none-any.whl → 0.10.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of learning-loop-node might be problematic. Click here for more details.

Files changed (32) hide show
  1. learning_loop_node/annotation/annotator_node.py +11 -10
  2. learning_loop_node/data_classes/detections.py +34 -26
  3. learning_loop_node/data_classes/general.py +27 -17
  4. learning_loop_node/data_exchanger.py +6 -5
  5. learning_loop_node/detector/detector_logic.py +3 -3
  6. learning_loop_node/detector/detector_node.py +21 -15
  7. learning_loop_node/detector/rest/about.py +34 -10
  8. learning_loop_node/detector/rest/backdoor_controls.py +9 -26
  9. learning_loop_node/detector/rest/detect.py +17 -16
  10. learning_loop_node/detector/rest/model_version_control.py +30 -13
  11. learning_loop_node/detector/rest/operation_mode.py +11 -5
  12. learning_loop_node/detector/rest/outbox_mode.py +7 -1
  13. learning_loop_node/node.py +93 -48
  14. learning_loop_node/rest.py +25 -2
  15. learning_loop_node/tests/detector/conftest.py +4 -4
  16. learning_loop_node/tests/detector/test_client_communication.py +21 -19
  17. learning_loop_node/tests/detector/test_detector_node.py +3 -3
  18. learning_loop_node/tests/trainer/conftest.py +4 -4
  19. learning_loop_node/tests/trainer/states/test_state_detecting.py +8 -9
  20. learning_loop_node/tests/trainer/states/test_state_download_train_model.py +8 -8
  21. learning_loop_node/tests/trainer/states/test_state_prepare.py +6 -7
  22. learning_loop_node/tests/trainer/states/test_state_sync_confusion_matrix.py +21 -18
  23. learning_loop_node/tests/trainer/states/test_state_train.py +6 -8
  24. learning_loop_node/tests/trainer/states/test_state_upload_detections.py +7 -9
  25. learning_loop_node/tests/trainer/states/test_state_upload_model.py +7 -8
  26. learning_loop_node/tests/trainer/test_errors.py +2 -2
  27. learning_loop_node/trainer/rest/backdoor_controls.py +19 -40
  28. learning_loop_node/trainer/trainer_logic_generic.py +7 -4
  29. learning_loop_node/trainer/trainer_node.py +4 -3
  30. {learning_loop_node-0.10.13.dist-info → learning_loop_node-0.10.14.dist-info}/METADATA +1 -1
  31. {learning_loop_node-0.10.13.dist-info → learning_loop_node-0.10.14.dist-info}/RECORD +32 -32
  32. {learning_loop_node-0.10.13.dist-info → learning_loop_node-0.10.14.dist-info}/WHEEL +0 -0
@@ -6,10 +6,13 @@ from ..outbox import Outbox
6
6
  router = APIRouter()
7
7
 
8
8
 
9
- @router.get("/outbox_mode")
9
+ @router.get("/outbox_mode", response_class=PlainTextResponse)
10
10
  async def get_outbox_mode(request: Request):
11
11
  '''
12
+ Get the outbox mode of the detector node.
13
+
12
14
  Example Usage
15
+
13
16
  curl http://localhost/outbox_mode
14
17
  '''
15
18
  outbox: Outbox = request.app.outbox
@@ -19,7 +22,10 @@ async def get_outbox_mode(request: Request):
19
22
  @router.put("/outbox_mode")
20
23
  async def put_outbox_mode(request: Request):
21
24
  '''
25
+ Set the outbox mode of the detector node.
26
+
22
27
  Example Usage
28
+
23
29
  curl -X PUT -d "continuous_upload" http://localhost/outbox_mode
24
30
  curl -X PUT -d "stopped" http://localhost/outbox_mode
25
31
  '''
@@ -8,7 +8,6 @@ from datetime import datetime
8
8
  from typing import Any, Optional
9
9
 
10
10
  import aiohttp
11
- import socketio
12
11
  from aiohttp import TCPConnector
13
12
  from fastapi import FastAPI
14
13
  from socketio import AsyncClient
@@ -21,6 +20,10 @@ from .loop_communication import LoopCommunicator
21
20
  from .rest import router
22
21
 
23
22
 
23
+ class NodeConnectionError(Exception):
24
+ pass
25
+
26
+
24
27
  class Node(FastAPI):
25
28
 
26
29
  def __init__(self, name: str, uuid: Optional[str] = None, node_type: str = 'node', needs_login: bool = True):
@@ -43,8 +46,7 @@ class Node(FastAPI):
43
46
  self.needs_login = needs_login
44
47
 
45
48
  self.log = logging.getLogger('Node')
46
- self.loop_communicator = LoopCommunicator()
47
- self.websocket_url = self.loop_communicator.websocket_url()
49
+ self.init_loop_communicator()
48
50
  self.data_exchanger = DataExchanger(None, self.loop_communicator)
49
51
 
50
52
  self.startup_datetime = datetime.now()
@@ -56,9 +58,20 @@ class Node(FastAPI):
56
58
  'nodeType': node_type}
57
59
 
58
60
  self.repeat_task: Any = None
61
+ self.socket_connection_broken = False
62
+ self._skip_repeat_loop = False
59
63
 
60
64
  self.include_router(router)
61
65
 
66
+ self.CONNECTED_TO_LOOP = asyncio.Event()
67
+ self.DISCONNECTED_FROM_LOOP = asyncio.Event()
68
+
69
+ self.repeat_loop_lock = asyncio.Lock()
70
+
71
+ def init_loop_communicator(self):
72
+ self.loop_communicator = LoopCommunicator()
73
+ self.websocket_url = self.loop_communicator.websocket_url()
74
+
62
75
  @property
63
76
  def sio_client(self) -> AsyncClient:
64
77
  if self._sio_client is None:
@@ -69,7 +82,10 @@ class Node(FastAPI):
69
82
  @asynccontextmanager
70
83
  async def lifespan(self, app: FastAPI): # pylint: disable=unused-argument
71
84
  try:
72
- await self._on_startup()
85
+ try:
86
+ await self._on_startup()
87
+ except Exception:
88
+ self.log.exception('Fatal error during startup: %s')
73
89
  self.repeat_task = asyncio.create_task(self.repeat_loop())
74
90
  yield
75
91
  finally:
@@ -83,13 +99,10 @@ class Node(FastAPI):
83
99
 
84
100
  async def _on_startup(self):
85
101
  self.log.info('received "startup" lifecycle-event')
86
- # activate_asyncio_warnings()
87
- if self.needs_login:
88
- await self.loop_communicator.backend_ready()
89
- self.log.info('ensuring login')
90
- await self.loop_communicator.ensure_login()
91
- self.log.info('create sio client')
92
- await self.create_sio_client()
102
+ try:
103
+ await self.reconnect_to_loop()
104
+ except Exception:
105
+ self.log.warning('Could not establish sio connection to loop during startup')
93
106
  self.log.info('done')
94
107
  await self.on_startup()
95
108
 
@@ -102,55 +115,93 @@ class Node(FastAPI):
102
115
  await self.on_shutdown()
103
116
 
104
117
  async def repeat_loop(self) -> None:
105
- """NOTE: with the lifespan approach, we cannot use @repeat_every anymore :("""
106
118
  while True:
119
+ if self._skip_repeat_loop:
120
+ self.log.debug('node is muted, skipping repeat loop')
121
+ await asyncio.sleep(1)
122
+ continue
107
123
  try:
108
- await self._on_repeat()
124
+ async with self.repeat_loop_lock:
125
+ await self._ensure_sio_connection()
126
+ await self.on_repeat()
109
127
  except asyncio.CancelledError:
110
128
  return
111
- except Exception as e:
112
- self.log.exception(f'error in repeat loop: {e}')
129
+ except Exception:
130
+ self.log.exception('error in repeat loop')
131
+
113
132
  await asyncio.sleep(5)
114
133
 
115
- async def _on_repeat(self):
116
- if not self.sio_client.connected:
117
- self.log.info('Reconnecting to loop via sio')
118
- await self.connect_sio()
119
- if not self.sio_client.connected:
120
- self.log.warning('Could not connect to loop via sio')
121
- return
122
- await self.on_repeat()
134
+ async def _ensure_sio_connection(self):
135
+ if self.socket_connection_broken or self._sio_client is None or not self.sio_client.connected:
136
+ self.log.info('Reconnecting to loop via sio due to %s',
137
+ 'broken connection' if self.socket_connection_broken else 'no connection')
138
+ await self.reconnect_to_loop()
139
+
140
+ async def reconnect_to_loop(self):
141
+ """Initialize the loop communicator, log in if needed and reconnect to the loop via socket.io."""
142
+ self.init_loop_communicator()
143
+ if self.needs_login:
144
+ await self.loop_communicator.ensure_login(relogin=True)
145
+ try:
146
+ await self._reconnect_socketio()
147
+ except NodeConnectionError:
148
+ self.log.exception('Could not reset sio connection to loop')
149
+ self.socket_connection_broken = True
150
+ raise
151
+
152
+ self.socket_connection_broken = False
153
+
154
+ def set_skip_repeat_loop(self, value: bool):
155
+ self._skip_repeat_loop = value
156
+ self.log.info('node is muted: %s', value)
123
157
 
124
158
  # --------------------------------------------------- SOCKET.IO ---------------------------------------------------
125
159
 
126
- async def create_sio_client(self):
127
- """Create a socket.io client that communicates with the learning loop and register the events.
128
- Note: The method is called in startup and soft restart of detector, so the _sio_client should always be available."""
160
+ async def _reconnect_socketio(self):
161
+ """Create a socket.io client, connect it to the learning loop and register its events.
162
+ The current client is disconnected and deleted if it already exists."""
163
+
164
+ self.log.debug('-------------- Connecting to loop via socket.io -------------------')
165
+ self.log.debug('HTTP Cookies: %s\n', self.loop_communicator.get_cookies())
129
166
 
167
+ if self._sio_client is not None:
168
+ try:
169
+ await self.sio_client.disconnect()
170
+ self.log.info('disconnected from loop via sio')
171
+ # NOTE: without waiting for the disconnect event, we might disconnect the next connection too early
172
+ await asyncio.wait_for(self.DISCONNECTED_FROM_LOOP.wait(), timeout=5)
173
+ except asyncio.TimeoutError:
174
+ self.log.warning(
175
+ 'Did not receive disconnect event from loop within 5 seconds.\nContinuing with new connection...')
176
+ except Exception as e:
177
+ self.log.warning('Could not disconnect from loop via sio: %s.\nIgnoring...', e)
178
+ self._sio_client = None
179
+
180
+ connector = None
130
181
  if self.loop_communicator.ssl_cert_path:
131
- logging.info(f'SIO using SSL certificate path: {self.loop_communicator.ssl_cert_path}')
182
+ logging.info('SIO using SSL certificate path: %s', self.loop_communicator.ssl_cert_path)
132
183
  ssl_context = ssl.create_default_context(cafile=self.loop_communicator.ssl_cert_path)
133
184
  ssl_context.check_hostname = False
134
185
  ssl_context.verify_mode = ssl.CERT_REQUIRED
135
186
  connector = TCPConnector(ssl=ssl_context)
136
- self._sio_client = AsyncClient(request_timeout=20,
137
- http_session=aiohttp.ClientSession(cookies=self.loop_communicator.get_cookies(),
138
- connector=connector))
139
187
 
140
- else:
141
- self._sio_client = AsyncClient(request_timeout=20,
142
- http_session=aiohttp.ClientSession(cookies=self.loop_communicator.get_cookies()))
188
+ self._sio_client = AsyncClient(request_timeout=20, http_session=aiohttp.ClientSession(
189
+ cookies=self.loop_communicator.get_cookies(), connector=connector))
143
190
 
144
191
  # pylint: disable=protected-access
145
- self.sio_client._trigger_event = ensure_socket_response(self.sio_client._trigger_event)
192
+ self._sio_client._trigger_event = ensure_socket_response(self._sio_client._trigger_event)
146
193
 
147
194
  @self._sio_client.event
148
195
  async def connect():
149
196
  self.log.info('received "connect" via sio from loop.')
197
+ self.CONNECTED_TO_LOOP.set()
198
+ self.DISCONNECTED_FROM_LOOP.clear()
150
199
 
151
200
  @self._sio_client.event
152
201
  async def disconnect():
153
202
  self.log.info('received "disconnect" via sio from loop.')
203
+ self.DISCONNECTED_FROM_LOOP.set()
204
+ self.CONNECTED_TO_LOOP.clear()
154
205
 
155
206
  @self._sio_client.event
156
207
  async def restart():
@@ -158,21 +209,15 @@ class Node(FastAPI):
158
209
  sys.exit(0)
159
210
 
160
211
  self.register_sio_events(self._sio_client)
161
-
162
- async def connect_sio(self):
163
212
  try:
164
- await self.sio_client.disconnect()
165
- except Exception:
166
- pass
167
-
168
- self.log.info(f'(re)connecting to Learning Loop at {self.websocket_url}')
169
- try:
170
- await self.sio_client.connect(f"{self.websocket_url}", headers=self.sio_headers, socketio_path="/ws/socket.io")
171
- self.log.info('connected to Learning Loop')
172
- except socketio.exceptions.ConnectionError: # type: ignore
173
- self.log.warning('connection error')
174
- except Exception:
175
- self.log.exception(f'error while connecting to "{self.websocket_url}". Exception:')
213
+ await self._sio_client.connect(f"{self.websocket_url}", headers=self.sio_headers, socketio_path="/ws/socket.io")
214
+ except Exception as e:
215
+ self.log.exception('Could not connect socketio client to loop')
216
+ raise NodeConnectionError('Could not connect socketio client to loop') from e
217
+
218
+ if not self._sio_client.connected:
219
+ self.log.exception('Could not connect socketio client to loop')
220
+ raise NodeConnectionError('Could not connect socketio client to loop')
176
221
 
177
222
  # --------------------------------------------------- ABSTRACT METHODS ---------------------------------------------------
178
223
 
@@ -1,7 +1,7 @@
1
1
  import logging
2
2
  from typing import TYPE_CHECKING
3
3
 
4
- from fastapi import APIRouter, Request, HTTPException
4
+ from fastapi import APIRouter, HTTPException, Request
5
5
 
6
6
  if TYPE_CHECKING:
7
7
  from .node import Node
@@ -12,7 +12,7 @@ logger = logging.getLogger('Node.rest')
12
12
 
13
13
 
14
14
  @router.put("/debug_logging")
15
- async def _debug_logging(request: Request):
15
+ async def _debug_logging(request: Request) -> str:
16
16
  '''
17
17
  Example Usage
18
18
 
@@ -30,3 +30,26 @@ async def _debug_logging(request: Request):
30
30
  node.log.setLevel('DEBUG')
31
31
  return 'on'
32
32
  raise HTTPException(status_code=400, detail='Invalid state')
33
+
34
+
35
+ @router.put("/socketio")
36
+ async def _socketio(request: Request) -> str:
37
+ '''
38
+ Enable or disable the socketio connection to the learning loop.
39
+ Not intended to be used outside of testing.
40
+
41
+ Example Usage
42
+
43
+ curl -X PUT -d "on" http://hosturl/socketio
44
+ '''
45
+ state = str(await request.body(), 'utf-8')
46
+ node: 'Node' = request.app
47
+
48
+ if state == 'off':
49
+ await node.sio_client.disconnect()
50
+ node.set_skip_repeat_loop(True) # Prevent auto-reconnection
51
+ return 'off'
52
+ if state == 'on':
53
+ node.set_skip_repeat_loop(False) # Allow auto-reconnection (1 sec delay)
54
+ return 'on'
55
+ raise HTTPException(status_code=400, detail='Invalid state')
@@ -41,8 +41,8 @@ def should_have_segmentations(request) -> bool:
41
41
  async def test_detector_node():
42
42
  """Initializes and runs a detector testnode. Note that the running instance and the one the function returns are not the same instances!"""
43
43
 
44
- os.environ['ORGANIZATION'] = 'zauberzeug'
45
- os.environ['PROJECT'] = 'demo'
44
+ os.environ['LOOP_ORGANIZATION'] = 'zauberzeug'
45
+ os.environ['LOOP_PROJECT'] = 'demo'
46
46
 
47
47
  detector = TestingDetectorLogic()
48
48
  node = DetectorNode(name='test', detector=detector)
@@ -143,8 +143,8 @@ def mock_detector_logic():
143
143
 
144
144
  @pytest.fixture
145
145
  def detector_node(mock_detector_logic):
146
- os.environ['ORGANIZATION'] = 'test_organization'
147
- os.environ['PROJECT'] = 'test_project'
146
+ os.environ['LOOP_ORGANIZATION'] = 'test_organization'
147
+ os.environ['LOOP_PROJECT'] = 'test_project'
148
148
  return DetectorNode(name="test_node", detector=mock_detector_logic)
149
149
 
150
150
  # ====================================== REDUNDANT FIXTURES IN ALL CONFTESTS ! ======================================
@@ -93,10 +93,10 @@ async def test_sio_upload(test_detector_node: DetectorNode, sio_client):
93
93
 
94
94
  # NOTE: This test seems to be flaky.
95
95
  async def test_about_endpoint(test_detector_node: DetectorNode):
96
- await asyncio.sleep(3)
96
+ await asyncio.sleep(11)
97
97
  response = requests.get(f'http://localhost:{GLOBALS.detector_port}/about', timeout=30)
98
98
 
99
- assert response.status_code == 200
99
+ assert response.status_code == 200, response.content
100
100
  response_dict = json.loads(response.content)
101
101
  assert response_dict['model_info']
102
102
  model_information = ModelInformation.from_dict(response_dict['model_info'])
@@ -108,59 +108,60 @@ async def test_about_endpoint(test_detector_node: DetectorNode):
108
108
 
109
109
 
110
110
  async def test_model_version_api(test_detector_node: DetectorNode):
111
- await asyncio.sleep(3)
111
+ await asyncio.sleep(11)
112
112
 
113
113
  response = requests.get(f'http://localhost:{GLOBALS.detector_port}/model_version', timeout=30)
114
- assert response.status_code == 200
114
+ assert response.status_code == 200, response.content
115
115
  response_dict = json.loads(response.content)
116
+ assert response_dict['version_control'] == 'follow_loop'
116
117
  assert response_dict['current_version'] == '1.1'
117
118
  assert response_dict['target_version'] == '1.1'
118
119
  assert response_dict['loop_version'] == '1.1'
119
120
  assert response_dict['local_versions'] == ['1.1']
120
- assert response_dict['version_control'] == 'follow_loop'
121
121
 
122
122
  response = requests.put(f'http://localhost:{GLOBALS.detector_port}/model_version', data='1.0', timeout=30)
123
+ assert response.status_code == 200, response.content
123
124
  response = requests.get(f'http://localhost:{GLOBALS.detector_port}/model_version', timeout=30)
124
- assert response.status_code == 200
125
+ assert response.status_code == 200, response.content
125
126
  response_dict = json.loads(response.content)
127
+ assert response_dict['version_control'] == 'specific_version'
126
128
  assert response_dict['current_version'] == '1.1'
127
129
  assert response_dict['target_version'] == '1.0'
128
130
  assert response_dict['loop_version'] == '1.1'
129
131
  assert response_dict['local_versions'] == ['1.1']
130
- assert response_dict['version_control'] == 'specific_version'
131
132
 
132
133
  await asyncio.sleep(11)
133
-
134
134
  response = requests.get(f'http://localhost:{GLOBALS.detector_port}/model_version', timeout=30)
135
- assert response.status_code == 200
135
+ assert response.status_code == 200, response.content
136
136
  response_dict = json.loads(response.content)
137
+ assert response_dict['version_control'] == 'specific_version'
137
138
  assert response_dict['current_version'] == '1.0'
138
139
  assert response_dict['target_version'] == '1.0'
139
140
  assert response_dict['loop_version'] == '1.1'
140
141
  assert set(response_dict['local_versions']) == set(['1.1', '1.0'])
141
- assert response_dict['version_control'] == 'specific_version'
142
142
 
143
143
  response = requests.put(f'http://localhost:{GLOBALS.detector_port}/model_version', data='pause', timeout=30)
144
+ assert response.status_code == 200, response.content
144
145
  await asyncio.sleep(11)
145
146
  response = requests.get(f'http://localhost:{GLOBALS.detector_port}/model_version', timeout=30)
146
- assert response.status_code == 200
147
+ assert response.status_code == 200, response.content
147
148
  response_dict = json.loads(response.content)
149
+ assert response_dict['version_control'] == 'pause'
148
150
  assert response_dict['current_version'] == '1.0'
149
151
  assert response_dict['target_version'] == '1.0'
150
152
  assert response_dict['loop_version'] == '1.1'
151
153
  assert set(response_dict['local_versions']) == set(['1.1', '1.0'])
152
- assert response_dict['version_control'] == 'pause'
153
154
 
154
155
  response = requests.put(f'http://localhost:{GLOBALS.detector_port}/model_version', data='follow_loop', timeout=30)
155
156
  await asyncio.sleep(11)
156
157
  response = requests.get(f'http://localhost:{GLOBALS.detector_port}/model_version', timeout=30)
157
- assert response.status_code == 200
158
+ assert response.status_code == 200, response.content
158
159
  response_dict = json.loads(response.content)
160
+ assert response_dict['version_control'] == 'follow_loop'
159
161
  assert response_dict['current_version'] == '1.1'
160
162
  assert response_dict['target_version'] == '1.1'
161
163
  assert response_dict['loop_version'] == '1.1'
162
164
  assert set(response_dict['local_versions']) == set(['1.1', '1.0'])
163
- assert response_dict['version_control'] == 'follow_loop'
164
165
 
165
166
 
166
167
  async def test_rest_outbox_mode(test_detector_node: DetectorNode):
@@ -169,9 +170,9 @@ async def test_rest_outbox_mode(test_detector_node: DetectorNode):
169
170
  def check_switch_to_mode(mode: str):
170
171
  response = requests.put(f'http://localhost:{GLOBALS.detector_port}/outbox_mode',
171
172
  data=mode, timeout=30)
172
- assert response.status_code == 200
173
+ assert response.status_code == 200, response.content
173
174
  response = requests.get(f'http://localhost:{GLOBALS.detector_port}/outbox_mode', timeout=30)
174
- assert response.status_code == 200
175
+ assert response.status_code == 200, response.content
175
176
  assert response.content == mode.encode()
176
177
 
177
178
  check_switch_to_mode('stopped')
@@ -185,7 +186,7 @@ async def test_api_responsive_during_large_upload(test_detector_node: DetectorNo
185
186
  with open(test_image_path, 'rb') as f:
186
187
  image_bytes = f.read()
187
188
 
188
- for _ in range(100):
189
+ for _ in range(200):
189
190
  test_detector_node.outbox.save(image_bytes)
190
191
 
191
192
  outbox_size_early = len(get_outbox_files(test_detector_node.outbox))
@@ -193,8 +194,9 @@ async def test_api_responsive_during_large_upload(test_detector_node: DetectorNo
193
194
 
194
195
  # check if api is still responsive
195
196
  response = requests.get(f'http://localhost:{GLOBALS.detector_port}/outbox_mode', timeout=2)
196
- assert response.status_code == 200
197
+ assert response.status_code == 200, response.content
197
198
 
198
199
  await asyncio.sleep(5)
199
200
  outbox_size_late = len(get_outbox_files(test_detector_node.outbox))
200
- assert outbox_size_early > outbox_size_late > 0, 'The outbox should have been partially emptied.'
201
+ assert outbox_size_late > 0, 'The outbox should not be fully cleared, maybe the node was too fast.'
202
+ assert outbox_size_early > outbox_size_late, 'The outbox should have been partially emptied.'
@@ -57,9 +57,9 @@ async def test_get_detections(detector_node: DetectorNode, monkeypatch):
57
57
 
58
58
  # Check if detections were processed
59
59
  assert result is not None
60
- assert "box_detections" in result
61
- assert len(result["box_detections"]) == 1
62
- assert result["box_detections"][0]["category_name"] == "test"
60
+ assert result.box_detections is not None
61
+ assert len(result.box_detections) == 1
62
+ assert result.box_detections[0].category_name == "test"
63
63
 
64
64
  # Check if the correct upload method was called
65
65
  assert filtered_upload_called == expect_filtered
@@ -1,7 +1,6 @@
1
- from ...globals import GLOBALS
2
- import shutil
3
1
  import logging
4
2
  import os
3
+ import shutil
5
4
  import socket
6
5
  from multiprocessing import log_to_stderr
7
6
 
@@ -9,6 +8,7 @@ import icecream
9
8
  import pytest
10
9
 
11
10
  from ...data_classes import Context
11
+ from ...globals import GLOBALS
12
12
  from ...trainer.trainer_node import TrainerNode
13
13
  from .testing_trainer_logic import TestingTrainerLogic
14
14
 
@@ -23,8 +23,8 @@ icecream.install()
23
23
 
24
24
  @pytest.fixture()
25
25
  async def test_initialized_trainer_node():
26
- os.environ['ORGANIZATION'] = 'zauberzeug'
27
- os.environ['PROJECT'] = 'demo'
26
+ os.environ['LOOP_ORGANIZATION'] = 'zauberzeug'
27
+ os.environ['LOOP_PROJECT'] = 'demo'
28
28
 
29
29
  trainer = TestingTrainerLogic()
30
30
  node = TrainerNode(name='test', trainer_logic=trainer, uuid='NOD30000-0000-0000-0000-000000000000')
@@ -7,11 +7,10 @@ from ..state_helper import assert_training_state, create_active_training_file
7
7
  from ..testing_trainer_logic import TestingTrainerLogic
8
8
 
9
9
  # pylint: disable=protected-access
10
- error_key = 'detecting'
11
10
 
12
11
 
13
- def trainer_has_error(trainer: TrainerLogic):
14
- return trainer.errors.has_error_for(error_key)
12
+ def trainer_has_detecting_error(trainer: TrainerLogic):
13
+ return trainer.errors.has_error_for('detecting')
15
14
 
16
15
 
17
16
  async def test_successful_detecting(test_initialized_trainer: TestingTrainerLogic):
@@ -25,7 +24,7 @@ async def test_successful_detecting(test_initialized_trainer: TestingTrainerLogi
25
24
  await assert_training_state(trainer.training, TrainerState.Detecting, timeout=1, interval=0.001)
26
25
  await assert_training_state(trainer.training, TrainerState.Detected, timeout=10, interval=0.001)
27
26
 
28
- assert trainer_has_error(trainer) is False
27
+ assert trainer_has_detecting_error(trainer) is False
29
28
  assert trainer.training.training_state == TrainerState.Detected
30
29
  assert trainer.node.last_training_io.load() == trainer.training
31
30
  assert trainer.active_training_io.detections_exist()
@@ -37,7 +36,7 @@ async def test_detecting_can_be_aborted(test_initialized_trainer: TestingTrainer
37
36
  trainer._init_from_last_training()
38
37
  trainer.training.model_uuid_for_detecting = '12345678-bobo-7e92-f95f-424242424242'
39
38
 
40
- _ = asyncio.get_running_loop().create_task(trainer._run())
39
+ trainer._begin_training_task()
41
40
 
42
41
  await assert_training_state(trainer.training, TrainerState.Detecting, timeout=5, interval=0.001)
43
42
  await trainer.stop()
@@ -54,13 +53,13 @@ async def test_model_not_downloadable_error(test_initialized_trainer: TestingTra
54
53
  model_uuid_for_detecting='00000000-0000-0000-0000-000000000000') # bad model id
55
54
  trainer._init_from_last_training()
56
55
 
57
- _ = asyncio.get_running_loop().create_task(trainer._run())
56
+ trainer._begin_training_task()
58
57
 
59
- await assert_training_state(trainer.training, 'detecting', timeout=1, interval=0.001)
60
- await assert_training_state(trainer.training, 'train_model_uploaded', timeout=1, interval=0.001)
58
+ await assert_training_state(trainer.training, TrainerState.Detecting, timeout=1, interval=0.001)
59
+ await assert_training_state(trainer.training, TrainerState.TrainModelUploaded, timeout=5, interval=0.001)
61
60
  await asyncio.sleep(0.1)
62
61
 
63
- assert trainer_has_error(trainer)
62
+ assert trainer_has_detecting_error(trainer)
64
63
  assert trainer.training.training_state == TrainerState.TrainModelUploaded
65
64
  assert trainer.training.model_uuid_for_detecting == '00000000-0000-0000-0000-000000000000'
66
65
  assert trainer.node.last_training_io.load() == trainer.training
@@ -20,8 +20,8 @@ async def test_downloading_is_successful(test_initialized_trainer: TestingTraine
20
20
  trainer._perform_state('download_model',
21
21
  TrainerState.TrainModelDownloading,
22
22
  TrainerState.TrainModelDownloaded, trainer._download_model))
23
- await assert_training_state(trainer.training, 'train_model_downloading', timeout=1, interval=0.001)
24
- await assert_training_state(trainer.training, 'train_model_downloaded', timeout=1, interval=0.001)
23
+ await assert_training_state(trainer.training, TrainerState.TrainModelDownloading, timeout=1, interval=0.001)
24
+ await assert_training_state(trainer.training, TrainerState.TrainModelDownloaded, timeout=10, interval=0.001)
25
25
 
26
26
  assert trainer.training.training_state == TrainerState.TrainModelDownloaded
27
27
  assert trainer.node.last_training_io.load() == trainer.training
@@ -34,11 +34,11 @@ async def test_downloading_is_successful(test_initialized_trainer: TestingTraine
34
34
 
35
35
  async def test_abort_download_model(test_initialized_trainer: TestingTrainerLogic):
36
36
  trainer = test_initialized_trainer
37
- create_active_training_file(trainer, training_state='data_downloaded')
37
+ create_active_training_file(trainer, training_state=TrainerState.DataDownloaded)
38
38
  trainer._init_from_last_training()
39
39
 
40
- _ = asyncio.get_running_loop().create_task(trainer._run())
41
- await assert_training_state(trainer.training, 'train_model_downloading', timeout=1, interval=0.001)
40
+ trainer._begin_training_task()
41
+ await assert_training_state(trainer.training, TrainerState.TrainModelDownloading, timeout=1, interval=0.001)
42
42
 
43
43
  await trainer.stop()
44
44
  await asyncio.sleep(0.1)
@@ -53,9 +53,9 @@ async def test_downloading_failed(test_initialized_trainer: TestingTrainerLogic)
53
53
  base_model_uuid_or_name='00000000-0000-0000-0000-000000000000') # bad model id)
54
54
  trainer._init_from_last_training()
55
55
 
56
- _ = asyncio.get_running_loop().create_task(trainer._run())
57
- await assert_training_state(trainer.training, 'train_model_downloading', timeout=1, interval=0.001)
58
- await assert_training_state(trainer.training, TrainerState.DataDownloaded, timeout=1, interval=0.001)
56
+ trainer._begin_training_task()
57
+ await assert_training_state(trainer.training, TrainerState.TrainModelDownloading, timeout=1, interval=0.001)
58
+ await assert_training_state(trainer.training, TrainerState.DataDownloaded, timeout=10, interval=0.001)
59
59
 
60
60
  assert trainer.errors.has_error_for('download_model')
61
61
  assert trainer._training is not None # pylint: disable=protected-access
@@ -6,11 +6,10 @@ from ..state_helper import assert_training_state, create_active_training_file
6
6
  from ..testing_trainer_logic import TestingTrainerLogic
7
7
 
8
8
  # pylint: disable=protected-access
9
- error_key = 'prepare'
10
9
 
11
10
 
12
- def trainer_has_error(trainer: TrainerLogic):
13
- return trainer.errors.has_error_for(error_key)
11
+ def trainer_has_prepare_error(trainer: TrainerLogic):
12
+ return trainer.errors.has_error_for('prepare')
14
13
 
15
14
 
16
15
  async def test_preparing_is_successful(test_initialized_trainer: TestingTrainerLogic):
@@ -19,7 +18,7 @@ async def test_preparing_is_successful(test_initialized_trainer: TestingTrainerL
19
18
  trainer._init_from_last_training()
20
19
 
21
20
  await trainer._perform_state('prepare', TrainerState.DataDownloading, TrainerState.DataDownloaded, trainer._prepare)
22
- assert trainer_has_error(trainer) is False
21
+ assert trainer_has_prepare_error(trainer) is False
23
22
  assert trainer.training.training_state == TrainerState.DataDownloaded
24
23
  assert trainer.training.data is not None
25
24
  assert trainer.node.last_training_io.load() == trainer.training
@@ -30,7 +29,7 @@ async def test_abort_preparing(test_initialized_trainer: TestingTrainerLogic):
30
29
  create_active_training_file(trainer)
31
30
  trainer._init_from_last_training()
32
31
 
33
- _ = asyncio.get_running_loop().create_task(trainer._run())
32
+ trainer._begin_training_task()
34
33
  await assert_training_state(trainer.training, TrainerState.DataDownloading, timeout=1, interval=0.001)
35
34
 
36
35
  await trainer.stop()
@@ -48,9 +47,9 @@ async def test_request_error(test_initialized_trainer: TestingTrainerLogic):
48
47
 
49
48
  _ = asyncio.get_running_loop().create_task(trainer._run())
50
49
  await assert_training_state(trainer.training, TrainerState.DataDownloading, timeout=3, interval=0.001)
51
- await assert_training_state(trainer.training, TrainerState.Initialized, timeout=3, interval=0.001)
50
+ await assert_training_state(trainer.training, TrainerState.Initialized, timeout=10, interval=0.001)
52
51
 
53
- assert trainer_has_error(trainer)
52
+ assert trainer_has_prepare_error(trainer)
54
53
  assert trainer._training is not None # pylint: disable=protected-access
55
54
  assert trainer.training.training_state == TrainerState.Initialized
56
55
  assert trainer.node.last_training_io.load() == trainer.training