predtiler 0.0.1__tar.gz
Sign up to get free protection for your applications and to get access to all the features.
- predtiler-0.0.1/.github/pull_request_template.md +39 -0
- predtiler-0.0.1/.github/workflows/ci.yml +91 -0
- predtiler-0.0.1/.github/workflows/coverage.yml +48 -0
- predtiler-0.0.1/.gitignore +168 -0
- predtiler-0.0.1/LICENSE +21 -0
- predtiler-0.0.1/PKG-INFO +122 -0
- predtiler-0.0.1/README.md +93 -0
- predtiler-0.0.1/pyproject.toml +165 -0
- predtiler-0.0.1/src/predtiler/dataset.py +53 -0
- predtiler-0.0.1/src/predtiler/tile_manager.py +210 -0
- predtiler-0.0.1/src/predtiler/tile_stitcher.py +65 -0
- predtiler-0.0.1/tests/test_full_tiling_setup.py +92 -0
@@ -0,0 +1,39 @@
|
|
1
|
+
### Description
|
2
|
+
|
3
|
+
Please provide a brief description of the changes in this PR. Include any relevant context or background information.
|
4
|
+
|
5
|
+
- **What**: Clearly and concisely describe what changes you have made.
|
6
|
+
- **Why**: Explain the reasoning behind these changes.
|
7
|
+
- **How**: Describe how you implemented these changes.
|
8
|
+
|
9
|
+
### Changes Made
|
10
|
+
|
11
|
+
- **Added**: List new features or files added.
|
12
|
+
- **Modified**: Describe existing features or files modified.
|
13
|
+
- **Removed**: Detail features or files that were removed.
|
14
|
+
|
15
|
+
### Related Issues
|
16
|
+
|
17
|
+
Link to any related issues or discussions. Use keywords like "Fixes", "Resolves", or "Closes" to link to issues automatically.
|
18
|
+
|
19
|
+
- Fixes #
|
20
|
+
- Resolves #
|
21
|
+
- Closes #
|
22
|
+
|
23
|
+
### Breaking changes
|
24
|
+
|
25
|
+
Describe any breaking change.
|
26
|
+
|
27
|
+
|
28
|
+
### Additional Notes and Examples
|
29
|
+
|
30
|
+
Include any additional notes or context that reviewers should be aware of, including snippets of code illustrating your new feature.
|
31
|
+
|
32
|
+
---
|
33
|
+
|
34
|
+
**Please ensure your PR meets the following requirements:**
|
35
|
+
|
36
|
+
- [ ] Code builds and passes tests locally, including doctests
|
37
|
+
- [ ] New tests have been added (for bug fixes/features)
|
38
|
+
- [ ] Pre-commit passes
|
39
|
+
- [ ] PR to the documentation exists (for bug fixes / features)
|
@@ -0,0 +1,91 @@
|
|
1
|
+
name: CI
|
2
|
+
|
3
|
+
on:
|
4
|
+
push:
|
5
|
+
branches:
|
6
|
+
- main
|
7
|
+
tags:
|
8
|
+
- "v*"
|
9
|
+
pull_request:
|
10
|
+
workflow_dispatch:
|
11
|
+
schedule:
|
12
|
+
# run every week (for --pre release tests)
|
13
|
+
- cron: "0 0 * * 0"
|
14
|
+
|
15
|
+
jobs:
|
16
|
+
check-manifest:
|
17
|
+
# check-manifest is a tool that checks that all files in version control are
|
18
|
+
# included in the sdist (unless explicitly excluded)
|
19
|
+
runs-on: ubuntu-latest
|
20
|
+
steps:
|
21
|
+
- uses: actions/checkout@v3
|
22
|
+
- run: pipx run check-manifest
|
23
|
+
|
24
|
+
test:
|
25
|
+
name: ${{ matrix.platform }} (${{ matrix.python-version }})
|
26
|
+
runs-on: ${{ matrix.platform }}
|
27
|
+
strategy:
|
28
|
+
fail-fast: false
|
29
|
+
matrix:
|
30
|
+
python-version: ["3.9", "3.10", "3.11", "3.12"]
|
31
|
+
# https://docs.github.com/en/actions/using-github-hosted-runners/about-github-hosted-runners/about-github-hosted-runners#standard-github-hosted-runners-for-public-repositories
|
32
|
+
platform: [ubuntu-latest, macos-13, windows-latest]
|
33
|
+
|
34
|
+
steps:
|
35
|
+
- name: ๐ Cancel Previous Runs
|
36
|
+
uses: styfle/cancel-workflow-action@0.11.0
|
37
|
+
with:
|
38
|
+
access_token: ${{ github.token }}
|
39
|
+
|
40
|
+
- uses: actions/checkout@v3
|
41
|
+
|
42
|
+
- name: ๐ Set up Python ${{ matrix.python-version }}
|
43
|
+
uses: actions/setup-python@v4
|
44
|
+
with:
|
45
|
+
python-version: ${{ matrix.python-version }}
|
46
|
+
cache-dependency-path: "pyproject.toml"
|
47
|
+
cache: "pip"
|
48
|
+
|
49
|
+
- name: Install Dependencies
|
50
|
+
run: |
|
51
|
+
python -m pip install -U pip
|
52
|
+
# if running a cron job, we add the --pre flag to test against pre-releases
|
53
|
+
python -m pip install ".[dev]" ${{ github.event_name == 'schedule' && '--pre' || '' }}
|
54
|
+
|
55
|
+
- name: ๐งช Run Tests
|
56
|
+
run: pytest
|
57
|
+
|
58
|
+
deploy:
|
59
|
+
name: Release
|
60
|
+
needs: test
|
61
|
+
if: success() && startsWith(github.ref, 'refs/tags/') && github.event_name != 'schedule'
|
62
|
+
runs-on: ubuntu-latest
|
63
|
+
|
64
|
+
permissions:
|
65
|
+
# IMPORTANT: this permission is mandatory for trusted publishing
|
66
|
+
id-token: write
|
67
|
+
|
68
|
+
# This permission allows writing releases
|
69
|
+
contents: write
|
70
|
+
|
71
|
+
steps:
|
72
|
+
- uses: actions/checkout@v4
|
73
|
+
with:
|
74
|
+
fetch-depth: 0
|
75
|
+
|
76
|
+
- name: Set up Python
|
77
|
+
uses: actions/setup-python@v5
|
78
|
+
with:
|
79
|
+
python-version: "3.9"
|
80
|
+
|
81
|
+
- name: Build
|
82
|
+
run: |
|
83
|
+
python -m pip install build
|
84
|
+
python -m build
|
85
|
+
|
86
|
+
- name: Publish to PyPI
|
87
|
+
uses: pypa/gh-action-pypi-publish@release/v1
|
88
|
+
|
89
|
+
- uses: softprops/action-gh-release@v2
|
90
|
+
with:
|
91
|
+
generate_release_notes: true
|
@@ -0,0 +1,48 @@
|
|
1
|
+
name: Coverage
|
2
|
+
|
3
|
+
on:
|
4
|
+
push:
|
5
|
+
branches:
|
6
|
+
- main
|
7
|
+
tags:
|
8
|
+
- "v*"
|
9
|
+
pull_request:
|
10
|
+
|
11
|
+
jobs:
|
12
|
+
|
13
|
+
test:
|
14
|
+
name: ${{ matrix.platform }} (${{ matrix.python-version }})
|
15
|
+
runs-on: ${{ matrix.platform }}
|
16
|
+
strategy:
|
17
|
+
fail-fast: false
|
18
|
+
matrix:
|
19
|
+
python-version: ["3.10"]
|
20
|
+
platform: [ubuntu-latest]
|
21
|
+
|
22
|
+
steps:
|
23
|
+
- name: ๐ Cancel Previous Runs
|
24
|
+
uses: styfle/cancel-workflow-action@0.11.0
|
25
|
+
with:
|
26
|
+
access_token: ${{ github.token }}
|
27
|
+
|
28
|
+
- uses: actions/checkout@v3
|
29
|
+
|
30
|
+
- name: ๐ Set up Python ${{ matrix.python-version }}
|
31
|
+
uses: actions/setup-python@v4
|
32
|
+
with:
|
33
|
+
python-version: ${{ matrix.python-version }}
|
34
|
+
cache-dependency-path: "pyproject.toml"
|
35
|
+
cache: "pip"
|
36
|
+
|
37
|
+
- name: Install Dependencies
|
38
|
+
run: |
|
39
|
+
python -m pip install -U pip
|
40
|
+
python -m pip install -e ".[dev]"
|
41
|
+
|
42
|
+
- name: ๐งช Run Tests
|
43
|
+
run: pytest --color=yes --cov --cov-config=pyproject.toml --cov-report=xml --cov-report=term-missing
|
44
|
+
|
45
|
+
- name: Coverage
|
46
|
+
uses: codecov/codecov-action@v3
|
47
|
+
with:
|
48
|
+
version: v0.7.3
|
@@ -0,0 +1,168 @@
|
|
1
|
+
# VSCode
|
2
|
+
.vscode
|
3
|
+
|
4
|
+
# Byte-compiled / optimized / DLL files
|
5
|
+
__pycache__/
|
6
|
+
*.py[cod]
|
7
|
+
*$py.class
|
8
|
+
|
9
|
+
# C extensions
|
10
|
+
*.so
|
11
|
+
|
12
|
+
# Distribution / packaging
|
13
|
+
.Python
|
14
|
+
build/
|
15
|
+
develop-eggs/
|
16
|
+
dist/
|
17
|
+
downloads/
|
18
|
+
eggs/
|
19
|
+
.eggs/
|
20
|
+
lib/
|
21
|
+
lib64/
|
22
|
+
parts/
|
23
|
+
sdist/
|
24
|
+
var/
|
25
|
+
wheels/
|
26
|
+
share/python-wheels/
|
27
|
+
*.egg-info/
|
28
|
+
.installed.cfg
|
29
|
+
*.egg
|
30
|
+
MANIFEST
|
31
|
+
|
32
|
+
# PyInstaller
|
33
|
+
# Usually these files are written by a python script from a template
|
34
|
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
35
|
+
*.manifest
|
36
|
+
*.spec
|
37
|
+
|
38
|
+
# Installer logs
|
39
|
+
pip-log.txt
|
40
|
+
pip-delete-this-directory.txt
|
41
|
+
|
42
|
+
# Unit test / coverage reports
|
43
|
+
htmlcov/
|
44
|
+
.tox/
|
45
|
+
.nox/
|
46
|
+
.coverage
|
47
|
+
.coverage.*
|
48
|
+
.cache
|
49
|
+
nosetests.xml
|
50
|
+
coverage.xml
|
51
|
+
*.cover
|
52
|
+
*.py,cover
|
53
|
+
.hypothesis/
|
54
|
+
.pytest_cache/
|
55
|
+
cover/
|
56
|
+
|
57
|
+
# Translations
|
58
|
+
*.mo
|
59
|
+
*.pot
|
60
|
+
|
61
|
+
# Django stuff:
|
62
|
+
*.log
|
63
|
+
local_settings.py
|
64
|
+
db.sqlite3
|
65
|
+
db.sqlite3-journal
|
66
|
+
|
67
|
+
# Flask stuff:
|
68
|
+
instance/
|
69
|
+
.webassets-cache
|
70
|
+
|
71
|
+
# Scrapy stuff:
|
72
|
+
.scrapy
|
73
|
+
|
74
|
+
# Sphinx documentation
|
75
|
+
docs/_build/
|
76
|
+
|
77
|
+
# PyBuilder
|
78
|
+
.pybuilder/
|
79
|
+
target/
|
80
|
+
|
81
|
+
# Jupyter Notebook
|
82
|
+
.ipynb_checkpoints
|
83
|
+
|
84
|
+
# IPython
|
85
|
+
profile_default/
|
86
|
+
ipython_config.py
|
87
|
+
|
88
|
+
# pyenv
|
89
|
+
# For a library or package, you might want to ignore these files since the code is
|
90
|
+
# intended to run in multiple environments; otherwise, check them in:
|
91
|
+
# .python-version
|
92
|
+
|
93
|
+
# pipenv
|
94
|
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
95
|
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
96
|
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
97
|
+
# install all needed dependencies.
|
98
|
+
#Pipfile.lock
|
99
|
+
|
100
|
+
# poetry
|
101
|
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
102
|
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
103
|
+
# commonly ignored for libraries.
|
104
|
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
105
|
+
#poetry.lock
|
106
|
+
|
107
|
+
# pdm
|
108
|
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
109
|
+
#pdm.lock
|
110
|
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
111
|
+
# in version control.
|
112
|
+
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
|
113
|
+
.pdm.toml
|
114
|
+
.pdm-python
|
115
|
+
.pdm-build/
|
116
|
+
|
117
|
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
118
|
+
__pypackages__/
|
119
|
+
|
120
|
+
# Celery stuff
|
121
|
+
celerybeat-schedule
|
122
|
+
celerybeat.pid
|
123
|
+
|
124
|
+
# SageMath parsed files
|
125
|
+
*.sage.py
|
126
|
+
|
127
|
+
# Environments
|
128
|
+
.env
|
129
|
+
.venv
|
130
|
+
env/
|
131
|
+
venv/
|
132
|
+
ENV/
|
133
|
+
env.bak/
|
134
|
+
venv.bak/
|
135
|
+
|
136
|
+
# Spyder project settings
|
137
|
+
.spyderproject
|
138
|
+
.spyproject
|
139
|
+
|
140
|
+
# Rope project settings
|
141
|
+
.ropeproject
|
142
|
+
|
143
|
+
# mkdocs documentation
|
144
|
+
/site
|
145
|
+
|
146
|
+
# mypy
|
147
|
+
.mypy_cache/
|
148
|
+
.dmypy.json
|
149
|
+
dmypy.json
|
150
|
+
|
151
|
+
# Pyre type checker
|
152
|
+
.pyre/
|
153
|
+
|
154
|
+
# pytype static type analyzer
|
155
|
+
.pytype/
|
156
|
+
|
157
|
+
# Cython debug symbols
|
158
|
+
cython_debug/
|
159
|
+
|
160
|
+
# PyCharm
|
161
|
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
162
|
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
163
|
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
164
|
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
165
|
+
#.idea/
|
166
|
+
|
167
|
+
# Ruff
|
168
|
+
.ruff_cache
|
predtiler-0.0.1/LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
1
|
+
MIT License
|
2
|
+
|
3
|
+
Copyright (c) 2024 ashesh
|
4
|
+
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
7
|
+
in the Software without restriction, including without limitation the rights
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
10
|
+
furnished to do so, subject to the following conditions:
|
11
|
+
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
13
|
+
copies or substantial portions of the Software.
|
14
|
+
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21
|
+
SOFTWARE.
|
predtiler-0.0.1/PKG-INFO
ADDED
@@ -0,0 +1,122 @@
|
|
1
|
+
Metadata-Version: 2.3
|
2
|
+
Name: predtiler
|
3
|
+
Version: 0.0.1
|
4
|
+
Summary: Converting your dataset class into a class that can be used for tiled prediction and eventually obtain stiched prediction.
|
5
|
+
Project-URL: homepage, https://github.com/ashesh-0/PredTiler
|
6
|
+
Project-URL: repository, https://github.com/ashesh-0/PredTiler
|
7
|
+
Author: Ashesh
|
8
|
+
License: MIT
|
9
|
+
License-File: LICENSE
|
10
|
+
Classifier: Development Status :: 3 - Alpha
|
11
|
+
Classifier: License :: OSI Approved :: MIT License
|
12
|
+
Classifier: Programming Language :: Python :: 3
|
13
|
+
Classifier: Programming Language :: Python :: 3.9
|
14
|
+
Classifier: Programming Language :: Python :: 3.10
|
15
|
+
Classifier: Programming Language :: Python :: 3.11
|
16
|
+
Classifier: Programming Language :: Python :: 3.12
|
17
|
+
Classifier: Typing :: Typed
|
18
|
+
Requires-Python: >=3.9
|
19
|
+
Requires-Dist: numpy
|
20
|
+
Provides-Extra: dev
|
21
|
+
Requires-Dist: pre-commit; extra == 'dev'
|
22
|
+
Requires-Dist: pytest; extra == 'dev'
|
23
|
+
Requires-Dist: pytest-cov; extra == 'dev'
|
24
|
+
Requires-Dist: sybil; extra == 'dev'
|
25
|
+
Provides-Extra: examples
|
26
|
+
Requires-Dist: jupyter; extra == 'examples'
|
27
|
+
Requires-Dist: matplotlib; extra == 'examples'
|
28
|
+
Description-Content-Type: text/markdown
|
29
|
+
|
30
|
+
A lean wrapper around your dataset class to enable tiled prediction.
|
31
|
+
|
32
|
+
[![License](https://img.shields.io/badge/license-MIT-blue)](https://github.com/ashesh-0/PredTiler/blob/main/LICENSE)
|
33
|
+
[![CI](https://github.com/ashesh-0/PredTiler/actions/workflows/ci.yml/badge.svg)](https://github.com/ashesh-0/PredTiler/actions/workflows/ci.yml)
|
34
|
+
[![codecov](https://codecov.io/gh/ashesh-0/PredTiler/graph/badge.svg?token=M655MOS7EL)](https://codecov.io/gh/ashesh-0/PredTiler)
|
35
|
+
|
36
|
+
## Objective
|
37
|
+
This package subclasses the dataset class you use to train your network.
|
38
|
+
With PredTiler, you can use your dataset class as is, and PredTiler will take care of the tiling logic for you.
|
39
|
+
It will automatically generate patches in such a way that they can be tiled with the overlap of `(patch_size - tile_size)//2`.
|
40
|
+
We also provide a function to stitch the tiles back together to get the final prediction.
|
41
|
+
|
42
|
+
In case you are facing issues, feel free to raise an issue and I will be happy to help you out !
|
43
|
+
In future, I plan to add detailed instructions for:
|
44
|
+
1. multi-channel data
|
45
|
+
2. 3D data
|
46
|
+
3. Data being a list of numpy arrays, each poissibly having different shapes.
|
47
|
+
|
48
|
+
## Installation
|
49
|
+
|
50
|
+
```bash
|
51
|
+
pip install predtiler
|
52
|
+
```
|
53
|
+
|
54
|
+
## Usage
|
55
|
+
To work with PredTiler, the only requirement is that your dataset class must have a **patch_location(self, index)** method that returns the location of the patch at the given index.
|
56
|
+
Your dataset class should only use the location information returned by this method to return the patch.
|
57
|
+
PredTiler will override this method to return the location of the patches needed for tiled prediction.
|
58
|
+
|
59
|
+
Note that your dataset class could be arbitrarily complex (augmentations, returning multiple patches, working with 3D data, etc.). The only requirement is that it should use the crop present at the location returned by **patch_location** method. Below is an example of a simple dataset class that can be used with PredTiler.
|
60
|
+
|
61
|
+
```python
|
62
|
+
class YourDataset:
|
63
|
+
def __init__(self, data_path, patch_size=64) -> None:
|
64
|
+
self.patch_size = patch_size
|
65
|
+
self.data = load_data(data_path) # shape: (N, H, W, C)
|
66
|
+
|
67
|
+
def patch_location(self, index:int)-> Tuple[int, int, int]:
|
68
|
+
# it just ignores the index and returns a random location
|
69
|
+
n_idx = np.random.randint(0,len(self.data))
|
70
|
+
h = np.random.randint(0, self.data.shape[1]-self.patch_size)
|
71
|
+
w = np.random.randint(0, self.data.shape[2]-self.patch_size)
|
72
|
+
return (n_idx, h, w)
|
73
|
+
|
74
|
+
def __len__(self):
|
75
|
+
return len(self.data)
|
76
|
+
|
77
|
+
def __getitem__(self, index):
|
78
|
+
n_idx, h, w = self.patch_location(index)
|
79
|
+
# return the patch at the location (patch_size, patch_size)
|
80
|
+
return self.data[n_idx, h:h+self.patch_size, w:w+self.patch_size]
|
81
|
+
```
|
82
|
+
|
83
|
+
## Getting overlapping patches needed for tiled prediction
|
84
|
+
To use PredTiler, we need to get a new class that wraps around your dataset class.
|
85
|
+
For this we also need a tile manager that will manage the tiles.
|
86
|
+
|
87
|
+
```python
|
88
|
+
|
89
|
+
from predtiler.dataset import get_tiling_dataset, get_tile_manager
|
90
|
+
patch_size = 256
|
91
|
+
tile_size = 128
|
92
|
+
data_shape = (10, 2048, 2048) # size of the data you are working with
|
93
|
+
manager = get_tile_manager(data_shape=data_shape, tile_shape=(1,tile_size,tile_size),
|
94
|
+
patch_shape=(1,patch_size,patch_size))
|
95
|
+
|
96
|
+
dset_class = get_tiling_dataset(YourDataset, manager)
|
97
|
+
```
|
98
|
+
|
99
|
+
At this point, you can use the `dset_class` as you would use `YourDataset` class.
|
100
|
+
|
101
|
+
```python
|
102
|
+
data_path = ... # path to your data
|
103
|
+
dset = dset_class(data_path, patch_size=patch_size)
|
104
|
+
```
|
105
|
+
|
106
|
+
## Stitching the predictions
|
107
|
+
The benefit of using PredTiler is that it will automatically generate the patches in such a way that they can be tiled with the overlap of `(patch_size - tile_size)//2`. This allows you to use your dataset class as is, without worrying about the tiling logic.
|
108
|
+
|
109
|
+
```python
|
110
|
+
model = ... # your model
|
111
|
+
predictions = []
|
112
|
+
for i in range(len(dset)):
|
113
|
+
inp = dset[i]
|
114
|
+
inp = torch.Tensor(inp)[None,None]
|
115
|
+
pred = model(inp)
|
116
|
+
predictions.append(pred[0].numpy())
|
117
|
+
|
118
|
+
predictions = np.stack(predictions) # shape: (number_of_patches, C, patch_size, patch_size)
|
119
|
+
stitched_pred = stitch_predictions(predictions, dset.tile_manager)
|
120
|
+
```
|
121
|
+
|
122
|
+
|
@@ -0,0 +1,93 @@
|
|
1
|
+
A lean wrapper around your dataset class to enable tiled prediction.
|
2
|
+
|
3
|
+
[![License](https://img.shields.io/badge/license-MIT-blue)](https://github.com/ashesh-0/PredTiler/blob/main/LICENSE)
|
4
|
+
[![CI](https://github.com/ashesh-0/PredTiler/actions/workflows/ci.yml/badge.svg)](https://github.com/ashesh-0/PredTiler/actions/workflows/ci.yml)
|
5
|
+
[![codecov](https://codecov.io/gh/ashesh-0/PredTiler/graph/badge.svg?token=M655MOS7EL)](https://codecov.io/gh/ashesh-0/PredTiler)
|
6
|
+
|
7
|
+
## Objective
|
8
|
+
This package subclasses the dataset class you use to train your network.
|
9
|
+
With PredTiler, you can use your dataset class as is, and PredTiler will take care of the tiling logic for you.
|
10
|
+
It will automatically generate patches in such a way that they can be tiled with the overlap of `(patch_size - tile_size)//2`.
|
11
|
+
We also provide a function to stitch the tiles back together to get the final prediction.
|
12
|
+
|
13
|
+
In case you are facing issues, feel free to raise an issue and I will be happy to help you out !
|
14
|
+
In future, I plan to add detailed instructions for:
|
15
|
+
1. multi-channel data
|
16
|
+
2. 3D data
|
17
|
+
3. Data being a list of numpy arrays, each poissibly having different shapes.
|
18
|
+
|
19
|
+
## Installation
|
20
|
+
|
21
|
+
```bash
|
22
|
+
pip install predtiler
|
23
|
+
```
|
24
|
+
|
25
|
+
## Usage
|
26
|
+
To work with PredTiler, the only requirement is that your dataset class must have a **patch_location(self, index)** method that returns the location of the patch at the given index.
|
27
|
+
Your dataset class should only use the location information returned by this method to return the patch.
|
28
|
+
PredTiler will override this method to return the location of the patches needed for tiled prediction.
|
29
|
+
|
30
|
+
Note that your dataset class could be arbitrarily complex (augmentations, returning multiple patches, working with 3D data, etc.). The only requirement is that it should use the crop present at the location returned by **patch_location** method. Below is an example of a simple dataset class that can be used with PredTiler.
|
31
|
+
|
32
|
+
```python
|
33
|
+
class YourDataset:
|
34
|
+
def __init__(self, data_path, patch_size=64) -> None:
|
35
|
+
self.patch_size = patch_size
|
36
|
+
self.data = load_data(data_path) # shape: (N, H, W, C)
|
37
|
+
|
38
|
+
def patch_location(self, index:int)-> Tuple[int, int, int]:
|
39
|
+
# it just ignores the index and returns a random location
|
40
|
+
n_idx = np.random.randint(0,len(self.data))
|
41
|
+
h = np.random.randint(0, self.data.shape[1]-self.patch_size)
|
42
|
+
w = np.random.randint(0, self.data.shape[2]-self.patch_size)
|
43
|
+
return (n_idx, h, w)
|
44
|
+
|
45
|
+
def __len__(self):
|
46
|
+
return len(self.data)
|
47
|
+
|
48
|
+
def __getitem__(self, index):
|
49
|
+
n_idx, h, w = self.patch_location(index)
|
50
|
+
# return the patch at the location (patch_size, patch_size)
|
51
|
+
return self.data[n_idx, h:h+self.patch_size, w:w+self.patch_size]
|
52
|
+
```
|
53
|
+
|
54
|
+
## Getting overlapping patches needed for tiled prediction
|
55
|
+
To use PredTiler, we need to get a new class that wraps around your dataset class.
|
56
|
+
For this we also need a tile manager that will manage the tiles.
|
57
|
+
|
58
|
+
```python
|
59
|
+
|
60
|
+
from predtiler.dataset import get_tiling_dataset, get_tile_manager
|
61
|
+
patch_size = 256
|
62
|
+
tile_size = 128
|
63
|
+
data_shape = (10, 2048, 2048) # size of the data you are working with
|
64
|
+
manager = get_tile_manager(data_shape=data_shape, tile_shape=(1,tile_size,tile_size),
|
65
|
+
patch_shape=(1,patch_size,patch_size))
|
66
|
+
|
67
|
+
dset_class = get_tiling_dataset(YourDataset, manager)
|
68
|
+
```
|
69
|
+
|
70
|
+
At this point, you can use the `dset_class` as you would use `YourDataset` class.
|
71
|
+
|
72
|
+
```python
|
73
|
+
data_path = ... # path to your data
|
74
|
+
dset = dset_class(data_path, patch_size=patch_size)
|
75
|
+
```
|
76
|
+
|
77
|
+
## Stitching the predictions
|
78
|
+
The benefit of using PredTiler is that it will automatically generate the patches in such a way that they can be tiled with the overlap of `(patch_size - tile_size)//2`. This allows you to use your dataset class as is, without worrying about the tiling logic.
|
79
|
+
|
80
|
+
```python
|
81
|
+
model = ... # your model
|
82
|
+
predictions = []
|
83
|
+
for i in range(len(dset)):
|
84
|
+
inp = dset[i]
|
85
|
+
inp = torch.Tensor(inp)[None,None]
|
86
|
+
pred = model(inp)
|
87
|
+
predictions.append(pred[0].numpy())
|
88
|
+
|
89
|
+
predictions = np.stack(predictions) # shape: (number_of_patches, C, patch_size, patch_size)
|
90
|
+
stitched_pred = stitch_predictions(predictions, dset.tile_manager)
|
91
|
+
```
|
92
|
+
|
93
|
+
|
@@ -0,0 +1,165 @@
|
|
1
|
+
# https://peps.python.org/pep-0517/
|
2
|
+
[build-system]
|
3
|
+
requires = ["hatchling", "hatch-vcs"]
|
4
|
+
build-backend = "hatchling.build"
|
5
|
+
# read more about configuring hatch at:
|
6
|
+
# https://hatch.pypa.io/latest/config/build/
|
7
|
+
|
8
|
+
# https://hatch.pypa.io/latest/config/metadata/
|
9
|
+
[tool.hatch.version]
|
10
|
+
source = "vcs" # use tags for versioning (e.g. tag v0.1.0, v0.2.0 etc.)
|
11
|
+
|
12
|
+
[tool.hatch.build.targets.wheel]
|
13
|
+
only-include = ["src"]
|
14
|
+
sources = ["src"]
|
15
|
+
|
16
|
+
# https://peps.python.org/pep-0621/
|
17
|
+
[project]
|
18
|
+
name = "predtiler"
|
19
|
+
version = "0.0.1"
|
20
|
+
description = "Converting your dataset class into a class that can be used for tiled prediction and eventually obtain stiched prediction."
|
21
|
+
readme = "README.md"
|
22
|
+
requires-python = ">=3.9"
|
23
|
+
license = { text = "MIT" }
|
24
|
+
authors = [
|
25
|
+
{ name = 'Ashesh' },
|
26
|
+
]
|
27
|
+
classifiers = [
|
28
|
+
"Development Status :: 3 - Alpha",
|
29
|
+
"Programming Language :: Python :: 3",
|
30
|
+
"Programming Language :: Python :: 3.9",
|
31
|
+
"Programming Language :: Python :: 3.10",
|
32
|
+
"Programming Language :: Python :: 3.11",
|
33
|
+
"Programming Language :: Python :: 3.12",
|
34
|
+
"License :: OSI Approved :: MIT License",
|
35
|
+
"Typing :: Typed",
|
36
|
+
]
|
37
|
+
dependencies = [
|
38
|
+
"numpy",
|
39
|
+
]
|
40
|
+
|
41
|
+
[project.optional-dependencies]
|
42
|
+
# development dependencies and tooling
|
43
|
+
dev = [
|
44
|
+
"pre-commit",
|
45
|
+
"pytest",
|
46
|
+
"pytest-cov",
|
47
|
+
"sybil", # doctesting
|
48
|
+
]
|
49
|
+
|
50
|
+
# examples
|
51
|
+
examples = ["jupyter", "matplotlib"]
|
52
|
+
|
53
|
+
[project.urls]
|
54
|
+
homepage = "https://github.com/ashesh-0/PredTiler"
|
55
|
+
repository = "https://github.com/ashesh-0/PredTiler"
|
56
|
+
|
57
|
+
# https://beta.ruff.rs/docs
|
58
|
+
[tool.ruff]
|
59
|
+
line-length = 88
|
60
|
+
target-version = "py39"
|
61
|
+
src = ["src"]
|
62
|
+
lint.select = [
|
63
|
+
"E", # style errors
|
64
|
+
"W", # style warnings
|
65
|
+
"F", # flakes
|
66
|
+
"D", # pydocstyle
|
67
|
+
"I", # isort
|
68
|
+
"UP", # pyupgrade
|
69
|
+
# "S", # bandit
|
70
|
+
"C4", # flake8-comprehensions
|
71
|
+
"B", # flake8-bugbear
|
72
|
+
"A001", # flake8-builtins
|
73
|
+
"RUF", # ruff-specific rules
|
74
|
+
]
|
75
|
+
lint.ignore = [
|
76
|
+
"D100", # Missing docstring in public module
|
77
|
+
"D107", # Missing docstring in __init__
|
78
|
+
"D203", # 1 blank line required before class docstring
|
79
|
+
"D212", # Multi-line docstring summary should start at the first line
|
80
|
+
"D213", # Multi-line docstring summary should start at the second line
|
81
|
+
"D401", # First line should be in imperative mood
|
82
|
+
"D413", # Missing blank line after last section
|
83
|
+
"D416", # Section name should end with a colon
|
84
|
+
|
85
|
+
# incompatibility with mypy
|
86
|
+
"RUF005", # collection-literal-concatenation, in prediction_utils.py:30
|
87
|
+
|
88
|
+
# version specific
|
89
|
+
"UP007", # Replace Union by |, mandatory for py3.9
|
90
|
+
]
|
91
|
+
show-fixes = true
|
92
|
+
|
93
|
+
[tool.ruff.lint.pydocstyle]
|
94
|
+
convention = "numpy"
|
95
|
+
|
96
|
+
[tool.ruff.lint.per-file-ignores]
|
97
|
+
"tests/*.py" = ["D", "S"]
|
98
|
+
"setup.py" = ["D"]
|
99
|
+
|
100
|
+
[tool.black]
|
101
|
+
line-length = 88
|
102
|
+
|
103
|
+
# https://mypy.readthedocs.io/en/stable/config_file.html
|
104
|
+
[tool.mypy]
|
105
|
+
files = "src/**/"
|
106
|
+
strict = false
|
107
|
+
# allow_untyped_defs = false
|
108
|
+
# allow_untyped_calls = false
|
109
|
+
# disallow_any_generics = false
|
110
|
+
# ignore_missing_imports = false
|
111
|
+
|
112
|
+
|
113
|
+
# https://docs.pytest.org/en/6.2.x/customize.html
|
114
|
+
[tool.pytest.ini_options]
|
115
|
+
minversion = "6.0"
|
116
|
+
testpaths = ["tests"] # add src/careamics for doctest discovery
|
117
|
+
filterwarnings = [
|
118
|
+
# "error",
|
119
|
+
# "ignore::UserWarning",
|
120
|
+
]
|
121
|
+
addopts = "-p no:doctest"
|
122
|
+
|
123
|
+
|
124
|
+
# https://coverage.readthedocs.io/en/6.4/config.html
|
125
|
+
[tool.coverage.report]
|
126
|
+
exclude_lines = [
|
127
|
+
"pragma: no cover",
|
128
|
+
"if TYPE_CHECKING:",
|
129
|
+
"@overload",
|
130
|
+
"except ImportError",
|
131
|
+
"\\.\\.\\.",
|
132
|
+
"raise NotImplementedError()",
|
133
|
+
]
|
134
|
+
|
135
|
+
|
136
|
+
[tool.coverage.run]
|
137
|
+
source = ["src/microssim"]
|
138
|
+
|
139
|
+
# https://github.com/mgedmin/check-manifest#configuration
|
140
|
+
# add files that you want check-manifest to explicitly ignore here
|
141
|
+
# (files that are in the repo but shouldn't go in the package)
|
142
|
+
[tool.check-manifest]
|
143
|
+
ignore = [
|
144
|
+
".github_changelog_generator",
|
145
|
+
".pre-commit-config.yaml",
|
146
|
+
".ruff_cache/**/*",
|
147
|
+
"setup.py",
|
148
|
+
"tests/**/*",
|
149
|
+
]
|
150
|
+
|
151
|
+
[tool.numpydoc_validation]
|
152
|
+
checks = [
|
153
|
+
"all", # report on all checks, except the below
|
154
|
+
"EX01", # Example section not found
|
155
|
+
"SA01", # See Also section not found
|
156
|
+
"ES01", # Extended Summar not found
|
157
|
+
"GL01", # Docstring text (summary) should start in the line immediately
|
158
|
+
# after the opening quotes
|
159
|
+
"GL02", # Closing quotes should be placed in the line after the last text
|
160
|
+
# in the docstring
|
161
|
+
"GL03", # Double line break found
|
162
|
+
]
|
163
|
+
exclude = [ # don't report on objects that match any of these regex
|
164
|
+
"test_*",
|
165
|
+
]
|
@@ -0,0 +1,53 @@
|
|
1
|
+
|
2
|
+
from predtiler.tile_manager import TileIndexManager, TilingMode
|
3
|
+
|
4
|
+
# class TilingDataset:
|
5
|
+
# def __init_subclass__(cls, parent_class=None, tile_manager=None, **kwargs):
|
6
|
+
# super().__init_subclass__(**kwargs)
|
7
|
+
# assert tile_manager is not None, 'tile_manager must be provided'
|
8
|
+
# cls.tile_manager = tile_manager
|
9
|
+
# if parent_class is not None:
|
10
|
+
# has_callable_method = callable(getattr(parent_class, 'patch_location', None))
|
11
|
+
# assert has_callable_method, f'{parent_class.__name__} must have a callable method with following signature: def patch_location(self, index)'
|
12
|
+
# cls.__bases__ = (parent_class,) + cls.__bases__
|
13
|
+
|
14
|
+
# def __len__(self):
|
15
|
+
# return self.tile_manager.total_grid_count()
|
16
|
+
|
17
|
+
# def patch_location(self, index):
|
18
|
+
# print('Calling patch_location')
|
19
|
+
# patch_loc_list = self.tile_manager.get_patch_location_from_dataset_idx(index)
|
20
|
+
# return patch_loc_list
|
21
|
+
|
22
|
+
|
23
|
+
# def get_tiling_dataset(dataset_class, tile_manager) -> type:
|
24
|
+
# class CorrespondingTilingDataset(TilingDataset, parent_class=dataset_class, tile_manager=tile_manager):
|
25
|
+
# pass
|
26
|
+
|
27
|
+
# return CorrespondingTilingDataset
|
28
|
+
|
29
|
+
def get_tiling_dataset(dataset_class, tile_manager) -> type:
|
30
|
+
has_callable_method = callable(getattr(dataset_class, 'patch_location', None))
|
31
|
+
assert has_callable_method, f'{dataset_class.__name__} must have a callable method with following signature: def patch_location(self, index)'
|
32
|
+
|
33
|
+
class TilingDataset(dataset_class):
|
34
|
+
def __init__(self, *args, **kwargs):
|
35
|
+
super().__init__(*args, **kwargs)
|
36
|
+
self.tile_manager = tile_manager
|
37
|
+
|
38
|
+
def __len__(self):
|
39
|
+
return self.tile_manager.total_grid_count()
|
40
|
+
|
41
|
+
def patch_location(self, index):
|
42
|
+
patch_loc_list = self.tile_manager.get_patch_location_from_dataset_idx(index)
|
43
|
+
return patch_loc_list
|
44
|
+
|
45
|
+
return TilingDataset
|
46
|
+
|
47
|
+
|
48
|
+
|
49
|
+
|
50
|
+
def get_tile_manager(data_shape, tile_shape, patch_shape, tiling_mode=TilingMode.ShiftBoundary):
|
51
|
+
return TileIndexManager(data_shape, tile_shape, patch_shape, tiling_mode)
|
52
|
+
|
53
|
+
|
@@ -0,0 +1,210 @@
|
|
1
|
+
from dataclasses import dataclass
|
2
|
+
|
3
|
+
import numpy as np
|
4
|
+
|
5
|
+
|
6
|
+
class TilingMode:
|
7
|
+
"""
|
8
|
+
Enum for the tiling mode.
|
9
|
+
"""
|
10
|
+
TrimBoundary = 0
|
11
|
+
PadBoundary = 1
|
12
|
+
ShiftBoundary = 2
|
13
|
+
|
14
|
+
@dataclass
|
15
|
+
class TileIndexManager:
|
16
|
+
data_shape: tuple
|
17
|
+
grid_shape: tuple
|
18
|
+
patch_shape: tuple
|
19
|
+
tiling_mode: TilingMode
|
20
|
+
|
21
|
+
def __post_init__(self):
|
22
|
+
assert len(self.data_shape) == len(self.grid_shape), f"Data shape:{self.data_shape} and grid size:{self.grid_shape} must have the same dimension"
|
23
|
+
assert len(self.data_shape) == len(self.patch_shape), f"Data shape:{self.data_shape} and patch shape:{self.patch_shape} must have the same dimension"
|
24
|
+
innerpad = np.array(self.patch_shape) - np.array(self.grid_shape)
|
25
|
+
for dim, pad in enumerate(innerpad):
|
26
|
+
if pad < 0:
|
27
|
+
raise ValueError(f"Patch shape:{self.patch_shape} must be greater than or equal to grid shape:{self.grid_shape} in dimension {dim}")
|
28
|
+
if pad % 2 != 0:
|
29
|
+
raise ValueError(f"Patch shape:{self.patch_shape} must have even padding in dimension {dim}")
|
30
|
+
|
31
|
+
def patch_offset(self):
|
32
|
+
return (np.array(self.patch_shape) - np.array(self.grid_shape))//2
|
33
|
+
|
34
|
+
def get_individual_dim_grid_count(self, dim:int):
|
35
|
+
"""
|
36
|
+
Returns the number of the grid in the specified dimension, ignoring all other dimensions.
|
37
|
+
"""
|
38
|
+
assert dim < len(self.data_shape), f"Dimension {dim} is out of bounds for data shape {self.data_shape}"
|
39
|
+
assert dim >= 0, "Dimension must be greater than or equal to 0"
|
40
|
+
|
41
|
+
if self.grid_shape[dim]==1 and self.patch_shape[dim]==1:
|
42
|
+
return self.data_shape[dim]
|
43
|
+
elif self.tiling_mode == TilingMode.PadBoundary:
|
44
|
+
return int(np.ceil(self.data_shape[dim] / self.grid_shape[dim]))
|
45
|
+
elif self.tiling_mode == TilingMode.ShiftBoundary:
|
46
|
+
excess_size = self.patch_shape[dim] - self.grid_shape[dim]
|
47
|
+
return int(np.ceil((self.data_shape[dim] - excess_size) / self.grid_shape[dim]))
|
48
|
+
else:
|
49
|
+
excess_size = self.patch_shape[dim] - self.grid_shape[dim]
|
50
|
+
return int(np.floor((self.data_shape[dim] - excess_size) / self.grid_shape[dim]))
|
51
|
+
|
52
|
+
def total_grid_count(self):
|
53
|
+
"""
|
54
|
+
Returns the total number of grids in the dataset.
|
55
|
+
"""
|
56
|
+
return self.grid_count(0) * self.get_individual_dim_grid_count(0)
|
57
|
+
|
58
|
+
def grid_count(self, dim:int):
|
59
|
+
"""
|
60
|
+
Returns the total number of grids for one value in the specified dimension.
|
61
|
+
"""
|
62
|
+
assert dim < len(self.data_shape), f"Dimension {dim} is out of bounds for data shape {self.data_shape}"
|
63
|
+
assert dim >= 0, "Dimension must be greater than or equal to 0"
|
64
|
+
if dim == len(self.data_shape)-1:
|
65
|
+
return 1
|
66
|
+
|
67
|
+
return self.get_individual_dim_grid_count(dim+1) * self.grid_count(dim+1)
|
68
|
+
|
69
|
+
def get_grid_index(self, dim:int, coordinate:int):
|
70
|
+
"""
|
71
|
+
Returns the index of the grid in the specified dimension.
|
72
|
+
"""
|
73
|
+
assert dim < len(self.data_shape), f"Dimension {dim} is out of bounds for data shape {self.data_shape}"
|
74
|
+
assert dim >= 0, "Dimension must be greater than or equal to 0"
|
75
|
+
assert coordinate < self.data_shape[dim], f"Coordinate {coordinate} is out of bounds for data shape {self.data_shape}"
|
76
|
+
|
77
|
+
if self.grid_shape[dim]==1 and self.patch_shape[dim]==1:
|
78
|
+
return coordinate
|
79
|
+
elif self.tiling_mode == TilingMode.PadBoundary: #self.trim_boundary is False:
|
80
|
+
return np.floor(coordinate / self.grid_shape[dim])
|
81
|
+
elif self.tiling_mode == TilingMode.TrimBoundary:
|
82
|
+
excess_size = (self.patch_shape[dim] - self.grid_shape[dim])//2
|
83
|
+
# can be <0 if coordinate is in [0,grid_shape[dim]]
|
84
|
+
return max(0, np.floor((coordinate - excess_size) / self.grid_shape[dim]))
|
85
|
+
elif self.tiling_mode == TilingMode.ShiftBoundary:
|
86
|
+
excess_size = (self.patch_shape[dim] - self.grid_shape[dim])//2
|
87
|
+
if coordinate + self.grid_shape[dim] + excess_size == self.data_shape[dim]:
|
88
|
+
return self.get_individual_dim_grid_count(dim) - 1
|
89
|
+
else:
|
90
|
+
# can be <0 if coordinate is in [0,grid_shape[dim]]
|
91
|
+
return max(0, np.floor((coordinate - excess_size) / self.grid_shape[dim]))
|
92
|
+
|
93
|
+
else:
|
94
|
+
raise ValueError(f"Unsupported tiling mode {self.tiling_mode}")
|
95
|
+
|
96
|
+
def dataset_idx_from_grid_idx(self, grid_idx:tuple):
|
97
|
+
"""
|
98
|
+
Returns the index of the grid in the dataset.
|
99
|
+
"""
|
100
|
+
assert len(grid_idx) == len(self.data_shape), f"Dimension indices {grid_idx} must have the same dimension as data shape {self.data_shape}"
|
101
|
+
index = 0
|
102
|
+
for dim in range(len(grid_idx)):
|
103
|
+
index += grid_idx[dim] * self.grid_count(dim)
|
104
|
+
return index
|
105
|
+
|
106
|
+
def get_patch_location_from_dataset_idx(self, dataset_idx:int):
|
107
|
+
"""
|
108
|
+
Returns the patch location of the grid in the dataset.
|
109
|
+
"""
|
110
|
+
grid_location = self.get_location_from_dataset_idx(dataset_idx)
|
111
|
+
offset = self.patch_offset()
|
112
|
+
return tuple(np.array(grid_location) - np.array(offset))
|
113
|
+
|
114
|
+
|
115
|
+
def get_dataset_idx_from_grid_location(self, location:tuple):
|
116
|
+
assert len(location) == len(self.data_shape), f"Location {location} must have the same dimension as data shape {self.data_shape}"
|
117
|
+
grid_idx = [self.get_grid_index(dim, location[dim]) for dim in range(len(location))]
|
118
|
+
return self.dataset_idx_from_grid_idx(tuple(grid_idx))
|
119
|
+
|
120
|
+
def get_gridstart_location_from_dim_index(self, dim:int, dim_index:int):
|
121
|
+
"""
|
122
|
+
Returns the grid-start coordinate of the grid in the specified dimension.
|
123
|
+
"""
|
124
|
+
assert dim < len(self.data_shape), f"Dimension {dim} is out of bounds for data shape {self.data_shape}"
|
125
|
+
assert dim >= 0, "Dimension must be greater than or equal to 0"
|
126
|
+
assert dim_index < self.get_individual_dim_grid_count(dim), f"Dimension index {dim_index} is out of bounds for data shape {self.data_shape}"
|
127
|
+
|
128
|
+
if self.grid_shape[dim]==1 and self.patch_shape[dim]==1:
|
129
|
+
return dim_index
|
130
|
+
elif self.tiling_mode == TilingMode.PadBoundary:
|
131
|
+
return dim_index * self.grid_shape[dim]
|
132
|
+
elif self.tiling_mode == TilingMode.TrimBoundary:
|
133
|
+
excess_size = (self.patch_shape[dim] - self.grid_shape[dim])//2
|
134
|
+
return dim_index * self.grid_shape[dim] + excess_size
|
135
|
+
elif self.tiling_mode == TilingMode.ShiftBoundary:
|
136
|
+
excess_size = (self.patch_shape[dim] - self.grid_shape[dim])//2
|
137
|
+
if dim_index < self.get_individual_dim_grid_count(dim) - 1:
|
138
|
+
return dim_index * self.grid_shape[dim] + excess_size
|
139
|
+
else:
|
140
|
+
# on boundary. grid should be placed such that the patch covers the entire data.
|
141
|
+
return self.data_shape[dim] - self.grid_shape[dim] - excess_size
|
142
|
+
else:
|
143
|
+
raise ValueError(f"Unsupported tiling mode {self.tiling_mode}")
|
144
|
+
|
145
|
+
def get_location_from_dataset_idx(self, dataset_idx:int):
|
146
|
+
"""
|
147
|
+
Returns the start location of the grid in the dataset.
|
148
|
+
"""
|
149
|
+
grid_idx = []
|
150
|
+
for dim in range(len(self.data_shape)):
|
151
|
+
grid_idx.append(dataset_idx // self.grid_count(dim))
|
152
|
+
dataset_idx = dataset_idx % self.grid_count(dim)
|
153
|
+
location = [self.get_gridstart_location_from_dim_index(dim, grid_idx[dim]) for dim in range(len(self.data_shape))]
|
154
|
+
return tuple(location)
|
155
|
+
|
156
|
+
def on_boundary(self, dataset_idx:int, dim:int, only_end:bool=False):
|
157
|
+
"""
|
158
|
+
Returns True if the grid is on the boundary in the specified dimension.
|
159
|
+
"""
|
160
|
+
assert dim < len(self.data_shape), f"Dimension {dim} is out of bounds for data shape {self.data_shape}"
|
161
|
+
assert dim >= 0, "Dimension must be greater than or equal to 0"
|
162
|
+
|
163
|
+
if dim > 0:
|
164
|
+
dataset_idx = dataset_idx % self.grid_count(dim-1)
|
165
|
+
|
166
|
+
dim_index = dataset_idx // self.grid_count(dim)
|
167
|
+
if only_end:
|
168
|
+
return dim_index == self.get_individual_dim_grid_count(dim) - 1
|
169
|
+
|
170
|
+
return dim_index == 0 or dim_index == self.get_individual_dim_grid_count(dim) - 1
|
171
|
+
|
172
|
+
def next_grid_along_dim(self, dataset_idx:int, dim:int):
|
173
|
+
"""
|
174
|
+
Returns the index of the grid in the specified dimension in the specified direction.
|
175
|
+
"""
|
176
|
+
assert dim < len(self.data_shape), f"Dimension {dim} is out of bounds for data shape {self.data_shape}"
|
177
|
+
assert dim >= 0, "Dimension must be greater than or equal to 0"
|
178
|
+
new_idx = dataset_idx + self.grid_count(dim)
|
179
|
+
if new_idx >= self.total_grid_count():
|
180
|
+
return None
|
181
|
+
return new_idx
|
182
|
+
|
183
|
+
def prev_grid_along_dim(self, dataset_idx:int, dim:int):
|
184
|
+
"""
|
185
|
+
Returns the index of the grid in the specified dimension in the specified direction.
|
186
|
+
"""
|
187
|
+
assert dim < len(self.data_shape), f"Dimension {dim} is out of bounds for data shape {self.data_shape}"
|
188
|
+
assert dim >= 0, "Dimension must be greater than or equal to 0"
|
189
|
+
new_idx = dataset_idx - self.grid_count(dim)
|
190
|
+
if new_idx < 0:
|
191
|
+
return None
|
192
|
+
|
193
|
+
if __name__ == '__main__':
|
194
|
+
# data_shape = (1, 5, 103, 103,2)
|
195
|
+
# grid_shape = (1, 1, 16,16, 2)
|
196
|
+
# patch_shape = (1, 3, 32, 32, 2)
|
197
|
+
data_shape = (5, 5, 64, 64, 2)
|
198
|
+
grid_shape = (1, 1, 8, 8, 2)
|
199
|
+
patch_shape = (1, 3, 16, 16, 2)
|
200
|
+
tiling_mode = TilingMode.ShiftBoundary
|
201
|
+
manager = TileIndexManager(data_shape, grid_shape, patch_shape, tiling_mode)
|
202
|
+
gc = manager.total_grid_count()
|
203
|
+
for i in range(gc):
|
204
|
+
loc = manager.get_location_from_dataset_idx(i)
|
205
|
+
print(i, loc)
|
206
|
+
inferred_i = manager.get_dataset_idx_from_grid_location(loc)
|
207
|
+
assert i == inferred_i, f"Index mismatch: {i} != {inferred_i}"
|
208
|
+
|
209
|
+
for i in range(5):
|
210
|
+
print(manager.on_boundary(40, i))
|
@@ -0,0 +1,65 @@
|
|
1
|
+
from typing import List
|
2
|
+
|
3
|
+
import numpy as np
|
4
|
+
|
5
|
+
from predtiler.tile_manager import TilingMode
|
6
|
+
|
7
|
+
|
8
|
+
def stitch_predictions(predictions:np.ndarray, manager):
|
9
|
+
"""
|
10
|
+
Args:
|
11
|
+
predictions: N*C*H*W or N*C*D*H*W numpy array where N is the number of datasets, C is the number of channels, H is the height, W is the width, D is the depth.
|
12
|
+
manager:
|
13
|
+
"""
|
14
|
+
|
15
|
+
mng = manager
|
16
|
+
shape = list(mng.data_shape)
|
17
|
+
shape.append(predictions.shape[1])
|
18
|
+
print(shape)
|
19
|
+
|
20
|
+
output = np.zeros(shape, dtype=predictions.dtype)
|
21
|
+
for dset_idx in range(predictions.shape[0]):
|
22
|
+
# grid start, grid end
|
23
|
+
gs = np.array(mng.get_location_from_dataset_idx(dset_idx), dtype=int)
|
24
|
+
ge = gs + mng.grid_shape
|
25
|
+
|
26
|
+
# patch start, patch end
|
27
|
+
ps = gs - mng.patch_offset()
|
28
|
+
pe = ps + mng.patch_shape
|
29
|
+
|
30
|
+
# valid grid start, valid grid end
|
31
|
+
vgs = np.array([max(0,x) for x in gs], dtype=int)
|
32
|
+
vge = np.array([min(x,y) for x,y in zip(ge, mng.data_shape)], dtype=int)
|
33
|
+
assert np.all(vgs ==gs)
|
34
|
+
assert np.all(vge ==ge)
|
35
|
+
|
36
|
+
if mng.tiling_mode == TilingMode.ShiftBoundary:
|
37
|
+
for dim in range(len(vgs)):
|
38
|
+
if ps[dim] == 0:
|
39
|
+
vgs[dim] = 0
|
40
|
+
if pe[dim] == mng.data_shape[dim]:
|
41
|
+
vge[dim]= mng.data_shape[dim]
|
42
|
+
|
43
|
+
# relative start, relative end. This will be used on pred_tiled
|
44
|
+
rs = vgs - ps
|
45
|
+
re = rs + ( vge - vgs)
|
46
|
+
|
47
|
+
for ch_idx in range(predictions.shape[1]):
|
48
|
+
if len(output.shape) == 4:
|
49
|
+
# channel dimension is the last one.
|
50
|
+
output[vgs[0]:vge[0],
|
51
|
+
vgs[1]:vge[1],
|
52
|
+
vgs[2]:vge[2],
|
53
|
+
ch_idx] = predictions[dset_idx][ch_idx,rs[1]:re[1], rs[2]:re[2]]
|
54
|
+
elif len(output.shape) == 5:
|
55
|
+
# channel dimension is the last one.
|
56
|
+
assert vge[0] - vgs[0] == 1, 'Only one frame is supported'
|
57
|
+
output[vgs[0],
|
58
|
+
vgs[1]:vge[1],
|
59
|
+
vgs[2]:vge[2],
|
60
|
+
vgs[3]:vge[3],
|
61
|
+
ch_idx] = predictions[dset_idx][ch_idx, rs[1]:re[1], rs[2]:re[2], rs[3]:re[3]]
|
62
|
+
else:
|
63
|
+
raise ValueError(f'Unsupported shape {output.shape}')
|
64
|
+
|
65
|
+
return output
|
@@ -0,0 +1,92 @@
|
|
1
|
+
from unittest.mock import Mock
|
2
|
+
import numpy as np
|
3
|
+
from predtiler.dataset import get_tiling_dataset, get_tile_manager
|
4
|
+
from predtiler.tile_stitcher import stitch_predictions
|
5
|
+
|
6
|
+
def get_data_3D(n=5,Z=9, H=512,W=512,C=2):
|
7
|
+
data = np.arange(n*Z*H*W*C).reshape(n,Z,H,W,C)
|
8
|
+
return data
|
9
|
+
|
10
|
+
def get_data_2D(n=5,H=512,W=512,C=2):
|
11
|
+
data = np.arange(n*H*W*C).reshape(n,H,W,C)
|
12
|
+
return data
|
13
|
+
|
14
|
+
class DummDataset:
|
15
|
+
def __init__(self, datatype ='2D', patch_size=64, z_patch_size=5) -> None:
|
16
|
+
assert datatype in ['2D', '3D'], 'datatype must be either 2D or 3D'
|
17
|
+
self.datatype = datatype
|
18
|
+
self.z_patch_size = z_patch_size
|
19
|
+
self.patch_size = patch_size
|
20
|
+
if datatype == '2D':
|
21
|
+
self.data = get_data_2D()
|
22
|
+
elif datatype == '3D':
|
23
|
+
self.data = get_data_3D()
|
24
|
+
|
25
|
+
def patch_location(self, index):
|
26
|
+
if self.datatype == '2D':
|
27
|
+
n_idx = np.random.randint(0,len(self.data))
|
28
|
+
h = np.random.randint(0, self.data.shape[1]-self.patch_size)
|
29
|
+
w = np.random.randint(0, self.data.shape[2]-self.patch_size)
|
30
|
+
return (n_idx, h, w)
|
31
|
+
elif self.datatype == '3D':
|
32
|
+
n_idx = np.random.randint(0,len(self.data))
|
33
|
+
z = np.random.randint(0, self.data.shape[1]-self.z_patch_size)
|
34
|
+
h = np.random.randint(0, self.data.shape[2]-self.patch_size)
|
35
|
+
w = np.random.randint(0, self.data.shape[3]-self.patch_size)
|
36
|
+
return (n_idx, z, h, w)
|
37
|
+
|
38
|
+
def __len__(self):
|
39
|
+
return len(self.data) * (self.data.shape[-2]//self.patch_size) * (self.data.shape[-3]//self.patch_size)
|
40
|
+
|
41
|
+
def __getitem__(self, index):
|
42
|
+
if self.datatype == '2D':
|
43
|
+
n_idx, h, w = self.patch_location(index)
|
44
|
+
return self.data[n_idx, h:h+self.patch_size, w:w+self.patch_size].transpose(2,0,1)
|
45
|
+
elif self.datatype == '3D':
|
46
|
+
n_idx, z, h, w = self.patch_location(index)
|
47
|
+
return self.data[n_idx, z:z+self.z_patch_size, h:h+self.patch_size, w:w+self.patch_size].transpose(3,0,1,2)
|
48
|
+
|
49
|
+
|
50
|
+
def test_stich_prediction_2D():
|
51
|
+
data_type = '2D'
|
52
|
+
data_fn = get_data_2D
|
53
|
+
patch_size = 256
|
54
|
+
tile_size = 128
|
55
|
+
data = data_fn()
|
56
|
+
manager = get_tile_manager(data_shape=data.shape[:-1], tile_shape=(1,tile_size,tile_size),
|
57
|
+
patch_shape=(1,patch_size,patch_size))
|
58
|
+
|
59
|
+
dset_class = get_tiling_dataset(DummDataset, manager)
|
60
|
+
dset = dset_class(data_type, patch_size)
|
61
|
+
|
62
|
+
predictions = []
|
63
|
+
for i in range(len(dset)):
|
64
|
+
predictions.append(dset[i])
|
65
|
+
|
66
|
+
predictions = np.stack(predictions)
|
67
|
+
stitched_pred = stitch_predictions(predictions, dset.tile_manager)
|
68
|
+
assert (stitched_pred== data).all()
|
69
|
+
|
70
|
+
|
71
|
+
|
72
|
+
def test_stich_prediction_3D():
|
73
|
+
data_type = '3D'
|
74
|
+
data_fn = get_data_3D
|
75
|
+
patch_size = 256
|
76
|
+
tile_size = 128
|
77
|
+
data = data_fn()
|
78
|
+
z_patch_size = 5
|
79
|
+
z_tile_size = 3
|
80
|
+
manager = get_tile_manager(data_shape=data.shape[:-1], tile_shape=(1,z_tile_size, tile_size,tile_size),
|
81
|
+
patch_shape=(1,z_patch_size, patch_size,patch_size))
|
82
|
+
|
83
|
+
dset_class = get_tiling_dataset(DummDataset, manager)
|
84
|
+
dset = dset_class(data_type, patch_size)
|
85
|
+
|
86
|
+
predictions = []
|
87
|
+
for i in range(len(dset)):
|
88
|
+
predictions.append(dset[i])
|
89
|
+
|
90
|
+
predictions = np.stack(predictions)
|
91
|
+
stitched_pred = stitch_predictions(predictions, dset.tile_manager)
|
92
|
+
assert (stitched_pred== data).all()
|