konfai 1.1.8__py3-none-any.whl → 1.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of konfai might be problematic. Click here for more details.

Files changed (36) hide show
  1. konfai/__init__.py +59 -14
  2. konfai/data/augmentation.py +457 -286
  3. konfai/data/data_manager.py +533 -316
  4. konfai/data/patching.py +300 -183
  5. konfai/data/transform.py +408 -275
  6. konfai/evaluator.py +325 -68
  7. konfai/main.py +71 -22
  8. konfai/metric/measure.py +360 -244
  9. konfai/metric/schedulers.py +24 -13
  10. konfai/models/classification/convNeXt.py +187 -81
  11. konfai/models/classification/resnet.py +272 -58
  12. konfai/models/generation/cStyleGan.py +233 -59
  13. konfai/models/generation/ddpm.py +348 -121
  14. konfai/models/generation/diffusionGan.py +757 -358
  15. konfai/models/generation/gan.py +177 -53
  16. konfai/models/generation/vae.py +140 -40
  17. konfai/models/registration/registration.py +135 -52
  18. konfai/models/representation/representation.py +57 -23
  19. konfai/models/segmentation/NestedUNet.py +339 -68
  20. konfai/models/segmentation/UNet.py +140 -30
  21. konfai/network/blocks.py +331 -187
  22. konfai/network/network.py +795 -427
  23. konfai/predictor.py +644 -238
  24. konfai/trainer.py +509 -222
  25. konfai/utils/ITK.py +191 -106
  26. konfai/utils/config.py +152 -95
  27. konfai/utils/dataset.py +326 -455
  28. konfai/utils/utils.py +497 -249
  29. {konfai-1.1.8.dist-info → konfai-1.2.0.dist-info}/METADATA +1 -3
  30. konfai-1.2.0.dist-info/RECORD +38 -0
  31. konfai/utils/registration.py +0 -199
  32. konfai-1.1.8.dist-info/RECORD +0 -39
  33. {konfai-1.1.8.dist-info → konfai-1.2.0.dist-info}/WHEEL +0 -0
  34. {konfai-1.1.8.dist-info → konfai-1.2.0.dist-info}/entry_points.txt +0 -0
  35. {konfai-1.1.8.dist-info → konfai-1.2.0.dist-info}/licenses/LICENSE +0 -0
  36. {konfai-1.1.8.dist-info → konfai-1.2.0.dist-info}/top_level.txt +0 -0
konfai/utils/utils.py CHANGED
@@ -1,119 +1,163 @@
1
+ import argparse
2
+ import importlib.util
1
3
  import itertools
2
- import pynvml
3
- import psutil
4
-
5
- import numpy as np
6
4
  import os
7
- import torch
8
-
5
+ import random
6
+ import re
7
+ import shutil
8
+ import socket
9
+ import subprocess # nosec B404
10
+ import sys
9
11
  from abc import ABC, abstractmethod
12
+ from contextlib import closing
10
13
  from enum import Enum
11
- from typing import Any, Union
14
+ from pathlib import Path
15
+ from typing import Any, TextIO, cast
12
16
 
13
- from konfai import CONFIG_FILE, EVALUATIONS_DIRECTORY, STATISTICS_DIRECTORY, PREDICTIONS_DIRECTORY, KONFAI_STATE, CUDA_VISIBLE_DEVICES
17
+ import numpy as np
18
+ import psutil
19
+ import pynvml
20
+ import requests
21
+ import torch
14
22
  import torch.distributed as dist
15
- import argparse
16
- import subprocess
17
- import random
23
+ import torch.nn.functional as F # noqa: N812
18
24
  from torch.utils.data import DataLoader
19
- import torch.nn.functional as F
20
- import sys
21
- import re
22
-
23
- import requests
25
+ from torch.utils.tensorboard.writer import SummaryWriter
24
26
  from tqdm import tqdm
25
- import importlib
26
- from pathlib import Path
27
- import shutil
28
27
 
29
- def description(model, modelEMA = None, showMemory: bool = True, train: bool = True) -> str:
30
- values_desc = lambda weights, values: " ".join(["{}({:.2f}) : {:.6f}".format(name.split(":")[-1], weight, value) for (name, value), weight in zip(values.items(), weights.values())])
31
- model_desc = lambda model : "("+" ".join(["{}({:.6f}) : {}".format(name, network.optimizer.param_groups[0]['lr'] if network.optimizer is not None else 0, values_desc(network.measure.getLastWeights(), network.measure.getLastValues())) for name, network in model.module.getNetworks().items() if network.measure is not None])+")"
32
- result = "Loss {}".format(model_desc(model))
33
- if modelEMA is not None:
34
- result += " Loss EMA {}".format(model_desc(modelEMA))
35
- result += " "+gpuInfo()
36
- if showMemory:
37
- result +=" | {}".format(memoryInfo())
28
+ from konfai import (
29
+ config_file,
30
+ cuda_visible_devices,
31
+ evaluations_directory,
32
+ konfai_state,
33
+ predictions_directory,
34
+ statistics_directory,
35
+ )
36
+
37
+
38
+ def description(model, model_ema=None, show_memory: bool = True, train: bool = True) -> str:
39
+ def loss_desc(model):
40
+ return (
41
+ "("
42
+ + " ".join(
43
+ [
44
+ f"{name}({(network.optimizer.param_groups[0]['lr'] if network.optimizer else 0):.6f}) : "
45
+ + " ".join(
46
+ f"{k.split(':')[-1]}({w:.2f}) : {v:.6f}"
47
+ for (k, v), w in zip(
48
+ network.measure.get_last_values().items(), network.measure.get_last_weights().values()
49
+ )
50
+ )
51
+ for name, network in model.module.get_networks().items()
52
+ if network.measure is not None
53
+ ]
54
+ )
55
+ + ")"
56
+ )
57
+
58
+ model_loss_desc = loss_desc(model)
59
+ result = ""
60
+ if len(model_loss_desc) > 2:
61
+ result += f"Loss {model_loss_desc} "
62
+ if model_ema is not None:
63
+ model_ema_loss_desc = loss_desc(model_ema)
64
+ if len(model_ema_loss_desc) > 2:
65
+ result += f"Loss EMA {model_ema_loss_desc} "
66
+ result += gpu_info()
67
+ if show_memory:
68
+ result += f" | {get_memory_info()}"
38
69
  return result
39
70
 
40
- def _getModule(classpath : str, type : str) -> tuple[str, str]:
71
+
72
+ def get_module(classpath: str, default_classpath: str) -> tuple[str, str]:
41
73
  if len(classpath.split(":")) > 1:
42
74
  module = ".".join(classpath.split(":")[:-1])
43
75
  name = classpath.split(":")[-1]
44
76
  else:
45
- module = type+("." if len(classpath.split(".")) > 2 else "")+".".join(classpath.split(".")[:-1])
77
+ module = (
78
+ default_classpath + ("." if len(classpath.split(".")) > 2 else "") + ".".join(classpath.split(".")[:-1])
79
+ )
46
80
  name = classpath.split(".")[-1]
47
81
  return module, name.split("/")[0]
48
82
 
49
- def cpuInfo() -> str:
50
- return "CPU ({:.2f} %)".format(psutil.cpu_percent(interval=0.5))
51
83
 
52
- def memoryInfo() -> str:
53
- return "Memory ({:.2f}G ({:.2f} %))".format(psutil.virtual_memory()[3]/2**30, psutil.virtual_memory()[2])
84
+ def get_cpu_info() -> str:
85
+ return f"CPU ({psutil.cpu_percent(interval=0.5):.2f} %)"
86
+
87
+
88
+ def get_memory_info() -> str:
89
+ return f"Memory ({psutil.virtual_memory()[3] / 2**30:.2f}G ({psutil.virtual_memory()[2]:.2f} %))"
90
+
54
91
 
55
- def getMemory() -> float:
56
- return psutil.virtual_memory()[3]/2**30
92
+ def get_memory() -> float:
93
+ return psutil.virtual_memory()[3] / 2**30
57
94
 
58
- def memoryForecast(memory_init : float, i : float, size : float) -> str:
59
- current_memory = getMemory()
60
- forecast = memory_init + ((current_memory-memory_init)*size/i) if i > 0 else 0
61
- return "Memory forecast ({:.2f}G ({:.2f} %))".format(forecast, forecast/(psutil.virtual_memory()[0]/2**30)*100)
62
95
 
63
- def gpuInfo() -> str:
64
- if CUDA_VISIBLE_DEVICES() == "":
96
+ def memory_forecast(memory_init: float, i: float, size: float) -> str:
97
+ current_memory = get_memory()
98
+ forecast = memory_init + ((current_memory - memory_init) * size / i) if i > 0 else 0
99
+ return f"Memory forecast ({forecast:.2f}G ({forecast / (psutil.virtual_memory()[0] / 2**30) * 100:.2f} %))"
100
+
101
+
102
+ def gpu_info() -> str:
103
+ if cuda_visible_devices() == "":
65
104
  return ""
66
-
67
- devices = [int(i) for i in CUDA_VISIBLE_DEVICES().split(",")]
105
+
106
+ devices = [int(i) for i in cuda_visible_devices().split(",")]
68
107
  device = devices[0]
69
-
108
+
70
109
  if device < pynvml.nvmlDeviceGetCount():
71
110
  handle = pynvml.nvmlDeviceGetHandleByIndex(device)
72
111
  memory = pynvml.nvmlDeviceGetMemoryInfo(handle)
73
112
  else:
74
113
  return ""
75
- node_name = "Node: {} " +os.environ["SLURMD_NODENAME"] if "SLURMD_NODENAME" in os.environ else ""
76
- return "{}GPU({}) Memory GPU ({:.2f}G ({:.2f} %))".format(node_name, devices, float(memory.used)/(10**9), float(memory.used)/float(memory.total)*100)
114
+ node_name = "Node: {} " + os.environ["SLURMD_NODENAME"] if "SLURMD_NODENAME" in os.environ else ""
115
+ return f"{node_name}GPU({devices}) Memory GPU ({memory.used / 1e9:.2f}G ({memory.used / memory.total * 100:.2f} %))"
116
+
77
117
 
78
- def getMaxGPUMemory(device : Union[int, torch.device]) -> float:
118
+ def get_max_gpu_memory(device: int | torch.device) -> float:
79
119
  if isinstance(device, torch.device):
80
120
  if str(device).startswith("cuda:"):
81
121
  device = int(str(device).replace("cuda:", ""))
82
122
  else:
83
123
  return 0
84
- device = [int(i) for i in CUDA_VISIBLE_DEVICES().split(",")][device]
124
+ device = [int(i) for i in cuda_visible_devices().split(",")][device]
85
125
  if device < pynvml.nvmlDeviceGetCount():
86
126
  handle = pynvml.nvmlDeviceGetHandleByIndex(device)
87
127
  memory = pynvml.nvmlDeviceGetMemoryInfo(handle)
88
128
  else:
89
129
  return 0
90
- return float(memory.total)/(10**9)
130
+ return float(memory.total) / (10**9)
91
131
 
92
- def getGPUMemory(device : Union[int, torch.device]) -> float:
132
+
133
+ def get_gpu_memory(device: int | torch.device) -> float:
93
134
  if isinstance(device, torch.device):
94
135
  if str(device).startswith("cuda:"):
95
136
  device = int(str(device).replace("cuda:", ""))
96
137
  else:
97
138
  return 0
98
- device = [int(i) for i in CUDA_VISIBLE_DEVICES().split(",")][device]
139
+ device = [int(i) for i in cuda_visible_devices().split(",")][device]
99
140
  if device < pynvml.nvmlDeviceGetCount():
100
141
  handle = pynvml.nvmlDeviceGetHandleByIndex(device)
101
142
  memory = pynvml.nvmlDeviceGetMemoryInfo(handle)
102
143
  else:
103
144
  return 0
104
- return float(memory.used)/(10**9)
145
+ return float(memory.used) / (10**9)
146
+
105
147
 
106
- class NeedDevice(ABC):
148
+ class NeedDevice:
107
149
 
108
150
  def __init__(self) -> None:
109
151
  super().__init__()
110
- self.device : torch.device
111
-
112
- def setDevice(self, device : int):
113
- self.device = getDevice(device)
152
+ self.device: torch.device
153
+
154
+ def to(self, device: int):
155
+ self.device = get_device(device)
156
+
157
+
158
+ def get_device(device: int):
159
+ return device if torch.cuda.is_available() and device >= 0 else torch.device("cpu")
114
160
 
115
- def getDevice(device : int):
116
- return device if torch.cuda.is_available() and device >=0 else torch.device("cpu")
117
161
 
118
162
  class State(Enum):
119
163
  TRAIN = "TRAIN"
@@ -122,13 +166,18 @@ class State(Enum):
122
166
  FINE_TUNING = "FINE_TUNING"
123
167
  PREDICTION = "PREDICTION"
124
168
  EVALUATION = "EVALUATION"
125
-
169
+
126
170
  def __str__(self) -> str:
127
171
  return self.value
128
172
 
129
- def get_patch_slices_from_nb_patch_per_dim(patch_size_tmp: list[int], nb_patch_per_dim : list[tuple[int, bool]], overlap: Union[int, None]) -> list[tuple[slice]]:
173
+
174
+ def get_patch_slices_from_nb_patch_per_dim(
175
+ patch_size_tmp: list[int],
176
+ nb_patch_per_dim: list[tuple[int, bool]],
177
+ overlap: int | None,
178
+ ) -> list[tuple[slice, ...]]:
130
179
  patch_slices = []
