rslearn 0.0.1__py3-none-any.whl → 0.0.21__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rslearn/arg_parser.py +31 -0
- rslearn/config/__init__.py +6 -12
- rslearn/config/dataset.py +520 -401
- rslearn/const.py +9 -15
- rslearn/data_sources/__init__.py +8 -23
- rslearn/data_sources/aws_landsat.py +242 -98
- rslearn/data_sources/aws_open_data.py +111 -151
- rslearn/data_sources/aws_sentinel1.py +131 -0
- rslearn/data_sources/climate_data_store.py +471 -0
- rslearn/data_sources/copernicus.py +884 -12
- rslearn/data_sources/data_source.py +43 -12
- rslearn/data_sources/earthdaily.py +484 -0
- rslearn/data_sources/earthdata_srtm.py +282 -0
- rslearn/data_sources/eurocrops.py +242 -0
- rslearn/data_sources/gcp_public_data.py +578 -222
- rslearn/data_sources/google_earth_engine.py +461 -135
- rslearn/data_sources/local_files.py +219 -150
- rslearn/data_sources/openstreetmap.py +51 -89
- rslearn/data_sources/planet.py +24 -60
- rslearn/data_sources/planet_basemap.py +275 -0
- rslearn/data_sources/planetary_computer.py +798 -0
- rslearn/data_sources/usda_cdl.py +195 -0
- rslearn/data_sources/usgs_landsat.py +115 -83
- rslearn/data_sources/utils.py +249 -61
- rslearn/data_sources/vector_source.py +1 -0
- rslearn/data_sources/worldcereal.py +449 -0
- rslearn/data_sources/worldcover.py +144 -0
- rslearn/data_sources/worldpop.py +153 -0
- rslearn/data_sources/xyz_tiles.py +150 -107
- rslearn/dataset/__init__.py +8 -2
- rslearn/dataset/add_windows.py +2 -2
- rslearn/dataset/dataset.py +40 -51
- rslearn/dataset/handler_summaries.py +131 -0
- rslearn/dataset/manage.py +313 -74
- rslearn/dataset/materialize.py +431 -107
- rslearn/dataset/remap.py +29 -4
- rslearn/dataset/storage/__init__.py +1 -0
- rslearn/dataset/storage/file.py +202 -0
- rslearn/dataset/storage/storage.py +140 -0
- rslearn/dataset/window.py +181 -44
- rslearn/lightning_cli.py +454 -0
- rslearn/log_utils.py +24 -0
- rslearn/main.py +384 -181
- rslearn/models/anysat.py +215 -0
- rslearn/models/attention_pooling.py +177 -0
- rslearn/models/clay/clay.py +231 -0
- rslearn/models/clay/configs/metadata.yaml +295 -0
- rslearn/models/clip.py +68 -0
- rslearn/models/component.py +111 -0
- rslearn/models/concatenate_features.py +103 -0
- rslearn/models/conv.py +63 -0
- rslearn/models/croma.py +306 -0
- rslearn/models/detr/__init__.py +5 -0
- rslearn/models/detr/box_ops.py +103 -0
- rslearn/models/detr/detr.py +504 -0
- rslearn/models/detr/matcher.py +107 -0
- rslearn/models/detr/position_encoding.py +114 -0
- rslearn/models/detr/transformer.py +429 -0
- rslearn/models/detr/util.py +24 -0
- rslearn/models/dinov3.py +177 -0
- rslearn/models/faster_rcnn.py +30 -28
- rslearn/models/feature_center_crop.py +53 -0
- rslearn/models/fpn.py +19 -8
- rslearn/models/galileo/__init__.py +5 -0
- rslearn/models/galileo/galileo.py +595 -0
- rslearn/models/galileo/single_file_galileo.py +1678 -0
- rslearn/models/module_wrapper.py +65 -0
- rslearn/models/molmo.py +69 -0
- rslearn/models/multitask.py +384 -28
- rslearn/models/olmoearth_pretrain/__init__.py +1 -0
- rslearn/models/olmoearth_pretrain/model.py +421 -0
- rslearn/models/olmoearth_pretrain/norm.py +86 -0
- rslearn/models/panopticon.py +170 -0
- rslearn/models/panopticon_data/sensors/drone.yaml +32 -0
- rslearn/models/panopticon_data/sensors/enmap.yaml +904 -0
- rslearn/models/panopticon_data/sensors/goes.yaml +9 -0
- rslearn/models/panopticon_data/sensors/himawari.yaml +9 -0
- rslearn/models/panopticon_data/sensors/intuition.yaml +606 -0
- rslearn/models/panopticon_data/sensors/landsat8.yaml +84 -0
- rslearn/models/panopticon_data/sensors/modis_terra.yaml +99 -0
- rslearn/models/panopticon_data/sensors/qb2_ge1.yaml +34 -0
- rslearn/models/panopticon_data/sensors/sentinel1.yaml +85 -0
- rslearn/models/panopticon_data/sensors/sentinel2.yaml +97 -0
- rslearn/models/panopticon_data/sensors/superdove.yaml +60 -0
- rslearn/models/panopticon_data/sensors/wv23.yaml +63 -0
- rslearn/models/pick_features.py +17 -10
- rslearn/models/pooling_decoder.py +60 -7
- rslearn/models/presto/__init__.py +5 -0
- rslearn/models/presto/presto.py +297 -0
- rslearn/models/presto/single_file_presto.py +926 -0
- rslearn/models/prithvi.py +1147 -0
- rslearn/models/resize_features.py +59 -0
- rslearn/models/sam2_enc.py +13 -9
- rslearn/models/satlaspretrain.py +38 -18
- rslearn/models/simple_time_series.py +188 -77
- rslearn/models/singletask.py +24 -13
- rslearn/models/ssl4eo_s12.py +40 -30
- rslearn/models/swin.py +44 -32
- rslearn/models/task_embedding.py +250 -0
- rslearn/models/terramind.py +256 -0
- rslearn/models/trunk.py +139 -0
- rslearn/models/unet.py +68 -22
- rslearn/models/upsample.py +48 -0
- rslearn/models/use_croma.py +508 -0
- rslearn/template_params.py +26 -0
- rslearn/tile_stores/__init__.py +41 -18
- rslearn/tile_stores/default.py +409 -0
- rslearn/tile_stores/tile_store.py +236 -132
- rslearn/train/all_patches_dataset.py +530 -0
- rslearn/train/callbacks/adapters.py +53 -0
- rslearn/train/callbacks/freeze_unfreeze.py +348 -17
- rslearn/train/callbacks/gradients.py +129 -0
- rslearn/train/callbacks/peft.py +116 -0
- rslearn/train/data_module.py +444 -20
- rslearn/train/dataset.py +588 -235
- rslearn/train/lightning_module.py +192 -62
- rslearn/train/model_context.py +88 -0
- rslearn/train/optimizer.py +31 -0
- rslearn/train/prediction_writer.py +319 -84
- rslearn/train/scheduler.py +92 -0
- rslearn/train/tasks/classification.py +55 -28
- rslearn/train/tasks/detection.py +132 -76
- rslearn/train/tasks/embedding.py +120 -0
- rslearn/train/tasks/multi_task.py +28 -14
- rslearn/train/tasks/per_pixel_regression.py +291 -0
- rslearn/train/tasks/regression.py +161 -44
- rslearn/train/tasks/segmentation.py +428 -53
- rslearn/train/tasks/task.py +6 -5
- rslearn/train/transforms/__init__.py +1 -1
- rslearn/train/transforms/concatenate.py +54 -10
- rslearn/train/transforms/crop.py +29 -11
- rslearn/train/transforms/flip.py +18 -6
- rslearn/train/transforms/mask.py +78 -0
- rslearn/train/transforms/normalize.py +101 -17
- rslearn/train/transforms/pad.py +19 -7
- rslearn/train/transforms/resize.py +83 -0
- rslearn/train/transforms/select_bands.py +76 -0
- rslearn/train/transforms/sentinel1.py +75 -0
- rslearn/train/transforms/transform.py +89 -70
- rslearn/utils/__init__.py +2 -6
- rslearn/utils/array.py +8 -6
- rslearn/utils/feature.py +2 -2
- rslearn/utils/fsspec.py +90 -1
- rslearn/utils/geometry.py +347 -7
- rslearn/utils/get_utm_ups_crs.py +2 -3
- rslearn/utils/grid_index.py +5 -5
- rslearn/utils/jsonargparse.py +178 -0
- rslearn/utils/mp.py +4 -3
- rslearn/utils/raster_format.py +268 -116
- rslearn/utils/rtree_index.py +64 -17
- rslearn/utils/sqlite_index.py +7 -1
- rslearn/utils/vector_format.py +252 -97
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/METADATA +532 -283
- rslearn-0.0.21.dist-info/RECORD +167 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/WHEEL +1 -1
- rslearn-0.0.21.dist-info/licenses/NOTICE +115 -0
- rslearn/data_sources/raster_source.py +0 -309
- rslearn/models/registry.py +0 -5
- rslearn/tile_stores/file.py +0 -242
- rslearn/utils/mgrs.py +0 -24
- rslearn/utils/utils.py +0 -22
- rslearn-0.0.1.dist-info/RECORD +0 -88
- /rslearn/{data_sources/geotiff.py → py.typed} +0 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/entry_points.txt +0 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info/licenses}/LICENSE +0 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,295 @@
|
|
|
1
|
+
sentinel-2-l2a:
|
|
2
|
+
band_order:
|
|
3
|
+
- blue
|
|
4
|
+
- green
|
|
5
|
+
- red
|
|
6
|
+
- rededge1
|
|
7
|
+
- rededge2
|
|
8
|
+
- rededge3
|
|
9
|
+
- nir
|
|
10
|
+
- nir08
|
|
11
|
+
- swir16
|
|
12
|
+
- swir22
|
|
13
|
+
rgb_indices:
|
|
14
|
+
- 2
|
|
15
|
+
- 1
|
|
16
|
+
- 0
|
|
17
|
+
gsd: 10
|
|
18
|
+
bands:
|
|
19
|
+
mean:
|
|
20
|
+
blue: 1105.
|
|
21
|
+
green: 1355.
|
|
22
|
+
red: 1552.
|
|
23
|
+
rededge1: 1887.
|
|
24
|
+
rededge2: 2422.
|
|
25
|
+
rededge3: 2630.
|
|
26
|
+
nir: 2743.
|
|
27
|
+
nir08: 2785.
|
|
28
|
+
swir16: 2388.
|
|
29
|
+
swir22: 1835.
|
|
30
|
+
std:
|
|
31
|
+
blue: 1809.
|
|
32
|
+
green: 1757.
|
|
33
|
+
red: 1888.
|
|
34
|
+
rededge1: 1870.
|
|
35
|
+
rededge2: 1732.
|
|
36
|
+
rededge3: 1697.
|
|
37
|
+
nir: 1742.
|
|
38
|
+
nir08: 1648.
|
|
39
|
+
swir16: 1470.
|
|
40
|
+
swir22: 1379.
|
|
41
|
+
wavelength:
|
|
42
|
+
blue: 0.493
|
|
43
|
+
green: 0.56
|
|
44
|
+
red: 0.665
|
|
45
|
+
rededge1: 0.704
|
|
46
|
+
rededge2: 0.74
|
|
47
|
+
rededge3: 0.783
|
|
48
|
+
nir: 0.842
|
|
49
|
+
nir08: 0.865
|
|
50
|
+
swir16: 1.61
|
|
51
|
+
swir22: 2.19
|
|
52
|
+
planetscope-sr:
|
|
53
|
+
band_order:
|
|
54
|
+
- coastal_blue
|
|
55
|
+
- blue
|
|
56
|
+
- green_i
|
|
57
|
+
- green
|
|
58
|
+
- yellow
|
|
59
|
+
- red
|
|
60
|
+
- rededge
|
|
61
|
+
- nir
|
|
62
|
+
rgb_indices:
|
|
63
|
+
- 5
|
|
64
|
+
- 3
|
|
65
|
+
- 1
|
|
66
|
+
gsd: 5
|
|
67
|
+
bands:
|
|
68
|
+
mean:
|
|
69
|
+
coastal_blue: 1720.
|
|
70
|
+
blue: 1715.
|
|
71
|
+
green_i: 1913.
|
|
72
|
+
green: 2088.
|
|
73
|
+
yellow: 2274.
|
|
74
|
+
red: 2290.
|
|
75
|
+
rededge: 2613.
|
|
76
|
+
nir: 3970.
|
|
77
|
+
std:
|
|
78
|
+
coastal_blue: 747.
|
|
79
|
+
blue: 698.
|
|
80
|
+
green_i: 739.
|
|
81
|
+
green: 768.
|
|
82
|
+
yellow: 849.
|
|
83
|
+
red: 868.
|
|
84
|
+
rededge: 849.
|
|
85
|
+
nir: 914.
|
|
86
|
+
wavelength:
|
|
87
|
+
coastal_blue: 0.443
|
|
88
|
+
blue: 0.490
|
|
89
|
+
green_i: 0.531
|
|
90
|
+
green: 0.565
|
|
91
|
+
yellow: 0.610
|
|
92
|
+
red: 0.665
|
|
93
|
+
rededge: 0.705
|
|
94
|
+
nir: 0.865
|
|
95
|
+
landsat-c2l1:
|
|
96
|
+
band_order:
|
|
97
|
+
- red
|
|
98
|
+
- green
|
|
99
|
+
- blue
|
|
100
|
+
- nir08
|
|
101
|
+
- swir16
|
|
102
|
+
- swir22
|
|
103
|
+
rgb_indices:
|
|
104
|
+
- 0
|
|
105
|
+
- 1
|
|
106
|
+
- 2
|
|
107
|
+
gsd: 30
|
|
108
|
+
bands:
|
|
109
|
+
mean:
|
|
110
|
+
red: 10678.
|
|
111
|
+
green: 10563.
|
|
112
|
+
blue: 11083.
|
|
113
|
+
nir08: 14792.
|
|
114
|
+
swir16: 12276.
|
|
115
|
+
swir22: 10114.
|
|
116
|
+
std:
|
|
117
|
+
red: 6025.
|
|
118
|
+
green: 5411.
|
|
119
|
+
blue: 5468.
|
|
120
|
+
nir08: 6746.
|
|
121
|
+
swir16: 5897.
|
|
122
|
+
swir22: 4850.
|
|
123
|
+
wavelength:
|
|
124
|
+
red: 0.65
|
|
125
|
+
green: 0.56
|
|
126
|
+
blue: 0.48
|
|
127
|
+
nir08: 0.86
|
|
128
|
+
swir16: 1.6
|
|
129
|
+
swir22: 2.2
|
|
130
|
+
landsat-c2l2-sr:
|
|
131
|
+
band_order:
|
|
132
|
+
- red
|
|
133
|
+
- green
|
|
134
|
+
- blue
|
|
135
|
+
- nir08
|
|
136
|
+
- swir16
|
|
137
|
+
- swir22
|
|
138
|
+
rgb_indices:
|
|
139
|
+
- 0
|
|
140
|
+
- 1
|
|
141
|
+
- 2
|
|
142
|
+
gsd: 30
|
|
143
|
+
bands:
|
|
144
|
+
mean:
|
|
145
|
+
red: 13705.
|
|
146
|
+
green: 13310.
|
|
147
|
+
blue: 12474.
|
|
148
|
+
nir08: 17801.
|
|
149
|
+
swir16: 14615.
|
|
150
|
+
swir22: 12701.
|
|
151
|
+
std:
|
|
152
|
+
red: 9578.
|
|
153
|
+
green: 9408.
|
|
154
|
+
blue: 10144.
|
|
155
|
+
nir08: 8277.
|
|
156
|
+
swir16: 5300.
|
|
157
|
+
swir22: 4522.
|
|
158
|
+
wavelength:
|
|
159
|
+
red: 0.65
|
|
160
|
+
green: 0.56
|
|
161
|
+
blue: 0.48
|
|
162
|
+
nir08: 0.86
|
|
163
|
+
swir16: 1.6
|
|
164
|
+
swir22: 2.2
|
|
165
|
+
naip:
|
|
166
|
+
band_order:
|
|
167
|
+
- red
|
|
168
|
+
- green
|
|
169
|
+
- blue
|
|
170
|
+
- nir
|
|
171
|
+
rgb_indices:
|
|
172
|
+
- 0
|
|
173
|
+
- 1
|
|
174
|
+
- 2
|
|
175
|
+
gsd: 1.0
|
|
176
|
+
bands:
|
|
177
|
+
mean:
|
|
178
|
+
red: 110.16
|
|
179
|
+
green: 115.41
|
|
180
|
+
blue: 98.15
|
|
181
|
+
nir: 139.04
|
|
182
|
+
std:
|
|
183
|
+
red: 47.23
|
|
184
|
+
green: 39.82
|
|
185
|
+
blue: 35.43
|
|
186
|
+
nir: 49.86
|
|
187
|
+
wavelength:
|
|
188
|
+
red: 0.65
|
|
189
|
+
green: 0.56
|
|
190
|
+
blue: 0.48
|
|
191
|
+
nir: 0.842
|
|
192
|
+
linz:
|
|
193
|
+
band_order:
|
|
194
|
+
- red
|
|
195
|
+
- green
|
|
196
|
+
- blue
|
|
197
|
+
rgb_indices:
|
|
198
|
+
- 0
|
|
199
|
+
- 1
|
|
200
|
+
- 2
|
|
201
|
+
gsd: 0.5
|
|
202
|
+
bands:
|
|
203
|
+
mean:
|
|
204
|
+
red: 89.96
|
|
205
|
+
green: 99.46
|
|
206
|
+
blue: 89.51
|
|
207
|
+
std:
|
|
208
|
+
red: 41.83
|
|
209
|
+
green: 36.96
|
|
210
|
+
blue: 31.45
|
|
211
|
+
wavelength:
|
|
212
|
+
red: 0.635
|
|
213
|
+
green: 0.555
|
|
214
|
+
blue: 0.465
|
|
215
|
+
sentinel-1-rtc:
|
|
216
|
+
band_order:
|
|
217
|
+
- vv
|
|
218
|
+
- vh
|
|
219
|
+
gsd: 10
|
|
220
|
+
bands:
|
|
221
|
+
mean:
|
|
222
|
+
vv: -12.113
|
|
223
|
+
vh: -18.673
|
|
224
|
+
std:
|
|
225
|
+
vv: 8.314
|
|
226
|
+
vh: 8.017
|
|
227
|
+
wavelength:
|
|
228
|
+
vv: 3.5
|
|
229
|
+
vh: 4.0
|
|
230
|
+
modis:
|
|
231
|
+
band_order:
|
|
232
|
+
- sur_refl_b01
|
|
233
|
+
- sur_refl_b02
|
|
234
|
+
- sur_refl_b03
|
|
235
|
+
- sur_refl_b04
|
|
236
|
+
- sur_refl_b05
|
|
237
|
+
- sur_refl_b06
|
|
238
|
+
- sur_refl_b07
|
|
239
|
+
rgb_indices:
|
|
240
|
+
- 0
|
|
241
|
+
- 3
|
|
242
|
+
- 2
|
|
243
|
+
gsd: 500
|
|
244
|
+
bands:
|
|
245
|
+
mean:
|
|
246
|
+
sur_refl_b01: 1072.
|
|
247
|
+
sur_refl_b02: 1624.
|
|
248
|
+
sur_refl_b03: 931.
|
|
249
|
+
sur_refl_b04: 1023.
|
|
250
|
+
sur_refl_b05: 1599.
|
|
251
|
+
sur_refl_b06: 1404.
|
|
252
|
+
sur_refl_b07: 1051.
|
|
253
|
+
std:
|
|
254
|
+
sur_refl_b01: 1643.
|
|
255
|
+
sur_refl_b02: 1878.
|
|
256
|
+
sur_refl_b03: 1449.
|
|
257
|
+
sur_refl_b04: 1538.
|
|
258
|
+
sur_refl_b05: 1763.
|
|
259
|
+
sur_refl_b06: 1618.
|
|
260
|
+
sur_refl_b07: 1396.
|
|
261
|
+
wavelength:
|
|
262
|
+
sur_refl_b01: .645
|
|
263
|
+
sur_refl_b02: .858
|
|
264
|
+
sur_refl_b03: .469
|
|
265
|
+
sur_refl_b04: .555
|
|
266
|
+
sur_refl_b05: 1.240
|
|
267
|
+
sur_refl_b06: 1.640
|
|
268
|
+
sur_refl_b07: 2.130
|
|
269
|
+
satellogic-MSI-L1D:
|
|
270
|
+
band_order:
|
|
271
|
+
- red
|
|
272
|
+
- green
|
|
273
|
+
- blue
|
|
274
|
+
- nir
|
|
275
|
+
rgb_indices:
|
|
276
|
+
- 0
|
|
277
|
+
- 1
|
|
278
|
+
- 2
|
|
279
|
+
gsd: 1.0
|
|
280
|
+
bands:
|
|
281
|
+
mean:
|
|
282
|
+
red: 1451.54
|
|
283
|
+
green: 1456.54
|
|
284
|
+
blue: 1543.22
|
|
285
|
+
nir: 2132.68
|
|
286
|
+
std:
|
|
287
|
+
red: 995.48
|
|
288
|
+
green: 771.29
|
|
289
|
+
blue: 708.86
|
|
290
|
+
nir: 1236.71
|
|
291
|
+
wavelength:
|
|
292
|
+
red: 0.640
|
|
293
|
+
green: 0.545
|
|
294
|
+
blue: 0.480
|
|
295
|
+
nir: 0.825
|
rslearn/models/clip.py
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
1
|
+
"""OpenAI CLIP models."""
|
|
2
|
+
|
|
3
|
+
from transformers import AutoModelForZeroShotImageClassification, AutoProcessor
|
|
4
|
+
|
|
5
|
+
from rslearn.train.model_context import ModelContext
|
|
6
|
+
|
|
7
|
+
from .component import FeatureExtractor, FeatureMaps
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class CLIP(FeatureExtractor):
|
|
11
|
+
"""CLIP image encoder."""
|
|
12
|
+
|
|
13
|
+
def __init__(
|
|
14
|
+
self,
|
|
15
|
+
model_name: str,
|
|
16
|
+
):
|
|
17
|
+
"""Instantiate a new CLIP instance.
|
|
18
|
+
|
|
19
|
+
Args:
|
|
20
|
+
model_name: the model name like "openai/clip-vit-large-patch14-336".
|
|
21
|
+
"""
|
|
22
|
+
super().__init__()
|
|
23
|
+
|
|
24
|
+
self.processor = AutoProcessor.from_pretrained(model_name) # nosec
|
|
25
|
+
model = AutoModelForZeroShotImageClassification.from_pretrained(model_name) # nosec
|
|
26
|
+
self.encoder = model.vision_model
|
|
27
|
+
|
|
28
|
+
# Get number of features and token map size from encoder attributes.
|
|
29
|
+
self.num_features = self.encoder.post_layernorm.normalized_shape[0]
|
|
30
|
+
crop_size = self.processor.image_processor.crop_size
|
|
31
|
+
stride = self.encoder.embeddings.patch_embedding.stride
|
|
32
|
+
self.height = crop_size["height"] // stride[0]
|
|
33
|
+
self.width = crop_size["width"] // stride[1]
|
|
34
|
+
|
|
35
|
+
def forward(self, context: ModelContext) -> FeatureMaps:
|
|
36
|
+
"""Compute outputs from the backbone.
|
|
37
|
+
|
|
38
|
+
Args:
|
|
39
|
+
context: the model context. Input dicts must include "image" key containing
|
|
40
|
+
the image to process. The images should have values 0-255.
|
|
41
|
+
|
|
42
|
+
Returns:
|
|
43
|
+
a FeatureMaps with one feature map from the ViT, which is always Bx24x24x1024.
|
|
44
|
+
"""
|
|
45
|
+
inputs = context.inputs
|
|
46
|
+
device = inputs[0]["image"].image.device
|
|
47
|
+
clip_inputs = self.processor(
|
|
48
|
+
images=[
|
|
49
|
+
inp["image"].single_ts_to_chw_tensor().cpu().numpy().transpose(1, 2, 0)
|
|
50
|
+
for inp in inputs
|
|
51
|
+
],
|
|
52
|
+
return_tensors="pt",
|
|
53
|
+
padding=True,
|
|
54
|
+
)
|
|
55
|
+
pixel_values = clip_inputs["pixel_values"].to(device)
|
|
56
|
+
output = self.encoder(pixel_values=pixel_values)
|
|
57
|
+
# Ignore class token output which is before the patch tokens.
|
|
58
|
+
image_features = output.last_hidden_state[:, 1:, :]
|
|
59
|
+
batch_size = image_features.shape[0]
|
|
60
|
+
|
|
61
|
+
# 576x1024 -> HxWxC
|
|
62
|
+
return FeatureMaps(
|
|
63
|
+
[
|
|
64
|
+
image_features.reshape(
|
|
65
|
+
batch_size, self.height, self.width, self.num_features
|
|
66
|
+
).permute(0, 3, 1, 2)
|
|
67
|
+
]
|
|
68
|
+
)
|
|
@@ -0,0 +1,111 @@
|
|
|
1
|
+
"""Model component API."""
|
|
2
|
+
|
|
3
|
+
import abc
|
|
4
|
+
from dataclasses import dataclass
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
|
|
9
|
+
from rslearn.train.model_context import ModelContext, ModelOutput
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class FeatureExtractor(torch.nn.Module, abc.ABC):
|
|
13
|
+
"""A feature extractor that performs initial processing of the inputs.
|
|
14
|
+
|
|
15
|
+
The FeatureExtractor is the first component in the encoders list for
|
|
16
|
+
SingleTaskModel and MultiTaskModel.
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
@abc.abstractmethod
|
|
20
|
+
def forward(self, context: ModelContext) -> Any:
|
|
21
|
+
"""Extract an initial intermediate from the model context.
|
|
22
|
+
|
|
23
|
+
Args:
|
|
24
|
+
context: the model context.
|
|
25
|
+
|
|
26
|
+
Returns:
|
|
27
|
+
any intermediate to pass to downstream components. Oftentimes this is a
|
|
28
|
+
FeatureMaps.
|
|
29
|
+
"""
|
|
30
|
+
raise NotImplementedError
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class IntermediateComponent(torch.nn.Module, abc.ABC):
|
|
34
|
+
"""An intermediate component in the model.
|
|
35
|
+
|
|
36
|
+
In SingleTaskModel and MultiTaskModel, modules after the first module
|
|
37
|
+
in the encoders list are IntermediateComponents, as are modules before the last
|
|
38
|
+
module in the decoders list(s).
|
|
39
|
+
"""
|
|
40
|
+
|
|
41
|
+
@abc.abstractmethod
|
|
42
|
+
def forward(self, intermediates: Any, context: ModelContext) -> Any:
|
|
43
|
+
"""Process the given intermediate into another intermediate.
|
|
44
|
+
|
|
45
|
+
Args:
|
|
46
|
+
intermediates: the output from the previous component (either a
|
|
47
|
+
FeatureExtractor or another IntermediateComponent).
|
|
48
|
+
context: the model context.
|
|
49
|
+
|
|
50
|
+
Returns:
|
|
51
|
+
any intermediate to pass to downstream components.
|
|
52
|
+
"""
|
|
53
|
+
raise NotImplementedError
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
class Predictor(torch.nn.Module, abc.ABC):
|
|
57
|
+
"""A predictor that computes task-specific outputs and a loss dict.
|
|
58
|
+
|
|
59
|
+
In SingleTaskModel and MultiTaskModel, the last module(s) in the decoders list(s)
|
|
60
|
+
are Predictors.
|
|
61
|
+
"""
|
|
62
|
+
|
|
63
|
+
@abc.abstractmethod
|
|
64
|
+
def forward(
|
|
65
|
+
self,
|
|
66
|
+
intermediates: Any,
|
|
67
|
+
context: ModelContext,
|
|
68
|
+
targets: list[dict[str, torch.Tensor]] | None = None,
|
|
69
|
+
) -> ModelOutput:
|
|
70
|
+
"""Compute task-specific outputs and loss dict.
|
|
71
|
+
|
|
72
|
+
Args:
|
|
73
|
+
intermediates: the output from the previous component.
|
|
74
|
+
context: the model context.
|
|
75
|
+
targets: the training targets, or None during prediction.
|
|
76
|
+
|
|
77
|
+
Returns:
|
|
78
|
+
a tuple of the task-specific outputs (which should be compatible with the
|
|
79
|
+
configured Task) and loss dict. The loss dict maps from a name for each
|
|
80
|
+
loss to a scalar tensor.
|
|
81
|
+
"""
|
|
82
|
+
raise NotImplementedError
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
@dataclass
|
|
86
|
+
class FeatureMaps:
|
|
87
|
+
"""An intermediate output type for multi-resolution feature maps."""
|
|
88
|
+
|
|
89
|
+
# List of BxCxHxW feature maps at different scales, ordered from highest resolution
|
|
90
|
+
# (most fine-grained) to lowest resolution (coarsest).
|
|
91
|
+
feature_maps: list[torch.Tensor]
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
@dataclass
|
|
95
|
+
class TokenFeatureMaps:
|
|
96
|
+
"""An intermediate output type for multi-resolution BCHWN feature maps with a token dimension.
|
|
97
|
+
|
|
98
|
+
Unlike `FeatureMaps`, these include an additional dimension for unpooled tokens.
|
|
99
|
+
"""
|
|
100
|
+
|
|
101
|
+
# List of BxCxHxWxN feature maps at different scales, ordered from highest resolution
|
|
102
|
+
# (most fine-grained) to lowest resolution (coarsest).
|
|
103
|
+
feature_maps: list[torch.Tensor]
|
|
104
|
+
|
|
105
|
+
|
|
106
|
+
@dataclass
|
|
107
|
+
class FeatureVector:
|
|
108
|
+
"""An intermediate output type for a flat feature vector."""
|
|
109
|
+
|
|
110
|
+
# Flat BxC feature vector.
|
|
111
|
+
feature_vector: torch.Tensor
|
|
@@ -0,0 +1,103 @@
|
|
|
1
|
+
"""Concatenate feature map with features from input data."""
|
|
2
|
+
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
|
|
7
|
+
from rslearn.train.model_context import ModelContext
|
|
8
|
+
|
|
9
|
+
from .component import FeatureMaps, IntermediateComponent
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class ConcatenateFeatures(IntermediateComponent):
|
|
13
|
+
"""Concatenate feature map with additional raw data inputs."""
|
|
14
|
+
|
|
15
|
+
def __init__(
|
|
16
|
+
self,
|
|
17
|
+
key: str,
|
|
18
|
+
in_channels: int | None = None,
|
|
19
|
+
conv_channels: int = 64,
|
|
20
|
+
out_channels: int | None = None,
|
|
21
|
+
num_conv_layers: int = 1,
|
|
22
|
+
kernel_size: int = 3,
|
|
23
|
+
final_relu: bool = False,
|
|
24
|
+
):
|
|
25
|
+
"""Create a new ConcatenateFeatures.
|
|
26
|
+
|
|
27
|
+
Args:
|
|
28
|
+
key: the key of the input_dict to concatenate.
|
|
29
|
+
in_channels: number of input channels of the additional features.
|
|
30
|
+
conv_channels: number of channels of the convolutional layers.
|
|
31
|
+
out_channels: number of output channels of the additional features.
|
|
32
|
+
num_conv_layers: number of convolutional layers to apply to the additional features.
|
|
33
|
+
kernel_size: kernel size of the convolutional layers.
|
|
34
|
+
final_relu: whether to apply a ReLU activation to the final output, default False.
|
|
35
|
+
"""
|
|
36
|
+
super().__init__()
|
|
37
|
+
self.key = key
|
|
38
|
+
|
|
39
|
+
if num_conv_layers > 0:
|
|
40
|
+
if in_channels is None or out_channels is None:
|
|
41
|
+
raise ValueError(
|
|
42
|
+
"in_channels and out_channels must be specified if num_conv_layers > 0"
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
conv_layers = []
|
|
46
|
+
for i in range(num_conv_layers):
|
|
47
|
+
conv_in = in_channels if i == 0 else conv_channels
|
|
48
|
+
conv_out = out_channels if i == num_conv_layers - 1 else conv_channels
|
|
49
|
+
conv_layers.append(
|
|
50
|
+
torch.nn.Conv2d(
|
|
51
|
+
in_channels=conv_in,
|
|
52
|
+
out_channels=conv_out,
|
|
53
|
+
kernel_size=kernel_size,
|
|
54
|
+
padding="same",
|
|
55
|
+
)
|
|
56
|
+
)
|
|
57
|
+
if i < num_conv_layers - 1 or final_relu:
|
|
58
|
+
conv_layers.append(torch.nn.ReLU(inplace=True))
|
|
59
|
+
|
|
60
|
+
self.conv_layers = torch.nn.Sequential(*conv_layers)
|
|
61
|
+
|
|
62
|
+
def forward(self, intermediates: Any, context: ModelContext) -> FeatureMaps:
|
|
63
|
+
"""Concatenate the feature map with the raw data inputs.
|
|
64
|
+
|
|
65
|
+
Args:
|
|
66
|
+
intermediates: the previous output, which must be a FeatureMaps.
|
|
67
|
+
context: the model context. The input dicts must have a key matching the
|
|
68
|
+
configured key.
|
|
69
|
+
|
|
70
|
+
Returns:
|
|
71
|
+
concatenated feature maps.
|
|
72
|
+
"""
|
|
73
|
+
if (
|
|
74
|
+
not isinstance(intermediates, FeatureMaps)
|
|
75
|
+
or len(intermediates.feature_maps) == 0
|
|
76
|
+
):
|
|
77
|
+
raise ValueError(
|
|
78
|
+
"Expected input to be FeatureMaps with at least one feature map"
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
add_data = torch.stack(
|
|
82
|
+
[input_data[self.key] for input_data in context.inputs], dim=0
|
|
83
|
+
)
|
|
84
|
+
add_features = self.conv_layers(add_data)
|
|
85
|
+
|
|
86
|
+
new_features: list[torch.Tensor] = []
|
|
87
|
+
for feature_map in intermediates.feature_maps:
|
|
88
|
+
# Shape of feature map: BCHW
|
|
89
|
+
feat_h, feat_w = feature_map.shape[2], feature_map.shape[3]
|
|
90
|
+
|
|
91
|
+
resized_add_features = add_features
|
|
92
|
+
# Resize additional features to match each feature map size if needed
|
|
93
|
+
if add_features.shape[2] != feat_h or add_features.shape[3] != feat_w:
|
|
94
|
+
resized_add_features = torch.nn.functional.interpolate(
|
|
95
|
+
add_features,
|
|
96
|
+
size=(feat_h, feat_w),
|
|
97
|
+
mode="bilinear",
|
|
98
|
+
align_corners=False,
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
new_features.append(torch.cat([feature_map, resized_add_features], dim=1))
|
|
102
|
+
|
|
103
|
+
return FeatureMaps(new_features)
|
rslearn/models/conv.py
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
1
|
+
"""A single convolutional layer."""
|
|
2
|
+
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
|
|
7
|
+
from rslearn.train.model_context import ModelContext
|
|
8
|
+
|
|
9
|
+
from .component import FeatureMaps, IntermediateComponent
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class Conv(IntermediateComponent):
|
|
13
|
+
"""A single convolutional layer.
|
|
14
|
+
|
|
15
|
+
It inputs a set of feature maps; the conv layer is applied to each feature map
|
|
16
|
+
independently, and list of outputs is returned.
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
def __init__(
|
|
20
|
+
self,
|
|
21
|
+
in_channels: int,
|
|
22
|
+
out_channels: int,
|
|
23
|
+
kernel_size: int,
|
|
24
|
+
padding: str | int = "same",
|
|
25
|
+
stride: int = 1,
|
|
26
|
+
activation: torch.nn.Module = torch.nn.ReLU(inplace=True),
|
|
27
|
+
):
|
|
28
|
+
"""Initialize a Conv.
|
|
29
|
+
|
|
30
|
+
Args:
|
|
31
|
+
in_channels: number of input channels.
|
|
32
|
+
out_channels: number of output channels.
|
|
33
|
+
kernel_size: kernel size, see torch.nn.Conv2D.
|
|
34
|
+
padding: padding to apply, see torch.nn.Conv2D.
|
|
35
|
+
stride: stride to apply, see torch.nn.Conv2D.
|
|
36
|
+
activation: activation to apply after convolution
|
|
37
|
+
"""
|
|
38
|
+
super().__init__()
|
|
39
|
+
|
|
40
|
+
self.layer = torch.nn.Conv2d(
|
|
41
|
+
in_channels, out_channels, kernel_size, padding=padding, stride=stride
|
|
42
|
+
)
|
|
43
|
+
self.activation = activation
|
|
44
|
+
|
|
45
|
+
def forward(self, intermediates: Any, context: ModelContext) -> FeatureMaps:
|
|
46
|
+
"""Apply conv layer on each feature map.
|
|
47
|
+
|
|
48
|
+
Args:
|
|
49
|
+
intermediates: the previous output, which must be a FeatureMaps.
|
|
50
|
+
context: the model context.
|
|
51
|
+
|
|
52
|
+
Returns:
|
|
53
|
+
the resulting feature maps after applying the same Conv2d on each one.
|
|
54
|
+
"""
|
|
55
|
+
if not isinstance(intermediates, FeatureMaps):
|
|
56
|
+
raise ValueError("input to Conv must be FeatureMaps")
|
|
57
|
+
|
|
58
|
+
new_features = []
|
|
59
|
+
for feat_map in intermediates.feature_maps:
|
|
60
|
+
feat_map = self.layer(feat_map)
|
|
61
|
+
feat_map = self.activation(feat_map)
|
|
62
|
+
new_features.append(feat_map)
|
|
63
|
+
return FeatureMaps(new_features)
|