geoai-py 0.25.0__py2.py3-none-any.whl → 0.26.0__py2.py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
geoai/__init__.py CHANGED
@@ -2,7 +2,7 @@
2
2
 
3
3
  __author__ = """Qiusheng Wu"""
4
4
  __email__ = "giswqs@gmail.com"
5
- __version__ = "0.25.0"
5
+ __version__ = "0.26.0"
6
6
 
7
7
 
8
8
  import os
@@ -162,3 +162,14 @@ try:
162
162
  except ImportError:
163
163
  # Moondream not available (missing dependency)
164
164
  pass
165
+
166
+ # Prithvi EO 2.0 Geospatial Foundation Model
167
+ try:
168
+ from .prithvi import (
169
+ PrithviProcessor,
170
+ load_prithvi_model,
171
+ prithvi_inference,
172
+ )
173
+ except ImportError:
174
+ # Prithvi not available (missing dependency)
175
+ pass
geoai/auto.py CHANGED
@@ -54,7 +54,6 @@ from transformers.utils import logging as hf_logging
54
54
 
55
55
  from .utils import get_device
56
56
 
57
-
58
57
  hf_logging.set_verbosity_error() # silence HF load reports
59
58
 
60
59
 
geoai/change_detection.py CHANGED
@@ -249,7 +249,7 @@ class ChangeDetection:
249
249
  dict: Detailed results if return_detailed_results=True
