torch-geopooling 1.2.0__cp311-cp311-manylinux_2_28_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of torch-geopooling might be problematic. Click here for more details.

@@ -0,0 +1,102 @@
1
+ # Copyright (C) 2024, Yakau Bubnou
2
+ #
3
+ # This program is free software: you can redistribute it and/or modify
4
+ # it under the terms of the GNU General Public License as published by
5
+ # the Free Software Foundation, either version 3 of the License, or
6
+ # (at your option) any later version.
7
+ #
8
+ # This program is distributed in the hope that it will be useful,
9
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
10
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11
+ # GNU General Public License for more details.
12
+ #
13
+ # You should have received a copy of the GNU General Public License
14
+ # along with this program. If not, see <https://www.gnu.org/licenses/>.
15
+
16
+ from typing import Union, Tuple, cast
17
+
18
+ import torch
19
+ from torch import Tensor, nn
20
+
21
+ from torch_geopooling import functional as F
22
+ from torch_geopooling.tiling import Exterior, ExteriorTuple
23
+
24
+
25
+ __all__ = [
26
+ "Embedding2d",
27
+ ]
28
+
29
+
30
+ _Exterior = Union[Exterior, ExteriorTuple]
31
+
32
+
33
+ class Embedding2d(nn.Module):
34
+ """
35
+ Retrieves spatial embeddings from a fixed-size lookup table based on 2D coordinates.
36
+
37
+ This module accepts a tensor of (x, y) coordinates and retrieves the corresponding
38
+ spatial embeddings from a provided embedding matrix. The embeddings are selected
39
+ based on the input coordinates, with an optional padding to include neighboring cells.
40
+
41
+ Args:
42
+ manifold: The size of the 2-dimensional embedding in a form (W, H, N), where
43
+ W is a width, H is a height, and N is a feature dimension of the embedding.
44
+ padding: The size of the neighborhood to query. Default is 0, meaning only the embedding
45
+ for the exact input coordinate is retrieved.
46
+ exterior: The geometric boundary of the learning space, specified as a tuple (X, Y, W, H),
47
+ where X and Y represent the origin, and W and H represent the width and height of the
48
+ space, respectively.
49
+ reflection: When true, kernel is wrapped around the exterior space, otherwise kernel is
50
+ squeezed into borders.
51
+
52
+ Shape:
53
+ - Input: :math:`(*, 2)`, where 2 comprises x and y coordinates.
54
+ - Output: :math:`(*, X_{out}, Y_{out}, N)`, where * is the input shape, \
55
+ :math:`N = \\text{manifold[2]}`, and
56
+
57
+ :math:`X_{out} = \\text{padding}[0] \\times 2 + 1`
58
+
59
+ :math:`Y_{out} = \\text{padding}[1] \\times 2 + 1`
60
+
61
+ Examples:
62
+
63
+ >>> # Create an embedding of EPSG:4326 rectangle into 1024x1024 embedding
64
+ >>> # with 3 features in each cell.
65
+ >>> embedding = nn.Embedding2d(
66
+ ... (1024, 1024, 3),
67
+ ... exterior=(-180.0, -90.0, 360.0, 180.0),
68
+ ... padding=(2, 2),
69
+ ... )
70
+ >>> input = torch.rand((100, 2), dtype=torch.float64) * 60.0
71
+ >>> output = embedding(input)
72
+ """
73
+
74
+ def __init__(
75
+ self,
76
+ manifold: Tuple[int, int, int],
77
+ exterior: _Exterior,
78
+ padding: Tuple[int, int] = (0, 0),
79
+ reflection: bool = True,
80
+ ) -> None:
81
+ super().__init__()
82
+ self.manifold = manifold
83
+ self.exterior = cast(ExteriorTuple, tuple(map(float, exterior)))
84
+ self.padding = padding
85
+ self.reflection = reflection
86
+
87
+ self.weight = nn.Parameter(torch.empty(manifold, dtype=torch.float64))
88
+ nn.init.zeros_(self.weight)
89
+
90
+ def extra_repr(self) -> str:
91
+ return "{manifold}, exterior={exterior}, padding={padding}, reflection={reflection}".format(
92
+ **self.__dict__
93
+ )
94
+
95
+ def forward(self, input: Tensor) -> Tensor:
96
+ return F.embedding2d(
97
+ input,
98
+ self.weight,
99
+ exterior=self.exterior,
100
+ padding=self.padding,
101
+ reflection=self.reflection,
102
+ )
@@ -0,0 +1,48 @@
1
+ # Copyright (C) 2024, Yakau Bubnou
2
+ #
3
+ # This program is free software: you can redistribute it and/or modify
4
+ # it under the terms of the GNU General Public License as published by
5
+ # the Free Software Foundation, either version 3 of the License, or
6
+ # (at your option) any later version.
7
+ #
8
+ # This program is distributed in the hope that it will be useful,
9
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
10
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11
+ # GNU General Public License for more details.
12
+ #
13
+ # You should have received a copy of the GNU General Public License
14
+ # along with this program. If not, see <https://www.gnu.org/licenses/>.
15
+
16
+ import pytest
17
+ import torch
18
+ from torch import nn
19
+ from torch.optim import SGD
20
+
21
+ from torch_geopooling.nn.embedding import Embedding2d
22
+
23
+
24
+ def test_embedding2d_optimize() -> None:
25
+ embedding = Embedding2d(
26
+ (2, 2, 1),
27
+ padding=(0, 0),
28
+ exterior=(-180.0, -90.0, 360.0, 180.0),
29
+ )
30
+
31
+ x_true = torch.tensor(
32
+ [[90.0, 45.0], [90.0, -45.0], [-90.0, -45.0], [-90.0, 45.0]], dtype=torch.float64
33
+ )
34
+ y_true = torch.tensor([[10.0], [20.0], [30.0], [40.0]], dtype=torch.float64)
35
+
36
+ optim = SGD(embedding.parameters(), lr=0.1)
37
+ loss_fn = nn.L1Loss()
38
+
39
+ for i in range(10000):
40
+ optim.zero_grad()
41
+
42
+ y_pred = embedding(x_true)
43
+ loss = loss_fn(y_pred[:, 0, 0, :], y_true)
44
+ loss.backward()
45
+
46
+ optim.step()
47
+
48
+ assert pytest.approx(0.0, abs=1e-1) == loss.detach().item()
@@ -0,0 +1,396 @@
1
+ # Copyright (C) 2024, Yakau Bubnou
2
+ #
3
+ # This program is free software: you can redistribute it and/or modify
4
+ # it under the terms of the GNU General Public License as published by
5
+ # the Free Software Foundation, either version 3 of the License, or
6
+ # (at your option) any later version.
7
+ #
8
+ # This program is distributed in the hope that it will be useful,
9
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
10
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11
+ # GNU General Public License for more details.
12
+ #
13
+ # You should have received a copy of the GNU General Public License
14
+ # along with this program. If not, see <https://www.gnu.org/licenses/>.
15
+
16
+ from typing import Optional, Union
17
+
18
+ import torch
19
+ from shapely.geometry import Polygon
20
+ from torch import Tensor, nn
21
+
22
+ from torch_geopooling import functional as F
23
+ from torch_geopooling.tiling import Exterior, ExteriorTuple, regular_tiling
24
+
25
+ __all__ = [
26
+ "AdaptiveAvgQuadPool2d",
27
+ "AdaptiveQuadPool2d",
28
+ "AdaptiveMaxQuadPool2d",
29
+ "AvgQuadPool2d",
30
+ "MaxQuadPool2d",
31
+ "QuadPool2d",
32
+ ]
33
+
34
+
35
+ _Exterior = Union[Exterior, ExteriorTuple]
36
+
37
+
38
+ _exterior_doc = """
39
+ Note:
40
+ Input coordinates must be within a specified exterior geometry (including boundaries).
41
+ For input coordinates outsize of the specified exterior, module throws an exception.
42
+ """
43
+
44
+
45
+ _terminal_group_doc = """
46
+ Note:
47
+ A **terminal group** refers to a collection of terminal nodes within the quadtree that
48
+ share the same parent tile.
49
+ """
50
+
51
+
52
+ class _AdaptiveQuadPool(nn.Module):
53
+ __doc__ = f"""
54
+ Args:
55
+ feature_dim: Size of each feature vector.
56
+ exterior: Geometrical boundary of the learning space in (X, Y, W, H) format.
57
+ max_terminal_nodes: Optional maximum number of terminal nodes in a quadtree. Once a
58
+ maximum is reached, internal nodes are no longer sub-divided and tree stops growing.
59
+ max_depth: Maximum depth of the quadtree. Default: 17.
60
+ capacity: Maximum number of inputs, after which a quadtree's node is subdivided and
61
+ depth of the tree grows. Default: 1.
62
+ precision: Optional rounding of the input coordinates. Default: 7.
63
+
64
+ Shape:
65
+ - Input: :math:`(*, 2)`, where 2 comprises longitude and latitude coordinates.
66
+ - Output: :math:`(*, H)`, where * is the input shape and :math:`H = \\text{{feature_dim}}`.
67
+
68
+ {_exterior_doc}
69
+ {_terminal_group_doc}
70
+ """
71
+
72
+ def __init__(
73
+ self,
74
+ feature_dim: int,
75
+ exterior: _Exterior,
76
+ max_terminal_nodes: Optional[int] = None,
77
+ max_depth: int = 17,
78
+ capacity: int = 1,
79
+ precision: Optional[int] = 7,
80
+ ) -> None:
81
+ super().__init__()
82
+ self.feature_dim = feature_dim
83
+ self.exterior = tuple(map(float, exterior))
84
+ self.max_terminal_nodes = max_terminal_nodes
85
+ self.max_depth = max_depth
86
+ self.capacity = capacity
87
+ self.precision = precision
88
+
89
+ self.initialize_parameters()
90
+
91
+ def initialize_parameters(self) -> None:
92
+ # The weight for adaptive operation should be sparse, since training operation
93
+ # results in a dynamic change of the underlying quadtree.
94
+ weight_size = (
95
+ self.max_depth + 1,
96
+ 1 << self.max_depth,
97
+ 1 << self.max_depth,
98
+ self.feature_dim,
99
+ )
100
+ self.weight = nn.Parameter(torch.sparse_coo_tensor(size=weight_size, dtype=torch.float64))
101
+
102
+ @property
103
+ def tiles(self) -> torch.Tensor:
104
+ """Return tiles of the quadtree."""
105
+ return self.weight.coalesce().detach().indices().t()[:, :-1]
106
+
107
+ def extra_repr(self) -> str:
108
+ return (
109
+ "{feature_dim}, "
110
+ "exterior={exterior}, capacity={capacity}, max_depth={max_depth}, "
111
+ "precision={precision}".format(**self.__dict__)
112
+ )
113
+
114
+
115
+ class AdaptiveQuadPool2d(_AdaptiveQuadPool):
116
+ __doc__ = f"""Adaptive lookup index over quadtree decomposition of input 2D coordinates.
117
+
118
+ This module constructs an internal lookup quadtree to organize closely situated 2D points.
119
+ Each terminal node in the resulting quadtree is paired with a weight. Thus, when providing
120
+ an input coordinate, the module retrieves the corresponding terminal node and returns its
121
+ associated weight.
122
+
123
+ {_AdaptiveQuadPool.__doc__}
124
+
125
+ Examples:
126
+
127
+ >>> # Feature vectors of size 4 over a 2d space.
128
+ >>> pool = nn.AdaptiveQuadPool2d(4, (-10, -5, 20, 10))
129
+ >>> # Grow tree up to 4-th level and sub-divides a node after 8 coordinates in a quad.
130
+ >>> pool = nn.AdaptiveQuadPool2d(4, (-10, -5, 20, 10), max_depth=4, capacity=8)
131
+ >>> # Create 2D coordinates and query associated weights.
132
+ >>> input = torch.rand((1024, 2), dtype=torch.float64) * 10 - 5
133
+ >>> output = pool(input)
134
+ """
135
+
136
+ def forward(self, input: Tensor) -> Tensor:
137
+ result = F.adaptive_quad_pool2d(
138
+ self.weight,
139
+ input,
140
+ self.exterior,
141
+ training=self.training,
142
+ max_terminal_nodes=self.max_terminal_nodes,
143
+ max_depth=self.max_depth,
144
+ capacity=self.capacity,
145
+ precision=self.precision,
146
+ )
147
+ if self.training:
148
+ self.weight.data = result.weight
149
+ return result.values
150
+
151
+
152
+ class AdaptiveMaxQuadPool2d(_AdaptiveQuadPool):
153
+ __doc__ = f"""Adaptive maximum pooling over quadtree decomposition of input 2D coordinates.
154
+
155
+ This module constructs an internal lookup quadtree to organize closely situated 2D points.
156
+ Each terminal node in the resulting quadtree is paired with a weight. Thus, when providing
157
+ an input coordinate, the module retrieves a **terminal group** of nodes and calculates the
158
+ maximum value for each ``feature_dim``.
159
+
160
+ {_AdaptiveQuadPool.__doc__}
161
+
162
+ Examples:
163
+
164
+ >>> pool = nn.AdaptiveMaxQuadPool2d(3, (-10, -5, 20, 10), max_depth=5)
165
+ >>> # Create 2D coordinates and feature vector associated with them.
166
+ >>> input = torch.rand((2048, 2), dtype=torch.float64) * 10 - 5
167
+ >>> output = pool(input)
168
+ """
169
+
170
+ def forward(self, input: Tensor) -> Tensor:
171
+ result = F.adaptive_max_quad_pool2d(
172
+ self.weight,
173
+ input,
174
+ self.exterior,
175
+ training=self.training,
176
+ max_terminal_nodes=self.max_terminal_nodes,
177
+ max_depth=self.max_depth,
178
+ capacity=self.capacity,
179
+ precision=self.precision,
180
+ )
181
+ if self.training:
182
+ self.weight.data = result.weight
183
+ return result.values
184
+
185
+
186
+ class AdaptiveAvgQuadPool2d(_AdaptiveQuadPool):
187
+ __doc__ = f"""Adaptive average pooling over quadtree decomposition of input 2D coordinates.
188
+
189
+ This module constructs an internal lookup quadtree to organize closely situated 2D points.
190
+ Each terminal node in the resulting quadtree is paired with a weight. Thus, when providing
191
+ an input coordinate, the module retrieves a **terminal group** of nodes and calculates an
192
+ average value for each ``feature_dim``.
193
+
194
+ {_AdaptiveQuadPool.__doc__}
195
+
196
+ Examples:
197
+
198
+ >>> # Create pool with 7 features.
199
+ >>> pool = nn.AdaptiveAvgQuadPool2d(7, (0, 0, 1, 1), max_depth=12)
200
+ >>> input = torch.rand((2048, 2), dtype=torch.float64)
201
+ >>> output = pool(input)
202
+ """
203
+
204
+ def forward(self, input: Tensor) -> Tensor:
205
+ result = F.adaptive_avg_quad_pool2d(
206
+ self.weight,
207
+ input,
208
+ self.exterior,
209
+ training=self.training,
210
+ max_terminal_nodes=self.max_terminal_nodes,
211
+ max_depth=self.max_depth,
212
+ capacity=self.capacity,
213
+ precision=self.precision,
214
+ )
215
+ if self.training:
216
+ self.weight.data = result.weight
217
+ return result.values
218
+
219
+
220
+ class _QuadPool(nn.Module):
221
+ __doc__ = f"""
222
+ Args:
223
+ feature_dim: Size of each feature vector.
224
+ polygon: Polygon that resembles boundary for the terminal nodes of a quadtree.
225
+ exterior: Geometrical boundary of the learning space in (X, Y, W, H) format.
226
+ max_terminal_nodes: Optional maximum number of terminal nodes in a quadtree. Once a
227
+ maximum is reached, internal nodes are no longer sub-divided and tree stops growing.
228
+ max_depth: Maximum depth of the quadtree. Default: 17.
229
+ precision: Optional rounding of the input coordinates. Default: 7.
230
+
231
+ Shape:
232
+ - Input: :math:`(*, 2)`, where 2 comprises longitude and latitude coordinates.
233
+ - Output: :math:`(*, H)`, where * is the input shape and :math:`H = \\text{{feature_dim}}`.
234
+
235
+ {_exterior_doc}
236
+ {_terminal_group_doc}
237
+
238
+ Note:
239
+ All terminal nodes that have an intersection with the specified polygon boundary are
240
+ included into the quadtree.
241
+ """
242
+
243
+ def __init__(
244
+ self,
245
+ feature_dim: int,
246
+ polygon: Polygon,
247
+ exterior: _Exterior,
248
+ max_terminal_nodes: Optional[int] = None,
249
+ max_depth: int = 17,
250
+ precision: Optional[int] = 7,
251
+ ) -> None:
252
+ super().__init__()
253
+ self.feature_dim = feature_dim
254
+ self.polygon = polygon
255
+ self.exterior = tuple(map(float, exterior))
256
+ self.max_terminal_nodes = max_terminal_nodes
257
+ self.max_depth = max_depth
258
+ self.precision = precision
259
+
260
+ # Generate regular tiling for the provided polygon and build from those
261
+ # tiles a quadtree from terminal nodes all way up to the root node.
262
+ tiles_iter = regular_tiling(
263
+ polygon, Exterior.from_tuple(exterior), z=max_depth, internal=True
264
+ )
265
+ tiles = torch.tensor(list(tiles_iter), dtype=torch.int64)
266
+
267
+ self.register_buffer("tiles", tiles)
268
+ self.tiles: Tensor
269
+
270
+ self.initialize_parameters()
271
+ self.reset_parameters()
272
+
273
+ def initialize_parameters(self) -> None:
274
+ weight_size = [self.tiles.size(0), self.feature_dim]
275
+ self.weight = nn.Parameter(torch.empty(weight_size, dtype=torch.float64))
276
+
277
+ def reset_parameters(self) -> None:
278
+ nn.init.uniform_(self.weight)
279
+
280
+ def extra_repr(self) -> str:
281
+ return (
282
+ "{feature_dim}, exterior={exterior}, max_depth={max_depth}, "
283
+ "precision={precision}".format(**self.__dict__)
284
+ )
285
+
286
+
287
+ class QuadPool2d(_QuadPool):
288
+ __doc__ = f"""Lookup index over quadtree decomposition of input 2D coordinates.
289
+
290
+ This module constructs an internal lookup tree to organize closely situated 2D points using
291
+ a specified polygon and exterior, where polygon is treated as a *boundary* of terminal
292
+ nodes of a quadtree.
293
+
294
+ Each terminal node in the resulting quadtree is paired with a weight. Thus, when providing
295
+ an input coordinate, the module retrieves the corresponding terminal node and returns its
296
+ associated weight.
297
+
298
+ {_QuadPool.__doc__}
299
+
300
+ Examples:
301
+
302
+ >>> from shapely.geometry import Polygon
303
+ >>> # Create a pool for squared exterior 100x100 and use only a portion of that
304
+ >>> # exterior isolated by a square 10x10.
305
+ >>> poly = Polygon([(0, 0), (10, 0), (10, 10), (0, 10)])
306
+ >>> pool = nn.QuadPool2d(5, poly, exterior=(0, 0, 100, 100))
307
+ >>> input = torch.rand((2048, 2), dtype=torch.float64)
308
+ >>> output = pool(input)
309
+ """
310
+
311
+ def forward(self, input: Tensor) -> Tensor:
312
+ result = F.quad_pool2d(
313
+ self.tiles,
314
+ self.weight,
315
+ input,
316
+ self.exterior,
317
+ # This is not a mistake, since we already know the shape of the
318
+ # quadtree, there is no need to learn it.
319
+ training=False,
320
+ max_terminal_nodes=self.max_terminal_nodes,
321
+ max_depth=self.max_depth,
322
+ precision=self.precision,
323
+ )
324
+ return result.values
325
+
326
+
327
+ class MaxQuadPool2d(_QuadPool):
328
+ __doc__ = f"""Maximum pooling over quadtree decomposition of input 2D coordinates.
329
+
330
+ This module constructs an internal lookup tree to organize closely situated 2D points using
331
+ a specified polygon and exterior, where polygon is treated as a *boundary* of terminal nodes
332
+ of a quadtree.
333
+
334
+ Each terminal node in the resulting quadtree is paired with a weight. Thus, when providing
335
+ an input coordinate, the module retrieves a **terminal group** of nodes and calculates the
336
+ maximum value for each ``feature_dim``.
337
+
338
+ {_QuadPool.__doc__}
339
+
340
+ Examples:
341
+
342
+ >>> from shapely.geometry import Polygon
343
+ >>> poly = Polygon([(0, 0), (10, 0), (10, 10), (0, 10)])
344
+ >>> pool = nn.MaxQuadPool2d(3, poly, exterior=(0, 0, 100, 100))
345
+ >>> input = torch.rand((2048, 2), dtype=torch.float64)
346
+ >>> output = pool(input)
347
+ """
348
+
349
+ def forward(self, input: Tensor) -> Tensor:
350
+ result = F.max_quad_pool2d(
351
+ self.tiles,
352
+ self.weight,
353
+ input,
354
+ self.exterior,
355
+ training=False,
356
+ max_terminal_nodes=self.max_terminal_nodes,
357
+ max_depth=self.max_depth,
358
+ precision=self.precision,
359
+ )
360
+ return result.values
361
+
362
+
363
+ class AvgQuadPool2d(_QuadPool):
364
+ __doc__ = f"""Average pooling over quadtree decomposition of input 2D coordinates.
365
+
366
+ This module constructs an internal lookup tree to organize closely situated 2D points using
367
+ a specified polygon and exterior, where polygon is treated as a *boundary* of terminal
368
+ nodes of a quadtree.
369
+
370
+ Each terminal node in the resulting quadtree is paired with a weight. Thus, when providing
371
+ an input coordinate, the module retrieves a **terminal group** of nodes and calculates an
372
+ average value for each ``feature_dim``.
373
+
374
+ {_QuadPool.__doc__}
375
+
376
+ Examples:
377
+
378
+ >>> from shapely.geometry import Polygon
379
+ >>> poly = Polygon([(0, 0), (10, 0), (10, 10), (0, 10)])
380
+ >>> pool = nn.AvgQuadPool2d(4, poly, exterior=(0, 0, 100, 100))
381
+ >>> input = torch.rand((2048, 2), dtype=torch.float64)
382
+ >>> output = pool(input)
383
+ """
384
+
385
+ def forward(self, input: Tensor) -> Tensor:
386
+ result = F.avg_quad_pool2d(
387
+ self.tiles,
388
+ self.weight,
389
+ input,
390
+ self.exterior,
391
+ training=False,
392
+ max_terminal_nodes=self.max_terminal_nodes,
393
+ max_depth=self.max_depth,
394
+ precision=self.precision,
395
+ )
396
+ return result.values
@@ -0,0 +1,158 @@
1
+ # Copyright (C) 2024, Yakau Bubnou
2
+ #
3
+ # This program is free software: you can redistribute it and/or modify
4
+ # it under the terms of the GNU General Public License as published by
5
+ # the Free Software Foundation, either version 3 of the License, or
6
+ # (at your option) any later version.
7
+ #
8
+ # This program is distributed in the hope that it will be useful,
9
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
10
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11
+ # GNU General Public License for more details.
12
+ #
13
+ # You should have received a copy of the GNU General Public License
14
+ # along with this program. If not, see <https://www.gnu.org/licenses/>.
15
+
16
+ from typing import Type
17
+
18
+ import pytest
19
+ import torch
20
+ from shapely.geometry import Polygon
21
+ from torch import nn
22
+ from torch.optim import SGD
23
+ from torch.nn import L1Loss
24
+
25
+ from torch_geopooling.nn.pooling import (
26
+ AdaptiveAvgQuadPool2d,
27
+ AdaptiveMaxQuadPool2d,
28
+ AdaptiveQuadPool2d,
29
+ AvgQuadPool2d,
30
+ MaxQuadPool2d,
31
+ QuadPool2d,
32
+ )
33
+
34
+
35
+ @pytest.mark.parametrize(
36
+ "module_class",
37
+ [
38
+ AdaptiveQuadPool2d,
39
+ AdaptiveMaxQuadPool2d,
40
+ AdaptiveAvgQuadPool2d,
41
+ ],
42
+ ids=["id", "max", "avg"],
43
+ )
44
+ def test_adaptive_quad_pool2d_gradient(module_class: Type[nn.Module]) -> None:
45
+ pool = module_class(5, (-180, -90, 360, 180))
46
+
47
+ input = torch.rand((100, 2), dtype=torch.float64) * 90
48
+ y = pool(input)
49
+
50
+ assert pool.weight.grad is None
51
+
52
+ loss_fn = L1Loss()
53
+ loss = loss_fn(y, torch.ones_like(y))
54
+ loss.backward()
55
+
56
+ assert pool.weight.grad is not None
57
+ assert pool.weight.grad.sum().item() == pytest.approx(-1)
58
+
59
+
60
+ def test_adaptive_quad_pool2d_optimize() -> None:
61
+ pool = AdaptiveQuadPool2d(1, (-180, -90, 360, 180), max_depth=1)
62
+
63
+ # Input coordinates are simply centers of the level-1 quads.
64
+ x_true = torch.tensor(
65
+ [[90.0, 45.0], [90.0, -45.0], [-90.0, -45.0], [-90.0, 45.0]], dtype=torch.float64
66
+ )
67
+ y_true = torch.tensor([[10.0], [20.0], [30.0], [40.0]], dtype=torch.float64)
68
+ y_tile = [[1, 1, 1], [1, 1, 0], [1, 0, 0], [1, 0, 1]]
69
+
70
+ optim = SGD(pool.parameters(), lr=0.01)
71
+ loss_fn = nn.L1Loss()
72
+
73
+ for i in range(20000):
74
+ optim.zero_grad()
75
+
76
+ y_pred = pool(x_true)
77
+ loss = loss_fn(y_pred, y_true)
78
+ loss.backward()
79
+
80
+ optim.step()
81
+
82
+ # Ensure that model converged with a small loss.
83
+ assert pytest.approx(0.0, abs=1e-1) == loss.detach().item()
84
+
85
+ # Ensure that weights that pooling operation learned are the same as in the
86
+ # target matrix (y_true).
87
+ weight = pool.weight.to_dense()
88
+
89
+ for i, tile in enumerate(y_tile):
90
+ z, x, y = tile
91
+ expect_weight = y_true[i].item()
92
+ actual_weight = weight[z, x, y].detach().item()
93
+
94
+ assert pytest.approx(expect_weight, abs=1e-1) == actual_weight, f"tile {tile} is wrong"
95
+
96
+
97
+ @pytest.mark.parametrize(
98
+ "module_class",
99
+ [
100
+ QuadPool2d,
101
+ MaxQuadPool2d,
102
+ AvgQuadPool2d,
103
+ ],
104
+ ids=["id", "max", "avg"],
105
+ )
106
+ def test_quad_pool2d_gradient(module_class: Type[nn.Module]) -> None:
107
+ poly = Polygon([(0.0, 0.0), (1.0, 0.0), (1.0, 1.1), (0.0, 1.0)])
108
+ exterior = (0.0, 0.0, 1.0, 1.0)
109
+
110
+ pool = module_class(4, poly, exterior, max_depth=5)
111
+ assert pool.weight.size() == torch.Size([pool.tiles.size(0), 4])
112
+
113
+ input = torch.rand((100, 2), dtype=torch.float64)
114
+ y = pool(input)
115
+
116
+ assert pool.weight.grad is None
117
+
118
+ loss_fn = L1Loss()
119
+ loss = loss_fn(y, torch.ones_like(y))
120
+ loss.backward()
121
+
122
+ assert pool.weight.grad is not None
123
+ assert pool.weight.grad.sum().item() == pytest.approx(-1)
124
+
125
+
126
+ def test_quad_pool2d_optimize() -> None:
127
+ poly = Polygon([(-180, -90), (-180, 90), (180, 90), (180, -90)])
128
+ pool = QuadPool2d(1, poly, (-180, -90, 360, 180), max_depth=1)
129
+
130
+ x_true = torch.tensor(
131
+ [[90.0, 45.0], [90.0, -45.0], [-90.0, -45.0], [-90.0, 45.0]], dtype=torch.float64
132
+ )
133
+ y_true = torch.tensor([[10.0], [20.0], [30.0], [40.0]], dtype=torch.float64)
134
+ y_tile = [(1, 1, 1), (1, 1, 0), (1, 0, 0), (1, 0, 1)]
135
+
136
+ optim = SGD(pool.parameters(), lr=0.01)
137
+ loss_fn = nn.L1Loss()
138
+
139
+ for i in range(20000):
140
+ optim.zero_grad()
141
+
142
+ y_pred = pool(x_true)
143
+ loss = loss_fn(y_pred, y_true)
144
+ loss.backward()
145
+
146
+ optim.step()
147
+
148
+ # Ensure that model converged with a small loss.
149
+ assert pytest.approx(0.0, abs=1e-1) == loss.detach().item()
150
+
151
+ actual_tiles = {}
152
+ for i in range(pool.tiles.size(0)):
153
+ tile = tuple(pool.tiles[i].detach().tolist())
154
+ actual_tiles[tile] = pool.weight[i, 0].detach().item()
155
+
156
+ for tile, expect_weight in zip(y_tile, y_true[:, 0].tolist()):
157
+ actual_weight = actual_tiles[tile]
158
+ assert pytest.approx(expect_weight, abs=1e-1) == actual_weight, f"tile {tile} is wrong"
File without changes