simcats-datasets 2.5.0__py3-none-any.whl → 2.6.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -18,15 +18,23 @@ import numpy as np
18
18
 
19
19
  # imports required for eval, to create IdealCSDGeometric objects from metadata strings
20
20
  import re
21
-
22
21
  import sympy
23
22
  from simcats.ideal_csd import IdealCSDGeometric
24
23
  from simcats.ideal_csd.geometric import calculate_all_bezier_anchors, tct_bezier, initialize_tct_functions
24
+ from simcats.sensor import SensorScanSensorGeneric, SensorPeakLorentzian, SensorRiseGLF
25
+ from simcats.sensor.barrier_function import BarrierFunctionGLF
25
26
  from numpy import array
26
27
  from simcats.distortions import OccupationDotJumps
27
28
  from tqdm import tqdm
28
29
 
29
30
  from simcats_datasets.loading import load_dataset
31
+ from simcats_datasets.support_functions.convert_lines import (points_2d_voltage_to_1d_idx_space,
32
+ lines_voltage_to_pixel_space)
33
+ from simcats_datasets.support_functions.get_coulomb_oscillation_area_boundaries import (
34
+ get_coulomb_oscillation_area_boundaries,
35
+ calc_sensor_potential,
36
+ calc_fully_conductive_potential)
37
+ from simcats_datasets.support_functions.metadata_utils import reconstruct_metadata_types
30
38
  from simcats.support_functions import rotate_points
31
39
 
32
40
 
@@ -34,7 +42,9 @@ from simcats.support_functions import rotate_points
34
42
  _csd_ground_truths = ["load_zeros_masks", "load_tct_masks", "load_tct_by_dot_masks", "load_idt_masks", "load_ct_masks",
35
43
  "load_ct_by_dot_masks", "load_tc_region_masks", "load_tc_region_minus_tct_masks",
36
44
  "load_c_region_masks"]
37
- _sensor_scan_ground_truths = ["load_zeros_masks", "load_tct_masks"]
45
+ _sensor_scan_ground_truths = ["load_zeros_masks", "load_sensor_regime_masks", "load_sensor_peak_center_masks",
46
+ "load_sensor_oscillation_boundaries_masks",
47
+ "load_sensor_quadrilateral_oscillation_boundaries_masks"]
38
48
 
39
49
 
