datamint 2.3.3__py3-none-any.whl → 2.9.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (59) hide show
  1. datamint/__init__.py +1 -3
  2. datamint/api/__init__.py +0 -3
  3. datamint/api/base_api.py +286 -54
  4. datamint/api/client.py +76 -13
  5. datamint/api/endpoints/__init__.py +2 -2
  6. datamint/api/endpoints/annotations_api.py +186 -28
  7. datamint/api/endpoints/deploy_model_api.py +78 -0
  8. datamint/api/endpoints/models_api.py +1 -0
  9. datamint/api/endpoints/projects_api.py +38 -7
  10. datamint/api/endpoints/resources_api.py +227 -100
  11. datamint/api/entity_base_api.py +66 -7
  12. datamint/apihandler/base_api_handler.py +0 -1
  13. datamint/apihandler/dto/annotation_dto.py +2 -0
  14. datamint/client_cmd_tools/datamint_config.py +0 -1
  15. datamint/client_cmd_tools/datamint_upload.py +3 -1
  16. datamint/configs.py +11 -7
  17. datamint/dataset/base_dataset.py +24 -4
  18. datamint/dataset/dataset.py +1 -1
  19. datamint/entities/__init__.py +1 -1
  20. datamint/entities/annotations/__init__.py +13 -0
  21. datamint/entities/{annotation.py → annotations/annotation.py} +81 -47
  22. datamint/entities/annotations/image_classification.py +12 -0
  23. datamint/entities/annotations/image_segmentation.py +252 -0
  24. datamint/entities/annotations/volume_segmentation.py +273 -0
  25. datamint/entities/base_entity.py +100 -6
  26. datamint/entities/cache_manager.py +129 -15
  27. datamint/entities/datasetinfo.py +60 -65
  28. datamint/entities/deployjob.py +18 -0
  29. datamint/entities/project.py +39 -0
  30. datamint/entities/resource.py +310 -46
  31. datamint/lightning/__init__.py +1 -0
  32. datamint/lightning/datamintdatamodule.py +103 -0
  33. datamint/mlflow/__init__.py +65 -0
  34. datamint/mlflow/artifact/__init__.py +1 -0
  35. datamint/mlflow/artifact/datamint_artifacts_repo.py +8 -0
  36. datamint/mlflow/env_utils.py +131 -0
  37. datamint/mlflow/env_vars.py +5 -0
  38. datamint/mlflow/flavors/__init__.py +17 -0
  39. datamint/mlflow/flavors/datamint_flavor.py +150 -0
  40. datamint/mlflow/flavors/model.py +877 -0
  41. datamint/mlflow/lightning/callbacks/__init__.py +1 -0
  42. datamint/mlflow/lightning/callbacks/modelcheckpoint.py +410 -0
  43. datamint/mlflow/models/__init__.py +93 -0
  44. datamint/mlflow/tracking/datamint_store.py +76 -0
  45. datamint/mlflow/tracking/default_experiment.py +27 -0
  46. datamint/mlflow/tracking/fluent.py +91 -0
  47. datamint/utils/env.py +27 -0
  48. datamint/utils/visualization.py +21 -13
  49. datamint-2.9.0.dist-info/METADATA +220 -0
  50. datamint-2.9.0.dist-info/RECORD +73 -0
  51. {datamint-2.3.3.dist-info → datamint-2.9.0.dist-info}/WHEEL +1 -1
  52. datamint-2.9.0.dist-info/entry_points.txt +18 -0
  53. datamint/apihandler/exp_api_handler.py +0 -204
  54. datamint/experiment/__init__.py +0 -1
  55. datamint/experiment/_patcher.py +0 -570
  56. datamint/experiment/experiment.py +0 -1049
  57. datamint-2.3.3.dist-info/METADATA +0 -125
  58. datamint-2.3.3.dist-info/RECORD +0 -54
  59. datamint-2.3.3.dist-info/entry_points.txt +0 -4
