learning-loop-node 0.9.3__py3-none-any.whl → 0.10.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of learning-loop-node might be problematic. Click here for more details.

Files changed (54) hide show
  1. learning_loop_node/__init__.py +2 -3
  2. learning_loop_node/annotation/annotator_logic.py +2 -2
  3. learning_loop_node/annotation/annotator_node.py +16 -15
  4. learning_loop_node/data_classes/__init__.py +17 -10
  5. learning_loop_node/data_classes/detections.py +7 -2
  6. learning_loop_node/data_classes/general.py +4 -5
  7. learning_loop_node/data_classes/training.py +49 -21
  8. learning_loop_node/data_exchanger.py +85 -139
  9. learning_loop_node/detector/__init__.py +0 -1
  10. learning_loop_node/detector/detector_node.py +10 -13
  11. learning_loop_node/detector/inbox_filter/cam_observation_history.py +4 -7
  12. learning_loop_node/detector/outbox.py +0 -1
  13. learning_loop_node/detector/rest/about.py +1 -0
  14. learning_loop_node/detector/tests/conftest.py +0 -1
  15. learning_loop_node/detector/tests/test_client_communication.py +5 -3
  16. learning_loop_node/detector/tests/test_outbox.py +2 -0
  17. learning_loop_node/detector/tests/testing_detector.py +1 -8
  18. learning_loop_node/globals.py +2 -2
  19. learning_loop_node/helpers/gdrive_downloader.py +1 -1
  20. learning_loop_node/helpers/misc.py +124 -17
  21. learning_loop_node/loop_communication.py +57 -25
  22. learning_loop_node/node.py +62 -135
  23. learning_loop_node/tests/test_downloader.py +8 -7
  24. learning_loop_node/tests/test_executor.py +14 -11
  25. learning_loop_node/tests/test_helper.py +3 -5
  26. learning_loop_node/trainer/downloader.py +1 -1
  27. learning_loop_node/trainer/executor.py +87 -83
  28. learning_loop_node/trainer/io_helpers.py +66 -9
  29. learning_loop_node/trainer/rest/backdoor_controls.py +10 -5
  30. learning_loop_node/trainer/rest/controls.py +3 -1
  31. learning_loop_node/trainer/tests/conftest.py +19 -28
  32. learning_loop_node/trainer/tests/states/test_state_cleanup.py +5 -3
  33. learning_loop_node/trainer/tests/states/test_state_detecting.py +23 -20
  34. learning_loop_node/trainer/tests/states/test_state_download_train_model.py +18 -12
  35. learning_loop_node/trainer/tests/states/test_state_prepare.py +13 -12
  36. learning_loop_node/trainer/tests/states/test_state_sync_confusion_matrix.py +21 -18
  37. learning_loop_node/trainer/tests/states/test_state_train.py +27 -28
  38. learning_loop_node/trainer/tests/states/test_state_upload_detections.py +34 -32
  39. learning_loop_node/trainer/tests/states/test_state_upload_model.py +22 -20
  40. learning_loop_node/trainer/tests/test_errors.py +20 -12
  41. learning_loop_node/trainer/tests/test_trainer_states.py +4 -5
  42. learning_loop_node/trainer/tests/testing_trainer_logic.py +25 -30
  43. learning_loop_node/trainer/trainer_logic.py +80 -590
  44. learning_loop_node/trainer/trainer_logic_generic.py +495 -0
  45. learning_loop_node/trainer/trainer_node.py +27 -106
  46. {learning_loop_node-0.9.3.dist-info → learning_loop_node-0.10.0.dist-info}/METADATA +1 -1
  47. learning_loop_node-0.10.0.dist-info/RECORD +85 -0
  48. learning_loop_node/converter/converter_logic.py +0 -68
  49. learning_loop_node/converter/converter_node.py +0 -125
  50. learning_loop_node/converter/tests/test_converter.py +0 -55
  51. learning_loop_node/trainer/training_syncronizer.py +0 -52
  52. learning_loop_node-0.9.3.dist-info/RECORD +0 -88
  53. /learning_loop_node/{converter/__init__.py → py.typed} +0 -0
  54. {learning_loop_node-0.9.3.dist-info → learning_loop_node-0.10.0.dist-info}/WHEEL +0 -0
@@ -3,405 +3,91 @@ import json
3
3
  import logging
4
4
  import os
5
5
  import shutil
6
- import time
7
6
  from abc import abstractmethod
8
- from dataclasses import asdict
9
7
  from datetime import datetime
10
- from glob import glob
11
- from time import perf_counter
12
- from typing import TYPE_CHECKING, Coroutine, Dict, List, Optional, Union
13
- from uuid import UUID, uuid4
8
+ from typing import Coroutine, List, Optional
14
9
 
15
- import socketio
16
10
  from dacite import from_dict
