octopi 1.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (65) hide show
  1. octopi/__init__.py +7 -0
  2. octopi/datasets/__init__.py +0 -0
  3. octopi/datasets/augment.py +83 -0
  4. octopi/datasets/cached_datset.py +113 -0
  5. octopi/datasets/dataset.py +19 -0
  6. octopi/datasets/generators.py +458 -0
  7. octopi/datasets/io.py +200 -0
  8. octopi/datasets/mixup.py +49 -0
  9. octopi/datasets/multi_config_generator.py +252 -0
  10. octopi/entry_points/__init__.py +0 -0
  11. octopi/entry_points/common.py +119 -0
  12. octopi/entry_points/create_slurm_submission.py +251 -0
  13. octopi/entry_points/groups.py +152 -0
  14. octopi/entry_points/run_create_targets.py +234 -0
  15. octopi/entry_points/run_evaluate.py +99 -0
  16. octopi/entry_points/run_extract_mb_picks.py +191 -0
  17. octopi/entry_points/run_extract_midpoint.py +143 -0
  18. octopi/entry_points/run_localize.py +176 -0
  19. octopi/entry_points/run_optuna.py +161 -0
  20. octopi/entry_points/run_segment.py +154 -0
  21. octopi/entry_points/run_train.py +189 -0
  22. octopi/extract/__init__.py +0 -0
  23. octopi/extract/localize.py +217 -0
  24. octopi/extract/membranebound_extract.py +263 -0
  25. octopi/extract/midpoint_extract.py +193 -0
  26. octopi/main.py +33 -0
  27. octopi/models/AttentionUnet.py +56 -0
  28. octopi/models/MedNeXt.py +111 -0
  29. octopi/models/ModelTemplate.py +36 -0
  30. octopi/models/SegResNet.py +92 -0
  31. octopi/models/Unet.py +59 -0
  32. octopi/models/UnetPlusPlus.py +47 -0
  33. octopi/models/__init__.py +0 -0
  34. octopi/models/common.py +72 -0
  35. octopi/processing/__init__.py +0 -0
  36. octopi/processing/create_targets_from_picks.py +224 -0
  37. octopi/processing/downloader.py +138 -0
  38. octopi/processing/downsample.py +125 -0
  39. octopi/processing/evaluate.py +302 -0
  40. octopi/processing/importers.py +116 -0
  41. octopi/processing/segmentation_from_picks.py +167 -0
  42. octopi/pytorch/__init__.py +0 -0
  43. octopi/pytorch/hyper_search.py +244 -0
  44. octopi/pytorch/model_search_submitter.py +291 -0
  45. octopi/pytorch/segmentation.py +363 -0
  46. octopi/pytorch/segmentation_multigpu.py +162 -0
  47. octopi/pytorch/trainer.py +465 -0
  48. octopi/pytorch_lightning/__init__.py +0 -0
  49. octopi/pytorch_lightning/optuna_pl_ddp.py +273 -0
  50. octopi/pytorch_lightning/train_pl.py +244 -0
  51. octopi/utils/__init__.py +0 -0
  52. octopi/utils/config.py +57 -0
  53. octopi/utils/io.py +215 -0
  54. octopi/utils/losses.py +86 -0
  55. octopi/utils/parsers.py +162 -0
  56. octopi/utils/progress.py +78 -0
  57. octopi/utils/stopping_criteria.py +143 -0
  58. octopi/utils/submit_slurm.py +95 -0
  59. octopi/utils/visualization_tools.py +290 -0
  60. octopi/workflows.py +262 -0
  61. octopi-1.4.0.dist-info/METADATA +119 -0
  62. octopi-1.4.0.dist-info/RECORD +65 -0
  63. octopi-1.4.0.dist-info/WHEEL +4 -0
  64. octopi-1.4.0.dist-info/entry_points.txt +3 -0
  65. octopi-1.4.0.dist-info/licenses/LICENSE +41 -0
