rslearn 0.0.1__py3-none-any.whl → 0.0.21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (166) hide show
  1. rslearn/arg_parser.py +31 -0
  2. rslearn/config/__init__.py +6 -12
  3. rslearn/config/dataset.py +520 -401
  4. rslearn/const.py +9 -15
  5. rslearn/data_sources/__init__.py +8 -23
  6. rslearn/data_sources/aws_landsat.py +242 -98
  7. rslearn/data_sources/aws_open_data.py +111 -151
  8. rslearn/data_sources/aws_sentinel1.py +131 -0
  9. rslearn/data_sources/climate_data_store.py +471 -0
  10. rslearn/data_sources/copernicus.py +884 -12
  11. rslearn/data_sources/data_source.py +43 -12
  12. rslearn/data_sources/earthdaily.py +484 -0
  13. rslearn/data_sources/earthdata_srtm.py +282 -0
  14. rslearn/data_sources/eurocrops.py +242 -0
  15. rslearn/data_sources/gcp_public_data.py +578 -222
  16. rslearn/data_sources/google_earth_engine.py +461 -135
  17. rslearn/data_sources/local_files.py +219 -150
  18. rslearn/data_sources/openstreetmap.py +51 -89
  19. rslearn/data_sources/planet.py +24 -60
  20. rslearn/data_sources/planet_basemap.py +275 -0
  21. rslearn/data_sources/planetary_computer.py +798 -0
  22. rslearn/data_sources/usda_cdl.py +195 -0
  23. rslearn/data_sources/usgs_landsat.py +115 -83
  24. rslearn/data_sources/utils.py +249 -61
  25. rslearn/data_sources/vector_source.py +1 -0
  26. rslearn/data_sources/worldcereal.py +449 -0
  27. rslearn/data_sources/worldcover.py +144 -0
  28. rslearn/data_sources/worldpop.py +153 -0
  29. rslearn/data_sources/xyz_tiles.py +150 -107
  30. rslearn/dataset/__init__.py +8 -2
  31. rslearn/dataset/add_windows.py +2 -2
  32. rslearn/dataset/dataset.py +40 -51
  33. rslearn/dataset/handler_summaries.py +131 -0
  34. rslearn/dataset/manage.py +313 -74
  35. rslearn/dataset/materialize.py +431 -107
  36. rslearn/dataset/remap.py +29 -4
  37. rslearn/dataset/storage/__init__.py +1 -0
  38. rslearn/dataset/storage/file.py +202 -0
  39. rslearn/dataset/storage/storage.py +140 -0
  40. rslearn/dataset/window.py +181 -44
  41. rslearn/lightning_cli.py +454 -0
  42. rslearn/log_utils.py +24 -0
  43. rslearn/main.py +384 -181
  44. rslearn/models/anysat.py +215 -0
  45. rslearn/models/attention_pooling.py +177 -0
  46. rslearn/models/clay/clay.py +231 -0
  47. rslearn/models/clay/configs/metadata.yaml +295 -0
  48. rslearn/models/clip.py +68 -0
  49. rslearn/models/component.py +111 -0
  50. rslearn/models/concatenate_features.py +103 -0
  51. rslearn/models/conv.py +63 -0
  52. rslearn/models/croma.py +306 -0
  53. rslearn/models/detr/__init__.py +5 -0
  54. rslearn/models/detr/box_ops.py +103 -0
  55. rslearn/models/detr/detr.py +504 -0
  56. rslearn/models/detr/matcher.py +107 -0
  57. rslearn/models/detr/position_encoding.py +114 -0
  58. rslearn/models/detr/transformer.py +429 -0
  59. rslearn/models/detr/util.py +24 -0
  60. rslearn/models/dinov3.py +177 -0
  61. rslearn/models/faster_rcnn.py +30 -28
  62. rslearn/models/feature_center_crop.py +53 -0
  63. rslearn/models/fpn.py +19 -8
  64. rslearn/models/galileo/__init__.py +5 -0
  65. rslearn/models/galileo/galileo.py +595 -0
  66. rslearn/models/galileo/single_file_galileo.py +1678 -0
  67. rslearn/models/module_wrapper.py +65 -0
  68. rslearn/models/molmo.py +69 -0
  69. rslearn/models/multitask.py +384 -28
  70. rslearn/models/olmoearth_pretrain/__init__.py +1 -0
  71. rslearn/models/olmoearth_pretrain/model.py +421 -0
  72. rslearn/models/olmoearth_pretrain/norm.py +86 -0
  73. rslearn/models/panopticon.py +170 -0
  74. rslearn/models/panopticon_data/sensors/drone.yaml +32 -0
  75. rslearn/models/panopticon_data/sensors/enmap.yaml +904 -0
  76. rslearn/models/panopticon_data/sensors/goes.yaml +9 -0
  77. rslearn/models/panopticon_data/sensors/himawari.yaml +9 -0
  78. rslearn/models/panopticon_data/sensors/intuition.yaml +606 -0
  79. rslearn/models/panopticon_data/sensors/landsat8.yaml +84 -0
  80. rslearn/models/panopticon_data/sensors/modis_terra.yaml +99 -0
  81. rslearn/models/panopticon_data/sensors/qb2_ge1.yaml +34 -0
  82. rslearn/models/panopticon_data/sensors/sentinel1.yaml +85 -0
  83. rslearn/models/panopticon_data/sensors/sentinel2.yaml +97 -0
  84. rslearn/models/panopticon_data/sensors/superdove.yaml +60 -0
  85. rslearn/models/panopticon_data/sensors/wv23.yaml +63 -0
  86. rslearn/models/pick_features.py +17 -10
  87. rslearn/models/pooling_decoder.py +60 -7
  88. rslearn/models/presto/__init__.py +5 -0
  89. rslearn/models/presto/presto.py +297 -0
  90. rslearn/models/presto/single_file_presto.py +926 -0
  91. rslearn/models/prithvi.py +1147 -0
  92. rslearn/models/resize_features.py +59 -0
  93. rslearn/models/sam2_enc.py +13 -9
  94. rslearn/models/satlaspretrain.py +38 -18
  95. rslearn/models/simple_time_series.py +188 -77
  96. rslearn/models/singletask.py +24 -13
  97. rslearn/models/ssl4eo_s12.py +40 -30
  98. rslearn/models/swin.py +44 -32
  99. rslearn/models/task_embedding.py +250 -0
  100. rslearn/models/terramind.py +256 -0
  101. rslearn/models/trunk.py +139 -0
  102. rslearn/models/unet.py +68 -22
  103. rslearn/models/upsample.py +48 -0
  104. rslearn/models/use_croma.py +508 -0
  105. rslearn/template_params.py +26 -0
  106. rslearn/tile_stores/__init__.py +41 -18
  107. rslearn/tile_stores/default.py +409 -0
  108. rslearn/tile_stores/tile_store.py +236 -132
  109. rslearn/train/all_patches_dataset.py +530 -0
  110. rslearn/train/callbacks/adapters.py +53 -0
  111. rslearn/train/callbacks/freeze_unfreeze.py +348 -17
  112. rslearn/train/callbacks/gradients.py +129 -0
  113. rslearn/train/callbacks/peft.py +116 -0
  114. rslearn/train/data_module.py +444 -20
  115. rslearn/train/dataset.py +588 -235
  116. rslearn/train/lightning_module.py +192 -62
  117. rslearn/train/model_context.py +88 -0
  118. rslearn/train/optimizer.py +31 -0
  119. rslearn/train/prediction_writer.py +319 -84
  120. rslearn/train/scheduler.py +92 -0
  121. rslearn/train/tasks/classification.py +55 -28
  122. rslearn/train/tasks/detection.py +132 -76
  123. rslearn/train/tasks/embedding.py +120 -0
  124. rslearn/train/tasks/multi_task.py +28 -14
  125. rslearn/train/tasks/per_pixel_regression.py +291 -0
  126. rslearn/train/tasks/regression.py +161 -44
  127. rslearn/train/tasks/segmentation.py +428 -53
  128. rslearn/train/tasks/task.py +6 -5
  129. rslearn/train/transforms/__init__.py +1 -1
  130. rslearn/train/transforms/concatenate.py +54 -10
  131. rslearn/train/transforms/crop.py +29 -11
  132. rslearn/train/transforms/flip.py +18 -6
  133. rslearn/train/transforms/mask.py +78 -0
  134. rslearn/train/transforms/normalize.py +101 -17
  135. rslearn/train/transforms/pad.py +19 -7
  136. rslearn/train/transforms/resize.py +83 -0
  137. rslearn/train/transforms/select_bands.py +76 -0
  138. rslearn/train/transforms/sentinel1.py +75 -0
  139. rslearn/train/transforms/transform.py +89 -70
  140. rslearn/utils/__init__.py +2 -6
  141. rslearn/utils/array.py +8 -6
  142. rslearn/utils/feature.py +2 -2
  143. rslearn/utils/fsspec.py +90 -1
  144. rslearn/utils/geometry.py +347 -7
  145. rslearn/utils/get_utm_ups_crs.py +2 -3
  146. rslearn/utils/grid_index.py +5 -5
  147. rslearn/utils/jsonargparse.py +178 -0
  148. rslearn/utils/mp.py +4 -3
  149. rslearn/utils/raster_format.py +268 -116
  150. rslearn/utils/rtree_index.py +64 -17
  151. rslearn/utils/sqlite_index.py +7 -1
  152. rslearn/utils/vector_format.py +252 -97
  153. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/METADATA +532 -283
  154. rslearn-0.0.21.dist-info/RECORD +167 -0
  155. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/WHEEL +1 -1
  156. rslearn-0.0.21.dist-info/licenses/NOTICE +115 -0
  157. rslearn/data_sources/raster_source.py +0 -309
  158. rslearn/models/registry.py +0 -5
  159. rslearn/tile_stores/file.py +0 -242
  160. rslearn/utils/mgrs.py +0 -24
  161. rslearn/utils/utils.py +0 -22
  162. rslearn-0.0.1.dist-info/RECORD +0 -88
  163. /rslearn/{data_sources/geotiff.py → py.typed} +0 -0
  164. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/entry_points.txt +0 -0
  165. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info/licenses}/LICENSE +0 -0
  166. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,99 @@
