mosamatic2 2.0.24__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (136) hide show
  1. models.py +259 -0
  2. mosamatic2/__init__.py +0 -0
  3. mosamatic2/app.py +32 -0
  4. mosamatic2/cli.py +50 -0
  5. mosamatic2/commands/__init__.py +0 -0
  6. mosamatic2/commands/boadockerpipeline.py +48 -0
  7. mosamatic2/commands/calculatemaskstatistics.py +59 -0
  8. mosamatic2/commands/calculatescores.py +73 -0
  9. mosamatic2/commands/createdicomsummary.py +61 -0
  10. mosamatic2/commands/createpngsfromsegmentations.py +65 -0
  11. mosamatic2/commands/defaultdockerpipeline.py +84 -0
  12. mosamatic2/commands/defaultpipeline.py +70 -0
  13. mosamatic2/commands/dicom2nifti.py +55 -0
  14. mosamatic2/commands/liveranalysispipeline.py +61 -0
  15. mosamatic2/commands/rescaledicomimages.py +54 -0
  16. mosamatic2/commands/segmentmusclefatl3tensorflow.py +55 -0
  17. mosamatic2/commands/selectslicefromscans.py +66 -0
  18. mosamatic2/commands/totalsegmentator.py +77 -0
  19. mosamatic2/constants.py +27 -0
  20. mosamatic2/core/__init__.py +0 -0
  21. mosamatic2/core/data/__init__.py +5 -0
  22. mosamatic2/core/data/dicomimage.py +27 -0
  23. mosamatic2/core/data/dicomimageseries.py +26 -0
  24. mosamatic2/core/data/dixonseries.py +22 -0
  25. mosamatic2/core/data/filedata.py +26 -0
  26. mosamatic2/core/data/multidicomimage.py +30 -0
  27. mosamatic2/core/data/multiniftiimage.py +26 -0
  28. mosamatic2/core/data/multinumpyimage.py +26 -0
  29. mosamatic2/core/data/niftiimage.py +13 -0
  30. mosamatic2/core/data/numpyimage.py +13 -0
  31. mosamatic2/core/managers/__init__.py +0 -0
  32. mosamatic2/core/managers/logmanager.py +45 -0
  33. mosamatic2/core/managers/logmanagerlistener.py +3 -0
  34. mosamatic2/core/pipelines/__init__.py +4 -0
  35. mosamatic2/core/pipelines/boadockerpipeline/__init__.py +0 -0
  36. mosamatic2/core/pipelines/boadockerpipeline/boadockerpipeline.py +70 -0
  37. mosamatic2/core/pipelines/defaultdockerpipeline/__init__.py +0 -0
  38. mosamatic2/core/pipelines/defaultdockerpipeline/defaultdockerpipeline.py +28 -0
  39. mosamatic2/core/pipelines/defaultpipeline/__init__.py +0 -0
  40. mosamatic2/core/pipelines/defaultpipeline/defaultpipeline.py +90 -0
  41. mosamatic2/core/pipelines/liveranalysispipeline/__init__.py +0 -0
  42. mosamatic2/core/pipelines/liveranalysispipeline/liveranalysispipeline.py +48 -0
  43. mosamatic2/core/pipelines/pipeline.py +14 -0
  44. mosamatic2/core/singleton.py +9 -0
  45. mosamatic2/core/tasks/__init__.py +13 -0
  46. mosamatic2/core/tasks/applythresholdtosegmentationstask/__init__.py +0 -0
  47. mosamatic2/core/tasks/applythresholdtosegmentationstask/applythresholdtosegmentationstask.py +117 -0
  48. mosamatic2/core/tasks/calculatemaskstatisticstask/__init__.py +0 -0
  49. mosamatic2/core/tasks/calculatemaskstatisticstask/calculatemaskstatisticstask.py +104 -0
  50. mosamatic2/core/tasks/calculatescorestask/__init__.py +0 -0
  51. mosamatic2/core/tasks/calculatescorestask/calculatescorestask.py +152 -0
  52. mosamatic2/core/tasks/createdicomsummarytask/__init__.py +0 -0
  53. mosamatic2/core/tasks/createdicomsummarytask/createdicomsummarytask.py +88 -0
  54. mosamatic2/core/tasks/createpngsfromsegmentationstask/__init__.py +0 -0
  55. mosamatic2/core/tasks/createpngsfromsegmentationstask/createpngsfromsegmentationstask.py +101 -0
  56. mosamatic2/core/tasks/dicom2niftitask/__init__.py +0 -0
  57. mosamatic2/core/tasks/dicom2niftitask/dicom2niftitask.py +45 -0
  58. mosamatic2/core/tasks/rescaledicomimagestask/__init__.py +0 -0
  59. mosamatic2/core/tasks/rescaledicomimagestask/rescaledicomimagestask.py +64 -0
  60. mosamatic2/core/tasks/segmentationnifti2numpytask/__init__.py +0 -0
  61. mosamatic2/core/tasks/segmentationnifti2numpytask/segmentationnifti2numpytask.py +57 -0
  62. mosamatic2/core/tasks/segmentationnumpy2niftitask/__init__.py +0 -0
  63. mosamatic2/core/tasks/segmentationnumpy2niftitask/segmentationnumpy2niftitask.py +86 -0
  64. mosamatic2/core/tasks/segmentmusclefatl3tensorflowtask/__init__.py +0 -0
  65. mosamatic2/core/tasks/segmentmusclefatl3tensorflowtask/paramloader.py +39 -0
  66. mosamatic2/core/tasks/segmentmusclefatl3tensorflowtask/segmentmusclefatl3tensorflowtask.py +122 -0
  67. mosamatic2/core/tasks/segmentmusclefatt4pytorchtask/__init__.py +0 -0
  68. mosamatic2/core/tasks/segmentmusclefatt4pytorchtask/paramloader.py +39 -0
  69. mosamatic2/core/tasks/segmentmusclefatt4pytorchtask/segmentmusclefatt4pytorchtask.py +128 -0
  70. mosamatic2/core/tasks/selectslicefromscanstask/__init__.py +0 -0
  71. mosamatic2/core/tasks/selectslicefromscanstask/selectslicefromscanstask.py +249 -0
  72. mosamatic2/core/tasks/task.py +50 -0
  73. mosamatic2/core/tasks/totalsegmentatortask/__init__.py +0 -0
  74. mosamatic2/core/tasks/totalsegmentatortask/totalsegmentatortask.py +75 -0
  75. mosamatic2/core/utils.py +405 -0
  76. mosamatic2/server.py +146 -0
  77. mosamatic2/ui/__init__.py +0 -0
  78. mosamatic2/ui/mainwindow.py +426 -0
  79. mosamatic2/ui/resources/VERSION +1 -0
  80. mosamatic2/ui/resources/icons/mosamatic2.icns +0 -0
  81. mosamatic2/ui/resources/icons/mosamatic2.ico +0 -0
  82. mosamatic2/ui/resources/icons/spinner.gif +0 -0
  83. mosamatic2/ui/resources/images/body-composition.jpg +0 -0
  84. mosamatic2/ui/settings.py +62 -0
  85. mosamatic2/ui/utils.py +36 -0
  86. mosamatic2/ui/widgets/__init__.py +0 -0
  87. mosamatic2/ui/widgets/dialogs/__init__.py +0 -0
  88. mosamatic2/ui/widgets/dialogs/dialog.py +16 -0
  89. mosamatic2/ui/widgets/dialogs/helpdialog.py +9 -0
  90. mosamatic2/ui/widgets/panels/__init__.py +0 -0
  91. mosamatic2/ui/widgets/panels/defaultpanel.py +31 -0
  92. mosamatic2/ui/widgets/panels/logpanel.py +65 -0
  93. mosamatic2/ui/widgets/panels/mainpanel.py +82 -0
  94. mosamatic2/ui/widgets/panels/pipelines/__init__.py +0 -0
  95. mosamatic2/ui/widgets/panels/pipelines/boadockerpipelinepanel.py +195 -0
  96. mosamatic2/ui/widgets/panels/pipelines/defaultdockerpipelinepanel.py +314 -0
  97. mosamatic2/ui/widgets/panels/pipelines/defaultpipelinepanel.py +302 -0
  98. mosamatic2/ui/widgets/panels/pipelines/liveranalysispipelinepanel.py +187 -0
  99. mosamatic2/ui/widgets/panels/pipelines/pipelinepanel.py +6 -0
  100. mosamatic2/ui/widgets/panels/settingspanel.py +16 -0
  101. mosamatic2/ui/widgets/panels/stackedpanel.py +22 -0
  102. mosamatic2/ui/widgets/panels/tasks/__init__.py +0 -0
  103. mosamatic2/ui/widgets/panels/tasks/applythresholdtosegmentationstaskpanel.py +271 -0
  104. mosamatic2/ui/widgets/panels/tasks/calculatemaskstatisticstaskpanel.py +215 -0
  105. mosamatic2/ui/widgets/panels/tasks/calculatescorestaskpanel.py +238 -0
  106. mosamatic2/ui/widgets/panels/tasks/createdicomsummarytaskpanel.py +206 -0
  107. mosamatic2/ui/widgets/panels/tasks/createpngsfromsegmentationstaskpanel.py +247 -0
  108. mosamatic2/ui/widgets/panels/tasks/dicom2niftitaskpanel.py +183 -0
  109. mosamatic2/ui/widgets/panels/tasks/rescaledicomimagestaskpanel.py +184 -0
  110. mosamatic2/ui/widgets/panels/tasks/segmentationnifti2numpytaskpanel.py +192 -0
  111. mosamatic2/ui/widgets/panels/tasks/segmentationnumpy2niftitaskpanel.py +213 -0
  112. mosamatic2/ui/widgets/panels/tasks/segmentmusclefatl3tensorflowtaskpanel.py +216 -0
  113. mosamatic2/ui/widgets/panels/tasks/segmentmusclefatt4pytorchtaskpanel.py +217 -0
  114. mosamatic2/ui/widgets/panels/tasks/selectslicefromscanstaskpanel.py +193 -0
  115. mosamatic2/ui/widgets/panels/tasks/taskpanel.py +6 -0
  116. mosamatic2/ui/widgets/panels/tasks/totalsegmentatortaskpanel.py +195 -0
  117. mosamatic2/ui/widgets/panels/visualizations/__init__.py +0 -0
  118. mosamatic2/ui/widgets/panels/visualizations/liversegmentvisualization/__init__.py +0 -0
  119. mosamatic2/ui/widgets/panels/visualizations/liversegmentvisualization/liversegmentpicker.py +96 -0
  120. mosamatic2/ui/widgets/panels/visualizations/liversegmentvisualization/liversegmentviewer.py +130 -0
  121. mosamatic2/ui/widgets/panels/visualizations/liversegmentvisualization/liversegmentvisualization.py +120 -0
  122. mosamatic2/ui/widgets/panels/visualizations/sliceselectionvisualization/__init__.py +0 -0
  123. mosamatic2/ui/widgets/panels/visualizations/sliceselectionvisualization/sliceselectionviewer.py +61 -0
  124. mosamatic2/ui/widgets/panels/visualizations/sliceselectionvisualization/sliceselectionvisualization.py +133 -0
  125. mosamatic2/ui/widgets/panels/visualizations/sliceselectionvisualization/slicetile.py +63 -0
  126. mosamatic2/ui/widgets/panels/visualizations/slicevisualization/__init__.py +0 -0
  127. mosamatic2/ui/widgets/panels/visualizations/slicevisualization/custominteractorstyle.py +80 -0
  128. mosamatic2/ui/widgets/panels/visualizations/slicevisualization/sliceviewer.py +116 -0
  129. mosamatic2/ui/widgets/panels/visualizations/slicevisualization/slicevisualization.py +141 -0
  130. mosamatic2/ui/widgets/panels/visualizations/visualization.py +6 -0
  131. mosamatic2/ui/widgets/splashscreen.py +101 -0
  132. mosamatic2/ui/worker.py +29 -0
  133. mosamatic2-2.0.24.dist-info/METADATA +43 -0
  134. mosamatic2-2.0.24.dist-info/RECORD +136 -0
  135. mosamatic2-2.0.24.dist-info/WHEEL +4 -0
  136. mosamatic2-2.0.24.dist-info/entry_points.txt +5 -0
