learning-loop-node 0.9.3__py3-none-any.whl → 0.10.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of learning-loop-node might be problematic. Click here for more details.
- learning_loop_node/__init__.py +2 -3
- learning_loop_node/annotation/annotator_logic.py +2 -2
- learning_loop_node/annotation/annotator_node.py +16 -15
- learning_loop_node/data_classes/__init__.py +17 -10
- learning_loop_node/data_classes/detections.py +7 -2
- learning_loop_node/data_classes/general.py +4 -5
- learning_loop_node/data_classes/training.py +49 -21
- learning_loop_node/data_exchanger.py +85 -139
- learning_loop_node/detector/__init__.py +0 -1
- learning_loop_node/detector/detector_node.py +10 -13
- learning_loop_node/detector/inbox_filter/cam_observation_history.py +4 -7
- learning_loop_node/detector/outbox.py +0 -1
- learning_loop_node/detector/rest/about.py +1 -0
- learning_loop_node/detector/tests/conftest.py +0 -1
- learning_loop_node/detector/tests/test_client_communication.py +5 -3
- learning_loop_node/detector/tests/test_outbox.py +2 -0
- learning_loop_node/detector/tests/testing_detector.py +1 -8
- learning_loop_node/globals.py +2 -2
- learning_loop_node/helpers/gdrive_downloader.py +1 -1
- learning_loop_node/helpers/misc.py +124 -17
- learning_loop_node/loop_communication.py +57 -25
- learning_loop_node/node.py +62 -135
- learning_loop_node/tests/test_downloader.py +8 -7
- learning_loop_node/tests/test_executor.py +14 -11
- learning_loop_node/tests/test_helper.py +3 -5
- learning_loop_node/trainer/downloader.py +1 -1
- learning_loop_node/trainer/executor.py +87 -83
- learning_loop_node/trainer/io_helpers.py +66 -9
- learning_loop_node/trainer/rest/backdoor_controls.py +10 -5
- learning_loop_node/trainer/rest/controls.py +3 -1
- learning_loop_node/trainer/tests/conftest.py +19 -28
- learning_loop_node/trainer/tests/states/test_state_cleanup.py +5 -3
- learning_loop_node/trainer/tests/states/test_state_detecting.py +23 -20
- learning_loop_node/trainer/tests/states/test_state_download_train_model.py +18 -12
- learning_loop_node/trainer/tests/states/test_state_prepare.py +13 -12
- learning_loop_node/trainer/tests/states/test_state_sync_confusion_matrix.py +21 -18
- learning_loop_node/trainer/tests/states/test_state_train.py +27 -28
- learning_loop_node/trainer/tests/states/test_state_upload_detections.py +34 -32
- learning_loop_node/trainer/tests/states/test_state_upload_model.py +22 -20
- learning_loop_node/trainer/tests/test_errors.py +20 -12
- learning_loop_node/trainer/tests/test_trainer_states.py +4 -5
- learning_loop_node/trainer/tests/testing_trainer_logic.py +25 -30
- learning_loop_node/trainer/trainer_logic.py +80 -590
- learning_loop_node/trainer/trainer_logic_generic.py +495 -0
- learning_loop_node/trainer/trainer_node.py +27 -106
- {learning_loop_node-0.9.3.dist-info → learning_loop_node-0.10.0.dist-info}/METADATA +1 -1
- learning_loop_node-0.10.0.dist-info/RECORD +85 -0
- learning_loop_node/converter/converter_logic.py +0 -68
- learning_loop_node/converter/converter_node.py +0 -125
- learning_loop_node/converter/tests/test_converter.py +0 -55
- learning_loop_node/trainer/training_syncronizer.py +0 -52
- learning_loop_node-0.9.3.dist-info/RECORD +0 -88
- /learning_loop_node/{converter/__init__.py → py.typed} +0 -0
- {learning_loop_node-0.9.3.dist-info → learning_loop_node-0.10.0.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,495 @@
|
|
|
1
|
+
import asyncio
|
|
2
|
+
import json
|
|
3
|
+
import logging
|
|
4
|
+
import shutil
|
|
5
|
+
import sys
|
|
6
|
+
import time
|
|
7
|
+
from abc import ABC, abstractmethod
|
|
8
|
+
from dataclasses import asdict
|
|
9
|
+
from typing import TYPE_CHECKING, Callable, Coroutine, Dict, List, Optional
|
|
10
|
+
|
|
11
|
+
from fastapi.encoders import jsonable_encoder
|
|
12
|
+
|
|
13
|
+
from ..data_classes import (Context, Errors, Hyperparameter, PretrainedModel, TrainerState, Training, TrainingData,
|
|
14
|
+
TrainingOut, TrainingStateData)
|
|
15
|
+
from ..helpers.misc import create_project_folder, delete_all_training_folders, generate_training, is_valid_uuid4
|
|
16
|
+
from .downloader import TrainingsDownloader
|
|
17
|
+
from .io_helpers import ActiveTrainingIO, EnvironmentVars, LastTrainingIO
|
|
18
|
+
|
|
19
|
+
if TYPE_CHECKING:
|
|
20
|
+
from .trainer_node import TrainerNode
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class TrainerLogicGeneric(ABC):
|
|
24
|
+
|
|
25
|
+
def __init__(self, model_format: str) -> None:
|
|
26
|
+
|
|
27
|
+
# NOTE: model_format is used in the file path for the model on the server:
|
|
28
|
+
# It acts as a key for list of files (cf. _get_latest_model_files)
|
|
29
|
+
# '/{context.organization}/projects/{context.project}/models/{model_id}/{model_format}/file'
|
|
30
|
+
self.model_format: str = model_format
|
|
31
|
+
self.errors = Errors()
|
|
32
|
+
|
|
33
|
+
self.training_task: Optional[asyncio.Task] = None
|
|
34
|
+
self.shutdown_event: asyncio.Event = asyncio.Event()
|
|
35
|
+
|
|
36
|
+
self._node: Optional['TrainerNode'] = None # type: ignore
|
|
37
|
+
self._last_training_io: Optional[LastTrainingIO] = None # type: ignore
|
|
38
|
+
|
|
39
|
+
self._training: Optional[Training] = None
|
|
40
|
+
self._active_training_io: Optional[ActiveTrainingIO] = None
|
|
41
|
+
self._environment_vars = EnvironmentVars()
|
|
42
|
+
|
|
43
|
+
# ---------------------------------------- PROPERTIES TO AVOID CHECKING FOR NONE ----------------------------------------
|
|
44
|
+
|
|
45
|
+
@property
|
|
46
|
+
def node(self) -> 'TrainerNode':
|
|
47
|
+
assert self._node is not None, 'node should be set by TrainerNode before initialization'
|
|
48
|
+
return self._node
|
|
49
|
+
|
|
50
|
+
@property
|
|
51
|
+
def last_training_io(self) -> LastTrainingIO:
|
|
52
|
+
assert self._last_training_io is not None, 'last_training_io should be set by TrainerNode before initialization'
|
|
53
|
+
return self._last_training_io
|
|
54
|
+
|
|
55
|
+
@property
|
|
56
|
+
def active_training_io(self) -> ActiveTrainingIO:
|
|
57
|
+
assert self._active_training_io is not None, 'active_training_io must be set, call `init` first'
|
|
58
|
+
return self._active_training_io
|
|
59
|
+
|
|
60
|
+
@property
|
|
61
|
+
def training(self) -> Training:
|
|
62
|
+
assert self._training is not None, 'training must be initialized, call `init` first'
|
|
63
|
+
return self._training
|
|
64
|
+
|
|
65
|
+
@property
|
|
66
|
+
def hyperparameter(self) -> Hyperparameter:
|
|
67
|
+
assert self.training_data is not None, 'Training should have data'
|
|
68
|
+
assert self.training_data.hyperparameter is not None, 'Training.data should have hyperparameter'
|
|
69
|
+
return self.training_data.hyperparameter
|
|
70
|
+
|
|
71
|
+
# ---------------------------------------- PROPERTIES ----------------------------------------
|
|
72
|
+
|
|
73
|
+
@property
|
|
74
|
+
def training_data(self) -> Optional[TrainingData]:
|
|
75
|
+
if self.training_active and self.training.data:
|
|
76
|
+
return self.training.data
|
|
77
|
+
return None
|
|
78
|
+
|
|
79
|
+
@property
|
|
80
|
+
def training_context(self) -> Optional[Context]:
|
|
81
|
+
if self.training_active:
|
|
82
|
+
return self.training.context
|
|
83
|
+
return None
|
|
84
|
+
|
|
85
|
+
@property
|
|
86
|
+
def training_active(self) -> bool:
|
|
87
|
+
"""_training and _active_training_io are set in 'init_new_training' or 'init_from_last_training'.
|
|
88
|
+
"""
|
|
89
|
+
return self._training is not None and self._active_training_io is not None
|
|
90
|
+
|
|
91
|
+
@property
|
|
92
|
+
def state(self) -> str:
|
|
93
|
+
"""Returns the current state of the training. Used solely by the node in send_status().
|
|
94
|
+
"""
|
|
95
|
+
if (not self.training_active) or (self.training.training_state is None):
|
|
96
|
+
return TrainerState.Idle.value
|
|
97
|
+
return self.training.training_state
|
|
98
|
+
|
|
99
|
+
@property
|
|
100
|
+
def training_uptime(self) -> Optional[float]:
|
|
101
|
+
"""Livetime of current Training object. Start time is set during initialization of Training object.
|
|
102
|
+
"""
|
|
103
|
+
if self.training_active:
|
|
104
|
+
return time.time() - self.training.start_time
|
|
105
|
+
return None
|
|
106
|
+
|
|
107
|
+
@property
|
|
108
|
+
def hyperparameters_for_state_sync(self) -> Optional[Dict]:
|
|
109
|
+
"""Used in sync_confusion_matrix and send_status to provide information about the training configuration.
|
|
110
|
+
"""
|
|
111
|
+
if self._training and self._training.data and self._training.data.hyperparameter:
|
|
112
|
+
information = {}
|
|
113
|
+
information['resolution'] = self._training.data.hyperparameter.resolution
|
|
114
|
+
information['flipRl'] = self._training.data.hyperparameter.flip_rl
|
|
115
|
+
information['flipUd'] = self._training.data.hyperparameter.flip_ud
|
|
116
|
+
return information
|
|
117
|
+
return None
|
|
118
|
+
|
|
119
|
+
@property
|
|
120
|
+
def general_progress(self) -> Optional[float]:
|
|
121
|
+
"""Represents the progress for different states, should run from 0 to 100 for each state.
|
|
122
|
+
Note that training_progress and detection_progress need to be implemented in the specific trainer.
|
|
123
|
+
"""
|
|
124
|
+
if not self.training_active:
|
|
125
|
+
return None
|
|
126
|
+
|
|
127
|
+
t_state = self.training.training_state
|
|
128
|
+
if t_state == TrainerState.DataDownloading:
|
|
129
|
+
return self.node.data_exchanger.progress
|
|
130
|
+
if t_state == TrainerState.TrainingRunning:
|
|
131
|
+
return self.training_progress
|
|
132
|
+
if t_state == TrainerState.Detecting:
|
|
133
|
+
return self.detection_progress
|
|
134
|
+
|
|
135
|
+
return None
|
|
136
|
+
|
|
137
|
+
# ---------------------------------------- ABSTRACT PROPERTIES ----------------------------------------
|
|
138
|
+
|
|
139
|
+
@property
|
|
140
|
+
@abstractmethod
|
|
141
|
+
def training_progress(self) -> Optional[float]:
|
|
142
|
+
"""Represents the training progress."""
|
|
143
|
+
raise NotImplementedError
|
|
144
|
+
|
|
145
|
+
@property
|
|
146
|
+
@abstractmethod
|
|
147
|
+
def detection_progress(self) -> Optional[float]:
|
|
148
|
+
"""Represents the detection progress."""
|
|
149
|
+
raise NotImplementedError
|
|
150
|
+
|
|
151
|
+
@property
|
|
152
|
+
@abstractmethod
|
|
153
|
+
def model_architecture(self) -> Optional[str]:
|
|
154
|
+
"""Returns the architecture name of the model if available"""
|
|
155
|
+
raise NotImplementedError
|
|
156
|
+
|
|
157
|
+
@property
|
|
158
|
+
@abstractmethod
|
|
159
|
+
def provided_pretrained_models(self) -> List[PretrainedModel]:
|
|
160
|
+
"""Returns the list of provided pretrained models.
|
|
161
|
+
The names of the models will come back as model_uuid_or_name in the training details.
|
|
162
|
+
"""
|
|
163
|
+
raise NotImplementedError
|
|
164
|
+
|
|
165
|
+
# ---------------------------------------- METHODS ----------------------------------------
|
|
166
|
+
|
|
167
|
+
# NOTE: Trainings are started by the Learning Loop via the begin_training event
|
|
168
|
+
# or by the trainer itself via try_continue_run_if_incomplete.
|
|
169
|
+
# The trainer will then initialize a new training object and start the training loop.
|
|
170
|
+
# Initializing a new training object will create the folder structure for the training.
|
|
171
|
+
# The training loop will then run through the states of the training.
|
|
172
|
+
|
|
173
|
+
async def try_continue_run_if_incomplete(self) -> bool:
|
|
174
|
+
"""Tries to continue a training if the last training was not finished.
|
|
175
|
+
"""
|
|
176
|
+
if not self.training_active and self.last_training_io.exists():
|
|
177
|
+
self._init_from_last_training()
|
|
178
|
+
logging.info('found incomplete training, continuing now.')
|
|
179
|
+
asyncio.get_event_loop().create_task(self._run())
|
|
180
|
+
return True
|
|
181
|
+
return False
|
|
182
|
+
|
|
183
|
+
def _init_from_last_training(self) -> None:
|
|
184
|
+
"""Initializes a new training object from the last training saved on disc via last_training_io.
|
|
185
|
+
"""
|
|
186
|
+
self._training = self.last_training_io.load()
|
|
187
|
+
assert self._training is not None and self._training.training_folder is not None, 'could not restore training folder'
|
|
188
|
+
self._active_training_io = ActiveTrainingIO(
|
|
189
|
+
self._training.training_folder, self.node.loop_communicator, self._training.context)
|
|
190
|
+
|
|
191
|
+
async def begin_training(self, organization: str, project: str, details: Dict) -> None:
|
|
192
|
+
"""Called on `begin_training` event from the Learning Loop.
|
|
193
|
+
"""
|
|
194
|
+
self._init_new_training(Context(organization=organization, project=project), details)
|
|
195
|
+
asyncio.get_event_loop().create_task(self._run())
|
|
196
|
+
|
|
197
|
+
def _init_new_training(self, context: Context, details: Dict) -> None:
|
|
198
|
+
"""Called on `begin_training` event from the Learning Loop.
|
|
199
|
+
Note that details needs the entries 'categories' and 'training_number',
|
|
200
|
+
but also the hyperparameter entries.
|
|
201
|
+
"""
|
|
202
|
+
project_folder = create_project_folder(context)
|
|
203
|
+
if not self._environment_vars.keep_old_trainings:
|
|
204
|
+
delete_all_training_folders(project_folder)
|
|
205
|
+
self._training = generate_training(project_folder, context)
|
|
206
|
+
self._training.set_values_from_data(details)
|
|
207
|
+
|
|
208
|
+
self._active_training_io = ActiveTrainingIO(
|
|
209
|
+
self._training.training_folder, self.node.loop_communicator, context)
|
|
210
|
+
logging.info(f'new training initialized: {self._training}')
|
|
211
|
+
|
|
212
|
+
async def _run(self) -> None:
|
|
213
|
+
"""Called on `begin_training` event from the Learning Loop.
|
|
214
|
+
Either via `begin_training` or `try_continue_run_if_incomplete`.
|
|
215
|
+
"""
|
|
216
|
+
self.errors.reset_all()
|
|
217
|
+
try:
|
|
218
|
+
self.training_task = asyncio.get_running_loop().create_task(self._training_loop())
|
|
219
|
+
await self.training_task # NOTE: Task object is used to potentially cancel the task
|
|
220
|
+
except asyncio.CancelledError:
|
|
221
|
+
if not self.shutdown_event.is_set():
|
|
222
|
+
logging.info('training task was cancelled but not by shutdown event')
|
|
223
|
+
self.training.training_state = TrainerState.ReadyForCleanup
|
|
224
|
+
self.last_training_io.save(self.training)
|
|
225
|
+
await self._clear_training()
|
|
226
|
+
except Exception as e:
|
|
227
|
+
logging.exception(f'Error in train: {e}')
|
|
228
|
+
|
|
229
|
+
# ---------------------------------------- TRAINING STATES ----------------------------------------
|
|
230
|
+
|
|
231
|
+
async def _training_loop(self) -> None:
|
|
232
|
+
"""Cycle through the training states until the training is finished or
|
|
233
|
+
an asyncio.CancelledError is raised.
|
|
234
|
+
"""
|
|
235
|
+
assert self.training_active
|
|
236
|
+
|
|
237
|
+
while self._training is not None:
|
|
238
|
+
tstate = self.training.training_state
|
|
239
|
+
await asyncio.sleep(0.6) # Note: Required for pytests!
|
|
240
|
+
|
|
241
|
+
if tstate == TrainerState.Initialized: # -> DataDownloading -> DataDownloaded
|
|
242
|
+
await self._perform_state('prepare', TrainerState.DataDownloading, TrainerState.DataDownloaded, self._prepare)
|
|
243
|
+
elif tstate == TrainerState.DataDownloaded: # -> TrainModelDownloading -> TrainModelDownloaded
|
|
244
|
+
await self._perform_state('download_model', TrainerState.TrainModelDownloading, TrainerState.TrainModelDownloaded, self._download_model)
|
|
245
|
+
elif tstate == TrainerState.TrainModelDownloaded: # -> TrainingRunning -> TrainingFinished
|
|
246
|
+
await self._perform_state('run_training', TrainerState.TrainingRunning, TrainerState.TrainingFinished, self._train)
|
|
247
|
+
elif tstate == TrainerState.TrainingFinished: # -> ConfusionMatrixSyncing -> ConfusionMatrixSynced
|
|
248
|
+
await self._perform_state('sync_confusion_matrix', TrainerState.ConfusionMatrixSyncing, TrainerState.ConfusionMatrixSynced, self._sync_confusion_matrix)
|
|
249
|
+
elif tstate == TrainerState.ConfusionMatrixSynced: # -> TrainModelUploading -> TrainModelUploaded
|
|
250
|
+
await self._perform_state('upload_model', TrainerState.TrainModelUploading, TrainerState.TrainModelUploaded, self._upload_model)
|
|
251
|
+
elif tstate == TrainerState.TrainModelUploaded: # -> Detecting -> Detected
|
|
252
|
+
await self._perform_state('detecting', TrainerState.Detecting, TrainerState.Detected, self._do_detections)
|
|
253
|
+
elif tstate == TrainerState.Detected: # -> DetectionUploading -> ReadyForCleanup
|
|
254
|
+
await self._perform_state('upload_detections', TrainerState.DetectionUploading, TrainerState.ReadyForCleanup, self.active_training_io.upload_detetions)
|
|
255
|
+
elif tstate == TrainerState.ReadyForCleanup: # -> RESTART or TrainingFinished
|
|
256
|
+
await self._clear_training()
|
|
257
|
+
self._may_restart()
|
|
258
|
+
|
|
259
|
+
async def _perform_state(self, error_key: str, state_during: TrainerState, state_after: TrainerState, action: Callable[[], Coroutine], reset_early=False):
|
|
260
|
+
await asyncio.sleep(0.1)
|
|
261
|
+
logging.info(f'Performing state: {state_during}')
|
|
262
|
+
previous_state = self.training.training_state
|
|
263
|
+
self.training.training_state = state_during
|
|
264
|
+
await asyncio.sleep(0.1)
|
|
265
|
+
if reset_early:
|
|
266
|
+
self.errors.reset(error_key)
|
|
267
|
+
|
|
268
|
+
try:
|
|
269
|
+
if await action():
|
|
270
|
+
logging.error('Something went really bad.. cleaning up')
|
|
271
|
+
state_after = TrainerState.ReadyForCleanup
|
|
272
|
+
except asyncio.CancelledError:
|
|
273
|
+
logging.warning(f'CancelledError in {state_during}')
|
|
274
|
+
raise
|
|
275
|
+
except Exception as e:
|
|
276
|
+
self.errors.set(error_key, str(e))
|
|
277
|
+
logging.exception(f'Error in {state_during} - Exception:')
|
|
278
|
+
self.training.training_state = previous_state
|
|
279
|
+
else:
|
|
280
|
+
if not reset_early:
|
|
281
|
+
self.errors.reset(error_key)
|
|
282
|
+
self.training.training_state = state_after
|
|
283
|
+
self.last_training_io.save(self.training)
|
|
284
|
+
|
|
285
|
+
async def _prepare(self) -> None:
|
|
286
|
+
"""Downloads images to the images_folder and saves annotations to training.data.image_data.
|
|
287
|
+
"""
|
|
288
|
+
self.node.data_exchanger.set_context(self.training.context)
|
|
289
|
+
downloader = TrainingsDownloader(self.node.data_exchanger)
|
|
290
|
+
image_data, skipped_image_count = await downloader.download_training_data(self.training.images_folder)
|
|
291
|
+
assert self.training.data is not None, 'training.data must be set'
|
|
292
|
+
self.training.data.image_data = image_data
|
|
293
|
+
self.training.data.skipped_image_count = skipped_image_count
|
|
294
|
+
|
|
295
|
+
async def _download_model(self) -> None:
|
|
296
|
+
"""If training is continued, the model is downloaded from the Learning Loop to the training_folder.
|
|
297
|
+
The downloaded model.json file is renamed to base_model.json because a new model.json will be created during training.
|
|
298
|
+
"""
|
|
299
|
+
base_model_uuid = self.training.base_model_uuid_or_name
|
|
300
|
+
|
|
301
|
+
# TODO this checks if we continue a training -> make more explicit
|
|
302
|
+
if not base_model_uuid or not is_valid_uuid4(base_model_uuid):
|
|
303
|
+
logging.info(f'skipping model download. No base model provided (in form of uuid): {base_model_uuid}')
|
|
304
|
+
return
|
|
305
|
+
|
|
306
|
+
logging.info('loading model from Learning Loop')
|
|
307
|
+
logging.info(f'downloading model {base_model_uuid} as {self.model_format}')
|
|
308
|
+
await self.node.data_exchanger.download_model(self.training.training_folder, self.training.context, base_model_uuid, self.model_format)
|
|
309
|
+
shutil.move(f'{self.training.training_folder}/model.json',
|
|
310
|
+
f'{self.training.training_folder}/base_model.json')
|
|
311
|
+
|
|
312
|
+
async def _sync_confusion_matrix(self) -> None:
|
|
313
|
+
"""Syncronizes the confusion matrix with the Learning Loop via the update_training endpoint.
|
|
314
|
+
NOTE: This stage sets the errors explicitly because it may be used inside the training stage.
|
|
315
|
+
"""
|
|
316
|
+
error_key = 'sync_confusion_matrix'
|
|
317
|
+
try:
|
|
318
|
+
new_best_model = self._get_new_best_training_state()
|
|
319
|
+
if new_best_model and self.training.data:
|
|
320
|
+
new_training = TrainingOut(trainer_id=self.node.uuid,
|
|
321
|
+
confusion_matrix=new_best_model.confusion_matrix,
|
|
322
|
+
train_image_count=self.training.data.train_image_count(),
|
|
323
|
+
test_image_count=self.training.data.test_image_count(),
|
|
324
|
+
hyperparameters=self.hyperparameters_for_state_sync)
|
|
325
|
+
await asyncio.sleep(0.1) # NOTE needed for tests.
|
|
326
|
+
|
|
327
|
+
result = await self.node.sio_client.call('update_training', (
|
|
328
|
+
self.training.context.organization, self.training.context.project, jsonable_encoder(new_training)))
|
|
329
|
+
if isinstance(result, dict) and result['success']:
|
|
330
|
+
logging.info(f'successfully updated training {asdict(new_training)}')
|
|
331
|
+
self._on_metrics_published(new_best_model)
|
|
332
|
+
else:
|
|
333
|
+
raise Exception(f'Error for update_training: Response from loop was : {result}')
|
|
334
|
+
except Exception as e:
|
|
335
|
+
logging.exception('Error during confusion matrix syncronization')
|
|
336
|
+
self.errors.set(error_key, str(e))
|
|
337
|
+
raise
|
|
338
|
+
self.errors.reset(error_key)
|
|
339
|
+
|
|
340
|
+
async def _upload_model(self) -> None:
|
|
341
|
+
"""Uploads the latest model to the Learning Loop.
|
|
342
|
+
"""
|
|
343
|
+
new_model_uuid = await self._upload_model_return_new_model_uuid(self.training.context)
|
|
344
|
+
if new_model_uuid is None:
|
|
345
|
+
self.training.training_state = TrainerState.ReadyForCleanup
|
|
346
|
+
logging.error('could not upload model - maybe training failed.. cleaning up')
|
|
347
|
+
logging.info(f'Successfully uploaded model and received new model id: {new_model_uuid}')
|
|
348
|
+
self.training.model_uuid_for_detecting = new_model_uuid
|
|
349
|
+
|
|
350
|
+
async def _upload_model_return_new_model_uuid(self, context: Context) -> Optional[str]:
|
|
351
|
+
"""Upload model files, usually pytorch model (.pt) hyp.yaml and the converted .wts file.
|
|
352
|
+
Note that with the latest trainers the conversion to (.wts) is done by the trainer.
|
|
353
|
+
The conversion from .wts to .engine is done by the detector (needs to be done on target hardware).
|
|
354
|
+
Note that trainer may train with different classes, which is why we send an initial model.json file."""
|
|
355
|
+
|
|
356
|
+
files = await self._get_latest_model_files()
|
|
357
|
+
if files is None:
|
|
358
|
+
return None
|
|
359
|
+
|
|
360
|
+
if isinstance(files, List):
|
|
361
|
+
files = {self.model_format: files}
|
|
362
|
+
assert isinstance(files, Dict), f'can only upload model as list or dict, but was {files}'
|
|
363
|
+
|
|
364
|
+
already_uploaded_formats = self.active_training_io.load_model_upload_progress()
|
|
365
|
+
|
|
366
|
+
model_uuid = None
|
|
367
|
+
for file_format in [f for f in files if f not in already_uploaded_formats]:
|
|
368
|
+
_files = files[file_format] + [self._dump_categories_to_json()]
|
|
369
|
+
assert len([f for f in _files if 'model.json' in f]) == 1, "model.json must be included exactly once"
|
|
370
|
+
|
|
371
|
+
model_uuid = await self.node.data_exchanger.upload_model_get_uuid(context, _files, self.training.training_number, file_format)
|
|
372
|
+
if model_uuid is None:
|
|
373
|
+
return None
|
|
374
|
+
|
|
375
|
+
already_uploaded_formats.append(file_format)
|
|
376
|
+
self.active_training_io.save_model_upload_progress(already_uploaded_formats)
|
|
377
|
+
|
|
378
|
+
return model_uuid
|
|
379
|
+
|
|
380
|
+
def _dump_categories_to_json(self) -> str:
|
|
381
|
+
"""Dumps the categories to a json file and returns the path to the file.
|
|
382
|
+
"""
|
|
383
|
+
content = {'categories': [asdict(c) for c in self.training_data.categories], } if self.training_data else None
|
|
384
|
+
json_path = '/tmp/model.json'
|
|
385
|
+
with open(json_path, 'w') as f:
|
|
386
|
+
json.dump(content, f)
|
|
387
|
+
return json_path
|
|
388
|
+
|
|
389
|
+
async def _clear_training(self):
|
|
390
|
+
"""Clears the training data after a training has finished.
|
|
391
|
+
"""
|
|
392
|
+
self.active_training_io.delete_detections()
|
|
393
|
+
self.active_training_io.delete_detection_upload_progress()
|
|
394
|
+
self.active_training_io.delete_detections_upload_file_index()
|
|
395
|
+
await self._clear_training_data(self.training.training_folder)
|
|
396
|
+
self.last_training_io.delete()
|
|
397
|
+
|
|
398
|
+
await self.node.send_status()
|
|
399
|
+
self._training = None
|
|
400
|
+
|
|
401
|
+
# ---------------------------------------- OTHER METHODS ----------------------------------------
|
|
402
|
+
|
|
403
|
+
async def on_shutdown(self) -> None:
|
|
404
|
+
self.shutdown_event.set()
|
|
405
|
+
await self.stop()
|
|
406
|
+
await self.stop()
|
|
407
|
+
|
|
408
|
+
async def stop(self):
|
|
409
|
+
"""Stops the training process by canceling training task.
|
|
410
|
+
"""
|
|
411
|
+
if not self.training_active:
|
|
412
|
+
return
|
|
413
|
+
if self.training_task:
|
|
414
|
+
logging.info('cancelling training task')
|
|
415
|
+
if self.training_task.cancel():
|
|
416
|
+
try:
|
|
417
|
+
await self.training_task
|
|
418
|
+
except asyncio.CancelledError:
|
|
419
|
+
pass
|
|
420
|
+
logging.info('cancelled training task')
|
|
421
|
+
self._may_restart()
|
|
422
|
+
|
|
423
|
+
def _may_restart(self) -> None:
|
|
424
|
+
"""If the environment variable RESTART_AFTER_TRAINING is set, the trainer will restart after a training.
|
|
425
|
+
"""
|
|
426
|
+
if self._environment_vars.restart_after_training:
|
|
427
|
+
logging.info('restarting')
|
|
428
|
+
sys.exit(0)
|
|
429
|
+
else:
|
|
430
|
+
logging.info('not restarting')
|
|
431
|
+
# ---------------------------------------- ABSTRACT METHODS ----------------------------------------
|
|
432
|
+
|
|
433
|
+
@abstractmethod
|
|
434
|
+
async def _train(self) -> None:
|
|
435
|
+
"""Should be used to execute a training.
|
|
436
|
+
At this point, images are already downloaded to the images_folder and annotations are saved in training.data.image_data.
|
|
437
|
+
If a training is continued, the model is already downloaded.
|
|
438
|
+
The model should be synchronized with the Learning Loop via self._sync_confusion_matrix() every now and then.
|
|
439
|
+
asyncio.CancelledError should be catched and re-raised.
|
|
440
|
+
"""
|
|
441
|
+
raise NotImplementedError
|
|
442
|
+
|
|
443
|
+
@abstractmethod
|
|
444
|
+
async def _do_detections(self) -> None:
|
|
445
|
+
"""Should be used to infer detections of all images and save them to drive.
|
|
446
|
+
active_training_io.save_detections(...) should be used to store the detections.
|
|
447
|
+
asyncio.CancelledError should be catched and re-raised.
|
|
448
|
+
"""
|
|
449
|
+
raise NotImplementedError
|
|
450
|
+
|
|
451
|
+
@abstractmethod
|
|
452
|
+
def _get_new_best_training_state(self) -> Optional[TrainingStateData]:
|
|
453
|
+
"""Is called frequently by `_sync_confusion_matrix` to check if a new "best" model is availabe.
|
|
454
|
+
Returns None if no new model could be found. Otherwise TrainingStateData(confusion_matrix, meta_information).
|
|
455
|
+
`confusion_matrix` contains a dict of all classes:
|
|
456
|
+
- The classes must be identified by their uuid, not their name.
|
|
457
|
+
- For each class a dict with tp, fp, fn is provided (true positives, false positives, false negatives).
|
|
458
|
+
`meta_information` can hold any data which is helpful for self._on_metrics_published to store weight file etc for later upload via self.get_model_files
|
|
459
|
+
"""
|
|
460
|
+
raise NotImplementedError
|
|
461
|
+
|
|
462
|
+
@abstractmethod
|
|
463
|
+
def _on_metrics_published(self, training_state_data: TrainingStateData) -> None:
|
|
464
|
+
"""Called after the metrics corresponding to TrainingStateData have been successfully send to the Learning Loop.
|
|
465
|
+
Receives the TrainingStateData object which was returned by self._get_new_best_training_state.
|
|
466
|
+
If above function returns None, this function is not called.
|
|
467
|
+
The respective files for this model should be stored so they can be later uploaded in get_latest_model_files.
|
|
468
|
+
"""
|
|
469
|
+
raise NotImplementedError
|
|
470
|
+
|
|
471
|
+
@abstractmethod
|
|
472
|
+
async def _get_latest_model_files(self) -> Dict[str, List[str]]:
|
|
473
|
+
"""Called when the Learning Loop requests to backup the latest model for the training.
|
|
474
|
+
This function is used to __generate and gather__ all files needed for transfering the actual data from the trainer node to the Learning Loop.
|
|
475
|
+
In the simplest implementation this method just renames the weight file (e.g. stored in TrainingStateData.meta_information) into a file name like latest_published_model
|
|
476
|
+
|
|
477
|
+
The function should return a list of file paths which describe the model per format.
|
|
478
|
+
These files must contain all data neccessary for the trainer to resume a training (eg. weight file, hyperparameters, etc.)
|
|
479
|
+
and will be stored in the Learning Loop unter the format of this trainer.
|
|
480
|
+
Note: by convention the weightfile should be named "model.<extension>" where extension is the file format of the weightfile.
|
|
481
|
+
For example "model.pt" for pytorch or "model.weights" for darknet/yolo.
|
|
482
|
+
|
|
483
|
+
If a trainer can also generate other formats (for example for an detector),
|
|
484
|
+
a dictionary mapping format -> list of files can be returned.
|
|
485
|
+
|
|
486
|
+
If the function returns an empty dict, something went wrong and the model upload will be skipped.
|
|
487
|
+
"""
|
|
488
|
+
raise NotImplementedError
|
|
489
|
+
|
|
490
|
+
@abstractmethod
|
|
491
|
+
async def _clear_training_data(self, training_folder: str) -> None:
|
|
492
|
+
"""Called after a training has finished. Deletes all data that is not needed anymore after a training run.
|
|
493
|
+
This can be old weightfiles or any additional files.
|
|
494
|
+
"""
|
|
495
|
+
raise NotImplementedError
|