torch-grid-utils 0.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- torch_grid_utils/__init__.py +13 -0
- torch_grid_utils/coordinate_grid.py +61 -0
- torch_grid_utils/fftfreq_grid.py +204 -0
- torch_grid_utils-0.0.1.dist-info/METADATA +43 -0
- torch_grid_utils-0.0.1.dist-info/RECORD +7 -0
- torch_grid_utils-0.0.1.dist-info/WHEEL +4 -0
- torch_grid_utils-0.0.1.dist-info/licenses/LICENSE +28 -0
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
"""Grids for 2D/3D image manipulations in PyTorch"""
|
|
2
|
+
|
|
3
|
+
from importlib.metadata import PackageNotFoundError, version
|
|
4
|
+
|
|
5
|
+
try:
|
|
6
|
+
__version__ = version("torch-grid-utils")
|
|
7
|
+
except PackageNotFoundError:
|
|
8
|
+
__version__ = "uninstalled"
|
|
9
|
+
__author__ = "Alister Burt"
|
|
10
|
+
__email__ = "alisterburt@gmail.com"
|
|
11
|
+
|
|
12
|
+
from .fftfreq_grid import fftfreq_grid
|
|
13
|
+
from .coordinate_grid import coordinate_grid
|
|
@@ -0,0 +1,61 @@
|
|
|
1
|
+
from typing import Sequence
|
|
2
|
+
|
|
3
|
+
import einops
|
|
4
|
+
import numpy as np
|
|
5
|
+
import torch
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def coordinate_grid(
|
|
9
|
+
image_shape: Sequence[int],
|
|
10
|
+
center: torch.Tensor | tuple[float, ...] | None = False,
|
|
11
|
+
norm: bool = False,
|
|
12
|
+
device: torch.device | None = None,
|
|
13
|
+
) -> torch.FloatTensor:
|
|
14
|
+
"""Get a dense grid of array coordinates from grid dimensions.
|
|
15
|
+
|
|
16
|
+
For input `image_shape` of `(d, h, w)`, this function produces a
|
|
17
|
+
`(d, h, w, 3)` grid of coordinates. Coordinate order matches the order of
|
|
18
|
+
dimensions in `image_shape`.
|
|
19
|
+
|
|
20
|
+
Parameters
|
|
21
|
+
----------
|
|
22
|
+
image_shape: Sequence[int]
|
|
23
|
+
Shape of the image for which coordinates should be returned.
|
|
24
|
+
center: torch.Tensor | tuple[float, ...] | None
|
|
25
|
+
Array of center points relative to which coordinates will be calculated.
|
|
26
|
+
If `None`, default to the array origin `[0, ...]` of zero in all dimensions.
|
|
27
|
+
norm: bool
|
|
28
|
+
Whether to compute the Euclidean norm of the coordinate grid.
|
|
29
|
+
device: torch.device
|
|
30
|
+
PyTorch device on which to put the coordinate grid.
|
|
31
|
+
|
|
32
|
+
Returns
|
|
33
|
+
-------
|
|
34
|
+
grid: torch.LongTensor
|
|
35
|
+
`(*image_shape, image_ndim)` array of coordinates if `norm` is `False`
|
|
36
|
+
else `(*image_shape, )`.
|
|
37
|
+
"""
|
|
38
|
+
grid = torch.tensor(
|
|
39
|
+
np.indices(image_shape),
|
|
40
|
+
device=device,
|
|
41
|
+
dtype=torch.float32
|
|
42
|
+
) # (coordinates, *image_shape)
|
|
43
|
+
grid = einops.rearrange(grid, 'coords ... -> ... coords')
|
|
44
|
+
ndim = len(image_shape)
|
|
45
|
+
if center is not None:
|
|
46
|
+
center = torch.as_tensor(center, dtype=grid.dtype, device=grid.device)
|
|
47
|
+
center = torch.atleast_1d(center)
|
|
48
|
+
center, ps = einops.pack([center], pattern='* coords')
|
|
49
|
+
ones = ' '.join('1' * ndim)
|
|
50
|
+
axis_ids = ' '.join(_unique_characters(ndim))
|
|
51
|
+
center = einops.rearrange(center, f"b coords -> b {ones} coords")
|
|
52
|
+
grid = grid - center
|
|
53
|
+
[grid] = einops.unpack(grid, packed_shapes=ps, pattern=f'* {axis_ids} coords')
|
|
54
|
+
if norm is True:
|
|
55
|
+
grid = einops.reduce(grid ** 2, '... coords -> ...', reduction='sum') ** 0.5
|
|
56
|
+
return grid
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def _unique_characters(n: int) -> str:
|
|
60
|
+
chars = "abcdefghijklmnopqrstuvwxyz"
|
|
61
|
+
return chars[:n]
|
|
@@ -0,0 +1,204 @@
|
|
|
1
|
+
import functools
|
|
2
|
+
from typing import Sequence
|
|
3
|
+
|
|
4
|
+
import einops
|
|
5
|
+
import torch
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
@functools.lru_cache(maxsize=1)
|
|
9
|
+
def fftfreq_grid(
|
|
10
|
+
image_shape: tuple[int, int] | tuple[int, int, int],
|
|
11
|
+
rfft: bool,
|
|
12
|
+
fftshift: bool = False,
|
|
13
|
+
spacing: float | tuple[float, float] | tuple[float, float, float] = 1,
|
|
14
|
+
norm: bool = False,
|
|
15
|
+
device: torch.device | None = None,
|
|
16
|
+
) -> torch.Tensor:
|
|
17
|
+
"""Construct a 2D or 3D grid of DFT sample frequencies.
|
|
18
|
+
|
|
19
|
+
For a 2D image with shape `(h, w)` and `rfft=False` this function will produce
|
|
20
|
+
a `(h, w, 2)` array of DFT sample frequencies in the `h` and `w` dimensions.
|
|
21
|
+
If `norm` is True the Euclidean norm will be calculated over the last dimension
|
|
22
|
+
leaving a `(h, w)` grid.
|
|
23
|
+
|
|
24
|
+
Parameters
|
|
25
|
+
----------
|
|
26
|
+
image_shape: tuple[int, int] | tuple[int, int, int]
|
|
27
|
+
Shape of the 2D or 3D image before computing the DFT.
|
|
28
|
+
rfft: bool
|
|
29
|
+
Whether the output should contain frequencies for a real-valued DFT.
|
|
30
|
+
fftshift: bool
|
|
31
|
+
Whether to fftshift the output grid.
|
|
32
|
+
spacing: float | tuple[float, float] | tuple[float, float, float]
|
|
33
|
+
Spacing between samples in each dimension. Sampling is considered to be
|
|
34
|
+
isotropic if a single value is passed.
|
|
35
|
+
norm: bool
|
|
36
|
+
Whether to compute the Euclidean norm over the last dimension.
|
|
37
|
+
device: torch.device | None
|
|
38
|
+
PyTorch device on which the returned grid will be stored.
|
|
39
|
+
|
|
40
|
+
Returns
|
|
41
|
+
-------
|
|
42
|
+
frequency_grid: torch.Tensor
|
|
43
|
+
`(*image_shape, ndim)` array of DFT sample frequencies in each
|
|
44
|
+
image dimension if `norm` is `False` else `(*image_shape, )`.
|
|
45
|
+
"""
|
|
46
|
+
if len(image_shape) == 2:
|
|
47
|
+
frequency_grid = _construct_fftfreq_grid_2d(
|
|
48
|
+
image_shape=image_shape,
|
|
49
|
+
rfft=rfft,
|
|
50
|
+
spacing=spacing,
|
|
51
|
+
device=device,
|
|
52
|
+
)
|
|
53
|
+
if fftshift is True:
|
|
54
|
+
frequency_grid = einops.rearrange(frequency_grid, '... freq -> freq ...')
|
|
55
|
+
frequency_grid = fftshift_2d(frequency_grid, rfft=rfft)
|
|
56
|
+
frequency_grid = einops.rearrange(frequency_grid, 'freq ... -> ... freq')
|
|
57
|
+
elif len(image_shape) == 3:
|
|
58
|
+
frequency_grid = _construct_fftfreq_grid_3d(
|
|
59
|
+
image_shape=image_shape,
|
|
60
|
+
rfft=rfft,
|
|
61
|
+
spacing=spacing,
|
|
62
|
+
device=device,
|
|
63
|
+
)
|
|
64
|
+
if fftshift is True:
|
|
65
|
+
frequency_grid = einops.rearrange(frequency_grid, '... freq -> freq ...')
|
|
66
|
+
frequency_grid = fftshift_3d(frequency_grid, rfft=rfft)
|
|
67
|
+
frequency_grid = einops.rearrange(frequency_grid, 'freq ... -> ... freq')
|
|
68
|
+
else:
|
|
69
|
+
raise NotImplementedError(
|
|
70
|
+
"Construction of fftfreq grids is currently only supported for "
|
|
71
|
+
"2D and 3D images."
|
|
72
|
+
)
|
|
73
|
+
if norm is True:
|
|
74
|
+
frequency_grid = einops.reduce(
|
|
75
|
+
frequency_grid ** 2, '... squared_freqs -> ...', reduction='sum'
|
|
76
|
+
) ** 0.5
|
|
77
|
+
return frequency_grid
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
def _construct_fftfreq_grid_2d(
|
|
81
|
+
image_shape: tuple[int, int],
|
|
82
|
+
rfft: bool,
|
|
83
|
+
spacing: float | tuple[float, float] = 1,
|
|
84
|
+
device: torch.device = None
|
|
85
|
+
) -> torch.Tensor:
|
|
86
|
+
"""Construct a grid of DFT sample freqs for a 2D image.
|
|
87
|
+
|
|
88
|
+
Parameters
|
|
89
|
+
----------
|
|
90
|
+
image_shape: Sequence[int]
|
|
91
|
+
A 2D shape `(h, w)` of the input image for which a grid of DFT sample freqs
|
|
92
|
+
should be calculated.
|
|
93
|
+
rfft: bool
|
|
94
|
+
Whether the frequency grid is for a real fft (rfft).
|
|
95
|
+
spacing: float | tuple[float, float]
|
|
96
|
+
Sample spacing in `h` and `w` dimensions of the grid.
|
|
97
|
+
device: torch.device
|
|
98
|
+
Torch device for the resulting grid.
|
|
99
|
+
|
|
100
|
+
Returns
|
|
101
|
+
-------
|
|
102
|
+
frequency_grid: torch.Tensor
|
|
103
|
+
`(h, w, 2)` array of DFT sample freqs.
|
|
104
|
+
Order of freqs in the last dimension corresponds to the order of
|
|
105
|
+
the two dimensions of the grid.
|
|
106
|
+
"""
|
|
107
|
+
dh, dw = spacing if isinstance(spacing, Sequence) else [spacing] * 2
|
|
108
|
+
last_axis_frequency_func = torch.fft.rfftfreq if rfft is True else torch.fft.fftfreq
|
|
109
|
+
h, w = image_shape
|
|
110
|
+
freq_y = torch.fft.fftfreq(h, d=dh, device=device)
|
|
111
|
+
freq_x = last_axis_frequency_func(w, d=dw, device=device)
|
|
112
|
+
h, w = rfft_shape(image_shape) if rfft is True else image_shape
|
|
113
|
+
freq_yy = einops.repeat(freq_y, 'h -> h w', w=w)
|
|
114
|
+
freq_xx = einops.repeat(freq_x, 'w -> h w', h=h)
|
|
115
|
+
return einops.rearrange([freq_yy, freq_xx], 'freq h w -> h w freq')
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
def _construct_fftfreq_grid_3d(
|
|
119
|
+
image_shape: Sequence[int],
|
|
120
|
+
rfft: bool,
|
|
121
|
+
spacing: float | tuple[float, float, float] = 1,
|
|
122
|
+
device: torch.device = None
|
|
123
|
+
) -> torch.Tensor:
|
|
124
|
+
"""Construct a grid of DFT sample freqs for a 3D image.
|
|
125
|
+
|
|
126
|
+
Parameters
|
|
127
|
+
----------
|
|
128
|
+
image_shape: Sequence[int]
|
|
129
|
+
A 3D shape `(d, h, w)` of the input image for which a grid of DFT sample freqs
|
|
130
|
+
should be calculated.
|
|
131
|
+
rfft: bool
|
|
132
|
+
Controls Whether the frequency grid is for a real fft (rfft).
|
|
133
|
+
spacing: float | tuple[float, float, float]
|
|
134
|
+
Sample spacing in `d`, `h` and `w` dimensions of the grid.
|
|
135
|
+
device: torch.device
|
|
136
|
+
Torch device for the resulting grid.
|
|
137
|
+
|
|
138
|
+
Returns
|
|
139
|
+
-------
|
|
140
|
+
frequency_grid: torch.Tensor
|
|
141
|
+
`(h, w, 3)` array of DFT sample freqs.
|
|
142
|
+
Order of freqs in the last dimension corresponds to the order of dimensions
|
|
143
|
+
of the grid.
|
|
144
|
+
"""
|
|
145
|
+
dd, dh, dw = spacing if isinstance(spacing, Sequence) else [spacing] * 3
|
|
146
|
+
last_axis_frequency_func = torch.fft.rfftfreq if rfft is True else torch.fft.fftfreq
|
|
147
|
+
d, h, w = image_shape
|
|
148
|
+
freq_z = torch.fft.fftfreq(d, d=dd, device=device)
|
|
149
|
+
freq_y = torch.fft.fftfreq(h, d=dh, device=device)
|
|
150
|
+
freq_x = last_axis_frequency_func(w, d=dw, device=device)
|
|
151
|
+
d, h, w = rfft_shape(image_shape) if rfft is True else image_shape
|
|
152
|
+
freq_zz = einops.repeat(freq_z, 'd -> d h w', h=h, w=w)
|
|
153
|
+
freq_yy = einops.repeat(freq_y, 'h -> d h w', d=d, w=w)
|
|
154
|
+
freq_xx = einops.repeat(freq_x, 'w -> d h w', d=d, h=h)
|
|
155
|
+
return einops.rearrange([freq_zz, freq_yy, freq_xx], 'freq ... -> ... freq')
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
def rfft_shape(input_shape: Sequence[int]) -> tuple[int]:
|
|
159
|
+
"""Get the output shape of an rfft on an input with input_shape."""
|
|
160
|
+
rfft_shape = list(input_shape)
|
|
161
|
+
rfft_shape[-1] = int((rfft_shape[-1] / 2) + 1)
|
|
162
|
+
return tuple(rfft_shape)
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
def dft_center(
|
|
166
|
+
image_shape: tuple[int, ...],
|
|
167
|
+
rfft: bool,
|
|
168
|
+
fftshifted: bool,
|
|
169
|
+
device: torch.device | None = None,
|
|
170
|
+
) -> torch.LongTensor:
|
|
171
|
+
"""Return the position of the DFT center for a given input shape."""
|
|
172
|
+
fft_center = torch.zeros(size=(len(image_shape),), device=device)
|
|
173
|
+
image_shape = torch.as_tensor(image_shape).float()
|
|
174
|
+
if rfft is True:
|
|
175
|
+
image_shape = torch.tensor(rfft_shape(image_shape))
|
|
176
|
+
if fftshifted is True:
|
|
177
|
+
fft_center = torch.divide(image_shape, 2, rounding_mode='floor')
|
|
178
|
+
if rfft is True:
|
|
179
|
+
fft_center[-1] = 0
|
|
180
|
+
return fft_center.long()
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
def fftshift_2d(input: torch.Tensor, rfft: bool):
|
|
184
|
+
if rfft is False:
|
|
185
|
+
output = torch.fft.fftshift(input, dim=(-2, -1))
|
|
186
|
+
else:
|
|
187
|
+
output = torch.fft.fftshift(input, dim=(-2,))
|
|
188
|
+
return output
|
|
189
|
+
|
|
190
|
+
|
|
191
|
+
def ifftshift_2d(input: torch.Tensor, rfft: bool):
|
|
192
|
+
if rfft is False:
|
|
193
|
+
output = torch.fft.ifftshift(input, dim=(-2, -1))
|
|
194
|
+
else:
|
|
195
|
+
output = torch.fft.ifftshift(input, dim=(-2,))
|
|
196
|
+
return output
|
|
197
|
+
|
|
198
|
+
|
|
199
|
+
def fftshift_3d(input: torch.Tensor, rfft: bool):
|
|
200
|
+
if rfft is False:
|
|
201
|
+
output = torch.fft.fftshift(input, dim=(-3, -2, -1))
|
|
202
|
+
else:
|
|
203
|
+
output = torch.fft.fftshift(input, dim=(-3, -2,))
|
|
204
|
+
return output
|
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
Metadata-Version: 2.3
|
|
2
|
+
Name: torch-grid-utils
|
|
3
|
+
Version: 0.0.1
|
|
4
|
+
Summary: Grid utilities for 2D/3D image manipulations in PyTorch
|
|
5
|
+
Project-URL: homepage, https://github.com/alisterburt/torch-grids
|
|
6
|
+
Project-URL: repository, https://github.com/alisterburt/torch-grids
|
|
7
|
+
Author-email: Alister Burt <alisterburt@gmail.com>
|
|
8
|
+
License: BSD-3-Clause
|
|
9
|
+
License-File: LICENSE
|
|
10
|
+
Classifier: Development Status :: 3 - Alpha
|
|
11
|
+
Classifier: License :: OSI Approved :: BSD License
|
|
12
|
+
Classifier: Programming Language :: Python :: 3
|
|
13
|
+
Classifier: Programming Language :: Python :: 3.10
|
|
14
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
15
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
16
|
+
Classifier: Typing :: Typed
|
|
17
|
+
Requires-Python: >=3.10
|
|
18
|
+
Requires-Dist: einops
|
|
19
|
+
Requires-Dist: numpy
|
|
20
|
+
Requires-Dist: torch
|
|
21
|
+
Provides-Extra: dev
|
|
22
|
+
Requires-Dist: ipython; extra == 'dev'
|
|
23
|
+
Requires-Dist: mypy; extra == 'dev'
|
|
24
|
+
Requires-Dist: pdbpp; extra == 'dev'
|
|
25
|
+
Requires-Dist: pre-commit; extra == 'dev'
|
|
26
|
+
Requires-Dist: rich; extra == 'dev'
|
|
27
|
+
Requires-Dist: ruff; extra == 'dev'
|
|
28
|
+
Provides-Extra: test
|
|
29
|
+
Requires-Dist: pytest; extra == 'test'
|
|
30
|
+
Requires-Dist: pytest-cov; extra == 'test'
|
|
31
|
+
Description-Content-Type: text/markdown
|
|
32
|
+
|
|
33
|
+
# torch-grid-utils
|
|
34
|
+
|
|
35
|
+
[](https://github.com/alisterburt/torch-grids/raw/main/LICENSE)
|
|
36
|
+
[](https://pypi.org/project/torch-grids)
|
|
37
|
+
[](https://python.org)
|
|
38
|
+
[](https://github.com/alisterburt/torch-grids/actions/workflows/ci.yml)
|
|
39
|
+
[](https://codecov.io/gh/alisterburt/torch-grids)
|
|
40
|
+
|
|
41
|
+
*torch-grid-utils* provides grids for 2D/3D image manipulations in PyTorch.
|
|
42
|
+
|
|
43
|
+
|
|
@@ -0,0 +1,7 @@
|
|
|
1
|
+
torch_grid_utils/__init__.py,sha256=RLM2DiwY2CDqg8pZn_JF49o7WswV6_tETK6BOKglTUE,377
|
|
2
|
+
torch_grid_utils/coordinate_grid.py,sha256=IN7uofztuRwBT4bS03_0cb7_lHoonatqP4uwIdVxflQ,2170
|
|
3
|
+
torch_grid_utils/fftfreq_grid.py,sha256=KKruQEfPUNFI0CZGW-7PjCN9TIkqHT8QBLyMQpJKYug,7473
|
|
4
|
+
torch_grid_utils-0.0.1.dist-info/METADATA,sha256=4uRLXmghp7bvwOrkaW_PfuP8VpTRfttTvQjjQYo0k-0,1894
|
|
5
|
+
torch_grid_utils-0.0.1.dist-info/WHEEL,sha256=zEMcRr9Kr03x1ozGwg5v9NQBKn3kndp6LSoSlVg-jhU,87
|
|
6
|
+
torch_grid_utils-0.0.1.dist-info/licenses/LICENSE,sha256=Kbo_h3sPum8rDAhMerH9fl4hzFn-QUCekJf05zk2epY,1499
|
|
7
|
+
torch_grid_utils-0.0.1.dist-info/RECORD,,
|
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
BSD 3-Clause License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2023, Alister Burt
|
|
4
|
+
|
|
5
|
+
Redistribution and use in source and binary forms, with or without
|
|
6
|
+
modification, are permitted provided that the following conditions are met:
|
|
7
|
+
|
|
8
|
+
1. Redistributions of source code must retain the above copyright notice, this
|
|
9
|
+
list of conditions and the following disclaimer.
|
|
10
|
+
|
|
11
|
+
2. Redistributions in binary form must reproduce the above copyright notice,
|
|
12
|
+
this list of conditions and the following disclaimer in the documentation
|
|
13
|
+
and/or other materials provided with the distribution.
|
|
14
|
+
|
|
15
|
+
3. Neither the name of the copyright holder nor the names of its
|
|
16
|
+
contributors may be used to endorse or promote products derived from
|
|
17
|
+
this software without specific prior written permission.
|
|
18
|
+
|
|
19
|
+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
|
20
|
+
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
|
21
|
+
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
|
22
|
+
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
|
23
|
+
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
|
24
|
+
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
|
25
|
+
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
|
26
|
+
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
|
27
|
+
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
|
28
|
+
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|