simcats-datasets 2.4.0__tar.gz → 2.5.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (28) hide show
  1. {simcats_datasets-2.4.0 → simcats_datasets-2.5.0}/PKG-INFO +1 -1
  2. {simcats_datasets-2.4.0 → simcats_datasets-2.5.0}/pyproject.toml +1 -1
  3. simcats_datasets-2.5.0/simcats_datasets/__init__.py +2 -0
  4. {simcats_datasets-2.4.0 → simcats_datasets-2.5.0}/simcats_datasets/generation/_create_dataset.py +68 -30
  5. {simcats_datasets-2.4.0 → simcats_datasets-2.5.0}/simcats_datasets/loading/_load_dataset.py +43 -14
  6. {simcats_datasets-2.4.0 → simcats_datasets-2.5.0}/simcats_datasets/loading/load_ground_truth.py +7 -0
  7. {simcats_datasets-2.4.0 → simcats_datasets-2.5.0}/simcats_datasets/loading/pytorch.py +69 -27
  8. simcats_datasets-2.5.0/simcats_datasets/support_functions/pytorch_format_output.py +169 -0
  9. {simcats_datasets-2.4.0 → simcats_datasets-2.5.0}/simcats_datasets.egg-info/PKG-INFO +1 -1
  10. simcats_datasets-2.4.0/simcats_datasets/__init__.py +0 -2
  11. simcats_datasets-2.4.0/simcats_datasets/support_functions/pytorch_format_output.py +0 -170
  12. {simcats_datasets-2.4.0 → simcats_datasets-2.5.0}/LICENSE +0 -0
  13. {simcats_datasets-2.4.0 → simcats_datasets-2.5.0}/README.md +0 -0
  14. {simcats_datasets-2.4.0 → simcats_datasets-2.5.0}/setup.cfg +0 -0
  15. {simcats_datasets-2.4.0 → simcats_datasets-2.5.0}/setup.py +0 -0
  16. {simcats_datasets-2.4.0 → simcats_datasets-2.5.0}/simcats_datasets/generation/__init__.py +0 -0
  17. {simcats_datasets-2.4.0 → simcats_datasets-2.5.0}/simcats_datasets/generation/_create_simulated_dataset.py +0 -0
  18. {simcats_datasets-2.4.0 → simcats_datasets-2.5.0}/simcats_datasets/loading/__init__.py +0 -0
  19. {simcats_datasets-2.4.0 → simcats_datasets-2.5.0}/simcats_datasets/support_functions/__init__.py +0 -0
  20. {simcats_datasets-2.4.0 → simcats_datasets-2.5.0}/simcats_datasets/support_functions/_json_encoders.py +0 -0
  21. {simcats_datasets-2.4.0 → simcats_datasets-2.5.0}/simcats_datasets/support_functions/clip_line_to_rectangle.py +0 -0
  22. {simcats_datasets-2.4.0 → simcats_datasets-2.5.0}/simcats_datasets/support_functions/convert_lines.py +0 -0
  23. {simcats_datasets-2.4.0 → simcats_datasets-2.5.0}/simcats_datasets/support_functions/data_preprocessing.py +0 -0
  24. {simcats_datasets-2.4.0 → simcats_datasets-2.5.0}/simcats_datasets/support_functions/get_lead_transition_labels.py +0 -0
  25. {simcats_datasets-2.4.0 → simcats_datasets-2.5.0}/simcats_datasets.egg-info/SOURCES.txt +0 -0
  26. {simcats_datasets-2.4.0 → simcats_datasets-2.5.0}/simcats_datasets.egg-info/dependency_links.txt +0 -0
  27. {simcats_datasets-2.4.0 → simcats_datasets-2.5.0}/simcats_datasets.egg-info/requires.txt +0 -0
  28. {simcats_datasets-2.4.0 → simcats_datasets-2.5.0}/simcats_datasets.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: simcats-datasets
3
- Version: 2.4.0
3
+ Version: 2.5.0
4
4
  Summary: SimCATS-Datasets is a Python package that simplifies the creation and loading of SimCATS datasets.
5
5
  Author-email: Fabian Hader <f.hader@fz-juelich.de>, Fabian Fuchs <f.fuchs@fz-juelich.de>, Karin Havemann <k.havemann@fz-juelich.de>, Sarah Fleitmann <s.fleitmann@fz-juelich.de>, Jan Vogelbruch <j.vogelbruch@fz-juelich.de>
6
6
  License: GNU GENERAL PUBLIC LICENSE
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "simcats-datasets"
7
- version = "2.4.0" # also change in docs/source/conf.py and __init__
7
+ version = "2.5.0" # also change in docs/source/conf.py and __init__
8
8
  license = { file="LICENSE" }