250
250
  """
251
251
  # Read and align images
252
- (img1, img2, transform, crs, original_shape) = self._read_and_align_images(
252
+ img1, img2, transform, crs, original_shape = self._read_and_align_images(
253
253
  image1_path, image2_path, target_size
254
254
  )
255
255
 
geoai/moondream.py CHANGED
@@ -23,7 +23,6 @@ from transformers.utils import logging as hf_logging
23
23
 
24
24
  from .utils import get_device
25
25
 
26
-
27
26
  hf_logging.set_verbosity_error() # silence HF load reports
28
27
 
29
28
 
geoai/prithvi.py ADDED
@@ -0,0 +1,1253 @@
1
+ """Prithvi EO 2.0 module for geospatial foundation model inference.
2
+
3
+ This module provides tools for using NASA-IBM's Prithvi EO 2.0 geospatial foundation model
4
+ for masked autoencoding and feature extraction on multi-temporal satellite imagery.
5
+ """
6
+
7
+ import json
8
+ import logging
9
+ import os
10
+ from typing import Dict, List, Optional, Tuple, Union
11
+
12
+ import numpy as np
13
+ import rasterio
14
+ import torch
15
+ import torch.nn as nn
16
+ from einops import rearrange
17
+ from huggingface_hub import hf_hub_download
18
+ from timm.layers import to_2tuple
19
+ from timm.models.vision_transformer import Block
20
+
21
+ from .utils import get_device
22
+
23
+ logger = logging.getLogger(__name__)
24
+
25
+ # Constants
26
+ NO_DATA = -9999
27
+ NO_DATA_FLOAT = 0.0001
28
+ OFFSET = 0
29
+ PERCENTILE = 99.9
30
+
31
+
32
+ def get_3d_sincos_pos_embed(embed_dim, grid_size, add_cls_token=False):
33
+ """Create 3D sin/cos positional embeddings.
34
+
35
+ Args:
36
+ embed_dim (int): Embedding dimension.
37
+ grid_size (tuple[int, int, int] | list[int]): The grid depth, height and width.
38
+ add_cls_token (bool, optional): Whether or not to add a classification (CLS) token.
39
+
40
+ Returns:
41
+ Position embeddings (with or without cls token)
42
+ """
43
+ assert embed_dim % 16 == 0
44
+
45
+ t_size, h_size, w_size = grid_size
46
+
47
+ w_embed_dim = embed_dim // 16 * 6
48
+ h_embed_dim = embed_dim // 16 * 6
49
+ t_embed_dim = embed_dim // 16 * 4
50
+
51
+ w_pos_embed = get_1d_sincos_pos_embed_from_grid(w_embed_dim, np.arange(w_size))
52
+ h_pos_embed = get_1d_sincos_pos_embed_from_grid(h_embed_dim, np.arange(h_size))
53
+ t_pos_embed = get_1d_sincos_pos_embed_from_grid(t_embed_dim, np.arange(t_size))
54
+
55
+ w_pos_embed = np.tile(w_pos_embed, (t_size * h_size, 1))
56
+ h_pos_embed = np.tile(np.repeat(h_pos_embed, w_size, axis=0), (t_size, 1))
57
+ t_pos_embed = np.repeat(t_pos_embed, h_size * w_size, axis=0)
58
+
59
+ pos_embed = np.concatenate((w_pos_embed, h_pos_embed, t_pos_embed), axis=1)
60
+
61
+ if add_cls_token:
62
+ pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
63
+ return pos_embed
64
+
65
+
66
+ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
67
+ """Generate 1D sincos position embeddings."""
68
+ if embed_dim % 2 != 0:
69
+ raise ValueError("embed_dim must be even")
70
+
71
+ omega = np.arange(embed_dim // 2, dtype=float)
72
+ omega /= embed_dim / 2.0
73
+ omega = 1.0 / 10000**omega
74
+
75
+ pos = pos.reshape(-1)
76
+ out = np.einsum("m,d->md", pos, omega)
77
+
78
+ emb_sin = np.sin(out)
79
+ emb_cos = np.cos(out)
80
+
81
+ emb = np.concatenate([emb_sin, emb_cos], axis=1)
82
+ return emb
83
+
84
+
85
+ def _get_1d_sincos_embed_from_grid_torch(embed_dim: int, pos: torch.Tensor):
86
+ """Modified torch version of get_1d_sincos_pos_embed_from_grid()."""
87
+ assert embed_dim % 2 == 0
88
+ assert pos.dtype in [torch.float32, torch.float16, torch.bfloat16]
89
+
90
+ omega = torch.arange(embed_dim // 2, dtype=pos.dtype).to(pos.device)
91
+ omega /= embed_dim / 2.0
92
+ omega = 1.0 / 10000**omega
93
+
94
+ pos = pos.reshape(-1)
95
+ out = torch.einsum("m,d->md", pos, omega)
96
+
97
+ emb_sin = torch.sin(out)
98
+ emb_cos = torch.cos(out)
99
+
100
+ emb = torch.cat([emb_sin, emb_cos], dim=1)
101
+ return emb
102
+
103
+
104
+ def _init_weights(module):
105
+ """Initialize the weights."""
106
+ if isinstance(module, nn.Linear):
107
+ nn.init.xavier_uniform_(module.weight)
108
+ if module.bias is not None:
109
+ nn.init.constant_(module.bias, 0)
110
+ elif isinstance(module, nn.LayerNorm):
111
+ module.bias.data.zero_()
112
+ module.weight.data.fill_(1.0)
113
+
114
+
115
+ class PatchEmbed(nn.Module):
116
+ """3D Patch Embedding."""
117
+
118
+ def __init__(
119
+ self,
120
+ input_size: tuple[int, int, int] = (1, 224, 224),
121
+ patch_size: tuple[int, int, int] = (1, 16, 16),
122
+ in_chans: int = 3,
123
+ embed_dim: int = 768,
124
+ norm_layer: nn.Module | None = None,
125
+ flatten: bool = True,
126
+ bias: bool = True,
127
+ ):
128
+ super().__init__()
129
+ self.input_size = input_size
130
+ self.patch_size = patch_size
131
+ self.grid_size = [s // p for s, p in zip(self.input_size, self.patch_size)]
132
+ assert all(
133
+ g >= 1 for g in self.grid_size
134
+ ), "Patch size is bigger than input size."
135
+ self.num_patches = self.grid_size[0] * self.grid_size[1] * self.grid_size[2]
136
+ self.flatten = flatten
137
+
138
+ self.proj = nn.Conv3d(
139
+ in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias
140
+ )
141
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
142
+
143
+ def forward(self, x):
144
+ B, C, T, H, W = x.shape
145
+ x = self.proj(x)
146
+ if self.flatten:
147
+ x = x.flatten(2).transpose(1, 2)
148
+ x = self.norm(x)
149
+ return x
150
+
151
+
152
+ class TemporalEncoder(nn.Module):
153
+ """Temporal coordinate encoder."""
154
+
155
+ def __init__(self, embed_dim: int, trainable_scale: bool = False):
156
+ super().__init__()
157
+ self.embed_dim = embed_dim
158
+ self.year_embed_dim = embed_dim // 2
159
+ self.julian_day_embed_dim = embed_dim - self.year_embed_dim
160
+
161
+ if trainable_scale:
162
+ self.scale = nn.Parameter(torch.tensor(0.1))
163
+ else:
164
+ self.scale = 1.0
165
+
166
+ def forward(
167
+ self, temporal_coords: torch.Tensor, tokens_per_frame: int | None = None
168
+ ):
169
+ """
170
+ temporal_coords: year and day-of-year info with shape (B, T, 2).
171
+ """
172
+ shape = temporal_coords.shape[:2] + (-1,)
173
+
174
+ year = _get_1d_sincos_embed_from_grid_torch(
175
+ self.year_embed_dim, temporal_coords[:, :, 0].flatten()
176
+ ).reshape(shape)
177
+ julian_day = _get_1d_sincos_embed_from_grid_torch(
178
+ self.julian_day_embed_dim, temporal_coords[:, :, 1].flatten()
179
+ ).reshape(shape)
180
+
181
+ embedding = self.scale * torch.cat([year, julian_day], dim=-1)
182
+
183
+ if tokens_per_frame is not None:
184
+ embedding = torch.repeat_interleave(embedding, tokens_per_frame, dim=1)
185
+
186
+ return embedding
187
+
188
+
189
+ class LocationEncoder(nn.Module):
190
+ """Location coordinate encoder."""
191
+
192
+ def __init__(self, embed_dim: int, trainable_scale: bool = False):
193
+ super().__init__()
194
+ self.embed_dim = embed_dim
195
+ self.lat_embed_dim = embed_dim // 2
196
+ self.lon_embed_dim = embed_dim - self.lat_embed_dim
197
+
198
+ if trainable_scale:
199
+ self.scale = nn.Parameter(torch.tensor(0.1))
200
+ else:
201
+ self.scale = 1.0
202
+
203
+ def forward(self, location_coords: torch.Tensor):
204
+ """
205
+ location_coords: lat and lon info with shape (B, 2).
206
+ """
207
+ shape = location_coords.shape[:1] + (1, -1)
208
+
209
+ lat = _get_1d_sincos_embed_from_grid_torch(
210
+ self.lat_embed_dim, location_coords[:, 0].flatten()
211
+ ).reshape(shape)
212
+ lon = _get_1d_sincos_embed_from_grid_torch(
213
+ self.lon_embed_dim, location_coords[:, 1].flatten()
214
+ ).reshape(shape)
215
+
216
+ embedding = self.scale * torch.cat([lat, lon], dim=-1)
217
+
218
+ return embedding
219
+
220
+
221
+ class PrithviViT(nn.Module):
222
+ """Prithvi ViT Encoder."""
223
+
224
+ def __init__(
225
+ self,
226
+ img_size: int | tuple[int, int] = 224,
227
+ patch_size: int | tuple[int, int, int] = (1, 16, 16),
228
+ num_frames: int = 1,
229
+ in_chans: int = 3,
230
+ embed_dim: int = 1024,
231
+ depth: int = 24,
232
+ num_heads: int = 16,
233
+ mlp_ratio: float = 4.0,
234
+ norm_layer: nn.Module = nn.LayerNorm,
235
+ coords_encoding: list[str] | None = None,
236
+ coords_scale_learn: bool = False,
237
+ drop_path: float = 0.0,
238
+ **kwargs,
239
+ ):
240
+ super().__init__()
241
+
242
+ self.in_chans = in_chans
243
+ self.num_frames = num_frames
244
+ self.embed_dim = embed_dim
245
+ self.img_size = to_2tuple(img_size)
246
+ if isinstance(patch_size, int):
247
+ patch_size = (1, patch_size, patch_size)
248
+
249
+ self.patch_embed = PatchEmbed(
250
+ input_size=(num_frames,) + self.img_size,
251
+ patch_size=patch_size,
252
+ in_chans=in_chans,
253
+ embed_dim=embed_dim,
254
+ )
255
+
256
+ coords_encoding = coords_encoding or []
257
+ self.temporal_encoding = "time" in coords_encoding
258
+ self.location_encoding = "location" in coords_encoding
259
+
260
+ if self.temporal_encoding:
261
+ self.temporal_encoder = TemporalEncoder(embed_dim, coords_scale_learn)
262
+ if self.location_encoding:
263
+ self.location_encoder = LocationEncoder(embed_dim, coords_scale_learn)
264
+
265
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
266
+ self.register_buffer(
267
+ "pos_embed", torch.zeros(1, self.patch_embed.num_patches + 1, embed_dim)
268
+ )
269
+
270
+ self.blocks = nn.ModuleList(
271
+ [
272
+ Block(
273
+ embed_dim,
274
+ num_heads,
275
+ mlp_ratio,
276
+ qkv_bias=True,
277
+ norm_layer=norm_layer,
278
+ )
279
+ for _ in range(depth)
280
+ ]
281
+ )
282
+
283
+ self.norm = norm_layer(embed_dim)
284
+ self.initialize_weights()
285
+
286
+ def initialize_weights(self):
287
+ pos_embed = get_3d_sincos_pos_embed(
288
+ self.pos_embed.shape[-1], self.patch_embed.grid_size, add_cls_token=True
289
+ )
290
+ self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
291
+
292
+ w = self.patch_embed.proj.weight.data
293
+ torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
294
+
295
+ torch.nn.init.normal_(self.cls_token, std=0.02)
296
+ self.apply(_init_weights)
297
+
298
+ def random_masking(self, sequence, mask_ratio, noise=None):
299
+ N, L, D = sequence.shape
300
+ len_keep = int(L * (1 - mask_ratio))
301
+
302
+ if noise is None:
303
+ noise = torch.rand(N, L, device=sequence.device)
304
+
305
+ ids_shuffle = torch.argsort(noise, dim=1)
306
+ ids_restore = torch.argsort(ids_shuffle, dim=1)
307
+
308
+ ids_keep = ids_shuffle[:, :len_keep]
309
+ sequence_masked = torch.gather(
310
+ sequence, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D)
311
+ )
312
+
313
+ mask = torch.ones([N, L], device=sequence.device)
314
+ mask[:, :len_keep] = 0
315
+ mask = torch.gather(mask, dim=1, index=ids_restore)
316
+
317
+ return sequence_masked, mask, ids_restore
318
+
319
+ def forward(
320
+ self,
321
+ x: torch.Tensor,
322
+ temporal_coords: None | torch.Tensor = None,
323
+ location_coords: None | torch.Tensor = None,
324
+ mask_ratio=0.75,
325
+ ):
326
+ x = self.patch_embed(x)
327
+ x = x + self.pos_embed[:, 1:, :]
328
+
329
+ if self.temporal_encoding and temporal_coords is not None:
330
+ x = x + self.temporal_encoder(
331
+ temporal_coords, x.shape[1] // self.num_frames
332
+ )
333
+
334
+ if self.location_encoding and location_coords is not None:
335
+ x = x + self.location_encoder(location_coords)
336
+
337
+ cls_token = self.cls_token + self.pos_embed[:, :1, :]
338
+ cls_tokens = cls_token.expand(x.shape[0], -1, -1)
339
+
340
+ x, mask, ids_restore = self.random_masking(x, mask_ratio)
341
+ x = torch.cat((cls_tokens, x), dim=1)
342
+
343
+ for blk in self.blocks:
344
+ x = blk(x)
345
+
346
+ x = self.norm(x)
347
+ return x, mask, ids_restore
348
+
349
+
350
+ class MAEDecoder(nn.Module):
351
+ """Transformer Decoder used in the Prithvi MAE."""
352
+
353
+ def __init__(
354
+ self,
355
+ patch_size: int | tuple[int, int, int] = (1, 16, 16),
356
+ grid_size: list[int] | tuple[int, int, int] = (3, 14, 14),
357
+ in_chans: int = 3,
358
+ encoder_embed_dim: int = 1024,
359
+ decoder_embed_dim: int = 512,
360
+ depth: int = 8,
361
+ num_heads: int = 16,
362
+ mlp_ratio: float = 4.0,
363
+ norm_layer: nn.Module = nn.LayerNorm,
364
+ coords_encoding: list[str] | None = None,
365
+ coords_scale_learn: bool = False,
366
+ ):
367
+ super().__init__()
368
+
369
+ self.patch_size = patch_size
370
+ self.grid_size = grid_size
371
+ self.in_chans = in_chans
372
+ self.decoder_embed = nn.Linear(encoder_embed_dim, decoder_embed_dim, bias=True)
373
+
374
+ self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))
375
+
376
+ self.register_buffer(
377
+ "decoder_pos_embed",
378
+ torch.zeros(
379
+ 1, grid_size[0] * grid_size[1] * grid_size[2] + 1, decoder_embed_dim
380
+ ),
381
+ )
382
+
383
+ coords_encoding = coords_encoding or []
384
+ self.temporal_encoding = "time" in coords_encoding
385
+ self.location_encoding = "location" in coords_encoding
386
+
387
+ if self.temporal_encoding:
388
+ self.temporal_encoder = TemporalEncoder(
389
+ decoder_embed_dim, coords_scale_learn
390
+ )
391
+ if self.location_encoding:
392
+ self.location_encoder = LocationEncoder(
393
+ decoder_embed_dim, coords_scale_learn
394
+ )
395
+
396
+ self.decoder_blocks = nn.ModuleList(
397
+ [
398
+ Block(
399
+ decoder_embed_dim,
400
+ num_heads,
401
+ mlp_ratio,
402
+ qkv_bias=True,
403
+ norm_layer=norm_layer,
404
+ )
405
+ for _ in range(depth)
406
+ ]
407
+ )
408
+
409
+ self.decoder_norm = norm_layer(decoder_embed_dim)
410
+ self.decoder_pred = nn.Linear(
411
+ decoder_embed_dim,
412
+ patch_size[0] * patch_size[1] * patch_size[2] * in_chans,
413
+ bias=True,
414
+ )
415
+
416
+ self.initialize_weights()
417
+
418
+ def initialize_weights(self):
419
+ pos_embed = get_3d_sincos_pos_embed(
420
+ self.decoder_pos_embed.shape[-1], self.grid_size, add_cls_token=True
421
+ )
422
+ self.decoder_pos_embed.data.copy_(
423
+ torch.from_numpy(pos_embed).float().unsqueeze(0)
424
+ )
425
+
426
+ torch.nn.init.normal_(self.mask_token, std=0.02)
427
+ self.apply(_init_weights)
428
+
429
+ def forward(
430
+ self,
431
+ hidden_states: torch.Tensor,
432
+ ids_restore: torch.Tensor,
433
+ temporal_coords: None | torch.Tensor = None,
434
+ location_coords: None | torch.Tensor = None,
435
+ input_size: list[int] = None,
436
+ ):
437
+ x = self.decoder_embed(hidden_states)
438
+
439
+ mask_tokens = self.mask_token.repeat(
440
+ x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1
441
+ )
442
+ x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1)
443
+ x_ = torch.gather(
444
+ x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])
445
+ )
446
+ x = torch.cat([x[:, :1, :], x_], dim=1)
447
+
448
+ x = x + self.decoder_pos_embed
449
+
450
+ if self.temporal_encoding and temporal_coords is not None:
451
+ num_frames = temporal_coords.shape[1]
452
+ tokens_per_frame = (x.shape[1] - 1) // num_frames
453
+ temp_embed = self.temporal_encoder(temporal_coords, tokens_per_frame)
454
+ x[:, 1:, :] = x[:, 1:, :] + temp_embed
455
+
456
+ if self.location_encoding and location_coords is not None:
457
+ x[:, 1:, :] = x[:, 1:, :] + self.location_encoder(location_coords)
458
+
459
+ for blk in self.decoder_blocks:
460
+ x = blk(x)
461
+
462
+ x = self.decoder_norm(x)
463
+ x = self.decoder_pred(x)
464
+ x = x[:, 1:, :]
465
+
466
+ return x
467
+
468
+
469
+ class PrithviMAE(nn.Module):
470
+ """Prithvi Masked Autoencoder."""
471
+
472
+ def __init__(
473
+ self,
474
+ img_size: int | tuple[int, int] = 224,
475
+ patch_size: int | tuple[int, int, int] = (1, 16, 16),
476
+ num_frames: int = 4,
477
+ in_chans: int = 6,
478
+ embed_dim: int = 768,
479
+ depth: int = 12,
480
+ num_heads: int = 12,
481
+ decoder_embed_dim: int = 512,
482
+ decoder_depth: int = 8,
483
+ decoder_num_heads: int = 16,
484
+ mlp_ratio: float = 4.0,
485
+ norm_layer: nn.Module = nn.LayerNorm,
486
+ norm_pix_loss: bool = False,
487
+ coords_encoding: list[str] | None = None,
488
+ coords_scale_learn: bool = False,
489
+ drop_path: float = 0.0,
490
+ mask_ratio: float = 0.75,
491
+ **kwargs,
492
+ ):
493
+ super().__init__()
494
+
495
+ self.img_size = to_2tuple(img_size)
496
+ self.patch_size = (
497
+ patch_size if isinstance(patch_size, tuple) else (1, patch_size, patch_size)
498
+ )
499
+ self.num_frames = num_frames
500
+ self.in_chans = in_chans
501
+ self.norm_pix_loss = norm_pix_loss
502
+
503
+ self.encoder = PrithviViT(
504
+ img_size=img_size,
505
+ patch_size=patch_size,
506
+ num_frames=num_frames,
507
+ in_chans=in_chans,
508
+ embed_dim=embed_dim,
509
+ depth=depth,
510
+ num_heads=num_heads,
511
+ mlp_ratio=mlp_ratio,
512
+ norm_layer=norm_layer,
513
+ coords_encoding=coords_encoding,
514
+ coords_scale_learn=coords_scale_learn,
515
+ drop_path=drop_path,
516
+ )
517
+
518
+ self.decoder = MAEDecoder(
519
+ patch_size=self.patch_size,
520
+ grid_size=self.encoder.patch_embed.grid_size,
521
+ in_chans=in_chans,
522
+ encoder_embed_dim=embed_dim,
523
+ decoder_embed_dim=decoder_embed_dim,
524
+ depth=decoder_depth,
525
+ num_heads=decoder_num_heads,
526
+ mlp_ratio=mlp_ratio,
527
+ norm_layer=norm_layer,
528
+ coords_encoding=coords_encoding,
529
+ coords_scale_learn=coords_scale_learn,
530
+ )
531
+
532
+ def patchify(self, pixel_values):
533
+ B, C, T, H, W = pixel_values.shape
534
+ pH = H // self.patch_size[1]
535
+ pW = W // self.patch_size[2]
536
+
537
+ x = pixel_values.reshape(
538
+ B,
539
+ C,
540
+ T // self.patch_size[0],
541
+ self.patch_size[0],
542
+ pH,
543
+ self.patch_size[1],
544
+ pW,
545
+ self.patch_size[2],
546
+ )
547
+ x = x.permute(0, 2, 4, 6, 3, 5, 7, 1)
548
+ patchified_pixel_values = x.reshape(
549
+ B,
550
+ T // self.patch_size[0] * pH * pW,
551
+ self.patch_size[0] * self.patch_size[1] * self.patch_size[2] * C,
552
+ )
553
+
554
+ return patchified_pixel_values
555
+
556
+ def unpatchify(
557
+ self, patchified_pixel_values, image_size: tuple[int, int] | None = None
558
+ ):
559
+ if image_size is None:
560
+ H, W = self.img_size
561
+ else:
562
+ H, W = image_size
563
+
564
+ C = self.in_chans
565
+ pH = H // self.patch_size[1]
566
+ pW = W // self.patch_size[2]
567
+ T = self.num_frames
568
+
569
+ x = patchified_pixel_values.reshape(
570
+ patchified_pixel_values.shape[0],
571
+ T // self.patch_size[0],
572
+ pH,
573
+ pW,
574
+ self.patch_size[0],
575
+ self.patch_size[1],
576
+ self.patch_size[2],
577
+ C,
578
+ )
579
+ x = x.permute(0, 7, 1, 4, 2, 5, 3, 6)
580
+ pixel_values = x.reshape(
581
+ patchified_pixel_values.shape[0],
582
+ C,
583
+ T,
584
+ pH * self.patch_size[1],
585
+ pW * self.patch_size[2],
586
+ )
587
+
588
+ return pixel_values
589
+
590
+ def forward_loss(self, pixel_values, pred, mask):
591
+ target = self.patchify(pixel_values)
592
+
593
+ if self.norm_pix_loss:
594
+ mean = target.mean(dim=-1, keepdim=True)
595
+ var = target.var(dim=-1, keepdim=True)
596
+ target = (target - mean) / (var + 1.0e-6) ** 0.5
597
+
598
+ loss = (pred - target) ** 2
599
+ loss = loss.mean(dim=-1)
600
+
601
+ loss = (loss * mask).sum() / mask.sum()
602
+ return loss
603
+
604
+ def forward(
605
+ self,
606
+ pixel_values: torch.Tensor,
607
+ temporal_coords: None | torch.Tensor = None,
608
+ location_coords: None | torch.Tensor = None,
609
+ mask_ratio: float = None,
610
+ ):
611
+ mask_ratio = mask_ratio if mask_ratio is not None else 0.75
612
+
613
+ latent, mask, ids_restore = self.encoder(
614
+ pixel_values, temporal_coords, location_coords, mask_ratio
615
+ )
616
+ pred = self.decoder(latent, ids_restore, temporal_coords, location_coords)
617
+ loss = self.forward_loss(pixel_values, pred, mask)
618
+
619
+ return loss, pred, mask
620
+
621
+
622
+ class PrithviProcessor:
623
+ """Prithvi EO 2.0 processor with GeoTIFF input/output support.
624
+
625
+ https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL
626
+ https://github.com/NASA-IMPACT/Prithvi-EO-2.0
627
+ """
628
+
629
+ def __init__(
630
+ self,
631
+ model_name: str = "Prithvi-EO-2.0-300M-TL",
632
+ config_path: Optional[str] = None,
633
+ checkpoint_path: Optional[str] = None,
634
+ device: Optional[torch.device] = None,
635
+ cache_dir: Optional[str] = None,
636
+ ):
637
+ """Initialize Prithvi processor.
638
+
639
+ Args:
640
+ model_name: Name of the Prithvi model to download from HuggingFace Hub
641
+ config_path: Path to config file (optional, downloads if not provided)
642
+ checkpoint_path: Path to checkpoint file (optional, downloads if not provided)
643
+ device: Torch device to use
644
+ cache_dir: Directory to cache downloaded files
645
+ """
646
+ self.device = device or get_device()
647
+ self.model_name = model_name
648
+ self.cache_dir = cache_dir
649
+
650
+ # Download or load config and checkpoint
651
+ if config_path is None or checkpoint_path is None:
652
+ config_path, checkpoint_path = self.download_model(model_name, cache_dir)
653
+
654
+ self.config_path = config_path
655
+ self.checkpoint_path = checkpoint_path
656
+
657
+ # Load config
658
+ with open(config_path, "r") as f:
659
+ config_data = json.load(f)
660
+ self.config = config_data["pretrained_cfg"]
661
+
662
+ # Extract parameters
663
+ self.bands = self.config["bands"]
664
+ self.mean = self.config["mean"]
665
+ self.std = self.config["std"]
666
+ self.img_size = self.config["img_size"]
667
+ self.patch_size = self.config["patch_size"]
668
+ self.mask_ratio = self.config["mask_ratio"]
669
+ self.num_frames = self.config.get("num_frames", 4)
670
+ self.coords_encoding = self.config.get("coords_encoding", [])
671
+
672
+ # Load model
673
+ self.model = self._load_model()
674
+
675
+ @staticmethod
676
+ def download_model(
677
+ model_name: str = "Prithvi-EO-2.0-300M-TL", cache_dir: str = None
678
+ ) -> Tuple[str, str]:
679
+ """Download Prithvi model from HuggingFace Hub.
680
+
681
+ Args:
682
+ model_name: Name of the model
683
+ cache_dir: Directory to cache files
684
+
685
+ Returns:
686
+ Tuple of (config_path, checkpoint_path)
687
+ """
688
+ repo_id = f"ibm-nasa-geospatial/{model_name}"
689
+
690
+ try:
691
+ # Download config
692
+ config_path = hf_hub_download(
693
+ repo_id=repo_id,
694
+ filename="config.json",
695
+ cache_dir=cache_dir,
696
+ )
697
+
698
+ # Download checkpoint
699
+ # Model name format: Prithvi-EO-2.0-300M-TL -> Prithvi_EO_V2_300M_TL.pt
700
+ checkpoint_filename = (
701
+ model_name.replace("-", "_").replace("_2.0_", "_V2_") + ".pt"
702
+ )
703
+ checkpoint_path = hf_hub_download(
704
+ repo_id=repo_id,
705
+ filename=checkpoint_filename,
706
+ cache_dir=cache_dir,
707
+ )
708
+
709
+ return config_path, checkpoint_path
710
+
711
+ except Exception as e:
712
+ raise RuntimeError(f"Failed to download model from HuggingFace Hub: {e}")
713
+
714
+ def _load_model(self) -> PrithviMAE:
715
+ """Load Prithvi MAE model."""
716
+ try:
717
+ # Convert patch_size to tuple if it's a list
718
+ patch_size = self.config["patch_size"]
719
+ if isinstance(patch_size, list):
720
+ patch_size = tuple(patch_size)
721
+
722
+ # Create model
723
+ model = PrithviMAE(
724
+ img_size=self.config["img_size"],
725
+ patch_size=patch_size,
726
+ num_frames=self.config["num_frames"],
727
+ in_chans=self.config["in_chans"],
728
+ embed_dim=self.config["embed_dim"],
729
+ depth=self.config["depth"],
730
+ num_heads=self.config["num_heads"],
731
+ decoder_embed_dim=self.config["decoder_embed_dim"],
732
+ decoder_depth=self.config["decoder_depth"],
733
+ decoder_num_heads=self.config["decoder_num_heads"],
734
+ mlp_ratio=self.config["mlp_ratio"],
735
+ coords_encoding=self.coords_encoding,
736
+ coords_scale_learn=self.config.get("coords_scale_learn", False),
737
+ mask_ratio=self.mask_ratio,
738
+ norm_pix_loss=self.config.get("norm_pix_loss", False),
739
+ )
740
+
741
+ # Load checkpoint
742
+ state_dict = torch.load(
743
+ self.checkpoint_path, map_location=self.device, weights_only=True
744
+ )
745
+
746
+ # Remove fixed pos_embed weights
747
+ for k in list(state_dict.keys()):
748
+ if "pos_embed" in k:
749
+ del state_dict[k]
750
+
751
+ model.load_state_dict(state_dict, strict=False)
752
+ model = model.to(self.device)
753
+ model.eval()
754
+
755
+ total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
756
+
757
+ return model
758
+
759
+ except Exception as e:
760
+ raise RuntimeError(f"Failed to load Prithvi model: {e}")
761
+
762
+ def read_geotiff(self, file_path: str) -> Tuple[np.ndarray, dict, Optional[Tuple]]:
763
+ """Read GeoTIFF file.
764
+
765
+ Args:
766
+ file_path: Path to GeoTIFF file
767
+
768
+ Returns:
769
+ Tuple of (image array, metadata, coordinates)
770
+ """
771
+ with rasterio.open(file_path) as src:
772
+ img = src.read()
773
+ meta = src.meta
774
+ try:
775
+ coords = src.tags()
776
+ except:
777
+ coords = None
778
+
779
+ return img, meta, coords
780
+
781
+ def preprocess_image(
782
+ self,
783
+ img: np.ndarray,
784
+ indices: Optional[List[int]] = None,
785
+ ) -> np.ndarray:
786
+ """Preprocess image for model input.
787
+
788
+ Args:
789
+ img: Image array with shape (C, H, W)
790
+ indices: Optional band indices to select
791
+
792
+ Returns:
793
+ Preprocessed image
794
+ """
795
+ # Move channels to last dimension
796
+ img = np.moveaxis(img, 0, -1)
797
+
798
+ # Select bands if specified
799
+ if indices is not None:
800
+ img = img[..., indices]
801
+
802
+ # Normalize (handle nodata)
803
+ img = np.where(img == NO_DATA, NO_DATA_FLOAT, (img - self.mean) / self.std)
804
+
805
+ return img
806
+
807
+ def load_images(
808
+ self,
809
+ file_paths: List[str],
810
+ indices: Optional[List[int]] = None,
811
+ ) -> Tuple[np.ndarray, List[dict], List, List]:
812
+ """Load and preprocess multiple images.
813
+
814
+ Args:
815
+ file_paths: List of GeoTIFF file paths
816
+ indices: Optional band indices
817
+
818
+ Returns:
819
+ Tuple of (images, metadata, temporal_coords, location_coords)
820
+ """
821
+ # Check if we need to pad to num_frames
822
+ if len(file_paths) < self.num_frames:
823
+ # Pad file_paths by repeating the last file
824
+ file_paths = list(file_paths) + [file_paths[-1]] * (
825
+ self.num_frames - len(file_paths)
826
+ )
827
+ elif len(file_paths) > self.num_frames:
828
+ file_paths = file_paths[: self.num_frames]
829
+
830
+ imgs = []
831
+ metas = []
832
+ temporal_coords = []
833
+ location_coords = []
834
+
835
+ for file in file_paths:
836
+ img, meta, coords = self.read_geotiff(file)
837
+
838
+ # Preprocess
839
+ img = self.preprocess_image(img, indices)
840
+
841
+ imgs.append(img)
842
+ metas.append(meta)
843
+
844
+ # Stack images: (T, H, W, C)
845
+ imgs = np.stack(imgs, axis=0)
846
+ # Rearrange to: (C, T, H, W)
847
+ imgs = np.moveaxis(imgs, -1, 0).astype("float32")
848
+ # Add batch dimension: (1, C, T, H, W)
849
+ imgs = np.expand_dims(imgs, axis=0)
850
+
851
+ return imgs, metas, temporal_coords, location_coords
852
+
853
+ def run_inference(
854
+ self,
855
+ input_data: torch.Tensor,
856
+ temporal_coords: Optional[torch.Tensor] = None,
857
+ location_coords: Optional[torch.Tensor] = None,
858
+ mask_ratio: Optional[float] = None,
859
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
860
+ """Run model inference.
861
+
862
+ Args:
863
+ input_data: Input tensor with shape (B, C, T, H, W)
864
+ temporal_coords: Optional temporal coordinates
865
+ location_coords: Optional location coordinates
866
+ mask_ratio: Mask ratio (default: from config)
867
+
868
+ Returns:
869
+ Tuple of (reconstructed_image, mask_image)
870
+ """
871
+ mask_ratio = mask_ratio or self.mask_ratio
872
+
873
+ # Check if input dimensions match model expectations
874
+ B, C, T, H, W = input_data.shape
875
+ if H % self.img_size != 0 or W % self.img_size != 0:
876
+ raise ValueError(
877
+ f"Input spatial dimensions ({H}x{W}) must be divisible by model image size ({self.img_size}). "
878
+ f"Use process_files() method which handles padding automatically, or pad your input to multiples of {self.img_size}."
879
+ )
880
+
881
+ with torch.no_grad():
882
+ x = input_data.to(self.device)
883
+ _, pred, mask = self.model(x, temporal_coords, location_coords, mask_ratio)
884
+
885
+ # Create mask and prediction images
886
+ mask_img = (
887
+ self.model.unpatchify(mask.unsqueeze(-1).repeat(1, 1, pred.shape[-1]))
888
+ .detach()
889
+ .cpu()
890
+ )
891
+ pred_img = self.model.unpatchify(pred).detach().cpu()
892
+
893
+ # Mix visible and predicted patches
894
+ rec_img = input_data.clone()
895
+ rec_img[mask_img == 1] = pred_img[mask_img == 1]
896
+
897
+ # Invert mask for better visualization
898
+ mask_img = (~(mask_img.to(torch.bool))).to(torch.float)
899
+
900
+ return rec_img, mask_img
901
+
902
+ def process_images(
903
+ self,
904
+ file_paths: List[str],
905
+ mask_ratio: Optional[float] = None,
906
+ indices: Optional[List[int]] = None,
907
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
908
+ """Process multiple GeoTIFF files and return tensors (without saving).
909
+
910
+ This method handles large images using sliding windows and returns tensors
911
+ for visualization, unlike process_files() which saves to disk.
912
+
913
+ Args:
914
+ file_paths: List of input file paths
915
+ mask_ratio: Optional mask ratio
916
+ indices: Optional band indices
917
+
918
+ Returns:
919
+ Tuple of (input_tensor, reconstructed_tensor, mask_tensor)
920
+ """
921
+ # Load images
922
+ input_data, metas, temporal_coords, location_coords = self.load_images(
923
+ file_paths, indices
924
+ )
925
+
926
+ # Handle padding
927
+ original_h, original_w = input_data.shape[-2:]
928
+ pad_h = (self.img_size - (original_h % self.img_size)) % self.img_size
929
+ pad_w = (self.img_size - (original_w % self.img_size)) % self.img_size
930
+
931
+ if pad_h > 0 or pad_w > 0:
932
+ input_data = np.pad(
933
+ input_data,
934
+ ((0, 0), (0, 0), (0, 0), (0, pad_h), (0, pad_w)),
935
+ mode="reflect",
936
+ )
937
+
938
+ # Convert to tensor
939
+ batch = torch.tensor(input_data, device="cpu")
940
+
941
+ # Create sliding windows
942
+ windows = batch.unfold(3, self.img_size, self.img_size).unfold(
943
+ 4, self.img_size, self.img_size
944
+ )
945
+ h1, w1 = windows.shape[3:5]
946
+ windows = rearrange(
947
+ windows,
948
+ "b c t h1 w1 h w -> (b h1 w1) c t h w",
949
+ h=self.img_size,
950
+ w=self.img_size,
951
+ )
952
+
953
+ # Split into batches
954
+ num_batches = max(1, windows.shape[0])
955
+ windows_list = torch.tensor_split(windows, num_batches, dim=0)
956
+
957
+ # Process each window
958
+ rec_imgs = []
959
+ mask_imgs = []
960
+
961
+ for i, x in enumerate(windows_list):
962
+ rec_img, mask_img = self.run_inference(x, None, None, mask_ratio)
963
+ rec_imgs.append(rec_img)
964
+ mask_imgs.append(mask_img)
965
+
966
+ # Concatenate results
967
+ rec_imgs = torch.cat(rec_imgs, dim=0)
968
+ mask_imgs = torch.cat(mask_imgs, dim=0)
969
+
970
+ # Rearrange patches back to image
971
+ num_frames = len(file_paths)
972
+ rec_imgs = rearrange(
973
+ rec_imgs,
974
+ "(b h1 w1) c t h w -> b c t (h1 h) (w1 w)",
975
+ h=self.img_size,
976
+ w=self.img_size,
977
+ b=1,
978
+ c=len(self.bands),
979
+ t=num_frames,
980
+ h1=h1,
981
+ w1=w1,
982
+ )
983
+ mask_imgs = rearrange(
984
+ mask_imgs,
985
+ "(b h1 w1) c t h w -> b c t (h1 h) (w1 w)",
986
+ h=self.img_size,
987
+ w=self.img_size,
988
+ b=1,
989
+ c=len(self.bands),
990
+ t=num_frames,
991
+ h1=h1,
992
+ w1=w1,
993
+ )
994
+
995
+ # Remove padding
996
+ rec_imgs = rec_imgs[..., :original_h, :original_w]
997
+ mask_imgs = mask_imgs[..., :original_h, :original_w]
998
+ input_imgs = batch[..., :original_h, :original_w]
999
+
1000
+ return input_imgs, rec_imgs, mask_imgs
1001
+
1002
+ def visualize_rgb(
1003
+ self,
1004
+ input_tensor: torch.Tensor,
1005
+ rec_tensor: torch.Tensor,
1006
+ mask_tensor: torch.Tensor,
1007
+ ) -> Tuple[List[np.ndarray], List[np.ndarray], List[np.ndarray]]:
1008
+ """Extract RGB images from tensors for visualization.
1009
+
1010
+ Args:
1011
+ input_tensor: Input tensor (B, C, T, H, W)
1012
+ rec_tensor: Reconstructed tensor (B, C, T, H, W)
1013
+ mask_tensor: Mask tensor (B, C, T, H, W)
1014
+
1015
+ Returns:
1016
+ Tuple of (original_rgb, masked_rgb, reconstructed_rgb) lists
1017
+ """
1018
+ # Get RGB channel indices (B04=Red, B03=Green, B02=Blue)
1019
+ rgb_channels = [
1020
+ self.bands.index("B04"),
1021
+ self.bands.index("B03"),
1022
+ self.bands.index("B02"),
1023
+ ]
1024
+
1025
+ # Remove batch dimension
1026
+ if input_tensor.dim() == 5:
1027
+ input_tensor = input_tensor[0]
1028
+ if rec_tensor.dim() == 5:
1029
+ rec_tensor = rec_tensor[0]
1030
+ if mask_tensor.dim() == 5:
1031
+ mask_tensor = mask_tensor[0]
1032
+
1033
+ mean = torch.tensor(self.mean)
1034
+ std = torch.tensor(self.std)
1035
+
1036
+ original_rgb = []
1037
+ masked_rgb = []
1038
+ reconstructed_rgb = []
1039
+
1040
+ num_frames = input_tensor.shape[1]
1041
+
1042
+ for t in range(num_frames):
1043
+ # Extract and denormalize original RGB
1044
+ rgb_orig = input_tensor[rgb_channels, t, :, :].clone()
1045
+ for i, c in enumerate(rgb_channels):
1046
+ rgb_orig[i] = rgb_orig[i] * std[c] + mean[c]
1047
+ rgb_orig_np = rgb_orig.numpy()
1048
+ rgb_orig_np = np.clip(rgb_orig_np, 0, 10000)
1049
+ rgb_orig_np = (rgb_orig_np / 10000 * 255).astype(np.uint8)
1050
+ rgb_orig_np = np.transpose(rgb_orig_np, (1, 2, 0))
1051
+ original_rgb.append(rgb_orig_np)
1052
+
1053
+ # Extract and denormalize reconstructed RGB
1054
+ rgb_rec = rec_tensor[rgb_channels, t, :, :].clone()
1055
+ for i, c in enumerate(rgb_channels):
1056
+ rgb_rec[i] = rgb_rec[i] * std[c] + mean[c]
1057
+ rgb_rec_np = rgb_rec.numpy()
1058
+ rgb_rec_np = np.clip(rgb_rec_np, 0, 10000)
1059
+ rgb_rec_np = (rgb_rec_np / 10000 * 255).astype(np.uint8)
1060
+ rgb_rec_np = np.transpose(rgb_rec_np, (1, 2, 0))
1061
+ reconstructed_rgb.append(rgb_rec_np)
1062
+
1063
+ # Create masked RGB (visible patches only)
1064
+ mask_t = mask_tensor[rgb_channels, t, :, :].numpy()
1065
+ masked_np = rgb_orig_np.astype(np.float32) * np.transpose(mask_t, (1, 2, 0))
1066
+ masked_rgb.append(masked_np.astype(np.uint8))
1067
+
1068
+ return original_rgb, masked_rgb, reconstructed_rgb
1069
+
1070
+ def process_files(
1071
+ self,
1072
+ file_paths: List[str],
1073
+ output_dir: str,
1074
+ mask_ratio: Optional[float] = None,
1075
+ indices: Optional[List[int]] = None,
1076
+ ):
1077
+ """Process multiple GeoTIFF files.
1078
+
1079
+ Args:
1080
+ file_paths: List of input file paths
1081
+ output_dir: Output directory for results
1082
+ mask_ratio: Optional mask ratio
1083
+ indices: Optional band indices
1084
+ """
1085
+ os.makedirs(output_dir, exist_ok=True)
1086
+
1087
+ # Load images
1088
+ input_data, metas, temporal_coords, location_coords = self.load_images(
1089
+ file_paths, indices
1090
+ )
1091
+
1092
+ # Handle padding
1093
+ original_h, original_w = input_data.shape[-2:]
1094
+ pad_h = self.img_size - (original_h % self.img_size)
1095
+ pad_w = self.img_size - (original_w % self.img_size)
1096
+ input_data = np.pad(
1097
+ input_data, ((0, 0), (0, 0), (0, 0), (0, pad_h), (0, pad_w)), mode="reflect"
1098
+ )
1099
+
1100
+ # Convert to tensor
1101
+ batch = torch.tensor(input_data, device="cpu")
1102
+
1103
+ # Create sliding windows
1104
+ windows = batch.unfold(3, self.img_size, self.img_size).unfold(
1105
+ 4, self.img_size, self.img_size
1106
+ )
1107
+ h1, w1 = windows.shape[3:5]
1108
+ windows = rearrange(
1109
+ windows,
1110
+ "b c t h1 w1 h w -> (b h1 w1) c t h w",
1111
+ h=self.img_size,
1112
+ w=self.img_size,
1113
+ )
1114
+
1115
+ # Split into batches
1116
+ num_batches = max(1, windows.shape[0])
1117
+ windows_list = torch.tensor_split(windows, num_batches, dim=0)
1118
+
1119
+ # Process each window
1120
+ rec_imgs = []
1121
+ mask_imgs = []
1122
+
1123
+ for i, x in enumerate(windows_list):
1124
+ rec_img, mask_img = self.run_inference(x, None, None, mask_ratio)
1125
+ rec_imgs.append(rec_img)
1126
+ mask_imgs.append(mask_img)
1127
+
1128
+ # Concatenate results
1129
+ rec_imgs = torch.cat(rec_imgs, dim=0)
1130
+ mask_imgs = torch.cat(mask_imgs, dim=0)
1131
+
1132
+ # Rearrange patches back to image
1133
+ num_frames = len(file_paths)
1134
+ rec_imgs = rearrange(
1135
+ rec_imgs,
1136
+ "(b h1 w1) c t h w -> b c t (h1 h) (w1 w)",
1137
+ h=self.img_size,
1138
+ w=self.img_size,
1139
+ b=1,
1140
+ c=len(self.bands),
1141
+ t=num_frames,
1142
+ h1=h1,
1143
+ w1=w1,
1144
+ )
1145
+ mask_imgs = rearrange(
1146
+ mask_imgs,
1147
+ "(b h1 w1) c t h w -> b c t (h1 h) (w1 w)",
1148
+ h=self.img_size,
1149
+ w=self.img_size,
1150
+ b=1,
1151
+ c=len(self.bands),
1152
+ t=num_frames,
1153
+ h1=h1,
1154
+ w1=w1,
1155
+ )
1156
+
1157
+ # Remove padding
1158
+ rec_imgs = rec_imgs[..., :original_h, :original_w]
1159
+ mask_imgs = mask_imgs[..., :original_h, :original_w]
1160
+
1161
+ # Save results
1162
+ self.save_results(rec_imgs[0], mask_imgs[0], metas, output_dir)
1163
+
1164
+ def save_results(
1165
+ self,
1166
+ rec_img: torch.Tensor,
1167
+ mask_img: torch.Tensor,
1168
+ metas: List[dict],
1169
+ output_dir: str,
1170
+ ):
1171
+ """Save reconstruction results.
1172
+
1173
+ Args:
1174
+ rec_img: Reconstructed image with shape (C, T, H, W)
1175
+ mask_img: Mask image with shape (C, T, H, W)
1176
+ metas: List of metadata dicts
1177
+ output_dir: Output directory
1178
+ """
1179
+ mean = torch.tensor(np.asarray(self.mean)[:, None, None])
1180
+ std = torch.tensor(np.asarray(self.std)[:, None, None])
1181
+
1182
+ for t in range(rec_img.shape[1]):
1183
+ # Denormalize
1184
+ rec_img_t = ((rec_img[:, t, :, :] * std) + mean).to(torch.int16)
1185
+ mask_img_t = mask_img[:, t, :, :].to(torch.int16)
1186
+
1187
+ # Update metadata
1188
+ meta = metas[t].copy()
1189
+ meta.update(compress="lzw", nodata=0)
1190
+
1191
+ # Save files
1192
+ self._save_geotiff(
1193
+ rec_img_t.numpy(),
1194
+ os.path.join(output_dir, f"reconstructed_t{t}.tif"),
1195
+ meta,
1196
+ )
1197
+ self._save_geotiff(
1198
+ mask_img_t.numpy(),
1199
+ os.path.join(output_dir, f"mask_t{t}.tif"),
1200
+ meta,
1201
+ )
1202
+
1203
+ @staticmethod
1204
+ def _save_geotiff(image: np.ndarray, output_path: str, meta: dict):
1205
+ """Save GeoTIFF file."""
1206
+ with rasterio.open(output_path, "w", **meta) as dest:
1207
+ for i in range(image.shape[0]):
1208
+ dest.write(image[i], i + 1)
1209
+
1210
+
1211
+ def load_prithvi_model(
1212
+ model_name: str = "Prithvi-EO-2.0-300M-TL",
1213
+ device: Optional[str] = None,
1214
+ cache_dir: Optional[str] = None,
1215
+ ) -> PrithviProcessor:
1216
+ """Load Prithvi model (convenience function).
1217
+
1218
+ Args:
1219
+ model_name: Name of the model
1220
+ device: Device to use ('cuda' or 'cpu')
1221
+ cache_dir: Cache directory
1222
+
1223
+ Returns:
1224
+ PrithviProcessor instance
1225
+ """
1226
+ if device is not None:
1227
+ device = torch.device(device)
1228
+
1229
+ return PrithviProcessor(
1230
+ model_name=model_name,
1231
+ device=device,
1232
+ cache_dir=cache_dir,
1233
+ )
1234
+
1235
+
1236
+ def prithvi_inference(
1237
+ file_paths: List[str],
1238
+ output_dir: str = "output",
1239
+ model_name: str = "Prithvi-EO-2.0-300M-TL",
1240
+ mask_ratio: Optional[float] = None,
1241
+ device: Optional[str] = None,
1242
+ ):
1243
+ """Run Prithvi inference on GeoTIFF files (convenience function).
1244
+
1245
+ Args:
1246
+ file_paths: List of input GeoTIFF files
1247
+ output_dir: Output directory
1248
+ model_name: Name of the model
1249
+ mask_ratio: Optional mask ratio
1250
+ device: Device to use
1251
+ """
1252
+ processor = load_prithvi_model(model_name, device)
1253
+ processor.process_files(file_paths, output_dir, mask_ratio)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: geoai-py
3
- Version: 0.25.0
3
+ Version: 0.26.0
4
4
  Summary: A Python package for using Artificial Intelligence (AI) with geospatial data
5
5
  Author-email: Qiusheng Wu <giswqs@gmail.com>
6
6
  License: MIT License
@@ -1,6 +1,6 @@
1
- geoai/__init__.py,sha256=4ppPdZeGdKBljCj6ioyNlC53SNi8N-e_GoSEJ2pMr0U,5334
2
- geoai/auto.py,sha256=OPj_wsSlMCvVNmpPu28k3255AbMX-a5_DkAO6VZ_edA,69220
3
- geoai/change_detection.py,sha256=pdQofnPRiwoES8vMln2vHghRnpeTdsmqLir74dnqZYU,60389
1
+ geoai/__init__.py,sha256=i5Kxxvdo0uANDhtMtzYoRWKhn0vIVOEAAH4Yy5ozXvY,5577
2
+ geoai/auto.py,sha256=-jlyk9mO9n8BiUnrN9OZ1GU9B4PFwINSoYMfEeDVDNk,69219
3
+ geoai/change_detection.py,sha256=g9-cVHO2hPdJf-VSAN2NckX2v0EfNgCUdiCRVnhjLyc,60387
4
4
  geoai/classify.py,sha256=0DcComVR6vKU4qWtH2oHVeXc7ZTcV0mFvdXRtlNmolo,35637
5
5
  geoai/detectron2.py,sha256=dOOFM9M9-6PV8q2A4-mnIPrz7yTo-MpEvDiAW34nl0w,14610
6
6
  geoai/dinov3.py,sha256=u4Lulihhvs4wTgi84RjRw8jWQpB8omQSl-dVVryNVus,40377
@@ -11,7 +11,8 @@ geoai/hf.py,sha256=HbfJfpO6XnANKhmFOBvpwULiC65TeMlnLNtyQHHmlKA,17248
11
11
  geoai/landcover_train.py,sha256=QuuTOAVCfNqVarkrKVZRFQRcPM5poA_3_PRc9GduKpc,24713
12
12
  geoai/landcover_utils.py,sha256=ETNhPORZk84BUvpzZvTiJ85AgAJ4fhPlIf4WgQsHYJU,14231
13
13
  geoai/map_widgets.py,sha256=HECIBcLbDBtLaKKNt-xA5HqLNbJAjOpuVZnfN_MaQAE,25521
14
- geoai/moondream.py,sha256=mogBWdvviEp0J_xV2GijTD8Z_C9tf6MH-XsP4SqHO-8,60923
14
+ geoai/moondream.py,sha256=sU8tXomWzRsjJqw9dhqtZfrQTkvDjbo5ARLP4AVfSQA,60922
15
+ geoai/prithvi.py,sha256=hrMFB9oIcR2UIHwpbijuXrNGWniYO8hEW_RTGqxkRDM,40515
15
16
  geoai/sam.py,sha256=O6S-kGiFn7YEcFbfWFItZZQOhnsm6-GlunxQLY0daEs,34345
16
17
  geoai/segment.py,sha256=yBGTxA-ti8lBpk7WVaBOp6yP23HkaulKJQk88acrmZ0,43788
17
18
  geoai/segmentation.py,sha256=7yEzBSKCyHW1dNssoK0rdvhxi2IXsIQIFSga817KdI4,11535
@@ -30,9 +31,9 @@ geoai/tools/__init__.py,sha256=DE1kZOxzaW8C89TfOWQ3Vdv4WsBLV_SFiFWNWVWc4Ys,1922
30
31
  geoai/tools/cloudmask.py,sha256=qzvqVa8FAEgd8mePXBaV5Ptx4fHhwfS1BsYL0JAZBjM,14500
31
32
  geoai/tools/multiclean.py,sha256=TVwmWgeQyGIyUuCe10b6pGCtgIl8TkZmcgVXPimn9uM,11949
32
33
  geoai/tools/sr.py,sha256=kg6Zkq2wB2Ve7c1WblCfbDgd7hFkY65wWgFiD1zC7Vg,7018
33
- geoai_py-0.25.0.dist-info/licenses/LICENSE,sha256=TlBm8mRusRVB9yF2NTg-STcb71v69-XZaKaPdshqP2I,1074
34
- geoai_py-0.25.0.dist-info/METADATA,sha256=HEbmXbhVDkoaHfCORrv4Hb-ozmMxOKs-G0jlgaKy8xc,12065
35
- geoai_py-0.25.0.dist-info/WHEEL,sha256=JNWh1Fm1UdwIQV075glCn4MVuCRs0sotJIq-J6rbxCU,109
36
- geoai_py-0.25.0.dist-info/entry_points.txt,sha256=uGp3Az3HURIsRHP9v-ys0hIbUuBBNUfXv6VbYHIXeg4,41
37
- geoai_py-0.25.0.dist-info/top_level.txt,sha256=1YkCUWu-ii-0qIex7kbwAvfei-gos9ycyDyUCJPNWHY,6
38
- geoai_py-0.25.0.dist-info/RECORD,,
34
+ geoai_py-0.26.0.dist-info/licenses/LICENSE,sha256=TlBm8mRusRVB9yF2NTg-STcb71v69-XZaKaPdshqP2I,1074
35
+ geoai_py-0.26.0.dist-info/METADATA,sha256=5kR92lHD9NwnZEdLkn42tWtwDAtjwH4nWfFYmA4E2ww,12065
36
+ geoai_py-0.26.0.dist-info/WHEEL,sha256=Q6xS052dXadQWXcEVKSI037R6NoyqhUlJ5BcYz2iMP4,110
37
+ geoai_py-0.26.0.dist-info/entry_points.txt,sha256=uGp3Az3HURIsRHP9v-ys0hIbUuBBNUfXv6VbYHIXeg4,41
38
+ geoai_py-0.26.0.dist-info/top_level.txt,sha256=1YkCUWu-ii-0qIex7kbwAvfei-gos9ycyDyUCJPNWHY,6
39
+ geoai_py-0.26.0.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (80.9.0)
2
+ Generator: setuptools (80.10.1)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py2-none-any
5
5
  Tag: py3-none-any