konfai 1.1.2__tar.gz → 1.1.4__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of konfai might be problematic. Click here for more details.
- {konfai-1.1.2 → konfai-1.1.4}/PKG-INFO +1 -1
- {konfai-1.1.2 → konfai-1.1.4}/konfai/data/augmentation.py +41 -50
- {konfai-1.1.2 → konfai-1.1.4}/konfai/data/data_manager.py +15 -7
- {konfai-1.1.2 → konfai-1.1.4}/konfai/data/patching.py +2 -1
- {konfai-1.1.2 → konfai-1.1.4}/konfai/data/transform.py +1 -1
- {konfai-1.1.2 → konfai-1.1.4}/konfai/evaluator.py +6 -3
- {konfai-1.1.2 → konfai-1.1.4}/konfai/main.py +2 -1
- {konfai-1.1.2 → konfai-1.1.4}/konfai/models/segmentation/NestedUNet.py +0 -1
- {konfai-1.1.2 → konfai-1.1.4}/konfai/models/segmentation/UNet.py +9 -11
- {konfai-1.1.2 → konfai-1.1.4}/konfai/network/network.py +25 -12
- {konfai-1.1.2 → konfai-1.1.4}/konfai/predictor.py +18 -2
- {konfai-1.1.2 → konfai-1.1.4}/konfai/trainer.py +15 -1
- {konfai-1.1.2 → konfai-1.1.4}/konfai/utils/config.py +10 -7
- {konfai-1.1.2 → konfai-1.1.4}/konfai/utils/utils.py +4 -4
- {konfai-1.1.2 → konfai-1.1.4}/konfai.egg-info/PKG-INFO +1 -1
- {konfai-1.1.2 → konfai-1.1.4}/pyproject.toml +1 -1
- {konfai-1.1.2 → konfai-1.1.4}/LICENSE +0 -0
- {konfai-1.1.2 → konfai-1.1.4}/README.md +0 -0
- {konfai-1.1.2 → konfai-1.1.4}/konfai/__init__.py +0 -0
- {konfai-1.1.2 → konfai-1.1.4}/konfai/data/__init__.py +0 -0
- {konfai-1.1.2 → konfai-1.1.4}/konfai/metric/__init__.py +0 -0
- {konfai-1.1.2 → konfai-1.1.4}/konfai/metric/measure.py +0 -0
- {konfai-1.1.2 → konfai-1.1.4}/konfai/metric/schedulers.py +0 -0
- {konfai-1.1.2 → konfai-1.1.4}/konfai/models/classification/convNeXt.py +0 -0
- {konfai-1.1.2 → konfai-1.1.4}/konfai/models/classification/resnet.py +0 -0
- {konfai-1.1.2 → konfai-1.1.4}/konfai/models/generation/cStyleGan.py +0 -0
- {konfai-1.1.2 → konfai-1.1.4}/konfai/models/generation/ddpm.py +0 -0
- {konfai-1.1.2 → konfai-1.1.4}/konfai/models/generation/diffusionGan.py +0 -0
- {konfai-1.1.2 → konfai-1.1.4}/konfai/models/generation/gan.py +0 -0
- {konfai-1.1.2 → konfai-1.1.4}/konfai/models/generation/vae.py +0 -0
- {konfai-1.1.2 → konfai-1.1.4}/konfai/models/registration/registration.py +0 -0
- {konfai-1.1.2 → konfai-1.1.4}/konfai/models/representation/representation.py +0 -0
- {konfai-1.1.2 → konfai-1.1.4}/konfai/network/__init__.py +0 -0
- {konfai-1.1.2 → konfai-1.1.4}/konfai/network/blocks.py +0 -0
- {konfai-1.1.2 → konfai-1.1.4}/konfai/utils/ITK.py +0 -0
- {konfai-1.1.2 → konfai-1.1.4}/konfai/utils/__init__.py +0 -0
- {konfai-1.1.2 → konfai-1.1.4}/konfai/utils/dataset.py +0 -0
- {konfai-1.1.2 → konfai-1.1.4}/konfai/utils/registration.py +0 -0
- {konfai-1.1.2 → konfai-1.1.4}/konfai.egg-info/SOURCES.txt +0 -0
- {konfai-1.1.2 → konfai-1.1.4}/konfai.egg-info/dependency_links.txt +0 -0
- {konfai-1.1.2 → konfai-1.1.4}/konfai.egg-info/entry_points.txt +0 -0
- {konfai-1.1.2 → konfai-1.1.4}/konfai.egg-info/requires.txt +0 -0
- {konfai-1.1.2 → konfai-1.1.4}/konfai.egg-info/top_level.txt +0 -0
- {konfai-1.1.2 → konfai-1.1.4}/setup.cfg +0 -0
- {konfai-1.1.2 → konfai-1.1.4}/tests/test_config.py +0 -0
- {konfai-1.1.2 → konfai-1.1.4}/tests/test_dataset.py +0 -0
|
@@ -9,8 +9,7 @@ import os
|
|
|
9
9
|
from konfai import KONFAI_ROOT
|
|
10
10
|
from konfai.utils.config import config
|
|
11
11
|
from konfai.utils.utils import _getModule, AugmentationError
|
|
12
|
-
from konfai.utils.dataset import Attribute, data_to_image
|
|
13
|
-
|
|
12
|
+
from konfai.utils.dataset import Attribute, data_to_image, Dataset
|
|
14
13
|
|
|
15
14
|
def _translate2DMatrix(t: torch.Tensor) -> torch.Tensor:
|
|
16
15
|
return torch.cat((torch.cat((torch.eye(2), torch.tensor([[t[0]], [t[1]]])), dim=1), torch.Tensor([[0,0,1]])), dim=0)
|
|
@@ -56,23 +55,29 @@ class DataAugmentationsList():
|
|
|
56
55
|
self.dataAugmentations : list[DataAugmentation] = []
|
|
57
56
|
self.dataAugmentationsLoader = dataAugmentations
|
|
58
57
|
|
|
59
|
-
def load(self, key: str):
|
|
58
|
+
def load(self, key: str, datasets: list[Dataset]):
|
|
60
59
|
for augmentation, prob in self.dataAugmentationsLoader.items():
|
|
61
60
|
module, name = _getModule(augmentation, "data.augmentation")
|
|
62
|
-
dataAugmentation: DataAugmentation =
|
|
61
|
+
dataAugmentation: DataAugmentation = config("{}.Dataset.augmentations.{}.dataAugmentations.{}".format(KONFAI_ROOT(), key, augmentation))(getattr(importlib.import_module(module), name))(config = None)
|
|
63
62
|
dataAugmentation.load(prob.prob)
|
|
63
|
+
dataAugmentation.setDatasets(datasets)
|
|
64
64
|
self.dataAugmentations.append(dataAugmentation)
|
|
65
|
-
|
|
65
|
+
|
|
66
66
|
class DataAugmentation(ABC):
|
|
67
67
|
|
|
68
|
-
def __init__(self) -> None:
|
|
68
|
+
def __init__(self, groups: Union[list[str], None] = None) -> None:
|
|
69
69
|
self.who_index: dict[int, list[int]] = {}
|
|
70
70
|
self.shape_index: dict[int, list[list[int]]] = {}
|
|
71
71
|
self._prob: float = 0
|
|
72
|
+
self.groups = groups
|
|
73
|
+
self.datasets : list[Dataset] = []
|
|
72
74
|
|
|
73
75
|
def load(self, prob: float):
|
|
74
76
|
self._prob = prob
|
|
75
77
|
|
|
78
|
+
def setDatasets(self, datasets: list[Dataset]):
|
|
79
|
+
self.datasets = datasets
|
|
80
|
+
|
|
76
81
|
def state_init(self, index: Union[None, int], shapes: list[list[int]], caches_attribute: list[Attribute]) -> list[list[int]]:
|
|
77
82
|
if index is not None:
|
|
78
83
|
if index not in self.who_index:
|
|
@@ -93,14 +98,14 @@ class DataAugmentation(ABC):
|
|
|
93
98
|
def _state_init(self, index : int, shapes: list[list[int]], caches_attribute: list[Attribute]) -> list[list[int]]:
|
|
94
99
|
pass
|
|
95
100
|
|
|
96
|
-
def __call__(self, index: int, inputs : list[torch.Tensor], device: Union[torch.device, None]) -> list[torch.Tensor]:
|
|
101
|
+
def __call__(self, name: str, index: int, inputs : list[torch.Tensor], device: Union[torch.device, None]) -> list[torch.Tensor]:
|
|
97
102
|
if len(self.who_index[index]) > 0:
|
|
98
|
-
for i, result in enumerate(self._compute(index, [inputs[i] for i in self.who_index[index]], device)):
|
|
103
|
+
for i, result in enumerate(self._compute(name, index, [inputs[i] for i in self.who_index[index]], device)):
|
|
99
104
|
inputs[self.who_index[index][i]] = result if device is None else result.cpu()
|
|
100
105
|
return inputs
|
|
101
106
|
|
|
102
107
|
@abstractmethod
|
|
103
|
-
def _compute(self, index: int, inputs : list[torch.Tensor], device: Union[torch.device, None]) -> list[torch.Tensor]:
|
|
108
|
+
def _compute(self, name: str, index: int, inputs : list[torch.Tensor], device: Union[torch.device, None]) -> list[torch.Tensor]:
|
|
104
109
|
pass
|
|
105
110
|
|
|
106
111
|
def inverse(self, index: int, a: int, input : torch.Tensor) -> torch.Tensor:
|
|
@@ -118,7 +123,7 @@ class EulerTransform(DataAugmentation):
|
|
|
118
123
|
super().__init__()
|
|
119
124
|
self.matrix: dict[int, list[torch.Tensor]] = {}
|
|
120
125
|
|
|
121
|
-
def _compute(self, index: int, inputs : list[torch.Tensor], device: Union[torch.device, None]) -> list[torch.Tensor]:
|
|
126
|
+
def _compute(self, name: str, index: int, inputs : list[torch.Tensor], device: Union[torch.device, None]) -> list[torch.Tensor]:
|
|
122
127
|
results = []
|
|
123
128
|
for input, matrix in zip(inputs, self.matrix[index]):
|
|
124
129
|
results.append(F.grid_sample(input.unsqueeze(0).type(torch.float32), F.affine_grid(matrix[:, :-1,...], [1]+list(input.shape), align_corners=True).to(input.device), align_corners=True, mode="bilinear", padding_mode="reflection").type(input.dtype).squeeze(0))
|
|
@@ -126,10 +131,9 @@ class EulerTransform(DataAugmentation):
|
|
|
126
131
|
|
|
127
132
|
def _inverse(self, index: int, a: int, input : torch.Tensor) -> torch.Tensor:
|
|
128
133
|
return F.grid_sample(input.unsqueeze(0).type(torch.float32), F.affine_grid(self.matrix[index][a].inverse()[:, :-1,...], [1]+list(input.shape), align_corners=True).to(input.device), align_corners=True, mode="bilinear", padding_mode="reflection").type(input.dtype).squeeze(0)
|
|
129
|
-
|
|
134
|
+
|
|
130
135
|
class Translate(EulerTransform):
|
|
131
136
|
|
|
132
|
-
@config("Translate")
|
|
133
137
|
def __init__(self, t_min: float = -10, t_max = 10, is_int: bool = False):
|
|
134
138
|
super().__init__()
|
|
135
139
|
self.t_min = t_min
|
|
@@ -147,7 +151,6 @@ class Translate(EulerTransform):
|
|
|
147
151
|
|
|
148
152
|
class Rotate(EulerTransform):
|
|
149
153
|
|
|
150
|
-
@config("Rotate")
|
|
151
154
|
def __init__(self, a_min: float = 0, a_max: float = 360, is_quarter: bool = False):
|
|
152
155
|
super().__init__()
|
|
153
156
|
self.a_min = a_min
|
|
@@ -169,7 +172,6 @@ class Rotate(EulerTransform):
|
|
|
169
172
|
|
|
170
173
|
class Scale(EulerTransform):
|
|
171
174
|
|
|
172
|
-
@config("Scale")
|
|
173
175
|
def __init__(self, s_std: float = 0.2):
|
|
174
176
|
super().__init__()
|
|
175
177
|
self.s_std = s_std
|
|
@@ -182,7 +184,6 @@ class Scale(EulerTransform):
|
|
|
182
184
|
|
|
183
185
|
class Flip(DataAugmentation):
|
|
184
186
|
|
|
185
|
-
@config("Flip")
|
|
186
187
|
def __init__(self, f_prob: Union[list[float], None] = [0.33, 0.33 ,0.33]) -> None:
|
|
187
188
|
super().__init__()
|
|
188
189
|
self.f_prob = f_prob
|
|
@@ -194,7 +195,7 @@ class Flip(DataAugmentation):
|
|
|
194
195
|
self.flip[index] = [dims[mask].tolist() for mask in prob]
|
|
195
196
|
return shapes
|
|
196
197
|
|
|
197
|
-
def _compute(self, index: int, inputs : list[torch.Tensor], device: Union[torch.device, None]) -> list[torch.Tensor]:
|
|
198
|
+
def _compute(self, name: str, index: int, inputs : list[torch.Tensor], device: Union[torch.device, None]) -> list[torch.Tensor]:
|
|
198
199
|
results = []
|
|
199
200
|
for input, flip in zip(inputs, self.flip[index]):
|
|
200
201
|
results.append(torch.flip(input, dims=flip))
|
|
@@ -206,12 +207,11 @@ class Flip(DataAugmentation):
|
|
|
206
207
|
|
|
207
208
|
class ColorTransform(DataAugmentation):
|
|
208
209
|
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
super().__init__()
|
|
210
|
+
def __init__(self, groups: Union[list[str], None] = None) -> None:
|
|
211
|
+
super().__init__(groups)
|
|
212
212
|
self.matrix: dict[int, list[torch.Tensor]] = {}
|
|
213
213
|
|
|
214
|
-
def _compute(self, index: int, inputs : list[torch.Tensor], device: Union[torch.device, None]) -> list[torch.Tensor]:
|
|
214
|
+
def _compute(self, name: str, index: int, inputs : list[torch.Tensor], device: Union[torch.device, None]) -> list[torch.Tensor]:
|
|
215
215
|
results = []
|
|
216
216
|
for input, matrix in zip(inputs, self.matrix[index]):
|
|
217
217
|
result = input.reshape([*input.shape[:1], int(np.prod(input.shape[1:]))])
|
|
@@ -231,9 +231,8 @@ class ColorTransform(DataAugmentation):
|
|
|
231
231
|
|
|
232
232
|
class Brightness(ColorTransform):
|
|
233
233
|
|
|
234
|
-
|
|
235
|
-
|
|
236
|
-
super().__init__()
|
|
234
|
+
def __init__(self, b_std: float, groups: Union[list[str], None] = None) -> None:
|
|
235
|
+
super().__init__(groups)
|
|
237
236
|
self.b_std = b_std
|
|
238
237
|
|
|
239
238
|
def _state_init(self, index : int, shapes: list[list[int]], caches_attribute: list[Attribute]) -> list[list[int]]:
|
|
@@ -243,9 +242,8 @@ class Brightness(ColorTransform):
|
|
|
243
242
|
|
|
244
243
|
class Contrast(ColorTransform):
|
|
245
244
|
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
super().__init__()
|
|
245
|
+
def __init__(self, c_std: float, groups: Union[list[str], None] = None) -> None:
|
|
246
|
+
super().__init__(groups)
|
|
249
247
|
self.c_std = c_std
|
|
250
248
|
|
|
251
249
|
def _state_init(self, index : int, shapes: list[list[int]], caches_attribute: list[Attribute]) -> list[list[int]]:
|
|
@@ -255,9 +253,8 @@ class Contrast(ColorTransform):
|
|
|
255
253
|
|
|
256
254
|
class LumaFlip(ColorTransform):
|
|
257
255
|
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
super().__init__()
|
|
256
|
+
def __init__(self, groups: Union[list[str], None] = None) -> None:
|
|
257
|
+
super().__init__(groups)
|
|
261
258
|
self.v = torch.tensor([1, 1, 1, 0])/torch.sqrt(torch.tensor(3))
|
|
262
259
|
|
|
263
260
|
def _state_init(self, index : int, shapes: list[list[int]], caches_attribute: list[Attribute]) -> list[list[int]]:
|
|
@@ -268,8 +265,8 @@ class LumaFlip(ColorTransform):
|
|
|
268
265
|
class HUE(ColorTransform):
|
|
269
266
|
|
|
270
267
|
@config("HUE")
|
|
271
|
-
def __init__(self, hue_max: float) -> None:
|
|
272
|
-
super().__init__()
|
|
268
|
+
def __init__(self, hue_max: float, groups: Union[list[str], None] = None) -> None:
|
|
269
|
+
super().__init__(groups)
|
|
273
270
|
self.hue_max = hue_max
|
|
274
271
|
self.v = torch.tensor([1, 1, 1])/torch.sqrt(torch.tensor(3))
|
|
275
272
|
|
|
@@ -280,9 +277,8 @@ class HUE(ColorTransform):
|
|
|
280
277
|
|
|
281
278
|
class Saturation(ColorTransform):
|
|
282
279
|
|
|
283
|
-
|
|
284
|
-
|
|
285
|
-
super().__init__()
|
|
280
|
+
def __init__(self, s_std: float, groups: Union[list[str], None] = None) -> None:
|
|
281
|
+
super().__init__(groups)
|
|
286
282
|
self.s_std = s_std
|
|
287
283
|
self.v = torch.tensor([1, 1, 1, 0])/torch.sqrt(torch.tensor(3))
|
|
288
284
|
|
|
@@ -367,9 +363,8 @@ class Saturation(ColorTransform):
|
|
|
367
363
|
|
|
368
364
|
class Noise(DataAugmentation):
|
|
369
365
|
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
super().__init__()
|
|
366
|
+
def __init__(self, n_std: float, noise_step: int=1000, beta_start: float = 1e-4, beta_end: float = 0.02, groups: Union[list[str], None] = None) -> None:
|
|
367
|
+
super().__init__(groups)
|
|
373
368
|
self.n_std = n_std
|
|
374
369
|
self.noise_step = noise_step
|
|
375
370
|
|
|
@@ -410,7 +405,7 @@ class Noise(DataAugmentation):
|
|
|
410
405
|
self.ts[index] = [torch.randint(0, int(self.max_T), (1,)) for _ in shapes]
|
|
411
406
|
return shapes
|
|
412
407
|
|
|
413
|
-
def _compute(self, index: int, inputs : list[torch.Tensor], device: Union[torch.device, None]) -> list[torch.Tensor]:
|
|
408
|
+
def _compute(self, name: str, index: int, inputs : list[torch.Tensor], device: Union[torch.device, None]) -> list[torch.Tensor]:
|
|
414
409
|
results = []
|
|
415
410
|
for input, t in zip(inputs, self.ts[index]):
|
|
416
411
|
alpha_hat_t = self.alpha_hat[t].to(input.device).reshape(*[1 for _ in range(len(input.shape))])
|
|
@@ -422,9 +417,8 @@ class Noise(DataAugmentation):
|
|
|
422
417
|
|
|
423
418
|
class CutOUT(DataAugmentation):
|
|
424
419
|
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
super().__init__()
|
|
420
|
+
def __init__(self, c_prob: float, cutout_size: int, value: float, groups: Union[list[str], None] = None) -> None:
|
|
421
|
+
super().__init__(groups)
|
|
428
422
|
self.c_prob = c_prob
|
|
429
423
|
self.cutout_size = cutout_size
|
|
430
424
|
self.centers: dict[int, list[torch.Tensor]] = {}
|
|
@@ -434,7 +428,7 @@ class CutOUT(DataAugmentation):
|
|
|
434
428
|
self.centers[index] = [torch.rand((3) if len(shape) == 3 else (2)) for shape in shapes]
|
|
435
429
|
return shapes
|
|
436
430
|
|
|
437
|
-
def _compute(self, index: int, inputs : list[torch.Tensor], device: Union[torch.device, None]) -> list[torch.Tensor]:
|
|
431
|
+
def _compute(self, name: str, index: int, inputs : list[torch.Tensor], device: Union[torch.device, None]) -> list[torch.Tensor]:
|
|
438
432
|
results = []
|
|
439
433
|
for input, center in zip(inputs, self.centers[index]):
|
|
440
434
|
masks = []
|
|
@@ -453,7 +447,6 @@ class CutOUT(DataAugmentation):
|
|
|
453
447
|
|
|
454
448
|
class Elastix(DataAugmentation):
|
|
455
449
|
|
|
456
|
-
@config("Elastix")
|
|
457
450
|
def __init__(self, grid_spacing: int = 16, max_displacement: int = 16) -> None:
|
|
458
451
|
super().__init__()
|
|
459
452
|
self.grid_spacing = grid_spacing
|
|
@@ -513,7 +506,7 @@ class Elastix(DataAugmentation):
|
|
|
513
506
|
print("Compute in progress : {:.2f} %".format((i+1)/len(shapes)*100))
|
|
514
507
|
return shapes
|
|
515
508
|
|
|
516
|
-
def _compute(self, index: int, inputs : list[torch.Tensor], device: Union[torch.device, None]) -> list[torch.Tensor]:
|
|
509
|
+
def _compute(self, name: str, index: int, inputs : list[torch.Tensor], device: Union[torch.device, None]) -> list[torch.Tensor]:
|
|
517
510
|
results = []
|
|
518
511
|
for input, displacement_field in zip(inputs, self.displacement_fields[index]):
|
|
519
512
|
results.append(F.grid_sample(input.type(torch.float32).unsqueeze(0), displacement_field.to(input.device), align_corners=True, mode="bilinear", padding_mode="border").type(input.dtype).squeeze(0))
|
|
@@ -524,7 +517,6 @@ class Elastix(DataAugmentation):
|
|
|
524
517
|
|
|
525
518
|
class Permute(DataAugmentation):
|
|
526
519
|
|
|
527
|
-
@config("Permute")
|
|
528
520
|
def __init__(self, prob_permute: Union[list[float], None] = [0.5 ,0.5]) -> None:
|
|
529
521
|
super().__init__()
|
|
530
522
|
self._permute_dims = torch.tensor([[0, 2, 1, 3], [0, 3, 1, 2]])
|
|
@@ -546,7 +538,7 @@ class Permute(DataAugmentation):
|
|
|
546
538
|
shapes[i] = [shapes[i][dim-1] for dim in permute[1:]]
|
|
547
539
|
return shapes
|
|
548
540
|
|
|
549
|
-
def _compute(self, index: int, inputs : list[torch.Tensor], device: Union[torch.device, None]) -> list[torch.Tensor]:
|
|
541
|
+
def _compute(self, name: str, index: int, inputs : list[torch.Tensor], device: Union[torch.device, None]) -> list[torch.Tensor]:
|
|
550
542
|
results = []
|
|
551
543
|
for input, prob in zip(inputs, self.permute[index]):
|
|
552
544
|
res = input
|
|
@@ -562,9 +554,8 @@ class Permute(DataAugmentation):
|
|
|
562
554
|
|
|
563
555
|
class Mask(DataAugmentation):
|
|
564
556
|
|
|
565
|
-
|
|
566
|
-
|
|
567
|
-
super().__init__()
|
|
557
|
+
def __init__(self, mask: str, value: float, groups: Union[list[str], None] = None) -> None:
|
|
558
|
+
super().__init__(groups)
|
|
568
559
|
if mask is not None:
|
|
569
560
|
if os.path.exists(mask):
|
|
570
561
|
self.mask = torch.tensor(sitk.GetArrayFromImage(sitk.ReadImage(mask)))
|
|
@@ -577,7 +568,7 @@ class Mask(DataAugmentation):
|
|
|
577
568
|
self.positions[index] = [torch.rand((3) if len(shape) == 3 else (2))*(torch.tensor([max(s1-s2, 0) for s1, s2 in zip(torch.tensor(shape), torch.tensor(self.mask.shape))])) for shape in shapes]
|
|
578
569
|
return [self.mask.shape for _ in shapes]
|
|
579
570
|
|
|
580
|
-
def _compute(self, index: int, inputs : list[torch.Tensor], device: Union[torch.device, None]) -> list[torch.Tensor]:
|
|
571
|
+
def _compute(self, name: str, index: int, inputs : list[torch.Tensor], device: Union[torch.device, None]) -> list[torch.Tensor]:
|
|
581
572
|
results = []
|
|
582
573
|
for input, position in zip(inputs, self.positions[index]):
|
|
583
574
|
slices = [slice(None, None)]+[slice(int(s1), int(s1)+s2) for s1, s2 in zip(position, self.mask.shape)]
|
|
@@ -184,7 +184,7 @@ class DatasetIter(data.Dataset):
|
|
|
184
184
|
data = {}
|
|
185
185
|
x, a, p = self.map[index]
|
|
186
186
|
if x not in self._index_cache:
|
|
187
|
-
if len(self._index_cache) >= self.buffer_size and not self.use_cache:
|
|
187
|
+
if x not in self._index_cache and len(self._index_cache) >= self.buffer_size and not self.use_cache:
|
|
188
188
|
self._unloadData(self._index_cache[0])
|
|
189
189
|
self._loadData(x)
|
|
190
190
|
|
|
@@ -217,7 +217,6 @@ class Subset():
|
|
|
217
217
|
else:
|
|
218
218
|
names_filtred = names
|
|
219
219
|
size = len(names_filtred)
|
|
220
|
-
|
|
221
220
|
index = []
|
|
222
221
|
if self.subset is None:
|
|
223
222
|
index = list(range(0, size))
|
|
@@ -226,15 +225,24 @@ class Subset():
|
|
|
226
225
|
r = np.clip(np.asarray([int(self.subset.split(":")[0]), int(self.subset.split(":")[1])]), 0, size)
|
|
227
226
|
index = list(range(r[0], r[1]))
|
|
228
227
|
elif os.path.exists(self.subset):
|
|
229
|
-
|
|
228
|
+
train_names = []
|
|
230
229
|
with open(self.subset, "r") as f:
|
|
231
230
|
for name in f:
|
|
232
|
-
|
|
231
|
+
train_names.append(name.strip())
|
|
233
232
|
index = []
|
|
234
233
|
for i, name in enumerate(names_filtred):
|
|
235
|
-
if name in
|
|
234
|
+
if name in train_names:
|
|
236
235
|
index.append(i)
|
|
237
|
-
|
|
236
|
+
elif self.subset.startswith("~") and os.path.exists(self.subset[1:]):
|
|
237
|
+
exclude_names = []
|
|
238
|
+
with open(self.subset[1:], "r") as f:
|
|
239
|
+
for name in f:
|
|
240
|
+
exclude_names.append(name.strip())
|
|
241
|
+
index = []
|
|
242
|
+
for i, name in enumerate(names_filtred):
|
|
243
|
+
if name not in exclude_names:
|
|
244
|
+
index.append(i)
|
|
245
|
+
|
|
238
246
|
elif isinstance(self.subset, list):
|
|
239
247
|
if len(self.subset) > 0:
|
|
240
248
|
if isinstance(self.subset[0], int):
|
|
@@ -390,7 +398,7 @@ class Data(ABC):
|
|
|
390
398
|
)
|
|
391
399
|
|
|
392
400
|
for key, dataAugmentations in self.dataAugmentationsList.items():
|
|
393
|
-
dataAugmentations.load(key)
|
|
401
|
+
dataAugmentations.load(key, [self.datasets[filename] for filename, _ in datasets[group_src]])
|
|
394
402
|
|
|
395
403
|
names = set()
|
|
396
404
|
dataset_name : dict[str, dict[str, list[str]]] = {}
|
|
@@ -326,7 +326,8 @@ class DatasetManager():
|
|
|
326
326
|
for dataAugmentations in dataAugmentationsList:
|
|
327
327
|
a_data = [self.data[0].clone() for _ in range(dataAugmentations.nb)]
|
|
328
328
|
for dataAugmentation in dataAugmentations.dataAugmentations:
|
|
329
|
-
|
|
329
|
+
if dataAugmentation.groups is None or self.group_dest in dataAugmentation.groups:
|
|
330
|
+
a_data = dataAugmentation(self.name, self.index, a_data, device)
|
|
330
331
|
|
|
331
332
|
for d in a_data:
|
|
332
333
|
self.data.append(d)
|
|
@@ -52,7 +52,7 @@ class Clip(Transform):
|
|
|
52
52
|
input[torch.where(input < self.min_value)] = self.min_value
|
|
53
53
|
input[torch.where(input > self.max_value)] = self.max_value
|
|
54
54
|
if self.saveClip_min:
|
|
55
|
-
cache_attribute["Min"] = self
|
|
55
|
+
cache_attribute["Min"] = self.min_value
|
|
56
56
|
if self.saveClip_max:
|
|
57
57
|
cache_attribute["Max"] = self.max_value
|
|
58
58
|
return input
|
|
@@ -102,7 +102,7 @@ class Evaluator(DistributedObject):
|
|
|
102
102
|
result = {}
|
|
103
103
|
for output_group in self.metrics:
|
|
104
104
|
for target_group in self.metrics[output_group]:
|
|
105
|
-
targets = [data_dict[group][0].to(0) if torch.cuda.is_available() else data_dict[group][0] for group in target_group.split("
|
|
105
|
+
targets = [data_dict[group][0].to(0) if torch.cuda.is_available() else data_dict[group][0] for group in target_group.split(";") if group in data_dict]
|
|
106
106
|
name = data_dict[output_group][1][0]
|
|
107
107
|
for metric in self.metrics[output_group][target_group]:
|
|
108
108
|
result["{}:{}:{}".format(output_group, target_group, metric.__class__.__name__)] = metric(data_dict[output_group][0].to(0) if torch.cuda.is_available() else data_dict[output_group][0], *targets).item()
|
|
@@ -135,8 +135,11 @@ class Evaluator(DistributedObject):
|
|
|
135
135
|
f"Available groups: {sorted(groupsDest)}"
|
|
136
136
|
)
|
|
137
137
|
|
|
138
|
-
target_groups =
|
|
139
|
-
|
|
138
|
+
target_groups = []
|
|
139
|
+
for i in {target for targets in self.metrics.values() for target in targets}:
|
|
140
|
+
for u in i.split(";"):
|
|
141
|
+
target_groups.append(u)
|
|
142
|
+
missing_targets = set(target_groups) - set(groupsDest)
|
|
140
143
|
if missing_targets:
|
|
141
144
|
raise EvaluatorError(
|
|
142
145
|
f"The following metric target groups are missing from 'groupsDest': {sorted(missing_targets)}. ",
|
|
@@ -9,6 +9,8 @@ import sys
|
|
|
9
9
|
sys.path.insert(0, os.getcwd())
|
|
10
10
|
|
|
11
11
|
def main():
|
|
12
|
+
import tracemalloc
|
|
13
|
+
|
|
12
14
|
parser = argparse.ArgumentParser(description="KonfAI", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
|
13
15
|
try:
|
|
14
16
|
with setup(parser) as distributedObject:
|
|
@@ -22,7 +24,6 @@ def main():
|
|
|
22
24
|
except KeyboardInterrupt:
|
|
23
25
|
print("\n[KonfAI] Manual interruption (Ctrl+C)")
|
|
24
26
|
|
|
25
|
-
|
|
26
27
|
def cluster():
|
|
27
28
|
parser = argparse.ArgumentParser(description="KonfAI", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
|
28
29
|
|
|
@@ -31,7 +31,6 @@ class NestedUNet(network.Network):
|
|
|
31
31
|
self.add_module("Softmax", torch.nn.Softmax(dim=1))
|
|
32
32
|
self.add_module("Argmax", blocks.ArgMax(dim=1))
|
|
33
33
|
|
|
34
|
-
@config("NestedUNet")
|
|
35
34
|
def __init__( self,
|
|
36
35
|
optimizer : network.OptimizerLoader = network.OptimizerLoader(),
|
|
37
36
|
schedulers : network.LRSchedulersLoader = network.LRSchedulersLoader(),
|
|
@@ -7,37 +7,37 @@ from konfai.data.patching import ModelPatch
|
|
|
7
7
|
|
|
8
8
|
class UNetHead(network.ModuleArgsDict):
|
|
9
9
|
|
|
10
|
-
def __init__(self, in_channels: int, nb_class: int, dim: int) -> None:
|
|
10
|
+
def __init__(self, in_channels: int, nb_class: int, dim: int, level: int) -> None:
|
|
11
11
|
super().__init__()
|
|
12
12
|
self.add_module("Conv", blocks.getTorchModule("Conv", dim)(in_channels = in_channels, out_channels = nb_class, kernel_size = 1, stride = 1, padding = 0))
|
|
13
13
|
self.add_module("Softmax", torch.nn.Softmax(dim=1))
|
|
14
14
|
self.add_module("Argmax", blocks.ArgMax(dim=1))
|
|
15
|
+
|
|
15
16
|
|
|
16
17
|
class UNetBlock(network.ModuleArgsDict):
|
|
17
18
|
|
|
18
|
-
def __init__(self, channels: list[int], nb_conv_per_stage: int, blockConfig: blocks.BlockConfig, downSampleMode: blocks.DownSampleMode, upSampleMode: blocks.UpSampleMode, attention : bool, block: type, nb_class: int, dim: int, i : int = 0
|
|
19
|
+
def __init__(self, channels: list[int], nb_conv_per_stage: int, blockConfig: blocks.BlockConfig, downSampleMode: blocks.DownSampleMode, upSampleMode: blocks.UpSampleMode, attention : bool, block: type, nb_class: int, dim: int, i : int = 0) -> None:
|
|
19
20
|
super().__init__()
|
|
20
21
|
blockConfig_stride = blockConfig
|
|
21
22
|
if i > 0:
|
|
22
23
|
if downSampleMode != blocks.DownSampleMode.CONV_STRIDE:
|
|
23
24
|
self.add_module(downSampleMode.name, blocks.downSample(in_channels=channels[0], out_channels=channels[1], downSampleMode=downSampleMode, dim=dim))
|
|
24
25
|
else:
|
|
25
|
-
blockConfig_stride = blocks.BlockConfig(blockConfig.kernel_size,
|
|
26
|
+
blockConfig_stride = blocks.BlockConfig(blockConfig.kernel_size, 2, blockConfig.padding, blockConfig.bias, blockConfig.activation, blockConfig.normMode)
|
|
26
27
|
self.add_module("DownConvBlock", block(in_channels=channels[0], out_channels=channels[1], blockConfigs=[blockConfig_stride]+[blockConfig]*(nb_conv_per_stage-1), dim=dim))
|
|
27
28
|
if len(channels) > 2:
|
|
28
|
-
self.add_module("UNetBlock_{}".format(i+1), UNetBlock(channels[1:], nb_conv_per_stage, blockConfig, downSampleMode, upSampleMode, attention, block, nb_class, dim, i+1
|
|
29
|
+
self.add_module("UNetBlock_{}".format(i+1), UNetBlock(channels[1:], nb_conv_per_stage, blockConfig, downSampleMode, upSampleMode, attention, block, nb_class, dim, i+1))
|
|
29
30
|
self.add_module("UpConvBlock", block(in_channels=(channels[1]+channels[2]) if upSampleMode != blocks.UpSampleMode.CONV_TRANSPOSE else channels[1]*2, out_channels=channels[1], blockConfigs=[blockConfig]*nb_conv_per_stage, dim=dim))
|
|
30
31
|
if nb_class > 0:
|
|
31
|
-
self.add_module("Head", UNetHead(channels[1], nb_class, dim), out_branch=[-1])
|
|
32
|
+
self.add_module("Head", UNetHead(channels[1], nb_class, dim, i), out_branch=[-1])
|
|
32
33
|
if i > 0:
|
|
33
34
|
if attention:
|
|
34
35
|
self.add_module("Attention", blocks.Attention(F_g=channels[1], F_l=channels[0], F_int=channels[0], dim=dim), in_branch=[1, 0], out_branch=[1])
|
|
35
|
-
self.add_module(upSampleMode.name, blocks.upSample(in_channels=channels[1], out_channels=channels[0], upSampleMode=upSampleMode, dim=dim, kernel_size=
|
|
36
|
+
self.add_module(upSampleMode.name, blocks.upSample(in_channels=channels[1], out_channels=channels[0], upSampleMode=upSampleMode, dim=dim, kernel_size=2, stride=2))
|
|
36
37
|
self.add_module("SkipConnection", blocks.Concat(), in_branch=[0, 1])
|
|
37
38
|
|
|
38
39
|
class UNet(network.Network):
|
|
39
40
|
|
|
40
|
-
@config("UNet")
|
|
41
41
|
def __init__( self,
|
|
42
42
|
optimizer : network.OptimizerLoader = network.OptimizerLoader(),
|
|
43
43
|
schedulers : network.LRSchedulersLoader = network.LRSchedulersLoader(),
|
|
@@ -51,8 +51,6 @@ class UNet(network.Network):
|
|
|
51
51
|
downSampleMode: str = "MAXPOOL",
|
|
52
52
|
upSampleMode: str = "CONV_TRANSPOSE",
|
|
53
53
|
attention : bool = False,
|
|
54
|
-
blockType: str = "Conv"
|
|
55
|
-
mri: bool = False) -> None:
|
|
54
|
+
blockType: str = "Conv") -> None:
|
|
56
55
|
super().__init__(in_channels = channels[0], optimizer = optimizer, schedulers = schedulers, outputsCriterions = outputsCriterions, patch=patch, dim = dim)
|
|
57
|
-
self.add_module("UNetBlock_0", UNetBlock(channels, nb_conv_per_stage, blockConfig, downSampleMode=blocks.DownSampleMode._member_map_[downSampleMode], upSampleMode=blocks.UpSampleMode._member_map_[upSampleMode], attention=attention, block = blocks.ConvBlock if blockType == "Conv" else blocks.ResBlock, nb_class=nb_class, dim=dim
|
|
58
|
-
|
|
56
|
+
self.add_module("UNetBlock_0", UNetBlock(channels, nb_conv_per_stage, blockConfig, downSampleMode=blocks.DownSampleMode._member_map_[downSampleMode], upSampleMode=blocks.UpSampleMode._member_map_[upSampleMode], attention=attention, block = blocks.ConvBlock if blockType == "Conv" else blocks.ResBlock, nb_class=nb_class, dim=dim))
|
|
@@ -157,7 +157,7 @@ class Measure():
|
|
|
157
157
|
outputs_group_rename = {}
|
|
158
158
|
|
|
159
159
|
modules = []
|
|
160
|
-
for i,_ in model.
|
|
160
|
+
for i,_,_ in model.named_ModuleArgsDict():
|
|
161
161
|
modules.append(i)
|
|
162
162
|
|
|
163
163
|
for output_group in self.outputsCriterions.keys():
|
|
@@ -167,11 +167,12 @@ class Measure():
|
|
|
167
167
|
"Please check that the name matches exactly a submodule or output of your model architecture."
|
|
168
168
|
)
|
|
169
169
|
for target_group in self.outputsCriterions[output_group]:
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
|
|
174
|
-
|
|
170
|
+
for target_group_tmp in target_group.split(";"):
|
|
171
|
+
if target_group_tmp not in group_dest:
|
|
172
|
+
raise MeasureError(
|
|
173
|
+
f"The target_group '{target_group_tmp}' defined in 'outputsCriterions.{output_group}.targetsCriterions' was not found in the available destination groups.",
|
|
174
|
+
"This target_group is expected for loss or metric computation, but was not loaded in 'group_dest'.",
|
|
175
|
+
f"Please make sure that the group '{target_group_tmp}' is defined in 'Dataset:groups_src:...:groups_dest:'{target_group_tmp}'' and correctly loaded from the dataset.")
|
|
175
176
|
for criterion in self.outputsCriterions[output_group][target_group]:
|
|
176
177
|
if not self.outputsCriterions[output_group][target_group][criterion].isTorchCriterion:
|
|
177
178
|
outputs_group_rename[output_group] = criterion.init(model, output_group, target_group)
|
|
@@ -189,7 +190,7 @@ class Measure():
|
|
|
189
190
|
|
|
190
191
|
def update(self, output_group: str, output : torch.Tensor, data_dict: dict[str, torch.Tensor], it: int, nb_patch: int, training: bool) -> None:
|
|
191
192
|
for target_group in self.outputsCriterions[output_group]:
|
|
192
|
-
target = [data_dict[group].to(output[0].device).detach() for group in target_group.split("
|
|
193
|
+
target = [data_dict[group].to(output[0].device).detach() for group in target_group.split(";") if group in data_dict]
|
|
193
194
|
|
|
194
195
|
for criterion, criterionsAttr in self.outputsCriterions[output_group][target_group].items():
|
|
195
196
|
if it >= criterionsAttr.stepStart and (criterionsAttr.stepStop is None or it <= criterionsAttr.stepStop):
|
|
@@ -928,6 +929,16 @@ class Network(ModuleArgsDict, ABC):
|
|
|
928
929
|
if isinstance(module, ModuleArgsDict):
|
|
929
930
|
module._training = state
|
|
930
931
|
|
|
932
|
+
class MinimalModel(Network):
|
|
933
|
+
|
|
934
|
+
def __init__(self, model: Network, optimizer : OptimizerLoader = OptimizerLoader(),
|
|
935
|
+
schedulers : LRSchedulersLoader = LRSchedulersLoader(),
|
|
936
|
+
outputsCriterions: dict[str, TargetCriterionsLoader] = {"default" : TargetCriterionsLoader()},
|
|
937
|
+
patch : Union[ModelPatch, None] = None,
|
|
938
|
+
dim : int = 3, nb_batch_per_step = 1, init_type = "normal", init_gain = 0.02):
|
|
939
|
+
super().__init__(1, optimizer, schedulers, outputsCriterions, patch, nb_batch_per_step, init_type, init_gain, dim)
|
|
940
|
+
self.add_module("Model", model)
|
|
941
|
+
|
|
931
942
|
class ModelLoader():
|
|
932
943
|
|
|
933
944
|
@config("Model")
|
|
@@ -937,11 +948,13 @@ class ModelLoader():
|
|
|
937
948
|
def getModel(self, train : bool = True, DL_args: Union[str, None] = None, DL_without=["optimizer", "schedulers", "nb_batch_per_step", "init_type", "init_gain"]) -> Network:
|
|
938
949
|
if not DL_args:
|
|
939
950
|
DL_args="{}.Model".format(KONFAI_ROOT())
|
|
940
|
-
|
|
941
|
-
if not train
|
|
942
|
-
|
|
943
|
-
|
|
944
|
-
|
|
951
|
+
DL_args += "."+self.name
|
|
952
|
+
model = config(DL_args)(getattr(importlib.import_module(self.module), self.name))(config = None, DL_without = DL_without if not train else [])
|
|
953
|
+
if not isinstance(model, Network):
|
|
954
|
+
model = config(DL_args)(partial(MinimalModel, model))(config = None, DL_without = DL_without+["model"] if not train else [])
|
|
955
|
+
model.setName(self.name)
|
|
956
|
+
return model
|
|
957
|
+
|
|
945
958
|
class CPU_Model():
|
|
946
959
|
|
|
947
960
|
def __init__(self, model: Network) -> None:
|
|
@@ -91,6 +91,23 @@ class OutDataset(Dataset, NeedDevice, ABC):
|
|
|
91
91
|
super().write(self.group, name, layer.numpy(), self.attributes[index][0][0])
|
|
92
92
|
self.attributes.pop(index)
|
|
93
93
|
|
|
94
|
+
class Reduction():
|
|
95
|
+
|
|
96
|
+
def __init__(self):
|
|
97
|
+
pass
|
|
98
|
+
|
|
99
|
+
class ReductionMean():
|
|
100
|
+
|
|
101
|
+
def __init__(self):
|
|
102
|
+
pass
|
|
103
|
+
|
|
104
|
+
class ReductionMedian():
|
|
105
|
+
|
|
106
|
+
def __init__(self):
|
|
107
|
+
pass
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
|
|
94
111
|
class OutSameAsGroupDataset(OutDataset):
|
|
95
112
|
|
|
96
113
|
@config("OutDataset")
|
|
@@ -241,7 +258,6 @@ class _Predictor():
|
|
|
241
258
|
desc = lambda : "Prediction : {}".format(description(self.modelComposite))
|
|
242
259
|
self.dataloader_prediction.dataset.load("Prediction")
|
|
243
260
|
with tqdm.tqdm(iterable = enumerate(self.dataloader_prediction), leave=True, desc = desc(), total=len(self.dataloader_prediction), ncols=0) as batch_iter:
|
|
244
|
-
dist.barrier()
|
|
245
261
|
for it, data_dict in batch_iter:
|
|
246
262
|
with torch.amp.autocast('cuda', enabled=self.autocast):
|
|
247
263
|
input = self.getInput(data_dict)
|
|
@@ -255,7 +271,7 @@ class _Predictor():
|
|
|
255
271
|
|
|
256
272
|
batch_iter.set_description(desc())
|
|
257
273
|
self.it += 1
|
|
258
|
-
|
|
274
|
+
|
|
259
275
|
def _predict_log(self, data_dict : dict[str, tuple[torch.Tensor, int, int, int]]):
|
|
260
276
|
measures = DistributedObject.getMeasure(self.world_size, self.global_rank, self.local_rank, {"" : self.modelComposite.module}, 1)
|
|
261
277
|
|
|
@@ -34,7 +34,7 @@ class EarlyStoppingBase:
|
|
|
34
34
|
class EarlyStopping(EarlyStoppingBase):
|
|
35
35
|
|
|
36
36
|
@config("EarlyStopping")
|
|
37
|
-
def __init__(self, monitor: Union[list[str], None] = [], patience=10, min_delta=0.0, mode="min"):
|
|
37
|
+
def __init__(self, monitor: Union[list[str], None] = [], patience: int=10, min_delta: float=0.0, mode: str="min"):
|
|
38
38
|
super().__init__()
|
|
39
39
|
self.monitor = [] if monitor is None else monitor
|
|
40
40
|
self.patience = patience
|
|
@@ -109,6 +109,8 @@ class _Trainer():
|
|
|
109
109
|
for data in data_log:
|
|
110
110
|
self.data_log[data.split("/")[0].replace(":", ".")] = (DataLog.__getitem__(data.split("/")[1]).value[0], int(data.split("/")[2]))
|
|
111
111
|
|
|
112
|
+
|
|
113
|
+
|
|
112
114
|
def __enter__(self):
|
|
113
115
|
return self
|
|
114
116
|
|
|
@@ -314,6 +316,18 @@ class Trainer(DistributedObject):
|
|
|
314
316
|
self.ema_decay = ema_decay
|
|
315
317
|
self.modelEMA : Union[torch.optim.swa_utils.AveragedModel, None] = None
|
|
316
318
|
self.data_log = data_log
|
|
319
|
+
|
|
320
|
+
modules = []
|
|
321
|
+
for i,_ in self.model.named_modules():
|
|
322
|
+
modules.append(i)
|
|
323
|
+
for k in self.data_log:
|
|
324
|
+
tmp = k.split("/")[0].replace(":", ".")
|
|
325
|
+
if tmp not in self.dataset.getGroupsDest() and tmp not in modules:
|
|
326
|
+
raise TrainerError( f"Invalid key '{tmp}' in `data_log`.",
|
|
327
|
+
f"This key is neither a destination group from the dataset ({self.dataset.getGroupsDest()})",
|
|
328
|
+
f"nor a valid module name in the model ({modules}).",
|
|
329
|
+
"Please check your `data_log` configuration — it should reference either a model output or a dataset group.")
|
|
330
|
+
|
|
317
331
|
self.gradient_checkpoints = gradient_checkpoints
|
|
318
332
|
self.gpu_checkpoints = gpu_checkpoints
|
|
319
333
|
self.save_checkpoint_mode = save_checkpoint_mode
|
|
@@ -3,7 +3,7 @@ import ruamel.yaml
|
|
|
3
3
|
import inspect
|
|
4
4
|
import collections
|
|
5
5
|
from copy import deepcopy
|
|
6
|
-
from typing import Union, Literal, get_origin, get_args
|
|
6
|
+
from typing import Union, Literal, get_origin, get_args, Any
|
|
7
7
|
import torch
|
|
8
8
|
from konfai import CONFIG_FILE
|
|
9
9
|
from konfai.utils.utils import ConfigError
|
|
@@ -175,8 +175,11 @@ def config(key : Union[str, None] = None):
|
|
|
175
175
|
os.environ['KONFAI_CONFIG_VARIABLE'] = "False"
|
|
176
176
|
kwargs = {}
|
|
177
177
|
for param in list(inspect.signature(function).parameters.values())[len(args):]:
|
|
178
|
-
|
|
178
|
+
if param.name in without:
|
|
179
|
+
continue
|
|
180
|
+
|
|
179
181
|
annotation = param.annotation
|
|
182
|
+
|
|
180
183
|
# --- support Literal ---
|
|
181
184
|
if get_origin(annotation) is Literal:
|
|
182
185
|
allowed_values = get_args(annotation)
|
|
@@ -192,11 +195,10 @@ def config(key : Union[str, None] = None):
|
|
|
192
195
|
for i in annotation.__args__:
|
|
193
196
|
annotation = i
|
|
194
197
|
break
|
|
195
|
-
|
|
196
|
-
continue
|
|
198
|
+
|
|
197
199
|
if not annotation == inspect._empty:
|
|
198
200
|
if annotation not in [int, str, bool, float, torch.Tensor]:
|
|
199
|
-
if str(annotation).startswith("list") or str(annotation).startswith("tuple") or str(annotation).startswith("typing.Tuple") or str(annotation).startswith("typing.List"):
|
|
201
|
+
if str(annotation).startswith("list") or str(annotation).startswith("tuple") or str(annotation).startswith("typing.Tuple") or str(annotation).startswith("typing.List") or str(annotation).startswith("typing.Sequence"):
|
|
200
202
|
elem_type = annotation.__args__[0]
|
|
201
203
|
values = config.getValue(param.name, param.default)
|
|
202
204
|
if getattr(elem_type, '__origin__', None) is Union:
|
|
@@ -221,7 +223,7 @@ def config(key : Union[str, None] = None):
|
|
|
221
223
|
elif str(annotation).startswith("dict"):
|
|
222
224
|
if annotation.__args__[0] == str:
|
|
223
225
|
values = config.getValue(param.name, param.default)
|
|
224
|
-
if values is not None and annotation.__args__[1] not in [int, str, bool, float]:
|
|
226
|
+
if values is not None and annotation.__args__[1] not in [int, str, bool, float, Any]:
|
|
225
227
|
try:
|
|
226
228
|
kwargs[param.name] = {value : annotation.__args__[1](config = filename, DL_args = key_tmp+"."+param.name+"."+value) for value in values}
|
|
227
229
|
except ValueError as e:
|
|
@@ -229,6 +231,7 @@ def config(key : Union[str, None] = None):
|
|
|
229
231
|
except Exception as e:
|
|
230
232
|
raise ConfigError("{} {}".format(values, e))
|
|
231
233
|
else:
|
|
234
|
+
|
|
232
235
|
kwargs[param.name] = values
|
|
233
236
|
else:
|
|
234
237
|
raise ConfigError("Config: The config only supports types : config(Object), int, str, bool, float, list[int], list[str], list[bool], list[float], dict[str, Object]")
|
|
@@ -236,7 +239,7 @@ def config(key : Union[str, None] = None):
|
|
|
236
239
|
try:
|
|
237
240
|
kwargs[param.name] = annotation(config = filename, DL_args = key_tmp)
|
|
238
241
|
except Exception as e:
|
|
239
|
-
raise ConfigError("Failed to instantiate {} with type {}, error {} ".format(param.name, annotation
|
|
242
|
+
raise ConfigError("Failed to instantiate {} with type {}, error {} ".format(param.name, annotation, e))
|
|
240
243
|
|
|
241
244
|
if os.environ['KONFAI_CONFIG_VARIABLE'] == "True":
|
|
242
245
|
os.environ['KONFAI_CONFIG_VARIABLE'] = "False"
|
|
@@ -21,7 +21,7 @@ import sys
|
|
|
21
21
|
import re
|
|
22
22
|
|
|
23
23
|
|
|
24
|
-
def description(model, modelEMA = None, showMemory: bool = True) -> str:
|
|
24
|
+
def description(model, modelEMA = None, showMemory: bool = True, train: bool = True) -> str:
|
|
25
25
|
values_desc = lambda weights, values: " ".join(["{}({:.2f}) : {:.6f}".format(name.split(":")[-1], weight, value) for (name, value), weight in zip(values.items(), weights.values())])
|
|
26
26
|
model_desc = lambda model : "("+" ".join(["{}({:.6f}) : {}".format(name, network.optimizer.param_groups[0]['lr'] if network.optimizer is not None else 0, values_desc(network.measure.getLastWeights(), network.measure.getLastValues())) for name, network in model.module.getNetworks().items() if network.measure is not None])+")"
|
|
27
27
|
result = "Loss {}".format(model_desc(model))
|
|
@@ -33,9 +33,9 @@ def description(model, modelEMA = None, showMemory: bool = True) -> str:
|
|
|
33
33
|
return result
|
|
34
34
|
|
|
35
35
|
def _getModule(classpath : str, type : str) -> tuple[str, str]:
|
|
36
|
-
if len(classpath.split("
|
|
37
|
-
module = ".".join(classpath.split("
|
|
38
|
-
name = classpath.split("
|
|
36
|
+
if len(classpath.split(":")) > 1:
|
|
37
|
+
module = ".".join(classpath.split(":")[:-1])
|
|
38
|
+
name = classpath.split(":")[-1]
|
|
39
39
|
else:
|
|
40
40
|
module = "konfai."+type+("." if len(classpath.split(".")) > 2 else "")+".".join(classpath.split(".")[:-1])
|
|
41
41
|
name = classpath.split(".")[-1]
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|