datamint 2.3.3__py3-none-any.whl → 2.9.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (59) hide show
  1. datamint/__init__.py +1 -3
  2. datamint/api/__init__.py +0 -3
  3. datamint/api/base_api.py +286 -54
  4. datamint/api/client.py +76 -13
  5. datamint/api/endpoints/__init__.py +2 -2
  6. datamint/api/endpoints/annotations_api.py +186 -28
  7. datamint/api/endpoints/deploy_model_api.py +78 -0
  8. datamint/api/endpoints/models_api.py +1 -0
  9. datamint/api/endpoints/projects_api.py +38 -7
  10. datamint/api/endpoints/resources_api.py +227 -100
  11. datamint/api/entity_base_api.py +66 -7
  12. datamint/apihandler/base_api_handler.py +0 -1
  13. datamint/apihandler/dto/annotation_dto.py +2 -0
  14. datamint/client_cmd_tools/datamint_config.py +0 -1
  15. datamint/client_cmd_tools/datamint_upload.py +3 -1
  16. datamint/configs.py +11 -7
  17. datamint/dataset/base_dataset.py +24 -4
  18. datamint/dataset/dataset.py +1 -1
  19. datamint/entities/__init__.py +1 -1
  20. datamint/entities/annotations/__init__.py +13 -0
  21. datamint/entities/{annotation.py → annotations/annotation.py} +81 -47
  22. datamint/entities/annotations/image_classification.py +12 -0
  23. datamint/entities/annotations/image_segmentation.py +252 -0
  24. datamint/entities/annotations/volume_segmentation.py +273 -0
  25. datamint/entities/base_entity.py +100 -6
  26. datamint/entities/cache_manager.py +129 -15
  27. datamint/entities/datasetinfo.py +60 -65
  28. datamint/entities/deployjob.py +18 -0
  29. datamint/entities/project.py +39 -0
  30. datamint/entities/resource.py +310 -46
  31. datamint/lightning/__init__.py +1 -0
  32. datamint/lightning/datamintdatamodule.py +103 -0
  33. datamint/mlflow/__init__.py +65 -0
  34. datamint/mlflow/artifact/__init__.py +1 -0
  35. datamint/mlflow/artifact/datamint_artifacts_repo.py +8 -0
  36. datamint/mlflow/env_utils.py +131 -0
  37. datamint/mlflow/env_vars.py +5 -0
  38. datamint/mlflow/flavors/__init__.py +17 -0
  39. datamint/mlflow/flavors/datamint_flavor.py +150 -0
  40. datamint/mlflow/flavors/model.py +877 -0
  41. datamint/mlflow/lightning/callbacks/__init__.py +1 -0
  42. datamint/mlflow/lightning/callbacks/modelcheckpoint.py +410 -0
  43. datamint/mlflow/models/__init__.py +93 -0
  44. datamint/mlflow/tracking/datamint_store.py +76 -0
  45. datamint/mlflow/tracking/default_experiment.py +27 -0
  46. datamint/mlflow/tracking/fluent.py +91 -0
  47. datamint/utils/env.py +27 -0
  48. datamint/utils/visualization.py +21 -13
  49. datamint-2.9.0.dist-info/METADATA +220 -0
  50. datamint-2.9.0.dist-info/RECORD +73 -0
  51. {datamint-2.3.3.dist-info → datamint-2.9.0.dist-info}/WHEEL +1 -1
  52. datamint-2.9.0.dist-info/entry_points.txt +18 -0
  53. datamint/apihandler/exp_api_handler.py +0 -204
  54. datamint/experiment/__init__.py +0 -1
  55. datamint/experiment/_patcher.py +0 -570
  56. datamint/experiment/experiment.py +0 -1049
  57. datamint-2.3.3.dist-info/METADATA +0 -125
  58. datamint-2.3.3.dist-info/RECORD +0 -54
  59. datamint-2.3.3.dist-info/entry_points.txt +0 -4
