octopi 1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of octopi might be problematic. Click here for more details.

Files changed (59) hide show
  1. octopi/__init__.py +0 -0
  2. octopi/datasets/__init__.py +0 -0
  3. octopi/datasets/augment.py +84 -0
  4. octopi/datasets/cached_datset.py +113 -0
  5. octopi/datasets/dataset.py +19 -0
  6. octopi/datasets/generators.py +429 -0
  7. octopi/datasets/mixup.py +49 -0
  8. octopi/datasets/multi_config_generator.py +253 -0
  9. octopi/entry_points/__init__.py +0 -0
  10. octopi/entry_points/common.py +80 -0
  11. octopi/entry_points/create_slurm_submission.py +243 -0
  12. octopi/entry_points/run_create_targets.py +281 -0
  13. octopi/entry_points/run_evaluate.py +65 -0
  14. octopi/entry_points/run_extract_mb_picks.py +141 -0
  15. octopi/entry_points/run_extract_midpoint.py +143 -0
  16. octopi/entry_points/run_localize.py +222 -0
  17. octopi/entry_points/run_optuna.py +139 -0
  18. octopi/entry_points/run_segment_predict.py +166 -0
  19. octopi/entry_points/run_train.py +201 -0
  20. octopi/extract/__init__.py +0 -0
  21. octopi/extract/localize.py +254 -0
  22. octopi/extract/membranebound_extract.py +262 -0
  23. octopi/extract/midpoint_extract.py +193 -0
  24. octopi/io.py +457 -0
  25. octopi/losses.py +86 -0
  26. octopi/main.py +101 -0
  27. octopi/models/AttentionUnet.py +56 -0
  28. octopi/models/MedNeXt.py +111 -0
  29. octopi/models/ModelTemplate.py +36 -0
  30. octopi/models/SegResNet.py +92 -0
  31. octopi/models/Unet.py +59 -0
  32. octopi/models/UnetPlusPlus.py +47 -0
  33. octopi/models/__init__.py +0 -0
  34. octopi/models/common.py +62 -0
  35. octopi/processing/__init__.py +0 -0
  36. octopi/processing/create_targets_from_picks.py +106 -0
  37. octopi/processing/downsample.py +129 -0
  38. octopi/processing/evaluate.py +289 -0
  39. octopi/processing/importers.py +213 -0
  40. octopi/processing/my_metrics.py +26 -0
  41. octopi/processing/segmentation_from_picks.py +167 -0
  42. octopi/processing/writers.py +102 -0
  43. octopi/pytorch/__init__.py +0 -0
  44. octopi/pytorch/hyper_search.py +243 -0
  45. octopi/pytorch/model_search_submitter.py +290 -0
  46. octopi/pytorch/segmentation.py +317 -0
  47. octopi/pytorch/trainer.py +438 -0
  48. octopi/pytorch_lightning/__init__.py +0 -0
  49. octopi/pytorch_lightning/optuna_pl_ddp.py +273 -0
  50. octopi/pytorch_lightning/train_pl.py +244 -0
  51. octopi/stopping_criteria.py +143 -0
  52. octopi/submit_slurm.py +95 -0
  53. octopi/utils.py +238 -0
  54. octopi/visualization_tools.py +201 -0
  55. octopi-1.0.dist-info/LICENSE +41 -0
  56. octopi-1.0.dist-info/METADATA +209 -0
  57. octopi-1.0.dist-info/RECORD +59 -0
  58. octopi-1.0.dist-info/WHEEL +4 -0
  59. octopi-1.0.dist-info/entry_points.txt +4 -0
