learning-loop-node 0.9.3__py3-none-any.whl → 0.10.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of learning-loop-node might be problematic. Click here for more details.

Files changed (54) hide show
  1. learning_loop_node/__init__.py +2 -3
  2. learning_loop_node/annotation/annotator_logic.py +2 -2
  3. learning_loop_node/annotation/annotator_node.py +16 -15
  4. learning_loop_node/data_classes/__init__.py +17 -10
  5. learning_loop_node/data_classes/detections.py +7 -2
  6. learning_loop_node/data_classes/general.py +4 -5
  7. learning_loop_node/data_classes/training.py +49 -21
  8. learning_loop_node/data_exchanger.py +85 -139
  9. learning_loop_node/detector/__init__.py +0 -1
  10. learning_loop_node/detector/detector_node.py +10 -13
  11. learning_loop_node/detector/inbox_filter/cam_observation_history.py +4 -7
  12. learning_loop_node/detector/outbox.py +0 -1
  13. learning_loop_node/detector/rest/about.py +1 -0
  14. learning_loop_node/detector/tests/conftest.py +0 -1
  15. learning_loop_node/detector/tests/test_client_communication.py +5 -3
  16. learning_loop_node/detector/tests/test_outbox.py +2 -0
  17. learning_loop_node/detector/tests/testing_detector.py +1 -8
  18. learning_loop_node/globals.py +2 -2
  19. learning_loop_node/helpers/gdrive_downloader.py +1 -1
  20. learning_loop_node/helpers/misc.py +124 -17
  21. learning_loop_node/loop_communication.py +57 -25
  22. learning_loop_node/node.py +62 -135
  23. learning_loop_node/tests/test_downloader.py +8 -7
  24. learning_loop_node/tests/test_executor.py +14 -11
  25. learning_loop_node/tests/test_helper.py +3 -5
  26. learning_loop_node/trainer/downloader.py +1 -1
  27. learning_loop_node/trainer/executor.py +87 -83
  28. learning_loop_node/trainer/io_helpers.py +66 -9
  29. learning_loop_node/trainer/rest/backdoor_controls.py +10 -5
  30. learning_loop_node/trainer/rest/controls.py +3 -1
  31. learning_loop_node/trainer/tests/conftest.py +19 -28
  32. learning_loop_node/trainer/tests/states/test_state_cleanup.py +5 -3
  33. learning_loop_node/trainer/tests/states/test_state_detecting.py +23 -20
  34. learning_loop_node/trainer/tests/states/test_state_download_train_model.py +18 -12
  35. learning_loop_node/trainer/tests/states/test_state_prepare.py +13 -12
  36. learning_loop_node/trainer/tests/states/test_state_sync_confusion_matrix.py +21 -18
  37. learning_loop_node/trainer/tests/states/test_state_train.py +27 -28
  38. learning_loop_node/trainer/tests/states/test_state_upload_detections.py +34 -32
  39. learning_loop_node/trainer/tests/states/test_state_upload_model.py +22 -20
  40. learning_loop_node/trainer/tests/test_errors.py +20 -12
  41. learning_loop_node/trainer/tests/test_trainer_states.py +4 -5
  42. learning_loop_node/trainer/tests/testing_trainer_logic.py +25 -30
  43. learning_loop_node/trainer/trainer_logic.py +80 -590
  44. learning_loop_node/trainer/trainer_logic_generic.py +495 -0
  45. learning_loop_node/trainer/trainer_node.py +27 -106
  46. {learning_loop_node-0.9.3.dist-info → learning_loop_node-0.10.0.dist-info}/METADATA +1 -1
  47. learning_loop_node-0.10.0.dist-info/RECORD +85 -0
  48. learning_loop_node/converter/converter_logic.py +0 -68
  49. learning_loop_node/converter/converter_node.py +0 -125
  50. learning_loop_node/converter/tests/test_converter.py +0 -55
  51. learning_loop_node/trainer/training_syncronizer.py +0 -52
  52. learning_loop_node-0.9.3.dist-info/RECORD +0 -88
  53. /learning_loop_node/{converter/__init__.py → py.typed} +0 -0
  54. {learning_loop_node-0.9.3.dist-info → learning_loop_node-0.10.0.dist-info}/WHEEL +0 -0
@@ -1,14 +1,21 @@
1
1
  """original copied from https://quantlane.com/blog/ensure-asyncio-task-exceptions-get-logged/"""
2
2
  import asyncio
3
3
  import functools
