careamics 0.1.0rc1__tar.gz → 0.1.0rc2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of careamics might be problematic. Click here for more details.

Files changed (97) hide show
  1. {careamics-0.1.0rc1 → careamics-0.1.0rc2}/.github/workflows/ci.yml +2 -2
  2. {careamics-0.1.0rc1 → careamics-0.1.0rc2}/.pre-commit-config.yaml +5 -5
  3. {careamics-0.1.0rc1 → careamics-0.1.0rc2}/PKG-INFO +4 -3
  4. {careamics-0.1.0rc1 → careamics-0.1.0rc2}/pyproject.toml +21 -41
  5. {careamics-0.1.0rc1 → careamics-0.1.0rc2}/src/careamics/bioimage/__init__.py +3 -3
  6. careamics-0.1.0rc2/src/careamics/bioimage/io.py +182 -0
  7. careamics-0.1.0rc2/src/careamics/bioimage/rdf.py +105 -0
  8. {careamics-0.1.0rc1 → careamics-0.1.0rc2}/src/careamics/config/config.py +4 -3
  9. {careamics-0.1.0rc1 → careamics-0.1.0rc2}/src/careamics/config/data.py +2 -2
  10. {careamics-0.1.0rc1 → careamics-0.1.0rc2}/src/careamics/dataset/dataset_utils.py +1 -5
  11. {careamics-0.1.0rc1 → careamics-0.1.0rc2}/src/careamics/dataset/in_memory_dataset.py +3 -2
  12. {careamics-0.1.0rc1 → careamics-0.1.0rc2}/src/careamics/dataset/patching.py +5 -6
  13. {careamics-0.1.0rc1 → careamics-0.1.0rc2}/src/careamics/dataset/prepare_dataset.py +4 -3
  14. {careamics-0.1.0rc1 → careamics-0.1.0rc2}/src/careamics/dataset/tiff_dataset.py +4 -3
  15. {careamics-0.1.0rc1 → careamics-0.1.0rc2}/src/careamics/engine.py +170 -110
  16. {careamics-0.1.0rc1 → careamics-0.1.0rc2}/src/careamics/losses/loss_factory.py +3 -2
  17. {careamics-0.1.0rc1 → careamics-0.1.0rc2}/src/careamics/models/model_factory.py +30 -11
  18. {careamics-0.1.0rc1 → careamics-0.1.0rc2}/src/careamics/prediction/prediction_utils.py +16 -12
  19. {careamics-0.1.0rc1 → careamics-0.1.0rc2}/src/careamics/utils/__init__.py +3 -4
  20. careamics-0.1.0rc2/src/careamics/utils/torch_utils.py +89 -0
  21. {careamics-0.1.0rc1 → careamics-0.1.0rc2}/src/careamics/utils/validators.py +34 -20
  22. {careamics-0.1.0rc1 → careamics-0.1.0rc2}/src/careamics/utils/wandb.py +1 -1
  23. careamics-0.1.0rc2/tests/bioimage/test_engine_bmz.py +123 -0
  24. careamics-0.1.0rc2/tests/bioimage/test_io.py +84 -0
  25. careamics-0.1.0rc2/tests/bioimage/test_rdf.py +37 -0
  26. {careamics-0.1.0rc1 → careamics-0.1.0rc2}/tests/conftest.py +5 -91
  27. {careamics-0.1.0rc1 → careamics-0.1.0rc2}/tests/dataset/test_patching.py +13 -0
  28. careamics-0.1.0rc2/tests/dataset/test_tiff_dataset.py +57 -0
  29. careamics-0.1.0rc2/tests/model/test_model_factory.py +2 -0
  30. careamics-0.1.0rc2/tests/smoke_test.py +134 -0
  31. careamics-0.1.0rc2/tests/test_conftest.py +35 -0
  32. {careamics-0.1.0rc1 → careamics-0.1.0rc2}/tests/test_engine.py +1 -1
  33. {careamics-0.1.0rc1 → careamics-0.1.0rc2}/tests/test_prediction_utils.py +46 -1
  34. careamics-0.1.0rc2/tests/utils/test_torch_utils.py +22 -0
  35. careamics-0.1.0rc1/tests/utils/test_axes.py → careamics-0.1.0rc2/tests/utils/test_validators.py +26 -1
  36. careamics-0.1.0rc1/.github/workflows/smoke_test.yml +0 -37
  37. careamics-0.1.0rc1/src/careamics/bioimage/io.py +0 -271
  38. careamics-0.1.0rc1/src/careamics/utils/torch_utils.py +0 -93
  39. careamics-0.1.0rc1/tests/bioimage/test_io.py +0 -160
  40. careamics-0.1.0rc1/tests/smoke_test.py +0 -53
  41. careamics-0.1.0rc1/tests/utils/test_torch_utils.py +0 -23
  42. {careamics-0.1.0rc1 → careamics-0.1.0rc2}/.github/ISSUE_TEMPLATE/bug_report.md +0 -0
  43. {careamics-0.1.0rc1 → careamics-0.1.0rc2}/.github/ISSUE_TEMPLATE/feature_request.md +0 -0
  44. {careamics-0.1.0rc1 → careamics-0.1.0rc2}/.github/TEST_FAIL_TEMPLATE.md +0 -0
  45. {careamics-0.1.0rc1 → careamics-0.1.0rc2}/.github/dependabot.yml +0 -0
  46. {careamics-0.1.0rc1 → careamics-0.1.0rc2}/.github/pull_request_template.md +0 -0
  47. {careamics-0.1.0rc1 → careamics-0.1.0rc2}/.gitignore +0 -0
  48. {careamics-0.1.0rc1 → careamics-0.1.0rc2}/LICENSE +0 -0
  49. {careamics-0.1.0rc1 → careamics-0.1.0rc2}/README.md +0 -0
  50. {careamics-0.1.0rc1 → careamics-0.1.0rc2}/examples/2D/example_BSD68.ipynb +0 -0
  51. {careamics-0.1.0rc1 → careamics-0.1.0rc2}/examples/2D/example_SEM.ipynb +0 -0
  52. {careamics-0.1.0rc1 → careamics-0.1.0rc2}/examples/2D/n2v_2D_BSD.yml +0 -0
  53. {careamics-0.1.0rc1 → careamics-0.1.0rc2}/examples/2D/n2v_2D_SEM.yml +0 -0
  54. {careamics-0.1.0rc1 → careamics-0.1.0rc2}/examples/3D/example_flywing_3D.ipynb +0 -0
  55. {careamics-0.1.0rc1 → careamics-0.1.0rc2}/examples/3D/n2v_3D.yml +0 -0
  56. {careamics-0.1.0rc1 → careamics-0.1.0rc2}/examples/3D/n2v_flywing_3D.yml +0 -0
  57. {careamics-0.1.0rc1 → careamics-0.1.0rc2}/examples/n2v_full_reference.yml +0 -0
  58. {careamics-0.1.0rc1 → careamics-0.1.0rc2}/src/careamics/__init__.py +0 -0
  59. {careamics-0.1.0rc1 → careamics-0.1.0rc2}/src/careamics/bioimage/docs/Noise2Void.md +0 -0
  60. {careamics-0.1.0rc1 → careamics-0.1.0rc2}/src/careamics/bioimage/docs/__init__.py +0 -0
  61. {careamics-0.1.0rc1 → careamics-0.1.0rc2}/src/careamics/config/__init__.py +0 -0
  62. {careamics-0.1.0rc1 → careamics-0.1.0rc2}/src/careamics/config/algorithm.py +0 -0
  63. {careamics-0.1.0rc1 → careamics-0.1.0rc2}/src/careamics/config/config_filter.py +0 -0
  64. {careamics-0.1.0rc1 → careamics-0.1.0rc2}/src/careamics/config/torch_optim.py +0 -0
  65. {careamics-0.1.0rc1 → careamics-0.1.0rc2}/src/careamics/config/training.py +0 -0
  66. {careamics-0.1.0rc1 → careamics-0.1.0rc2}/src/careamics/dataset/__init__.py +0 -0
  67. {careamics-0.1.0rc1 → careamics-0.1.0rc2}/src/careamics/dataset/extraction_strategy.py +0 -0
  68. {careamics-0.1.0rc1 → careamics-0.1.0rc2}/src/careamics/losses/__init__.py +0 -0
  69. {careamics-0.1.0rc1 → careamics-0.1.0rc2}/src/careamics/losses/losses.py +0 -0
  70. {careamics-0.1.0rc1 → careamics-0.1.0rc2}/src/careamics/manipulation/__init__.py +0 -0
  71. {careamics-0.1.0rc1 → careamics-0.1.0rc2}/src/careamics/manipulation/pixel_manipulation.py +0 -0
  72. {careamics-0.1.0rc1 → careamics-0.1.0rc2}/src/careamics/models/__init__.py +0 -0
  73. {careamics-0.1.0rc1 → careamics-0.1.0rc2}/src/careamics/models/layers.py +0 -0
  74. {careamics-0.1.0rc1 → careamics-0.1.0rc2}/src/careamics/models/unet.py +0 -0
  75. {careamics-0.1.0rc1 → careamics-0.1.0rc2}/src/careamics/prediction/__init__.py +0 -0
  76. {careamics-0.1.0rc1 → careamics-0.1.0rc2}/src/careamics/py.typed +0 -0
  77. {careamics-0.1.0rc1 → careamics-0.1.0rc2}/src/careamics/utils/ascii_logo.txt +0 -0
  78. {careamics-0.1.0rc1 → careamics-0.1.0rc2}/src/careamics/utils/augment.py +0 -0
  79. {careamics-0.1.0rc1 → careamics-0.1.0rc2}/src/careamics/utils/context.py +0 -0
  80. {careamics-0.1.0rc1 → careamics-0.1.0rc2}/src/careamics/utils/logging.py +0 -0
  81. {careamics-0.1.0rc1 → careamics-0.1.0rc2}/src/careamics/utils/metrics.py +0 -0
  82. {careamics-0.1.0rc1 → careamics-0.1.0rc2}/src/careamics/utils/normalization.py +0 -0
  83. {careamics-0.1.0rc1 → careamics-0.1.0rc2}/tests/__init__.py +0 -0
  84. {careamics-0.1.0rc1 → careamics-0.1.0rc2}/tests/config/test_algorithm.py +0 -0
  85. {careamics-0.1.0rc1 → careamics-0.1.0rc2}/tests/config/test_config.py +0 -0
  86. {careamics-0.1.0rc1 → careamics-0.1.0rc2}/tests/config/test_config_filters.py +0 -0
  87. {careamics-0.1.0rc1 → careamics-0.1.0rc2}/tests/config/test_data.py +0 -0
  88. {careamics-0.1.0rc1 → careamics-0.1.0rc2}/tests/config/test_torch_optimizer.py +0 -0
  89. {careamics-0.1.0rc1 → careamics-0.1.0rc2}/tests/config/test_training.py +0 -0
  90. {careamics-0.1.0rc1 → careamics-0.1.0rc2}/tests/dataset/test_dataset_utils.py +0 -0
  91. {careamics-0.1.0rc1 → careamics-0.1.0rc2}/tests/manipulation/test_pixel_manipulation.py +0 -0
  92. {careamics-0.1.0rc1/tests → careamics-0.1.0rc2/tests/model}/test_model.py +0 -0
  93. {careamics-0.1.0rc1 → careamics-0.1.0rc2}/tests/test_augment.py +0 -0
  94. {careamics-0.1.0rc1 → careamics-0.1.0rc2}/tests/test_metrics.py +0 -0
  95. {careamics-0.1.0rc1 → careamics-0.1.0rc2}/tests/utils/test_context.py +0 -0
  96. {careamics-0.1.0rc1 → careamics-0.1.0rc2}/tests/utils/test_logging.py +0 -0
  97. {careamics-0.1.0rc1 → careamics-0.1.0rc2}/tests/utils/test_wandb.py +0 -0
