torch-geopooling 1.2.0__cp312-cp312-manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of torch-geopooling might be problematic. Click here for more details.

@@ -0,0 +1,28 @@
1
+ #include <torch_geopooling/torch_geopooling.h>
2
+
3
+ #include <pybind11/pybind11.h>
4
+ #include <torch/extension.h>
5
+
6
+ #include "python_tuples.h"
7
+
8
+
9
+ namespace torch_geopooling {
10
+
11
+
12
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
13
+ {
14
+ m.def("avg_quad_pool2d", &avg_quad_pool2d);
15
+ m.def("avg_quad_pool2d_backward", &avg_quad_pool2d_backward);
16
+
17
+ m.def("max_quad_pool2d", &max_quad_pool2d);
18
+ m.def("max_quad_pool2d_backward", &max_quad_pool2d_backward);
19
+
20
+ m.def("quad_pool2d", &quad_pool2d);
21
+ m.def("quad_pool2d_backward", &quad_pool2d_backward);
22
+
23
+ m.def("embedding2d", &embedding2d);
24
+ m.def("embedding2d_backward", &embedding2d_backward);
25
+ }
26
+
27
+
28
+ } // namespace torch_geopooling
@@ -0,0 +1,3 @@
1
+ from typing import Final
2
+
3
+ __version__: Final[str] = "1.2.0"
@@ -0,0 +1,2 @@
1
+ from torch_geopooling.functional.embedding import * # noqa
2
+ from torch_geopooling.functional.pooling import * # noqa
@@ -0,0 +1,92 @@
1
+ # Copyright (C) 2024, Yakau Bubnou
2
+ #
3
+ # This program is free software: you can redistribute it and/or modify
4
+ # it under the terms of the GNU General Public License as published by
5
+ # the Free Software Foundation, either version 3 of the License, or
6
+ # (at your option) any later version.
7
+ #
8
+ # This program is distributed in the hope that it will be useful,
9
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
10
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11
+ # GNU General Public License for more details.
12
+ #
13
+ # You should have received a copy of the GNU General Public License
14
+ # along with this program. If not, see <https://www.gnu.org/licenses/>.
15
+
16
+ from typing import Optional, Tuple
17
+
18
+ from torch import Tensor, autograd
19
+ from torch.autograd.function import FunctionCtx
20
+
21
+ import torch_geopooling._C as _C
22
+ from torch_geopooling.tiling import ExteriorTuple
23
+
24
+
25
+ __all__ = ["embedding2d"]
26
+
27
+
28
+ class Function(autograd.Function):
29
+ @staticmethod
30
+ def forward(
31
+ input: Tensor,
32
+ weight: Tensor,
33
+ padding: Tuple[int, int],
34
+ exterior: ExteriorTuple,
35
+ reflection: bool,
36
+ ) -> Tensor:
37
+ return _C.embedding2d(input, weight, padding, exterior, reflection)
38
+
39
+ @staticmethod
40
+ def setup_context(ctx: FunctionCtx, inputs: Tuple, outputs: Tuple) -> None:
41
+ input, weight, padding, exterior, reflection = inputs
42
+
43
+ ctx.save_for_backward(input, weight)
44
+ ctx.padding = padding
45
+ ctx.exterior = exterior
46
+ ctx.reflection = reflection
47
+
48
+ @staticmethod
49
+ def backward(ctx: FunctionCtx, grad: Tensor) -> Tuple[Optional[Tensor], ...]:
50
+ input, weight = ctx.saved_tensors
51
+ grad_weight = _C.embedding2d_backward(
52
+ grad, input, weight, ctx.padding, ctx.exterior, ctx.reflection
53
+ )
54
+ return None, grad_weight, None, None, None
55
+
56
+
57
+ def embedding2d(
58
+ input: Tensor,
59
+ weight: Tensor,
60
+ *,
61
+ padding: Tuple[int, int] = (0, 0),
62
+ exterior: ExteriorTuple,
63
+ reflection: bool = True,
64
+ ) -> Tensor:
65
+ """
66
+ Retrieves spatial embeddings from a fixed-size lookup table based on 2D coordinates.
67
+
68
+ This function accepts a list of (x, y) coordinates and retrieves the corresponding
69
+ spatial embeddings from a provided embedding matrix. The embeddings are selected
70
+ based on the input coordinates, with an optional padding to include neighboring cells.
71
+ See :class:`torch_geopooling.nn.Embedding2d` for more details.
72
+
73
+ Args:
74
+ input: A list of 2D coordinates where each coordinate is represented as a tuple (x, y),
75
+ where x is the longitude and y is the latitude.
76
+ weight: A 3D tensor representing the embedding matrix. The first dimension corresponds to
77
+ the maximum possible bucket for the x coordinate, the second dimension corresponds to
78
+ the maximum possible bucket for the y coordinate, and the third dimension corresponds
79
+ to the embedding size.
80
+ padding: The size of the neighborhood to query. Default is 0, meaning only the embedding
81
+ for the exact input coordinate is retrieved.
82
+ exterior: The geometric boundary of the learning space, specified as a tuple (X, Y, W, H),
83
+ where X and Y represent the origin, and W and H represent the width and height of the
84
+ space, respectively.
85
+ reflection: When true, kernel is wrapped around the exterior space, otherwise kernel is
86
+ squeezed into borders.
87
+
88
+ Returns:
89
+ Tensor: The retrieved spatial embeddings corresponding to the input coordinates.
90
+ """
91
+
92
+ return Function.apply(input, weight, padding, exterior, reflection)
@@ -0,0 +1,32 @@
1
+ # Copyright (C) 2024, Yakau Bubnou
2
+ #
3
+ # This program is free software: you can redistribute it and/or modify
4
+ # it under the terms of the GNU General Public License as published by
5
+ # the Free Software Foundation, either version 3 of the License, or
6
+ # (at your option) any later version.
7
+ #
8
+ # This program is distributed in the hope that it will be useful,
9
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
10
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11
+ # GNU General Public License for more details.
12
+ #
13
+ # You should have received a copy of the GNU General Public License
14
+ # along with this program. If not, see <https://www.gnu.org/licenses/>.
15
+
16
+ import torch
17
+ from torch_geopooling.functional.embedding import embedding2d
18
+
19
+
20
+ def test_embedding2d() -> None:
21
+ input = torch.rand((100, 2), dtype=torch.float64) * 10.0
22
+ weight = torch.rand((1024, 1024, 3), dtype=torch.float64)
23
+
24
+ result = embedding2d(
25
+ input,
26
+ weight,
27
+ padding=(3, 2),
28
+ exterior=(-10.0, -10.0, 20.0, 20.0),
29
+ reflection=True,
30
+ )
31
+
32
+ assert result.size() == torch.Size([100, 7, 5, 3])
@@ -0,0 +1,354 @@
1
+ # Copyright (C) 2024, Yakau Bubnou
2
+ #
3
+ # This program is free software: you can redistribute it and/or modify
4
+ # it under the terms of the GNU General Public License as published by
5
+ # the Free Software Foundation, either version 3 of the License, or
6
+ # (at your option) any later version.
7
+ #
8
+ # This program is distributed in the hope that it will be useful,
9
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
10
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11
+ # GNU General Public License for more details.
12
+ #
13
+ # You should have received a copy of the GNU General Public License
14
+ # along with this program. If not, see <https://www.gnu.org/licenses/>.
15
+
16
+ from textwrap import dedent, indent
17
+ from functools import partial
18
+ from inspect import signature
19
+ from typing import Callable, NamedTuple, Optional, Tuple
20
+
21
+ import torch
22
+ from torch import Tensor, autograd
23
+ from torch.autograd.function import FunctionCtx
24
+
25
+ import torch_geopooling._C as _C
26
+ from torch_geopooling import return_types
27
+ from torch_geopooling.tiling import ExteriorTuple
28
+
29
+ __all__ = [
30
+ "adaptive_avg_quad_pool2d",
31
+ "adaptive_max_quad_pool2d",
32
+ "adaptive_quad_pool2d",
33
+ "avg_quad_pool2d",
34
+ "max_quad_pool2d",
35
+ "quad_pool2d",
36
+ ]
37
+
38
+
39
+ def __def__(fn: Callable, doc: str) -> Callable:
40
+ f = partial(fn)
41
+ f.__doc__ = doc + indent(dedent(fn.__doc__ or ""), " ")
42
+ f.__module__ = fn.__module__
43
+ f.__annotations__ = fn.__annotations__
44
+ f.__signature__ = signature(fn) # type: ignore
45
+ f.__defaults__ = fn.__defaults__ # type: ignore
46
+ f.__kwdefaults__ = fn.__kwdefaults__ # type: ignore
47
+ return f
48
+
49
+
50
+ class FunctionParams(NamedTuple):
51
+ max_terminal_nodes: Optional[int] = None
52
+ max_depth: Optional[int] = None
53
+ capacity: Optional[int] = None
54
+ precision: Optional[int] = None
55
+
56
+
57
+ ForwardType = Callable[
58
+ [
59
+ Tensor, # tiles
60
+ Tensor, # weight
61
+ Tensor, # input
62
+ ExteriorTuple, # exterior
63
+ bool, # training
64
+ Optional[int], # max_terminal_nodes
65
+ Optional[int], # max_depth
66
+ Optional[int], # capacity
67
+ Optional[int], # precision
68
+ ],
69
+ Tuple[Tensor, Tensor, Tensor], # (tiles, weight, values)
70
+ ]
71
+
72
+
73
+ BackwardType = Callable[
74
+ [
75
+ Tensor, # grad_output
76
+ Tensor, # tiles
77
+ Tensor, # weight
78
+ Tensor, # input
79
+ ExteriorTuple, # exterior
80
+ Optional[int], # max_terminal_nodes
81
+ Optional[int], # max_depth
82
+ Optional[int], # capacity
83
+ Optional[int], # precision
84
+ ],
85
+ Tensor, # (grad_weight)
86
+ ]
87
+
88
+
89
+ class Function(autograd.Function):
90
+ forward_impl: ForwardType
91
+ backward_impl: BackwardType
92
+
93
+ @classmethod
94
+ def forward(
95
+ cls,
96
+ tiles: Tensor,
97
+ weight: Tensor,
98
+ input: Tensor,
99
+ exterior: ExteriorTuple,
100
+ training: bool,
101
+ params: FunctionParams,
102
+ ) -> Tuple[Tensor, Tensor, Tensor]:
103
+ return cls.forward_impl(tiles, weight, input, exterior, training, *params)
104
+
105
+ @staticmethod
106
+ def setup_context(ctx: FunctionCtx, inputs: Tuple, outputs: Tuple) -> None:
107
+ _, _, input, exterior, _, params = inputs
108
+ tiles, weight, _ = outputs
109
+
110
+ ctx.save_for_backward(tiles.view_as(tiles), weight.view_as(weight), input.view_as(input))
111
+ ctx.exterior = exterior
112
+ ctx.params = params
113
+
114
+ @classmethod
115
+ def backward(
116
+ cls, ctx: FunctionCtx, grad_tiles: Tensor, grad_weight: Tensor, grad_values: Tensor
117
+ ) -> Tuple[Optional[Tensor], ...]:
118
+ grad_weight_out = cls.backward_impl(
119
+ grad_values, *ctx.saved_tensors, ctx.exterior, *ctx.params
120
+ ) # type: ignore
121
+ # Drop gradient for tiles, this should not be changed by an optimizer.
122
+ return None, grad_weight_out, None, None, None, None
123
+
124
+ @classmethod
125
+ def func(
126
+ cls,
127
+ tiles: Tensor,
128
+ weight: Tensor,
129
+ input: Tensor,
130
+ exterior: Tuple[float, ...],
131
+ *,
132
+ training: bool = True,
133
+ max_terminal_nodes: Optional[int] = None,
134
+ max_depth: Optional[int] = None,
135
+ capacity: Optional[int] = None,
136
+ precision: Optional[int] = None,
137
+ ) -> return_types.quad_pool2d:
138
+ """
139
+ Args:
140
+ tiles: Tiles tensor representing tiles of a quadtree (both, internal and terminal).
141
+ weight: Weights tensor associated with each tile of a quadtree.
142
+ input: Input 2D coordinates as pairs of x (longitude) and y (latitude).
143
+ exterior: Geometrical boundary of the learning space in (X, Y, W, H) format.
144
+ training: True, when executed during training, and False otherwise.
145
+ max_terminal_nodes: Optional maximum number of terminal nodes in a quadtree. Once a
146
+ maximum is reached, internal nodes are no longer sub-divided and tree stops
147
+ growing.
148
+ max_depth: Maximum depth of the quadtree. Default: 17.
149
+ capacity: Maximum number of inputs, after which a quadtree's node is subdivided and
150
+ depth of the tree grows. Default: 1.
151
+ precision: Optional rounding of the input coordinates. Default: 7.
152
+ """
153
+ params = FunctionParams(
154
+ max_terminal_nodes=max_terminal_nodes,
155
+ max_depth=max_depth,
156
+ capacity=capacity,
157
+ precision=precision,
158
+ )
159
+
160
+ result = cls.apply(tiles, weight, input, exterior, training, params)
161
+ return return_types.quad_pool2d(*result)
162
+
163
+
164
+ class QuadPool2d(Function):
165
+ forward_impl = _C.quad_pool2d
166
+ backward_impl = _C.quad_pool2d_backward
167
+
168
+
169
+ class MaxQuadPool2d(Function):
170
+ forward_impl = _C.max_quad_pool2d
171
+ backward_impl = _C.max_quad_pool2d_backward
172
+
173
+
174
+ class AvgQuadPool2d(Function):
175
+ forward_impl = _C.avg_quad_pool2d
176
+ backward_impl = _C.avg_quad_pool2d_backward
177
+
178
+
179
+ quad_pool2d = __def__(
180
+ QuadPool2d.func,
181
+ """Lookup index over quadtree decomposition of input 2D coordinates.
182
+
183
+ See :class:`torch_geopooling.nn.QuadPool2d` for more details.
184
+ """,
185
+ )
186
+ max_quad_pool2d = __def__(
187
+ MaxQuadPool2d.func,
188
+ """Maximum pooling over quadtree decomposition of input 2D coordinates.
189
+
190
+ See :class:`torch_geopooling.nn.MaxQuadPool2d` for more details.
191
+ """,
192
+ )
193
+ avg_quad_pool2d = __def__(
194
+ AvgQuadPool2d.func,
195
+ """Average pooling over quadtree decomposition of input 2D coordinates.
196
+
197
+ See :class:`torch_geopooling.nn.AvgQuadPool2d` for more details.
198
+ """,
199
+ )
200
+
201
+
202
+ class AdaptiveFunction(autograd.Function):
203
+ forward_impl: ForwardType
204
+ backward_impl: BackwardType
205
+
206
+ @staticmethod
207
+ def sparse_ravel(weight: Tensor) -> Tuple[Tensor, Tensor]:
208
+ """Transform weight as coordinate sparse tensor into a tuple of tiles and feature tensor.
209
+
210
+ The method transforms sparse encoding of quadtree (where 3 first dimensions are
211
+ coordinates of a tile and 4-th dimension is an index of a feature in the feature
212
+ vector), into tuple of coordinates (tiles) and dense weight tensor.
213
+
214
+ Effectively: (17,131072,131702,5) -> (nnz,3), (nnz,5); where nnz - is a number of
215
+ non-zero elements in the sparse tensor.
216
+ """
217
+ feature_dim = weight.size(-1)
218
+ weight = weight.coalesce()
219
+
220
+ # Transform sparse tensor into a tuple of (tiles, weight) that are directly usable
221
+ # by the C++ extension functions.
222
+ tiles = weight.indices().t()[::feature_dim, :-1]
223
+ w = weight.values().reshape((-1, feature_dim))
224
+ return tiles, w
225
+
226
+ @staticmethod
227
+ def sparse_unravel(tiles: Tensor, weight: Tensor, size: torch.Size) -> Tensor:
228
+ """Perform inverse operation of `ravel`.
229
+
230
+ Method packs tiles (coordinates) and weight (values) into a coordinate sparse tensor.
231
+ """
232
+ feature_dim = weight.size(-1)
233
+ feature_indices = torch.arange(0, feature_dim).repeat(tiles.size(0))
234
+
235
+ indices = tiles.repeat_interleave(feature_dim, dim=0)
236
+ indices = torch.column_stack((indices, feature_indices))
237
+
238
+ return torch.sparse_coo_tensor(indices.t(), weight.ravel(), size=size)
239
+
240
+ @classmethod
241
+ def forward(
242
+ cls,
243
+ weight: Tensor,
244
+ input: Tensor,
245
+ exterior: ExteriorTuple,
246
+ training: bool,
247
+ params: FunctionParams,
248
+ ) -> Tuple[Tensor, Tensor]:
249
+ tiles, w = cls.sparse_ravel(weight)
250
+
251
+ tiles_out, w_out, values_out = cls.forward_impl(
252
+ tiles, w, input, exterior, training, *params
253
+ )
254
+
255
+ weight_out = cls.sparse_unravel(tiles_out, w_out, size=weight.size())
256
+ return weight_out.coalesce(), values_out
257
+
258
+ @staticmethod
259
+ def setup_context(ctx: FunctionCtx, inputs: Tuple, outputs: Tuple) -> None:
260
+ _, input, exterior, _, params = inputs
261
+ weight, _ = outputs
262
+
263
+ ctx.save_for_backward(weight, input)
264
+ ctx.exterior = exterior
265
+ ctx.params = params
266
+
267
+ @classmethod
268
+ def backward(
269
+ cls, ctx: FunctionCtx, grad_weight: Tensor, grad_values: Tensor
270
+ ) -> Tuple[Optional[Tensor], ...]:
271
+ weight, input = ctx.saved_tensors
272
+ tiles, w = cls.sparse_ravel(weight)
273
+
274
+ grad_weight_dense = cls.backward_impl(
275
+ grad_values, tiles, w, input, ctx.exterior, *ctx.params
276
+ ) # type: ignore
277
+ grad_weight_sparse = cls.sparse_unravel(tiles, grad_weight_dense, size=weight.size())
278
+
279
+ return grad_weight_sparse.coalesce(), None, None, None, None
280
+
281
+ @classmethod
282
+ def func(
283
+ cls,
284
+ weight: Tensor,
285
+ input: Tensor,
286
+ exterior: Tuple[float, ...],
287
+ *,
288
+ training: bool = True,
289
+ max_terminal_nodes: Optional[int] = None,
290
+ max_depth: Optional[int] = None,
291
+ capacity: Optional[int] = None,
292
+ precision: Optional[int] = None,
293
+ ) -> return_types.adaptive_quad_pool2d:
294
+ """
295
+ Args:
296
+ weight: Weights tensor associated with each tile of a quadtree.
297
+ input: Input 2D coordinates as pairs of x (longitude) and y (latitude).
298
+ exterior: Geometrical boundary of the learning space in (X, Y, W, H) format.
299
+ training: True, when executed during training, and False otherwise.
300
+ max_terminal_nodes: Optional maximum number of terminal nodes in a quadtree. Once a
301
+ maximum is reached, internal nodes are no longer sub-divided and tree stops
302
+ growing.
303
+ max_depth: Maximum depth of the quadtree. Default: 17.
304
+ capacity: Maximum number of inputs, after which a quadtree's node is subdivided and
305
+ depth of the tree grows. Default: 1.
306
+ precision: Optional rounding of the input coordinates. Default: 7.
307
+ """
308
+ params = FunctionParams(
309
+ max_terminal_nodes=max_terminal_nodes,
310
+ max_depth=max_depth,
311
+ capacity=capacity,
312
+ precision=precision,
313
+ )
314
+
315
+ result = cls.apply(weight, input, exterior, training, params)
316
+ return return_types.adaptive_quad_pool2d(*result)
317
+
318
+
319
+ class AdaptiveQuadPool2d(AdaptiveFunction):
320
+ forward_impl = _C.quad_pool2d
321
+ backward_impl = _C.quad_pool2d_backward
322
+
323
+
324
+ class AdaptiveMaxQuadPool2d(AdaptiveFunction):
325
+ forward_impl = _C.max_quad_pool2d
326
+ backward_impl = _C.max_quad_pool2d_backward
327
+
328
+
329
+ class AdaptiveAvgQuadPool2d(AdaptiveFunction):
330
+ forward_impl = _C.avg_quad_pool2d
331
+ backward_impl = _C.avg_quad_pool2d_backward
332
+
333
+
334
+ adaptive_quad_pool2d = __def__(
335
+ AdaptiveQuadPool2d.func,
336
+ """Adaptive lookup index over quadtree decomposition of input 2D coordinates.
337
+
338
+ See :class:`torch_geopooling.nn.AdaptiveQuadPool2d` for more details.
339
+ """,
340
+ )
341
+ adaptive_max_quad_pool2d = __def__(
342
+ AdaptiveMaxQuadPool2d.func,
343
+ """Adaptive maximum pooling over quadtree decomposition of input 2D coordinates.
344
+
345
+ See :class:`torch_geopooling.nn.AdaptiveMaxQuadPool2d` for more details.
346
+ """,
347
+ )
348
+ adaptive_avg_quad_pool2d = __def__(
349
+ AdaptiveAvgQuadPool2d.func,
350
+ """Adaptive average pooling over quadtree decomposition of input 2D coordinates.
351
+
352
+ See :class:`torch_geopooling.nn.AdaptiveAvgQuadPool2d` for more details.
353
+ """,
354
+ )
@@ -0,0 +1,100 @@
1
+ # Copyright (C) 2024, Yakau Bubnou
2
+ #
3
+ # This program is free software: you can redistribute it and/or modify
4
+ # it under the terms of the GNU General Public License as published by
5
+ # the Free Software Foundation, either version 3 of the License, or
6
+ # (at your option) any later version.
7
+ #
8
+ # This program is distributed in the hope that it will be useful,
9
+ # but WITHOUT ANY WARRANTY; without even the implied warranty of
10
+ # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11
+ # GNU General Public License for more details.
12
+ #
13
+ # You should have received a copy of the GNU General Public License
14
+ # along with this program. If not, see <https://www.gnu.org/licenses/>.
15
+
16
+ import pytest
17
+ import torch
18
+
19
+ from torch_geopooling.functional.pooling import (
20
+ AdaptiveFunction,
21
+ adaptive_quad_pool2d,
22
+ adaptive_avg_quad_pool2d,
23
+ adaptive_max_quad_pool2d,
24
+ avg_quad_pool2d,
25
+ max_quad_pool2d,
26
+ quad_pool2d,
27
+ )
28
+
29
+
30
+ def test_adaptive_function_ravel() -> None:
31
+ size = (2, 2, 2, 1)
32
+ tiles = torch.tensor([[0, 0, 0], [1, 0, 0], [1, 0, 1], [1, 1, 0], [1, 1, 1]], dtype=torch.int64)
33
+
34
+ weight = torch.tensor([[1.0], [2.0], [3.0], [4.0], [5.0]], dtype=torch.float64)
35
+
36
+ sparse = AdaptiveFunction.sparse_unravel(tiles, weight, size=size)
37
+ torch.testing.assert_close(sparse.to_dense().to_sparse_coo(), sparse)
38
+
39
+ tiles_out, weight_out = AdaptiveFunction.sparse_ravel(sparse)
40
+ torch.testing.assert_close(tiles_out, tiles)
41
+ torch.testing.assert_close(weight_out, weight)
42
+
43
+
44
+ @pytest.mark.parametrize(
45
+ "function",
46
+ [
47
+ quad_pool2d,
48
+ max_quad_pool2d,
49
+ avg_quad_pool2d,
50
+ ],
51
+ ids=["id", "max", "avg"],
52
+ )
53
+ def test_quad_pool2d(function) -> None:
54
+ tiles = torch.empty((0, 3), dtype=torch.int64)
55
+ input = torch.rand((100, 2), dtype=torch.float64) * 10.0
56
+ weight = torch.randn([0, 5], dtype=torch.float64)
57
+
58
+ result = function(
59
+ tiles,
60
+ weight,
61
+ input,
62
+ (0.0, 0.0, 10.0, 10.0),
63
+ training=True,
64
+ max_depth=16,
65
+ capacity=1,
66
+ precision=6,
67
+ )
68
+ assert result.tiles.size(0) > 0
69
+ assert result.tiles.size(1) == 3
70
+
71
+ assert result.weight.size(0) == result.tiles.size(0)
72
+ assert result.values.size() == torch.Size([input.size(0), weight.size(1)])
73
+
74
+
75
+ @pytest.mark.parametrize(
76
+ "function",
77
+ [
78
+ adaptive_quad_pool2d,
79
+ adaptive_max_quad_pool2d,
80
+ adaptive_avg_quad_pool2d,
81
+ ],
82
+ ids=["id", "max", "avg"],
83
+ )
84
+ def test_adaptive_quad_pool2d(function) -> None:
85
+ input = torch.rand((100, 2), dtype=torch.float64) * 10.0
86
+ weight = torch.sparse_coo_tensor(size=(10, 1 << 10, 1 << 10, 4), dtype=torch.float64)
87
+
88
+ result = function(
89
+ weight,
90
+ input,
91
+ (0.0, 0.0, 10.0, 10.0),
92
+ training=True,
93
+ max_depth=16,
94
+ capacity=1,
95
+ precision=6,
96
+ )
97
+
98
+ assert result.weight.layout == torch.sparse_coo
99
+ assert result.weight.indices().size(0) > 0
100
+ assert result.values.size() == torch.Size([input.size(0), weight.size(-1)])
@@ -0,0 +1,2 @@
1
+ from torch_geopooling.nn.embedding import * # noqa
2
+ from torch_geopooling.nn.pooling import * # noqa