torch-geopooling 1.0.0rc3__cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of torch-geopooling might be problematic. Click here for more details.

torch_geopooling/nn.py ADDED
@@ -0,0 +1,391 @@
1
+ # Copyright (C) 2024, Yakau Bubnou
2
+ #
3
+ # This program is free software: you can redistribute it and/or modify
4
+ # it under the terms of the GNU General Public License as published by
5
+ # the Free Software Foundation, either version 3 of the License, or
6
+ # (at your option) any later version.
7
+ #
8
+ # This program is distributed in the hope that it will be useful,
9
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
10
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11
+ # GNU General Public License for more details.
12
+ #
13
+ # You should have received a copy of the GNU General Public License
14
+ # along with this program. If not, see <https://www.gnu.org/licenses/>.
15
+
16
+ from typing import Optional, Union
17
+
18
+ import torch
19
+ from shapely.geometry import Polygon
20
+ from torch import Tensor, nn
21
+
22
+ from torch_geopooling import functional as F
23
+ from torch_geopooling.tiling import Exterior, ExteriorTuple, regular_tiling
24
+
25
+ __all__ = [
26
+ "AdaptiveAvgQuadPool2d",
27
+ "AdaptiveQuadPool2d",
28
+ "AdaptiveMaxQuadPool2d",
29
+ "AvgQuadPool2d",
30
+ "MaxQuadPool2d",
31
+ "QuadPool2d",
32
+ ]
33
+
34
+
35
+ _Exterior = Union[Exterior, ExteriorTuple]
36
+
37
+
38
+ _exterior_doc = """
39
+ Note:
40
+ Input coordinates must be within a specified exterior geometry (including boundaries).
41
+ For input coordinates outsize of the specified exterior, module throws an exception.
42
+ """
43
+
44
+
45
+ _terminal_group_doc = """
46
+ Note:
47
+ A **terminal group** refers to a collection of terminal nodes within the quadtree that
48
+ share the same parent tile.
49
+ """
50
+
51
+
52
+ class _AdaptiveQuadPool(nn.Module):
53
+ __doc__ = f"""
54
+ Args:
55
+ feature_dim: Size of each feature vector.
56
+ exterior: Geometrical boundary of the learning space in (X, Y, W, H) format.
57
+ max_terminal_nodes: Optional maximum number of terminal nodes in a quadtree. Once a
58
+ maximum is reached, internal nodes are no longer sub-divided and tree stops growing.
59
+ max_depth: Maximum depth of the quadtree. Default: 17.
60
+ capacity: Maximum number of inputs, after which a quadtree's node is subdivided and
61
+ depth of the tree grows. Default: 1.
62
+ precision: Optional rounding of the input coordinates. Default: 7.
63
+
64
+ Shape:
65
+ - Input: :math:`(*, 2)`, where 2 comprises longitude and latitude coordinates.
66
+ - Output: :math:`(*, H)`, where * is the input shape and :math:`H = \\text{{feature_dim}}`.
67
+
68
+ {_exterior_doc}
69
+ {_terminal_group_doc}
70
+ """
71
+
72
+ def __init__(
73
+ self,
74
+ feature_dim: int,
75
+ exterior: _Exterior,
76
+ max_terminal_nodes: Optional[int] = None,
77
+ max_depth: int = 17,
78
+ capacity: int = 1,
79
+ precision: Optional[int] = 7,
80
+ ) -> None:
81
+ super().__init__()
82
+ self.feature_dim = feature_dim
83
+ self.exterior = tuple(map(float, exterior))
84
+ self.max_terminal_nodes = max_terminal_nodes
85
+ self.max_depth = max_depth
86
+ self.capacity = capacity
87
+ self.precision = precision
88
+
89
+ self.initialize_parameters()
90
+
91
+ def initialize_parameters(self) -> None:
92
+ # The weight for adaptive operation should be sparse, since training operation
93
+ # results in a dynamic change of the underlying quadtree.
94
+ weight_size = (self.max_depth, 1 << self.max_depth, 1 << self.max_depth, self.feature_dim)
95
+ self.weight = nn.Parameter(torch.sparse_coo_tensor(size=weight_size, dtype=torch.float64))
96
+
97
+ @property
98
+ def tiles(self) -> torch.Tensor:
99
+ """Return tiles of the quadtree."""
100
+ return self.weight.coalesce().detach().indices().t()[:, :-1]
101
+
102
+ def extra_repr(self) -> str:
103
+ return (
104
+ "{feature_dim}, "
105
+ "exterior={exterior}, capacity={capacity}, max_depth={max_depth}, "
106
+ "precision={precision}".format(**self.__dict__)
107
+ )
108
+
109
+
110
+ class AdaptiveQuadPool2d(_AdaptiveQuadPool):
111
+ __doc__ = f"""Adaptive lookup index over quadtree decomposition of input 2D coordinates.
112
+
113
+ This module constructs an internal lookup quadtree to organize closely situated 2D points.
114
+ Each terminal node in the resulting quadtree is paired with a weight. Thus, when providing
115
+ an input coordinate, the module retrieves the corresponding terminal node and returns its
116
+ associated weight.
117
+
118
+ {_AdaptiveQuadPool.__doc__}
119
+
120
+ Examples:
121
+
122
+ >>> # Feature vectors of size 4 over a 2d space.
123
+ >>> pool = nn.AdaptiveQuadPool2d(4, (-10, -5, 20, 10))
124
+ >>> # Grow tree up to 4-th level and sub-divides a node after 8 coordinates in a quad.
125
+ >>> pool = nn.AdaptiveQuadPool2d(4, (-10, -5, 20, 10), max_depth=4, capacity=8)
126
+ >>> # Create 2D coordinates and query associated weights.
127
+ >>> input = torch.rand((1024, 2), dtype=torch.float64) * 10 - 5
128
+ >>> output = pool(input)
129
+ """
130
+
131
+ def forward(self, input: Tensor) -> Tensor:
132
+ result = F.adaptive_quad_pool2d(
133
+ self.weight,
134
+ input,
135
+ self.exterior,
136
+ training=self.training,
137
+ max_terminal_nodes=self.max_terminal_nodes,
138
+ max_depth=self.max_depth,
139
+ capacity=self.capacity,
140
+ precision=self.precision,
141
+ )
142
+ if self.training:
143
+ self.weight.data = result.weight
144
+ return result.values
145
+
146
+
147
+ class AdaptiveMaxQuadPool2d(_AdaptiveQuadPool):
148
+ __doc__ = f"""Adaptive maximum pooling over quadtree decomposition of input 2D coordinates.
149
+
150
+ This module constructs an internal lookup quadtree to organize closely situated 2D points.
151
+ Each terminal node in the resulting quadtree is paired with a weight. Thus, when providing
152
+ an input coordinate, the module retrieves a **terminal group** of nodes and calculates the
153
+ maximum value for each ``feature_dim``.
154
+
155
+ {_AdaptiveQuadPool.__doc__}
156
+
157
+ Examples:
158
+
159
+ >>> pool = nn.AdaptiveMaxQuadPool2d(3, (-10, -5, 20, 10), max_depth=5)
160
+ >>> # Create 2D coordinates and feature vector associated with them.
161
+ >>> input = torch.rand((2048, 2), dtype=torch.float64) * 10 - 5
162
+ >>> output = pool(input)
163
+ """
164
+
165
+ def forward(self, input: Tensor) -> Tensor:
166
+ result = F.adaptive_max_quad_pool2d(
167
+ self.weight,
168
+ input,
169
+ self.exterior,
170
+ training=self.training,
171
+ max_terminal_nodes=self.max_terminal_nodes,
172
+ max_depth=self.max_depth,
173
+ capacity=self.capacity,
174
+ precision=self.precision,
175
+ )
176
+ if self.training:
177
+ self.weight.data = result.weight
178
+ return result.values
179
+
180
+
181
+ class AdaptiveAvgQuadPool2d(_AdaptiveQuadPool):
182
+ __doc__ = f"""Adaptive average pooling over quadtree decomposition of input 2D coordinates.
183
+
184
+ This module constructs an internal lookup quadtree to organize closely situated 2D points.
185
+ Each terminal node in the resulting quadtree is paired with a weight. Thus, when providing
186
+ an input coordinate, the module retrieves a **terminal group** of nodes and calculates an
187
+ average value for each ``feature_dim``.
188
+
189
+ {_AdaptiveQuadPool.__doc__}
190
+
191
+ Examples:
192
+
193
+ >>> # Create pool with 7 features.
194
+ >>> pool = nn.AdaptiveAvgQuadPool2d(7, (0, 0, 1, 1), max_depth=12)
195
+ >>> input = torch.rand((2048, 2), dtype=torch.float64)
196
+ >>> output = pool(input)
197
+ """
198
+
199
+ def forward(self, input: Tensor) -> Tensor:
200
+ result = F.adaptive_avg_quad_pool2d(
201
+ self.weight,
202
+ input,
203
+ self.exterior,
204
+ training=self.training,
205
+ max_terminal_nodes=self.max_terminal_nodes,
206
+ max_depth=self.max_depth,
207
+ capacity=self.capacity,
208
+ precision=self.precision,
209
+ )
210
+ if self.training:
211
+ self.weight.data = result.weight
212
+ return result.values
213
+
214
+
215
+ class _QuadPool(nn.Module):
216
+ __doc__ = f"""
217
+ Args:
218
+ feature_dim: Size of each feature vector.
219
+ polygon: Polygon that resembles boundary for the terminal nodes of a quadtree.
220
+ exterior: Geometrical boundary of the learning space in (X, Y, W, H) format.
221
+ max_terminal_nodes: Optional maximum number of terminal nodes in a quadtree. Once a
222
+ maximum is reached, internal nodes are no longer sub-divided and tree stops growing.
223
+ max_depth: Maximum depth of the quadtree. Default: 17.
224
+ precision: Optional rounding of the input coordinates. Default: 7.
225
+
226
+ Shape:
227
+ - Input: :math:`(*, 2)`, where 2 comprises longitude and latitude coordinates.
228
+ - Output: :math:`(*, H)`, where * is the input shape and :math:`H = \\text{{feature_dim}}`.
229
+
230
+ {_exterior_doc}
231
+ {_terminal_group_doc}
232
+
233
+ Note:
234
+ All terminal nodes that have an intersection with the specified polygon boundary are
235
+ included into the quadtree.
236
+ """
237
+
238
+ def __init__(
239
+ self,
240
+ feature_dim: int,
241
+ polygon: Polygon,
242
+ exterior: _Exterior,
243
+ max_terminal_nodes: Optional[int] = None,
244
+ max_depth: int = 17,
245
+ precision: Optional[int] = 7,
246
+ ) -> None:
247
+ super().__init__()
248
+ self.feature_dim = feature_dim
249
+ self.polygon = polygon
250
+ self.exterior = tuple(map(float, exterior))
251
+ self.max_terminal_nodes = max_terminal_nodes
252
+ self.max_depth = max_depth
253
+ self.precision = precision
254
+
255
+ # Generate regular tiling for the provided polygon and build from those
256
+ # tiles a quadtree from terminal nodes all way up to the root node.
257
+ tiles_iter = regular_tiling(
258
+ polygon, Exterior.from_tuple(exterior), z=max_depth, internal=True
259
+ )
260
+ tiles = torch.tensor(list(tiles_iter), dtype=torch.int64)
261
+
262
+ self.register_buffer("tiles", tiles)
263
+ self.tiles: Tensor
264
+
265
+ self.initialize_parameters()
266
+ self.reset_parameters()
267
+
268
+ def initialize_parameters(self) -> None:
269
+ weight_size = [self.tiles.size(0), self.feature_dim]
270
+ self.weight = nn.Parameter(torch.empty(weight_size, dtype=torch.float64))
271
+
272
+ def reset_parameters(self) -> None:
273
+ nn.init.uniform_(self.weight)
274
+
275
+ def extra_repr(self) -> str:
276
+ return (
277
+ "{feature_dim}, exterior={exterior}, max_depth={max_depth}, "
278
+ "precision={precision}".format(**self.__dict__)
279
+ )
280
+
281
+
282
+ class QuadPool2d(_QuadPool):
283
+ __doc__ = f"""Lookup index over quadtree decomposition of input 2D coordinates.
284
+
285
+ This module constructs an internal lookup tree to organize closely situated 2D points using
286
+ a specified polygon and exterior, where polygon is treated as a *boundary* of terminal
287
+ nodes of a quadtree.
288
+
289
+ Each terminal node in the resulting quadtree is paired with a weight. Thus, when providing
290
+ an input coordinate, the module retrieves the corresponding terminal node and returns its
291
+ associated weight.
292
+
293
+ {_QuadPool.__doc__}
294
+
295
+ Examples:
296
+
297
+ >>> from shapely.geometry import Polygon
298
+ >>> # Create a pool for squared exterior 100x100 and use only a portion of that
299
+ >>> # exterior isolated by a square 10x10.
300
+ >>> poly = Polygon([(0, 0), (10, 0), (10, 10), (0, 10)])
301
+ >>> pool = nn.QuadPool2d(5, poly, exterior=(0, 0, 100, 100))
302
+ >>> input = torch.rand((2048, 2), dtype=torch.float64)
303
+ >>> output = pool(input)
304
+ """
305
+
306
+ def forward(self, input: Tensor) -> Tensor:
307
+ result = F.quad_pool2d(
308
+ self.tiles,
309
+ self.weight,
310
+ input,
311
+ self.exterior,
312
+ # This is not a mistake, since we already know the shape of the
313
+ # quadtree, there is no need to learn it.
314
+ training=False,
315
+ max_terminal_nodes=self.max_terminal_nodes,
316
+ max_depth=self.max_depth,
317
+ precision=self.precision,
318
+ )
319
+ return result.values
320
+
321
+
322
+ class MaxQuadPool2d(_QuadPool):
323
+ __doc__ = f"""Maximum pooling over quadtree decomposition of input 2D coordinates.
324
+
325
+ This module constructs an internal lookup tree to organize closely situated 2D points using
326
+ a specified polygon and exterior, where polygon is treated as a *boundary* of terminal nodes
327
+ of a quadtree.
328
+
329
+ Each terminal node in the resulting quadtree is paired with a weight. Thus, when providing
330
+ an input coordinate, the module retrieves a **terminal group** of nodes and calculates the
331
+ maximum value for each ``feature_dim``.
332
+
333
+ {_QuadPool.__doc__}
334
+
335
+ Examples:
336
+
337
+ >>> from shapely.geometry import Polygon
338
+ >>> poly = Polygon([(0, 0), (10, 0), (10, 10), (0, 10)])
339
+ >>> pool = nn.MaxQuadPool2d(3, poly, exterior=(0, 0, 100, 100))
340
+ >>> input = torch.rand((2048, 2), dtype=torch.float64)
341
+ >>> output = pool(input)
342
+ """
343
+
344
+ def forward(self, input: Tensor) -> Tensor:
345
+ result = F.max_quad_pool2d(
346
+ self.tiles,
347
+ self.weight,
348
+ input,
349
+ self.exterior,
350
+ training=False,
351
+ max_terminal_nodes=self.max_terminal_nodes,
352
+ max_depth=self.max_depth,
353
+ precision=self.precision,
354
+ )
355
+ return result.values
356
+
357
+
358
+ class AvgQuadPool2d(_QuadPool):
359
+ __doc__ = f"""Average pooling over quadtree decomposition of input 2D coordinates.
360
+
361
+ This module constructs an internal lookup tree to organize closely situated 2D points using
362
+ a specified polygon and exterior, where polygon is treated as a *boundary* of terminal
363
+ nodes of a quadtree.
364
+
365
+ Each terminal node in the resulting quadtree is paired with a weight. Thus, when providing
366
+ an input coordinate, the module retrieves a **terminal group** of nodes and calculates an
367
+ average value for each ``feature_dim``.
368
+
369
+ {_QuadPool.__doc__}
370
+
371
+ Examples:
372
+
373
+ >>> from shapely.geometry import Polygon
374
+ >>> poly = Polygon([(0, 0), (10, 0), (10, 10), (0, 10)])
375
+ >>> pool = nn.AvgQuadPool2d(4, poly, exterior=(0, 0, 100, 100))
376
+ >>> input = torch.rand((2048, 2), dtype=torch.float64)
377
+ >>> output = pool(input)
378
+ """
379
+
380
+ def forward(self, input: Tensor) -> Tensor:
381
+ result = F.avg_quad_pool2d(
382
+ self.tiles,
383
+ self.weight,
384
+ input,
385
+ self.exterior,
386
+ training=False,
387
+ max_terminal_nodes=self.max_terminal_nodes,
388
+ max_depth=self.max_depth,
389
+ precision=self.precision,
390
+ )
391
+ return result.values
@@ -0,0 +1,85 @@
1
+ # Copyright (C) 2024, Yakau Bubnou
2
+ #
3
+ # This program is free software: you can redistribute it and/or modify
4
+ # it under the terms of the GNU General Public License as published by
5
+ # the Free Software Foundation, either version 3 of the License, or
6
+ # (at your option) any later version.
7
+ #
8
+ # This program is distributed in the hope that it will be useful,
9
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
10
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11
+ # GNU General Public License for more details.
12
+ #
13
+ # You should have received a copy of the GNU General Public License
14
+ # along with this program. If not, see <https://www.gnu.org/licenses/>.
15
+
16
+ from typing import Type
17
+
18
+ import pytest
19
+ import torch
20
+ from shapely.geometry import Polygon
21
+ from torch import nn
22
+ from torch.nn import L1Loss
23
+
24
+ from torch_geopooling.nn import (
25
+ AdaptiveAvgQuadPool2d,
26
+ AdaptiveMaxQuadPool2d,
27
+ AdaptiveQuadPool2d,
28
+ AvgQuadPool2d,
29
+ MaxQuadPool2d,
30
+ QuadPool2d,
31
+ )
32
+
33
+
34
+ @pytest.mark.parametrize(
35
+ "module_class",
36
+ [
37
+ AdaptiveQuadPool2d,
38
+ AdaptiveMaxQuadPool2d,
39
+ AdaptiveAvgQuadPool2d,
40
+ ],
41
+ ids=["id", "max", "avg"],
42
+ )
43
+ def test_adaptive_quad_pool2d_gradient(module_class: Type[nn.Module]) -> None:
44
+ pool = module_class(5, (-180, -90, 360, 180))
45
+
46
+ input = torch.rand((100, 2), dtype=torch.float64) * 90
47
+ y = pool(input)
48
+
49
+ assert pool.weight.grad is None
50
+
51
+ loss_fn = L1Loss()
52
+ loss = loss_fn(y, torch.ones_like(y))
53
+ loss.backward()
54
+
55
+ assert pool.weight.grad is not None
56
+ assert pool.weight.grad.sum().item() == pytest.approx(-1)
57
+
58
+
59
+ @pytest.mark.parametrize(
60
+ "module_class",
61
+ [
62
+ QuadPool2d,
63
+ MaxQuadPool2d,
64
+ AvgQuadPool2d,
65
+ ],
66
+ ids=["id", "max", "avg"],
67
+ )
68
+ def test_quad_pool2d_gradient(module_class: Type[nn.Module]) -> None:
69
+ poly = Polygon([(0.0, 0.0), (1.0, 0.0), (1.0, 1.1), (0.0, 1.0)])
70
+ exterior = (0.0, 0.0, 1.0, 1.0)
71
+
72
+ pool = module_class(4, poly, exterior, max_depth=5)
73
+ assert pool.weight.size() == torch.Size([pool.tiles.size(0), 4])
74
+
75
+ input = torch.rand((100, 2), dtype=torch.float64)
76
+ y = pool(input)
77
+
78
+ assert pool.weight.grad is None
79
+
80
+ loss_fn = L1Loss()
81
+ loss = loss_fn(y, torch.ones_like(y))
82
+ loss.backward()
83
+
84
+ assert pool.weight.grad is not None
85
+ assert pool.weight.grad.sum().item() == pytest.approx(-1)
@@ -0,0 +1,29 @@
1
+ # Copyright (C) 2024, Yakau Bubnou
2
+ #
3
+ # This program is free software: you can redistribute it and/or modify
4
+ # it under the terms of the GNU General Public License as published by
5
+ # the Free Software Foundation, either version 3 of the License, or
6
+ # (at your option) any later version.
7
+ #
8
+ # This program is distributed in the hope that it will be useful,
9
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
10
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11
+ # GNU General Public License for more details.
12
+ #
13
+ # You should have received a copy of the GNU General Public License
14
+ # along with this program. If not, see <https://www.gnu.org/licenses/>.
15
+
16
+ from typing import NamedTuple
17
+
18
+ from torch import Tensor
19
+
20
+
21
+ class quad_pool2d(NamedTuple):
22
+ tiles: Tensor
23
+ weight: Tensor
24
+ values: Tensor
25
+
26
+
27
+ class adaptive_quad_pool2d(NamedTuple):
28
+ weight: Tensor
29
+ values: Tensor
@@ -0,0 +1,97 @@
1
+ from __future__ import annotations
2
+
3
+ from collections import deque
4
+ from itertools import product
5
+ from typing import Iterator, NamedTuple, Tuple
6
+
7
+ from shapely.geometry import Polygon
8
+
9
+ __all__ = ["Exterior", "ExteriorTuple", "Tile", "regular_tiling"]
10
+
11
+
12
+ class Tile(NamedTuple):
13
+ z: int
14
+ x: int
15
+ y: int
16
+
17
+ @classmethod
18
+ def root(cls) -> Tile:
19
+ return cls(0, 0, 0)
20
+
21
+ def child(self, x: int, y: int) -> Tile:
22
+ return Tile(self.z + 1, self.x * 2 + x, self.y * 2 + y)
23
+
24
+ def children(self) -> Iterator[Tile]:
25
+ for x, y in product(range(2), range(2)):
26
+ yield self.child(x, y)
27
+
28
+
29
+ ExteriorTuple = Tuple[float, float, float, float]
30
+
31
+
32
+ class Exterior(NamedTuple):
33
+ xmin: float
34
+ ymin: float
35
+ width: float
36
+ height: float
37
+
38
+ @classmethod
39
+ def from_tuple(cls, exterior_tuple: ExteriorTuple) -> Exterior:
40
+ return cls(*exterior_tuple)
41
+
42
+ @property
43
+ def xmax(self) -> float:
44
+ return self.xmin + self.width
45
+
46
+ @property
47
+ def ymax(self) -> float:
48
+ return self.ymin + self.height
49
+
50
+ def slice(self, tile: Tile) -> Exterior:
51
+ w = self.width / (1 << tile.z)
52
+ h = self.height / (1 << tile.z)
53
+ return Exterior(self.xmin + tile.x * w, self.ymin + tile.y * h, w, h)
54
+
55
+ def as_polygon(self) -> Polygon:
56
+ return Polygon(
57
+ [
58
+ (self.xmin, self.ymin),
59
+ (self.xmax, self.ymin),
60
+ (self.xmax, self.ymax),
61
+ (self.xmin, self.ymax),
62
+ ]
63
+ )
64
+
65
+
66
+ def regular_tiling(
67
+ polygon: Polygon, exterior: Exterior, z: int, internal: bool = False
68
+ ) -> Iterator[Tile]:
69
+ """Returns a regular quad-tiling (tiles of the same size).
70
+
71
+ Method returns all tiles of level (z) that have a common intersection with a specified
72
+ polygon.
73
+
74
+ Args:
75
+ polygon: A polygon to cover with tiles.
76
+ exterior: Exterior (bounding box) of the quadtree. For example, for geospatial
77
+ coordinates, this will be `(-180.0, -90.0, 360.0, 180.0)`.
78
+ z: Zoom level of the tiles.
79
+ internal: When `True`, returns internal tiles (nodes) of the quadtree up to a root
80
+ tile (0,0,0).
81
+
82
+ Returns:
83
+ Iterator of tiles.
84
+ """
85
+ queue = deque([Tile.root()])
86
+
87
+ while len(queue) > 0:
88
+ tile = queue.pop()
89
+
90
+ tile_poly = exterior.slice(tile).as_polygon()
91
+ if not tile_poly.intersects(polygon):
92
+ continue
93
+
94
+ if internal or tile.z >= z:
95
+ yield tile
96
+ if tile.z < z:
97
+ queue.extend(tile.children())
@@ -0,0 +1,103 @@
1
+ # Copyright (C) 2024, Yakau Bubnou
2
+ #
3
+ # This program is free software: you can redistribute it and/or modify
4
+ # it under the terms of the GNU General Public License as published by
5
+ # the Free Software Foundation, either version 3 of the License, or
6
+ # (at your option) any later version.
7
+ #
8
+ # This program is distributed in the hope that it will be useful,
9
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
10
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11
+ # GNU General Public License for more details.
12
+ #
13
+ # You should have received a copy of the GNU General Public License
14
+ # along with this program. If not, see <https://www.gnu.org/licenses/>.
15
+
16
+ from __future__ import annotations
17
+
18
+ from typing import Iterator, Tuple
19
+
20
+ from torch import Tensor
21
+
22
+ from torch_geopooling.tiling import Tile
23
+
24
+ __all__ = ["TileWKT"]
25
+
26
+
27
+ class _TileSet(set):
28
+ def __init__(self, tiles: Tensor) -> None:
29
+ super().__init__(Tile(*tile.detach().tolist()) for tile in tiles)
30
+
31
+ def is_terminal(self, tile: Tile) -> bool:
32
+ return (
33
+ (tile.child(0, 0) not in self)
34
+ and (tile.child(0, 1) not in self)
35
+ and (tile.child(1, 0) not in self)
36
+ and (tile.child(1, 1) not in self)
37
+ )
38
+
39
+
40
+ class TileWKT:
41
+ """Convert a Tile to a WKT polygon given the exterior of the whole geometry.
42
+
43
+ Module returns a tile geometry in WKT format, which comprises a polygon.
44
+
45
+ Args:
46
+ exterior: Exterior coordinates in (X, Y, W, H) format. The exterior is used to calculate
47
+ boundaries of a tile to produce a final WKT.
48
+ precision: A precision of the resulting geometry, digits after the decimal point.
49
+ internal: When `True`, output includes internal nodes of the quadtree tiles.
50
+ Otherwise (default) returns only geometry of terminal nodes.
51
+ """
52
+
53
+ def __init__(
54
+ self,
55
+ exterior: Tuple[float, float, float, float],
56
+ precision: int = 7,
57
+ internal: bool = False,
58
+ ) -> None:
59
+ self.exterior = tuple(map(float, exterior))
60
+ self.precision = precision
61
+ self.internal = internal
62
+
63
+ self._xmin, self._ymin, self._width, self._height = exterior
64
+ if self._width <= 0:
65
+ raise ValueError(f"exterior width should be >0, got {self._width}")
66
+ if self._height <= 0:
67
+ raise ValueError(f"exterior height should be >0, got {self._height}")
68
+
69
+ def __call__(self, tiles: Tensor) -> Iterator[str]:
70
+ if len(tiles.size()) != 2:
71
+ raise ValueError(f"tiles tensor must be a 2D tensor, got {tiles.size()} shape")
72
+
73
+ if tiles.size(1) != 3:
74
+ raise ValueError(
75
+ f"tiles should be triplets of (z, x, y), got tensor of shape {tiles.size()}"
76
+ )
77
+
78
+ tileset = _TileSet(tiles)
79
+
80
+ for tile in tiles:
81
+ z, x, y = tile.detach().tolist()
82
+ width = self._width / (1 << z)
83
+ height = self._height / (1 << z)
84
+
85
+ if (not self.internal) and (not tileset.is_terminal(Tile(z, x, y))):
86
+ continue
87
+
88
+ xmin = self._xmin + width * x
89
+ xmax = round(xmin + width, self.precision)
90
+ xmin = round(xmin, self.precision)
91
+
92
+ ymin = self._ymin + height * y
93
+ ymax = round(ymin + height, self.precision)
94
+ ymin = round(ymin, self.precision)
95
+
96
+ yield (
97
+ "POLYGON (("
98
+ f"{xmin} {ymin}, {xmax} {ymin}, {xmax} {ymax}, {xmin} {ymax}, {xmin} {ymin}"
99
+ "))"
100
+ )
101
+
102
+ def __repr__(self) -> str:
103
+ return f"{self.__class__.__name__}(exterior={self.exterior})"