rslearn 0.0.1__py3-none-any.whl → 0.0.21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (166) hide show
  1. rslearn/arg_parser.py +31 -0
  2. rslearn/config/__init__.py +6 -12
  3. rslearn/config/dataset.py +520 -401
  4. rslearn/const.py +9 -15
  5. rslearn/data_sources/__init__.py +8 -23
  6. rslearn/data_sources/aws_landsat.py +242 -98
  7. rslearn/data_sources/aws_open_data.py +111 -151
  8. rslearn/data_sources/aws_sentinel1.py +131 -0
  9. rslearn/data_sources/climate_data_store.py +471 -0
  10. rslearn/data_sources/copernicus.py +884 -12
  11. rslearn/data_sources/data_source.py +43 -12
  12. rslearn/data_sources/earthdaily.py +484 -0
  13. rslearn/data_sources/earthdata_srtm.py +282 -0
  14. rslearn/data_sources/eurocrops.py +242 -0
  15. rslearn/data_sources/gcp_public_data.py +578 -222
  16. rslearn/data_sources/google_earth_engine.py +461 -135
  17. rslearn/data_sources/local_files.py +219 -150
  18. rslearn/data_sources/openstreetmap.py +51 -89
  19. rslearn/data_sources/planet.py +24 -60
  20. rslearn/data_sources/planet_basemap.py +275 -0
  21. rslearn/data_sources/planetary_computer.py +798 -0
  22. rslearn/data_sources/usda_cdl.py +195 -0
  23. rslearn/data_sources/usgs_landsat.py +115 -83
  24. rslearn/data_sources/utils.py +249 -61
  25. rslearn/data_sources/vector_source.py +1 -0
  26. rslearn/data_sources/worldcereal.py +449 -0
  27. rslearn/data_sources/worldcover.py +144 -0
  28. rslearn/data_sources/worldpop.py +153 -0
  29. rslearn/data_sources/xyz_tiles.py +150 -107
  30. rslearn/dataset/__init__.py +8 -2
  31. rslearn/dataset/add_windows.py +2 -2
  32. rslearn/dataset/dataset.py +40 -51
  33. rslearn/dataset/handler_summaries.py +131 -0
  34. rslearn/dataset/manage.py +313 -74
  35. rslearn/dataset/materialize.py +431 -107
  36. rslearn/dataset/remap.py +29 -4
  37. rslearn/dataset/storage/__init__.py +1 -0
  38. rslearn/dataset/storage/file.py +202 -0
  39. rslearn/dataset/storage/storage.py +140 -0
  40. rslearn/dataset/window.py +181 -44
  41. rslearn/lightning_cli.py +454 -0
  42. rslearn/log_utils.py +24 -0
  43. rslearn/main.py +384 -181
  44. rslearn/models/anysat.py +215 -0
  45. rslearn/models/attention_pooling.py +177 -0
  46. rslearn/models/clay/clay.py +231 -0
  47. rslearn/models/clay/configs/metadata.yaml +295 -0
  48. rslearn/models/clip.py +68 -0
  49. rslearn/models/component.py +111 -0
  50. rslearn/models/concatenate_features.py +103 -0
  51. rslearn/models/conv.py +63 -0
  52. rslearn/models/croma.py +306 -0
  53. rslearn/models/detr/__init__.py +5 -0
  54. rslearn/models/detr/box_ops.py +103 -0
  55. rslearn/models/detr/detr.py +504 -0
  56. rslearn/models/detr/matcher.py +107 -0
  57. rslearn/models/detr/position_encoding.py +114 -0
  58. rslearn/models/detr/transformer.py +429 -0
  59. rslearn/models/detr/util.py +24 -0
  60. rslearn/models/dinov3.py +177 -0
  61. rslearn/models/faster_rcnn.py +30 -28
  62. rslearn/models/feature_center_crop.py +53 -0
  63. rslearn/models/fpn.py +19 -8
  64. rslearn/models/galileo/__init__.py +5 -0
  65. rslearn/models/galileo/galileo.py +595 -0
  66. rslearn/models/galileo/single_file_galileo.py +1678 -0
  67. rslearn/models/module_wrapper.py +65 -0
  68. rslearn/models/molmo.py +69 -0
  69. rslearn/models/multitask.py +384 -28
  70. rslearn/models/olmoearth_pretrain/__init__.py +1 -0
  71. rslearn/models/olmoearth_pretrain/model.py +421 -0
  72. rslearn/models/olmoearth_pretrain/norm.py +86 -0
  73. rslearn/models/panopticon.py +170 -0
  74. rslearn/models/panopticon_data/sensors/drone.yaml +32 -0
  75. rslearn/models/panopticon_data/sensors/enmap.yaml +904 -0
  76. rslearn/models/panopticon_data/sensors/goes.yaml +9 -0
  77. rslearn/models/panopticon_data/sensors/himawari.yaml +9 -0
  78. rslearn/models/panopticon_data/sensors/intuition.yaml +606 -0
  79. rslearn/models/panopticon_data/sensors/landsat8.yaml +84 -0
  80. rslearn/models/panopticon_data/sensors/modis_terra.yaml +99 -0
  81. rslearn/models/panopticon_data/sensors/qb2_ge1.yaml +34 -0
  82. rslearn/models/panopticon_data/sensors/sentinel1.yaml +85 -0
  83. rslearn/models/panopticon_data/sensors/sentinel2.yaml +97 -0
  84. rslearn/models/panopticon_data/sensors/superdove.yaml +60 -0
  85. rslearn/models/panopticon_data/sensors/wv23.yaml +63 -0
  86. rslearn/models/pick_features.py +17 -10
  87. rslearn/models/pooling_decoder.py +60 -7
  88. rslearn/models/presto/__init__.py +5 -0
  89. rslearn/models/presto/presto.py +297 -0
  90. rslearn/models/presto/single_file_presto.py +926 -0
  91. rslearn/models/prithvi.py +1147 -0
  92. rslearn/models/resize_features.py +59 -0
  93. rslearn/models/sam2_enc.py +13 -9
  94. rslearn/models/satlaspretrain.py +38 -18
  95. rslearn/models/simple_time_series.py +188 -77
  96. rslearn/models/singletask.py +24 -13
  97. rslearn/models/ssl4eo_s12.py +40 -30
  98. rslearn/models/swin.py +44 -32
  99. rslearn/models/task_embedding.py +250 -0
  100. rslearn/models/terramind.py +256 -0
  101. rslearn/models/trunk.py +139 -0
  102. rslearn/models/unet.py +68 -22
  103. rslearn/models/upsample.py +48 -0
  104. rslearn/models/use_croma.py +508 -0
  105. rslearn/template_params.py +26 -0
  106. rslearn/tile_stores/__init__.py +41 -18
  107. rslearn/tile_stores/default.py +409 -0
  108. rslearn/tile_stores/tile_store.py +236 -132
  109. rslearn/train/all_patches_dataset.py +530 -0
  110. rslearn/train/callbacks/adapters.py +53 -0
  111. rslearn/train/callbacks/freeze_unfreeze.py +348 -17
  112. rslearn/train/callbacks/gradients.py +129 -0
  113. rslearn/train/callbacks/peft.py +116 -0
  114. rslearn/train/data_module.py +444 -20
  115. rslearn/train/dataset.py +588 -235
  116. rslearn/train/lightning_module.py +192 -62
  117. rslearn/train/model_context.py +88 -0
  118. rslearn/train/optimizer.py +31 -0
  119. rslearn/train/prediction_writer.py +319 -84
  120. rslearn/train/scheduler.py +92 -0
  121. rslearn/train/tasks/classification.py +55 -28
  122. rslearn/train/tasks/detection.py +132 -76
  123. rslearn/train/tasks/embedding.py +120 -0
  124. rslearn/train/tasks/multi_task.py +28 -14
  125. rslearn/train/tasks/per_pixel_regression.py +291 -0
  126. rslearn/train/tasks/regression.py +161 -44
  127. rslearn/train/tasks/segmentation.py +428 -53
  128. rslearn/train/tasks/task.py +6 -5
  129. rslearn/train/transforms/__init__.py +1 -1
  130. rslearn/train/transforms/concatenate.py +54 -10
  131. rslearn/train/transforms/crop.py +29 -11
  132. rslearn/train/transforms/flip.py +18 -6
  133. rslearn/train/transforms/mask.py +78 -0
  134. rslearn/train/transforms/normalize.py +101 -17
  135. rslearn/train/transforms/pad.py +19 -7
  136. rslearn/train/transforms/resize.py +83 -0
  137. rslearn/train/transforms/select_bands.py +76 -0
  138. rslearn/train/transforms/sentinel1.py +75 -0
  139. rslearn/train/transforms/transform.py +89 -70
  140. rslearn/utils/__init__.py +2 -6
  141. rslearn/utils/array.py +8 -6
  142. rslearn/utils/feature.py +2 -2
  143. rslearn/utils/fsspec.py +90 -1
  144. rslearn/utils/geometry.py +347 -7
  145. rslearn/utils/get_utm_ups_crs.py +2 -3
  146. rslearn/utils/grid_index.py +5 -5
  147. rslearn/utils/jsonargparse.py +178 -0
  148. rslearn/utils/mp.py +4 -3
  149. rslearn/utils/raster_format.py +268 -116
  150. rslearn/utils/rtree_index.py +64 -17
  151. rslearn/utils/sqlite_index.py +7 -1
  152. rslearn/utils/vector_format.py +252 -97
  153. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/METADATA +532 -283
  154. rslearn-0.0.21.dist-info/RECORD +167 -0
  155. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/WHEEL +1 -1
  156. rslearn-0.0.21.dist-info/licenses/NOTICE +115 -0
  157. rslearn/data_sources/raster_source.py +0 -309
  158. rslearn/models/registry.py +0 -5
  159. rslearn/tile_stores/file.py +0 -242
  160. rslearn/utils/mgrs.py +0 -24
  161. rslearn/utils/utils.py +0 -22
  162. rslearn-0.0.1.dist-info/RECORD +0 -88
  163. /rslearn/{data_sources/geotiff.py → py.typed} +0 -0
  164. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/entry_points.txt +0 -0
  165. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info/licenses}/LICENSE +0 -0
  166. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,295 @@
