geoai-py 0.25.0__py2.py3-none-any.whl → 0.27.0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
geoai/prithvi.py ADDED
@@ -0,0 +1,1338 @@
1
+ """Prithvi EO 2.0 module for geospatial foundation model inference.
2
+
3
+ This module provides tools for using NASA-IBM's Prithvi EO 2.0 geospatial foundation model
4
+ for masked autoencoding and feature extraction on multi-temporal satellite imagery.
5
+ """
6
+
7
+ import json
8
+ import logging
9
+ import os
10
+ from typing import Dict, List, Optional, Tuple, Union
11
+
12
+ import numpy as np
13
+ import rasterio
14
+ import torch
15
+ import torch.nn as nn
16
+ from einops import rearrange
17
+ from huggingface_hub import hf_hub_download
18
+ from timm.layers import to_2tuple
19
+ from timm.models.vision_transformer import Block
20
+
21
+ from .utils import get_device
22
+
23
+ logger = logging.getLogger(__name__)
24
+
25
+ # Constants
26
+ NO_DATA = -9999
27
+ NO_DATA_FLOAT = 0.0001
28
+ OFFSET = 0
29
+ PERCENTILE = 99.9
30
+
31
+ # Available Prithvi models
32
+ AVAILABLE_MODELS = [
33
+ "Prithvi-EO-2.0-tiny-TL", # tiny transfer learning, embed_dim=192, depth=12, with coords
34
+ "Prithvi-EO-2.0-100M-TL", # 100M transfer learning, embed_dim=768, depth=12, with coords
35
+ "Prithvi-EO-2.0-300M", # 300M base model, embed_dim=1024, depth=24, no coords
36
+ "Prithvi-EO-2.0-300M-TL", # 300M transfer learning, embed_dim=768, depth=12, with coords
37
+ "Prithvi-EO-2.0-600M", # 600M base model, embed_dim=1280, depth=32, no coords
38
+ "Prithvi-EO-2.0-600M-TL", # 600M transfer learning, embed_dim=1280, depth=32, with coords
39
+ ]
40
+
41
+
42
+ def get_3d_sincos_pos_embed(embed_dim, grid_size, add_cls_token=False):
43
+ """Create 3D sin/cos positional embeddings.
44
+
45
+ Args:
46
+ embed_dim (int): Embedding dimension.
47
+ grid_size (tuple[int, int, int] | list[int]): The grid depth, height and width.
48
+ add_cls_token (bool, optional): Whether or not to add a classification (CLS) token.
49
+
50
+ Returns:
51
+ Position embeddings (with or without cls token)
52
+ """
53
+ assert embed_dim % 16 == 0
54
+
55
+ t_size, h_size, w_size = grid_size
56
+
57
+ w_embed_dim = embed_dim // 16 * 6
58
+ h_embed_dim = embed_dim // 16 * 6
59
+ t_embed_dim = embed_dim // 16 * 4
60
+
61
+ w_pos_embed = get_1d_sincos_pos_embed_from_grid(w_embed_dim, np.arange(w_size))
62
+ h_pos_embed = get_1d_sincos_pos_embed_from_grid(h_embed_dim, np.arange(h_size))
63
+ t_pos_embed = get_1d_sincos_pos_embed_from_grid(t_embed_dim, np.arange(t_size))
64
+
65
+ w_pos_embed = np.tile(w_pos_embed, (t_size * h_size, 1))
66
+ h_pos_embed = np.tile(np.repeat(h_pos_embed, w_size, axis=0), (t_size, 1))
67
+ t_pos_embed = np.repeat(t_pos_embed, h_size * w_size, axis=0)
68
+
69
+ pos_embed = np.concatenate((w_pos_embed, h_pos_embed, t_pos_embed), axis=1)
70
+
71
+ if add_cls_token:
72
+ pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
73
+ return pos_embed
74
+
75
+
76
+ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
77
+ """Generate 1D sincos position embeddings."""
78
+ if embed_dim % 2 != 0:
79
+ raise ValueError("embed_dim must be even")
80
+
81
+ omega = np.arange(embed_dim // 2, dtype=float)
82
+ omega /= embed_dim / 2.0
83
+ omega = 1.0 / 10000**omega
84
+
85
+ pos = pos.reshape(-1)
86
+ out = np.einsum("m,d->md", pos, omega)
87
+
88
+ emb_sin = np.sin(out)
89
+ emb_cos = np.cos(out)
90
+
91
+ emb = np.concatenate([emb_sin, emb_cos], axis=1)
92
+ return emb
93
+
94
+
95
+ def _get_1d_sincos_embed_from_grid_torch(embed_dim: int, pos: torch.Tensor):
96
+ """Modified torch version of get_1d_sincos_pos_embed_from_grid()."""
97
+ assert embed_dim % 2 == 0
98
+ assert pos.dtype in [torch.float32, torch.float16, torch.bfloat16]
99
+
100
+ omega = torch.arange(embed_dim // 2, dtype=pos.dtype).to(pos.device)
101
+ omega /= embed_dim / 2.0
102
+ omega = 1.0 / 10000**omega
103
+
104
+ pos = pos.reshape(-1)
105
+ out = torch.einsum("m,d->md", pos, omega)
106
+
107
+ emb_sin = torch.sin(out)
108
+ emb_cos = torch.cos(out)
109
+
110
+ emb = torch.cat([emb_sin, emb_cos], dim=1)
111
+ return emb
112
+
113
+
114
+ def _init_weights(module):
115
+ """Initialize the weights."""
116
+ if isinstance(module, nn.Linear):
117
+ nn.init.xavier_uniform_(module.weight)
118
+ if module.bias is not None:
119
+ nn.init.constant_(module.bias, 0)
120
+ elif isinstance(module, nn.LayerNorm):
121
+ module.bias.data.zero_()
122
+ module.weight.data.fill_(1.0)
123
+
124
+
125
+ class PatchEmbed(nn.Module):
126
+ """3D Patch Embedding."""
127
+
128
+ def __init__(
129
+ self,
130
+ input_size: tuple[int, int, int] = (1, 224, 224),
131
+ patch_size: tuple[int, int, int] = (1, 16, 16),
132
+ in_chans: int = 3,
133
+ embed_dim: int = 768,
134
+ norm_layer: nn.Module | None = None,
135
+ flatten: bool = True,
136
+ bias: bool = True,
137
+ ):
138
+ super().__init__()
139
+ self.input_size = input_size
140
+ self.patch_size = patch_size
141
+ self.grid_size = [s // p for s, p in zip(self.input_size, self.patch_size)]
142
+ assert all(
143
+ g >= 1 for g in self.grid_size
144
+ ), "Patch size is bigger than input size."
145
+ self.num_patches = self.grid_size[0] * self.grid_size[1] * self.grid_size[2]
146
+ self.flatten = flatten
147
+
148
+ self.proj = nn.Conv3d(
149
+ in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias
150
+ )
151
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
152
+
153
+ def forward(self, x):
154
+ B, C, T, H, W = x.shape
155
+ x = self.proj(x)
156
+ if self.flatten:
157
+ x = x.flatten(2).transpose(1, 2)
158
+ x = self.norm(x)
159
+ return x
160
+
161
+
162
+ class TemporalEncoder(nn.Module):
163
+ """Temporal coordinate encoder."""
164
+
165
+ def __init__(self, embed_dim: int, trainable_scale: bool = False):
166
+ super().__init__()
167
+ self.embed_dim = embed_dim
168
+ self.year_embed_dim = embed_dim // 2
169
+ self.julian_day_embed_dim = embed_dim - self.year_embed_dim
170
+
171
+ if trainable_scale:
172
+ self.scale = nn.Parameter(torch.tensor(0.1))
173
+ else:
174
+ self.scale = 1.0
175
+
176
+ def forward(
177
+ self, temporal_coords: torch.Tensor, tokens_per_frame: int | None = None
178
+ ):
179
+ """
180
+ temporal_coords: year and day-of-year info with shape (B, T, 2).
181
+ """
182
+ shape = temporal_coords.shape[:2] + (-1,)
183
+
184
+ year = _get_1d_sincos_embed_from_grid_torch(
185
+ self.year_embed_dim, temporal_coords[:, :, 0].flatten()
186
+ ).reshape(shape)
187
+ julian_day = _get_1d_sincos_embed_from_grid_torch(
188
+ self.julian_day_embed_dim, temporal_coords[:, :, 1].flatten()
189
+ ).reshape(shape)
190
+
191
+ embedding = self.scale * torch.cat([year, julian_day], dim=-1)
192
+
193
+ if tokens_per_frame is not None:
194
+ embedding = torch.repeat_interleave(embedding, tokens_per_frame, dim=1)
195
+
196
+ return embedding
197
+
198
+
199
+ class LocationEncoder(nn.Module):
200
+ """Location coordinate encoder."""
201
+
202
+ def __init__(self, embed_dim: int, trainable_scale: bool = False):
203
+ super().__init__()
204
+ self.embed_dim = embed_dim
205
+ self.lat_embed_dim = embed_dim // 2
206
+ self.lon_embed_dim = embed_dim - self.lat_embed_dim
207
+
208
+ if trainable_scale:
209
+ self.scale = nn.Parameter(torch.tensor(0.1))
210
+ else:
211
+ self.scale = 1.0
212
+
213
+ def forward(self, location_coords: torch.Tensor):
214
+ """
215
+ location_coords: lat and lon info with shape (B, 2).
216
+ """
217
+ shape = location_coords.shape[:1] + (1, -1)
218
+
219
+ lat = _get_1d_sincos_embed_from_grid_torch(
220
+ self.lat_embed_dim, location_coords[:, 0].flatten()
221
+ ).reshape(shape)
222
+ lon = _get_1d_sincos_embed_from_grid_torch(
223
+ self.lon_embed_dim, location_coords[:, 1].flatten()
224
+ ).reshape(shape)
225
+
226
+ embedding = self.scale * torch.cat([lat, lon], dim=-1)
227
+
228
+ return embedding
229
+
230
+
231
+ class PrithviViT(nn.Module):
232
+ """Prithvi ViT Encoder."""
233
+
234
+ def __init__(
235
+ self,
236
+ img_size: int | tuple[int, int] = 224,
237
+ patch_size: int | tuple[int, int, int] = (1, 16, 16),
238
+ num_frames: int = 1,
239
+ in_chans: int = 3,
240
+ embed_dim: int = 1024,
241
+ depth: int = 24,
242
+ num_heads: int = 16,
243
+ mlp_ratio: float = 4.0,
244
+ norm_layer: nn.Module = nn.LayerNorm,
245
+ coords_encoding: list[str] | None = None,
246
+ coords_scale_learn: bool = False,
247
+ drop_path: float = 0.0,
248
+ **kwargs,
249
+ ):
250
+ super().__init__()
251
+
252
+ self.in_chans = in_chans
253
+ self.num_frames = num_frames
254
+ self.embed_dim = embed_dim
255
+ self.img_size = to_2tuple(img_size)
256
+ if isinstance(patch_size, int):
257
+ patch_size = (1, patch_size, patch_size)
258
+
259
+ self.patch_embed = PatchEmbed(
260
+ input_size=(num_frames,) + self.img_size,
261
+ patch_size=patch_size,
262
+ in_chans=in_chans,
263
+ embed_dim=embed_dim,
264
+ )
265
+
266
+ coords_encoding = coords_encoding or []
267
+ self.temporal_encoding = "time" in coords_encoding
268
+ self.location_encoding = "location" in coords_encoding
269
+
270
+ if self.temporal_encoding:
271
+ self.temporal_encoder = TemporalEncoder(embed_dim, coords_scale_learn)
272
+ if self.location_encoding:
273
+ self.location_encoder = LocationEncoder(embed_dim, coords_scale_learn)
274
+
275
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
276
+ self.register_buffer(
277
+ "pos_embed", torch.zeros(1, self.patch_embed.num_patches + 1, embed_dim)
278
+ )
279
+
280
+ self.blocks = nn.ModuleList(
281
+ [
282
+ Block(
283
+ embed_dim,
284
+ num_heads,
285
+ mlp_ratio,
286
+ qkv_bias=True,
287
+ norm_layer=norm_layer,
288
+ )
289
+ for _ in range(depth)
290
+ ]
291
+ )
292
+
293
+ self.norm = norm_layer(embed_dim)
294
+ self.initialize_weights()
295
+
296
+ def initialize_weights(self):
297
+ pos_embed = get_3d_sincos_pos_embed(
298
+ self.pos_embed.shape[-1], self.patch_embed.grid_size, add_cls_token=True
299
+ )
300
+ self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
301
+
302
+ w = self.patch_embed.proj.weight.data
303
+ torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
304
+
305
+ torch.nn.init.normal_(self.cls_token, std=0.02)
306
+ self.apply(_init_weights)
307
+
308
+ def random_masking(self, sequence, mask_ratio, noise=None):
309
+ N, L, D = sequence.shape
310
+ len_keep = int(L * (1 - mask_ratio))
311
+
312
+ if noise is None:
313
+ noise = torch.rand(N, L, device=sequence.device)
314
+
315
+ ids_shuffle = torch.argsort(noise, dim=1)
316
+ ids_restore = torch.argsort(ids_shuffle, dim=1)
317
+
318
+ ids_keep = ids_shuffle[:, :len_keep]
319
+ sequence_masked = torch.gather(
320
+ sequence, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D)
321
+ )
322
+
323
+ mask = torch.ones([N, L], device=sequence.device)
324
+ mask[:, :len_keep] = 0
325
+ mask = torch.gather(mask, dim=1, index=ids_restore)
326
+
327
+ return sequence_masked, mask, ids_restore
328
+
329
+ def forward(
330
+ self,
331
+ x: torch.Tensor,
332
+ temporal_coords: None | torch.Tensor = None,
333
+ location_coords: None | torch.Tensor = None,
334
+ mask_ratio=0.75,
335
+ ):
336
+ x = self.patch_embed(x)
337
+ x = x + self.pos_embed[:, 1:, :]
338
+
339
+ if self.temporal_encoding and temporal_coords is not None:
340
+ x = x + self.temporal_encoder(
341
+ temporal_coords, x.shape[1] // self.num_frames
342
+ )
343
+
344
+ if self.location_encoding and location_coords is not None:
345
+ x = x + self.location_encoder(location_coords)
346
+
347
+ cls_token = self.cls_token + self.pos_embed[:, :1, :]
348
+ cls_tokens = cls_token.expand(x.shape[0], -1, -1)
349
+
350
+ x, mask, ids_restore = self.random_masking(x, mask_ratio)
351
+ x = torch.cat((cls_tokens, x), dim=1)
352
+
353
+ for blk in self.blocks:
354
+ x = blk(x)
355
+
356
+ x = self.norm(x)
357
+ return x, mask, ids_restore
358
+
359
+
360
+ class MAEDecoder(nn.Module):
361
+ """Transformer Decoder used in the Prithvi MAE."""
362
+
363
+ def __init__(
364
+ self,
365
+ patch_size: int | tuple[int, int, int] = (1, 16, 16),
366
+ grid_size: list[int] | tuple[int, int, int] = (3, 14, 14),
367
+ in_chans: int = 3,
368
+ encoder_embed_dim: int = 1024,
369
+ decoder_embed_dim: int = 512,
370
+ depth: int = 8,
371
+ num_heads: int = 16,
372
+ mlp_ratio: float = 4.0,
373
+ norm_layer: nn.Module = nn.LayerNorm,
374
+ coords_encoding: list[str] | None = None,
375
+ coords_scale_learn: bool = False,
376
+ ):
377
+ super().__init__()
378
+
379
+ self.patch_size = patch_size
380
+ self.grid_size = grid_size
381
+ self.in_chans = in_chans
382
+ self.decoder_embed = nn.Linear(encoder_embed_dim, decoder_embed_dim, bias=True)
383
+
384
+ self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))
385
+
386
+ self.register_buffer(
387
+ "decoder_pos_embed",
388
+ torch.zeros(
389
+ 1, grid_size[0] * grid_size[1] * grid_size[2] + 1, decoder_embed_dim
390
+ ),
391
+ )
392
+
393
+ coords_encoding = coords_encoding or []
394
+ self.temporal_encoding = "time" in coords_encoding
395
+ self.location_encoding = "location" in coords_encoding
396
+
397
+ if self.temporal_encoding:
398
+ self.temporal_encoder = TemporalEncoder(
399
+ decoder_embed_dim, coords_scale_learn
400
+ )
401
+ if self.location_encoding:
402
+ self.location_encoder = LocationEncoder(
403
+ decoder_embed_dim, coords_scale_learn
404
+ )
405
+
406
+ self.decoder_blocks = nn.ModuleList(
407
+ [
408
+ Block(
409
+ decoder_embed_dim,
410
+ num_heads,
411
+ mlp_ratio,
412
+ qkv_bias=True,
413
+ norm_layer=norm_layer,
414
+ )
415
+ for _ in range(depth)
416
+ ]
417
+ )
418
+
419
+ self.decoder_norm = norm_layer(decoder_embed_dim)
420
+ self.decoder_pred = nn.Linear(
421
+ decoder_embed_dim,
422
+ patch_size[0] * patch_size[1] * patch_size[2] * in_chans,
423
+ bias=True,
424
+ )
425
+
426
+ self.initialize_weights()
427
+
428
+ def initialize_weights(self):
429
+ pos_embed = get_3d_sincos_pos_embed(
430
+ self.decoder_pos_embed.shape[-1], self.grid_size, add_cls_token=True
431
+ )
432
+ self.decoder_pos_embed.data.copy_(
433
+ torch.from_numpy(pos_embed).float().unsqueeze(0)
434
+ )
435
+
436
+ torch.nn.init.normal_(self.mask_token, std=0.02)
437
+ self.apply(_init_weights)
438
+
439
+ def forward(
440
+ self,
441
+ hidden_states: torch.Tensor,
442
+ ids_restore: torch.Tensor,
443
+ temporal_coords: None | torch.Tensor = None,
444
+ location_coords: None | torch.Tensor = None,
445
+ input_size: list[int] = None,
446
+ ):
447
+ x = self.decoder_embed(hidden_states)
448
+
449
+ mask_tokens = self.mask_token.repeat(
450
+ x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1
451
+ )
452
+ x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1)
453
+ x_ = torch.gather(
454
+ x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])
455
+ )
456
+ x = torch.cat([x[:, :1, :], x_], dim=1)
457
+
458
+ x = x + self.decoder_pos_embed
459
+
460
+ if self.temporal_encoding and temporal_coords is not None:
461
+ num_frames = temporal_coords.shape[1]
462
+ tokens_per_frame = (x.shape[1] - 1) // num_frames
463
+ temp_embed = self.temporal_encoder(temporal_coords, tokens_per_frame)
464
+ x[:, 1:, :] = x[:, 1:, :] + temp_embed
465
+
466
+ if self.location_encoding and location_coords is not None:
467
+ x[:, 1:, :] = x[:, 1:, :] + self.location_encoder(location_coords)
468
+
469
+ for blk in self.decoder_blocks:
470
+ x = blk(x)
471
+
472
+ x = self.decoder_norm(x)
473
+ x = self.decoder_pred(x)
474
+ x = x[:, 1:, :]
475
+
476
+ return x
477
+
478
+
479
+ class PrithviMAE(nn.Module):
480
+ """Prithvi Masked Autoencoder."""
481
+
482
+ def __init__(
483
+ self,
484
+ img_size: int | tuple[int, int] = 224,
485
+ patch_size: int | tuple[int, int, int] = (1, 16, 16),
486
+ num_frames: int = 4,
487
+ in_chans: int = 6,
488
+ embed_dim: int = 768,
489
+ depth: int = 12,
490
+ num_heads: int = 12,
491
+ decoder_embed_dim: int = 512,
492
+ decoder_depth: int = 8,
493
+ decoder_num_heads: int = 16,
494
+ mlp_ratio: float = 4.0,
495
+ norm_layer: nn.Module = nn.LayerNorm,
496
+ norm_pix_loss: bool = False,
497
+ coords_encoding: list[str] | None = None,
498
+ coords_scale_learn: bool = False,
499
+ drop_path: float = 0.0,
500
+ mask_ratio: float = 0.75,
501
+ **kwargs,
502
+ ):
503
+ super().__init__()
504
+
505
+ self.img_size = to_2tuple(img_size)
506
+ self.patch_size = (
507
+ patch_size if isinstance(patch_size, tuple) else (1, patch_size, patch_size)
508
+ )
509
+ self.num_frames = num_frames
510
+ self.in_chans = in_chans
511
+ self.norm_pix_loss = norm_pix_loss
512
+
513
+ self.encoder = PrithviViT(
514
+ img_size=img_size,
515
+ patch_size=patch_size,
516
+ num_frames=num_frames,
517
+ in_chans=in_chans,
518
+ embed_dim=embed_dim,
519
+ depth=depth,
520
+ num_heads=num_heads,
521
+ mlp_ratio=mlp_ratio,
522
+ norm_layer=norm_layer,
523
+ coords_encoding=coords_encoding,
524
+ coords_scale_learn=coords_scale_learn,
525
+ drop_path=drop_path,
526
+ )
527
+
528
+ self.decoder = MAEDecoder(
529
+ patch_size=self.patch_size,
530
+ grid_size=self.encoder.patch_embed.grid_size,
531
+ in_chans=in_chans,
532
+ encoder_embed_dim=embed_dim,
533
+ decoder_embed_dim=decoder_embed_dim,
534
+ depth=decoder_depth,
535
+ num_heads=decoder_num_heads,
536
+ mlp_ratio=mlp_ratio,
537
+ norm_layer=norm_layer,
538
+ coords_encoding=coords_encoding,
539
+ coords_scale_learn=coords_scale_learn,
540
+ )
541
+
542
+ def patchify(self, pixel_values):
543
+ B, C, T, H, W = pixel_values.shape
544
+ pH = H // self.patch_size[1]
545
+ pW = W // self.patch_size[2]
546
+
547
+ x = pixel_values.reshape(
548
+ B,
549
+ C,
550
+ T // self.patch_size[0],
551
+ self.patch_size[0],
552
+ pH,
553
+ self.patch_size[1],
554
+ pW,
555
+ self.patch_size[2],
556
+ )
557
+ x = x.permute(0, 2, 4, 6, 3, 5, 7, 1)
558
+ patchified_pixel_values = x.reshape(
559
+ B,
560
+ T // self.patch_size[0] * pH * pW,
561
+ self.patch_size[0] * self.patch_size[1] * self.patch_size[2] * C,
562
+ )
563
+
564
+ return patchified_pixel_values
565
+
566
+ def unpatchify(
567
+ self, patchified_pixel_values, image_size: tuple[int, int] | None = None
568
+ ):
569
+ if image_size is None:
570
+ H, W = self.img_size
571
+ else:
572
+ H, W = image_size
573
+
574
+ C = self.in_chans
575
+ pH = H // self.patch_size[1]
576
+ pW = W // self.patch_size[2]
577
+ T = self.num_frames
578
+
579
+ x = patchified_pixel_values.reshape(
580
+ patchified_pixel_values.shape[0],
581
+ T // self.patch_size[0],
582
+ pH,
583
+ pW,
584
+ self.patch_size[0],
585
+ self.patch_size[1],
586
+ self.patch_size[2],
587
+ C,
588
+ )
589
+ x = x.permute(0, 7, 1, 4, 2, 5, 3, 6)
590
+ pixel_values = x.reshape(
591
+ patchified_pixel_values.shape[0],
592
+ C,
593
+ T,
594
+ pH * self.patch_size[1],
595
+ pW * self.patch_size[2],
596
+ )
597
+
598
+ return pixel_values
599
+
600
+ def forward_loss(self, pixel_values, pred, mask):
601
+ target = self.patchify(pixel_values)
602
+
603
+ if self.norm_pix_loss:
604
+ mean = target.mean(dim=-1, keepdim=True)
605
+ var = target.var(dim=-1, keepdim=True)
606
+ target = (target - mean) / (var + 1.0e-6) ** 0.5
607
+
608
+ loss = (pred - target) ** 2
609
+ loss = loss.mean(dim=-1)
610
+
611
+ loss = (loss * mask).sum() / mask.sum()
612
+ return loss
613
+
614
+ def forward(
615
+ self,
616
+ pixel_values: torch.Tensor,
617
+ temporal_coords: None | torch.Tensor = None,
618
+ location_coords: None | torch.Tensor = None,
619
+ mask_ratio: float = None,
620
+ ):
621
+ mask_ratio = mask_ratio if mask_ratio is not None else 0.75
622
+
623
+ latent, mask, ids_restore = self.encoder(
624
+ pixel_values, temporal_coords, location_coords, mask_ratio
625
+ )
626
+ pred = self.decoder(latent, ids_restore, temporal_coords, location_coords)
627
+ loss = self.forward_loss(pixel_values, pred, mask)
628
+
629
+ return loss, pred, mask
630
+
631
+
632
+ class PrithviProcessor:
633
+ """Prithvi EO 2.0 processor with GeoTIFF input/output support.
634
+
635
+ Supports multiple model variants:
636
+ - Prithvi-EO-2.0-tiny-TL (tiny transfer learning)
637
+ - Prithvi-EO-2.0-100M-TL (100M transfer learning)
638
+ - Prithvi-EO-2.0-300M (300M base model)
639
+ - Prithvi-EO-2.0-300M-TL (300M transfer learning)
640
+ - Prithvi-EO-2.0-600M (600M base model)
641
+ - Prithvi-EO-2.0-600M-TL (600M transfer learning)
642
+
643
+ References:
644
+ - tiny-TL: https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-2.0-tiny-TL
645
+ - 100M-TL: https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-2.0-100M-TL
646
+ - 300M: https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-2.0-300M
647
+ - 300M-TL: https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL
648
+ - 600M: https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-2.0-600M
649
+ - 600M-TL: https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-2.0-600M-TL
650
+ - GitHub: https://github.com/NASA-IMPACT/Prithvi-EO-2.0
651
+ """
652
+
653
+ def __init__(
654
+ self,
655
+ model_name: str = "Prithvi-EO-2.0-300M-TL",
656
+ config_path: Optional[str] = None,
657
+ checkpoint_path: Optional[str] = None,
658
+ device: Optional[torch.device] = None,
659
+ cache_dir: Optional[str] = None,
660
+ ):
661
+ """Initialize Prithvi processor.
662
+
663
+ Args:
664
+ model_name: Name of the Prithvi model to download from HuggingFace Hub.
665
+ Options:
666
+ - "Prithvi-EO-2.0-tiny-TL" (tiny, 192 dim, 12 layers)
667
+ - "Prithvi-EO-2.0-100M-TL" (100M, 768 dim, 12 layers)
668
+ - "Prithvi-EO-2.0-300M" (base, 1024 dim, 24 layers)
669
+ - "Prithvi-EO-2.0-300M-TL" (default, 768 dim, 12 layers)
670
+ - "Prithvi-EO-2.0-600M" (base, 1280 dim, 32 layers)
671
+ - "Prithvi-EO-2.0-600M-TL" (1280 dim, 32 layers)
672
+ config_path: Path to config file (optional, downloads if not provided)
673
+ checkpoint_path: Path to checkpoint file (optional, downloads if not provided)
674
+ device: Torch device to use
675
+ cache_dir: Directory to cache downloaded files
676
+ """
677
+ self.device = device or get_device()
678
+ self.model_name = model_name
679
+ self.cache_dir = cache_dir
680
+
681
+ # Download or load config and checkpoint
682
+ if config_path is None or checkpoint_path is None:
683
+ config_path, checkpoint_path = self.download_model(model_name, cache_dir)
684
+
685
+ self.config_path = config_path
686
+ self.checkpoint_path = checkpoint_path
687
+
688
+ # Load config
689
+ with open(config_path, "r") as f:
690
+ config_data = json.load(f)
691
+ self.config = config_data["pretrained_cfg"]
692
+
693
+ # Extract parameters
694
+ self.bands = self.config["bands"]
695
+ self.mean = self.config["mean"]
696
+ self.std = self.config["std"]
697
+ self.img_size = self.config["img_size"]
698
+ self.patch_size = self.config["patch_size"]
699
+ self.mask_ratio = self.config["mask_ratio"]
700
+ self.num_frames = self.config.get("num_frames", 4)
701
+ self.coords_encoding = self.config.get("coords_encoding", [])
702
+
703
+ # Load model
704
+ self.model = self._load_model()
705
+
706
+ @staticmethod
707
+ def download_model(
708
+ model_name: str = "Prithvi-EO-2.0-300M-TL", cache_dir: str = None
709
+ ) -> Tuple[str, str]:
710
+ """Download Prithvi model from HuggingFace Hub.
711
+
712
+ Args:
713
+ model_name: Name of the model. Options:
714
+ - "Prithvi-EO-2.0-tiny-TL"
715
+ - "Prithvi-EO-2.0-100M-TL"
716
+ - "Prithvi-EO-2.0-300M" (base model)
717
+ - "Prithvi-EO-2.0-300M-TL" (default)
718
+ - "Prithvi-EO-2.0-600M" (base model)
719
+ - "Prithvi-EO-2.0-600M-TL"
720
+ cache_dir: Directory to cache files
721
+
722
+ Returns:
723
+ Tuple of (config_path, checkpoint_path)
724
+ """
725
+ repo_id = f"ibm-nasa-geospatial/{model_name}"
726
+
727
+ try:
728
+ # Download config
729
+ config_path = hf_hub_download(
730
+ repo_id=repo_id,
731
+ filename="config.json",
732
+ cache_dir=cache_dir,
733
+ )
734
+
735
+ # Download checkpoint
736
+ # Model name format: Prithvi-EO-2.0-300M-TL -> Prithvi_EO_V2_300M_TL.pt
737
+ checkpoint_filename = (
738
+ model_name.replace("-", "_").replace("_2.0_", "_V2_") + ".pt"
739
+ )
740
+ checkpoint_path = hf_hub_download(
741
+ repo_id=repo_id,
742
+ filename=checkpoint_filename,
743
+ cache_dir=cache_dir,
744
+ )
745
+
746
+ return config_path, checkpoint_path
747
+
748
+ except Exception as e:
749
+ raise RuntimeError(f"Failed to download model from HuggingFace Hub: {e}")
750
+
751
+ def _load_model(self) -> PrithviMAE:
752
+ """Load Prithvi MAE model."""
753
+ try:
754
+ # Convert patch_size to tuple if it's a list
755
+ patch_size = self.config["patch_size"]
756
+ if isinstance(patch_size, list):
757
+ patch_size = tuple(patch_size)
758
+
759
+ # Create model
760
+ model = PrithviMAE(
761
+ img_size=self.config["img_size"],
762
+ patch_size=patch_size,
763
+ num_frames=self.config["num_frames"],
764
+ in_chans=self.config["in_chans"],
765
+ embed_dim=self.config["embed_dim"],
766
+ depth=self.config["depth"],
767
+ num_heads=self.config["num_heads"],
768
+ decoder_embed_dim=self.config["decoder_embed_dim"],
769
+ decoder_depth=self.config["decoder_depth"],
770
+ decoder_num_heads=self.config["decoder_num_heads"],
771
+ mlp_ratio=self.config["mlp_ratio"],
772
+ coords_encoding=self.coords_encoding,
773
+ coords_scale_learn=self.config.get("coords_scale_learn", False),
774
+ mask_ratio=self.mask_ratio,
775
+ norm_pix_loss=self.config.get("norm_pix_loss", False),
776
+ )
777
+
778
+ # Load checkpoint
779
+ state_dict = torch.load(
780
+ self.checkpoint_path, map_location=self.device, weights_only=True
781
+ )
782
+
783
+ # Remove fixed pos_embed weights
784
+ for k in list(state_dict.keys()):
785
+ if "pos_embed" in k:
786
+ del state_dict[k]
787
+
788
+ model.load_state_dict(state_dict, strict=False)
789
+ model = model.to(self.device)
790
+ model.eval()
791
+
792
+ total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
793
+
794
+ return model
795
+
796
+ except Exception as e:
797
+ raise RuntimeError(f"Failed to load Prithvi model: {e}")
798
+
799
+ def read_geotiff(self, file_path: str) -> Tuple[np.ndarray, dict, Optional[Tuple]]:
800
+ """Read GeoTIFF file.
801
+
802
+ Args:
803
+ file_path: Path to GeoTIFF file
804
+
805
+ Returns:
806
+ Tuple of (image array, metadata, coordinates)
807
+ """
808
+ with rasterio.open(file_path) as src:
809
+ img = src.read()
810
+ meta = src.meta
811
+ try:
812
+ coords = src.tags()
813
+ except:
814
+ coords = None
815
+
816
+ return img, meta, coords
817
+
818
+ def preprocess_image(
819
+ self,
820
+ img: np.ndarray,
821
+ indices: Optional[List[int]] = None,
822
+ ) -> np.ndarray:
823
+ """Preprocess image for model input.
824
+
825
+ Args:
826
+ img: Image array with shape (C, H, W)
827
+ indices: Optional band indices to select
828
+
829
+ Returns:
830
+ Preprocessed image
831
+ """
832
+ # Move channels to last dimension
833
+ img = np.moveaxis(img, 0, -1)
834
+
835
+ # Select bands if specified
836
+ if indices is not None:
837
+ img = img[..., indices]
838
+
839
+ # Normalize (handle nodata)
840
+ img = np.where(img == NO_DATA, NO_DATA_FLOAT, (img - self.mean) / self.std)
841
+
842
+ return img
843
+
844
+ def load_images(
845
+ self,
846
+ file_paths: List[str],
847
+ indices: Optional[List[int]] = None,
848
+ ) -> Tuple[np.ndarray, List[dict], List, List]:
849
+ """Load and preprocess multiple images.
850
+
851
+ Args:
852
+ file_paths: List of GeoTIFF file paths
853
+ indices: Optional band indices
854
+
855
+ Returns:
856
+ Tuple of (images, metadata, temporal_coords, location_coords)
857
+ """
858
+ # Check if we need to pad to num_frames
859
+ if len(file_paths) < self.num_frames:
860
+ # Pad file_paths by repeating the last file
861
+ file_paths = list(file_paths) + [file_paths[-1]] * (
862
+ self.num_frames - len(file_paths)
863
+ )
864
+ elif len(file_paths) > self.num_frames:
865
+ file_paths = file_paths[: self.num_frames]
866
+
867
+ imgs = []
868
+ metas = []
869
+ temporal_coords = []
870
+ location_coords = []
871
+
872
+ for file in file_paths:
873
+ img, meta, coords = self.read_geotiff(file)
874
+
875
+ # Preprocess
876
+ img = self.preprocess_image(img, indices)
877
+
878
+ imgs.append(img)
879
+ metas.append(meta)
880
+
881
+ # Stack images: (T, H, W, C)
882
+ imgs = np.stack(imgs, axis=0)
883
+ # Rearrange to: (C, T, H, W)
884
+ imgs = np.moveaxis(imgs, -1, 0).astype("float32")
885
+ # Add batch dimension: (1, C, T, H, W)
886
+ imgs = np.expand_dims(imgs, axis=0)
887
+
888
+ return imgs, metas, temporal_coords, location_coords
889
+
890
+ def run_inference(
891
+ self,
892
+ input_data: torch.Tensor,
893
+ temporal_coords: Optional[torch.Tensor] = None,
894
+ location_coords: Optional[torch.Tensor] = None,
895
+ mask_ratio: Optional[float] = None,
896
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
897
+ """Run model inference.
898
+
899
+ Args:
900
+ input_data: Input tensor with shape (B, C, T, H, W)
901
+ temporal_coords: Optional temporal coordinates
902
+ location_coords: Optional location coordinates
903
+ mask_ratio: Mask ratio (default: from config)
904
+
905
+ Returns:
906
+ Tuple of (reconstructed_image, mask_image)
907
+ """
908
+ mask_ratio = mask_ratio or self.mask_ratio
909
+
910
+ # Check if input dimensions match model expectations
911
+ B, C, T, H, W = input_data.shape
912
+ if H % self.img_size != 0 or W % self.img_size != 0:
913
+ raise ValueError(
914
+ f"Input spatial dimensions ({H}x{W}) must be divisible by model image size ({self.img_size}). "
915
+ f"Use process_files() method which handles padding automatically, or pad your input to multiples of {self.img_size}."
916
+ )
917
+
918
+ with torch.no_grad():
919
+ x = input_data.to(self.device)
920
+ _, pred, mask = self.model(x, temporal_coords, location_coords, mask_ratio)
921
+
922
+ # Create mask and prediction images
923
+ mask_img = (
924
+ self.model.unpatchify(mask.unsqueeze(-1).repeat(1, 1, pred.shape[-1]))
925
+ .detach()
926
+ .cpu()
927
+ )
928
+ pred_img = self.model.unpatchify(pred).detach().cpu()
929
+
930
+ # Mix visible and predicted patches
931
+ rec_img = input_data.clone()
932
+ rec_img[mask_img == 1] = pred_img[mask_img == 1]
933
+
934
+ # Invert mask for better visualization
935
+ mask_img = (~(mask_img.to(torch.bool))).to(torch.float)
936
+
937
+ return rec_img, mask_img
938
+
939
+ def process_images(
940
+ self,
941
+ file_paths: List[str],
942
+ mask_ratio: Optional[float] = None,
943
+ indices: Optional[List[int]] = None,
944
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
945
+ """Process multiple GeoTIFF files and return tensors (without saving).
946
+
947
+ This method handles large images using sliding windows and returns tensors
948
+ for visualization, unlike process_files() which saves to disk.
949
+
950
+ Args:
951
+ file_paths: List of input file paths
952
+ mask_ratio: Optional mask ratio
953
+ indices: Optional band indices
954
+
955
+ Returns:
956
+ Tuple of (input_tensor, reconstructed_tensor, mask_tensor)
957
+ """
958
+ # Load images
959
+ input_data, metas, temporal_coords, location_coords = self.load_images(
960
+ file_paths, indices
961
+ )
962
+
963
+ # Handle padding
964
+ original_h, original_w = input_data.shape[-2:]
965
+ pad_h = (self.img_size - (original_h % self.img_size)) % self.img_size
966
+ pad_w = (self.img_size - (original_w % self.img_size)) % self.img_size
967
+
968
+ if pad_h > 0 or pad_w > 0:
969
+ input_data = np.pad(
970
+ input_data,
971
+ ((0, 0), (0, 0), (0, 0), (0, pad_h), (0, pad_w)),
972
+ mode="reflect",
973
+ )
974
+
975
+ # Convert to tensor
976
+ batch = torch.tensor(input_data, device="cpu")
977
+
978
+ # Create sliding windows
979
+ windows = batch.unfold(3, self.img_size, self.img_size).unfold(
980
+ 4, self.img_size, self.img_size
981
+ )
982
+ h1, w1 = windows.shape[3:5]
983
+ windows = rearrange(
984
+ windows,
985
+ "b c t h1 w1 h w -> (b h1 w1) c t h w",
986
+ h=self.img_size,
987
+ w=self.img_size,
988
+ )
989
+
990
+ # Split into batches
991
+ num_batches = max(1, windows.shape[0])
992
+ windows_list = torch.tensor_split(windows, num_batches, dim=0)
993
+
994
+ # Process each window
995
+ rec_imgs = []
996
+ mask_imgs = []
997
+
998
+ for i, x in enumerate(windows_list):
999
+ rec_img, mask_img = self.run_inference(x, None, None, mask_ratio)
1000
+ rec_imgs.append(rec_img)
1001
+ mask_imgs.append(mask_img)
1002
+
1003
+ # Concatenate results
1004
+ rec_imgs = torch.cat(rec_imgs, dim=0)
1005
+ mask_imgs = torch.cat(mask_imgs, dim=0)
1006
+
1007
+ # Rearrange patches back to image
1008
+ num_frames = len(file_paths)
1009
+ rec_imgs = rearrange(
1010
+ rec_imgs,
1011
+ "(b h1 w1) c t h w -> b c t (h1 h) (w1 w)",
1012
+ h=self.img_size,
1013
+ w=self.img_size,
1014
+ b=1,
1015
+ c=len(self.bands),
1016
+ t=num_frames,
1017
+ h1=h1,
1018
+ w1=w1,
1019
+ )
1020
+ mask_imgs = rearrange(
1021
+ mask_imgs,
1022
+ "(b h1 w1) c t h w -> b c t (h1 h) (w1 w)",
1023
+ h=self.img_size,
1024
+ w=self.img_size,
1025
+ b=1,
1026
+ c=len(self.bands),
1027
+ t=num_frames,
1028
+ h1=h1,
1029
+ w1=w1,
1030
+ )
1031
+
1032
+ # Remove padding
1033
+ rec_imgs = rec_imgs[..., :original_h, :original_w]
1034
+ mask_imgs = mask_imgs[..., :original_h, :original_w]
1035
+ input_imgs = batch[..., :original_h, :original_w]
1036
+
1037
+ return input_imgs, rec_imgs, mask_imgs
1038
+
1039
+ def visualize_rgb(
1040
+ self,
1041
+ input_tensor: torch.Tensor,
1042
+ rec_tensor: torch.Tensor,
1043
+ mask_tensor: torch.Tensor,
1044
+ ) -> Tuple[List[np.ndarray], List[np.ndarray], List[np.ndarray]]:
1045
+ """Extract RGB images from tensors for visualization.
1046
+
1047
+ Args:
1048
+ input_tensor: Input tensor (B, C, T, H, W)
1049
+ rec_tensor: Reconstructed tensor (B, C, T, H, W)
1050
+ mask_tensor: Mask tensor (B, C, T, H, W)
1051
+
1052
+ Returns:
1053
+ Tuple of (original_rgb, masked_rgb, reconstructed_rgb) lists
1054
+ """
1055
+ # Get RGB channel indices (B04=Red, B03=Green, B02=Blue)
1056
+ rgb_channels = [
1057
+ self.bands.index("B04"),
1058
+ self.bands.index("B03"),
1059
+ self.bands.index("B02"),
1060
+ ]
1061
+
1062
+ # Remove batch dimension
1063
+ if input_tensor.dim() == 5:
1064
+ input_tensor = input_tensor[0]
1065
+ if rec_tensor.dim() == 5:
1066
+ rec_tensor = rec_tensor[0]
1067
+ if mask_tensor.dim() == 5:
1068
+ mask_tensor = mask_tensor[0]
1069
+
1070
+ mean = torch.tensor(self.mean)
1071
+ std = torch.tensor(self.std)
1072
+
1073
+ original_rgb = []
1074
+ masked_rgb = []
1075
+ reconstructed_rgb = []
1076
+
1077
+ num_frames = input_tensor.shape[1]
1078
+
1079
+ for t in range(num_frames):
1080
+ # Extract and denormalize original RGB
1081
+ rgb_orig = input_tensor[rgb_channels, t, :, :].clone()
1082
+ for i, c in enumerate(rgb_channels):
1083
+ rgb_orig[i] = rgb_orig[i] * std[c] + mean[c]
1084
+ rgb_orig_np = rgb_orig.numpy()
1085
+ rgb_orig_np = np.clip(rgb_orig_np, 0, 10000)
1086
+ rgb_orig_np = (rgb_orig_np / 10000 * 255).astype(np.uint8)
1087
+ rgb_orig_np = np.transpose(rgb_orig_np, (1, 2, 0))
1088
+ original_rgb.append(rgb_orig_np)
1089
+
1090
+ # Extract and denormalize reconstructed RGB
1091
+ rgb_rec = rec_tensor[rgb_channels, t, :, :].clone()
1092
+ for i, c in enumerate(rgb_channels):
1093
+ rgb_rec[i] = rgb_rec[i] * std[c] + mean[c]
1094
+ rgb_rec_np = rgb_rec.numpy()
1095
+ rgb_rec_np = np.clip(rgb_rec_np, 0, 10000)
1096
+ rgb_rec_np = (rgb_rec_np / 10000 * 255).astype(np.uint8)
1097
+ rgb_rec_np = np.transpose(rgb_rec_np, (1, 2, 0))
1098
+ reconstructed_rgb.append(rgb_rec_np)
1099
+
1100
+ # Create masked RGB (visible patches only)
1101
+ mask_t = mask_tensor[rgb_channels, t, :, :].numpy()
1102
+ masked_np = rgb_orig_np.astype(np.float32) * np.transpose(mask_t, (1, 2, 0))
1103
+ masked_rgb.append(masked_np.astype(np.uint8))
1104
+
1105
+ return original_rgb, masked_rgb, reconstructed_rgb
1106
+
1107
+ def process_files(
1108
+ self,
1109
+ file_paths: List[str],
1110
+ output_dir: str,
1111
+ mask_ratio: Optional[float] = None,
1112
+ indices: Optional[List[int]] = None,
1113
+ ):
1114
+ """Process multiple GeoTIFF files.
1115
+
1116
+ Args:
1117
+ file_paths: List of input file paths
1118
+ output_dir: Output directory for results
1119
+ mask_ratio: Optional mask ratio
1120
+ indices: Optional band indices
1121
+ """
1122
+ os.makedirs(output_dir, exist_ok=True)
1123
+
1124
+ # Load images
1125
+ input_data, metas, temporal_coords, location_coords = self.load_images(
1126
+ file_paths, indices
1127
+ )
1128
+
1129
+ # Handle padding
1130
+ original_h, original_w = input_data.shape[-2:]
1131
+ pad_h = self.img_size - (original_h % self.img_size)
1132
+ pad_w = self.img_size - (original_w % self.img_size)
1133
+ input_data = np.pad(
1134
+ input_data, ((0, 0), (0, 0), (0, 0), (0, pad_h), (0, pad_w)), mode="reflect"
1135
+ )
1136
+
1137
+ # Convert to tensor
1138
+ batch = torch.tensor(input_data, device="cpu")
1139
+
1140
+ # Create sliding windows
1141
+ windows = batch.unfold(3, self.img_size, self.img_size).unfold(
1142
+ 4, self.img_size, self.img_size
1143
+ )
1144
+ h1, w1 = windows.shape[3:5]
1145
+ windows = rearrange(
1146
+ windows,
1147
+ "b c t h1 w1 h w -> (b h1 w1) c t h w",
1148
+ h=self.img_size,
1149
+ w=self.img_size,
1150
+ )
1151
+
1152
+ # Split into batches
1153
+ num_batches = max(1, windows.shape[0])
1154
+ windows_list = torch.tensor_split(windows, num_batches, dim=0)
1155
+
1156
+ # Process each window
1157
+ rec_imgs = []
1158
+ mask_imgs = []
1159
+
1160
+ for i, x in enumerate(windows_list):
1161
+ rec_img, mask_img = self.run_inference(x, None, None, mask_ratio)
1162
+ rec_imgs.append(rec_img)
1163
+ mask_imgs.append(mask_img)
1164
+
1165
+ # Concatenate results
1166
+ rec_imgs = torch.cat(rec_imgs, dim=0)
1167
+ mask_imgs = torch.cat(mask_imgs, dim=0)
1168
+
1169
+ # Rearrange patches back to image
1170
+ num_frames = len(file_paths)
1171
+ rec_imgs = rearrange(
1172
+ rec_imgs,
1173
+ "(b h1 w1) c t h w -> b c t (h1 h) (w1 w)",
1174
+ h=self.img_size,
1175
+ w=self.img_size,
1176
+ b=1,
1177
+ c=len(self.bands),
1178
+ t=num_frames,
1179
+ h1=h1,
1180
+ w1=w1,
1181
+ )
1182
+ mask_imgs = rearrange(
1183
+ mask_imgs,
1184
+ "(b h1 w1) c t h w -> b c t (h1 h) (w1 w)",
1185
+ h=self.img_size,
1186
+ w=self.img_size,
1187
+ b=1,
1188
+ c=len(self.bands),
1189
+ t=num_frames,
1190
+ h1=h1,
1191
+ w1=w1,
1192
+ )
1193
+
1194
+ # Remove padding
1195
+ rec_imgs = rec_imgs[..., :original_h, :original_w]
1196
+ mask_imgs = mask_imgs[..., :original_h, :original_w]
1197
+
1198
+ # Save results
1199
+ self.save_results(rec_imgs[0], mask_imgs[0], metas, output_dir)
1200
+
1201
+ def save_results(
1202
+ self,
1203
+ rec_img: torch.Tensor,
1204
+ mask_img: torch.Tensor,
1205
+ metas: List[dict],
1206
+ output_dir: str,
1207
+ ):
1208
+ """Save reconstruction results.
1209
+
1210
+ Args:
1211
+ rec_img: Reconstructed image with shape (C, T, H, W)
1212
+ mask_img: Mask image with shape (C, T, H, W)
1213
+ metas: List of metadata dicts
1214
+ output_dir: Output directory
1215
+ """
1216
+ mean = torch.tensor(np.asarray(self.mean)[:, None, None])
1217
+ std = torch.tensor(np.asarray(self.std)[:, None, None])
1218
+
1219
+ for t in range(rec_img.shape[1]):
1220
+ # Denormalize
1221
+ rec_img_t = ((rec_img[:, t, :, :] * std) + mean).to(torch.int16)
1222
+ mask_img_t = mask_img[:, t, :, :].to(torch.int16)
1223
+
1224
+ # Update metadata
1225
+ meta = metas[t].copy()
1226
+ meta.update(compress="lzw", nodata=0)
1227
+
1228
+ # Save files
1229
+ self._save_geotiff(
1230
+ rec_img_t.numpy(),
1231
+ os.path.join(output_dir, f"reconstructed_t{t}.tif"),
1232
+ meta,
1233
+ )
1234
+ self._save_geotiff(
1235
+ mask_img_t.numpy(),
1236
+ os.path.join(output_dir, f"mask_t{t}.tif"),
1237
+ meta,
1238
+ )
1239
+
1240
+ @staticmethod
1241
+ def _save_geotiff(image: np.ndarray, output_path: str, meta: dict):
1242
+ """Save GeoTIFF file."""
1243
+ with rasterio.open(output_path, "w", **meta) as dest:
1244
+ for i in range(image.shape[0]):
1245
+ dest.write(image[i], i + 1)
1246
+
1247
+
1248
+ def get_available_prithvi_models() -> List[str]:
1249
+ """Get list of available Prithvi model names.
1250
+
1251
+ Returns:
1252
+ List of available model names
1253
+
1254
+ Example:
1255
+ >>> models = get_available_prithvi_models()
1256
+ >>> print(models)
1257
+ ['Prithvi-EO-2.0-300M-TL', 'Prithvi-EO-2.0-600M-TL']
1258
+ """
1259
+ return AVAILABLE_MODELS.copy()
1260
+
1261
+
1262
+ def load_prithvi_model(
1263
+ model_name: str = "Prithvi-EO-2.0-300M-TL",
1264
+ device: Optional[str] = None,
1265
+ cache_dir: Optional[str] = None,
1266
+ ) -> PrithviProcessor:
1267
+ """Load Prithvi model (convenience function).
1268
+
1269
+ Args:
1270
+ model_name: Name of the model. Options:
1271
+ - "Prithvi-EO-2.0-tiny-TL"
1272
+ - "Prithvi-EO-2.0-100M-TL"
1273
+ - "Prithvi-EO-2.0-300M" (base)
1274
+ - "Prithvi-EO-2.0-300M-TL" (default)
1275
+ - "Prithvi-EO-2.0-600M" (base)
1276
+ - "Prithvi-EO-2.0-600M-TL"
1277
+ device: Device to use ('cuda' or 'cpu')
1278
+ cache_dir: Cache directory
1279
+
1280
+ Returns:
1281
+ PrithviProcessor instance
1282
+
1283
+ Example:
1284
+ >>> # Load tiny-TL model
1285
+ >>> processor = load_prithvi_model("Prithvi-EO-2.0-tiny-TL")
1286
+ >>> # Load 100M-TL model
1287
+ >>> processor = load_prithvi_model("Prithvi-EO-2.0-100M-TL")
1288
+ >>> # Load 300M base model
1289
+ >>> processor = load_prithvi_model("Prithvi-EO-2.0-300M")
1290
+ >>> # Load 300M-TL model
1291
+ >>> processor = load_prithvi_model("Prithvi-EO-2.0-300M-TL")
1292
+ >>> # Load 600M base model
1293
+ >>> processor = load_prithvi_model("Prithvi-EO-2.0-600M")
1294
+ >>> # Load 600M-TL model
1295
+ >>> processor = load_prithvi_model("Prithvi-EO-2.0-600M-TL")
1296
+ """
1297
+ if device is not None:
1298
+ device = torch.device(device)
1299
+
1300
+ return PrithviProcessor(
1301
+ model_name=model_name,
1302
+ device=device,
1303
+ cache_dir=cache_dir,
1304
+ )
1305
+
1306
+
1307
+ def prithvi_inference(
1308
+ file_paths: List[str],
1309
+ output_dir: str = "output",
1310
+ model_name: str = "Prithvi-EO-2.0-300M-TL",
1311
+ mask_ratio: Optional[float] = None,
1312
+ device: Optional[str] = None,
1313
+ ):
1314
+ """Run Prithvi inference on GeoTIFF files (convenience function).
1315
+
1316
+ Args:
1317
+ file_paths: List of input GeoTIFF files
1318
+ output_dir: Output directory
1319
+ model_name: Name of the model. Options:
1320
+ - "Prithvi-EO-2.0-tiny-TL"
1321
+ - "Prithvi-EO-2.0-100M-TL"
1322
+ - "Prithvi-EO-2.0-300M" (base)
1323
+ - "Prithvi-EO-2.0-300M-TL" (default)
1324
+ - "Prithvi-EO-2.0-600M" (base)
1325
+ - "Prithvi-EO-2.0-600M-TL"
1326
+ mask_ratio: Optional mask ratio
1327
+ device: Device to use
1328
+
1329
+ Example:
1330
+ >>> # Use tiny-TL model
1331
+ >>> prithvi_inference(
1332
+ ... file_paths=["img1.tif", "img2.tif", "img3.tif", "img4.tif"],
1333
+ ... model_name="Prithvi-EO-2.0-tiny-TL",
1334
+ ... output_dir="output_tiny"
1335
+ ... )
1336
+ """
1337
+ processor = load_prithvi_model(model_name, device)
1338
+ processor.process_files(file_paths, output_dir, mask_ratio)