predtiler 0.0.1__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- predtiler/dataset.py +53 -0
- predtiler/tile_manager.py +210 -0
- predtiler/tile_stitcher.py +65 -0
- predtiler-0.0.1.dist-info/METADATA +122 -0
- predtiler-0.0.1.dist-info/RECORD +7 -0
- predtiler-0.0.1.dist-info/WHEEL +4 -0
- predtiler-0.0.1.dist-info/licenses/LICENSE +21 -0
predtiler/dataset.py
ADDED
@@ -0,0 +1,53 @@
|
|
1
|
+
|
2
|
+
from predtiler.tile_manager import TileIndexManager, TilingMode
|
3
|
+
|
4
|
+
# class TilingDataset:
|
5
|
+
# def __init_subclass__(cls, parent_class=None, tile_manager=None, **kwargs):
|
6
|
+
# super().__init_subclass__(**kwargs)
|
7
|
+
# assert tile_manager is not None, 'tile_manager must be provided'
|
8
|
+
# cls.tile_manager = tile_manager
|
9
|
+
# if parent_class is not None:
|
10
|
+
# has_callable_method = callable(getattr(parent_class, 'patch_location', None))
|
11
|
+
# assert has_callable_method, f'{parent_class.__name__} must have a callable method with following signature: def patch_location(self, index)'
|
12
|
+
# cls.__bases__ = (parent_class,) + cls.__bases__
|
13
|
+
|
14
|
+
# def __len__(self):
|
15
|
+
# return self.tile_manager.total_grid_count()
|
16
|
+
|
17
|
+
# def patch_location(self, index):
|
18
|
+
# print('Calling patch_location')
|
19
|
+
# patch_loc_list = self.tile_manager.get_patch_location_from_dataset_idx(index)
|
20
|
+
# return patch_loc_list
|
21
|
+
|
22
|
+
|
23
|
+
# def get_tiling_dataset(dataset_class, tile_manager) -> type:
|
24
|
+
# class CorrespondingTilingDataset(TilingDataset, parent_class=dataset_class, tile_manager=tile_manager):
|
25
|
+
# pass
|
26
|
+
|
27
|
+
# return CorrespondingTilingDataset
|
28
|
+
|
29
|
+
def get_tiling_dataset(dataset_class, tile_manager) -> type:
|
30
|
+
has_callable_method = callable(getattr(dataset_class, 'patch_location', None))
|
31
|
+
assert has_callable_method, f'{dataset_class.__name__} must have a callable method with following signature: def patch_location(self, index)'
|
32
|
+
|
33
|
+
class TilingDataset(dataset_class):
|
34
|
+
def __init__(self, *args, **kwargs):
|
35
|
+
super().__init__(*args, **kwargs)
|
36
|
+
self.tile_manager = tile_manager
|
37
|
+
|
38
|
+
def __len__(self):
|
39
|
+
return self.tile_manager.total_grid_count()
|
40
|
+
|
41
|
+
def patch_location(self, index):
|
42
|
+
patch_loc_list = self.tile_manager.get_patch_location_from_dataset_idx(index)
|
43
|
+
return patch_loc_list
|
44
|
+
|
45
|
+
return TilingDataset
|
46
|
+
|
47
|
+
|
48
|
+
|
49
|
+
|
50
|
+
def get_tile_manager(data_shape, tile_shape, patch_shape, tiling_mode=TilingMode.ShiftBoundary):
|
51
|
+
return TileIndexManager(data_shape, tile_shape, patch_shape, tiling_mode)
|
52
|
+
|
53
|
+
|
@@ -0,0 +1,210 @@
|
|
1
|
+
from dataclasses import dataclass
|
2
|
+
|
3
|
+
import numpy as np
|
4
|
+
|
5
|
+
|
6
|
+
class TilingMode:
|
7
|
+
"""
|
8
|
+
Enum for the tiling mode.
|
9
|
+
"""
|
10
|
+
TrimBoundary = 0
|
11
|
+
PadBoundary = 1
|
12
|
+
ShiftBoundary = 2
|
13
|
+
|
14
|
+
@dataclass
|
15
|
+
class TileIndexManager:
|
16
|
+
data_shape: tuple
|
17
|
+
grid_shape: tuple
|
18
|
+
patch_shape: tuple
|
19
|
+
tiling_mode: TilingMode
|
20
|
+
|
21
|
+
def __post_init__(self):
|
22
|
+
assert len(self.data_shape) == len(self.grid_shape), f"Data shape:{self.data_shape} and grid size:{self.grid_shape} must have the same dimension"
|
23
|
+
assert len(self.data_shape) == len(self.patch_shape), f"Data shape:{self.data_shape} and patch shape:{self.patch_shape} must have the same dimension"
|
24
|
+
innerpad = np.array(self.patch_shape) - np.array(self.grid_shape)
|
25
|
+
for dim, pad in enumerate(innerpad):
|
26
|
+
if pad < 0:
|
27
|
+
raise ValueError(f"Patch shape:{self.patch_shape} must be greater than or equal to grid shape:{self.grid_shape} in dimension {dim}")
|
28
|
+
if pad % 2 != 0:
|
29
|
+
raise ValueError(f"Patch shape:{self.patch_shape} must have even padding in dimension {dim}")
|
30
|
+
|
31
|
+
def patch_offset(self):
|
32
|
+
return (np.array(self.patch_shape) - np.array(self.grid_shape))//2
|
33
|
+
|
34
|
+
def get_individual_dim_grid_count(self, dim:int):
|
35
|
+
"""
|
36
|
+
Returns the number of the grid in the specified dimension, ignoring all other dimensions.
|
37
|
+
"""
|
38
|
+
assert dim < len(self.data_shape), f"Dimension {dim} is out of bounds for data shape {self.data_shape}"
|
39
|
+
assert dim >= 0, "Dimension must be greater than or equal to 0"
|
40
|
+
|
41
|
+
if self.grid_shape[dim]==1 and self.patch_shape[dim]==1:
|
42
|
+
return self.data_shape[dim]
|
43
|
+
elif self.tiling_mode == TilingMode.PadBoundary:
|
44
|
+
return int(np.ceil(self.data_shape[dim] / self.grid_shape[dim]))
|
45
|
+
elif self.tiling_mode == TilingMode.ShiftBoundary:
|
46
|
+
excess_size = self.patch_shape[dim] - self.grid_shape[dim]
|
47
|
+
return int(np.ceil((self.data_shape[dim] - excess_size) / self.grid_shape[dim]))
|
48
|
+
else:
|
49
|
+
excess_size = self.patch_shape[dim] - self.grid_shape[dim]
|
50
|
+
return int(np.floor((self.data_shape[dim] - excess_size) / self.grid_shape[dim]))
|
51
|
+
|
52
|
+
def total_grid_count(self):
|
53
|
+
"""
|
54
|
+
Returns the total number of grids in the dataset.
|
55
|
+
"""
|
56
|
+
return self.grid_count(0) * self.get_individual_dim_grid_count(0)
|
57
|
+
|
58
|
+
def grid_count(self, dim:int):
|
59
|
+
"""
|
60
|
+
Returns the total number of grids for one value in the specified dimension.
|
61
|
+
"""
|
62
|
+
assert dim < len(self.data_shape), f"Dimension {dim} is out of bounds for data shape {self.data_shape}"
|
63
|
+
assert dim >= 0, "Dimension must be greater than or equal to 0"
|
64
|
+
if dim == len(self.data_shape)-1:
|
65
|
+
return 1
|
66
|
+
|
67
|
+
return self.get_individual_dim_grid_count(dim+1) * self.grid_count(dim+1)
|
68
|
+
|
69
|
+
def get_grid_index(self, dim:int, coordinate:int):
|
70
|
+
"""
|
71
|
+
Returns the index of the grid in the specified dimension.
|
72
|
+
"""
|
73
|
+
assert dim < len(self.data_shape), f"Dimension {dim} is out of bounds for data shape {self.data_shape}"
|
74
|
+
assert dim >= 0, "Dimension must be greater than or equal to 0"
|
75
|
+
assert coordinate < self.data_shape[dim], f"Coordinate {coordinate} is out of bounds for data shape {self.data_shape}"
|
76
|
+
|
77
|
+
if self.grid_shape[dim]==1 and self.patch_shape[dim]==1:
|
78
|
+
return coordinate
|
79
|
+
elif self.tiling_mode == TilingMode.PadBoundary: #self.trim_boundary is False:
|
80
|
+
return np.floor(coordinate / self.grid_shape[dim])
|
81
|
+
elif self.tiling_mode == TilingMode.TrimBoundary:
|
82
|
+
excess_size = (self.patch_shape[dim] - self.grid_shape[dim])//2
|
83
|
+
# can be <0 if coordinate is in [0,grid_shape[dim]]
|
84
|
+
return max(0, np.floor((coordinate - excess_size) / self.grid_shape[dim]))
|
85
|
+
elif self.tiling_mode == TilingMode.ShiftBoundary:
|
86
|
+
excess_size = (self.patch_shape[dim] - self.grid_shape[dim])//2
|
87
|
+
if coordinate + self.grid_shape[dim] + excess_size == self.data_shape[dim]:
|
88
|
+
return self.get_individual_dim_grid_count(dim) - 1
|
89
|
+
else:
|
90
|
+
# can be <0 if coordinate is in [0,grid_shape[dim]]
|
91
|
+
return max(0, np.floor((coordinate - excess_size) / self.grid_shape[dim]))
|
92
|
+
|
93
|
+
else:
|
94
|
+
raise ValueError(f"Unsupported tiling mode {self.tiling_mode}")
|
95
|
+
|
96
|
+
def dataset_idx_from_grid_idx(self, grid_idx:tuple):
|
97
|
+
"""
|
98
|
+
Returns the index of the grid in the dataset.
|
99
|
+
"""
|
100
|
+
assert len(grid_idx) == len(self.data_shape), f"Dimension indices {grid_idx} must have the same dimension as data shape {self.data_shape}"
|
101
|
+
index = 0
|
102
|
+
for dim in range(len(grid_idx)):
|
103
|
+
index += grid_idx[dim] * self.grid_count(dim)
|
104
|
+
return index
|
105
|
+
|
106
|
+
def get_patch_location_from_dataset_idx(self, dataset_idx:int):
|
107
|
+
"""
|
108
|
+
Returns the patch location of the grid in the dataset.
|
109
|
+
"""
|
110
|
+
grid_location = self.get_location_from_dataset_idx(dataset_idx)
|
111
|
+
offset = self.patch_offset()
|
112
|
+
return tuple(np.array(grid_location) - np.array(offset))
|
113
|
+
|
114
|
+
|
115
|
+
def get_dataset_idx_from_grid_location(self, location:tuple):
|
116
|
+
assert len(location) == len(self.data_shape), f"Location {location} must have the same dimension as data shape {self.data_shape}"
|
117
|
+
grid_idx = [self.get_grid_index(dim, location[dim]) for dim in range(len(location))]
|
118
|
+
return self.dataset_idx_from_grid_idx(tuple(grid_idx))
|
119
|
+
|
120
|
+
def get_gridstart_location_from_dim_index(self, dim:int, dim_index:int):
|
121
|
+
"""
|
122
|
+
Returns the grid-start coordinate of the grid in the specified dimension.
|
123
|
+
"""
|
124
|
+
assert dim < len(self.data_shape), f"Dimension {dim} is out of bounds for data shape {self.data_shape}"
|
125
|
+
assert dim >= 0, "Dimension must be greater than or equal to 0"
|
126
|
+
assert dim_index < self.get_individual_dim_grid_count(dim), f"Dimension index {dim_index} is out of bounds for data shape {self.data_shape}"
|
127
|
+
|
128
|
+
if self.grid_shape[dim]==1 and self.patch_shape[dim]==1:
|
129
|
+
return dim_index
|
130
|
+
elif self.tiling_mode == TilingMode.PadBoundary:
|
131
|
+
return dim_index * self.grid_shape[dim]
|
132
|
+
elif self.tiling_mode == TilingMode.TrimBoundary:
|
133
|
+
excess_size = (self.patch_shape[dim] - self.grid_shape[dim])//2
|
134
|
+
return dim_index * self.grid_shape[dim] + excess_size
|
135
|
+
elif self.tiling_mode == TilingMode.ShiftBoundary:
|
136
|
+
excess_size = (self.patch_shape[dim] - self.grid_shape[dim])//2
|
137
|
+
if dim_index < self.get_individual_dim_grid_count(dim) - 1:
|
138
|
+
return dim_index * self.grid_shape[dim] + excess_size
|
139
|
+
else:
|
140
|
+
# on boundary. grid should be placed such that the patch covers the entire data.
|
141
|
+
return self.data_shape[dim] - self.grid_shape[dim] - excess_size
|
142
|
+
else:
|
143
|
+
raise ValueError(f"Unsupported tiling mode {self.tiling_mode}")
|
144
|
+
|
145
|
+
def get_location_from_dataset_idx(self, dataset_idx:int):
|
146
|
+
"""
|
147
|
+
Returns the start location of the grid in the dataset.
|
148
|
+
"""
|
149
|
+
grid_idx = []
|
150
|
+
for dim in range(len(self.data_shape)):
|
151
|
+
grid_idx.append(dataset_idx // self.grid_count(dim))
|
152
|
+
dataset_idx = dataset_idx % self.grid_count(dim)
|
153
|
+
location = [self.get_gridstart_location_from_dim_index(dim, grid_idx[dim]) for dim in range(len(self.data_shape))]
|
154
|
+
return tuple(location)
|
155
|
+
|
156
|
+
def on_boundary(self, dataset_idx:int, dim:int, only_end:bool=False):
|
157
|
+
"""
|
158
|
+
Returns True if the grid is on the boundary in the specified dimension.
|
159
|
+
"""
|
160
|
+
assert dim < len(self.data_shape), f"Dimension {dim} is out of bounds for data shape {self.data_shape}"
|
161
|
+
assert dim >= 0, "Dimension must be greater than or equal to 0"
|
162
|
+
|
163
|
+
if dim > 0:
|
164
|
+
dataset_idx = dataset_idx % self.grid_count(dim-1)
|
165
|
+
|
166
|
+
dim_index = dataset_idx // self.grid_count(dim)
|
167
|
+
if only_end:
|
168
|
+
return dim_index == self.get_individual_dim_grid_count(dim) - 1
|
169
|
+
|
170
|
+
return dim_index == 0 or dim_index == self.get_individual_dim_grid_count(dim) - 1
|
171
|
+
|
172
|
+
def next_grid_along_dim(self, dataset_idx:int, dim:int):
|
173
|
+
"""
|
174
|
+
Returns the index of the grid in the specified dimension in the specified direction.
|
175
|
+
"""
|
176
|
+
assert dim < len(self.data_shape), f"Dimension {dim} is out of bounds for data shape {self.data_shape}"
|
177
|
+
assert dim >= 0, "Dimension must be greater than or equal to 0"
|
178
|
+
new_idx = dataset_idx + self.grid_count(dim)
|
179
|
+
if new_idx >= self.total_grid_count():
|
180
|
+
return None
|
181
|
+
return new_idx
|
182
|
+
|
183
|
+
def prev_grid_along_dim(self, dataset_idx:int, dim:int):
|
184
|
+
"""
|
185
|
+
Returns the index of the grid in the specified dimension in the specified direction.
|
186
|
+
"""
|
187
|
+
assert dim < len(self.data_shape), f"Dimension {dim} is out of bounds for data shape {self.data_shape}"
|
188
|
+
assert dim >= 0, "Dimension must be greater than or equal to 0"
|
189
|
+
new_idx = dataset_idx - self.grid_count(dim)
|
190
|
+
if new_idx < 0:
|
191
|
+
return None
|
192
|
+
|
193
|
+
if __name__ == '__main__':
|
194
|
+
# data_shape = (1, 5, 103, 103,2)
|
195
|
+
# grid_shape = (1, 1, 16,16, 2)
|
196
|
+
# patch_shape = (1, 3, 32, 32, 2)
|
197
|
+
data_shape = (5, 5, 64, 64, 2)
|
198
|
+
grid_shape = (1, 1, 8, 8, 2)
|
199
|
+
patch_shape = (1, 3, 16, 16, 2)
|
200
|
+
tiling_mode = TilingMode.ShiftBoundary
|
201
|
+
manager = TileIndexManager(data_shape, grid_shape, patch_shape, tiling_mode)
|
202
|
+
gc = manager.total_grid_count()
|
203
|
+
for i in range(gc):
|
204
|
+
loc = manager.get_location_from_dataset_idx(i)
|
205
|
+
print(i, loc)
|
206
|
+
inferred_i = manager.get_dataset_idx_from_grid_location(loc)
|
207
|
+
assert i == inferred_i, f"Index mismatch: {i} != {inferred_i}"
|
208
|
+
|
209
|
+
for i in range(5):
|
210
|
+
print(manager.on_boundary(40, i))
|
@@ -0,0 +1,65 @@
|
|
1
|
+
from typing import List
|
2
|
+
|
3
|
+
import numpy as np
|
4
|
+
|
5
|
+
from predtiler.tile_manager import TilingMode
|
6
|
+
|
7
|
+
|
8
|
+
def stitch_predictions(predictions:np.ndarray, manager):
|
9
|
+
"""
|
10
|
+
Args:
|
11
|
+
predictions: N*C*H*W or N*C*D*H*W numpy array where N is the number of datasets, C is the number of channels, H is the height, W is the width, D is the depth.
|
12
|
+
manager:
|
13
|
+
"""
|
14
|
+
|
15
|
+
mng = manager
|
16
|
+
shape = list(mng.data_shape)
|
17
|
+
shape.append(predictions.shape[1])
|
18
|
+
print(shape)
|
19
|
+
|
20
|
+
output = np.zeros(shape, dtype=predictions.dtype)
|
21
|
+
for dset_idx in range(predictions.shape[0]):
|
22
|
+
# grid start, grid end
|
23
|
+
gs = np.array(mng.get_location_from_dataset_idx(dset_idx), dtype=int)
|
24
|
+
ge = gs + mng.grid_shape
|
25
|
+
|
26
|
+
# patch start, patch end
|
27
|
+
ps = gs - mng.patch_offset()
|
28
|
+
pe = ps + mng.patch_shape
|
29
|
+
|
30
|
+
# valid grid start, valid grid end
|
31
|
+
vgs = np.array([max(0,x) for x in gs], dtype=int)
|
32
|
+
vge = np.array([min(x,y) for x,y in zip(ge, mng.data_shape)], dtype=int)
|
33
|
+
assert np.all(vgs ==gs)
|
34
|
+
assert np.all(vge ==ge)
|
35
|
+
|
36
|
+
if mng.tiling_mode == TilingMode.ShiftBoundary:
|
37
|
+
for dim in range(len(vgs)):
|
38
|
+
if ps[dim] == 0:
|
39
|
+
vgs[dim] = 0
|
40
|
+
if pe[dim] == mng.data_shape[dim]:
|
41
|
+
vge[dim]= mng.data_shape[dim]
|
42
|
+
|
43
|
+
# relative start, relative end. This will be used on pred_tiled
|
44
|
+
rs = vgs - ps
|
45
|
+
re = rs + ( vge - vgs)
|
46
|
+
|
47
|
+
for ch_idx in range(predictions.shape[1]):
|
48
|
+
if len(output.shape) == 4:
|
49
|
+
# channel dimension is the last one.
|
50
|
+
output[vgs[0]:vge[0],
|
51
|
+
vgs[1]:vge[1],
|
52
|
+
vgs[2]:vge[2],
|
53
|
+
ch_idx] = predictions[dset_idx][ch_idx,rs[1]:re[1], rs[2]:re[2]]
|
54
|
+
elif len(output.shape) == 5:
|
55
|
+
# channel dimension is the last one.
|
56
|
+
assert vge[0] - vgs[0] == 1, 'Only one frame is supported'
|
57
|
+
output[vgs[0],
|
58
|
+
vgs[1]:vge[1],
|
59
|
+
vgs[2]:vge[2],
|
60
|
+
vgs[3]:vge[3],
|
61
|
+
ch_idx] = predictions[dset_idx][ch_idx, rs[1]:re[1], rs[2]:re[2], rs[3]:re[3]]
|
62
|
+
else:
|
63
|
+
raise ValueError(f'Unsupported shape {output.shape}')
|
64
|
+
|
65
|
+
return output
|
@@ -0,0 +1,122 @@
|
|
1
|
+
Metadata-Version: 2.3
|
2
|
+
Name: predtiler
|
3
|
+
Version: 0.0.1
|
4
|
+
Summary: Converting your dataset class into a class that can be used for tiled prediction and eventually obtain stiched prediction.
|
5
|
+
Project-URL: homepage, https://github.com/ashesh-0/PredTiler
|
6
|
+
Project-URL: repository, https://github.com/ashesh-0/PredTiler
|
7
|
+
Author: Ashesh
|
8
|
+
License: MIT
|
9
|
+
License-File: LICENSE
|
10
|
+
Classifier: Development Status :: 3 - Alpha
|
11
|
+
Classifier: License :: OSI Approved :: MIT License
|
12
|
+
Classifier: Programming Language :: Python :: 3
|
13
|
+
Classifier: Programming Language :: Python :: 3.9
|
14
|
+
Classifier: Programming Language :: Python :: 3.10
|
15
|
+
Classifier: Programming Language :: Python :: 3.11
|
16
|
+
Classifier: Programming Language :: Python :: 3.12
|
17
|
+
Classifier: Typing :: Typed
|
18
|
+
Requires-Python: >=3.9
|
19
|
+
Requires-Dist: numpy
|
20
|
+
Provides-Extra: dev
|
21
|
+
Requires-Dist: pre-commit; extra == 'dev'
|
22
|
+
Requires-Dist: pytest; extra == 'dev'
|
23
|
+
Requires-Dist: pytest-cov; extra == 'dev'
|
24
|
+
Requires-Dist: sybil; extra == 'dev'
|
25
|
+
Provides-Extra: examples
|
26
|
+
Requires-Dist: jupyter; extra == 'examples'
|
27
|
+
Requires-Dist: matplotlib; extra == 'examples'
|
28
|
+
Description-Content-Type: text/markdown
|
29
|
+
|
30
|
+
A lean wrapper around your dataset class to enable tiled prediction.
|
31
|
+
|
32
|
+
[![License](https://img.shields.io/badge/license-MIT-blue)](https://github.com/ashesh-0/PredTiler/blob/main/LICENSE)
|
33
|
+
[![CI](https://github.com/ashesh-0/PredTiler/actions/workflows/ci.yml/badge.svg)](https://github.com/ashesh-0/PredTiler/actions/workflows/ci.yml)
|
34
|
+
[![codecov](https://codecov.io/gh/ashesh-0/PredTiler/graph/badge.svg?token=M655MOS7EL)](https://codecov.io/gh/ashesh-0/PredTiler)
|
35
|
+
|
36
|
+
## Objective
|
37
|
+
This package subclasses the dataset class you use to train your network.
|
38
|
+
With PredTiler, you can use your dataset class as is, and PredTiler will take care of the tiling logic for you.
|
39
|
+
It will automatically generate patches in such a way that they can be tiled with the overlap of `(patch_size - tile_size)//2`.
|
40
|
+
We also provide a function to stitch the tiles back together to get the final prediction.
|
41
|
+
|
42
|
+
In case you are facing issues, feel free to raise an issue and I will be happy to help you out !
|
43
|
+
In future, I plan to add detailed instructions for:
|
44
|
+
1. multi-channel data
|
45
|
+
2. 3D data
|
46
|
+
3. Data being a list of numpy arrays, each poissibly having different shapes.
|
47
|
+
|
48
|
+
## Installation
|
49
|
+
|
50
|
+
```bash
|
51
|
+
pip install predtiler
|
52
|
+
```
|
53
|
+
|
54
|
+
## Usage
|
55
|
+
To work with PredTiler, the only requirement is that your dataset class must have a **patch_location(self, index)** method that returns the location of the patch at the given index.
|
56
|
+
Your dataset class should only use the location information returned by this method to return the patch.
|
57
|
+
PredTiler will override this method to return the location of the patches needed for tiled prediction.
|
58
|
+
|
59
|
+
Note that your dataset class could be arbitrarily complex (augmentations, returning multiple patches, working with 3D data, etc.). The only requirement is that it should use the crop present at the location returned by **patch_location** method. Below is an example of a simple dataset class that can be used with PredTiler.
|
60
|
+
|
61
|
+
```python
|
62
|
+
class YourDataset:
|
63
|
+
def __init__(self, data_path, patch_size=64) -> None:
|
64
|
+
self.patch_size = patch_size
|
65
|
+
self.data = load_data(data_path) # shape: (N, H, W, C)
|
66
|
+
|
67
|
+
def patch_location(self, index:int)-> Tuple[int, int, int]:
|
68
|
+
# it just ignores the index and returns a random location
|
69
|
+
n_idx = np.random.randint(0,len(self.data))
|
70
|
+
h = np.random.randint(0, self.data.shape[1]-self.patch_size)
|
71
|
+
w = np.random.randint(0, self.data.shape[2]-self.patch_size)
|
72
|
+
return (n_idx, h, w)
|
73
|
+
|
74
|
+
def __len__(self):
|
75
|
+
return len(self.data)
|
76
|
+
|
77
|
+
def __getitem__(self, index):
|
78
|
+
n_idx, h, w = self.patch_location(index)
|
79
|
+
# return the patch at the location (patch_size, patch_size)
|
80
|
+
return self.data[n_idx, h:h+self.patch_size, w:w+self.patch_size]
|
81
|
+
```
|
82
|
+
|
83
|
+
## Getting overlapping patches needed for tiled prediction
|
84
|
+
To use PredTiler, we need to get a new class that wraps around your dataset class.
|
85
|
+
For this we also need a tile manager that will manage the tiles.
|
86
|
+
|
87
|
+
```python
|
88
|
+
|
89
|
+
from predtiler.dataset import get_tiling_dataset, get_tile_manager
|
90
|
+
patch_size = 256
|
91
|
+
tile_size = 128
|
92
|
+
data_shape = (10, 2048, 2048) # size of the data you are working with
|
93
|
+
manager = get_tile_manager(data_shape=data_shape, tile_shape=(1,tile_size,tile_size),
|
94
|
+
patch_shape=(1,patch_size,patch_size))
|
95
|
+
|
96
|
+
dset_class = get_tiling_dataset(YourDataset, manager)
|
97
|
+
```
|
98
|
+
|
99
|
+
At this point, you can use the `dset_class` as you would use `YourDataset` class.
|
100
|
+
|
101
|
+
```python
|
102
|
+
data_path = ... # path to your data
|
103
|
+
dset = dset_class(data_path, patch_size=patch_size)
|
104
|
+
```
|
105
|
+
|
106
|
+
## Stitching the predictions
|
107
|
+
The benefit of using PredTiler is that it will automatically generate the patches in such a way that they can be tiled with the overlap of `(patch_size - tile_size)//2`. This allows you to use your dataset class as is, without worrying about the tiling logic.
|
108
|
+
|
109
|
+
```python
|
110
|
+
model = ... # your model
|
111
|
+
predictions = []
|
112
|
+
for i in range(len(dset)):
|
113
|
+
inp = dset[i]
|
114
|
+
inp = torch.Tensor(inp)[None,None]
|
115
|
+
pred = model(inp)
|
116
|
+
predictions.append(pred[0].numpy())
|
117
|
+
|
118
|
+
predictions = np.stack(predictions) # shape: (number_of_patches, C, patch_size, patch_size)
|
119
|
+
stitched_pred = stitch_predictions(predictions, dset.tile_manager)
|
120
|
+
```
|
121
|
+
|
122
|
+
|
@@ -0,0 +1,7 @@
|
|
1
|
+
predtiler/dataset.py,sha256=tHgVZksqJ2h2sw212x2g8SOW4Tf2TDZ_f6ibTpQPUMM,2136
|
2
|
+
predtiler/tile_manager.py,sha256=V-iunzWcDw4_GytvkzhlrG8hLepf2_iSTBA3Da5nWgE,9958
|
3
|
+
predtiler/tile_stitcher.py,sha256=rAgEV0y_FyRzU9RC_W11AE9OaGOrm1CgWPA-hbTh2vE,2336
|
4
|
+
predtiler-0.0.1.dist-info/METADATA,sha256=cXyF36cQHVKxttq9zTUdbksOs9H4Pd5YYhqbF8NRPfA,5317
|
5
|
+
predtiler-0.0.1.dist-info/WHEEL,sha256=1yFddiXMmvYK7QYTqtRNtX66WJ0Mz8PYEiEUoOUUxRY,87
|
6
|
+
predtiler-0.0.1.dist-info/licenses/LICENSE,sha256=QyujMlqgcoASx9pVONwI8UO7xMSntGUEgQdlMqwVv4w,1063
|
7
|
+
predtiler-0.0.1.dist-info/RECORD,,
|
@@ -0,0 +1,21 @@
|
|
1
|
+
MIT License
|
2
|
+
|
3
|
+
Copyright (c) 2024 ashesh
|
4
|
+
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
7
|
+
in the Software without restriction, including without limitation the rights
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
10
|
+
furnished to do so, subject to the following conditions:
|
11
|
+
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
13
|
+
copies or substantial portions of the Software.
|
14
|
+
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21
|
+
SOFTWARE.
|