learning-loop-node 0.9.3__py3-none-any.whl → 0.10.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of learning-loop-node might be problematic. Click here for more details.

Files changed (54) hide show
  1. learning_loop_node/__init__.py +2 -3
  2. learning_loop_node/annotation/annotator_logic.py +2 -2
  3. learning_loop_node/annotation/annotator_node.py +16 -15
  4. learning_loop_node/data_classes/__init__.py +17 -10
  5. learning_loop_node/data_classes/detections.py +7 -2
  6. learning_loop_node/data_classes/general.py +4 -5
  7. learning_loop_node/data_classes/training.py +49 -21
  8. learning_loop_node/data_exchanger.py +85 -139
  9. learning_loop_node/detector/__init__.py +0 -1
  10. learning_loop_node/detector/detector_node.py +10 -13
  11. learning_loop_node/detector/inbox_filter/cam_observation_history.py +4 -7
  12. learning_loop_node/detector/outbox.py +0 -1
  13. learning_loop_node/detector/rest/about.py +1 -0
  14. learning_loop_node/detector/tests/conftest.py +0 -1
  15. learning_loop_node/detector/tests/test_client_communication.py +5 -3
  16. learning_loop_node/detector/tests/test_outbox.py +2 -0
  17. learning_loop_node/detector/tests/testing_detector.py +1 -8
  18. learning_loop_node/globals.py +2 -2
  19. learning_loop_node/helpers/gdrive_downloader.py +1 -1
  20. learning_loop_node/helpers/misc.py +124 -17
  21. learning_loop_node/loop_communication.py +57 -25
  22. learning_loop_node/node.py +62 -135
  23. learning_loop_node/tests/test_downloader.py +8 -7
  24. learning_loop_node/tests/test_executor.py +14 -11
  25. learning_loop_node/tests/test_helper.py +3 -5
  26. learning_loop_node/trainer/downloader.py +1 -1
  27. learning_loop_node/trainer/executor.py +87 -83
  28. learning_loop_node/trainer/io_helpers.py +66 -9
  29. learning_loop_node/trainer/rest/backdoor_controls.py +10 -5
  30. learning_loop_node/trainer/rest/controls.py +3 -1
  31. learning_loop_node/trainer/tests/conftest.py +19 -28
  32. learning_loop_node/trainer/tests/states/test_state_cleanup.py +5 -3
  33. learning_loop_node/trainer/tests/states/test_state_detecting.py +23 -20
  34. learning_loop_node/trainer/tests/states/test_state_download_train_model.py +18 -12
  35. learning_loop_node/trainer/tests/states/test_state_prepare.py +13 -12
  36. learning_loop_node/trainer/tests/states/test_state_sync_confusion_matrix.py +21 -18
  37. learning_loop_node/trainer/tests/states/test_state_train.py +27 -28
  38. learning_loop_node/trainer/tests/states/test_state_upload_detections.py +34 -32
  39. learning_loop_node/trainer/tests/states/test_state_upload_model.py +22 -20
  40. learning_loop_node/trainer/tests/test_errors.py +20 -12
  41. learning_loop_node/trainer/tests/test_trainer_states.py +4 -5
  42. learning_loop_node/trainer/tests/testing_trainer_logic.py +25 -30
  43. learning_loop_node/trainer/trainer_logic.py +80 -590
  44. learning_loop_node/trainer/trainer_logic_generic.py +495 -0
  45. learning_loop_node/trainer/trainer_node.py +27 -106
  46. {learning_loop_node-0.9.3.dist-info → learning_loop_node-0.10.0.dist-info}/METADATA +1 -1
  47. learning_loop_node-0.10.0.dist-info/RECORD +85 -0
  48. learning_loop_node/converter/converter_logic.py +0 -68
  49. learning_loop_node/converter/converter_node.py +0 -125
  50. learning_loop_node/converter/tests/test_converter.py +0 -55
  51. learning_loop_node/trainer/training_syncronizer.py +0 -52
  52. learning_loop_node-0.9.3.dist-info/RECORD +0 -88
  53. /learning_loop_node/{converter/__init__.py → py.typed} +0 -0
  54. {learning_loop_node-0.9.3.dist-info → learning_loop_node-0.10.0.dist-info}/WHEEL +0 -0
@@ -21,26 +21,28 @@ def cleanup():
21
21
  cleanup_process.communicate()
22
22
 
23
23
 
