konfai 1.1.7__py3-none-any.whl → 1.1.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of konfai might be problematic. Click here for more details.
- konfai/__init__.py +59 -14
- konfai/data/augmentation.py +457 -286
- konfai/data/data_manager.py +509 -290
- konfai/data/patching.py +300 -183
- konfai/data/transform.py +384 -277
- konfai/evaluator.py +309 -68
- konfai/main.py +71 -22
- konfai/metric/measure.py +341 -222
- konfai/metric/schedulers.py +24 -13
- konfai/models/classification/convNeXt.py +187 -81
- konfai/models/classification/resnet.py +272 -58
- konfai/models/generation/cStyleGan.py +233 -59
- konfai/models/generation/ddpm.py +348 -121
- konfai/models/generation/diffusionGan.py +757 -358
- konfai/models/generation/gan.py +177 -53
- konfai/models/generation/vae.py +140 -40
- konfai/models/registration/registration.py +135 -52
- konfai/models/representation/representation.py +57 -23
- konfai/models/segmentation/NestedUNet.py +339 -68
- konfai/models/segmentation/UNet.py +140 -30
- konfai/network/blocks.py +331 -187
- konfai/network/network.py +781 -423
- konfai/predictor.py +645 -240
- konfai/trainer.py +527 -216
- konfai/utils/ITK.py +191 -106
- konfai/utils/config.py +152 -95
- konfai/utils/dataset.py +326 -455
- konfai/utils/utils.py +495 -249
- {konfai-1.1.7.dist-info → konfai-1.1.9.dist-info}/METADATA +1 -3
- konfai-1.1.9.dist-info/RECORD +38 -0
- konfai/utils/registration.py +0 -199
- konfai-1.1.7.dist-info/RECORD +0 -39
- {konfai-1.1.7.dist-info → konfai-1.1.9.dist-info}/WHEEL +0 -0
- {konfai-1.1.7.dist-info → konfai-1.1.9.dist-info}/entry_points.txt +0 -0
- {konfai-1.1.7.dist-info → konfai-1.1.9.dist-info}/licenses/LICENSE +0 -0
- {konfai-1.1.7.dist-info → konfai-1.1.9.dist-info}/top_level.txt +0 -0
konfai/utils/dataset.py
CHANGED
|
@@ -1,255 +1,33 @@
|
|
|
1
|
-
import
|
|
2
|
-
import h5py
|
|
3
|
-
from abc import ABC, abstractmethod
|
|
4
|
-
import numpy as np
|
|
5
|
-
from typing import Any, Union
|
|
1
|
+
import ast
|
|
6
2
|
import copy
|
|
7
|
-
import torch
|
|
8
|
-
import os
|
|
9
|
-
|
|
10
|
-
from lxml import etree
|
|
11
3
|
import csv
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
def __init__(self, root: etree.ElementTree) -> None:
|
|
17
|
-
self.root = root
|
|
4
|
+
import os
|
|
5
|
+
from abc import ABC, abstractmethod
|
|
6
|
+
from typing import Any
|
|
18
7
|
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
result[label] = np.fromstring(root.text, sep = ",").astype('double')
|
|
25
|
-
else:
|
|
26
|
-
for node in root:
|
|
27
|
-
Plot._explore(node, result, label+":"+node.tag)
|
|
28
|
-
|
|
29
|
-
def getNodes(root, path = None, id = None):
|
|
30
|
-
nodes = []
|
|
31
|
-
if path != None:
|
|
32
|
-
path = path.split(":")
|
|
33
|
-
for node_name in path:
|
|
34
|
-
node = root.find(node_name)
|
|
35
|
-
if node != None:
|
|
36
|
-
root = node
|
|
37
|
-
else:
|
|
38
|
-
break
|
|
39
|
-
if id != None:
|
|
40
|
-
for node in root.findall(".//"+id):
|
|
41
|
-
nodes.append(node)
|
|
42
|
-
else:
|
|
43
|
-
nodes.append(root)
|
|
44
|
-
return nodes
|
|
8
|
+
import h5py
|
|
9
|
+
import numpy as np
|
|
10
|
+
import SimpleITK as sitk # noqa: N813
|
|
11
|
+
import torch
|
|
12
|
+
from lxml import etree # nosec B410
|
|
45
13
|
|
|
46
|
-
|
|
47
|
-
result = dict()
|
|
48
|
-
for node in Plot.getNodes(root, path, id):
|
|
49
|
-
Plot._explore(node, result, etree.ElementTree(root).getpath(node))
|
|
50
|
-
return result
|
|
14
|
+
from konfai import current_date
|
|
51
15
|
|
|
52
|
-
def _extract(self, ids = [], patients = []):
|
|
53
|
-
result = dict()
|
|
54
|
-
if len(patients) == 0:
|
|
55
|
-
if len(ids) == 0:
|
|
56
|
-
result.update(Plot.read(self.root,None, None))
|
|
57
|
-
else:
|
|
58
|
-
for id in ids:
|
|
59
|
-
result.update(Plot.read(self.root, None, id))
|
|
60
|
-
else:
|
|
61
|
-
for path in patients:
|
|
62
|
-
if len(ids) == 0:
|
|
63
|
-
result.update(Plot.read(self.root, path, None))
|
|
64
|
-
else:
|
|
65
|
-
for id in ids:
|
|
66
|
-
result.update(Plot.read(self.root, path, id))
|
|
67
|
-
return result
|
|
68
|
-
|
|
69
|
-
def getErrors(self, ids = [], patients = []):
|
|
70
|
-
results = self._extract(ids=ids, patients=patients)
|
|
71
|
-
errors = {k: v for k, v in results.items() if not k.startswith("attrib:")}
|
|
72
|
-
results : dict[str, dict[str, np.ndarray]]= {}
|
|
73
|
-
for key, error in errors.items():
|
|
74
|
-
patient = key.replace("/",":").split(":")[2]
|
|
75
|
-
k = key.replace("/",":").split(":")[-1]
|
|
76
|
-
err = np.linalg.norm(error.reshape(int(error.shape[0]/3),3), ord=2, axis=1)
|
|
77
|
-
if patient not in results:
|
|
78
|
-
results[patient] = {k : err}
|
|
79
|
-
else:
|
|
80
|
-
results[patient].update({k : err})
|
|
81
|
-
return results
|
|
82
|
-
|
|
83
|
-
def statistic_attrib(self, ids = [], patients = [], type: str = "HD95Mean"):
|
|
84
|
-
results = self._extract(ids=ids, patients=patients)
|
|
85
|
-
|
|
86
|
-
errors = {k.replace("attrib:", ""): float(v) for k, v in results.items() if type in k}
|
|
87
|
-
|
|
88
|
-
values = {key : np.array([]) for key in ids}
|
|
89
|
-
for key, error in errors.items():
|
|
90
|
-
k = key.replace("/",":").split(":")[-2]
|
|
91
|
-
values[k] = np.append(values[k], error)
|
|
92
|
-
|
|
93
|
-
for k in values:
|
|
94
|
-
values[k] = np.mean(values[k])
|
|
95
|
-
print(values)
|
|
96
|
-
return values
|
|
97
|
-
|
|
98
|
-
def statistic_parameter(self, ids = [], patients = []):
|
|
99
|
-
results = self._extract(ids=ids, patients=patients)
|
|
100
|
-
errors = {k.replace("attrib:", "").replace(":Time", "") : np.load("./Results/{}/{}.npy".format(k.split("/")[3].split(":")[0], k.split("/")[2])) for k in results.keys()}
|
|
101
|
-
|
|
102
|
-
norms = {key : np.array([]) for key in ids}
|
|
103
|
-
max = 0
|
|
104
|
-
for key, error in errors.items():
|
|
105
|
-
if max < int(error.shape[0]/3):
|
|
106
|
-
max = int(error.shape[0]/3)
|
|
107
|
-
|
|
108
|
-
for key, error in errors.items():
|
|
109
|
-
k = key.replace("/",":").split(":")[-1]
|
|
110
|
-
norms[k] = np.append(norms[k], np.linalg.norm(error.reshape(int(error.shape[0]/3),3), ord=2, axis=1))
|
|
111
|
-
v = np.linalg.norm(error.reshape(int(error.shape[0]/3),3), ord=2, axis=1)
|
|
112
|
-
|
|
113
|
-
print(key, "{} {} {} {} {}".format(np.round(np.mean(v), 2), np.round(np.std(v), 2), np.round(np.quantile(v, 0.25), 2), np.round(np.quantile(v, 0.5), 2), np.round(np.quantile(v, 0.75), 2)))
|
|
114
|
-
results = {}
|
|
115
|
-
for key, values in norms.items():
|
|
116
|
-
if key == "Rigid":
|
|
117
|
-
results.update({key : values})
|
|
118
|
-
else:
|
|
119
|
-
try:
|
|
120
|
-
name = "{}".format("_".join(key.split("_"))[:-1])
|
|
121
|
-
it = int(key.split("_")[-1])
|
|
122
|
-
except:
|
|
123
|
-
name = "{}".format(key.split("-")[0])
|
|
124
|
-
it = int(key.split("-")[-1])
|
|
125
|
-
|
|
126
|
-
if name in results:
|
|
127
|
-
results[name].update({it : values})
|
|
128
|
-
else:
|
|
129
|
-
results.update({name : {it : values}})
|
|
130
|
-
|
|
131
|
-
r = []
|
|
132
|
-
for key, values in norms.items():
|
|
133
|
-
#r.append("{} $\pm$ {}".format(np.round(np.mean(values), 2), np.round(np.std(values), 2)))
|
|
134
|
-
r.append("{} {} {}".format(np.round(np.quantile(values, 0.25), 2), np.round(np.quantile(values, 0.5),2), np.round(np.quantile(values, 0.75), 2)))
|
|
135
|
-
#r.append("{} $\pm$ {}".format(np.round(np.quantile(values, 0.5), 2), np.round(np.quantile(values, 0.75)-np.quantile(values, 0.25), 2)))
|
|
136
|
-
print(" & ".join(r))
|
|
137
|
-
|
|
138
|
-
def statistic(self, ids = [], patients = []):
|
|
139
|
-
results = self._extract(ids=ids, patients=patients)
|
|
140
|
-
#errors = {k.replace("attrib:", "").replace(":Time", "") : np.load("./Dataset/{}/{}.npy".format(k.split("/")[3].split(":")[0], k.split("/")[2])) for k in results.keys()}
|
|
141
|
-
errors = {k: v for k, v in results.items() if not k.startswith("attrib:")}
|
|
142
|
-
print(errors)
|
|
143
|
-
norms = {key : np.array([]) for key in ids}
|
|
144
|
-
max = 0
|
|
145
|
-
for key, error in errors.items():
|
|
146
|
-
if max < int(error.shape[0]/3):
|
|
147
|
-
max = int(error.shape[0]/3)
|
|
148
|
-
|
|
149
|
-
for key, error in errors.items():
|
|
150
|
-
k = key.replace("/",":").split(":")[-1]
|
|
151
|
-
norms[k] = np.append(norms[k], np.linalg.norm(error.reshape(int(error.shape[0]/3),3), ord=2, axis=1))
|
|
152
|
-
v = np.linalg.norm(error.reshape(int(error.shape[0]/3),3), ord=2, axis=1)
|
|
153
|
-
print(key, (np.mean(v), np.std(v), np.quantile(v, 0.25), np.quantile(v, 0.5), np.quantile(v, 0.75)) )
|
|
154
|
-
results = {}
|
|
155
|
-
"""for key, values in norms.items():
|
|
156
|
-
if key == "Rigid":
|
|
157
|
-
results.update({key : values})
|
|
158
|
-
else:
|
|
159
|
-
try:
|
|
160
|
-
name = "{}".format("_".join(key.split("_"))[:-1])
|
|
161
|
-
it = int(key.split("_")[-1])
|
|
162
|
-
except:
|
|
163
|
-
name = "{}".format(key.split("-")[0])
|
|
164
|
-
it = int(key.split("-")[-1])
|
|
165
|
-
|
|
166
|
-
if name in results:
|
|
167
|
-
results[name].update({it : values})
|
|
168
|
-
else:
|
|
169
|
-
results.update({name : {it : values}})"""
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
print({key: (np.mean(values), np.std(values), np.quantile(values, 0.25), np.quantile(values, 0.5), np.quantile(values, 0.75)) for key, values in norms.items()})
|
|
173
|
-
return results
|
|
174
|
-
|
|
175
|
-
def plot(self, ids = [], patients = [], labels = [], colors = None):
|
|
176
|
-
|
|
177
|
-
import matplotlib.pyplot as pyplot
|
|
178
|
-
results = self._extract(ids=ids, patients=patients)
|
|
179
|
-
|
|
180
|
-
attrs = {k: v for k, v in results.items() if k.startswith("attrib:")}
|
|
181
|
-
errors = {k: v for k, v in results.items() if not k.startswith("attrib:")}
|
|
182
|
-
|
|
183
|
-
patients = set()
|
|
184
|
-
max = 0
|
|
185
|
-
for key, error in errors.items():
|
|
186
|
-
patients.add(key.replace("/",":").split(":")[2])
|
|
187
|
-
if max < int(error.shape[0]/3):
|
|
188
|
-
max = int(error.shape[0]/3)
|
|
189
|
-
patients = sorted(patients)
|
|
190
|
-
|
|
191
|
-
norms = {patient : np.array([]) for patient in patients}
|
|
192
|
-
markups = {patient : np.array([]) for patient in patients}
|
|
193
|
-
series = list()
|
|
194
|
-
for key, error in errors.items():
|
|
195
|
-
patient = key.replace("/",":").split(":")[2]
|
|
196
|
-
markup = np.full((max,3), np.nan)
|
|
197
|
-
markup[0:int(error.shape[0]/3), :] = error.reshape(int(error.shape[0]/3),3)
|
|
198
|
-
|
|
199
|
-
markups[patient] = np.append(markups[patient], markup)
|
|
200
|
-
norms[patient] = np.append(norms[patient], np.linalg.norm(markup, ord=2, axis=1))
|
|
201
|
-
|
|
202
|
-
if len(labels) == 0:
|
|
203
|
-
labels = list(set([k.split("/")[-1] for k in errors.keys()]))
|
|
204
|
-
|
|
205
|
-
for label in labels:
|
|
206
|
-
series = series+[label]*max
|
|
207
|
-
import pandas as pd
|
|
208
|
-
df = pd.DataFrame(dict([(k,pd.Series(v)) for k, v in norms.items()]))
|
|
209
|
-
df['Categories'] = pd.Series(series)
|
|
210
|
-
|
|
211
|
-
bp = df.boxplot(by='Categories', color="black", figsize=(12,8), notch=True,layout=(1,len(patients)), fontsize=18, rot=0, patch_artist = True, return_type='both', widths=[0.5]*len(labels))
|
|
212
|
-
|
|
213
|
-
color_pallet = {"b" : "paleturquoise", "g" : "lightgreen"}
|
|
214
|
-
if colors == None:
|
|
215
|
-
colors = ["b"] * len(patients)
|
|
216
|
-
pyplot.suptitle('')
|
|
217
|
-
it_1 = 0
|
|
218
|
-
for index, (ax,row) in bp.items():
|
|
219
|
-
ax.set_xlabel('')
|
|
220
|
-
ax.set_ylim(ymin=0)
|
|
221
|
-
ax.set_ylabel("TRE (mm)", fontsize=18)
|
|
222
|
-
ax.set_yticks([0,1,2,3,4,5,6,7,8,9,10,15,20,25]) # Set label locations.
|
|
223
|
-
for i,object in enumerate(row["boxes"]):
|
|
224
|
-
object.set_edgecolor("black")
|
|
225
|
-
object.set_facecolor(color_pallet[colors[i]])
|
|
226
|
-
object.set_alpha(0.7)
|
|
227
|
-
object.set_linewidth(1.0)
|
|
228
|
-
|
|
229
|
-
for i,object in enumerate(row["medians"]):
|
|
230
|
-
object.set_color("indianred")
|
|
231
|
-
xy = object.get_xydata()
|
|
232
|
-
object.set_linewidth(2.0)
|
|
233
|
-
it_1+=1
|
|
234
|
-
return self
|
|
235
|
-
|
|
236
|
-
def show(self):
|
|
237
|
-
import matplotlib.pyplot as pyplot
|
|
238
|
-
pyplot.show()
|
|
239
16
|
|
|
240
17
|
class Attribute(dict[str, Any]):
|
|
241
18
|
|
|
242
|
-
def __init__(self, attributes
|
|
19
|
+
def __init__(self, attributes: dict[str, Any] | None = None) -> None:
|
|
243
20
|
super().__init__()
|
|
21
|
+
attributes = attributes or {}
|
|
244
22
|
for k, v in attributes.items():
|
|
245
23
|
super().__setitem__(copy.deepcopy(k), copy.deepcopy(v))
|
|
246
|
-
|
|
24
|
+
|
|
247
25
|
def __getitem__(self, key: str) -> Any:
|
|
248
26
|
i = len([k for k in super().keys() if k.startswith(key)])
|
|
249
|
-
if i > 0 and "{}_{
|
|
250
|
-
return str(super().__getitem__("{}_{
|
|
27
|
+
if i > 0 and f"{key}_{i - 1}" in super().keys():
|
|
28
|
+
return str(super().__getitem__(f"{key}_{i - 1}"))
|
|
251
29
|
else:
|
|
252
|
-
raise NameError("{} not in cache_attribute"
|
|
30
|
+
raise NameError(f"{key} not in cache_attribute")
|
|
253
31
|
|
|
254
32
|
def __setitem__(self, key: str, value: Any) -> None:
|
|
255
33
|
if "_" not in key:
|
|
@@ -259,58 +37,63 @@ class Attribute(dict[str, Any]):
|
|
|
259
37
|
result = str(value.numpy())
|
|
260
38
|
else:
|
|
261
39
|
result = str(value)
|
|
262
|
-
result = result.replace(
|
|
263
|
-
super().__setitem__("{}_{}"
|
|
40
|
+
result = result.replace("\n", "")
|
|
41
|
+
super().__setitem__(f"{key}_{i}", result)
|
|
264
42
|
else:
|
|
265
43
|
result = None
|
|
266
44
|
if isinstance(value, torch.Tensor):
|
|
267
45
|
result = str(value.numpy())
|
|
268
46
|
else:
|
|
269
47
|
result = str(value)
|
|
270
|
-
result = result.replace(
|
|
48
|
+
result = result.replace("\n", "")
|
|
271
49
|
super().__setitem__(key, result)
|
|
272
50
|
|
|
273
|
-
def pop(self, key: str) -> Any:
|
|
51
|
+
def pop(self, key: str, default: Any = None) -> Any:
|
|
274
52
|
i = len([k for k in super().keys() if k.startswith(key)])
|
|
275
|
-
if i > 0 and "{}_{
|
|
276
|
-
return super().pop("{}_{
|
|
53
|
+
if i > 0 and f"{key}_{i - 1}" in super().keys():
|
|
54
|
+
return super().pop(f"{key}_{i - 1}")
|
|
277
55
|
else:
|
|
278
|
-
raise NameError("{} not in cache_attribute"
|
|
56
|
+
raise NameError(f"{key} not in cache_attribute")
|
|
279
57
|
|
|
280
58
|
def get_np_array(self, key) -> np.ndarray:
|
|
281
59
|
return np.fromstring(self[key][1:-1], sep=" ", dtype=np.double)
|
|
282
|
-
|
|
60
|
+
|
|
283
61
|
def get_tensor(self, key) -> torch.Tensor:
|
|
284
62
|
return torch.tensor(self.get_np_array(key)).to(torch.float32)
|
|
285
|
-
|
|
63
|
+
|
|
286
64
|
def pop_np_array(self, key):
|
|
287
65
|
return np.fromstring(self.pop(key)[1:-1], sep=" ", dtype=np.double)
|
|
288
|
-
|
|
66
|
+
|
|
289
67
|
def pop_tensor(self, key) -> torch.Tensor:
|
|
290
68
|
return torch.tensor(self.pop_np_array(key))
|
|
291
|
-
|
|
292
|
-
def __contains__(self, key:
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
69
|
+
|
|
70
|
+
def __contains__(self, key: object) -> bool:
|
|
71
|
+
if not isinstance(key, str):
|
|
72
|
+
return False
|
|
73
|
+
return any(k.startswith(key) for k in super().keys())
|
|
74
|
+
|
|
75
|
+
def is_info(self, key: str, value: str) -> bool:
|
|
296
76
|
return key in self and self[key] == value
|
|
297
77
|
|
|
298
|
-
|
|
78
|
+
|
|
79
|
+
def is_an_image(attributes: Attribute):
|
|
299
80
|
return "Origin" in attributes and "Spacing" in attributes and "Direction" in attributes
|
|
300
81
|
|
|
301
|
-
|
|
302
|
-
|
|
82
|
+
|
|
83
|
+
def data_to_image(data: np.ndarray, attributes: Attribute) -> sitk.Image:
|
|
84
|
+
if not is_an_image(attributes):
|
|
303
85
|
raise NameError("Data is not an image")
|
|
304
86
|
if data.shape[0] == 1:
|
|
305
87
|
image = sitk.GetImageFromArray(data[0])
|
|
306
88
|
else:
|
|
307
|
-
data = data.transpose(tuple([i+1 for i in range(len(data.shape)-1)]+[0]))
|
|
89
|
+
data = data.transpose(tuple([i + 1 for i in range(len(data.shape) - 1)] + [0]))
|
|
308
90
|
image = sitk.GetImageFromArray(data, isVector=True)
|
|
309
91
|
image.SetOrigin(attributes.get_np_array("Origin").tolist())
|
|
310
92
|
image.SetSpacing(attributes.get_np_array("Spacing").tolist())
|
|
311
93
|
image.SetDirection(attributes.get_np_array("Direction").tolist())
|
|
312
94
|
return image
|
|
313
95
|
|
|
96
|
+
|
|
314
97
|
def image_to_data(image: sitk.Image) -> tuple[np.ndarray, Attribute]:
|
|
315
98
|
attributes = Attribute()
|
|
316
99
|
attributes["Origin"] = np.asarray(image.GetOrigin())
|
|
@@ -321,81 +104,98 @@ def image_to_data(image: sitk.Image) -> tuple[np.ndarray, Attribute]:
|
|
|
321
104
|
if image.GetNumberOfComponentsPerPixel() == 1:
|
|
322
105
|
data = np.expand_dims(data, 0)
|
|
323
106
|
else:
|
|
324
|
-
data = np.transpose(data, (len(data.shape)-1, *
|
|
107
|
+
data = np.transpose(data, (len(data.shape) - 1, *list(range(len(data.shape) - 1))))
|
|
325
108
|
return data, attributes
|
|
326
109
|
|
|
327
|
-
|
|
110
|
+
|
|
111
|
+
class Dataset:
|
|
328
112
|
|
|
329
113
|
class AbstractFile(ABC):
|
|
330
114
|
|
|
115
|
+
@abstractmethod
|
|
331
116
|
def __init__(self) -> None:
|
|
332
117
|
pass
|
|
333
|
-
|
|
118
|
+
|
|
119
|
+
@abstractmethod
|
|
334
120
|
def __enter__(self):
|
|
335
121
|
pass
|
|
336
122
|
|
|
337
|
-
|
|
123
|
+
@abstractmethod
|
|
124
|
+
def __exit__(self, exc_type, value, traceback):
|
|
338
125
|
pass
|
|
339
126
|
|
|
340
127
|
@abstractmethod
|
|
341
|
-
def file_to_data(self):
|
|
128
|
+
def file_to_data(self, group: str, name: str) -> tuple[np.ndarray, Attribute]:
|
|
342
129
|
pass
|
|
343
130
|
|
|
344
131
|
@abstractmethod
|
|
345
|
-
def data_to_file(
|
|
132
|
+
def data_to_file(
|
|
133
|
+
self,
|
|
134
|
+
name: str,
|
|
135
|
+
data: sitk.Image | sitk.Transform | np.ndarray,
|
|
136
|
+
attributes: Attribute | None = None,
|
|
137
|
+
) -> None:
|
|
346
138
|
pass
|
|
347
139
|
|
|
348
140
|
@abstractmethod
|
|
349
|
-
def
|
|
141
|
+
def get_names(self, group: str) -> list[str]:
|
|
350
142
|
pass
|
|
351
|
-
|
|
143
|
+
|
|
352
144
|
@abstractmethod
|
|
353
|
-
def
|
|
145
|
+
def get_group(self) -> list[str]:
|
|
354
146
|
pass
|
|
355
147
|
|
|
356
148
|
@abstractmethod
|
|
357
|
-
def
|
|
149
|
+
def is_exist(self, group: str, name: str | None = None) -> bool:
|
|
358
150
|
pass
|
|
359
|
-
|
|
151
|
+
|
|
360
152
|
@abstractmethod
|
|
361
|
-
def
|
|
153
|
+
def get_infos(self, group: str, name: str) -> tuple[list[int], Attribute]:
|
|
362
154
|
pass
|
|
363
155
|
|
|
364
156
|
class H5File(AbstractFile):
|
|
365
157
|
|
|
366
158
|
def __init__(self, filename: str, read: bool) -> None:
|
|
367
|
-
self.h5:
|
|
159
|
+
self.h5: h5py.File | None = None
|
|
368
160
|
self.filename = filename
|
|
369
161
|
if not self.filename.endswith(".h5"):
|
|
370
162
|
self.filename += ".h5"
|
|
371
163
|
self.read = read
|
|
372
164
|
|
|
373
165
|
def __enter__(self):
|
|
374
|
-
args = {}
|
|
375
166
|
if self.read:
|
|
376
|
-
self.h5 = h5py.File(self.filename,
|
|
167
|
+
self.h5 = h5py.File(self.filename, "r")
|
|
377
168
|
else:
|
|
378
169
|
if not os.path.exists(self.filename):
|
|
379
|
-
if len(self.filename.split("/")) > 1 and not os.path.exists(
|
|
170
|
+
if len(self.filename.split("/")) > 1 and not os.path.exists(
|
|
171
|
+
"/".join(self.filename.split("/")[:-1])
|
|
172
|
+
):
|
|
380
173
|
os.makedirs("/".join(self.filename.split("/")[:-1]))
|
|
381
|
-
self.h5 = h5py.File(self.filename,
|
|
382
|
-
else:
|
|
383
|
-
self.h5 = h5py.File(self.filename,
|
|
384
|
-
self.h5.attrs["Date"] =
|
|
174
|
+
self.h5 = h5py.File(self.filename, "w")
|
|
175
|
+
else:
|
|
176
|
+
self.h5 = h5py.File(self.filename, "r+")
|
|
177
|
+
self.h5.attrs["Date"] = current_date()
|
|
385
178
|
self.h5.__enter__()
|
|
386
179
|
return self.h5
|
|
387
|
-
|
|
388
|
-
def __exit__(self,
|
|
180
|
+
|
|
181
|
+
def __exit__(self, exc_type, value, traceback):
|
|
389
182
|
if self.h5 is not None:
|
|
390
183
|
self.h5.close()
|
|
391
|
-
|
|
184
|
+
|
|
392
185
|
def file_to_data(self, groups: str, name: str) -> tuple[np.ndarray, Attribute]:
|
|
393
|
-
dataset = self.
|
|
186
|
+
dataset = self._get_dataset(groups, name)
|
|
394
187
|
data = np.zeros(dataset.shape, dataset.dtype)
|
|
395
188
|
dataset.read_direct(data)
|
|
396
|
-
return data, Attribute({k
|
|
397
|
-
|
|
398
|
-
def data_to_file(
|
|
189
|
+
return data, Attribute({k: str(v) for k, v in dataset.attrs.items()})
|
|
190
|
+
|
|
191
|
+
def data_to_file(
|
|
192
|
+
self,
|
|
193
|
+
name: str,
|
|
194
|
+
data: sitk.Image | sitk.Transform | np.ndarray,
|
|
195
|
+
attributes: Attribute | None = None,
|
|
196
|
+
) -> None:
|
|
197
|
+
if self.h5 is None:
|
|
198
|
+
return
|
|
399
199
|
if attributes is None:
|
|
400
200
|
attributes = Attribute()
|
|
401
201
|
if isinstance(data, sitk.Image):
|
|
@@ -405,7 +205,7 @@ class Dataset():
|
|
|
405
205
|
transforms = []
|
|
406
206
|
if isinstance(data, sitk.CompositeTransform):
|
|
407
207
|
for i in range(data.GetNumberOfTransforms()):
|
|
408
|
-
transforms.append(data.GetNthTransform(i))
|
|
208
|
+
transforms.append(data.GetNthTransform(i))
|
|
409
209
|
else:
|
|
410
210
|
transforms.append(data)
|
|
411
211
|
datas = []
|
|
@@ -416,8 +216,8 @@ class Dataset():
|
|
|
416
216
|
transform_type = "AffineTransform_double_3_3"
|
|
417
217
|
if isinstance(transform, sitk.BSplineTransform):
|
|
418
218
|
transform_type = "BSplineTransform_double_3_3"
|
|
419
|
-
attributes["{}:Transform"
|
|
420
|
-
attributes["{}:FixedParameters"
|
|
219
|
+
attributes[f"{i}:Transform"] = transform_type
|
|
220
|
+
attributes[f"{i}:FixedParameters"] = transform.GetFixedParameters()
|
|
421
221
|
|
|
422
222
|
datas.append(np.asarray(transform.GetParameters()))
|
|
423
223
|
data = np.asarray(datas)
|
|
@@ -434,38 +234,41 @@ class Dataset():
|
|
|
434
234
|
del h5_group[name]
|
|
435
235
|
|
|
436
236
|
dataset = h5_group.create_dataset(name, data=data, dtype=data.dtype, chunks=None)
|
|
437
|
-
dataset.attrs.update({k
|
|
438
|
-
|
|
439
|
-
def
|
|
440
|
-
if
|
|
441
|
-
if
|
|
442
|
-
|
|
443
|
-
|
|
444
|
-
|
|
445
|
-
|
|
446
|
-
|
|
237
|
+
dataset.attrs.update({k: str(v) for k, v in attributes.items()})
|
|
238
|
+
|
|
239
|
+
def is_exist(self, group: str, name: str | None = None) -> bool:
|
|
240
|
+
if self.h5 is not None:
|
|
241
|
+
if group in self.h5:
|
|
242
|
+
if isinstance(self.h5[group], h5py.Dataset):
|
|
243
|
+
return True
|
|
244
|
+
elif name is not None:
|
|
245
|
+
return name in self.h5[group]
|
|
246
|
+
else:
|
|
247
|
+
return False
|
|
447
248
|
return False
|
|
448
249
|
|
|
449
|
-
def
|
|
250
|
+
def get_names(self, groups: str, h5_group: h5py.Group = None) -> list[str]:
|
|
450
251
|
names = []
|
|
451
252
|
if h5_group is None:
|
|
452
253
|
h5_group = self.h5
|
|
453
254
|
group = groups.split("/")[0]
|
|
454
255
|
if group == "":
|
|
455
|
-
names = [
|
|
256
|
+
names = [
|
|
257
|
+
dataset.name.split("/")[-1] for dataset in h5_group.values() if isinstance(dataset, h5py.Dataset)
|
|
258
|
+
]
|
|
456
259
|
elif group == "*":
|
|
457
260
|
for k in h5_group.keys():
|
|
458
261
|
if isinstance(h5_group[k], h5py.Group):
|
|
459
|
-
names.extend(self.
|
|
262
|
+
names.extend(self.get_names("/".join(groups.split("/")[1:]), h5_group[k]))
|
|
460
263
|
else:
|
|
461
264
|
if group in h5_group:
|
|
462
|
-
names.extend(self.
|
|
265
|
+
names.extend(self.get_names("/".join(groups.split("/")[1:]), h5_group[group]))
|
|
463
266
|
return names
|
|
464
|
-
|
|
465
|
-
def
|
|
466
|
-
return self.h5.keys()
|
|
467
|
-
|
|
468
|
-
def
|
|
267
|
+
|
|
268
|
+
def get_group(self) -> list[str]:
|
|
269
|
+
return list(self.h5.keys()) if self.h5 is not None else []
|
|
270
|
+
|
|
271
|
+
def _get_dataset(self, groups: str, name: str, h5_group: h5py.Group = None) -> h5py.Dataset:
|
|
469
272
|
if h5_group is None:
|
|
470
273
|
h5_group = self.h5
|
|
471
274
|
if groups != "":
|
|
@@ -479,39 +282,42 @@ class Dataset():
|
|
|
479
282
|
elif group == "*":
|
|
480
283
|
for k in h5_group.keys():
|
|
481
284
|
if isinstance(h5_group[k], h5py.Group):
|
|
482
|
-
result_tmp = self.
|
|
285
|
+
result_tmp = self._get_dataset("/".join(groups.split("/")[1:]), name, h5_group[k])
|
|
483
286
|
if result_tmp is not None:
|
|
484
287
|
result = result_tmp
|
|
485
288
|
else:
|
|
486
289
|
if group in h5_group:
|
|
487
|
-
result_tmp = self.
|
|
290
|
+
result_tmp = self._get_dataset("/".join(groups.split("/")[1:]), name, h5_group[group])
|
|
488
291
|
if result_tmp is not None:
|
|
489
292
|
result = result_tmp
|
|
490
293
|
return result
|
|
491
|
-
|
|
492
|
-
def
|
|
493
|
-
dataset = self.
|
|
494
|
-
return (
|
|
495
|
-
|
|
294
|
+
|
|
295
|
+
def get_infos(self, groups: str, name: str) -> tuple[list[int], Attribute]:
|
|
296
|
+
dataset = self._get_dataset(groups, name)
|
|
297
|
+
return (
|
|
298
|
+
dataset.shape,
|
|
299
|
+
Attribute({k: str(v) for k, v in dataset.attrs.items()}),
|
|
300
|
+
)
|
|
301
|
+
|
|
496
302
|
class SitkFile(AbstractFile):
|
|
497
303
|
|
|
498
|
-
def __init__(self, filename: str, read: bool,
|
|
304
|
+
def __init__(self, filename: str, read: bool, file_format: str) -> None:
|
|
499
305
|
self.filename = filename
|
|
500
306
|
self.read = read
|
|
501
|
-
self.
|
|
502
|
-
|
|
307
|
+
self.file_format = file_format
|
|
308
|
+
|
|
503
309
|
def file_to_data(self, group: str, name: str) -> tuple[np.ndarray, Attribute]:
|
|
504
310
|
attributes = Attribute()
|
|
505
|
-
if os.path.exists("{}{}.{}"
|
|
506
|
-
image = sitk.ReadImage("{}{}.{}"
|
|
311
|
+
if os.path.exists(f"{self.filename}{name}.{self.file_format}"):
|
|
312
|
+
image = sitk.ReadImage(f"{self.filename}{name}.{self.file_format}")
|
|
507
313
|
data, attributes_tmp = image_to_data(image)
|
|
508
314
|
attributes.update(attributes_tmp)
|
|
509
|
-
elif os.path.exists("{}{}.itk.txt"
|
|
510
|
-
data = sitk.ReadTransform("{}{}.itk.txt"
|
|
315
|
+
elif os.path.exists(f"{self.filename}{name}.itk.txt"):
|
|
316
|
+
data = sitk.ReadTransform(f"{self.filename}{name}.itk.txt")
|
|
511
317
|
transforms = []
|
|
512
318
|
if isinstance(data, sitk.CompositeTransform):
|
|
513
319
|
for i in range(data.GetNumberOfTransforms()):
|
|
514
|
-
transforms.append(data.GetNthTransform(i))
|
|
320
|
+
transforms.append(data.GetNthTransform(i))
|
|
515
321
|
else:
|
|
516
322
|
transforms.append(data)
|
|
517
323
|
datas = []
|
|
@@ -522,113 +328,151 @@ class Dataset():
|
|
|
522
328
|
transform_type = "AffineTransform_double_3_3"
|
|
523
329
|
if isinstance(transform, sitk.BSplineTransform):
|
|
524
330
|
transform_type = "BSplineTransform_double_3_3"
|
|
525
|
-
attributes["{}:Transform"
|
|
526
|
-
attributes["{}:FixedParameters"
|
|
331
|
+
attributes[f"{i}:Transform"] = transform_type
|
|
332
|
+
attributes[f"{i}:FixedParameters"] = transform.GetFixedParameters()
|
|
527
333
|
|
|
528
334
|
datas.append(np.asarray(transform.GetParameters()))
|
|
529
335
|
data = np.asarray(datas)
|
|
530
|
-
elif os.path.exists("{}{}.fcsv"
|
|
531
|
-
with open("{}{}.fcsv"
|
|
532
|
-
reader = csv.reader(filter(lambda row: row[0]!=
|
|
336
|
+
elif os.path.exists(f"{self.filename}{name}.fcsv"):
|
|
337
|
+
with open(f"{self.filename}{name}.fcsv", newline="") as csvfile:
|
|
338
|
+
reader = csv.reader(filter(lambda row: row[0] != "#", csvfile))
|
|
533
339
|
lines = list(reader)
|
|
534
340
|
data = np.zeros((len(list(lines)), 3), dtype=np.double)
|
|
535
341
|
for i, row in enumerate(lines):
|
|
536
342
|
data[i] = np.array(row[1:4], dtype=np.double)
|
|
537
343
|
csvfile.close()
|
|
538
|
-
elif os.path.exists("{}{}.xml"
|
|
539
|
-
with open("{}{}.xml"
|
|
540
|
-
result = etree.parse(xml_file, etree.XMLParser(remove_blank_text=True)).getroot()
|
|
344
|
+
elif os.path.exists(f"{self.filename}{name}.xml"):
|
|
345
|
+
with open(f"{self.filename}{name}.xml", "rb") as xml_file:
|
|
346
|
+
result = etree.parse(xml_file, etree.XMLParser(remove_blank_text=True)).getroot() # nosec B320
|
|
541
347
|
xml_file.close()
|
|
542
348
|
return result
|
|
543
|
-
elif os.path.exists("{}{}.vtk"
|
|
349
|
+
elif os.path.exists(f"{self.filename}{name}.vtk"):
|
|
544
350
|
import vtk
|
|
545
|
-
|
|
546
|
-
|
|
547
|
-
|
|
351
|
+
|
|
352
|
+
vtk_reader = vtk.vtkPolyDataReader()
|
|
353
|
+
vtk_reader.SetFileName(f"{self.filename}{name}.vtk")
|
|
354
|
+
vtk_reader.Update()
|
|
548
355
|
data = []
|
|
549
|
-
points =
|
|
356
|
+
points = vtk_reader.GetOutput().GetPoints()
|
|
550
357
|
num_points = points.GetNumberOfPoints()
|
|
551
358
|
for i in range(num_points):
|
|
552
359
|
data.append(list(points.GetPoint(i)))
|
|
553
360
|
data = np.asarray(data)
|
|
554
|
-
elif os.path.exists("{}{}.npy"
|
|
555
|
-
data = np.load("{}{}.npy"
|
|
361
|
+
elif os.path.exists(f"{self.filename}{name}.npy"):
|
|
362
|
+
data = np.load(f"{self.filename}{name}.npy")
|
|
556
363
|
return data, attributes
|
|
557
|
-
|
|
364
|
+
|
|
558
365
|
def is_vtk_polydata(self, obj):
|
|
559
366
|
try:
|
|
560
367
|
import vtk
|
|
368
|
+
|
|
561
369
|
return isinstance(obj, vtk.vtkPolyData)
|
|
562
370
|
except ImportError:
|
|
563
371
|
return False
|
|
564
|
-
|
|
565
|
-
def
|
|
372
|
+
|
|
373
|
+
def __enter__(self):
|
|
374
|
+
pass
|
|
375
|
+
|
|
376
|
+
def __exit__(self, exc_type, value, traceback):
|
|
377
|
+
pass
|
|
378
|
+
|
|
379
|
+
def data_to_file(
|
|
380
|
+
self,
|
|
381
|
+
name: str,
|
|
382
|
+
data: sitk.Image | sitk.Transform | np.ndarray,
|
|
383
|
+
attributes: Attribute | None = None,
|
|
384
|
+
) -> None:
|
|
385
|
+
if attributes is None:
|
|
386
|
+
attributes = Attribute()
|
|
566
387
|
if not os.path.exists(self.filename):
|
|
567
388
|
os.makedirs(self.filename)
|
|
568
389
|
if isinstance(data, sitk.Image):
|
|
569
390
|
for k, v in attributes.items():
|
|
570
391
|
data.SetMetaData(k, v)
|
|
571
|
-
sitk.WriteImage(data, "{}{}.{}"
|
|
392
|
+
sitk.WriteImage(data, f"{self.filename}{name}.{self.file_format}")
|
|
572
393
|
elif isinstance(data, sitk.Transform):
|
|
573
|
-
sitk.WriteTransform(data, "{}{}.itk.txt"
|
|
394
|
+
sitk.WriteTransform(data, f"{self.filename}{name}.itk.txt")
|
|
574
395
|
elif self.is_vtk_polydata(data):
|
|
575
396
|
import vtk
|
|
576
|
-
|
|
577
|
-
|
|
578
|
-
|
|
579
|
-
|
|
580
|
-
|
|
397
|
+
|
|
398
|
+
vtk_writer = vtk.vtkPolyDataWriter()
|
|
399
|
+
vtk_writer.SetFileName(f"{self.filename}{name}.vtk")
|
|
400
|
+
vtk_writer.SetInputData(data)
|
|
401
|
+
vtk_writer.Write()
|
|
402
|
+
elif is_an_image(attributes):
|
|
581
403
|
self.data_to_file(name, data_to_image(data, attributes), attributes)
|
|
582
|
-
elif
|
|
404
|
+
elif len(data.shape) == 2 and data.shape[1] == 3 and data.shape[0] > 0:
|
|
583
405
|
data = np.round(data, 4)
|
|
584
|
-
with open("{}{}.fcsv"
|
|
585
|
-
f.write(
|
|
406
|
+
with open(f"{self.filename}{name}.fcsv", "w") as f:
|
|
407
|
+
f.write(
|
|
408
|
+
"# Markups fiducial file version = 4.6\n# CoordinateSystem = 0\n#"
|
|
409
|
+
" columns = id,x,y,z,ow,ox,oy,oz,vis,sel,lock,label,desc,associatedNodeID\n",
|
|
410
|
+
)
|
|
586
411
|
for i in range(data.shape[0]):
|
|
587
|
-
f.write(
|
|
412
|
+
f.write(
|
|
413
|
+
"vtkMRMLMarkupsFiducialNode_"
|
|
414
|
+
+ str(i + 1)
|
|
415
|
+
+ ","
|
|
416
|
+
+ str(data[i, 0])
|
|
417
|
+
+ ","
|
|
418
|
+
+ str(data[i, 1])
|
|
419
|
+
+ ","
|
|
420
|
+
+ str(data[i, 2])
|
|
421
|
+
+ ",0,0,0,1,1,1,0,F-"
|
|
422
|
+
+ str(i + 1)
|
|
423
|
+
+ ",,vtkMRMLScalarVolumeNode1\n"
|
|
424
|
+
)
|
|
588
425
|
f.close()
|
|
589
426
|
elif "path" in attributes:
|
|
590
|
-
if os.path.exists("{}{}.xml"
|
|
591
|
-
with open("{}{}.xml"
|
|
592
|
-
root = etree.parse(xml_file, etree.XMLParser(remove_blank_text=True)).getroot()
|
|
427
|
+
if os.path.exists(f"{self.filename}{name}.xml"):
|
|
428
|
+
with open(f"{self.filename}{name}.xml", "rb") as xml_file:
|
|
429
|
+
root = etree.parse(xml_file, etree.XMLParser(remove_blank_text=True)).getroot() # nosec B320
|
|
593
430
|
xml_file.close()
|
|
594
431
|
else:
|
|
595
432
|
root = etree.Element(name)
|
|
596
433
|
node = root
|
|
597
|
-
path = attributes["path"].split(
|
|
434
|
+
path = attributes["path"].split(":")
|
|
598
435
|
|
|
599
436
|
for node_name in path:
|
|
600
437
|
node_tmp = node.find(node_name)
|
|
601
|
-
if node_tmp
|
|
438
|
+
if node_tmp is None:
|
|
602
439
|
node_tmp = etree.SubElement(node, node_name)
|
|
603
440
|
node.append(node_tmp)
|
|
604
441
|
node = node_tmp
|
|
605
|
-
if attributes
|
|
442
|
+
if attributes is not None:
|
|
606
443
|
for attribute_tmp in attributes.keys():
|
|
607
444
|
attribute = "_".join(attribute_tmp.split("_")[:-1])
|
|
608
445
|
if attribute != "path":
|
|
609
446
|
node.set(attribute, attributes[attribute])
|
|
610
447
|
if data.size > 0:
|
|
611
|
-
node.text = ", ".join(
|
|
612
|
-
|
|
613
|
-
|
|
448
|
+
node.text = ", ".join(
|
|
449
|
+
map(str, data.flatten())
|
|
450
|
+
) # np.array2string(data, separator=',')[1:-1].replace('\n','')
|
|
451
|
+
with open(f"{self.filename}{name}.xml", "wb") as f:
|
|
452
|
+
f.write(etree.tostring(root, pretty_print=True, encoding="utf-8"))
|
|
614
453
|
f.close()
|
|
615
454
|
else:
|
|
616
|
-
np.save("{}{}.npy"
|
|
617
|
-
|
|
618
|
-
def
|
|
619
|
-
return
|
|
620
|
-
|
|
621
|
-
|
|
455
|
+
np.save(f"{self.filename}{name}.npy", data)
|
|
456
|
+
|
|
457
|
+
def is_exist(self, group: str, name: str | None = None) -> bool:
|
|
458
|
+
return (
|
|
459
|
+
os.path.exists(f"{self.filename}{group}.{self.file_format}")
|
|
460
|
+
or os.path.exists(f"{self.filename}{group}.itk.txt")
|
|
461
|
+
or os.path.exists(f"{self.filename}{group}.fcsv")
|
|
462
|
+
or os.path.exists(f"{self.filename}{group}.npy")
|
|
463
|
+
)
|
|
464
|
+
|
|
465
|
+
def get_names(self, group: str) -> list[str]:
|
|
622
466
|
raise NotImplementedError()
|
|
623
|
-
|
|
624
|
-
def
|
|
467
|
+
|
|
468
|
+
def get_group(self):
|
|
625
469
|
raise NotImplementedError()
|
|
626
|
-
|
|
627
|
-
def
|
|
470
|
+
|
|
471
|
+
def get_infos(self, group: str, name: str) -> tuple[list[int], Attribute]:
|
|
628
472
|
attributes = Attribute()
|
|
629
|
-
if os.path.exists("{
|
|
473
|
+
if os.path.exists(f"{self.filename}{group if group is not None else ''}{name}.{self.file_format}"):
|
|
630
474
|
file_reader = sitk.ImageFileReader()
|
|
631
|
-
file_reader.SetFileName("{
|
|
475
|
+
file_reader.SetFileName(f"{self.filename}{group if group is not None else ''}{name}.{self.file_format}")
|
|
632
476
|
file_reader.ReadImageInformation()
|
|
633
477
|
attributes["Origin"] = np.asarray(file_reader.GetOrigin())
|
|
634
478
|
attributes["Spacing"] = np.asarray(file_reader.GetSpacing())
|
|
@@ -638,72 +482,85 @@ class Dataset():
|
|
|
638
482
|
size = list(file_reader.GetSize())
|
|
639
483
|
if len(size) == 3:
|
|
640
484
|
size = list(reversed(size))
|
|
641
|
-
size = [file_reader.GetNumberOfComponents()]+size
|
|
485
|
+
size = [file_reader.GetNumberOfComponents()] + size
|
|
642
486
|
else:
|
|
643
487
|
data, attributes = self.file_to_data(group if group is not None else "", name)
|
|
644
488
|
size = data.shape
|
|
645
|
-
return
|
|
489
|
+
return size, attributes
|
|
646
490
|
|
|
647
|
-
class File
|
|
491
|
+
class File:
|
|
648
492
|
|
|
649
|
-
def __init__(self, filename: str, read: bool,
|
|
493
|
+
def __init__(self, filename: str, read: bool, file_format: str) -> None:
|
|
650
494
|
self.filename = filename
|
|
651
495
|
self.read = read
|
|
652
|
-
self.file = None
|
|
653
|
-
self.
|
|
496
|
+
self.file: "Dataset.AbstractFile" | None = None
|
|
497
|
+
self.file_format = file_format
|
|
654
498
|
|
|
655
|
-
def __enter__(self):
|
|
656
|
-
if self.
|
|
499
|
+
def __enter__(self) -> "Dataset.AbstractFile":
|
|
500
|
+
if self.file_format == "h5":
|
|
657
501
|
self.file = Dataset.H5File(self.filename, self.read)
|
|
658
502
|
else:
|
|
659
|
-
self.file = Dataset.SitkFile(self.filename+"/", self.read, self.
|
|
503
|
+
self.file = Dataset.SitkFile(self.filename + "/", self.read, self.file_format)
|
|
660
504
|
self.file.__enter__()
|
|
661
505
|
return self.file
|
|
662
506
|
|
|
663
|
-
def __exit__(self,
|
|
664
|
-
self.file
|
|
507
|
+
def __exit__(self, exc_type, value, traceback):
|
|
508
|
+
if self.file is not None:
|
|
509
|
+
self.file.__exit__(exc_type, value, traceback)
|
|
665
510
|
|
|
666
|
-
def __init__(self, filename
|
|
667
|
-
if
|
|
668
|
-
filename = "{}/"
|
|
669
|
-
self.is_directory = filename.endswith("/")
|
|
511
|
+
def __init__(self, filename: str, file_format: str) -> None:
|
|
512
|
+
if file_format != "h5" and not filename.endswith("/"):
|
|
513
|
+
filename = f"{filename}/"
|
|
514
|
+
self.is_directory = filename.endswith("/")
|
|
670
515
|
self.filename = filename
|
|
671
|
-
self.
|
|
672
|
-
|
|
673
|
-
def write(
|
|
516
|
+
self.file_format = file_format
|
|
517
|
+
|
|
518
|
+
def write(
|
|
519
|
+
self,
|
|
520
|
+
group: str,
|
|
521
|
+
name: str,
|
|
522
|
+
data: sitk.Image | sitk.Transform | np.ndarray,
|
|
523
|
+
attributes: Attribute | None = None,
|
|
524
|
+
):
|
|
525
|
+
if attributes is None:
|
|
526
|
+
attributes = Attribute()
|
|
674
527
|
if self.is_directory:
|
|
675
528
|
if not os.path.exists(self.filename):
|
|
676
529
|
os.makedirs(self.filename)
|
|
677
530
|
if self.is_directory:
|
|
678
531
|
s_group = group.split("/")
|
|
679
532
|
if len(s_group) > 1:
|
|
680
|
-
|
|
681
|
-
name = "{}/{}"
|
|
533
|
+
sub_directory = "/".join(s_group[:-1])
|
|
534
|
+
name = f"{sub_directory}/{name}"
|
|
682
535
|
group = s_group[-1]
|
|
683
|
-
with Dataset.File("{}{}"
|
|
536
|
+
with Dataset.File(f"{self.filename}{name}", False, self.file_format) as file:
|
|
684
537
|
file.data_to_file(group, data, attributes)
|
|
685
538
|
else:
|
|
686
|
-
with Dataset.File(self.filename, False, self.
|
|
687
|
-
file.data_to_file("{}/{}"
|
|
688
|
-
|
|
689
|
-
def
|
|
539
|
+
with Dataset.File(self.filename, False, self.file_format) as file:
|
|
540
|
+
file.data_to_file(f"{group}/{name}", data, attributes)
|
|
541
|
+
|
|
542
|
+
def read_data(self, groups: str, name: str) -> tuple[np.ndarray, Attribute]:
|
|
690
543
|
if not os.path.exists(self.filename):
|
|
691
|
-
raise NameError("Dataset {} not found"
|
|
544
|
+
raise NameError(f"Dataset {self.filename} not found")
|
|
692
545
|
if self.is_directory:
|
|
693
|
-
for
|
|
546
|
+
for sub_directory in self._get_sub_directories(groups):
|
|
694
547
|
group = groups.split("/")[-1]
|
|
695
|
-
if os.path.exists("{}{}{}{
|
|
696
|
-
with Dataset.File(
|
|
548
|
+
if os.path.exists(f"{self.filename}{sub_directory}{name}{'.h5' if self.file_format == 'h5' else ''}"):
|
|
549
|
+
with Dataset.File(
|
|
550
|
+
f"{self.filename}{sub_directory}{name}",
|
|
551
|
+
False,
|
|
552
|
+
self.file_format,
|
|
553
|
+
) as file:
|
|
697
554
|
result = file.file_to_data("", group)
|
|
698
555
|
else:
|
|
699
|
-
with Dataset.File(self.filename, False, self.
|
|
556
|
+
with Dataset.File(self.filename, False, self.file_format) as file:
|
|
700
557
|
result = file.file_to_data(groups, name)
|
|
701
558
|
return result
|
|
702
|
-
|
|
703
|
-
def
|
|
559
|
+
|
|
560
|
+
def read_transform(self, group: str, name: str) -> sitk.Transform:
|
|
704
561
|
if not os.path.exists(self.filename):
|
|
705
|
-
raise NameError("Dataset {} not found"
|
|
706
|
-
|
|
562
|
+
raise NameError(f"Dataset {self.filename} not found")
|
|
563
|
+
transform_parameters, attribute = self.read_data(group, name)
|
|
707
564
|
transforms_type = [v for k, v in attribute.items() if k.endswith(":Transform_0")]
|
|
708
565
|
transforms = []
|
|
709
566
|
for i, transform_type in enumerate(transforms_type):
|
|
@@ -713,78 +570,92 @@ class Dataset():
|
|
|
713
570
|
transform = sitk.AffineTransform(3)
|
|
714
571
|
if transform_type == "BSplineTransform_double_3_3":
|
|
715
572
|
transform = sitk.BSplineTransform(3)
|
|
716
|
-
transform.SetFixedParameters(
|
|
717
|
-
transform.SetParameters(tuple(
|
|
573
|
+
transform.SetFixedParameters(ast.literal_eval(attribute[f"{i}:FixedParameters"]))
|
|
574
|
+
transform.SetParameters(tuple(transform_parameters[i]))
|
|
718
575
|
transforms.append(transform)
|
|
719
576
|
return sitk.CompositeTransform(transforms) if len(transforms) > 1 else transforms[0]
|
|
720
577
|
|
|
721
|
-
def
|
|
722
|
-
|
|
723
|
-
|
|
724
|
-
|
|
725
|
-
def
|
|
726
|
-
return len(self.
|
|
727
|
-
|
|
728
|
-
def
|
|
729
|
-
return self.
|
|
730
|
-
|
|
731
|
-
def
|
|
732
|
-
return name in self.
|
|
733
|
-
|
|
734
|
-
def
|
|
578
|
+
def read_image(self, group: str, name: str):
|
|
579
|
+
data, attribute = self.read_data(group, name)
|
|
580
|
+
return data_to_image(data, attribute)
|
|
581
|
+
|
|
582
|
+
def get_size(self, group: str) -> int:
|
|
583
|
+
return len(self.get_names(group))
|
|
584
|
+
|
|
585
|
+
def is_group_exist(self, group: str) -> bool:
|
|
586
|
+
return self.get_size(group) > 0
|
|
587
|
+
|
|
588
|
+
def is_dataset_exist(self, group: str, name: str) -> bool:
|
|
589
|
+
return name in self.get_names(group)
|
|
590
|
+
|
|
591
|
+
def _get_sub_directories(self, groups: str, sub_directory: str = ""):
|
|
735
592
|
group = groups.split("/")[0]
|
|
736
|
-
|
|
593
|
+
sub_directories = []
|
|
737
594
|
if len(groups.split("/")) == 1:
|
|
738
|
-
|
|
595
|
+
sub_directories.append(sub_directory)
|
|
739
596
|
elif group == "*":
|
|
740
|
-
for k in os.listdir("{}{}"
|
|
741
|
-
if not os.path.isfile("{}{}{}"
|
|
742
|
-
|
|
597
|
+
for k in os.listdir(f"{self.filename}{sub_directory}"):
|
|
598
|
+
if not os.path.isfile(f"{self.filename}{sub_directory}{k}"):
|
|
599
|
+
sub_directories.extend(
|
|
600
|
+
self._get_sub_directories(
|
|
601
|
+
"/".join(groups.split("/")[1:]),
|
|
602
|
+
f"{sub_directory}{k}/",
|
|
603
|
+
)
|
|
604
|
+
)
|
|
743
605
|
else:
|
|
744
|
-
|
|
745
|
-
if os.path.exists("{}{}"
|
|
746
|
-
|
|
747
|
-
return
|
|
606
|
+
sub_directory = f"{sub_directory}{group}/"
|
|
607
|
+
if os.path.exists(f"{self.filename}{sub_directory}"):
|
|
608
|
+
sub_directories.extend(self._get_sub_directories("/".join(groups.split("/")[1:]), sub_directory))
|
|
609
|
+
return sub_directories
|
|
748
610
|
|
|
749
|
-
def
|
|
611
|
+
def get_names(self, groups: str, index: list[int] | None = None) -> list[str]:
|
|
750
612
|
names = []
|
|
751
613
|
if self.is_directory:
|
|
752
|
-
for
|
|
614
|
+
for sub_directory in self._get_sub_directories(groups):
|
|
753
615
|
group = groups.split("/")[-1]
|
|
754
|
-
if os.path.exists("{}{}"
|
|
755
|
-
for name in sorted(os.listdir("{}{}"
|
|
756
|
-
if os.path.isfile("{}{}{}"
|
|
757
|
-
with Dataset.File(
|
|
758
|
-
|
|
759
|
-
|
|
616
|
+
if os.path.exists(f"{self.filename}{sub_directory}"):
|
|
617
|
+
for name in sorted(os.listdir(f"{self.filename}{sub_directory}")):
|
|
618
|
+
if os.path.isfile(f"{self.filename}{sub_directory}{name}") or self.file_format != "h5":
|
|
619
|
+
with Dataset.File(
|
|
620
|
+
f"{self.filename}{sub_directory}{name}",
|
|
621
|
+
True,
|
|
622
|
+
self.file_format,
|
|
623
|
+
) as file:
|
|
624
|
+
if file.is_exist(group):
|
|
625
|
+
names.append(name.replace(".h5", "") if self.file_format == "h5" else name)
|
|
760
626
|
else:
|
|
761
|
-
with Dataset.File(self.filename, True, self.
|
|
762
|
-
names = file.
|
|
627
|
+
with Dataset.File(self.filename, True, self.file_format) as file:
|
|
628
|
+
names = file.get_names(groups)
|
|
763
629
|
return [name for i, name in enumerate(sorted(names)) if index is None or i in index]
|
|
764
|
-
|
|
765
|
-
def
|
|
630
|
+
|
|
631
|
+
def get_group(self):
|
|
766
632
|
if self.is_directory:
|
|
767
|
-
|
|
633
|
+
groups_set = set()
|
|
768
634
|
for root, _, files in os.walk(self.filename):
|
|
769
635
|
for file in files:
|
|
770
636
|
path = os.path.relpath(os.path.join(root, file.split(".")[0]), self.filename)
|
|
771
637
|
parts = path.split("/")
|
|
772
638
|
if len(parts) >= 2:
|
|
773
639
|
del parts[-2]
|
|
774
|
-
|
|
640
|
+
groups_set.add("/".join(parts))
|
|
641
|
+
groups = list(groups_set)
|
|
775
642
|
else:
|
|
776
|
-
with Dataset.File(self.filename, True, self.
|
|
777
|
-
groups =
|
|
643
|
+
with Dataset.File(self.filename, True, self.file_format) as dataset_file:
|
|
644
|
+
groups = dataset_file.get_group()
|
|
778
645
|
return list(groups)
|
|
779
|
-
|
|
780
|
-
def
|
|
646
|
+
|
|
647
|
+
def get_infos(self, groups: str, name: str) -> tuple[list[int], Attribute]:
|
|
781
648
|
if self.is_directory:
|
|
782
|
-
for
|
|
649
|
+
for sub_directory in self._get_sub_directories(groups):
|
|
783
650
|
group = groups.split("/")[-1]
|
|
784
|
-
if os.path.exists("{}{}{}{
|
|
785
|
-
with Dataset.File(
|
|
786
|
-
|
|
651
|
+
if os.path.exists(f"{self.filename}{sub_directory}{name}{'.h5' if self.file_format == 'h5' else ''}"):
|
|
652
|
+
with Dataset.File(
|
|
653
|
+
f"{self.filename}{sub_directory}{name}",
|
|
654
|
+
True,
|
|
655
|
+
self.file_format,
|
|
656
|
+
) as file:
|
|
657
|
+
result = file.get_infos("", group)
|
|
787
658
|
else:
|
|
788
|
-
with Dataset.File(self.filename, True, self.
|
|
789
|
-
result = file.
|
|
790
|
-
return result
|
|
659
|
+
with Dataset.File(self.filename, True, self.file_format) as file:
|
|
660
|
+
result = file.get_infos(groups, name)
|
|
661
|
+
return result
|