learning-loop-node 0.9.3__py3-none-any.whl → 0.10.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of learning-loop-node might be problematic. Click here for more details.

Files changed (54) hide show
  1. learning_loop_node/__init__.py +2 -3
  2. learning_loop_node/annotation/annotator_logic.py +2 -2
  3. learning_loop_node/annotation/annotator_node.py +16 -15
  4. learning_loop_node/data_classes/__init__.py +17 -10
  5. learning_loop_node/data_classes/detections.py +7 -2
  6. learning_loop_node/data_classes/general.py +4 -5
  7. learning_loop_node/data_classes/training.py +49 -21
  8. learning_loop_node/data_exchanger.py +85 -139
  9. learning_loop_node/detector/__init__.py +0 -1
  10. learning_loop_node/detector/detector_node.py +10 -13
  11. learning_loop_node/detector/inbox_filter/cam_observation_history.py +4 -7
  12. learning_loop_node/detector/outbox.py +0 -1
  13. learning_loop_node/detector/rest/about.py +1 -0
  14. learning_loop_node/detector/tests/conftest.py +0 -1
  15. learning_loop_node/detector/tests/test_client_communication.py +5 -3
  16. learning_loop_node/detector/tests/test_outbox.py +2 -0
  17. learning_loop_node/detector/tests/testing_detector.py +1 -8
  18. learning_loop_node/globals.py +2 -2
  19. learning_loop_node/helpers/gdrive_downloader.py +1 -1
  20. learning_loop_node/helpers/misc.py +124 -17
  21. learning_loop_node/loop_communication.py +57 -25
  22. learning_loop_node/node.py +62 -135
  23. learning_loop_node/tests/test_downloader.py +8 -7
  24. learning_loop_node/tests/test_executor.py +14 -11
  25. learning_loop_node/tests/test_helper.py +3 -5
  26. learning_loop_node/trainer/downloader.py +1 -1
  27. learning_loop_node/trainer/executor.py +87 -83
  28. learning_loop_node/trainer/io_helpers.py +66 -9
  29. learning_loop_node/trainer/rest/backdoor_controls.py +10 -5
  30. learning_loop_node/trainer/rest/controls.py +3 -1
  31. learning_loop_node/trainer/tests/conftest.py +19 -28
  32. learning_loop_node/trainer/tests/states/test_state_cleanup.py +5 -3
  33. learning_loop_node/trainer/tests/states/test_state_detecting.py +23 -20
  34. learning_loop_node/trainer/tests/states/test_state_download_train_model.py +18 -12
  35. learning_loop_node/trainer/tests/states/test_state_prepare.py +13 -12
  36. learning_loop_node/trainer/tests/states/test_state_sync_confusion_matrix.py +21 -18
  37. learning_loop_node/trainer/tests/states/test_state_train.py +27 -28
  38. learning_loop_node/trainer/tests/states/test_state_upload_detections.py +34 -32
  39. learning_loop_node/trainer/tests/states/test_state_upload_model.py +22 -20
  40. learning_loop_node/trainer/tests/test_errors.py +20 -12
  41. learning_loop_node/trainer/tests/test_trainer_states.py +4 -5
  42. learning_loop_node/trainer/tests/testing_trainer_logic.py +25 -30
  43. learning_loop_node/trainer/trainer_logic.py +80 -590
  44. learning_loop_node/trainer/trainer_logic_generic.py +495 -0
  45. learning_loop_node/trainer/trainer_node.py +27 -106
  46. {learning_loop_node-0.9.3.dist-info → learning_loop_node-0.10.0.dist-info}/METADATA +1 -1
  47. learning_loop_node-0.10.0.dist-info/RECORD +85 -0
  48. learning_loop_node/converter/converter_logic.py +0 -68
  49. learning_loop_node/converter/converter_node.py +0 -125
  50. learning_loop_node/converter/tests/test_converter.py +0 -55
  51. learning_loop_node/trainer/training_syncronizer.py +0 -52
  52. learning_loop_node-0.9.3.dist-info/RECORD +0 -88
  53. /learning_loop_node/{converter/__init__.py → py.typed} +0 -0
  54. {learning_loop_node-0.9.3.dist-info → learning_loop_node-0.10.0.dist-info}/WHEEL +0 -0
