hyperresashs 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,25 @@
1
+ import sys
2
+ import os
3
+ from importlib.metadata import version, PackageNotFoundError
4
+
5
+ # Check for any already-cached nnunetv2 modules
6
+ to_remove = [key for key in sys.modules if key == "nnunetv2" or key.startswith("nnunetv2.")]
7
+ for key in to_remove:
8
+ print(f'hyperresashs found existing nnunetv2 module {key} in sys.modules. Removing it to avoid conflicts with hyperresashs submodule nnunetv2.')
9
+ del sys.modules[key]
10
+
11
+ # Insert the submodule path at the front so it takes priority
12
+ _submodule_path = os.path.join(os.path.dirname(__file__), "submodules", "nnUNet")
13
+ sys.path.insert(0, _submodule_path)
14
+
15
+ # Insert the submodule path at the front so it takes priority
16
+ _submodule_path_inr = os.path.join(os.path.dirname(__file__), "submodules", "multi_contrast_inr")
17
+ sys.path.insert(0, _submodule_path_inr)
18
+
19
+ # Get the version from python importlib
20
+ try:
21
+ # Replace 'your_package_name' with the actual name of your package
22
+ __version__ = version("hyperresashs")
23
+ except PackageNotFoundError:
24
+ # Handle cases where the package isn't installed in the environment (e.g., running from source without proper installation)
25
+ __version__ = "unknown"
@@ -0,0 +1,2 @@
1
+ from .ashs_cli import main
2
+ main()
@@ -0,0 +1,511 @@
1
+ from .ashs_inference import HyperASHSInference
2
+ from .utils.huggingface import hf_disable_ssl_verification, hf_read_yaml, torch_hub_disable_ssl_verification
3
+ from .ashs_training import HyperASHSTraining
4
+ from .utils.tool import copy_or_link_file
5
+ from . import __version__
6
+ import argparse
7
+ import huggingface_hub as hf
8
+ import pandas as pd
9
+ import re
10
+ import pprint
11
+ import json
12
+ import textwrap
13
+ import os
14
+ import shutil
15
+ import yaml
16
+ import sys
17
+ from typing import Dict, Any
18
+ from importlib.resources import files
19
+
20
+ """
21
+ Original ASHS command line help message for reference:
22
+ required options:
23
+ -a dir Location of the atlas directory. Can be a full pathname or a
24
+ relative directory name under ASHS_ROOT/data directory.
25
+ -g image Filename of 3D (g)radient echo MRI (ASHS_MPRAGE, T1w)
26
+ -f image Filename of 2D focal (f)ast spin echo MRI (ASHS_TSE, T2w)
27
+ -w path Working/output directory
28
+
29
+ optional:
30
+ -d Enable debugging
31
+ -h Print help
32
+ -s integer Run only one stage (see below); also accepts range (e.g. -s 1-3)
33
+ -N No overriding of registration results. If a result from an earlier run
34
+ exists, don't run greedy again.
35
+ -G Use template brain mask in T1 template rigid registratian
36
+ -T Tidy mode. Cleans up files once they are unneeded. The -N option will
37
+ have no effect in tidy mode, because intermediate results will be erased.
38
+ -I string Subject ID (for stats output). Defaults to last word of working dir.
39
+ -V Display version information and exit
40
+ -C file Configuration file. If not passed, uses $ASHS_ROOT/bin/ashs_config.sh
41
+ -Q Use Sun Grid Engine (SGE) to schedule sub-tasks in each stage.
42
+ By default, the whole ashs_main job runs in a single process.
43
+ If you are doing a lot of segmentations and have SGE, it is better to
44
+ run each segmentation (ashs_main) in a separate SGE job, rather than use the -q flag.
45
+ The -q flag is best for when you have only a few segmentations and want them to run fast.
46
+ -P Use GNU parallel to run on multiple cores on the local machine. You need to
47
+ have GNU parallel installed.
48
+ -S Use SLURM instead of SGE, LSF or GNU parallel
49
+ -l Use LSF instead of SGE, SLURM or GNU parallel
50
+ -q OPTS Pass in additional options to SGE/SLURM/LSF/GNU Parallel. If -S -B or -P not specified
51
+ turns on SGE.
52
+ -z script Provide a path to an executable script that will be used to retrieve SGE, LSF, SLURM or
53
+ GNU parallel options for different stages of ASHS. Takes precendence over -q
54
+ -r files Compare segmentation results with a reference segmentation. The parameter
55
+ files should consist of two nifti files in quotation marks:
56
+
57
+ -r "ref_seg_left.nii.gz ref_seg_right.nii.gz"
58
+
59
+ The results will include overlap calculations between different
60
+ stages of the segmentation and the reference segmentation. Note that the
61
+ comparison takes into account the heuristic rules specified in the altas, so
62
+ it is not as simple as computing dice overlaps between the reference seg
63
+ and the ASHS segs.
64
+ -m file Provide the .mat file for the transform between the T1w and T2w image. The file
65
+ is in the format used by ITK-SNAP and C3D and should be generated by performing
66
+ registration with the T2w MRI as fixed image and the T1w MRI as moving.
67
+ By default, the mat file is used as a hint to initialize rigid T2/T1
68
+ but this can be modified with the -M flag.
69
+ -M The mat file provided with -m is used as the final T2/T1 registration.
70
+ ASHS will not attempt to run registration between T2 and T2.
71
+ -t threads Specify number of parallel threads the greedy runs
72
+ -H Tell ASHS to use external hooks for reporting progress, errors, and warnings.
73
+ The environment variables ASHS_HOOK_SCRIPT must be set to point to the appropriate
74
+ script. For an example script with comments, see ashs_default_hook.sh
75
+ The purpose of the hook is to allow intermediary systems (e.g. XNAT)
76
+ to monitor ASHS performance. An optional ASHS_HOOK_DATA variable can be set
77
+ -B Do not perform the bootstrapping step, and use the output of the initial joint
78
+ label fusion (in multiatlas directory) as the final output.
79
+ and will be forwarded to the script
80
+ """
81
+
82
+ _ashs_naming_scheme = """
83
+ # ------- input -------
84
+ # input image names
85
+ t1_native_img: "mprage.nii.gz"
86
+ t2_native_img: "tse.nii.gz"
87
+
88
+ # input template names
89
+ template: "template.nii.gz"
90
+ left_roi_file: "left_round_in_global_space_larger.nii.gz"
91
+ right_roi_file: "right_round_in_global_space_larger.nii.gz"
92
+
93
+ # ------- output -------
94
+ # registration intermediate
95
+ t1_name_after_triming_neck: "mprage_necktrim.nii.gz"
96
+ affine_matrix: "t1_to_template_affine_inv.mat"
97
+ t1_whole_img: "mprage_to_tse_warped.nii.gz"
98
+ t2_padded_img: "tse_padded.nii.gz"
99
+
100
+ # registration output
101
+ template_to_3tt1: "template_to_mprage_warped.nii.gz"
102
+ global_roi_in_3tt1_XYZ: "template_roi_XYZ_to_mprage_warped.nii.gz"
103
+
104
+ # preprocessing
105
+ inr_primary: "primary.nii.gz"
106
+ inr_secondary: "secondary.nii.gz"
107
+ inr_seg: "primary_seg.nii.gz"
108
+ seg: "input_primary_seg.nii.gz"
109
+ hyper_primary: "hyper_primary.nii.gz"
110
+ hyper_secondary: "hyper_secondary.nii.gz"
111
+ reg_mat: "auxiluary_to_primary.mat"
112
+ hyper_secondary_after_registertion: "auxiluary_to_primary_registered.nii.gz"
113
+ hyper_primary_seg: "inr_hyper_primary_seg.nii.gz"
114
+ """
115
+
116
+
117
+ def main():
118
+
119
+ parser = argparse.ArgumentParser(
120
+ prog='hrashs',
121
+ description='HyperResASHS: High-Resolution Automatic Segmentation of Hippocampal Subfields',
122
+ epilog="""HyperResASHS (C) 2026 by Yue Li, Paul Yushkevich, and UPenn Patch Lab.
123
+ Citation: https://doi.org/10.48550/arXiv.2508.17171
124
+ GitHub: https://github.com/liyue3780/HyperResASHS.git""")
125
+
126
+ # Create subparsers for different commands
127
+ subparsers = parser.add_subparsers(dest='command', required=True, help='Available commands')
128
+
129
+ # List atlases subcommand
130
+ list_parser = subparsers.add_parser('list', help='List available HyperResASHS atlases')
131
+ list_parser.add_argument('-l', '--long', action='store_true', help='Show detailed information about each atlas')
132
+
133
+ # Describe atlas subcommand
134
+ describe_parser = subparsers.add_parser('desc', help='Describe a specific HyperResASHS atlas')
135
+ describe_parser.add_argument('atlas', type=str, help='Name of the atlas to describe or path to atlas config file')
136
+
137
+ # Run segmentation subcommand
138
+ run_parser = subparsers.add_parser('run', help='Run HyperResASHS segmentation pipeline')
139
+ run_parser.add_argument('-a', '--atlas', type=str, required=True,
140
+ help='Name of the atlas to use or path to atlas config file')
141
+ run_parser.add_argument('-g', '--t1', type=str, required=True,
142
+ help='Path to T1-weighted image')
143
+ run_parser.add_argument('-f', '--t2', type=str, required=True,
144
+ help='Path to T2-weighted image')
145
+
146
+ # Training command
147
+ train_parser = subparsers.add_parser('train', help='Train a HyperResASHS model on a dataset')
148
+ train_parser.add_argument('-c', '--config', type=str, required=True,
149
+ help='Path to training configuration YAML file. See documentation for expected format.')
150
+ train_parser.add_argument('-m', '--manifest', type=str, required=True,
151
+ help='Path to manifest CSV file describing the training dataset. See documentation for expected format.')
152
+ train_parser.add_argument('-l', '--labels', type=str, required=True,
153
+ help='Path to ITK-SNAP label description file. See documentation for expected format.')
154
+ train_parser.add_argument('-x', '--xval', type=str, default=None,
155
+ help='Cross-validation fold specification. See documentation for expected format.')
156
+ train_parser.add_argument('-R', '--inr-random-seed', type=int, default=None,
157
+ help='Specify random seed for the INR optimization; use to rerun failed INR experiments.')
158
+ train_parser.add_argument('--inr-batch-size', type=int, default=None,
159
+ help='Specify batch size for the INR optimization. Default: 10000')
160
+
161
+ train_parser.add_argument('-s', '--stage', type=str, default='',
162
+ help='''
163
+ Run one or more selected stages of the training pipeline.
164
+ You can specify single stage as -s 1, or a range of stages as -s 1-3.
165
+ The stages are as follows:
166
+ 1: Preprocessing (neck trim, registration, patch extraction)
167
+ 2: INR training
168
+ 3: nnU-Net preparation (resampling, cropping, and formatting for nnU-Net)
169
+ 4: nnU-Net training
170
+ ''')
171
+ train_parser.add_argument('-F', '--filter', type=str, metavar='REGEX' ,default=None,
172
+ help='''Restrict execution to image(s) that match specified regular expression.
173
+ For stage 1 (preprocessing), REGEX will be matched to subject id and date.
174
+ For stage 2 (INR training), REGEX will be matched to subject id, date, and side.
175
+ For stage 3 (nnU-Net preparation), this is ignored.
176
+ For stage 4 (nnU-Net training), REGEX is matched to fold number (0-4)
177
+ ''')
178
+
179
+
180
+ # Add common arguments for run and train subcommands
181
+ for p in [run_parser, train_parser]:
182
+ p.add_argument('-w', '--workdir', type=str, required=True,
183
+ help='Path to working directory')
184
+ p.add_argument('-N', '--no-overwrite', action='store_true',
185
+ help='Do not overwrite existing results')
186
+ p.add_argument('-t', '--threads', type=int, default=1,
187
+ help='Number of parallel threads to use for segmentation [default: 1]')
188
+ p.add_argument('--device', type=str, default='auto',
189
+ help='''Device to use for segmentation (e.g. "cuda" or "cpu"). Default: auto-detect
190
+ To select specific GPU(s), use the CUDA_VISIBLE_DEVICES environment variable, e.g.:
191
+ CUDA_VISIBLE_DEVICES=0 hrashs run ...''')
192
+ p.add_argument('-L', '--no-links', action='store_true',
193
+ help='Do not create symlinks in the working directory; copy files instead. By default, symlinks are created to the input T1 and T2 images in the working directory')
194
+ p.add_argument('-T', '--tidy', action='store_true',
195
+ help='Tidy mode. Reduce the number of intermediate files generated.')
196
+
197
+ # Add -k flag for all relevant sub parsers
198
+ for p in [list_parser, run_parser, describe_parser, train_parser]:
199
+ p.add_argument('-k', '--disable-ssl-verification', action='store_true',
200
+ help='Disable SSL verification for Hugging Face Hub access')
201
+
202
+ args = parser.parse_args()
203
+
204
+ # Disable SSL verification if requested
205
+ if args.disable_ssl_verification:
206
+ hf_disable_ssl_verification()
207
+ torch_hub_disable_ssl_verification()
208
+
209
+ # Handle commands
210
+ if args.command == 'list':
211
+ list_atlases(args)
212
+ elif args.command == 'desc':
213
+ describe_atlas(args)
214
+ elif args.command == 'run':
215
+ run_segmentation(args)
216
+ elif args.command == 'train':
217
+ return run_training(args)
218
+ else:
219
+ raise ValueError(f"Unknown command: {args.command}")
220
+
221
+ return 0
222
+
223
+
224
+ def get_atlas_listing(match=None):
225
+ # Fetch upennpatchlab/hyperresashs_atlas_directory file active_atlases.yaml from HF:
226
+ # Read the list of available atlases from the Hugging Face Hub
227
+ repo_list = hf_read_yaml(
228
+ repo_id="upennpatchlab/hyperresashs_atlas_directory",
229
+ filename="active_atlases.yaml")
230
+ d = {}
231
+ for repo in repo_list:
232
+ repo_config = hf_read_yaml(repo_id=repo, filename=f"atlas.yaml")
233
+ if match and not re.match(match, repo_config['metadata'].get('id',''), re.IGNORECASE):
234
+ continue
235
+ repo_config['metadata']['json'] = json.dumps(repo_config)
236
+ repo_config['metadata']['repo'] = repo
237
+ for k, v in repo_config['metadata'].items():
238
+ if k not in d:
239
+ d[k] = []
240
+ d[k].append(v)
241
+
242
+ df = pd.DataFrame(d)
243
+ return df
244
+
245
+
246
+ def print_atlas_listing(long=False, match=None):
247
+ # Get the atlas listing as a DataFrame
248
+ df = get_atlas_listing(match=match)
249
+ df = df[[c for c in df.columns if c != 'json']] # Exclude the json column from display
250
+ df.columns = [c.capitalize().replace('_',' ') for c in df.columns]
251
+
252
+ # Print the atlas directory in a nice format
253
+ if long:
254
+ for _, row in df.iterrows():
255
+ pprint.pp(row.to_dict())
256
+ else:
257
+ print(df[['Id','Name']].to_string(index=False))
258
+
259
+
260
+ def list_atlases(args):
261
+ """List available atlases."""
262
+ try:
263
+ print_atlas_listing(long=args.long)
264
+ except Exception as e:
265
+ print(f"Error fetching atlas directory from Hugging Face Hub: {e}")
266
+ return
267
+
268
+
269
+ def describe_atlas(args):
270
+ """Describe a specific atlas."""
271
+ # Read the list of available atlases from the Hugging Face Hub
272
+ try:
273
+ print_atlas_listing(long=True, match=args.atlas)
274
+ except Exception as e:
275
+ print(f"Error fetching atlas directory from Hugging Face Hub: {e}")
276
+ return
277
+
278
+
279
+ def print_header(metadata):
280
+
281
+ main_citation = """\
282
+ Li, Y., Khandelwal, P., Jena, R., Xie, L., Duong, M., Denning, A.E., Brown, C.A.,
283
+ Wisse, L.E., Das, S.R., Wolk, D.A. and Yushkevich, P.A., 2025.
284
+ Achieving detailed medial temporal lobe segmentation with upsampled
285
+ isotropic training from implicit neural representation.
286
+ arXiv preprint arXiv:2508.17171."""
287
+ nnunet_citation = """\
288
+ Isensee, F., Jaeger, P. F., Kohl, S. A., Petersen, J., & Maier-Hein, K. H. (2021).
289
+ nnU-Net: a self-configuring method for deep learning-based biomedical image segmentation.
290
+ Nature methods, 18(2), 203-211."""
291
+ citations = [main_citation, nnunet_citation] + metadata.get('citations', [])
292
+
293
+ print('' + '='*80)
294
+ print(f"HyperResASHS {__version__} using atlas {metadata['id']} {metadata['version']}")
295
+ print(f" - {metadata['name']}")
296
+ print()
297
+ print(f"Please cite the following papers if you use this atlas in your research:")
298
+
299
+ for i,ref in enumerate(citations):
300
+ print(f"\n[{i+1}] {textwrap.fill(textwrap.dedent(ref), width=74, subsequent_indent=' ' * 5)}")
301
+ print('' + '='*80)
302
+
303
+
304
+ class Logger(object):
305
+
306
+ class Stream(object):
307
+ def __init__(self, original_stream, log_file):
308
+ self.original_stream = original_stream
309
+ self.log_file = log_file
310
+
311
+ def write(self, message):
312
+ self.original_stream.write(message) # Write to console
313
+ self.log_file.write(message) # Write to file
314
+
315
+ def flush(self):
316
+ # This flush method is important for proper buffering behavior
317
+ self.original_stream.flush()
318
+ self.log_file.flush()
319
+
320
+ def __init__(self, filename):
321
+ self.log_file = open(filename, "w")
322
+ self.out = self.Stream(sys.stdout, self.log_file)
323
+ self.err = self.Stream(sys.stderr, self.log_file)
324
+ sys.stdout = self.out # Redirect all print statements to this object
325
+ sys.stderr = self.err # Redirect all error messages to this object
326
+
327
+ def close(self):
328
+ self.log_file.close()
329
+ sys.stdout = self.out.original_stream # Restore original stdout
330
+ sys.stderr = self.err.original_stream # Restore original stderr
331
+
332
+ def __del__(self):
333
+ self.close()
334
+
335
+
336
+ def _fetch_template(atlas_config : Dict[str,Any]) -> str:
337
+ # Fetch the template from Hugging Face Hub if specified in the atlas config
338
+ template_hf = atlas_config.get('template', {}).get('hf', None)
339
+ if template_hf:
340
+ template_path = hf.snapshot_download(template_hf)
341
+ else:
342
+ template_path = atlas_config.get('template', {}).get('local', None)
343
+ if template_path is None:
344
+ raise ValueError("No template specified in atlas configuration. Please check the atlas configuration and try again.")
345
+
346
+ return template_path
347
+
348
+
349
+ def _setup_config(atlas_config : Dict[str,Any], args: argparse.Namespace, atlas_local_path:str, training:bool=False):
350
+
351
+ # Set up the atlas configuration the way Yue's code expects it
352
+ config_src = atlas_config['config']
353
+ config_src['TEST_PATH'] = args.workdir
354
+ config_src['TEMPLATE_PATH'] = os.path.join(_fetch_template(atlas_config), 'hyperashs-template')
355
+ config_src['ATLAS_PATH'] = atlas_local_path
356
+ config_src['ITKSNAP_LABEL_FILE'] = os.path.join(atlas_local_path, 'itksnap_labels.txt')
357
+ config_src['GREEDY_NUM_THREADS'] = args.threads
358
+ config_src['NNUNET_NUM_THREADS'] = args.threads
359
+ config_src['FILE_NAME_CONFIG'] = os.path.join(args.workdir, 'config', 'ashs_filename_scheme.yaml')
360
+
361
+ # Write the config to the working directory
362
+ os.makedirs(os.path.join(args.workdir, 'config'), exist_ok=True)
363
+ fn_config_yaml = os.path.join(args.workdir, 'config', f'configtest_{config_src["EXP_NUM"]}_{config_src["MODEL_NAME"]}.yaml')
364
+ with open(fn_config_yaml, 'wt') as f:
365
+ yaml.dump(config_src, f)
366
+
367
+ # Write the preferred naming scheme to the working directory
368
+ fn_scheme = files('hyperresashs').joinpath('config_templates/ashs_filename_scheme.yaml')
369
+ with fn_scheme.open('r') as f:
370
+ _ashs_naming_scheme = f.read()
371
+ with open(config_src['FILE_NAME_CONFIG'], 'wt') as f:
372
+ f.write(_ashs_naming_scheme)
373
+
374
+ # Create the config in the way that Yue's code expects it
375
+ config = yaml.safe_load(open(fn_config_yaml, 'r'))
376
+
377
+ return config
378
+
379
+
380
+
381
+ def run_segmentation(args):
382
+ """Run the segmentation pipeline."""
383
+
384
+ # Create a logger
385
+ os.makedirs(os.path.join(args.workdir, 'logs'), exist_ok=True)
386
+ logger = Logger(os.path.join(os.path.join(args.workdir, 'logs'), 'hyperashs_log.txt'))
387
+
388
+ # First, the atlas may be either a local path to the atlas or a huggingface hub id.
389
+ atlas_local_path = None
390
+ if os.path.isdir(args.atlas) and os.path.exists(args.atlas):
391
+ # If it's a local path, we assume it's a config file and we load the config directly from it
392
+ with open(os.path.join(args.atlas, 'atlas.yaml'), 'r') as f:
393
+ atlas_config = yaml.safe_load(f)
394
+ atlas_local_path = args.atlas
395
+
396
+ # Print the header with atlas information and citations
397
+ print_header(atlas_config['metadata'])
398
+ else:
399
+ # Fetch the atlas the user wants
400
+ df_meta = get_atlas_listing(match=args.atlas)
401
+ if len(df_meta) == 0:
402
+ raise ValueError(f"Error: No atlases found matching '{args.atlas}'. Please check the atlas name and try again.")
403
+ elif len(df_meta) > 1:
404
+ raise ValueError(f"Error: Multiple atlases found matching '{args.atlas}'. Please be more specific.")
405
+
406
+ # Read the atlas configuration from the JSON metadata
407
+ atlas_config = json.loads(df_meta.iloc[0]['json'])
408
+
409
+ # Print the header with atlas information and citations
410
+ print_header(atlas_config['metadata'])
411
+
412
+ # Download the enture atlas snapshot
413
+ atlas_local_path = hf.snapshot_download(df_meta.iloc[0]['repo'])
414
+
415
+ # Set up the atlas configuration the way Yue's code expects it
416
+ config = _setup_config(atlas_config, args, atlas_local_path, training=False)
417
+
418
+ # Create the inferencer
419
+ tester = HyperASHSInference(config)
420
+
421
+ # Create links in the working directory
422
+ create_links, overwrite_existing = not args.no_links, not args.no_overwrite
423
+ for img_type, img_path in [('mprage', args.t1), ('tse', args.t2)]:
424
+ dest = os.path.join(args.workdir, f'{img_type}.nii.gz')
425
+ copy_or_link_file(img_path, dest, create_links=create_links, force_overwrite=overwrite_existing, relative_links=False)
426
+
427
+ # Run the segmentation
428
+ print('-' * 60)
429
+ print(f"Running HyperResASHS with:")
430
+ print(f" Atlas: {args.atlas}")
431
+ print(f" T1: {args.t1}")
432
+ print(f" T2: {args.t2}")
433
+ print(f" Workdir: {args.workdir}")
434
+ print('-' * 60)
435
+ tester.run_inference_for_one_case(case_path=args.workdir,
436
+ save_intermediates=args.tidy is False,
437
+ overwrite_existing=overwrite_existing,
438
+ create_links=create_links,
439
+ device=args.device)
440
+
441
+
442
+ def run_training(args) -> int:
443
+ """Run the training pipeline."""
444
+
445
+ create_links, overwrite_existing = not args.no_links, not args.no_overwrite
446
+
447
+ # Create a logger
448
+ os.makedirs(os.path.join(args.workdir, 'logs'), exist_ok=True)
449
+ logger = Logger(os.path.join(os.path.join(args.workdir, 'logs'), 'hyperashs_log.txt'))
450
+
451
+ # Load the training configuration from the specified YAML file
452
+ try:
453
+ with open(os.path.join(args.config), 'r') as f:
454
+ atlas_config = yaml.safe_load(f)
455
+
456
+ # Print the header with atlas information and citations
457
+ print_header(atlas_config['metadata'])
458
+
459
+ except Exception as e:
460
+ print(f"Error loading atlas configuration: {e}")
461
+ return -1
462
+
463
+ # Set up the atlas configuration the way Yue's code expects it
464
+ config = _setup_config(atlas_config, args, args.workdir, training=True)
465
+
466
+ # Process the stage argument to determine which stages to run
467
+ stage_no = set()
468
+ if args.stage:
469
+ for part in args.stage.split(','):
470
+ if '-' in part:
471
+ start, end = map(int, part.split('-'))
472
+ stage_no.update(range(start, end + 1))
473
+ else:
474
+ stage_no.add(int(part))
475
+ else:
476
+ stage_no = {1, 2, 3, 4, 5} # Default to running all stages if not specified
477
+
478
+ # Create the training object
479
+ trainer = HyperASHSTraining(config,
480
+ manifest_file=args.manifest,
481
+ label_file=args.labels,
482
+ xval_file=args.xval,
483
+ output_dir=args.workdir,
484
+ overwrite_existing=overwrite_existing,
485
+ save_intermediates=args.tidy is False,
486
+ create_links=create_links)
487
+
488
+ # Set up stages and validations for the training pipeline
489
+ stages = [
490
+ (1, trainer.preprocess, {'filter': args.filter}, None),
491
+ (2, trainer.train_inr, {'filter': args.filter, 'device': args.device, 'random_seed': args.inr_random_seed, 'batch_size': args.inr_batch_size}, trainer.validity_check_inr_results),
492
+ (3, trainer.prepare_nnunet, {}, None),
493
+ (4, trainer.train_nnunet, {'filter': args.filter, 'device': args.device}, trainer.validity_check_nnunet_results),
494
+ (5, trainer.finalize, {'full_metadata': atlas_config}, None)
495
+ ]
496
+
497
+ # Run the specified stages
498
+ for i_stage, stage_func, stage_kwargs, validity_check in stages:
499
+ if i_stage in stage_no:
500
+ # Run the stage
501
+ stage_func(**stage_kwargs)
502
+ if max(stage_no) == i_stage:
503
+ return 0
504
+
505
+ # Perform validity check at the end of this stage. This prevents moving on
506
+ # to the next stage if there were failures in the current stage.
507
+ if validity_check is not None:
508
+ if validity_check() is False:
509
+ return 1
510
+
511
+ return 0