octopi/utils/io.py ADDED
@@ -0,0 +1,215 @@
1
+ """
2
+ File I/O utilities for YAML and JSON operations.
3
+ """
4
+
5
+ import os, json, yaml, copick, glob
6
+ import pandas as pd
7
+
8
+
9
+ # Create a custom dumper that uses flow style for lists only.
10
+ class InlineListDumper(yaml.SafeDumper):
11
+ def represent_list(self, data):
12
+ node = super().represent_list(data)
13
+ node.flow_style = True # Use inline style for lists
14
+ return node
15
+
16
+
17
+ def save_parameters_yaml(params: dict, output_path: str):
18
+ """
19
+ Save parameters to a YAML file.
20
+ """
21
+ InlineListDumper.add_representer(list, InlineListDumper.represent_list)
22
+ with open(output_path, 'w') as f:
23
+ yaml.dump(params, f, Dumper=InlineListDumper, default_flow_style=False, sort_keys=False)
24
+
25
+
26
+ def load_yaml(path: str) -> dict:
27
+ """
28
+ Load a YAML file and return the contents as a dictionary.
29
+ """
30
+ if os.path.exists(path):
31
+ with open(path, 'r') as f:
32
+ return yaml.safe_load(f)
33
+ else:
34
+ raise FileNotFoundError(f"File not found: {path}")
35
+
36
+ def save_results_to_csv(results, filename: str):
37
+ """Save training results to a CSV file (aligned to validation steps)."""
38
+ data = {}
39
+
40
+ # Get validation steps (these are the steps we want to keep)
41
+ val_steps = set()
42
+ for key, value in results.items():
43
+ if isinstance(value, list) and value and isinstance(value[0], tuple):
44
+ if key.startswith(('val_', 'avg_', 'f1_', 'recall_', 'precision_', 'fbeta_')):
45
+ val_steps.update([item[0] for item in value])
46
+ break
47
+
48
+ val_steps = sorted(val_steps)
49
+ data['step'] = val_steps
50
+
51
+ # Extract all metrics, filtering to validation steps only
52
+ for key, value in results.items():
53
+ if isinstance(value, list) and value and isinstance(value[0], tuple):
54
+ step_to_value = {item[0]: item[1] for item in value}
55
+ data[key] = [step_to_value.get(step, None) for step in val_steps]
56
+
57
+ df = pd.DataFrame(data)
58
+ df.to_csv(filename, index=False)
59
+ print(f"📊 Training Results saved to {filename}")
60
+
61
+ def prepare_inline_results_json(results):
62
+ """
63
+ Prepare results for inline JSON formatting.
64
+ """
65
+ # Traverse the dictionary and format lists of lists as inline JSON
66
+ for key, value in results.items():
67
+ # Check if the value is a list of lists (like [[epoch, value], ...])
68
+ if isinstance(value, list) and all(isinstance(item, list) and len(item) == 2 for item in value):
69
+ # Format the list of lists as a single-line JSON string
70
+ results[key] = json.dumps(value)
71
+ return results
72
+
73
+ def get_optimizer_parameters(trainer):
74
+ """
75
+ Extract optimizer parameters from a trainer object.
76
+ """
77
+ optimizer_parameters = {
78
+ 'my_num_samples': trainer.num_samples,
79
+ 'val_interval': trainer.val_interval,
80
+ 'lr': trainer.optimizer.param_groups[0]['lr'],
81
+ 'optimizer': trainer.optimizer.__class__.__name__,
82
+ 'metrics_function': trainer.metrics_function.__class__.__name__,
83
+ 'loss_function': trainer.loss_function.__class__.__name__,
84
+ }
85
+
86
+ # Log Tversky Loss Parameters
87
+ if trainer.loss_function.__class__.__name__ == 'TverskyLoss':
88
+ optimizer_parameters['alpha'] = trainer.loss_function.alpha
89
+ elif trainer.loss_function.__class__.__name__ == 'FocalLoss':
90
+ optimizer_parameters['gamma'] = trainer.loss_function.gamma
91
+ elif trainer.loss_function.__class__.__name__ == 'WeightedFocalTverskyLoss':
92
+ optimizer_parameters['alpha'] = trainer.loss_function.alpha
93
+ optimizer_parameters['gamma'] = trainer.loss_function.gamma
94
+ optimizer_parameters['weight_tversky'] = trainer.loss_function.weight_tversky
95
+ elif trainer.loss_function.__class__.__name__ == 'FocalTverskyLoss':
96
+ optimizer_parameters['alpha'] = trainer.loss_function.alpha
97
+ optimizer_parameters['gamma'] = trainer.loss_function.gamma
98
+
99
+ return optimizer_parameters
100
+
101
+
102
+ def save_parameters_to_yaml(model, trainer, dataloader, filename: str):
103
+ """
104
+ Save training parameters to a YAML file.
105
+ """
106
+
107
+ # Check for the target configuration file for model labels
108
+ target_config = check_target_config_path(dataloader)
109
+
110
+ # Extract and flatten parameters
111
+ parameters = {
112
+ 'model': model.get_model_parameters(),
113
+ 'labels': target_config['input']['labels'],
114
+ 'optimizer': get_optimizer_parameters(trainer),
115
+ 'dataloader': dataloader.get_dataloader_parameters()
116
+ }
117
+
118
+ save_parameters_yaml(parameters, filename)
119
+ print(f"⚙️ Training Parameters saved to {filename}")
120
+
121
+ def flatten_params(params, parent_key=''):
122
+ """
123
+ Helper function to flatten and serialize nested parameters.
124
+ """
125
+ flattened = {}
126
+ for key, value in params.items():
127
+ new_key = f"{parent_key}.{key}" if parent_key else key
128
+ if isinstance(value, dict):
129
+ flattened.update(flatten_params(value, new_key))
130
+ elif isinstance(value, list):
131
+ flattened[new_key] = ', '.join(map(str, value)) # Convert list to a comma-separated string
132
+ else:
133
+ flattened[new_key] = value
134
+ return flattened
135
+
136
+
137
+ def prepare_for_inline_json(data):
138
+ """
139
+ Manually join specific lists into strings for inline display.
140
+ """
141
+ for key in ["trainRunIDs", "valRunIDs", "testRunIDs"]:
142
+ if key in data['dataloader']:
143
+ data['dataloader'][key] = f"[{', '.join(map(repr, data['dataloader'][key]))}]"
144
+
145
+ for key in ['channels', 'strides']:
146
+ if key in data['model']:
147
+ data['model'][key] = f"[{', '.join(map(repr, data['model'][key]))}]"
148
+ return data
149
+
150
+ def check_target_config_path(data_generator):
151
+ """
152
+ Check for the target configuration file in the CoPick overlay or static root directories.
153
+ If the session_id is not provided, search for the most recent file matching the target name
154
+ """
155
+
156
+ # Open the Copick Project for MultiConfig or SingleConfig Workflow
157
+ if isinstance(data_generator.config, dict):
158
+ sessions = list(data_generator.config.keys())
159
+ config_path = data_generator.config[sessions[0]]
160
+ else:
161
+ config_path = data_generator.config
162
+
163
+ # Get the Target Config File
164
+ return get_config(
165
+ config_path, data_generator.target_name, 'targets',
166
+ data_generator.target_user_id, data_generator.target_session_id
167
+ )
168
+
169
+ def get_config(config_path, name, process, user_id=None, session_id=None):
170
+ """
171
+ Get the configuration for a specific process and target name.
172
+ """
173
+
174
+ # Get the Overlay and Static Roots
175
+ root = copick.from_file(config_path)
176
+
177
+ # Remove the local:// prefix from static_root if it exists
178
+ overlay_root = remove_prefix(root.config.overlay_root)
179
+ try: # Check if static_root is available
180
+ static_root = remove_prefix(root.config.static_root)
181
+ except: # If not, set it to None
182
+ static_root = None
183
+
184
+ # Two Search Patterns, Either only a name provided or name, user_id, session_id
185
+ if session_id is None:
186
+ pattern = glob.glob(os.path.join(overlay_root, 'logs', f"{process}_*{name}.yaml"))
187
+ if len(pattern) == 0 and static_root is not None:
188
+ pattern = glob.glob(os.path.join(static_root, 'logs', f"{process}_*{name}.yaml"))
189
+ fname = pattern[-1]
190
+ else:
191
+ fname = f"{process}-{user_id}_{session_id}_{name}.yaml"
192
+
193
+ # The Target Config File Should Either in be the Overlay or Static Root
194
+ if os.path.exists(os.path.join(overlay_root, 'logs', fname)):
195
+ path = os.path.join(overlay_root, 'logs', fname)
196
+ elif static_root is not None and os.path.exists(os.path.join(static_root, 'logs', fname)):
197
+ path = os.path.join(static_root, 'logs', fname)
198
+ else:
199
+ raise FileNotFoundError(f"Target config file not found: {fname}")
200
+
201
+ # Load the Target Config File
202
+ with open(path, 'r') as f:
203
+ target_config = yaml.safe_load(f)
204
+ return target_config
205
+
206
+ def remove_prefix(text: str) -> str:
207
+ """
208
+ Remove a prefix from a string if it exists.
209
+ """
210
+ # Check if the text is None
211
+ if text is None:
212
+ return None
213
+ elif text[:8] == 'local://':
214
+ text = text[8:]
215
+ return text
octopi/utils/losses.py ADDED
@@ -0,0 +1,86 @@
1
+ from monai.losses import FocalLoss, TverskyLoss
2
+ import torch
3
+
4
+ class WeightedFocalTverskyLoss(torch.nn.Module):
5
+ def __init__(
6
+ self, gamma=1.0, alpha=0.7, beta=0.3,
7
+ weight_tversky=0.5, weight_focal=0.5,
8
+ smooth=1e-5, **kwargs ):
9
+ """
10
+ Weighted combination of Focal and Tversky loss.
11
+
12
+ Args:
13
+ gamma (float): Focus parameter for Focal Loss.
14
+ alpha (float): Weight for false positives in Tversky Loss.
15
+ beta (float): Weight for false negatives in Tversky Loss.
16
+ weight_tversky (float): Weight of Tversky loss in the combination.
17
+ weight_focal (float): Weight of Focal loss in the combination.
18
+ smooth (float): Smoothing factor to avoid division by zero.
19
+ """
20
+ super().__init__()
21
+ self.tversky_loss = TverskyLoss(
22
+ alpha=alpha, beta=beta, include_background=True,
23
+ to_onehot_y=True, softmax=True,
24
+ smooth_nr=smooth, smooth_dr=smooth, **kwargs
25
+ )
26
+ self.focal_loss = FocalLoss(
27
+ include_background=True, to_onehot_y=True,
28
+ use_softmax=True, gamma=gamma
29
+ )
30
+ self.alpha = alpha
31
+ self.beta = beta
32
+ self.gamma = gamma
33
+ self.weight_tversky = weight_tversky
34
+ self.weight_focal = weight_focal
35
+
36
+ def forward(self, y_pred, y_true):
37
+ """
38
+ Compute the combined loss.
39
+
40
+ Args:
41
+ y_pred (Tensor): Predicted probabilities (B, C, ...).
42
+ y_true (Tensor): Ground truth labels (B, C, ...), one-hot encoded.
43
+
44
+ Returns:
45
+ Tensor: Weighted combination of Tversky and Focal losses.
46
+ """
47
+ tversky = self.tversky_loss(y_pred, y_true)
48
+ focal = self.focal_loss(y_pred, y_true)
49
+ return self.weight_tversky * tversky + self.weight_focal * focal
50
+
51
+ class FocalTverskyLoss(TverskyLoss):
52
+ def __init__(
53
+ self,
54
+ alpha=0.7, beta=0.3, gamma=1.0, smooth=1e-5, **kwargs):
55
+ """
56
+ Focal Tversky Loss with an additional power term for harder samples.
57
+
58
+ From https://arxiv.org/abs/1810.07842
59
+
60
+ Args:
61
+ alpha (float): Weight for false positives.
62
+ beta (float): Weight for false negatives.
63
+ gamma (float): Focus parameter (like Focal Loss).
64
+ smooth (float): Smoothing factor to avoid division by zero.
65
+ """
66
+ super().__init__(
67
+ alpha=alpha, beta=beta,
68
+ include_background=True,
69
+ to_onehot_y=True, softmax=True,
70
+ smooth_nr=smooth, smooth_dr=smooth, **kwargs)
71
+ self.gamma = gamma
72
+ self.alpha = alpha
73
+ self.beta = beta
74
+
75
+ def forward(self, y_pred, y_true):
76
+ """
77
+ Args:
78
+ y_pred (Tensor): Predicted probabilities (B, C, ...).
79
+ y_true (Tensor): Ground truth labels (B, C, ...), one-hot encoded.
80
+
81
+ Returns:
82
+ Tensor: Loss value.
83
+ """
84
+ tversky_loss = super().forward(y_pred, y_true)
85
+ modified_loss = torch.pow(tversky_loss, 1 / self.gamma)
86
+ return modified_loss
@@ -0,0 +1,162 @@
1
+ """
2
+ Argument parsing and configuration utilities.
3
+ """
4
+
5
+ from typing import List, Tuple, Union
6
+ from dotenv import load_dotenv
7
+ import argparse
8
+
9
+ def parse_list(value: str) -> List[str]:
10
+ """
11
+ Parse a string representing a list of items.
12
+ Supports formats like '[item1,item2,item3]' or 'item1,item2,item3'.
13
+ """
14
+ value = value.strip("[]") # Remove brackets if present
15
+ return [x.strip() for x in value.split(",")]
16
+
17
+
18
+ def parse_int_list(value: str) -> List[int]:
19
+ """
20
+ Parse a string representing a list of integers.
21
+ Supports formats like '[1,2,3]' or '1,2,3'.
22
+ """
23
+ return [int(x) for x in parse_list(value)]
24
+
25
+
26
+ def string2bool(value: str):
27
+ """
28
+ Custom function to convert string values to boolean.
29
+ """
30
+ if isinstance(value, bool):
31
+ return value
32
+ if value.lower() in {'True', 'true', 't', '1', 'yes'}:
33
+ return True
34
+ elif value.lower() in {'False', 'false', 'f', '0', 'no'}:
35
+ return False
36
+ else:
37
+ raise argparse.ArgumentTypeError(f"Invalid boolean value: {value}")
38
+
39
+
40
+ def parse_target(value: str) -> Tuple[str, Union[str, None], Union[str, None]]:
41
+ """
42
+ Parse a single target string.
43
+ Expected formats:
44
+ - "name"
45
+ - "name,user_id,session_id"
46
+ """
47
+ parts = value.split(',')
48
+ if len(parts) == 1:
49
+ obj_name = parts[0]
50
+ return obj_name, None, None
51
+ elif len(parts) == 3:
52
+ obj_name, user_id, session_id = parts
53
+ return obj_name, user_id, session_id
54
+ else:
55
+ raise argparse.ArgumentTypeError(
56
+ f"Invalid target format: '{value}'. Expected 'name' or 'name,user_id,session_id'."
57
+ )
58
+
59
+
60
+ def parse_seg_target(value: str) -> List[Tuple[str, Union[str, None], Union[str, None]]]:
61
+ """
62
+ Parse segmentation targets. Each target can have the format:
63
+ - "name"
64
+ - "name,user_id,session_id"
65
+ Multiple targets can be comma-separated.
66
+ """
67
+ targets = []
68
+ for target in value.split(';'): # Use ';' as a separator for multiple targets
69
+ parts = target.split(',')
70
+ if len(parts) == 1:
71
+ name = parts[0]
72
+ targets.append((name, None, None))
73
+ elif len(parts) == 3:
74
+ name, user_id, session_id = parts
75
+ targets.append((name, user_id, session_id))
76
+ else:
77
+ raise argparse.ArgumentTypeError(
78
+ f"Invalid seg-target format: '{target}'. Expected 'name' or 'name,user_id,session_id'."
79
+ )
80
+ return targets
81
+
82
+
83
+ def parse_copick_configs(config_entries: List[str]):
84
+ """
85
+ Parse a string representing a list of CoPick configuration file paths.
86
+ """
87
+ # Process the --config arguments into a dictionary
88
+ copick_configs = {}
89
+
90
+ for config_entry in config_entries:
91
+ if ',' in config_entry:
92
+ # Entry has a session name and a config path
93
+ try:
94
+ session_name, config_path = config_entry.split(',', 1)
95
+ copick_configs[session_name] = config_path
96
+ except ValueError:
97
+ raise argparse.ArgumentTypeError(
98
+ f"Invalid format for --config entry: '{config_entry}'. Expected 'session_name,/path/to/config.json'."
99
+ )
100
+ else:
101
+ # Single configuration path without a session name
102
+ # if "default" in copick_configs:
103
+ # raise argparse.ArgumentTypeError(
104
+ # f"Only one single-path --config entry is allowed when using default configurations. "
105
+ # f"Detected duplicate: {config_entry}"
106
+ # )
107
+ # copick_configs["default"] = config_entry
108
+ copick_configs = config_entry
109
+
110
+ # if ',' in config_entry:
111
+ # parts = config_entry.split(',')
112
+ # if len(parts) == 2:
113
+ # # Entry with session name and config path
114
+ # session_name, config_path = parts
115
+ # copick_configs[session_name] = {"path": config_path, "algorithm": None}
116
+ # elif len(parts) == 3:
117
+ # # Entry with session name, config path, and algorithm
118
+ # session_name, config_path, algorithm = parts
119
+ # copick_configs[session_name] = {"path": config_path, "algorithm": algorithm}
120
+ # else:
121
+ # copick_configs = config_entry
122
+
123
+ return copick_configs
124
+
125
+
126
+ def parse_data_split(value: str) -> Tuple[float, float, float]:
127
+ """
128
+ Parse data split ratios from string input.
129
+
130
+ Args:
131
+ value: Either a single float (e.g., "0.8") or two comma-separated floats (e.g., "0.7,0.1")
132
+
133
+ Returns:
134
+ Tuple of (train_ratio, val_ratio, test_ratio)
135
+
136
+ Examples:
137
+ "0.8" -> (0.8, 0.2, 0.0)
138
+ "0.7,0.1" -> (0.7, 0.1, 0.2)
139
+ """
140
+ parts = value.split(',')
141
+
142
+ if len(parts) == 1:
143
+ # Single value provided - use it as train ratio
144
+ train_ratio = float(parts[0])
145
+ val_ratio = 1.0 - train_ratio
146
+ test_ratio = 0.0
147
+ elif len(parts) == 2:
148
+ # Two values provided - use as train and val ratios
149
+ train_ratio = float(parts[0])
150
+ val_ratio = float(parts[1])
151
+ test_ratio = 1.0 - train_ratio - val_ratio
152
+ else:
153
+ raise ValueError("Data split must be either a single value or two comma-separated values")
154
+
155
+ # Validate ratios
156
+ if train_ratio < 0 or val_ratio < 0 or test_ratio < 0:
157
+ raise ValueError("All ratios must be non-negative")
158
+
159
+ if abs(train_ratio + val_ratio + test_ratio - 1.0) > 1e-6:
160
+ raise ValueError(f"Ratios must sum to 1.0, got {train_ratio + val_ratio + test_ratio}")
161
+
162
+ return round(train_ratio, 2), round(val_ratio, 2), round(test_ratio, 2)
@@ -0,0 +1,78 @@
1
+ from rich.progress import (
2
+ Progress,
3
+ SpinnerColumn,
4
+ TextColumn,
5
+ BarColumn,
6
+ MofNCompleteColumn,
7
+ TimeElapsedColumn,
8
+ TimeRemainingColumn,
9
+ )
10
+ from rich.console import Console
11
+ from rich.table import Table
12
+ from rich.panel import Panel
13
+ import json
14
+
15
+ # Minimal helper to get a shared Console without top-level Rich dependency.
16
+ def get_console():
17
+ return Console()
18
+
19
+ def _progress(iterable, description="Processing"):
20
+ """
21
+ Wrap an iterable with a Rich progress bar.
22
+
23
+ Args:
24
+ iterable: Any iterable object (e.g., list, generator).
25
+ description: Text label to display above the progress bar.
26
+
27
+ Yields:
28
+ Each item from the iterable, while updating the progress bar.
29
+
30
+ Example:
31
+ for x in _progress(range(10), "Doing work"):
32
+ time.sleep(0.5)
33
+ """
34
+
35
+ console = Console()
36
+
37
+ # The generator itself yields items while advancing the progress bar
38
+ with Progress(
39
+ SpinnerColumn(),
40
+ TextColumn(f"[bold blue]{description}"),
41
+ BarColumn(),
42
+ MofNCompleteColumn(),
43
+ TimeElapsedColumn(),
44
+ TimeRemainingColumn(),
45
+ transient=False,
46
+ console=console,
47
+ ) as progress:
48
+ task = progress.add_task(description, total=len(iterable))
49
+ for item in iterable:
50
+ yield item
51
+ progress.advance(task)
52
+
53
+ def print_summary(process: str, **kwargs):
54
+ """
55
+ Pretty-print download parameters using Rich in a clean table with green highlights.
56
+ """
57
+
58
+ console = Console()
59
+
60
+ # ---- Header ----
61
+ console.rule(f"[bold green]{process} Parameters Summary[/bold green]")
62
+
63
+ # ---- Table (no outer box) ----
64
+ table = Table(
65
+ show_header=True,
66
+ header_style="bold magenta",
67
+ expand=False,
68
+ border_style="green", # table borders only
69
+ )
70
+ table.add_column("Parameter", style="cyan", no_wrap=True)
71
+ table.add_column("Value", style="white")
72
+
73
+ for key, value in kwargs.items():
74
+ if isinstance(value, (dict, list)):
75
+ value = json.dumps(value, indent=2)
76
+ table.add_row(str(key), str(value))
77
+
78
+ console.print(table) # Print table directly, NO panel wrapper
@@ -0,0 +1,143 @@
1
+ import numpy as np
2
+
3
+ class EarlyStoppingChecker:
4
+ """
5
+ A class to manage various early stopping criteria for model training.
6
+ """
7
+
8
+ def __init__(self,
9
+ max_nan_epochs=15,
10
+ plateau_patience=20,
11
+ plateau_min_delta=0.005,
12
+ stagnation_patience=30,
13
+ convergence_window=5,
14
+ convergence_threshold=0.005,
15
+ val_interval=15,
16
+ monitor_metric='avg_fbeta'):
17
+ """
18
+ Initialize early stopping parameters.
19
+
20
+ Args:
21
+ max_nan_epochs: Maximum number of epochs with NaN loss before stopping
22
+ plateau_patience: Number of validation checks to wait for plateau detection
23
+ plateau_min_delta: Minimum change to qualify as improvement
24
+ stagnation_patience: Number of validation intervals to wait for best metric improvement
25
+ convergence_window: Window size for calculating improvement rate
26
+ convergence_threshold: Minimum improvement rate threshold
27
+ val_interval: Number of epochs between validation runs
28
+ monitor_metric: Primary metric to monitor for early stopping criteria
29
+ """
30
+ self.max_nan_epochs = max_nan_epochs
31
+ self.plateau_patience = plateau_patience
32
+ self.plateau_min_delta = plateau_min_delta
33
+ self.stagnation_patience = stagnation_patience
34
+ self.convergence_window = convergence_window
35
+ self.convergence_threshold = convergence_threshold
36
+ self.val_interval = val_interval
37
+ self.monitor_metric = monitor_metric
38
+
39
+ # Counters
40
+ self.nan_counter = 0
41
+
42
+ # Flags for detailed reporting
43
+ self.stopped_reason = None
44
+
45
+ def check_for_nan(self, epoch_loss):
46
+ """Check for NaN in the loss."""
47
+ if np.isnan(epoch_loss):
48
+ self.nan_counter += 1
49
+ if self.nan_counter > self.max_nan_epochs:
50
+ self.stopped_reason = f"NaN values in loss for more than {self.max_nan_epochs} epochs"
51
+ return True
52
+ else:
53
+ self.nan_counter = 0 # Reset the counter if loss is valid
54
+ return False
55
+
56
+ def check_for_plateau(self, results):
57
+ """Detect plateaus in validation metrics."""
58
+ if len(results[self.monitor_metric]) < self.plateau_patience + 1:
59
+ return False
60
+
61
+ # Get the last 'patience' number of validation points
62
+ recent_values = [x[1] for x in results[self.monitor_metric][-self.plateau_patience:]]
63
+ # Find the max value in the window
64
+ max_value = max(recent_values)
65
+ # Find the min value in the window
66
+ min_value = min(recent_values)
67
+
68
+ # If the range of values is small, consider it a plateau
69
+ if max_value - min_value < self.plateau_min_delta:
70
+ self.stopped_reason = f"{self.monitor_metric} plateaued for {self.plateau_patience} validations"
71
+ return True
72
+
73
+ return False
74
+
75
+ def check_best_metric_stagnation(self, results):
76
+ """Stop if best metric hasn't improved for a number of validation intervals."""
77
+ if "best_metric_epoch" not in results or len(results[self.monitor_metric]) < self.stagnation_patience + 1:
78
+ return False
79
+
80
+ # Get epoch of the best metric so far
81
+ best_epoch = results["best_metric_epoch"]
82
+ current_epoch = results[self.monitor_metric][-1][0]
83
+
84
+ # Check if it's been more than 'patience' validation intervals
85
+ if (current_epoch - best_epoch) >= (self.stagnation_patience * self.val_interval):
86
+ self.stopped_reason = f"No improvement for {self.stagnation_patience} validation intervals"
87
+ return True
88
+
89
+ return False
90
+
91
+ # def check_convergence_rate(self, results):
92
+ # """Stop when improvement rate slows below threshold."""
93
+ # if len(results[self.monitor_metric]) < self.convergence_window + 1:
94
+ # return False
95
+
96
+ # # Calculate average improvement rate over window
97
+ # recent_values = [x[1] for x in results[self.monitor_metric][-(self.convergence_window+1):]]
98
+ # improvements = [recent_values[i+1] - recent_values[i] for i in range(self.convergence_window)]
99
+ # avg_improvement = sum(improvements) / self.convergence_window
100
+
101
+ # if avg_improvement < self.convergence_threshold and avg_improvement > 0:
102
+ # self.stopped_reason = f"Convergence rate ({avg_improvement:.6f}) below threshold"
103
+ # return True
104
+
105
+ # return False
106
+
107
+ def should_stop_training(self, epoch_loss, results=None, check_metrics=False):
108
+ """
109
+ Comprehensive check for whether training should stop.
110
+
111
+ Args:
112
+ epoch_loss: Current epoch's loss value
113
+ results: Dictionary containing training metrics history
114
+ check_metrics: Whether to also check validation metrics-based criteria
115
+
116
+ Returns:
117
+ bool: True if training should stop, False otherwise
118
+ """
119
+ # Check for NaN in loss (can be done every epoch)
120
+ if self.check_for_nan(epoch_loss):
121
+ return True
122
+
123
+ # Only check metric-based criteria if requested and results are provided
124
+ if check_metrics and results:
125
+ # Check for plateau in validation metrics
126
+ if self.check_for_plateau(results):
127
+ return True
128
+
129
+ # Check if best metric hasn't improved for a while
130
+ if self.check_best_metric_stagnation(results):
131
+ return True
132
+
133
+ # # Check if convergence rate has slowed down
134
+ # if self.check_convergence_rate(results):
135
+ # return True
136
+
137
+ return False
138
+
139
+ def get_stopped_reason(self):
140
+ """Get the reason for stopping, if any."""
141
+ if self.stopped_reason:
142
+ return f"Early stopping triggered: {self.stopped_reason}"
143
+ return "No early stopping criteria met."