simcats-datasets 2.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,2 @@
1
+ __all__ = []
2
+ __version__ = "2.4.0"
@@ -0,0 +1,6 @@
1
+ """Module with functions for creating datasets."""
2
+
3
+ from simcats_datasets.generation._create_dataset import create_dataset
4
+ from simcats_datasets.generation._create_simulated_dataset import create_simulated_dataset, add_ct_by_dot_masks_to_dataset
5
+
6
+ __all__ = ["create_dataset", "create_simulated_dataset", "add_ct_by_dot_masks_to_dataset"]
@@ -0,0 +1,221 @@
1
+ """Module with functions for creating a dataset from already existing data.
2
+
3
+ @author: f.hader
4
+ """
5
+
6
+ import json
7
+ import h5py
8
+ from typing import Optional, List
9
+ from pathlib import Path
10
+ from os.path import dirname
11
+ import numpy as np
12
+
13
+ from simcats_datasets.support_functions._json_encoders import MultipleJsonEncoders, NumpyEncoder, DataArrayEncoder
14
+
15
+ __all__ = []
16
+
17
+
18
+ def create_dataset(dataset_path: str,
19
+ csds: List[np.ndarray],
20
+ occupations: Optional[List[np.ndarray]] = None,
21
+ tct_masks: Optional[List[np.ndarray]] = None,
22
+ ct_by_dot_masks: Optional[List[np.ndarray]] = None,
23
+ line_coordinates: Optional[List[np.ndarray]] = None,
24
+ line_labels: Optional[List[dict]] = None,
25
+ metadata: Optional[List[dict]] = None,
26
+ max_len_line_coordinates_chunk: Optional[int] = None,
27
+ max_len_line_labels_chunk: Optional[int] = None,
28
+ max_len_metadata_chunk: Optional[int] = None,
29
+ dtype_csd: np.dtype = np.float32,
30
+ dtype_occ: np.dtype = np.float32,
31
+ dtype_tct: np.dtype = np.uint8,
32
+ dtype_ct_by_dot: np.dtype = np.uint8,
33
+ dtype_line_coordinates: np.dtype = np.float32) -> None:
34
+ """Function for creating simcats_datasets v2 format datasets from given data.
35
+
36
+ Args:
37
+ dataset_path: The path where the new (v2) HDF5 dataset will be stored.
38
+ csds: The list of CSDs to use for creating the dataset.
39
+ occupations: List of occupations to use for creating the dataset. Defaults to None.
40
+ tct_masks: List of TCT masks to use for creating the dataset. Defaults to None.
41
+ ct_by_dot_masks: List of CT by dot masks to use for creating the dataset. Defaults to None.
42
+ line_coordinates: List of line coordinates to use for creating the dataset. Defaults to None.
43
+ line_labels: List of line labels to use for creating the dataset. Defaults to None.
44
+ metadata: List of metadata to use for creating the dataset. Defaults to None.
45
+ max_len_line_coordinates_chunk: The expected maximal length for line coordinates in number of float values (each
46
+ line requires 4 floats). If None, it is set to the largest value of the CSD shape. Default is None.
47
+ max_len_line_labels_chunk: The expected maximal length for line labels in number of uint8/char values (each line
48
+ label, encoded as utf-8 json, should require at most 80 chars). If None, it is set to the largest value of
49
+ the CSD shape * 20 (matching with allowed number of line coords). Default is None.
50
+ max_len_metadata_chunk: The expected maximal length for metadata in number of uint8/char values (each metadata
51
+ dict, encoded as utf-8 json, should require at most 8000 chars, expected rather something like 4000, but
52
+ could get larger for dot jumps metadata of high resolution scans). If None, it is set to 8000. Default is
53
+ None.
54
+ dtype_csd: Specifies the dtype to be used for saving CSDs. Default is np.float32.
55
+ dtype_occ: Specifies the dtype to be used for saving Occupations. Default is np.float32.
56
+ dtype_tct: Specifies the dtype to be used for saving TCTs. Default is np.uint8.
57
+ dtype_ct_by_dot: Specifies the dtype to be used for saving CT by dot masks. Default is np.uint8.
58
+ dtype_line_coordinates: Specifies the dtype to be used for saving line coordinates. Default is np.float32.
59
+ """
60
+ # Create path where the dataset will be saved (if folder doesn't exist already)
61
+ Path(dirname(dataset_path)).mkdir(parents=True, exist_ok=True)
62
+
63
+ with h5py.File(dataset_path, "a") as hdf5_file:
64
+ # get the number of total ids. This is especially required if a large dataset is loaded and saved step by step
65
+ num_ids = len(csds)
66
+
67
+ # process CSDs
68
+ # save an example CSD to get shape and dtype
69
+ temp_csd = csds[0].copy()
70
+ # use chunks as this will speed up reading later! One chunk is set to be exactly one image (optimized to load
71
+ # one image at a time during training)
72
+ ds = hdf5_file.require_dataset(name='csds', shape=(0, *temp_csd.shape), dtype=dtype_csd,
73
+ maxshape=(None, *temp_csd.shape))
74
+ # determine index offset if there is already data in the dataset
75
+ id_offset = ds.shape[0]
76
+ # resize datasets to fit new data
77
+ ds.resize(ds.shape[0] + num_ids, axis=0)
78
+ ds[id_offset:] = np.array(csds).astype(dtype_csd)
79
+ if occupations is not None:
80
+ if len(occupations) != num_ids:
81
+ raise ValueError(
82
+ f"Number of new occupation arrays ({len(occupations)}) does not match the number of new CSDs "
83
+ f"({num_ids}).")
84
+ # process Occupations
85
+ # save an example occ to get shape
86
+ temp_occ = occupations[0].copy()
87
+ # use chunks as this will speed up reading later! One chunk is set to be exactly one image (optimized to
88
+ # load one image at a time during training)
89
+ ds = hdf5_file.require_dataset(name='occupations', shape=(0, *temp_occ.shape), dtype=dtype_occ,
90
+ maxshape=(None, *temp_occ.shape))
91
+ if ds.shape[0] != id_offset:
92
+ raise ValueError(
93
+ f"Number of already stored occupation arrays ({ds.shape[0]}) does not match the number of already "
94
+ f"stored CSDs ({id_offset}).")
95
+ # resize datasets to fit new data
96
+ ds.resize(ds.shape[0] + num_ids, axis=0)
97
+ ds[id_offset:] = np.array(occupations).astype(dtype_occ)
98
+ if tct_masks is not None:
99
+ if len(tct_masks) != num_ids:
100
+ raise ValueError(
101
+ f"Number of new TCT mask arrays ({len(tct_masks)}) does not match the number of new CSDs "
102
+ f"({num_ids}).")
103
+ # process tct masks
104
+ # save an example tct to get shape and dtype
105
+ temp_tct = tct_masks[0].copy()
106
+ # use chunks as this will speed up reading later! One chunk is set to be exactly one image (optimized to
107
+ # load one image at a time during training)
108
+ ds = hdf5_file.require_dataset(name='tct_masks', shape=(0, *temp_tct.shape), dtype=dtype_tct,
109
+ maxshape=(None, *temp_tct.shape))
110
+ if ds.shape[0] != id_offset:
111
+ raise ValueError(
112
+ f"Number of already stored TCT mask arrays ({ds.shape[0]}) does not match the number of already "
113
+ f"stored CSDs ({id_offset}).")
114
+ # resize datasets to fit new data
115
+ ds.resize(ds.shape[0] + num_ids, axis=0)
116
+ ds[id_offset:] = np.array(tct_masks).astype(dtype_tct)
117
+ if ct_by_dot_masks is not None:
118
+ if len(ct_by_dot_masks) != num_ids:
119
+ raise ValueError(
120
+ f"Number of new CT by dot mask arrays ({len(ct_by_dot_masks)}) does not match the number of new "
121
+ f"CSDs ({num_ids}).")
122
+ # process tct masks
123
+ # save an example tct to get shape and dtype
124
+ temp_ct_by_dot = ct_by_dot_masks[0].copy()
125
+ # use chunks as this will speed up reading later! One chunk is set to be exactly one image (optimized to
126
+ # load one image at a time during training)
127
+ ds = hdf5_file.require_dataset(name='ct_by_dot_masks', shape=(0, *temp_ct_by_dot.shape),
128
+ dtype=dtype_ct_by_dot, maxshape=(None, *temp_ct_by_dot.shape))
129
+ if ds.shape[0] != id_offset:
130
+ raise ValueError(
131
+ f"Number of already stored CT by dot mask arrays ({ds.shape[0]}) does not match the number of "
132
+ f"already stored CSDs ({id_offset}).")
133
+ # resize datasets to fit new data
134
+ ds.resize(ds.shape[0] + num_ids, axis=0)
135
+ ds[id_offset:] = np.array(ct_by_dot_masks).astype(dtype_tct)
136
+ if line_coordinates is not None:
137
+ if len(line_coordinates) != num_ids:
138
+ raise ValueError(
139
+ f"Number of new line coordinates ({len(line_coordinates)}) does not match the number of new "
140
+ f"CSDs ({num_ids}).")
141
+ # retrieve fixed length for chunks
142
+ if max_len_line_coordinates_chunk is None:
143
+ # calculate max expected length (max_number_of_lines * 4 entries, max number estimated as max(shape)/4)
144
+ max_len = max(temp_csd.shape)
145
+ else:
146
+ max_len = max_len_line_coordinates_chunk
147
+ # use chunks as this will speed up reading later! One chunk is set to be exactly one image (optimized to
148
+ # load one image at a time during training)
149
+ ds = hdf5_file.require_dataset(name='line_coordinates', shape=(0, max_len), dtype=dtype_line_coordinates,
150
+ maxshape=(None, max_len))
151
+ if ds.shape[0] != id_offset:
152
+ raise ValueError(
153
+ f"Number of already stored line coordinates ({ds.shape[0]}) does not match the number of already "
154
+ f"stored CSDs ({id_offset}).")
155
+ # resize datasets to fit new data
156
+ ds.resize(ds.shape[0] + num_ids, axis=0)
157
+ # process line coordinates
158
+ # pad to a fixed size, so that we don't need the leaky special dtype
159
+ line_coordinates = np.array(
160
+ [np.pad(l_c.flatten(), ((0, max_len - l_c.size)), 'constant', constant_values=np.nan) for l_c in
161
+ line_coordinates])
162
+ ds[id_offset:] = line_coordinates.astype(dtype_line_coordinates)
163
+ if line_labels is not None:
164
+ if len(line_labels) != num_ids:
165
+ raise ValueError(
166
+ f"Number of new line labels ({len(line_labels)}) does not match the number of new CSDs "
167
+ f"({num_ids}).")
168
+ # retrieve fixed length for chunks
169
+ if max_len_line_labels_chunk is None:
170
+ # calculate max expected length (max_number_of_lines * 80 uint8 numbers, max number estimated as
171
+ # max(shape)/4)
172
+ max_len = max(temp_csd.shape) * 20
173
+ else:
174
+ max_len = max_len_line_labels_chunk
175
+ # use chunks as this will speed up reading later! One chunk is set to be exactly one image (optimized to
176
+ # load one image at a time during training)
177
+ ds = hdf5_file.require_dataset(name='line_labels', shape=(0, max_len), dtype=np.uint8,
178
+ maxshape=(None, max_len))
179
+ if ds.shape[0] != id_offset:
180
+ raise ValueError(
181
+ f"Number of already stored line labels ({ds.shape[0]}) does not match the number of already stored "
182
+ f"CSDs ({id_offset}).")
183
+ # resize datasets to fit new data
184
+ ds.resize(ds.shape[0] + num_ids, axis=0)
185
+ # process line labels
186
+ line_labels = [json.dumps(l_l).encode("utf-8") for l_l in line_labels]
187
+ # go to bytes array for better saving and loading
188
+ line_labels = [np.frombuffer(l_l, dtype=np.uint8) for l_l in line_labels]
189
+ # pad with whitespace (" " in uint8 = 32) to a fixed size, so that we don't need the leaky special dtype
190
+ line_labels = np.array(
191
+ [np.pad(l_l, ((0, max_len - l_l.size)), 'constant', constant_values=32) for l_l in line_labels])
192
+ ds[id_offset:] = line_labels
193
+ if metadata is not None:
194
+ if len(metadata) != num_ids:
195
+ raise ValueError(
196
+ f"Number of new metadata ({len(metadata)}) does not match the number of new CSDs ({num_ids}).")
197
+ # retrieve fixed length for chunks
198
+ if max_len_metadata_chunk is None:
199
+ # set len to 8000 uint8 numbers, that should already include some extra safety (expected smth. like
200
+ # 3200-4000 chars)
201
+ max_len = 8000
202
+ else:
203
+ max_len = max_len_metadata_chunk
204
+ # use chunks as this will speed up reading later! One chunk is set to be exactly one image (optimized to
205
+ # load one image at a time during training)
206
+ ds = hdf5_file.require_dataset(name='metadata', shape=(0, max_len), dtype=np.uint8,
207
+ maxshape=(None, max_len))
208
+ if ds.shape[0] != id_offset:
209
+ raise ValueError(
210
+ f"Number of already stored metadata ({ds.shape[0]}) does not match the number of already stored "
211
+ f"CSDs ({id_offset}).")
212
+ # resize datasets to fit new data
213
+ ds.resize(ds.shape[0] + num_ids, axis=0)
214
+ # process metadata
215
+ metadata = [json.dumps(meta, cls=MultipleJsonEncoders(NumpyEncoder, DataArrayEncoder)).encode("utf-8") for
216
+ meta in metadata]
217
+ # go to bytes array for better saving and loading
218
+ metadata = [np.frombuffer(m, dtype=np.uint8) for m in metadata]
219
+ # pad with whitespace (" " in uint8 = 32) to a fixed size, so that we don't need the leaky special dtype
220
+ metadata = np.array([np.pad(m, ((0, max_len - m.size)), 'constant', constant_values=32) for m in metadata])
221
+ ds[id_offset:] = metadata
@@ -0,0 +1,372 @@
1
+ """Module with functions for creating a csd dataset using SimCATs for simulations.
2
+
3
+ @author: f.hader
4
+ """
5
+
6
+ import itertools
7
+ import json
8
+ import math
9
+ from pathlib import Path
10
+ from typing import List, Tuple, Optional
11
+
12
+ import h5py
13
+
14
+ # data handling
15
+ import numpy as np
16
+
17
+ # parallel
18
+ from parallelbar import progress_imap
19
+
20
+ # for SimCATS simulation
21
+ from simcats import Simulation, default_configs
22
+ from simcats.distortions import OccupationDotJumps
23
+ from simcats.support_functions import (
24
+ LogNormalSamplingRange,
25
+ NormalSamplingRange,
26
+ UniformSamplingRange, ExponentialSamplingRange,
27
+ )
28
+ from tqdm import tqdm
29
+
30
+ from simcats_datasets.loading import load_dataset
31
+ from simcats_datasets.loading.load_ground_truth import load_ct_by_dot_masks
32
+ # label creation based on line intersection
33
+ from simcats_datasets.support_functions.get_lead_transition_labels import get_lead_transition_labels
34
+ from simcats_datasets.support_functions._json_encoders import NumpyEncoder
35
+
36
+ __all__ = []
37
+
38
+
39
+ def _simulate(args: Tuple) -> Tuple:
40
+ """Method to simulate a csd with the given args. Required for parallel simulation in create_cimulated_dataset.
41
+
42
+ Args:
43
+ args: Tuple of sample_range_g1, sample_range_g2, volt_range, simcats_config, resolution.
44
+
45
+ Returns:
46
+ Tuple of csd, occ, lead_trans, metadata, line_points, labels.
47
+ """
48
+ sample_range_g1, sample_range_g2, volt_range, simcats_config, resolution = args
49
+
50
+ # random number generator used for sampling volt ranges.
51
+ # !Must be generated here! Else same for every process!
52
+ rng = np.random.default_rng()
53
+ # !also update the rng of the configs, because else all workers sample the same noise!
54
+ for distortion in (
55
+ *simcats_config["occupation_distortions"],
56
+ *simcats_config["sensor_potential_distortions"],
57
+ *simcats_config["sensor_response_distortions"],
58
+ ):
59
+ if hasattr(distortion, "rng"):
60
+ distortion.rng = np.random.default_rng()
61
+ if hasattr(distortion, "sigma"):
62
+ # get sigma
63
+ temp_sigma = distortion.sigma
64
+ # modify sigma
65
+ if isinstance(distortion.sigma, LogNormalSamplingRange):
66
+ temp_sigma._LogNormalSamplingRange__rng = np.random.default_rng()
67
+ elif isinstance(distortion.sigma, UniformSamplingRange):
68
+ temp_sigma._UniformSamplingRange__rng = np.random.default_rng()
69
+ elif isinstance(distortion.sigma, NormalSamplingRange):
70
+ temp_sigma._NormalSamplingRange__rng = np.random.default_rng()
71
+ elif isinstance(distortion.sigma, ExponentialSamplingRange):
72
+ temp_sigma._ExponentialSamplingRange__rng = np.random.default_rng()
73
+ # set sigma
74
+ distortion.sigma = temp_sigma
75
+ sim = Simulation(**simcats_config)
76
+
77
+ # sample voltage ranges
78
+ g1_start = rng.uniform(low=sample_range_g1[0], high=sample_range_g1[1])
79
+ g2_start = rng.uniform(low=sample_range_g2[0], high=sample_range_g2[1])
80
+ g1_range = np.array([g1_start, g1_start + volt_range[0]])
81
+ g2_range = np.array([g2_start, g2_start + volt_range[1]])
82
+ # perform simulation
83
+ csd, occ, lead_trans, metadata = sim.measure(
84
+ sweep_range_g1=g1_range, sweep_range_g2=g2_range, resolution=resolution
85
+ )
86
+ # calculate lead_transition labels
87
+ ideal_csd_conf = metadata["ideal_csd_config"]
88
+ line_points, labels = get_lead_transition_labels(
89
+ sweep_range_g1=g1_range,
90
+ sweep_range_g2=g2_range,
91
+ ideal_csd_config=ideal_csd_conf,
92
+ lead_transition_mask=lead_trans,
93
+ )
94
+ return csd, occ, lead_trans, metadata, line_points, labels
95
+
96
+
97
+ def create_simulated_dataset(
98
+ dataset_path: str,
99
+ simcats_config: dict = default_configs["GaAs_v1"],
100
+ n_runs: int = 10000,
101
+ resolution: np.ndarray = np.array([100, 100]),
102
+ volt_range: np.ndarray = np.array([0.03, 0.03]),
103
+ tags: Optional[dict] = None,
104
+ num_workers: int = 1,
105
+ progress_bar: bool = True,
106
+ max_len_line_coordinates_chunk: int = 100,
107
+ max_len_line_labels_chunk: int = 2000,
108
+ max_len_metadata_chunk: int = 8000,
109
+ dtype_csd: np.dtype = np.float32,
110
+ dtype_occ: np.dtype = np.float32,
111
+ dtype_tct: np.dtype = np.uint8,
112
+ dtype_line_coordinates: np.dtype = np.float32,
113
+ ) -> None:
114
+ """Function for generating simulated datasets using SimCATS for simulations.
115
+
116
+ **Warning**: This function expects that the simulation config uses IdealCSDGeometric from SimCATS. Other
117
+ implementations are not guaranteed to work.
118
+
119
+ Args:
120
+ dataset_path: The path where the dataset will be stored. Can also be an already existing dataset, to which new
121
+ data is added.
122
+ simcats_config: Configuration for simcats simulation class. Default is the GaAs_v1 config provided by simcats.
123
+ n_runs: Number of CSDs to be generated. Default is 10000.
124
+ resolution: Pixel resolution for both axis of the CSDs, first number of columns (x), then number of rows (y).
125
+ Default is np.array([100, 100]). \n
126
+ Example: \n
127
+ [res_g1, res_g2]
128
+ volt_range: Volt range for both axis of the CSDs. Individual CSDs with the specified size are randomly sampled
129
+ in the voltage space. Default is np.array([0.03, 0.03]) (usually the scans from RWTH GaAs offler sample are
130
+ 30mV x 30mV).
131
+ tags: Additional tags for the data to be simulated, which will be added to the dataset DataFrame. Default is
132
+ None. \n
133
+ Example: \n
134
+ {"tags": "shifted sensor, no noise", "sample": "GaAs"}.
135
+ num_workers: Number of workers to parallelize dataset creation. Minimum is 1. Default is 1.
136
+ progress_bar: Determines whether to display a progress bar. Default is True.
137
+ max_len_line_coordinates_chunk: Maximum number of line coordinates. This is the size of the flattened array,
138
+ therefore 100 means 20 lines. Default is 100.
139
+ max_len_line_labels_chunk: Maximum number of chars for the line label dict. Default is 2000.
140
+ max_len_metadata_chunk: Maximum number of chars for the metadata dict. Default is 8000.
141
+ dtype_csd: Specifies the dtype to be used for saving CSDs. Default is np.float32.
142
+ dtype_occ: Specifies the dtype to be used for saving Occupations. Default is np.float32.
143
+ dtype_tct: Specifies the dtype to be used for saving TCTs. Default is np.uint8.
144
+ dtype_line_coordinates: Specifies the dtype to be used for saving line coordinates. Default is np.float32.
145
+ """
146
+ # set tags to an empty dict if none were supplied
147
+ if tags is None:
148
+ tags = {}
149
+
150
+ # Create path where the dataset will be saved (if folder doesn't exist already)
151
+ Path(Path(dataset_path).parent).mkdir(parents=True, exist_ok=True)
152
+
153
+ # arange volt limits so that random sampling gives us a starting point that is at least the defined volt_range below
154
+ # the maximum
155
+ sample_range_g1 = simcats_config["volt_limits_g1"].copy()
156
+ sample_range_g1[-1] -= volt_range[0]
157
+ sample_range_g2 = simcats_config["volt_limits_g2"].copy()
158
+ sample_range_g2[-1] -= volt_range[1]
159
+
160
+ with h5py.File(dataset_path, "a") as hdf5_file:
161
+ # load datasets or create them if not already there
162
+ csds = hdf5_file.require_dataset(
163
+ name="csds",
164
+ shape=(0, resolution[1], resolution[0]),
165
+ chunks=(1, resolution[1], resolution[0]),
166
+ dtype=dtype_csd,
167
+ maxshape=(None, resolution[1], resolution[0]),
168
+ )
169
+ occupations = hdf5_file.require_dataset(
170
+ name="occupations",
171
+ shape=(0, resolution[1], resolution[0], 2),
172
+ chunks=(1, resolution[1], resolution[0], 2),
173
+ dtype=dtype_occ,
174
+ maxshape=(None, resolution[1], resolution[0], 2),
175
+ )
176
+ tct_masks = hdf5_file.require_dataset(
177
+ name="tct_masks",
178
+ shape=(0, resolution[1], resolution[0]),
179
+ chunks=(1, resolution[1], resolution[0]),
180
+ dtype=dtype_tct,
181
+ maxshape=(None, resolution[1], resolution[0]),
182
+ )
183
+ line_coords = hdf5_file.require_dataset(
184
+ name="line_coordinates",
185
+ shape=(0, max_len_line_coordinates_chunk),
186
+ chunks=(1, max_len_line_coordinates_chunk),
187
+ dtype=dtype_line_coordinates,
188
+ maxshape=(None, max_len_line_coordinates_chunk),
189
+ )
190
+ line_labels = hdf5_file.require_dataset(
191
+ name="line_labels",
192
+ shape=(0, max_len_line_labels_chunk),
193
+ chunks=(1, max_len_line_labels_chunk),
194
+ dtype=np.uint8,
195
+ maxshape=(None, max_len_line_labels_chunk),
196
+ )
197
+ metadatas = hdf5_file.require_dataset(
198
+ name="metadata",
199
+ shape=(0, max_len_metadata_chunk),
200
+ chunks=(1, max_len_metadata_chunk),
201
+ dtype=np.uint8,
202
+ maxshape=(None, max_len_metadata_chunk),
203
+ )
204
+ # determine index offset if there is already data in the dataset
205
+ id_offset = csds.shape[0]
206
+
207
+ # resize datasets to fit new data
208
+ csds.resize(csds.shape[0] + n_runs, axis=0)
209
+ occupations.resize(occupations.shape[0] + n_runs, axis=0)
210
+ tct_masks.resize(tct_masks.shape[0] + n_runs, axis=0)
211
+ line_coords.resize(line_coords.shape[0] + n_runs, axis=0)
212
+ line_labels.resize(line_labels.shape[0] + n_runs, axis=0)
213
+ metadatas.resize(metadatas.shape[0] + n_runs, axis=0)
214
+
215
+ # simulate and save data
216
+ indices = range(id_offset, n_runs + id_offset)
217
+ arguments = itertools.repeat(
218
+ (sample_range_g1, sample_range_g2, volt_range, simcats_config, resolution),
219
+ times=len(indices),
220
+ )
221
+ for index, (csd, occ, lead_trans, metadata, line_points, labels) in zip(
222
+ indices,
223
+ progress_imap(
224
+ func=_simulate,
225
+ tasks=arguments,
226
+ n_cpu=num_workers,
227
+ total=len(indices),
228
+ chunk_size=len(indices) // num_workers,
229
+ disable=not progress_bar,
230
+ ),
231
+ ):
232
+ # save data
233
+ csds[index] = csd.astype(dtype_csd)
234
+ occupations[index] = occ.astype(dtype_occ)
235
+ tct_masks[index] = lead_trans.astype(dtype_tct)
236
+ line_coords[index] = np.pad(
237
+ line_points.flatten(),
238
+ ((0, max_len_line_coordinates_chunk - line_points.size)),
239
+ "constant",
240
+ constant_values=np.nan,
241
+ ).astype(dtype_line_coordinates)
242
+ # Convert the line label dictionary to a JSON string
243
+ json_line_labels = np.frombuffer(
244
+ json.dumps(labels).encode("utf-8"), dtype=np.uint8
245
+ )
246
+ # pad with whitespace (" " in uint8 = 32) to a fixed size, so that we don't need the leaky special dtype
247
+ json_line_labels_padded = np.pad(
248
+ json_line_labels,
249
+ ((0, max_len_line_labels_chunk - json_line_labels.size)),
250
+ "constant",
251
+ constant_values=32,
252
+ )
253
+ line_labels[index] = json_line_labels_padded
254
+
255
+ # convert metadata
256
+ metadata_converted = dict()
257
+ for metadata_key, metadata_value in {**metadata, **tags}.items():
258
+ if isinstance(metadata_value, np.ndarray):
259
+ metadata_converted[metadata_key] = metadata_value
260
+ else:
261
+ metadata_converted[metadata_key] = str(metadata_value)
262
+ # add dot jumps to the metadata, to be able to apply it to all ground truth types
263
+ if metadata_key == "occupation_distortions":
264
+ for distortion in metadata_value:
265
+ if (
266
+ isinstance(distortion, OccupationDotJumps)
267
+ and distortion.activated
268
+ ):
269
+ metadata_converted[
270
+ f"OccupationDotJumps_axis{distortion.axis}"
271
+ ] = distortion._OccupationDotJumps__previous_noise
272
+ # save metadata
273
+ metadata_converted = json.dumps(
274
+ metadata_converted, cls=NumpyEncoder
275
+ ).encode("utf-8")
276
+ metadata_converted = np.frombuffer(metadata_converted, dtype=np.uint8)
277
+ # pad with whitespace (" " in uint8 = 32) to a fixed size, so that we don't need the leaky special dtype
278
+ metadata_padded = np.pad(
279
+ metadata_converted,
280
+ ((0, max_len_metadata_chunk - metadata_converted.size)),
281
+ "constant",
282
+ constant_values=32,
283
+ )
284
+ metadatas[index] = metadata_padded
285
+
286
+
287
+ def _load_ct_by_dot_masks_for_parallel(args: tuple) -> List[np.ndarray]:
288
+ """Helper function for parallel loading of CT_by_dot masks in add_ct_by_dot_masks_to_dataset.
289
+
290
+ Args:
291
+ args: Tuple of arguments
292
+
293
+ Returns:
294
+ Loaded CT_by_dot masks
295
+ """
296
+ return load_ct_by_dot_masks(*args)
297
+
298
+
299
+ def add_ct_by_dot_masks_to_dataset(
300
+ dataset_path: str,
301
+ num_workers: int = 10,
302
+ progress_bar: bool = True,
303
+ dtype_ct_by_dot: np.dtype = np.uint8,
304
+ batch_size_per_worker: int = 100,
305
+ ) -> None:
306
+ """Function for adding charge transitions labeled by dot masks to existing simulated datasets.
307
+
308
+ Args:
309
+ dataset_path: The path where the dataset is stored.
310
+ num_workers: Number of workers to parallelize dataset creation. Minimum is 1. Default is 10.
311
+ progress_bar: Determines whether to display a progress bar. Default is True.
312
+ dtype_ct_by_dot: Specifies the dtype to be used for saving CT_by_dot masks. Default is np.uint8.
313
+ batch_size_per_worker: Determines how many CT_by_dot masks are consecutively calculated by each worker, before
314
+ saving them. Default is 100.
315
+ """
316
+ num_ids = len(load_dataset(file=dataset_path, load_csds=False, load_ids=True).ids)
317
+ resolution = load_dataset(file=dataset_path, load_csds=True, specific_ids=[0]).csds[0].shape
318
+
319
+ # setup id ranges for the batches
320
+ id_ranges = list()
321
+ for i in range(math.ceil(num_ids / batch_size_per_worker)):
322
+ id_ranges.append(range(
323
+ i * batch_size_per_worker, np.min([(i + 1) * batch_size_per_worker, num_ids])
324
+ ))
325
+
326
+ # Iterate and always calculate exactly one batch per worker, so that we can write to HDF5 after all workers have
327
+ # finished their batch. SingleWriterMultipleReader mode of HDF5 was causing problems. Therefore, we now always write
328
+ # after all workers have stopped and closed their file objects.
329
+ with tqdm(unit="batches", total=len(id_ranges)) as pbar:
330
+ for i in range(math.ceil(len(id_ranges) / num_workers)):
331
+ temp_ct_by_dot_masks = list()
332
+ temp_indices = list()
333
+
334
+ arguments = zip(itertools.cycle([dataset_path]),
335
+ id_ranges[i*num_workers:(i+1)*num_workers],
336
+ itertools.cycle([False]),
337
+ itertools.cycle([1000]),
338
+ itertools.cycle([False])
339
+ )
340
+
341
+ # calculate data
342
+ for indices, _ct_by_dot_masks in zip(
343
+ id_ranges[i*num_workers:(i+1)*num_workers],
344
+ progress_imap(
345
+ func=_load_ct_by_dot_masks_for_parallel,
346
+ tasks=arguments,
347
+ n_cpu=num_workers,
348
+ total=len(id_ranges),
349
+ chunk_size=1,
350
+ disable=True,
351
+ ),
352
+ ):
353
+ temp_ct_by_dot_masks.append(_ct_by_dot_masks)
354
+ temp_indices.append(indices)
355
+ pbar.update(1)
356
+
357
+ # save data
358
+ with h5py.File(dataset_path, "a") as hdf5_file:
359
+ # load datasets or create them if not already there
360
+ ct_by_dot_masks = hdf5_file.require_dataset(
361
+ name="ct_by_dot_masks",
362
+ shape=(num_ids, resolution[0], resolution[1]),
363
+ dtype=dtype_ct_by_dot,
364
+ chunks=(1, resolution[0], resolution[1]),
365
+ maxshape=(None, resolution[0], resolution[1]),
366
+ )
367
+ for ids, masks in zip(temp_indices, temp_ct_by_dot_masks):
368
+ if isinstance(masks, list):
369
+ masks = np.array(masks)
370
+ # just to be sure to save the masks one by one as chunks (as I absolutely don't trust HDF5 anymore)
371
+ for _id, _mask in zip(ids, masks):
372
+ ct_by_dot_masks[_id] = _mask.astype(dtype_ct_by_dot)
@@ -0,0 +1,8 @@
1
+ """Module with functionalities for loading data from dataset files (HDF5 format).
2
+
3
+ Also contains functionalities for loading data as pytorch dataset with different ground truth types.
4
+ """
5
+
6
+ from simcats_datasets.loading._load_dataset import load_dataset
7
+
8
+ __all__ = ["load_dataset"]