40
50
  def load_zeros_masks(file: Union[str, h5py.File],
@@ -491,3 +501,204 @@ def load_c_region_masks(file: Union[str, h5py.File],
491
501
  for ct_mask, c_region_mask in zip(ct_masks, c_region_masks):
492
502
  c_region_mask[c_region_mask == ct_mask] = 0
493
503
  return c_region_masks
504
+
505
+
506
+ def load_sensor_regime_masks(file: Union[str, h5py.File],
507
+ specific_ids: Union[range, List[int], np.ndarray, None] = None,
508
+ progress_bar: bool = True) -> List[np.ndarray]:
509
+ """Load sensor regime masks as ground truth data for sensor scans.
510
+
511
+ In the sensor regime mask, the non-conductive area is marked with 0, the oscillation area with 1 and the fully
512
+ conductive area with 2.
513
+
514
+ Args:
515
+ file: The file to read the data from. Can either be an object of the type `h5py.File` or the path to the
516
+ dataset. If you want to do multiple consecutive loads from the same file (e.g. for using th PyTorch
517
+ SimcatsDataset without preloading), consider initializing the file object yourself and passing it, to
518
+ improve the performance.
519
+ specific_ids: Determines if only specific ids should be loaded. Using this option, the returned values are
520
+ sorted according to the specified ids and not necessarily ascending. If set to None, all data is loaded.
521
+ Default is None.
522
+ progress_bar: Determines whether to display a progress bar. Default is True.
523
+
524
+ Returns:
525
+ Sensor regime masks
526
+ """
527
+ return load_dataset(file=file, load_csds=False, load_sensor_regime_masks=True, specific_ids=specific_ids,
528
+ progress_bar=progress_bar).sensor_regime_masks
529
+
530
+
531
+ def load_sensor_peak_center_masks(file: Union[str, h5py.File],
532
+ specific_ids: Union[range, List[int], np.ndarray, None] = None,
533
+ progress_bar: bool = True) -> List[np.ndarray]:
534
+ """Load sensor peak center masks as ground truth data for sensor scans.
535
+
536
+ Args:
537
+ file: The file to read the data from. Can either be an object of the type `h5py.File` or the path to the
538
+ dataset. If you want to do multiple consecutive loads from the same file (e.g. for using th PyTorch
539
+ SimcatsDataset without preloading), consider initializing the file object yourself and passing it, to
540
+ improve the performance.
541
+ specific_ids: Determines if only specific ids should be loaded. Using this option, the returned values are
542
+ sorted according to the specified ids and not necessarily ascending. If set to None, all data is loaded.
543
+ Default is None.
544
+ progress_bar: Determines whether to display a progress bar. Default is True.
545
+
546
+ Returns:
547
+ Sensor peak center masks
548
+ """
549
+ return load_dataset(file=file, load_csds=False, load_sensor_peak_center_masks=True, specific_ids=specific_ids,
550
+ progress_bar=progress_bar).sensor_peak_center_masks
551
+
552
+
553
+ def load_sensor_oscillation_boundaries_masks(file: Union[str, h5py.File],
554
+ specific_ids: Union[range, List[int], np.ndarray, None] = None,
555
+ progress_bar: bool = True) -> List[np.ndarray]:
556
+ """Masks for the boundaries of the oscillation area with a triangular basic shape (in 2D).
557
+
558
+ The oscillation area boundaries are marked with value 1 for pinch-off lines and 2 for the fully conductive line.
559
+ The rest of the mask is 0. In 1D data the oscillation area boundaries are points, in 2D data lines. This method
560
+ uses the metadata from each scan to calculate the fully conductive line in the sensor potential.
561
+
562
+ Args:
563
+ file: The file to read the data from. Can either be an object of the type `h5py.File` or the path to the
564
+ dataset. If you want to do multiple consecutive loads from the same file (e.g. for using th PyTorch
565
+ SimcatsDataset without preloading), consider initializing the file object yourself and passing it, to
566
+ improve the performance.
567
+ specific_ids: Determines if only specific ids should be loaded. Using this option, the returned values are
568
+ sorted according to the specified ids and not necessarily ascending. If set to None, all data is loaded.
569
+ Default is None.
570
+ progress_bar: Determines whether to display a progress bar. Default is True.
571
+
572
+ Returns:
573
+ Oscillation area boundaries masks (triangular shape)
574
+
575
+ """
576
+ metadata = load_dataset(file=file, load_csds=False, load_sensor_scans=False, load_metadata=True,
577
+ specific_ids=specific_ids, progress_bar=progress_bar).metadata
578
+
579
+ sensor_masks = []
580
+
581
+ for metadata_sensor in tqdm(metadata, desc="calculating sensor regime boundaries masks", total=len(metadata),
582
+ disable=not progress_bar):
583
+ # adapt metadata types if necessary:
584
+ meta = reconstruct_metadata_types(metadata_sensor)
585
+ line_coords, line_labels = get_coulomb_oscillation_area_boundaries(meta)
586
+
587
+ if isinstance(meta["resolution"], int): # 1D
588
+ oscillation_mask = np.zeros(meta["resolution"], dtype=np.uint8)
589
+ if line_coords.ndim == 2:
590
+ coordinate_indices = points_2d_voltage_to_1d_idx_space(coordinates=line_coords[:, :2],
591
+ voltage_range_x=meta["sweep_range_sensor_g1"],
592
+ voltage_range_y=meta["sweep_range_sensor_g2"],
593
+ resolution=meta["resolution"],
594
+ round_to_int=True)
595
+ for i, point in enumerate(coordinate_indices):
596
+ if "start" in line_labels[i] or "pinch_off" in line_labels[i]: # pinch-off line
597
+ oscillation_mask[point] = 1
598
+ elif "sensor_potential_fully_conductive" == line_labels[i]: # fully-conductive line
599
+ oscillation_mask[point] = 2
600
+ # else: quadrangular fully conductive lines not required here
601
+ # else: no lines in the area of the scan
602
+ else: # if isinstance(meta["resolution"], list) (2D case)
603
+ oscillation_mask = np.zeros((meta["resolution"][1], meta["resolution"][0]), dtype=np.uint8)
604
+
605
+ if line_coords.ndim == 2:
606
+ coordinate_pixels = lines_voltage_to_pixel_space(lines=line_coords,
607
+ voltage_range_x=np.array(meta["sweep_range_sensor_g1"]),
608
+ voltage_range_y=np.array(meta["sweep_range_sensor_g2"]),
609
+ image_width=meta["resolution"][0],
610
+ image_height=meta["resolution"][1],
611
+ round_to_int=True)
612
+
613
+ for i, line in enumerate(coordinate_pixels):
614
+ max_res = np.abs(line[2] - line[0]) if (
615
+ np.abs(line[2] - line[0]) > np.abs(line[3] - line[1])) else np.abs(line[3] - line[1])
616
+ max_res = max_res + 1 if (
617
+ (max_res < meta["resolution"][0]) and (max_res < meta["resolution"][1])) else max_res
618
+ x_coords = np.linspace(line[0], line[2], max_res, dtype=int)
619
+ y_coords = np.linspace(line[1], line[3], max_res, dtype=int)
620
+
621
+ if "start" in line_labels[i] or "pinch_off" in line_labels[i]: # pinch-off line
622
+ oscillation_mask[y_coords, x_coords] = 1
623
+ elif "sensor_potential_fully_conductive" == line_labels[i]: # fully-conductive line
624
+ oscillation_mask[y_coords, x_coords] = 2
625
+ # else: quadrangular fully conductive lines not required here
626
+ # else: no lines in the area of the scan
627
+ sensor_masks.append(oscillation_mask)
628
+
629
+ return sensor_masks
630
+
631
+
632
+ def load_sensor_quadrilateral_oscillation_boundaries_masks(file: Union[str, h5py.File],
633
+ specific_ids: Union[range, List[int], np.ndarray, None] = None,
634
+ progress_bar: bool = True) -> List[np.ndarray]:
635
+ """Load quadrilateral oscillation area boundaries masks as ground truth data for sensor scans.
636
+
637
+ Unlike the standard triangular masks from load_sensor_oscillation_boundaries_masks, this method does not combine the
638
+ sensor dot potential with the barriers to determine full conductivity. Instead, it treats the two barriers
639
+ separately and draws fully conductive lines where each barrier becomes fully conductive (or effectively vanishes).
640
+
641
+ The oscillation area boundaries are marked with the values 1 for pinch-off lines and 2 for fully conductive lines.
642
+ The rest of the mask is 0. In 1D data the oscillation area boundaries are points, in 2D data lines.
643
+
644
+ Args:
645
+ file: The file to read the data from. Can either be an object of the type `h5py.File` or the path to the
646
+ dataset. If you want to do multiple consecutive loads from the same file (e.g. for using th PyTorch
647
+ SimcatsDataset without preloading), consider initializing the file object yourself and passing it, to
648
+ improve the performance.
649
+ specific_ids: Determines if only specific ids should be loaded. Using this option, the returned values are
650
+ sorted according to the specified ids and not necessarily ascending. If set to None, all data is loaded.
651
+ Default is None.
652
+ progress_bar: Determines whether to display a progress bar. Default is True.
653
+
654
+ Returns:
655
+ Oscillation area boundaries masks
656
+ """
657
+ metadata = load_dataset(file=file, load_csds=False, load_sensor_scans=False, load_metadata=True,
658
+ specific_ids=specific_ids, progress_bar=progress_bar).metadata
659
+ coordinates = load_dataset(file=file, load_csds=False, load_sensor_scans=False, load_line_coords=True,
660
+ specific_ids=specific_ids, progress_bar=progress_bar).line_coordinates
661
+ labels = load_dataset(file=file, load_csds=False, load_sensor_scans=False, load_line_labels=True,
662
+ specific_ids=specific_ids, progress_bar=progress_bar).line_labels
663
+
664
+ oscillation_boundaries_masks = []
665
+ for meta, coords, label in tqdm(zip(metadata, coordinates, labels), desc="calculating oscillation boundary masks",
666
+ total=len(metadata), disable=not progress_bar):
667
+ if isinstance(meta["resolution"], str): # 1D case
668
+ oscillation_mask = np.zeros(int(meta["resolution"]), dtype=np.uint8)
669
+ # if coord.shape[0] > 0:
670
+ coordinate_indices = points_2d_voltage_to_1d_idx_space(coordinates=coords[:,:2],
671
+ voltage_range_x=meta["sweep_range_sensor_g1"],
672
+ voltage_range_y=meta["sweep_range_sensor_g2"],
673
+ resolution=int(meta["resolution"]),
674
+ round_to_int=True)
675
+ for i, point in enumerate(coordinate_indices):
676
+ if "start" in label[i] or "pinch_off" in label[i]: # pinch-off line
677
+ oscillation_mask[point] = 1
678
+ elif "stop" in label[i] or "fully_conductive" in label[i]: # fully-conductive line
679
+ oscillation_mask[point] = 2
680
+
681
+ else: # if isinstance(meta["resolution"], list) (2D case)
682
+ oscillation_mask = np.zeros((meta["resolution"][1], meta["resolution"][0]), dtype=np.uint8)
683
+
684
+ coordinate_pixels = lines_voltage_to_pixel_space(lines=coords,
685
+ voltage_range_x=np.array(meta["sweep_range_sensor_g1"]),
686
+ voltage_range_y=np.array(meta["sweep_range_sensor_g2"]),
687
+ image_width=meta["resolution"][0],
688
+ image_height=meta["resolution"][1],
689
+ round_to_int=True)
690
+
691
+ for i, line in enumerate(coordinate_pixels):
692
+ max_res = np.abs(line[2] - line[0]) if (np.abs(line[2]- line[0]) > np.abs(line[3] - line[1])) else np.abs(line[3] - line[1])
693
+ max_res = max_res + 1 if ((max_res < meta["resolution"][0]) and (max_res < meta["resolution"][1])) else max_res
694
+ x_coords = np.linspace(line[0], line[2], max_res, dtype=int)
695
+ y_coords = np.linspace(line[1], line[3], max_res, dtype=int)
696
+
697
+ if "start" in label[i] or "pinch_off" in label[i]: # pinch-off line
698
+ oscillation_mask[y_coords, x_coords] = 1
699
+ elif "stop" in label[i] or ("fully_conductive" in label[i] and not "potential" in label[i]): # fully-conductive line
700
+ oscillation_mask[y_coords, x_coords] = 2
701
+
702
+ oscillation_boundaries_masks.append(oscillation_mask)
703
+
704
+ return oscillation_boundaries_masks
@@ -26,7 +26,8 @@ class SimcatsDataset(Dataset):
26
26
  load_ground_truth: Union[Callable, str, None] = None,
27
27
  data_preprocessors: Union[List[Union[str, Callable]], None] = None,
28
28
  ground_truth_preprocessors: Union[List[Union[str, Callable]], None] = None,
29
- format_output: Union[Callable, str, None] = None, preload: bool = True,
29
+ format_output: Union[Callable, str, None] = None,
30
+ preload: bool = True,
30
31
  max_concurrent_preloads: int = 100000,
31
32
  progress_bar: bool = False,
32
33
  sensor_scan_dataset: bool = False,):
@@ -95,6 +96,8 @@ class SimcatsDataset(Dataset):
95
96
  # check if it is possible to load the desired ground truth from the given dataset
96
97
  try:
97
98
  _ = self.load_ground_truth(file=self.__h5_path, specific_ids=[0], progress_bar=False)
99
+ except FileNotFoundError:
100
+ raise
98
101
  except:
99
102
  raise ValueError(
100
103
  f"The specified ground truth ({self.load_ground_truth.__name__}) can't be loaded for the given "
@@ -289,13 +292,17 @@ class SimcatsDataset(Dataset):
289
292
 
290
293
 
291
294
  class SimcatsConcatDataset(ConcatDataset):
295
+ """Pytorch ConcatDataset class implementation for SimCATS datasets. Uses simcats_datasets to load and provide (training) data.
296
+ """
297
+
292
298
  def __init__(self,
293
299
  h5_paths: List[str],
294
300
  specific_ids: Union[List[Union[range, int, np.ndarray, None]], None] = None,
295
301
  load_ground_truth: Union[Callable, str, None] = None,
296
302
  data_preprocessors: Union[List[Union[str, Callable]], None] = None,
297
303
  ground_truth_preprocessors: Union[List[Union[str, Callable]], None] = None,
298
- format_output: Union[Callable, str, None] = None, preload: bool = True,
304
+ format_output: Union[Callable, str, None] = None,
305
+ preload: bool = True,
299
306
  max_concurrent_preloads: int = 100000,
300
307
  progress_bar: bool = False,
301
308
  sensor_scan_dataset: bool = False,):
@@ -7,6 +7,7 @@ Used to clip single transition lines into the CSD space to generate transition s
7
7
  """
8
8
 
9
9
  from typing import Tuple, List, Union
10
+ import math
10
11
 
11
12
 
12
13
  def clip_slope_line_to_rectangle(slope: float, point: Tuple[float, float], rect_corners: List[Tuple[float, float]],
@@ -106,12 +107,22 @@ def clip_point_line_to_rectangle(start: Tuple[float, float], end: Tuple[float, f
106
107
  # Calculate the intersection point between the line and the rectangle edge
107
108
  intersection = line_intersection(start, end, rect_point1, rect_point2)
108
109
  if intersection is not None:
109
- if clipped_start is None:
110
+ # set the clipped start point to the newly found intersection, if either no clipped_start was found so far
111
+ # or if the current value is the same as the initially supplied end point. This happens, if the end point
112
+ # was exactly on the boundary and therefore also caused an intersection.
113
+ if (clipped_start is None
114
+ or (clipped_end is not None
115
+ and math.isclose(clipped_start[0], end[0])
116
+ and math.isclose(clipped_start[1], end[1]))):
110
117
  clipped_start = intersection
111
- elif clipped_end is None:
118
+ # set the clipped end point to the newly found intersection, if either no clipped_end was found so far or
119
+ # if the current value is the same as the initially supplied start point. This happens, if the start point
120
+ # was exactly on the boundary and therefore also caused an intersection.
121
+ elif (clipped_end is None
122
+ or (clipped_start is not None
123
+ and math.isclose(clipped_end[0], start[0])
124
+ and math.isclose(clipped_end[1], start[1]))):
112
125
  clipped_end = intersection
113
- if clipped_start is not None and clipped_end is not None:
114
- break
115
126
 
116
127
  return clipped_start, clipped_end
117
128
 
@@ -4,6 +4,7 @@
4
4
  """
5
5
 
6
6
  from copy import deepcopy
7
+ import math
7
8
  from typing import List, Union
8
9
 
9
10
  import numpy as np
@@ -108,3 +109,36 @@ def lines_convert_two_coordinates_to_coordinate_plus_change(lines: Union[List[np
108
109
  else:
109
110
  new_lines_pairs.append([p2[0], p2[1], p1[0] - p2[0], p1[1] - p2[1]])
110
111
  return np.array(new_lines_pairs)
112
+
113
+
114
+ def points_2d_voltage_to_1d_idx_space(coordinates: Union[List[np.ndarray], np.ndarray],
115
+ voltage_range_x: np.ndarray,
116
+ voltage_range_y: np.ndarray,
117
+ resolution: int,
118
+ round_to_int: bool = False) -> np.ndarray:
119
+ """Convert points from 2D voltage space to 1D array/index space.
120
+
121
+ Args:
122
+ coordinates: Array or list of points to convert, shape: (n, 2). \n
123
+ Example: \n
124
+ [[x, y], ...]
125
+ voltage_range_x: Voltage range in x direction.
126
+ voltage_range_y: Voltage range in y direction.
127
+ resolution: Size of the index space.
128
+ round_to_int: Toggles if the points are returned as floats (False) or are rounded and then returned as integers
129
+ (True). Defaults to false.
130
+
131
+ Returns:
132
+ Array with rows containing the converted points.
133
+ """
134
+ if math.isclose(voltage_range_x[0], voltage_range_x[1]): # vertical scan
135
+ coord_idx = ((resolution - 1) * (np.array(coordinates)[:, 1] - voltage_range_y[0])
136
+ / (voltage_range_y[1] - voltage_range_y[0]))
137
+ else:
138
+ coord_idx = ((resolution - 1) * (np.array(coordinates)[:, 0] - voltage_range_x[0])
139
+ / (voltage_range_x[1] - voltage_range_x[0]))
140
+
141
+ if round_to_int:
142
+ coord_idx = coord_idx.round(decimals=0).astype(int)
143
+
144
+ return coord_idx
@@ -13,6 +13,7 @@ import cv2
13
13
  import skimage.restoration
14
14
  import bm3d
15
15
  from scipy.signal import resample, decimate
16
+ import warnings
16
17
 
17
18
 
18
19
  def example_preprocessor(data: Union[np.ndarray, List[np.ndarray]]) -> Union[np.ndarray, List[np.ndarray]]:
@@ -49,6 +50,40 @@ def cast_to_float32(data: Union[np.ndarray, List[np.ndarray]]) -> Union[np.ndarr
49
50
  return data
50
51
 
51
52
 
53
+ def quantize_to_int_bit_depth(data: Union[np.ndarray, List[np.ndarray]],
54
+ bit_depth: int,
55
+ min_float: float,
56
+ max_float: float,
57
+ dtype: np.dtype = np.int_) -> Union[np.ndarray, List[np.ndarray]]:
58
+ """Quantize the data to unsigned integers of the specified bit depth.
59
+ This is especially useful to test algorithms that would rely on low resolution measurements, for example, if an
60
+ efficient readout is implemented in the cryostat. The data will be quantized by mapping the float range defined by
61
+ the supplied min and max values to an integer with the specified bit depth. This allows to use the whole integer bit
62
+ range for the signal, assuming that the readout is precisely calibrated for this range.
63
+
64
+ Args:
65
+ data: Numpy array to be cast to float32 (or a list of such).
66
+ bit_depth: Number of bits to quantize the data to.
67
+ min_float: Minimum expected value of the float numbers, that will be mapped to the minimum of the integer (0).
68
+ max_float: Maximum expected value of the float numbers, that will be mapped to the minimum of the integer
69
+ (2^(bit_depth)-1).
70
+ dtype: Data type of the output array. Defaults to np.int_.
71
+
72
+ Returns:
73
+ integer numpy array (or a list of such).
74
+ """
75
+ # handle list here, for example with list comprehension
76
+ if isinstance(data, list):
77
+ data = [quantize_to_int_bit_depth(temp_data, bit_depth=bit_depth, min=min_float, max=max_float) for temp_data in data]
78
+ else:
79
+ assert np.min(data) >= min_float and np.max(data) <= max_float
80
+ data -= min_float
81
+ data /= (max_float - min_float)
82
+ data *= 2**bit_depth - 1
83
+ data = np.round(data).astype(dtype)
84
+ return data
85
+
86
+
52
87
  def cast_to_float16(data: Union[np.ndarray, List[np.ndarray]]) -> Union[np.ndarray, List[np.ndarray]]:
53
88
  """Cast the data to float16. Especially useful to reduce memory usage for preloaded datasets.
54
89
 
@@ -108,6 +143,69 @@ def min_max_0_1(data: Union[np.ndarray, List[np.ndarray]]) -> Union[np.ndarray,
108
143
  return data
109
144
 
110
145
 
146
+ def min_max_0_1_global(data: Union[np.ndarray, List[np.ndarray]]) -> Union[np.ndarray, List[np.ndarray]]:
147
+ """Global min max scaling of the data to [0, 1] across the entire data set.
148
+
149
+ If a list of data arrays is provided, the global minimum and maximum are computed across all entries. These values
150
+ are then passed to min_max_0_1_given_limits to perform the scaling.
151
+
152
+ Args:
153
+ data: Numpy array to be scaled (or a list of such).
154
+
155
+ Returns:
156
+ Rescaled numpy array (or a list of such).
157
+ """
158
+ if isinstance(data, list):
159
+ # determine global minimum and maximum
160
+ global_min = np.min(data[0])
161
+ global_max = np.max(data[0])
162
+ for _data in data:
163
+ global_min = np.min(_data) if np.min(_data) < global_min else global_min
164
+ global_max = np.max(_data) if np.max(_data) > global_max else global_max
165
+
166
+ # min max scaling
167
+ for _data in data:
168
+ min_max_0_1_given_limits(_data, minimum=global_min, maximum=global_max)
169
+ else:
170
+ min_max_0_1_given_limits(data, minimum=np.min(data), maximum=np.max(data))
171
+ return data
172
+
173
+
174
+ def min_max_0_1_given_limits(data: Union[np.ndarray, List[np.ndarray]], minimum: float, maximum: float) -> Union[np.ndarray, List[np.ndarray]]:
175
+ """Min max scaling of the data to [0, 1] using the given limits minimum and maximum.
176
+
177
+ If a list of data is passed, the data is scaled using the given limits minimum and maximum. If the specified limits
178
+ of the min max scaling do not match the data and therefore scaled values lie outside the interval [0, 1], a warning
179
+ is generated.
180
+
181
+ Args:
182
+ data: Numpy array to be scaled (or a list of such).
183
+ minimum: Minimum value used for scaling.
184
+ maximum: Maximum value used for scaling.
185
+
186
+ Returns:
187
+ Rescaled numpy array (or a list of such).
188
+ """
189
+ if isinstance(data, list):
190
+ # min max scaling
191
+ for _data in data:
192
+ _data -= minimum
193
+ _data /= (maximum - minimum)
194
+ if np.any(_data < 0.0) or np.any(_data > 1.0):
195
+ warnings.warn(
196
+ f"The specified limits minimum={minimum} and maximum={maximum} do not match the data. "
197
+ f"At least one scaled value is outside the interval [0, 1].")
198
+
199
+ else:
200
+ data -= minimum
201
+ data /= (maximum - minimum)
202
+ if np.any(data < 0.0) or np.any(data > 1.0):
203
+ warnings.warn(
204
+ f"The specified limits minimum={minimum} and maximum={maximum} do not match the data. "
205
+ f"At least one scaled value is outside the interval [0, 1].")
206
+ return data
207
+
208
+
111
209
  def min_max_minus_one_one(data: Union[np.ndarray, List[np.ndarray]]) -> Union[np.ndarray, List[np.ndarray]]:
112
210
  """Min max scaling of the data to [-1, 1].
113
211
 
@@ -223,7 +321,7 @@ def resample_image(data: Union[np.ndarray, List[np.ndarray]], target_size: Tuple
223
321
  The resampled image or a list of such.
224
322
  """
225
323
  if isinstance(data, list):
226
- data = [resample_image(temp_data) for temp_data in data]
324
+ data = [resample_image(data=temp_data, target_size=target_size) for temp_data in data]
227
325
  else:
228
326
  if data.shape[0] > target_size[0]:
229
327
  data = resample(data, target_size[0], axis=0)
@@ -232,6 +330,19 @@ def resample_image(data: Union[np.ndarray, List[np.ndarray]], target_size: Tuple
232
330
  return data
233
331
 
234
332
 
333
+ def resample_image_to_32x32(data: Union[np.ndarray, List[np.ndarray]]) -> Union[
334
+ np.ndarray, List[np.ndarray]]:
335
+ """Resample an image to size of 32x32 using scipy.signal.resample.
336
+
337
+ Args:
338
+ data: The image to resample.
339
+
340
+ Returns:
341
+ The resampled image or a list of such.
342
+ """
343
+ return resample_image(data=data, target_size=(32, 32))
344
+
345
+
235
346
  def decimate_image(data: Union[np.ndarray, List[np.ndarray]], target_size: Tuple[int, int]) -> Union[
236
347
  np.ndarray, List[np.ndarray]]:
237
348
  """Decimate an image to target size using scipy.signal.decimate.