torch-geopooling 1.3.0__cp310-cp310-manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,32 @@
1
+ // SPDX-License-Identifier: GPL-3.0-or-later
2
+ // SPDX-FileCopyrightText: 2025 Yakau Bubnou
3
+ // SPDX-FileType: SOURCE
4
+
5
+ #include <torch_geopooling/torch_geopooling.h>
6
+
7
+ #include <pybind11/pybind11.h>
8
+ #include <torch/extension.h>
9
+
10
+ #include "python_tuples.h"
11
+
12
+
13
+ namespace torch_geopooling {
14
+
15
+
16
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
17
+ {
18
+ m.def("avg_quad_pool2d", &avg_quad_pool2d);
19
+ m.def("avg_quad_pool2d_backward", &avg_quad_pool2d_backward);
20
+
21
+ m.def("max_quad_pool2d", &max_quad_pool2d);
22
+ m.def("max_quad_pool2d_backward", &max_quad_pool2d_backward);
23
+
24
+ m.def("quad_pool2d", &quad_pool2d);
25
+ m.def("quad_pool2d_backward", &quad_pool2d_backward);
26
+
27
+ m.def("embedding2d", &embedding2d);
28
+ m.def("embedding2d_backward", &embedding2d_backward);
29
+ }
30
+
31
+
32
+ } // namespace torch_geopooling
@@ -0,0 +1,7 @@
1
+ # SPDX-License-Identifier: GPL-3.0-or-later
2
+ # SPDX-FileCopyrightText: 2025 Yakau Bubnou
3
+ # SPDX-FileType: SOURCE
4
+
5
+ from typing import Final
6
+
7
+ __version__: Final[str] = "1.3.0"
@@ -0,0 +1,6 @@
1
+ # SPDX-License-Identifier: GPL-3.0-or-later
2
+ # SPDX-FileCopyrightText: 2025 Yakau Bubnou
3
+ # SPDX-FileType: SOURCE
4
+
5
+ from torch_geopooling.functional.embedding import * # noqa
6
+ from torch_geopooling.functional.pooling import * # noqa
@@ -0,0 +1,81 @@
1
+ # SPDX-License-Identifier: GPL-3.0-or-later
2
+ # SPDX-FileCopyrightText: 2025 Yakau Bubnou
3
+ # SPDX-FileType: SOURCE
4
+
5
+ from typing import Optional, Tuple
6
+
7
+ from torch import Tensor, autograd
8
+ from torch.autograd.function import FunctionCtx
9
+
10
+ import torch_geopooling._C as _C
11
+ from torch_geopooling.tiling import ExteriorTuple
12
+
13
+
14
+ __all__ = ["embedding2d"]
15
+
16
+
17
+ class Function(autograd.Function):
18
+ @staticmethod
19
+ def forward(
20
+ input: Tensor,
21
+ weight: Tensor,
22
+ padding: Tuple[int, int],
23
+ exterior: ExteriorTuple,
24
+ reflection: bool,
25
+ ) -> Tensor:
26
+ return _C.embedding2d(input, weight, padding, exterior, reflection)
27
+
28
+ @staticmethod
29
+ def setup_context(ctx: FunctionCtx, inputs: Tuple, outputs: Tuple) -> None:
30
+ input, weight, padding, exterior, reflection = inputs
31
+
32
+ ctx.save_for_backward(input, weight)
33
+ ctx.padding = padding
34
+ ctx.exterior = exterior
35
+ ctx.reflection = reflection
36
+
37
+ @staticmethod
38
+ def backward(ctx: FunctionCtx, grad: Tensor) -> Tuple[Optional[Tensor], ...]:
39
+ input, weight = ctx.saved_tensors
40
+ grad_weight = _C.embedding2d_backward(
41
+ grad, input, weight, ctx.padding, ctx.exterior, ctx.reflection
42
+ )
43
+ return None, grad_weight, None, None, None
44
+
45
+
46
+ def embedding2d(
47
+ input: Tensor,
48
+ weight: Tensor,
49
+ *,
50
+ padding: Tuple[int, int] = (0, 0),
51
+ exterior: ExteriorTuple,
52
+ reflection: bool = True,
53
+ ) -> Tensor:
54
+ """
55
+ Retrieves spatial embeddings from a fixed-size lookup table based on 2D coordinates.
56
+
57
+ This function accepts a list of (x, y) coordinates and retrieves the corresponding
58
+ spatial embeddings from a provided embedding matrix. The embeddings are selected
59
+ based on the input coordinates, with an optional padding to include neighboring cells.
60
+ See :class:`torch_geopooling.nn.Embedding2d` for more details.
61
+
62
+ Args:
63
+ input: A list of 2D coordinates where each coordinate is represented as a tuple (x, y),
64
+ where x is the longitude and y is the latitude.
65
+ weight: A 3D tensor representing the embedding matrix. The first dimension corresponds to
66
+ the maximum possible bucket for the x coordinate, the second dimension corresponds to
67
+ the maximum possible bucket for the y coordinate, and the third dimension corresponds
68
+ to the embedding size.
69
+ padding: The size of the neighborhood to query. Default is 0, meaning only the embedding
70
+ for the exact input coordinate is retrieved.
71
+ exterior: The geometric boundary of the learning space, specified as a tuple (X, Y, W, H),
72
+ where X and Y represent the origin, and W and H represent the width and height of the
73
+ space, respectively.
74
+ reflection: When true, kernel is wrapped around the exterior space, otherwise kernel is
75
+ squeezed into borders.
76
+
77
+ Returns:
78
+ Tensor: The retrieved spatial embeddings corresponding to the input coordinates.
79
+ """
80
+
81
+ return Function.apply(input, weight, padding, exterior, reflection)
@@ -0,0 +1,21 @@
1
+ # SPDX-License-Identifier: GPL-3.0-or-later
2
+ # SPDX-FileCopyrightText: 2025 Yakau Bubnou
3
+ # SPDX-FileType: SOURCE
4
+
5
+ import torch
6
+ from torch_geopooling.functional.embedding import embedding2d
7
+
8
+
9
+ def test_embedding2d() -> None:
10
+ input = torch.rand((100, 2), dtype=torch.float64) * 10.0
11
+ weight = torch.rand((1024, 1024, 3), dtype=torch.float64)
12
+
13
+ result = embedding2d(
14
+ input,
15
+ weight,
16
+ padding=(3, 2),
17
+ exterior=(-10.0, -10.0, 20.0, 20.0),
18
+ reflection=True,
19
+ )
20
+
21
+ assert result.size() == torch.Size([100, 7, 5, 3])
@@ -0,0 +1,343 @@
1
+ # SPDX-License-Identifier: GPL-3.0-or-later
2
+ # SPDX-FileCopyrightText: 2025 Yakau Bubnou
3
+ # SPDX-FileType: SOURCE
4
+
5
+ from textwrap import dedent, indent
6
+ from functools import partial
7
+ from inspect import signature
8
+ from typing import Callable, NamedTuple, Optional, Tuple
9
+
10
+ import torch
11
+ from torch import Tensor, autograd
12
+ from torch.autograd.function import FunctionCtx
13
+
14
+ import torch_geopooling._C as _C
15
+ from torch_geopooling import return_types
16
+ from torch_geopooling.tiling import ExteriorTuple
17
+
18
+ __all__ = [
19
+ "adaptive_avg_quad_pool2d",
20
+ "adaptive_max_quad_pool2d",
21
+ "adaptive_quad_pool2d",
22
+ "avg_quad_pool2d",
23
+ "max_quad_pool2d",
24
+ "quad_pool2d",
25
+ ]
26
+
27
+
28
+ def __def__(fn: Callable, doc: str) -> Callable:
29
+ f = partial(fn)
30
+ f.__doc__ = doc + indent(dedent(fn.__doc__ or ""), " ")
31
+ f.__module__ = fn.__module__
32
+ f.__annotations__ = fn.__annotations__
33
+ f.__signature__ = signature(fn) # type: ignore
34
+ f.__defaults__ = fn.__defaults__ # type: ignore
35
+ f.__kwdefaults__ = fn.__kwdefaults__ # type: ignore
36
+ return f
37
+
38
+
39
+ class FunctionParams(NamedTuple):
40
+ max_terminal_nodes: Optional[int] = None
41
+ max_depth: Optional[int] = None
42
+ capacity: Optional[int] = None
43
+ precision: Optional[int] = None
44
+
45
+
46
+ ForwardType = Callable[
47
+ [
48
+ Tensor, # tiles
49
+ Tensor, # weight
50
+ Tensor, # input
51
+ ExteriorTuple, # exterior
52
+ bool, # training
53
+ Optional[int], # max_terminal_nodes
54
+ Optional[int], # max_depth
55
+ Optional[int], # capacity
56
+ Optional[int], # precision
57
+ ],
58
+ Tuple[Tensor, Tensor, Tensor], # (tiles, weight, values)
59
+ ]
60
+
61
+
62
+ BackwardType = Callable[
63
+ [
64
+ Tensor, # grad_output
65
+ Tensor, # tiles
66
+ Tensor, # weight
67
+ Tensor, # input
68
+ ExteriorTuple, # exterior
69
+ Optional[int], # max_terminal_nodes
70
+ Optional[int], # max_depth
71
+ Optional[int], # capacity
72
+ Optional[int], # precision
73
+ ],
74
+ Tensor, # (grad_weight)
75
+ ]
76
+
77
+
78
+ class Function(autograd.Function):
79
+ forward_impl: ForwardType
80
+ backward_impl: BackwardType
81
+
82
+ @classmethod
83
+ def forward(
84
+ cls,
85
+ tiles: Tensor,
86
+ weight: Tensor,
87
+ input: Tensor,
88
+ exterior: ExteriorTuple,
89
+ training: bool,
90
+ params: FunctionParams,
91
+ ) -> Tuple[Tensor, Tensor, Tensor]:
92
+ return cls.forward_impl(tiles, weight, input, exterior, training, *params)
93
+
94
+ @staticmethod
95
+ def setup_context(ctx: FunctionCtx, inputs: Tuple, outputs: Tuple) -> None:
96
+ _, _, input, exterior, _, params = inputs
97
+ tiles, weight, _ = outputs
98
+
99
+ ctx.save_for_backward(tiles.view_as(tiles), weight.view_as(weight), input.view_as(input))
100
+ ctx.exterior = exterior
101
+ ctx.params = params
102
+
103
+ @classmethod
104
+ def backward(
105
+ cls, ctx: FunctionCtx, grad_tiles: Tensor, grad_weight: Tensor, grad_values: Tensor
106
+ ) -> Tuple[Optional[Tensor], ...]:
107
+ grad_weight_out = cls.backward_impl(
108
+ grad_values, *ctx.saved_tensors, ctx.exterior, *ctx.params
109
+ ) # type: ignore
110
+ # Drop gradient for tiles, this should not be changed by an optimizer.
111
+ return None, grad_weight_out, None, None, None, None
112
+
113
+ @classmethod
114
+ def func(
115
+ cls,
116
+ tiles: Tensor,
117
+ weight: Tensor,
118
+ input: Tensor,
119
+ exterior: Tuple[float, ...],
120
+ *,
121
+ training: bool = True,
122
+ max_terminal_nodes: Optional[int] = None,
123
+ max_depth: Optional[int] = None,
124
+ capacity: Optional[int] = None,
125
+ precision: Optional[int] = None,
126
+ ) -> return_types.quad_pool2d:
127
+ """
128
+ Args:
129
+ tiles: Tiles tensor representing tiles of a quadtree (both, internal and terminal).
130
+ weight: Weights tensor associated with each tile of a quadtree.
131
+ input: Input 2D coordinates as pairs of x (longitude) and y (latitude).
132
+ exterior: Geometrical boundary of the learning space in (X, Y, W, H) format.
133
+ training: True, when executed during training, and False otherwise.
134
+ max_terminal_nodes: Optional maximum number of terminal nodes in a quadtree. Once a
135
+ maximum is reached, internal nodes are no longer sub-divided and tree stops
136
+ growing.
137
+ max_depth: Maximum depth of the quadtree. Default: 17.
138
+ capacity: Maximum number of inputs, after which a quadtree's node is subdivided and
139
+ depth of the tree grows. Default: 1.
140
+ precision: Optional rounding of the input coordinates. Default: 7.
141
+ """
142
+ params = FunctionParams(
143
+ max_terminal_nodes=max_terminal_nodes,
144
+ max_depth=max_depth,
145
+ capacity=capacity,
146
+ precision=precision,
147
+ )
148
+
149
+ result = cls.apply(tiles, weight, input, exterior, training, params)
150
+ return return_types.quad_pool2d(*result)
151
+
152
+
153
+ class QuadPool2d(Function):
154
+ forward_impl = _C.quad_pool2d
155
+ backward_impl = _C.quad_pool2d_backward
156
+
157
+
158
+ class MaxQuadPool2d(Function):
159
+ forward_impl = _C.max_quad_pool2d
160
+ backward_impl = _C.max_quad_pool2d_backward
161
+
162
+
163
+ class AvgQuadPool2d(Function):
164
+ forward_impl = _C.avg_quad_pool2d
165
+ backward_impl = _C.avg_quad_pool2d_backward
166
+
167
+
168
+ quad_pool2d = __def__(
169
+ QuadPool2d.func,
170
+ """Lookup index over quadtree decomposition of input 2D coordinates.
171
+
172
+ See :class:`torch_geopooling.nn.QuadPool2d` for more details.
173
+ """,
174
+ )
175
+ max_quad_pool2d = __def__(
176
+ MaxQuadPool2d.func,
177
+ """Maximum pooling over quadtree decomposition of input 2D coordinates.
178
+
179
+ See :class:`torch_geopooling.nn.MaxQuadPool2d` for more details.
180
+ """,
181
+ )
182
+ avg_quad_pool2d = __def__(
183
+ AvgQuadPool2d.func,
184
+ """Average pooling over quadtree decomposition of input 2D coordinates.
185
+
186
+ See :class:`torch_geopooling.nn.AvgQuadPool2d` for more details.
187
+ """,
188
+ )
189
+
190
+
191
+ class AdaptiveFunction(autograd.Function):
192
+ forward_impl: ForwardType
193
+ backward_impl: BackwardType
194
+
195
+ @staticmethod
196
+ def sparse_ravel(weight: Tensor) -> Tuple[Tensor, Tensor]:
197
+ """Transform weight as coordinate sparse tensor into a tuple of tiles and feature tensor.
198
+
199
+ The method transforms sparse encoding of quadtree (where 3 first dimensions are
200
+ coordinates of a tile and 4-th dimension is an index of a feature in the feature
201
+ vector), into tuple of coordinates (tiles) and dense weight tensor.
202
+
203
+ Effectively: (17,131072,131702,5) -> (nnz,3), (nnz,5); where nnz - is a number of
204
+ non-zero elements in the sparse tensor.
205
+ """
206
+ feature_dim = weight.size(-1)
207
+ weight = weight.coalesce()
208
+
209
+ # Transform sparse tensor into a tuple of (tiles, weight) that are directly usable
210
+ # by the C++ extension functions.
211
+ tiles = weight.indices().t()[::feature_dim, :-1]
212
+ w = weight.values().reshape((-1, feature_dim))
213
+ return tiles, w
214
+
215
+ @staticmethod
216
+ def sparse_unravel(tiles: Tensor, weight: Tensor, size: torch.Size) -> Tensor:
217
+ """Perform inverse operation of `ravel`.
218
+
219
+ Method packs tiles (coordinates) and weight (values) into a coordinate sparse tensor.
220
+ """
221
+ feature_dim = weight.size(-1)
222
+ feature_indices = torch.arange(0, feature_dim).repeat(tiles.size(0))
223
+
224
+ indices = tiles.repeat_interleave(feature_dim, dim=0)
225
+ indices = torch.column_stack((indices, feature_indices))
226
+
227
+ return torch.sparse_coo_tensor(indices.t(), weight.ravel(), size=size)
228
+
229
+ @classmethod
230
+ def forward(
231
+ cls,
232
+ weight: Tensor,
233
+ input: Tensor,
234
+ exterior: ExteriorTuple,
235
+ training: bool,
236
+ params: FunctionParams,
237
+ ) -> Tuple[Tensor, Tensor]:
238
+ tiles, w = cls.sparse_ravel(weight)
239
+
240
+ tiles_out, w_out, values_out = cls.forward_impl(
241
+ tiles, w, input, exterior, training, *params
242
+ )
243
+
244
+ weight_out = cls.sparse_unravel(tiles_out, w_out, size=weight.size())
245
+ return weight_out.coalesce(), values_out
246
+
247
+ @staticmethod
248
+ def setup_context(ctx: FunctionCtx, inputs: Tuple, outputs: Tuple) -> None:
249
+ _, input, exterior, _, params = inputs
250
+ weight, _ = outputs
251
+
252
+ ctx.save_for_backward(weight, input)
253
+ ctx.exterior = exterior
254
+ ctx.params = params
255
+
256
+ @classmethod
257
+ def backward(
258
+ cls, ctx: FunctionCtx, grad_weight: Tensor, grad_values: Tensor
259
+ ) -> Tuple[Optional[Tensor], ...]:
260
+ weight, input = ctx.saved_tensors
261
+ tiles, w = cls.sparse_ravel(weight)
262
+
263
+ grad_weight_dense = cls.backward_impl(
264
+ grad_values, tiles, w, input, ctx.exterior, *ctx.params
265
+ ) # type: ignore
266
+ grad_weight_sparse = cls.sparse_unravel(tiles, grad_weight_dense, size=weight.size())
267
+
268
+ return grad_weight_sparse.coalesce(), None, None, None, None
269
+
270
+ @classmethod
271
+ def func(
272
+ cls,
273
+ weight: Tensor,
274
+ input: Tensor,
275
+ exterior: Tuple[float, ...],
276
+ *,
277
+ training: bool = True,
278
+ max_terminal_nodes: Optional[int] = None,
279
+ max_depth: Optional[int] = None,
280
+ capacity: Optional[int] = None,
281
+ precision: Optional[int] = None,
282
+ ) -> return_types.adaptive_quad_pool2d:
283
+ """
284
+ Args:
285
+ weight: Weights tensor associated with each tile of a quadtree.
286
+ input: Input 2D coordinates as pairs of x (longitude) and y (latitude).
287
+ exterior: Geometrical boundary of the learning space in (X, Y, W, H) format.
288
+ training: True, when executed during training, and False otherwise.
289
+ max_terminal_nodes: Optional maximum number of terminal nodes in a quadtree. Once a
290
+ maximum is reached, internal nodes are no longer sub-divided and tree stops
291
+ growing.
292
+ max_depth: Maximum depth of the quadtree. Default: 17.
293
+ capacity: Maximum number of inputs, after which a quadtree's node is subdivided and
294
+ depth of the tree grows. Default: 1.
295
+ precision: Optional rounding of the input coordinates. Default: 7.
296
+ """
297
+ params = FunctionParams(
298
+ max_terminal_nodes=max_terminal_nodes,
299
+ max_depth=max_depth,
300
+ capacity=capacity,
301
+ precision=precision,
302
+ )
303
+
304
+ result = cls.apply(weight, input, exterior, training, params)
305
+ return return_types.adaptive_quad_pool2d(*result)
306
+
307
+
308
+ class AdaptiveQuadPool2d(AdaptiveFunction):
309
+ forward_impl = _C.quad_pool2d
310
+ backward_impl = _C.quad_pool2d_backward
311
+
312
+
313
+ class AdaptiveMaxQuadPool2d(AdaptiveFunction):
314
+ forward_impl = _C.max_quad_pool2d
315
+ backward_impl = _C.max_quad_pool2d_backward
316
+
317
+
318
+ class AdaptiveAvgQuadPool2d(AdaptiveFunction):
319
+ forward_impl = _C.avg_quad_pool2d
320
+ backward_impl = _C.avg_quad_pool2d_backward
321
+
322
+
323
+ adaptive_quad_pool2d = __def__(
324
+ AdaptiveQuadPool2d.func,
325
+ """Adaptive lookup index over quadtree decomposition of input 2D coordinates.
326
+
327
+ See :class:`torch_geopooling.nn.AdaptiveQuadPool2d` for more details.
328
+ """,
329
+ )
330
+ adaptive_max_quad_pool2d = __def__(
331
+ AdaptiveMaxQuadPool2d.func,
332
+ """Adaptive maximum pooling over quadtree decomposition of input 2D coordinates.
333
+
334
+ See :class:`torch_geopooling.nn.AdaptiveMaxQuadPool2d` for more details.
335
+ """,
336
+ )
337
+ adaptive_avg_quad_pool2d = __def__(
338
+ AdaptiveAvgQuadPool2d.func,
339
+ """Adaptive average pooling over quadtree decomposition of input 2D coordinates.
340
+
341
+ See :class:`torch_geopooling.nn.AdaptiveAvgQuadPool2d` for more details.
342
+ """,
343
+ )
@@ -0,0 +1,89 @@
1
+ # SPDX-License-Identifier: GPL-3.0-or-later
2
+ # SPDX-FileCopyrightText: 2025 Yakau Bubnou
3
+ # SPDX-FileType: SOURCE
4
+
5
+ import pytest
6
+ import torch
7
+
8
+ from torch_geopooling.functional.pooling import (
9
+ AdaptiveFunction,
10
+ adaptive_quad_pool2d,
11
+ adaptive_avg_quad_pool2d,
12
+ adaptive_max_quad_pool2d,
13
+ avg_quad_pool2d,
14
+ max_quad_pool2d,
15
+ quad_pool2d,
16
+ )
17
+
18
+
19
+ def test_adaptive_function_ravel() -> None:
20
+ size = (2, 2, 2, 1)
21
+ tiles = torch.tensor([[0, 0, 0], [1, 0, 0], [1, 0, 1], [1, 1, 0], [1, 1, 1]], dtype=torch.int64)
22
+
23
+ weight = torch.tensor([[1.0], [2.0], [3.0], [4.0], [5.0]], dtype=torch.float64)
24
+
25
+ sparse = AdaptiveFunction.sparse_unravel(tiles, weight, size=size)
26
+ torch.testing.assert_close(sparse.to_dense().to_sparse_coo(), sparse)
27
+
28
+ tiles_out, weight_out = AdaptiveFunction.sparse_ravel(sparse)
29
+ torch.testing.assert_close(tiles_out, tiles)
30
+ torch.testing.assert_close(weight_out, weight)
31
+
32
+
33
+ @pytest.mark.parametrize(
34
+ "function",
35
+ [
36
+ quad_pool2d,
37
+ max_quad_pool2d,
38
+ avg_quad_pool2d,
39
+ ],
40
+ ids=["id", "max", "avg"],
41
+ )
42
+ def test_quad_pool2d(function) -> None:
43
+ tiles = torch.empty((0, 3), dtype=torch.int64)
44
+ input = torch.rand((100, 2), dtype=torch.float64) * 10.0
45
+ weight = torch.randn([0, 5], dtype=torch.float64)
46
+
47
+ result = function(
48
+ tiles,
49
+ weight,
50
+ input,
51
+ (0.0, 0.0, 10.0, 10.0),
52
+ training=True,
53
+ max_depth=16,
54
+ capacity=1,
55
+ precision=6,
56
+ )
57
+ assert result.tiles.size(0) > 0
58
+ assert result.tiles.size(1) == 3
59
+
60
+ assert result.weight.size(0) == result.tiles.size(0)
61
+ assert result.values.size() == torch.Size([input.size(0), weight.size(1)])
62
+
63
+
64
+ @pytest.mark.parametrize(
65
+ "function",
66
+ [
67
+ adaptive_quad_pool2d,
68
+ adaptive_max_quad_pool2d,
69
+ adaptive_avg_quad_pool2d,
70
+ ],
71
+ ids=["id", "max", "avg"],
72
+ )
73
+ def test_adaptive_quad_pool2d(function) -> None:
74
+ input = torch.rand((100, 2), dtype=torch.float64) * 10.0
75
+ weight = torch.sparse_coo_tensor(size=(10, 1 << 10, 1 << 10, 4), dtype=torch.float64)
76
+
77
+ result = function(
78
+ weight,
79
+ input,
80
+ (0.0, 0.0, 10.0, 10.0),
81
+ training=True,
82
+ max_depth=16,
83
+ capacity=1,
84
+ precision=6,
85
+ )
86
+
87
+ assert result.weight.layout == torch.sparse_coo
88
+ assert result.weight.indices().size(0) > 0
89
+ assert result.values.size() == torch.Size([input.size(0), weight.size(-1)])
@@ -0,0 +1,6 @@
1
+ # SPDX-License-Identifier: GPL-3.0-or-later
2
+ # SPDX-FileCopyrightText: 2025 Yakau Bubnou
3
+ # SPDX-FileType: SOURCE
4
+
5
+ from torch_geopooling.nn.embedding import * # noqa
6
+ from torch_geopooling.nn.pooling import * # noqa
@@ -0,0 +1,91 @@
1
+ # SPDX-License-Identifier: GPL-3.0-or-later
2
+ # SPDX-FileCopyrightText: 2025 Yakau Bubnou
3
+ # SPDX-FileType: SOURCE
4
+
5
+ from typing import Union, Tuple, cast
6
+
7
+ import torch
8
+ from torch import Tensor, nn
9
+
10
+ from torch_geopooling import functional as F
11
+ from torch_geopooling.tiling import Exterior, ExteriorTuple
12
+
13
+
14
+ __all__ = [
15
+ "Embedding2d",
16
+ ]
17
+
18
+
19
+ _Exterior = Union[Exterior, ExteriorTuple]
20
+
21
+
22
+ class Embedding2d(nn.Module):
23
+ """
24
+ Retrieves spatial embeddings from a fixed-size lookup table based on 2D coordinates.
25
+
26
+ This module accepts a tensor of (x, y) coordinates and retrieves the corresponding
27
+ spatial embeddings from a provided embedding matrix. The embeddings are selected
28
+ based on the input coordinates, with an optional padding to include neighboring cells.
29
+
30
+ Args:
31
+ manifold: The size of the 2-dimensional embedding in a form (W, H, N), where
32
+ W is a width, H is a height, and N is a feature dimension of the embedding.
33
+ padding: The size of the neighborhood to query. Default is 0, meaning only the embedding
34
+ for the exact input coordinate is retrieved.
35
+ exterior: The geometric boundary of the learning space, specified as a tuple (X, Y, W, H),
36
+ where X and Y represent the origin, and W and H represent the width and height of the
37
+ space, respectively.
38
+ reflection: When true, kernel is wrapped around the exterior space, otherwise kernel is
39
+ squeezed into borders.
40
+
41
+ Shape:
42
+ - Input: :math:`(*, 2)`, where 2 comprises x and y coordinates.
43
+ - Output: :math:`(*, X_{out}, Y_{out}, N)`, where * is the input shape, \
44
+ :math:`N = \\text{manifold[2]}`, and
45
+
46
+ :math:`X_{out} = \\text{padding}[0] \\times 2 + 1`
47
+
48
+ :math:`Y_{out} = \\text{padding}[1] \\times 2 + 1`
49
+
50
+ Examples:
51
+
52
+ >>> # Create an embedding of EPSG:4326 rectangle into 1024x1024 embedding
53
+ >>> # with 3 features in each cell.
54
+ >>> embedding = nn.Embedding2d(
55
+ ... (1024, 1024, 3),
56
+ ... exterior=(-180.0, -90.0, 360.0, 180.0),
57
+ ... padding=(2, 2),
58
+ ... )
59
+ >>> input = torch.rand((100, 2), dtype=torch.float64) * 60.0
60
+ >>> output = embedding(input)
61
+ """
62
+
63
+ def __init__(
64
+ self,
65
+ manifold: Tuple[int, int, int],
66
+ exterior: _Exterior,
67
+ padding: Tuple[int, int] = (0, 0),
68
+ reflection: bool = True,
69
+ ) -> None:
70
+ super().__init__()
71
+ self.manifold = manifold
72
+ self.exterior = cast(ExteriorTuple, tuple(map(float, exterior)))
73
+ self.padding = padding
74
+ self.reflection = reflection
75
+
76
+ self.weight = nn.Parameter(torch.empty(manifold, dtype=torch.float64))
77
+ nn.init.zeros_(self.weight)
78
+
79
+ def extra_repr(self) -> str:
80
+ return "{manifold}, exterior={exterior}, padding={padding}, reflection={reflection}".format(
81
+ **self.__dict__
82
+ )
83
+
84
+ def forward(self, input: Tensor) -> Tensor:
85
+ return F.embedding2d(
86
+ input,
87
+ self.weight,
88
+ exterior=self.exterior,
89
+ padding=self.padding,
90
+ reflection=self.reflection,
91
+ )