konfai 1.1.8__py3-none-any.whl → 1.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of konfai might be problematic. Click here for more details.
- konfai/__init__.py +59 -14
- konfai/data/augmentation.py +457 -286
- konfai/data/data_manager.py +533 -316
- konfai/data/patching.py +300 -183
- konfai/data/transform.py +408 -275
- konfai/evaluator.py +325 -68
- konfai/main.py +71 -22
- konfai/metric/measure.py +360 -244
- konfai/metric/schedulers.py +24 -13
- konfai/models/classification/convNeXt.py +187 -81
- konfai/models/classification/resnet.py +272 -58
- konfai/models/generation/cStyleGan.py +233 -59
- konfai/models/generation/ddpm.py +348 -121
- konfai/models/generation/diffusionGan.py +757 -358
- konfai/models/generation/gan.py +177 -53
- konfai/models/generation/vae.py +140 -40
- konfai/models/registration/registration.py +135 -52
- konfai/models/representation/representation.py +57 -23
- konfai/models/segmentation/NestedUNet.py +339 -68
- konfai/models/segmentation/UNet.py +140 -30
- konfai/network/blocks.py +331 -187
- konfai/network/network.py +795 -427
- konfai/predictor.py +644 -238
- konfai/trainer.py +509 -222
- konfai/utils/ITK.py +191 -106
- konfai/utils/config.py +152 -95
- konfai/utils/dataset.py +326 -455
- konfai/utils/utils.py +497 -249
- {konfai-1.1.8.dist-info → konfai-1.2.0.dist-info}/METADATA +1 -3
- konfai-1.2.0.dist-info/RECORD +38 -0
- konfai/utils/registration.py +0 -199
- konfai-1.1.8.dist-info/RECORD +0 -39
- {konfai-1.1.8.dist-info → konfai-1.2.0.dist-info}/WHEEL +0 -0
- {konfai-1.1.8.dist-info → konfai-1.2.0.dist-info}/entry_points.txt +0 -0
- {konfai-1.1.8.dist-info → konfai-1.2.0.dist-info}/licenses/LICENSE +0 -0
- {konfai-1.1.8.dist-info → konfai-1.2.0.dist-info}/top_level.txt +0 -0
konfai/utils/utils.py
CHANGED
|
@@ -1,119 +1,163 @@
|
|
|
1
|
+
import argparse
|
|
2
|
+
import importlib.util
|
|
1
3
|
import itertools
|
|
2
|
-
import pynvml
|
|
3
|
-
import psutil
|
|
4
|
-
|
|
5
|
-
import numpy as np
|
|
6
4
|
import os
|
|
7
|
-
import
|
|
8
|
-
|
|
5
|
+
import random
|
|
6
|
+
import re
|
|
7
|
+
import shutil
|
|
8
|
+
import socket
|
|
9
|
+
import subprocess # nosec B404
|
|
10
|
+
import sys
|
|
9
11
|
from abc import ABC, abstractmethod
|
|
12
|
+
from contextlib import closing
|
|
10
13
|
from enum import Enum
|
|
11
|
-
from
|
|
14
|
+
from pathlib import Path
|
|
15
|
+
from typing import Any, TextIO, cast
|
|
12
16
|
|
|
13
|
-
|
|
17
|
+
import numpy as np
|
|
18
|
+
import psutil
|
|
19
|
+
import pynvml
|
|
20
|
+
import requests
|
|
21
|
+
import torch
|
|
14
22
|
import torch.distributed as dist
|
|
15
|
-
import
|
|
16
|
-
import subprocess
|
|
17
|
-
import random
|
|
23
|
+
import torch.nn.functional as F # noqa: N812
|
|
18
24
|
from torch.utils.data import DataLoader
|
|
19
|
-
|
|
20
|
-
import sys
|
|
21
|
-
import re
|
|
22
|
-
|
|
23
|
-
import requests
|
|
25
|
+
from torch.utils.tensorboard.writer import SummaryWriter
|
|
24
26
|
from tqdm import tqdm
|
|
25
|
-
import importlib
|
|
26
|
-
from pathlib import Path
|
|
27
|
-
import shutil
|
|
28
27
|
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
28
|
+
from konfai import (
|
|
29
|
+
config_file,
|
|
30
|
+
cuda_visible_devices,
|
|
31
|
+
evaluations_directory,
|
|
32
|
+
konfai_state,
|
|
33
|
+
predictions_directory,
|
|
34
|
+
statistics_directory,
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def description(model, model_ema=None, show_memory: bool = True, train: bool = True) -> str:
|
|
39
|
+
def loss_desc(model):
|
|
40
|
+
return (
|
|
41
|
+
"("
|
|
42
|
+
+ " ".join(
|
|
43
|
+
[
|
|
44
|
+
f"{name}({(network.optimizer.param_groups[0]['lr'] if network.optimizer else 0):.6f}) : "
|
|
45
|
+
+ " ".join(
|
|
46
|
+
f"{k.split(':')[-1]}({w:.2f}) : {v:.6f}"
|
|
47
|
+
for (k, v), w in zip(
|
|
48
|
+
network.measure.get_last_values().items(), network.measure.get_last_weights().values()
|
|
49
|
+
)
|
|
50
|
+
)
|
|
51
|
+
for name, network in model.module.get_networks().items()
|
|
52
|
+
if network.measure is not None
|
|
53
|
+
]
|
|
54
|
+
)
|
|
55
|
+
+ ")"
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
model_loss_desc = loss_desc(model)
|
|
59
|
+
result = ""
|
|
60
|
+
if len(model_loss_desc) > 2:
|
|
61
|
+
result += f"Loss {model_loss_desc} "
|
|
62
|
+
if model_ema is not None:
|
|
63
|
+
model_ema_loss_desc = loss_desc(model_ema)
|
|
64
|
+
if len(model_ema_loss_desc) > 2:
|
|
65
|
+
result += f"Loss EMA {model_ema_loss_desc} "
|
|
66
|
+
result += gpu_info()
|
|
67
|
+
if show_memory:
|
|
68
|
+
result += f" | {get_memory_info()}"
|
|
38
69
|
return result
|
|
39
70
|
|
|
40
|
-
|
|
71
|
+
|
|
72
|
+
def get_module(classpath: str, default_classpath: str) -> tuple[str, str]:
|
|
41
73
|
if len(classpath.split(":")) > 1:
|
|
42
74
|
module = ".".join(classpath.split(":")[:-1])
|
|
43
75
|
name = classpath.split(":")[-1]
|
|
44
76
|
else:
|
|
45
|
-
module =
|
|
77
|
+
module = (
|
|
78
|
+
default_classpath + ("." if len(classpath.split(".")) > 2 else "") + ".".join(classpath.split(".")[:-1])
|
|
79
|
+
)
|
|
46
80
|
name = classpath.split(".")[-1]
|
|
47
81
|
return module, name.split("/")[0]
|
|
48
82
|
|
|
49
|
-
def cpuInfo() -> str:
|
|
50
|
-
return "CPU ({:.2f} %)".format(psutil.cpu_percent(interval=0.5))
|
|
51
83
|
|
|
52
|
-
def
|
|
53
|
-
return "
|
|
84
|
+
def get_cpu_info() -> str:
|
|
85
|
+
return f"CPU ({psutil.cpu_percent(interval=0.5):.2f} %)"
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def get_memory_info() -> str:
|
|
89
|
+
return f"Memory ({psutil.virtual_memory()[3] / 2**30:.2f}G ({psutil.virtual_memory()[2]:.2f} %))"
|
|
90
|
+
|
|
54
91
|
|
|
55
|
-
def
|
|
56
|
-
return psutil.virtual_memory()[3]/2**30
|
|
92
|
+
def get_memory() -> float:
|
|
93
|
+
return psutil.virtual_memory()[3] / 2**30
|
|
57
94
|
|
|
58
|
-
def memoryForecast(memory_init : float, i : float, size : float) -> str:
|
|
59
|
-
current_memory = getMemory()
|
|
60
|
-
forecast = memory_init + ((current_memory-memory_init)*size/i) if i > 0 else 0
|
|
61
|
-
return "Memory forecast ({:.2f}G ({:.2f} %))".format(forecast, forecast/(psutil.virtual_memory()[0]/2**30)*100)
|
|
62
95
|
|
|
63
|
-
def
|
|
64
|
-
|
|
96
|
+
def memory_forecast(memory_init: float, i: float, size: float) -> str:
|
|
97
|
+
current_memory = get_memory()
|
|
98
|
+
forecast = memory_init + ((current_memory - memory_init) * size / i) if i > 0 else 0
|
|
99
|
+
return f"Memory forecast ({forecast:.2f}G ({forecast / (psutil.virtual_memory()[0] / 2**30) * 100:.2f} %))"
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
def gpu_info() -> str:
|
|
103
|
+
if cuda_visible_devices() == "":
|
|
65
104
|
return ""
|
|
66
|
-
|
|
67
|
-
devices = [int(i) for i in
|
|
105
|
+
|
|
106
|
+
devices = [int(i) for i in cuda_visible_devices().split(",")]
|
|
68
107
|
device = devices[0]
|
|
69
|
-
|
|
108
|
+
|
|
70
109
|
if device < pynvml.nvmlDeviceGetCount():
|
|
71
110
|
handle = pynvml.nvmlDeviceGetHandleByIndex(device)
|
|
72
111
|
memory = pynvml.nvmlDeviceGetMemoryInfo(handle)
|
|
73
112
|
else:
|
|
74
113
|
return ""
|
|
75
|
-
node_name = "Node: {} " +os.environ["SLURMD_NODENAME"] if "SLURMD_NODENAME" in os.environ else ""
|
|
76
|
-
return
|
|
114
|
+
node_name = "Node: {} " + os.environ["SLURMD_NODENAME"] if "SLURMD_NODENAME" in os.environ else ""
|
|
115
|
+
return f"{node_name}GPU({devices}) Memory GPU ({memory.used / 1e9:.2f}G ({memory.used / memory.total * 100:.2f} %))"
|
|
116
|
+
|
|
77
117
|
|
|
78
|
-
def
|
|
118
|
+
def get_max_gpu_memory(device: int | torch.device) -> float:
|
|
79
119
|
if isinstance(device, torch.device):
|
|
80
120
|
if str(device).startswith("cuda:"):
|
|
81
121
|
device = int(str(device).replace("cuda:", ""))
|
|
82
122
|
else:
|
|
83
123
|
return 0
|
|
84
|
-
device = [int(i) for i in
|
|
124
|
+
device = [int(i) for i in cuda_visible_devices().split(",")][device]
|
|
85
125
|
if device < pynvml.nvmlDeviceGetCount():
|
|
86
126
|
handle = pynvml.nvmlDeviceGetHandleByIndex(device)
|
|
87
127
|
memory = pynvml.nvmlDeviceGetMemoryInfo(handle)
|
|
88
128
|
else:
|
|
89
129
|
return 0
|
|
90
|
-
return float(memory.total)/(10**9)
|
|
130
|
+
return float(memory.total) / (10**9)
|
|
91
131
|
|
|
92
|
-
|
|
132
|
+
|
|
133
|
+
def get_gpu_memory(device: int | torch.device) -> float:
|
|
93
134
|
if isinstance(device, torch.device):
|
|
94
135
|
if str(device).startswith("cuda:"):
|
|
95
136
|
device = int(str(device).replace("cuda:", ""))
|
|
96
137
|
else:
|
|
97
138
|
return 0
|
|
98
|
-
device = [int(i) for i in
|
|
139
|
+
device = [int(i) for i in cuda_visible_devices().split(",")][device]
|
|
99
140
|
if device < pynvml.nvmlDeviceGetCount():
|
|
100
141
|
handle = pynvml.nvmlDeviceGetHandleByIndex(device)
|
|
101
142
|
memory = pynvml.nvmlDeviceGetMemoryInfo(handle)
|
|
102
143
|
else:
|
|
103
144
|
return 0
|
|
104
|
-
return float(memory.used)/(10**9)
|
|
145
|
+
return float(memory.used) / (10**9)
|
|
146
|
+
|
|
105
147
|
|
|
106
|
-
class NeedDevice
|
|
148
|
+
class NeedDevice:
|
|
107
149
|
|
|
108
150
|
def __init__(self) -> None:
|
|
109
151
|
super().__init__()
|
|
110
|
-
self.device
|
|
111
|
-
|
|
112
|
-
def
|
|
113
|
-
self.device =
|
|
152
|
+
self.device: torch.device
|
|
153
|
+
|
|
154
|
+
def to(self, device: int):
|
|
155
|
+
self.device = get_device(device)
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
def get_device(device: int):
|
|
159
|
+
return device if torch.cuda.is_available() and device >= 0 else torch.device("cpu")
|
|
114
160
|
|
|
115
|
-
def getDevice(device : int):
|
|
116
|
-
return device if torch.cuda.is_available() and device >=0 else torch.device("cpu")
|
|
117
161
|
|
|
118
162
|
class State(Enum):
|
|
119
163
|
TRAIN = "TRAIN"
|
|
@@ -122,13 +166,18 @@ class State(Enum):
|
|
|
122
166
|
FINE_TUNING = "FINE_TUNING"
|
|
123
167
|
PREDICTION = "PREDICTION"
|
|
124
168
|
EVALUATION = "EVALUATION"
|
|
125
|
-
|
|
169
|
+
|
|
126
170
|
def __str__(self) -> str:
|
|
127
171
|
return self.value
|
|
128
172
|
|
|
129
|
-
|
|
173
|
+
|
|
174
|
+
def get_patch_slices_from_nb_patch_per_dim(
|
|
175
|
+
patch_size_tmp: list[int],
|
|
176
|
+
nb_patch_per_dim: list[tuple[int, bool]],
|
|
177
|
+
overlap: int | None,
|
|
178
|
+
) -> list[tuple[slice, ...]]:
|
|
130
179
|
patch_slices = []
|
|
131
|
-
slices
|
|
180
|
+
slices: list[list[slice]] = []
|
|
132
181
|
if overlap is None:
|
|
133
182
|
overlap = 0
|
|
134
183
|
patch_size = []
|
|
@@ -138,47 +187,54 @@ def get_patch_slices_from_nb_patch_per_dim(patch_size_tmp: list[int], nb_patch_p
|
|
|
138
187
|
patch_size.append(1)
|
|
139
188
|
else:
|
|
140
189
|
patch_size.append(patch_size_tmp[i])
|
|
141
|
-
i+=1
|
|
190
|
+
i += 1
|
|
142
191
|
|
|
143
192
|
for dim, nb in enumerate(nb_patch_per_dim):
|
|
144
193
|
slices.append([])
|
|
145
194
|
for index in range(nb[0]):
|
|
146
|
-
start = (patch_size[dim]-overlap)*index
|
|
195
|
+
start = (patch_size[dim] - overlap) * index
|
|
147
196
|
end = start + patch_size[dim]
|
|
148
|
-
slices[dim].append(slice(start,end))
|
|
197
|
+
slices[dim].append(slice(start, end))
|
|
149
198
|
for chunk in itertools.product(*slices):
|
|
150
199
|
patch_slices.append(tuple(chunk))
|
|
151
200
|
return patch_slices
|
|
152
201
|
|
|
153
|
-
|
|
202
|
+
|
|
203
|
+
def get_patch_slices_from_shape(
|
|
204
|
+
patch_size: list[int], shape: list[int], overlap_tmp: int | None
|
|
205
|
+
) -> tuple[list[tuple[slice, ...]], list[tuple[int, bool]]]:
|
|
154
206
|
if len(shape) != len(patch_size):
|
|
155
|
-
|
|
156
|
-
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
|
|
207
|
+
raise DatasetManagerError(
|
|
208
|
+
f"Dimension mismatch: 'patch_size' has {len(patch_size)} dimensions, but 'shape' has {len(shape)}.",
|
|
209
|
+
f"patch_size: {patch_size}",
|
|
210
|
+
f"shape: {shape}",
|
|
211
|
+
"Both must have the same number of dimensions (e.g., 3D patch for 3D volume).",
|
|
212
|
+
)
|
|
161
213
|
patch_slices = []
|
|
162
214
|
nb_patch_per_dim = []
|
|
163
|
-
slices
|
|
164
|
-
if
|
|
165
|
-
size = [np.ceil(a/b) for a, b in zip(shape, patch_size)]
|
|
215
|
+
slices: list[list[slice]] = []
|
|
216
|
+
if overlap_tmp is None:
|
|
217
|
+
size = [np.ceil(a / b) for a, b in zip(shape, patch_size)]
|
|
166
218
|
tmp = np.zeros(len(size), dtype=np.int_)
|
|
167
219
|
for i, s in enumerate(size):
|
|
168
220
|
if s > 1:
|
|
169
|
-
tmp[i] = np.mod(patch_size[i]-np.mod(shape[i], patch_size[i]), patch_size[i])//(size[i]-1)
|
|
221
|
+
tmp[i] = np.mod(patch_size[i] - np.mod(shape[i], patch_size[i]), patch_size[i]) // (size[i] - 1)
|
|
170
222
|
overlap = tmp
|
|
171
223
|
else:
|
|
172
|
-
overlap = [
|
|
173
|
-
|
|
224
|
+
overlap = [overlap_tmp if size > 1 else 0 for size in patch_size]
|
|
225
|
+
|
|
174
226
|
for dim in range(len(shape)):
|
|
175
|
-
|
|
227
|
+
if overlap[dim] >= patch_size[dim]:
|
|
228
|
+
raise ValueError(
|
|
229
|
+
f"Overlap must be less than patch size, got overlap={overlap[dim]}",
|
|
230
|
+
f" ≥ patch_size={patch_size[dim]} at dim={dim}",
|
|
231
|
+
)
|
|
176
232
|
|
|
177
233
|
for dim in range(len(shape)):
|
|
178
234
|
slices.append([])
|
|
179
235
|
index = 0
|
|
180
236
|
while True:
|
|
181
|
-
start = (patch_size[dim]-overlap[dim])*index
|
|
237
|
+
start = (patch_size[dim] - overlap[dim]) * index
|
|
182
238
|
|
|
183
239
|
end = start + patch_size[dim]
|
|
184
240
|
if end >= shape[dim]:
|
|
@@ -187,84 +243,107 @@ def get_patch_slices_from_shape(patch_size: list[int], shape : list[int], overla
|
|
|
187
243
|
break
|
|
188
244
|
slices[dim].append(slice(start, end))
|
|
189
245
|
index += 1
|
|
190
|
-
nb_patch_per_dim.append((index+1, patch_size[dim] == 1))
|
|
246
|
+
nb_patch_per_dim.append((index + 1, patch_size[dim] == 1))
|
|
191
247
|
|
|
192
248
|
for chunk in itertools.product(*slices):
|
|
193
249
|
patch_slices.append(tuple(chunk))
|
|
194
|
-
|
|
250
|
+
|
|
195
251
|
return patch_slices, nb_patch_per_dim
|
|
196
252
|
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
if len(
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
253
|
+
|
|
254
|
+
def _log_signal_format(array: np.ndarray) -> dict[str, np.ndarray]:
|
|
255
|
+
return {str(i): channel for i, channel in enumerate(array)}
|
|
256
|
+
|
|
257
|
+
|
|
258
|
+
def _log_image_format(array: np.ndarray) -> np.ndarray:
|
|
259
|
+
if len(array.shape) == 2:
|
|
260
|
+
array = np.expand_dims(array, axis=0)
|
|
261
|
+
|
|
262
|
+
if len(array.shape) == 3 and array.shape[0] != 1:
|
|
263
|
+
array = np.expand_dims(array, axis=0)
|
|
264
|
+
if len(array.shape) == 4:
|
|
265
|
+
array = array[:, array.shape[1] // 2]
|
|
266
|
+
|
|
267
|
+
array = array.astype(float)
|
|
268
|
+
b = -np.min(array)
|
|
269
|
+
if (np.max(array) + b) > 0:
|
|
270
|
+
return (array + b) / (np.max(array) + b)
|
|
213
271
|
else:
|
|
214
|
-
return 0*
|
|
272
|
+
return 0 * array
|
|
273
|
+
|
|
215
274
|
|
|
216
|
-
def
|
|
275
|
+
def _log_images_format(array: np.ndarray) -> np.ndarray:
|
|
217
276
|
result = []
|
|
218
|
-
for n in range(
|
|
219
|
-
result.append(
|
|
277
|
+
for n in range(array.shape[0]):
|
|
278
|
+
result.append(_log_image_format(array[n]))
|
|
220
279
|
result = np.stack(result, axis=0)
|
|
221
280
|
return result
|
|
222
281
|
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
282
|
+
|
|
283
|
+
def _log_video_format(array: np.ndarray) -> np.ndarray:
|
|
284
|
+
result_list = []
|
|
285
|
+
for t in range(array.shape[1]):
|
|
286
|
+
result_list.append(_log_images_format(array[:, t, ...]))
|
|
287
|
+
result = np.stack(result_list, axis=1)
|
|
228
288
|
|
|
229
289
|
nb_channel = result.shape[2]
|
|
230
290
|
if nb_channel < 3:
|
|
231
291
|
channel_split = [result[:, :, 0, ...] for i in range(3)]
|
|
232
292
|
else:
|
|
233
293
|
channel_split = np.split(result, 3, axis=0)
|
|
234
|
-
|
|
294
|
+
array = np.zeros((result.shape[0], result.shape[1], 3, *list(result.shape[3:])))
|
|
235
295
|
for i, channels in enumerate(channel_split):
|
|
236
|
-
|
|
237
|
-
return
|
|
296
|
+
array[:, :, i] = np.mean(channels, axis=0)
|
|
297
|
+
return array
|
|
298
|
+
|
|
238
299
|
|
|
239
300
|
class DataLog(Enum):
|
|
240
|
-
SIGNAL
|
|
241
|
-
IMAGE
|
|
242
|
-
IMAGES
|
|
243
|
-
VIDEO
|
|
244
|
-
AUDIO
|
|
301
|
+
SIGNAL = "SIGNAL"
|
|
302
|
+
IMAGE = "IMAGE"
|
|
303
|
+
IMAGES = "IMAGES"
|
|
304
|
+
VIDEO = "VIDEO"
|
|
305
|
+
AUDIO = "AUDIO"
|
|
306
|
+
|
|
307
|
+
def __call__(self, tb: SummaryWriter, name: str, layer: torch.Tensor, it: int):
|
|
308
|
+
if self == DataLog.SIGNAL:
|
|
309
|
+
return [
|
|
310
|
+
tb.add_scalars(name, _log_signal_format(layer[b, :, 0]), layer.shape[0] * it + b)
|
|
311
|
+
for b in range(layer.shape[0])
|
|
312
|
+
]
|
|
313
|
+
elif self == DataLog.IMAGE:
|
|
314
|
+
return tb.add_image(name, _log_image_format(layer[0]), it)
|
|
315
|
+
elif self == DataLog.IMAGES:
|
|
316
|
+
return tb.add_images(name, _log_images_format(layer), it)
|
|
317
|
+
elif self == DataLog.VIDEO:
|
|
318
|
+
return tb.add_video(name, _log_video_format(layer), it)
|
|
319
|
+
elif self == DataLog.AUDIO:
|
|
320
|
+
return tb.add_audio(name, _log_image_format(layer), it)
|
|
321
|
+
else:
|
|
322
|
+
raise ValueError(f"Unsupported DataLog type: {self}")
|
|
323
|
+
|
|
245
324
|
|
|
246
325
|
class Log:
|
|
247
326
|
def __init__(self, name: str, rank: int) -> None:
|
|
248
|
-
if
|
|
249
|
-
path =
|
|
250
|
-
elif
|
|
251
|
-
path =
|
|
327
|
+
if konfai_state() == "PREDICTION":
|
|
328
|
+
path = predictions_directory()
|
|
329
|
+
elif konfai_state() == "EVALUATION":
|
|
330
|
+
path = evaluations_directory()
|
|
252
331
|
else:
|
|
253
|
-
path =
|
|
254
|
-
|
|
332
|
+
path = statistics_directory()
|
|
333
|
+
|
|
255
334
|
self.verbose = os.environ.get("KONFAI_VERBOSE", "True") == "True"
|
|
256
335
|
self.log_path = os.path.join(path, name)
|
|
257
336
|
os.makedirs(self.log_path, exist_ok=True)
|
|
258
337
|
self.rank = rank
|
|
259
|
-
self.file = open(os.path.join(self.log_path, "log_{}.txt"
|
|
338
|
+
self.file = open(os.path.join(self.log_path, f"log_{rank}.txt"), "w", buffering=1)
|
|
260
339
|
self.stdout_bak = sys.stdout
|
|
261
340
|
self.stderr_bak = sys.stderr
|
|
262
341
|
self._buffered_line = ""
|
|
263
|
-
|
|
342
|
+
|
|
264
343
|
def __enter__(self):
|
|
265
344
|
self.file.__enter__()
|
|
266
|
-
sys.stdout = self
|
|
267
|
-
sys.stderr = self
|
|
345
|
+
sys.stdout = cast(TextIO, self)
|
|
346
|
+
sys.stderr = cast(TextIO, self)
|
|
268
347
|
return self
|
|
269
348
|
|
|
270
349
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
@@ -275,25 +354,22 @@ class Log:
|
|
|
275
354
|
def write(self, msg: str):
|
|
276
355
|
if not msg:
|
|
277
356
|
return
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
|
|
283
|
-
if '\r' in msg_clean or '[A' in msg:
|
|
284
|
-
# On garde seulement le contenu après le dernier retour chariot
|
|
285
|
-
msg_clean = msg_clean.split('\r')[-1].strip()
|
|
357
|
+
|
|
358
|
+
ansi_escape = re.compile(r"\x1b\[[0-9;]*[a-zA-Z]")
|
|
359
|
+
msg_clean = ansi_escape.sub("", msg)
|
|
360
|
+
if "\r" in msg_clean or "[A" in msg:
|
|
361
|
+
msg_clean = msg_clean.split("\r")[-1].strip()
|
|
286
362
|
self._buffered_line = msg_clean
|
|
287
363
|
else:
|
|
288
364
|
self._buffered_line = msg_clean.strip()
|
|
289
365
|
|
|
290
366
|
if self._buffered_line:
|
|
291
|
-
# Écrit dans le fichier
|
|
292
367
|
self.file.write(self._buffered_line + "\n")
|
|
293
368
|
self.file.flush()
|
|
294
369
|
if self.verbose and (self.rank == 0 or "KONFAI_CLUSTER" in os.environ):
|
|
295
|
-
sys.__stdout__
|
|
296
|
-
|
|
370
|
+
if sys.__stdout__ is not None:
|
|
371
|
+
sys.__stdout__.write(msg)
|
|
372
|
+
sys.__stdout__.flush()
|
|
297
373
|
|
|
298
374
|
def flush(self):
|
|
299
375
|
self.file.flush()
|
|
@@ -302,65 +378,104 @@ class Log:
|
|
|
302
378
|
return False
|
|
303
379
|
|
|
304
380
|
def fileno(self):
|
|
381
|
+
if sys.__stdout__ is None:
|
|
382
|
+
raise RuntimeError("sys.__stdout__ is None, cannot get fileno")
|
|
305
383
|
return sys.__stdout__.fileno()
|
|
306
|
-
|
|
307
|
-
|
|
308
|
-
|
|
384
|
+
|
|
385
|
+
|
|
386
|
+
class TensorBoard:
|
|
387
|
+
|
|
309
388
|
def __init__(self, name: str) -> None:
|
|
310
|
-
self.process = None
|
|
389
|
+
self.process: subprocess.Popen | None = None
|
|
311
390
|
self.name = name
|
|
312
391
|
|
|
313
392
|
def __enter__(self):
|
|
314
393
|
if "KONFAI_TENSORBOARD_PORT" in os.environ:
|
|
315
|
-
|
|
316
|
-
|
|
394
|
+
tensorboard_exe = shutil.which("tensorboard")
|
|
395
|
+
if tensorboard_exe is None:
|
|
396
|
+
raise RuntimeError("TensorBoard executable not found in PATH.")
|
|
397
|
+
|
|
398
|
+
logdir = (
|
|
399
|
+
predictions_directory() if konfai_state() == "PREDICTION" else statistics_directory() + self.name + "/"
|
|
400
|
+
)
|
|
401
|
+
|
|
402
|
+
port = os.environ.get("KONFAI_TENSORBOARD_PORT")
|
|
403
|
+
if not port or not port.isdigit():
|
|
404
|
+
raise ValueError("Invalid or missing KONFAI_TENSORBOARD_PORT.")
|
|
405
|
+
|
|
406
|
+
command = [
|
|
407
|
+
tensorboard_exe,
|
|
408
|
+
"--logdir",
|
|
409
|
+
logdir,
|
|
410
|
+
"--port",
|
|
411
|
+
port,
|
|
412
|
+
"--bind_all",
|
|
413
|
+
]
|
|
414
|
+
self.process = subprocess.Popen(command, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) # nosec B603
|
|
317
415
|
try:
|
|
318
416
|
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
|
319
|
-
s.connect((
|
|
320
|
-
|
|
417
|
+
s.connect(("10.255.255.255", 1))
|
|
418
|
+
ip = s.getsockname()[0]
|
|
321
419
|
except Exception:
|
|
322
|
-
|
|
420
|
+
ip = "127.0.0.1"
|
|
323
421
|
finally:
|
|
324
422
|
s.close()
|
|
325
|
-
print("Tensorboard : http://{}:{
|
|
423
|
+
print(f"Tensorboard : http://{ip}:{os.environ['KONFAI_TENSORBOARD_PORT']}/")
|
|
326
424
|
return self
|
|
327
|
-
|
|
328
|
-
def __exit__(self,
|
|
425
|
+
|
|
426
|
+
def __exit__(self, exc_type, value, traceback):
|
|
329
427
|
if self.process is not None:
|
|
330
428
|
self.process.terminate()
|
|
331
429
|
self.process.wait()
|
|
332
430
|
|
|
333
|
-
|
|
431
|
+
|
|
432
|
+
class DistributedObject(ABC):
|
|
334
433
|
|
|
335
434
|
def __init__(self, name: str) -> None:
|
|
336
435
|
self.port = find_free_port()
|
|
337
|
-
self.dataloader
|
|
338
|
-
self.manual_seed:
|
|
436
|
+
self.dataloader: list[list[DataLoader]]
|
|
437
|
+
self.manual_seed: int | None = None
|
|
339
438
|
self.name = name
|
|
340
439
|
self.size = 1
|
|
341
|
-
|
|
440
|
+
|
|
342
441
|
@abstractmethod
|
|
343
442
|
def setup(self, world_size: int):
|
|
344
443
|
pass
|
|
345
|
-
|
|
444
|
+
|
|
346
445
|
def __enter__(self):
|
|
347
446
|
return self
|
|
348
447
|
|
|
349
|
-
def __exit__(self,
|
|
448
|
+
def __exit__(self, exc_type, value, traceback):
|
|
350
449
|
cleanup()
|
|
351
450
|
|
|
352
451
|
@abstractmethod
|
|
353
|
-
def run_process(
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
452
|
+
def run_process(
|
|
453
|
+
self,
|
|
454
|
+
world_size: int,
|
|
455
|
+
global_rank: int,
|
|
456
|
+
local_rank: int,
|
|
457
|
+
dataloaders: list[DataLoader],
|
|
458
|
+
):
|
|
459
|
+
pass
|
|
460
|
+
|
|
461
|
+
@staticmethod
|
|
462
|
+
def get_measure(
|
|
463
|
+
world_size: int,
|
|
464
|
+
global_rank: int,
|
|
465
|
+
gpu: int,
|
|
466
|
+
models: dict[str, torch.nn.Module],
|
|
467
|
+
n: int,
|
|
468
|
+
) -> dict[str, tuple[dict[str, tuple[float, float]], dict[str, tuple[float, float]]]]:
|
|
357
469
|
data = {}
|
|
358
470
|
for label, model in models.items():
|
|
359
|
-
for name, network in model.
|
|
471
|
+
for name, network in model.get_networks().items():
|
|
360
472
|
if network.measure is not None:
|
|
361
|
-
data["{}{}"
|
|
473
|
+
data[f"{name}{label}"] = (
|
|
474
|
+
network.measure.format_loss(True, n),
|
|
475
|
+
network.measure.format_loss(False, n),
|
|
476
|
+
)
|
|
362
477
|
outputs = synchronize_data(world_size, gpu, data)
|
|
363
|
-
result = {}
|
|
478
|
+
result: dict[str, tuple[dict[str, tuple[float, float]], dict[str, tuple[float, float]]]] = {}
|
|
364
479
|
if global_rank == 0:
|
|
365
480
|
for output in outputs:
|
|
366
481
|
for k, v in output.items():
|
|
@@ -369,16 +484,19 @@ class DistributedObject():
|
|
|
369
484
|
if k not in result:
|
|
370
485
|
result[k] = ({}, {})
|
|
371
486
|
if u not in result[k][t]:
|
|
372
|
-
result[k][t][u] = (n[0], 0)
|
|
373
|
-
result[k][t][u] = (
|
|
487
|
+
result[k][t][u] = (n[0], 0) # type: ignore[index]
|
|
488
|
+
result[k][t][u] = (
|
|
489
|
+
result[k][t][u][0],
|
|
490
|
+
result[k][t][u][1] + n[1] / world_size, # type: ignore[index]
|
|
491
|
+
)
|
|
374
492
|
return result
|
|
375
493
|
|
|
376
|
-
def __call__(self, rank:
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
|
|
494
|
+
def __call__(self, rank: int | None = None) -> None:
|
|
495
|
+
world_size = len(self.dataloader)
|
|
496
|
+
global_rank, local_rank = setup_gpu(world_size, self.port, rank)
|
|
497
|
+
if global_rank is None or local_rank is None:
|
|
498
|
+
return
|
|
499
|
+
with Log(self.name, global_rank):
|
|
382
500
|
if torch.cuda.is_available():
|
|
383
501
|
pynvml.nvmlInit()
|
|
384
502
|
if self.manual_seed is not None:
|
|
@@ -390,7 +508,7 @@ class DistributedObject():
|
|
|
390
508
|
torch.backends.cuda.matmul.allow_tf32 = True
|
|
391
509
|
torch.backends.cudnn.allow_tf32 = True
|
|
392
510
|
dataloaders = self.dataloader[global_rank]
|
|
393
|
-
if torch.cuda.is_available():
|
|
511
|
+
if torch.cuda.is_available():
|
|
394
512
|
torch.cuda.set_device(local_rank)
|
|
395
513
|
try:
|
|
396
514
|
self.run_process(world_size, global_rank, local_rank, dataloaders)
|
|
@@ -399,27 +517,75 @@ class DistributedObject():
|
|
|
399
517
|
if torch.cuda.is_available():
|
|
400
518
|
pynvml.nvmlShutdown()
|
|
401
519
|
|
|
520
|
+
|
|
402
521
|
def setup(parser: argparse.ArgumentParser) -> DistributedObject:
|
|
403
522
|
# KONFAI arguments
|
|
404
|
-
|
|
405
|
-
|
|
406
|
-
|
|
407
|
-
|
|
408
|
-
|
|
409
|
-
|
|
410
|
-
|
|
411
|
-
|
|
412
|
-
|
|
413
|
-
|
|
414
|
-
|
|
415
|
-
|
|
416
|
-
|
|
417
|
-
|
|
418
|
-
|
|
419
|
-
|
|
420
|
-
|
|
421
|
-
|
|
422
|
-
|
|
523
|
+
konfai = parser.add_argument_group("KONFAI arguments")
|
|
524
|
+
konfai.add_argument("type", type=State, choices=list(State))
|
|
525
|
+
konfai.add_argument("-y", action="store_true", help="Accept overwrite")
|
|
526
|
+
konfai.add_argument("-tb", action="store_true", help="Start TensorBoard")
|
|
527
|
+
konfai.add_argument("-c", "--config", type=str, default="None", help="Configuration file location")
|
|
528
|
+
konfai.add_argument(
|
|
529
|
+
"-g",
|
|
530
|
+
"--gpu",
|
|
531
|
+
type=str,
|
|
532
|
+
default=(os.environ["CUDA_VISIBLE_DEVICES"] if "CUDA_VISIBLE_DEVICES" in os.environ else ""),
|
|
533
|
+
help="List of GPU",
|
|
534
|
+
)
|
|
535
|
+
konfai.add_argument("-cpu", "--cpu", type=str, default="1", help="List of GPU")
|
|
536
|
+
konfai.add_argument(
|
|
537
|
+
"--num-workers",
|
|
538
|
+
"--num_workers",
|
|
539
|
+
default=4,
|
|
540
|
+
type=int,
|
|
541
|
+
help="No. of workers per DataLoader & GPU",
|
|
542
|
+
)
|
|
543
|
+
konfai.add_argument(
|
|
544
|
+
"-models_dir",
|
|
545
|
+
"--MODELS_DIRECTORY",
|
|
546
|
+
type=str,
|
|
547
|
+
default="./Models/",
|
|
548
|
+
help="Models location",
|
|
549
|
+
)
|
|
550
|
+
konfai.add_argument(
|
|
551
|
+
"-checkpoints_dir",
|
|
552
|
+
"--CHECKPOINTS_DIRECTORY",
|
|
553
|
+
type=str,
|
|
554
|
+
default="./Checkpoints/",
|
|
555
|
+
help="Checkpoints location",
|
|
556
|
+
)
|
|
557
|
+
konfai.add_argument("-model", "--MODEL", type=str, default="", help="URL Model")
|
|
558
|
+
konfai.add_argument(
|
|
559
|
+
"-predictions_dir",
|
|
560
|
+
"--PREDICTIONS_DIRECTORY",
|
|
561
|
+
type=str,
|
|
562
|
+
default="./Predictions/",
|
|
563
|
+
help="Predictions location",
|
|
564
|
+
)
|
|
565
|
+
konfai.add_argument(
|
|
566
|
+
"-evaluation_dir",
|
|
567
|
+
"--EVALUATIONS_DIRECTORY",
|
|
568
|
+
type=str,
|
|
569
|
+
default="./Evaluations/",
|
|
570
|
+
help="Evaluations location",
|
|
571
|
+
)
|
|
572
|
+
konfai.add_argument(
|
|
573
|
+
"-statistics_dir",
|
|
574
|
+
"--STATISTICS_DIRECTORY",
|
|
575
|
+
type=str,
|
|
576
|
+
default="./Statistics/",
|
|
577
|
+
help="Statistics location",
|
|
578
|
+
)
|
|
579
|
+
konfai.add_argument(
|
|
580
|
+
"-setups_dir",
|
|
581
|
+
"--SETUPS_DIRECTORY",
|
|
582
|
+
type=str,
|
|
583
|
+
default="./Setups/",
|
|
584
|
+
help="Setups location",
|
|
585
|
+
)
|
|
586
|
+
konfai.add_argument("-log", action="store_true", help="Save log")
|
|
587
|
+
konfai.add_argument("-quiet", action="store_false", help="")
|
|
588
|
+
|
|
423
589
|
args = parser.parse_args()
|
|
424
590
|
config = vars(args)
|
|
425
591
|
|
|
@@ -432,14 +598,14 @@ def setup(parser: argparse.ArgumentParser) -> DistributedObject:
|
|
|
432
598
|
os.environ["KONFAI_PREDICTIONS_DIRECTORY"] = config["PREDICTIONS_DIRECTORY"]
|
|
433
599
|
os.environ["KONFAI_EVALUATIONS_DIRECTORY"] = config["EVALUATIONS_DIRECTORY"]
|
|
434
600
|
os.environ["KONFAI_STATISTICS_DIRECTORY"] = config["STATISTICS_DIRECTORY"]
|
|
435
|
-
|
|
601
|
+
|
|
436
602
|
os.environ["KONFAI_STATE"] = str(config["type"])
|
|
437
|
-
|
|
603
|
+
|
|
438
604
|
os.environ["KONFAI_MODEL"] = config["MODEL"]
|
|
439
605
|
|
|
440
606
|
os.environ["KONFAI_SETUPS_DIRECTORY"] = config["SETUPS_DIRECTORY"]
|
|
441
607
|
|
|
442
|
-
os.environ["KONFAI_OVERWRITE"] =
|
|
608
|
+
os.environ["KONFAI_OVERWRITE"] = str(config["y"])
|
|
443
609
|
os.environ["KONFAI_CONFIG_MODE"] = "Done"
|
|
444
610
|
if config["tb"]:
|
|
445
611
|
os.environ["KONFAI_TENSORBOARD_PORT"] = str(find_free_port())
|
|
@@ -448,37 +614,49 @@ def setup(parser: argparse.ArgumentParser) -> DistributedObject:
|
|
|
448
614
|
|
|
449
615
|
if config["config"] == "None":
|
|
450
616
|
if config["type"] is State.PREDICTION:
|
|
451
|
-
|
|
617
|
+
os.environ["KONFAI_config_file"] = "Prediction.yml"
|
|
452
618
|
elif config["type"] is State.EVALUATION:
|
|
453
|
-
os.environ["
|
|
619
|
+
os.environ["KONFAI_config_file"] = "Evaluation.yml"
|
|
454
620
|
else:
|
|
455
|
-
os.environ["
|
|
621
|
+
os.environ["KONFAI_config_file"] = "Config.yml"
|
|
456
622
|
else:
|
|
457
|
-
os.environ["
|
|
623
|
+
os.environ["KONFAI_config_file"] = config["config"]
|
|
458
624
|
torch.autograd.set_detect_anomaly(True)
|
|
459
625
|
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
|
|
460
|
-
|
|
461
|
-
if config["type"] is State.PREDICTION:
|
|
626
|
+
|
|
627
|
+
if config["type"] is State.PREDICTION:
|
|
462
628
|
from konfai.predictor import Predictor
|
|
463
|
-
|
|
464
|
-
|
|
629
|
+
|
|
630
|
+
os.environ["konfai_root"] = "Predictor"
|
|
631
|
+
return Predictor(config=config_file())
|
|
465
632
|
elif config["type"] is State.EVALUATION:
|
|
466
633
|
from konfai.evaluator import Evaluator
|
|
467
|
-
|
|
468
|
-
|
|
634
|
+
|
|
635
|
+
os.environ["konfai_root"] = "Evaluator"
|
|
636
|
+
return Evaluator(config=config_file())
|
|
469
637
|
else:
|
|
470
638
|
from konfai.trainer import Trainer
|
|
471
|
-
os.environ["KONFAI_ROOT"] = "Trainer"
|
|
472
|
-
return Trainer(config=CONFIG_FILE())
|
|
473
639
|
|
|
640
|
+
os.environ["konfai_root"] = "Trainer"
|
|
641
|
+
return Trainer(config=config_file())
|
|
474
642
|
|
|
475
|
-
|
|
643
|
+
|
|
644
|
+
def setup_gpu(world_size: int, port: int, rank: int | None = None) -> tuple[int | None, int | None]:
|
|
476
645
|
try:
|
|
477
|
-
|
|
478
|
-
|
|
646
|
+
nodelist = os.getenv("SLURM_JOB_NODELIST")
|
|
647
|
+
if nodelist is None:
|
|
648
|
+
raise RuntimeError("SLURM_JOB_NODELIST is not set.")
|
|
649
|
+
scontrol_path = shutil.which("scontrol")
|
|
650
|
+
if scontrol_path is None:
|
|
651
|
+
raise FileNotFoundError("scontrol not found in PATH")
|
|
652
|
+
host_name = subprocess.check_output(
|
|
653
|
+
[scontrol_path, "show", "hostnames", nodelist], text=True, stderr=subprocess.DEVNULL
|
|
654
|
+
).strip() # nosec B603
|
|
655
|
+
except Exception:
|
|
479
656
|
host_name = "localhost"
|
|
480
657
|
if rank is None:
|
|
481
658
|
import submitit
|
|
659
|
+
|
|
482
660
|
job_env = submitit.JobEnvironment()
|
|
483
661
|
global_rank = job_env.global_rank
|
|
484
662
|
local_rank = job_env.local_rank
|
|
@@ -487,42 +665,48 @@ def setupGPU(world_size: int, port: int, rank: Union[int, None] = None) -> tuple
|
|
|
487
665
|
local_rank = rank
|
|
488
666
|
if global_rank >= world_size:
|
|
489
667
|
return None, None
|
|
490
|
-
#print("tcp://{}:{}".format(host_name, port))
|
|
668
|
+
# print("tcp://{}:{}".format(host_name, port))
|
|
491
669
|
if torch.cuda.is_available():
|
|
492
670
|
torch.cuda.empty_cache()
|
|
493
|
-
dist.init_process_group(
|
|
671
|
+
dist.init_process_group(
|
|
672
|
+
"nccl",
|
|
673
|
+
rank=global_rank,
|
|
674
|
+
init_method=f"tcp://{host_name}:{port}",
|
|
675
|
+
world_size=world_size,
|
|
676
|
+
)
|
|
494
677
|
else:
|
|
495
678
|
if not dist.is_initialized():
|
|
496
679
|
dist.init_process_group(
|
|
497
680
|
backend="gloo",
|
|
498
681
|
init_method=f"tcp://{host_name}:{port}",
|
|
499
682
|
rank=global_rank,
|
|
500
|
-
world_size=world_size
|
|
683
|
+
world_size=world_size,
|
|
501
684
|
)
|
|
502
685
|
return global_rank, local_rank
|
|
503
686
|
|
|
504
|
-
import socket
|
|
505
|
-
from contextlib import closing
|
|
506
687
|
|
|
507
688
|
def find_free_port():
|
|
508
689
|
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
|
|
509
|
-
s.bind((
|
|
690
|
+
s.bind(("", 0))
|
|
510
691
|
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
|
511
692
|
return s.getsockname()[1]
|
|
512
|
-
|
|
693
|
+
|
|
694
|
+
|
|
513
695
|
def cleanup():
|
|
514
696
|
if dist.is_initialized():
|
|
515
697
|
dist.destroy_process_group()
|
|
516
698
|
|
|
517
|
-
|
|
699
|
+
|
|
700
|
+
def synchronize_data(world_size: int, gpu: int, data: Any) -> list[Any]:
|
|
518
701
|
if torch.cuda.is_available():
|
|
519
|
-
outputs: list[dict[str, tuple[dict[str, float], dict[str, float]]]] = [None
|
|
702
|
+
outputs: list[dict[str, tuple[dict[str, float], dict[str, float]]] | None] = [None] * world_size
|
|
520
703
|
torch.cuda.set_device(gpu)
|
|
521
704
|
dist.all_gather_object(outputs, data)
|
|
522
705
|
else:
|
|
523
706
|
outputs = [data]
|
|
524
707
|
return outputs
|
|
525
708
|
|
|
709
|
+
|
|
526
710
|
def _resample(data: torch.Tensor, size: list[int]) -> torch.Tensor:
|
|
527
711
|
if data.dtype == torch.uint8:
|
|
528
712
|
mode = "nearest"
|
|
@@ -530,38 +714,79 @@ def _resample(data: torch.Tensor, size: list[int]) -> torch.Tensor:
|
|
|
530
714
|
mode = "bilinear"
|
|
531
715
|
else:
|
|
532
716
|
mode = "trilinear"
|
|
533
|
-
return
|
|
717
|
+
return (
|
|
718
|
+
F.interpolate(
|
|
719
|
+
data.type(torch.float32).unsqueeze(0),
|
|
720
|
+
size=tuple(reversed(size)),
|
|
721
|
+
mode=mode,
|
|
722
|
+
)
|
|
723
|
+
.squeeze(0)
|
|
724
|
+
.type(data.dtype)
|
|
725
|
+
)
|
|
726
|
+
|
|
534
727
|
|
|
535
728
|
def _affine_matrix(matrix: torch.Tensor, translation: torch.Tensor) -> torch.Tensor:
|
|
536
|
-
return torch.cat(
|
|
729
|
+
return torch.cat(
|
|
730
|
+
(
|
|
731
|
+
torch.cat((matrix, translation.unsqueeze(0).T), dim=1),
|
|
732
|
+
torch.tensor([[0, 0, 0, 1]]),
|
|
733
|
+
),
|
|
734
|
+
dim=0,
|
|
735
|
+
)
|
|
736
|
+
|
|
537
737
|
|
|
538
738
|
def _resample_affine(data: torch.Tensor, matrix: torch.Tensor):
|
|
539
739
|
if data.dtype == torch.uint8:
|
|
540
740
|
mode = "nearest"
|
|
541
741
|
else:
|
|
542
742
|
mode = "bilinear"
|
|
543
|
-
return
|
|
544
|
-
|
|
743
|
+
return (
|
|
744
|
+
F.grid_sample(
|
|
745
|
+
data.unsqueeze(0).type(torch.float32),
|
|
746
|
+
F.affine_grid(
|
|
747
|
+
matrix[:, :-1, ...].type(torch.float32),
|
|
748
|
+
[1] + list(data.shape),
|
|
749
|
+
align_corners=True,
|
|
750
|
+
),
|
|
751
|
+
align_corners=True,
|
|
752
|
+
mode=mode,
|
|
753
|
+
padding_mode="reflection",
|
|
754
|
+
)
|
|
755
|
+
.squeeze(0)
|
|
756
|
+
.type(data.dtype)
|
|
757
|
+
)
|
|
545
758
|
|
|
546
759
|
|
|
547
760
|
def download_url(model_name: str, url: str) -> str:
|
|
548
761
|
spec = importlib.util.find_spec("konfai")
|
|
549
|
-
|
|
762
|
+
if spec is None or spec.submodule_search_locations is None:
|
|
763
|
+
raise ImportError("Could not locate 'konfai' package")
|
|
764
|
+
locations = spec.submodule_search_locations
|
|
765
|
+
if not isinstance(locations, list) or not locations:
|
|
766
|
+
raise ImportError("No valid submodule_search_locations found")
|
|
767
|
+
base_path = Path(locations[0]) / "metric" / "models"
|
|
768
|
+
os.makedirs(base_path, exist_ok=True)
|
|
769
|
+
|
|
550
770
|
subdirs = Path(model_name).parent
|
|
551
771
|
model_dir = base_path / subdirs
|
|
552
772
|
model_dir.mkdir(exist_ok=True)
|
|
553
|
-
filetmp = model_dir / ("tmp_"+str(Path(model_name).name))
|
|
773
|
+
filetmp = model_dir / ("tmp_" + str(Path(model_name).name))
|
|
554
774
|
file = model_dir / Path(model_name).name
|
|
555
775
|
if file.exists():
|
|
556
776
|
return str(file)
|
|
557
|
-
|
|
777
|
+
|
|
558
778
|
try:
|
|
559
779
|
print(f"[FOCUS] Downloading {model_name} to {file}")
|
|
560
|
-
with requests.get(url+model_name, stream=True) as r:
|
|
780
|
+
with requests.get(url + model_name, stream=True, timeout=10) as r:
|
|
561
781
|
r.raise_for_status()
|
|
562
|
-
total = int(r.headers.get(
|
|
563
|
-
with open(filetmp,
|
|
564
|
-
with tqdm(
|
|
782
|
+
total = int(r.headers.get("content-length", 0))
|
|
783
|
+
with open(filetmp, "wb") as f:
|
|
784
|
+
with tqdm(
|
|
785
|
+
total=total,
|
|
786
|
+
unit="B",
|
|
787
|
+
unit_scale=True,
|
|
788
|
+
desc=f"Downloading {model_name}",
|
|
789
|
+
) as pbar:
|
|
565
790
|
for chunk in r.iter_content(chunk_size=8192):
|
|
566
791
|
f.write(chunk)
|
|
567
792
|
pbar.update(len(chunk))
|
|
@@ -573,63 +798,86 @@ def download_url(model_name: str, url: str) -> str:
|
|
|
573
798
|
if filetmp.exists():
|
|
574
799
|
os.remove(filetmp)
|
|
575
800
|
return str(file)
|
|
576
|
-
|
|
577
|
-
SUPPORTED_EXTENSIONS = [
|
|
578
|
-
"mha", "mhd", # MetaImage
|
|
579
|
-
"nii", "nii.gz", # NIfTI
|
|
580
|
-
"nrrd", "nrrd.gz", # NRRD
|
|
581
|
-
"gipl", "gipl.gz", # GIPL
|
|
582
|
-
"hdr", "img", # Analyze
|
|
583
|
-
"dcm", # DICOM (si GDCM activé)
|
|
584
|
-
"tif", "tiff", # TIFF
|
|
585
|
-
"png", "jpg", "jpeg", "bmp", # 2D formats
|
|
586
|
-
"h5", "itk.txt", ".fcsv", ".xml", ".vtk", ".npy"
|
|
587
801
|
|
|
802
|
+
|
|
803
|
+
SUPPORTED_EXTENSIONS = [
|
|
804
|
+
"mha",
|
|
805
|
+
"mhd", # MetaImage
|
|
806
|
+
"nii",
|
|
807
|
+
"nii.gz", # NIfTI
|
|
808
|
+
"nrrd",
|
|
809
|
+
"nrrd.gz", # NRRD
|
|
810
|
+
"gipl",
|
|
811
|
+
"gipl.gz", # GIPL
|
|
812
|
+
"hdr",
|
|
813
|
+
"img", # Analyze
|
|
814
|
+
"dcm", # DICOM (si GDCM activé)
|
|
815
|
+
"tif",
|
|
816
|
+
"tiff", # TIFF
|
|
817
|
+
"png",
|
|
818
|
+
"jpg",
|
|
819
|
+
"jpeg",
|
|
820
|
+
"bmp", # 2D formats
|
|
821
|
+
"h5",
|
|
822
|
+
"itk.txt",
|
|
823
|
+
".fcsv",
|
|
824
|
+
".xml",
|
|
825
|
+
".vtk",
|
|
826
|
+
".npy",
|
|
588
827
|
]
|
|
589
828
|
|
|
829
|
+
|
|
590
830
|
class KonfAIError(Exception):
|
|
591
831
|
|
|
592
|
-
def __init__(self,
|
|
593
|
-
super().__init__(
|
|
832
|
+
def __init__(self, type_error: str, *messages: str) -> None:
|
|
833
|
+
super().__init__(
|
|
834
|
+
f"\n[{type_error}] {messages[0]}" + ("\n" if len(messages) > 0 else "") + "\n→\t".join(messages[1:])
|
|
835
|
+
)
|
|
594
836
|
|
|
595
837
|
|
|
596
838
|
class ConfigError(KonfAIError):
|
|
597
839
|
|
|
598
840
|
def __init__(self, *message) -> None:
|
|
599
|
-
super().__init__("Config", message)
|
|
841
|
+
super().__init__("Config", *message)
|
|
600
842
|
|
|
601
843
|
|
|
602
844
|
class DatasetManagerError(KonfAIError):
|
|
603
845
|
|
|
604
846
|
def __init__(self, *message) -> None:
|
|
605
|
-
super().__init__("DatasetManager", message)
|
|
847
|
+
super().__init__("DatasetManager", *message)
|
|
848
|
+
|
|
606
849
|
|
|
607
850
|
class MeasureError(KonfAIError):
|
|
608
851
|
|
|
609
852
|
def __init__(self, *message) -> None:
|
|
610
|
-
super().__init__("Measure", message)
|
|
853
|
+
super().__init__("Measure", *message)
|
|
854
|
+
|
|
611
855
|
|
|
612
856
|
class TrainerError(KonfAIError):
|
|
613
857
|
|
|
614
858
|
def __init__(self, *message) -> None:
|
|
615
|
-
super().__init__("Trainer", message)
|
|
859
|
+
super().__init__("Trainer", *message)
|
|
860
|
+
|
|
616
861
|
|
|
617
862
|
class AugmentationError(KonfAIError):
|
|
618
863
|
|
|
619
864
|
def __init__(self, *message) -> None:
|
|
620
|
-
super().__init__("Augmentation", message)
|
|
865
|
+
super().__init__("Augmentation", *message)
|
|
866
|
+
|
|
621
867
|
|
|
622
868
|
class EvaluatorError(KonfAIError):
|
|
623
869
|
|
|
624
870
|
def __init__(self, *message) -> None:
|
|
625
|
-
super().__init__("Evaluator", message)
|
|
871
|
+
super().__init__("Evaluator", *message)
|
|
872
|
+
|
|
626
873
|
|
|
627
874
|
class PredictorError(KonfAIError):
|
|
628
875
|
|
|
629
876
|
def __init__(self, *message) -> None:
|
|
630
|
-
super().__init__("Predictor", message)
|
|
877
|
+
super().__init__("Predictor", *message)
|
|
878
|
+
|
|
631
879
|
|
|
632
880
|
class TransformError(KonfAIError):
|
|
633
881
|
|
|
634
882
|
def __init__(self, *message) -> None:
|
|
635
|
-
super().__init__("Transform", message)
|
|
883
|
+
super().__init__("Transform", *message)
|