1
+ instrument: 'MSI' # You may want to update this based on the actual instrument
2
+ processing_level: 'L1C' # You may want to update this based on the actual processing level
3
+
4
+ bands:
5
+ B01:
6
+ name: 'Band 1' # You may want to provide a descriptive name
7
+ gaussian:
8
+ mu: 411.633593
9
+ sigma: 14.652
10
+
11
+ B02:
12
+ name: 'Band 2' # coastal aerosol
13
+ gaussian:
14
+ mu: 442.155
15
+ sigma: 8.434933
16
+
17
+ B03:
18
+ name: 'Band 3'
19
+ gaussian:
20
+ mu: 466.122
21
+ sigma: 18.894
22
+
23
+ B04:
24
+ name: 'Band 4' # ~ blue
25
+ gaussian:
26
+ mu: 487.078
27
+ sigma: 10.633
28
+
29
+ B05:
30
+ name: 'Band 5'
31
+ gaussian:
32
+ mu: 529.783
33
+ sigma: 10.782657
34
+
35
+ B06:
36
+ name: 'Band 6'
37
+ gaussian:
38
+ mu: 546.981
39
+ sigma: 10.331
40
+
41
+ B07:
42
+ name: 'Band 7' # green
43
+ gaussian:
44
+ mu: 554.026
45
+ sigma: 17.766939
46
+
47
+ B08:
48
+ name: 'Band 8'
49
+ gaussian:
50
+ mu: 644.898
51
+ sigma: 34.650989
52
+
53
+ B09:
54
+ name: 'Band 9' # red
55
+ gaussian:
56
+ mu: 665.695
57
+ sigma: 10.117
58
+
59
+ B10:
60
+ name: 'Band 10'
61
+ gaussian:
62
+ mu: 677.068
63
+ sigma: 9.007702
64
+
65
+ B11:
66
+ name: 'Band 11'
67
+ gaussian:
68
+ mu: 746.736
69
+ sigma: 9.952
70
+
71
+ B12:
72
+ name: 'Band 12'
73
+ gaussian:
74
+ mu: 857.323
75
+ sigma: 34.696136
76
+
77
+ B13:
78
+ name: 'Band 13'
79
+ gaussian:
80
+ mu: 866.55
81
+ sigma: 13.421421
82
+
83
+ B14:
84
+ name: 'Band 14'
85
+ gaussian:
86
+ mu: 1241.597
87
+ sigma: 23.356
88
+
89
+ B15:
90
+ name: 'Band 15'
91
+ gaussian:
92
+ mu: 1627.972
93
+ sigma: 27.593
94
+
95
+ B16:
96
+ name: 'Band 16'
97
+ gaussian:
98
+ mu: 2113.124
99
+ sigma: 47.341375
@@ -0,0 +1,34 @@
1
+ instrument: '4-band'
2
+ processing_level: 'SR'
3
+
4
+ # QuickBird-2 and GeoEye-1
5
+ # In the corresponding dataset, either of qb2 or ge1 are used. However, since
6
+ # their bands have the same properties, we use the same file for both
7
+ bands:
8
+ B01:
9
+ name: 'blue'
10
+ gaussian:
11
+ mu: 482.417803
12
+ sigma: 45.029733
13
+ GSD: 1.24
14
+
15
+ B02:
16
+ name: 'green'
17
+ gaussian:
18
+ mu: 547.272289
19
+ sigma: 58.295452
20
+ GSD: 1.24
21
+
22
+ B03:
23
+ name: 'red'
24
+ gaussian:
25
+ mu: 665.031763
26
+ sigma: 35.775908
27
+ GSD: 1.24
28
+
29
+ B04:
30
+ name: 'nir'
31
+ gaussian:
32
+ mu: 840.773239
33
+ sigma: 83.700732
34
+ GSD: 1.24
@@ -0,0 +1,85 @@
1
+ bands:
2
+ # Modify the band numbering to be the Helios band strings in uppercase
3
+ VV:
4
+ name: VV
5
+ gaussian:
6
+ mu: -1
7
+ sigma: -1
8
+ orbit: BOTH
9
+
10
+ VH:
11
+ name: VH
12
+ gaussian:
13
+ mu: -2
14
+ sigma: -1
15
+ orbit: BOTH
16
+
17
+ HH:
18
+ name: HH
19
+ gaussian:
20
+ mu: -3
21
+ sigma: -1
22
+ orbit: BOTH
23
+
24
+ HV:
25
+ name: HV
26
+ gaussian:
27
+ mu: -4
28
+ sigma: -1
29
+ orbit: BOTH
30
+
31
+ VV_ASCENDING:
32
+ name: VV
33
+ gaussian:
34
+ mu: -5
35
+ sigma: -1
36
+ orbit: ASCENDING
37
+
38
+ VH_ASCENDING:
39
+ name: VH
40
+ gaussian:
41
+ mu: -6
42
+ sigma: -1
43
+ orbit: ASCENDING
44
+
45
+ HH_ASCENDING:
46
+ name: HH
47
+ gaussian:
48
+ mu: -7
49
+ sigma: -1
50
+ orbit: ASCENDING
51
+
52
+ HV_ASCENDING:
53
+ name: HV
54
+ gaussian:
55
+ mu: -8
56
+ sigma: -1
57
+ orbit: ASCENDING
58
+
59
+ VV_DESCENDING:
60
+ name: VV
61
+ gaussian:
62
+ mu: -9
63
+ sigma: -1
64
+ orbit: DESCENDING
65
+
66
+ VH_DESCENDING:
67
+ name: VH
68
+ gaussian:
69
+ mu: -10
70
+ sigma: -1
71
+ orbit: DESCENDING
72
+
73
+ HH_DESCENDING:
74
+ name: HH
75
+ gaussian:
76
+ mu: -11
77
+ sigma: -1
78
+ orbit: DESCENDING
79
+
80
+ HV_DESCENDING:
81
+ name: HV
82
+ gaussian:
83
+ mu: -12
84
+ sigma: -1
85
+ orbit: DESCENDING
@@ -0,0 +1,97 @@
1
+ instrument: MSI
2
+ level: L1C
3
+
4
+ srf_filename: rfs_sentinel2_a_13b.npy
5
+
6
+ bands:
7
+ B01:
8
+ name: '01 - Coastal aerosol'
9
+ gaussian:
10
+ mu: 442.922568734037
11
+ sigma: 7.248330717861807
12
+ GSD: 60
13
+
14
+ B02:
15
+ name: '02 - Blue'
16
+ gaussian:
17
+ mu: 492.9971095687347
18
+ sigma: 23.810316659477703
19
+ GSD: 10
20
+
21
+ B03:
22
+ name: '03 - Green'
23
+ gaussian:
24
+ mu: 559.5987534818435
25
+ sigma: 12.768882177939654
26
+ GSD: 10
27
+
28
+ B04:
29
+ name: '04 - Red'
30
+ gaussian:
31
+ mu: 664.6300422881802
32
+ sigma: 11.757355524910432
33
+ GSD: 10
34
+
35
+ B05:
36
+ name: '05 - Vegetation Red Edge'
37
+ gaussian:
38
+ mu: 704.0059319834206
39
+ sigma: 5.362493403740522
40
+ GSD: 20
41
+
42
+ B06:
43
+ name: '06 - Vegetation Red Edge'
44
+ gaussian:
45
+ mu: 740.5521320760564
46
+ sigma: 5.2330999827526155
47
+ GSD: 20
48
+
49
+ B07:
50
+ name: '07 - Vegetation Red Edge'
51
+ gaussian:
52
+ mu: 782.4190761493182
53
+ sigma: 7.212484180540051
54
+ GSD: 20
55
+
56
+ B08:
57
+ name: '08 - NIR'
58
+ gaussian:
59
+ mu: 827.5394062383036
60
+ sigma: 36.79409520400872
61
+ GSD: 10
62
+
63
+ B8A:
64
+ name: '08A - Vegetation Red Edge'
65
+ gaussian:
66
+ mu: 864.7801257644385
67
+ sigma: 8.07210759526792
68
+ GSD: 20
69
+
70
+ B09:
71
+ name: '09 - Water vapour'
72
+ gaussian:
73
+ mu: 945.0294901407692
74
+ sigma: 7.518965324285279
75
+ GSD: 60
76
+
77
+ B10:
78
+ name: '10 - SWIR - Cirrus'
79
+ gaussian:
80
+ mu: 1373.3636762095748
81
+ sigma: 11.163498916290587
82
+ GSD: 60
83
+ p: 2
84
+
85
+ B11:
86
+ name: '11 - SWIR'
87
+ gaussian:
88
+ mu: 1613.8624163477282
89
+ sigma: 34.4986558584479
90
+ GSD: 20
91
+
92
+ B12:
93
+ name: '12 - SWIR'
94
+ gaussian:
95
+ mu: 2203.6182057820033
96
+ sigma: 64.60648125885301
97
+ GSD: 20
@@ -0,0 +1,60 @@
1
+ instrument: 'PSB.SD'
2
+ processing_level: 'NA'
3
+
4
+ bands:
5
+ B01:
6
+ name: 'Coastal Blue'
7
+ gaussian:
8
+ mu: 443.704
9
+ sigma: 7.9672
10
+ GSD: 3.7
11
+
12
+ B02:
13
+ name: 'Blue'
14
+ gaussian:
15
+ mu: 490.973
16
+ sigma: 20.5096
17
+ GSD: 3.7
18
+
19
+ B03:
20
+ name: 'Green I'
21
+ gaussian:
22
+ mu: 532.719
23
+ sigma: 14.4789
24
+ GSD: 3.7
25
+
26
+ B04:
27
+ name: 'Green'
28
+ gaussian:
29
+ mu: 565.811
30
+ sigma: 15.2825
31
+ GSD: 3.7
32
+
33
+ B05:
34
+ name: 'Yellow'
35
+ gaussian:
36
+ mu: 611.587
37
+ sigma: 9.33594
38
+ GSD: 3.7
39
+
40
+ B06:
41
+ name: 'Red'
42
+ gaussian:
43
+ mu: 665.751
44
+ sigma: 12.6253
45
+ GSD: 3.7
46
+
47
+
48
+ B07:
49
+ name: 'Red Edge'
50
+ gaussian:
51
+ mu: 706.918
52
+ sigma: 6.92817
53
+ GSD: 3.7
54
+
55
+ B08:
56
+ name: 'NIR'
57
+ gaussian:
58
+ mu: 864.831
59
+ sigma: 15.2059
60
+ GSD: 3.7
@@ -0,0 +1,63 @@
1
+ processing_level: 'SR'
2
+
3
+ srf_filename: rfs_wv23_recon.npy
4
+
5
+ # Worldview-2 and Worldview-3_VNIR
6
+ # In the corresponding dataset, either of the two above satellites are used. However, since
7
+ # their bands have the same properties, we use the same file for both
8
+ bands:
9
+ B01:
10
+ name: 'coastal'
11
+ gaussian:
12
+ mu: 427.911967712222
13
+ sigma: 17.620786889126904
14
+ GSD: 1.24
15
+
16
+ B02:
17
+ name: 'blue'
18
+ gaussian:
19
+ mu: 482.40648216687816
20
+ sigma: 22.189227543486883
21
+ GSD: 1.24
22
+
23
+ B03:
24
+ name: 'green'
25
+ gaussian:
26
+ mu: 545.1346759174888
27
+ sigma: 27.270655243664613
28
+ GSD: 1.24
29
+
30
+ B04:
31
+ name: 'yellow'
32
+ gaussian:
33
+ mu: 604.6891589644367
34
+ sigma: 15.166919163740687
35
+ GSD: 1.24
36
+
37
+ B05:
38
+ name: 'red'
39
+ gaussian:
40
+ mu: 660.5315665213377
41
+ sigma: 23.075009737550587
42
+ GSD: 1.24
43
+
44
+ B06:
45
+ name: 'red edge'
46
+ gaussian:
47
+ mu: 723.1823149413602
48
+ sigma: 15.151759763702627
49
+ GSD: 1.24
50
+
51
+ B07:
52
+ name: 'nir1'
53
+ gaussian:
54
+ mu: 823.9274208290032
55
+ sigma: 42.09302701870739
56
+ GSD: 1.24
57
+
58
+ B08:
59
+ name: 'nir2'
60
+ gaussian:
61
+ mu: 906.4611534199017
62
+ sigma: 36.61665833552878
63
+ GSD: 1.24
@@ -2,10 +2,15 @@
2
2
 