@@ -1,570 +0,0 @@
1
- from unittest.mock import patch
2
- import importlib
3
- from typing import Sequence, Any, Dict
4
- import logging
5
- from .experiment import Experiment
6
- from torch.utils.data import DataLoader
7
- import torch
8
- import sys
9
- import pandas as pd
10
- import atexit
11
- from collections import OrderedDict, defaultdict
12
- import numpy as np
13
-
14
- _LOGGER = logging.getLogger(__name__)
15
-
16
- IS_INITIALIZED = False
17
-
18
-
19
- def _is_iterable(obj):
20
- try:
21
- iter(obj)
22
- return True
23
- except TypeError:
24
- return False
25
-
26
-
27
- class Wrapper:
28
- class IteratorWrapper:
29
- def __init__(self, iterator,
30
- cb_before_first_next,
31
- cb_next_return,
32
- cb_after_iter_return,
33
- original_func,
34
- cb_args,
35
- cb_kwargs) -> None:
36
- self.iterator = iterator
37
- self.cb_args = cb_args
38
- self.cb_kwargs = cb_kwargs
39
- self.original_func = original_func
40
- self.cb_before_first_next = cb_before_first_next
41
- self.cb_next_return = cb_next_return
42
- self.cb_after_iter_return = cb_after_iter_return
43
- self.first_next = True
44
-
45
- def __iter__(self):
46
- self.iterator = iter(self.iterator)
47
- return self
48
-
49
- def __next__(self):
50
- if self.first_next:
51
- for cb in self.cb_before_first_next:
52
- cb(self.original_func, self.cb_args, self.cb_kwargs, original_iter=self.iterator)
53
- self.first_next = False
54
- try:
55
- return_value = next(self.iterator)
56
- for cb in self.cb_next_return:
57
- cb(self.original_func, self.cb_args, self.cb_kwargs, return_value)
58
- return return_value
59
- except StopIteration as e:
60
- for cb in self.cb_after_iter_return:
61
- cb(self.original_func, self.cb_args, self.cb_kwargs, original_iter=self.iterator)
62
- self.first_next = True
63
- raise e
64
-
65
- def __getattr__(self, item):
66
- if item != '__iter__':
67
- return getattr(self.iterator, item)
68
- return self.__iter__
69
-
70
- def __len__(self):
71
- return len(self.iterator)
72
-
73
- def __init__(self,
74
- target: str,
75
- cb_before: Sequence[callable] | callable = None,
76
- cb_after: Sequence[callable] | callable = None,
77
- cb_before_first_next: Sequence[callable] | callable = None,
78
- cb_next_return: Sequence[callable] | callable = None,
79
- cb_after_iter_return: Sequence[callable] | callable = None,
80
- ) -> None:
81
- self.cb_before = cb_before if cb_before is not None else []
82
- self.cb_after = cb_after if cb_after is not None else []
83
- self.cb_after_iter_return = cb_after_iter_return if cb_after_iter_return is not None else []
84
- self.cb_before_first_next = cb_before_first_next if cb_before_first_next is not None else []
85
- self.cb_next_return = cb_next_return if cb_next_return is not None else []
86
- if not _is_iterable(self.cb_before):
87
- self.cb_before = [self.cb_before]
88
- if not _is_iterable(self.cb_after):
89
- self.cb_after = [self.cb_after]
90
- if not _is_iterable(self.cb_after_iter_return):
91
- self.cb_after_iter_return = [self.cb_after_iter_return]
92
- if not _is_iterable(self.cb_before_first_next):
93
- self.cb_before_first_next = [self.cb_before_first_next]
94
- if not _is_iterable(self.cb_next_return):
95
- self.cb_next_return = [self.cb_next_return]
96
- self.target = target
97
- self._patch()
98
-
99
- def _patch(self):
100
- def _callback(*args, **kwargs):
101
- for cb in self.cb_before:
102
- cb(original, args, kwargs)
103
-
104
- try:
105
- return_value = original(*args, **kwargs)
106
- # if return_value is a generator, wrap it
107
- if len(self.cb_after_iter_return) > 0 and _is_iterable(return_value):
108
- return_value = self._wrap_iterator(return_value,
109
- original,
110
- args, kwargs)
111
-
112
- except Exception as exception:
113
- # We are assuming the patched function does not return an exception.
114
- # return_value = exception
115
- raise exception
116
-
117
- for cb in self.cb_after:
118
- cb(original, args, kwargs, return_value)
119
-
120
- if isinstance(return_value, Exception):
121
- raise return_value
122
-
123
- return return_value
124
-
125
- original = get_function_from_string(self.target)
126
- # Patch the original function with the callback
127
- self.patcher = patch(self.target, new=_callback)
128
-
129
- def start(self):
130
- self.patcher.start()
131
-
132
- def stop(self):
133
- self.patcher.stop()
134
-
135
- def _wrap_iterator(self, iterator, original_func, args, kwargs):
136
- return Wrapper.IteratorWrapper(iterator,
137
- cb_before_first_next=self.cb_before_first_next,
138
- cb_next_return=self.cb_next_return,
139
- cb_after_iter_return=self.cb_after_iter_return,
140
- original_func=original_func,
141
- cb_args=args,
142
- cb_kwargs=kwargs)
143
-
144
-
145
- def get_function_from_string(target: str):
146
- target_spl = target.split('.')
147
- for i in range(len(target_spl)):
148
- module_name = '.'.join(target_spl[:-i-1])
149
- function_name = '.'.join(target_spl[-i-1:])
150
- try:
151
- module = importlib.import_module(module_name)
152
- except ModuleNotFoundError:
153
- continue
154
- break
155
- else:
156
- raise ModuleNotFoundError(f"Module {module_name} not found")
157
-
158
- try:
159
- cur_obj = module
160
- for objname in function_name.split('.'):
161
- cur_obj = getattr(cur_obj, objname)
162
- except AttributeError:
163
- raise ModuleNotFoundError(f"Module attribute {module_name}.{objname} not found")
164
- return cur_obj
165
-
166
-
167
- class PytorchPatcher:
168
- class DataLoaderInfo:
169
- def __init__(self, dataloader: DataLoader):
170
- self.dataloader = dataloader
171
- self.times_started_iter = 0 # This includes the current iteration
172
- self.is_iterating = False
173
- self.metrics = []
174
- self.iteration_idx = None
175
- self.predictions = defaultdict(list) # TODO: save to disk
176
- self.cur_batch = None
177
-
178
- for obj in [dataloader, dataloader.batch_sampler]:
179
- if hasattr(obj, 'batch_size'):
180
- self.batch_size = obj.batch_size
181
- _LOGGER.debug(f"Found batch size {self.batch_size} for dataloader {dataloader}")
182
- break
183
- else:
184
- self.batch_size = None
185
- _LOGGER.debug(f"Could not find batch size for dataloader {dataloader}")
186
-
187
- def __str__(self) -> str:
188
- return f"DataLoaderInfo: {self.dataloader}, number_of_times_iterated: {self.times_started_iter}, " \
189
- f"is_iterating: {self.is_iterating}"
190
-
191
- def append_metric(self, name: str, value: float, step, epoch):
192
- self.metrics.append([name, value, step, epoch])
193
-
194
- AUTO_LOSS_LOG_INTERVAL = 20
195
-
196
- def __init__(self) -> None:
197
- self.dataloaders_info: Dict[Any, PytorchPatcher.DataLoaderInfo] = OrderedDict()
198
- self.metrics_association = {} # Associate metrics with dataloaders
199
- self.last_dataloader = None
200
- self.exit_with_error = False
201
-
202
- def _dataloader_created(self,
203
- original_obj, func_args, func_kwargs,
204
- return_value):
205
-
206
- dataloader = func_args[0]
207
- if dataloader in self.dataloaders_info:
208
- _LOGGER.warning("Dataloader already exists")
209
- _LOGGER.debug('Adding a new dataloader')
210
- self.dataloaders_info[dataloader] = PytorchPatcher.DataLoaderInfo(dataloader)
211
-
212
- def _inc_exp_step(self) -> int:
213
- exp = Experiment.get_singleton_experiment()
214
- if exp.cur_step is None:
215
- exp._set_step(0)
216
- else:
217
- exp._set_step(exp.cur_step + 1)
218
-
219
- return exp.cur_step
220
-
221
- def _backward_cb(self,
222
- original_obj, func_args, func_kwargs):
223
- """
224
- This method is a wrapper for the backward method of the Pytorch Tensor class.
225
- """
226
- loss = func_args[0]
227
- cur_step = self._inc_exp_step()
228
-
229
- if cur_step % PytorchPatcher.AUTO_LOSS_LOG_INTERVAL == 0:
230
- self._log_metric('loss', loss.item())
231
-
232
- def clf_loss_computed(self,
233
- original_obj, func_args, func_kwargs, return_value):
234
- loss = return_value.detach().cpu()
235
- # if is not a 0-d tensor, do not log
236
- if len(loss.shape) != 0:
237
- return
238
- loss = loss.item()
239
- loss_name = original_obj.__name__
240
- cur_step = self._inc_exp_step()
241
-
242
- dataloader = self.get_last_dataloader()
243
- if dataloader is not None and self.dataloaders_info[dataloader].is_iterating:
244
- self.metrics_association[func_args[0]] = dataloader
245
-
246
- if cur_step % PytorchPatcher.AUTO_LOSS_LOG_INTERVAL == 0:
247
- self._log_metric(loss_name, loss, dataloader=dataloader)
248
-
249
- def _classification_loss_computed(self, preds, targets):
250
- dataloader = self.get_last_dataloader()
251
- if dataloader is not None and self.dataloaders_info[dataloader].is_iterating:
252
- dinfo = self.dataloaders_info[dataloader]
253
- if not isinstance(dinfo.cur_batch, dict) or 'metainfo' not in dinfo.cur_batch:
254
- _LOGGER.debug(f"No metainfo in batch")
255
- return
256
- batch_metainfo = dinfo.cur_batch['metainfo']
257
- if 'id' not in batch_metainfo[0]:
258
- _LOGGER.debug("No id in batch metainfo")
259
- return
260
- resources_ids = [b['id'] for b in batch_metainfo]
261
-
262
- if len(resources_ids) != len(preds):
263
- _LOGGER.debug(f"Number of predictions ({len(preds)}) and targets ({len(targets)}) do not match")
264
- return
265
-
266
- dinfo.predictions['predictions'].extend(preds)
267
- dinfo.predictions['id'].extend(resources_ids)
268
- dinfo.predictions['frame_index'].extend([b['frame_index'] for b in batch_metainfo])
269
- else:
270
- _LOGGER.warning("No dataloader found")
271
-
272
- def bce_with_logits_computed(self,
273
- original_obj, func_args, func_kwargs, return_value):
274
- self.clf_loss_computed(original_obj, func_args, func_kwargs, return_value)
275
- preds = func_kwargs['input'] if 'input' in func_kwargs else func_args[0]
276
- targets = func_kwargs['target'] if 'target' in func_kwargs else func_args[1]
277
- preds = torch.nn.functional.sigmoid(preds).detach().cpu()
278
- targets = targets.detach().cpu()
279
- self._classification_loss_computed(preds, targets)
280
-
281
- def ce_computed(self,
282
- original_obj, func_args, func_kwargs, return_value):
283
- self.clf_loss_computed(original_obj, func_args, func_kwargs, return_value)
284
- preds = func_kwargs['input'] if 'input' in func_kwargs else func_args[0]
285
- targets = func_kwargs['target'] if 'target' in func_kwargs else func_args[1]
286
- preds = preds.detach().cpu()
287
- targets = targets.detach().cpu()
288
- self._classification_loss_computed(preds, targets)
289
-
290
- def _dataloader_start_iterating_cb(self,
291
- original_obj, func_args, func_kwargs, original_iter):
292
- exp = Experiment.get_singleton_experiment()
293
- dataloader = func_args[0] # self
294
- dataloader_info = self.dataloaders_info[dataloader]
295
- dataloader_info.is_iterating = True
296
- if dataloader_info.iteration_idx is None:
297
- dataloader_info.iteration_idx = self._get_dataloader_iteration_idx()+1
298
- exp._set_epoch(dataloader_info.times_started_iter)
299
- dataloader_info.times_started_iter += 1
300
-
301
- self.last_dataloader = dataloader
302
-
303
- _LOGGER.debug(f'Dataloader is iterating: {dataloader_info}')
304
-
305
- def _dataloader_next(self,
306
- original_obj, func_args, func_kwargs, return_value):
307
- dataloader = func_args[0]
308
- dataloder_info = self.dataloaders_info[dataloader]
309
- dataloder_info.cur_batch = return_value
310
-
311
- def _dataloader_stop_iterating_cb(self,
312
- original_obj, func_args, func_kwargs, original_iter):
313
- dataloader = func_args[0]
314
- dinfo = self.dataloaders_info[dataloader]
315
- dinfo.is_iterating = False
316
- dinfo.cur_batch = None
317
- # find the dataloader that is still iterating # FIXME: For 3 dataloaders being iterating
318
- for dloader, dlinfo in self.dataloaders_info.items():
319
- if dlinfo.is_iterating:
320
- self.last_dataloader = dloader
321
- break
322
- else:
323
- _LOGGER.debug("No dataloader is iterating")
324
-
325
- _LOGGER.debug(f'Dataloader stopped iterating: {self.dataloaders_info[dataloader]}')
326
-
327
- def _log_metric(self, name, value,
328
- dataloader=None,
329
- **kwargs):
330
- exp = Experiment.get_singleton_experiment()
331
- if self.finish_callback not in exp.finish_callbacks:
332
- exp._add_finish_callback(self.finish_callback)
333
-
334
- if dataloader is None:
335
- dataloader = self.get_last_dataloader()
336
- dloader_info = self.dataloaders_info[dataloader]
337
- if not dloader_info.is_iterating:
338
- dataloader = None
339
- else:
340
- dloader_info = self.dataloaders_info[dataloader]
341
-
342
- if dataloader is not None:
343
- name = f"dataset{dloader_info.iteration_idx+1}/{name}"
344
- dloader_info.append_metric(name, value, exp.cur_step, exp.cur_epoch)
345
-
346
- _LOGGER.debug(f"Logging metric {name} with value {value}")
347
- exp.log_metric(name, value, **kwargs)
348
-
349
- def torchmetric_clf_computed(self,
350
- original_obj, func_args, func_kwargs, return_value):
351
- if isinstance(return_value, torch.Tensor):
352
- return_value = return_value.item()
353
-
354
- dataloader = self.metrics_association.get(func_args[0], None)
355
-
356
- self._log_metric(func_args[0].__class__.__name__,
357
- value=return_value,
358
- dataloader=dataloader)
359
-
360
- def torchmetric_clf_updated(self,
361
- original_obj, func_args, func_kwargs):
362
- dataloader = self.get_last_dataloader()
363
- if dataloader is None or not self.dataloaders_info[dataloader].is_iterating:
364
- _LOGGER.debug("Dataloader not found or not iterating")
365
- return
366
-
367
- self.metrics_association[func_args[0]] = dataloader
368
-
369
- def _get_dataloader_iteration_idx(self) -> int:
370
- dataloader = self.get_last_dataloader()
371
- if dataloader is None:
372
- return -1
373
- return self.dataloaders_info[dataloader].iteration_idx
374
-
375
- def get_last_dataloader(self):
376
- return self.last_dataloader
377
-
378
- def _rename_metric(self, metric_name: str, phase: str) -> str:
379
- real_metric_name = metric_name.split('/', 1)[-1]
380
- if real_metric_name.startswith('Binary') or real_metric_name.startswith('Multiclass') or real_metric_name.startswith('Multilabel'):
381
- real_metric_name = real_metric_name.replace('Binary', '').replace(
382
- 'Multiclass', '').replace('Multilabel', '')
383
-
384
- if real_metric_name == 'Recall':
385
- real_metric_name = 'Sensitivity'
386
-
387
- if real_metric_name == 'Precision':
388
- real_metric_name = 'Positive Predictive Value'
389
-
390
- if phase is not None:
391
- return f"{phase}/{real_metric_name}"
392
-
393
- return metric_name.split('/', 1)[0] + real_metric_name
394
-
395
- def finish_callback(self, exp: Experiment):
396
- # Get the last dataloader with 1 iteration, and assume it is the test dataloader
397
- dataloader = None
398
- phase = None
399
- for dloader, dlinfo in reversed(self.dataloaders_info.items()):
400
- if dlinfo.times_started_iter == 1:
401
- dataloader = dloader
402
- phase = 'test'
403
- break
404
- else:
405
- _LOGGER.debug('No dataloader with 1 iteration found')
406
- if dataloader is None:
407
- dataloader = self.get_last_dataloader()
408
- if len(self.dataloaders_info) > 1:
409
- phase = 'test'
410
- else:
411
- _LOGGER.warning("No test dataloader found")
412
- if dataloader is None:
413
- _LOGGER.warning("No dataloader to log found")
414
- return
415
-
416
- dlinfo = self.dataloaders_info[dataloader]
417
-
418
- # log predictions
419
- if len(dlinfo.predictions) > 0:
420
- if hasattr(dataloader.dataset, 'labels_set'):
421
- exp.log_classification_predictions(predictions_conf=np.array(dlinfo.predictions['predictions']),
422
- label_names=dataloader.dataset.labels_set,
423
- resource_ids=dlinfo.predictions['id'],
424
- dataset_split=phase,
425
- frame_idxs=dlinfo.predictions['frame_index'])
426
-
427
- dlinfo_metrics = pd.DataFrame(dlinfo.metrics,
428
- columns=['name', 'value', 'step', 'epoch'])
429
- summary = {'metrics': {}}
430
- # only use value from the last epoch
431
- dlinfo_metrics = dlinfo_metrics[dlinfo_metrics['epoch'] == dlinfo_metrics['epoch'].max()]
432
-
433
- for metric_name, value in dlinfo_metrics.groupby('name')['value'].mean().items():
434
- metric_name = self._rename_metric(metric_name, phase)
435
- summary['metrics'][metric_name] = value
436
-
437
- exp.add_to_summary(summary)
438
-
439
- def module_constructed_cb(self,
440
- original_obj, func_args, func_kwargs, value):
441
- exp = Experiment.get_singleton_experiment()
442
- if exp is not None and exp.model is None:
443
- model = func_args[0]
444
-
445
- # check that is not a torchmetrics model
446
- if model.__module__.startswith('torchmetrics.'):
447
- return
448
- # Not a loss function
449
- if model.__module__.startswith('torch.nn.modules.loss'):
450
- return
451
- # Not a torchvision transform
452
- if model.__module__.startswith('torchvision.transforms'):
453
- return
454
- # Not a optimizer
455
- if model.__module__.startswith('torch.optim'):
456
- return
457
-
458
- exp.set_model(model)
459
- _LOGGER.debug(f'Found user model {model.__class__.__name__}')
460
-
461
- def custom_excepthook(self, exc_type, exc_value, traceback):
462
- ORIGINAL_EXCEPTHOOK
463
- self.exit_with_error = True
464
- # Call the original exception hook
465
- ORIGINAL_EXCEPTHOOK(exc_type, exc_value, traceback)
466
-
467
- def at_exit_cb(self):
468
- if self.exit_with_error:
469
- return
470
- exp = Experiment.get_singleton_experiment()
471
- if exp is not None and (exp.cur_step is not None or exp.cur_epoch is not None):
472
- exp.finish()
473
-
474
-
475
- def initialize_automatic_logging(enable_rich_logging: bool = True):
476
- """
477
- This function initializes the automatic logging of Pytorch loss using patching.
478
- """
479
- from rich.logging import RichHandler
480
- global IS_INITIALIZED, ORIGINAL_EXCEPTHOOK
481
-
482
- if IS_INITIALIZED == True:
483
- return
484
- IS_INITIALIZED = True
485
-
486
- # check if RichHandler is already in the handlers
487
- if enable_rich_logging and not any(isinstance(h, RichHandler) for h in logging.getLogger().handlers):
488
- logging.getLogger().handlers.append(RichHandler()) # set rich logging handler for the root logger
489
- # logging.getLogger("datamint").setLevel(logging.INFO)
490
-
491
- pytorch_patcher = PytorchPatcher()
492
-
493
- torchmetrics_clfs_base_metrics = ['Recall', 'Precision', 'AveragePrecision',
494
- 'F1Score', 'Accuracy', 'AUROC',
495
- 'CohenKappa']
496
-
497
- torchmetrics_clf_metrics = [f'Multiclass{m}' for m in torchmetrics_clfs_base_metrics]
498
- torchmetrics_clf_metrics += [f'Multilabel{m}' for m in torchmetrics_clfs_base_metrics]
499
- torchmetrics_clf_metrics += [f'Binary{m}' for m in torchmetrics_clfs_base_metrics]
500
- torchmetrics_clf_metrics = [f'torchmetrics.classification.{m}' for m in torchmetrics_clf_metrics]
501
- torchmetrics_detseg_metrics = ['torchmetrics.segmentation.GeneralizedDiceScore',
502
- 'torchmetrics.detection.iou.IntersectionOverUnion',
503
- 'torchmetrics.detection.giou.GeneralizedIntersectionOverUnion']
504
-
505
- torchmetrics_metrics = torchmetrics_clf_metrics + torchmetrics_detseg_metrics
506
-
507
- params = [
508
- {
509
- 'target': ['torch.Tensor.backward', 'torch.tensor.Tensor.backward'],
510
- 'cb_before': pytorch_patcher._backward_cb
511
- },
512
- {
513
- 'target': 'torch.utils.data.DataLoader.__iter__',
514
- 'cb_before_first_next': pytorch_patcher._dataloader_start_iterating_cb,
515
- 'cb_after_iter_return': pytorch_patcher._dataloader_stop_iterating_cb,
516
- 'cb_next_return': pytorch_patcher._dataloader_next,
517
- },
518
- {
519
- 'target': 'torch.utils.data.DataLoader.__init__',
520
- 'cb_after': pytorch_patcher._dataloader_created
521
- },
522
- {
523
- 'target': 'torch.nn.functional.nll_loss',
524
- 'cb_after': pytorch_patcher.clf_loss_computed
525
- },
526
- {
527
- 'target': [f'{m}.compute' for m in torchmetrics_metrics],
528
- 'cb_after': pytorch_patcher.torchmetric_clf_computed
529
- },
530
- {
531
- 'target': [f'{m}.update' for m in torchmetrics_metrics],
532
- 'cb_before': pytorch_patcher.torchmetric_clf_updated
533
- },
534
- {
535
- 'target': 'torch.nn.modules.module.Module.__init__',
536
- 'cb_after': pytorch_patcher.module_constructed_cb
537
- },
538
- {
539
- 'target': 'torch.nn.functional.binary_cross_entropy_with_logits',
540
- 'cb_after': pytorch_patcher.bce_with_logits_computed
541
- },
542
- {
543
- 'target': ['torch.nn.functional.binary_cross_entropy', 'torch.nn.functional.cross_entropy'],
544
- 'cb_after': pytorch_patcher.ce_computed
545
- }
546
- ]
547
-
548
- # explode the list of targets into individual targets
549
- new_params = []
550
- for p in params:
551
- if isinstance(p['target'], list):
552
- for t in p['target']:
553
- new_params.append({**p, 'target': t})
554
- else:
555
- new_params.append(p)
556
- params = new_params
557
-
558
- for p in params:
559
- try:
560
- Wrapper(**p).start()
561
- except Exception as e:
562
- _LOGGER.debug(f"Error while patching {p['target']}: {e}")
563
-
564
- try:
565
- # Set the custom exception hook
566
- ORIGINAL_EXCEPTHOOK = sys.excepthook
567
- sys.excepthook = pytorch_patcher.custom_excepthook
568
- atexit.register(pytorch_patcher.at_exit_cb)
569
- except Exception:
570
- _LOGGER.warning("Failed to use atexit.register")