datamint 2.3.3__py3-none-any.whl → 2.9.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- datamint/__init__.py +1 -3
- datamint/api/__init__.py +0 -3
- datamint/api/base_api.py +286 -54
- datamint/api/client.py +76 -13
- datamint/api/endpoints/__init__.py +2 -2
- datamint/api/endpoints/annotations_api.py +186 -28
- datamint/api/endpoints/deploy_model_api.py +78 -0
- datamint/api/endpoints/models_api.py +1 -0
- datamint/api/endpoints/projects_api.py +38 -7
- datamint/api/endpoints/resources_api.py +227 -100
- datamint/api/entity_base_api.py +66 -7
- datamint/apihandler/base_api_handler.py +0 -1
- datamint/apihandler/dto/annotation_dto.py +2 -0
- datamint/client_cmd_tools/datamint_config.py +0 -1
- datamint/client_cmd_tools/datamint_upload.py +3 -1
- datamint/configs.py +11 -7
- datamint/dataset/base_dataset.py +24 -4
- datamint/dataset/dataset.py +1 -1
- datamint/entities/__init__.py +1 -1
- datamint/entities/annotations/__init__.py +13 -0
- datamint/entities/{annotation.py → annotations/annotation.py} +81 -47
- datamint/entities/annotations/image_classification.py +12 -0
- datamint/entities/annotations/image_segmentation.py +252 -0
- datamint/entities/annotations/volume_segmentation.py +273 -0
- datamint/entities/base_entity.py +100 -6
- datamint/entities/cache_manager.py +129 -15
- datamint/entities/datasetinfo.py +60 -65
- datamint/entities/deployjob.py +18 -0
- datamint/entities/project.py +39 -0
- datamint/entities/resource.py +310 -46
- datamint/lightning/__init__.py +1 -0
- datamint/lightning/datamintdatamodule.py +103 -0
- datamint/mlflow/__init__.py +65 -0
- datamint/mlflow/artifact/__init__.py +1 -0
- datamint/mlflow/artifact/datamint_artifacts_repo.py +8 -0
- datamint/mlflow/env_utils.py +131 -0
- datamint/mlflow/env_vars.py +5 -0
- datamint/mlflow/flavors/__init__.py +17 -0
- datamint/mlflow/flavors/datamint_flavor.py +150 -0
- datamint/mlflow/flavors/model.py +877 -0
- datamint/mlflow/lightning/callbacks/__init__.py +1 -0
- datamint/mlflow/lightning/callbacks/modelcheckpoint.py +410 -0
- datamint/mlflow/models/__init__.py +93 -0
- datamint/mlflow/tracking/datamint_store.py +76 -0
- datamint/mlflow/tracking/default_experiment.py +27 -0
- datamint/mlflow/tracking/fluent.py +91 -0
- datamint/utils/env.py +27 -0
- datamint/utils/visualization.py +21 -13
- datamint-2.9.0.dist-info/METADATA +220 -0
- datamint-2.9.0.dist-info/RECORD +73 -0
- {datamint-2.3.3.dist-info → datamint-2.9.0.dist-info}/WHEEL +1 -1
- datamint-2.9.0.dist-info/entry_points.txt +18 -0
- datamint/apihandler/exp_api_handler.py +0 -204
- datamint/experiment/__init__.py +0 -1
- datamint/experiment/_patcher.py +0 -570
- datamint/experiment/experiment.py +0 -1049
- datamint-2.3.3.dist-info/METADATA +0 -125
- datamint-2.3.3.dist-info/RECORD +0 -54
- datamint-2.3.3.dist-info/entry_points.txt +0 -4
|
@@ -1,1049 +0,0 @@
|
|
|
1
|
-
import logging
|
|
2
|
-
from datamint.apihandler.api_handler import APIHandler
|
|
3
|
-
from datamint.apihandler.base_api_handler import DatamintException
|
|
4
|
-
from datetime import datetime, timezone
|
|
5
|
-
from typing import List, Dict, Optional, Union, Any, Tuple, IO, Literal
|
|
6
|
-
from collections import defaultdict
|
|
7
|
-
from datamint.dataset.dataset import DatamintDataset
|
|
8
|
-
import os
|
|
9
|
-
import numpy as np
|
|
10
|
-
import heapq
|
|
11
|
-
from datamint.utils import io_utils
|
|
12
|
-
|
|
13
|
-
_LOGGER = logging.getLogger(__name__)
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
IMPORTANT_METRICS = ['Accuracy', 'Precision', 'Recall', 'F1score', 'Positive Predictive Value', 'Sensitivity']
|
|
17
|
-
IMPORTANT_METRICS = ['test/'+m.lower() for m in IMPORTANT_METRICS]
|
|
18
|
-
METRIC_RENAMER = {
|
|
19
|
-
'precision': 'Positive Predictive Value',
|
|
20
|
-
'recall': 'Sensitivity',
|
|
21
|
-
}
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
class TopN:
|
|
25
|
-
class _Item:
|
|
26
|
-
def __init__(self, key, item):
|
|
27
|
-
self.key = key
|
|
28
|
-
self.item = item
|
|
29
|
-
|
|
30
|
-
def __lt__(self, other):
|
|
31
|
-
return self.key < other.key
|
|
32
|
-
|
|
33
|
-
def __eq__(self, other):
|
|
34
|
-
return self.key == other
|
|
35
|
-
|
|
36
|
-
def __init__(self, N, key=lambda x: x, reverse=False):
|
|
37
|
-
self.N = N
|
|
38
|
-
self.reverse = reverse
|
|
39
|
-
self.key = key
|
|
40
|
-
self.heap = []
|
|
41
|
-
|
|
42
|
-
def add(self, item):
|
|
43
|
-
item_key = float(self.key(item))
|
|
44
|
-
if self.reverse:
|
|
45
|
-
item_key = -item_key # Invert the key to keep the lowest ones
|
|
46
|
-
if len(self.heap) < self.N:
|
|
47
|
-
heapq.heappush(self.heap, TopN._Item(item_key, item))
|
|
48
|
-
else:
|
|
49
|
-
heapq.heappushpop(self.heap, TopN._Item(item_key, item))
|
|
50
|
-
|
|
51
|
-
def __len__(self):
|
|
52
|
-
return len(self.heap)
|
|
53
|
-
|
|
54
|
-
def get_top(self) -> list:
|
|
55
|
-
sorted_items = sorted(self.heap, key=lambda x: x.key, reverse=True)
|
|
56
|
-
return [item.item for item in sorted_items]
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
class _DryRunExperimentAPIHandler(APIHandler):
|
|
60
|
-
"""
|
|
61
|
-
Dry-run implementation of the ExperimentAPIHandler.
|
|
62
|
-
No data will be uploaded to the platform.
|
|
63
|
-
"""
|
|
64
|
-
|
|
65
|
-
def __init__(self, *args, **kwargs):
|
|
66
|
-
super().__init__(*args, check_connection=False, **kwargs)
|
|
67
|
-
|
|
68
|
-
def create_experiment(self, dataset_id: str, name: str, description: str, environment: Dict) -> str:
|
|
69
|
-
return "dry_run"
|
|
70
|
-
|
|
71
|
-
def log_entry(self, exp_id: str, entry: Dict):
|
|
72
|
-
pass
|
|
73
|
-
|
|
74
|
-
def log_summary(self, exp_id: str, result_summary: Dict):
|
|
75
|
-
pass
|
|
76
|
-
|
|
77
|
-
def log_model(self, exp_id: str, *args, **kwargs):
|
|
78
|
-
return {'id': 'dry_run'}
|
|
79
|
-
|
|
80
|
-
def finish_experiment(self, exp_id: str):
|
|
81
|
-
pass
|
|
82
|
-
|
|
83
|
-
def upload_segmentations(self, *args, **kwargs) -> str:
|
|
84
|
-
return "dry_run"
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
def _get_confidence_callback(pred) -> float:
|
|
88
|
-
return pred['predicted'][0]['confidence']
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
class Experiment:
|
|
92
|
-
"""
|
|
93
|
-
Experiment class to log metrics, models, and other information to the platform.
|
|
94
|
-
|
|
95
|
-
Args:
|
|
96
|
-
name (str): Name of the experiment.
|
|
97
|
-
project_name (str): Name of the project.
|
|
98
|
-
description (str): Description of the experiment.
|
|
99
|
-
api_key (str): API key for the platform.
|
|
100
|
-
root_url (str): Root URL of the platform.
|
|
101
|
-
dataset_dir (str): Directory to store the datasets.
|
|
102
|
-
log_enviroment (bool): Log the enviroment information.
|
|
103
|
-
dry_run (bool): Run in dry-run mode. No data will be uploaded to the platform
|
|
104
|
-
auto_log (bool): Automatically log the experiment using patching mechanisms.
|
|
105
|
-
tags (List[str]): Tags to add to the experiment.
|
|
106
|
-
"""
|
|
107
|
-
|
|
108
|
-
DATAMINT_DEFAULT_DIR = ".datamint"
|
|
109
|
-
DATAMINT_DATASETS_DIR = 'datasets'
|
|
110
|
-
|
|
111
|
-
def __init__(self,
|
|
112
|
-
name: str,
|
|
113
|
-
project_name: Optional[str] = None,
|
|
114
|
-
description: Optional[str] = None,
|
|
115
|
-
api_key: Optional[str] = None,
|
|
116
|
-
root_url: Optional[str] = None,
|
|
117
|
-
dataset_dir: Optional[str] = None,
|
|
118
|
-
log_enviroment: bool = True,
|
|
119
|
-
dry_run: bool = False,
|
|
120
|
-
auto_log=True,
|
|
121
|
-
tags: Optional[List[str]] = None,
|
|
122
|
-
allow_existing: bool = False
|
|
123
|
-
) -> None:
|
|
124
|
-
import torch
|
|
125
|
-
from ._patcher import initialize_automatic_logging
|
|
126
|
-
if auto_log:
|
|
127
|
-
initialize_automatic_logging()
|
|
128
|
-
self.auto_log = auto_log
|
|
129
|
-
self.name = name
|
|
130
|
-
self.dry_run = dry_run
|
|
131
|
-
if dry_run:
|
|
132
|
-
self.apihandler = _DryRunExperimentAPIHandler(api_key=api_key, root_url=root_url)
|
|
133
|
-
_LOGGER.warning("Running in dry-run mode. No data will be uploaded to the platform.")
|
|
134
|
-
else:
|
|
135
|
-
self.apihandler = APIHandler(api_key=api_key, root_url=root_url)
|
|
136
|
-
self.cur_step = None
|
|
137
|
-
self.cur_epoch = None
|
|
138
|
-
self.summary_log = defaultdict(dict)
|
|
139
|
-
self.finish_callbacks = []
|
|
140
|
-
self.model: torch.nn.Module = None
|
|
141
|
-
self.model_id = None
|
|
142
|
-
self.model_hyper_params = None
|
|
143
|
-
self.is_finished = False
|
|
144
|
-
self.log_enviroment = log_enviroment
|
|
145
|
-
|
|
146
|
-
if dataset_dir is None:
|
|
147
|
-
# store them in the home directory
|
|
148
|
-
dataset_dir = os.path.join(os.path.expanduser("~"),
|
|
149
|
-
Experiment.DATAMINT_DEFAULT_DIR)
|
|
150
|
-
dataset_dir = os.path.join(dataset_dir, Experiment.DATAMINT_DATASETS_DIR)
|
|
151
|
-
|
|
152
|
-
if not os.path.exists(dataset_dir):
|
|
153
|
-
os.makedirs(dataset_dir)
|
|
154
|
-
self.dataset_dir = dataset_dir
|
|
155
|
-
|
|
156
|
-
self.project = self.apihandler.get_project_by_name(project_name)
|
|
157
|
-
if 'error' in self.project:
|
|
158
|
-
raise DatamintException(str(self.project))
|
|
159
|
-
exp_info = self.apihandler.get_experiment_by_name(name, self.project)
|
|
160
|
-
|
|
161
|
-
self.project_name = self.project['name']
|
|
162
|
-
dataset_info = self.apihandler.get_dataset_by_id(self.project['dataset_id'])
|
|
163
|
-
self.dataset_id = dataset_info['id']
|
|
164
|
-
self.dataset_info = dataset_info
|
|
165
|
-
|
|
166
|
-
if exp_info is None:
|
|
167
|
-
self._initialize_new_exp(project_name, name, description, tags, log_enviroment)
|
|
168
|
-
else:
|
|
169
|
-
if not allow_existing:
|
|
170
|
-
raise DatamintException(f"Experiment with name '{name}' already exists for project '{project_name}'.")
|
|
171
|
-
self._init_from_existing_experiment(project=self.project, exp=exp_info)
|
|
172
|
-
|
|
173
|
-
self.time_finished = None
|
|
174
|
-
|
|
175
|
-
self.highest_predictions = defaultdict(lambda: TopN(5, key=_get_confidence_callback, reverse=False))
|
|
176
|
-
self.lowest_predictions = defaultdict(lambda: TopN(5, key=_get_confidence_callback, reverse=True))
|
|
177
|
-
|
|
178
|
-
Experiment._set_singleton_experiment(self)
|
|
179
|
-
|
|
180
|
-
def _initialize_new_exp(self,
|
|
181
|
-
project: Dict,
|
|
182
|
-
name: str,
|
|
183
|
-
description: str,
|
|
184
|
-
tags: Optional[List[str]] = None,
|
|
185
|
-
log_enviroment: bool = True):
|
|
186
|
-
env_info = Experiment.get_enviroment_info() if log_enviroment else {}
|
|
187
|
-
self.exp_id = self.apihandler.create_experiment(dataset_id=self.dataset_id,
|
|
188
|
-
name=name,
|
|
189
|
-
description=description,
|
|
190
|
-
environment=env_info)
|
|
191
|
-
self.time_started = datetime.now(timezone.utc) # FIXME: use created_at field from response
|
|
192
|
-
if tags is not None:
|
|
193
|
-
self.apihandler.log_entry(exp_id=self.exp_id,
|
|
194
|
-
entry={'tags': list(tags)})
|
|
195
|
-
|
|
196
|
-
def _init_from_existing_experiment(self, project: Dict, exp: Dict):
|
|
197
|
-
self.exp_id = exp['id']
|
|
198
|
-
|
|
199
|
-
# raise error if the experiment is already finished
|
|
200
|
-
if exp['completed_at'] is not None:
|
|
201
|
-
project_name = project["name"]
|
|
202
|
-
raise DatamintException(f"Experiment '{self.name}' from project '{project_name}' is already finished.")
|
|
203
|
-
|
|
204
|
-
# example of `exp['created_at']`: 2024-11-01T19:26:12.239Z
|
|
205
|
-
# example 2: 2024-11-14T17:47:22.363452-03:00
|
|
206
|
-
self.time_started = datetime.fromisoformat(exp['created_at'].replace('Z', '+00:00'))
|
|
207
|
-
|
|
208
|
-
@staticmethod
|
|
209
|
-
def get_enviroment_info() -> Dict[str, Any]:
|
|
210
|
-
"""
|
|
211
|
-
Get the enviroment information of the machine such as OS, Python version, etc.
|
|
212
|
-
|
|
213
|
-
Returns:
|
|
214
|
-
Dict: Enviroment information.
|
|
215
|
-
"""
|
|
216
|
-
import platform
|
|
217
|
-
import torchvision
|
|
218
|
-
import psutil
|
|
219
|
-
import socket
|
|
220
|
-
import torch
|
|
221
|
-
|
|
222
|
-
# find all ip address, removing localhost
|
|
223
|
-
ip_addresses = [addr.address for iface in psutil.net_if_addrs().values()
|
|
224
|
-
for addr in iface if addr.family == socket.AF_INET and not addr.address.startswith('127.0.')]
|
|
225
|
-
ip_addresses = list(set(ip_addresses))
|
|
226
|
-
if len(ip_addresses) == 1:
|
|
227
|
-
ip_addresses = ip_addresses[0]
|
|
228
|
-
|
|
229
|
-
# Get the enviroment and machine information, such as OS, Python version, machine name, RAM size, etc.
|
|
230
|
-
env = {
|
|
231
|
-
'python_version': platform.python_version(),
|
|
232
|
-
'torch_version': torch.__version__,
|
|
233
|
-
'torchvision_version': torchvision.__version__,
|
|
234
|
-
'numpy_version': np.__version__,
|
|
235
|
-
'os': platform.system(),
|
|
236
|
-
'os_version': platform.version(),
|
|
237
|
-
'os_name': platform.system(),
|
|
238
|
-
'machine_name': platform.node(),
|
|
239
|
-
'cpu': platform.processor(),
|
|
240
|
-
'ram_gb': psutil.virtual_memory().total / (1024. ** 3),
|
|
241
|
-
'gpu': torch.cuda.get_device_name(0) if torch.cuda.is_available() else None,
|
|
242
|
-
'gpu_count': torch.cuda.device_count(),
|
|
243
|
-
'gpu_memory': torch.cuda.get_device_properties(0).total_memory / (1024. ** 3) if torch.cuda.is_available() else None,
|
|
244
|
-
'processor_count': os.cpu_count(),
|
|
245
|
-
'processor_name': platform.processor(),
|
|
246
|
-
'hostname': os.uname().nodename,
|
|
247
|
-
'ip_address': ip_addresses,
|
|
248
|
-
}
|
|
249
|
-
|
|
250
|
-
return env
|
|
251
|
-
|
|
252
|
-
def set_model(self,
|
|
253
|
-
model,
|
|
254
|
-
hyper_params: Optional[Dict] = None):
|
|
255
|
-
"""
|
|
256
|
-
Set the model and hyper-parameters of the experiment.
|
|
257
|
-
|
|
258
|
-
Args:
|
|
259
|
-
model (torch.nn.Module): The model to log.
|
|
260
|
-
hyper_params (Optional[Dict]): The hyper-parameters of the model.
|
|
261
|
-
"""
|
|
262
|
-
self.model = model
|
|
263
|
-
self.model_hyper_params = hyper_params
|
|
264
|
-
|
|
265
|
-
@staticmethod
|
|
266
|
-
def _get_dataset_info(apihandler: APIHandler,
|
|
267
|
-
dataset_id,
|
|
268
|
-
project_name: str) -> Dict:
|
|
269
|
-
if project_name is not None:
|
|
270
|
-
project = apihandler.get_project_by_name(project_name)
|
|
271
|
-
if 'error' in project:
|
|
272
|
-
raise ValueError(str(project))
|
|
273
|
-
dataset_id = project['dataset_id']
|
|
274
|
-
|
|
275
|
-
if dataset_id is None:
|
|
276
|
-
raise ValueError("Either project_name or dataset_id must be provided.")
|
|
277
|
-
|
|
278
|
-
return apihandler.get_dataset_by_id(dataset_id)
|
|
279
|
-
|
|
280
|
-
@staticmethod
|
|
281
|
-
def get_singleton_experiment() -> 'Experiment':
|
|
282
|
-
global _EXPERIMENT
|
|
283
|
-
return _EXPERIMENT
|
|
284
|
-
|
|
285
|
-
@staticmethod
|
|
286
|
-
def _set_singleton_experiment(experiment: 'Experiment'):
|
|
287
|
-
global _EXPERIMENT
|
|
288
|
-
if _EXPERIMENT is not None:
|
|
289
|
-
_LOGGER.warning(
|
|
290
|
-
"There is already an active Experiment. Setting a new Experiment will overwrite the existing one."
|
|
291
|
-
)
|
|
292
|
-
|
|
293
|
-
_EXPERIMENT = experiment
|
|
294
|
-
|
|
295
|
-
def _set_step(self, step: Optional[int]) -> int:
|
|
296
|
-
"""
|
|
297
|
-
Set the current step of the experiment and return it.
|
|
298
|
-
If step is None, return the current step.
|
|
299
|
-
"""
|
|
300
|
-
if step is not None:
|
|
301
|
-
self.cur_step = step
|
|
302
|
-
return self.cur_step
|
|
303
|
-
|
|
304
|
-
def _set_epoch(self, epoch: Optional[int]) -> int:
|
|
305
|
-
"""
|
|
306
|
-
Set the current epoch of the experiment and return it.
|
|
307
|
-
If epoch is None, return the current epoch.
|
|
308
|
-
"""
|
|
309
|
-
if epoch is not None:
|
|
310
|
-
_LOGGER.debug(f"Setting current epoch to {epoch}")
|
|
311
|
-
self.cur_epoch = epoch
|
|
312
|
-
return self.cur_epoch
|
|
313
|
-
|
|
314
|
-
def log_metric(self,
|
|
315
|
-
name: str,
|
|
316
|
-
value: float,
|
|
317
|
-
step: int = None,
|
|
318
|
-
epoch: int = None,
|
|
319
|
-
show_in_summary: bool = False) -> None:
|
|
320
|
-
"""
|
|
321
|
-
Log a metric to the platform.
|
|
322
|
-
|
|
323
|
-
Args:
|
|
324
|
-
name (str): Arbritary name of the metric.
|
|
325
|
-
value (float): Value of the metric.
|
|
326
|
-
step (int): The step of the experiment.
|
|
327
|
-
epoch (int): The epoch of the experiment.
|
|
328
|
-
show_in_summary (bool): Show the metric in the summary. Use this to show only important metrics in the summary.
|
|
329
|
-
|
|
330
|
-
Example:
|
|
331
|
-
>>> exp.log_metric('test/sensitivity', 0.9, show_in_summary=True)
|
|
332
|
-
|
|
333
|
-
See Also:
|
|
334
|
-
:py:meth:`~log_metrics`
|
|
335
|
-
"""
|
|
336
|
-
self.log_metrics({name: value},
|
|
337
|
-
step=step,
|
|
338
|
-
epoch=epoch,
|
|
339
|
-
show_in_summary=show_in_summary)
|
|
340
|
-
|
|
341
|
-
def log_metrics(self,
|
|
342
|
-
metrics: Dict[str, float],
|
|
343
|
-
step=None,
|
|
344
|
-
epoch=None,
|
|
345
|
-
show_in_summary: bool = False) -> None:
|
|
346
|
-
"""
|
|
347
|
-
Log multiple metrics to the platform. Handy for logging multiple metrics at once.
|
|
348
|
-
|
|
349
|
-
Args:
|
|
350
|
-
metrics (Dict[str, float]): A dictionary of metrics to log.
|
|
351
|
-
step (int): The step of the experiment.
|
|
352
|
-
epoch (int): The epoch of the experiment.
|
|
353
|
-
show_in_summary (bool): Show the metric in the summary. Use this to show only important metrics in the summary
|
|
354
|
-
|
|
355
|
-
Example:
|
|
356
|
-
>>> exp.log_metrics({'test/loss': 0.1, 'test/accuracy': 0.9}, show_in_summary=True)
|
|
357
|
-
|
|
358
|
-
See Also:
|
|
359
|
-
:py:meth:`~log_metric`
|
|
360
|
-
"""
|
|
361
|
-
step = self._set_step(step)
|
|
362
|
-
epoch = self._set_epoch(epoch)
|
|
363
|
-
|
|
364
|
-
# Fix nan values
|
|
365
|
-
for name, value in metrics.items():
|
|
366
|
-
if np.isnan(value):
|
|
367
|
-
_LOGGER.debug(f"Metric {name} has a nan value. Replacing with 'NAN'.")
|
|
368
|
-
metrics[name] = 'NAN'
|
|
369
|
-
|
|
370
|
-
for name, value in metrics.items():
|
|
371
|
-
spl_name = name.lower().split('test/', maxsplit=1)
|
|
372
|
-
if spl_name[-1] in METRIC_RENAMER:
|
|
373
|
-
name = spl_name[0] + 'test/' + METRIC_RENAMER[spl_name[-1]]
|
|
374
|
-
|
|
375
|
-
if show_in_summary or name.lower() in IMPORTANT_METRICS:
|
|
376
|
-
self.add_to_summary({'metrics': {name: value}})
|
|
377
|
-
|
|
378
|
-
entry = [{'type': 'metric',
|
|
379
|
-
'name': name,
|
|
380
|
-
'value': value}
|
|
381
|
-
for name, value in metrics.items()]
|
|
382
|
-
|
|
383
|
-
for m in entry:
|
|
384
|
-
if step is not None:
|
|
385
|
-
m['step'] = step
|
|
386
|
-
if epoch is not None:
|
|
387
|
-
m['epoch'] = epoch
|
|
388
|
-
|
|
389
|
-
self.apihandler.log_entry(exp_id=self.exp_id,
|
|
390
|
-
entry={'logs': entry})
|
|
391
|
-
|
|
392
|
-
def add_to_summary(self,
|
|
393
|
-
dic: Dict):
|
|
394
|
-
for key, value in dic.items():
|
|
395
|
-
if key not in self.summary_log:
|
|
396
|
-
self.summary_log[key] = value
|
|
397
|
-
continue
|
|
398
|
-
cur_value = self.summary_log[key]
|
|
399
|
-
if isinstance(value, dict) and isinstance(cur_value, dict):
|
|
400
|
-
self.summary_log[key].update(value)
|
|
401
|
-
elif isinstance(value, list) and isinstance(cur_value, list):
|
|
402
|
-
self.summary_log[key].extend(value)
|
|
403
|
-
elif isinstance(value, tuple) and isinstance(cur_value, tuple):
|
|
404
|
-
self.summary_log[key] += value
|
|
405
|
-
else:
|
|
406
|
-
_LOGGER.warning(f"Key {key} already exists in summary. Overwriting value.")
|
|
407
|
-
self.summary_log[key] = value
|
|
408
|
-
|
|
409
|
-
def update_summary_metrics(self,
|
|
410
|
-
phase: str | None,
|
|
411
|
-
f1score: float | None,
|
|
412
|
-
accuracy: float | None,
|
|
413
|
-
sensitivity: float | None,
|
|
414
|
-
ppv: float | None,
|
|
415
|
-
):
|
|
416
|
-
"""
|
|
417
|
-
Handy method to update the summary with the most common classification metrics.
|
|
418
|
-
|
|
419
|
-
Args:
|
|
420
|
-
phase (str): The phase of the experiment. Can be 'train', 'val', 'test', '', or None.
|
|
421
|
-
f1score (float): The F1 score.
|
|
422
|
-
accuracy (float): The accuracy.
|
|
423
|
-
sensitivity (float): The sensitivity (a.k.a recall).
|
|
424
|
-
specificity (float): The specificity.
|
|
425
|
-
ppv (float): The positive predictive value (a.k.a precision).
|
|
426
|
-
"""
|
|
427
|
-
|
|
428
|
-
if phase is None:
|
|
429
|
-
phase = ""
|
|
430
|
-
|
|
431
|
-
if phase not in ['train', 'val', 'test', '']:
|
|
432
|
-
raise ValueError(f"Invalid phase: '{phase}'. Must be one of ['train', 'val', 'test', '']")
|
|
433
|
-
|
|
434
|
-
metrics = {}
|
|
435
|
-
if f1score is not None:
|
|
436
|
-
metrics[f'{phase}/F1Score'] = f1score
|
|
437
|
-
if accuracy is not None:
|
|
438
|
-
metrics[f'{phase}/Accuracy'] = accuracy
|
|
439
|
-
if sensitivity is not None:
|
|
440
|
-
metrics[f'{phase}/Sensitivity'] = sensitivity
|
|
441
|
-
if ppv is not None:
|
|
442
|
-
metrics[f'{phase}/Positive Predictive Value'] = ppv
|
|
443
|
-
|
|
444
|
-
self.add_to_summary({'metrics': metrics})
|
|
445
|
-
|
|
446
|
-
def log_summary(self,
|
|
447
|
-
result_summary: Dict) -> None:
|
|
448
|
-
"""
|
|
449
|
-
Log the summary of the experiment. This is what will be shown in the platform summary.
|
|
450
|
-
|
|
451
|
-
Args:
|
|
452
|
-
result_summary (Dict): The summary of the experiment.
|
|
453
|
-
|
|
454
|
-
Example:
|
|
455
|
-
.. code-block:: python
|
|
456
|
-
|
|
457
|
-
exp.log_summary({"metrics": {
|
|
458
|
-
"test/F1Score": 0.85,
|
|
459
|
-
"test/Accuracy": 0.9,
|
|
460
|
-
"test/Sensitivity": 0.92,
|
|
461
|
-
"test/Positive Predictive Value": 0.79,
|
|
462
|
-
}
|
|
463
|
-
})
|
|
464
|
-
"""
|
|
465
|
-
_LOGGER.debug(f"Logging summary: {result_summary}")
|
|
466
|
-
self.apihandler.log_summary(exp_id=self.exp_id,
|
|
467
|
-
result_summary=result_summary)
|
|
468
|
-
|
|
469
|
-
def log_model(self,
|
|
470
|
-
model: Any | str | IO[bytes],
|
|
471
|
-
hyper_params: Optional[Dict] = None,
|
|
472
|
-
log_model_attributes: bool = True,
|
|
473
|
-
torch_save_kwargs: Dict = {}):
|
|
474
|
-
"""
|
|
475
|
-
Log the model to the platform.
|
|
476
|
-
|
|
477
|
-
Args:
|
|
478
|
-
model (torch.nn.Module | str | IO[bytes]): The model to log. Can be a torch model, a path to a .pt or .pth file, or a BytesIO object.
|
|
479
|
-
hyper_params (Optional[Dict]): The hyper-parameters of the model. Arbitrary key-value pairs.
|
|
480
|
-
log_model_attributes (bool): Adds the attributes of the model to the hyper-parameters.
|
|
481
|
-
torch_save_kwargs (Dict): Additional arguments to pass to `torch.save`.
|
|
482
|
-
|
|
483
|
-
Example:
|
|
484
|
-
.. code-block:: python
|
|
485
|
-
|
|
486
|
-
exp.log_model(model, hyper_params={"num_layers": 3, "pretrained": True})
|
|
487
|
-
|
|
488
|
-
"""
|
|
489
|
-
import torch
|
|
490
|
-
if self.model_id is not None:
|
|
491
|
-
raise Exception("Model is already logged. Updating the model is not supported.")
|
|
492
|
-
|
|
493
|
-
if self.model is None:
|
|
494
|
-
self.model = model
|
|
495
|
-
self.model_hyper_params = hyper_params
|
|
496
|
-
|
|
497
|
-
if log_model_attributes and isinstance(model, torch.nn.Module):
|
|
498
|
-
if hyper_params is None:
|
|
499
|
-
hyper_params = {}
|
|
500
|
-
hyper_params['__model_classname'] = model.__class__.__name__
|
|
501
|
-
# get all attributes of the model that are int, float or string
|
|
502
|
-
for attr_name, attr_value in model.__dict__.items():
|
|
503
|
-
if attr_name.startswith('_'):
|
|
504
|
-
continue
|
|
505
|
-
if attr_name in ['training']:
|
|
506
|
-
continue
|
|
507
|
-
if isinstance(attr_value, (int, float, str)):
|
|
508
|
-
hyper_params[attr_name] = attr_value
|
|
509
|
-
|
|
510
|
-
# Add additional useful information
|
|
511
|
-
if isinstance(model, torch.nn.Module):
|
|
512
|
-
hyper_params.update({
|
|
513
|
-
'__num_layers': len(list(model.children())),
|
|
514
|
-
'__num_parameters': sum(p.numel() for p in model.parameters()),
|
|
515
|
-
})
|
|
516
|
-
|
|
517
|
-
self.model_id = self.apihandler.log_model(exp_id=self.exp_id,
|
|
518
|
-
model=model,
|
|
519
|
-
hyper_params=hyper_params,
|
|
520
|
-
torch_save_kwargs=torch_save_kwargs)['id']
|
|
521
|
-
|
|
522
|
-
def _add_finish_callback(self, callback):
|
|
523
|
-
self.finish_callbacks.append(callback)
|
|
524
|
-
|
|
525
|
-
def log_dataset_stats(self, dataset: DatamintDataset,
|
|
526
|
-
dataset_entry_name: str = 'default'):
|
|
527
|
-
"""
|
|
528
|
-
Log the statistics of the dataset.
|
|
529
|
-
|
|
530
|
-
Args:
|
|
531
|
-
dataset (DatamintDataset): The dataset to log the statistics.
|
|
532
|
-
dataset_entry_name (str): The name of the dataset entry.
|
|
533
|
-
Used to distinguish between different datasets and dataset splits.
|
|
534
|
-
|
|
535
|
-
Example:
|
|
536
|
-
.. code-block:: python
|
|
537
|
-
|
|
538
|
-
dataset = exp.get_dataset(split='train')
|
|
539
|
-
exp.log_dataset_stats(dataset, dataset_entry_name='train')
|
|
540
|
-
"""
|
|
541
|
-
|
|
542
|
-
if dataset_entry_name is None:
|
|
543
|
-
dataset_entry_name = 'default'
|
|
544
|
-
|
|
545
|
-
dataset_stats = {
|
|
546
|
-
'num_samples': len(dataset),
|
|
547
|
-
'num_frame_labels': len(dataset.frame_labels_set),
|
|
548
|
-
'num_segmentation_labels': len(dataset.segmentation_labels_set),
|
|
549
|
-
'frame_label_distribution': dataset.get_framelabel_distribution(normalize=True),
|
|
550
|
-
'segmentation_label_distribution': dataset.get_segmentationlabel_distribution(normalize=True),
|
|
551
|
-
}
|
|
552
|
-
|
|
553
|
-
keys_to_get = ['updated_at', 'total_resource']
|
|
554
|
-
dataset_stats.update({k: v for k, v in dataset.metainfo.items() if k in keys_to_get})
|
|
555
|
-
|
|
556
|
-
self.add_to_summary({'dataset_stats': dataset_stats})
|
|
557
|
-
dataset_params_names = ['return_dicom', 'return_metainfo', 'return_segmentations'
|
|
558
|
-
'return_frame_by_frame', 'return_as_semantic_segmentation']
|
|
559
|
-
dataset_stats['dataset_params'] = {k: getattr(dataset, k) for k in dataset_params_names if hasattr(dataset, k)}
|
|
560
|
-
dataset_stats['dataset_params']['image_transform'] = repr(dataset.image_transform)
|
|
561
|
-
dataset_stats['dataset_params']['mask_transform'] = repr(dataset.mask_transform)
|
|
562
|
-
|
|
563
|
-
self.apihandler.log_entry(exp_id=self.exp_id,
|
|
564
|
-
entry={'dataset_stats': {dataset_entry_name: dataset_stats}})
|
|
565
|
-
|
|
566
|
-
def get_dataset(self, split: str = 'all', **kwargs) -> DatamintDataset:
|
|
567
|
-
"""
|
|
568
|
-
Get the dataset associated with the experiment's project.
|
|
569
|
-
The dataset will be downloaded to the directory specified in the constructor (`self.dataset_dir`).
|
|
570
|
-
|
|
571
|
-
Args:
|
|
572
|
-
split (str): The split of the dataset to get. Can be one of ['all', 'train', 'test', 'val'].
|
|
573
|
-
**kwargs: Additional arguments to pass to the :py:class:`~datamint.dataset.dataset.DatamintDataset` class.
|
|
574
|
-
|
|
575
|
-
Returns:
|
|
576
|
-
DatamintDataset: The dataset object.
|
|
577
|
-
"""
|
|
578
|
-
if split not in ['all', 'train', 'test', 'val']:
|
|
579
|
-
raise ValueError(f"Invalid split parameter: '{split}'. Must be one of ['all', 'train', 'test', 'val']")
|
|
580
|
-
|
|
581
|
-
params = dict(project_name=self.project_name)
|
|
582
|
-
|
|
583
|
-
dataset = DatamintDataset(root=self.dataset_dir,
|
|
584
|
-
api_key=self.apihandler.api_key,
|
|
585
|
-
server_url=self.apihandler.root_url,
|
|
586
|
-
**params,
|
|
587
|
-
**kwargs)
|
|
588
|
-
|
|
589
|
-
# infer task
|
|
590
|
-
if not hasattr(self, 'detected_task') and self.auto_log:
|
|
591
|
-
self.detected_task = self._detect_machine_learning_task(dataset)
|
|
592
|
-
self.add_to_summary({'detected_task': self.detected_task})
|
|
593
|
-
|
|
594
|
-
if split == 'all':
|
|
595
|
-
self.log_dataset_stats(dataset, split)
|
|
596
|
-
return dataset
|
|
597
|
-
|
|
598
|
-
# FIXME: samples should be marked as train, test, val previously
|
|
599
|
-
|
|
600
|
-
train_split_val = 0.8
|
|
601
|
-
test_split_val = 0.1
|
|
602
|
-
indices = list(range(len(dataset)))
|
|
603
|
-
rs = np.random.RandomState(42)
|
|
604
|
-
rs.shuffle(indices)
|
|
605
|
-
train_split_idx = int(train_split_val * len(dataset))
|
|
606
|
-
test_split_idx = int(np.ceil(test_split_val * len(dataset))) + train_split_idx
|
|
607
|
-
train_indices = indices[:train_split_idx]
|
|
608
|
-
test_indices = indices[train_split_idx:test_split_idx]
|
|
609
|
-
val_indices = indices[test_split_idx:]
|
|
610
|
-
|
|
611
|
-
if split == 'train':
|
|
612
|
-
indices_to_split = train_indices
|
|
613
|
-
elif split == 'test':
|
|
614
|
-
indices_to_split = test_indices
|
|
615
|
-
elif split == 'val':
|
|
616
|
-
indices_to_split = val_indices
|
|
617
|
-
|
|
618
|
-
dataset = dataset.subset(indices_to_split)
|
|
619
|
-
self.log_dataset_stats(dataset, split)
|
|
620
|
-
return dataset
|
|
621
|
-
|
|
622
|
-
def _detect_machine_learning_task(self, dataset: DatamintDataset) -> str:
|
|
623
|
-
try:
|
|
624
|
-
# Detect machine learning task based on the dataset params
|
|
625
|
-
if dataset.return_as_semantic_segmentation and len(dataset.segmentation_labels_set) > 0:
|
|
626
|
-
return 'semantic segmentation'
|
|
627
|
-
elif dataset.return_segmentations and len(dataset.segmentation_labels_set) > 0:
|
|
628
|
-
return 'instance segmentation'
|
|
629
|
-
|
|
630
|
-
num_labels = len(dataset.frame_labels_set) # FIXME: when not frame by frame
|
|
631
|
-
num_categories = len(dataset.segmentation_labels_set)
|
|
632
|
-
if num_categories == 0:
|
|
633
|
-
if num_labels == 1:
|
|
634
|
-
return 'binary classification'
|
|
635
|
-
elif num_labels > 1:
|
|
636
|
-
return 'multilabel classification'
|
|
637
|
-
elif num_categories == 1:
|
|
638
|
-
if num_labels == 0:
|
|
639
|
-
return 'multiclass classification'
|
|
640
|
-
return 'multi-task classification'
|
|
641
|
-
else:
|
|
642
|
-
return 'multi-task classification'
|
|
643
|
-
except Exception as e:
|
|
644
|
-
_LOGGER.warning(f"Could not detect machine learning task: {e}")
|
|
645
|
-
|
|
646
|
-
return 'unknown'
|
|
647
|
-
|
|
648
|
-
def _log_predictions(self,
|
|
649
|
-
predictions: List[Dict[str, Any]],
|
|
650
|
-
dataset_split: Optional[str] = None,
|
|
651
|
-
step: Optional[int] = None,
|
|
652
|
-
epoch: Optional[int] = None):
|
|
653
|
-
"""
|
|
654
|
-
Log the predictions of the model.
|
|
655
|
-
|
|
656
|
-
Args:
|
|
657
|
-
predictions (List[Dict[str, Any]]): The predictions to log. See example below.
|
|
658
|
-
step (Optional[int]): The step of the experiment.
|
|
659
|
-
epoch (Optional[int]): The epoch of the experiment.
|
|
660
|
-
|
|
661
|
-
|
|
662
|
-
Example:
|
|
663
|
-
.. code-block:: python
|
|
664
|
-
|
|
665
|
-
predictions = [
|
|
666
|
-
{
|
|
667
|
-
'resource_id': '123',
|
|
668
|
-
# If not provided, it will be assumed predictions are for the whole resource.
|
|
669
|
-
'frame_index': 0,
|
|
670
|
-
'predicted': [
|
|
671
|
-
{
|
|
672
|
-
'identifier': 'has_fracture',
|
|
673
|
-
'value': True, # Optional
|
|
674
|
-
'confidence': 0.9, # Optional
|
|
675
|
-
'ground_truth': True # Optional
|
|
676
|
-
},
|
|
677
|
-
{
|
|
678
|
-
'identifier': 'tumor',
|
|
679
|
-
'value': segmentation1, # Optional. numpy array of shape (H, W)
|
|
680
|
-
'confidence': 0.9 # Optional. Can be mask.max()
|
|
681
|
-
}
|
|
682
|
-
]
|
|
683
|
-
}]
|
|
684
|
-
exp.log_predictions(predictions)
|
|
685
|
-
"""
|
|
686
|
-
|
|
687
|
-
self._set_step(step)
|
|
688
|
-
self._set_epoch(epoch)
|
|
689
|
-
|
|
690
|
-
entry = {'type': 'prediction',
|
|
691
|
-
'predictions': predictions,
|
|
692
|
-
'dataset_split': dataset_split,
|
|
693
|
-
'step': step,
|
|
694
|
-
}
|
|
695
|
-
|
|
696
|
-
if dataset_split == 'test':
|
|
697
|
-
for pred in predictions:
|
|
698
|
-
# if prediction is categorical
|
|
699
|
-
if pred['prediction_type'] == 'category':
|
|
700
|
-
for p in pred['predicted']:
|
|
701
|
-
if 'confidence' in p:
|
|
702
|
-
self.highest_predictions[p['identifier']].add(pred)
|
|
703
|
-
self.lowest_predictions[p['identifier']].add(pred)
|
|
704
|
-
|
|
705
|
-
self.apihandler.log_entry(exp_id=self.exp_id,
|
|
706
|
-
entry=entry)
|
|
707
|
-
|
|
708
|
-
def log_classification_predictions(self,
|
|
709
|
-
predictions_conf: np.ndarray,
|
|
710
|
-
resource_ids: List[str],
|
|
711
|
-
label_names: List[str],
|
|
712
|
-
dataset_split: Optional[str] = None,
|
|
713
|
-
frame_idxs: Optional[List[int]] = None,
|
|
714
|
-
step: Optional[int] = None,
|
|
715
|
-
epoch: Optional[int] = None,
|
|
716
|
-
add_info: Optional[Dict] = None
|
|
717
|
-
):
|
|
718
|
-
"""
|
|
719
|
-
Log the classification predictions of the model.
|
|
720
|
-
|
|
721
|
-
Args:
|
|
722
|
-
predictions_conf (np.ndarray): The predictions of the model. Can have two shapes:
|
|
723
|
-
|
|
724
|
-
- Shape (N, C) where N is the number of samples and C is the number of classes.
|
|
725
|
-
Does not need to sum to 1 (i.e., can be multilabel).
|
|
726
|
-
- Shape (N,) where N is the number of samples.
|
|
727
|
-
In this case, `label_names` should have the same length as the predictions.
|
|
728
|
-
|
|
729
|
-
label_names (List[str]): The names of the classes.
|
|
730
|
-
If the predictions are shape (N,), this should have the same length as the predictions.
|
|
731
|
-
resource_ids (List[str]): The resource IDs of the samples.
|
|
732
|
-
dataset_split (Optional[str]): The dataset split of the predictions.
|
|
733
|
-
frame_idxs (Optional[List[int]]): The frame indexes of the predictions.
|
|
734
|
-
step (Optional[int]): The step of the experiment.
|
|
735
|
-
epoch (Optional[int]): The epoch of the experiment.
|
|
736
|
-
add_info (Optional[Dict]): Additional information to add to each prediction.
|
|
737
|
-
|
|
738
|
-
Example:
|
|
739
|
-
.. code-block:: python
|
|
740
|
-
|
|
741
|
-
predictions_conf = np.array([[0.9, 0.1], [0.2, 0.8]])
|
|
742
|
-
label_names = ['cat', 'dog']
|
|
743
|
-
resource_ids = ['123', '456']
|
|
744
|
-
exp.log_classification_predictions(predictions_conf, label_names, resource_ids, dataset_split='test')
|
|
745
|
-
"""
|
|
746
|
-
|
|
747
|
-
# check predictions shape and lengths
|
|
748
|
-
if len(predictions_conf) != len(resource_ids):
|
|
749
|
-
raise ValueError("Length of predictions and resource_ids must be the same.")
|
|
750
|
-
|
|
751
|
-
if predictions_conf.ndim == 2:
|
|
752
|
-
if predictions_conf.shape[1] != len(label_names):
|
|
753
|
-
raise ValueError("Number of classes must match the number of columns in predictions_conf.")
|
|
754
|
-
elif predictions_conf.ndim == 1:
|
|
755
|
-
if len(label_names) != len(predictions_conf):
|
|
756
|
-
raise ValueError("Number of classes must match the length of predictions when predictions are 1D.")
|
|
757
|
-
else:
|
|
758
|
-
raise ValueError("Predictions must be 1D or 2D.")
|
|
759
|
-
|
|
760
|
-
resources = self.apihandler.get_resources_by_ids(resource_ids)
|
|
761
|
-
|
|
762
|
-
predictions = []
|
|
763
|
-
if predictions_conf.ndim == 2:
|
|
764
|
-
for res, pred in zip(resources, predictions_conf):
|
|
765
|
-
data = {'resource_id': res['id'],
|
|
766
|
-
'resource_filename': res['filename'],
|
|
767
|
-
'prediction_type': 'category',
|
|
768
|
-
'predicted': [{'identifier': label_names[i],
|
|
769
|
-
'confidence': float(pred[i])}
|
|
770
|
-
for i in range(len(pred))]}
|
|
771
|
-
if add_info is not None:
|
|
772
|
-
data.update(add_info)
|
|
773
|
-
predictions.append(data)
|
|
774
|
-
else:
|
|
775
|
-
# if predictions are 1D, label_names have the same length
|
|
776
|
-
for res, pred, label_i in zip(resources, predictions_conf, label_names):
|
|
777
|
-
data = {'resource_id': res['id'],
|
|
778
|
-
'resource_filename': res['filename'],
|
|
779
|
-
'prediction_type': 'category',
|
|
780
|
-
'predicted': [{'identifier': label_i,
|
|
781
|
-
'confidence': float(pred)}
|
|
782
|
-
]}
|
|
783
|
-
if add_info is not None:
|
|
784
|
-
data.update(add_info)
|
|
785
|
-
predictions.append(data)
|
|
786
|
-
|
|
787
|
-
if frame_idxs is not None:
|
|
788
|
-
for pred, frame_idx in zip(predictions, frame_idxs):
|
|
789
|
-
pred['frame_index'] = frame_idx
|
|
790
|
-
self._log_predictions(predictions, step=step, epoch=epoch, dataset_split=dataset_split)
|
|
791
|
-
|
|
792
|
-
def log_segmentation_predictions(self,
|
|
793
|
-
resource_id: str | dict,
|
|
794
|
-
predictions: np.ndarray | str,
|
|
795
|
-
label_name: str | dict[int, str],
|
|
796
|
-
frame_index: int | list[int] | None = None,
|
|
797
|
-
threshold: float = 0.5,
|
|
798
|
-
predictions_format: Literal['multi-class', 'probability'] = 'probability'
|
|
799
|
-
):
|
|
800
|
-
"""
|
|
801
|
-
Log the segmentation prediction of the model for a single frame
|
|
802
|
-
|
|
803
|
-
Args:
|
|
804
|
-
resource_id: The resource ID of the sample.
|
|
805
|
-
predictions: The predictions of the model. One binary mask for each class. Can be a numpy array of shape (H, W) or (N,H,W);
|
|
806
|
-
Or a path to a png file; Or a path to a .nii/.nii.gz file.
|
|
807
|
-
label_name: The name of the class or a dictionary mapping pixel values to names.
|
|
808
|
-
Example: ``{1: 'Femur', 2: 'Tibia'}`` means that pixel value 1 is 'Femur' and pixel value 2 is 'Tibia'.
|
|
809
|
-
frame_index: The frame index of the prediction or a list of frame indexes.
|
|
810
|
-
If a list, must have the same length as the predictions.
|
|
811
|
-
If None,
|
|
812
|
-
threshold: The threshold to apply to the predictions.
|
|
813
|
-
predictions_format: The format of the predictions. Can be a probability mask ('probability') or a multi-class mask ('multi-class').
|
|
814
|
-
|
|
815
|
-
Example:
|
|
816
|
-
.. code-block:: python
|
|
817
|
-
|
|
818
|
-
resource_id = '123'
|
|
819
|
-
predictions = np.array([[0.1, 0.4], [0.9, 0.2]])
|
|
820
|
-
label_name = 'fracture'
|
|
821
|
-
exp.log_segmentation_predictions(resource_id, predictions, label_name, threshold=0.5)
|
|
822
|
-
|
|
823
|
-
.. code-block:: python
|
|
824
|
-
|
|
825
|
-
resource_id = '456'
|
|
826
|
-
predictions = np.array([[0, 1, 2], [1, 2, 0]]) # Multi-class mask with values 0, 1, 2
|
|
827
|
-
label_name = {1: 'Femur', 2: 'Tibia'} # Mapping of pixel values to class names
|
|
828
|
-
exp.log_segmentation_predictions(
|
|
829
|
-
resource_id,
|
|
830
|
-
predictions,
|
|
831
|
-
label_name,
|
|
832
|
-
predictions_format='multi-class'
|
|
833
|
-
)
|
|
834
|
-
"""
|
|
835
|
-
|
|
836
|
-
if predictions_format not in ['multi-class', 'probability']:
|
|
837
|
-
raise ValueError("predictions_format must be 'multi-class' or 'probability'.")
|
|
838
|
-
|
|
839
|
-
if isinstance(label_name, dict) and predictions_format!='multi-class':
|
|
840
|
-
raise ValueError("If label_name is a dictionary, predictions_format must be 'multi-class'.")
|
|
841
|
-
|
|
842
|
-
if isinstance(resource_id, dict):
|
|
843
|
-
resource_id = resource_id['id']
|
|
844
|
-
|
|
845
|
-
if self.model_id is None:
|
|
846
|
-
raise ValueError("Model is not logged. Cannot log segmentation predictions. see `log_model` method.")
|
|
847
|
-
|
|
848
|
-
if isinstance(predictions, str):
|
|
849
|
-
predictions = io_utils.read_array_normalized(predictions)
|
|
850
|
-
|
|
851
|
-
if predictions_format == 'probability':
|
|
852
|
-
predictions = predictions > threshold
|
|
853
|
-
|
|
854
|
-
is_2d_prediction = predictions.ndim == 2
|
|
855
|
-
|
|
856
|
-
if predictions.ndim == 4 and predictions.shape[1] == 1:
|
|
857
|
-
predictions = predictions[:, 0]
|
|
858
|
-
elif predictions.ndim == 2:
|
|
859
|
-
predictions = predictions[np.newaxis]
|
|
860
|
-
elif predictions.ndim != 3:
|
|
861
|
-
raise ValueError(f"Prediction with shape {predictions.shape} is different than (H, W) and (N,H,W).")
|
|
862
|
-
|
|
863
|
-
if frame_index is None:
|
|
864
|
-
if is_2d_prediction:
|
|
865
|
-
raise ValueError("frame_index must be provided when predictions is 2D.")
|
|
866
|
-
frame_index = list(range(predictions.shape[0]))
|
|
867
|
-
elif isinstance(frame_index, int):
|
|
868
|
-
frame_index = [frame_index]
|
|
869
|
-
else:
|
|
870
|
-
if len(frame_index) != predictions.shape[0]:
|
|
871
|
-
raise ValueError("Length of frame_index must match the first dimension of predictions.")
|
|
872
|
-
|
|
873
|
-
new_ann_id = self.apihandler.upload_segmentations(
|
|
874
|
-
resource_id=resource_id,
|
|
875
|
-
file_path=predictions.transpose(1, 2, 0),
|
|
876
|
-
name=label_name,
|
|
877
|
-
frame_index=frame_index,
|
|
878
|
-
model_id=self.model_id,
|
|
879
|
-
worklist_id=self.project['worklist_id'],
|
|
880
|
-
)
|
|
881
|
-
|
|
882
|
-
def log_semantic_seg_predictions(self,
|
|
883
|
-
predictions: np.ndarray | str,
|
|
884
|
-
resource_ids: Union[list[str], str],
|
|
885
|
-
label_names: list[str],
|
|
886
|
-
dataset_split: Optional[str] = None,
|
|
887
|
-
frame_idxs: Optional[list[int]] = None,
|
|
888
|
-
step: Optional[int] = None,
|
|
889
|
-
epoch: Optional[int] = None,
|
|
890
|
-
threshold: float = 0.5
|
|
891
|
-
):
|
|
892
|
-
"""
|
|
893
|
-
Log the semantic segmentation predictions of the model.
|
|
894
|
-
|
|
895
|
-
Args:
|
|
896
|
-
predictions (np.ndarray | str): The predictions of the model. A list of numpy arrays of shape (N, C, H, W).
|
|
897
|
-
Or a path to a png file; Or a path to a .nii.gz file.
|
|
898
|
-
label_names (list[str]): The names of the classes. List of strings of size C.
|
|
899
|
-
resource_ids (list[str]): The resource IDs of the samples.
|
|
900
|
-
dataset_split (Optional[str]): The dataset split of the predictions.
|
|
901
|
-
frame_idxs (Optional[list[int]]): The frame indexes of the predictions.
|
|
902
|
-
step (Optional[int]): The step of the experiment.
|
|
903
|
-
epoch (Optional[int]): The epoch of the experiment.
|
|
904
|
-
"""
|
|
905
|
-
|
|
906
|
-
if isinstance(predictions, str):
|
|
907
|
-
predictions = io_utils.read_array_normalized(predictions)
|
|
908
|
-
|
|
909
|
-
if isinstance(resource_ids, str):
|
|
910
|
-
resource_ids = [resource_ids] * len(predictions)
|
|
911
|
-
|
|
912
|
-
if predictions.ndim != 4:
|
|
913
|
-
raise ValueError("Predictions must be of shape (N, C, H, W).")
|
|
914
|
-
|
|
915
|
-
# check lengths
|
|
916
|
-
if len(predictions) != len(resource_ids):
|
|
917
|
-
raise ValueError("Length of predictions and resource_ids must be the same.")
|
|
918
|
-
|
|
919
|
-
if frame_idxs is not None:
|
|
920
|
-
if len(predictions) != len(frame_idxs):
|
|
921
|
-
raise ValueError("Length of predictions and frame_idxs must be the same.")
|
|
922
|
-
# non negative frame indexes
|
|
923
|
-
if any(fidx < 0 for fidx in frame_idxs):
|
|
924
|
-
raise ValueError("Frame indexes must be non-negative.")
|
|
925
|
-
|
|
926
|
-
if len(label_names) != predictions.shape[1]:
|
|
927
|
-
raise ValueError("Number of classes must match the number of columns in predictions.")
|
|
928
|
-
|
|
929
|
-
predictions_conf = predictions.max(axis=(2, 3)) # final shape: (N, C)
|
|
930
|
-
|
|
931
|
-
# log it as classification predictions
|
|
932
|
-
self.log_classification_predictions(predictions_conf=predictions_conf,
|
|
933
|
-
label_names=label_names,
|
|
934
|
-
resource_ids=resource_ids,
|
|
935
|
-
dataset_split=dataset_split,
|
|
936
|
-
frame_idxs=frame_idxs,
|
|
937
|
-
step=step,
|
|
938
|
-
epoch=epoch,
|
|
939
|
-
add_info={'origin': 'semantic segmentation'})
|
|
940
|
-
|
|
941
|
-
if self.model_id is not None:
|
|
942
|
-
_LOGGER.info("Uploading segmentation masks to the platform.")
|
|
943
|
-
# For each frame
|
|
944
|
-
predictions = predictions > threshold
|
|
945
|
-
grouped_predictions = defaultdict(list)
|
|
946
|
-
for fidx, res_id, pred in zip(frame_idxs, resource_ids, predictions):
|
|
947
|
-
grouped_predictions[res_id].append((fidx, pred))
|
|
948
|
-
|
|
949
|
-
for res_id, list_preds in grouped_predictions.items():
|
|
950
|
-
frame_idxs = [fidx for fidx, _ in list_preds]
|
|
951
|
-
preds = np.stack([pred for _, pred in list_preds])
|
|
952
|
-
for i in range(len(label_names)):
|
|
953
|
-
preds_i = preds[:, i] # get the i-th class predictions
|
|
954
|
-
# preds_i.shape: (N, H, W)
|
|
955
|
-
new_ann_id = self.apihandler.upload_segmentations(
|
|
956
|
-
resource_id=res_id,
|
|
957
|
-
file_path=preds_i,
|
|
958
|
-
name=label_names[i],
|
|
959
|
-
frame_index=frame_idxs,
|
|
960
|
-
model_id=self.model_id,
|
|
961
|
-
worklist_id=self.project['worklist_id'],
|
|
962
|
-
)
|
|
963
|
-
else:
|
|
964
|
-
_LOGGER.warning("Model is not logged. Skipping uploading segmentation masks.")
|
|
965
|
-
|
|
966
|
-
def finish(self):
|
|
967
|
-
"""
|
|
968
|
-
Finish the experiment.
|
|
969
|
-
This will log the summary and finish the experiment.
|
|
970
|
-
"""
|
|
971
|
-
def _process_toppredictions(top_predictions: Dict[str, TopN], rev: bool) -> Tuple[TopN, Dict]:
|
|
972
|
-
preds_per_label = {key: values.get_top()
|
|
973
|
-
for key, values in top_predictions.items()}
|
|
974
|
-
# get the highest prediction over all labels
|
|
975
|
-
preds_combined = TopN(5, key=_get_confidence_callback, reverse=rev)
|
|
976
|
-
for label_preds in preds_per_label.values():
|
|
977
|
-
for pred in label_preds:
|
|
978
|
-
preds_combined.add(pred)
|
|
979
|
-
|
|
980
|
-
return preds_combined, preds_per_label
|
|
981
|
-
|
|
982
|
-
if self.is_finished:
|
|
983
|
-
_LOGGER.debug("Experiment is already finished.")
|
|
984
|
-
return
|
|
985
|
-
_LOGGER.info("Finishing experiment")
|
|
986
|
-
for callback in self.finish_callbacks:
|
|
987
|
-
callback(self)
|
|
988
|
-
self.time_finished = datetime.now(timezone.utc)
|
|
989
|
-
time_spent_seconds = (self.time_finished - self.time_started).total_seconds()
|
|
990
|
-
|
|
991
|
-
### produce finishing summary ###
|
|
992
|
-
# time spent
|
|
993
|
-
self.add_to_summary({'time_spent_seconds': time_spent_seconds})
|
|
994
|
-
|
|
995
|
-
# add the most interesting predictions
|
|
996
|
-
if len(self.highest_predictions) > 0:
|
|
997
|
-
highest_preds_combined, highest_preds_per_label = _process_toppredictions(self.highest_predictions, False)
|
|
998
|
-
lowest_preds_combined, lowest_preds_per_label = _process_toppredictions(self.lowest_predictions, True)
|
|
999
|
-
|
|
1000
|
-
self.add_to_summary({'highest_predictions': {'combined': highest_preds_combined.get_top(),
|
|
1001
|
-
'per_label': highest_preds_per_label
|
|
1002
|
-
}
|
|
1003
|
-
}
|
|
1004
|
-
)
|
|
1005
|
-
|
|
1006
|
-
self.add_to_summary({'lowest_predictions': {'combined': lowest_preds_combined.get_top(),
|
|
1007
|
-
'per_label': lowest_preds_per_label
|
|
1008
|
-
}
|
|
1009
|
-
}
|
|
1010
|
-
)
|
|
1011
|
-
|
|
1012
|
-
self.log_summary(result_summary=self.summary_log)
|
|
1013
|
-
# if the model is not already logged, log it
|
|
1014
|
-
if self.model is not None and self.model_id is None:
|
|
1015
|
-
self.log_model(model=self.model, hyper_params=self.model_hyper_params)
|
|
1016
|
-
self.apihandler.finish_experiment(self.exp_id)
|
|
1017
|
-
self.is_finished = True
|
|
1018
|
-
|
|
1019
|
-
_LOGGER.info("Experiment finished and uploaded to the platform.")
|
|
1020
|
-
|
|
1021
|
-
|
|
1022
|
-
class _LogHistory:
|
|
1023
|
-
"""
|
|
1024
|
-
TODO: integrate this with the Experiment class.
|
|
1025
|
-
"""
|
|
1026
|
-
|
|
1027
|
-
def __init__(self):
|
|
1028
|
-
self.history = []
|
|
1029
|
-
|
|
1030
|
-
def append(self, dt: datetime = None, **kwargs):
|
|
1031
|
-
if dt is None:
|
|
1032
|
-
dt = datetime.now(timezone.utc)
|
|
1033
|
-
else:
|
|
1034
|
-
if dt.tzinfo is None:
|
|
1035
|
-
_LOGGER.warning("No timezone information provided. Assuming UTC.")
|
|
1036
|
-
dt = dt.replace(tzinfo=timezone.utc)
|
|
1037
|
-
|
|
1038
|
-
item = {
|
|
1039
|
-
# datetime in GMT+0
|
|
1040
|
-
'timestamp': dt.timestamp(),
|
|
1041
|
-
**kwargs
|
|
1042
|
-
}
|
|
1043
|
-
self.history.append(item)
|
|
1044
|
-
|
|
1045
|
-
def get_history(self) -> List[Dict]:
|
|
1046
|
-
return self.history
|
|
1047
|
-
|
|
1048
|
-
|
|
1049
|
-
_EXPERIMENT: Experiment = None
|