konfai 1.1.8__py3-none-any.whl → 1.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of konfai might be problematic. Click here for more details.

Files changed (36) hide show
  1. konfai/__init__.py +59 -14
  2. konfai/data/augmentation.py +457 -286
  3. konfai/data/data_manager.py +533 -316
  4. konfai/data/patching.py +300 -183
  5. konfai/data/transform.py +408 -275
  6. konfai/evaluator.py +325 -68
  7. konfai/main.py +71 -22
  8. konfai/metric/measure.py +360 -244
  9. konfai/metric/schedulers.py +24 -13
  10. konfai/models/classification/convNeXt.py +187 -81
  11. konfai/models/classification/resnet.py +272 -58
  12. konfai/models/generation/cStyleGan.py +233 -59
  13. konfai/models/generation/ddpm.py +348 -121
  14. konfai/models/generation/diffusionGan.py +757 -358
  15. konfai/models/generation/gan.py +177 -53
  16. konfai/models/generation/vae.py +140 -40
  17. konfai/models/registration/registration.py +135 -52
  18. konfai/models/representation/representation.py +57 -23
  19. konfai/models/segmentation/NestedUNet.py +339 -68
  20. konfai/models/segmentation/UNet.py +140 -30
  21. konfai/network/blocks.py +331 -187
  22. konfai/network/network.py +795 -427
  23. konfai/predictor.py +644 -238
  24. konfai/trainer.py +509 -222
  25. konfai/utils/ITK.py +191 -106
  26. konfai/utils/config.py +152 -95
  27. konfai/utils/dataset.py +326 -455
  28. konfai/utils/utils.py +497 -249
  29. {konfai-1.1.8.dist-info → konfai-1.2.0.dist-info}/METADATA +1 -3
  30. konfai-1.2.0.dist-info/RECORD +38 -0
  31. konfai/utils/registration.py +0 -199
  32. konfai-1.1.8.dist-info/RECORD +0 -39
  33. {konfai-1.1.8.dist-info → konfai-1.2.0.dist-info}/WHEEL +0 -0
  34. {konfai-1.1.8.dist-info → konfai-1.2.0.dist-info}/entry_points.txt +0 -0
  35. {konfai-1.1.8.dist-info → konfai-1.2.0.dist-info}/licenses/LICENSE +0 -0
  36. {konfai-1.1.8.dist-info → konfai-1.2.0.dist-info}/top_level.txt +0 -0
konfai/utils/dataset.py CHANGED
@@ -1,255 +1,33 @@
1
- import SimpleITK as sitk
2
- import h5py
3
- from abc import ABC, abstractmethod
4
- import numpy as np
5
- from typing import Any, Union
1
+ import ast
6
2
  import copy
7
- import torch
8
- import os
9
-
10
- from lxml import etree
11
3
  import csv
12
- from konfai import DATE
13
-
14
- class Plot():
15
-
16
- def __init__(self, root: etree.ElementTree) -> None:
17
- self.root = root
4
+ import os
5
+ from abc import ABC, abstractmethod
6
+ from typing import Any
18
7
 
19
- def _explore(root, result, label):
20
- if len(root) == 0:
21
- for attribute in root.attrib:
22
- result["attrib:"+label+":"+attribute] = root.attrib[attribute]
23
- if root.text is not None:
24
- result[label] = np.fromstring(root.text, sep = ",").astype('double')
25
- else:
26
- for node in root:
27
- Plot._explore(node, result, label+":"+node.tag)
28
-
29
- def getNodes(root, path = None, id = None):
30
- nodes = []
31
- if path != None:
32
- path = path.split(":")
33
- for node_name in path:
34
- node = root.find(node_name)
35
- if node != None:
36
- root = node
37
- else:
38
- break
39
- if id != None:
40
- for node in root.findall(".//"+id):
41
- nodes.append(node)
42
- else:
43
- nodes.append(root)
44
- return nodes
8
+ import h5py
9
+ import numpy as np
10
+ import SimpleITK as sitk # noqa: N813
11
+ import torch
12
+ from lxml import etree # nosec B410
45
13
 
46
- def read(root, path = None, id = None):
47
- result = dict()
48
- for node in Plot.getNodes(root, path, id):
49
- Plot._explore(node, result, etree.ElementTree(root).getpath(node))
50
- return result
14
+ from konfai import current_date
51
15
 
