learning-loop-node 0.10.12__py3-none-any.whl → 0.10.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of learning-loop-node might be problematic. Click here for more details.

Files changed (37) hide show
  1. learning_loop_node/annotation/annotator_node.py +11 -10
  2. learning_loop_node/data_classes/detections.py +34 -25
  3. learning_loop_node/data_classes/general.py +27 -17
  4. learning_loop_node/data_exchanger.py +6 -5
  5. learning_loop_node/detector/detector_logic.py +10 -4
  6. learning_loop_node/detector/detector_node.py +80 -54
  7. learning_loop_node/detector/inbox_filter/relevance_filter.py +9 -3
  8. learning_loop_node/detector/outbox.py +8 -1
  9. learning_loop_node/detector/rest/about.py +34 -9
  10. learning_loop_node/detector/rest/backdoor_controls.py +10 -29
  11. learning_loop_node/detector/rest/detect.py +27 -19
  12. learning_loop_node/detector/rest/model_version_control.py +30 -13
  13. learning_loop_node/detector/rest/operation_mode.py +11 -5
  14. learning_loop_node/detector/rest/outbox_mode.py +7 -1
  15. learning_loop_node/helpers/log_conf.py +5 -0
  16. learning_loop_node/node.py +97 -49
  17. learning_loop_node/rest.py +55 -0
  18. learning_loop_node/tests/detector/conftest.py +36 -2
  19. learning_loop_node/tests/detector/test_client_communication.py +21 -19
  20. learning_loop_node/tests/detector/test_detector_node.py +86 -0
  21. learning_loop_node/tests/trainer/conftest.py +4 -4
  22. learning_loop_node/tests/trainer/states/test_state_detecting.py +8 -9
  23. learning_loop_node/tests/trainer/states/test_state_download_train_model.py +8 -8
  24. learning_loop_node/tests/trainer/states/test_state_prepare.py +6 -7
  25. learning_loop_node/tests/trainer/states/test_state_sync_confusion_matrix.py +21 -18
  26. learning_loop_node/tests/trainer/states/test_state_train.py +6 -8
  27. learning_loop_node/tests/trainer/states/test_state_upload_detections.py +7 -9
  28. learning_loop_node/tests/trainer/states/test_state_upload_model.py +7 -8
  29. learning_loop_node/tests/trainer/test_errors.py +2 -2
  30. learning_loop_node/trainer/io_helpers.py +3 -6
  31. learning_loop_node/trainer/rest/backdoor_controls.py +19 -40
  32. learning_loop_node/trainer/trainer_logic.py +4 -4
  33. learning_loop_node/trainer/trainer_logic_generic.py +15 -12
  34. learning_loop_node/trainer/trainer_node.py +5 -4
  35. {learning_loop_node-0.10.12.dist-info → learning_loop_node-0.10.14.dist-info}/METADATA +16 -15
  36. {learning_loop_node-0.10.12.dist-info → learning_loop_node-0.10.14.dist-info}/RECORD +37 -35
  37. {learning_loop_node-0.10.12.dist-info → learning_loop_node-0.10.14.dist-info}/WHEEL +0 -0
@@ -1,25 +1,50 @@
1
1
 
2
- from typing import TYPE_CHECKING
2
+ import sys
3
+ from dataclasses import dataclass, field
4
+ from typing import TYPE_CHECKING, Optional
3
5
 
4
6
  from fastapi import APIRouter, Request
5
7
 
8
+ from ...data_classes import ModelInformation
9
+
6
10
  if TYPE_CHECKING:
7
11
  from ..detector_node import DetectorNode
12
+ KWONLY_SLOTS = {'kw_only': True, 'slots': True} if sys.version_info >= (3, 10) else {}
8
13
 
9
14
  router = APIRouter()
10
15
 
11
16
 
12
- @router.get("/about")
17
+ @dataclass(**KWONLY_SLOTS)
18
+ class AboutResponse:
19
+ operation_mode: str = field(metadata={"description": "The operation mode of the detector node"})
20
+ state: Optional[str] = field(metadata={
21
+ "description": "The state of the detector node",
22
+ "example": "idle, online, detecting"})
23
+ model_info: Optional[ModelInformation] = field(metadata={
24
+ "description": "Information about the model of the detector node"})
25
+ target_model: Optional[str] = field(metadata={"description": "The target model of the detector node"})
26
+ version_control: str = field(metadata={
27
+ "description": "The version control mode of the detector node",
28
+ "example": "follow_loop, specific_version, pause"})
29
+
30
+
31
+ @router.get("/about", response_model=AboutResponse)
13
32
  async def get_about(request: Request):