3
3
  from typing import Any
4
4
 
5
- import torch
5
+ from rslearn.train.model_context import ModelContext
6
6
 
7
+ from .component import (
8
+ FeatureMaps,
9
+ IntermediateComponent,
10
+ )
7
11
 
8
- class PickFeatures(torch.nn.Module):
12
+
13
+ class PickFeatures(IntermediateComponent):
9
14
  """Picks a subset of feature maps in a multi-scale feature map list."""
10
15
 
11
16
  def __init__(self, indexes: list[int]):
@@ -19,15 +24,17 @@ class PickFeatures(torch.nn.Module):
19
24
 
20
25
  def forward(
21
26
  self,
22
- features: list[torch.Tensor],
23
- inputs: list[dict[str, Any]] | None = None,
24
- targets: list[dict[str, Any]] | None = None,
25
- ) -> list[torch.Tensor]:
27
+ intermediates: Any,
28
+ context: ModelContext,
29
+ ) -> FeatureMaps:
26
30
  """Pick a subset of the features.
27
31
 
28
32
  Args:
29
- features: input features
30
- inputs: raw inputs, not used
31
- targets: targets, not used
33
+ intermediates: the output from the previous component, which must be a FeatureMaps.
34
+ context: the model context.
32
35
  """
33
- return [features[idx] for idx in self.indexes]
36
+ if not isinstance(intermediates, FeatureMaps):
37
+ raise ValueError("input to PickFeatures must be FeatureMaps")
38
+
39
+ new_features = [intermediates.feature_maps[idx] for idx in self.indexes]
40
+ return FeatureMaps(new_features)
@@ -4,8 +4,16 @@ from typing import Any
4
4
 
