octopi 1.0__py3-none-any.whl → 1.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of octopi might be problematic. Click here for more details.

Files changed (48) hide show
  1. octopi/__init__.py +1 -0
  2. octopi/datasets/cached_datset.py +1 -1
  3. octopi/datasets/generators.py +1 -1
  4. octopi/datasets/io.py +200 -0
  5. octopi/datasets/multi_config_generator.py +1 -1
  6. octopi/entry_points/common.py +9 -9
  7. octopi/entry_points/create_slurm_submission.py +16 -8
  8. octopi/entry_points/run_create_targets.py +6 -6
  9. octopi/entry_points/run_evaluate.py +4 -3
  10. octopi/entry_points/run_extract_mb_picks.py +22 -45
  11. octopi/entry_points/run_localize.py +37 -54
  12. octopi/entry_points/run_optuna.py +7 -7
  13. octopi/entry_points/run_segment_predict.py +4 -4
  14. octopi/entry_points/run_train.py +7 -8
  15. octopi/extract/localize.py +19 -12
  16. octopi/extract/membranebound_extract.py +11 -10
  17. octopi/extract/midpoint_extract.py +3 -3
  18. octopi/main.py +1 -1
  19. octopi/models/common.py +1 -1
  20. octopi/processing/create_targets_from_picks.py +11 -5
  21. octopi/processing/downsample.py +6 -10
  22. octopi/processing/evaluate.py +24 -11
  23. octopi/processing/importers.py +4 -4
  24. octopi/pytorch/hyper_search.py +2 -3
  25. octopi/pytorch/model_search_submitter.py +15 -15
  26. octopi/pytorch/segmentation.py +147 -192
  27. octopi/pytorch/segmentation_multigpu.py +162 -0
  28. octopi/pytorch/trainer.py +9 -3
  29. octopi/utils/__init__.py +0 -0
  30. octopi/utils/config.py +57 -0
  31. octopi/utils/io.py +128 -0
  32. octopi/{utils.py → utils/parsers.py} +10 -84
  33. octopi/{stopping_criteria.py → utils/stopping_criteria.py} +3 -3
  34. octopi/{visualization_tools.py → utils/visualization_tools.py} +4 -4
  35. octopi/workflows.py +236 -0
  36. octopi-1.2.0.dist-info/METADATA +120 -0
  37. octopi-1.2.0.dist-info/RECORD +62 -0
  38. {octopi-1.0.dist-info → octopi-1.2.0.dist-info}/WHEEL +1 -1
  39. octopi-1.2.0.dist-info/entry_points.txt +3 -0
  40. {octopi-1.0.dist-info → octopi-1.2.0.dist-info/licenses}/LICENSE +3 -3
  41. octopi/io.py +0 -457
  42. octopi/processing/my_metrics.py +0 -26
  43. octopi/processing/writers.py +0 -102
  44. octopi-1.0.dist-info/METADATA +0 -209
  45. octopi-1.0.dist-info/RECORD +0 -59
  46. octopi-1.0.dist-info/entry_points.txt +0 -4
  47. /octopi/{losses.py → utils/losses.py} +0 -0
  48. /octopi/{submit_slurm.py → utils/submit_slurm.py} +0 -0