octopi/submit_slurm.py ADDED
@@ -0,0 +1,95 @@
1
+ def create_shellsubmit(
2
+ job_name,
3
+ output_file,
4
+ shell_name,
5
+ conda_path,
6
+ command,
7
+ num_gpus = 1,
8
+ gpu_constraint = 'h100'):
9
+
10
+ if num_gpus > 0:
11
+ slurm_gpus = f'#SBATCH --partition=gpu\n#SBATCH --gpus={gpu_constraint}:{num_gpus}'
12
+ else:
13
+ slurm_gpus = f'#SBATCH --partition=cpu'
14
+
15
+ shell_script_content = f"""#!/bin/bash
16
+
17
+ {slurm_gpus}
18
+ #SBATCH --time=18:00:00
19
+ #SBATCH --cpus-per-task=4
20
+ #SBATCH --mem-per-cpu=16G
21
+ #SBATCH --job-name={job_name}
22
+ #SBATCH --output={output_file}
23
+
24
+ ml anaconda
25
+ conda activate {conda_path}
26
+ {command}
27
+ """
28
+
29
+ # Save to file
30
+ with open(shell_name, "w") as file:
31
+ file.write(shell_script_content)
32
+
33
+ print(f"\nShell script has been created successfully as {shell_name}\n")
34
+
35
+ def create_shellsubmit_array(
36
+ job_name,
37
+ output_file,
38
+ shell_name,
39
+ conda_path,
40
+ command,
41
+ job_array = [min, max]):
42
+
43
+ shell_script_content = f"""#!/bin/bash
44
+
45
+ #SBATCH --time=18:00:00
46
+ #SBATCH --cpus-per-task=4
47
+ #SBATCH --mem-per-cpu=16G
48
+ #SBATCH --job-name={job_name}
49
+ #SBATCH --output={output_file}
50
+ #SBATCH --array={job_array[0]}-{job_array[1]}
51
+
52
+ ml anaconda
53
+ conda activate {conda_path}
54
+ {command}
55
+ """
56
+
57
+ # Save to file
58
+ with open(shell_name, "w") as file:
59
+ file.write(shell_script_content)
60
+
61
+ print(f"\nShell script has been created successfully as {shell_name}\n")
62
+
63
+ def create_multiconfig_shellsubmit(
64
+ job_name,
65
+ output_file,
66
+ shell_name,
67
+ conda_path,
68
+ base_inputs,
69
+ config_inputs,
70
+ command):
71
+
72
+ multiconfig = f"""#! /bin/bash
73
+
74
+ #SBATCH --job-name={job_name}
75
+ #SBATCH --time=24:00:00
76
+ #SBATCH --cpus-per-task=4
77
+ #SBATCH --mem-per-cpu=12G
78
+ #SBATCH --partition=cpu
79
+ #SBATCH --output={output_file}
80
+
81
+ ml anaconda
82
+ {conda_path}
83
+
84
+ {base_inputs}
85
+
86
+ {config_inputs}
87
+
88
+ {command}
89
+ """
90
+
91
+ # Save to file
92
+ with open(shell_name, "w") as file:
93
+ file.write(multiconfig)
94
+
95
+ print(f"\nShell script has been created successfully as {shell_name}\n")
octopi/utils.py ADDED
@@ -0,0 +1,238 @@
1
+ from monai.networks.nets import UNet, AttentionUnet
2
+ from typing import List, Tuple, Union
3
+ from dotenv import load_dotenv
4
+ import argparse, octopi
5
+ import torch, random, os, yaml
6
+ from typing import List
7
+ import numpy as np
8
+
9
+ ##############################################################################################################################
10
+
11
+ def mlflow_setup():
12
+
13
+ module_root = os.path.dirname(octopi.__file__)
14
+ dotenv_path = module_root.replace('src/octopi','') + '.env'
15
+ load_dotenv(dotenv_path=dotenv_path)
16
+
17
+ # MLflow setup
18
+ username = os.getenv('MLFLOW_TRACKING_USERNAME')
19
+ password = os.getenv('MLFLOW_TRACKING_PASSWORD')
20
+ if not password or not username:
21
+ print("Password not found in environment, loading from .env file...")
22
+ load_dotenv() # Loads environment variables from a .env file
23
+ username = os.getenv('MLFLOW_TRACKING_USERNAME')
24
+ password = os.getenv('MLFLOW_TRACKING_PASSWORD')
25
+
26
+ # Check again after loading .env file
27
+ if not password:
28
+ raise ValueError("Password is not set in environment variables or .env file!")
29
+ else:
30
+ print("Password loaded successfully")
31
+ os.environ['MLFLOW_TRACKING_USERNAME'] = username
32
+ os.environ['MLFLOW_TRACKING_PASSWORD'] = password
33
+
34
+ return os.getenv('MLFLOW_TRACKING_URI')
35
+
36
+ ##############################################################################################################################
37
+
38
+ def set_seed(seed):
39
+ # Set the seed for Python's random module
40
+ random.seed(seed)
41
+
42
+ # Set the seed for NumPy
43
+ np.random.seed(seed)
44
+
45
+ # Set the seed for PyTorch (both CPU and GPU)
46
+ torch.manual_seed(seed)
47
+ if torch.cuda.is_available():
48
+ torch.cuda.manual_seed(seed)
49
+ torch.cuda.manual_seed_all(seed) # If using multi-GPU
50
+
51
+ # Ensure reproducibility of operations by disabling certain optimizations
52
+ torch.backends.cudnn.deterministic = True
53
+ torch.backends.cudnn.benchmark = False
54
+
55
+ ###############################################################################################################################
56
+
57
+ def parse_list(value: str) -> List[str]:
58
+ """
59
+ Parse a string representing a list of items.
60
+ Supports formats like '[item1,item2,item3]' or 'item1,item2,item3'.
61
+ """
62
+ value = value.strip("[]") # Remove brackets if present
63
+ return [x.strip() for x in value.split(",")]
64
+
65
+ ###############################################################################################################################
66
+
67
+ def parse_int_list(value: str) -> List[int]:
68
+ """
69
+ Parse a string representing a list of integers.
70
+ Supports formats like '[1,2,3]' or '1,2,3'.
71
+ """
72
+ return [int(x) for x in parse_list(value)]
73
+
74
+ ###############################################################################################################################
75
+
76
+ def string2bool(value: str):
77
+ """
78
+ Custom function to convert string values to boolean.
79
+ """
80
+ if isinstance(value, bool):
81
+ return value
82
+ if value.lower() in {'True', 'true', 't', '1', 'yes'}:
83
+ return True
84
+ elif value.lower() in {'False', 'false', 'f', '0', 'no'}:
85
+ return False
86
+ else:
87
+ raise argparse.ArgumentTypeError(f"Invalid boolean value: {value}")
88
+
89
+ ###############################################################################################################################
90
+
91
+ def parse_target(value: str) -> Tuple[str, Union[str, None], Union[str, None]]:
92
+ """
93
+ Parse a single target string.
94
+ Expected formats:
95
+ - "name"
96
+ - "name,user_id,session_id"
97
+ """
98
+ parts = value.split(',')
99
+ if len(parts) == 1:
100
+ obj_name = parts[0]
101
+ return obj_name, None, None
102
+ elif len(parts) == 3:
103
+ obj_name, user_id, session_id = parts
104
+ return obj_name, user_id, session_id
105
+ else:
106
+ raise argparse.ArgumentTypeError(
107
+ f"Invalid target format: '{value}'. Expected 'name' or 'name,user_id,session_id'."
108
+ )
109
+
110
+
111
+ def parse_seg_target(value: str) -> List[Tuple[str, Union[str, None], Union[str, None]]]:
112
+ """
113
+ Parse segmentation targets. Each target can have the format:
114
+ - "name"
115
+ - "name,user_id,session_id"
116
+ Multiple targets can be comma-separated.
117
+ """
118
+ targets = []
119
+ for target in value.split(';'): # Use ';' as a separator for multiple targets
120
+ parts = target.split(',')
121
+ if len(parts) == 1:
122
+ name = parts[0]
123
+ targets.append((name, None, None))
124
+ elif len(parts) == 3:
125
+ name, user_id, session_id = parts
126
+ targets.append((name, user_id, session_id))
127
+ else:
128
+ raise argparse.ArgumentTypeError(
129
+ f"Invalid seg-target format: '{target}'. Expected 'name' or 'name,user_id,session_id'."
130
+ )
131
+ return targets
132
+
133
+ def parse_copick_configs(config_entries: List[str]):
134
+ """
135
+ Parse a string representing a list of CoPick configuration file paths.
136
+ """
137
+ # Process the --config arguments into a dictionary
138
+ copick_configs = {}
139
+
140
+ for config_entry in config_entries:
141
+ if ',' in config_entry:
142
+ # Entry has a session name and a config path
143
+ try:
144
+ session_name, config_path = config_entry.split(',', 1)
145
+ copick_configs[session_name] = config_path
146
+ except ValueError:
147
+ raise argparse.ArgumentTypeError(
148
+ f"Invalid format for --config entry: '{config_entry}'. Expected 'session_name,/path/to/config.json'."
149
+ )
150
+ else:
151
+ # Single configuration path without a session name
152
+ # if "default" in copick_configs:
153
+ # raise argparse.ArgumentTypeError(
154
+ # f"Only one single-path --config entry is allowed when using default configurations. "
155
+ # f"Detected duplicate: {config_entry}"
156
+ # )
157
+ # copick_configs["default"] = config_entry
158
+ copick_configs = config_entry
159
+
160
+ # if ',' in config_entry:
161
+ # parts = config_entry.split(',')
162
+ # if len(parts) == 2:
163
+ # # Entry with session name and config path
164
+ # session_name, config_path = parts
165
+ # copick_configs[session_name] = {"path": config_path, "algorithm": None}
166
+ # elif len(parts) == 3:
167
+ # # Entry with session name, config path, and algorithm
168
+ # session_name, config_path, algorithm = parts
169
+ # copick_configs[session_name] = {"path": config_path, "algorithm": algorithm}
170
+ # else:
171
+ # copick_configs = config_entry
172
+
173
+ return copick_configs
174
+
175
+ def parse_data_split(value: str) -> Tuple[float, float, float]:
176
+ """
177
+ Parse data split ratios from string input.
178
+
179
+ Args:
180
+ value: Either a single float (e.g., "0.8") or two comma-separated floats (e.g., "0.7,0.1")
181
+
182
+ Returns:
183
+ Tuple of (train_ratio, val_ratio, test_ratio)
184
+
185
+ Examples:
186
+ "0.8" -> (0.8, 0.2, 0.0)
187
+ "0.7,0.1" -> (0.7, 0.1, 0.2)
188
+ """
189
+ parts = value.split(',')
190
+
191
+ if len(parts) == 1:
192
+ # Single value provided - use it as train ratio
193
+ train_ratio = float(parts[0])
194
+ val_ratio = 1.0 - train_ratio
195
+ test_ratio = 0.0
196
+ elif len(parts) == 2:
197
+ # Two values provided - use as train and val ratios
198
+ train_ratio = float(parts[0])
199
+ val_ratio = float(parts[1])
200
+ test_ratio = 1.0 - train_ratio - val_ratio
201
+ else:
202
+ raise ValueError("Data split must be either a single value or two comma-separated values")
203
+
204
+ # Validate ratios
205
+ if train_ratio < 0 or val_ratio < 0 or test_ratio < 0:
206
+ raise ValueError("All ratios must be non-negative")
207
+
208
+ if abs(train_ratio + val_ratio + test_ratio - 1.0) > 1e-6:
209
+ raise ValueError(f"Ratios must sum to 1.0, got {train_ratio + val_ratio + test_ratio}")
210
+
211
+ return round(train_ratio, 2), round(val_ratio, 2), round(test_ratio, 2)
212
+
213
+ ##############################################################################################################################
214
+
215
+ # Create a custom dumper that uses flow style for lists only.
216
+ class InlineListDumper(yaml.SafeDumper):
217
+ def represent_list(self, data):
218
+ node = super().represent_list(data)
219
+ node.flow_style = True # Use inline style for lists
220
+ return node
221
+
222
+ def save_parameters_yaml(params: dict, output_path: str):
223
+ """
224
+ Save parameters to a YAML file.
225
+ """
226
+ InlineListDumper.add_representer(list, InlineListDumper.represent_list)
227
+ with open(output_path, 'w') as f:
228
+ yaml.dump(params, f, Dumper=InlineListDumper, default_flow_style=False, sort_keys=False)
229
+
230
+ def load_yaml(path: str) -> dict:
231
+ """
232
+ Load a YAML file and return the contents as a dictionary.
233
+ """
234
+ if os.path.exists(path):
235
+ with open(path, 'r') as f:
236
+ return yaml.safe_load(f)
237
+ else:
238
+ raise FileNotFoundError(f"File not found: {path}")
@@ -0,0 +1,201 @@
1
+ import matplotlib.colors as mcolors
2
+ from typing import Optional, List
3
+ import matplotlib.pyplot as plt
4
+ from octopi import io
5
+ import numpy as np
6
+
7
+ # Define the plotting function
8
+ def show_tomo_segmentation(tomo, seg, vol_slice):
9
+
10
+ plt.figure(figsize=(20, 10))
11
+
12
+ # Tomogram
13
+ plt.subplot(1, 3, 1)
14
+ plt.title('Tomogram')
15
+ plt.imshow(tomo[vol_slice], cmap='gray')
16
+ plt.axis('off')
17
+
18
+ # Painted Segmentation
19
+ plt.subplot(1, 3, 2)
20
+ plt.title('Painted Segmentation from Picks')
21
+ plt.imshow(seg[vol_slice], cmap='viridis')
22
+ plt.axis('off')
23
+
24
+ plt.subplot(1, 3, 3)
25
+ plt.title('Overlay')
26
+ plt.imshow(tomo[vol_slice], cmap='gray')
27
+ plt.imshow(seg[vol_slice], cmap='viridis', alpha=0.5) # Add alpha=0.5 for 50% transparency
28
+ plt.axis('off')
29
+
30
+ plt.tight_layout()
31
+ plt.show()
32
+
33
+ def show_labeled_tomo_segmentation(tomo, seg, seg_labels, unique_values, vol_slice):
34
+
35
+ # # Check unique values in segmentation to ensure correct mapping
36
+ # unique_values = np.unique(seg)
37
+
38
+ plt.figure(figsize=(20, 10))
39
+ num_classes = len(seg_labels)
40
+
41
+ # Dynamically update the labels and colormap based on unique values
42
+ seg_labels_filtered = {k: v for k, v in seg_labels.items() if k in unique_values}
43
+ num_classes = len(seg_labels_filtered)
44
+
45
+ # Create a discrete colormap
46
+ colors = plt.cm.tab20b(np.linspace(0, 1, num_classes)) # You can use other colormaps like 'Set3', 'tab20', etc.
47
+ cmap = mcolors.ListedColormap(colors)
48
+ bounds = list(seg_labels_filtered.keys()) + [max(seg_labels_filtered.keys())]
49
+ # norm = mcolors.BoundaryNorm(bounds, cmap.N)
50
+
51
+ # Tomogram plot
52
+ plt.subplot(1, 2, 1)
53
+ plt.title('Tomogram')
54
+ plt.imshow(tomo[vol_slice], cmap='gray')
55
+ plt.axis('off')
56
+
57
+ # Prediction segmentation plot
58
+ plt.subplot(1, 2, 2)
59
+ plt.title('Prediction Segmentation')
60
+ im = plt.imshow(seg[vol_slice], cmap=cmap) # Use norm and cmap for segmentation
61
+ plt.axis('off')
62
+
63
+ # Add the labeled color bar
64
+ cbar = plt.colorbar(im, ticks=list(seg_labels_filtered.keys()))
65
+ cbar.ax.set_yticklabels([seg_labels_filtered[i] for i in seg_labels_filtered.keys()]) # Set custom labels
66
+
67
+ plt.tight_layout()
68
+ plt.show()
69
+
70
+
71
+ def show_tomo_points(tomo, run, objects, user_id, vol_slice, session_id = None, slice_proximity_threshold = 3):
72
+ plt.figure(figsize=(20, 10))
73
+
74
+ plt.imshow(tomo[vol_slice],cmap='gray')
75
+ plt.axis('off')
76
+
77
+ for name,_,_ in objects:
78
+ try:
79
+ coordinates = io.get_copick_coordinates(run, name=name, user_id=user_id, session_id=session_id)
80
+ close_points = coordinates[np.abs(coordinates[:, 0] - vol_slice) <= slice_proximity_threshold]
81
+ plt.scatter(close_points[:, 2], close_points[:, 1], label=name, s=15)
82
+ except:
83
+ pass
84
+
85
+ plt.show()
86
+
87
+ def compare_tomo_points(tomo, run, objects, vol_slice, user_id1, user_id2,
88
+ session_id1 = None, session_id2 = None, slice_proximity_threshold = 3):
89
+ plt.figure(figsize=(20, 10))
90
+
91
+ plt.subplot(1, 2, 1)
92
+ plt.imshow(tomo[vol_slice],cmap='gray')
93
+ plt.title(f'{user_id1} Picks')
94
+
95
+ for name,_,_ in objects:
96
+ try:
97
+ coordinates = io.get_copick_coordinates(run, name=name, user_id=user_id1, session_id=session_id1)
98
+ close_points = coordinates[np.abs(coordinates[:, 0] - vol_slice) <= slice_proximity_threshold]
99
+ plt.scatter(close_points[:, 2], close_points[:, 1], label=name, s=15)
100
+ except:
101
+ pass
102
+
103
+ plt.subplot(1, 2, 2)
104
+ plt.imshow(tomo[vol_slice],cmap='gray')
105
+ plt.title(f'{user_id2} Picks')
106
+
107
+ for name,_,_ in objects:
108
+ try:
109
+ coordinates = io.get_copick_coordinates(run, name=name, user_id=user_id2, session_id=session_id2)
110
+ close_points = coordinates[np.abs(coordinates[:, 0] - vol_slice) <= slice_proximity_threshold]
111
+ plt.scatter(close_points[:, 2], close_points[:, 1], label=name, s=15)
112
+ except:
113
+ pass
114
+
115
+ plt.axis('off')
116
+ plt.show()
117
+
118
+ def plot_training_results(
119
+ results,
120
+ class_names: Optional[List[str]] = None,
121
+ save_plot: str = None,
122
+ fig = None, axs = None):
123
+
124
+ # Create a 2x2 subplot layout
125
+ if fig is None:
126
+ fig, axs = plt.subplots(2, 2, figsize=(12, 10))
127
+ else:
128
+ # Clear previos plots
129
+ for ax in axs.flatten():
130
+ ax.clear()
131
+
132
+ fig.suptitle("Metrics Over Epochs", fontsize=16)
133
+
134
+ # Unpack the data for loss (logged every epoch)
135
+ epochs_loss = [epoch for epoch, _ in results['loss']]
136
+ loss = [value for _, value in results['loss']]
137
+ val_epochs_loss = [epoch for epoch, _ in results['val_loss']]
138
+ val_loss = [value for _,value in results['val_loss']]
139
+
140
+ # Plot Training Loss in the top-left
141
+ axs[0, 0].plot(epochs_loss, loss, label="Training Loss")
142
+ axs[0, 0].plot(val_epochs_loss, val_loss, label='Validation Loss')
143
+ axs[0, 0].set_xlabel("Epochs")
144
+ axs[0, 0].set_ylabel("Loss")
145
+ axs[0, 0].set_title("Training Loss")
146
+ axs[0, 0].legend()
147
+ axs[0, 0].tick_params(axis='both', direction='in', top=True, right=True, length=6, width=1)
148
+
149
+ # For metrics that are logged every `val_interval` epochs
150
+ epochs_metrics = [epoch for epoch, _ in results['avg_recall']]
151
+
152
+ # Determine the number of classes and names
153
+ num_classes = len([key for key in results.keys() if key.startswith('recall_class')])
154
+
155
+ if class_names is None or len(class_names) != num_classes - 1:
156
+ class_names = [f"Class {i+1}" for i in range(num_classes)]
157
+
158
+ # Plot Recall in the top-right
159
+ for class_idx in range(num_classes):
160
+ recall_class = [value for _, value in results[f'recall_class{class_idx+1}']]
161
+ axs[0, 1].plot(epochs_metrics, recall_class, label=f"{class_names[class_idx]}")
162
+ axs[0, 1].set_xlabel("Epochs")
163
+ axs[0, 1].set_ylabel("Recall")
164
+ axs[0, 1].set_title("Recall per Class")
165
+ # axs[0, 1].legend()
166
+ axs[0, 1].tick_params(axis='both', direction='in', top=True, right=True, length=6, width=1)
167
+
168
+ # Plot Precision in the bottom-left
169
+ for class_idx in range(num_classes):
170
+ precision_class = [value for _, value in results[f'precision_class{class_idx+1}']]
171
+ axs[1, 0].plot(epochs_metrics, precision_class, label=f"{class_names[class_idx]}")
172
+ axs[1, 0].set_xlabel("Epochs")
173
+ axs[1, 0].set_ylabel("Precision")
174
+ axs[1, 0].set_title("Precision per Class")
175
+ axs[1, 0].legend()
176
+ axs[1, 0].tick_params(axis='both', direction='in', top=True, right=True, length=6, width=1)
177
+
178
+ # Plot F1 Score in the bottom-right
179
+ for class_idx in range(num_classes):
180
+ f1_class = [value for _, value in results[f'f1_class{class_idx+1}']]
181
+ axs[1, 1].plot(epochs_metrics, f1_class, label=f"{class_names[class_idx]}")
182
+ axs[1, 1].set_xlabel("Epochs")
183
+ axs[1, 1].set_ylabel("F1 Score")
184
+ axs[1, 1].set_title("F1 Score per Class")
185
+ # axs[1, 1].legend()
186
+ axs[1, 1].tick_params(axis='both', direction='in', top=True, right=True, length=6, width=1)
187
+
188
+ # Adjust layout and show plot
189
+ plt.tight_layout(rect=[0, 0, 1, 0.96]) # Leave space for the main title
190
+
191
+ fig.savefig(save_plot)
192
+ fig.canvas.draw()
193
+
194
+ return fig, axs
195
+
196
+ # if save_plot:
197
+ # fig.savefig(save_plot)
198
+ # else:
199
+ # plt.show()
200
+ # # # Just draw the plot without displaying
201
+ # # fig.canvas.draw()
@@ -0,0 +1,41 @@
1
+ # Legal
2
+
3
+ ## License for the octopi package
4
+
5
+ This package is licensed under the MIT License:
6
+
7
+ ```
8
+ MIT License
9
+
10
+ Copyright (c) 2025 Chan Zuckerberg Initiative
11
+
12
+ Permission is hereby granted, free of charge, to any person obtaining a copy
13
+ of this software and associated documentation files (the "Software"), to deal
14
+ in the Software without restriction, including without limitation the rights
15
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
16
+ copies of the Software, and to permit persons to whom the Software is
17
+ furnished to do so, subject to the following conditions:
18
+
19
+ The above copyright notice and this permission notice shall be included in all
20
+ copies or substantial portions of the Software.
21
+
22
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
23
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
24
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
25
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
26
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
27
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
28
+ SOFTWARE.
29
+ ```
30
+
31
+ ## License Notice for Dependencies
32
+
33
+ ```
34
+ This repository is licensed under the MIT License; however, it relies on certain third-party dependencies that are licensed under the GNU General Public License (GPL). Specifically:
35
+
36
+ - monai is licensed under the Apache License 2.0.
37
+ - pytorch-lightning is licensed under the Apache License 2.0.
38
+
39
+ All dependencies use permissive open-source licenses that are compatible with this project's MIT License. No GPL or other copyleft licensed dependencies are included.
40
+ For specific licensing information about any dependency, please refer to the respective package documentation or repository.
41
+ ```