52
- def _extract(self, ids = [], patients = []):
53
- result = dict()
54
- if len(patients) == 0:
55
- if len(ids) == 0:
56
- result.update(Plot.read(self.root,None, None))
57
- else:
58
- for id in ids:
59
- result.update(Plot.read(self.root, None, id))
60
- else:
61
- for path in patients:
62
- if len(ids) == 0:
63
- result.update(Plot.read(self.root, path, None))
64
- else:
65
- for id in ids:
66
- result.update(Plot.read(self.root, path, id))
67
- return result
68
-
69
- def getErrors(self, ids = [], patients = []):
70
- results = self._extract(ids=ids, patients=patients)
71
- errors = {k: v for k, v in results.items() if not k.startswith("attrib:")}
72
- results : dict[str, dict[str, np.ndarray]]= {}
73
- for key, error in errors.items():
74
- patient = key.replace("/",":").split(":")[2]
75
- k = key.replace("/",":").split(":")[-1]
76
- err = np.linalg.norm(error.reshape(int(error.shape[0]/3),3), ord=2, axis=1)
77
- if patient not in results:
78
- results[patient] = {k : err}
79
- else:
80
- results[patient].update({k : err})
81
- return results
82
-
83
- def statistic_attrib(self, ids = [], patients = [], type: str = "HD95Mean"):
84
- results = self._extract(ids=ids, patients=patients)
85
-
86
- errors = {k.replace("attrib:", ""): float(v) for k, v in results.items() if type in k}
87
-
88
- values = {key : np.array([]) for key in ids}
89
- for key, error in errors.items():
90
- k = key.replace("/",":").split(":")[-2]
91
- values[k] = np.append(values[k], error)
92
-
93
- for k in values:
94
- values[k] = np.mean(values[k])
95
- print(values)
96
- return values
97
-
98
- def statistic_parameter(self, ids = [], patients = []):
99
- results = self._extract(ids=ids, patients=patients)
100
- errors = {k.replace("attrib:", "").replace(":Time", "") : np.load("./Results/{}/{}.npy".format(k.split("/")[3].split(":")[0], k.split("/")[2])) for k in results.keys()}
101
-
102
- norms = {key : np.array([]) for key in ids}
103
- max = 0
104
- for key, error in errors.items():
105
- if max < int(error.shape[0]/3):
106
- max = int(error.shape[0]/3)
107
-
108
- for key, error in errors.items():
109
- k = key.replace("/",":").split(":")[-1]
110
- norms[k] = np.append(norms[k], np.linalg.norm(error.reshape(int(error.shape[0]/3),3), ord=2, axis=1))
111
- v = np.linalg.norm(error.reshape(int(error.shape[0]/3),3), ord=2, axis=1)
112
-
113
- print(key, "{} {} {} {} {}".format(np.round(np.mean(v), 2), np.round(np.std(v), 2), np.round(np.quantile(v, 0.25), 2), np.round(np.quantile(v, 0.5), 2), np.round(np.quantile(v, 0.75), 2)))
114
- results = {}
115
- for key, values in norms.items():
116
- if key == "Rigid":
117
- results.update({key : values})
118
- else:
119
- try:
120
- name = "{}".format("_".join(key.split("_"))[:-1])
121
- it = int(key.split("_")[-1])
122
- except:
123
- name = "{}".format(key.split("-")[0])
124
- it = int(key.split("-")[-1])
125
-
126
- if name in results:
127
- results[name].update({it : values})
128
- else:
129
- results.update({name : {it : values}})
130
-
131
- r = []
132
- for key, values in norms.items():
133
- #r.append("{} $\pm$ {}".format(np.round(np.mean(values), 2), np.round(np.std(values), 2)))
134
- r.append("{} {} {}".format(np.round(np.quantile(values, 0.25), 2), np.round(np.quantile(values, 0.5),2), np.round(np.quantile(values, 0.75), 2)))
135
- #r.append("{} $\pm$ {}".format(np.round(np.quantile(values, 0.5), 2), np.round(np.quantile(values, 0.75)-np.quantile(values, 0.25), 2)))
136
- print(" & ".join(r))
137
-
138
- def statistic(self, ids = [], patients = []):
139
- results = self._extract(ids=ids, patients=patients)
140
- #errors = {k.replace("attrib:", "").replace(":Time", "") : np.load("./Dataset/{}/{}.npy".format(k.split("/")[3].split(":")[0], k.split("/")[2])) for k in results.keys()}
141
- errors = {k: v for k, v in results.items() if not k.startswith("attrib:")}
142
- print(errors)
143
- norms = {key : np.array([]) for key in ids}
144
- max = 0
145
- for key, error in errors.items():
146
- if max < int(error.shape[0]/3):
147
- max = int(error.shape[0]/3)
148
-
149
- for key, error in errors.items():
150
- k = key.replace("/",":").split(":")[-1]
151
- norms[k] = np.append(norms[k], np.linalg.norm(error.reshape(int(error.shape[0]/3),3), ord=2, axis=1))
152
- v = np.linalg.norm(error.reshape(int(error.shape[0]/3),3), ord=2, axis=1)
153
- print(key, (np.mean(v), np.std(v), np.quantile(v, 0.25), np.quantile(v, 0.5), np.quantile(v, 0.75)) )
154
- results = {}
155
- """for key, values in norms.items():
156
- if key == "Rigid":
157
- results.update({key : values})
158
- else:
159
- try:
160
- name = "{}".format("_".join(key.split("_"))[:-1])
161
- it = int(key.split("_")[-1])
162
- except:
163
- name = "{}".format(key.split("-")[0])
164
- it = int(key.split("-")[-1])
165
-
166
- if name in results:
167
- results[name].update({it : values})
168
- else:
169
- results.update({name : {it : values}})"""
170
-
171
-
172
- print({key: (np.mean(values), np.std(values), np.quantile(values, 0.25), np.quantile(values, 0.5), np.quantile(values, 0.75)) for key, values in norms.items()})
173
- return results
174
-
175
- def plot(self, ids = [], patients = [], labels = [], colors = None):
176
-
177
- import matplotlib.pyplot as pyplot
178
- results = self._extract(ids=ids, patients=patients)
179
-
180
- attrs = {k: v for k, v in results.items() if k.startswith("attrib:")}
181
- errors = {k: v for k, v in results.items() if not k.startswith("attrib:")}
182
-
183
- patients = set()
184
- max = 0
185
- for key, error in errors.items():
186
- patients.add(key.replace("/",":").split(":")[2])
187
- if max < int(error.shape[0]/3):
188
- max = int(error.shape[0]/3)
189
- patients = sorted(patients)
190
-
191
- norms = {patient : np.array([]) for patient in patients}
192
- markups = {patient : np.array([]) for patient in patients}
193
- series = list()
194
- for key, error in errors.items():
195
- patient = key.replace("/",":").split(":")[2]
196
- markup = np.full((max,3), np.nan)
197
- markup[0:int(error.shape[0]/3), :] = error.reshape(int(error.shape[0]/3),3)
198
-
199
- markups[patient] = np.append(markups[patient], markup)
200
- norms[patient] = np.append(norms[patient], np.linalg.norm(markup, ord=2, axis=1))
201
-
202
- if len(labels) == 0:
203
- labels = list(set([k.split("/")[-1] for k in errors.keys()]))
204
-
205
- for label in labels:
206
- series = series+[label]*max
207
- import pandas as pd
208
- df = pd.DataFrame(dict([(k,pd.Series(v)) for k, v in norms.items()]))
209
- df['Categories'] = pd.Series(series)
210
-
211
- bp = df.boxplot(by='Categories', color="black", figsize=(12,8), notch=True,layout=(1,len(patients)), fontsize=18, rot=0, patch_artist = True, return_type='both', widths=[0.5]*len(labels))
212
-
213
- color_pallet = {"b" : "paleturquoise", "g" : "lightgreen"}
214
- if colors == None:
215
- colors = ["b"] * len(patients)
216
- pyplot.suptitle('')
217
- it_1 = 0
218
- for index, (ax,row) in bp.items():
219
- ax.set_xlabel('')
220
- ax.set_ylim(ymin=0)
221
- ax.set_ylabel("TRE (mm)", fontsize=18)
222
- ax.set_yticks([0,1,2,3,4,5,6,7,8,9,10,15,20,25]) # Set label locations.
223
- for i,object in enumerate(row["boxes"]):
224
- object.set_edgecolor("black")
225
- object.set_facecolor(color_pallet[colors[i]])
226
- object.set_alpha(0.7)
227
- object.set_linewidth(1.0)
228
-
229
- for i,object in enumerate(row["medians"]):
230
- object.set_color("indianred")
231
- xy = object.get_xydata()
232
- object.set_linewidth(2.0)
233
- it_1+=1
234
- return self
235
-
236
- def show(self):
237
- import matplotlib.pyplot as pyplot
238
- pyplot.show()
239
16
 
240
17
  class Attribute(dict[str, Any]):
241
18
 
242
- def __init__(self, attributes : dict[str, Any] = {}) -> None:
19
+ def __init__(self, attributes: dict[str, Any] | None = None) -> None:
243
20
  super().__init__()
21
+ attributes = attributes or {}
244
22
  for k, v in attributes.items():
245
23
  super().__setitem__(copy.deepcopy(k), copy.deepcopy(v))
246
-
24
+
247
25
  def __getitem__(self, key: str) -> Any:
248
26
  i = len([k for k in super().keys() if k.startswith(key)])