@@ -0,0 +1,86 @@
1
+ import os
2
+ import numpy as np
3
+ import SimpleITK as sitk
4
+ from mosamatic2.core.tasks.task import Task
5
+ from mosamatic2.core.managers.logmanager import LogManager
6
+ from mosamatic2.core.utils import (
7
+ is_dicom,
8
+ convert_numpy_array_to_png_image,
9
+ convert_dicom_to_png_image,
10
+ AlbertaColorMap,
11
+ )
12
+
13
+ LOG = LogManager()
14
+
15
+
16
+ class SegmentationNumpy2NiftiTask(Task):
17
+ INPUTS = ['images', 'segmentations']
18
+ PARAMS = []
19
+
20
+ def __init__(self, inputs, params, output, overwrite):
21
+ super(SegmentationNumpy2NiftiTask, self).__init__(inputs, params, output, overwrite)
22
+
23
+ def load_images(self):
24
+ images = []
25
+ for f in os.listdir(self.input('images')):
26
+ f_path = os.path.join(self.input('images'), f)
27
+ if is_dicom(f_path):
28
+ images.append(f_path)
29
+ return images
30
+
31
+ def load_segmentations(self):
32
+ segmentations = []
33
+ for f in os.listdir(self.input('segmentations')):
34
+ if f.endswith('.seg.npy'):
35
+ f_path = os.path.join(self.input('segmentations'), f)
36
+ segmentations.append(f_path)
37
+ return segmentations
38
+
39
+ def create_pairs_of_images_and_segmentations(self, images, segmentations):
40
+ pairs = []
41
+ for segmentation in segmentations:
42
+ for image in images:
43
+ image_name = os.path.split(image)[1]
44
+ if image_name in os.path.split(segmentation)[1]:
45
+ pairs.append((image, segmentation))
46
+ return pairs
47
+
48
+ def load_segmentation_as_narray(self, segmentation):
49
+ narray = np.load(segmentation)
50
+ return narray
51
+
52
+ def create_png_from_dicom(self, file_path):
53
+ convert_dicom_to_png_image(file_path, self.output())
54
+
55
+ def create_png_from_array(self, data, file_path):
56
+ png_file_name = os.path.split(file_path)[1] + '.png'
57
+ convert_numpy_array_to_png_image(
58
+ data,
59
+ self.output(),
60
+ AlbertaColorMap(),
61
+ png_file_name,
62
+ fig_width=10, fig_height=10,
63
+ )
64
+
65
+
66
+ def run(self):
67
+ images = self.load_images()
68
+ segmentations = self.load_segmentations()
69
+ image_and_segmentation_pairs = self.create_pairs_of_images_and_segmentations(images, segmentations)
70
+ nr_steps = len(image_and_segmentation_pairs)
71
+ for step in range(nr_steps):
72
+ pair = image_and_segmentation_pairs[step]
73
+ # Load segmentation
74
+ segmentation_narray = self.load_segmentation_as_narray(pair[1])
75
+ segmentation_narray = segmentation_narray.astype(np.uint16)
76
+ segmentation_narray3d = segmentation_narray[None, ...]
77
+ # Load DICOM
78
+ image_itk = sitk.ReadImage(pair[0])
79
+ segmentation_itk = sitk.GetImageFromArray(segmentation_narray3d)
80
+ segmentation_itk.CopyInformation(image_itk)
81
+ segmentation_itk_name = os.path.split(pair[1])[1] + '.nii.gz'
82
+ segmentation_itk_path = os.path.join(self.output(), segmentation_itk_name)
83
+ sitk.WriteImage(segmentation_itk, segmentation_itk_path, useCompression=True)
84
+ self.create_png_from_array(segmentation_narray, os.path.join(self.output(), segmentation_itk_name))
85
+ self.create_png_from_dicom(pair[0])
86
+ self.set_progress(step, nr_steps)
@@ -0,0 +1,39 @@
1
+ import json
2
+
3
+
4
+ class ParamLoader:
5
+ def __init__(self, json_path):
6
+ self.update(json_path)
7
+
8
+ def save(self, json_path):
9
+ """"
10
+ Save dict to json file
11
+
12
+ Parameters
13
+ ----------
14
+ json_path : string
15
+ Path to save location
16
+ """
17
+ with open(json_path, 'w') as f:
18
+ json.dump(self.__dict__, f, indent=4)
19
+
20
+ def update(self, json_path):
21
+ """
22
+ Load parameters from json file
23
+
24
+ Parameters
25
+ ----------
26
+ json_path : string
27
+ Path to json file
28
+ """
29
+ with open(json_path) as f:
30
+ params = json.load(f)
31
+ self.__dict__.update(params)
32
+
33
+ @property
34
+ def dict(self):
35
+ """"
36
+ Give dict-like access to Params instance
37
+ by: 'params.dict['learning_rate']'
38
+ """
39
+ return self.__dict__
@@ -0,0 +1,122 @@
1
+ import os
2
+ import zipfile
3
+ import tempfile
4
+ import numpy as np
5
+
6
+ import models
7
+
8
+ from mosamatic2.core.tasks.task import Task
9
+ from mosamatic2.core.tasks.segmentmusclefatl3tensorflowtask.paramloader import ParamLoader
10
+ from mosamatic2.core.data.multidicomimage import MultiDicomImage
11
+ from mosamatic2.core.data.dicomimage import DicomImage
12
+ from mosamatic2.core.managers.logmanager import LogManager
13
+ from mosamatic2.core.utils import (
14
+ normalize_between,
15
+ get_pixels_from_dicom_object,
16
+ convert_labels_to_157,
17
+ )
18
+
19
+ DEVICE = 'cpu'
20
+ L3_INDEX = 167
21
+ LOG = LogManager()
22
+
23
+
24
+ class SegmentMuscleFatL3TensorFlowTask(Task):
25
+ INPUTS = [
26
+ 'images',
27
+ 'model_files'
28
+ ]
29
+ PARAMS = ['model_version']
30
+
31
+ def __init__(self, inputs, params, output, overwrite=True):
32
+ super(SegmentMuscleFatL3TensorFlowTask, self).__init__(inputs, params, output, overwrite)
33
+
34
+ def load_images(self):
35
+ image_data = MultiDicomImage()
36
+ image_data.set_path(self.input('images'))
37
+ if image_data.load():
38
+ return image_data
39
+ raise RuntimeError('Could not load images')
40
+
41
+ def load_model_files(self):
42
+ model_files = []
43
+ for f in os.listdir(self.input('model_files')):
44
+ f_path = os.path.join(self.input('model_files'), f)
45
+ if f_path.endswith('.zip') or f_path.endswith('.json'):
46
+ model_files.append(f_path)
47
+ if len(model_files) != 3:
48
+ raise RuntimeError(f'Found {len(model_files)} model files. This should be 3!')
49
+ return model_files
50
+
51
+ def load_models_and_params(self, model_files, model_version):
52
+ tfLoaded = False
53
+ model, contour_model, params = None, None, None
54
+ for f_path in model_files:
55
+ f_name = os.path.split(f_path)[1]
56
+ if f_name == f'model-{str(model_version)}.zip':
57
+ if not tfLoaded:
58
+ import tensorflow as tf
59
+ tfLoaded = True
60
+ with tempfile.TemporaryDirectory() as model_dir_unzipped:
61
+ os.makedirs(model_dir_unzipped, exist_ok=True)
62
+ with zipfile.ZipFile(f_path) as zipObj:
63
+ zipObj.extractall(path=model_dir_unzipped)
64
+ model = tf.keras.models.load_model(model_dir_unzipped, compile=False)
65
+ elif f_name == f'contour_model-{str(model_version)}.zip':
66
+ if not tfLoaded:
67
+ import tensorflow as tf
68
+ tfLoaded = True
69
+ with tempfile.TemporaryDirectory() as contour_model_dir_unzipped:
70
+ os.makedirs(contour_model_dir_unzipped, exist_ok=True)
71
+ with zipfile.ZipFile(f_path) as zipObj:
72
+ zipObj.extractall(path=contour_model_dir_unzipped)
73
+ contour_model = tf.keras.models.load_model(contour_model_dir_unzipped, compile=False)
74
+ elif f_name == f'params-{model_version}.json':
75
+ params = ParamLoader(f_path)
76
+ else:
77
+ pass
78
+ return model, contour_model, params
79
+
80
+ def extract_contour(self, image, contour_model, params):
81
+ ct = np.copy(image)
82
+ ct = normalize_between(ct, params.dict['min_bound_contour'], params.dict['max_bound_contour'])
83
+ img2 = np.expand_dims(ct, 0)
84
+ img2 = np.expand_dims(img2, -1)
85
+ pred = contour_model.predict([img2])
86
+ pred_squeeze = np.squeeze(pred)
87
+ pred_max = pred_squeeze.argmax(axis=-1)
88
+ mask = np.uint8(pred_max)
89
+ return mask
90
+
91
+ def segment_muscle_and_fat(self, image, model):
92
+ img2 = np.expand_dims(image, 0)
93
+ img2 = np.expand_dims(img2, -1)
94
+ pred = model.predict([img2])
95
+ pred_squeeze = np.squeeze(pred)
96
+ pred_max = pred_squeeze.argmax(axis=-1)
97
+ return pred_max
98
+
99
+ def process_file(self, image, output_dir, model, contour_model, params):
100
+ assert isinstance(image, DicomImage)
101
+ pixels = get_pixels_from_dicom_object(image.object(), normalize=True)
102
+ if contour_model:
103
+ mask = self.extract_contour(pixels, contour_model, params)
104
+ pixels = normalize_between(pixels, params.dict['min_bound'], params.dict['max_bound'])
105
+ pixels = pixels * mask
106
+ pixels = pixels.astype(np.float32)
107
+ segmentation = self.segment_muscle_and_fat(pixels, model)
108
+ segmentation = convert_labels_to_157(segmentation)
109
+ segmentation_file_name = os.path.split(image.path())[1]
110
+ segmentation_file_path = os.path.join(output_dir, f'{segmentation_file_name}.seg.npy')
111
+ np.save(segmentation_file_path, segmentation)
112
+
113
+ def run(self):
114
+ image_data = self.load_images()
115
+ model_files = self.load_model_files()
116
+ model_version = self.param('model_version')
117
+ model, contour_model, params = self.load_models_and_params(model_files, model_version)
118
+ images = image_data.images()
119
+ nr_steps = len(images)
120
+ for step in range(nr_steps):
121
+ self.process_file(images[step], self.output(), model, contour_model, params)
122
+ self.set_progress(step, nr_steps)
@@ -0,0 +1,39 @@
1
+ import json
2
+
3
+
4
+ class ParamLoader:
5
+ def __init__(self, json_path):
6
+ self.update(json_path)
7
+
8
+ def save(self, json_path):
9
+ """"
10
+ Save dict to json file
11
+
12
+ Parameters
13
+ ----------
14
+ json_path : string
15
+ Path to save location
16
+ """
17
+ with open(json_path, 'w') as f:
18
+ json.dump(self.__dict__, f, indent=4)
19
+
20
+ def update(self, json_path):
21
+ """
22
+ Load parameters from json file
23
+
24
+ Parameters
25
+ ----------
26
+ json_path : string
27
+ Path to json file
28
+ """
29
+ with open(json_path) as f:
30
+ params = json.load(f)
31
+ self.__dict__.update(params)
32
+
33
+ @property
34
+ def dict(self):
35
+ """"
36
+ Give dict-like access to Params instance
37
+ by: 'params.dict['learning_rate']'
38
+ """
39
+ return self.__dict__
@@ -0,0 +1,128 @@
1
+ import os
2
+ import torch
3
+ import numpy as np
4
+ import matplotlib.pyplot as plt
5
+
6
+ import models
7
+
8
+ from mosamatic2.core.tasks.task import Task
9
+ from mosamatic2.core.utils import (
10
+ normalize_between,
11
+ get_pixels_from_dicom_object,
12
+ convert_labels_to_157,
13
+ )
14
+ from mosamatic2.core.data.multidicomimage import MultiDicomImage
15
+ from mosamatic2.core.data.dicomimage import DicomImage
16
+ from mosamatic2.core.tasks.segmentmusclefatt4pytorchtask.paramloader import ParamLoader
17
+
18
+ DEVICE = 'cpu'
19
+
20
+
21
+ class SegmentMuscleFatT4PyTorchTask(Task):
22
+ INPUTS = [
23
+ 'images',
24
+ 'model_files'
25
+ ]
26
+ PARAMS = ['model_version']
27
+
28
+ def __init__(self, inputs, params, output, overwrite=True):
29
+ super(SegmentMuscleFatT4PyTorchTask, self).__init__(inputs, params, output, overwrite)
30
+
31
+ def load_images(self):
32
+ image_data = MultiDicomImage()
33
+ image_data.set_path(self.input('images'))
34
+ if image_data.load():
35
+ return image_data
36
+ raise RuntimeError('Could not load images')
37
+
38
+ def load_model_files(self):
39
+ model_files = []
40
+ for f in os.listdir(self.input('model_files')):
41
+ f_path = os.path.join(self.input('model_files'), f)
42
+ if f_path.endswith('.pt') or f_path.endswith('.json'):
43
+ model_files.append(f_path)
44
+ if len(model_files) != 3:
45
+ raise RuntimeError(f'Found {len(model_files)} model files. This should be 3!')
46
+ return model_files
47
+
48
+ def load_models_and_params(self, model_files, model_version):
49
+ # First load params.json because it is needed to instantiate the models
50
+ params = None
51
+ for f_path in model_files:
52
+ f_name = os.path.split(f_path)[1]
53
+ if f_name == f'params-{str(model_version)}.json':
54
+ params = ParamLoader(f_path)
55
+ break
56
+ if params is None:
57
+ raise RuntimeError('Could not load parameters')
58
+ # Next, load models
59
+ model, contour_model = None, None
60
+ for f_path in model_files:
61
+ f_name = os.path.split(f_path)[1]
62
+ if f_name == f'model-{str(model_version)}.pt':
63
+ # model = models.AttentionUNet(params, 4).to(device=DEVICE)
64
+ model = models.UNet(params, 4).to(device=DEVICE)
65
+ model.load_state_dict(torch.load(f_path, weights_only=False, map_location=torch.device(DEVICE)))
66
+ model.eval()
67
+ elif f_name == f'contour_model-{str(model_version)}.pt':
68
+ contour_model = models.UNet(params, 2).to(device=DEVICE)
69
+ contour_model.load_state_dict(torch.load(f_path, weights_only=False, map_location=torch.device(DEVICE)))
70
+ contour_model.eval()
71
+ else:
72
+ pass
73
+ return model, contour_model, params
74
+
75
+ def extract_contour(self, image, contour_model):
76
+ with torch.no_grad():
77
+ # Create 4D Tensor input
78
+ input = np.expand_dims(image, 0)
79
+ input = np.expand_dims(input, 0)
80
+ input = torch.Tensor(input)
81
+ input = input.to(DEVICE, dtype=torch.float)
82
+ # Predict
83
+ prediction = contour_model(input)
84
+ prediction = torch.argmax(prediction, axis=1)
85
+ prediction = prediction.squeeze()
86
+ prediction = prediction.detach().cpu().numpy()
87
+ result = image * prediction
88
+ return result
89
+
90
+ def show_result(self, result):
91
+ fig = plt.figure()
92
+ plt.imshow(result, cmap='gray')
93
+ plt.savefig("/Users/ralph/result.png")
94
+
95
+ def segment_muscle_and_fat(self, image, model):
96
+ input = np.expand_dims(image, 0)
97
+ input = np.expand_dims(input, 0)
98
+ input = torch.Tensor(input)
99
+ input = input.to(DEVICE, dtype=torch.float)
100
+ segmentation = model(input)
101
+ segmentation = torch.argmax(segmentation, axis=1)
102
+ segmentation = segmentation.squeeze()
103
+ segmentation = segmentation.detach().cpu().numpy()
104
+ return segmentation
105
+
106
+ def process_file(self, image, output_dir, model, contour_model, params):
107
+ assert isinstance(image, DicomImage)
108
+ pixels = get_pixels_from_dicom_object(image.object(), normalize=True)
109
+ if contour_model:
110
+ pixels = normalize_between(pixels, params.dict['lower_bound'], params.dict['upper_bound'])
111
+ pixels = self.extract_contour(pixels, contour_model)
112
+ pixels = pixels.astype(np.float32)
113
+ segmentation = self.segment_muscle_and_fat(pixels, model)
114
+ segmentation = convert_labels_to_157(segmentation)
115
+ segmentation_file_name = os.path.split(image.path())[1]
116
+ segmentation_file_path = os.path.join(output_dir, f'{segmentation_file_name}.seg.npy')
117
+ np.save(segmentation_file_path, segmentation)
118
+
119
+ def run(self):
120
+ image_data = self.load_images()
121
+ model_files = self.load_model_files()
122
+ model_version = self.param('model_version')
123
+ model, contour_model, params = self.load_models_and_params(model_files, model_version)
124
+ images = image_data.images()
125
+ nr_steps = len(images)
126
+ for step in range(nr_steps):
127
+ self.process_file(images[step], self.output(), model, contour_model, params)
128
+ self.set_progress(step, nr_steps)
@@ -0,0 +1,249 @@
1
+ import os
2
+ import math
3
+ import tempfile
4
+ import shutil
5
+ import nibabel as nib
6
+ import numpy as np
7
+ import SimpleITK as sitk
8
+ import matplotlib.pyplot as plt
9
+ from totalsegmentator.python_api import totalsegmentator
10
+ from mosamatic2.core.tasks.task import Task
11
+ from mosamatic2.core.managers.logmanager import LogManager
12
+ from mosamatic2.core.utils import load_dicom
13
+
14
+ LOG = LogManager()
15
+
16
+ TOTAL_SEGMENTATOR_OUTPUT_DIR = os.path.join(tempfile.gettempdir(), 'total_segmentator_output')
17
+ TOTAL_SEGMENTATOR_TASK = 'total'
18
+ Z_DELTA_OFFSETS = {
19
+ 'vertebrae_L3': 0.333,
20
+ 'vertebrae_T4': 0.5,
21
+ }
22
+
23
+
24
+ class SelectSliceFromScansTask(Task):
25
+ INPUTS = ['scans']
26
+ PARAMS = ['vertebra']
27
+
28
+ def __init__(self, inputs, params, output, overwrite):
29
+ super(SelectSliceFromScansTask, self).__init__(inputs, params, output, overwrite)
30
+ self._error_dir = os.path.split(self.output())[0]
31
+ self._error_dir = os.path.join(self._error_dir, 'selectslicefromscanstask_errors')
32
+ os.makedirs(self._error_dir, exist_ok=True)
33
+ self._error_file = os.path.join(self._error_dir, 'errors.txt')
34
+ with open(self._error_file, 'w') as f:
35
+ f.write('Errors:\n\n')
36
+ LOG.info(f'Error directory: {self._error_dir}')
37
+
38
+ def write_error(self, message):
39
+ LOG.error(message)
40
+ with open(self._error_file, 'a') as f:
41
+ f.write(message + '\n')
42
+
43
+ def load_scan_dirs(self):
44
+ scan_dirs = []
45
+ for d in os.listdir(self.input('scans')):
46
+ scan_dir = os.path.join(self.input('scans'), d)
47
+ if os.path.isdir(scan_dir):
48
+ scan_dirs.append(scan_dir)
49
+ return scan_dirs
50
+
51
+ def read_ct_series_sitk(self, scan_dir):
52
+ reader = sitk.ImageSeriesReader()
53
+ series_ids = reader.GetGDCMSeriesIDs(scan_dir)
54
+ if not series_ids:
55
+ raise ValueError(f'No DICOM series found in {scan_dir}')
56
+ file_names = reader.GetGDCMSeriesFileNames(scan_dir, series_ids[0])
57
+ reader.SetFileNames(file_names)
58
+ img = reader.Execute()
59
+ return img
60
+
61
+ def resample_to_reference(self, moving: sitk.Image, reference: sitk.Image, is_label: bool) -> sitk.Image:
62
+ interp = sitk.sitkNearestNeighbor if is_label else sitk.sitkLinear
63
+ return sitk.Resample(
64
+ moving,
65
+ reference,
66
+ sitk.Transform(),
67
+ interp,
68
+ 0, # default value
69
+ moving.GetPixelID()
70
+ )
71
+
72
+ def centroid_x_index_from_mask(self, mask_ref: sitk.Image) -> int:
73
+ # mask_ref must already be on the CT grid
74
+ arr = sitk.GetArrayFromImage(mask_ref) # shape: [z, y, x]
75
+ idx = np.argwhere(arr > 0)
76
+ if idx.size == 0:
77
+ # fallback to true mid-sagittal
78
+ size_x = mask_ref.GetSize()[0]
79
+ return size_x // 2
80
+ x_mean = int(round(idx[:, 2].mean()))
81
+ return x_mean
82
+
83
+ def z_index_from_physical_z(self, ct_img: sitk.Image, z_phys: float, x_index: int) -> int:
84
+ # pick a y roughly mid (or better: y centroid of mask if you want)
85
+ size = ct_img.GetSize() # (x, y, z)
86
+ y_index = size[1] // 2
87
+
88
+ # Convert chosen (x_index, y_index, any z_index) -> physical, then replace z
89
+ p = ct_img.TransformIndexToPhysicalPoint((x_index, y_index, size[2] // 2))
90
+ phys_point = (p[0], p[1], z_phys)
91
+
92
+ try:
93
+ ix, iy, iz = ct_img.TransformPhysicalPointToIndex(phys_point)
94
+ except RuntimeError:
95
+ # if z_phys is slightly out of range, clamp later
96
+ iz = None
97
+
98
+ if iz is None:
99
+ # clamp using physical bounds
100
+ # safest crude clamp: convert all slice centers to physical z and find closest
101
+ z_centers = []
102
+ for k in range(size[2]):
103
+ pk = ct_img.TransformIndexToPhysicalPoint((x_index, y_index, k))
104
+ z_centers.append(pk[2])
105
+ z_centers = np.array(z_centers)
106
+ iz = int(np.argmin(np.abs(z_centers - z_phys)))
107
+
108
+ iz = int(np.clip(iz, 0, size[2]-1))
109
+ return iz
110
+
111
+ def plot_sagittal_with_vertebra_overlay(self, scan_dir, mask_file, z_vert_phys_mm, out_png):
112
+ ct = self.read_ct_series_sitk(scan_dir)
113
+ vert_mask = sitk.ReadImage(mask_file)
114
+ vert_mask_ref = self.resample_to_reference(vert_mask, ct, is_label=True)
115
+ x_idx = self.centroid_x_index_from_mask(vert_mask_ref)
116
+ z_idx = self.z_index_from_physical_z(ct, z_vert_phys_mm, x_idx)
117
+ ct_arr = sitk.GetArrayFromImage(ct).astype(np.float32)
118
+ mk_arr = sitk.GetArrayFromImage(vert_mask_ref).astype(np.uint8)
119
+ sag_ct = ct_arr[:, :, x_idx]
120
+ sag_mk = mk_arr[:, :, x_idx]
121
+ vmin, vmax = np.percentile(sag_ct, (1, 99))
122
+ sy = ct.GetSpacing()[1] # y spacing (mm)
123
+ sz = ct.GetSpacing()[2] # z spacing (mm)
124
+ aspect = sz / sy
125
+ plt.figure(figsize=(7, 9))
126
+ plt.imshow(sag_ct, cmap="gray", vmin=vmin, vmax=vmax, origin="lower", aspect=aspect)
127
+ plt.imshow(sag_mk, alpha=0.35, origin="lower", aspect=aspect)
128
+ plt.axhline(z_idx, linewidth=2) # line across the vertebra axial slice position
129
+ plt.title("Sagittal view with vertebral mask overlay + selected vertebra axial slice")
130
+ plt.axis("off")
131
+ plt.savefig(out_png, bbox_inches="tight", dpi=200)
132
+
133
+ def extract_masks(self, scan_dir):
134
+ os.makedirs(TOTAL_SEGMENTATOR_OUTPUT_DIR, exist_ok=True)
135
+ totalsegmentator(input=scan_dir, output=TOTAL_SEGMENTATOR_OUTPUT_DIR, fast=True)
136
+ if not os.path.isfile(os.path.join(TOTAL_SEGMENTATOR_OUTPUT_DIR, 'vertebrae_L3.nii.gz')):
137
+ raise Exception(f'{scan_dir}: vertebrae_L3.nii.gz does not exist')
138
+ # os.system(f'TotalSegmentator -i {scan_dir} -o {TOTAL_SEGMENTATOR_OUTPUT_DIR} --fast')
139
+
140
+ def delete_total_segmentator_output(self):
141
+ if os.path.exists(TOTAL_SEGMENTATOR_OUTPUT_DIR):
142
+ shutil.rmtree(TOTAL_SEGMENTATOR_OUTPUT_DIR)
143
+
144
+ def get_z_delta_offset_for_mask(self, mask_name):
145
+ if mask_name not in Z_DELTA_OFFSETS.keys():
146
+ return None
147
+ return Z_DELTA_OFFSETS[mask_name]
148
+
149
+ def find_slice(self, scan_dir, vertebra):
150
+ if vertebra == 'L3':
151
+ vertebral_level = 'vertebrae_L3'
152
+ elif vertebra == 'T4':
153
+ vertebral_level = 'vertebrae_T4'
154
+ else:
155
+ self.write_error(f'{scan_dir}: Unknown vertbra {vertebra}. Options are "L3" and "T4"')
156
+ return None
157
+ # Find Z-positions DICOM images
158
+ z_positions = {}
159
+ for f in os.listdir(scan_dir):
160
+ f_path = os.path.join(scan_dir, f)
161
+ try:
162
+ p = load_dicom(f_path, stop_before_pixels=True)
163
+ if p is not None and hasattr(p, "ImagePositionPatient"):
164
+ z_positions[p.ImagePositionPatient[2]] = f_path
165
+ except Exception as e:
166
+ self.write_error(f"{scan_dir}: Failed to load DICOM {f_path}: {e}")
167
+ break
168
+ if not z_positions:
169
+ self.write_error(f"{scan_dir}: No valid DICOM z-positions found.")
170
+ return None
171
+ # Find Z-position vertebral image
172
+ mask_file = os.path.join(TOTAL_SEGMENTATOR_OUTPUT_DIR, f'{vertebral_level}.nii.gz')
173
+ if not os.path.exists(mask_file):
174
+ self.write_error(f"{scan_dir}: Mask file not found: {mask_file}")
175
+ return None
176
+ try:
177
+ mask_obj = nib.load(mask_file)
178
+ mask = mask_obj.get_fdata()
179
+ affine_transform = mask_obj.affine
180
+ except Exception as e:
181
+ self.write_error(f"{scan_dir}: Failed to load mask {mask_file}: {e}")
182
+ return None
183
+ indexes = np.array(np.where(mask == 1))
184
+ if indexes.size == 0:
185
+ self.write_error(f"{scan_dir}: No voxels found in mask {mask_file} for {vertebral_level}")
186
+ return None
187
+ try:
188
+ index_min = indexes.min(axis=1)
189
+ index_max = indexes.max(axis=1)
190
+ except ValueError as e:
191
+ self.write_error(f"{scan_dir}: Invalid indexes array for {vertebral_level}: {e}")
192
+ return None
193
+ world_min = nib.affines.apply_affine(affine_transform, index_min)
194
+ world_max = nib.affines.apply_affine(affine_transform, index_max)
195
+ z_direction = affine_transform[:3, 2][2]
196
+ if z_direction == 0:
197
+ self.write_error(f"{scan_dir}: Affine z-direction is zero.")
198
+ return None
199
+ z_sign = math.copysign(1, z_direction)
200
+ z_delta_offset = self.get_z_delta_offset_for_mask(vertebral_level)
201
+ if z_delta_offset is None:
202
+ return None
203
+ z_delta = 0.333 * abs(world_max[2] - world_min[2]) # This needs to be vertebra-specific perhaps
204
+ z_l3 = world_max[2] - z_sign * z_delta
205
+ # Find closest L3 image in DICOM set
206
+ positions = sorted(z_positions.keys())
207
+ closest_file = None
208
+ for z1, z2 in zip(positions[:-1], positions[1:]):
209
+ if min(z1, z2) <= z_l3 <= max(z1, z2):
210
+ closest_z = min(z_positions.keys(), key=lambda z: abs(z - z_l3))
211
+ closest_file = z_positions[closest_z]
212
+ LOG.info(f'Closest image: {closest_file}')
213
+ break
214
+ if closest_file is None:
215
+ self.write_error(f"{scan_dir}: No matching slice found.")
216
+ return closest_file, z_l3
217
+
218
+ def run(self):
219
+ scan_dirs = self.load_scan_dirs()
220
+ vertebra = self.param('vertebra')
221
+ nr_steps = len(scan_dirs)
222
+ for step in range(nr_steps):
223
+ scan_dir = scan_dirs[step]
224
+ scan_name = os.path.split(scan_dir)[1]
225
+ errors = False
226
+ LOG.info(f'Processing {scan_dir}...')
227
+ try:
228
+ self.extract_masks(scan_dir)
229
+ except Exception as e:
230
+ self.write_error(f'{scan_dir}: Could not extract masks [{str(e)}]. Skipping scan...')
231
+ errors = True
232
+ if not errors:
233
+ file_path, z_vert = self.find_slice(scan_dir, vertebra)
234
+ if file_path is not None:
235
+ extension = '' if file_path.endswith('.dcm') else '.dcm'
236
+ target_file_path = os.path.join(self.output(), vertebra + '_' + scan_name + extension)
237
+ shutil.copyfile(file_path, target_file_path)
238
+ mask_file = os.path.join(TOTAL_SEGMENTATOR_OUTPUT_DIR, f'vertebrae_{vertebra}.nii.gz')
239
+ output_png = os.path.join(self.output(), f"{vertebra}_{scan_name}_sagittal.png")
240
+ self.plot_sagittal_with_vertebra_overlay(scan_dir, mask_file, z_vert, output_png)
241
+ else:
242
+ self.write_error(f'{scan_dir}: Could not find slice for vertebral level: {vertebra}')
243
+ errors = True
244
+ self.delete_total_segmentator_output()
245
+ if errors:
246
+ LOG.info(f'Copying problematic scan {scan_dir} to error directory: {self._error_dir}')
247
+ scan_error_dir = os.path.join(self._error_dir, scan_name)
248
+ shutil.copytree(scan_dir, scan_error_dir)
249
+ self.set_progress(step, nr_steps)