17
- from fastapi.encoders import jsonable_encoder
18
- from tqdm import tqdm
19
-
20
- from ..data_classes import (BasicModel, Category, Context, Detections, Errors, Hyperparameter, ModelInformation,
21
- PretrainedModel, Training, TrainingData, TrainingError, TrainingState)
22
- from ..helpers.misc import create_image_folder
23
- from ..node import Node
24
- from . import training_syncronizer
25
- from .downloader import TrainingsDownloader
26
- from .executor import Executor
27
- from .io_helpers import ActiveTrainingIO
28
-
29
- if TYPE_CHECKING:
30
- from .trainer_node import TrainerNode
31
11
 
32
-
33
- def is_valid_uuid4(val):
34
- try:
35
- _ = UUID(str(val)).version
36
- return True
37
- except ValueError:
38
- return False
12
+ from ..data_classes import Detections, ModelInformation, TrainerState, TrainingError
13
+ from ..helpers.misc import create_image_folder, create_project_folder, images_for_ids, is_valid_uuid4
14
+ from .executor import Executor
15
+ from .trainer_logic_generic import TrainerLogicGeneric
39
16
 
40
17
 
41
- class TrainerLogic():
18
+ class TrainerLogic(TrainerLogicGeneric):
42
19
 
43
20
  def __init__(self, model_format: str) -> None:
44
- self.model_format: str = model_format
21
+ """This class is the base class for all trainers that use an executor to run training processes.
22
+ The executor is used to run the training process in a separate process."""
23
+
24
+ super().__init__(model_format)
25
+ self._detection_progress: Optional[float] = None
45
26
  self._executor: Optional[Executor] = None
46
- self.start_time: Optional[float] = None
47
- self.training_task: Optional[asyncio.Task] = None
48
27
  self.start_training_task: Optional[Coroutine] = None
49
- self.errors = Errors()
50
- self.shutdown_event: asyncio.Event = asyncio.Event()
51
- self.detection_progress = 0.0
52
-
53
- self._training: Optional[Training] = None
54
- self._active_training_io: Optional[ActiveTrainingIO] = None
55
- self._node: Optional[TrainerNode] = None
56
- self.restart_after_training = os.environ.get('RESTART_AFTER_TRAINING', 'FALSE').lower() in ['true', '1']
57
- self.keep_old_trainings = os.environ.get('KEEP_OLD_TRAININGS', 'FALSE').lower() in ['true', '1']
58
- self.inference_batch_size = int(os.environ.get('INFERENCE_BATCH_SIZE', '10'))
59
- logging.info(f'INFERENCE_BATCH_SIZE: {self.inference_batch_size}')
28
+ self.inference_batch_size = 10
60
29
 
61
- @property
62
- def executor(self) -> Executor:
63
- assert self._executor is not None, 'executor must be set, call `run_training` first'
64
- return self._executor
65
-
66
- @property
67
- def training(self) -> Training:
68
- assert self._training is not None, 'training must be set, call `init` first'
69
- return self._training
30
+ # ---------------------------------------- IMPLEMENTED ABSTRACT PROPERTIES ----------------------------------------
70
31
 
71
32
  @property
72
- def active_training_io(self) -> ActiveTrainingIO:
73
- assert self._active_training_io is not None, 'active_training_io must be set, call `init` first'
74
- return self._active_training_io
33
+ def detection_progress(self) -> Optional[float]:
34
+ return self._detection_progress
75
35
 
76
- @property
77
- def node(self) -> 'TrainerNode':
78
- assert self._node is not None, 'node should be set by TrainerNodes before initialization'
79
- return self._node
36
+ # ---------------------------------------- PROPERTIES ----------------------------------------
80
37
 
81
38
  @property
82
- def is_initialized(self) -> bool:
83
- """_training and _active_training_io are set in 'init_new_training' or 'init_from_last_training'"""
84
- return self._training is not None and self._active_training_io is not None and self._node is not None
85
-
86
- def init_new_training(self, context: Context, details: Dict) -> None:
87
- """Called on `begin_training` event from the Learning Loop.
88
- Note that details needs the entries 'categories' and 'training_number'"""
39
+ def executor(self) -> Executor:
40
+ assert self._executor is not None, 'executor must be set, call `run_training` first'
41
+ return self._executor
89
42
 