1
+ sentinel-2-l2a:
2
+ band_order:
3
+ - blue
4
+ - green
5
+ - red
6
+ - rededge1
7
+ - rededge2
8
+ - rededge3
9
+ - nir
10
+ - nir08
11
+ - swir16
12
+ - swir22
13
+ rgb_indices:
14
+ - 2
15
+ - 1
16
+ - 0
17
+ gsd: 10
18
+ bands:
19
+ mean:
20
+ blue: 1105.
21
+ green: 1355.
22
+ red: 1552.
23
+ rededge1: 1887.
24
+ rededge2: 2422.
25
+ rededge3: 2630.
26
+ nir: 2743.
27
+ nir08: 2785.
28
+ swir16: 2388.
29
+ swir22: 1835.
30
+ std:
31
+ blue: 1809.
32
+ green: 1757.
33
+ red: 1888.
34
+ rededge1: 1870.
35
+ rededge2: 1732.
36
+ rededge3: 1697.
37
+ nir: 1742.
38
+ nir08: 1648.
39
+ swir16: 1470.
40
+ swir22: 1379.
41
+ wavelength:
42
+ blue: 0.493
43
+ green: 0.56
44
+ red: 0.665
45
+ rededge1: 0.704
46
+ rededge2: 0.74
47
+ rededge3: 0.783
48
+ nir: 0.842
49
+ nir08: 0.865
50
+ swir16: 1.61
51
+ swir22: 2.19
52
+ planetscope-sr:
53
+ band_order:
54
+ - coastal_blue
55
+ - blue
56
+ - green_i
57
+ - green
58
+ - yellow
59
+ - red
60
+ - rededge
61
+ - nir
62
+ rgb_indices:
63
+ - 5
64
+ - 3
65
+ - 1
66
+ gsd: 5
67
+ bands:
68
+ mean:
69
+ coastal_blue: 1720.
70
+ blue: 1715.
71
+ green_i: 1913.
72
+ green: 2088.
73
+ yellow: 2274.
74
+ red: 2290.
75
+ rededge: 2613.
76
+ nir: 3970.
77
+ std:
78
+ coastal_blue: 747.
79
+ blue: 698.
80
+ green_i: 739.
81
+ green: 768.
82
+ yellow: 849.
83
+ red: 868.
84
+ rededge: 849.
85
+ nir: 914.
86
+ wavelength:
87
+ coastal_blue: 0.443
88
+ blue: 0.490
89
+ green_i: 0.531
90
+ green: 0.565
91
+ yellow: 0.610
92
+ red: 0.665
93
+ rededge: 0.705
94
+ nir: 0.865
95
+ landsat-c2l1:
96
+ band_order:
97
+ - red
98
+ - green
99
+ - blue
100
+ - nir08
101
+ - swir16
102
+ - swir22
103
+ rgb_indices:
104
+ - 0
105
+ - 1
106
+ - 2
107
+ gsd: 30
108
+ bands:
109
+ mean:
110
+ red: 10678.
111
+ green: 10563.
112
+ blue: 11083.
113
+ nir08: 14792.
114
+ swir16: 12276.
115
+ swir22: 10114.
116
+ std:
117
+ red: 6025.
118
+ green: 5411.
119
+ blue: 5468.
120
+ nir08: 6746.
121
+ swir16: 5897.
122
+ swir22: 4850.
123
+ wavelength:
124
+ red: 0.65
125
+ green: 0.56
126
+ blue: 0.48
127
+ nir08: 0.86
128
+ swir16: 1.6
129
+ swir22: 2.2
130
+ landsat-c2l2-sr:
131
+ band_order:
132
+ - red
133
+ - green
134
+ - blue
135
+ - nir08
136
+ - swir16
137
+ - swir22
138
+ rgb_indices:
139
+ - 0
140
+ - 1
141
+ - 2
142
+ gsd: 30
143
+ bands:
144
+ mean:
145
+ red: 13705.
146
+ green: 13310.
147
+ blue: 12474.
148
+ nir08: 17801.
149
+ swir16: 14615.
150
+ swir22: 12701.
151
+ std:
152
+ red: 9578.
153
+ green: 9408.
154
+ blue: 10144.
155
+ nir08: 8277.
156
+ swir16: 5300.
157
+ swir22: 4522.
158
+ wavelength:
159
+ red: 0.65
160
+ green: 0.56
161
+ blue: 0.48
162
+ nir08: 0.86
163
+ swir16: 1.6
164
+ swir22: 2.2
165
+ naip:
166
+ band_order:
167
+ - red
168
+ - green
169
+ - blue
170
+ - nir
171
+ rgb_indices:
172
+ - 0
173
+ - 1
174
+ - 2
175
+ gsd: 1.0
176
+ bands:
177
+ mean:
178
+ red: 110.16
179
+ green: 115.41
180
+ blue: 98.15
181
+ nir: 139.04
182
+ std:
183
+ red: 47.23
184
+ green: 39.82
185
+ blue: 35.43
186
+ nir: 49.86
187
+ wavelength:
188
+ red: 0.65
189
+ green: 0.56
190
+ blue: 0.48
191
+ nir: 0.842
192
+ linz:
193
+ band_order:
194
+ - red
195
+ - green
196
+ - blue
197
+ rgb_indices:
198
+ - 0
199
+ - 1
200
+ - 2
201
+ gsd: 0.5
202
+ bands:
203
+ mean:
204
+ red: 89.96
205
+ green: 99.46
206
+ blue: 89.51
207
+ std:
208
+ red: 41.83
209
+ green: 36.96
210
+ blue: 31.45
211
+ wavelength:
212
+ red: 0.635
213
+ green: 0.555
214
+ blue: 0.465
215
+ sentinel-1-rtc:
216
+ band_order:
217
+ - vv
218
+ - vh
219
+ gsd: 10
220
+ bands:
221
+ mean:
222
+ vv: -12.113
223
+ vh: -18.673
224
+ std:
225
+ vv: 8.314
226
+ vh: 8.017
227
+ wavelength:
228
+ vv: 3.5
229
+ vh: 4.0
230
+ modis:
231
+ band_order:
232
+ - sur_refl_b01
233
+ - sur_refl_b02
234
+ - sur_refl_b03
235
+ - sur_refl_b04
236
+ - sur_refl_b05
237
+ - sur_refl_b06
238
+ - sur_refl_b07
239
+ rgb_indices:
240
+ - 0
241
+ - 3
242
+ - 2
243
+ gsd: 500
244
+ bands:
245
+ mean:
246
+ sur_refl_b01: 1072.
247
+ sur_refl_b02: 1624.
248
+ sur_refl_b03: 931.
249
+ sur_refl_b04: 1023.
250
+ sur_refl_b05: 1599.
251
+ sur_refl_b06: 1404.
252
+ sur_refl_b07: 1051.
253
+ std:
254
+ sur_refl_b01: 1643.
255
+ sur_refl_b02: 1878.
256
+ sur_refl_b03: 1449.
257
+ sur_refl_b04: 1538.
258
+ sur_refl_b05: 1763.
259
+ sur_refl_b06: 1618.
260
+ sur_refl_b07: 1396.
261
+ wavelength:
262
+ sur_refl_b01: .645
263
+ sur_refl_b02: .858
264
+ sur_refl_b03: .469
265
+ sur_refl_b04: .555
266
+ sur_refl_b05: 1.240
267
+ sur_refl_b06: 1.640
268
+ sur_refl_b07: 2.130
269
+ satellogic-MSI-L1D:
270
+ band_order:
271
+ - red
272
+ - green
273
+ - blue
274
+ - nir
275
+ rgb_indices:
276
+ - 0
277
+ - 1
278
+ - 2
279
+ gsd: 1.0
280
+ bands:
281
+ mean:
282
+ red: 1451.54
283
+ green: 1456.54
284
+ blue: 1543.22
285
+ nir: 2132.68
286
+ std:
287
+ red: 995.48
288
+ green: 771.29
289
+ blue: 708.86
290
+ nir: 1236.71
291
+ wavelength:
292
+ red: 0.640
293
+ green: 0.545
294
+ blue: 0.480
295
+ nir: 0.825
rslearn/models/clip.py ADDED
@@ -0,0 +1,68 @@
1
+ """OpenAI CLIP models."""
2
+
3
+ from transformers import AutoModelForZeroShotImageClassification, AutoProcessor
4
+
5
+ from rslearn.train.model_context import ModelContext
6
+
7
+ from .component import FeatureExtractor, FeatureMaps
8
+
9
+
10
+ class CLIP(FeatureExtractor):
11
+ """CLIP image encoder."""
12
+
13
+ def __init__(
14
+ self,
15
+ model_name: str,
16
+ ):
17
+ """Instantiate a new CLIP instance.
18
+
19
+ Args:
20
+ model_name: the model name like "openai/clip-vit-large-patch14-336".
21
+ """
22
+ super().__init__()
23
+
24
+ self.processor = AutoProcessor.from_pretrained(model_name) # nosec
25
+ model = AutoModelForZeroShotImageClassification.from_pretrained(model_name) # nosec
26
+ self.encoder = model.vision_model
27
+
28
+ # Get number of features and token map size from encoder attributes.
29
+ self.num_features = self.encoder.post_layernorm.normalized_shape[0]
30
+ crop_size = self.processor.image_processor.crop_size
31
+ stride = self.encoder.embeddings.patch_embedding.stride
32
+ self.height = crop_size["height"] // stride[0]
33
+ self.width = crop_size["width"] // stride[1]
34
+
35
+ def forward(self, context: ModelContext) -> FeatureMaps:
36
+ """Compute outputs from the backbone.
37
+
38
+ Args:
39
+ context: the model context. Input dicts must include "image" key containing
40
+ the image to process. The images should have values 0-255.
41
+
42
+ Returns:
43
+ a FeatureMaps with one feature map from the ViT, which is always Bx24x24x1024.
44
+ """
45
+ inputs = context.inputs
46
+ device = inputs[0]["image"].image.device
47
+ clip_inputs = self.processor(
48
+ images=[
49
+ inp["image"].single_ts_to_chw_tensor().cpu().numpy().transpose(1, 2, 0)
50
+ for inp in inputs
51
+ ],
52
+ return_tensors="pt",
53
+ padding=True,
54
+ )
55
+ pixel_values = clip_inputs["pixel_values"].to(device)
56
+ output = self.encoder(pixel_values=pixel_values)
57
+ # Ignore class token output which is before the patch tokens.
58
+ image_features = output.last_hidden_state[:, 1:, :]
59
+ batch_size = image_features.shape[0]
60
+
61
+ # 576x1024 -> HxWxC
62
+ return FeatureMaps(
63
+ [
64
+ image_features.reshape(
65
+ batch_size, self.height, self.width, self.num_features
66
+ ).permute(0, 3, 1, 2)
67
+ ]
68
+ )
@@ -0,0 +1,111 @@
1
+ """Model component API."""
2
+
3
+ import abc
4
+ from dataclasses import dataclass
5
+ from typing import Any
6
+
7
+ import torch
8
+
9
+ from rslearn.train.model_context import ModelContext, ModelOutput
10
+
11
+
12
+ class FeatureExtractor(torch.nn.Module, abc.ABC):
13
+ """A feature extractor that performs initial processing of the inputs.
14
+
15
+ The FeatureExtractor is the first component in the encoders list for
16
+ SingleTaskModel and MultiTaskModel.
17
+ """
18
+
19
+ @abc.abstractmethod
20
+ def forward(self, context: ModelContext) -> Any:
21
+ """Extract an initial intermediate from the model context.
22
+
23
+ Args:
24
+ context: the model context.
25
+
26
+ Returns:
27
+ any intermediate to pass to downstream components. Oftentimes this is a
28
+ FeatureMaps.
29
+ """
30
+ raise NotImplementedError
31
+
32
+
33
+ class IntermediateComponent(torch.nn.Module, abc.ABC):
34
+ """An intermediate component in the model.
35
+
36
+ In SingleTaskModel and MultiTaskModel, modules after the first module
37
+ in the encoders list are IntermediateComponents, as are modules before the last
38
+ module in the decoders list(s).
39
+ """
40
+
41
+ @abc.abstractmethod
42
+ def forward(self, intermediates: Any, context: ModelContext) -> Any:
43
+ """Process the given intermediate into another intermediate.
44
+
45
+ Args:
46
+ intermediates: the output from the previous component (either a
47
+ FeatureExtractor or another IntermediateComponent).
48
+ context: the model context.
49
+
50
+ Returns:
51
+ any intermediate to pass to downstream components.
52
+ """
53
+ raise NotImplementedError
54
+
55
+
56
+ class Predictor(torch.nn.Module, abc.ABC):
57
+ """A predictor that computes task-specific outputs and a loss dict.
58
+
59
+ In SingleTaskModel and MultiTaskModel, the last module(s) in the decoders list(s)
60
+ are Predictors.
61
+ """
62
+
63
+ @abc.abstractmethod
64
+ def forward(
65
+ self,
66
+ intermediates: Any,
67
+ context: ModelContext,
68
+ targets: list[dict[str, torch.Tensor]] | None = None,
69
+ ) -> ModelOutput:
70
+ """Compute task-specific outputs and loss dict.
71
+
72
+ Args:
73
+ intermediates: the output from the previous component.
74
+ context: the model context.
75
+ targets: the training targets, or None during prediction.
76
+
77
+ Returns:
78
+ a tuple of the task-specific outputs (which should be compatible with the
79
+ configured Task) and loss dict. The loss dict maps from a name for each
80
+ loss to a scalar tensor.
81
+ """
82
+ raise NotImplementedError
83
+
84
+
85
+ @dataclass
86
+ class FeatureMaps:
87
+ """An intermediate output type for multi-resolution feature maps."""
88
+
89
+ # List of BxCxHxW feature maps at different scales, ordered from highest resolution
90
+ # (most fine-grained) to lowest resolution (coarsest).
91
+ feature_maps: list[torch.Tensor]
92
+
93
+
94
+ @dataclass
95
+ class TokenFeatureMaps:
96
+ """An intermediate output type for multi-resolution BCHWN feature maps with a token dimension.
97
+
98
+ Unlike `FeatureMaps`, these include an additional dimension for unpooled tokens.
99
+ """
100
+
101
+ # List of BxCxHxWxN feature maps at different scales, ordered from highest resolution
102
+ # (most fine-grained) to lowest resolution (coarsest).
103
+ feature_maps: list[torch.Tensor]
104
+
105
+
106
+ @dataclass
107
+ class FeatureVector:
108
+ """An intermediate output type for a flat feature vector."""
109
+
110
+ # Flat BxC feature vector.
111
+ feature_vector: torch.Tensor
@@ -0,0 +1,103 @@
1
+ """Concatenate feature map with features from input data."""
2
+
3
+ from typing import Any
4
+
5
+ import torch
6
+
7
+ from rslearn.train.model_context import ModelContext
8
+
9
+ from .component import FeatureMaps, IntermediateComponent
10
+
11
+
12
+ class ConcatenateFeatures(IntermediateComponent):
13
+ """Concatenate feature map with additional raw data inputs."""
14
+
15
+ def __init__(
16
+ self,
17
+ key: str,
18
+ in_channels: int | None = None,
19
+ conv_channels: int = 64,
20
+ out_channels: int | None = None,
21
+ num_conv_layers: int = 1,
22
+ kernel_size: int = 3,
23
+ final_relu: bool = False,
24
+ ):
25
+ """Create a new ConcatenateFeatures.
26
+
27
+ Args:
28
+ key: the key of the input_dict to concatenate.
29
+ in_channels: number of input channels of the additional features.
30
+ conv_channels: number of channels of the convolutional layers.
31
+ out_channels: number of output channels of the additional features.
32
+ num_conv_layers: number of convolutional layers to apply to the additional features.
33
+ kernel_size: kernel size of the convolutional layers.
34
+ final_relu: whether to apply a ReLU activation to the final output, default False.
35
+ """
36
+ super().__init__()
37
+ self.key = key
38
+
39
+ if num_conv_layers > 0:
40
+ if in_channels is None or out_channels is None:
41
+ raise ValueError(
42
+ "in_channels and out_channels must be specified if num_conv_layers > 0"
43
+ )
44
+
45
+ conv_layers = []
46
+ for i in range(num_conv_layers):
47
+ conv_in = in_channels if i == 0 else conv_channels
48
+ conv_out = out_channels if i == num_conv_layers - 1 else conv_channels
49
+ conv_layers.append(
50
+ torch.nn.Conv2d(
51
+ in_channels=conv_in,
52
+ out_channels=conv_out,
53
+ kernel_size=kernel_size,
54
+ padding="same",
55
+ )
56
+ )
57
+ if i < num_conv_layers - 1 or final_relu:
58
+ conv_layers.append(torch.nn.ReLU(inplace=True))
59
+
60
+ self.conv_layers = torch.nn.Sequential(*conv_layers)
61
+
62
+ def forward(self, intermediates: Any, context: ModelContext) -> FeatureMaps:
63
+ """Concatenate the feature map with the raw data inputs.
64
+
65
+ Args:
66
+ intermediates: the previous output, which must be a FeatureMaps.
67
+ context: the model context. The input dicts must have a key matching the
68
+ configured key.
69
+
70
+ Returns:
71
+ concatenated feature maps.
72
+ """
73
+ if (
74
+ not isinstance(intermediates, FeatureMaps)
75
+ or len(intermediates.feature_maps) == 0
76
+ ):
77
+ raise ValueError(
78
+ "Expected input to be FeatureMaps with at least one feature map"
79
+ )
80
+
81
+ add_data = torch.stack(
82
+ [input_data[self.key] for input_data in context.inputs], dim=0
83
+ )
84
+ add_features = self.conv_layers(add_data)
85
+
86
+ new_features: list[torch.Tensor] = []
87
+ for feature_map in intermediates.feature_maps:
88
+ # Shape of feature map: BCHW
89
+ feat_h, feat_w = feature_map.shape[2], feature_map.shape[3]
90
+
91
+ resized_add_features = add_features
92
+ # Resize additional features to match each feature map size if needed
93
+ if add_features.shape[2] != feat_h or add_features.shape[3] != feat_w:
94
+ resized_add_features = torch.nn.functional.interpolate(
95
+ add_features,
96
+ size=(feat_h, feat_w),
97
+ mode="bilinear",
98
+ align_corners=False,
99
+ )
100
+
101
+ new_features.append(torch.cat([feature_map, resized_add_features], dim=1))
102
+
103
+ return FeatureMaps(new_features)
rslearn/models/conv.py ADDED
@@ -0,0 +1,63 @@
1
+ """A single convolutional layer."""
2
+
3
+ from typing import Any
4
+
5
+ import torch
6
+
7
+ from rslearn.train.model_context import ModelContext
8
+
9
+ from .component import FeatureMaps, IntermediateComponent
10
+
11
+
12
+ class Conv(IntermediateComponent):
13
+ """A single convolutional layer.
14
+
15
+ It inputs a set of feature maps; the conv layer is applied to each feature map
16
+ independently, and list of outputs is returned.
17
+ """
18
+
19
+ def __init__(
20
+ self,
21
+ in_channels: int,
22
+ out_channels: int,
23
+ kernel_size: int,
24
+ padding: str | int = "same",
25
+ stride: int = 1,
26
+ activation: torch.nn.Module = torch.nn.ReLU(inplace=True),
27
+ ):
28
+ """Initialize a Conv.
29
+
30
+ Args:
31
+ in_channels: number of input channels.
32
+ out_channels: number of output channels.
33
+ kernel_size: kernel size, see torch.nn.Conv2D.
34
+ padding: padding to apply, see torch.nn.Conv2D.
35
+ stride: stride to apply, see torch.nn.Conv2D.
36
+ activation: activation to apply after convolution
37
+ """
38
+ super().__init__()
39
+
40
+ self.layer = torch.nn.Conv2d(
41
+ in_channels, out_channels, kernel_size, padding=padding, stride=stride
42
+ )
43
+ self.activation = activation
44
+
45
+ def forward(self, intermediates: Any, context: ModelContext) -> FeatureMaps:
46
+ """Apply conv layer on each feature map.
47
+
48
+ Args:
49
+ intermediates: the previous output, which must be a FeatureMaps.
50
+ context: the model context.
51
+
52
+ Returns:
53
+ the resulting feature maps after applying the same Conv2d on each one.
54
+ """
55
+ if not isinstance(intermediates, FeatureMaps):
56
+ raise ValueError("input to Conv must be FeatureMaps")
57
+
58
+ new_features = []
59
+ for feat_map in intermediates.feature_maps:
60
+ feat_map = self.layer(feat_map)
61
+ feat_map = self.activation(feat_map)
62
+ new_features.append(feat_map)
63
+ return FeatureMaps(new_features)