249
- if i > 0 and "{}_{}".format(key, i-1) in super().keys():
250
- return str(super().__getitem__("{}_{}".format(key, i-1)))
27
+ if i > 0 and f"{key}_{i - 1}" in super().keys():
28
+ return str(super().__getitem__(f"{key}_{i - 1}"))
251
29
  else:
252
- raise NameError("{} not in cache_attribute".format(key))
30
+ raise NameError(f"{key} not in cache_attribute")
253
31
 
254
32
  def __setitem__(self, key: str, value: Any) -> None:
255
33
  if "_" not in key:
@@ -259,58 +37,63 @@ class Attribute(dict[str, Any]):
259
37
  result = str(value.numpy())
260
38
  else:
261
39
  result = str(value)
262
- result = result.replace('\n', '')
263
- super().__setitem__("{}_{}".format(key, i), result)
40
+ result = result.replace("\n", "")
41
+ super().__setitem__(f"{key}_{i}", result)
264
42
  else:
265
43
  result = None
266
44
  if isinstance(value, torch.Tensor):
267
45
  result = str(value.numpy())
268
46
  else:
269
47
  result = str(value)
270
- result = result.replace('\n', '')
48
+ result = result.replace("\n", "")
271
49
  super().__setitem__(key, result)
272
50
 
273
- def pop(self, key: str) -> Any:
51
+ def pop(self, key: str, default: Any = None) -> Any:
274
52
  i = len([k for k in super().keys() if k.startswith(key)])
275
- if i > 0 and "{}_{}".format(key, i-1) in super().keys():
276
- return super().pop("{}_{}".format(key, i-1))
53
+ if i > 0 and f"{key}_{i - 1}" in super().keys():
54
+ return super().pop(f"{key}_{i - 1}")
277
55
  else:
278
- raise NameError("{} not in cache_attribute".format(key))
56
+ raise NameError(f"{key} not in cache_attribute")
279
57
 
280
58
  def get_np_array(self, key) -> np.ndarray:
281
59
  return np.fromstring(self[key][1:-1], sep=" ", dtype=np.double)
282
-
60
+
283
61
  def get_tensor(self, key) -> torch.Tensor:
284
62
  return torch.tensor(self.get_np_array(key)).to(torch.float32)
285
-
63
+
286
64
  def pop_np_array(self, key):
287
65
  return np.fromstring(self.pop(key)[1:-1], sep=" ", dtype=np.double)
288
-
66
+
289
67
  def pop_tensor(self, key) -> torch.Tensor:
290
68
  return torch.tensor(self.pop_np_array(key))
291
-
292
- def __contains__(self, key: str) -> bool:
293
- return len([k for k in super().keys() if k.startswith(key)]) > 0
294
-
295
- def isInfo(self, key: str, value: str) -> bool:
69
+
70
+ def __contains__(self, key: object) -> bool:
71
+ if not isinstance(key, str):
72
+ return False
73
+ return any(k.startswith(key) for k in super().keys())
74
+
75
+ def is_info(self, key: str, value: str) -> bool:
296
76
  return key in self and self[key] == value
297
77
 
298
- def isAnImage(attributes: Attribute):
78
+
79
+ def is_an_image(attributes: Attribute):
299
80
  return "Origin" in attributes and "Spacing" in attributes and "Direction" in attributes
300
81
 
301
- def data_to_image(data : np.ndarray, attributes: Attribute) -> sitk.Image:
302
- if not isAnImage(attributes):
82
+
83
+ def data_to_image(data: np.ndarray, attributes: Attribute) -> sitk.Image:
84
+ if not is_an_image(attributes):
303
85
  raise NameError("Data is not an image")
304
86
  if data.shape[0] == 1:
305
87
  image = sitk.GetImageFromArray(data[0])
306
88
  else:
307
- data = data.transpose(tuple([i+1 for i in range(len(data.shape)-1)]+[0]))
89
+ data = data.transpose(tuple([i + 1 for i in range(len(data.shape) - 1)] + [0]))
308
90
  image = sitk.GetImageFromArray(data, isVector=True)
309
91
  image.SetOrigin(attributes.get_np_array("Origin").tolist())
310
92
  image.SetSpacing(attributes.get_np_array("Spacing").tolist())
311
93
  image.SetDirection(attributes.get_np_array("Direction").tolist())
312
94
  return image
313
95
 
96
+
314
97
  def image_to_data(image: sitk.Image) -> tuple[np.ndarray, Attribute]:
315
98
  attributes = Attribute()
316
99
  attributes["Origin"] = np.asarray(image.GetOrigin())
@@ -321,81 +104,98 @@ def image_to_data(image: sitk.Image) -> tuple[np.ndarray, Attribute]:
321
104
  if image.GetNumberOfComponentsPerPixel() == 1:
322
105
  data = np.expand_dims(data, 0)
323
106
  else:
324
- data = np.transpose(data, (len(data.shape)-1, *[i for i in range(len(data.shape)-1)]))
107
+ data = np.transpose(data, (len(data.shape) - 1, *list(range(len(data.shape) - 1))))
325
108
  return data, attributes
326
109
 
327
- class Dataset():
110
+
111
+ class Dataset:
328
112
 
329
113
  class AbstractFile(ABC):
330
114
 
115
+ @abstractmethod
331
116
  def __init__(self) -> None:
332
117
  pass
333
-
118
+
119
+ @abstractmethod
334
120
  def __enter__(self):
335
121
  pass
336
122
 
337
- def __exit__(self, type, value, traceback):
123
+ @abstractmethod
124
+ def __exit__(self, exc_type, value, traceback):
338
125
  pass
339
126
 
340
127
  @abstractmethod
341
- def file_to_data(self):
128
+ def file_to_data(self, group: str, name: str) -> tuple[np.ndarray, Attribute]:
342
129
  pass
343
130
 
344
131
  @abstractmethod
345
- def data_to_file(self):
132
+ def data_to_file(
133
+ self,
134
+ name: str,
135
+ data: sitk.Image | sitk.Transform | np.ndarray,
136
+ attributes: Attribute | None = None,
137
+ ) -> None:
346
138
  pass
347
139
 
348
140
  @abstractmethod
349
- def getNames(self, group: str) -> list[str]:
141
+ def get_names(self, group: str) -> list[str]:
350
142
  pass
351
-
143
+
352
144
  @abstractmethod
353
- def getGroup(self) -> list[str]:
145
+ def get_group(self) -> list[str]:
354
146
  pass
355
147
 
356
148
  @abstractmethod
357
- def isExist(self, group: str, name: Union[str, None] = None) -> bool:
149
+ def is_exist(self, group: str, name: str | None = None) -> bool:
358
150
  pass
359
-
151
+
360
152
  @abstractmethod
361
- def getInfos(self, group: Union[str, None], name: str) -> tuple[list[int], Attribute]:
153
+ def get_infos(self, group: str, name: str) -> tuple[list[int], Attribute]:
362
154
  pass