90
- try:
91
- project_folder = Node.create_project_folder(context)
92
- if not self.keep_old_trainings:
93
- # NOTE: We delete all existing training folders because they are not needed anymore.
94
- TrainerLogic.delete_all_training_folders(project_folder)
95
- self._training = TrainerLogic.generate_training(project_folder, context)
96
- self._training.data = TrainingData(categories=Category.from_list(details['categories']))
97
- self._training.data.hyperparameter = from_dict(data_class=Hyperparameter, data=details)
98
- self._training.training_number = details['training_number']
99
- self._training.base_model_id = details['id']
100
- self._training.training_state = TrainingState.Initialized
101
- self._active_training_io = ActiveTrainingIO(self._training.training_folder)
102
- logging.info(f'init training: {self._training}')
103
- except Exception:
104
- logging.exception('Error in init')
105
-
106
- def init_from_last_training(self) -> None:
107
- self._training = self.node.last_training_io.load()
108
- assert self._training is not None and self._training.training_folder is not None, 'could not restore training folder'
109
- self._active_training_io = ActiveTrainingIO(self._training.training_folder)
110
-
111
- async def run(self) -> None:
112
- """Called on `begin_training` event from the Learning Loop."""
113
-
114
- self.start_time = time.time()
115
- self.errors.reset_all()
116
- try:
117
- self.training_task = asyncio.get_running_loop().create_task(self._run())
118
- await self.training_task # Object is used to potentially cancel the task
119
- except asyncio.CancelledError:
120
- if not self.shutdown_event.is_set():
121
- logging.info('training task was cancelled but not by shutdown event')
122
- self.training.training_state = TrainingState.ReadyForCleanup
123
- self.node.last_training_io.save(self.training)
124
- await self.clear_training()
125
-
126
- except Exception as e:
127
- logging.exception(f'Error in train: {e}')
128
- finally:
129
- self.start_time = None
130
-
131
- # ---------------------------------------- TRAINING STATES ----------------------------------------
132
-
133
- async def _run(self) -> None:
134
- """asyncio.CancelledError is catched in train"""
135
-
136
- if not self.is_initialized:
137
- logging.error('could not start training - trainer is not initialized')
138
- return
43
+ # ---------------------------------------- IMPLEMENTED ABSTRACT MEHTODS ----------------------------------------
139
44
 
140
- while self._training is not None:
141
- tstate = self.training.training_state
142
- logging.info(f'STATE LOOP: {tstate}')
143
- await asyncio.sleep(0.6) # Note: Required for pytests!
144
- if tstate == TrainingState.Initialized: # -> DataDownloading -> DataDownloaded
145
- await self.prepare()
146
- elif tstate == TrainingState.DataDownloaded: # -> TrainModelDownloading -> TrainModelDownloaded
147
- await self.download_model()
148
- elif tstate == TrainingState.TrainModelDownloaded: # -> TrainingRunning -> TrainingFinished
149
- await self.train()
150
- elif tstate == TrainingState.TrainingFinished: # -> ConfusionMatrixSyncing -> ConfusionMatrixSynced
151
- await self.ensure_confusion_matrix_synced()
152
- elif tstate == TrainingState.ConfusionMatrixSynced: # -> TrainModelUploading -> TrainModelUploaded
153
- await self.upload_model()
154
- elif tstate == TrainingState.TrainModelUploaded: # -> Detecting -> Detected
155
- await self.do_detections()
156
- elif tstate == TrainingState.Detected: # -> DetectionUploading -> ReadyForCleanup
157
- await self.upload_detections()
158
- elif tstate == TrainingState.ReadyForCleanup: # -> RESTART or TrainingFinished
159
- await self.clear_training()
160
- self.may_restart()
161
-
162
- async def prepare(self) -> None:
163
- previous_state = self.training.training_state
164
- self.training.training_state = TrainingState.DataDownloading
165
- error_key = 'prepare'
166
- try:
167
- await self._prepare()
168
- except asyncio.CancelledError:
169
- logging.warning('CancelledError in prepare')
170
- raise
171
- except Exception as e:
172
- logging.exception("Unknown error in 'prepare'. Exception:")
173
- self.training.training_state = previous_state
174
- self.errors.set(error_key, str(e))
175
- else:
176
- self.errors.reset(error_key)
177
- self.training.training_state = TrainingState.DataDownloaded
178
- self.node.last_training_io.save(self.training)
179
-
180
- async def _prepare(self) -> None:
181
- self.node.data_exchanger.set_context(self.training.context)
182
- downloader = TrainingsDownloader(self.node.data_exchanger)
183
- image_data, skipped_image_count = await downloader.download_training_data(self.training.images_folder)
184
- assert self.training.data is not None, 'training.data must be set'
185
- self.training.data.image_data = image_data
186
- self.training.data.skipped_image_count = skipped_image_count
187
-
188
- async def download_model(self) -> None:
189
- logging.info('Downloading model')
190
- previous_state = self.training.training_state
191
- self.training.training_state = TrainingState.TrainModelDownloading
192
- error_key = 'download_model'
193
- try:
194
- await self._download_model()
195
- except asyncio.CancelledError:
196
- logging.warning('CancelledError in download_model')
197
- raise
198
- except Exception as e:
199
- logging.exception('download_model failed')
200
- self.training.training_state = previous_state
201
- self.errors.set(error_key, str(e))
202
- else:
203
- self.errors.reset(error_key)
204
- logging.info('download_model_task finished')
205
- self.training.training_state = TrainingState.TrainModelDownloaded
206
- self.node.last_training_io.save(self.training)
207
-
208
- async def _download_model(self) -> None:
209
- model_id = self.training.base_model_id
210
- assert model_id is not None, 'model_id must be set'
211
- if is_valid_uuid4(
212
- self.training.base_model_id): # TODO this checks if we continue a training -> make more explicit
213
- logging.info('loading model from Learning Loop')
214
- logging.info(f'downloading model {model_id} as {self.model_format}')
215
- await self.node.data_exchanger.download_model(self.training.training_folder, self.training.context, model_id, self.model_format)
216
- shutil.move(f'{self.training.training_folder}/model.json',
217
- f'{self.training.training_folder}/base_model.json')
218
- else:
219
- logging.info(f'base_model_id {model_id} is not a valid uuid4, skipping download')
220
-
221
- async def train(self) -> None:
222
- logging.info('Running training')
45
+ async def _train(self) -> None:
46
+ previous_state = TrainerState.TrainModelDownloaded
223
47
  error_key = 'run_training'
