careamics 0.0.2__tar.gz → 0.0.3__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (254) hide show
  1. {careamics-0.0.2 → careamics-0.0.3}/.pre-commit-config.yaml +3 -2
  2. {careamics-0.0.2 → careamics-0.0.3}/PKG-INFO +2 -2
  3. careamics-0.0.3/examples/example_training_LVAE_split.ipynb +439 -0
  4. careamics-0.0.3/mypy.ini +17 -0
  5. {careamics-0.0.2 → careamics-0.0.3}/pyproject.toml +4 -0
  6. {careamics-0.0.2 → careamics-0.0.3}/src/careamics/careamist.py +14 -11
  7. {careamics-0.0.2 → careamics-0.0.3}/src/careamics/config/__init__.py +7 -3
  8. {careamics-0.0.2 → careamics-0.0.3}/src/careamics/config/architectures/__init__.py +2 -2
  9. {careamics-0.0.2 → careamics-0.0.3}/src/careamics/config/architectures/architecture_model.py +1 -1
  10. {careamics-0.0.2 → careamics-0.0.3}/src/careamics/config/architectures/custom_model.py +11 -8
  11. careamics-0.0.3/src/careamics/config/architectures/lvae_model.py +174 -0
  12. {careamics-0.0.2 → careamics-0.0.3}/src/careamics/config/configuration_factory.py +11 -3
  13. {careamics-0.0.2 → careamics-0.0.3}/src/careamics/config/configuration_model.py +7 -3
  14. {careamics-0.0.2 → careamics-0.0.3}/src/careamics/config/data_model.py +33 -8
  15. careamics-0.0.2/src/careamics/config/algorithm_model.py → careamics-0.0.3/src/careamics/config/fcn_algorithm_model.py +28 -43
  16. careamics-0.0.3/src/careamics/config/likelihood_model.py +43 -0
  17. careamics-0.0.3/src/careamics/config/nm_model.py +101 -0
  18. {careamics-0.0.2 → careamics-0.0.3}/src/careamics/config/support/supported_activations.py +1 -0
  19. careamics-0.0.3/src/careamics/config/support/supported_algorithms.py +33 -0
  20. careamics-0.0.3/src/careamics/config/support/supported_architectures.py +17 -0
  21. {careamics-0.0.2 → careamics-0.0.3}/src/careamics/config/support/supported_losses.py +3 -1
  22. {careamics-0.0.2 → careamics-0.0.3}/src/careamics/config/transformations/n2v_manipulate_model.py +1 -1
  23. careamics-0.0.3/src/careamics/config/vae_algorithm_model.py +171 -0
  24. careamics-0.0.3/src/careamics/dataset/tiling/lvae_tiled_patching.py +282 -0
  25. {careamics-0.0.2 → careamics-0.0.3}/src/careamics/file_io/read/tiff.py +1 -1
  26. {careamics-0.0.2 → careamics-0.0.3}/src/careamics/lightning/__init__.py +3 -2
  27. {careamics-0.0.2 → careamics-0.0.3}/src/careamics/lightning/callbacks/hyperparameters_callback.py +1 -1
  28. {careamics-0.0.2 → careamics-0.0.3}/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy_factory.py +1 -1
  29. careamics-0.0.3/src/careamics/lightning/lightning_module.py +632 -0
  30. {careamics-0.0.2 → careamics-0.0.3}/src/careamics/lightning/predict_data_module.py +2 -2
  31. {careamics-0.0.2 → careamics-0.0.3}/src/careamics/lightning/train_data_module.py +2 -2
  32. careamics-0.0.3/src/careamics/losses/__init__.py +15 -0
  33. careamics-0.0.3/src/careamics/losses/fcn/__init__.py +1 -0
  34. {careamics-0.0.2/src/careamics/losses → careamics-0.0.3/src/careamics/losses/fcn}/losses.py +1 -1
  35. careamics-0.0.3/src/careamics/losses/loss_factory.py +155 -0
  36. careamics-0.0.3/src/careamics/losses/lvae/__init__.py +1 -0
  37. careamics-0.0.3/src/careamics/losses/lvae/loss_utils.py +83 -0
  38. careamics-0.0.3/src/careamics/losses/lvae/losses.py +445 -0
  39. {careamics-0.0.2/src/careamics/lvae_training → careamics-0.0.3/src/careamics/lvae_training/dataset}/data_utils.py +277 -194
  40. careamics-0.0.3/src/careamics/lvae_training/dataset/lc_dataset.py +259 -0
  41. careamics-0.0.3/src/careamics/lvae_training/dataset/lc_dataset_config.py +13 -0
  42. careamics-0.0.3/src/careamics/lvae_training/dataset/vae_data_config.py +179 -0
  43. careamics-0.0.2/src/careamics/lvae_training/data_modules.py → careamics-0.0.3/src/careamics/lvae_training/dataset/vae_dataset.py +306 -472
  44. {careamics-0.0.2 → careamics-0.0.3}/src/careamics/lvae_training/get_config.py +1 -1
  45. {careamics-0.0.2 → careamics-0.0.3}/src/careamics/lvae_training/train_lvae.py +6 -3
  46. {careamics-0.0.2 → careamics-0.0.3}/src/careamics/model_io/bioimage/bioimage_utils.py +1 -1
  47. {careamics-0.0.2 → careamics-0.0.3}/src/careamics/model_io/bioimage/model_description.py +2 -2
  48. {careamics-0.0.2 → careamics-0.0.3}/src/careamics/model_io/bmz_io.py +19 -6
  49. {careamics-0.0.2 → careamics-0.0.3}/src/careamics/model_io/model_io_utils.py +16 -4
  50. careamics-0.0.3/src/careamics/models/__init__.py +5 -0
  51. {careamics-0.0.2 → careamics-0.0.3}/src/careamics/models/activation.py +2 -0
  52. careamics-0.0.3/src/careamics/models/lvae/__init__.py +3 -0
  53. {careamics-0.0.2 → careamics-0.0.3}/src/careamics/models/lvae/layers.py +21 -21
  54. careamics-0.0.3/src/careamics/models/lvae/likelihoods.py +364 -0
  55. {careamics-0.0.2 → careamics-0.0.3}/src/careamics/models/lvae/lvae.py +52 -136
  56. careamics-0.0.3/src/careamics/models/lvae/noise_models.py +541 -0
  57. {careamics-0.0.2 → careamics-0.0.3}/src/careamics/models/lvae/utils.py +2 -2
  58. {careamics-0.0.2 → careamics-0.0.3}/src/careamics/models/model_factory.py +22 -7
  59. careamics-0.0.3/src/careamics/prediction_utils/lvae_prediction.py +158 -0
  60. careamics-0.0.3/src/careamics/prediction_utils/lvae_tiling_manager.py +362 -0
  61. {careamics-0.0.2 → careamics-0.0.3}/src/careamics/prediction_utils/stitch_prediction.py +16 -2
  62. {careamics-0.0.2 → careamics-0.0.3}/src/careamics/transforms/pixel_manipulation.py +1 -1
  63. {careamics-0.0.2 → careamics-0.0.3}/src/careamics/utils/metrics.py +74 -1
  64. {careamics-0.0.2 → careamics-0.0.3}/tests/config/architectures/test_custom_model.py +14 -32
  65. careamics-0.0.3/tests/config/architectures/test_lvae_model.py +134 -0
  66. {careamics-0.0.2 → careamics-0.0.3}/tests/config/architectures/test_register_model.py +1 -0
  67. {careamics-0.0.2 → careamics-0.0.3}/tests/config/architectures/test_unet_model.py +6 -5
  68. {careamics-0.0.2 → careamics-0.0.3}/tests/config/test_configuration_model.py +2 -0
  69. {careamics-0.0.2 → careamics-0.0.3}/tests/config/test_data_model.py +22 -0
  70. careamics-0.0.3/tests/config/test_fcn_algorithm_model.py +125 -0
  71. careamics-0.0.3/tests/config/test_vae_algorithm_model.py +68 -0
  72. {careamics-0.0.2 → careamics-0.0.3}/tests/conftest.py +108 -11
  73. careamics-0.0.3/tests/dataset/tiling/test_lvae_tiled_patching.py +188 -0
  74. {careamics-0.0.2 → careamics-0.0.3}/tests/lightning/callbacks/prediction_writer_callback/test_prediction_writer_callback.py +2 -0
  75. careamics-0.0.3/tests/lightning/test_LVAE_lightning_module.py +579 -0
  76. {careamics-0.0.2 → careamics-0.0.3}/tests/lightning/test_lightning_api.py +2 -0
  77. {careamics-0.0.2 → careamics-0.0.3}/tests/lightning/test_lightning_module.py +40 -25
  78. careamics-0.0.3/tests/likelihood_modules/test_likelihoods.py +97 -0
  79. careamics-0.0.3/tests/losses/test_lvae_losses.py +335 -0
  80. careamics-0.0.3/tests/models/lvae/test_dataset.py +131 -0
  81. careamics-0.0.3/tests/models/lvae/test_noise_model.py +141 -0
  82. {careamics-0.0.2 → careamics-0.0.3}/tests/models/test_model_factory.py +2 -13
  83. careamics-0.0.3/tests/prediction_utils/test_lvae_prediction.py +182 -0
  84. {careamics-0.0.2 → careamics-0.0.3}/tests/test_conftest.py +2 -2
  85. {careamics-0.0.2 → careamics-0.0.3}/tests/transforms/test_pixel_manipulation.py +1 -1
  86. careamics-0.0.2/src/careamics/config/architectures/vae_model.py +0 -42
  87. careamics-0.0.2/src/careamics/config/support/supported_algorithms.py +0 -20
  88. careamics-0.0.2/src/careamics/config/support/supported_architectures.py +0 -20
  89. careamics-0.0.2/src/careamics/lightning/lightning_module.py +0 -276
  90. careamics-0.0.2/src/careamics/losses/__init__.py +0 -5
  91. careamics-0.0.2/src/careamics/losses/loss_factory.py +0 -49
  92. careamics-0.0.2/src/careamics/models/__init__.py +0 -7
  93. careamics-0.0.2/src/careamics/models/lvae/likelihoods.py +0 -312
  94. careamics-0.0.2/src/careamics/models/lvae/noise_models.py +0 -409
  95. careamics-0.0.2/tests/config/test_algorithm_model.py +0 -101
  96. {careamics-0.0.2 → careamics-0.0.3}/.github/ISSUE_TEMPLATE/bug_report.md +0 -0
  97. {careamics-0.0.2 → careamics-0.0.3}/.github/ISSUE_TEMPLATE/feature_request.md +0 -0
  98. {careamics-0.0.2 → careamics-0.0.3}/.github/TEST_FAIL_TEMPLATE.md +0 -0
  99. {careamics-0.0.2 → careamics-0.0.3}/.github/pull_request_template.md +0 -0
  100. {careamics-0.0.2 → careamics-0.0.3}/.github/workflows/ci.yml +0 -0
  101. {careamics-0.0.2 → careamics-0.0.3}/.gitignore +0 -0
  102. {careamics-0.0.2 → careamics-0.0.3}/LICENSE +0 -0
  103. {careamics-0.0.2 → careamics-0.0.3}/README.md +0 -0
  104. {careamics-0.0.2 → careamics-0.0.3}/examples/3D/example_flywing_3D.ipynb +0 -0
  105. {careamics-0.0.2 → careamics-0.0.3}/examples/3D/n2v_flywing_3D.yml +0 -0
  106. {careamics-0.0.2 → careamics-0.0.3}/examples/evaluate_LVAE.ipynb +0 -0
  107. {careamics-0.0.2 → careamics-0.0.3}/src/careamics/__init__.py +0 -0
  108. {careamics-0.0.2 → careamics-0.0.3}/src/careamics/config/architectures/register_model.py +0 -0
  109. {careamics-0.0.2 → careamics-0.0.3}/src/careamics/config/architectures/unet_model.py +0 -0
  110. {careamics-0.0.2 → careamics-0.0.3}/src/careamics/config/callback_model.py +0 -0
  111. {careamics-0.0.2 → careamics-0.0.3}/src/careamics/config/inference_model.py +0 -0
  112. {careamics-0.0.2 → careamics-0.0.3}/src/careamics/config/optimizer_models.py +0 -0
  113. {careamics-0.0.2 → careamics-0.0.3}/src/careamics/config/references/__init__.py +0 -0
  114. {careamics-0.0.2 → careamics-0.0.3}/src/careamics/config/references/algorithm_descriptions.py +0 -0
  115. {careamics-0.0.2 → careamics-0.0.3}/src/careamics/config/references/references.py +0 -0
  116. {careamics-0.0.2 → careamics-0.0.3}/src/careamics/config/support/__init__.py +0 -0
  117. {careamics-0.0.2 → careamics-0.0.3}/src/careamics/config/support/supported_data.py +0 -0
  118. {careamics-0.0.2 → careamics-0.0.3}/src/careamics/config/support/supported_loggers.py +0 -0
  119. {careamics-0.0.2 → careamics-0.0.3}/src/careamics/config/support/supported_optimizers.py +0 -0
  120. {careamics-0.0.2 → careamics-0.0.3}/src/careamics/config/support/supported_pixel_manipulations.py +0 -0
  121. {careamics-0.0.2 → careamics-0.0.3}/src/careamics/config/support/supported_struct_axis.py +0 -0
  122. {careamics-0.0.2 → careamics-0.0.3}/src/careamics/config/support/supported_transforms.py +0 -0
  123. {careamics-0.0.2 → careamics-0.0.3}/src/careamics/config/tile_information.py +0 -0
  124. {careamics-0.0.2 → careamics-0.0.3}/src/careamics/config/training_model.py +0 -0
  125. {careamics-0.0.2 → careamics-0.0.3}/src/careamics/config/transformations/__init__.py +0 -0
  126. {careamics-0.0.2 → careamics-0.0.3}/src/careamics/config/transformations/normalize_model.py +0 -0
  127. {careamics-0.0.2 → careamics-0.0.3}/src/careamics/config/transformations/transform_model.py +0 -0
  128. {careamics-0.0.2 → careamics-0.0.3}/src/careamics/config/transformations/xy_flip_model.py +0 -0
  129. {careamics-0.0.2 → careamics-0.0.3}/src/careamics/config/transformations/xy_random_rotate90_model.py +0 -0
  130. {careamics-0.0.2 → careamics-0.0.3}/src/careamics/config/validators/__init__.py +0 -0
  131. {careamics-0.0.2 → careamics-0.0.3}/src/careamics/config/validators/validator_utils.py +0 -0
  132. {careamics-0.0.2 → careamics-0.0.3}/src/careamics/conftest.py +0 -0
  133. {careamics-0.0.2 → careamics-0.0.3}/src/careamics/dataset/__init__.py +0 -0
  134. {careamics-0.0.2 → careamics-0.0.3}/src/careamics/dataset/dataset_utils/__init__.py +0 -0
  135. {careamics-0.0.2 → careamics-0.0.3}/src/careamics/dataset/dataset_utils/dataset_utils.py +0 -0
  136. {careamics-0.0.2 → careamics-0.0.3}/src/careamics/dataset/dataset_utils/file_utils.py +0 -0
  137. {careamics-0.0.2 → careamics-0.0.3}/src/careamics/dataset/dataset_utils/iterate_over_files.py +0 -0
  138. {careamics-0.0.2 → careamics-0.0.3}/src/careamics/dataset/dataset_utils/running_stats.py +0 -0
  139. {careamics-0.0.2 → careamics-0.0.3}/src/careamics/dataset/in_memory_dataset.py +0 -0
  140. {careamics-0.0.2 → careamics-0.0.3}/src/careamics/dataset/in_memory_pred_dataset.py +0 -0
  141. {careamics-0.0.2 → careamics-0.0.3}/src/careamics/dataset/in_memory_tiled_pred_dataset.py +0 -0
  142. {careamics-0.0.2 → careamics-0.0.3}/src/careamics/dataset/iterable_dataset.py +0 -0
  143. {careamics-0.0.2 → careamics-0.0.3}/src/careamics/dataset/iterable_pred_dataset.py +0 -0
  144. {careamics-0.0.2 → careamics-0.0.3}/src/careamics/dataset/iterable_tiled_pred_dataset.py +0 -0
  145. {careamics-0.0.2 → careamics-0.0.3}/src/careamics/dataset/patching/__init__.py +0 -0
  146. {careamics-0.0.2 → careamics-0.0.3}/src/careamics/dataset/patching/patching.py +0 -0
  147. {careamics-0.0.2 → careamics-0.0.3}/src/careamics/dataset/patching/random_patching.py +0 -0
  148. {careamics-0.0.2 → careamics-0.0.3}/src/careamics/dataset/patching/sequential_patching.py +0 -0
  149. {careamics-0.0.2 → careamics-0.0.3}/src/careamics/dataset/patching/validate_patch_dimension.py +0 -0
  150. {careamics-0.0.2 → careamics-0.0.3}/src/careamics/dataset/tiling/__init__.py +0 -0
  151. {careamics-0.0.2 → careamics-0.0.3}/src/careamics/dataset/tiling/collate_tiles.py +0 -0
  152. {careamics-0.0.2 → careamics-0.0.3}/src/careamics/dataset/tiling/tiled_patching.py +0 -0
  153. {careamics-0.0.2 → careamics-0.0.3}/src/careamics/dataset/zarr_dataset.py +0 -0
  154. {careamics-0.0.2 → careamics-0.0.3}/src/careamics/file_io/__init__.py +0 -0
  155. {careamics-0.0.2 → careamics-0.0.3}/src/careamics/file_io/read/__init__.py +0 -0
  156. {careamics-0.0.2 → careamics-0.0.3}/src/careamics/file_io/read/get_func.py +0 -0
  157. {careamics-0.0.2 → careamics-0.0.3}/src/careamics/file_io/read/zarr.py +0 -0
  158. {careamics-0.0.2 → careamics-0.0.3}/src/careamics/file_io/write/__init__.py +0 -0
  159. {careamics-0.0.2 → careamics-0.0.3}/src/careamics/file_io/write/get_func.py +0 -0
  160. {careamics-0.0.2 → careamics-0.0.3}/src/careamics/file_io/write/tiff.py +0 -0
  161. {careamics-0.0.2 → careamics-0.0.3}/src/careamics/lightning/callbacks/__init__.py +0 -0
  162. {careamics-0.0.2 → careamics-0.0.3}/src/careamics/lightning/callbacks/prediction_writer_callback/__init__.py +0 -0
  163. {careamics-0.0.2 → careamics-0.0.3}/src/careamics/lightning/callbacks/prediction_writer_callback/file_path_utils.py +0 -0
  164. {careamics-0.0.2 → careamics-0.0.3}/src/careamics/lightning/callbacks/prediction_writer_callback/prediction_writer_callback.py +0 -0
  165. {careamics-0.0.2 → careamics-0.0.3}/src/careamics/lightning/callbacks/prediction_writer_callback/write_strategy.py +0 -0
  166. {careamics-0.0.2 → careamics-0.0.3}/src/careamics/lightning/callbacks/progress_bar_callback.py +0 -0
  167. {careamics-0.0.2 → careamics-0.0.3}/src/careamics/lvae_training/__init__.py +0 -0
  168. {careamics-0.0.2/src/careamics/models/lvae → careamics-0.0.3/src/careamics/lvae_training/dataset}/__init__.py +0 -0
  169. {careamics-0.0.2 → careamics-0.0.3}/src/careamics/lvae_training/eval_utils.py +0 -0
  170. {careamics-0.0.2 → careamics-0.0.3}/src/careamics/lvae_training/lightning_module.py +0 -0
  171. {careamics-0.0.2 → careamics-0.0.3}/src/careamics/lvae_training/metrics.py +0 -0
  172. {careamics-0.0.2 → careamics-0.0.3}/src/careamics/lvae_training/train_utils.py +0 -0
  173. {careamics-0.0.2 → careamics-0.0.3}/src/careamics/model_io/__init__.py +0 -0
  174. {careamics-0.0.2 → careamics-0.0.3}/src/careamics/model_io/bioimage/__init__.py +0 -0
  175. {careamics-0.0.2 → careamics-0.0.3}/src/careamics/model_io/bioimage/_readme_factory.py +0 -0
  176. {careamics-0.0.2 → careamics-0.0.3}/src/careamics/models/layers.py +0 -0
  177. {careamics-0.0.2 → careamics-0.0.3}/src/careamics/models/unet.py +0 -0
  178. {careamics-0.0.2 → careamics-0.0.3}/src/careamics/prediction_utils/__init__.py +0 -0
  179. {careamics-0.0.2 → careamics-0.0.3}/src/careamics/prediction_utils/prediction_outputs.py +0 -0
  180. {careamics-0.0.2 → careamics-0.0.3}/src/careamics/py.typed +0 -0
  181. {careamics-0.0.2 → careamics-0.0.3}/src/careamics/transforms/__init__.py +0 -0
  182. {careamics-0.0.2 → careamics-0.0.3}/src/careamics/transforms/compose.py +0 -0
  183. {careamics-0.0.2 → careamics-0.0.3}/src/careamics/transforms/n2v_manipulate.py +0 -0
  184. {careamics-0.0.2 → careamics-0.0.3}/src/careamics/transforms/normalize.py +0 -0
  185. {careamics-0.0.2 → careamics-0.0.3}/src/careamics/transforms/struct_mask_parameters.py +0 -0
  186. {careamics-0.0.2 → careamics-0.0.3}/src/careamics/transforms/transform.py +0 -0
  187. {careamics-0.0.2 → careamics-0.0.3}/src/careamics/transforms/tta.py +0 -0
  188. {careamics-0.0.2 → careamics-0.0.3}/src/careamics/transforms/xy_flip.py +0 -0
  189. {careamics-0.0.2 → careamics-0.0.3}/src/careamics/transforms/xy_random_rotate90.py +0 -0
  190. {careamics-0.0.2 → careamics-0.0.3}/src/careamics/utils/__init__.py +0 -0
  191. {careamics-0.0.2 → careamics-0.0.3}/src/careamics/utils/autocorrelation.py +0 -0
  192. {careamics-0.0.2 → careamics-0.0.3}/src/careamics/utils/base_enum.py +0 -0
  193. {careamics-0.0.2 → careamics-0.0.3}/src/careamics/utils/context.py +0 -0
  194. {careamics-0.0.2 → careamics-0.0.3}/src/careamics/utils/logging.py +0 -0
  195. {careamics-0.0.2 → careamics-0.0.3}/src/careamics/utils/path_utils.py +0 -0
  196. {careamics-0.0.2 → careamics-0.0.3}/src/careamics/utils/ram.py +0 -0
  197. {careamics-0.0.2 → careamics-0.0.3}/src/careamics/utils/receptive_field.py +0 -0
  198. {careamics-0.0.2 → careamics-0.0.3}/src/careamics/utils/torch_utils.py +0 -0
  199. {careamics-0.0.2 → careamics-0.0.3}/tests/config/architectures/test_architecture_model.py +0 -0
  200. {careamics-0.0.2 → careamics-0.0.3}/tests/config/support/test_supported_data.py +0 -0
  201. {careamics-0.0.2 → careamics-0.0.3}/tests/config/support/test_supported_optimizers.py +0 -0
  202. {careamics-0.0.2 → careamics-0.0.3}/tests/config/test_callback_models.py +0 -0
  203. {careamics-0.0.2 → careamics-0.0.3}/tests/config/test_configuration_factory.py +0 -0
  204. {careamics-0.0.2 → careamics-0.0.3}/tests/config/test_inference_model.py +0 -0
  205. {careamics-0.0.2 → careamics-0.0.3}/tests/config/test_optimizers_model.py +0 -0
  206. {careamics-0.0.2 → careamics-0.0.3}/tests/config/test_tile_information.py +0 -0
  207. {careamics-0.0.2 → careamics-0.0.3}/tests/config/test_training_model.py +0 -0
  208. {careamics-0.0.2 → careamics-0.0.3}/tests/config/transformations/test_n2v_manipulate_model.py +0 -0
  209. {careamics-0.0.2 → careamics-0.0.3}/tests/config/transformations/test_normalize_model.py +0 -0
  210. {careamics-0.0.2 → careamics-0.0.3}/tests/config/transformations/test_xy_flip_model.py +0 -0
  211. {careamics-0.0.2 → careamics-0.0.3}/tests/config/transformations/test_xy_random_rotate90_model.py +0 -0
  212. {careamics-0.0.2 → careamics-0.0.3}/tests/config/validators/test_validator_utils.py +0 -0
  213. {careamics-0.0.2 → careamics-0.0.3}/tests/dataset/dataset_utils/test_compute_normalization_stats.py +0 -0
  214. {careamics-0.0.2 → careamics-0.0.3}/tests/dataset/dataset_utils/test_list_files.py +0 -0
  215. {careamics-0.0.2 → careamics-0.0.3}/tests/dataset/patching/test_patching_utils.py +0 -0
  216. {careamics-0.0.2 → careamics-0.0.3}/tests/dataset/patching/test_random_patching.py +0 -0
  217. {careamics-0.0.2 → careamics-0.0.3}/tests/dataset/patching/test_sequential_patching.py +0 -0
  218. {careamics-0.0.2 → careamics-0.0.3}/tests/dataset/test_in_memory_dataset.py +0 -0
  219. {careamics-0.0.2 → careamics-0.0.3}/tests/dataset/test_in_memory_pred_dataset.py +0 -0
  220. {careamics-0.0.2 → careamics-0.0.3}/tests/dataset/test_in_memory_tiled_pred_dataset.py +0 -0
  221. {careamics-0.0.2 → careamics-0.0.3}/tests/dataset/test_iterable_dataset.py +0 -0
  222. {careamics-0.0.2 → careamics-0.0.3}/tests/dataset/test_iterable_pred_dataset.py +0 -0
  223. {careamics-0.0.2 → careamics-0.0.3}/tests/dataset/test_iterable_tiled_pred_dataset.py +0 -0
  224. {careamics-0.0.2 → careamics-0.0.3}/tests/dataset/tiling/test_collate_tiles.py +0 -0
  225. {careamics-0.0.2 → careamics-0.0.3}/tests/dataset/tiling/test_tiled_patching.py +0 -0
  226. {careamics-0.0.2 → careamics-0.0.3}/tests/file_io/read/test_get_read_func.py +0 -0
  227. {careamics-0.0.2 → careamics-0.0.3}/tests/file_io/read/test_read_tiff.py +0 -0
  228. {careamics-0.0.2 → careamics-0.0.3}/tests/file_io/write/test_get_write_func.py +0 -0
  229. {careamics-0.0.2 → careamics-0.0.3}/tests/file_io/write/test_write_tiff.py +0 -0
  230. {careamics-0.0.2 → careamics-0.0.3}/tests/lightning/callbacks/prediction_writer_callback/test_cache_tiles_write_strategy.py +0 -0
  231. {careamics-0.0.2 → careamics-0.0.3}/tests/lightning/callbacks/prediction_writer_callback/test_file_path_utils.py +0 -0
  232. {careamics-0.0.2 → careamics-0.0.3}/tests/lightning/callbacks/prediction_writer_callback/test_write_image_write_strategy.py +0 -0
  233. {careamics-0.0.2 → careamics-0.0.3}/tests/lightning/callbacks/prediction_writer_callback/test_write_strategy_factory.py +0 -0
  234. {careamics-0.0.2 → careamics-0.0.3}/tests/lightning/test_predict_data_module.py +0 -0
  235. {careamics-0.0.2 → careamics-0.0.3}/tests/lightning/test_train_data_module.py +0 -0
  236. {careamics-0.0.2 → careamics-0.0.3}/tests/model_io/test_bmz_io.py +0 -0
  237. {careamics-0.0.2 → careamics-0.0.3}/tests/models/test_unet.py +0 -0
  238. {careamics-0.0.2 → careamics-0.0.3}/tests/prediction_utils/test_prediction_outputs.py +0 -0
  239. {careamics-0.0.2 → careamics-0.0.3}/tests/prediction_utils/test_stitch_prediction.py +0 -0
  240. {careamics-0.0.2 → careamics-0.0.3}/tests/test_careamist.py +0 -0
  241. {careamics-0.0.2 → careamics-0.0.3}/tests/transforms/test_compose.py +0 -0
  242. {careamics-0.0.2 → careamics-0.0.3}/tests/transforms/test_manipulate_n2v.py +0 -0
  243. {careamics-0.0.2 → careamics-0.0.3}/tests/transforms/test_normalize.py +0 -0
  244. {careamics-0.0.2 → careamics-0.0.3}/tests/transforms/test_supported_transforms.py +0 -0
  245. {careamics-0.0.2 → careamics-0.0.3}/tests/transforms/test_tta.py +0 -0
  246. {careamics-0.0.2 → careamics-0.0.3}/tests/transforms/test_xy_flip.py +0 -0
  247. {careamics-0.0.2 → careamics-0.0.3}/tests/transforms/test_xy_random_rotate90.py +0 -0
  248. {careamics-0.0.2 → careamics-0.0.3}/tests/utils/test_autocorrelation.py +0 -0
  249. {careamics-0.0.2 → careamics-0.0.3}/tests/utils/test_base_enum.py +0 -0
  250. {careamics-0.0.2 → careamics-0.0.3}/tests/utils/test_context.py +0 -0
  251. {careamics-0.0.2 → careamics-0.0.3}/tests/utils/test_logging.py +0 -0
  252. {careamics-0.0.2 → careamics-0.0.3}/tests/utils/test_metrics.py +0 -0
  253. {careamics-0.0.2 → careamics-0.0.3}/tests/utils/test_torch_utils.py +0 -0
  254. {careamics-0.0.2 → careamics-0.0.3}/tests/utils/test_wandb.py +0 -0