363
155
 
364
156
  class H5File(AbstractFile):
365
157
 
366
158
  def __init__(self, filename: str, read: bool) -> None:
367
- self.h5: Union[h5py.File, None] = None
159
+ self.h5: h5py.File | None = None
368
160
  self.filename = filename
369
161
  if not self.filename.endswith(".h5"):
370
162
  self.filename += ".h5"
371
163
  self.read = read
372
164
 
373
165
  def __enter__(self):
374
- args = {}
375
166
  if self.read:
376
- self.h5 = h5py.File(self.filename, 'r', **args)
167
+ self.h5 = h5py.File(self.filename, "r")
377
168
  else:
378
169
  if not os.path.exists(self.filename):
379
- if len(self.filename.split("/")) > 1 and not os.path.exists("/".join(self.filename.split("/")[:-1])):
170
+ if len(self.filename.split("/")) > 1 and not os.path.exists(
171
+ "/".join(self.filename.split("/")[:-1])
172
+ ):
380
173
  os.makedirs("/".join(self.filename.split("/")[:-1]))
381
- self.h5 = h5py.File(self.filename, 'w', **args)
382
- else:
383
- self.h5 = h5py.File(self.filename, 'r+', **args)
384
- self.h5.attrs["Date"] = DATE()
174
+ self.h5 = h5py.File(self.filename, "w")
175
+ else:
176
+ self.h5 = h5py.File(self.filename, "r+")
177
+ self.h5.attrs["Date"] = current_date()
385
178
  self.h5.__enter__()
386
179
  return self.h5
387
-
388
- def __exit__(self, type, value, traceback):
180
+
181
+ def __exit__(self, exc_type, value, traceback):
389
182
  if self.h5 is not None:
390
183
  self.h5.close()
391
-
184
+
392
185
  def file_to_data(self, groups: str, name: str) -> tuple[np.ndarray, Attribute]:
393
- dataset = self._getDataset(groups, name)
186
+ dataset = self._get_dataset(groups, name)
394
187
  data = np.zeros(dataset.shape, dataset.dtype)
395
188
  dataset.read_direct(data)
396
- return data, Attribute({k : str(v) for k, v in dataset.attrs.items()})
397
-
398
- def data_to_file(self, name : str, data : Union[sitk.Image, sitk.Transform, np.ndarray], attributes : Union[Attribute, None] = None) -> None:
189
+ return data, Attribute({k: str(v) for k, v in dataset.attrs.items()})
190
+
191
+ def data_to_file(
192
+ self,
193
+ name: str,
194
+ data: sitk.Image | sitk.Transform | np.ndarray,
195
+ attributes: Attribute | None = None,
196
+ ) -> None:
197
+ if self.h5 is None:
198
+ return
399
199
  if attributes is None:
400
200
  attributes = Attribute()
401
201
  if isinstance(data, sitk.Image):
@@ -405,7 +205,7 @@ class Dataset():
405
205
  transforms = []
406
206
  if isinstance(data, sitk.CompositeTransform):
407
207
  for i in range(data.GetNumberOfTransforms()):
408
- transforms.append(data.GetNthTransform(i))
208
+ transforms.append(data.GetNthTransform(i))
409
209
  else:
410
210
  transforms.append(data)
411
211
  datas = []
@@ -416,8 +216,8 @@ class Dataset():
416
216
  transform_type = "AffineTransform_double_3_3"
417
217
  if isinstance(transform, sitk.BSplineTransform):
418
218
  transform_type = "BSplineTransform_double_3_3"
419
- attributes["{}:Transform".format(i)] = transform_type
420
- attributes["{}:FixedParameters".format(i)] = transform.GetFixedParameters()
219
+ attributes[f"{i}:Transform"] = transform_type
220
+ attributes[f"{i}:FixedParameters"] = transform.GetFixedParameters()
421
221
 
422
222
  datas.append(np.asarray(transform.GetParameters()))
423
223
  data = np.asarray(datas)
@@ -434,38 +234,41 @@ class Dataset():
434
234
  del h5_group[name]
435
235
 
436
236
  dataset = h5_group.create_dataset(name, data=data, dtype=data.dtype, chunks=None)
437
- dataset.attrs.update({k : str(v) for k, v in attributes.items()})
438
-
439
- def isExist(self, group: str, name: Union[str, None] = None) -> bool:
440
- if group in self.h5:
441
- if isinstance(self.h5[group], h5py.Dataset):
442
- return True
443
- elif name is not None:
444
- return name in self.h5[group]
445
- else:
446
- return False
237
+ dataset.attrs.update({k: str(v) for k, v in attributes.items()})
238
+
239
+ def is_exist(self, group: str, name: str | None = None) -> bool:
240
+ if self.h5 is not None:
241
+ if group in self.h5:
242
+ if isinstance(self.h5[group], h5py.Dataset):
243
+ return True
244
+ elif name is not None:
245
+ return name in self.h5[group]
246
+ else:
247
+ return False
447
248
  return False
448
249
 
449
- def getNames(self, groups: str, h5_group: h5py.Group = None) -> list[str]:
250
+ def get_names(self, groups: str, h5_group: h5py.Group = None) -> list[str]:
450
251
  names = []
451
252
  if h5_group is None:
452
253
  h5_group = self.h5
453
254
  group = groups.split("/")[0]
454
255
  if group == "":
455
- names = [dataset.name.split("/")[-1] for dataset in h5_group.values() if isinstance(dataset, h5py.Dataset)]
256
+ names = [
257
+ dataset.name.split("/")[-1] for dataset in h5_group.values() if isinstance(dataset, h5py.Dataset)
258
+ ]
456
259
  elif group == "*":
457
260
  for k in h5_group.keys():
458
261
  if isinstance(h5_group[k], h5py.Group):
459
- names.extend(self.getNames("/".join(groups.split("/")[1:]), h5_group[k]))
262
+ names.extend(self.get_names("/".join(groups.split("/")[1:]), h5_group[k]))
460
263
  else:
461
264
  if group in h5_group:
462
- names.extend(self.getNames("/".join(groups.split("/")[1:]), h5_group[group]))
265
+ names.extend(self.get_names("/".join(groups.split("/")[1:]), h5_group[group]))
463
266
  return names
464
-
465
- def getGroup(self):
466
- return self.h5.keys()
467
-
468
- def _getDataset(self, groups: str, name: str, h5_group: h5py.Group = None) -> h5py.Dataset:
267
+
268
+ def get_group(self) -> list[str]:
269
+ return list(self.h5.keys()) if self.h5 is not None else []
270
+
271
+ def _get_dataset(self, groups: str, name: str, h5_group: h5py.Group = None) -> h5py.Dataset:
469
272
  if h5_group is None:
470
273
  h5_group = self.h5
471
274
  if groups != "":
@@ -479,39 +282,42 @@ class Dataset():
479
282
  elif group == "*":
480
283
  for k in h5_group.keys():
