simcats-datasets 2.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,426 @@
1
+ """Implementation of a pytorch dataset class. Can be used to train machine learning approaches with CSD data.
2
+
3
+ @author: f.hader
4
+ """
5
+
6
+ import math
7
+ from typing import Callable, List, Union, Tuple
8
+
9
+ import h5py
10
+ import numpy as np
11
+ from torch.utils.data import Dataset, ConcatDataset
12
+
13
+ import simcats_datasets.loading.load_ground_truth
14
+ import simcats_datasets.support_functions.data_preprocessing
15
+ from simcats_datasets.loading import load_dataset
16
+ from simcats_datasets.support_functions.pytorch_format_output import format_dict_csd_float_ground_truth_long
17
+
18
+
19
+ class SimcatsDataset(Dataset):
20
+ """Pytorch Dataset class implementation for SimCATS datasets. Uses simcats_datasets to load and provide (training) data.
21
+ """
22
+
23
+ def __init__(self,
24
+ h5_path: str,
25
+ specific_ids: Union[range, List[int], np.ndarray, None] = None,
26
+ load_ground_truth: Union[Callable, str, None] = None,
27
+ data_preprocessors: Union[List[Union[str, Callable]], None] = None,
28
+ ground_truth_preprocessors: Union[List[Union[str, Callable]], None] = None,
29
+ format_output: Union[Callable, str, None] = None, preload: bool = True,
30
+ max_concurrent_preloads: int = 100000,
31
+ progress_bar: bool = False, ):
32
+ """Initializes an object for providing simcats_datasets data to pytorch.
33
+
34
+ Args:
35
+ h5_path: The path to the h5 file containing the dataset.
36
+ specific_ids: Determines if only specific ids should be loaded. Using this option, the returned values are
37
+ sorted according to the specified ids and not necessarily ascending. If set to None, all data is loaded.
38
+ Default is None.
39
+ load_ground_truth: Defines the required type of ground truth data to be loaded. Accepts either a callable or
40
+ a string. Callables must be of the same structure/interface as load_zeros_masks defined in
41
+ simcats_datasets.loading.load_ground_truth. Strings must map to the function names of the
42
+ loading functions defined in simcats_datasets.loading.load_ground_truth. If this is None,
43
+ no ground truth are loaded is used, which restricts what output formats are possible. Default is None. \n
44
+ Example of available types (**full list at simcats_datasets.loading.load_ground_truth**): \n
45
+ - **'tct_masks'**: The Total Charge Transition (TCT) mask generated by SimCATS.
46
+ - **'tc_region_masks'**: Regions with a fixed number of total charges.
47
+ - **'tc_region_minus_tct_masks'**: Regions with a fixed number of total charges, but with zeros between
48
+ the regions (at tcts).
49
+ data_preprocessors: Defines if data should be preprocessed. Accepts a list of callables or strings.
50
+ Callables must be of the same structure/interface as example_preprocessor defined in
51
+ simcats_datasets.support_functions.data_preprocessing. Strings must map to the function names of the
52
+ preprocessors defined in simcats_datasets.support_functions.data_preprocessing. Default is None. \n
53
+ Example of available types (**full list at simcats_datasets.support_functions.data_preprocessing**): \n
54
+ - **'min_max_0_1'**: Min max scaling of the data to [0, 1]
55
+ - **'standardization'**: Standardization of the data (mean=0, std=1)
56
+ - **'add_newaxis'**: Adds new axis as first axis (required for UNET)
57
+ ground_truth_preprocessors: Defines if ground truth should be preprocessed. Accepts a list of callables or
58
+ strings. Callables must be of the same structure/interface as example_preprocessor defined in
59
+ simcats_datasets.support_functions.data_preprocessing. Strings must map to the function names of the
60
+ preprocessors defined in simcats_datasets.support_functions.data_preprocessing. Default is None. \n
61
+ Example of available types (**full list at simcats_datasets.support_functions.data_preprocessing**): \n
62
+ - **'only_two_classes'**: Reduce the number of classes in a mask to 2 (set every pixel > 1 = 1)
63
+ format_output: Defines the required type of data format for the output. Accepts either a callable or a
64
+ string. Callables must be of the same structure/interface as format_dict_csd_float_ground_truth_long
65
+ defined in simcats_datasets.support_functions.pytorch_format_output. Strings must map to the function
66
+ names of the format functions defined in simcats_datasets.support_functions.pytorch_format_output. If
67
+ this is None, format_dict_csd_float_ground_truth_long is used, which does return the output as dict
68
+ with entries 'csd' and 'ground_truth' of dtype float and long, respectively. Default is None. \n
69
+ Example of available types (**full list at simcats_datasets.support_functions.pytorch_format_output**): \n
70
+ - **'format_dict_csd_float_ground_truth_long'**: formats the output as dict with entries 'csd' and
71
+ 'ground_truth' of dtype float and long, respectively
72
+ preload: Enables preloading the whole dataset during the initialization (requires more RAM). Default is
73
+ True.
74
+ max_concurrent_preloads: Determines how many CSDs are concurrently loaded from the dataset during the
75
+ preload phase. This option only affects instances with preload = True. It allows to preload large
76
+ datasets (for which it might not be possible to load the whole dataset into the memory at once), by
77
+ loading them step by step and for example converting the CSDs to float32 with a corresponding data
78
+ preprocessor. Default is 100,000.
79
+ progress_bar: Determines whether to display a progress bar while loading data. Default is False.
80
+ """
81
+ self.__h5_path = h5_path
82
+ self.__specific_ids = specific_ids
83
+ # set up the load ground truth function. Could be None, function referenced by string, or callable
84
+ if load_ground_truth is None:
85
+ self.__load_ground_truth = None
86
+ else:
87
+ if isinstance(load_ground_truth, str):
88
+ self.__load_ground_truth = getattr(simcats_datasets.loading.load_ground_truth, load_ground_truth)
89
+ else:
90
+ self.__load_ground_truth = load_ground_truth
91
+ # set up the data preprocessors. Could be None, functions referenced by strings, or callables
92
+ if data_preprocessors is None:
93
+ self.__data_preprocessors = data_preprocessors
94
+ else:
95
+ self.__data_preprocessors = [
96
+ i if not isinstance(i, str) else getattr(simcats_datasets.support_functions.data_preprocessing, i) for i
97
+ in data_preprocessors]
98
+ # set up the ground truth preprocessors. Could be None, functions referenced by strings, or callables
99
+ if ground_truth_preprocessors is None:
100
+ self.__ground_truth_preprocessors = ground_truth_preprocessors
101
+ else:
102
+ if self.load_ground_truth is None:
103
+ raise ValueError("If load_ground_truth is None. ground_truth_preprocessors should also be None")
104
+ self.__ground_truth_preprocessors = [
105
+ i if not isinstance(i, str) else getattr(simcats_datasets.support_functions.data_preprocessing, i) for i
106
+ in ground_truth_preprocessors]
107
+ # set up the output format function. Could be None, function referenced by string, or callable
108
+ if format_output is None:
109
+ self.__format_output = format_dict_csd_float_ground_truth_long
110
+ else:
111
+ if isinstance(format_output, str):
112
+ self.__format_output = getattr(simcats_datasets.support_functions.pytorch_format_output,
113
+ format_output, )
114
+ else:
115
+ self.__format_output = format_output
116
+ self.__preload = preload
117
+ self.__progress_bar = progress_bar
118
+ with h5py.File(h5_path, "r") as h5_file:
119
+ # setup available ids (if specific ids were supplied, they are mapped to a new range from 0 to len(specific_ids)
120
+ self.__num_ids = len(
121
+ load_dataset(file=h5_file, load_csds=False, load_ids=True, specific_ids=self.specific_ids,
122
+ progress_bar=self.progress_bar, ).ids)
123
+ # preprocess an exemplary image to get final shape (some preprocessors might adjust the shape)
124
+ _temp_csd = \
125
+ load_dataset(file=h5_file, load_csds=True, specific_ids=[0], progress_bar=self.progress_bar, ).csds[0]
126
+ if self.data_preprocessors is not None:
127
+ for processor in self.data_preprocessors:
128
+ _temp_csd = processor(_temp_csd)
129
+ self.__shape = (self.__num_ids, *np.squeeze(_temp_csd).shape)
130
+ # preload all data if requested
131
+ if self.preload:
132
+ self.__csds = []
133
+ self.__ground_truths = []
134
+ # load and save data, at most max_concurrent_ids at a time
135
+ for i in range(math.ceil(self.__num_ids / max_concurrent_preloads)):
136
+ _ids = range(i * max_concurrent_preloads,
137
+ np.min([(i + 1) * max_concurrent_preloads, self.__num_ids]))
138
+ if self.specific_ids is not None:
139
+ _ids = [self.specific_ids[i] for i in _ids]
140
+ # load
141
+ _temp_csds = [csd for csd in
142
+ load_dataset(file=h5_file, specific_ids=_ids, progress_bar=self.progress_bar, ).csds]
143
+ # preprocess data
144
+ if self.data_preprocessors is not None:
145
+ for processor in self.data_preprocessors:
146
+ _temp_csds = processor(_temp_csds)
147
+ self.__csds.extend(_temp_csds)
148
+ del _temp_csds
149
+ try:
150
+ _temp_ground_truths = [gt for gt in
151
+ self.load_ground_truth(file=h5_file, specific_ids=_ids, progress_bar=self.progress_bar, )]
152
+ # preprocess ground truth
153
+ if self.ground_truth_preprocessors is not None:
154
+ for processor in self.ground_truth_preprocessors:
155
+ _temp_ground_truths = processor(_temp_ground_truths)
156
+ self.__ground_truths.extend(_temp_ground_truths)
157
+ del _temp_ground_truths
158
+ except TypeError:
159
+ pass
160
+
161
+ @property
162
+ def h5_path(self) -> str:
163
+ return self.__h5_path
164
+
165
+ @property
166
+ def specific_ids(self) -> Union[range, List[int], np.ndarray, None]:
167
+ return self.__specific_ids
168
+
169
+ @property
170
+ def load_ground_truth(self) -> Callable:
171
+ return self.__load_ground_truth
172
+
173
+ @property
174
+ def data_preprocessors(self) -> Union[List[Callable], None]:
175
+ return self.__data_preprocessors
176
+
177
+ @property
178
+ def ground_truth_preprocessors(self) -> Union[List[Callable], None]:
179
+ return self.__ground_truth_preprocessors
180
+
181
+ @property
182
+ def format_output(self) -> Callable:
183
+ return self.__format_output
184
+
185
+ @property
186
+ def preload(self) -> bool:
187
+ return self.__preload
188
+
189
+ @property
190
+ def progress_bar(self) -> bool:
191
+ return self.__progress_bar
192
+
193
+ @property
194
+ def shape(self) -> Tuple[int]:
195
+ return self.__shape
196
+
197
+ def __len__(self):
198
+ """
199
+ Returns the number of CSDs in the dataset.
200
+ """
201
+ return self.__num_ids
202
+
203
+ def __getitem__(self, idx: int):
204
+ """
205
+ Retrieves a csd and the corresponding ground truth at given index idx.
206
+
207
+ Args:
208
+ idx: The id of the csd and ground truth to be returned.
209
+ """
210
+ if self.preload:
211
+ csd = self.__csds[idx]
212
+ try:
213
+ ground_truth = self.__ground_truths[idx]
214
+ except IndexError:
215
+ ground_truth = None
216
+ else:
217
+ # create h5_file here for non-preloaded mode. we can't create it before, because non preloaded Dataset used
218
+ # with multiple workers is not able to pickle HDF5 files!
219
+ if not hasattr(self, "__h5_file"):
220
+ self.__h5_file = h5py.File(self.h5_path, mode="r")
221
+ if self.specific_ids is not None:
222
+ idx = self.specific_ids[idx]
223
+ # load data
224
+ csd = load_dataset(file=self.__h5_file, specific_ids=[idx], progress_bar=self.progress_bar).csds[0]
225
+ # preprocess data
226
+ if self.data_preprocessors is not None:
227
+ for processor in self.data_preprocessors:
228
+ csd = processor(csd)
229
+ # load ground truth
230
+ try:
231
+ ground_truth = \
232
+ self.load_ground_truth(file=self.__h5_file, specific_ids=[idx], progress_bar=self.progress_bar)[0]
233
+ # preprocess ground truth
234
+ if self.ground_truth_preprocessors is not None:
235
+ for processor in self.ground_truth_preprocessors:
236
+ ground_truth = processor(ground_truth)
237
+ except TypeError:
238
+ ground_truth = None
239
+ return self.format_output(csd=csd, ground_truth=ground_truth, idx=idx)
240
+
241
+ def __repr__(self):
242
+ return (f"{self.__class__.__name__}(\n"
243
+ f"\th5_path={self.h5_path},\n"
244
+ f"\tspecific_ids={self.specific_ids},\n"
245
+ f"\tload_ground_truth={self.load_ground_truth.__name__ if self.load_ground_truth is not None else None},\n"
246
+ f"\tdata_preprocessors=[{[', '.join([func.__name__ for func in self.data_preprocessors]) if self.data_preprocessors is not None else None][0]}],\n"
247
+ f"\tground_truth_preprocessors=[{[', '.join([func.__name__ for func in self.ground_truth_preprocessors]) if self.ground_truth_preprocessors is not None else None][0]}],\n"
248
+ f"\tformat_output={self.format_output.__name__},\n"
249
+ f"\tpreload={self.preload},\n"
250
+ f"\tprogress_bar={self.progress_bar}\n"
251
+ f")")
252
+
253
+ def __del__(self):
254
+ if hasattr(self, "__h5_file"):
255
+ self.__h5_file.close()
256
+
257
+
258
+ class SimcatsConcatDataset(ConcatDataset):
259
+ def __init__(self,
260
+ h5_paths: List[str],
261
+ specific_ids: Union[List[Union[range, int, np.ndarray, None]], None] = None,
262
+ load_ground_truth: Union[Callable, str, None] = None,
263
+ data_preprocessors: Union[List[Union[str, Callable]], None] = None,
264
+ ground_truth_preprocessors: Union[List[Union[str, Callable]], None] = None,
265
+ format_output: Union[Callable, str, None] = None, preload: bool = True,
266
+ max_concurrent_preloads: int = 100000,
267
+ progress_bar: bool = False, ):
268
+ """Initializes an object for providing concatenated simcats_datasets data to pytorch.
269
+
270
+ Args:
271
+ h5_paths: The paths to the h5 files containing the datasets to be concatenated.
272
+ specific_ids: Determines if only specific ids should be loaded. Using this option, the returned values are
273
+ sorted according to the specified ids and not necessarily ascending. If set to None, all data is loaded.
274
+ Expects a list of specific_id settings, with one entry for each provided h5_path. Default is None.
275
+ load_ground_truth: Defines the required type of ground truth data to be loaded. Accepts either a callable or
276
+ a string. Callables must be of the same structure/interface as load_zeros_masks defined in
277
+ simcats_datasets.loading.load_ground_truth. Strings must map to the function names of the
278
+ loading functions defined in simcats_datasets.loading.load_ground_truth. If this is None,
279
+ no ground truth are loaded is used, which restricts what output formats are possible. Default is None. \n
280
+ Example of available types (**full list at simcats_datasets.loading.load_ground_truth**): \n
281
+ - **'tct_masks'**: The Total Charge Transition (TCT) mask generated by SimCATS.
282
+ - **'tc_region_masks'**: Regions with a fixed number of total charges.
283
+ - **'tc_region_minus_tct_masks'**: Regions with a fixed number of total charges, but with zeros between
284
+ the regions (at tcts).
285
+ data_preprocessors: Defines if data should be preprocessed. Accepts a list of callables or strings.
286
+ Callables must be of the same structure/interface as example_preprocessor defined in
287
+ simcats_datasets.support_functions.data_preprocessing. Strings must map to the function names of the
288
+ preprocessors defined in simcats_datasets.support_functions.data_preprocessing. Default is None. \n
289
+ Example of available types (**full list at simcats_datasets.support_functions.data_preprocessing**): \n
290
+ - **'min_max_0_1'**: Min max scaling of the data to [0, 1]
291
+ - **'standardization'**: Standardization of the data (mean=0, std=1)
292
+ - **'add_newaxis'**: Adds new axis as first axis (required for UNET)
293
+ ground_truth_preprocessors: Defines if ground truth should be preprocessed. Accepts a list of callables or
294
+ strings. Callables must be of the same structure/interface as example_preprocessor defined in
295
+ simcats_datasets.support_functions.data_preprocessing. Strings must map to the function names of the
296
+ preprocessors defined in simcats_datasets.support_functions.data_preprocessing. Default is None. \n
297
+ Example of available types (**full list at simcats_datasets.support_functions.data_preprocessing**): \n
298
+ - **'only_two_classes'**: Reduce the number of classes in a mask to 2 (set every pixel > 1 = 1)
299
+ format_output: Defines the required type of data format for the output. Accepts either a callable or a
300
+ string. Callables must be of the same structure/interface as format_dict_csd_float_ground_truth_long
301
+ defined in simcats_datasets.support_functions.pytorch_format_output. Strings must map to the function
302
+ names of the format functions defined in simcats_datasets.support_functions.pytorch_format_output. If
303
+ this is None, format_dict_csd_float_ground_truth_long is used, which does return the output as dict
304
+ with entries 'csd' and 'ground_truth' of dtype float and long, respectively. Default is None. \n
305
+ Example of available types (**full list at simcats_datasets.support_functions.pytorch_format_output**): \n
306
+ - **'format_dict_csd_float_ground_truth_long'**: formats the output as dict with entries 'csd' and
307
+ 'ground_truth' of dtype float and long, respectively
308
+ preload: Enables preloading the whole dataset during the initialization (requires more RAM). Default is
309
+ True.
310
+ max_concurrent_preloads: Determines how many CSDs are concurrently loaded from the dataset during the
311
+ preload phase. This option only affects instances with preload = True. It allows to preload large
312
+ datasets (for which it might not be possible to load the whole dataset into the memory at once), by
313
+ loading them step by step and for example converting the CSDs to float32 with a corresponding data
314
+ preprocessor. Default is 100.000.
315
+ progress_bar: Determines whether to display a progress bar while loading data. Default is False.
316
+ """
317
+ _datasets = list()
318
+ if specific_ids is not None and len(specific_ids) != len(h5_paths):
319
+ raise IndexError("Specific_ids were provided but with a different number of entries than h5_paths! If "
320
+ "specific_ids are provided they need to contain the same number of entries!")
321
+ for i, path in enumerate(h5_paths):
322
+ if specific_ids is not None and len(specific_ids) == len(h5_paths):
323
+ temp_specific_ids = specific_ids[i]
324
+ else:
325
+ temp_specific_ids = None
326
+ _datasets.append(
327
+ SimcatsDataset(h5_path=path, specific_ids=temp_specific_ids, load_ground_truth=load_ground_truth,
328
+ data_preprocessors=data_preprocessors,
329
+ ground_truth_preprocessors=ground_truth_preprocessors, format_output=format_output,
330
+ preload=preload, max_concurrent_preloads=max_concurrent_preloads,
331
+ progress_bar=progress_bar))
332
+ super().__init__(_datasets)
333
+ self.__h5_paths = h5_paths
334
+ self.__specific_ids = specific_ids
335
+ # set up the load ground truth function. Could be None, function referenced by string, or callable
336
+ if load_ground_truth is None:
337
+ self.__load_ground_truth = None
338
+ else:
339
+ if isinstance(load_ground_truth, str):
340
+ self.__load_ground_truth = getattr(simcats_datasets.loading.load_ground_truth, load_ground_truth)
341
+ else:
342
+ self.__load_ground_truth = load_ground_truth
343
+ # set up the data preprocessors. Could be None, functions referenced by strings, or callables
344
+ if data_preprocessors is None:
345
+ self.__data_preprocessors = data_preprocessors
346
+ else:
347
+ self.__data_preprocessors = [
348
+ i if not isinstance(i, str) else getattr(simcats_datasets.support_functions.data_preprocessing, i) for i
349
+ in data_preprocessors]
350
+ # set up the ground truth preprocessors. Could be None, functions referenced by strings, or callables
351
+ if ground_truth_preprocessors is None:
352
+ self.__ground_truth_preprocessors = ground_truth_preprocessors
353
+ else:
354
+ if self.load_ground_truth is None:
355
+ raise ValueError("If load_ground_truth is None, ground_truth_preprocessors should also be None")
356
+ self.__ground_truth_preprocessors = [
357
+ i if not isinstance(i, str) else getattr(simcats_datasets.support_functions.data_preprocessing, i) for i
358
+ in ground_truth_preprocessors]
359
+ # set up the output format function. Could be None, function referenced by string, or callable
360
+ if format_output is None:
361
+ self.__format_output = format_dict_csd_float_ground_truth_long
362
+ else:
363
+ if isinstance(format_output, str):
364
+ self.__format_output = getattr(simcats_datasets.support_functions.pytorch_format_output,
365
+ format_output, )
366
+ else:
367
+ self.__format_output = format_output
368
+ self.__preload = preload
369
+ self.__progress_bar = progress_bar
370
+ # get dataset shapes and check if all shapes are the same
371
+ shape = None
372
+ for dataset in _datasets:
373
+ if shape is None:
374
+ shape = dataset.shape[1:]
375
+ elif dataset.shape[1:] != shape:
376
+ raise ValueError(f"The shape of the SimcatsDataset CSDs should be identical but found shapes "
377
+ f"{[dataset.shape[1:] for dataset in _datasets]}")
378
+ self.__shape = (len(self), *shape)
379
+
380
+ @property
381
+ def h5_paths(self) -> List[str]:
382
+ return self.__h5_paths
383
+
384
+ @property
385
+ def specific_ids(self) -> Union[List[Union[range, List[int], np.ndarray, None]], None]:
386
+ return self.__specific_ids
387
+
388
+ @property
389
+ def load_ground_truth(self) -> Callable:
390
+ return self.__load_ground_truth
391
+
392
+ @property
393
+ def data_preprocessors(self) -> Union[List[Callable], None]:
394
+ return self.__data_preprocessors
395
+
396
+ @property
397
+ def ground_truth_preprocessors(self) -> Union[List[Callable], None]:
398
+ return self.__ground_truth_preprocessors
399
+
400
+ @property
401
+ def format_output(self) -> Callable:
402
+ return self.__format_output
403
+
404
+ @property
405
+ def preload(self) -> bool:
406
+ return self.__preload
407
+
408
+ @property
409
+ def progress_bar(self) -> bool:
410
+ return self.__progress_bar
411
+
412
+ @property
413
+ def shape(self) -> Tuple[int]:
414
+ return self.__shape
415
+
416
+ def __repr__(self):
417
+ return (f"{self.__class__.__name__}(\n"
418
+ f"\th5_paths={self.h5_paths},\n"
419
+ f"\tspecific_ids={self.specific_ids},\n"
420
+ f"\tload_ground_truth={self.load_ground_truth.__name__},\n"
421
+ f"\tdata_preprocessors=[{[', '.join([func.__name__ for func in self.data_preprocessors]) if self.data_preprocessors is not None else None][0]}],\n"
422
+ f"\tground_truth_preprocessors=[{[', '.join([func.__name__ for func in self.ground_truth_preprocessors]) if self.ground_truth_preprocessors is not None else None][0]}],\n"
423
+ f"\tformat_output={self.format_output.__name__},\n"
424
+ f"\tpreload={self.preload},\n"
425
+ f"\tprogress_bar={self.progress_bar}\n"
426
+ f")")
@@ -0,0 +1 @@
1
+ """Module with support functions for different main functionalities."""
@@ -0,0 +1,51 @@
1
+ """JSON encoders that are required to encode numpy arrays and xarray DataArrays while saving datasets.
2
+
3
+ @author: f.hader
4
+ """
5
+
6
+ import json
7
+ import numpy as np
8
+ from xarray import DataArray
9
+
10
+
11
+ class MultipleJsonEncoders(json.JSONEncoder):
12
+ """Combine multiple JSON encoders"""
13
+
14
+ def __init__(self, *encoders):
15
+ super().__init__()
16
+ self.encoders = encoders
17
+ self.args = ()
18
+ self.kwargs = {}
19
+
20
+ def default(self, obj):
21
+ for encoder in self.encoders:
22
+ try:
23
+ return encoder(*self.args, **self.kwargs).default(obj)
24
+ except TypeError:
25
+ pass
26
+ raise TypeError(f'Object of type {obj.__class__.__name__} is not JSON serializable')
27
+
28
+ def __call__(self, *args, **kwargs):
29
+ self.args = args
30
+ self.kwargs = kwargs
31
+ enc = json.JSONEncoder(*args, **kwargs)
32
+ enc.default = self.default
33
+ return enc
34
+
35
+
36
+ class DataArrayEncoder(json.JSONEncoder):
37
+ """JSON encoder for DataArrays."""
38
+
39
+ def default(self, obj):
40
+ if isinstance(obj, DataArray):
41
+ return obj.to_dict()
42
+ return json.JSONEncoder.default(self, obj)
43
+
44
+
45
+ class NumpyEncoder(json.JSONEncoder):
46
+ """JSON encoder for numpy arrays."""
47
+
48
+ def default(self, obj):
49
+ if isinstance(obj, np.ndarray):
50
+ return obj.tolist()
51
+ return json.JSONEncoder.default(self, obj)
@@ -0,0 +1,191 @@
1
+ """Helper functions for clipping lines into rectangles (into the CSD space).
2
+
3
+ Used to clip single transition lines into the CSD space to generate transition specific labels with the function
4
+ `get_lead_transition_labels` from `simcats_datasets.support_functions.get_lead_transition_labels`.
5
+
6
+ @author: f.hader
7
+ """
8
+
9
+ from typing import Tuple, List, Union
10
+
11
+
12
+ def clip_slope_line_to_rectangle(slope: float, point: Tuple[float, float], rect_corners: List[Tuple[float, float]],
13
+ is_start: bool = True) -> Union[Tuple[Tuple[float, float], Tuple[float, float]], None]:
14
+ """Clips a line segment with a given slope to a rectangle, extending it to either positive or negative infinity from the provided point.
15
+
16
+ Args:
17
+ slope: The slope of the line.
18
+ point: A tuple (x, y) representing the starting or ending point of the line segment.
19
+ rect_corners: A list of four tuples, each representing the corner points of the rectangle.
20
+ is_start: Specifies whether the line extends to positive infinity (True) or negative infinity (False) from the
21
+ provided point. Default is True (positive infinity).
22
+
23
+ Returns:
24
+ A tuple (start, end) representing the clipped line segment. Returns None if the line is entirely outside the
25
+ rectangle.
26
+
27
+ Notes:
28
+ - The function handles lines defined by a slope and a single point.
29
+ - The 'is_start' parameter determines whether the line extends to positive or negative infinity from the
30
+ provided point.
31
+ """
32
+ x_range = (min(rect_corners, key=lambda p: p[0])[0], max(rect_corners, key=lambda p: p[0])[0])
33
+
34
+ if is_start:
35
+ start = point
36
+ end = (x_range[1] + 1, point[1] + slope * (x_range[1] + 1 - point[0]))
37
+ else:
38
+ start = (x_range[0] - 1, point[1] - slope * (point[0] - x_range[0] + 1))
39
+ end = point
40
+
41
+ # Use the clip_line_to_rectangle function to clip the line
42
+ return clip_point_line_to_rectangle(start, end, rect_corners)
43
+
44
+
45
+ def clip_infinite_slope_line_to_rectangle(slope: float, point: Tuple[float, float],
46
+ rect_corners: List[Tuple[float, float]]) -> Union[
47
+ Tuple[Tuple[float, float], Tuple[float, float]], None]:
48
+ """Clips a line segment with a given slope to a rectangle, extending it to positive and negative infinity from the provided point.
49
+
50
+ Args:
51
+ slope: The slope of the line.
52
+ point: A tuple (x, y) representing the starting or ending point of the line segment.
53
+ rect_corners: A list of four tuples, each representing the corner points of the rectangle.
54
+
55
+ Returns:
56
+ A tuple (start, end) representing the clipped line segment. Returns None if the line is entirely outside the
57
+ rectangle.
58
+
59
+ Notes:
60
+ - The function handles lines defined by a slope and a single point.
61
+ """
62
+ x_range = (min(rect_corners, key=lambda p: p[0])[0], max(rect_corners, key=lambda p: p[0])[0])
63
+
64
+ start = (x_range[0] - 1, point[1] - slope * (point[0] - x_range[0] + 1))
65
+ end = (x_range[1] + 1, point[1] + slope * (x_range[1] + 1 - point[0]))
66
+
67
+ # Use the clip_line_to_rectangle function to clip the line
68
+ return clip_point_line_to_rectangle(start, end, rect_corners)
69
+
70
+
71
+ def clip_point_line_to_rectangle(start: Tuple[float, float], end: Tuple[float, float],
72
+ rect_corners: List[Tuple[float, float]]) -> Union[
73
+ Tuple[Tuple[float, float], Tuple[float, float]], None]:
74
+ """Clips a line segment defined by its start and end points to a rectangle.
75
+
76
+ Args:
77
+ start: A tuple (x, y) representing the start point of the line.
78
+ end: A tuple (x, y) representing the end point of the line.
79
+ rect_corners: A list of four tuples, each representing the corner points of the rectangle.
80
+
81
+ Returns:
82
+ A tuple representing the clipped line segment (start, end) if any part of the line is inside the rectangle.
83
+ Returns None if the line is entirely outside the rectangle.
84
+
85
+ Notes:
86
+ - The function handles lines defined by two points.
87
+ - The function handles the case when the line is entirely inside the rectangle.
88
+ """
89
+ if all(is_point_inside_rectangle(point, rect_corners) for point in (start, end)):
90
+ # The entire line is inside the rectangle
91
+ return start, end
92
+
93
+ clipped_start = None
94
+ clipped_end = None
95
+
96
+ # Check if the start point is inside the rectangle
97
+ if is_point_inside_rectangle(start, rect_corners):
98
+ clipped_start = start
99
+
100
+ # Check if the end point is inside the rectangle
101
+ if is_point_inside_rectangle(end, rect_corners):
102
+ clipped_end = end
103
+
104
+ # Iterate through pairs of adjacent corner points to check for intersections
105
+ for rect_point1, rect_point2 in zip(rect_corners, rect_corners[1:] + [rect_corners[0]]):
106
+ # Calculate the intersection point between the line and the rectangle edge
107
+ intersection = line_intersection(start, end, rect_point1, rect_point2)
108
+ if intersection is not None:
109
+ if clipped_start is None:
110
+ clipped_start = intersection
111
+ elif clipped_end is None:
112
+ clipped_end = intersection
113
+ if clipped_start is not None and clipped_end is not None:
114
+ break
115
+
116
+ return clipped_start, clipped_end
117
+
118
+
119
+ def is_point_inside_rectangle(point: Tuple[float, float], rect_corners: List[Tuple[float, float]]) -> bool:
120
+ """Checks if a point is inside a rectangle defined by its corner points.
121
+
122
+ Args:
123
+ point: A tuple (x, y) representing the point to be checked.
124
+ rect_corners: A list of four tuples, each representing the corner points of the rectangle. it is assumed that
125
+ these are sorted so that they from a course around the rectangle. Thus, the first and third (or
126
+ alternatively the second and fourth) define the rectangle.
127
+
128
+ Returns:
129
+ True if the point is inside the rectangle, False otherwise.
130
+ """
131
+ x, y = point
132
+ x1, y1 = rect_corners[0]
133
+ x2, y2 = rect_corners[2]
134
+ return x1 <= x <= x2 and y1 <= y <= y2
135
+
136
+
137
+ def line_intersection(p1: Tuple[float, float], p2: Tuple[float, float], q1: Tuple[float, float],
138
+ q2: Tuple[float, float]) -> Union[Tuple[float, float], None]:
139
+ """Calculates the intersection point between two line segments defined by their respective endpoints.
140
+
141
+ Args:
142
+ p1: A tuple (x, y) representing the first endpoint of the first line.
143
+ p2: A tuple (x, y) representing the second endpoint of the first line.
144
+ q1: A tuple (x, y) representing the first endpoint of the second line.
145
+ q2: A tuple (x, y) representing the second endpoint of the second line.
146
+
147
+ Returns:
148
+ A tuple (x, y) representing the intersection point if the lines intersect. Returns None if the lines are
149
+ parallel or do not intersect.
150
+ """
151
+ x1, y1 = p1
152
+ x2, y2 = p2
153
+ x3, y3 = q1
154
+ x4, y4 = q2
155
+
156
+ # denominator
157
+ den = (x1 - x2) * (y3 - y4) - (y1 - y2) * (x3 - x4)
158
+ if den == 0:
159
+ return None # Lines are parallel or coincident
160
+
161
+ t = ((x1 - x3) * (y3 - y4) - (y1 - y3) * (x3 - x4)) / den
162
+ u = -((x1 - x2) * (y1 - y3) - (y1 - y2) * (x1 - x3)) / den
163
+
164
+ if 0 <= t <= 1 and 0 <= u <= 1:
165
+ intersection_x = x1 + t * (x2 - x1)
166
+ intersection_y = y1 + t * (y2 - y1)
167
+ return (intersection_x, intersection_y)
168
+ else:
169
+ return None
170
+
171
+
172
+ def create_rectangle_corners(x_range: Tuple[float, float], y_range: Tuple[float, float]) -> List[Tuple[float, float]]:
173
+ """Creates rectangle corner points that form a rectangle around the specified x and y value ranges.
174
+
175
+ Args:
176
+ x_range: A tuple (x_min, x_max) representing the minimum and maximum x values.
177
+ y_range: A tuple (y_min, y_max) representing the minimum and maximum y values.
178
+
179
+ Returns:
180
+ A list of four tuples, each representing the corner points of the rectangle in the following order: bottom-left,
181
+ bottom-right, top-right, top-left.
182
+ """
183
+ x_min, x_max = x_range
184
+ y_min, y_max = y_range
185
+
186
+ bottom_left = (x_min, y_min)
187
+ bottom_right = (x_max, y_min)
188
+ top_right = (x_max, y_max)
189
+ top_left = (x_min, y_max)
190
+
191
+ return [bottom_left, bottom_right, top_right, top_left]