careamics 0.0.1__py3-none-any.whl → 0.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of careamics might be problematic. Click here for more details.
- careamics/__init__.py +6 -1
- careamics/careamist.py +729 -0
- careamics/config/__init__.py +39 -0
- careamics/config/architectures/__init__.py +17 -0
- careamics/config/architectures/architecture_model.py +37 -0
- careamics/config/architectures/custom_model.py +162 -0
- careamics/config/architectures/lvae_model.py +174 -0
- careamics/config/architectures/register_model.py +103 -0
- careamics/config/architectures/unet_model.py +118 -0
- careamics/config/callback_model.py +123 -0
- careamics/config/configuration_factory.py +583 -0
- careamics/config/configuration_model.py +604 -0
- careamics/config/data_model.py +527 -0
- careamics/config/fcn_algorithm_model.py +147 -0
- careamics/config/inference_model.py +239 -0
- careamics/config/likelihood_model.py +43 -0
- careamics/config/nm_model.py +101 -0
- careamics/config/optimizer_models.py +187 -0
- careamics/config/references/__init__.py +45 -0
- careamics/config/references/algorithm_descriptions.py +132 -0
- careamics/config/references/references.py +39 -0
- careamics/config/support/__init__.py +31 -0
- careamics/config/support/supported_activations.py +27 -0
- careamics/config/support/supported_algorithms.py +33 -0
- careamics/config/support/supported_architectures.py +17 -0
- careamics/config/support/supported_data.py +109 -0
- careamics/config/support/supported_loggers.py +10 -0
- careamics/config/support/supported_losses.py +29 -0
- careamics/config/support/supported_optimizers.py +57 -0
- careamics/config/support/supported_pixel_manipulations.py +15 -0
- careamics/config/support/supported_struct_axis.py +21 -0
- careamics/config/support/supported_transforms.py +11 -0
- careamics/config/tile_information.py +65 -0
- careamics/config/training_model.py +72 -0
- careamics/config/transformations/__init__.py +15 -0
- careamics/config/transformations/n2v_manipulate_model.py +64 -0
- careamics/config/transformations/normalize_model.py +60 -0
- careamics/config/transformations/transform_model.py +45 -0
- careamics/config/transformations/xy_flip_model.py +43 -0
- careamics/config/transformations/xy_random_rotate90_model.py +35 -0
- careamics/config/vae_algorithm_model.py +171 -0
- careamics/config/validators/__init__.py +5 -0
- careamics/config/validators/validator_utils.py +101 -0
- careamics/conftest.py +39 -0
- careamics/dataset/__init__.py +17 -0
- careamics/dataset/dataset_utils/__init__.py +19 -0
- careamics/dataset/dataset_utils/dataset_utils.py +101 -0
- careamics/dataset/dataset_utils/file_utils.py +141 -0
- careamics/dataset/dataset_utils/iterate_over_files.py +83 -0
- careamics/dataset/dataset_utils/running_stats.py +186 -0
- careamics/dataset/in_memory_dataset.py +310 -0
- careamics/dataset/in_memory_pred_dataset.py +88 -0
- careamics/dataset/in_memory_tiled_pred_dataset.py +129 -0
- careamics/dataset/iterable_dataset.py +295 -0
- careamics/dataset/iterable_pred_dataset.py +122 -0
- careamics/dataset/iterable_tiled_pred_dataset.py +140 -0
- careamics/dataset/patching/__init__.py +1 -0
- careamics/dataset/patching/patching.py +299 -0
- careamics/dataset/patching/random_patching.py +201 -0
- careamics/dataset/patching/sequential_patching.py +212 -0
- careamics/dataset/patching/validate_patch_dimension.py +64 -0
- careamics/dataset/tiling/__init__.py +10 -0
- careamics/dataset/tiling/collate_tiles.py +33 -0
- careamics/dataset/tiling/lvae_tiled_patching.py +282 -0
- careamics/dataset/tiling/tiled_patching.py +164 -0
- careamics/dataset/zarr_dataset.py +151 -0
- careamics/file_io/__init__.py +15 -0
- careamics/file_io/read/__init__.py +12 -0
- careamics/file_io/read/get_func.py +56 -0
- careamics/file_io/read/tiff.py +58 -0
- careamics/file_io/read/zarr.py +60 -0
- careamics/file_io/write/__init__.py +15 -0
- careamics/file_io/write/get_func.py +63 -0
- careamics/file_io/write/tiff.py +40 -0
- careamics/lightning/__init__.py +18 -0
- careamics/lightning/callbacks/__init__.py +11 -0
- careamics/lightning/callbacks/hyperparameters_callback.py +49 -0
- careamics/lightning/callbacks/prediction_writer_callback/__init__.py +20 -0
- careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py +56 -0
- careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +233 -0
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +398 -0
- careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +215 -0
- careamics/lightning/callbacks/progress_bar_callback.py +90 -0
- careamics/lightning/lightning_module.py +632 -0
- careamics/lightning/predict_data_module.py +333 -0
- careamics/lightning/train_data_module.py +680 -0
- careamics/losses/__init__.py +15 -0
- careamics/losses/fcn/__init__.py +1 -0
- careamics/losses/fcn/losses.py +98 -0
- careamics/losses/loss_factory.py +155 -0
- careamics/losses/lvae/__init__.py +1 -0
- careamics/losses/lvae/loss_utils.py +83 -0
- careamics/losses/lvae/losses.py +445 -0
- careamics/lvae_training/__init__.py +0 -0
- careamics/lvae_training/dataset/__init__.py +0 -0
- careamics/lvae_training/dataset/data_utils.py +701 -0
- careamics/lvae_training/dataset/lc_dataset.py +259 -0
- careamics/lvae_training/dataset/lc_dataset_config.py +13 -0
- careamics/lvae_training/dataset/vae_data_config.py +179 -0
- careamics/lvae_training/dataset/vae_dataset.py +1054 -0
- careamics/lvae_training/eval_utils.py +905 -0
- careamics/lvae_training/get_config.py +84 -0
- careamics/lvae_training/lightning_module.py +701 -0
- careamics/lvae_training/metrics.py +214 -0
- careamics/lvae_training/train_lvae.py +342 -0
- careamics/lvae_training/train_utils.py +121 -0
- careamics/model_io/__init__.py +7 -0
- careamics/model_io/bioimage/__init__.py +11 -0
- careamics/model_io/bioimage/_readme_factory.py +121 -0
- careamics/model_io/bioimage/bioimage_utils.py +52 -0
- careamics/model_io/bioimage/model_description.py +327 -0
- careamics/model_io/bmz_io.py +246 -0
- careamics/model_io/model_io_utils.py +95 -0
- careamics/models/__init__.py +5 -0
- careamics/models/activation.py +39 -0
- careamics/models/layers.py +493 -0
- careamics/models/lvae/__init__.py +3 -0
- careamics/models/lvae/layers.py +1998 -0
- careamics/models/lvae/likelihoods.py +364 -0
- careamics/models/lvae/lvae.py +901 -0
- careamics/models/lvae/noise_models.py +541 -0
- careamics/models/lvae/utils.py +395 -0
- careamics/models/model_factory.py +67 -0
- careamics/models/unet.py +443 -0
- careamics/prediction_utils/__init__.py +10 -0
- careamics/prediction_utils/lvae_prediction.py +158 -0
- careamics/prediction_utils/lvae_tiling_manager.py +362 -0
- careamics/prediction_utils/prediction_outputs.py +135 -0
- careamics/prediction_utils/stitch_prediction.py +112 -0
- careamics/transforms/__init__.py +20 -0
- careamics/transforms/compose.py +107 -0
- careamics/transforms/n2v_manipulate.py +146 -0
- careamics/transforms/normalize.py +243 -0
- careamics/transforms/pixel_manipulation.py +407 -0
- careamics/transforms/struct_mask_parameters.py +20 -0
- careamics/transforms/transform.py +24 -0
- careamics/transforms/tta.py +88 -0
- careamics/transforms/xy_flip.py +123 -0
- careamics/transforms/xy_random_rotate90.py +101 -0
- careamics/utils/__init__.py +19 -0
- careamics/utils/autocorrelation.py +40 -0
- careamics/utils/base_enum.py +60 -0
- careamics/utils/context.py +66 -0
- careamics/utils/logging.py +322 -0
- careamics/utils/metrics.py +188 -0
- careamics/utils/path_utils.py +26 -0
- careamics/utils/ram.py +15 -0
- careamics/utils/receptive_field.py +108 -0
- careamics/utils/torch_utils.py +127 -0
- careamics-0.0.3.dist-info/METADATA +78 -0
- careamics-0.0.3.dist-info/RECORD +154 -0
- {careamics-0.0.1.dist-info → careamics-0.0.3.dist-info}/WHEEL +1 -1
- {careamics-0.0.1.dist-info → careamics-0.0.3.dist-info}/licenses/LICENSE +1 -1
- careamics-0.0.1.dist-info/METADATA +0 -46
- careamics-0.0.1.dist-info/RECORD +0 -6
|
@@ -0,0 +1,701 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Utility functions needed by dataloader & co.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import os
|
|
6
|
+
from dataclasses import dataclass
|
|
7
|
+
from typing import List
|
|
8
|
+
|
|
9
|
+
import numpy as np
|
|
10
|
+
from skimage.io import imread, imsave
|
|
11
|
+
|
|
12
|
+
from careamics.lvae_training.dataset.vae_data_config import DataSplitType, DataType
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def load_tiff(path):
|
|
16
|
+
"""
|
|
17
|
+
Returns a 4d numpy array: num_imgs*h*w*num_channels
|
|
18
|
+
"""
|
|
19
|
+
data = imread(path, plugin="tifffile")
|
|
20
|
+
return data
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def save_tiff(path, data):
|
|
24
|
+
imsave(path, data, plugin="tifffile")
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def load_tiffs(paths):
|
|
28
|
+
data = [load_tiff(path) for path in paths]
|
|
29
|
+
return np.concatenate(data, axis=0)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def split_in_half(s, e):
|
|
33
|
+
n = e - s
|
|
34
|
+
s1 = list(np.arange(n // 2))
|
|
35
|
+
s2 = list(np.arange(n // 2, n))
|
|
36
|
+
return [x + s for x in s1], [x + s for x in s2]
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def adjust_for_imbalance_in_fraction_value(
|
|
40
|
+
val: List[int],
|
|
41
|
+
test: List[int],
|
|
42
|
+
val_fraction: float,
|
|
43
|
+
test_fraction: float,
|
|
44
|
+
total_size: int,
|
|
45
|
+
):
|
|
46
|
+
"""
|
|
47
|
+
here, val and test are divided almost equally. Here, we need to take into account their respective fractions
|
|
48
|
+
and pick elements rendomly from one array and put in the other array.
|
|
49
|
+
"""
|
|
50
|
+
if val_fraction == 0:
|
|
51
|
+
test += val
|
|
52
|
+
val = []
|
|
53
|
+
elif test_fraction == 0:
|
|
54
|
+
val += test
|
|
55
|
+
test = []
|
|
56
|
+
else:
|
|
57
|
+
diff_fraction = test_fraction - val_fraction
|
|
58
|
+
if diff_fraction > 0:
|
|
59
|
+
imb_count = int(diff_fraction * total_size / 2)
|
|
60
|
+
val = list(np.random.RandomState(seed=955).permutation(val))
|
|
61
|
+
test += val[:imb_count]
|
|
62
|
+
val = val[imb_count:]
|
|
63
|
+
elif diff_fraction < 0:
|
|
64
|
+
imb_count = int(-1 * diff_fraction * total_size / 2)
|
|
65
|
+
test = list(np.random.RandomState(seed=955).permutation(test))
|
|
66
|
+
val += test[:imb_count]
|
|
67
|
+
test = test[imb_count:]
|
|
68
|
+
return val, test
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def get_train_val_data(
|
|
72
|
+
data_config,
|
|
73
|
+
fpath,
|
|
74
|
+
datasplit_type: DataSplitType,
|
|
75
|
+
val_fraction=None,
|
|
76
|
+
test_fraction=None,
|
|
77
|
+
allow_generation=False, # TODO: what is this
|
|
78
|
+
):
|
|
79
|
+
"""
|
|
80
|
+
Load the data from the given path and split them in training, validation and test sets.
|
|
81
|
+
|
|
82
|
+
Ensure that the shape of data should be N*H*W*C: N is number of data points. H,W are the image dimensions.
|
|
83
|
+
C is the number of channels.
|
|
84
|
+
"""
|
|
85
|
+
if data_config.data_type == DataType.SeparateTiffData:
|
|
86
|
+
fpath1 = os.path.join(fpath, data_config.ch1_fname)
|
|
87
|
+
fpath2 = os.path.join(fpath, data_config.ch2_fname)
|
|
88
|
+
fpaths = [fpath1, fpath2]
|
|
89
|
+
fpath0 = ""
|
|
90
|
+
if "ch_input_fname" in data_config:
|
|
91
|
+
fpath0 = os.path.join(fpath, data_config.ch_input_fname)
|
|
92
|
+
fpaths = [fpath0] + fpaths
|
|
93
|
+
|
|
94
|
+
print(
|
|
95
|
+
f"Loading from {fpath} Channels: "
|
|
96
|
+
f"{fpath1},{fpath2}, inp:{fpath0} Mode:{DataSplitType.name(datasplit_type)}"
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
data = np.concatenate([load_tiff(fpath)[..., None] for fpath in fpaths], axis=3)
|
|
100
|
+
if data_config.data_type == DataType.PredictedTiffData:
|
|
101
|
+
assert len(data.shape) == 5 and data.shape[-1] == 1
|
|
102
|
+
data = data[..., 0].copy()
|
|
103
|
+
# data = data[::3].copy()
|
|
104
|
+
# NOTE: This was not the correct way to do it. It is so because the noise present in the input was directly related
|
|
105
|
+
# to the noise present in the channels and so this is not the way we would get the data.
|
|
106
|
+
# We need to add the noise independently to the input and the target.
|
|
107
|
+
|
|
108
|
+
# if data_config.get('poisson_noise_factor', False):
|
|
109
|
+
# data = np.random.poisson(data)
|
|
110
|
+
# if data_config.get('enable_gaussian_noise', False):
|
|
111
|
+
# synthetic_scale = data_config.get('synthetic_gaussian_scale', 0.1)
|
|
112
|
+
# print('Adding Gaussian noise with scale', synthetic_scale)
|
|
113
|
+
# noise = np.random.normal(0, synthetic_scale, data.shape)
|
|
114
|
+
# data = data + noise
|
|
115
|
+
|
|
116
|
+
if datasplit_type == DataSplitType.All:
|
|
117
|
+
return data.astype(np.float32)
|
|
118
|
+
|
|
119
|
+
train_idx, val_idx, test_idx = get_datasplit_tuples(
|
|
120
|
+
val_fraction, test_fraction, len(data), starting_test=True
|
|
121
|
+
)
|
|
122
|
+
if datasplit_type == DataSplitType.Train:
|
|
123
|
+
return data[train_idx].astype(np.float32)
|
|
124
|
+
elif datasplit_type == DataSplitType.Val:
|
|
125
|
+
return data[val_idx].astype(np.float32)
|
|
126
|
+
elif datasplit_type == DataSplitType.Test:
|
|
127
|
+
return data[test_idx].astype(np.float32)
|
|
128
|
+
|
|
129
|
+
elif data_config.data_type == DataType.BioSR_MRC:
|
|
130
|
+
num_channels = data_config.num_channels
|
|
131
|
+
fpaths = []
|
|
132
|
+
data_list = []
|
|
133
|
+
for i in range(num_channels):
|
|
134
|
+
fpath1 = os.path.join(fpath, getattr(data_config, f"ch{i + 1}_fname"))
|
|
135
|
+
fpaths.append(fpath1)
|
|
136
|
+
data = get_mrc_data(fpath1)[..., None]
|
|
137
|
+
data_list.append(data)
|
|
138
|
+
|
|
139
|
+
dirname = os.path.dirname(os.path.dirname(fpaths[0])) + "/"
|
|
140
|
+
|
|
141
|
+
msg = ",".join([x[len(dirname) :] for x in fpaths])
|
|
142
|
+
print(
|
|
143
|
+
f"Loaded from {dirname} Channels:{len(fpaths)} {msg} Mode:{datasplit_type}"
|
|
144
|
+
)
|
|
145
|
+
N = data_list[0].shape[0]
|
|
146
|
+
for data in data_list:
|
|
147
|
+
N = min(N, data.shape[0])
|
|
148
|
+
|
|
149
|
+
cropped_data = []
|
|
150
|
+
for data in data_list:
|
|
151
|
+
cropped_data.append(data[:N])
|
|
152
|
+
|
|
153
|
+
data = np.concatenate(cropped_data, axis=3)
|
|
154
|
+
|
|
155
|
+
if datasplit_type == DataSplitType.All:
|
|
156
|
+
return data.astype(np.float32)
|
|
157
|
+
|
|
158
|
+
train_idx, val_idx, test_idx = get_datasplit_tuples(
|
|
159
|
+
val_fraction, test_fraction, len(data), starting_test=True
|
|
160
|
+
)
|
|
161
|
+
if datasplit_type == DataSplitType.Train:
|
|
162
|
+
return data[train_idx].astype(np.float32)
|
|
163
|
+
elif datasplit_type == DataSplitType.Val:
|
|
164
|
+
return data[val_idx].astype(np.float32)
|
|
165
|
+
elif datasplit_type == DataSplitType.Test:
|
|
166
|
+
return data[test_idx].astype(np.float32)
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
def get_datasplit_tuples(
|
|
170
|
+
val_fraction: float,
|
|
171
|
+
test_fraction: float,
|
|
172
|
+
total_size: int,
|
|
173
|
+
starting_test: bool = False,
|
|
174
|
+
):
|
|
175
|
+
if starting_test:
|
|
176
|
+
# test => val => train
|
|
177
|
+
test = list(range(0, int(total_size * test_fraction)))
|
|
178
|
+
val = list(range(test[-1] + 1, test[-1] + 1 + int(total_size * val_fraction)))
|
|
179
|
+
train = list(range(val[-1] + 1, total_size))
|
|
180
|
+
else:
|
|
181
|
+
# {test,val}=> train
|
|
182
|
+
test_val_size = int((val_fraction + test_fraction) * total_size)
|
|
183
|
+
train = list(range(test_val_size, total_size))
|
|
184
|
+
|
|
185
|
+
if test_val_size == 0:
|
|
186
|
+
test = []
|
|
187
|
+
val = []
|
|
188
|
+
return train, val, test
|
|
189
|
+
|
|
190
|
+
# Split the test and validation in chunks.
|
|
191
|
+
chunksize = max(1, min(3, test_val_size // 2))
|
|
192
|
+
|
|
193
|
+
nchunks = test_val_size // chunksize
|
|
194
|
+
|
|
195
|
+
test = []
|
|
196
|
+
val = []
|
|
197
|
+
s = 0
|
|
198
|
+
for i in range(nchunks):
|
|
199
|
+
if i % 2 == 0:
|
|
200
|
+
val += list(np.arange(s, s + chunksize))
|
|
201
|
+
else:
|
|
202
|
+
test += list(np.arange(s, s + chunksize))
|
|
203
|
+
s += chunksize
|
|
204
|
+
|
|
205
|
+
if i % 2 == 0:
|
|
206
|
+
test += list(np.arange(s, test_val_size))
|
|
207
|
+
else:
|
|
208
|
+
p1, p2 = split_in_half(s, test_val_size)
|
|
209
|
+
test += p1
|
|
210
|
+
val += p2
|
|
211
|
+
|
|
212
|
+
val, test = adjust_for_imbalance_in_fraction_value(
|
|
213
|
+
val, test, val_fraction, test_fraction, total_size
|
|
214
|
+
)
|
|
215
|
+
|
|
216
|
+
return train, val, test
|
|
217
|
+
|
|
218
|
+
|
|
219
|
+
def get_mrc_data(fpath):
|
|
220
|
+
# HXWXN
|
|
221
|
+
_, data = read_mrc(fpath)
|
|
222
|
+
data = data[None]
|
|
223
|
+
data = np.swapaxes(data, 0, 3)
|
|
224
|
+
return data[..., 0]
|
|
225
|
+
|
|
226
|
+
|
|
227
|
+
@dataclass
|
|
228
|
+
class GridIndexManager:
|
|
229
|
+
data_shape: tuple
|
|
230
|
+
grid_shape: tuple
|
|
231
|
+
patch_shape: tuple
|
|
232
|
+
trim_boundary: bool
|
|
233
|
+
|
|
234
|
+
# Vera: patch is centered on index in the grid, grid size not used in training,
|
|
235
|
+
# used only during val / test, grid size controls the overlap of the patches
|
|
236
|
+
# in training you only get random patches every time
|
|
237
|
+
# For borders - just cropped the data, so it perfectly divisible
|
|
238
|
+
|
|
239
|
+
def __post_init__(self):
|
|
240
|
+
assert len(self.data_shape) == len(
|
|
241
|
+
self.grid_shape
|
|
242
|
+
), f"Data shape:{self.data_shape} and grid size:{self.grid_shape} must have the same dimension"
|
|
243
|
+
assert len(self.data_shape) == len(
|
|
244
|
+
self.patch_shape
|
|
245
|
+
), f"Data shape:{self.data_shape} and patch shape:{self.patch_shape} must have the same dimension"
|
|
246
|
+
innerpad = np.array(self.patch_shape) - np.array(self.grid_shape)
|
|
247
|
+
for dim, pad in enumerate(innerpad):
|
|
248
|
+
if pad < 0:
|
|
249
|
+
raise ValueError(
|
|
250
|
+
f"Patch shape:{self.patch_shape} must be greater than or equal to grid shape:{self.grid_shape} in dimension {dim}"
|
|
251
|
+
)
|
|
252
|
+
if pad % 2 != 0:
|
|
253
|
+
raise ValueError(
|
|
254
|
+
f"Patch shape:{self.patch_shape} must have even padding in dimension {dim}"
|
|
255
|
+
)
|
|
256
|
+
|
|
257
|
+
def patch_offset(self):
|
|
258
|
+
return (np.array(self.patch_shape) - np.array(self.grid_shape)) // 2
|
|
259
|
+
|
|
260
|
+
def get_individual_dim_grid_count(self, dim: int):
|
|
261
|
+
"""
|
|
262
|
+
Returns the number of the grid in the specified dimension, ignoring all other dimensions.
|
|
263
|
+
"""
|
|
264
|
+
assert dim < len(
|
|
265
|
+
self.data_shape
|
|
266
|
+
), f"Dimension {dim} is out of bounds for data shape {self.data_shape}"
|
|
267
|
+
assert dim >= 0, "Dimension must be greater than or equal to 0"
|
|
268
|
+
|
|
269
|
+
if self.grid_shape[dim] == 1 and self.patch_shape[dim] == 1:
|
|
270
|
+
return self.data_shape[dim]
|
|
271
|
+
elif self.trim_boundary is False:
|
|
272
|
+
return int(np.ceil(self.data_shape[dim] / self.grid_shape[dim]))
|
|
273
|
+
else:
|
|
274
|
+
excess_size = self.patch_shape[dim] - self.grid_shape[dim]
|
|
275
|
+
return int(
|
|
276
|
+
np.floor((self.data_shape[dim] - excess_size) / self.grid_shape[dim])
|
|
277
|
+
)
|
|
278
|
+
|
|
279
|
+
def total_grid_count(self):
|
|
280
|
+
"""
|
|
281
|
+
Returns the total number of grids in the dataset.
|
|
282
|
+
"""
|
|
283
|
+
return self.grid_count(0) * self.get_individual_dim_grid_count(0)
|
|
284
|
+
|
|
285
|
+
def grid_count(self, dim: int):
|
|
286
|
+
"""
|
|
287
|
+
Returns the total number of grids for one value in the specified dimension.
|
|
288
|
+
"""
|
|
289
|
+
assert dim < len(
|
|
290
|
+
self.data_shape
|
|
291
|
+
), f"Dimension {dim} is out of bounds for data shape {self.data_shape}"
|
|
292
|
+
assert dim >= 0, "Dimension must be greater than or equal to 0"
|
|
293
|
+
if dim == len(self.data_shape) - 1:
|
|
294
|
+
return 1
|
|
295
|
+
|
|
296
|
+
return self.get_individual_dim_grid_count(dim + 1) * self.grid_count(dim + 1)
|
|
297
|
+
|
|
298
|
+
def get_grid_index(self, dim: int, coordinate: int):
|
|
299
|
+
"""
|
|
300
|
+
Returns the index of the grid in the specified dimension.
|
|
301
|
+
"""
|
|
302
|
+
assert dim < len(
|
|
303
|
+
self.data_shape
|
|
304
|
+
), f"Dimension {dim} is out of bounds for data shape {self.data_shape}"
|
|
305
|
+
assert dim >= 0, "Dimension must be greater than or equal to 0"
|
|
306
|
+
assert (
|
|
307
|
+
coordinate < self.data_shape[dim]
|
|
308
|
+
), f"Coordinate {coordinate} is out of bounds for data shape {self.data_shape}"
|
|
309
|
+
|
|
310
|
+
if self.grid_shape[dim] == 1 and self.patch_shape[dim] == 1:
|
|
311
|
+
return coordinate
|
|
312
|
+
elif self.trim_boundary is False:
|
|
313
|
+
return np.floor(coordinate / self.grid_shape[dim])
|
|
314
|
+
else:
|
|
315
|
+
excess_size = (self.patch_shape[dim] - self.grid_shape[dim]) // 2
|
|
316
|
+
# can be <0 if coordinate is in [0,grid_shape[dim]]
|
|
317
|
+
return max(0, np.floor((coordinate - excess_size) / self.grid_shape[dim]))
|
|
318
|
+
|
|
319
|
+
def dataset_idx_from_grid_idx(self, grid_idx: tuple):
|
|
320
|
+
"""
|
|
321
|
+
Returns the index of the grid in the dataset.
|
|
322
|
+
"""
|
|
323
|
+
assert len(grid_idx) == len(
|
|
324
|
+
self.data_shape
|
|
325
|
+
), f"Dimension indices {grid_idx} must have the same dimension as data shape {self.data_shape}"
|
|
326
|
+
index = 0
|
|
327
|
+
for dim in range(len(grid_idx)):
|
|
328
|
+
index += grid_idx[dim] * self.grid_count(dim)
|
|
329
|
+
return index
|
|
330
|
+
|
|
331
|
+
def get_patch_location_from_dataset_idx(self, dataset_idx: int):
|
|
332
|
+
"""
|
|
333
|
+
Returns the patch location of the grid in the dataset.
|
|
334
|
+
"""
|
|
335
|
+
location = self.get_location_from_dataset_idx(dataset_idx)
|
|
336
|
+
offset = self.patch_offset()
|
|
337
|
+
return tuple(np.array(location) - np.array(offset))
|
|
338
|
+
|
|
339
|
+
def get_dataset_idx_from_grid_location(self, location: tuple):
|
|
340
|
+
assert len(location) == len(
|
|
341
|
+
self.data_shape
|
|
342
|
+
), f"Location {location} must have the same dimension as data shape {self.data_shape}"
|
|
343
|
+
grid_idx = [
|
|
344
|
+
self.get_grid_index(dim, location[dim]) for dim in range(len(location))
|
|
345
|
+
]
|
|
346
|
+
return self.dataset_idx_from_grid_idx(tuple(grid_idx))
|
|
347
|
+
|
|
348
|
+
def get_gridstart_location_from_dim_index(self, dim: int, dim_index: int):
|
|
349
|
+
"""
|
|
350
|
+
Returns the grid-start coordinate of the grid in the specified dimension.
|
|
351
|
+
"""
|
|
352
|
+
assert dim < len(
|
|
353
|
+
self.data_shape
|
|
354
|
+
), f"Dimension {dim} is out of bounds for data shape {self.data_shape}"
|
|
355
|
+
assert dim >= 0, "Dimension must be greater than or equal to 0"
|
|
356
|
+
assert dim_index < self.get_individual_dim_grid_count(
|
|
357
|
+
dim
|
|
358
|
+
), f"Dimension index {dim_index} is out of bounds for data shape {self.data_shape}"
|
|
359
|
+
|
|
360
|
+
if self.grid_shape[dim] == 1 and self.patch_shape[dim] == 1:
|
|
361
|
+
return dim_index
|
|
362
|
+
elif self.trim_boundary is False:
|
|
363
|
+
return dim_index * self.grid_shape[dim]
|
|
364
|
+
else:
|
|
365
|
+
excess_size = (self.patch_shape[dim] - self.grid_shape[dim]) // 2
|
|
366
|
+
return dim_index * self.grid_shape[dim] + excess_size
|
|
367
|
+
|
|
368
|
+
def get_location_from_dataset_idx(self, dataset_idx: int):
|
|
369
|
+
grid_idx = []
|
|
370
|
+
for dim in range(len(self.data_shape)):
|
|
371
|
+
grid_idx.append(dataset_idx // self.grid_count(dim))
|
|
372
|
+
dataset_idx = dataset_idx % self.grid_count(dim)
|
|
373
|
+
location = [
|
|
374
|
+
self.get_gridstart_location_from_dim_index(dim, grid_idx[dim])
|
|
375
|
+
for dim in range(len(self.data_shape))
|
|
376
|
+
]
|
|
377
|
+
return tuple(location)
|
|
378
|
+
|
|
379
|
+
def on_boundary(self, dataset_idx: int, dim: int):
|
|
380
|
+
"""
|
|
381
|
+
Returns True if the grid is on the boundary in the specified dimension.
|
|
382
|
+
"""
|
|
383
|
+
assert dim < len(
|
|
384
|
+
self.data_shape
|
|
385
|
+
), f"Dimension {dim} is out of bounds for data shape {self.data_shape}"
|
|
386
|
+
assert dim >= 0, "Dimension must be greater than or equal to 0"
|
|
387
|
+
|
|
388
|
+
if dim > 0:
|
|
389
|
+
dataset_idx = dataset_idx % self.grid_count(dim - 1)
|
|
390
|
+
|
|
391
|
+
dim_index = dataset_idx // self.grid_count(dim)
|
|
392
|
+
return (
|
|
393
|
+
dim_index == 0 or dim_index == self.get_individual_dim_grid_count(dim) - 1
|
|
394
|
+
)
|
|
395
|
+
|
|
396
|
+
def next_grid_along_dim(self, dataset_idx: int, dim: int):
|
|
397
|
+
"""
|
|
398
|
+
Returns the index of the grid in the specified dimension in the specified direction.
|
|
399
|
+
"""
|
|
400
|
+
assert dim < len(
|
|
401
|
+
self.data_shape
|
|
402
|
+
), f"Dimension {dim} is out of bounds for data shape {self.data_shape}"
|
|
403
|
+
assert dim >= 0, "Dimension must be greater than or equal to 0"
|
|
404
|
+
new_idx = dataset_idx + self.grid_count(dim)
|
|
405
|
+
if new_idx >= self.total_grid_count():
|
|
406
|
+
return None
|
|
407
|
+
return new_idx
|
|
408
|
+
|
|
409
|
+
def prev_grid_along_dim(self, dataset_idx: int, dim: int):
|
|
410
|
+
"""
|
|
411
|
+
Returns the index of the grid in the specified dimension in the specified direction.
|
|
412
|
+
"""
|
|
413
|
+
assert dim < len(
|
|
414
|
+
self.data_shape
|
|
415
|
+
), f"Dimension {dim} is out of bounds for data shape {self.data_shape}"
|
|
416
|
+
assert dim >= 0, "Dimension must be greater than or equal to 0"
|
|
417
|
+
new_idx = dataset_idx - self.grid_count(dim)
|
|
418
|
+
if new_idx < 0:
|
|
419
|
+
return None
|
|
420
|
+
|
|
421
|
+
|
|
422
|
+
class IndexSwitcher:
|
|
423
|
+
"""
|
|
424
|
+
The idea is to switch from valid indices for target to invalid indices for target.
|
|
425
|
+
If index in invalid for the target, then we return all zero vector as target.
|
|
426
|
+
This combines both logic:
|
|
427
|
+
1. Using less amount of total data.
|
|
428
|
+
2. Using less amount of target data but using full data.
|
|
429
|
+
"""
|
|
430
|
+
|
|
431
|
+
def __init__(self, idx_manager, data_config, patch_size) -> None:
|
|
432
|
+
self.idx_manager = idx_manager
|
|
433
|
+
self._data_shape = self.idx_manager.get_data_shape()
|
|
434
|
+
self._training_validtarget_fraction = data_config.get(
|
|
435
|
+
"training_validtarget_fraction", 1.0
|
|
436
|
+
)
|
|
437
|
+
self._validtarget_ceilT = int(
|
|
438
|
+
np.ceil(self._data_shape[0] * self._training_validtarget_fraction)
|
|
439
|
+
)
|
|
440
|
+
self._patch_size = patch_size
|
|
441
|
+
assert (
|
|
442
|
+
data_config.deterministic_grid is True
|
|
443
|
+
), "This only works when the dataset has deterministic grid. Needed randomness comes from this class."
|
|
444
|
+
assert (
|
|
445
|
+
"grid_size" in data_config and data_config.grid_size == 1
|
|
446
|
+
), "We need a one to one mapping between index and h, w, t"
|
|
447
|
+
|
|
448
|
+
self._h_validmax, self._w_validmax = self.get_reduced_frame_size(
|
|
449
|
+
self._data_shape[:3], self._training_validtarget_fraction
|
|
450
|
+
)
|
|
451
|
+
if self._h_validmax < self._patch_size or self._w_validmax < self._patch_size:
|
|
452
|
+
print(
|
|
453
|
+
"WARNING: The valid target size is smaller than the patch size. This will result in all zero target. so, we are ignoring this frame for target."
|
|
454
|
+
)
|
|
455
|
+
self._h_validmax = 0
|
|
456
|
+
self._w_validmax = 0
|
|
457
|
+
|
|
458
|
+
print(
|
|
459
|
+
f"[{self.__class__.__name__}] Target Indices: [0,{self._validtarget_ceilT - 1}]. Index={self._validtarget_ceilT - 1} has shape [:{self._h_validmax},:{self._w_validmax}]. Available data: {self._data_shape[0]}"
|
|
460
|
+
)
|
|
461
|
+
|
|
462
|
+
def get_valid_target_index(self):
|
|
463
|
+
"""
|
|
464
|
+
Returns an index which corresponds to a frame which is expected to have a target.
|
|
465
|
+
"""
|
|
466
|
+
_, h, w, _ = self._data_shape
|
|
467
|
+
framepixelcount = h * w
|
|
468
|
+
targetpixels = np.array(
|
|
469
|
+
[framepixelcount] * (self._validtarget_ceilT - 1)
|
|
470
|
+
+ [self._h_validmax * self._w_validmax]
|
|
471
|
+
)
|
|
472
|
+
targetpixels = targetpixels / np.sum(targetpixels)
|
|
473
|
+
t = np.random.choice(self._validtarget_ceilT, p=targetpixels)
|
|
474
|
+
# t = np.random.randint(0, self._validtarget_ceilT) if self._validtarget_ceilT >= 1 else 0
|
|
475
|
+
h, w = self.get_valid_target_hw(t)
|
|
476
|
+
index = self.idx_manager.idx_from_hwt(h, w, t)
|
|
477
|
+
# print('Valid', index, h,w,t)
|
|
478
|
+
return index
|
|
479
|
+
|
|
480
|
+
def get_invalid_target_index(self):
|
|
481
|
+
# if self._validtarget_ceilT == 0:
|
|
482
|
+
# TODO: There may not be enough data for this to work. The better way is to skip using 0 for invalid target.
|
|
483
|
+
# t = np.random.randint(1, self._data_shape[0])
|
|
484
|
+
# elif self._validtarget_ceilT < self._data_shape[0]:
|
|
485
|
+
# t = np.random.randint(self._validtarget_ceilT, self._data_shape[0])
|
|
486
|
+
# else:
|
|
487
|
+
# t = self._validtarget_ceilT - 1
|
|
488
|
+
# 5
|
|
489
|
+
# 1.2 => 2
|
|
490
|
+
total_t, h, w, _ = self._data_shape
|
|
491
|
+
framepixelcount = h * w
|
|
492
|
+
available_h = h - self._h_validmax
|
|
493
|
+
if available_h < self._patch_size:
|
|
494
|
+
available_h = 0
|
|
495
|
+
available_w = w - self._w_validmax
|
|
496
|
+
if available_w < self._patch_size:
|
|
497
|
+
available_w = 0
|
|
498
|
+
|
|
499
|
+
targetpixels = np.array(
|
|
500
|
+
[available_h * available_w]
|
|
501
|
+
+ [framepixelcount] * (total_t - self._validtarget_ceilT)
|
|
502
|
+
)
|
|
503
|
+
t_probab = targetpixels / np.sum(targetpixels)
|
|
504
|
+
t = np.random.choice(
|
|
505
|
+
np.arange(self._validtarget_ceilT - 1, total_t), p=t_probab
|
|
506
|
+
)
|
|
507
|
+
|
|
508
|
+
h, w = self.get_invalid_target_hw(t)
|
|
509
|
+
index = self.idx_manager.idx_from_hwt(h, w, t)
|
|
510
|
+
# print('Invalid', index, h,w,t)
|
|
511
|
+
return index
|
|
512
|
+
|
|
513
|
+
def get_valid_target_hw(self, t):
|
|
514
|
+
"""
|
|
515
|
+
This is the opposite of get_invalid_target_hw. It returns a h,w which is valid for target.
|
|
516
|
+
This is only valid for single frame setup.
|
|
517
|
+
"""
|
|
518
|
+
if t == self._validtarget_ceilT - 1:
|
|
519
|
+
h = np.random.randint(0, self._h_validmax - self._patch_size)
|
|
520
|
+
w = np.random.randint(0, self._w_validmax - self._patch_size)
|
|
521
|
+
else:
|
|
522
|
+
h = np.random.randint(0, self._data_shape[1] - self._patch_size)
|
|
523
|
+
w = np.random.randint(0, self._data_shape[2] - self._patch_size)
|
|
524
|
+
return h, w
|
|
525
|
+
|
|
526
|
+
def get_invalid_target_hw(self, t):
|
|
527
|
+
"""
|
|
528
|
+
This is the opposite of get_valid_target_hw. It returns a h,w which is not valid for target.
|
|
529
|
+
This is only valid for single frame setup.
|
|
530
|
+
"""
|
|
531
|
+
if t == self._validtarget_ceilT - 1:
|
|
532
|
+
h = np.random.randint(
|
|
533
|
+
self._h_validmax, self._data_shape[1] - self._patch_size
|
|
534
|
+
)
|
|
535
|
+
w = np.random.randint(
|
|
536
|
+
self._w_validmax, self._data_shape[2] - self._patch_size
|
|
537
|
+
)
|
|
538
|
+
else:
|
|
539
|
+
h = np.random.randint(0, self._data_shape[1] - self._patch_size)
|
|
540
|
+
w = np.random.randint(0, self._data_shape[2] - self._patch_size)
|
|
541
|
+
return h, w
|
|
542
|
+
|
|
543
|
+
def _get_tidx(self, index):
|
|
544
|
+
if isinstance(index, int) or isinstance(index, np.int64):
|
|
545
|
+
idx = index
|
|
546
|
+
else:
|
|
547
|
+
idx = index[0]
|
|
548
|
+
return self.idx_manager.get_t(idx)
|
|
549
|
+
|
|
550
|
+
def index_should_have_target(self, index):
|
|
551
|
+
tidx = self._get_tidx(index)
|
|
552
|
+
if tidx < self._validtarget_ceilT - 1:
|
|
553
|
+
return True
|
|
554
|
+
elif tidx > self._validtarget_ceilT - 1:
|
|
555
|
+
return False
|
|
556
|
+
else:
|
|
557
|
+
h, w, _ = self.idx_manager.hwt_from_idx(index)
|
|
558
|
+
return (
|
|
559
|
+
h + self._patch_size < self._h_validmax
|
|
560
|
+
and w + self._patch_size < self._w_validmax
|
|
561
|
+
)
|
|
562
|
+
|
|
563
|
+
@staticmethod
|
|
564
|
+
def get_reduced_frame_size(data_shape_nhw, fraction):
|
|
565
|
+
n, h, w = data_shape_nhw
|
|
566
|
+
|
|
567
|
+
framepixelcount = h * w
|
|
568
|
+
targetpixelcount = int(n * framepixelcount * fraction)
|
|
569
|
+
|
|
570
|
+
# We are currently supporting this only when there is just one frame.
|
|
571
|
+
# if np.ceil(pixelcount / framepixelcount) > 1:
|
|
572
|
+
# return None, None
|
|
573
|
+
|
|
574
|
+
lastframepixelcount = targetpixelcount % framepixelcount
|
|
575
|
+
assert data_shape_nhw[1] == data_shape_nhw[2]
|
|
576
|
+
if lastframepixelcount > 0:
|
|
577
|
+
new_size = int(np.sqrt(lastframepixelcount))
|
|
578
|
+
return new_size, new_size
|
|
579
|
+
else:
|
|
580
|
+
assert (
|
|
581
|
+
targetpixelcount / framepixelcount >= 1
|
|
582
|
+
), "This is not possible in euclidean space :D (so this is a bug)"
|
|
583
|
+
return h, w
|
|
584
|
+
|
|
585
|
+
|
|
586
|
+
rec_header_dtd = [
|
|
587
|
+
("nx", "i4"), # Number of columns
|
|
588
|
+
("ny", "i4"), # Number of rows
|
|
589
|
+
("nz", "i4"), # Number of sections
|
|
590
|
+
("mode", "i4"), # Types of pixels in the image. Values used by IMOD:
|
|
591
|
+
# 0 = unsigned or signed bytes depending on flag in imodFlags
|
|
592
|
+
# 1 = signed short integers (16 bits)
|
|
593
|
+
# 2 = float (32 bits)
|
|
594
|
+
# 3 = short * 2, (used for complex data)
|
|
595
|
+
# 4 = float * 2, (used for complex data)
|
|
596
|
+
# 6 = unsigned 16-bit integers (non-standard)
|
|
597
|
+
# 16 = unsigned char * 3 (for rgb data, non-standard)
|
|
598
|
+
("nxstart", "i4"), # Starting point of sub-image (not used in IMOD)
|
|
599
|
+
("nystart", "i4"),
|
|
600
|
+
("nzstart", "i4"),
|
|
601
|
+
("mx", "i4"), # Grid size in X, Y and Z
|
|
602
|
+
("my", "i4"),
|
|
603
|
+
("mz", "i4"),
|
|
604
|
+
("xlen", "f4"), # Cell size; pixel spacing = xlen/mx, ylen/my, zlen/mz
|
|
605
|
+
("ylen", "f4"),
|
|
606
|
+
("zlen", "f4"),
|
|
607
|
+
("alpha", "f4"), # Cell angles - ignored by IMOD
|
|
608
|
+
("beta", "f4"),
|
|
609
|
+
("gamma", "f4"),
|
|
610
|
+
# These need to be set to 1, 2, and 3 for pixel spacing to be interpreted correctly
|
|
611
|
+
("mapc", "i4"), # map column 1=x,2=y,3=z.
|
|
612
|
+
("mapr", "i4"), # map row 1=x,2=y,3=z.
|
|
613
|
+
("maps", "i4"), # map section 1=x,2=y,3=z.
|
|
614
|
+
# These need to be set for proper scaling of data
|
|
615
|
+
("amin", "f4"), # Minimum pixel value
|
|
616
|
+
("amax", "f4"), # Maximum pixel value
|
|
617
|
+
("amean", "f4"), # Mean pixel value
|
|
618
|
+
("ispg", "i4"), # space group number (ignored by IMOD)
|
|
619
|
+
(
|
|
620
|
+
"next",
|
|
621
|
+
"i4",
|
|
622
|
+
), # number of bytes in extended header (called nsymbt in MRC standard)
|
|
623
|
+
("creatid", "i2"), # used to be an ID number, is 0 as of IMOD 4.2.23
|
|
624
|
+
("extra_data", "V30"), # (not used, first two bytes should be 0)
|
|
625
|
+
# These two values specify the structure of data in the extended header; their meaning depend on whether the
|
|
626
|
+
# extended header has the Agard format, a series of 4-byte integers then real numbers, or has data
|
|
627
|
+
# produced by SerialEM, a series of short integers. SerialEM stores a float as two shorts, s1 and s2, by:
|
|
628
|
+
# value = (sign of s1)*(|s1|*256 + (|s2| modulo 256)) * 2**((sign of s2) * (|s2|/256))
|
|
629
|
+
("nint", "i2"),
|
|
630
|
+
# Number of integers per section (Agard format) or number of bytes per section (SerialEM format)
|
|
631
|
+
("nreal", "i2"), # Number of reals per section (Agard format) or bit
|
|
632
|
+
# Number of reals per section (Agard format) or bit
|
|
633
|
+
# flags for which types of short data (SerialEM format):
|
|
634
|
+
# 1 = tilt angle * 100 (2 bytes)
|
|
635
|
+
# 2 = piece coordinates for montage (6 bytes)
|
|
636
|
+
# 4 = Stage position * 25 (4 bytes)
|
|
637
|
+
# 8 = Magnification / 100 (2 bytes)
|
|
638
|
+
# 16 = Intensity * 25000 (2 bytes)
|
|
639
|
+
# 32 = Exposure dose in e-/A2, a float in 4 bytes
|
|
640
|
+
# 128, 512: Reserved for 4-byte items
|
|
641
|
+
# 64, 256, 1024: Reserved for 2-byte items
|
|
642
|
+
# If the number of bytes implied by these flags does
|
|
643
|
+
# not add up to the value in nint, then nint and nreal
|
|
644
|
+
# are interpreted as ints and reals per section
|
|
645
|
+
("extra_data2", "V20"), # extra data (not used)
|
|
646
|
+
("imodStamp", "i4"), # 1146047817 indicates that file was created by IMOD
|
|
647
|
+
("imodFlags", "i4"), # Bit flags: 1 = bytes are stored as signed
|
|
648
|
+
# Explanation of type of data
|
|
649
|
+
("idtype", "i2"), # ( 0 = mono, 1 = tilt, 2 = tilts, 3 = lina, 4 = lins)
|
|
650
|
+
("lens", "i2"),
|
|
651
|
+
# ("nd1", "i2"), # for idtype = 1, nd1 = axis (1, 2, or 3)
|
|
652
|
+
# ("nd2", "i2"),
|
|
653
|
+
("nphase", "i4"),
|
|
654
|
+
("vd1", "i2"), # vd1 = 100. * tilt increment
|
|
655
|
+
("vd2", "i2"), # vd2 = 100. * starting angle
|
|
656
|
+
# Current angles are used to rotate a model to match a new rotated image. The three values in each set are
|
|
657
|
+
# rotations about X, Y, and Z axes, applied in the order Z, Y, X.
|
|
658
|
+
("triangles", "f4", 6), # 0,1,2 = original: 3,4,5 = current
|
|
659
|
+
("xorg", "f4"), # Origin of image
|
|
660
|
+
("yorg", "f4"),
|
|
661
|
+
("zorg", "f4"),
|
|
662
|
+
("cmap", "S4"), # Contains "MAP "
|
|
663
|
+
(
|
|
664
|
+
"stamp",
|
|
665
|
+
"u1",
|
|
666
|
+
4,
|
|
667
|
+
), # First two bytes have 17 and 17 for big-endian or 68 and 65 for little-endian
|
|
668
|
+
("rms", "f4"), # RMS deviation of densities from mean density
|
|
669
|
+
("nlabl", "i4"), # Number of labels with useful data
|
|
670
|
+
("labels", "S80", 10), # 10 labels of 80 charactors
|
|
671
|
+
]
|
|
672
|
+
|
|
673
|
+
|
|
674
|
+
def read_mrc(filename, filetype="image"):
|
|
675
|
+
fd = open(filename, "rb")
|
|
676
|
+
header = np.fromfile(fd, dtype=rec_header_dtd, count=1)
|
|
677
|
+
|
|
678
|
+
nx, ny, nz = header["nx"][0], header["ny"][0], header["nz"][0]
|
|
679
|
+
|
|
680
|
+
if header[0][3] == 1:
|
|
681
|
+
data_type = "int16"
|
|
682
|
+
elif header[0][3] == 2:
|
|
683
|
+
data_type = "float32"
|
|
684
|
+
elif header[0][3] == 4:
|
|
685
|
+
data_type = "single"
|
|
686
|
+
nx = nx * 2
|
|
687
|
+
elif header[0][3] == 6:
|
|
688
|
+
data_type = "uint16"
|
|
689
|
+
|
|
690
|
+
data = np.ndarray(shape=(nx, ny, nz))
|
|
691
|
+
imgrawdata = np.fromfile(fd, data_type)
|
|
692
|
+
fd.close()
|
|
693
|
+
|
|
694
|
+
if filetype == "image":
|
|
695
|
+
for iz in range(nz):
|
|
696
|
+
data_2d = imgrawdata[nx * ny * iz : nx * ny * (iz + 1)]
|
|
697
|
+
data[:, :, iz] = data_2d.reshape(nx, ny, order="F")
|
|
698
|
+
else:
|
|
699
|
+
data = imgrawdata
|
|
700
|
+
|
|
701
|
+
return header, data
|