monai-weekly 1.5.dev2507__py3-none-any.whl → 1.5.dev2509__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- monai/__init__.py +1 -1
- monai/_version.py +3 -3
- monai/bundle/scripts.py +39 -16
- monai/handlers/__init__.py +1 -0
- monai/handlers/average_precision.py +53 -0
- monai/inferers/inferer.py +39 -18
- monai/metrics/__init__.py +1 -0
- monai/metrics/average_precision.py +187 -0
- monai/transforms/utility/array.py +2 -12
- monai/transforms/utils_pytorch_numpy_unification.py +2 -4
- monai/utils/enums.py +3 -2
- monai/utils/module.py +6 -6
- {monai_weekly-1.5.dev2507.dist-info → monai_weekly-1.5.dev2509.dist-info}/METADATA +20 -16
- {monai_weekly-1.5.dev2507.dist-info → monai_weekly-1.5.dev2509.dist-info}/RECORD +24 -20
- {monai_weekly-1.5.dev2507.dist-info → monai_weekly-1.5.dev2509.dist-info}/WHEEL +1 -1
- tests/bundle/test_bundle_trt_export.py +2 -2
- tests/handlers/test_handler_average_precision.py +79 -0
- tests/inferers/test_controlnet_inferers.py +89 -2
- tests/inferers/test_latent_diffusion_inferer.py +61 -1
- tests/metrics/test_compute_average_precision.py +162 -0
- tests/networks/test_convert_to_onnx.py +1 -1
- tests/transforms/test_gibbs_noise.py +3 -5
- {monai_weekly-1.5.dev2507.dist-info → monai_weekly-1.5.dev2509.dist-info}/LICENSE +0 -0
- {monai_weekly-1.5.dev2507.dist-info → monai_weekly-1.5.dev2509.dist-info}/top_level.txt +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.2
|
2
2
|
Name: monai-weekly
|
3
|
-
Version: 1.5.
|
3
|
+
Version: 1.5.dev2509
|
4
4
|
Summary: AI Toolkit for Healthcare Imaging
|
5
5
|
Home-page: https://monai.io/
|
6
6
|
Author: MONAI Consortium
|
@@ -176,12 +176,13 @@ Requires-Dist: pyamg>=5.0.0; extra == "pyamg"
|
|
176
176
|
|
177
177
|
MONAI is a [PyTorch](https://pytorch.org/)-based, [open-source](https://github.com/Project-MONAI/MONAI/blob/dev/LICENSE) framework for deep learning in healthcare imaging, part of the [PyTorch Ecosystem](https://pytorch.org/ecosystem/).
|
178
178
|
Its ambitions are as follows:
|
179
|
+
|
179
180
|
- Developing a community of academic, industrial and clinical researchers collaborating on a common foundation;
|
180
181
|
- Creating state-of-the-art, end-to-end training workflows for healthcare imaging;
|
181
182
|
- Providing researchers with the optimized and standardized way to create and evaluate deep learning models.
|
182
183
|
|
183
|
-
|
184
184
|
## Features
|
185
|
+
|
185
186
|
> _Please see [the technical highlights](https://docs.monai.io/en/latest/highlights.html) and [What's New](https://docs.monai.io/en/latest/whatsnew.html) of the milestone releases._
|
186
187
|
|
187
188
|
- flexible pre-processing for multi-dimensional medical imaging data;
|
@@ -190,7 +191,6 @@ Its ambitions are as follows:
|
|
190
191
|
- customizable design for varying user expertise;
|
191
192
|
- multi-GPU multi-node data parallelism support.
|
192
193
|
|
193
|
-
|
194
194
|
## Installation
|
195
195
|
|
196
196
|
To install [the current release](https://pypi.org/project/monai/), you can simply run:
|
@@ -211,30 +211,34 @@ Technical documentation is available at [docs.monai.io](https://docs.monai.io).
|
|
211
211
|
|
212
212
|
## Citation
|
213
213
|
|
214
|
-
If you have used MONAI in your research, please cite us! The citation can be exported from: https://arxiv.org/abs/2211.02701
|
214
|
+
If you have used MONAI in your research, please cite us! The citation can be exported from: <https://arxiv.org/abs/2211.02701>.
|
215
215
|
|
216
216
|
## Model Zoo
|
217
|
+
|
217
218
|
[The MONAI Model Zoo](https://github.com/Project-MONAI/model-zoo) is a place for researchers and data scientists to share the latest and great models from the community.
|
218
219
|
Utilizing [the MONAI Bundle format](https://docs.monai.io/en/latest/bundle_intro.html) makes it easy to [get started](https://github.com/Project-MONAI/tutorials/tree/main/model_zoo) building workflows with MONAI.
|
219
220
|
|
220
221
|
## Contributing
|
222
|
+
|
221
223
|
For guidance on making a contribution to MONAI, see the [contributing guidelines](https://github.com/Project-MONAI/MONAI/blob/dev/CONTRIBUTING.md).
|
222
224
|
|
223
225
|
## Community
|
226
|
+
|
224
227
|
Join the conversation on Twitter/X [@ProjectMONAI](https://twitter.com/ProjectMONAI) or join our [Slack channel](https://forms.gle/QTxJq3hFictp31UM9).
|
225
228
|
|
226
229
|
Ask and answer questions over on [MONAI's GitHub Discussions tab](https://github.com/Project-MONAI/MONAI/discussions).
|
227
230
|
|
228
231
|
## Links
|
229
|
-
|
230
|
-
-
|
231
|
-
- API documentation (
|
232
|
-
-
|
233
|
-
-
|
234
|
-
-
|
235
|
-
-
|
236
|
-
-
|
237
|
-
-
|
238
|
-
-
|
239
|
-
-
|
240
|
-
-
|
232
|
+
|
233
|
+
- Website: <https://monai.io/>
|
234
|
+
- API documentation (milestone): <https://docs.monai.io/>
|
235
|
+
- API documentation (latest dev): <https://docs.monai.io/en/latest/>
|
236
|
+
- Code: <https://github.com/Project-MONAI/MONAI>
|
237
|
+
- Project tracker: <https://github.com/Project-MONAI/MONAI/projects>
|
238
|
+
- Issue tracker: <https://github.com/Project-MONAI/MONAI/issues>
|
239
|
+
- Wiki: <https://github.com/Project-MONAI/MONAI/wiki>
|
240
|
+
- Test status: <https://github.com/Project-MONAI/MONAI/actions>
|
241
|
+
- PyPI package: <https://pypi.org/project/monai/>
|
242
|
+
- conda-forge: <https://anaconda.org/conda-forge/monai>
|
243
|
+
- Weekly previews: <https://pypi.org/project/monai-weekly/>
|
244
|
+
- Docker Hub: <https://hub.docker.com/r/projectmonai/monai>
|
@@ -1,5 +1,5 @@
|
|
1
|
-
monai/__init__.py,sha256=
|
2
|
-
monai/_version.py,sha256=
|
1
|
+
monai/__init__.py,sha256=2QSN66gMNzIDVAeBWVrsS3xgXmpc90Ksxr0j3D3KLiQ,4095
|
2
|
+
monai/_version.py,sha256=3pISgTcfhG3j_LA8zhH9EcyDi6PgzKxbNALoD_5HCps,503
|
3
3
|
monai/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
4
4
|
monai/_extensions/__init__.py,sha256=NEBPreRhQ8H9gVvgrLr_y52_TmqB96u_u4VQmeNT93I,642
|
5
5
|
monai/_extensions/loader.py,sha256=7SiKw36q-nOzH8CRbBurFrz7GM40GCu7rc93Tm8XpnI,3643
|
@@ -114,7 +114,7 @@ monai/bundle/config_item.py,sha256=rMjXSGkjJZdi04BwSHwCcIwzIb_TflmC3xDhC3SVJRs,1
|
|
114
114
|
monai/bundle/config_parser.py,sha256=cGyEn-cqNk0rEEZ1Qiv6UydmIDvtWZcMVljyfVm5i50,23025
|
115
115
|
monai/bundle/properties.py,sha256=iN3K4FVmN9ny1Hw9p5j7_ULcCdSD8PmrR7qXxbNz49k,11582
|
116
116
|
monai/bundle/reference_resolver.py,sha256=GXCMK4iogxxE6VocsmAbUrcXosmC5arnjeG9zYhHKpg,16748
|
117
|
-
monai/bundle/scripts.py,sha256=
|
117
|
+
monai/bundle/scripts.py,sha256=p7wlT0BplTIdW4DbxRPotf_tLsgddvtklW1kcAEPBZQ,91016
|
118
118
|
monai/bundle/utils.py,sha256=t-22uFvLn7Yy-dr1v1U33peNOxgAmU4TJiGAbsBrUKs,10108
|
119
119
|
monai/bundle/workflows.py,sha256=CuhmFq1AWsN3ATiYJCSakPOxrOdGutl6vkpo9sxe8gU,34369
|
120
120
|
monai/config/__init__.py,sha256=CN28CfTdsp301gv8YXfVvkbztCfbAqrLKrJi_C8oP9s,1048
|
@@ -160,7 +160,8 @@ monai/fl/utils/__init__.py,sha256=s9djSd6kvViPnFvMR11Dgd30Lv4oY6FaPJr4ZZJZLq0,57
|
|
160
160
|
monai/fl/utils/constants.py,sha256=OjMAE17niYqQh7nz45SC6CXvkMa4-XZsIuoHUHqP7W0,1784
|
161
161
|
monai/fl/utils/exchange_object.py,sha256=q41trOwBdog_g3k_Eh2EFnLufHJ1mj7nGyQ-ShuW5Mo,3527
|
162
162
|
monai/fl/utils/filters.py,sha256=InXplYes52JJqtsNbePAPPAYS8am_uRO7UkBHyYyJCo,1633
|
163
|
-
monai/handlers/__init__.py,sha256=
|
163
|
+
monai/handlers/__init__.py,sha256=m6SDdtXAZ4ONLCCYrSgONuPaJOz7lewOAzOvZ3J9r14,2442
|
164
|
+
monai/handlers/average_precision.py,sha256=FkIUP2mKqGvybnc_HxuuOdqPeq06wnZP_vwb8K-IhUg,2753
|
164
165
|
monai/handlers/checkpoint_loader.py,sha256=Y0qNBq5b-GJ-XOJNjuslegCpIGPZYOdNs3PxzNYCCm8,7432
|
165
166
|
monai/handlers/checkpoint_saver.py,sha256=z_w5HtNSeRM3QwHQIgQKqVodSYNy8dhL8KTBUzHuF0g,16047
|
166
167
|
monai/handlers/classification_saver.py,sha256=CNzdU9GrKj8KEC42jaBy2rEgpd3mqgz-YZg4dr61Jyg,7605
|
@@ -194,7 +195,7 @@ monai/handlers/trt_handler.py,sha256=uWFdgC8QKRkcNwWfKIbQMdK6-MX_1ON0mKabeIn1ltI
|
|
194
195
|
monai/handlers/utils.py,sha256=Ib1u-PLrtIkiLqTfREnrCWpN4af1btdNzkyMZuuuYyU,10239
|
195
196
|
monai/handlers/validation_handler.py,sha256=NZO21c6zzXbmAgJZHkkdoZQSQIHwuxh94QD3PLUldGU,3674
|
196
197
|
monai/inferers/__init__.py,sha256=K74t_RCeUPdEZvHzIPzVAwZ9DtmouLqhb3qDEmFBWs4,1107
|
197
|
-
monai/inferers/inferer.py,sha256=
|
198
|
+
monai/inferers/inferer.py,sha256=_VPnBIErwYzbrJIA9eMMalSso1pSsc_8cONVUUvPJOw,93549
|
198
199
|
monai/inferers/merger.py,sha256=dZm-FVyXPlFb59q4DG52mbtPm8Iy4cNFWv3un0Z8k0M,16262
|
199
200
|
monai/inferers/splitter.py,sha256=_hTnFdvDNRckkA7ZGQehVsNZw83oXoGFWyk5VXNqgJg,21149
|
200
201
|
monai/inferers/utils.py,sha256=dvZBCAjaPa8xXcJuXRzNQ-fBzteauzkKbxE5YZdGBGY,20374
|
@@ -220,8 +221,9 @@ monai/losses/sure_loss.py,sha256=PDDNNeZm8SLPRCDUPbc8o4--ribHnY4nbo8y55nRo0w,817
|
|
220
221
|
monai/losses/tversky.py,sha256=uLuqCvsac8OabTJzKQEzAfAvlwrflYCh0s76rgbcVJ0,6955
|
221
222
|
monai/losses/unified_focal_loss.py,sha256=rCj8IpueYH_UMrOUXU0tjbXIN4Uix3bGnRZQtRvl7Sg,10224
|
222
223
|
monai/losses/utils.py,sha256=wrpKcEO0XhbFOHz_jJRqeAeIgpMiMxmepnRf31_DNRU,2786
|
223
|
-
monai/metrics/__init__.py,sha256=
|
224
|
+
monai/metrics/__init__.py,sha256=rIRTn5dsXPzGoRv7tZ2ipZ7IiHlNJ4TrZOG_aDDhw28,2255
|
224
225
|
monai/metrics/active_learning_metrics.py,sha256=uKID2O4mnY-9P2ZzyT4sqJd2NfgzjSpNKpAwulWCozU,8211
|
226
|
+
monai/metrics/average_precision.py,sha256=rQYfPAmE78np8E4UoDPk-DSVRtEVC2hAcj5w9Q6ZIqk,8454
|
225
227
|
monai/metrics/confusion_matrix.py,sha256=Spb20jYPnbgGZfPKDQI36ePznPf1xujxhboNnW8HxdQ,15064
|
226
228
|
monai/metrics/cumulative_average.py,sha256=8GGjHmiBboBikprg1380SsNn7RgzFIrHGWBYDBv6ebE,5636
|
227
229
|
monai/metrics/f_beta_score.py,sha256=urI0J_tvl0qQ5-l2fgWV_jChbgpzLmgpRq125B3yxpw,3984
|
@@ -360,7 +362,7 @@ monai/transforms/transform.py,sha256=0eC_Gw7T2jBb589-3EHLh-8gJD687k2OVmrnMxaKs3o
|
|
360
362
|
monai/transforms/utils.py,sha256=t4TMksfSzozyNqP-HJK-ZydvmImLFzxhks0yJnZTOYM,106430
|
361
363
|
monai/transforms/utils_create_transform_ims.py,sha256=QEJVHsCZX7ZxsBArk6NjgCzSZuuokf8l1uFqiUZBBys,31155
|
362
364
|
monai/transforms/utils_morphological_ops.py,sha256=tt0lRLLxmlnn9roUuPEBtqah6t7BH8ittxyDFuskkUI,6767
|
363
|
-
monai/transforms/utils_pytorch_numpy_unification.py,sha256=
|
365
|
+
monai/transforms/utils_pytorch_numpy_unification.py,sha256=pM6-x-TAGVcQohSYirfTqiy2SQnPixcKKHTmTqtBbg0,18706
|
364
366
|
monai/transforms/croppad/__init__.py,sha256=s9djSd6kvViPnFvMR11Dgd30Lv4oY6FaPJr4ZZJZLq0,573
|
365
367
|
monai/transforms/croppad/array.py,sha256=WeSAs4JNtNafFaIMLPi3-9NuuyCiTm19cq2oEOonKWQ,74632
|
366
368
|
monai/transforms/croppad/batch.py,sha256=5ukcYk3VCDpk62AL5Q_jTqpXmSNTlw0UCUhDeAB4aV0,6138
|
@@ -396,17 +398,17 @@ monai/transforms/spatial/array.py,sha256=5EKivdPYCP4i4qYUlkK1RpYQFzaU_baYyzgubid
|
|
396
398
|
monai/transforms/spatial/dictionary.py,sha256=t0SvEDSVNFUEw2fK66OVF20sqSzCNxil17HmvsMFBt8,133752
|
397
399
|
monai/transforms/spatial/functional.py,sha256=IwS0witCqbGkyuxzu_R4Ztp90S0pg9hY1irG7feXqig,33886
|
398
400
|
monai/transforms/utility/__init__.py,sha256=s9djSd6kvViPnFvMR11Dgd30Lv4oY6FaPJr4ZZJZLq0,573
|
399
|
-
monai/transforms/utility/array.py,sha256=
|
401
|
+
monai/transforms/utility/array.py,sha256=Du3QA6m0io7mR51gUgaMwHBFNStdFmRxhaYmBCVy7BY,81215
|
400
402
|
monai/transforms/utility/dictionary.py,sha256=iOFdTSekvkAsBbbfHeffcRsOKRtNcnt3N1cVuUarZ1s,80549
|
401
403
|
monai/utils/__init__.py,sha256=2_AIpb1wqGMkmgoZ3r43muFTEsnMTCkPu3LtckipYHg,3793
|
402
404
|
monai/utils/component_store.py,sha256=Fe9jbHgwwBBAeJAw0nI02Ae13v17wlwF6N9uUue8tJg,4525
|
403
405
|
monai/utils/decorators.py,sha256=qhhdmJMjMfZIUM6x_VGUGF7kaq2cBUAam8WymAU_mhw,3156
|
404
406
|
monai/utils/deprecate_utils.py,sha256=gKeEV4MsI51qeQ5gci2me_C-0e-tDwa3VZzd3XPQqLk,14759
|
405
407
|
monai/utils/dist.py,sha256=7brB42CvdS8Jvr8Y7hfqov1uk6NNnYea9dYfgMYy0BY,8578
|
406
|
-
monai/utils/enums.py,sha256=
|
408
|
+
monai/utils/enums.py,sha256=jXtLaNDxG3BRBgLG2t13_S_G4iVWYHZO_GztykAtmXg,19594
|
407
409
|
monai/utils/jupyter_utils.py,sha256=BYtj80LWQAYg5RWPj5g4j2AMCzLECvAcnZdXns0Ruw8,15651
|
408
410
|
monai/utils/misc.py,sha256=R-sCS5u7SA8hX6e7x6WSc8FgLcNpqKFRRDMWxUd2wCo,31759
|
409
|
-
monai/utils/module.py,sha256=
|
411
|
+
monai/utils/module.py,sha256=R37PpCNCcHQvjjZFbNjNyzWb3FURaKLxQucjhzQk0eU,26087
|
410
412
|
monai/utils/nvtx.py,sha256=i9JBxR1uhW1ZCgLPLlTx8b907QlXkFzJyTBLMlFjhtU,6876
|
411
413
|
monai/utils/ordering.py,sha256=0nlA5b5QpVCHbtiCbTC-YsqjTmjm0bub0IeJhGFBOes,8270
|
412
414
|
monai/utils/profiling.py,sha256=V2_cSHgrcmVF48_G3nUi2-O6fnXsS89nSlb8jj58YLo,15937
|
@@ -504,7 +506,7 @@ tests/bundle/test_bundle_ckpt_export.py,sha256=VnpigCoBAAc2lo0rWOpVMg0IYGB6vbHXL
|
|
504
506
|
tests/bundle/test_bundle_download.py,sha256=4wpnCXNYTwTHWNjuSZqnXpVzadxNRabmFaFM3LZ_TJU,20072
|
505
507
|
tests/bundle/test_bundle_get_data.py,sha256=lQh321mev_7fsLXRg0Tq5uEjuQILethDHRKzB6VV0o4,3667
|
506
508
|
tests/bundle/test_bundle_push_to_hf_hub.py,sha256=Zjl6xDwRKgkS6jvO5dzMBaTLEd4EXyMXp0_wzDNSY3g,1740
|
507
|
-
tests/bundle/test_bundle_trt_export.py,sha256=
|
509
|
+
tests/bundle/test_bundle_trt_export.py,sha256=png-2SGjBSt46LXSz-PLprOXwJ0WkC_3YLR3Ibk_WBc,6344
|
508
510
|
tests/bundle/test_bundle_utils.py,sha256=GTTS_5tEvV5qLad-aHeZXHDQLZcsDwi56Ldn5FnK2RE,4573
|
509
511
|
tests/bundle/test_bundle_verify_metadata.py,sha256=OmcERLA5ht91cUDK9yYKXhpk-96yZcj4EBwZBk7zW3w,2660
|
510
512
|
tests/bundle/test_bundle_verify_net.py,sha256=guCsyjb5op216AUUUQo97YY3p1-XcQEWINouxNX6F84,3383
|
@@ -529,6 +531,7 @@ tests/fl/monai_algo/test_fl_monai_algo_dist.py,sha256=Tq560TGvTmafEa5sDGax_chRlD
|
|
529
531
|
tests/fl/utils/__init__.py,sha256=s9djSd6kvViPnFvMR11Dgd30Lv4oY6FaPJr4ZZJZLq0,573
|
530
532
|
tests/fl/utils/test_fl_exchange_object.py,sha256=rddodowFMAdNT9wquI0NHg0CSm5Xvk_v9Si-eJqyiow,2571
|
531
533
|
tests/handlers/__init__.py,sha256=s9djSd6kvViPnFvMR11Dgd30Lv4oY6FaPJr4ZZJZLq0,573
|
534
|
+
tests/handlers/test_handler_average_precision.py,sha256=0cmgjzWxlfdZsUJB1NnSXfx3dmmDI6CbvIqggtc5rTY,2814
|
532
535
|
tests/handlers/test_handler_checkpoint_loader.py,sha256=1dA4WYp-L6KxtzZIqUs--lNM4O-Anw2-s29QSdIOReU,8443
|
533
536
|
tests/handlers/test_handler_checkpoint_saver.py,sha256=K3bxelElfETpQSXRovWZlxZZmkjY3hm_cJo8kjYCJ3I,6256
|
534
537
|
tests/handlers/test_handler_classification_saver.py,sha256=vesCfTcAPkDAR7oAB_8kyeQrXpkrPQmdME9YBwPV7EE,2355
|
@@ -567,9 +570,9 @@ tests/handlers/test_trt_compile.py,sha256=p8Gr2CJmBo6gG8w7bGlAO--nDHtQvy9Ld3jtua
|
|
567
570
|
tests/handlers/test_write_metrics_reports.py,sha256=oKGYR1plj1hSAu-ntbxkw_TD4O5hKPwVH_BS3MdHIbs,3027
|
568
571
|
tests/inferers/__init__.py,sha256=s9djSd6kvViPnFvMR11Dgd30Lv4oY6FaPJr4ZZJZLq0,573
|
569
572
|
tests/inferers/test_avg_merger.py,sha256=lMR2PcNGFD6sfF6CjJTkahrAiMA5m5LUs5A11P6h8n0,5952
|
570
|
-
tests/inferers/test_controlnet_inferers.py,sha256=
|
573
|
+
tests/inferers/test_controlnet_inferers.py,sha256=sWs5vkZHa-D0V3tWJ6149Z-RNq0for_XngDYxZRl_Ao,50285
|
571
574
|
tests/inferers/test_diffusion_inferer.py,sha256=1O2V_bEmifOZ4RvpbZgYUCooiJ97T73avaBuMJPpBs0,9992
|
572
|
-
tests/inferers/test_latent_diffusion_inferer.py,sha256=
|
575
|
+
tests/inferers/test_latent_diffusion_inferer.py,sha256=atJjmfVznUq8z9EjohFIMyA0Q1XT1Ly0Zepf_1xPz5I,32274
|
573
576
|
tests/inferers/test_patch_inferer.py,sha256=LkYXWVn71vWinP-OJsIvq3FPH3jr36T7nKRIH5PzaqY,9878
|
574
577
|
tests/inferers/test_saliency_inferer.py,sha256=7miHRbA4yb_WGcxql6za9uXXoZlql_7y23f7IzsyIps,1949
|
575
578
|
tests/inferers/test_slice_inferer.py,sha256=kzaJjjTnf2rAiR75l8A_J-Kie4NaLj2bogi-aJ5L5mk,1897
|
@@ -645,6 +648,7 @@ tests/losses/image_dissimilarity/__init__.py,sha256=s9djSd6kvViPnFvMR11Dgd30Lv4o
|
|
645
648
|
tests/losses/image_dissimilarity/test_global_mutual_information_loss.py,sha256=9xEX5BCEQ1s004QgcwYaAFwKTmlZjuVG8cIbK7Giwts,5692
|
646
649
|
tests/losses/image_dissimilarity/test_local_normalized_cross_correlation_loss.py,sha256=Gs3zHnGWNZ50liU_tya4Z_6tCRKIWCtG59imAxXdKPI,6070
|
647
650
|
tests/metrics/__init__.py,sha256=s9djSd6kvViPnFvMR11Dgd30Lv4oY6FaPJr4ZZJZLq0,573
|
651
|
+
tests/metrics/test_compute_average_precision.py,sha256=o5gYko4Ow87Ix1n_z6_HmfuTKmkZM__fDZQpjKNJNrA,4743
|
648
652
|
tests/metrics/test_compute_confusion_matrix.py,sha256=dwiqMnp7T6KJLJ7qv6J5g_RDDrB6UiLAe-pgmVNSz7I,10669
|
649
653
|
tests/metrics/test_compute_f_beta.py,sha256=xbCipeICoAXWZLgDFeDAa1KjDQxDTMVArNbtUYiCG3c,3286
|
650
654
|
tests/metrics/test_compute_fid_metric.py,sha256=B9OZECl3CT1JKzG-2C_YaPFjgfvlFoS9vI1j8vBzWZg,1328
|
@@ -670,7 +674,7 @@ tests/metrics/test_surface_dice.py,sha256=CGCQt-ydMzaT2q1fFnzpKb6E-TPydym4vE_kdp
|
|
670
674
|
tests/metrics/test_surface_distance.py,sha256=gkW0dai3vHjXubLNBilqFnV5Up-abSMgQ53v0iCTVeE,6237
|
671
675
|
tests/networks/__init__.py,sha256=s9djSd6kvViPnFvMR11Dgd30Lv4oY6FaPJr4ZZJZLq0,573
|
672
676
|
tests/networks/test_bundle_onnx_export.py,sha256=_lEnAJhq7D2IOuVEdgBVsA8vySgs34FkfMrvNsCLfUg,2853
|
673
|
-
tests/networks/test_convert_to_onnx.py,sha256=
|
677
|
+
tests/networks/test_convert_to_onnx.py,sha256=h1Sjb0SZmiwwbx0_PrzeFDOE3-JRSp18qDS6G_PdD6g,3673
|
674
678
|
tests/networks/test_convert_to_torchscript.py,sha256=NhrJMCfQtC0sftrhDjL28omS7VKzg_niRK0KtY5Mr_A,1636
|
675
679
|
tests/networks/test_convert_to_trt.py,sha256=5TkuUvCPgW5mAvYUysRRrSjtSbDoDDAoJb2kJtuXOVk,2656
|
676
680
|
tests/networks/test_save_state.py,sha256=_glX4irpJVqk2jnOJaVqYxsOQNX3oCauxlEXe2ly8Cg,2354
|
@@ -881,7 +885,7 @@ tests/transforms/test_generate_label_classes_crop_centers.py,sha256=E5DtL2s1sio1
|
|
881
885
|
tests/transforms/test_generate_pos_neg_label_crop_centers.py,sha256=DdCbdYaTHL40crC5o440cpEt0xNLXzT-rVphaBH11HM,2516
|
882
886
|
tests/transforms/test_generate_spatial_bounding_box.py,sha256=JxHt4BHmtGYIqyzGhWgkCB5_oJU2ro_737upVxWBPvI,3510
|
883
887
|
tests/transforms/test_get_extreme_points.py,sha256=881LZMTms1tXRDtODIheZbKDXMVQ69ff78IvukoabGc,1700
|
884
|
-
tests/transforms/test_gibbs_noise.py,sha256=
|
888
|
+
tests/transforms/test_gibbs_noise.py,sha256=9TgOYhGz1P6-VJUXszuV9NgqhjF5FKCVcQuG_7o3jUI,2658
|
885
889
|
tests/transforms/test_gibbs_noised.py,sha256=o9ZQVAyuHATbV9JHkeTy_pDLz5Mqg5ctMQawMmP71RQ,3228
|
886
890
|
tests/transforms/test_grid_distortion.py,sha256=8dTQjWQ2_euNKN00xxZXqZk-cFSsKfpVpkNm-1-WytA,4472
|
887
891
|
tests/transforms/test_grid_distortiond.py,sha256=bSLhB_LGQKXo5VqP9RCyJDSyiZi2er2W2Qdw7qDep9s,3492
|
@@ -1174,8 +1178,8 @@ tests/visualize/test_vis_gradcam.py,sha256=WpA-pvTB75eZs7JoIc5qyvOV9PwgkzWI8-Vow
|
|
1174
1178
|
tests/visualize/utils/__init__.py,sha256=s9djSd6kvViPnFvMR11Dgd30Lv4oY6FaPJr4ZZJZLq0,573
|
1175
1179
|
tests/visualize/utils/test_blend_images.py,sha256=RVs2p_8RWQDfhLHDNNtZaMig27v8o0km7XxNa-zWjKE,2274
|
1176
1180
|
tests/visualize/utils/test_matshow3d.py,sha256=wXYj77L5Jvnp0f6DvL1rsi_-YlCxS0HJ9hiPmrbpuP8,5021
|
1177
|
-
monai_weekly-1.5.
|
1178
|
-
monai_weekly-1.5.
|
1179
|
-
monai_weekly-1.5.
|
1180
|
-
monai_weekly-1.5.
|
1181
|
-
monai_weekly-1.5.
|
1181
|
+
monai_weekly-1.5.dev2509.dist-info/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
1182
|
+
monai_weekly-1.5.dev2509.dist-info/METADATA,sha256=h7L3w9XhzSfoxC5yRoqgKS_NeECPEORKyEX4E1WS6Vc,11909
|
1183
|
+
monai_weekly-1.5.dev2509.dist-info/WHEEL,sha256=jB7zZ3N9hIM9adW7qlTAyycLYW9npaWKLRzaoVcLKcM,91
|
1184
|
+
monai_weekly-1.5.dev2509.dist-info/top_level.txt,sha256=hn2Y6P9xBf2R8faMeVMHhPMvrdDKxMsIOwMDYI0yTjs,12
|
1185
|
+
monai_weekly-1.5.dev2509.dist-info/RECORD,,
|
@@ -70,7 +70,7 @@ class TestTRTExport(unittest.TestCase):
|
|
70
70
|
@parameterized.expand([TEST_CASE_1, TEST_CASE_2, TEST_CASE_3, TEST_CASE_4])
|
71
71
|
@unittest.skipUnless(has_torchtrt and has_tensorrt, "Torch-TensorRT is required for conversion!")
|
72
72
|
def test_trt_export(self, convert_precision, input_shape, dynamic_batch):
|
73
|
-
tests_dir = Path(__file__).resolve().
|
73
|
+
tests_dir = Path(__file__).resolve().parents[1]
|
74
74
|
meta_file = os.path.join(tests_dir, "testing_data", "metadata.json")
|
75
75
|
config_file = os.path.join(tests_dir, "testing_data", "inference.json")
|
76
76
|
with tempfile.TemporaryDirectory() as tempdir:
|
@@ -108,7 +108,7 @@ class TestTRTExport(unittest.TestCase):
|
|
108
108
|
has_onnx and has_torchtrt and has_tensorrt, "Onnx and TensorRT are required for onnx-trt conversion!"
|
109
109
|
)
|
110
110
|
def test_onnx_trt_export(self, convert_precision, input_shape, dynamic_batch):
|
111
|
-
tests_dir = Path(__file__).resolve().
|
111
|
+
tests_dir = Path(__file__).resolve().parents[1]
|
112
112
|
meta_file = os.path.join(tests_dir, "testing_data", "metadata.json")
|
113
113
|
config_file = os.path.join(tests_dir, "testing_data", "inference.json")
|
114
114
|
with tempfile.TemporaryDirectory() as tempdir:
|
@@ -0,0 +1,79 @@
|
|
1
|
+
# Copyright (c) MONAI Consortium
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
6
|
+
# Unless required by applicable law or agreed to in writing, software
|
7
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
8
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
9
|
+
# See the License for the specific language governing permissions and
|
10
|
+
# limitations under the License.
|
11
|
+
|
12
|
+
from __future__ import annotations
|
13
|
+
|
14
|
+
import unittest
|
15
|
+
|
16
|
+
import numpy as np
|
17
|
+
import torch
|
18
|
+
import torch.distributed as dist
|
19
|
+
|
20
|
+
from monai.handlers import AveragePrecision
|
21
|
+
from monai.transforms import Activations, AsDiscrete
|
22
|
+
from tests.test_utils import DistCall, DistTestCase
|
23
|
+
|
24
|
+
|
25
|
+
class TestHandlerAveragePrecision(unittest.TestCase):
|
26
|
+
|
27
|
+
def test_compute(self):
|
28
|
+
ap_metric = AveragePrecision()
|
29
|
+
act = Activations(softmax=True)
|
30
|
+
to_onehot = AsDiscrete(to_onehot=2)
|
31
|
+
|
32
|
+
y_pred = [torch.Tensor([0.1, 0.9]), torch.Tensor([0.3, 1.4])]
|
33
|
+
y = [torch.Tensor([0]), torch.Tensor([1])]
|
34
|
+
y_pred = [act(p) for p in y_pred]
|
35
|
+
y = [to_onehot(y_) for y_ in y]
|
36
|
+
ap_metric.update([y_pred, y])
|
37
|
+
|
38
|
+
y_pred = [torch.Tensor([0.2, 0.1]), torch.Tensor([0.1, 0.5])]
|
39
|
+
y = [torch.Tensor([0]), torch.Tensor([1])]
|
40
|
+
y_pred = [act(p) for p in y_pred]
|
41
|
+
y = [to_onehot(y_) for y_ in y]
|
42
|
+
|
43
|
+
ap_metric.update([y_pred, y])
|
44
|
+
|
45
|
+
ap = ap_metric.compute()
|
46
|
+
np.testing.assert_allclose(0.8333333, ap)
|
47
|
+
|
48
|
+
|
49
|
+
class DistributedAveragePrecision(DistTestCase):
|
50
|
+
|
51
|
+
@DistCall(nnodes=1, nproc_per_node=2, node_rank=0)
|
52
|
+
def test_compute(self):
|
53
|
+
ap_metric = AveragePrecision()
|
54
|
+
act = Activations(softmax=True)
|
55
|
+
to_onehot = AsDiscrete(to_onehot=2)
|
56
|
+
|
57
|
+
device = f"cuda:{dist.get_rank()}" if torch.cuda.is_available() else "cpu"
|
58
|
+
if dist.get_rank() == 0:
|
59
|
+
y_pred = [torch.tensor([0.1, 0.9], device=device), torch.tensor([0.3, 1.4], device=device)]
|
60
|
+
y = [torch.tensor([0], device=device), torch.tensor([1], device=device)]
|
61
|
+
|
62
|
+
if dist.get_rank() == 1:
|
63
|
+
y_pred = [
|
64
|
+
torch.tensor([0.2, 0.1], device=device),
|
65
|
+
torch.tensor([0.1, 0.5], device=device),
|
66
|
+
torch.tensor([0.3, 0.4], device=device),
|
67
|
+
]
|
68
|
+
y = [torch.tensor([0], device=device), torch.tensor([1], device=device), torch.tensor([1], device=device)]
|
69
|
+
|
70
|
+
y_pred = [act(p) for p in y_pred]
|
71
|
+
y = [to_onehot(y_) for y_ in y]
|
72
|
+
ap_metric.update([y_pred, y])
|
73
|
+
|
74
|
+
result = ap_metric.compute()
|
75
|
+
np.testing.assert_allclose(0.7778, result, rtol=1e-4)
|
76
|
+
|
77
|
+
|
78
|
+
if __name__ == "__main__":
|
79
|
+
unittest.main()
|
@@ -550,6 +550,8 @@ class ControlNetTestDiffusionSamplingInferer(unittest.TestCase):
|
|
550
550
|
def test_sampler_conditioned(self, model_params, controlnet_params, input_shape):
|
551
551
|
model_params["with_conditioning"] = True
|
552
552
|
model_params["cross_attention_dim"] = 3
|
553
|
+
controlnet_params["with_conditioning"] = True
|
554
|
+
controlnet_params["cross_attention_dim"] = 3
|
553
555
|
model = DiffusionModelUNet(**model_params)
|
554
556
|
controlnet = ControlNet(**controlnet_params)
|
555
557
|
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
@@ -619,8 +621,11 @@ class ControlNetTestDiffusionSamplingInferer(unittest.TestCase):
|
|
619
621
|
model_params = model_params.copy()
|
620
622
|
n_concat_channel = 2
|
621
623
|
model_params["in_channels"] = model_params["in_channels"] + n_concat_channel
|
624
|
+
controlnet_params["in_channels"] = controlnet_params["in_channels"] + n_concat_channel
|
622
625
|
model_params["cross_attention_dim"] = None
|
626
|
+
controlnet_params["cross_attention_dim"] = None
|
623
627
|
model_params["with_conditioning"] = False
|
628
|
+
controlnet_params["with_conditioning"] = False
|
624
629
|
model = DiffusionModelUNet(**model_params)
|
625
630
|
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
626
631
|
model.to(device)
|
@@ -722,7 +727,7 @@ class LatentControlNetTestDiffusionSamplingInferer(unittest.TestCase):
|
|
722
727
|
|
723
728
|
@parameterized.expand(LATENT_CNDM_TEST_CASES)
|
724
729
|
@skipUnless(has_einops, "Requires einops")
|
725
|
-
def
|
730
|
+
def test_pred_shape(
|
726
731
|
self,
|
727
732
|
ae_model_type,
|
728
733
|
autoencoder_params,
|
@@ -1023,8 +1028,10 @@ class LatentControlNetTestDiffusionSamplingInferer(unittest.TestCase):
|
|
1023
1028
|
if ae_model_type == "SPADEAutoencoderKL":
|
1024
1029
|
stage_1 = SPADEAutoencoderKL(**autoencoder_params)
|
1025
1030
|
stage_2_params = stage_2_params.copy()
|
1031
|
+
controlnet_params = controlnet_params.copy()
|
1026
1032
|
n_concat_channel = 3
|
1027
1033
|
stage_2_params["in_channels"] = stage_2_params["in_channels"] + n_concat_channel
|
1034
|
+
controlnet_params["in_channels"] = controlnet_params["in_channels"] + n_concat_channel
|
1028
1035
|
if dm_model_type == "SPADEDiffusionModelUNet":
|
1029
1036
|
stage_2 = SPADEDiffusionModelUNet(**stage_2_params)
|
1030
1037
|
else:
|
@@ -1106,8 +1113,10 @@ class LatentControlNetTestDiffusionSamplingInferer(unittest.TestCase):
|
|
1106
1113
|
if ae_model_type == "SPADEAutoencoderKL":
|
1107
1114
|
stage_1 = SPADEAutoencoderKL(**autoencoder_params)
|
1108
1115
|
stage_2_params = stage_2_params.copy()
|
1116
|
+
controlnet_params = controlnet_params.copy()
|
1109
1117
|
n_concat_channel = 3
|
1110
1118
|
stage_2_params["in_channels"] = stage_2_params["in_channels"] + n_concat_channel
|
1119
|
+
controlnet_params["in_channels"] = controlnet_params["in_channels"] + n_concat_channel
|
1111
1120
|
if dm_model_type == "SPADEDiffusionModelUNet":
|
1112
1121
|
stage_2 = SPADEDiffusionModelUNet(**stage_2_params)
|
1113
1122
|
else:
|
@@ -1165,7 +1174,7 @@ class LatentControlNetTestDiffusionSamplingInferer(unittest.TestCase):
|
|
1165
1174
|
|
1166
1175
|
@parameterized.expand(LATENT_CNDM_TEST_CASES_DIFF_SHAPES)
|
1167
1176
|
@skipUnless(has_einops, "Requires einops")
|
1168
|
-
def
|
1177
|
+
def test_shape_different_latents(
|
1169
1178
|
self,
|
1170
1179
|
ae_model_type,
|
1171
1180
|
autoencoder_params,
|
@@ -1242,6 +1251,84 @@ class LatentControlNetTestDiffusionSamplingInferer(unittest.TestCase):
|
|
1242
1251
|
)
|
1243
1252
|
self.assertEqual(prediction.shape, latent_shape)
|
1244
1253
|
|
1254
|
+
@parameterized.expand(LATENT_CNDM_TEST_CASES_DIFF_SHAPES)
|
1255
|
+
@skipUnless(has_einops, "Requires einops")
|
1256
|
+
def test_sample_shape_different_latents(
|
1257
|
+
self,
|
1258
|
+
ae_model_type,
|
1259
|
+
autoencoder_params,
|
1260
|
+
dm_model_type,
|
1261
|
+
stage_2_params,
|
1262
|
+
controlnet_params,
|
1263
|
+
input_shape,
|
1264
|
+
latent_shape,
|
1265
|
+
):
|
1266
|
+
stage_1 = None
|
1267
|
+
|
1268
|
+
if ae_model_type == "AutoencoderKL":
|
1269
|
+
stage_1 = AutoencoderKL(**autoencoder_params)
|
1270
|
+
if ae_model_type == "VQVAE":
|
1271
|
+
stage_1 = VQVAE(**autoencoder_params)
|
1272
|
+
if ae_model_type == "SPADEAutoencoderKL":
|
1273
|
+
stage_1 = SPADEAutoencoderKL(**autoencoder_params)
|
1274
|
+
if dm_model_type == "SPADEDiffusionModelUNet":
|
1275
|
+
stage_2 = SPADEDiffusionModelUNet(**stage_2_params)
|
1276
|
+
else:
|
1277
|
+
stage_2 = DiffusionModelUNet(**stage_2_params)
|
1278
|
+
controlnet = ControlNet(**controlnet_params)
|
1279
|
+
|
1280
|
+
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
1281
|
+
stage_1.to(device)
|
1282
|
+
stage_2.to(device)
|
1283
|
+
controlnet.to(device)
|
1284
|
+
stage_1.eval()
|
1285
|
+
stage_2.eval()
|
1286
|
+
controlnet.eval()
|
1287
|
+
|
1288
|
+
noise = torch.randn(latent_shape).to(device)
|
1289
|
+
mask = torch.randn(input_shape).to(device)
|
1290
|
+
scheduler = DDPMScheduler(num_train_timesteps=10)
|
1291
|
+
# We infer the VAE shape
|
1292
|
+
if ae_model_type == "VQVAE":
|
1293
|
+
autoencoder_latent_shape = [i // (2 ** (len(autoencoder_params["channels"]))) for i in input_shape[2:]]
|
1294
|
+
else:
|
1295
|
+
autoencoder_latent_shape = [i // (2 ** (len(autoencoder_params["channels"]) - 1)) for i in input_shape[2:]]
|
1296
|
+
|
1297
|
+
inferer = ControlNetLatentDiffusionInferer(
|
1298
|
+
scheduler=scheduler,
|
1299
|
+
scale_factor=1.0,
|
1300
|
+
ldm_latent_shape=list(latent_shape[2:]),
|
1301
|
+
autoencoder_latent_shape=autoencoder_latent_shape,
|
1302
|
+
)
|
1303
|
+
scheduler.set_timesteps(num_inference_steps=10)
|
1304
|
+
|
1305
|
+
if dm_model_type == "SPADEDiffusionModelUNet" or ae_model_type == "SPADEAutoencoderKL":
|
1306
|
+
input_shape_seg = list(input_shape)
|
1307
|
+
if "label_nc" in stage_2_params.keys():
|
1308
|
+
input_shape_seg[1] = stage_2_params["label_nc"]
|
1309
|
+
else:
|
1310
|
+
input_shape_seg[1] = autoencoder_params["label_nc"]
|
1311
|
+
input_seg = torch.randn(input_shape_seg).to(device)
|
1312
|
+
prediction, _ = inferer.sample(
|
1313
|
+
autoencoder_model=stage_1,
|
1314
|
+
diffusion_model=stage_2,
|
1315
|
+
controlnet=controlnet,
|
1316
|
+
cn_cond=mask,
|
1317
|
+
input_noise=noise,
|
1318
|
+
seg=input_seg,
|
1319
|
+
save_intermediates=True,
|
1320
|
+
)
|
1321
|
+
else:
|
1322
|
+
prediction = inferer.sample(
|
1323
|
+
autoencoder_model=stage_1,
|
1324
|
+
diffusion_model=stage_2,
|
1325
|
+
input_noise=noise,
|
1326
|
+
controlnet=controlnet,
|
1327
|
+
cn_cond=mask,
|
1328
|
+
save_intermediates=False,
|
1329
|
+
)
|
1330
|
+
self.assertEqual(prediction.shape, input_shape)
|
1331
|
+
|
1245
1332
|
@skipUnless(has_einops, "Requires einops")
|
1246
1333
|
def test_incompatible_spade_setup(self):
|
1247
1334
|
stage_1 = SPADEAutoencoderKL(
|
@@ -714,7 +714,7 @@ class TestDiffusionSamplingInferer(unittest.TestCase):
|
|
714
714
|
|
715
715
|
@parameterized.expand(TEST_CASES_DIFF_SHAPES)
|
716
716
|
@skipUnless(has_einops, "Requires einops")
|
717
|
-
def
|
717
|
+
def test_shape_different_latents(
|
718
718
|
self, ae_model_type, autoencoder_params, dm_model_type, stage_2_params, input_shape, latent_shape
|
719
719
|
):
|
720
720
|
stage_1 = None
|
@@ -772,6 +772,66 @@ class TestDiffusionSamplingInferer(unittest.TestCase):
|
|
772
772
|
)
|
773
773
|
self.assertEqual(prediction.shape, latent_shape)
|
774
774
|
|
775
|
+
@parameterized.expand(TEST_CASES_DIFF_SHAPES)
|
776
|
+
@skipUnless(has_einops, "Requires einops")
|
777
|
+
def test_sample_shape_different_latents(
|
778
|
+
self, ae_model_type, autoencoder_params, dm_model_type, stage_2_params, input_shape, latent_shape
|
779
|
+
):
|
780
|
+
stage_1 = None
|
781
|
+
|
782
|
+
if ae_model_type == "AutoencoderKL":
|
783
|
+
stage_1 = AutoencoderKL(**autoencoder_params)
|
784
|
+
if ae_model_type == "VQVAE":
|
785
|
+
stage_1 = VQVAE(**autoencoder_params)
|
786
|
+
if ae_model_type == "SPADEAutoencoderKL":
|
787
|
+
stage_1 = SPADEAutoencoderKL(**autoencoder_params)
|
788
|
+
if dm_model_type == "SPADEDiffusionModelUNet":
|
789
|
+
stage_2 = SPADEDiffusionModelUNet(**stage_2_params)
|
790
|
+
else:
|
791
|
+
stage_2 = DiffusionModelUNet(**stage_2_params)
|
792
|
+
|
793
|
+
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
794
|
+
stage_1.to(device)
|
795
|
+
stage_2.to(device)
|
796
|
+
stage_1.eval()
|
797
|
+
stage_2.eval()
|
798
|
+
|
799
|
+
noise = torch.randn(latent_shape).to(device)
|
800
|
+
scheduler = DDPMScheduler(num_train_timesteps=10)
|
801
|
+
# We infer the VAE shape
|
802
|
+
if ae_model_type == "VQVAE":
|
803
|
+
autoencoder_latent_shape = [i // (2 ** (len(autoencoder_params["channels"]))) for i in input_shape[2:]]
|
804
|
+
else:
|
805
|
+
autoencoder_latent_shape = [i // (2 ** (len(autoencoder_params["channels"]) - 1)) for i in input_shape[2:]]
|
806
|
+
|
807
|
+
inferer = LatentDiffusionInferer(
|
808
|
+
scheduler=scheduler,
|
809
|
+
scale_factor=1.0,
|
810
|
+
ldm_latent_shape=list(latent_shape[2:]),
|
811
|
+
autoencoder_latent_shape=autoencoder_latent_shape,
|
812
|
+
)
|
813
|
+
scheduler.set_timesteps(num_inference_steps=10)
|
814
|
+
|
815
|
+
if dm_model_type == "SPADEDiffusionModelUNet" or ae_model_type == "SPADEAutoencoderKL":
|
816
|
+
input_shape_seg = list(input_shape)
|
817
|
+
if "label_nc" in stage_2_params.keys():
|
818
|
+
input_shape_seg[1] = stage_2_params["label_nc"]
|
819
|
+
else:
|
820
|
+
input_shape_seg[1] = autoencoder_params["label_nc"]
|
821
|
+
input_seg = torch.randn(input_shape_seg).to(device)
|
822
|
+
prediction, _ = inferer.sample(
|
823
|
+
autoencoder_model=stage_1,
|
824
|
+
diffusion_model=stage_2,
|
825
|
+
input_noise=noise,
|
826
|
+
save_intermediates=True,
|
827
|
+
seg=input_seg,
|
828
|
+
)
|
829
|
+
else:
|
830
|
+
prediction = inferer.sample(
|
831
|
+
autoencoder_model=stage_1, diffusion_model=stage_2, input_noise=noise, save_intermediates=False
|
832
|
+
)
|
833
|
+
self.assertEqual(prediction.shape, input_shape)
|
834
|
+
|
775
835
|
@skipUnless(has_einops, "Requires einops")
|
776
836
|
def test_incompatible_spade_setup(self):
|
777
837
|
stage_1 = SPADEAutoencoderKL(
|