octopi 1.0__py3-none-any.whl → 1.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of octopi might be problematic. Click here for more details.
- octopi/__init__.py +1 -0
- octopi/datasets/cached_datset.py +1 -1
- octopi/datasets/generators.py +1 -1
- octopi/datasets/io.py +200 -0
- octopi/datasets/multi_config_generator.py +1 -1
- octopi/entry_points/common.py +9 -9
- octopi/entry_points/create_slurm_submission.py +16 -8
- octopi/entry_points/run_create_targets.py +6 -6
- octopi/entry_points/run_evaluate.py +4 -3
- octopi/entry_points/run_extract_mb_picks.py +22 -45
- octopi/entry_points/run_localize.py +37 -54
- octopi/entry_points/run_optuna.py +7 -7
- octopi/entry_points/run_segment_predict.py +4 -4
- octopi/entry_points/run_train.py +7 -8
- octopi/extract/localize.py +19 -12
- octopi/extract/membranebound_extract.py +11 -10
- octopi/extract/midpoint_extract.py +3 -3
- octopi/main.py +1 -1
- octopi/models/common.py +1 -1
- octopi/processing/create_targets_from_picks.py +11 -5
- octopi/processing/downsample.py +6 -10
- octopi/processing/evaluate.py +24 -11
- octopi/processing/importers.py +4 -4
- octopi/pytorch/hyper_search.py +2 -3
- octopi/pytorch/model_search_submitter.py +15 -15
- octopi/pytorch/segmentation.py +147 -192
- octopi/pytorch/segmentation_multigpu.py +162 -0
- octopi/pytorch/trainer.py +9 -3
- octopi/utils/__init__.py +0 -0
- octopi/utils/config.py +57 -0
- octopi/utils/io.py +128 -0
- octopi/{utils.py → utils/parsers.py} +10 -84
- octopi/{stopping_criteria.py → utils/stopping_criteria.py} +3 -3
- octopi/{visualization_tools.py → utils/visualization_tools.py} +4 -4
- octopi/workflows.py +236 -0
- octopi-1.2.0.dist-info/METADATA +120 -0
- octopi-1.2.0.dist-info/RECORD +62 -0
- {octopi-1.0.dist-info → octopi-1.2.0.dist-info}/WHEEL +1 -1
- octopi-1.2.0.dist-info/entry_points.txt +3 -0
- {octopi-1.0.dist-info → octopi-1.2.0.dist-info/licenses}/LICENSE +3 -3
- octopi/io.py +0 -457
- octopi/processing/my_metrics.py +0 -26
- octopi/processing/writers.py +0 -102
- octopi-1.0.dist-info/METADATA +0 -209
- octopi-1.0.dist-info/RECORD +0 -59
- octopi-1.0.dist-info/entry_points.txt +0 -4
- /octopi/{losses.py → utils/losses.py} +0 -0
- /octopi/{submit_slurm.py → utils/submit_slurm.py} +0 -0
|
@@ -0,0 +1,120 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: octopi
|
|
3
|
+
Version: 1.2.0
|
|
4
|
+
Summary: Model architecture exploration for cryoET particle picking
|
|
5
|
+
Project-URL: Homepage, https://github.com/chanzuckerberg/octopi
|
|
6
|
+
Project-URL: Documentation, https://chanzuckerberg.github.io/octopi/
|
|
7
|
+
Project-URL: Issues, https://github.com/chanzuckerberg/octopi/issues
|
|
8
|
+
Author: Jonathan Schwartz, Kevin Zhao, Daniel Ji, Utz Ermel
|
|
9
|
+
License: MIT
|
|
10
|
+
License-File: LICENSE
|
|
11
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
12
|
+
Classifier: Programming Language :: Python :: 3
|
|
13
|
+
Classifier: Programming Language :: Python :: 3.9
|
|
14
|
+
Classifier: Programming Language :: Python :: 3.10
|
|
15
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
16
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
17
|
+
Classifier: Programming Language :: Python :: 3.13
|
|
18
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
19
|
+
Classifier: Topic :: Scientific/Engineering :: Image Processing
|
|
20
|
+
Classifier: Topic :: Scientific/Engineering :: Image Recognition
|
|
21
|
+
Requires-Python: >=3.9
|
|
22
|
+
Requires-Dist: copick
|
|
23
|
+
Requires-Dist: copick-utils
|
|
24
|
+
Requires-Dist: ipywidgets
|
|
25
|
+
Requires-Dist: kaleido
|
|
26
|
+
Requires-Dist: matplotlib
|
|
27
|
+
Requires-Dist: mlflow
|
|
28
|
+
Requires-Dist: monai
|
|
29
|
+
Requires-Dist: mrcfile
|
|
30
|
+
Requires-Dist: multiprocess
|
|
31
|
+
Requires-Dist: nibabel
|
|
32
|
+
Requires-Dist: optuna
|
|
33
|
+
Requires-Dist: optuna-integration[botorch,pytorch-lightning]
|
|
34
|
+
Requires-Dist: pandas
|
|
35
|
+
Requires-Dist: python-dotenv
|
|
36
|
+
Requires-Dist: requests
|
|
37
|
+
Requires-Dist: torch-ema
|
|
38
|
+
Requires-Dist: tqdm
|
|
39
|
+
Provides-Extra: dev
|
|
40
|
+
Requires-Dist: black>=24.8.0; extra == 'dev'
|
|
41
|
+
Requires-Dist: pre-commit>=3.8.0; extra == 'dev'
|
|
42
|
+
Requires-Dist: pytest>=6.2.3; extra == 'dev'
|
|
43
|
+
Requires-Dist: ruff>=0.6.4; extra == 'dev'
|
|
44
|
+
Provides-Extra: docs
|
|
45
|
+
Requires-Dist: mkdocs; extra == 'docs'
|
|
46
|
+
Requires-Dist: mkdocs-awesome-pages-plugin; extra == 'docs'
|
|
47
|
+
Requires-Dist: mkdocs-git-authors-plugin; extra == 'docs'
|
|
48
|
+
Requires-Dist: mkdocs-git-committers-plugin-2; extra == 'docs'
|
|
49
|
+
Requires-Dist: mkdocs-git-revision-date-localized-plugin; extra == 'docs'
|
|
50
|
+
Requires-Dist: mkdocs-material; extra == 'docs'
|
|
51
|
+
Requires-Dist: mkdocs-minify-plugin; extra == 'docs'
|
|
52
|
+
Requires-Dist: mkdocs-redirects; extra == 'docs'
|
|
53
|
+
Description-Content-Type: text/markdown
|
|
54
|
+
|
|
55
|
+
# OCTOPI 🐙🐙🐙
|
|
56
|
+
|
|
57
|
+
[](https://github.com/chanzuckerberg/octopi/raw/main/LICENSE)
|
|
58
|
+
[](https://pypi.org/project/octopi)
|
|
59
|
+
[](https://www.python.org/)
|
|
60
|
+
|
|
61
|
+
**O**bject dete**CT**ion **O**f **P**rote**I**ns. A deep learning framework for Cryo-ET 3D particle picking with autonomous model exploration capabilities.
|
|
62
|
+
|
|
63
|
+
## 🚀 Introduction
|
|
64
|
+
|
|
65
|
+
octopi addresses a critical bottleneck in cryo-electron tomography (cryo-ET) research: the efficient identification and extraction of proteins within complex cellular environments. As advances in cryo-ET enable the collection of thousands of tomograms, the need for automated, accurate particle picking has become increasingly urgent.
|
|
66
|
+
|
|
67
|
+
Our deep learning-based pipeline streamlines the training and execution of 3D autoencoder models specifically designed for cryo-ET particle picking. Built on [copick](https://github.com/copick/copick), a storage-agnostic API, octopi seamlessly accesses tomograms and segmentations across local and remote environments.
|
|
68
|
+
|
|
69
|
+
## 🧩 Core Features
|
|
70
|
+
|
|
71
|
+
- **3D U-Net Training**: Train and evaluate custom 3D U-Net models for particle segmentation
|
|
72
|
+
- **Automatic Architecture Search**: Explore optimal model configurations using Bayesian optimization via Optuna
|
|
73
|
+
- **Flexible Data Access**: Seamlessly work with tomograms from local storage or remote data portals
|
|
74
|
+
- **HPC Ready**: Built-in support for SLURM-based clusters
|
|
75
|
+
- **Experiment Tracking**: Integrated MLflow support for monitoring training and optimization
|
|
76
|
+
- **Dual Interface**: Use via command-line or Python API
|
|
77
|
+
|
|
78
|
+
## 🚀 Quick Start
|
|
79
|
+
|
|
80
|
+
### Installation
|
|
81
|
+
|
|
82
|
+
Octopi is availableon PyPI and can be installed using pip:
|
|
83
|
+
```bash
|
|
84
|
+
pip install octopi
|
|
85
|
+
```
|
|
86
|
+
|
|
87
|
+
⚠️ **Note**: One of the current dependencies is currently not working with pip 25. To temporarily reduce the pip version, run:
|
|
88
|
+
```bash
|
|
89
|
+
pip install --upgrade "pip<25"
|
|
90
|
+
```
|
|
91
|
+
|
|
92
|
+
### Basic Usage
|
|
93
|
+
|
|
94
|
+
octopi provides two main command-line interfaces:
|
|
95
|
+
|
|
96
|
+
```bash
|
|
97
|
+
# Main CLI for training, inference, and data processing
|
|
98
|
+
octopi --help
|
|
99
|
+
|
|
100
|
+
# HPC-specific CLI for submitting jobs to SLURM clusters
|
|
101
|
+
octopi-slurm --help
|
|
102
|
+
```
|
|
103
|
+
|
|
104
|
+
## 📚 Documentation
|
|
105
|
+
|
|
106
|
+
For detailed documentation, tutorials, CLI and API reference, visit our [documentation](https://chanzuckerberg.github.io/octopi/).
|
|
107
|
+
|
|
108
|
+
## 🤝 Contributing
|
|
109
|
+
|
|
110
|
+
## Code of Conduct
|
|
111
|
+
|
|
112
|
+
This project adheres to the Contributor Covenant [code of conduct](https://github.com/chanzuckerberg/.github/blob/master/CODE_OF_CONDUCT.md).
|
|
113
|
+
By participating, you are expected to uphold this code.
|
|
114
|
+
Please report unacceptable behavior to [opensource@chanzuckerberg.com](mailto:opensource@chanzuckerberg.com).
|
|
115
|
+
|
|
116
|
+
## 🔒 Security
|
|
117
|
+
|
|
118
|
+
If you believe you have found a security issue, please responsibly disclose by contacting us at security@chanzuckerberg.com.
|
|
119
|
+
|
|
120
|
+
|
|
@@ -0,0 +1,62 @@
|
|
|
1
|
+
octopi/__init__.py,sha256=Btl_98iBIXFtvGx47MqpfbaEVYoOMPBQn9bao7UASkQ,21
|
|
2
|
+
octopi/main.py,sha256=ef_zpvopl6JiN4gOT_x2QghRJriqg4iwRzBnvuiBeTo,4864
|
|
3
|
+
octopi/workflows.py,sha256=bRP5uPqu7PtUtQUOgtVjBZ6FulhG9uAdlcEH_kBYXPE,9235
|
|
4
|
+
octopi/datasets/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
5
|
+
octopi/datasets/augment.py,sha256=k7UXQzidPANaPuLoBzer_ZHc4_vF-kKTOyZisEnAbNw,3203
|
|
6
|
+
octopi/datasets/cached_datset.py,sha256=z20Ldfve1nNfLLPNx6nUE2p6i4IXx4V2SHmt99aVN_k,3918
|
|
7
|
+
octopi/datasets/dataset.py,sha256=9C3AuD83FMswj4MYLFn5XPqvAP1a7MQSG7UNiStg090,511
|
|
8
|
+
octopi/datasets/generators.py,sha256=qL43vd6gxpPrySZmt8RWujtJo3eO6sFcpS92CmNkqSA,18883
|
|
9
|
+
octopi/datasets/io.py,sha256=COjGoCLVbPU_xQC2UCaHGPd24a8AXtDvUqyd8kOc-Eo,6978
|
|
10
|
+
octopi/datasets/mixup.py,sha256=BJUAmM7ItZWFChs8glnd8RNSXR5qGW7DHscbcVc3TsU,1575
|
|
11
|
+
octopi/datasets/multi_config_generator.py,sha256=_eAZOzWLo6bI_H8x0quoslyjfuERVRXEVHZ4aW7j_Ok,10814
|
|
12
|
+
octopi/entry_points/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
13
|
+
octopi/entry_points/common.py,sha256=pHYQFQKdrtbYYU96klTQV4u1acUBfApkbEOc_TbuIzA,5367
|
|
14
|
+
octopi/entry_points/create_slurm_submission.py,sha256=rW1hr59mLLvKzmHG-uozYvJvTcqnHNZdmy2nMr5EBDc,8914
|
|
15
|
+
octopi/entry_points/run_create_targets.py,sha256=FPpsgVQHbepJvE2ZdC1VQKdT5RpALdhDXk7jU1RUXKo,11236
|
|
16
|
+
octopi/entry_points/run_evaluate.py,sha256=ab_VEDlh3RxMHfSyoMCdFH6z9G3fpHIfAKvyoYmTP2M,2835
|
|
17
|
+
octopi/entry_points/run_extract_mb_picks.py,sha256=u7mVqYKEvvkSN-L7v27UJVURxl74q7yiUZwTs-3zfYE,4866
|
|
18
|
+
octopi/entry_points/run_extract_midpoint.py,sha256=O6GdkSD7oXIpSVqizOc6PHhv9nunz3j0RucmYQ2yryM,5742
|
|
19
|
+
octopi/entry_points/run_localize.py,sha256=-CF9w3XKLjiGIAN-kWgeobpfFYfIvZxRpxR0D8fyb3M,8264
|
|
20
|
+
octopi/entry_points/run_optuna.py,sha256=U1VYzwE0eYV8DNnms0gbjytqsDwSvHJ5BsMj48iZz8k,5784
|
|
21
|
+
octopi/entry_points/run_segment_predict.py,sha256=5RvnxUDDNzBfRv4dOGJl6SieAlF5SR5stHL_Z6BgUh0,5439
|
|
22
|
+
octopi/entry_points/run_train.py,sha256=O9uDunSyQGpEXqTTOYNSgO3LvTZt49K9HzcN0sGuD14,8095
|
|
23
|
+
octopi/extract/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
24
|
+
octopi/extract/localize.py,sha256=SQY1p2Q9sYrB_yjOZySYqoWwwlfgcvBTR1f_Fpi2Wxc,9528
|
|
25
|
+
octopi/extract/membranebound_extract.py,sha256=kTV_QFBbvzd3STJsC55QKoRTYcvg_raV3C757Y_9HBo,11087
|
|
26
|
+
octopi/extract/midpoint_extract.py,sha256=iEUZJShlhBC5qAdP5w-zjjEOMz_nu9JZF6jhbJvcV94,7036
|
|
27
|
+
octopi/models/AttentionUnet.py,sha256=r185aXRtfXhN-n8FxA-Sjz18nqpxHH_2t2uadrH-Mgs,1991
|
|
28
|
+
octopi/models/MedNeXt.py,sha256=9q0FsyrqTx211hCbDv0Lm2XflzXL_bGA4-76BscziGk,4875
|
|
29
|
+
octopi/models/ModelTemplate.py,sha256=X80EOXwSovCjmVb7x-0_JmRjHfDfLByDdd60MrgFTyw,1084
|
|
30
|
+
octopi/models/SegResNet.py,sha256=1dK8dy_7hHHKYZLsTYafl__7MxQOlWGbBqQWPGxHSXg,3609
|
|
31
|
+
octopi/models/Unet.py,sha256=7RGT_nl7ogsNlS3Y3Qexe305Ym9NlK9CV0y45g2DEU4,2171
|
|
32
|
+
octopi/models/UnetPlusPlus.py,sha256=fnV-SvJV8B432KJXQAtdwLy8Va6DJ4fRB_7a1mZiqTU,1529
|
|
33
|
+
octopi/models/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
34
|
+
octopi/models/common.py,sha256=qxx1ak2ayVC9fPbpjs6XtfUNPSX-PGywq3ySsXGdY74,2313
|
|
35
|
+
octopi/processing/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
36
|
+
octopi/processing/create_targets_from_picks.py,sha256=5dbCCjFXq55A-zD9SX4mXUKRq_lPWriTlnRYb958hBc,4377
|
|
37
|
+
octopi/processing/downsample.py,sha256=u3V2HULdKTLEthzI-MqJmI-NqpiyX0Vg61i77-j2MKQ,5392
|
|
38
|
+
octopi/processing/evaluate.py,sha256=Xxx8hJrWzZ_pSxEs89CC_6AxGUch3_RSHpH7UtEBs1c,14139
|
|
39
|
+
octopi/processing/importers.py,sha256=Rt-Y3QtFJz_ewCkdqtJpmiiI4xFSFeXaRcL1ZBn0bd0,8876
|
|
40
|
+
octopi/processing/segmentation_from_picks.py,sha256=jah1gAXEn09LIok1Cb8IeVN-fT3jktcVPfjbOFHkgg0,7089
|
|
41
|
+
octopi/pytorch/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
42
|
+
octopi/pytorch/hyper_search.py,sha256=ywwGEtTHJUkrBCOYSycrpWccUvxJlxny50q5GnWoWLA,9376
|
|
43
|
+
octopi/pytorch/model_search_submitter.py,sha256=zEfvlcB9AthtanUaqV3a8f2HGG_eo5gvIOZDOvUziJo,11267
|
|
44
|
+
octopi/pytorch/segmentation.py,sha256=nPW4tkQfgA4fC02j_bcLyKvS7SL9ghQnENKT7_YC_gc,11266
|
|
45
|
+
octopi/pytorch/segmentation_multigpu.py,sha256=KK8BZIdB1PJBs_jCv7WBS0xORR_XlpcKmmPf3AgIfYY,6781
|
|
46
|
+
octopi/pytorch/trainer.py,sha256=NAoOf00oh_-Sv5FZUaq_L5v7Y_VV_VtuExIaZCrvxfs,17760
|
|
47
|
+
octopi/pytorch_lightning/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
48
|
+
octopi/pytorch_lightning/optuna_pl_ddp.py,sha256=ynD5i81IP-awr6OD9GDurjrQK-5Kc079qPaukphTHnA,11924
|
|
49
|
+
octopi/pytorch_lightning/train_pl.py,sha256=igOHzU_mUdZRQGhoOGW5vmxJcHFcw5fAPHfVCIZ0eG4,10220
|
|
50
|
+
octopi/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
51
|
+
octopi/utils/config.py,sha256=RLrHmyjnQ0tXX2_bls--ZecbgRTNlnGtkvLl7YWtYtc,1833
|
|
52
|
+
octopi/utils/io.py,sha256=-PMzzBiIKbrUAIw7NcmeAj2V0yy9r4N2QRMIGlm2SII,4623
|
|
53
|
+
octopi/utils/losses.py,sha256=fs9yR80Hs-hL07WgVMkRy5N81vzP5m9XBCvzO76bIPU,3097
|
|
54
|
+
octopi/utils/parsers.py,sha256=InHLN-iYq0bWoKAHCX-hX82HO8b4pcZZ1on1LcwcsXo,5709
|
|
55
|
+
octopi/utils/stopping_criteria.py,sha256=q-tkcqyZ3GfqR8BzWu1eKE8dt2pNZ3MOD9dcMe2T7fU,6112
|
|
56
|
+
octopi/utils/submit_slurm.py,sha256=cRbJTESbPFCt6Cq4Hat2uPOQKFYMPcQxuNs0jc1ygUA,1945
|
|
57
|
+
octopi/utils/visualization_tools.py,sha256=I-fmXKk9eyaLTbvsYxx8H6nlEvUOH94NJ_l2dayXXe4,7364
|
|
58
|
+
octopi-1.2.0.dist-info/METADATA,sha256=UWZ2lfjUNDd34EDtL47LdpIp4p_TbkemNWhqjsehQ6U,5062
|
|
59
|
+
octopi-1.2.0.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
|
60
|
+
octopi-1.2.0.dist-info/entry_points.txt,sha256=n4I7aA9RS52l1eYPidQx63VhxT9pPsMK4xitlBm7X7k,90
|
|
61
|
+
octopi-1.2.0.dist-info/licenses/LICENSE,sha256=DRoLUEjEHgru_AwKHhii9UIOt6QA65am4WQd7UFzqnE,1822
|
|
62
|
+
octopi-1.2.0.dist-info/RECORD,,
|
|
@@ -17,15 +17,15 @@ copies of the Software, and to permit persons to whom the Software is
|
|
|
17
17
|
furnished to do so, subject to the following conditions:
|
|
18
18
|
|
|
19
19
|
The above copyright notice and this permission notice shall be included in all
|
|
20
|
-
copies or substantial portions of the Software.
|
|
20
|
+
copies or substantial portions of the "Software".
|
|
21
21
|
|
|
22
22
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
23
23
|
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
24
24
|
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
25
25
|
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
26
26
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
27
|
-
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
28
|
-
SOFTWARE.
|
|
27
|
+
OUT OF OR IN CONNECTION WITH THE "SOFTWARE" OR THE USE OR OTHER DEALINGS IN THE
|
|
28
|
+
"SOFTWARE".
|
|
29
29
|
```
|
|
30
30
|
|
|
31
31
|
## License Notice for Dependencies
|
octopi/io.py
DELETED
|
@@ -1,457 +0,0 @@
|
|
|
1
|
-
from monai.data import DataLoader, CacheDataset, Dataset
|
|
2
|
-
from monai.transforms import (
|
|
3
|
-
Compose,
|
|
4
|
-
NormalizeIntensityd,
|
|
5
|
-
EnsureChannelFirstd,
|
|
6
|
-
)
|
|
7
|
-
from sklearn.model_selection import train_test_split
|
|
8
|
-
import copick, torch, os, json, random
|
|
9
|
-
from collections import defaultdict
|
|
10
|
-
from octopi import utils
|
|
11
|
-
from typing import List
|
|
12
|
-
from tqdm import tqdm
|
|
13
|
-
import numpy as np
|
|
14
|
-
|
|
15
|
-
##############################################################################################################################
|
|
16
|
-
|
|
17
|
-
def load_training_data(root,
|
|
18
|
-
runIDs: List[str],
|
|
19
|
-
voxel_spacing: float,
|
|
20
|
-
tomo_algorithm: str,
|
|
21
|
-
segmenation_name: str,
|
|
22
|
-
segmentation_session_id: str = None,
|
|
23
|
-
segmentation_user_id: str = None,
|
|
24
|
-
progress_update: bool = True):
|
|
25
|
-
|
|
26
|
-
data_dicts = []
|
|
27
|
-
# Use tqdm for progress tracking only if progress_update is True
|
|
28
|
-
iterable = tqdm(runIDs, desc="Loading Training Data") if progress_update else runIDs
|
|
29
|
-
for runID in iterable:
|
|
30
|
-
run = root.get_run(str(runID))
|
|
31
|
-
tomogram = get_tomogram_array(run, voxel_spacing, tomo_algorithm)
|
|
32
|
-
segmentation = get_segmentation_array(run,
|
|
33
|
-
voxel_spacing,
|
|
34
|
-
segmenation_name,
|
|
35
|
-
segmentation_session_id,
|
|
36
|
-
segmentation_user_id)
|
|
37
|
-
data_dicts.append({"image": tomogram, "label": segmentation})
|
|
38
|
-
|
|
39
|
-
return data_dicts
|
|
40
|
-
|
|
41
|
-
##############################################################################################################################
|
|
42
|
-
|
|
43
|
-
def load_predict_data(root,
|
|
44
|
-
runIDs: List[str],
|
|
45
|
-
voxel_spacing: float,
|
|
46
|
-
tomo_algorithm: str):
|
|
47
|
-
|
|
48
|
-
data_dicts = []
|
|
49
|
-
for runID in tqdm(runIDs):
|
|
50
|
-
run = root.get_run(str(runID))
|
|
51
|
-
tomogram = get_tomogram_array(run, voxel_spacing, tomo_algorithm)
|
|
52
|
-
data_dicts.append({"image": tomogram})
|
|
53
|
-
|
|
54
|
-
return data_dicts
|
|
55
|
-
|
|
56
|
-
##############################################################################################################################
|
|
57
|
-
|
|
58
|
-
def create_predict_dataloader(
|
|
59
|
-
root,
|
|
60
|
-
voxel_spacing: float,
|
|
61
|
-
tomo_algorithm: str,
|
|
62
|
-
runIDs: str = None,
|
|
63
|
-
):
|
|
64
|
-
|
|
65
|
-
# define pre transforms
|
|
66
|
-
pre_transforms = Compose(
|
|
67
|
-
[ EnsureChannelFirstd(keys=["image"], channel_dim="no_channel"),
|
|
68
|
-
NormalizeIntensityd(keys=["image"]),
|
|
69
|
-
])
|
|
70
|
-
|
|
71
|
-
# Split trainRunIDs, validateRunIDs, testRunIDs
|
|
72
|
-
if runIDs is None:
|
|
73
|
-
runIDs = [run.name for run in root.runs]
|
|
74
|
-
test_files = load_predict_data(root, runIDs, voxel_spacing, tomo_algorithm)
|
|
75
|
-
|
|
76
|
-
bs = min( len(test_files), 4)
|
|
77
|
-
test_ds = CacheDataset(data=test_files, transform=pre_transforms)
|
|
78
|
-
test_loader = DataLoader(test_ds,
|
|
79
|
-
batch_size=bs,
|
|
80
|
-
shuffle=False,
|
|
81
|
-
num_workers=4,
|
|
82
|
-
pin_memory=torch.cuda.is_available())
|
|
83
|
-
return test_loader, test_ds
|
|
84
|
-
|
|
85
|
-
##############################################################################################################################
|
|
86
|
-
|
|
87
|
-
def get_tomogram_array(run,
|
|
88
|
-
voxel_size: float = 10,
|
|
89
|
-
tomo_type: str = 'wbp',
|
|
90
|
-
raise_error: bool = True):
|
|
91
|
-
|
|
92
|
-
voxel_spacing_obj = run.get_voxel_spacing(voxel_size)
|
|
93
|
-
|
|
94
|
-
if voxel_spacing_obj is None:
|
|
95
|
-
# Query Avaiable Voxel Spacings
|
|
96
|
-
availableVoxelSpacings = [tomo.voxel_size for tomo in run.voxel_spacings]
|
|
97
|
-
|
|
98
|
-
# Report to the user which voxel spacings they can use
|
|
99
|
-
message = (f"\n[Warning] No tomogram found for {run.name} with voxel size {voxel_size} and tomogram type {tomo_type}"
|
|
100
|
-
f"\nAvailable spacings are: {', '.join(map(str, availableVoxelSpacings))}\n" )
|
|
101
|
-
if raise_error:
|
|
102
|
-
raise ValueError(message)
|
|
103
|
-
else:
|
|
104
|
-
print(message)
|
|
105
|
-
return None
|
|
106
|
-
|
|
107
|
-
tomogram = voxel_spacing_obj.get_tomogram(tomo_type)
|
|
108
|
-
if tomogram is None:
|
|
109
|
-
# Get available algorithms
|
|
110
|
-
availableAlgorithms = [tomo.tomo_type for tomo in run.get_voxel_spacing(voxel_size).tomograms]
|
|
111
|
-
|
|
112
|
-
# Report to the user which algorithms are available
|
|
113
|
-
message = (f"\n[Warning] No tomogram found for {run.name} with voxel size {voxel_size} and tomogram type {tomo_type}"
|
|
114
|
-
f"\nAvailable algorithms are: {', '.join(availableAlgorithms)}\n")
|
|
115
|
-
if raise_error:
|
|
116
|
-
raise ValueError(message)
|
|
117
|
-
else:
|
|
118
|
-
print(message)
|
|
119
|
-
return None
|
|
120
|
-
|
|
121
|
-
return tomogram.numpy().astype(np.float32)
|
|
122
|
-
|
|
123
|
-
##############################################################################################################################
|
|
124
|
-
|
|
125
|
-
def get_segmentation_array(run,
|
|
126
|
-
voxel_spacing: float,
|
|
127
|
-
segmentation_name: str,
|
|
128
|
-
session_id=None,
|
|
129
|
-
user_id=None,
|
|
130
|
-
raise_error: bool = True):
|
|
131
|
-
|
|
132
|
-
seg = run.get_segmentations(name=segmentation_name,
|
|
133
|
-
session_id = session_id,
|
|
134
|
-
user_id = user_id,
|
|
135
|
-
voxel_size=float(voxel_spacing))
|
|
136
|
-
|
|
137
|
-
# No Segmentations Are Available, Result in Error
|
|
138
|
-
if len(seg) == 0:
|
|
139
|
-
# Get all available segmentations with their metadata
|
|
140
|
-
available_segs = run.get_segmentations(voxel_size=voxel_spacing)
|
|
141
|
-
seg_info = [(s.name, s.user_id, s.session_id) for s in available_segs]
|
|
142
|
-
|
|
143
|
-
# Format the information for display
|
|
144
|
-
seg_details = [f"(name: {name}, user_id: {uid}, session_id: {sid})"
|
|
145
|
-
for name, uid, sid in seg_info]
|
|
146
|
-
|
|
147
|
-
message = ( f'\nNo segmentation found matching:\n'
|
|
148
|
-
f' name: {segmentation_name}, user_id: {user_id}, session_id: {session_id}\n'
|
|
149
|
-
f'Available segmentations in {run.name} are:\n ' +
|
|
150
|
-
'\n '.join(seg_details) )
|
|
151
|
-
if raise_error:
|
|
152
|
-
raise ValueError(message)
|
|
153
|
-
else:
|
|
154
|
-
print(message)
|
|
155
|
-
return None
|
|
156
|
-
|
|
157
|
-
# No Segmentations Are Available, Result in Error
|
|
158
|
-
if len(seg) > 1:
|
|
159
|
-
print(f'[Warning] More Than 1 Segmentation is Available for the Query Information. '
|
|
160
|
-
f'Available Segmentations are: {seg} '
|
|
161
|
-
f'Defaulting to Loading: {seg[0]}\n')
|
|
162
|
-
seg = seg[0]
|
|
163
|
-
|
|
164
|
-
return seg.numpy().astype(np.int8)
|
|
165
|
-
|
|
166
|
-
##############################################################################################################################
|
|
167
|
-
|
|
168
|
-
def get_copick_coordinates(run, # CoPick run object containing the segmentation data
|
|
169
|
-
name: str, # Name of the object or protein for which coordinates are being extracted
|
|
170
|
-
user_id: str, # Identifier of the user that generated the picks
|
|
171
|
-
session_id: str = None, # Identifier of the session that generated the picks
|
|
172
|
-
voxel_size: float = 10, # Voxel size of the tomogram, used for scaling the coordinates
|
|
173
|
-
raise_error: bool = True):
|
|
174
|
-
|
|
175
|
-
# Retrieve the pick points associated with the specified object and user ID
|
|
176
|
-
picks = run.get_picks(object_name=name, user_id=user_id, session_id=session_id)
|
|
177
|
-
|
|
178
|
-
if len(picks) == 0:
|
|
179
|
-
# Get all available segmentations with their metadata
|
|
180
|
-
|
|
181
|
-
available_picks = run.get_picks()
|
|
182
|
-
picks_info = [(s.pickable_object_name, s.user_id, s.session_id) for s in available_picks]
|
|
183
|
-
|
|
184
|
-
# Format the information for display
|
|
185
|
-
picks_details = [f"(name: {name}, user_id: {uid}, session_id: {sid})"
|
|
186
|
-
for name, uid, sid in picks_info]
|
|
187
|
-
|
|
188
|
-
message = ( f'\nNo picks found matching:\n'
|
|
189
|
-
f' name: {name}, user_id: {user_id}, session_id: {session_id}\n'
|
|
190
|
-
f'Available picks are:\n '
|
|
191
|
-
+ '\n '.join(picks_details) )
|
|
192
|
-
if raise_error:
|
|
193
|
-
raise ValueError(message)
|
|
194
|
-
else:
|
|
195
|
-
print(message)
|
|
196
|
-
return None
|
|
197
|
-
elif len(picks) > 1:
|
|
198
|
-
# Format pick information for display
|
|
199
|
-
picks_info = [(p.pickable_object_name, p.user_id, p.session_id) for p in picks]
|
|
200
|
-
picks_details = [f"(name: {name}, user_id: {uid}, session_id: {sid})"
|
|
201
|
-
for name, uid, sid in picks_info]
|
|
202
|
-
|
|
203
|
-
print(f'[Warning] More than 1 pick is available for the query information.'
|
|
204
|
-
f'\nAvailable picks are:\n ' +
|
|
205
|
-
'\n '.join(picks_details) +
|
|
206
|
-
f'\nDefaulting to loading:\n {picks[0]}\n')
|
|
207
|
-
points = picks[0].points
|
|
208
|
-
|
|
209
|
-
# Initialize an array to store the coordinates
|
|
210
|
-
nPoints = len(picks[0].points) # Number of points retrieved
|
|
211
|
-
coordinates = np.zeros([len(picks[0].points), 3]) # Create an empty array to hold the (z, y, x) coordinates
|
|
212
|
-
|
|
213
|
-
# Iterate over all points and convert their locations to coordinates in voxel space
|
|
214
|
-
for ii in range(nPoints):
|
|
215
|
-
coordinates[ii,] = [points[ii].location.z / voxel_size, # Scale z-coordinate by voxel size
|
|
216
|
-
points[ii].location.y / voxel_size, # Scale y-coordinate by voxel size
|
|
217
|
-
points[ii].location.x / voxel_size] # Scale x-coordinate by voxel size
|
|
218
|
-
|
|
219
|
-
# Return the array of coordinates
|
|
220
|
-
return coordinates
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
##############################################################################################################################
|
|
224
|
-
|
|
225
|
-
def adjust_to_multiple(value, multiple = 16):
|
|
226
|
-
return int((value // multiple) * multiple)
|
|
227
|
-
|
|
228
|
-
def get_input_dimensions(dataset, crop_size: int):
|
|
229
|
-
nx = dataset[0]['image'].shape[1]
|
|
230
|
-
if crop_size > nx:
|
|
231
|
-
first_dim = adjust_to_multiple(nx/2)
|
|
232
|
-
return first_dim, crop_size, crop_size
|
|
233
|
-
else:
|
|
234
|
-
return crop_size, crop_size, crop_size
|
|
235
|
-
|
|
236
|
-
def get_num_classes(copick_config_path: str):
|
|
237
|
-
|
|
238
|
-
root = copick.from_file(copick_config_path)
|
|
239
|
-
return len(root.pickable_objects) + 1
|
|
240
|
-
|
|
241
|
-
def split_multiclass_dataset(runIDs,
|
|
242
|
-
train_ratio: float = 0.7,
|
|
243
|
-
val_ratio: float = 0.15,
|
|
244
|
-
test_ratio: float = 0.15,
|
|
245
|
-
return_test_dataset: bool = True,
|
|
246
|
-
random_state: int = 42):
|
|
247
|
-
"""
|
|
248
|
-
Splits a given dataset into three subsets: training, validation, and testing. If the dataset
|
|
249
|
-
has categories (as tuples), splits are balanced across all categories. If the dataset is a 1D
|
|
250
|
-
list, it is split without categorization.
|
|
251
|
-
|
|
252
|
-
Parameters:
|
|
253
|
-
- runIDs: A list of items to split. It can be a 1D list or a list of tuples (category, value).
|
|
254
|
-
- train_ratio: Proportion of the dataset for training.
|
|
255
|
-
- val_ratio: Proportion of the dataset for validation.
|
|
256
|
-
- test_ratio: Proportion of the dataset for testing.
|
|
257
|
-
- return_test_dataset: Whether to return the test dataset.
|
|
258
|
-
- random_state: Random state for reproducibility.
|
|
259
|
-
|
|
260
|
-
Returns:
|
|
261
|
-
- trainRunIDs: Training subset.
|
|
262
|
-
- valRunIDs: Validation subset.
|
|
263
|
-
- testRunIDs: Testing subset (if return_test_dataset is True, otherwise None).
|
|
264
|
-
"""
|
|
265
|
-
|
|
266
|
-
# Ensure the ratios add up to 1
|
|
267
|
-
assert train_ratio + val_ratio + test_ratio == 1.0, "Ratios must add up to 1.0"
|
|
268
|
-
|
|
269
|
-
# Check if the dataset has categories
|
|
270
|
-
if isinstance(runIDs[0], tuple) and len(runIDs[0]) == 2:
|
|
271
|
-
# Group by category
|
|
272
|
-
grouped = defaultdict(list)
|
|
273
|
-
for item in runIDs:
|
|
274
|
-
grouped[item[0]].append(item)
|
|
275
|
-
|
|
276
|
-
# Split each category
|
|
277
|
-
trainRunIDs, valRunIDs, testRunIDs = [], [], []
|
|
278
|
-
for category, items in grouped.items():
|
|
279
|
-
# Shuffle for randomness
|
|
280
|
-
random.shuffle(items)
|
|
281
|
-
# Split into train and remaining
|
|
282
|
-
train_items, remaining = train_test_split(items, test_size=(1 - train_ratio), random_state=random_state)
|
|
283
|
-
trainRunIDs.extend(train_items)
|
|
284
|
-
|
|
285
|
-
if return_test_dataset:
|
|
286
|
-
# Split remaining into validation and test
|
|
287
|
-
val_items, test_items = train_test_split(
|
|
288
|
-
remaining,
|
|
289
|
-
test_size=(test_ratio / (val_ratio + test_ratio)),
|
|
290
|
-
random_state=random_state,
|
|
291
|
-
)
|
|
292
|
-
valRunIDs.extend(val_items)
|
|
293
|
-
testRunIDs.extend(test_items)
|
|
294
|
-
else:
|
|
295
|
-
valRunIDs.extend(remaining)
|
|
296
|
-
testRunIDs = []
|
|
297
|
-
else:
|
|
298
|
-
# If no categories, split as a 1D list
|
|
299
|
-
trainRunIDs, remaining = train_test_split(runIDs, test_size=(1 - train_ratio), random_state=random_state)
|
|
300
|
-
if return_test_dataset:
|
|
301
|
-
valRunIDs, testRunIDs = train_test_split(
|
|
302
|
-
remaining,
|
|
303
|
-
test_size=(test_ratio / (val_ratio + test_ratio)),
|
|
304
|
-
random_state=random_state,
|
|
305
|
-
)
|
|
306
|
-
else:
|
|
307
|
-
valRunIDs = remaining
|
|
308
|
-
testRunIDs = []
|
|
309
|
-
|
|
310
|
-
return trainRunIDs, valRunIDs, testRunIDs
|
|
311
|
-
|
|
312
|
-
##############################################################################################################################
|
|
313
|
-
|
|
314
|
-
def load_copick_config(path: str):
|
|
315
|
-
|
|
316
|
-
if os.path.isfile(path):
|
|
317
|
-
root = copick.from_file(path)
|
|
318
|
-
else:
|
|
319
|
-
raise FileNotFoundError(f"Copick Config Path does not exist: {path}")
|
|
320
|
-
|
|
321
|
-
return root
|
|
322
|
-
|
|
323
|
-
##############################################################################################################################
|
|
324
|
-
|
|
325
|
-
# Helper function to flatten and serialize nested parameters
|
|
326
|
-
def flatten_params(params, parent_key=''):
|
|
327
|
-
flattened = {}
|
|
328
|
-
for key, value in params.items():
|
|
329
|
-
new_key = f"{parent_key}.{key}" if parent_key else key
|
|
330
|
-
if isinstance(value, dict):
|
|
331
|
-
flattened.update(flatten_params(value, new_key))
|
|
332
|
-
elif isinstance(value, list):
|
|
333
|
-
flattened[new_key] = ', '.join(map(str, value)) # Convert list to a comma-separated string
|
|
334
|
-
else:
|
|
335
|
-
flattened[new_key] = value
|
|
336
|
-
return flattened
|
|
337
|
-
|
|
338
|
-
# Manually join specific lists into strings for inline display
|
|
339
|
-
def prepare_for_inline_json(data):
|
|
340
|
-
for key in ["trainRunIDs", "valRunIDs", "testRunIDs"]:
|
|
341
|
-
if key in data['dataloader']:
|
|
342
|
-
data['dataloader'][key] = f"[{', '.join(map(repr, data['dataloader'][key]))}]"
|
|
343
|
-
|
|
344
|
-
for key in ['channels', 'strides']:
|
|
345
|
-
if key in data['model']:
|
|
346
|
-
data['model'][key] = f"[{', '.join(map(repr, data['model'][key]))}]"
|
|
347
|
-
return data
|
|
348
|
-
|
|
349
|
-
def get_optimizer_parameters(trainer):
|
|
350
|
-
|
|
351
|
-
optimizer_parameters = {
|
|
352
|
-
'my_num_samples': trainer.num_samples,
|
|
353
|
-
'val_interval': trainer.val_interval,
|
|
354
|
-
'lr': trainer.optimizer.param_groups[0]['lr'],
|
|
355
|
-
'optimizer': trainer.optimizer.__class__.__name__,
|
|
356
|
-
'metrics_function': trainer.metrics_function.__class__.__name__,
|
|
357
|
-
'loss_function': trainer.loss_function.__class__.__name__,
|
|
358
|
-
}
|
|
359
|
-
|
|
360
|
-
# Log Tversky Loss Parameters
|
|
361
|
-
if trainer.loss_function.__class__.__name__ == 'TverskyLoss':
|
|
362
|
-
optimizer_parameters['alpha'] = trainer.loss_function.alpha
|
|
363
|
-
elif trainer.loss_function.__class__.__name__ == 'FocalLoss':
|
|
364
|
-
optimizer_parameters['gamma'] = trainer.loss_function.gamma
|
|
365
|
-
elif trainer.loss_function.__class__.__name__ == 'WeightedFocalTverskyLoss':
|
|
366
|
-
optimizer_parameters['alpha'] = trainer.loss_function.alpha
|
|
367
|
-
optimizer_parameters['gamma'] = trainer.loss_function.gamma
|
|
368
|
-
optimizer_parameters['weight_tversky'] = trainer.loss_function.weight_tversky
|
|
369
|
-
elif trainer.loss_function.__class__.__name__ == 'FocalTverskyLoss':
|
|
370
|
-
optimizer_parameters['alpha'] = trainer.loss_function.alpha
|
|
371
|
-
optimizer_parameters['gamma'] = trainer.loss_function.gamma
|
|
372
|
-
|
|
373
|
-
return optimizer_parameters
|
|
374
|
-
|
|
375
|
-
def save_parameters_to_yaml(model, trainer, dataloader, filename: str):
|
|
376
|
-
|
|
377
|
-
parameters = {
|
|
378
|
-
'model': model.get_model_parameters(),
|
|
379
|
-
'optimizer': get_optimizer_parameters(trainer),
|
|
380
|
-
'dataloader': dataloader.get_dataloader_parameters()
|
|
381
|
-
}
|
|
382
|
-
|
|
383
|
-
utils.save_parameters_yaml(parameters, filename)
|
|
384
|
-
print(f"Training Parameters saved to {filename}")
|
|
385
|
-
|
|
386
|
-
def prepare_inline_results_json(results):
|
|
387
|
-
# Traverse the dictionary and format lists of lists as inline JSON
|
|
388
|
-
for key, value in results.items():
|
|
389
|
-
# Check if the value is a list of lists (like [[epoch, value], ...])
|
|
390
|
-
if isinstance(value, list) and all(isinstance(item, list) and len(item) == 2 for item in value):
|
|
391
|
-
# Format the list of lists as a single-line JSON string
|
|
392
|
-
results[key] = json.dumps(value)
|
|
393
|
-
return results
|
|
394
|
-
|
|
395
|
-
# Check to See if I'm Happy with This... Maybe Save as H5 File?
|
|
396
|
-
def save_results_to_json(results, filename: str):
|
|
397
|
-
|
|
398
|
-
results = prepare_inline_results_json(results)
|
|
399
|
-
with open(os.path.join(filename), "w") as json_file:
|
|
400
|
-
json.dump( results, json_file, indent=4 )
|
|
401
|
-
print(f"Training Results saved to {filename}")
|
|
402
|
-
|
|
403
|
-
##############################################################################################################################
|
|
404
|
-
|
|
405
|
-
# def save_parameters_to_json(model, trainer, dataloader, filename: str):
|
|
406
|
-
|
|
407
|
-
# parameters = {
|
|
408
|
-
# 'model': model.get_model_parameters(),
|
|
409
|
-
# 'optimizer': get_optimizer_parameters(trainer),
|
|
410
|
-
# 'dataloader': dataloader.get_dataloader_parameters()
|
|
411
|
-
# }
|
|
412
|
-
# parameters = prepare_for_inline_json(parameters)
|
|
413
|
-
|
|
414
|
-
# with open(os.path.join(filename), "w") as json_file:
|
|
415
|
-
# json.dump( parameters, json_file, indent=4 )
|
|
416
|
-
# print(f"Training Parameters saved to {filename}")
|
|
417
|
-
|
|
418
|
-
# def split_datasets(runIDs,
|
|
419
|
-
# train_ratio: float = 0.7,
|
|
420
|
-
# val_ratio: float = 0.15,
|
|
421
|
-
# test_ratio: float = 0.15,
|
|
422
|
-
# return_test_dataset: bool = True,
|
|
423
|
-
# random_state: int = 42):
|
|
424
|
-
# """
|
|
425
|
-
# Splits a given dataset into three subsets: training, validation, and testing. The proportions
|
|
426
|
-
# of each subset are determined by the provided ratios, ensuring that they add up to 1. The
|
|
427
|
-
# function uses a fixed random state for reproducibility.
|
|
428
|
-
|
|
429
|
-
# Parameters:
|
|
430
|
-
# - runIDs: The complete dataset that needs to be split.
|
|
431
|
-
# - train_ratio: The proportion of the dataset to be used for training.
|
|
432
|
-
# - val_ratio: The proportion of the dataset to be used for validation.
|
|
433
|
-
# - test_ratio: The proportion of the dataset to be used for testing.
|
|
434
|
-
|
|
435
|
-
# Returns:
|
|
436
|
-
# - trainRunIDs: The subset of the dataset used for training.
|
|
437
|
-
# - valRunIDs: The subset of the dataset used for validation.
|
|
438
|
-
# - testRunIDs: The subset of the dataset used for testing.
|
|
439
|
-
# """
|
|
440
|
-
|
|
441
|
-
# # Ensure the ratios add up to 1
|
|
442
|
-
# assert train_ratio + val_ratio + test_ratio == 1.0, "Ratios must add up to 1.0"
|
|
443
|
-
|
|
444
|
-
# # First, split into train and remaining (30%)
|
|
445
|
-
# trainRunIDs, valRunIDs = train_test_split(runIDs, test_size=(1 - train_ratio), random_state=random_state)
|
|
446
|
-
|
|
447
|
-
# # (Optional) split the remaining into validation and test
|
|
448
|
-
# if return_test_dataset:
|
|
449
|
-
# valRunIDs, testRunIDs = train_test_split(
|
|
450
|
-
# valRunIDs,
|
|
451
|
-
# test_size=(test_ratio / (val_ratio + test_ratio)),
|
|
452
|
-
# random_state=random_state,
|
|
453
|
-
# )
|
|
454
|
-
# else:
|
|
455
|
-
# testRunIDs = None
|
|
456
|
-
|
|
457
|
-
# return trainRunIDs, valRunIDs, testRunIDs
|