131
- slices : list[list[slice]] = []
180
+ slices: list[list[slice]] = []
132
181
  if overlap is None:
133
182
  overlap = 0
134
183
  patch_size = []
@@ -138,47 +187,54 @@ def get_patch_slices_from_nb_patch_per_dim(patch_size_tmp: list[int], nb_patch_p
138
187
  patch_size.append(1)
139
188
  else:
140
189
  patch_size.append(patch_size_tmp[i])
141
- i+=1
190
+ i += 1
142
191
 
143
192
  for dim, nb in enumerate(nb_patch_per_dim):
144
193
  slices.append([])
145
194
  for index in range(nb[0]):
146
- start = (patch_size[dim]-overlap)*index
195
+ start = (patch_size[dim] - overlap) * index
147
196
  end = start + patch_size[dim]
148
- slices[dim].append(slice(start,end))
197
+ slices[dim].append(slice(start, end))
149
198
  for chunk in itertools.product(*slices):
150
199
  patch_slices.append(tuple(chunk))
151
200
  return patch_slices
152
201
 
153
- def get_patch_slices_from_shape(patch_size: list[int], shape : list[int], overlap: Union[int, None]) -> tuple[list[tuple[slice]], list[tuple[int, bool]]]:
202
+
203
+ def get_patch_slices_from_shape(
204
+ patch_size: list[int], shape: list[int], overlap_tmp: int | None
205
+ ) -> tuple[list[tuple[slice, ...]], list[tuple[int, bool]]]:
154
206
  if len(shape) != len(patch_size):
155
- raise DatasetManagerError(
156
- f"Dimension mismatch: 'patch_size' has {len(patch_size)} dimensions, but 'shape' has {len(shape)}.",
157
- f"patch_size: {patch_size}",
158
- f"shape: {shape}",
159
- "Both must have the same number of dimensions (e.g., 3D patch for 3D volume)."
160
- )
207
+ raise DatasetManagerError(
208
+ f"Dimension mismatch: 'patch_size' has {len(patch_size)} dimensions, but 'shape' has {len(shape)}.",
209
+ f"patch_size: {patch_size}",
210
+ f"shape: {shape}",
211
+ "Both must have the same number of dimensions (e.g., 3D patch for 3D volume).",
212
+ )
161
213
  patch_slices = []
162
214
  nb_patch_per_dim = []
163
- slices : list[list[slice]] = []
164
- if overlap is None:
165
- size = [np.ceil(a/b) for a, b in zip(shape, patch_size)]
215
+ slices: list[list[slice]] = []
216
+ if overlap_tmp is None:
217
+ size = [np.ceil(a / b) for a, b in zip(shape, patch_size)]
166
218
  tmp = np.zeros(len(size), dtype=np.int_)
167
219
  for i, s in enumerate(size):
168
220
  if s > 1:
169
- tmp[i] = np.mod(patch_size[i]-np.mod(shape[i], patch_size[i]), patch_size[i])//(size[i]-1)
221
+ tmp[i] = np.mod(patch_size[i] - np.mod(shape[i], patch_size[i]), patch_size[i]) // (size[i] - 1)
170
222
  overlap = tmp
171
223
  else:
172
- overlap = [overlap if size > 1 else 0 for size in patch_size]
173
-
224
+ overlap = [overlap_tmp if size > 1 else 0 for size in patch_size]
225
+
174
226
  for dim in range(len(shape)):
175
- assert overlap[dim] < patch_size[dim], "Overlap must be less than patch size"
227
+ if overlap[dim] >= patch_size[dim]:
228
+ raise ValueError(
229
+ f"Overlap must be less than patch size, got overlap={overlap[dim]}",
230
+ f" ≥ patch_size={patch_size[dim]} at dim={dim}",
231
+ )
176
232
 
177
233
  for dim in range(len(shape)):
178
234
  slices.append([])
179
235
  index = 0
180
236
  while True:
181
- start = (patch_size[dim]-overlap[dim])*index
237
+ start = (patch_size[dim] - overlap[dim]) * index
182
238
 
183
239
  end = start + patch_size[dim]
184
240
  if end >= shape[dim]:
@@ -187,84 +243,107 @@ def get_patch_slices_from_shape(patch_size: list[int], shape : list[int], overla
187
243
  break
188
244
  slices[dim].append(slice(start, end))
189
245
  index += 1
190
- nb_patch_per_dim.append((index+1, patch_size[dim] == 1))
246
+ nb_patch_per_dim.append((index + 1, patch_size[dim] == 1))
191
247
 
192
248
  for chunk in itertools.product(*slices):
193
249
  patch_slices.append(tuple(chunk))
194
-
250
+
195
251
  return patch_slices, nb_patch_per_dim
196
252
 
197
- def _logSignalFormat(input : np.ndarray):
198
- return {str(i): channel for i, channel in enumerate(input)}
199
-
200
- def _logImageFormat(input : np.ndarray):
201
- if len(input.shape) == 2:
202
- input = np.expand_dims(input, axis=0)
203
-
204
- if len(input.shape) == 3 and input.shape[0] != 1:
205
- input = np.expand_dims(input, axis=0)
206
- if len(input.shape) == 4:
207
- input = input[:, input.shape[1]//2]
208
-
209
- input = input.astype(float)
210
- b = -np.min(input)
211
- if (np.max(input)+b) > 0:
212
- return (input+b)/(np.max(input)+b)
253
+
254
+ def _log_signal_format(array: np.ndarray) -> dict[str, np.ndarray]:
255
+ return {str(i): channel for i, channel in enumerate(array)}
256
+
257
+
258
+ def _log_image_format(array: np.ndarray) -> np.ndarray:
259
+ if len(array.shape) == 2:
260
+ array = np.expand_dims(array, axis=0)
261
+
262
+ if len(array.shape) == 3 and array.shape[0] != 1:
263
+ array = np.expand_dims(array, axis=0)
264
+ if len(array.shape) == 4:
265
+ array = array[:, array.shape[1] // 2]
266
+
267
+ array = array.astype(float)
268
+ b = -np.min(array)
269
+ if (np.max(array) + b) > 0:
270
+ return (array + b) / (np.max(array) + b)
213
271
  else:
214
- return 0*input
272
+ return 0 * array
273
+
215
274
 
216
- def _logImagesFormat(input : np.ndarray):
275
+ def _log_images_format(array: np.ndarray) -> np.ndarray:
217
276
  result = []
218
- for n in range(input.shape[0]):
219
- result.append(_logImageFormat(input[n]))
277
+ for n in range(array.shape[0]):
278
+ result.append(_log_image_format(array[n]))
220
279
  result = np.stack(result, axis=0)
221
280
  return result
222
281
 
223
- def _logVideoFormat(input : np.ndarray):
224
- result = []
225
- for t in range(input.shape[1]):
226
- result.append( _logImagesFormat(input[:, t,...]))
227
- result = np.stack(result, axis=1)
282
+
283
+ def _log_video_format(array: np.ndarray) -> np.ndarray:
284
+ result_list = []
285
+ for t in range(array.shape[1]):
286
+ result_list.append(_log_images_format(array[:, t, ...]))
287
+ result = np.stack(result_list, axis=1)
228
288
 
229
289
  nb_channel = result.shape[2]
230
290
  if nb_channel < 3:
231
291
  channel_split = [result[:, :, 0, ...] for i in range(3)]
232
292
  else:
233
293
  channel_split = np.split(result, 3, axis=0)
234
- input = np.zeros((result.shape[0], result.shape[1], 3, *list(result.shape[3:])))
294
+ array = np.zeros((result.shape[0], result.shape[1], 3, *list(result.shape[3:])))
235
295
  for i, channels in enumerate(channel_split):
236
- input[:,:,i] = np.mean(channels, axis=0)
237
- return input
296
+ array[:, :, i] = np.mean(channels, axis=0)
297
+ return array
298
+
238
299
 
239
300
  class DataLog(Enum):
240
- SIGNAL = lambda tb, name, layer, it : [tb.add_scalars(name, _logSignalFormat(layer[b, :, 0]), layer.shape[0]*it+b) for b in range(layer.shape[0])],
241
- IMAGE = lambda tb, name, layer, it : tb.add_image(name, _logImageFormat(layer[0]), it),
242
- IMAGES = lambda tb, name, layer, it : tb.add_images(name, _logImagesFormat(layer), it),
243
- VIDEO = lambda tb, name, layer, it : tb.add_video(name, _logVideoFormat(layer), it),
244
- AUDIO = lambda tb, name, layer, it : tb.add_audio(name, _logImageFormat(layer), it)
301
+ SIGNAL = "SIGNAL"
302
+ IMAGE = "IMAGE"
303
+ IMAGES = "IMAGES"
304
+ VIDEO = "VIDEO"
305
+ AUDIO = "AUDIO"
306
+
307
+ def __call__(self, tb: SummaryWriter, name: str, layer: torch.Tensor, it: int):
308
+ if self == DataLog.SIGNAL:
309
+ return [
310
+ tb.add_scalars(name, _log_signal_format(layer[b, :, 0]), layer.shape[0] * it + b)
311
+ for b in range(layer.shape[0])
312
+ ]
313
+ elif self == DataLog.IMAGE:
314
+ return tb.add_image(name, _log_image_format(layer[0]), it)
315
+ elif self == DataLog.IMAGES:
316
+ return tb.add_images(name, _log_images_format(layer), it)
317
+ elif self == DataLog.VIDEO:
318
+ return tb.add_video(name, _log_video_format(layer), it)
319
+ elif self == DataLog.AUDIO:
320
+ return tb.add_audio(name, _log_image_format(layer), it)
321
+ else:
322
+ raise ValueError(f"Unsupported DataLog type: {self}")
323
+
245
324
 
246
325
  class Log:
247
326
  def __init__(self, name: str, rank: int) -> None:
248
- if KONFAI_STATE() == "PREDICTION":
249
- path = PREDICTIONS_DIRECTORY()
250
- elif KONFAI_STATE() == "EVALUATION":
251
- path = EVALUATIONS_DIRECTORY()
327
+ if konfai_state() == "PREDICTION":
328
+ path = predictions_directory()
329
+ elif konfai_state() == "EVALUATION":
330
+ path = evaluations_directory()
252
331
  else:
253
- path = STATISTICS_DIRECTORY()
254
-
332
+ path = statistics_directory()
333
+
255
334
  self.verbose = os.environ.get("KONFAI_VERBOSE", "True") == "True"
256
335
  self.log_path = os.path.join(path, name)
257
336
  os.makedirs(self.log_path, exist_ok=True)
258
337
  self.rank = rank
259
- self.file = open(os.path.join(self.log_path, "log_{}.txt".format(rank)), "w", buffering=1)
338
+ self.file = open(os.path.join(self.log_path, f"log_{rank}.txt"), "w", buffering=1)
260
339
  self.stdout_bak = sys.stdout
261
340
  self.stderr_bak = sys.stderr
262
341
  self._buffered_line = ""
263
-
342
+
264
343
  def __enter__(self):
265
344
  self.file.__enter__()
266
- sys.stdout = self
267
- sys.stderr = self
345
+ sys.stdout = cast(TextIO, self)
346
+ sys.stderr = cast(TextIO, self)
268
347
  return self
269
348
 
270
349
  def __exit__(self, exc_type, exc_val, exc_tb):
@@ -275,25 +354,22 @@ class Log:
275
354
  def write(self, msg: str):
276
355
  if not msg:
277
356
  return
278
-
279
-
280
- ANSI_ESCAPE = re.compile(r'\x1b\[[0-9;]*[a-zA-Z]')
281
- CARRIAGE_RETURN = re.compile(r'(?:\r|\x1b\[A).*')
282
- msg_clean = ANSI_ESCAPE.sub('', msg)
283
- if '\r' in msg_clean or '' in msg:
284
- # On garde seulement le contenu après le dernier retour chariot
285
- msg_clean = msg_clean.split('\r')[-1].strip()
357
+
358
+ ansi_escape = re.compile(r"\x1b\[[0-9;]*[a-zA-Z]")
359
+ msg_clean = ansi_escape.sub("", msg)
360
+ if "\r" in msg_clean or "" in msg:
361
+ msg_clean = msg_clean.split("\r")[-1].strip()
286
362
  self._buffered_line = msg_clean
287
363
  else:
288
364
  self._buffered_line = msg_clean.strip()
289
365
 
290
366
  if self._buffered_line:
291
- # Écrit dans le fichier
292
367
  self.file.write(self._buffered_line + "\n")
293
368
  self.file.flush()
294
369
  if self.verbose and (self.rank == 0 or "KONFAI_CLUSTER" in os.environ):
295
- sys.__stdout__.write(msg)
296
- sys.__stdout__.flush()
370
+ if sys.__stdout__ is not None:
371
+ sys.__stdout__.write(msg)
372
+ sys.__stdout__.flush()
297
373
 
298
374
  def flush(self):
299
375
  self.file.flush()
@@ -302,65 +378,104 @@ class Log:
302
378
  return False
303
379
 
304
380
  def fileno(self):
381
+ if sys.__stdout__ is None:
382
+ raise RuntimeError("sys.__stdout__ is None, cannot get fileno")
305
383
  return sys.__stdout__.fileno()
306
-
307
- class TensorBoard():
308
-
384
+
385
+
386
+ class TensorBoard:
387
+
309
388
  def __init__(self, name: str) -> None:
310
- self.process = None
389
+ self.process: subprocess.Popen | None = None
311
390
  self.name = name
312
391
 
313
392
  def __enter__(self):
314
393
  if "KONFAI_TENSORBOARD_PORT" in os.environ:
315
- command = ["tensorboard", "--logdir", PREDICTIONS_DIRECTORY() if KONFAI_STATE() == "PREDICTION" else STATISTICS_DIRECTORY() + self.name + "/", "--port", os.environ["KONFAI_TENSORBOARD_PORT"], "--bind_all"]
316
- self.process = subprocess.Popen(command, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
394
+ tensorboard_exe = shutil.which("tensorboard")
395
+ if tensorboard_exe is None:
396
+ raise RuntimeError("TensorBoard executable not found in PATH.")
397
+
398
+ logdir = (
399
+ predictions_directory() if konfai_state() == "PREDICTION" else statistics_directory() + self.name + "/"
400
+ )
401
+
402
+ port = os.environ.get("KONFAI_TENSORBOARD_PORT")
403
+ if not port or not port.isdigit():
404
+ raise ValueError("Invalid or missing KONFAI_TENSORBOARD_PORT.")
405
+
406
+ command = [
407
+ tensorboard_exe,
408
+ "--logdir",
409
+ logdir,
410
+ "--port",
411
+ port,
412
+ "--bind_all",
413
+ ]
414
+ self.process = subprocess.Popen(command, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) # nosec B603
317
415
  try:
318
416
  s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
319
- s.connect(('10.255.255.255', 1))
320
- IP = s.getsockname()[0]
417
+ s.connect(("10.255.255.255", 1))
418
+ ip = s.getsockname()[0]
321
419
  except Exception:
322
- IP = '127.0.0.1'
420
+ ip = "127.0.0.1"
323
421
  finally:
324
422
  s.close()
325
- print("Tensorboard : http://{}:{}/".format(IP, os.environ["KONFAI_TENSORBOARD_PORT"]))
423
+ print(f"Tensorboard : http://{ip}:{os.environ['KONFAI_TENSORBOARD_PORT']}/")
326
424
  return self
327
-
328
- def __exit__(self, type, value, traceback):
425
+
426
+ def __exit__(self, exc_type, value, traceback):
329
427
  if self.process is not None:
330
428
  self.process.terminate()
331
429
  self.process.wait()
332
430
 
333
- class DistributedObject():
431
+
432
+ class DistributedObject(ABC):
334
433
 
335
434
  def __init__(self, name: str) -> None:
336
435
  self.port = find_free_port()
337
- self.dataloader : list[list[DataLoader]]
338
- self.manual_seed: bool = None
436
+ self.dataloader: list[list[DataLoader]]
437
+ self.manual_seed: int | None = None
339
438
  self.name = name
340
439
  self.size = 1
341
-
440
+
342
441
  @abstractmethod
343
442
  def setup(self, world_size: int):
344
443
  pass
345
-
444
+
346
445
  def __enter__(self):
347
446
  return self
348
447
 
349
- def __exit__(self, type, value, traceback):
448
+ def __exit__(self, exc_type, value, traceback):
350
449
  cleanup()
351
450
 
352
451
  @abstractmethod
353
- def run_process(self, world_size: int, global_rank: int, local_rank: int, dataloaders: list[DataLoader]):
354
- pass
355
-
356
- def getMeasure(world_size: int, global_rank: int, gpu: int, models: dict[str, torch.nn.Module], n: int) -> dict[str, tuple[dict[str, tuple[float, float]], dict[str, tuple[float, float]]]]:
452
+ def run_process(
453
+ self,
454
+ world_size: int,
455
+ global_rank: int,
456
+ local_rank: int,
457
+ dataloaders: list[DataLoader],
458
+ ):
459
+ pass
460
+
461
+ @staticmethod
462
+ def get_measure(
463
+ world_size: int,
464
+ global_rank: int,
465
+ gpu: int,
466
+ models: dict[str, torch.nn.Module],
467
+ n: int,
468
+ ) -> dict[str, tuple[dict[str, tuple[float, float]], dict[str, tuple[float, float]]]]:
357
469
  data = {}
358
470
  for label, model in models.items():
359
- for name, network in model.getNetworks().items():
471
+ for name, network in model.get_networks().items():
360
472
  if network.measure is not None:
361
- data["{}{}".format(name, label)] = (network.measure.format(True, n), network.measure.format(False, n))
473
+ data[f"{name}{label}"] = (
474
+ network.measure.format_loss(True, n),
475
+ network.measure.format_loss(False, n),
476
+ )
362
477
  outputs = synchronize_data(world_size, gpu, data)
363
- result = {}
478
+ result: dict[str, tuple[dict[str, tuple[float, float]], dict[str, tuple[float, float]]]] = {}
364
479
  if global_rank == 0:
365
480
  for output in outputs:
366
481
  for k, v in output.items():
@@ -369,16 +484,19 @@ class DistributedObject():
369
484
  if k not in result:
370
485
  result[k] = ({}, {})
371
486
  if u not in result[k][t]:
372
- result[k][t][u] = (n[0], 0)
373
- result[k][t][u] = (result[k][t][u][0], result[k][t][u][1]+n[1]/world_size)
487
+ result[k][t][u] = (n[0], 0) # type: ignore[index]
488
+ result[k][t][u] = (
489
+ result[k][t][u][0],
490
+ result[k][t][u][1] + n[1] / world_size, # type: ignore[index]
491
+ )
374
492
  return result
375
493
 
376
- def __call__(self, rank: Union[int, None] = None) -> None:
377
- with Log(self.name, rank):
378
- world_size = len(self.dataloader)
379
- global_rank, local_rank = setupGPU(world_size, self.port, rank)
380
- if global_rank is None:
381
- return
494
+ def __call__(self, rank: int | None = None) -> None:
495
+ world_size = len(self.dataloader)
496
+ global_rank, local_rank = setup_gpu(world_size, self.port, rank)
497
+ if global_rank is None or local_rank is None:
498
+ return
499
+ with Log(self.name, global_rank):
382
500
  if torch.cuda.is_available():
383
501
  pynvml.nvmlInit()
384
502
  if self.manual_seed is not None:
@@ -390,7 +508,7 @@ class DistributedObject():
390
508
  torch.backends.cuda.matmul.allow_tf32 = True
391
509
  torch.backends.cudnn.allow_tf32 = True
392
510
  dataloaders = self.dataloader[global_rank]
393
- if torch.cuda.is_available():
511
+ if torch.cuda.is_available():
394
512
  torch.cuda.set_device(local_rank)
395
513
  try:
396
514
  self.run_process(world_size, global_rank, local_rank, dataloaders)
@@ -399,27 +517,75 @@ class DistributedObject():
399
517
  if torch.cuda.is_available():
400
518
  pynvml.nvmlShutdown()
401
519
 
520
+
402
521
  def setup(parser: argparse.ArgumentParser) -> DistributedObject:
403
522
  # KONFAI arguments
404
- KONFAI_args = parser.add_argument_group('KONFAI arguments')
405
- KONFAI_args.add_argument("type", type=State, choices=list(State))
406
- KONFAI_args.add_argument('-y', action='store_true', help="Accept overwrite")
407
- KONFAI_args.add_argument('-tb', action='store_true', help='Start TensorBoard')
408
- KONFAI_args.add_argument("-c", "--config", type=str, default="None", help="Configuration file location")
409
- KONFAI_args.add_argument("-g", "--gpu", type=str, default=os.environ["CUDA_VISIBLE_DEVICES"] if "CUDA_VISIBLE_DEVICES" in os.environ else "", help="List of GPU")
410
- KONFAI_args.add_argument("-cpu", "--cpu", type=str, default="1" , help="List of GPU")
411
- KONFAI_args.add_argument('--num-workers', '--num_workers', default=4, type=int, help='No. of workers per DataLoader & GPU')
412
- KONFAI_args.add_argument("-models_dir", "--MODELS_DIRECTORY", type=str, default="./Models/", help="Models location")
413
- KONFAI_args.add_argument("-checkpoints_dir", "--CHECKPOINTS_DIRECTORY", type=str, default="./Checkpoints/", help="Checkpoints location")
414
- KONFAI_args.add_argument("-model", "--MODEL", type=str, default="", help="URL Model")
415
- KONFAI_args.add_argument("-predictions_dir", "--PREDICTIONS_DIRECTORY", type=str, default="./Predictions/", help="Predictions location")
416
- KONFAI_args.add_argument("-evaluation_dir", "--EVALUATIONS_DIRECTORY", type=str, default="./Evaluations/", help="Evaluations location")
417
- KONFAI_args.add_argument("-statistics_dir", "--STATISTICS_DIRECTORY", type=str, default="./Statistics/", help="Statistics location")
418
- KONFAI_args.add_argument("-setups_dir", "--SETUPS_DIRECTORY", type=str, default="./Setups/", help="Setups location")
419
- KONFAI_args.add_argument('-log', action='store_true', help='Save log')
420
- KONFAI_args.add_argument('-quiet', action='store_false', help='')
421
-
422
-
523
+ konfai = parser.add_argument_group("KONFAI arguments")
524
+ konfai.add_argument("type", type=State, choices=list(State))
525
+ konfai.add_argument("-y", action="store_true", help="Accept overwrite")
526
+ konfai.add_argument("-tb", action="store_true", help="Start TensorBoard")
527
+ konfai.add_argument("-c", "--config", type=str, default="None", help="Configuration file location")
528
+ konfai.add_argument(
529
+ "-g",
530
+ "--gpu",
531
+ type=str,
532
+ default=(os.environ["CUDA_VISIBLE_DEVICES"] if "CUDA_VISIBLE_DEVICES" in os.environ else ""),
533
+ help="List of GPU",
534
+ )
535
+ konfai.add_argument("-cpu", "--cpu", type=str, default="1", help="List of GPU")
536
+ konfai.add_argument(
537
+ "--num-workers",
538
+ "--num_workers",
539
+ default=4,
540
+ type=int,
541
+ help="No. of workers per DataLoader & GPU",
542
+ )
543
+ konfai.add_argument(
544
+ "-models_dir",
545
+ "--MODELS_DIRECTORY",
546
+ type=str,
547
+ default="./Models/",
548
+ help="Models location",
549
+ )
550
+ konfai.add_argument(
551
+ "-checkpoints_dir",
552
+ "--CHECKPOINTS_DIRECTORY",
553
+ type=str,
554
+ default="./Checkpoints/",
555
+ help="Checkpoints location",
556
+ )
557
+ konfai.add_argument("-model", "--MODEL", type=str, default="", help="URL Model")
558
+ konfai.add_argument(
559
+ "-predictions_dir",
560
+ "--PREDICTIONS_DIRECTORY",
561
+ type=str,
562
+ default="./Predictions/",
563
+ help="Predictions location",
564
+ )
565
+ konfai.add_argument(
566
+ "-evaluation_dir",
567
+ "--EVALUATIONS_DIRECTORY",
568
+ type=str,
569
+ default="./Evaluations/",
570
+ help="Evaluations location",
571
+ )
572
+ konfai.add_argument(
573
+ "-statistics_dir",
574
+ "--STATISTICS_DIRECTORY",
575
+ type=str,
576
+ default="./Statistics/",
577
+ help="Statistics location",
578
+ )
579
+ konfai.add_argument(
580
+ "-setups_dir",
581
+ "--SETUPS_DIRECTORY",
582
+ type=str,
583
+ default="./Setups/",
584
+ help="Setups location",
585
+ )
586
+ konfai.add_argument("-log", action="store_true", help="Save log")
587
+ konfai.add_argument("-quiet", action="store_false", help="")
588
+
423
589
  args = parser.parse_args()
424
590
  config = vars(args)
425
591
 
@@ -432,14 +598,14 @@ def setup(parser: argparse.ArgumentParser) -> DistributedObject:
432
598
  os.environ["KONFAI_PREDICTIONS_DIRECTORY"] = config["PREDICTIONS_DIRECTORY"]
433
599
  os.environ["KONFAI_EVALUATIONS_DIRECTORY"] = config["EVALUATIONS_DIRECTORY"]
434
600
  os.environ["KONFAI_STATISTICS_DIRECTORY"] = config["STATISTICS_DIRECTORY"]
435
-
601
+
436
602
  os.environ["KONFAI_STATE"] = str(config["type"])
437
-
603
+
438
604
  os.environ["KONFAI_MODEL"] = config["MODEL"]
439
605
 
440
606
  os.environ["KONFAI_SETUPS_DIRECTORY"] = config["SETUPS_DIRECTORY"]
441
607
 
442
- os.environ["KONFAI_OVERWRITE"] = "{}".format(config["y"])
608
+ os.environ["KONFAI_OVERWRITE"] = str(config["y"])
443
609
  os.environ["KONFAI_CONFIG_MODE"] = "Done"
444
610
  if config["tb"]:
445
611
  os.environ["KONFAI_TENSORBOARD_PORT"] = str(find_free_port())
@@ -448,37 +614,49 @@ def setup(parser: argparse.ArgumentParser) -> DistributedObject:
448
614
 
449
615
  if config["config"] == "None":
450
616
  if config["type"] is State.PREDICTION:
451
- os.environ["KONFAI_CONFIG_FILE"] = "Prediction.yml"
617
+ os.environ["KONFAI_config_file"] = "Prediction.yml"
452
618
  elif config["type"] is State.EVALUATION:
453
- os.environ["KONFAI_CONFIG_FILE"] = "Evaluation.yml"
619
+ os.environ["KONFAI_config_file"] = "Evaluation.yml"
454
620
  else:
455
- os.environ["KONFAI_CONFIG_FILE"] = "Config.yml"
621
+ os.environ["KONFAI_config_file"] = "Config.yml"
456
622
  else:
457
- os.environ["KONFAI_CONFIG_FILE"] = config["config"]
623
+ os.environ["KONFAI_config_file"] = config["config"]
458
624
  torch.autograd.set_detect_anomaly(True)
459
625
  os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
460
-
461
- if config["type"] is State.PREDICTION:
626
+
627
+ if config["type"] is State.PREDICTION:
462
628
  from konfai.predictor import Predictor
463
- os.environ["KONFAI_ROOT"] = "Predictor"
464
- return Predictor(config=CONFIG_FILE())
629
+
630
+ os.environ["konfai_root"] = "Predictor"
631
+ return Predictor(config=config_file())
465
632
  elif config["type"] is State.EVALUATION:
466
633
  from konfai.evaluator import Evaluator
467
- os.environ["KONFAI_ROOT"] = "Evaluator"
468
- return Evaluator(config=CONFIG_FILE())
634
+
635
+ os.environ["konfai_root"] = "Evaluator"
636
+ return Evaluator(config=config_file())
469
637
  else:
470
638
  from konfai.trainer import Trainer
471
- os.environ["KONFAI_ROOT"] = "Trainer"
472
- return Trainer(config=CONFIG_FILE())
473
639
 
640
+ os.environ["konfai_root"] = "Trainer"
641
+ return Trainer(config=config_file())
474
642
 
475
- def setupGPU(world_size: int, port: int, rank: Union[int, None] = None) -> tuple[int , int]:
643
+
644
+ def setup_gpu(world_size: int, port: int, rank: int | None = None) -> tuple[int | None, int | None]:
476
645
  try:
477
- host_name = subprocess.check_output("scontrol show hostnames {}".format(os.getenv('SLURM_JOB_NODELIST')).split()).decode().splitlines()[0]
478
- except:
646
+ nodelist = os.getenv("SLURM_JOB_NODELIST")
647
+ if nodelist is None:
648
+ raise RuntimeError("SLURM_JOB_NODELIST is not set.")
649
+ scontrol_path = shutil.which("scontrol")
650
+ if scontrol_path is None:
651
+ raise FileNotFoundError("scontrol not found in PATH")
652
+ host_name = subprocess.check_output(
653
+ [scontrol_path, "show", "hostnames", nodelist], text=True, stderr=subprocess.DEVNULL
654
+ ).strip() # nosec B603
655
+ except Exception:
479
656
  host_name = "localhost"
480
657
  if rank is None:
481
658
  import submitit
659
+
482
660
  job_env = submitit.JobEnvironment()
483
661
  global_rank = job_env.global_rank
484
662
  local_rank = job_env.local_rank
@@ -487,42 +665,48 @@ def setupGPU(world_size: int, port: int, rank: Union[int, None] = None) -> tuple
487
665
  local_rank = rank
488
666
  if global_rank >= world_size:
489
667
  return None, None
490
- #print("tcp://{}:{}".format(host_name, port))
668
+ # print("tcp://{}:{}".format(host_name, port))
491
669
  if torch.cuda.is_available():
492
670
  torch.cuda.empty_cache()
493
- dist.init_process_group("nccl", rank=global_rank, init_method="tcp://{}:{}".format(host_name, port), world_size=world_size)
671
+ dist.init_process_group(
672
+ "nccl",
673
+ rank=global_rank,
674
+ init_method=f"tcp://{host_name}:{port}",
675
+ world_size=world_size,
676
+ )
494
677
  else:
495
678
  if not dist.is_initialized():
496
679
  dist.init_process_group(
497
680
  backend="gloo",
498
681
  init_method=f"tcp://{host_name}:{port}",
499
682
  rank=global_rank,
500
- world_size=world_size
683
+ world_size=world_size,
501
684
  )
502
685
  return global_rank, local_rank
503
686
 
504
- import socket
505
- from contextlib import closing
506
687
 
507
688
  def find_free_port():
508
689
  with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
509
- s.bind(('', 0))
690
+ s.bind(("", 0))
510
691
  s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
511
692
  return s.getsockname()[1]
512
-
693
+
694
+
513
695
  def cleanup():
514
696
  if dist.is_initialized():
515
697
  dist.destroy_process_group()
516
698
 
517
- def synchronize_data(world_size: int, gpu: int, data: any) -> list[Any]:
699
+
700
+ def synchronize_data(world_size: int, gpu: int, data: Any) -> list[Any]:
518
701
  if torch.cuda.is_available():
519
- outputs: list[dict[str, tuple[dict[str, float], dict[str, float]]]] = [None for _ in range(world_size)]
702
+ outputs: list[dict[str, tuple[dict[str, float], dict[str, float]]] | None] = [None] * world_size
520
703
  torch.cuda.set_device(gpu)
521
704
  dist.all_gather_object(outputs, data)
522
705
  else:
523
706
  outputs = [data]
524
707
  return outputs
525
708
 
709
+
526
710
  def _resample(data: torch.Tensor, size: list[int]) -> torch.Tensor:
527
711
  if data.dtype == torch.uint8:
528
712
  mode = "nearest"
@@ -530,38 +714,79 @@ def _resample(data: torch.Tensor, size: list[int]) -> torch.Tensor:
530
714
  mode = "bilinear"
531
715
  else:
532
716
  mode = "trilinear"
533
- return F.interpolate(data.type(torch.float32).unsqueeze(0), size=tuple([s for s in reversed(size)]), mode=mode).squeeze(0).type(data.dtype)
717
+ return (
718
+ F.interpolate(
719
+ data.type(torch.float32).unsqueeze(0),
720
+ size=tuple(reversed(size)),
721
+ mode=mode,
722
+ )
723
+ .squeeze(0)
724
+ .type(data.dtype)
725
+ )
726
+
534
727
 
535
728
  def _affine_matrix(matrix: torch.Tensor, translation: torch.Tensor) -> torch.Tensor:
536
- return torch.cat((torch.cat((matrix, translation.unsqueeze(0).T), dim=1), torch.tensor([[0, 0, 0, 1]])), dim=0)
729
+ return torch.cat(
730
+ (
731
+ torch.cat((matrix, translation.unsqueeze(0).T), dim=1),
732
+ torch.tensor([[0, 0, 0, 1]]),
733
+ ),
734
+ dim=0,
735
+ )
736
+
537
737
 
538
738
  def _resample_affine(data: torch.Tensor, matrix: torch.Tensor):
539
739
  if data.dtype == torch.uint8:
540
740
  mode = "nearest"
541
741
  else:
542
742
  mode = "bilinear"
543
- return F.grid_sample(data.unsqueeze(0).type(torch.float32), F.affine_grid(matrix[:, :-1,...].type(torch.float32), [1]+list(data.shape), align_corners=True), align_corners=True, mode=mode, padding_mode="reflection").squeeze(0).type(data.dtype)
544
-
743
+ return (
744
+ F.grid_sample(
745
+ data.unsqueeze(0).type(torch.float32),
746
+ F.affine_grid(
747
+ matrix[:, :-1, ...].type(torch.float32),
748
+ [1] + list(data.shape),
749
+ align_corners=True,
750
+ ),
751
+ align_corners=True,
752
+ mode=mode,
753
+ padding_mode="reflection",
754
+ )
755
+ .squeeze(0)
756
+ .type(data.dtype)
757
+ )
545
758
 
546
759
 
547
760
  def download_url(model_name: str, url: str) -> str:
548
761
  spec = importlib.util.find_spec("konfai")
549
- base_path = Path(spec.submodule_search_locations[0]) / "metric" / "models"
762
+ if spec is None or spec.submodule_search_locations is None:
763
+ raise ImportError("Could not locate 'konfai' package")
764
+ locations = spec.submodule_search_locations
765
+ if not isinstance(locations, list) or not locations:
766
+ raise ImportError("No valid submodule_search_locations found")
767
+ base_path = Path(locations[0]) / "metric" / "models"
768
+ os.makedirs(base_path, exist_ok=True)
769
+
550
770
  subdirs = Path(model_name).parent
551
771
  model_dir = base_path / subdirs
552
772
  model_dir.mkdir(exist_ok=True)
553
- filetmp = model_dir / ("tmp_"+str(Path(model_name).name))
773
+ filetmp = model_dir / ("tmp_" + str(Path(model_name).name))
554
774
  file = model_dir / Path(model_name).name
555
775
  if file.exists():
556
776
  return str(file)
557
-
777
+
558
778
  try:
559
779
  print(f"[FOCUS] Downloading {model_name} to {file}")
560
- with requests.get(url+model_name, stream=True) as r:
780
+ with requests.get(url + model_name, stream=True, timeout=10) as r:
561
781
  r.raise_for_status()
562
- total = int(r.headers.get('content-length', 0))
563
- with open(filetmp, 'wb') as f:
564
- with tqdm(total=total, unit='B', unit_scale=True, desc=f"Downloading {model_name}") as pbar:
782
+ total = int(r.headers.get("content-length", 0))
783
+ with open(filetmp, "wb") as f:
784
+ with tqdm(
785
+ total=total,
786
+ unit="B",
787
+ unit_scale=True,
788
+ desc=f"Downloading {model_name}",
789
+ ) as pbar:
565
790
  for chunk in r.iter_content(chunk_size=8192):
566
791
  f.write(chunk)
567
792
  pbar.update(len(chunk))
@@ -573,63 +798,86 @@ def download_url(model_name: str, url: str) -> str:
573
798
  if filetmp.exists():
574
799
  os.remove(filetmp)
575
800
  return str(file)
576
-
577
- SUPPORTED_EXTENSIONS = [
578
- "mha", "mhd", # MetaImage
579
- "nii", "nii.gz", # NIfTI
580
- "nrrd", "nrrd.gz", # NRRD
581
- "gipl", "gipl.gz", # GIPL
582
- "hdr", "img", # Analyze
583
- "dcm", # DICOM (si GDCM activé)
584
- "tif", "tiff", # TIFF
585
- "png", "jpg", "jpeg", "bmp", # 2D formats
586
- "h5", "itk.txt", ".fcsv", ".xml", ".vtk", ".npy"
587
801
 
802
+
803
+ SUPPORTED_EXTENSIONS = [
804
+ "mha",
805
+ "mhd", # MetaImage
806
+ "nii",
807
+ "nii.gz", # NIfTI
808
+ "nrrd",
809
+ "nrrd.gz", # NRRD
810
+ "gipl",
811
+ "gipl.gz", # GIPL
812
+ "hdr",
813
+ "img", # Analyze
814
+ "dcm", # DICOM (si GDCM activé)
815
+ "tif",
816
+ "tiff", # TIFF
817
+ "png",
818
+ "jpg",
819
+ "jpeg",
820
+ "bmp", # 2D formats
821
+ "h5",
822
+ "itk.txt",
823
+ ".fcsv",
824
+ ".xml",
825
+ ".vtk",
826
+ ".npy",
588
827
  ]
589
828
 
829
+
590
830
  class KonfAIError(Exception):
591
831
 
592
- def __init__(self, typeError: str, messages: list[str]) -> None:
593
- super().__init__("\n[{}] {}".format(typeError, messages[0])+("\n" if len(messages) > 0 else "")+"\n→\t".join(messages[1:]))
832
+ def __init__(self, type_error: str, *messages: str) -> None:
833
+ super().__init__(
834
+ f"\n[{type_error}] {messages[0]}" + ("\n" if len(messages) > 0 else "") + "\n→\t".join(messages[1:])
835
+ )
594
836
 
595
837
 
596
838
  class ConfigError(KonfAIError):
597
839
 
598
840
  def __init__(self, *message) -> None:
599
- super().__init__("Config", message)
841
+ super().__init__("Config", *message)
600
842
 
601
843
 
602
844
  class DatasetManagerError(KonfAIError):
603
845
 
604
846
  def __init__(self, *message) -> None:
605
- super().__init__("DatasetManager", message)
847
+ super().__init__("DatasetManager", *message)
848
+
606
849
 
607
850
  class MeasureError(KonfAIError):
608
851
 
609
852
  def __init__(self, *message) -> None:
610
- super().__init__("Measure", message)
853
+ super().__init__("Measure", *message)
854
+
611
855
 
612
856
  class TrainerError(KonfAIError):
613
857
 
614
858
  def __init__(self, *message) -> None:
615
- super().__init__("Trainer", message)
859
+ super().__init__("Trainer", *message)
860
+
616
861
 
617
862
  class AugmentationError(KonfAIError):
618
863
 
619
864
  def __init__(self, *message) -> None:
620
- super().__init__("Augmentation", message)
865
+ super().__init__("Augmentation", *message)
866
+
621
867
 
622
868
  class EvaluatorError(KonfAIError):
623
869
 
624
870
  def __init__(self, *message) -> None:
625
- super().__init__("Evaluator", message)
871
+ super().__init__("Evaluator", *message)
872
+
626
873
 
627
874
  class PredictorError(KonfAIError):
628
875
 
629
876
  def __init__(self, *message) -> None:
630
- super().__init__("Predictor", message)
877
+ super().__init__("Predictor", *message)
878
+
631
879
 
632
880
  class TransformError(KonfAIError):
633
881
 
634
882
  def __init__(self, *message) -> None:
635
- super().__init__("Transform", message)
883
+ super().__init__("Transform", *message)