konfai 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of konfai might be problematic. Click here for more details.
- konfai/__init__.py +16 -0
- konfai/data/HDF5.py +326 -0
- konfai/data/__init__.py +0 -0
- konfai/data/augmentation.py +597 -0
- konfai/data/dataset.py +470 -0
- konfai/data/transform.py +536 -0
- konfai/evaluator.py +146 -0
- konfai/main.py +43 -0
- konfai/metric/__init__.py +0 -0
- konfai/metric/measure.py +488 -0
- konfai/metric/schedulers.py +49 -0
- konfai/models/classification/convNeXt.py +175 -0
- konfai/models/classification/resnet.py +116 -0
- konfai/models/generation/cStyleGan.py +137 -0
- konfai/models/generation/ddpm.py +218 -0
- konfai/models/generation/diffusionGan.py +557 -0
- konfai/models/generation/gan.py +134 -0
- konfai/models/generation/vae.py +72 -0
- konfai/models/registration/registration.py +136 -0
- konfai/models/representation/representation.py +57 -0
- konfai/models/segmentation/NestedUNet.py +53 -0
- konfai/models/segmentation/UNet.py +58 -0
- konfai/network/__init__.py +0 -0
- konfai/network/blocks.py +348 -0
- konfai/network/network.py +950 -0
- konfai/predictor.py +366 -0
- konfai/trainer.py +330 -0
- konfai/utils/ITK.py +269 -0
- konfai/utils/Registration.py +199 -0
- konfai/utils/__init__.py +0 -0
- konfai/utils/config.py +218 -0
- konfai/utils/dataset.py +764 -0
- konfai/utils/utils.py +493 -0
- konfai-1.0.0.dist-info/METADATA +68 -0
- konfai-1.0.0.dist-info/RECORD +39 -0
- konfai-1.0.0.dist-info/WHEEL +5 -0
- konfai-1.0.0.dist-info/entry_points.txt +3 -0
- konfai-1.0.0.dist-info/licenses/LICENSE +201 -0
- konfai-1.0.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,199 @@
|
|
|
1
|
+
import SimpleITK as sitk
|
|
2
|
+
from typing import Union
|
|
3
|
+
import numpy as np
|
|
4
|
+
import sys
|
|
5
|
+
import scipy
|
|
6
|
+
|
|
7
|
+
def parameterMap_to_transform(path_src: str) -> Union[sitk.Transform, list[sitk.Transform]]:
|
|
8
|
+
transform = sitk.ReadParameterFile(path_src)
|
|
9
|
+
format = lambda x: [float(i) for i in x]
|
|
10
|
+
dimension = int(transform["FixedImageDimension"][0])
|
|
11
|
+
|
|
12
|
+
if transform["Transform"][0] == "EulerTransform":
|
|
13
|
+
if dimension == 2:
|
|
14
|
+
result = sitk.Euler2DTransform()
|
|
15
|
+
else:
|
|
16
|
+
result = sitk.Euler3DTransform()
|
|
17
|
+
parameters = format(transform["TransformParameters"])
|
|
18
|
+
fixedParameters = format(transform["CenterOfRotationPoint"])+[0]
|
|
19
|
+
elif transform["Transform"][0] == "TranslationTransform":
|
|
20
|
+
result = sitk.TranslationTransform(dimension)
|
|
21
|
+
parameters = format(transform["TransformParameters"])
|
|
22
|
+
fixedParameters = []
|
|
23
|
+
elif transform["Transform"][0] == "AffineTransform":
|
|
24
|
+
result = sitk.AffineTransform(dimension)
|
|
25
|
+
parameters = format(transform["TransformParameters"])
|
|
26
|
+
fixedParameters = format(transform["CenterOfRotationPoint"])+[0]
|
|
27
|
+
elif transform["Transform"][0] == "BSplineStackTransform":
|
|
28
|
+
parameters = format(transform["TransformParameters"])
|
|
29
|
+
GridSize = format(transform["GridSize"])
|
|
30
|
+
GridOrigin = format(transform["GridOrigin"])
|
|
31
|
+
GridSpacing = format(transform["GridSpacing"])
|
|
32
|
+
GridDirection = np.asarray(format(transform["GridDirection"])).reshape((dimension, dimension)).T.flatten()
|
|
33
|
+
fixedParameters = np.concatenate([GridSize, GridOrigin, GridSpacing, GridDirection])
|
|
34
|
+
|
|
35
|
+
nb = int(format(transform["Size"])[-1])
|
|
36
|
+
sub = int(np.prod(GridSize))*dimension
|
|
37
|
+
results = []
|
|
38
|
+
for i in range(nb):
|
|
39
|
+
result = sitk.BSplineTransform(dimension)
|
|
40
|
+
sub_parameters = np.asarray(parameters[i*sub:(i+1)*sub])
|
|
41
|
+
result.SetFixedParameters(fixedParameters)
|
|
42
|
+
result.SetParameters(sub_parameters)
|
|
43
|
+
results.append(result)
|
|
44
|
+
return results
|
|
45
|
+
elif transform["Transform"][0] == "AffineLogStackTransform":
|
|
46
|
+
parameters = format(transform["TransformParameters"])
|
|
47
|
+
fixedParameters = format(transform["CenterOfRotationPoint"])+[0]
|
|
48
|
+
|
|
49
|
+
nb = int(transform["NumberOfSubTransforms"][0])
|
|
50
|
+
sub = dimension*4
|
|
51
|
+
results = []
|
|
52
|
+
for i in range(nb):
|
|
53
|
+
result = sitk.AffineTransform(dimension)
|
|
54
|
+
sub_parameters = np.asarray(parameters[i*sub:(i+1)*sub])
|
|
55
|
+
|
|
56
|
+
result.SetFixedParameters(fixedParameters)
|
|
57
|
+
result.SetParameters(np.concatenate([scipy.linalg.expm(sub_parameters[:dimension*dimension].reshape((dimension,dimension))).flatten(), sub_parameters[-dimension:]]))
|
|
58
|
+
results.append(result)
|
|
59
|
+
return results
|
|
60
|
+
else:
|
|
61
|
+
result = sitk.BSplineTransform(dimension)
|
|
62
|
+
|
|
63
|
+
parameters = format(transform["TransformParameters"])
|
|
64
|
+
GridSize = format(transform["GridSize"])
|
|
65
|
+
GridOrigin = format(transform["GridOrigin"])
|
|
66
|
+
GridSpacing = format(transform["GridSpacing"])
|
|
67
|
+
GridDirection = np.array(format(transform["GridDirection"])).reshape((dimension,dimension)).T.flatten()
|
|
68
|
+
fixedParameters = np.concatenate([GridSize, GridOrigin, GridSpacing, GridDirection])
|
|
69
|
+
|
|
70
|
+
result.SetFixedParameters(fixedParameters)
|
|
71
|
+
result.SetParameters(parameters)
|
|
72
|
+
return result
|
|
73
|
+
|
|
74
|
+
if __name__ == "__main__":
|
|
75
|
+
out_path = sys.argv[1]
|
|
76
|
+
finename = sys.argv[2]
|
|
77
|
+
finename_dest = sys.argv[3]
|
|
78
|
+
transform = parameterMap_to_transform("{}/{}".format(out_path, finename))
|
|
79
|
+
sitk.WriteTransform(transform, "{}/{}".format(out_path, finename_dest))
|
|
80
|
+
|
|
81
|
+
def getFlatLabel(mask: sitk.Image, labels: list[int]) -> sitk.Image:
|
|
82
|
+
data = sitk.GetArrayFromImage(mask)
|
|
83
|
+
result_data = np.zeros_like(data, np.uint8)
|
|
84
|
+
|
|
85
|
+
for label in labels:
|
|
86
|
+
result_data[data == label] = 1
|
|
87
|
+
|
|
88
|
+
result = sitk.GetImageFromArray(result_data)
|
|
89
|
+
result.CopyInformation(mask)
|
|
90
|
+
return result
|
|
91
|
+
|
|
92
|
+
def rampFilterHistogram(image: sitk.Image, rampStart: float, rampEnd: float) -> sitk.Image:
|
|
93
|
+
imageData = sitk.GetArrayFromImage(image)
|
|
94
|
+
filter = np.logical_and(imageData > rampStart, imageData < rampEnd)
|
|
95
|
+
rampWidth = rampEnd - rampStart
|
|
96
|
+
imageData[filter] = (1/rampWidth) * (imageData[filter] - rampStart) * imageData[filter]
|
|
97
|
+
filteredImage = sitk.GetImageFromArray(imageData)
|
|
98
|
+
filteredImage.CopyInformation(image)
|
|
99
|
+
return filteredImage
|
|
100
|
+
|
|
101
|
+
def elastic_registration( fixed_image : sitk.Image,
|
|
102
|
+
moving_image : sitk.Image,
|
|
103
|
+
fixed_mask : Union[sitk.Image, None],
|
|
104
|
+
moving_mask : Union[sitk.Image, None],
|
|
105
|
+
name_parameterMap : str,
|
|
106
|
+
outputDir: str) -> sitk.Transform:
|
|
107
|
+
labels = np.unique(sitk.GetArrayFromImage(fixed_mask))
|
|
108
|
+
fixed_mask = getFlatLabel(fixed_mask, labels[1:])
|
|
109
|
+
moving_mask = getFlatLabel(moving_mask, labels[1:])
|
|
110
|
+
|
|
111
|
+
fixed_mask.CopyInformation(fixed_image)
|
|
112
|
+
moving_mask.CopyInformation(moving_image)
|
|
113
|
+
|
|
114
|
+
fixed_mask_dillated = sitk.BinaryDilate(fixed_mask, [5,5,5])
|
|
115
|
+
moving_mask_dillated = sitk.BinaryDilate(moving_mask, [5,5,5])
|
|
116
|
+
|
|
117
|
+
fixed_image = sitk.Mask(fixed_image, fixed_mask)
|
|
118
|
+
moving_image = sitk.Mask(moving_image, moving_mask)
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
minGradientMagnitude = 50
|
|
122
|
+
fixed_image_gradient = rampFilterHistogram(sitk.VectorMagnitude(sitk.Gradient(fixed_image)), 0, minGradientMagnitude)
|
|
123
|
+
moving_image_gradient = rampFilterHistogram(sitk.VectorMagnitude(sitk.Gradient(moving_image)), 0, minGradientMagnitude)
|
|
124
|
+
|
|
125
|
+
elastixImageFilter = sitk.ElastixImageFilter()
|
|
126
|
+
elastixImageFilter.SetFixedImage(fixed_image_gradient)
|
|
127
|
+
elastixImageFilter.AddFixedImage(fixed_image)
|
|
128
|
+
elastixImageFilter.AddFixedImage(fixed_image)
|
|
129
|
+
|
|
130
|
+
if fixed_mask is not None:
|
|
131
|
+
elastixImageFilter.SetFixedMask(fixed_mask)
|
|
132
|
+
elastixImageFilter.AddFixedMask(fixed_mask_dillated)
|
|
133
|
+
elastixImageFilter.AddFixedMask(fixed_mask_dillated)
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
elastixImageFilter.SetMovingImage(moving_image_gradient)
|
|
137
|
+
elastixImageFilter.AddMovingImage(moving_image)
|
|
138
|
+
elastixImageFilter.AddMovingImage(moving_image)
|
|
139
|
+
|
|
140
|
+
if moving_mask is not None:
|
|
141
|
+
elastixImageFilter.SetMovingMask(moving_mask)
|
|
142
|
+
elastixImageFilter.AddMovingMask(moving_mask_dillated)
|
|
143
|
+
elastixImageFilter.AddMovingMask(moving_mask_dillated)
|
|
144
|
+
|
|
145
|
+
elastixImageFilter.SetParameterMap(sitk.ReadParameterFile("{}.txt".format(name_parameterMap)))
|
|
146
|
+
elastixImageFilter.LogToConsoleOn()
|
|
147
|
+
elastixImageFilter.SetOutputDirectory(outputDir)
|
|
148
|
+
|
|
149
|
+
elastixImageFilter.Execute()
|
|
150
|
+
|
|
151
|
+
transform = parameterMap_to_transform("{}TransformParameters".format(outputDir))
|
|
152
|
+
|
|
153
|
+
return transform
|
|
154
|
+
|
|
155
|
+
def registration( fixed_image : sitk.Image,
|
|
156
|
+
moving_image : sitk.Image,
|
|
157
|
+
fixed_mask : Union[sitk.Image, None],
|
|
158
|
+
moving_mask : Union[sitk.Image, None],
|
|
159
|
+
name_parameterMap : str,
|
|
160
|
+
outputDir: str) -> sitk.Transform:
|
|
161
|
+
elastixImageFilter = sitk.ElastixImageFilter()
|
|
162
|
+
elastixImageFilter.SetFixedImage(fixed_image)
|
|
163
|
+
if fixed_mask is not None:
|
|
164
|
+
elastixImageFilter.SetFixedMask(fixed_mask)
|
|
165
|
+
|
|
166
|
+
elastixImageFilter.SetMovingImage(moving_image)
|
|
167
|
+
if moving_mask is not None:
|
|
168
|
+
elastixImageFilter.SetMovingMask(moving_mask)
|
|
169
|
+
|
|
170
|
+
elastixImageFilter.SetParameterMap(sitk.ReadParameterFile("{}.txt".format(name_parameterMap)))
|
|
171
|
+
elastixImageFilter.LogToConsoleOn()
|
|
172
|
+
elastixImageFilter.SetOutputDirectory(outputDir)
|
|
173
|
+
elastixImageFilter.Execute()
|
|
174
|
+
|
|
175
|
+
transform = parameterMap_to_transform("{}TransformParameters".format(outputDir))
|
|
176
|
+
|
|
177
|
+
return transform
|
|
178
|
+
|
|
179
|
+
def registration_groupewise(images_1: sitk.Image, masks: sitk.Image, images_2: sitk.Image, name_parameterMap : str, output_dir: str):
|
|
180
|
+
elastixImageFilter = sitk.ElastixImageFilter()
|
|
181
|
+
elastixImageFilter.SetFixedImage(images_1)
|
|
182
|
+
elastixImageFilter.SetMovingImage(images_1)
|
|
183
|
+
elastixImageFilter.SetFixedMask(masks)
|
|
184
|
+
elastixImageFilter.SetMovingMask(masks)
|
|
185
|
+
|
|
186
|
+
if images_2 is not None:
|
|
187
|
+
elastixImageFilter.AddFixedImage(images_2)
|
|
188
|
+
elastixImageFilter.AddMovingImage(images_2)
|
|
189
|
+
#elastixImageFilter.AddFixedImage(images_2)
|
|
190
|
+
#elastixImageFilter.AddMovingImage(images_2)
|
|
191
|
+
|
|
192
|
+
elastixImageFilter.SetParameterMap(sitk.ReadParameterFile("{}.txt".format(name_parameterMap)))
|
|
193
|
+
elastixImageFilter.LogToConsoleOn()
|
|
194
|
+
elastixImageFilter.SetOutputDirectory(output_dir)
|
|
195
|
+
elastixImageFilter.LogToFileOn()
|
|
196
|
+
elastixImageFilter.Execute()
|
|
197
|
+
|
|
198
|
+
transforms = parameterMap_to_transform("{}TransformParameters".format(output_dir))
|
|
199
|
+
return transforms
|
konfai/utils/__init__.py
ADDED
|
File without changes
|
konfai/utils/config.py
ADDED
|
@@ -0,0 +1,218 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import ruamel.yaml
|
|
3
|
+
import inspect
|
|
4
|
+
import collections
|
|
5
|
+
from copy import deepcopy
|
|
6
|
+
from typing import Union
|
|
7
|
+
import torch
|
|
8
|
+
|
|
9
|
+
from KonfAI.konfai import CONFIG_FILE
|
|
10
|
+
|
|
11
|
+
yaml = ruamel.yaml.YAML()
|
|
12
|
+
|
|
13
|
+
class ConfigError(Exception):
|
|
14
|
+
|
|
15
|
+
def __init__(self, message : str = "The config only supports types : config(Object), int, str, bool, float, list[int], list[str], list[bool], list[float], dict[str, Object]") -> None:
|
|
16
|
+
self.message = message
|
|
17
|
+
super().__init__(self.message)
|
|
18
|
+
|
|
19
|
+
class Config():
|
|
20
|
+
|
|
21
|
+
def __init__(self, filename, key) -> None:
|
|
22
|
+
self.filename = filename
|
|
23
|
+
self.keys = key.split(".")
|
|
24
|
+
|
|
25
|
+
def __enter__(self):
|
|
26
|
+
if not os.path.exists(self.filename):
|
|
27
|
+
result = input("Create a new config file ? [no,yes,interactive] : ")
|
|
28
|
+
if result in ["yes", "interactive"]:
|
|
29
|
+
os.environ["DEEP_LEANING_API_CONFIG_MODE"] = "interactive" if result == "interactive" else "default"
|
|
30
|
+
else:
|
|
31
|
+
exit(0)
|
|
32
|
+
with open(self.filename, "w") as f:
|
|
33
|
+
pass
|
|
34
|
+
|
|
35
|
+
self.yml = open(self.filename, 'r')
|
|
36
|
+
self.data = yaml.load(self.yml)
|
|
37
|
+
if self.data == None:
|
|
38
|
+
self.data = {}
|
|
39
|
+
|
|
40
|
+
self.config = self.data
|
|
41
|
+
|
|
42
|
+
for key in self.keys:
|
|
43
|
+
if self.config == None or key not in self.config:
|
|
44
|
+
self.config = {key : {}}
|
|
45
|
+
|
|
46
|
+
self.config = self.config[key]
|
|
47
|
+
return self
|
|
48
|
+
|
|
49
|
+
def createDictionary(self, data, keys, i) -> dict:
|
|
50
|
+
if keys[i] not in data:
|
|
51
|
+
data = {keys[i]: data}
|
|
52
|
+
if i == 0:
|
|
53
|
+
return data
|
|
54
|
+
else:
|
|
55
|
+
i -= 1
|
|
56
|
+
return self.createDictionary(data, keys, i)
|
|
57
|
+
|
|
58
|
+
def merge(self, dict1, dict2) -> dict:
|
|
59
|
+
result = deepcopy(dict1)
|
|
60
|
+
|
|
61
|
+
for key, value in dict2.items():
|
|
62
|
+
if isinstance(value, collections.abc.Mapping):
|
|
63
|
+
result[key] = self.merge(result.get(key, {}), value)
|
|
64
|
+
else:
|
|
65
|
+
if not dict2[key] == None:
|
|
66
|
+
result[key] = deepcopy(dict2[key])
|
|
67
|
+
return result
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def __exit__(self, type, value, traceback) -> None:
|
|
71
|
+
self.yml.close()
|
|
72
|
+
if os.environ["DEEP_LEANING_API_CONFIG_MODE"] == "remove":
|
|
73
|
+
if os.path.exists(CONFIG_FILE()):
|
|
74
|
+
os.remove(CONFIG_FILE())
|
|
75
|
+
return
|
|
76
|
+
with open(self.filename, 'r') as yml:
|
|
77
|
+
data = yaml.load(yml)
|
|
78
|
+
if data == None:
|
|
79
|
+
data = {}
|
|
80
|
+
with open(self.filename, 'w') as yml:
|
|
81
|
+
yaml.dump(self.merge(data, self.createDictionary(self.config, self.keys, len(self.keys)-1)), yml)
|
|
82
|
+
|
|
83
|
+
@staticmethod
|
|
84
|
+
def _getInput(name : str, default : str) -> str:
|
|
85
|
+
try:
|
|
86
|
+
return input("{} [{}]: ".format(name, ",".join(default.split(":")[1:]) if len(default.split(":")) else ""))
|
|
87
|
+
except:
|
|
88
|
+
result = input("\nKeep a default configuration file ? (yes,no) : ")
|
|
89
|
+
if result == "yes":
|
|
90
|
+
os.environ["DEEP_LEANING_API_CONFIG_MODE"] = "default"
|
|
91
|
+
else:
|
|
92
|
+
os.environ["DEEP_LEANING_API_CONFIG_MODE"] = "remove"
|
|
93
|
+
exit(0)
|
|
94
|
+
return default.split(":")[1] if len(default.split(":")) > 1 else default
|
|
95
|
+
|
|
96
|
+
@staticmethod
|
|
97
|
+
def _getInputDefault(name : str, default : Union[str, None], isList : bool = False) -> Union[list[Union[str, None]], str, None]:
|
|
98
|
+
if isinstance(default, str) and (default == "default" or (len(default.split(":")) > 1 and default.split(":")[0] == "default")):
|
|
99
|
+
if os.environ["DEEP_LEANING_API_CONFIG_MODE"] == "interactive":
|
|
100
|
+
if isList:
|
|
101
|
+
list_tmp = []
|
|
102
|
+
key_tmp = "OK"
|
|
103
|
+
while key_tmp != "!" and os.environ["DEEP_LEANING_API_CONFIG_MODE"] == "interactive":
|
|
104
|
+
key_tmp = Config._getInput(name, default)
|
|
105
|
+
if key_tmp != "!":
|
|
106
|
+
if key_tmp == "":
|
|
107
|
+
key_tmp = default.split(":")[1] if len(default.split(":")) > 1 else default
|
|
108
|
+
list_tmp.append(key_tmp)
|
|
109
|
+
return list_tmp
|
|
110
|
+
else:
|
|
111
|
+
value = Config._getInput(name, default)
|
|
112
|
+
if value == "":
|
|
113
|
+
return default.split(":")[1] if len(default.split(":")) > 1 else default
|
|
114
|
+
else:
|
|
115
|
+
return value
|
|
116
|
+
else:
|
|
117
|
+
default = default.split(":")[1] if len(default.split(":")) > 1 else default
|
|
118
|
+
return [default] if isList else default
|
|
119
|
+
|
|
120
|
+
def getValue(self, name, default) -> object:
|
|
121
|
+
if name in self.config and self.config[name] is not None:
|
|
122
|
+
value = self.config[name]
|
|
123
|
+
if value == None:
|
|
124
|
+
value = default
|
|
125
|
+
value_config = value
|
|
126
|
+
else:
|
|
127
|
+
value = Config._getInputDefault(name, default if default != inspect._empty else None)
|
|
128
|
+
|
|
129
|
+
value_config = value
|
|
130
|
+
if type(value_config) == tuple:
|
|
131
|
+
value_config = list(value)
|
|
132
|
+
|
|
133
|
+
if type(value_config) == list:
|
|
134
|
+
list_tmp = []
|
|
135
|
+
for key in value_config:
|
|
136
|
+
list_tmp.extend(Config._getInputDefault(name, key, isList=True))
|
|
137
|
+
|
|
138
|
+
value = list_tmp
|
|
139
|
+
value_config = list_tmp
|
|
140
|
+
|
|
141
|
+
if type(value) == dict:
|
|
142
|
+
key_tmp = []
|
|
143
|
+
|
|
144
|
+
value_config = {}
|
|
145
|
+
dict_value = {}
|
|
146
|
+
for key in value:
|
|
147
|
+
key_tmp.extend(Config._getInputDefault(name, key, isList=True))
|
|
148
|
+
for key in key_tmp:
|
|
149
|
+
if key in value:
|
|
150
|
+
value_tmp = value[key]
|
|
151
|
+
else:
|
|
152
|
+
value_tmp = next(v for k,v in value.items() if "default" in k)
|
|
153
|
+
|
|
154
|
+
value_config[key] = None
|
|
155
|
+
dict_value[key] = value_tmp
|
|
156
|
+
value = dict_value
|
|
157
|
+
if isinstance(self.config, str):
|
|
158
|
+
os.environ['DEEP_LEARNING_API_CONFIG_VARIABLE'] = "True"
|
|
159
|
+
return None
|
|
160
|
+
|
|
161
|
+
self.config[name] = value_config if value_config is not None else "None"
|
|
162
|
+
if value == "None":
|
|
163
|
+
value = None
|
|
164
|
+
return value
|
|
165
|
+
|
|
166
|
+
def config(key : Union[str, None] = None):
|
|
167
|
+
def decorator(function):
|
|
168
|
+
def new_function(*args, **kwargs):
|
|
169
|
+
if "config" in kwargs:
|
|
170
|
+
filename = kwargs["config"]
|
|
171
|
+
if filename == None:
|
|
172
|
+
filename = os.environ['DEEP_LEARNING_API_CONFIG_FILE']
|
|
173
|
+
else:
|
|
174
|
+
os.environ['DEEP_LEARNING_API_CONFIG_FILE'] = filename
|
|
175
|
+
key_tmp = kwargs["DL_args"]+("."+key if key is not None else "") if "DL_args" in kwargs else key
|
|
176
|
+
without = kwargs["DL_without"] if "DL_without" in kwargs else []
|
|
177
|
+
os.environ['DEEP_LEARNING_API_CONFIG_PATH'] = key_tmp
|
|
178
|
+
with Config(filename, key_tmp) as config:
|
|
179
|
+
os.environ['DEEP_LEARNING_API_CONFIG_VARIABLE'] = "False"
|
|
180
|
+
kwargs = {}
|
|
181
|
+
for param in list(inspect.signature(function).parameters.values())[len(args):]:
|
|
182
|
+
annotation = param.annotation
|
|
183
|
+
if str(annotation).startswith("typing.Union") or str(annotation).startswith("typing.Optional"):
|
|
184
|
+
for i in annotation.__args__:
|
|
185
|
+
annotation = i
|
|
186
|
+
break
|
|
187
|
+
if param.name in without:
|
|
188
|
+
continue
|
|
189
|
+
if not annotation == inspect._empty:
|
|
190
|
+
if annotation not in [int, str, bool, float, torch.Tensor]:
|
|
191
|
+
if str(annotation).startswith("list") or str(annotation).startswith("tuple") or str(annotation).startswith("typing.Tuple"):
|
|
192
|
+
if annotation.__args__[0] in [int, str, bool, float]:
|
|
193
|
+
values = config.getValue(param.name, param.default)
|
|
194
|
+
kwargs[param.name] = values
|
|
195
|
+
else:
|
|
196
|
+
raise ConfigError()
|
|
197
|
+
elif str(annotation).startswith("dict"):
|
|
198
|
+
if annotation.__args__[0] == str:
|
|
199
|
+
values = config.getValue(param.name, param.default)
|
|
200
|
+
if values is not None and annotation.__args__[1] not in [int, str, bool, float]:
|
|
201
|
+
kwargs[param.name] = {value : annotation.__args__[1](config = filename, DL_args = key_tmp+"."+param.name+"."+value) for value in values}
|
|
202
|
+
else:
|
|
203
|
+
kwargs[param.name] = values
|
|
204
|
+
else:
|
|
205
|
+
raise ConfigError()
|
|
206
|
+
else:
|
|
207
|
+
kwargs[param.name] = annotation(config = filename, DL_args = key_tmp)
|
|
208
|
+
if os.environ['DEEP_LEARNING_API_CONFIG_VARIABLE'] == "True":
|
|
209
|
+
os.environ['DEEP_LEARNING_API_CONFIG_VARIABLE'] = "False"
|
|
210
|
+
kwargs[param.name] = None
|
|
211
|
+
else:
|
|
212
|
+
kwargs[param.name] = config.getValue(param.name, param.default)
|
|
213
|
+
elif param.name != "self":
|
|
214
|
+
kwargs[param.name] = config.getValue(param.name, param.default)
|
|
215
|
+
result = function(*args, **kwargs)
|
|
216
|
+
return result
|
|
217
|
+
return new_function
|
|
218
|
+
return decorator
|