@@ -40,7 +40,7 @@ jobs:
40
40
  - uses: actions/checkout@v4
41
41
 
42
42
  - name: 🐍 Set up Python ${{ matrix.python-version }}
43
- uses: actions/setup-python@v4
43
+ uses: actions/setup-python@v5
44
44
  with:
45
45
  python-version: ${{ matrix.python-version }}
46
46
  cache-dependency-path: pyproject.toml
@@ -91,7 +91,7 @@ jobs:
91
91
  fetch-depth: 0
92
92
 
93
93
  - name: 🐍 Set up Python
94
- uses: actions/setup-python@v4
94
+ uses: actions/setup-python@v5
95
95
  with:
96
96
  python-version: 3.x
97
97
 
@@ -10,23 +10,23 @@ ci:
10
10
 
11
11
  repos:
12
12
  - repo: https://github.com/abravalheri/validate-pyproject
13
- rev: v0.14
13
+ rev: v0.15
14
14
  hooks:
15
15
  - id: validate-pyproject
16
16
 
17
17
  - repo: https://github.com/astral-sh/ruff-pre-commit
18
- rev: v0.0.292
18
+ rev: v0.1.6
19
19
  hooks:
20
20
  - id: ruff
21
21
  args: [--fix, --target-version, py38]
22
22
 