@@ -30,7 +30,8 @@ repos:
30
30
  hooks:
31
31
  - id: mypy
32
32
  files: "^src/"
33
- exclude: "^src/careamics/lvae_training/.*|^src/careamics/models/lvae/.*"
33
+ exclude: "^src/careamics/lvae_training/.*|^src/careamics/models/lvae/.*|^src/careamics/config/likelihood_model.py|^src/careamics/losses/loss_factory.py|^src/careamics/losses/lvae/losses.py"
34
+ args: ['--config-file', 'mypy.ini']
34
35
  additional_dependencies:
35
36
  - numpy
36
37
  - types-PyYAML
@@ -41,7 +42,7 @@ repos:
41
42
  rev: v1.8.0rc2
42
43
  hooks:
43
44
  - id: numpydoc-validation
44
- exclude: "^src/careamics/lvae_training/.*|^src/careamics/models/lvae/.*"
45
+ exclude: "^src/careamics/lvae_training/.*|^src/careamics/models/lvae/.*|^src/careamics/losses/lvae/.*"
45
46
 
46
47
  # # jupyter linting and formatting
47
48
  # - repo: https://github.com/nbQA-dev/nbQA
@@ -1,10 +1,10 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: careamics
3
- Version: 0.0.2
3
+ Version: 0.0.3
4
4
  Summary: Toolbox for running N2V and friends.