@@ -0,0 +1,120 @@
1
+ Metadata-Version: 2.4
2
+ Name: octopi
3
+ Version: 1.2.0
4
+ Summary: Model architecture exploration for cryoET particle picking
5
+ Project-URL: Homepage, https://github.com/chanzuckerberg/octopi
6
+ Project-URL: Documentation, https://chanzuckerberg.github.io/octopi/
7
+ Project-URL: Issues, https://github.com/chanzuckerberg/octopi/issues
8
+ Author: Jonathan Schwartz, Kevin Zhao, Daniel Ji, Utz Ermel
9
+ License: MIT
10
+ License-File: LICENSE
11
+ Classifier: License :: OSI Approved :: MIT License
12
+ Classifier: Programming Language :: Python :: 3
13
+ Classifier: Programming Language :: Python :: 3.9
14
+ Classifier: Programming Language :: Python :: 3.10
15
+ Classifier: Programming Language :: Python :: 3.11
16
+ Classifier: Programming Language :: Python :: 3.12
17
+ Classifier: Programming Language :: Python :: 3.13
18
+ Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
19
+ Classifier: Topic :: Scientific/Engineering :: Image Processing
20
+ Classifier: Topic :: Scientific/Engineering :: Image Recognition
21
+ Requires-Python: >=3.9
22
+ Requires-Dist: copick
23
+ Requires-Dist: copick-utils
24
+ Requires-Dist: ipywidgets
25
+ Requires-Dist: kaleido
26
+ Requires-Dist: matplotlib
27
+ Requires-Dist: mlflow
28
+ Requires-Dist: monai
29
+ Requires-Dist: mrcfile
30
+ Requires-Dist: multiprocess
31
+ Requires-Dist: nibabel
32
+ Requires-Dist: optuna
33
+ Requires-Dist: optuna-integration[botorch,pytorch-lightning]
34
+ Requires-Dist: pandas
35
+ Requires-Dist: python-dotenv
36
+ Requires-Dist: requests
37
+ Requires-Dist: torch-ema
38
+ Requires-Dist: tqdm
39
+ Provides-Extra: dev
40
+ Requires-Dist: black>=24.8.0; extra == 'dev'
41
+ Requires-Dist: pre-commit>=3.8.0; extra == 'dev'
42
+ Requires-Dist: pytest>=6.2.3; extra == 'dev'
43
+ Requires-Dist: ruff>=0.6.4; extra == 'dev'
44
+ Provides-Extra: docs
45
+ Requires-Dist: mkdocs; extra == 'docs'
46
+ Requires-Dist: mkdocs-awesome-pages-plugin; extra == 'docs'
47
+ Requires-Dist: mkdocs-git-authors-plugin; extra == 'docs'
48
+ Requires-Dist: mkdocs-git-committers-plugin-2; extra == 'docs'
49
+ Requires-Dist: mkdocs-git-revision-date-localized-plugin; extra == 'docs'
50
+ Requires-Dist: mkdocs-material; extra == 'docs'
51
+ Requires-Dist: mkdocs-minify-plugin; extra == 'docs'
52
+ Requires-Dist: mkdocs-redirects; extra == 'docs'
53
+ Description-Content-Type: text/markdown
54
+
55
+ # OCTOPI 🐙🐙🐙
56
+
57
+ [![License](https://img.shields.io/pypi/l/octopi.svg?color=green)](https://github.com/chanzuckerberg/octopi/raw/main/LICENSE)
58
+ [![PyPI](https://img.shields.io/pypi/v/octopi.svg?color=green)](https://pypi.org/project/octopi)
59
+ [![Python Version](https://img.shields.io/pypi/pyversions/octopi.svg?color=green)](https://www.python.org/)
60
+
61
+ **O**bject dete**CT**ion **O**f **P**rote**I**ns. A deep learning framework for Cryo-ET 3D particle picking with autonomous model exploration capabilities.
62
+
63
+ ## 🚀 Introduction
64
+
65
+ octopi addresses a critical bottleneck in cryo-electron tomography (cryo-ET) research: the efficient identification and extraction of proteins within complex cellular environments. As advances in cryo-ET enable the collection of thousands of tomograms, the need for automated, accurate particle picking has become increasingly urgent.
66
+
67
+ Our deep learning-based pipeline streamlines the training and execution of 3D autoencoder models specifically designed for cryo-ET particle picking. Built on [copick](https://github.com/copick/copick), a storage-agnostic API, octopi seamlessly accesses tomograms and segmentations across local and remote environments.
68
+
69
+ ## 🧩 Core Features
70
+
71
+ - **3D U-Net Training**: Train and evaluate custom 3D U-Net models for particle segmentation
72
+ - **Automatic Architecture Search**: Explore optimal model configurations using Bayesian optimization via Optuna
73
+ - **Flexible Data Access**: Seamlessly work with tomograms from local storage or remote data portals
74
+ - **HPC Ready**: Built-in support for SLURM-based clusters
75
+ - **Experiment Tracking**: Integrated MLflow support for monitoring training and optimization
76
+ - **Dual Interface**: Use via command-line or Python API
77
+
78
+ ## 🚀 Quick Start
79
+
80
+ ### Installation
81
+
82
+ Octopi is availableon PyPI and can be installed using pip:
83
+ ```bash
84
+ pip install octopi
85
+ ```
86
+
87
+ ⚠️ **Note**: One of the current dependencies is currently not working with pip 25. To temporarily reduce the pip version, run:
88
+ ```bash
89
+ pip install --upgrade "pip<25"
90
+ ```
91
+
92
+ ### Basic Usage
93
+
94
+ octopi provides two main command-line interfaces:
95
+
96
+ ```bash
97
+ # Main CLI for training, inference, and data processing
98
+ octopi --help
99
+
100
+ # HPC-specific CLI for submitting jobs to SLURM clusters
101
+ octopi-slurm --help
102
+ ```
103
+
104
+ ## 📚 Documentation
105
+
106
+ For detailed documentation, tutorials, CLI and API reference, visit our [documentation](https://chanzuckerberg.github.io/octopi/).
107
+
108
+ ## 🤝 Contributing
109
+
110
+ ## Code of Conduct
111
+
112
+ This project adheres to the Contributor Covenant [code of conduct](https://github.com/chanzuckerberg/.github/blob/master/CODE_OF_CONDUCT.md).
113
+ By participating, you are expected to uphold this code.
114
+ Please report unacceptable behavior to [opensource@chanzuckerberg.com](mailto:opensource@chanzuckerberg.com).
115
+
116
+ ## 🔒 Security
117
+
118
+ If you believe you have found a security issue, please responsibly disclose by contacting us at security@chanzuckerberg.com.
119
+
120
+
@@ -0,0 +1,62 @@
1
+ octopi/__init__.py,sha256=Btl_98iBIXFtvGx47MqpfbaEVYoOMPBQn9bao7UASkQ,21
2
+ octopi/main.py,sha256=ef_zpvopl6JiN4gOT_x2QghRJriqg4iwRzBnvuiBeTo,4864
3
+ octopi/workflows.py,sha256=bRP5uPqu7PtUtQUOgtVjBZ6FulhG9uAdlcEH_kBYXPE,9235
4
+ octopi/datasets/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
5
+ octopi/datasets/augment.py,sha256=k7UXQzidPANaPuLoBzer_ZHc4_vF-kKTOyZisEnAbNw,3203
6
+ octopi/datasets/cached_datset.py,sha256=z20Ldfve1nNfLLPNx6nUE2p6i4IXx4V2SHmt99aVN_k,3918
7
+ octopi/datasets/dataset.py,sha256=9C3AuD83FMswj4MYLFn5XPqvAP1a7MQSG7UNiStg090,511
8
+ octopi/datasets/generators.py,sha256=qL43vd6gxpPrySZmt8RWujtJo3eO6sFcpS92CmNkqSA,18883
9
+ octopi/datasets/io.py,sha256=COjGoCLVbPU_xQC2UCaHGPd24a8AXtDvUqyd8kOc-Eo,6978
10
+ octopi/datasets/mixup.py,sha256=BJUAmM7ItZWFChs8glnd8RNSXR5qGW7DHscbcVc3TsU,1575
11
+ octopi/datasets/multi_config_generator.py,sha256=_eAZOzWLo6bI_H8x0quoslyjfuERVRXEVHZ4aW7j_Ok,10814
12
+ octopi/entry_points/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
13
+ octopi/entry_points/common.py,sha256=pHYQFQKdrtbYYU96klTQV4u1acUBfApkbEOc_TbuIzA,5367
14
+ octopi/entry_points/create_slurm_submission.py,sha256=rW1hr59mLLvKzmHG-uozYvJvTcqnHNZdmy2nMr5EBDc,8914
15
+ octopi/entry_points/run_create_targets.py,sha256=FPpsgVQHbepJvE2ZdC1VQKdT5RpALdhDXk7jU1RUXKo,11236
16
+ octopi/entry_points/run_evaluate.py,sha256=ab_VEDlh3RxMHfSyoMCdFH6z9G3fpHIfAKvyoYmTP2M,2835
17
+ octopi/entry_points/run_extract_mb_picks.py,sha256=u7mVqYKEvvkSN-L7v27UJVURxl74q7yiUZwTs-3zfYE,4866
18
+ octopi/entry_points/run_extract_midpoint.py,sha256=O6GdkSD7oXIpSVqizOc6PHhv9nunz3j0RucmYQ2yryM,5742
19
+ octopi/entry_points/run_localize.py,sha256=-CF9w3XKLjiGIAN-kWgeobpfFYfIvZxRpxR0D8fyb3M,8264
20
+ octopi/entry_points/run_optuna.py,sha256=U1VYzwE0eYV8DNnms0gbjytqsDwSvHJ5BsMj48iZz8k,5784
21
+ octopi/entry_points/run_segment_predict.py,sha256=5RvnxUDDNzBfRv4dOGJl6SieAlF5SR5stHL_Z6BgUh0,5439
22
+ octopi/entry_points/run_train.py,sha256=O9uDunSyQGpEXqTTOYNSgO3LvTZt49K9HzcN0sGuD14,8095
23
+ octopi/extract/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
24
+ octopi/extract/localize.py,sha256=SQY1p2Q9sYrB_yjOZySYqoWwwlfgcvBTR1f_Fpi2Wxc,9528
25
+ octopi/extract/membranebound_extract.py,sha256=kTV_QFBbvzd3STJsC55QKoRTYcvg_raV3C757Y_9HBo,11087
26
+ octopi/extract/midpoint_extract.py,sha256=iEUZJShlhBC5qAdP5w-zjjEOMz_nu9JZF6jhbJvcV94,7036
27
+ octopi/models/AttentionUnet.py,sha256=r185aXRtfXhN-n8FxA-Sjz18nqpxHH_2t2uadrH-Mgs,1991
28
+ octopi/models/MedNeXt.py,sha256=9q0FsyrqTx211hCbDv0Lm2XflzXL_bGA4-76BscziGk,4875
29
+ octopi/models/ModelTemplate.py,sha256=X80EOXwSovCjmVb7x-0_JmRjHfDfLByDdd60MrgFTyw,1084
30
+ octopi/models/SegResNet.py,sha256=1dK8dy_7hHHKYZLsTYafl__7MxQOlWGbBqQWPGxHSXg,3609
31
+ octopi/models/Unet.py,sha256=7RGT_nl7ogsNlS3Y3Qexe305Ym9NlK9CV0y45g2DEU4,2171
32
+ octopi/models/UnetPlusPlus.py,sha256=fnV-SvJV8B432KJXQAtdwLy8Va6DJ4fRB_7a1mZiqTU,1529
33
+ octopi/models/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
34
+ octopi/models/common.py,sha256=qxx1ak2ayVC9fPbpjs6XtfUNPSX-PGywq3ySsXGdY74,2313
35
+ octopi/processing/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
36
+ octopi/processing/create_targets_from_picks.py,sha256=5dbCCjFXq55A-zD9SX4mXUKRq_lPWriTlnRYb958hBc,4377
37
+ octopi/processing/downsample.py,sha256=u3V2HULdKTLEthzI-MqJmI-NqpiyX0Vg61i77-j2MKQ,5392
38
+ octopi/processing/evaluate.py,sha256=Xxx8hJrWzZ_pSxEs89CC_6AxGUch3_RSHpH7UtEBs1c,14139
39
+ octopi/processing/importers.py,sha256=Rt-Y3QtFJz_ewCkdqtJpmiiI4xFSFeXaRcL1ZBn0bd0,8876
40
+ octopi/processing/segmentation_from_picks.py,sha256=jah1gAXEn09LIok1Cb8IeVN-fT3jktcVPfjbOFHkgg0,7089
41
+ octopi/pytorch/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
42
+ octopi/pytorch/hyper_search.py,sha256=ywwGEtTHJUkrBCOYSycrpWccUvxJlxny50q5GnWoWLA,9376
43
+ octopi/pytorch/model_search_submitter.py,sha256=zEfvlcB9AthtanUaqV3a8f2HGG_eo5gvIOZDOvUziJo,11267
44
+ octopi/pytorch/segmentation.py,sha256=nPW4tkQfgA4fC02j_bcLyKvS7SL9ghQnENKT7_YC_gc,11266
45
+ octopi/pytorch/segmentation_multigpu.py,sha256=KK8BZIdB1PJBs_jCv7WBS0xORR_XlpcKmmPf3AgIfYY,6781
46
+ octopi/pytorch/trainer.py,sha256=NAoOf00oh_-Sv5FZUaq_L5v7Y_VV_VtuExIaZCrvxfs,17760
47
+ octopi/pytorch_lightning/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
48
+ octopi/pytorch_lightning/optuna_pl_ddp.py,sha256=ynD5i81IP-awr6OD9GDurjrQK-5Kc079qPaukphTHnA,11924
49
+ octopi/pytorch_lightning/train_pl.py,sha256=igOHzU_mUdZRQGhoOGW5vmxJcHFcw5fAPHfVCIZ0eG4,10220
50
+ octopi/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
51
+ octopi/utils/config.py,sha256=RLrHmyjnQ0tXX2_bls--ZecbgRTNlnGtkvLl7YWtYtc,1833
52
+ octopi/utils/io.py,sha256=-PMzzBiIKbrUAIw7NcmeAj2V0yy9r4N2QRMIGlm2SII,4623
53
+ octopi/utils/losses.py,sha256=fs9yR80Hs-hL07WgVMkRy5N81vzP5m9XBCvzO76bIPU,3097
54
+ octopi/utils/parsers.py,sha256=InHLN-iYq0bWoKAHCX-hX82HO8b4pcZZ1on1LcwcsXo,5709
55
+ octopi/utils/stopping_criteria.py,sha256=q-tkcqyZ3GfqR8BzWu1eKE8dt2pNZ3MOD9dcMe2T7fU,6112
56
+ octopi/utils/submit_slurm.py,sha256=cRbJTESbPFCt6Cq4Hat2uPOQKFYMPcQxuNs0jc1ygUA,1945
57
+ octopi/utils/visualization_tools.py,sha256=I-fmXKk9eyaLTbvsYxx8H6nlEvUOH94NJ_l2dayXXe4,7364
58
+ octopi-1.2.0.dist-info/METADATA,sha256=UWZ2lfjUNDd34EDtL47LdpIp4p_TbkemNWhqjsehQ6U,5062
59
+ octopi-1.2.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
60
+ octopi-1.2.0.dist-info/entry_points.txt,sha256=n4I7aA9RS52l1eYPidQx63VhxT9pPsMK4xitlBm7X7k,90
61
+ octopi-1.2.0.dist-info/licenses/LICENSE,sha256=DRoLUEjEHgru_AwKHhii9UIOt6QA65am4WQd7UFzqnE,1822
62
+ octopi-1.2.0.dist-info/RECORD,,
@@ -1,4 +1,4 @@
1
1
  Wheel-Version: 1.0
2
- Generator: poetry-core 2.1.3
2
+ Generator: hatchling 1.27.0
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
@@ -0,0 +1,3 @@
1
+ [console_scripts]
2
+ octopi = octopi.main:cli_main
3
+ octopi-slurm = octopi.main:cli_slurm_main
@@ -17,15 +17,15 @@ copies of the Software, and to permit persons to whom the Software is
17
17
  furnished to do so, subject to the following conditions:
18
18
 
19
19
  The above copyright notice and this permission notice shall be included in all
20
- copies or substantial portions of the Software.
20
+ copies or substantial portions of the "Software".
21
21
 
22
22
  THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
23
23
  IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
24
24
  FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
25
25
  AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
26
26
  LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
27
- OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
28
- SOFTWARE.
27
+ OUT OF OR IN CONNECTION WITH THE "SOFTWARE" OR THE USE OR OTHER DEALINGS IN THE
28
+ "SOFTWARE".
29
29
  ```
30
30
 
31
31
  ## License Notice for Dependencies
octopi/io.py DELETED
@@ -1,457 +0,0 @@
1
- from monai.data import DataLoader, CacheDataset, Dataset
2
- from monai.transforms import (
3
- Compose,
4
- NormalizeIntensityd,
5
- EnsureChannelFirstd,
6
- )
7
- from sklearn.model_selection import train_test_split
8
- import copick, torch, os, json, random
9
- from collections import defaultdict
10
- from octopi import utils
11
- from typing import List
12
- from tqdm import tqdm
13
- import numpy as np
14
-
15
- ##############################################################################################################################
16
-
17
- def load_training_data(root,
18
- runIDs: List[str],
19
- voxel_spacing: float,
20
- tomo_algorithm: str,
21
- segmenation_name: str,
22
- segmentation_session_id: str = None,
23
- segmentation_user_id: str = None,
24
- progress_update: bool = True):
25
-
26
- data_dicts = []
27
- # Use tqdm for progress tracking only if progress_update is True
28
- iterable = tqdm(runIDs, desc="Loading Training Data") if progress_update else runIDs
29
- for runID in iterable:
30
- run = root.get_run(str(runID))
31
- tomogram = get_tomogram_array(run, voxel_spacing, tomo_algorithm)
32
- segmentation = get_segmentation_array(run,
33
- voxel_spacing,
34
- segmenation_name,
35
- segmentation_session_id,
36
- segmentation_user_id)
37
- data_dicts.append({"image": tomogram, "label": segmentation})
38
-
39
- return data_dicts
40
-
41
- ##############################################################################################################################
42
-
43
- def load_predict_data(root,
44
- runIDs: List[str],
45
- voxel_spacing: float,
46
- tomo_algorithm: str):
47
-
48
- data_dicts = []
49
- for runID in tqdm(runIDs):
50
- run = root.get_run(str(runID))
51
- tomogram = get_tomogram_array(run, voxel_spacing, tomo_algorithm)
52
- data_dicts.append({"image": tomogram})
53
-
54
- return data_dicts
55
-
56
- ##############################################################################################################################
57
-
58
- def create_predict_dataloader(
59
- root,
60
- voxel_spacing: float,
61
- tomo_algorithm: str,
62
- runIDs: str = None,
63
- ):
64
-
65
- # define pre transforms
66
- pre_transforms = Compose(
67
- [ EnsureChannelFirstd(keys=["image"], channel_dim="no_channel"),
68
- NormalizeIntensityd(keys=["image"]),
69
- ])
70
-
71
- # Split trainRunIDs, validateRunIDs, testRunIDs
72
- if runIDs is None:
73
- runIDs = [run.name for run in root.runs]
74
- test_files = load_predict_data(root, runIDs, voxel_spacing, tomo_algorithm)
75
-
76
- bs = min( len(test_files), 4)
77
- test_ds = CacheDataset(data=test_files, transform=pre_transforms)
78
- test_loader = DataLoader(test_ds,
79
- batch_size=bs,
80
- shuffle=False,
81
- num_workers=4,
82
- pin_memory=torch.cuda.is_available())
83
- return test_loader, test_ds
84
-
85
- ##############################################################################################################################
86
-
87
- def get_tomogram_array(run,
88
- voxel_size: float = 10,
89
- tomo_type: str = 'wbp',
90
- raise_error: bool = True):
91
-
92
- voxel_spacing_obj = run.get_voxel_spacing(voxel_size)
93
-
94
- if voxel_spacing_obj is None:
95
- # Query Avaiable Voxel Spacings
96
- availableVoxelSpacings = [tomo.voxel_size for tomo in run.voxel_spacings]
97
-
98
- # Report to the user which voxel spacings they can use
99
- message = (f"\n[Warning] No tomogram found for {run.name} with voxel size {voxel_size} and tomogram type {tomo_type}"
100
- f"\nAvailable spacings are: {', '.join(map(str, availableVoxelSpacings))}\n" )
101
- if raise_error:
102
- raise ValueError(message)
103
- else:
104
- print(message)
105
- return None
106
-
107
- tomogram = voxel_spacing_obj.get_tomogram(tomo_type)
108
- if tomogram is None:
109
- # Get available algorithms
110
- availableAlgorithms = [tomo.tomo_type for tomo in run.get_voxel_spacing(voxel_size).tomograms]
111
-
112
- # Report to the user which algorithms are available
113
- message = (f"\n[Warning] No tomogram found for {run.name} with voxel size {voxel_size} and tomogram type {tomo_type}"
114
- f"\nAvailable algorithms are: {', '.join(availableAlgorithms)}\n")
115
- if raise_error:
116
- raise ValueError(message)
117
- else:
118
- print(message)
119
- return None
120
-
121
- return tomogram.numpy().astype(np.float32)
122
-
123
- ##############################################################################################################################
124
-
125
- def get_segmentation_array(run,
126
- voxel_spacing: float,
127
- segmentation_name: str,
128
- session_id=None,
129
- user_id=None,
130
- raise_error: bool = True):
131
-
132
- seg = run.get_segmentations(name=segmentation_name,
133
- session_id = session_id,
134
- user_id = user_id,
135
- voxel_size=float(voxel_spacing))
136
-
137
- # No Segmentations Are Available, Result in Error
138
- if len(seg) == 0:
139
- # Get all available segmentations with their metadata
140
- available_segs = run.get_segmentations(voxel_size=voxel_spacing)
141
- seg_info = [(s.name, s.user_id, s.session_id) for s in available_segs]
142
-
143
- # Format the information for display
144
- seg_details = [f"(name: {name}, user_id: {uid}, session_id: {sid})"
145
- for name, uid, sid in seg_info]
146
-
147
- message = ( f'\nNo segmentation found matching:\n'
148
- f' name: {segmentation_name}, user_id: {user_id}, session_id: {session_id}\n'
149
- f'Available segmentations in {run.name} are:\n ' +
150
- '\n '.join(seg_details) )
151
- if raise_error:
152
- raise ValueError(message)
153
- else:
154
- print(message)
155
- return None
156
-
157
- # No Segmentations Are Available, Result in Error
158
- if len(seg) > 1:
159
- print(f'[Warning] More Than 1 Segmentation is Available for the Query Information. '
160
- f'Available Segmentations are: {seg} '
161
- f'Defaulting to Loading: {seg[0]}\n')
162
- seg = seg[0]
163
-
164
- return seg.numpy().astype(np.int8)
165
-
166
- ##############################################################################################################################
167
-
168
- def get_copick_coordinates(run, # CoPick run object containing the segmentation data
169
- name: str, # Name of the object or protein for which coordinates are being extracted
170
- user_id: str, # Identifier of the user that generated the picks
171
- session_id: str = None, # Identifier of the session that generated the picks
172
- voxel_size: float = 10, # Voxel size of the tomogram, used for scaling the coordinates
173
- raise_error: bool = True):
174
-
175
- # Retrieve the pick points associated with the specified object and user ID
176
- picks = run.get_picks(object_name=name, user_id=user_id, session_id=session_id)
177
-
178
- if len(picks) == 0:
179
- # Get all available segmentations with their metadata
180
-
181
- available_picks = run.get_picks()
182
- picks_info = [(s.pickable_object_name, s.user_id, s.session_id) for s in available_picks]
183
-
184
- # Format the information for display
185
- picks_details = [f"(name: {name}, user_id: {uid}, session_id: {sid})"
186
- for name, uid, sid in picks_info]
187
-
188
- message = ( f'\nNo picks found matching:\n'
189
- f' name: {name}, user_id: {user_id}, session_id: {session_id}\n'
190
- f'Available picks are:\n '
191
- + '\n '.join(picks_details) )
192
- if raise_error:
193
- raise ValueError(message)
194
- else:
195
- print(message)
196
- return None
197
- elif len(picks) > 1:
198
- # Format pick information for display
199
- picks_info = [(p.pickable_object_name, p.user_id, p.session_id) for p in picks]
200
- picks_details = [f"(name: {name}, user_id: {uid}, session_id: {sid})"
201
- for name, uid, sid in picks_info]
202
-
203
- print(f'[Warning] More than 1 pick is available for the query information.'
204
- f'\nAvailable picks are:\n ' +
205
- '\n '.join(picks_details) +
206
- f'\nDefaulting to loading:\n {picks[0]}\n')
207
- points = picks[0].points
208
-
209
- # Initialize an array to store the coordinates
210
- nPoints = len(picks[0].points) # Number of points retrieved
211
- coordinates = np.zeros([len(picks[0].points), 3]) # Create an empty array to hold the (z, y, x) coordinates
212
-
213
- # Iterate over all points and convert their locations to coordinates in voxel space
214
- for ii in range(nPoints):
215
- coordinates[ii,] = [points[ii].location.z / voxel_size, # Scale z-coordinate by voxel size
216
- points[ii].location.y / voxel_size, # Scale y-coordinate by voxel size
217
- points[ii].location.x / voxel_size] # Scale x-coordinate by voxel size
218
-
219
- # Return the array of coordinates
220
- return coordinates
221
-
222
-
223
- ##############################################################################################################################
224
-
225
- def adjust_to_multiple(value, multiple = 16):
226
- return int((value // multiple) * multiple)
227
-
228
- def get_input_dimensions(dataset, crop_size: int):
229
- nx = dataset[0]['image'].shape[1]
230
- if crop_size > nx:
231
- first_dim = adjust_to_multiple(nx/2)
232
- return first_dim, crop_size, crop_size
233
- else:
234
- return crop_size, crop_size, crop_size
235
-
236
- def get_num_classes(copick_config_path: str):
237
-
238
- root = copick.from_file(copick_config_path)
239
- return len(root.pickable_objects) + 1
240
-
241
- def split_multiclass_dataset(runIDs,
242
- train_ratio: float = 0.7,
243
- val_ratio: float = 0.15,
244
- test_ratio: float = 0.15,
245
- return_test_dataset: bool = True,
246
- random_state: int = 42):
247
- """
248
- Splits a given dataset into three subsets: training, validation, and testing. If the dataset
249
- has categories (as tuples), splits are balanced across all categories. If the dataset is a 1D
250
- list, it is split without categorization.
251
-
252
- Parameters:
253
- - runIDs: A list of items to split. It can be a 1D list or a list of tuples (category, value).
254
- - train_ratio: Proportion of the dataset for training.
255
- - val_ratio: Proportion of the dataset for validation.
256
- - test_ratio: Proportion of the dataset for testing.
257
- - return_test_dataset: Whether to return the test dataset.
258
- - random_state: Random state for reproducibility.
259
-
260
- Returns:
261
- - trainRunIDs: Training subset.
262
- - valRunIDs: Validation subset.
263
- - testRunIDs: Testing subset (if return_test_dataset is True, otherwise None).
264
- """
265
-
266
- # Ensure the ratios add up to 1
267
- assert train_ratio + val_ratio + test_ratio == 1.0, "Ratios must add up to 1.0"
268
-
269
- # Check if the dataset has categories
270
- if isinstance(runIDs[0], tuple) and len(runIDs[0]) == 2:
271
- # Group by category
272
- grouped = defaultdict(list)
273
- for item in runIDs:
274
- grouped[item[0]].append(item)
275
-
276
- # Split each category
277
- trainRunIDs, valRunIDs, testRunIDs = [], [], []
278
- for category, items in grouped.items():
279
- # Shuffle for randomness
280
- random.shuffle(items)
281
- # Split into train and remaining
282
- train_items, remaining = train_test_split(items, test_size=(1 - train_ratio), random_state=random_state)
283
- trainRunIDs.extend(train_items)
284
-
285
- if return_test_dataset:
286
- # Split remaining into validation and test
287
- val_items, test_items = train_test_split(
288
- remaining,
289
- test_size=(test_ratio / (val_ratio + test_ratio)),
290
- random_state=random_state,
291
- )
292
- valRunIDs.extend(val_items)
293
- testRunIDs.extend(test_items)
294
- else:
295
- valRunIDs.extend(remaining)
296
- testRunIDs = []
297
- else:
298
- # If no categories, split as a 1D list
299
- trainRunIDs, remaining = train_test_split(runIDs, test_size=(1 - train_ratio), random_state=random_state)
300
- if return_test_dataset:
301
- valRunIDs, testRunIDs = train_test_split(
302
- remaining,
303
- test_size=(test_ratio / (val_ratio + test_ratio)),
304
- random_state=random_state,
305
- )
306
- else:
307
- valRunIDs = remaining
308
- testRunIDs = []
309
-
310
- return trainRunIDs, valRunIDs, testRunIDs
311
-
312
- ##############################################################################################################################
313
-
314
- def load_copick_config(path: str):
315
-
316
- if os.path.isfile(path):
317
- root = copick.from_file(path)
318
- else:
319
- raise FileNotFoundError(f"Copick Config Path does not exist: {path}")
320
-
321
- return root
322
-
323
- ##############################################################################################################################
324
-
325
- # Helper function to flatten and serialize nested parameters
326
- def flatten_params(params, parent_key=''):
327
- flattened = {}
328
- for key, value in params.items():
329
- new_key = f"{parent_key}.{key}" if parent_key else key
330
- if isinstance(value, dict):
331
- flattened.update(flatten_params(value, new_key))
332
- elif isinstance(value, list):
333
- flattened[new_key] = ', '.join(map(str, value)) # Convert list to a comma-separated string
334
- else:
335
- flattened[new_key] = value
336
- return flattened
337
-
338
- # Manually join specific lists into strings for inline display
339
- def prepare_for_inline_json(data):
340
- for key in ["trainRunIDs", "valRunIDs", "testRunIDs"]:
341
- if key in data['dataloader']:
342
- data['dataloader'][key] = f"[{', '.join(map(repr, data['dataloader'][key]))}]"
343
-
344
- for key in ['channels', 'strides']:
345
- if key in data['model']:
346
- data['model'][key] = f"[{', '.join(map(repr, data['model'][key]))}]"
347
- return data
348
-
349
- def get_optimizer_parameters(trainer):
350
-
351
- optimizer_parameters = {
352
- 'my_num_samples': trainer.num_samples,
353
- 'val_interval': trainer.val_interval,
354
- 'lr': trainer.optimizer.param_groups[0]['lr'],
355
- 'optimizer': trainer.optimizer.__class__.__name__,
356
- 'metrics_function': trainer.metrics_function.__class__.__name__,
357
- 'loss_function': trainer.loss_function.__class__.__name__,
358
- }
359
-
360
- # Log Tversky Loss Parameters
361
- if trainer.loss_function.__class__.__name__ == 'TverskyLoss':
362
- optimizer_parameters['alpha'] = trainer.loss_function.alpha
363
- elif trainer.loss_function.__class__.__name__ == 'FocalLoss':
364
- optimizer_parameters['gamma'] = trainer.loss_function.gamma
365
- elif trainer.loss_function.__class__.__name__ == 'WeightedFocalTverskyLoss':
366
- optimizer_parameters['alpha'] = trainer.loss_function.alpha
367
- optimizer_parameters['gamma'] = trainer.loss_function.gamma
368
- optimizer_parameters['weight_tversky'] = trainer.loss_function.weight_tversky
369
- elif trainer.loss_function.__class__.__name__ == 'FocalTverskyLoss':
370
- optimizer_parameters['alpha'] = trainer.loss_function.alpha
371
- optimizer_parameters['gamma'] = trainer.loss_function.gamma
372
-
373
- return optimizer_parameters
374
-
375
- def save_parameters_to_yaml(model, trainer, dataloader, filename: str):
376
-
377
- parameters = {
378
- 'model': model.get_model_parameters(),
379
- 'optimizer': get_optimizer_parameters(trainer),
380
- 'dataloader': dataloader.get_dataloader_parameters()
381
- }
382
-
383
- utils.save_parameters_yaml(parameters, filename)
384
- print(f"Training Parameters saved to {filename}")
385
-
386
- def prepare_inline_results_json(results):
387
- # Traverse the dictionary and format lists of lists as inline JSON
388
- for key, value in results.items():
389
- # Check if the value is a list of lists (like [[epoch, value], ...])
390
- if isinstance(value, list) and all(isinstance(item, list) and len(item) == 2 for item in value):
391
- # Format the list of lists as a single-line JSON string
392
- results[key] = json.dumps(value)
393
- return results
394
-
395
- # Check to See if I'm Happy with This... Maybe Save as H5 File?
396
- def save_results_to_json(results, filename: str):
397
-
398
- results = prepare_inline_results_json(results)
399
- with open(os.path.join(filename), "w") as json_file:
400
- json.dump( results, json_file, indent=4 )
401
- print(f"Training Results saved to {filename}")
402
-
403
- ##############################################################################################################################
404
-
405
- # def save_parameters_to_json(model, trainer, dataloader, filename: str):
406
-
407
- # parameters = {
408
- # 'model': model.get_model_parameters(),
409
- # 'optimizer': get_optimizer_parameters(trainer),
410
- # 'dataloader': dataloader.get_dataloader_parameters()
411
- # }
412
- # parameters = prepare_for_inline_json(parameters)
413
-
414
- # with open(os.path.join(filename), "w") as json_file:
415
- # json.dump( parameters, json_file, indent=4 )
416
- # print(f"Training Parameters saved to {filename}")
417
-
418
- # def split_datasets(runIDs,
419
- # train_ratio: float = 0.7,
420
- # val_ratio: float = 0.15,
421
- # test_ratio: float = 0.15,
422
- # return_test_dataset: bool = True,
423
- # random_state: int = 42):
424
- # """
425
- # Splits a given dataset into three subsets: training, validation, and testing. The proportions
426
- # of each subset are determined by the provided ratios, ensuring that they add up to 1. The
427
- # function uses a fixed random state for reproducibility.
428
-
429
- # Parameters:
430
- # - runIDs: The complete dataset that needs to be split.
431
- # - train_ratio: The proportion of the dataset to be used for training.
432
- # - val_ratio: The proportion of the dataset to be used for validation.
433
- # - test_ratio: The proportion of the dataset to be used for testing.
434
-
435
- # Returns:
436
- # - trainRunIDs: The subset of the dataset used for training.
437
- # - valRunIDs: The subset of the dataset used for validation.
438
- # - testRunIDs: The subset of the dataset used for testing.
439
- # """
440
-
441
- # # Ensure the ratios add up to 1
442
- # assert train_ratio + val_ratio + test_ratio == 1.0, "Ratios must add up to 1.0"
443
-
444
- # # First, split into train and remaining (30%)
445
- # trainRunIDs, valRunIDs = train_test_split(runIDs, test_size=(1 - train_ratio), random_state=random_state)
446
-
447
- # # (Optional) split the remaining into validation and test
448
- # if return_test_dataset:
449
- # valRunIDs, testRunIDs = train_test_split(
450
- # valRunIDs,
451
- # test_size=(test_ratio / (val_ratio + test_ratio)),
452
- # random_state=random_state,
453
- # )
454
- # else:
455
- # testRunIDs = None
456
-
457
- # return trainRunIDs, valRunIDs, testRunIDs