24
- def test_executor_lifecycle():
24
+ @pytest.mark.asyncio
25
+ async def test_executor_lifecycle():
25
26
  assert_process_is_running('some_executable.sh', False)
26
27
 
27
- executor = Executor('/tmp/test_executor/' + str(uuid4()))
28
- cmd = executor.path + '/some_executable.sh'
29
- with open(cmd, 'w') as f:
30
- f.write('while true; do echo "some output"; sleep 1; done')
31
- os.chmod(cmd, 0o755)
28
+ executor = Executor('/tmp/test_executor/' + str(uuid4())+'/')
29
+ cmd = 'bash some_executable.sh'
30
+ executable_path = executor.path+'some_executable.sh'
31
+ with open(executable_path, 'w') as f:
32
+ f.write('/bin/bash -c "while true; do sleep 1; echo some output; done"')
33
+ os.chmod(executable_path, 0o755)
32
34
 
33
- executor.start(cmd)
35
+ await executor.start(cmd)
34
36
 
35
- assert executor.is_process_running()
37
+ assert executor.is_running()
36
38
  assert_process_is_running('some_executable.sh')
37
39
 
38
- sleep(1)
40
+ sleep(5)
39
41
  assert 'some output' in executor.get_log()
40
42
 
41
- executor.stop()
43
+ await executor.stop_and_wait()
42
44
 
43
- assert not executor.is_process_running()
45
+ assert not executor.is_running()
44
46
  sleep(1)
45
47
  assert_process_is_running('some_executable.sh', False)
46
48
 
@@ -48,6 +50,7 @@ def test_executor_lifecycle():
48
50
  def assert_process_is_running(process_name, running=True):
49
51
  if running:
50
52
  for process in psutil.process_iter():
53
+ print(process.name(), process.cmdline())
51
54
  process_name_match = process_name in process.name()
52
55
  process_cmd_match = process_name in str(process.cmdline())
53
56
  if process_name_match or process_cmd_match:
@@ -7,10 +7,8 @@ from glob import glob
7
7
  from typing import Callable
8
8
 
9
9
  from learning_loop_node.data_classes import Context
10
- from learning_loop_node.helpers.misc import create_image_folder
10
+ from learning_loop_node.helpers.misc import create_image_folder, create_project_folder, create_training_folder
11
11
  from learning_loop_node.loop_communication import LoopCommunicator
12
- from learning_loop_node.node import Node
13
- from learning_loop_node.trainer.trainer_logic import TrainerLogic
14
12
 
15
13
 
16
14
  def get_files_in_folder(folder: str):
@@ -65,8 +63,8 @@ def _update_attribute_dict(obj: dict, **kwargs) -> None:
65
63
 
66
64
 
67
65
  def create_needed_folders(training_uuid: str = 'some_uuid'): # pylint: disable=unused-argument