5
5
  import torch
6
6
 
7
+ from rslearn.train.model_context import ModelContext
7
8
 
8
- class PoolingDecoder(torch.nn.Module):
9
+ from .component import (
10
+ FeatureMaps,
11
+ FeatureVector,
12
+ IntermediateComponent,
13
+ )
14
+
15
+
16
+ class PoolingDecoder(IntermediateComponent):
9
17
  """Decoder that computes flat vector from a 2D feature map.
10
18
 
11
19
  It inputs multi-scale features, but only uses the last feature map. Then applies a
@@ -21,7 +29,7 @@ class PoolingDecoder(torch.nn.Module):
21
29
  num_fc_layers: int = 0,
22
30
  conv_channels: int = 128,
23
31
  fc_channels: int = 512,
24
- ):
32
+ ) -> None:
25
33
  """Initialize a PoolingDecoder.
26
34
 
27
35
  Args:
@@ -57,20 +65,65 @@ class PoolingDecoder(torch.nn.Module):
57
65
 
58
66
  self.output_layer = torch.nn.Linear(prev_channels, out_channels)
59
67
 
60
- def forward(self, features: list[torch.Tensor], inputs: list[dict[str, Any]]):
68
+ def forward(self, intermediates: Any, context: ModelContext) -> Any:
61
69
  """Compute flat output vector from multi-scale feature map.
62
70
 
63
71
  Args:
64
- features: list of feature maps at different resolutions.
65
- inputs: original inputs (ignored).
72
+ intermediates: the output from the previous component, which must be a FeatureMaps.
73
+ context: the model context.
66
74
 
67
75
  Returns:
68
76
  flat feature vector
69
77
  """