14
33
  '''
34
+ Get information about the detector node.
35
+
15
36
  Example Usage
16
- curl http://localhost/about
37
+
38
+ curl http://hosturl/about
17
39
  '''
18
40
  app: 'DetectorNode' = request.app
19
41
 
20
- return {
21
- 'operation_mode': app.operation_mode.value,
22
- 'state': app.status.state,
23
- 'model_info': app.detector_logic._model_info, # pylint: disable=protected-access
24
- 'target_model': app.target_model.version if app.target_model is not None else 'None',
25
- }
42
+ response = AboutResponse(
43
+ operation_mode=app.operation_mode.value,
44
+ state=app.status.state,
45
+ model_info=app.detector_logic._model_info, # pylint: disable=protected-access
46
+ target_model=app.target_model.version if app.target_model is not None else None,
47
+ version_control=app.version_control.value
48
+ )
49
+
50
+ return response
@@ -14,43 +14,24 @@ if TYPE_CHECKING:
14
14
  router = APIRouter()
15
15
 
16
16
 
17
- @router.put("/socketio")
18
- async def _socketio(request: Request):
17
+ @router.post("/reset")
18
+ async def _reset(request: Request):
19
19
  '''
20
+ Soft-Reset the detector node.
21
+
20
22
  Example Usage
21
23
 
22
- curl -X PUT -d "on" http://localhost:8007/socketio
24
+ curl -X POST http://hosturl/reset
23
25
  '''
24
- state = str(await request.body(), 'utf-8')
25
- await _switch_socketio(state, request.app)
26
-
27
-
28
- async def _switch_socketio(state: str, detector_node: 'DetectorNode'):
29
- if state == 'off':
30
- logging.info('BC: turning socketio off')
31
- await detector_node.sio_client.disconnect()
32
- if state == 'on':
33
- logging.info('BC: turning socketio on')
34
- await detector_node.connect_sio()
35
-
36
-
37
- @router.post("/reset")
38
- async def _reset(request: Request):
39
26
  logging.info('BC: reset')
27
+ detector_node: 'DetectorNode' = request.app
28
+
40
29
  try:
41
30
  shutil.rmtree(GLOBALS.data_folder, ignore_errors=True)
42
31
  os.makedirs(GLOBALS.data_folder, exist_ok=True)
43
32
 
44
- # get file dir
45
- # restart_path = Path(os.path.realpath(__file__)) / 'restart' / 'restart.py'
46
- # restart_path = Path(os.getcwd()).absolute() / 'app_code' / 'restart' / 'restart.py'
47
- # restart_path.touch()
48
- # assert isinstance(request.app, 'DetectorNode')
49
- await request.app.soft_reload()
50
-
51
- # assert isinstance(request.app, DetectorNode)
52
- # request.app.reload(reason='------- reset was called from backdoor controls',)
53
- except Exception as e:
54
- logging.error(f'BC: could not reset: {e}')
33
+ await detector_node.soft_reload()
34
+ except Exception:
35
+ logging.exception('BC: could not reset:')
55
36
  return False
56
37
  return True
@@ -1,45 +1,53 @@
1
1
  import logging
2
- from typing import Optional
2
+ from typing import TYPE_CHECKING, Optional
3
3
 
4
4
  import numpy as np
5
5
  from fastapi import APIRouter, File, Header, Request, UploadFile
6
6
  from fastapi.responses import JSONResponse
7
7
 
8
+ from ...data_classes.detections import Detections
9
+
10
+ if TYPE_CHECKING:
11
+ from ..detector_node import DetectorNode
12
+
8
13
  router = APIRouter()
9
14
 
10
15
 
11
- @router.post("/detect")
16
+ @router.post("/detect", response_model=Detections)
12
17
  async def http_detect(
13
18
  request: Request,
14
- file: UploadFile = File(...),
15
- camera_id: Optional[str] = Header(None),
16
- mac: Optional[str] = Header(None),
17
- tags: Optional[str] = Header(None),
18
- autoupload: Optional[str] = Header(None),
19
+ file: UploadFile = File(..., description='The image file to run detection on'),
20
+ camera_id: Optional[str] = Header(None, description='The camera id (used by learning loop)'),
21
+ mac: Optional[str] = Header(None, description='The camera mac address (used by learning loop)'),
22
+ tags: Optional[str] = Header(None, description='Tags to add to the image (used by learning loop)'),
23
+ source: Optional[str] = Header(None, description='The source of the image (used by learning loop)'),
24
+ autoupload: Optional[str] = Header(None, description='Mode to decide whether to upload the image to the learning loop',
25
+ examples=['filtered', 'all', 'disabled']),
19
26
  ):