481
284
  if isinstance(h5_group[k], h5py.Group):
482
- result_tmp = self._getDataset("/".join(groups.split("/")[1:]), name, h5_group[k])
285
+ result_tmp = self._get_dataset("/".join(groups.split("/")[1:]), name, h5_group[k])
483
286
  if result_tmp is not None:
484
287
  result = result_tmp
485
288
  else:
486
289
  if group in h5_group:
487
- result_tmp = self._getDataset("/".join(groups.split("/")[1:]), name, h5_group[group])
290
+ result_tmp = self._get_dataset("/".join(groups.split("/")[1:]), name, h5_group[group])
488
291
  if result_tmp is not None:
489
292
  result = result_tmp
490
293
  return result
491
-
492
- def getInfos(self, groups: str, name: str) -> tuple[list[int], Attribute]:
493
- dataset = self._getDataset(groups, name)
494
- return (dataset.shape, Attribute({k : str(v) for k, v in dataset.attrs.items()}))
495
-
294
+
295
+ def get_infos(self, groups: str, name: str) -> tuple[list[int], Attribute]:
296
+ dataset = self._get_dataset(groups, name)
297
+ return (
298
+ dataset.shape,
299
+ Attribute({k: str(v) for k, v in dataset.attrs.items()}),
300
+ )
301
+
496
302
  class SitkFile(AbstractFile):
497
303
 
498
- def __init__(self, filename: str, read: bool, format: str) -> None:
304
+ def __init__(self, filename: str, read: bool, file_format: str) -> None:
499
305
  self.filename = filename
500
306
  self.read = read
501
- self.format = format
502
-
307
+ self.file_format = file_format
308
+
503
309
  def file_to_data(self, group: str, name: str) -> tuple[np.ndarray, Attribute]:
504
310
  attributes = Attribute()
505
- if os.path.exists("{}{}.{}".format(self.filename, name, self.format)):
506
- image = sitk.ReadImage("{}{}.{}".format(self.filename, name, self.format))
311
+ if os.path.exists(f"{self.filename}{name}.{self.file_format}"):
312
+ image = sitk.ReadImage(f"{self.filename}{name}.{self.file_format}")
507
313
  data, attributes_tmp = image_to_data(image)
508
314
  attributes.update(attributes_tmp)
509
- elif os.path.exists("{}{}.itk.txt".format(self.filename, name)):
510
- data = sitk.ReadTransform("{}{}.itk.txt".format(self.filename, name))
315
+ elif os.path.exists(f"{self.filename}{name}.itk.txt"):
316
+ data = sitk.ReadTransform(f"{self.filename}{name}.itk.txt")
511
317
  transforms = []
512
318
  if isinstance(data, sitk.CompositeTransform):
513
319
  for i in range(data.GetNumberOfTransforms()):
514
- transforms.append(data.GetNthTransform(i))
320
+ transforms.append(data.GetNthTransform(i))
515
321
  else:
516
322
  transforms.append(data)
517
323
  datas = []
@@ -522,113 +328,151 @@ class Dataset():
522
328
  transform_type = "AffineTransform_double_3_3"
523
329
  if isinstance(transform, sitk.BSplineTransform):
524
330
  transform_type = "BSplineTransform_double_3_3"
525
- attributes["{}:Transform".format(i)] = transform_type
526
- attributes["{}:FixedParameters".format(i)] = transform.GetFixedParameters()
331
+ attributes[f"{i}:Transform"] = transform_type
332
+ attributes[f"{i}:FixedParameters"] = transform.GetFixedParameters()
527
333
 
528
334
  datas.append(np.asarray(transform.GetParameters()))
529
335
  data = np.asarray(datas)
530
- elif os.path.exists("{}{}.fcsv".format(self.filename, name)):
531
- with open("{}{}.fcsv".format(self.filename, name), newline="") as csvfile:
532
- reader = csv.reader(filter(lambda row: row[0]!='#', csvfile))
336
+ elif os.path.exists(f"{self.filename}{name}.fcsv"):
337
+ with open(f"{self.filename}{name}.fcsv", newline="") as csvfile:
338
+ reader = csv.reader(filter(lambda row: row[0] != "#", csvfile))
533
339
  lines = list(reader)
534
340
  data = np.zeros((len(list(lines)), 3), dtype=np.double)
535
341
  for i, row in enumerate(lines):
536
342
  data[i] = np.array(row[1:4], dtype=np.double)
537
343
  csvfile.close()
538
- elif os.path.exists("{}{}.xml".format(self.filename, name)):
539
- with open("{}{}.xml".format(self.filename, name), 'rb') as xml_file:
540
- result = etree.parse(xml_file, etree.XMLParser(remove_blank_text=True)).getroot()
344
+ elif os.path.exists(f"{self.filename}{name}.xml"):
345
+ with open(f"{self.filename}{name}.xml", "rb") as xml_file:
346
+ result = etree.parse(xml_file, etree.XMLParser(remove_blank_text=True)).getroot() # nosec B320
541
347
  xml_file.close()
542
348
  return result
543
- elif os.path.exists("{}{}.vtk".format(self.filename, name)):
349
+ elif os.path.exists(f"{self.filename}{name}.vtk"):
544
350
  import vtk
545
- vtkReader = vtk.vtkPolyDataReader()
546
- vtkReader.SetFileName("{}{}.vtk".format(self.filename, name))
547
- vtkReader.Update()
351
+
352
+ vtk_reader = vtk.vtkPolyDataReader()
353
+ vtk_reader.SetFileName(f"{self.filename}{name}.vtk")
354
+ vtk_reader.Update()
548
355
  data = []
549
- points = vtkReader.GetOutput().GetPoints()
356
+ points = vtk_reader.GetOutput().GetPoints()
550
357
  num_points = points.GetNumberOfPoints()
551
358
  for i in range(num_points):
552
359
  data.append(list(points.GetPoint(i)))
553
360
  data = np.asarray(data)
554
- elif os.path.exists("{}{}.npy".format(self.filename, name)):
555
- data = np.load("{}{}.npy".format(self.filename, name))
361
+ elif os.path.exists(f"{self.filename}{name}.npy"):
362
+ data = np.load(f"{self.filename}{name}.npy")
556
363
  return data, attributes
557
-
364
+
558
365
  def is_vtk_polydata(self, obj):
559
366
  try:
560
367
  import vtk
368
+
561
369
  return isinstance(obj, vtk.vtkPolyData)
562
370
  except ImportError:
563
371
  return False
564
-
565
- def data_to_file(self, name : str, data : Union[sitk.Image, sitk.Transform, np.ndarray], attributes : Attribute = Attribute()) -> None:
372
+
373
+ def __enter__(self):
374
+ pass
375
+
376
+ def __exit__(self, exc_type, value, traceback):
377
+ pass
378
+
379
+ def data_to_file(
380
+ self,
381
+ name: str,
382
+ data: sitk.Image | sitk.Transform | np.ndarray,
383
+ attributes: Attribute | None = None,
384
+ ) -> None:
385
+ if attributes is None:
386
+ attributes = Attribute()
566
387
  if not os.path.exists(self.filename):