5
5
  Project-URL: homepage, https://careamics.github.io/
6
6
  Project-URL: repository, https://github.com/CAREamics/careamics
7
- Author-email: Melisande Croft <melisande.croft@fht.org>, Joran Deschamps <joran.deschamps@fht.org>, Igor Zubarev <igor.zubarev@fht.org>
7
+ Author-email: CAREamics team <rse@fht.org>, Ashesh <ashesh.ashesh@fht.org>, Federico Carrara <federico.carrara@fht.org>, Melisande Croft <melisande.croft@fht.org>, Joran Deschamps <joran.deschamps@fht.org>, Vera Galinova <vera.galinova@fht.org>, Igor Zubarev <igor.zubarev@fht.org>
8
8
  License: BSD-3-Clause
9
9
  License-File: LICENSE
10
10
  Classifier: Development Status :: 3 - Alpha
@@ -0,0 +1,439 @@
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "import os\n",
10
+ "import socket\n",
11
+ "from pathlib import Path\n",
12
+ "from typing import Optional, Literal, Union\n",
13
+ "\n",
14
+ "import numpy as np\n",
15
+ "import torch \n",
16
+ "from torch.utils.data import Dataset, DataLoader\n",
17
+ "from pytorch_lightning import Trainer\n",
18
+ "from pytorch_lightning.loggers import WandbLogger\n",
19
+ "from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint, EarlyStopping\n",
20
+ "\n",
21
+ "from careamics.config import VAEAlgorithmConfig\n",
22
+ "from careamics.config.architectures import LVAEModel\n",
23
+ "from careamics.config.callback_model import CheckpointModel, EarlyStoppingModel\n",
24
+ "from careamics.config.likelihood_model import (\n",
25
+ " GaussianLikelihoodConfig,\n",
26
+ " NMLikelihoodConfig,\n",
27
+ ")\n",
28
+ "from careamics.config.nm_model import GaussianMixtureNMConfig, MultiChannelNMConfig\n",
29
+ "from careamics.lightning import VAEModule\n",
30
+ "from careamics.models.lvae.noise_models import noise_model_factory"
31
+ ]
32
+ },
33
+ {
34
+ "cell_type": "markdown",
35
+ "metadata": {},
36
+ "source": [
37
+ "Set some parameters for the current training simulation"
38
+ ]
39
+ },
40
+ {
41
+ "cell_type": "code",
42
+ "execution_count": null,
43
+ "metadata": {},
44
+ "outputs": [],
45
+ "source": [
46
+ "img_size: int = 64\n",
47
+ "\"\"\"Spatial size of the input image.\"\"\"\n",
48
+ "target_channels: int = 2\n",
49
+ "\"\"\"Number of channels in the target image.\"\"\"\n",
50
+ "multiscale_count: int = 5\n",
51
+ "\"\"\"The number of LC inputs plus one (the actual input).\"\"\"\n",
52
+ "predict_logvar: Optional[Literal[\"pixelwise\"]] = \"pixelwise\"\n",
53
+ "\"\"\"Whether to compute also the log-variance as LVAE output.\"\"\"\n",
54
+ "loss_type: Optional[Literal[\"musplit\", \"denoisplit\", \"denoisplit_musplit\"]] = \"musplit\"\n",
55
+ "\"\"\"The type of reconstruction loss (i.e., likelihood) to use.\"\"\""
56
+ ]
57
+ },
58
+ {
59
+ "cell_type": "markdown",
60
+ "metadata": {},
61
+ "source": [
62
+ "## 1. Create `Dataset` and `Dataloader`"
63
+ ]
64
+ },
65
+ {
66
+ "cell_type": "markdown",
67
+ "metadata": {},
68
+ "source": [
69
+ "### 1.1. Dummy Data"
70
+ ]
71
+ },
72
+ {
73
+ "cell_type": "code",
74
+ "execution_count": null,
75
+ "metadata": {},
76
+ "outputs": [],
77
+ "source": [
78
+ "class DummyDataset(Dataset):\n",
79
+ " def __init__(\n",
80
+ " self, \n",
81
+ " img_size: int = 64, \n",
82
+ " target_ch: int = 1,\n",
83
+ " multiscale_count: int = 1,\n",
84
+ " ):\n",
85
+ " self.num_samples = 100\n",
86
+ " self.img_size = img_size\n",
87
+ " self.target_ch = target_ch\n",
88
+ " self.multiscale_count = multiscale_count\n",
89
+ " \n",
90
+ " def __len__(self):\n",
91
+ " return self.num_samples\n",
92
+ " \n",
93
+ " def __getitem__(self, idx: int):\n",
94
+ " input_ = torch.randn(self.multiscale_count, self.img_size, self.img_size)\n",
95
+ " target = torch.randn(self.target_ch, self.img_size, self.img_size)\n",
96
+ " return input_, target\n",
97
+ "\n",
98
+ "def dummy_dataloader(\n",
99
+ " batch_size: int = 1,\n",
100
+ " img_size: int = 64,\n",
101
+ " target_ch: int = 1,\n",
102
+ " multiscale_count: int = 1,\n",
103
+ "):\n",
104
+ " dataset = DummyDataset(\n",
105
+ " img_size=img_size,\n",
106
+ " target_ch=target_ch,\n",
107
+ " multiscale_count=multiscale_count,\n",
108
+ " )\n",
109
+ " return DataLoader(dataset, batch_size=batch_size, num_workers=3, shuffle=False)"
110
+ ]
111
+ },
112
+ {
113
+ "cell_type": "code",
114
+ "execution_count": null,
115
+ "metadata": {},
116
+ "outputs": [],
117
+ "source": [
118
+ "dloader = dummy_dataloader(\n",
119
+ " img_size=img_size,\n",
120
+ " target_ch=target_channels,\n",
121
+ " multiscale_count=multiscale_count,\n",
122
+ ")\n",
123
+ "input_, target = next(iter(dloader))\n",
124
+ "input_.shape, target.shape"
125
+ ]
126
+ },
127
+ {
128
+ "cell_type": "markdown",
129
+ "metadata": {},
130
+ "source": [
131
+ "### 1.2. Real Data"
132
+ ]
133
+ },
134
+ {
135
+ "cell_type": "code",
136
+ "execution_count": null,
137
+ "metadata": {},
138
+ "outputs": [],
139
+ "source": []
140
+ },
141
+ {
142
+ "cell_type": "markdown",
143
+ "metadata": {},
144
+ "source": [
145
+ "## 2. Instantiate the lightning module"
146
+ ]
147
+ },
148
+ {
149
+ "cell_type": "code",
150
+ "execution_count": null,
151
+ "metadata": {},
152
+ "outputs": [],
153
+ "source": [
154
+ "def create_dummy_noise_model(\n",
155
+ " save_path: Optional[Union[Path, str]] = None,\n",
156
+ " n_gaussians: int = 3,\n",
157
+ " n_coeffs: int = 3,\n",
158
+ ") -> Path:\n",
159
+ " weights = np.random.rand(3*n_gaussians, n_coeffs)\n",
160
+ " nm_dict = {\n",
161
+ " \"trained_weight\": weights,\n",
162
+ " \"min_signal\": np.array([0]),\n",
163
+ " \"max_signal\": np.array([2**16 - 1]),\n",
164
+ " \"min_sigma\": 0.125,\n",
165
+ " }\n",
166
+ " out_path = Path(save_path) / \"dummy_noise_model.npz\"\n",
167
+ " np.savez(out_path, **nm_dict)\n",
168
+ " return out_path"
169
+ ]
170
+ },
171
+ {
172
+ "cell_type": "code",
173
+ "execution_count": null,
174
+ "metadata": {},
175
+ "outputs": [],
176
+ "source": [
177
+ "def create_split_lightning_model(\n",
178
+ " algorithm: str,\n",
179
+ " loss_type: str,\n",
180
+ " multiscale_count: int = 1,\n",
181
+ " predict_logvar: Optional[Literal[\"pixelwise\"]] = None,\n",
182
+ " target_ch: int = 1,\n",
183
+ " NM_path: Optional[Path] = None,\n",
184
+ ") -> VAEModule:\n",
185
+ " \"\"\"Instantiate the muSplit lightining model.\"\"\"\n",
186
+ " lvae_config = LVAEModel(\n",
187
+ " architecture=\"LVAE\",\n",
188
+ " input_shape=64,\n",
189
+ " multiscale_count=multiscale_count,\n",
190
+ " z_dims=[128, 128, 128, 128],\n",
191
+ " output_channels=target_ch,\n",
192
+ " predict_logvar=predict_logvar,\n",
193
+ " )\n",
194
+ "\n",
195
+ " # gaussian likelihood\n",
196
+ " if loss_type in [\"musplit\", \"denoisplit_musplit\"]:\n",
197
+ " gaussian_lik_config = GaussianLikelihoodConfig(\n",
198
+ " predict_logvar=predict_logvar,\n",
199
+ " logvar_lowerbound=0.0,\n",
200
+ " )\n",
201
+ " else:\n",
202
+ " gaussian_lik_config = None\n",
203
+ " # noise model likelihood\n",
204
+ " if loss_type in [\"denoisplit\", \"denoisplit_musplit\"]:\n",
205
+ " if NM_path is None:\n",
206
+ " NM_path = create_dummy_noise_model(Path(\"./\"), 3, 3)\n",
207
+ " gmm = GaussianMixtureNMConfig(\n",
208
+ " model_type=\"GaussianMixtureNoiseModel\",\n",
209
+ " path=NM_path,\n",
210
+ " )\n",
211
+ " noise_model_config = MultiChannelNMConfig(noise_models=[gmm] * target_ch)\n",
212
+ " nm = noise_model_factory(noise_model_config)\n",
213
+ " nm_lik_config = NMLikelihoodConfig(noise_model=nm)\n",
214
+ " else:\n",
215
+ " noise_model_config = None\n",
216
+ " nm_lik_config = None\n",
217
+ "\n",
218
+ " vae_config = VAEAlgorithmConfig(\n",
219
+ " algorithm_type=\"vae\",\n",
220
+ " algorithm=algorithm,\n",
221
+ " loss=loss_type,\n",
222
+ " model=lvae_config,\n",
223
+ " gaussian_likelihood_model=gaussian_lik_config,\n",
224
+ " noise_model=noise_model_config,\n",
225
+ " noise_model_likelihood_model=nm_lik_config,\n",
226
+ " )\n",
227
+ "\n",
228
+ " return VAEModule(\n",
229
+ " algorithm_config=vae_config,\n",
230
+ " )"
231
+ ]
232
+ },
233
+ {
234
+ "cell_type": "code",
235
+ "execution_count": null,
236
+ "metadata": {},
237
+ "outputs": [],
238
+ "source": [
239
+ "algo = \"musplit\" if loss_type == \"musplit\" else \"denoisplit\"\n",
240
+ "lightning_model = create_split_lightning_model(\n",
241
+ " algorithm=algo,\n",
242
+ " loss_type=loss_type,\n",
243
+ " multiscale_count=multiscale_count,\n",
244
+ " predict_logvar=predict_logvar,\n",
245
+ " target_ch=target_channels,\n",
246
+ " NM_path=None\n",
247
+ ")"
248
+ ]
249
+ },
250
+ {
251
+ "cell_type": "markdown",
252
+ "metadata": {},
253
+ "source": [
254
+ "## 3. Set utils for training"
255
+ ]
256
+ },
257
+ {
258
+ "cell_type": "code",
259
+ "execution_count": null,
260
+ "metadata": {},
261
+ "outputs": [],
262
+ "source": [
263
+ "from datetime import datetime\n",
264
+ "\n",
265
+ "from careamics.lvae_training.train_utils import get_new_model_version\n",
266
+ "\n",
267
+ "def get_new_model_version(model_dir: Union[Path, str]) -> int:\n",
268
+ " \"\"\"Create a unique version ID for a new model run.\"\"\"\n",
269
+ " versions = []\n",
270
+ " for version_dir in os.listdir(model_dir):\n",
271
+ " try:\n",
272
+ " versions.append(int(version_dir))\n",
273
+ " except:\n",
274
+ " print(\n",
275
+ " f\"Invalid subdirectory:{model_dir}/{version_dir}. Only integer versions are allowed\"\n",
276
+ " )\n",
277
+ " exit()\n",
278
+ " if len(versions) == 0:\n",
279
+ " return \"0\"\n",
280
+ " return f\"{max(versions) + 1}\"\n",
281
+ "\n",
282
+ "def get_workdir(\n",
283
+ " root_dir: str,\n",
284
+ " model_name: str,\n",
285
+ ") -> tuple[Path, Path]:\n",
286
+ " \"\"\"Get the workdir for the current model.\n",
287
+ " \n",
288
+ " It has the following structure: \"root_dir/YYMM/model_name/version\"\n",
289
+ " \"\"\"\n",
290
+ " rel_path = datetime.now().strftime(\"%y%m\")\n",
291
+ " cur_workdir = os.path.join(root_dir, rel_path)\n",
292
+ " Path(cur_workdir).mkdir(exist_ok=True)\n",
293
+ "\n",
294
+ " rel_path = os.path.join(rel_path, model_name)\n",
295
+ " cur_workdir = os.path.join(root_dir, rel_path)\n",
296
+ " Path(cur_workdir).mkdir(exist_ok=True)\n",
297
+ "\n",
298
+ " rel_path = os.path.join(rel_path, get_new_model_version(cur_workdir))\n",
299
+ " cur_workdir = os.path.join(root_dir, rel_path)\n",
300
+ " try:\n",
301
+ " Path(cur_workdir).mkdir(exist_ok=False)\n",
302
+ " except FileExistsError:\n",
303
+ " print(\n",
304
+ " f\"Workdir {cur_workdir} already exists.\"\n",
305
+ " )\n",
306
+ " return cur_workdir, rel_path"
307
+ ]
308
+ },
309
+ {
310
+ "cell_type": "code",
311
+ "execution_count": null,
312
+ "metadata": {},
313
+ "outputs": [],
314
+ "source": [
315
+ "ROOT_DIR = \"/group/jug/federico/careamics_training/refac_v2/\"\n",
316
+ "workdir, exp_tag = get_workdir(ROOT_DIR, \"dummy_debugging\")\n",
317
+ "print(f\"Current workdir: {workdir}\")"
318
+ ]
319
+ },
320
+ {
321
+ "cell_type": "code",
322
+ "execution_count": null,
323
+ "metadata": {},
324
+ "outputs": [],
325
+ "source": [
326
+ "# Define the logger\n",
327
+ "custom_logger = WandbLogger(\n",
328
+ " name=os.path.join(socket.gethostname(), exp_tag),\n",
329
+ " save_dir=workdir,\n",
330
+ " project=\"careamics_debugging_LVAE\",\n",
331
+ ")"
332
+ ]
333
+ },
334
+ {
335
+ "cell_type": "code",
336
+ "execution_count": null,
337
+ "metadata": {},
338
+ "outputs": [],
339
+ "source": [
340
+ "# Define callbacks (e.g., ModelCheckpoint, EarlyStopping, etc.)\n",
341
+ "early_stopping_config = EarlyStoppingModel(\n",
342
+ " monitor=\"val_loss\",\n",
343
+ " min_delta=1e-6,\n",
344
+ " patience=10,\n",
345
+ " mode=\"min\",\n",
346
+ " verbose=True,\n",
347
+ ")\n",
348
+ "checkpoint_config = CheckpointModel(\n",
349
+ " monitor=\"val_loss\",\n",
350
+ " save_top_k=2,\n",
351
+ " mode=\"min\",\n",
352
+ ")\n",
353
+ "custom_callbacks = [\n",
354
+ " EarlyStopping(**early_stopping_config.model_dump()), \n",
355
+ " ModelCheckpoint(**checkpoint_config.model_dump()),\n",
356
+ " LearningRateMonitor(logging_interval=\"epoch\")\n",
357
+ "]"
358
+ ]
359
+ },
360
+ {
361
+ "cell_type": "code",
362
+ "execution_count": null,
363
+ "metadata": {},
364
+ "outputs": [],
365
+ "source": [
366
+ "# Save AlgorithmConfig\n",
367
+ "with open(os.path.join(workdir, \"algorithm_config.json\"), \"w\") as f:\n",
368
+ " f.write(lightning_model.algorithm_config.model_dump_json())\n",
369
+ "\n",
370
+ "custom_logger.experiment.config.update(\n",
371
+ " lightning_model.algorithm_config.model_dump() \n",
372
+ ")"
373
+ ]
374
+ },
375
+ {
376
+ "cell_type": "markdown",
377
+ "metadata": {},
378
+ "source": [
379
+ "## 4. Train the model"
380
+ ]
381
+ },
382
+ {
383
+ "cell_type": "code",
384
+ "execution_count": null,
385
+ "metadata": {},
386
+ "outputs": [],
387
+ "source": [
388
+ "trainer = Trainer(\n",
389
+ " max_epochs=10,\n",
390
+ " accelerator=\"cpu\",\n",
391
+ " enable_progress_bar=True,\n",
392
+ " logger=custom_logger,\n",
393
+ " callbacks=custom_callbacks,\n",
394
+ ")"
395
+ ]
396
+ },
397
+ {
398
+ "cell_type": "code",
399
+ "execution_count": null,
400
+ "metadata": {},
401
+ "outputs": [],
402
+ "source": [
403
+ "trainer.fit(\n",
404
+ " model=lightning_model,\n",
405
+ " train_dataloaders=dloader,\n",
406
+ " val_dataloaders=dloader,\n",
407
+ ")"
408
+ ]
409
+ },
410
+ {
411
+ "cell_type": "code",
412
+ "execution_count": null,
413
+ "metadata": {},
414
+ "outputs": [],
415
+ "source": []
416
+ }
417
+ ],
418
+ "metadata": {
419
+ "kernelspec": {
420
+ "display_name": "train_lvae",
421
+ "language": "python",
422
+ "name": "python3"
423
+ },
424
+ "language_info": {
425
+ "codemirror_mode": {
426
+ "name": "ipython",
427
+ "version": 3
428
+ },
429
+ "file_extension": ".py",
430
+ "mimetype": "text/x-python",
431
+ "name": "python",
432
+ "nbconvert_exporter": "python",
433
+ "pygments_lexer": "ipython3",
434
+ "version": "3.9.19"
435
+ }
436
+ },
437
+ "nbformat": 4,
438
+ "nbformat_minor": 2
439
+ }
@@ -0,0 +1,17 @@
1
+ [mypy]
2
+ ignore_missing_imports = True
3
+
4
+ [mypy-careamics.lvae_training.*]
5
+ follow_imports = skip
6
+
7
+ [mypy-careamics.models.lvae.*]
8
+ follow_imports = skip
9
+
10
+ [mypy-careamics.losses.loss_factory]
11
+ follow_imports = skip
12
+
13
+ [mypy-careamics.losses.lvae.losses]
14
+ follow_imports = skip
15
+
16
+ [mypy-careamics.config.likelihood_model]
17
+ follow_imports = skip
@@ -23,8 +23,12 @@ readme = "README.md"
23
23
  requires-python = ">=3.9"