20
27
  """
21
- Example Usage
28
+ Single image example:
29
+
30
+ curl --request POST -F 'file=@test.jpg' localhost:8004/detect -H 'autoupload: all' -H 'camera-id: front_cam' -H 'source: test' -H 'tags: test,test2'
22
31
 
23
- curl --request POST -F 'file=@test.jpg' localhost:8004/detect
32
+ Multiple images example:
24
33
 
25
34
  for i in `seq 1 10`; do time curl --request POST -F 'file=@test.jpg' localhost:8004/detect; done
26
35
 
27
- You can additionally provide the following camera parameters:
28
- - `autoupload`: configures auto-submission to the learning loop; `filtered` (default), `all`, `disabled` (example curl parameter `-H 'autoupload: all'`)
29
- - `camera-id`: a string which groups images for submission together (example curl parameter `-H 'camera-id: front_cam'`)
30
36
  """
31
37
  try:
32
38
  np_image = np.fromfile(file.file, np.uint8)
33
39
  except Exception as exc:
34
- logging.exception(f'Error during reading of image {file.filename}.')
40
+ logging.exception('Error during reading of image %s.', file.filename)
35
41
  raise Exception(f'Uploaded file {file.filename} is no image file.') from exc
36
42
 
37
43
  try:
38
- detections = await request.app.get_detections(raw_image=np_image,
39
- camera_id=camera_id or mac or None,
40
- tags=tags.split(',') if tags else [],
41
- autoupload=autoupload,)
44
+ app: 'DetectorNode' = request.app
45
+ detections = await app.get_detections(raw_image=np_image,
46
+ camera_id=camera_id or mac or None,
47
+ tags=tags.split(',') if tags else [],
48
+ source=source,
49
+ autoupload=autoupload)
42
50
  except Exception as exc:
43
- logging.exception(f'Error during detection of image {file.filename}.')
51
+ logging.exception('Error during detection of image %s.', file.filename)
44
52
  raise Exception(f'Error during detection of image {file.filename}.') from exc
45
- return JSONResponse(detections)
53
+ return detections
@@ -1,7 +1,9 @@
1
1
 
2
2
  import os
3
+ import sys
4
+ from dataclasses import dataclass, field
3
5
  from enum import Enum
4
- from typing import TYPE_CHECKING
6
+ from typing import TYPE_CHECKING, List
5
7
 
6
8
  from fastapi import APIRouter, HTTPException, Request
7
9
 
@@ -10,6 +12,7 @@ from ...globals import GLOBALS
10
12
 
11
13
  if TYPE_CHECKING:
12
14
  from ..detector_node import DetectorNode
15
+ KWONLY_SLOTS = {'kw_only': True, 'slots': True} if sys.version_info >= (3, 10) else {}
13
16
 
14
17
  router = APIRouter()
15
18
 
@@ -20,9 +23,20 @@ class VersionMode(str, Enum):
20
23
  Pause = 'pause' # will pause the updates
21
24
 
22
25
 
26
+ @dataclass(**KWONLY_SLOTS)
27
+ class ModelVersionResponse:
28
+ current_version: str = field(metadata={"description": "The version of the model currently used by the detector."})
29
+ target_version: str = field(metadata={"description": "The target model version set in the detector."})
30
+ loop_version: str = field(metadata={"description": "The target model version specified by the loop."})
31
+ local_versions: List[str] = field(metadata={"description": "The locally available versions of the model."})
32
+ version_control: str = field(metadata={"description": "The version control mode."})
33
+
34
+
23
35
  @router.get("/model_version")
24
36
  async def get_version(request: Request):
25
37
  '''
38
+ Get information about the model version control and the current model version.
39
+
26
40
  Example Usage
27
41
  curl http://localhost/model_version
28
42
  '''
@@ -35,28 +49,31 @@ async def get_version(request: Request):
35
49
  loop_version = app.loop_deployment_target.version if app.loop_deployment_target is not None else 'None'