4
+ import json
4
5
  import logging
5
6
  import os
7
+ import shutil
8
+ import sys
6
9
  from dataclasses import asdict
10
+ from glob import glob
11
+ from time import perf_counter
7
12
  from typing import Any, Coroutine, List, Optional, Tuple, TypeVar
13
+ from uuid import UUID, uuid4
8
14
 
9
15
  import pynvml
10
16
 
11
- from ..data_classes import SocketResponse
17
+ from ..data_classes import Context, SocketResponse, Training
18
+ from ..globals import GLOBALS
12
19
 
13
20
  T = TypeVar('T')
14
21
 
@@ -48,7 +55,7 @@ def _handle_task_result(task: asyncio.Task,
48
55
  logger.exception(message, *message_args)
49
56
 
50
57
 
51
- def get_free_memory_mb() -> float: # TODO check if this is used
58
+ def get_free_memory_mb() -> float: # NOTE used by yolov5
52
59
  pynvml.nvmlInit()
53
60
  h = pynvml.nvmlDeviceGetHandleByIndex(0)
54
61
  info = pynvml.nvmlDeviceGetMemoryInfo(h)
@@ -56,16 +63,33 @@ def get_free_memory_mb() -> float: # TODO check if this is used
56
63
  return free
57
64
 
58
65
 
66
+ async def is_valid_image(filename: str, check_jpeg: bool) -> bool:
67
+ if not os.path.isfile(filename) or os.path.getsize(filename) == 0:
68
+ return False
69
+ if not check_jpeg:
70
+ return True
71
+
72
+ info = await asyncio.create_subprocess_shell(f'jpeginfo -c {filename}',
73
+ stdout=asyncio.subprocess.PIPE,
74
+ stderr=asyncio.subprocess.PIPE)
75
+ out, _ = await info.communicate()
76
+ return "OK" in out.decode()
77
+
78
+
79
+ async def delete_corrupt_images(image_folder: str, check_jpeg: bool = False) -> None:
80
+ logging.info('deleting corrupt images')
81
+ n_deleted = 0
82
+ for image in glob(f'{image_folder}/*.jpg'):
83
+ if not await is_valid_image(image, check_jpeg):
84
+ logging.debug(f' deleting image {image}')
85
+ os.remove(image)
86
+ n_deleted += 1
87
+
88
+ logging.info(f'deleted {n_deleted} images')
89
+
90
+
59
91
  def create_resource_paths(organization_name: str, project_name: str, image_ids: List[str]) -> Tuple[List[str], List[str]]:
60
- # TODO: experimental:
61
92
  return [f'/{organization_name}/projects/{project_name}/images/{id}/main' for id in image_ids], image_ids
62
- # if not image_ids:
63
- # return [], []
64
- # url_ids: List[Tuple(str, str)] = [(f'/{organization_name}/projects/{project_name}/images/{id}/main', id)
65
- # for id in image_ids]
66
- # urls, ids = list(map(list, zip(*url_ids)))
67
-
68
- # return urls, ids
69
93
 
70
94
 
71
95
  def create_image_folder(project_folder: str) -> str:
@@ -74,6 +98,24 @@ def create_image_folder(project_folder: str) -> str:
74
98
  return image_folder
75
99
 
76
100
 
101
+ def read_or_create_uuid(identifier: str) -> str:
102
+ identifier = identifier.lower().replace(' ', '_')
103
+ uuids = {}
104
+ os.makedirs(GLOBALS.data_folder, exist_ok=True)
105
+ file_path = f'{GLOBALS.data_folder}/uuids.json'
106
+ if os.path.exists(file_path):
107
+ with open(file_path, 'r') as f:
108
+ uuids = json.load(f)
109
+
110
+ uuid = uuids.get(identifier, None)
111
+ if not uuid:
112
+ uuid = str(uuid4())
113
+ uuids[identifier] = uuid
114
+ with open(file_path, 'w') as f:
115
+ json.dump(uuids, f)
116
+ return uuid
117
+
118
+
77
119
  def ensure_socket_response(func):
78
120
  """Decorator to ensure that the return value of a socket.io event handler is a SocketResponse.
79
121
 
@@ -90,20 +132,85 @@ def ensure_socket_response(func):
90
132
 
91
133
  if isinstance(value, str):
92
134
  return asdict(SocketResponse.for_success(value))
93
- elif isinstance(value, bool):
135
+ if isinstance(value, bool):
94
136
  return asdict(SocketResponse.from_bool(value))
95
- elif isinstance(value, SocketResponse):
137
+ if isinstance(value, SocketResponse):
96
138
  return value
97
- elif (args[0] in ['connect', 'disconnect', 'connect_error']):
139
+ if (args[0] in ['connect', 'disconnect', 'connect_error']):
98
140
  return value
99
- elif value is None:
141
+ if value is None:
100
142
  return None
101
- else:
102
- raise Exception(
103
- f"Return type for sio must be str, bool, SocketResponse or None', but was {type(value)}'")
143
+
144
+ raise Exception(
145
+ f"Return type for sio must be str, bool, SocketResponse or None', but was {type(value)}'")
104
146
  except Exception as e:
105
147
  logging.exception(f'An error occured for {args[0]}')
106
148
 
107
149
  return asdict(SocketResponse.for_failure(str(e)))
108
150
 
109
151
  return wrapper_ensure_socket_response
152
+
153
+
154
+ def is_valid_uuid4(val):
155
+ if not val:
156
+ return False
157
+ try:
158
+ _ = UUID(str(val)).version
159
+ return True
160
+ except ValueError:
161
+ return False
162
+
163
+
164
+ def create_project_folder(context: Context) -> str:
165
+ project_folder = f'{GLOBALS.data_folder}/{context.organization}/{context.project}'
166
+ os.makedirs(project_folder, exist_ok=True)
167
+ return project_folder
168
+
169
+
170
+ def activate_asyncio_warnings() -> None:
171
+ '''Produce warnings for coroutines which take too long on the main loop and hence clog the event loop'''
172
+ try:
173
+ if sys.version_info.major >= 3 and sys.version_info.minor >= 7: # most
174
+ loop = asyncio.get_running_loop()
175
+ else:
176
+ loop = asyncio.get_event_loop()
177
+
178
+ loop.set_debug(True)
179
+ loop.slow_callback_duration = 0.2
180
+ logging.info('activated asyncio warnings')
181
+ except Exception:
182
+ logging.exception('could not activate asyncio warnings. Exception:')
183
+
184
+
185
+ def images_for_ids(image_ids, image_folder) -> List[str]:
186
+ logging.info(f'### Going to get images for {len(image_ids)} images ids')
187
+ start = perf_counter()
188
+ images = [img for img in glob(f'{image_folder}/**/*.*', recursive=True)
189
+ if os.path.splitext(os.path.basename(img))[0] in image_ids]
190
+ end = perf_counter()
191
+ logging.info(f'found {len(images)} images for {len(image_ids)} image ids, which took {end-start:0.2f} seconds')
192
+ return images
193
+
194
+
195
+ def generate_training(project_folder: str, context: Context) -> Training:
196
+ training_uuid = str(uuid4())
197
+ return Training(
198
+ id=training_uuid,
199
+ context=context,
200
+ project_folder=project_folder,
201
+ images_folder=create_image_folder(project_folder),
202
+ training_folder=create_training_folder(project_folder, training_uuid)
203
+ )
204
+
205
+
206
+ def delete_all_training_folders(project_folder: str):
207
+ if not os.path.exists(f'{project_folder}/trainings'):
208
+ return
209
+ for uuid in os.listdir(f'{project_folder}/trainings'):
210
+ shutil.rmtree(f'{project_folder}/trainings/{uuid}', ignore_errors=True)
211
+
212
+
213
+ def create_training_folder(project_folder: str, trainings_id: str) -> str:
214
+ training_folder = f'{project_folder}/trainings/{trainings_id}'
215
+ os.makedirs(training_folder, exist_ok=True)
216
+ return training_folder
@@ -1,6 +1,6 @@
1
1
  import asyncio
2
2
  import logging
3
- from typing import List, Optional
3
+ from typing import Awaitable, Callable, List, Optional
4
4
 
5
5
  import httpx
6
6
  from httpx import Cookies, Timeout
@@ -24,21 +24,21 @@ class LoopCommunicator():
24
24
  self.project: str = environment_reader.project() # used by mock_detector
25
25
  self.base_url: str = f'http{"s" if "learning-loop.ai" in host else ""}://' + host
26
26
  self.async_client: httpx.AsyncClient = httpx.AsyncClient(base_url=self.base_url, timeout=Timeout(60.0))
27
+ self.async_client.cookies.clear()
27
28
 
28
29
  logging.info(f'Loop interface initialized with base_url: {self.base_url} / user: {self.username}')
29
30
 
30
- # @property
31
- # def project_path(self): # TODO: remove?
32
- # return f'/{self.organization}/projects/{self.project}'
31
+ def websocket_url(self) -> str:
32
+ return f'ws{"s" if "learning-loop.ai" in self.host else ""}://' + self.host
33
33
 
34
- async def ensure_login(self) -> None:
34
+ async def ensure_login(self, relogin=False) -> None:
35
35
  """aiohttp client session needs to be created on the event loop"""
36
36
 
37
37
  assert not self.async_client.is_closed, 'async client must not be used after shutdown'
38
- if not self.async_client.cookies.keys():
38
+ if not self.async_client.cookies.keys() or relogin:
39
+ self.async_client.cookies.clear()
39
40
  response = await self.async_client.post('/api/login', data={'username': self.username, 'password': self.password})
40
41
  if response.status_code != 200:
41
- self.async_client.cookies.clear()
42
42
  logging.info(f'Login failed with response: {response}')
43
43
  raise LoopCommunicationException('Login failed with response: ' + str(response))
44
44
  self.async_client.cookies.update(response.cookies)
@@ -50,8 +50,9 @@ class LoopCommunicator():
50
50
  if response.status_code != 200:
51
51
  logging.info(f'Logout failed with response: {response}')
52
52
  raise LoopCommunicationException('Logout failed with response: ' + str(response))
53
+ self.async_client.cookies.clear()
53
54
 
54
- async def get_cookies(self) -> Cookies:
55
+ def get_cookies(self) -> Cookies:
55
56
  return self.async_client.cookies
56
57
 
57
58
  async def shutdown(self):
@@ -70,37 +71,68 @@ class LoopCommunicator():
70
71
  logging.info(f'backend not ready: {e}')
71
72
  await asyncio.sleep(10)
72
73
 
74
+ async def retry_on_401(self, func: Callable[..., Awaitable[httpx.Response]], *args, **kwargs) -> httpx.Response:
75
+ response = await func(*args, **kwargs)
76
+ if response.status_code == 401:
77
+ await self.ensure_login(relogin=True)
78
+ response = await func(*args, **kwargs)
79
+ return response
80
+
73
81
  async def get(self, path: str, requires_login: bool = True, api_prefix: str = '/api') -> httpx.Response:
74
82
  if requires_login:
75
83
  await self.ensure_login()
84
+ return await self.retry_on_401(self._get, path, api_prefix)
85
+ else:
86
+ return await self._get(path, api_prefix)
87
+
88
+ async def _get(self, path: str, api_prefix: str) -> httpx.Response:
76
89
  return await self.async_client.get(api_prefix+path)
77
90
 
78
- async def put(self, path, files: Optional[List[str]]=None, requires_login=True, api_prefix='/api', **kwargs) -> httpx.Response:
91
+ async def put(self, path: str, files: Optional[List[str]] = None, requires_login: bool = True, api_prefix: str = '/api', **kwargs) -> httpx.Response:
79
92
  if requires_login:
80
93
  await self.ensure_login()
94
+ return await self.retry_on_401(self._put, path, files, api_prefix, **kwargs)
95
+ else:
96
+ return await self._put(path, files, api_prefix, **kwargs)
97
+
98
+ async def _put(self, path: str, files: Optional[List[str]], api_prefix: str, **kwargs) -> httpx.Response:
81
99
  if files is None:
82
100
  return await self.async_client.put(api_prefix+path, **kwargs)
83
-
84
- file_list = [('files', open(f, 'rb')) for f in files] # TODO: does this properly close the files after upload?
85
- return await self.async_client.put(api_prefix+path, files=file_list)
86
101
 
87
- async def post(self, path, requires_login=True, api_prefix='/api', **kwargs) -> httpx.Response:
102
+ file_handles = []
103
+ for f in files:
104
+ try:
105
+ file_handles.append(open(f, 'rb'))
106
+ except FileNotFoundError:
107
+ for fh in file_handles:
108
+ fh.close() # Ensure all files are closed
109
+ return httpx.Response(404, content=b'File not found')
110
+
111
+ try:
112
+ file_list = [('files', fh) for fh in file_handles] # Use file handles
113
+ response = await self.async_client.put(api_prefix+path, files=file_list)
114
+ finally:
115
+ for fh in file_handles:
116
+ fh.close() # Ensure all files are closed
117
+
118
+ return response
119
+
120
+ async def post(self, path: str, requires_login: bool = True, api_prefix: str = '/api', **kwargs) -> httpx.Response:
88
121
  if requires_login:
89
122
  await self.ensure_login()
123
+ return await self.retry_on_401(self._post, path, api_prefix, **kwargs)
124
+ else:
125
+ return await self._post(path, api_prefix, **kwargs)
126
+
127
+ async def _post(self, path, api_prefix='/api', **kwargs) -> httpx.Response:
90
128
  return await self.async_client.post(api_prefix+path, **kwargs)
91
129
 
92
- async def delete(self, path, requires_login=True, api_prefix='/api', **kwargs) -> httpx.Response:
130
+ async def delete(self, path: str, requires_login: bool = True, api_prefix: str = '/api', **kwargs) -> httpx.Response:
93
131
  if requires_login:
94
132
  await self.ensure_login()
95
- return await self.async_client.delete(api_prefix+path, **kwargs)
96
-
97
- # --------------------------------- unused?! --------------------------------- #TODO remove?
133
+ return await self.retry_on_401(self._delete, path, api_prefix, **kwargs)
134
+ else:
135
+ return await self._delete(path, api_prefix, **kwargs)
98
136
 
99
- # def get_data(self, path):
100
- # return asyncio.get_event_loop().run_until_complete(self._get_data_async(path))
101
-
102
- # async def _get_data_async(self, path) -> bytes:
103
- # response = await self.get(f'{self.project_path}{path}')
104
- # if response.status_code != 200:
105
- # raise LoopCommunicationException('bad response: ' + str(response))
106
- # return response.content
137
+ async def _delete(self, path, api_prefix, **kwargs) -> httpx.Response:
138
+ return await self.async_client.delete(api_prefix+path, **kwargs)
@@ -1,58 +1,58 @@
1
1
  import asyncio
2
- import json
3
2
  import logging
4
- import os
5
3
  import sys
6
4
  from abc import abstractmethod
5
+ from contextlib import asynccontextmanager
7
6
  from datetime import datetime
8
- from typing import Optional
9
- from uuid import uuid4
7
+ from typing import Any, Optional
10
8
 
11
9
  import aiohttp
12
10
  import socketio
13
11
  from fastapi import FastAPI
14
- from fastapi_utils.tasks import repeat_every
15
12
  from socketio import AsyncClient
16
13
 
17
- from .data_classes import Context, NodeState, NodeStatus
14
+ from .data_classes import NodeStatus
18
15
  from .data_exchanger import DataExchanger
19
- from .globals import GLOBALS
20
- from .helpers import environment_reader, log_conf
21
- from .helpers.misc import ensure_socket_response
16
+ from .helpers import log_conf
17
+ from .helpers.misc import activate_asyncio_warnings, ensure_socket_response, read_or_create_uuid
22
18
  from .loop_communication import LoopCommunicator
23
19
 
24
20
 
25
21
  class Node(FastAPI):
26
22
 
27
- def __init__(self, name: str, uuid: Optional[str] = None):
23
+ def __init__(self, name: str, uuid: Optional[str] = None, node_type: str = 'node', needs_login: bool = True):
28
24
  """Base class for all nodes. A node is a process that communicates with the zauberzeug learning loop.
25
+ This class provides the basic functionality to connect to the learning loop via socket.io and to exchange data.
29
26
 
30
27
  Args:
31
28
  name (str): The name of the node. This name is used to generate a uuid.
32
29
  uuid (Optional[str]): The uuid of the node. If None, a uuid is generated based on the name
33
30
  and stored in f'{GLOBALS.data_folder}/uuids.json'.
34
- From the second run, the uuid is recovered based on the name of the node. Defaults to None.
31
+ From the second run, the uuid is recovered based on the name of the node.
32
+ needs_login (bool): If True, the node will try to login to the learning loop.
35
33
  """
36
34
 
37
- super().__init__()
35
+ super().__init__(lifespan=self.lifespan)
38
36
  log_conf.init()
39
37
 
38
+ self.name = name
39
+ self.uuid = uuid or read_or_create_uuid(self.name)
40
+ self.needs_login = needs_login
41
+
40
42
  self.log = logging.getLogger()
41
43
  self.loop_communicator = LoopCommunicator()
44
+ self.websocket_url = self.loop_communicator.websocket_url()
42
45
  self.data_exchanger = DataExchanger(None, self.loop_communicator)
43
46
 
44
- host = environment_reader.host(default='learning-loop.ai')
45
- self.ws_url = f'ws{"s" if "learning-loop.ai" in host else ""}://' + host
46
-
47
- self.name = name
48
- self.uuid = self.read_or_create_uuid(self.name) if uuid is None else uuid
49
- self.startup_time = datetime.now()
47
+ self.startup_datetime = datetime.now()
50
48
  self._sio_client: Optional[AsyncClient] = None
51
49
  self.status = NodeStatus(id=self.uuid, name=self.name)
52
- # NOTE this is can be set to False for Nodes which do not need to authenticate with the backend (like the DetectorNode)
53
- self.needs_login = True
54
- self._setup_sio_headers()
55
- self._register_lifecycle_events()
50
+
51
+ self.sio_headers = {'organization': self.loop_communicator.organization,
52
+ 'project': self.loop_communicator.project,
53
+ 'nodeType': node_type}
54
+
55
+ self.repeat_task: Any = None
56
56
 
57
57
  @property
58
58
  def sio_client(self) -> AsyncClient:
@@ -60,52 +60,25 @@ class Node(FastAPI):
60
60
  raise Exception('sio_client not yet initialized')
61
61
  return self._sio_client
62
62
 
63
- def sio_is_initialized(self) -> bool:
64
- return self._sio_client is not None
65
-
66
- # --------------------------------------------------- INIT ---------------------------------------------------
67
-
68
- def read_or_create_uuid(self, identifier: str) -> str:
69
- identifier = identifier.lower().replace(' ', '_')
70
- uuids = {}
71
- os.makedirs(GLOBALS.data_folder, exist_ok=True)
72
- file_path = f'{GLOBALS.data_folder}/uuids.json'
73
- if os.path.exists(file_path):
74
- with open(file_path, 'r') as f:
75
- uuids = json.load(f)
76
-
77
- uuid = uuids.get(identifier, None)
78
- if not uuid:
79
- uuid = str(uuid4())
80
- uuids[identifier] = uuid
81
- with open(file_path, 'w') as f:
82
- json.dump(uuids, f)
83
- return uuid
84
-
85
- def _setup_sio_headers(self) -> None:
86
- self.sio_headers = {'organization': self.loop_communicator.organization,
87
- 'project': self.loop_communicator.project,
88
- 'nodeType': self.get_node_type()}
89
-
90
63
  # --------------------------------------------------- APPLICATION LIFECYCLE ---------------------------------------------------
91
-
92
- def _register_lifecycle_events(self):
93
- @self.on_event("startup")
94
- async def startup():
64
+ @asynccontextmanager
65
+ async def lifespan(self, app: FastAPI): # pylint: disable=unused-argument
66
+ try:
95
67
  await self._on_startup()
96
-
97
- @self.on_event("shutdown") # NOTE only used for developent ?!
98
- async def shutdown():
68
+ self.repeat_task = asyncio.create_task(self.repeat_loop())
69
+ yield
70
+ finally:
99
71
  await self._on_shutdown()
100
-
101
- @self.on_event("startup")
102
- @repeat_every(seconds=5, raise_exceptions=False, wait_first=False)
103
- async def ensure_connected() -> None:
104
- await self._on_repeat()
72
+ if self.repeat_task is not None:
73
+ self.repeat_task.cancel()
74
+ try:
75
+ await self.repeat_task
76
+ except asyncio.CancelledError:
77
+ pass
105
78
 
106
79
  async def _on_startup(self):
107
80
  self.log.info('received "startup" lifecycle-event')
108
- Node._activate_asyncio_warnings()
81
+ # activate_asyncio_warnings()
109
82
  if self.needs_login:
110
83
  await self.loop_communicator.backend_ready()
111
84
  self.log.info('ensuring login')
@@ -123,10 +96,18 @@ class Node(FastAPI):
123
96
  self.log.info('successfully disconnected from loop.')
124
97
  await self.on_shutdown()
125
98
 
99
+ async def repeat_loop(self) -> None:
100
+ """NOTE: with the lifespan approach, we cannot use @repeat_every anymore :("""
101
+ while True:
102
+ try:
103
+ await self._on_repeat()
104
+ except asyncio.CancelledError:
105
+ return
106
+ except Exception as e:
107
+ self.log.exception(f'error in repeat loop: {e}')
108
+ await asyncio.sleep(5)
109
+
126
110
  async def _on_repeat(self):
127
- while not self.sio_is_initialized():
128
- self.log.info('Waiting for sio client to be initialized')
129
- await asyncio.sleep(1)
130
111
  if not self.sio_client.connected:
131
112
  self.log.info('Reconnecting to loop via sio')
132
113
  await self.connect_sio()
@@ -138,8 +119,11 @@ class Node(FastAPI):
138
119
  # --------------------------------------------------- SOCKET.IO ---------------------------------------------------
139
120
 
140
121
  async def create_sio_client(self):
141
- cookies = await self.loop_communicator.get_cookies()
142
- self._sio_client = AsyncClient(request_timeout=20, http_session=aiohttp.ClientSession(cookies=cookies))
122
+ """Create a socket.io client that communicates with the learning loop and register the events.
123
+ Note: The method is called in startup and soft restart of detector, so the _sio_client should always be available."""
124
+
125
+ self._sio_client = AsyncClient(request_timeout=20,
126
+ http_session=aiohttp.ClientSession(cookies=self.loop_communicator.get_cookies()))
143
127
 
144
128
  # pylint: disable=protected-access
145
129
  self.sio_client._trigger_event = ensure_socket_response(self.sio_client._trigger_event)
@@ -147,72 +131,39 @@ class Node(FastAPI):
147
131
  @self._sio_client.event
148
132
  async def connect():
149
133
  self.log.info('received "connect" via sio from loop.')
150
- self.status = NodeStatus(id=self.uuid, name=self.name)
151
- state = await self.get_state()
152
- try:
153
- await self._update_send_state(state)
154
- except:
155
- self.log.exception('Error sending state. Exception:')
156
- raise
157
134
 
158
135
  @self._sio_client.event
159
136
  async def disconnect():
160
137
  self.log.info('received "disconnect" via sio from loop.')
161
- await self._update_send_state(NodeState.Offline)
162
138
 
163
139
  @self._sio_client.event
164
140
  async def restart():
165
- self.log.info('received "restart" via sio from loop.')
166
- self.restart()
141
+ self.log.info('received "restart" via sio from loop -> restarting node.')
142
+ sys.exit(0)
167
143
 
168
144
  self.register_sio_events(self._sio_client)
169
145
 
170
146
  async def connect_sio(self):
171
- if not self.sio_is_initialized():
172
- self.log.warning('sio client not yet initialized')
173
- return
174
147
  try:
175
148
  await self.sio_client.disconnect()
176
149
  except Exception:
177
150
  pass
178
151
 
179
- self.log.info(f'(re)connecting to Learning Loop at {self.ws_url}')
152
+ self.log.info(f'(re)connecting to Learning Loop at {self.websocket_url}')
180
153
  try:
181
- await self.sio_client.connect(f"{self.ws_url}", headers=self.sio_headers, socketio_path="/ws/socket.io")
154
+ await self.sio_client.connect(f"{self.websocket_url}", headers=self.sio_headers, socketio_path="/ws/socket.io")
182
155
  self.log.info('connected to Learning Loop')
183
156
  except socketio.exceptions.ConnectionError: # type: ignore
184
157
  self.log.warning('connection error')
185
158
  except Exception:
186
- self.log.exception(f'error while connecting to "{self.ws_url}". Exception:')
187
-
188
- async def _update_send_state(self, state: NodeState):
189
- self.status.state = state
190
- if self.status.state != NodeState.Offline:
191
- await self.send_status()
159
+ self.log.exception(f'error while connecting to "{self.websocket_url}". Exception:')
192
160
 
193
161
  # --------------------------------------------------- ABSTRACT METHODS ---------------------------------------------------
194
162
 
195
- @abstractmethod
196
- def register_sio_events(self, sio_client: AsyncClient):
197
- """Register socket.io events for the communication with the learning loop.
198
- The events: connect and disconnect are already registered and should not be overwritten."""
199
-
200
- @abstractmethod
201
- async def send_status(self):
202
- """Send the current status to the learning loop.
203
- Note that currently this method is also used to react to the response of the learning loop."""
204
-
205
- @abstractmethod
206
- async def get_state(self) -> NodeState:
207
- """Return the current state of the node."""
208
-
209
- @abstractmethod
210
- def get_node_type(self):
211
- pass
212
-
213
163
  @abstractmethod
214
164
  async def on_startup(self):
215
- """This method is called when the node is started."""
165
+ """This method is called when the node is started.
166
+ Note: In this method the sio connection is not yet established!"""
216
167
 
217
168
  @abstractmethod
218
169
  async def on_shutdown(self):
@@ -221,32 +172,8 @@ class Node(FastAPI):
221
172
  @abstractmethod
222
173
  async def on_repeat(self):
223
174
  """This method is called every 10 seconds."""
224
- # --------------------------------------------------- SHARED FUNCTIONS ---------------------------------------------------
225
-
226
- def restart(self):
227
- """Restart the node."""
228
- self.log.info('restarting node')
229
- sys.exit(0)
230
-
231
- # --------------------------------------------------- HELPER ---------------------------------------------------
232
-
233
- @staticmethod
234
- def create_project_folder(context: Context) -> str:
235
- project_folder = f'{GLOBALS.data_folder}/{context.organization}/{context.project}'
236
- os.makedirs(project_folder, exist_ok=True)
237
- return project_folder
238
175
 
239
- @staticmethod
240
- def _activate_asyncio_warnings() -> None:
241
- '''Produce warnings for coroutines which take too long on the main loop and hence clog the event loop'''
242
- try:
243
- if sys.version_info.major >= 3 and sys.version_info.minor >= 7: # most
244
- loop = asyncio.get_running_loop()
245
- else:
246
- loop = asyncio.get_event_loop()
247
-
248
- loop.set_debug(True)
249
- loop.slow_callback_duration = 0.2
250
- logging.info('activated asyncio warnings')
251
- except Exception:
252
- logging.exception('could not activate asyncio warnings. Exception:')
176
+ @abstractmethod
177
+ def register_sio_events(self, sio_client: AsyncClient):
178
+ """Register (additional) socket.io events for the communication with the learning loop.
179
+ The events: connect, disconnect and restart are already registered and should not be overwritten."""
@@ -2,9 +2,10 @@ import os
2
2
  import shutil
3
3
 
4
4
  from learning_loop_node.data_classes import Context
5
- from learning_loop_node.data_exchanger import DataExchanger, check_jpeg
5
+ from learning_loop_node.data_exchanger import DataExchanger
6
6
  from learning_loop_node.globals import GLOBALS
7
7
 
8
+ from ..helpers.misc import delete_corrupt_images
8
9
  from . import test_helper
9
10
 
10
11
 
@@ -33,26 +34,26 @@ async def test_download_model(data_exchanger: DataExchanger):
33
34
 
34
35
  # pylint: disable=redefined-outer-name
35
36
  async def test_fetching_image_ids(data_exchanger: DataExchanger):
36
- ids = await data_exchanger.fetch_image_ids()
37
+ ids = await data_exchanger.fetch_image_uuids()
37
38
  assert len(ids) == 3
38
39
 
39
40
 
40
41
  async def test_download_images(data_exchanger: DataExchanger):
41
42
  _, image_folder, _ = test_helper.create_needed_folders()
42
- image_ids = await data_exchanger.fetch_image_ids()
43
+ image_ids = await data_exchanger.fetch_image_uuids()
43
44
  await data_exchanger.download_images(image_ids, image_folder)
44
45
  files = test_helper.get_files_in_folder(GLOBALS.data_folder)
45
46
  assert len(files) == 3
46
47
 
47
48
 
48
49
  async def test_download_training_data(data_exchanger: DataExchanger):
49
- image_ids = await data_exchanger.fetch_image_ids()
50
+ image_ids = await data_exchanger.fetch_image_uuids()
50
51
  image_data = await data_exchanger.download_images_data(image_ids)
51
52
  assert len(image_data) == 3
52
53
 
53
54
 
54
55
  async def test_removal_of_corrupted_images(data_exchanger: DataExchanger):
55
- image_ids = await data_exchanger.fetch_image_ids()
56
+ image_ids = await data_exchanger.fetch_image_uuids()
56
57
 
57
58
  shutil.rmtree('/tmp/img_folder', ignore_errors=True)
58
59
  os.makedirs('/tmp/img_folder', exist_ok=True)
@@ -65,7 +66,7 @@ async def test_removal_of_corrupted_images(data_exchanger: DataExchanger):
65
66
  with open('/tmp/img_folder/c1.jpg', 'w') as f:
66
67
  f.write('I am no image')
67
68
 
68
- await data_exchanger.delete_corrupt_images('/tmp/img_folder')
69
+ await delete_corrupt_images('/tmp/img_folder', True)
69
70
 
70
- assert len(os.listdir('/tmp/img_folder')) == num_images if check_jpeg else num_images - 1
71
+ assert len(os.listdir('/tmp/img_folder')) == num_images if data_exchanger.check_jpeg else num_images - 1
71
72
  shutil.rmtree('/tmp/img_folder', ignore_errors=True)