224
- # NOTE normally we reset errors after the step was successful. We do not want to display an old error during the whole training.
225
- self.errors.reset(error_key)
226
- previous_state = self.training.training_state
227
48
  self._executor = Executor(self.training.training_folder)
228
- self.training.training_state = TrainingState.TrainingRunning
49
+ self.training.training_state = TrainerState.TrainingRunning
50
+
229
51
  try:
230
52
  await self._start_training()
231
-
232
53
  last_sync_time = datetime.now()
54
+
233
55
  while True:
234
- if not self.executor.is_process_running():
56
+ await asyncio.sleep(0.1)
57
+ if not self.executor.is_running():
235
58
  break
236
59
  if (datetime.now() - last_sync_time).total_seconds() > 5:
237
60
  last_sync_time = datetime.now()
238
- if self.get_executor_error_from_log():
61
+ if self._get_executor_error_from_log():
239
62
  break
240
63
  self.errors.reset(error_key)
241
64
  try:
242
- await self.sync_confusion_matrix()
65
+ await self._sync_confusion_matrix()
243
66
  except asyncio.CancelledError:
244
67
  logging.warning('CancelledError in run_training')
245
68
  raise
246
69
  except Exception:
247
- pass
248
- else:
249
- await asyncio.sleep(0.1)
70
+ logging.error('Error in sync_confusion_matrix (this error is ignored)')
250
71
 
251
- error = self.get_executor_error_from_log()
252
- if error:
253
- self.errors.set(error_key, error)
72
+ if error := self._get_executor_error_from_log():
254
73
  raise TrainingError(cause=error)
255
- # TODO check if this works:
74
+
75
+ # NOTE: This is problematic, because the return code is not 0 when executor was stoppen e.g. via self.stop()
256
76
  # if self.executor.return_code != 0:
257
- # self.errors.set(error_key, f'Executor return code was {self.executor.return_code}')
258
- # raise TrainingError(cause=f'Executor return code was {self.executor.return_code}')
77
+ # raise TrainingError(cause=f'Executor returned with error code: {self.executor.return_code}')
259
78
 
260
- except asyncio.CancelledError:
261
- logging.warning('CancelledError in run_training')
262
- raise
263
79
  except TrainingError:
264
- logging.exception('Error in TrainingProcess')
265
- if self.executor.is_process_running():
266
- self.executor.stop()
267
- self.training.training_state = previous_state
268
- except Exception as e:
269
- self.errors.set(error_key, f'Could not start training {str(e)}')
80
+ logging.exception('Exception in trainer_logic._train')
81
+ await self.executor.stop_and_wait()
270
82
  self.training.training_state = previous_state
271
- logging.exception('Error in run_training')
272
- else:
273
- self.training.training_state = TrainingState.TrainingFinished
274
- self.node.last_training_io.save(self.training)
275
-
276
- async def _start_training(self):
277
- self.start_training_task = None # NOTE: this is used i.e. by tests
278
- if self.can_resume():
279
- self.start_training_task = self.resume()
280
- else:
281
- base_model_id = self.training.base_model_id
282
- if not is_valid_uuid4(base_model_id): # TODO this check was done earlier!
283
- assert isinstance(base_model_id, str)
284
- # TODO this could be removed here and accessed via self.training.base_model_id
285
- self.start_training_task = self.start_training_from_scratch(base_model_id)
286
- else:
287
- self.start_training_task = self.start_training()
288
- await self.start_training_task
289
-
290
- async def ensure_confusion_matrix_synced(self):
291
- logging.info('Ensure syncing confusion matrix')
292
- previous_state = self.training.training_state
293
- self.training.training_state = TrainingState.ConfusionMatrixSyncing
294
- try:
295
- await self.sync_confusion_matrix()
296
- except asyncio.CancelledError:
297
- logging.warning('CancelledError in run_training')
298
- raise
299
- except Exception:
300
- logging.exception('Error in ensure_confusion_matrix_synced')
301
- self.training.training_state = previous_state
302
- else:
303
- self.training.training_state = TrainingState.ConfusionMatrixSynced
304
- self.node.last_training_io.save(self.training)
305
-
306
- async def sync_confusion_matrix(self):
307
- logging.info('Syncing confusion matrix')
308
- error_key = 'sync_confusion_matrix'
309
- try:
310
- await training_syncronizer.try_sync_model(self, self.node.uuid, self.node.sio_client)
311
- except socketio.exceptions.BadNamespaceError as e: # type: ignore
312
- logging.error('Error during confusion matrix syncronization. BadNamespaceError')
313
- self.errors.set(error_key, str(e))
314
- raise
315
- except Exception as e:
316
- logging.exception('Error during confusion matrix syncronization')
317
- self.errors.set(error_key, str(e))
318
- raise
319
-
320
- self.errors.reset(error_key)
321
-
322
- async def upload_model(self) -> None:
323
- error_key = 'upload_model'
324
- previous_state = self.training.training_state
325
- self.training.training_state = TrainingState.TrainModelUploading
326
- try:
327
- new_model_id = await self._upload_model_return_new_id(self.training.context)
328
- if new_model_id is None:
329
- self.training.training_state = TrainingState.ReadyForCleanup
330
- logging.error('could not upload model - maybe training failed.. cleaning up')
331
- return
332
- assert new_model_id is not None, 'uploaded_model must be set'
333
- logging.info(f'successfully uploaded model and received new model id: {new_model_id}')
334
- self.training.model_id_for_detecting = new_model_id
335
- except asyncio.CancelledError:
336
- logging.warning('CancelledError in upload_model')
337
83
  raise
