hyperresashs 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- hyperresashs/__init__.py +25 -0
- hyperresashs/__main__.py +2 -0
- hyperresashs/ashs_cli.py +511 -0
- hyperresashs/ashs_exp.py +239 -0
- hyperresashs/ashs_inference.py +184 -0
- hyperresashs/ashs_preproc.py +641 -0
- hyperresashs/ashs_training.py +765 -0
- hyperresashs/config_templates/ashs_filename_scheme.yaml +67 -0
- hyperresashs/config_templates/config_inr_template.yaml +63 -0
- hyperresashs/main.py +203 -0
- hyperresashs/prepare_inr.py +327 -0
- hyperresashs/preprocessing.py +340 -0
- hyperresashs/testing.py +400 -0
- hyperresashs-1.0.0.dist-info/METADATA +376 -0
- hyperresashs-1.0.0.dist-info/RECORD +18 -0
- hyperresashs-1.0.0.dist-info/WHEEL +5 -0
- hyperresashs-1.0.0.dist-info/entry_points.txt +2 -0
- hyperresashs-1.0.0.dist-info/top_level.txt +1 -0
hyperresashs/__init__.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
import sys
|
|
2
|
+
import os
|
|
3
|
+
from importlib.metadata import version, PackageNotFoundError
|
|
4
|
+
|
|
5
|
+
# Check for any already-cached nnunetv2 modules
|
|
6
|
+
to_remove = [key for key in sys.modules if key == "nnunetv2" or key.startswith("nnunetv2.")]
|
|
7
|
+
for key in to_remove:
|
|
8
|
+
print(f'hyperresashs found existing nnunetv2 module {key} in sys.modules. Removing it to avoid conflicts with hyperresashs submodule nnunetv2.')
|
|
9
|
+
del sys.modules[key]
|
|
10
|
+
|
|
11
|
+
# Insert the submodule path at the front so it takes priority
|
|
12
|
+
_submodule_path = os.path.join(os.path.dirname(__file__), "submodules", "nnUNet")
|
|
13
|
+
sys.path.insert(0, _submodule_path)
|
|
14
|
+
|
|
15
|
+
# Insert the submodule path at the front so it takes priority
|
|
16
|
+
_submodule_path_inr = os.path.join(os.path.dirname(__file__), "submodules", "multi_contrast_inr")
|
|
17
|
+
sys.path.insert(0, _submodule_path_inr)
|
|
18
|
+
|
|
19
|
+
# Get the version from python importlib
|
|
20
|
+
try:
|
|
21
|
+
# Replace 'your_package_name' with the actual name of your package
|
|
22
|
+
__version__ = version("hyperresashs")
|
|
23
|
+
except PackageNotFoundError:
|
|
24
|
+
# Handle cases where the package isn't installed in the environment (e.g., running from source without proper installation)
|
|
25
|
+
__version__ = "unknown"
|
hyperresashs/__main__.py
ADDED
hyperresashs/ashs_cli.py
ADDED
|
@@ -0,0 +1,511 @@
|
|
|
1
|
+
from .ashs_inference import HyperASHSInference
|
|
2
|
+
from .utils.huggingface import hf_disable_ssl_verification, hf_read_yaml, torch_hub_disable_ssl_verification
|
|
3
|
+
from .ashs_training import HyperASHSTraining
|
|
4
|
+
from .utils.tool import copy_or_link_file
|
|
5
|
+
from . import __version__
|
|
6
|
+
import argparse
|
|
7
|
+
import huggingface_hub as hf
|
|
8
|
+
import pandas as pd
|
|
9
|
+
import re
|
|
10
|
+
import pprint
|
|
11
|
+
import json
|
|
12
|
+
import textwrap
|
|
13
|
+
import os
|
|
14
|
+
import shutil
|
|
15
|
+
import yaml
|
|
16
|
+
import sys
|
|
17
|
+
from typing import Dict, Any
|
|
18
|
+
from importlib.resources import files
|
|
19
|
+
|
|
20
|
+
"""
|
|
21
|
+
Original ASHS command line help message for reference:
|
|
22
|
+
required options:
|
|
23
|
+
-a dir Location of the atlas directory. Can be a full pathname or a
|
|
24
|
+
relative directory name under ASHS_ROOT/data directory.
|
|
25
|
+
-g image Filename of 3D (g)radient echo MRI (ASHS_MPRAGE, T1w)
|
|
26
|
+
-f image Filename of 2D focal (f)ast spin echo MRI (ASHS_TSE, T2w)
|
|
27
|
+
-w path Working/output directory
|
|
28
|
+
|
|
29
|
+
optional:
|
|
30
|
+
-d Enable debugging
|
|
31
|
+
-h Print help
|
|
32
|
+
-s integer Run only one stage (see below); also accepts range (e.g. -s 1-3)
|
|
33
|
+
-N No overriding of registration results. If a result from an earlier run
|
|
34
|
+
exists, don't run greedy again.
|
|
35
|
+
-G Use template brain mask in T1 template rigid registratian
|
|
36
|
+
-T Tidy mode. Cleans up files once they are unneeded. The -N option will
|
|
37
|
+
have no effect in tidy mode, because intermediate results will be erased.
|
|
38
|
+
-I string Subject ID (for stats output). Defaults to last word of working dir.
|
|
39
|
+
-V Display version information and exit
|
|
40
|
+
-C file Configuration file. If not passed, uses $ASHS_ROOT/bin/ashs_config.sh
|
|
41
|
+
-Q Use Sun Grid Engine (SGE) to schedule sub-tasks in each stage.
|
|
42
|
+
By default, the whole ashs_main job runs in a single process.
|
|
43
|
+
If you are doing a lot of segmentations and have SGE, it is better to
|
|
44
|
+
run each segmentation (ashs_main) in a separate SGE job, rather than use the -q flag.
|
|
45
|
+
The -q flag is best for when you have only a few segmentations and want them to run fast.
|
|
46
|
+
-P Use GNU parallel to run on multiple cores on the local machine. You need to
|
|
47
|
+
have GNU parallel installed.
|
|
48
|
+
-S Use SLURM instead of SGE, LSF or GNU parallel
|
|
49
|
+
-l Use LSF instead of SGE, SLURM or GNU parallel
|
|
50
|
+
-q OPTS Pass in additional options to SGE/SLURM/LSF/GNU Parallel. If -S -B or -P not specified
|
|
51
|
+
turns on SGE.
|
|
52
|
+
-z script Provide a path to an executable script that will be used to retrieve SGE, LSF, SLURM or
|
|
53
|
+
GNU parallel options for different stages of ASHS. Takes precendence over -q
|
|
54
|
+
-r files Compare segmentation results with a reference segmentation. The parameter
|
|
55
|
+
files should consist of two nifti files in quotation marks:
|
|
56
|
+
|
|
57
|
+
-r "ref_seg_left.nii.gz ref_seg_right.nii.gz"
|
|
58
|
+
|
|
59
|
+
The results will include overlap calculations between different
|
|
60
|
+
stages of the segmentation and the reference segmentation. Note that the
|
|
61
|
+
comparison takes into account the heuristic rules specified in the altas, so
|
|
62
|
+
it is not as simple as computing dice overlaps between the reference seg
|
|
63
|
+
and the ASHS segs.
|
|
64
|
+
-m file Provide the .mat file for the transform between the T1w and T2w image. The file
|
|
65
|
+
is in the format used by ITK-SNAP and C3D and should be generated by performing
|
|
66
|
+
registration with the T2w MRI as fixed image and the T1w MRI as moving.
|
|
67
|
+
By default, the mat file is used as a hint to initialize rigid T2/T1
|
|
68
|
+
but this can be modified with the -M flag.
|
|
69
|
+
-M The mat file provided with -m is used as the final T2/T1 registration.
|
|
70
|
+
ASHS will not attempt to run registration between T2 and T2.
|
|
71
|
+
-t threads Specify number of parallel threads the greedy runs
|
|
72
|
+
-H Tell ASHS to use external hooks for reporting progress, errors, and warnings.
|
|
73
|
+
The environment variables ASHS_HOOK_SCRIPT must be set to point to the appropriate
|
|
74
|
+
script. For an example script with comments, see ashs_default_hook.sh
|
|
75
|
+
The purpose of the hook is to allow intermediary systems (e.g. XNAT)
|
|
76
|
+
to monitor ASHS performance. An optional ASHS_HOOK_DATA variable can be set
|
|
77
|
+
-B Do not perform the bootstrapping step, and use the output of the initial joint
|
|
78
|
+
label fusion (in multiatlas directory) as the final output.
|
|
79
|
+
and will be forwarded to the script
|
|
80
|
+
"""
|
|
81
|
+
|
|
82
|
+
_ashs_naming_scheme = """
|
|
83
|
+
# ------- input -------
|
|
84
|
+
# input image names
|
|
85
|
+
t1_native_img: "mprage.nii.gz"
|
|
86
|
+
t2_native_img: "tse.nii.gz"
|
|
87
|
+
|
|
88
|
+
# input template names
|
|
89
|
+
template: "template.nii.gz"
|
|
90
|
+
left_roi_file: "left_round_in_global_space_larger.nii.gz"
|
|
91
|
+
right_roi_file: "right_round_in_global_space_larger.nii.gz"
|
|
92
|
+
|
|
93
|
+
# ------- output -------
|
|
94
|
+
# registration intermediate
|
|
95
|
+
t1_name_after_triming_neck: "mprage_necktrim.nii.gz"
|
|
96
|
+
affine_matrix: "t1_to_template_affine_inv.mat"
|
|
97
|
+
t1_whole_img: "mprage_to_tse_warped.nii.gz"
|
|
98
|
+
t2_padded_img: "tse_padded.nii.gz"
|
|
99
|
+
|
|
100
|
+
# registration output
|
|
101
|
+
template_to_3tt1: "template_to_mprage_warped.nii.gz"
|
|
102
|
+
global_roi_in_3tt1_XYZ: "template_roi_XYZ_to_mprage_warped.nii.gz"
|
|
103
|
+
|
|
104
|
+
# preprocessing
|
|
105
|
+
inr_primary: "primary.nii.gz"
|
|
106
|
+
inr_secondary: "secondary.nii.gz"
|
|
107
|
+
inr_seg: "primary_seg.nii.gz"
|
|
108
|
+
seg: "input_primary_seg.nii.gz"
|
|
109
|
+
hyper_primary: "hyper_primary.nii.gz"
|
|
110
|
+
hyper_secondary: "hyper_secondary.nii.gz"
|
|
111
|
+
reg_mat: "auxiluary_to_primary.mat"
|
|
112
|
+
hyper_secondary_after_registertion: "auxiluary_to_primary_registered.nii.gz"
|
|
113
|
+
hyper_primary_seg: "inr_hyper_primary_seg.nii.gz"
|
|
114
|
+
"""
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
def main():
|
|
118
|
+
|
|
119
|
+
parser = argparse.ArgumentParser(
|
|
120
|
+
prog='hrashs',
|
|
121
|
+
description='HyperResASHS: High-Resolution Automatic Segmentation of Hippocampal Subfields',
|
|
122
|
+
epilog="""HyperResASHS (C) 2026 by Yue Li, Paul Yushkevich, and UPenn Patch Lab.
|
|
123
|
+
Citation: https://doi.org/10.48550/arXiv.2508.17171
|
|
124
|
+
GitHub: https://github.com/liyue3780/HyperResASHS.git""")
|
|
125
|
+
|
|
126
|
+
# Create subparsers for different commands
|
|
127
|
+
subparsers = parser.add_subparsers(dest='command', required=True, help='Available commands')
|
|
128
|
+
|
|
129
|
+
# List atlases subcommand
|
|
130
|
+
list_parser = subparsers.add_parser('list', help='List available HyperResASHS atlases')
|
|
131
|
+
list_parser.add_argument('-l', '--long', action='store_true', help='Show detailed information about each atlas')
|
|
132
|
+
|
|
133
|
+
# Describe atlas subcommand
|
|
134
|
+
describe_parser = subparsers.add_parser('desc', help='Describe a specific HyperResASHS atlas')
|
|
135
|
+
describe_parser.add_argument('atlas', type=str, help='Name of the atlas to describe or path to atlas config file')
|
|
136
|
+
|
|
137
|
+
# Run segmentation subcommand
|
|
138
|
+
run_parser = subparsers.add_parser('run', help='Run HyperResASHS segmentation pipeline')
|
|
139
|
+
run_parser.add_argument('-a', '--atlas', type=str, required=True,
|
|
140
|
+
help='Name of the atlas to use or path to atlas config file')
|
|
141
|
+
run_parser.add_argument('-g', '--t1', type=str, required=True,
|
|
142
|
+
help='Path to T1-weighted image')
|
|
143
|
+
run_parser.add_argument('-f', '--t2', type=str, required=True,
|
|
144
|
+
help='Path to T2-weighted image')
|
|
145
|
+
|
|
146
|
+
# Training command
|
|
147
|
+
train_parser = subparsers.add_parser('train', help='Train a HyperResASHS model on a dataset')
|
|
148
|
+
train_parser.add_argument('-c', '--config', type=str, required=True,
|
|
149
|
+
help='Path to training configuration YAML file. See documentation for expected format.')
|
|
150
|
+
train_parser.add_argument('-m', '--manifest', type=str, required=True,
|
|
151
|
+
help='Path to manifest CSV file describing the training dataset. See documentation for expected format.')
|
|
152
|
+
train_parser.add_argument('-l', '--labels', type=str, required=True,
|
|
153
|
+
help='Path to ITK-SNAP label description file. See documentation for expected format.')
|
|
154
|
+
train_parser.add_argument('-x', '--xval', type=str, default=None,
|
|
155
|
+
help='Cross-validation fold specification. See documentation for expected format.')
|
|
156
|
+
train_parser.add_argument('-R', '--inr-random-seed', type=int, default=None,
|
|
157
|
+
help='Specify random seed for the INR optimization; use to rerun failed INR experiments.')
|
|
158
|
+
train_parser.add_argument('--inr-batch-size', type=int, default=None,
|
|
159
|
+
help='Specify batch size for the INR optimization. Default: 10000')
|
|
160
|
+
|
|
161
|
+
train_parser.add_argument('-s', '--stage', type=str, default='',
|
|
162
|
+
help='''
|
|
163
|
+
Run one or more selected stages of the training pipeline.
|
|
164
|
+
You can specify single stage as -s 1, or a range of stages as -s 1-3.
|
|
165
|
+
The stages are as follows:
|
|
166
|
+
1: Preprocessing (neck trim, registration, patch extraction)
|
|
167
|
+
2: INR training
|
|
168
|
+
3: nnU-Net preparation (resampling, cropping, and formatting for nnU-Net)
|
|
169
|
+
4: nnU-Net training
|
|
170
|
+
''')
|
|
171
|
+
train_parser.add_argument('-F', '--filter', type=str, metavar='REGEX' ,default=None,
|
|
172
|
+
help='''Restrict execution to image(s) that match specified regular expression.
|
|
173
|
+
For stage 1 (preprocessing), REGEX will be matched to subject id and date.
|
|
174
|
+
For stage 2 (INR training), REGEX will be matched to subject id, date, and side.
|
|
175
|
+
For stage 3 (nnU-Net preparation), this is ignored.
|
|
176
|
+
For stage 4 (nnU-Net training), REGEX is matched to fold number (0-4)
|
|
177
|
+
''')
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
# Add common arguments for run and train subcommands
|
|
181
|
+
for p in [run_parser, train_parser]:
|
|
182
|
+
p.add_argument('-w', '--workdir', type=str, required=True,
|
|
183
|
+
help='Path to working directory')
|
|
184
|
+
p.add_argument('-N', '--no-overwrite', action='store_true',
|
|
185
|
+
help='Do not overwrite existing results')
|
|
186
|
+
p.add_argument('-t', '--threads', type=int, default=1,
|
|
187
|
+
help='Number of parallel threads to use for segmentation [default: 1]')
|
|
188
|
+
p.add_argument('--device', type=str, default='auto',
|
|
189
|
+
help='''Device to use for segmentation (e.g. "cuda" or "cpu"). Default: auto-detect
|
|
190
|
+
To select specific GPU(s), use the CUDA_VISIBLE_DEVICES environment variable, e.g.:
|
|
191
|
+
CUDA_VISIBLE_DEVICES=0 hrashs run ...''')
|
|
192
|
+
p.add_argument('-L', '--no-links', action='store_true',
|
|
193
|
+
help='Do not create symlinks in the working directory; copy files instead. By default, symlinks are created to the input T1 and T2 images in the working directory')
|
|
194
|
+
p.add_argument('-T', '--tidy', action='store_true',
|
|
195
|
+
help='Tidy mode. Reduce the number of intermediate files generated.')
|
|
196
|
+
|
|
197
|
+
# Add -k flag for all relevant sub parsers
|
|
198
|
+
for p in [list_parser, run_parser, describe_parser, train_parser]:
|
|
199
|
+
p.add_argument('-k', '--disable-ssl-verification', action='store_true',
|
|
200
|
+
help='Disable SSL verification for Hugging Face Hub access')
|
|
201
|
+
|
|
202
|
+
args = parser.parse_args()
|
|
203
|
+
|
|
204
|
+
# Disable SSL verification if requested
|
|
205
|
+
if args.disable_ssl_verification:
|
|
206
|
+
hf_disable_ssl_verification()
|
|
207
|
+
torch_hub_disable_ssl_verification()
|
|
208
|
+
|
|
209
|
+
# Handle commands
|
|
210
|
+
if args.command == 'list':
|
|
211
|
+
list_atlases(args)
|
|
212
|
+
elif args.command == 'desc':
|
|
213
|
+
describe_atlas(args)
|
|
214
|
+
elif args.command == 'run':
|
|
215
|
+
run_segmentation(args)
|
|
216
|
+
elif args.command == 'train':
|
|
217
|
+
return run_training(args)
|
|
218
|
+
else:
|
|
219
|
+
raise ValueError(f"Unknown command: {args.command}")
|
|
220
|
+
|
|
221
|
+
return 0
|
|
222
|
+
|
|
223
|
+
|
|
224
|
+
def get_atlas_listing(match=None):
|
|
225
|
+
# Fetch upennpatchlab/hyperresashs_atlas_directory file active_atlases.yaml from HF:
|
|
226
|
+
# Read the list of available atlases from the Hugging Face Hub
|
|
227
|
+
repo_list = hf_read_yaml(
|
|
228
|
+
repo_id="upennpatchlab/hyperresashs_atlas_directory",
|
|
229
|
+
filename="active_atlases.yaml")
|
|
230
|
+
d = {}
|
|
231
|
+
for repo in repo_list:
|
|
232
|
+
repo_config = hf_read_yaml(repo_id=repo, filename=f"atlas.yaml")
|
|
233
|
+
if match and not re.match(match, repo_config['metadata'].get('id',''), re.IGNORECASE):
|
|
234
|
+
continue
|
|
235
|
+
repo_config['metadata']['json'] = json.dumps(repo_config)
|
|
236
|
+
repo_config['metadata']['repo'] = repo
|
|
237
|
+
for k, v in repo_config['metadata'].items():
|
|
238
|
+
if k not in d:
|
|
239
|
+
d[k] = []
|
|
240
|
+
d[k].append(v)
|
|
241
|
+
|
|
242
|
+
df = pd.DataFrame(d)
|
|
243
|
+
return df
|
|
244
|
+
|
|
245
|
+
|
|
246
|
+
def print_atlas_listing(long=False, match=None):
|
|
247
|
+
# Get the atlas listing as a DataFrame
|
|
248
|
+
df = get_atlas_listing(match=match)
|
|
249
|
+
df = df[[c for c in df.columns if c != 'json']] # Exclude the json column from display
|
|
250
|
+
df.columns = [c.capitalize().replace('_',' ') for c in df.columns]
|
|
251
|
+
|
|
252
|
+
# Print the atlas directory in a nice format
|
|
253
|
+
if long:
|
|
254
|
+
for _, row in df.iterrows():
|
|
255
|
+
pprint.pp(row.to_dict())
|
|
256
|
+
else:
|
|
257
|
+
print(df[['Id','Name']].to_string(index=False))
|
|
258
|
+
|
|
259
|
+
|
|
260
|
+
def list_atlases(args):
|
|
261
|
+
"""List available atlases."""
|
|
262
|
+
try:
|
|
263
|
+
print_atlas_listing(long=args.long)
|
|
264
|
+
except Exception as e:
|
|
265
|
+
print(f"Error fetching atlas directory from Hugging Face Hub: {e}")
|
|
266
|
+
return
|
|
267
|
+
|
|
268
|
+
|
|
269
|
+
def describe_atlas(args):
|
|
270
|
+
"""Describe a specific atlas."""
|
|
271
|
+
# Read the list of available atlases from the Hugging Face Hub
|
|
272
|
+
try:
|
|
273
|
+
print_atlas_listing(long=True, match=args.atlas)
|
|
274
|
+
except Exception as e:
|
|
275
|
+
print(f"Error fetching atlas directory from Hugging Face Hub: {e}")
|
|
276
|
+
return
|
|
277
|
+
|
|
278
|
+
|
|
279
|
+
def print_header(metadata):
|
|
280
|
+
|
|
281
|
+
main_citation = """\
|
|
282
|
+
Li, Y., Khandelwal, P., Jena, R., Xie, L., Duong, M., Denning, A.E., Brown, C.A.,
|
|
283
|
+
Wisse, L.E., Das, S.R., Wolk, D.A. and Yushkevich, P.A., 2025.
|
|
284
|
+
Achieving detailed medial temporal lobe segmentation with upsampled
|
|
285
|
+
isotropic training from implicit neural representation.
|
|
286
|
+
arXiv preprint arXiv:2508.17171."""
|
|
287
|
+
nnunet_citation = """\
|
|
288
|
+
Isensee, F., Jaeger, P. F., Kohl, S. A., Petersen, J., & Maier-Hein, K. H. (2021).
|
|
289
|
+
nnU-Net: a self-configuring method for deep learning-based biomedical image segmentation.
|
|
290
|
+
Nature methods, 18(2), 203-211."""
|
|
291
|
+
citations = [main_citation, nnunet_citation] + metadata.get('citations', [])
|
|
292
|
+
|
|
293
|
+
print('' + '='*80)
|
|
294
|
+
print(f"HyperResASHS {__version__} using atlas {metadata['id']} {metadata['version']}")
|
|
295
|
+
print(f" - {metadata['name']}")
|
|
296
|
+
print()
|
|
297
|
+
print(f"Please cite the following papers if you use this atlas in your research:")
|
|
298
|
+
|
|
299
|
+
for i,ref in enumerate(citations):
|
|
300
|
+
print(f"\n[{i+1}] {textwrap.fill(textwrap.dedent(ref), width=74, subsequent_indent=' ' * 5)}")
|
|
301
|
+
print('' + '='*80)
|
|
302
|
+
|
|
303
|
+
|
|
304
|
+
class Logger(object):
|
|
305
|
+
|
|
306
|
+
class Stream(object):
|
|
307
|
+
def __init__(self, original_stream, log_file):
|
|
308
|
+
self.original_stream = original_stream
|
|
309
|
+
self.log_file = log_file
|
|
310
|
+
|
|
311
|
+
def write(self, message):
|
|
312
|
+
self.original_stream.write(message) # Write to console
|
|
313
|
+
self.log_file.write(message) # Write to file
|
|
314
|
+
|
|
315
|
+
def flush(self):
|
|
316
|
+
# This flush method is important for proper buffering behavior
|
|
317
|
+
self.original_stream.flush()
|
|
318
|
+
self.log_file.flush()
|
|
319
|
+
|
|
320
|
+
def __init__(self, filename):
|
|
321
|
+
self.log_file = open(filename, "w")
|
|
322
|
+
self.out = self.Stream(sys.stdout, self.log_file)
|
|
323
|
+
self.err = self.Stream(sys.stderr, self.log_file)
|
|
324
|
+
sys.stdout = self.out # Redirect all print statements to this object
|
|
325
|
+
sys.stderr = self.err # Redirect all error messages to this object
|
|
326
|
+
|
|
327
|
+
def close(self):
|
|
328
|
+
self.log_file.close()
|
|
329
|
+
sys.stdout = self.out.original_stream # Restore original stdout
|
|
330
|
+
sys.stderr = self.err.original_stream # Restore original stderr
|
|
331
|
+
|
|
332
|
+
def __del__(self):
|
|
333
|
+
self.close()
|
|
334
|
+
|
|
335
|
+
|
|
336
|
+
def _fetch_template(atlas_config : Dict[str,Any]) -> str:
|
|
337
|
+
# Fetch the template from Hugging Face Hub if specified in the atlas config
|
|
338
|
+
template_hf = atlas_config.get('template', {}).get('hf', None)
|
|
339
|
+
if template_hf:
|
|
340
|
+
template_path = hf.snapshot_download(template_hf)
|
|
341
|
+
else:
|
|
342
|
+
template_path = atlas_config.get('template', {}).get('local', None)
|
|
343
|
+
if template_path is None:
|
|
344
|
+
raise ValueError("No template specified in atlas configuration. Please check the atlas configuration and try again.")
|
|
345
|
+
|
|
346
|
+
return template_path
|
|
347
|
+
|
|
348
|
+
|
|
349
|
+
def _setup_config(atlas_config : Dict[str,Any], args: argparse.Namespace, atlas_local_path:str, training:bool=False):
|
|
350
|
+
|
|
351
|
+
# Set up the atlas configuration the way Yue's code expects it
|
|
352
|
+
config_src = atlas_config['config']
|
|
353
|
+
config_src['TEST_PATH'] = args.workdir
|
|
354
|
+
config_src['TEMPLATE_PATH'] = os.path.join(_fetch_template(atlas_config), 'hyperashs-template')
|
|
355
|
+
config_src['ATLAS_PATH'] = atlas_local_path
|
|
356
|
+
config_src['ITKSNAP_LABEL_FILE'] = os.path.join(atlas_local_path, 'itksnap_labels.txt')
|
|
357
|
+
config_src['GREEDY_NUM_THREADS'] = args.threads
|
|
358
|
+
config_src['NNUNET_NUM_THREADS'] = args.threads
|
|
359
|
+
config_src['FILE_NAME_CONFIG'] = os.path.join(args.workdir, 'config', 'ashs_filename_scheme.yaml')
|
|
360
|
+
|
|
361
|
+
# Write the config to the working directory
|
|
362
|
+
os.makedirs(os.path.join(args.workdir, 'config'), exist_ok=True)
|
|
363
|
+
fn_config_yaml = os.path.join(args.workdir, 'config', f'configtest_{config_src["EXP_NUM"]}_{config_src["MODEL_NAME"]}.yaml')
|
|
364
|
+
with open(fn_config_yaml, 'wt') as f:
|
|
365
|
+
yaml.dump(config_src, f)
|
|
366
|
+
|
|
367
|
+
# Write the preferred naming scheme to the working directory
|
|
368
|
+
fn_scheme = files('hyperresashs').joinpath('config_templates/ashs_filename_scheme.yaml')
|
|
369
|
+
with fn_scheme.open('r') as f:
|
|
370
|
+
_ashs_naming_scheme = f.read()
|
|
371
|
+
with open(config_src['FILE_NAME_CONFIG'], 'wt') as f:
|
|
372
|
+
f.write(_ashs_naming_scheme)
|
|
373
|
+
|
|
374
|
+
# Create the config in the way that Yue's code expects it
|
|
375
|
+
config = yaml.safe_load(open(fn_config_yaml, 'r'))
|
|
376
|
+
|
|
377
|
+
return config
|
|
378
|
+
|
|
379
|
+
|
|
380
|
+
|
|
381
|
+
def run_segmentation(args):
|
|
382
|
+
"""Run the segmentation pipeline."""
|
|
383
|
+
|
|
384
|
+
# Create a logger
|
|
385
|
+
os.makedirs(os.path.join(args.workdir, 'logs'), exist_ok=True)
|
|
386
|
+
logger = Logger(os.path.join(os.path.join(args.workdir, 'logs'), 'hyperashs_log.txt'))
|
|
387
|
+
|
|
388
|
+
# First, the atlas may be either a local path to the atlas or a huggingface hub id.
|
|
389
|
+
atlas_local_path = None
|
|
390
|
+
if os.path.isdir(args.atlas) and os.path.exists(args.atlas):
|
|
391
|
+
# If it's a local path, we assume it's a config file and we load the config directly from it
|
|
392
|
+
with open(os.path.join(args.atlas, 'atlas.yaml'), 'r') as f:
|
|
393
|
+
atlas_config = yaml.safe_load(f)
|
|
394
|
+
atlas_local_path = args.atlas
|
|
395
|
+
|
|
396
|
+
# Print the header with atlas information and citations
|
|
397
|
+
print_header(atlas_config['metadata'])
|
|
398
|
+
else:
|
|
399
|
+
# Fetch the atlas the user wants
|
|
400
|
+
df_meta = get_atlas_listing(match=args.atlas)
|
|
401
|
+
if len(df_meta) == 0:
|
|
402
|
+
raise ValueError(f"Error: No atlases found matching '{args.atlas}'. Please check the atlas name and try again.")
|
|
403
|
+
elif len(df_meta) > 1:
|
|
404
|
+
raise ValueError(f"Error: Multiple atlases found matching '{args.atlas}'. Please be more specific.")
|
|
405
|
+
|
|
406
|
+
# Read the atlas configuration from the JSON metadata
|
|
407
|
+
atlas_config = json.loads(df_meta.iloc[0]['json'])
|
|
408
|
+
|
|
409
|
+
# Print the header with atlas information and citations
|
|
410
|
+
print_header(atlas_config['metadata'])
|
|
411
|
+
|
|
412
|
+
# Download the enture atlas snapshot
|
|
413
|
+
atlas_local_path = hf.snapshot_download(df_meta.iloc[0]['repo'])
|
|
414
|
+
|
|
415
|
+
# Set up the atlas configuration the way Yue's code expects it
|
|
416
|
+
config = _setup_config(atlas_config, args, atlas_local_path, training=False)
|
|
417
|
+
|
|
418
|
+
# Create the inferencer
|
|
419
|
+
tester = HyperASHSInference(config)
|
|
420
|
+
|
|
421
|
+
# Create links in the working directory
|
|
422
|
+
create_links, overwrite_existing = not args.no_links, not args.no_overwrite
|
|
423
|
+
for img_type, img_path in [('mprage', args.t1), ('tse', args.t2)]:
|
|
424
|
+
dest = os.path.join(args.workdir, f'{img_type}.nii.gz')
|
|
425
|
+
copy_or_link_file(img_path, dest, create_links=create_links, force_overwrite=overwrite_existing, relative_links=False)
|
|
426
|
+
|
|
427
|
+
# Run the segmentation
|
|
428
|
+
print('-' * 60)
|
|
429
|
+
print(f"Running HyperResASHS with:")
|
|
430
|
+
print(f" Atlas: {args.atlas}")
|
|
431
|
+
print(f" T1: {args.t1}")
|
|
432
|
+
print(f" T2: {args.t2}")
|
|
433
|
+
print(f" Workdir: {args.workdir}")
|
|
434
|
+
print('-' * 60)
|
|
435
|
+
tester.run_inference_for_one_case(case_path=args.workdir,
|
|
436
|
+
save_intermediates=args.tidy is False,
|
|
437
|
+
overwrite_existing=overwrite_existing,
|
|
438
|
+
create_links=create_links,
|
|
439
|
+
device=args.device)
|
|
440
|
+
|
|
441
|
+
|
|
442
|
+
def run_training(args) -> int:
|
|
443
|
+
"""Run the training pipeline."""
|
|
444
|
+
|
|
445
|
+
create_links, overwrite_existing = not args.no_links, not args.no_overwrite
|
|
446
|
+
|
|
447
|
+
# Create a logger
|
|
448
|
+
os.makedirs(os.path.join(args.workdir, 'logs'), exist_ok=True)
|
|
449
|
+
logger = Logger(os.path.join(os.path.join(args.workdir, 'logs'), 'hyperashs_log.txt'))
|
|
450
|
+
|
|
451
|
+
# Load the training configuration from the specified YAML file
|
|
452
|
+
try:
|
|
453
|
+
with open(os.path.join(args.config), 'r') as f:
|
|
454
|
+
atlas_config = yaml.safe_load(f)
|
|
455
|
+
|
|
456
|
+
# Print the header with atlas information and citations
|
|
457
|
+
print_header(atlas_config['metadata'])
|
|
458
|
+
|
|
459
|
+
except Exception as e:
|
|
460
|
+
print(f"Error loading atlas configuration: {e}")
|
|
461
|
+
return -1
|
|
462
|
+
|
|
463
|
+
# Set up the atlas configuration the way Yue's code expects it
|
|
464
|
+
config = _setup_config(atlas_config, args, args.workdir, training=True)
|
|
465
|
+
|
|
466
|
+
# Process the stage argument to determine which stages to run
|
|
467
|
+
stage_no = set()
|
|
468
|
+
if args.stage:
|
|
469
|
+
for part in args.stage.split(','):
|
|
470
|
+
if '-' in part:
|
|
471
|
+
start, end = map(int, part.split('-'))
|
|
472
|
+
stage_no.update(range(start, end + 1))
|
|
473
|
+
else:
|
|
474
|
+
stage_no.add(int(part))
|
|
475
|
+
else:
|
|
476
|
+
stage_no = {1, 2, 3, 4, 5} # Default to running all stages if not specified
|
|
477
|
+
|
|
478
|
+
# Create the training object
|
|
479
|
+
trainer = HyperASHSTraining(config,
|
|
480
|
+
manifest_file=args.manifest,
|
|
481
|
+
label_file=args.labels,
|
|
482
|
+
xval_file=args.xval,
|
|
483
|
+
output_dir=args.workdir,
|
|
484
|
+
overwrite_existing=overwrite_existing,
|
|
485
|
+
save_intermediates=args.tidy is False,
|
|
486
|
+
create_links=create_links)
|
|
487
|
+
|
|
488
|
+
# Set up stages and validations for the training pipeline
|
|
489
|
+
stages = [
|
|
490
|
+
(1, trainer.preprocess, {'filter': args.filter}, None),
|
|
491
|
+
(2, trainer.train_inr, {'filter': args.filter, 'device': args.device, 'random_seed': args.inr_random_seed, 'batch_size': args.inr_batch_size}, trainer.validity_check_inr_results),
|
|
492
|
+
(3, trainer.prepare_nnunet, {}, None),
|
|
493
|
+
(4, trainer.train_nnunet, {'filter': args.filter, 'device': args.device}, trainer.validity_check_nnunet_results),
|
|
494
|
+
(5, trainer.finalize, {'full_metadata': atlas_config}, None)
|
|
495
|
+
]
|
|
496
|
+
|
|
497
|
+
# Run the specified stages
|
|
498
|
+
for i_stage, stage_func, stage_kwargs, validity_check in stages:
|
|
499
|
+
if i_stage in stage_no:
|
|
500
|
+
# Run the stage
|
|
501
|
+
stage_func(**stage_kwargs)
|
|
502
|
+
if max(stage_no) == i_stage:
|
|
503
|
+
return 0
|
|
504
|
+
|
|
505
|
+
# Perform validity check at the end of this stage. This prevents moving on
|
|
506
|
+
# to the next stage if there were failures in the current stage.
|
|
507
|
+
if validity_check is not None:
|
|
508
|
+
if validity_check() is False:
|
|
509
|
+
return 1
|
|
510
|
+
|
|
511
|
+
return 0
|