567
388
  os.makedirs(self.filename)
568
389
  if isinstance(data, sitk.Image):
569
390
  for k, v in attributes.items():
570
391
  data.SetMetaData(k, v)
571
- sitk.WriteImage(data, "{}{}.{}".format(self.filename, name, self.format))
392
+ sitk.WriteImage(data, f"{self.filename}{name}.{self.file_format}")
572
393
  elif isinstance(data, sitk.Transform):
573
- sitk.WriteTransform(data, "{}{}.itk.txt".format(self.filename, name))
394
+ sitk.WriteTransform(data, f"{self.filename}{name}.itk.txt")
574
395
  elif self.is_vtk_polydata(data):
575
396
  import vtk
576
- vtkWriter = vtk.vtkPolyDataWriter()
577
- vtkWriter.SetFileName("{}{}.vtk".format(self.filename, name))
578
- vtkWriter.SetInputData(data)
579
- vtkWriter.Write()
580
- elif isAnImage(attributes):
397
+
398
+ vtk_writer = vtk.vtkPolyDataWriter()
399
+ vtk_writer.SetFileName(f"{self.filename}{name}.vtk")
400
+ vtk_writer.SetInputData(data)
401
+ vtk_writer.Write()
402
+ elif is_an_image(attributes):
581
403
  self.data_to_file(name, data_to_image(data, attributes), attributes)
582
- elif (len(data.shape) == 2 and data.shape[1] == 3 and data.shape[0] > 0):
404
+ elif len(data.shape) == 2 and data.shape[1] == 3 and data.shape[0] > 0:
583
405
  data = np.round(data, 4)
584
- with open("{}{}.fcsv".format(self.filename, name), 'w') as f:
585
- f.write("# Markups fiducial file version = 4.6\n# CoordinateSystem = 0\n# columns = id,x,y,z,ow,ox,oy,oz,vis,sel,lock,label,desc,associatedNodeID\n")
406
+ with open(f"{self.filename}{name}.fcsv", "w") as f:
407
+ f.write(
408
+ "# Markups fiducial file version = 4.6\n# CoordinateSystem = 0\n#"
409
+ " columns = id,x,y,z,ow,ox,oy,oz,vis,sel,lock,label,desc,associatedNodeID\n",
410
+ )
586
411
  for i in range(data.shape[0]):
587
- f.write("vtkMRMLMarkupsFiducialNode_"+str(i+1)+","+str(data[i, 0])+","+str(data[i, 1])+","+str(data[i, 2])+",0,0,0,1,1,1,0,F-"+str(i+1)+",,vtkMRMLScalarVolumeNode1\n")
412
+ f.write(
413
+ "vtkMRMLMarkupsFiducialNode_"
414
+ + str(i + 1)
415
+ + ","
416
+ + str(data[i, 0])
417
+ + ","
418
+ + str(data[i, 1])
419
+ + ","
420
+ + str(data[i, 2])
421
+ + ",0,0,0,1,1,1,0,F-"
422
+ + str(i + 1)
423
+ + ",,vtkMRMLScalarVolumeNode1\n"
424
+ )
588
425
  f.close()
589
426
  elif "path" in attributes:
590
- if os.path.exists("{}{}.xml".format(self.filename, name)):
591
- with open("{}{}.xml".format(self.filename, name), 'rb') as xml_file:
592
- root = etree.parse(xml_file, etree.XMLParser(remove_blank_text=True)).getroot()
427
+ if os.path.exists(f"{self.filename}{name}.xml"):
428
+ with open(f"{self.filename}{name}.xml", "rb") as xml_file:
429
+ root = etree.parse(xml_file, etree.XMLParser(remove_blank_text=True)).getroot() # nosec B320
593
430
  xml_file.close()
594
431
  else:
595
432
  root = etree.Element(name)
596
433
  node = root
597
- path = attributes["path"].split(':')
434
+ path = attributes["path"].split(":")
598
435
 
599
436
  for node_name in path:
600
437
  node_tmp = node.find(node_name)
601
- if node_tmp == None:
438
+ if node_tmp is None:
602
439
  node_tmp = etree.SubElement(node, node_name)
603
440
  node.append(node_tmp)
604
441
  node = node_tmp
605
- if attributes != None:
442
+ if attributes is not None:
606
443
  for attribute_tmp in attributes.keys():
607
444
  attribute = "_".join(attribute_tmp.split("_")[:-1])
608
445
  if attribute != "path":
609
446
  node.set(attribute, attributes[attribute])
610
447
  if data.size > 0:
611
- node.text = ", ".join(map(str, data.flatten())) #np.array2string(data, separator=',')[1:-1].replace('\n','')
612
- with open("{}{}.xml".format(self.filename, name), 'wb') as f:
613
- f.write(etree.tostring(root, pretty_print=True, encoding='utf-8'))
448
+ node.text = ", ".join(
449
+ map(str, data.flatten())
450
+ ) # np.array2string(data, separator=',')[1:-1].replace('\n','')
451
+ with open(f"{self.filename}{name}.xml", "wb") as f:
452
+ f.write(etree.tostring(root, pretty_print=True, encoding="utf-8"))
614
453
  f.close()
615
454
  else:
616
- np.save("{}{}.npy".format(self.filename, name), data)
617
-
618
- def isExist(self, group: str, name: Union[str, None] = None) -> bool:
619
- return os.path.exists("{}{}.{}".format(self.filename, group, self.format)) or os.path.exists("{}{}.itk.txt".format(self.filename, group)) or os.path.exists("{}{}.fcsv".format(self.filename, group)) or os.path.exists("{}{}.npy".format(self.filename, group))
620
-
621
- def getNames(self, group: str) -> list[str]:
455
+ np.save(f"{self.filename}{name}.npy", data)
456
+
457
+ def is_exist(self, group: str, name: str | None = None) -> bool:
458
+ return (
459
+ os.path.exists(f"{self.filename}{group}.{self.file_format}")
460
+ or os.path.exists(f"{self.filename}{group}.itk.txt")
461
+ or os.path.exists(f"{self.filename}{group}.fcsv")
462
+ or os.path.exists(f"{self.filename}{group}.npy")
463
+ )
464
+
465
+ def get_names(self, group: str) -> list[str]:
622
466
  raise NotImplementedError()
623
-
624
- def getGroup(self):
467
+
468
+ def get_group(self):
625
469
  raise NotImplementedError()
626
-
627
- def getInfos(self, group: str, name: str) -> tuple[list[int], Attribute]:
470
+
471
+ def get_infos(self, group: str, name: str) -> tuple[list[int], Attribute]:
628
472
  attributes = Attribute()