338
- except Exception as e:
339
- logging.exception('Error in upload_model. Exception:')
340
- self.errors.set(error_key, str(e))
341
- self.training.training_state = previous_state # TODO... going back is pointless here as it ends in a deadlock ?!
342
- # self.training.training_state = TrainingState.ReadyForCleanup
343
- else:
344
- self.errors.reset(error_key)
345
- self.training.training_state = TrainingState.TrainModelUploaded
346
- self.node.last_training_io.save(self.training)
347
-
348
- async def _upload_model_return_new_id(self, context: Context) -> Optional[str]:
349
- """Upload model files, usually pytorch model (.pt) hyp.yaml and the converted .wts file.
350
- Note that with the latest trainers the conversion to (.wts) is done by the trainer.
351
- The conversion from .wts to .engine is done by the detector (needs to be done on target hardware).
352
- Note that trainer may train with different classes, which is why we send an initial model.json file.
353
- """
354
- files = await asyncio.get_running_loop().run_in_executor(None, self.get_latest_model_files)
355
-
356
- if files is None:
357
- return None
358
-
359
- if isinstance(files, List):
360
- files = {self.model_format: files}
361
- assert isinstance(files, Dict), f'can only save model as list or dict, but was {files}'
362
-
363
- model_json_path = self.create_model_json_with_categories()
364
- already_uploaded_formats = self.active_training_io.load_model_upload_progress()
365
-
366
- new_id = None
367
- for file_format in files:
368
- if file_format in already_uploaded_formats:
369
- continue
370
- _files = files[file_format]
371
- # model.json was mandatory in previous versions. Now its forbidden to provide an own model.json file.
372
- assert not any(f for f in _files if 'model.json' in f), "Upload 'model.json' not allowed (added automatically)."
373
- _files.append(model_json_path)
374
- new_id = await self.node.data_exchanger.upload_model_for_training(context, _files, self.training.training_number, file_format)
375
- if new_id is None:
376
- return None
377
-
378
- already_uploaded_formats.append(file_format)
379
- self.active_training_io.save_model_upload_progress(already_uploaded_formats)
380
-
381
- return new_id
382
-
383
- async def do_detections(self):
384
- error_key = 'detecting'
385
- previous_state = self.training.training_state
386
- try:
387
- self.training.training_state = TrainingState.Detecting
388
- await self._do_detections()
389
- except asyncio.CancelledError:
390
- logging.warning('CancelledError in do_detections')
391
- raise
392
- except Exception as e:
393
- self.errors.set(error_key, str(e))
394
- logging.exception('Error in do_detections - Exception:')
395
- self.training.training_state = previous_state
396
- else:
397
- self.errors.reset(error_key)
398
- self.training.training_state = TrainingState.Detected
399
- self.node.last_training_io.save(self.training)
400
84
 
401
85
  async def _do_detections(self) -> None:
402
86
  context = self.training.context
403
- model_id = self.training.model_id_for_detecting
404
- assert model_id, 'model_id must be set'
87
+ model_id = self.training.model_uuid_for_detecting
88
+ if not model_id:
89
+ logging.error('model_id is not set! Cannot do detections.')
90
+ return
405
91
  tmp_folder = f'/tmp/model_for_auto_detections_{model_id}_{self.model_format}'
406
92
 
407
93
  shutil.rmtree(tmp_folder, ignore_errors=True)
@@ -410,111 +96,57 @@ class TrainerLogic():
410
96
 
411
97
  await self.node.data_exchanger.download_model(tmp_folder, context, model_id, self.model_format)
412
98
  with open(f'{tmp_folder}/model.json', 'r') as f:
413
- content = json.load(f)
414
- model_information = from_dict(data_class=ModelInformation, data=content)
99
+ model_information = from_dict(data_class=ModelInformation, data=json.load(f))
415
100
 
416
- project_folder = Node.create_project_folder(context)
101
+ project_folder = create_project_folder(context)
417
102
  image_folder = create_image_folder(project_folder)