36
50
 
37
51
  local_versions: list[str] = []
38
-
39
- local_models = os.listdir(os.path.join(GLOBALS.data_folder, 'models'))
52
+ models_path = os.path.join(GLOBALS.data_folder, 'models')
53
+ local_models = os.listdir(models_path) if os.path.exists(models_path) else []
40
54
  for model in local_models:
41
55
  if model.replace('.', '').isdigit():
42
56
  local_versions.append(model)
43
57
 
44
- return {
45
- 'current_version': current_version,
46
- 'target_version': target_version,
47
- 'loop_version': loop_version,
48
- 'local_versions': local_versions,
49
- 'version_control': app.version_control.value,
50
- }
58
+ response = ModelVersionResponse(
59
+ current_version=current_version,
60
+ target_version=target_version,
61
+ loop_version=loop_version,
62
+ local_versions=local_versions,
63
+ version_control=app.version_control.value,
64
+ )
65
+ return response
51
66
 
52
67
 
53
68
  @router.put("/model_version")
54
69
  async def put_version(request: Request):
55
70
  '''
71
+ Set the model version control mode.
72
+
56
73
  Example Usage
57
- curl -X PUT -d "follow_loop" http://localhost/model_version
58
- curl -X PUT -d "pause" http://localhost/model_version
59
- curl -X PUT -d "13.6" http://localhost/model_version
74
+ curl -X PUT -d "follow_loop" http://hosturl/model_version
75
+ curl -X PUT -d "pause" http://hosturl/model_version
76
+ curl -X PUT -d "13.6" http://hosturl/model_version
60
77
  '''
61
78
  app: 'DetectorNode' = request.app
62
79
  content = str(await request.body(), 'utf-8')
@@ -22,7 +22,10 @@ class OperationMode(str, Enum):
22
22
  @router.put("/operation_mode")
23
23
  async def put_operation_mode(request: Request):
24
24
  '''
25
+ Set the operation mode of the detector node.
26
+
25
27
  Example Usage
28
+
26
29
  curl -X PUT -d "check_for_updates" http://localhost/operation_mode
27
30
  curl -X PUT -d "detecting" http://localhost/operation_mode
28
31
  '''
@@ -34,22 +37,25 @@ async def put_operation_mode(request: Request):
34
37
  raise HTTPException(422, str(exc)) from exc
35
38
  node: DetectorNode = request.app
36
39
 
37
- logging.info(f'current node state : {node.status.state}')
38
- logging.info(f'current operation mode : {node.operation_mode.value}')
39
- logging.info(f'target operation mode : {target_mode}')
40
+ logging.info('current node state : %s', node.status.state)
41
+ logging.info('current operation mode : %s', node.operation_mode.value)
42
+ logging.info('target operation mode : %s', target_mode)
40
43
  if target_mode == node.operation_mode:
41
44
  logging.info('operation mode already set')
42
45
  return "OK"
43
46
 
44
47
  await node.set_operation_mode(target_mode)
45
- logging.info(f'operation mode set to : {target_mode}')
48
+ logging.info('operation mode set to : %s', target_mode)
46
49
  return "OK"
47
50
 
48
51
 
49
- @router.get("/operation_mode")
52
+ @router.get("/operation_mode", response_class=PlainTextResponse)
50
53
  async def get_operation_mode(request: Request):
51
54
  '''
55
+ Get the operation mode of the detector node.
56
+
52
57
  Example Usage
58
+
53
59
  curl http://localhost/operation_mode
54
60
  '''
55
61
  return PlainTextResponse(request.app.operation_mode.value)
@@ -6,10 +6,13 @@ from ..outbox import Outbox
6
6
  router = APIRouter()
7
7
 
8
8
 
9
- @router.get("/outbox_mode")
9
+ @router.get("/outbox_mode", response_class=PlainTextResponse)
10
10
  async def get_outbox_mode(request: Request):
11
11
  '''
12
+ Get the outbox mode of the detector node.
13
+
12
14
  Example Usage
15
+
13
16
  curl http://localhost/outbox_mode
14
17
  '''
15
18
  outbox: Outbox = request.app.outbox
@@ -19,7 +22,10 @@ async def get_outbox_mode(request: Request):
19
22
  @router.put("/outbox_mode")
20
23
  async def put_outbox_mode(request: Request):
