konfai 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of konfai might be problematic. Click here for more details.

konfai/utils/utils.py ADDED
@@ -0,0 +1,493 @@
1
+ import itertools
2
+ import pynvml
3
+ import psutil
4
+
5
+ import numpy as np
6
+ import os
7
+ import torch
8
+
9
+ from abc import ABC, abstractmethod
10
+ from enum import Enum
11
+ from typing import Any, Union
12
+
13
+ from KonfAI.konfai import CONFIG_FILE, STATISTICS_DIRECTORY, PREDICTIONS_DIRECTORY, DL_API_STATE, CUDA_VISIBLE_DEVICES
14
+ import torch.distributed as dist
15
+ import argparse
16
+ import subprocess
17
+ import random
18
+ from torch.utils.data import DataLoader
19
+ import torch.nn.functional as F
20
+ import sys
21
+
22
+ def description(model, modelEMA = None, showMemory: bool = True) -> str:
23
+ values_desc = lambda weights, values: " ".join(["{}({:.2f}) : {:.6f}".format(name.split(":")[-1], weight, value) for (name, value), weight in zip(values.items(), weights.values())])
24
+ model_desc = lambda model : "("+" ".join(["{}({:.6f}) : {}".format(name, network.optimizer.param_groups[0]['lr'] if network.optimizer is not None else 0, values_desc(network.measure.getLastWeights(), network.measure.getLastValues())) for name, network in model.module.getNetworks().items() if network.measure is not None])+")"
25
+ result = "Loss {}".format(model_desc(model))
26
+ if modelEMA is not None:
27
+ result += "Loss EMA {}".format(model_desc(modelEMA))
28
+ result += " "+gpuInfo()
29
+ if showMemory:
30
+ result +=" | {}".format(memoryInfo())
31
+ return result
32
+
33
+ def _getModule(classpath : str, type : str) -> tuple[str, str]:
34
+ if len(classpath.split("_")) > 1:
35
+ module = ".".join(classpath.split("_")[:-1])
36
+ name = classpath.split("_")[-1]
37
+ else:
38
+ module = "KonfAI."+type
39
+ name = classpath
40
+ return module, name
41
+
42
+ def cpuInfo() -> str:
43
+ return "CPU ({:.2f} %)".format(psutil.cpu_percent(interval=0.5))
44
+
45
+ def memoryInfo() -> str:
46
+ return "Memory ({:.2f}G ({:.2f} %))".format(psutil.virtual_memory()[3]/2**30, psutil.virtual_memory()[2])
47
+
48
+ def getMemory() -> float:
49
+ return psutil.virtual_memory()[3]/2**30
50
+
51
+ def memoryForecast(memory_init : float, i : float, size : float) -> str:
52
+ current_memory = getMemory()
53
+ forecast = memory_init + ((current_memory-memory_init)*size/i) if i > 0 else 0
54
+ return "Memory forecast ({:.2f}G ({:.2f} %))".format(forecast, forecast/(psutil.virtual_memory()[0]/2**30)*100)
55
+
56
+ def gpuInfo() -> str:
57
+ if CUDA_VISIBLE_DEVICES() == "":
58
+ return ""
59
+
60
+ devices = [int(i) for i in CUDA_VISIBLE_DEVICES().split(",")]
61
+ device = devices[0]
62
+
63
+ if device < pynvml.nvmlDeviceGetCount():
64
+ handle = pynvml.nvmlDeviceGetHandleByIndex(device)
65
+ memory = pynvml.nvmlDeviceGetMemoryInfo(handle)
66
+ else:
67
+ return ""
68
+ node_name = "Node: {} " +os.environ["SLURMD_NODENAME"] if "SLURMD_NODENAME" in os.environ else ""
69
+ return "{}GPU({}) Memory GPU ({:.2f}G ({:.2f} %))".format(node_name, devices, float(memory.used)/(10**9), float(memory.used)/float(memory.total)*100)
70
+
71
+ def getMaxGPUMemory(device : Union[int, torch.device]) -> float:
72
+ if isinstance(device, torch.device):
73
+ if str(device).startswith("cuda:"):
74
+ device = int(str(device).replace("cuda:", ""))
75
+ else:
76
+ return 0
77
+ device = [int(i) for i in CUDA_VISIBLE_DEVICES().split(",")][device]
78
+ if device < pynvml.nvmlDeviceGetCount():
79
+ handle = pynvml.nvmlDeviceGetHandleByIndex(device)
80
+ memory = pynvml.nvmlDeviceGetMemoryInfo(handle)
81
+ else:
82
+ return 0
83
+ return float(memory.total)/(10**9)
84
+
85
+ def getGPUMemory(device : Union[int, torch.device]) -> float:
86
+ if isinstance(device, torch.device):
87
+ if str(device).startswith("cuda:"):
88
+ device = int(str(device).replace("cuda:", ""))
89
+ else:
90
+ return 0
91
+ device = [int(i) for i in CUDA_VISIBLE_DEVICES().split(",")][device]
92
+ if device < pynvml.nvmlDeviceGetCount():
93
+ handle = pynvml.nvmlDeviceGetHandleByIndex(device)
94
+ memory = pynvml.nvmlDeviceGetMemoryInfo(handle)
95
+ else:
96
+ return 0
97
+ return float(memory.used)/(10**9)
98
+
99
+ class NeedDevice(ABC):
100
+
101
+ def __init__(self) -> None:
102
+ super().__init__()
103
+ self.device : torch.device
104
+
105
+ def setDevice(self, device : int):
106
+ self.device = getDevice(device)
107
+
108
+ def getDevice(device : int):
109
+ return device if torch.cuda.is_available() and device >=0 else torch.device("cpu")
110
+
111
+ class State(Enum):
112
+ TRAIN = "TRAIN"
113
+ RESUME = "RESUME"
114
+ TRANSFER_LEARNING = "TRANSFER_LEARNING"
115
+ FINE_TUNING = "FINE_TUNING"
116
+ PREDICTION = "PREDICTION"
117
+ EVALUATION = "EVALUATION"
118
+
119
+ def __str__(self) -> str:
120
+ return self.value
121
+
122
+ def get_patch_slices_from_nb_patch_per_dim(patch_size_tmp: list[int], nb_patch_per_dim : list[tuple[int, bool]], overlap: Union[int, None]) -> list[tuple[slice]]:
123
+ patch_slices = []
124
+ slices : list[list[slice]] = []
125
+ if overlap is None:
126
+ overlap = 0
127
+ patch_size = []
128
+ i = 0
129
+ for nb in nb_patch_per_dim:
130
+ if nb[1]:
131
+ patch_size.append(1)
132
+ else:
133
+ patch_size.append(patch_size_tmp[i])
134
+ i+=1
135
+
136
+ for dim, nb in enumerate(nb_patch_per_dim):
137
+ slices.append([])
138
+ for index in range(nb[0]):
139
+ start = (patch_size[dim]-overlap)*index
140
+ end = start + patch_size[dim]
141
+ slices[dim].append(slice(start,end))
142
+ for chunk in itertools.product(*slices):
143
+ patch_slices.append(tuple(chunk))
144
+ return patch_slices
145
+
146
+ def get_patch_slices_from_shape(patch_size: list[int], shape : list[int], overlap: Union[int, None]) -> tuple[list[tuple[slice]], list[tuple[int, bool]]]:
147
+ if len(shape) != len(patch_size):
148
+ return [tuple([slice(0, s) for s in shape])], [(1, True)]*len(shape)
149
+
150
+ patch_slices = []
151
+ nb_patch_per_dim = []
152
+ slices : list[list[slice]] = []
153
+ if overlap is None:
154
+ size = [np.ceil(a/b) for a, b in zip(shape, patch_size)]
155
+ tmp = np.zeros(len(size), dtype=np.int_)
156
+ for i, s in enumerate(size):
157
+ if s > 1:
158
+ tmp[i] = np.mod(patch_size[i]-np.mod(shape[i], patch_size[i]), patch_size[i])//(size[i]-1)
159
+ overlap = tmp
160
+ else:
161
+ overlap = [overlap if size > 1 else 0 for size in patch_size]
162
+
163
+ for dim in range(len(shape)):
164
+ assert overlap[dim] < patch_size[dim], "Overlap must be less than patch size"
165
+
166
+ for dim in range(len(shape)):
167
+ slices.append([])
168
+ index = 0
169
+ while True:
170
+ start = (patch_size[dim]-overlap[dim])*index
171
+
172
+ end = start + patch_size[dim]
173
+ if end >= shape[dim]:
174
+ end = shape[dim]
175
+ slices[dim].append(slice(start, end))
176
+ break
177
+ slices[dim].append(slice(start, end))
178
+ index += 1
179
+ nb_patch_per_dim.append((index+1, patch_size[dim] == 1))
180
+
181
+ for chunk in itertools.product(*slices):
182
+ patch_slices.append(tuple(chunk))
183
+
184
+ return patch_slices, nb_patch_per_dim
185
+
186
+ def _logSignalFormat(input : np.ndarray):
187
+ return {str(i): channel for i, channel in enumerate(input)}
188
+
189
+ def _logImageFormat(input : np.ndarray):
190
+ if len(input.shape) == 2:
191
+ input = np.expand_dims(input, axis=0)
192
+
193
+ if len(input.shape) == 3 and input.shape[0] != 1:
194
+ input = np.expand_dims(input, axis=0)
195
+
196
+ if len(input.shape) == 4:
197
+ input = input[:, input.shape[1]//2]
198
+
199
+ if input.dtype == np.uint8:
200
+ return input
201
+
202
+ input = input.astype(float)
203
+ b = -np.min(input)
204
+ if (np.max(input)+b) > 0:
205
+ return (input+b)/(np.max(input)+b)
206
+ else:
207
+ return 0*input
208
+
209
+ def _logImagesFormat(input : np.ndarray):
210
+ result = []
211
+ for n in range(input.shape[0]):
212
+ result.append(_logImageFormat(input[n]))
213
+ result = np.stack(result, axis=0)
214
+ return result
215
+
216
+ def _logVideoFormat(input : np.ndarray):
217
+ result = []
218
+ for t in range(input.shape[1]):
219
+ result.append( _logImagesFormat(input[:, t,...]))
220
+ result = np.stack(result, axis=1)
221
+
222
+ nb_channel = result.shape[2]
223
+ if nb_channel < 3:
224
+ channel_split = [result[:, :, 0, ...] for i in range(3)]
225
+ else:
226
+ channel_split = np.split(result, 3, axis=0)
227
+ input = np.zeros((result.shape[0], result.shape[1], 3, *list(result.shape[3:])))
228
+ for i, channels in enumerate(channel_split):
229
+ input[:,:,i] = np.mean(channels, axis=0)
230
+ return input
231
+
232
+ class DataLog(Enum):
233
+ SIGNAL = lambda tb, name, layer, it : [tb.add_scalars(name, _logSignalFormat(layer[b, :, 0]), layer.shape[0]*it+b) for b in range(layer.shape[0])],
234
+ IMAGE = lambda tb, name, layer, it : tb.add_image(name, _logImageFormat(layer[0]), it),
235
+ IMAGES = lambda tb, name, layer, it : tb.add_images(name, _logImagesFormat(layer), it),
236
+ VIDEO = lambda tb, name, layer, it : tb.add_video(name, _logVideoFormat(layer), it),
237
+ AUDIO = lambda tb, name, layer, it : tb.add_audio(name, _logImageFormat(layer), it)
238
+
239
+ class Log():
240
+
241
+ def __init__(self, name: str) -> None:
242
+ path = PREDICTIONS_DIRECTORY() if DL_API_STATE() == "PREDICTION" else STATISTICS_DIRECTORY()
243
+ if not os.path.exists("{}{}".format(path, name)):
244
+ os.makedirs("{}{}".format(path, name))
245
+ self.file = open("{}{}/log.txt".format(path, name), 'w')
246
+ self.stdout_bak = sys.stdout
247
+ self.stderr_bak = sys.stderr
248
+ self.verbose = os.environ["DEEP_LEARNING_VERBOSE"] == "True"
249
+
250
+ def __enter__(self):
251
+ self.file.__enter__()
252
+ sys.stdout = self
253
+ sys.stderr = self
254
+ return self
255
+
256
+ def __exit__(self, type, value, traceback):
257
+ self.file.__exit__(type, value, traceback)
258
+ sys.stdout = self.stdout_bak
259
+ sys.stderr = self.stderr_bak
260
+
261
+ def write(self, msg):
262
+ if msg.strip() != "":
263
+ self.file.write(msg)
264
+ if self.verbose:
265
+ print(msg, file=sys.__stdout__)
266
+
267
+ def flush(self):
268
+ pass
269
+
270
+ class TensorBoard():
271
+
272
+ def __init__(self, name: str) -> None:
273
+ self.process = None
274
+ self.name = name
275
+
276
+ def __enter__(self):
277
+ if "DEEP_LEARNING_TENSORBOARD_PORT" in os.environ:
278
+ command = ["tensorboard", "--logdir", PREDICTIONS_DIRECTORY() if DL_API_STATE() == "PREDICTION" else STATISTICS_DIRECTORY() + self.name + "/", "--port", os.environ["DEEP_LEARNING_TENSORBOARD_PORT"], "--bind_all"]
279
+ self.process = subprocess.Popen(command, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
280
+ try:
281
+ s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
282
+ s.connect(('10.255.255.255', 1))
283
+ IP = s.getsockname()[0]
284
+ except Exception:
285
+ IP = '127.0.0.1'
286
+ finally:
287
+ s.close()
288
+ print("Tensorboard : http://{}:{}/".format(IP, os.environ["DEEP_LEARNING_TENSORBOARD_PORT"]))
289
+ return self
290
+
291
+ def __exit__(self, type, value, traceback):
292
+ if self.process is not None:
293
+ self.process.terminate()
294
+ self.process.wait()
295
+
296
+ class DistributedObject():
297
+
298
+ def __init__(self, name: str) -> None:
299
+ self.port = find_free_port()
300
+ self.dataloader : list[list[DataLoader]]
301
+ self.manual_seed: bool = None
302
+ self.name = name
303
+ self.size = 1
304
+
305
+ @abstractmethod
306
+ def setup(self, world_size: int):
307
+ pass
308
+
309
+ def __enter__(self):
310
+ return self
311
+
312
+ def __exit__(self, type, value, traceback):
313
+ pass
314
+
315
+ @abstractmethod
316
+ def run_process(self, world_size: int, global_rank: int, local_rank: int, dataloaders: list[DataLoader]):
317
+ pass
318
+
319
+ def getMeasure(world_size: int, global_rank: int, gpu: int, models: dict[str, torch.nn.Module], n: int) -> dict[str, tuple[dict[str, tuple[float, float]], dict[str, tuple[float, float]]]]:
320
+ data = {}
321
+ for label, model in models.items():
322
+ for name, network in model.getNetworks().items():
323
+ if network.measure is not None:
324
+ data["{}{}".format(name, label)] = (network.measure.format(True, n), network.measure.format(False, n))
325
+ outputs = synchronize_data(world_size, gpu, data)
326
+ result = {}
327
+ if global_rank == 0:
328
+ for output in outputs:
329
+ for k, v in output.items():
330
+ for t in range(len(v)):
331
+ for u, n in v[t].items():
332
+ if k not in result:
333
+ result[k] = ({}, {})
334
+ if u not in result[k][t]:
335
+ result[k][t][u] = (n[0], 0)
336
+ result[k][t][u] = (result[k][t][u][0], result[k][t][u][1]+n[1]/world_size)
337
+ return result
338
+
339
+ def __call__(self, rank: Union[int, None] = None) -> None:
340
+ with Log(self.name):
341
+ world_size = len(self.dataloader)
342
+ global_rank, local_rank = setupGPU(world_size, self.port, rank)
343
+ if global_rank is None:
344
+ return
345
+ if torch.cuda.is_available():
346
+ pynvml.nvmlInit()
347
+ if self.manual_seed is not None:
348
+ np.random.seed(self.manual_seed * world_size + global_rank)
349
+ random.seed(self.manual_seed * world_size + global_rank)
350
+ torch.manual_seed(self.manual_seed * world_size + global_rank)
351
+ torch.backends.cudnn.benchmark = self.manual_seed is None
352
+ torch.backends.cudnn.deterministic = self.manual_seed is not None
353
+ torch.backends.cuda.matmul.allow_tf32 = True
354
+ torch.backends.cudnn.allow_tf32 = True
355
+ dataloaders = self.dataloader[global_rank]
356
+ if torch.cuda.is_available():
357
+ torch.cuda.set_device(local_rank)
358
+
359
+ self.run_process(world_size, global_rank, local_rank, dataloaders)
360
+ if torch.cuda.is_available():
361
+ pynvml.nvmlShutdown()
362
+ cleanup()
363
+
364
+ def setupAPI(parser: argparse.ArgumentParser) -> DistributedObject:
365
+ # API arguments
366
+ api_args = parser.add_argument_group('API arguments')
367
+ api_args.add_argument("type", type=State, choices=list(State))
368
+ api_args.add_argument('-y', action='store_true', help="Accept overwrite")
369
+ api_args.add_argument('-tb', action='store_true', help='Start TensorBoard')
370
+ api_args.add_argument("-c", "--config", type=str, default="None", help="Configuration file location")
371
+ api_args.add_argument("-g", "--gpu", type=str, default=os.environ["CUDA_VISIBLE_DEVICES"] if "CUDA_VISIBLE_DEVICES" in os.environ else "", help="List of GPU")
372
+ api_args.add_argument('--num-workers', '--num_workers', default=4, type=int, help='No. of workers per DataLoader & GPU')
373
+ api_args.add_argument("-models_dir", "--MODELS_DIRECTORY", type=str, default="./Models/", help="Models location")
374
+ api_args.add_argument("-checkpoints_dir", "--CHECKPOINTS_DIRECTORY", type=str, default="./Checkpoints/", help="Checkpoints location")
375
+ api_args.add_argument("-model", "--MODEL", type=str, default="", help="URL Model")
376
+ api_args.add_argument("-predictions_dir", "--PREDICTIONS_DIRECTORY", type=str, default="./Predictions/", help="Predictions location")
377
+ api_args.add_argument("-evaluation_dir", "--EVALUATIONS_DIRECTORY", type=str, default="./Evaluations/", help="Evaluations location")
378
+ api_args.add_argument("-statistics_dir", "--STATISTICS_DIRECTORY", type=str, default="./Statistics/", help="Statistics location")
379
+ api_args.add_argument("-setups_dir", "--SETUPS_DIRECTORY", type=str, default="./Setups/", help="Setups location")
380
+ api_args.add_argument('-log', action='store_true', help='Save log')
381
+ api_args.add_argument('-quiet', action='store_false', help='')
382
+
383
+
384
+ args = parser.parse_args()
385
+ config = vars(args)
386
+
387
+ os.environ["CUDA_VISIBLE_DEVICES"] = config["gpu"]
388
+ os.environ["DL_API_MODELS_DIRECTORY"] = config["MODELS_DIRECTORY"]
389
+ os.environ["DL_API_CHECKPOINTS_DIRECTORY"] = config["CHECKPOINTS_DIRECTORY"]
390
+ os.environ["DL_API_PREDICTIONS_DIRECTORY"] = config["PREDICTIONS_DIRECTORY"]
391
+ os.environ["DL_API_EVALUATIONS_DIRECTORY"] = config["EVALUATIONs_DIRECTORY"]
392
+ os.environ["DL_API_STATISTICS_DIRECTORY"] = config["STATISTICS_DIRECTORY"]
393
+
394
+ os.environ["DL_API_STATE"] = str(config["type"])
395
+
396
+ os.environ["DL_API_MODEL"] = config["MODEL"]
397
+
398
+ os.environ["DL_API_SETUPS_DIRECTORY"] = config["SETUPS_DIRECTORY"]
399
+
400
+ os.environ["DL_API_OVERWRITE"] = "{}".format(config["y"])
401
+ os.environ["DEEP_LEANING_API_CONFIG_MODE"] = "Done"
402
+ if config["tb"]:
403
+ os.environ["DEEP_LEARNING_TENSORBOARD_PORT"] = str(find_free_port())
404
+
405
+ os.environ["DEEP_LEARNING_VERBOSE"] = str(config["quiet"])
406
+
407
+ if config["config"] == "None":
408
+ if config["type"] is State.PREDICTION:
409
+ os.environ["DEEP_LEARNING_API_CONFIG_FILE"] = "Prediction.yml"
410
+ elif config["type"] is State.EVALUATION:
411
+ os.environ["DEEP_LEARNING_API_CONFIG_FILE"] = "Evaluation.yml"
412
+ else:
413
+ os.environ["DEEP_LEARNING_API_CONFIG_FILE"] = "Config.yml"
414
+ else:
415
+ os.environ["DEEP_LEARNING_API_CONFIG_FILE"] = config["config"]
416
+ torch.autograd.set_detect_anomaly(True)
417
+ os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
418
+
419
+ if config["type"] is State.PREDICTION:
420
+ from KonfAI.konfai.predictor import Predictor
421
+ os.environ["DEEP_LEARNING_API_ROOT"] = "Predictor"
422
+ return Predictor(config=CONFIG_FILE())
423
+ elif config["type"] is State.EVALUATION:
424
+ from KonfAI.konfai.evaluator import Evaluator
425
+ os.environ["DEEP_LEARNING_API_ROOT"] = "Evaluator"
426
+ return Evaluator(config=CONFIG_FILE())
427
+ else:
428
+ from KonfAI.konfai.trainer import Trainer
429
+ os.environ["DEEP_LEARNING_API_ROOT"] = "Trainer"
430
+ return Trainer(config=CONFIG_FILE())
431
+
432
+ import submitit
433
+
434
+ def setupGPU(world_size: int, port: int, rank: Union[int, None] = None) -> tuple[int , int]:
435
+ try:
436
+ host_name = subprocess.check_output("scontrol show hostnames {}".format(os.getenv('SLURM_JOB_NODELIST')).split()).decode().splitlines()[0]
437
+ except:
438
+ host_name = "localhost"
439
+ if rank is None:
440
+ job_env = submitit.JobEnvironment()
441
+ global_rank = job_env.global_rank
442
+ local_rank = job_env.local_rank
443
+ else:
444
+ global_rank = rank
445
+ local_rank = rank
446
+ if global_rank >= world_size:
447
+ return None, None
448
+ print("tcp://{}:{}".format(host_name, port))
449
+ if torch.cuda.is_available():
450
+ torch.cuda.empty_cache()
451
+ dist.init_process_group("nccl", rank=global_rank, init_method="tcp://{}:{}".format(host_name, port), world_size=world_size)
452
+ return global_rank, local_rank
453
+
454
+ import socket
455
+ from contextlib import closing
456
+
457
+ def find_free_port():
458
+ with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
459
+ s.bind(('', 0))
460
+ s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
461
+ return s.getsockname()[1]
462
+
463
+ def cleanup():
464
+ if torch.cuda.is_available():
465
+ dist.destroy_process_group()
466
+
467
+ def synchronize_data(world_size: int, gpu: int, data: any) -> list[Any]:
468
+ if torch.cuda.is_available():
469
+ outputs: list[dict[str, tuple[dict[str, float], dict[str, float]]]] = [None for _ in range(world_size)]
470
+ torch.cuda.set_device(gpu)
471
+ dist.all_gather_object(outputs, data)
472
+ else:
473
+ outputs = [data]
474
+ return outputs
475
+
476
+ def _resample(data: torch.Tensor, size: list[int]) -> torch.Tensor:
477
+ if data.dtype == torch.uint8:
478
+ mode = "nearest"
479
+ elif len(data.shape) < 4:
480
+ mode = "bilinear"
481
+ else:
482
+ mode = "trilinear"
483
+ return F.interpolate(data.type(torch.float32).unsqueeze(0), size=tuple([s for s in reversed(size)]), mode=mode).squeeze(0).type(data.dtype)
484
+
485
+ def _affine_matrix(matrix: torch.Tensor, translation: torch.Tensor) -> torch.Tensor:
486
+ return torch.cat((torch.cat((matrix, translation.unsqueeze(0).T), dim=1), torch.tensor([[0, 0, 0, 1]])), dim=0)
487
+
488
+ def _resample_affine(data: torch.Tensor, matrix: torch.Tensor):
489
+ if data.dtype == torch.uint8:
490
+ mode = "nearest"
491
+ else:
492
+ mode = "bilinear"
493
+ return F.grid_sample(data.unsqueeze(0).type(torch.float32), F.affine_grid(matrix[:, :-1,...].type(torch.float32), [1]+list(data.shape), align_corners=True), align_corners=True, mode=mode, padding_mode="reflection").squeeze(0).type(data.dtype)
@@ -0,0 +1,68 @@
1
+ Metadata-Version: 2.4
2
+ Name: konfai
3
+ Version: 1.0.0
4
+ Summary: Modular and configurable Deep Learning framework with YAML and PyTorch
5
+ Author-email: Valentin Boussot <boussot.v@gmail.com>
6
+ License-Expression: Apache-2.0
7
+ Project-URL: Homepage, https://github.com/vboussot/KonfAI
8
+ Project-URL: Repository, https://github.com/vboussot/KonfAI
9
+ Project-URL: Issues, https://github.com/vboussot/KonfAI/issues
10
+ Project-URL: License, https://www.apache.org/licenses/LICENSE-2.0
11
+ Requires-Python: >=3.8
12
+ Description-Content-Type: text/markdown
13
+ License-File: LICENSE
14
+ Requires-Dist: torch
15
+ Requires-Dist: tqdm
16
+ Requires-Dist: numpy
17
+ Requires-Dist: ruamel.yaml
18
+ Requires-Dist: psutil
19
+ Requires-Dist: tensorboard
20
+ Requires-Dist: SimpleITK
21
+ Requires-Dist: lxml
22
+ Requires-Dist: h5py
23
+ Requires-Dist: pynvml
24
+ Provides-Extra: vtk
25
+ Requires-Dist: vtk; extra == "vtk"
26
+ Provides-Extra: lpips
27
+ Requires-Dist: lpips; extra == "lpips"
28
+ Provides-Extra: cluster
29
+ Requires-Dist: submitit; extra == "cluster"
30
+ Dynamic: license-file
31
+
32
+
33
+ # ๐Ÿง  KonfAI
34
+ <img src="logo.png" alt="KonfAI Logo" width="200" align="right"/>
35
+
36
+ **KonfAI** is a modular and highly configurable deep learning framework built on PyTorch, driven entirely by YAML configuration files.
37
+
38
+ It is designed to support complex medical imaging workflows, flexible model architectures, customizable training loops, and advanced loss scheduling, without hardcoding anything.
39
+
40
+ ---
41
+
42
+ ## ๐Ÿ”ง Key Features
43
+
44
+ - ๐Ÿ”€ Full training/prediction/evaluation orchestration via YAML configuration file
45
+ - ๐Ÿงฉ Modular plugin-like structure (transforms, augmentations, models, losses, schedulers)
46
+ - ๐Ÿ”„ Dynamic criterion scheduling per head / target
47
+ - ๐Ÿง  Multi-branch / multi-output model support
48
+ - ๐Ÿ–ฅ๏ธ Cluster-ready
49
+ - ๐Ÿ“ˆ TensorBoard and custom logging support
50
+
51
+ ---
52
+
53
+ ## ๐Ÿš€ Installation
54
+
55
+ ```bash
56
+ git clone https://github.com/vboussot/KonfAI.git && cd KonfAI
57
+ pip install -e .
58
+ ```
59
+
60
+ ---
61
+
62
+ ## ๐Ÿงช Usage
63
+
64
+ ```bash
65
+ konfai TRAIN --gpu 0
66
+ konfai PREDICTION --gpu 0
67
+ konfai EVALUATION
68
+ ```
@@ -0,0 +1,39 @@
1
+ konfai/__init__.py,sha256=jXMTNml38eX6FSq9d3C_gJVgRLTKHPBUXqOLC7Pqkuo,828
2
+ konfai/evaluator.py,sha256=6YU3bXBy1YcS2kl0YwebkgwXYeO_VoBN59-MLMqj-ds,7468
3
+ konfai/main.py,sha256=Y-8vTgHVecMglB8krGKLbBoFlcoo-Oa4l-PjTDhBzbM,2142
4
+ konfai/predictor.py,sha256=IOh70fCVm8q-sgZyACNperTO-Vel8QKvYp-FoBY39ao,20236
5
+ konfai/trainer.py,sha256=zGvXd2skcqWgRN9GLx93xYB4Bv-46C0oo7J9My4Levk,16901
6
+ konfai/data/HDF5.py,sha256=QfU8VnyslkQhT_k2AJNFMNkJK7lm75ozxT4WELZt8wk,14390
7
+ konfai/data/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
8
+ konfai/data/augmentation.py,sha256=F4hAuw5j-8zpmAzSAIn4VDhZG1-CCDCZQD_aoAYO-LA,31757
9
+ konfai/data/dataset.py,sha256=JnouutlJGlgIe7XnAijFe4FUatTZyWDLzWSLE3OxjZM,23963
10
+ konfai/data/transform.py,sha256=AxGqtEHC6XIk4AT-Clbq7w1sWUBrONLMsDHGiC0wIhI,25171
11
+ konfai/metric/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
12
+ konfai/metric/measure.py,sha256=fYkLgtttoGiQWU9a67Kk8eEOP1K_4yCwk87IA3x84Uc,21840
13
+ konfai/metric/schedulers.py,sha256=zQORXilMGPGBHq7Rg3l9JmbdoEebnTiB6yPHxuyEl7c,1377
14
+ konfai/models/classification/convNeXt.py,sha256=JbEPl7PHLY-FkSsjqDJtLlixH5JBIv_TDf0rMZeEe8s,9260
15
+ konfai/models/classification/resnet.py,sha256=i7SWA00yQdSEpLwYnn0_Cf32e4uAAzN-67mbcuC0wzw,7975
16
+ konfai/models/generation/cStyleGan.py,sha256=0b3lUH3PYJEBDSN6J0wanA-R2bWyj-gMmaDvSMiHw3A,8057
17
+ konfai/models/generation/ddpm.py,sha256=X4hMnrkIfiqdyL_XY8YE6hKeLUPqV15EkrsWxQ3DXEY,13184
18
+ konfai/models/generation/diffusionGan.py,sha256=x0sksJe2CI_Oqj-skuXfFChlWhBw4s0HoYk2ICE0fDM,33237
19
+ konfai/models/generation/gan.py,sha256=fN6CIDi23_XcsXdre8fJL_xRXhtLVSnLzoVcvF3bmLk,7871
20
+ konfai/models/generation/vae.py,sha256=Qq1nKnAGyv7VsgH5nZatMPjjrlIpXZ3bWifvu8W8W7Y,4733
21
+ konfai/models/registration/registration.py,sha256=vpzFl-ozga1TDLadyGE6w0xosDblC6PBABmcS-4E31w,6362
22
+ konfai/models/representation/representation.py,sha256=t9gX49KhyK7PvO2CduK_RmUrNY0m8pyr5nXwrOZ8szo,2691
23
+ konfai/models/segmentation/NestedUNet.py,sha256=hgbawKp27elTgkK5APjEa1nUEt2oJi9x1nsJLE22p7g,4318
24
+ konfai/models/segmentation/UNet.py,sha256=X-ddiQBJboq3ZHDfj8CvoZiNc9RT-eXKlBXriaL_mFY,4235
25
+ konfai/network/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
26
+ konfai/network/blocks.py,sha256=U5P4EAHShvQh6s5VWkhn5_VIGj8gHpUvd7WyVWG3MiI,13542
27
+ konfai/network/network.py,sha256=qfqnPFQHOVu_fmZXOrFKXj3-Ej0HzcfJpkhW6FEbgD4,45810
28
+ konfai/utils/ITK.py,sha256=tErt6ymFesZWg4Mw6ZYc8kOC9zpUSjBRQMm1PUnvgF8,13962
29
+ konfai/utils/Registration.py,sha256=v1srEBOcgDnHrx0YtsK6bcj0yCMH7wNeaQ3wC7gEvOw,8898
30
+ konfai/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
31
+ konfai/utils/config.py,sha256=4NkR1BWXxwtsf95G_xH9_t-pe2unmbRvxzyKCIuS6eE,9894
32
+ konfai/utils/dataset.py,sha256=PNmzxaeFGMBh-pAaR92tDresDjF0QXSli2eGNJyzSVQ,35465
33
+ konfai/utils/utils.py,sha256=BCH7nwDf_Aqt_5lDieBCUEI0HE5TlinG4Z96EOAwA94,20206
34
+ konfai-1.0.0.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
35
+ konfai-1.0.0.dist-info/METADATA,sha256=flcowFH9PD-G7CYRMjBtpX2Dy-XLTY2mCXPfSUL5kAM,1971
36
+ konfai-1.0.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
37
+ konfai-1.0.0.dist-info/entry_points.txt,sha256=fG82HRN5-g39ACSOCtij_I3N6EHxfYnMR0D7TI_8pW8,81
38
+ konfai-1.0.0.dist-info/top_level.txt,sha256=xF470dkIlFoFqTZEOlRehKJr4WU_8OKGXrJqYm9vWKs,7
39
+ konfai-1.0.0.dist-info/RECORD,,
@@ -0,0 +1,5 @@
1
+ Wheel-Version: 1.0
2
+ Generator: setuptools (80.9.0)
3
+ Root-Is-Purelib: true
4
+ Tag: py3-none-any
5
+
@@ -0,0 +1,3 @@
1
+ [console_scripts]
2
+ konfai = konfai.main:main
3
+ konfai-cluster = konfai.main:cluster