octopi 1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of octopi might be problematic. Click here for more details.
- octopi/__init__.py +0 -0
- octopi/datasets/__init__.py +0 -0
- octopi/datasets/augment.py +84 -0
- octopi/datasets/cached_datset.py +113 -0
- octopi/datasets/dataset.py +19 -0
- octopi/datasets/generators.py +429 -0
- octopi/datasets/mixup.py +49 -0
- octopi/datasets/multi_config_generator.py +253 -0
- octopi/entry_points/__init__.py +0 -0
- octopi/entry_points/common.py +80 -0
- octopi/entry_points/create_slurm_submission.py +243 -0
- octopi/entry_points/run_create_targets.py +281 -0
- octopi/entry_points/run_evaluate.py +65 -0
- octopi/entry_points/run_extract_mb_picks.py +141 -0
- octopi/entry_points/run_extract_midpoint.py +143 -0
- octopi/entry_points/run_localize.py +222 -0
- octopi/entry_points/run_optuna.py +139 -0
- octopi/entry_points/run_segment_predict.py +166 -0
- octopi/entry_points/run_train.py +201 -0
- octopi/extract/__init__.py +0 -0
- octopi/extract/localize.py +254 -0
- octopi/extract/membranebound_extract.py +262 -0
- octopi/extract/midpoint_extract.py +193 -0
- octopi/io.py +457 -0
- octopi/losses.py +86 -0
- octopi/main.py +101 -0
- octopi/models/AttentionUnet.py +56 -0
- octopi/models/MedNeXt.py +111 -0
- octopi/models/ModelTemplate.py +36 -0
- octopi/models/SegResNet.py +92 -0
- octopi/models/Unet.py +59 -0
- octopi/models/UnetPlusPlus.py +47 -0
- octopi/models/__init__.py +0 -0
- octopi/models/common.py +62 -0
- octopi/processing/__init__.py +0 -0
- octopi/processing/create_targets_from_picks.py +106 -0
- octopi/processing/downsample.py +129 -0
- octopi/processing/evaluate.py +289 -0
- octopi/processing/importers.py +213 -0
- octopi/processing/my_metrics.py +26 -0
- octopi/processing/segmentation_from_picks.py +167 -0
- octopi/processing/writers.py +102 -0
- octopi/pytorch/__init__.py +0 -0
- octopi/pytorch/hyper_search.py +243 -0
- octopi/pytorch/model_search_submitter.py +290 -0
- octopi/pytorch/segmentation.py +317 -0
- octopi/pytorch/trainer.py +438 -0
- octopi/pytorch_lightning/__init__.py +0 -0
- octopi/pytorch_lightning/optuna_pl_ddp.py +273 -0
- octopi/pytorch_lightning/train_pl.py +244 -0
- octopi/stopping_criteria.py +143 -0
- octopi/submit_slurm.py +95 -0
- octopi/utils.py +238 -0
- octopi/visualization_tools.py +201 -0
- octopi-1.0.dist-info/LICENSE +41 -0
- octopi-1.0.dist-info/METADATA +209 -0
- octopi-1.0.dist-info/RECORD +59 -0
- octopi-1.0.dist-info/WHEEL +4 -0
- octopi-1.0.dist-info/entry_points.txt +4 -0
octopi/submit_slurm.py
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
1
|
+
def create_shellsubmit(
|
|
2
|
+
job_name,
|
|
3
|
+
output_file,
|
|
4
|
+
shell_name,
|
|
5
|
+
conda_path,
|
|
6
|
+
command,
|
|
7
|
+
num_gpus = 1,
|
|
8
|
+
gpu_constraint = 'h100'):
|
|
9
|
+
|
|
10
|
+
if num_gpus > 0:
|
|
11
|
+
slurm_gpus = f'#SBATCH --partition=gpu\n#SBATCH --gpus={gpu_constraint}:{num_gpus}'
|
|
12
|
+
else:
|
|
13
|
+
slurm_gpus = f'#SBATCH --partition=cpu'
|
|
14
|
+
|
|
15
|
+
shell_script_content = f"""#!/bin/bash
|
|
16
|
+
|
|
17
|
+
{slurm_gpus}
|
|
18
|
+
#SBATCH --time=18:00:00
|
|
19
|
+
#SBATCH --cpus-per-task=4
|
|
20
|
+
#SBATCH --mem-per-cpu=16G
|
|
21
|
+
#SBATCH --job-name={job_name}
|
|
22
|
+
#SBATCH --output={output_file}
|
|
23
|
+
|
|
24
|
+
ml anaconda
|
|
25
|
+
conda activate {conda_path}
|
|
26
|
+
{command}
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
# Save to file
|
|
30
|
+
with open(shell_name, "w") as file:
|
|
31
|
+
file.write(shell_script_content)
|
|
32
|
+
|
|
33
|
+
print(f"\nShell script has been created successfully as {shell_name}\n")
|
|
34
|
+
|
|
35
|
+
def create_shellsubmit_array(
|
|
36
|
+
job_name,
|
|
37
|
+
output_file,
|
|
38
|
+
shell_name,
|
|
39
|
+
conda_path,
|
|
40
|
+
command,
|
|
41
|
+
job_array = [min, max]):
|
|
42
|
+
|
|
43
|
+
shell_script_content = f"""#!/bin/bash
|
|
44
|
+
|
|
45
|
+
#SBATCH --time=18:00:00
|
|
46
|
+
#SBATCH --cpus-per-task=4
|
|
47
|
+
#SBATCH --mem-per-cpu=16G
|
|
48
|
+
#SBATCH --job-name={job_name}
|
|
49
|
+
#SBATCH --output={output_file}
|
|
50
|
+
#SBATCH --array={job_array[0]}-{job_array[1]}
|
|
51
|
+
|
|
52
|
+
ml anaconda
|
|
53
|
+
conda activate {conda_path}
|
|
54
|
+
{command}
|
|
55
|
+
"""
|
|
56
|
+
|
|
57
|
+
# Save to file
|
|
58
|
+
with open(shell_name, "w") as file:
|
|
59
|
+
file.write(shell_script_content)
|
|
60
|
+
|
|
61
|
+
print(f"\nShell script has been created successfully as {shell_name}\n")
|
|
62
|
+
|
|
63
|
+
def create_multiconfig_shellsubmit(
|
|
64
|
+
job_name,
|
|
65
|
+
output_file,
|
|
66
|
+
shell_name,
|
|
67
|
+
conda_path,
|
|
68
|
+
base_inputs,
|
|
69
|
+
config_inputs,
|
|
70
|
+
command):
|
|
71
|
+
|
|
72
|
+
multiconfig = f"""#! /bin/bash
|
|
73
|
+
|
|
74
|
+
#SBATCH --job-name={job_name}
|
|
75
|
+
#SBATCH --time=24:00:00
|
|
76
|
+
#SBATCH --cpus-per-task=4
|
|
77
|
+
#SBATCH --mem-per-cpu=12G
|
|
78
|
+
#SBATCH --partition=cpu
|
|
79
|
+
#SBATCH --output={output_file}
|
|
80
|
+
|
|
81
|
+
ml anaconda
|
|
82
|
+
{conda_path}
|
|
83
|
+
|
|
84
|
+
{base_inputs}
|
|
85
|
+
|
|
86
|
+
{config_inputs}
|
|
87
|
+
|
|
88
|
+
{command}
|
|
89
|
+
"""
|
|
90
|
+
|
|
91
|
+
# Save to file
|
|
92
|
+
with open(shell_name, "w") as file:
|
|
93
|
+
file.write(multiconfig)
|
|
94
|
+
|
|
95
|
+
print(f"\nShell script has been created successfully as {shell_name}\n")
|
octopi/utils.py
ADDED
|
@@ -0,0 +1,238 @@
|
|
|
1
|
+
from monai.networks.nets import UNet, AttentionUnet
|
|
2
|
+
from typing import List, Tuple, Union
|
|
3
|
+
from dotenv import load_dotenv
|
|
4
|
+
import argparse, octopi
|
|
5
|
+
import torch, random, os, yaml
|
|
6
|
+
from typing import List
|
|
7
|
+
import numpy as np
|
|
8
|
+
|
|
9
|
+
##############################################################################################################################
|
|
10
|
+
|
|
11
|
+
def mlflow_setup():
|
|
12
|
+
|
|
13
|
+
module_root = os.path.dirname(octopi.__file__)
|
|
14
|
+
dotenv_path = module_root.replace('src/octopi','') + '.env'
|
|
15
|
+
load_dotenv(dotenv_path=dotenv_path)
|
|
16
|
+
|
|
17
|
+
# MLflow setup
|
|
18
|
+
username = os.getenv('MLFLOW_TRACKING_USERNAME')
|
|
19
|
+
password = os.getenv('MLFLOW_TRACKING_PASSWORD')
|
|
20
|
+
if not password or not username:
|
|
21
|
+
print("Password not found in environment, loading from .env file...")
|
|
22
|
+
load_dotenv() # Loads environment variables from a .env file
|
|
23
|
+
username = os.getenv('MLFLOW_TRACKING_USERNAME')
|
|
24
|
+
password = os.getenv('MLFLOW_TRACKING_PASSWORD')
|
|
25
|
+
|
|
26
|
+
# Check again after loading .env file
|
|
27
|
+
if not password:
|
|
28
|
+
raise ValueError("Password is not set in environment variables or .env file!")
|
|
29
|
+
else:
|
|
30
|
+
print("Password loaded successfully")
|
|
31
|
+
os.environ['MLFLOW_TRACKING_USERNAME'] = username
|
|
32
|
+
os.environ['MLFLOW_TRACKING_PASSWORD'] = password
|
|
33
|
+
|
|
34
|
+
return os.getenv('MLFLOW_TRACKING_URI')
|
|
35
|
+
|
|
36
|
+
##############################################################################################################################
|
|
37
|
+
|
|
38
|
+
def set_seed(seed):
|
|
39
|
+
# Set the seed for Python's random module
|
|
40
|
+
random.seed(seed)
|
|
41
|
+
|
|
42
|
+
# Set the seed for NumPy
|
|
43
|
+
np.random.seed(seed)
|
|
44
|
+
|
|
45
|
+
# Set the seed for PyTorch (both CPU and GPU)
|
|
46
|
+
torch.manual_seed(seed)
|
|
47
|
+
if torch.cuda.is_available():
|
|
48
|
+
torch.cuda.manual_seed(seed)
|
|
49
|
+
torch.cuda.manual_seed_all(seed) # If using multi-GPU
|
|
50
|
+
|
|
51
|
+
# Ensure reproducibility of operations by disabling certain optimizations
|
|
52
|
+
torch.backends.cudnn.deterministic = True
|
|
53
|
+
torch.backends.cudnn.benchmark = False
|
|
54
|
+
|
|
55
|
+
###############################################################################################################################
|
|
56
|
+
|
|
57
|
+
def parse_list(value: str) -> List[str]:
|
|
58
|
+
"""
|
|
59
|
+
Parse a string representing a list of items.
|
|
60
|
+
Supports formats like '[item1,item2,item3]' or 'item1,item2,item3'.
|
|
61
|
+
"""
|
|
62
|
+
value = value.strip("[]") # Remove brackets if present
|
|
63
|
+
return [x.strip() for x in value.split(",")]
|
|
64
|
+
|
|
65
|
+
###############################################################################################################################
|
|
66
|
+
|
|
67
|
+
def parse_int_list(value: str) -> List[int]:
|
|
68
|
+
"""
|
|
69
|
+
Parse a string representing a list of integers.
|
|
70
|
+
Supports formats like '[1,2,3]' or '1,2,3'.
|
|
71
|
+
"""
|
|
72
|
+
return [int(x) for x in parse_list(value)]
|
|
73
|
+
|
|
74
|
+
###############################################################################################################################
|
|
75
|
+
|
|
76
|
+
def string2bool(value: str):
|
|
77
|
+
"""
|
|
78
|
+
Custom function to convert string values to boolean.
|
|
79
|
+
"""
|
|
80
|
+
if isinstance(value, bool):
|
|
81
|
+
return value
|
|
82
|
+
if value.lower() in {'True', 'true', 't', '1', 'yes'}:
|
|
83
|
+
return True
|
|
84
|
+
elif value.lower() in {'False', 'false', 'f', '0', 'no'}:
|
|
85
|
+
return False
|
|
86
|
+
else:
|
|
87
|
+
raise argparse.ArgumentTypeError(f"Invalid boolean value: {value}")
|
|
88
|
+
|
|
89
|
+
###############################################################################################################################
|
|
90
|
+
|
|
91
|
+
def parse_target(value: str) -> Tuple[str, Union[str, None], Union[str, None]]:
|
|
92
|
+
"""
|
|
93
|
+
Parse a single target string.
|
|
94
|
+
Expected formats:
|
|
95
|
+
- "name"
|
|
96
|
+
- "name,user_id,session_id"
|
|
97
|
+
"""
|
|
98
|
+
parts = value.split(',')
|
|
99
|
+
if len(parts) == 1:
|
|
100
|
+
obj_name = parts[0]
|
|
101
|
+
return obj_name, None, None
|
|
102
|
+
elif len(parts) == 3:
|
|
103
|
+
obj_name, user_id, session_id = parts
|
|
104
|
+
return obj_name, user_id, session_id
|
|
105
|
+
else:
|
|
106
|
+
raise argparse.ArgumentTypeError(
|
|
107
|
+
f"Invalid target format: '{value}'. Expected 'name' or 'name,user_id,session_id'."
|
|
108
|
+
)
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
def parse_seg_target(value: str) -> List[Tuple[str, Union[str, None], Union[str, None]]]:
|
|
112
|
+
"""
|
|
113
|
+
Parse segmentation targets. Each target can have the format:
|
|
114
|
+
- "name"
|
|
115
|
+
- "name,user_id,session_id"
|
|
116
|
+
Multiple targets can be comma-separated.
|
|
117
|
+
"""
|
|
118
|
+
targets = []
|
|
119
|
+
for target in value.split(';'): # Use ';' as a separator for multiple targets
|
|
120
|
+
parts = target.split(',')
|
|
121
|
+
if len(parts) == 1:
|
|
122
|
+
name = parts[0]
|
|
123
|
+
targets.append((name, None, None))
|
|
124
|
+
elif len(parts) == 3:
|
|
125
|
+
name, user_id, session_id = parts
|
|
126
|
+
targets.append((name, user_id, session_id))
|
|
127
|
+
else:
|
|
128
|
+
raise argparse.ArgumentTypeError(
|
|
129
|
+
f"Invalid seg-target format: '{target}'. Expected 'name' or 'name,user_id,session_id'."
|
|
130
|
+
)
|
|
131
|
+
return targets
|
|
132
|
+
|
|
133
|
+
def parse_copick_configs(config_entries: List[str]):
|
|
134
|
+
"""
|
|
135
|
+
Parse a string representing a list of CoPick configuration file paths.
|
|
136
|
+
"""
|
|
137
|
+
# Process the --config arguments into a dictionary
|
|
138
|
+
copick_configs = {}
|
|
139
|
+
|
|
140
|
+
for config_entry in config_entries:
|
|
141
|
+
if ',' in config_entry:
|
|
142
|
+
# Entry has a session name and a config path
|
|
143
|
+
try:
|
|
144
|
+
session_name, config_path = config_entry.split(',', 1)
|
|
145
|
+
copick_configs[session_name] = config_path
|
|
146
|
+
except ValueError:
|
|
147
|
+
raise argparse.ArgumentTypeError(
|
|
148
|
+
f"Invalid format for --config entry: '{config_entry}'. Expected 'session_name,/path/to/config.json'."
|
|
149
|
+
)
|
|
150
|
+
else:
|
|
151
|
+
# Single configuration path without a session name
|
|
152
|
+
# if "default" in copick_configs:
|
|
153
|
+
# raise argparse.ArgumentTypeError(
|
|
154
|
+
# f"Only one single-path --config entry is allowed when using default configurations. "
|
|
155
|
+
# f"Detected duplicate: {config_entry}"
|
|
156
|
+
# )
|
|
157
|
+
# copick_configs["default"] = config_entry
|
|
158
|
+
copick_configs = config_entry
|
|
159
|
+
|
|
160
|
+
# if ',' in config_entry:
|
|
161
|
+
# parts = config_entry.split(',')
|
|
162
|
+
# if len(parts) == 2:
|
|
163
|
+
# # Entry with session name and config path
|
|
164
|
+
# session_name, config_path = parts
|
|
165
|
+
# copick_configs[session_name] = {"path": config_path, "algorithm": None}
|
|
166
|
+
# elif len(parts) == 3:
|
|
167
|
+
# # Entry with session name, config path, and algorithm
|
|
168
|
+
# session_name, config_path, algorithm = parts
|
|
169
|
+
# copick_configs[session_name] = {"path": config_path, "algorithm": algorithm}
|
|
170
|
+
# else:
|
|
171
|
+
# copick_configs = config_entry
|
|
172
|
+
|
|
173
|
+
return copick_configs
|
|
174
|
+
|
|
175
|
+
def parse_data_split(value: str) -> Tuple[float, float, float]:
|
|
176
|
+
"""
|
|
177
|
+
Parse data split ratios from string input.
|
|
178
|
+
|
|
179
|
+
Args:
|
|
180
|
+
value: Either a single float (e.g., "0.8") or two comma-separated floats (e.g., "0.7,0.1")
|
|
181
|
+
|
|
182
|
+
Returns:
|
|
183
|
+
Tuple of (train_ratio, val_ratio, test_ratio)
|
|
184
|
+
|
|
185
|
+
Examples:
|
|
186
|
+
"0.8" -> (0.8, 0.2, 0.0)
|
|
187
|
+
"0.7,0.1" -> (0.7, 0.1, 0.2)
|
|
188
|
+
"""
|
|
189
|
+
parts = value.split(',')
|
|
190
|
+
|
|
191
|
+
if len(parts) == 1:
|
|
192
|
+
# Single value provided - use it as train ratio
|
|
193
|
+
train_ratio = float(parts[0])
|
|
194
|
+
val_ratio = 1.0 - train_ratio
|
|
195
|
+
test_ratio = 0.0
|
|
196
|
+
elif len(parts) == 2:
|
|
197
|
+
# Two values provided - use as train and val ratios
|
|
198
|
+
train_ratio = float(parts[0])
|
|
199
|
+
val_ratio = float(parts[1])
|
|
200
|
+
test_ratio = 1.0 - train_ratio - val_ratio
|
|
201
|
+
else:
|
|
202
|
+
raise ValueError("Data split must be either a single value or two comma-separated values")
|
|
203
|
+
|
|
204
|
+
# Validate ratios
|
|
205
|
+
if train_ratio < 0 or val_ratio < 0 or test_ratio < 0:
|
|
206
|
+
raise ValueError("All ratios must be non-negative")
|
|
207
|
+
|
|
208
|
+
if abs(train_ratio + val_ratio + test_ratio - 1.0) > 1e-6:
|
|
209
|
+
raise ValueError(f"Ratios must sum to 1.0, got {train_ratio + val_ratio + test_ratio}")
|
|
210
|
+
|
|
211
|
+
return round(train_ratio, 2), round(val_ratio, 2), round(test_ratio, 2)
|
|
212
|
+
|
|
213
|
+
##############################################################################################################################
|
|
214
|
+
|
|
215
|
+
# Create a custom dumper that uses flow style for lists only.
|
|
216
|
+
class InlineListDumper(yaml.SafeDumper):
|
|
217
|
+
def represent_list(self, data):
|
|
218
|
+
node = super().represent_list(data)
|
|
219
|
+
node.flow_style = True # Use inline style for lists
|
|
220
|
+
return node
|
|
221
|
+
|
|
222
|
+
def save_parameters_yaml(params: dict, output_path: str):
|
|
223
|
+
"""
|
|
224
|
+
Save parameters to a YAML file.
|
|
225
|
+
"""
|
|
226
|
+
InlineListDumper.add_representer(list, InlineListDumper.represent_list)
|
|
227
|
+
with open(output_path, 'w') as f:
|
|
228
|
+
yaml.dump(params, f, Dumper=InlineListDumper, default_flow_style=False, sort_keys=False)
|
|
229
|
+
|
|
230
|
+
def load_yaml(path: str) -> dict:
|
|
231
|
+
"""
|
|
232
|
+
Load a YAML file and return the contents as a dictionary.
|
|
233
|
+
"""
|
|
234
|
+
if os.path.exists(path):
|
|
235
|
+
with open(path, 'r') as f:
|
|
236
|
+
return yaml.safe_load(f)
|
|
237
|
+
else:
|
|
238
|
+
raise FileNotFoundError(f"File not found: {path}")
|
|
@@ -0,0 +1,201 @@
|
|
|
1
|
+
import matplotlib.colors as mcolors
|
|
2
|
+
from typing import Optional, List
|
|
3
|
+
import matplotlib.pyplot as plt
|
|
4
|
+
from octopi import io
|
|
5
|
+
import numpy as np
|
|
6
|
+
|
|
7
|
+
# Define the plotting function
|
|
8
|
+
def show_tomo_segmentation(tomo, seg, vol_slice):
|
|
9
|
+
|
|
10
|
+
plt.figure(figsize=(20, 10))
|
|
11
|
+
|
|
12
|
+
# Tomogram
|
|
13
|
+
plt.subplot(1, 3, 1)
|
|
14
|
+
plt.title('Tomogram')
|
|
15
|
+
plt.imshow(tomo[vol_slice], cmap='gray')
|
|
16
|
+
plt.axis('off')
|
|
17
|
+
|
|
18
|
+
# Painted Segmentation
|
|
19
|
+
plt.subplot(1, 3, 2)
|
|
20
|
+
plt.title('Painted Segmentation from Picks')
|
|
21
|
+
plt.imshow(seg[vol_slice], cmap='viridis')
|
|
22
|
+
plt.axis('off')
|
|
23
|
+
|
|
24
|
+
plt.subplot(1, 3, 3)
|
|
25
|
+
plt.title('Overlay')
|
|
26
|
+
plt.imshow(tomo[vol_slice], cmap='gray')
|
|
27
|
+
plt.imshow(seg[vol_slice], cmap='viridis', alpha=0.5) # Add alpha=0.5 for 50% transparency
|
|
28
|
+
plt.axis('off')
|
|
29
|
+
|
|
30
|
+
plt.tight_layout()
|
|
31
|
+
plt.show()
|
|
32
|
+
|
|
33
|
+
def show_labeled_tomo_segmentation(tomo, seg, seg_labels, unique_values, vol_slice):
|
|
34
|
+
|
|
35
|
+
# # Check unique values in segmentation to ensure correct mapping
|
|
36
|
+
# unique_values = np.unique(seg)
|
|
37
|
+
|
|
38
|
+
plt.figure(figsize=(20, 10))
|
|
39
|
+
num_classes = len(seg_labels)
|
|
40
|
+
|
|
41
|
+
# Dynamically update the labels and colormap based on unique values
|
|
42
|
+
seg_labels_filtered = {k: v for k, v in seg_labels.items() if k in unique_values}
|
|
43
|
+
num_classes = len(seg_labels_filtered)
|
|
44
|
+
|
|
45
|
+
# Create a discrete colormap
|
|
46
|
+
colors = plt.cm.tab20b(np.linspace(0, 1, num_classes)) # You can use other colormaps like 'Set3', 'tab20', etc.
|
|
47
|
+
cmap = mcolors.ListedColormap(colors)
|
|
48
|
+
bounds = list(seg_labels_filtered.keys()) + [max(seg_labels_filtered.keys())]
|
|
49
|
+
# norm = mcolors.BoundaryNorm(bounds, cmap.N)
|
|
50
|
+
|
|
51
|
+
# Tomogram plot
|
|
52
|
+
plt.subplot(1, 2, 1)
|
|
53
|
+
plt.title('Tomogram')
|
|
54
|
+
plt.imshow(tomo[vol_slice], cmap='gray')
|
|
55
|
+
plt.axis('off')
|
|
56
|
+
|
|
57
|
+
# Prediction segmentation plot
|
|
58
|
+
plt.subplot(1, 2, 2)
|
|
59
|
+
plt.title('Prediction Segmentation')
|
|
60
|
+
im = plt.imshow(seg[vol_slice], cmap=cmap) # Use norm and cmap for segmentation
|
|
61
|
+
plt.axis('off')
|
|
62
|
+
|
|
63
|
+
# Add the labeled color bar
|
|
64
|
+
cbar = plt.colorbar(im, ticks=list(seg_labels_filtered.keys()))
|
|
65
|
+
cbar.ax.set_yticklabels([seg_labels_filtered[i] for i in seg_labels_filtered.keys()]) # Set custom labels
|
|
66
|
+
|
|
67
|
+
plt.tight_layout()
|
|
68
|
+
plt.show()
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def show_tomo_points(tomo, run, objects, user_id, vol_slice, session_id = None, slice_proximity_threshold = 3):
|
|
72
|
+
plt.figure(figsize=(20, 10))
|
|
73
|
+
|
|
74
|
+
plt.imshow(tomo[vol_slice],cmap='gray')
|
|
75
|
+
plt.axis('off')
|
|
76
|
+
|
|
77
|
+
for name,_,_ in objects:
|
|
78
|
+
try:
|
|
79
|
+
coordinates = io.get_copick_coordinates(run, name=name, user_id=user_id, session_id=session_id)
|
|
80
|
+
close_points = coordinates[np.abs(coordinates[:, 0] - vol_slice) <= slice_proximity_threshold]
|
|
81
|
+
plt.scatter(close_points[:, 2], close_points[:, 1], label=name, s=15)
|
|
82
|
+
except:
|
|
83
|
+
pass
|
|
84
|
+
|
|
85
|
+
plt.show()
|
|
86
|
+
|
|
87
|
+
def compare_tomo_points(tomo, run, objects, vol_slice, user_id1, user_id2,
|
|
88
|
+
session_id1 = None, session_id2 = None, slice_proximity_threshold = 3):
|
|
89
|
+
plt.figure(figsize=(20, 10))
|
|
90
|
+
|
|
91
|
+
plt.subplot(1, 2, 1)
|
|
92
|
+
plt.imshow(tomo[vol_slice],cmap='gray')
|
|
93
|
+
plt.title(f'{user_id1} Picks')
|
|
94
|
+
|
|
95
|
+
for name,_,_ in objects:
|
|
96
|
+
try:
|
|
97
|
+
coordinates = io.get_copick_coordinates(run, name=name, user_id=user_id1, session_id=session_id1)
|
|
98
|
+
close_points = coordinates[np.abs(coordinates[:, 0] - vol_slice) <= slice_proximity_threshold]
|
|
99
|
+
plt.scatter(close_points[:, 2], close_points[:, 1], label=name, s=15)
|
|
100
|
+
except:
|
|
101
|
+
pass
|
|
102
|
+
|
|
103
|
+
plt.subplot(1, 2, 2)
|
|
104
|
+
plt.imshow(tomo[vol_slice],cmap='gray')
|
|
105
|
+
plt.title(f'{user_id2} Picks')
|
|
106
|
+
|
|
107
|
+
for name,_,_ in objects:
|
|
108
|
+
try:
|
|
109
|
+
coordinates = io.get_copick_coordinates(run, name=name, user_id=user_id2, session_id=session_id2)
|
|
110
|
+
close_points = coordinates[np.abs(coordinates[:, 0] - vol_slice) <= slice_proximity_threshold]
|
|
111
|
+
plt.scatter(close_points[:, 2], close_points[:, 1], label=name, s=15)
|
|
112
|
+
except:
|
|
113
|
+
pass
|
|
114
|
+
|
|
115
|
+
plt.axis('off')
|
|
116
|
+
plt.show()
|
|
117
|
+
|
|
118
|
+
def plot_training_results(
|
|
119
|
+
results,
|
|
120
|
+
class_names: Optional[List[str]] = None,
|
|
121
|
+
save_plot: str = None,
|
|
122
|
+
fig = None, axs = None):
|
|
123
|
+
|
|
124
|
+
# Create a 2x2 subplot layout
|
|
125
|
+
if fig is None:
|
|
126
|
+
fig, axs = plt.subplots(2, 2, figsize=(12, 10))
|
|
127
|
+
else:
|
|
128
|
+
# Clear previos plots
|
|
129
|
+
for ax in axs.flatten():
|
|
130
|
+
ax.clear()
|
|
131
|
+
|
|
132
|
+
fig.suptitle("Metrics Over Epochs", fontsize=16)
|
|
133
|
+
|
|
134
|
+
# Unpack the data for loss (logged every epoch)
|
|
135
|
+
epochs_loss = [epoch for epoch, _ in results['loss']]
|
|
136
|
+
loss = [value for _, value in results['loss']]
|
|
137
|
+
val_epochs_loss = [epoch for epoch, _ in results['val_loss']]
|
|
138
|
+
val_loss = [value for _,value in results['val_loss']]
|
|
139
|
+
|
|
140
|
+
# Plot Training Loss in the top-left
|
|
141
|
+
axs[0, 0].plot(epochs_loss, loss, label="Training Loss")
|
|
142
|
+
axs[0, 0].plot(val_epochs_loss, val_loss, label='Validation Loss')
|
|
143
|
+
axs[0, 0].set_xlabel("Epochs")
|
|
144
|
+
axs[0, 0].set_ylabel("Loss")
|
|
145
|
+
axs[0, 0].set_title("Training Loss")
|
|
146
|
+
axs[0, 0].legend()
|
|
147
|
+
axs[0, 0].tick_params(axis='both', direction='in', top=True, right=True, length=6, width=1)
|
|
148
|
+
|
|
149
|
+
# For metrics that are logged every `val_interval` epochs
|
|
150
|
+
epochs_metrics = [epoch for epoch, _ in results['avg_recall']]
|
|
151
|
+
|
|
152
|
+
# Determine the number of classes and names
|
|
153
|
+
num_classes = len([key for key in results.keys() if key.startswith('recall_class')])
|
|
154
|
+
|
|
155
|
+
if class_names is None or len(class_names) != num_classes - 1:
|
|
156
|
+
class_names = [f"Class {i+1}" for i in range(num_classes)]
|
|
157
|
+
|
|
158
|
+
# Plot Recall in the top-right
|
|
159
|
+
for class_idx in range(num_classes):
|
|
160
|
+
recall_class = [value for _, value in results[f'recall_class{class_idx+1}']]
|
|
161
|
+
axs[0, 1].plot(epochs_metrics, recall_class, label=f"{class_names[class_idx]}")
|
|
162
|
+
axs[0, 1].set_xlabel("Epochs")
|
|
163
|
+
axs[0, 1].set_ylabel("Recall")
|
|
164
|
+
axs[0, 1].set_title("Recall per Class")
|
|
165
|
+
# axs[0, 1].legend()
|
|
166
|
+
axs[0, 1].tick_params(axis='both', direction='in', top=True, right=True, length=6, width=1)
|
|
167
|
+
|
|
168
|
+
# Plot Precision in the bottom-left
|
|
169
|
+
for class_idx in range(num_classes):
|
|
170
|
+
precision_class = [value for _, value in results[f'precision_class{class_idx+1}']]
|
|
171
|
+
axs[1, 0].plot(epochs_metrics, precision_class, label=f"{class_names[class_idx]}")
|
|
172
|
+
axs[1, 0].set_xlabel("Epochs")
|
|
173
|
+
axs[1, 0].set_ylabel("Precision")
|
|
174
|
+
axs[1, 0].set_title("Precision per Class")
|
|
175
|
+
axs[1, 0].legend()
|
|
176
|
+
axs[1, 0].tick_params(axis='both', direction='in', top=True, right=True, length=6, width=1)
|
|
177
|
+
|
|
178
|
+
# Plot F1 Score in the bottom-right
|
|
179
|
+
for class_idx in range(num_classes):
|
|
180
|
+
f1_class = [value for _, value in results[f'f1_class{class_idx+1}']]
|
|
181
|
+
axs[1, 1].plot(epochs_metrics, f1_class, label=f"{class_names[class_idx]}")
|
|
182
|
+
axs[1, 1].set_xlabel("Epochs")
|
|
183
|
+
axs[1, 1].set_ylabel("F1 Score")
|
|
184
|
+
axs[1, 1].set_title("F1 Score per Class")
|
|
185
|
+
# axs[1, 1].legend()
|
|
186
|
+
axs[1, 1].tick_params(axis='both', direction='in', top=True, right=True, length=6, width=1)
|
|
187
|
+
|
|
188
|
+
# Adjust layout and show plot
|
|
189
|
+
plt.tight_layout(rect=[0, 0, 1, 0.96]) # Leave space for the main title
|
|
190
|
+
|
|
191
|
+
fig.savefig(save_plot)
|
|
192
|
+
fig.canvas.draw()
|
|
193
|
+
|
|
194
|
+
return fig, axs
|
|
195
|
+
|
|
196
|
+
# if save_plot:
|
|
197
|
+
# fig.savefig(save_plot)
|
|
198
|
+
# else:
|
|
199
|
+
# plt.show()
|
|
200
|
+
# # # Just draw the plot without displaying
|
|
201
|
+
# # fig.canvas.draw()
|
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
# Legal
|
|
2
|
+
|
|
3
|
+
## License for the octopi package
|
|
4
|
+
|
|
5
|
+
This package is licensed under the MIT License:
|
|
6
|
+
|
|
7
|
+
```
|
|
8
|
+
MIT License
|
|
9
|
+
|
|
10
|
+
Copyright (c) 2025 Chan Zuckerberg Initiative
|
|
11
|
+
|
|
12
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
13
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
14
|
+
in the Software without restriction, including without limitation the rights
|
|
15
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
16
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
17
|
+
furnished to do so, subject to the following conditions:
|
|
18
|
+
|
|
19
|
+
The above copyright notice and this permission notice shall be included in all
|
|
20
|
+
copies or substantial portions of the Software.
|
|
21
|
+
|
|
22
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
23
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
24
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
25
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
26
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
27
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
28
|
+
SOFTWARE.
|
|
29
|
+
```
|
|
30
|
+
|
|
31
|
+
## License Notice for Dependencies
|
|
32
|
+
|
|
33
|
+
```
|
|
34
|
+
This repository is licensed under the MIT License; however, it relies on certain third-party dependencies that are licensed under the GNU General Public License (GPL). Specifically:
|
|
35
|
+
|
|
36
|
+
- monai is licensed under the Apache License 2.0.
|
|
37
|
+
- pytorch-lightning is licensed under the Apache License 2.0.
|
|
38
|
+
|
|
39
|
+
All dependencies use permissive open-source licenses that are compatible with this project's MIT License. No GPL or other copyleft licensed dependencies are included.
|
|
40
|
+
For specific licensing information about any dependency, please refer to the respective package documentation or repository.
|
|
41
|
+
```
|