21
24
  '''
25
+ Set the outbox mode of the detector node.
26
+
22
27
  Example Usage
28
+
23
29
  curl -X PUT -d "continuous_upload" http://localhost/outbox_mode
24
30
  curl -X PUT -d "stopped" http://localhost/outbox_mode
25
31
  '''
@@ -23,6 +23,11 @@ LOGGING_CONF = {
23
23
  'level': 'INFO',
24
24
  'propagate': False,
25
25
  },
26
+ 'Node': {
27
+ 'handlers': ['console'],
28
+ 'level': 'INFO',
29
+ 'propagate': False,
30
+ },
26
31
  },
27
32
  }
28
33
 
@@ -8,7 +8,6 @@ from datetime import datetime
8
8
  from typing import Any, Optional
9
9
 
10
10
  import aiohttp
11
- import socketio
12
11
  from aiohttp import TCPConnector
13
12
  from fastapi import FastAPI
14
13
  from socketio import AsyncClient
@@ -18,6 +17,11 @@ from .data_exchanger import DataExchanger
18
17
  from .helpers import log_conf
19
18
  from .helpers.misc import ensure_socket_response, read_or_create_uuid
20
19
  from .loop_communication import LoopCommunicator
20
+ from .rest import router
21
+
22
+
23
+ class NodeConnectionError(Exception):
24
+ pass
21
25
 
22
26
 
23
27
  class Node(FastAPI):
@@ -41,9 +45,8 @@ class Node(FastAPI):
41
45
  self.uuid = uuid or read_or_create_uuid(self.name)
42
46
  self.needs_login = needs_login
43
47
 
44
- self.log = logging.getLogger()
45
- self.loop_communicator = LoopCommunicator()
46
- self.websocket_url = self.loop_communicator.websocket_url()
48
+ self.log = logging.getLogger('Node')
49
+ self.init_loop_communicator()
47
50
  self.data_exchanger = DataExchanger(None, self.loop_communicator)
48
51
 
49
52
  self.startup_datetime = datetime.now()
@@ -55,6 +58,19 @@ class Node(FastAPI):
55
58
  'nodeType': node_type}
56
59
 
57
60
  self.repeat_task: Any = None
61
+ self.socket_connection_broken = False
62
+ self._skip_repeat_loop = False
63
+
64
+ self.include_router(router)
65
+
66
+ self.CONNECTED_TO_LOOP = asyncio.Event()
67
+ self.DISCONNECTED_FROM_LOOP = asyncio.Event()
68
+
69
+ self.repeat_loop_lock = asyncio.Lock()
70
+
71
+ def init_loop_communicator(self):
72
+ self.loop_communicator = LoopCommunicator()
73
+ self.websocket_url = self.loop_communicator.websocket_url()
58
74
 
59
75
  @property
60
76
  def sio_client(self) -> AsyncClient:
@@ -66,7 +82,10 @@ class Node(FastAPI):
66
82
  @asynccontextmanager
67
83
  async def lifespan(self, app: FastAPI): # pylint: disable=unused-argument
68
84
  try:
69
- await self._on_startup()
85
+ try:
86
+ await self._on_startup()
87
+ except Exception:
88
+ self.log.exception('Fatal error during startup: %s')
70
89
  self.repeat_task = asyncio.create_task(self.repeat_loop())
71
90
  yield
72
91
  finally:
@@ -80,13 +99,10 @@ class Node(FastAPI):
80
99
 
81
100
  async def _on_startup(self):
82
101
  self.log.info('received "startup" lifecycle-event')
83
- # activate_asyncio_warnings()
84
- if self.needs_login:
85
- await self.loop_communicator.backend_ready()
86
- self.log.info('ensuring login')
87
- await self.loop_communicator.ensure_login()
88
- self.log.info('create sio client')
89
- await self.create_sio_client()
102
+ try:
103
+ await self.reconnect_to_loop()
104
+ except Exception:
105
+ self.log.warning('Could not establish sio connection to loop during startup')
90
106
  self.log.info('done')
91
107
  await self.on_startup()
92
108
 
@@ -99,55 +115,93 @@ class Node(FastAPI):
99
115
  await self.on_shutdown()
100
116
 
101
117
  async def repeat_loop(self) -> None:
102
- """NOTE: with the lifespan approach, we cannot use @repeat_every anymore :("""
103
118
  while True:
119
+ if self._skip_repeat_loop:
120
+ self.log.debug('node is muted, skipping repeat loop')
121
+ await asyncio.sleep(1)
122
+ continue
104
123
  try:
105
- await self._on_repeat()
124
+ async with self.repeat_loop_lock:
125
+ await self._ensure_sio_connection()
126
+ await self.on_repeat()
106
127
  except asyncio.CancelledError:
107
128
  return
108
- except Exception as e:
109
- self.log.exception(f'error in repeat loop: {e}')
129
+ except Exception:
130
+ self.log.exception('error in repeat loop')
131
+
110
132
  await asyncio.sleep(5)
111
133
 
112
- async def _on_repeat(self):
113
- if not self.sio_client.connected:
114
- self.log.info('Reconnecting to loop via sio')
115
- await self.connect_sio()
116
- if not self.sio_client.connected:
117
- self.log.warning('Could not connect to loop via sio')
118
- return
119
- await self.on_repeat()
134
+ async def _ensure_sio_connection(self):
135
+ if self.socket_connection_broken or self._sio_client is None or not self.sio_client.connected:
136
+ self.log.info('Reconnecting to loop via sio due to %s',
137
+ 'broken connection' if self.socket_connection_broken else 'no connection')
138
+ await self.reconnect_to_loop()
139
+
140
+ async def reconnect_to_loop(self):
141
+ """Initialize the loop communicator, log in if needed and reconnect to the loop via socket.io."""
142
+ self.init_loop_communicator()
143
+ if self.needs_login:
144
+ await self.loop_communicator.ensure_login(relogin=True)
145
+ try:
146
+ await self._reconnect_socketio()
147
+ except NodeConnectionError:
148
+ self.log.exception('Could not reset sio connection to loop')
149
+ self.socket_connection_broken = True
150
+ raise
151
+
152
+ self.socket_connection_broken = False
153
+
154
+ def set_skip_repeat_loop(self, value: bool):
155
+ self._skip_repeat_loop = value
156
+ self.log.info('node is muted: %s', value)
120
157
 
121
158
  # --------------------------------------------------- SOCKET.IO ---------------------------------------------------
122
159
 
123
- async def create_sio_client(self):
124
- """Create a socket.io client that communicates with the learning loop and register the events.
125
- Note: The method is called in startup and soft restart of detector, so the _sio_client should always be available."""
160
+ async def _reconnect_socketio(self):
161
+ """Create a socket.io client, connect it to the learning loop and register its events.
162
+ The current client is disconnected and deleted if it already exists."""
163
+
164
+ self.log.debug('-------------- Connecting to loop via socket.io -------------------')
165
+ self.log.debug('HTTP Cookies: %s\n', self.loop_communicator.get_cookies())
166
+
167
+ if self._sio_client is not None:
168
+ try:
169
+ await self.sio_client.disconnect()
170
+ self.log.info('disconnected from loop via sio')
171
+ # NOTE: without waiting for the disconnect event, we might disconnect the next connection too early
172
+ await asyncio.wait_for(self.DISCONNECTED_FROM_LOOP.wait(), timeout=5)
173
+ except asyncio.TimeoutError:
174
+ self.log.warning(
175
+ 'Did not receive disconnect event from loop within 5 seconds.\nContinuing with new connection...')
176
+ except Exception as e:
177
+ self.log.warning('Could not disconnect from loop via sio: %s.\nIgnoring...', e)
178
+ self._sio_client = None
126
179
 
180
+ connector = None
127
181
  if self.loop_communicator.ssl_cert_path:
128
- logging.info(f'SIO using SSL certificate path: {self.loop_communicator.ssl_cert_path}')
182
+ logging.info('SIO using SSL certificate path: %s', self.loop_communicator.ssl_cert_path)
129
183
  ssl_context = ssl.create_default_context(cafile=self.loop_communicator.ssl_cert_path)
130
184
  ssl_context.check_hostname = False
131
185
  ssl_context.verify_mode = ssl.CERT_REQUIRED
132
186
  connector = TCPConnector(ssl=ssl_context)
133
- self._sio_client = AsyncClient(request_timeout=20,
134
- http_session=aiohttp.ClientSession(cookies=self.loop_communicator.get_cookies(),
135
- connector=connector))
136
187
 
137
- else:
138
- self._sio_client = AsyncClient(request_timeout=20,
139
- http_session=aiohttp.ClientSession(cookies=self.loop_communicator.get_cookies()))
188
+ self._sio_client = AsyncClient(request_timeout=20, http_session=aiohttp.ClientSession(
189
+ cookies=self.loop_communicator.get_cookies(), connector=connector))
140
190
 
