konfai 1.1.0__py3-none-any.whl → 1.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of konfai might be problematic. Click here for more details.
- konfai/__init__.py +1 -1
- konfai/data/augmentation.py +2 -2
- konfai/data/data_manager.py +145 -42
- konfai/data/patching.py +39 -13
- konfai/data/transform.py +48 -21
- konfai/evaluator.py +24 -7
- konfai/main.py +7 -5
- konfai/models/registration/registration.py +0 -1
- konfai/network/blocks.py +0 -1
- konfai/network/network.py +29 -16
- konfai/predictor.py +24 -21
- konfai/trainer.py +15 -15
- konfai/utils/config.py +12 -12
- konfai/utils/dataset.py +27 -2
- konfai/utils/utils.py +108 -24
- {konfai-1.1.0.dist-info → konfai-1.1.2.dist-info}/METADATA +1 -1
- {konfai-1.1.0.dist-info → konfai-1.1.2.dist-info}/RECORD +21 -21
- {konfai-1.1.0.dist-info → konfai-1.1.2.dist-info}/WHEEL +0 -0
- {konfai-1.1.0.dist-info → konfai-1.1.2.dist-info}/entry_points.txt +0 -0
- {konfai-1.1.0.dist-info → konfai-1.1.2.dist-info}/licenses/LICENSE +0 -0
- {konfai-1.1.0.dist-info → konfai-1.1.2.dist-info}/top_level.txt +0 -0
konfai/utils/utils.py
CHANGED
|
@@ -18,13 +18,15 @@ import random
|
|
|
18
18
|
from torch.utils.data import DataLoader
|
|
19
19
|
import torch.nn.functional as F
|
|
20
20
|
import sys
|
|
21
|
+
import re
|
|
22
|
+
|
|
21
23
|
|
|
22
24
|
def description(model, modelEMA = None, showMemory: bool = True) -> str:
|
|
23
25
|
values_desc = lambda weights, values: " ".join(["{}({:.2f}) : {:.6f}".format(name.split(":")[-1], weight, value) for (name, value), weight in zip(values.items(), weights.values())])
|
|
24
26
|
model_desc = lambda model : "("+" ".join(["{}({:.6f}) : {}".format(name, network.optimizer.param_groups[0]['lr'] if network.optimizer is not None else 0, values_desc(network.measure.getLastWeights(), network.measure.getLastValues())) for name, network in model.module.getNetworks().items() if network.measure is not None])+")"
|
|
25
27
|
result = "Loss {}".format(model_desc(model))
|
|
26
28
|
if modelEMA is not None:
|
|
27
|
-
result += "Loss EMA {}".format(model_desc(modelEMA))
|
|
29
|
+
result += " Loss EMA {}".format(model_desc(modelEMA))
|
|
28
30
|
result += " "+gpuInfo()
|
|
29
31
|
if showMemory:
|
|
30
32
|
result +=" | {}".format(memoryInfo())
|
|
@@ -145,8 +147,12 @@ def get_patch_slices_from_nb_patch_per_dim(patch_size_tmp: list[int], nb_patch_p
|
|
|
145
147
|
|
|
146
148
|
def get_patch_slices_from_shape(patch_size: list[int], shape : list[int], overlap: Union[int, None]) -> tuple[list[tuple[slice]], list[tuple[int, bool]]]:
|
|
147
149
|
if len(shape) != len(patch_size):
|
|
148
|
-
|
|
149
|
-
|
|
150
|
+
raise DatasetManagerError(
|
|
151
|
+
f"Dimension mismatch: 'patch_size' has {len(patch_size)} dimensions, but 'shape' has {len(shape)}.",
|
|
152
|
+
f"patch_size: {patch_size}",
|
|
153
|
+
f"shape: {shape}",
|
|
154
|
+
"Both must have the same number of dimensions (e.g., 3D patch for 3D volume)."
|
|
155
|
+
)
|
|
150
156
|
patch_slices = []
|
|
151
157
|
nb_patch_per_dim = []
|
|
152
158
|
slices : list[list[slice]] = []
|
|
@@ -192,12 +198,8 @@ def _logImageFormat(input : np.ndarray):
|
|
|
192
198
|
|
|
193
199
|
if len(input.shape) == 3 and input.shape[0] != 1:
|
|
194
200
|
input = np.expand_dims(input, axis=0)
|
|
195
|
-
|
|
196
201
|
if len(input.shape) == 4:
|
|
197
202
|
input = input[:, input.shape[1]//2]
|
|
198
|
-
|
|
199
|
-
if input.dtype == np.uint8:
|
|
200
|
-
return input
|
|
201
203
|
|
|
202
204
|
input = input.astype(float)
|
|
203
205
|
b = -np.min(input)
|
|
@@ -237,7 +239,7 @@ class DataLog(Enum):
|
|
|
237
239
|
AUDIO = lambda tb, name, layer, it : tb.add_audio(name, _logImageFormat(layer), it)
|
|
238
240
|
|
|
239
241
|
class Log:
|
|
240
|
-
def __init__(self, name: str) -> None:
|
|
242
|
+
def __init__(self, name: str, rank: int) -> None:
|
|
241
243
|
if KONFAI_STATE() == "PREDICTION":
|
|
242
244
|
path = PREDICTIONS_DIRECTORY()
|
|
243
245
|
elif KONFAI_STATE() == "EVALUATION":
|
|
@@ -248,11 +250,12 @@ class Log:
|
|
|
248
250
|
self.verbose = os.environ.get("KONFAI_VERBOSE", "True") == "True"
|
|
249
251
|
self.log_path = os.path.join(path, name)
|
|
250
252
|
os.makedirs(self.log_path, exist_ok=True)
|
|
251
|
-
|
|
252
|
-
self.file = open(os.path.join(self.log_path, "
|
|
253
|
+
self.rank = rank
|
|
254
|
+
self.file = open(os.path.join(self.log_path, "log_{}.txt".format(rank)), "w", buffering=1)
|
|
253
255
|
self.stdout_bak = sys.stdout
|
|
254
256
|
self.stderr_bak = sys.stderr
|
|
255
|
-
|
|
257
|
+
self._buffered_line = ""
|
|
258
|
+
|
|
256
259
|
def __enter__(self):
|
|
257
260
|
self.file.__enter__()
|
|
258
261
|
sys.stdout = self
|
|
@@ -264,12 +267,26 @@ class Log:
|
|
|
264
267
|
sys.stdout = self.stdout_bak
|
|
265
268
|
sys.stderr = self.stderr_bak
|
|
266
269
|
|
|
267
|
-
def write(self, msg):
|
|
270
|
+
def write(self, msg: str):
|
|
268
271
|
if not msg:
|
|
269
272
|
return
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
+
|
|
274
|
+
|
|
275
|
+
ANSI_ESCAPE = re.compile(r'\x1b\[[0-9;]*[a-zA-Z]')
|
|
276
|
+
CARRIAGE_RETURN = re.compile(r'(?:\r|\x1b\[A).*')
|
|
277
|
+
msg_clean = ANSI_ESCAPE.sub('', msg)
|
|
278
|
+
if '\r' in msg_clean or '[A' in msg:
|
|
279
|
+
# On garde seulement le contenu après le dernier retour chariot
|
|
280
|
+
msg_clean = msg_clean.split('\r')[-1].strip()
|
|
281
|
+
self._buffered_line = msg_clean
|
|
282
|
+
else:
|
|
283
|
+
self._buffered_line = msg_clean.strip()
|
|
284
|
+
|
|
285
|
+
if self._buffered_line:
|
|
286
|
+
# Écrit dans le fichier
|
|
287
|
+
self.file.write(self._buffered_line + "\n")
|
|
288
|
+
self.file.flush()
|
|
289
|
+
if self.verbose and (self.rank == 0 or "KONFAI_CLUSTER" in os.environ):
|
|
273
290
|
sys.__stdout__.write(msg)
|
|
274
291
|
sys.__stdout__.flush()
|
|
275
292
|
|
|
@@ -325,7 +342,7 @@ class DistributedObject():
|
|
|
325
342
|
return self
|
|
326
343
|
|
|
327
344
|
def __exit__(self, type, value, traceback):
|
|
328
|
-
|
|
345
|
+
cleanup()
|
|
329
346
|
|
|
330
347
|
@abstractmethod
|
|
331
348
|
def run_process(self, world_size: int, global_rank: int, local_rank: int, dataloaders: list[DataLoader]):
|
|
@@ -352,7 +369,7 @@ class DistributedObject():
|
|
|
352
369
|
return result
|
|
353
370
|
|
|
354
371
|
def __call__(self, rank: Union[int, None] = None) -> None:
|
|
355
|
-
with Log(self.name):
|
|
372
|
+
with Log(self.name, rank):
|
|
356
373
|
world_size = len(self.dataloader)
|
|
357
374
|
global_rank, local_rank = setupGPU(world_size, self.port, rank)
|
|
358
375
|
if global_rank is None:
|
|
@@ -370,11 +387,12 @@ class DistributedObject():
|
|
|
370
387
|
dataloaders = self.dataloader[global_rank]
|
|
371
388
|
if torch.cuda.is_available():
|
|
372
389
|
torch.cuda.set_device(local_rank)
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
|
|
390
|
+
try:
|
|
391
|
+
self.run_process(world_size, global_rank, local_rank, dataloaders)
|
|
392
|
+
finally:
|
|
393
|
+
cleanup()
|
|
394
|
+
if torch.cuda.is_available():
|
|
395
|
+
pynvml.nvmlShutdown()
|
|
378
396
|
|
|
379
397
|
def setup(parser: argparse.ArgumentParser) -> DistributedObject:
|
|
380
398
|
# KONFAI arguments
|
|
@@ -384,6 +402,7 @@ def setup(parser: argparse.ArgumentParser) -> DistributedObject:
|
|
|
384
402
|
KONFAI_args.add_argument('-tb', action='store_true', help='Start TensorBoard')
|
|
385
403
|
KONFAI_args.add_argument("-c", "--config", type=str, default="None", help="Configuration file location")
|
|
386
404
|
KONFAI_args.add_argument("-g", "--gpu", type=str, default=os.environ["CUDA_VISIBLE_DEVICES"] if "CUDA_VISIBLE_DEVICES" in os.environ else "", help="List of GPU")
|
|
405
|
+
KONFAI_args.add_argument("-cpu", "--cpu", type=str, default="1" , help="List of GPU")
|
|
387
406
|
KONFAI_args.add_argument('--num-workers', '--num_workers', default=4, type=int, help='No. of workers per DataLoader & GPU')
|
|
388
407
|
KONFAI_args.add_argument("-models_dir", "--MODELS_DIRECTORY", type=str, default="./Models/", help="Models location")
|
|
389
408
|
KONFAI_args.add_argument("-checkpoints_dir", "--CHECKPOINTS_DIRECTORY", type=str, default="./Checkpoints/", help="Checkpoints location")
|
|
@@ -400,6 +419,7 @@ def setup(parser: argparse.ArgumentParser) -> DistributedObject:
|
|
|
400
419
|
config = vars(args)
|
|
401
420
|
|
|
402
421
|
os.environ["CUDA_VISIBLE_DEVICES"] = config["gpu"]
|
|
422
|
+
os.environ["KONFAI_NB_CORES"] = config["cpu"]
|
|
403
423
|
os.environ["KONFAI_MODELS_DIRECTORY"] = config["MODELS_DIRECTORY"]
|
|
404
424
|
os.environ["KONFAI_CHECKPOINTS_DIRECTORY"] = config["CHECKPOINTS_DIRECTORY"]
|
|
405
425
|
os.environ["KONFAI_PREDICTIONS_DIRECTORY"] = config["PREDICTIONS_DIRECTORY"]
|
|
@@ -460,10 +480,18 @@ def setupGPU(world_size: int, port: int, rank: Union[int, None] = None) -> tuple
|
|
|
460
480
|
local_rank = rank
|
|
461
481
|
if global_rank >= world_size:
|
|
462
482
|
return None, None
|
|
463
|
-
print("tcp://{}:{}".format(host_name, port))
|
|
483
|
+
#print("tcp://{}:{}".format(host_name, port))
|
|
464
484
|
if torch.cuda.is_available():
|
|
465
485
|
torch.cuda.empty_cache()
|
|
466
486
|
dist.init_process_group("nccl", rank=global_rank, init_method="tcp://{}:{}".format(host_name, port), world_size=world_size)
|
|
487
|
+
else:
|
|
488
|
+
if not dist.is_initialized():
|
|
489
|
+
dist.init_process_group(
|
|
490
|
+
backend="gloo",
|
|
491
|
+
init_method=f"tcp://{host_name}:{port}",
|
|
492
|
+
rank=global_rank,
|
|
493
|
+
world_size=world_size
|
|
494
|
+
)
|
|
467
495
|
return global_rank, local_rank
|
|
468
496
|
|
|
469
497
|
import socket
|
|
@@ -476,7 +504,7 @@ def find_free_port():
|
|
|
476
504
|
return s.getsockname()[1]
|
|
477
505
|
|
|
478
506
|
def cleanup():
|
|
479
|
-
if
|
|
507
|
+
if dist.is_initialized():
|
|
480
508
|
dist.destroy_process_group()
|
|
481
509
|
|
|
482
510
|
def synchronize_data(world_size: int, gpu: int, data: any) -> list[Any]:
|
|
@@ -506,3 +534,59 @@ def _resample_affine(data: torch.Tensor, matrix: torch.Tensor):
|
|
|
506
534
|
else:
|
|
507
535
|
mode = "bilinear"
|
|
508
536
|
return F.grid_sample(data.unsqueeze(0).type(torch.float32), F.affine_grid(matrix[:, :-1,...].type(torch.float32), [1]+list(data.shape), align_corners=True), align_corners=True, mode=mode, padding_mode="reflection").squeeze(0).type(data.dtype)
|
|
537
|
+
|
|
538
|
+
|
|
539
|
+
SUPPORTED_EXTENSIONS = [
|
|
540
|
+
"mha", "mhd", # MetaImage
|
|
541
|
+
"nii", "nii.gz", # NIfTI
|
|
542
|
+
"nrrd", "nrrd.gz", # NRRD
|
|
543
|
+
"gipl", "gipl.gz", # GIPL
|
|
544
|
+
"hdr", "img", # Analyze
|
|
545
|
+
"dcm", # DICOM (si GDCM activé)
|
|
546
|
+
"tif", "tiff", # TIFF
|
|
547
|
+
"png", "jpg", "jpeg", "bmp", # 2D formats
|
|
548
|
+
"h5", "itk.txt", ".fcsv", ".xml", ".vtk", ".npy"
|
|
549
|
+
|
|
550
|
+
]
|
|
551
|
+
|
|
552
|
+
class KonfAIError(Exception):
|
|
553
|
+
|
|
554
|
+
def __init__(self, typeError: str, messages: list[str]) -> None:
|
|
555
|
+
super().__init__("\n[{}] {}".format(typeError, messages[0])+("\n" if len(messages) > 0 else "")+"\n→\t".join(messages[1:]))
|
|
556
|
+
|
|
557
|
+
|
|
558
|
+
class ConfigError(KonfAIError):
|
|
559
|
+
|
|
560
|
+
def __init__(self, *message) -> None:
|
|
561
|
+
super().__init__("Config", message)
|
|
562
|
+
|
|
563
|
+
|
|
564
|
+
class DatasetManagerError(KonfAIError):
|
|
565
|
+
|
|
566
|
+
def __init__(self, *message) -> None:
|
|
567
|
+
super().__init__("DatasetManager", message)
|
|
568
|
+
|
|
569
|
+
class MeasureError(KonfAIError):
|
|
570
|
+
|
|
571
|
+
def __init__(self, *message) -> None:
|
|
572
|
+
super().__init__("Measure", message)
|
|
573
|
+
|
|
574
|
+
class TrainerError(KonfAIError):
|
|
575
|
+
|
|
576
|
+
def __init__(self, *message) -> None:
|
|
577
|
+
super().__init__("Trainer", message)
|
|
578
|
+
|
|
579
|
+
class AugmentationError(KonfAIError):
|
|
580
|
+
|
|
581
|
+
def __init__(self, *message) -> None:
|
|
582
|
+
super().__init__("Augmentation", message)
|
|
583
|
+
|
|
584
|
+
class EvaluatorError(KonfAIError):
|
|
585
|
+
|
|
586
|
+
def __init__(self, *message) -> None:
|
|
587
|
+
super().__init__("Evaluator", message)
|
|
588
|
+
|
|
589
|
+
class TransformError(KonfAIError):
|
|
590
|
+
|
|
591
|
+
def __init__(self, *message) -> None:
|
|
592
|
+
super().__init__("Transform", message)
|
|
@@ -1,13 +1,13 @@
|
|
|
1
|
-
konfai/__init__.py,sha256=
|
|
2
|
-
konfai/evaluator.py,sha256=
|
|
3
|
-
konfai/main.py,sha256=
|
|
4
|
-
konfai/predictor.py,sha256=
|
|
5
|
-
konfai/trainer.py,sha256=
|
|
1
|
+
konfai/__init__.py,sha256=YXG-wpSEXWs6Jt3BDI77V4r89gEUNX-6lxW9btj5VYI,851
|
|
2
|
+
konfai/evaluator.py,sha256=rAhfdRemMjzC3VoaqyQKJR0SBekuLDiLT1nhblH8RQk,8293
|
|
3
|
+
konfai/main.py,sha256=rTTJl-biaX4CkLNxPtqwwsrybXSDxjWlU39TE2ImU5o,2574
|
|
4
|
+
konfai/predictor.py,sha256=e5V7awlIRDE1NkLTR_D6fUuDg93ZFGaGLTDR71KeqM8,22549
|
|
5
|
+
konfai/trainer.py,sha256=NQDzp4hUKYAcDdKmeBcAE9QNFW9KGW2qV-7q_PXeWi8,19509
|
|
6
6
|
konfai/data/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
7
|
-
konfai/data/augmentation.py,sha256=
|
|
8
|
-
konfai/data/data_manager.py,sha256=
|
|
9
|
-
konfai/data/patching.py,sha256=
|
|
10
|
-
konfai/data/transform.py,sha256=
|
|
7
|
+
konfai/data/augmentation.py,sha256=ASmKWBpykLBHDB_YeTgoBTlqvJ06v5OUG7n7ugPN6NU,31718
|
|
8
|
+
konfai/data/data_manager.py,sha256=9VpF4CTTAnqS1Hq0csbTE-XAocRwziFssITg-SZhA4A,29603
|
|
9
|
+
konfai/data/patching.py,sha256=OtGNs99jKQSYmuj8B3MNGcbupKOcX5PUpMCIe0pKnX0,15636
|
|
10
|
+
konfai/data/transform.py,sha256=dTZh2CgfxNitHW-HEkO0jQ9GG26-Wf-YlI-qoI2mqC0,26472
|
|
11
11
|
konfai/metric/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
12
12
|
konfai/metric/measure.py,sha256=K7s15wNJm1_Iaav9y7Oe_UhSnbLUiybhLTVumHZ8ZaY,22001
|
|
13
13
|
konfai/metric/schedulers.py,sha256=UoSr1TW_hrus3DhvOEbDefxCSUGz7lJS_8vbz0GEye8,1370
|
|
@@ -18,22 +18,22 @@ konfai/models/generation/ddpm.py,sha256=awvuRo-vk8M80N93NWF4i0-WWfaycBxSOmdYJNJv
|
|
|
18
18
|
konfai/models/generation/diffusionGan.py,sha256=KnJyV-tx4CiE_ag-5IXwiYLCuC2yFHX16k2CtASdecg,33199
|
|
19
19
|
konfai/models/generation/gan.py,sha256=-GoKxHm3W9NdD4U77UcJrG5TfOZ3NWFUZG663kt2XPo,7854
|
|
20
20
|
konfai/models/generation/vae.py,sha256=_3JYVT2ojZ0P98tYcD2ny7a-gWVUmnByLDhY7i-n_4g,4719
|
|
21
|
-
konfai/models/registration/registration.py,sha256=
|
|
21
|
+
konfai/models/registration/registration.py,sha256=18EiWt4RJIXLyFtqU-kHjV1sMnQRm9mxAA6_-2B1YqI,6313
|
|
22
22
|
konfai/models/representation/representation.py,sha256=RwQYoxtdph440-t_ZLelykl0hkUAD1zdspQaLkgxb-0,2677
|
|
23
23
|
konfai/models/segmentation/NestedUNet.py,sha256=6XGizAIc4bDL8vx4AHW8BBFjUvovRYcjdMBHsN4ViNo,4301
|
|
24
24
|
konfai/models/segmentation/UNet.py,sha256=9-g63oNqaxSlGmrD1-IVzJ-kac9QphmM7i-3rEsH23I,4218
|
|
25
25
|
konfai/network/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
26
|
-
konfai/network/blocks.py,sha256=
|
|
27
|
-
konfai/network/network.py,sha256=
|
|
26
|
+
konfai/network/blocks.py,sha256=RaTI0Lrvq1V-GIFei-WTUB6wlg4LydZksAyJ8DMk40M,13502
|
|
27
|
+
konfai/network/network.py,sha256=SnwQUKFTZZU8zuix5sm0vW8s0G3h6Fa8pHoIcwC984Q,46512
|
|
28
28
|
konfai/utils/ITK.py,sha256=OxTieDNNYHGkn7zxJsAG-6ecRG1VYMvn1dlBbBe1DOs,13955
|
|
29
29
|
konfai/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
30
|
-
konfai/utils/config.py,sha256=
|
|
31
|
-
konfai/utils/dataset.py,sha256=
|
|
30
|
+
konfai/utils/config.py,sha256=OsCUqgNaPX56nD_MpJ2WRpiQ7xxqHonfPYaS9VGtsl8,12413
|
|
31
|
+
konfai/utils/dataset.py,sha256=6ZzevdhJ7e5zlXATAVwSh9O6acKXM7gYNxkMAa5DrmM,36351
|
|
32
32
|
konfai/utils/registration.py,sha256=v1srEBOcgDnHrx0YtsK6bcj0yCMH7wNeaQ3wC7gEvOw,8898
|
|
33
|
-
konfai/utils/utils.py,sha256=
|
|
34
|
-
konfai-1.1.
|
|
35
|
-
konfai-1.1.
|
|
36
|
-
konfai-1.1.
|
|
37
|
-
konfai-1.1.
|
|
38
|
-
konfai-1.1.
|
|
39
|
-
konfai-1.1.
|
|
33
|
+
konfai/utils/utils.py,sha256=Laq8bGc5mGKFZlJIkHxa-BrC9uR2F7MTRuL4YFAIxQY,23439
|
|
34
|
+
konfai-1.1.2.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
|
35
|
+
konfai-1.1.2.dist-info/METADATA,sha256=Usg7hzW1xmlHswHZb3fC2Hgl2V1Jkq43J0T5j5TDEIw,2515
|
|
36
|
+
konfai-1.1.2.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
37
|
+
konfai-1.1.2.dist-info/entry_points.txt,sha256=fG82HRN5-g39ACSOCtij_I3N6EHxfYnMR0D7TI_8pW8,81
|
|
38
|
+
konfai-1.1.2.dist-info/top_level.txt,sha256=xF470dkIlFoFqTZEOlRehKJr4WU_8OKGXrJqYm9vWKs,7
|
|
39
|
+
konfai-1.1.2.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|