9
9
  authors = [
10
10
  { name="Fabian Hader", email="f.hader@fz-juelich.de" },
@@ -0,0 +1,2 @@
1
+ __all__ = []
2
+ __version__ = "2.5.0"
@@ -16,7 +16,8 @@ __all__ = []
16
16
 
17
17
 
18
18
  def create_dataset(dataset_path: str,
19
- csds: List[np.ndarray],
19
+ csds: Optional[List[np.ndarray]] = None,
20
+ sensor_scans: Optional[List[np.ndarray]] = None,
20
21
  occupations: Optional[List[np.ndarray]] = None,
21
22
  tct_masks: Optional[List[np.ndarray]] = None,
22
23
  ct_by_dot_masks: Optional[List[np.ndarray]] = None,
@@ -27,6 +28,7 @@ def create_dataset(dataset_path: str,
27
28
  max_len_line_labels_chunk: Optional[int] = None,
28
29
  max_len_metadata_chunk: Optional[int] = None,
29
30
  dtype_csd: np.dtype = np.float32,
31
+ dtype_sensor_scan: np.dtype = np.float32,
30
32
  dtype_occ: np.dtype = np.float32,
31
33
  dtype_tct: np.dtype = np.uint8,
32
34
  dtype_ct_by_dot: np.dtype = np.uint8,
@@ -35,7 +37,10 @@ def create_dataset(dataset_path: str,
35
37
 
36
38
  Args:
37
39
  dataset_path: The path where the new (v2) HDF5 dataset will be stored.
38
- csds: The list of CSDs to use for creating the dataset.
40
+ csds: The list of CSDs to use for creating the dataset. A dataset can have either CSDs or sensor scans, but
41
+ never both. Default is None.
42
+ sensor_scans: The list of sensor scans to use for creating the dataset. A dataset can have either CSDs or sensor
43
+ scans, but never both. Default is None.
39
44
  occupations: List of occupations to use for creating the dataset. Defaults to None.
40
45
  tct_masks: List of TCT masks to use for creating the dataset. Defaults to None.
41
46
  ct_by_dot_masks: List of CT by dot masks to use for creating the dataset. Defaults to None.
@@ -43,15 +48,17 @@ def create_dataset(dataset_path: str,
43
48
  line_labels: List of line labels to use for creating the dataset. Defaults to None.
44
49
  metadata: List of metadata to use for creating the dataset. Defaults to None.
45
50
  max_len_line_coordinates_chunk: The expected maximal length for line coordinates in number of float values (each
46
- line requires 4 floats). If None, it is set to the largest value of the CSD shape. Default is None.
51
+ line requires 4 floats). If None, it is set to the largest value of the CSD (or sensor scan) shape. Default
52
+ is None.
47
53
  max_len_line_labels_chunk: The expected maximal length for line labels in number of uint8/char values (each line
48
54
  label, encoded as utf-8 json, should require at most 80 chars). If None, it is set to the largest value of
49
- the CSD shape * 20 (matching with allowed number of line coords). Default is None.
55
+ the CSD (or sensor scan) shape * 20 (matching with allowed number of line coords). Default is None.
50
56
  max_len_metadata_chunk: The expected maximal length for metadata in number of uint8/char values (each metadata
51
57
  dict, encoded as utf-8 json, should require at most 8000 chars, expected rather something like 4000, but
52
58
  could get larger for dot jumps metadata of high resolution scans). If None, it is set to 8000. Default is
53
59
  None.
54
60
  dtype_csd: Specifies the dtype to be used for saving CSDs. Default is np.float32.
61
+ dtype_sensor_scan: Specifies the dtype to be used for saving sensor scans. Default is np.float32.
55
62
  dtype_occ: Specifies the dtype to be used for saving Occupations. Default is np.float32.
56
63
  dtype_tct: Specifies the dtype to be used for saving TCTs. Default is np.uint8.
57
64
  dtype_ct_by_dot: Specifies the dtype to be used for saving CT by dot masks. Default is np.uint8.
@@ -60,27 +67,57 @@ def create_dataset(dataset_path: str,
60
67
  # Create path where the dataset will be saved (if folder doesn't exist already)
61
68
  Path(dirname(dataset_path)).mkdir(parents=True, exist_ok=True)
62
69
 
70
+ # check if the dataset to be created is a csd or sensor_scan dataset
71
+ if csds is not None and sensor_scans is None:
72
+ csd_dataset = True
73
+ elif csds is None and sensor_scans is not None:
74
+ csd_dataset = False
75
+ else:
76
+ raise ValueError("A dataset can contain either CSDs or sensor scans but never both! Exactly one of the two has "
77
+ "to be None.")
78
+
63
79
  with h5py.File(dataset_path, "a") as hdf5_file:
64
80
  # get the number of total ids. This is especially required if a large dataset is loaded and saved step by step
65
- num_ids = len(csds)
81
+ if csd_dataset:
82
+ num_ids = len(csds)
83
+ else:
84
+ num_ids = len(sensor_scans)
66
85
 
67
- # process CSDs
68
- # save an example CSD to get shape and dtype
69
- temp_csd = csds[0].copy()
70
- # use chunks as this will speed up reading later! One chunk is set to be exactly one image (optimized to load
71
- # one image at a time during training)
72
- ds = hdf5_file.require_dataset(name='csds', shape=(0, *temp_csd.shape), dtype=dtype_csd,
73
- maxshape=(None, *temp_csd.shape))
86
+ # get a temp copy of a csd or sensor scan (to get the shape) and retrieve the corresponding HDF5 dataset
87
+ if csd_dataset:
88
+ # process CSDs
89
+ # save an example CSD to get shape and dtype
90
+ temp_data = csds[0].copy()
91
+ # use chunks as this will speed up reading later! One chunk is set to be exactly one image (optimized to
92
+ # load one image at a time during training)
93
+ ds = hdf5_file.require_dataset(name='csds',
94
+ shape=(0, *temp_data.shape),
95
+ dtype=dtype_csd,
96
+ maxshape=(None, *temp_data.shape))
97
+ else:
98
+ # process sensor scans
99
+ # save an example sensor scan to get shape and dtype
100
+ temp_data = sensor_scans[0].copy()
101
+ # use chunks as this will speed up reading later! One chunk is set to be exactly one image (optimized to
102
+ # load one image at a time during training)
103
+ ds = hdf5_file.require_dataset(name='sensor_scans',
104
+ shape=(0, *temp_data.shape),
105
+ dtype=dtype_sensor_scan,
106
+ maxshape=(None, *temp_data.shape))
74
107
  # determine index offset if there is already data in the dataset
75
108
  id_offset = ds.shape[0]
76
109
  # resize datasets to fit new data
77
110
  ds.resize(ds.shape[0] + num_ids, axis=0)
78
- ds[id_offset:] = np.array(csds).astype(dtype_csd)
111
+ # Add new CSDs or sensor scans to the dataset
112
+ if csd_dataset:
113
+ ds[id_offset:] = np.array(csds).astype(dtype_csd)
114
+ else:
115
+ ds[id_offset:] = np.array(sensor_scans).astype(dtype_sensor_scan)
79
116
  if occupations is not None:
80
117
  if len(occupations) != num_ids:
81
118
  raise ValueError(
82
- f"Number of new occupation arrays ({len(occupations)}) does not match the number of new CSDs "
83
- f"({num_ids}).")
119
+ f"Number of new occupation arrays ({len(occupations)}) does not match the number of new CSDs or "
120
+ f"sensor scans ({num_ids}).")
84
121
  # process Occupations
85
122
  # save an example occ to get shape
86
123
  temp_occ = occupations[0].copy()
@@ -91,15 +128,15 @@ def create_dataset(dataset_path: str,
91
128
  if ds.shape[0] != id_offset:
92
129
  raise ValueError(
93
130
  f"Number of already stored occupation arrays ({ds.shape[0]}) does not match the number of already "
94
- f"stored CSDs ({id_offset}).")
131
+ f"stored CSDs or sensor scans ({id_offset}).")
95
132
  # resize datasets to fit new data
96
133
  ds.resize(ds.shape[0] + num_ids, axis=0)
97
134
  ds[id_offset:] = np.array(occupations).astype(dtype_occ)
98
135
  if tct_masks is not None:
99
136
  if len(tct_masks) != num_ids:
100
137
  raise ValueError(
101
- f"Number of new TCT mask arrays ({len(tct_masks)}) does not match the number of new CSDs "
102
- f"({num_ids}).")
138
+ f"Number of new TCT mask arrays ({len(tct_masks)}) does not match the number of new CSDs or sensor "
139
+ f"scans ({num_ids}).")
103
140
  # process tct masks
104
141
  # save an example tct to get shape and dtype
105
142
  temp_tct = tct_masks[0].copy()
@@ -110,7 +147,7 @@ def create_dataset(dataset_path: str,
110
147
  if ds.shape[0] != id_offset:
111
148
  raise ValueError(
112
149
  f"Number of already stored TCT mask arrays ({ds.shape[0]}) does not match the number of already "
113
- f"stored CSDs ({id_offset}).")
150
+ f"stored CSDs or sensor scans ({id_offset}).")
114
151
  # resize datasets to fit new data
115
152
  ds.resize(ds.shape[0] + num_ids, axis=0)
116
153
  ds[id_offset:] = np.array(tct_masks).astype(dtype_tct)
@@ -118,7 +155,7 @@ def create_dataset(dataset_path: str,
118
155
  if len(ct_by_dot_masks) != num_ids:
119
156
  raise ValueError(
120
157
  f"Number of new CT by dot mask arrays ({len(ct_by_dot_masks)}) does not match the number of new "
121
- f"CSDs ({num_ids}).")
158
+ f"CSDs or sensor scans ({num_ids}).")
122
159
  # process tct masks
123
160
  # save an example tct to get shape and dtype
124
161
  temp_ct_by_dot = ct_by_dot_masks[0].copy()
@@ -129,7 +166,7 @@ def create_dataset(dataset_path: str,
129
166
  if ds.shape[0] != id_offset:
130
167
  raise ValueError(
131
168
  f"Number of already stored CT by dot mask arrays ({ds.shape[0]}) does not match the number of "
132
- f"already stored CSDs ({id_offset}).")
169
+ f"already stored CSDs or sensor scans ({id_offset}).")
133
170
  # resize datasets to fit new data
134
171
  ds.resize(ds.shape[0] + num_ids, axis=0)
135
172
  ds[id_offset:] = np.array(ct_by_dot_masks).astype(dtype_tct)
@@ -137,11 +174,11 @@ def create_dataset(dataset_path: str,
137
174
  if len(line_coordinates) != num_ids:
138
175
  raise ValueError(
139
176
  f"Number of new line coordinates ({len(line_coordinates)}) does not match the number of new "
140
- f"CSDs ({num_ids}).")
177
+ f"CSDs or sensor scans ({num_ids}).")
141
178
  # retrieve fixed length for chunks
142
179
  if max_len_line_coordinates_chunk is None:
143
180
  # calculate max expected length (max_number_of_lines * 4 entries, max number estimated as max(shape)/4)
144
- max_len = max(temp_csd.shape)
181
+ max_len = max(temp_data.shape)
145
182
  else:
146
183
  max_len = max_len_line_coordinates_chunk
147
184
  # use chunks as this will speed up reading later! One chunk is set to be exactly one image (optimized to
@@ -151,7 +188,7 @@ def create_dataset(dataset_path: str,
151
188
  if ds.shape[0] != id_offset:
152
189
  raise ValueError(
153
190
  f"Number of already stored line coordinates ({ds.shape[0]}) does not match the number of already "
154
- f"stored CSDs ({id_offset}).")
191
+ f"stored CSDs or sensor scans ({id_offset}).")
155
192
  # resize datasets to fit new data
156
193
  ds.resize(ds.shape[0] + num_ids, axis=0)
157
194
  # process line coordinates
@@ -163,13 +200,13 @@ def create_dataset(dataset_path: str,
163
200
  if line_labels is not None:
164
201
  if len(line_labels) != num_ids:
165
202
  raise ValueError(
166
- f"Number of new line labels ({len(line_labels)}) does not match the number of new CSDs "
167
- f"({num_ids}).")
203
+ f"Number of new line labels ({len(line_labels)}) does not match the number of new CSDs or sensor "
204
+ f"scans ({num_ids}).")
168
205
  # retrieve fixed length for chunks
169
206
  if max_len_line_labels_chunk is None:
170
207
  # calculate max expected length (max_number_of_lines * 80 uint8 numbers, max number estimated as
171
208
  # max(shape)/4)
172
- max_len = max(temp_csd.shape) * 20
209
+ max_len = max(temp_data.shape) * 20
173
210
  else:
174
211
  max_len = max_len_line_labels_chunk
175
212
  # use chunks as this will speed up reading later! One chunk is set to be exactly one image (optimized to
@@ -179,7 +216,7 @@ def create_dataset(dataset_path: str,
179
216
  if ds.shape[0] != id_offset:
180
217
  raise ValueError(
181
218
  f"Number of already stored line labels ({ds.shape[0]}) does not match the number of already stored "
182
- f"CSDs ({id_offset}).")
219
+ f"CSDs or sensor scans ({id_offset}).")
183
220
  # resize datasets to fit new data
184
221
  ds.resize(ds.shape[0] + num_ids, axis=0)
185
222
  # process line labels
@@ -193,7 +230,8 @@ def create_dataset(dataset_path: str,
193
230
  if metadata is not None:
194
231
  if len(metadata) != num_ids:
195
232
  raise ValueError(
196
- f"Number of new metadata ({len(metadata)}) does not match the number of new CSDs ({num_ids}).")
233
+ f"Number of new metadata ({len(metadata)}) does not match the number of new CSDs or sensor scans "
234
+ f"({num_ids}).")
197
235
  # retrieve fixed length for chunks
198
236
  if max_len_metadata_chunk is None:
199
237
  # set len to 8000 uint8 numbers, that should already include some extra safety (expected smth. like
@@ -208,7 +246,7 @@ def create_dataset(dataset_path: str,
208
246
  if ds.shape[0] != id_offset:
209
247
  raise ValueError(
210
248
  f"Number of already stored metadata ({ds.shape[0]}) does not match the number of already stored "
211
- f"CSDs ({id_offset}).")
249
+ f"CSDs or sensor scans ({id_offset}).")
212
250
  # resize datasets to fit new data
213
251
  ds.resize(ds.shape[0] + num_ids, axis=0)
214
252
  # process metadata
@@ -13,11 +13,11 @@ from typing import List, Tuple, Union
13
13
 
14
14
  import h5py
15
15
  import numpy as np
16
- from tqdm import tqdm
17
16
 
18
17
 
19
18
  def load_dataset(file: Union[str, h5py.File],
20
- load_csds=True,
19
+ load_csds: bool = True,
20
+ load_sensor_scans: bool = False,
21
21
  load_occupations: bool = False,
22
22
  load_tct_masks: bool = False,
23
23
  load_ct_by_dot_masks: bool = False,
@@ -34,12 +34,15 @@ def load_dataset(file: Union[str, h5py.File],
34
34
  dataset. If a path is supplied, load_dataset will open the file itself. If you want to do multiple
35
35
  consecutive loads from the same file (e.g. for using th PyTorch SimcatsDataset without preloading), consider
36
36
  initializing the file object yourself and passing it, to improve the performance.
37
- load_csds: Determines if csds should be loaded. Default is True.
37
+ load_csds: Determines if CSDs should be loaded. A dataset can have either CSDs or sensor scans, but never both.
38
+ Default is True.
39
+ load_sensor_scans: Determines if sensor scans should be loaded. A dataset can have either CSDs or sensor scans,
40
+ but never both. Default is False.
38
41
  load_occupations: Determines if occupation data should be loaded. Default is False.
39
42
  load_tct_masks: Determines if lead transition masks should be loaded. Default is False.
40
43
  load_ct_by_dot_masks: Determines if charge transition labeled by affected dot masks should be loaded. This
41
44
  requires that ct_by_dot_masks have been added to the dataset. If a dataset has been created using
42
- create_simulated_dataset, these masks can be added afterwards using add_ct_by_dot_masks_to_dataset, mainly
45
+ create_simulated_dataset, these masks can be added afterward using add_ct_by_dot_masks_to_dataset, mainly
43
46
  to avoid recalculating them multiple times (for example for machine learning purposes). Default is False.
44
47
  load_line_coords: Determines if lead transition definitions using start and end points should be loaded. Default
45
48
  is False.
@@ -56,13 +59,15 @@ def load_dataset(file: Union[str, h5py.File],
56
59
 
57
60
  Returns:
58
61
  namedtuple: The namedtuple can be unpacked like every normal tuple, or instead accessed by field names. \n
59
- Depending on what has been enabled, the following data is included in the named tuple: \n
60
- - field 'csds': List containing all CSDs as numpy arrays. The list is sorted by the id of the CSDs (if no
61
- specific_ids are provided, else the order is given by specific_ids).
62
+ Depending on what has been enabled, the following data is included in the named tuple (all lists are sorted by
63
+ the id of the CSDs or sensor_scans if no specific_ids are provided, else the order is given by specific_ids): \n
64
+ - field 'csds': List containing all CSDs as numpy arrays.
65
+ - field 'sensor_scans': List containing all sensor scans as numpy arrays.
62
66
  - field 'occupations': List containing numpy arrays with occupations.
63
67
  - field 'tct_masks': List containing numpy arrays of TCT masks.
64
68
  - field 'ct_by_dot_masks': List containing numpy arrays of CT_by_dot masks.
65
- - field 'line_coordinates': List containing numpy arrays of line coordinates.
69
+ - field 'line_coordinates': List containing numpy arrays of line coordinates. Each row of the array specifies
70
+ the start and end points of one line.
66
71
  - field 'line_labels': List containing a list of dictionaries (one dict for each line specified as line
67
72
  coordinates).
68
73
  - field 'metadata': List containing dictionaries with all metadata (simcats configs) for each CSD.
@@ -72,6 +77,8 @@ def load_dataset(file: Union[str, h5py.File],
72
77
  fieldnames = []
73
78
  if load_csds:
74
79
  fieldnames.append("csds")
80
+ if load_sensor_scans:
81
+ fieldnames.append("sensor_scans")
75
82
  if load_occupations:
76
83
  fieldnames.append("occupations")
77
84
  if load_tct_masks:
@@ -86,10 +93,17 @@ def load_dataset(file: Union[str, h5py.File],
86
93
  fieldnames.append("metadata")
87
94
  if load_ids:
88
95
  fieldnames.append("ids")
89
- CSDDataset = namedtuple(typename="CSDDataset", field_names=fieldnames)
96
+ SimcatsDataset = namedtuple(typename="SimcatsDataset", field_names=fieldnames)
90
97
 
91
98
  # use nullcontext to catch the case where a file is passed instead of the string
92
99
  with h5py.File(file, "r") if isinstance(file, str) else nullcontext(file) as _file:
100
+ # check if the dataset contains csd or sensor_scans
101
+ if "csds" in _file:
102
+ csd_dataset = True
103
+ elif "sensor_scans" in _file:
104
+ csd_dataset = False
105
+ else:
106
+ raise KeyError("The dataset that should be loaded does not contain any csds or sensor_scans!")
93
107
  # if only specific ids should be loaded, check if all ids are available
94
108
  if specific_ids is not None:
95
109
  if isinstance(specific_ids, list) or isinstance(specific_ids, np.ndarray):
@@ -104,18 +118,31 @@ def load_dataset(file: Union[str, h5py.File],
104
118
  # Dataset with non-existing specific IDs (which else would only crash as soon as a non-existent ID is
105
119
  # requested during training). We can't check this on loading CSDs etc. as it massively slows down loading.
106
120
  if specific_ids is not None:
107
- if np.min(specific_ids) < 0 or np.max(specific_ids) >= len(_file["csds"]):
108
- msg = "Not all ids specified by 'specific_ids' are available in the dataset!"
109
- raise IndexError(msg)
121
+ if csd_dataset:
122
+ if np.min(specific_ids) < 0 or np.max(specific_ids) >= len(_file["csds"]):
123
+ msg = "Not all ids specified by 'specific_ids' are available in the dataset!"
124
+ raise IndexError(msg)
125
+ else:
126
+ if np.min(specific_ids) < 0 or np.max(specific_ids) >= len(_file["sensor_scans"]):
127
+ msg = "Not all ids specified by 'specific_ids' are available in the dataset!"
128
+ raise IndexError(msg)
110
129
  available_ids = specific_ids
111
130
  else:
112
- available_ids = range(len(_file["csds"]))
131
+ if csd_dataset:
132
+ available_ids = range(len(_file["csds"]))
133
+ else:
134
+ available_ids = range(len(_file["sensor_scans"]))
113
135
 
114
136
  if load_csds:
115
137
  if specific_ids is not None:
116
138
  csds = _file["csds"][specific_ids]
117
139
  else:
118
140
  csds = _file["csds"][:]
141
+ if load_sensor_scans:
142
+ if specific_ids is not None:
143
+ sensor_scans = _file["sensor_scans"][specific_ids]
144
+ else:
145
+ sensor_scans = _file["sensor_scans"][:]
119
146
  if load_occupations:
120
147
  if specific_ids is not None:
121
148
  occupations = _file["occupations"][specific_ids]
@@ -155,6 +182,8 @@ def load_dataset(file: Union[str, h5py.File],
155
182
  return_data = []
156
183
  if load_csds:
157
184
  return_data.append(csds)
185
+ if load_sensor_scans:
186
+ return_data.append(sensor_scans)
158
187
  if load_occupations:
159
188
  return_data.append(occupations)
160
189
  if load_tct_masks:
@@ -174,4 +203,4 @@ def load_dataset(file: Union[str, h5py.File],
174
203
  if specific_ids is not None and undo_sort_ids is not None:
175
204
  return_data = [[x[i] for i in undo_sort_ids] for x in return_data]
176
205
 
177
- return CSDDataset._make(tuple(return_data))
206
+ return SimcatsDataset._make(tuple(return_data))
@@ -30,6 +30,13 @@ from simcats_datasets.loading import load_dataset
30
30
  from simcats.support_functions import rotate_points
31
31
 
32
32
 
33
+ # Lists defining which ground truth type is supported for CSD and sensor scan datasets, respectively
34
+ _csd_ground_truths = ["load_zeros_masks", "load_tct_masks", "load_tct_by_dot_masks", "load_idt_masks", "load_ct_masks",
35
+ "load_ct_by_dot_masks", "load_tc_region_masks", "load_tc_region_minus_tct_masks",
36
+ "load_c_region_masks"]
37
+ _sensor_scan_ground_truths = ["load_zeros_masks", "load_tct_masks"]
38
+
39
+
33
40
  def load_zeros_masks(file: Union[str, h5py.File],
34
41
  specific_ids: Union[range, List[int], np.ndarray, None] = None,
35
42
  progress_bar: bool = True) -> List[np.ndarray]:
@@ -28,7 +28,8 @@ class SimcatsDataset(Dataset):
28
28
  ground_truth_preprocessors: Union[List[Union[str, Callable]], None] = None,
29
29
  format_output: Union[Callable, str, None] = None, preload: bool = True,
30
30
  max_concurrent_preloads: int = 100000,
31
- progress_bar: bool = False, ):
31
+ progress_bar: bool = False,
32
+ sensor_scan_dataset: bool = False,):
32
33
  """Initializes an object for providing simcats_datasets data to pytorch.
33
34
 
34
35
  Args:
@@ -77,8 +78,11 @@ class SimcatsDataset(Dataset):
77
78
  loading them step by step and for example converting the CSDs to float32 with a corresponding data
78
79
  preprocessor. Default is 100,000.
79
80
  progress_bar: Determines whether to display a progress bar while loading data. Default is False.
81
+ sensor_scan_dataset: Determines whether the dataset is a sensor scan dataset (contains sensor scans instead
82
+ of CSDs). Default is False.
80
83
  """
81
84
  self.__h5_path = h5_path
85
+ self.__sensor_scan_dataset = sensor_scan_dataset
82
86
  self.__specific_ids = specific_ids
83
87
  # set up the load ground truth function. Could be None, function referenced by string, or callable
84
88
  if load_ground_truth is None:
@@ -88,6 +92,20 @@ class SimcatsDataset(Dataset):
88
92
  self.__load_ground_truth = getattr(simcats_datasets.loading.load_ground_truth, load_ground_truth)
89
93
  else:
90
94
  self.__load_ground_truth = load_ground_truth
95
+ # check if it is possible to load the desired ground truth from the given dataset
96
+ try:
97
+ _ = self.load_ground_truth(file=self.__h5_path, specific_ids=[0], progress_bar=False)
98
+ except:
99
+ raise ValueError(
100
+ f"The specified ground truth ({self.load_ground_truth.__name__}) can't be loaded for the given "
101
+ f"dataset ({self.h5_path}). Please make sure to select a supported ground truth type.\n"
102
+ f"Supported ground truth types for CSD datasets created using "
103
+ f"simcats_datasets.generation.create_simulated_dataset are:\n"
104
+ f"{', '.join(simcats_datasets.loading.load_ground_truth._csd_ground_truths)}\n"
105
+ f"Supported ground truth types for sensor scan datasets created using "
106
+ f"simcats_datasets.generation.create_simulated_dataset are:\n"
107
+ f"{', '.join(simcats_datasets.loading.load_ground_truth._sensor_scan_ground_truths)}"
108
+ )
91
109
  # set up the data preprocessors. Could be None, functions referenced by strings, or callables
92
110
  if data_preprocessors is None:
93
111
  self.__data_preprocessors = data_preprocessors
@@ -121,31 +139,37 @@ class SimcatsDataset(Dataset):
121
139
  load_dataset(file=h5_file, load_csds=False, load_ids=True, specific_ids=self.specific_ids,
122
140
  progress_bar=self.progress_bar, ).ids)
123
141
  # preprocess an exemplary image to get final shape (some preprocessors might adjust the shape)
124
- _temp_csd = \
125
- load_dataset(file=h5_file, load_csds=True, specific_ids=[0], progress_bar=self.progress_bar, ).csds[0]
142
+ _temp_measurement = load_dataset(file=h5_file,
143
+ load_csds=not self.__sensor_scan_dataset,
144
+ load_sensor_scans=self.__sensor_scan_dataset,
145
+ specific_ids=[0],
146
+ progress_bar=self.progress_bar, )[0][0]
126
147
  if self.data_preprocessors is not None:
127
148
  for processor in self.data_preprocessors:
128
- _temp_csd = processor(_temp_csd)
129
- self.__shape = (self.__num_ids, *np.squeeze(_temp_csd).shape)
130
- # preload all data if requested
149
+ _temp_measurement = processor(_temp_measurement)
150
+ self.__shape = (self.__num_ids, *np.squeeze(_temp_measurement).shape)
151
+ # preload all measurements if requested
131
152
  if self.preload:
132
- self.__csds = []
153
+ self.__measurements = []
133
154
  self.__ground_truths = []
134
- # load and save data, at most max_concurrent_ids at a time
155
+ # load and save measurements, at most max_concurrent_ids at a time
135
156
  for i in range(math.ceil(self.__num_ids / max_concurrent_preloads)):
136
157
  _ids = range(i * max_concurrent_preloads,
137
158
  np.min([(i + 1) * max_concurrent_preloads, self.__num_ids]))
138
159
  if self.specific_ids is not None:
139
160
  _ids = [self.specific_ids[i] for i in _ids]
140
161
  # load
141
- _temp_csds = [csd for csd in
142
- load_dataset(file=h5_file, specific_ids=_ids, progress_bar=self.progress_bar, ).csds]
143
- # preprocess data
162
+ _temp_measurements = [data for data in load_dataset(file=h5_file,
163
+ load_csds=not self.__sensor_scan_dataset,
164
+ load_sensor_scans=self.__sensor_scan_dataset,
165
+ specific_ids=_ids,
166
+ progress_bar=self.progress_bar, )[0]]
167
+ # preprocess measurements
144
168
  if self.data_preprocessors is not None:
145
169
  for processor in self.data_preprocessors:
146
- _temp_csds = processor(_temp_csds)
147
- self.__csds.extend(_temp_csds)
148
- del _temp_csds
170
+ _temp_measurements = processor(_temp_measurements)
171
+ self.__measurements.extend(_temp_measurements)
172
+ del _temp_measurements
149
173
  try:
150
174
  _temp_ground_truths = [gt for gt in
151
175
  self.load_ground_truth(file=h5_file, specific_ids=_ids, progress_bar=self.progress_bar, )]
@@ -162,6 +186,10 @@ class SimcatsDataset(Dataset):
162
186
  def h5_path(self) -> str:
163
187
  return self.__h5_path
164
188
 
189
+ @property
190
+ def sensor_scan_dataset(self) -> bool:
191
+ return self.__sensor_scan_dataset
192
+
165
193
  @property
166
194
  def specific_ids(self) -> Union[range, List[int], np.ndarray, None]:
167
195
  return self.__specific_ids
@@ -196,19 +224,19 @@ class SimcatsDataset(Dataset):
196
224
 
197
225
  def __len__(self):
198
226
  """
199
- Returns the number of CSDs in the dataset.
227
+ Returns the number of measurements in the dataset.
200
228
  """
201
229
  return self.__num_ids
202
230
 
203
231
  def __getitem__(self, idx: int):
204
232
  """
205
- Retrieves a csd and the corresponding ground truth at given index idx.
233
+ Retrieves a measurement and the corresponding ground truth at given index idx.
206
234
 
207
235
  Args:
208
236
  idx: The id of the csd and ground truth to be returned.
209
237
  """
210
238
  if self.preload:
211
- csd = self.__csds[idx]
239
+ measurement = self.__measurements[idx]
212
240
  try:
213
241
  ground_truth = self.__ground_truths[idx]
214
242
  except IndexError:
@@ -220,12 +248,16 @@ class SimcatsDataset(Dataset):
220
248
  self.__h5_file = h5py.File(self.h5_path, mode="r")
221
249
  if self.specific_ids is not None:
222
250
  idx = self.specific_ids[idx]
223
- # load data
224
- csd = load_dataset(file=self.__h5_file, specific_ids=[idx], progress_bar=self.progress_bar).csds[0]
225
- # preprocess data
251
+ # load measurement
252
+ measurement = load_dataset(file=self.__h5_file,
253
+ load_csds=not self.__sensor_scan_dataset,
254
+ load_sensor_scans=self.__sensor_scan_dataset,
255
+ specific_ids=[idx],
256
+ progress_bar=self.progress_bar)[0][0]
257
+ # preprocess measurement
226
258
  if self.data_preprocessors is not None:
227
259
  for processor in self.data_preprocessors:
228
- csd = processor(csd)
260
+ measurement = processor(measurement)
229
261
  # load ground truth
230
262
  try:
231
263
  ground_truth = \
@@ -236,7 +268,7 @@ class SimcatsDataset(Dataset):
236
268
  ground_truth = processor(ground_truth)
237
269
  except TypeError:
238
270
  ground_truth = None
239
- return self.format_output(csd=csd, ground_truth=ground_truth, idx=idx)
271
+ return self.format_output(measurement=measurement, ground_truth=ground_truth, idx=idx)
240
272
 
241
273
  def __repr__(self):
242
274
  return (f"{self.__class__.__name__}(\n"
@@ -247,7 +279,8 @@ class SimcatsDataset(Dataset):
247
279
  f"\tground_truth_preprocessors=[{[', '.join([func.__name__ for func in self.ground_truth_preprocessors]) if self.ground_truth_preprocessors is not None else None][0]}],\n"
248
280
  f"\tformat_output={self.format_output.__name__},\n"
249
281
  f"\tpreload={self.preload},\n"
250
- f"\tprogress_bar={self.progress_bar}\n"
282
+ f"\tprogress_bar={self.progress_bar},\n"
283
+ f"\tsensor_scan_dataset={self.sensor_scan_dataset}\n"
251
284
  f")")
252
285
 
253
286
  def __del__(self):
@@ -264,7 +297,8 @@ class SimcatsConcatDataset(ConcatDataset):
264
297
  ground_truth_preprocessors: Union[List[Union[str, Callable]], None] = None,
265
298
  format_output: Union[Callable, str, None] = None, preload: bool = True,
266
299
  max_concurrent_preloads: int = 100000,
267
- progress_bar: bool = False, ):
300
+ progress_bar: bool = False,
301
+ sensor_scan_dataset: bool = False,):
268
302
  """Initializes an object for providing concatenated simcats_datasets data to pytorch.
269
303
 
270
304
  Args:
@@ -313,6 +347,8 @@ class SimcatsConcatDataset(ConcatDataset):
313
347
  loading them step by step and for example converting the CSDs to float32 with a corresponding data
314
348
  preprocessor. Default is 100.000.
315
349
  progress_bar: Determines whether to display a progress bar while loading data. Default is False.
350
+ sensor_scan_dataset: Determines whether the datasets are sensor scan datasets (contain sensor scans instead
351
+ of CSDs). Default is False.
316
352
  """
317
353
  _datasets = list()
318
354
  if specific_ids is not None and len(specific_ids) != len(h5_paths):
@@ -328,9 +364,10 @@ class SimcatsConcatDataset(ConcatDataset):
328
364
  data_preprocessors=data_preprocessors,
329
365
  ground_truth_preprocessors=ground_truth_preprocessors, format_output=format_output,
330
366
  preload=preload, max_concurrent_preloads=max_concurrent_preloads,
331
- progress_bar=progress_bar))
367
+ progress_bar=progress_bar, sensor_scan_dataset=sensor_scan_dataset))
332
368
  super().__init__(_datasets)
333
369
  self.__h5_paths = h5_paths
370
+ self.__sensor_scan_dataset = sensor_scan_dataset
334
371
  self.__specific_ids = specific_ids
335
372
  # set up the load ground truth function. Could be None, function referenced by string, or callable
336
373
  if load_ground_truth is None:
@@ -373,7 +410,7 @@ class SimcatsConcatDataset(ConcatDataset):
373
410
  if shape is None:
374
411
  shape = dataset.shape[1:]
375
412
  elif dataset.shape[1:] != shape:
376
- raise ValueError(f"The shape of the SimcatsDataset CSDs should be identical but found shapes "
413
+ raise ValueError(f"The shape of the SimcatsDataset Measurements should be identical but found shapes "
377
414
  f"{[dataset.shape[1:] for dataset in _datasets]}")
378
415
  self.__shape = (len(self), *shape)
379
416
 
@@ -381,6 +418,10 @@ class SimcatsConcatDataset(ConcatDataset):
381
418
  def h5_paths(self) -> List[str]:
382
419
  return self.__h5_paths
383
420
 
421
+ @property
422
+ def sensor_scan_dataset(self) -> bool:
423
+ return self.__sensor_scan_dataset
424
+
384
425
  @property
385
426
  def specific_ids(self) -> Union[List[Union[range, List[int], np.ndarray, None]], None]:
386
427
  return self.__specific_ids
@@ -422,5 +463,6 @@ class SimcatsConcatDataset(ConcatDataset):
422
463
  f"\tground_truth_preprocessors=[{[', '.join([func.__name__ for func in self.ground_truth_preprocessors]) if self.ground_truth_preprocessors is not None else None][0]}],\n"
423
464
  f"\tformat_output={self.format_output.__name__},\n"
424
465
  f"\tpreload={self.preload},\n"
425
- f"\tprogress_bar={self.progress_bar}\n"
466
+ f"\tprogress_bar={self.progress_bar},\n"
467
+ f"\tsensor_scan_dataset={self.sensor_scan_dataset},\n"
426
468
  f")")
@@ -0,0 +1,169 @@
1
+ """Functions for formatting the output of the **Pytorch Dataset class**.
2
+
3
+ Every function must accept a measurement (as array), a ground truth (e.g. TCT mask as array) and the image id as input.
4
+ Output type depends on the ground truth type and the required pytorch datatype (tensor as long, float, ...). Ground
5
+ truth could for example be a pixel mask or defined start end points of lines.
6
+ **Please look at format_dict_csd_float_ground_truth_long for a reference.**
7
+
8
+ @author: f.hader
9
+ """
10
+
11
+ from __future__ import annotations
12
+
13
+ from typing import Tuple
14
+
15
+ import numpy as np
16
+ import torch
17
+
18
+ def format_dict_csd_float_ground_truth_long(measurement: np.ndarray, ground_truth: np.ndarray, idx: int, ) -> dict[
19
+ str, torch.Tensor]:
20
+ """Format the output of the Pytorch Dataset class to be a dict with entries 'csd' and 'ground_truth' of dtype float and long, respectively. (default of Pytorch Dataset class.)
21
+
22
+ Args:
23
+ measurement: The measurement array.
24
+ ground_truth: Ground truth as pixel mask.
25
+ idx: index of the measurement. Not used in this format.
26
+
27
+ Returns:
28
+ Dict with 'csd' and 'ground_truth' of dtype float and long, respectively.
29
+ """
30
+ assert (measurement.size == ground_truth.size), \
31
+ f"Image and mask should be the same size, but are {measurement.size=} and {ground_truth.size=}"
32
+ return {"csd": torch.as_tensor(measurement.copy(), dtype=torch.float).contiguous(),
33
+ "ground_truth": torch.as_tensor(ground_truth.copy(), dtype=torch.long, ).contiguous(), }
34
+
35
+
36
+ def format_dict_csd_float16_ground_truth_long(measurement: np.ndarray, ground_truth: np.ndarray, idx: int, ) -> dict[
37
+ str, torch.Tensor]:
38
+ """Format the output of the Pytorch Dataset class to be a dict with entries 'csd' and 'ground_truth' of dtype float16 and long, respectively.
39
+
40
+ Args:
41
+ measurement: The measurement array.
42
+ ground_truth: Ground truth as pixel mask.
43
+ idx: index of the measurement. Not used in this format.
44
+
45
+ Returns:
46
+ Dict with 'csd' and 'ground_truth' of dtype float16 and long, respectively.
47
+ """
48
+ assert (measurement.size == ground_truth.size), \
49
+ f"Image and mask should be the same size, but are {measurement.size=} and {ground_truth.size=}"
50
+ return {"csd": torch.as_tensor(measurement.copy(), dtype=torch.float16).contiguous(),
51
+ "ground_truth": torch.as_tensor(ground_truth.copy(), dtype=torch.long, ).contiguous(), }
52
+
53
+
54
+ def format_dict_csd_float_ground_truth_float(measurement: np.ndarray, ground_truth: np.ndarray, idx: int, ) -> dict[
55
+ str, torch.Tensor]:
56
+ """Format the output of the Pytorch Dataset class to be a dict with entries 'csd' and 'ground_truth' of dtype float and float, respectively.
57
+
58
+ Args:
59
+ measurement: The measurement array.
60
+ ground_truth: Ground truth as pixel mask.
61
+ idx: index of the measurement. Not used in this format.
62
+
63
+ Returns:
64
+ Dict with 'csd' and 'ground_truth' of dtype float and float, respectively.
65
+ """
66
+ assert (measurement.size == ground_truth.size), \
67
+ f"Image and mask should be the same size, but are {measurement.size=} and {ground_truth.size=}"
68
+ return {"csd": torch.as_tensor(measurement.copy(), dtype=torch.float).contiguous(),
69
+ "ground_truth": torch.as_tensor(ground_truth.copy(), dtype=torch.float).contiguous(), }
70
+
71
+
72
+ def format_mmsegmentation(measurement: np.ndarray, ground_truth: np.ndarray, idx: int, ) -> dict[str, torch.Tensor]:
73
+ """Format the output of the Pytorch Dataset class to be conform to the MMSegmentation CustomDataset of version 0.6.0, see https://github.com/open-mmlab/mmsegmentation/blob/v0.6.0/mmseg/datasets/custom.py.
74
+
75
+ Args:
76
+ measurement: The measurement array.
77
+ ground_truth: Ground truth as pixel mask.
78
+ idx: index of the measurement.
79
+
80
+ Returns:
81
+ Dict with data conform to the MMSegmentation CustomDataset of version 0.6.0, see https://github.com/open-mmlab/mmsegmentation/blob/v0.6.0/mmseg/datasets/custom.py.
82
+ """
83
+ assert (measurement.size == ground_truth.size), \
84
+ f"Image and mask should be the same size, but are {measurement.size=} and {ground_truth.size=}"
85
+ return {"img": torch.as_tensor(measurement.copy()).float().contiguous(),
86
+ "gt_semantic_seg": torch.as_tensor(ground_truth.copy()).float().contiguous(),
87
+ "img_metas": {"filename": f"{idx}.jpg", "ori_filename": f"{idx}_ori.jpg", "ori_shape": measurement.shape[::-1],
88
+ # we want (100, 100, 1) not (1, 100, 100)
89
+ "img_shape": measurement.shape[::-1], "pad_shape": measurement.shape[::-1], # image shape after padding
90
+ "scale_factor": 1.0, "img_norm_cfg": {"mean": np.mean(measurement, axis=(-2, -1)), # mean for each channel
91
+ "std": np.std(measurement, axis=(-2, -1)), # std for each channel
92
+ "to_rgb": False, }, "img_id": f"{idx}", }, }
93
+
94
+
95
+ def format_csd_only(measurement: np.ndarray, ground_truth: np.ndarray, idx: int, ) -> torch.Tensor:
96
+ """Format the output of the Pytorch Dataset class to be just a measurement.
97
+
98
+ Args:
99
+ measurement: The measurement array.
100
+ ground_truth: Ground truth as pixel mask. Not used in this format.
101
+ idx: Index of the measurement. Not used in this format.
102
+
103
+ Returns:
104
+ The measurement as tensor.
105
+ """
106
+ return torch.as_tensor(measurement.copy(), dtype=torch.float).contiguous()
107
+
108
+
109
+ def format_csd_float16_only(measurement: np.ndarray, ground_truth: np.ndarray, idx: int, ) -> torch.Tensor:
110
+ """Format the output of the Pytorch Dataset class to be just a float16 (half precision) measurement.
111
+
112
+ Args:
113
+ measurement: The measurement array.
114
+ ground_truth: Ground truth as pixel mask. Not used in this format.
115
+ idx: Index of the measurement. Not used in this format.
116
+
117
+ Returns:
118
+ The float 16 (half precision) measurement as tensor.
119
+ """
120
+ return torch.as_tensor(measurement.copy(), dtype=torch.float16).contiguous()
121
+
122
+
123
+ def format_csd_bfloat16_only(measurement: np.ndarray, ground_truth: np.ndarray, idx: int, ) -> torch.Tensor:
124
+ """Format the output of the Pytorch Dataset class to be just a bfloat16 (half precision) measurement.
125
+
126
+ Args:
127
+ measurement: The measurement array.
128
+ ground_truth: Ground truth as pixel mask. Not used in this format.
129
+ idx: Index of the measurement. Not used in this format.
130
+
131
+ Returns:
132
+ The brain float 16 (half precision) measurement as tensor.
133
+ """
134
+ return torch.as_tensor(measurement.copy(), dtype=torch.bfloat16).contiguous()
135
+
136
+
137
+ def format_csd_class_index(measurement: np.ndarray, ground_truth: np.ndarray, idx: int, ) -> Tuple[
138
+ torch.Tensor, torch.Tensor, int]:
139
+ """Format the output of the Pytorch Dataset class to be the measurement, a class index (which is always 0 as we have no classes) and the index.
140
+
141
+ This is needed to be conform to the datasets used in DeepSVDD, see https://github.com/lukasruff/Deep-SVDD-PyTorch.
142
+
143
+ Args:
144
+ measurement: The measurement array.
145
+ ground_truth: Ground truth as pixel mask. Not used in this format.
146
+ idx: Index of the measurement.
147
+
148
+ Returns:
149
+ A tuple of measurement, class index, and the index.
150
+ """
151
+ return torch.as_tensor(measurement.copy(), dtype=torch.float).unsqueeze(0).contiguous(), torch.tensor(0), idx
152
+
153
+
154
+ def format_tuple_csd_float_ground_truth_float(measurement: np.ndarray, ground_truth: np.ndarray, idx: int, ) -> dict[
155
+ str, torch.Tensor]:
156
+ """Format the output of the Pytorch Dataset class to be a tuple of the measurement and the ground_truth.
157
+
158
+ Args:
159
+ measurement: The measurement array.
160
+ ground_truth: Ground truth as pixel mask.
161
+ idx: index of the measurement. Not used in this format.
162
+
163
+ Returns:
164
+ Tuple with measurement and ground_truth of dtype float and float, respectively.
165
+ """
166
+ assert (measurement.size == ground_truth.size), \
167
+ f"Image and mask should be the same size, but are {measurement.size=} and {ground_truth.size=}"
168
+ return (torch.as_tensor(measurement.copy(), dtype=torch.float).contiguous(),
169
+ torch.as_tensor(ground_truth.copy(), dtype=torch.float).contiguous(),)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: simcats-datasets
3
- Version: 2.4.0
3
+ Version: 2.5.0
4
4
  Summary: SimCATS-Datasets is a Python package that simplifies the creation and loading of SimCATS datasets.
5
5
  Author-email: Fabian Hader <f.hader@fz-juelich.de>, Fabian Fuchs <f.fuchs@fz-juelich.de>, Karin Havemann <k.havemann@fz-juelich.de>, Sarah Fleitmann <s.fleitmann@fz-juelich.de>, Jan Vogelbruch <j.vogelbruch@fz-juelich.de>
6
6
  License: GNU GENERAL PUBLIC LICENSE
@@ -1,2 +0,0 @@
1
- __all__ = []
2
- __version__ = "2.4.0"
@@ -1,170 +0,0 @@
1
- """Functions for formatting the output of the **Pytorch Dataset class**.
2
-
3
- Every function must accept a CSD (as array), a ground truth (e.g. TCT mask as array) and the image id as input.
4
- Output type depends on the ground truth type and the required pytorch datatype (tensor as long, float, ...). Ground
5
- truth could for example be a pixel mask or defined start end points of lines.
6
- **Please look at format_dict_csd_float_ground_truth_long for a reference.**
7
-
8
- @author: f.hader
9
- """
10
-
11
- from __future__ import annotations
12
-
13
- from typing import Tuple
14
-
15
- import numpy as np
16
- import torch
17
-
18
-
19
- def format_dict_csd_float_ground_truth_long(csd: np.ndarray, ground_truth: np.ndarray, idx: int, ) -> dict[
20
- str, torch.Tensor]:
21
- """Format the output of the Pytorch Dataset class to be a dict with entries 'csd' and 'ground_truth' of dtype float and long, respectively. (default of Pytorch Dataset class.)
22
-
23
- Args:
24
- csd: The CSD array.
25
- ground_truth: Ground truth as pixel mask.
26
- idx: index of the csd. Not used in this format.
27
-
28
- Returns:
29
- Dict with 'csd' and 'ground_truth' of dtype float and long, respectively.
30
- """
31
- assert (
32
- csd.size == ground_truth.size), f"Image and mask should be the same size, but are {csd.size=} and {ground_truth.size=}"
33
- return {"csd": torch.as_tensor(csd.copy(), dtype=torch.float).contiguous(),
34
- "ground_truth": torch.as_tensor(ground_truth.copy(), dtype=torch.long, ).contiguous(), }
35
-
36
-
37
- def format_dict_csd_float16_ground_truth_long(csd: np.ndarray, ground_truth: np.ndarray, idx: int, ) -> dict[
38
- str, torch.Tensor]:
39
- """Format the output of the Pytorch Dataset class to be a dict with entries 'csd' and 'ground_truth' of dtype float16 and long, respectively.
40
-
41
- Args:
42
- csd: The CSD array.
43
- ground_truth: Ground truth as pixel mask.
44
- idx: index of the csd. Not used in this format.
45
-
46
- Returns:
47
- Dict with 'csd' and 'ground_truth' of dtype float16 and long, respectively.
48
- """
49
- assert (
50
- csd.size == ground_truth.size), f"Image and mask should be the same size, but are {csd.size=} and {ground_truth.size=}"
51
- return {"csd": torch.as_tensor(csd.copy(), dtype=torch.float16).contiguous(),
52
- "ground_truth": torch.as_tensor(ground_truth.copy(), dtype=torch.long, ).contiguous(), }
53
-
54
-
55
- def format_dict_csd_float_ground_truth_float(csd: np.ndarray, ground_truth: np.ndarray, idx: int, ) -> dict[
56
- str, torch.Tensor]:
57
- """Format the output of the Pytorch Dataset class to be a dict with entries 'csd' and 'ground_truth' of dtype float and float, respectively.
58
-
59
- Args:
60
- csd: The CSD array.
61
- ground_truth: Ground truth as pixel mask.
62
- idx: index of the csd. Not used in this format.
63
-
64
- Returns:
65
- Dict with 'csd' and 'ground_truth' of dtype float and float, respectively.
66
- """
67
- assert (
68
- csd.size == ground_truth.size), f"Image and mask should be the same size, but are {csd.size=} and {ground_truth.size=}"
69
- return {"csd": torch.as_tensor(csd.copy(), dtype=torch.float).contiguous(),
70
- "ground_truth": torch.as_tensor(ground_truth.copy(), dtype=torch.float).contiguous(), }
71
-
72
-
73
- def format_mmsegmentation(csd: np.ndarray, ground_truth: np.ndarray, idx: int, ) -> dict[str, torch.Tensor]:
74
- """Format the output of the Pytorch Dataset class to be conform to the MMSegmentation CustomDataset of version 0.6.0, see https://github.com/open-mmlab/mmsegmentation/blob/v0.6.0/mmseg/datasets/custom.py.
75
-
76
- Args:
77
- csd: The CSD array.
78
- ground_truth: Ground truth as pixel mask.
79
- idx: index of the csd.
80
-
81
- Returns:
82
- Dict with data conform to the MMSegmentation CustomDataset of version 0.6.0, see https://github.com/open-mmlab/mmsegmentation/blob/v0.6.0/mmseg/datasets/custom.py.
83
- """
84
- assert (
85
- csd.size == ground_truth.size), f"Image and mask should be the same size, but are {csd.size=} and {ground_truth.size=}"
86
- return {"img": torch.as_tensor(csd.copy()).float().contiguous(),
87
- "gt_semantic_seg": torch.as_tensor(ground_truth.copy()).float().contiguous(),
88
- "img_metas": {"filename": f"{idx}.jpg", "ori_filename": f"{idx}_ori.jpg", "ori_shape": csd.shape[::-1],
89
- # we want (100, 100, 1) not (1, 100, 100)
90
- "img_shape": csd.shape[::-1], "pad_shape": csd.shape[::-1], # image shape after padding
91
- "scale_factor": 1.0, "img_norm_cfg": {"mean": np.mean(csd, axis=(-2, -1)), # mean for each channel
92
- "std": np.std(csd, axis=(-2, -1)), # std for each channel
93
- "to_rgb": False, }, "img_id": f"{idx}", }, }
94
-
95
-
96
- def format_csd_only(csd: np.ndarray, ground_truth: np.ndarray, idx: int, ) -> torch.Tensor:
97
- """Format the output of the Pytorch Dataset class to be just a CSD.
98
-
99
- Args:
100
- csd: The CSD array.
101
- ground_truth: Ground truth as pixel mask. Not used in this format.
102
- idx: Index of the csd. Not used in this format.
103
-
104
- Returns:
105
- The CSD as tensor.
106
- """
107
- return torch.as_tensor(csd.copy(), dtype=torch.float).contiguous()
108
-
109
-
110
- def format_csd_float16_only(csd: np.ndarray, ground_truth: np.ndarray, idx: int, ) -> torch.Tensor:
111
- """Format the output of the Pytorch Dataset class to be just a float16 (half precision) CSD.
112
-
113
- Args:
114
- csd: The CSD array.
115
- ground_truth: Ground truth as pixel mask. Not used in this format.
116
- idx: Index of the csd. Not used in this format.
117
-
118
- Returns:
119
- The float 16 (half precision) CSD as tensor.
120
- """
121
- return torch.as_tensor(csd.copy(), dtype=torch.float16).contiguous()
122
-
123
-
124
- def format_csd_bfloat16_only(csd: np.ndarray, ground_truth: np.ndarray, idx: int, ) -> torch.Tensor:
125
- """Format the output of the Pytorch Dataset class to be just a bfloat16 (half precision) CSD.
126
-
127
- Args:
128
- csd: The CSD array.
129
- ground_truth: Ground truth as pixel mask. Not used in this format.
130
- idx: Index of the csd. Not used in this format.
131
-
132
- Returns:
133
- The brain float 16 (half precision) CSD as tensor.
134
- """
135
- return torch.as_tensor(csd.copy(), dtype=torch.bfloat16).contiguous()
136
-
137
-
138
- def format_csd_class_index(csd: np.ndarray, ground_truth: np.ndarray, idx: int, ) -> Tuple[
139
- torch.Tensor, torch.Tensor, int]:
140
- """Format the output of the Pytorch Dataset class to be the CSD, a class index (which is always 0 as we have no classes) and the index.
141
-
142
- This is needed to be conform to the datasets used in DeepSVDD, see https://github.com/lukasruff/Deep-SVDD-PyTorch.
143
-
144
- Args:
145
- csd: The CSD array.
146
- ground_truth: Ground truth as pixel mask. Not used in this format.
147
- idx: Index of the csd.
148
-
149
- Returns:
150
- A tuple of CSD, class index, and the index.
151
- """
152
- return torch.as_tensor(csd.copy(), dtype=torch.float).unsqueeze(0).contiguous(), torch.tensor(0), idx
153
-
154
-
155
- def format_tuple_csd_float_ground_truth_float(csd: np.ndarray, ground_truth: np.ndarray, idx: int, ) -> dict[
156
- str, torch.Tensor]:
157
- """Format the output of the Pytorch Dataset class to be a tuple of the csd and the ground_truth.
158
-
159
- Args:
160
- csd: The CSD array.
161
- ground_truth: Ground truth as pixel mask.
162
- idx: index of the csd. Not used in this format.
163
-
164
- Returns:
165
- Tuple with csd and ground_truth of dtype float and float, respectively.
166
- """
167
- assert (
168
- csd.size == ground_truth.size), f"Image and mask should be the same size, but are {csd.size=} and {ground_truth.size=}"
169
- return (torch.as_tensor(csd.copy(), dtype=torch.float).contiguous(),
170
- torch.as_tensor(ground_truth.copy(), dtype=torch.float).contiguous(),)