418
103
  self.node.data_exchanger.set_context(context)
419
104
  image_ids = []
420
105
  for state, p in zip(['inbox', 'annotate', 'review', 'complete'], [0.1, 0.2, 0.3, 0.4]):
421
- self.detection_progress = p
106
+ self._detection_progress = p
422
107
  logging.info(f'fetching image ids of {state}')
423
- new_ids = await self.node.data_exchanger.fetch_image_ids(query_params=f'state={state}')
108
+ new_ids = await self.node.data_exchanger.fetch_image_uuids(query_params=f'state={state}')
424
109
  image_ids += new_ids
425
110
  logging.info(f'downloading {len(new_ids)} images')
426
111
  await self.node.data_exchanger.download_images(new_ids, image_folder)
427
- self.detection_progress = 0.42
428
- await self.node.data_exchanger.delete_corrupt_images(image_folder)
112
+ self._detection_progress = 0.42
113
+ # await delete_corrupt_images(image_folder)
429
114
 
430
- images = await asyncio.get_event_loop().run_in_executor(None, TrainerLogic.images_for_ids, image_ids, image_folder)
431
- num_images = len(images)
432
- logging.info(f'running detections on {num_images} images')
433
- batch_size = 200
434
- idx = 0
115
+ images = await asyncio.get_event_loop().run_in_executor(None, images_for_ids, image_ids, image_folder)
435
116
  if not images:
436
- self.active_training_io.save_detections([], idx)
437
- for i in tqdm(range(0, num_images, batch_size), position=0, leave=True):
438
- self.detection_progress = 0.5 + (i/num_images)*0.5
439
- batch_images = images[i:i+batch_size]
117
+ self.active_training_io.save_detections([], 0)
118
+ num_images = len(images)
119
+
120
+ for idx, i in enumerate(range(0, num_images, self.inference_batch_size)):
121
+ self._detection_progress = 0.5 + (i/num_images)*0.5
122
+ batch_images = images[i:i+self.inference_batch_size]
440
123
  batch_detections = await self._detect(model_information, batch_images, tmp_folder)
441
124
  self.active_training_io.save_detections(batch_detections, idx)
442
- idx += 1
443
125
 
444
- return None
126
+ # ---------------------------------------- METHODS ----------------------------------------
445
127
 
446
- async def upload_detections(self):
447
- error_key = 'upload_detections'
448
- previous_state = self.training.training_state
449
- self.training.training_state = TrainingState.DetectionUploading
450
- await asyncio.sleep(0.1) # NOTE needed for tests
451
- try:
452
- json_files = self.active_training_io.get_detection_file_names()
453
- if not json_files:
454
- raise Exception()
455
- current_json_file_index = self.active_training_io.load_detections_upload_file_index()
456
- for i in range(current_json_file_index, len(json_files)):
457
- detections = self.active_training_io.load_detections(i)
458
- logging.info(f'uploading detections {i}/{len(json_files)}')
459
- await self._upload_detections_batched(self.training.context, detections)
460
- self.active_training_io.save_detections_upload_file_index(i+1)
461
- except asyncio.CancelledError:
462
- logging.warning('CancelledError in upload_detections')
463
- raise
464
- except Exception as e:
465
- self.errors.set(error_key, str(e))
466
- logging.exception('Error in upload_detections')
467
- self.training.training_state = previous_state
468
- else:
469
- self.errors.reset(error_key)
470
- self.training.training_state = TrainingState.ReadyForCleanup
471
- self.node.last_training_io.save(self.training)
472
-
473
- async def _upload_detections_batched(self, context: Context, detections: List[Detections]):
474
- batch_size = 10
475
- skip_detections = self.active_training_io.load_detection_upload_progress()
476
- for i in tqdm(range(skip_detections, len(detections), batch_size), position=0, leave=True):
477
- up_progress = i+batch_size
478
- batch_detections = detections[i:up_progress]
479
- dict_detections = [jsonable_encoder(asdict(detection)) for detection in batch_detections]
480
- logging.info(f'uploading detections. File size : {len(json.dumps(dict_detections))}')
481
- await self._upload_detections(context, batch_detections, up_progress)
482
- skip_detections = up_progress
483
-
484
- async def _upload_detections(self, context: Context, batch_detections: List[Detections], up_progress: int):
485
- assert self._active_training_io is not None, 'active_training must be set'
486
-
487
- detections_json = [jsonable_encoder(asdict(detections)) for detections in batch_detections]
488
- response = await self.node.loop_communicator.post(
489
- f'/{context.organization}/projects/{context.project}/detections', json=detections_json)
490
- if response.status_code != 200:
491
- msg = f'could not upload detections. {str(response)}'
492
- logging.error(msg)
493
- raise Exception(msg)
128
+ async def _start_training(self):
129
+ self.start_training_task = None # NOTE: this is used i.e. by tests
130
+ if self._can_resume():
131
+ self.start_training_task = self._resume()
494
132
  else:
495
- logging.info('successfully uploaded detections')
496
- if up_progress > len(batch_detections):
497
- self._active_training_io.save_detection_upload_progress(0)
133
+ base_model_uuid_or_name = self.training.base_model_uuid_or_name
134
+ if not is_valid_uuid4(base_model_uuid_or_name):
135
+ self.start_training_task = self._start_training_from_scratch()
498
136
  else:
499
- self._active_training_io.save_detection_upload_progress(up_progress)
500
-
501
- async def clear_training(self):
502
- self.active_training_io.delete_detections()
503
- self.active_training_io.delete_detection_upload_progress()
504
- self.active_training_io.delete_detections_upload_file_index()
505
- await self.clear_training_data(self.training.training_folder)
506
- self.node.last_training_io.delete()
507
- # self.training.training_state = TrainingState.TrainingFinished
508
- assert self._node is not None
509
- await self._node.send_status() # make sure the status is updated before we stop the training
510
- self._training = None
137
+ self.start_training_task = self._start_training_from_base_model()
138
+ await self.start_training_task
139
+
140
+ # ---------------------------------------- OVERWRITTEN METHODS ----------------------------------------
511
141
 
512
142
  async def stop(self) -> None:
513
143
  """If executor is running, stop it. Else cancel training task."""
514
- if not self.is_initialized:
144
+ print('===============> stop received in trainer_logic.', flush=True)
145
+
146
+ if not self.training_active:
515
147
  return
516
- if self._executor and self._executor.is_process_running():
517
- self.executor.stop()
148
+ if self._executor and self._executor.is_running():
149
+ await self.executor.stop_and_wait()
518
150
  elif self.training_task:
519
151
  logging.info('cancelling training task')
520
152
  if self.training_task.cancel():
@@ -523,175 +155,33 @@ class TrainerLogic():
523
155
  except asyncio.CancelledError:
524
156
  pass
525
157
  logging.info('cancelled training task')
526
- self.may_restart()
527
-
528
- async def shutdown(self) -> None:
529
- self.shutdown_event.set()
530
- await self.stop()
531
- await self.stop() # NOTE first stop may only stop training.
532
-
533
- def get_log(self) -> str:
534
- return self.executor.get_log()
158
+ self._may_restart()
535
159
 
536
- def may_restart(self) -> None:
537
- if self.restart_after_training:
538
- logging.info('restarting')
539
- assert self._node is not None
540
- self._node.restart()
541
- else:
542
- logging.info('not restarting')
543
-
544
- @property
545
- def general_progress(self) -> Optional[float]:
546
- """Represents the progress for different states."""
547
- if not self.is_initialized:
548
- return None
549
-
550
- t_state = self.training.training_state
551
- if t_state == TrainingState.DataDownloading:
552
- return self.node.data_exchanger.progress
553
- if t_state == TrainingState.TrainingRunning:
554
- return self.training_progress
555
- if t_state == TrainingState.Detecting:
556
- return self.detection_progress
557
-
558
- return None
559
160
  # ---------------------------------------- ABSTRACT METHODS ----------------------------------------
560
161
 
561
- @property
562
- @abstractmethod
563
- def training_progress(self) -> Optional[float]:
564
- """Represents the training progress."""
565
- raise NotImplementedError
566
-
567
- @property
568
- @abstractmethod
569
- def provided_pretrained_models(self) -> List[PretrainedModel]:
570
- raise NotImplementedError
571
-
572
- @property
573
162
  @abstractmethod
574
- def model_architecture(self) -> Optional[str]:
575
- raise NotImplementedError
163
+ async def _start_training_from_base_model(self) -> None:
164
+ '''Should be used to start a training on executer, e.g. self.executor.start(cmd).'''
576
165
 
577
166
  @abstractmethod
578
- async def start_training(self) -> None:
579
- '''Should be used to start a training.'''
167
+ async def _start_training_from_scratch(self) -> None:
168
+ '''Should be used to start a training from scratch on executer, e.g. self.executor.start(cmd).
169
+ NOTE base_model_id is now accessible via self.training.base_model_id
170
+ the id of a pretrained model provided by self.provided_pretrained_models.'''
580
171
 
581
172
  @abstractmethod
582
- async def start_training_from_scratch(self, base_model_id: str) -> None:
583
- '''Should be used to start a training from scratch.
584
- base_model_id is the id of a pretrained model provided by self.provided_pretrained_models.'''
585
-
586
- @abstractmethod
587
- def can_resume(self) -> bool:
173
+ def _can_resume(self) -> bool:
588
174
  '''Override this method to return True if the trainer can resume training.'''
589
175
 
590
176
  @abstractmethod
591
- async def resume(self) -> None:
177
+ async def _resume(self) -> None:
592
178
  '''Is called when self.can_resume() returns True.
593
179
  One may resume the training on a previously trained model stored by self.on_model_published(basic_model).'''
594
180
 
595
181
  @abstractmethod
