konfai 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of konfai might be problematic. Click here for more details.
- konfai/__init__.py +16 -0
- konfai/data/HDF5.py +326 -0
- konfai/data/__init__.py +0 -0
- konfai/data/augmentation.py +597 -0
- konfai/data/dataset.py +470 -0
- konfai/data/transform.py +536 -0
- konfai/evaluator.py +146 -0
- konfai/main.py +43 -0
- konfai/metric/__init__.py +0 -0
- konfai/metric/measure.py +488 -0
- konfai/metric/schedulers.py +49 -0
- konfai/models/classification/convNeXt.py +175 -0
- konfai/models/classification/resnet.py +116 -0
- konfai/models/generation/cStyleGan.py +137 -0
- konfai/models/generation/ddpm.py +218 -0
- konfai/models/generation/diffusionGan.py +557 -0
- konfai/models/generation/gan.py +134 -0
- konfai/models/generation/vae.py +72 -0
- konfai/models/registration/registration.py +136 -0
- konfai/models/representation/representation.py +57 -0
- konfai/models/segmentation/NestedUNet.py +53 -0
- konfai/models/segmentation/UNet.py +58 -0
- konfai/network/__init__.py +0 -0
- konfai/network/blocks.py +348 -0
- konfai/network/network.py +950 -0
- konfai/predictor.py +366 -0
- konfai/trainer.py +330 -0
- konfai/utils/ITK.py +269 -0
- konfai/utils/Registration.py +199 -0
- konfai/utils/__init__.py +0 -0
- konfai/utils/config.py +218 -0
- konfai/utils/dataset.py +764 -0
- konfai/utils/utils.py +493 -0
- konfai-1.0.0.dist-info/METADATA +68 -0
- konfai-1.0.0.dist-info/RECORD +39 -0
- konfai-1.0.0.dist-info/WHEEL +5 -0
- konfai-1.0.0.dist-info/entry_points.txt +3 -0
- konfai-1.0.0.dist-info/licenses/LICENSE +201 -0
- konfai-1.0.0.dist-info/top_level.txt +1 -0
konfai/utils/utils.py
ADDED
|
@@ -0,0 +1,493 @@
|
|
|
1
|
+
import itertools
|
|
2
|
+
import pynvml
|
|
3
|
+
import psutil
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
import os
|
|
7
|
+
import torch
|
|
8
|
+
|
|
9
|
+
from abc import ABC, abstractmethod
|
|
10
|
+
from enum import Enum
|
|
11
|
+
from typing import Any, Union
|
|
12
|
+
|
|
13
|
+
from KonfAI.konfai import CONFIG_FILE, STATISTICS_DIRECTORY, PREDICTIONS_DIRECTORY, DL_API_STATE, CUDA_VISIBLE_DEVICES
|
|
14
|
+
import torch.distributed as dist
|
|
15
|
+
import argparse
|
|
16
|
+
import subprocess
|
|
17
|
+
import random
|
|
18
|
+
from torch.utils.data import DataLoader
|
|
19
|
+
import torch.nn.functional as F
|
|
20
|
+
import sys
|
|
21
|
+
|
|
22
|
+
def description(model, modelEMA = None, showMemory: bool = True) -> str:
|
|
23
|
+
values_desc = lambda weights, values: " ".join(["{}({:.2f}) : {:.6f}".format(name.split(":")[-1], weight, value) for (name, value), weight in zip(values.items(), weights.values())])
|
|
24
|
+
model_desc = lambda model : "("+" ".join(["{}({:.6f}) : {}".format(name, network.optimizer.param_groups[0]['lr'] if network.optimizer is not None else 0, values_desc(network.measure.getLastWeights(), network.measure.getLastValues())) for name, network in model.module.getNetworks().items() if network.measure is not None])+")"
|
|
25
|
+
result = "Loss {}".format(model_desc(model))
|
|
26
|
+
if modelEMA is not None:
|
|
27
|
+
result += "Loss EMA {}".format(model_desc(modelEMA))
|
|
28
|
+
result += " "+gpuInfo()
|
|
29
|
+
if showMemory:
|
|
30
|
+
result +=" | {}".format(memoryInfo())
|
|
31
|
+
return result
|
|
32
|
+
|
|
33
|
+
def _getModule(classpath : str, type : str) -> tuple[str, str]:
|
|
34
|
+
if len(classpath.split("_")) > 1:
|
|
35
|
+
module = ".".join(classpath.split("_")[:-1])
|
|
36
|
+
name = classpath.split("_")[-1]
|
|
37
|
+
else:
|
|
38
|
+
module = "KonfAI."+type
|
|
39
|
+
name = classpath
|
|
40
|
+
return module, name
|
|
41
|
+
|
|
42
|
+
def cpuInfo() -> str:
|
|
43
|
+
return "CPU ({:.2f} %)".format(psutil.cpu_percent(interval=0.5))
|
|
44
|
+
|
|
45
|
+
def memoryInfo() -> str:
|
|
46
|
+
return "Memory ({:.2f}G ({:.2f} %))".format(psutil.virtual_memory()[3]/2**30, psutil.virtual_memory()[2])
|
|
47
|
+
|
|
48
|
+
def getMemory() -> float:
|
|
49
|
+
return psutil.virtual_memory()[3]/2**30
|
|
50
|
+
|
|
51
|
+
def memoryForecast(memory_init : float, i : float, size : float) -> str:
|
|
52
|
+
current_memory = getMemory()
|
|
53
|
+
forecast = memory_init + ((current_memory-memory_init)*size/i) if i > 0 else 0
|
|
54
|
+
return "Memory forecast ({:.2f}G ({:.2f} %))".format(forecast, forecast/(psutil.virtual_memory()[0]/2**30)*100)
|
|
55
|
+
|
|
56
|
+
def gpuInfo() -> str:
|
|
57
|
+
if CUDA_VISIBLE_DEVICES() == "":
|
|
58
|
+
return ""
|
|
59
|
+
|
|
60
|
+
devices = [int(i) for i in CUDA_VISIBLE_DEVICES().split(",")]
|
|
61
|
+
device = devices[0]
|
|
62
|
+
|
|
63
|
+
if device < pynvml.nvmlDeviceGetCount():
|
|
64
|
+
handle = pynvml.nvmlDeviceGetHandleByIndex(device)
|
|
65
|
+
memory = pynvml.nvmlDeviceGetMemoryInfo(handle)
|
|
66
|
+
else:
|
|
67
|
+
return ""
|
|
68
|
+
node_name = "Node: {} " +os.environ["SLURMD_NODENAME"] if "SLURMD_NODENAME" in os.environ else ""
|
|
69
|
+
return "{}GPU({}) Memory GPU ({:.2f}G ({:.2f} %))".format(node_name, devices, float(memory.used)/(10**9), float(memory.used)/float(memory.total)*100)
|
|
70
|
+
|
|
71
|
+
def getMaxGPUMemory(device : Union[int, torch.device]) -> float:
|
|
72
|
+
if isinstance(device, torch.device):
|
|
73
|
+
if str(device).startswith("cuda:"):
|
|
74
|
+
device = int(str(device).replace("cuda:", ""))
|
|
75
|
+
else:
|
|
76
|
+
return 0
|
|
77
|
+
device = [int(i) for i in CUDA_VISIBLE_DEVICES().split(",")][device]
|
|
78
|
+
if device < pynvml.nvmlDeviceGetCount():
|
|
79
|
+
handle = pynvml.nvmlDeviceGetHandleByIndex(device)
|
|
80
|
+
memory = pynvml.nvmlDeviceGetMemoryInfo(handle)
|
|
81
|
+
else:
|
|
82
|
+
return 0
|
|
83
|
+
return float(memory.total)/(10**9)
|
|
84
|
+
|
|
85
|
+
def getGPUMemory(device : Union[int, torch.device]) -> float:
|
|
86
|
+
if isinstance(device, torch.device):
|
|
87
|
+
if str(device).startswith("cuda:"):
|
|
88
|
+
device = int(str(device).replace("cuda:", ""))
|
|
89
|
+
else:
|
|
90
|
+
return 0
|
|
91
|
+
device = [int(i) for i in CUDA_VISIBLE_DEVICES().split(",")][device]
|
|
92
|
+
if device < pynvml.nvmlDeviceGetCount():
|
|
93
|
+
handle = pynvml.nvmlDeviceGetHandleByIndex(device)
|
|
94
|
+
memory = pynvml.nvmlDeviceGetMemoryInfo(handle)
|
|
95
|
+
else:
|
|
96
|
+
return 0
|
|
97
|
+
return float(memory.used)/(10**9)
|
|
98
|
+
|
|
99
|
+
class NeedDevice(ABC):
|
|
100
|
+
|
|
101
|
+
def __init__(self) -> None:
|
|
102
|
+
super().__init__()
|
|
103
|
+
self.device : torch.device
|
|
104
|
+
|
|
105
|
+
def setDevice(self, device : int):
|
|
106
|
+
self.device = getDevice(device)
|
|
107
|
+
|
|
108
|
+
def getDevice(device : int):
|
|
109
|
+
return device if torch.cuda.is_available() and device >=0 else torch.device("cpu")
|
|
110
|
+
|
|
111
|
+
class State(Enum):
|
|
112
|
+
TRAIN = "TRAIN"
|
|
113
|
+
RESUME = "RESUME"
|
|
114
|
+
TRANSFER_LEARNING = "TRANSFER_LEARNING"
|
|
115
|
+
FINE_TUNING = "FINE_TUNING"
|
|
116
|
+
PREDICTION = "PREDICTION"
|
|
117
|
+
EVALUATION = "EVALUATION"
|
|
118
|
+
|
|
119
|
+
def __str__(self) -> str:
|
|
120
|
+
return self.value
|
|
121
|
+
|
|
122
|
+
def get_patch_slices_from_nb_patch_per_dim(patch_size_tmp: list[int], nb_patch_per_dim : list[tuple[int, bool]], overlap: Union[int, None]) -> list[tuple[slice]]:
|
|
123
|
+
patch_slices = []
|
|
124
|
+
slices : list[list[slice]] = []
|
|
125
|
+
if overlap is None:
|
|
126
|
+
overlap = 0
|
|
127
|
+
patch_size = []
|
|
128
|
+
i = 0
|
|
129
|
+
for nb in nb_patch_per_dim:
|
|
130
|
+
if nb[1]:
|
|
131
|
+
patch_size.append(1)
|
|
132
|
+
else:
|
|
133
|
+
patch_size.append(patch_size_tmp[i])
|
|
134
|
+
i+=1
|
|
135
|
+
|
|
136
|
+
for dim, nb in enumerate(nb_patch_per_dim):
|
|
137
|
+
slices.append([])
|
|
138
|
+
for index in range(nb[0]):
|
|
139
|
+
start = (patch_size[dim]-overlap)*index
|
|
140
|
+
end = start + patch_size[dim]
|
|
141
|
+
slices[dim].append(slice(start,end))
|
|
142
|
+
for chunk in itertools.product(*slices):
|
|
143
|
+
patch_slices.append(tuple(chunk))
|
|
144
|
+
return patch_slices
|
|
145
|
+
|
|
146
|
+
def get_patch_slices_from_shape(patch_size: list[int], shape : list[int], overlap: Union[int, None]) -> tuple[list[tuple[slice]], list[tuple[int, bool]]]:
|
|
147
|
+
if len(shape) != len(patch_size):
|
|
148
|
+
return [tuple([slice(0, s) for s in shape])], [(1, True)]*len(shape)
|
|
149
|
+
|
|
150
|
+
patch_slices = []
|
|
151
|
+
nb_patch_per_dim = []
|
|
152
|
+
slices : list[list[slice]] = []
|
|
153
|
+
if overlap is None:
|
|
154
|
+
size = [np.ceil(a/b) for a, b in zip(shape, patch_size)]
|
|
155
|
+
tmp = np.zeros(len(size), dtype=np.int_)
|
|
156
|
+
for i, s in enumerate(size):
|
|
157
|
+
if s > 1:
|
|
158
|
+
tmp[i] = np.mod(patch_size[i]-np.mod(shape[i], patch_size[i]), patch_size[i])//(size[i]-1)
|
|
159
|
+
overlap = tmp
|
|
160
|
+
else:
|
|
161
|
+
overlap = [overlap if size > 1 else 0 for size in patch_size]
|
|
162
|
+
|
|
163
|
+
for dim in range(len(shape)):
|
|
164
|
+
assert overlap[dim] < patch_size[dim], "Overlap must be less than patch size"
|
|
165
|
+
|
|
166
|
+
for dim in range(len(shape)):
|
|
167
|
+
slices.append([])
|
|
168
|
+
index = 0
|
|
169
|
+
while True:
|
|
170
|
+
start = (patch_size[dim]-overlap[dim])*index
|
|
171
|
+
|
|
172
|
+
end = start + patch_size[dim]
|
|
173
|
+
if end >= shape[dim]:
|
|
174
|
+
end = shape[dim]
|
|
175
|
+
slices[dim].append(slice(start, end))
|
|
176
|
+
break
|
|
177
|
+
slices[dim].append(slice(start, end))
|
|
178
|
+
index += 1
|
|
179
|
+
nb_patch_per_dim.append((index+1, patch_size[dim] == 1))
|
|
180
|
+
|
|
181
|
+
for chunk in itertools.product(*slices):
|
|
182
|
+
patch_slices.append(tuple(chunk))
|
|
183
|
+
|
|
184
|
+
return patch_slices, nb_patch_per_dim
|
|
185
|
+
|
|
186
|
+
def _logSignalFormat(input : np.ndarray):
|
|
187
|
+
return {str(i): channel for i, channel in enumerate(input)}
|
|
188
|
+
|
|
189
|
+
def _logImageFormat(input : np.ndarray):
|
|
190
|
+
if len(input.shape) == 2:
|
|
191
|
+
input = np.expand_dims(input, axis=0)
|
|
192
|
+
|
|
193
|
+
if len(input.shape) == 3 and input.shape[0] != 1:
|
|
194
|
+
input = np.expand_dims(input, axis=0)
|
|
195
|
+
|
|
196
|
+
if len(input.shape) == 4:
|
|
197
|
+
input = input[:, input.shape[1]//2]
|
|
198
|
+
|
|
199
|
+
if input.dtype == np.uint8:
|
|
200
|
+
return input
|
|
201
|
+
|
|
202
|
+
input = input.astype(float)
|
|
203
|
+
b = -np.min(input)
|
|
204
|
+
if (np.max(input)+b) > 0:
|
|
205
|
+
return (input+b)/(np.max(input)+b)
|
|
206
|
+
else:
|
|
207
|
+
return 0*input
|
|
208
|
+
|
|
209
|
+
def _logImagesFormat(input : np.ndarray):
|
|
210
|
+
result = []
|
|
211
|
+
for n in range(input.shape[0]):
|
|
212
|
+
result.append(_logImageFormat(input[n]))
|
|
213
|
+
result = np.stack(result, axis=0)
|
|
214
|
+
return result
|
|
215
|
+
|
|
216
|
+
def _logVideoFormat(input : np.ndarray):
|
|
217
|
+
result = []
|
|
218
|
+
for t in range(input.shape[1]):
|
|
219
|
+
result.append( _logImagesFormat(input[:, t,...]))
|
|
220
|
+
result = np.stack(result, axis=1)
|
|
221
|
+
|
|
222
|
+
nb_channel = result.shape[2]
|
|
223
|
+
if nb_channel < 3:
|
|
224
|
+
channel_split = [result[:, :, 0, ...] for i in range(3)]
|
|
225
|
+
else:
|
|
226
|
+
channel_split = np.split(result, 3, axis=0)
|
|
227
|
+
input = np.zeros((result.shape[0], result.shape[1], 3, *list(result.shape[3:])))
|
|
228
|
+
for i, channels in enumerate(channel_split):
|
|
229
|
+
input[:,:,i] = np.mean(channels, axis=0)
|
|
230
|
+
return input
|
|
231
|
+
|
|
232
|
+
class DataLog(Enum):
|
|
233
|
+
SIGNAL = lambda tb, name, layer, it : [tb.add_scalars(name, _logSignalFormat(layer[b, :, 0]), layer.shape[0]*it+b) for b in range(layer.shape[0])],
|
|
234
|
+
IMAGE = lambda tb, name, layer, it : tb.add_image(name, _logImageFormat(layer[0]), it),
|
|
235
|
+
IMAGES = lambda tb, name, layer, it : tb.add_images(name, _logImagesFormat(layer), it),
|
|
236
|
+
VIDEO = lambda tb, name, layer, it : tb.add_video(name, _logVideoFormat(layer), it),
|
|
237
|
+
AUDIO = lambda tb, name, layer, it : tb.add_audio(name, _logImageFormat(layer), it)
|
|
238
|
+
|
|
239
|
+
class Log():
|
|
240
|
+
|
|
241
|
+
def __init__(self, name: str) -> None:
|
|
242
|
+
path = PREDICTIONS_DIRECTORY() if DL_API_STATE() == "PREDICTION" else STATISTICS_DIRECTORY()
|
|
243
|
+
if not os.path.exists("{}{}".format(path, name)):
|
|
244
|
+
os.makedirs("{}{}".format(path, name))
|
|
245
|
+
self.file = open("{}{}/log.txt".format(path, name), 'w')
|
|
246
|
+
self.stdout_bak = sys.stdout
|
|
247
|
+
self.stderr_bak = sys.stderr
|
|
248
|
+
self.verbose = os.environ["DEEP_LEARNING_VERBOSE"] == "True"
|
|
249
|
+
|
|
250
|
+
def __enter__(self):
|
|
251
|
+
self.file.__enter__()
|
|
252
|
+
sys.stdout = self
|
|
253
|
+
sys.stderr = self
|
|
254
|
+
return self
|
|
255
|
+
|
|
256
|
+
def __exit__(self, type, value, traceback):
|
|
257
|
+
self.file.__exit__(type, value, traceback)
|
|
258
|
+
sys.stdout = self.stdout_bak
|
|
259
|
+
sys.stderr = self.stderr_bak
|
|
260
|
+
|
|
261
|
+
def write(self, msg):
|
|
262
|
+
if msg.strip() != "":
|
|
263
|
+
self.file.write(msg)
|
|
264
|
+
if self.verbose:
|
|
265
|
+
print(msg, file=sys.__stdout__)
|
|
266
|
+
|
|
267
|
+
def flush(self):
|
|
268
|
+
pass
|
|
269
|
+
|
|
270
|
+
class TensorBoard():
|
|
271
|
+
|
|
272
|
+
def __init__(self, name: str) -> None:
|
|
273
|
+
self.process = None
|
|
274
|
+
self.name = name
|
|
275
|
+
|
|
276
|
+
def __enter__(self):
|
|
277
|
+
if "DEEP_LEARNING_TENSORBOARD_PORT" in os.environ:
|
|
278
|
+
command = ["tensorboard", "--logdir", PREDICTIONS_DIRECTORY() if DL_API_STATE() == "PREDICTION" else STATISTICS_DIRECTORY() + self.name + "/", "--port", os.environ["DEEP_LEARNING_TENSORBOARD_PORT"], "--bind_all"]
|
|
279
|
+
self.process = subprocess.Popen(command, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
|
|
280
|
+
try:
|
|
281
|
+
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
|
282
|
+
s.connect(('10.255.255.255', 1))
|
|
283
|
+
IP = s.getsockname()[0]
|
|
284
|
+
except Exception:
|
|
285
|
+
IP = '127.0.0.1'
|
|
286
|
+
finally:
|
|
287
|
+
s.close()
|
|
288
|
+
print("Tensorboard : http://{}:{}/".format(IP, os.environ["DEEP_LEARNING_TENSORBOARD_PORT"]))
|
|
289
|
+
return self
|
|
290
|
+
|
|
291
|
+
def __exit__(self, type, value, traceback):
|
|
292
|
+
if self.process is not None:
|
|
293
|
+
self.process.terminate()
|
|
294
|
+
self.process.wait()
|
|
295
|
+
|
|
296
|
+
class DistributedObject():
|
|
297
|
+
|
|
298
|
+
def __init__(self, name: str) -> None:
|
|
299
|
+
self.port = find_free_port()
|
|
300
|
+
self.dataloader : list[list[DataLoader]]
|
|
301
|
+
self.manual_seed: bool = None
|
|
302
|
+
self.name = name
|
|
303
|
+
self.size = 1
|
|
304
|
+
|
|
305
|
+
@abstractmethod
|
|
306
|
+
def setup(self, world_size: int):
|
|
307
|
+
pass
|
|
308
|
+
|
|
309
|
+
def __enter__(self):
|
|
310
|
+
return self
|
|
311
|
+
|
|
312
|
+
def __exit__(self, type, value, traceback):
|
|
313
|
+
pass
|
|
314
|
+
|
|
315
|
+
@abstractmethod
|
|
316
|
+
def run_process(self, world_size: int, global_rank: int, local_rank: int, dataloaders: list[DataLoader]):
|
|
317
|
+
pass
|
|
318
|
+
|
|
319
|
+
def getMeasure(world_size: int, global_rank: int, gpu: int, models: dict[str, torch.nn.Module], n: int) -> dict[str, tuple[dict[str, tuple[float, float]], dict[str, tuple[float, float]]]]:
|
|
320
|
+
data = {}
|
|
321
|
+
for label, model in models.items():
|
|
322
|
+
for name, network in model.getNetworks().items():
|
|
323
|
+
if network.measure is not None:
|
|
324
|
+
data["{}{}".format(name, label)] = (network.measure.format(True, n), network.measure.format(False, n))
|
|
325
|
+
outputs = synchronize_data(world_size, gpu, data)
|
|
326
|
+
result = {}
|
|
327
|
+
if global_rank == 0:
|
|
328
|
+
for output in outputs:
|
|
329
|
+
for k, v in output.items():
|
|
330
|
+
for t in range(len(v)):
|
|
331
|
+
for u, n in v[t].items():
|
|
332
|
+
if k not in result:
|
|
333
|
+
result[k] = ({}, {})
|
|
334
|
+
if u not in result[k][t]:
|
|
335
|
+
result[k][t][u] = (n[0], 0)
|
|
336
|
+
result[k][t][u] = (result[k][t][u][0], result[k][t][u][1]+n[1]/world_size)
|
|
337
|
+
return result
|
|
338
|
+
|
|
339
|
+
def __call__(self, rank: Union[int, None] = None) -> None:
|
|
340
|
+
with Log(self.name):
|
|
341
|
+
world_size = len(self.dataloader)
|
|
342
|
+
global_rank, local_rank = setupGPU(world_size, self.port, rank)
|
|
343
|
+
if global_rank is None:
|
|
344
|
+
return
|
|
345
|
+
if torch.cuda.is_available():
|
|
346
|
+
pynvml.nvmlInit()
|
|
347
|
+
if self.manual_seed is not None:
|
|
348
|
+
np.random.seed(self.manual_seed * world_size + global_rank)
|
|
349
|
+
random.seed(self.manual_seed * world_size + global_rank)
|
|
350
|
+
torch.manual_seed(self.manual_seed * world_size + global_rank)
|
|
351
|
+
torch.backends.cudnn.benchmark = self.manual_seed is None
|
|
352
|
+
torch.backends.cudnn.deterministic = self.manual_seed is not None
|
|
353
|
+
torch.backends.cuda.matmul.allow_tf32 = True
|
|
354
|
+
torch.backends.cudnn.allow_tf32 = True
|
|
355
|
+
dataloaders = self.dataloader[global_rank]
|
|
356
|
+
if torch.cuda.is_available():
|
|
357
|
+
torch.cuda.set_device(local_rank)
|
|
358
|
+
|
|
359
|
+
self.run_process(world_size, global_rank, local_rank, dataloaders)
|
|
360
|
+
if torch.cuda.is_available():
|
|
361
|
+
pynvml.nvmlShutdown()
|
|
362
|
+
cleanup()
|
|
363
|
+
|
|
364
|
+
def setupAPI(parser: argparse.ArgumentParser) -> DistributedObject:
|
|
365
|
+
# API arguments
|
|
366
|
+
api_args = parser.add_argument_group('API arguments')
|
|
367
|
+
api_args.add_argument("type", type=State, choices=list(State))
|
|
368
|
+
api_args.add_argument('-y', action='store_true', help="Accept overwrite")
|
|
369
|
+
api_args.add_argument('-tb', action='store_true', help='Start TensorBoard')
|
|
370
|
+
api_args.add_argument("-c", "--config", type=str, default="None", help="Configuration file location")
|
|
371
|
+
api_args.add_argument("-g", "--gpu", type=str, default=os.environ["CUDA_VISIBLE_DEVICES"] if "CUDA_VISIBLE_DEVICES" in os.environ else "", help="List of GPU")
|
|
372
|
+
api_args.add_argument('--num-workers', '--num_workers', default=4, type=int, help='No. of workers per DataLoader & GPU')
|
|
373
|
+
api_args.add_argument("-models_dir", "--MODELS_DIRECTORY", type=str, default="./Models/", help="Models location")
|
|
374
|
+
api_args.add_argument("-checkpoints_dir", "--CHECKPOINTS_DIRECTORY", type=str, default="./Checkpoints/", help="Checkpoints location")
|
|
375
|
+
api_args.add_argument("-model", "--MODEL", type=str, default="", help="URL Model")
|
|
376
|
+
api_args.add_argument("-predictions_dir", "--PREDICTIONS_DIRECTORY", type=str, default="./Predictions/", help="Predictions location")
|
|
377
|
+
api_args.add_argument("-evaluation_dir", "--EVALUATIONS_DIRECTORY", type=str, default="./Evaluations/", help="Evaluations location")
|
|
378
|
+
api_args.add_argument("-statistics_dir", "--STATISTICS_DIRECTORY", type=str, default="./Statistics/", help="Statistics location")
|
|
379
|
+
api_args.add_argument("-setups_dir", "--SETUPS_DIRECTORY", type=str, default="./Setups/", help="Setups location")
|
|
380
|
+
api_args.add_argument('-log', action='store_true', help='Save log')
|
|
381
|
+
api_args.add_argument('-quiet', action='store_false', help='')
|
|
382
|
+
|
|
383
|
+
|
|
384
|
+
args = parser.parse_args()
|
|
385
|
+
config = vars(args)
|
|
386
|
+
|
|
387
|
+
os.environ["CUDA_VISIBLE_DEVICES"] = config["gpu"]
|
|
388
|
+
os.environ["DL_API_MODELS_DIRECTORY"] = config["MODELS_DIRECTORY"]
|
|
389
|
+
os.environ["DL_API_CHECKPOINTS_DIRECTORY"] = config["CHECKPOINTS_DIRECTORY"]
|
|
390
|
+
os.environ["DL_API_PREDICTIONS_DIRECTORY"] = config["PREDICTIONS_DIRECTORY"]
|
|
391
|
+
os.environ["DL_API_EVALUATIONS_DIRECTORY"] = config["EVALUATIONs_DIRECTORY"]
|
|
392
|
+
os.environ["DL_API_STATISTICS_DIRECTORY"] = config["STATISTICS_DIRECTORY"]
|
|
393
|
+
|
|
394
|
+
os.environ["DL_API_STATE"] = str(config["type"])
|
|
395
|
+
|
|
396
|
+
os.environ["DL_API_MODEL"] = config["MODEL"]
|
|
397
|
+
|
|
398
|
+
os.environ["DL_API_SETUPS_DIRECTORY"] = config["SETUPS_DIRECTORY"]
|
|
399
|
+
|
|
400
|
+
os.environ["DL_API_OVERWRITE"] = "{}".format(config["y"])
|
|
401
|
+
os.environ["DEEP_LEANING_API_CONFIG_MODE"] = "Done"
|
|
402
|
+
if config["tb"]:
|
|
403
|
+
os.environ["DEEP_LEARNING_TENSORBOARD_PORT"] = str(find_free_port())
|
|
404
|
+
|
|
405
|
+
os.environ["DEEP_LEARNING_VERBOSE"] = str(config["quiet"])
|
|
406
|
+
|
|
407
|
+
if config["config"] == "None":
|
|
408
|
+
if config["type"] is State.PREDICTION:
|
|
409
|
+
os.environ["DEEP_LEARNING_API_CONFIG_FILE"] = "Prediction.yml"
|
|
410
|
+
elif config["type"] is State.EVALUATION:
|
|
411
|
+
os.environ["DEEP_LEARNING_API_CONFIG_FILE"] = "Evaluation.yml"
|
|
412
|
+
else:
|
|
413
|
+
os.environ["DEEP_LEARNING_API_CONFIG_FILE"] = "Config.yml"
|
|
414
|
+
else:
|
|
415
|
+
os.environ["DEEP_LEARNING_API_CONFIG_FILE"] = config["config"]
|
|
416
|
+
torch.autograd.set_detect_anomaly(True)
|
|
417
|
+
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
|
|
418
|
+
|
|
419
|
+
if config["type"] is State.PREDICTION:
|
|
420
|
+
from KonfAI.konfai.predictor import Predictor
|
|
421
|
+
os.environ["DEEP_LEARNING_API_ROOT"] = "Predictor"
|
|
422
|
+
return Predictor(config=CONFIG_FILE())
|
|
423
|
+
elif config["type"] is State.EVALUATION:
|
|
424
|
+
from KonfAI.konfai.evaluator import Evaluator
|
|
425
|
+
os.environ["DEEP_LEARNING_API_ROOT"] = "Evaluator"
|
|
426
|
+
return Evaluator(config=CONFIG_FILE())
|
|
427
|
+
else:
|
|
428
|
+
from KonfAI.konfai.trainer import Trainer
|
|
429
|
+
os.environ["DEEP_LEARNING_API_ROOT"] = "Trainer"
|
|
430
|
+
return Trainer(config=CONFIG_FILE())
|
|
431
|
+
|
|
432
|
+
import submitit
|
|
433
|
+
|
|
434
|
+
def setupGPU(world_size: int, port: int, rank: Union[int, None] = None) -> tuple[int , int]:
|
|
435
|
+
try:
|
|
436
|
+
host_name = subprocess.check_output("scontrol show hostnames {}".format(os.getenv('SLURM_JOB_NODELIST')).split()).decode().splitlines()[0]
|
|
437
|
+
except:
|
|
438
|
+
host_name = "localhost"
|
|
439
|
+
if rank is None:
|
|
440
|
+
job_env = submitit.JobEnvironment()
|
|
441
|
+
global_rank = job_env.global_rank
|
|
442
|
+
local_rank = job_env.local_rank
|
|
443
|
+
else:
|
|
444
|
+
global_rank = rank
|
|
445
|
+
local_rank = rank
|
|
446
|
+
if global_rank >= world_size:
|
|
447
|
+
return None, None
|
|
448
|
+
print("tcp://{}:{}".format(host_name, port))
|
|
449
|
+
if torch.cuda.is_available():
|
|
450
|
+
torch.cuda.empty_cache()
|
|
451
|
+
dist.init_process_group("nccl", rank=global_rank, init_method="tcp://{}:{}".format(host_name, port), world_size=world_size)
|
|
452
|
+
return global_rank, local_rank
|
|
453
|
+
|
|
454
|
+
import socket
|
|
455
|
+
from contextlib import closing
|
|
456
|
+
|
|
457
|
+
def find_free_port():
|
|
458
|
+
with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
|
|
459
|
+
s.bind(('', 0))
|
|
460
|
+
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
|
|
461
|
+
return s.getsockname()[1]
|
|
462
|
+
|
|
463
|
+
def cleanup():
|
|
464
|
+
if torch.cuda.is_available():
|
|
465
|
+
dist.destroy_process_group()
|
|
466
|
+
|
|
467
|
+
def synchronize_data(world_size: int, gpu: int, data: any) -> list[Any]:
|
|
468
|
+
if torch.cuda.is_available():
|
|
469
|
+
outputs: list[dict[str, tuple[dict[str, float], dict[str, float]]]] = [None for _ in range(world_size)]
|
|
470
|
+
torch.cuda.set_device(gpu)
|
|
471
|
+
dist.all_gather_object(outputs, data)
|
|
472
|
+
else:
|
|
473
|
+
outputs = [data]
|
|
474
|
+
return outputs
|
|
475
|
+
|
|
476
|
+
def _resample(data: torch.Tensor, size: list[int]) -> torch.Tensor:
|
|
477
|
+
if data.dtype == torch.uint8:
|
|
478
|
+
mode = "nearest"
|
|
479
|
+
elif len(data.shape) < 4:
|
|
480
|
+
mode = "bilinear"
|
|
481
|
+
else:
|
|
482
|
+
mode = "trilinear"
|
|
483
|
+
return F.interpolate(data.type(torch.float32).unsqueeze(0), size=tuple([s for s in reversed(size)]), mode=mode).squeeze(0).type(data.dtype)
|
|
484
|
+
|
|
485
|
+
def _affine_matrix(matrix: torch.Tensor, translation: torch.Tensor) -> torch.Tensor:
|
|
486
|
+
return torch.cat((torch.cat((matrix, translation.unsqueeze(0).T), dim=1), torch.tensor([[0, 0, 0, 1]])), dim=0)
|
|
487
|
+
|
|
488
|
+
def _resample_affine(data: torch.Tensor, matrix: torch.Tensor):
|
|
489
|
+
if data.dtype == torch.uint8:
|
|
490
|
+
mode = "nearest"
|
|
491
|
+
else:
|
|
492
|
+
mode = "bilinear"
|
|
493
|
+
return F.grid_sample(data.unsqueeze(0).type(torch.float32), F.affine_grid(matrix[:, :-1,...].type(torch.float32), [1]+list(data.shape), align_corners=True), align_corners=True, mode=mode, padding_mode="reflection").squeeze(0).type(data.dtype)
|
|
@@ -0,0 +1,68 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: konfai
|
|
3
|
+
Version: 1.0.0
|
|
4
|
+
Summary: Modular and configurable Deep Learning framework with YAML and PyTorch
|
|
5
|
+
Author-email: Valentin Boussot <boussot.v@gmail.com>
|
|
6
|
+
License-Expression: Apache-2.0
|
|
7
|
+
Project-URL: Homepage, https://github.com/vboussot/KonfAI
|
|
8
|
+
Project-URL: Repository, https://github.com/vboussot/KonfAI
|
|
9
|
+
Project-URL: Issues, https://github.com/vboussot/KonfAI/issues
|
|
10
|
+
Project-URL: License, https://www.apache.org/licenses/LICENSE-2.0
|
|
11
|
+
Requires-Python: >=3.8
|
|
12
|
+
Description-Content-Type: text/markdown
|
|
13
|
+
License-File: LICENSE
|
|
14
|
+
Requires-Dist: torch
|
|
15
|
+
Requires-Dist: tqdm
|
|
16
|
+
Requires-Dist: numpy
|
|
17
|
+
Requires-Dist: ruamel.yaml
|
|
18
|
+
Requires-Dist: psutil
|
|
19
|
+
Requires-Dist: tensorboard
|
|
20
|
+
Requires-Dist: SimpleITK
|
|
21
|
+
Requires-Dist: lxml
|
|
22
|
+
Requires-Dist: h5py
|
|
23
|
+
Requires-Dist: pynvml
|
|
24
|
+
Provides-Extra: vtk
|
|
25
|
+
Requires-Dist: vtk; extra == "vtk"
|
|
26
|
+
Provides-Extra: lpips
|
|
27
|
+
Requires-Dist: lpips; extra == "lpips"
|
|
28
|
+
Provides-Extra: cluster
|
|
29
|
+
Requires-Dist: submitit; extra == "cluster"
|
|
30
|
+
Dynamic: license-file
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
# ๐ง KonfAI
|
|
34
|
+
<img src="logo.png" alt="KonfAI Logo" width="200" align="right"/>
|
|
35
|
+
|
|
36
|
+
**KonfAI** is a modular and highly configurable deep learning framework built on PyTorch, driven entirely by YAML configuration files.
|
|
37
|
+
|
|
38
|
+
It is designed to support complex medical imaging workflows, flexible model architectures, customizable training loops, and advanced loss scheduling, without hardcoding anything.
|
|
39
|
+
|
|
40
|
+
---
|
|
41
|
+
|
|
42
|
+
## ๐ง Key Features
|
|
43
|
+
|
|
44
|
+
- ๐ Full training/prediction/evaluation orchestration via YAML configuration file
|
|
45
|
+
- ๐งฉ Modular plugin-like structure (transforms, augmentations, models, losses, schedulers)
|
|
46
|
+
- ๐ Dynamic criterion scheduling per head / target
|
|
47
|
+
- ๐ง Multi-branch / multi-output model support
|
|
48
|
+
- ๐ฅ๏ธ Cluster-ready
|
|
49
|
+
- ๐ TensorBoard and custom logging support
|
|
50
|
+
|
|
51
|
+
---
|
|
52
|
+
|
|
53
|
+
## ๐ Installation
|
|
54
|
+
|
|
55
|
+
```bash
|
|
56
|
+
git clone https://github.com/vboussot/KonfAI.git && cd KonfAI
|
|
57
|
+
pip install -e .
|
|
58
|
+
```
|
|
59
|
+
|
|
60
|
+
---
|
|
61
|
+
|
|
62
|
+
## ๐งช Usage
|
|
63
|
+
|
|
64
|
+
```bash
|
|
65
|
+
konfai TRAIN --gpu 0
|
|
66
|
+
konfai PREDICTION --gpu 0
|
|
67
|
+
konfai EVALUATION
|
|
68
|
+
```
|
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
konfai/__init__.py,sha256=jXMTNml38eX6FSq9d3C_gJVgRLTKHPBUXqOLC7Pqkuo,828
|
|
2
|
+
konfai/evaluator.py,sha256=6YU3bXBy1YcS2kl0YwebkgwXYeO_VoBN59-MLMqj-ds,7468
|
|
3
|
+
konfai/main.py,sha256=Y-8vTgHVecMglB8krGKLbBoFlcoo-Oa4l-PjTDhBzbM,2142
|
|
4
|
+
konfai/predictor.py,sha256=IOh70fCVm8q-sgZyACNperTO-Vel8QKvYp-FoBY39ao,20236
|
|
5
|
+
konfai/trainer.py,sha256=zGvXd2skcqWgRN9GLx93xYB4Bv-46C0oo7J9My4Levk,16901
|
|
6
|
+
konfai/data/HDF5.py,sha256=QfU8VnyslkQhT_k2AJNFMNkJK7lm75ozxT4WELZt8wk,14390
|
|
7
|
+
konfai/data/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
8
|
+
konfai/data/augmentation.py,sha256=F4hAuw5j-8zpmAzSAIn4VDhZG1-CCDCZQD_aoAYO-LA,31757
|
|
9
|
+
konfai/data/dataset.py,sha256=JnouutlJGlgIe7XnAijFe4FUatTZyWDLzWSLE3OxjZM,23963
|
|
10
|
+
konfai/data/transform.py,sha256=AxGqtEHC6XIk4AT-Clbq7w1sWUBrONLMsDHGiC0wIhI,25171
|
|
11
|
+
konfai/metric/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
12
|
+
konfai/metric/measure.py,sha256=fYkLgtttoGiQWU9a67Kk8eEOP1K_4yCwk87IA3x84Uc,21840
|
|
13
|
+
konfai/metric/schedulers.py,sha256=zQORXilMGPGBHq7Rg3l9JmbdoEebnTiB6yPHxuyEl7c,1377
|
|
14
|
+
konfai/models/classification/convNeXt.py,sha256=JbEPl7PHLY-FkSsjqDJtLlixH5JBIv_TDf0rMZeEe8s,9260
|
|
15
|
+
konfai/models/classification/resnet.py,sha256=i7SWA00yQdSEpLwYnn0_Cf32e4uAAzN-67mbcuC0wzw,7975
|
|
16
|
+
konfai/models/generation/cStyleGan.py,sha256=0b3lUH3PYJEBDSN6J0wanA-R2bWyj-gMmaDvSMiHw3A,8057
|
|
17
|
+
konfai/models/generation/ddpm.py,sha256=X4hMnrkIfiqdyL_XY8YE6hKeLUPqV15EkrsWxQ3DXEY,13184
|
|
18
|
+
konfai/models/generation/diffusionGan.py,sha256=x0sksJe2CI_Oqj-skuXfFChlWhBw4s0HoYk2ICE0fDM,33237
|
|
19
|
+
konfai/models/generation/gan.py,sha256=fN6CIDi23_XcsXdre8fJL_xRXhtLVSnLzoVcvF3bmLk,7871
|
|
20
|
+
konfai/models/generation/vae.py,sha256=Qq1nKnAGyv7VsgH5nZatMPjjrlIpXZ3bWifvu8W8W7Y,4733
|
|
21
|
+
konfai/models/registration/registration.py,sha256=vpzFl-ozga1TDLadyGE6w0xosDblC6PBABmcS-4E31w,6362
|
|
22
|
+
konfai/models/representation/representation.py,sha256=t9gX49KhyK7PvO2CduK_RmUrNY0m8pyr5nXwrOZ8szo,2691
|
|
23
|
+
konfai/models/segmentation/NestedUNet.py,sha256=hgbawKp27elTgkK5APjEa1nUEt2oJi9x1nsJLE22p7g,4318
|
|
24
|
+
konfai/models/segmentation/UNet.py,sha256=X-ddiQBJboq3ZHDfj8CvoZiNc9RT-eXKlBXriaL_mFY,4235
|
|
25
|
+
konfai/network/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
26
|
+
konfai/network/blocks.py,sha256=U5P4EAHShvQh6s5VWkhn5_VIGj8gHpUvd7WyVWG3MiI,13542
|
|
27
|
+
konfai/network/network.py,sha256=qfqnPFQHOVu_fmZXOrFKXj3-Ej0HzcfJpkhW6FEbgD4,45810
|
|
28
|
+
konfai/utils/ITK.py,sha256=tErt6ymFesZWg4Mw6ZYc8kOC9zpUSjBRQMm1PUnvgF8,13962
|
|
29
|
+
konfai/utils/Registration.py,sha256=v1srEBOcgDnHrx0YtsK6bcj0yCMH7wNeaQ3wC7gEvOw,8898
|
|
30
|
+
konfai/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
31
|
+
konfai/utils/config.py,sha256=4NkR1BWXxwtsf95G_xH9_t-pe2unmbRvxzyKCIuS6eE,9894
|
|
32
|
+
konfai/utils/dataset.py,sha256=PNmzxaeFGMBh-pAaR92tDresDjF0QXSli2eGNJyzSVQ,35465
|
|
33
|
+
konfai/utils/utils.py,sha256=BCH7nwDf_Aqt_5lDieBCUEI0HE5TlinG4Z96EOAwA94,20206
|
|
34
|
+
konfai-1.0.0.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
|
35
|
+
konfai-1.0.0.dist-info/METADATA,sha256=flcowFH9PD-G7CYRMjBtpX2Dy-XLTY2mCXPfSUL5kAM,1971
|
|
36
|
+
konfai-1.0.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
37
|
+
konfai-1.0.0.dist-info/entry_points.txt,sha256=fG82HRN5-g39ACSOCtij_I3N6EHxfYnMR0D7TI_8pW8,81
|
|
38
|
+
konfai-1.0.0.dist-info/top_level.txt,sha256=xF470dkIlFoFqTZEOlRehKJr4WU_8OKGXrJqYm9vWKs,7
|
|
39
|
+
konfai-1.0.0.dist-info/RECORD,,
|