68
- project_folder = Node.create_project_folder(
66
+ project_folder = create_project_folder(
69
67
  Context(organization='zauberzeug', project='pytest'))
70
68
  image_folder = create_image_folder(project_folder)
71
- training_folder = TrainerLogic.create_training_folder(project_folder, training_uuid)
69
+ training_folder = create_training_folder(project_folder, training_uuid)
72
70
  return project_folder, image_folder, training_folder
@@ -12,7 +12,7 @@ class TrainingsDownloader():
12
12
  self.data_exchanger = data_exchanger
13
13
 
14
14
  async def download_training_data(self, image_folder: str) -> Tuple[List[Dict], int]:
15
- image_ids = await self.data_exchanger.fetch_image_ids(query_params=self.data_query_params)
15
+ image_ids = await self.data_exchanger.fetch_image_uuids(query_params=self.data_query_params)
16
16
  image_data, skipped_image_count = await self.download_images_and_annotations(image_ids, image_folder)
17
17
  return (image_data, skipped_image_count)
18
18
 
@@ -1,105 +1,109 @@
1
-
2
- import ctypes
1
+ import asyncio
3
2
  import logging
4
3
  import os
5
- import signal
6
- import subprocess
7
- from sys import platform
4
+ import shlex
5
+ from io import BufferedWriter
8
6
  from typing import List, Optional
9
7
 
10
- import psutil
11
8
 
9
+ class Executor:
10
+ def __init__(self, base_path: str, log_name='last_training.log') -> None:
11
+ """An executor that runs a command in a separate async subprocess.
12
+ The log of the process is written to 'last_training.log' in the base_path.
13
+ Tthe process is executed in the base_path directory.
14
+ The process should be awaited to finish using `wait` or stopped using `stop` to
15
+ avoid zombie processes and close the log file."""
12
16
 
13
- def create_signal_handler(sig=signal.SIGTERM):
14
- if platform == "linux" or platform == "linux2":
15
- # "The system will send a signal to the child once the parent exits for any reason (even sigkill)."
16
- # https://stackoverflow.com/a/19448096
17
- libc = ctypes.CDLL("libc.so.6")
17
+ self.path = base_path
18
+ self.log_file_path = f'{self.path}/{log_name}'
19
+ self.log_file: None | BufferedWriter = None
20
+ self._process: Optional[asyncio.subprocess.Process] = None # pylint: disable=no-member
21
+ os.makedirs(self.path, exist_ok=True)
18
22
 
19
- def callable_():
20
- os.setsid()
21
- return libc.prctl(1, sig)
23
+ def _get_running_process(self) -> Optional[asyncio.subprocess.Process]: # pylint: disable=no-member
24
+ """Get the running process if available."""
25
+ if self._process is not None and self._process.returncode is None:
26
+ return self._process
27
+ return None
22
28
 
23
- return callable_
24
- return os.setsid
29
+ async def start(self, cmd: str, env: Optional[dict[str, str]] = None) -> None:
30
+ """Start the process with the given command and environment variables."""
25
31
 
32
+ full_env = os.environ.copy()
33
+ if env is not None:
34
+ full_env.update(env)
26
35
 
27
- class Executor:
28
- def __init__(self, base_path: str) -> None:
29
- self.path = base_path
30
- os.makedirs(self.path, exist_ok=True)
31
- self.process: Optional[subprocess.Popen[bytes]] = None
32
-
33
- def start(self, cmd: str):
34
- with open(f'{self.path}/last_training.log', 'a') as f:
35
- f.write(f'\nStarting executor with command: {cmd}\n')
36
- # pylint: disable=subprocess-popen-preexec-fn
37
- self.process = subprocess.Popen(
38
- f'cd {self.path}; {cmd} >> last_training.log 2>&1',
39
- shell=True,
40
- stdout=subprocess.PIPE,
41
- stderr=subprocess.PIPE,
42
- executable='/bin/bash',
43
- preexec_fn=create_signal_handler(),
44
- )
36
+ logging.info(f'Starting executor with command: {cmd} in {self.path} - logging to {self.log_file_path}')
37
+ self.log_file = open(self.log_file_path, 'ab')
45
38
 
46
- def is_process_running(self):
47
- if self.process is None:
48
- return False
39
+ self._process = await asyncio.create_subprocess_exec(
40
+ *shlex.split(cmd),
41
+ cwd=self.path,
42
+ stdout=self.log_file,
43
+ stderr=asyncio.subprocess.STDOUT, # Merge stderr with stdout
44
+ env=full_env
45
+ )
49
46
 
50
- if self.process.poll() is not None:
51
- return False
47
+ def is_running(self) -> bool:
48
+ """Check if the process is still running."""
49
+ return self._process is not None and self._process.returncode is None
52
50
 
53
- try:
54
- psutil.Process(self.process.pid)
55
- except psutil.NoSuchProcess:
56
- # self.process.terminate() # TODO does this make sense?
57
- # self.process = None
58
- return False
51
+ def terminate(self) -> None:
52
+ """Terminate the process."""
59
53
 
60
- return True
54
+ if process := self._get_running_process():
55
+ try:
56
+ process.terminate()
57
+ return
58
+ except ProcessLookupError:
59
+ logging.error('No process to terminate')
60
+ self._process = None
61
61
 
62
- def get_log(self) -> str:
63
- try:
64
- with open(f'{self.path}/last_training.log') as f:
65
- return f.read()
66
- except Exception:
67
- return ''
62
+ async def wait(self) -> Optional[int]:
63
+ """Wait for the process to finish. Returns the return code of the process or None if no process is running."""
68
64
 
69
- def get_log_by_lines(self, since_last_start=False) -> List[str]: # TODO do not read whole log again
70
- try:
71
- with open(f'{self.path}/last_training.log') as f:
72
- lines = f.readlines()
73
- if since_last_start:
74
- lines_since_last_start = []
75
- for line in reversed(lines):
76
- lines_since_last_start.append(line)
77
- if line.startswith('Starting executor with command:'):
78
- break
79
- return list(reversed(lines_since_last_start))
80
- return lines
81
- except Exception:
82
- return []
65
+ if not self._process:
66
+ logging.info('No process to wait for')
67
+ return None
83
68
 
84
- def stop(self):
85
- if self.process is None:
86
- logging.info('no process running ... nothing to stop')
87
- return
69
+ return_code = await self._process.wait()
88
70
 
89
- logging.info('terminating process')
71
+ self.close_log()
72
+ self._process = None
90
73
 
91
- try:
92
- os.killpg(os.getpgid(self.process.pid), signal.SIGTERM)
93
- except ProcessLookupError:
94
- pass
74
+ return return_code
95
75
 
96
- self.process.terminate()
97
- _, _ = self.process.communicate(timeout=3)
76
+ async def stop_and_wait(self) -> Optional[int]:
77
+ """Terminate the process and wait for it to finish. Returns the return code of the process."""
98
78
 
99
- @property
100
- def return_code(self):
101
- if not self.process:
102
- return None
103
- if self.is_process_running():
79
+ if not self.is_running():
80
+ logging.info('No process to stop')
104
81
  return None
105
- return self.process.poll()
82
+
83
+ self.terminate()
84
+ return await self.wait()
85
+
86
+ # -------------------------------------------------------------------------------------------- LOGGING
87
+
88
+ def get_log(self) -> str:
89
+ """Get the log of the process as a string."""
90
+ if not os.path.exists(self.log_file_path):
91
+ return ''
92
+ with open(self.log_file_path, 'r') as f:
93
+ return f.read()
94
+
95
+ def get_log_by_lines(self, tail: Optional[int] = None) -> List[str]:
96
+ """Get the log of the process as a list of lines."""
97
+ if not os.path.exists(self.log_file_path):
98
+ return []
99
+ with open(self.log_file_path) as f:
100
+ lines = f.readlines()
101
+ if tail is not None:
102
+ lines = lines[-tail:]
103
+ return lines
104
+
105
+ def close_log(self):
106
+ """Close the log file."""
107
+ if self.log_file is not None:
108
+ self.log_file.close()
109
+ self.log_file = None
@@ -1,5 +1,6 @@
1
1
 
2
2
  import json
3
+ import logging
3
4
  import os
4
5
  from dataclasses import asdict
5
6
  from pathlib import Path
@@ -8,8 +9,19 @@ from typing import List
8
9
  from dacite import from_dict
9
10
  from fastapi.encoders import jsonable_encoder
10
11
 
11
- from ..data_classes import Detections, Training
12
+ from ..data_classes import Context, Detections, Training
12
13
  from ..globals import GLOBALS
14
+ from ..loop_communication import LoopCommunicator
15
+
16
+
17
+ class EnvironmentVars:
18
+ def __init__(self) -> None:
19
+ self.restart_after_training = os.environ.get(
20
+ 'RESTART_AFTER_TRAINING', 'FALSE').lower() in ['true', '1']
21
+ self.keep_old_trainings = os.environ.get(
22
+ 'KEEP_OLD_TRAININGS', 'FALSE').lower() in ['true', '1']
23
+ self.inference_batch_size = int(
24
+ os.environ.get('INFERENCE_BATCH_SIZE', '10'))
13
25
 
14
26
 
15
27
  class LastTrainingIO:
@@ -35,13 +47,16 @@ class LastTrainingIO:
35
47
 
36
48
  class ActiveTrainingIO:
37
49
 
38
- @staticmethod
39
- def create_mocked_training_io() -> 'ActiveTrainingIO':
40
- training_folder = ''
41
- return ActiveTrainingIO(training_folder)
50
+ # @staticmethod
51
+ # def create_mocked_training_io() -> 'ActiveTrainingIO':
52
+ # training_folder = ''
53
+ # return ActiveTrainingIO(training_folder)
42
54
 
43
- def __init__(self, training_folder: str):
55
+ def __init__(self, training_folder: str, loop_communicator: LoopCommunicator, context: Context) -> None:
44
56
  self.training_folder = training_folder
57
+ self.loop_communicator = loop_communicator
58
+ self.context = context
59
+
45
60
  self.mup_path = f'{training_folder}/model_uploading_progress.txt'
46
61
  # string with placeholder gor index
47
62
  self.det_path = f'{training_folder}' + '/detections_{0}.json'
@@ -63,13 +78,16 @@ class ActiveTrainingIO:
63
78
 
64
79
  # detections
65
80
 
66
- def get_detection_file_names(self) -> List[Path]:
81
+ def _get_detection_file_names(self) -> List[Path]:
67
82
  files = [f for f in Path(self.training_folder).iterdir()
68
83
  if f.is_file() and f.name.startswith('detections_')]
69
84
  if not files:
70
85
  return []
71
86
  return files
72
87
 
88
+ def get_number_of_detection_files(self) -> int:
89
+ return len(self._get_detection_file_names())
90
+
73
91
  # TODO: saving and uploading multiple files is not tested!
74
92
  def save_detections(self, detections: List[Detections], index: int = 0) -> None:
75
93
  with open(self.det_path.format(index), 'w') as f:
@@ -81,11 +99,11 @@ class ActiveTrainingIO:
81
99
  return [from_dict(data_class=Detections, data=d) for d in dict_list]
82
100
 
83
101
  def delete_detections(self) -> None:
84
- for file in self.get_detection_file_names():
102
+ for file in self._get_detection_file_names():
85
103
  os.remove(Path(self.training_folder) / file)
86
104
 
87
105
  def detections_exist(self) -> bool:
88
- return bool(self.get_detection_file_names())
106
+ return bool(self._get_detection_file_names())
89
107
 
90
108
  # detections upload file index
91
109
 
@@ -124,3 +142,42 @@ class ActiveTrainingIO:
124
142
 
125
143
  def detection_upload_progress_exist(self) -> bool:
126
144
  return os.path.exists(self.dup_path)
145
+
146
+ async def upload_detetions(self):
147
+ num_files = self.get_number_of_detection_files()
148
+ print(f'num_files: {num_files}', flush=True)
149
+ if not num_files:
150
+ logging.error('no detection files found')
151
+ return
152
+ current_json_file_index = self.load_detections_upload_file_index()
153
+ for i in range(current_json_file_index, num_files):
154
+ detections = self.load_detections(i)
155
+ logging.info(f'uploading detections {i}/{num_files}')
156
+ await self._upload_detections_batched(self.context, detections)
157
+ self.save_detections_upload_file_index(i+1)
158
+
159
+ async def _upload_detections_batched(self, context: Context, detections: List[Detections]):
160
+ batch_size = 10
161
+ skip_detections = self.load_detection_upload_progress()
162
+ for i in range(skip_detections, len(detections), batch_size):
163
+ up_progress = i+batch_size
164
+ batch_detections = detections[i:up_progress]
165
+ dict_detections = [jsonable_encoder(asdict(detection)) for detection in batch_detections]
166
+ logging.info(f'uploading detections. File size : {len(json.dumps(dict_detections))}')
167
+ await self._upload_detections(context, batch_detections, up_progress)
168
+ skip_detections = up_progress
169
+
170
+ async def _upload_detections(self, context: Context, batch_detections: List[Detections], up_progress: int):
171
+ detections_json = [jsonable_encoder(asdict(detections)) for detections in batch_detections]
172
+ response = await self.loop_communicator.post(
173
+ f'/{context.organization}/projects/{context.project}/detections', json=detections_json)
174
+ if response.status_code != 200:
175
+ msg = f'could not upload detections. {str(response)}'
176
+ logging.error(msg)
177
+ raise Exception(msg)
178
+
179
+ logging.info('successfully uploaded detections')
180
+ if up_progress > len(batch_detections):
181
+ self.save_detection_upload_progress(0)
182
+ else:
183
+ self.save_detection_upload_progress(up_progress)
@@ -5,10 +5,10 @@ import logging
5
5
  from dataclasses import asdict
6
6
  from typing import TYPE_CHECKING, Dict
7
7
 
8
- from dacite import from_dict
9
8
  from fastapi import APIRouter, HTTPException, Request
10
9
 
11
10
  from ...data_classes import ErrorConfiguration, NodeState
11
+ from ..trainer_logic import TrainerLogic
12
12
 
13
13
  if TYPE_CHECKING:
14
14
  from ..trainer_node import TrainerNode
@@ -95,7 +95,9 @@ async def add_steps(request: Request):
95
95
  trainer_node = trainer_node_from_request(request)
96
96
  trainer_logic = trainer_node.trainer_logic # NOTE: is MockTrainerLogic which has 'provide_new_model' and 'current_iteration'
97
97
 
98
- if not trainer_logic._executor or not trainer_logic._executor.is_process_running(): # pylint: disable=protected-access
98
+ assert isinstance(trainer_logic, TrainerLogic), 'trainer_logic is not TrainerLogic'
99
+
100
+ if not trainer_logic._executor or not trainer_logic._executor.is_running(): # pylint: disable=protected-access
99
101
  training = trainer_logic._training # pylint: disable=protected-access
100
102
  logging.error(f'cannot add steps when not running, state: {training.training_state if training else "None"}')
101
103
  raise HTTPException(status_code=409, detail="trainer is not running")
@@ -109,7 +111,7 @@ async def add_steps(request: Request):
109
111
  for _ in range(steps):
110
112
  try:
111
113
  logging.warning('calling sync_confusion_matrix')
112
- await trainer_logic.sync_confusion_matrix()
114
+ await trainer_logic._sync_confusion_matrix() # pylint: disable=protected-access
113
115
  except Exception:
114
116
  pass # Tests can force synchroniation to fail, error state is reported to backend
115
117
  trainer_logic.provide_new_model = previous_state # type: ignore
@@ -119,11 +121,14 @@ async def add_steps(request: Request):
119
121
 
120
122
  @router.post("/kill_training_process")
121
123
  async def kill_process(request: Request):
124
+
122
125
  # pylint: disable=protected-access
123
126
  trainer_node = trainer_node_from_request(request)
124
- if not trainer_node.trainer_logic._executor or not trainer_node.trainer_logic._executor.is_process_running():
127
+ trainer_logic = trainer_node.trainer_logic
128
+ assert isinstance(trainer_logic, TrainerLogic), 'trainer_logic is not TrainerLogic'
129
+ if not trainer_logic._executor or not trainer_logic._executor.is_running():
125
130
  raise HTTPException(status_code=409, detail="trainer is not running")
126
- trainer_node.trainer_logic._executor.stop()
131
+ await trainer_logic._executor.stop_and_wait()
127
132
 
128
133
 
129
134
  @router.post("/force_status_update")
@@ -7,6 +7,8 @@ from learning_loop_node.trainer.trainer_logic import TrainerLogic
7
7
 
8
8
  router = APIRouter()
9
9
 
10
+ # pylint: disable=protected-access
11
+
10
12
 
11
13
  @router.post("/controls/detect/{organization}/{project}/{version}")
12
14
  async def operation_mode(organization: str, project: str, version: str, request: Request):
@@ -22,5 +24,5 @@ async def operation_mode(organization: str, project: str, version: str, request:
22
24
  model_id = next(m for m in models if m['version'] == version)['id']
23
25
  logging.info(model_id)
24
26
  trainer: TrainerLogic = request.app.trainer
25
- await trainer.do_detections()
27
+ await trainer._do_detections()
26
28
  return "OK"
@@ -10,6 +10,8 @@ from learning_loop_node.data_classes import Context
10
10
  from learning_loop_node.trainer.tests.testing_trainer_logic import TestingTrainerLogic
11
11
  from learning_loop_node.trainer.trainer_node import TrainerNode
12
12
 
13
+ # pylint: disable=protected-access
14
+
13
15
  logging.basicConfig(level=logging.INFO)
14
16
  # show ouptut from uvicorn server https://stackoverflow.com/a/66132186/364388
15
17
  log_to_stderr(logging.INFO)
@@ -24,16 +26,14 @@ async def test_initialized_trainer_node():
24
26
 
25
27
  trainer = TestingTrainerLogic()
26
28
  node = TrainerNode(name='test', trainer_logic=trainer, uuid='NOD30000-0000-0000-0000-000000000000')
27
- trainer._node = node # pylint: disable=protected-access
28
- trainer.init_new_training(context=Context(organization='zauberzeug', project='demo'),
29
- details={'categories': [],
30
- 'id': '917d5c7f-403d-7e92-f95f-577f79c2273a', # version 1.2 of demo project
31
- 'training_number': 0,
32
- 'resolution': 800,
33
- 'flip_rl': False,
34
- 'flip_ud': False})
35
-
36
- # pylint: disable=protected-access
29
+ trainer._node = node
30
+ trainer._init_new_training(context=Context(organization='zauberzeug', project='demo'),
31
+ details={'categories': [],
32
+ 'id': '917d5c7f-403d-7e92-f95f-577f79c2273a', # version 1.2 of demo project
33
+ 'training_number': 0,
34
+ 'resolution': 800,
35
+ 'flip_rl': False,
36
+ 'flip_ud': False})
37
37
  await node._on_startup()
38
38
  yield node
39
39
  await node._on_shutdown()
@@ -44,19 +44,17 @@ async def test_initialized_trainer():
44
44
 
45
45
  trainer = TestingTrainerLogic()
46
46
  node = TrainerNode(name='test', trainer_logic=trainer, uuid='NODE-000-0000-0000-0000-000000000000')
47
- # pylint: disable=protected-access
48
- await node._on_startup()
49
- trainer._node = node # pylint: disable=protected-access
50
- trainer.init_new_training(context=Context(organization='zauberzeug', project='demo'),
51
- details={'categories': [],
52
- 'id': '917d5c7f-403d-7e92-f95f-577f79c2273a', # version 1.2 of demo project
53
- 'training_number': 0,
54
- 'resolution': 800,
55
- 'flip_rl': False,
56
- 'flip_ud': False})
57
47
 
48
+ await node._on_startup()
49
+ trainer._node = node
50
+ trainer._init_new_training(context=Context(organization='zauberzeug', project='demo'),
51
+ details={'categories': [],
52
+ 'id': '917d5c7f-403d-7e92-f95f-577f79c2273a', # version 1.2 of demo project
53
+ 'training_number': 0,
54
+ 'resolution': 800,
55
+ 'flip_rl': False,
56
+ 'flip_ud': False})
58
57
  yield trainer
59
- # await node._on_shutdown()
60
58
  try:
61
59
  await node._on_shutdown()
62
60
  except Exception:
@@ -66,10 +64,3 @@ async def test_initialized_trainer():
66
64
  def is_port_in_use(port):
67
65
  with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
68
66
  return s.connect_ex(('localhost', port)) == 0
69
-
70
-
71
- # @pytest.fixture(autouse=True, scope='session')
72
- # def initialize_active_training():
73
- # from learning_loop_node.trainer import active_training_module
74
- # active_training_module.init('00000000-0000-0000-0000-000000000000')
75
- # yield
@@ -1,11 +1,13 @@
1
1
  from learning_loop_node.trainer.tests.state_helper import create_active_training_file
2
2
  from learning_loop_node.trainer.tests.testing_trainer_logic import TestingTrainerLogic
3
3
 
4
+ # pylint: disable=protected-access
5
+
4
6
 
5
7
  async def test_cleanup_successfull(test_initialized_trainer: TestingTrainerLogic):
6
8
  trainer = test_initialized_trainer
7
9
  create_active_training_file(trainer, training_state='ready_for_cleanup')
8
- trainer.init_from_last_training()
10
+ trainer._init_from_last_training()
9
11
  trainer.active_training_io.save_detections(detections=[])
10
12
 
11
13
  trainer.active_training_io.save_detection_upload_progress(count=42)
@@ -16,9 +18,9 @@ async def test_cleanup_successfull(test_initialized_trainer: TestingTrainerLogic
16
18
  assert trainer.active_training_io.detection_upload_progress_exist() is True
17
19
  assert trainer.active_training_io.detections_upload_file_index_exists() is True
18
20
 
19
- await trainer.clear_training()
21
+ await trainer._clear_training()
20
22
 
21
- assert trainer._training is None # pylint: disable=protected-access
23
+ assert trainer._training is None
22
24
  assert trainer.node.last_training_io.exists() is False
23
25
  assert trainer.active_training_io.detections_exist() is False
24
26
  assert trainer.active_training_io.detection_upload_progress_exist() is False
@@ -1,11 +1,12 @@
1
1
  import asyncio
2
2
 
3
3
  from learning_loop_node.conftest import get_dummy_detections
4
- from learning_loop_node.data_classes import TrainingState
4
+ from learning_loop_node.data_classes import TrainerState
5
5
  from learning_loop_node.trainer.tests.state_helper import assert_training_state, create_active_training_file
6
6
  from learning_loop_node.trainer.tests.testing_trainer_logic import TestingTrainerLogic
7
7
  from learning_loop_node.trainer.trainer_logic import TrainerLogic
8
8
 
9
+ # pylint: disable=protected-access
9
10
  error_key = 'detecting'
10
11
 
11
12
 
@@ -13,60 +14,62 @@ def trainer_has_error(trainer: TrainerLogic):
13
14
  return trainer.errors.has_error_for(error_key)
14
15
 
15
16
 
16
- async def test_successful_detecting(test_initialized_trainer: TestingTrainerLogic): # TODO Flaky test
17
+ async def test_successful_detecting(test_initialized_trainer: TestingTrainerLogic): # NOTE was a flaky test
17
18
  trainer = test_initialized_trainer
18
19
  create_active_training_file(trainer, training_state='train_model_uploaded',
19
- model_id_for_detecting='917d5c7f-403d-7e92-f95f-577f79c2273a')
20
+ model_uuid_for_detecting='917d5c7f-403d-7e92-f95f-577f79c2273a')
20
21
  # trainer.load_active_training()
21
- _ = asyncio.get_running_loop().create_task(trainer.do_detections())
22
+ _ = asyncio.get_running_loop().create_task(
23
+ trainer._perform_state('do_detections', TrainerState.Detecting, TrainerState.Detected, trainer._do_detections))
22
24
 
23
- await assert_training_state(trainer.training, 'detecting', timeout=1, interval=0.001)
24
- await assert_training_state(trainer.training, 'detected', timeout=10, interval=0.001)
25
+ await assert_training_state(trainer.training, TrainerState.Detecting, timeout=1, interval=0.001)
26
+ await assert_training_state(trainer.training, TrainerState.Detected, timeout=10, interval=0.001)
25
27
 
26
28
  assert trainer_has_error(trainer) is False
27
- assert trainer.training.training_state == 'detected'
29
+ assert trainer.training.training_state == TrainerState.Detected
28
30
  assert trainer.node.last_training_io.load() == trainer.training
29
31
  assert trainer.active_training_io.detections_exist()
30
32
 
31
33
 
32
34
  async def test_detecting_can_be_aborted(test_initialized_trainer: TestingTrainerLogic):
33
35
  trainer = test_initialized_trainer
34
- create_active_training_file(trainer, training_state=TrainingState.TrainModelUploaded)
35
- trainer.init_from_last_training()
36
- trainer.training.model_id_for_detecting = '12345678-bobo-7e92-f95f-424242424242'
36
+ create_active_training_file(trainer, training_state=TrainerState.TrainModelUploaded)
37
+ trainer._init_from_last_training()
38
+ trainer.training.model_uuid_for_detecting = '12345678-bobo-7e92-f95f-424242424242'
37
39
 
38
- _ = asyncio.get_running_loop().create_task(trainer.run())
40
+ _ = asyncio.get_running_loop().create_task(trainer._run())
39
41
 
40
- await assert_training_state(trainer.training, 'detecting', timeout=5, interval=0.001)
42
+ await assert_training_state(trainer.training, TrainerState.Detecting, timeout=5, interval=0.001)
41
43
  await trainer.stop()
42
44
  await asyncio.sleep(0.1)
43
45
 
44
- assert trainer._training is None # pylint: disable=protected-access
46
+ assert trainer._training is None
45
47
  assert trainer.active_training_io.detections_exist() is False
46
48
  assert trainer.node.last_training_io.exists() is False
47
49
 
48
50
 
49
51
  async def test_model_not_downloadable_error(test_initialized_trainer: TestingTrainerLogic):
50
52
  trainer = test_initialized_trainer
51
- create_active_training_file(trainer, training_state='train_model_uploaded',
52
- model_id_for_detecting='00000000-0000-0000-0000-000000000000') # bad model id
53
- trainer.init_from_last_training()
53
+ create_active_training_file(trainer, training_state=TrainerState.TrainModelUploaded,
54
+ model_uuid_for_detecting='00000000-0000-0000-0000-000000000000') # bad model id
55
+ trainer._init_from_last_training()
54
56
 
55
- _ = asyncio.get_running_loop().create_task(trainer.run())
57
+ _ = asyncio.get_running_loop().create_task(trainer._run())
56
58
 
57
59
  await assert_training_state(trainer.training, 'detecting', timeout=1, interval=0.001)
58
60
  await assert_training_state(trainer.training, 'train_model_uploaded', timeout=1, interval=0.001)
61
+ await asyncio.sleep(0.1)
59
62
 
60
63
  assert trainer_has_error(trainer)
61
- assert trainer.training.training_state == 'train_model_uploaded'
62
- assert trainer.training.model_id_for_detecting == '00000000-0000-0000-0000-000000000000'
64
+ assert trainer.training.training_state == TrainerState.TrainModelUploaded
65
+ assert trainer.training.model_uuid_for_detecting == '00000000-0000-0000-0000-000000000000'
63
66
  assert trainer.node.last_training_io.load() == trainer.training
64
67
 
65
68
 
66
69
  def test_save_load_detections(test_initialized_trainer: TestingTrainerLogic):
67
70
  trainer = test_initialized_trainer
68
71
  create_active_training_file(trainer)
69
- trainer.init_from_last_training()
72
+ trainer._init_from_last_training()
70
73
 
71
74
  detections = [get_dummy_detections(), get_dummy_detections()]
72
75