simcats-datasets 2.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,177 @@
1
+ """The main interface for loading simcats datasets, typically stored as HDF5 files.
2
+
3
+ @author: f.hader
4
+ """
5
+
6
+ import json
7
+
8
+ from collections import namedtuple
9
+ from contextlib import nullcontext
10
+ from copy import deepcopy
11
+
12
+ from typing import List, Tuple, Union
13
+
14
+ import h5py
15
+ import numpy as np
16
+ from tqdm import tqdm
17
+
18
+
19
+ def load_dataset(file: Union[str, h5py.File],
20
+ load_csds=True,
21
+ load_occupations: bool = False,
22
+ load_tct_masks: bool = False,
23
+ load_ct_by_dot_masks: bool = False,
24
+ load_line_coords: bool = False,
25
+ load_line_labels: bool = False,
26
+ load_metadata: bool = False,
27
+ load_ids: bool = False,
28
+ specific_ids: Union[range, List[int], np.ndarray, None] = None,
29
+ progress_bar: bool = False) -> Tuple:
30
+ """Loads a dataset consisting of multiple CSDs from a given path.
31
+
32
+ Args:
33
+ file: The file to read the data from. Can either be an object of the type `h5py.File` or the path to the
34
+ dataset. If a path is supplied, load_dataset will open the file itself. If you want to do multiple
35
+ consecutive loads from the same file (e.g. for using th PyTorch SimcatsDataset without preloading), consider
36
+ initializing the file object yourself and passing it, to improve the performance.
37
+ load_csds: Determines if csds should be loaded. Default is True.
38
+ load_occupations: Determines if occupation data should be loaded. Default is False.
39
+ load_tct_masks: Determines if lead transition masks should be loaded. Default is False.
40
+ load_ct_by_dot_masks: Determines if charge transition labeled by affected dot masks should be loaded. This
41
+ requires that ct_by_dot_masks have been added to the dataset. If a dataset has been created using
42
+ create_simulated_dataset, these masks can be added afterwards using add_ct_by_dot_masks_to_dataset, mainly
43
+ to avoid recalculating them multiple times (for example for machine learning purposes). Default is False.
44
+ load_line_coords: Determines if lead transition definitions using start and end points should be loaded. Default
45
+ is False.
46
+ load_line_labels: Determines if labels for lead transitions defined using start and end points should be loaded.
47
+ Default is False.
48
+ load_metadata: Determines if the metadata (SimCATS config) of the CSDs should be loaded. Default is False.
49
+ load_ids: Determines if the available ids should be loaded (or in case of specific ids: the specific ids are
50
+ returned in the given order). Default is False.
51
+ specific_ids: Determines if only specific ids should be loaded. Using this option, the returned values are
52
+ sorted according to the specified ids and not necessarily ascending. If set to None, all data is loaded.
53
+ Default is None.
54
+ progress_bar: Determines whether to display a progress bar. This parameter has no functionality since version 2,
55
+ but is kept for compatibility reasons. Default is False.
56
+
57
+ Returns:
58
+ namedtuple: The namedtuple can be unpacked like every normal tuple, or instead accessed by field names. \n
59
+ Depending on what has been enabled, the following data is included in the named tuple: \n
60
+ - field 'csds': List containing all CSDs as numpy arrays. The list is sorted by the id of the CSDs (if no
61
+ specific_ids are provided, else the order is given by specific_ids).
62
+ - field 'occupations': List containing numpy arrays with occupations.
63
+ - field 'tct_masks': List containing numpy arrays of TCT masks.
64
+ - field 'ct_by_dot_masks': List containing numpy arrays of CT_by_dot masks.
65
+ - field 'line_coordinates': List containing numpy arrays of line coordinates.
66
+ - field 'line_labels': List containing a list of dictionaries (one dict for each line specified as line
67
+ coordinates).
68
+ - field 'metadata': List containing dictionaries with all metadata (simcats configs) for each CSD.
69
+ - field 'ids': List of the ids of the CSDs.
70
+ """
71
+ # fieldname are used for the namedtuple, to make fields accessible by names
72
+ fieldnames = []
73
+ if load_csds:
74
+ fieldnames.append("csds")
75
+ if load_occupations:
76
+ fieldnames.append("occupations")
77
+ if load_tct_masks:
78
+ fieldnames.append("tct_masks")
79
+ if load_ct_by_dot_masks:
80
+ fieldnames.append("ct_by_dot_masks")
81
+ if load_line_coords:
82
+ fieldnames.append("line_coordinates")
83
+ if load_line_labels:
84
+ fieldnames.append("line_labels")
85
+ if load_metadata:
86
+ fieldnames.append("metadata")
87
+ if load_ids:
88
+ fieldnames.append("ids")
89
+ CSDDataset = namedtuple(typename="CSDDataset", field_names=fieldnames)
90
+
91
+ # use nullcontext to catch the case where a file is passed instead of the string
92
+ with h5py.File(file, "r") if isinstance(file, str) else nullcontext(file) as _file:
93
+ # if only specific ids should be loaded, check if all ids are available
94
+ if specific_ids is not None:
95
+ if isinstance(specific_ids, list) or isinstance(specific_ids, np.ndarray):
96
+ # remember the previous order to undo the sorting that is required for reading from h5
97
+ specific_ids = deepcopy(specific_ids)
98
+ undo_sort_ids = np.argsort(np.argsort(specific_ids))
99
+ specific_ids.sort()
100
+ else:
101
+ undo_sort_ids = None
102
+ if load_ids:
103
+ # only check if ids are correct, if load_ids is True. This prevents initializing a non-preloaded PyTorch
104
+ # Dataset with non-existing specific IDs (which else would only crash as soon as a non-existent ID is
105
+ # requested during training). We can't check this on loading CSDs etc. as it massively slows down loading.
106
+ if specific_ids is not None:
107
+ if np.min(specific_ids) < 0 or np.max(specific_ids) >= len(_file["csds"]):
108
+ msg = "Not all ids specified by 'specific_ids' are available in the dataset!"
109
+ raise IndexError(msg)
110
+ available_ids = specific_ids
111
+ else:
112
+ available_ids = range(len(_file["csds"]))
113
+
114
+ if load_csds:
115
+ if specific_ids is not None:
116
+ csds = _file["csds"][specific_ids]
117
+ else:
118
+ csds = _file["csds"][:]
119
+ if load_occupations:
120
+ if specific_ids is not None:
121
+ occupations = _file["occupations"][specific_ids]
122
+ else:
123
+ occupations = _file["occupations"][:]
124
+ if load_tct_masks:
125
+ if specific_ids is not None:
126
+ tct_masks = _file["tct_masks"][specific_ids]
127
+ else:
128
+ tct_masks = _file["tct_masks"][:]
129
+ if load_ct_by_dot_masks:
130
+ if specific_ids is not None:
131
+ ct_by_dot_masks = _file["ct_by_dot_masks"][specific_ids]
132
+ else:
133
+ ct_by_dot_masks = _file["ct_by_dot_masks"][:]
134
+ if load_line_coords:
135
+ if specific_ids is not None:
136
+ # remove padded nan values
137
+ line_coords = [l_c[~np.isnan(l_c)].reshape((-1, 4)) for l_c in _file["line_coordinates"][specific_ids]]
138
+ else:
139
+ # remove padded nan values
140
+ line_coords = [l_c[~np.isnan(l_c)].reshape((-1, 4)) for l_c in _file["line_coordinates"][:]]
141
+ if load_line_labels:
142
+ if specific_ids is not None:
143
+ line_labels = [json.loads(l_l.tobytes().strip().decode("utf-8")) for l_l in
144
+ _file["line_labels"][specific_ids]]
145
+ else:
146
+ line_labels = [json.loads(l_l.tobytes().strip().decode("utf-8")) for l_l in _file["line_labels"][:]]
147
+ if load_metadata:
148
+ if specific_ids is not None:
149
+ metadata = [json.loads(meta.tobytes().strip().decode("utf-8")) for meta in
150
+ _file["metadata"][specific_ids]]
151
+ else:
152
+ metadata = [json.loads(meta.tobytes().strip().decode("utf-8")) for meta in _file["metadata"][:]]
153
+
154
+ # create a list of the further data to be returned (if activated)
155
+ return_data = []
156
+ if load_csds:
157
+ return_data.append(csds)
158
+ if load_occupations:
159
+ return_data.append(occupations)
160
+ if load_tct_masks:
161
+ return_data.append(tct_masks)
162
+ if load_ct_by_dot_masks:
163
+ return_data.append(ct_by_dot_masks)
164
+ if load_line_coords:
165
+ return_data.append(line_coords)
166
+ if load_line_labels:
167
+ return_data.append(line_labels)
168
+ if load_metadata:
169
+ return_data.append(metadata)
170
+ if load_ids:
171
+ return_data.append(available_ids)
172
+
173
+ # revert sorting if specific ids were used
174
+ if specific_ids is not None and undo_sort_ids is not None:
175
+ return_data = [[x[i] for i in undo_sort_ids] for x in return_data]
176
+
177
+ return CSDDataset._make(tuple(return_data))
@@ -0,0 +1,486 @@
1
+ """Functions for providing ground truth data to be used with the **Pytorch Dataset class**.
2
+
3
+ For examples of the different ground truth types, please have a look at the notebook Examples_Pytorch_SimcatsDataset.
4
+
5
+ Every function must accept a h5 File or path for a simcats_dataset as input, provide an option to use only specific_ids
6
+ and allow disabling the progress_bar.
7
+ Output type depends on the ground truth type. Could for example be a pixel mask or defined start end points of lines.
8
+ **Please look at load_zeros_masks for a reference.**
9
+
10
+ @author: f.hader
11
+ """
12
+
13
+ from typing import Union, List
14
+
15
+ import bezier
16
+ import h5py
17
+ import numpy as np
18
+
19
+ # imports required for eval, to create IdealCSDGeometric objects from metadata strings
20
+ import re
21
+
22
+ import sympy
23
+ from simcats.ideal_csd import IdealCSDGeometric
24
+ from simcats.ideal_csd.geometric import calculate_all_bezier_anchors, tct_bezier, initialize_tct_functions
25
+ from numpy import array
26
+ from simcats.distortions import OccupationDotJumps
27
+ from tqdm import tqdm
28
+
29
+ from simcats_datasets.loading import load_dataset
30
+ from simcats.support_functions import rotate_points
31
+
32
+
33
+ def load_zeros_masks(file: Union[str, h5py.File],
34
+ specific_ids: Union[range, List[int], np.ndarray, None] = None,
35
+ progress_bar: bool = True) -> List[np.ndarray]:
36
+ """Load no/empty ground truth data (arrays with only zeros).
37
+ Used for loading sets without ground truth. This is helpful to e.g. load experimental datasets without labels with
38
+ the pytorch SimcatsDataset class to analyze train results with the same Interface as for simulated data.
39
+
40
+ Args:
41
+ file: The file to read the data from. Can either be an object of the type `h5py.File` or the path to the
42
+ dataset. If you want to do multiple consecutive loads from the same file (e.g. for using th PyTorch
43
+ SimcatsDataset without preloading), consider initializing the file object yourself and passing it, to
44
+ improve the performance.
45
+ specific_ids: Determines if only specific ids should be loaded. Using this option, the returned values are
46
+ sorted according to the specified ids and not necessarily ascending. If set to None, all data is loaded.
47
+ Default is None.
48
+ progress_bar: Determines whether to display a progress bar. Default is True.
49
+
50
+ Returns:
51
+ List of arrays containing only zeros as ground truth data
52
+ """
53
+ return [np.zeros_like(csd, dtype=np.uint8) for csd in
54
+ load_dataset(file=file, load_csds=True, specific_ids=specific_ids, progress_bar=progress_bar).csds]
55
+
56
+
57
+ def load_tct_masks(file: Union[str, h5py.File],
58
+ specific_ids: Union[range, List[int], np.ndarray, None] = None,
59
+ progress_bar: bool = True) -> List[np.ndarray]:
60
+ """Load Total Charge Transition (TCT) masks as ground truth data.
61
+
62
+ Args:
63
+ file: The file to read the data from. Can either be an object of the type `h5py.File` or the path to the
64
+ dataset. If you want to do multiple consecutive loads from the same file (e.g. for using th PyTorch
65
+ SimcatsDataset without preloading), consider initializing the file object yourself and passing it, to
66
+ improve the performance.
67
+ specific_ids: Determines if only specific ids should be loaded. Using this option, the returned values are
68
+ sorted according to the specified ids and not necessarily ascending. If set to None, all data is loaded.
69
+ Default is None.
70
+ progress_bar: Determines whether to display a progress bar. Default is True.
71
+
72
+ Returns:
73
+ Total Charge Transition (TCT) masks
74
+ """
75
+ return load_dataset(file=file, load_csds=False, load_tct_masks=True, specific_ids=specific_ids,
76
+ progress_bar=progress_bar).tct_masks
77
+
78
+
79
+ def load_tct_by_dot_masks(file: Union[str, h5py.File],
80
+ specific_ids: Union[range, List[int], np.ndarray, None] = None,
81
+ progress_bar: bool = True,
82
+ lut_entries: int = 1000) -> List[np.ndarray]:
83
+ """Load Total Charge Transition (TCT) masks with transitions labeled by affected dot as ground truth data.
84
+
85
+ Args:
86
+ file: The file to read the data from. Can either be an object of the type `h5py.File` or the path to the
87
+ dataset. If you want to do multiple consecutive loads from the same file (e.g. for using th PyTorch
88
+ SimcatsDataset without preloading), consider initializing the file object yourself and passing it, to
89
+ improve the performance.
90
+ specific_ids: Determines if only specific ids should be loaded. Using this option, the returned values are
91
+ sorted according to the specified ids and not necessarily ascending. If set to None, all data is loaded.
92
+ Default is None.
93
+ progress_bar: Determines whether to display a progress bar. Default is True.
94
+ lut_entries: Number of lookup-table entries to use for tct_bezier. Default is 1000.
95
+
96
+ Returns:
97
+ Total Charge Transition (TCT) masks
98
+ """
99
+ tct_masks = load_dataset(file=file, load_csds=False, load_tct_masks=True, specific_ids=specific_ids,
100
+ progress_bar=progress_bar).tct_masks
101
+ metadata = load_dataset(file=file, load_csds=False, load_metadata=True, specific_ids=specific_ids,
102
+ progress_bar=progress_bar).metadata
103
+
104
+ for mask_id, meta in tqdm(enumerate(metadata), desc="calculating transitions", total=len(metadata),
105
+ disable=not progress_bar):
106
+ try:
107
+ csd_geometric = eval(meta["ideal_csd_config"])
108
+ except:
109
+ # This is required to support metadata from older simcats versions, where the class had a different name
110
+ csd_geometric = eval(meta["ideal_csd_config"].replace("IdealCSDGeometrical", "IdealCSDGeometric"))
111
+
112
+ tct_ids = np.unique(tct_masks[mask_id][np.nonzero(tct_masks[mask_id])])
113
+ # skip images with no TCTs
114
+ if tct_ids.size == 0:
115
+ continue
116
+ # setup tct functions
117
+ tct_funcs = initialize_tct_functions(tct_params=csd_geometric.tct_params[np.min(tct_ids) - 1:np.max(tct_ids)],
118
+ max_peaks=np.min(tct_ids))
119
+ # calculate all bezier anchors
120
+ bezier_coords = calculate_all_bezier_anchors(
121
+ tct_params=csd_geometric.tct_params[np.min(tct_ids) - 1:np.max(tct_ids)], max_peaks=np.min(tct_ids),
122
+ rotation=csd_geometric.rotation)
123
+
124
+ # get parameters of pixel vs voltage space for discretization
125
+ x_res = tct_masks[mask_id].shape[1]
126
+ y_res = tct_masks[mask_id].shape[0]
127
+ x_lims = meta["sweep_range_g1"]
128
+ y_lims = meta["sweep_range_g2"]
129
+ # stepsize x/y
130
+ x_step = (x_lims[-1] - x_lims[0]) / (x_res - 1)
131
+ y_step = (y_lims[-1] - y_lims[0]) / (y_res - 1)
132
+
133
+ # get corner points to know max value range for generating points
134
+ corner_points = np.array(
135
+ [[x_lims[0], y_lims[0]], [x_lims[0], y_lims[1]], [x_lims[1], y_lims[0]], [x_lims[1], y_lims[1]]])
136
+ x_c_rot = rotate_points(points=corner_points, angle=-csd_geometric.rotation)[:, 0]
137
+
138
+ # replace tct_mask by a new empty array
139
+ tct_masks[mask_id] = np.zeros_like(tct_masks[mask_id], dtype=np.uint8)
140
+ for tct_id in tct_ids:
141
+ for transition in range(tct_id * 2):
142
+ # get start x position of current transition
143
+ if transition == 0:
144
+ x_start = np.min(x_c_rot) - x_step
145
+ else:
146
+ # rotate bezier coords
147
+ bezier_coords_rot = rotate_points(points=bezier_coords[tct_id][transition - 1, 1],
148
+ angle=-csd_geometric.rotation)
149
+ x_start = bezier_coords_rot[0]
150
+ # get stop x position of current transition
151
+ if transition == tct_id * 2 - 1:
152
+ x_stop = np.max(x_c_rot) + x_step
153
+ else:
154
+ # rotate bezier coords
155
+ bezier_coords_rot = rotate_points(points=bezier_coords[tct_id][transition, 1],
156
+ angle=-csd_geometric.rotation)
157
+ x_stop = bezier_coords_rot[0]
158
+
159
+ # generate enough x-values to cover the complete range of the CSD with a
160
+ # higher resolution than required to have a precise result after discretization
161
+ tct_points = np.empty(((x_res + y_res) * 4, 2))
162
+ tct_points[:, 0] = np.linspace(x_start, x_stop, (x_res + y_res) * 4)
163
+
164
+ # Insert the transition line into the CSD
165
+ # The required TCT points are sampled and discretized.
166
+ # generate the y-values for all generated x-values
167
+ tct_points[:, 1] = tct_funcs[tct_id](x_eval=tct_points[:, 0], lut_entries=lut_entries)
168
+
169
+ # rotate the TCT into the original orientation
170
+ wf_points_rot = rotate_points(points=tct_points, angle=csd_geometric.rotation)
171
+
172
+ # select only TCT pixels that are in the csd-limits
173
+ valid_ids = np.where((wf_points_rot[:, 0] > (x_lims[0] - 0.5 * x_step)) & (
174
+ wf_points_rot[:, 0] < (x_lims[1] + 0.5 * x_step)) & (
175
+ wf_points_rot[:, 1] > (y_lims[0] - 0.5 * y_step)) & (
176
+ wf_points_rot[:, 1] < (y_lims[1] + 0.5 * y_step)))
177
+ # x_h_rot = x_h_rot[valid_ids]
178
+ # y_h_rot = y_h_rot[valid_ids]
179
+ wf_points_rot = wf_points_rot[valid_ids[0], :]
180
+
181
+ # insert TCT pixels into the csd
182
+ # calculation of the ids for the values:
183
+ # x = min(csd_x) + id * x_step
184
+ # add half step size, so that the pixel id of the nearest pixel is obtained after the division
185
+ # (round up if next higher value in range of 0.5 * step_size)
186
+ x_id = np.floor_divide(wf_points_rot[:, 0] + 0.5 * x_step - x_lims[0], x_step).astype(int)
187
+ y_id = np.floor_divide(wf_points_rot[:, 1] + 0.5 * y_step - y_lims[0], y_step).astype(int)
188
+ tct_masks[mask_id][y_id, x_id] = ((transition + 1) % 2) + 1
189
+
190
+ # apply dot jumps if any were active
191
+ if "OccupationDotJumps_axis0" in meta or "OccupationDotJumps_axis1" in meta:
192
+ occ_jumps_objects_string = [s.split(", rng")[0] + ")" for s in
193
+ re.findall(r"OccupationDotJumps[^\]]*", meta["occupation_distortions"]) if
194
+ "[activated" in s]
195
+ occ_jumps_objects = [eval(s) for s in occ_jumps_objects_string]
196
+ for obj in occ_jumps_objects:
197
+ if f"OccupationDotJumps_axis{obj.axis}" in meta:
198
+ obj._OccupationDotJumps__activated = True
199
+ obj._OccupationDotJumps__previous_noise = meta[f"OccupationDotJumps_axis{obj.axis}"]
200
+ _, tct_masks[mask_id] = obj.noise_function(occupations=np.empty((0, 0, 2)),
201
+ lead_transitions=tct_masks[mask_id],
202
+ volt_limits_g1=x_lims, volt_limits_g2=y_lims,
203
+ freeze=True)
204
+ return tct_masks
205
+
206
+
207
+ def load_idt_masks(file: Union[str, h5py.File],
208
+ specific_ids: Union[range, List[int], np.ndarray, None] = None,
209
+ progress_bar: bool = True) -> List[np.ndarray]:
210
+ """Load Inter-Dot Transition (IDT) masks as ground truth data.
211
+ In comparison to the Total Charge Transition (TCT) masks, only inter-dot transitions are included.
212
+
213
+ Args:
214
+ file: The file to read the data from. Can either be an object of the type `h5py.File` or the path to the
215
+ dataset. If you want to do multiple consecutive loads from the same file (e.g. for using th PyTorch
216
+ SimcatsDataset without preloading), consider initializing the file object yourself and passing it, to
217
+ improve the performance.
218
+ specific_ids: Determines if only specific ids should be loaded. Using this option, the returned values are
219
+ sorted according to the specified ids and not necessarily ascending. If set to None, all data is loaded.
220
+ Default is None.
221
+ progress_bar: Determines whether to display a progress bar. Default is True.
222
+
223
+ Returns:
224
+ Inter-Dot Transition (IDT) masks
225
+ """
226
+ tct_masks = load_dataset(file=file, load_csds=False, load_tct_masks=True, specific_ids=specific_ids,
227
+ progress_bar=progress_bar).tct_masks
228
+ metadata = load_dataset(file=file, load_csds=False, load_metadata=True, specific_ids=specific_ids,
229
+ progress_bar=progress_bar).metadata
230
+ idt_masks = []
231
+ for tct_mask, meta in tqdm(zip(tct_masks, metadata), desc="calculating idt", total=len(tct_masks),
232
+ disable=not progress_bar):
233
+ idt_mask = np.zeros(tct_mask.shape, dtype=np.uint8)
234
+ try:
235
+ csd_geometric = eval(meta["ideal_csd_config"])
236
+ except:
237
+ # This is required to support metadata from older simcats versions, where the class had a different name
238
+ csd_geometric = eval(meta["ideal_csd_config"].replace("IdealCSDGeometrical", "IdealCSDGeometric"))
239
+ bezier_coords = calculate_all_bezier_anchors(tct_params=csd_geometric.tct_params[:int(np.max(tct_mask) + 1)],
240
+ rotation=csd_geometric.rotation)
241
+ # get parameters of pixel vs voltage space for discretization
242
+ x_res = idt_mask.shape[1]
243
+ y_res = idt_mask.shape[0]
244
+ x_lims = meta["sweep_range_g1"]
245
+ y_lims = meta["sweep_range_g2"]
246
+ # stepsize x/y
247
+ x_step = (x_lims[-1] - x_lims[0]) / (x_res - 1)
248
+ y_step = (y_lims[-1] - y_lims[0]) / (y_res - 1)
249
+ try:
250
+ # min tct minus 1, because we always need to start with the one below of the lowest in the image
251
+ first_tct = int(np.max([np.min(tct_mask[np.nonzero(tct_mask)]) - 1, 1]))
252
+ except ValueError:
253
+ first_tct = 1
254
+ # iterate over all tcts that are part of the current voltage range
255
+ for i in range(first_tct, int(np.max(tct_mask) + 1)):
256
+ # The number of inter-dot transitions connected to the next higher TCT for every TCT is given by the ID of the TCT
257
+ for j in range(i):
258
+ # get start and end vector for inter dot transitions
259
+ inter_dot_transition = np.array([bezier_coords[i + 1][j * 2 + 1, 1, :], bezier_coords[i][j * 2, 1, :]])
260
+
261
+ # sample points between the two outer bezier anchors of the current TCT to find the intersection with the interdot vector
262
+ # (localized by the center bezier anchor). This is required because interdot transitions can be longer than
263
+ # the distance between the central anchors (length depends on rounding, more rounding = longer)
264
+ # rotate interdot transition into default representation
265
+ inter_dot_transition_rot = rotate_points(points=inter_dot_transition, angle=-csd_geometric.rotation)
266
+
267
+ # distance to lower TCT
268
+ bezier_coords_rot = rotate_points(points=bezier_coords[i][j * 2], angle=-csd_geometric.rotation)
269
+ x_eval = np.linspace(bezier_coords_rot[0, 0], bezier_coords_rot[2, 0], np.max(idt_mask.shape))
270
+ # bezier nodes as fortran array
271
+ nodes = np.asfortranarray(bezier_coords_rot.T)
272
+ # initialize bezier curve
273
+ bezier_curve = bezier.Curve.from_nodes(nodes)
274
+ t = np.linspace(0, 1, np.max(idt_mask.shape) * 2)
275
+ bezier_lut = bezier_curve.evaluate_multi(t)
276
+ y_eval = [bezier_lut[1, np.argmin(np.abs(bezier_lut[0, :] - x))] for x in x_eval]
277
+ # find the closest point: y_eval - (central_tct_anchor_y - (central_tct_anchor_x - x_eval) * (interdot_vec_y / interdot_vec_x))
278
+ y_res = np.abs(y_eval - (bezier_coords_rot[1, 1] - (bezier_coords_rot[1, 0] - x_eval) * (
279
+ (inter_dot_transition_rot[0][1] - inter_dot_transition_rot[1][1]) / (
280
+ inter_dot_transition_rot[0][0] - inter_dot_transition_rot[1][0]))))
281
+ intersection_pixel = np.argmin(y_res)
282
+ # calculate the distance from this point to the central bezier anchor
283
+ intersection_dist_to_bezier = np.linalg.norm(
284
+ [x_eval[intersection_pixel] - inter_dot_transition_rot[1][0],
285
+ y_eval[intersection_pixel] - inter_dot_transition_rot[1][1]])
286
+ intersection_dist_to_lower_bezier_percentage = intersection_dist_to_bezier / np.linalg.norm(
287
+ (inter_dot_transition_rot[1] - inter_dot_transition_rot[0]))
288
+
289
+ # distance to upper TCT
290
+ bezier_coords_rot = rotate_points(points=bezier_coords[i + 1][j * 2 + 1], angle=-csd_geometric.rotation)
291
+ x_eval = np.linspace(bezier_coords_rot[0, 0], bezier_coords_rot[2, 0], np.max(idt_mask.shape))
292
+ # bezier nodes as fortran array
293
+ nodes = np.asfortranarray(bezier_coords_rot.T)
294
+ # initialize bezier curve
295
+ bezier_curve = bezier.Curve.from_nodes(nodes)
296
+ t = np.linspace(0, 1, np.max(idt_mask.shape) * 2)
297
+ bezier_lut = bezier_curve.evaluate_multi(t)
298
+ y_eval = [bezier_lut[1, np.argmin(np.abs(bezier_lut[0, :] - x))] for x in x_eval]
299
+ # find the closest point: y_eval - (central_tct_anchor_y - (central_tct_anchor_x - x_eval) * (interdot_vec_y / interdot_vec_x))
300
+ y_res = np.abs(y_eval - (bezier_coords_rot[1, 1] - (bezier_coords_rot[1, 0] - x_eval) * (
301
+ (inter_dot_transition_rot[0][1] - inter_dot_transition_rot[1][1]) / (
302
+ inter_dot_transition_rot[0][0] - inter_dot_transition_rot[1][0]))))
303
+ intersection_pixel = np.argmin(y_res)
304
+ # calculate the distance from this point to the central bezier anchor
305
+ intersection_dist_to_bezier = np.linalg.norm(
306
+ [x_eval[intersection_pixel] - inter_dot_transition_rot[0][0],
307
+ y_eval[intersection_pixel] - inter_dot_transition_rot[0][1]])
308
+ intersection_dist_to_upper_bezier_percentage = intersection_dist_to_bezier / np.linalg.norm(
309
+ (inter_dot_transition_rot[1] - inter_dot_transition_rot[0]))
310
+
311
+ # sample some points along the transition
312
+ inter_dot_points = inter_dot_transition[0] + \
313
+ np.linspace(0 - intersection_dist_to_upper_bezier_percentage,
314
+ 1 + intersection_dist_to_lower_bezier_percentage,
315
+ np.max(idt_mask.shape))[..., np.newaxis] * (
316
+ inter_dot_transition[1] - inter_dot_transition[0])
317
+ # select only pixels that are in the csd-limits
318
+ valid_ids = np.where((inter_dot_points[:, 0] > (x_lims[0] - 0.5 * x_step)) & (
319
+ inter_dot_points[:, 0] < (x_lims[1] + 0.5 * x_step)) & (
320
+ inter_dot_points[:, 1] > (y_lims[0] - 0.5 * y_step)) & (
321
+ inter_dot_points[:, 1] < (y_lims[1] + 0.5 * y_step)))
322
+ inter_dot_points = inter_dot_points[valid_ids[0], :]
323
+
324
+ # insert pixels into the mask
325
+ # calculation of the pixel ids for the values:
326
+ # x = min(csd_x) + id * x_step
327
+ # add half step size, so that the pixel id of the nearest pixel is obtained after the division
328
+ # (round up if next higher value in range of 0.5 * step_size)
329
+ x_id = np.floor_divide(inter_dot_points[:, 0] + 0.5 * x_step - x_lims[0], x_step).astype(int)
330
+ y_id = np.floor_divide(inter_dot_points[:, 1] + 0.5 * y_step - y_lims[0], y_step).astype(int)
331
+ idt_mask[y_id, x_id] = i
332
+ # apply dot jumps if any were active
333
+ if "OccupationDotJumps_axis0" in meta or "OccupationDotJumps_axis1" in meta:
334
+ occ_jumps_objects_string = [s.split(", rng")[0] + ")" for s in
335
+ re.findall(r"OccupationDotJumps[^\]]*", meta["occupation_distortions"]) if
336
+ "[activated" in s]
337
+ occ_jumps_objects = [eval(s) for s in occ_jumps_objects_string]
338
+ for obj in occ_jumps_objects:
339
+ if f"OccupationDotJumps_axis{obj.axis}" in meta:
340
+ obj._OccupationDotJumps__activated = True
341
+ obj._OccupationDotJumps__previous_noise = meta[f"OccupationDotJumps_axis{obj.axis}"]
342
+ _, idt_mask = obj.noise_function(occupations=np.empty((0, 0, 2)), lead_transitions=idt_mask,
343
+ volt_limits_g1=x_lims, volt_limits_g2=y_lims, freeze=True)
344
+ idt_masks.append(idt_mask)
345
+ return idt_masks
346
+
347
+
348
+ def load_ct_masks(file: Union[str, h5py.File],
349
+ specific_ids: Union[range, List[int], np.ndarray, None] = None,
350
+ progress_bar: bool = True) -> List[np.ndarray]:
351
+ """Load Charge Transition (CT) masks as ground truth data.
352
+ In comparison to the Total Charge Transition (TCT) masks, the inter-dot transitions are included.
353
+
354
+ Args:
355
+ file: The file to read the data from. Can either be an object of the type `h5py.File` or the path to the
356
+ dataset. If you want to do multiple consecutive loads from the same file (e.g. for using th PyTorch
357
+ SimcatsDataset without preloading), consider initializing the file object yourself and passing it, to
358
+ improve the performance.
359
+ specific_ids: Determines if only specific ids should be loaded. Using this option, the returned values are
360
+ sorted according to the specified ids and not necessarily ascending. If set to None, all data is loaded.
361
+ Default is None.
362
+ progress_bar: Determines whether to display a progress bar. Default is True.
363
+
364
+ Returns:
365
+ Charge Transition (CT) masks
366
+ """
367
+ ct_masks = load_dataset(file=file, load_csds=False, load_tct_masks=True, specific_ids=specific_ids,
368
+ progress_bar=progress_bar).tct_masks
369
+ idt_masks = load_idt_masks(file=file, specific_ids=specific_ids, progress_bar=progress_bar)
370
+ for ct_mask, idt_mask in zip(ct_masks, idt_masks):
371
+ ct_mask[(idt_mask > 0) & (ct_mask == 0)] = idt_mask[(idt_mask > 0) & (ct_mask == 0)]
372
+ return ct_masks
373
+
374
+
375
+ def load_ct_by_dot_masks(file: Union[str, h5py.File],
376
+ specific_ids: Union[range, List[int], np.ndarray, None] = None,
377
+ progress_bar: bool = True,
378
+ lut_entries: int = 1000,
379
+ try_directly_loading_from_file: bool = True) -> List[np.ndarray]:
380
+ """Load Charge Transition (CT) masks with transitions labeled by affected dot as ground truth data.
381
+ In comparison to the Total Charge Transition (TCT) masks, the inter-dot transitions are included.
382
+
383
+ Args:
384
+ file: The file to read the data from. Can either be an object of the type `h5py.File` or the path to the
385
+ dataset. If you want to do multiple consecutive loads from the same file (e.g. for using th PyTorch
386
+ SimcatsDataset without preloading), consider initializing the file object yourself and passing it, to
387
+ improve the performance.
388
+ specific_ids: Determines if only specific ids should be loaded. Using this option, the returned values are
389
+ sorted according to the specified ids and not necessarily ascending. If set to None, all data is loaded.
390
+ Default is None.
391
+ progress_bar: Determines whether to display a progress bar. Default is True.
392
+ lut_entries: Number of lookup-table entries to use for tct_bezier. Default is 1000.
393
+ try_directly_loading_from_file: Specifies if the loader should try to find the masks in the h5 file before
394
+ falling back to calculating them (not all datasets include these masks). Default is True.
395
+
396
+ Returns:
397
+ Charge Transition (CT) masks
398
+ """
399
+ ct_masks = None
400
+ if try_directly_loading_from_file:
401
+ try:
402
+ ct_masks = load_dataset(file=file, load_csds=False, load_ct_by_dot_masks=True, specific_ids=specific_ids,
403
+ progress_bar=progress_bar).ct_by_dot_masks
404
+ except KeyError:
405
+ pass
406
+ # if the data could not be loaded from the file, or loading from the file was disabled, calculate it manually
407
+ if ct_masks is None:
408
+ ct_masks = load_tct_by_dot_masks(file=file, specific_ids=specific_ids, progress_bar=progress_bar,
409
+ lut_entries=lut_entries)
410
+ idt_masks = load_idt_masks(file=file, specific_ids=specific_ids, progress_bar=progress_bar)
411
+ for ct_mask, idt_mask in zip(ct_masks, idt_masks):
412
+ ct_mask[(idt_mask > 0) & (ct_mask == 0)] = 3
413
+ return ct_masks
414
+
415
+
416
+ def load_tc_region_masks(file: Union[str, h5py.File],
417
+ specific_ids: Union[range, List[int], np.ndarray, None] = None,
418
+ progress_bar: bool = True) -> List[np.ndarray]:
419
+ """Load Total Charge (TC) region masks as ground truth data.
420
+
421
+ Args:
422
+ file: The file to read the data from. Can either be an object of the type `h5py.File` or the path to the
423
+ dataset. If you want to do multiple consecutive loads from the same file (e.g. for using th PyTorch
424
+ SimcatsDataset without preloading), consider initializing the file object yourself and passing it, to
425
+ improve the performance.
426
+ specific_ids: Determines if only specific ids should be loaded. Using this option, the returned values are
427
+ sorted according to the specified ids and not necessarily ascending. If set to None, all data is loaded.
428
+ Default is None.
429
+ progress_bar: Determines whether to display a progress bar. Default is True.
430
+
431
+ Returns:
432
+ Total Charge (TC) region masks
433
+ """
434
+ return [np.round(np.sum(occ, axis=-1)).astype(np.uint8) for occ in
435
+ load_dataset(file=file, load_csds=False, load_occupations=True, specific_ids=specific_ids,
436
+ progress_bar=progress_bar).occupations]
437
+
438
+
439
+ def load_tc_region_minus_tct_masks(file: Union[str, h5py.File],
440
+ specific_ids: Union[range, List[int], np.ndarray, None] = None,
441
+ progress_bar: bool = True) -> List[np.ndarray]:
442
+ """Load Total Charge (TC) region minus Total Charge Transition (TCT) masks as ground truth data (TCTs are basically excluded from the regions).
443
+
444
+ Args:
445
+ file: The file to read the data from. Can either be an object of the type `h5py.File` or the path to the
446
+ dataset. If you want to do multiple consecutive loads from the same file (e.g. for using th PyTorch
447
+ SimcatsDataset without preloading), consider initializing the file object yourself and passing it, to
448
+ improve the performance.
449
+ specific_ids: Determines if only specific ids should be loaded. Using this option, the returned values are
450
+ sorted according to the specified ids and not necessarily ascending. If set to None, all data is loaded.
451
+ Default is None.
452
+ progress_bar: Determines whether to display a progress bar. Default is True.
453
+
454
+ Returns:
455
+ Total Charge (TC) region minus Total Charge Transition (TCT) masks
456
+ """
457
+ return [np.round(np.sum(occ, axis=-1) - tct_mask).astype(np.uint8) for (occ, tct_mask) in
458
+ zip(load_dataset(file=file, load_csds=False, load_occupations=True, specific_ids=specific_ids,
459
+ progress_bar=progress_bar).occupations,
460
+ load_dataset(file=file, load_csds=False, load_tct_masks=True, specific_ids=specific_ids,
461
+ progress_bar=progress_bar).tct_masks)]
462
+
463
+
464
+ def load_c_region_masks(file: Union[str, h5py.File],
465
+ specific_ids: Union[range, List[int], np.ndarray, None] = None,
466
+ progress_bar: bool = True) -> List[np.ndarray]:
467
+ """Load Charge (C) region masks as ground truth data (CTs are basically excluded from the TC regions).
468
+
469
+ Args:
470
+ file: The file to read the data from. Can either be an object of the type `h5py.File` or the path to the
471
+ dataset. If you want to do multiple consecutive loads from the same file (e.g. for using th PyTorch
472
+ SimcatsDataset without preloading), consider initializing the file object yourself and passing it, to
473
+ improve the performance.
474
+ specific_ids: Determines if only specific ids should be loaded. Using this option, the returned values are
475
+ sorted according to the specified ids and not necessarily ascending. If set to None, all data is loaded.
476
+ Default is None.
477
+ progress_bar: Determines whether to display a progress bar. Default is True.
478
+
479
+ Returns:
480
+ Charge (C) region masks
481
+ """
482
+ c_region_masks = load_tc_region_masks(file=file, specific_ids=specific_ids, progress_bar=progress_bar)
483
+ ct_masks = load_ct_masks(file=file, specific_ids=specific_ids, progress_bar=progress_bar)
484
+ for ct_mask, c_region_mask in zip(ct_masks, c_region_masks):
485
+ c_region_mask[c_region_mask == ct_mask] = 0
486
+ return c_region_masks