24
24
  license = { text = "BSD-3-Clause" }
25
25
  authors = [
26
+ { name = 'CAREamics team', email = 'rse@fht.org' },
27
+ { name = 'Ashesh', email = 'ashesh.ashesh@fht.org' },
28
+ { name = 'Federico Carrara', email = 'federico.carrara@fht.org' },
26
29
  { name = 'Melisande Croft', email = 'melisande.croft@fht.org' },
27
30
  { name = 'Joran Deschamps', email = 'joran.deschamps@fht.org' },
31
+ { name = 'Vera Galinova', email = 'vera.galinova@fht.org' },
28
32
  { name = 'Igor Zubarev', email = 'igor.zubarev@fht.org' },
29
33
  ]
30
34
  classifiers = [
@@ -13,10 +13,7 @@ from pytorch_lightning.callbacks import (
13
13
  )
14
14
  from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger
15
15
 
16
- from careamics.config import (
17
- Configuration,
18
- load_configuration,
19
- )
16
+ from careamics.config import Configuration, FCNAlgorithmConfig, load_configuration
20
17
  from careamics.config.support import (
21
18
  SupportedAlgorithm,
22
19
  SupportedArchitecture,
@@ -25,7 +22,7 @@ from careamics.config.support import (
25
22
  )
26
23
  from careamics.dataset.dataset_utils import reshape_array
27
24
  from careamics.lightning import (
28
- CAREamicsModule,
25
+ FCNModule,
29
26
  HyperParametersCallback,
30
27
  PredictDataModule,
31
28
  ProgressBarCallback,
@@ -148,9 +145,12 @@ class CAREamist:
148
145
  self.cfg = source
149
146
 
150
147
  # instantiate model
151
- self.model = CAREamicsModule(
152
- algorithm_config=self.cfg.algorithm_config,
153
- )
148
+ if isinstance(self.cfg.algorithm_config, FCNAlgorithmConfig):
149
+ self.model = FCNModule(
150
+ algorithm_config=self.cfg.algorithm_config,
151
+ )
152
+ else:
153
+ raise NotImplementedError("Architecture not supported.")
154
154
 
155
155
  # path to configuration file or model
156
156
  else:
@@ -164,9 +164,12 @@ class CAREamist:
164
164
  self.cfg = load_configuration(source)
165
165
 
166
166
  # instantiate model
167
- self.model = CAREamicsModule(
168
- algorithm_config=self.cfg.algorithm_config,
169
- )
167
+ if isinstance(self.cfg.algorithm_config, FCNAlgorithmConfig):
168
+ self.model = FCNModule(
169
+ algorithm_config=self.cfg.algorithm_config,
170
+ ) # type: ignore
171
+ else:
172
+ raise NotImplementedError("Architecture not supported.")
170
173
 
171
174
  # attempt loading a pre-trained model
172
175
  else:
@@ -1,7 +1,8 @@
1
1
  """Configuration module."""
2
2
 
3
3
  __all__ = [
4
- "AlgorithmConfig",
4
+ "FCNAlgorithmConfig",
5
+ "VAEAlgorithmConfig",
5
6
  "DataConfig",
6
7
  "Configuration",
7
8
  "CheckpointModel",
@@ -15,9 +16,9 @@ __all__ = [
15
16
  "register_model",
16
17
  "CustomModel",
17
18
  "clear_custom_models",
19
+ "GaussianMixtureNMConfig",
20
+ "MultiChannelNMConfig",
18
21
  ]
19
-
20
- from .algorithm_model import AlgorithmConfig
21
22
  from .architectures import CustomModel, clear_custom_models, register_model
22
23
  from .callback_model import CheckpointModel
23
24
  from .configuration_factory import (
@@ -31,5 +32,8 @@ from .configuration_model import (
31
32
  save_configuration,
32
33
  )
33
34
  from .data_model import DataConfig
35
+ from .fcn_algorithm_model import FCNAlgorithmConfig
34
36
  from .inference_model import InferenceConfig
37
+ from .nm_model import GaussianMixtureNMConfig, MultiChannelNMConfig
35
38
  from .training_model import TrainingConfig
39
+ from .vae_algorithm_model import VAEAlgorithmConfig
@@ -4,7 +4,7 @@ __all__ = [
4
4
  "ArchitectureModel",
5
5
  "CustomModel",
6
6
  "UNetModel",
7
- "VAEModel",
7
+ "LVAEModel",
8
8
  "clear_custom_models",
9
9
  "get_custom_model",
10
10
  "register_model",
@@ -12,6 +12,6 @@ __all__ = [
12
12
 
13
13
  from .architecture_model import ArchitectureModel
14
14
  from .custom_model import CustomModel
15
+ from .lvae_model import LVAEModel
15
16
  from .register_model import clear_custom_models, get_custom_model, register_model
16
17
  from .unet_model import UNetModel
17
- from .vae_model import VAEModel
@@ -27,7 +27,7 @@ class ArchitectureModel(BaseModel):
27
27
  Returns
28
28
  -------
29
29
  dict[str, Any]
30
- Model as a dictionnary.
30
+ Model as a dictionary.
31
31
  """
32
32
  model_dict = super().model_dump(**kwargs)
33
33
 
@@ -2,6 +2,7 @@
2
2
 
3
3
  from __future__ import annotations
4
4
 
5
+ import inspect
5
6
  from pprint import pformat
6
7
  from typing import Any, Literal
7
8
 
@@ -23,12 +24,13 @@ class CustomModel(ArchitectureModel):
23
24
 
24
25
  Attributes
25
26
  ----------
26
- architecture : Literal["Custom"]
27
- Discriminator for the custom model, must be set to "Custom".
27
+ architecture : Literal["custom"]
28
+ Discriminator for the custom model, must be set to "custom".
28
29
  name : str
29
30
  Name of the custom model.
30
31
  parameters : CustomParametersModel
31
- Parameters of the custom model.
32
+ All parameters, required for the initialization of the torch module have to be
33
+ passed here.
32
34
 
33
35
  Raises
34
36
  ------
@@ -57,7 +59,7 @@ class CustomModel(ArchitectureModel):
57
59
  ...
58
60
  >>> # Create a configuration
59
61
  >>> config_dict = {
60
- ... "architecture": "Custom",
62
+ ... "architecture": "custom",
61
63
  ... "name": "my_linear",
62
64
  ... "in_features": 10,
63
65
  ... "out_features": 5,
@@ -71,10 +73,9 @@ class CustomModel(ArchitectureModel):
71
73
  )
72
74
 
73
75
  # discriminator used for choosing the pydantic model in Model
74
- architecture: Literal["Custom"]
76
+ architecture: Literal["custom"]
75
77
  """Name of the architecture."""
76
78
 
77
- # name of the custom model
78
79
  name: str
79
80
  """Name of the custom model."""
80
81
 
@@ -120,10 +121,12 @@ class CustomModel(ArchitectureModel):
120
121
  get_custom_model(self.name)(**self.model_dump())
121
122
  except Exception as e:
122
123
  raise ValueError(
123
- f"error while passing parameters to the model {e}. Verify that all "
124
+ f"while passing parameters to the model {e}. Verify that all "
124
125
  f"mandatory parameters are provided, and that either the {e} accepts "
125
126
  f"*args and **kwargs in its __init__() method, or that no additional"
126
- f"parameter is provided."
127
+ f"parameter is provided. Trace: "
128
+ f"filename: {inspect.trace()[-1].filename}, function: "
129
+ f"{inspect.trace()[-1].function}, line: {inspect.trace()[-1].lineno}"
127
130
  ) from None
128
131
 
129
132
  return self