predtiler 0.0.1__tar.gz

Sign up to get free protection for your applications and to get access to all the features.
@@ -0,0 +1,39 @@
1
+ ### Description
2
+
3
+ Please provide a brief description of the changes in this PR. Include any relevant context or background information.
4
+
5
+ - **What**: Clearly and concisely describe what changes you have made.
6
+ - **Why**: Explain the reasoning behind these changes.
7
+ - **How**: Describe how you implemented these changes.
8
+
9
+ ### Changes Made
10
+
11
+ - **Added**: List new features or files added.
12
+ - **Modified**: Describe existing features or files modified.
13
+ - **Removed**: Detail features or files that were removed.
14
+
15
+ ### Related Issues
16
+
17
+ Link to any related issues or discussions. Use keywords like "Fixes", "Resolves", or "Closes" to link to issues automatically.
18
+
19
+ - Fixes #
20
+ - Resolves #
21
+ - Closes #
22
+
23
+ ### Breaking changes
24
+
25
+ Describe any breaking change.
26
+
27
+
28
+ ### Additional Notes and Examples
29
+
30
+ Include any additional notes or context that reviewers should be aware of, including snippets of code illustrating your new feature.
31
+
32
+ ---
33
+
34
+ **Please ensure your PR meets the following requirements:**
35
+
36
+ - [ ] Code builds and passes tests locally, including doctests
37
+ - [ ] New tests have been added (for bug fixes/features)
38
+ - [ ] Pre-commit passes
39
+ - [ ] PR to the documentation exists (for bug fixes / features)
@@ -0,0 +1,91 @@
1
+ name: CI
2
+
3
+ on:
4
+ push:
5
+ branches:
6
+ - main
7
+ tags:
8
+ - "v*"
9
+ pull_request:
10
+ workflow_dispatch:
11
+ schedule:
12
+ # run every week (for --pre release tests)
13
+ - cron: "0 0 * * 0"
14
+
15
+ jobs:
16
+ check-manifest:
17
+ # check-manifest is a tool that checks that all files in version control are
18
+ # included in the sdist (unless explicitly excluded)
19
+ runs-on: ubuntu-latest
20
+ steps:
21
+ - uses: actions/checkout@v3
22
+ - run: pipx run check-manifest
23
+
24
+ test:
25
+ name: ${{ matrix.platform }} (${{ matrix.python-version }})
26
+ runs-on: ${{ matrix.platform }}
27
+ strategy:
28
+ fail-fast: false
29
+ matrix:
30
+ python-version: ["3.9", "3.10", "3.11", "3.12"]
31
+ # https://docs.github.com/en/actions/using-github-hosted-runners/about-github-hosted-runners/about-github-hosted-runners#standard-github-hosted-runners-for-public-repositories
32
+ platform: [ubuntu-latest, macos-13, windows-latest]
33
+
34
+ steps:
35
+ - name: ๐Ÿ›‘ Cancel Previous Runs
36
+ uses: styfle/cancel-workflow-action@0.11.0
37
+ with:
38
+ access_token: ${{ github.token }}
39
+
40
+ - uses: actions/checkout@v3
41
+
42
+ - name: ๐Ÿ Set up Python ${{ matrix.python-version }}
43
+ uses: actions/setup-python@v4
44
+ with:
45
+ python-version: ${{ matrix.python-version }}
46
+ cache-dependency-path: "pyproject.toml"
47
+ cache: "pip"
48
+
49
+ - name: Install Dependencies
50
+ run: |
51
+ python -m pip install -U pip
52
+ # if running a cron job, we add the --pre flag to test against pre-releases
53
+ python -m pip install ".[dev]" ${{ github.event_name == 'schedule' && '--pre' || '' }}
54
+
55
+ - name: ๐Ÿงช Run Tests
56
+ run: pytest
57
+
58
+ deploy:
59
+ name: Release
60
+ needs: test
61
+ if: success() && startsWith(github.ref, 'refs/tags/') && github.event_name != 'schedule'
62
+ runs-on: ubuntu-latest
63
+
64
+ permissions:
65
+ # IMPORTANT: this permission is mandatory for trusted publishing
66
+ id-token: write
67
+
68
+ # This permission allows writing releases
69
+ contents: write
70
+
71
+ steps:
72
+ - uses: actions/checkout@v4
73
+ with:
74
+ fetch-depth: 0
75
+
76
+ - name: Set up Python
77
+ uses: actions/setup-python@v5
78
+ with:
79
+ python-version: "3.9"
80
+
81
+ - name: Build
82
+ run: |
83
+ python -m pip install build
84
+ python -m build
85
+
86
+ - name: Publish to PyPI
87
+ uses: pypa/gh-action-pypi-publish@release/v1
88
+
89
+ - uses: softprops/action-gh-release@v2
90
+ with:
91
+ generate_release_notes: true
@@ -0,0 +1,48 @@
1
+ name: Coverage
2
+
3
+ on:
4
+ push:
5
+ branches:
6
+ - main
7
+ tags:
8
+ - "v*"
9
+ pull_request:
10
+
11
+ jobs:
12
+
13
+ test:
14
+ name: ${{ matrix.platform }} (${{ matrix.python-version }})
15
+ runs-on: ${{ matrix.platform }}
16
+ strategy:
17
+ fail-fast: false
18
+ matrix:
19
+ python-version: ["3.10"]
20
+ platform: [ubuntu-latest]
21
+
22
+ steps:
23
+ - name: ๐Ÿ›‘ Cancel Previous Runs
24
+ uses: styfle/cancel-workflow-action@0.11.0
25
+ with:
26
+ access_token: ${{ github.token }}
27
+
28
+ - uses: actions/checkout@v3
29
+
30
+ - name: ๐Ÿ Set up Python ${{ matrix.python-version }}
31
+ uses: actions/setup-python@v4
32
+ with:
33
+ python-version: ${{ matrix.python-version }}
34
+ cache-dependency-path: "pyproject.toml"
35
+ cache: "pip"
36
+
37
+ - name: Install Dependencies
38
+ run: |
39
+ python -m pip install -U pip
40
+ python -m pip install -e ".[dev]"
41
+
42
+ - name: ๐Ÿงช Run Tests
43
+ run: pytest --color=yes --cov --cov-config=pyproject.toml --cov-report=xml --cov-report=term-missing
44
+
45
+ - name: Coverage
46
+ uses: codecov/codecov-action@v3
47
+ with:
48
+ version: v0.7.3
@@ -0,0 +1,168 @@
1
+ # VSCode
2
+ .vscode
3
+
4
+ # Byte-compiled / optimized / DLL files
5
+ __pycache__/
6
+ *.py[cod]
7
+ *$py.class
8
+
9
+ # C extensions
10
+ *.so
11
+
12
+ # Distribution / packaging
13
+ .Python
14
+ build/
15
+ develop-eggs/
16
+ dist/
17
+ downloads/
18
+ eggs/
19
+ .eggs/
20
+ lib/
21
+ lib64/
22
+ parts/
23
+ sdist/
24
+ var/
25
+ wheels/
26
+ share/python-wheels/
27
+ *.egg-info/
28
+ .installed.cfg
29
+ *.egg
30
+ MANIFEST
31
+
32
+ # PyInstaller
33
+ # Usually these files are written by a python script from a template
34
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
35
+ *.manifest
36
+ *.spec
37
+
38
+ # Installer logs
39
+ pip-log.txt
40
+ pip-delete-this-directory.txt
41
+
42
+ # Unit test / coverage reports
43
+ htmlcov/
44
+ .tox/
45
+ .nox/
46
+ .coverage
47
+ .coverage.*
48
+ .cache
49
+ nosetests.xml
50
+ coverage.xml
51
+ *.cover
52
+ *.py,cover
53
+ .hypothesis/
54
+ .pytest_cache/
55
+ cover/
56
+
57
+ # Translations
58
+ *.mo
59
+ *.pot
60
+
61
+ # Django stuff:
62
+ *.log
63
+ local_settings.py
64
+ db.sqlite3
65
+ db.sqlite3-journal
66
+
67
+ # Flask stuff:
68
+ instance/
69
+ .webassets-cache
70
+
71
+ # Scrapy stuff:
72
+ .scrapy
73
+
74
+ # Sphinx documentation
75
+ docs/_build/
76
+
77
+ # PyBuilder
78
+ .pybuilder/
79
+ target/
80
+
81
+ # Jupyter Notebook
82
+ .ipynb_checkpoints
83
+
84
+ # IPython
85
+ profile_default/
86
+ ipython_config.py
87
+
88
+ # pyenv
89
+ # For a library or package, you might want to ignore these files since the code is
90
+ # intended to run in multiple environments; otherwise, check them in:
91
+ # .python-version
92
+
93
+ # pipenv
94
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
95
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
96
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
97
+ # install all needed dependencies.
98
+ #Pipfile.lock
99
+
100
+ # poetry
101
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
102
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
103
+ # commonly ignored for libraries.
104
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
105
+ #poetry.lock
106
+
107
+ # pdm
108
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
109
+ #pdm.lock
110
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
111
+ # in version control.
112
+ # https://pdm.fming.dev/latest/usage/project/#working-with-version-control
113
+ .pdm.toml
114
+ .pdm-python
115
+ .pdm-build/
116
+
117
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
118
+ __pypackages__/
119
+
120
+ # Celery stuff
121
+ celerybeat-schedule
122
+ celerybeat.pid
123
+
124
+ # SageMath parsed files
125
+ *.sage.py
126
+
127
+ # Environments
128
+ .env
129
+ .venv
130
+ env/
131
+ venv/
132
+ ENV/
133
+ env.bak/
134
+ venv.bak/
135
+
136
+ # Spyder project settings
137
+ .spyderproject
138
+ .spyproject
139
+
140
+ # Rope project settings
141
+ .ropeproject
142
+
143
+ # mkdocs documentation
144
+ /site
145
+
146
+ # mypy
147
+ .mypy_cache/
148
+ .dmypy.json
149
+ dmypy.json
150
+
151
+ # Pyre type checker
152
+ .pyre/
153
+
154
+ # pytype static type analyzer
155
+ .pytype/
156
+
157
+ # Cython debug symbols
158
+ cython_debug/
159
+
160
+ # PyCharm
161
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
162
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
163
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
164
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
165
+ #.idea/
166
+
167
+ # Ruff
168
+ .ruff_cache
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2024 ashesh
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
@@ -0,0 +1,122 @@
1
+ Metadata-Version: 2.3
2
+ Name: predtiler
3
+ Version: 0.0.1
4
+ Summary: Converting your dataset class into a class that can be used for tiled prediction and eventually obtain stiched prediction.
5
+ Project-URL: homepage, https://github.com/ashesh-0/PredTiler
6
+ Project-URL: repository, https://github.com/ashesh-0/PredTiler
7
+ Author: Ashesh
8
+ License: MIT
9
+ License-File: LICENSE
10
+ Classifier: Development Status :: 3 - Alpha
11
+ Classifier: License :: OSI Approved :: MIT License
12
+ Classifier: Programming Language :: Python :: 3
13
+ Classifier: Programming Language :: Python :: 3.9
14
+ Classifier: Programming Language :: Python :: 3.10
15
+ Classifier: Programming Language :: Python :: 3.11
16
+ Classifier: Programming Language :: Python :: 3.12
17
+ Classifier: Typing :: Typed
18
+ Requires-Python: >=3.9
19
+ Requires-Dist: numpy
20
+ Provides-Extra: dev
21
+ Requires-Dist: pre-commit; extra == 'dev'
22
+ Requires-Dist: pytest; extra == 'dev'
23
+ Requires-Dist: pytest-cov; extra == 'dev'
24
+ Requires-Dist: sybil; extra == 'dev'
25
+ Provides-Extra: examples
26
+ Requires-Dist: jupyter; extra == 'examples'
27
+ Requires-Dist: matplotlib; extra == 'examples'
28
+ Description-Content-Type: text/markdown
29
+
30
+ A lean wrapper around your dataset class to enable tiled prediction.
31
+
32
+ [![License](https://img.shields.io/badge/license-MIT-blue)](https://github.com/ashesh-0/PredTiler/blob/main/LICENSE)
33
+ [![CI](https://github.com/ashesh-0/PredTiler/actions/workflows/ci.yml/badge.svg)](https://github.com/ashesh-0/PredTiler/actions/workflows/ci.yml)
34
+ [![codecov](https://codecov.io/gh/ashesh-0/PredTiler/graph/badge.svg?token=M655MOS7EL)](https://codecov.io/gh/ashesh-0/PredTiler)
35
+
36
+ ## Objective
37
+ This package subclasses the dataset class you use to train your network.
38
+ With PredTiler, you can use your dataset class as is, and PredTiler will take care of the tiling logic for you.
39
+ It will automatically generate patches in such a way that they can be tiled with the overlap of `(patch_size - tile_size)//2`.
40
+ We also provide a function to stitch the tiles back together to get the final prediction.
41
+
42
+ In case you are facing issues, feel free to raise an issue and I will be happy to help you out !
43
+ In future, I plan to add detailed instructions for:
44
+ 1. multi-channel data
45
+ 2. 3D data
46
+ 3. Data being a list of numpy arrays, each poissibly having different shapes.
47
+
48
+ ## Installation
49
+
50
+ ```bash
51
+ pip install predtiler
52
+ ```
53
+
54
+ ## Usage
55
+ To work with PredTiler, the only requirement is that your dataset class must have a **patch_location(self, index)** method that returns the location of the patch at the given index.
56
+ Your dataset class should only use the location information returned by this method to return the patch.
57
+ PredTiler will override this method to return the location of the patches needed for tiled prediction.
58
+
59
+ Note that your dataset class could be arbitrarily complex (augmentations, returning multiple patches, working with 3D data, etc.). The only requirement is that it should use the crop present at the location returned by **patch_location** method. Below is an example of a simple dataset class that can be used with PredTiler.
60
+
61
+ ```python
62
+ class YourDataset:
63
+ def __init__(self, data_path, patch_size=64) -> None:
64
+ self.patch_size = patch_size
65
+ self.data = load_data(data_path) # shape: (N, H, W, C)
66
+
67
+ def patch_location(self, index:int)-> Tuple[int, int, int]:
68
+ # it just ignores the index and returns a random location
69
+ n_idx = np.random.randint(0,len(self.data))
70
+ h = np.random.randint(0, self.data.shape[1]-self.patch_size)
71
+ w = np.random.randint(0, self.data.shape[2]-self.patch_size)
72
+ return (n_idx, h, w)
73
+
74
+ def __len__(self):
75
+ return len(self.data)
76
+
77
+ def __getitem__(self, index):
78
+ n_idx, h, w = self.patch_location(index)
79
+ # return the patch at the location (patch_size, patch_size)
80
+ return self.data[n_idx, h:h+self.patch_size, w:w+self.patch_size]
81
+ ```
82
+
83
+ ## Getting overlapping patches needed for tiled prediction
84
+ To use PredTiler, we need to get a new class that wraps around your dataset class.
85
+ For this we also need a tile manager that will manage the tiles.
86
+
87
+ ```python
88
+
89
+ from predtiler.dataset import get_tiling_dataset, get_tile_manager
90
+ patch_size = 256
91
+ tile_size = 128
92
+ data_shape = (10, 2048, 2048) # size of the data you are working with
93
+ manager = get_tile_manager(data_shape=data_shape, tile_shape=(1,tile_size,tile_size),
94
+ patch_shape=(1,patch_size,patch_size))
95
+
96
+ dset_class = get_tiling_dataset(YourDataset, manager)
97
+ ```
98
+
99
+ At this point, you can use the `dset_class` as you would use `YourDataset` class.
100
+
101
+ ```python
102
+ data_path = ... # path to your data
103
+ dset = dset_class(data_path, patch_size=patch_size)
104
+ ```
105
+
106
+ ## Stitching the predictions
107
+ The benefit of using PredTiler is that it will automatically generate the patches in such a way that they can be tiled with the overlap of `(patch_size - tile_size)//2`. This allows you to use your dataset class as is, without worrying about the tiling logic.
108
+
109
+ ```python
110
+ model = ... # your model
111
+ predictions = []
112
+ for i in range(len(dset)):
113
+ inp = dset[i]
114
+ inp = torch.Tensor(inp)[None,None]
115
+ pred = model(inp)
116
+ predictions.append(pred[0].numpy())
117
+
118
+ predictions = np.stack(predictions) # shape: (number_of_patches, C, patch_size, patch_size)
119
+ stitched_pred = stitch_predictions(predictions, dset.tile_manager)
120
+ ```
121
+
122
+
@@ -0,0 +1,93 @@
1
+ A lean wrapper around your dataset class to enable tiled prediction.
2
+
3
+ [![License](https://img.shields.io/badge/license-MIT-blue)](https://github.com/ashesh-0/PredTiler/blob/main/LICENSE)
4
+ [![CI](https://github.com/ashesh-0/PredTiler/actions/workflows/ci.yml/badge.svg)](https://github.com/ashesh-0/PredTiler/actions/workflows/ci.yml)
5
+ [![codecov](https://codecov.io/gh/ashesh-0/PredTiler/graph/badge.svg?token=M655MOS7EL)](https://codecov.io/gh/ashesh-0/PredTiler)
6
+
7
+ ## Objective
8
+ This package subclasses the dataset class you use to train your network.
9
+ With PredTiler, you can use your dataset class as is, and PredTiler will take care of the tiling logic for you.
10
+ It will automatically generate patches in such a way that they can be tiled with the overlap of `(patch_size - tile_size)//2`.
11
+ We also provide a function to stitch the tiles back together to get the final prediction.
12
+
13
+ In case you are facing issues, feel free to raise an issue and I will be happy to help you out !
14
+ In future, I plan to add detailed instructions for:
15
+ 1. multi-channel data
16
+ 2. 3D data
17
+ 3. Data being a list of numpy arrays, each poissibly having different shapes.
18
+
19
+ ## Installation
20
+
21
+ ```bash
22
+ pip install predtiler
23
+ ```
24
+
25
+ ## Usage
26
+ To work with PredTiler, the only requirement is that your dataset class must have a **patch_location(self, index)** method that returns the location of the patch at the given index.
27
+ Your dataset class should only use the location information returned by this method to return the patch.
28
+ PredTiler will override this method to return the location of the patches needed for tiled prediction.
29
+
30
+ Note that your dataset class could be arbitrarily complex (augmentations, returning multiple patches, working with 3D data, etc.). The only requirement is that it should use the crop present at the location returned by **patch_location** method. Below is an example of a simple dataset class that can be used with PredTiler.
31
+
32
+ ```python
33
+ class YourDataset:
34
+ def __init__(self, data_path, patch_size=64) -> None:
35
+ self.patch_size = patch_size
36
+ self.data = load_data(data_path) # shape: (N, H, W, C)
37
+
38
+ def patch_location(self, index:int)-> Tuple[int, int, int]:
39
+ # it just ignores the index and returns a random location
40
+ n_idx = np.random.randint(0,len(self.data))
41
+ h = np.random.randint(0, self.data.shape[1]-self.patch_size)
42
+ w = np.random.randint(0, self.data.shape[2]-self.patch_size)
43
+ return (n_idx, h, w)
44
+
45
+ def __len__(self):
46
+ return len(self.data)
47
+
48
+ def __getitem__(self, index):
49
+ n_idx, h, w = self.patch_location(index)
50
+ # return the patch at the location (patch_size, patch_size)
51
+ return self.data[n_idx, h:h+self.patch_size, w:w+self.patch_size]
52
+ ```
53
+
54
+ ## Getting overlapping patches needed for tiled prediction
55
+ To use PredTiler, we need to get a new class that wraps around your dataset class.
56
+ For this we also need a tile manager that will manage the tiles.
57
+
58
+ ```python
59
+
60
+ from predtiler.dataset import get_tiling_dataset, get_tile_manager
61
+ patch_size = 256
62
+ tile_size = 128
63
+ data_shape = (10, 2048, 2048) # size of the data you are working with
64
+ manager = get_tile_manager(data_shape=data_shape, tile_shape=(1,tile_size,tile_size),
65
+ patch_shape=(1,patch_size,patch_size))
66
+
67
+ dset_class = get_tiling_dataset(YourDataset, manager)
68
+ ```
69
+
70
+ At this point, you can use the `dset_class` as you would use `YourDataset` class.
71
+
72
+ ```python
73
+ data_path = ... # path to your data
74
+ dset = dset_class(data_path, patch_size=patch_size)
75
+ ```
76
+
77
+ ## Stitching the predictions
78
+ The benefit of using PredTiler is that it will automatically generate the patches in such a way that they can be tiled with the overlap of `(patch_size - tile_size)//2`. This allows you to use your dataset class as is, without worrying about the tiling logic.
79
+
80
+ ```python
81
+ model = ... # your model
82
+ predictions = []
83
+ for i in range(len(dset)):
84
+ inp = dset[i]
85
+ inp = torch.Tensor(inp)[None,None]
86
+ pred = model(inp)
87
+ predictions.append(pred[0].numpy())
88
+
89
+ predictions = np.stack(predictions) # shape: (number_of_patches, C, patch_size, patch_size)
90
+ stitched_pred = stitch_predictions(predictions, dset.tile_manager)
91
+ ```
92
+
93
+
@@ -0,0 +1,165 @@
1
+ # https://peps.python.org/pep-0517/
2
+ [build-system]
3
+ requires = ["hatchling", "hatch-vcs"]
4
+ build-backend = "hatchling.build"
5
+ # read more about configuring hatch at:
6
+ # https://hatch.pypa.io/latest/config/build/
7
+
8
+ # https://hatch.pypa.io/latest/config/metadata/
9
+ [tool.hatch.version]
10
+ source = "vcs" # use tags for versioning (e.g. tag v0.1.0, v0.2.0 etc.)
11
+
12
+ [tool.hatch.build.targets.wheel]
13
+ only-include = ["src"]
14
+ sources = ["src"]
15
+
16
+ # https://peps.python.org/pep-0621/
17
+ [project]
18
+ name = "predtiler"
19
+ version = "0.0.1"
20
+ description = "Converting your dataset class into a class that can be used for tiled prediction and eventually obtain stiched prediction."
21
+ readme = "README.md"
22
+ requires-python = ">=3.9"
23
+ license = { text = "MIT" }
24
+ authors = [
25
+ { name = 'Ashesh' },
26
+ ]
27
+ classifiers = [
28
+ "Development Status :: 3 - Alpha",
29
+ "Programming Language :: Python :: 3",
30
+ "Programming Language :: Python :: 3.9",
31
+ "Programming Language :: Python :: 3.10",
32
+ "Programming Language :: Python :: 3.11",
33
+ "Programming Language :: Python :: 3.12",
34
+ "License :: OSI Approved :: MIT License",
35
+ "Typing :: Typed",
36
+ ]
37
+ dependencies = [
38
+ "numpy",
39
+ ]
40
+
41
+ [project.optional-dependencies]
42
+ # development dependencies and tooling
43
+ dev = [
44
+ "pre-commit",
45
+ "pytest",
46
+ "pytest-cov",
47
+ "sybil", # doctesting
48
+ ]
49
+
50
+ # examples
51
+ examples = ["jupyter", "matplotlib"]
52
+
53
+ [project.urls]
54
+ homepage = "https://github.com/ashesh-0/PredTiler"
55
+ repository = "https://github.com/ashesh-0/PredTiler"
56
+
57
+ # https://beta.ruff.rs/docs
58
+ [tool.ruff]
59
+ line-length = 88
60
+ target-version = "py39"
61
+ src = ["src"]
62
+ lint.select = [
63
+ "E", # style errors
64
+ "W", # style warnings
65
+ "F", # flakes
66
+ "D", # pydocstyle
67
+ "I", # isort
68
+ "UP", # pyupgrade
69
+ # "S", # bandit
70
+ "C4", # flake8-comprehensions
71
+ "B", # flake8-bugbear
72
+ "A001", # flake8-builtins
73
+ "RUF", # ruff-specific rules
74
+ ]
75
+ lint.ignore = [
76
+ "D100", # Missing docstring in public module
77
+ "D107", # Missing docstring in __init__
78
+ "D203", # 1 blank line required before class docstring
79
+ "D212", # Multi-line docstring summary should start at the first line
80
+ "D213", # Multi-line docstring summary should start at the second line
81
+ "D401", # First line should be in imperative mood
82
+ "D413", # Missing blank line after last section
83
+ "D416", # Section name should end with a colon
84
+
85
+ # incompatibility with mypy
86
+ "RUF005", # collection-literal-concatenation, in prediction_utils.py:30
87
+
88
+ # version specific
89
+ "UP007", # Replace Union by |, mandatory for py3.9
90
+ ]
91
+ show-fixes = true
92
+
93
+ [tool.ruff.lint.pydocstyle]
94
+ convention = "numpy"
95
+
96
+ [tool.ruff.lint.per-file-ignores]
97
+ "tests/*.py" = ["D", "S"]
98
+ "setup.py" = ["D"]
99
+
100
+ [tool.black]
101
+ line-length = 88
102
+
103
+ # https://mypy.readthedocs.io/en/stable/config_file.html
104
+ [tool.mypy]
105
+ files = "src/**/"
106
+ strict = false
107
+ # allow_untyped_defs = false
108
+ # allow_untyped_calls = false
109
+ # disallow_any_generics = false
110
+ # ignore_missing_imports = false
111
+
112
+
113
+ # https://docs.pytest.org/en/6.2.x/customize.html
114
+ [tool.pytest.ini_options]
115
+ minversion = "6.0"
116
+ testpaths = ["tests"] # add src/careamics for doctest discovery
117
+ filterwarnings = [
118
+ # "error",
119
+ # "ignore::UserWarning",
120
+ ]
121
+ addopts = "-p no:doctest"
122
+
123
+
124
+ # https://coverage.readthedocs.io/en/6.4/config.html
125
+ [tool.coverage.report]
126
+ exclude_lines = [
127
+ "pragma: no cover",
128
+ "if TYPE_CHECKING:",
129
+ "@overload",
130
+ "except ImportError",
131
+ "\\.\\.\\.",
132
+ "raise NotImplementedError()",
133
+ ]
134
+
135
+
136
+ [tool.coverage.run]
137
+ source = ["src/microssim"]
138
+
139
+ # https://github.com/mgedmin/check-manifest#configuration
140
+ # add files that you want check-manifest to explicitly ignore here
141
+ # (files that are in the repo but shouldn't go in the package)
142
+ [tool.check-manifest]
143
+ ignore = [
144
+ ".github_changelog_generator",
145
+ ".pre-commit-config.yaml",
146
+ ".ruff_cache/**/*",
147
+ "setup.py",
148
+ "tests/**/*",
149
+ ]
150
+
151
+ [tool.numpydoc_validation]
152
+ checks = [
153
+ "all", # report on all checks, except the below
154
+ "EX01", # Example section not found
155
+ "SA01", # See Also section not found
156
+ "ES01", # Extended Summar not found
157
+ "GL01", # Docstring text (summary) should start in the line immediately
158
+ # after the opening quotes
159
+ "GL02", # Closing quotes should be placed in the line after the last text
160
+ # in the docstring
161
+ "GL03", # Double line break found
162
+ ]
163
+ exclude = [ # don't report on objects that match any of these regex
164
+ "test_*",
165
+ ]
@@ -0,0 +1,53 @@
1
+
2
+ from predtiler.tile_manager import TileIndexManager, TilingMode
3
+
4
+ # class TilingDataset:
5
+ # def __init_subclass__(cls, parent_class=None, tile_manager=None, **kwargs):
6
+ # super().__init_subclass__(**kwargs)
7
+ # assert tile_manager is not None, 'tile_manager must be provided'
8
+ # cls.tile_manager = tile_manager
9
+ # if parent_class is not None:
10
+ # has_callable_method = callable(getattr(parent_class, 'patch_location', None))
11
+ # assert has_callable_method, f'{parent_class.__name__} must have a callable method with following signature: def patch_location(self, index)'
12
+ # cls.__bases__ = (parent_class,) + cls.__bases__
13
+
14
+ # def __len__(self):
15
+ # return self.tile_manager.total_grid_count()
16
+
17
+ # def patch_location(self, index):
18
+ # print('Calling patch_location')
19
+ # patch_loc_list = self.tile_manager.get_patch_location_from_dataset_idx(index)
20
+ # return patch_loc_list
21
+
22
+
23
+ # def get_tiling_dataset(dataset_class, tile_manager) -> type:
24
+ # class CorrespondingTilingDataset(TilingDataset, parent_class=dataset_class, tile_manager=tile_manager):
25
+ # pass
26
+
27
+ # return CorrespondingTilingDataset
28
+
29
+ def get_tiling_dataset(dataset_class, tile_manager) -> type:
30
+ has_callable_method = callable(getattr(dataset_class, 'patch_location', None))
31
+ assert has_callable_method, f'{dataset_class.__name__} must have a callable method with following signature: def patch_location(self, index)'
32
+
33
+ class TilingDataset(dataset_class):
34
+ def __init__(self, *args, **kwargs):
35
+ super().__init__(*args, **kwargs)
36
+ self.tile_manager = tile_manager
37
+
38
+ def __len__(self):
39
+ return self.tile_manager.total_grid_count()
40
+
41
+ def patch_location(self, index):
42
+ patch_loc_list = self.tile_manager.get_patch_location_from_dataset_idx(index)
43
+ return patch_loc_list
44
+
45
+ return TilingDataset
46
+
47
+
48
+
49
+
50
+ def get_tile_manager(data_shape, tile_shape, patch_shape, tiling_mode=TilingMode.ShiftBoundary):
51
+ return TileIndexManager(data_shape, tile_shape, patch_shape, tiling_mode)
52
+
53
+
@@ -0,0 +1,210 @@
1
+ from dataclasses import dataclass
2
+
3
+ import numpy as np
4
+
5
+
6
+ class TilingMode:
7
+ """
8
+ Enum for the tiling mode.
9
+ """
10
+ TrimBoundary = 0
11
+ PadBoundary = 1
12
+ ShiftBoundary = 2
13
+
14
+ @dataclass
15
+ class TileIndexManager:
16
+ data_shape: tuple
17
+ grid_shape: tuple
18
+ patch_shape: tuple
19
+ tiling_mode: TilingMode
20
+
21
+ def __post_init__(self):
22
+ assert len(self.data_shape) == len(self.grid_shape), f"Data shape:{self.data_shape} and grid size:{self.grid_shape} must have the same dimension"
23
+ assert len(self.data_shape) == len(self.patch_shape), f"Data shape:{self.data_shape} and patch shape:{self.patch_shape} must have the same dimension"
24
+ innerpad = np.array(self.patch_shape) - np.array(self.grid_shape)
25
+ for dim, pad in enumerate(innerpad):
26
+ if pad < 0:
27
+ raise ValueError(f"Patch shape:{self.patch_shape} must be greater than or equal to grid shape:{self.grid_shape} in dimension {dim}")
28
+ if pad % 2 != 0:
29
+ raise ValueError(f"Patch shape:{self.patch_shape} must have even padding in dimension {dim}")
30
+
31
+ def patch_offset(self):
32
+ return (np.array(self.patch_shape) - np.array(self.grid_shape))//2
33
+
34
+ def get_individual_dim_grid_count(self, dim:int):
35
+ """
36
+ Returns the number of the grid in the specified dimension, ignoring all other dimensions.
37
+ """
38
+ assert dim < len(self.data_shape), f"Dimension {dim} is out of bounds for data shape {self.data_shape}"
39
+ assert dim >= 0, "Dimension must be greater than or equal to 0"
40
+
41
+ if self.grid_shape[dim]==1 and self.patch_shape[dim]==1:
42
+ return self.data_shape[dim]
43
+ elif self.tiling_mode == TilingMode.PadBoundary:
44
+ return int(np.ceil(self.data_shape[dim] / self.grid_shape[dim]))
45
+ elif self.tiling_mode == TilingMode.ShiftBoundary:
46
+ excess_size = self.patch_shape[dim] - self.grid_shape[dim]
47
+ return int(np.ceil((self.data_shape[dim] - excess_size) / self.grid_shape[dim]))
48
+ else:
49
+ excess_size = self.patch_shape[dim] - self.grid_shape[dim]
50
+ return int(np.floor((self.data_shape[dim] - excess_size) / self.grid_shape[dim]))
51
+
52
+ def total_grid_count(self):
53
+ """
54
+ Returns the total number of grids in the dataset.
55
+ """
56
+ return self.grid_count(0) * self.get_individual_dim_grid_count(0)
57
+
58
+ def grid_count(self, dim:int):
59
+ """
60
+ Returns the total number of grids for one value in the specified dimension.
61
+ """
62
+ assert dim < len(self.data_shape), f"Dimension {dim} is out of bounds for data shape {self.data_shape}"
63
+ assert dim >= 0, "Dimension must be greater than or equal to 0"
64
+ if dim == len(self.data_shape)-1:
65
+ return 1
66
+
67
+ return self.get_individual_dim_grid_count(dim+1) * self.grid_count(dim+1)
68
+
69
+ def get_grid_index(self, dim:int, coordinate:int):
70
+ """
71
+ Returns the index of the grid in the specified dimension.
72
+ """
73
+ assert dim < len(self.data_shape), f"Dimension {dim} is out of bounds for data shape {self.data_shape}"
74
+ assert dim >= 0, "Dimension must be greater than or equal to 0"
75
+ assert coordinate < self.data_shape[dim], f"Coordinate {coordinate} is out of bounds for data shape {self.data_shape}"
76
+
77
+ if self.grid_shape[dim]==1 and self.patch_shape[dim]==1:
78
+ return coordinate
79
+ elif self.tiling_mode == TilingMode.PadBoundary: #self.trim_boundary is False:
80
+ return np.floor(coordinate / self.grid_shape[dim])
81
+ elif self.tiling_mode == TilingMode.TrimBoundary:
82
+ excess_size = (self.patch_shape[dim] - self.grid_shape[dim])//2
83
+ # can be <0 if coordinate is in [0,grid_shape[dim]]
84
+ return max(0, np.floor((coordinate - excess_size) / self.grid_shape[dim]))
85
+ elif self.tiling_mode == TilingMode.ShiftBoundary:
86
+ excess_size = (self.patch_shape[dim] - self.grid_shape[dim])//2
87
+ if coordinate + self.grid_shape[dim] + excess_size == self.data_shape[dim]:
88
+ return self.get_individual_dim_grid_count(dim) - 1
89
+ else:
90
+ # can be <0 if coordinate is in [0,grid_shape[dim]]
91
+ return max(0, np.floor((coordinate - excess_size) / self.grid_shape[dim]))
92
+
93
+ else:
94
+ raise ValueError(f"Unsupported tiling mode {self.tiling_mode}")
95
+
96
+ def dataset_idx_from_grid_idx(self, grid_idx:tuple):
97
+ """
98
+ Returns the index of the grid in the dataset.
99
+ """
100
+ assert len(grid_idx) == len(self.data_shape), f"Dimension indices {grid_idx} must have the same dimension as data shape {self.data_shape}"
101
+ index = 0
102
+ for dim in range(len(grid_idx)):
103
+ index += grid_idx[dim] * self.grid_count(dim)
104
+ return index
105
+
106
+ def get_patch_location_from_dataset_idx(self, dataset_idx:int):
107
+ """
108
+ Returns the patch location of the grid in the dataset.
109
+ """
110
+ grid_location = self.get_location_from_dataset_idx(dataset_idx)
111
+ offset = self.patch_offset()
112
+ return tuple(np.array(grid_location) - np.array(offset))
113
+
114
+
115
+ def get_dataset_idx_from_grid_location(self, location:tuple):
116
+ assert len(location) == len(self.data_shape), f"Location {location} must have the same dimension as data shape {self.data_shape}"
117
+ grid_idx = [self.get_grid_index(dim, location[dim]) for dim in range(len(location))]
118
+ return self.dataset_idx_from_grid_idx(tuple(grid_idx))
119
+
120
+ def get_gridstart_location_from_dim_index(self, dim:int, dim_index:int):
121
+ """
122
+ Returns the grid-start coordinate of the grid in the specified dimension.
123
+ """
124
+ assert dim < len(self.data_shape), f"Dimension {dim} is out of bounds for data shape {self.data_shape}"
125
+ assert dim >= 0, "Dimension must be greater than or equal to 0"
126
+ assert dim_index < self.get_individual_dim_grid_count(dim), f"Dimension index {dim_index} is out of bounds for data shape {self.data_shape}"
127
+
128
+ if self.grid_shape[dim]==1 and self.patch_shape[dim]==1:
129
+ return dim_index
130
+ elif self.tiling_mode == TilingMode.PadBoundary:
131
+ return dim_index * self.grid_shape[dim]
132
+ elif self.tiling_mode == TilingMode.TrimBoundary:
133
+ excess_size = (self.patch_shape[dim] - self.grid_shape[dim])//2
134
+ return dim_index * self.grid_shape[dim] + excess_size
135
+ elif self.tiling_mode == TilingMode.ShiftBoundary:
136
+ excess_size = (self.patch_shape[dim] - self.grid_shape[dim])//2
137
+ if dim_index < self.get_individual_dim_grid_count(dim) - 1:
138
+ return dim_index * self.grid_shape[dim] + excess_size
139
+ else:
140
+ # on boundary. grid should be placed such that the patch covers the entire data.
141
+ return self.data_shape[dim] - self.grid_shape[dim] - excess_size
142
+ else:
143
+ raise ValueError(f"Unsupported tiling mode {self.tiling_mode}")
144
+
145
+ def get_location_from_dataset_idx(self, dataset_idx:int):
146
+ """
147
+ Returns the start location of the grid in the dataset.
148
+ """
149
+ grid_idx = []
150
+ for dim in range(len(self.data_shape)):
151
+ grid_idx.append(dataset_idx // self.grid_count(dim))
152
+ dataset_idx = dataset_idx % self.grid_count(dim)
153
+ location = [self.get_gridstart_location_from_dim_index(dim, grid_idx[dim]) for dim in range(len(self.data_shape))]
154
+ return tuple(location)
155
+
156
+ def on_boundary(self, dataset_idx:int, dim:int, only_end:bool=False):
157
+ """
158
+ Returns True if the grid is on the boundary in the specified dimension.
159
+ """
160
+ assert dim < len(self.data_shape), f"Dimension {dim} is out of bounds for data shape {self.data_shape}"
161
+ assert dim >= 0, "Dimension must be greater than or equal to 0"
162
+
163
+ if dim > 0:
164
+ dataset_idx = dataset_idx % self.grid_count(dim-1)
165
+
166
+ dim_index = dataset_idx // self.grid_count(dim)
167
+ if only_end:
168
+ return dim_index == self.get_individual_dim_grid_count(dim) - 1
169
+
170
+ return dim_index == 0 or dim_index == self.get_individual_dim_grid_count(dim) - 1
171
+
172
+ def next_grid_along_dim(self, dataset_idx:int, dim:int):
173
+ """
174
+ Returns the index of the grid in the specified dimension in the specified direction.
175
+ """
176
+ assert dim < len(self.data_shape), f"Dimension {dim} is out of bounds for data shape {self.data_shape}"
177
+ assert dim >= 0, "Dimension must be greater than or equal to 0"
178
+ new_idx = dataset_idx + self.grid_count(dim)
179
+ if new_idx >= self.total_grid_count():
180
+ return None
181
+ return new_idx
182
+
183
+ def prev_grid_along_dim(self, dataset_idx:int, dim:int):
184
+ """
185
+ Returns the index of the grid in the specified dimension in the specified direction.
186
+ """
187
+ assert dim < len(self.data_shape), f"Dimension {dim} is out of bounds for data shape {self.data_shape}"
188
+ assert dim >= 0, "Dimension must be greater than or equal to 0"
189
+ new_idx = dataset_idx - self.grid_count(dim)
190
+ if new_idx < 0:
191
+ return None
192
+
193
+ if __name__ == '__main__':
194
+ # data_shape = (1, 5, 103, 103,2)
195
+ # grid_shape = (1, 1, 16,16, 2)
196
+ # patch_shape = (1, 3, 32, 32, 2)
197
+ data_shape = (5, 5, 64, 64, 2)
198
+ grid_shape = (1, 1, 8, 8, 2)
199
+ patch_shape = (1, 3, 16, 16, 2)
200
+ tiling_mode = TilingMode.ShiftBoundary
201
+ manager = TileIndexManager(data_shape, grid_shape, patch_shape, tiling_mode)
202
+ gc = manager.total_grid_count()
203
+ for i in range(gc):
204
+ loc = manager.get_location_from_dataset_idx(i)
205
+ print(i, loc)
206
+ inferred_i = manager.get_dataset_idx_from_grid_location(loc)
207
+ assert i == inferred_i, f"Index mismatch: {i} != {inferred_i}"
208
+
209
+ for i in range(5):
210
+ print(manager.on_boundary(40, i))
@@ -0,0 +1,65 @@
1
+ from typing import List
2
+
3
+ import numpy as np
4
+
5
+ from predtiler.tile_manager import TilingMode
6
+
7
+
8
+ def stitch_predictions(predictions:np.ndarray, manager):
9
+ """
10
+ Args:
11
+ predictions: N*C*H*W or N*C*D*H*W numpy array where N is the number of datasets, C is the number of channels, H is the height, W is the width, D is the depth.
12
+ manager:
13
+ """
14
+
15
+ mng = manager
16
+ shape = list(mng.data_shape)
17
+ shape.append(predictions.shape[1])
18
+ print(shape)
19
+
20
+ output = np.zeros(shape, dtype=predictions.dtype)
21
+ for dset_idx in range(predictions.shape[0]):
22
+ # grid start, grid end
23
+ gs = np.array(mng.get_location_from_dataset_idx(dset_idx), dtype=int)
24
+ ge = gs + mng.grid_shape
25
+
26
+ # patch start, patch end
27
+ ps = gs - mng.patch_offset()
28
+ pe = ps + mng.patch_shape
29
+
30
+ # valid grid start, valid grid end
31
+ vgs = np.array([max(0,x) for x in gs], dtype=int)
32
+ vge = np.array([min(x,y) for x,y in zip(ge, mng.data_shape)], dtype=int)
33
+ assert np.all(vgs ==gs)
34
+ assert np.all(vge ==ge)
35
+
36
+ if mng.tiling_mode == TilingMode.ShiftBoundary:
37
+ for dim in range(len(vgs)):
38
+ if ps[dim] == 0:
39
+ vgs[dim] = 0
40
+ if pe[dim] == mng.data_shape[dim]:
41
+ vge[dim]= mng.data_shape[dim]
42
+
43
+ # relative start, relative end. This will be used on pred_tiled
44
+ rs = vgs - ps
45
+ re = rs + ( vge - vgs)
46
+
47
+ for ch_idx in range(predictions.shape[1]):
48
+ if len(output.shape) == 4:
49
+ # channel dimension is the last one.
50
+ output[vgs[0]:vge[0],
51
+ vgs[1]:vge[1],
52
+ vgs[2]:vge[2],
53
+ ch_idx] = predictions[dset_idx][ch_idx,rs[1]:re[1], rs[2]:re[2]]
54
+ elif len(output.shape) == 5:
55
+ # channel dimension is the last one.
56
+ assert vge[0] - vgs[0] == 1, 'Only one frame is supported'
57
+ output[vgs[0],
58
+ vgs[1]:vge[1],
59
+ vgs[2]:vge[2],
60
+ vgs[3]:vge[3],
61
+ ch_idx] = predictions[dset_idx][ch_idx, rs[1]:re[1], rs[2]:re[2], rs[3]:re[3]]
62
+ else:
63
+ raise ValueError(f'Unsupported shape {output.shape}')
64
+
65
+ return output
@@ -0,0 +1,92 @@
1
+ from unittest.mock import Mock
2
+ import numpy as np
3
+ from predtiler.dataset import get_tiling_dataset, get_tile_manager
4
+ from predtiler.tile_stitcher import stitch_predictions
5
+
6
+ def get_data_3D(n=5,Z=9, H=512,W=512,C=2):
7
+ data = np.arange(n*Z*H*W*C).reshape(n,Z,H,W,C)
8
+ return data
9
+
10
+ def get_data_2D(n=5,H=512,W=512,C=2):
11
+ data = np.arange(n*H*W*C).reshape(n,H,W,C)
12
+ return data
13
+
14
+ class DummDataset:
15
+ def __init__(self, datatype ='2D', patch_size=64, z_patch_size=5) -> None:
16
+ assert datatype in ['2D', '3D'], 'datatype must be either 2D or 3D'
17
+ self.datatype = datatype
18
+ self.z_patch_size = z_patch_size
19
+ self.patch_size = patch_size
20
+ if datatype == '2D':
21
+ self.data = get_data_2D()
22
+ elif datatype == '3D':
23
+ self.data = get_data_3D()
24
+
25
+ def patch_location(self, index):
26
+ if self.datatype == '2D':
27
+ n_idx = np.random.randint(0,len(self.data))
28
+ h = np.random.randint(0, self.data.shape[1]-self.patch_size)
29
+ w = np.random.randint(0, self.data.shape[2]-self.patch_size)
30
+ return (n_idx, h, w)
31
+ elif self.datatype == '3D':
32
+ n_idx = np.random.randint(0,len(self.data))
33
+ z = np.random.randint(0, self.data.shape[1]-self.z_patch_size)
34
+ h = np.random.randint(0, self.data.shape[2]-self.patch_size)
35
+ w = np.random.randint(0, self.data.shape[3]-self.patch_size)
36
+ return (n_idx, z, h, w)
37
+
38
+ def __len__(self):
39
+ return len(self.data) * (self.data.shape[-2]//self.patch_size) * (self.data.shape[-3]//self.patch_size)
40
+
41
+ def __getitem__(self, index):
42
+ if self.datatype == '2D':
43
+ n_idx, h, w = self.patch_location(index)
44
+ return self.data[n_idx, h:h+self.patch_size, w:w+self.patch_size].transpose(2,0,1)
45
+ elif self.datatype == '3D':
46
+ n_idx, z, h, w = self.patch_location(index)
47
+ return self.data[n_idx, z:z+self.z_patch_size, h:h+self.patch_size, w:w+self.patch_size].transpose(3,0,1,2)
48
+
49
+
50
+ def test_stich_prediction_2D():
51
+ data_type = '2D'
52
+ data_fn = get_data_2D
53
+ patch_size = 256
54
+ tile_size = 128
55
+ data = data_fn()
56
+ manager = get_tile_manager(data_shape=data.shape[:-1], tile_shape=(1,tile_size,tile_size),
57
+ patch_shape=(1,patch_size,patch_size))
58
+
59
+ dset_class = get_tiling_dataset(DummDataset, manager)
60
+ dset = dset_class(data_type, patch_size)
61
+
62
+ predictions = []
63
+ for i in range(len(dset)):
64
+ predictions.append(dset[i])
65
+
66
+ predictions = np.stack(predictions)
67
+ stitched_pred = stitch_predictions(predictions, dset.tile_manager)
68
+ assert (stitched_pred== data).all()
69
+
70
+
71
+
72
+ def test_stich_prediction_3D():
73
+ data_type = '3D'
74
+ data_fn = get_data_3D
75
+ patch_size = 256
76
+ tile_size = 128
77
+ data = data_fn()
78
+ z_patch_size = 5
79
+ z_tile_size = 3
80
+ manager = get_tile_manager(data_shape=data.shape[:-1], tile_shape=(1,z_tile_size, tile_size,tile_size),
81
+ patch_shape=(1,z_patch_size, patch_size,patch_size))
82
+
83
+ dset_class = get_tiling_dataset(DummDataset, manager)
84
+ dset = dset_class(data_type, patch_size)
85
+
86
+ predictions = []
87
+ for i in range(len(dset)):
88
+ predictions.append(dset[i])
89
+
90
+ predictions = np.stack(predictions)
91
+ stitched_pred = stitch_predictions(predictions, dset.tile_manager)
92
+ assert (stitched_pred== data).all()