23
23
  - repo: https://github.com/psf/black
24
- rev: 23.9.1
24
+ rev: 23.11.0
25
25
  hooks:
26
26
  - id: black
27
27
 
28
28
  - repo: https://github.com/pre-commit/mirrors-mypy
29
- rev: v1.5.1
29
+ rev: v1.7.1
30
30
  hooks:
31
31
  - id: mypy
32
32
  files: ^src/
@@ -42,7 +42,7 @@ repos:
42
42
 
43
43
  # jupyter linting and formatting
44
44
  - repo: https://github.com/nbQA-dev/nbQA
45
- rev: 1.7.0
45
+ rev: 1.7.1
46
46
  hooks:
47
47
  - id: nbqa-ruff
48
48
  args: [--fix]
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: careamics
3
- Version: 0.1.0rc1
3
+ Version: 0.1.0rc2
4
4
  Summary: Toolbox for running N2V and friends.
5
5
  Project-URL: homepage, https://careamics.github.io/
6
6
  Project-URL: repository, https://github.com/CAREamics/careamics
@@ -14,6 +14,7 @@ Classifier: Programming Language :: Python :: 3.8
14
14
  Classifier: Programming Language :: Python :: 3.9
15
15
  Classifier: Programming Language :: Python :: 3.10
16
16
  Classifier: Programming Language :: Python :: 3.11
