octopi 1.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- octopi/__init__.py +7 -0
- octopi/datasets/__init__.py +0 -0
- octopi/datasets/augment.py +83 -0
- octopi/datasets/cached_datset.py +113 -0
- octopi/datasets/dataset.py +19 -0
- octopi/datasets/generators.py +458 -0
- octopi/datasets/io.py +200 -0
- octopi/datasets/mixup.py +49 -0
- octopi/datasets/multi_config_generator.py +252 -0
- octopi/entry_points/__init__.py +0 -0
- octopi/entry_points/common.py +119 -0
- octopi/entry_points/create_slurm_submission.py +251 -0
- octopi/entry_points/groups.py +152 -0
- octopi/entry_points/run_create_targets.py +234 -0
- octopi/entry_points/run_evaluate.py +99 -0
- octopi/entry_points/run_extract_mb_picks.py +191 -0
- octopi/entry_points/run_extract_midpoint.py +143 -0
- octopi/entry_points/run_localize.py +176 -0
- octopi/entry_points/run_optuna.py +161 -0
- octopi/entry_points/run_segment.py +154 -0
- octopi/entry_points/run_train.py +189 -0
- octopi/extract/__init__.py +0 -0
- octopi/extract/localize.py +217 -0
- octopi/extract/membranebound_extract.py +263 -0
- octopi/extract/midpoint_extract.py +193 -0
- octopi/main.py +33 -0
- octopi/models/AttentionUnet.py +56 -0
- octopi/models/MedNeXt.py +111 -0
- octopi/models/ModelTemplate.py +36 -0
- octopi/models/SegResNet.py +92 -0
- octopi/models/Unet.py +59 -0
- octopi/models/UnetPlusPlus.py +47 -0
- octopi/models/__init__.py +0 -0
- octopi/models/common.py +72 -0
- octopi/processing/__init__.py +0 -0
- octopi/processing/create_targets_from_picks.py +224 -0
- octopi/processing/downloader.py +138 -0
- octopi/processing/downsample.py +125 -0
- octopi/processing/evaluate.py +302 -0
- octopi/processing/importers.py +116 -0
- octopi/processing/segmentation_from_picks.py +167 -0
- octopi/pytorch/__init__.py +0 -0
- octopi/pytorch/hyper_search.py +244 -0
- octopi/pytorch/model_search_submitter.py +291 -0
- octopi/pytorch/segmentation.py +363 -0
- octopi/pytorch/segmentation_multigpu.py +162 -0
- octopi/pytorch/trainer.py +465 -0
- octopi/pytorch_lightning/__init__.py +0 -0
- octopi/pytorch_lightning/optuna_pl_ddp.py +273 -0
- octopi/pytorch_lightning/train_pl.py +244 -0
- octopi/utils/__init__.py +0 -0
- octopi/utils/config.py +57 -0
- octopi/utils/io.py +215 -0
- octopi/utils/losses.py +86 -0
- octopi/utils/parsers.py +162 -0
- octopi/utils/progress.py +78 -0
- octopi/utils/stopping_criteria.py +143 -0
- octopi/utils/submit_slurm.py +95 -0
- octopi/utils/visualization_tools.py +290 -0
- octopi/workflows.py +262 -0
- octopi-1.4.0.dist-info/METADATA +119 -0
- octopi-1.4.0.dist-info/RECORD +65 -0
- octopi-1.4.0.dist-info/WHEEL +4 -0
- octopi-1.4.0.dist-info/entry_points.txt +3 -0
- octopi-1.4.0.dist-info/licenses/LICENSE +41 -0
octopi/utils/io.py
ADDED
|
@@ -0,0 +1,215 @@
|
|
|
1
|
+
"""
|
|
2
|
+
File I/O utilities for YAML and JSON operations.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import os, json, yaml, copick, glob
|
|
6
|
+
import pandas as pd
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
# Create a custom dumper that uses flow style for lists only.
|
|
10
|
+
class InlineListDumper(yaml.SafeDumper):
|
|
11
|
+
def represent_list(self, data):
|
|
12
|
+
node = super().represent_list(data)
|
|
13
|
+
node.flow_style = True # Use inline style for lists
|
|
14
|
+
return node
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def save_parameters_yaml(params: dict, output_path: str):
|
|
18
|
+
"""
|
|
19
|
+
Save parameters to a YAML file.
|
|
20
|
+
"""
|
|
21
|
+
InlineListDumper.add_representer(list, InlineListDumper.represent_list)
|
|
22
|
+
with open(output_path, 'w') as f:
|
|
23
|
+
yaml.dump(params, f, Dumper=InlineListDumper, default_flow_style=False, sort_keys=False)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def load_yaml(path: str) -> dict:
|
|
27
|
+
"""
|
|
28
|
+
Load a YAML file and return the contents as a dictionary.
|
|
29
|
+
"""
|
|
30
|
+
if os.path.exists(path):
|
|
31
|
+
with open(path, 'r') as f:
|
|
32
|
+
return yaml.safe_load(f)
|
|
33
|
+
else:
|
|
34
|
+
raise FileNotFoundError(f"File not found: {path}")
|
|
35
|
+
|
|
36
|
+
def save_results_to_csv(results, filename: str):
|
|
37
|
+
"""Save training results to a CSV file (aligned to validation steps)."""
|
|
38
|
+
data = {}
|
|
39
|
+
|
|
40
|
+
# Get validation steps (these are the steps we want to keep)
|
|
41
|
+
val_steps = set()
|
|
42
|
+
for key, value in results.items():
|
|
43
|
+
if isinstance(value, list) and value and isinstance(value[0], tuple):
|
|
44
|
+
if key.startswith(('val_', 'avg_', 'f1_', 'recall_', 'precision_', 'fbeta_')):
|
|
45
|
+
val_steps.update([item[0] for item in value])
|
|
46
|
+
break
|
|
47
|
+
|
|
48
|
+
val_steps = sorted(val_steps)
|
|
49
|
+
data['step'] = val_steps
|
|
50
|
+
|
|
51
|
+
# Extract all metrics, filtering to validation steps only
|
|
52
|
+
for key, value in results.items():
|
|
53
|
+
if isinstance(value, list) and value and isinstance(value[0], tuple):
|
|
54
|
+
step_to_value = {item[0]: item[1] for item in value}
|
|
55
|
+
data[key] = [step_to_value.get(step, None) for step in val_steps]
|
|
56
|
+
|
|
57
|
+
df = pd.DataFrame(data)
|
|
58
|
+
df.to_csv(filename, index=False)
|
|
59
|
+
print(f"📊 Training Results saved to {filename}")
|
|
60
|
+
|
|
61
|
+
def prepare_inline_results_json(results):
|
|
62
|
+
"""
|
|
63
|
+
Prepare results for inline JSON formatting.
|
|
64
|
+
"""
|
|
65
|
+
# Traverse the dictionary and format lists of lists as inline JSON
|
|
66
|
+
for key, value in results.items():
|
|
67
|
+
# Check if the value is a list of lists (like [[epoch, value], ...])
|
|
68
|
+
if isinstance(value, list) and all(isinstance(item, list) and len(item) == 2 for item in value):
|
|
69
|
+
# Format the list of lists as a single-line JSON string
|
|
70
|
+
results[key] = json.dumps(value)
|
|
71
|
+
return results
|
|
72
|
+
|
|
73
|
+
def get_optimizer_parameters(trainer):
|
|
74
|
+
"""
|
|
75
|
+
Extract optimizer parameters from a trainer object.
|
|
76
|
+
"""
|
|
77
|
+
optimizer_parameters = {
|
|
78
|
+
'my_num_samples': trainer.num_samples,
|
|
79
|
+
'val_interval': trainer.val_interval,
|
|
80
|
+
'lr': trainer.optimizer.param_groups[0]['lr'],
|
|
81
|
+
'optimizer': trainer.optimizer.__class__.__name__,
|
|
82
|
+
'metrics_function': trainer.metrics_function.__class__.__name__,
|
|
83
|
+
'loss_function': trainer.loss_function.__class__.__name__,
|
|
84
|
+
}
|
|
85
|
+
|
|
86
|
+
# Log Tversky Loss Parameters
|
|
87
|
+
if trainer.loss_function.__class__.__name__ == 'TverskyLoss':
|
|
88
|
+
optimizer_parameters['alpha'] = trainer.loss_function.alpha
|
|
89
|
+
elif trainer.loss_function.__class__.__name__ == 'FocalLoss':
|
|
90
|
+
optimizer_parameters['gamma'] = trainer.loss_function.gamma
|
|
91
|
+
elif trainer.loss_function.__class__.__name__ == 'WeightedFocalTverskyLoss':
|
|
92
|
+
optimizer_parameters['alpha'] = trainer.loss_function.alpha
|
|
93
|
+
optimizer_parameters['gamma'] = trainer.loss_function.gamma
|
|
94
|
+
optimizer_parameters['weight_tversky'] = trainer.loss_function.weight_tversky
|
|
95
|
+
elif trainer.loss_function.__class__.__name__ == 'FocalTverskyLoss':
|
|
96
|
+
optimizer_parameters['alpha'] = trainer.loss_function.alpha
|
|
97
|
+
optimizer_parameters['gamma'] = trainer.loss_function.gamma
|
|
98
|
+
|
|
99
|
+
return optimizer_parameters
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
def save_parameters_to_yaml(model, trainer, dataloader, filename: str):
|
|
103
|
+
"""
|
|
104
|
+
Save training parameters to a YAML file.
|
|
105
|
+
"""
|
|
106
|
+
|
|
107
|
+
# Check for the target configuration file for model labels
|
|
108
|
+
target_config = check_target_config_path(dataloader)
|
|
109
|
+
|
|
110
|
+
# Extract and flatten parameters
|
|
111
|
+
parameters = {
|
|
112
|
+
'model': model.get_model_parameters(),
|
|
113
|
+
'labels': target_config['input']['labels'],
|
|
114
|
+
'optimizer': get_optimizer_parameters(trainer),
|
|
115
|
+
'dataloader': dataloader.get_dataloader_parameters()
|
|
116
|
+
}
|
|
117
|
+
|
|
118
|
+
save_parameters_yaml(parameters, filename)
|
|
119
|
+
print(f"⚙️ Training Parameters saved to {filename}")
|
|
120
|
+
|
|
121
|
+
def flatten_params(params, parent_key=''):
|
|
122
|
+
"""
|
|
123
|
+
Helper function to flatten and serialize nested parameters.
|
|
124
|
+
"""
|
|
125
|
+
flattened = {}
|
|
126
|
+
for key, value in params.items():
|
|
127
|
+
new_key = f"{parent_key}.{key}" if parent_key else key
|
|
128
|
+
if isinstance(value, dict):
|
|
129
|
+
flattened.update(flatten_params(value, new_key))
|
|
130
|
+
elif isinstance(value, list):
|
|
131
|
+
flattened[new_key] = ', '.join(map(str, value)) # Convert list to a comma-separated string
|
|
132
|
+
else:
|
|
133
|
+
flattened[new_key] = value
|
|
134
|
+
return flattened
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
def prepare_for_inline_json(data):
|
|
138
|
+
"""
|
|
139
|
+
Manually join specific lists into strings for inline display.
|
|
140
|
+
"""
|
|
141
|
+
for key in ["trainRunIDs", "valRunIDs", "testRunIDs"]:
|
|
142
|
+
if key in data['dataloader']:
|
|
143
|
+
data['dataloader'][key] = f"[{', '.join(map(repr, data['dataloader'][key]))}]"
|
|
144
|
+
|
|
145
|
+
for key in ['channels', 'strides']:
|
|
146
|
+
if key in data['model']:
|
|
147
|
+
data['model'][key] = f"[{', '.join(map(repr, data['model'][key]))}]"
|
|
148
|
+
return data
|
|
149
|
+
|
|
150
|
+
def check_target_config_path(data_generator):
|
|
151
|
+
"""
|
|
152
|
+
Check for the target configuration file in the CoPick overlay or static root directories.
|
|
153
|
+
If the session_id is not provided, search for the most recent file matching the target name
|
|
154
|
+
"""
|
|
155
|
+
|
|
156
|
+
# Open the Copick Project for MultiConfig or SingleConfig Workflow
|
|
157
|
+
if isinstance(data_generator.config, dict):
|
|
158
|
+
sessions = list(data_generator.config.keys())
|
|
159
|
+
config_path = data_generator.config[sessions[0]]
|
|
160
|
+
else:
|
|
161
|
+
config_path = data_generator.config
|
|
162
|
+
|
|
163
|
+
# Get the Target Config File
|
|
164
|
+
return get_config(
|
|
165
|
+
config_path, data_generator.target_name, 'targets',
|
|
166
|
+
data_generator.target_user_id, data_generator.target_session_id
|
|
167
|
+
)
|
|
168
|
+
|
|
169
|
+
def get_config(config_path, name, process, user_id=None, session_id=None):
|
|
170
|
+
"""
|
|
171
|
+
Get the configuration for a specific process and target name.
|
|
172
|
+
"""
|
|
173
|
+
|
|
174
|
+
# Get the Overlay and Static Roots
|
|
175
|
+
root = copick.from_file(config_path)
|
|
176
|
+
|
|
177
|
+
# Remove the local:// prefix from static_root if it exists
|
|
178
|
+
overlay_root = remove_prefix(root.config.overlay_root)
|
|
179
|
+
try: # Check if static_root is available
|
|
180
|
+
static_root = remove_prefix(root.config.static_root)
|
|
181
|
+
except: # If not, set it to None
|
|
182
|
+
static_root = None
|
|
183
|
+
|
|
184
|
+
# Two Search Patterns, Either only a name provided or name, user_id, session_id
|
|
185
|
+
if session_id is None:
|
|
186
|
+
pattern = glob.glob(os.path.join(overlay_root, 'logs', f"{process}_*{name}.yaml"))
|
|
187
|
+
if len(pattern) == 0 and static_root is not None:
|
|
188
|
+
pattern = glob.glob(os.path.join(static_root, 'logs', f"{process}_*{name}.yaml"))
|
|
189
|
+
fname = pattern[-1]
|
|
190
|
+
else:
|
|
191
|
+
fname = f"{process}-{user_id}_{session_id}_{name}.yaml"
|
|
192
|
+
|
|
193
|
+
# The Target Config File Should Either in be the Overlay or Static Root
|
|
194
|
+
if os.path.exists(os.path.join(overlay_root, 'logs', fname)):
|
|
195
|
+
path = os.path.join(overlay_root, 'logs', fname)
|
|
196
|
+
elif static_root is not None and os.path.exists(os.path.join(static_root, 'logs', fname)):
|
|
197
|
+
path = os.path.join(static_root, 'logs', fname)
|
|
198
|
+
else:
|
|
199
|
+
raise FileNotFoundError(f"Target config file not found: {fname}")
|
|
200
|
+
|
|
201
|
+
# Load the Target Config File
|
|
202
|
+
with open(path, 'r') as f:
|
|
203
|
+
target_config = yaml.safe_load(f)
|
|
204
|
+
return target_config
|
|
205
|
+
|
|
206
|
+
def remove_prefix(text: str) -> str:
|
|
207
|
+
"""
|
|
208
|
+
Remove a prefix from a string if it exists.
|
|
209
|
+
"""
|
|
210
|
+
# Check if the text is None
|
|
211
|
+
if text is None:
|
|
212
|
+
return None
|
|
213
|
+
elif text[:8] == 'local://':
|
|
214
|
+
text = text[8:]
|
|
215
|
+
return text
|
octopi/utils/losses.py
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
1
|
+
from monai.losses import FocalLoss, TverskyLoss
|
|
2
|
+
import torch
|
|
3
|
+
|
|
4
|
+
class WeightedFocalTverskyLoss(torch.nn.Module):
|
|
5
|
+
def __init__(
|
|
6
|
+
self, gamma=1.0, alpha=0.7, beta=0.3,
|
|
7
|
+
weight_tversky=0.5, weight_focal=0.5,
|
|
8
|
+
smooth=1e-5, **kwargs ):
|
|
9
|
+
"""
|
|
10
|
+
Weighted combination of Focal and Tversky loss.
|
|
11
|
+
|
|
12
|
+
Args:
|
|
13
|
+
gamma (float): Focus parameter for Focal Loss.
|
|
14
|
+
alpha (float): Weight for false positives in Tversky Loss.
|
|
15
|
+
beta (float): Weight for false negatives in Tversky Loss.
|
|
16
|
+
weight_tversky (float): Weight of Tversky loss in the combination.
|
|
17
|
+
weight_focal (float): Weight of Focal loss in the combination.
|
|
18
|
+
smooth (float): Smoothing factor to avoid division by zero.
|
|
19
|
+
"""
|
|
20
|
+
super().__init__()
|
|
21
|
+
self.tversky_loss = TverskyLoss(
|
|
22
|
+
alpha=alpha, beta=beta, include_background=True,
|
|
23
|
+
to_onehot_y=True, softmax=True,
|
|
24
|
+
smooth_nr=smooth, smooth_dr=smooth, **kwargs
|
|
25
|
+
)
|
|
26
|
+
self.focal_loss = FocalLoss(
|
|
27
|
+
include_background=True, to_onehot_y=True,
|
|
28
|
+
use_softmax=True, gamma=gamma
|
|
29
|
+
)
|
|
30
|
+
self.alpha = alpha
|
|
31
|
+
self.beta = beta
|
|
32
|
+
self.gamma = gamma
|
|
33
|
+
self.weight_tversky = weight_tversky
|
|
34
|
+
self.weight_focal = weight_focal
|
|
35
|
+
|
|
36
|
+
def forward(self, y_pred, y_true):
|
|
37
|
+
"""
|
|
38
|
+
Compute the combined loss.
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
y_pred (Tensor): Predicted probabilities (B, C, ...).
|
|
42
|
+
y_true (Tensor): Ground truth labels (B, C, ...), one-hot encoded.
|
|
43
|
+
|
|
44
|
+
Returns:
|
|
45
|
+
Tensor: Weighted combination of Tversky and Focal losses.
|
|
46
|
+
"""
|
|
47
|
+
tversky = self.tversky_loss(y_pred, y_true)
|
|
48
|
+
focal = self.focal_loss(y_pred, y_true)
|
|
49
|
+
return self.weight_tversky * tversky + self.weight_focal * focal
|
|
50
|
+
|
|
51
|
+
class FocalTverskyLoss(TverskyLoss):
|
|
52
|
+
def __init__(
|
|
53
|
+
self,
|
|
54
|
+
alpha=0.7, beta=0.3, gamma=1.0, smooth=1e-5, **kwargs):
|
|
55
|
+
"""
|
|
56
|
+
Focal Tversky Loss with an additional power term for harder samples.
|
|
57
|
+
|
|
58
|
+
From https://arxiv.org/abs/1810.07842
|
|
59
|
+
|
|
60
|
+
Args:
|
|
61
|
+
alpha (float): Weight for false positives.
|
|
62
|
+
beta (float): Weight for false negatives.
|
|
63
|
+
gamma (float): Focus parameter (like Focal Loss).
|
|
64
|
+
smooth (float): Smoothing factor to avoid division by zero.
|
|
65
|
+
"""
|
|
66
|
+
super().__init__(
|
|
67
|
+
alpha=alpha, beta=beta,
|
|
68
|
+
include_background=True,
|
|
69
|
+
to_onehot_y=True, softmax=True,
|
|
70
|
+
smooth_nr=smooth, smooth_dr=smooth, **kwargs)
|
|
71
|
+
self.gamma = gamma
|
|
72
|
+
self.alpha = alpha
|
|
73
|
+
self.beta = beta
|
|
74
|
+
|
|
75
|
+
def forward(self, y_pred, y_true):
|
|
76
|
+
"""
|
|
77
|
+
Args:
|
|
78
|
+
y_pred (Tensor): Predicted probabilities (B, C, ...).
|
|
79
|
+
y_true (Tensor): Ground truth labels (B, C, ...), one-hot encoded.
|
|
80
|
+
|
|
81
|
+
Returns:
|
|
82
|
+
Tensor: Loss value.
|
|
83
|
+
"""
|
|
84
|
+
tversky_loss = super().forward(y_pred, y_true)
|
|
85
|
+
modified_loss = torch.pow(tversky_loss, 1 / self.gamma)
|
|
86
|
+
return modified_loss
|
octopi/utils/parsers.py
ADDED
|
@@ -0,0 +1,162 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Argument parsing and configuration utilities.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
from typing import List, Tuple, Union
|
|
6
|
+
from dotenv import load_dotenv
|
|
7
|
+
import argparse
|
|
8
|
+
|
|
9
|
+
def parse_list(value: str) -> List[str]:
|
|
10
|
+
"""
|
|
11
|
+
Parse a string representing a list of items.
|
|
12
|
+
Supports formats like '[item1,item2,item3]' or 'item1,item2,item3'.
|
|
13
|
+
"""
|
|
14
|
+
value = value.strip("[]") # Remove brackets if present
|
|
15
|
+
return [x.strip() for x in value.split(",")]
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def parse_int_list(value: str) -> List[int]:
|
|
19
|
+
"""
|
|
20
|
+
Parse a string representing a list of integers.
|
|
21
|
+
Supports formats like '[1,2,3]' or '1,2,3'.
|
|
22
|
+
"""
|
|
23
|
+
return [int(x) for x in parse_list(value)]
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def string2bool(value: str):
|
|
27
|
+
"""
|
|
28
|
+
Custom function to convert string values to boolean.
|
|
29
|
+
"""
|
|
30
|
+
if isinstance(value, bool):
|
|
31
|
+
return value
|
|
32
|
+
if value.lower() in {'True', 'true', 't', '1', 'yes'}:
|
|
33
|
+
return True
|
|
34
|
+
elif value.lower() in {'False', 'false', 'f', '0', 'no'}:
|
|
35
|
+
return False
|
|
36
|
+
else:
|
|
37
|
+
raise argparse.ArgumentTypeError(f"Invalid boolean value: {value}")
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def parse_target(value: str) -> Tuple[str, Union[str, None], Union[str, None]]:
|
|
41
|
+
"""
|
|
42
|
+
Parse a single target string.
|
|
43
|
+
Expected formats:
|
|
44
|
+
- "name"
|
|
45
|
+
- "name,user_id,session_id"
|
|
46
|
+
"""
|
|
47
|
+
parts = value.split(',')
|
|
48
|
+
if len(parts) == 1:
|
|
49
|
+
obj_name = parts[0]
|
|
50
|
+
return obj_name, None, None
|
|
51
|
+
elif len(parts) == 3:
|
|
52
|
+
obj_name, user_id, session_id = parts
|
|
53
|
+
return obj_name, user_id, session_id
|
|
54
|
+
else:
|
|
55
|
+
raise argparse.ArgumentTypeError(
|
|
56
|
+
f"Invalid target format: '{value}'. Expected 'name' or 'name,user_id,session_id'."
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def parse_seg_target(value: str) -> List[Tuple[str, Union[str, None], Union[str, None]]]:
|
|
61
|
+
"""
|
|
62
|
+
Parse segmentation targets. Each target can have the format:
|
|
63
|
+
- "name"
|
|
64
|
+
- "name,user_id,session_id"
|
|
65
|
+
Multiple targets can be comma-separated.
|
|
66
|
+
"""
|
|
67
|
+
targets = []
|
|
68
|
+
for target in value.split(';'): # Use ';' as a separator for multiple targets
|
|
69
|
+
parts = target.split(',')
|
|
70
|
+
if len(parts) == 1:
|
|
71
|
+
name = parts[0]
|
|
72
|
+
targets.append((name, None, None))
|
|
73
|
+
elif len(parts) == 3:
|
|
74
|
+
name, user_id, session_id = parts
|
|
75
|
+
targets.append((name, user_id, session_id))
|
|
76
|
+
else:
|
|
77
|
+
raise argparse.ArgumentTypeError(
|
|
78
|
+
f"Invalid seg-target format: '{target}'. Expected 'name' or 'name,user_id,session_id'."
|
|
79
|
+
)
|
|
80
|
+
return targets
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def parse_copick_configs(config_entries: List[str]):
|
|
84
|
+
"""
|
|
85
|
+
Parse a string representing a list of CoPick configuration file paths.
|
|
86
|
+
"""
|
|
87
|
+
# Process the --config arguments into a dictionary
|
|
88
|
+
copick_configs = {}
|
|
89
|
+
|
|
90
|
+
for config_entry in config_entries:
|
|
91
|
+
if ',' in config_entry:
|
|
92
|
+
# Entry has a session name and a config path
|
|
93
|
+
try:
|
|
94
|
+
session_name, config_path = config_entry.split(',', 1)
|
|
95
|
+
copick_configs[session_name] = config_path
|
|
96
|
+
except ValueError:
|
|
97
|
+
raise argparse.ArgumentTypeError(
|
|
98
|
+
f"Invalid format for --config entry: '{config_entry}'. Expected 'session_name,/path/to/config.json'."
|
|
99
|
+
)
|
|
100
|
+
else:
|
|
101
|
+
# Single configuration path without a session name
|
|
102
|
+
# if "default" in copick_configs:
|
|
103
|
+
# raise argparse.ArgumentTypeError(
|
|
104
|
+
# f"Only one single-path --config entry is allowed when using default configurations. "
|
|
105
|
+
# f"Detected duplicate: {config_entry}"
|
|
106
|
+
# )
|
|
107
|
+
# copick_configs["default"] = config_entry
|
|
108
|
+
copick_configs = config_entry
|
|
109
|
+
|
|
110
|
+
# if ',' in config_entry:
|
|
111
|
+
# parts = config_entry.split(',')
|
|
112
|
+
# if len(parts) == 2:
|
|
113
|
+
# # Entry with session name and config path
|
|
114
|
+
# session_name, config_path = parts
|
|
115
|
+
# copick_configs[session_name] = {"path": config_path, "algorithm": None}
|
|
116
|
+
# elif len(parts) == 3:
|
|
117
|
+
# # Entry with session name, config path, and algorithm
|
|
118
|
+
# session_name, config_path, algorithm = parts
|
|
119
|
+
# copick_configs[session_name] = {"path": config_path, "algorithm": algorithm}
|
|
120
|
+
# else:
|
|
121
|
+
# copick_configs = config_entry
|
|
122
|
+
|
|
123
|
+
return copick_configs
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
def parse_data_split(value: str) -> Tuple[float, float, float]:
|
|
127
|
+
"""
|
|
128
|
+
Parse data split ratios from string input.
|
|
129
|
+
|
|
130
|
+
Args:
|
|
131
|
+
value: Either a single float (e.g., "0.8") or two comma-separated floats (e.g., "0.7,0.1")
|
|
132
|
+
|
|
133
|
+
Returns:
|
|
134
|
+
Tuple of (train_ratio, val_ratio, test_ratio)
|
|
135
|
+
|
|
136
|
+
Examples:
|
|
137
|
+
"0.8" -> (0.8, 0.2, 0.0)
|
|
138
|
+
"0.7,0.1" -> (0.7, 0.1, 0.2)
|
|
139
|
+
"""
|
|
140
|
+
parts = value.split(',')
|
|
141
|
+
|
|
142
|
+
if len(parts) == 1:
|
|
143
|
+
# Single value provided - use it as train ratio
|
|
144
|
+
train_ratio = float(parts[0])
|
|
145
|
+
val_ratio = 1.0 - train_ratio
|
|
146
|
+
test_ratio = 0.0
|
|
147
|
+
elif len(parts) == 2:
|
|
148
|
+
# Two values provided - use as train and val ratios
|
|
149
|
+
train_ratio = float(parts[0])
|
|
150
|
+
val_ratio = float(parts[1])
|
|
151
|
+
test_ratio = 1.0 - train_ratio - val_ratio
|
|
152
|
+
else:
|
|
153
|
+
raise ValueError("Data split must be either a single value or two comma-separated values")
|
|
154
|
+
|
|
155
|
+
# Validate ratios
|
|
156
|
+
if train_ratio < 0 or val_ratio < 0 or test_ratio < 0:
|
|
157
|
+
raise ValueError("All ratios must be non-negative")
|
|
158
|
+
|
|
159
|
+
if abs(train_ratio + val_ratio + test_ratio - 1.0) > 1e-6:
|
|
160
|
+
raise ValueError(f"Ratios must sum to 1.0, got {train_ratio + val_ratio + test_ratio}")
|
|
161
|
+
|
|
162
|
+
return round(train_ratio, 2), round(val_ratio, 2), round(test_ratio, 2)
|
octopi/utils/progress.py
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
1
|
+
from rich.progress import (
|
|
2
|
+
Progress,
|
|
3
|
+
SpinnerColumn,
|
|
4
|
+
TextColumn,
|
|
5
|
+
BarColumn,
|
|
6
|
+
MofNCompleteColumn,
|
|
7
|
+
TimeElapsedColumn,
|
|
8
|
+
TimeRemainingColumn,
|
|
9
|
+
)
|
|
10
|
+
from rich.console import Console
|
|
11
|
+
from rich.table import Table
|
|
12
|
+
from rich.panel import Panel
|
|
13
|
+
import json
|
|
14
|
+
|
|
15
|
+
# Minimal helper to get a shared Console without top-level Rich dependency.
|
|
16
|
+
def get_console():
|
|
17
|
+
return Console()
|
|
18
|
+
|
|
19
|
+
def _progress(iterable, description="Processing"):
|
|
20
|
+
"""
|
|
21
|
+
Wrap an iterable with a Rich progress bar.
|
|
22
|
+
|
|
23
|
+
Args:
|
|
24
|
+
iterable: Any iterable object (e.g., list, generator).
|
|
25
|
+
description: Text label to display above the progress bar.
|
|
26
|
+
|
|
27
|
+
Yields:
|
|
28
|
+
Each item from the iterable, while updating the progress bar.
|
|
29
|
+
|
|
30
|
+
Example:
|
|
31
|
+
for x in _progress(range(10), "Doing work"):
|
|
32
|
+
time.sleep(0.5)
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
console = Console()
|
|
36
|
+
|
|
37
|
+
# The generator itself yields items while advancing the progress bar
|
|
38
|
+
with Progress(
|
|
39
|
+
SpinnerColumn(),
|
|
40
|
+
TextColumn(f"[bold blue]{description}"),
|
|
41
|
+
BarColumn(),
|
|
42
|
+
MofNCompleteColumn(),
|
|
43
|
+
TimeElapsedColumn(),
|
|
44
|
+
TimeRemainingColumn(),
|
|
45
|
+
transient=False,
|
|
46
|
+
console=console,
|
|
47
|
+
) as progress:
|
|
48
|
+
task = progress.add_task(description, total=len(iterable))
|
|
49
|
+
for item in iterable:
|
|
50
|
+
yield item
|
|
51
|
+
progress.advance(task)
|
|
52
|
+
|
|
53
|
+
def print_summary(process: str, **kwargs):
|
|
54
|
+
"""
|
|
55
|
+
Pretty-print download parameters using Rich in a clean table with green highlights.
|
|
56
|
+
"""
|
|
57
|
+
|
|
58
|
+
console = Console()
|
|
59
|
+
|
|
60
|
+
# ---- Header ----
|
|
61
|
+
console.rule(f"[bold green]{process} Parameters Summary[/bold green]")
|
|
62
|
+
|
|
63
|
+
# ---- Table (no outer box) ----
|
|
64
|
+
table = Table(
|
|
65
|
+
show_header=True,
|
|
66
|
+
header_style="bold magenta",
|
|
67
|
+
expand=False,
|
|
68
|
+
border_style="green", # table borders only
|
|
69
|
+
)
|
|
70
|
+
table.add_column("Parameter", style="cyan", no_wrap=True)
|
|
71
|
+
table.add_column("Value", style="white")
|
|
72
|
+
|
|
73
|
+
for key, value in kwargs.items():
|
|
74
|
+
if isinstance(value, (dict, list)):
|
|
75
|
+
value = json.dumps(value, indent=2)
|
|
76
|
+
table.add_row(str(key), str(value))
|
|
77
|
+
|
|
78
|
+
console.print(table) # Print table directly, NO panel wrapper
|
|
@@ -0,0 +1,143 @@
|
|
|
1
|
+
import numpy as np
|
|
2
|
+
|
|
3
|
+
class EarlyStoppingChecker:
|
|
4
|
+
"""
|
|
5
|
+
A class to manage various early stopping criteria for model training.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
def __init__(self,
|
|
9
|
+
max_nan_epochs=15,
|
|
10
|
+
plateau_patience=20,
|
|
11
|
+
plateau_min_delta=0.005,
|
|
12
|
+
stagnation_patience=30,
|
|
13
|
+
convergence_window=5,
|
|
14
|
+
convergence_threshold=0.005,
|
|
15
|
+
val_interval=15,
|
|
16
|
+
monitor_metric='avg_fbeta'):
|
|
17
|
+
"""
|
|
18
|
+
Initialize early stopping parameters.
|
|
19
|
+
|
|
20
|
+
Args:
|
|
21
|
+
max_nan_epochs: Maximum number of epochs with NaN loss before stopping
|
|
22
|
+
plateau_patience: Number of validation checks to wait for plateau detection
|
|
23
|
+
plateau_min_delta: Minimum change to qualify as improvement
|
|
24
|
+
stagnation_patience: Number of validation intervals to wait for best metric improvement
|
|
25
|
+
convergence_window: Window size for calculating improvement rate
|
|
26
|
+
convergence_threshold: Minimum improvement rate threshold
|
|
27
|
+
val_interval: Number of epochs between validation runs
|
|
28
|
+
monitor_metric: Primary metric to monitor for early stopping criteria
|
|
29
|
+
"""
|
|
30
|
+
self.max_nan_epochs = max_nan_epochs
|
|
31
|
+
self.plateau_patience = plateau_patience
|
|
32
|
+
self.plateau_min_delta = plateau_min_delta
|
|
33
|
+
self.stagnation_patience = stagnation_patience
|
|
34
|
+
self.convergence_window = convergence_window
|
|
35
|
+
self.convergence_threshold = convergence_threshold
|
|
36
|
+
self.val_interval = val_interval
|
|
37
|
+
self.monitor_metric = monitor_metric
|
|
38
|
+
|
|
39
|
+
# Counters
|
|
40
|
+
self.nan_counter = 0
|
|
41
|
+
|
|
42
|
+
# Flags for detailed reporting
|
|
43
|
+
self.stopped_reason = None
|
|
44
|
+
|
|
45
|
+
def check_for_nan(self, epoch_loss):
|
|
46
|
+
"""Check for NaN in the loss."""
|
|
47
|
+
if np.isnan(epoch_loss):
|
|
48
|
+
self.nan_counter += 1
|
|
49
|
+
if self.nan_counter > self.max_nan_epochs:
|
|
50
|
+
self.stopped_reason = f"NaN values in loss for more than {self.max_nan_epochs} epochs"
|
|
51
|
+
return True
|
|
52
|
+
else:
|
|
53
|
+
self.nan_counter = 0 # Reset the counter if loss is valid
|
|
54
|
+
return False
|
|
55
|
+
|
|
56
|
+
def check_for_plateau(self, results):
|
|
57
|
+
"""Detect plateaus in validation metrics."""
|
|
58
|
+
if len(results[self.monitor_metric]) < self.plateau_patience + 1:
|
|
59
|
+
return False
|
|
60
|
+
|
|
61
|
+
# Get the last 'patience' number of validation points
|
|
62
|
+
recent_values = [x[1] for x in results[self.monitor_metric][-self.plateau_patience:]]
|
|
63
|
+
# Find the max value in the window
|
|
64
|
+
max_value = max(recent_values)
|
|
65
|
+
# Find the min value in the window
|
|
66
|
+
min_value = min(recent_values)
|
|
67
|
+
|
|
68
|
+
# If the range of values is small, consider it a plateau
|
|
69
|
+
if max_value - min_value < self.plateau_min_delta:
|
|
70
|
+
self.stopped_reason = f"{self.monitor_metric} plateaued for {self.plateau_patience} validations"
|
|
71
|
+
return True
|
|
72
|
+
|
|
73
|
+
return False
|
|
74
|
+
|
|
75
|
+
def check_best_metric_stagnation(self, results):
|
|
76
|
+
"""Stop if best metric hasn't improved for a number of validation intervals."""
|
|
77
|
+
if "best_metric_epoch" not in results or len(results[self.monitor_metric]) < self.stagnation_patience + 1:
|
|
78
|
+
return False
|
|
79
|
+
|
|
80
|
+
# Get epoch of the best metric so far
|
|
81
|
+
best_epoch = results["best_metric_epoch"]
|
|
82
|
+
current_epoch = results[self.monitor_metric][-1][0]
|
|
83
|
+
|
|
84
|
+
# Check if it's been more than 'patience' validation intervals
|
|
85
|
+
if (current_epoch - best_epoch) >= (self.stagnation_patience * self.val_interval):
|
|
86
|
+
self.stopped_reason = f"No improvement for {self.stagnation_patience} validation intervals"
|
|
87
|
+
return True
|
|
88
|
+
|
|
89
|
+
return False
|
|
90
|
+
|
|
91
|
+
# def check_convergence_rate(self, results):
|
|
92
|
+
# """Stop when improvement rate slows below threshold."""
|
|
93
|
+
# if len(results[self.monitor_metric]) < self.convergence_window + 1:
|
|
94
|
+
# return False
|
|
95
|
+
|
|
96
|
+
# # Calculate average improvement rate over window
|
|
97
|
+
# recent_values = [x[1] for x in results[self.monitor_metric][-(self.convergence_window+1):]]
|
|
98
|
+
# improvements = [recent_values[i+1] - recent_values[i] for i in range(self.convergence_window)]
|
|
99
|
+
# avg_improvement = sum(improvements) / self.convergence_window
|
|
100
|
+
|
|
101
|
+
# if avg_improvement < self.convergence_threshold and avg_improvement > 0:
|
|
102
|
+
# self.stopped_reason = f"Convergence rate ({avg_improvement:.6f}) below threshold"
|
|
103
|
+
# return True
|
|
104
|
+
|
|
105
|
+
# return False
|
|
106
|
+
|
|
107
|
+
def should_stop_training(self, epoch_loss, results=None, check_metrics=False):
|
|
108
|
+
"""
|
|
109
|
+
Comprehensive check for whether training should stop.
|
|
110
|
+
|
|
111
|
+
Args:
|
|
112
|
+
epoch_loss: Current epoch's loss value
|
|
113
|
+
results: Dictionary containing training metrics history
|
|
114
|
+
check_metrics: Whether to also check validation metrics-based criteria
|
|
115
|
+
|
|
116
|
+
Returns:
|
|
117
|
+
bool: True if training should stop, False otherwise
|
|
118
|
+
"""
|
|
119
|
+
# Check for NaN in loss (can be done every epoch)
|
|
120
|
+
if self.check_for_nan(epoch_loss):
|
|
121
|
+
return True
|
|
122
|
+
|
|
123
|
+
# Only check metric-based criteria if requested and results are provided
|
|
124
|
+
if check_metrics and results:
|
|
125
|
+
# Check for plateau in validation metrics
|
|
126
|
+
if self.check_for_plateau(results):
|
|
127
|
+
return True
|
|
128
|
+
|
|
129
|
+
# Check if best metric hasn't improved for a while
|
|
130
|
+
if self.check_best_metric_stagnation(results):
|
|
131
|
+
return True
|
|
132
|
+
|
|
133
|
+
# # Check if convergence rate has slowed down
|
|
134
|
+
# if self.check_convergence_rate(results):
|
|
135
|
+
# return True
|
|
136
|
+
|
|
137
|
+
return False
|
|
138
|
+
|
|
139
|
+
def get_stopped_reason(self):
|
|
140
|
+
"""Get the reason for stopping, if any."""
|
|
141
|
+
if self.stopped_reason:
|
|
142
|
+
return f"Early stopping triggered: {self.stopped_reason}"
|
|
143
|
+
return "No early stopping criteria met."
|