78
+ if not isinstance(intermediates, FeatureMaps):
79
+ raise ValueError("input to PoolingDecoder must be a FeatureMaps")
80
+
70
81
  # Only use last feature map.
71
- features = features[-1]
82
+ features = intermediates.feature_maps[-1]
72
83
 
73
84
  features = self.conv_layers(features)
74
85
  features = torch.amax(features, dim=(2, 3))
75
86
  features = self.fc_layers(features)
76
- return self.output_layer(features)
87
+ return FeatureVector(self.output_layer(features))
88
+
89
+
90
+ class SegmentationPoolingDecoder(PoolingDecoder):
91
+ """Like PoolingDecoder, but copy output to all pixels.
92
+
93
+ This allows for the model to produce a global output while still being compatible
94
+ with SegmentationTask. This only makes sense for very small windows, since the
95
+ output probabilities will be the same at all pixels. The main use case is to train
96
+ for a classification-like task on small windows, but still produce a raster during
97
+ inference on large windows.
98
+ """
99
+
100
+ def __init__(
101
+ self,
102
+ in_channels: int,
103
+ out_channels: int,
104
+ image_key: str = "image",
105
+ **kwargs: Any,
106
+ ):
107
+ """Create a new SegmentationPoolingDecoder.
108
+
109
+ Args:
110
+ in_channels: input channels (channels in the last feature map passed to
111
+ this module)
112
+ out_channels: channels for the output flat feature vector
113
+ image_key: the key in inputs for the image from which the expected width
114
+ and height is derived.
115
+ kwargs: other arguments to pass to PoolingDecoder.
116
+ """
117
+ super().__init__(in_channels=in_channels, out_channels=out_channels, **kwargs)
118
+ self.image_key = image_key
119
+
120
+ def forward(self, intermediates: Any, context: ModelContext) -> Any:
121
+ """Extend PoolingDecoder forward to upsample the output to a segmentation mask.
122
+
123
+ This only works when all of the pixels have the same segmentation target.
124
+ """
125
+ output_probs = super().forward(intermediates, context)
126
+ # BC -> BCHW
127
+ h, w = context.inputs[0][self.image_key].shape[1:3]
128
+ feat_map = output_probs.feature_vector[:, :, None, None].repeat([1, 1, h, w])
129
+ return FeatureMaps([feat_map])
@@ -0,0 +1,5 @@
1
+ """Presto."""
2
+
3
+ from .presto import Presto
4
+
5
+ __all__ = ["Presto"]