simcats-datasets 2.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- simcats_datasets/__init__.py +2 -0
- simcats_datasets/generation/__init__.py +6 -0
- simcats_datasets/generation/_create_dataset.py +221 -0
- simcats_datasets/generation/_create_simulated_dataset.py +372 -0
- simcats_datasets/loading/__init__.py +8 -0
- simcats_datasets/loading/_load_dataset.py +177 -0
- simcats_datasets/loading/load_ground_truth.py +486 -0
- simcats_datasets/loading/pytorch.py +426 -0
- simcats_datasets/support_functions/__init__.py +1 -0
- simcats_datasets/support_functions/_json_encoders.py +51 -0
- simcats_datasets/support_functions/clip_line_to_rectangle.py +191 -0
- simcats_datasets/support_functions/convert_lines.py +110 -0
- simcats_datasets/support_functions/data_preprocessing.py +351 -0
- simcats_datasets/support_functions/get_lead_transition_labels.py +102 -0
- simcats_datasets/support_functions/pytorch_format_output.py +170 -0
- simcats_datasets-2.4.0.dist-info/LICENSE +674 -0
- simcats_datasets-2.4.0.dist-info/METADATA +837 -0
- simcats_datasets-2.4.0.dist-info/RECORD +20 -0
- simcats_datasets-2.4.0.dist-info/WHEEL +5 -0
- simcats_datasets-2.4.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,426 @@
|
|
|
1
|
+
"""Implementation of a pytorch dataset class. Can be used to train machine learning approaches with CSD data.
|
|
2
|
+
|
|
3
|
+
@author: f.hader
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import math
|
|
7
|
+
from typing import Callable, List, Union, Tuple
|
|
8
|
+
|
|
9
|
+
import h5py
|
|
10
|
+
import numpy as np
|
|
11
|
+
from torch.utils.data import Dataset, ConcatDataset
|
|
12
|
+
|
|
13
|
+
import simcats_datasets.loading.load_ground_truth
|
|
14
|
+
import simcats_datasets.support_functions.data_preprocessing
|
|
15
|
+
from simcats_datasets.loading import load_dataset
|
|
16
|
+
from simcats_datasets.support_functions.pytorch_format_output import format_dict_csd_float_ground_truth_long
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class SimcatsDataset(Dataset):
|
|
20
|
+
"""Pytorch Dataset class implementation for SimCATS datasets. Uses simcats_datasets to load and provide (training) data.
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
def __init__(self,
|
|
24
|
+
h5_path: str,
|
|
25
|
+
specific_ids: Union[range, List[int], np.ndarray, None] = None,
|
|
26
|
+
load_ground_truth: Union[Callable, str, None] = None,
|
|
27
|
+
data_preprocessors: Union[List[Union[str, Callable]], None] = None,
|
|
28
|
+
ground_truth_preprocessors: Union[List[Union[str, Callable]], None] = None,
|
|
29
|
+
format_output: Union[Callable, str, None] = None, preload: bool = True,
|
|
30
|
+
max_concurrent_preloads: int = 100000,
|
|
31
|
+
progress_bar: bool = False, ):
|
|
32
|
+
"""Initializes an object for providing simcats_datasets data to pytorch.
|
|
33
|
+
|
|
34
|
+
Args:
|
|
35
|
+
h5_path: The path to the h5 file containing the dataset.
|
|
36
|
+
specific_ids: Determines if only specific ids should be loaded. Using this option, the returned values are
|
|
37
|
+
sorted according to the specified ids and not necessarily ascending. If set to None, all data is loaded.
|
|
38
|
+
Default is None.
|
|
39
|
+
load_ground_truth: Defines the required type of ground truth data to be loaded. Accepts either a callable or
|
|
40
|
+
a string. Callables must be of the same structure/interface as load_zeros_masks defined in
|
|
41
|
+
simcats_datasets.loading.load_ground_truth. Strings must map to the function names of the
|
|
42
|
+
loading functions defined in simcats_datasets.loading.load_ground_truth. If this is None,
|
|
43
|
+
no ground truth are loaded is used, which restricts what output formats are possible. Default is None. \n
|
|
44
|
+
Example of available types (**full list at simcats_datasets.loading.load_ground_truth**): \n
|
|
45
|
+
- **'tct_masks'**: The Total Charge Transition (TCT) mask generated by SimCATS.
|
|
46
|
+
- **'tc_region_masks'**: Regions with a fixed number of total charges.
|
|
47
|
+
- **'tc_region_minus_tct_masks'**: Regions with a fixed number of total charges, but with zeros between
|
|
48
|
+
the regions (at tcts).
|
|
49
|
+
data_preprocessors: Defines if data should be preprocessed. Accepts a list of callables or strings.
|
|
50
|
+
Callables must be of the same structure/interface as example_preprocessor defined in
|
|
51
|
+
simcats_datasets.support_functions.data_preprocessing. Strings must map to the function names of the
|
|
52
|
+
preprocessors defined in simcats_datasets.support_functions.data_preprocessing. Default is None. \n
|
|
53
|
+
Example of available types (**full list at simcats_datasets.support_functions.data_preprocessing**): \n
|
|
54
|
+
- **'min_max_0_1'**: Min max scaling of the data to [0, 1]
|
|
55
|
+
- **'standardization'**: Standardization of the data (mean=0, std=1)
|
|
56
|
+
- **'add_newaxis'**: Adds new axis as first axis (required for UNET)
|
|
57
|
+
ground_truth_preprocessors: Defines if ground truth should be preprocessed. Accepts a list of callables or
|
|
58
|
+
strings. Callables must be of the same structure/interface as example_preprocessor defined in
|
|
59
|
+
simcats_datasets.support_functions.data_preprocessing. Strings must map to the function names of the
|
|
60
|
+
preprocessors defined in simcats_datasets.support_functions.data_preprocessing. Default is None. \n
|
|
61
|
+
Example of available types (**full list at simcats_datasets.support_functions.data_preprocessing**): \n
|
|
62
|
+
- **'only_two_classes'**: Reduce the number of classes in a mask to 2 (set every pixel > 1 = 1)
|
|
63
|
+
format_output: Defines the required type of data format for the output. Accepts either a callable or a
|
|
64
|
+
string. Callables must be of the same structure/interface as format_dict_csd_float_ground_truth_long
|
|
65
|
+
defined in simcats_datasets.support_functions.pytorch_format_output. Strings must map to the function
|
|
66
|
+
names of the format functions defined in simcats_datasets.support_functions.pytorch_format_output. If
|
|
67
|
+
this is None, format_dict_csd_float_ground_truth_long is used, which does return the output as dict
|
|
68
|
+
with entries 'csd' and 'ground_truth' of dtype float and long, respectively. Default is None. \n
|
|
69
|
+
Example of available types (**full list at simcats_datasets.support_functions.pytorch_format_output**): \n
|
|
70
|
+
- **'format_dict_csd_float_ground_truth_long'**: formats the output as dict with entries 'csd' and
|
|
71
|
+
'ground_truth' of dtype float and long, respectively
|
|
72
|
+
preload: Enables preloading the whole dataset during the initialization (requires more RAM). Default is
|
|
73
|
+
True.
|
|
74
|
+
max_concurrent_preloads: Determines how many CSDs are concurrently loaded from the dataset during the
|
|
75
|
+
preload phase. This option only affects instances with preload = True. It allows to preload large
|
|
76
|
+
datasets (for which it might not be possible to load the whole dataset into the memory at once), by
|
|
77
|
+
loading them step by step and for example converting the CSDs to float32 with a corresponding data
|
|
78
|
+
preprocessor. Default is 100,000.
|
|
79
|
+
progress_bar: Determines whether to display a progress bar while loading data. Default is False.
|
|
80
|
+
"""
|
|
81
|
+
self.__h5_path = h5_path
|
|
82
|
+
self.__specific_ids = specific_ids
|
|
83
|
+
# set up the load ground truth function. Could be None, function referenced by string, or callable
|
|
84
|
+
if load_ground_truth is None:
|
|
85
|
+
self.__load_ground_truth = None
|
|
86
|
+
else:
|
|
87
|
+
if isinstance(load_ground_truth, str):
|
|
88
|
+
self.__load_ground_truth = getattr(simcats_datasets.loading.load_ground_truth, load_ground_truth)
|
|
89
|
+
else:
|
|
90
|
+
self.__load_ground_truth = load_ground_truth
|
|
91
|
+
# set up the data preprocessors. Could be None, functions referenced by strings, or callables
|
|
92
|
+
if data_preprocessors is None:
|
|
93
|
+
self.__data_preprocessors = data_preprocessors
|
|
94
|
+
else:
|
|
95
|
+
self.__data_preprocessors = [
|
|
96
|
+
i if not isinstance(i, str) else getattr(simcats_datasets.support_functions.data_preprocessing, i) for i
|
|
97
|
+
in data_preprocessors]
|
|
98
|
+
# set up the ground truth preprocessors. Could be None, functions referenced by strings, or callables
|
|
99
|
+
if ground_truth_preprocessors is None:
|
|
100
|
+
self.__ground_truth_preprocessors = ground_truth_preprocessors
|
|
101
|
+
else:
|
|
102
|
+
if self.load_ground_truth is None:
|
|
103
|
+
raise ValueError("If load_ground_truth is None. ground_truth_preprocessors should also be None")
|
|
104
|
+
self.__ground_truth_preprocessors = [
|
|
105
|
+
i if not isinstance(i, str) else getattr(simcats_datasets.support_functions.data_preprocessing, i) for i
|
|
106
|
+
in ground_truth_preprocessors]
|
|
107
|
+
# set up the output format function. Could be None, function referenced by string, or callable
|
|
108
|
+
if format_output is None:
|
|
109
|
+
self.__format_output = format_dict_csd_float_ground_truth_long
|
|
110
|
+
else:
|
|
111
|
+
if isinstance(format_output, str):
|
|
112
|
+
self.__format_output = getattr(simcats_datasets.support_functions.pytorch_format_output,
|
|
113
|
+
format_output, )
|
|
114
|
+
else:
|
|
115
|
+
self.__format_output = format_output
|
|
116
|
+
self.__preload = preload
|
|
117
|
+
self.__progress_bar = progress_bar
|
|
118
|
+
with h5py.File(h5_path, "r") as h5_file:
|
|
119
|
+
# setup available ids (if specific ids were supplied, they are mapped to a new range from 0 to len(specific_ids)
|
|
120
|
+
self.__num_ids = len(
|
|
121
|
+
load_dataset(file=h5_file, load_csds=False, load_ids=True, specific_ids=self.specific_ids,
|
|
122
|
+
progress_bar=self.progress_bar, ).ids)
|
|
123
|
+
# preprocess an exemplary image to get final shape (some preprocessors might adjust the shape)
|
|
124
|
+
_temp_csd = \
|
|
125
|
+
load_dataset(file=h5_file, load_csds=True, specific_ids=[0], progress_bar=self.progress_bar, ).csds[0]
|
|
126
|
+
if self.data_preprocessors is not None:
|
|
127
|
+
for processor in self.data_preprocessors:
|
|
128
|
+
_temp_csd = processor(_temp_csd)
|
|
129
|
+
self.__shape = (self.__num_ids, *np.squeeze(_temp_csd).shape)
|
|
130
|
+
# preload all data if requested
|
|
131
|
+
if self.preload:
|
|
132
|
+
self.__csds = []
|
|
133
|
+
self.__ground_truths = []
|
|
134
|
+
# load and save data, at most max_concurrent_ids at a time
|
|
135
|
+
for i in range(math.ceil(self.__num_ids / max_concurrent_preloads)):
|
|
136
|
+
_ids = range(i * max_concurrent_preloads,
|
|
137
|
+
np.min([(i + 1) * max_concurrent_preloads, self.__num_ids]))
|
|
138
|
+
if self.specific_ids is not None:
|
|
139
|
+
_ids = [self.specific_ids[i] for i in _ids]
|
|
140
|
+
# load
|
|
141
|
+
_temp_csds = [csd for csd in
|
|
142
|
+
load_dataset(file=h5_file, specific_ids=_ids, progress_bar=self.progress_bar, ).csds]
|
|
143
|
+
# preprocess data
|
|
144
|
+
if self.data_preprocessors is not None:
|
|
145
|
+
for processor in self.data_preprocessors:
|
|
146
|
+
_temp_csds = processor(_temp_csds)
|
|
147
|
+
self.__csds.extend(_temp_csds)
|
|
148
|
+
del _temp_csds
|
|
149
|
+
try:
|
|
150
|
+
_temp_ground_truths = [gt for gt in
|
|
151
|
+
self.load_ground_truth(file=h5_file, specific_ids=_ids, progress_bar=self.progress_bar, )]
|
|
152
|
+
# preprocess ground truth
|
|
153
|
+
if self.ground_truth_preprocessors is not None:
|
|
154
|
+
for processor in self.ground_truth_preprocessors:
|
|
155
|
+
_temp_ground_truths = processor(_temp_ground_truths)
|
|
156
|
+
self.__ground_truths.extend(_temp_ground_truths)
|
|
157
|
+
del _temp_ground_truths
|
|
158
|
+
except TypeError:
|
|
159
|
+
pass
|
|
160
|
+
|
|
161
|
+
@property
|
|
162
|
+
def h5_path(self) -> str:
|
|
163
|
+
return self.__h5_path
|
|
164
|
+
|
|
165
|
+
@property
|
|
166
|
+
def specific_ids(self) -> Union[range, List[int], np.ndarray, None]:
|
|
167
|
+
return self.__specific_ids
|
|
168
|
+
|
|
169
|
+
@property
|
|
170
|
+
def load_ground_truth(self) -> Callable:
|
|
171
|
+
return self.__load_ground_truth
|
|
172
|
+
|
|
173
|
+
@property
|
|
174
|
+
def data_preprocessors(self) -> Union[List[Callable], None]:
|
|
175
|
+
return self.__data_preprocessors
|
|
176
|
+
|
|
177
|
+
@property
|
|
178
|
+
def ground_truth_preprocessors(self) -> Union[List[Callable], None]:
|
|
179
|
+
return self.__ground_truth_preprocessors
|
|
180
|
+
|
|
181
|
+
@property
|
|
182
|
+
def format_output(self) -> Callable:
|
|
183
|
+
return self.__format_output
|
|
184
|
+
|
|
185
|
+
@property
|
|
186
|
+
def preload(self) -> bool:
|
|
187
|
+
return self.__preload
|
|
188
|
+
|
|
189
|
+
@property
|
|
190
|
+
def progress_bar(self) -> bool:
|
|
191
|
+
return self.__progress_bar
|
|
192
|
+
|
|
193
|
+
@property
|
|
194
|
+
def shape(self) -> Tuple[int]:
|
|
195
|
+
return self.__shape
|
|
196
|
+
|
|
197
|
+
def __len__(self):
|
|
198
|
+
"""
|
|
199
|
+
Returns the number of CSDs in the dataset.
|
|
200
|
+
"""
|
|
201
|
+
return self.__num_ids
|
|
202
|
+
|
|
203
|
+
def __getitem__(self, idx: int):
|
|
204
|
+
"""
|
|
205
|
+
Retrieves a csd and the corresponding ground truth at given index idx.
|
|
206
|
+
|
|
207
|
+
Args:
|
|
208
|
+
idx: The id of the csd and ground truth to be returned.
|
|
209
|
+
"""
|
|
210
|
+
if self.preload:
|
|
211
|
+
csd = self.__csds[idx]
|
|
212
|
+
try:
|
|
213
|
+
ground_truth = self.__ground_truths[idx]
|
|
214
|
+
except IndexError:
|
|
215
|
+
ground_truth = None
|
|
216
|
+
else:
|
|
217
|
+
# create h5_file here for non-preloaded mode. we can't create it before, because non preloaded Dataset used
|
|
218
|
+
# with multiple workers is not able to pickle HDF5 files!
|
|
219
|
+
if not hasattr(self, "__h5_file"):
|
|
220
|
+
self.__h5_file = h5py.File(self.h5_path, mode="r")
|
|
221
|
+
if self.specific_ids is not None:
|
|
222
|
+
idx = self.specific_ids[idx]
|
|
223
|
+
# load data
|
|
224
|
+
csd = load_dataset(file=self.__h5_file, specific_ids=[idx], progress_bar=self.progress_bar).csds[0]
|
|
225
|
+
# preprocess data
|
|
226
|
+
if self.data_preprocessors is not None:
|
|
227
|
+
for processor in self.data_preprocessors:
|
|
228
|
+
csd = processor(csd)
|
|
229
|
+
# load ground truth
|
|
230
|
+
try:
|
|
231
|
+
ground_truth = \
|
|
232
|
+
self.load_ground_truth(file=self.__h5_file, specific_ids=[idx], progress_bar=self.progress_bar)[0]
|
|
233
|
+
# preprocess ground truth
|
|
234
|
+
if self.ground_truth_preprocessors is not None:
|
|
235
|
+
for processor in self.ground_truth_preprocessors:
|
|
236
|
+
ground_truth = processor(ground_truth)
|
|
237
|
+
except TypeError:
|
|
238
|
+
ground_truth = None
|
|
239
|
+
return self.format_output(csd=csd, ground_truth=ground_truth, idx=idx)
|
|
240
|
+
|
|
241
|
+
def __repr__(self):
|
|
242
|
+
return (f"{self.__class__.__name__}(\n"
|
|
243
|
+
f"\th5_path={self.h5_path},\n"
|
|
244
|
+
f"\tspecific_ids={self.specific_ids},\n"
|
|
245
|
+
f"\tload_ground_truth={self.load_ground_truth.__name__ if self.load_ground_truth is not None else None},\n"
|
|
246
|
+
f"\tdata_preprocessors=[{[', '.join([func.__name__ for func in self.data_preprocessors]) if self.data_preprocessors is not None else None][0]}],\n"
|
|
247
|
+
f"\tground_truth_preprocessors=[{[', '.join([func.__name__ for func in self.ground_truth_preprocessors]) if self.ground_truth_preprocessors is not None else None][0]}],\n"
|
|
248
|
+
f"\tformat_output={self.format_output.__name__},\n"
|
|
249
|
+
f"\tpreload={self.preload},\n"
|
|
250
|
+
f"\tprogress_bar={self.progress_bar}\n"
|
|
251
|
+
f")")
|
|
252
|
+
|
|
253
|
+
def __del__(self):
|
|
254
|
+
if hasattr(self, "__h5_file"):
|
|
255
|
+
self.__h5_file.close()
|
|
256
|
+
|
|
257
|
+
|
|
258
|
+
class SimcatsConcatDataset(ConcatDataset):
|
|
259
|
+
def __init__(self,
|
|
260
|
+
h5_paths: List[str],
|
|
261
|
+
specific_ids: Union[List[Union[range, int, np.ndarray, None]], None] = None,
|
|
262
|
+
load_ground_truth: Union[Callable, str, None] = None,
|
|
263
|
+
data_preprocessors: Union[List[Union[str, Callable]], None] = None,
|
|
264
|
+
ground_truth_preprocessors: Union[List[Union[str, Callable]], None] = None,
|
|
265
|
+
format_output: Union[Callable, str, None] = None, preload: bool = True,
|
|
266
|
+
max_concurrent_preloads: int = 100000,
|
|
267
|
+
progress_bar: bool = False, ):
|
|
268
|
+
"""Initializes an object for providing concatenated simcats_datasets data to pytorch.
|
|
269
|
+
|
|
270
|
+
Args:
|
|
271
|
+
h5_paths: The paths to the h5 files containing the datasets to be concatenated.
|
|
272
|
+
specific_ids: Determines if only specific ids should be loaded. Using this option, the returned values are
|
|
273
|
+
sorted according to the specified ids and not necessarily ascending. If set to None, all data is loaded.
|
|
274
|
+
Expects a list of specific_id settings, with one entry for each provided h5_path. Default is None.
|
|
275
|
+
load_ground_truth: Defines the required type of ground truth data to be loaded. Accepts either a callable or
|
|
276
|
+
a string. Callables must be of the same structure/interface as load_zeros_masks defined in
|
|
277
|
+
simcats_datasets.loading.load_ground_truth. Strings must map to the function names of the
|
|
278
|
+
loading functions defined in simcats_datasets.loading.load_ground_truth. If this is None,
|
|
279
|
+
no ground truth are loaded is used, which restricts what output formats are possible. Default is None. \n
|
|
280
|
+
Example of available types (**full list at simcats_datasets.loading.load_ground_truth**): \n
|
|
281
|
+
- **'tct_masks'**: The Total Charge Transition (TCT) mask generated by SimCATS.
|
|
282
|
+
- **'tc_region_masks'**: Regions with a fixed number of total charges.
|
|
283
|
+
- **'tc_region_minus_tct_masks'**: Regions with a fixed number of total charges, but with zeros between
|
|
284
|
+
the regions (at tcts).
|
|
285
|
+
data_preprocessors: Defines if data should be preprocessed. Accepts a list of callables or strings.
|
|
286
|
+
Callables must be of the same structure/interface as example_preprocessor defined in
|
|
287
|
+
simcats_datasets.support_functions.data_preprocessing. Strings must map to the function names of the
|
|
288
|
+
preprocessors defined in simcats_datasets.support_functions.data_preprocessing. Default is None. \n
|
|
289
|
+
Example of available types (**full list at simcats_datasets.support_functions.data_preprocessing**): \n
|
|
290
|
+
- **'min_max_0_1'**: Min max scaling of the data to [0, 1]
|
|
291
|
+
- **'standardization'**: Standardization of the data (mean=0, std=1)
|
|
292
|
+
- **'add_newaxis'**: Adds new axis as first axis (required for UNET)
|
|
293
|
+
ground_truth_preprocessors: Defines if ground truth should be preprocessed. Accepts a list of callables or
|
|
294
|
+
strings. Callables must be of the same structure/interface as example_preprocessor defined in
|
|
295
|
+
simcats_datasets.support_functions.data_preprocessing. Strings must map to the function names of the
|
|
296
|
+
preprocessors defined in simcats_datasets.support_functions.data_preprocessing. Default is None. \n
|
|
297
|
+
Example of available types (**full list at simcats_datasets.support_functions.data_preprocessing**): \n
|
|
298
|
+
- **'only_two_classes'**: Reduce the number of classes in a mask to 2 (set every pixel > 1 = 1)
|
|
299
|
+
format_output: Defines the required type of data format for the output. Accepts either a callable or a
|
|
300
|
+
string. Callables must be of the same structure/interface as format_dict_csd_float_ground_truth_long
|
|
301
|
+
defined in simcats_datasets.support_functions.pytorch_format_output. Strings must map to the function
|
|
302
|
+
names of the format functions defined in simcats_datasets.support_functions.pytorch_format_output. If
|
|
303
|
+
this is None, format_dict_csd_float_ground_truth_long is used, which does return the output as dict
|
|
304
|
+
with entries 'csd' and 'ground_truth' of dtype float and long, respectively. Default is None. \n
|
|
305
|
+
Example of available types (**full list at simcats_datasets.support_functions.pytorch_format_output**): \n
|
|
306
|
+
- **'format_dict_csd_float_ground_truth_long'**: formats the output as dict with entries 'csd' and
|
|
307
|
+
'ground_truth' of dtype float and long, respectively
|
|
308
|
+
preload: Enables preloading the whole dataset during the initialization (requires more RAM). Default is
|
|
309
|
+
True.
|
|
310
|
+
max_concurrent_preloads: Determines how many CSDs are concurrently loaded from the dataset during the
|
|
311
|
+
preload phase. This option only affects instances with preload = True. It allows to preload large
|
|
312
|
+
datasets (for which it might not be possible to load the whole dataset into the memory at once), by
|
|
313
|
+
loading them step by step and for example converting the CSDs to float32 with a corresponding data
|
|
314
|
+
preprocessor. Default is 100.000.
|
|
315
|
+
progress_bar: Determines whether to display a progress bar while loading data. Default is False.
|
|
316
|
+
"""
|
|
317
|
+
_datasets = list()
|
|
318
|
+
if specific_ids is not None and len(specific_ids) != len(h5_paths):
|
|
319
|
+
raise IndexError("Specific_ids were provided but with a different number of entries than h5_paths! If "
|
|
320
|
+
"specific_ids are provided they need to contain the same number of entries!")
|
|
321
|
+
for i, path in enumerate(h5_paths):
|
|
322
|
+
if specific_ids is not None and len(specific_ids) == len(h5_paths):
|
|
323
|
+
temp_specific_ids = specific_ids[i]
|
|
324
|
+
else:
|
|
325
|
+
temp_specific_ids = None
|
|
326
|
+
_datasets.append(
|
|
327
|
+
SimcatsDataset(h5_path=path, specific_ids=temp_specific_ids, load_ground_truth=load_ground_truth,
|
|
328
|
+
data_preprocessors=data_preprocessors,
|
|
329
|
+
ground_truth_preprocessors=ground_truth_preprocessors, format_output=format_output,
|
|
330
|
+
preload=preload, max_concurrent_preloads=max_concurrent_preloads,
|
|
331
|
+
progress_bar=progress_bar))
|
|
332
|
+
super().__init__(_datasets)
|
|
333
|
+
self.__h5_paths = h5_paths
|
|
334
|
+
self.__specific_ids = specific_ids
|
|
335
|
+
# set up the load ground truth function. Could be None, function referenced by string, or callable
|
|
336
|
+
if load_ground_truth is None:
|
|
337
|
+
self.__load_ground_truth = None
|
|
338
|
+
else:
|
|
339
|
+
if isinstance(load_ground_truth, str):
|
|
340
|
+
self.__load_ground_truth = getattr(simcats_datasets.loading.load_ground_truth, load_ground_truth)
|
|
341
|
+
else:
|
|
342
|
+
self.__load_ground_truth = load_ground_truth
|
|
343
|
+
# set up the data preprocessors. Could be None, functions referenced by strings, or callables
|
|
344
|
+
if data_preprocessors is None:
|
|
345
|
+
self.__data_preprocessors = data_preprocessors
|
|
346
|
+
else:
|
|
347
|
+
self.__data_preprocessors = [
|
|
348
|
+
i if not isinstance(i, str) else getattr(simcats_datasets.support_functions.data_preprocessing, i) for i
|
|
349
|
+
in data_preprocessors]
|
|
350
|
+
# set up the ground truth preprocessors. Could be None, functions referenced by strings, or callables
|
|
351
|
+
if ground_truth_preprocessors is None:
|
|
352
|
+
self.__ground_truth_preprocessors = ground_truth_preprocessors
|
|
353
|
+
else:
|
|
354
|
+
if self.load_ground_truth is None:
|
|
355
|
+
raise ValueError("If load_ground_truth is None, ground_truth_preprocessors should also be None")
|
|
356
|
+
self.__ground_truth_preprocessors = [
|
|
357
|
+
i if not isinstance(i, str) else getattr(simcats_datasets.support_functions.data_preprocessing, i) for i
|
|
358
|
+
in ground_truth_preprocessors]
|
|
359
|
+
# set up the output format function. Could be None, function referenced by string, or callable
|
|
360
|
+
if format_output is None:
|
|
361
|
+
self.__format_output = format_dict_csd_float_ground_truth_long
|
|
362
|
+
else:
|
|
363
|
+
if isinstance(format_output, str):
|
|
364
|
+
self.__format_output = getattr(simcats_datasets.support_functions.pytorch_format_output,
|
|
365
|
+
format_output, )
|
|
366
|
+
else:
|
|
367
|
+
self.__format_output = format_output
|
|
368
|
+
self.__preload = preload
|
|
369
|
+
self.__progress_bar = progress_bar
|
|
370
|
+
# get dataset shapes and check if all shapes are the same
|
|
371
|
+
shape = None
|
|
372
|
+
for dataset in _datasets:
|
|
373
|
+
if shape is None:
|
|
374
|
+
shape = dataset.shape[1:]
|
|
375
|
+
elif dataset.shape[1:] != shape:
|
|
376
|
+
raise ValueError(f"The shape of the SimcatsDataset CSDs should be identical but found shapes "
|
|
377
|
+
f"{[dataset.shape[1:] for dataset in _datasets]}")
|
|
378
|
+
self.__shape = (len(self), *shape)
|
|
379
|
+
|
|
380
|
+
@property
|
|
381
|
+
def h5_paths(self) -> List[str]:
|
|
382
|
+
return self.__h5_paths
|
|
383
|
+
|
|
384
|
+
@property
|
|
385
|
+
def specific_ids(self) -> Union[List[Union[range, List[int], np.ndarray, None]], None]:
|
|
386
|
+
return self.__specific_ids
|
|
387
|
+
|
|
388
|
+
@property
|
|
389
|
+
def load_ground_truth(self) -> Callable:
|
|
390
|
+
return self.__load_ground_truth
|
|
391
|
+
|
|
392
|
+
@property
|
|
393
|
+
def data_preprocessors(self) -> Union[List[Callable], None]:
|
|
394
|
+
return self.__data_preprocessors
|
|
395
|
+
|
|
396
|
+
@property
|
|
397
|
+
def ground_truth_preprocessors(self) -> Union[List[Callable], None]:
|
|
398
|
+
return self.__ground_truth_preprocessors
|
|
399
|
+
|
|
400
|
+
@property
|
|
401
|
+
def format_output(self) -> Callable:
|
|
402
|
+
return self.__format_output
|
|
403
|
+
|
|
404
|
+
@property
|
|
405
|
+
def preload(self) -> bool:
|
|
406
|
+
return self.__preload
|
|
407
|
+
|
|
408
|
+
@property
|
|
409
|
+
def progress_bar(self) -> bool:
|
|
410
|
+
return self.__progress_bar
|
|
411
|
+
|
|
412
|
+
@property
|
|
413
|
+
def shape(self) -> Tuple[int]:
|
|
414
|
+
return self.__shape
|
|
415
|
+
|
|
416
|
+
def __repr__(self):
|
|
417
|
+
return (f"{self.__class__.__name__}(\n"
|
|
418
|
+
f"\th5_paths={self.h5_paths},\n"
|
|
419
|
+
f"\tspecific_ids={self.specific_ids},\n"
|
|
420
|
+
f"\tload_ground_truth={self.load_ground_truth.__name__},\n"
|
|
421
|
+
f"\tdata_preprocessors=[{[', '.join([func.__name__ for func in self.data_preprocessors]) if self.data_preprocessors is not None else None][0]}],\n"
|
|
422
|
+
f"\tground_truth_preprocessors=[{[', '.join([func.__name__ for func in self.ground_truth_preprocessors]) if self.ground_truth_preprocessors is not None else None][0]}],\n"
|
|
423
|
+
f"\tformat_output={self.format_output.__name__},\n"
|
|
424
|
+
f"\tpreload={self.preload},\n"
|
|
425
|
+
f"\tprogress_bar={self.progress_bar}\n"
|
|
426
|
+
f")")
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""Module with support functions for different main functionalities."""
|
|
@@ -0,0 +1,51 @@
|
|
|
1
|
+
"""JSON encoders that are required to encode numpy arrays and xarray DataArrays while saving datasets.
|
|
2
|
+
|
|
3
|
+
@author: f.hader
|
|
4
|
+
"""
|
|
5
|
+
|
|
6
|
+
import json
|
|
7
|
+
import numpy as np
|
|
8
|
+
from xarray import DataArray
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class MultipleJsonEncoders(json.JSONEncoder):
|
|
12
|
+
"""Combine multiple JSON encoders"""
|
|
13
|
+
|
|
14
|
+
def __init__(self, *encoders):
|
|
15
|
+
super().__init__()
|
|
16
|
+
self.encoders = encoders
|
|
17
|
+
self.args = ()
|
|
18
|
+
self.kwargs = {}
|
|
19
|
+
|
|
20
|
+
def default(self, obj):
|
|
21
|
+
for encoder in self.encoders:
|
|
22
|
+
try:
|
|
23
|
+
return encoder(*self.args, **self.kwargs).default(obj)
|
|
24
|
+
except TypeError:
|
|
25
|
+
pass
|
|
26
|
+
raise TypeError(f'Object of type {obj.__class__.__name__} is not JSON serializable')
|
|
27
|
+
|
|
28
|
+
def __call__(self, *args, **kwargs):
|
|
29
|
+
self.args = args
|
|
30
|
+
self.kwargs = kwargs
|
|
31
|
+
enc = json.JSONEncoder(*args, **kwargs)
|
|
32
|
+
enc.default = self.default
|
|
33
|
+
return enc
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class DataArrayEncoder(json.JSONEncoder):
|
|
37
|
+
"""JSON encoder for DataArrays."""
|
|
38
|
+
|
|
39
|
+
def default(self, obj):
|
|
40
|
+
if isinstance(obj, DataArray):
|
|
41
|
+
return obj.to_dict()
|
|
42
|
+
return json.JSONEncoder.default(self, obj)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class NumpyEncoder(json.JSONEncoder):
|
|
46
|
+
"""JSON encoder for numpy arrays."""
|
|
47
|
+
|
|
48
|
+
def default(self, obj):
|
|
49
|
+
if isinstance(obj, np.ndarray):
|
|
50
|
+
return obj.tolist()
|
|
51
|
+
return json.JSONEncoder.default(self, obj)
|
|
@@ -0,0 +1,191 @@
|
|
|
1
|
+
"""Helper functions for clipping lines into rectangles (into the CSD space).
|
|
2
|
+
|
|
3
|
+
Used to clip single transition lines into the CSD space to generate transition specific labels with the function
|
|
4
|
+
`get_lead_transition_labels` from `simcats_datasets.support_functions.get_lead_transition_labels`.
|
|
5
|
+
|
|
6
|
+
@author: f.hader
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
from typing import Tuple, List, Union
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def clip_slope_line_to_rectangle(slope: float, point: Tuple[float, float], rect_corners: List[Tuple[float, float]],
|
|
13
|
+
is_start: bool = True) -> Union[Tuple[Tuple[float, float], Tuple[float, float]], None]:
|
|
14
|
+
"""Clips a line segment with a given slope to a rectangle, extending it to either positive or negative infinity from the provided point.
|
|
15
|
+
|
|
16
|
+
Args:
|
|
17
|
+
slope: The slope of the line.
|
|
18
|
+
point: A tuple (x, y) representing the starting or ending point of the line segment.
|
|
19
|
+
rect_corners: A list of four tuples, each representing the corner points of the rectangle.
|
|
20
|
+
is_start: Specifies whether the line extends to positive infinity (True) or negative infinity (False) from the
|
|
21
|
+
provided point. Default is True (positive infinity).
|
|
22
|
+
|
|
23
|
+
Returns:
|
|
24
|
+
A tuple (start, end) representing the clipped line segment. Returns None if the line is entirely outside the
|
|
25
|
+
rectangle.
|
|
26
|
+
|
|
27
|
+
Notes:
|
|
28
|
+
- The function handles lines defined by a slope and a single point.
|
|
29
|
+
- The 'is_start' parameter determines whether the line extends to positive or negative infinity from the
|
|
30
|
+
provided point.
|
|
31
|
+
"""
|
|
32
|
+
x_range = (min(rect_corners, key=lambda p: p[0])[0], max(rect_corners, key=lambda p: p[0])[0])
|
|
33
|
+
|
|
34
|
+
if is_start:
|
|
35
|
+
start = point
|
|
36
|
+
end = (x_range[1] + 1, point[1] + slope * (x_range[1] + 1 - point[0]))
|
|
37
|
+
else:
|
|
38
|
+
start = (x_range[0] - 1, point[1] - slope * (point[0] - x_range[0] + 1))
|
|
39
|
+
end = point
|
|
40
|
+
|
|
41
|
+
# Use the clip_line_to_rectangle function to clip the line
|
|
42
|
+
return clip_point_line_to_rectangle(start, end, rect_corners)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def clip_infinite_slope_line_to_rectangle(slope: float, point: Tuple[float, float],
|
|
46
|
+
rect_corners: List[Tuple[float, float]]) -> Union[
|
|
47
|
+
Tuple[Tuple[float, float], Tuple[float, float]], None]:
|
|
48
|
+
"""Clips a line segment with a given slope to a rectangle, extending it to positive and negative infinity from the provided point.
|
|
49
|
+
|
|
50
|
+
Args:
|
|
51
|
+
slope: The slope of the line.
|
|
52
|
+
point: A tuple (x, y) representing the starting or ending point of the line segment.
|
|
53
|
+
rect_corners: A list of four tuples, each representing the corner points of the rectangle.
|
|
54
|
+
|
|
55
|
+
Returns:
|
|
56
|
+
A tuple (start, end) representing the clipped line segment. Returns None if the line is entirely outside the
|
|
57
|
+
rectangle.
|
|
58
|
+
|
|
59
|
+
Notes:
|
|
60
|
+
- The function handles lines defined by a slope and a single point.
|
|
61
|
+
"""
|
|
62
|
+
x_range = (min(rect_corners, key=lambda p: p[0])[0], max(rect_corners, key=lambda p: p[0])[0])
|
|
63
|
+
|
|
64
|
+
start = (x_range[0] - 1, point[1] - slope * (point[0] - x_range[0] + 1))
|
|
65
|
+
end = (x_range[1] + 1, point[1] + slope * (x_range[1] + 1 - point[0]))
|
|
66
|
+
|
|
67
|
+
# Use the clip_line_to_rectangle function to clip the line
|
|
68
|
+
return clip_point_line_to_rectangle(start, end, rect_corners)
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def clip_point_line_to_rectangle(start: Tuple[float, float], end: Tuple[float, float],
|
|
72
|
+
rect_corners: List[Tuple[float, float]]) -> Union[
|
|
73
|
+
Tuple[Tuple[float, float], Tuple[float, float]], None]:
|
|
74
|
+
"""Clips a line segment defined by its start and end points to a rectangle.
|
|
75
|
+
|
|
76
|
+
Args:
|
|
77
|
+
start: A tuple (x, y) representing the start point of the line.
|
|
78
|
+
end: A tuple (x, y) representing the end point of the line.
|
|
79
|
+
rect_corners: A list of four tuples, each representing the corner points of the rectangle.
|
|
80
|
+
|
|
81
|
+
Returns:
|
|
82
|
+
A tuple representing the clipped line segment (start, end) if any part of the line is inside the rectangle.
|
|
83
|
+
Returns None if the line is entirely outside the rectangle.
|
|
84
|
+
|
|
85
|
+
Notes:
|
|
86
|
+
- The function handles lines defined by two points.
|
|
87
|
+
- The function handles the case when the line is entirely inside the rectangle.
|
|
88
|
+
"""
|
|
89
|
+
if all(is_point_inside_rectangle(point, rect_corners) for point in (start, end)):
|
|
90
|
+
# The entire line is inside the rectangle
|
|
91
|
+
return start, end
|
|
92
|
+
|
|
93
|
+
clipped_start = None
|
|
94
|
+
clipped_end = None
|
|
95
|
+
|
|
96
|
+
# Check if the start point is inside the rectangle
|
|
97
|
+
if is_point_inside_rectangle(start, rect_corners):
|
|
98
|
+
clipped_start = start
|
|
99
|
+
|
|
100
|
+
# Check if the end point is inside the rectangle
|
|
101
|
+
if is_point_inside_rectangle(end, rect_corners):
|
|
102
|
+
clipped_end = end
|
|
103
|
+
|
|
104
|
+
# Iterate through pairs of adjacent corner points to check for intersections
|
|
105
|
+
for rect_point1, rect_point2 in zip(rect_corners, rect_corners[1:] + [rect_corners[0]]):
|
|
106
|
+
# Calculate the intersection point between the line and the rectangle edge
|
|
107
|
+
intersection = line_intersection(start, end, rect_point1, rect_point2)
|
|
108
|
+
if intersection is not None:
|
|
109
|
+
if clipped_start is None:
|
|
110
|
+
clipped_start = intersection
|
|
111
|
+
elif clipped_end is None:
|
|
112
|
+
clipped_end = intersection
|
|
113
|
+
if clipped_start is not None and clipped_end is not None:
|
|
114
|
+
break
|
|
115
|
+
|
|
116
|
+
return clipped_start, clipped_end
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
def is_point_inside_rectangle(point: Tuple[float, float], rect_corners: List[Tuple[float, float]]) -> bool:
|
|
120
|
+
"""Checks if a point is inside a rectangle defined by its corner points.
|
|
121
|
+
|
|
122
|
+
Args:
|
|
123
|
+
point: A tuple (x, y) representing the point to be checked.
|
|
124
|
+
rect_corners: A list of four tuples, each representing the corner points of the rectangle. it is assumed that
|
|
125
|
+
these are sorted so that they from a course around the rectangle. Thus, the first and third (or
|
|
126
|
+
alternatively the second and fourth) define the rectangle.
|
|
127
|
+
|
|
128
|
+
Returns:
|
|
129
|
+
True if the point is inside the rectangle, False otherwise.
|
|
130
|
+
"""
|
|
131
|
+
x, y = point
|
|
132
|
+
x1, y1 = rect_corners[0]
|
|
133
|
+
x2, y2 = rect_corners[2]
|
|
134
|
+
return x1 <= x <= x2 and y1 <= y <= y2
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
def line_intersection(p1: Tuple[float, float], p2: Tuple[float, float], q1: Tuple[float, float],
|
|
138
|
+
q2: Tuple[float, float]) -> Union[Tuple[float, float], None]:
|
|
139
|
+
"""Calculates the intersection point between two line segments defined by their respective endpoints.
|
|
140
|
+
|
|
141
|
+
Args:
|
|
142
|
+
p1: A tuple (x, y) representing the first endpoint of the first line.
|
|
143
|
+
p2: A tuple (x, y) representing the second endpoint of the first line.
|
|
144
|
+
q1: A tuple (x, y) representing the first endpoint of the second line.
|
|
145
|
+
q2: A tuple (x, y) representing the second endpoint of the second line.
|
|
146
|
+
|
|
147
|
+
Returns:
|
|
148
|
+
A tuple (x, y) representing the intersection point if the lines intersect. Returns None if the lines are
|
|
149
|
+
parallel or do not intersect.
|
|
150
|
+
"""
|
|
151
|
+
x1, y1 = p1
|
|
152
|
+
x2, y2 = p2
|
|
153
|
+
x3, y3 = q1
|
|
154
|
+
x4, y4 = q2
|
|
155
|
+
|
|
156
|
+
# denominator
|
|
157
|
+
den = (x1 - x2) * (y3 - y4) - (y1 - y2) * (x3 - x4)
|
|
158
|
+
if den == 0:
|
|
159
|
+
return None # Lines are parallel or coincident
|
|
160
|
+
|
|
161
|
+
t = ((x1 - x3) * (y3 - y4) - (y1 - y3) * (x3 - x4)) / den
|
|
162
|
+
u = -((x1 - x2) * (y1 - y3) - (y1 - y2) * (x1 - x3)) / den
|
|
163
|
+
|
|
164
|
+
if 0 <= t <= 1 and 0 <= u <= 1:
|
|
165
|
+
intersection_x = x1 + t * (x2 - x1)
|
|
166
|
+
intersection_y = y1 + t * (y2 - y1)
|
|
167
|
+
return (intersection_x, intersection_y)
|
|
168
|
+
else:
|
|
169
|
+
return None
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
def create_rectangle_corners(x_range: Tuple[float, float], y_range: Tuple[float, float]) -> List[Tuple[float, float]]:
|
|
173
|
+
"""Creates rectangle corner points that form a rectangle around the specified x and y value ranges.
|
|
174
|
+
|
|
175
|
+
Args:
|
|
176
|
+
x_range: A tuple (x_min, x_max) representing the minimum and maximum x values.
|
|
177
|
+
y_range: A tuple (y_min, y_max) representing the minimum and maximum y values.
|
|
178
|
+
|
|
179
|
+
Returns:
|
|
180
|
+
A list of four tuples, each representing the corner points of the rectangle in the following order: bottom-left,
|
|
181
|
+
bottom-right, top-right, top-left.
|
|
182
|
+
"""
|
|
183
|
+
x_min, x_max = x_range
|
|
184
|
+
y_min, y_max = y_range
|
|
185
|
+
|
|
186
|
+
bottom_left = (x_min, y_min)
|
|
187
|
+
bottom_right = (x_max, y_min)
|
|
188
|
+
top_right = (x_max, y_max)
|
|
189
|
+
top_left = (x_min, y_max)
|
|
190
|
+
|
|
191
|
+
return [bottom_left, bottom_right, top_right, top_left]
|