@@ -0,0 +1,495 @@
1
+ import asyncio
2
+ import json
3
+ import logging
4
+ import shutil
5
+ import sys
6
+ import time
7
+ from abc import ABC, abstractmethod
8
+ from dataclasses import asdict
9
+ from typing import TYPE_CHECKING, Callable, Coroutine, Dict, List, Optional
10
+
11
+ from fastapi.encoders import jsonable_encoder
12
+
13
+ from ..data_classes import (Context, Errors, Hyperparameter, PretrainedModel, TrainerState, Training, TrainingData,
14
+ TrainingOut, TrainingStateData)
15
+ from ..helpers.misc import create_project_folder, delete_all_training_folders, generate_training, is_valid_uuid4
16
+ from .downloader import TrainingsDownloader
17
+ from .io_helpers import ActiveTrainingIO, EnvironmentVars, LastTrainingIO
18
+
19
+ if TYPE_CHECKING:
20
+ from .trainer_node import TrainerNode
21
+
22
+
23
+ class TrainerLogicGeneric(ABC):
24
+
25
+ def __init__(self, model_format: str) -> None:
26
+
27
+ # NOTE: model_format is used in the file path for the model on the server:
28
+ # It acts as a key for list of files (cf. _get_latest_model_files)
29
+ # '/{context.organization}/projects/{context.project}/models/{model_id}/{model_format}/file'
30
+ self.model_format: str = model_format
31
+ self.errors = Errors()
32
+
33
+ self.training_task: Optional[asyncio.Task] = None
34
+ self.shutdown_event: asyncio.Event = asyncio.Event()
35
+
36
+ self._node: Optional['TrainerNode'] = None # type: ignore
37
+ self._last_training_io: Optional[LastTrainingIO] = None # type: ignore
38
+
39
+ self._training: Optional[Training] = None
40
+ self._active_training_io: Optional[ActiveTrainingIO] = None
41
+ self._environment_vars = EnvironmentVars()
42
+
43
+ # ---------------------------------------- PROPERTIES TO AVOID CHECKING FOR NONE ----------------------------------------
44
+
45
+ @property
46
+ def node(self) -> 'TrainerNode':
47
+ assert self._node is not None, 'node should be set by TrainerNode before initialization'
48
+ return self._node
49
+
50
+ @property
51
+ def last_training_io(self) -> LastTrainingIO:
52
+ assert self._last_training_io is not None, 'last_training_io should be set by TrainerNode before initialization'
53
+ return self._last_training_io
54
+
55
+ @property
56
+ def active_training_io(self) -> ActiveTrainingIO:
57
+ assert self._active_training_io is not None, 'active_training_io must be set, call `init` first'
58
+ return self._active_training_io
59
+
60
+ @property
61
+ def training(self) -> Training:
62
+ assert self._training is not None, 'training must be initialized, call `init` first'
63
+ return self._training
64
+
65
+ @property
66
+ def hyperparameter(self) -> Hyperparameter:
67
+ assert self.training_data is not None, 'Training should have data'
68
+ assert self.training_data.hyperparameter is not None, 'Training.data should have hyperparameter'
69
+ return self.training_data.hyperparameter
70
+
71
+ # ---------------------------------------- PROPERTIES ----------------------------------------
72
+
73
+ @property
74
+ def training_data(self) -> Optional[TrainingData]:
75
+ if self.training_active and self.training.data:
76
+ return self.training.data
77
+ return None
78
+
79
+ @property
80
+ def training_context(self) -> Optional[Context]:
81
+ if self.training_active:
82
+ return self.training.context
83
+ return None
84
+
85
+ @property
86
+ def training_active(self) -> bool:
87
+ """_training and _active_training_io are set in 'init_new_training' or 'init_from_last_training'.
88
+ """
89
+ return self._training is not None and self._active_training_io is not None
90
+
91
+ @property
92
+ def state(self) -> str:
93
+ """Returns the current state of the training. Used solely by the node in send_status().
94
+ """
95
+ if (not self.training_active) or (self.training.training_state is None):
96
+ return TrainerState.Idle.value
97
+ return self.training.training_state
98
+
99
+ @property
100
+ def training_uptime(self) -> Optional[float]:
101
+ """Livetime of current Training object. Start time is set during initialization of Training object.
102
+ """
103
+ if self.training_active:
104
+ return time.time() - self.training.start_time
105
+ return None
106
+
107
+ @property
108
+ def hyperparameters_for_state_sync(self) -> Optional[Dict]:
109
+ """Used in sync_confusion_matrix and send_status to provide information about the training configuration.
110
+ """
111
+ if self._training and self._training.data and self._training.data.hyperparameter:
112
+ information = {}
113
+ information['resolution'] = self._training.data.hyperparameter.resolution
114
+ information['flipRl'] = self._training.data.hyperparameter.flip_rl
115
+ information['flipUd'] = self._training.data.hyperparameter.flip_ud
116
+ return information
117
+ return None
118
+
119
+ @property
120
+ def general_progress(self) -> Optional[float]:
121
+ """Represents the progress for different states, should run from 0 to 100 for each state.
122
+ Note that training_progress and detection_progress need to be implemented in the specific trainer.
123
+ """
124
+ if not self.training_active:
125
+ return None
126
+
127
+ t_state = self.training.training_state
128
+ if t_state == TrainerState.DataDownloading:
129
+ return self.node.data_exchanger.progress
130
+ if t_state == TrainerState.TrainingRunning:
131
+ return self.training_progress
132
+ if t_state == TrainerState.Detecting:
133
+ return self.detection_progress
134
+
135
+ return None
136
+
137
+ # ---------------------------------------- ABSTRACT PROPERTIES ----------------------------------------
138
+
139
+ @property
140
+ @abstractmethod
141
+ def training_progress(self) -> Optional[float]:
142
+ """Represents the training progress."""
143
+ raise NotImplementedError
144
+
145
+ @property
146
+ @abstractmethod
147
+ def detection_progress(self) -> Optional[float]:
148
+ """Represents the detection progress."""
149
+ raise NotImplementedError
150
+
151
+ @property
152
+ @abstractmethod
153
+ def model_architecture(self) -> Optional[str]:
154
+ """Returns the architecture name of the model if available"""
155
+ raise NotImplementedError
156
+
157
+ @property
158
+ @abstractmethod
159
+ def provided_pretrained_models(self) -> List[PretrainedModel]:
160
+ """Returns the list of provided pretrained models.
161
+ The names of the models will come back as model_uuid_or_name in the training details.
162
+ """
163
+ raise NotImplementedError
164
+
165
+ # ---------------------------------------- METHODS ----------------------------------------
166
+
167
+ # NOTE: Trainings are started by the Learning Loop via the begin_training event
168
+ # or by the trainer itself via try_continue_run_if_incomplete.
169
+ # The trainer will then initialize a new training object and start the training loop.
170
+ # Initializing a new training object will create the folder structure for the training.
171
+ # The training loop will then run through the states of the training.
172
+
173
+ async def try_continue_run_if_incomplete(self) -> bool:
174
+ """Tries to continue a training if the last training was not finished.
175
+ """
176
+ if not self.training_active and self.last_training_io.exists():
177
+ self._init_from_last_training()
178
+ logging.info('found incomplete training, continuing now.')
179
+ asyncio.get_event_loop().create_task(self._run())
180
+ return True
181
+ return False
182
+
183
+ def _init_from_last_training(self) -> None:
184
+ """Initializes a new training object from the last training saved on disc via last_training_io.
185
+ """
186
+ self._training = self.last_training_io.load()
187
+ assert self._training is not None and self._training.training_folder is not None, 'could not restore training folder'
188
+ self._active_training_io = ActiveTrainingIO(
189
+ self._training.training_folder, self.node.loop_communicator, self._training.context)
190
+
191
+ async def begin_training(self, organization: str, project: str, details: Dict) -> None:
192
+ """Called on `begin_training` event from the Learning Loop.
193
+ """
194
+ self._init_new_training(Context(organization=organization, project=project), details)
195
+ asyncio.get_event_loop().create_task(self._run())
196
+
197
+ def _init_new_training(self, context: Context, details: Dict) -> None:
198
+ """Called on `begin_training` event from the Learning Loop.
199
+ Note that details needs the entries 'categories' and 'training_number',
200
+ but also the hyperparameter entries.
201
+ """
202
+ project_folder = create_project_folder(context)
203
+ if not self._environment_vars.keep_old_trainings:
204
+ delete_all_training_folders(project_folder)
205
+ self._training = generate_training(project_folder, context)
206
+ self._training.set_values_from_data(details)
207
+
208
+ self._active_training_io = ActiveTrainingIO(
209
+ self._training.training_folder, self.node.loop_communicator, context)
210
+ logging.info(f'new training initialized: {self._training}')
211
+
212
+ async def _run(self) -> None:
213
+ """Called on `begin_training` event from the Learning Loop.
214
+ Either via `begin_training` or `try_continue_run_if_incomplete`.
215
+ """
216
+ self.errors.reset_all()
217
+ try:
218
+ self.training_task = asyncio.get_running_loop().create_task(self._training_loop())
219
+ await self.training_task # NOTE: Task object is used to potentially cancel the task
220
+ except asyncio.CancelledError:
221
+ if not self.shutdown_event.is_set():
222
+ logging.info('training task was cancelled but not by shutdown event')
223
+ self.training.training_state = TrainerState.ReadyForCleanup
224
+ self.last_training_io.save(self.training)
225
+ await self._clear_training()
226
+ except Exception as e:
227
+ logging.exception(f'Error in train: {e}')
228
+
229
+ # ---------------------------------------- TRAINING STATES ----------------------------------------
230
+
231
+ async def _training_loop(self) -> None:
232
+ """Cycle through the training states until the training is finished or
233
+ an asyncio.CancelledError is raised.
234
+ """
235
+ assert self.training_active
236
+
237
+ while self._training is not None:
238
+ tstate = self.training.training_state
239
+ await asyncio.sleep(0.6) # Note: Required for pytests!
240
+
241
+ if tstate == TrainerState.Initialized: # -> DataDownloading -> DataDownloaded
242
+ await self._perform_state('prepare', TrainerState.DataDownloading, TrainerState.DataDownloaded, self._prepare)
243
+ elif tstate == TrainerState.DataDownloaded: # -> TrainModelDownloading -> TrainModelDownloaded
244
+ await self._perform_state('download_model', TrainerState.TrainModelDownloading, TrainerState.TrainModelDownloaded, self._download_model)
245
+ elif tstate == TrainerState.TrainModelDownloaded: # -> TrainingRunning -> TrainingFinished
246
+ await self._perform_state('run_training', TrainerState.TrainingRunning, TrainerState.TrainingFinished, self._train)
247
+ elif tstate == TrainerState.TrainingFinished: # -> ConfusionMatrixSyncing -> ConfusionMatrixSynced
248
+ await self._perform_state('sync_confusion_matrix', TrainerState.ConfusionMatrixSyncing, TrainerState.ConfusionMatrixSynced, self._sync_confusion_matrix)
249
+ elif tstate == TrainerState.ConfusionMatrixSynced: # -> TrainModelUploading -> TrainModelUploaded
250
+ await self._perform_state('upload_model', TrainerState.TrainModelUploading, TrainerState.TrainModelUploaded, self._upload_model)
251
+ elif tstate == TrainerState.TrainModelUploaded: # -> Detecting -> Detected
252
+ await self._perform_state('detecting', TrainerState.Detecting, TrainerState.Detected, self._do_detections)
253
+ elif tstate == TrainerState.Detected: # -> DetectionUploading -> ReadyForCleanup
254
+ await self._perform_state('upload_detections', TrainerState.DetectionUploading, TrainerState.ReadyForCleanup, self.active_training_io.upload_detetions)
255
+ elif tstate == TrainerState.ReadyForCleanup: # -> RESTART or TrainingFinished
256
+ await self._clear_training()
257
+ self._may_restart()
258
+
259
+ async def _perform_state(self, error_key: str, state_during: TrainerState, state_after: TrainerState, action: Callable[[], Coroutine], reset_early=False):
260
+ await asyncio.sleep(0.1)
261
+ logging.info(f'Performing state: {state_during}')
262
+ previous_state = self.training.training_state
263
+ self.training.training_state = state_during
264
+ await asyncio.sleep(0.1)
265
+ if reset_early:
266
+ self.errors.reset(error_key)
267
+
268
+ try:
269
+ if await action():
270
+ logging.error('Something went really bad.. cleaning up')
271
+ state_after = TrainerState.ReadyForCleanup
272
+ except asyncio.CancelledError:
273
+ logging.warning(f'CancelledError in {state_during}')
274
+ raise
275
+ except Exception as e:
276
+ self.errors.set(error_key, str(e))
277
+ logging.exception(f'Error in {state_during} - Exception:')
278
+ self.training.training_state = previous_state
279
+ else:
280
+ if not reset_early:
281
+ self.errors.reset(error_key)
282
+ self.training.training_state = state_after
283
+ self.last_training_io.save(self.training)
284
+
285
+ async def _prepare(self) -> None:
286
+ """Downloads images to the images_folder and saves annotations to training.data.image_data.
287
+ """
288
+ self.node.data_exchanger.set_context(self.training.context)
289
+ downloader = TrainingsDownloader(self.node.data_exchanger)
290
+ image_data, skipped_image_count = await downloader.download_training_data(self.training.images_folder)
291
+ assert self.training.data is not None, 'training.data must be set'
292
+ self.training.data.image_data = image_data
293
+ self.training.data.skipped_image_count = skipped_image_count
294
+
295
+ async def _download_model(self) -> None:
296
+ """If training is continued, the model is downloaded from the Learning Loop to the training_folder.
297
+ The downloaded model.json file is renamed to base_model.json because a new model.json will be created during training.
298
+ """
299
+ base_model_uuid = self.training.base_model_uuid_or_name
300
+
301
+ # TODO this checks if we continue a training -> make more explicit
302
+ if not base_model_uuid or not is_valid_uuid4(base_model_uuid):
303
+ logging.info(f'skipping model download. No base model provided (in form of uuid): {base_model_uuid}')
304
+ return
305
+
306
+ logging.info('loading model from Learning Loop')
307
+ logging.info(f'downloading model {base_model_uuid} as {self.model_format}')
308
+ await self.node.data_exchanger.download_model(self.training.training_folder, self.training.context, base_model_uuid, self.model_format)
309
+ shutil.move(f'{self.training.training_folder}/model.json',
310
+ f'{self.training.training_folder}/base_model.json')
311
+
312
+ async def _sync_confusion_matrix(self) -> None:
313
+ """Syncronizes the confusion matrix with the Learning Loop via the update_training endpoint.
314
+ NOTE: This stage sets the errors explicitly because it may be used inside the training stage.
315
+ """
316
+ error_key = 'sync_confusion_matrix'
317
+ try:
318
+ new_best_model = self._get_new_best_training_state()
319
+ if new_best_model and self.training.data:
320
+ new_training = TrainingOut(trainer_id=self.node.uuid,
321
+ confusion_matrix=new_best_model.confusion_matrix,
322
+ train_image_count=self.training.data.train_image_count(),
323
+ test_image_count=self.training.data.test_image_count(),
324
+ hyperparameters=self.hyperparameters_for_state_sync)
325
+ await asyncio.sleep(0.1) # NOTE needed for tests.
326
+
327
+ result = await self.node.sio_client.call('update_training', (
328
+ self.training.context.organization, self.training.context.project, jsonable_encoder(new_training)))
329
+ if isinstance(result, dict) and result['success']:
330
+ logging.info(f'successfully updated training {asdict(new_training)}')
331
+ self._on_metrics_published(new_best_model)
332
+ else:
333
+ raise Exception(f'Error for update_training: Response from loop was : {result}')
334
+ except Exception as e:
335
+ logging.exception('Error during confusion matrix syncronization')
336
+ self.errors.set(error_key, str(e))
337
+ raise
338
+ self.errors.reset(error_key)
339
+
340
+ async def _upload_model(self) -> None:
341
+ """Uploads the latest model to the Learning Loop.
342
+ """
343
+ new_model_uuid = await self._upload_model_return_new_model_uuid(self.training.context)
344
+ if new_model_uuid is None:
345
+ self.training.training_state = TrainerState.ReadyForCleanup
346
+ logging.error('could not upload model - maybe training failed.. cleaning up')
347
+ logging.info(f'Successfully uploaded model and received new model id: {new_model_uuid}')
348
+ self.training.model_uuid_for_detecting = new_model_uuid
349
+
350
+ async def _upload_model_return_new_model_uuid(self, context: Context) -> Optional[str]:
351
+ """Upload model files, usually pytorch model (.pt) hyp.yaml and the converted .wts file.
352
+ Note that with the latest trainers the conversion to (.wts) is done by the trainer.
353
+ The conversion from .wts to .engine is done by the detector (needs to be done on target hardware).
354
+ Note that trainer may train with different classes, which is why we send an initial model.json file."""
355
+
356
+ files = await self._get_latest_model_files()
357
+ if files is None:
358
+ return None
359
+
360
+ if isinstance(files, List):
361
+ files = {self.model_format: files}
362
+ assert isinstance(files, Dict), f'can only upload model as list or dict, but was {files}'
363
+
364
+ already_uploaded_formats = self.active_training_io.load_model_upload_progress()
365
+
366
+ model_uuid = None
367
+ for file_format in [f for f in files if f not in already_uploaded_formats]:
368
+ _files = files[file_format] + [self._dump_categories_to_json()]
369
+ assert len([f for f in _files if 'model.json' in f]) == 1, "model.json must be included exactly once"
370
+
371
+ model_uuid = await self.node.data_exchanger.upload_model_get_uuid(context, _files, self.training.training_number, file_format)
372
+ if model_uuid is None:
373
+ return None
374
+
375
+ already_uploaded_formats.append(file_format)
376
+ self.active_training_io.save_model_upload_progress(already_uploaded_formats)
377
+
378
+ return model_uuid
379
+
380
+ def _dump_categories_to_json(self) -> str:
381
+ """Dumps the categories to a json file and returns the path to the file.
382
+ """
383
+ content = {'categories': [asdict(c) for c in self.training_data.categories], } if self.training_data else None
384
+ json_path = '/tmp/model.json'
385
+ with open(json_path, 'w') as f:
386
+ json.dump(content, f)
387
+ return json_path
388
+
389
+ async def _clear_training(self):
390
+ """Clears the training data after a training has finished.
391
+ """
392
+ self.active_training_io.delete_detections()
393
+ self.active_training_io.delete_detection_upload_progress()
394
+ self.active_training_io.delete_detections_upload_file_index()
395
+ await self._clear_training_data(self.training.training_folder)
396
+ self.last_training_io.delete()
397
+
398
+ await self.node.send_status()
399
+ self._training = None
400
+
401
+ # ---------------------------------------- OTHER METHODS ----------------------------------------
402
+
403
+ async def on_shutdown(self) -> None:
404
+ self.shutdown_event.set()
405
+ await self.stop()
406
+ await self.stop()
407
+
408
+ async def stop(self):
409
+ """Stops the training process by canceling training task.
410
+ """
411
+ if not self.training_active:
412
+ return
413
+ if self.training_task:
414
+ logging.info('cancelling training task')
415
+ if self.training_task.cancel():
416
+ try:
417
+ await self.training_task
418
+ except asyncio.CancelledError:
419
+ pass
420
+ logging.info('cancelled training task')
421
+ self._may_restart()
422
+
423
+ def _may_restart(self) -> None:
424
+ """If the environment variable RESTART_AFTER_TRAINING is set, the trainer will restart after a training.
425
+ """
426
+ if self._environment_vars.restart_after_training:
427
+ logging.info('restarting')
428
+ sys.exit(0)
429
+ else:
430
+ logging.info('not restarting')
431
+ # ---------------------------------------- ABSTRACT METHODS ----------------------------------------
432
+
433
+ @abstractmethod
434
+ async def _train(self) -> None:
435
+ """Should be used to execute a training.
436
+ At this point, images are already downloaded to the images_folder and annotations are saved in training.data.image_data.
437
+ If a training is continued, the model is already downloaded.
438
+ The model should be synchronized with the Learning Loop via self._sync_confusion_matrix() every now and then.
439
+ asyncio.CancelledError should be catched and re-raised.
440
+ """
441
+ raise NotImplementedError
442
+
443
+ @abstractmethod
444
+ async def _do_detections(self) -> None:
445
+ """Should be used to infer detections of all images and save them to drive.
446
+ active_training_io.save_detections(...) should be used to store the detections.
447
+ asyncio.CancelledError should be catched and re-raised.
448
+ """
449
+ raise NotImplementedError
450
+
451
+ @abstractmethod
452
+ def _get_new_best_training_state(self) -> Optional[TrainingStateData]:
453
+ """Is called frequently by `_sync_confusion_matrix` to check if a new "best" model is availabe.
454
+ Returns None if no new model could be found. Otherwise TrainingStateData(confusion_matrix, meta_information).
455
+ `confusion_matrix` contains a dict of all classes:
456
+ - The classes must be identified by their uuid, not their name.
457
+ - For each class a dict with tp, fp, fn is provided (true positives, false positives, false negatives).
458
+ `meta_information` can hold any data which is helpful for self._on_metrics_published to store weight file etc for later upload via self.get_model_files
459
+ """
460
+ raise NotImplementedError
461
+
462
+ @abstractmethod
463
+ def _on_metrics_published(self, training_state_data: TrainingStateData) -> None:
464
+ """Called after the metrics corresponding to TrainingStateData have been successfully send to the Learning Loop.
465
+ Receives the TrainingStateData object which was returned by self._get_new_best_training_state.
466
+ If above function returns None, this function is not called.
467
+ The respective files for this model should be stored so they can be later uploaded in get_latest_model_files.
468
+ """
469
+ raise NotImplementedError
470
+
471
+ @abstractmethod
472
+ async def _get_latest_model_files(self) -> Dict[str, List[str]]:
473
+ """Called when the Learning Loop requests to backup the latest model for the training.
474
+ This function is used to __generate and gather__ all files needed for transfering the actual data from the trainer node to the Learning Loop.
475
+ In the simplest implementation this method just renames the weight file (e.g. stored in TrainingStateData.meta_information) into a file name like latest_published_model
476
+
477
+ The function should return a list of file paths which describe the model per format.
478
+ These files must contain all data neccessary for the trainer to resume a training (eg. weight file, hyperparameters, etc.)
479
+ and will be stored in the Learning Loop unter the format of this trainer.
480
+ Note: by convention the weightfile should be named "model.<extension>" where extension is the file format of the weightfile.
481
+ For example "model.pt" for pytorch or "model.weights" for darknet/yolo.
482
+
483
+ If a trainer can also generate other formats (for example for an detector),
484
+ a dictionary mapping format -> list of files can be returned.
485
+
486
+ If the function returns an empty dict, something went wrong and the model upload will be skipped.
487
+ """
488
+ raise NotImplementedError
489
+
490
+ @abstractmethod
491
+ async def _clear_training_data(self, training_folder: str) -> None:
492
+ """Called after a training has finished. Deletes all data that is not needed anymore after a training run.
493
+ This can be old weightfiles or any additional files.
494
+ """
495
+ raise NotImplementedError