konfai 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of konfai might be problematic. Click here for more details.

@@ -0,0 +1,764 @@
1
+ import SimpleITK as sitk
2
+ import h5py
3
+ from abc import ABC, abstractmethod
4
+ import numpy as np
5
+ from typing import Any, Union
6
+ import copy
7
+ import torch
8
+ import os
9
+
10
+ from lxml import etree
11
+ import csv
12
+ import matplotlib.pyplot as pyplot
13
+ import pandas as pd
14
+ from KonfAI.konfai import DATE
15
+
16
+ class Plot():
17
+
18
+ def __init__(self, root: etree.ElementTree) -> None:
19
+ self.root = root
20
+
21
+ def _explore(root, result, label):
22
+ if len(root) == 0:
23
+ for attribute in root.attrib:
24
+ result["attrib:"+label+":"+attribute] = root.attrib[attribute]
25
+ if root.text is not None:
26
+ result[label] = np.fromstring(root.text, sep = ",").astype('double')
27
+ else:
28
+ for node in root:
29
+ Plot._explore(node, result, label+":"+node.tag)
30
+
31
+ def getNodes(root, path = None, id = None):
32
+ nodes = []
33
+ if path != None:
34
+ path = path.split(":")
35
+ for node_name in path:
36
+ node = root.find(node_name)
37
+ if node != None:
38
+ root = node
39
+ else:
40
+ break
41
+ if id != None:
42
+ for node in root.findall(".//"+id):
43
+ nodes.append(node)
44
+ else:
45
+ nodes.append(root)
46
+ return nodes
47
+
48
+ def read(root, path = None, id = None):
49
+ result = dict()
50
+ for node in Plot.getNodes(root, path, id):
51
+ Plot._explore(node, result, etree.ElementTree(root).getpath(node))
52
+ return result
53
+
54
+ def _extract(self, ids = [], patients = []):
55
+ result = dict()
56
+ if len(patients) == 0:
57
+ if len(ids) == 0:
58
+ result.update(Plot.read(self.root,None, None))
59
+ else:
60
+ for id in ids:
61
+ result.update(Plot.read(self.root, None, id))
62
+ else:
63
+ for path in patients:
64
+ if len(ids) == 0:
65
+ result.update(Plot.read(self.root, path, None))
66
+ else:
67
+ for id in ids:
68
+ result.update(Plot.read(self.root, path, id))
69
+ return result
70
+
71
+ def getErrors(self, ids = [], patients = []):
72
+ results = self._extract(ids=ids, patients=patients)
73
+ errors = {k: v for k, v in results.items() if not k.startswith("attrib:")}
74
+ results : dict[str, dict[str, np.ndarray]]= {}
75
+ for key, error in errors.items():
76
+ patient = key.replace("/",":").split(":")[2]
77
+ k = key.replace("/",":").split(":")[-1]
78
+ err = np.linalg.norm(error.reshape(int(error.shape[0]/3),3), ord=2, axis=1)
79
+ if patient not in results:
80
+ results[patient] = {k : err}
81
+ else:
82
+ results[patient].update({k : err})
83
+ return results
84
+
85
+ def statistic_attrib(self, ids = [], patients = [], type: str = "HD95Mean"):
86
+ results = self._extract(ids=ids, patients=patients)
87
+
88
+ errors = {k.replace("attrib:", ""): float(v) for k, v in results.items() if type in k}
89
+
90
+ values = {key : np.array([]) for key in ids}
91
+ for key, error in errors.items():
92
+ k = key.replace("/",":").split(":")[-2]
93
+ values[k] = np.append(values[k], error)
94
+
95
+ for k in values:
96
+ values[k] = np.mean(values[k])
97
+ print(values)
98
+ return values
99
+
100
+ def statistic_parameter(self, ids = [], patients = []):
101
+ results = self._extract(ids=ids, patients=patients)
102
+ errors = {k.replace("attrib:", "").replace(":Time", "") : np.load("./Results/{}/{}.npy".format(k.split("/")[3].split(":")[0], k.split("/")[2])) for k in results.keys()}
103
+
104
+ norms = {key : np.array([]) for key in ids}
105
+ max = 0
106
+ for key, error in errors.items():
107
+ if max < int(error.shape[0]/3):
108
+ max = int(error.shape[0]/3)
109
+
110
+ for key, error in errors.items():
111
+ k = key.replace("/",":").split(":")[-1]
112
+ norms[k] = np.append(norms[k], np.linalg.norm(error.reshape(int(error.shape[0]/3),3), ord=2, axis=1))
113
+ v = np.linalg.norm(error.reshape(int(error.shape[0]/3),3), ord=2, axis=1)
114
+
115
+ print(key, "{} {} {} {} {}".format(np.round(np.mean(v), 2), np.round(np.std(v), 2), np.round(np.quantile(v, 0.25), 2), np.round(np.quantile(v, 0.5), 2), np.round(np.quantile(v, 0.75), 2)))
116
+ results = {}
117
+ for key, values in norms.items():
118
+ if key == "Rigid":
119
+ results.update({key : values})
120
+ else:
121
+ try:
122
+ name = "{}".format("_".join(key.split("_"))[:-1])
123
+ it = int(key.split("_")[-1])
124
+ except:
125
+ name = "{}".format(key.split("-")[0])
126
+ it = int(key.split("-")[-1])
127
+
128
+ if name in results:
129
+ results[name].update({it : values})
130
+ else:
131
+ results.update({name : {it : values}})
132
+
133
+ r = []
134
+ for key, values in norms.items():
135
+ #r.append("{} $\pm$ {}".format(np.round(np.mean(values), 2), np.round(np.std(values), 2)))
136
+ r.append("{} {} {}".format(np.round(np.quantile(values, 0.25), 2), np.round(np.quantile(values, 0.5),2), np.round(np.quantile(values, 0.75), 2)))
137
+ #r.append("{} $\pm$ {}".format(np.round(np.quantile(values, 0.5), 2), np.round(np.quantile(values, 0.75)-np.quantile(values, 0.25), 2)))
138
+ print(" & ".join(r))
139
+
140
+ def statistic(self, ids = [], patients = []):
141
+ results = self._extract(ids=ids, patients=patients)
142
+ #errors = {k.replace("attrib:", "").replace(":Time", "") : np.load("./Dataset/{}/{}.npy".format(k.split("/")[3].split(":")[0], k.split("/")[2])) for k in results.keys()}
143
+ errors = {k: v for k, v in results.items() if not k.startswith("attrib:")}
144
+ print(errors)
145
+ norms = {key : np.array([]) for key in ids}
146
+ max = 0
147
+ for key, error in errors.items():
148
+ if max < int(error.shape[0]/3):
149
+ max = int(error.shape[0]/3)
150
+
151
+ for key, error in errors.items():
152
+ k = key.replace("/",":").split(":")[-1]
153
+ norms[k] = np.append(norms[k], np.linalg.norm(error.reshape(int(error.shape[0]/3),3), ord=2, axis=1))
154
+ v = np.linalg.norm(error.reshape(int(error.shape[0]/3),3), ord=2, axis=1)
155
+ print(key, (np.mean(v), np.std(v), np.quantile(v, 0.25), np.quantile(v, 0.5), np.quantile(v, 0.75)) )
156
+ results = {}
157
+ """for key, values in norms.items():
158
+ if key == "Rigid":
159
+ results.update({key : values})
160
+ else:
161
+ try:
162
+ name = "{}".format("_".join(key.split("_"))[:-1])
163
+ it = int(key.split("_")[-1])
164
+ except:
165
+ name = "{}".format(key.split("-")[0])
166
+ it = int(key.split("-")[-1])
167
+
168
+ if name in results:
169
+ results[name].update({it : values})
170
+ else:
171
+ results.update({name : {it : values}})"""
172
+
173
+
174
+ print({key: (np.mean(values), np.std(values), np.quantile(values, 0.25), np.quantile(values, 0.5), np.quantile(values, 0.75)) for key, values in norms.items()})
175
+ return results
176
+
177
+ def plot(self, ids = [], patients = [], labels = [], colors = None):
178
+ results = self._extract(ids=ids, patients=patients)
179
+
180
+ attrs = {k: v for k, v in results.items() if k.startswith("attrib:")}
181
+ errors = {k: v for k, v in results.items() if not k.startswith("attrib:")}
182
+
183
+ patients = set()
184
+ max = 0
185
+ for key, error in errors.items():
186
+ patients.add(key.replace("/",":").split(":")[2])
187
+ if max < int(error.shape[0]/3):
188
+ max = int(error.shape[0]/3)
189
+ patients = sorted(patients)
190
+
191
+ norms = {patient : np.array([]) for patient in patients}
192
+ markups = {patient : np.array([]) for patient in patients}
193
+ series = list()
194
+ for key, error in errors.items():
195
+ patient = key.replace("/",":").split(":")[2]
196
+ markup = np.full((max,3), np.nan)
197
+ markup[0:int(error.shape[0]/3), :] = error.reshape(int(error.shape[0]/3),3)
198
+
199
+ markups[patient] = np.append(markups[patient], markup)
200
+ norms[patient] = np.append(norms[patient], np.linalg.norm(markup, ord=2, axis=1))
201
+
202
+ if len(labels) == 0:
203
+ labels = list(set([k.split("/")[-1] for k in errors.keys()]))
204
+
205
+ for label in labels:
206
+ series = series+[label]*max
207
+
208
+ df = pd.DataFrame(dict([(k,pd.Series(v)) for k, v in norms.items()]))
209
+ df['Categories'] = pd.Series(series)
210
+
211
+ bp = df.boxplot(by='Categories', color="black", figsize=(12,8), notch=True,layout=(1,len(patients)), fontsize=18, rot=0, patch_artist = True, return_type='both', widths=[0.5]*len(labels))
212
+
213
+ color_pallet = {"b" : "paleturquoise", "g" : "lightgreen"}
214
+ if colors == None:
215
+ colors = ["b"] * len(patients)
216
+ pyplot.suptitle('')
217
+ it_1 = 0
218
+ for index, (ax,row) in bp.items():
219
+ ax.set_xlabel('')
220
+ ax.set_ylim(ymin=0)
221
+ ax.set_ylabel("TRE (mm)", fontsize=18)
222
+ ax.set_yticks([0,1,2,3,4,5,6,7,8,9,10,15,20,25]) # Set label locations.
223
+ for i,object in enumerate(row["boxes"]):
224
+ object.set_edgecolor("black")
225
+ object.set_facecolor(color_pallet[colors[i]])
226
+ object.set_alpha(0.7)
227
+ object.set_linewidth(1.0)
228
+
229
+ for i,object in enumerate(row["medians"]):
230
+ object.set_color("indianred")
231
+ xy = object.get_xydata()
232
+ object.set_linewidth(2.0)
233
+ it_1+=1
234
+ return self
235
+
236
+ def show(self):
237
+ pyplot.show()
238
+
239
+ class Attribute(dict[str, Any]):
240
+
241
+ def __init__(self, attributes : dict[str, Any] = {}) -> None:
242
+ super().__init__()
243
+ for k, v in attributes.items():
244
+ super().__setitem__(copy.deepcopy(k), copy.deepcopy(v))
245
+
246
+ def __getitem__(self, key: str) -> Any:
247
+ i = len([k for k in super().keys() if k.startswith(key)])
248
+ if i > 0 and "{}_{}".format(key, i-1) in super().keys():
249
+ return str(super().__getitem__("{}_{}".format(key, i-1)))
250
+ else:
251
+ raise NameError("{} not in cache_attribute".format(key))
252
+
253
+ def __setitem__(self, key: str, value: Any) -> None:
254
+ if "_" not in key:
255
+ i = len([k for k in super().keys() if k.startswith(key)])
256
+ result = None
257
+ if isinstance(value, torch.Tensor):
258
+ result = str(value.numpy())
259
+ else:
260
+ result = str(value)
261
+ result = result.replace('\n', '')
262
+ super().__setitem__("{}_{}".format(key, i), result)
263
+ else:
264
+ result = None
265
+ if isinstance(value, torch.Tensor):
266
+ result = str(value.numpy())
267
+ else:
268
+ result = str(value)
269
+ result = result.replace('\n', '')
270
+ super().__setitem__(key, result)
271
+
272
+ def pop(self, key: str) -> Any:
273
+ i = len([k for k in super().keys() if k.startswith(key)])
274
+ if i > 0 and "{}_{}".format(key, i-1) in super().keys():
275
+ return super().pop("{}_{}".format(key, i-1))
276
+ else:
277
+ raise NameError("{} not in cache_attribute".format(key))
278
+
279
+ def get_np_array(self, key) -> np.ndarray:
280
+ return np.fromstring(self[key][1:-1], sep=" ", dtype=np.double)
281
+
282
+ def get_tensor(self, key) -> torch.Tensor:
283
+ return torch.tensor(self.get_np_array(key)).to(torch.float32)
284
+
285
+ def pop_np_array(self, key):
286
+ return np.fromstring(self.pop(key)[1:-1], sep=" ", dtype=np.double)
287
+
288
+ def pop_tensor(self, key) -> torch.Tensor:
289
+ return torch.tensor(self.pop_np_array(key))
290
+
291
+ def __contains__(self, key: str) -> bool:
292
+ return len([k for k in super().keys() if k.startswith(key)]) > 0
293
+
294
+ def isInfo(self, key: str, value: str) -> bool:
295
+ return key in self and self[key] == value
296
+
297
+ def isAnImage(attributes: Attribute):
298
+ return "Origin" in attributes and "Spacing" in attributes and "Direction" in attributes
299
+
300
+ def data_to_image(data : np.ndarray, attributes: Attribute) -> sitk.Image:
301
+ if not isAnImage(attributes):
302
+ raise NameError("Data is not an image")
303
+ if data.shape[0] == 1:
304
+ image = sitk.GetImageFromArray(data[0])
305
+ else:
306
+ data = data.transpose(tuple([i+1 for i in range(len(data.shape)-1)]+[0]))
307
+ image = sitk.GetImageFromArray(data, isVector=True)
308
+ image.SetOrigin(attributes.get_np_array("Origin").tolist())
309
+ image.SetSpacing(attributes.get_np_array("Spacing").tolist())
310
+ image.SetDirection(attributes.get_np_array("Direction").tolist())
311
+ return image
312
+
313
+ def image_to_data(image: sitk.Image) -> tuple[np.ndarray, Attribute]:
314
+ attributes = Attribute()
315
+ attributes["Origin"] = np.asarray(image.GetOrigin())
316
+ attributes["Spacing"] = np.asarray(image.GetSpacing())
317
+ attributes["Direction"] = np.asarray(image.GetDirection())
318
+ data = sitk.GetArrayFromImage(image)
319
+
320
+ if image.GetNumberOfComponentsPerPixel() == 1:
321
+ data = np.expand_dims(data, 0)
322
+ else:
323
+ data = np.transpose(data, (len(data.shape)-1, *[i for i in range(len(data.shape)-1)]))
324
+ return data, attributes
325
+
326
+ class Dataset():
327
+
328
+ class AbstractFile(ABC):
329
+
330
+ def __init__(self) -> None:
331
+ pass
332
+
333
+ def __enter__(self):
334
+ pass
335
+
336
+ def __exit__(self, type, value, traceback):
337
+ pass
338
+
339
+ @abstractmethod
340
+ def file_to_data(self):
341
+ pass
342
+
343
+ @abstractmethod
344
+ def data_to_file(self):
345
+ pass
346
+
347
+ @abstractmethod
348
+ def getNames(self, group: str) -> list[str]:
349
+ pass
350
+
351
+ @abstractmethod
352
+ def isExist(self, group: str, name: Union[str, None] = None) -> bool:
353
+ pass
354
+
355
+ @abstractmethod
356
+ def getInfos(self, group: Union[str, None], name: str) -> tuple[list[int], Attribute]:
357
+ pass
358
+
359
+ class H5File(AbstractFile):
360
+
361
+ def __init__(self, filename: str, read: bool) -> None:
362
+ self.h5: Union[h5py.File, None] = None
363
+ self.filename = filename
364
+ if not self.filename.endswith(".h5"):
365
+ self.filename += ".h5"
366
+ self.read = read
367
+
368
+ def __enter__(self):
369
+ args = {}
370
+ if self.read:
371
+ self.h5 = h5py.File(self.filename, 'r', **args)
372
+ else:
373
+ if not os.path.exists(self.filename):
374
+ if len(self.filename.split("/")) > 1 and not os.path.exists("/".join(self.filename.split("/")[:-1])):
375
+ os.makedirs("/".join(self.filename.split("/")[:-1]))
376
+ self.h5 = h5py.File(self.filename, 'w', **args)
377
+ else:
378
+ self.h5 = h5py.File(self.filename, 'r+', **args)
379
+ self.h5.attrs["Date"] = DATE()
380
+ self.h5.__enter__()
381
+ return self.h5
382
+
383
+ def __exit__(self, type, value, traceback):
384
+ if self.h5 is not None:
385
+ self.h5.close()
386
+
387
+ def file_to_data(self, groups: str, name: str) -> tuple[np.ndarray, Attribute]:
388
+ dataset = self._getDataset(groups, name)
389
+ data = np.zeros(dataset.shape, dataset.dtype)
390
+ dataset.read_direct(data)
391
+ return data, Attribute({k : str(v) for k, v in dataset.attrs.items()})
392
+
393
+ def data_to_file(self, name : str, data : Union[sitk.Image, sitk.Transform, np.ndarray], attributes : Union[Attribute, None] = None) -> None:
394
+ if attributes is None:
395
+ attributes = Attribute()
396
+ if isinstance(data, sitk.Image):
397
+ data, attributes_tmp = image_to_data(data)
398
+ attributes.update(attributes_tmp)
399
+ elif isinstance(data, sitk.Transform):
400
+ transforms = []
401
+ if isinstance(data, sitk.CompositeTransform):
402
+ for i in range(data.GetNumberOfTransforms()):
403
+ transforms.append(data.GetNthTransform(i))
404
+ else:
405
+ transforms.append(data)
406
+ datas = []
407
+ for i, transform in enumerate(transforms):
408
+ if isinstance(transform, sitk.Euler3DTransform):
409
+ transform_type = "Euler3DTransform_double_3_3"
410
+ if isinstance(transform, sitk.AffineTransform):
411
+ transform_type = "AffineTransform_double_3_3"
412
+ if isinstance(transform, sitk.BSplineTransform):
413
+ transform_type = "BSplineTransform_double_3_3"
414
+ attributes["{}:Transform".format(i)] = transform_type
415
+ attributes["{}:FixedParameters".format(i)] = transform.GetFixedParameters()
416
+
417
+ datas.append(np.asarray(transform.GetParameters()))
418
+ data = np.asarray(datas)
419
+
420
+ h5_group = self.h5
421
+ if len(name.split("/")) > 1:
422
+ group = "/".join(name.split("/")[:-1])
423
+ if group not in self.h5:
424
+ self.h5.create_group(group)
425
+ h5_group = self.h5[group]
426
+
427
+ name = name.split("/")[-1]
428
+ if name in h5_group:
429
+ del h5_group[name]
430
+
431
+ dataset = h5_group.create_dataset(name, data=data, dtype=data.dtype, chunks=None)
432
+ dataset.attrs.update({k : str(v) for k, v in attributes.items()})
433
+
434
+ def isExist(self, group: str, name: Union[str, None] = None) -> bool:
435
+ if group in self.h5:
436
+ if isinstance(self.h5[group], h5py.Dataset):
437
+ return True
438
+ elif name is not None:
439
+ return name in self.h5[group]
440
+ else:
441
+ return False
442
+ return False
443
+
444
+ def getNames(self, groups: str, h5_group: h5py.Group = None) -> list[str]:
445
+ names = []
446
+ if h5_group is None:
447
+ h5_group = self.h5
448
+ group = groups.split("/")[0]
449
+ if group == "":
450
+ names = [dataset.name.split("/")[-1] for dataset in h5_group.values() if isinstance(dataset, h5py.Dataset)]
451
+ elif group == "*":
452
+ for k in h5_group.keys():
453
+ if isinstance(h5_group[k], h5py.Group):
454
+ names.extend(self.getNames("/".join(groups.split("/")[1:]), h5_group[k]))
455
+ else:
456
+ if group in h5_group:
457
+ names.extend(self.getNames("/".join(groups.split("/")[1:]), h5_group[group]))
458
+ return names
459
+
460
+ def _getDataset(self, groups: str, name: str, h5_group: h5py.Group = None) -> h5py.Dataset:
461
+ if h5_group is None:
462
+ h5_group = self.h5
463
+ if groups != "":
464
+ group = groups.split("/")[0]
465
+ else:
466
+ group = ""
467
+ result = None
468
+ if group == "":
469
+ if name in h5_group:
470
+ result = h5_group[name]
471
+ elif group == "*":
472
+ for k in h5_group.keys():
473
+ if isinstance(h5_group[k], h5py.Group):
474
+ result_tmp = self._getDataset("/".join(groups.split("/")[1:]), name, h5_group[k])
475
+ if result_tmp is not None:
476
+ result = result_tmp
477
+ else:
478
+ if group in h5_group:
479
+ result_tmp = self._getDataset("/".join(groups.split("/")[1:]), name, h5_group[group])
480
+ if result_tmp is not None:
481
+ result = result_tmp
482
+ return result
483
+
484
+ def getInfos(self, groups: str, name: str) -> tuple[list[int], Attribute]:
485
+ dataset = self._getDataset(groups, name)
486
+ return (dataset.shape, Attribute({k : str(v) for k, v in dataset.attrs.items()}))
487
+
488
+ class SitkFile(AbstractFile):
489
+
490
+ def __init__(self, filename: str, read: bool, format: str) -> None:
491
+ self.filename = filename
492
+ self.read = read
493
+ self.format = format
494
+
495
+ def file_to_data(self, group: str, name: str) -> tuple[np.ndarray, Attribute]:
496
+ attributes = Attribute()
497
+ if os.path.exists("{}{}.{}".format(self.filename, name, self.format)):
498
+ image = sitk.ReadImage("{}{}.{}".format(self.filename, name, self.format))
499
+ data, attributes_tmp = image_to_data(image)
500
+ attributes.update(attributes_tmp)
501
+ elif os.path.exists("{}{}.itk.txt".format(self.filename, name)):
502
+ data = sitk.ReadTransform("{}{}.itk.txt".format(self.filename, name))
503
+ transforms = []
504
+ if isinstance(data, sitk.CompositeTransform):
505
+ for i in range(data.GetNumberOfTransforms()):
506
+ transforms.append(data.GetNthTransform(i))
507
+ else:
508
+ transforms.append(data)
509
+ datas = []
510
+ for i, transform in enumerate(transforms):
511
+ if isinstance(transform, sitk.Euler3DTransform):
512
+ transform_type = "Euler3DTransform_double_3_3"
513
+ if isinstance(transform, sitk.AffineTransform):
514
+ transform_type = "AffineTransform_double_3_3"
515
+ if isinstance(transform, sitk.BSplineTransform):
516
+ transform_type = "BSplineTransform_double_3_3"
517
+ attributes["{}:Transform".format(i)] = transform_type
518
+ attributes["{}:FixedParameters".format(i)] = transform.GetFixedParameters()
519
+
520
+ datas.append(np.asarray(transform.GetParameters()))
521
+ data = np.asarray(datas)
522
+ elif os.path.exists("{}{}.fcsv".format(self.filename, name)):
523
+ with open("{}{}.fcsv".format(self.filename, name), newline="") as csvfile:
524
+ reader = csv.reader(filter(lambda row: row[0]!='#', csvfile))
525
+ lines = list(reader)
526
+ data = np.zeros((len(list(lines)), 3), dtype=np.double)
527
+ for i, row in enumerate(lines):
528
+ data[i] = np.array(row[1:4], dtype=np.double)
529
+ csvfile.close()
530
+ elif os.path.exists("{}{}.xml".format(self.filename, name)):
531
+ with open("{}{}.xml".format(self.filename, name), 'rb') as xml_file:
532
+ result = etree.parse(xml_file, etree.XMLParser(remove_blank_text=True)).getroot()
533
+ xml_file.close()
534
+ return result
535
+ elif os.path.exists("{}{}.vtk".format(self.filename, name)):
536
+ import vtk
537
+ vtkReader = vtk.vtkPolyDataReader()
538
+ vtkReader.SetFileName("{}{}.vtk".format(self.filename, name))
539
+ vtkReader.Update()
540
+ data = []
541
+ points = vtkReader.GetOutput().GetPoints()
542
+ num_points = points.GetNumberOfPoints()
543
+ for i in range(num_points):
544
+ data.append(list(points.GetPoint(i)))
545
+ data = np.asarray(data)
546
+ elif os.path.exists("{}{}.npy".format(self.filename, name)):
547
+ data = np.load("{}{}.npy".format(self.filename, name))
548
+ return data, attributes
549
+
550
+ def is_vtk_polydata(self, obj):
551
+ try:
552
+ import vtk
553
+ return isinstance(obj, vtk.vtkPolyData)
554
+ except ImportError:
555
+ return False
556
+
557
+ def data_to_file(self, name : str, data : Union[sitk.Image, sitk.Transform, np.ndarray], attributes : Attribute = Attribute()) -> None:
558
+ if not os.path.exists(self.filename):
559
+ os.makedirs(self.filename)
560
+ if isinstance(data, sitk.Image):
561
+ for k, v in attributes.items():
562
+ data.SetMetaData(k, v)
563
+ sitk.WriteImage(data, "{}{}.{}".format(self.filename, name, self.format))
564
+ elif isinstance(data, sitk.Transform):
565
+ sitk.WriteTransform(data, "{}{}.itk.txt".format(self.filename, name))
566
+ elif self.is_vtk_polydata(data):
567
+ import vtk
568
+ vtkWriter = vtk.vtkPolyDataWriter()
569
+ vtkWriter.SetFileName("{}{}.vtk".format(self.filename, name))
570
+ vtkWriter.SetInputData(data)
571
+ vtkWriter.Write()
572
+ elif isAnImage(attributes):
573
+ self.data_to_file(name, data_to_image(data, attributes), attributes)
574
+ elif (len(data.shape) == 2 and data.shape[1] == 3 and data.shape[0] > 0):
575
+ data = np.round(data, 4)
576
+ with open("{}{}.fcsv".format(self.filename, name), 'w') as f:
577
+ f.write("# Markups fiducial file version = 4.6\n# CoordinateSystem = 0\n# columns = id,x,y,z,ow,ox,oy,oz,vis,sel,lock,label,desc,associatedNodeID\n")
578
+ for i in range(data.shape[0]):
579
+ f.write("vtkMRMLMarkupsFiducialNode_"+str(i+1)+","+str(data[i, 0])+","+str(data[i, 1])+","+str(data[i, 2])+",0,0,0,1,1,1,0,F-"+str(i+1)+",,vtkMRMLScalarVolumeNode1\n")
580
+ f.close()
581
+ elif "path" in attributes:
582
+ if os.path.exists("{}{}.xml".format(self.filename, name)):
583
+ with open("{}{}.xml".format(self.filename, name), 'rb') as xml_file:
584
+ root = etree.parse(xml_file, etree.XMLParser(remove_blank_text=True)).getroot()
585
+ xml_file.close()
586
+ else:
587
+ root = etree.Element(name)
588
+ node = root
589
+ path = attributes["path"].split(':')
590
+
591
+ for node_name in path:
592
+ node_tmp = node.find(node_name)
593
+ if node_tmp == None:
594
+ node_tmp = etree.SubElement(node, node_name)
595
+ node.append(node_tmp)
596
+ node = node_tmp
597
+ if attributes != None:
598
+ for attribute_tmp in attributes.keys():
599
+ attribute = "_".join(attribute_tmp.split("_")[:-1])
600
+ if attribute != "path":
601
+ node.set(attribute, attributes[attribute])
602
+ if data.size > 0:
603
+ node.text = ", ".join(map(str, data.flatten())) #np.array2string(data, separator=',')[1:-1].replace('\n','')
604
+ with open("{}{}.xml".format(self.filename, name), 'wb') as f:
605
+ f.write(etree.tostring(root, pretty_print=True, encoding='utf-8'))
606
+ f.close()
607
+ else:
608
+ np.save("{}{}.npy".format(self.filename, name), data)
609
+
610
+ def isExist(self, group: str, name: Union[str, None] = None) -> bool:
611
+ return os.path.exists("{}{}.{}".format(self.filename, group, self.format)) or os.path.exists("{}{}.itk.txt".format(self.filename, group)) or os.path.exists("{}{}.fcsv".format(self.filename, group)) or os.path.exists("{}{}.npy".format(self.filename, group))
612
+
613
+ def getNames(self, group: str) -> list[str]:
614
+ raise NotImplementedError()
615
+
616
+ def getInfos(self, group: str, name: str) -> tuple[list[int], Attribute]:
617
+ attributes = Attribute()
618
+ if os.path.exists("{}{}{}.{}".format(self.filename, group if group is not None else "", name, self.format)):
619
+ file_reader = sitk.ImageFileReader()
620
+ file_reader.SetFileName("{}{}{}.{}".format(self.filename, group if group is not None else "", name, self.format))
621
+ file_reader.ReadImageInformation()
622
+ attributes["Origin"] = np.asarray(file_reader.GetOrigin())
623
+ attributes["Spacing"] = np.asarray(file_reader.GetSpacing())
624
+ attributes["Direction"] = np.asarray(file_reader.GetDirection())
625
+ for k in file_reader.GetMetaDataKeys():
626
+ attributes[k] = file_reader.GetMetaData(k)
627
+ size = list(file_reader.GetSize())
628
+ if len(size) == 3:
629
+ size = list(reversed(size))
630
+ size = [file_reader.GetNumberOfComponents()]+size
631
+ else:
632
+ data, attributes = self.file_to_data(group if group is not None else "", name)
633
+ size = data.shape
634
+ return tuple(size), attributes
635
+
636
+ class File(ABC):
637
+
638
+ def __init__(self, filename: str, read: bool, format: str) -> None:
639
+ self.filename = filename
640
+ self.read = read
641
+ self.file = None
642
+ self.format = format
643
+
644
+ def __enter__(self):
645
+ if self.format == "h5":
646
+ self.file = Dataset.H5File(self.filename, self.read)
647
+ else:
648
+ self.file = Dataset.SitkFile(self.filename+"/", self.read, self.format)
649
+ self.file.__enter__()
650
+ return self.file
651
+
652
+ def __exit__(self, type, value, traceback):
653
+ self.file.__exit__(type, value, traceback)
654
+
655
+ def __init__(self, filename : str, format: str) -> None:
656
+ if format != "h5" and not filename.endswith("/"):
657
+ filename = "{}/".format(filename)
658
+ self.is_directory = filename.endswith("/")
659
+ self.filename = filename
660
+ self.format = format
661
+
662
+ def write(self, group : str, name : str, data : Union[sitk.Image, sitk.Transform, np.ndarray], attributes : Attribute = Attribute()):
663
+ if self.is_directory:
664
+ if not os.path.exists(self.filename):
665
+ os.makedirs(self.filename)
666
+ if self.is_directory:
667
+ s_group = group.split("/")
668
+ if len(s_group) > 1:
669
+ subDirectory = "/".join(s_group[:-1])
670
+ name = "{}/{}".format(subDirectory, name)
671
+ group = s_group[-1]
672
+ with Dataset.File("{}{}".format(self.filename, name), False, self.format) as file:
673
+ file.data_to_file(group, data, attributes)
674
+ else:
675
+ with Dataset.File(self.filename, False, self.format) as file:
676
+ file.data_to_file("{}/{}".format(group, name), data, attributes)
677
+
678
+ def readData(self, groups : str, name : str) -> tuple[np.ndarray, Attribute]:
679
+ if not os.path.exists(self.filename):
680
+ raise NameError("Dataset {} not found".format(self.filename))
681
+ if self.is_directory:
682
+ for subDirectory in self._getSubDirectories(groups):
683
+ group = groups.split("/")[-1]
684
+ if os.path.exists("{}{}{}{}".format(self.filename, subDirectory, name, ".h5" if self.format == "h5" else "")):
685
+ with Dataset.File("{}{}{}".format(self.filename, subDirectory, name), False, self.format) as file:
686
+ result = file.file_to_data("", group)
687
+ else:
688
+ with Dataset.File(self.filename, False, self.format) as file:
689
+ result = file.file_to_data(groups, name)
690
+ return result
691
+
692
+ def readTransform(self, group : str, name : str) -> sitk.Transform:
693
+ if not os.path.exists(self.filename):
694
+ raise NameError("Dataset {} not found".format(self.filename))
695
+ transformParameters, attribute = self.readData(group, name)
696
+ transforms_type = [v for k, v in attribute.items() if k.endswith(":Transform_0")]
697
+ transforms = []
698
+ for i, transform_type in enumerate(transforms_type):
699
+ if transform_type == "Euler3DTransform_double_3_3":
700
+ transform = sitk.Euler3DTransform()
701
+ if transform_type == "AffineTransform_double_3_3":
702
+ transform = sitk.AffineTransform(3)
703
+ if transform_type == "BSplineTransform_double_3_3":
704
+ transform = sitk.BSplineTransform(3)
705
+ transform.SetFixedParameters(eval(attribute["{}:FixedParameters".format(i)]))
706
+ transform.SetParameters(tuple(transformParameters[i]))
707
+ transforms.append(transform)
708
+ return sitk.CompositeTransform(transforms) if len(transforms) > 1 else transforms[0]
709
+
710
+ def readImage(self, group : str, name : str):
711
+ data, attribute = self.readData(group, name)
712
+ return data_to_image(data, attribute)
713
+
714
+ def getSize(self, group: str) -> int:
715
+ return len(self.getNames(group))
716
+
717
+ def isGroupExist(self, group: str) -> bool:
718
+ return self.getSize(group) > 0
719
+
720
+ def isDatasetExist(self, group: str, name: str) -> bool:
721
+ return name in self.getNames(group)
722
+
723
+ def _getSubDirectories(self, groups: str, subDirectory: str = ""):
724
+ group = groups.split("/")[0]
725
+ subDirectories = []
726
+ if len(groups.split("/")) == 1:
727
+ subDirectories.append(subDirectory)
728
+ elif group == "*":
729
+ for k in os.listdir("{}{}".format(self.filename, subDirectory)):
730
+ if not os.path.isfile("{}{}{}".format(self.filename, subDirectory, k)):
731
+ subDirectories.extend(self._getSubDirectories("/".join(groups.split("/")[1:]), "{}{}/".format(subDirectory , k)))
732
+ else:
733
+ subDirectory = "{}{}/".format(subDirectory, group)
734
+ if os.path.exists("{}{}".format(self.filename, subDirectory)):
735
+ subDirectories.extend(self._getSubDirectories("/".join(groups.split("/")[1:]), subDirectory))
736
+ return subDirectories
737
+
738
+ def getNames(self, groups: str, index: Union[list[int], None] = None, subDirectory: str = "") -> list[str]:
739
+ names = []
740
+ if self.is_directory:
741
+ for subDirectory in self._getSubDirectories(groups):
742
+ group = groups.split("/")[-1]
743
+ if os.path.exists("{}{}".format(self.filename, subDirectory)):
744
+ for name in sorted(os.listdir("{}{}".format(self.filename, subDirectory))):
745
+ if os.path.isfile("{}{}{}".format(self.filename, subDirectory, name)) or self.format != "h5":
746
+ with Dataset.File("{}{}{}".format(self.filename, subDirectory, name), True, self.format) as file:
747
+ if file.isExist(group):
748
+ names.append(name.replace(".h5", "") if self.format == "h5" else name)
749
+ else:
750
+ with Dataset.File(self.filename, True, self.format) as file:
751
+ names = file.getNames(groups)
752
+ return [name for i, name in enumerate(names) if index is None or i in index]
753
+
754
+ def getInfos(self, groups: str, name: str) -> tuple[list[int], Attribute]:
755
+ if self.is_directory:
756
+ for subDirectory in self._getSubDirectories(groups):
757
+ group = groups.split("/")[-1]
758
+ if os.path.exists("{}{}{}{}".format(self.filename, subDirectory, name, ".h5" if self.format == "h5" else "")):
759
+ with Dataset.File("{}{}{}".format(self.filename, subDirectory, name), True, self.format) as file:
760
+ result = file.getInfos("", group)
761
+ else:
762
+ with Dataset.File(self.filename, True, self.format) as file:
763
+ result = file.getInfos(groups, name)
764
+ return result