17
+ Classifier: Programming Language :: Python :: 3.12
17
18
  Classifier: Typing :: Typed
18
19
  Requires-Python: >=3.8
19
20
  Requires-Dist: bioimageio-core
@@ -26,6 +27,7 @@ Requires-Dist: torchvision
26
27
  Requires-Dist: zarr
27
28
  Provides-Extra: all
28
29
  Requires-Dist: careamics-portfolio; extra == 'all'
30
+ Requires-Dist: ipython; extra == 'all'
29
31
  Requires-Dist: itkwidgets; extra == 'all'
30
32
  Requires-Dist: jupyter; extra == 'all'
31
33
  Requires-Dist: pre-commit; extra == 'all'
@@ -43,12 +45,11 @@ Requires-Dist: ipython; extra == 'notebooks'
43
45
  Requires-Dist: itkwidgets; extra == 'notebooks'
44
46
  Requires-Dist: jupyter; extra == 'notebooks'
45
47
  Requires-Dist: torchsummary; extra == 'notebooks'
48
+ Requires-Dist: wandb; extra == 'notebooks'
46
49
  Provides-Extra: test
47
50
  Requires-Dist: pytest; extra == 'test'
48
51
  Requires-Dist: pytest-cov; extra == 'test'
49
52
  Requires-Dist: wandb; extra == 'test'
50
- Provides-Extra: wandb
51
- Requires-Dist: wandb; extra == 'wandb'
52
53
  Description-Content-Type: text/markdown
53
54
 
54
55
  <p align="center">
@@ -1,15 +1,12 @@
1
- # https://peps.python.org/pep-0517/
2
-
3
1
  [build-system]
4
2
  requires = ["hatchling", "hatch-vcs"]
5
3
  build-backend = "hatchling.build"
6
- # read more about configuring hatch at:
7
- # https://hatch.pypa.io/latest/config/build/
8
4
 
9
5
  # https://hatch.pypa.io/latest/config/metadata/
10
6
  [tool.hatch.version]
11
7
  source = "vcs"
12
8
 
9
+ # https://hatch.pypa.io/latest/config/build/
13
10
  [tool.hatch.build.targets.wheel]
14
11
  only-include = ["src"]
15
12
  sources = ["src"]
@@ -35,6 +32,7 @@ classifiers = [
35
32
  "Programming Language :: Python :: 3.9",
36
33
  "Programming Language :: Python :: 3.10",
37
34
  "Programming Language :: Python :: 3.11",
35
+ "Programming Language :: Python :: 3.12",
38
36
  "License :: OSI Approved :: BSD License",
39
37
  "Typing :: Typed",
40
38
  ]
