konfai 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of konfai might be problematic. Click here for more details.

@@ -0,0 +1,199 @@
1
+ import SimpleITK as sitk
2
+ from typing import Union
3
+ import numpy as np
4
+ import sys
5
+ import scipy
6
+
7
+ def parameterMap_to_transform(path_src: str) -> Union[sitk.Transform, list[sitk.Transform]]:
8
+ transform = sitk.ReadParameterFile(path_src)
9
+ format = lambda x: [float(i) for i in x]
10
+ dimension = int(transform["FixedImageDimension"][0])
11
+
12
+ if transform["Transform"][0] == "EulerTransform":
13
+ if dimension == 2:
14
+ result = sitk.Euler2DTransform()
15
+ else:
16
+ result = sitk.Euler3DTransform()
17
+ parameters = format(transform["TransformParameters"])
18
+ fixedParameters = format(transform["CenterOfRotationPoint"])+[0]
19
+ elif transform["Transform"][0] == "TranslationTransform":
20
+ result = sitk.TranslationTransform(dimension)
21
+ parameters = format(transform["TransformParameters"])
22
+ fixedParameters = []
23
+ elif transform["Transform"][0] == "AffineTransform":
24
+ result = sitk.AffineTransform(dimension)
25
+ parameters = format(transform["TransformParameters"])
26
+ fixedParameters = format(transform["CenterOfRotationPoint"])+[0]
27
+ elif transform["Transform"][0] == "BSplineStackTransform":
28
+ parameters = format(transform["TransformParameters"])
29
+ GridSize = format(transform["GridSize"])
30
+ GridOrigin = format(transform["GridOrigin"])
31
+ GridSpacing = format(transform["GridSpacing"])
32
+ GridDirection = np.asarray(format(transform["GridDirection"])).reshape((dimension, dimension)).T.flatten()
33
+ fixedParameters = np.concatenate([GridSize, GridOrigin, GridSpacing, GridDirection])
34
+
35
+ nb = int(format(transform["Size"])[-1])
36
+ sub = int(np.prod(GridSize))*dimension
37
+ results = []
38
+ for i in range(nb):
39
+ result = sitk.BSplineTransform(dimension)
40
+ sub_parameters = np.asarray(parameters[i*sub:(i+1)*sub])
41
+ result.SetFixedParameters(fixedParameters)
42
+ result.SetParameters(sub_parameters)
43
+ results.append(result)
44
+ return results
45
+ elif transform["Transform"][0] == "AffineLogStackTransform":
46
+ parameters = format(transform["TransformParameters"])
47
+ fixedParameters = format(transform["CenterOfRotationPoint"])+[0]
48
+
49
+ nb = int(transform["NumberOfSubTransforms"][0])
50
+ sub = dimension*4
51
+ results = []
52
+ for i in range(nb):
53
+ result = sitk.AffineTransform(dimension)
54
+ sub_parameters = np.asarray(parameters[i*sub:(i+1)*sub])
55
+
56
+ result.SetFixedParameters(fixedParameters)
57
+ result.SetParameters(np.concatenate([scipy.linalg.expm(sub_parameters[:dimension*dimension].reshape((dimension,dimension))).flatten(), sub_parameters[-dimension:]]))
58
+ results.append(result)
59
+ return results
60
+ else:
61
+ result = sitk.BSplineTransform(dimension)
62
+
63
+ parameters = format(transform["TransformParameters"])
64
+ GridSize = format(transform["GridSize"])
65
+ GridOrigin = format(transform["GridOrigin"])
66
+ GridSpacing = format(transform["GridSpacing"])
67
+ GridDirection = np.array(format(transform["GridDirection"])).reshape((dimension,dimension)).T.flatten()
68
+ fixedParameters = np.concatenate([GridSize, GridOrigin, GridSpacing, GridDirection])
69
+
70
+ result.SetFixedParameters(fixedParameters)
71
+ result.SetParameters(parameters)
72
+ return result
73
+
74
+ if __name__ == "__main__":
75
+ out_path = sys.argv[1]
76
+ finename = sys.argv[2]
77
+ finename_dest = sys.argv[3]
78
+ transform = parameterMap_to_transform("{}/{}".format(out_path, finename))
79
+ sitk.WriteTransform(transform, "{}/{}".format(out_path, finename_dest))
80
+
81
+ def getFlatLabel(mask: sitk.Image, labels: list[int]) -> sitk.Image:
82
+ data = sitk.GetArrayFromImage(mask)
83
+ result_data = np.zeros_like(data, np.uint8)
84
+
85
+ for label in labels:
86
+ result_data[data == label] = 1
87
+
88
+ result = sitk.GetImageFromArray(result_data)
89
+ result.CopyInformation(mask)
90
+ return result
91
+
92
+ def rampFilterHistogram(image: sitk.Image, rampStart: float, rampEnd: float) -> sitk.Image:
93
+ imageData = sitk.GetArrayFromImage(image)
94
+ filter = np.logical_and(imageData > rampStart, imageData < rampEnd)
95
+ rampWidth = rampEnd - rampStart
96
+ imageData[filter] = (1/rampWidth) * (imageData[filter] - rampStart) * imageData[filter]
97
+ filteredImage = sitk.GetImageFromArray(imageData)
98
+ filteredImage.CopyInformation(image)
99
+ return filteredImage
100
+
101
+ def elastic_registration( fixed_image : sitk.Image,
102
+ moving_image : sitk.Image,
103
+ fixed_mask : Union[sitk.Image, None],
104
+ moving_mask : Union[sitk.Image, None],
105
+ name_parameterMap : str,
106
+ outputDir: str) -> sitk.Transform:
107
+ labels = np.unique(sitk.GetArrayFromImage(fixed_mask))
108
+ fixed_mask = getFlatLabel(fixed_mask, labels[1:])
109
+ moving_mask = getFlatLabel(moving_mask, labels[1:])
110
+
111
+ fixed_mask.CopyInformation(fixed_image)
112
+ moving_mask.CopyInformation(moving_image)
113
+
114
+ fixed_mask_dillated = sitk.BinaryDilate(fixed_mask, [5,5,5])
115
+ moving_mask_dillated = sitk.BinaryDilate(moving_mask, [5,5,5])
116
+
117
+ fixed_image = sitk.Mask(fixed_image, fixed_mask)
118
+ moving_image = sitk.Mask(moving_image, moving_mask)
119
+
120
+
121
+ minGradientMagnitude = 50
122
+ fixed_image_gradient = rampFilterHistogram(sitk.VectorMagnitude(sitk.Gradient(fixed_image)), 0, minGradientMagnitude)
123
+ moving_image_gradient = rampFilterHistogram(sitk.VectorMagnitude(sitk.Gradient(moving_image)), 0, minGradientMagnitude)
124
+
125
+ elastixImageFilter = sitk.ElastixImageFilter()
126
+ elastixImageFilter.SetFixedImage(fixed_image_gradient)
127
+ elastixImageFilter.AddFixedImage(fixed_image)
128
+ elastixImageFilter.AddFixedImage(fixed_image)
129
+
130
+ if fixed_mask is not None:
131
+ elastixImageFilter.SetFixedMask(fixed_mask)
132
+ elastixImageFilter.AddFixedMask(fixed_mask_dillated)
133
+ elastixImageFilter.AddFixedMask(fixed_mask_dillated)
134
+
135
+
136
+ elastixImageFilter.SetMovingImage(moving_image_gradient)
137
+ elastixImageFilter.AddMovingImage(moving_image)
138
+ elastixImageFilter.AddMovingImage(moving_image)
139
+
140
+ if moving_mask is not None:
141
+ elastixImageFilter.SetMovingMask(moving_mask)
142
+ elastixImageFilter.AddMovingMask(moving_mask_dillated)
143
+ elastixImageFilter.AddMovingMask(moving_mask_dillated)
144
+
145
+ elastixImageFilter.SetParameterMap(sitk.ReadParameterFile("{}.txt".format(name_parameterMap)))
146
+ elastixImageFilter.LogToConsoleOn()
147
+ elastixImageFilter.SetOutputDirectory(outputDir)
148
+
149
+ elastixImageFilter.Execute()
150
+
151
+ transform = parameterMap_to_transform("{}TransformParameters".format(outputDir))
152
+
153
+ return transform
154
+
155
+ def registration( fixed_image : sitk.Image,
156
+ moving_image : sitk.Image,
157
+ fixed_mask : Union[sitk.Image, None],
158
+ moving_mask : Union[sitk.Image, None],
159
+ name_parameterMap : str,
160
+ outputDir: str) -> sitk.Transform:
161
+ elastixImageFilter = sitk.ElastixImageFilter()
162
+ elastixImageFilter.SetFixedImage(fixed_image)
163
+ if fixed_mask is not None:
164
+ elastixImageFilter.SetFixedMask(fixed_mask)
165
+
166
+ elastixImageFilter.SetMovingImage(moving_image)
167
+ if moving_mask is not None:
168
+ elastixImageFilter.SetMovingMask(moving_mask)
169
+
170
+ elastixImageFilter.SetParameterMap(sitk.ReadParameterFile("{}.txt".format(name_parameterMap)))
171
+ elastixImageFilter.LogToConsoleOn()
172
+ elastixImageFilter.SetOutputDirectory(outputDir)
173
+ elastixImageFilter.Execute()
174
+
175
+ transform = parameterMap_to_transform("{}TransformParameters".format(outputDir))
176
+
177
+ return transform
178
+
179
+ def registration_groupewise(images_1: sitk.Image, masks: sitk.Image, images_2: sitk.Image, name_parameterMap : str, output_dir: str):
180
+ elastixImageFilter = sitk.ElastixImageFilter()
181
+ elastixImageFilter.SetFixedImage(images_1)
182
+ elastixImageFilter.SetMovingImage(images_1)
183
+ elastixImageFilter.SetFixedMask(masks)
184
+ elastixImageFilter.SetMovingMask(masks)
185
+
186
+ if images_2 is not None:
187
+ elastixImageFilter.AddFixedImage(images_2)
188
+ elastixImageFilter.AddMovingImage(images_2)
189
+ #elastixImageFilter.AddFixedImage(images_2)
190
+ #elastixImageFilter.AddMovingImage(images_2)
191
+
192
+ elastixImageFilter.SetParameterMap(sitk.ReadParameterFile("{}.txt".format(name_parameterMap)))
193
+ elastixImageFilter.LogToConsoleOn()
194
+ elastixImageFilter.SetOutputDirectory(output_dir)
195
+ elastixImageFilter.LogToFileOn()
196
+ elastixImageFilter.Execute()
197
+
198
+ transforms = parameterMap_to_transform("{}TransformParameters".format(output_dir))
199
+ return transforms
File without changes
konfai/utils/config.py ADDED
@@ -0,0 +1,218 @@
1
+ import os
2
+ import ruamel.yaml
3
+ import inspect
4
+ import collections
5
+ from copy import deepcopy
6
+ from typing import Union
7
+ import torch
8
+
9
+ from KonfAI.konfai import CONFIG_FILE
10
+
11
+ yaml = ruamel.yaml.YAML()
12
+
13
+ class ConfigError(Exception):
14
+
15
+ def __init__(self, message : str = "The config only supports types : config(Object), int, str, bool, float, list[int], list[str], list[bool], list[float], dict[str, Object]") -> None:
16
+ self.message = message
17
+ super().__init__(self.message)
18
+
19
+ class Config():
20
+
21
+ def __init__(self, filename, key) -> None:
22
+ self.filename = filename
23
+ self.keys = key.split(".")
24
+
25
+ def __enter__(self):
26
+ if not os.path.exists(self.filename):
27
+ result = input("Create a new config file ? [no,yes,interactive] : ")
28
+ if result in ["yes", "interactive"]:
29
+ os.environ["DEEP_LEANING_API_CONFIG_MODE"] = "interactive" if result == "interactive" else "default"
30
+ else:
31
+ exit(0)
32
+ with open(self.filename, "w") as f:
33
+ pass
34
+
35
+ self.yml = open(self.filename, 'r')
36
+ self.data = yaml.load(self.yml)
37
+ if self.data == None:
38
+ self.data = {}
39
+
40
+ self.config = self.data
41
+
42
+ for key in self.keys:
43
+ if self.config == None or key not in self.config:
44
+ self.config = {key : {}}
45
+
46
+ self.config = self.config[key]
47
+ return self
48
+
49
+ def createDictionary(self, data, keys, i) -> dict:
50
+ if keys[i] not in data:
51
+ data = {keys[i]: data}
52
+ if i == 0:
53
+ return data
54
+ else:
55
+ i -= 1
56
+ return self.createDictionary(data, keys, i)
57
+
58
+ def merge(self, dict1, dict2) -> dict:
59
+ result = deepcopy(dict1)
60
+
61
+ for key, value in dict2.items():
62
+ if isinstance(value, collections.abc.Mapping):
63
+ result[key] = self.merge(result.get(key, {}), value)
64
+ else:
65
+ if not dict2[key] == None:
66
+ result[key] = deepcopy(dict2[key])
67
+ return result
68
+
69
+
70
+ def __exit__(self, type, value, traceback) -> None:
71
+ self.yml.close()
72
+ if os.environ["DEEP_LEANING_API_CONFIG_MODE"] == "remove":
73
+ if os.path.exists(CONFIG_FILE()):
74
+ os.remove(CONFIG_FILE())
75
+ return
76
+ with open(self.filename, 'r') as yml:
77
+ data = yaml.load(yml)
78
+ if data == None:
79
+ data = {}
80
+ with open(self.filename, 'w') as yml:
81
+ yaml.dump(self.merge(data, self.createDictionary(self.config, self.keys, len(self.keys)-1)), yml)
82
+
83
+ @staticmethod
84
+ def _getInput(name : str, default : str) -> str:
85
+ try:
86
+ return input("{} [{}]: ".format(name, ",".join(default.split(":")[1:]) if len(default.split(":")) else ""))
87
+ except:
88
+ result = input("\nKeep a default configuration file ? (yes,no) : ")
89
+ if result == "yes":
90
+ os.environ["DEEP_LEANING_API_CONFIG_MODE"] = "default"
91
+ else:
92
+ os.environ["DEEP_LEANING_API_CONFIG_MODE"] = "remove"
93
+ exit(0)
94
+ return default.split(":")[1] if len(default.split(":")) > 1 else default
95
+
96
+ @staticmethod
97
+ def _getInputDefault(name : str, default : Union[str, None], isList : bool = False) -> Union[list[Union[str, None]], str, None]:
98
+ if isinstance(default, str) and (default == "default" or (len(default.split(":")) > 1 and default.split(":")[0] == "default")):
99
+ if os.environ["DEEP_LEANING_API_CONFIG_MODE"] == "interactive":
100
+ if isList:
101
+ list_tmp = []
102
+ key_tmp = "OK"
103
+ while key_tmp != "!" and os.environ["DEEP_LEANING_API_CONFIG_MODE"] == "interactive":
104
+ key_tmp = Config._getInput(name, default)
105
+ if key_tmp != "!":
106
+ if key_tmp == "":
107
+ key_tmp = default.split(":")[1] if len(default.split(":")) > 1 else default
108
+ list_tmp.append(key_tmp)
109
+ return list_tmp
110
+ else:
111
+ value = Config._getInput(name, default)
112
+ if value == "":
113
+ return default.split(":")[1] if len(default.split(":")) > 1 else default
114
+ else:
115
+ return value
116
+ else:
117
+ default = default.split(":")[1] if len(default.split(":")) > 1 else default
118
+ return [default] if isList else default
119
+
120
+ def getValue(self, name, default) -> object:
121
+ if name in self.config and self.config[name] is not None:
122
+ value = self.config[name]
123
+ if value == None:
124
+ value = default
125
+ value_config = value
126
+ else:
127
+ value = Config._getInputDefault(name, default if default != inspect._empty else None)
128
+
129
+ value_config = value
130
+ if type(value_config) == tuple:
131
+ value_config = list(value)
132
+
133
+ if type(value_config) == list:
134
+ list_tmp = []
135
+ for key in value_config:
136
+ list_tmp.extend(Config._getInputDefault(name, key, isList=True))
137
+
138
+ value = list_tmp
139
+ value_config = list_tmp
140
+
141
+ if type(value) == dict:
142
+ key_tmp = []
143
+
144
+ value_config = {}
145
+ dict_value = {}
146
+ for key in value:
147
+ key_tmp.extend(Config._getInputDefault(name, key, isList=True))
148
+ for key in key_tmp:
149
+ if key in value:
150
+ value_tmp = value[key]
151
+ else:
152
+ value_tmp = next(v for k,v in value.items() if "default" in k)
153
+
154
+ value_config[key] = None
155
+ dict_value[key] = value_tmp
156
+ value = dict_value
157
+ if isinstance(self.config, str):
158
+ os.environ['DEEP_LEARNING_API_CONFIG_VARIABLE'] = "True"
159
+ return None
160
+
161
+ self.config[name] = value_config if value_config is not None else "None"
162
+ if value == "None":
163
+ value = None
164
+ return value
165
+
166
+ def config(key : Union[str, None] = None):
167
+ def decorator(function):
168
+ def new_function(*args, **kwargs):
169
+ if "config" in kwargs:
170
+ filename = kwargs["config"]
171
+ if filename == None:
172
+ filename = os.environ['DEEP_LEARNING_API_CONFIG_FILE']
173
+ else:
174
+ os.environ['DEEP_LEARNING_API_CONFIG_FILE'] = filename
175
+ key_tmp = kwargs["DL_args"]+("."+key if key is not None else "") if "DL_args" in kwargs else key
176
+ without = kwargs["DL_without"] if "DL_without" in kwargs else []
177
+ os.environ['DEEP_LEARNING_API_CONFIG_PATH'] = key_tmp
178
+ with Config(filename, key_tmp) as config:
179
+ os.environ['DEEP_LEARNING_API_CONFIG_VARIABLE'] = "False"
180
+ kwargs = {}
181
+ for param in list(inspect.signature(function).parameters.values())[len(args):]:
182
+ annotation = param.annotation
183
+ if str(annotation).startswith("typing.Union") or str(annotation).startswith("typing.Optional"):
184
+ for i in annotation.__args__:
185
+ annotation = i
186
+ break
187
+ if param.name in without:
188
+ continue
189
+ if not annotation == inspect._empty:
190
+ if annotation not in [int, str, bool, float, torch.Tensor]:
191
+ if str(annotation).startswith("list") or str(annotation).startswith("tuple") or str(annotation).startswith("typing.Tuple"):
192
+ if annotation.__args__[0] in [int, str, bool, float]:
193
+ values = config.getValue(param.name, param.default)
194
+ kwargs[param.name] = values
195
+ else:
196
+ raise ConfigError()
197
+ elif str(annotation).startswith("dict"):
198
+ if annotation.__args__[0] == str:
199
+ values = config.getValue(param.name, param.default)
200
+ if values is not None and annotation.__args__[1] not in [int, str, bool, float]:
201
+ kwargs[param.name] = {value : annotation.__args__[1](config = filename, DL_args = key_tmp+"."+param.name+"."+value) for value in values}
202
+ else:
203
+ kwargs[param.name] = values
204
+ else:
205
+ raise ConfigError()
206
+ else:
207
+ kwargs[param.name] = annotation(config = filename, DL_args = key_tmp)
208
+ if os.environ['DEEP_LEARNING_API_CONFIG_VARIABLE'] == "True":
209
+ os.environ['DEEP_LEARNING_API_CONFIG_VARIABLE'] = "False"
210
+ kwargs[param.name] = None
211
+ else:
212
+ kwargs[param.name] = config.getValue(param.name, param.default)
213
+ elif param.name != "self":
214
+ kwargs[param.name] = config.getValue(param.name, param.default)
215
+ result = function(*args, **kwargs)
216
+ return result
217
+ return new_function
218
+ return decorator