simcats-datasets 2.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- simcats_datasets/__init__.py +2 -0
- simcats_datasets/generation/__init__.py +6 -0
- simcats_datasets/generation/_create_dataset.py +221 -0
- simcats_datasets/generation/_create_simulated_dataset.py +372 -0
- simcats_datasets/loading/__init__.py +8 -0
- simcats_datasets/loading/_load_dataset.py +177 -0
- simcats_datasets/loading/load_ground_truth.py +486 -0
- simcats_datasets/loading/pytorch.py +426 -0
- simcats_datasets/support_functions/__init__.py +1 -0
- simcats_datasets/support_functions/_json_encoders.py +51 -0
- simcats_datasets/support_functions/clip_line_to_rectangle.py +191 -0
- simcats_datasets/support_functions/convert_lines.py +110 -0
- simcats_datasets/support_functions/data_preprocessing.py +351 -0
- simcats_datasets/support_functions/get_lead_transition_labels.py +102 -0
- simcats_datasets/support_functions/pytorch_format_output.py +170 -0
- simcats_datasets-2.4.0.dist-info/LICENSE +674 -0
- simcats_datasets-2.4.0.dist-info/METADATA +837 -0
- simcats_datasets-2.4.0.dist-info/RECORD +20 -0
- simcats_datasets-2.4.0.dist-info/WHEEL +5 -0
- simcats_datasets-2.4.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,177 @@
|
|
|
1
|
+
"""The main interface for loading simcats datasets, typically stored as HDF5 files.
|
|
2
|
+
|
|
3
|
+
@author: f.hader
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import json
|
|
7
|
+
|
|
8
|
+
from collections import namedtuple
|
|
9
|
+
from contextlib import nullcontext
|
|
10
|
+
from copy import deepcopy
|
|
11
|
+
|
|
12
|
+
from typing import List, Tuple, Union
|
|
13
|
+
|
|
14
|
+
import h5py
|
|
15
|
+
import numpy as np
|
|
16
|
+
from tqdm import tqdm
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def load_dataset(file: Union[str, h5py.File],
|
|
20
|
+
load_csds=True,
|
|
21
|
+
load_occupations: bool = False,
|
|
22
|
+
load_tct_masks: bool = False,
|
|
23
|
+
load_ct_by_dot_masks: bool = False,
|
|
24
|
+
load_line_coords: bool = False,
|
|
25
|
+
load_line_labels: bool = False,
|
|
26
|
+
load_metadata: bool = False,
|
|
27
|
+
load_ids: bool = False,
|
|
28
|
+
specific_ids: Union[range, List[int], np.ndarray, None] = None,
|
|
29
|
+
progress_bar: bool = False) -> Tuple:
|
|
30
|
+
"""Loads a dataset consisting of multiple CSDs from a given path.
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
file: The file to read the data from. Can either be an object of the type `h5py.File` or the path to the
|
|
34
|
+
dataset. If a path is supplied, load_dataset will open the file itself. If you want to do multiple
|
|
35
|
+
consecutive loads from the same file (e.g. for using th PyTorch SimcatsDataset without preloading), consider
|
|
36
|
+
initializing the file object yourself and passing it, to improve the performance.
|
|
37
|
+
load_csds: Determines if csds should be loaded. Default is True.
|
|
38
|
+
load_occupations: Determines if occupation data should be loaded. Default is False.
|
|
39
|
+
load_tct_masks: Determines if lead transition masks should be loaded. Default is False.
|
|
40
|
+
load_ct_by_dot_masks: Determines if charge transition labeled by affected dot masks should be loaded. This
|
|
41
|
+
requires that ct_by_dot_masks have been added to the dataset. If a dataset has been created using
|
|
42
|
+
create_simulated_dataset, these masks can be added afterwards using add_ct_by_dot_masks_to_dataset, mainly
|
|
43
|
+
to avoid recalculating them multiple times (for example for machine learning purposes). Default is False.
|
|
44
|
+
load_line_coords: Determines if lead transition definitions using start and end points should be loaded. Default
|
|
45
|
+
is False.
|
|
46
|
+
load_line_labels: Determines if labels for lead transitions defined using start and end points should be loaded.
|
|
47
|
+
Default is False.
|
|
48
|
+
load_metadata: Determines if the metadata (SimCATS config) of the CSDs should be loaded. Default is False.
|
|
49
|
+
load_ids: Determines if the available ids should be loaded (or in case of specific ids: the specific ids are
|
|
50
|
+
returned in the given order). Default is False.
|
|
51
|
+
specific_ids: Determines if only specific ids should be loaded. Using this option, the returned values are
|
|
52
|
+
sorted according to the specified ids and not necessarily ascending. If set to None, all data is loaded.
|
|
53
|
+
Default is None.
|
|
54
|
+
progress_bar: Determines whether to display a progress bar. This parameter has no functionality since version 2,
|
|
55
|
+
but is kept for compatibility reasons. Default is False.
|
|
56
|
+
|
|
57
|
+
Returns:
|
|
58
|
+
namedtuple: The namedtuple can be unpacked like every normal tuple, or instead accessed by field names. \n
|
|
59
|
+
Depending on what has been enabled, the following data is included in the named tuple: \n
|
|
60
|
+
- field 'csds': List containing all CSDs as numpy arrays. The list is sorted by the id of the CSDs (if no
|
|
61
|
+
specific_ids are provided, else the order is given by specific_ids).
|
|
62
|
+
- field 'occupations': List containing numpy arrays with occupations.
|
|
63
|
+
- field 'tct_masks': List containing numpy arrays of TCT masks.
|
|
64
|
+
- field 'ct_by_dot_masks': List containing numpy arrays of CT_by_dot masks.
|
|
65
|
+
- field 'line_coordinates': List containing numpy arrays of line coordinates.
|
|
66
|
+
- field 'line_labels': List containing a list of dictionaries (one dict for each line specified as line
|
|
67
|
+
coordinates).
|
|
68
|
+
- field 'metadata': List containing dictionaries with all metadata (simcats configs) for each CSD.
|
|
69
|
+
- field 'ids': List of the ids of the CSDs.
|
|
70
|
+
"""
|
|
71
|
+
# fieldname are used for the namedtuple, to make fields accessible by names
|
|
72
|
+
fieldnames = []
|
|
73
|
+
if load_csds:
|
|
74
|
+
fieldnames.append("csds")
|
|
75
|
+
if load_occupations:
|
|
76
|
+
fieldnames.append("occupations")
|
|
77
|
+
if load_tct_masks:
|
|
78
|
+
fieldnames.append("tct_masks")
|
|
79
|
+
if load_ct_by_dot_masks:
|
|
80
|
+
fieldnames.append("ct_by_dot_masks")
|
|
81
|
+
if load_line_coords:
|
|
82
|
+
fieldnames.append("line_coordinates")
|
|
83
|
+
if load_line_labels:
|
|
84
|
+
fieldnames.append("line_labels")
|
|
85
|
+
if load_metadata:
|
|
86
|
+
fieldnames.append("metadata")
|
|
87
|
+
if load_ids:
|
|
88
|
+
fieldnames.append("ids")
|
|
89
|
+
CSDDataset = namedtuple(typename="CSDDataset", field_names=fieldnames)
|
|
90
|
+
|
|
91
|
+
# use nullcontext to catch the case where a file is passed instead of the string
|
|
92
|
+
with h5py.File(file, "r") if isinstance(file, str) else nullcontext(file) as _file:
|
|
93
|
+
# if only specific ids should be loaded, check if all ids are available
|
|
94
|
+
if specific_ids is not None:
|
|
95
|
+
if isinstance(specific_ids, list) or isinstance(specific_ids, np.ndarray):
|
|
96
|
+
# remember the previous order to undo the sorting that is required for reading from h5
|
|
97
|
+
specific_ids = deepcopy(specific_ids)
|
|
98
|
+
undo_sort_ids = np.argsort(np.argsort(specific_ids))
|
|
99
|
+
specific_ids.sort()
|
|
100
|
+
else:
|
|
101
|
+
undo_sort_ids = None
|
|
102
|
+
if load_ids:
|
|
103
|
+
# only check if ids are correct, if load_ids is True. This prevents initializing a non-preloaded PyTorch
|
|
104
|
+
# Dataset with non-existing specific IDs (which else would only crash as soon as a non-existent ID is
|
|
105
|
+
# requested during training). We can't check this on loading CSDs etc. as it massively slows down loading.
|
|
106
|
+
if specific_ids is not None:
|
|
107
|
+
if np.min(specific_ids) < 0 or np.max(specific_ids) >= len(_file["csds"]):
|
|
108
|
+
msg = "Not all ids specified by 'specific_ids' are available in the dataset!"
|
|
109
|
+
raise IndexError(msg)
|
|
110
|
+
available_ids = specific_ids
|
|
111
|
+
else:
|
|
112
|
+
available_ids = range(len(_file["csds"]))
|
|
113
|
+
|
|
114
|
+
if load_csds:
|
|
115
|
+
if specific_ids is not None:
|
|
116
|
+
csds = _file["csds"][specific_ids]
|
|
117
|
+
else:
|
|
118
|
+
csds = _file["csds"][:]
|
|
119
|
+
if load_occupations:
|
|
120
|
+
if specific_ids is not None:
|
|
121
|
+
occupations = _file["occupations"][specific_ids]
|
|
122
|
+
else:
|
|
123
|
+
occupations = _file["occupations"][:]
|
|
124
|
+
if load_tct_masks:
|
|
125
|
+
if specific_ids is not None:
|
|
126
|
+
tct_masks = _file["tct_masks"][specific_ids]
|
|
127
|
+
else:
|
|
128
|
+
tct_masks = _file["tct_masks"][:]
|
|
129
|
+
if load_ct_by_dot_masks:
|
|
130
|
+
if specific_ids is not None:
|
|
131
|
+
ct_by_dot_masks = _file["ct_by_dot_masks"][specific_ids]
|
|
132
|
+
else:
|
|
133
|
+
ct_by_dot_masks = _file["ct_by_dot_masks"][:]
|
|
134
|
+
if load_line_coords:
|
|
135
|
+
if specific_ids is not None:
|
|
136
|
+
# remove padded nan values
|
|
137
|
+
line_coords = [l_c[~np.isnan(l_c)].reshape((-1, 4)) for l_c in _file["line_coordinates"][specific_ids]]
|
|
138
|
+
else:
|
|
139
|
+
# remove padded nan values
|
|
140
|
+
line_coords = [l_c[~np.isnan(l_c)].reshape((-1, 4)) for l_c in _file["line_coordinates"][:]]
|
|
141
|
+
if load_line_labels:
|
|
142
|
+
if specific_ids is not None:
|
|
143
|
+
line_labels = [json.loads(l_l.tobytes().strip().decode("utf-8")) for l_l in
|
|
144
|
+
_file["line_labels"][specific_ids]]
|
|
145
|
+
else:
|
|
146
|
+
line_labels = [json.loads(l_l.tobytes().strip().decode("utf-8")) for l_l in _file["line_labels"][:]]
|
|
147
|
+
if load_metadata:
|
|
148
|
+
if specific_ids is not None:
|
|
149
|
+
metadata = [json.loads(meta.tobytes().strip().decode("utf-8")) for meta in
|
|
150
|
+
_file["metadata"][specific_ids]]
|
|
151
|
+
else:
|
|
152
|
+
metadata = [json.loads(meta.tobytes().strip().decode("utf-8")) for meta in _file["metadata"][:]]
|
|
153
|
+
|
|
154
|
+
# create a list of the further data to be returned (if activated)
|
|
155
|
+
return_data = []
|
|
156
|
+
if load_csds:
|
|
157
|
+
return_data.append(csds)
|
|
158
|
+
if load_occupations:
|
|
159
|
+
return_data.append(occupations)
|
|
160
|
+
if load_tct_masks:
|
|
161
|
+
return_data.append(tct_masks)
|
|
162
|
+
if load_ct_by_dot_masks:
|
|
163
|
+
return_data.append(ct_by_dot_masks)
|
|
164
|
+
if load_line_coords:
|
|
165
|
+
return_data.append(line_coords)
|
|
166
|
+
if load_line_labels:
|
|
167
|
+
return_data.append(line_labels)
|
|
168
|
+
if load_metadata:
|
|
169
|
+
return_data.append(metadata)
|
|
170
|
+
if load_ids:
|
|
171
|
+
return_data.append(available_ids)
|
|
172
|
+
|
|
173
|
+
# revert sorting if specific ids were used
|
|
174
|
+
if specific_ids is not None and undo_sort_ids is not None:
|
|
175
|
+
return_data = [[x[i] for i in undo_sort_ids] for x in return_data]
|
|
176
|
+
|
|
177
|
+
return CSDDataset._make(tuple(return_data))
|
|
@@ -0,0 +1,486 @@
|
|
|
1
|
+
"""Functions for providing ground truth data to be used with the **Pytorch Dataset class**.
|
|
2
|
+
|
|
3
|
+
For examples of the different ground truth types, please have a look at the notebook Examples_Pytorch_SimcatsDataset.
|
|
4
|
+
|
|
5
|
+
Every function must accept a h5 File or path for a simcats_dataset as input, provide an option to use only specific_ids
|
|
6
|
+
and allow disabling the progress_bar.
|
|
7
|
+
Output type depends on the ground truth type. Could for example be a pixel mask or defined start end points of lines.
|
|
8
|
+
**Please look at load_zeros_masks for a reference.**
|
|
9
|
+
|
|
10
|
+
@author: f.hader
|
|
11
|
+
"""
|
|
12
|
+
|
|
13
|
+
from typing import Union, List
|
|
14
|
+
|
|
15
|
+
import bezier
|
|
16
|
+
import h5py
|
|
17
|
+
import numpy as np
|
|
18
|
+
|
|
19
|
+
# imports required for eval, to create IdealCSDGeometric objects from metadata strings
|
|
20
|
+
import re
|
|
21
|
+
|
|
22
|
+
import sympy
|
|
23
|
+
from simcats.ideal_csd import IdealCSDGeometric
|
|
24
|
+
from simcats.ideal_csd.geometric import calculate_all_bezier_anchors, tct_bezier, initialize_tct_functions
|
|
25
|
+
from numpy import array
|
|
26
|
+
from simcats.distortions import OccupationDotJumps
|
|
27
|
+
from tqdm import tqdm
|
|
28
|
+
|
|
29
|
+
from simcats_datasets.loading import load_dataset
|
|
30
|
+
from simcats.support_functions import rotate_points
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def load_zeros_masks(file: Union[str, h5py.File],
|
|
34
|
+
specific_ids: Union[range, List[int], np.ndarray, None] = None,
|
|
35
|
+
progress_bar: bool = True) -> List[np.ndarray]:
|
|
36
|
+
"""Load no/empty ground truth data (arrays with only zeros).
|
|
37
|
+
Used for loading sets without ground truth. This is helpful to e.g. load experimental datasets without labels with
|
|
38
|
+
the pytorch SimcatsDataset class to analyze train results with the same Interface as for simulated data.
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
file: The file to read the data from. Can either be an object of the type `h5py.File` or the path to the
|
|
42
|
+
dataset. If you want to do multiple consecutive loads from the same file (e.g. for using th PyTorch
|
|
43
|
+
SimcatsDataset without preloading), consider initializing the file object yourself and passing it, to
|
|
44
|
+
improve the performance.
|
|
45
|
+
specific_ids: Determines if only specific ids should be loaded. Using this option, the returned values are
|
|
46
|
+
sorted according to the specified ids and not necessarily ascending. If set to None, all data is loaded.
|
|
47
|
+
Default is None.
|
|
48
|
+
progress_bar: Determines whether to display a progress bar. Default is True.
|
|
49
|
+
|
|
50
|
+
Returns:
|
|
51
|
+
List of arrays containing only zeros as ground truth data
|
|
52
|
+
"""
|
|
53
|
+
return [np.zeros_like(csd, dtype=np.uint8) for csd in
|
|
54
|
+
load_dataset(file=file, load_csds=True, specific_ids=specific_ids, progress_bar=progress_bar).csds]
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def load_tct_masks(file: Union[str, h5py.File],
|
|
58
|
+
specific_ids: Union[range, List[int], np.ndarray, None] = None,
|
|
59
|
+
progress_bar: bool = True) -> List[np.ndarray]:
|
|
60
|
+
"""Load Total Charge Transition (TCT) masks as ground truth data.
|
|
61
|
+
|
|
62
|
+
Args:
|
|
63
|
+
file: The file to read the data from. Can either be an object of the type `h5py.File` or the path to the
|
|
64
|
+
dataset. If you want to do multiple consecutive loads from the same file (e.g. for using th PyTorch
|
|
65
|
+
SimcatsDataset without preloading), consider initializing the file object yourself and passing it, to
|
|
66
|
+
improve the performance.
|
|
67
|
+
specific_ids: Determines if only specific ids should be loaded. Using this option, the returned values are
|
|
68
|
+
sorted according to the specified ids and not necessarily ascending. If set to None, all data is loaded.
|
|
69
|
+
Default is None.
|
|
70
|
+
progress_bar: Determines whether to display a progress bar. Default is True.
|
|
71
|
+
|
|
72
|
+
Returns:
|
|
73
|
+
Total Charge Transition (TCT) masks
|
|
74
|
+
"""
|
|
75
|
+
return load_dataset(file=file, load_csds=False, load_tct_masks=True, specific_ids=specific_ids,
|
|
76
|
+
progress_bar=progress_bar).tct_masks
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def load_tct_by_dot_masks(file: Union[str, h5py.File],
|
|
80
|
+
specific_ids: Union[range, List[int], np.ndarray, None] = None,
|
|
81
|
+
progress_bar: bool = True,
|
|
82
|
+
lut_entries: int = 1000) -> List[np.ndarray]:
|
|
83
|
+
"""Load Total Charge Transition (TCT) masks with transitions labeled by affected dot as ground truth data.
|
|
84
|
+
|
|
85
|
+
Args:
|
|
86
|
+
file: The file to read the data from. Can either be an object of the type `h5py.File` or the path to the
|
|
87
|
+
dataset. If you want to do multiple consecutive loads from the same file (e.g. for using th PyTorch
|
|
88
|
+
SimcatsDataset without preloading), consider initializing the file object yourself and passing it, to
|
|
89
|
+
improve the performance.
|
|
90
|
+
specific_ids: Determines if only specific ids should be loaded. Using this option, the returned values are
|
|
91
|
+
sorted according to the specified ids and not necessarily ascending. If set to None, all data is loaded.
|
|
92
|
+
Default is None.
|
|
93
|
+
progress_bar: Determines whether to display a progress bar. Default is True.
|
|
94
|
+
lut_entries: Number of lookup-table entries to use for tct_bezier. Default is 1000.
|
|
95
|
+
|
|
96
|
+
Returns:
|
|
97
|
+
Total Charge Transition (TCT) masks
|
|
98
|
+
"""
|
|
99
|
+
tct_masks = load_dataset(file=file, load_csds=False, load_tct_masks=True, specific_ids=specific_ids,
|
|
100
|
+
progress_bar=progress_bar).tct_masks
|
|
101
|
+
metadata = load_dataset(file=file, load_csds=False, load_metadata=True, specific_ids=specific_ids,
|
|
102
|
+
progress_bar=progress_bar).metadata
|
|
103
|
+
|
|
104
|
+
for mask_id, meta in tqdm(enumerate(metadata), desc="calculating transitions", total=len(metadata),
|
|
105
|
+
disable=not progress_bar):
|
|
106
|
+
try:
|
|
107
|
+
csd_geometric = eval(meta["ideal_csd_config"])
|
|
108
|
+
except:
|
|
109
|
+
# This is required to support metadata from older simcats versions, where the class had a different name
|
|
110
|
+
csd_geometric = eval(meta["ideal_csd_config"].replace("IdealCSDGeometrical", "IdealCSDGeometric"))
|
|
111
|
+
|
|
112
|
+
tct_ids = np.unique(tct_masks[mask_id][np.nonzero(tct_masks[mask_id])])
|
|
113
|
+
# skip images with no TCTs
|
|
114
|
+
if tct_ids.size == 0:
|
|
115
|
+
continue
|
|
116
|
+
# setup tct functions
|
|
117
|
+
tct_funcs = initialize_tct_functions(tct_params=csd_geometric.tct_params[np.min(tct_ids) - 1:np.max(tct_ids)],
|
|
118
|
+
max_peaks=np.min(tct_ids))
|
|
119
|
+
# calculate all bezier anchors
|
|
120
|
+
bezier_coords = calculate_all_bezier_anchors(
|
|
121
|
+
tct_params=csd_geometric.tct_params[np.min(tct_ids) - 1:np.max(tct_ids)], max_peaks=np.min(tct_ids),
|
|
122
|
+
rotation=csd_geometric.rotation)
|
|
123
|
+
|
|
124
|
+
# get parameters of pixel vs voltage space for discretization
|
|
125
|
+
x_res = tct_masks[mask_id].shape[1]
|
|
126
|
+
y_res = tct_masks[mask_id].shape[0]
|
|
127
|
+
x_lims = meta["sweep_range_g1"]
|
|
128
|
+
y_lims = meta["sweep_range_g2"]
|
|
129
|
+
# stepsize x/y
|
|
130
|
+
x_step = (x_lims[-1] - x_lims[0]) / (x_res - 1)
|
|
131
|
+
y_step = (y_lims[-1] - y_lims[0]) / (y_res - 1)
|
|
132
|
+
|
|
133
|
+
# get corner points to know max value range for generating points
|
|
134
|
+
corner_points = np.array(
|
|
135
|
+
[[x_lims[0], y_lims[0]], [x_lims[0], y_lims[1]], [x_lims[1], y_lims[0]], [x_lims[1], y_lims[1]]])
|
|
136
|
+
x_c_rot = rotate_points(points=corner_points, angle=-csd_geometric.rotation)[:, 0]
|
|
137
|
+
|
|
138
|
+
# replace tct_mask by a new empty array
|
|
139
|
+
tct_masks[mask_id] = np.zeros_like(tct_masks[mask_id], dtype=np.uint8)
|
|
140
|
+
for tct_id in tct_ids:
|
|
141
|
+
for transition in range(tct_id * 2):
|
|
142
|
+
# get start x position of current transition
|
|
143
|
+
if transition == 0:
|
|
144
|
+
x_start = np.min(x_c_rot) - x_step
|
|
145
|
+
else:
|
|
146
|
+
# rotate bezier coords
|
|
147
|
+
bezier_coords_rot = rotate_points(points=bezier_coords[tct_id][transition - 1, 1],
|
|
148
|
+
angle=-csd_geometric.rotation)
|
|
149
|
+
x_start = bezier_coords_rot[0]
|
|
150
|
+
# get stop x position of current transition
|
|
151
|
+
if transition == tct_id * 2 - 1:
|
|
152
|
+
x_stop = np.max(x_c_rot) + x_step
|
|
153
|
+
else:
|
|
154
|
+
# rotate bezier coords
|
|
155
|
+
bezier_coords_rot = rotate_points(points=bezier_coords[tct_id][transition, 1],
|
|
156
|
+
angle=-csd_geometric.rotation)
|
|
157
|
+
x_stop = bezier_coords_rot[0]
|
|
158
|
+
|
|
159
|
+
# generate enough x-values to cover the complete range of the CSD with a
|
|
160
|
+
# higher resolution than required to have a precise result after discretization
|
|
161
|
+
tct_points = np.empty(((x_res + y_res) * 4, 2))
|
|
162
|
+
tct_points[:, 0] = np.linspace(x_start, x_stop, (x_res + y_res) * 4)
|
|
163
|
+
|
|
164
|
+
# Insert the transition line into the CSD
|
|
165
|
+
# The required TCT points are sampled and discretized.
|
|
166
|
+
# generate the y-values for all generated x-values
|
|
167
|
+
tct_points[:, 1] = tct_funcs[tct_id](x_eval=tct_points[:, 0], lut_entries=lut_entries)
|
|
168
|
+
|
|
169
|
+
# rotate the TCT into the original orientation
|
|
170
|
+
wf_points_rot = rotate_points(points=tct_points, angle=csd_geometric.rotation)
|
|
171
|
+
|
|
172
|
+
# select only TCT pixels that are in the csd-limits
|
|
173
|
+
valid_ids = np.where((wf_points_rot[:, 0] > (x_lims[0] - 0.5 * x_step)) & (
|
|
174
|
+
wf_points_rot[:, 0] < (x_lims[1] + 0.5 * x_step)) & (
|
|
175
|
+
wf_points_rot[:, 1] > (y_lims[0] - 0.5 * y_step)) & (
|
|
176
|
+
wf_points_rot[:, 1] < (y_lims[1] + 0.5 * y_step)))
|
|
177
|
+
# x_h_rot = x_h_rot[valid_ids]
|
|
178
|
+
# y_h_rot = y_h_rot[valid_ids]
|
|
179
|
+
wf_points_rot = wf_points_rot[valid_ids[0], :]
|
|
180
|
+
|
|
181
|
+
# insert TCT pixels into the csd
|
|
182
|
+
# calculation of the ids for the values:
|
|
183
|
+
# x = min(csd_x) + id * x_step
|
|
184
|
+
# add half step size, so that the pixel id of the nearest pixel is obtained after the division
|
|
185
|
+
# (round up if next higher value in range of 0.5 * step_size)
|
|
186
|
+
x_id = np.floor_divide(wf_points_rot[:, 0] + 0.5 * x_step - x_lims[0], x_step).astype(int)
|
|
187
|
+
y_id = np.floor_divide(wf_points_rot[:, 1] + 0.5 * y_step - y_lims[0], y_step).astype(int)
|
|
188
|
+
tct_masks[mask_id][y_id, x_id] = ((transition + 1) % 2) + 1
|
|
189
|
+
|
|
190
|
+
# apply dot jumps if any were active
|
|
191
|
+
if "OccupationDotJumps_axis0" in meta or "OccupationDotJumps_axis1" in meta:
|
|
192
|
+
occ_jumps_objects_string = [s.split(", rng")[0] + ")" for s in
|
|
193
|
+
re.findall(r"OccupationDotJumps[^\]]*", meta["occupation_distortions"]) if
|
|
194
|
+
"[activated" in s]
|
|
195
|
+
occ_jumps_objects = [eval(s) for s in occ_jumps_objects_string]
|
|
196
|
+
for obj in occ_jumps_objects:
|
|
197
|
+
if f"OccupationDotJumps_axis{obj.axis}" in meta:
|
|
198
|
+
obj._OccupationDotJumps__activated = True
|
|
199
|
+
obj._OccupationDotJumps__previous_noise = meta[f"OccupationDotJumps_axis{obj.axis}"]
|
|
200
|
+
_, tct_masks[mask_id] = obj.noise_function(occupations=np.empty((0, 0, 2)),
|
|
201
|
+
lead_transitions=tct_masks[mask_id],
|
|
202
|
+
volt_limits_g1=x_lims, volt_limits_g2=y_lims,
|
|
203
|
+
freeze=True)
|
|
204
|
+
return tct_masks
|
|
205
|
+
|
|
206
|
+
|
|
207
|
+
def load_idt_masks(file: Union[str, h5py.File],
|
|
208
|
+
specific_ids: Union[range, List[int], np.ndarray, None] = None,
|
|
209
|
+
progress_bar: bool = True) -> List[np.ndarray]:
|
|
210
|
+
"""Load Inter-Dot Transition (IDT) masks as ground truth data.
|
|
211
|
+
In comparison to the Total Charge Transition (TCT) masks, only inter-dot transitions are included.
|
|
212
|
+
|
|
213
|
+
Args:
|
|
214
|
+
file: The file to read the data from. Can either be an object of the type `h5py.File` or the path to the
|
|
215
|
+
dataset. If you want to do multiple consecutive loads from the same file (e.g. for using th PyTorch
|
|
216
|
+
SimcatsDataset without preloading), consider initializing the file object yourself and passing it, to
|
|
217
|
+
improve the performance.
|
|
218
|
+
specific_ids: Determines if only specific ids should be loaded. Using this option, the returned values are
|
|
219
|
+
sorted according to the specified ids and not necessarily ascending. If set to None, all data is loaded.
|
|
220
|
+
Default is None.
|
|
221
|
+
progress_bar: Determines whether to display a progress bar. Default is True.
|
|
222
|
+
|
|
223
|
+
Returns:
|
|
224
|
+
Inter-Dot Transition (IDT) masks
|
|
225
|
+
"""
|
|
226
|
+
tct_masks = load_dataset(file=file, load_csds=False, load_tct_masks=True, specific_ids=specific_ids,
|
|
227
|
+
progress_bar=progress_bar).tct_masks
|
|
228
|
+
metadata = load_dataset(file=file, load_csds=False, load_metadata=True, specific_ids=specific_ids,
|
|
229
|
+
progress_bar=progress_bar).metadata
|
|
230
|
+
idt_masks = []
|
|
231
|
+
for tct_mask, meta in tqdm(zip(tct_masks, metadata), desc="calculating idt", total=len(tct_masks),
|
|
232
|
+
disable=not progress_bar):
|
|
233
|
+
idt_mask = np.zeros(tct_mask.shape, dtype=np.uint8)
|
|
234
|
+
try:
|
|
235
|
+
csd_geometric = eval(meta["ideal_csd_config"])
|
|
236
|
+
except:
|
|
237
|
+
# This is required to support metadata from older simcats versions, where the class had a different name
|
|
238
|
+
csd_geometric = eval(meta["ideal_csd_config"].replace("IdealCSDGeometrical", "IdealCSDGeometric"))
|
|
239
|
+
bezier_coords = calculate_all_bezier_anchors(tct_params=csd_geometric.tct_params[:int(np.max(tct_mask) + 1)],
|
|
240
|
+
rotation=csd_geometric.rotation)
|
|
241
|
+
# get parameters of pixel vs voltage space for discretization
|
|
242
|
+
x_res = idt_mask.shape[1]
|
|
243
|
+
y_res = idt_mask.shape[0]
|
|
244
|
+
x_lims = meta["sweep_range_g1"]
|
|
245
|
+
y_lims = meta["sweep_range_g2"]
|
|
246
|
+
# stepsize x/y
|
|
247
|
+
x_step = (x_lims[-1] - x_lims[0]) / (x_res - 1)
|
|
248
|
+
y_step = (y_lims[-1] - y_lims[0]) / (y_res - 1)
|
|
249
|
+
try:
|
|
250
|
+
# min tct minus 1, because we always need to start with the one below of the lowest in the image
|
|
251
|
+
first_tct = int(np.max([np.min(tct_mask[np.nonzero(tct_mask)]) - 1, 1]))
|
|
252
|
+
except ValueError:
|
|
253
|
+
first_tct = 1
|
|
254
|
+
# iterate over all tcts that are part of the current voltage range
|
|
255
|
+
for i in range(first_tct, int(np.max(tct_mask) + 1)):
|
|
256
|
+
# The number of inter-dot transitions connected to the next higher TCT for every TCT is given by the ID of the TCT
|
|
257
|
+
for j in range(i):
|
|
258
|
+
# get start and end vector for inter dot transitions
|
|
259
|
+
inter_dot_transition = np.array([bezier_coords[i + 1][j * 2 + 1, 1, :], bezier_coords[i][j * 2, 1, :]])
|
|
260
|
+
|
|
261
|
+
# sample points between the two outer bezier anchors of the current TCT to find the intersection with the interdot vector
|
|
262
|
+
# (localized by the center bezier anchor). This is required because interdot transitions can be longer than
|
|
263
|
+
# the distance between the central anchors (length depends on rounding, more rounding = longer)
|
|
264
|
+
# rotate interdot transition into default representation
|
|
265
|
+
inter_dot_transition_rot = rotate_points(points=inter_dot_transition, angle=-csd_geometric.rotation)
|
|
266
|
+
|
|
267
|
+
# distance to lower TCT
|
|
268
|
+
bezier_coords_rot = rotate_points(points=bezier_coords[i][j * 2], angle=-csd_geometric.rotation)
|
|
269
|
+
x_eval = np.linspace(bezier_coords_rot[0, 0], bezier_coords_rot[2, 0], np.max(idt_mask.shape))
|
|
270
|
+
# bezier nodes as fortran array
|
|
271
|
+
nodes = np.asfortranarray(bezier_coords_rot.T)
|
|
272
|
+
# initialize bezier curve
|
|
273
|
+
bezier_curve = bezier.Curve.from_nodes(nodes)
|
|
274
|
+
t = np.linspace(0, 1, np.max(idt_mask.shape) * 2)
|
|
275
|
+
bezier_lut = bezier_curve.evaluate_multi(t)
|
|
276
|
+
y_eval = [bezier_lut[1, np.argmin(np.abs(bezier_lut[0, :] - x))] for x in x_eval]
|
|
277
|
+
# find the closest point: y_eval - (central_tct_anchor_y - (central_tct_anchor_x - x_eval) * (interdot_vec_y / interdot_vec_x))
|
|
278
|
+
y_res = np.abs(y_eval - (bezier_coords_rot[1, 1] - (bezier_coords_rot[1, 0] - x_eval) * (
|
|
279
|
+
(inter_dot_transition_rot[0][1] - inter_dot_transition_rot[1][1]) / (
|
|
280
|
+
inter_dot_transition_rot[0][0] - inter_dot_transition_rot[1][0]))))
|
|
281
|
+
intersection_pixel = np.argmin(y_res)
|
|
282
|
+
# calculate the distance from this point to the central bezier anchor
|
|
283
|
+
intersection_dist_to_bezier = np.linalg.norm(
|
|
284
|
+
[x_eval[intersection_pixel] - inter_dot_transition_rot[1][0],
|
|
285
|
+
y_eval[intersection_pixel] - inter_dot_transition_rot[1][1]])
|
|
286
|
+
intersection_dist_to_lower_bezier_percentage = intersection_dist_to_bezier / np.linalg.norm(
|
|
287
|
+
(inter_dot_transition_rot[1] - inter_dot_transition_rot[0]))
|
|
288
|
+
|
|
289
|
+
# distance to upper TCT
|
|
290
|
+
bezier_coords_rot = rotate_points(points=bezier_coords[i + 1][j * 2 + 1], angle=-csd_geometric.rotation)
|
|
291
|
+
x_eval = np.linspace(bezier_coords_rot[0, 0], bezier_coords_rot[2, 0], np.max(idt_mask.shape))
|
|
292
|
+
# bezier nodes as fortran array
|
|
293
|
+
nodes = np.asfortranarray(bezier_coords_rot.T)
|
|
294
|
+
# initialize bezier curve
|
|
295
|
+
bezier_curve = bezier.Curve.from_nodes(nodes)
|
|
296
|
+
t = np.linspace(0, 1, np.max(idt_mask.shape) * 2)
|
|
297
|
+
bezier_lut = bezier_curve.evaluate_multi(t)
|
|
298
|
+
y_eval = [bezier_lut[1, np.argmin(np.abs(bezier_lut[0, :] - x))] for x in x_eval]
|
|
299
|
+
# find the closest point: y_eval - (central_tct_anchor_y - (central_tct_anchor_x - x_eval) * (interdot_vec_y / interdot_vec_x))
|
|
300
|
+
y_res = np.abs(y_eval - (bezier_coords_rot[1, 1] - (bezier_coords_rot[1, 0] - x_eval) * (
|
|
301
|
+
(inter_dot_transition_rot[0][1] - inter_dot_transition_rot[1][1]) / (
|
|
302
|
+
inter_dot_transition_rot[0][0] - inter_dot_transition_rot[1][0]))))
|
|
303
|
+
intersection_pixel = np.argmin(y_res)
|
|
304
|
+
# calculate the distance from this point to the central bezier anchor
|
|
305
|
+
intersection_dist_to_bezier = np.linalg.norm(
|
|
306
|
+
[x_eval[intersection_pixel] - inter_dot_transition_rot[0][0],
|
|
307
|
+
y_eval[intersection_pixel] - inter_dot_transition_rot[0][1]])
|
|
308
|
+
intersection_dist_to_upper_bezier_percentage = intersection_dist_to_bezier / np.linalg.norm(
|
|
309
|
+
(inter_dot_transition_rot[1] - inter_dot_transition_rot[0]))
|
|
310
|
+
|
|
311
|
+
# sample some points along the transition
|
|
312
|
+
inter_dot_points = inter_dot_transition[0] + \
|
|
313
|
+
np.linspace(0 - intersection_dist_to_upper_bezier_percentage,
|
|
314
|
+
1 + intersection_dist_to_lower_bezier_percentage,
|
|
315
|
+
np.max(idt_mask.shape))[..., np.newaxis] * (
|
|
316
|
+
inter_dot_transition[1] - inter_dot_transition[0])
|
|
317
|
+
# select only pixels that are in the csd-limits
|
|
318
|
+
valid_ids = np.where((inter_dot_points[:, 0] > (x_lims[0] - 0.5 * x_step)) & (
|
|
319
|
+
inter_dot_points[:, 0] < (x_lims[1] + 0.5 * x_step)) & (
|
|
320
|
+
inter_dot_points[:, 1] > (y_lims[0] - 0.5 * y_step)) & (
|
|
321
|
+
inter_dot_points[:, 1] < (y_lims[1] + 0.5 * y_step)))
|
|
322
|
+
inter_dot_points = inter_dot_points[valid_ids[0], :]
|
|
323
|
+
|
|
324
|
+
# insert pixels into the mask
|
|
325
|
+
# calculation of the pixel ids for the values:
|
|
326
|
+
# x = min(csd_x) + id * x_step
|
|
327
|
+
# add half step size, so that the pixel id of the nearest pixel is obtained after the division
|
|
328
|
+
# (round up if next higher value in range of 0.5 * step_size)
|
|
329
|
+
x_id = np.floor_divide(inter_dot_points[:, 0] + 0.5 * x_step - x_lims[0], x_step).astype(int)
|
|
330
|
+
y_id = np.floor_divide(inter_dot_points[:, 1] + 0.5 * y_step - y_lims[0], y_step).astype(int)
|
|
331
|
+
idt_mask[y_id, x_id] = i
|
|
332
|
+
# apply dot jumps if any were active
|
|
333
|
+
if "OccupationDotJumps_axis0" in meta or "OccupationDotJumps_axis1" in meta:
|
|
334
|
+
occ_jumps_objects_string = [s.split(", rng")[0] + ")" for s in
|
|
335
|
+
re.findall(r"OccupationDotJumps[^\]]*", meta["occupation_distortions"]) if
|
|
336
|
+
"[activated" in s]
|
|
337
|
+
occ_jumps_objects = [eval(s) for s in occ_jumps_objects_string]
|
|
338
|
+
for obj in occ_jumps_objects:
|
|
339
|
+
if f"OccupationDotJumps_axis{obj.axis}" in meta:
|
|
340
|
+
obj._OccupationDotJumps__activated = True
|
|
341
|
+
obj._OccupationDotJumps__previous_noise = meta[f"OccupationDotJumps_axis{obj.axis}"]
|
|
342
|
+
_, idt_mask = obj.noise_function(occupations=np.empty((0, 0, 2)), lead_transitions=idt_mask,
|
|
343
|
+
volt_limits_g1=x_lims, volt_limits_g2=y_lims, freeze=True)
|
|
344
|
+
idt_masks.append(idt_mask)
|
|
345
|
+
return idt_masks
|
|
346
|
+
|
|
347
|
+
|
|
348
|
+
def load_ct_masks(file: Union[str, h5py.File],
|
|
349
|
+
specific_ids: Union[range, List[int], np.ndarray, None] = None,
|
|
350
|
+
progress_bar: bool = True) -> List[np.ndarray]:
|
|
351
|
+
"""Load Charge Transition (CT) masks as ground truth data.
|
|
352
|
+
In comparison to the Total Charge Transition (TCT) masks, the inter-dot transitions are included.
|
|
353
|
+
|
|
354
|
+
Args:
|
|
355
|
+
file: The file to read the data from. Can either be an object of the type `h5py.File` or the path to the
|
|
356
|
+
dataset. If you want to do multiple consecutive loads from the same file (e.g. for using th PyTorch
|
|
357
|
+
SimcatsDataset without preloading), consider initializing the file object yourself and passing it, to
|
|
358
|
+
improve the performance.
|
|
359
|
+
specific_ids: Determines if only specific ids should be loaded. Using this option, the returned values are
|
|
360
|
+
sorted according to the specified ids and not necessarily ascending. If set to None, all data is loaded.
|
|
361
|
+
Default is None.
|
|
362
|
+
progress_bar: Determines whether to display a progress bar. Default is True.
|
|
363
|
+
|
|
364
|
+
Returns:
|
|
365
|
+
Charge Transition (CT) masks
|
|
366
|
+
"""
|
|
367
|
+
ct_masks = load_dataset(file=file, load_csds=False, load_tct_masks=True, specific_ids=specific_ids,
|
|
368
|
+
progress_bar=progress_bar).tct_masks
|
|
369
|
+
idt_masks = load_idt_masks(file=file, specific_ids=specific_ids, progress_bar=progress_bar)
|
|
370
|
+
for ct_mask, idt_mask in zip(ct_masks, idt_masks):
|
|
371
|
+
ct_mask[(idt_mask > 0) & (ct_mask == 0)] = idt_mask[(idt_mask > 0) & (ct_mask == 0)]
|
|
372
|
+
return ct_masks
|
|
373
|
+
|
|
374
|
+
|
|
375
|
+
def load_ct_by_dot_masks(file: Union[str, h5py.File],
|
|
376
|
+
specific_ids: Union[range, List[int], np.ndarray, None] = None,
|
|
377
|
+
progress_bar: bool = True,
|
|
378
|
+
lut_entries: int = 1000,
|
|
379
|
+
try_directly_loading_from_file: bool = True) -> List[np.ndarray]:
|
|
380
|
+
"""Load Charge Transition (CT) masks with transitions labeled by affected dot as ground truth data.
|
|
381
|
+
In comparison to the Total Charge Transition (TCT) masks, the inter-dot transitions are included.
|
|
382
|
+
|
|
383
|
+
Args:
|
|
384
|
+
file: The file to read the data from. Can either be an object of the type `h5py.File` or the path to the
|
|
385
|
+
dataset. If you want to do multiple consecutive loads from the same file (e.g. for using th PyTorch
|
|
386
|
+
SimcatsDataset without preloading), consider initializing the file object yourself and passing it, to
|
|
387
|
+
improve the performance.
|
|
388
|
+
specific_ids: Determines if only specific ids should be loaded. Using this option, the returned values are
|
|
389
|
+
sorted according to the specified ids and not necessarily ascending. If set to None, all data is loaded.
|
|
390
|
+
Default is None.
|
|
391
|
+
progress_bar: Determines whether to display a progress bar. Default is True.
|
|
392
|
+
lut_entries: Number of lookup-table entries to use for tct_bezier. Default is 1000.
|
|
393
|
+
try_directly_loading_from_file: Specifies if the loader should try to find the masks in the h5 file before
|
|
394
|
+
falling back to calculating them (not all datasets include these masks). Default is True.
|
|
395
|
+
|
|
396
|
+
Returns:
|
|
397
|
+
Charge Transition (CT) masks
|
|
398
|
+
"""
|
|
399
|
+
ct_masks = None
|
|
400
|
+
if try_directly_loading_from_file:
|
|
401
|
+
try:
|
|
402
|
+
ct_masks = load_dataset(file=file, load_csds=False, load_ct_by_dot_masks=True, specific_ids=specific_ids,
|
|
403
|
+
progress_bar=progress_bar).ct_by_dot_masks
|
|
404
|
+
except KeyError:
|
|
405
|
+
pass
|
|
406
|
+
# if the data could not be loaded from the file, or loading from the file was disabled, calculate it manually
|
|
407
|
+
if ct_masks is None:
|
|
408
|
+
ct_masks = load_tct_by_dot_masks(file=file, specific_ids=specific_ids, progress_bar=progress_bar,
|
|
409
|
+
lut_entries=lut_entries)
|
|
410
|
+
idt_masks = load_idt_masks(file=file, specific_ids=specific_ids, progress_bar=progress_bar)
|
|
411
|
+
for ct_mask, idt_mask in zip(ct_masks, idt_masks):
|
|
412
|
+
ct_mask[(idt_mask > 0) & (ct_mask == 0)] = 3
|
|
413
|
+
return ct_masks
|
|
414
|
+
|
|
415
|
+
|
|
416
|
+
def load_tc_region_masks(file: Union[str, h5py.File],
|
|
417
|
+
specific_ids: Union[range, List[int], np.ndarray, None] = None,
|
|
418
|
+
progress_bar: bool = True) -> List[np.ndarray]:
|
|
419
|
+
"""Load Total Charge (TC) region masks as ground truth data.
|
|
420
|
+
|
|
421
|
+
Args:
|
|
422
|
+
file: The file to read the data from. Can either be an object of the type `h5py.File` or the path to the
|
|
423
|
+
dataset. If you want to do multiple consecutive loads from the same file (e.g. for using th PyTorch
|
|
424
|
+
SimcatsDataset without preloading), consider initializing the file object yourself and passing it, to
|
|
425
|
+
improve the performance.
|
|
426
|
+
specific_ids: Determines if only specific ids should be loaded. Using this option, the returned values are
|
|
427
|
+
sorted according to the specified ids and not necessarily ascending. If set to None, all data is loaded.
|
|
428
|
+
Default is None.
|
|
429
|
+
progress_bar: Determines whether to display a progress bar. Default is True.
|
|
430
|
+
|
|
431
|
+
Returns:
|
|
432
|
+
Total Charge (TC) region masks
|
|
433
|
+
"""
|
|
434
|
+
return [np.round(np.sum(occ, axis=-1)).astype(np.uint8) for occ in
|
|
435
|
+
load_dataset(file=file, load_csds=False, load_occupations=True, specific_ids=specific_ids,
|
|
436
|
+
progress_bar=progress_bar).occupations]
|
|
437
|
+
|
|
438
|
+
|
|
439
|
+
def load_tc_region_minus_tct_masks(file: Union[str, h5py.File],
|
|
440
|
+
specific_ids: Union[range, List[int], np.ndarray, None] = None,
|
|
441
|
+
progress_bar: bool = True) -> List[np.ndarray]:
|
|
442
|
+
"""Load Total Charge (TC) region minus Total Charge Transition (TCT) masks as ground truth data (TCTs are basically excluded from the regions).
|
|
443
|
+
|
|
444
|
+
Args:
|
|
445
|
+
file: The file to read the data from. Can either be an object of the type `h5py.File` or the path to the
|
|
446
|
+
dataset. If you want to do multiple consecutive loads from the same file (e.g. for using th PyTorch
|
|
447
|
+
SimcatsDataset without preloading), consider initializing the file object yourself and passing it, to
|
|
448
|
+
improve the performance.
|
|
449
|
+
specific_ids: Determines if only specific ids should be loaded. Using this option, the returned values are
|
|
450
|
+
sorted according to the specified ids and not necessarily ascending. If set to None, all data is loaded.
|
|
451
|
+
Default is None.
|
|
452
|
+
progress_bar: Determines whether to display a progress bar. Default is True.
|
|
453
|
+
|
|
454
|
+
Returns:
|
|
455
|
+
Total Charge (TC) region minus Total Charge Transition (TCT) masks
|
|
456
|
+
"""
|
|
457
|
+
return [np.round(np.sum(occ, axis=-1) - tct_mask).astype(np.uint8) for (occ, tct_mask) in
|
|
458
|
+
zip(load_dataset(file=file, load_csds=False, load_occupations=True, specific_ids=specific_ids,
|
|
459
|
+
progress_bar=progress_bar).occupations,
|
|
460
|
+
load_dataset(file=file, load_csds=False, load_tct_masks=True, specific_ids=specific_ids,
|
|
461
|
+
progress_bar=progress_bar).tct_masks)]
|
|
462
|
+
|
|
463
|
+
|
|
464
|
+
def load_c_region_masks(file: Union[str, h5py.File],
|
|
465
|
+
specific_ids: Union[range, List[int], np.ndarray, None] = None,
|
|
466
|
+
progress_bar: bool = True) -> List[np.ndarray]:
|
|
467
|
+
"""Load Charge (C) region masks as ground truth data (CTs are basically excluded from the TC regions).
|
|
468
|
+
|
|
469
|
+
Args:
|
|
470
|
+
file: The file to read the data from. Can either be an object of the type `h5py.File` or the path to the
|
|
471
|
+
dataset. If you want to do multiple consecutive loads from the same file (e.g. for using th PyTorch
|
|
472
|
+
SimcatsDataset without preloading), consider initializing the file object yourself and passing it, to
|
|
473
|
+
improve the performance.
|
|
474
|
+
specific_ids: Determines if only specific ids should be loaded. Using this option, the returned values are
|
|
475
|
+
sorted according to the specified ids and not necessarily ascending. If set to None, all data is loaded.
|
|
476
|
+
Default is None.
|
|
477
|
+
progress_bar: Determines whether to display a progress bar. Default is True.
|
|
478
|
+
|
|
479
|
+
Returns:
|
|
480
|
+
Charge (C) region masks
|
|
481
|
+
"""
|
|
482
|
+
c_region_masks = load_tc_region_masks(file=file, specific_ids=specific_ids, progress_bar=progress_bar)
|
|
483
|
+
ct_masks = load_ct_masks(file=file, specific_ids=specific_ids, progress_bar=progress_bar)
|
|
484
|
+
for ct_mask, c_region_mask in zip(ct_masks, c_region_masks):
|
|
485
|
+
c_region_mask[c_region_mask == ct_mask] = 0
|
|
486
|
+
return c_region_masks
|