629
- if os.path.exists("{}{}{}.{}".format(self.filename, group if group is not None else "", name, self.format)):
473
+ if os.path.exists(f"{self.filename}{group if group is not None else ''}{name}.{self.file_format}"):
630
474
  file_reader = sitk.ImageFileReader()
631
- file_reader.SetFileName("{}{}{}.{}".format(self.filename, group if group is not None else "", name, self.format))
475
+ file_reader.SetFileName(f"{self.filename}{group if group is not None else ''}{name}.{self.file_format}")
632
476
  file_reader.ReadImageInformation()
633
477
  attributes["Origin"] = np.asarray(file_reader.GetOrigin())
634
478
  attributes["Spacing"] = np.asarray(file_reader.GetSpacing())
@@ -638,72 +482,85 @@ class Dataset():
638
482
  size = list(file_reader.GetSize())
639
483
  if len(size) == 3:
640
484
  size = list(reversed(size))
641
- size = [file_reader.GetNumberOfComponents()]+size
485
+ size = [file_reader.GetNumberOfComponents()] + size
642
486
  else:
643
487
  data, attributes = self.file_to_data(group if group is not None else "", name)
644
488
  size = data.shape
645
- return tuple(size), attributes
489
+ return size, attributes
646
490
 
647
- class File(ABC):
491
+ class File:
648
492
 
649
- def __init__(self, filename: str, read: bool, format: str) -> None:
493
+ def __init__(self, filename: str, read: bool, file_format: str) -> None:
650
494
  self.filename = filename
651
495
  self.read = read
652
- self.file = None
653
- self.format = format
496
+ self.file: "Dataset.AbstractFile" | None = None
497
+ self.file_format = file_format
654
498
 
655
- def __enter__(self):
656
- if self.format == "h5":
499
+ def __enter__(self) -> "Dataset.AbstractFile":
500
+ if self.file_format == "h5":
657
501
  self.file = Dataset.H5File(self.filename, self.read)
658
502
  else:
659
- self.file = Dataset.SitkFile(self.filename+"/", self.read, self.format)
503
+ self.file = Dataset.SitkFile(self.filename + "/", self.read, self.file_format)
660
504
  self.file.__enter__()
661
505
  return self.file
662
506
 
663
- def __exit__(self, type, value, traceback):
664
- self.file.__exit__(type, value, traceback)
507
+ def __exit__(self, exc_type, value, traceback):
508
+ if self.file is not None:
509
+ self.file.__exit__(exc_type, value, traceback)
665
510
 
666
- def __init__(self, filename : str, format: str) -> None:
667
- if format != "h5" and not filename.endswith("/"):
668
- filename = "{}/".format(filename)
669
- self.is_directory = filename.endswith("/")
511
+ def __init__(self, filename: str, file_format: str) -> None:
512
+ if file_format != "h5" and not filename.endswith("/"):
513
+ filename = f"{filename}/"
514
+ self.is_directory = filename.endswith("/")
670
515
  self.filename = filename
671
- self.format = format
672
-
673
- def write(self, group : str, name : str, data : Union[sitk.Image, sitk.Transform, np.ndarray], attributes : Attribute = Attribute()):
516
+ self.file_format = file_format
517
+
518
+ def write(
519
+ self,
520
+ group: str,
521
+ name: str,
522
+ data: sitk.Image | sitk.Transform | np.ndarray,
523
+ attributes: Attribute | None = None,
524
+ ):
525
+ if attributes is None:
526
+ attributes = Attribute()
674
527
  if self.is_directory:
675
528
  if not os.path.exists(self.filename):
676
529
  os.makedirs(self.filename)
677
530
  if self.is_directory:
678
531
  s_group = group.split("/")
679
532
  if len(s_group) > 1:
680
- subDirectory = "/".join(s_group[:-1])
681
- name = "{}/{}".format(subDirectory, name)
533
+ sub_directory = "/".join(s_group[:-1])
534
+ name = f"{sub_directory}/{name}"
682
535
  group = s_group[-1]
683
- with Dataset.File("{}{}".format(self.filename, name), False, self.format) as file:
536
+ with Dataset.File(f"{self.filename}{name}", False, self.file_format) as file:
684
537
  file.data_to_file(group, data, attributes)
685
538
  else:
686
- with Dataset.File(self.filename, False, self.format) as file:
687
- file.data_to_file("{}/{}".format(group, name), data, attributes)
688
-
689
- def readData(self, groups : str, name : str) -> tuple[np.ndarray, Attribute]:
539
+ with Dataset.File(self.filename, False, self.file_format) as file:
540
+ file.data_to_file(f"{group}/{name}", data, attributes)
541
+
542
+ def read_data(self, groups: str, name: str) -> tuple[np.ndarray, Attribute]:
690
543
  if not os.path.exists(self.filename):
691
- raise NameError("Dataset {} not found".format(self.filename))
544
+ raise NameError(f"Dataset {self.filename} not found")
692
545
  if self.is_directory:
693
- for subDirectory in self._getSubDirectories(groups):
546
+ for sub_directory in self._get_sub_directories(groups):
694
547
  group = groups.split("/")[-1]
695
- if os.path.exists("{}{}{}{}".format(self.filename, subDirectory, name, ".h5" if self.format == "h5" else "")):
696
- with Dataset.File("{}{}{}".format(self.filename, subDirectory, name), False, self.format) as file:
548
+ if os.path.exists(f"{self.filename}{sub_directory}{name}{'.h5' if self.file_format == 'h5' else ''}"):
549
+ with Dataset.File(
550
+ f"{self.filename}{sub_directory}{name}",
551
+ False,
552
+ self.file_format,
553
+ ) as file:
697
554
  result = file.file_to_data("", group)
698
555
  else:
699
- with Dataset.File(self.filename, False, self.format) as file:
556
+ with Dataset.File(self.filename, False, self.file_format) as file:
700
557
  result = file.file_to_data(groups, name)
701
558
  return result
702
-
703
- def readTransform(self, group : str, name : str) -> sitk.Transform:
559
+
560
+ def read_transform(self, group: str, name: str) -> sitk.Transform:
704
561
  if not os.path.exists(self.filename):
705
- raise NameError("Dataset {} not found".format(self.filename))
706
- transformParameters, attribute = self.readData(group, name)
562
+ raise NameError(f"Dataset {self.filename} not found")
563
+ transform_parameters, attribute = self.read_data(group, name)
707
564
  transforms_type = [v for k, v in attribute.items() if k.endswith(":Transform_0")]
708
565
  transforms = []
709
566
  for i, transform_type in enumerate(transforms_type):
@@ -713,78 +570,92 @@ class Dataset():
713
570
  transform = sitk.AffineTransform(3)
714
571
  if transform_type == "BSplineTransform_double_3_3":
715
572
  transform = sitk.BSplineTransform(3)
716
- transform.SetFixedParameters(eval(attribute["{}:FixedParameters".format(i)]))
717
- transform.SetParameters(tuple(transformParameters[i]))
573
+ transform.SetFixedParameters(ast.literal_eval(attribute[f"{i}:FixedParameters"]))
574
+ transform.SetParameters(tuple(transform_parameters[i]))
718
575
  transforms.append(transform)
