learning-loop-node 0.10.12__py3-none-any.whl → 0.10.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of learning-loop-node might be problematic. Click here for more details.
- learning_loop_node/annotation/annotator_node.py +11 -10
- learning_loop_node/data_classes/detections.py +34 -25
- learning_loop_node/data_classes/general.py +27 -17
- learning_loop_node/data_exchanger.py +6 -5
- learning_loop_node/detector/detector_logic.py +10 -4
- learning_loop_node/detector/detector_node.py +80 -54
- learning_loop_node/detector/inbox_filter/relevance_filter.py +9 -3
- learning_loop_node/detector/outbox.py +8 -1
- learning_loop_node/detector/rest/about.py +34 -9
- learning_loop_node/detector/rest/backdoor_controls.py +10 -29
- learning_loop_node/detector/rest/detect.py +27 -19
- learning_loop_node/detector/rest/model_version_control.py +30 -13
- learning_loop_node/detector/rest/operation_mode.py +11 -5
- learning_loop_node/detector/rest/outbox_mode.py +7 -1
- learning_loop_node/helpers/log_conf.py +5 -0
- learning_loop_node/node.py +97 -49
- learning_loop_node/rest.py +55 -0
- learning_loop_node/tests/detector/conftest.py +36 -2
- learning_loop_node/tests/detector/test_client_communication.py +21 -19
- learning_loop_node/tests/detector/test_detector_node.py +86 -0
- learning_loop_node/tests/trainer/conftest.py +4 -4
- learning_loop_node/tests/trainer/states/test_state_detecting.py +8 -9
- learning_loop_node/tests/trainer/states/test_state_download_train_model.py +8 -8
- learning_loop_node/tests/trainer/states/test_state_prepare.py +6 -7
- learning_loop_node/tests/trainer/states/test_state_sync_confusion_matrix.py +21 -18
- learning_loop_node/tests/trainer/states/test_state_train.py +6 -8
- learning_loop_node/tests/trainer/states/test_state_upload_detections.py +7 -9
- learning_loop_node/tests/trainer/states/test_state_upload_model.py +7 -8
- learning_loop_node/tests/trainer/test_errors.py +2 -2
- learning_loop_node/trainer/io_helpers.py +3 -6
- learning_loop_node/trainer/rest/backdoor_controls.py +19 -40
- learning_loop_node/trainer/trainer_logic.py +4 -4
- learning_loop_node/trainer/trainer_logic_generic.py +15 -12
- learning_loop_node/trainer/trainer_node.py +5 -4
- {learning_loop_node-0.10.12.dist-info → learning_loop_node-0.10.14.dist-info}/METADATA +16 -15
- {learning_loop_node-0.10.12.dist-info → learning_loop_node-0.10.14.dist-info}/RECORD +37 -35
- {learning_loop_node-0.10.12.dist-info → learning_loop_node-0.10.14.dist-info}/WHEEL +0 -0
|
@@ -1,25 +1,50 @@
|
|
|
1
1
|
|
|
2
|
-
|
|
2
|
+
import sys
|
|
3
|
+
from dataclasses import dataclass, field
|
|
4
|
+
from typing import TYPE_CHECKING, Optional
|
|
3
5
|
|
|
4
6
|
from fastapi import APIRouter, Request
|
|
5
7
|
|
|
8
|
+
from ...data_classes import ModelInformation
|
|
9
|
+
|
|
6
10
|
if TYPE_CHECKING:
|
|
7
11
|
from ..detector_node import DetectorNode
|
|
12
|
+
KWONLY_SLOTS = {'kw_only': True, 'slots': True} if sys.version_info >= (3, 10) else {}
|
|
8
13
|
|
|
9
14
|
router = APIRouter()
|
|
10
15
|
|
|
11
16
|
|
|
12
|
-
@
|
|
17
|
+
@dataclass(**KWONLY_SLOTS)
|
|
18
|
+
class AboutResponse:
|
|
19
|
+
operation_mode: str = field(metadata={"description": "The operation mode of the detector node"})
|
|
20
|
+
state: Optional[str] = field(metadata={
|
|
21
|
+
"description": "The state of the detector node",
|
|
22
|
+
"example": "idle, online, detecting"})
|
|
23
|
+
model_info: Optional[ModelInformation] = field(metadata={
|
|
24
|
+
"description": "Information about the model of the detector node"})
|
|
25
|
+
target_model: Optional[str] = field(metadata={"description": "The target model of the detector node"})
|
|
26
|
+
version_control: str = field(metadata={
|
|
27
|
+
"description": "The version control mode of the detector node",
|
|
28
|
+
"example": "follow_loop, specific_version, pause"})
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
@router.get("/about", response_model=AboutResponse)
|
|
13
32
|
async def get_about(request: Request):
|
|
14
33
|
'''
|
|
34
|
+
Get information about the detector node.
|
|
35
|
+
|
|
15
36
|
Example Usage
|
|
16
|
-
|
|
37
|
+
|
|
38
|
+
curl http://hosturl/about
|
|
17
39
|
'''
|
|
18
40
|
app: 'DetectorNode' = request.app
|
|
19
41
|
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
|
|
25
|
-
|
|
42
|
+
response = AboutResponse(
|
|
43
|
+
operation_mode=app.operation_mode.value,
|
|
44
|
+
state=app.status.state,
|
|
45
|
+
model_info=app.detector_logic._model_info, # pylint: disable=protected-access
|
|
46
|
+
target_model=app.target_model.version if app.target_model is not None else None,
|
|
47
|
+
version_control=app.version_control.value
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
return response
|
|
@@ -14,43 +14,24 @@ if TYPE_CHECKING:
|
|
|
14
14
|
router = APIRouter()
|
|
15
15
|
|
|
16
16
|
|
|
17
|
-
@router.
|
|
18
|
-
async def
|
|
17
|
+
@router.post("/reset")
|
|
18
|
+
async def _reset(request: Request):
|
|
19
19
|
'''
|
|
20
|
+
Soft-Reset the detector node.
|
|
21
|
+
|
|
20
22
|
Example Usage
|
|
21
23
|
|
|
22
|
-
curl -X
|
|
24
|
+
curl -X POST http://hosturl/reset
|
|
23
25
|
'''
|
|
24
|
-
state = str(await request.body(), 'utf-8')
|
|
25
|
-
await _switch_socketio(state, request.app)
|
|
26
|
-
|
|
27
|
-
|
|
28
|
-
async def _switch_socketio(state: str, detector_node: 'DetectorNode'):
|
|
29
|
-
if state == 'off':
|
|
30
|
-
logging.info('BC: turning socketio off')
|
|
31
|
-
await detector_node.sio_client.disconnect()
|
|
32
|
-
if state == 'on':
|
|
33
|
-
logging.info('BC: turning socketio on')
|
|
34
|
-
await detector_node.connect_sio()
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
@router.post("/reset")
|
|
38
|
-
async def _reset(request: Request):
|
|
39
26
|
logging.info('BC: reset')
|
|
27
|
+
detector_node: 'DetectorNode' = request.app
|
|
28
|
+
|
|
40
29
|
try:
|
|
41
30
|
shutil.rmtree(GLOBALS.data_folder, ignore_errors=True)
|
|
42
31
|
os.makedirs(GLOBALS.data_folder, exist_ok=True)
|
|
43
32
|
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
# restart_path.touch()
|
|
48
|
-
# assert isinstance(request.app, 'DetectorNode')
|
|
49
|
-
await request.app.soft_reload()
|
|
50
|
-
|
|
51
|
-
# assert isinstance(request.app, DetectorNode)
|
|
52
|
-
# request.app.reload(reason='------- reset was called from backdoor controls',)
|
|
53
|
-
except Exception as e:
|
|
54
|
-
logging.error(f'BC: could not reset: {e}')
|
|
33
|
+
await detector_node.soft_reload()
|
|
34
|
+
except Exception:
|
|
35
|
+
logging.exception('BC: could not reset:')
|
|
55
36
|
return False
|
|
56
37
|
return True
|
|
@@ -1,45 +1,53 @@
|
|
|
1
1
|
import logging
|
|
2
|
-
from typing import Optional
|
|
2
|
+
from typing import TYPE_CHECKING, Optional
|
|
3
3
|
|
|
4
4
|
import numpy as np
|
|
5
5
|
from fastapi import APIRouter, File, Header, Request, UploadFile
|
|
6
6
|
from fastapi.responses import JSONResponse
|
|
7
7
|
|
|
8
|
+
from ...data_classes.detections import Detections
|
|
9
|
+
|
|
10
|
+
if TYPE_CHECKING:
|
|
11
|
+
from ..detector_node import DetectorNode
|
|
12
|
+
|
|
8
13
|
router = APIRouter()
|
|
9
14
|
|
|
10
15
|
|
|
11
|
-
@router.post("/detect")
|
|
16
|
+
@router.post("/detect", response_model=Detections)
|
|
12
17
|
async def http_detect(
|
|
13
18
|
request: Request,
|
|
14
|
-
file: UploadFile = File(
|
|
15
|
-
camera_id: Optional[str] = Header(None),
|
|
16
|
-
mac: Optional[str] = Header(None),
|
|
17
|
-
tags: Optional[str] = Header(None),
|
|
18
|
-
|
|
19
|
+
file: UploadFile = File(..., description='The image file to run detection on'),
|
|
20
|
+
camera_id: Optional[str] = Header(None, description='The camera id (used by learning loop)'),
|
|
21
|
+
mac: Optional[str] = Header(None, description='The camera mac address (used by learning loop)'),
|
|
22
|
+
tags: Optional[str] = Header(None, description='Tags to add to the image (used by learning loop)'),
|
|
23
|
+
source: Optional[str] = Header(None, description='The source of the image (used by learning loop)'),
|
|
24
|
+
autoupload: Optional[str] = Header(None, description='Mode to decide whether to upload the image to the learning loop',
|
|
25
|
+
examples=['filtered', 'all', 'disabled']),
|
|
19
26
|
):
|
|
20
27
|
"""
|
|
21
|
-
|
|
28
|
+
Single image example:
|
|
29
|
+
|
|
30
|
+
curl --request POST -F 'file=@test.jpg' localhost:8004/detect -H 'autoupload: all' -H 'camera-id: front_cam' -H 'source: test' -H 'tags: test,test2'
|
|
22
31
|
|
|
23
|
-
|
|
32
|
+
Multiple images example:
|
|
24
33
|
|
|
25
34
|
for i in `seq 1 10`; do time curl --request POST -F 'file=@test.jpg' localhost:8004/detect; done
|
|
26
35
|
|
|
27
|
-
You can additionally provide the following camera parameters:
|
|
28
|
-
- `autoupload`: configures auto-submission to the learning loop; `filtered` (default), `all`, `disabled` (example curl parameter `-H 'autoupload: all'`)
|
|
29
|
-
- `camera-id`: a string which groups images for submission together (example curl parameter `-H 'camera-id: front_cam'`)
|
|
30
36
|
"""
|
|
31
37
|
try:
|
|
32
38
|
np_image = np.fromfile(file.file, np.uint8)
|
|
33
39
|
except Exception as exc:
|
|
34
|
-
logging.exception(
|
|
40
|
+
logging.exception('Error during reading of image %s.', file.filename)
|
|
35
41
|
raise Exception(f'Uploaded file {file.filename} is no image file.') from exc
|
|
36
42
|
|
|
37
43
|
try:
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
44
|
+
app: 'DetectorNode' = request.app
|
|
45
|
+
detections = await app.get_detections(raw_image=np_image,
|
|
46
|
+
camera_id=camera_id or mac or None,
|
|
47
|
+
tags=tags.split(',') if tags else [],
|
|
48
|
+
source=source,
|
|
49
|
+
autoupload=autoupload)
|
|
42
50
|
except Exception as exc:
|
|
43
|
-
logging.exception(
|
|
51
|
+
logging.exception('Error during detection of image %s.', file.filename)
|
|
44
52
|
raise Exception(f'Error during detection of image {file.filename}.') from exc
|
|
45
|
-
return
|
|
53
|
+
return detections
|
|
@@ -1,7 +1,9 @@
|
|
|
1
1
|
|
|
2
2
|
import os
|
|
3
|
+
import sys
|
|
4
|
+
from dataclasses import dataclass, field
|
|
3
5
|
from enum import Enum
|
|
4
|
-
from typing import TYPE_CHECKING
|
|
6
|
+
from typing import TYPE_CHECKING, List
|
|
5
7
|
|
|
6
8
|
from fastapi import APIRouter, HTTPException, Request
|
|
7
9
|
|
|
@@ -10,6 +12,7 @@ from ...globals import GLOBALS
|
|
|
10
12
|
|
|
11
13
|
if TYPE_CHECKING:
|
|
12
14
|
from ..detector_node import DetectorNode
|
|
15
|
+
KWONLY_SLOTS = {'kw_only': True, 'slots': True} if sys.version_info >= (3, 10) else {}
|
|
13
16
|
|
|
14
17
|
router = APIRouter()
|
|
15
18
|
|
|
@@ -20,9 +23,20 @@ class VersionMode(str, Enum):
|
|
|
20
23
|
Pause = 'pause' # will pause the updates
|
|
21
24
|
|
|
22
25
|
|
|
26
|
+
@dataclass(**KWONLY_SLOTS)
|
|
27
|
+
class ModelVersionResponse:
|
|
28
|
+
current_version: str = field(metadata={"description": "The version of the model currently used by the detector."})
|
|
29
|
+
target_version: str = field(metadata={"description": "The target model version set in the detector."})
|
|
30
|
+
loop_version: str = field(metadata={"description": "The target model version specified by the loop."})
|
|
31
|
+
local_versions: List[str] = field(metadata={"description": "The locally available versions of the model."})
|
|
32
|
+
version_control: str = field(metadata={"description": "The version control mode."})
|
|
33
|
+
|
|
34
|
+
|
|
23
35
|
@router.get("/model_version")
|
|
24
36
|
async def get_version(request: Request):
|
|
25
37
|
'''
|
|
38
|
+
Get information about the model version control and the current model version.
|
|
39
|
+
|
|
26
40
|
Example Usage
|
|
27
41
|
curl http://localhost/model_version
|
|
28
42
|
'''
|
|
@@ -35,28 +49,31 @@ async def get_version(request: Request):
|
|
|
35
49
|
loop_version = app.loop_deployment_target.version if app.loop_deployment_target is not None else 'None'
|
|
36
50
|
|
|
37
51
|
local_versions: list[str] = []
|
|
38
|
-
|
|
39
|
-
local_models = os.listdir(os.path.
|
|
52
|
+
models_path = os.path.join(GLOBALS.data_folder, 'models')
|
|
53
|
+
local_models = os.listdir(models_path) if os.path.exists(models_path) else []
|
|
40
54
|
for model in local_models:
|
|
41
55
|
if model.replace('.', '').isdigit():
|
|
42
56
|
local_versions.append(model)
|
|
43
57
|
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
58
|
+
response = ModelVersionResponse(
|
|
59
|
+
current_version=current_version,
|
|
60
|
+
target_version=target_version,
|
|
61
|
+
loop_version=loop_version,
|
|
62
|
+
local_versions=local_versions,
|
|
63
|
+
version_control=app.version_control.value,
|
|
64
|
+
)
|
|
65
|
+
return response
|
|
51
66
|
|
|
52
67
|
|
|
53
68
|
@router.put("/model_version")
|
|
54
69
|
async def put_version(request: Request):
|
|
55
70
|
'''
|
|
71
|
+
Set the model version control mode.
|
|
72
|
+
|
|
56
73
|
Example Usage
|
|
57
|
-
curl -X PUT -d "follow_loop" http://
|
|
58
|
-
curl -X PUT -d "pause" http://
|
|
59
|
-
curl -X PUT -d "13.6" http://
|
|
74
|
+
curl -X PUT -d "follow_loop" http://hosturl/model_version
|
|
75
|
+
curl -X PUT -d "pause" http://hosturl/model_version
|
|
76
|
+
curl -X PUT -d "13.6" http://hosturl/model_version
|
|
60
77
|
'''
|
|
61
78
|
app: 'DetectorNode' = request.app
|
|
62
79
|
content = str(await request.body(), 'utf-8')
|
|
@@ -22,7 +22,10 @@ class OperationMode(str, Enum):
|
|
|
22
22
|
@router.put("/operation_mode")
|
|
23
23
|
async def put_operation_mode(request: Request):
|
|
24
24
|
'''
|
|
25
|
+
Set the operation mode of the detector node.
|
|
26
|
+
|
|
25
27
|
Example Usage
|
|
28
|
+
|
|
26
29
|
curl -X PUT -d "check_for_updates" http://localhost/operation_mode
|
|
27
30
|
curl -X PUT -d "detecting" http://localhost/operation_mode
|
|
28
31
|
'''
|
|
@@ -34,22 +37,25 @@ async def put_operation_mode(request: Request):
|
|
|
34
37
|
raise HTTPException(422, str(exc)) from exc
|
|
35
38
|
node: DetectorNode = request.app
|
|
36
39
|
|
|
37
|
-
logging.info(
|
|
38
|
-
logging.info(
|
|
39
|
-
logging.info(
|
|
40
|
+
logging.info('current node state : %s', node.status.state)
|
|
41
|
+
logging.info('current operation mode : %s', node.operation_mode.value)
|
|
42
|
+
logging.info('target operation mode : %s', target_mode)
|
|
40
43
|
if target_mode == node.operation_mode:
|
|
41
44
|
logging.info('operation mode already set')
|
|
42
45
|
return "OK"
|
|
43
46
|
|
|
44
47
|
await node.set_operation_mode(target_mode)
|
|
45
|
-
logging.info(
|
|
48
|
+
logging.info('operation mode set to : %s', target_mode)
|
|
46
49
|
return "OK"
|
|
47
50
|
|
|
48
51
|
|
|
49
|
-
@router.get("/operation_mode")
|
|
52
|
+
@router.get("/operation_mode", response_class=PlainTextResponse)
|
|
50
53
|
async def get_operation_mode(request: Request):
|
|
51
54
|
'''
|
|
55
|
+
Get the operation mode of the detector node.
|
|
56
|
+
|
|
52
57
|
Example Usage
|
|
58
|
+
|
|
53
59
|
curl http://localhost/operation_mode
|
|
54
60
|
'''
|
|
55
61
|
return PlainTextResponse(request.app.operation_mode.value)
|
|
@@ -6,10 +6,13 @@ from ..outbox import Outbox
|
|
|
6
6
|
router = APIRouter()
|
|
7
7
|
|
|
8
8
|
|
|
9
|
-
@router.get("/outbox_mode")
|
|
9
|
+
@router.get("/outbox_mode", response_class=PlainTextResponse)
|
|
10
10
|
async def get_outbox_mode(request: Request):
|
|
11
11
|
'''
|
|
12
|
+
Get the outbox mode of the detector node.
|
|
13
|
+
|
|
12
14
|
Example Usage
|
|
15
|
+
|
|
13
16
|
curl http://localhost/outbox_mode
|
|
14
17
|
'''
|
|
15
18
|
outbox: Outbox = request.app.outbox
|
|
@@ -19,7 +22,10 @@ async def get_outbox_mode(request: Request):
|
|
|
19
22
|
@router.put("/outbox_mode")
|
|
20
23
|
async def put_outbox_mode(request: Request):
|
|
21
24
|
'''
|
|
25
|
+
Set the outbox mode of the detector node.
|
|
26
|
+
|
|
22
27
|
Example Usage
|
|
28
|
+
|
|
23
29
|
curl -X PUT -d "continuous_upload" http://localhost/outbox_mode
|
|
24
30
|
curl -X PUT -d "stopped" http://localhost/outbox_mode
|
|
25
31
|
'''
|
learning_loop_node/node.py
CHANGED
|
@@ -8,7 +8,6 @@ from datetime import datetime
|
|
|
8
8
|
from typing import Any, Optional
|
|
9
9
|
|
|
10
10
|
import aiohttp
|
|
11
|
-
import socketio
|
|
12
11
|
from aiohttp import TCPConnector
|
|
13
12
|
from fastapi import FastAPI
|
|
14
13
|
from socketio import AsyncClient
|
|
@@ -18,6 +17,11 @@ from .data_exchanger import DataExchanger
|
|
|
18
17
|
from .helpers import log_conf
|
|
19
18
|
from .helpers.misc import ensure_socket_response, read_or_create_uuid
|
|
20
19
|
from .loop_communication import LoopCommunicator
|
|
20
|
+
from .rest import router
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class NodeConnectionError(Exception):
|
|
24
|
+
pass
|
|
21
25
|
|
|
22
26
|
|
|
23
27
|
class Node(FastAPI):
|
|
@@ -41,9 +45,8 @@ class Node(FastAPI):
|
|
|
41
45
|
self.uuid = uuid or read_or_create_uuid(self.name)
|
|
42
46
|
self.needs_login = needs_login
|
|
43
47
|
|
|
44
|
-
self.log = logging.getLogger()
|
|
45
|
-
self.
|
|
46
|
-
self.websocket_url = self.loop_communicator.websocket_url()
|
|
48
|
+
self.log = logging.getLogger('Node')
|
|
49
|
+
self.init_loop_communicator()
|
|
47
50
|
self.data_exchanger = DataExchanger(None, self.loop_communicator)
|
|
48
51
|
|
|
49
52
|
self.startup_datetime = datetime.now()
|
|
@@ -55,6 +58,19 @@ class Node(FastAPI):
|
|
|
55
58
|
'nodeType': node_type}
|
|
56
59
|
|
|
57
60
|
self.repeat_task: Any = None
|
|
61
|
+
self.socket_connection_broken = False
|
|
62
|
+
self._skip_repeat_loop = False
|
|
63
|
+
|
|
64
|
+
self.include_router(router)
|
|
65
|
+
|
|
66
|
+
self.CONNECTED_TO_LOOP = asyncio.Event()
|
|
67
|
+
self.DISCONNECTED_FROM_LOOP = asyncio.Event()
|
|
68
|
+
|
|
69
|
+
self.repeat_loop_lock = asyncio.Lock()
|
|
70
|
+
|
|
71
|
+
def init_loop_communicator(self):
|
|
72
|
+
self.loop_communicator = LoopCommunicator()
|
|
73
|
+
self.websocket_url = self.loop_communicator.websocket_url()
|
|
58
74
|
|
|
59
75
|
@property
|
|
60
76
|
def sio_client(self) -> AsyncClient:
|
|
@@ -66,7 +82,10 @@ class Node(FastAPI):
|
|
|
66
82
|
@asynccontextmanager
|
|
67
83
|
async def lifespan(self, app: FastAPI): # pylint: disable=unused-argument
|
|
68
84
|
try:
|
|
69
|
-
|
|
85
|
+
try:
|
|
86
|
+
await self._on_startup()
|
|
87
|
+
except Exception:
|
|
88
|
+
self.log.exception('Fatal error during startup: %s')
|
|
70
89
|
self.repeat_task = asyncio.create_task(self.repeat_loop())
|
|
71
90
|
yield
|
|
72
91
|
finally:
|
|
@@ -80,13 +99,10 @@ class Node(FastAPI):
|
|
|
80
99
|
|
|
81
100
|
async def _on_startup(self):
|
|
82
101
|
self.log.info('received "startup" lifecycle-event')
|
|
83
|
-
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
self.log.
|
|
87
|
-
await self.loop_communicator.ensure_login()
|
|
88
|
-
self.log.info('create sio client')
|
|
89
|
-
await self.create_sio_client()
|
|
102
|
+
try:
|
|
103
|
+
await self.reconnect_to_loop()
|
|
104
|
+
except Exception:
|
|
105
|
+
self.log.warning('Could not establish sio connection to loop during startup')
|
|
90
106
|
self.log.info('done')
|
|
91
107
|
await self.on_startup()
|
|
92
108
|
|
|
@@ -99,55 +115,93 @@ class Node(FastAPI):
|
|
|
99
115
|
await self.on_shutdown()
|
|
100
116
|
|
|
101
117
|
async def repeat_loop(self) -> None:
|
|
102
|
-
"""NOTE: with the lifespan approach, we cannot use @repeat_every anymore :("""
|
|
103
118
|
while True:
|
|
119
|
+
if self._skip_repeat_loop:
|
|
120
|
+
self.log.debug('node is muted, skipping repeat loop')
|
|
121
|
+
await asyncio.sleep(1)
|
|
122
|
+
continue
|
|
104
123
|
try:
|
|
105
|
-
|
|
124
|
+
async with self.repeat_loop_lock:
|
|
125
|
+
await self._ensure_sio_connection()
|
|
126
|
+
await self.on_repeat()
|
|
106
127
|
except asyncio.CancelledError:
|
|
107
128
|
return
|
|
108
|
-
except Exception
|
|
109
|
-
self.log.exception(
|
|
129
|
+
except Exception:
|
|
130
|
+
self.log.exception('error in repeat loop')
|
|
131
|
+
|
|
110
132
|
await asyncio.sleep(5)
|
|
111
133
|
|
|
112
|
-
async def
|
|
113
|
-
if not self.sio_client.connected:
|
|
114
|
-
self.log.info('Reconnecting to loop via sio'
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
134
|
+
async def _ensure_sio_connection(self):
|
|
135
|
+
if self.socket_connection_broken or self._sio_client is None or not self.sio_client.connected:
|
|
136
|
+
self.log.info('Reconnecting to loop via sio due to %s',
|
|
137
|
+
'broken connection' if self.socket_connection_broken else 'no connection')
|
|
138
|
+
await self.reconnect_to_loop()
|
|
139
|
+
|
|
140
|
+
async def reconnect_to_loop(self):
|
|
141
|
+
"""Initialize the loop communicator, log in if needed and reconnect to the loop via socket.io."""
|
|
142
|
+
self.init_loop_communicator()
|
|
143
|
+
if self.needs_login:
|
|
144
|
+
await self.loop_communicator.ensure_login(relogin=True)
|
|
145
|
+
try:
|
|
146
|
+
await self._reconnect_socketio()
|
|
147
|
+
except NodeConnectionError:
|
|
148
|
+
self.log.exception('Could not reset sio connection to loop')
|
|
149
|
+
self.socket_connection_broken = True
|
|
150
|
+
raise
|
|
151
|
+
|
|
152
|
+
self.socket_connection_broken = False
|
|
153
|
+
|
|
154
|
+
def set_skip_repeat_loop(self, value: bool):
|
|
155
|
+
self._skip_repeat_loop = value
|
|
156
|
+
self.log.info('node is muted: %s', value)
|
|
120
157
|
|
|
121
158
|
# --------------------------------------------------- SOCKET.IO ---------------------------------------------------
|
|
122
159
|
|
|
123
|
-
async def
|
|
124
|
-
"""Create a socket.io client
|
|
125
|
-
|
|
160
|
+
async def _reconnect_socketio(self):
|
|
161
|
+
"""Create a socket.io client, connect it to the learning loop and register its events.
|
|
162
|
+
The current client is disconnected and deleted if it already exists."""
|
|
163
|
+
|
|
164
|
+
self.log.debug('-------------- Connecting to loop via socket.io -------------------')
|
|
165
|
+
self.log.debug('HTTP Cookies: %s\n', self.loop_communicator.get_cookies())
|
|
166
|
+
|
|
167
|
+
if self._sio_client is not None:
|
|
168
|
+
try:
|
|
169
|
+
await self.sio_client.disconnect()
|
|
170
|
+
self.log.info('disconnected from loop via sio')
|
|
171
|
+
# NOTE: without waiting for the disconnect event, we might disconnect the next connection too early
|
|
172
|
+
await asyncio.wait_for(self.DISCONNECTED_FROM_LOOP.wait(), timeout=5)
|
|
173
|
+
except asyncio.TimeoutError:
|
|
174
|
+
self.log.warning(
|
|
175
|
+
'Did not receive disconnect event from loop within 5 seconds.\nContinuing with new connection...')
|
|
176
|
+
except Exception as e:
|
|
177
|
+
self.log.warning('Could not disconnect from loop via sio: %s.\nIgnoring...', e)
|
|
178
|
+
self._sio_client = None
|
|
126
179
|
|
|
180
|
+
connector = None
|
|
127
181
|
if self.loop_communicator.ssl_cert_path:
|
|
128
|
-
logging.info(
|
|
182
|
+
logging.info('SIO using SSL certificate path: %s', self.loop_communicator.ssl_cert_path)
|
|
129
183
|
ssl_context = ssl.create_default_context(cafile=self.loop_communicator.ssl_cert_path)
|
|
130
184
|
ssl_context.check_hostname = False
|
|
131
185
|
ssl_context.verify_mode = ssl.CERT_REQUIRED
|
|
132
186
|
connector = TCPConnector(ssl=ssl_context)
|
|
133
|
-
self._sio_client = AsyncClient(request_timeout=20,
|
|
134
|
-
http_session=aiohttp.ClientSession(cookies=self.loop_communicator.get_cookies(),
|
|
135
|
-
connector=connector))
|
|
136
187
|
|
|
137
|
-
|
|
138
|
-
self.
|
|
139
|
-
http_session=aiohttp.ClientSession(cookies=self.loop_communicator.get_cookies()))
|
|
188
|
+
self._sio_client = AsyncClient(request_timeout=20, http_session=aiohttp.ClientSession(
|
|
189
|
+
cookies=self.loop_communicator.get_cookies(), connector=connector))
|
|
140
190
|
|
|
141
191
|
# pylint: disable=protected-access
|
|
142
|
-
self.
|
|
192
|
+
self._sio_client._trigger_event = ensure_socket_response(self._sio_client._trigger_event)
|
|
143
193
|
|
|
144
194
|
@self._sio_client.event
|
|
145
195
|
async def connect():
|
|
146
196
|
self.log.info('received "connect" via sio from loop.')
|
|
197
|
+
self.CONNECTED_TO_LOOP.set()
|
|
198
|
+
self.DISCONNECTED_FROM_LOOP.clear()
|
|
147
199
|
|
|
148
200
|
@self._sio_client.event
|
|
149
201
|
async def disconnect():
|
|
150
202
|
self.log.info('received "disconnect" via sio from loop.')
|
|
203
|
+
self.DISCONNECTED_FROM_LOOP.set()
|
|
204
|
+
self.CONNECTED_TO_LOOP.clear()
|
|
151
205
|
|
|
152
206
|
@self._sio_client.event
|
|
153
207
|
async def restart():
|
|
@@ -155,21 +209,15 @@ class Node(FastAPI):
|
|
|
155
209
|
sys.exit(0)
|
|
156
210
|
|
|
157
211
|
self.register_sio_events(self._sio_client)
|
|
158
|
-
|
|
159
|
-
async def connect_sio(self):
|
|
160
212
|
try:
|
|
161
|
-
await self.
|
|
162
|
-
except Exception:
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
168
|
-
|
|
169
|
-
except socketio.exceptions.ConnectionError: # type: ignore
|
|
170
|
-
self.log.warning('connection error')
|
|
171
|
-
except Exception:
|
|
172
|
-
self.log.exception(f'error while connecting to "{self.websocket_url}". Exception:')
|
|
213
|
+
await self._sio_client.connect(f"{self.websocket_url}", headers=self.sio_headers, socketio_path="/ws/socket.io")
|
|
214
|
+
except Exception as e:
|
|
215
|
+
self.log.exception('Could not connect socketio client to loop')
|
|
216
|
+
raise NodeConnectionError('Could not connect socketio client to loop') from e
|
|
217
|
+
|
|
218
|
+
if not self._sio_client.connected:
|
|
219
|
+
self.log.exception('Could not connect socketio client to loop')
|
|
220
|
+
raise NodeConnectionError('Could not connect socketio client to loop')
|
|
173
221
|
|
|
174
222
|
# --------------------------------------------------- ABSTRACT METHODS ---------------------------------------------------
|
|
175
223
|
|
|
@@ -0,0 +1,55 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from typing import TYPE_CHECKING
|
|
3
|
+
|
|
4
|
+
from fastapi import APIRouter, HTTPException, Request
|
|
5
|
+
|
|
6
|
+
if TYPE_CHECKING:
|
|
7
|
+
from .node import Node
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
router = APIRouter()
|
|
11
|
+
logger = logging.getLogger('Node.rest')
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@router.put("/debug_logging")
|
|
15
|
+
async def _debug_logging(request: Request) -> str:
|
|
16
|
+
'''
|
|
17
|
+
Example Usage
|
|
18
|
+
|
|
19
|
+
curl -X PUT -d "on" http://localhost:8007/debug_logging
|
|
20
|
+
'''
|
|
21
|
+
state = str(await request.body(), 'utf-8')
|
|
22
|
+
node: 'Node' = request.app
|
|
23
|
+
|
|
24
|
+
if state == 'off':
|
|
25
|
+
logger.info('turning debug logging off')
|
|
26
|
+
node.log.setLevel('INFO')
|
|
27
|
+
return 'off'
|
|
28
|
+
if state == 'on':
|
|
29
|
+
logger.info('turning debug logging on')
|
|
30
|
+
node.log.setLevel('DEBUG')
|
|
31
|
+
return 'on'
|
|
32
|
+
raise HTTPException(status_code=400, detail='Invalid state')
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
@router.put("/socketio")
|
|
36
|
+
async def _socketio(request: Request) -> str:
|
|
37
|
+
'''
|
|
38
|
+
Enable or disable the socketio connection to the learning loop.
|
|
39
|
+
Not intended to be used outside of testing.
|
|
40
|
+
|
|
41
|
+
Example Usage
|
|
42
|
+
|
|
43
|
+
curl -X PUT -d "on" http://hosturl/socketio
|
|
44
|
+
'''
|
|
45
|
+
state = str(await request.body(), 'utf-8')
|
|
46
|
+
node: 'Node' = request.app
|
|
47
|
+
|
|
48
|
+
if state == 'off':
|
|
49
|
+
await node.sio_client.disconnect()
|
|
50
|
+
node.set_skip_repeat_loop(True) # Prevent auto-reconnection
|
|
51
|
+
return 'off'
|
|
52
|
+
if state == 'on':
|
|
53
|
+
node.set_skip_repeat_loop(False) # Allow auto-reconnection (1 sec delay)
|
|
54
|
+
return 'on'
|
|
55
|
+
raise HTTPException(status_code=400, detail='Invalid state')
|