596
- def get_executor_error_from_log(self) -> Optional[str]: # TODO we should allow other options to get the error
182
+ def _get_executor_error_from_log(self) -> Optional[str]:
597
183
  '''Should be used to provide error informations to the Learning Loop by extracting data from self.executor.get_log().'''
598
184
 
599
- @abstractmethod
600
- def get_new_model(self) -> Optional[BasicModel]:
601
- '''Is called frequently in `try_sync_model` to check if a new "best" model is availabe.
602
- Returns None if no new model could be found. Otherwise BasicModel(confusion_matrix, meta_information).
603
- `confusion_matrix` contains a dict of all classes:
604
- - The classes must be identified by their id, not their name.
605
- - For each class a dict with tp, fp, fn is provided (true positives, false positives, false negatives).
606
- `meta_information` can hold any data which is helpful for self.on_model_published to store weight file etc for later upload via self.get_model_files
607
- '''
608
-
609
- @abstractmethod
610
- def on_model_published(self, basic_model: BasicModel) -> None:
611
- '''Called after a BasicModel has been successfully send to the Learning Loop.
612
- The files for this model should be stored.
613
- self.get_latest_model_files is used to gather all files needed for transfering the actual data from the trainer node to the Learning Loop.
614
- In the simplest implementation this method just renames the weight file (encoded in BasicModel.meta_information) into a file name like latest_published_model
615
- '''
616
-
617
- @abstractmethod
618
- def get_latest_model_files(self) -> Optional[Union[List[str], Dict[str, List[str]]]]:
619
- '''Called when the Learning Loop requests to backup the latest model for the training.
620
- Should return a list of file paths which describe the model.
621
- These files must contain all data neccessary for the trainer to resume a training (eg. weight file, hyperparameters, etc.)
622
- and will be stored in the Learning Loop unter the format of this trainer.
623
- Note: by convention the weightfile should be named "model.<extension>" where extension is the file format of the weightfile.
624
- For example "model.pt" for pytorch or "model.weights" for darknet/yolo.
625
-
626
- If a trainer can also generate other formats (for example for an detector),
627
- a dictionary mapping format -> list of files can be returned.'''
628
-
629
185
  @abstractmethod
630
186
  async def _detect(self, model_information: ModelInformation, images: List[str], model_folder: str) -> List[Detections]:
631
187
  '''Called to run detections on a list of images.'''
632
-
633
- @abstractmethod
634
- async def clear_training_data(self, training_folder: str) -> None:
635
- '''Called after a training has finished. Deletes all data that is not needed anymore after a training run.
636
- This can be old weightfiles or any additional files.'''
637
-
638
- # ---------------------------------------- HELPER METHODS ----------------------------------------
639
-
640
- @staticmethod
641
- def images_for_ids(image_ids, image_folder) -> List[str]:
642
- logging.info(f'### Going to get images for {len(image_ids)} images ids')
643
- start = perf_counter()
644
- images = [img for img in glob(f'{image_folder}/**/*.*', recursive=True)
645
- if os.path.splitext(os.path.basename(img))[0] in image_ids]
646
- end = perf_counter()
647
- logging.info(f'found {len(images)} images for {len(image_ids)} image ids, which took {end-start:0.2f} seconds')
648
- return images
649
-
650
- @staticmethod
651
- def generate_training(project_folder: str, context: Context) -> Training:
652
- training_uuid = str(uuid4())
653
- return Training(
654
- id=training_uuid,
655
- context=context,
656
- project_folder=project_folder,
657
- images_folder=create_image_folder(project_folder),
658
- training_folder=TrainerLogic.create_training_folder(project_folder, training_uuid)
659
- )
660
-
661
- @staticmethod
662
- def delete_all_training_folders(project_folder: str):
663
- if not os.path.exists(f'{project_folder}/trainings'):
664
- return
665
- for uuid in os.listdir(f'{project_folder}/trainings'):
666
- shutil.rmtree(f'{project_folder}/trainings/{uuid}', ignore_errors=True)
667
-
668
- @staticmethod
669
- def create_training_folder(project_folder: str, trainings_id: str) -> str:
670
- training_folder = f'{project_folder}/trainings/{trainings_id}'
671
- os.makedirs(training_folder, exist_ok=True)
672
- return training_folder
673
-
674
- @property
675
- def hyperparameters(self) -> Optional[Dict]:
676
- if self._training and self._training.data and self._training.data.hyperparameter:
677
- information = {}
678
- information['resolution'] = self._training.data.hyperparameter.resolution
679
- information['flipRl'] = self._training.data.hyperparameter.flip_rl
680
- information['flipUd'] = self._training.data.hyperparameter.flip_ud
681
- return information
682
- return None
683
-
684
- def create_model_json_with_categories(self) -> str:
685
- """Remaining fields are filled by the Learning Loop"""
686
- if self._training and self._training.data:
687
- content = {
688
- 'categories': [asdict(c) for c in self._training.data.categories],
689
- }
690
- else:
691
- content = None
692
-
693
- model_json_path = '/tmp/model.json'
694
- with open(model_json_path, 'w') as f:
695
- json.dump(content, f)
696
-
697
- return model_json_path