719
576
  return sitk.CompositeTransform(transforms) if len(transforms) > 1 else transforms[0]
720
577
 
721
- def readImage(self, group : str, name : str):
722
- data, attribute = self.readData(group, name)
723
- return data_to_image(data, attribute)
724
-
725
- def getSize(self, group: str) -> int:
726
- return len(self.getNames(group))
727
-
728
- def isGroupExist(self, group: str) -> bool:
729
- return self.getSize(group) > 0
730
-
731
- def isDatasetExist(self, group: str, name: str) -> bool:
732
- return name in self.getNames(group)
733
-
734
- def _getSubDirectories(self, groups: str, subDirectory: str = ""):
578
+ def read_image(self, group: str, name: str):
579
+ data, attribute = self.read_data(group, name)
580
+ return data_to_image(data, attribute)
581
+
582
+ def get_size(self, group: str) -> int:
583
+ return len(self.get_names(group))
584
+
585
+ def is_group_exist(self, group: str) -> bool:
586
+ return self.get_size(group) > 0
587
+
588
+ def is_dataset_exist(self, group: str, name: str) -> bool:
589
+ return name in self.get_names(group)
590
+
591
+ def _get_sub_directories(self, groups: str, sub_directory: str = ""):
735
592
  group = groups.split("/")[0]
736
- subDirectories = []
593
+ sub_directories = []
737
594
  if len(groups.split("/")) == 1:
738
- subDirectories.append(subDirectory)
595
+ sub_directories.append(sub_directory)
739
596
  elif group == "*":
740
- for k in os.listdir("{}{}".format(self.filename, subDirectory)):
741
- if not os.path.isfile("{}{}{}".format(self.filename, subDirectory, k)):
742
- subDirectories.extend(self._getSubDirectories("/".join(groups.split("/")[1:]), "{}{}/".format(subDirectory , k)))
597
+ for k in os.listdir(f"{self.filename}{sub_directory}"):
598
+ if not os.path.isfile(f"{self.filename}{sub_directory}{k}"):
599
+ sub_directories.extend(
600
+ self._get_sub_directories(
601
+ "/".join(groups.split("/")[1:]),
602
+ f"{sub_directory}{k}/",
603
+ )
604
+ )
743
605
  else:
744
- subDirectory = "{}{}/".format(subDirectory, group)
745
- if os.path.exists("{}{}".format(self.filename, subDirectory)):
746
- subDirectories.extend(self._getSubDirectories("/".join(groups.split("/")[1:]), subDirectory))
747
- return subDirectories
606
+ sub_directory = f"{sub_directory}{group}/"
607
+ if os.path.exists(f"{self.filename}{sub_directory}"):
608
+ sub_directories.extend(self._get_sub_directories("/".join(groups.split("/")[1:]), sub_directory))
609
+ return sub_directories
748
610
 
749
- def getNames(self, groups: str, index: Union[list[int], None] = None) -> list[str]:
611
+ def get_names(self, groups: str, index: list[int] | None = None) -> list[str]:
750
612
  names = []
751
613
  if self.is_directory:
752
- for subDirectory in self._getSubDirectories(groups):
614
+ for sub_directory in self._get_sub_directories(groups):
753
615
  group = groups.split("/")[-1]
754
- if os.path.exists("{}{}".format(self.filename, subDirectory)):
755
- for name in sorted(os.listdir("{}{}".format(self.filename, subDirectory))):
756
- if os.path.isfile("{}{}{}".format(self.filename, subDirectory, name)) or self.format != "h5":
757
- with Dataset.File("{}{}{}".format(self.filename, subDirectory, name), True, self.format) as file:
758
- if file.isExist(group):
759
- names.append(name.replace(".h5", "") if self.format == "h5" else name)
616
+ if os.path.exists(f"{self.filename}{sub_directory}"):
617
+ for name in sorted(os.listdir(f"{self.filename}{sub_directory}")):
618
+ if os.path.isfile(f"{self.filename}{sub_directory}{name}") or self.file_format != "h5":
619
+ with Dataset.File(
620
+ f"{self.filename}{sub_directory}{name}",
621
+ True,
622
+ self.file_format,
623
+ ) as file:
624
+ if file.is_exist(group):
625
+ names.append(name.replace(".h5", "") if self.file_format == "h5" else name)
760
626
  else:
761
- with Dataset.File(self.filename, True, self.format) as file:
762
- names = file.getNames(groups)
627
+ with Dataset.File(self.filename, True, self.file_format) as file:
628
+ names = file.get_names(groups)
763
629
  return [name for i, name in enumerate(sorted(names)) if index is None or i in index]
764
-
765
- def getGroup(self):
630
+
631
+ def get_group(self):
766
632
  if self.is_directory:
767
- groups = set()
633
+ groups_set = set()
768
634
  for root, _, files in os.walk(self.filename):
769
635
  for file in files:
770
636
  path = os.path.relpath(os.path.join(root, file.split(".")[0]), self.filename)
771
637
  parts = path.split("/")
772
638
  if len(parts) >= 2:
773
639
  del parts[-2]
774
- groups.add("/".join(parts))
640
+ groups_set.add("/".join(parts))
641
+ groups = list(groups_set)
775
642
  else:
776
- with Dataset.File(self.filename, True, self.format) as file:
777
- groups = file.getGroup()
643
+ with Dataset.File(self.filename, True, self.file_format) as dataset_file:
644
+ groups = dataset_file.get_group()
778
645
  return list(groups)
779
-
780
- def getInfos(self, groups: str, name: str) -> tuple[list[int], Attribute]:
646
+
647
+ def get_infos(self, groups: str, name: str) -> tuple[list[int], Attribute]:
781
648
  if self.is_directory:
782
- for subDirectory in self._getSubDirectories(groups):
649
+ for sub_directory in self._get_sub_directories(groups):
783
650
  group = groups.split("/")[-1]
784
- if os.path.exists("{}{}{}{}".format(self.filename, subDirectory, name, ".h5" if self.format == "h5" else "")):
785
- with Dataset.File("{}{}{}".format(self.filename, subDirectory, name), True, self.format) as file:
786
- result = file.getInfos("", group)
651
+ if os.path.exists(f"{self.filename}{sub_directory}{name}{'.h5' if self.file_format == 'h5' else ''}"):
652
+ with Dataset.File(
653
+ f"{self.filename}{sub_directory}{name}",
654
+ True,
655
+ self.file_format,
656
+ ) as file:
657
+ result = file.get_infos("", group)
787
658
  else:
788
- with Dataset.File(self.filename, True, self.format) as file:
789
- result = file.getInfos(groups, name)
790
- return result
659
+ with Dataset.File(self.filename, True, self.file_format) as file:
660
+ result = file.get_infos(groups, name)
661
+ return result