@@ -1,1049 +0,0 @@
1
- import logging
2
- from datamint.apihandler.api_handler import APIHandler
3
- from datamint.apihandler.base_api_handler import DatamintException
4
- from datetime import datetime, timezone
5
- from typing import List, Dict, Optional, Union, Any, Tuple, IO, Literal
6
- from collections import defaultdict
7
- from datamint.dataset.dataset import DatamintDataset
8
- import os
9
- import numpy as np
10
- import heapq
11
- from datamint.utils import io_utils
12
-
13
- _LOGGER = logging.getLogger(__name__)
14
-
15
-
16
- IMPORTANT_METRICS = ['Accuracy', 'Precision', 'Recall', 'F1score', 'Positive Predictive Value', 'Sensitivity']
17
- IMPORTANT_METRICS = ['test/'+m.lower() for m in IMPORTANT_METRICS]
18
- METRIC_RENAMER = {
19
- 'precision': 'Positive Predictive Value',
20
- 'recall': 'Sensitivity',
21
- }
22
-
23
-
24
- class TopN:
25
- class _Item:
26
- def __init__(self, key, item):
27
- self.key = key
28
- self.item = item
29
-
30
- def __lt__(self, other):
31
- return self.key < other.key
32
-
33
- def __eq__(self, other):
34
- return self.key == other
35
-
36
- def __init__(self, N, key=lambda x: x, reverse=False):
37
- self.N = N
38
- self.reverse = reverse
39
- self.key = key
40
- self.heap = []
41
-
42
- def add(self, item):
43
- item_key = float(self.key(item))
44
- if self.reverse:
45
- item_key = -item_key # Invert the key to keep the lowest ones
46
- if len(self.heap) < self.N:
47
- heapq.heappush(self.heap, TopN._Item(item_key, item))
48
- else:
49
- heapq.heappushpop(self.heap, TopN._Item(item_key, item))
50
-
51
- def __len__(self):
52
- return len(self.heap)
53
-
54
- def get_top(self) -> list:
55
- sorted_items = sorted(self.heap, key=lambda x: x.key, reverse=True)
56
- return [item.item for item in sorted_items]
57
-
58
-
59
- class _DryRunExperimentAPIHandler(APIHandler):
60
- """
61
- Dry-run implementation of the ExperimentAPIHandler.
62
- No data will be uploaded to the platform.
63
- """
64
-
65
- def __init__(self, *args, **kwargs):
66
- super().__init__(*args, check_connection=False, **kwargs)
67
-
68
- def create_experiment(self, dataset_id: str, name: str, description: str, environment: Dict) -> str:
69
- return "dry_run"
70
-
71
- def log_entry(self, exp_id: str, entry: Dict):
72
- pass
73
-
74
- def log_summary(self, exp_id: str, result_summary: Dict):
75
- pass
76
-
77
- def log_model(self, exp_id: str, *args, **kwargs):
78
- return {'id': 'dry_run'}
79
-
80
- def finish_experiment(self, exp_id: str):
81
- pass
82
-
83
- def upload_segmentations(self, *args, **kwargs) -> str:
84
- return "dry_run"
85
-
86
-
87
- def _get_confidence_callback(pred) -> float:
88
- return pred['predicted'][0]['confidence']
89
-
90
-
91
- class Experiment:
92
- """
93
- Experiment class to log metrics, models, and other information to the platform.
94
-
95
- Args:
96
- name (str): Name of the experiment.
97
- project_name (str): Name of the project.
98
- description (str): Description of the experiment.
99
- api_key (str): API key for the platform.
100
- root_url (str): Root URL of the platform.
101
- dataset_dir (str): Directory to store the datasets.
102
- log_enviroment (bool): Log the enviroment information.
103
- dry_run (bool): Run in dry-run mode. No data will be uploaded to the platform
104
- auto_log (bool): Automatically log the experiment using patching mechanisms.
105
- tags (List[str]): Tags to add to the experiment.
106
- """
107
-
108
- DATAMINT_DEFAULT_DIR = ".datamint"
109
- DATAMINT_DATASETS_DIR = 'datasets'
110
-
111
- def __init__(self,
112
- name: str,
113
- project_name: Optional[str] = None,
114
- description: Optional[str] = None,
115
- api_key: Optional[str] = None,
116
- root_url: Optional[str] = None,
117
- dataset_dir: Optional[str] = None,
118
- log_enviroment: bool = True,
119
- dry_run: bool = False,
120
- auto_log=True,
121
- tags: Optional[List[str]] = None,
122
- allow_existing: bool = False
123
- ) -> None:
124
- import torch
125
- from ._patcher import initialize_automatic_logging
126
- if auto_log:
127
- initialize_automatic_logging()
128
- self.auto_log = auto_log
129
- self.name = name
130
- self.dry_run = dry_run
131
- if dry_run:
132
- self.apihandler = _DryRunExperimentAPIHandler(api_key=api_key, root_url=root_url)
133
- _LOGGER.warning("Running in dry-run mode. No data will be uploaded to the platform.")
134
- else:
135
- self.apihandler = APIHandler(api_key=api_key, root_url=root_url)
136
- self.cur_step = None
137
- self.cur_epoch = None
138
- self.summary_log = defaultdict(dict)
139
- self.finish_callbacks = []
140
- self.model: torch.nn.Module = None
141
- self.model_id = None
142
- self.model_hyper_params = None
143
- self.is_finished = False
144
- self.log_enviroment = log_enviroment
145
-
146
- if dataset_dir is None:
147
- # store them in the home directory
148
- dataset_dir = os.path.join(os.path.expanduser("~"),
149
- Experiment.DATAMINT_DEFAULT_DIR)
150
- dataset_dir = os.path.join(dataset_dir, Experiment.DATAMINT_DATASETS_DIR)
151
-
152
- if not os.path.exists(dataset_dir):
153
- os.makedirs(dataset_dir)
154
- self.dataset_dir = dataset_dir
155
-
156
- self.project = self.apihandler.get_project_by_name(project_name)
157
- if 'error' in self.project:
158
- raise DatamintException(str(self.project))
159
- exp_info = self.apihandler.get_experiment_by_name(name, self.project)
160
-
161
- self.project_name = self.project['name']
162
- dataset_info = self.apihandler.get_dataset_by_id(self.project['dataset_id'])
163
- self.dataset_id = dataset_info['id']
164
- self.dataset_info = dataset_info
165
-
166
- if exp_info is None:
167
- self._initialize_new_exp(project_name, name, description, tags, log_enviroment)
168
- else:
169
- if not allow_existing:
170
- raise DatamintException(f"Experiment with name '{name}' already exists for project '{project_name}'.")
171
- self._init_from_existing_experiment(project=self.project, exp=exp_info)
172
-
173
- self.time_finished = None
174
-
175
- self.highest_predictions = defaultdict(lambda: TopN(5, key=_get_confidence_callback, reverse=False))
176
- self.lowest_predictions = defaultdict(lambda: TopN(5, key=_get_confidence_callback, reverse=True))
177
-
178
- Experiment._set_singleton_experiment(self)
179
-
180
- def _initialize_new_exp(self,
181
- project: Dict,
182
- name: str,
183
- description: str,
184
- tags: Optional[List[str]] = None,
185
- log_enviroment: bool = True):
186
- env_info = Experiment.get_enviroment_info() if log_enviroment else {}
187
- self.exp_id = self.apihandler.create_experiment(dataset_id=self.dataset_id,
188
- name=name,
189
- description=description,
190
- environment=env_info)
191
- self.time_started = datetime.now(timezone.utc) # FIXME: use created_at field from response
192
- if tags is not None:
193
- self.apihandler.log_entry(exp_id=self.exp_id,
194
- entry={'tags': list(tags)})
195
-
196
- def _init_from_existing_experiment(self, project: Dict, exp: Dict):
197
- self.exp_id = exp['id']
198
-
199
- # raise error if the experiment is already finished
200
- if exp['completed_at'] is not None:
201
- project_name = project["name"]
202
- raise DatamintException(f"Experiment '{self.name}' from project '{project_name}' is already finished.")
203
-
204
- # example of `exp['created_at']`: 2024-11-01T19:26:12.239Z
205
- # example 2: 2024-11-14T17:47:22.363452-03:00
206
- self.time_started = datetime.fromisoformat(exp['created_at'].replace('Z', '+00:00'))
207
-
208
- @staticmethod
209
- def get_enviroment_info() -> Dict[str, Any]:
210
- """
211
- Get the enviroment information of the machine such as OS, Python version, etc.
212
-
213
- Returns:
214
- Dict: Enviroment information.
215
- """
216
- import platform
217
- import torchvision
218
- import psutil
219
- import socket
220
- import torch
221
-
222
- # find all ip address, removing localhost
223
- ip_addresses = [addr.address for iface in psutil.net_if_addrs().values()
224
- for addr in iface if addr.family == socket.AF_INET and not addr.address.startswith('127.0.')]
225
- ip_addresses = list(set(ip_addresses))
226
- if len(ip_addresses) == 1:
227
- ip_addresses = ip_addresses[0]
228
-
229
- # Get the enviroment and machine information, such as OS, Python version, machine name, RAM size, etc.
230
- env = {
231
- 'python_version': platform.python_version(),
232
- 'torch_version': torch.__version__,
233
- 'torchvision_version': torchvision.__version__,
234
- 'numpy_version': np.__version__,
235
- 'os': platform.system(),
236
- 'os_version': platform.version(),
237
- 'os_name': platform.system(),
238
- 'machine_name': platform.node(),
239
- 'cpu': platform.processor(),
240
- 'ram_gb': psutil.virtual_memory().total / (1024. ** 3),
241
- 'gpu': torch.cuda.get_device_name(0) if torch.cuda.is_available() else None,
242
- 'gpu_count': torch.cuda.device_count(),
243
- 'gpu_memory': torch.cuda.get_device_properties(0).total_memory / (1024. ** 3) if torch.cuda.is_available() else None,
244
- 'processor_count': os.cpu_count(),
245
- 'processor_name': platform.processor(),
246
- 'hostname': os.uname().nodename,
247
- 'ip_address': ip_addresses,
248
- }
249
-
250
- return env
251
-
252
- def set_model(self,
253
- model,
254
- hyper_params: Optional[Dict] = None):
255
- """
256
- Set the model and hyper-parameters of the experiment.
257
-
258
- Args:
259
- model (torch.nn.Module): The model to log.
260
- hyper_params (Optional[Dict]): The hyper-parameters of the model.
261
- """
262
- self.model = model
263
- self.model_hyper_params = hyper_params
264
-
265
- @staticmethod
266
- def _get_dataset_info(apihandler: APIHandler,
267
- dataset_id,
268
- project_name: str) -> Dict:
269
- if project_name is not None:
270
- project = apihandler.get_project_by_name(project_name)
271
- if 'error' in project:
272
- raise ValueError(str(project))
273
- dataset_id = project['dataset_id']
274
-
275
- if dataset_id is None:
276
- raise ValueError("Either project_name or dataset_id must be provided.")
277
-
278
- return apihandler.get_dataset_by_id(dataset_id)
279
-
280
- @staticmethod
281
- def get_singleton_experiment() -> 'Experiment':
282
- global _EXPERIMENT
283
- return _EXPERIMENT
284
-
285
- @staticmethod
286
- def _set_singleton_experiment(experiment: 'Experiment'):
287
- global _EXPERIMENT
288
- if _EXPERIMENT is not None:
289
- _LOGGER.warning(
290
- "There is already an active Experiment. Setting a new Experiment will overwrite the existing one."
291
- )
292
-
293
- _EXPERIMENT = experiment
294
-
295
- def _set_step(self, step: Optional[int]) -> int:
296
- """
297
- Set the current step of the experiment and return it.
298
- If step is None, return the current step.
299
- """
300
- if step is not None:
301
- self.cur_step = step
302
- return self.cur_step
303
-
304
- def _set_epoch(self, epoch: Optional[int]) -> int:
305
- """
306
- Set the current epoch of the experiment and return it.
307
- If epoch is None, return the current epoch.
308
- """
309
- if epoch is not None:
310
- _LOGGER.debug(f"Setting current epoch to {epoch}")
311
- self.cur_epoch = epoch
312
- return self.cur_epoch
313
-
314
- def log_metric(self,
315
- name: str,
316
- value: float,
317
- step: int = None,
318
- epoch: int = None,
319
- show_in_summary: bool = False) -> None:
320
- """
321
- Log a metric to the platform.
322
-
323
- Args:
324
- name (str): Arbritary name of the metric.
325
- value (float): Value of the metric.
326
- step (int): The step of the experiment.
327
- epoch (int): The epoch of the experiment.
328
- show_in_summary (bool): Show the metric in the summary. Use this to show only important metrics in the summary.
329
-
330
- Example:
331
- >>> exp.log_metric('test/sensitivity', 0.9, show_in_summary=True)
332
-
333
- See Also:
334
- :py:meth:`~log_metrics`
335
- """
336
- self.log_metrics({name: value},
337
- step=step,
338
- epoch=epoch,
339
- show_in_summary=show_in_summary)
340
-
341
- def log_metrics(self,
342
- metrics: Dict[str, float],
343
- step=None,
344
- epoch=None,
345
- show_in_summary: bool = False) -> None:
346
- """
347
- Log multiple metrics to the platform. Handy for logging multiple metrics at once.
348
-
349
- Args:
350
- metrics (Dict[str, float]): A dictionary of metrics to log.
351
- step (int): The step of the experiment.
352
- epoch (int): The epoch of the experiment.
353
- show_in_summary (bool): Show the metric in the summary. Use this to show only important metrics in the summary
354
-
355
- Example:
356
- >>> exp.log_metrics({'test/loss': 0.1, 'test/accuracy': 0.9}, show_in_summary=True)
357
-
358
- See Also:
359
- :py:meth:`~log_metric`
360
- """
361
- step = self._set_step(step)
362
- epoch = self._set_epoch(epoch)
363
-
364
- # Fix nan values
365
- for name, value in metrics.items():
366
- if np.isnan(value):
367
- _LOGGER.debug(f"Metric {name} has a nan value. Replacing with 'NAN'.")
368
- metrics[name] = 'NAN'
369
-
370
- for name, value in metrics.items():
371
- spl_name = name.lower().split('test/', maxsplit=1)
372
- if spl_name[-1] in METRIC_RENAMER:
373
- name = spl_name[0] + 'test/' + METRIC_RENAMER[spl_name[-1]]
374
-
375
- if show_in_summary or name.lower() in IMPORTANT_METRICS:
376
- self.add_to_summary({'metrics': {name: value}})
377
-
378
- entry = [{'type': 'metric',
379
- 'name': name,
380
- 'value': value}
381
- for name, value in metrics.items()]
382
-
383
- for m in entry:
384
- if step is not None:
385
- m['step'] = step
386
- if epoch is not None:
387
- m['epoch'] = epoch
388
-
389
- self.apihandler.log_entry(exp_id=self.exp_id,
390
- entry={'logs': entry})
391
-
392
- def add_to_summary(self,
393
- dic: Dict):
394
- for key, value in dic.items():
395
- if key not in self.summary_log:
396
- self.summary_log[key] = value
397
- continue
398
- cur_value = self.summary_log[key]
399
- if isinstance(value, dict) and isinstance(cur_value, dict):
400
- self.summary_log[key].update(value)
401
- elif isinstance(value, list) and isinstance(cur_value, list):
402
- self.summary_log[key].extend(value)
403
- elif isinstance(value, tuple) and isinstance(cur_value, tuple):
404
- self.summary_log[key] += value
405
- else:
406
- _LOGGER.warning(f"Key {key} already exists in summary. Overwriting value.")
407
- self.summary_log[key] = value
408
-
409
- def update_summary_metrics(self,
410
- phase: str | None,
411
- f1score: float | None,
412
- accuracy: float | None,
413
- sensitivity: float | None,
414
- ppv: float | None,
415
- ):
416
- """
417
- Handy method to update the summary with the most common classification metrics.
418
-
419
- Args:
420
- phase (str): The phase of the experiment. Can be 'train', 'val', 'test', '', or None.
421
- f1score (float): The F1 score.
422
- accuracy (float): The accuracy.
423
- sensitivity (float): The sensitivity (a.k.a recall).
424
- specificity (float): The specificity.
425
- ppv (float): The positive predictive value (a.k.a precision).
426
- """
427
-
428
- if phase is None:
429
- phase = ""
430
-
431
- if phase not in ['train', 'val', 'test', '']:
432
- raise ValueError(f"Invalid phase: '{phase}'. Must be one of ['train', 'val', 'test', '']")
433
-
434
- metrics = {}
435
- if f1score is not None:
436
- metrics[f'{phase}/F1Score'] = f1score
437
- if accuracy is not None:
438
- metrics[f'{phase}/Accuracy'] = accuracy
439
- if sensitivity is not None:
440
- metrics[f'{phase}/Sensitivity'] = sensitivity
441
- if ppv is not None:
442
- metrics[f'{phase}/Positive Predictive Value'] = ppv
443
-
444
- self.add_to_summary({'metrics': metrics})
445
-
446
- def log_summary(self,
447
- result_summary: Dict) -> None:
448
- """
449
- Log the summary of the experiment. This is what will be shown in the platform summary.
450
-
451
- Args:
452
- result_summary (Dict): The summary of the experiment.
453
-
454
- Example:
455
- .. code-block:: python
456
-
457
- exp.log_summary({"metrics": {
458
- "test/F1Score": 0.85,
459
- "test/Accuracy": 0.9,
460
- "test/Sensitivity": 0.92,
461
- "test/Positive Predictive Value": 0.79,
462
- }
463
- })
464
- """
465
- _LOGGER.debug(f"Logging summary: {result_summary}")
466
- self.apihandler.log_summary(exp_id=self.exp_id,
467
- result_summary=result_summary)
468
-
469
- def log_model(self,
470
- model: Any | str | IO[bytes],
471
- hyper_params: Optional[Dict] = None,
472
- log_model_attributes: bool = True,
473
- torch_save_kwargs: Dict = {}):
474
- """
475
- Log the model to the platform.
476
-
477
- Args:
478
- model (torch.nn.Module | str | IO[bytes]): The model to log. Can be a torch model, a path to a .pt or .pth file, or a BytesIO object.
479
- hyper_params (Optional[Dict]): The hyper-parameters of the model. Arbitrary key-value pairs.
480
- log_model_attributes (bool): Adds the attributes of the model to the hyper-parameters.
481
- torch_save_kwargs (Dict): Additional arguments to pass to `torch.save`.
482
-
483
- Example:
484
- .. code-block:: python
485
-
486
- exp.log_model(model, hyper_params={"num_layers": 3, "pretrained": True})
487
-
488
- """
489
- import torch
490
- if self.model_id is not None:
491
- raise Exception("Model is already logged. Updating the model is not supported.")
492
-
493
- if self.model is None:
494
- self.model = model
495
- self.model_hyper_params = hyper_params
496
-
497
- if log_model_attributes and isinstance(model, torch.nn.Module):
498
- if hyper_params is None:
499
- hyper_params = {}
500
- hyper_params['__model_classname'] = model.__class__.__name__
501
- # get all attributes of the model that are int, float or string
502
- for attr_name, attr_value in model.__dict__.items():
503
- if attr_name.startswith('_'):
504
- continue
505
- if attr_name in ['training']:
506
- continue
507
- if isinstance(attr_value, (int, float, str)):
508
- hyper_params[attr_name] = attr_value
509
-
510
- # Add additional useful information
511
- if isinstance(model, torch.nn.Module):
512
- hyper_params.update({
513
- '__num_layers': len(list(model.children())),
514
- '__num_parameters': sum(p.numel() for p in model.parameters()),
515
- })
516
-
517
- self.model_id = self.apihandler.log_model(exp_id=self.exp_id,
518
- model=model,
519
- hyper_params=hyper_params,
520
- torch_save_kwargs=torch_save_kwargs)['id']
521
-
522
- def _add_finish_callback(self, callback):
523
- self.finish_callbacks.append(callback)
524
-
525
- def log_dataset_stats(self, dataset: DatamintDataset,
526
- dataset_entry_name: str = 'default'):
527
- """
528
- Log the statistics of the dataset.
529
-
530
- Args:
531
- dataset (DatamintDataset): The dataset to log the statistics.
532
- dataset_entry_name (str): The name of the dataset entry.
533
- Used to distinguish between different datasets and dataset splits.
534
-
535
- Example:
536
- .. code-block:: python
537
-
538
- dataset = exp.get_dataset(split='train')
539
- exp.log_dataset_stats(dataset, dataset_entry_name='train')
540
- """
541
-
542
- if dataset_entry_name is None:
543
- dataset_entry_name = 'default'
544
-
545
- dataset_stats = {
546
- 'num_samples': len(dataset),
547
- 'num_frame_labels': len(dataset.frame_labels_set),
548
- 'num_segmentation_labels': len(dataset.segmentation_labels_set),
549
- 'frame_label_distribution': dataset.get_framelabel_distribution(normalize=True),
550
- 'segmentation_label_distribution': dataset.get_segmentationlabel_distribution(normalize=True),
551
- }
552
-
553
- keys_to_get = ['updated_at', 'total_resource']
554
- dataset_stats.update({k: v for k, v in dataset.metainfo.items() if k in keys_to_get})
555
-
556
- self.add_to_summary({'dataset_stats': dataset_stats})
557
- dataset_params_names = ['return_dicom', 'return_metainfo', 'return_segmentations'
558
- 'return_frame_by_frame', 'return_as_semantic_segmentation']
559
- dataset_stats['dataset_params'] = {k: getattr(dataset, k) for k in dataset_params_names if hasattr(dataset, k)}
560
- dataset_stats['dataset_params']['image_transform'] = repr(dataset.image_transform)
561
- dataset_stats['dataset_params']['mask_transform'] = repr(dataset.mask_transform)
562
-
563
- self.apihandler.log_entry(exp_id=self.exp_id,
564
- entry={'dataset_stats': {dataset_entry_name: dataset_stats}})
565
-
566
- def get_dataset(self, split: str = 'all', **kwargs) -> DatamintDataset:
567
- """
568
- Get the dataset associated with the experiment's project.
569
- The dataset will be downloaded to the directory specified in the constructor (`self.dataset_dir`).
570
-
571
- Args:
572
- split (str): The split of the dataset to get. Can be one of ['all', 'train', 'test', 'val'].
573
- **kwargs: Additional arguments to pass to the :py:class:`~datamint.dataset.dataset.DatamintDataset` class.
574
-
575
- Returns:
576
- DatamintDataset: The dataset object.
577
- """
578
- if split not in ['all', 'train', 'test', 'val']:
579
- raise ValueError(f"Invalid split parameter: '{split}'. Must be one of ['all', 'train', 'test', 'val']")
580
-
581
- params = dict(project_name=self.project_name)
582
-
583
- dataset = DatamintDataset(root=self.dataset_dir,
584
- api_key=self.apihandler.api_key,
585
- server_url=self.apihandler.root_url,
586
- **params,
587
- **kwargs)
588
-
589
- # infer task
590
- if not hasattr(self, 'detected_task') and self.auto_log:
591
- self.detected_task = self._detect_machine_learning_task(dataset)
592
- self.add_to_summary({'detected_task': self.detected_task})
593
-
594
- if split == 'all':
595
- self.log_dataset_stats(dataset, split)
596
- return dataset
597
-
598
- # FIXME: samples should be marked as train, test, val previously
599
-
600
- train_split_val = 0.8
601
- test_split_val = 0.1
602
- indices = list(range(len(dataset)))
603
- rs = np.random.RandomState(42)
604
- rs.shuffle(indices)
605
- train_split_idx = int(train_split_val * len(dataset))
606
- test_split_idx = int(np.ceil(test_split_val * len(dataset))) + train_split_idx
607
- train_indices = indices[:train_split_idx]
608
- test_indices = indices[train_split_idx:test_split_idx]
609
- val_indices = indices[test_split_idx:]
610
-
611
- if split == 'train':
612
- indices_to_split = train_indices
613
- elif split == 'test':
614
- indices_to_split = test_indices
615
- elif split == 'val':
616
- indices_to_split = val_indices
617
-
618
- dataset = dataset.subset(indices_to_split)
619
- self.log_dataset_stats(dataset, split)
620
- return dataset
621
-
622
- def _detect_machine_learning_task(self, dataset: DatamintDataset) -> str:
623
- try:
624
- # Detect machine learning task based on the dataset params
625
- if dataset.return_as_semantic_segmentation and len(dataset.segmentation_labels_set) > 0:
626
- return 'semantic segmentation'
627
- elif dataset.return_segmentations and len(dataset.segmentation_labels_set) > 0:
628
- return 'instance segmentation'
629
-
630
- num_labels = len(dataset.frame_labels_set) # FIXME: when not frame by frame
631
- num_categories = len(dataset.segmentation_labels_set)
632
- if num_categories == 0:
633
- if num_labels == 1:
634
- return 'binary classification'
635
- elif num_labels > 1:
636
- return 'multilabel classification'
637
- elif num_categories == 1:
638
- if num_labels == 0:
639
- return 'multiclass classification'
640
- return 'multi-task classification'
641
- else:
642
- return 'multi-task classification'
643
- except Exception as e:
644
- _LOGGER.warning(f"Could not detect machine learning task: {e}")
645
-
646
- return 'unknown'
647
-
648
- def _log_predictions(self,
649
- predictions: List[Dict[str, Any]],
650
- dataset_split: Optional[str] = None,
651
- step: Optional[int] = None,
652
- epoch: Optional[int] = None):
653
- """
654
- Log the predictions of the model.
655
-
656
- Args:
657
- predictions (List[Dict[str, Any]]): The predictions to log. See example below.
658
- step (Optional[int]): The step of the experiment.
659
- epoch (Optional[int]): The epoch of the experiment.
660
-
661
-
662
- Example:
663
- .. code-block:: python
664
-
665
- predictions = [
666
- {
667
- 'resource_id': '123',
668
- # If not provided, it will be assumed predictions are for the whole resource.
669
- 'frame_index': 0,
670
- 'predicted': [
671
- {
672
- 'identifier': 'has_fracture',
673
- 'value': True, # Optional
674
- 'confidence': 0.9, # Optional
675
- 'ground_truth': True # Optional
676
- },
677
- {
678
- 'identifier': 'tumor',
679
- 'value': segmentation1, # Optional. numpy array of shape (H, W)
680
- 'confidence': 0.9 # Optional. Can be mask.max()
681
- }
682
- ]
683
- }]
684
- exp.log_predictions(predictions)
685
- """
686
-
687
- self._set_step(step)
688
- self._set_epoch(epoch)
689
-
690
- entry = {'type': 'prediction',
691
- 'predictions': predictions,
692
- 'dataset_split': dataset_split,
693
- 'step': step,
694
- }
695
-
696
- if dataset_split == 'test':
697
- for pred in predictions:
698
- # if prediction is categorical
699
- if pred['prediction_type'] == 'category':
700
- for p in pred['predicted']:
701
- if 'confidence' in p:
702
- self.highest_predictions[p['identifier']].add(pred)
703
- self.lowest_predictions[p['identifier']].add(pred)
704
-
705
- self.apihandler.log_entry(exp_id=self.exp_id,
706
- entry=entry)
707
-
708
- def log_classification_predictions(self,
709
- predictions_conf: np.ndarray,
710
- resource_ids: List[str],
711
- label_names: List[str],
712
- dataset_split: Optional[str] = None,
713
- frame_idxs: Optional[List[int]] = None,
714
- step: Optional[int] = None,
715
- epoch: Optional[int] = None,
716
- add_info: Optional[Dict] = None
717
- ):
718
- """
719
- Log the classification predictions of the model.
720
-
721
- Args:
722
- predictions_conf (np.ndarray): The predictions of the model. Can have two shapes:
723
-
724
- - Shape (N, C) where N is the number of samples and C is the number of classes.
725
- Does not need to sum to 1 (i.e., can be multilabel).
726
- - Shape (N,) where N is the number of samples.
727
- In this case, `label_names` should have the same length as the predictions.
728
-
729
- label_names (List[str]): The names of the classes.
730
- If the predictions are shape (N,), this should have the same length as the predictions.
731
- resource_ids (List[str]): The resource IDs of the samples.
732
- dataset_split (Optional[str]): The dataset split of the predictions.
733
- frame_idxs (Optional[List[int]]): The frame indexes of the predictions.
734
- step (Optional[int]): The step of the experiment.
735
- epoch (Optional[int]): The epoch of the experiment.
736
- add_info (Optional[Dict]): Additional information to add to each prediction.
737
-
738
- Example:
739
- .. code-block:: python
740
-
741
- predictions_conf = np.array([[0.9, 0.1], [0.2, 0.8]])
742
- label_names = ['cat', 'dog']
743
- resource_ids = ['123', '456']
744
- exp.log_classification_predictions(predictions_conf, label_names, resource_ids, dataset_split='test')
745
- """
746
-
747
- # check predictions shape and lengths
748
- if len(predictions_conf) != len(resource_ids):
749
- raise ValueError("Length of predictions and resource_ids must be the same.")
750
-
751
- if predictions_conf.ndim == 2:
752
- if predictions_conf.shape[1] != len(label_names):
753
- raise ValueError("Number of classes must match the number of columns in predictions_conf.")
754
- elif predictions_conf.ndim == 1:
755
- if len(label_names) != len(predictions_conf):
756
- raise ValueError("Number of classes must match the length of predictions when predictions are 1D.")
757
- else:
758
- raise ValueError("Predictions must be 1D or 2D.")
759
-
760
- resources = self.apihandler.get_resources_by_ids(resource_ids)
761
-
762
- predictions = []
763
- if predictions_conf.ndim == 2:
764
- for res, pred in zip(resources, predictions_conf):
765
- data = {'resource_id': res['id'],
766
- 'resource_filename': res['filename'],
767
- 'prediction_type': 'category',
768
- 'predicted': [{'identifier': label_names[i],
769
- 'confidence': float(pred[i])}
770
- for i in range(len(pred))]}
771
- if add_info is not None:
772
- data.update(add_info)
773
- predictions.append(data)
774
- else:
775
- # if predictions are 1D, label_names have the same length
776
- for res, pred, label_i in zip(resources, predictions_conf, label_names):
777
- data = {'resource_id': res['id'],
778
- 'resource_filename': res['filename'],
779
- 'prediction_type': 'category',
780
- 'predicted': [{'identifier': label_i,
781
- 'confidence': float(pred)}
782
- ]}
783
- if add_info is not None:
784
- data.update(add_info)
785
- predictions.append(data)
786
-
787
- if frame_idxs is not None:
788
- for pred, frame_idx in zip(predictions, frame_idxs):
789
- pred['frame_index'] = frame_idx
790
- self._log_predictions(predictions, step=step, epoch=epoch, dataset_split=dataset_split)
791
-
792
- def log_segmentation_predictions(self,
793
- resource_id: str | dict,
794
- predictions: np.ndarray | str,
795
- label_name: str | dict[int, str],
796
- frame_index: int | list[int] | None = None,
797
- threshold: float = 0.5,
798
- predictions_format: Literal['multi-class', 'probability'] = 'probability'
799
- ):
800
- """
801
- Log the segmentation prediction of the model for a single frame
802
-
803
- Args:
804
- resource_id: The resource ID of the sample.
805
- predictions: The predictions of the model. One binary mask for each class. Can be a numpy array of shape (H, W) or (N,H,W);
806
- Or a path to a png file; Or a path to a .nii/.nii.gz file.
807
- label_name: The name of the class or a dictionary mapping pixel values to names.
808
- Example: ``{1: 'Femur', 2: 'Tibia'}`` means that pixel value 1 is 'Femur' and pixel value 2 is 'Tibia'.
809
- frame_index: The frame index of the prediction or a list of frame indexes.
810
- If a list, must have the same length as the predictions.
811
- If None,
812
- threshold: The threshold to apply to the predictions.
813
- predictions_format: The format of the predictions. Can be a probability mask ('probability') or a multi-class mask ('multi-class').
814
-
815
- Example:
816
- .. code-block:: python
817
-
818
- resource_id = '123'
819
- predictions = np.array([[0.1, 0.4], [0.9, 0.2]])
820
- label_name = 'fracture'
821
- exp.log_segmentation_predictions(resource_id, predictions, label_name, threshold=0.5)
822
-
823
- .. code-block:: python
824
-
825
- resource_id = '456'
826
- predictions = np.array([[0, 1, 2], [1, 2, 0]]) # Multi-class mask with values 0, 1, 2
827
- label_name = {1: 'Femur', 2: 'Tibia'} # Mapping of pixel values to class names
828
- exp.log_segmentation_predictions(
829
- resource_id,
830
- predictions,
831
- label_name,
832
- predictions_format='multi-class'
833
- )
834
- """
835
-
836
- if predictions_format not in ['multi-class', 'probability']:
837
- raise ValueError("predictions_format must be 'multi-class' or 'probability'.")
838
-
839
- if isinstance(label_name, dict) and predictions_format!='multi-class':
840
- raise ValueError("If label_name is a dictionary, predictions_format must be 'multi-class'.")
841
-
842
- if isinstance(resource_id, dict):
843
- resource_id = resource_id['id']
844
-
845
- if self.model_id is None:
846
- raise ValueError("Model is not logged. Cannot log segmentation predictions. see `log_model` method.")
847
-
848
- if isinstance(predictions, str):
849
- predictions = io_utils.read_array_normalized(predictions)
850
-
851
- if predictions_format == 'probability':
852
- predictions = predictions > threshold
853
-
854
- is_2d_prediction = predictions.ndim == 2
855
-
856
- if predictions.ndim == 4 and predictions.shape[1] == 1:
857
- predictions = predictions[:, 0]
858
- elif predictions.ndim == 2:
859
- predictions = predictions[np.newaxis]
860
- elif predictions.ndim != 3:
861
- raise ValueError(f"Prediction with shape {predictions.shape} is different than (H, W) and (N,H,W).")
862
-
863
- if frame_index is None:
864
- if is_2d_prediction:
865
- raise ValueError("frame_index must be provided when predictions is 2D.")
866
- frame_index = list(range(predictions.shape[0]))
867
- elif isinstance(frame_index, int):
868
- frame_index = [frame_index]
869
- else:
870
- if len(frame_index) != predictions.shape[0]:
871
- raise ValueError("Length of frame_index must match the first dimension of predictions.")
872
-
873
- new_ann_id = self.apihandler.upload_segmentations(
874
- resource_id=resource_id,
875
- file_path=predictions.transpose(1, 2, 0),
876
- name=label_name,
877
- frame_index=frame_index,
878
- model_id=self.model_id,
879
- worklist_id=self.project['worklist_id'],
880
- )
881
-
882
- def log_semantic_seg_predictions(self,
883
- predictions: np.ndarray | str,
884
- resource_ids: Union[list[str], str],
885
- label_names: list[str],
886
- dataset_split: Optional[str] = None,
887
- frame_idxs: Optional[list[int]] = None,
888
- step: Optional[int] = None,
889
- epoch: Optional[int] = None,
890
- threshold: float = 0.5
891
- ):
892
- """
893
- Log the semantic segmentation predictions of the model.
894
-
895
- Args:
896
- predictions (np.ndarray | str): The predictions of the model. A list of numpy arrays of shape (N, C, H, W).
897
- Or a path to a png file; Or a path to a .nii.gz file.
898
- label_names (list[str]): The names of the classes. List of strings of size C.
899
- resource_ids (list[str]): The resource IDs of the samples.
900
- dataset_split (Optional[str]): The dataset split of the predictions.
901
- frame_idxs (Optional[list[int]]): The frame indexes of the predictions.
902
- step (Optional[int]): The step of the experiment.
903
- epoch (Optional[int]): The epoch of the experiment.
904
- """
905
-
906
- if isinstance(predictions, str):
907
- predictions = io_utils.read_array_normalized(predictions)
908
-
909
- if isinstance(resource_ids, str):
910
- resource_ids = [resource_ids] * len(predictions)
911
-
912
- if predictions.ndim != 4:
913
- raise ValueError("Predictions must be of shape (N, C, H, W).")
914
-
915
- # check lengths
916
- if len(predictions) != len(resource_ids):
917
- raise ValueError("Length of predictions and resource_ids must be the same.")
918
-
919
- if frame_idxs is not None:
920
- if len(predictions) != len(frame_idxs):
921
- raise ValueError("Length of predictions and frame_idxs must be the same.")
922
- # non negative frame indexes
923
- if any(fidx < 0 for fidx in frame_idxs):
924
- raise ValueError("Frame indexes must be non-negative.")
925
-
926
- if len(label_names) != predictions.shape[1]:
927
- raise ValueError("Number of classes must match the number of columns in predictions.")
928
-
929
- predictions_conf = predictions.max(axis=(2, 3)) # final shape: (N, C)
930
-
931
- # log it as classification predictions
932
- self.log_classification_predictions(predictions_conf=predictions_conf,
933
- label_names=label_names,
934
- resource_ids=resource_ids,
935
- dataset_split=dataset_split,
936
- frame_idxs=frame_idxs,
937
- step=step,
938
- epoch=epoch,
939
- add_info={'origin': 'semantic segmentation'})
940
-
941
- if self.model_id is not None:
942
- _LOGGER.info("Uploading segmentation masks to the platform.")
943
- # For each frame
944
- predictions = predictions > threshold
945
- grouped_predictions = defaultdict(list)
946
- for fidx, res_id, pred in zip(frame_idxs, resource_ids, predictions):
947
- grouped_predictions[res_id].append((fidx, pred))
948
-
949
- for res_id, list_preds in grouped_predictions.items():
950
- frame_idxs = [fidx for fidx, _ in list_preds]
951
- preds = np.stack([pred for _, pred in list_preds])
952
- for i in range(len(label_names)):
953
- preds_i = preds[:, i] # get the i-th class predictions
954
- # preds_i.shape: (N, H, W)
955
- new_ann_id = self.apihandler.upload_segmentations(
956
- resource_id=res_id,
957
- file_path=preds_i,
958
- name=label_names[i],
959
- frame_index=frame_idxs,
960
- model_id=self.model_id,
961
- worklist_id=self.project['worklist_id'],
962
- )
963
- else:
964
- _LOGGER.warning("Model is not logged. Skipping uploading segmentation masks.")
965
-
966
- def finish(self):
967
- """
968
- Finish the experiment.
969
- This will log the summary and finish the experiment.
970
- """
971
- def _process_toppredictions(top_predictions: Dict[str, TopN], rev: bool) -> Tuple[TopN, Dict]:
972
- preds_per_label = {key: values.get_top()
973
- for key, values in top_predictions.items()}
974
- # get the highest prediction over all labels
975
- preds_combined = TopN(5, key=_get_confidence_callback, reverse=rev)
976
- for label_preds in preds_per_label.values():
977
- for pred in label_preds:
978
- preds_combined.add(pred)
979
-
980
- return preds_combined, preds_per_label
981
-
982
- if self.is_finished:
983
- _LOGGER.debug("Experiment is already finished.")
984
- return
985
- _LOGGER.info("Finishing experiment")
986
- for callback in self.finish_callbacks:
987
- callback(self)
988
- self.time_finished = datetime.now(timezone.utc)
989
- time_spent_seconds = (self.time_finished - self.time_started).total_seconds()
990
-
991
- ### produce finishing summary ###
992
- # time spent
993
- self.add_to_summary({'time_spent_seconds': time_spent_seconds})
994
-
995
- # add the most interesting predictions
996
- if len(self.highest_predictions) > 0:
997
- highest_preds_combined, highest_preds_per_label = _process_toppredictions(self.highest_predictions, False)
998
- lowest_preds_combined, lowest_preds_per_label = _process_toppredictions(self.lowest_predictions, True)
999
-
1000
- self.add_to_summary({'highest_predictions': {'combined': highest_preds_combined.get_top(),
1001
- 'per_label': highest_preds_per_label
1002
- }
1003
- }
1004
- )
1005
-
1006
- self.add_to_summary({'lowest_predictions': {'combined': lowest_preds_combined.get_top(),
1007
- 'per_label': lowest_preds_per_label
1008
- }
1009
- }
1010
- )
1011
-
1012
- self.log_summary(result_summary=self.summary_log)
1013
- # if the model is not already logged, log it
1014
- if self.model is not None and self.model_id is None:
1015
- self.log_model(model=self.model, hyper_params=self.model_hyper_params)
1016
- self.apihandler.finish_experiment(self.exp_id)
1017
- self.is_finished = True
1018
-
1019
- _LOGGER.info("Experiment finished and uploaded to the platform.")
1020
-
1021
-
1022
- class _LogHistory:
1023
- """
1024
- TODO: integrate this with the Experiment class.
1025
- """
1026
-
1027
- def __init__(self):
1028
- self.history = []
1029
-
1030
- def append(self, dt: datetime = None, **kwargs):
1031
- if dt is None:
1032
- dt = datetime.now(timezone.utc)
1033
- else:
1034
- if dt.tzinfo is None:
1035
- _LOGGER.warning("No timezone information provided. Assuming UTC.")
1036
- dt = dt.replace(tzinfo=timezone.utc)
1037
-
1038
- item = {
1039
- # datetime in GMT+0
1040
- 'timestamp': dt.timestamp(),
1041
- **kwargs
1042
- }
1043
- self.history.append(item)
1044
-
1045
- def get_history(self) -> List[Dict]:
1046
- return self.history
1047
-
1048
-
1049
- _EXPERIMENT: Experiment = None