datamint 2.3.3__py3-none-any.whl → 2.9.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- datamint/__init__.py +1 -3
- datamint/api/__init__.py +0 -3
- datamint/api/base_api.py +286 -54
- datamint/api/client.py +76 -13
- datamint/api/endpoints/__init__.py +2 -2
- datamint/api/endpoints/annotations_api.py +186 -28
- datamint/api/endpoints/deploy_model_api.py +78 -0
- datamint/api/endpoints/models_api.py +1 -0
- datamint/api/endpoints/projects_api.py +38 -7
- datamint/api/endpoints/resources_api.py +227 -100
- datamint/api/entity_base_api.py +66 -7
- datamint/apihandler/base_api_handler.py +0 -1
- datamint/apihandler/dto/annotation_dto.py +2 -0
- datamint/client_cmd_tools/datamint_config.py +0 -1
- datamint/client_cmd_tools/datamint_upload.py +3 -1
- datamint/configs.py +11 -7
- datamint/dataset/base_dataset.py +24 -4
- datamint/dataset/dataset.py +1 -1
- datamint/entities/__init__.py +1 -1
- datamint/entities/annotations/__init__.py +13 -0
- datamint/entities/{annotation.py → annotations/annotation.py} +81 -47
- datamint/entities/annotations/image_classification.py +12 -0
- datamint/entities/annotations/image_segmentation.py +252 -0
- datamint/entities/annotations/volume_segmentation.py +273 -0
- datamint/entities/base_entity.py +100 -6
- datamint/entities/cache_manager.py +129 -15
- datamint/entities/datasetinfo.py +60 -65
- datamint/entities/deployjob.py +18 -0
- datamint/entities/project.py +39 -0
- datamint/entities/resource.py +310 -46
- datamint/lightning/__init__.py +1 -0
- datamint/lightning/datamintdatamodule.py +103 -0
- datamint/mlflow/__init__.py +65 -0
- datamint/mlflow/artifact/__init__.py +1 -0
- datamint/mlflow/artifact/datamint_artifacts_repo.py +8 -0
- datamint/mlflow/env_utils.py +131 -0
- datamint/mlflow/env_vars.py +5 -0
- datamint/mlflow/flavors/__init__.py +17 -0
- datamint/mlflow/flavors/datamint_flavor.py +150 -0
- datamint/mlflow/flavors/model.py +877 -0
- datamint/mlflow/lightning/callbacks/__init__.py +1 -0
- datamint/mlflow/lightning/callbacks/modelcheckpoint.py +410 -0
- datamint/mlflow/models/__init__.py +93 -0
- datamint/mlflow/tracking/datamint_store.py +76 -0
- datamint/mlflow/tracking/default_experiment.py +27 -0
- datamint/mlflow/tracking/fluent.py +91 -0
- datamint/utils/env.py +27 -0
- datamint/utils/visualization.py +21 -13
- datamint-2.9.0.dist-info/METADATA +220 -0
- datamint-2.9.0.dist-info/RECORD +73 -0
- {datamint-2.3.3.dist-info → datamint-2.9.0.dist-info}/WHEEL +1 -1
- datamint-2.9.0.dist-info/entry_points.txt +18 -0
- datamint/apihandler/exp_api_handler.py +0 -204
- datamint/experiment/__init__.py +0 -1
- datamint/experiment/_patcher.py +0 -570
- datamint/experiment/experiment.py +0 -1049
- datamint-2.3.3.dist-info/METADATA +0 -125
- datamint-2.3.3.dist-info/RECORD +0 -54
- datamint-2.3.3.dist-info/entry_points.txt +0 -4
datamint/experiment/_patcher.py
DELETED
|
@@ -1,570 +0,0 @@
|
|
|
1
|
-
from unittest.mock import patch
|
|
2
|
-
import importlib
|
|
3
|
-
from typing import Sequence, Any, Dict
|
|
4
|
-
import logging
|
|
5
|
-
from .experiment import Experiment
|
|
6
|
-
from torch.utils.data import DataLoader
|
|
7
|
-
import torch
|
|
8
|
-
import sys
|
|
9
|
-
import pandas as pd
|
|
10
|
-
import atexit
|
|
11
|
-
from collections import OrderedDict, defaultdict
|
|
12
|
-
import numpy as np
|
|
13
|
-
|
|
14
|
-
_LOGGER = logging.getLogger(__name__)
|
|
15
|
-
|
|
16
|
-
IS_INITIALIZED = False
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
def _is_iterable(obj):
|
|
20
|
-
try:
|
|
21
|
-
iter(obj)
|
|
22
|
-
return True
|
|
23
|
-
except TypeError:
|
|
24
|
-
return False
|
|
25
|
-
|
|
26
|
-
|
|
27
|
-
class Wrapper:
|
|
28
|
-
class IteratorWrapper:
|
|
29
|
-
def __init__(self, iterator,
|
|
30
|
-
cb_before_first_next,
|
|
31
|
-
cb_next_return,
|
|
32
|
-
cb_after_iter_return,
|
|
33
|
-
original_func,
|
|
34
|
-
cb_args,
|
|
35
|
-
cb_kwargs) -> None:
|
|
36
|
-
self.iterator = iterator
|
|
37
|
-
self.cb_args = cb_args
|
|
38
|
-
self.cb_kwargs = cb_kwargs
|
|
39
|
-
self.original_func = original_func
|
|
40
|
-
self.cb_before_first_next = cb_before_first_next
|
|
41
|
-
self.cb_next_return = cb_next_return
|
|
42
|
-
self.cb_after_iter_return = cb_after_iter_return
|
|
43
|
-
self.first_next = True
|
|
44
|
-
|
|
45
|
-
def __iter__(self):
|
|
46
|
-
self.iterator = iter(self.iterator)
|
|
47
|
-
return self
|
|
48
|
-
|
|
49
|
-
def __next__(self):
|
|
50
|
-
if self.first_next:
|
|
51
|
-
for cb in self.cb_before_first_next:
|
|
52
|
-
cb(self.original_func, self.cb_args, self.cb_kwargs, original_iter=self.iterator)
|
|
53
|
-
self.first_next = False
|
|
54
|
-
try:
|
|
55
|
-
return_value = next(self.iterator)
|
|
56
|
-
for cb in self.cb_next_return:
|
|
57
|
-
cb(self.original_func, self.cb_args, self.cb_kwargs, return_value)
|
|
58
|
-
return return_value
|
|
59
|
-
except StopIteration as e:
|
|
60
|
-
for cb in self.cb_after_iter_return:
|
|
61
|
-
cb(self.original_func, self.cb_args, self.cb_kwargs, original_iter=self.iterator)
|
|
62
|
-
self.first_next = True
|
|
63
|
-
raise e
|
|
64
|
-
|
|
65
|
-
def __getattr__(self, item):
|
|
66
|
-
if item != '__iter__':
|
|
67
|
-
return getattr(self.iterator, item)
|
|
68
|
-
return self.__iter__
|
|
69
|
-
|
|
70
|
-
def __len__(self):
|
|
71
|
-
return len(self.iterator)
|
|
72
|
-
|
|
73
|
-
def __init__(self,
|
|
74
|
-
target: str,
|
|
75
|
-
cb_before: Sequence[callable] | callable = None,
|
|
76
|
-
cb_after: Sequence[callable] | callable = None,
|
|
77
|
-
cb_before_first_next: Sequence[callable] | callable = None,
|
|
78
|
-
cb_next_return: Sequence[callable] | callable = None,
|
|
79
|
-
cb_after_iter_return: Sequence[callable] | callable = None,
|
|
80
|
-
) -> None:
|
|
81
|
-
self.cb_before = cb_before if cb_before is not None else []
|
|
82
|
-
self.cb_after = cb_after if cb_after is not None else []
|
|
83
|
-
self.cb_after_iter_return = cb_after_iter_return if cb_after_iter_return is not None else []
|
|
84
|
-
self.cb_before_first_next = cb_before_first_next if cb_before_first_next is not None else []
|
|
85
|
-
self.cb_next_return = cb_next_return if cb_next_return is not None else []
|
|
86
|
-
if not _is_iterable(self.cb_before):
|
|
87
|
-
self.cb_before = [self.cb_before]
|
|
88
|
-
if not _is_iterable(self.cb_after):
|
|
89
|
-
self.cb_after = [self.cb_after]
|
|
90
|
-
if not _is_iterable(self.cb_after_iter_return):
|
|
91
|
-
self.cb_after_iter_return = [self.cb_after_iter_return]
|
|
92
|
-
if not _is_iterable(self.cb_before_first_next):
|
|
93
|
-
self.cb_before_first_next = [self.cb_before_first_next]
|
|
94
|
-
if not _is_iterable(self.cb_next_return):
|
|
95
|
-
self.cb_next_return = [self.cb_next_return]
|
|
96
|
-
self.target = target
|
|
97
|
-
self._patch()
|
|
98
|
-
|
|
99
|
-
def _patch(self):
|
|
100
|
-
def _callback(*args, **kwargs):
|
|
101
|
-
for cb in self.cb_before:
|
|
102
|
-
cb(original, args, kwargs)
|
|
103
|
-
|
|
104
|
-
try:
|
|
105
|
-
return_value = original(*args, **kwargs)
|
|
106
|
-
# if return_value is a generator, wrap it
|
|
107
|
-
if len(self.cb_after_iter_return) > 0 and _is_iterable(return_value):
|
|
108
|
-
return_value = self._wrap_iterator(return_value,
|
|
109
|
-
original,
|
|
110
|
-
args, kwargs)
|
|
111
|
-
|
|
112
|
-
except Exception as exception:
|
|
113
|
-
# We are assuming the patched function does not return an exception.
|
|
114
|
-
# return_value = exception
|
|
115
|
-
raise exception
|
|
116
|
-
|
|
117
|
-
for cb in self.cb_after:
|
|
118
|
-
cb(original, args, kwargs, return_value)
|
|
119
|
-
|
|
120
|
-
if isinstance(return_value, Exception):
|
|
121
|
-
raise return_value
|
|
122
|
-
|
|
123
|
-
return return_value
|
|
124
|
-
|
|
125
|
-
original = get_function_from_string(self.target)
|
|
126
|
-
# Patch the original function with the callback
|
|
127
|
-
self.patcher = patch(self.target, new=_callback)
|
|
128
|
-
|
|
129
|
-
def start(self):
|
|
130
|
-
self.patcher.start()
|
|
131
|
-
|
|
132
|
-
def stop(self):
|
|
133
|
-
self.patcher.stop()
|
|
134
|
-
|
|
135
|
-
def _wrap_iterator(self, iterator, original_func, args, kwargs):
|
|
136
|
-
return Wrapper.IteratorWrapper(iterator,
|
|
137
|
-
cb_before_first_next=self.cb_before_first_next,
|
|
138
|
-
cb_next_return=self.cb_next_return,
|
|
139
|
-
cb_after_iter_return=self.cb_after_iter_return,
|
|
140
|
-
original_func=original_func,
|
|
141
|
-
cb_args=args,
|
|
142
|
-
cb_kwargs=kwargs)
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
def get_function_from_string(target: str):
|
|
146
|
-
target_spl = target.split('.')
|
|
147
|
-
for i in range(len(target_spl)):
|
|
148
|
-
module_name = '.'.join(target_spl[:-i-1])
|
|
149
|
-
function_name = '.'.join(target_spl[-i-1:])
|
|
150
|
-
try:
|
|
151
|
-
module = importlib.import_module(module_name)
|
|
152
|
-
except ModuleNotFoundError:
|
|
153
|
-
continue
|
|
154
|
-
break
|
|
155
|
-
else:
|
|
156
|
-
raise ModuleNotFoundError(f"Module {module_name} not found")
|
|
157
|
-
|
|
158
|
-
try:
|
|
159
|
-
cur_obj = module
|
|
160
|
-
for objname in function_name.split('.'):
|
|
161
|
-
cur_obj = getattr(cur_obj, objname)
|
|
162
|
-
except AttributeError:
|
|
163
|
-
raise ModuleNotFoundError(f"Module attribute {module_name}.{objname} not found")
|
|
164
|
-
return cur_obj
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
class PytorchPatcher:
|
|
168
|
-
class DataLoaderInfo:
|
|
169
|
-
def __init__(self, dataloader: DataLoader):
|
|
170
|
-
self.dataloader = dataloader
|
|
171
|
-
self.times_started_iter = 0 # This includes the current iteration
|
|
172
|
-
self.is_iterating = False
|
|
173
|
-
self.metrics = []
|
|
174
|
-
self.iteration_idx = None
|
|
175
|
-
self.predictions = defaultdict(list) # TODO: save to disk
|
|
176
|
-
self.cur_batch = None
|
|
177
|
-
|
|
178
|
-
for obj in [dataloader, dataloader.batch_sampler]:
|
|
179
|
-
if hasattr(obj, 'batch_size'):
|
|
180
|
-
self.batch_size = obj.batch_size
|
|
181
|
-
_LOGGER.debug(f"Found batch size {self.batch_size} for dataloader {dataloader}")
|
|
182
|
-
break
|
|
183
|
-
else:
|
|
184
|
-
self.batch_size = None
|
|
185
|
-
_LOGGER.debug(f"Could not find batch size for dataloader {dataloader}")
|
|
186
|
-
|
|
187
|
-
def __str__(self) -> str:
|
|
188
|
-
return f"DataLoaderInfo: {self.dataloader}, number_of_times_iterated: {self.times_started_iter}, " \
|
|
189
|
-
f"is_iterating: {self.is_iterating}"
|
|
190
|
-
|
|
191
|
-
def append_metric(self, name: str, value: float, step, epoch):
|
|
192
|
-
self.metrics.append([name, value, step, epoch])
|
|
193
|
-
|
|
194
|
-
AUTO_LOSS_LOG_INTERVAL = 20
|
|
195
|
-
|
|
196
|
-
def __init__(self) -> None:
|
|
197
|
-
self.dataloaders_info: Dict[Any, PytorchPatcher.DataLoaderInfo] = OrderedDict()
|
|
198
|
-
self.metrics_association = {} # Associate metrics with dataloaders
|
|
199
|
-
self.last_dataloader = None
|
|
200
|
-
self.exit_with_error = False
|
|
201
|
-
|
|
202
|
-
def _dataloader_created(self,
|
|
203
|
-
original_obj, func_args, func_kwargs,
|
|
204
|
-
return_value):
|
|
205
|
-
|
|
206
|
-
dataloader = func_args[0]
|
|
207
|
-
if dataloader in self.dataloaders_info:
|
|
208
|
-
_LOGGER.warning("Dataloader already exists")
|
|
209
|
-
_LOGGER.debug('Adding a new dataloader')
|
|
210
|
-
self.dataloaders_info[dataloader] = PytorchPatcher.DataLoaderInfo(dataloader)
|
|
211
|
-
|
|
212
|
-
def _inc_exp_step(self) -> int:
|
|
213
|
-
exp = Experiment.get_singleton_experiment()
|
|
214
|
-
if exp.cur_step is None:
|
|
215
|
-
exp._set_step(0)
|
|
216
|
-
else:
|
|
217
|
-
exp._set_step(exp.cur_step + 1)
|
|
218
|
-
|
|
219
|
-
return exp.cur_step
|
|
220
|
-
|
|
221
|
-
def _backward_cb(self,
|
|
222
|
-
original_obj, func_args, func_kwargs):
|
|
223
|
-
"""
|
|
224
|
-
This method is a wrapper for the backward method of the Pytorch Tensor class.
|
|
225
|
-
"""
|
|
226
|
-
loss = func_args[0]
|
|
227
|
-
cur_step = self._inc_exp_step()
|
|
228
|
-
|
|
229
|
-
if cur_step % PytorchPatcher.AUTO_LOSS_LOG_INTERVAL == 0:
|
|
230
|
-
self._log_metric('loss', loss.item())
|
|
231
|
-
|
|
232
|
-
def clf_loss_computed(self,
|
|
233
|
-
original_obj, func_args, func_kwargs, return_value):
|
|
234
|
-
loss = return_value.detach().cpu()
|
|
235
|
-
# if is not a 0-d tensor, do not log
|
|
236
|
-
if len(loss.shape) != 0:
|
|
237
|
-
return
|
|
238
|
-
loss = loss.item()
|
|
239
|
-
loss_name = original_obj.__name__
|
|
240
|
-
cur_step = self._inc_exp_step()
|
|
241
|
-
|
|
242
|
-
dataloader = self.get_last_dataloader()
|
|
243
|
-
if dataloader is not None and self.dataloaders_info[dataloader].is_iterating:
|
|
244
|
-
self.metrics_association[func_args[0]] = dataloader
|
|
245
|
-
|
|
246
|
-
if cur_step % PytorchPatcher.AUTO_LOSS_LOG_INTERVAL == 0:
|
|
247
|
-
self._log_metric(loss_name, loss, dataloader=dataloader)
|
|
248
|
-
|
|
249
|
-
def _classification_loss_computed(self, preds, targets):
|
|
250
|
-
dataloader = self.get_last_dataloader()
|
|
251
|
-
if dataloader is not None and self.dataloaders_info[dataloader].is_iterating:
|
|
252
|
-
dinfo = self.dataloaders_info[dataloader]
|
|
253
|
-
if not isinstance(dinfo.cur_batch, dict) or 'metainfo' not in dinfo.cur_batch:
|
|
254
|
-
_LOGGER.debug(f"No metainfo in batch")
|
|
255
|
-
return
|
|
256
|
-
batch_metainfo = dinfo.cur_batch['metainfo']
|
|
257
|
-
if 'id' not in batch_metainfo[0]:
|
|
258
|
-
_LOGGER.debug("No id in batch metainfo")
|
|
259
|
-
return
|
|
260
|
-
resources_ids = [b['id'] for b in batch_metainfo]
|
|
261
|
-
|
|
262
|
-
if len(resources_ids) != len(preds):
|
|
263
|
-
_LOGGER.debug(f"Number of predictions ({len(preds)}) and targets ({len(targets)}) do not match")
|
|
264
|
-
return
|
|
265
|
-
|
|
266
|
-
dinfo.predictions['predictions'].extend(preds)
|
|
267
|
-
dinfo.predictions['id'].extend(resources_ids)
|
|
268
|
-
dinfo.predictions['frame_index'].extend([b['frame_index'] for b in batch_metainfo])
|
|
269
|
-
else:
|
|
270
|
-
_LOGGER.warning("No dataloader found")
|
|
271
|
-
|
|
272
|
-
def bce_with_logits_computed(self,
|
|
273
|
-
original_obj, func_args, func_kwargs, return_value):
|
|
274
|
-
self.clf_loss_computed(original_obj, func_args, func_kwargs, return_value)
|
|
275
|
-
preds = func_kwargs['input'] if 'input' in func_kwargs else func_args[0]
|
|
276
|
-
targets = func_kwargs['target'] if 'target' in func_kwargs else func_args[1]
|
|
277
|
-
preds = torch.nn.functional.sigmoid(preds).detach().cpu()
|
|
278
|
-
targets = targets.detach().cpu()
|
|
279
|
-
self._classification_loss_computed(preds, targets)
|
|
280
|
-
|
|
281
|
-
def ce_computed(self,
|
|
282
|
-
original_obj, func_args, func_kwargs, return_value):
|
|
283
|
-
self.clf_loss_computed(original_obj, func_args, func_kwargs, return_value)
|
|
284
|
-
preds = func_kwargs['input'] if 'input' in func_kwargs else func_args[0]
|
|
285
|
-
targets = func_kwargs['target'] if 'target' in func_kwargs else func_args[1]
|
|
286
|
-
preds = preds.detach().cpu()
|
|
287
|
-
targets = targets.detach().cpu()
|
|
288
|
-
self._classification_loss_computed(preds, targets)
|
|
289
|
-
|
|
290
|
-
def _dataloader_start_iterating_cb(self,
|
|
291
|
-
original_obj, func_args, func_kwargs, original_iter):
|
|
292
|
-
exp = Experiment.get_singleton_experiment()
|
|
293
|
-
dataloader = func_args[0] # self
|
|
294
|
-
dataloader_info = self.dataloaders_info[dataloader]
|
|
295
|
-
dataloader_info.is_iterating = True
|
|
296
|
-
if dataloader_info.iteration_idx is None:
|
|
297
|
-
dataloader_info.iteration_idx = self._get_dataloader_iteration_idx()+1
|
|
298
|
-
exp._set_epoch(dataloader_info.times_started_iter)
|
|
299
|
-
dataloader_info.times_started_iter += 1
|
|
300
|
-
|
|
301
|
-
self.last_dataloader = dataloader
|
|
302
|
-
|
|
303
|
-
_LOGGER.debug(f'Dataloader is iterating: {dataloader_info}')
|
|
304
|
-
|
|
305
|
-
def _dataloader_next(self,
|
|
306
|
-
original_obj, func_args, func_kwargs, return_value):
|
|
307
|
-
dataloader = func_args[0]
|
|
308
|
-
dataloder_info = self.dataloaders_info[dataloader]
|
|
309
|
-
dataloder_info.cur_batch = return_value
|
|
310
|
-
|
|
311
|
-
def _dataloader_stop_iterating_cb(self,
|
|
312
|
-
original_obj, func_args, func_kwargs, original_iter):
|
|
313
|
-
dataloader = func_args[0]
|
|
314
|
-
dinfo = self.dataloaders_info[dataloader]
|
|
315
|
-
dinfo.is_iterating = False
|
|
316
|
-
dinfo.cur_batch = None
|
|
317
|
-
# find the dataloader that is still iterating # FIXME: For 3 dataloaders being iterating
|
|
318
|
-
for dloader, dlinfo in self.dataloaders_info.items():
|
|
319
|
-
if dlinfo.is_iterating:
|
|
320
|
-
self.last_dataloader = dloader
|
|
321
|
-
break
|
|
322
|
-
else:
|
|
323
|
-
_LOGGER.debug("No dataloader is iterating")
|
|
324
|
-
|
|
325
|
-
_LOGGER.debug(f'Dataloader stopped iterating: {self.dataloaders_info[dataloader]}')
|
|
326
|
-
|
|
327
|
-
def _log_metric(self, name, value,
|
|
328
|
-
dataloader=None,
|
|
329
|
-
**kwargs):
|
|
330
|
-
exp = Experiment.get_singleton_experiment()
|
|
331
|
-
if self.finish_callback not in exp.finish_callbacks:
|
|
332
|
-
exp._add_finish_callback(self.finish_callback)
|
|
333
|
-
|
|
334
|
-
if dataloader is None:
|
|
335
|
-
dataloader = self.get_last_dataloader()
|
|
336
|
-
dloader_info = self.dataloaders_info[dataloader]
|
|
337
|
-
if not dloader_info.is_iterating:
|
|
338
|
-
dataloader = None
|
|
339
|
-
else:
|
|
340
|
-
dloader_info = self.dataloaders_info[dataloader]
|
|
341
|
-
|
|
342
|
-
if dataloader is not None:
|
|
343
|
-
name = f"dataset{dloader_info.iteration_idx+1}/{name}"
|
|
344
|
-
dloader_info.append_metric(name, value, exp.cur_step, exp.cur_epoch)
|
|
345
|
-
|
|
346
|
-
_LOGGER.debug(f"Logging metric {name} with value {value}")
|
|
347
|
-
exp.log_metric(name, value, **kwargs)
|
|
348
|
-
|
|
349
|
-
def torchmetric_clf_computed(self,
|
|
350
|
-
original_obj, func_args, func_kwargs, return_value):
|
|
351
|
-
if isinstance(return_value, torch.Tensor):
|
|
352
|
-
return_value = return_value.item()
|
|
353
|
-
|
|
354
|
-
dataloader = self.metrics_association.get(func_args[0], None)
|
|
355
|
-
|
|
356
|
-
self._log_metric(func_args[0].__class__.__name__,
|
|
357
|
-
value=return_value,
|
|
358
|
-
dataloader=dataloader)
|
|
359
|
-
|
|
360
|
-
def torchmetric_clf_updated(self,
|
|
361
|
-
original_obj, func_args, func_kwargs):
|
|
362
|
-
dataloader = self.get_last_dataloader()
|
|
363
|
-
if dataloader is None or not self.dataloaders_info[dataloader].is_iterating:
|
|
364
|
-
_LOGGER.debug("Dataloader not found or not iterating")
|
|
365
|
-
return
|
|
366
|
-
|
|
367
|
-
self.metrics_association[func_args[0]] = dataloader
|
|
368
|
-
|
|
369
|
-
def _get_dataloader_iteration_idx(self) -> int:
|
|
370
|
-
dataloader = self.get_last_dataloader()
|
|
371
|
-
if dataloader is None:
|
|
372
|
-
return -1
|
|
373
|
-
return self.dataloaders_info[dataloader].iteration_idx
|
|
374
|
-
|
|
375
|
-
def get_last_dataloader(self):
|
|
376
|
-
return self.last_dataloader
|
|
377
|
-
|
|
378
|
-
def _rename_metric(self, metric_name: str, phase: str) -> str:
|
|
379
|
-
real_metric_name = metric_name.split('/', 1)[-1]
|
|
380
|
-
if real_metric_name.startswith('Binary') or real_metric_name.startswith('Multiclass') or real_metric_name.startswith('Multilabel'):
|
|
381
|
-
real_metric_name = real_metric_name.replace('Binary', '').replace(
|
|
382
|
-
'Multiclass', '').replace('Multilabel', '')
|
|
383
|
-
|
|
384
|
-
if real_metric_name == 'Recall':
|
|
385
|
-
real_metric_name = 'Sensitivity'
|
|
386
|
-
|
|
387
|
-
if real_metric_name == 'Precision':
|
|
388
|
-
real_metric_name = 'Positive Predictive Value'
|
|
389
|
-
|
|
390
|
-
if phase is not None:
|
|
391
|
-
return f"{phase}/{real_metric_name}"
|
|
392
|
-
|
|
393
|
-
return metric_name.split('/', 1)[0] + real_metric_name
|
|
394
|
-
|
|
395
|
-
def finish_callback(self, exp: Experiment):
|
|
396
|
-
# Get the last dataloader with 1 iteration, and assume it is the test dataloader
|
|
397
|
-
dataloader = None
|
|
398
|
-
phase = None
|
|
399
|
-
for dloader, dlinfo in reversed(self.dataloaders_info.items()):
|
|
400
|
-
if dlinfo.times_started_iter == 1:
|
|
401
|
-
dataloader = dloader
|
|
402
|
-
phase = 'test'
|
|
403
|
-
break
|
|
404
|
-
else:
|
|
405
|
-
_LOGGER.debug('No dataloader with 1 iteration found')
|
|
406
|
-
if dataloader is None:
|
|
407
|
-
dataloader = self.get_last_dataloader()
|
|
408
|
-
if len(self.dataloaders_info) > 1:
|
|
409
|
-
phase = 'test'
|
|
410
|
-
else:
|
|
411
|
-
_LOGGER.warning("No test dataloader found")
|
|
412
|
-
if dataloader is None:
|
|
413
|
-
_LOGGER.warning("No dataloader to log found")
|
|
414
|
-
return
|
|
415
|
-
|
|
416
|
-
dlinfo = self.dataloaders_info[dataloader]
|
|
417
|
-
|
|
418
|
-
# log predictions
|
|
419
|
-
if len(dlinfo.predictions) > 0:
|
|
420
|
-
if hasattr(dataloader.dataset, 'labels_set'):
|
|
421
|
-
exp.log_classification_predictions(predictions_conf=np.array(dlinfo.predictions['predictions']),
|
|
422
|
-
label_names=dataloader.dataset.labels_set,
|
|
423
|
-
resource_ids=dlinfo.predictions['id'],
|
|
424
|
-
dataset_split=phase,
|
|
425
|
-
frame_idxs=dlinfo.predictions['frame_index'])
|
|
426
|
-
|
|
427
|
-
dlinfo_metrics = pd.DataFrame(dlinfo.metrics,
|
|
428
|
-
columns=['name', 'value', 'step', 'epoch'])
|
|
429
|
-
summary = {'metrics': {}}
|
|
430
|
-
# only use value from the last epoch
|
|
431
|
-
dlinfo_metrics = dlinfo_metrics[dlinfo_metrics['epoch'] == dlinfo_metrics['epoch'].max()]
|
|
432
|
-
|
|
433
|
-
for metric_name, value in dlinfo_metrics.groupby('name')['value'].mean().items():
|
|
434
|
-
metric_name = self._rename_metric(metric_name, phase)
|
|
435
|
-
summary['metrics'][metric_name] = value
|
|
436
|
-
|
|
437
|
-
exp.add_to_summary(summary)
|
|
438
|
-
|
|
439
|
-
def module_constructed_cb(self,
|
|
440
|
-
original_obj, func_args, func_kwargs, value):
|
|
441
|
-
exp = Experiment.get_singleton_experiment()
|
|
442
|
-
if exp is not None and exp.model is None:
|
|
443
|
-
model = func_args[0]
|
|
444
|
-
|
|
445
|
-
# check that is not a torchmetrics model
|
|
446
|
-
if model.__module__.startswith('torchmetrics.'):
|
|
447
|
-
return
|
|
448
|
-
# Not a loss function
|
|
449
|
-
if model.__module__.startswith('torch.nn.modules.loss'):
|
|
450
|
-
return
|
|
451
|
-
# Not a torchvision transform
|
|
452
|
-
if model.__module__.startswith('torchvision.transforms'):
|
|
453
|
-
return
|
|
454
|
-
# Not a optimizer
|
|
455
|
-
if model.__module__.startswith('torch.optim'):
|
|
456
|
-
return
|
|
457
|
-
|
|
458
|
-
exp.set_model(model)
|
|
459
|
-
_LOGGER.debug(f'Found user model {model.__class__.__name__}')
|
|
460
|
-
|
|
461
|
-
def custom_excepthook(self, exc_type, exc_value, traceback):
|
|
462
|
-
ORIGINAL_EXCEPTHOOK
|
|
463
|
-
self.exit_with_error = True
|
|
464
|
-
# Call the original exception hook
|
|
465
|
-
ORIGINAL_EXCEPTHOOK(exc_type, exc_value, traceback)
|
|
466
|
-
|
|
467
|
-
def at_exit_cb(self):
|
|
468
|
-
if self.exit_with_error:
|
|
469
|
-
return
|
|
470
|
-
exp = Experiment.get_singleton_experiment()
|
|
471
|
-
if exp is not None and (exp.cur_step is not None or exp.cur_epoch is not None):
|
|
472
|
-
exp.finish()
|
|
473
|
-
|
|
474
|
-
|
|
475
|
-
def initialize_automatic_logging(enable_rich_logging: bool = True):
|
|
476
|
-
"""
|
|
477
|
-
This function initializes the automatic logging of Pytorch loss using patching.
|
|
478
|
-
"""
|
|
479
|
-
from rich.logging import RichHandler
|
|
480
|
-
global IS_INITIALIZED, ORIGINAL_EXCEPTHOOK
|
|
481
|
-
|
|
482
|
-
if IS_INITIALIZED == True:
|
|
483
|
-
return
|
|
484
|
-
IS_INITIALIZED = True
|
|
485
|
-
|
|
486
|
-
# check if RichHandler is already in the handlers
|
|
487
|
-
if enable_rich_logging and not any(isinstance(h, RichHandler) for h in logging.getLogger().handlers):
|
|
488
|
-
logging.getLogger().handlers.append(RichHandler()) # set rich logging handler for the root logger
|
|
489
|
-
# logging.getLogger("datamint").setLevel(logging.INFO)
|
|
490
|
-
|
|
491
|
-
pytorch_patcher = PytorchPatcher()
|
|
492
|
-
|
|
493
|
-
torchmetrics_clfs_base_metrics = ['Recall', 'Precision', 'AveragePrecision',
|
|
494
|
-
'F1Score', 'Accuracy', 'AUROC',
|
|
495
|
-
'CohenKappa']
|
|
496
|
-
|
|
497
|
-
torchmetrics_clf_metrics = [f'Multiclass{m}' for m in torchmetrics_clfs_base_metrics]
|
|
498
|
-
torchmetrics_clf_metrics += [f'Multilabel{m}' for m in torchmetrics_clfs_base_metrics]
|
|
499
|
-
torchmetrics_clf_metrics += [f'Binary{m}' for m in torchmetrics_clfs_base_metrics]
|
|
500
|
-
torchmetrics_clf_metrics = [f'torchmetrics.classification.{m}' for m in torchmetrics_clf_metrics]
|
|
501
|
-
torchmetrics_detseg_metrics = ['torchmetrics.segmentation.GeneralizedDiceScore',
|
|
502
|
-
'torchmetrics.detection.iou.IntersectionOverUnion',
|
|
503
|
-
'torchmetrics.detection.giou.GeneralizedIntersectionOverUnion']
|
|
504
|
-
|
|
505
|
-
torchmetrics_metrics = torchmetrics_clf_metrics + torchmetrics_detseg_metrics
|
|
506
|
-
|
|
507
|
-
params = [
|
|
508
|
-
{
|
|
509
|
-
'target': ['torch.Tensor.backward', 'torch.tensor.Tensor.backward'],
|
|
510
|
-
'cb_before': pytorch_patcher._backward_cb
|
|
511
|
-
},
|
|
512
|
-
{
|
|
513
|
-
'target': 'torch.utils.data.DataLoader.__iter__',
|
|
514
|
-
'cb_before_first_next': pytorch_patcher._dataloader_start_iterating_cb,
|
|
515
|
-
'cb_after_iter_return': pytorch_patcher._dataloader_stop_iterating_cb,
|
|
516
|
-
'cb_next_return': pytorch_patcher._dataloader_next,
|
|
517
|
-
},
|
|
518
|
-
{
|
|
519
|
-
'target': 'torch.utils.data.DataLoader.__init__',
|
|
520
|
-
'cb_after': pytorch_patcher._dataloader_created
|
|
521
|
-
},
|
|
522
|
-
{
|
|
523
|
-
'target': 'torch.nn.functional.nll_loss',
|
|
524
|
-
'cb_after': pytorch_patcher.clf_loss_computed
|
|
525
|
-
},
|
|
526
|
-
{
|
|
527
|
-
'target': [f'{m}.compute' for m in torchmetrics_metrics],
|
|
528
|
-
'cb_after': pytorch_patcher.torchmetric_clf_computed
|
|
529
|
-
},
|
|
530
|
-
{
|
|
531
|
-
'target': [f'{m}.update' for m in torchmetrics_metrics],
|
|
532
|
-
'cb_before': pytorch_patcher.torchmetric_clf_updated
|
|
533
|
-
},
|
|
534
|
-
{
|
|
535
|
-
'target': 'torch.nn.modules.module.Module.__init__',
|
|
536
|
-
'cb_after': pytorch_patcher.module_constructed_cb
|
|
537
|
-
},
|
|
538
|
-
{
|
|
539
|
-
'target': 'torch.nn.functional.binary_cross_entropy_with_logits',
|
|
540
|
-
'cb_after': pytorch_patcher.bce_with_logits_computed
|
|
541
|
-
},
|
|
542
|
-
{
|
|
543
|
-
'target': ['torch.nn.functional.binary_cross_entropy', 'torch.nn.functional.cross_entropy'],
|
|
544
|
-
'cb_after': pytorch_patcher.ce_computed
|
|
545
|
-
}
|
|
546
|
-
]
|
|
547
|
-
|
|
548
|
-
# explode the list of targets into individual targets
|
|
549
|
-
new_params = []
|
|
550
|
-
for p in params:
|
|
551
|
-
if isinstance(p['target'], list):
|
|
552
|
-
for t in p['target']:
|
|
553
|
-
new_params.append({**p, 'target': t})
|
|
554
|
-
else:
|
|
555
|
-
new_params.append(p)
|
|
556
|
-
params = new_params
|
|
557
|
-
|
|
558
|
-
for p in params:
|
|
559
|
-
try:
|
|
560
|
-
Wrapper(**p).start()
|
|
561
|
-
except Exception as e:
|
|
562
|
-
_LOGGER.debug(f"Error while patching {p['target']}: {e}")
|
|
563
|
-
|
|
564
|
-
try:
|
|
565
|
-
# Set the custom exception hook
|
|
566
|
-
ORIGINAL_EXCEPTHOOK = sys.excepthook
|
|
567
|
-
sys.excepthook = pytorch_patcher.custom_excepthook
|
|
568
|
-
atexit.register(pytorch_patcher.at_exit_cb)
|
|
569
|
-
except Exception:
|
|
570
|
-
_LOGGER.warning("Failed to use atexit.register")
|