@@ -53,12 +51,9 @@ dependencies = [
53
51
  # development dependencies and tooling
54
52
  dev = ["pre-commit", "pytest", "pytest-cov"]
55
53
 
56
- # test for ci
54
+ # for ci
57
55
  test = ["pytest", "pytest-cov", "wandb"]
58
56
 
59
- # use wandb for logging
60
- wandb = ["wandb"]
61
-
62
57
  # notebooks
63
58
  notebooks = [
64
59
  "jupyter",
@@ -66,6 +61,7 @@ notebooks = [
66
61
  "itkwidgets",
67
62
  "torchsummary",
68
63
  "ipython",
64
+ "wandb",
69
65
  ]
70
66
 
71
67
  # all
@@ -78,43 +74,32 @@ all = [
78
74
  "careamics-portfolio",
79
75
  "itkwidgets",
80
76
  "torchsummary",
77
+ "ipython",
81
78
  ]
82
79
 
83
80
  [project.urls]
84
81
  homepage = "https://careamics.github.io/"
85
82
  repository = "https://github.com/CAREamics/careamics"
86
83
 
87
- # https://beta.ruff.rs/docs
84
+ # https://docs.astral.sh/ruff/
88
85
  [tool.ruff]
89
86
  line-length = 88
90
87
  target-version = "py38"
91
88
  src = ["src"]
92
89
  select = [
93
- "E", # style errors
94
- "W", # style warnings
95
- "F", # flakes
96
- "D", # pydocstyle
97
- "I", # isort
98
- "UP", # pyupgrade
99
- # "S", # bandit
90
+ "E", # style errors
91
+ "W", # style warnings
92
+ "F", # flakes
93
+ "I", # isort
94
+ "UP", # pyupgrade
100
95
  "C4", # flake8-comprehensions
101
96
  "B", # flake8-bugbear
102
97
  "A001", # flake8-builtins
103
98
  "RUF", # ruff-specific rules
99
+ "TCH", # flake8-type-checking
100
+ "TID", # flake8-tidy-imports
104
101
  ]
105
102
  ignore = [
106
- "D100", # Missing docstring in public module
107
- "D107", # Missing docstring in __init__
108
- "D203", # 1 blank line required before class docstring
109
- "D212", # Multi-line docstring summary should start at the first line
110
- "D213", # Multi-line docstring summary should start at the second line
111
- "D401", # First line should be in imperative mood
112
- "D413", # Missing blank line after last section
113
- "D416", # Section name should end with a colon
114
-
115
- # incompatibility with mypy
116
- "RUF005", # collection-literal-concatenation, in prediction_utils.py:30
117
-
118
103
  # version specific
119
104
  "UP006", # Replace typing.List by list, mandatory for py3.8
120
105
  "UP007", # Replace Union by |, mandatory for py3.9
@@ -122,12 +107,8 @@ ignore = [
122
107
  ignore-init-module-imports = true
123
108
  show-fixes = true
124
109
 
125
- [tool.ruff.pydocstyle]
126
- convention = "numpy"
127
-
128
110
  [tool.ruff.per-file-ignores]
129
- "tests/*.py" = ["D", "S"]
130
- "setup.py" = ["D"]
111
+ "tests/*.py" = ["S"]
131
112
 
132
113
  [tool.black]
133
114
  line-length = 88
@@ -141,10 +122,6 @@ allow_untyped_calls = false
141
122
  disallow_any_generics = false
142
123
  ignore_missing_imports = false
143
124
 
144
- # # module specific overrides
145
- # [[tool.mypy.overrides]]
146
- # module = ["numpy.*",]
147
- # ignore_errors = true
148
125
  # https://docs.pytest.org/en/6.2.x/customize.html
149
126
  [tool.pytest.ini_options]
150
127
  minversion = "6.0"
@@ -153,8 +130,7 @@ filterwarnings = [
153
130
  # "error",
154
131
  # "ignore::UserWarning",
155
132
  ]
156
- markers = ["gpu: marks tests as requiring gpu"]
157
-
133
+ markers = ["gpu: mark tests as requiring gpu"]
158
134
 
159
135
  # https://coverage.readthedocs.io/en/6.4/config.html
160
136
  [tool.coverage.report]
@@ -162,9 +138,14 @@ exclude_lines = [
162
138
  "pragma: no cover",
163
139
  "if TYPE_CHECKING:",
164
140
  "@overload",
165
- "except ImportError",
166
141
  "\\.\\.\\.",
142
+ "except ImportError:",
167
143
  "raise NotImplementedError()",
144
+ "except PackageNotFoundError:",
145
+ "if torch.cuda.is_available():",
146
+ "except UsageError as e:",
147
+ "except ModuleNotFoundError:",
148
+ "except KeyboardInterrupt:",
168
149
  ]
169
150
 
170
151
  [tool.coverage.run]
@@ -178,7 +159,6 @@ ignore = [
178
159
  ".github_changelog_generator",
179
160
  ".pre-commit-config.yaml",
180
161
  ".ruff_cache/**/*",
181
- "setup.py",
182
162
  "tests/**/*",
183
163
  ]
184
164
 
@@ -1,7 +1,7 @@
1
1
  """Provide utilities for exporting models to BioImage model zoo."""
2
2
 
3
3
  __all__ = [
4
- "build_zip_model",
4
+ "save_bioimage_model",
5
5
  "import_bioimage_model",
6
6
  "get_default_model_specs",
7
7
  "PYTORCH_STATE_DICT",
@@ -9,7 +9,7 @@ __all__ = [
9
9
 
10
10
  from .io import (
11
11
  PYTORCH_STATE_DICT,
12
- build_zip_model,
13
- get_default_model_specs,
14
12
  import_bioimage_model,
13
+ save_bioimage_model,
15
14
  )
15
+ from .rdf import get_default_model_specs
@@ -0,0 +1,182 @@
1
+ """Export to bioimage.io format."""
2
+ from pathlib import Path
3
+ from typing import Union
4
+
5
+ import torch
6
+ from bioimageio.core import load_resource_description
7
+ from bioimageio.core.build_spec import build_model
8
+
9
+ from careamics.config.config import Configuration
10
+ from careamics.utils.context import cwd
11
+
12
+ PYTORCH_STATE_DICT = "pytorch_state_dict"
13
+
14
+
15
+ def save_bioimage_model(
16
+ path: Union[str, Path],
17
+ config: Configuration,
18
+ specs: dict,
19
+ ) -> None:
20
+ """
21
+ Build bioimage model zip file from model RDF data.
22
+
23
+ Parameters
24
+ ----------
25
+ path : Union[str, Path]
26
+ Path to the model zip file.
27
+ config : Configuration
28
+ Configuration object.
29
+ specs : dict
30
+ Model RDF dict.
31
+ """
32
+ workdir = config.working_directory
33
+
34
+ # temporary folder
35
+ temp_folder = Path.home().joinpath(".careamics", "bmz_tmp")
36
+ temp_folder.mkdir(exist_ok=True, parents=True)
37
+
38
+ # change working directory to the temp folder
39
+ with cwd(temp_folder):
40
+ # load best checkpoint
41
+ checkpoint_path = workdir.joinpath(
42
+ f"{config.experiment_name}_best.pth"
43
+ ).absolute()
44
+ checkpoint = torch.load(checkpoint_path, map_location="cpu")
45
+
46
+ # save chekpoint entries in separate files
47
+ weight_path = Path("model_weights.pth")
48
+ torch.save(checkpoint["model_state_dict"], weight_path)
49
+
50
+ optim_path = Path("optim.pth")
51
+ torch.save(checkpoint["optimizer_state_dict"], optim_path)
52
+
53
+ scheduler_path = Path("scheduler.pth")
54
+ torch.save(checkpoint["scheduler_state_dict"], scheduler_path)
55
+
56
+ grad_path = Path("grad.pth")
57
+ torch.save(checkpoint["grad_scaler_state_dict"], grad_path)
58
+
59
+ config_path = Path("config.pth")
60
+ torch.save(config.model_dump(), config_path)
61
+
62
+ # create attachments
63
+ attachments = [
64
+ str(optim_path),
65
+ str(scheduler_path),
66
+ str(grad_path),
67
+ str(config_path),
68
+ ]
69
+
70
+ # create requirements file
71
+ requirements = Path("requirements.txt")
72
+ with open(requirements, "w") as f:
73
+ f.write("git+https://github.com/CAREamics/careamics.git")
74
+
75
+ algo_config = config.algorithm
76
+ specs.update(
77
+ {
78
+ "weight_type": PYTORCH_STATE_DICT,
79
+ "weight_uri": str(weight_path),
80
+ "architecture": "careamics.models.unet.UNet",
81
+ "pytorch_version": torch.__version__,
82
+ "model_kwargs": {
83
+ "conv_dim": algo_config.get_conv_dim(),
84
+ "depth": algo_config.model_parameters.depth,
85
+ "num_channels_init": algo_config.model_parameters.num_channels_init,
86
+ },
87
+ "dependencies": "pip:" + str(requirements),
88
+ "attachments": {"files": attachments},
89
+ }
90
+ )
91
+
92
+ if config.algorithm.is_3D:
93
+ specs["tags"].append("3D")
94
+ else:
95
+ specs["tags"].append("2D")
96
+
97
+ # build model zip
98
+ build_model(
99
+ output_path=Path(path).absolute(),
100
+ **specs,
101
+ )
102
+
103
+ # remove temporary files
104
+ for file in temp_folder.glob("*"):
105
+ file.unlink()
106
+
107
+ # delete temporary folder
108
+ temp_folder.rmdir()
109
+
110
+
111
+ def import_bioimage_model(model_path: Union[str, Path]) -> Path:
112
+ """
113
+ Load configuration and weights from a bioimage zip model.
114
+
115
+ Parameters
116
+ ----------
117
+ model_path : Union[str, Path]
118
+ Path to the bioimage.io archive.
119
+
120
+ Returns
121
+ -------
122
+ Path
123
+ Path to the checkpoint.
124
+
125
+ Raises
126
+ ------
127
+ ValueError
128
+ If the model format is invalid.
129
+ FileNotFoundError
130
+ If the checkpoint file was not found.
131
+ """
132
+ model_path = Path(model_path)
133
+
134
+ # check the model extension (should be a zip file).
135
+ if model_path.suffix != ".zip":
136
+ raise ValueError("Invalid model format. Expected bioimage model zip file.")
137
+
138
+ # load the model
139
+ rdf = load_resource_description(model_path)
140
+
141
+ # create a valid checkpoint file from weights and attached files
142
+ basedir = model_path.parent.joinpath("rdf_model")
143
+ basedir.mkdir(exist_ok=True)
144
+ optim_path = None
145
+ scheduler_path = None
146
+ grad_path = None
147
+ config_path = None
148
+ weight_path = None
149
+
150
+ if rdf.weights.get(PYTORCH_STATE_DICT) is not None:
151
+ weight_path = rdf.weights.get(PYTORCH_STATE_DICT).source
152
+
153
+ for file in rdf.attachments.files:
154
+ if file.name.endswith("optim.pth"):
155
+ optim_path = file
156
+ elif file.name.endswith("scheduler.pth"):
157
+ scheduler_path = file
158
+ elif file.name.endswith("grad.pth"):
159
+ grad_path = file
160
+ elif file.name.endswith("config.pth"):
161
+ config_path = file
162
+
163
+ if (
164
+ weight_path is None
165
+ or optim_path is None
166
+ or scheduler_path is None
167
+ or grad_path is None
168
+ or config_path is None
169
+ ):
170
+ raise FileNotFoundError(f"No valid checkpoint file was found in {model_path}.")
171
+
172
+ checkpoint = {
173
+ "model_state_dict": torch.load(weight_path, map_location="cpu"),
174
+ "optimizer_state_dict": torch.load(optim_path, map_location="cpu"),
175
+ "scheduler_state_dict": torch.load(scheduler_path, map_location="cpu"),
176
+ "grad_scaler_state_dict": torch.load(grad_path, map_location="cpu"),
177
+ "config": torch.load(config_path, map_location="cpu"),
178
+ }
179
+ checkpoint_path = basedir.joinpath("checkpoint.pth")
180
+ torch.save(checkpoint, checkpoint_path)
181
+
182
+ return checkpoint_path
@@ -0,0 +1,105 @@
1
+ """RDF related methods."""
2
+ from pathlib import Path
3
+
4
+
5
+ def _get_model_doc(name: str) -> str:
6
+ """
7
+ Return markdown documentation path for the provided model.
8
+
9
+ Parameters
10
+ ----------
11
+ name : str
12
+ Model's name.
13
+
14
+ Returns
15
+ -------
16
+ str
17
+ Path to the model's markdown documentation.
18
+
19
+ Raises
20
+ ------
21
+ FileNotFoundError
22
+ If the documentation file was not found.
23
+ """
24
+ doc = Path(__file__).parent.joinpath("docs").joinpath(f"{name}.md")
25
+ if doc.exists():
26
+ return str(doc.absolute())
27
+ else:
28
+ raise FileNotFoundError(f"Documentation for {name} was not found.")
29
+
30
+
31
+ def get_default_model_specs(
32
+ name: str, mean: float, std: float, is_3D: bool = False
33
+ ) -> dict:
34
+ """
35
+ Return the default bioimage.io specs for the provided model's name.
36
+
37
+ Currently only supports `Noise2Void` model.
38
+
39
+ Parameters
40
+ ----------
41
+ name : str
42
+ Algorithm's name.
43
+ mean : float
44
+ Mean of the dataset.
45
+ std : float
46
+ Std of the dataset.
47
+ is_3D : bool, optional
48
+ Whether the model is 3D or not, by default False.
49
+
50
+ Returns
51
+ -------
52
+ dict
53
+ Model specs compatible with bioimage.io export.
54
+ """
55
+ rdf = {
56
+ "name": "Noise2Void",
57
+ "description": "Self-supervised denoising.",
58
+ "license": "BSD-3-Clause",
59
+ "authors": [
60
+ {"name": "Alexander Krull"},
61
+ {"name": "Tim-Oliver Buchholz"},
62
+ {"name": "Florian Jug"},
63
+ ],
64
+ "cite": [
65
+ {
66
+ "doi": "10.48550/arXiv.1811.10980",
67
+ "text": 'A. Krull, T.-O. Buchholz and F. Jug, "Noise2Void - Learning '
68
+ 'Denoising From Single Noisy Images," 2019 IEEE/CVF '
69
+ "Conference on Computer Vision and Pattern Recognition "
70
+ "(CVPR), 2019, pp. 2124-2132",
71
+ }
72
+ ],
73
+ # "input_axes": ["bcyx"], <- overriden in save_as_bioimage
74
+ "preprocessing": [ # for multiple inputs
75
+ [ # multiple processes per input
76
+ {
77
+ "kwargs": {
78
+ "axes": "zyx" if is_3D else "yx",
79
+ "mean": [mean],
80
+ "mode": "fixed",
81
+ "std": [std],
82
+ },
83
+ "name": "zero_mean_unit_variance",
84
+ }
85
+ ]
86
+ ],
87
+ # "output_axes": ["bcyx"], <- overriden in save_as_bioimage
88
+ "postprocessing": [ # for multiple outputs
89
+ [ # multiple processes per input
90
+ {
91
+ "kwargs": {
92
+ "axes": "zyx" if is_3D else "yx",
93
+ "gain": [std],
94
+ "offset": [mean],
95
+ },
96
+ "name": "scale_linear",
97
+ }
98
+ ]
99
+ ],
100
+ "tags": ["unet", "denoising", "Noise2Void", "tensorflow", "napari"],
101
+ }
102
+
103
+ rdf["documentation"] = _get_model_doc(name)
104
+
105
+ return rdf
@@ -13,10 +13,11 @@ from pydantic import (
13
13
  model_validator,
14
14
  )
15
15
 
16
- from .algorithm import Algorithm
16
+ # ignore typing-only-first-party-import in this file (flake8)
17
+ from .algorithm import Algorithm # noqa: TCH001
17
18
  from .config_filter import paths_to_str
18
- from .data import Data
19
- from .training import Training
19
+ from .data import Data # noqa: TCH001
20
+ from .training import Training # noqa: TCH001
20
21
 
21
22
 
22
23
  class Configuration(BaseModel):
@@ -12,7 +12,7 @@ from pydantic import (
12
12
  model_validator,
13
13
  )
14
14
 
15
- from ..utils import check_axes_validity
15
+ from careamics.utils import check_axes_validity
16
16
 
17
17
 
18
18
  class SupportedExtension(str, Enum):
@@ -23,7 +23,7 @@ class SupportedExtension(str, Enum):
23
23
  - tif/tiff: .tiff files.
24
24
  """
25
25
 
26
- TIFF = "tiff"
26
+ TIFF = "tiff" # TODO these should be a single one
27
27
  TIF = "tif"
28
28
 
29
29
  @classmethod
@@ -46,15 +46,11 @@ def _update_axes(array: np.ndarray, axes: str) -> np.ndarray:
46
46
  Updated array.
47
47
  """
48
48
  # concatenate ST axes to N, return NCZYX
49
- if ("S" in axes or "T" in axes) and array.dtype != "O":
49
+ if "S" in axes or "T" in axes:
50
50
  new_axes_len = len(axes.replace("Z", "").replace("YX", ""))
51
51
  # TODO test reshape as it can scramble data, moveaxis is probably better
52
52
  array = array.reshape(-1, *array.shape[new_axes_len:]).astype(np.float32)
53
53
 
54
- elif array.dtype == "O":
55
- for i in range(len(array)):
56
- array[i] = np.expand_dims(array[i], axis=0).astype(np.float32)
57
-
58
54
  else:
59
55
  array = np.expand_dims(array, axis=0).astype(np.float32)
60
56
 
@@ -5,8 +5,9 @@ from typing import Callable, Dict, List, Optional, Tuple, Union
5
5
  import numpy as np
6
6
  import torch
7
7
 
8
- from ..utils import normalize
9
- from ..utils.logging import get_logger
8
+ from careamics.utils import normalize
9
+ from careamics.utils.logging import get_logger
10
+
10
11
  from .dataset_utils import (
11
12
  list_files,
12
13
  read_tiff,
@@ -9,7 +9,8 @@ from typing import Generator, List, Optional, Tuple, Union
9
9
  import numpy as np
10
10
  from skimage.util import view_as_windows
11
11
 
12
- from ..utils.logging import get_logger
12
+ from careamics.utils.logging import get_logger
13
+
13
14
  from .extraction_strategy import ExtractionStrategy
14
15
 
15
16
  logger = get_logger(__name__)
@@ -481,13 +482,11 @@ def generate_patches(
481
482
  elif patch_extraction_method == ExtractionStrategy.SEQUENTIAL:
482
483
  patches = _extract_patches_sequential(sample, patch_size=patch_size)
483
484
 
484
- elif patch_extraction_method == ExtractionStrategy.RANDOM:
485
+ else:
486
+ # random patching
485
487
  patches = _extract_patches_random(sample, patch_size=patch_size)
486
488
 
487
- if patches is None:
488
- raise ValueError("No patch generated")
489
-
490
489
  return patches
491
490
  else:
492
- # no patching
491
+ # no patching, return a generator for the sample
493
492
  return (sample for _ in range(1))
@@ -6,9 +6,10 @@ Methods to set up the datasets for training, validation and prediction.
6
6
  from pathlib import Path
7
7
  from typing import List, Optional, Union
8
8
 
9
- from ..config import Configuration
10
- from ..manipulation import default_manipulate
11
- from ..utils import check_tiling_validity
9
+ from careamics.config import Configuration
10
+ from careamics.manipulation import default_manipulate
11
+ from careamics.utils import check_tiling_validity
12
+
12
13
  from .extraction_strategy import ExtractionStrategy
13
14
  from .in_memory_dataset import InMemoryDataset
14
15
  from .tiff_dataset import TiffDataset
@@ -10,8 +10,9 @@ from typing import Callable, Dict, Generator, List, Optional, Tuple, Union
10
10
  import numpy as np
11
11
  import torch
12
12
 
13
- from ..utils import normalize
14
- from ..utils.logging import get_logger
13
+ from careamics.utils import normalize
14
+ from careamics.utils.logging import get_logger
15
+
15
16
  from .dataset_utils import (
16
17
  list_files,
17
18
  read_tiff,
@@ -53,7 +54,7 @@ class TiffDataset(torch.utils.data.IterableDataset):
53
54
  def __init__(
54
55
  self,
55
56
  data_path: Union[str, Path],
56
- data_format: str,
57
+ data_format: str, # TODO: TiffDataset should not know that they are tiff
57
58
  axes: str,
58
59
  patch_extraction_method: Union[ExtractionStrategy, None],
59
60
  patch_size: Optional[Union[List[int], Tuple[int]]] = None,