141
191
  # pylint: disable=protected-access
142
- self.sio_client._trigger_event = ensure_socket_response(self.sio_client._trigger_event)
192
+ self._sio_client._trigger_event = ensure_socket_response(self._sio_client._trigger_event)
143
193
 
144
194
  @self._sio_client.event
145
195
  async def connect():
146
196
  self.log.info('received "connect" via sio from loop.')
197
+ self.CONNECTED_TO_LOOP.set()
198
+ self.DISCONNECTED_FROM_LOOP.clear()
147
199
 
148
200
  @self._sio_client.event
149
201
  async def disconnect():
150
202
  self.log.info('received "disconnect" via sio from loop.')
203
+ self.DISCONNECTED_FROM_LOOP.set()
204
+ self.CONNECTED_TO_LOOP.clear()
151
205
 
152
206
  @self._sio_client.event
153
207
  async def restart():
@@ -155,21 +209,15 @@ class Node(FastAPI):
155
209
  sys.exit(0)
156
210
 
157
211
  self.register_sio_events(self._sio_client)
158
-
159
- async def connect_sio(self):
160
212
  try:
161
- await self.sio_client.disconnect()
162
- except Exception:
163
- pass
164
-
165
- self.log.info(f'(re)connecting to Learning Loop at {self.websocket_url}')
166
- try:
167
- await self.sio_client.connect(f"{self.websocket_url}", headers=self.sio_headers, socketio_path="/ws/socket.io")
168
- self.log.info('connected to Learning Loop')
169
- except socketio.exceptions.ConnectionError: # type: ignore
170
- self.log.warning('connection error')
171
- except Exception:
172
- self.log.exception(f'error while connecting to "{self.websocket_url}". Exception:')
213
+ await self._sio_client.connect(f"{self.websocket_url}", headers=self.sio_headers, socketio_path="/ws/socket.io")
214
+ except Exception as e:
215
+ self.log.exception('Could not connect socketio client to loop')
216
+ raise NodeConnectionError('Could not connect socketio client to loop') from e
217
+
218
+ if not self._sio_client.connected:
219
+ self.log.exception('Could not connect socketio client to loop')
220
+ raise NodeConnectionError('Could not connect socketio client to loop')
173
221
 
174
222
  # --------------------------------------------------- ABSTRACT METHODS ---------------------------------------------------
175
223
 
@@ -0,0 +1,55 @@
1
+ import logging
2
+ from typing import TYPE_CHECKING
3
+
4
+ from fastapi import APIRouter, HTTPException, Request
5
+
6
+ if TYPE_CHECKING:
7
+ from .node import Node
8
+
9
+
10
+ router = APIRouter()
11
+ logger = logging.getLogger('Node.rest')
12
+
13
+
14
+ @router.put("/debug_logging")
15
+ async def _debug_logging(request: Request) -> str:
16
+ '''
17
+ Example Usage
18
+
19
+ curl -X PUT -d "on" http://localhost:8007/debug_logging
20
+ '''
21
+ state = str(await request.body(), 'utf-8')
22
+ node: 'Node' = request.app
23
+
24
+ if state == 'off':
25
+ logger.info('turning debug logging off')
26
+ node.log.setLevel('INFO')
27
+ return 'off'
28
+ if state == 'on':
29
+ logger.info('turning debug logging on')
30
+ node.log.setLevel('DEBUG')
31
+ return 'on'
32
+ raise HTTPException(status_code=400, detail='Invalid state')
33
+
34
+
35
+ @router.put("/socketio")
36
+ async def _socketio(request: Request) -> str:
37
+ '''
38
+ Enable or disable the socketio connection to the learning loop.
39
+ Not intended to be used outside of testing.
40
+
41
+ Example Usage
42
+
43
+ curl -X PUT -d "on" http://hosturl/socketio
44
+ '''
45
+ state = str(await request.body(), 'utf-8')
46
+ node: 'Node' = request.app
47
+
48
+ if state == 'off':
49
+ await node.sio_client.disconnect()
50
+ node.set_skip_repeat_loop(True) # Prevent auto-reconnection
51
+ return 'off'
52
+ if state == 'on':
53
+ node.set_skip_repeat_loop(False) # Allow auto-reconnection (1 sec delay)
54
+ return 'on'
55
+ raise HTTPException(status_code=400, detail='Invalid state')