geoai-py 0.25.0__py2.py3-none-any.whl → 0.27.0__py2.py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- geoai/__init__.py +38 -1
- geoai/auto.py +4 -2
- geoai/change_detection.py +2 -2
- geoai/detectron2.py +4 -1
- geoai/extract.py +4 -1
- geoai/moondream.py +0 -1
- geoai/prithvi.py +1338 -0
- geoai/sam.py +2 -1
- geoai/segment.py +10 -1
- geoai/timm_regress.py +1652 -0
- geoai/utils.py +3 -1
- {geoai_py-0.25.0.dist-info → geoai_py-0.27.0.dist-info}/METADATA +4 -4
- {geoai_py-0.25.0.dist-info → geoai_py-0.27.0.dist-info}/RECORD +17 -15
- {geoai_py-0.25.0.dist-info → geoai_py-0.27.0.dist-info}/WHEEL +1 -1
- {geoai_py-0.25.0.dist-info → geoai_py-0.27.0.dist-info}/entry_points.txt +0 -0
- {geoai_py-0.25.0.dist-info → geoai_py-0.27.0.dist-info}/licenses/LICENSE +0 -0
- {geoai_py-0.25.0.dist-info → geoai_py-0.27.0.dist-info}/top_level.txt +0 -0
geoai/prithvi.py
ADDED
|
@@ -0,0 +1,1338 @@
|
|
|
1
|
+
"""Prithvi EO 2.0 module for geospatial foundation model inference.
|
|
2
|
+
|
|
3
|
+
This module provides tools for using NASA-IBM's Prithvi EO 2.0 geospatial foundation model
|
|
4
|
+
for masked autoencoding and feature extraction on multi-temporal satellite imagery.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
import json
|
|
8
|
+
import logging
|
|
9
|
+
import os
|
|
10
|
+
from typing import Dict, List, Optional, Tuple, Union
|
|
11
|
+
|
|
12
|
+
import numpy as np
|
|
13
|
+
import rasterio
|
|
14
|
+
import torch
|
|
15
|
+
import torch.nn as nn
|
|
16
|
+
from einops import rearrange
|
|
17
|
+
from huggingface_hub import hf_hub_download
|
|
18
|
+
from timm.layers import to_2tuple
|
|
19
|
+
from timm.models.vision_transformer import Block
|
|
20
|
+
|
|
21
|
+
from .utils import get_device
|
|
22
|
+
|
|
23
|
+
logger = logging.getLogger(__name__)
|
|
24
|
+
|
|
25
|
+
# Constants
|
|
26
|
+
NO_DATA = -9999
|
|
27
|
+
NO_DATA_FLOAT = 0.0001
|
|
28
|
+
OFFSET = 0
|
|
29
|
+
PERCENTILE = 99.9
|
|
30
|
+
|
|
31
|
+
# Available Prithvi models
|
|
32
|
+
AVAILABLE_MODELS = [
|
|
33
|
+
"Prithvi-EO-2.0-tiny-TL", # tiny transfer learning, embed_dim=192, depth=12, with coords
|
|
34
|
+
"Prithvi-EO-2.0-100M-TL", # 100M transfer learning, embed_dim=768, depth=12, with coords
|
|
35
|
+
"Prithvi-EO-2.0-300M", # 300M base model, embed_dim=1024, depth=24, no coords
|
|
36
|
+
"Prithvi-EO-2.0-300M-TL", # 300M transfer learning, embed_dim=768, depth=12, with coords
|
|
37
|
+
"Prithvi-EO-2.0-600M", # 600M base model, embed_dim=1280, depth=32, no coords
|
|
38
|
+
"Prithvi-EO-2.0-600M-TL", # 600M transfer learning, embed_dim=1280, depth=32, with coords
|
|
39
|
+
]
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def get_3d_sincos_pos_embed(embed_dim, grid_size, add_cls_token=False):
|
|
43
|
+
"""Create 3D sin/cos positional embeddings.
|
|
44
|
+
|
|
45
|
+
Args:
|
|
46
|
+
embed_dim (int): Embedding dimension.
|
|
47
|
+
grid_size (tuple[int, int, int] | list[int]): The grid depth, height and width.
|
|
48
|
+
add_cls_token (bool, optional): Whether or not to add a classification (CLS) token.
|
|
49
|
+
|
|
50
|
+
Returns:
|
|
51
|
+
Position embeddings (with or without cls token)
|
|
52
|
+
"""
|
|
53
|
+
assert embed_dim % 16 == 0
|
|
54
|
+
|
|
55
|
+
t_size, h_size, w_size = grid_size
|
|
56
|
+
|
|
57
|
+
w_embed_dim = embed_dim // 16 * 6
|
|
58
|
+
h_embed_dim = embed_dim // 16 * 6
|
|
59
|
+
t_embed_dim = embed_dim // 16 * 4
|
|
60
|
+
|
|
61
|
+
w_pos_embed = get_1d_sincos_pos_embed_from_grid(w_embed_dim, np.arange(w_size))
|
|
62
|
+
h_pos_embed = get_1d_sincos_pos_embed_from_grid(h_embed_dim, np.arange(h_size))
|
|
63
|
+
t_pos_embed = get_1d_sincos_pos_embed_from_grid(t_embed_dim, np.arange(t_size))
|
|
64
|
+
|
|
65
|
+
w_pos_embed = np.tile(w_pos_embed, (t_size * h_size, 1))
|
|
66
|
+
h_pos_embed = np.tile(np.repeat(h_pos_embed, w_size, axis=0), (t_size, 1))
|
|
67
|
+
t_pos_embed = np.repeat(t_pos_embed, h_size * w_size, axis=0)
|
|
68
|
+
|
|
69
|
+
pos_embed = np.concatenate((w_pos_embed, h_pos_embed, t_pos_embed), axis=1)
|
|
70
|
+
|
|
71
|
+
if add_cls_token:
|
|
72
|
+
pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
|
|
73
|
+
return pos_embed
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
|
77
|
+
"""Generate 1D sincos position embeddings."""
|
|
78
|
+
if embed_dim % 2 != 0:
|
|
79
|
+
raise ValueError("embed_dim must be even")
|
|
80
|
+
|
|
81
|
+
omega = np.arange(embed_dim // 2, dtype=float)
|
|
82
|
+
omega /= embed_dim / 2.0
|
|
83
|
+
omega = 1.0 / 10000**omega
|
|
84
|
+
|
|
85
|
+
pos = pos.reshape(-1)
|
|
86
|
+
out = np.einsum("m,d->md", pos, omega)
|
|
87
|
+
|
|
88
|
+
emb_sin = np.sin(out)
|
|
89
|
+
emb_cos = np.cos(out)
|
|
90
|
+
|
|
91
|
+
emb = np.concatenate([emb_sin, emb_cos], axis=1)
|
|
92
|
+
return emb
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
def _get_1d_sincos_embed_from_grid_torch(embed_dim: int, pos: torch.Tensor):
|
|
96
|
+
"""Modified torch version of get_1d_sincos_pos_embed_from_grid()."""
|
|
97
|
+
assert embed_dim % 2 == 0
|
|
98
|
+
assert pos.dtype in [torch.float32, torch.float16, torch.bfloat16]
|
|
99
|
+
|
|
100
|
+
omega = torch.arange(embed_dim // 2, dtype=pos.dtype).to(pos.device)
|
|
101
|
+
omega /= embed_dim / 2.0
|
|
102
|
+
omega = 1.0 / 10000**omega
|
|
103
|
+
|
|
104
|
+
pos = pos.reshape(-1)
|
|
105
|
+
out = torch.einsum("m,d->md", pos, omega)
|
|
106
|
+
|
|
107
|
+
emb_sin = torch.sin(out)
|
|
108
|
+
emb_cos = torch.cos(out)
|
|
109
|
+
|
|
110
|
+
emb = torch.cat([emb_sin, emb_cos], dim=1)
|
|
111
|
+
return emb
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
def _init_weights(module):
|
|
115
|
+
"""Initialize the weights."""
|
|
116
|
+
if isinstance(module, nn.Linear):
|
|
117
|
+
nn.init.xavier_uniform_(module.weight)
|
|
118
|
+
if module.bias is not None:
|
|
119
|
+
nn.init.constant_(module.bias, 0)
|
|
120
|
+
elif isinstance(module, nn.LayerNorm):
|
|
121
|
+
module.bias.data.zero_()
|
|
122
|
+
module.weight.data.fill_(1.0)
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
class PatchEmbed(nn.Module):
|
|
126
|
+
"""3D Patch Embedding."""
|
|
127
|
+
|
|
128
|
+
def __init__(
|
|
129
|
+
self,
|
|
130
|
+
input_size: tuple[int, int, int] = (1, 224, 224),
|
|
131
|
+
patch_size: tuple[int, int, int] = (1, 16, 16),
|
|
132
|
+
in_chans: int = 3,
|
|
133
|
+
embed_dim: int = 768,
|
|
134
|
+
norm_layer: nn.Module | None = None,
|
|
135
|
+
flatten: bool = True,
|
|
136
|
+
bias: bool = True,
|
|
137
|
+
):
|
|
138
|
+
super().__init__()
|
|
139
|
+
self.input_size = input_size
|
|
140
|
+
self.patch_size = patch_size
|
|
141
|
+
self.grid_size = [s // p for s, p in zip(self.input_size, self.patch_size)]
|
|
142
|
+
assert all(
|
|
143
|
+
g >= 1 for g in self.grid_size
|
|
144
|
+
), "Patch size is bigger than input size."
|
|
145
|
+
self.num_patches = self.grid_size[0] * self.grid_size[1] * self.grid_size[2]
|
|
146
|
+
self.flatten = flatten
|
|
147
|
+
|
|
148
|
+
self.proj = nn.Conv3d(
|
|
149
|
+
in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias
|
|
150
|
+
)
|
|
151
|
+
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
|
152
|
+
|
|
153
|
+
def forward(self, x):
|
|
154
|
+
B, C, T, H, W = x.shape
|
|
155
|
+
x = self.proj(x)
|
|
156
|
+
if self.flatten:
|
|
157
|
+
x = x.flatten(2).transpose(1, 2)
|
|
158
|
+
x = self.norm(x)
|
|
159
|
+
return x
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
class TemporalEncoder(nn.Module):
|
|
163
|
+
"""Temporal coordinate encoder."""
|
|
164
|
+
|
|
165
|
+
def __init__(self, embed_dim: int, trainable_scale: bool = False):
|
|
166
|
+
super().__init__()
|
|
167
|
+
self.embed_dim = embed_dim
|
|
168
|
+
self.year_embed_dim = embed_dim // 2
|
|
169
|
+
self.julian_day_embed_dim = embed_dim - self.year_embed_dim
|
|
170
|
+
|
|
171
|
+
if trainable_scale:
|
|
172
|
+
self.scale = nn.Parameter(torch.tensor(0.1))
|
|
173
|
+
else:
|
|
174
|
+
self.scale = 1.0
|
|
175
|
+
|
|
176
|
+
def forward(
|
|
177
|
+
self, temporal_coords: torch.Tensor, tokens_per_frame: int | None = None
|
|
178
|
+
):
|
|
179
|
+
"""
|
|
180
|
+
temporal_coords: year and day-of-year info with shape (B, T, 2).
|
|
181
|
+
"""
|
|
182
|
+
shape = temporal_coords.shape[:2] + (-1,)
|
|
183
|
+
|
|
184
|
+
year = _get_1d_sincos_embed_from_grid_torch(
|
|
185
|
+
self.year_embed_dim, temporal_coords[:, :, 0].flatten()
|
|
186
|
+
).reshape(shape)
|
|
187
|
+
julian_day = _get_1d_sincos_embed_from_grid_torch(
|
|
188
|
+
self.julian_day_embed_dim, temporal_coords[:, :, 1].flatten()
|
|
189
|
+
).reshape(shape)
|
|
190
|
+
|
|
191
|
+
embedding = self.scale * torch.cat([year, julian_day], dim=-1)
|
|
192
|
+
|
|
193
|
+
if tokens_per_frame is not None:
|
|
194
|
+
embedding = torch.repeat_interleave(embedding, tokens_per_frame, dim=1)
|
|
195
|
+
|
|
196
|
+
return embedding
|
|
197
|
+
|
|
198
|
+
|
|
199
|
+
class LocationEncoder(nn.Module):
|
|
200
|
+
"""Location coordinate encoder."""
|
|
201
|
+
|
|
202
|
+
def __init__(self, embed_dim: int, trainable_scale: bool = False):
|
|
203
|
+
super().__init__()
|
|
204
|
+
self.embed_dim = embed_dim
|
|
205
|
+
self.lat_embed_dim = embed_dim // 2
|
|
206
|
+
self.lon_embed_dim = embed_dim - self.lat_embed_dim
|
|
207
|
+
|
|
208
|
+
if trainable_scale:
|
|
209
|
+
self.scale = nn.Parameter(torch.tensor(0.1))
|
|
210
|
+
else:
|
|
211
|
+
self.scale = 1.0
|
|
212
|
+
|
|
213
|
+
def forward(self, location_coords: torch.Tensor):
|
|
214
|
+
"""
|
|
215
|
+
location_coords: lat and lon info with shape (B, 2).
|
|
216
|
+
"""
|
|
217
|
+
shape = location_coords.shape[:1] + (1, -1)
|
|
218
|
+
|
|
219
|
+
lat = _get_1d_sincos_embed_from_grid_torch(
|
|
220
|
+
self.lat_embed_dim, location_coords[:, 0].flatten()
|
|
221
|
+
).reshape(shape)
|
|
222
|
+
lon = _get_1d_sincos_embed_from_grid_torch(
|
|
223
|
+
self.lon_embed_dim, location_coords[:, 1].flatten()
|
|
224
|
+
).reshape(shape)
|
|
225
|
+
|
|
226
|
+
embedding = self.scale * torch.cat([lat, lon], dim=-1)
|
|
227
|
+
|
|
228
|
+
return embedding
|
|
229
|
+
|
|
230
|
+
|
|
231
|
+
class PrithviViT(nn.Module):
|
|
232
|
+
"""Prithvi ViT Encoder."""
|
|
233
|
+
|
|
234
|
+
def __init__(
|
|
235
|
+
self,
|
|
236
|
+
img_size: int | tuple[int, int] = 224,
|
|
237
|
+
patch_size: int | tuple[int, int, int] = (1, 16, 16),
|
|
238
|
+
num_frames: int = 1,
|
|
239
|
+
in_chans: int = 3,
|
|
240
|
+
embed_dim: int = 1024,
|
|
241
|
+
depth: int = 24,
|
|
242
|
+
num_heads: int = 16,
|
|
243
|
+
mlp_ratio: float = 4.0,
|
|
244
|
+
norm_layer: nn.Module = nn.LayerNorm,
|
|
245
|
+
coords_encoding: list[str] | None = None,
|
|
246
|
+
coords_scale_learn: bool = False,
|
|
247
|
+
drop_path: float = 0.0,
|
|
248
|
+
**kwargs,
|
|
249
|
+
):
|
|
250
|
+
super().__init__()
|
|
251
|
+
|
|
252
|
+
self.in_chans = in_chans
|
|
253
|
+
self.num_frames = num_frames
|
|
254
|
+
self.embed_dim = embed_dim
|
|
255
|
+
self.img_size = to_2tuple(img_size)
|
|
256
|
+
if isinstance(patch_size, int):
|
|
257
|
+
patch_size = (1, patch_size, patch_size)
|
|
258
|
+
|
|
259
|
+
self.patch_embed = PatchEmbed(
|
|
260
|
+
input_size=(num_frames,) + self.img_size,
|
|
261
|
+
patch_size=patch_size,
|
|
262
|
+
in_chans=in_chans,
|
|
263
|
+
embed_dim=embed_dim,
|
|
264
|
+
)
|
|
265
|
+
|
|
266
|
+
coords_encoding = coords_encoding or []
|
|
267
|
+
self.temporal_encoding = "time" in coords_encoding
|
|
268
|
+
self.location_encoding = "location" in coords_encoding
|
|
269
|
+
|
|
270
|
+
if self.temporal_encoding:
|
|
271
|
+
self.temporal_encoder = TemporalEncoder(embed_dim, coords_scale_learn)
|
|
272
|
+
if self.location_encoding:
|
|
273
|
+
self.location_encoder = LocationEncoder(embed_dim, coords_scale_learn)
|
|
274
|
+
|
|
275
|
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
|
276
|
+
self.register_buffer(
|
|
277
|
+
"pos_embed", torch.zeros(1, self.patch_embed.num_patches + 1, embed_dim)
|
|
278
|
+
)
|
|
279
|
+
|
|
280
|
+
self.blocks = nn.ModuleList(
|
|
281
|
+
[
|
|
282
|
+
Block(
|
|
283
|
+
embed_dim,
|
|
284
|
+
num_heads,
|
|
285
|
+
mlp_ratio,
|
|
286
|
+
qkv_bias=True,
|
|
287
|
+
norm_layer=norm_layer,
|
|
288
|
+
)
|
|
289
|
+
for _ in range(depth)
|
|
290
|
+
]
|
|
291
|
+
)
|
|
292
|
+
|
|
293
|
+
self.norm = norm_layer(embed_dim)
|
|
294
|
+
self.initialize_weights()
|
|
295
|
+
|
|
296
|
+
def initialize_weights(self):
|
|
297
|
+
pos_embed = get_3d_sincos_pos_embed(
|
|
298
|
+
self.pos_embed.shape[-1], self.patch_embed.grid_size, add_cls_token=True
|
|
299
|
+
)
|
|
300
|
+
self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
|
|
301
|
+
|
|
302
|
+
w = self.patch_embed.proj.weight.data
|
|
303
|
+
torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
|
|
304
|
+
|
|
305
|
+
torch.nn.init.normal_(self.cls_token, std=0.02)
|
|
306
|
+
self.apply(_init_weights)
|
|
307
|
+
|
|
308
|
+
def random_masking(self, sequence, mask_ratio, noise=None):
|
|
309
|
+
N, L, D = sequence.shape
|
|
310
|
+
len_keep = int(L * (1 - mask_ratio))
|
|
311
|
+
|
|
312
|
+
if noise is None:
|
|
313
|
+
noise = torch.rand(N, L, device=sequence.device)
|
|
314
|
+
|
|
315
|
+
ids_shuffle = torch.argsort(noise, dim=1)
|
|
316
|
+
ids_restore = torch.argsort(ids_shuffle, dim=1)
|
|
317
|
+
|
|
318
|
+
ids_keep = ids_shuffle[:, :len_keep]
|
|
319
|
+
sequence_masked = torch.gather(
|
|
320
|
+
sequence, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D)
|
|
321
|
+
)
|
|
322
|
+
|
|
323
|
+
mask = torch.ones([N, L], device=sequence.device)
|
|
324
|
+
mask[:, :len_keep] = 0
|
|
325
|
+
mask = torch.gather(mask, dim=1, index=ids_restore)
|
|
326
|
+
|
|
327
|
+
return sequence_masked, mask, ids_restore
|
|
328
|
+
|
|
329
|
+
def forward(
|
|
330
|
+
self,
|
|
331
|
+
x: torch.Tensor,
|
|
332
|
+
temporal_coords: None | torch.Tensor = None,
|
|
333
|
+
location_coords: None | torch.Tensor = None,
|
|
334
|
+
mask_ratio=0.75,
|
|
335
|
+
):
|
|
336
|
+
x = self.patch_embed(x)
|
|
337
|
+
x = x + self.pos_embed[:, 1:, :]
|
|
338
|
+
|
|
339
|
+
if self.temporal_encoding and temporal_coords is not None:
|
|
340
|
+
x = x + self.temporal_encoder(
|
|
341
|
+
temporal_coords, x.shape[1] // self.num_frames
|
|
342
|
+
)
|
|
343
|
+
|
|
344
|
+
if self.location_encoding and location_coords is not None:
|
|
345
|
+
x = x + self.location_encoder(location_coords)
|
|
346
|
+
|
|
347
|
+
cls_token = self.cls_token + self.pos_embed[:, :1, :]
|
|
348
|
+
cls_tokens = cls_token.expand(x.shape[0], -1, -1)
|
|
349
|
+
|
|
350
|
+
x, mask, ids_restore = self.random_masking(x, mask_ratio)
|
|
351
|
+
x = torch.cat((cls_tokens, x), dim=1)
|
|
352
|
+
|
|
353
|
+
for blk in self.blocks:
|
|
354
|
+
x = blk(x)
|
|
355
|
+
|
|
356
|
+
x = self.norm(x)
|
|
357
|
+
return x, mask, ids_restore
|
|
358
|
+
|
|
359
|
+
|
|
360
|
+
class MAEDecoder(nn.Module):
|
|
361
|
+
"""Transformer Decoder used in the Prithvi MAE."""
|
|
362
|
+
|
|
363
|
+
def __init__(
|
|
364
|
+
self,
|
|
365
|
+
patch_size: int | tuple[int, int, int] = (1, 16, 16),
|
|
366
|
+
grid_size: list[int] | tuple[int, int, int] = (3, 14, 14),
|
|
367
|
+
in_chans: int = 3,
|
|
368
|
+
encoder_embed_dim: int = 1024,
|
|
369
|
+
decoder_embed_dim: int = 512,
|
|
370
|
+
depth: int = 8,
|
|
371
|
+
num_heads: int = 16,
|
|
372
|
+
mlp_ratio: float = 4.0,
|
|
373
|
+
norm_layer: nn.Module = nn.LayerNorm,
|
|
374
|
+
coords_encoding: list[str] | None = None,
|
|
375
|
+
coords_scale_learn: bool = False,
|
|
376
|
+
):
|
|
377
|
+
super().__init__()
|
|
378
|
+
|
|
379
|
+
self.patch_size = patch_size
|
|
380
|
+
self.grid_size = grid_size
|
|
381
|
+
self.in_chans = in_chans
|
|
382
|
+
self.decoder_embed = nn.Linear(encoder_embed_dim, decoder_embed_dim, bias=True)
|
|
383
|
+
|
|
384
|
+
self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))
|
|
385
|
+
|
|
386
|
+
self.register_buffer(
|
|
387
|
+
"decoder_pos_embed",
|
|
388
|
+
torch.zeros(
|
|
389
|
+
1, grid_size[0] * grid_size[1] * grid_size[2] + 1, decoder_embed_dim
|
|
390
|
+
),
|
|
391
|
+
)
|
|
392
|
+
|
|
393
|
+
coords_encoding = coords_encoding or []
|
|
394
|
+
self.temporal_encoding = "time" in coords_encoding
|
|
395
|
+
self.location_encoding = "location" in coords_encoding
|
|
396
|
+
|
|
397
|
+
if self.temporal_encoding:
|
|
398
|
+
self.temporal_encoder = TemporalEncoder(
|
|
399
|
+
decoder_embed_dim, coords_scale_learn
|
|
400
|
+
)
|
|
401
|
+
if self.location_encoding:
|
|
402
|
+
self.location_encoder = LocationEncoder(
|
|
403
|
+
decoder_embed_dim, coords_scale_learn
|
|
404
|
+
)
|
|
405
|
+
|
|
406
|
+
self.decoder_blocks = nn.ModuleList(
|
|
407
|
+
[
|
|
408
|
+
Block(
|
|
409
|
+
decoder_embed_dim,
|
|
410
|
+
num_heads,
|
|
411
|
+
mlp_ratio,
|
|
412
|
+
qkv_bias=True,
|
|
413
|
+
norm_layer=norm_layer,
|
|
414
|
+
)
|
|
415
|
+
for _ in range(depth)
|
|
416
|
+
]
|
|
417
|
+
)
|
|
418
|
+
|
|
419
|
+
self.decoder_norm = norm_layer(decoder_embed_dim)
|
|
420
|
+
self.decoder_pred = nn.Linear(
|
|
421
|
+
decoder_embed_dim,
|
|
422
|
+
patch_size[0] * patch_size[1] * patch_size[2] * in_chans,
|
|
423
|
+
bias=True,
|
|
424
|
+
)
|
|
425
|
+
|
|
426
|
+
self.initialize_weights()
|
|
427
|
+
|
|
428
|
+
def initialize_weights(self):
|
|
429
|
+
pos_embed = get_3d_sincos_pos_embed(
|
|
430
|
+
self.decoder_pos_embed.shape[-1], self.grid_size, add_cls_token=True
|
|
431
|
+
)
|
|
432
|
+
self.decoder_pos_embed.data.copy_(
|
|
433
|
+
torch.from_numpy(pos_embed).float().unsqueeze(0)
|
|
434
|
+
)
|
|
435
|
+
|
|
436
|
+
torch.nn.init.normal_(self.mask_token, std=0.02)
|
|
437
|
+
self.apply(_init_weights)
|
|
438
|
+
|
|
439
|
+
def forward(
|
|
440
|
+
self,
|
|
441
|
+
hidden_states: torch.Tensor,
|
|
442
|
+
ids_restore: torch.Tensor,
|
|
443
|
+
temporal_coords: None | torch.Tensor = None,
|
|
444
|
+
location_coords: None | torch.Tensor = None,
|
|
445
|
+
input_size: list[int] = None,
|
|
446
|
+
):
|
|
447
|
+
x = self.decoder_embed(hidden_states)
|
|
448
|
+
|
|
449
|
+
mask_tokens = self.mask_token.repeat(
|
|
450
|
+
x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1
|
|
451
|
+
)
|
|
452
|
+
x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1)
|
|
453
|
+
x_ = torch.gather(
|
|
454
|
+
x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])
|
|
455
|
+
)
|
|
456
|
+
x = torch.cat([x[:, :1, :], x_], dim=1)
|
|
457
|
+
|
|
458
|
+
x = x + self.decoder_pos_embed
|
|
459
|
+
|
|
460
|
+
if self.temporal_encoding and temporal_coords is not None:
|
|
461
|
+
num_frames = temporal_coords.shape[1]
|
|
462
|
+
tokens_per_frame = (x.shape[1] - 1) // num_frames
|
|
463
|
+
temp_embed = self.temporal_encoder(temporal_coords, tokens_per_frame)
|
|
464
|
+
x[:, 1:, :] = x[:, 1:, :] + temp_embed
|
|
465
|
+
|
|
466
|
+
if self.location_encoding and location_coords is not None:
|
|
467
|
+
x[:, 1:, :] = x[:, 1:, :] + self.location_encoder(location_coords)
|
|
468
|
+
|
|
469
|
+
for blk in self.decoder_blocks:
|
|
470
|
+
x = blk(x)
|
|
471
|
+
|
|
472
|
+
x = self.decoder_norm(x)
|
|
473
|
+
x = self.decoder_pred(x)
|
|
474
|
+
x = x[:, 1:, :]
|
|
475
|
+
|
|
476
|
+
return x
|
|
477
|
+
|
|
478
|
+
|
|
479
|
+
class PrithviMAE(nn.Module):
|
|
480
|
+
"""Prithvi Masked Autoencoder."""
|
|
481
|
+
|
|
482
|
+
def __init__(
|
|
483
|
+
self,
|
|
484
|
+
img_size: int | tuple[int, int] = 224,
|
|
485
|
+
patch_size: int | tuple[int, int, int] = (1, 16, 16),
|
|
486
|
+
num_frames: int = 4,
|
|
487
|
+
in_chans: int = 6,
|
|
488
|
+
embed_dim: int = 768,
|
|
489
|
+
depth: int = 12,
|
|
490
|
+
num_heads: int = 12,
|
|
491
|
+
decoder_embed_dim: int = 512,
|
|
492
|
+
decoder_depth: int = 8,
|
|
493
|
+
decoder_num_heads: int = 16,
|
|
494
|
+
mlp_ratio: float = 4.0,
|
|
495
|
+
norm_layer: nn.Module = nn.LayerNorm,
|
|
496
|
+
norm_pix_loss: bool = False,
|
|
497
|
+
coords_encoding: list[str] | None = None,
|
|
498
|
+
coords_scale_learn: bool = False,
|
|
499
|
+
drop_path: float = 0.0,
|
|
500
|
+
mask_ratio: float = 0.75,
|
|
501
|
+
**kwargs,
|
|
502
|
+
):
|
|
503
|
+
super().__init__()
|
|
504
|
+
|
|
505
|
+
self.img_size = to_2tuple(img_size)
|
|
506
|
+
self.patch_size = (
|
|
507
|
+
patch_size if isinstance(patch_size, tuple) else (1, patch_size, patch_size)
|
|
508
|
+
)
|
|
509
|
+
self.num_frames = num_frames
|
|
510
|
+
self.in_chans = in_chans
|
|
511
|
+
self.norm_pix_loss = norm_pix_loss
|
|
512
|
+
|
|
513
|
+
self.encoder = PrithviViT(
|
|
514
|
+
img_size=img_size,
|
|
515
|
+
patch_size=patch_size,
|
|
516
|
+
num_frames=num_frames,
|
|
517
|
+
in_chans=in_chans,
|
|
518
|
+
embed_dim=embed_dim,
|
|
519
|
+
depth=depth,
|
|
520
|
+
num_heads=num_heads,
|
|
521
|
+
mlp_ratio=mlp_ratio,
|
|
522
|
+
norm_layer=norm_layer,
|
|
523
|
+
coords_encoding=coords_encoding,
|
|
524
|
+
coords_scale_learn=coords_scale_learn,
|
|
525
|
+
drop_path=drop_path,
|
|
526
|
+
)
|
|
527
|
+
|
|
528
|
+
self.decoder = MAEDecoder(
|
|
529
|
+
patch_size=self.patch_size,
|
|
530
|
+
grid_size=self.encoder.patch_embed.grid_size,
|
|
531
|
+
in_chans=in_chans,
|
|
532
|
+
encoder_embed_dim=embed_dim,
|
|
533
|
+
decoder_embed_dim=decoder_embed_dim,
|
|
534
|
+
depth=decoder_depth,
|
|
535
|
+
num_heads=decoder_num_heads,
|
|
536
|
+
mlp_ratio=mlp_ratio,
|
|
537
|
+
norm_layer=norm_layer,
|
|
538
|
+
coords_encoding=coords_encoding,
|
|
539
|
+
coords_scale_learn=coords_scale_learn,
|
|
540
|
+
)
|
|
541
|
+
|
|
542
|
+
def patchify(self, pixel_values):
|
|
543
|
+
B, C, T, H, W = pixel_values.shape
|
|
544
|
+
pH = H // self.patch_size[1]
|
|
545
|
+
pW = W // self.patch_size[2]
|
|
546
|
+
|
|
547
|
+
x = pixel_values.reshape(
|
|
548
|
+
B,
|
|
549
|
+
C,
|
|
550
|
+
T // self.patch_size[0],
|
|
551
|
+
self.patch_size[0],
|
|
552
|
+
pH,
|
|
553
|
+
self.patch_size[1],
|
|
554
|
+
pW,
|
|
555
|
+
self.patch_size[2],
|
|
556
|
+
)
|
|
557
|
+
x = x.permute(0, 2, 4, 6, 3, 5, 7, 1)
|
|
558
|
+
patchified_pixel_values = x.reshape(
|
|
559
|
+
B,
|
|
560
|
+
T // self.patch_size[0] * pH * pW,
|
|
561
|
+
self.patch_size[0] * self.patch_size[1] * self.patch_size[2] * C,
|
|
562
|
+
)
|
|
563
|
+
|
|
564
|
+
return patchified_pixel_values
|
|
565
|
+
|
|
566
|
+
def unpatchify(
|
|
567
|
+
self, patchified_pixel_values, image_size: tuple[int, int] | None = None
|
|
568
|
+
):
|
|
569
|
+
if image_size is None:
|
|
570
|
+
H, W = self.img_size
|
|
571
|
+
else:
|
|
572
|
+
H, W = image_size
|
|
573
|
+
|
|
574
|
+
C = self.in_chans
|
|
575
|
+
pH = H // self.patch_size[1]
|
|
576
|
+
pW = W // self.patch_size[2]
|
|
577
|
+
T = self.num_frames
|
|
578
|
+
|
|
579
|
+
x = patchified_pixel_values.reshape(
|
|
580
|
+
patchified_pixel_values.shape[0],
|
|
581
|
+
T // self.patch_size[0],
|
|
582
|
+
pH,
|
|
583
|
+
pW,
|
|
584
|
+
self.patch_size[0],
|
|
585
|
+
self.patch_size[1],
|
|
586
|
+
self.patch_size[2],
|
|
587
|
+
C,
|
|
588
|
+
)
|
|
589
|
+
x = x.permute(0, 7, 1, 4, 2, 5, 3, 6)
|
|
590
|
+
pixel_values = x.reshape(
|
|
591
|
+
patchified_pixel_values.shape[0],
|
|
592
|
+
C,
|
|
593
|
+
T,
|
|
594
|
+
pH * self.patch_size[1],
|
|
595
|
+
pW * self.patch_size[2],
|
|
596
|
+
)
|
|
597
|
+
|
|
598
|
+
return pixel_values
|
|
599
|
+
|
|
600
|
+
def forward_loss(self, pixel_values, pred, mask):
|
|
601
|
+
target = self.patchify(pixel_values)
|
|
602
|
+
|
|
603
|
+
if self.norm_pix_loss:
|
|
604
|
+
mean = target.mean(dim=-1, keepdim=True)
|
|
605
|
+
var = target.var(dim=-1, keepdim=True)
|
|
606
|
+
target = (target - mean) / (var + 1.0e-6) ** 0.5
|
|
607
|
+
|
|
608
|
+
loss = (pred - target) ** 2
|
|
609
|
+
loss = loss.mean(dim=-1)
|
|
610
|
+
|
|
611
|
+
loss = (loss * mask).sum() / mask.sum()
|
|
612
|
+
return loss
|
|
613
|
+
|
|
614
|
+
def forward(
|
|
615
|
+
self,
|
|
616
|
+
pixel_values: torch.Tensor,
|
|
617
|
+
temporal_coords: None | torch.Tensor = None,
|
|
618
|
+
location_coords: None | torch.Tensor = None,
|
|
619
|
+
mask_ratio: float = None,
|
|
620
|
+
):
|
|
621
|
+
mask_ratio = mask_ratio if mask_ratio is not None else 0.75
|
|
622
|
+
|
|
623
|
+
latent, mask, ids_restore = self.encoder(
|
|
624
|
+
pixel_values, temporal_coords, location_coords, mask_ratio
|
|
625
|
+
)
|
|
626
|
+
pred = self.decoder(latent, ids_restore, temporal_coords, location_coords)
|
|
627
|
+
loss = self.forward_loss(pixel_values, pred, mask)
|
|
628
|
+
|
|
629
|
+
return loss, pred, mask
|
|
630
|
+
|
|
631
|
+
|
|
632
|
+
class PrithviProcessor:
|
|
633
|
+
"""Prithvi EO 2.0 processor with GeoTIFF input/output support.
|
|
634
|
+
|
|
635
|
+
Supports multiple model variants:
|
|
636
|
+
- Prithvi-EO-2.0-tiny-TL (tiny transfer learning)
|
|
637
|
+
- Prithvi-EO-2.0-100M-TL (100M transfer learning)
|
|
638
|
+
- Prithvi-EO-2.0-300M (300M base model)
|
|
639
|
+
- Prithvi-EO-2.0-300M-TL (300M transfer learning)
|
|
640
|
+
- Prithvi-EO-2.0-600M (600M base model)
|
|
641
|
+
- Prithvi-EO-2.0-600M-TL (600M transfer learning)
|
|
642
|
+
|
|
643
|
+
References:
|
|
644
|
+
- tiny-TL: https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-2.0-tiny-TL
|
|
645
|
+
- 100M-TL: https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-2.0-100M-TL
|
|
646
|
+
- 300M: https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-2.0-300M
|
|
647
|
+
- 300M-TL: https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL
|
|
648
|
+
- 600M: https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-2.0-600M
|
|
649
|
+
- 600M-TL: https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-2.0-600M-TL
|
|
650
|
+
- GitHub: https://github.com/NASA-IMPACT/Prithvi-EO-2.0
|
|
651
|
+
"""
|
|
652
|
+
|
|
653
|
+
def __init__(
|
|
654
|
+
self,
|
|
655
|
+
model_name: str = "Prithvi-EO-2.0-300M-TL",
|
|
656
|
+
config_path: Optional[str] = None,
|
|
657
|
+
checkpoint_path: Optional[str] = None,
|
|
658
|
+
device: Optional[torch.device] = None,
|
|
659
|
+
cache_dir: Optional[str] = None,
|
|
660
|
+
):
|
|
661
|
+
"""Initialize Prithvi processor.
|
|
662
|
+
|
|
663
|
+
Args:
|
|
664
|
+
model_name: Name of the Prithvi model to download from HuggingFace Hub.
|
|
665
|
+
Options:
|
|
666
|
+
- "Prithvi-EO-2.0-tiny-TL" (tiny, 192 dim, 12 layers)
|
|
667
|
+
- "Prithvi-EO-2.0-100M-TL" (100M, 768 dim, 12 layers)
|
|
668
|
+
- "Prithvi-EO-2.0-300M" (base, 1024 dim, 24 layers)
|
|
669
|
+
- "Prithvi-EO-2.0-300M-TL" (default, 768 dim, 12 layers)
|
|
670
|
+
- "Prithvi-EO-2.0-600M" (base, 1280 dim, 32 layers)
|
|
671
|
+
- "Prithvi-EO-2.0-600M-TL" (1280 dim, 32 layers)
|
|
672
|
+
config_path: Path to config file (optional, downloads if not provided)
|
|
673
|
+
checkpoint_path: Path to checkpoint file (optional, downloads if not provided)
|
|
674
|
+
device: Torch device to use
|
|
675
|
+
cache_dir: Directory to cache downloaded files
|
|
676
|
+
"""
|
|
677
|
+
self.device = device or get_device()
|
|
678
|
+
self.model_name = model_name
|
|
679
|
+
self.cache_dir = cache_dir
|
|
680
|
+
|
|
681
|
+
# Download or load config and checkpoint
|
|
682
|
+
if config_path is None or checkpoint_path is None:
|
|
683
|
+
config_path, checkpoint_path = self.download_model(model_name, cache_dir)
|
|
684
|
+
|
|
685
|
+
self.config_path = config_path
|
|
686
|
+
self.checkpoint_path = checkpoint_path
|
|
687
|
+
|
|
688
|
+
# Load config
|
|
689
|
+
with open(config_path, "r") as f:
|
|
690
|
+
config_data = json.load(f)
|
|
691
|
+
self.config = config_data["pretrained_cfg"]
|
|
692
|
+
|
|
693
|
+
# Extract parameters
|
|
694
|
+
self.bands = self.config["bands"]
|
|
695
|
+
self.mean = self.config["mean"]
|
|
696
|
+
self.std = self.config["std"]
|
|
697
|
+
self.img_size = self.config["img_size"]
|
|
698
|
+
self.patch_size = self.config["patch_size"]
|
|
699
|
+
self.mask_ratio = self.config["mask_ratio"]
|
|
700
|
+
self.num_frames = self.config.get("num_frames", 4)
|
|
701
|
+
self.coords_encoding = self.config.get("coords_encoding", [])
|
|
702
|
+
|
|
703
|
+
# Load model
|
|
704
|
+
self.model = self._load_model()
|
|
705
|
+
|
|
706
|
+
@staticmethod
|
|
707
|
+
def download_model(
|
|
708
|
+
model_name: str = "Prithvi-EO-2.0-300M-TL", cache_dir: str = None
|
|
709
|
+
) -> Tuple[str, str]:
|
|
710
|
+
"""Download Prithvi model from HuggingFace Hub.
|
|
711
|
+
|
|
712
|
+
Args:
|
|
713
|
+
model_name: Name of the model. Options:
|
|
714
|
+
- "Prithvi-EO-2.0-tiny-TL"
|
|
715
|
+
- "Prithvi-EO-2.0-100M-TL"
|
|
716
|
+
- "Prithvi-EO-2.0-300M" (base model)
|
|
717
|
+
- "Prithvi-EO-2.0-300M-TL" (default)
|
|
718
|
+
- "Prithvi-EO-2.0-600M" (base model)
|
|
719
|
+
- "Prithvi-EO-2.0-600M-TL"
|
|
720
|
+
cache_dir: Directory to cache files
|
|
721
|
+
|
|
722
|
+
Returns:
|
|
723
|
+
Tuple of (config_path, checkpoint_path)
|
|
724
|
+
"""
|
|
725
|
+
repo_id = f"ibm-nasa-geospatial/{model_name}"
|
|
726
|
+
|
|
727
|
+
try:
|
|
728
|
+
# Download config
|
|
729
|
+
config_path = hf_hub_download(
|
|
730
|
+
repo_id=repo_id,
|
|
731
|
+
filename="config.json",
|
|
732
|
+
cache_dir=cache_dir,
|
|
733
|
+
)
|
|
734
|
+
|
|
735
|
+
# Download checkpoint
|
|
736
|
+
# Model name format: Prithvi-EO-2.0-300M-TL -> Prithvi_EO_V2_300M_TL.pt
|
|
737
|
+
checkpoint_filename = (
|
|
738
|
+
model_name.replace("-", "_").replace("_2.0_", "_V2_") + ".pt"
|
|
739
|
+
)
|
|
740
|
+
checkpoint_path = hf_hub_download(
|
|
741
|
+
repo_id=repo_id,
|
|
742
|
+
filename=checkpoint_filename,
|
|
743
|
+
cache_dir=cache_dir,
|
|
744
|
+
)
|
|
745
|
+
|
|
746
|
+
return config_path, checkpoint_path
|
|
747
|
+
|
|
748
|
+
except Exception as e:
|
|
749
|
+
raise RuntimeError(f"Failed to download model from HuggingFace Hub: {e}")
|
|
750
|
+
|
|
751
|
+
def _load_model(self) -> PrithviMAE:
|
|
752
|
+
"""Load Prithvi MAE model."""
|
|
753
|
+
try:
|
|
754
|
+
# Convert patch_size to tuple if it's a list
|
|
755
|
+
patch_size = self.config["patch_size"]
|
|
756
|
+
if isinstance(patch_size, list):
|
|
757
|
+
patch_size = tuple(patch_size)
|
|
758
|
+
|
|
759
|
+
# Create model
|
|
760
|
+
model = PrithviMAE(
|
|
761
|
+
img_size=self.config["img_size"],
|
|
762
|
+
patch_size=patch_size,
|
|
763
|
+
num_frames=self.config["num_frames"],
|
|
764
|
+
in_chans=self.config["in_chans"],
|
|
765
|
+
embed_dim=self.config["embed_dim"],
|
|
766
|
+
depth=self.config["depth"],
|
|
767
|
+
num_heads=self.config["num_heads"],
|
|
768
|
+
decoder_embed_dim=self.config["decoder_embed_dim"],
|
|
769
|
+
decoder_depth=self.config["decoder_depth"],
|
|
770
|
+
decoder_num_heads=self.config["decoder_num_heads"],
|
|
771
|
+
mlp_ratio=self.config["mlp_ratio"],
|
|
772
|
+
coords_encoding=self.coords_encoding,
|
|
773
|
+
coords_scale_learn=self.config.get("coords_scale_learn", False),
|
|
774
|
+
mask_ratio=self.mask_ratio,
|
|
775
|
+
norm_pix_loss=self.config.get("norm_pix_loss", False),
|
|
776
|
+
)
|
|
777
|
+
|
|
778
|
+
# Load checkpoint
|
|
779
|
+
state_dict = torch.load(
|
|
780
|
+
self.checkpoint_path, map_location=self.device, weights_only=True
|
|
781
|
+
)
|
|
782
|
+
|
|
783
|
+
# Remove fixed pos_embed weights
|
|
784
|
+
for k in list(state_dict.keys()):
|
|
785
|
+
if "pos_embed" in k:
|
|
786
|
+
del state_dict[k]
|
|
787
|
+
|
|
788
|
+
model.load_state_dict(state_dict, strict=False)
|
|
789
|
+
model = model.to(self.device)
|
|
790
|
+
model.eval()
|
|
791
|
+
|
|
792
|
+
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
|
793
|
+
|
|
794
|
+
return model
|
|
795
|
+
|
|
796
|
+
except Exception as e:
|
|
797
|
+
raise RuntimeError(f"Failed to load Prithvi model: {e}")
|
|
798
|
+
|
|
799
|
+
def read_geotiff(self, file_path: str) -> Tuple[np.ndarray, dict, Optional[Tuple]]:
|
|
800
|
+
"""Read GeoTIFF file.
|
|
801
|
+
|
|
802
|
+
Args:
|
|
803
|
+
file_path: Path to GeoTIFF file
|
|
804
|
+
|
|
805
|
+
Returns:
|
|
806
|
+
Tuple of (image array, metadata, coordinates)
|
|
807
|
+
"""
|
|
808
|
+
with rasterio.open(file_path) as src:
|
|
809
|
+
img = src.read()
|
|
810
|
+
meta = src.meta
|
|
811
|
+
try:
|
|
812
|
+
coords = src.tags()
|
|
813
|
+
except:
|
|
814
|
+
coords = None
|
|
815
|
+
|
|
816
|
+
return img, meta, coords
|
|
817
|
+
|
|
818
|
+
def preprocess_image(
|
|
819
|
+
self,
|
|
820
|
+
img: np.ndarray,
|
|
821
|
+
indices: Optional[List[int]] = None,
|
|
822
|
+
) -> np.ndarray:
|
|
823
|
+
"""Preprocess image for model input.
|
|
824
|
+
|
|
825
|
+
Args:
|
|
826
|
+
img: Image array with shape (C, H, W)
|
|
827
|
+
indices: Optional band indices to select
|
|
828
|
+
|
|
829
|
+
Returns:
|
|
830
|
+
Preprocessed image
|
|
831
|
+
"""
|
|
832
|
+
# Move channels to last dimension
|
|
833
|
+
img = np.moveaxis(img, 0, -1)
|
|
834
|
+
|
|
835
|
+
# Select bands if specified
|
|
836
|
+
if indices is not None:
|
|
837
|
+
img = img[..., indices]
|
|
838
|
+
|
|
839
|
+
# Normalize (handle nodata)
|
|
840
|
+
img = np.where(img == NO_DATA, NO_DATA_FLOAT, (img - self.mean) / self.std)
|
|
841
|
+
|
|
842
|
+
return img
|
|
843
|
+
|
|
844
|
+
def load_images(
|
|
845
|
+
self,
|
|
846
|
+
file_paths: List[str],
|
|
847
|
+
indices: Optional[List[int]] = None,
|
|
848
|
+
) -> Tuple[np.ndarray, List[dict], List, List]:
|
|
849
|
+
"""Load and preprocess multiple images.
|
|
850
|
+
|
|
851
|
+
Args:
|
|
852
|
+
file_paths: List of GeoTIFF file paths
|
|
853
|
+
indices: Optional band indices
|
|
854
|
+
|
|
855
|
+
Returns:
|
|
856
|
+
Tuple of (images, metadata, temporal_coords, location_coords)
|
|
857
|
+
"""
|
|
858
|
+
# Check if we need to pad to num_frames
|
|
859
|
+
if len(file_paths) < self.num_frames:
|
|
860
|
+
# Pad file_paths by repeating the last file
|
|
861
|
+
file_paths = list(file_paths) + [file_paths[-1]] * (
|
|
862
|
+
self.num_frames - len(file_paths)
|
|
863
|
+
)
|
|
864
|
+
elif len(file_paths) > self.num_frames:
|
|
865
|
+
file_paths = file_paths[: self.num_frames]
|
|
866
|
+
|
|
867
|
+
imgs = []
|
|
868
|
+
metas = []
|
|
869
|
+
temporal_coords = []
|
|
870
|
+
location_coords = []
|
|
871
|
+
|
|
872
|
+
for file in file_paths:
|
|
873
|
+
img, meta, coords = self.read_geotiff(file)
|
|
874
|
+
|
|
875
|
+
# Preprocess
|
|
876
|
+
img = self.preprocess_image(img, indices)
|
|
877
|
+
|
|
878
|
+
imgs.append(img)
|
|
879
|
+
metas.append(meta)
|
|
880
|
+
|
|
881
|
+
# Stack images: (T, H, W, C)
|
|
882
|
+
imgs = np.stack(imgs, axis=0)
|
|
883
|
+
# Rearrange to: (C, T, H, W)
|
|
884
|
+
imgs = np.moveaxis(imgs, -1, 0).astype("float32")
|
|
885
|
+
# Add batch dimension: (1, C, T, H, W)
|
|
886
|
+
imgs = np.expand_dims(imgs, axis=0)
|
|
887
|
+
|
|
888
|
+
return imgs, metas, temporal_coords, location_coords
|
|
889
|
+
|
|
890
|
+
def run_inference(
|
|
891
|
+
self,
|
|
892
|
+
input_data: torch.Tensor,
|
|
893
|
+
temporal_coords: Optional[torch.Tensor] = None,
|
|
894
|
+
location_coords: Optional[torch.Tensor] = None,
|
|
895
|
+
mask_ratio: Optional[float] = None,
|
|
896
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
897
|
+
"""Run model inference.
|
|
898
|
+
|
|
899
|
+
Args:
|
|
900
|
+
input_data: Input tensor with shape (B, C, T, H, W)
|
|
901
|
+
temporal_coords: Optional temporal coordinates
|
|
902
|
+
location_coords: Optional location coordinates
|
|
903
|
+
mask_ratio: Mask ratio (default: from config)
|
|
904
|
+
|
|
905
|
+
Returns:
|
|
906
|
+
Tuple of (reconstructed_image, mask_image)
|
|
907
|
+
"""
|
|
908
|
+
mask_ratio = mask_ratio or self.mask_ratio
|
|
909
|
+
|
|
910
|
+
# Check if input dimensions match model expectations
|
|
911
|
+
B, C, T, H, W = input_data.shape
|
|
912
|
+
if H % self.img_size != 0 or W % self.img_size != 0:
|
|
913
|
+
raise ValueError(
|
|
914
|
+
f"Input spatial dimensions ({H}x{W}) must be divisible by model image size ({self.img_size}). "
|
|
915
|
+
f"Use process_files() method which handles padding automatically, or pad your input to multiples of {self.img_size}."
|
|
916
|
+
)
|
|
917
|
+
|
|
918
|
+
with torch.no_grad():
|
|
919
|
+
x = input_data.to(self.device)
|
|
920
|
+
_, pred, mask = self.model(x, temporal_coords, location_coords, mask_ratio)
|
|
921
|
+
|
|
922
|
+
# Create mask and prediction images
|
|
923
|
+
mask_img = (
|
|
924
|
+
self.model.unpatchify(mask.unsqueeze(-1).repeat(1, 1, pred.shape[-1]))
|
|
925
|
+
.detach()
|
|
926
|
+
.cpu()
|
|
927
|
+
)
|
|
928
|
+
pred_img = self.model.unpatchify(pred).detach().cpu()
|
|
929
|
+
|
|
930
|
+
# Mix visible and predicted patches
|
|
931
|
+
rec_img = input_data.clone()
|
|
932
|
+
rec_img[mask_img == 1] = pred_img[mask_img == 1]
|
|
933
|
+
|
|
934
|
+
# Invert mask for better visualization
|
|
935
|
+
mask_img = (~(mask_img.to(torch.bool))).to(torch.float)
|
|
936
|
+
|
|
937
|
+
return rec_img, mask_img
|
|
938
|
+
|
|
939
|
+
def process_images(
|
|
940
|
+
self,
|
|
941
|
+
file_paths: List[str],
|
|
942
|
+
mask_ratio: Optional[float] = None,
|
|
943
|
+
indices: Optional[List[int]] = None,
|
|
944
|
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
945
|
+
"""Process multiple GeoTIFF files and return tensors (without saving).
|
|
946
|
+
|
|
947
|
+
This method handles large images using sliding windows and returns tensors
|
|
948
|
+
for visualization, unlike process_files() which saves to disk.
|
|
949
|
+
|
|
950
|
+
Args:
|
|
951
|
+
file_paths: List of input file paths
|
|
952
|
+
mask_ratio: Optional mask ratio
|
|
953
|
+
indices: Optional band indices
|
|
954
|
+
|
|
955
|
+
Returns:
|
|
956
|
+
Tuple of (input_tensor, reconstructed_tensor, mask_tensor)
|
|
957
|
+
"""
|
|
958
|
+
# Load images
|
|
959
|
+
input_data, metas, temporal_coords, location_coords = self.load_images(
|
|
960
|
+
file_paths, indices
|
|
961
|
+
)
|
|
962
|
+
|
|
963
|
+
# Handle padding
|
|
964
|
+
original_h, original_w = input_data.shape[-2:]
|
|
965
|
+
pad_h = (self.img_size - (original_h % self.img_size)) % self.img_size
|
|
966
|
+
pad_w = (self.img_size - (original_w % self.img_size)) % self.img_size
|
|
967
|
+
|
|
968
|
+
if pad_h > 0 or pad_w > 0:
|
|
969
|
+
input_data = np.pad(
|
|
970
|
+
input_data,
|
|
971
|
+
((0, 0), (0, 0), (0, 0), (0, pad_h), (0, pad_w)),
|
|
972
|
+
mode="reflect",
|
|
973
|
+
)
|
|
974
|
+
|
|
975
|
+
# Convert to tensor
|
|
976
|
+
batch = torch.tensor(input_data, device="cpu")
|
|
977
|
+
|
|
978
|
+
# Create sliding windows
|
|
979
|
+
windows = batch.unfold(3, self.img_size, self.img_size).unfold(
|
|
980
|
+
4, self.img_size, self.img_size
|
|
981
|
+
)
|
|
982
|
+
h1, w1 = windows.shape[3:5]
|
|
983
|
+
windows = rearrange(
|
|
984
|
+
windows,
|
|
985
|
+
"b c t h1 w1 h w -> (b h1 w1) c t h w",
|
|
986
|
+
h=self.img_size,
|
|
987
|
+
w=self.img_size,
|
|
988
|
+
)
|
|
989
|
+
|
|
990
|
+
# Split into batches
|
|
991
|
+
num_batches = max(1, windows.shape[0])
|
|
992
|
+
windows_list = torch.tensor_split(windows, num_batches, dim=0)
|
|
993
|
+
|
|
994
|
+
# Process each window
|
|
995
|
+
rec_imgs = []
|
|
996
|
+
mask_imgs = []
|
|
997
|
+
|
|
998
|
+
for i, x in enumerate(windows_list):
|
|
999
|
+
rec_img, mask_img = self.run_inference(x, None, None, mask_ratio)
|
|
1000
|
+
rec_imgs.append(rec_img)
|
|
1001
|
+
mask_imgs.append(mask_img)
|
|
1002
|
+
|
|
1003
|
+
# Concatenate results
|
|
1004
|
+
rec_imgs = torch.cat(rec_imgs, dim=0)
|
|
1005
|
+
mask_imgs = torch.cat(mask_imgs, dim=0)
|
|
1006
|
+
|
|
1007
|
+
# Rearrange patches back to image
|
|
1008
|
+
num_frames = len(file_paths)
|
|
1009
|
+
rec_imgs = rearrange(
|
|
1010
|
+
rec_imgs,
|
|
1011
|
+
"(b h1 w1) c t h w -> b c t (h1 h) (w1 w)",
|
|
1012
|
+
h=self.img_size,
|
|
1013
|
+
w=self.img_size,
|
|
1014
|
+
b=1,
|
|
1015
|
+
c=len(self.bands),
|
|
1016
|
+
t=num_frames,
|
|
1017
|
+
h1=h1,
|
|
1018
|
+
w1=w1,
|
|
1019
|
+
)
|
|
1020
|
+
mask_imgs = rearrange(
|
|
1021
|
+
mask_imgs,
|
|
1022
|
+
"(b h1 w1) c t h w -> b c t (h1 h) (w1 w)",
|
|
1023
|
+
h=self.img_size,
|
|
1024
|
+
w=self.img_size,
|
|
1025
|
+
b=1,
|
|
1026
|
+
c=len(self.bands),
|
|
1027
|
+
t=num_frames,
|
|
1028
|
+
h1=h1,
|
|
1029
|
+
w1=w1,
|
|
1030
|
+
)
|
|
1031
|
+
|
|
1032
|
+
# Remove padding
|
|
1033
|
+
rec_imgs = rec_imgs[..., :original_h, :original_w]
|
|
1034
|
+
mask_imgs = mask_imgs[..., :original_h, :original_w]
|
|
1035
|
+
input_imgs = batch[..., :original_h, :original_w]
|
|
1036
|
+
|
|
1037
|
+
return input_imgs, rec_imgs, mask_imgs
|
|
1038
|
+
|
|
1039
|
+
def visualize_rgb(
|
|
1040
|
+
self,
|
|
1041
|
+
input_tensor: torch.Tensor,
|
|
1042
|
+
rec_tensor: torch.Tensor,
|
|
1043
|
+
mask_tensor: torch.Tensor,
|
|
1044
|
+
) -> Tuple[List[np.ndarray], List[np.ndarray], List[np.ndarray]]:
|
|
1045
|
+
"""Extract RGB images from tensors for visualization.
|
|
1046
|
+
|
|
1047
|
+
Args:
|
|
1048
|
+
input_tensor: Input tensor (B, C, T, H, W)
|
|
1049
|
+
rec_tensor: Reconstructed tensor (B, C, T, H, W)
|
|
1050
|
+
mask_tensor: Mask tensor (B, C, T, H, W)
|
|
1051
|
+
|
|
1052
|
+
Returns:
|
|
1053
|
+
Tuple of (original_rgb, masked_rgb, reconstructed_rgb) lists
|
|
1054
|
+
"""
|
|
1055
|
+
# Get RGB channel indices (B04=Red, B03=Green, B02=Blue)
|
|
1056
|
+
rgb_channels = [
|
|
1057
|
+
self.bands.index("B04"),
|
|
1058
|
+
self.bands.index("B03"),
|
|
1059
|
+
self.bands.index("B02"),
|
|
1060
|
+
]
|
|
1061
|
+
|
|
1062
|
+
# Remove batch dimension
|
|
1063
|
+
if input_tensor.dim() == 5:
|
|
1064
|
+
input_tensor = input_tensor[0]
|
|
1065
|
+
if rec_tensor.dim() == 5:
|
|
1066
|
+
rec_tensor = rec_tensor[0]
|
|
1067
|
+
if mask_tensor.dim() == 5:
|
|
1068
|
+
mask_tensor = mask_tensor[0]
|
|
1069
|
+
|
|
1070
|
+
mean = torch.tensor(self.mean)
|
|
1071
|
+
std = torch.tensor(self.std)
|
|
1072
|
+
|
|
1073
|
+
original_rgb = []
|
|
1074
|
+
masked_rgb = []
|
|
1075
|
+
reconstructed_rgb = []
|
|
1076
|
+
|
|
1077
|
+
num_frames = input_tensor.shape[1]
|
|
1078
|
+
|
|
1079
|
+
for t in range(num_frames):
|
|
1080
|
+
# Extract and denormalize original RGB
|
|
1081
|
+
rgb_orig = input_tensor[rgb_channels, t, :, :].clone()
|
|
1082
|
+
for i, c in enumerate(rgb_channels):
|
|
1083
|
+
rgb_orig[i] = rgb_orig[i] * std[c] + mean[c]
|
|
1084
|
+
rgb_orig_np = rgb_orig.numpy()
|
|
1085
|
+
rgb_orig_np = np.clip(rgb_orig_np, 0, 10000)
|
|
1086
|
+
rgb_orig_np = (rgb_orig_np / 10000 * 255).astype(np.uint8)
|
|
1087
|
+
rgb_orig_np = np.transpose(rgb_orig_np, (1, 2, 0))
|
|
1088
|
+
original_rgb.append(rgb_orig_np)
|
|
1089
|
+
|
|
1090
|
+
# Extract and denormalize reconstructed RGB
|
|
1091
|
+
rgb_rec = rec_tensor[rgb_channels, t, :, :].clone()
|
|
1092
|
+
for i, c in enumerate(rgb_channels):
|
|
1093
|
+
rgb_rec[i] = rgb_rec[i] * std[c] + mean[c]
|
|
1094
|
+
rgb_rec_np = rgb_rec.numpy()
|
|
1095
|
+
rgb_rec_np = np.clip(rgb_rec_np, 0, 10000)
|
|
1096
|
+
rgb_rec_np = (rgb_rec_np / 10000 * 255).astype(np.uint8)
|
|
1097
|
+
rgb_rec_np = np.transpose(rgb_rec_np, (1, 2, 0))
|
|
1098
|
+
reconstructed_rgb.append(rgb_rec_np)
|
|
1099
|
+
|
|
1100
|
+
# Create masked RGB (visible patches only)
|
|
1101
|
+
mask_t = mask_tensor[rgb_channels, t, :, :].numpy()
|
|
1102
|
+
masked_np = rgb_orig_np.astype(np.float32) * np.transpose(mask_t, (1, 2, 0))
|
|
1103
|
+
masked_rgb.append(masked_np.astype(np.uint8))
|
|
1104
|
+
|
|
1105
|
+
return original_rgb, masked_rgb, reconstructed_rgb
|
|
1106
|
+
|
|
1107
|
+
def process_files(
|
|
1108
|
+
self,
|
|
1109
|
+
file_paths: List[str],
|
|
1110
|
+
output_dir: str,
|
|
1111
|
+
mask_ratio: Optional[float] = None,
|
|
1112
|
+
indices: Optional[List[int]] = None,
|
|
1113
|
+
):
|
|
1114
|
+
"""Process multiple GeoTIFF files.
|
|
1115
|
+
|
|
1116
|
+
Args:
|
|
1117
|
+
file_paths: List of input file paths
|
|
1118
|
+
output_dir: Output directory for results
|
|
1119
|
+
mask_ratio: Optional mask ratio
|
|
1120
|
+
indices: Optional band indices
|
|
1121
|
+
"""
|
|
1122
|
+
os.makedirs(output_dir, exist_ok=True)
|
|
1123
|
+
|
|
1124
|
+
# Load images
|
|
1125
|
+
input_data, metas, temporal_coords, location_coords = self.load_images(
|
|
1126
|
+
file_paths, indices
|
|
1127
|
+
)
|
|
1128
|
+
|
|
1129
|
+
# Handle padding
|
|
1130
|
+
original_h, original_w = input_data.shape[-2:]
|
|
1131
|
+
pad_h = self.img_size - (original_h % self.img_size)
|
|
1132
|
+
pad_w = self.img_size - (original_w % self.img_size)
|
|
1133
|
+
input_data = np.pad(
|
|
1134
|
+
input_data, ((0, 0), (0, 0), (0, 0), (0, pad_h), (0, pad_w)), mode="reflect"
|
|
1135
|
+
)
|
|
1136
|
+
|
|
1137
|
+
# Convert to tensor
|
|
1138
|
+
batch = torch.tensor(input_data, device="cpu")
|
|
1139
|
+
|
|
1140
|
+
# Create sliding windows
|
|
1141
|
+
windows = batch.unfold(3, self.img_size, self.img_size).unfold(
|
|
1142
|
+
4, self.img_size, self.img_size
|
|
1143
|
+
)
|
|
1144
|
+
h1, w1 = windows.shape[3:5]
|
|
1145
|
+
windows = rearrange(
|
|
1146
|
+
windows,
|
|
1147
|
+
"b c t h1 w1 h w -> (b h1 w1) c t h w",
|
|
1148
|
+
h=self.img_size,
|
|
1149
|
+
w=self.img_size,
|
|
1150
|
+
)
|
|
1151
|
+
|
|
1152
|
+
# Split into batches
|
|
1153
|
+
num_batches = max(1, windows.shape[0])
|
|
1154
|
+
windows_list = torch.tensor_split(windows, num_batches, dim=0)
|
|
1155
|
+
|
|
1156
|
+
# Process each window
|
|
1157
|
+
rec_imgs = []
|
|
1158
|
+
mask_imgs = []
|
|
1159
|
+
|
|
1160
|
+
for i, x in enumerate(windows_list):
|
|
1161
|
+
rec_img, mask_img = self.run_inference(x, None, None, mask_ratio)
|
|
1162
|
+
rec_imgs.append(rec_img)
|
|
1163
|
+
mask_imgs.append(mask_img)
|
|
1164
|
+
|
|
1165
|
+
# Concatenate results
|
|
1166
|
+
rec_imgs = torch.cat(rec_imgs, dim=0)
|
|
1167
|
+
mask_imgs = torch.cat(mask_imgs, dim=0)
|
|
1168
|
+
|
|
1169
|
+
# Rearrange patches back to image
|
|
1170
|
+
num_frames = len(file_paths)
|
|
1171
|
+
rec_imgs = rearrange(
|
|
1172
|
+
rec_imgs,
|
|
1173
|
+
"(b h1 w1) c t h w -> b c t (h1 h) (w1 w)",
|
|
1174
|
+
h=self.img_size,
|
|
1175
|
+
w=self.img_size,
|
|
1176
|
+
b=1,
|
|
1177
|
+
c=len(self.bands),
|
|
1178
|
+
t=num_frames,
|
|
1179
|
+
h1=h1,
|
|
1180
|
+
w1=w1,
|
|
1181
|
+
)
|
|
1182
|
+
mask_imgs = rearrange(
|
|
1183
|
+
mask_imgs,
|
|
1184
|
+
"(b h1 w1) c t h w -> b c t (h1 h) (w1 w)",
|
|
1185
|
+
h=self.img_size,
|
|
1186
|
+
w=self.img_size,
|
|
1187
|
+
b=1,
|
|
1188
|
+
c=len(self.bands),
|
|
1189
|
+
t=num_frames,
|
|
1190
|
+
h1=h1,
|
|
1191
|
+
w1=w1,
|
|
1192
|
+
)
|
|
1193
|
+
|
|
1194
|
+
# Remove padding
|
|
1195
|
+
rec_imgs = rec_imgs[..., :original_h, :original_w]
|
|
1196
|
+
mask_imgs = mask_imgs[..., :original_h, :original_w]
|
|
1197
|
+
|
|
1198
|
+
# Save results
|
|
1199
|
+
self.save_results(rec_imgs[0], mask_imgs[0], metas, output_dir)
|
|
1200
|
+
|
|
1201
|
+
def save_results(
|
|
1202
|
+
self,
|
|
1203
|
+
rec_img: torch.Tensor,
|
|
1204
|
+
mask_img: torch.Tensor,
|
|
1205
|
+
metas: List[dict],
|
|
1206
|
+
output_dir: str,
|
|
1207
|
+
):
|
|
1208
|
+
"""Save reconstruction results.
|
|
1209
|
+
|
|
1210
|
+
Args:
|
|
1211
|
+
rec_img: Reconstructed image with shape (C, T, H, W)
|
|
1212
|
+
mask_img: Mask image with shape (C, T, H, W)
|
|
1213
|
+
metas: List of metadata dicts
|
|
1214
|
+
output_dir: Output directory
|
|
1215
|
+
"""
|
|
1216
|
+
mean = torch.tensor(np.asarray(self.mean)[:, None, None])
|
|
1217
|
+
std = torch.tensor(np.asarray(self.std)[:, None, None])
|
|
1218
|
+
|
|
1219
|
+
for t in range(rec_img.shape[1]):
|
|
1220
|
+
# Denormalize
|
|
1221
|
+
rec_img_t = ((rec_img[:, t, :, :] * std) + mean).to(torch.int16)
|
|
1222
|
+
mask_img_t = mask_img[:, t, :, :].to(torch.int16)
|
|
1223
|
+
|
|
1224
|
+
# Update metadata
|
|
1225
|
+
meta = metas[t].copy()
|
|
1226
|
+
meta.update(compress="lzw", nodata=0)
|
|
1227
|
+
|
|
1228
|
+
# Save files
|
|
1229
|
+
self._save_geotiff(
|
|
1230
|
+
rec_img_t.numpy(),
|
|
1231
|
+
os.path.join(output_dir, f"reconstructed_t{t}.tif"),
|
|
1232
|
+
meta,
|
|
1233
|
+
)
|
|
1234
|
+
self._save_geotiff(
|
|
1235
|
+
mask_img_t.numpy(),
|
|
1236
|
+
os.path.join(output_dir, f"mask_t{t}.tif"),
|
|
1237
|
+
meta,
|
|
1238
|
+
)
|
|
1239
|
+
|
|
1240
|
+
@staticmethod
|
|
1241
|
+
def _save_geotiff(image: np.ndarray, output_path: str, meta: dict):
|
|
1242
|
+
"""Save GeoTIFF file."""
|
|
1243
|
+
with rasterio.open(output_path, "w", **meta) as dest:
|
|
1244
|
+
for i in range(image.shape[0]):
|
|
1245
|
+
dest.write(image[i], i + 1)
|
|
1246
|
+
|
|
1247
|
+
|
|
1248
|
+
def get_available_prithvi_models() -> List[str]:
|
|
1249
|
+
"""Get list of available Prithvi model names.
|
|
1250
|
+
|
|
1251
|
+
Returns:
|
|
1252
|
+
List of available model names
|
|
1253
|
+
|
|
1254
|
+
Example:
|
|
1255
|
+
>>> models = get_available_prithvi_models()
|
|
1256
|
+
>>> print(models)
|
|
1257
|
+
['Prithvi-EO-2.0-300M-TL', 'Prithvi-EO-2.0-600M-TL']
|
|
1258
|
+
"""
|
|
1259
|
+
return AVAILABLE_MODELS.copy()
|
|
1260
|
+
|
|
1261
|
+
|
|
1262
|
+
def load_prithvi_model(
|
|
1263
|
+
model_name: str = "Prithvi-EO-2.0-300M-TL",
|
|
1264
|
+
device: Optional[str] = None,
|
|
1265
|
+
cache_dir: Optional[str] = None,
|
|
1266
|
+
) -> PrithviProcessor:
|
|
1267
|
+
"""Load Prithvi model (convenience function).
|
|
1268
|
+
|
|
1269
|
+
Args:
|
|
1270
|
+
model_name: Name of the model. Options:
|
|
1271
|
+
- "Prithvi-EO-2.0-tiny-TL"
|
|
1272
|
+
- "Prithvi-EO-2.0-100M-TL"
|
|
1273
|
+
- "Prithvi-EO-2.0-300M" (base)
|
|
1274
|
+
- "Prithvi-EO-2.0-300M-TL" (default)
|
|
1275
|
+
- "Prithvi-EO-2.0-600M" (base)
|
|
1276
|
+
- "Prithvi-EO-2.0-600M-TL"
|
|
1277
|
+
device: Device to use ('cuda' or 'cpu')
|
|
1278
|
+
cache_dir: Cache directory
|
|
1279
|
+
|
|
1280
|
+
Returns:
|
|
1281
|
+
PrithviProcessor instance
|
|
1282
|
+
|
|
1283
|
+
Example:
|
|
1284
|
+
>>> # Load tiny-TL model
|
|
1285
|
+
>>> processor = load_prithvi_model("Prithvi-EO-2.0-tiny-TL")
|
|
1286
|
+
>>> # Load 100M-TL model
|
|
1287
|
+
>>> processor = load_prithvi_model("Prithvi-EO-2.0-100M-TL")
|
|
1288
|
+
>>> # Load 300M base model
|
|
1289
|
+
>>> processor = load_prithvi_model("Prithvi-EO-2.0-300M")
|
|
1290
|
+
>>> # Load 300M-TL model
|
|
1291
|
+
>>> processor = load_prithvi_model("Prithvi-EO-2.0-300M-TL")
|
|
1292
|
+
>>> # Load 600M base model
|
|
1293
|
+
>>> processor = load_prithvi_model("Prithvi-EO-2.0-600M")
|
|
1294
|
+
>>> # Load 600M-TL model
|
|
1295
|
+
>>> processor = load_prithvi_model("Prithvi-EO-2.0-600M-TL")
|
|
1296
|
+
"""
|
|
1297
|
+
if device is not None:
|
|
1298
|
+
device = torch.device(device)
|
|
1299
|
+
|
|
1300
|
+
return PrithviProcessor(
|
|
1301
|
+
model_name=model_name,
|
|
1302
|
+
device=device,
|
|
1303
|
+
cache_dir=cache_dir,
|
|
1304
|
+
)
|
|
1305
|
+
|
|
1306
|
+
|
|
1307
|
+
def prithvi_inference(
|
|
1308
|
+
file_paths: List[str],
|
|
1309
|
+
output_dir: str = "output",
|
|
1310
|
+
model_name: str = "Prithvi-EO-2.0-300M-TL",
|
|
1311
|
+
mask_ratio: Optional[float] = None,
|
|
1312
|
+
device: Optional[str] = None,
|
|
1313
|
+
):
|
|
1314
|
+
"""Run Prithvi inference on GeoTIFF files (convenience function).
|
|
1315
|
+
|
|
1316
|
+
Args:
|
|
1317
|
+
file_paths: List of input GeoTIFF files
|
|
1318
|
+
output_dir: Output directory
|
|
1319
|
+
model_name: Name of the model. Options:
|
|
1320
|
+
- "Prithvi-EO-2.0-tiny-TL"
|
|
1321
|
+
- "Prithvi-EO-2.0-100M-TL"
|
|
1322
|
+
- "Prithvi-EO-2.0-300M" (base)
|
|
1323
|
+
- "Prithvi-EO-2.0-300M-TL" (default)
|
|
1324
|
+
- "Prithvi-EO-2.0-600M" (base)
|
|
1325
|
+
- "Prithvi-EO-2.0-600M-TL"
|
|
1326
|
+
mask_ratio: Optional mask ratio
|
|
1327
|
+
device: Device to use
|
|
1328
|
+
|
|
1329
|
+
Example:
|
|
1330
|
+
>>> # Use tiny-TL model
|
|
1331
|
+
>>> prithvi_inference(
|
|
1332
|
+
... file_paths=["img1.tif", "img2.tif", "img3.tif", "img4.tif"],
|
|
1333
|
+
... model_name="Prithvi-EO-2.0-tiny-TL",
|
|
1334
|
+
... output_dir="output_tiny"
|
|
1335
|
+
... )
|
|
1336
|
+
"""
|
|
1337
|
+
processor = load_prithvi_model(model_name, device)
|
|
1338
|
+
processor.process_files(file_paths, output_dir, mask_ratio)
|